instruct model setup
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
"""
|
||||
2025.8.4
|
||||
2025.8.5
|
||||
4.55.1
|
||||
2025.8.9
|
||||
2025.8.10
|
||||
4.55.4
|
||||
0.21.0
|
||||
__UNSLOTH_VERSIONING__
|
||||
"""
|
||||
@@ -10,7 +10,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable
|
||||
from trl.trainer.sft_trainer import (Any, AutoModelForCausalLM, AutoTokenizer, BaseImageProcessor, Callable, DataCollator, DataCollatorForLanguageModeling, Dataset, EvalPrediction, FeatureExtractionMixin, IterableDataset, Optional, Path, PeftConfig, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, SFTConfig, SFTTrainer, Trainer, TrainerCallback, TrainingArguments, Union, clone_chat_template, contextlib, dataclass, dataclasses, defaultdict, generate_model_card, get_act_offloading_ctx_manager, get_comet_experiment_url, get_peft_model, is_conversational, is_peft_available, is_wandb_available, nn, os, pad, peft, peft_module_casting_to_bf16, prepare_model_for_kbit_training, torch, version, warnings, Callable, DataCollator, DataCollatorForLanguageModeling, Dataset, IterableDataset, Optional, Union, os, pad, Optional, PeftModel, PreTrainedModel, Trainer, is_peft_available, os, peft, torch, os)
|
||||
from trl.trainer.sft_trainer import (Any, AutoModelForCausalLM, AutoTokenizer, BaseImageProcessor, Callable, DataCollator, DataCollatorForLanguageModeling, Dataset, EvalPrediction, FeatureExtractionMixin, IterableDataset, Optional, Path, PeftConfig, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, SFTConfig, SFTTrainer, Trainer, TrainerCallback, TrainingArguments, Union, clone_chat_template, contextlib, dataclass, dataclasses, defaultdict, generate_model_card, get_act_offloading_ctx_manager, get_comet_experiment_url, get_peft_model, is_conversational, is_peft_available, is_wandb_available, nn, os, pack_dataset, pad, peft, peft_module_casting_to_bf16, prepare_model_for_kbit_training, torch, version, warnings, Callable, DataCollator, DataCollatorForLanguageModeling, Dataset, IterableDataset, Optional, Union, os, pack_dataset, pad, Optional, PeftModel, PreTrainedModel, Trainer, is_peft_available, os, peft, torch, os)
|
||||
|
||||
|
||||
import os
|
||||
@@ -132,6 +132,10 @@ class UnslothSFTConfig(SFTConfig):
|
||||
default = -1,
|
||||
metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
|
||||
)
|
||||
max_seq_length : Optional[int] = field(
|
||||
default = None,
|
||||
metadata = {'help': 'Maximum sequence length to truncate to.'},
|
||||
)
|
||||
def __init__(
|
||||
self,
|
||||
output_dir = None,
|
||||
@@ -280,6 +284,7 @@ class UnslothSFTConfig(SFTConfig):
|
||||
activation_offloading = False,
|
||||
vllm_sampling_params = None,
|
||||
unsloth_num_chunks = -1,
|
||||
max_seq_length = None,
|
||||
**kwargs,
|
||||
):
|
||||
if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
|
||||
@@ -289,7 +294,13 @@ class UnslothSFTConfig(SFTConfig):
|
||||
save_strategy = 'no'
|
||||
if dataset_num_proc is None:
|
||||
from multiprocessing import cpu_count
|
||||
dataset_num_proc = min(cpu_count()*2, 2)
|
||||
dataset_num_proc = max(cpu_count()+4, 2)
|
||||
if os.environ.get('UNSLOTH_ENABLE_FLEX_ATTENTION', '0') == '1':
|
||||
from unsloth_zoo.flex_attention import HAS_FLEX_ATTENTION
|
||||
if HAS_FLEX_ATTENTION and pad_to_multiple_of is None:
|
||||
from unsloth_zoo.flex_attention import FLEX_ATTENTION_BLOCK_SIZE
|
||||
pad_to_multiple_of = FLEX_ATTENTION_BLOCK_SIZE
|
||||
|
||||
|
||||
super().__init__(
|
||||
output_dir = output_dir,
|
||||
@@ -438,6 +449,7 @@ class UnslothSFTConfig(SFTConfig):
|
||||
activation_offloading = activation_offloading,**kwargs)
|
||||
self.vllm_sampling_params = vllm_sampling_params
|
||||
self.unsloth_num_chunks = unsloth_num_chunks
|
||||
self.max_seq_length = max_seq_length
|
||||
pass
|
||||
|
||||
class _UnslothSFTTrainer(Trainer):
|
||||
@@ -868,7 +880,11 @@ class _UnslothSFTTrainer(Trainer):
|
||||
pass
|
||||
|
||||
if not isinstance(dataset, IterableDataset):
|
||||
map_kwargs["num_proc"] = getattr(args, "dataset_num_proc", 2)
|
||||
dataset_num_proc = getattr(args, "dataset_num_proc", None)
|
||||
if dataset_num_proc is None:
|
||||
from multiprocessing import cpu_count
|
||||
dataset_num_proc = max(cpu_count()+4, 2)
|
||||
map_kwargs["num_proc"] = dataset_num_proc
|
||||
else:
|
||||
map_kwargs["batch_size"] = dataset._ex_iterable.batch_size
|
||||
|
||||
@@ -882,18 +898,22 @@ class _UnslothSFTTrainer(Trainer):
|
||||
pass
|
||||
pass
|
||||
if packing:
|
||||
print("Unsloth: Hugging Face's packing is currently buggy - we're disabling it for now!")
|
||||
return dataset
|
||||
# Try using new packing which works in TRL
|
||||
try:
|
||||
pack_dataset
|
||||
except:
|
||||
print("Unsloth: Hugging Face's packing is currently buggy - we're disabling it for now!")
|
||||
return dataset
|
||||
|
||||
if max_seq_length == 0:
|
||||
raise ValueError("When packing is enabled, `max_seq_length` can't be `None`.")
|
||||
|
||||
if use_desc: map_kwargs["desc"] = f"Unsloth: Packing {dataset_name} dataset"
|
||||
dataset = dataset.select_columns(used_column_names).map(
|
||||
pack_examples,
|
||||
batched = True,
|
||||
fn_kwargs = {"seq_length": max_seq_length,},
|
||||
**map_kwargs,
|
||||
dataset = pack_dataset(
|
||||
dataset.select_columns(used_column_names),
|
||||
max_seq_length,
|
||||
getattr(args, "packing_strategy", "bfd"),
|
||||
map_kwargs,
|
||||
)
|
||||
pass
|
||||
return dataset
|
||||
@@ -1101,7 +1121,7 @@ class UnslothSFTTrainer(_UnslothSFTTrainer):
|
||||
print('Unsloth: Switching to float32 training since model cannot work with float16')
|
||||
force_float32 = True
|
||||
mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
|
||||
dtype = getattr(model.config, 'torch_dtype', None)
|
||||
dtype = getattr(model.config, 'dtype', None) or getattr(model.config, 'torch_dtype', None)
|
||||
if dtype is None: dtype = model.get_input_embeddings().dtype
|
||||
from unsloth_zoo.utils import _get_dtype
|
||||
dtype = _get_dtype(dtype)
|
||||
@@ -1166,9 +1186,7 @@ class UnslothSFTTrainer(_UnslothSFTTrainer):
|
||||
max_length = args.max_length
|
||||
else:
|
||||
model_max_length = getattr(model, 'max_seq_length', None)
|
||||
# print(model_max_length, 'mml1')
|
||||
if model_max_length is None: model_max_length = getattr(model, 'max_length', None)
|
||||
# print(model_max_length, 'mml2')
|
||||
if model_max_length is not None:
|
||||
args.max_length = model_max_length
|
||||
max_length = args.max_length
|
||||
@@ -1189,9 +1207,17 @@ class UnslothSFTTrainer(_UnslothSFTTrainer):
|
||||
from unsloth_zoo.vision_utils import UnslothVisionDataCollator
|
||||
if not isinstance(data_collator, UnslothVisionDataCollator):
|
||||
if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
|
||||
data_collator = TransformersDataCollatorForLanguageModeling(__tokenizer, mlm = False, mlm_probability = 0.0)
|
||||
data_collator = TransformersDataCollatorForLanguageModeling(
|
||||
__tokenizer,
|
||||
mlm = False,
|
||||
mlm_probability = 0.0,
|
||||
pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
|
||||
)
|
||||
elif isinstance(data_collator, TransformersDataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
|
||||
data_collator = DataCollatorForSeq2Seq(__tokenizer)
|
||||
data_collator = DataCollatorForSeq2Seq(
|
||||
__tokenizer,
|
||||
pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
|
||||
)
|
||||
else:
|
||||
if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
|
||||
if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
|
||||
@@ -1199,9 +1225,17 @@ class UnslothSFTTrainer(_UnslothSFTTrainer):
|
||||
if not isinstance(data_collator, UnslothVisionDataCollator):
|
||||
if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
|
||||
if isinstance(data_collator, DataCollatorForSeq2Seq):
|
||||
data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
|
||||
data_collator = DataCollatorForSeq2Seq(
|
||||
__tokenizer.tokenizer,
|
||||
pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
|
||||
)
|
||||
else:
|
||||
data_collator = TransformersDataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False, mlm_probability = 0.0)
|
||||
data_collator = TransformersDataCollatorForLanguageModeling(
|
||||
__tokenizer.tokenizer,
|
||||
mlm = False,
|
||||
mlm_probability = 0.0,
|
||||
pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),
|
||||
)
|
||||
other_metrics = []
|
||||
|
||||
from unsloth_zoo.logging_utils import PatchRLStatistics
|
||||
|
||||
Reference in New Issue
Block a user