updated style mimciking fine tuning
This commit is contained in:
@@ -0,0 +1,665 @@
|
||||
"""
|
||||
2025.8.4
|
||||
2025.8.5
|
||||
4.55.1
|
||||
0.21.0
|
||||
__UNSLOTH_VERSIONING__
|
||||
"""
|
||||
from torch import Tensor
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable
|
||||
from trl.trainer.alignprop_trainer import (Accelerator, AlignPropConfig, AlignPropTrainer, Any, Callable, DDPOStableDiffusionPipeline, Optional, Path, ProjectConfiguration, PyTorchModelHubMixin, Union, defaultdict, generate_model_card, get_comet_experiment_url, is_wandb_available, logger, os, set_seed, textwrap, torch, warnings)
|
||||
|
||||
|
||||
import os
|
||||
from typing import *
|
||||
from dataclasses import dataclass, field
|
||||
from packaging.version import Version
|
||||
import torch
|
||||
import numpy as np
|
||||
from contextlib import nullcontext
|
||||
from torch.nn import functional as F
|
||||
from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling
|
||||
|
||||
torch_compile_options = {
|
||||
"epilogue_fusion" : True,
|
||||
"max_autotune" : False,
|
||||
"shape_padding" : True,
|
||||
"trace.enabled" : False,
|
||||
"triton.cudagraphs" : False,
|
||||
}
|
||||
|
||||
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
|
||||
def chunked_selective_log_softmax(logits, index):
|
||||
# Split into 4 chunks only
|
||||
chunked_logits = torch.chunk(logits.reshape(-1, logits.shape[-1]), chunks = 4, dim = 0)
|
||||
chunked_index = torch.chunk(index.reshape(-1), chunks = 4, dim = 0)
|
||||
all_per_token_logps = []
|
||||
# Below loop does the same as selective_log_softmax(chunk_logits, chunk_index)
|
||||
for chunk_logits, chunk_index in zip(chunked_logits, chunked_index):
|
||||
chunk_logits = chunk_logits.to(torch.float32)
|
||||
selected_logits = torch.gather(chunk_logits, dim = -1, index = chunk_index.unsqueeze(-1)).squeeze(-1)
|
||||
logsumexp_values = torch.logsumexp(chunk_logits, dim = -1)
|
||||
per_token_logps = selected_logits - logsumexp_values
|
||||
all_per_token_logps.append(per_token_logps)
|
||||
pass
|
||||
all_per_token_logps = torch.concat(all_per_token_logps)
|
||||
all_per_token_logps = all_per_token_logps.reshape((logits.shape[0], logits.shape[1]))
|
||||
return all_per_token_logps
|
||||
@dataclass
|
||||
class UnslothAlignPropConfig(AlignPropConfig):
|
||||
"""
|
||||
|
||||
Configuration class for the [`AlignPropTrainer`].
|
||||
|
||||
Using [`~transformers.HfArgumentParser`] we can turn this class into
|
||||
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
|
||||
command line.
|
||||
|
||||
Parameters:
|
||||
exp_name (`str`, *optional*, defaults to `os.path.basename(sys.argv[0])[: -len(".py")]`):
|
||||
Name of this experiment (defaults to the file name without the extension).
|
||||
run_name (`str`, *optional*, defaults to `""`):
|
||||
Name of this run.
|
||||
seed (`int`, *optional*, defaults to `0`):
|
||||
Random seed for reproducibility.
|
||||
log_with (`str` or `None`, *optional*, defaults to `None`):
|
||||
Log with either `"wandb"` or `"tensorboard"`. Check
|
||||
[tracking](https://huggingface.co/docs/accelerate/usage_guides/tracking) for more details.
|
||||
log_image_freq (`int`, *optional*, defaults to `1`):
|
||||
Frequency for logging images.
|
||||
tracker_kwargs (`dict[str, Any]`, *optional*, defaults to `{}`):
|
||||
Keyword arguments for the tracker (e.g., `wandb_project`).
|
||||
accelerator_kwargs (`dict[str, Any]`, *optional*, defaults to `{}`):
|
||||
Keyword arguments for the accelerator.
|
||||
project_kwargs (`dict[str, Any]`, *optional*, defaults to `{}`):
|
||||
Keyword arguments for the accelerator project config (e.g., `logging_dir`).
|
||||
tracker_project_name (`str`, *optional*, defaults to `"trl"`):
|
||||
Name of project to use for tracking.
|
||||
logdir (`str`, *optional*, defaults to `"logs"`):
|
||||
Top-level logging directory for checkpoint saving.
|
||||
num_epochs (`int`, *optional*, defaults to `100`):
|
||||
Number of epochs to train.
|
||||
save_freq (`int`, *optional*, defaults to `1`):
|
||||
Number of epochs between saving model checkpoints.
|
||||
num_checkpoint_limit (`int`, *optional*, defaults to `5`):
|
||||
Number of checkpoints to keep before overwriting old ones.
|
||||
mixed_precision (`str`, *optional*, defaults to `"fp16"`):
|
||||
Mixed precision training.
|
||||
allow_tf32 (`bool`, *optional*, defaults to `True`):
|
||||
Allow `tf32` on Ampere GPUs.
|
||||
resume_from (`str`, *optional*, defaults to `""`):
|
||||
Path to resume training from a checkpoint.
|
||||
sample_num_steps (`int`, *optional*, defaults to `50`):
|
||||
Number of sampler inference steps.
|
||||
sample_eta (`float`, *optional*, defaults to `1.0`):
|
||||
Eta parameter for the DDIM sampler.
|
||||
sample_guidance_scale (`float`, *optional*, defaults to `5.0`):
|
||||
Classifier-free guidance weight.
|
||||
train_batch_size (`int`, *optional*, defaults to `1`):
|
||||
Batch size for training.
|
||||
train_use_8bit_adam (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use the 8bit Adam optimizer from `bitsandbytes`.
|
||||
train_learning_rate (`float`, *optional*, defaults to `1e-3`):
|
||||
Learning rate.
|
||||
train_adam_beta1 (`float`, *optional*, defaults to `0.9`):
|
||||
Beta1 for Adam optimizer.
|
||||
train_adam_beta2 (`float`, *optional*, defaults to `0.999`):
|
||||
Beta2 for Adam optimizer.
|
||||
train_adam_weight_decay (`float`, *optional*, defaults to `1e-4`):
|
||||
Weight decay for Adam optimizer.
|
||||
train_adam_epsilon (`float`, *optional*, defaults to `1e-8`):
|
||||
Epsilon value for Adam optimizer.
|
||||
train_gradient_accumulation_steps (`int`, *optional*, defaults to `1`):
|
||||
Number of gradient accumulation steps.
|
||||
train_max_grad_norm (`float`, *optional*, defaults to `1.0`):
|
||||
Maximum gradient norm for gradient clipping.
|
||||
negative_prompts (`str` or `None`, *optional*, defaults to `None`):
|
||||
Comma-separated list of prompts to use as negative examples.
|
||||
truncated_backprop_rand (`bool`, *optional*, defaults to `True`):
|
||||
If `True`, randomized truncation to different diffusion timesteps is used.
|
||||
truncated_backprop_timestep (`int`, *optional*, defaults to `49`):
|
||||
Absolute timestep to which the gradients are backpropagated. Used only if `truncated_backprop_rand=False`.
|
||||
truncated_rand_backprop_minmax (`tuple[int, int]`, *optional*, defaults to `(0, 50)`):
|
||||
Range of diffusion timesteps for randomized truncated backpropagation.
|
||||
push_to_hub (`bool`, *optional*, defaults to `False`):
|
||||
Whether to push the final model to the Hub.
|
||||
|
||||
"""
|
||||
vllm_sampling_params: Optional[Any] = field(
|
||||
default = None,
|
||||
metadata = {'help': 'vLLM SamplingParams'},
|
||||
)
|
||||
unsloth_num_chunks : Optional[int] = field(
|
||||
default = -1,
|
||||
metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
|
||||
)
|
||||
def __init__(
|
||||
self,
|
||||
exp_name = 'inference',
|
||||
run_name = '',
|
||||
seed = 3407,
|
||||
log_with = None,
|
||||
log_image_freq = 1,
|
||||
tracker_project_name = 'trl',
|
||||
logdir = 'logs',
|
||||
num_epochs = 100,
|
||||
save_freq = 1,
|
||||
num_checkpoint_limit = 5,
|
||||
mixed_precision = 'fp16',
|
||||
allow_tf32 = True,
|
||||
resume_from = '',
|
||||
sample_num_steps = 50,
|
||||
sample_eta = 1.0,
|
||||
sample_guidance_scale = 5.0,
|
||||
train_batch_size = 1,
|
||||
train_use_8bit_adam = False,
|
||||
train_learning_rate = 5e-05,
|
||||
train_adam_beta1 = 0.9,
|
||||
train_adam_beta2 = 0.999,
|
||||
train_adam_weight_decay = 0.01,
|
||||
train_adam_epsilon = 1e-08,
|
||||
train_gradient_accumulation_steps = 2,
|
||||
train_max_grad_norm = 1.0,
|
||||
negative_prompts = None,
|
||||
truncated_backprop_rand = True,
|
||||
truncated_backprop_timestep = 49,
|
||||
push_to_hub = False,
|
||||
vllm_sampling_params = None,
|
||||
unsloth_num_chunks = -1,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
super().__init__(
|
||||
exp_name = exp_name,
|
||||
run_name = run_name,
|
||||
seed = seed,
|
||||
log_with = log_with,
|
||||
log_image_freq = log_image_freq,
|
||||
tracker_project_name = tracker_project_name,
|
||||
logdir = logdir,
|
||||
num_epochs = num_epochs,
|
||||
save_freq = save_freq,
|
||||
num_checkpoint_limit = num_checkpoint_limit,
|
||||
mixed_precision = mixed_precision,
|
||||
allow_tf32 = allow_tf32,
|
||||
resume_from = resume_from,
|
||||
sample_num_steps = sample_num_steps,
|
||||
sample_eta = sample_eta,
|
||||
sample_guidance_scale = sample_guidance_scale,
|
||||
train_batch_size = train_batch_size,
|
||||
train_use_8bit_adam = train_use_8bit_adam,
|
||||
train_learning_rate = train_learning_rate,
|
||||
train_adam_beta1 = train_adam_beta1,
|
||||
train_adam_beta2 = train_adam_beta2,
|
||||
train_adam_weight_decay = train_adam_weight_decay,
|
||||
train_adam_epsilon = train_adam_epsilon,
|
||||
train_gradient_accumulation_steps = train_gradient_accumulation_steps,
|
||||
train_max_grad_norm = train_max_grad_norm,
|
||||
negative_prompts = negative_prompts,
|
||||
truncated_backprop_rand = truncated_backprop_rand,
|
||||
truncated_backprop_timestep = truncated_backprop_timestep,
|
||||
push_to_hub = push_to_hub,**kwargs)
|
||||
self.vllm_sampling_params = vllm_sampling_params
|
||||
self.unsloth_num_chunks = unsloth_num_chunks
|
||||
pass
|
||||
|
||||
class _UnslothAlignPropTrainer(PyTorchModelHubMixin):
|
||||
""""""
|
||||
|
||||
_tag_names = ["trl", "alignprop"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: AlignPropConfig,
|
||||
reward_function: Callable[[torch.Tensor, tuple[str], tuple[Any]], torch.Tensor],
|
||||
prompt_function: Callable[[], tuple[str, Any]],
|
||||
sd_pipeline: DDPOStableDiffusionPipeline,
|
||||
image_samples_hook: Optional[Callable[[Any, Any, Any], Any]] = None,
|
||||
):
|
||||
warnings.warn(
|
||||
"AlignPropTrainer is deprecated and will be removed in version 0.23.0.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
if image_samples_hook is None:
|
||||
warnings.warn("No image_samples_hook provided; no images will be logged")
|
||||
|
||||
self.prompt_fn = prompt_function
|
||||
self.reward_fn = reward_function
|
||||
self.config = config
|
||||
self.image_samples_callback = image_samples_hook
|
||||
|
||||
accelerator_project_config = ProjectConfiguration(**self.config.project_kwargs)
|
||||
|
||||
if self.config.resume_from:
|
||||
self.config.resume_from = os.path.normpath(os.path.expanduser(self.config.resume_from))
|
||||
if "checkpoint_" not in os.path.basename(self.config.resume_from):
|
||||
# get the most recent checkpoint in this directory
|
||||
checkpoints = list(
|
||||
filter(
|
||||
lambda x: "checkpoint_" in x,
|
||||
os.listdir(self.config.resume_from),
|
||||
)
|
||||
)
|
||||
if len(checkpoints) == 0:
|
||||
raise ValueError(f"No checkpoints found in {self.config.resume_from}")
|
||||
checkpoint_numbers = sorted([int(x.split("_")[-1]) for x in checkpoints])
|
||||
self.config.resume_from = os.path.join(
|
||||
self.config.resume_from,
|
||||
f"checkpoint_{checkpoint_numbers[-1]}",
|
||||
)
|
||||
|
||||
accelerator_project_config.iteration = checkpoint_numbers[-1] + 1
|
||||
|
||||
self.accelerator = Accelerator(
|
||||
log_with=self.config.log_with,
|
||||
mixed_precision=self.config.mixed_precision,
|
||||
project_config=accelerator_project_config,
|
||||
# we always accumulate gradients across timesteps; we want config.train.gradient_accumulation_steps to be the
|
||||
# number of *samples* we accumulate across, so we need to multiply by the number of training timesteps to get
|
||||
# the total number of optimizer steps to accumulate across.
|
||||
gradient_accumulation_steps=self.config.train_gradient_accumulation_steps,
|
||||
**self.config.accelerator_kwargs,
|
||||
)
|
||||
|
||||
is_using_tensorboard = config.log_with is not None and config.log_with == "tensorboard"
|
||||
|
||||
if self.accelerator.is_main_process:
|
||||
self.accelerator.init_trackers(
|
||||
self.config.tracker_project_name,
|
||||
config=dict(alignprop_trainer_config=config.to_dict())
|
||||
if not is_using_tensorboard
|
||||
else config.to_dict(),
|
||||
init_kwargs=self.config.tracker_kwargs,
|
||||
)
|
||||
|
||||
logger.info(f"\n{config}")
|
||||
|
||||
set_seed(self.config.seed, device_specific=True)
|
||||
|
||||
self.sd_pipeline = sd_pipeline
|
||||
|
||||
self.sd_pipeline.set_progress_bar_config(
|
||||
position=1,
|
||||
disable=not self.accelerator.is_local_main_process,
|
||||
leave=False,
|
||||
desc="Timestep",
|
||||
dynamic_ncols=True,
|
||||
)
|
||||
|
||||
# For mixed precision training we cast all non-trainable weights [vae, non-lora text_encoder and non-lora unet] to half-precision
|
||||
# as these weights are only used for inference, keeping weights in full precision is not required.
|
||||
if self.accelerator.mixed_precision == "fp16":
|
||||
inference_dtype = torch.float16
|
||||
elif self.accelerator.mixed_precision == "bf16":
|
||||
inference_dtype = torch.bfloat16
|
||||
else:
|
||||
inference_dtype = torch.float32
|
||||
|
||||
self.sd_pipeline.vae.to(self.accelerator.device, dtype=inference_dtype)
|
||||
self.sd_pipeline.text_encoder.to(self.accelerator.device, dtype=inference_dtype)
|
||||
self.sd_pipeline.unet.to(self.accelerator.device, dtype=inference_dtype)
|
||||
|
||||
trainable_layers = self.sd_pipeline.get_trainable_layers()
|
||||
|
||||
self.accelerator.register_save_state_pre_hook(self._save_model_hook)
|
||||
self.accelerator.register_load_state_pre_hook(self._load_model_hook)
|
||||
|
||||
# Enable TF32 for faster training on Ampere GPUs,
|
||||
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
|
||||
if self.config.allow_tf32:
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
||||
self.optimizer = self._setup_optimizer(
|
||||
trainable_layers.parameters() if not isinstance(trainable_layers, list) else trainable_layers
|
||||
)
|
||||
|
||||
self.neg_prompt_embed = self.sd_pipeline.text_encoder(
|
||||
self.sd_pipeline.tokenizer(
|
||||
[""] if self.config.negative_prompts is None else self.config.negative_prompts,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=self.sd_pipeline.tokenizer.model_max_length,
|
||||
).input_ids.to(self.accelerator.device)
|
||||
)[0]
|
||||
|
||||
# NOTE: for some reason, autocast is necessary for non-lora training but for lora training it isn't necessary and it uses
|
||||
# more memory
|
||||
self.autocast = self.sd_pipeline.autocast or self.accelerator.autocast
|
||||
|
||||
if hasattr(self.sd_pipeline, "use_lora") and self.sd_pipeline.use_lora:
|
||||
unet, self.optimizer = self.accelerator.prepare(trainable_layers, self.optimizer)
|
||||
self.trainable_layers = list(filter(lambda p: p.requires_grad, unet.parameters()))
|
||||
else:
|
||||
self.trainable_layers, self.optimizer = self.accelerator.prepare(trainable_layers, self.optimizer)
|
||||
|
||||
if config.resume_from:
|
||||
logger.info(f"Resuming from {config.resume_from}")
|
||||
self.accelerator.load_state(config.resume_from)
|
||||
self.first_epoch = int(config.resume_from.split("_")[-1]) + 1
|
||||
else:
|
||||
self.first_epoch = 0
|
||||
|
||||
def compute_rewards(self, prompt_image_pairs):
|
||||
reward, reward_metadata = self.reward_fn(
|
||||
prompt_image_pairs["images"], prompt_image_pairs["prompts"], prompt_image_pairs["prompt_metadata"]
|
||||
)
|
||||
return reward
|
||||
|
||||
def step(self, epoch: int, global_step: int):
|
||||
"""
|
||||
Perform a single step of training.
|
||||
|
||||
Args:
|
||||
epoch (int): The current epoch.
|
||||
global_step (int): The current global step.
|
||||
|
||||
Side Effects:
|
||||
- Model weights are updated
|
||||
- Logs the statistics to the accelerator trackers.
|
||||
- If `self.image_samples_callback` is not None, it will be called with the prompt_image_pairs, global_step,
|
||||
and the accelerator tracker.
|
||||
|
||||
Returns:
|
||||
global_step (int): The updated global step.
|
||||
"""
|
||||
info = defaultdict(list)
|
||||
|
||||
self.sd_pipeline.unet.train()
|
||||
|
||||
for _ in range(self.config.train_gradient_accumulation_steps):
|
||||
with self.accelerator.accumulate(self.sd_pipeline.unet), self.autocast(), torch.enable_grad():
|
||||
prompt_image_pairs = self._generate_samples(
|
||||
batch_size=self.config.train_batch_size,
|
||||
)
|
||||
|
||||
rewards = self.compute_rewards(prompt_image_pairs)
|
||||
|
||||
prompt_image_pairs["rewards"] = rewards
|
||||
|
||||
rewards_vis = self.accelerator.gather(rewards).detach().cpu().numpy()
|
||||
|
||||
loss = self.calculate_loss(rewards)
|
||||
|
||||
self.accelerator.backward(loss)
|
||||
|
||||
if self.accelerator.sync_gradients:
|
||||
self.accelerator.clip_grad_norm_(
|
||||
self.trainable_layers.parameters()
|
||||
if not isinstance(self.trainable_layers, list)
|
||||
else self.trainable_layers,
|
||||
self.config.train_max_grad_norm,
|
||||
)
|
||||
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
info["reward_mean"].append(rewards_vis.mean())
|
||||
info["reward_std"].append(rewards_vis.std())
|
||||
info["loss"].append(loss.item())
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if self.accelerator.sync_gradients:
|
||||
# log training-related stuff
|
||||
info = {k: torch.mean(torch.tensor(v)) for k, v in info.items()}
|
||||
info = self.accelerator.reduce(info, reduction="mean")
|
||||
info.update({"epoch": epoch})
|
||||
self.accelerator.log(info, step=global_step)
|
||||
global_step += 1
|
||||
info = defaultdict(list)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Optimization step should have been performed by this point. Please check calculated gradient accumulation settings."
|
||||
)
|
||||
# Logs generated images
|
||||
if self.image_samples_callback is not None and global_step % self.config.log_image_freq == 0:
|
||||
self.image_samples_callback(prompt_image_pairs, global_step, self.accelerator.trackers[0])
|
||||
|
||||
if epoch != 0 and epoch % self.config.save_freq == 0 and self.accelerator.is_main_process:
|
||||
self.accelerator.save_state()
|
||||
|
||||
return global_step
|
||||
|
||||
def calculate_loss(self, rewards):
|
||||
"""
|
||||
Calculate the loss for a batch of an unpacked sample
|
||||
|
||||
Args:
|
||||
rewards (torch.Tensor):
|
||||
Differentiable reward scalars for each generated image, shape: [batch_size]
|
||||
|
||||
Returns:
|
||||
loss (torch.Tensor) (all of these are of shape (1,))
|
||||
"""
|
||||
# Loss is specific to Aesthetic Reward function used in AlignProp (https://huggingface.co/papers/2310.03739)
|
||||
loss = 10.0 - (rewards).mean()
|
||||
return loss
|
||||
|
||||
def loss(
|
||||
self,
|
||||
advantages: torch.Tensor,
|
||||
clip_range: float,
|
||||
ratio: torch.Tensor,
|
||||
):
|
||||
unclipped_loss = -advantages * ratio
|
||||
clipped_loss = -advantages * torch.clamp(
|
||||
ratio,
|
||||
1.0 - clip_range,
|
||||
1.0 + clip_range,
|
||||
)
|
||||
return torch.mean(torch.maximum(unclipped_loss, clipped_loss))
|
||||
|
||||
def _setup_optimizer(self, trainable_layers_parameters):
|
||||
if self.config.train_use_8bit_adam:
|
||||
import bitsandbytes
|
||||
|
||||
optimizer_cls = bitsandbytes.optim.AdamW8bit
|
||||
else:
|
||||
optimizer_cls = torch.optim.AdamW
|
||||
|
||||
return optimizer_cls(
|
||||
trainable_layers_parameters,
|
||||
lr=self.config.train_learning_rate,
|
||||
betas=(self.config.train_adam_beta1, self.config.train_adam_beta2),
|
||||
weight_decay=self.config.train_adam_weight_decay,
|
||||
eps=self.config.train_adam_epsilon,
|
||||
)
|
||||
|
||||
def _save_model_hook(self, models, weights, output_dir):
|
||||
self.sd_pipeline.save_checkpoint(models, weights, output_dir)
|
||||
weights.pop() # ensures that accelerate doesn't try to handle saving of the model
|
||||
|
||||
def _load_model_hook(self, models, input_dir):
|
||||
self.sd_pipeline.load_checkpoint(models, input_dir)
|
||||
models.pop() # ensures that accelerate doesn't try to handle loading of the model
|
||||
|
||||
def _generate_samples(self, batch_size, with_grad=True, prompts=None):
|
||||
"""
|
||||
Generate samples from the model
|
||||
|
||||
Args:
|
||||
batch_size (int): Batch size to use for sampling
|
||||
with_grad (bool): Whether the generated RGBs should have gradients attached to it.
|
||||
|
||||
Returns:
|
||||
prompt_image_pairs (dict[Any])
|
||||
"""
|
||||
prompt_image_pairs = {}
|
||||
|
||||
sample_neg_prompt_embeds = self.neg_prompt_embed.repeat(batch_size, 1, 1)
|
||||
|
||||
if prompts is None:
|
||||
prompts, prompt_metadata = zip(*[self.prompt_fn() for _ in range(batch_size)])
|
||||
else:
|
||||
prompt_metadata = [{} for _ in range(batch_size)]
|
||||
|
||||
prompt_ids = self.sd_pipeline.tokenizer(
|
||||
prompts,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=self.sd_pipeline.tokenizer.model_max_length,
|
||||
).input_ids.to(self.accelerator.device)
|
||||
|
||||
prompt_embeds = self.sd_pipeline.text_encoder(prompt_ids)[0]
|
||||
|
||||
if with_grad:
|
||||
sd_output = self.sd_pipeline.rgb_with_grad(
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=sample_neg_prompt_embeds,
|
||||
num_inference_steps=self.config.sample_num_steps,
|
||||
guidance_scale=self.config.sample_guidance_scale,
|
||||
eta=self.config.sample_eta,
|
||||
truncated_backprop_rand=self.config.truncated_backprop_rand,
|
||||
truncated_backprop_timestep=self.config.truncated_backprop_timestep,
|
||||
truncated_rand_backprop_minmax=self.config.truncated_rand_backprop_minmax,
|
||||
output_type="pt",
|
||||
)
|
||||
else:
|
||||
sd_output = self.sd_pipeline(
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=sample_neg_prompt_embeds,
|
||||
num_inference_steps=self.config.sample_num_steps,
|
||||
guidance_scale=self.config.sample_guidance_scale,
|
||||
eta=self.config.sample_eta,
|
||||
output_type="pt",
|
||||
)
|
||||
|
||||
images = sd_output.images
|
||||
|
||||
prompt_image_pairs["images"] = images
|
||||
prompt_image_pairs["prompts"] = prompts
|
||||
prompt_image_pairs["prompt_metadata"] = prompt_metadata
|
||||
|
||||
return prompt_image_pairs
|
||||
|
||||
def train(self, epochs: Optional[int] = None):
|
||||
"""
|
||||
Train the model for a given number of epochs
|
||||
"""
|
||||
global_step = 0
|
||||
if epochs is None:
|
||||
epochs = self.config.num_epochs
|
||||
for epoch in range(self.first_epoch, epochs):
|
||||
global_step = self.step(epoch, global_step)
|
||||
|
||||
def _save_pretrained(self, save_directory):
|
||||
self.sd_pipeline.save_pretrained(save_directory)
|
||||
self.create_model_card()
|
||||
|
||||
# Ensure the model card is saved along with the checkpoint
|
||||
def _save_checkpoint(self, model, trial):
|
||||
if self.args.hub_model_id is None:
|
||||
model_name = Path(self.args.output_dir).name
|
||||
else:
|
||||
model_name = self.args.hub_model_id.split("/")[-1]
|
||||
self.create_model_card(model_name=model_name)
|
||||
super()._save_checkpoint(model, trial)
|
||||
|
||||
def create_model_card(
|
||||
self,
|
||||
model_name: Optional[str] = None,
|
||||
dataset_name: Optional[str] = None,
|
||||
tags: Union[str, list[str], None] = None,
|
||||
):
|
||||
"""
|
||||
Creates a draft of a model card using the information available to the `Trainer`.
|
||||
|
||||
Args:
|
||||
model_name (`str` or `None`, *optional*, defaults to `None`):
|
||||
Name of the model.
|
||||
dataset_name (`str` or `None`, *optional*, defaults to `None`):
|
||||
Name of the dataset used for training.
|
||||
tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
|
||||
Tags to be associated with the model card.
|
||||
"""
|
||||
if not self.is_world_process_zero():
|
||||
return
|
||||
|
||||
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
|
||||
base_model = self.model.config._name_or_path
|
||||
else:
|
||||
base_model = None
|
||||
|
||||
# normalize `tags` to a mutable set
|
||||
if tags is None:
|
||||
tags = set()
|
||||
elif isinstance(tags, str):
|
||||
tags = {tags}
|
||||
else:
|
||||
tags = set(tags)
|
||||
|
||||
if hasattr(self.model.config, "unsloth_version"):
|
||||
tags.add("unsloth")
|
||||
|
||||
tags.update(self._tag_names)
|
||||
|
||||
citation = textwrap.dedent("""\
|
||||
@article{prabhudesai2024aligning,
|
||||
title = {{Aligning Text-to-Image Diffusion Models with Reward Backpropagation}},
|
||||
author = {Mihir Prabhudesai and Anirudh Goyal and Deepak Pathak and Katerina Fragkiadaki},
|
||||
year = 2024,
|
||||
eprint = {arXiv:2310.03739}
|
||||
}""")
|
||||
|
||||
model_card = generate_model_card(
|
||||
base_model=base_model,
|
||||
model_name=model_name,
|
||||
hub_model_id=self.hub_model_id,
|
||||
dataset_name=dataset_name,
|
||||
tags=tags,
|
||||
wandb_url=wandb.run.url if is_wandb_available() and wandb.run is not None else None,
|
||||
comet_url=get_comet_experiment_url(),
|
||||
trainer_name="AlignProp",
|
||||
trainer_citation=citation,
|
||||
paper_title="Aligning Text-to-Image Diffusion Models with Reward Backpropagation",
|
||||
paper_id="2310.03739",
|
||||
)
|
||||
|
||||
model_card.save(os.path.join(self.args.output_dir, "README.md"))
|
||||
class UnslothAlignPropTrainer(_UnslothAlignPropTrainer):
|
||||
"""
|
||||
|
||||
The AlignPropTrainer uses Deep Diffusion Policy Optimization to optimise diffusion models. Note, this trainer is
|
||||
heavily inspired by the work here: https://github.com/mihirp1998/AlignProp/ As of now only Stable Diffusion based
|
||||
pipelines are supported
|
||||
|
||||
Attributes:
|
||||
config (`AlignPropConfig`):
|
||||
Configuration object for AlignPropTrainer. Check the documentation of `PPOConfig` for more details.
|
||||
reward_function (`Callable[[torch.Tensor, tuple[str], tuple[Any]], torch.Tensor]`):
|
||||
Reward function to be used
|
||||
prompt_function (`Callable[[], tuple[str, Any]]`):
|
||||
Function to generate prompts to guide model
|
||||
sd_pipeline (`DDPOStableDiffusionPipeline`):
|
||||
Stable Diffusion pipeline to be used for training.
|
||||
image_samples_hook (`Optional[Callable[[Any, Any, Any], Any]]`):
|
||||
Hook to be called to log images
|
||||
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
reward_function,
|
||||
prompt_function,
|
||||
sd_pipeline,
|
||||
image_samples_hook = None,
|
||||
**kwargs
|
||||
):
|
||||
if args is None: args = UnslothAlignPropConfig()
|
||||
other_metrics = []
|
||||
|
||||
from unsloth_zoo.logging_utils import PatchRLStatistics
|
||||
PatchRLStatistics('alignprop_trainer', other_metrics)
|
||||
|
||||
super().__init__(
|
||||
config = config,
|
||||
reward_function = reward_function,
|
||||
prompt_function = prompt_function,
|
||||
sd_pipeline = sd_pipeline,
|
||||
image_samples_hook = image_samples_hook,**kwargs)
|
||||
|
||||
pass
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,901 @@
|
||||
"""
|
||||
2025.8.4
|
||||
2025.8.5
|
||||
4.55.1
|
||||
0.21.0
|
||||
__UNSLOTH_VERSIONING__
|
||||
"""
|
||||
from torch import Tensor
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable
|
||||
from trl.trainer.ddpo_trainer import (Accelerator, Any, Callable, DDPOConfig, DDPOStableDiffusionPipeline, DDPOTrainer, Optional, Path, PerPromptStatTracker, ProjectConfiguration, PyTorchModelHubMixin, Union, defaultdict, futures, generate_model_card, get_comet_experiment_url, is_wandb_available, logger, os, set_seed, textwrap, torch, warnings)
|
||||
|
||||
|
||||
import os
|
||||
from typing import *
|
||||
from dataclasses import dataclass, field
|
||||
from packaging.version import Version
|
||||
import torch
|
||||
import numpy as np
|
||||
from contextlib import nullcontext
|
||||
from torch.nn import functional as F
|
||||
from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling
|
||||
|
||||
torch_compile_options = {
|
||||
"epilogue_fusion" : True,
|
||||
"max_autotune" : False,
|
||||
"shape_padding" : True,
|
||||
"trace.enabled" : False,
|
||||
"triton.cudagraphs" : False,
|
||||
}
|
||||
|
||||
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
|
||||
def chunked_selective_log_softmax(logits, index):
|
||||
# Split into 4 chunks only
|
||||
chunked_logits = torch.chunk(logits.reshape(-1, logits.shape[-1]), chunks = 4, dim = 0)
|
||||
chunked_index = torch.chunk(index.reshape(-1), chunks = 4, dim = 0)
|
||||
all_per_token_logps = []
|
||||
# Below loop does the same as selective_log_softmax(chunk_logits, chunk_index)
|
||||
for chunk_logits, chunk_index in zip(chunked_logits, chunked_index):
|
||||
chunk_logits = chunk_logits.to(torch.float32)
|
||||
selected_logits = torch.gather(chunk_logits, dim = -1, index = chunk_index.unsqueeze(-1)).squeeze(-1)
|
||||
logsumexp_values = torch.logsumexp(chunk_logits, dim = -1)
|
||||
per_token_logps = selected_logits - logsumexp_values
|
||||
all_per_token_logps.append(per_token_logps)
|
||||
pass
|
||||
all_per_token_logps = torch.concat(all_per_token_logps)
|
||||
all_per_token_logps = all_per_token_logps.reshape((logits.shape[0], logits.shape[1]))
|
||||
return all_per_token_logps
|
||||
@dataclass
|
||||
class UnslothDDPOConfig(DDPOConfig):
|
||||
"""
|
||||
|
||||
Configuration class for the [`DDPOTrainer`].
|
||||
|
||||
Using [`~transformers.HfArgumentParser`] we can turn this class into
|
||||
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
|
||||
command line.
|
||||
|
||||
Parameters:
|
||||
exp_name (`str`, *optional*, defaults to `os.path.basename(sys.argv[0])[: -len(".py")]`):
|
||||
Name of this experiment (by default is the file name without the extension name).
|
||||
run_name (`str`, *optional*, defaults to `""`):
|
||||
Name of this run.
|
||||
seed (`int`, *optional*, defaults to `0`):
|
||||
Random seed.
|
||||
log_with (`Literal["wandb", "tensorboard"]]` or `None`, *optional*, defaults to `None`):
|
||||
Log with either 'wandb' or 'tensorboard', check
|
||||
https://huggingface.co/docs/accelerate/usage_guides/tracking for more details.
|
||||
tracker_kwargs (`Dict`, *optional*, defaults to `{}`):
|
||||
Keyword arguments for the tracker (e.g. wandb_project).
|
||||
accelerator_kwargs (`Dict`, *optional*, defaults to `{}`):
|
||||
Keyword arguments for the accelerator.
|
||||
project_kwargs (`Dict`, *optional*, defaults to `{}`):
|
||||
Keyword arguments for the accelerator project config (e.g. `logging_dir`).
|
||||
tracker_project_name (`str`, *optional*, defaults to `"trl"`):
|
||||
Name of project to use for tracking.
|
||||
logdir (`str`, *optional*, defaults to `"logs"`):
|
||||
Top-level logging directory for checkpoint saving.
|
||||
num_epochs (`int`, *optional*, defaults to `100`):
|
||||
Number of epochs to train.
|
||||
save_freq (`int`, *optional*, defaults to `1`):
|
||||
Number of epochs between saving model checkpoints.
|
||||
num_checkpoint_limit (`int`, *optional*, defaults to `5`):
|
||||
Number of checkpoints to keep before overwriting old ones.
|
||||
mixed_precision (`str`, *optional*, defaults to `"fp16"`):
|
||||
Mixed precision training.
|
||||
allow_tf32 (`bool`, *optional*, defaults to `True`):
|
||||
Allow `tf32` on Ampere GPUs.
|
||||
resume_from (`str`, *optional*, defaults to `""`):
|
||||
Resume training from a checkpoint.
|
||||
sample_num_steps (`int`, *optional*, defaults to `50`):
|
||||
Number of sampler inference steps.
|
||||
sample_eta (`float`, *optional*, defaults to `1.0`):
|
||||
Eta parameter for the DDIM sampler.
|
||||
sample_guidance_scale (`float`, *optional*, defaults to `5.0`):
|
||||
Classifier-free guidance weight.
|
||||
sample_batch_size (`int`, *optional*, defaults to `1`):
|
||||
Batch size (per GPU) to use for sampling.
|
||||
sample_num_batches_per_epoch (`int`, *optional*, defaults to `2`):
|
||||
Number of batches to sample per epoch.
|
||||
train_batch_size (`int`, *optional*, defaults to `1`):
|
||||
Batch size (per GPU) to use for training.
|
||||
train_use_8bit_adam (`bool`, *optional*, defaults to `False`):
|
||||
Use 8bit Adam optimizer from bitsandbytes.
|
||||
train_learning_rate (`float`, *optional*, defaults to `3e-4`):
|
||||
Learning rate.
|
||||
train_adam_beta1 (`float`, *optional*, defaults to `0.9`):
|
||||
Adam beta1.
|
||||
train_adam_beta2 (`float`, *optional*, defaults to `0.999`):
|
||||
Adam beta2.
|
||||
train_adam_weight_decay (`float`, *optional*, defaults to `1e-4`):
|
||||
Adam weight decay.
|
||||
train_adam_epsilon (`float`, *optional*, defaults to `1e-8`):
|
||||
Adam epsilon.
|
||||
train_gradient_accumulation_steps (`int`, *optional*, defaults to `1`):
|
||||
Number of gradient accumulation steps.
|
||||
train_max_grad_norm (`float`, *optional*, defaults to `1.0`):
|
||||
Maximum gradient norm for gradient clipping.
|
||||
train_num_inner_epochs (`int`, *optional*, defaults to `1`):
|
||||
Number of inner epochs per outer epoch.
|
||||
train_cfg (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use classifier-free guidance during training.
|
||||
train_adv_clip_max (`float`, *optional*, defaults to `5.0`):
|
||||
Clip advantages to the range.
|
||||
train_clip_range (`float`, *optional*, defaults to `1e-4`):
|
||||
PPO clip range.
|
||||
train_timestep_fraction (`float`, *optional*, defaults to `1.0`):
|
||||
Fraction of timesteps to train on.
|
||||
per_prompt_stat_tracking (`bool`, *optional*, defaults to `False`):
|
||||
Whether to track statistics for each prompt separately.
|
||||
per_prompt_stat_tracking_buffer_size (`int`, *optional*, defaults to `16`):
|
||||
Number of reward values to store in the buffer for each prompt.
|
||||
per_prompt_stat_tracking_min_count (`int`, *optional*, defaults to `16`):
|
||||
Minimum number of reward values to store in the buffer.
|
||||
async_reward_computation (`bool`, *optional*, defaults to `False`):
|
||||
Whether to compute rewards asynchronously.
|
||||
max_workers (`int`, *optional*, defaults to `2`):
|
||||
Maximum number of workers to use for async reward computation.
|
||||
negative_prompts (`str`, *optional*, defaults to `""`):
|
||||
Comma-separated list of prompts to use as negative examples.
|
||||
push_to_hub (`bool`, *optional*, defaults to `False`):
|
||||
Whether to push the final model checkpoint to the Hub.
|
||||
|
||||
"""
|
||||
vllm_sampling_params: Optional[Any] = field(
|
||||
default = None,
|
||||
metadata = {'help': 'vLLM SamplingParams'},
|
||||
)
|
||||
unsloth_num_chunks : Optional[int] = field(
|
||||
default = -1,
|
||||
metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
|
||||
)
|
||||
def __init__(
|
||||
self,
|
||||
exp_name = 'inference',
|
||||
run_name = '',
|
||||
seed = 3407,
|
||||
log_with = None,
|
||||
tracker_project_name = 'trl',
|
||||
logdir = 'logs',
|
||||
num_epochs = 100,
|
||||
save_freq = 1,
|
||||
num_checkpoint_limit = 5,
|
||||
mixed_precision = 'fp16',
|
||||
allow_tf32 = True,
|
||||
resume_from = '',
|
||||
sample_num_steps = 50,
|
||||
sample_eta = 1.0,
|
||||
sample_guidance_scale = 5.0,
|
||||
sample_batch_size = 1,
|
||||
sample_num_batches_per_epoch = 2,
|
||||
train_batch_size = 1,
|
||||
train_use_8bit_adam = False,
|
||||
train_learning_rate = 5e-05,
|
||||
train_adam_beta1 = 0.9,
|
||||
train_adam_beta2 = 0.999,
|
||||
train_adam_weight_decay = 0.01,
|
||||
train_adam_epsilon = 1e-08,
|
||||
train_gradient_accumulation_steps = 2,
|
||||
train_max_grad_norm = 1.0,
|
||||
train_num_inner_epochs = 1,
|
||||
train_cfg = True,
|
||||
train_adv_clip_max = 5.0,
|
||||
train_clip_range = 0.0001,
|
||||
train_timestep_fraction = 1.0,
|
||||
per_prompt_stat_tracking = False,
|
||||
per_prompt_stat_tracking_buffer_size = 16,
|
||||
per_prompt_stat_tracking_min_count = 16,
|
||||
async_reward_computation = False,
|
||||
max_workers = 2,
|
||||
negative_prompts = '',
|
||||
push_to_hub = False,
|
||||
vllm_sampling_params = None,
|
||||
unsloth_num_chunks = -1,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
super().__init__(
|
||||
exp_name = exp_name,
|
||||
run_name = run_name,
|
||||
seed = seed,
|
||||
log_with = log_with,
|
||||
tracker_project_name = tracker_project_name,
|
||||
logdir = logdir,
|
||||
num_epochs = num_epochs,
|
||||
save_freq = save_freq,
|
||||
num_checkpoint_limit = num_checkpoint_limit,
|
||||
mixed_precision = mixed_precision,
|
||||
allow_tf32 = allow_tf32,
|
||||
resume_from = resume_from,
|
||||
sample_num_steps = sample_num_steps,
|
||||
sample_eta = sample_eta,
|
||||
sample_guidance_scale = sample_guidance_scale,
|
||||
sample_batch_size = sample_batch_size,
|
||||
sample_num_batches_per_epoch = sample_num_batches_per_epoch,
|
||||
train_batch_size = train_batch_size,
|
||||
train_use_8bit_adam = train_use_8bit_adam,
|
||||
train_learning_rate = train_learning_rate,
|
||||
train_adam_beta1 = train_adam_beta1,
|
||||
train_adam_beta2 = train_adam_beta2,
|
||||
train_adam_weight_decay = train_adam_weight_decay,
|
||||
train_adam_epsilon = train_adam_epsilon,
|
||||
train_gradient_accumulation_steps = train_gradient_accumulation_steps,
|
||||
train_max_grad_norm = train_max_grad_norm,
|
||||
train_num_inner_epochs = train_num_inner_epochs,
|
||||
train_cfg = train_cfg,
|
||||
train_adv_clip_max = train_adv_clip_max,
|
||||
train_clip_range = train_clip_range,
|
||||
train_timestep_fraction = train_timestep_fraction,
|
||||
per_prompt_stat_tracking = per_prompt_stat_tracking,
|
||||
per_prompt_stat_tracking_buffer_size = per_prompt_stat_tracking_buffer_size,
|
||||
per_prompt_stat_tracking_min_count = per_prompt_stat_tracking_min_count,
|
||||
async_reward_computation = async_reward_computation,
|
||||
max_workers = max_workers,
|
||||
negative_prompts = negative_prompts,
|
||||
push_to_hub = push_to_hub,**kwargs)
|
||||
self.vllm_sampling_params = vllm_sampling_params
|
||||
self.unsloth_num_chunks = unsloth_num_chunks
|
||||
pass
|
||||
|
||||
class _UnslothDDPOTrainer(PyTorchModelHubMixin):
|
||||
""""""
|
||||
|
||||
_tag_names = ["trl", "ddpo"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: DDPOConfig,
|
||||
reward_function: Callable[[torch.Tensor, tuple[str], tuple[Any]], torch.Tensor],
|
||||
prompt_function: Callable[[], tuple[str, Any]],
|
||||
sd_pipeline: DDPOStableDiffusionPipeline,
|
||||
image_samples_hook: Optional[Callable[[Any, Any, Any], Any]] = None,
|
||||
):
|
||||
warnings.warn(
|
||||
"DDPOTrainer is deprecated and will be removed in version 0.23.0.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
if image_samples_hook is None:
|
||||
warnings.warn("No image_samples_hook provided; no images will be logged")
|
||||
|
||||
self.prompt_fn = prompt_function
|
||||
self.reward_fn = reward_function
|
||||
self.config = config
|
||||
self.image_samples_callback = image_samples_hook
|
||||
|
||||
accelerator_project_config = ProjectConfiguration(**self.config.project_kwargs)
|
||||
|
||||
if self.config.resume_from:
|
||||
self.config.resume_from = os.path.normpath(os.path.expanduser(self.config.resume_from))
|
||||
if "checkpoint_" not in os.path.basename(self.config.resume_from):
|
||||
# get the most recent checkpoint in this directory
|
||||
checkpoints = list(
|
||||
filter(
|
||||
lambda x: "checkpoint_" in x,
|
||||
os.listdir(self.config.resume_from),
|
||||
)
|
||||
)
|
||||
if len(checkpoints) == 0:
|
||||
raise ValueError(f"No checkpoints found in {self.config.resume_from}")
|
||||
checkpoint_numbers = sorted([int(x.split("_")[-1]) for x in checkpoints])
|
||||
self.config.resume_from = os.path.join(
|
||||
self.config.resume_from,
|
||||
f"checkpoint_{checkpoint_numbers[-1]}",
|
||||
)
|
||||
|
||||
accelerator_project_config.iteration = checkpoint_numbers[-1] + 1
|
||||
|
||||
# number of timesteps within each trajectory to train on
|
||||
self.num_train_timesteps = int(self.config.sample_num_steps * self.config.train_timestep_fraction)
|
||||
|
||||
self.accelerator = Accelerator(
|
||||
log_with=self.config.log_with,
|
||||
mixed_precision=self.config.mixed_precision,
|
||||
project_config=accelerator_project_config,
|
||||
# we always accumulate gradients across timesteps; we want config.train.gradient_accumulation_steps to be the
|
||||
# number of *samples* we accumulate across, so we need to multiply by the number of training timesteps to get
|
||||
# the total number of optimizer steps to accumulate across.
|
||||
gradient_accumulation_steps=self.config.train_gradient_accumulation_steps * self.num_train_timesteps,
|
||||
**self.config.accelerator_kwargs,
|
||||
)
|
||||
|
||||
is_okay, message = self._config_check()
|
||||
if not is_okay:
|
||||
raise ValueError(message)
|
||||
|
||||
is_using_tensorboard = config.log_with is not None and config.log_with == "tensorboard"
|
||||
|
||||
if self.accelerator.is_main_process:
|
||||
self.accelerator.init_trackers(
|
||||
self.config.tracker_project_name,
|
||||
config=dict(ddpo_trainer_config=config.to_dict()) if not is_using_tensorboard else config.to_dict(),
|
||||
init_kwargs=self.config.tracker_kwargs,
|
||||
)
|
||||
|
||||
logger.info(f"\n{config}")
|
||||
|
||||
set_seed(self.config.seed, device_specific=True)
|
||||
|
||||
self.sd_pipeline = sd_pipeline
|
||||
|
||||
self.sd_pipeline.set_progress_bar_config(
|
||||
position=1,
|
||||
disable=not self.accelerator.is_local_main_process,
|
||||
leave=False,
|
||||
desc="Timestep",
|
||||
dynamic_ncols=True,
|
||||
)
|
||||
|
||||
# For mixed precision training we cast all non-trainable weights [vae, non-lora text_encoder and non-lora unet] to half-precision
|
||||
# as these weights are only used for inference, keeping weights in full precision is not required.
|
||||
if self.accelerator.mixed_precision == "fp16":
|
||||
inference_dtype = torch.float16
|
||||
elif self.accelerator.mixed_precision == "bf16":
|
||||
inference_dtype = torch.bfloat16
|
||||
else:
|
||||
inference_dtype = torch.float32
|
||||
|
||||
self.sd_pipeline.vae.to(self.accelerator.device, dtype=inference_dtype)
|
||||
self.sd_pipeline.text_encoder.to(self.accelerator.device, dtype=inference_dtype)
|
||||
self.sd_pipeline.unet.to(self.accelerator.device, dtype=inference_dtype)
|
||||
|
||||
trainable_layers = self.sd_pipeline.get_trainable_layers()
|
||||
|
||||
self.accelerator.register_save_state_pre_hook(self._save_model_hook)
|
||||
self.accelerator.register_load_state_pre_hook(self._load_model_hook)
|
||||
|
||||
# Enable TF32 for faster training on Ampere GPUs,
|
||||
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
|
||||
if self.config.allow_tf32:
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
||||
self.optimizer = self._setup_optimizer(
|
||||
trainable_layers.parameters() if not isinstance(trainable_layers, list) else trainable_layers
|
||||
)
|
||||
|
||||
self.neg_prompt_embed = self.sd_pipeline.text_encoder(
|
||||
self.sd_pipeline.tokenizer(
|
||||
[""] if self.config.negative_prompts is None else self.config.negative_prompts,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=self.sd_pipeline.tokenizer.model_max_length,
|
||||
).input_ids.to(self.accelerator.device)
|
||||
)[0]
|
||||
|
||||
if config.per_prompt_stat_tracking:
|
||||
self.stat_tracker = PerPromptStatTracker(
|
||||
config.per_prompt_stat_tracking_buffer_size,
|
||||
config.per_prompt_stat_tracking_min_count,
|
||||
)
|
||||
|
||||
# NOTE: for some reason, autocast is necessary for non-lora training but for lora training it isn't necessary and it uses
|
||||
# more memory
|
||||
self.autocast = self.sd_pipeline.autocast or self.accelerator.autocast
|
||||
|
||||
if hasattr(self.sd_pipeline, "use_lora") and self.sd_pipeline.use_lora:
|
||||
unet, self.optimizer = self.accelerator.prepare(trainable_layers, self.optimizer)
|
||||
self.trainable_layers = list(filter(lambda p: p.requires_grad, unet.parameters()))
|
||||
else:
|
||||
self.trainable_layers, self.optimizer = self.accelerator.prepare(trainable_layers, self.optimizer)
|
||||
|
||||
if self.config.async_reward_computation:
|
||||
self.executor = futures.ThreadPoolExecutor(max_workers=config.max_workers)
|
||||
|
||||
if config.resume_from:
|
||||
logger.info(f"Resuming from {config.resume_from}")
|
||||
self.accelerator.load_state(config.resume_from)
|
||||
self.first_epoch = int(config.resume_from.split("_")[-1]) + 1
|
||||
else:
|
||||
self.first_epoch = 0
|
||||
|
||||
def compute_rewards(self, prompt_image_pairs, is_async=False):
|
||||
if not is_async:
|
||||
rewards = []
|
||||
for images, prompts, prompt_metadata in prompt_image_pairs:
|
||||
reward, reward_metadata = self.reward_fn(images, prompts, prompt_metadata)
|
||||
rewards.append(
|
||||
(
|
||||
torch.as_tensor(reward, device=self.accelerator.device),
|
||||
reward_metadata,
|
||||
)
|
||||
)
|
||||
else:
|
||||
rewards = self.executor.map(lambda x: self.reward_fn(*x), prompt_image_pairs)
|
||||
rewards = [
|
||||
(torch.as_tensor(reward.result(), device=self.accelerator.device), reward_metadata.result())
|
||||
for reward, reward_metadata in rewards
|
||||
]
|
||||
|
||||
return zip(*rewards)
|
||||
|
||||
def step(self, epoch: int, global_step: int):
|
||||
"""
|
||||
Perform a single step of training.
|
||||
|
||||
Args:
|
||||
epoch (int): The current epoch.
|
||||
global_step (int): The current global step.
|
||||
|
||||
Side Effects:
|
||||
- Model weights are updated
|
||||
- Logs the statistics to the accelerator trackers.
|
||||
- If `self.image_samples_callback` is not None, it will be called with the prompt_image_pairs, global_step,
|
||||
and the accelerator tracker.
|
||||
|
||||
Returns:
|
||||
global_step (int): The updated global step.
|
||||
|
||||
"""
|
||||
samples, prompt_image_data = self._generate_samples(
|
||||
iterations=self.config.sample_num_batches_per_epoch,
|
||||
batch_size=self.config.sample_batch_size,
|
||||
)
|
||||
|
||||
# collate samples into dict where each entry has shape (num_batches_per_epoch * sample.batch_size, ...)
|
||||
samples = {k: torch.cat([s[k] for s in samples]) for k in samples[0].keys()}
|
||||
rewards, rewards_metadata = self.compute_rewards(
|
||||
prompt_image_data, is_async=self.config.async_reward_computation
|
||||
)
|
||||
|
||||
for i, image_data in enumerate(prompt_image_data):
|
||||
image_data.extend([rewards[i], rewards_metadata[i]])
|
||||
|
||||
if self.image_samples_callback is not None:
|
||||
self.image_samples_callback(prompt_image_data, global_step, self.accelerator.trackers[0])
|
||||
|
||||
rewards = torch.cat(rewards)
|
||||
rewards = self.accelerator.gather(rewards).cpu().numpy()
|
||||
|
||||
self.accelerator.log(
|
||||
{
|
||||
"reward": rewards,
|
||||
"epoch": epoch,
|
||||
"reward_mean": rewards.mean(),
|
||||
"reward_std": rewards.std(),
|
||||
},
|
||||
step=global_step,
|
||||
)
|
||||
|
||||
if self.config.per_prompt_stat_tracking:
|
||||
# gather the prompts across processes
|
||||
prompt_ids = self.accelerator.gather(samples["prompt_ids"]).cpu().numpy()
|
||||
prompts = self.sd_pipeline.tokenizer.batch_decode(prompt_ids, skip_special_tokens=True)
|
||||
advantages = self.stat_tracker.update(prompts, rewards)
|
||||
else:
|
||||
advantages = (rewards - rewards.mean()) / (rewards.std() + 1e-8)
|
||||
|
||||
# ungather advantages; keep the entries corresponding to the samples on this process
|
||||
samples["advantages"] = (
|
||||
torch.as_tensor(advantages)
|
||||
.reshape(self.accelerator.num_processes, -1)[self.accelerator.process_index]
|
||||
.to(self.accelerator.device)
|
||||
)
|
||||
|
||||
del samples["prompt_ids"]
|
||||
|
||||
total_batch_size, num_timesteps = samples["timesteps"].shape
|
||||
|
||||
for inner_epoch in range(self.config.train_num_inner_epochs):
|
||||
# shuffle samples along batch dimension
|
||||
perm = torch.randperm(total_batch_size, device=self.accelerator.device)
|
||||
samples = {k: v[perm] for k, v in samples.items()}
|
||||
|
||||
# shuffle along time dimension independently for each sample
|
||||
# still trying to understand the code below
|
||||
perms = torch.stack(
|
||||
[torch.randperm(num_timesteps, device=self.accelerator.device) for _ in range(total_batch_size)]
|
||||
)
|
||||
|
||||
for key in ["timesteps", "latents", "next_latents", "log_probs"]:
|
||||
samples[key] = samples[key][
|
||||
torch.arange(total_batch_size, device=self.accelerator.device)[:, None],
|
||||
perms,
|
||||
]
|
||||
|
||||
original_keys = samples.keys()
|
||||
original_values = samples.values()
|
||||
# rebatch them as user defined train_batch_size is different from sample_batch_size
|
||||
reshaped_values = [v.reshape(-1, self.config.train_batch_size, *v.shape[1:]) for v in original_values]
|
||||
|
||||
# Transpose the list of original values
|
||||
transposed_values = zip(*reshaped_values)
|
||||
# Create new dictionaries for each row of transposed values
|
||||
samples_batched = [dict(zip(original_keys, row_values)) for row_values in transposed_values]
|
||||
|
||||
self.sd_pipeline.unet.train()
|
||||
global_step = self._train_batched_samples(inner_epoch, epoch, global_step, samples_batched)
|
||||
# ensure optimization step at the end of the inner epoch
|
||||
if not self.accelerator.sync_gradients:
|
||||
raise ValueError(
|
||||
"Optimization step should have been performed by this point. Please check calculated gradient accumulation settings."
|
||||
)
|
||||
|
||||
if epoch != 0 and epoch % self.config.save_freq == 0 and self.accelerator.is_main_process:
|
||||
self.accelerator.save_state()
|
||||
|
||||
return global_step
|
||||
|
||||
def calculate_loss(self, latents, timesteps, next_latents, log_probs, advantages, embeds):
|
||||
"""
|
||||
Calculate the loss for a batch of an unpacked sample
|
||||
|
||||
Args:
|
||||
latents (torch.Tensor):
|
||||
The latents sampled from the diffusion model, shape: [batch_size, num_channels_latents, height, width]
|
||||
timesteps (torch.Tensor):
|
||||
The timesteps sampled from the diffusion model, shape: [batch_size]
|
||||
next_latents (torch.Tensor):
|
||||
The next latents sampled from the diffusion model, shape: [batch_size, num_channels_latents, height,
|
||||
width]
|
||||
log_probs (torch.Tensor):
|
||||
The log probabilities of the latents, shape: [batch_size]
|
||||
advantages (torch.Tensor):
|
||||
The advantages of the latents, shape: [batch_size]
|
||||
embeds (torch.Tensor):
|
||||
The embeddings of the prompts, shape: [2*batch_size or batch_size, ...] Note: the "or" is because if
|
||||
train_cfg is True, the expectation is that negative prompts are concatenated to the embeds
|
||||
|
||||
Returns:
|
||||
loss (torch.Tensor), approx_kl (torch.Tensor), clipfrac (torch.Tensor) (all of these are of shape (1,))
|
||||
"""
|
||||
with self.autocast():
|
||||
if self.config.train_cfg:
|
||||
noise_pred = self.sd_pipeline.unet(
|
||||
torch.cat([latents] * 2),
|
||||
torch.cat([timesteps] * 2),
|
||||
embeds,
|
||||
).sample
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + self.config.sample_guidance_scale * (
|
||||
noise_pred_text - noise_pred_uncond
|
||||
)
|
||||
else:
|
||||
noise_pred = self.sd_pipeline.unet(
|
||||
latents,
|
||||
timesteps,
|
||||
embeds,
|
||||
).sample
|
||||
# compute the log prob of next_latents given latents under the current model
|
||||
|
||||
scheduler_step_output = self.sd_pipeline.scheduler_step(
|
||||
noise_pred,
|
||||
timesteps,
|
||||
latents,
|
||||
eta=self.config.sample_eta,
|
||||
prev_sample=next_latents,
|
||||
)
|
||||
|
||||
log_prob = scheduler_step_output.log_probs
|
||||
|
||||
advantages = torch.clamp(
|
||||
advantages,
|
||||
-self.config.train_adv_clip_max,
|
||||
self.config.train_adv_clip_max,
|
||||
)
|
||||
|
||||
ratio = torch.exp(log_prob - log_probs)
|
||||
|
||||
loss = self.loss(advantages, self.config.train_clip_range, ratio)
|
||||
|
||||
approx_kl = 0.5 * torch.mean((log_prob - log_probs) ** 2)
|
||||
|
||||
clipfrac = torch.mean((torch.abs(ratio - 1.0) > self.config.train_clip_range).float())
|
||||
|
||||
return loss, approx_kl, clipfrac
|
||||
|
||||
def loss(
|
||||
self,
|
||||
advantages: torch.Tensor,
|
||||
clip_range: float,
|
||||
ratio: torch.Tensor,
|
||||
):
|
||||
unclipped_loss = -advantages * ratio
|
||||
clipped_loss = -advantages * torch.clamp(
|
||||
ratio,
|
||||
1.0 - clip_range,
|
||||
1.0 + clip_range,
|
||||
)
|
||||
return torch.mean(torch.maximum(unclipped_loss, clipped_loss))
|
||||
|
||||
def _setup_optimizer(self, trainable_layers_parameters):
|
||||
if self.config.train_use_8bit_adam:
|
||||
import bitsandbytes
|
||||
|
||||
optimizer_cls = bitsandbytes.optim.AdamW8bit
|
||||
else:
|
||||
optimizer_cls = torch.optim.AdamW
|
||||
|
||||
return optimizer_cls(
|
||||
trainable_layers_parameters,
|
||||
lr=self.config.train_learning_rate,
|
||||
betas=(self.config.train_adam_beta1, self.config.train_adam_beta2),
|
||||
weight_decay=self.config.train_adam_weight_decay,
|
||||
eps=self.config.train_adam_epsilon,
|
||||
)
|
||||
|
||||
def _save_model_hook(self, models, weights, output_dir):
|
||||
self.sd_pipeline.save_checkpoint(models, weights, output_dir)
|
||||
weights.pop() # ensures that accelerate doesn't try to handle saving of the model
|
||||
|
||||
def _load_model_hook(self, models, input_dir):
|
||||
self.sd_pipeline.load_checkpoint(models, input_dir)
|
||||
models.pop() # ensures that accelerate doesn't try to handle loading of the model
|
||||
|
||||
def _generate_samples(self, iterations, batch_size):
|
||||
"""
|
||||
Generate samples from the model
|
||||
|
||||
Args:
|
||||
iterations (int): Number of iterations to generate samples for
|
||||
batch_size (int): Batch size to use for sampling
|
||||
|
||||
Returns:
|
||||
samples (list[dict[str, torch.Tensor]]), prompt_image_pairs (list[list[Any]])
|
||||
"""
|
||||
samples = []
|
||||
prompt_image_pairs = []
|
||||
self.sd_pipeline.unet.eval()
|
||||
|
||||
sample_neg_prompt_embeds = self.neg_prompt_embed.repeat(batch_size, 1, 1)
|
||||
|
||||
for _ in range(iterations):
|
||||
prompts, prompt_metadata = zip(*[self.prompt_fn() for _ in range(batch_size)])
|
||||
|
||||
prompt_ids = self.sd_pipeline.tokenizer(
|
||||
prompts,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=self.sd_pipeline.tokenizer.model_max_length,
|
||||
).input_ids.to(self.accelerator.device)
|
||||
prompt_embeds = self.sd_pipeline.text_encoder(prompt_ids)[0]
|
||||
|
||||
with self.autocast():
|
||||
sd_output = self.sd_pipeline(
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=sample_neg_prompt_embeds,
|
||||
num_inference_steps=self.config.sample_num_steps,
|
||||
guidance_scale=self.config.sample_guidance_scale,
|
||||
eta=self.config.sample_eta,
|
||||
output_type="pt",
|
||||
)
|
||||
|
||||
images = sd_output.images
|
||||
latents = sd_output.latents
|
||||
log_probs = sd_output.log_probs
|
||||
|
||||
latents = torch.stack(latents, dim=1) # (batch_size, num_steps + 1, ...)
|
||||
log_probs = torch.stack(log_probs, dim=1) # (batch_size, num_steps, 1)
|
||||
timesteps = self.sd_pipeline.scheduler.timesteps.repeat(batch_size, 1) # (batch_size, num_steps)
|
||||
|
||||
samples.append(
|
||||
{
|
||||
"prompt_ids": prompt_ids,
|
||||
"prompt_embeds": prompt_embeds,
|
||||
"timesteps": timesteps,
|
||||
"latents": latents[:, :-1], # each entry is the latent before timestep t
|
||||
"next_latents": latents[:, 1:], # each entry is the latent after timestep t
|
||||
"log_probs": log_probs,
|
||||
"negative_prompt_embeds": sample_neg_prompt_embeds,
|
||||
}
|
||||
)
|
||||
prompt_image_pairs.append([images, prompts, prompt_metadata])
|
||||
|
||||
return samples, prompt_image_pairs
|
||||
|
||||
def _train_batched_samples(self, inner_epoch, epoch, global_step, batched_samples):
|
||||
"""
|
||||
Train on a batch of samples. Main training segment
|
||||
|
||||
Args:
|
||||
inner_epoch (int): The current inner epoch
|
||||
epoch (int): The current epoch
|
||||
global_step (int): The current global step
|
||||
batched_samples (list[dict[str, torch.Tensor]]): The batched samples to train on
|
||||
|
||||
Side Effects:
|
||||
- Model weights are updated
|
||||
- Logs the statistics to the accelerator trackers.
|
||||
|
||||
Returns:
|
||||
global_step (int): The updated global step
|
||||
"""
|
||||
info = defaultdict(list)
|
||||
for _i, sample in enumerate(batched_samples):
|
||||
if self.config.train_cfg:
|
||||
# concat negative prompts to sample prompts to avoid two forward passes
|
||||
embeds = torch.cat([sample["negative_prompt_embeds"], sample["prompt_embeds"]])
|
||||
else:
|
||||
embeds = sample["prompt_embeds"]
|
||||
|
||||
for j in range(self.num_train_timesteps):
|
||||
with self.accelerator.accumulate(self.sd_pipeline.unet):
|
||||
loss, approx_kl, clipfrac = self.calculate_loss(
|
||||
sample["latents"][:, j],
|
||||
sample["timesteps"][:, j],
|
||||
sample["next_latents"][:, j],
|
||||
sample["log_probs"][:, j],
|
||||
sample["advantages"],
|
||||
embeds,
|
||||
)
|
||||
info["approx_kl"].append(approx_kl)
|
||||
info["clipfrac"].append(clipfrac)
|
||||
info["loss"].append(loss)
|
||||
|
||||
self.accelerator.backward(loss)
|
||||
if self.accelerator.sync_gradients:
|
||||
self.accelerator.clip_grad_norm_(
|
||||
self.trainable_layers.parameters()
|
||||
if not isinstance(self.trainable_layers, list)
|
||||
else self.trainable_layers,
|
||||
self.config.train_max_grad_norm,
|
||||
)
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if self.accelerator.sync_gradients:
|
||||
# log training-related stuff
|
||||
info = {k: torch.mean(torch.stack(v)) for k, v in info.items()}
|
||||
info = self.accelerator.reduce(info, reduction="mean")
|
||||
info.update({"epoch": epoch, "inner_epoch": inner_epoch})
|
||||
self.accelerator.log(info, step=global_step)
|
||||
global_step += 1
|
||||
info = defaultdict(list)
|
||||
return global_step
|
||||
|
||||
def _config_check(self) -> tuple[bool, str]:
|
||||
samples_per_epoch = (
|
||||
self.config.sample_batch_size * self.accelerator.num_processes * self.config.sample_num_batches_per_epoch
|
||||
)
|
||||
total_train_batch_size = (
|
||||
self.config.train_batch_size
|
||||
* self.accelerator.num_processes
|
||||
* self.config.train_gradient_accumulation_steps
|
||||
)
|
||||
|
||||
if not self.config.sample_batch_size >= self.config.train_batch_size:
|
||||
return (
|
||||
False,
|
||||
f"Sample batch size ({self.config.sample_batch_size}) must be greater than or equal to the train batch size ({self.config.train_batch_size})",
|
||||
)
|
||||
if not self.config.sample_batch_size % self.config.train_batch_size == 0:
|
||||
return (
|
||||
False,
|
||||
f"Sample batch size ({self.config.sample_batch_size}) must be divisible by the train batch size ({self.config.train_batch_size})",
|
||||
)
|
||||
if not samples_per_epoch % total_train_batch_size == 0:
|
||||
return (
|
||||
False,
|
||||
f"Number of samples per epoch ({samples_per_epoch}) must be divisible by the total train batch size ({total_train_batch_size})",
|
||||
)
|
||||
return True, ""
|
||||
|
||||
def train(self, epochs: Optional[int] = None):
|
||||
"""
|
||||
Train the model for a given number of epochs
|
||||
"""
|
||||
global_step = 0
|
||||
if epochs is None:
|
||||
epochs = self.config.num_epochs
|
||||
for epoch in range(self.first_epoch, epochs):
|
||||
global_step = self.step(epoch, global_step)
|
||||
|
||||
def _save_pretrained(self, save_directory):
|
||||
self.sd_pipeline.save_pretrained(save_directory)
|
||||
self.create_model_card()
|
||||
|
||||
# Ensure the model card is saved along with the checkpoint
|
||||
def _save_checkpoint(self, model, trial):
|
||||
if self.args.hub_model_id is None:
|
||||
model_name = Path(self.args.output_dir).name
|
||||
else:
|
||||
model_name = self.args.hub_model_id.split("/")[-1]
|
||||
self.create_model_card(model_name=model_name)
|
||||
super()._save_checkpoint(model, trial)
|
||||
|
||||
def create_model_card(
|
||||
self,
|
||||
model_name: Optional[str] = None,
|
||||
dataset_name: Optional[str] = None,
|
||||
tags: Union[str, list[str], None] = None,
|
||||
):
|
||||
"""
|
||||
Creates a draft of a model card using the information available to the `Trainer`.
|
||||
|
||||
Args:
|
||||
model_name (`str` or `None`, *optional*, defaults to `None`):
|
||||
Name of the model.
|
||||
dataset_name (`str` or `None`, *optional*, defaults to `None`):
|
||||
Name of the dataset used for training.
|
||||
tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
|
||||
Tags to be associated with the model card.
|
||||
"""
|
||||
if not self.is_world_process_zero():
|
||||
return
|
||||
|
||||
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
|
||||
base_model = self.model.config._name_or_path
|
||||
else:
|
||||
base_model = None
|
||||
|
||||
# normalize `tags` to a mutable set
|
||||
if tags is None:
|
||||
tags = set()
|
||||
elif isinstance(tags, str):
|
||||
tags = {tags}
|
||||
else:
|
||||
tags = set(tags)
|
||||
|
||||
if hasattr(self.model.config, "unsloth_version"):
|
||||
tags.add("unsloth")
|
||||
|
||||
tags.update(self._tag_names)
|
||||
|
||||
citation = textwrap.dedent("""\
|
||||
@inproceedings{black2024training,
|
||||
title = {{Training Diffusion Models with Reinforcement Learning}},
|
||||
author = {Kevin Black and Michael Janner and Yilun Du and Ilya Kostrikov and Sergey Levine},
|
||||
year = 2024,
|
||||
booktitle = {The Twelfth International Conference on Learning Representations, {ICLR} 2024, Vienna, Austria, May 7-11, 2024},
|
||||
publisher = {OpenReview.net},
|
||||
url = {https://openreview.net/forum?id=YCWjhGrJFD},
|
||||
}""")
|
||||
|
||||
model_card = generate_model_card(
|
||||
base_model=base_model,
|
||||
model_name=model_name,
|
||||
hub_model_id=self.hub_model_id,
|
||||
dataset_name=dataset_name,
|
||||
tags=tags,
|
||||
wandb_url=wandb.run.url if is_wandb_available() and wandb.run is not None else None,
|
||||
comet_url=get_comet_experiment_url(),
|
||||
trainer_name="DDPO",
|
||||
trainer_citation=citation,
|
||||
paper_title="Training Diffusion Models with Reinforcement Learning",
|
||||
paper_id="2305.13301",
|
||||
)
|
||||
|
||||
model_card.save(os.path.join(self.args.output_dir, "README.md"))
|
||||
class UnslothDDPOTrainer(_UnslothDDPOTrainer):
|
||||
"""
|
||||
|
||||
The DDPOTrainer uses Deep Diffusion Policy Optimization to optimise diffusion models. Note, this trainer is heavily
|
||||
inspired by the work here: https://github.com/kvablack/ddpo-pytorch As of now only Stable Diffusion based pipelines
|
||||
are supported
|
||||
|
||||
Attributes:
|
||||
**config** (`DDPOConfig`) -- Configuration object for DDPOTrainer. Check the documentation of `PPOConfig` for more:
|
||||
details.
|
||||
**reward_function** (Callable[[torch.Tensor, tuple[str], tuple[Any]], torch.Tensor]) -- Reward function to be used:
|
||||
**prompt_function** (Callable[[], tuple[str, Any]]) -- Function to generate prompts to guide model
|
||||
**sd_pipeline** (`DDPOStableDiffusionPipeline`) -- Stable Diffusion pipeline to be used for training.
|
||||
**image_samples_hook** (Optional[Callable[[Any, Any, Any], Any]]) -- Hook to be called to log images
|
||||
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
reward_function,
|
||||
prompt_function,
|
||||
sd_pipeline,
|
||||
image_samples_hook = None,
|
||||
**kwargs
|
||||
):
|
||||
if args is None: args = UnslothDDPOConfig()
|
||||
other_metrics = []
|
||||
|
||||
from unsloth_zoo.logging_utils import PatchRLStatistics
|
||||
PatchRLStatistics('ddpo_trainer', other_metrics)
|
||||
|
||||
super().__init__(
|
||||
config = config,
|
||||
reward_function = reward_function,
|
||||
prompt_function = prompt_function,
|
||||
sd_pipeline = sd_pipeline,
|
||||
image_samples_hook = image_samples_hook,**kwargs)
|
||||
|
||||
pass
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,874 @@
|
||||
"""
|
||||
2025.8.4
|
||||
2025.8.5
|
||||
4.55.1
|
||||
0.21.0
|
||||
__UNSLOTH_VERSIONING__
|
||||
"""
|
||||
from torch import Tensor
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable
|
||||
from trl.trainer.gkd_trainer import (Any, AutoModelForCausalLM, BaseImageProcessor, Callable, DataCollator, DataCollatorForChatML, Dataset, EvalPrediction, F, FeatureExtractionMixin, GKDConfig, GKDTrainer, GenerationConfig, Optional, PeftConfig, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, SFTTrainer, TrainerCallback, Union, disable_dropout_in_model, empty_cache, generate_model_card, get_comet_experiment_url, is_wandb_available, nn, os, prepare_deepspeed, random, textwrap, torch, unwrap_model_for_generation)
|
||||
|
||||
|
||||
import os
|
||||
from typing import *
|
||||
from dataclasses import dataclass, field
|
||||
from packaging.version import Version
|
||||
import torch
|
||||
import numpy as np
|
||||
from contextlib import nullcontext
|
||||
from torch.nn import functional as F
|
||||
from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling
|
||||
|
||||
torch_compile_options = {
|
||||
"epilogue_fusion" : True,
|
||||
"max_autotune" : False,
|
||||
"shape_padding" : True,
|
||||
"trace.enabled" : False,
|
||||
"triton.cudagraphs" : False,
|
||||
}
|
||||
|
||||
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
|
||||
def chunked_selective_log_softmax(logits, index):
|
||||
# Split into 4 chunks only
|
||||
chunked_logits = torch.chunk(logits.reshape(-1, logits.shape[-1]), chunks = 4, dim = 0)
|
||||
chunked_index = torch.chunk(index.reshape(-1), chunks = 4, dim = 0)
|
||||
all_per_token_logps = []
|
||||
# Below loop does the same as selective_log_softmax(chunk_logits, chunk_index)
|
||||
for chunk_logits, chunk_index in zip(chunked_logits, chunked_index):
|
||||
chunk_logits = chunk_logits.to(torch.float32)
|
||||
selected_logits = torch.gather(chunk_logits, dim = -1, index = chunk_index.unsqueeze(-1)).squeeze(-1)
|
||||
logsumexp_values = torch.logsumexp(chunk_logits, dim = -1)
|
||||
per_token_logps = selected_logits - logsumexp_values
|
||||
all_per_token_logps.append(per_token_logps)
|
||||
pass
|
||||
all_per_token_logps = torch.concat(all_per_token_logps)
|
||||
all_per_token_logps = all_per_token_logps.reshape((logits.shape[0], logits.shape[1]))
|
||||
return all_per_token_logps
|
||||
@dataclass
|
||||
class UnslothGKDConfig(GKDConfig):
|
||||
"""
|
||||
|
||||
Configuration class for [`GKDTrainer`].
|
||||
|
||||
This class includes only the parameters that are specific to GKD training. For a full list of training arguments,
|
||||
please refer to the [`~transformers.TrainingArguments`] and [`SFTConfig`] documentation.
|
||||
|
||||
Args:
|
||||
temperature (`float`, *optional*, defaults to `0.9`):
|
||||
Temperature for sampling. The higher the temperature, the more random the completions.
|
||||
lmbda (`float`, *optional*, defaults to `0.5`):
|
||||
Lambda parameter that controls the student data fraction (i.e., the proportion of on-policy
|
||||
student-generated outputs).
|
||||
beta (`float`, *optional*, defaults to `0.5`):
|
||||
Interpolation coefficient between `0.0` and `1.0` of the Generalized Jensen-Shannon Divergence loss. When
|
||||
beta is `0.0`, the loss is the KL divergence. When beta is `1.0`, the loss is the Inverse KL Divergence.
|
||||
max_new_tokens (`int`, *optional*, defaults to `128`):
|
||||
Maximum number of tokens to generate per completion.
|
||||
teacher_model_name_or_path (`str` or `None`, *optional*, defaults to `None`):
|
||||
Model name or path of the teacher model. If `None`, the teacher model will be the same as the model being
|
||||
trained.
|
||||
teacher_model_init_kwargs (`dict[str, Any]]` or `None`, *optional*, defaults to `None`):
|
||||
Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the teacher model
|
||||
from a string.
|
||||
disable_dropout (`bool`, *optional*, defaults to `True`):
|
||||
Whether to disable dropout in the model.
|
||||
seq_kd (`bool`, *optional*, defaults to `False`):
|
||||
Seq_kd parameter that controls whether to perform Sequence-Level KD (can be viewed as supervised FT on
|
||||
teacher-generated output).
|
||||
|
||||
"""
|
||||
vllm_sampling_params: Optional[Any] = field(
|
||||
default = None,
|
||||
metadata = {'help': 'vLLM SamplingParams'},
|
||||
)
|
||||
unsloth_num_chunks : Optional[int] = field(
|
||||
default = -1,
|
||||
metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
|
||||
)
|
||||
def __init__(
|
||||
self,
|
||||
output_dir = None,
|
||||
overwrite_output_dir = None,
|
||||
do_train = False,
|
||||
do_eval = False,
|
||||
do_predict = False,
|
||||
eval_strategy = 'no',
|
||||
prediction_loss_only = False,
|
||||
per_device_train_batch_size = 4,
|
||||
per_device_eval_batch_size = 4,
|
||||
per_gpu_train_batch_size = None,
|
||||
per_gpu_eval_batch_size = None,
|
||||
gradient_accumulation_steps = 2,
|
||||
eval_accumulation_steps = 2,
|
||||
eval_delay = 0,
|
||||
torch_empty_cache_steps = 250,
|
||||
learning_rate = 5e-05,
|
||||
weight_decay = 0.01,
|
||||
adam_beta1 = 0.9,
|
||||
adam_beta2 = 0.999,
|
||||
adam_epsilon = 1e-08,
|
||||
max_grad_norm = 1.0,
|
||||
num_train_epochs = 3.0,
|
||||
max_steps = -1,
|
||||
lr_scheduler_type = 'linear',
|
||||
warmup_ratio = 0.1,
|
||||
warmup_steps = 0,
|
||||
log_level = 'passive',
|
||||
log_level_replica = 'warning',
|
||||
log_on_each_node = True,
|
||||
logging_dir = None,
|
||||
logging_strategy = 'steps',
|
||||
logging_first_step = False,
|
||||
logging_steps = 1,
|
||||
logging_nan_inf_filter = False,
|
||||
save_strategy = 'steps',
|
||||
save_steps = 500,
|
||||
save_total_limit = None,
|
||||
save_safetensors = True,
|
||||
save_on_each_node = False,
|
||||
save_only_model = False,
|
||||
restore_callback_states_from_checkpoint = False,
|
||||
no_cuda = False,
|
||||
use_cpu = False,
|
||||
use_mps_device = False,
|
||||
seed = 3407,
|
||||
data_seed = 3407,
|
||||
jit_mode_eval = False,
|
||||
use_ipex = False,
|
||||
bf16 = False,
|
||||
fp16 = False,
|
||||
fp16_opt_level = 'O1',
|
||||
half_precision_backend = 'auto',
|
||||
bf16_full_eval = False,
|
||||
fp16_full_eval = False,
|
||||
tf32 = None,
|
||||
local_rank = -1,
|
||||
ddp_backend = None,
|
||||
tpu_num_cores = None,
|
||||
tpu_metrics_debug = False,
|
||||
debug = '',
|
||||
dataloader_drop_last = False,
|
||||
eval_steps = None,
|
||||
dataloader_num_workers = 0,
|
||||
dataloader_prefetch_factor = None,
|
||||
past_index = -1,
|
||||
run_name = None,
|
||||
disable_tqdm = None,
|
||||
remove_unused_columns = True,
|
||||
label_names = None,
|
||||
load_best_model_at_end = False,
|
||||
metric_for_best_model = None,
|
||||
greater_is_better = None,
|
||||
ignore_data_skip = False,
|
||||
fsdp = '',
|
||||
fsdp_min_num_params = 0,
|
||||
fsdp_config = None,
|
||||
fsdp_transformer_layer_cls_to_wrap = None,
|
||||
accelerator_config = None,
|
||||
deepspeed = None,
|
||||
label_smoothing_factor = 0.0,
|
||||
optim = 'adamw_8bit',
|
||||
optim_args = None,
|
||||
adafactor = False,
|
||||
group_by_length = False,
|
||||
length_column_name = 'length',
|
||||
report_to = None,
|
||||
ddp_find_unused_parameters = None,
|
||||
ddp_bucket_cap_mb = None,
|
||||
ddp_broadcast_buffers = None,
|
||||
dataloader_pin_memory = True,
|
||||
dataloader_persistent_workers = False,
|
||||
skip_memory_metrics = True,
|
||||
use_legacy_prediction_loop = False,
|
||||
push_to_hub = False,
|
||||
resume_from_checkpoint = None,
|
||||
hub_model_id = None,
|
||||
hub_strategy = 'every_save',
|
||||
hub_token = None,
|
||||
hub_private_repo = None,
|
||||
hub_always_push = False,
|
||||
hub_revision = None,
|
||||
gradient_checkpointing = False,
|
||||
gradient_checkpointing_kwargs = None,
|
||||
include_inputs_for_metrics = False,
|
||||
eval_do_concat_batches = True,
|
||||
fp16_backend = 'auto',
|
||||
push_to_hub_model_id = None,
|
||||
push_to_hub_organization = None,
|
||||
push_to_hub_token = None,
|
||||
mp_parameters = '',
|
||||
auto_find_batch_size = True,
|
||||
full_determinism = False,
|
||||
torchdynamo = None,
|
||||
ray_scope = 'last',
|
||||
ddp_timeout = 1800,
|
||||
torch_compile = False,
|
||||
torch_compile_backend = None,
|
||||
torch_compile_mode = None,
|
||||
include_tokens_per_second = False,
|
||||
include_num_input_tokens_seen = False,
|
||||
neftune_noise_alpha = None,
|
||||
optim_target_modules = None,
|
||||
batch_eval_metrics = False,
|
||||
eval_on_start = False,
|
||||
use_liger_kernel = False,
|
||||
liger_kernel_config = None,
|
||||
eval_use_gather_object = False,
|
||||
average_tokens_across_devices = True,
|
||||
model_init_kwargs = None,
|
||||
chat_template_path = None,
|
||||
dataset_text_field = 'text',
|
||||
dataset_kwargs = None,
|
||||
dataset_num_proc = None,
|
||||
eos_token = None,
|
||||
pad_token = None,
|
||||
max_length = 1024,
|
||||
packing = False,
|
||||
packing_strategy = 'bfd',
|
||||
padding_free = False,
|
||||
pad_to_multiple_of = None,
|
||||
eval_packing = None,
|
||||
completion_only_loss = None,
|
||||
assistant_only_loss = False,
|
||||
activation_offloading = False,
|
||||
temperature = 0.9,
|
||||
lmbda = 0.5,
|
||||
beta = 0.5,
|
||||
max_new_tokens = 128,
|
||||
teacher_model_name_or_path = None,
|
||||
teacher_model_init_kwargs = None,
|
||||
disable_dropout = True,
|
||||
seq_kd = False,
|
||||
vllm_sampling_params = None,
|
||||
unsloth_num_chunks = -1,
|
||||
**kwargs,
|
||||
):
|
||||
if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
|
||||
if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
|
||||
if output_dir is None and save_strategy == 'steps' and save_steps == 500:
|
||||
output_dir = 'unsloth_training_checkpoints'
|
||||
save_strategy = 'no'
|
||||
if dataset_num_proc is None:
|
||||
from multiprocessing import cpu_count
|
||||
dataset_num_proc = min(cpu_count()*2, 2)
|
||||
if temperature <= 0:
|
||||
raise MathError('Unsloth: Please set a positive non-zero temperature since your results will be wrong.')
|
||||
elif temperature >= 10:
|
||||
raise MathError('Unsloth: Please set a positive non-zero temperature less than 10, since sampling will be quite erratic.')
|
||||
|
||||
|
||||
super().__init__(
|
||||
output_dir = output_dir,
|
||||
overwrite_output_dir = overwrite_output_dir,
|
||||
do_train = do_train,
|
||||
do_eval = do_eval,
|
||||
do_predict = do_predict,
|
||||
eval_strategy = eval_strategy,
|
||||
prediction_loss_only = prediction_loss_only,
|
||||
per_device_train_batch_size = per_device_train_batch_size,
|
||||
per_device_eval_batch_size = per_device_eval_batch_size,
|
||||
per_gpu_train_batch_size = per_gpu_train_batch_size,
|
||||
per_gpu_eval_batch_size = per_gpu_eval_batch_size,
|
||||
gradient_accumulation_steps = gradient_accumulation_steps,
|
||||
eval_accumulation_steps = eval_accumulation_steps,
|
||||
eval_delay = eval_delay,
|
||||
torch_empty_cache_steps = torch_empty_cache_steps,
|
||||
learning_rate = learning_rate,
|
||||
weight_decay = weight_decay,
|
||||
adam_beta1 = adam_beta1,
|
||||
adam_beta2 = adam_beta2,
|
||||
adam_epsilon = adam_epsilon,
|
||||
max_grad_norm = max_grad_norm,
|
||||
num_train_epochs = num_train_epochs,
|
||||
max_steps = max_steps,
|
||||
lr_scheduler_type = lr_scheduler_type,
|
||||
warmup_ratio = warmup_ratio,
|
||||
warmup_steps = warmup_steps,
|
||||
log_level = log_level,
|
||||
log_level_replica = log_level_replica,
|
||||
log_on_each_node = log_on_each_node,
|
||||
logging_dir = logging_dir,
|
||||
logging_strategy = logging_strategy,
|
||||
logging_first_step = logging_first_step,
|
||||
logging_steps = logging_steps,
|
||||
logging_nan_inf_filter = logging_nan_inf_filter,
|
||||
save_strategy = save_strategy,
|
||||
save_steps = save_steps,
|
||||
save_total_limit = save_total_limit,
|
||||
save_safetensors = save_safetensors,
|
||||
save_on_each_node = save_on_each_node,
|
||||
save_only_model = save_only_model,
|
||||
restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
|
||||
no_cuda = no_cuda,
|
||||
use_cpu = use_cpu,
|
||||
use_mps_device = use_mps_device,
|
||||
seed = seed,
|
||||
data_seed = data_seed,
|
||||
jit_mode_eval = jit_mode_eval,
|
||||
use_ipex = use_ipex,
|
||||
bf16 = bf16,
|
||||
fp16 = fp16,
|
||||
fp16_opt_level = fp16_opt_level,
|
||||
half_precision_backend = half_precision_backend,
|
||||
bf16_full_eval = bf16_full_eval,
|
||||
fp16_full_eval = fp16_full_eval,
|
||||
tf32 = tf32,
|
||||
local_rank = local_rank,
|
||||
ddp_backend = ddp_backend,
|
||||
tpu_num_cores = tpu_num_cores,
|
||||
tpu_metrics_debug = tpu_metrics_debug,
|
||||
debug = debug,
|
||||
dataloader_drop_last = dataloader_drop_last,
|
||||
eval_steps = eval_steps,
|
||||
dataloader_num_workers = dataloader_num_workers,
|
||||
dataloader_prefetch_factor = dataloader_prefetch_factor,
|
||||
past_index = past_index,
|
||||
run_name = run_name,
|
||||
disable_tqdm = disable_tqdm,
|
||||
remove_unused_columns = remove_unused_columns,
|
||||
label_names = label_names,
|
||||
load_best_model_at_end = load_best_model_at_end,
|
||||
metric_for_best_model = metric_for_best_model,
|
||||
greater_is_better = greater_is_better,
|
||||
ignore_data_skip = ignore_data_skip,
|
||||
fsdp = fsdp,
|
||||
fsdp_min_num_params = fsdp_min_num_params,
|
||||
fsdp_config = fsdp_config,
|
||||
fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
|
||||
accelerator_config = accelerator_config,
|
||||
deepspeed = deepspeed,
|
||||
label_smoothing_factor = label_smoothing_factor,
|
||||
optim = optim,
|
||||
optim_args = optim_args,
|
||||
adafactor = adafactor,
|
||||
group_by_length = group_by_length,
|
||||
length_column_name = length_column_name,
|
||||
report_to = report_to,
|
||||
ddp_find_unused_parameters = ddp_find_unused_parameters,
|
||||
ddp_bucket_cap_mb = ddp_bucket_cap_mb,
|
||||
ddp_broadcast_buffers = ddp_broadcast_buffers,
|
||||
dataloader_pin_memory = dataloader_pin_memory,
|
||||
dataloader_persistent_workers = dataloader_persistent_workers,
|
||||
skip_memory_metrics = skip_memory_metrics,
|
||||
use_legacy_prediction_loop = use_legacy_prediction_loop,
|
||||
push_to_hub = push_to_hub,
|
||||
resume_from_checkpoint = resume_from_checkpoint,
|
||||
hub_model_id = hub_model_id,
|
||||
hub_strategy = hub_strategy,
|
||||
hub_token = hub_token,
|
||||
hub_private_repo = hub_private_repo,
|
||||
hub_always_push = hub_always_push,
|
||||
hub_revision = hub_revision,
|
||||
gradient_checkpointing = gradient_checkpointing,
|
||||
gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
|
||||
include_inputs_for_metrics = include_inputs_for_metrics,
|
||||
eval_do_concat_batches = eval_do_concat_batches,
|
||||
fp16_backend = fp16_backend,
|
||||
push_to_hub_model_id = push_to_hub_model_id,
|
||||
push_to_hub_organization = push_to_hub_organization,
|
||||
push_to_hub_token = push_to_hub_token,
|
||||
mp_parameters = mp_parameters,
|
||||
auto_find_batch_size = auto_find_batch_size,
|
||||
full_determinism = full_determinism,
|
||||
torchdynamo = torchdynamo,
|
||||
ray_scope = ray_scope,
|
||||
ddp_timeout = ddp_timeout,
|
||||
torch_compile = torch_compile,
|
||||
torch_compile_backend = torch_compile_backend,
|
||||
torch_compile_mode = torch_compile_mode,
|
||||
include_tokens_per_second = include_tokens_per_second,
|
||||
include_num_input_tokens_seen = include_num_input_tokens_seen,
|
||||
neftune_noise_alpha = neftune_noise_alpha,
|
||||
optim_target_modules = optim_target_modules,
|
||||
batch_eval_metrics = batch_eval_metrics,
|
||||
eval_on_start = eval_on_start,
|
||||
use_liger_kernel = use_liger_kernel,
|
||||
liger_kernel_config = liger_kernel_config,
|
||||
eval_use_gather_object = eval_use_gather_object,
|
||||
average_tokens_across_devices = average_tokens_across_devices,
|
||||
model_init_kwargs = model_init_kwargs,
|
||||
chat_template_path = chat_template_path,
|
||||
dataset_text_field = dataset_text_field,
|
||||
dataset_kwargs = dataset_kwargs,
|
||||
dataset_num_proc = dataset_num_proc,
|
||||
eos_token = eos_token,
|
||||
pad_token = pad_token,
|
||||
max_length = max_length,
|
||||
packing = packing,
|
||||
packing_strategy = packing_strategy,
|
||||
padding_free = padding_free,
|
||||
pad_to_multiple_of = pad_to_multiple_of,
|
||||
eval_packing = eval_packing,
|
||||
completion_only_loss = completion_only_loss,
|
||||
assistant_only_loss = assistant_only_loss,
|
||||
activation_offloading = activation_offloading,
|
||||
temperature = temperature,
|
||||
lmbda = lmbda,
|
||||
beta = beta,
|
||||
max_new_tokens = max_new_tokens,
|
||||
teacher_model_name_or_path = teacher_model_name_or_path,
|
||||
teacher_model_init_kwargs = teacher_model_init_kwargs,
|
||||
disable_dropout = disable_dropout,
|
||||
seq_kd = seq_kd,**kwargs)
|
||||
self.vllm_sampling_params = vllm_sampling_params
|
||||
self.unsloth_num_chunks = unsloth_num_chunks
|
||||
pass
|
||||
|
||||
class _UnslothGKDTrainer(SFTTrainer):
|
||||
_tag_names = ["trl", "gkd"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
|
||||
teacher_model: Union[PreTrainedModel, nn.Module, str] = None,
|
||||
args: Optional[GKDConfig] = None,
|
||||
data_collator: Optional[DataCollator] = None, # type: ignore
|
||||
train_dataset: Optional[Dataset] = None,
|
||||
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
|
||||
processing_class: Optional[
|
||||
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
|
||||
] = None,
|
||||
compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
|
||||
callbacks: Optional[list[TrainerCallback]] = None,
|
||||
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
||||
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
||||
peft_config: Optional["PeftConfig"] = None,
|
||||
formatting_func: Optional[Callable] = None,
|
||||
):
|
||||
# add remove_unused_columns=False to the dataclass args
|
||||
args.remove_unused_columns = False
|
||||
data_collator = DataCollatorForChatML(tokenizer=processing_class, max_length=args.max_length)
|
||||
|
||||
super().__init__(
|
||||
model,
|
||||
args=args,
|
||||
data_collator=data_collator,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
processing_class=processing_class,
|
||||
compute_metrics=compute_metrics,
|
||||
callbacks=callbacks,
|
||||
optimizers=optimizers,
|
||||
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
|
||||
peft_config=peft_config,
|
||||
formatting_func=formatting_func,
|
||||
)
|
||||
|
||||
if args.teacher_model_init_kwargs is None:
|
||||
teacher_model_init_kwargs = {}
|
||||
elif not isinstance(teacher_model, str):
|
||||
raise ValueError(
|
||||
"You passed teacher_model_init_kwargs to the GKDConfig, but your teacher_model is already instantiated."
|
||||
)
|
||||
else:
|
||||
teacher_model_init_kwargs = args.teacher_model_init_kwargs
|
||||
teacher_model_init_kwargs["torch_dtype"] = (
|
||||
teacher_model_init_kwargs["torch_dtype"]
|
||||
if teacher_model_init_kwargs["torch_dtype"] in ["auto", None]
|
||||
else getattr(torch, teacher_model_init_kwargs["torch_dtype"])
|
||||
)
|
||||
|
||||
if isinstance(teacher_model, str):
|
||||
teacher_model = AutoModelForCausalLM.from_pretrained(teacher_model, **teacher_model_init_kwargs)
|
||||
|
||||
# Disable dropout in the model
|
||||
if args.disable_dropout:
|
||||
disable_dropout_in_model(self.model)
|
||||
|
||||
if self.is_deepspeed_enabled:
|
||||
self.teacher_model = prepare_deepspeed(teacher_model, self.accelerator)
|
||||
else:
|
||||
self.teacher_model = self.accelerator.prepare_model(teacher_model, evaluation_mode=True)
|
||||
|
||||
self.lmbda = args.lmbda
|
||||
self.beta = args.beta
|
||||
self.temperature = args.temperature
|
||||
self.seq_kd = args.seq_kd
|
||||
|
||||
self.generation_config = GenerationConfig(
|
||||
max_new_tokens=args.max_new_tokens,
|
||||
temperature=args.temperature,
|
||||
do_sample=True,
|
||||
top_k=0,
|
||||
use_cache=False if args.gradient_checkpointing else True,
|
||||
pad_token_id=self.processing_class.pad_token_id,
|
||||
)
|
||||
# Set custom EOS tokens if they are specified by the model's generation
|
||||
# config. This is important for models with the Llama 3 chat template,
|
||||
# which use special tokens <|eot_id|> and <|eom_id|> to mark the end of
|
||||
# turns or messages.
|
||||
if (
|
||||
hasattr(self.model.generation_config, "eos_token_id")
|
||||
and self.model.generation_config.eos_token_id is not None
|
||||
):
|
||||
self.generation_config.eos_token_id = self.model.generation_config.eos_token_id
|
||||
|
||||
@staticmethod
|
||||
def generalized_jsd_loss(
|
||||
student_logits, teacher_logits, labels=None, beta=0.5, temperature=1.0, reduction="batchmean"
|
||||
):
|
||||
"""
|
||||
Compute the generalized Jensen-Shannon Divergence loss for knowledge distillation using F.kl_div. See Eq. (1)
|
||||
of https://huggingface.co/papers/2306.13649 for the definition.
|
||||
|
||||
Args:
|
||||
student_logits:
|
||||
Tensor of shape (batch_size, sequence_length, vocab_size)
|
||||
teacher_logits:
|
||||
Tensor of shape (batch_size, sequence_length, vocab_size)
|
||||
labels:
|
||||
Tensor of shape (batch_size, sequence_length) with -100 for padding tokens to ignore when computing
|
||||
loss
|
||||
beta:
|
||||
Interpolation coefficient between 0 and 1 (default: 0.5)
|
||||
temperature:
|
||||
Softmax temperature (default: 1.0)
|
||||
reduction:
|
||||
Specifies the reduction to apply to the output (default: 'batchmean')
|
||||
|
||||
Returns:
|
||||
loss: Scalar tensor with the generalized JSD loss
|
||||
"""
|
||||
|
||||
# Apply temperature scaling
|
||||
student_logits = student_logits / temperature
|
||||
teacher_logits = teacher_logits / temperature
|
||||
|
||||
# Compute log probabilities for student and probabilities for teacher
|
||||
student_log_probs = F.log_softmax(student_logits, dim=-1)
|
||||
teacher_log_probs = F.log_softmax(teacher_logits, dim=-1)
|
||||
|
||||
if beta == 0:
|
||||
jsd = F.kl_div(student_log_probs, teacher_log_probs, reduction="none", log_target=True)
|
||||
elif beta == 1:
|
||||
jsd = F.kl_div(teacher_log_probs, student_log_probs, reduction="none", log_target=True)
|
||||
else:
|
||||
# Compute the log of the mixture distribution
|
||||
# log(a + b) = log(exp(log(a)) + exp(log(b))) -> for mixture
|
||||
beta = torch.tensor(beta, dtype=student_log_probs.dtype)
|
||||
mixture_log_probs = torch.logsumexp(
|
||||
torch.stack([student_log_probs + torch.log(1 - beta), teacher_log_probs + torch.log(beta)]),
|
||||
dim=0,
|
||||
)
|
||||
|
||||
# Compute KL divergences using F.kl_div
|
||||
# PyTorch differs from the standard mathematical definition, so the order of the probability distributions is swapped compared to that defined in the paper.
|
||||
kl_teacher = F.kl_div(mixture_log_probs, teacher_log_probs, reduction="none", log_target=True)
|
||||
kl_student = F.kl_div(mixture_log_probs, student_log_probs, reduction="none", log_target=True)
|
||||
|
||||
# Compute the Generalized Jensen-Shannon Divergence
|
||||
jsd = beta * kl_teacher + (1 - beta) * kl_student
|
||||
|
||||
# Masking
|
||||
if labels is not None:
|
||||
mask = labels != -100
|
||||
jsd = jsd[mask]
|
||||
|
||||
# Apply reduction
|
||||
if reduction == "batchmean":
|
||||
return jsd.sum() / mask.sum() if labels is not None else jsd.sum() / (jsd.size(0) * jsd.size(1))
|
||||
elif reduction == "sum":
|
||||
return jsd.sum()
|
||||
elif reduction == "mean":
|
||||
return jsd.mean()
|
||||
else:
|
||||
return jsd
|
||||
|
||||
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
|
||||
# compute student output
|
||||
outputs_student = model(
|
||||
input_ids=inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
)
|
||||
|
||||
# compute teacher output in eval mode
|
||||
self.teacher_model.eval()
|
||||
with torch.no_grad():
|
||||
outputs_teacher = self.teacher_model(
|
||||
input_ids=inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
)
|
||||
|
||||
# slice the logits for the generated tokens using the inputs["prompts"] lengths
|
||||
prompt_lengths = inputs["prompts"].shape[1]
|
||||
shifted_student_logits = outputs_student.logits[:, prompt_lengths - 1 : -1, :]
|
||||
shifted_teacher_logits = outputs_teacher.logits[:, prompt_lengths - 1 : -1, :]
|
||||
shifted_labels = inputs["labels"][:, prompt_lengths:]
|
||||
|
||||
# compute loss
|
||||
loss = self.generalized_jsd_loss(
|
||||
student_logits=shifted_student_logits,
|
||||
teacher_logits=shifted_teacher_logits,
|
||||
labels=shifted_labels,
|
||||
beta=self.beta,
|
||||
)
|
||||
|
||||
# empty cache
|
||||
empty_cache()
|
||||
|
||||
# Return loss
|
||||
return (loss, outputs_student) if return_outputs else loss
|
||||
|
||||
@staticmethod
|
||||
def generate_on_policy_outputs(model, inputs, generation_config, pad_token_id=None):
|
||||
# Generate output with respect to the prompt-only
|
||||
generated_outputs = model.generate(
|
||||
input_ids=inputs["prompts"],
|
||||
attention_mask=inputs.get("prompt_attention_mask", None),
|
||||
generation_config=generation_config,
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
|
||||
# Get the generated token IDs
|
||||
generated_tokens = generated_outputs.sequences
|
||||
# Calculate new attention mask
|
||||
new_attention_mask = torch.ones_like(generated_tokens)
|
||||
new_labels = generated_tokens.clone()
|
||||
|
||||
# If there's pad_token_id, set attention mask to 0 for padding tokens
|
||||
if pad_token_id is not None:
|
||||
new_labels[new_labels == pad_token_id] = -100
|
||||
new_attention_mask[generated_tokens == pad_token_id] = 0
|
||||
|
||||
return generated_tokens, new_attention_mask, new_labels
|
||||
|
||||
def training_step(
|
||||
self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], num_items_in_batch: Optional[int] = None
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Perform a training step for the Generalized Knowledge Distillation (GKD) model.
|
||||
|
||||
This method implements the on-policy learning approach described in the GKD paper. With probability
|
||||
`self.lmbda`, it generates new responses using the student model, which are then used for training instead of
|
||||
the original inputs.
|
||||
"""
|
||||
if self.seq_kd:
|
||||
with unwrap_model_for_generation(self.teacher_model, self.accelerator) as unwrapped_model:
|
||||
new_input_ids, new_attention_mask, new_labels = self.generate_on_policy_outputs(
|
||||
unwrapped_model, inputs, self.generation_config, self.processing_class.pad_token_id
|
||||
)
|
||||
inputs["input_ids"] = new_input_ids
|
||||
inputs["attention_mask"] = new_attention_mask
|
||||
inputs["labels"] = new_labels
|
||||
if random.random() <= self.lmbda:
|
||||
with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
|
||||
new_input_ids, new_attention_mask, new_labels = self.generate_on_policy_outputs(
|
||||
unwrapped_model, inputs, self.generation_config, self.processing_class.pad_token_id
|
||||
)
|
||||
inputs["input_ids"] = new_input_ids
|
||||
inputs["attention_mask"] = new_attention_mask
|
||||
inputs["labels"] = new_labels
|
||||
|
||||
loss = super().training_step(model, inputs, num_items_in_batch)
|
||||
return loss
|
||||
|
||||
def create_model_card(
|
||||
self,
|
||||
model_name: Optional[str] = None,
|
||||
dataset_name: Optional[str] = None,
|
||||
tags: Union[str, list[str], None] = None,
|
||||
):
|
||||
"""
|
||||
Creates a draft of a model card using the information available to the `Trainer`.
|
||||
|
||||
Args:
|
||||
model_name (`str` or `None`, *optional*, defaults to `None`):
|
||||
Name of the model.
|
||||
dataset_name (`str` or `None`, *optional*, defaults to `None`):
|
||||
Name of the dataset used for training.
|
||||
tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
|
||||
Tags to be associated with the model card.
|
||||
"""
|
||||
if not self.is_world_process_zero():
|
||||
return
|
||||
|
||||
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
|
||||
base_model = self.model.config._name_or_path
|
||||
else:
|
||||
base_model = None
|
||||
|
||||
# normalize `tags` to a mutable set
|
||||
if tags is None:
|
||||
tags = set()
|
||||
elif isinstance(tags, str):
|
||||
tags = {tags}
|
||||
else:
|
||||
tags = set(tags)
|
||||
|
||||
if hasattr(self.model.config, "unsloth_version"):
|
||||
tags.add("unsloth")
|
||||
|
||||
tags.update(self._tag_names)
|
||||
|
||||
citation = textwrap.dedent("""\
|
||||
@inproceedings{agarwal2024on-policy,
|
||||
title = {{On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes}},
|
||||
author = {Rishabh Agarwal and Nino Vieillard and Yongchao Zhou and Piotr Stanczyk and Sabela Ramos Garea and Matthieu Geist and Olivier Bachem},
|
||||
year = 2024,
|
||||
booktitle = {The Twelfth International Conference on Learning Representations, {ICLR} 2024, Vienna, Austria, May 7-11, 2024},
|
||||
publisher = {OpenReview.net},
|
||||
url = {https://openreview.net/forum?id=3zKtaqxLhW},
|
||||
}""")
|
||||
|
||||
model_card = generate_model_card(
|
||||
base_model=base_model,
|
||||
model_name=model_name,
|
||||
hub_model_id=self.hub_model_id,
|
||||
dataset_name=dataset_name,
|
||||
tags=tags,
|
||||
wandb_url=wandb.run.url if is_wandb_available() and wandb.run is not None else None,
|
||||
comet_url=get_comet_experiment_url(),
|
||||
trainer_name="GKD",
|
||||
trainer_citation=citation,
|
||||
paper_title="On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes",
|
||||
paper_id="2306.13649",
|
||||
)
|
||||
|
||||
model_card.save(os.path.join(self.args.output_dir, "README.md"))
|
||||
class UnslothGKDTrainer(_UnslothGKDTrainer):
|
||||
"""
|
||||
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
model = None,
|
||||
teacher_model = None,
|
||||
args = None,
|
||||
data_collator = None,
|
||||
train_dataset = None,
|
||||
eval_dataset = None,
|
||||
processing_class = None,
|
||||
compute_metrics = None,
|
||||
callbacks = None,
|
||||
preprocess_logits_for_metrics = None,
|
||||
peft_config = None,
|
||||
formatting_func = None,
|
||||
**kwargs
|
||||
):
|
||||
if args is None: args = UnslothGKDConfig()
|
||||
use_bf16 = getattr(args, 'bf16', False)
|
||||
if type(use_bf16) is not bool: use_bf16 = False
|
||||
use_fp16 = getattr(args, 'fp16', False)
|
||||
if type(use_fp16) is not bool: use_fp16 = False
|
||||
force_float32 = False
|
||||
if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
|
||||
print('Unsloth: Switching to float32 training since model cannot work with float16')
|
||||
force_float32 = True
|
||||
mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
|
||||
dtype = getattr(model.config, 'torch_dtype', None)
|
||||
if dtype is None: dtype = model.get_input_embeddings().dtype
|
||||
from unsloth_zoo.utils import _get_dtype
|
||||
dtype = _get_dtype(dtype)
|
||||
float16 = dtype == torch.float16
|
||||
if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
|
||||
if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
|
||||
if force_float32:
|
||||
args.fp16 = False
|
||||
args.bf16 = False
|
||||
os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
|
||||
elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
|
||||
args.fp16 = float16
|
||||
args.bf16 = not float16
|
||||
os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
|
||||
if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
|
||||
args.eval_strategy = 'steps'
|
||||
if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
|
||||
ga_steps = getattr(args, 'gradient_accumulation_steps', None)
|
||||
if ga_steps is not None and ga_steps > 1:
|
||||
from transformers import __version__ as transformers_version
|
||||
if Version(transformers_version) <= Version('4.45.2'):
|
||||
print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
|
||||
'`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
|
||||
if getattr(args, 'eval_strategy', 'no') != 'no':
|
||||
eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
|
||||
if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
|
||||
if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
|
||||
fp16_full_eval = getattr(args, 'fp16_full_eval', False)
|
||||
if type(fp16_full_eval) is not bool: fp16_full_eval = False
|
||||
bf16_full_eval = getattr(args, 'bf16_full_eval', False)
|
||||
if type(bf16_full_eval) is not bool: bf16_full_eval = False
|
||||
if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
|
||||
if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
|
||||
if force_float32:
|
||||
args.bf16_full_eval = False
|
||||
args.fp16_full_eval = False
|
||||
elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
|
||||
args.bf16_full_eval = True
|
||||
args.fp16_full_eval = False
|
||||
elif not bf16_full_eval and not fp16_full_eval:
|
||||
args.bf16_full_eval = args.bf16
|
||||
args.fp16_full_eval = args.fp16
|
||||
_output_logits = False
|
||||
if locals().get('compute_metrics', None) is not None: _output_logits = True
|
||||
if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
|
||||
if _output_logits:
|
||||
os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
|
||||
if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
|
||||
pass
|
||||
else:
|
||||
model_max_seq_length = getattr(model, 'max_seq_length', None)
|
||||
args_max_seq_length = getattr(args, 'max_seq_length', None)
|
||||
if args_max_seq_length is None and model_max_seq_length is not None:
|
||||
max_seq_length = model.max_seq_length
|
||||
if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
|
||||
if model is not None and hasattr(model, 'for_training'):
|
||||
model.for_training()
|
||||
if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
|
||||
if 'processing_class' in locals():
|
||||
if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
|
||||
if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
|
||||
__tokenizer = processing_class if 'processing_class' in locals() else tokenizer
|
||||
from unsloth_zoo.vision_utils import UnslothVisionDataCollator
|
||||
if not isinstance(data_collator, UnslothVisionDataCollator):
|
||||
if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
|
||||
data_collator = TransformersDataCollatorForLanguageModeling(__tokenizer, mlm = False, mlm_probability = 0.0)
|
||||
elif isinstance(data_collator, TransformersDataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
|
||||
data_collator = DataCollatorForSeq2Seq(__tokenizer)
|
||||
else:
|
||||
if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
|
||||
if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
|
||||
if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
|
||||
if not isinstance(data_collator, UnslothVisionDataCollator):
|
||||
if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
|
||||
if isinstance(data_collator, DataCollatorForSeq2Seq):
|
||||
data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
|
||||
else:
|
||||
data_collator = TransformersDataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False, mlm_probability = 0.0)
|
||||
other_metrics = []
|
||||
|
||||
from unsloth_zoo.logging_utils import PatchRLStatistics
|
||||
PatchRLStatistics('gkd_trainer', other_metrics)
|
||||
|
||||
super().__init__(
|
||||
model = model,
|
||||
teacher_model = teacher_model,
|
||||
args = args,
|
||||
data_collator = data_collator,
|
||||
train_dataset = train_dataset,
|
||||
eval_dataset = eval_dataset,
|
||||
processing_class = processing_class,
|
||||
compute_metrics = compute_metrics,
|
||||
callbacks = callbacks,
|
||||
preprocess_logits_for_metrics = preprocess_logits_for_metrics,
|
||||
peft_config = peft_config,
|
||||
formatting_func = formatting_func,**kwargs)
|
||||
if hasattr(self, 'neftune_hook_handle'):
|
||||
self.neftune_hook_handle.remove()
|
||||
if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
|
||||
if getattr(args, 'neftune_noise_alpha', None) is not None:
|
||||
model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
|
||||
pass
|
||||
if hasattr(self, 'accelerator'):
|
||||
scaler = self.accelerator.scaler
|
||||
current_model = model
|
||||
while hasattr(current_model, 'model'):
|
||||
current_model.accelerator_scaler = scaler
|
||||
current_model = current_model.model
|
||||
current_model.accelerator_scaler = scaler
|
||||
pass
|
||||
|
||||
pass
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,918 @@
|
||||
"""
|
||||
2025.8.4
|
||||
2025.8.5
|
||||
4.55.1
|
||||
0.21.0
|
||||
__UNSLOTH_VERSIONING__
|
||||
"""
|
||||
from torch import Tensor
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable
|
||||
from trl.trainer.iterative_sft_trainer import (AutoModelForCausalLM, AutoTokenizer, BaseImageProcessor, Callable, DataCollator, DataCollatorForLanguageModeling, DataCollatorForSeq2Seq, DataLoader, Dataset, EvalLoopOutput, FeatureExtractionMixin, IterativeSFTConfig, IterativeSFTTrainer, Optional, PPODecorators, Path, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, Trainer, TrainingArguments, Union, generate_model_card, get_comet_experiment_url, is_peft_available, is_wandb_available, os, torch, warnings, Optional, PeftModel, PreTrainedModel, Trainer, is_peft_available, os, torch)
|
||||
|
||||
|
||||
import os
|
||||
from typing import *
|
||||
from dataclasses import dataclass, field
|
||||
from packaging.version import Version
|
||||
import torch
|
||||
import numpy as np
|
||||
from contextlib import nullcontext
|
||||
from torch.nn import functional as F
|
||||
from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling
|
||||
|
||||
torch_compile_options = {
|
||||
"epilogue_fusion" : True,
|
||||
"max_autotune" : False,
|
||||
"shape_padding" : True,
|
||||
"trace.enabled" : False,
|
||||
"triton.cudagraphs" : False,
|
||||
}
|
||||
|
||||
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
|
||||
def chunked_selective_log_softmax(logits, index):
|
||||
# Split into 4 chunks only
|
||||
chunked_logits = torch.chunk(logits.reshape(-1, logits.shape[-1]), chunks = 4, dim = 0)
|
||||
chunked_index = torch.chunk(index.reshape(-1), chunks = 4, dim = 0)
|
||||
all_per_token_logps = []
|
||||
# Below loop does the same as selective_log_softmax(chunk_logits, chunk_index)
|
||||
for chunk_logits, chunk_index in zip(chunked_logits, chunked_index):
|
||||
chunk_logits = chunk_logits.to(torch.float32)
|
||||
selected_logits = torch.gather(chunk_logits, dim = -1, index = chunk_index.unsqueeze(-1)).squeeze(-1)
|
||||
logsumexp_values = torch.logsumexp(chunk_logits, dim = -1)
|
||||
per_token_logps = selected_logits - logsumexp_values
|
||||
all_per_token_logps.append(per_token_logps)
|
||||
pass
|
||||
all_per_token_logps = torch.concat(all_per_token_logps)
|
||||
all_per_token_logps = all_per_token_logps.reshape((logits.shape[0], logits.shape[1]))
|
||||
return all_per_token_logps
|
||||
@dataclass
|
||||
class UnslothIterativeSFTConfig(IterativeSFTConfig):
|
||||
"""
|
||||
|
||||
Configuration class for the [`IterativeSFTTrainer`].
|
||||
|
||||
This class includes only the parameters that are specific to Iterative SFT training. For a full list of training
|
||||
arguments, please refer to the [`~transformers.TrainingArguments`] documentation. Note that default values in this
|
||||
class may differ from those in [`~transformers.TrainingArguments`].
|
||||
|
||||
Using [`~transformers.HfArgumentParser`] we can turn this class into
|
||||
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
|
||||
command line.
|
||||
|
||||
Parameters:
|
||||
> Parameters that control the model
|
||||
|
||||
model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
|
||||
Keyword arguments for [`~transformers.AutoModelForCausalLM.from_pretrained`], used when the `model`
|
||||
argument of the [`IterativeSFTTrainer`] is provided as a string.
|
||||
|
||||
> Parameters that control the data preprocessing
|
||||
|
||||
max_length (`int` or `None`, *optional*, defaults to `None`):
|
||||
Maximum length of the tokenized sequence. Sequences longer than `max_length` are truncated.
|
||||
truncation_mode (`str`, *optional*, defaults to `"keep_end"`):
|
||||
The truncation mode to use, either `"keep_end"` or `"keep_start"`.
|
||||
optimize_device_cache (`bool`, *optional*, defaults to `False`):
|
||||
Whether to optimize accelerator cache for slightly more memory-efficient training.
|
||||
|
||||
"""
|
||||
vllm_sampling_params: Optional[Any] = field(
|
||||
default = None,
|
||||
metadata = {'help': 'vLLM SamplingParams'},
|
||||
)
|
||||
unsloth_num_chunks : Optional[int] = field(
|
||||
default = -1,
|
||||
metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
|
||||
)
|
||||
def __init__(
|
||||
self,
|
||||
output_dir = None,
|
||||
overwrite_output_dir = None,
|
||||
do_train = False,
|
||||
do_eval = False,
|
||||
do_predict = False,
|
||||
eval_strategy = 'no',
|
||||
prediction_loss_only = False,
|
||||
per_device_train_batch_size = 4,
|
||||
per_device_eval_batch_size = 4,
|
||||
per_gpu_train_batch_size = None,
|
||||
per_gpu_eval_batch_size = None,
|
||||
gradient_accumulation_steps = 2,
|
||||
eval_accumulation_steps = 2,
|
||||
eval_delay = 0,
|
||||
torch_empty_cache_steps = 250,
|
||||
learning_rate = 5e-05,
|
||||
weight_decay = 0.01,
|
||||
adam_beta1 = 0.9,
|
||||
adam_beta2 = 0.999,
|
||||
adam_epsilon = 1e-08,
|
||||
max_grad_norm = 1.0,
|
||||
num_train_epochs = 3.0,
|
||||
max_steps = -1,
|
||||
lr_scheduler_type = 'linear',
|
||||
warmup_ratio = 0.1,
|
||||
warmup_steps = 0,
|
||||
log_level = 'passive',
|
||||
log_level_replica = 'warning',
|
||||
log_on_each_node = True,
|
||||
logging_dir = None,
|
||||
logging_strategy = 'steps',
|
||||
logging_first_step = False,
|
||||
logging_steps = 1,
|
||||
logging_nan_inf_filter = False,
|
||||
save_strategy = 'steps',
|
||||
save_steps = 500,
|
||||
save_total_limit = None,
|
||||
save_safetensors = True,
|
||||
save_on_each_node = False,
|
||||
save_only_model = False,
|
||||
restore_callback_states_from_checkpoint = False,
|
||||
no_cuda = False,
|
||||
use_cpu = False,
|
||||
use_mps_device = False,
|
||||
seed = 3407,
|
||||
data_seed = 3407,
|
||||
jit_mode_eval = False,
|
||||
use_ipex = False,
|
||||
bf16 = False,
|
||||
fp16 = False,
|
||||
fp16_opt_level = 'O1',
|
||||
half_precision_backend = 'auto',
|
||||
bf16_full_eval = False,
|
||||
fp16_full_eval = False,
|
||||
tf32 = None,
|
||||
local_rank = -1,
|
||||
ddp_backend = None,
|
||||
tpu_num_cores = None,
|
||||
tpu_metrics_debug = False,
|
||||
debug = '',
|
||||
dataloader_drop_last = False,
|
||||
eval_steps = None,
|
||||
dataloader_num_workers = 0,
|
||||
dataloader_prefetch_factor = None,
|
||||
past_index = -1,
|
||||
run_name = None,
|
||||
disable_tqdm = None,
|
||||
remove_unused_columns = True,
|
||||
label_names = None,
|
||||
load_best_model_at_end = False,
|
||||
metric_for_best_model = None,
|
||||
greater_is_better = None,
|
||||
ignore_data_skip = False,
|
||||
fsdp = '',
|
||||
fsdp_min_num_params = 0,
|
||||
fsdp_config = None,
|
||||
fsdp_transformer_layer_cls_to_wrap = None,
|
||||
accelerator_config = None,
|
||||
deepspeed = None,
|
||||
label_smoothing_factor = 0.0,
|
||||
optim = 'adamw_8bit',
|
||||
optim_args = None,
|
||||
adafactor = False,
|
||||
group_by_length = False,
|
||||
length_column_name = 'length',
|
||||
report_to = None,
|
||||
ddp_find_unused_parameters = None,
|
||||
ddp_bucket_cap_mb = None,
|
||||
ddp_broadcast_buffers = None,
|
||||
dataloader_pin_memory = True,
|
||||
dataloader_persistent_workers = False,
|
||||
skip_memory_metrics = True,
|
||||
use_legacy_prediction_loop = False,
|
||||
push_to_hub = False,
|
||||
resume_from_checkpoint = None,
|
||||
hub_model_id = None,
|
||||
hub_strategy = 'every_save',
|
||||
hub_token = None,
|
||||
hub_private_repo = None,
|
||||
hub_always_push = False,
|
||||
hub_revision = None,
|
||||
gradient_checkpointing = False,
|
||||
gradient_checkpointing_kwargs = None,
|
||||
include_inputs_for_metrics = False,
|
||||
eval_do_concat_batches = True,
|
||||
fp16_backend = 'auto',
|
||||
push_to_hub_model_id = None,
|
||||
push_to_hub_organization = None,
|
||||
push_to_hub_token = None,
|
||||
mp_parameters = '',
|
||||
auto_find_batch_size = True,
|
||||
full_determinism = False,
|
||||
torchdynamo = None,
|
||||
ray_scope = 'last',
|
||||
ddp_timeout = 1800,
|
||||
torch_compile = False,
|
||||
torch_compile_backend = None,
|
||||
torch_compile_mode = None,
|
||||
include_tokens_per_second = False,
|
||||
include_num_input_tokens_seen = False,
|
||||
neftune_noise_alpha = None,
|
||||
optim_target_modules = None,
|
||||
batch_eval_metrics = False,
|
||||
eval_on_start = False,
|
||||
use_liger_kernel = False,
|
||||
liger_kernel_config = None,
|
||||
eval_use_gather_object = False,
|
||||
average_tokens_across_devices = True,
|
||||
model_init_kwargs = None,
|
||||
max_length = None,
|
||||
truncation_mode = 'keep_end',
|
||||
optimize_device_cache = False,
|
||||
vllm_sampling_params = None,
|
||||
unsloth_num_chunks = -1,
|
||||
**kwargs,
|
||||
):
|
||||
if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
|
||||
if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
|
||||
if output_dir is None and save_strategy == 'steps' and save_steps == 500:
|
||||
output_dir = 'unsloth_training_checkpoints'
|
||||
save_strategy = 'no'
|
||||
|
||||
super().__init__(
|
||||
output_dir = output_dir,
|
||||
overwrite_output_dir = overwrite_output_dir,
|
||||
do_train = do_train,
|
||||
do_eval = do_eval,
|
||||
do_predict = do_predict,
|
||||
eval_strategy = eval_strategy,
|
||||
prediction_loss_only = prediction_loss_only,
|
||||
per_device_train_batch_size = per_device_train_batch_size,
|
||||
per_device_eval_batch_size = per_device_eval_batch_size,
|
||||
per_gpu_train_batch_size = per_gpu_train_batch_size,
|
||||
per_gpu_eval_batch_size = per_gpu_eval_batch_size,
|
||||
gradient_accumulation_steps = gradient_accumulation_steps,
|
||||
eval_accumulation_steps = eval_accumulation_steps,
|
||||
eval_delay = eval_delay,
|
||||
torch_empty_cache_steps = torch_empty_cache_steps,
|
||||
learning_rate = learning_rate,
|
||||
weight_decay = weight_decay,
|
||||
adam_beta1 = adam_beta1,
|
||||
adam_beta2 = adam_beta2,
|
||||
adam_epsilon = adam_epsilon,
|
||||
max_grad_norm = max_grad_norm,
|
||||
num_train_epochs = num_train_epochs,
|
||||
max_steps = max_steps,
|
||||
lr_scheduler_type = lr_scheduler_type,
|
||||
warmup_ratio = warmup_ratio,
|
||||
warmup_steps = warmup_steps,
|
||||
log_level = log_level,
|
||||
log_level_replica = log_level_replica,
|
||||
log_on_each_node = log_on_each_node,
|
||||
logging_dir = logging_dir,
|
||||
logging_strategy = logging_strategy,
|
||||
logging_first_step = logging_first_step,
|
||||
logging_steps = logging_steps,
|
||||
logging_nan_inf_filter = logging_nan_inf_filter,
|
||||
save_strategy = save_strategy,
|
||||
save_steps = save_steps,
|
||||
save_total_limit = save_total_limit,
|
||||
save_safetensors = save_safetensors,
|
||||
save_on_each_node = save_on_each_node,
|
||||
save_only_model = save_only_model,
|
||||
restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
|
||||
no_cuda = no_cuda,
|
||||
use_cpu = use_cpu,
|
||||
use_mps_device = use_mps_device,
|
||||
seed = seed,
|
||||
data_seed = data_seed,
|
||||
jit_mode_eval = jit_mode_eval,
|
||||
use_ipex = use_ipex,
|
||||
bf16 = bf16,
|
||||
fp16 = fp16,
|
||||
fp16_opt_level = fp16_opt_level,
|
||||
half_precision_backend = half_precision_backend,
|
||||
bf16_full_eval = bf16_full_eval,
|
||||
fp16_full_eval = fp16_full_eval,
|
||||
tf32 = tf32,
|
||||
local_rank = local_rank,
|
||||
ddp_backend = ddp_backend,
|
||||
tpu_num_cores = tpu_num_cores,
|
||||
tpu_metrics_debug = tpu_metrics_debug,
|
||||
debug = debug,
|
||||
dataloader_drop_last = dataloader_drop_last,
|
||||
eval_steps = eval_steps,
|
||||
dataloader_num_workers = dataloader_num_workers,
|
||||
dataloader_prefetch_factor = dataloader_prefetch_factor,
|
||||
past_index = past_index,
|
||||
run_name = run_name,
|
||||
disable_tqdm = disable_tqdm,
|
||||
remove_unused_columns = remove_unused_columns,
|
||||
label_names = label_names,
|
||||
load_best_model_at_end = load_best_model_at_end,
|
||||
metric_for_best_model = metric_for_best_model,
|
||||
greater_is_better = greater_is_better,
|
||||
ignore_data_skip = ignore_data_skip,
|
||||
fsdp = fsdp,
|
||||
fsdp_min_num_params = fsdp_min_num_params,
|
||||
fsdp_config = fsdp_config,
|
||||
fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
|
||||
accelerator_config = accelerator_config,
|
||||
deepspeed = deepspeed,
|
||||
label_smoothing_factor = label_smoothing_factor,
|
||||
optim = optim,
|
||||
optim_args = optim_args,
|
||||
adafactor = adafactor,
|
||||
group_by_length = group_by_length,
|
||||
length_column_name = length_column_name,
|
||||
report_to = report_to,
|
||||
ddp_find_unused_parameters = ddp_find_unused_parameters,
|
||||
ddp_bucket_cap_mb = ddp_bucket_cap_mb,
|
||||
ddp_broadcast_buffers = ddp_broadcast_buffers,
|
||||
dataloader_pin_memory = dataloader_pin_memory,
|
||||
dataloader_persistent_workers = dataloader_persistent_workers,
|
||||
skip_memory_metrics = skip_memory_metrics,
|
||||
use_legacy_prediction_loop = use_legacy_prediction_loop,
|
||||
push_to_hub = push_to_hub,
|
||||
resume_from_checkpoint = resume_from_checkpoint,
|
||||
hub_model_id = hub_model_id,
|
||||
hub_strategy = hub_strategy,
|
||||
hub_token = hub_token,
|
||||
hub_private_repo = hub_private_repo,
|
||||
hub_always_push = hub_always_push,
|
||||
hub_revision = hub_revision,
|
||||
gradient_checkpointing = gradient_checkpointing,
|
||||
gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
|
||||
include_inputs_for_metrics = include_inputs_for_metrics,
|
||||
eval_do_concat_batches = eval_do_concat_batches,
|
||||
fp16_backend = fp16_backend,
|
||||
push_to_hub_model_id = push_to_hub_model_id,
|
||||
push_to_hub_organization = push_to_hub_organization,
|
||||
push_to_hub_token = push_to_hub_token,
|
||||
mp_parameters = mp_parameters,
|
||||
auto_find_batch_size = auto_find_batch_size,
|
||||
full_determinism = full_determinism,
|
||||
torchdynamo = torchdynamo,
|
||||
ray_scope = ray_scope,
|
||||
ddp_timeout = ddp_timeout,
|
||||
torch_compile = torch_compile,
|
||||
torch_compile_backend = torch_compile_backend,
|
||||
torch_compile_mode = torch_compile_mode,
|
||||
include_tokens_per_second = include_tokens_per_second,
|
||||
include_num_input_tokens_seen = include_num_input_tokens_seen,
|
||||
neftune_noise_alpha = neftune_noise_alpha,
|
||||
optim_target_modules = optim_target_modules,
|
||||
batch_eval_metrics = batch_eval_metrics,
|
||||
eval_on_start = eval_on_start,
|
||||
use_liger_kernel = use_liger_kernel,
|
||||
liger_kernel_config = liger_kernel_config,
|
||||
eval_use_gather_object = eval_use_gather_object,
|
||||
average_tokens_across_devices = average_tokens_across_devices,
|
||||
model_init_kwargs = model_init_kwargs,
|
||||
max_length = max_length,
|
||||
truncation_mode = truncation_mode,
|
||||
optimize_device_cache = optimize_device_cache,**kwargs)
|
||||
self.vllm_sampling_params = vllm_sampling_params
|
||||
self.unsloth_num_chunks = unsloth_num_chunks
|
||||
pass
|
||||
|
||||
class _UnslothIterativeSFTTrainer(Trainer):
|
||||
""""""
|
||||
|
||||
_tag_names = ["trl", "iterative-sft"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Union[str, PreTrainedModel],
|
||||
args: Optional[Union[IterativeSFTConfig, TrainingArguments]] = None,
|
||||
data_collator: Optional[DataCollator] = None,
|
||||
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
|
||||
processing_class: Optional[
|
||||
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
|
||||
] = None,
|
||||
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (
|
||||
None,
|
||||
None,
|
||||
),
|
||||
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
||||
compute_metrics: Optional[Callable[[EvalLoopOutput], dict]] = None,
|
||||
):
|
||||
# Args
|
||||
model_id = model if isinstance(model, str) else model.config._name_or_path
|
||||
if args is None:
|
||||
model_name = model_id.split("/")[-1]
|
||||
args = IterativeSFTConfig(f"{model_name}-IterativeSFT")
|
||||
elif isinstance(args, TrainingArguments) and not isinstance(args, IterativeSFTConfig):
|
||||
dict_args = args.to_dict()
|
||||
dict_args["hub_token"] = args.hub_token # to_dict hides the hub_token
|
||||
dict_args.pop("push_to_hub_token")
|
||||
args = IterativeSFTConfig(**dict_args)
|
||||
|
||||
# Handle the tokenizer
|
||||
if processing_class is None:
|
||||
processing_class = AutoTokenizer.from_pretrained(model_id)
|
||||
|
||||
# Model
|
||||
if args.model_init_kwargs is not None and not isinstance(model, str):
|
||||
warnings.warn(
|
||||
"You passed model_init_kwargs to the `IterativeSFTConfig`, but your model is already instantiated. "
|
||||
"The `model_init_kwargs` will be ignored."
|
||||
)
|
||||
if isinstance(model, str):
|
||||
model = self._create_model_from_path(model, args)
|
||||
|
||||
# PEFT configuration and model wrapping
|
||||
if is_peft_available() and isinstance(model, PeftModel):
|
||||
self.is_peft_model = True
|
||||
else:
|
||||
self.is_peft_model = False
|
||||
|
||||
self.processing_class = processing_class
|
||||
self.is_encoder_decoder = getattr(model.config, "is_encoder_decoder", False)
|
||||
|
||||
if data_collator is None:
|
||||
if self.is_encoder_decoder:
|
||||
self.data_collator = DataCollatorForSeq2Seq(
|
||||
processing_class, label_pad_token_id=-100, pad_to_multiple_of=8
|
||||
)
|
||||
else:
|
||||
self.data_collator = DataCollatorForLanguageModeling(self.processing_class, mlm=False)
|
||||
else:
|
||||
self.data_collator = data_collator
|
||||
|
||||
self.max_length = args.max_length
|
||||
self.truncation_mode = args.truncation_mode
|
||||
self.optimize_device_cache = args.optimize_device_cache
|
||||
|
||||
super().__init__(
|
||||
model=model,
|
||||
args=args,
|
||||
data_collator=self.data_collator,
|
||||
eval_dataset=eval_dataset,
|
||||
processing_class=processing_class,
|
||||
compute_metrics=compute_metrics,
|
||||
optimizers=optimizers,
|
||||
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
|
||||
)
|
||||
|
||||
# Add tags for models that have been loaded with the correct transformers version
|
||||
if hasattr(self.model, "add_model_tags"):
|
||||
self.model.add_model_tags(self._tag_names)
|
||||
|
||||
self.create_optimizer_and_scheduler(self.args.max_steps)
|
||||
|
||||
# prepare model, optimizer and lr_scheduler
|
||||
self.model, self.optimizer, self.lr_scheduler = self.accelerator.prepare(
|
||||
self.model, self.optimizer, self.lr_scheduler
|
||||
)
|
||||
|
||||
self.processing_class.truncation_side = "left" if self.truncation_mode == "keep_end" else "right"
|
||||
|
||||
if not hasattr(self, "accelerator"):
|
||||
raise AttributeError(
|
||||
"Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`."
|
||||
)
|
||||
|
||||
PPODecorators.optimize_device_cache = self.optimize_device_cache
|
||||
|
||||
def _create_model_from_path(self, model_path: str, args: IterativeSFTConfig) -> PreTrainedModel:
|
||||
"""Creates a model from a path or model identifier."""
|
||||
model_init_kwargs = args.model_init_kwargs or {}
|
||||
return AutoModelForCausalLM.from_pretrained(model_path, **model_init_kwargs)
|
||||
|
||||
def prepare_model_inputs(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, labels: torch.Tensor):
|
||||
if attention_mask is None:
|
||||
attention_mask = [torch.ones_like(ids) for ids in input_ids]
|
||||
|
||||
if self.is_encoder_decoder:
|
||||
input_data = self.data_collator(
|
||||
[
|
||||
{"input_ids": ids, "attention_mask": att, "labels": lab}
|
||||
for ids, att, lab in zip(input_ids, attention_mask, labels)
|
||||
]
|
||||
).to(self.model.device)
|
||||
|
||||
input_data.pop("decoder_input_ids", None) # This is directly computed inside the model
|
||||
|
||||
input_data["labels"][input_data["labels"] == self.processing_class.pad_token_id] = -100
|
||||
|
||||
else:
|
||||
input_data = self.data_collator(
|
||||
[{"input_ids": ids, "attention_mask": att} for ids, att in zip(input_ids, attention_mask)]
|
||||
).to(self.model.device)
|
||||
|
||||
# truncate in case the user has provided input_ids, attention_mask and labels
|
||||
if self.max_length is not None:
|
||||
if self.truncation_mode == "keep_start":
|
||||
input_data = {k: v[: self.max_length] for k, v in input_data.items()}
|
||||
elif self.truncation_mode == "keep_end":
|
||||
input_data = {k: v[-self.max_length :] for k, v in input_data.items()}
|
||||
else:
|
||||
raise ValueError(f"Unknown truncation mode: {self.truncation_mode}")
|
||||
|
||||
return input_data
|
||||
|
||||
@staticmethod
|
||||
def _step_safety_checker(
|
||||
input_ids: list[torch.LongTensor],
|
||||
attention_mask: list[torch.LongTensor],
|
||||
labels: list[torch.LongTensor],
|
||||
texts: list[str],
|
||||
texts_labels: list[str],
|
||||
):
|
||||
"""
|
||||
Check if the input data is valid for training.
|
||||
|
||||
Args:
|
||||
input_ids (list[`torch.LongTensor`]):
|
||||
List of tensors containing the input_ids
|
||||
attention_mask (list[`torch.LongTensor`]):
|
||||
List of tensors containing the attention_mask
|
||||
labels (list[`torch.FloatTensor`]):
|
||||
List of tensors containing the labels
|
||||
texts (list[`str`]):
|
||||
List of string containing the text input.
|
||||
texts_labels (list[`str`]):
|
||||
List of string containing the text labels.
|
||||
|
||||
Returns:
|
||||
`tuple`: The input data.
|
||||
"""
|
||||
if texts is None:
|
||||
if attention_mask is None:
|
||||
for name, tensor_list in zip(["input_ids", "labels"], [input_ids, labels]):
|
||||
if not isinstance(tensor_list, list):
|
||||
raise ValueError(f"{name} must be a list of tensors - got {type(tensor_list)}")
|
||||
if not isinstance(tensor_list[0], torch.Tensor):
|
||||
raise ValueError(f"Elements in {name} must be tensors - got {type(tensor_list[0])}")
|
||||
else:
|
||||
for name, tensor_list in zip(
|
||||
["input_ids", "attention_mask", "labels"], [input_ids, attention_mask, labels]
|
||||
):
|
||||
if not isinstance(tensor_list, list):
|
||||
raise ValueError(f"{name} must be a list of tensors - got {type(tensor_list)}")
|
||||
if not isinstance(tensor_list[0], torch.Tensor):
|
||||
raise ValueError(f"Elements in {name} must be tensors - got {type(tensor_list[0])}")
|
||||
else:
|
||||
if not isinstance(texts, list):
|
||||
raise ValueError(f"'text' must be a list of strings - got {type(texts)}")
|
||||
if not isinstance(texts[0], str):
|
||||
raise ValueError(f"Elements in 'text' must be strings - got {type(texts[0])}")
|
||||
if texts_labels is not None:
|
||||
if not isinstance(texts_labels, list):
|
||||
raise ValueError(f"'text_labels' must be a list of strings - got {type(texts_labels)}")
|
||||
if not isinstance(texts_labels[0], str):
|
||||
raise ValueError(f"Elements in 'text_labels' must be strings - got {type(texts_labels[0])}")
|
||||
|
||||
return input_ids, attention_mask, labels, texts, texts_labels
|
||||
|
||||
@PPODecorators.empty_device_cache()
|
||||
def step(
|
||||
self,
|
||||
input_ids: Optional[list[torch.LongTensor]] = None,
|
||||
attention_mask: Optional[list[torch.LongTensor]] = None,
|
||||
labels: Optional[list[torch.LongTensor]] = None,
|
||||
texts: Optional[list[str]] = None,
|
||||
texts_labels: Optional[list[str]] = None,
|
||||
):
|
||||
"""
|
||||
Run an optimisation step given a list of input_ids, attention_mask, and labels or a list of text and
|
||||
text_labels.
|
||||
|
||||
Args:
|
||||
input_ids (list[`torch.LongTensor`]):
|
||||
List of tensors containing the input_ids (if not provided, text will be used)
|
||||
attention_mask (list[`torch.LongTensor`], , *optional*):
|
||||
List of tensors containing the attention_mask
|
||||
labels (list[`torch.FloatTensor`], *optional*):
|
||||
List of tensors containing the labels (if set to None, will default to input_ids)
|
||||
texts (list[`str`], *optional*):
|
||||
List of strings containing the text input (if not provided, input_ids will directly be used)
|
||||
texts_labels (list[`str`], *optional*):
|
||||
List of strings containing the text labels (if set to None, will default to text)
|
||||
|
||||
Returns:
|
||||
`dict[str, Any]`: A summary of the training statistics
|
||||
"""
|
||||
self.model.train()
|
||||
|
||||
if self.state.global_step == 0:
|
||||
self.tr_loss = torch.tensor(0.0).to(self.args.device)
|
||||
self._globalstep_last_logged = self.state.global_step
|
||||
|
||||
if input_ids is None and texts is None:
|
||||
raise ValueError("Step should include `input_ids` or `texts` as keyword arguments.")
|
||||
elif input_ids is not None and texts is not None:
|
||||
warnings.warn(
|
||||
"Both `input_ids` and `texts` argument are provided. `input_ids` will be ignored. "
|
||||
"Please provide only one of the two.",
|
||||
UserWarning,
|
||||
)
|
||||
|
||||
if labels is None and texts_labels is None and self.is_encoder_decoder:
|
||||
raise ValueError(
|
||||
"No 'labels' or 'text_labels' are provided. When using an encoder-decoder architecture, 'labels' or 'text_labels' must be passed."
|
||||
)
|
||||
|
||||
# Convert Column to list if not already
|
||||
input_ids = input_ids[:] if input_ids is not None else None
|
||||
attention_mask = attention_mask[:] if attention_mask is not None else None
|
||||
labels = labels[:] if labels is not None else None
|
||||
texts = texts[:] if texts is not None else None
|
||||
texts_labels = texts_labels[:] if texts_labels is not None else None
|
||||
|
||||
input_ids, attention_mask, labels, texts, texts_labels = self._step_safety_checker(
|
||||
input_ids, attention_mask, labels, texts, texts_labels
|
||||
)
|
||||
|
||||
if texts is not None:
|
||||
model_inputs = self.processing_class(
|
||||
texts, max_length=self.max_length, truncation=True, padding=True, return_tensors="pt"
|
||||
)
|
||||
|
||||
input_ids, attention_mask = model_inputs["input_ids"], model_inputs["attention_mask"]
|
||||
|
||||
if texts_labels is not None:
|
||||
labels = self.processing_class(
|
||||
texts, max_length=self.max_length, truncation=True, padding=True, return_tensors="pt"
|
||||
)["input_ids"]
|
||||
|
||||
if labels is None:
|
||||
labels = input_ids
|
||||
|
||||
model_inputs = self.prepare_model_inputs(input_ids, attention_mask, labels)
|
||||
|
||||
model_inputs_names = list(model_inputs.keys())
|
||||
|
||||
batch_dict = {}
|
||||
batch_dict.update(model_inputs)
|
||||
|
||||
def collator(data):
|
||||
return_dict = dict()
|
||||
for key in data[0]:
|
||||
if key in ["input_ids", "attention_mask", "labels"]:
|
||||
return_dict[key] = torch.stack([d[key] for d in data]).to(self.model.device)
|
||||
return return_dict
|
||||
|
||||
batch_data = Dataset.from_dict(batch_dict)
|
||||
batch_data.set_format("torch")
|
||||
|
||||
step_dataloader = DataLoader(
|
||||
batch_data,
|
||||
batch_size=self.args.per_device_train_batch_size,
|
||||
shuffle=True,
|
||||
collate_fn=collator,
|
||||
)
|
||||
|
||||
for _, batch in enumerate(step_dataloader):
|
||||
with self.accelerator.accumulate(self.model):
|
||||
model_inputs = {k: batch[k] for k in model_inputs_names}
|
||||
loss = self.compute_loss(self.model, model_inputs)
|
||||
|
||||
if self.args.n_gpu > 1:
|
||||
loss = loss.mean()
|
||||
|
||||
tr_loss_step = loss.detach()
|
||||
|
||||
self.accelerator.backward(loss)
|
||||
|
||||
if self.accelerator.sync_gradients and self.args.max_grad_norm is not None:
|
||||
self.accelerator.clip_grad_norm_(
|
||||
self.model.parameters(),
|
||||
self.args.max_grad_norm,
|
||||
)
|
||||
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
if self.lr_scheduler is not None:
|
||||
self.lr_scheduler.step()
|
||||
|
||||
self.state.global_step += 1
|
||||
|
||||
# update stats etc
|
||||
self.tr_loss += tr_loss_step
|
||||
|
||||
self._maybe_log_save_evaluate()
|
||||
|
||||
def _maybe_log_save_evaluate(self):
|
||||
# check if eval is required
|
||||
if self.args.eval_steps is not None:
|
||||
if self.state.global_step % self.args.eval_steps == 0 and self.state.global_step != 0:
|
||||
self.evaluate(self.eval_dataset)
|
||||
|
||||
# check if logging is required
|
||||
if self.args.logging_steps is not None:
|
||||
if self.state.global_step % self.args.logging_steps == 0 and self.state.global_step != 0:
|
||||
logs: dict[str, float] = {}
|
||||
|
||||
tr_loss_scalar = self._nested_gather(self.tr_loss).mean().item()
|
||||
|
||||
# reset tr_loss to zero
|
||||
self.tr_loss -= self.tr_loss
|
||||
|
||||
logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4)
|
||||
logs["learning_rate"] = self._get_learning_rate()
|
||||
|
||||
self._globalstep_last_logged = self.state.global_step
|
||||
|
||||
self.log(logs)
|
||||
|
||||
# Ensure the model card is saved along with the checkpoint
|
||||
def _save_checkpoint(self, model, trial):
|
||||
if self.args.hub_model_id is None:
|
||||
model_name = Path(self.args.output_dir).name
|
||||
else:
|
||||
model_name = self.args.hub_model_id.split("/")[-1]
|
||||
self.create_model_card(model_name=model_name)
|
||||
super()._save_checkpoint(model, trial)
|
||||
|
||||
def create_model_card(
|
||||
self,
|
||||
model_name: Optional[str] = None,
|
||||
dataset_name: Optional[str] = None,
|
||||
tags: Union[str, list[str], None] = None,
|
||||
):
|
||||
"""
|
||||
Creates a draft of a model card using the information available to the `Trainer`.
|
||||
|
||||
Args:
|
||||
model_name (`str` or `None`, *optional*, defaults to `None`):
|
||||
Name of the model.
|
||||
dataset_name (`str` or `None`, *optional*, defaults to `None`):
|
||||
Name of the dataset used for training.
|
||||
tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
|
||||
Tags to be associated with the model card.
|
||||
"""
|
||||
if not self.is_world_process_zero():
|
||||
return
|
||||
|
||||
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
|
||||
base_model = self.model.config._name_or_path
|
||||
else:
|
||||
base_model = None
|
||||
|
||||
# normalize `tags` to a mutable set
|
||||
if tags is None:
|
||||
tags = set()
|
||||
elif isinstance(tags, str):
|
||||
tags = {tags}
|
||||
else:
|
||||
tags = set(tags)
|
||||
|
||||
if hasattr(self.model.config, "unsloth_version"):
|
||||
tags.add("unsloth")
|
||||
|
||||
tags.update(self._tag_names)
|
||||
|
||||
model_card = generate_model_card(
|
||||
base_model=base_model,
|
||||
model_name=model_name,
|
||||
hub_model_id=self.hub_model_id,
|
||||
dataset_name=dataset_name,
|
||||
tags=tags,
|
||||
wandb_url=wandb.run.url if is_wandb_available() and wandb.run is not None else None,
|
||||
comet_url=get_comet_experiment_url(),
|
||||
trainer_name="Iterative SFT",
|
||||
)
|
||||
|
||||
model_card.save(os.path.join(self.args.output_dir, "README.md"))
|
||||
class UnslothIterativeSFTTrainer(_UnslothIterativeSFTTrainer):
|
||||
"""
|
||||
|
||||
The IterativeSFTTrainer can be used to finetune models with methods that requires some steps between optimization.
|
||||
|
||||
Args:
|
||||
model (`Union[str, PreTrainedModel]`):
|
||||
Model to be trained. Can be either:
|
||||
|
||||
- A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or a
|
||||
path to a *directory* containing model weights saved using
|
||||
[`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is loaded
|
||||
using [`~transformers.AutoModelForCausalLM.from_pretrained`] with the keyword arguments in
|
||||
`args.model_init_kwargs`.
|
||||
- A [`~transformers.PreTrainedModel`] object. Only causal language models are supported.
|
||||
args ([`IterativeSFTConfig`], *optional*, defaults to `None`):
|
||||
Configuration for this trainer. If `None`, a default configuration is used.
|
||||
data_collator (`DataCollator`, *optional*):
|
||||
Function to use to form a batch from a list of elements of the processed `train_dataset` or `eval_dataset`.
|
||||
Will default to [`~transformers.default_data_collator`] if no `processing_class` is provided, an instance
|
||||
of [`~transformers.DataCollatorWithPadding`] otherwise if the processing_class is a feature extractor or
|
||||
tokenizer.
|
||||
eval_dataset (`datasets.Dataset`):
|
||||
The dataset to use for evaluation.
|
||||
processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.BaseImageProcessor`], [`~transformers.FeatureExtractionMixin`] or [`~transformers.ProcessorMixin`], *optional*, defaults to `None`):
|
||||
Processing class used to process the data. If `None`, the processing class is loaded from the model's name
|
||||
with [`~transformers.AutoTokenizer.from_pretrained`].
|
||||
optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
|
||||
The optimizer and scheduler to use for training.
|
||||
preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
|
||||
The function to use to preprocess the logits before computing the metrics.
|
||||
compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
|
||||
The function to use to compute the metrics. Must take a `EvalPrediction` and return a dictionary string to
|
||||
metric values.
|
||||
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
args = None,
|
||||
data_collator = None,
|
||||
eval_dataset = None,
|
||||
processing_class = None,
|
||||
preprocess_logits_for_metrics = None,
|
||||
compute_metrics = None,
|
||||
**kwargs
|
||||
):
|
||||
if args is None: args = UnslothIterativeSFTConfig()
|
||||
use_bf16 = getattr(args, 'bf16', False)
|
||||
if type(use_bf16) is not bool: use_bf16 = False
|
||||
use_fp16 = getattr(args, 'fp16', False)
|
||||
if type(use_fp16) is not bool: use_fp16 = False
|
||||
force_float32 = False
|
||||
if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
|
||||
print('Unsloth: Switching to float32 training since model cannot work with float16')
|
||||
force_float32 = True
|
||||
mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
|
||||
dtype = getattr(model.config, 'torch_dtype', None)
|
||||
if dtype is None: dtype = model.get_input_embeddings().dtype
|
||||
from unsloth_zoo.utils import _get_dtype
|
||||
dtype = _get_dtype(dtype)
|
||||
float16 = dtype == torch.float16
|
||||
if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
|
||||
if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
|
||||
if force_float32:
|
||||
args.fp16 = False
|
||||
args.bf16 = False
|
||||
os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
|
||||
elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
|
||||
args.fp16 = float16
|
||||
args.bf16 = not float16
|
||||
os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
|
||||
if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
|
||||
args.eval_strategy = 'steps'
|
||||
if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
|
||||
ga_steps = getattr(args, 'gradient_accumulation_steps', None)
|
||||
if ga_steps is not None and ga_steps > 1:
|
||||
from transformers import __version__ as transformers_version
|
||||
if Version(transformers_version) <= Version('4.45.2'):
|
||||
print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
|
||||
'`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
|
||||
if getattr(args, 'eval_strategy', 'no') != 'no':
|
||||
eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
|
||||
if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
|
||||
if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
|
||||
fp16_full_eval = getattr(args, 'fp16_full_eval', False)
|
||||
if type(fp16_full_eval) is not bool: fp16_full_eval = False
|
||||
bf16_full_eval = getattr(args, 'bf16_full_eval', False)
|
||||
if type(bf16_full_eval) is not bool: bf16_full_eval = False
|
||||
if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
|
||||
if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
|
||||
if force_float32:
|
||||
args.bf16_full_eval = False
|
||||
args.fp16_full_eval = False
|
||||
elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
|
||||
args.bf16_full_eval = True
|
||||
args.fp16_full_eval = False
|
||||
elif not bf16_full_eval and not fp16_full_eval:
|
||||
args.bf16_full_eval = args.bf16
|
||||
args.fp16_full_eval = args.fp16
|
||||
_output_logits = False
|
||||
if locals().get('compute_metrics', None) is not None: _output_logits = True
|
||||
if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
|
||||
if _output_logits:
|
||||
os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
|
||||
if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
|
||||
pass
|
||||
else:
|
||||
model_max_seq_length = getattr(model, 'max_seq_length', None)
|
||||
args_max_seq_length = getattr(args, 'max_seq_length', None)
|
||||
if args_max_seq_length is None and model_max_seq_length is not None:
|
||||
max_seq_length = model.max_seq_length
|
||||
if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
|
||||
if model is not None and hasattr(model, 'for_training'):
|
||||
model.for_training()
|
||||
if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
|
||||
if 'processing_class' in locals():
|
||||
if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
|
||||
if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
|
||||
other_metrics = []
|
||||
|
||||
from unsloth_zoo.logging_utils import PatchRLStatistics
|
||||
PatchRLStatistics('iterative_sft_trainer', other_metrics)
|
||||
|
||||
super().__init__(
|
||||
model = model,
|
||||
args = args,
|
||||
data_collator = data_collator,
|
||||
eval_dataset = eval_dataset,
|
||||
processing_class = processing_class,
|
||||
preprocess_logits_for_metrics = preprocess_logits_for_metrics,
|
||||
compute_metrics = compute_metrics,**kwargs)
|
||||
if hasattr(self, 'neftune_hook_handle'):
|
||||
self.neftune_hook_handle.remove()
|
||||
if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
|
||||
if getattr(args, 'neftune_noise_alpha', None) is not None:
|
||||
model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
|
||||
pass
|
||||
if hasattr(self, 'accelerator'):
|
||||
scaler = self.accelerator.scaler
|
||||
current_model = model
|
||||
while hasattr(current_model, 'model'):
|
||||
current_model.accelerator_scaler = scaler
|
||||
current_model = current_model.model
|
||||
current_model.accelerator_scaler = scaler
|
||||
pass
|
||||
|
||||
pass
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,842 @@
|
||||
"""
|
||||
2025.8.4
|
||||
2025.8.5
|
||||
4.55.1
|
||||
0.21.0
|
||||
__UNSLOTH_VERSIONING__
|
||||
"""
|
||||
from torch import Tensor
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable
|
||||
from trl.trainer.prm_trainer import (BaseImageProcessor, Callable, DataCollator, DataCollatorForTokenClassification, Dataset, EvalPrediction, FeatureExtractionMixin, Optional, PRMConfig, PRMTrainer, PartialState, Path, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, Trainer, TrainerCallback, Union, chain, compute_accuracy, disable_dropout_in_model, features, generate_model_card, inspect, is_peft_available, is_wandb_available, nn, os, prepare_model_for_kbit_training, textwrap, torch, warnings, Optional, PeftModel, PreTrainedModel, Trainer, is_peft_available, os, torch)
|
||||
|
||||
|
||||
import os
|
||||
from typing import *
|
||||
from dataclasses import dataclass, field
|
||||
from packaging.version import Version
|
||||
import torch
|
||||
import numpy as np
|
||||
from contextlib import nullcontext
|
||||
from torch.nn import functional as F
|
||||
from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling
|
||||
|
||||
torch_compile_options = {
|
||||
"epilogue_fusion" : True,
|
||||
"max_autotune" : False,
|
||||
"shape_padding" : True,
|
||||
"trace.enabled" : False,
|
||||
"triton.cudagraphs" : False,
|
||||
}
|
||||
|
||||
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
|
||||
def chunked_selective_log_softmax(logits, index):
|
||||
# Split into 4 chunks only
|
||||
chunked_logits = torch.chunk(logits.reshape(-1, logits.shape[-1]), chunks = 4, dim = 0)
|
||||
chunked_index = torch.chunk(index.reshape(-1), chunks = 4, dim = 0)
|
||||
all_per_token_logps = []
|
||||
# Below loop does the same as selective_log_softmax(chunk_logits, chunk_index)
|
||||
for chunk_logits, chunk_index in zip(chunked_logits, chunked_index):
|
||||
chunk_logits = chunk_logits.to(torch.float32)
|
||||
selected_logits = torch.gather(chunk_logits, dim = -1, index = chunk_index.unsqueeze(-1)).squeeze(-1)
|
||||
logsumexp_values = torch.logsumexp(chunk_logits, dim = -1)
|
||||
per_token_logps = selected_logits - logsumexp_values
|
||||
all_per_token_logps.append(per_token_logps)
|
||||
pass
|
||||
all_per_token_logps = torch.concat(all_per_token_logps)
|
||||
all_per_token_logps = all_per_token_logps.reshape((logits.shape[0], logits.shape[1]))
|
||||
return all_per_token_logps
|
||||
@dataclass
|
||||
class UnslothPRMConfig(PRMConfig):
|
||||
"""
|
||||
|
||||
Configuration class for the [`PRMTrainer`].
|
||||
|
||||
This class includes only the parameters that are specific to PRM training. For a full list of training arguments,
|
||||
please refer to the [`~transformers.TrainingArguments`] documentation. Note that default values in this class may
|
||||
differ from those in [`~transformers.TrainingArguments`].
|
||||
|
||||
Using [`~transformers.HfArgumentParser`] we can turn this class into
|
||||
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
|
||||
command line.
|
||||
|
||||
Parameters:
|
||||
max_length (`int` or `None`, *optional*, defaults to `1024`):
|
||||
Maximum length of the sequences (prompt + completion) used for truncation.
|
||||
max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
|
||||
Maximum length of the prompt used for truncation.
|
||||
max_completion_length (`int` or `None`, *optional*, defaults to `None`):
|
||||
Maximum length of the completion used for truncation. The completion is the concatenation of the steps.
|
||||
disable_dropout (`bool`, *optional*, defaults to `True`):
|
||||
Whether to disable dropout in the model.
|
||||
step_separator (`str`, *optional*, defaults to `"\n"`):
|
||||
Separator used to separate each step of the reasoning process.
|
||||
train_on_last_step_only (`bool`, *optional*, defaults to `False`):
|
||||
Whether to train only on the last step.
|
||||
dataset_num_proc (`int`, *optional*, defaults to `None`):
|
||||
Number of processes to use for processing the dataset.
|
||||
|
||||
"""
|
||||
vllm_sampling_params: Optional[Any] = field(
|
||||
default = None,
|
||||
metadata = {'help': 'vLLM SamplingParams'},
|
||||
)
|
||||
unsloth_num_chunks : Optional[int] = field(
|
||||
default = -1,
|
||||
metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
|
||||
)
|
||||
def __init__(
|
||||
self,
|
||||
output_dir = None,
|
||||
overwrite_output_dir = None,
|
||||
do_train = False,
|
||||
do_eval = False,
|
||||
do_predict = False,
|
||||
eval_strategy = 'no',
|
||||
prediction_loss_only = False,
|
||||
per_device_train_batch_size = 4,
|
||||
per_device_eval_batch_size = 4,
|
||||
per_gpu_train_batch_size = None,
|
||||
per_gpu_eval_batch_size = None,
|
||||
gradient_accumulation_steps = 2,
|
||||
eval_accumulation_steps = 2,
|
||||
eval_delay = 0,
|
||||
torch_empty_cache_steps = 250,
|
||||
learning_rate = 5e-05,
|
||||
weight_decay = 0.01,
|
||||
adam_beta1 = 0.9,
|
||||
adam_beta2 = 0.999,
|
||||
adam_epsilon = 1e-08,
|
||||
max_grad_norm = 1.0,
|
||||
num_train_epochs = 3.0,
|
||||
max_steps = -1,
|
||||
lr_scheduler_type = 'linear',
|
||||
warmup_ratio = 0.1,
|
||||
warmup_steps = 0,
|
||||
log_level = 'passive',
|
||||
log_level_replica = 'warning',
|
||||
log_on_each_node = True,
|
||||
logging_dir = None,
|
||||
logging_strategy = 'steps',
|
||||
logging_first_step = False,
|
||||
logging_steps = 1,
|
||||
logging_nan_inf_filter = False,
|
||||
save_strategy = 'steps',
|
||||
save_steps = 500,
|
||||
save_total_limit = None,
|
||||
save_safetensors = True,
|
||||
save_on_each_node = False,
|
||||
save_only_model = False,
|
||||
restore_callback_states_from_checkpoint = False,
|
||||
no_cuda = False,
|
||||
use_cpu = False,
|
||||
use_mps_device = False,
|
||||
seed = 3407,
|
||||
data_seed = 3407,
|
||||
jit_mode_eval = False,
|
||||
use_ipex = False,
|
||||
bf16 = False,
|
||||
fp16 = False,
|
||||
fp16_opt_level = 'O1',
|
||||
half_precision_backend = 'auto',
|
||||
bf16_full_eval = False,
|
||||
fp16_full_eval = False,
|
||||
tf32 = None,
|
||||
local_rank = -1,
|
||||
ddp_backend = None,
|
||||
tpu_num_cores = None,
|
||||
tpu_metrics_debug = False,
|
||||
debug = '',
|
||||
dataloader_drop_last = False,
|
||||
eval_steps = None,
|
||||
dataloader_num_workers = 0,
|
||||
dataloader_prefetch_factor = None,
|
||||
past_index = -1,
|
||||
run_name = None,
|
||||
disable_tqdm = None,
|
||||
remove_unused_columns = True,
|
||||
label_names = None,
|
||||
load_best_model_at_end = False,
|
||||
metric_for_best_model = None,
|
||||
greater_is_better = None,
|
||||
ignore_data_skip = False,
|
||||
fsdp = '',
|
||||
fsdp_min_num_params = 0,
|
||||
fsdp_config = None,
|
||||
fsdp_transformer_layer_cls_to_wrap = None,
|
||||
accelerator_config = None,
|
||||
deepspeed = None,
|
||||
label_smoothing_factor = 0.0,
|
||||
optim = 'adamw_8bit',
|
||||
optim_args = None,
|
||||
adafactor = False,
|
||||
group_by_length = False,
|
||||
length_column_name = 'length',
|
||||
report_to = None,
|
||||
ddp_find_unused_parameters = None,
|
||||
ddp_bucket_cap_mb = None,
|
||||
ddp_broadcast_buffers = None,
|
||||
dataloader_pin_memory = True,
|
||||
dataloader_persistent_workers = False,
|
||||
skip_memory_metrics = True,
|
||||
use_legacy_prediction_loop = False,
|
||||
push_to_hub = False,
|
||||
resume_from_checkpoint = None,
|
||||
hub_model_id = None,
|
||||
hub_strategy = 'every_save',
|
||||
hub_token = None,
|
||||
hub_private_repo = None,
|
||||
hub_always_push = False,
|
||||
hub_revision = None,
|
||||
gradient_checkpointing = False,
|
||||
gradient_checkpointing_kwargs = None,
|
||||
include_inputs_for_metrics = False,
|
||||
eval_do_concat_batches = True,
|
||||
fp16_backend = 'auto',
|
||||
push_to_hub_model_id = None,
|
||||
push_to_hub_organization = None,
|
||||
push_to_hub_token = None,
|
||||
mp_parameters = '',
|
||||
auto_find_batch_size = True,
|
||||
full_determinism = False,
|
||||
torchdynamo = None,
|
||||
ray_scope = 'last',
|
||||
ddp_timeout = 1800,
|
||||
torch_compile = False,
|
||||
torch_compile_backend = None,
|
||||
torch_compile_mode = None,
|
||||
include_tokens_per_second = False,
|
||||
include_num_input_tokens_seen = False,
|
||||
neftune_noise_alpha = None,
|
||||
optim_target_modules = None,
|
||||
batch_eval_metrics = False,
|
||||
eval_on_start = False,
|
||||
use_liger_kernel = False,
|
||||
liger_kernel_config = None,
|
||||
eval_use_gather_object = False,
|
||||
average_tokens_across_devices = True,
|
||||
max_length = 1024,
|
||||
max_prompt_length = 512,
|
||||
max_completion_length = None,
|
||||
disable_dropout = True,
|
||||
step_separator = '\
|
||||
',
|
||||
train_on_last_step_only = False,
|
||||
dataset_num_proc = None,
|
||||
vllm_sampling_params = None,
|
||||
unsloth_num_chunks = -1,
|
||||
**kwargs,
|
||||
):
|
||||
if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
|
||||
if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
|
||||
if output_dir is None and save_strategy == 'steps' and save_steps == 500:
|
||||
output_dir = 'unsloth_training_checkpoints'
|
||||
save_strategy = 'no'
|
||||
if dataset_num_proc is None:
|
||||
from multiprocessing import cpu_count
|
||||
dataset_num_proc = min(cpu_count()*2, 2)
|
||||
|
||||
super().__init__(
|
||||
output_dir = output_dir,
|
||||
overwrite_output_dir = overwrite_output_dir,
|
||||
do_train = do_train,
|
||||
do_eval = do_eval,
|
||||
do_predict = do_predict,
|
||||
eval_strategy = eval_strategy,
|
||||
prediction_loss_only = prediction_loss_only,
|
||||
per_device_train_batch_size = per_device_train_batch_size,
|
||||
per_device_eval_batch_size = per_device_eval_batch_size,
|
||||
per_gpu_train_batch_size = per_gpu_train_batch_size,
|
||||
per_gpu_eval_batch_size = per_gpu_eval_batch_size,
|
||||
gradient_accumulation_steps = gradient_accumulation_steps,
|
||||
eval_accumulation_steps = eval_accumulation_steps,
|
||||
eval_delay = eval_delay,
|
||||
torch_empty_cache_steps = torch_empty_cache_steps,
|
||||
learning_rate = learning_rate,
|
||||
weight_decay = weight_decay,
|
||||
adam_beta1 = adam_beta1,
|
||||
adam_beta2 = adam_beta2,
|
||||
adam_epsilon = adam_epsilon,
|
||||
max_grad_norm = max_grad_norm,
|
||||
num_train_epochs = num_train_epochs,
|
||||
max_steps = max_steps,
|
||||
lr_scheduler_type = lr_scheduler_type,
|
||||
warmup_ratio = warmup_ratio,
|
||||
warmup_steps = warmup_steps,
|
||||
log_level = log_level,
|
||||
log_level_replica = log_level_replica,
|
||||
log_on_each_node = log_on_each_node,
|
||||
logging_dir = logging_dir,
|
||||
logging_strategy = logging_strategy,
|
||||
logging_first_step = logging_first_step,
|
||||
logging_steps = logging_steps,
|
||||
logging_nan_inf_filter = logging_nan_inf_filter,
|
||||
save_strategy = save_strategy,
|
||||
save_steps = save_steps,
|
||||
save_total_limit = save_total_limit,
|
||||
save_safetensors = save_safetensors,
|
||||
save_on_each_node = save_on_each_node,
|
||||
save_only_model = save_only_model,
|
||||
restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
|
||||
no_cuda = no_cuda,
|
||||
use_cpu = use_cpu,
|
||||
use_mps_device = use_mps_device,
|
||||
seed = seed,
|
||||
data_seed = data_seed,
|
||||
jit_mode_eval = jit_mode_eval,
|
||||
use_ipex = use_ipex,
|
||||
bf16 = bf16,
|
||||
fp16 = fp16,
|
||||
fp16_opt_level = fp16_opt_level,
|
||||
half_precision_backend = half_precision_backend,
|
||||
bf16_full_eval = bf16_full_eval,
|
||||
fp16_full_eval = fp16_full_eval,
|
||||
tf32 = tf32,
|
||||
local_rank = local_rank,
|
||||
ddp_backend = ddp_backend,
|
||||
tpu_num_cores = tpu_num_cores,
|
||||
tpu_metrics_debug = tpu_metrics_debug,
|
||||
debug = debug,
|
||||
dataloader_drop_last = dataloader_drop_last,
|
||||
eval_steps = eval_steps,
|
||||
dataloader_num_workers = dataloader_num_workers,
|
||||
dataloader_prefetch_factor = dataloader_prefetch_factor,
|
||||
past_index = past_index,
|
||||
run_name = run_name,
|
||||
disable_tqdm = disable_tqdm,
|
||||
remove_unused_columns = remove_unused_columns,
|
||||
label_names = label_names,
|
||||
load_best_model_at_end = load_best_model_at_end,
|
||||
metric_for_best_model = metric_for_best_model,
|
||||
greater_is_better = greater_is_better,
|
||||
ignore_data_skip = ignore_data_skip,
|
||||
fsdp = fsdp,
|
||||
fsdp_min_num_params = fsdp_min_num_params,
|
||||
fsdp_config = fsdp_config,
|
||||
fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
|
||||
accelerator_config = accelerator_config,
|
||||
deepspeed = deepspeed,
|
||||
label_smoothing_factor = label_smoothing_factor,
|
||||
optim = optim,
|
||||
optim_args = optim_args,
|
||||
adafactor = adafactor,
|
||||
group_by_length = group_by_length,
|
||||
length_column_name = length_column_name,
|
||||
report_to = report_to,
|
||||
ddp_find_unused_parameters = ddp_find_unused_parameters,
|
||||
ddp_bucket_cap_mb = ddp_bucket_cap_mb,
|
||||
ddp_broadcast_buffers = ddp_broadcast_buffers,
|
||||
dataloader_pin_memory = dataloader_pin_memory,
|
||||
dataloader_persistent_workers = dataloader_persistent_workers,
|
||||
skip_memory_metrics = skip_memory_metrics,
|
||||
use_legacy_prediction_loop = use_legacy_prediction_loop,
|
||||
push_to_hub = push_to_hub,
|
||||
resume_from_checkpoint = resume_from_checkpoint,
|
||||
hub_model_id = hub_model_id,
|
||||
hub_strategy = hub_strategy,
|
||||
hub_token = hub_token,
|
||||
hub_private_repo = hub_private_repo,
|
||||
hub_always_push = hub_always_push,
|
||||
hub_revision = hub_revision,
|
||||
gradient_checkpointing = gradient_checkpointing,
|
||||
gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
|
||||
include_inputs_for_metrics = include_inputs_for_metrics,
|
||||
eval_do_concat_batches = eval_do_concat_batches,
|
||||
fp16_backend = fp16_backend,
|
||||
push_to_hub_model_id = push_to_hub_model_id,
|
||||
push_to_hub_organization = push_to_hub_organization,
|
||||
push_to_hub_token = push_to_hub_token,
|
||||
mp_parameters = mp_parameters,
|
||||
auto_find_batch_size = auto_find_batch_size,
|
||||
full_determinism = full_determinism,
|
||||
torchdynamo = torchdynamo,
|
||||
ray_scope = ray_scope,
|
||||
ddp_timeout = ddp_timeout,
|
||||
torch_compile = torch_compile,
|
||||
torch_compile_backend = torch_compile_backend,
|
||||
torch_compile_mode = torch_compile_mode,
|
||||
include_tokens_per_second = include_tokens_per_second,
|
||||
include_num_input_tokens_seen = include_num_input_tokens_seen,
|
||||
neftune_noise_alpha = neftune_noise_alpha,
|
||||
optim_target_modules = optim_target_modules,
|
||||
batch_eval_metrics = batch_eval_metrics,
|
||||
eval_on_start = eval_on_start,
|
||||
use_liger_kernel = use_liger_kernel,
|
||||
liger_kernel_config = liger_kernel_config,
|
||||
eval_use_gather_object = eval_use_gather_object,
|
||||
average_tokens_across_devices = average_tokens_across_devices,
|
||||
max_length = max_length,
|
||||
max_prompt_length = max_prompt_length,
|
||||
max_completion_length = max_completion_length,
|
||||
disable_dropout = disable_dropout,
|
||||
step_separator = step_separator,
|
||||
train_on_last_step_only = train_on_last_step_only,
|
||||
dataset_num_proc = dataset_num_proc,**kwargs)
|
||||
self.vllm_sampling_params = vllm_sampling_params
|
||||
self.unsloth_num_chunks = unsloth_num_chunks
|
||||
pass
|
||||
|
||||
class _UnslothPRMTrainer(Trainer):
|
||||
""""""
|
||||
|
||||
_tag_names = ["trl", "prm"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Optional[Union[PreTrainedModel, nn.Module]] = None,
|
||||
args: Optional[PRMConfig] = None,
|
||||
data_collator: Optional[DataCollator] = None,
|
||||
train_dataset: Optional[Dataset] = None,
|
||||
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
|
||||
processing_class: Optional[
|
||||
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
|
||||
] = None,
|
||||
model_init: Optional[Callable[[], PreTrainedModel]] = None,
|
||||
compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
|
||||
callbacks: Optional[list[TrainerCallback]] = None,
|
||||
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (
|
||||
None,
|
||||
None,
|
||||
),
|
||||
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
||||
peft_config: Optional[dict] = None,
|
||||
):
|
||||
if not is_peft_available() and peft_config is not None:
|
||||
raise ValueError(
|
||||
"PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
|
||||
)
|
||||
elif is_peft_available() and peft_config is not None:
|
||||
if not isinstance(model, PeftModel):
|
||||
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_quantized", False):
|
||||
_supports_gc_kwargs = "gradient_checkpointing_kwargs" in list(
|
||||
inspect.signature(prepare_model_for_kbit_training).parameters
|
||||
)
|
||||
|
||||
prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
|
||||
|
||||
if not _supports_gc_kwargs and args.gradient_checkpointing_kwargs is not None:
|
||||
warnings.warn(
|
||||
"You passed `gradient_checkpointing_kwargs` in the trainer's kwargs, but your peft version does not support it. "
|
||||
"please update to the latest version of peft to use `gradient_checkpointing_kwargs`."
|
||||
)
|
||||
elif _supports_gc_kwargs and args.gradient_checkpointing_kwargs is not None:
|
||||
prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
|
||||
|
||||
model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
|
||||
|
||||
model = model
|
||||
|
||||
# Disable dropout in the model
|
||||
if args.disable_dropout:
|
||||
disable_dropout_in_model(model)
|
||||
|
||||
if compute_metrics is None:
|
||||
compute_metrics = compute_accuracy
|
||||
|
||||
if data_collator is None:
|
||||
if processing_class is None:
|
||||
raise ValueError(
|
||||
"A processing_class must be specified when using the default DataCollatorForTokenClassification"
|
||||
)
|
||||
data_collator = DataCollatorForTokenClassification(processing_class, max_length=args.max_length)
|
||||
|
||||
if "input_ids" not in train_dataset.column_names:
|
||||
with PartialState().main_process_first():
|
||||
fn_kwargs = {
|
||||
"tokenizer": processing_class,
|
||||
"step_separator": args.step_separator,
|
||||
"max_length": args.max_length,
|
||||
"max_prompt_length": args.max_prompt_length,
|
||||
"max_completion_length": args.max_completion_length,
|
||||
"train_on_last_step_only": args.train_on_last_step_only,
|
||||
}
|
||||
train_fn_kwargs = {**fn_kwargs, "is_eval": False}
|
||||
train_dataset = train_dataset.map(
|
||||
self.tokenize_row,
|
||||
fn_kwargs=train_fn_kwargs,
|
||||
num_proc=args.dataset_num_proc,
|
||||
remove_columns=train_dataset.features,
|
||||
desc="Tokenizing train dataset",
|
||||
features=features.Features( # needed to avoid map to cast labels to bool
|
||||
{
|
||||
"labels": features.Sequence(features.Value("int64")),
|
||||
"input_ids": features.Sequence(features.Value("int64")),
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
eval_fn_kwargs = {**fn_kwargs, "is_eval": True}
|
||||
if eval_dataset is not None:
|
||||
eval_dataset = eval_dataset.map(
|
||||
self.tokenize_row,
|
||||
fn_kwargs=eval_fn_kwargs,
|
||||
num_proc=args.dataset_num_proc,
|
||||
remove_columns=eval_dataset.features,
|
||||
desc="Tokenizing eval dataset",
|
||||
features=features.Features( # needed to avoid map to cast labels to bool
|
||||
{
|
||||
"labels": features.Sequence(features.Value("int64")),
|
||||
"input_ids": features.Sequence(features.Value("int64")),
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
model=model,
|
||||
args=args,
|
||||
data_collator=data_collator,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
processing_class=processing_class,
|
||||
model_init=model_init,
|
||||
compute_metrics=compute_metrics,
|
||||
callbacks=callbacks,
|
||||
optimizers=optimizers,
|
||||
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
|
||||
)
|
||||
|
||||
# Add tags for models that have been loaded with the correct transformers version
|
||||
if hasattr(self.model, "add_model_tags"):
|
||||
self.model.add_model_tags(self._tag_names)
|
||||
|
||||
@staticmethod
|
||||
def tokenize_row(
|
||||
features,
|
||||
tokenizer,
|
||||
step_separator,
|
||||
max_length,
|
||||
max_prompt_length,
|
||||
max_completion_length,
|
||||
train_on_last_step_only,
|
||||
is_eval,
|
||||
):
|
||||
r"""
|
||||
Tokenize a row of the dataset.
|
||||
|
||||
Args:
|
||||
features (`dict[str, str]`):
|
||||
Row of the dataset, should contain the keys `"prompt"`, `"completions"`, and `"labels"`.
|
||||
tokenizer (`PreTrainedTokenizerBase`):
|
||||
Tokenizer used to process the data.
|
||||
step_separator (`str`):
|
||||
Separator between steps in the completion.
|
||||
max_length (`int` or `None`):
|
||||
Maximum length of the sequences (prompt + completion). If `None`, the sequences are not truncated.
|
||||
max_prompt_length (`int` or `None`):
|
||||
Maximum length of the prompt. If `None`, the prompt is not truncated.
|
||||
max_completion_length (`int` or `None`):
|
||||
Maximum length of the completion sequences. If `None`, the completion sequences are not truncated.
|
||||
train_on_last_step_only (`bool`):
|
||||
Whether to train only on the last step. If `True`, the labels are `-100` for all tokens except the last
|
||||
token of the completion.
|
||||
is_eval (`bool`):
|
||||
Whether the function is used to tokenize samples from a training or an evaluation dataset. Used only if
|
||||
`train_on_last_step_only` is set to `True`.
|
||||
|
||||
Returns:
|
||||
`dict[str, list[int]]`:
|
||||
Tokenized sequences with the keys `"input_ids"`, and `"labels".
|
||||
|
||||
Example:
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer
|
||||
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B")
|
||||
>>> features = {
|
||||
... "prompt": "Which number is larger, 9.8 or 9.11?",
|
||||
... "completions": ["11 is greater than 8.", "Hence, 9.11 > 9.8."],
|
||||
... "labels": [True, False],
|
||||
... }
|
||||
>>> PRMTrainer.tokenize_row(
|
||||
... features, tokenizer, "\n", max_completion_length=None, train_on_last_step_only=False, is_eval=False
|
||||
... )
|
||||
{'input_ids': [23085, 1372, 374, 8131, 11, 220, 24, 13, 23, 476, 220, 24, 13, 16, 16, 30, 16, 16, 374, 7046, 1091, 220, 23, 13, 198, 39, 763, 11, 220, 24, 13, 16, 16, 861, 220, 24, 13, 23, 13, 198],
|
||||
'labels': [-100, -100, -100, -100, -100, -100, -100, -100, 1, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 0]}
|
||||
```
|
||||
"""
|
||||
# Tokenize the prompt and completions
|
||||
prompt_ids = tokenizer(features["prompt"], add_special_tokens=False)["input_ids"]
|
||||
completions_ids = [
|
||||
tokenizer(completion, add_special_tokens=False)["input_ids"] for completion in features["completions"]
|
||||
]
|
||||
if train_on_last_step_only and not is_eval:
|
||||
labels = [-100] * (len(features["labels"]) - 1) + [int(features["labels"][-1])]
|
||||
else:
|
||||
labels = [int(label) for label in features["labels"]]
|
||||
|
||||
# Get the ID of the separator token and add it to the completions
|
||||
separator_ids = tokenizer.encode(step_separator, add_special_tokens=False)
|
||||
completions_ids = [completion + separator_ids for completion in completions_ids]
|
||||
|
||||
# Create the label
|
||||
labels = [[-100] * (len(completion) - 1) + [label] for completion, label in zip(completions_ids, labels)]
|
||||
|
||||
# Join the completions and labels steps
|
||||
completion_ids = list(chain(*completions_ids))
|
||||
labels = list(chain(*labels))
|
||||
|
||||
if tokenizer.bos_token_id is not None:
|
||||
prompt_ids = [tokenizer.bos_token_id] + prompt_ids
|
||||
|
||||
# Truncate prompt and completion sequences
|
||||
if max_prompt_length is not None:
|
||||
prompt_ids = prompt_ids[-max_prompt_length:]
|
||||
if max_completion_length is not None:
|
||||
completion_ids = completion_ids[:max_completion_length]
|
||||
labels = labels[:max_completion_length]
|
||||
|
||||
input_ids = prompt_ids + completion_ids
|
||||
labels = [-100] * len(prompt_ids) + labels
|
||||
|
||||
if max_length is not None:
|
||||
input_ids = input_ids[:max_length]
|
||||
labels = labels[:max_length]
|
||||
|
||||
return {"input_ids": input_ids, "labels": labels}
|
||||
|
||||
# Ensure the model card is saved along with the checkpoint
|
||||
def _save_checkpoint(self, model, trial):
|
||||
if self.args.hub_model_id is None:
|
||||
model_name = Path(self.args.output_dir).name
|
||||
else:
|
||||
model_name = self.args.hub_model_id.split("/")[-1]
|
||||
self.create_model_card(model_name=model_name)
|
||||
super()._save_checkpoint(model, trial)
|
||||
|
||||
def create_model_card(
|
||||
self,
|
||||
model_name: Optional[str] = None,
|
||||
dataset_name: Optional[str] = None,
|
||||
tags: Union[str, list[str], None] = None,
|
||||
):
|
||||
"""
|
||||
Creates a draft of a model card using the information available to the `Trainer`.
|
||||
|
||||
Args:
|
||||
model_name (`str` or `None`, *optional*, defaults to `None`):
|
||||
Name of the model.
|
||||
dataset_name (`str` or `None`, *optional*, defaults to `None`):
|
||||
Name of the dataset used for training.
|
||||
tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
|
||||
Tags to be associated with the model card.
|
||||
"""
|
||||
if not self.is_world_process_zero():
|
||||
return
|
||||
|
||||
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
|
||||
base_model = self.model.config._name_or_path
|
||||
else:
|
||||
base_model = None
|
||||
|
||||
# normalize `tags` to a mutable set
|
||||
if tags is None:
|
||||
tags = set()
|
||||
elif isinstance(tags, str):
|
||||
tags = {tags}
|
||||
else:
|
||||
tags = set(tags)
|
||||
|
||||
if hasattr(self.model.config, "unsloth_version"):
|
||||
tags.add("unsloth")
|
||||
|
||||
tags.update(self._tag_names)
|
||||
|
||||
citation = textwrap.dedent("""\
|
||||
@article{uesato2022solving,
|
||||
title = {{Solving Math Word Problems With Process- and Outcome-Based Feedback}},
|
||||
author = {Uesato, Jonathan and Kushman, Nate and Kumar, Ramana and Song, Francis and Siegel, Noah and Wang, Lisa and Creswell, Antonia and Irving, Geoffrey and Higgins, Irina},
|
||||
year = 2022,
|
||||
journal = {arXiv preprint arXiv:2211.14275}
|
||||
}""")
|
||||
|
||||
model_card = generate_model_card(
|
||||
base_model=base_model,
|
||||
model_name=model_name,
|
||||
hub_model_id=self.hub_model_id,
|
||||
dataset_name=dataset_name,
|
||||
tags=tags,
|
||||
wandb_url=wandb.run.url if is_wandb_available() and wandb.run is not None else None,
|
||||
trainer_name="PRM",
|
||||
trainer_citation=citation,
|
||||
paper_title="Solving math word problems with process-and outcome-based feedback",
|
||||
)
|
||||
|
||||
model_card.save(os.path.join(self.args.output_dir, "README.md"))
|
||||
class UnslothPRMTrainer(_UnslothPRMTrainer):
|
||||
"""
|
||||
|
||||
Initialize PRMTrainer.
|
||||
|
||||
Args:
|
||||
model (`transformers.PreTrainedModel`):
|
||||
The model to train, preferably an `AutoModelForTokenClassification`.
|
||||
args (`PRMConfig`):
|
||||
The arguments to use for training.
|
||||
data_collator (`transformers.DataCollator`):
|
||||
The data collator to use for training. If None is specified, the default data collator
|
||||
(`DataCollatorForTokenClassification`) will be used which will pad the sequences to the maximum length of
|
||||
the sequences in the batch, given a dataset of paired sequences.
|
||||
train_dataset (`datasets.Dataset`):
|
||||
The dataset to use for training.
|
||||
eval_dataset (`datasets.Dataset`):
|
||||
The dataset to use for evaluation.
|
||||
processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.BaseImageProcessor`], [`~transformers.FeatureExtractionMixin`] or [`~transformers.ProcessorMixin`], *optional*, defaults to `None`):
|
||||
Processing class used to process the data. If provided, will be used to automatically process the inputs
|
||||
for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
|
||||
reuse the fine-tuned model.
|
||||
model_init (`Callable[[], transformers.PreTrainedModel]`):
|
||||
The model initializer to use for training. If None is specified, the default model initializer will be
|
||||
used.
|
||||
compute_metrics (`Callable[[transformers.EvalPrediction], dict]`, *optional* defaults to `compute_accuracy`):
|
||||
The metrics to use for evaluation. If no metrics are specified, the default metric (`compute_accuracy`)
|
||||
will be used.
|
||||
callbacks (`list[transformers.TrainerCallback]`):
|
||||
The callbacks to use for training.
|
||||
optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
|
||||
The optimizer and scheduler to use for training.
|
||||
preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
|
||||
The function to use to preprocess the logits before computing the metrics.
|
||||
peft_config (`dict`, defaults to `None`):
|
||||
The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in
|
||||
a PEFT model.
|
||||
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
model = None,
|
||||
args = None,
|
||||
data_collator = None,
|
||||
train_dataset = None,
|
||||
eval_dataset = None,
|
||||
processing_class = None,
|
||||
model_init = None,
|
||||
compute_metrics = None,
|
||||
callbacks = None,
|
||||
preprocess_logits_for_metrics = None,
|
||||
peft_config = None,
|
||||
**kwargs
|
||||
):
|
||||
if args is None: args = UnslothPRMConfig()
|
||||
use_bf16 = getattr(args, 'bf16', False)
|
||||
if type(use_bf16) is not bool: use_bf16 = False
|
||||
use_fp16 = getattr(args, 'fp16', False)
|
||||
if type(use_fp16) is not bool: use_fp16 = False
|
||||
force_float32 = False
|
||||
if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
|
||||
print('Unsloth: Switching to float32 training since model cannot work with float16')
|
||||
force_float32 = True
|
||||
mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
|
||||
dtype = getattr(model.config, 'torch_dtype', None)
|
||||
if dtype is None: dtype = model.get_input_embeddings().dtype
|
||||
from unsloth_zoo.utils import _get_dtype
|
||||
dtype = _get_dtype(dtype)
|
||||
float16 = dtype == torch.float16
|
||||
if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
|
||||
if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
|
||||
if force_float32:
|
||||
args.fp16 = False
|
||||
args.bf16 = False
|
||||
os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
|
||||
elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
|
||||
args.fp16 = float16
|
||||
args.bf16 = not float16
|
||||
os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
|
||||
if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
|
||||
args.eval_strategy = 'steps'
|
||||
if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
|
||||
ga_steps = getattr(args, 'gradient_accumulation_steps', None)
|
||||
if ga_steps is not None and ga_steps > 1:
|
||||
from transformers import __version__ as transformers_version
|
||||
if Version(transformers_version) <= Version('4.45.2'):
|
||||
print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
|
||||
'`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
|
||||
if getattr(args, 'eval_strategy', 'no') != 'no':
|
||||
eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
|
||||
if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
|
||||
if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
|
||||
fp16_full_eval = getattr(args, 'fp16_full_eval', False)
|
||||
if type(fp16_full_eval) is not bool: fp16_full_eval = False
|
||||
bf16_full_eval = getattr(args, 'bf16_full_eval', False)
|
||||
if type(bf16_full_eval) is not bool: bf16_full_eval = False
|
||||
if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
|
||||
if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
|
||||
if force_float32:
|
||||
args.bf16_full_eval = False
|
||||
args.fp16_full_eval = False
|
||||
elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
|
||||
args.bf16_full_eval = True
|
||||
args.fp16_full_eval = False
|
||||
elif not bf16_full_eval and not fp16_full_eval:
|
||||
args.bf16_full_eval = args.bf16
|
||||
args.fp16_full_eval = args.fp16
|
||||
_output_logits = False
|
||||
if locals().get('compute_metrics', None) is not None: _output_logits = True
|
||||
if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
|
||||
if _output_logits:
|
||||
os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
|
||||
if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
|
||||
pass
|
||||
else:
|
||||
model_max_seq_length = getattr(model, 'max_seq_length', None)
|
||||
args_max_seq_length = getattr(args, 'max_seq_length', None)
|
||||
if args_max_seq_length is None and model_max_seq_length is not None:
|
||||
max_seq_length = model.max_seq_length
|
||||
if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
|
||||
if model is not None and hasattr(model, 'for_training'):
|
||||
model.for_training()
|
||||
if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
|
||||
if 'processing_class' in locals():
|
||||
if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
|
||||
if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
|
||||
__tokenizer = processing_class if 'processing_class' in locals() else tokenizer
|
||||
from unsloth_zoo.vision_utils import UnslothVisionDataCollator
|
||||
if not isinstance(data_collator, UnslothVisionDataCollator):
|
||||
if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
|
||||
data_collator = TransformersDataCollatorForLanguageModeling(__tokenizer, mlm = False, mlm_probability = 0.0)
|
||||
elif isinstance(data_collator, TransformersDataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
|
||||
data_collator = DataCollatorForSeq2Seq(__tokenizer)
|
||||
else:
|
||||
if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
|
||||
if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
|
||||
if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
|
||||
if not isinstance(data_collator, UnslothVisionDataCollator):
|
||||
if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
|
||||
if isinstance(data_collator, DataCollatorForSeq2Seq):
|
||||
data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
|
||||
else:
|
||||
data_collator = TransformersDataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False, mlm_probability = 0.0)
|
||||
other_metrics = []
|
||||
|
||||
from unsloth_zoo.logging_utils import PatchRLStatistics
|
||||
PatchRLStatistics('prm_trainer', other_metrics)
|
||||
|
||||
super().__init__(
|
||||
model = model,
|
||||
args = args,
|
||||
data_collator = data_collator,
|
||||
train_dataset = train_dataset,
|
||||
eval_dataset = eval_dataset,
|
||||
processing_class = processing_class,
|
||||
model_init = model_init,
|
||||
compute_metrics = compute_metrics,
|
||||
callbacks = callbacks,
|
||||
preprocess_logits_for_metrics = preprocess_logits_for_metrics,
|
||||
peft_config = peft_config,**kwargs)
|
||||
if hasattr(self, 'neftune_hook_handle'):
|
||||
self.neftune_hook_handle.remove()
|
||||
if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
|
||||
if getattr(args, 'neftune_noise_alpha', None) is not None:
|
||||
model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
|
||||
pass
|
||||
if hasattr(self, 'accelerator'):
|
||||
scaler = self.accelerator.scaler
|
||||
current_model = model
|
||||
while hasattr(current_model, 'model'):
|
||||
current_model.accelerator_scaler = scaler
|
||||
current_model = current_model.model
|
||||
current_model.accelerator_scaler = scaler
|
||||
pass
|
||||
|
||||
pass
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,860 @@
|
||||
"""
|
||||
2025.8.4
|
||||
2025.8.5
|
||||
4.55.1
|
||||
0.21.0
|
||||
__UNSLOTH_VERSIONING__
|
||||
"""
|
||||
from torch import Tensor
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable
|
||||
from trl.trainer.reward_trainer import (Any, BaseImageProcessor, Callable, DataCollator, Dataset, EvalPrediction, FeatureExtractionMixin, FrozenInstanceError, Optional, PartialState, Path, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, RewardConfig, RewardDataCollatorWithPadding, RewardTrainer, Trainer, TrainerCallback, Union, _tokenize, compute_accuracy, decode_and_strip_padding, defaultdict, disable_dropout_in_model, gather_object, generate_model_card, get_comet_experiment_url, inspect, is_peft_available, is_rich_available, is_wandb_available, log_table_to_comet_experiment, maybe_apply_chat_template, nested_detach, nn, os, pd, prepare_model_for_kbit_training, print_rich_table, replace, torch, warnings, Optional, PeftModel, PreTrainedModel, Trainer, is_peft_available, os, torch)
|
||||
|
||||
|
||||
import os
|
||||
from typing import *
|
||||
from dataclasses import dataclass, field
|
||||
from packaging.version import Version
|
||||
import torch
|
||||
import numpy as np
|
||||
from contextlib import nullcontext
|
||||
from torch.nn import functional as F
|
||||
from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling
|
||||
|
||||
torch_compile_options = {
|
||||
"epilogue_fusion" : True,
|
||||
"max_autotune" : False,
|
||||
"shape_padding" : True,
|
||||
"trace.enabled" : False,
|
||||
"triton.cudagraphs" : False,
|
||||
}
|
||||
|
||||
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
|
||||
def chunked_selective_log_softmax(logits, index):
|
||||
# Split into 4 chunks only
|
||||
chunked_logits = torch.chunk(logits.reshape(-1, logits.shape[-1]), chunks = 4, dim = 0)
|
||||
chunked_index = torch.chunk(index.reshape(-1), chunks = 4, dim = 0)
|
||||
all_per_token_logps = []
|
||||
# Below loop does the same as selective_log_softmax(chunk_logits, chunk_index)
|
||||
for chunk_logits, chunk_index in zip(chunked_logits, chunked_index):
|
||||
chunk_logits = chunk_logits.to(torch.float32)
|
||||
selected_logits = torch.gather(chunk_logits, dim = -1, index = chunk_index.unsqueeze(-1)).squeeze(-1)
|
||||
logsumexp_values = torch.logsumexp(chunk_logits, dim = -1)
|
||||
per_token_logps = selected_logits - logsumexp_values
|
||||
all_per_token_logps.append(per_token_logps)
|
||||
pass
|
||||
all_per_token_logps = torch.concat(all_per_token_logps)
|
||||
all_per_token_logps = all_per_token_logps.reshape((logits.shape[0], logits.shape[1]))
|
||||
return all_per_token_logps
|
||||
@dataclass
|
||||
class UnslothRewardConfig(RewardConfig):
|
||||
"""
|
||||
|
||||
Configuration class for the [`RewardTrainer`].
|
||||
|
||||
This class includes only the parameters that are specific to Reward training. For a full list of training
|
||||
arguments, please refer to the [`~transformers.TrainingArguments`] documentation. Note that default values in this
|
||||
class may differ from those in [`~transformers.TrainingArguments`].
|
||||
|
||||
Using [`~transformers.HfArgumentParser`] we can turn this class into
|
||||
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
|
||||
command line.
|
||||
|
||||
Parameters:
|
||||
max_length (`int` or `None`, *optional*, defaults to `1024`):
|
||||
Maximum length of the sequences (prompt + completion) in the batch, filters out entries that exceed the
|
||||
limit. This argument is required if you want to use the default data collator.
|
||||
disable_dropout (`bool`, *optional*, defaults to `True`):
|
||||
Whether to disable dropout in the model.
|
||||
dataset_num_proc (`int`, *optional*, defaults to `None`):
|
||||
Number of processes to use for processing the dataset.
|
||||
center_rewards_coefficient (`float`, *optional*, defaults to `None`):
|
||||
Coefficient to incentivize the reward model to output mean-zero rewards (proposed by
|
||||
https://huggingface.co/papers/2312.09244, Eq. 2). Recommended value: `0.01`.
|
||||
remove_unused_columns (`bool`, *optional*, defaults to `False`):
|
||||
Whether to remove the columns that are not used by the model's forward pass. Can be `True` only if the
|
||||
dataset is pretokenized.
|
||||
|
||||
"""
|
||||
vllm_sampling_params: Optional[Any] = field(
|
||||
default = None,
|
||||
metadata = {'help': 'vLLM SamplingParams'},
|
||||
)
|
||||
unsloth_num_chunks : Optional[int] = field(
|
||||
default = -1,
|
||||
metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
|
||||
)
|
||||
def __init__(
|
||||
self,
|
||||
output_dir = None,
|
||||
overwrite_output_dir = None,
|
||||
do_train = False,
|
||||
do_eval = False,
|
||||
do_predict = False,
|
||||
eval_strategy = 'no',
|
||||
prediction_loss_only = False,
|
||||
per_device_train_batch_size = 4,
|
||||
per_device_eval_batch_size = 4,
|
||||
per_gpu_train_batch_size = None,
|
||||
per_gpu_eval_batch_size = None,
|
||||
gradient_accumulation_steps = 2,
|
||||
eval_accumulation_steps = 2,
|
||||
eval_delay = 0,
|
||||
torch_empty_cache_steps = 250,
|
||||
learning_rate = 5e-05,
|
||||
weight_decay = 0.01,
|
||||
adam_beta1 = 0.9,
|
||||
adam_beta2 = 0.999,
|
||||
adam_epsilon = 1e-08,
|
||||
max_grad_norm = 1.0,
|
||||
num_train_epochs = 3.0,
|
||||
max_steps = -1,
|
||||
lr_scheduler_type = 'linear',
|
||||
warmup_ratio = 0.1,
|
||||
warmup_steps = 0,
|
||||
log_level = 'passive',
|
||||
log_level_replica = 'warning',
|
||||
log_on_each_node = True,
|
||||
logging_dir = None,
|
||||
logging_strategy = 'steps',
|
||||
logging_first_step = False,
|
||||
logging_steps = 1,
|
||||
logging_nan_inf_filter = False,
|
||||
save_strategy = 'steps',
|
||||
save_steps = 500,
|
||||
save_total_limit = None,
|
||||
save_safetensors = True,
|
||||
save_on_each_node = False,
|
||||
save_only_model = False,
|
||||
restore_callback_states_from_checkpoint = False,
|
||||
no_cuda = False,
|
||||
use_cpu = False,
|
||||
use_mps_device = False,
|
||||
seed = 3407,
|
||||
data_seed = 3407,
|
||||
jit_mode_eval = False,
|
||||
use_ipex = False,
|
||||
bf16 = False,
|
||||
fp16 = False,
|
||||
fp16_opt_level = 'O1',
|
||||
half_precision_backend = 'auto',
|
||||
bf16_full_eval = False,
|
||||
fp16_full_eval = False,
|
||||
tf32 = None,
|
||||
local_rank = -1,
|
||||
ddp_backend = None,
|
||||
tpu_num_cores = None,
|
||||
tpu_metrics_debug = False,
|
||||
debug = '',
|
||||
dataloader_drop_last = False,
|
||||
eval_steps = None,
|
||||
dataloader_num_workers = 0,
|
||||
dataloader_prefetch_factor = None,
|
||||
past_index = -1,
|
||||
run_name = None,
|
||||
disable_tqdm = None,
|
||||
remove_unused_columns = False,
|
||||
label_names = None,
|
||||
load_best_model_at_end = False,
|
||||
metric_for_best_model = None,
|
||||
greater_is_better = None,
|
||||
ignore_data_skip = False,
|
||||
fsdp = '',
|
||||
fsdp_min_num_params = 0,
|
||||
fsdp_config = None,
|
||||
fsdp_transformer_layer_cls_to_wrap = None,
|
||||
accelerator_config = None,
|
||||
deepspeed = None,
|
||||
label_smoothing_factor = 0.0,
|
||||
optim = 'adamw_8bit',
|
||||
optim_args = None,
|
||||
adafactor = False,
|
||||
group_by_length = False,
|
||||
length_column_name = 'length',
|
||||
report_to = None,
|
||||
ddp_find_unused_parameters = None,
|
||||
ddp_bucket_cap_mb = None,
|
||||
ddp_broadcast_buffers = None,
|
||||
dataloader_pin_memory = True,
|
||||
dataloader_persistent_workers = False,
|
||||
skip_memory_metrics = True,
|
||||
use_legacy_prediction_loop = False,
|
||||
push_to_hub = False,
|
||||
resume_from_checkpoint = None,
|
||||
hub_model_id = None,
|
||||
hub_strategy = 'every_save',
|
||||
hub_token = None,
|
||||
hub_private_repo = None,
|
||||
hub_always_push = False,
|
||||
hub_revision = None,
|
||||
gradient_checkpointing = False,
|
||||
gradient_checkpointing_kwargs = None,
|
||||
include_inputs_for_metrics = False,
|
||||
eval_do_concat_batches = True,
|
||||
fp16_backend = 'auto',
|
||||
push_to_hub_model_id = None,
|
||||
push_to_hub_organization = None,
|
||||
push_to_hub_token = None,
|
||||
mp_parameters = '',
|
||||
auto_find_batch_size = True,
|
||||
full_determinism = False,
|
||||
torchdynamo = None,
|
||||
ray_scope = 'last',
|
||||
ddp_timeout = 1800,
|
||||
torch_compile = False,
|
||||
torch_compile_backend = None,
|
||||
torch_compile_mode = None,
|
||||
include_tokens_per_second = False,
|
||||
include_num_input_tokens_seen = False,
|
||||
neftune_noise_alpha = None,
|
||||
optim_target_modules = None,
|
||||
batch_eval_metrics = False,
|
||||
eval_on_start = False,
|
||||
use_liger_kernel = False,
|
||||
liger_kernel_config = None,
|
||||
eval_use_gather_object = False,
|
||||
average_tokens_across_devices = True,
|
||||
max_length = 1024,
|
||||
disable_dropout = True,
|
||||
dataset_num_proc = None,
|
||||
center_rewards_coefficient = None,
|
||||
vllm_sampling_params = None,
|
||||
unsloth_num_chunks = -1,
|
||||
**kwargs,
|
||||
):
|
||||
if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
|
||||
if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
|
||||
if output_dir is None and save_strategy == 'steps' and save_steps == 500:
|
||||
output_dir = 'unsloth_training_checkpoints'
|
||||
save_strategy = 'no'
|
||||
if dataset_num_proc is None:
|
||||
from multiprocessing import cpu_count
|
||||
dataset_num_proc = min(cpu_count()*2, 2)
|
||||
|
||||
super().__init__(
|
||||
output_dir = output_dir,
|
||||
overwrite_output_dir = overwrite_output_dir,
|
||||
do_train = do_train,
|
||||
do_eval = do_eval,
|
||||
do_predict = do_predict,
|
||||
eval_strategy = eval_strategy,
|
||||
prediction_loss_only = prediction_loss_only,
|
||||
per_device_train_batch_size = per_device_train_batch_size,
|
||||
per_device_eval_batch_size = per_device_eval_batch_size,
|
||||
per_gpu_train_batch_size = per_gpu_train_batch_size,
|
||||
per_gpu_eval_batch_size = per_gpu_eval_batch_size,
|
||||
gradient_accumulation_steps = gradient_accumulation_steps,
|
||||
eval_accumulation_steps = eval_accumulation_steps,
|
||||
eval_delay = eval_delay,
|
||||
torch_empty_cache_steps = torch_empty_cache_steps,
|
||||
learning_rate = learning_rate,
|
||||
weight_decay = weight_decay,
|
||||
adam_beta1 = adam_beta1,
|
||||
adam_beta2 = adam_beta2,
|
||||
adam_epsilon = adam_epsilon,
|
||||
max_grad_norm = max_grad_norm,
|
||||
num_train_epochs = num_train_epochs,
|
||||
max_steps = max_steps,
|
||||
lr_scheduler_type = lr_scheduler_type,
|
||||
warmup_ratio = warmup_ratio,
|
||||
warmup_steps = warmup_steps,
|
||||
log_level = log_level,
|
||||
log_level_replica = log_level_replica,
|
||||
log_on_each_node = log_on_each_node,
|
||||
logging_dir = logging_dir,
|
||||
logging_strategy = logging_strategy,
|
||||
logging_first_step = logging_first_step,
|
||||
logging_steps = logging_steps,
|
||||
logging_nan_inf_filter = logging_nan_inf_filter,
|
||||
save_strategy = save_strategy,
|
||||
save_steps = save_steps,
|
||||
save_total_limit = save_total_limit,
|
||||
save_safetensors = save_safetensors,
|
||||
save_on_each_node = save_on_each_node,
|
||||
save_only_model = save_only_model,
|
||||
restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
|
||||
no_cuda = no_cuda,
|
||||
use_cpu = use_cpu,
|
||||
use_mps_device = use_mps_device,
|
||||
seed = seed,
|
||||
data_seed = data_seed,
|
||||
jit_mode_eval = jit_mode_eval,
|
||||
use_ipex = use_ipex,
|
||||
bf16 = bf16,
|
||||
fp16 = fp16,
|
||||
fp16_opt_level = fp16_opt_level,
|
||||
half_precision_backend = half_precision_backend,
|
||||
bf16_full_eval = bf16_full_eval,
|
||||
fp16_full_eval = fp16_full_eval,
|
||||
tf32 = tf32,
|
||||
local_rank = local_rank,
|
||||
ddp_backend = ddp_backend,
|
||||
tpu_num_cores = tpu_num_cores,
|
||||
tpu_metrics_debug = tpu_metrics_debug,
|
||||
debug = debug,
|
||||
dataloader_drop_last = dataloader_drop_last,
|
||||
eval_steps = eval_steps,
|
||||
dataloader_num_workers = dataloader_num_workers,
|
||||
dataloader_prefetch_factor = dataloader_prefetch_factor,
|
||||
past_index = past_index,
|
||||
run_name = run_name,
|
||||
disable_tqdm = disable_tqdm,
|
||||
remove_unused_columns = remove_unused_columns,
|
||||
label_names = label_names,
|
||||
load_best_model_at_end = load_best_model_at_end,
|
||||
metric_for_best_model = metric_for_best_model,
|
||||
greater_is_better = greater_is_better,
|
||||
ignore_data_skip = ignore_data_skip,
|
||||
fsdp = fsdp,
|
||||
fsdp_min_num_params = fsdp_min_num_params,
|
||||
fsdp_config = fsdp_config,
|
||||
fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
|
||||
accelerator_config = accelerator_config,
|
||||
deepspeed = deepspeed,
|
||||
label_smoothing_factor = label_smoothing_factor,
|
||||
optim = optim,
|
||||
optim_args = optim_args,
|
||||
adafactor = adafactor,
|
||||
group_by_length = group_by_length,
|
||||
length_column_name = length_column_name,
|
||||
report_to = report_to,
|
||||
ddp_find_unused_parameters = ddp_find_unused_parameters,
|
||||
ddp_bucket_cap_mb = ddp_bucket_cap_mb,
|
||||
ddp_broadcast_buffers = ddp_broadcast_buffers,
|
||||
dataloader_pin_memory = dataloader_pin_memory,
|
||||
dataloader_persistent_workers = dataloader_persistent_workers,
|
||||
skip_memory_metrics = skip_memory_metrics,
|
||||
use_legacy_prediction_loop = use_legacy_prediction_loop,
|
||||
push_to_hub = push_to_hub,
|
||||
resume_from_checkpoint = resume_from_checkpoint,
|
||||
hub_model_id = hub_model_id,
|
||||
hub_strategy = hub_strategy,
|
||||
hub_token = hub_token,
|
||||
hub_private_repo = hub_private_repo,
|
||||
hub_always_push = hub_always_push,
|
||||
hub_revision = hub_revision,
|
||||
gradient_checkpointing = gradient_checkpointing,
|
||||
gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
|
||||
include_inputs_for_metrics = include_inputs_for_metrics,
|
||||
eval_do_concat_batches = eval_do_concat_batches,
|
||||
fp16_backend = fp16_backend,
|
||||
push_to_hub_model_id = push_to_hub_model_id,
|
||||
push_to_hub_organization = push_to_hub_organization,
|
||||
push_to_hub_token = push_to_hub_token,
|
||||
mp_parameters = mp_parameters,
|
||||
auto_find_batch_size = auto_find_batch_size,
|
||||
full_determinism = full_determinism,
|
||||
torchdynamo = torchdynamo,
|
||||
ray_scope = ray_scope,
|
||||
ddp_timeout = ddp_timeout,
|
||||
torch_compile = torch_compile,
|
||||
torch_compile_backend = torch_compile_backend,
|
||||
torch_compile_mode = torch_compile_mode,
|
||||
include_tokens_per_second = include_tokens_per_second,
|
||||
include_num_input_tokens_seen = include_num_input_tokens_seen,
|
||||
neftune_noise_alpha = neftune_noise_alpha,
|
||||
optim_target_modules = optim_target_modules,
|
||||
batch_eval_metrics = batch_eval_metrics,
|
||||
eval_on_start = eval_on_start,
|
||||
use_liger_kernel = use_liger_kernel,
|
||||
liger_kernel_config = liger_kernel_config,
|
||||
eval_use_gather_object = eval_use_gather_object,
|
||||
average_tokens_across_devices = average_tokens_across_devices,
|
||||
max_length = max_length,
|
||||
disable_dropout = disable_dropout,
|
||||
dataset_num_proc = dataset_num_proc,
|
||||
center_rewards_coefficient = center_rewards_coefficient,**kwargs)
|
||||
self.vllm_sampling_params = vllm_sampling_params
|
||||
self.unsloth_num_chunks = unsloth_num_chunks
|
||||
pass
|
||||
|
||||
class _UnslothRewardTrainer(Trainer):
|
||||
_tag_names = ["trl", "reward-trainer"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Optional[Union[PreTrainedModel, nn.Module]] = None,
|
||||
args: Optional[RewardConfig] = None,
|
||||
data_collator: Optional[DataCollator] = None,
|
||||
train_dataset: Optional[Dataset] = None,
|
||||
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
|
||||
processing_class: Optional[
|
||||
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
|
||||
] = None,
|
||||
model_init: Optional[Callable[[], PreTrainedModel]] = None,
|
||||
compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
|
||||
callbacks: Optional[list[TrainerCallback]] = None,
|
||||
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (
|
||||
None,
|
||||
None,
|
||||
),
|
||||
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
||||
peft_config: Optional[dict] = None,
|
||||
):
|
||||
"""
|
||||
Initialize RewardTrainer.
|
||||
|
||||
Args:
|
||||
model (`transformers.PreTrainedModel`):
|
||||
The model to train, preferably an `AutoModelForSequenceClassification`.
|
||||
args (`RewardConfig`):
|
||||
The arguments to use for training.
|
||||
data_collator (`transformers.DataCollator`):
|
||||
The data collator to use for training. If None is specified, the default data collator
|
||||
(`RewardDataCollatorWithPadding`) will be used which will pad the sequences to the maximum length of
|
||||
the sequences in the batch, given a dataset of paired sequences.
|
||||
train_dataset (`datasets.Dataset`):
|
||||
The dataset to use for training.
|
||||
eval_dataset (`datasets.Dataset`):
|
||||
The dataset to use for evaluation.
|
||||
processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.BaseImageProcessor`], [`~transformers.FeatureExtractionMixin`] or [`~transformers.ProcessorMixin`], *optional*, defaults to `None`):
|
||||
Processing class used to process the data. If provided, will be used to automatically process the
|
||||
inputs for the model, and it will be saved along the model to make it easier to rerun an interrupted
|
||||
training or reuse the fine-tuned model.
|
||||
model_init (`Callable[[], transformers.PreTrainedModel]`):
|
||||
The model initializer to use for training. If None is specified, the default model initializer will be
|
||||
used.
|
||||
compute_metrics (`Callable[[transformers.EvalPrediction], dict]`, *optional* defaults to `compute_accuracy`):
|
||||
The metrics to use for evaluation. If no metrics are specified, the default metric (`compute_accuracy`)
|
||||
will be used.
|
||||
callbacks (`list[transformers.TrainerCallback]`):
|
||||
The callbacks to use for training.
|
||||
optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
|
||||
The optimizer and scheduler to use for training.
|
||||
preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
|
||||
The function to use to preprocess the logits before computing the metrics.
|
||||
peft_config (`dict`, defaults to `None`):
|
||||
The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped
|
||||
in a PEFT model.
|
||||
"""
|
||||
if not is_peft_available() and peft_config is not None:
|
||||
raise ValueError(
|
||||
"PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
|
||||
)
|
||||
elif is_peft_available() and peft_config is not None:
|
||||
if not isinstance(model, PeftModel):
|
||||
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_quantized", False):
|
||||
_supports_gc_kwargs = "gradient_checkpointing_kwargs" in list(
|
||||
inspect.signature(prepare_model_for_kbit_training).parameters
|
||||
)
|
||||
|
||||
prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
|
||||
|
||||
if not _supports_gc_kwargs and args.gradient_checkpointing_kwargs is not None:
|
||||
warnings.warn(
|
||||
"You passed `gradient_checkpointing_kwargs` in the trainer's kwargs, but your peft version does not support it. "
|
||||
"please update to the latest version of peft to use `gradient_checkpointing_kwargs`.",
|
||||
UserWarning,
|
||||
)
|
||||
elif _supports_gc_kwargs and args.gradient_checkpointing_kwargs is not None:
|
||||
prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
|
||||
|
||||
model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
|
||||
|
||||
model = model
|
||||
|
||||
# Disable dropout in the model
|
||||
if args.disable_dropout:
|
||||
disable_dropout_in_model(model)
|
||||
|
||||
if compute_metrics is None:
|
||||
compute_metrics = compute_accuracy
|
||||
|
||||
if data_collator is None:
|
||||
if processing_class is None:
|
||||
raise ValueError(
|
||||
"A processing_class must be specified when using the default RewardDataCollatorWithPadding"
|
||||
)
|
||||
|
||||
max_length = args.max_length
|
||||
|
||||
data_collator = RewardDataCollatorWithPadding(processing_class)
|
||||
|
||||
if args.remove_unused_columns:
|
||||
try: # for bc before https://github.com/huggingface/transformers/pull/25435
|
||||
args.remove_unused_columns = False
|
||||
except FrozenInstanceError:
|
||||
args = replace(args, remove_unused_columns=False)
|
||||
# warn users
|
||||
warnings.warn(
|
||||
"When using RewardDataCollatorWithPadding, you should set `remove_unused_columns=False` in your RewardConfig"
|
||||
" we have set it for you, but you should do it yourself in the future.",
|
||||
UserWarning,
|
||||
)
|
||||
|
||||
self.use_reward_data_collator = True
|
||||
else:
|
||||
self.use_reward_data_collator = False
|
||||
|
||||
# The trainer estimates the number of FLOPs [floating-point operations] using the number of elements in the
|
||||
# input tensor associated with the key "input_ids". However, in Reward, the sampled data does not include the
|
||||
# "input_ids" key. Instead, the available keys are "input_ids_chosen" and "input_ids_rejected". As a result,
|
||||
# the trainer issues the warning: "Could not estimate the number of tokens of the input, floating-point
|
||||
# operations will not be computed." To suppress this warning, we set the "estimate_tokens" key in the model's
|
||||
# "warnings_issued" dictionary to True. This acts as a flag to indicate that the warning has already been
|
||||
# issued.
|
||||
model.warnings_issued["estimate_tokens"] = True
|
||||
|
||||
if "input_ids_chosen" not in train_dataset.column_names:
|
||||
with PartialState().main_process_first():
|
||||
fn_kwargs = {"tokenizer": processing_class}
|
||||
train_dataset = train_dataset.map(maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class})
|
||||
train_dataset = train_dataset.map(
|
||||
_tokenize,
|
||||
batched=True,
|
||||
fn_kwargs=fn_kwargs,
|
||||
num_proc=args.dataset_num_proc,
|
||||
)
|
||||
# This filter is important because otherwise you get samples that exceed the model's context length and
|
||||
# get truncated => noisy signal the chosen/rejected label gets lost. The downside is that the
|
||||
# user might get surprised if N samples are missing from training.
|
||||
train_dataset = train_dataset.filter(
|
||||
lambda x: len(x["input_ids_chosen"]) <= max_length and len(x["input_ids_rejected"]) <= max_length,
|
||||
num_proc=args.dataset_num_proc,
|
||||
)
|
||||
if eval_dataset is not None:
|
||||
eval_dataset = eval_dataset.map(
|
||||
maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class}
|
||||
)
|
||||
eval_dataset = eval_dataset.map(
|
||||
_tokenize,
|
||||
fn_kwargs=fn_kwargs,
|
||||
batched=True,
|
||||
num_proc=args.dataset_num_proc,
|
||||
)
|
||||
# This filter is important because otherwise you get samples that exceed the model's context length and
|
||||
# get truncated => noisy signal the chosen/rejected label gets lost. The downside is that the
|
||||
# user might get surprised if N samples are missing from training.
|
||||
eval_dataset = eval_dataset.filter(
|
||||
lambda x: len(x["input_ids_chosen"]) <= max_length
|
||||
and len(x["input_ids_rejected"]) <= max_length,
|
||||
num_proc=args.dataset_num_proc,
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
model=model,
|
||||
args=args,
|
||||
data_collator=data_collator,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
processing_class=processing_class,
|
||||
model_init=model_init,
|
||||
compute_metrics=compute_metrics,
|
||||
callbacks=callbacks,
|
||||
optimizers=optimizers,
|
||||
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
|
||||
)
|
||||
|
||||
# Add tags for models that have been loaded with the correct transformers version
|
||||
if hasattr(self.model, "add_model_tags"):
|
||||
self.model.add_model_tags(self._tag_names)
|
||||
|
||||
def compute_loss(
|
||||
self,
|
||||
model: Union[PreTrainedModel, nn.Module],
|
||||
inputs: dict[str, Union[torch.Tensor, Any]],
|
||||
return_outputs=False,
|
||||
num_items_in_batch=None,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, torch.Tensor]]]:
|
||||
rewards_chosen = model(
|
||||
input_ids=inputs["input_ids_chosen"],
|
||||
attention_mask=inputs["attention_mask_chosen"],
|
||||
return_dict=True,
|
||||
)["logits"]
|
||||
rewards_rejected = model(
|
||||
input_ids=inputs["input_ids_rejected"],
|
||||
attention_mask=inputs["attention_mask_rejected"],
|
||||
return_dict=True,
|
||||
)["logits"]
|
||||
# calculate loss, optionally modulate with margin
|
||||
if "margin" in inputs:
|
||||
loss = -nn.functional.logsigmoid(rewards_chosen - rewards_rejected - inputs["margin"]).mean()
|
||||
else:
|
||||
loss = -nn.functional.logsigmoid(rewards_chosen - rewards_rejected).mean()
|
||||
|
||||
if self.args.center_rewards_coefficient is not None:
|
||||
loss += self.args.center_rewards_coefficient * torch.mean((rewards_chosen + rewards_rejected) ** 2)
|
||||
|
||||
if return_outputs:
|
||||
return loss, {
|
||||
"rewards_chosen": rewards_chosen,
|
||||
"rewards_rejected": rewards_rejected,
|
||||
}
|
||||
return loss
|
||||
|
||||
def prediction_step(
|
||||
self,
|
||||
model: Union[PreTrainedModel, nn.Module],
|
||||
inputs: dict[str, Union[torch.Tensor, Any]],
|
||||
prediction_loss_only: bool,
|
||||
ignore_keys: Optional[list[str]] = None,
|
||||
) -> tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
inputs = self._prepare_inputs(inputs)
|
||||
if ignore_keys is None:
|
||||
if hasattr(self.model, "config"):
|
||||
ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", [])
|
||||
else:
|
||||
ignore_keys = []
|
||||
|
||||
with torch.no_grad():
|
||||
loss, logits_dict = self.compute_loss(model, inputs, return_outputs=True)
|
||||
|
||||
if prediction_loss_only:
|
||||
return (loss, None, None)
|
||||
|
||||
loss = loss.detach()
|
||||
logits = tuple(v for k, v in logits_dict.items() if k not in ignore_keys)
|
||||
logits = nested_detach(logits)
|
||||
# Stack accepted against rejected, mean over logits
|
||||
# and softmax to get preferences between accepted and rejected to sum to 1
|
||||
logits = torch.stack(logits).mean(dim=2).softmax(dim=0).T
|
||||
|
||||
labels = torch.zeros(logits.shape[0])
|
||||
labels = self._prepare_inputs(labels)
|
||||
|
||||
return loss, logits, labels
|
||||
|
||||
def evaluate(self, *args, **kwargs):
|
||||
num_print_samples = kwargs.pop("num_print_samples", 4)
|
||||
self.visualize_samples(num_print_samples)
|
||||
return super().evaluate(*args, **kwargs)
|
||||
|
||||
def visualize_samples(self, num_print_samples: int):
|
||||
"""
|
||||
Visualize the reward model logits prediction
|
||||
|
||||
Args:
|
||||
num_print_samples (`int`, defaults to `4`):
|
||||
The number of samples to print. Set to `-1` to print all samples.
|
||||
"""
|
||||
eval_dataloader = self.get_eval_dataloader()
|
||||
table = defaultdict(list)
|
||||
for _, inputs in enumerate(eval_dataloader):
|
||||
_, logits, _ = self.prediction_step(self.model, inputs, prediction_loss_only=False)
|
||||
chosen_text = decode_and_strip_padding(inputs["input_ids_chosen"], self.processing_class)
|
||||
rejected_text = decode_and_strip_padding(inputs["input_ids_rejected"], self.processing_class)
|
||||
table["chosen_text"].extend(gather_object(chosen_text))
|
||||
table["rejected_text"].extend(gather_object(rejected_text))
|
||||
table["logits"].extend(
|
||||
gather_object([[round(inner_item, 4) for inner_item in item] for item in logits.tolist()])
|
||||
)
|
||||
if num_print_samples >= 0 and len(table["chosen_text"]) >= num_print_samples:
|
||||
break
|
||||
df = pd.DataFrame(table)
|
||||
if self.accelerator.process_index == 0:
|
||||
if is_rich_available():
|
||||
print_rich_table(df[:num_print_samples])
|
||||
if "wandb" in self.args.report_to:
|
||||
import wandb
|
||||
|
||||
if wandb.run is not None:
|
||||
wandb.log({"completions": wandb.Table(dataframe=df)})
|
||||
|
||||
if "comet_ml" in self.args.report_to:
|
||||
log_table_to_comet_experiment(
|
||||
name="completions.csv",
|
||||
table=df,
|
||||
)
|
||||
|
||||
# Ensure the model card is saved along with the checkpoint
|
||||
def _save_checkpoint(self, model, trial):
|
||||
if self.args.hub_model_id is None:
|
||||
model_name = Path(self.args.output_dir).name
|
||||
else:
|
||||
model_name = self.args.hub_model_id.split("/")[-1]
|
||||
self.create_model_card(model_name=model_name)
|
||||
super()._save_checkpoint(model, trial)
|
||||
|
||||
def create_model_card(
|
||||
self,
|
||||
model_name: Optional[str] = None,
|
||||
dataset_name: Optional[str] = None,
|
||||
tags: Union[str, list[str], None] = None,
|
||||
):
|
||||
"""
|
||||
Creates a draft of a model card using the information available to the `Trainer`.
|
||||
|
||||
Args:
|
||||
model_name (`str` or `None`, *optional*, defaults to `None`):
|
||||
Name of the model.
|
||||
dataset_name (`str` or `None`, *optional*, defaults to `None`):
|
||||
Name of the dataset used for training.
|
||||
tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
|
||||
Tags to be associated with the model card.
|
||||
"""
|
||||
if not self.is_world_process_zero():
|
||||
return
|
||||
|
||||
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
|
||||
base_model = self.model.config._name_or_path
|
||||
else:
|
||||
base_model = None
|
||||
|
||||
# normalize `tags` to a mutable set
|
||||
if tags is None:
|
||||
tags = set()
|
||||
elif isinstance(tags, str):
|
||||
tags = {tags}
|
||||
else:
|
||||
tags = set(tags)
|
||||
|
||||
if hasattr(self.model.config, "unsloth_version"):
|
||||
tags.add("unsloth")
|
||||
|
||||
tags.update(self._tag_names)
|
||||
|
||||
model_card = generate_model_card(
|
||||
base_model=base_model,
|
||||
model_name=model_name,
|
||||
hub_model_id=self.hub_model_id,
|
||||
dataset_name=dataset_name,
|
||||
tags=tags,
|
||||
wandb_url=wandb.run.url if is_wandb_available() and wandb.run is not None else None,
|
||||
comet_url=get_comet_experiment_url(),
|
||||
trainer_name="Reward",
|
||||
)
|
||||
|
||||
model_card.save(os.path.join(self.args.output_dir, "README.md"))
|
||||
class UnslothRewardTrainer(_UnslothRewardTrainer):
|
||||
"""
|
||||
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
model = None,
|
||||
args = None,
|
||||
data_collator = None,
|
||||
train_dataset = None,
|
||||
eval_dataset = None,
|
||||
processing_class = None,
|
||||
model_init = None,
|
||||
compute_metrics = None,
|
||||
callbacks = None,
|
||||
preprocess_logits_for_metrics = None,
|
||||
peft_config = None,
|
||||
**kwargs
|
||||
):
|
||||
if args is None: args = UnslothRewardConfig()
|
||||
use_bf16 = getattr(args, 'bf16', False)
|
||||
if type(use_bf16) is not bool: use_bf16 = False
|
||||
use_fp16 = getattr(args, 'fp16', False)
|
||||
if type(use_fp16) is not bool: use_fp16 = False
|
||||
force_float32 = False
|
||||
if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
|
||||
print('Unsloth: Switching to float32 training since model cannot work with float16')
|
||||
force_float32 = True
|
||||
mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
|
||||
dtype = getattr(model.config, 'torch_dtype', None)
|
||||
if dtype is None: dtype = model.get_input_embeddings().dtype
|
||||
from unsloth_zoo.utils import _get_dtype
|
||||
dtype = _get_dtype(dtype)
|
||||
float16 = dtype == torch.float16
|
||||
if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
|
||||
if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
|
||||
if force_float32:
|
||||
args.fp16 = False
|
||||
args.bf16 = False
|
||||
os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
|
||||
elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
|
||||
args.fp16 = float16
|
||||
args.bf16 = not float16
|
||||
os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
|
||||
if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
|
||||
args.eval_strategy = 'steps'
|
||||
if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
|
||||
ga_steps = getattr(args, 'gradient_accumulation_steps', None)
|
||||
if ga_steps is not None and ga_steps > 1:
|
||||
from transformers import __version__ as transformers_version
|
||||
if Version(transformers_version) <= Version('4.45.2'):
|
||||
print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
|
||||
'`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
|
||||
if getattr(args, 'eval_strategy', 'no') != 'no':
|
||||
eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
|
||||
if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
|
||||
if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
|
||||
fp16_full_eval = getattr(args, 'fp16_full_eval', False)
|
||||
if type(fp16_full_eval) is not bool: fp16_full_eval = False
|
||||
bf16_full_eval = getattr(args, 'bf16_full_eval', False)
|
||||
if type(bf16_full_eval) is not bool: bf16_full_eval = False
|
||||
if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
|
||||
if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
|
||||
if force_float32:
|
||||
args.bf16_full_eval = False
|
||||
args.fp16_full_eval = False
|
||||
elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
|
||||
args.bf16_full_eval = True
|
||||
args.fp16_full_eval = False
|
||||
elif not bf16_full_eval and not fp16_full_eval:
|
||||
args.bf16_full_eval = args.bf16
|
||||
args.fp16_full_eval = args.fp16
|
||||
_output_logits = False
|
||||
if locals().get('compute_metrics', None) is not None: _output_logits = True
|
||||
if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
|
||||
if _output_logits:
|
||||
os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
|
||||
if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
|
||||
pass
|
||||
else:
|
||||
model_max_seq_length = getattr(model, 'max_seq_length', None)
|
||||
args_max_seq_length = getattr(args, 'max_seq_length', None)
|
||||
if args_max_seq_length is None and model_max_seq_length is not None:
|
||||
max_seq_length = model.max_seq_length
|
||||
if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
|
||||
if model is not None and hasattr(model, 'for_training'):
|
||||
model.for_training()
|
||||
if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
|
||||
if 'processing_class' in locals():
|
||||
if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
|
||||
if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
|
||||
__tokenizer = processing_class if 'processing_class' in locals() else tokenizer
|
||||
from unsloth_zoo.vision_utils import UnslothVisionDataCollator
|
||||
if not isinstance(data_collator, UnslothVisionDataCollator):
|
||||
if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
|
||||
data_collator = TransformersDataCollatorForLanguageModeling(__tokenizer, mlm = False, mlm_probability = 0.0)
|
||||
elif isinstance(data_collator, TransformersDataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
|
||||
data_collator = DataCollatorForSeq2Seq(__tokenizer)
|
||||
else:
|
||||
if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
|
||||
if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
|
||||
if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
|
||||
if not isinstance(data_collator, UnslothVisionDataCollator):
|
||||
if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
|
||||
if isinstance(data_collator, DataCollatorForSeq2Seq):
|
||||
data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
|
||||
else:
|
||||
data_collator = TransformersDataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False, mlm_probability = 0.0)
|
||||
other_metrics = []
|
||||
|
||||
from unsloth_zoo.logging_utils import PatchRLStatistics
|
||||
PatchRLStatistics('reward_trainer', other_metrics)
|
||||
|
||||
super().__init__(
|
||||
model = model,
|
||||
args = args,
|
||||
data_collator = data_collator,
|
||||
train_dataset = train_dataset,
|
||||
eval_dataset = eval_dataset,
|
||||
processing_class = processing_class,
|
||||
model_init = model_init,
|
||||
compute_metrics = compute_metrics,
|
||||
callbacks = callbacks,
|
||||
preprocess_logits_for_metrics = preprocess_logits_for_metrics,
|
||||
peft_config = peft_config,**kwargs)
|
||||
if hasattr(self, 'neftune_hook_handle'):
|
||||
self.neftune_hook_handle.remove()
|
||||
if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
|
||||
if getattr(args, 'neftune_noise_alpha', None) is not None:
|
||||
model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
|
||||
pass
|
||||
if hasattr(self, 'accelerator'):
|
||||
scaler = self.accelerator.scaler
|
||||
current_model = model
|
||||
while hasattr(current_model, 'model'):
|
||||
current_model.accelerator_scaler = scaler
|
||||
current_model = current_model.model
|
||||
current_model.accelerator_scaler = scaler
|
||||
pass
|
||||
|
||||
pass
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Reference in New Issue
Block a user