2025-08-13 23:50:20 +00:00
"""
2025-08-28 17:57:59 +00:00
2025.8.9
2025.8.10
4.55.4
2025-08-13 23:50:20 +00:00
0.21.0
__UNSLOTH_VERSIONING__
"""
from torch import Tensor
import torch
import torch . nn as nn
from torch . nn import functional as F
from typing import Any , List , Optional , Tuple , Union , Dict , Set , Callable
from trl . trainer . gkd_trainer import ( Any , AutoModelForCausalLM , BaseImageProcessor , Callable , DataCollator , DataCollatorForChatML , Dataset , EvalPrediction , F , FeatureExtractionMixin , GKDConfig , GKDTrainer , GenerationConfig , Optional , PeftConfig , PreTrainedModel , PreTrainedTokenizerBase , ProcessorMixin , SFTTrainer , TrainerCallback , Union , disable_dropout_in_model , empty_cache , generate_model_card , get_comet_experiment_url , is_wandb_available , nn , os , prepare_deepspeed , random , textwrap , torch , unwrap_model_for_generation )
import os
from typing import *
from dataclasses import dataclass , field
from packaging . version import Version
import torch
import numpy as np
from contextlib import nullcontext
from torch . nn import functional as F
from transformers import DataCollatorForSeq2Seq , DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling
torch_compile_options = {
" epilogue_fusion " : True ,
" max_autotune " : False ,
" shape_padding " : True ,
" trace.enabled " : False ,
" triton.cudagraphs " : False ,
}
@torch.compile ( dynamic = True , fullgraph = True , options = torch_compile_options , )
def chunked_selective_log_softmax ( logits , index ) :
# Split into 4 chunks only
chunked_logits = torch . chunk ( logits . reshape ( - 1 , logits . shape [ - 1 ] ) , chunks = 4 , dim = 0 )
chunked_index = torch . chunk ( index . reshape ( - 1 ) , chunks = 4 , dim = 0 )
all_per_token_logps = [ ]
# Below loop does the same as selective_log_softmax(chunk_logits, chunk_index)
for chunk_logits , chunk_index in zip ( chunked_logits , chunked_index ) :
chunk_logits = chunk_logits . to ( torch . float32 )
selected_logits = torch . gather ( chunk_logits , dim = - 1 , index = chunk_index . unsqueeze ( - 1 ) ) . squeeze ( - 1 )
logsumexp_values = torch . logsumexp ( chunk_logits , dim = - 1 )
per_token_logps = selected_logits - logsumexp_values
all_per_token_logps . append ( per_token_logps )
pass
all_per_token_logps = torch . concat ( all_per_token_logps )
all_per_token_logps = all_per_token_logps . reshape ( ( logits . shape [ 0 ] , logits . shape [ 1 ] ) )
return all_per_token_logps
@dataclass
class UnslothGKDConfig ( GKDConfig ) :
"""
Configuration class for [`GKDTrainer`].
This class includes only the parameters that are specific to GKD training. For a full list of training arguments,
please refer to the [`~transformers.TrainingArguments`] and [`SFTConfig`] documentation.
Args:
temperature (`float`, *optional*, defaults to `0.9`):
Temperature for sampling. The higher the temperature, the more random the completions.
lmbda (`float`, *optional*, defaults to `0.5`):
Lambda parameter that controls the student data fraction (i.e., the proportion of on-policy
student-generated outputs).
beta (`float`, *optional*, defaults to `0.5`):
Interpolation coefficient between `0.0` and `1.0` of the Generalized Jensen-Shannon Divergence loss. When
beta is `0.0`, the loss is the KL divergence. When beta is `1.0`, the loss is the Inverse KL Divergence.
max_new_tokens (`int`, *optional*, defaults to `128`):
Maximum number of tokens to generate per completion.
teacher_model_name_or_path (`str` or `None`, *optional*, defaults to `None`):
Model name or path of the teacher model. If `None`, the teacher model will be the same as the model being
trained.
teacher_model_init_kwargs (`dict[str, Any]]` or `None`, *optional*, defaults to `None`):
Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the teacher model
from a string.
disable_dropout (`bool`, *optional*, defaults to `True`):
Whether to disable dropout in the model.
seq_kd (`bool`, *optional*, defaults to `False`):
Seq_kd parameter that controls whether to perform Sequence-Level KD (can be viewed as supervised FT on
teacher-generated output).
"""
vllm_sampling_params : Optional [ Any ] = field (
default = None ,
metadata = { ' help ' : ' vLLM SamplingParams ' } ,
)
unsloth_num_chunks : Optional [ int ] = field (
default = - 1 ,
metadata = { ' help ' : ' Chunk size to reduce memory usage. -1 is most efficient. ' } ,
)
2025-08-28 17:57:59 +00:00
max_seq_length : Optional [ int ] = field (
default = None ,
metadata = { ' help ' : ' Maximum sequence length to truncate to. ' } ,
)
2025-08-13 23:50:20 +00:00
def __init__ (
self ,
output_dir = None ,
overwrite_output_dir = None ,
do_train = False ,
do_eval = False ,
do_predict = False ,
eval_strategy = ' no ' ,
prediction_loss_only = False ,
per_device_train_batch_size = 4 ,
per_device_eval_batch_size = 4 ,
per_gpu_train_batch_size = None ,
per_gpu_eval_batch_size = None ,
gradient_accumulation_steps = 2 ,
eval_accumulation_steps = 2 ,
eval_delay = 0 ,
torch_empty_cache_steps = 250 ,
learning_rate = 5e-05 ,
weight_decay = 0.01 ,
adam_beta1 = 0.9 ,
adam_beta2 = 0.999 ,
adam_epsilon = 1e-08 ,
max_grad_norm = 1.0 ,
num_train_epochs = 3.0 ,
max_steps = - 1 ,
lr_scheduler_type = ' linear ' ,
warmup_ratio = 0.1 ,
warmup_steps = 0 ,
log_level = ' passive ' ,
log_level_replica = ' warning ' ,
log_on_each_node = True ,
logging_dir = None ,
logging_strategy = ' steps ' ,
logging_first_step = False ,
logging_steps = 1 ,
logging_nan_inf_filter = False ,
save_strategy = ' steps ' ,
save_steps = 500 ,
save_total_limit = None ,
save_safetensors = True ,
save_on_each_node = False ,
save_only_model = False ,
restore_callback_states_from_checkpoint = False ,
no_cuda = False ,
use_cpu = False ,
use_mps_device = False ,
seed = 3407 ,
data_seed = 3407 ,
jit_mode_eval = False ,
use_ipex = False ,
bf16 = False ,
fp16 = False ,
fp16_opt_level = ' O1 ' ,
half_precision_backend = ' auto ' ,
bf16_full_eval = False ,
fp16_full_eval = False ,
tf32 = None ,
local_rank = - 1 ,
ddp_backend = None ,
tpu_num_cores = None ,
tpu_metrics_debug = False ,
debug = ' ' ,
dataloader_drop_last = False ,
eval_steps = None ,
dataloader_num_workers = 0 ,
dataloader_prefetch_factor = None ,
past_index = - 1 ,
run_name = None ,
disable_tqdm = None ,
remove_unused_columns = True ,
label_names = None ,
load_best_model_at_end = False ,
metric_for_best_model = None ,
greater_is_better = None ,
ignore_data_skip = False ,
fsdp = ' ' ,
fsdp_min_num_params = 0 ,
fsdp_config = None ,
fsdp_transformer_layer_cls_to_wrap = None ,
accelerator_config = None ,
deepspeed = None ,
label_smoothing_factor = 0.0 ,
optim = ' adamw_8bit ' ,
optim_args = None ,
adafactor = False ,
group_by_length = False ,
length_column_name = ' length ' ,
report_to = None ,
ddp_find_unused_parameters = None ,
ddp_bucket_cap_mb = None ,
ddp_broadcast_buffers = None ,
dataloader_pin_memory = True ,
dataloader_persistent_workers = False ,
skip_memory_metrics = True ,
use_legacy_prediction_loop = False ,
push_to_hub = False ,
resume_from_checkpoint = None ,
hub_model_id = None ,
hub_strategy = ' every_save ' ,
hub_token = None ,
hub_private_repo = None ,
hub_always_push = False ,
hub_revision = None ,
gradient_checkpointing = False ,
gradient_checkpointing_kwargs = None ,
include_inputs_for_metrics = False ,
eval_do_concat_batches = True ,
fp16_backend = ' auto ' ,
push_to_hub_model_id = None ,
push_to_hub_organization = None ,
push_to_hub_token = None ,
mp_parameters = ' ' ,
auto_find_batch_size = True ,
full_determinism = False ,
torchdynamo = None ,
ray_scope = ' last ' ,
ddp_timeout = 1800 ,
torch_compile = False ,
torch_compile_backend = None ,
torch_compile_mode = None ,
include_tokens_per_second = False ,
include_num_input_tokens_seen = False ,
neftune_noise_alpha = None ,
optim_target_modules = None ,
batch_eval_metrics = False ,
eval_on_start = False ,
use_liger_kernel = False ,
liger_kernel_config = None ,
eval_use_gather_object = False ,
average_tokens_across_devices = True ,
model_init_kwargs = None ,
chat_template_path = None ,
dataset_text_field = ' text ' ,
dataset_kwargs = None ,
dataset_num_proc = None ,
eos_token = None ,
pad_token = None ,
max_length = 1024 ,
packing = False ,
packing_strategy = ' bfd ' ,
padding_free = False ,
pad_to_multiple_of = None ,
eval_packing = None ,
completion_only_loss = None ,
assistant_only_loss = False ,
activation_offloading = False ,
temperature = 0.9 ,
lmbda = 0.5 ,
beta = 0.5 ,
max_new_tokens = 128 ,
teacher_model_name_or_path = None ,
teacher_model_init_kwargs = None ,
disable_dropout = True ,
seq_kd = False ,
vllm_sampling_params = None ,
unsloth_num_chunks = - 1 ,
2025-08-28 17:57:59 +00:00
max_seq_length = None ,
2025-08-13 23:50:20 +00:00
* * kwargs ,
) :
if learning_rate < 1e-7 : raise FloatingPointError ( f ' Unsloth: Your learning rate of ` { learning_rate } ` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0! ' )
if learning_rate > 1 : raise OverflowError ( f ' Unsloth: Your learning rate of ` { learning_rate } ` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode! ' )
if output_dir is None and save_strategy == ' steps ' and save_steps == 500 :
output_dir = ' unsloth_training_checkpoints '
save_strategy = ' no '
if dataset_num_proc is None :
from multiprocessing import cpu_count
2025-08-28 17:57:59 +00:00
dataset_num_proc = max ( cpu_count ( ) + 4 , 2 )
if os . environ . get ( ' UNSLOTH_ENABLE_FLEX_ATTENTION ' , ' 0 ' ) == ' 1 ' :
from unsloth_zoo . flex_attention import HAS_FLEX_ATTENTION
if HAS_FLEX_ATTENTION and pad_to_multiple_of is None :
from unsloth_zoo . flex_attention import FLEX_ATTENTION_BLOCK_SIZE
pad_to_multiple_of = FLEX_ATTENTION_BLOCK_SIZE
2025-08-13 23:50:20 +00:00
if temperature < = 0 :
raise MathError ( ' Unsloth: Please set a positive non-zero temperature since your results will be wrong. ' )
elif temperature > = 10 :
raise MathError ( ' Unsloth: Please set a positive non-zero temperature less than 10, since sampling will be quite erratic. ' )
super ( ) . __init__ (
output_dir = output_dir ,
overwrite_output_dir = overwrite_output_dir ,
do_train = do_train ,
do_eval = do_eval ,
do_predict = do_predict ,
eval_strategy = eval_strategy ,
prediction_loss_only = prediction_loss_only ,
per_device_train_batch_size = per_device_train_batch_size ,
per_device_eval_batch_size = per_device_eval_batch_size ,
per_gpu_train_batch_size = per_gpu_train_batch_size ,
per_gpu_eval_batch_size = per_gpu_eval_batch_size ,
gradient_accumulation_steps = gradient_accumulation_steps ,
eval_accumulation_steps = eval_accumulation_steps ,
eval_delay = eval_delay ,
torch_empty_cache_steps = torch_empty_cache_steps ,
learning_rate = learning_rate ,
weight_decay = weight_decay ,
adam_beta1 = adam_beta1 ,
adam_beta2 = adam_beta2 ,
adam_epsilon = adam_epsilon ,
max_grad_norm = max_grad_norm ,
num_train_epochs = num_train_epochs ,
max_steps = max_steps ,
lr_scheduler_type = lr_scheduler_type ,
warmup_ratio = warmup_ratio ,
warmup_steps = warmup_steps ,
log_level = log_level ,
log_level_replica = log_level_replica ,
log_on_each_node = log_on_each_node ,
logging_dir = logging_dir ,
logging_strategy = logging_strategy ,
logging_first_step = logging_first_step ,
logging_steps = logging_steps ,
logging_nan_inf_filter = logging_nan_inf_filter ,
save_strategy = save_strategy ,
save_steps = save_steps ,
save_total_limit = save_total_limit ,
save_safetensors = save_safetensors ,
save_on_each_node = save_on_each_node ,
save_only_model = save_only_model ,
restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint ,
no_cuda = no_cuda ,
use_cpu = use_cpu ,
use_mps_device = use_mps_device ,
seed = seed ,
data_seed = data_seed ,
jit_mode_eval = jit_mode_eval ,
use_ipex = use_ipex ,
bf16 = bf16 ,
fp16 = fp16 ,
fp16_opt_level = fp16_opt_level ,
half_precision_backend = half_precision_backend ,
bf16_full_eval = bf16_full_eval ,
fp16_full_eval = fp16_full_eval ,
tf32 = tf32 ,
local_rank = local_rank ,
ddp_backend = ddp_backend ,
tpu_num_cores = tpu_num_cores ,
tpu_metrics_debug = tpu_metrics_debug ,
debug = debug ,
dataloader_drop_last = dataloader_drop_last ,
eval_steps = eval_steps ,
dataloader_num_workers = dataloader_num_workers ,
dataloader_prefetch_factor = dataloader_prefetch_factor ,
past_index = past_index ,
run_name = run_name ,
disable_tqdm = disable_tqdm ,
remove_unused_columns = remove_unused_columns ,
label_names = label_names ,
load_best_model_at_end = load_best_model_at_end ,
metric_for_best_model = metric_for_best_model ,
greater_is_better = greater_is_better ,
ignore_data_skip = ignore_data_skip ,
fsdp = fsdp ,
fsdp_min_num_params = fsdp_min_num_params ,
fsdp_config = fsdp_config ,
fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap ,
accelerator_config = accelerator_config ,
deepspeed = deepspeed ,
label_smoothing_factor = label_smoothing_factor ,
optim = optim ,
optim_args = optim_args ,
adafactor = adafactor ,
group_by_length = group_by_length ,
length_column_name = length_column_name ,
report_to = report_to ,
ddp_find_unused_parameters = ddp_find_unused_parameters ,
ddp_bucket_cap_mb = ddp_bucket_cap_mb ,
ddp_broadcast_buffers = ddp_broadcast_buffers ,
dataloader_pin_memory = dataloader_pin_memory ,
dataloader_persistent_workers = dataloader_persistent_workers ,
skip_memory_metrics = skip_memory_metrics ,
use_legacy_prediction_loop = use_legacy_prediction_loop ,
push_to_hub = push_to_hub ,
resume_from_checkpoint = resume_from_checkpoint ,
hub_model_id = hub_model_id ,
hub_strategy = hub_strategy ,
hub_token = hub_token ,
hub_private_repo = hub_private_repo ,
hub_always_push = hub_always_push ,
hub_revision = hub_revision ,
gradient_checkpointing = gradient_checkpointing ,
gradient_checkpointing_kwargs = gradient_checkpointing_kwargs ,
include_inputs_for_metrics = include_inputs_for_metrics ,
eval_do_concat_batches = eval_do_concat_batches ,
fp16_backend = fp16_backend ,
push_to_hub_model_id = push_to_hub_model_id ,
push_to_hub_organization = push_to_hub_organization ,
push_to_hub_token = push_to_hub_token ,
mp_parameters = mp_parameters ,
auto_find_batch_size = auto_find_batch_size ,
full_determinism = full_determinism ,
torchdynamo = torchdynamo ,
ray_scope = ray_scope ,
ddp_timeout = ddp_timeout ,
torch_compile = torch_compile ,
torch_compile_backend = torch_compile_backend ,
torch_compile_mode = torch_compile_mode ,
include_tokens_per_second = include_tokens_per_second ,
include_num_input_tokens_seen = include_num_input_tokens_seen ,
neftune_noise_alpha = neftune_noise_alpha ,
optim_target_modules = optim_target_modules ,
batch_eval_metrics = batch_eval_metrics ,
eval_on_start = eval_on_start ,
use_liger_kernel = use_liger_kernel ,
liger_kernel_config = liger_kernel_config ,
eval_use_gather_object = eval_use_gather_object ,
average_tokens_across_devices = average_tokens_across_devices ,
model_init_kwargs = model_init_kwargs ,
chat_template_path = chat_template_path ,
dataset_text_field = dataset_text_field ,
dataset_kwargs = dataset_kwargs ,
dataset_num_proc = dataset_num_proc ,
eos_token = eos_token ,
pad_token = pad_token ,
max_length = max_length ,
packing = packing ,
packing_strategy = packing_strategy ,
padding_free = padding_free ,
pad_to_multiple_of = pad_to_multiple_of ,
eval_packing = eval_packing ,
completion_only_loss = completion_only_loss ,
assistant_only_loss = assistant_only_loss ,
activation_offloading = activation_offloading ,
temperature = temperature ,
lmbda = lmbda ,
beta = beta ,
max_new_tokens = max_new_tokens ,
teacher_model_name_or_path = teacher_model_name_or_path ,
teacher_model_init_kwargs = teacher_model_init_kwargs ,
disable_dropout = disable_dropout ,
seq_kd = seq_kd , * * kwargs )
self . vllm_sampling_params = vllm_sampling_params
self . unsloth_num_chunks = unsloth_num_chunks
2025-08-28 17:57:59 +00:00
self . max_seq_length = max_seq_length
2025-08-13 23:50:20 +00:00
pass
class _UnslothGKDTrainer ( SFTTrainer ) :
_tag_names = [ " trl " , " gkd " ]
def __init__ (
self ,
model : Optional [ Union [ PreTrainedModel , nn . Module , str ] ] = None ,
teacher_model : Union [ PreTrainedModel , nn . Module , str ] = None ,
args : Optional [ GKDConfig ] = None ,
data_collator : Optional [ DataCollator ] = None , # type: ignore
train_dataset : Optional [ Dataset ] = None ,
eval_dataset : Optional [ Union [ Dataset , dict [ str , Dataset ] ] ] = None ,
processing_class : Optional [
Union [ PreTrainedTokenizerBase , BaseImageProcessor , FeatureExtractionMixin , ProcessorMixin ]
] = None ,
compute_metrics : Optional [ Callable [ [ EvalPrediction ] , dict ] ] = None ,
callbacks : Optional [ list [ TrainerCallback ] ] = None ,
optimizers : tuple [ torch . optim . Optimizer , torch . optim . lr_scheduler . LambdaLR ] = ( None , None ) ,
preprocess_logits_for_metrics : Optional [ Callable [ [ torch . Tensor , torch . Tensor ] , torch . Tensor ] ] = None ,
peft_config : Optional [ " PeftConfig " ] = None ,
formatting_func : Optional [ Callable ] = None ,
) :
# add remove_unused_columns=False to the dataclass args
args . remove_unused_columns = False
data_collator = DataCollatorForChatML ( tokenizer = processing_class , max_length = args . max_length )
super ( ) . __init__ (
model ,
args = args ,
data_collator = data_collator ,
train_dataset = train_dataset ,
eval_dataset = eval_dataset ,
processing_class = processing_class ,
compute_metrics = compute_metrics ,
callbacks = callbacks ,
optimizers = optimizers ,
preprocess_logits_for_metrics = preprocess_logits_for_metrics ,
peft_config = peft_config ,
formatting_func = formatting_func ,
)
if args . teacher_model_init_kwargs is None :
teacher_model_init_kwargs = { }
elif not isinstance ( teacher_model , str ) :
raise ValueError (
" You passed teacher_model_init_kwargs to the GKDConfig, but your teacher_model is already instantiated. "
)
else :
teacher_model_init_kwargs = args . teacher_model_init_kwargs
teacher_model_init_kwargs [ " torch_dtype " ] = (
teacher_model_init_kwargs [ " torch_dtype " ]
if teacher_model_init_kwargs [ " torch_dtype " ] in [ " auto " , None ]
else getattr ( torch , teacher_model_init_kwargs [ " torch_dtype " ] )
)
if isinstance ( teacher_model , str ) :
teacher_model = AutoModelForCausalLM . from_pretrained ( teacher_model , * * teacher_model_init_kwargs )
# Disable dropout in the model
if args . disable_dropout :
disable_dropout_in_model ( self . model )
if self . is_deepspeed_enabled :
self . teacher_model = prepare_deepspeed ( teacher_model , self . accelerator )
else :
self . teacher_model = self . accelerator . prepare_model ( teacher_model , evaluation_mode = True )
self . lmbda = args . lmbda
self . beta = args . beta
self . temperature = args . temperature
self . seq_kd = args . seq_kd
self . generation_config = GenerationConfig (
max_new_tokens = args . max_new_tokens ,
temperature = args . temperature ,
do_sample = True ,
top_k = 0 ,
use_cache = False if args . gradient_checkpointing else True ,
pad_token_id = self . processing_class . pad_token_id ,
)
# Set custom EOS tokens if they are specified by the model's generation
# config. This is important for models with the Llama 3 chat template,
# which use special tokens <|eot_id|> and <|eom_id|> to mark the end of
# turns or messages.
if (
hasattr ( self . model . generation_config , " eos_token_id " )
and self . model . generation_config . eos_token_id is not None
) :
self . generation_config . eos_token_id = self . model . generation_config . eos_token_id
@staticmethod
def generalized_jsd_loss (
student_logits , teacher_logits , labels = None , beta = 0.5 , temperature = 1.0 , reduction = " batchmean "
) :
"""
Compute the generalized Jensen-Shannon Divergence loss for knowledge distillation using F.kl_div. See Eq. (1)
of https://huggingface.co/papers/2306.13649 for the definition.
Args:
student_logits:
Tensor of shape (batch_size, sequence_length, vocab_size)
teacher_logits:
Tensor of shape (batch_size, sequence_length, vocab_size)
labels:
Tensor of shape (batch_size, sequence_length) with -100 for padding tokens to ignore when computing
loss
beta:
Interpolation coefficient between 0 and 1 (default: 0.5)
temperature:
Softmax temperature (default: 1.0)
reduction:
Specifies the reduction to apply to the output (default: ' batchmean ' )
Returns:
loss: Scalar tensor with the generalized JSD loss
"""
# Apply temperature scaling
student_logits = student_logits / temperature
teacher_logits = teacher_logits / temperature
# Compute log probabilities for student and probabilities for teacher
student_log_probs = F . log_softmax ( student_logits , dim = - 1 )
teacher_log_probs = F . log_softmax ( teacher_logits , dim = - 1 )
if beta == 0 :
jsd = F . kl_div ( student_log_probs , teacher_log_probs , reduction = " none " , log_target = True )
elif beta == 1 :
jsd = F . kl_div ( teacher_log_probs , student_log_probs , reduction = " none " , log_target = True )
else :
# Compute the log of the mixture distribution
# log(a + b) = log(exp(log(a)) + exp(log(b))) -> for mixture
beta = torch . tensor ( beta , dtype = student_log_probs . dtype )
mixture_log_probs = torch . logsumexp (
torch . stack ( [ student_log_probs + torch . log ( 1 - beta ) , teacher_log_probs + torch . log ( beta ) ] ) ,
dim = 0 ,
)
# Compute KL divergences using F.kl_div
# PyTorch differs from the standard mathematical definition, so the order of the probability distributions is swapped compared to that defined in the paper.
kl_teacher = F . kl_div ( mixture_log_probs , teacher_log_probs , reduction = " none " , log_target = True )
kl_student = F . kl_div ( mixture_log_probs , student_log_probs , reduction = " none " , log_target = True )
# Compute the Generalized Jensen-Shannon Divergence
jsd = beta * kl_teacher + ( 1 - beta ) * kl_student
# Masking
if labels is not None :
mask = labels != - 100
jsd = jsd [ mask ]
# Apply reduction
if reduction == " batchmean " :
return jsd . sum ( ) / mask . sum ( ) if labels is not None else jsd . sum ( ) / ( jsd . size ( 0 ) * jsd . size ( 1 ) )
elif reduction == " sum " :
return jsd . sum ( )
elif reduction == " mean " :
return jsd . mean ( )
else :
return jsd
def compute_loss ( self , model , inputs , return_outputs = False , num_items_in_batch = None ) :
# compute student output
outputs_student = model (
input_ids = inputs [ " input_ids " ] ,
attention_mask = inputs [ " attention_mask " ] ,
)
# compute teacher output in eval mode
self . teacher_model . eval ( )
with torch . no_grad ( ) :
outputs_teacher = self . teacher_model (
input_ids = inputs [ " input_ids " ] ,
attention_mask = inputs [ " attention_mask " ] ,
)
# slice the logits for the generated tokens using the inputs["prompts"] lengths
prompt_lengths = inputs [ " prompts " ] . shape [ 1 ]
shifted_student_logits = outputs_student . logits [ : , prompt_lengths - 1 : - 1 , : ]
shifted_teacher_logits = outputs_teacher . logits [ : , prompt_lengths - 1 : - 1 , : ]
shifted_labels = inputs [ " labels " ] [ : , prompt_lengths : ]
# compute loss
loss = self . generalized_jsd_loss (
student_logits = shifted_student_logits ,
teacher_logits = shifted_teacher_logits ,
labels = shifted_labels ,
beta = self . beta ,
)
# empty cache
empty_cache ( )
# Return loss
return ( loss , outputs_student ) if return_outputs else loss
@staticmethod
def generate_on_policy_outputs ( model , inputs , generation_config , pad_token_id = None ) :
# Generate output with respect to the prompt-only
generated_outputs = model . generate (
input_ids = inputs [ " prompts " ] ,
attention_mask = inputs . get ( " prompt_attention_mask " , None ) ,
generation_config = generation_config ,
return_dict_in_generate = True ,
)
# Get the generated token IDs
generated_tokens = generated_outputs . sequences
# Calculate new attention mask
new_attention_mask = torch . ones_like ( generated_tokens )
new_labels = generated_tokens . clone ( )
# If there's pad_token_id, set attention mask to 0 for padding tokens
if pad_token_id is not None :
new_labels [ new_labels == pad_token_id ] = - 100
new_attention_mask [ generated_tokens == pad_token_id ] = 0
return generated_tokens , new_attention_mask , new_labels
def training_step (
self , model : nn . Module , inputs : dict [ str , Union [ torch . Tensor , Any ] ] , num_items_in_batch : Optional [ int ] = None
) - > torch . Tensor :
"""
Perform a training step for the Generalized Knowledge Distillation (GKD) model.
This method implements the on-policy learning approach described in the GKD paper. With probability
`self.lmbda`, it generates new responses using the student model, which are then used for training instead of
the original inputs.
"""
if self . seq_kd :
with unwrap_model_for_generation ( self . teacher_model , self . accelerator ) as unwrapped_model :
new_input_ids , new_attention_mask , new_labels = self . generate_on_policy_outputs (
unwrapped_model , inputs , self . generation_config , self . processing_class . pad_token_id
)
inputs [ " input_ids " ] = new_input_ids
inputs [ " attention_mask " ] = new_attention_mask
inputs [ " labels " ] = new_labels
if random . random ( ) < = self . lmbda :
with unwrap_model_for_generation ( model , self . accelerator ) as unwrapped_model :
new_input_ids , new_attention_mask , new_labels = self . generate_on_policy_outputs (
unwrapped_model , inputs , self . generation_config , self . processing_class . pad_token_id
)
inputs [ " input_ids " ] = new_input_ids
inputs [ " attention_mask " ] = new_attention_mask
inputs [ " labels " ] = new_labels
loss = super ( ) . training_step ( model , inputs , num_items_in_batch )
return loss
def create_model_card (
self ,
model_name : Optional [ str ] = None ,
dataset_name : Optional [ str ] = None ,
tags : Union [ str , list [ str ] , None ] = None ,
) :
"""
Creates a draft of a model card using the information available to the `Trainer`.
Args:
model_name (`str` or `None`, *optional*, defaults to `None`):
Name of the model.
dataset_name (`str` or `None`, *optional*, defaults to `None`):
Name of the dataset used for training.
tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
Tags to be associated with the model card.
"""
if not self . is_world_process_zero ( ) :
return
if hasattr ( self . model . config , " _name_or_path " ) and not os . path . isdir ( self . model . config . _name_or_path ) :
base_model = self . model . config . _name_or_path
else :
base_model = None
# normalize `tags` to a mutable set
if tags is None :
tags = set ( )
elif isinstance ( tags , str ) :
tags = { tags }
else :
tags = set ( tags )
if hasattr ( self . model . config , " unsloth_version " ) :
tags . add ( " unsloth " )
tags . update ( self . _tag_names )
citation = textwrap . dedent ( """ \
@inproceedings { agarwal2024on-policy,
title = {{ On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes}},
author = { Rishabh Agarwal and Nino Vieillard and Yongchao Zhou and Piotr Stanczyk and Sabela Ramos Garea and Matthieu Geist and Olivier Bachem},
year = 2024,
booktitle = { The Twelfth International Conference on Learning Representations, {ICLR} 2024, Vienna, Austria, May 7-11, 2024},
publisher = {OpenReview.net} ,
url = { https://openreview.net/forum?id=3zKtaqxLhW},
} """ )
model_card = generate_model_card (
base_model = base_model ,
model_name = model_name ,
hub_model_id = self . hub_model_id ,
dataset_name = dataset_name ,
tags = tags ,
wandb_url = wandb . run . url if is_wandb_available ( ) and wandb . run is not None else None ,
comet_url = get_comet_experiment_url ( ) ,
trainer_name = " GKD " ,
trainer_citation = citation ,
paper_title = " On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes " ,
paper_id = " 2306.13649 " ,
)
model_card . save ( os . path . join ( self . args . output_dir , " README.md " ) )
class UnslothGKDTrainer ( _UnslothGKDTrainer ) :
"""
"""
def __init__ (
self ,
model = None ,
teacher_model = None ,
args = None ,
data_collator = None ,
train_dataset = None ,
eval_dataset = None ,
processing_class = None ,
compute_metrics = None ,
callbacks = None ,
preprocess_logits_for_metrics = None ,
peft_config = None ,
formatting_func = None ,
* * kwargs
) :
if args is None : args = UnslothGKDConfig ( )
use_bf16 = getattr ( args , ' bf16 ' , False )
if type ( use_bf16 ) is not bool : use_bf16 = False
use_fp16 = getattr ( args , ' fp16 ' , False )
if type ( use_fp16 ) is not bool : use_fp16 = False
force_float32 = False
if os . environ . get ( ' UNSLOTH_FORCE_FLOAT32 ' , ' 0 ' ) == ' 1 ' :
print ( ' Unsloth: Switching to float32 training since model cannot work with float16 ' )
force_float32 = True
mixed_precision_dtype = os . environ . get ( ' UNSLOTH_MIXED_PRECISION ' , ' float32 ' )
2025-08-28 17:57:59 +00:00
dtype = getattr ( model . config , ' dtype ' , None ) or getattr ( model . config , ' torch_dtype ' , None )
2025-08-13 23:50:20 +00:00
if dtype is None : dtype = model . get_input_embeddings ( ) . dtype
from unsloth_zoo . utils import _get_dtype
dtype = _get_dtype ( dtype )
float16 = dtype == torch . float16
if not force_float32 and ( float16 and use_bf16 ) : raise TypeError ( ' Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False` ' )
if not force_float32 and ( not float16 and use_fp16 ) : raise TypeError ( ' Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True` ' )
if force_float32 :
args . fp16 = False
args . bf16 = False
os . environ [ ' ACCELERATE_MIXED_PRECISION ' ] = ' no '
elif ( not use_bf16 and not use_fp16 ) and mixed_precision_dtype == ' float32 ' :
args . fp16 = float16
args . bf16 = not float16
os . environ [ ' ACCELERATE_MIXED_PRECISION ' ] = ' fp16 ' if float16 else ' bf16 '
if getattr ( args , ' eval_dataset ' , None ) is not None and getattr ( args , ' eval_strategy ' , ' no ' ) == ' no ' :
args . eval_strategy = ' steps '
if getattr ( args , ' eval_steps ' , None ) is None : args . eval_steps = 0.1
ga_steps = getattr ( args , ' gradient_accumulation_steps ' , None )
if ga_steps is not None and ga_steps > 1 :
from transformers import __version__ as transformers_version
if Version ( transformers_version ) < = Version ( ' 4.45.2 ' ) :
print ( ' **** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth! \n '
' `pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo` ' )
if getattr ( args , ' eval_strategy ' , ' no ' ) != ' no ' :
eval_bsz = getattr ( args , ' per_device_eval_batch_size ' , 8 )
if eval_bsz == 8 and args . per_device_train_batch_size < eval_bsz : args . per_device_eval_batch_size = args . per_device_train_batch_size
if getattr ( args , ' eval_accumulation_steps ' , None ) is None and ga_steps is not None : args . eval_accumulation_steps = ga_steps
fp16_full_eval = getattr ( args , ' fp16_full_eval ' , False )
if type ( fp16_full_eval ) is not bool : fp16_full_eval = False
bf16_full_eval = getattr ( args , ' bf16_full_eval ' , False )
if type ( bf16_full_eval ) is not bool : bf16_full_eval = False
if args . fp16 and bf16_full_eval : args . bf16_full_eval = False ; args . fp16_full_eval = True
if args . bf16 and fp16_full_eval : args . bf16_full_eval = True ; args . fp16_full_eval = False
if force_float32 :
args . bf16_full_eval = False
args . fp16_full_eval = False
elif os . environ . get ( ' UNSLOTH_MIXED_PRECISION ' , ' float32 ' ) == ' bfloat16 ' :
args . bf16_full_eval = True
args . fp16_full_eval = False
elif not bf16_full_eval and not fp16_full_eval :
args . bf16_full_eval = args . bf16
args . fp16_full_eval = args . fp16
_output_logits = False
if locals ( ) . get ( ' compute_metrics ' , None ) is not None : _output_logits = True
if locals ( ) . get ( ' preprocess_logits_for_metrics ' , None ) is not None : _output_logits = True
if _output_logits :
os . environ [ ' UNSLOTH_RETURN_LOGITS ' ] = ' 1 '
if ' max_seq_length ' not in locals ( ) and not hasattr ( args , ' max_seq_length ' ) :
pass
else :
model_max_seq_length = getattr ( model , ' max_seq_length ' , None )
args_max_seq_length = getattr ( args , ' max_seq_length ' , None )
if args_max_seq_length is None and model_max_seq_length is not None :
max_seq_length = model . max_seq_length
if hasattr ( args , ' max_seq_length ' ) : args . max_seq_length = max_seq_length
if model is not None and hasattr ( model , ' for_training ' ) :
model . for_training ( )
if ' tokenizer ' in locals ( ) and hasattr ( tokenizer , ' padding_side ' ) : tokenizer . padding_side = ' right '
if ' processing_class ' in locals ( ) :
if hasattr ( processing_class , ' padding_side ' ) : processing_class . padding_side = ' right '
if hasattr ( processing_class , ' tokenizer ' ) and hasattr ( processing_class . tokenizer , ' padding_side ' ) : processing_class . tokenizer . padding_side = ' right '
__tokenizer = processing_class if ' processing_class ' in locals ( ) else tokenizer
from unsloth_zoo . vision_utils import UnslothVisionDataCollator
if not isinstance ( data_collator , UnslothVisionDataCollator ) :
if isinstance ( data_collator , DataCollatorForSeq2Seq ) and ' labels ' not in train_dataset . column_names :
2025-08-28 17:57:59 +00:00
data_collator = TransformersDataCollatorForLanguageModeling (
__tokenizer ,
mlm = False ,
mlm_probability = 0.0 ,
pad_to_multiple_of = getattr ( args , ' pad_to_multiple_of ' , None ) ,
)
2025-08-13 23:50:20 +00:00
elif isinstance ( data_collator , TransformersDataCollatorForLanguageModeling ) and ' labels ' in train_dataset . column_names :
2025-08-28 17:57:59 +00:00
data_collator = DataCollatorForSeq2Seq (
__tokenizer ,
pad_to_multiple_of = getattr ( args , ' pad_to_multiple_of ' , None ) ,
)
2025-08-13 23:50:20 +00:00
else :
if hasattr ( args , ' remove_unused_columns ' ) : args . remove_unused_columns = False
if hasattr ( args , ' dataset_text_field ' ) : args . dataset_text_field = ' '
if hasattr ( args , ' dataset_kwargs ' ) : args . dataset_kwargs = { ' skip_prepare_dataset ' : True }
if not isinstance ( data_collator , UnslothVisionDataCollator ) :
if not hasattr ( __tokenizer , ' pad ' ) and hasattr ( __tokenizer , ' tokenizer ' ) :
if isinstance ( data_collator , DataCollatorForSeq2Seq ) :
2025-08-28 17:57:59 +00:00
data_collator = DataCollatorForSeq2Seq (
__tokenizer . tokenizer ,
pad_to_multiple_of = getattr ( args , ' pad_to_multiple_of ' , None ) ,
)
2025-08-13 23:50:20 +00:00
else :
2025-08-28 17:57:59 +00:00
data_collator = TransformersDataCollatorForLanguageModeling (
__tokenizer . tokenizer ,
mlm = False ,
mlm_probability = 0.0 ,
pad_to_multiple_of = getattr ( args , ' pad_to_multiple_of ' , None ) ,
)
2025-08-13 23:50:20 +00:00
other_metrics = [ ]
from unsloth_zoo . logging_utils import PatchRLStatistics
PatchRLStatistics ( ' gkd_trainer ' , other_metrics )
super ( ) . __init__ (
model = model ,
teacher_model = teacher_model ,
args = args ,
data_collator = data_collator ,
train_dataset = train_dataset ,
eval_dataset = eval_dataset ,
processing_class = processing_class ,
compute_metrics = compute_metrics ,
callbacks = callbacks ,
preprocess_logits_for_metrics = preprocess_logits_for_metrics ,
peft_config = peft_config ,
formatting_func = formatting_func , * * kwargs )
if hasattr ( self , ' neftune_hook_handle ' ) :
self . neftune_hook_handle . remove ( )
if hasattr ( self , ' neftune_hook_handle ' ) : del self . neftune_hook_handle
if getattr ( args , ' neftune_noise_alpha ' , None ) is not None :
model . get_input_embeddings ( ) . neftune_noise_alpha = self . neftune_noise_alpha
pass
if hasattr ( self , ' accelerator ' ) :
scaler = self . accelerator . scaler
current_model = model
while hasattr ( current_model , ' model ' ) :
current_model . accelerator_scaler = scaler
current_model = current_model . model
current_model . accelerator_scaler = scaler
pass
pass