2025-08-13 21:17:01 +01:00
#!/usr/bin/env python3
"""
Styling Training Pipeline using Unsloth and SFTTrainer
Supports style transfer tasks with LoRA fine-tuning
"""
import os
import sys
import json
import logging
import argparse
from pathlib import Path
from typing import Dict , Any , Optional
import yaml
# Add the project root to the path
sys . path . append ( str ( Path ( __file__ ) . parent . parent . parent ) )
from utils . config . config_manager import ConfigManager
#from utils.logging.logging import setup_logging
# Training imports
import torch
from datasets import load_from_disk , Dataset
from unsloth import FastLanguageModel , is_bfloat16_supported
from trl import SFTTrainer
from transformers import TrainingArguments
logger = logging . getLogger ( __name__ )
class StylingTrainer :
""" Styling task trainer using Unsloth and SFTTrainer """
def __init__ ( self , config : Dict [ str , Any ] ) :
self . config = config
self . model = None
self . tokenizer = None
self . trainer = None
# Set device
self . device = " cuda " if torch . cuda . is_available ( ) else " cpu "
logger . info ( f " Using device: { self . device } " )
# Training parameters
self . max_seq_length = config . get ( ' max_seq_length ' , 2048 )
self . dtype = config . get ( ' dtype ' , None )
self . load_in_4bit = config . get ( ' load_in_4bit ' , True )
self . hf_token = config . get ( ' hf_token ' , None )
# LoRA parameters
self . lora_r = config . get ( ' lora_r ' , 16 )
self . lora_alpha = config . get ( ' lora_alpha ' , 16 )
self . lora_dropout = config . get ( ' lora_dropout ' , 0 )
self . target_modules = config . get ( ' target_modules ' , [
" q_proj " , " k_proj " , " v_proj " , " o_proj " ,
" gate_proj " , " up_proj " , " down_proj "
] )
# Training arguments
self . batch_size = config . get ( ' batch_size ' , 2 )
self . gradient_accumulation_steps = config . get ( ' gradient_accumulation_steps ' , 4 )
self . learning_rate = config . get ( ' learning_rate ' , 2e-4 )
self . num_epochs = config . get ( ' num_epochs ' , 1 )
self . max_steps = config . get ( ' max_steps ' , None )
2025-08-13 23:50:20 +00:00
self . warmup_ratio = config . get ( ' warmup_ratio ' , 0.1 )
# Set a default warmup_steps value
self . warmup_steps = config . get ( ' warmup_steps ' , 10 )
2025-08-13 21:17:01 +01:00
self . weight_decay = config . get ( ' weight_decay ' , 0.01 )
self . seed = config . get ( ' seed ' , 3407 )
# Output paths
self . output_dir = config . get ( ' output_dir ' , ' ./outputs ' )
self . model_output_dir = config . get ( ' model_output_dir ' , ' ./models/styling ' )
def load_model_and_tokenizer ( self ) :
""" Load the pre-trained model and tokenizer """
logger . info ( " Loading model and tokenizer... " )
try :
self . model , self . tokenizer = FastLanguageModel . from_pretrained (
model_name = self . config [ ' model_name ' ] ,
max_seq_length = self . max_seq_length ,
dtype = self . dtype ,
load_in_4bit = self . load_in_4bit ,
token = self . hf_token
)
logger . info ( f " ✅ Model loaded: { self . config [ ' model_name ' ] } " )
logger . info ( f " ✅ Tokenizer loaded with vocab size: { self . tokenizer . vocab_size } " )
except Exception as e :
logger . error ( f " ❌ Error loading model: { e } " )
raise
def setup_lora ( self ) :
""" Setup LoRA for efficient fine-tuning """
logger . info ( " Setting up LoRA configuration... " )
try :
self . model = FastLanguageModel . get_peft_model (
self . model ,
r = self . lora_r ,
target_modules = self . target_modules ,
lora_alpha = self . lora_alpha ,
lora_dropout = self . lora_dropout ,
bias = " none " ,
use_gradient_checkpointing = " unsloth " ,
random_state = self . seed ,
use_rslora = False ,
loftq_config = None
)
logger . info ( f " ✅ LoRA configured with r= { self . lora_r } , alpha= { self . lora_alpha } " )
except Exception as e :
logger . error ( f " ❌ Error setting up LoRA: { e } " )
raise
def load_dataset ( self , dataset_path : str ) - > Dataset :
""" Load the training dataset """
logger . info ( f " Loading dataset from: { dataset_path } " )
try :
if Path ( dataset_path ) . exists ( ) :
# Check if it's a HuggingFace dataset directory
if ( Path ( dataset_path ) / " dataset_info.json " ) . exists ( ) :
# Load from HuggingFace dataset directory
dataset = load_from_disk ( dataset_path )
logger . info ( f " Loaded HuggingFace dataset from disk: { len ( dataset ) } samples " )
else :
# Load from processed data files (JSONL format)
logger . info ( " Loading from processed data files... " )
from datasets import Dataset
import json
all_data = [ ]
data_dir = Path ( dataset_path )
# Look for train.jsonl, validation.jsonl, test.jsonl
for split_file in [ " train.jsonl " , " validation.jsonl " , " test.jsonl " ] :
file_path = data_dir / split_file
if file_path . exists ( ) :
logger . info ( f " Loading { split_file } ... " )
with open ( file_path , ' r ' , encoding = ' utf-8 ' ) as f :
for line in f :
if line . strip ( ) :
data = json . loads ( line )
all_data . append ( data )
if not all_data :
raise ValueError ( f " No data found in { dataset_path } " )
# Create HuggingFace dataset
dataset = Dataset . from_list ( all_data )
logger . info ( f " Created HuggingFace dataset from { len ( all_data ) } samples " )
else :
# Try loading from HuggingFace Hub
logger . info ( f " Attempting to load from HuggingFace Hub: { dataset_path } " )
dataset = Dataset . load_dataset ( dataset_path , split = " train " )
logger . info ( f " Loaded from HuggingFace Hub: { len ( dataset ) } samples " )
logger . info ( f " Dataset loaded: { len ( dataset ) } samples " )
logger . info ( f " Dataset features: { dataset . features } " )
# Verify required fields exist
required_fields = [ " instruction " , " input " , " output " ]
missing_fields = [ field for field in required_fields if field not in dataset . features ]
if missing_fields :
raise ValueError ( f " Missing required fields in dataset: { missing_fields } " )
return dataset
except Exception as e :
logger . error ( f " Error loading dataset: { e } " )
raise
def setup_trainer ( self , train_dataset : Dataset ) :
""" Setup the SFTTrainer """
2025-08-13 23:50:20 +00:00
print ( " Setting up SFTTrainer... " )
2025-08-13 21:17:01 +01:00
try :
# First, map the dataset to create the text field with EOS token
def formatting_prompts_func ( examples ) :
instructions = examples [ " instruction " ]
inputs = examples [ " input " ]
outputs = examples [ " output " ]
texts = [ ]
for instruction , input_text , output in zip ( instructions , inputs , outputs ) :
# Must add EOS_TOKEN, otherwise your generation will go on forever!
alpaca_prompt = """ Below is an instruction that describes a task, paired with an input that provides further context. Write a response that follows the instruction
### Instruction:
{}
### Input:
{}
### Response:
{} """
text = alpaca_prompt . format ( instruction , input_text , output ) + self . tokenizer . eos_token
texts . append ( text )
return { " text " : texts }
# Apply the formatting function to create the text field
2025-08-13 23:50:20 +00:00
print ( " Mapping dataset to create text field with EOS token... " )
2025-08-13 21:17:01 +01:00
formatted_dataset = train_dataset . map ( formatting_prompts_func , batched = True , remove_columns = train_dataset . column_names )
2025-08-13 23:50:20 +00:00
print ( f " Dataset mapped successfully. New features: { formatted_dataset . features } " )
print ( f " Sample text field: { formatted_dataset [ 0 ] [ ' text ' ] [ : 100 ] } ... " )
2025-08-13 21:17:01 +01:00
2025-08-13 23:50:20 +00:00
# Debug logging to identify parameter issues
print ( " Training parameters for TrainingArguments: " )
print ( f " batch_size: { self . batch_size } (type: { type ( self . batch_size ) } ) " )
print ( f " gradient_accumulation_steps: { self . gradient_accumulation_steps } (type: { type ( self . gradient_accumulation_steps ) } ) " )
print ( f " warmup_steps: { self . warmup_steps } (type: { type ( self . warmup_steps ) } ) " )
print ( f " num_epochs: { self . num_epochs } (type: { type ( self . num_epochs ) } ) " )
print ( f " max_steps: { self . max_steps } (type: { type ( self . max_steps ) } ) " )
print ( f " learning_rate: { self . learning_rate } (type: { type ( self . learning_rate ) } ) " )
print ( f " weight_decay: { self . weight_decay } (type: { type ( self . weight_decay ) } ) " )
print ( f " seed: { self . seed } (type: { type ( self . seed ) } ) " )
print ( " Creating TrainingArguments... " )
# Training arguments - using the exact working configuration
2025-08-13 21:17:01 +01:00
training_args = TrainingArguments (
per_device_train_batch_size = self . batch_size ,
gradient_accumulation_steps = self . gradient_accumulation_steps ,
warmup_steps = self . warmup_steps ,
num_train_epochs = self . num_epochs ,
2025-08-13 23:50:20 +00:00
max_steps = self . max_steps if self . max_steps is not None else 60 , # Use default if None
2025-08-13 21:17:01 +01:00
learning_rate = self . learning_rate ,
fp16 = not is_bfloat16_supported ( ) ,
bf16 = is_bfloat16_supported ( ) ,
logging_steps = 1 ,
optim = " adamw_8bit " ,
weight_decay = self . weight_decay ,
lr_scheduler_type = " linear " ,
seed = self . seed ,
output_dir = self . output_dir ,
report_to = " none " , # Disable wandb for now
)
2025-08-13 23:50:20 +00:00
print ( " TrainingArguments created successfully! " )
print ( " SFTTrainer parameters: " )
print ( f " model: { type ( self . model ) } " )
print ( f " tokenizer: { type ( self . tokenizer ) } " )
print ( f " train_dataset: { type ( formatted_dataset ) } with { len ( formatted_dataset ) } samples " )
print ( f " dataset_text_field: text " )
print ( f " max_seq_length: { self . max_seq_length } (type: { type ( self . max_seq_length ) } ) " )
print ( f " dataset_num_proc: 2 " )
print ( f " packing: False " )
print ( f " args: { type ( training_args ) } " )
2025-08-13 21:17:01 +01:00
2025-08-13 23:50:20 +00:00
print ( " Creating SFTTrainer... " )
2025-08-13 21:17:01 +01:00
# Create trainer with the formatted dataset
self . trainer = SFTTrainer (
model = self . model ,
tokenizer = self . tokenizer ,
train_dataset = formatted_dataset , # Use the formatted dataset
dataset_text_field = " text " , # The field we just created
2025-08-13 23:50:20 +00:00
max_seq_length = int ( self . max_seq_length ) if self . max_seq_length is not None else 2048 ,
2025-08-13 21:17:01 +01:00
dataset_num_proc = 2 ,
packing = False , # Can make training 5x faster for short sequences
args = training_args
)
2025-08-13 23:50:20 +00:00
print ( " SFTTrainer configured successfully " )
2025-08-13 21:17:01 +01:00
except Exception as e :
2025-08-13 23:50:20 +00:00
print ( f " Error setting up trainer: { e } " )
import traceback
print ( " Full error traceback: " )
traceback . print_exc ( )
2025-08-13 21:17:01 +01:00
raise
def train ( self , dataset_path : str ) :
""" Run the training process """
2025-08-13 23:50:20 +00:00
print ( " 🚀 Starting training process... " )
2025-08-13 21:17:01 +01:00
try :
# Load model and tokenizer
2025-08-13 23:50:20 +00:00
print ( " Loading model and tokenizer... " )
2025-08-13 21:17:01 +01:00
self . load_model_and_tokenizer ( )
# Setup LoRA
2025-08-13 23:50:20 +00:00
print ( " Setting up LoRA... " )
2025-08-13 21:17:01 +01:00
self . setup_lora ( )
# Load dataset
2025-08-13 23:50:20 +00:00
print ( f " Loading dataset from: { dataset_path } " )
2025-08-13 21:17:01 +01:00
train_dataset = self . load_dataset ( dataset_path )
# Setup trainer
2025-08-13 23:50:20 +00:00
print ( " Setting up trainer... " )
2025-08-13 21:17:01 +01:00
self . setup_trainer ( train_dataset )
# Start training
2025-08-13 23:50:20 +00:00
print ( " Starting training... " )
2025-08-13 21:17:01 +01:00
trainer_stats = self . trainer . train ( )
2025-08-13 23:50:20 +00:00
print ( " ✅ Training completed successfully! " )
print ( f " Training stats: { trainer_stats } " )
2025-08-13 21:17:01 +01:00
# Save the model
self . save_model ( )
return trainer_stats
except Exception as e :
2025-08-13 23:50:20 +00:00
print ( f " ❌ Training failed: { e } " )
import traceback
print ( " Full error traceback: " )
traceback . print_exc ( )
2025-08-13 21:17:01 +01:00
raise
def save_model ( self ) :
""" Save the trained model """
2025-08-13 23:50:20 +00:00
print ( " Saving trained model... " )
2025-08-13 21:17:01 +01:00
try :
# Create output directory
Path ( self . model_output_dir ) . mkdir ( parents = True , exist_ok = True )
# Save model and tokenizer
self . model . save_pretrained ( self . model_output_dir )
self . tokenizer . save_pretrained ( self . model_output_dir )
# Save training config
config_path = Path ( self . model_output_dir ) / " training_config.json "
with open ( config_path , ' w ' ) as f :
json . dump ( self . config , f , indent = 2 )
2025-08-13 23:50:20 +00:00
print ( f " ✅ Model saved to: { self . model_output_dir } " )
print ( f " ✅ You can now use this model for inference with: --config { self . model_output_dir } " )
2025-08-13 21:17:01 +01:00
except Exception as e :
2025-08-13 23:50:20 +00:00
print ( f " ❌ Error saving model: { e } " )
2025-08-13 21:17:01 +01:00
raise
def prepare_for_inference ( self ) :
""" Prepare model for inference """
logger . info ( " Preparing model for inference... " )
try :
FastLanguageModel . for_inference ( self . model )
logger . info ( " ✅ Model prepared for inference " )
except Exception as e :
logger . error ( f " ❌ Error preparing for inference: { e } " )
raise
2025-08-13 23:50:20 +00:00
def load_training_config ( yaml_path : str ) - > Dict [ str , Any ] :
2025-08-13 21:17:01 +01:00
""" Load training configuration from YAML file """
try :
2025-08-13 23:50:20 +00:00
with open ( yaml_path , ' r ' ) as f :
2025-08-13 21:17:01 +01:00
config = yaml . safe_load ( f )
training_config = { }
2025-08-13 23:50:20 +00:00
# Model configuration - extract from model section
2025-08-13 21:17:01 +01:00
if ' model ' in config :
2025-08-13 23:50:20 +00:00
model_config = config [ ' model ' ]
2025-08-13 21:17:01 +01:00
training_config . update ( {
2025-08-13 23:50:20 +00:00
' model_name ' : model_config . get ( ' name ' , ' unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit ' ) ,
' max_seq_length ' : int ( model_config . get ( ' max_seq_length ' , 2048 ) ) ,
' dtype ' : model_config . get ( ' dtype ' , None ) ,
' load_in_4bit ' : model_config . get ( ' load_in_4bit ' , True ) ,
' hf_token ' : model_config . get ( ' token ' , None )
2025-08-13 21:17:01 +01:00
} )
2025-08-13 23:50:20 +00:00
# Training configuration - extract from training section
2025-08-13 21:17:01 +01:00
if ' training ' in config :
training_data = config [ ' training ' ]
2025-08-13 23:50:20 +00:00
print ( " DEBUG: Training data from YAML: " )
print ( f " num_epochs: { training_data . get ( ' num_epochs ' ) } (type: { type ( training_data . get ( ' num_epochs ' ) ) } ) " )
print ( f " batch_size: { training_data . get ( ' batch_size ' ) } (type: { type ( training_data . get ( ' batch_size ' ) ) } ) " )
print ( f " learning_rate: { training_data . get ( ' learning_rate ' ) } (type: { type ( training_data . get ( ' learning_rate ' ) ) } ) " )
print ( f " weight_decay: { training_data . get ( ' weight_decay ' ) } (type: { type ( training_data . get ( ' weight_decay ' ) ) } ) " )
print ( f " warmup_steps: { training_data . get ( ' warmup_steps ' ) } (type: { type ( training_data . get ( ' warmup_steps ' ) ) } ) " )
print ( f " max_steps: { training_data . get ( ' max_steps ' ) } (type: { type ( training_data . get ( ' max_steps ' ) ) } ) " )
print ( f " gradient_accumulation_steps: { training_data . get ( ' gradient_accumulation_steps ' ) } (type: { type ( training_data . get ( ' gradient_accumulation_steps ' ) ) } ) " )
print ( f " seed: { training_data . get ( ' seed ' ) } (type: { type ( training_data . get ( ' seed ' ) ) } ) " )
print ( f " model_output_dir: { training_data . get ( ' model_output_dir ' ) } (type: { type ( training_data . get ( ' model_output_dir ' ) ) } ) " )
2025-08-13 21:17:01 +01:00
training_config . update ( {
2025-08-13 23:50:20 +00:00
' num_epochs ' : int ( training_data . get ( ' num_epochs ' , 1 ) ) ,
' batch_size ' : int ( training_data . get ( ' batch_size ' , 2 ) ) ,
' learning_rate ' : float ( training_data . get ( ' learning_rate ' , 2e-4 ) ) ,
' weight_decay ' : float ( training_data . get ( ' weight_decay ' , 0.01 ) ) ,
' warmup_steps ' : int ( training_data . get ( ' warmup_steps ' , 5 ) ) ,
' max_steps ' : int ( training_data . get ( ' max_steps ' , 60 ) ) ,
' gradient_accumulation_steps ' : int ( training_data . get ( ' gradient_accumulation_steps ' , 4 ) ) ,
' lr_scheduler_type ' : training_data . get ( ' lr_scheduler_type ' , ' linear ' ) ,
' seed ' : int ( training_data . get ( ' seed ' , 3407 ) ) ,
' model_output_dir ' : training_data . get ( ' model_output_dir ' , ' ./models/styling ' )
2025-08-13 21:17:01 +01:00
} )
# Data configuration - use output_dir from data section
if ' data ' in config :
data_config = config [ ' data ' ]
output_dir = data_config . get ( ' output_dir ' , ' ./data/processed/styling ' )
training_config . update ( {
' data_output_dir ' : output_dir ,
' dataset_path ' : output_dir , # Default dataset path is the output_dir
' style_instruction ' : data_config . get ( ' instruction ' , ' Rewrite the following text in a formal style ' )
} )
# LoRA configuration
training_config . update ( {
' lora_r ' : 16 ,
' lora_alpha ' : 16 ,
' lora_dropout ' : 0 ,
' target_modules ' : [
" q_proj " , " k_proj " , " v_proj " , " o_proj " ,
" gate_proj " , " up_proj " , " down_proj "
] ,
' output_dir ' : ' ./outputs ' ,
' model_output_dir ' : ' ./models/styling '
} )
2025-08-13 23:50:20 +00:00
print ( " DEBUG: Final training_config: " )
for key , value in training_config . items ( ) :
print ( f " { key } : { value } (type: { type ( value ) } ) " )
2025-08-13 21:17:01 +01:00
return training_config
except Exception as e :
logger . error ( f " Error loading training config: { e } " )
raise
def main ( ) :
""" Main training function """
parser = argparse . ArgumentParser ( description = " Styling Training Pipeline " )
# Configuration
parser . add_argument ( " --config " , type = str , required = True , help = " Path to YAML configuration file " )
parser . add_argument ( " --dataset " , type = str , help = " Path to training dataset (HF dataset path or local path) " )
parser . add_argument ( " --output-dir " , type = str , help = " Output directory for model " )
parser . add_argument ( " --epochs " , type = int , help = " Number of training epochs " )
parser . add_argument ( " --batch-size " , type = int , help = " Training batch size " )
parser . add_argument ( " --learning-rate " , type = float , help = " Learning rate " )
parser . add_argument ( " --max-steps " , type = int , help = " Maximum training steps " )
args = parser . parse_args ( )
# Setup logging
# setup_logging() # Commented out as per user's change
try :
# Load configuration
logger . info ( f " Loading configuration from: { args . config } " )
training_config = load_training_config ( args . config )
# Override with CLI arguments
if args . output_dir :
training_config [ ' model_output_dir ' ] = args . output_dir
if args . epochs :
2025-08-13 23:50:20 +00:00
training_config [ ' num_epochs ' ] = int ( args . epochs )
2025-08-13 21:17:01 +01:00
if args . batch_size :
2025-08-13 23:50:20 +00:00
training_config [ ' batch_size ' ] = int ( args . batch_size )
2025-08-13 21:17:01 +01:00
if args . learning_rate :
2025-08-13 23:50:20 +00:00
training_config [ ' learning_rate ' ] = float ( args . learning_rate )
2025-08-13 21:17:01 +01:00
if args . max_steps :
2025-08-13 23:50:20 +00:00
training_config [ ' max_steps ' ] = int ( args . max_steps )
2025-08-13 21:17:01 +01:00
# Determine dataset path: CLI argument takes precedence, then YAML config
dataset_path = args . dataset or training_config . get ( ' dataset_path ' )
if not dataset_path :
logger . error ( " No dataset path provided. Use --dataset or ensure output_dir is set in YAML config. " )
sys . exit ( 1 )
logger . info ( " Training configuration: " )
for key , value in training_config . items ( ) :
logger . info ( f " { key } : { value } " )
logger . info ( f " Dataset path: { dataset_path } " )
# Initialize trainer
trainer = StylingTrainer ( training_config )
# Start training
trainer . train ( dataset_path )
logger . info ( " Training completed successfully! " )
except Exception as e :
logger . error ( f " Training failed: { e } " )
sys . exit ( 1 )
if __name__ == " __main__ " :
main ( )