494 lines
20 KiB
Python
494 lines
20 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Instruct Training Pipeline using Unsloth and SFTTrainer
|
|
Supports instruction fine-tuning with conversational data and LoRA fine-tuning
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
import json
|
|
import argparse
|
|
from pathlib import Path
|
|
from typing import Dict, Any, Optional, List
|
|
import yaml
|
|
|
|
# Add the project root to the path
|
|
sys.path.append(str(Path(__file__).parent.parent.parent))
|
|
|
|
from utils.config.config_manager import ConfigManager
|
|
|
|
# Training imports
|
|
import torch
|
|
from datasets import load_from_disk, Dataset
|
|
from unsloth import FastLanguageModel, is_bfloat16_supported
|
|
from unsloth.chat_templates import get_chat_template, standardize_sharegpt, train_on_responses_only
|
|
from trl import SFTTrainer, SFTConfig
|
|
from transformers import DataCollatorForSeq2Seq
|
|
|
|
class InstructTrainer:
|
|
"""Instruction fine-tuning trainer using Unsloth and SFTTrainer"""
|
|
|
|
def __init__(self, config: Dict[str, Any]):
|
|
self.config = config
|
|
self.model = None
|
|
self.tokenizer = None
|
|
self.trainer = None
|
|
|
|
# Set device
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
print(f"Using device: {self.device}")
|
|
|
|
# Model parameters
|
|
self.model_name = config.get('model_name', 'unsloth/Qwen2.5-72B-Instruct')
|
|
self.max_seq_length = config.get('max_seq_length', 2048)
|
|
self.dtype = config.get('dtype', None)
|
|
self.load_in_4bit = config.get('load_in_4bit', True)
|
|
self.hf_token = config.get('hf_token', None)
|
|
|
|
# LoRA parameters
|
|
self.lora_r = config.get('lora_r', 32)
|
|
self.lora_alpha = config.get('lora_alpha', 16)
|
|
self.lora_dropout = config.get('lora_dropout', 0)
|
|
self.target_modules = config.get('target_modules', [
|
|
"q_proj", "k_proj", "v_proj", "o_proj",
|
|
"gate_proj", "up_proj", "down_proj"
|
|
])
|
|
|
|
# Training arguments
|
|
self.batch_size = config.get('batch_size', 1)
|
|
self.gradient_accumulation_steps = config.get('gradient_accumulation_steps', 4)
|
|
self.learning_rate = config.get('learning_rate', 2e-4)
|
|
self.num_epochs = config.get('num_epochs', 1)
|
|
self.max_steps = config.get('max_steps', 30)
|
|
self.warmup_steps = config.get('warmup_steps', 5)
|
|
self.weight_decay = config.get('weight_decay', 0.01)
|
|
self.seed = config.get('seed', 3407)
|
|
|
|
# Output paths
|
|
self.output_dir = config.get('output_dir', './outputs')
|
|
self.model_output_dir = config.get('model_output_dir', './models/instruct')
|
|
|
|
# Chat template
|
|
self.chat_template = config.get('chat_template', 'llama-3.1')
|
|
|
|
def load_model_and_tokenizer(self):
|
|
"""Load the pre-trained model and tokenizer"""
|
|
print("Loading model and tokenizer...")
|
|
|
|
try:
|
|
self.model, self.tokenizer = FastLanguageModel.from_pretrained(
|
|
model_name=self.model_name,
|
|
max_seq_length=self.max_seq_length,
|
|
dtype=self.dtype,
|
|
load_in_4bit=self.load_in_4bit,
|
|
token=self.hf_token
|
|
)
|
|
|
|
print(f"✅ Model loaded: {self.model_name}")
|
|
print(f"✅ Tokenizer loaded with vocab size: {self.tokenizer.vocab_size}")
|
|
|
|
except Exception as e:
|
|
print(f"❌ Error loading model: {e}")
|
|
raise
|
|
|
|
def setup_lora(self):
|
|
"""Setup LoRA for efficient fine-tuning"""
|
|
print("Setting up LoRA configuration...")
|
|
|
|
try:
|
|
self.model = FastLanguageModel.get_peft_model(
|
|
self.model,
|
|
r=self.lora_r,
|
|
target_modules=self.target_modules,
|
|
lora_alpha=self.lora_alpha,
|
|
lora_dropout=self.lora_dropout,
|
|
bias="none",
|
|
use_gradient_checkpointing="unsloth",
|
|
random_state=self.seed,
|
|
use_rslora=False,
|
|
loftq_config=None
|
|
)
|
|
|
|
print(f"✅ LoRA configured with r={self.lora_r}, alpha={self.lora_alpha}")
|
|
|
|
except Exception as e:
|
|
print(f"❌ Error setting up LoRA: {e}")
|
|
raise
|
|
|
|
def setup_chat_template(self):
|
|
"""Setup chat template for conversation formatting"""
|
|
print("Setting up chat template...")
|
|
|
|
try:
|
|
self.tokenizer = get_chat_template(
|
|
self.tokenizer,
|
|
chat_template=self.chat_template,
|
|
)
|
|
|
|
print(f"✅ Chat template configured: {self.chat_template}")
|
|
|
|
except Exception as e:
|
|
print(f"❌ Error setting up chat template: {e}")
|
|
raise
|
|
|
|
def load_dataset(self, dataset_path: str) -> Dataset:
|
|
"""Load the conversation training dataset directly from JSONL file"""
|
|
print(f"Loading conversation dataset from: {dataset_path}")
|
|
|
|
try:
|
|
# Load JSONL data exactly as provided
|
|
data = []
|
|
with open(dataset_path, "r", encoding="utf-8") as f:
|
|
for line in f:
|
|
data.append(json.loads(line))
|
|
|
|
print(f"Loaded {len(data)} examples")
|
|
|
|
# Convert to HuggingFace Dataset
|
|
dataset = Dataset.from_list(data)
|
|
|
|
print(dataset)
|
|
print(dataset[0]) # Show first example
|
|
|
|
return dataset
|
|
|
|
except Exception as e:
|
|
print(f"Error loading conversation dataset: {e}")
|
|
raise
|
|
|
|
def format_dataset_for_training(self, dataset: Dataset) -> Dataset:
|
|
"""Format conversation dataset for training using standardize_sharegpt and apply_chat_template"""
|
|
print("Formatting conversation dataset for training...")
|
|
|
|
try:
|
|
# Define the formatting function exactly as provided
|
|
def formatting_prompts_func(examples):
|
|
convos = examples["conversation"]
|
|
texts = [self.tokenizer.apply_chat_template(convo, tokenize=False, add_generation_prompt=False) for convo in convos]
|
|
return {"text": texts}
|
|
|
|
# Standardize the ShareGPT format
|
|
print("Standardizing ShareGPT format...")
|
|
dataset = standardize_sharegpt(dataset)
|
|
|
|
# Apply the formatting function
|
|
print("Applying chat template formatting...")
|
|
dataset = dataset.map(formatting_prompts_func, batched=True)
|
|
|
|
print(f"✅ Dataset formatted for training with {len(dataset)} samples")
|
|
print(f"Sample formatted text: {dataset[0]['text'][:200]}...")
|
|
|
|
return dataset
|
|
|
|
except Exception as e:
|
|
print(f"❌ Error formatting dataset: {e}")
|
|
raise
|
|
|
|
def setup_trainer(self, train_dataset: Dataset):
|
|
"""Setup the SFTTrainer for instruction fine-tuning"""
|
|
print("Setting up SFTTrainer for instruction fine-tuning...")
|
|
|
|
try:
|
|
# SFT Configuration
|
|
sft_config = SFTConfig(
|
|
per_device_train_batch_size=self.batch_size,
|
|
gradient_accumulation_steps=self.gradient_accumulation_steps,
|
|
warmup_steps=self.warmup_steps,
|
|
max_steps=self.max_steps,
|
|
learning_rate=self.learning_rate,
|
|
logging_steps=1,
|
|
optim="paged_adamw_8bit",
|
|
weight_decay=self.weight_decay,
|
|
lr_scheduler_type="linear",
|
|
seed=self.seed,
|
|
output_dir=self.output_dir,
|
|
report_to="none", # Disable wandb for now
|
|
)
|
|
|
|
print("SFT Configuration:")
|
|
print(f" batch_size: {self.batch_size}")
|
|
print(f" gradient_accumulation_steps: {self.gradient_accumulation_steps}")
|
|
print(f" warmup_steps: {self.warmup_steps}")
|
|
print(f" max_steps: {self.max_steps}")
|
|
print(f" learning_rate: {self.learning_rate}")
|
|
|
|
# Create SFTTrainer
|
|
self.trainer = SFTTrainer(
|
|
model=self.model,
|
|
tokenizer=self.tokenizer,
|
|
train_dataset=train_dataset,
|
|
dataset_text_field="text",
|
|
max_seq_length=self.max_seq_length,
|
|
data_collator=DataCollatorForSeq2Seq(tokenizer=self.tokenizer),
|
|
packing=False, # Disable packing for conversation data
|
|
args=sft_config,
|
|
)
|
|
|
|
print("✅ SFTTrainer configured successfully")
|
|
|
|
except Exception as e:
|
|
print(f"❌ Error setting up trainer: {e}")
|
|
import traceback
|
|
print("Full error traceback:")
|
|
traceback.print_exc()
|
|
raise
|
|
|
|
def setup_response_only_training(self):
|
|
"""Setup training to only learn from assistant responses"""
|
|
print("Setting up response-only training...")
|
|
|
|
try:
|
|
# For Qwen models, we need to use the correct chat template tokens
|
|
# Qwen uses different tokens than Llama
|
|
# if "qwen" in self.model_name.lower():
|
|
instruction_part = "<|im_start|>user\n"
|
|
response_part = "<|im_start|>assistant\n"
|
|
# else:
|
|
# # Default for other models
|
|
# instruction_part = "<|start_header_id|>user<|end_header_id|>\n\n"
|
|
# response_part = "<|start_header_id|>assistant<|end_header_id|>\n\n"
|
|
|
|
# Configure trainer to only train on responses
|
|
self.trainer = train_on_responses_only(
|
|
self.trainer,
|
|
instruction_part=instruction_part,
|
|
response_part=response_part,
|
|
)
|
|
|
|
print("✅ Response-only training configured")
|
|
|
|
except Exception as e:
|
|
print(f"❌ Error setting up response-only training: {e}")
|
|
print("Skipping response-only training and proceeding with full training...")
|
|
# Don't raise the exception, just continue with regular training
|
|
|
|
def train(self, dataset_path: str):
|
|
"""Run the instruction fine-tuning process"""
|
|
print("🚀 Starting instruction fine-tuning process...")
|
|
|
|
try:
|
|
# Load model and tokenizer
|
|
print("Step 1: Loading model and tokenizer...")
|
|
self.load_model_and_tokenizer()
|
|
|
|
# Setup LoRA
|
|
print("Step 2: Setting up LoRA...")
|
|
self.setup_lora()
|
|
|
|
# Setup chat template
|
|
print("Step 3: Setting up chat template...")
|
|
self.setup_chat_template()
|
|
|
|
# Load dataset
|
|
print(f"Step 4: Loading conversation dataset from: {dataset_path}")
|
|
train_dataset = self.load_dataset(dataset_path)
|
|
|
|
# Format dataset for training
|
|
print("Step 5: Formatting dataset for training...")
|
|
formatted_dataset = self.format_dataset_for_training(train_dataset)
|
|
|
|
# Setup trainer
|
|
print("Step 6: Setting up trainer...")
|
|
self.setup_trainer(formatted_dataset)
|
|
|
|
#Setup response-only training (optional but recommended for chat models)
|
|
print("Step 7: Setting up response-only training...")
|
|
try:
|
|
self.setup_response_only_training()
|
|
except Exception as e:
|
|
print(f"⚠️ Response-only training failed: {e}")
|
|
print("Continuing with full training (will train on all tokens)...")
|
|
|
|
# Start training
|
|
print("Step 8: Starting training...")
|
|
trainer_stats = self.trainer.train()
|
|
|
|
print("✅ Instruction fine-tuning completed successfully!")
|
|
print(f"Training stats: {trainer_stats}")
|
|
|
|
# Save the model
|
|
self.save_model()
|
|
|
|
return trainer_stats
|
|
|
|
except Exception as e:
|
|
print(f"❌ Instruction fine-tuning failed: {e}")
|
|
import traceback
|
|
print("Full error traceback:")
|
|
traceback.print_exc()
|
|
raise
|
|
|
|
def save_model(self):
|
|
"""Save the trained instruction model"""
|
|
print("Saving trained instruction model...")
|
|
|
|
try:
|
|
# Create output directory
|
|
Path(self.model_output_dir).mkdir(parents=True, exist_ok=True)
|
|
|
|
# Save model and tokenizer
|
|
self.model.save_pretrained(self.model_output_dir)
|
|
self.tokenizer.save_pretrained(self.model_output_dir)
|
|
|
|
# Save training config
|
|
config_path = Path(self.model_output_dir) / "training_config.json"
|
|
with open(config_path, 'w') as f:
|
|
json.dump(self.config, f, indent=2)
|
|
|
|
print(f"✅ Instruction model saved to: {self.model_output_dir}")
|
|
print(f"✅ You can now use this model for inference")
|
|
|
|
except Exception as e:
|
|
print(f"❌ Error saving model: {e}")
|
|
raise
|
|
|
|
def prepare_for_inference(self):
|
|
"""Prepare model for inference"""
|
|
print("Preparing model for inference...")
|
|
|
|
try:
|
|
FastLanguageModel.for_inference(self.model)
|
|
print("✅ Model prepared for inference")
|
|
|
|
except Exception as e:
|
|
print(f"❌ Error preparing for inference: {e}")
|
|
raise
|
|
|
|
def load_training_config(yaml_path: str) -> Dict[str, Any]:
|
|
"""Load training configuration from YAML file"""
|
|
try:
|
|
with open(yaml_path, 'r') as f:
|
|
config = yaml.safe_load(f)
|
|
|
|
training_config = {}
|
|
|
|
# Model configuration - extract from model section
|
|
if 'model' in config:
|
|
model_config = config['model']
|
|
training_config.update({
|
|
'model_name': model_config.get('name', 'unsloth/Qwen2.5-72B-Instruct'),
|
|
'max_seq_length': int(model_config.get('max_seq_length', 2048)),
|
|
'dtype': model_config.get('dtype', None),
|
|
'load_in_4bit': model_config.get('load_in_4bit', True),
|
|
'hf_token': model_config.get('token', None)
|
|
})
|
|
|
|
# Training configuration - extract from training section
|
|
if 'training' in config:
|
|
training_data = config['training']
|
|
print("Training data from YAML:")
|
|
print(f" num_epochs: {training_data.get('num_epochs')} (type: {type(training_data.get('num_epochs'))})")
|
|
print(f" batch_size: {training_data.get('batch_size')} (type: {type(training_data.get('batch_size'))})")
|
|
print(f" learning_rate: {training_data.get('learning_rate')} (type: {type(training_data.get('learning_rate'))})")
|
|
print(f" weight_decay: {training_data.get('weight_decay')} (type: {type(training_data.get('weight_decay'))})")
|
|
print(f" warmup_steps: {training_data.get('warmup_steps')} (type: {type(training_data.get('warmup_steps'))})")
|
|
print(f" max_steps: {training_data.get('max_steps')} (type: {type(training_data.get('max_steps'))})")
|
|
print(f" gradient_accumulation_steps: {training_data.get('gradient_accumulation_steps')} (type: {type(training_data.get('gradient_accumulation_steps'))})")
|
|
print(f" seed: {training_data.get('seed')} (type: {type(training_data.get('seed'))})")
|
|
print(f" model_output_dir: {training_data.get('model_output_dir')} (type: {type(training_data.get('model_output_dir'))})")
|
|
|
|
training_config.update({
|
|
'num_epochs': int(training_data.get('num_epochs', 1)),
|
|
'batch_size': int(training_data.get('batch_size', 1)),
|
|
'learning_rate': float(training_data.get('learning_rate', 2e-4)),
|
|
'weight_decay': float(training_data.get('weight_decay', 0.01)),
|
|
'warmup_steps': int(training_data.get('warmup_steps', 5)),
|
|
'max_steps': int(training_data.get('max_steps', 30)),
|
|
'gradient_accumulation_steps': int(training_data.get('gradient_accumulation_steps', 4)),
|
|
'lr_scheduler_type': training_data.get('lr_scheduler_type', 'linear'),
|
|
'seed': int(training_data.get('seed', 3407)),
|
|
'model_output_dir': training_data.get('model_output_dir', './models/instruct'),
|
|
# LoRA configuration
|
|
'lora_r': int(training_data.get('lora_r', 32)),
|
|
'lora_alpha': int(training_data.get('lora_alpha', 16)),
|
|
'lora_dropout': float(training_data.get('lora_dropout', 0)),
|
|
'target_modules': training_data.get('target_modules', [
|
|
"q_proj", "k_proj", "v_proj", "o_proj",
|
|
"gate_proj", "up_proj", "down_proj"
|
|
])
|
|
})
|
|
|
|
# Data configuration - use data_path from data section
|
|
if 'data' in config:
|
|
data_config = config['data']
|
|
data_path = data_config.get('data_path', './data/raw/instruct/code_reasoning.jsonl')
|
|
training_config.update({
|
|
'dataset_path': data_path, # Use data_path directly for JSONL file
|
|
})
|
|
|
|
# Output configuration
|
|
training_config.update({
|
|
'output_dir': './outputs',
|
|
'chat_template': 'qwen-2.5' # Use Qwen chat template by default
|
|
})
|
|
|
|
print("Final training_config:")
|
|
for key, value in training_config.items():
|
|
print(f" {key}: {value} (type: {type(value)})")
|
|
|
|
return training_config
|
|
|
|
except Exception as e:
|
|
print(f"Error loading training config: {e}")
|
|
raise
|
|
|
|
def main():
|
|
"""Main training function"""
|
|
parser = argparse.ArgumentParser(description="Instruction Fine-tuning Training Pipeline")
|
|
|
|
# Configuration
|
|
parser.add_argument("--config", type=str, required=True, help="Path to YAML configuration file")
|
|
parser.add_argument("--dataset", type=str, help="Path to training dataset (conversation data path)")
|
|
parser.add_argument("--output-dir", type=str, help="Output directory for model")
|
|
parser.add_argument("--epochs", type=int, help="Number of training epochs")
|
|
parser.add_argument("--batch-size", type=int, help="Training batch size")
|
|
parser.add_argument("--learning-rate", type=float, help="Learning rate")
|
|
parser.add_argument("--max-steps", type=int, help="Maximum training steps")
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Setup logging replaced with print statements
|
|
|
|
try:
|
|
# Load configuration
|
|
print(f"Loading configuration from: {args.config}")
|
|
training_config = load_training_config(args.config)
|
|
|
|
# Override with CLI arguments
|
|
if args.output_dir:
|
|
training_config['model_output_dir'] = args.output_dir
|
|
if args.epochs:
|
|
training_config['num_epochs'] = int(args.epochs)
|
|
if args.batch_size:
|
|
training_config['batch_size'] = int(args.batch_size)
|
|
if args.learning_rate:
|
|
training_config['learning_rate'] = float(args.learning_rate)
|
|
if args.max_steps:
|
|
training_config['max_steps'] = int(args.max_steps)
|
|
|
|
# Determine dataset path: CLI argument takes precedence, then YAML config
|
|
dataset_path = args.dataset or training_config.get('dataset_path')
|
|
if not dataset_path:
|
|
print("No dataset path provided. Use --dataset or ensure output_dir is set in YAML config.")
|
|
sys.exit(1)
|
|
|
|
print("Training configuration:")
|
|
for key, value in training_config.items():
|
|
print(f" {key}: {value}")
|
|
print(f" Dataset path: {dataset_path}")
|
|
|
|
# Initialize trainer
|
|
trainer = InstructTrainer(training_config)
|
|
|
|
# Start training
|
|
trainer.train(dataset_path)
|
|
|
|
print("Instruction fine-tuning completed successfully!")
|
|
|
|
except Exception as e:
|
|
print(f"Instruction fine-tuning failed: {e}")
|
|
sys.exit(1)
|
|
|
|
if __name__ == "__main__":
|
|
main()
|