2025-08-13 23:50:20 +00:00
"""
2025-08-28 17:57:59 +00:00
2025.8.9
2025.8.10
4.55.4
2025-08-13 23:50:20 +00:00
0.21.0
__UNSLOTH_VERSIONING__
"""
from torch import Tensor
import torch
import torch . nn as nn
from torch . nn import functional as F
from typing import Any , List , Optional , Tuple , Union , Dict , Set , Callable
from trl . trainer . ppo_trainer import ( Accelerator , BaseImageProcessor , CallbackHandler , DEFAULT_CALLBACKS , DEFAULT_PROGRESS_CALLBACK , DataCollatorWithPadding , DataLoader , Dataset , ExportableState , FeatureExtractionMixin , GenerationConfig , INVALID_LOGPROB , OnlineTrainerState , Optional , PPOConfig , PPOTrainer , Path , PeftConfig , PeftModel , PolicyAndValueWrapper , PreTrainedTokenizerBase , PrinterCallback , ProcessorMixin , Trainer , TrainerCallback , TrainerControl , Union , batch_generation , broadcast , contextmanager , create_reference_model , defaultdict , disable_dropout_in_model , empty_cache , exact_div , first_true_indices , forward , gather_object , gc , generate_model_card , get_comet_experiment_url , get_peft_model , get_reporting_integration_callbacks , get_reward , is_peft_available , is_rich_available , is_wandb_available , log_table_to_comet_experiment , masked_mean , masked_whiten , math , nn , np , nullcontext , os , pd , peft_module_casting_to_bf16 , prepare_deepspeed , print_rich_table , selective_log_softmax , textwrap , time , torch , truncate_response , unwrap_model_for_generation , Optional , PeftModel , Trainer , is_peft_available , os , torch )
import os
from typing import *
from dataclasses import dataclass , field
from packaging . version import Version
import torch
import numpy as np
from contextlib import nullcontext
from torch . nn import functional as F
from transformers import DataCollatorForSeq2Seq , DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling
torch_compile_options = {
" epilogue_fusion " : True ,
" max_autotune " : False ,
" shape_padding " : True ,
" trace.enabled " : False ,
" triton.cudagraphs " : False ,
}
@torch.compile ( dynamic = True , fullgraph = True , options = torch_compile_options , )
def chunked_selective_log_softmax ( logits , index ) :
# Split into 4 chunks only
chunked_logits = torch . chunk ( logits . reshape ( - 1 , logits . shape [ - 1 ] ) , chunks = 4 , dim = 0 )
chunked_index = torch . chunk ( index . reshape ( - 1 ) , chunks = 4 , dim = 0 )
all_per_token_logps = [ ]
# Below loop does the same as selective_log_softmax(chunk_logits, chunk_index)
for chunk_logits , chunk_index in zip ( chunked_logits , chunked_index ) :
chunk_logits = chunk_logits . to ( torch . float32 )
selected_logits = torch . gather ( chunk_logits , dim = - 1 , index = chunk_index . unsqueeze ( - 1 ) ) . squeeze ( - 1 )
logsumexp_values = torch . logsumexp ( chunk_logits , dim = - 1 )
per_token_logps = selected_logits - logsumexp_values
all_per_token_logps . append ( per_token_logps )
pass
all_per_token_logps = torch . concat ( all_per_token_logps )
all_per_token_logps = all_per_token_logps . reshape ( ( logits . shape [ 0 ] , logits . shape [ 1 ] ) )
return all_per_token_logps
@dataclass
class UnslothPPOConfig ( PPOConfig ) :
"""
Configuration class for the [`PPOTrainer`].
This class includes only the parameters that are specific to PPO training. For a full list of training arguments,
please refer to the [`~transformers.TrainingArguments`] and [`OnPolicyConfig`] documentation. Note that default
values in this class may differ from those in [`~transformers.TrainingArguments`].
Using [`~transformers.HfArgumentParser`] we can turn this class into
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
command line.
Parameters:
exp_name (`str`, *optional*, defaults to `os.path.basename(__file__)[:-3]`):
Name of this experiment.
reward_model_path (`str`, *optional*, defaults to ` " EleutherAI/pythia-160m " `):
Path to the reward model.
model_adapter_name (`str` or `None`, *optional*, defaults to `None`):
Name of the train target PEFT adapter, when using LoRA with multiple adapters.
ref_adapter_name (`str` or `None`, *optional*, defaults to `None`):
Name of the reference PEFT adapter, when using LoRA with multiple adapters.
num_ppo_epochs (`int`, *optional*, defaults to `4`):
Number of epochs to train.
whiten_rewards (`bool`, *optional*, defaults to `False`):
Whether to whiten the rewards.
kl_coef (`float`, *optional*, defaults to `0.05`):
KL coefficient.
kl_estimator (`Literal[ " k1 " , " k3 " ]`, *optional*, defaults to ` " k1 " `):
Which estimator for KL-Divergence to use from [Approximating KL
Divergence](http://joschu.net/blog/kl-approx.html). Defaults to " k1 " , a straightforward, unbiased
estimator. Can be set to " k3 " , an unbiased estimator with lower variance which " appears to be a strictly
better estimator " . Cannot be set to " k2 " , as it is used for logging purposes.
cliprange (`float`, *optional*, defaults to `0.2`):
Clip range.
vf_coef (`float`, *optional*, defaults to `0.1`):
Value function coefficient.
cliprange_value (`float`, *optional*, defaults to `0.2`):
Clip range for the value function.
gamma (`float`, *optional*, defaults to `1.0`):
Discount factor.
lam (`float`, *optional*, defaults to `0.95`):
Lambda value for GAE.
ds3_gather_for_generation (`bool`, *optional*, defaults to `True`):
This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for generation,
improving generation speed. However, disabling this option allows training models that exceed the VRAM
capacity of a single GPU, albeit at the cost of slower generation.
"""
vllm_sampling_params : Optional [ Any ] = field (
default = None ,
metadata = { ' help ' : ' vLLM SamplingParams ' } ,
)
unsloth_num_chunks : Optional [ int ] = field (
default = - 1 ,
metadata = { ' help ' : ' Chunk size to reduce memory usage. -1 is most efficient. ' } ,
)
2025-08-28 17:57:59 +00:00
2025-08-13 23:50:20 +00:00
def __init__ (
self ,
output_dir = None ,
overwrite_output_dir = None ,
do_train = False ,
do_eval = False ,
do_predict = False ,
eval_strategy = ' no ' ,
prediction_loss_only = False ,
per_device_train_batch_size = 4 ,
per_device_eval_batch_size = 4 ,
per_gpu_train_batch_size = None ,
per_gpu_eval_batch_size = None ,
gradient_accumulation_steps = 2 ,
eval_accumulation_steps = 2 ,
eval_delay = 0 ,
torch_empty_cache_steps = 250 ,
learning_rate = 5e-05 ,
weight_decay = 0.01 ,
adam_beta1 = 0.9 ,
adam_beta2 = 0.999 ,
adam_epsilon = 1e-08 ,
max_grad_norm = 1.0 ,
num_train_epochs = 3.0 ,
max_steps = - 1 ,
lr_scheduler_type = ' linear ' ,
warmup_ratio = 0.1 ,
warmup_steps = 0 ,
log_level = ' passive ' ,
log_level_replica = ' warning ' ,
log_on_each_node = True ,
logging_dir = None ,
logging_strategy = ' steps ' ,
logging_first_step = False ,
logging_steps = 1 ,
logging_nan_inf_filter = False ,
save_strategy = ' steps ' ,
save_steps = 500 ,
save_total_limit = None ,
save_safetensors = True ,
save_on_each_node = False ,
save_only_model = False ,
restore_callback_states_from_checkpoint = False ,
no_cuda = False ,
use_cpu = False ,
use_mps_device = False ,
seed = 3407 ,
data_seed = 3407 ,
jit_mode_eval = False ,
use_ipex = False ,
bf16 = False ,
fp16 = False ,
fp16_opt_level = ' O1 ' ,
half_precision_backend = ' auto ' ,
bf16_full_eval = False ,
fp16_full_eval = False ,
tf32 = None ,
local_rank = - 1 ,
ddp_backend = None ,
tpu_num_cores = None ,
tpu_metrics_debug = False ,
debug = ' ' ,
dataloader_drop_last = False ,
eval_steps = None ,
dataloader_num_workers = 0 ,
dataloader_prefetch_factor = None ,
past_index = - 1 ,
run_name = None ,
disable_tqdm = None ,
remove_unused_columns = True ,
label_names = None ,
load_best_model_at_end = False ,
metric_for_best_model = None ,
greater_is_better = None ,
ignore_data_skip = False ,
fsdp = ' ' ,
fsdp_min_num_params = 0 ,
fsdp_config = None ,
fsdp_transformer_layer_cls_to_wrap = None ,
accelerator_config = None ,
deepspeed = None ,
label_smoothing_factor = 0.0 ,
optim = ' adamw_8bit ' ,
optim_args = None ,
adafactor = False ,
group_by_length = False ,
length_column_name = ' length ' ,
report_to = None ,
ddp_find_unused_parameters = None ,
ddp_bucket_cap_mb = None ,
ddp_broadcast_buffers = None ,
dataloader_pin_memory = True ,
dataloader_persistent_workers = False ,
skip_memory_metrics = True ,
use_legacy_prediction_loop = False ,
push_to_hub = False ,
resume_from_checkpoint = None ,
hub_model_id = None ,
hub_strategy = ' every_save ' ,
hub_token = None ,
hub_private_repo = None ,
hub_always_push = False ,
hub_revision = None ,
gradient_checkpointing = False ,
gradient_checkpointing_kwargs = None ,
include_inputs_for_metrics = False ,
eval_do_concat_batches = True ,
fp16_backend = ' auto ' ,
push_to_hub_model_id = None ,
push_to_hub_organization = None ,
push_to_hub_token = None ,
mp_parameters = ' ' ,
auto_find_batch_size = True ,
full_determinism = False ,
torchdynamo = None ,
ray_scope = ' last ' ,
ddp_timeout = 1800 ,
torch_compile = False ,
torch_compile_backend = None ,
torch_compile_mode = None ,
include_tokens_per_second = False ,
include_num_input_tokens_seen = False ,
neftune_noise_alpha = None ,
optim_target_modules = None ,
batch_eval_metrics = False ,
eval_on_start = False ,
use_liger_kernel = False ,
liger_kernel_config = None ,
eval_use_gather_object = False ,
average_tokens_across_devices = True ,
dataset_num_proc = None ,
num_mini_batches = 1 ,
total_episodes = None ,
local_rollout_forward_batch_size = 64 ,
num_sample_generations = 10 ,
response_length = 53 ,
stop_token = None ,
stop_token_id = None ,
temperature = 0.7 ,
missing_eos_penalty = None ,
sft_model_path = ' EleutherAI/pythia-160m ' ,
world_size = None ,
num_total_batches = None ,
micro_batch_size = None ,
local_batch_size = None ,
batch_size = None ,
local_mini_batch_size = None ,
mini_batch_size = None ,
exp_name = ' ppo_config ' ,
reward_model_path = ' EleutherAI/pythia-160m ' ,
model_adapter_name = None ,
ref_adapter_name = None ,
num_ppo_epochs = 4 ,
whiten_rewards = False ,
kl_coef = 0.05 ,
kl_estimator = ' k1 ' ,
cliprange = 0.2 ,
vf_coef = 0.1 ,
cliprange_value = 0.2 ,
gamma = 1.0 ,
lam = 0.95 ,
ds3_gather_for_generation = True ,
vllm_sampling_params = None ,
unsloth_num_chunks = - 1 ,
2025-08-28 17:57:59 +00:00
2025-08-13 23:50:20 +00:00
* * kwargs ,
) :
if learning_rate < 1e-7 : raise FloatingPointError ( f ' Unsloth: Your learning rate of ` { learning_rate } ` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0! ' )
if learning_rate > 1 : raise OverflowError ( f ' Unsloth: Your learning rate of ` { learning_rate } ` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode! ' )
if output_dir is None and save_strategy == ' steps ' and save_steps == 500 :
output_dir = ' unsloth_training_checkpoints '
save_strategy = ' no '
if dataset_num_proc is None :
from multiprocessing import cpu_count
2025-08-28 17:57:59 +00:00
dataset_num_proc = max ( cpu_count ( ) + 4 , 2 )
2025-08-13 23:50:20 +00:00
if temperature < = 0 :
raise MathError ( ' Unsloth: Please set a positive non-zero temperature since your results will be wrong. ' )
elif temperature > = 10 :
raise MathError ( ' Unsloth: Please set a positive non-zero temperature less than 10, since sampling will be quite erratic. ' )
super ( ) . __init__ (
output_dir = output_dir ,
overwrite_output_dir = overwrite_output_dir ,
do_train = do_train ,
do_eval = do_eval ,
do_predict = do_predict ,
eval_strategy = eval_strategy ,
prediction_loss_only = prediction_loss_only ,
per_device_train_batch_size = per_device_train_batch_size ,
per_device_eval_batch_size = per_device_eval_batch_size ,
per_gpu_train_batch_size = per_gpu_train_batch_size ,
per_gpu_eval_batch_size = per_gpu_eval_batch_size ,
gradient_accumulation_steps = gradient_accumulation_steps ,
eval_accumulation_steps = eval_accumulation_steps ,
eval_delay = eval_delay ,
torch_empty_cache_steps = torch_empty_cache_steps ,
learning_rate = learning_rate ,
weight_decay = weight_decay ,
adam_beta1 = adam_beta1 ,
adam_beta2 = adam_beta2 ,
adam_epsilon = adam_epsilon ,
max_grad_norm = max_grad_norm ,
num_train_epochs = num_train_epochs ,
max_steps = max_steps ,
lr_scheduler_type = lr_scheduler_type ,
warmup_ratio = warmup_ratio ,
warmup_steps = warmup_steps ,
log_level = log_level ,
log_level_replica = log_level_replica ,
log_on_each_node = log_on_each_node ,
logging_dir = logging_dir ,
logging_strategy = logging_strategy ,
logging_first_step = logging_first_step ,
logging_steps = logging_steps ,
logging_nan_inf_filter = logging_nan_inf_filter ,
save_strategy = save_strategy ,
save_steps = save_steps ,
save_total_limit = save_total_limit ,
save_safetensors = save_safetensors ,
save_on_each_node = save_on_each_node ,
save_only_model = save_only_model ,
restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint ,
no_cuda = no_cuda ,
use_cpu = use_cpu ,
use_mps_device = use_mps_device ,
seed = seed ,
data_seed = data_seed ,
jit_mode_eval = jit_mode_eval ,
use_ipex = use_ipex ,
bf16 = bf16 ,
fp16 = fp16 ,
fp16_opt_level = fp16_opt_level ,
half_precision_backend = half_precision_backend ,
bf16_full_eval = bf16_full_eval ,
fp16_full_eval = fp16_full_eval ,
tf32 = tf32 ,
local_rank = local_rank ,
ddp_backend = ddp_backend ,
tpu_num_cores = tpu_num_cores ,
tpu_metrics_debug = tpu_metrics_debug ,
debug = debug ,
dataloader_drop_last = dataloader_drop_last ,
eval_steps = eval_steps ,
dataloader_num_workers = dataloader_num_workers ,
dataloader_prefetch_factor = dataloader_prefetch_factor ,
past_index = past_index ,
run_name = run_name ,
disable_tqdm = disable_tqdm ,
remove_unused_columns = remove_unused_columns ,
label_names = label_names ,
load_best_model_at_end = load_best_model_at_end ,
metric_for_best_model = metric_for_best_model ,
greater_is_better = greater_is_better ,
ignore_data_skip = ignore_data_skip ,
fsdp = fsdp ,
fsdp_min_num_params = fsdp_min_num_params ,
fsdp_config = fsdp_config ,
fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap ,
accelerator_config = accelerator_config ,
deepspeed = deepspeed ,
label_smoothing_factor = label_smoothing_factor ,
optim = optim ,
optim_args = optim_args ,
adafactor = adafactor ,
group_by_length = group_by_length ,
length_column_name = length_column_name ,
report_to = report_to ,
ddp_find_unused_parameters = ddp_find_unused_parameters ,
ddp_bucket_cap_mb = ddp_bucket_cap_mb ,
ddp_broadcast_buffers = ddp_broadcast_buffers ,
dataloader_pin_memory = dataloader_pin_memory ,
dataloader_persistent_workers = dataloader_persistent_workers ,
skip_memory_metrics = skip_memory_metrics ,
use_legacy_prediction_loop = use_legacy_prediction_loop ,
push_to_hub = push_to_hub ,
resume_from_checkpoint = resume_from_checkpoint ,
hub_model_id = hub_model_id ,
hub_strategy = hub_strategy ,
hub_token = hub_token ,
hub_private_repo = hub_private_repo ,
hub_always_push = hub_always_push ,
hub_revision = hub_revision ,
gradient_checkpointing = gradient_checkpointing ,
gradient_checkpointing_kwargs = gradient_checkpointing_kwargs ,
include_inputs_for_metrics = include_inputs_for_metrics ,
eval_do_concat_batches = eval_do_concat_batches ,
fp16_backend = fp16_backend ,
push_to_hub_model_id = push_to_hub_model_id ,
push_to_hub_organization = push_to_hub_organization ,
push_to_hub_token = push_to_hub_token ,
mp_parameters = mp_parameters ,
auto_find_batch_size = auto_find_batch_size ,
full_determinism = full_determinism ,
torchdynamo = torchdynamo ,
ray_scope = ray_scope ,
ddp_timeout = ddp_timeout ,
torch_compile = torch_compile ,
torch_compile_backend = torch_compile_backend ,
torch_compile_mode = torch_compile_mode ,
include_tokens_per_second = include_tokens_per_second ,
include_num_input_tokens_seen = include_num_input_tokens_seen ,
neftune_noise_alpha = neftune_noise_alpha ,
optim_target_modules = optim_target_modules ,
batch_eval_metrics = batch_eval_metrics ,
eval_on_start = eval_on_start ,
use_liger_kernel = use_liger_kernel ,
liger_kernel_config = liger_kernel_config ,
eval_use_gather_object = eval_use_gather_object ,
average_tokens_across_devices = average_tokens_across_devices ,
dataset_num_proc = dataset_num_proc ,
num_mini_batches = num_mini_batches ,
total_episodes = total_episodes ,
local_rollout_forward_batch_size = local_rollout_forward_batch_size ,
num_sample_generations = num_sample_generations ,
response_length = response_length ,
stop_token = stop_token ,
stop_token_id = stop_token_id ,
temperature = temperature ,
missing_eos_penalty = missing_eos_penalty ,
sft_model_path = sft_model_path ,
world_size = world_size ,
num_total_batches = num_total_batches ,
micro_batch_size = micro_batch_size ,
local_batch_size = local_batch_size ,
batch_size = batch_size ,
local_mini_batch_size = local_mini_batch_size ,
mini_batch_size = mini_batch_size ,
exp_name = exp_name ,
reward_model_path = reward_model_path ,
model_adapter_name = model_adapter_name ,
ref_adapter_name = ref_adapter_name ,
num_ppo_epochs = num_ppo_epochs ,
whiten_rewards = whiten_rewards ,
kl_coef = kl_coef ,
kl_estimator = kl_estimator ,
cliprange = cliprange ,
vf_coef = vf_coef ,
cliprange_value = cliprange_value ,
gamma = gamma ,
lam = lam ,
ds3_gather_for_generation = ds3_gather_for_generation , * * kwargs )
self . vllm_sampling_params = vllm_sampling_params
self . unsloth_num_chunks = unsloth_num_chunks
2025-08-28 17:57:59 +00:00
2025-08-13 23:50:20 +00:00
pass
class _UnslothPPOTrainer ( Trainer ) :
_tag_names = [ " trl " , " ppo " ]
def __init__ (
self ,
args : PPOConfig ,
processing_class : Optional [
Union [ PreTrainedTokenizerBase , BaseImageProcessor , FeatureExtractionMixin , ProcessorMixin ]
] ,
model : nn . Module ,
ref_model : Optional [ nn . Module ] ,
reward_model : nn . Module ,
train_dataset : Dataset ,
value_model : nn . Module ,
data_collator : Optional [ DataCollatorWithPadding ] = None ,
eval_dataset : Optional [ Union [ Dataset , dict [ str , Dataset ] ] ] = None ,
# less commonly used
optimizers : tuple [ torch . optim . Optimizer , torch . optim . lr_scheduler . LambdaLR ] = ( None , None ) ,
callbacks : Optional [ list [ TrainerCallback ] ] = None ,
peft_config : Optional [ " PeftConfig " ] = None ,
) - > None :
if ref_model is model :
raise ValueError (
" `model` and `ref_model` cannot be the same object. If you want `ref_model` to be the "
" same as `model`, you must make a copy of it, or `None` if you use peft. "
)
self . args = args
self . processing_class = processing_class
self . policy_model = model
# Define the collator if not provided
if data_collator is None :
data_collator = DataCollatorWithPadding ( self . processing_class )
# Handle stop token settings: update policy model's generation_config to use provided stop token
if args . stop_token and args . stop_token_id :
raise ValueError ( " You cannot set both `stop_token` and `stop_token_id`. " )
elif args . stop_token :
if args . stop_token == " eos " :
self . policy_model . generation_config . eos_token_id = self . stop_token_id = processing_class . eos_token_id
else :
raise ValueError (
f " Unknown `stop_token` { args . stop_token } . Allowed values are: ` ' eos ' ` and `None` (no stop token). "
)
else :
self . policy_model . generation_config . eos_token_id = self . stop_token_id = args . stop_token_id # None or int
# Check that the kl estimator is valid
if self . args . kl_estimator not in { " k1 " , " k3 " } :
raise ValueError (
" kl_estimator must be either ' k1 ' (straightforward, unbiased) or ' k3 ' (lower variance, unbiased, "
" appears to be a strictly better estimator). See "
" [Approximating KL Divergence](http://joschu.net/blog/kl-approx.html) for details. "
)
# peft support
if not is_peft_available ( ) and peft_config is not None :
raise ImportError (
" PEFT is not installed and you passed a `peft_config` in the trainer ' s kwargs, please install it to use the PEFT models "
)
elif is_peft_available ( ) and peft_config is not None :
# if model is a peft model and we have a peft_confg, we merge and unload it first
if isinstance ( self . policy_model , PeftModel ) :
self . policy_model = self . policy_model . merge_and_unload ( )
# get peft model with the given config
self . policy_model = get_peft_model ( self . policy_model , peft_config )
if args . bf16 and getattr ( self . policy_model , " is_loaded_in_4bit " , False ) :
peft_module_casting_to_bf16 ( self . policy_model )
self . is_peft_model = is_peft_available ( ) and isinstance ( self . policy_model , PeftModel )
self . model_adapter_name = args . model_adapter_name
self . ref_adapter_name = args . ref_adapter_name
if ref_model :
self . ref_model = ref_model
elif self . is_peft_model :
self . ref_model = None
else :
self . ref_model = create_reference_model ( self . policy_model )
self . reward_model = reward_model
self . train_dataset = train_dataset
self . train_dataset_len = len ( train_dataset )
self . value_model = value_model
self . data_collator = data_collator
self . eval_dataset = eval_dataset
self . optimizer , self . lr_scheduler = optimizers
self . optimizer_cls_and_kwargs = None # needed for transformers >= 4.47
#########
# calculate various batch sizes
#########
if args . total_episodes is None : # allow the users to define episodes in terms of epochs.
args . total_episodes = int ( args . num_train_epochs * self . train_dataset_len )
accelerator = Accelerator ( gradient_accumulation_steps = args . gradient_accumulation_steps )
self . accelerator = accelerator
args . world_size = accelerator . num_processes
args . local_batch_size = args . per_device_train_batch_size * args . gradient_accumulation_steps
args . micro_batch_size = int ( args . per_device_train_batch_size * args . world_size )
args . batch_size = int ( args . local_batch_size * args . world_size )
args . mini_batch_size = exact_div (
args . batch_size , args . num_mini_batches , " `batch_size` must be a multiple of `num_mini_batches` "
)
args . local_mini_batch_size = exact_div (
args . local_batch_size , args . num_mini_batches , " `local_batch_size` must be a multiple of `num_mini_batches` "
)
if args . whiten_rewards :
assert args . local_mini_batch_size > = 8 , (
f " Per-rank minibatch size { args . local_mini_batch_size } is insufficient for whitening "
)
# `per_rank_rollout_batch_size` is our `args.local_batch_size`
# `per_rank_minibatch_size` is our `args.local_mini_batch_size`
args . num_total_batches = math . ceil (
args . total_episodes / args . batch_size
) # we may train for more than `total_episodes`
time_tensor = torch . tensor ( int ( time . time ( ) ) , device = accelerator . device )
time_int = broadcast ( time_tensor , 0 ) . item ( ) # avoid different timestamps across processes
args . run_name = f " { args . exp_name } __ { args . seed } __ { time_int } "
self . local_seed = args . seed + accelerator . process_index * 100003 # Prime
if args . num_sample_generations > 0 :
self . sample_generations_freq = max ( 1 , args . num_total_batches / / args . num_sample_generations )
self . local_dataloader_batch_size = args . local_batch_size
#########
# setup model, optimizer, and others
#########
for module in [ self . policy_model , self . ref_model , self . value_model , self . reward_model ] :
if module is not None :
disable_dropout_in_model ( module )
self . model = PolicyAndValueWrapper ( self . policy_model , self . value_model )
self . model . config = self . policy_model . config # needed for pushing to hub
self . create_optimizer_and_scheduler (
num_training_steps = args . num_total_batches
) # note that we are calling `self.lr_scheduler.step[]` manually only at the batch level
#########
### trainer specifics
#########
default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks ( self . args . report_to )
self . callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks
self . callback_handler = CallbackHandler (
self . callbacks , self . model , self . processing_class , self . optimizer , self . lr_scheduler
)
self . add_callback ( PrinterCallback if self . args . disable_tqdm else DEFAULT_PROGRESS_CALLBACK )
self . control = TrainerControl ( )
self . state = OnlineTrainerState (
is_local_process_zero = self . is_local_process_zero ( ) ,
is_world_process_zero = self . is_world_process_zero ( ) ,
stateful_callbacks = [
cb for cb in self . callback_handler . callbacks + [ self . control ] if isinstance ( cb , ExportableState )
] ,
)
self . current_flos = 0
self . hp_search_backend = None
self . is_deepspeed_enabled = getattr ( self . accelerator . state , " deepspeed_plugin " , None ) is not None
self . is_fsdp_enabled = getattr ( self . accelerator . state , " fsdp_plugin " , None ) is not None
# Create distant repo and output directory if needed
self . hub_model_id = None
if self . args . push_to_hub :
self . init_hf_repo ( )
if self . args . should_save :
os . makedirs ( self . args . output_dir , exist_ok = True )
# Add tags for models that have been loaded with the correct transformers version
if hasattr ( self . model , " add_model_tags " ) :
self . model . add_model_tags ( self . _tag_names )
#########
### setup dataloader
#########
self . dataloader = DataLoader (
self . train_dataset ,
batch_size = self . local_dataloader_batch_size ,
shuffle = True ,
collate_fn = self . data_collator ,
drop_last = True , # needed; otherwise the last batch will be of ragged shape
)
# sync random states for DataLoader[shuffle=True] before `accelerator.prepare`
# see https://gist.github.com/vwxyzjn/2581bff1e48e185e0b85b6dfe1def79c
torch . manual_seed ( args . seed )
self . model , self . optimizer , self . dataloader = accelerator . prepare ( self . model , self . optimizer , self . dataloader )
torch . manual_seed ( self . local_seed ) # reset the local seed again
self . eval_dataloader = DataLoader (
self . eval_dataset ,
batch_size = args . per_device_eval_batch_size ,
collate_fn = self . data_collator ,
drop_last = True ,
) # no need to shuffle eval dataset
self . eval_dataloader = accelerator . prepare ( self . eval_dataloader )
if self . is_deepspeed_enabled :
self . reward_model = prepare_deepspeed (
self . reward_model , args . per_device_train_batch_size , args . fp16 , args . bf16
)
if self . ref_model is None :
if not self . is_peft_model :
raise ValueError ( " No reference model and model is not a Peft model. " )
else :
self . ref_model = prepare_deepspeed (
self . ref_model , args . per_device_train_batch_size , args . fp16 , args . bf16
)
else :
if self . ref_model is None :
if not self . is_peft_model :
raise ValueError ( " No reference model and model is not a Peft model. " )
else :
self . ref_model = self . ref_model . to ( self . accelerator . device )
self . reward_model = self . reward_model . to ( self . accelerator . device )
def get_train_dataloader ( self ) - > DataLoader :
return self . dataloader
def get_eval_dataloader ( self ) - > DataLoader :
return self . eval_dataloader
@contextmanager
def null_ref_context ( self ) :
""" Context manager for handling null reference model (that is, peft adapter manipulation). """
with (
self . accelerator . unwrap_model ( self . model . policy ) . disable_adapter ( )
if self . is_peft_model and not self . ref_adapter_name
else nullcontext ( )
) :
if self . ref_adapter_name :
self . model . policy . set_adapter ( self . ref_adapter_name )
yield
if self . ref_adapter_name :
self . model . policy . set_adapter ( self . model_adapter_name or " default " )
def save_model ( self , output_dir : Optional [ str ] = None , _internal_call : bool = False ) :
backup_model = self . model
self . model = self . model . policy # save only the policy
if self . is_deepspeed_enabled :
backup_deepspeed = self . deepspeed
self . deepspeed = self . model
super ( ) . save_model ( output_dir , _internal_call )
self . model = backup_model
if self . is_deepspeed_enabled :
self . deepspeed = backup_deepspeed
def train ( self ) :
args = self . args
accelerator = self . accelerator
optimizer = self . optimizer
model = self . model
ref_policy = self . ref_model
reward_model = self . reward_model
processing_class = self . processing_class
dataloader = self . dataloader
device = accelerator . device
def repeat_generator ( ) :
while True :
yield from dataloader
iter_dataloader = iter ( repeat_generator ( ) )
generation_config = GenerationConfig (
max_new_tokens = args . response_length ,
temperature = ( args . temperature + 1e-7 ) ,
top_k = 0.0 ,
top_p = 1.0 ,
do_sample = True ,
)
accelerator . print ( " ===training policy=== " )
start_time = time . time ( )
stats_shape = ( args . num_ppo_epochs , args . num_mini_batches , args . gradient_accumulation_steps )
approxkl_stats = torch . zeros ( stats_shape , device = device )
pg_clipfrac_stats = torch . zeros ( stats_shape , device = device )
pg_loss_stats = torch . zeros ( stats_shape , device = device )
vf_loss_stats = torch . zeros ( stats_shape , device = device )
vf_clipfrac_stats = torch . zeros ( stats_shape , device = device )
entropy_stats = torch . zeros ( stats_shape , device = device )
ratio_stats = torch . zeros ( stats_shape , device = device )
model . train ( )
# trainer state initialization
self . state . global_step = 0
self . state . episode = 0
self . state . max_steps = args . num_total_batches
self . state . num_train_epochs = args . total_episodes / self . train_dataset_len
# Compute absolute values for logging, eval, and save if given as ratio
if args . logging_steps is not None :
if args . logging_steps < 1 :
self . state . logging_steps = math . ceil ( self . state . max_steps * args . logging_steps )
else :
self . state . logging_steps = args . logging_steps
if args . eval_steps is not None :
if args . eval_steps < 1 :
self . state . eval_steps = math . ceil ( self . state . max_steps * args . eval_steps )
else :
self . state . eval_steps = args . eval_steps
if args . save_steps is not None :
if args . save_steps < 1 :
self . state . save_steps = math . ceil ( self . state . max_steps * args . save_steps )
else :
self . state . save_steps = args . save_steps
self . control = self . callback_handler . on_train_begin ( args , self . state , self . control )
# backward compatibility
if self . is_deepspeed_enabled :
self . deepspeed = self . model
self . model_wrapped = self . model
for update in range ( 1 , args . num_total_batches + 1 ) :
self . state . episode + = 1 * args . batch_size
data = next ( iter_dataloader )
with torch . no_grad ( ) :
queries = data [ " input_ids " ] . to ( device )
context_length = queries . shape [ 1 ]
responses = [ ]
postprocessed_responses = [ ]
logprobs = [ ]
ref_logprobs = [ ]
scores = [ ]
sequence_lengths = [ ]
values = [ ]
with unwrap_model_for_generation (
self . model , self . accelerator , gather_deepspeed3_params = self . args . ds3_gather_for_generation
) as unwrapped_model :
query_responses , logitss = batch_generation (
unwrapped_model . policy ,
queries ,
args . local_rollout_forward_batch_size ,
processing_class . pad_token_id ,
generation_config ,
)
for i in range ( 0 , queries . shape [ 0 ] , args . local_rollout_forward_batch_size ) :
query = queries [ i : i + args . local_rollout_forward_batch_size ]
query_response = query_responses [ i : i + args . local_rollout_forward_batch_size ]
response = query_response [ : , context_length : ]
logits = logitss [ i : i + args . local_rollout_forward_batch_size ]
logprob = selective_log_softmax ( logits , response )
del logits
empty_cache ( )
if ref_policy is None :
with self . null_ref_context ( ) :
ref_output = forward ( model . policy , query_response , processing_class . pad_token_id )
else :
ref_output = forward ( ref_policy , query_response , processing_class . pad_token_id )
ref_logits = ref_output . logits [ : , context_length - 1 : - 1 ]
ref_logits / = args . temperature + 1e-7
ref_logprob = selective_log_softmax ( ref_logits , response )
del ref_output , ref_logits
empty_cache ( )
# Response Processing 1. truncate response after the first occurrence of `stop_token_id`
postprocessed_response = response
if self . stop_token_id is not None : # handle the edge case when stop_token_id exists but is 0
postprocessed_response = truncate_response (
self . stop_token_id , processing_class . pad_token_id , response
)
# Response Processing 2. run reward model on the truncated responses
postprocessed_query_response = torch . cat ( ( query , postprocessed_response ) , 1 )
sequence_length = first_true_indices ( postprocessed_response == processing_class . pad_token_id ) - 1
unwrapped_value_model = accelerator . unwrap_model ( model ) . value_model
full_value , _ , _ = get_reward (
unwrapped_value_model , query_response , processing_class . pad_token_id , context_length
)
value = full_value [ : , context_length - 1 : - 1 ] . squeeze ( - 1 )
_ , score , _ = get_reward (
reward_model , postprocessed_query_response , processing_class . pad_token_id , context_length
)
responses . append ( response )
postprocessed_responses . append ( postprocessed_response )
logprobs . append ( logprob )
ref_logprobs . append ( ref_logprob )
sequence_lengths . append ( sequence_length )
scores . append ( score )
values . append ( value )
responses = torch . cat ( responses , 0 )
postprocessed_responses = torch . cat ( postprocessed_responses , 0 )
logprobs = torch . cat ( logprobs , 0 )
ref_logprobs = torch . cat ( ref_logprobs , 0 )
sequence_lengths = torch . cat ( sequence_lengths , 0 )
scores = torch . cat ( scores , 0 )
values = torch . cat ( values , 0 )
del ( logprob , ref_logprob , full_value , value , score , unwrapped_model )
empty_cache ( )
gc . collect ( )
# Response Processing 3. Filter completion. Ensure that the sample contains stop_token_id
# Completions not passing that filter will receive a lower score.
contain_eos_token = torch . any ( postprocessed_responses == self . processing_class . eos_token_id , dim = - 1 )
if self . args . missing_eos_penalty is not None :
scores [ ~ contain_eos_token ] - = self . args . missing_eos_penalty
# accelerator.print(f"{scores=}, {(contain_eos_token.sum() / len(contain_eos_token))=}")
# be very careful with `padding_mask_p1`; see https://excalidraw.com/#json=LWnzG4w2k5DjF_EOL_xPt,e2w3a-hFJ_gX5vOfeyXGTw
response_idxs = torch . arange ( responses . shape [ 1 ] , device = responses . device ) . repeat ( responses . shape [ 0 ] , 1 )
padding_mask = response_idxs > sequence_lengths . unsqueeze ( 1 )
logprobs = torch . masked_fill ( logprobs , padding_mask , INVALID_LOGPROB )
ref_logprobs = torch . masked_fill ( ref_logprobs , padding_mask , INVALID_LOGPROB )
sequence_lengths_p1 = sequence_lengths + 1
padding_mask_p1 = response_idxs > ( sequence_lengths_p1 . unsqueeze ( 1 ) )
values = torch . masked_fill ( values , padding_mask_p1 , 0 )
# 4. compute rewards
# Formula used by http://joschu.net/blog/kl-approx.html for the k1 and k3 estimators
logr = ref_logprobs - logprobs
kl = - logr if args . kl_estimator == " k1 " else ( logr . exp ( ) - 1 ) - logr # Else statement is k3
non_score_reward = - args . kl_coef * kl
rewards = non_score_reward . clone ( )
actual_start = torch . arange ( rewards . size ( 0 ) , device = rewards . device )
actual_end = torch . where ( sequence_lengths_p1 < rewards . size ( 1 ) , sequence_lengths_p1 , sequence_lengths )
rewards [ [ actual_start , actual_end ] ] + = scores
# 5. whiten rewards
if args . whiten_rewards :
rewards = masked_whiten ( rewards , mask = ~ padding_mask_p1 , shift_mean = False )
rewards = torch . masked_fill ( rewards , padding_mask_p1 , 0 )
# 6. compute advantages and returns
lastgaelam = 0
advantages_reversed = [ ]
gen_length = responses . shape [ 1 ]
for t in reversed ( range ( gen_length ) ) :
nextvalues = values [ : , t + 1 ] if t < gen_length - 1 else 0.0
delta = rewards [ : , t ] + args . gamma * nextvalues - values [ : , t ]
lastgaelam = delta + args . gamma * args . lam * lastgaelam
advantages_reversed . append ( lastgaelam )
advantages = torch . stack ( advantages_reversed [ : : - 1 ] , axis = 1 )
returns = advantages + values
advantages = masked_whiten ( advantages , ~ padding_mask )
advantages = torch . masked_fill ( advantages , padding_mask , 0 )
empty_cache ( )
# Do multiple epochs of PPO training, with a fresh random shuffle in each epoch
for ppo_epoch_idx in range ( args . num_ppo_epochs ) :
b_inds = np . random . permutation ( args . local_batch_size )
minibatch_idx = 0
for mini_batch_start in range ( 0 , args . local_batch_size , args . local_mini_batch_size ) :
mini_batch_end = mini_batch_start + args . local_mini_batch_size
mini_batch_inds = b_inds [ mini_batch_start : mini_batch_end ]
gradient_accumulation_idx = 0
for micro_batch_start in range ( 0 , args . local_mini_batch_size , args . per_device_train_batch_size ) :
with accelerator . accumulate ( model ) :
micro_batch_end = micro_batch_start + args . per_device_train_batch_size
micro_batch_inds = mini_batch_inds [ micro_batch_start : micro_batch_end ]
mb_advantage = advantages [ micro_batch_inds ]
mb_responses = responses [ micro_batch_inds ]
mb_query_responses = query_responses [ micro_batch_inds ]
mb_logprobs = logprobs [ micro_batch_inds ]
mb_return = returns [ micro_batch_inds ]
mb_values = values [ micro_batch_inds ]
output , vpred_temp = forward ( model , mb_query_responses , processing_class . pad_token_id )
logits = output . logits [ : , context_length - 1 : - 1 ]
logits / = args . temperature + 1e-7
new_logprobs = selective_log_softmax ( logits , mb_responses )
new_logprobs = torch . masked_fill (
new_logprobs , padding_mask [ micro_batch_inds ] , INVALID_LOGPROB
)
vpred = vpred_temp [ : , context_length - 1 : - 1 ] . squeeze ( - 1 )
vpred = torch . masked_fill ( vpred , padding_mask_p1 [ micro_batch_inds ] , 0 )
vpredclipped = torch . clamp (
vpred ,
mb_values - args . cliprange_value ,
mb_values + args . cliprange_value ,
)
vf_losses1 = torch . square ( vpred - mb_return )
vf_losses2 = torch . square ( vpredclipped - mb_return )
vf_loss_max = torch . max ( vf_losses1 , vf_losses2 )
vf_loss = 0.5 * masked_mean ( vf_loss_max , ~ padding_mask_p1 [ micro_batch_inds ] )
vf_clipfrac = masked_mean (
( vf_losses2 > vf_losses1 ) . float ( ) , ~ padding_mask_p1 [ micro_batch_inds ]
)
logprobs_diff = new_logprobs - mb_logprobs
ratio = torch . exp ( logprobs_diff )
pg_losses = - mb_advantage * ratio
pg_losses2 = - mb_advantage * torch . clamp ( ratio , 1.0 - args . cliprange , 1.0 + args . cliprange )
pg_loss_max = torch . max ( pg_losses , pg_losses2 )
pg_loss = masked_mean ( pg_loss_max , ~ padding_mask [ micro_batch_inds ] )
loss = pg_loss + args . vf_coef * vf_loss
accelerator . backward ( loss )
optimizer . step ( )
optimizer . zero_grad ( )
with torch . no_grad ( ) :
pg_clipfrac = masked_mean (
( pg_losses2 > pg_losses ) . float ( ) , ~ padding_mask [ micro_batch_inds ]
)
2025-08-28 17:57:59 +00:00
prob_dist = torch . nn . functional . softmax ( logits , dim = - 1 , dtype = torch . float32 ) . to ( logits . dtype )
2025-08-13 23:50:20 +00:00
entropy = torch . logsumexp ( logits , dim = - 1 ) - torch . sum ( prob_dist * logits , dim = - 1 )
approxkl = 0.5 * ( logprobs_diff * * 2 ) . mean ( )
approxkl_stats [ ppo_epoch_idx , minibatch_idx , gradient_accumulation_idx ] = approxkl
pg_clipfrac_stats [ ppo_epoch_idx , minibatch_idx , gradient_accumulation_idx ] = (
pg_clipfrac
)
pg_loss_stats [ ppo_epoch_idx , minibatch_idx , gradient_accumulation_idx ] = pg_loss
vf_loss_stats [ ppo_epoch_idx , minibatch_idx , gradient_accumulation_idx ] = vf_loss
vf_clipfrac_stats [ ppo_epoch_idx , minibatch_idx , gradient_accumulation_idx ] = (
vf_clipfrac
)
entropy_stats [ ppo_epoch_idx , minibatch_idx , gradient_accumulation_idx ] = entropy . mean ( )
ratio_stats [ ppo_epoch_idx , minibatch_idx , gradient_accumulation_idx ] = ratio . mean ( )
gradient_accumulation_idx + = 1
minibatch_idx + = 1
# del everything and empty cache
# fmt: off
del (
output , vpred_temp , logits , new_logprobs , vpred , vpredclipped ,
vf_losses1 , vf_losses2 , vf_loss , vf_clipfrac , logprobs_diff , ratio , pg_losses , pg_losses2 , pg_loss_max ,
pg_loss , loss , pg_clipfrac , prob_dist , entropy , approxkl , mb_return ,
mb_advantage , mb_values , mb_responses , mb_query_responses , mb_logprobs ,
)
# fmt: on
empty_cache ( )
with torch . no_grad ( ) :
mean_kl = kl . sum ( 1 ) . mean ( )
mean_entropy = ( - logprobs ) . sum ( 1 ) . mean ( )
mean_non_score_reward = non_score_reward . sum ( 1 ) . mean ( )
rlhf_reward = mean_non_score_reward + scores . mean ( )
eps = int ( self . state . episode / ( time . time ( ) - start_time ) )
metrics = { }
metrics [ " eps " ] = eps
metrics [ " objective/kl " ] = self . accelerator . gather_for_metrics ( mean_kl ) . mean ( ) . item ( )
metrics [ " objective/entropy " ] = self . accelerator . gather_for_metrics ( mean_entropy ) . mean ( ) . item ( )
metrics [ " objective/non_score_reward " ] = (
self . accelerator . gather_for_metrics ( mean_non_score_reward ) . mean ( ) . item ( )
)
metrics [ " objective/rlhf_reward " ] = self . accelerator . gather_for_metrics ( rlhf_reward ) . mean ( ) . item ( )
metrics [ " objective/scores " ] = self . accelerator . gather_for_metrics ( scores . mean ( ) ) . mean ( ) . item ( )
metrics [ " policy/approxkl_avg " ] = self . accelerator . gather_for_metrics ( approxkl_stats ) . mean ( ) . item ( )
metrics [ " policy/clipfrac_avg " ] = self . accelerator . gather_for_metrics ( pg_clipfrac_stats ) . mean ( ) . item ( )
metrics [ " loss/policy_avg " ] = self . accelerator . gather_for_metrics ( pg_loss_stats ) . mean ( ) . item ( )
metrics [ " loss/value_avg " ] = self . accelerator . gather_for_metrics ( vf_loss_stats ) . mean ( ) . item ( )
metrics [ " val/clipfrac_avg " ] = self . accelerator . gather_for_metrics ( vf_clipfrac_stats ) . mean ( ) . item ( )
metrics [ " policy/entropy_avg " ] = self . accelerator . gather_for_metrics ( entropy_stats ) . mean ( ) . item ( )
metrics [ " val/ratio " ] = self . accelerator . gather_for_metrics ( ratio_stats ) . mean ( ) . item ( )
metrics [ " val/ratio_var " ] = self . accelerator . gather_for_metrics ( ratio_stats ) . var ( ) . item ( )
metrics [ " val/num_eos_tokens " ] = ( responses == processing_class . eos_token_id ) . sum ( ) . item ( )
metrics [ " lr " ] = self . lr_scheduler . get_last_lr ( ) [ 0 ]
metrics [ " episode " ] = self . state . episode
self . state . epoch = self . state . episode / self . train_dataset_len # used by self.log
self . state . global_step + = 1
self . log ( metrics )
self . lr_scheduler . step ( )
self . control = self . callback_handler . on_step_end ( args , self . state , self . control )
if self . control . should_save :
self . _save_checkpoint ( model , trial = None )
self . control = self . callback_handler . on_save ( self . args , self . state , self . control )
del kl , mean_kl , mean_entropy , mean_non_score_reward , scores , metrics , non_score_reward
empty_cache ( )
gc . collect ( )
if args . num_sample_generations > 0 and ( update - 1 ) % self . sample_generations_freq == 0 :
self . generate_completions ( sampling = True )
empty_cache ( )
del (
query_responses ,
responses ,
postprocessed_responses ,
logprobs ,
ref_logprobs ,
values ,
sequence_lengths ,
contain_eos_token ,
sequence_lengths_p1 ,
response_idxs ,
padding_mask ,
padding_mask_p1 ,
rewards ,
actual_start ,
actual_end ,
advantages ,
returns ,
)
empty_cache ( )
# HF trainer specifics
self . control = self . callback_handler . on_train_end ( args , self . state , self . control )
if self . control . should_save :
self . _save_checkpoint ( model , trial = None , metrics = None )
self . control = self . callback_handler . on_save ( self . args , self . state , self . control )
def generate_completions ( self , sampling : bool = False ) :
args = self . args
processing_class = self . processing_class
generation_config = GenerationConfig (
max_new_tokens = self . args . response_length ,
temperature = ( 0.01 + 1e-7 ) ,
top_k = 0.0 ,
top_p = 1.0 ,
do_sample = True ,
)
table = defaultdict ( list )
with unwrap_model_for_generation (
self . model , self . accelerator , gather_deepspeed3_params = self . args . ds3_gather_for_generation
) as unwrapped_model :
for batch in self . eval_dataloader :
query = batch [ " input_ids " ]
with torch . no_grad ( ) :
context_length = query . shape [ 1 ]
query_response , _ = batch_generation (
unwrapped_model . policy ,
query ,
query . shape [ 0 ] ,
processing_class . pad_token_id ,
generation_config ,
)
response = query_response [ : , context_length : ]
postprocessed_response = response
if self . stop_token_id is not None : # handle the edge case when stop_token_id exists but is 0
postprocessed_response = truncate_response (
self . stop_token_id , processing_class . pad_token_id , response
)
table [ " query " ] . extend (
gather_object ( processing_class . batch_decode ( query , skip_special_tokens = True ) )
)
table [ " model response " ] . extend (
gather_object ( processing_class . batch_decode ( postprocessed_response ) )
)
postprocessed_query_response = torch . cat ( ( query , postprocessed_response ) , 1 )
_ , score , _ = get_reward (
self . reward_model , postprocessed_query_response , processing_class . pad_token_id , context_length
)
table [ " score " ] . extend ( self . accelerator . gather_for_metrics ( score ) . float ( ) . cpu ( ) . numpy ( ) )
if sampling :
break
df = pd . DataFrame ( table )
if self . accelerator . is_main_process :
if is_rich_available ( ) :
print_rich_table ( df . iloc [ 0 : 0 + 5 ] )
if " wandb " in args . report_to :
import wandb
if wandb . run is not None :
wandb . log ( { " completions " : wandb . Table ( dataframe = df ) } )
if " comet_ml " in args . report_to :
log_table_to_comet_experiment (
name = " completions.csv " ,
table = df ,
)
# Ensure the model card is saved along with the checkpoint
def _save_checkpoint ( self , model , trial ) :
if self . args . hub_model_id is None :
model_name = Path ( self . args . output_dir ) . name
else :
model_name = self . args . hub_model_id . split ( " / " ) [ - 1 ]
self . create_model_card ( model_name = model_name )
super ( ) . _save_checkpoint ( model , trial )
def create_model_card (
self ,
model_name : Optional [ str ] = None ,
dataset_name : Optional [ str ] = None ,
tags : Union [ str , list [ str ] , None ] = None ,
) :
"""
Creates a draft of a model card using the information available to the `Trainer`.
Args:
model_name (`str` or `None`, *optional*, defaults to `None`):
Name of the model.
dataset_name (`str` or `None`, *optional*, defaults to `None`):
Name of the dataset used for training.
tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
Tags to be associated with the model card.
"""
if not self . is_world_process_zero ( ) :
return
if hasattr ( self . model . config , " _name_or_path " ) and not os . path . isdir ( self . model . config . _name_or_path ) :
base_model = self . model . config . _name_or_path
else :
base_model = None
# normalize `tags` to a mutable set
if tags is None :
tags = set ( )
elif isinstance ( tags , str ) :
tags = { tags }
else :
tags = set ( tags )
if hasattr ( self . model . config , " unsloth_version " ) :
tags . add ( " unsloth " )
tags . update ( self . _tag_names )
citation = textwrap . dedent ( """ \
@article { mziegler2019fine-tuning,
title = {{ Fine-Tuning Language Models from Human Preferences}},
author = { Daniel M. Ziegler and Nisan Stiennon and Jeffrey Wu and Tom B. Brown and Alec Radford and Dario Amodei and Paul F. Christiano and Geoffrey Irving},
year = 2019,
eprint = {arXiv:1909.08593}
} """ )
model_card = generate_model_card (
base_model = base_model ,
model_name = model_name ,
hub_model_id = self . hub_model_id ,
dataset_name = dataset_name ,
tags = tags ,
wandb_url = wandb . run . url if is_wandb_available ( ) and wandb . run is not None else None ,
comet_url = get_comet_experiment_url ( ) ,
trainer_name = " PPO " ,
trainer_citation = citation ,
paper_title = " Fine-Tuning Language Models from Human Preferences " ,
paper_id = " 1909.08593 " ,
)
model_card . save ( os . path . join ( self . args . output_dir , " README.md " ) )
class UnslothPPOTrainer ( _UnslothPPOTrainer ) :
"""
"""
def __init__ (
self ,
args ,
processing_class ,
model ,
ref_model ,
reward_model ,
train_dataset ,
value_model ,
data_collator = None ,
eval_dataset = None ,
callbacks = None ,
peft_config = None ,
* * kwargs
) :
if args is None : args = UnslothPPOConfig ( )
use_bf16 = getattr ( args , ' bf16 ' , False )
if type ( use_bf16 ) is not bool : use_bf16 = False
use_fp16 = getattr ( args , ' fp16 ' , False )
if type ( use_fp16 ) is not bool : use_fp16 = False
force_float32 = False
if os . environ . get ( ' UNSLOTH_FORCE_FLOAT32 ' , ' 0 ' ) == ' 1 ' :
print ( ' Unsloth: Switching to float32 training since model cannot work with float16 ' )
force_float32 = True
mixed_precision_dtype = os . environ . get ( ' UNSLOTH_MIXED_PRECISION ' , ' float32 ' )
2025-08-28 17:57:59 +00:00
dtype = getattr ( model . config , ' dtype ' , None ) or getattr ( model . config , ' torch_dtype ' , None )
2025-08-13 23:50:20 +00:00
if dtype is None : dtype = model . get_input_embeddings ( ) . dtype
from unsloth_zoo . utils import _get_dtype
dtype = _get_dtype ( dtype )
float16 = dtype == torch . float16
if not force_float32 and ( float16 and use_bf16 ) : raise TypeError ( ' Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False` ' )
if not force_float32 and ( not float16 and use_fp16 ) : raise TypeError ( ' Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True` ' )
if force_float32 :
args . fp16 = False
args . bf16 = False
os . environ [ ' ACCELERATE_MIXED_PRECISION ' ] = ' no '
elif ( not use_bf16 and not use_fp16 ) and mixed_precision_dtype == ' float32 ' :
args . fp16 = float16
args . bf16 = not float16
os . environ [ ' ACCELERATE_MIXED_PRECISION ' ] = ' fp16 ' if float16 else ' bf16 '
if getattr ( args , ' eval_dataset ' , None ) is not None and getattr ( args , ' eval_strategy ' , ' no ' ) == ' no ' :
args . eval_strategy = ' steps '
if getattr ( args , ' eval_steps ' , None ) is None : args . eval_steps = 0.1
ga_steps = getattr ( args , ' gradient_accumulation_steps ' , None )
if ga_steps is not None and ga_steps > 1 :
from transformers import __version__ as transformers_version
if Version ( transformers_version ) < = Version ( ' 4.45.2 ' ) :
print ( ' **** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth! \n '
' `pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo` ' )
if getattr ( args , ' eval_strategy ' , ' no ' ) != ' no ' :
eval_bsz = getattr ( args , ' per_device_eval_batch_size ' , 8 )
if eval_bsz == 8 and args . per_device_train_batch_size < eval_bsz : args . per_device_eval_batch_size = args . per_device_train_batch_size
if getattr ( args , ' eval_accumulation_steps ' , None ) is None and ga_steps is not None : args . eval_accumulation_steps = ga_steps
fp16_full_eval = getattr ( args , ' fp16_full_eval ' , False )
if type ( fp16_full_eval ) is not bool : fp16_full_eval = False
bf16_full_eval = getattr ( args , ' bf16_full_eval ' , False )
if type ( bf16_full_eval ) is not bool : bf16_full_eval = False
if args . fp16 and bf16_full_eval : args . bf16_full_eval = False ; args . fp16_full_eval = True
if args . bf16 and fp16_full_eval : args . bf16_full_eval = True ; args . fp16_full_eval = False
if force_float32 :
args . bf16_full_eval = False
args . fp16_full_eval = False
elif os . environ . get ( ' UNSLOTH_MIXED_PRECISION ' , ' float32 ' ) == ' bfloat16 ' :
args . bf16_full_eval = True
args . fp16_full_eval = False
elif not bf16_full_eval and not fp16_full_eval :
args . bf16_full_eval = args . bf16
args . fp16_full_eval = args . fp16
_output_logits = False
if locals ( ) . get ( ' compute_metrics ' , None ) is not None : _output_logits = True
if locals ( ) . get ( ' preprocess_logits_for_metrics ' , None ) is not None : _output_logits = True
if _output_logits :
os . environ [ ' UNSLOTH_RETURN_LOGITS ' ] = ' 1 '
if ' max_seq_length ' not in locals ( ) and not hasattr ( args , ' max_seq_length ' ) :
pass
else :
model_max_seq_length = getattr ( model , ' max_seq_length ' , None )
args_max_seq_length = getattr ( args , ' max_seq_length ' , None )
if args_max_seq_length is None and model_max_seq_length is not None :
max_seq_length = model . max_seq_length
if hasattr ( args , ' max_seq_length ' ) : args . max_seq_length = max_seq_length
if model is not None and hasattr ( model , ' for_training ' ) :
model . for_training ( )
if ' tokenizer ' in locals ( ) and hasattr ( tokenizer , ' padding_side ' ) : tokenizer . padding_side = ' right '
if ' processing_class ' in locals ( ) :
if hasattr ( processing_class , ' padding_side ' ) : processing_class . padding_side = ' right '
if hasattr ( processing_class , ' tokenizer ' ) and hasattr ( processing_class . tokenizer , ' padding_side ' ) : processing_class . tokenizer . padding_side = ' right '
__tokenizer = processing_class if ' processing_class ' in locals ( ) else tokenizer
from unsloth_zoo . vision_utils import UnslothVisionDataCollator
if not isinstance ( data_collator , UnslothVisionDataCollator ) :
if isinstance ( data_collator , DataCollatorForSeq2Seq ) and ' labels ' not in train_dataset . column_names :
2025-08-28 17:57:59 +00:00
data_collator = TransformersDataCollatorForLanguageModeling (
__tokenizer ,
mlm = False ,
mlm_probability = 0.0 ,
pad_to_multiple_of = getattr ( args , ' pad_to_multiple_of ' , None ) ,
)
2025-08-13 23:50:20 +00:00
elif isinstance ( data_collator , TransformersDataCollatorForLanguageModeling ) and ' labels ' in train_dataset . column_names :
2025-08-28 17:57:59 +00:00
data_collator = DataCollatorForSeq2Seq (
__tokenizer ,
pad_to_multiple_of = getattr ( args , ' pad_to_multiple_of ' , None ) ,
)
2025-08-13 23:50:20 +00:00
else :
if hasattr ( args , ' remove_unused_columns ' ) : args . remove_unused_columns = False
if hasattr ( args , ' dataset_text_field ' ) : args . dataset_text_field = ' '
if hasattr ( args , ' dataset_kwargs ' ) : args . dataset_kwargs = { ' skip_prepare_dataset ' : True }
if not isinstance ( data_collator , UnslothVisionDataCollator ) :
if not hasattr ( __tokenizer , ' pad ' ) and hasattr ( __tokenizer , ' tokenizer ' ) :
if isinstance ( data_collator , DataCollatorForSeq2Seq ) :
2025-08-28 17:57:59 +00:00
data_collator = DataCollatorForSeq2Seq (
__tokenizer . tokenizer ,
pad_to_multiple_of = getattr ( args , ' pad_to_multiple_of ' , None ) ,
)
2025-08-13 23:50:20 +00:00
else :
2025-08-28 17:57:59 +00:00
data_collator = TransformersDataCollatorForLanguageModeling (
__tokenizer . tokenizer ,
mlm = False ,
mlm_probability = 0.0 ,
pad_to_multiple_of = getattr ( args , ' pad_to_multiple_of ' , None ) ,
)
2025-08-13 23:50:20 +00:00
other_metrics = [ ]
from unsloth_zoo . logging_utils import PatchRLStatistics
PatchRLStatistics ( ' ppo_trainer ' , other_metrics )
super ( ) . __init__ (
args = args ,
processing_class = processing_class ,
model = model ,
ref_model = ref_model ,
reward_model = reward_model ,
train_dataset = train_dataset ,
value_model = value_model ,
data_collator = data_collator ,
eval_dataset = eval_dataset ,
callbacks = callbacks ,
peft_config = peft_config , * * kwargs )
if hasattr ( self , ' neftune_hook_handle ' ) :
self . neftune_hook_handle . remove ( )
if hasattr ( self , ' neftune_hook_handle ' ) : del self . neftune_hook_handle
if getattr ( args , ' neftune_noise_alpha ' , None ) is not None :
model . get_input_embeddings ( ) . neftune_noise_alpha = self . neftune_noise_alpha
pass
if hasattr ( self , ' accelerator ' ) :
scaler = self . accelerator . scaler
current_model = model
while hasattr ( current_model , ' model ' ) :
current_model . accelerator_scaler = scaler
current_model = current_model . model
current_model . accelerator_scaler = scaler
pass
pass