2025-08-13 23:50:20 +00:00
"""
2025-08-28 17:57:59 +00:00
2025.8.9
2025.8.10
4.55.4
2025-08-13 23:50:20 +00:00
0.21.0
__UNSLOTH_VERSIONING__
"""
from torch import Tensor
import torch
import torch . nn as nn
from torch . nn import functional as F
from typing import Any , List , Optional , Tuple , Union , Dict , Set , Callable
from trl . trainer . rloo_trainer import ( Accelerator , BaseImageProcessor , Callable , CallbackHandler , DEFAULT_CALLBACKS , DEFAULT_PROGRESS_CALLBACK , DataCollatorWithPadding , DataLoader , Dataset , ExportableState , FeatureExtractionMixin , GenerationConfig , INVALID_LOGPROB , OnlineTrainerState , Optional , Path , PreTrainedTokenizerBase , PrinterCallback , ProcessorMixin , RLOOConfig , RLOOTrainer , Trainer , TrainerCallback , TrainerControl , Union , batch_generation , broadcast , defaultdict , disable_dropout_in_model , empty_cache , exact_div , first_true_indices , forward , gather_object , gc , generate_model_card , get_comet_experiment_url , get_reporting_integration_callbacks , get_reward , is_rich_available , is_wandb_available , log_table_to_comet_experiment , math , nn , np , os , pd , prepare_deepspeed , print_rich_table , selective_log_softmax , textwrap , time , torch , truncate_response , unwrap_model_for_generation , Optional , Trainer , os , torch )
import os
from typing import *
from dataclasses import dataclass , field
from packaging . version import Version
import torch
import numpy as np
from contextlib import nullcontext
from torch . nn import functional as F
from transformers import DataCollatorForSeq2Seq , DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling
torch_compile_options = {
" epilogue_fusion " : True ,
" max_autotune " : False ,
" shape_padding " : True ,
" trace.enabled " : False ,
" triton.cudagraphs " : False ,
}
@torch.compile ( dynamic = True , fullgraph = True , options = torch_compile_options , )
def chunked_selective_log_softmax ( logits , index ) :
# Split into 4 chunks only
chunked_logits = torch . chunk ( logits . reshape ( - 1 , logits . shape [ - 1 ] ) , chunks = 4 , dim = 0 )
chunked_index = torch . chunk ( index . reshape ( - 1 ) , chunks = 4 , dim = 0 )
all_per_token_logps = [ ]
# Below loop does the same as selective_log_softmax(chunk_logits, chunk_index)
for chunk_logits , chunk_index in zip ( chunked_logits , chunked_index ) :
chunk_logits = chunk_logits . to ( torch . float32 )
selected_logits = torch . gather ( chunk_logits , dim = - 1 , index = chunk_index . unsqueeze ( - 1 ) ) . squeeze ( - 1 )
logsumexp_values = torch . logsumexp ( chunk_logits , dim = - 1 )
per_token_logps = selected_logits - logsumexp_values
all_per_token_logps . append ( per_token_logps )
pass
all_per_token_logps = torch . concat ( all_per_token_logps )
all_per_token_logps = all_per_token_logps . reshape ( ( logits . shape [ 0 ] , logits . shape [ 1 ] ) )
return all_per_token_logps
@dataclass
class UnslothRLOOConfig ( RLOOConfig ) :
"""
Configuration class for the [`RLOOTrainer`].
This class includes only the parameters that are specific to RLOO training. For a full list of training arguments,
please refer to the [`~transformers.TrainingArguments`] and [`OnPolicyConfig`] documentation. Note that default
values in this class may differ from those in [`~transformers.TrainingArguments`].
Using [`~transformers.HfArgumentParser`] we can turn this class into
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
command line.
Parameters:
exp_name (`str`, *optional*, defaults to `os.path.basename(__file__)[: -len( " .py " )]`):
Name of this experiment.
reward_model_path (`str`, *optional*, defaults to ` " EleutherAI/pythia-160m " `):
Path to the reward model.
num_ppo_epochs (`int`, *optional*, defaults to `4`):
Number of epochs to train.
whiten_rewards (`bool`, *optional*, defaults to `False`):
Whether to whiten the rewards.
kl_coef (`float`, *optional*, defaults to `0.05`):
KL coefficient.
cliprange (`float`, *optional*, defaults to `0.2`):
Clip range.
rloo_k (`int`, *optional*, defaults to `2`):
REINFORCE Leave-One-Out (RLOO) number of online samples per prompt.
normalize_reward (`bool`, *optional*, defaults to `False`):
Whether to normalize rewards.
reward_clip_range (`float`, *optional*, defaults to `10.0`):
Clip range for rewards.
normalize_advantage (`bool`, *optional*, defaults to `False`):
Whether to normalize advantages.
token_level_kl (`bool`, *optional*, defaults to `True`):
Whether to use token-level KL penalty or sequence-level KL penalty.
ds3_gather_for_generation (`bool`, *optional*, defaults to `True`):
This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for generation,
improving generation speed. However, disabling this option allows training models that exceed the VRAM
capacity of a single GPU, albeit at the cost of slower generation.
"""
vllm_sampling_params : Optional [ Any ] = field (
default = None ,
metadata = { ' help ' : ' vLLM SamplingParams ' } ,
)
unsloth_num_chunks : Optional [ int ] = field (
default = - 1 ,
metadata = { ' help ' : ' Chunk size to reduce memory usage. -1 is most efficient. ' } ,
)
2025-08-28 17:57:59 +00:00
2025-08-13 23:50:20 +00:00
def __init__ (
self ,
output_dir = None ,
overwrite_output_dir = None ,
do_train = False ,
do_eval = False ,
do_predict = False ,
eval_strategy = ' no ' ,
prediction_loss_only = False ,
per_device_train_batch_size = 4 ,
per_device_eval_batch_size = 4 ,
per_gpu_train_batch_size = None ,
per_gpu_eval_batch_size = None ,
gradient_accumulation_steps = 2 ,
eval_accumulation_steps = 2 ,
eval_delay = 0 ,
torch_empty_cache_steps = 250 ,
learning_rate = 5e-05 ,
weight_decay = 0.01 ,
adam_beta1 = 0.9 ,
adam_beta2 = 0.999 ,
adam_epsilon = 1e-08 ,
max_grad_norm = 1.0 ,
num_train_epochs = 3.0 ,
max_steps = - 1 ,
lr_scheduler_type = ' linear ' ,
warmup_ratio = 0.1 ,
warmup_steps = 0 ,
log_level = ' passive ' ,
log_level_replica = ' warning ' ,
log_on_each_node = True ,
logging_dir = None ,
logging_strategy = ' steps ' ,
logging_first_step = False ,
logging_steps = 1 ,
logging_nan_inf_filter = False ,
save_strategy = ' steps ' ,
save_steps = 500 ,
save_total_limit = None ,
save_safetensors = True ,
save_on_each_node = False ,
save_only_model = False ,
restore_callback_states_from_checkpoint = False ,
no_cuda = False ,
use_cpu = False ,
use_mps_device = False ,
seed = 3407 ,
data_seed = 3407 ,
jit_mode_eval = False ,
use_ipex = False ,
bf16 = False ,
fp16 = False ,
fp16_opt_level = ' O1 ' ,
half_precision_backend = ' auto ' ,
bf16_full_eval = False ,
fp16_full_eval = False ,
tf32 = None ,
local_rank = - 1 ,
ddp_backend = None ,
tpu_num_cores = None ,
tpu_metrics_debug = False ,
debug = ' ' ,
dataloader_drop_last = False ,
eval_steps = None ,
dataloader_num_workers = 0 ,
dataloader_prefetch_factor = None ,
past_index = - 1 ,
run_name = None ,
disable_tqdm = None ,
remove_unused_columns = True ,
label_names = None ,
load_best_model_at_end = False ,
metric_for_best_model = None ,
greater_is_better = None ,
ignore_data_skip = False ,
fsdp = ' ' ,
fsdp_min_num_params = 0 ,
fsdp_config = None ,
fsdp_transformer_layer_cls_to_wrap = None ,
accelerator_config = None ,
deepspeed = None ,
label_smoothing_factor = 0.0 ,
optim = ' adamw_8bit ' ,
optim_args = None ,
adafactor = False ,
group_by_length = False ,
length_column_name = ' length ' ,
report_to = None ,
ddp_find_unused_parameters = None ,
ddp_bucket_cap_mb = None ,
ddp_broadcast_buffers = None ,
dataloader_pin_memory = True ,
dataloader_persistent_workers = False ,
skip_memory_metrics = True ,
use_legacy_prediction_loop = False ,
push_to_hub = False ,
resume_from_checkpoint = None ,
hub_model_id = None ,
hub_strategy = ' every_save ' ,
hub_token = None ,
hub_private_repo = None ,
hub_always_push = False ,
hub_revision = None ,
gradient_checkpointing = False ,
gradient_checkpointing_kwargs = None ,
include_inputs_for_metrics = False ,
eval_do_concat_batches = True ,
fp16_backend = ' auto ' ,
push_to_hub_model_id = None ,
push_to_hub_organization = None ,
push_to_hub_token = None ,
mp_parameters = ' ' ,
auto_find_batch_size = True ,
full_determinism = False ,
torchdynamo = None ,
ray_scope = ' last ' ,
ddp_timeout = 1800 ,
torch_compile = False ,
torch_compile_backend = None ,
torch_compile_mode = None ,
include_tokens_per_second = False ,
include_num_input_tokens_seen = False ,
neftune_noise_alpha = None ,
optim_target_modules = None ,
batch_eval_metrics = False ,
eval_on_start = False ,
use_liger_kernel = False ,
liger_kernel_config = None ,
eval_use_gather_object = False ,
average_tokens_across_devices = True ,
dataset_num_proc = None ,
num_mini_batches = 1 ,
total_episodes = None ,
local_rollout_forward_batch_size = 64 ,
num_sample_generations = 10 ,
response_length = 53 ,
stop_token = None ,
stop_token_id = None ,
temperature = 0.7 ,
missing_eos_penalty = None ,
sft_model_path = ' EleutherAI/pythia-160m ' ,
world_size = None ,
num_total_batches = None ,
micro_batch_size = None ,
local_batch_size = None ,
batch_size = None ,
local_mini_batch_size = None ,
mini_batch_size = None ,
exp_name = ' rloo_config ' ,
reward_model_path = ' EleutherAI/pythia-160m ' ,
num_ppo_epochs = 4 ,
whiten_rewards = False ,
kl_coef = 0.05 ,
cliprange = 0.2 ,
rloo_k = 2 ,
normalize_reward = False ,
reward_clip_range = 10.0 ,
normalize_advantage = False ,
token_level_kl = False ,
ds3_gather_for_generation = True ,
vllm_sampling_params = None ,
unsloth_num_chunks = - 1 ,
2025-08-28 17:57:59 +00:00
2025-08-13 23:50:20 +00:00
* * kwargs ,
) :
if learning_rate < 1e-7 : raise FloatingPointError ( f ' Unsloth: Your learning rate of ` { learning_rate } ` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0! ' )
if learning_rate > 1 : raise OverflowError ( f ' Unsloth: Your learning rate of ` { learning_rate } ` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode! ' )
if output_dir is None and save_strategy == ' steps ' and save_steps == 500 :
output_dir = ' unsloth_training_checkpoints '
save_strategy = ' no '
if dataset_num_proc is None :
from multiprocessing import cpu_count
2025-08-28 17:57:59 +00:00
dataset_num_proc = max ( cpu_count ( ) + 4 , 2 )
2025-08-13 23:50:20 +00:00
if temperature < = 0 :
raise MathError ( ' Unsloth: Please set a positive non-zero temperature since your results will be wrong. ' )
elif temperature > = 10 :
raise MathError ( ' Unsloth: Please set a positive non-zero temperature less than 10, since sampling will be quite erratic. ' )
super ( ) . __init__ (
output_dir = output_dir ,
overwrite_output_dir = overwrite_output_dir ,
do_train = do_train ,
do_eval = do_eval ,
do_predict = do_predict ,
eval_strategy = eval_strategy ,
prediction_loss_only = prediction_loss_only ,
per_device_train_batch_size = per_device_train_batch_size ,
per_device_eval_batch_size = per_device_eval_batch_size ,
per_gpu_train_batch_size = per_gpu_train_batch_size ,
per_gpu_eval_batch_size = per_gpu_eval_batch_size ,
gradient_accumulation_steps = gradient_accumulation_steps ,
eval_accumulation_steps = eval_accumulation_steps ,
eval_delay = eval_delay ,
torch_empty_cache_steps = torch_empty_cache_steps ,
learning_rate = learning_rate ,
weight_decay = weight_decay ,
adam_beta1 = adam_beta1 ,
adam_beta2 = adam_beta2 ,
adam_epsilon = adam_epsilon ,
max_grad_norm = max_grad_norm ,
num_train_epochs = num_train_epochs ,
max_steps = max_steps ,
lr_scheduler_type = lr_scheduler_type ,
warmup_ratio = warmup_ratio ,
warmup_steps = warmup_steps ,
log_level = log_level ,
log_level_replica = log_level_replica ,
log_on_each_node = log_on_each_node ,
logging_dir = logging_dir ,
logging_strategy = logging_strategy ,
logging_first_step = logging_first_step ,
logging_steps = logging_steps ,
logging_nan_inf_filter = logging_nan_inf_filter ,
save_strategy = save_strategy ,
save_steps = save_steps ,
save_total_limit = save_total_limit ,
save_safetensors = save_safetensors ,
save_on_each_node = save_on_each_node ,
save_only_model = save_only_model ,
restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint ,
no_cuda = no_cuda ,
use_cpu = use_cpu ,
use_mps_device = use_mps_device ,
seed = seed ,
data_seed = data_seed ,
jit_mode_eval = jit_mode_eval ,
use_ipex = use_ipex ,
bf16 = bf16 ,
fp16 = fp16 ,
fp16_opt_level = fp16_opt_level ,
half_precision_backend = half_precision_backend ,
bf16_full_eval = bf16_full_eval ,
fp16_full_eval = fp16_full_eval ,
tf32 = tf32 ,
local_rank = local_rank ,
ddp_backend = ddp_backend ,
tpu_num_cores = tpu_num_cores ,
tpu_metrics_debug = tpu_metrics_debug ,
debug = debug ,
dataloader_drop_last = dataloader_drop_last ,
eval_steps = eval_steps ,
dataloader_num_workers = dataloader_num_workers ,
dataloader_prefetch_factor = dataloader_prefetch_factor ,
past_index = past_index ,
run_name = run_name ,
disable_tqdm = disable_tqdm ,
remove_unused_columns = remove_unused_columns ,
label_names = label_names ,
load_best_model_at_end = load_best_model_at_end ,
metric_for_best_model = metric_for_best_model ,
greater_is_better = greater_is_better ,
ignore_data_skip = ignore_data_skip ,
fsdp = fsdp ,
fsdp_min_num_params = fsdp_min_num_params ,
fsdp_config = fsdp_config ,
fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap ,
accelerator_config = accelerator_config ,
deepspeed = deepspeed ,
label_smoothing_factor = label_smoothing_factor ,
optim = optim ,
optim_args = optim_args ,
adafactor = adafactor ,
group_by_length = group_by_length ,
length_column_name = length_column_name ,
report_to = report_to ,
ddp_find_unused_parameters = ddp_find_unused_parameters ,
ddp_bucket_cap_mb = ddp_bucket_cap_mb ,
ddp_broadcast_buffers = ddp_broadcast_buffers ,
dataloader_pin_memory = dataloader_pin_memory ,
dataloader_persistent_workers = dataloader_persistent_workers ,
skip_memory_metrics = skip_memory_metrics ,
use_legacy_prediction_loop = use_legacy_prediction_loop ,
push_to_hub = push_to_hub ,
resume_from_checkpoint = resume_from_checkpoint ,
hub_model_id = hub_model_id ,
hub_strategy = hub_strategy ,
hub_token = hub_token ,
hub_private_repo = hub_private_repo ,
hub_always_push = hub_always_push ,
hub_revision = hub_revision ,
gradient_checkpointing = gradient_checkpointing ,
gradient_checkpointing_kwargs = gradient_checkpointing_kwargs ,
include_inputs_for_metrics = include_inputs_for_metrics ,
eval_do_concat_batches = eval_do_concat_batches ,
fp16_backend = fp16_backend ,
push_to_hub_model_id = push_to_hub_model_id ,
push_to_hub_organization = push_to_hub_organization ,
push_to_hub_token = push_to_hub_token ,
mp_parameters = mp_parameters ,
auto_find_batch_size = auto_find_batch_size ,
full_determinism = full_determinism ,
torchdynamo = torchdynamo ,
ray_scope = ray_scope ,
ddp_timeout = ddp_timeout ,
torch_compile = torch_compile ,
torch_compile_backend = torch_compile_backend ,
torch_compile_mode = torch_compile_mode ,
include_tokens_per_second = include_tokens_per_second ,
include_num_input_tokens_seen = include_num_input_tokens_seen ,
neftune_noise_alpha = neftune_noise_alpha ,
optim_target_modules = optim_target_modules ,
batch_eval_metrics = batch_eval_metrics ,
eval_on_start = eval_on_start ,
use_liger_kernel = use_liger_kernel ,
liger_kernel_config = liger_kernel_config ,
eval_use_gather_object = eval_use_gather_object ,
average_tokens_across_devices = average_tokens_across_devices ,
dataset_num_proc = dataset_num_proc ,
num_mini_batches = num_mini_batches ,
total_episodes = total_episodes ,
local_rollout_forward_batch_size = local_rollout_forward_batch_size ,
num_sample_generations = num_sample_generations ,
response_length = response_length ,
stop_token = stop_token ,
stop_token_id = stop_token_id ,
temperature = temperature ,
missing_eos_penalty = missing_eos_penalty ,
sft_model_path = sft_model_path ,
world_size = world_size ,
num_total_batches = num_total_batches ,
micro_batch_size = micro_batch_size ,
local_batch_size = local_batch_size ,
batch_size = batch_size ,
local_mini_batch_size = local_mini_batch_size ,
mini_batch_size = mini_batch_size ,
exp_name = exp_name ,
reward_model_path = reward_model_path ,
num_ppo_epochs = num_ppo_epochs ,
whiten_rewards = whiten_rewards ,
kl_coef = kl_coef ,
cliprange = cliprange ,
rloo_k = rloo_k ,
normalize_reward = normalize_reward ,
reward_clip_range = reward_clip_range ,
normalize_advantage = normalize_advantage ,
token_level_kl = token_level_kl ,
ds3_gather_for_generation = ds3_gather_for_generation , * * kwargs )
self . vllm_sampling_params = vllm_sampling_params
self . unsloth_num_chunks = unsloth_num_chunks
2025-08-28 17:57:59 +00:00
2025-08-13 23:50:20 +00:00
pass
class _UnslothRLOOTrainer ( Trainer ) :
_tag_names = [ " trl " , " rloo " ]
def __init__ (
self ,
config : RLOOConfig ,
processing_class : Optional [
Union [ PreTrainedTokenizerBase , BaseImageProcessor , FeatureExtractionMixin , ProcessorMixin ]
] ,
policy : nn . Module ,
ref_policy : nn . Module ,
reward_model : Union [ nn . Module , Callable [ [ list [ str ] ] , list [ float ] ] ] ,
train_dataset : Dataset ,
data_collator : Optional [ DataCollatorWithPadding ] = None ,
eval_dataset : Optional [ Union [ Dataset , dict [ str , Dataset ] ] ] = None ,
# less commonly used
optimizers : tuple [ torch . optim . Optimizer , torch . optim . lr_scheduler . LambdaLR ] = ( None , None ) ,
callbacks : Optional [ list [ TrainerCallback ] ] = None ,
) - > None :
if ref_policy is policy :
raise ValueError (
" `policy` and `ref_policy` cannot be the same object. If you want `ref_policy` to be the "
" same as `policy`, you must mass a copy of it, or `None` if you use peft. "
)
self . args = config
args = config
self . processing_class = processing_class
self . policy = policy
# Define the collator if not provided
if data_collator is None :
data_collator = DataCollatorWithPadding ( self . processing_class )
self . policy . generation_config . eos_token_id = (
None # disable `pad_token_id` and `eos_token_id` because we just want to
)
self . policy . generation_config . pad_token_id = None # generate tokens without truncation / padding
self . ref_policy = ref_policy
self . reward_model = reward_model
self . train_dataset = train_dataset
self . train_dataset_len = len ( train_dataset )
self . data_collator = data_collator
self . eval_dataset = eval_dataset
self . optimizer , self . lr_scheduler = optimizers
self . optimizer_cls_and_kwargs = None # needed for transformers >= 4.47
#########
# calculate various batch sizes
#########
if args . total_episodes is None : # allow the users to define episodes in terms of epochs.
args . total_episodes = int ( args . num_train_epochs * self . train_dataset_len )
accelerator = Accelerator ( gradient_accumulation_steps = args . gradient_accumulation_steps )
self . accelerator = accelerator
args . world_size = accelerator . num_processes
args . local_batch_size = (
args . per_device_train_batch_size * args . gradient_accumulation_steps * args . num_mini_batches
)
args . micro_batch_size = int ( args . per_device_train_batch_size * args . world_size )
args . batch_size = int ( args . local_batch_size * args . world_size )
args . mini_batch_size = exact_div (
args . batch_size , args . num_mini_batches , " `batch_size` must be a multiple of `num_mini_batches` "
)
args . local_mini_batch_size = exact_div (
args . local_batch_size , args . num_mini_batches , " `local_batch_size` must be a multiple of `num_mini_batches` "
)
args . num_total_batches = math . ceil (
args . total_episodes / args . batch_size
) # we may train for more than `total_episodes`
time_tensor = torch . tensor ( int ( time . time ( ) ) , device = accelerator . device )
time_int = broadcast ( time_tensor , 0 ) . item ( ) # avoid different timestamps across processes
args . run_name = f " { args . exp_name } __ { args . seed } __ { time_int } "
self . local_seed = args . seed + accelerator . process_index * 100003 # Prime
if args . num_sample_generations > 0 :
self . sample_generations_freq = max ( 1 , args . num_total_batches / / args . num_sample_generations )
self . local_dataloader_batch_size = exact_div (
args . local_batch_size , args . rloo_k , " `local_batch_size` must be a multiple of rloo_k "
) # RLOO logic: needed because RLOO repeats the same prompt args.rloo_k times
#########
# setup model, optimizer, and others
#########
for module in [ policy , ref_policy , reward_model ] :
if isinstance ( module , nn . Module ) :
disable_dropout_in_model ( module )
if args . stop_token and args . stop_token == " eos " :
args . stop_token_id = self . processing_class . eos_token_id
self . model = policy
self . create_optimizer_and_scheduler (
num_training_steps = args . num_total_batches
) # note that we are calling `self.lr_scheduler.step[]` manually only at the batch level
#########
### trainer specifics
#########
default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks ( self . args . report_to )
self . callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks
self . callback_handler = CallbackHandler (
self . callbacks , self . model , self . processing_class , self . optimizer , self . lr_scheduler
)
self . add_callback ( PrinterCallback if self . args . disable_tqdm else DEFAULT_PROGRESS_CALLBACK )
self . control = TrainerControl ( )
self . state = OnlineTrainerState (
is_local_process_zero = self . is_local_process_zero ( ) ,
is_world_process_zero = self . is_world_process_zero ( ) ,
stateful_callbacks = [
cb for cb in self . callback_handler . callbacks + [ self . control ] if isinstance ( cb , ExportableState )
] ,
)
self . current_flos = 0
self . hp_search_backend = None
self . is_deepspeed_enabled = getattr ( self . accelerator . state , " deepspeed_plugin " , None ) is not None
self . is_fsdp_enabled = getattr ( self . accelerator . state , " fsdp_plugin " , None ) is not None
# Create distant repo and output directory if needed
self . hub_model_id = None
if self . args . push_to_hub :
self . init_hf_repo ( )
if self . args . should_save :
os . makedirs ( self . args . output_dir , exist_ok = True )
self . backup_model = None
# Add tags for models that have been loaded with the correct transformers version
if hasattr ( self . model , " add_model_tags " ) :
self . model . add_model_tags ( self . _tag_names )
#########
### setup dataloader
#########
self . dataloader = DataLoader (
self . train_dataset ,
batch_size = self . local_dataloader_batch_size ,
shuffle = True ,
collate_fn = self . data_collator ,
drop_last = True , # needed; otherwise the last batch will be of ragged shape
)
# sync random states for DataLoader[shuffle=True] before `accelerator.prepare`
# see https://gist.github.com/vwxyzjn/2581bff1e48e185e0b85b6dfe1def79c
torch . manual_seed ( args . seed )
self . model , self . optimizer , self . dataloader = accelerator . prepare ( self . model , self . optimizer , self . dataloader )
torch . manual_seed ( self . local_seed ) # reset the local seed again
self . eval_dataloader = DataLoader (
self . eval_dataset ,
batch_size = args . per_device_eval_batch_size ,
collate_fn = self . data_collator ,
drop_last = True ,
) # no need to shuffle eval dataset
self . eval_dataloader = accelerator . prepare ( self . eval_dataloader )
if self . is_deepspeed_enabled :
if isinstance ( self . reward_model , nn . Module ) :
self . reward_model = prepare_deepspeed (
self . reward_model , args . per_device_train_batch_size , args . fp16 , args . bf16
)
self . ref_policy = prepare_deepspeed (
self . ref_policy , args . per_device_train_batch_size , args . fp16 , args . bf16
)
self . deepspeed = self . model
else :
self . ref_policy = self . ref_policy . to ( self . accelerator . device )
if isinstance ( self . reward_model , nn . Module ) :
self . reward_model = self . reward_model . to ( self . accelerator . device )
def get_train_dataloader ( self ) - > DataLoader :
return self . dataloader
def get_eval_dataloader ( self ) - > DataLoader :
return self . eval_dataloader
def train ( self ) :
args = self . args
accelerator = self . accelerator
optimizer = self . optimizer
model = self . model
self . model_wrapped = self . model
ref_policy = self . ref_policy
reward_model = self . reward_model
processing_class = self . processing_class
dataloader = self . dataloader
device = accelerator . device
def repeat_generator ( ) :
while True :
yield from dataloader
iter_dataloader = iter ( repeat_generator ( ) )
generation_config = GenerationConfig (
max_new_tokens = args . response_length ,
temperature = ( args . temperature + 1e-7 ) ,
top_k = 0.0 ,
top_p = 1.0 ,
do_sample = True ,
)
accelerator . print ( " ===training policy=== " )
start_time = time . time ( )
stats_shape = ( args . num_ppo_epochs , args . num_mini_batches , args . gradient_accumulation_steps )
approxkl_stats = torch . zeros ( stats_shape , device = device )
pg_clipfrac_stats = torch . zeros ( stats_shape , device = device )
pg_loss_stats = torch . zeros ( stats_shape , device = device )
vf_clipfrac_stats = torch . zeros ( stats_shape , device = device )
entropy_stats = torch . zeros ( stats_shape , device = device )
ratio_stats = torch . zeros ( stats_shape , device = device )
model . train ( )
# trainer state initialization
self . state . global_step = 0
self . state . episode = 0
self . state . max_steps = ( args . num_total_batches * args . num_mini_batches ) / / 2
self . state . num_train_epochs = args . total_episodes / self . train_dataset_len
# Compute absolute values for logging, eval, and save if given as ratio
if args . logging_steps is not None :
if args . logging_steps < 1 :
self . state . logging_steps = math . ceil ( self . state . max_steps * args . logging_steps )
else :
self . state . logging_steps = args . logging_steps
if args . eval_steps is not None :
if args . eval_steps < 1 :
self . state . eval_steps = math . ceil ( self . state . max_steps * args . eval_steps )
else :
self . state . eval_steps = args . eval_steps
if args . save_steps is not None :
if args . save_steps < 1 :
self . state . save_steps = math . ceil ( self . state . max_steps * args . save_steps )
else :
self . state . save_steps = args . save_steps
self . control = self . callback_handler . on_train_begin ( args , self . state , self . control )
for update in range ( 1 , args . num_total_batches + 1 ) :
self . state . episode + = 1 * args . batch_size
data = next ( iter_dataloader )
with torch . no_grad ( ) :
queries = data [ " input_ids " ] . to ( device )
queries = queries . repeat ( args . rloo_k , 1 )
context_length = queries . shape [ 1 ]
responses = [ ]
postprocessed_responses = [ ]
logprobs = [ ]
ref_logprobs = [ ]
scores = [ ]
sequence_lengths = [ ]
# Generate responses and compute logprobs
with unwrap_model_for_generation (
self . model , self . accelerator , gather_deepspeed3_params = self . args . ds3_gather_for_generation
) as unwrapped_model :
query_responses , logitss = batch_generation (
unwrapped_model ,
queries ,
args . local_rollout_forward_batch_size ,
processing_class . pad_token_id ,
generation_config ,
)
# Process responses in batches
for i in range ( 0 , queries . shape [ 0 ] , args . local_rollout_forward_batch_size ) :
query = queries [ i : i + args . local_rollout_forward_batch_size ]
query_response = query_responses [ i : i + args . local_rollout_forward_batch_size ]
response = query_response [ : , context_length : ]
logits = logitss [ i : i + args . local_rollout_forward_batch_size ]
logprob = selective_log_softmax ( logits , response )
del logits
empty_cache ( )
ref_output = forward ( ref_policy , query_response , processing_class . pad_token_id )
ref_logits = ref_output . logits [ : , context_length - 1 : - 1 ]
ref_logits / = args . temperature + 1e-7
ref_logprob = selective_log_softmax ( ref_logits , response )
del ref_output , ref_logits
empty_cache ( )
# Response Processing 1. truncate response after the first occurrence of `stop_token_id`
postprocessed_response = response
if args . stop_token_id is not None : # handle the edge case when stop_token_id exists but is 0
postprocessed_response = truncate_response (
args . stop_token_id , processing_class . pad_token_id , response
)
# Response Processing 2. run reward model on the truncated responses
postprocessed_query_response = torch . cat ( ( query , postprocessed_response ) , 1 )
sequence_length = first_true_indices ( postprocessed_response == processing_class . pad_token_id ) - 1
if isinstance ( reward_model , nn . Module ) :
_ , score , _ = get_reward (
reward_model , postprocessed_query_response , processing_class . pad_token_id , context_length
)
else :
score = torch . tensor (
reward_model (
processing_class . batch_decode ( postprocessed_query_response , skip_special_tokens = True )
) ,
dtype = torch . float ,
) . to ( device )
# Store batch results
responses . append ( response )
postprocessed_responses . append ( postprocessed_response )
logprobs . append ( logprob )
ref_logprobs . append ( ref_logprob )
sequence_lengths . append ( sequence_length )
scores . append ( score )
# Concatenate all batched results
responses = torch . cat ( responses , 0 )
postprocessed_responses = torch . cat ( postprocessed_responses , 0 )
logprobs = torch . cat ( logprobs , 0 )
ref_logprobs = torch . cat ( ref_logprobs , 0 )
sequence_lengths = torch . cat ( sequence_lengths , 0 )
scores = torch . cat ( scores , 0 )
del ( logprob , ref_logprob , score )
empty_cache ( )
gc . collect ( )
# Response Processing 3. filter response. Ensure that the sample contains stop_token_id
# responses not passing that filter will receive a low (fixed) score
# only query humans on responses that pass that filter
contain_eos_token = torch . any ( postprocessed_responses == processing_class . eos_token_id , dim = - 1 )
if args . missing_eos_penalty is not None :
scores [ ~ contain_eos_token ] - = self . args . missing_eos_penalty
# accelerator.print(f"{scores=}, {(contain_eos_token.sum() / len(contain_eos_token))=}")
# be very careful with `padding_mask_p1`; see https://excalidraw.com/#json=LWnzG4w2k5DjF_EOL_xPt,e2w3a-hFJ_gX5vOfeyXGTw
response_idxs = torch . arange ( responses . shape [ 1 ] , device = responses . device ) . repeat ( responses . shape [ 0 ] , 1 )
padding_mask = response_idxs > sequence_lengths . unsqueeze ( 1 )
logprobs = torch . masked_fill ( logprobs , padding_mask , INVALID_LOGPROB )
ref_logprobs = torch . masked_fill ( ref_logprobs , padding_mask , INVALID_LOGPROB )
# 4. compute rewards
# Compute KL divergence
kl = logprobs - ref_logprobs
# Normalize rewards
if args . normalize_reward :
scores = ( scores - scores . mean ( ) ) / ( scores . std ( ) + 1e-8 )
scores = torch . clamp ( scores , - args . reward_clip_range , args . reward_clip_range )
# Compute total reward with KL penalty
if args . token_level_kl :
# Token-level KL penalty: apply KL penalty per token
kl_reward = - args . kl_coef * kl
# Get the index of the last non-padded token for each sequence
eos_indices = padding_mask . size ( 1 ) - 1 - padding_mask . long ( ) . fliplr ( ) . argmax ( dim = 1 , keepdim = True )
last_reward = torch . zeros_like ( kl )
# Ensure scores has correct shape and type
scores_shaped = scores . reshape ( - 1 , 1 ) . to ( kl . dtype )
last_reward . scatter_ ( dim = 1 , index = eos_indices , src = scores_shaped )
# Combine KL reward and last reward
non_score_reward = kl_reward . sum ( 1 ) # Keep this for logging
reward = last_reward + kl_reward
rlhf_reward = reward . sum ( 1 ) # Sum across sequence length
else :
# Sequence-level KL penalty: sum KL across tokens first
sequence_kl = kl . sum ( 1 )
non_score_reward = - args . kl_coef * sequence_kl
rlhf_reward = non_score_reward + scores
# vectorized RLOO advantages implementation
rlhf_reward = rlhf_reward . reshape ( args . rloo_k , - 1 )
baseline = ( rlhf_reward . sum ( 0 ) - rlhf_reward ) / ( args . rloo_k - 1 )
advantages = rlhf_reward - baseline
advantages = advantages . flatten ( )
# Normalize advantages
if args . normalize_advantage :
advantages = ( advantages - advantages . mean ( ) ) / ( advantages . std ( ) + 1e-8 )
empty_cache ( )
# Do multiple epochs of PPO training, with a fresh random shuffle in each epoch
for ppo_epoch_idx in range ( args . num_ppo_epochs ) :
b_inds = np . random . permutation ( args . local_batch_size )
minibatch_idx = 0
for mini_batch_start in range ( 0 , args . local_batch_size , args . local_mini_batch_size ) :
mini_batch_end = mini_batch_start + args . local_mini_batch_size
mini_batch_inds = b_inds [ mini_batch_start : mini_batch_end ]
gradient_accumulation_idx = 0
for micro_batch_start in range ( 0 , args . local_mini_batch_size , args . per_device_train_batch_size ) :
with accelerator . accumulate ( model ) :
micro_batch_end = micro_batch_start + args . per_device_train_batch_size
micro_batch_inds = mini_batch_inds [ micro_batch_start : micro_batch_end ]
# Get batch data
mb_advantage = advantages [ micro_batch_inds ]
mb_responses = responses [ micro_batch_inds ]
mb_query_responses = query_responses [ micro_batch_inds ]
mb_logprobs = logprobs [ micro_batch_inds ]
# Forward pass
output = forward ( model , mb_query_responses , processing_class . pad_token_id )
logits = output . logits [ : , context_length - 1 : - 1 ]
logits / = args . temperature + 1e-7
# Compute new logprobs
new_logprobs = selective_log_softmax ( logits , mb_responses )
new_logprobs = torch . masked_fill (
new_logprobs , padding_mask [ micro_batch_inds ] , INVALID_LOGPROB
)
# Compute probability ratios
new_ratio = ( new_logprobs - mb_logprobs ) . exp ( )
new_logprobs = new_logprobs . sum ( 1 )
mb_logprobs = mb_logprobs . sum ( 1 )
logprobs_diff = new_logprobs - mb_logprobs
ratio = torch . exp ( logprobs_diff )
# PPO clipped loss
pg_losses = - mb_advantage * ratio
pg_losses2 = - mb_advantage * torch . clamp ( ratio , 1.0 - args . cliprange , 1.0 + args . cliprange )
pg_loss_max = torch . max ( pg_losses , pg_losses2 )
pg_loss = pg_loss_max . mean ( )
# Final loss
loss = pg_loss
# Optimization step
accelerator . backward ( loss )
optimizer . step ( )
optimizer . zero_grad ( )
with torch . no_grad ( ) :
pg_clipfrac = ( pg_losses2 > pg_losses ) . float ( ) . mean ( )
2025-08-28 17:57:59 +00:00
prob_dist = torch . nn . functional . softmax ( logits , dim = - 1 , dtype = torch . float32 ) . to ( logits . dtype )
2025-08-13 23:50:20 +00:00
entropy = torch . logsumexp ( logits , dim = - 1 ) - torch . sum ( prob_dist * logits , dim = - 1 )
approxkl = 0.5 * ( logprobs_diff * * 2 ) . mean ( )
approxkl_stats [ ppo_epoch_idx , minibatch_idx , gradient_accumulation_idx ] = approxkl
pg_clipfrac_stats [ ppo_epoch_idx , minibatch_idx , gradient_accumulation_idx ] = (
pg_clipfrac
)
pg_loss_stats [ ppo_epoch_idx , minibatch_idx , gradient_accumulation_idx ] = pg_loss
entropy_stats [ ppo_epoch_idx , minibatch_idx , gradient_accumulation_idx ] = entropy . mean ( )
ratio_stats [ ppo_epoch_idx , minibatch_idx , gradient_accumulation_idx ] = new_ratio . mean ( )
gradient_accumulation_idx + = 1
minibatch_idx + = 1
# del everything and empty cache
# fmt: off
del (
output , logits , new_logprobs , logprobs_diff , ratio , pg_losses ,
pg_losses2 , pg_loss , loss , pg_clipfrac , prob_dist , entropy , approxkl ,
mb_advantage , mb_responses , mb_query_responses , mb_logprobs ,
)
# fmt: on
empty_cache ( )
# Compute metrics
with torch . no_grad ( ) :
mean_kl = kl . sum ( 1 ) . mean ( )
mean_entropy = ( - logprobs ) . sum ( 1 ) . mean ( )
mean_non_score_reward = non_score_reward . mean ( )
eps = int ( self . state . episode / ( time . time ( ) - start_time ) )
metrics = { }
metrics [ " eps " ] = eps
metrics [ " objective/kl " ] = self . accelerator . gather_for_metrics ( mean_kl ) . mean ( ) . item ( )
metrics [ " objective/entropy " ] = self . accelerator . gather_for_metrics ( mean_entropy ) . mean ( ) . item ( )
metrics [ " objective/non_score_reward " ] = (
self . accelerator . gather_for_metrics ( mean_non_score_reward ) . mean ( ) . item ( )
)
metrics [ " objective/rlhf_reward " ] = self . accelerator . gather_for_metrics ( rlhf_reward ) . mean ( ) . item ( )
metrics [ " objective/scores " ] = self . accelerator . gather_for_metrics ( scores . mean ( ) ) . mean ( ) . item ( )
metrics [ " policy/approxkl_avg " ] = self . accelerator . gather_for_metrics ( approxkl_stats ) . mean ( ) . item ( )
metrics [ " policy/clipfrac_avg " ] = self . accelerator . gather_for_metrics ( pg_clipfrac_stats ) . mean ( ) . item ( )
metrics [ " loss/policy_avg " ] = self . accelerator . gather_for_metrics ( pg_loss_stats ) . mean ( ) . item ( )
metrics [ " val/clipfrac_avg " ] = self . accelerator . gather_for_metrics ( vf_clipfrac_stats ) . mean ( ) . item ( )
metrics [ " policy/entropy_avg " ] = self . accelerator . gather_for_metrics ( entropy_stats ) . mean ( ) . item ( )
metrics [ " val/ratio " ] = self . accelerator . gather_for_metrics ( ratio_stats ) . mean ( ) . item ( )
metrics [ " val/ratio_var " ] = self . accelerator . gather_for_metrics ( ratio_stats ) . var ( ) . item ( )
metrics [ " val/num_eos_tokens " ] = ( responses == processing_class . eos_token_id ) . sum ( ) . item ( )
metrics [ " lr " ] = self . lr_scheduler . get_last_lr ( ) [ 0 ]
metrics [ " episode " ] = self . state . episode
self . state . epoch = self . state . episode / ( args . rloo_k * self . train_dataset_len ) # used by self.log
self . log ( metrics )
del kl , mean_kl , mean_entropy , scores
self . lr_scheduler . step ( )
self . state . global_step + = 1
self . control = self . callback_handler . on_step_end ( args , self . state , self . control )
if self . control . should_save :
self . _save_checkpoint ( model , trial = None )
self . control = self . callback_handler . on_save ( self . args , self . state , self . control )
empty_cache ( )
gc . collect ( )
if args . num_sample_generations > 0 and ( update - 1 ) % self . sample_generations_freq == 0 :
self . generate_completions ( sampling = True )
# HF trainer specifics
self . control = self . callback_handler . on_train_end ( args , self . state , self . control )
if self . control . should_save :
self . _save_checkpoint ( model , trial = None , metrics = None )
self . control = self . callback_handler . on_save ( self . args , self . state , self . control )
def generate_completions ( self , sampling : bool = False ) :
args = self . args
processing_class = self . processing_class
generation_config = GenerationConfig (
max_new_tokens = self . args . response_length ,
temperature = ( 0.01 + 1e-7 ) ,
top_k = 0.0 ,
top_p = 1.0 ,
do_sample = True ,
)
table = defaultdict ( list )
with unwrap_model_for_generation (
self . model , self . accelerator , gather_deepspeed3_params = self . args . ds3_gather_for_generation
) as unwrapped_model :
for batch in self . eval_dataloader :
query = batch [ " input_ids " ]
with torch . no_grad ( ) :
context_length = query . shape [ 1 ]
query_response , _ = batch_generation (
unwrapped_model ,
query ,
query . shape [ 0 ] ,
processing_class . pad_token_id ,
generation_config ,
)
response = query_response [ : , context_length : ]
postprocessed_response = response
if args . stop_token_id is not None : # handle the edge case when stop_token_id exists but is 0
postprocessed_response = truncate_response (
args . stop_token_id , processing_class . pad_token_id , response
)
table [ " query " ] . extend (
gather_object ( processing_class . batch_decode ( query , skip_special_tokens = True ) )
)
table [ " model response " ] . extend (
gather_object ( processing_class . batch_decode ( postprocessed_response ) )
)
postprocessed_query_response = torch . cat ( ( query , postprocessed_response ) , 1 )
if isinstance ( self . reward_model , nn . Module ) :
_ , score , _ = get_reward (
self . reward_model ,
postprocessed_query_response ,
processing_class . pad_token_id ,
context_length ,
)
else :
score = torch . tensor (
self . reward_model (
processing_class . batch_decode ( postprocessed_query_response , skip_special_tokens = True )
) ,
dtype = torch . float ,
) . to ( postprocessed_query_response . device )
table [ " score " ] . extend ( self . accelerator . gather_for_metrics ( score ) . float ( ) . cpu ( ) . numpy ( ) )
if sampling :
break
df = pd . DataFrame ( table )
if self . accelerator . is_main_process :
if is_rich_available ( ) :
print_rich_table ( df . iloc [ 0 : 0 + 5 ] )
if " wandb " in args . report_to :
import wandb
if wandb . run is not None :
wandb . log ( { " completions " : wandb . Table ( dataframe = df ) } )
if " comet_ml " in args . report_to :
log_table_to_comet_experiment (
name = " completions.csv " ,
table = df ,
)
# Ensure the model card is saved along with the checkpoint
def _save_checkpoint ( self , model , trial ) :
if self . args . hub_model_id is None :
model_name = Path ( self . args . output_dir ) . name
else :
model_name = self . args . hub_model_id . split ( " / " ) [ - 1 ]
self . create_model_card ( model_name = model_name )
super ( ) . _save_checkpoint ( model , trial )
def create_model_card (
self ,
model_name : Optional [ str ] = None ,
dataset_name : Optional [ str ] = None ,
tags : Union [ str , list [ str ] , None ] = None ,
) :
"""
Creates a draft of a model card using the information available to the `Trainer`.
Args:
model_name (`str` or `None`, *optional*, defaults to `None`):
Name of the model.
dataset_name (`str` or `None`, *optional*, defaults to `None`):
Name of the dataset used for training.
tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
Tags to be associated with the model card.
"""
if not self . is_world_process_zero ( ) :
return
if hasattr ( self . model . config , " _name_or_path " ) and not os . path . isdir ( self . model . config . _name_or_path ) :
base_model = self . model . config . _name_or_path
else :
base_model = None
# normalize `tags` to a mutable set
if tags is None :
tags = set ( )
elif isinstance ( tags , str ) :
tags = { tags }
else :
tags = set ( tags )
if hasattr ( self . model . config , " unsloth_version " ) :
tags . add ( " unsloth " )
tags . update ( self . _tag_names )
citation = textwrap . dedent ( """ \
@inproceedings { ahmadian2024back,
title = {{ Back to Basics: Revisiting REINFORCE-Style Optimization for Learning from Human Feedback in LLMs}},
author = { Arash Ahmadian and Chris Cremer and Matthias Gall { \' {e} } and Marzieh Fadaee and Julia Kreutzer and Olivier Pietquin and Ahmet { \" {U} }st { \" {u} }n and Sara Hooker},
year = 2024,
booktitle = { Proceedings of the 62nd Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers), {ACL} 2024, Bangkok, Thailand, August 11-16, 2024},
publisher = { Association for Computational Linguistics},
pages = { 12248--12267},
editor = { Lun { -}Wei Ku and Andre Martins and Vivek Srikumar},
} """ )
model_card = generate_model_card (
base_model = base_model ,
model_name = model_name ,
hub_model_id = self . hub_model_id ,
dataset_name = dataset_name ,
tags = tags ,
wandb_url = wandb . run . url if is_wandb_available ( ) and wandb . run is not None else None ,
comet_url = get_comet_experiment_url ( ) ,
trainer_name = " RLOO " ,
trainer_citation = citation ,
paper_title = " Back to Basics: Revisiting REINFORCE-Style Optimization for Learning from Human Feedback in LLMs " ,
paper_id = " 2402.14740 " ,
)
model_card . save ( os . path . join ( self . args . output_dir , " README.md " ) )
class UnslothRLOOTrainer ( _UnslothRLOOTrainer ) :
"""
"""
def __init__ (
self ,
config ,
processing_class ,
policy ,
ref_policy ,
reward_model ,
train_dataset ,
data_collator = None ,
eval_dataset = None ,
callbacks = None ,
* * kwargs
) :
if args is None : args = UnslothRLOOConfig ( )
_output_logits = False
if locals ( ) . get ( ' compute_metrics ' , None ) is not None : _output_logits = True
if locals ( ) . get ( ' preprocess_logits_for_metrics ' , None ) is not None : _output_logits = True
if _output_logits :
os . environ [ ' UNSLOTH_RETURN_LOGITS ' ] = ' 1 '
if ' max_seq_length ' not in locals ( ) and not hasattr ( args , ' max_seq_length ' ) :
pass
else :
model_max_seq_length = getattr ( model , ' max_seq_length ' , None )
args_max_seq_length = getattr ( args , ' max_seq_length ' , None )
if args_max_seq_length is None and model_max_seq_length is not None :
max_seq_length = model . max_seq_length
if hasattr ( args , ' max_seq_length ' ) : args . max_seq_length = max_seq_length
if model is not None and hasattr ( model , ' for_training ' ) :
model . for_training ( )
if ' tokenizer ' in locals ( ) and hasattr ( tokenizer , ' padding_side ' ) : tokenizer . padding_side = ' right '
if ' processing_class ' in locals ( ) :
if hasattr ( processing_class , ' padding_side ' ) : processing_class . padding_side = ' right '
if hasattr ( processing_class , ' tokenizer ' ) and hasattr ( processing_class . tokenizer , ' padding_side ' ) : processing_class . tokenizer . padding_side = ' right '
__tokenizer = processing_class if ' processing_class ' in locals ( ) else tokenizer
from unsloth_zoo . vision_utils import UnslothVisionDataCollator
if not isinstance ( data_collator , UnslothVisionDataCollator ) :
if isinstance ( data_collator , DataCollatorForSeq2Seq ) and ' labels ' not in train_dataset . column_names :
2025-08-28 17:57:59 +00:00
data_collator = TransformersDataCollatorForLanguageModeling (
__tokenizer ,
mlm = False ,
mlm_probability = 0.0 ,
pad_to_multiple_of = getattr ( args , ' pad_to_multiple_of ' , None ) ,
)
2025-08-13 23:50:20 +00:00
elif isinstance ( data_collator , TransformersDataCollatorForLanguageModeling ) and ' labels ' in train_dataset . column_names :
2025-08-28 17:57:59 +00:00
data_collator = DataCollatorForSeq2Seq (
__tokenizer ,
pad_to_multiple_of = getattr ( args , ' pad_to_multiple_of ' , None ) ,
)
2025-08-13 23:50:20 +00:00
else :
if hasattr ( args , ' remove_unused_columns ' ) : args . remove_unused_columns = False
if hasattr ( args , ' dataset_text_field ' ) : args . dataset_text_field = ' '
if hasattr ( args , ' dataset_kwargs ' ) : args . dataset_kwargs = { ' skip_prepare_dataset ' : True }
if not isinstance ( data_collator , UnslothVisionDataCollator ) :
if not hasattr ( __tokenizer , ' pad ' ) and hasattr ( __tokenizer , ' tokenizer ' ) :
if isinstance ( data_collator , DataCollatorForSeq2Seq ) :
2025-08-28 17:57:59 +00:00
data_collator = DataCollatorForSeq2Seq (
__tokenizer . tokenizer ,
pad_to_multiple_of = getattr ( args , ' pad_to_multiple_of ' , None ) ,
)
2025-08-13 23:50:20 +00:00
else :
2025-08-28 17:57:59 +00:00
data_collator = TransformersDataCollatorForLanguageModeling (
__tokenizer . tokenizer ,
mlm = False ,
mlm_probability = 0.0 ,
pad_to_multiple_of = getattr ( args , ' pad_to_multiple_of ' , None ) ,
)
2025-08-13 23:50:20 +00:00
other_metrics = [ ]
from unsloth_zoo . logging_utils import PatchRLStatistics
PatchRLStatistics ( ' rloo_trainer ' , other_metrics )
super ( ) . __init__ (
config = config ,
processing_class = processing_class ,
policy = policy ,
ref_policy = ref_policy ,
reward_model = reward_model ,
train_dataset = train_dataset ,
data_collator = data_collator ,
eval_dataset = eval_dataset ,
callbacks = callbacks , * * kwargs )
if hasattr ( self , ' neftune_hook_handle ' ) :
self . neftune_hook_handle . remove ( )
if hasattr ( self , ' neftune_hook_handle ' ) : del self . neftune_hook_handle
if getattr ( args , ' neftune_noise_alpha ' , None ) is not None :
model . get_input_embeddings ( ) . neftune_noise_alpha = self . neftune_noise_alpha
pass
if hasattr ( self , ' accelerator ' ) :
scaler = self . accelerator . scaler
current_model = model
while hasattr ( current_model , ' model ' ) :
current_model . accelerator_scaler = scaler
current_model = current_model . model
current_model . accelerator_scaler = scaler
pass
pass