instruct mode added to pipiline
This commit is contained in:
+82
-89
@@ -7,7 +7,6 @@ Supports instruction fine-tuning with conversational data and LoRA fine-tuning
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import logging
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Optional, List
|
||||
@@ -21,13 +20,11 @@ from utils.config.config_manager import ConfigManager
|
||||
# Training imports
|
||||
import torch
|
||||
from datasets import load_from_disk, Dataset
|
||||
from unsloth import FastLanguageModel, is_bfloat16_supported
|
||||
from unsloth import FastLanguageModel #is_bfloat16_supported
|
||||
from unsloth.chat_templates import get_chat_template, standardize_sharegpt, train_on_responses_only
|
||||
from trl import SFTTrainer, SFTConfig
|
||||
from transformers import DataCollatorForSeq2Seq
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class InstructTrainer:
|
||||
"""Instruction fine-tuning trainer using Unsloth and SFTTrainer"""
|
||||
|
||||
@@ -39,7 +36,7 @@ class InstructTrainer:
|
||||
|
||||
# Set device
|
||||
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
logger.info(f"Using device: {self.device}")
|
||||
print(f"Using device: {self.device}")
|
||||
|
||||
# Model parameters
|
||||
self.model_name = config.get('model_name', 'unsloth/Qwen2.5-72B-Instruct')
|
||||
@@ -76,7 +73,7 @@ class InstructTrainer:
|
||||
|
||||
def load_model_and_tokenizer(self):
|
||||
"""Load the pre-trained model and tokenizer"""
|
||||
logger.info("Loading model and tokenizer...")
|
||||
print("Loading model and tokenizer...")
|
||||
|
||||
try:
|
||||
self.model, self.tokenizer = FastLanguageModel.from_pretrained(
|
||||
@@ -87,16 +84,16 @@ class InstructTrainer:
|
||||
token=self.hf_token
|
||||
)
|
||||
|
||||
logger.info(f"✅ Model loaded: {self.model_name}")
|
||||
logger.info(f"✅ Tokenizer loaded with vocab size: {self.tokenizer.vocab_size}")
|
||||
print(f"✅ Model loaded: {self.model_name}")
|
||||
print(f"✅ Tokenizer loaded with vocab size: {self.tokenizer.vocab_size}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error loading model: {e}")
|
||||
print(f"❌ Error loading model: {e}")
|
||||
raise
|
||||
|
||||
def setup_lora(self):
|
||||
"""Setup LoRA for efficient fine-tuning"""
|
||||
logger.info("Setting up LoRA configuration...")
|
||||
print("Setting up LoRA configuration...")
|
||||
|
||||
try:
|
||||
self.model = FastLanguageModel.get_peft_model(
|
||||
@@ -112,15 +109,15 @@ class InstructTrainer:
|
||||
loftq_config=None
|
||||
)
|
||||
|
||||
logger.info(f"✅ LoRA configured with r={self.lora_r}, alpha={self.lora_alpha}")
|
||||
print(f"✅ LoRA configured with r={self.lora_r}, alpha={self.lora_alpha}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error setting up LoRA: {e}")
|
||||
print(f"❌ Error setting up LoRA: {e}")
|
||||
raise
|
||||
|
||||
def setup_chat_template(self):
|
||||
"""Setup chat template for conversation formatting"""
|
||||
logger.info("Setting up chat template...")
|
||||
print("Setting up chat template...")
|
||||
|
||||
try:
|
||||
self.tokenizer = get_chat_template(
|
||||
@@ -128,15 +125,15 @@ class InstructTrainer:
|
||||
chat_template=self.chat_template,
|
||||
)
|
||||
|
||||
logger.info(f"✅ Chat template configured: {self.chat_template}")
|
||||
print(f"✅ Chat template configured: {self.chat_template}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error setting up chat template: {e}")
|
||||
print(f"❌ Error setting up chat template: {e}")
|
||||
raise
|
||||
|
||||
def load_dataset(self, dataset_path: str) -> Dataset:
|
||||
"""Load the conversation training dataset"""
|
||||
logger.info(f"Loading conversation dataset from: {dataset_path}")
|
||||
print(f"Loading conversation dataset from: {dataset_path}")
|
||||
|
||||
try:
|
||||
if Path(dataset_path).exists():
|
||||
@@ -144,10 +141,10 @@ class InstructTrainer:
|
||||
if (Path(dataset_path) / "dataset_info.json").exists():
|
||||
# Load from HuggingFace dataset directory
|
||||
dataset = load_from_disk(dataset_path)
|
||||
logger.info(f"Loaded HuggingFace dataset from disk: {len(dataset)} samples")
|
||||
print(f"Loaded HuggingFace dataset from disk: {len(dataset)} samples")
|
||||
else:
|
||||
# Load from processed conversation data files (JSONL format)
|
||||
logger.info("Loading from processed conversation data files...")
|
||||
print("Loading from processed conversation data files...")
|
||||
from datasets import Dataset
|
||||
import json
|
||||
|
||||
@@ -158,7 +155,7 @@ class InstructTrainer:
|
||||
for split_file in ["train.jsonl", "validation.jsonl", "test.jsonl"]:
|
||||
file_path = data_dir / split_file
|
||||
if file_path.exists():
|
||||
logger.info(f"Loading {split_file}...")
|
||||
print(f"Loading {split_file}...")
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
if line.strip():
|
||||
@@ -170,15 +167,15 @@ class InstructTrainer:
|
||||
|
||||
# Create HuggingFace dataset
|
||||
dataset = Dataset.from_list(all_data)
|
||||
logger.info(f"Created HuggingFace dataset from {len(all_data)} conversation samples")
|
||||
print(f"Created HuggingFace dataset from {len(all_data)} conversation samples")
|
||||
else:
|
||||
# Try loading from HuggingFace Hub
|
||||
logger.info(f"Attempting to load from HuggingFace Hub: {dataset_path}")
|
||||
print(f"Attempting to load from HuggingFace Hub: {dataset_path}")
|
||||
dataset = Dataset.load_dataset(dataset_path, split="train")
|
||||
logger.info(f"Loaded from HuggingFace Hub: {len(dataset)} samples")
|
||||
print(f"Loaded from HuggingFace Hub: {len(dataset)} samples")
|
||||
|
||||
logger.info(f"Dataset loaded: {len(dataset)} samples")
|
||||
logger.info(f"Dataset features: {dataset.features}")
|
||||
print(f"Dataset loaded: {len(dataset)} samples")
|
||||
print(f"Dataset features: {dataset.features}")
|
||||
|
||||
# Verify required fields exist for conversation data
|
||||
required_fields = ["conversation"]
|
||||
@@ -189,16 +186,16 @@ class InstructTrainer:
|
||||
return dataset
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading conversation dataset: {e}")
|
||||
print(f"Error loading conversation dataset: {e}")
|
||||
raise
|
||||
|
||||
def format_dataset_for_training(self, dataset: Dataset) -> Dataset:
|
||||
"""Format conversation dataset for training using standardize_sharegpt and apply_chat_template"""
|
||||
logger.info("Formatting conversation dataset for training...")
|
||||
print("Formatting conversation dataset for training...")
|
||||
|
||||
try:
|
||||
# Standardize the ShareGPT format
|
||||
logger.info("Standardizing ShareGPT format...")
|
||||
print("Standardizing ShareGPT format...")
|
||||
dataset = standardize_sharegpt(dataset)
|
||||
|
||||
# Define the formatting function for chat templates
|
||||
@@ -214,21 +211,21 @@ class InstructTrainer:
|
||||
return {"text": texts}
|
||||
|
||||
# Apply the formatting function
|
||||
logger.info("Applying chat template formatting...")
|
||||
print("Applying chat template formatting...")
|
||||
dataset = dataset.map(formatting_prompts_func, batched=True)
|
||||
|
||||
logger.info(f"✅ Dataset formatted for training with {len(dataset)} samples")
|
||||
logger.info(f"Sample formatted text: {dataset[0]['text'][:200]}...")
|
||||
print(f"✅ Dataset formatted for training with {len(dataset)} samples")
|
||||
print(f"Sample formatted text: {dataset[0]['text'][:200]}...")
|
||||
|
||||
return dataset
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error formatting dataset: {e}")
|
||||
print(f"❌ Error formatting dataset: {e}")
|
||||
raise
|
||||
|
||||
def setup_trainer(self, train_dataset: Dataset):
|
||||
"""Setup the SFTTrainer for instruction fine-tuning"""
|
||||
logger.info("Setting up SFTTrainer for instruction fine-tuning...")
|
||||
print("Setting up SFTTrainer for instruction fine-tuning...")
|
||||
|
||||
try:
|
||||
# SFT Configuration
|
||||
@@ -247,12 +244,12 @@ class InstructTrainer:
|
||||
report_to="none", # Disable wandb for now
|
||||
)
|
||||
|
||||
logger.info("SFT Configuration:")
|
||||
logger.info(f" batch_size: {self.batch_size}")
|
||||
logger.info(f" gradient_accumulation_steps: {self.gradient_accumulation_steps}")
|
||||
logger.info(f" warmup_steps: {self.warmup_steps}")
|
||||
logger.info(f" max_steps: {self.max_steps}")
|
||||
logger.info(f" learning_rate: {self.learning_rate}")
|
||||
print("SFT Configuration:")
|
||||
print(f" batch_size: {self.batch_size}")
|
||||
print(f" gradient_accumulation_steps: {self.gradient_accumulation_steps}")
|
||||
print(f" warmup_steps: {self.warmup_steps}")
|
||||
print(f" max_steps: {self.max_steps}")
|
||||
print(f" learning_rate: {self.learning_rate}")
|
||||
|
||||
# Create SFTTrainer
|
||||
self.trainer = SFTTrainer(
|
||||
@@ -266,18 +263,18 @@ class InstructTrainer:
|
||||
args=sft_config,
|
||||
)
|
||||
|
||||
logger.info("✅ SFTTrainer configured successfully")
|
||||
print("✅ SFTTrainer configured successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error setting up trainer: {e}")
|
||||
print(f"❌ Error setting up trainer: {e}")
|
||||
import traceback
|
||||
logger.error("Full error traceback:")
|
||||
print("Full error traceback:")
|
||||
traceback.print_exc()
|
||||
raise
|
||||
|
||||
def setup_response_only_training(self):
|
||||
"""Setup training to only learn from assistant responses"""
|
||||
logger.info("Setting up response-only training...")
|
||||
print("Setting up response-only training...")
|
||||
|
||||
try:
|
||||
# Configure trainer to only train on responses
|
||||
@@ -287,51 +284,51 @@ class InstructTrainer:
|
||||
response_part="<|im_start|>assistant\n",
|
||||
)
|
||||
|
||||
logger.info("✅ Response-only training configured")
|
||||
print("✅ Response-only training configured")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error setting up response-only training: {e}")
|
||||
print(f"❌ Error setting up response-only training: {e}")
|
||||
raise
|
||||
|
||||
def train(self, dataset_path: str):
|
||||
"""Run the instruction fine-tuning process"""
|
||||
logger.info("🚀 Starting instruction fine-tuning process...")
|
||||
print("🚀 Starting instruction fine-tuning process...")
|
||||
|
||||
try:
|
||||
# Load model and tokenizer
|
||||
logger.info("Step 1: Loading model and tokenizer...")
|
||||
print("Step 1: Loading model and tokenizer...")
|
||||
self.load_model_and_tokenizer()
|
||||
|
||||
# Setup LoRA
|
||||
logger.info("Step 2: Setting up LoRA...")
|
||||
print("Step 2: Setting up LoRA...")
|
||||
self.setup_lora()
|
||||
|
||||
# Setup chat template
|
||||
logger.info("Step 3: Setting up chat template...")
|
||||
print("Step 3: Setting up chat template...")
|
||||
self.setup_chat_template()
|
||||
|
||||
# Load dataset
|
||||
logger.info(f"Step 4: Loading conversation dataset from: {dataset_path}")
|
||||
print(f"Step 4: Loading conversation dataset from: {dataset_path}")
|
||||
train_dataset = self.load_dataset(dataset_path)
|
||||
|
||||
# Format dataset for training
|
||||
logger.info("Step 5: Formatting dataset for training...")
|
||||
print("Step 5: Formatting dataset for training...")
|
||||
formatted_dataset = self.format_dataset_for_training(train_dataset)
|
||||
|
||||
# Setup trainer
|
||||
logger.info("Step 6: Setting up trainer...")
|
||||
print("Step 6: Setting up trainer...")
|
||||
self.setup_trainer(formatted_dataset)
|
||||
|
||||
# Setup response-only training (optional but recommended for chat models)
|
||||
logger.info("Step 7: Setting up response-only training...")
|
||||
print("Step 7: Setting up response-only training...")
|
||||
self.setup_response_only_training()
|
||||
|
||||
# Start training
|
||||
logger.info("Step 8: Starting training...")
|
||||
print("Step 8: Starting training...")
|
||||
trainer_stats = self.trainer.train()
|
||||
|
||||
logger.info("✅ Instruction fine-tuning completed successfully!")
|
||||
logger.info(f"Training stats: {trainer_stats}")
|
||||
print("✅ Instruction fine-tuning completed successfully!")
|
||||
print(f"Training stats: {trainer_stats}")
|
||||
|
||||
# Save the model
|
||||
self.save_model()
|
||||
@@ -339,15 +336,15 @@ class InstructTrainer:
|
||||
return trainer_stats
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Instruction fine-tuning failed: {e}")
|
||||
print(f"❌ Instruction fine-tuning failed: {e}")
|
||||
import traceback
|
||||
logger.error("Full error traceback:")
|
||||
print("Full error traceback:")
|
||||
traceback.print_exc()
|
||||
raise
|
||||
|
||||
def save_model(self):
|
||||
"""Save the trained instruction model"""
|
||||
logger.info("Saving trained instruction model...")
|
||||
print("Saving trained instruction model...")
|
||||
|
||||
try:
|
||||
# Create output directory
|
||||
@@ -362,23 +359,23 @@ class InstructTrainer:
|
||||
with open(config_path, 'w') as f:
|
||||
json.dump(self.config, f, indent=2)
|
||||
|
||||
logger.info(f"✅ Instruction model saved to: {self.model_output_dir}")
|
||||
logger.info(f"✅ You can now use this model for inference")
|
||||
print(f"✅ Instruction model saved to: {self.model_output_dir}")
|
||||
print(f"✅ You can now use this model for inference")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error saving model: {e}")
|
||||
print(f"❌ Error saving model: {e}")
|
||||
raise
|
||||
|
||||
def prepare_for_inference(self):
|
||||
"""Prepare model for inference"""
|
||||
logger.info("Preparing model for inference...")
|
||||
print("Preparing model for inference...")
|
||||
|
||||
try:
|
||||
FastLanguageModel.for_inference(self.model)
|
||||
logger.info("✅ Model prepared for inference")
|
||||
print("✅ Model prepared for inference")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error preparing for inference: {e}")
|
||||
print(f"❌ Error preparing for inference: {e}")
|
||||
raise
|
||||
|
||||
def load_training_config(yaml_path: str) -> Dict[str, Any]:
|
||||
@@ -403,16 +400,16 @@ def load_training_config(yaml_path: str) -> Dict[str, Any]:
|
||||
# Training configuration - extract from training section
|
||||
if 'training' in config:
|
||||
training_data = config['training']
|
||||
logger.info("Training data from YAML:")
|
||||
logger.info(f" num_epochs: {training_data.get('num_epochs')} (type: {type(training_data.get('num_epochs'))})")
|
||||
logger.info(f" batch_size: {training_data.get('batch_size')} (type: {type(training_data.get('batch_size'))})")
|
||||
logger.info(f" learning_rate: {training_data.get('learning_rate')} (type: {type(training_data.get('learning_rate'))})")
|
||||
logger.info(f" weight_decay: {training_data.get('weight_decay')} (type: {type(training_data.get('weight_decay'))})")
|
||||
logger.info(f" warmup_steps: {training_data.get('warmup_steps')} (type: {type(training_data.get('warmup_steps'))})")
|
||||
logger.info(f" max_steps: {training_data.get('max_steps')} (type: {type(training_data.get('max_steps'))})")
|
||||
logger.info(f" gradient_accumulation_steps: {training_data.get('gradient_accumulation_steps')} (type: {type(training_data.get('gradient_accumulation_steps'))})")
|
||||
logger.info(f" seed: {training_data.get('seed')} (type: {type(training_data.get('seed'))})")
|
||||
logger.info(f" model_output_dir: {training_data.get('model_output_dir')} (type: {type(training_data.get('model_output_dir'))})")
|
||||
print("Training data from YAML:")
|
||||
print(f" num_epochs: {training_data.get('num_epochs')} (type: {type(training_data.get('num_epochs'))})")
|
||||
print(f" batch_size: {training_data.get('batch_size')} (type: {type(training_data.get('batch_size'))})")
|
||||
print(f" learning_rate: {training_data.get('learning_rate')} (type: {type(training_data.get('learning_rate'))})")
|
||||
print(f" weight_decay: {training_data.get('weight_decay')} (type: {type(training_data.get('weight_decay'))})")
|
||||
print(f" warmup_steps: {training_data.get('warmup_steps')} (type: {type(training_data.get('warmup_steps'))})")
|
||||
print(f" max_steps: {training_data.get('max_steps')} (type: {type(training_data.get('max_steps'))})")
|
||||
print(f" gradient_accumulation_steps: {training_data.get('gradient_accumulation_steps')} (type: {type(training_data.get('gradient_accumulation_steps'))})")
|
||||
print(f" seed: {training_data.get('seed')} (type: {type(training_data.get('seed'))})")
|
||||
print(f" model_output_dir: {training_data.get('model_output_dir')} (type: {type(training_data.get('model_output_dir'))})")
|
||||
|
||||
training_config.update({
|
||||
'num_epochs': int(training_data.get('num_epochs', 1)),
|
||||
@@ -450,14 +447,14 @@ def load_training_config(yaml_path: str) -> Dict[str, Any]:
|
||||
'chat_template': 'qwen-2.5' # Use Qwen chat template by default
|
||||
})
|
||||
|
||||
logger.info("Final training_config:")
|
||||
print("Final training_config:")
|
||||
for key, value in training_config.items():
|
||||
logger.info(f" {key}: {value} (type: {type(value)})")
|
||||
print(f" {key}: {value} (type: {type(value)})")
|
||||
|
||||
return training_config
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading training config: {e}")
|
||||
print(f"Error loading training config: {e}")
|
||||
raise
|
||||
|
||||
def main():
|
||||
@@ -475,15 +472,11 @@ def main():
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
# Setup logging replaced with print statements
|
||||
|
||||
try:
|
||||
# Load configuration
|
||||
logger.info(f"Loading configuration from: {args.config}")
|
||||
print(f"Loading configuration from: {args.config}")
|
||||
training_config = load_training_config(args.config)
|
||||
|
||||
# Override with CLI arguments
|
||||
@@ -501,13 +494,13 @@ def main():
|
||||
# Determine dataset path: CLI argument takes precedence, then YAML config
|
||||
dataset_path = args.dataset or training_config.get('dataset_path')
|
||||
if not dataset_path:
|
||||
logger.error("No dataset path provided. Use --dataset or ensure output_dir is set in YAML config.")
|
||||
print("No dataset path provided. Use --dataset or ensure output_dir is set in YAML config.")
|
||||
sys.exit(1)
|
||||
|
||||
logger.info("Training configuration:")
|
||||
print("Training configuration:")
|
||||
for key, value in training_config.items():
|
||||
logger.info(f" {key}: {value}")
|
||||
logger.info(f" Dataset path: {dataset_path}")
|
||||
print(f" {key}: {value}")
|
||||
print(f" Dataset path: {dataset_path}")
|
||||
|
||||
# Initialize trainer
|
||||
trainer = InstructTrainer(training_config)
|
||||
@@ -515,10 +508,10 @@ def main():
|
||||
# Start training
|
||||
trainer.train(dataset_path)
|
||||
|
||||
logger.info("Instruction fine-tuning completed successfully!")
|
||||
print("Instruction fine-tuning completed successfully!")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Instruction fine-tuning failed: {e}")
|
||||
print(f"Instruction fine-tuning failed: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user