instruct mode added to pipiline

This commit is contained in:
Your Name
2025-08-28 16:46:24 +00:00
parent 78d519efbf
commit 77c563f358
16 changed files with 19404 additions and 161 deletions
@@ -9,3 +9,4 @@ numpy>=1.24.0
scikit-learn>=1.3.0
pyyaml>=6.0
huggingface-hub>=0.15.0
unsloth
+46
View File
@@ -0,0 +1,46 @@
#!/usr/bin/env bash
set -e # exit on error
# Step 1: Download Miniconda installer
echo ">>> Downloading Miniconda..."
wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda.sh
# Step 2: Run the installer
echo ">>> Installing Miniconda..."
bash ~/miniconda.sh -b -p $HOME/miniconda
# Step 3: Initialize conda for bash
echo ">>> Initializing Conda..."
eval "$($HOME/miniconda/bin/conda shell.bash hook)"
# Step 4: Add conda to PATH permanently
if ! grep -q "miniconda/bin" ~/.bashrc; then
echo "export PATH=\$HOME/miniconda/bin:\$PATH" >> ~/.bashrc
fi
source ~/.bashrc
# Step 5: Verify installation
echo ">>> Conda version:"
conda --version
# Step 6: Create environment
echo ">>> Creating environment: llm-env (python=3.10)..."
conda create -n llm-env python=3.10 -y
# Step 7: Activate environment
echo ">>> Activating environment..."
conda activate llm-env
# Step 8: Install dependencies
echo ">>> Installing scikit-learn..."
conda install scikit-learn -y
if [ -f requirements.txt ]; then
echo ">>> Installing Python requirements..."
pip install -r requirements.txt
else
echo ">>> No requirements.txt found, skipping pip install."
fi
echo ">>> Setup complete. To activate your environment run:"
echo "conda activate llm-env"
@@ -10,7 +10,7 @@ task:
# Data Processing Configuration
data:
source: "custom" # Data source: "huggingface" or "custom"
data_path: "./data/raw/instruct/code_reasoning.jsonl" # Path to conversation data file
data_path: "data/raw/swe_reasoning_dataset (3).jsonl" # Path to conversation data file
data_format: "jsonl" # Data format: "jsonl", "json"
# Field Mapping for Conversation Data
@@ -34,7 +34,7 @@ data:
# Model Configuration
model:
name: "unsloth/Qwen2.5-72B-Instruct" # Model name from HuggingFace Hub (optimized for instruction following)
name: "unsloth/Qwen2.5-Coder-7B" # Model name from HuggingFace Hub (optimized for instruction following)
max_length: 2048 # Maximum sequence length for tokenization
max_seq_length: 2048 # Maximum sequence length for training (RoPE scaling supported)
dtype: null # Data type: null for auto detection, float16 for Tesla T4/V100, bfloat16 for Ampere+
@@ -42,7 +42,7 @@ model:
token: null # HuggingFace token for gated models (e.g., "hf_...")
# Training Model Parameters
training_model: "unsloth/Qwen2.5-72B-Instruct" # Model to use for training
training_model: "unsloth/Qwen2.5-Coder-7B" # Model to use for training
training_max_seq_length: 2048 # Max sequence length for training
training_dtype: null # Data type for training
training_load_in_4bit: true # 4bit quantization for training
+3 -3
View File
@@ -10,7 +10,7 @@ task:
# Data Processing Configuration
data:
source: "custom" # Data source: "huggingface" or "custom"
data_path: "./data/raw/instruct/code_reasoning.jsonl" # Path to conversation data file
data_path: "data/raw/swe_reasoning_dataset (3).jsonl" # Path to conversation data file
data_format: "jsonl" # Data format: "jsonl", "json"
# Field Mapping for Conversation Data
@@ -34,7 +34,7 @@ data:
# Model Configuration
model:
name: "unsloth/Qwen2.5-72B-Instruct" # Model name from HuggingFace Hub (optimized for instruction following)
name: "unsloth/Qwen2.5-Coder-7B" # Model name from HuggingFace Hub (optimized for instruction following)
max_length: 2048 # Maximum sequence length for tokenization
max_seq_length: 2048 # Maximum sequence length for training (RoPE scaling supported)
dtype: null # Data type: null for auto detection, float16 for Tesla T4/V100, bfloat16 for Ampere+
@@ -42,7 +42,7 @@ model:
token: null # HuggingFace token for gated models (e.g., "hf_...")
# Training Model Parameters
training_model: "unsloth/Qwen2.5-72B-Instruct" # Model to use for training
training_model: "unsloth/Qwen2.5-Coder-7B" # Model to use for training
training_max_seq_length: 2048 # Max sequence length for training
training_dtype: null # Data type for training
training_load_in_4bit: true # 4bit quantization for training
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@@ -0,0 +1,913 @@
import json
import pandas as pd
import numpy as np
from pathlib import Path
from typing import Dict, List, Optional, Union, Any, Tuple
from datasets import Dataset, load_dataset
import os
from dataclasses import dataclass
from abc import ABC, abstractmethod
from sklearn.model_selection import train_test_split
import re
import argparse
import sys
import yaml
@dataclass
class InstructConfig:
"""Configuration for instruction fine-tuning tasks"""
# Data source configuration
data_source: str = "custom" # "huggingface" or "custom"
dataset_name: Optional[str] = None # For Hugging Face datasets
data_path: Optional[str] = None # For custom datasets
data_format: str = "jsonl" # jsonl, json
# Field mapping - conversation data specific
conversation_field: str = "conversation" # Field containing conversation array
# Data processing
max_samples: Optional[int] = None
train_split: float = 0.8
validation_split: float = 0.1
test_split: float = 0.1
# Text preprocessing
clean_text: bool = True
min_length: int = 10
max_length: int = 2048
# Output configuration
output_format: str = "conversation" # conversation, alpaca
output_dir: str = "./data/processed/instruct"
# Hugging Face specific
hf_split: str = "train"
hf_cache_dir: Optional[str] = None
# Split configuration
test_split_from: str = "train"
val_split_from: str = "train"
# Custom data specific
encoding: str = "utf-8"
class ConversationValidator:
"""Validates conversation data quality and format"""
@staticmethod
def validate_conversation_data(data: Dict[str, List[Dict]], config: InstructConfig, is_processed: bool = False) -> Tuple[bool, List[str]]:
"""Validate conversation dataset splits"""
errors = []
# Check if we have the expected splits
expected_splits = ["train", "validation", "test"]
for split in expected_splits:
if split not in data:
errors.append(f"Missing '{split}' split")
elif split == "train" and not data[split]:
errors.append(f"Train split cannot be empty")
if errors:
return False, errors
total_samples = sum(len(split_data) for split_data in data.values())
print(f"Validating {total_samples} total samples across all splits...")
# Determine field names based on whether data is processed or not
conversation_field = "conversation" if not is_processed else "conversation"
# Validate each split
for split_name, split_data in data.items():
if not split_data:
print(f"Skipping validation for empty {split_name} split")
continue
print(f"Validating {split_name} split with {len(split_data)} samples...")
# Check required fields
missing_conversation_count = 0
for i, item in enumerate(split_data):
if conversation_field not in item:
errors.append(f"Missing conversation field '{conversation_field}' in {split_name} split, item {i}")
missing_conversation_count += 1
else:
# Validate conversation structure
conversation = item[conversation_field]
if not isinstance(conversation, list):
errors.append(f"Conversation field must be a list in {split_name} split, item {i}")
else:
# Validate each turn in conversation
for j, turn in enumerate(conversation):
if not isinstance(turn, dict):
errors.append(f"Each conversation turn must be a dict in {split_name} split, item {i}, turn {j}")
continue
# Check for required fields in conversation turn
if "role" not in turn:
errors.append(f"Missing 'role' field in conversation turn {j}, {split_name} split, item {i}")
if "content" not in turn:
errors.append(f"Missing 'content' field in conversation turn {j}, {split_name} split, item {i}")
# Validate role values
if "role" in turn and turn["role"] not in ["user", "assistant", "system"]:
errors.append(f"Invalid role '{turn['role']}' in conversation turn {j}, {split_name} split, item {i}. Must be 'user', 'assistant', or 'system'")
print(f"{split_name} - Items missing conversation field: {missing_conversation_count}")
# Show sample of processed data for debugging
if split_data:
print(f"Sample conversation from {split_name}:")
for i in range(min(2, len(split_data))):
item = split_data[i]
conversation = item.get(conversation_field, [])
print(f" Item {i} conversation length: {len(conversation)} turns")
for j, turn in enumerate(conversation[:3]): # Show first 3 turns
role = turn.get("role", "unknown")
content = turn.get("content", "")[:100] + "..." if len(turn.get("content", "")) > 100 else turn.get("content", "")
print(f" Turn {j}: {role} -> '{content}'")
return len(errors) == 0, errors
@staticmethod
def analyze_conversation_dataset(data: Dict[str, List[Dict]], config: InstructConfig, is_processed: bool = False) -> Dict[str, Any]:
"""Analyze conversation dataset characteristics across all splits"""
analysis = {
"splits": {},
"overall": {
"total_samples": 0,
"split_sizes": {},
"conversation_stats": {
"total_turns": 0,
"avg_turns_per_conversation": 0,
"role_distribution": {"user": 0, "assistant": 0, "system": 0}
}
}
}
conversation_field = "conversation" if not is_processed else "conversation"
total_turns = 0
total_conversations = 0
role_counts = {"user": 0, "assistant": 0, "system": 0}
# Analyze each split
for split_name, split_data in data.items():
if not split_data:
split_analysis = {
"total_samples": 0,
"conversation_stats": {},
"missing_values": {}
}
analysis["splits"][split_name] = split_analysis
analysis["overall"]["split_sizes"][split_name] = 0
continue
split_analysis = {
"total_samples": len(split_data),
"conversation_stats": {},
"missing_values": {}
}
# Conversation statistics
split_turns = 0
split_conversations = len(split_data)
split_role_counts = {"user": 0, "assistant": 0, "system": 0}
conversation_lengths = []
for item in split_data:
conversation = item.get(conversation_field, [])
if isinstance(conversation, list):
conversation_lengths.append(len(conversation))
split_turns += len(conversation)
for turn in conversation:
if isinstance(turn, dict) and "role" in turn:
role = turn["role"]
if role in split_role_counts:
split_role_counts[role] += 1
if conversation_lengths:
split_analysis["conversation_stats"] = {
"total_turns": split_turns,
"avg_turns_per_conversation": np.mean(conversation_lengths),
"min_turns": min(conversation_lengths),
"max_turns": max(conversation_lengths),
"median_turns": np.median(conversation_lengths),
"role_distribution": split_role_counts
}
# Missing values
missing_count = sum(1 for item in split_data if not item.get(conversation_field))
split_analysis["missing_values"][conversation_field] = missing_count
analysis["splits"][split_name] = split_analysis
analysis["overall"]["total_samples"] += len(split_data)
analysis["overall"]["split_sizes"][split_name] = len(split_data)
# Accumulate overall stats
total_turns += split_turns
total_conversations += split_conversations
for role, count in split_role_counts.items():
role_counts[role] += count
# Calculate overall conversation stats
if total_conversations > 0:
analysis["overall"]["conversation_stats"]["total_turns"] = total_turns
analysis["overall"]["conversation_stats"]["avg_turns_per_conversation"] = total_turns / total_conversations
analysis["overall"]["conversation_stats"]["role_distribution"] = role_counts
return analysis
class BaseInstructDataLoader(ABC):
"""Abstract base class for instruction data loaders"""
@abstractmethod
def load(self, config: InstructConfig) -> Dict[str, List[Dict]]:
"""Load data and return dictionary with train/val/test splits"""
pass
@abstractmethod
def preprocess(self, data: Dict[str, List[Dict]], config: InstructConfig) -> Dict[str, List[Dict]]:
"""Apply preprocessing steps to all splits"""
pass
class HuggingFaceInstructDataLoader(BaseInstructDataLoader):
"""Load conversation datasets from Hugging Face Hub"""
def load(self, config: InstructConfig) -> Dict[str, List[Dict]]:
"""Load dataset from Hugging Face Hub with flexible split handling"""
if not config.dataset_name:
raise ValueError("Dataset name is required for Hugging Face datasets")
print(f"Loading Hugging Face conversation dataset: {config.dataset_name}")
try:
dataset = load_dataset(
config.dataset_name,
cache_dir=config.hf_cache_dir
)
available_splits = list(dataset.keys())
print(f"Available splits in dataset: {available_splits}")
splits_data = {
"train": [],
"validation": [],
"test": []
}
# Handle train split
if "train" in available_splits:
train_dataset = dataset["train"]
print(f"Using 'train' split with {len(train_dataset)} samples")
splits_data["train"] = list(train_dataset)
else:
print("No 'train' split found in dataset!")
raise ValueError(f"Dataset {config.dataset_name} does not have a 'train' split")
# Handle validation and test splits (similar logic to styling pipeline)
# ... [validation and test split handling logic similar to styling pipeline]
# Apply max_samples limit if specified
if config.max_samples:
for split_name in splits_data:
if splits_data[split_name]:
original_size = len(splits_data[split_name])
splits_data[split_name] = splits_data[split_name][:config.max_samples]
print(f"Limited {split_name} split from {original_size} to {len(splits_data[split_name])} samples")
print(f"Successfully loaded dataset {config.dataset_name}")
return splits_data
except Exception as e:
print(f"Error loading dataset {config.dataset_name}: {e}")
raise
def preprocess(self, data: Dict[str, List[Dict]], config: InstructConfig) -> Dict[str, List[Dict]]:
"""Apply preprocessing steps to all splits separately"""
processed_splits = {}
print(f"=== PREPROCESSING CONVERSATION DATA ===")
for split_name, split_data in data.items():
print(f"Processing {split_name} split with {len(split_data)} items...")
processed_data = []
processed_count = 0
skipped_count = 0
for i, item in enumerate(split_data):
processed_item = self._preprocess_item(item, config)
if processed_item is not None:
processed_data.append(processed_item)
processed_count += 1
else:
skipped_count += 1
processed_splits[split_name] = processed_data
print(f"{split_name} - Preprocessed {processed_count} samples, skipped {skipped_count} samples")
return processed_splits
def _preprocess_item(self, item: Dict, config: InstructConfig) -> Optional[Dict]:
"""Preprocess a single conversation item"""
conversation = item.get(config.conversation_field, [])
if not isinstance(conversation, list) or not conversation:
return None
# Validate conversation structure
valid_conversation = []
for turn in conversation:
if not isinstance(turn, dict):
continue
if "role" not in turn or "content" not in turn:
continue
if turn["role"] not in ["user", "assistant", "system"]:
continue
content = str(turn["content"]).strip()
if len(content) < config.min_length or len(content) > config.max_length:
continue
if config.clean_text:
content = self._clean_text(content)
valid_conversation.append({
"role": turn["role"],
"content": content
})
if len(valid_conversation) < 2: # Need at least 2 turns for a conversation
return None
return {"conversation": valid_conversation}
def _clean_text(self, text: str) -> str:
"""Clean and normalize text"""
if not isinstance(text, str):
return ""
# Remove extra whitespace
text = re.sub(r'\s+', ' ', text).strip()
return text
class CustomInstructDataLoader(BaseInstructDataLoader):
"""Load custom conversation datasets from local files"""
def load(self, config: InstructConfig) -> Dict[str, List[Dict]]:
"""Load custom conversation dataset from local file and create splits"""
if not config.data_path:
raise ValueError("Data path is required for custom datasets")
file_path = Path(config.data_path)
if not file_path.exists():
raise FileNotFoundError(f"Data file not found: {file_path}")
print(f"Loading custom conversation dataset: {file_path}")
if config.data_format == "jsonl":
raw_data = self._load_jsonl(file_path, config)
elif config.data_format == "json":
raw_data = self._load_json(file_path, config)
else:
raise ValueError(f"Unsupported format: {config.data_format}")
if config.max_samples:
raw_data = raw_data[:config.max_samples]
print(f"Loaded {len(raw_data)} conversation samples from {file_path}")
# Create splits from the raw data
splits_data = self._create_splits(raw_data, config)
return splits_data
def _create_splits(self, data: List[Dict], config: InstructConfig) -> Dict[str, List[Dict]]:
"""Create train/validation/test splits from raw data"""
print(f"Creating splits from {len(data)} conversation samples...")
# Handle very small datasets
if len(data) < 3:
print(f"Dataset has only {len(data)} samples. Using all data for training.")
return {
"train": data,
"validation": [],
"test": []
}
# Calculate split sizes
total_samples = len(data)
# Adjust split ratios if dataset is too small
if total_samples < 10:
config.train_split = 0.6
config.validation_split = 0.2
config.test_split = 0.2
print(f"Small dataset detected. Adjusted split ratios to: train={config.train_split}, val={config.validation_split}, test={config.test_split}")
val_size = max(1, int(total_samples * config.validation_split))
test_size = max(1, int(total_samples * config.test_split))
train_size = total_samples - val_size - test_size
# Ensure train split has at least 1 sample
if train_size < 1:
if val_size > 1:
val_size -= 1
train_size += 1
elif test_size > 1:
test_size -= 1
train_size += 1
print(f"Split sizes: train={train_size}, validation={val_size}, test={test_size}")
# Create splits
if val_size == 0 and test_size == 0:
splits_data = {
"train": data,
"validation": [],
"test": []
}
elif val_size == 0:
train_data, test_data = train_test_split(data, test_size=test_size, random_state=42)
splits_data = {
"train": train_data,
"validation": [],
"test": test_data
}
elif test_size == 0:
train_data, val_data = train_test_split(data, test_size=val_size, random_state=42)
splits_data = {
"train": train_data,
"validation": val_data,
"test": []
}
else:
# Full three-way split
train_data, temp_data = train_test_split(
data,
test_size=val_size + test_size,
random_state=42
)
val_data, test_data = train_test_split(
temp_data,
test_size=test_size,
random_state=42
)
splits_data = {
"train": train_data,
"validation": val_data,
"test": test_data
}
print(f"Created conversation splits:")
print(f" Train: {len(splits_data['train'])} samples")
print(f" Validation: {len(splits_data['validation'])} samples")
print(f" Test: {len(splits_data['test'])} samples")
return splits_data
def _load_jsonl(self, file_path: Path, config: InstructConfig) -> List[Dict]:
"""Load JSONL file"""
data = []
with open(file_path, 'r', encoding=config.encoding) as f:
for line_num, line in enumerate(f, 1):
if line.strip():
try:
data.append(json.loads(line))
except json.JSONDecodeError as e:
print(f"Invalid JSON at line {line_num}: {e}")
return data
def _load_json(self, file_path: Path, config: InstructConfig) -> List[Dict]:
"""Load JSON file"""
with open(file_path, 'r', encoding=config.encoding) as f:
data = json.load(f)
if isinstance(data, list):
return data
elif isinstance(data, dict) and "data" in data:
return data["data"]
else:
return [data]
def preprocess(self, data: Dict[str, List[Dict]], config: InstructConfig) -> Dict[str, List[Dict]]:
"""Apply preprocessing steps to all splits separately"""
processed_splits = {}
print(f"=== PREPROCESSING CUSTOM CONVERSATION DATA ===")
for split_name, split_data in data.items():
print(f"Processing {split_name} split with {len(split_data)} items...")
processed_data = []
processed_count = 0
skipped_count = 0
for i, item in enumerate(split_data):
processed_item = self._preprocess_item(item, config)
if processed_item is not None:
processed_data.append(processed_item)
processed_count += 1
else:
skipped_count += 1
if skipped_count <= 3: # Log first few skipped items
print(f"Skipped item {i} from {split_name}: {item}")
processed_splits[split_name] = processed_data
print(f"{split_name} - Preprocessed {processed_count} samples, skipped {skipped_count} samples")
return processed_splits
def _preprocess_item(self, item: Dict, config: InstructConfig) -> Optional[Dict]:
"""Preprocess a single conversation item"""
conversation = item.get(config.conversation_field, [])
if not isinstance(conversation, list) or not conversation:
return None
# Validate conversation structure
valid_conversation = []
for turn in conversation:
if not isinstance(turn, dict):
continue
if "role" not in turn or "content" not in turn:
continue
if turn["role"] not in ["user", "assistant", "system"]:
continue
content = str(turn["content"]).strip()
if len(content) < config.min_length or len(content) > config.max_length:
continue
if config.clean_text:
content = self._clean_text(content)
valid_conversation.append({
"role": turn["role"],
"content": content
})
if len(valid_conversation) < 2: # Need at least 2 turns for a conversation
return None
return {"conversation": valid_conversation}
def _clean_text(self, text: str) -> str:
"""Clean and normalize text"""
if not isinstance(text, str):
return ""
# Remove extra whitespace
text = re.sub(r'\s+', ' ', text).strip()
return text
class InstructDataPipeline:
"""Main instruction fine-tuning data pipeline"""
def __init__(self):
self.validator = ConversationValidator()
self.hf_loader = HuggingFaceInstructDataLoader()
self.custom_loader = CustomInstructDataLoader()
def create_config(
self,
data_source: str,
dataset_name: Optional[str] = None,
data_path: Optional[str] = None,
conversation_field: str = "conversation",
**kwargs
) -> InstructConfig:
"""Create instruction configuration"""
return InstructConfig(
data_source=data_source,
dataset_name=dataset_name,
data_path=data_path,
conversation_field=conversation_field,
**kwargs
)
def load_config_from_yaml(self, yaml_path: str) -> InstructConfig:
"""Load configuration from YAML file"""
try:
config_dict = load_yaml_config(yaml_path)
# Create configuration object from YAML data
config = InstructConfig(
data_source=config_dict.get('data_source', 'custom'),
dataset_name=config_dict.get('dataset_name'),
data_path=config_dict.get('data_path'),
data_format=config_dict.get('data_format', 'jsonl'),
conversation_field=config_dict.get('conversation_field', 'conversation'),
max_samples=config_dict.get('max_samples'),
train_split=config_dict.get('train_split', 0.8),
validation_split=config_dict.get('validation_split', 0.1),
test_split=config_dict.get('test_split', 0.1),
clean_text=config_dict.get('clean_text', True),
min_length=config_dict.get('min_length', 10),
max_length=config_dict.get('max_length', 2048),
output_format=config_dict.get('output_format', 'conversation'),
output_dir=config_dict.get('output_dir', './data/processed/instruct'),
hf_split=config_dict.get('hf_split', 'train'),
hf_cache_dir=config_dict.get('hf_cache_dir'),
test_split_from=config_dict.get('test_split_from', 'train'),
val_split_from=config_dict.get('val_split_from', 'train'),
encoding=config_dict.get('encoding', 'utf-8')
)
print(f"Configuration loaded from YAML: {yaml_path}")
print(f"Output directory: {config.output_dir}")
print(f"Conversation field: {config.conversation_field}")
return config
except Exception as e:
print(f"Error loading configuration from YAML {yaml_path}: {e}")
raise
def load_and_preprocess(self, config: InstructConfig) -> Tuple[Dict[str, List[Dict]], Dict[str, Any]]:
"""Load and preprocess conversation data"""
print(f"Starting conversation data loading and preprocessing...")
print(f"Data source: {config.data_source}")
try:
# Load data
if config.data_source == "huggingface":
print("Loading HuggingFace conversation dataset...")
raw_splits = self.hf_loader.load(config)
print("Preprocessing HuggingFace conversation dataset...")
processed_splits = self.hf_loader.preprocess(raw_splits, config)
elif config.data_source == "custom":
print("Loading custom conversation dataset...")
raw_splits = self.custom_loader.load(config)
print("Preprocessing custom conversation dataset...")
processed_splits = self.custom_loader.preprocess(raw_splits, config)
else:
raise ValueError(f"Unsupported data source: {config.data_source}")
print(f"Conversation data loading and preprocessing completed successfully")
# Validate processed data
print("Validating processed conversation data...")
is_valid, errors = self.validator.validate_conversation_data(processed_splits, config, is_processed=True)
if not is_valid:
print("Conversation data validation failed:")
for error in errors:
print(f" - {error}")
raise ValueError("Conversation data validation failed")
print("Conversation data validation passed")
# Analyze dataset
print("Analyzing conversation dataset...")
analysis = self.validator.analyze_conversation_dataset(processed_splits, config, is_processed=True)
print("Conversation dataset analysis completed")
return processed_splits, analysis
except Exception as e:
print(f"Error in load_and_preprocess: {e}")
raise
def save_data(self, data: Dict[str, List[Dict]], output_dir: str, format: str = "jsonl"):
"""Save processed conversation data splits to files"""
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
for split_name, split_data in data.items():
if format == "jsonl":
output_file = output_path / f"{split_name}.jsonl"
with open(output_file, 'w', encoding='utf-8') as f:
for item in split_data:
f.write(json.dumps(item, ensure_ascii=False) + '\n')
elif format == "json":
output_file = output_path / f"{split_name}.json"
with open(output_file, 'w', encoding='utf-8') as f:
json.dump(split_data, f, ensure_ascii=False, indent=2)
print(f"Saved {len(split_data)} conversation samples to {output_file}")
def run_pipeline(
self,
config: InstructConfig,
save_splits: bool = True
) -> Dict[str, Any]:
"""Run complete instruction data pipeline"""
print("Starting instruction data pipeline...")
# Load and preprocess data
processed_splits, analysis = self.load_and_preprocess(config)
# Save data if requested
if save_splits:
output_dir = Path(config.output_dir)
self.save_data(processed_splits, str(output_dir))
# Create result summary
result = {
"config": config,
"analysis": analysis,
"splits": {
split_name: len(split_data) for split_name, split_data in processed_splits.items()
},
"output_format": config.output_format,
"output_dir": config.output_dir,
"data": processed_splits, # Include the actual processed data
}
print("Instruction data pipeline completed successfully!")
return result
def load_yaml_config(config_path: str) -> Dict[str, Any]:
"""Load and parse YAML configuration file with proper structure handling"""
try:
with open(config_path, 'r', encoding='utf-8') as f:
yaml_data = yaml.safe_load(f)
# Extract configuration from YAML structure
config_dict = {}
# Handle task section
if 'task' in yaml_data:
task_data = yaml_data['task']
config_dict.update({
'task_name': task_data.get('name'),
'task_type': task_data.get('type')
})
# Handle data section
if 'data' in yaml_data:
data_config = yaml_data['data']
config_dict.update({
'data_source': data_config.get('source'),
'dataset_name': data_config.get('dataset_name'),
'data_path': data_config.get('data_path'),
'data_format': data_config.get('data_format'),
'conversation_field': data_config.get('conversation_field'),
'max_samples': data_config.get('max_samples'),
'train_split': data_config.get('train_split'),
'validation_split': data_config.get('validation_split'),
'test_split': data_config.get('test_split'),
'clean_text': data_config.get('clean_text'),
'min_length': data_config.get('min_length'),
'max_length': data_config.get('max_length'),
'output_format': data_config.get('output_format'),
'output_dir': data_config.get('output_dir'),
'encoding': data_config.get('encoding')
})
print(f"Successfully parsed YAML configuration from: {config_path}")
print(f"Extracted {len(config_dict)} configuration parameters")
return config_dict
except Exception as e:
print(f"Error loading YAML config from {config_path}: {e}")
raise
def main():
"""Main function with YAML configuration support"""
parser = argparse.ArgumentParser(description="Instruction Data Processing Pipeline")
# YAML configuration
parser.add_argument("--config", type=str, help="Path to YAML configuration file")
# Data source arguments
parser.add_argument("--data-source", choices=["huggingface", "custom"], help="Data source")
parser.add_argument("--dataset-name", type=str, help="HuggingFace dataset name")
parser.add_argument("--data-path", type=str, help="Path to custom data file")
parser.add_argument("--data-format", choices=["jsonl", "json"], help="Data format")
# Field mapping
parser.add_argument("--conversation-field", type=str, help="Conversation field name")
# Data processing
parser.add_argument("--max-samples", type=int, help="Maximum samples to process")
parser.add_argument("--train-split", type=float, help="Training split ratio")
parser.add_argument("--validation-split", type=float, help="Validation split ratio")
parser.add_argument("--test-split", type=float, help="Test split ratio")
# Output configuration
parser.add_argument("--output-dir", type=str, help="Output directory")
# Logging
parser.add_argument("--log-level", choices=["DEBUG", "INFO", "WARNING", "ERROR"], default="INFO", help="Logging level")
args = parser.parse_args()
# Set up logging
# logging.basicConfig(
# level=getattr(logging, args.log_level),
# format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
# )
# Load configuration
config_dict = {}
# Load YAML config if provided
if args.config:
try:
config_dict = load_yaml_config(args.config)
except Exception as e:
print(f"Error loading YAML config: {e}")
sys.exit(1)
# Override YAML config with CLI arguments (similar to styling pipeline)
cli_overrides = {}
if args.data_source:
cli_overrides['data_source'] = args.data_source
if args.dataset_name:
cli_overrides['dataset_name'] = args.dataset_name
if args.data_path:
cli_overrides['data_path'] = args.data_path
if args.data_format:
cli_overrides['data_format'] = args.data_format
if args.conversation_field:
cli_overrides['conversation_field'] = args.conversation_field
if args.max_samples:
cli_overrides['max_samples'] = args.max_samples
if args.train_split:
cli_overrides['train_split'] = args.train_split
if args.validation_split:
cli_overrides['validation_split'] = args.validation_split
if args.test_split:
cli_overrides['test_split'] = args.test_split
if args.output_dir:
cli_overrides['output_dir'] = args.output_dir
# Merge configurations
for key, value in cli_overrides.items():
if key in config_dict:
print(f"Overriding YAML config '{key}' with CLI value: {value}")
config_dict[key] = value
# Validate required arguments
if not config_dict.get('data_source'):
parser.error("--data-source is required (either in YAML config or CLI)")
if config_dict.get('data_source') == "huggingface" and not config_dict.get('dataset_name'):
parser.error("--dataset-name is required for HuggingFace datasets")
if config_dict.get('data_source') == "custom" and not config_dict.get('data_path'):
parser.error("--data-path is required for custom datasets")
# Create configuration object
config = InstructConfig(
data_source=config_dict.get('data_source', 'custom'),
dataset_name=config_dict.get('dataset_name'),
data_path=config_dict.get('data_path'),
data_format=config_dict.get('data_format', 'jsonl'),
conversation_field=config_dict.get('conversation_field', 'conversation'),
max_samples=config_dict.get('max_samples'),
train_split=config_dict.get('train_split', 0.8),
validation_split=config_dict.get('validation_split', 0.1),
test_split=config_dict.get('test_split', 0.1),
clean_text=config_dict.get('clean_text', True),
min_length=config_dict.get('min_length', 10),
max_length=config_dict.get('max_length', 2048),
output_format=config_dict.get('output_format', 'conversation'),
output_dir=config_dict.get('output_dir', './data/processed/instruct'),
hf_split=config_dict.get('hf_split', 'train'),
hf_cache_dir=config_dict.get('hf_cache_dir'),
test_split_from=config_dict.get('test_split_from', 'train'),
val_split_from=config_dict.get('val_split_from', 'train'),
encoding=config_dict.get('encoding', 'utf-8')
)
# Initialize pipeline
pipeline = InstructDataPipeline()
try:
print(f"Starting instruction data pipeline with {config.data_source} data source...")
if args.config:
print(f"Using YAML configuration: {args.config}")
print(f"Conversation field: {config.conversation_field}")
print()
result = pipeline.run_pipeline(config, save_splits=True)
print(f"✅ Pipeline completed successfully!")
print(f" Data source: {config.data_source}")
if config.data_source == "huggingface":
print(f" Dataset: {config.dataset_name}")
else:
print(f" Data file: {config.data_path}")
print(f" Total samples: {result['analysis']['overall']['total_samples']}")
print(f" Split sizes: {result['analysis']['overall']['split_sizes']}")
print(f" Output directory: {config.output_dir}")
print(f" Conversation stats: {result['analysis']['overall']['conversation_stats']}")
except Exception as e:
print(f"❌ Error running pipeline: {e}")
import traceback
print("Full error traceback:")
traceback.print_exc()
sys.exit(1)
if __name__ == "__main__":
main()
@@ -0,0 +1,393 @@
#!/usr/bin/env python3
"""
Instruct Inference Pipeline using Trained Models
Supports conversational inference with streaming and batch processing
"""
import os
import sys
import json
import argparse
from pathlib import Path
from typing import Dict, Any, Optional, List, Union
import yaml
# Add the project root to the path
sys.path.append(str(Path(__file__).parent.parent.parent))
# Inference imports
import torch
from datasets import load_from_disk, Dataset
from unsloth import FastLanguageModel
from unsloth.chat_templates import get_chat_template
from transformers import TextStreamer
class InstructInference:
"""Instruction fine-tuning inference using trained models"""
def __init__(self, config: Dict[str, Any]):
self.config = config
self.model = None
self.tokenizer = None
# Set device
self.device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {self.device}")
# Model parameters
self.model_output_dir = config.get('model_output_dir', './models/instruct')
self.base_model_name = config.get('base_model_name', 'unsloth/Qwen2.5-72B-Instruct')
self.max_seq_length = config.get('max_seq_length', 2048)
self.dtype = config.get('dtype', None)
self.load_in_4bit = config.get('load_in_4bit', True)
self.hf_token = config.get('hf_token', None)
# Inference parameters
self.batch_size = config.get('batch_size', 1)
self.max_new_tokens = config.get('max_new_tokens', 128)
self.temperature = config.get('temperature', 1.5)
self.min_p = config.get('min_p', 0.1)
self.use_cache = config.get('use_cache', True)
# Chat template
self.chat_template = config.get('chat_template', 'qwen-2.5')
def load_model_and_tokenizer(self):
"""Load the trained model and tokenizer"""
print("Loading trained instruction model and tokenizer...")
try:
# Load the saved LoRA model
model_path = self.model_output_dir
print(f"Loading model from: {model_path}")
# Check if the model directory exists
if not Path(model_path).exists():
raise FileNotFoundError(f"Model directory not found: {model_path}")
# Load the model directly from the saved path
self.model, self.tokenizer = FastLanguageModel.from_pretrained(
model_name=model_path,
max_seq_length=self.max_seq_length,
dtype=self.dtype,
load_in_4bit=self.load_in_4bit,
)
# Enable native 2x faster inference
FastLanguageModel.for_inference(self.model)
print(f"✅ Model loaded from: {model_path}")
print(f"✅ Tokenizer loaded with vocab size: {self.tokenizer.vocab_size}")
except Exception as e:
print(f"❌ Error loading model: {e}")
raise
def setup_chat_template(self):
"""Setup chat template for conversation formatting"""
print("Setting up chat template...")
try:
self.tokenizer = get_chat_template(
self.tokenizer,
chat_template=self.chat_template,
)
print(f"✅ Chat template configured: {self.chat_template}")
except Exception as e:
print(f"❌ Error setting up chat template: {e}")
raise
def format_messages(self, messages: List[Dict[str, str]]) -> str:
"""Format messages using chat template"""
try:
# Apply chat template to format the conversation
formatted_prompt = self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True, # Add generation prompt for inference
)
return formatted_prompt
except Exception as e:
print(f"❌ Error formatting messages: {e}")
raise
def generate_response(
self,
messages: List[Dict[str, str]],
max_new_tokens: Optional[int] = None,
temperature: Optional[float] = None,
stream: bool = False
) -> str:
"""Generate response using the trained instruction model"""
try:
# Use default values if not provided
max_tokens = max_new_tokens or self.max_new_tokens
temp = temperature or self.temperature
# Format the messages
formatted_prompt = self.format_messages(messages)
print(f"Formatted prompt: {formatted_prompt[:200]}...")
# Tokenize the input
inputs = self.tokenizer(
[formatted_prompt],
return_tensors="pt"
).to(self.device)
if stream:
# Streaming generation
text_streamer = TextStreamer(self.tokenizer, skip_prompt=True)
print("Generating with streaming...")
_ = self.model.generate(
input_ids=inputs.input_ids,
streamer=text_streamer,
max_new_tokens=max_tokens,
use_cache=self.use_cache,
temperature=temp,
min_p=self.min_p
)
return "" # Streaming output is handled by streamer
else:
# Non-streaming generation
print("Generating response...")
outputs = self.model.generate(
input_ids=inputs.input_ids,
max_new_tokens=max_tokens,
use_cache=self.use_cache,
temperature=temp,
min_p=self.min_p,
pad_token_id=self.tokenizer.eos_token_id
)
# Decode the generated text
full_response = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
# Extract only the new generated response (remove the input prompt)
prompt_length = len(formatted_prompt)
response = full_response[prompt_length:].strip()
return response
except Exception as e:
print(f"❌ Error generating response: {e}")
raise
def chat(self, user_input: str, conversation_history: Optional[List[Dict[str, str]]] = None, stream: bool = False) -> str:
"""Have a chat conversation with the model"""
try:
# Initialize conversation history if not provided
if conversation_history is None:
conversation_history = []
# Add user input to conversation
messages = conversation_history + [{"role": "user", "content": user_input}]
print(f"User: {user_input}")
if stream:
print("Assistant: ", end="", flush=True)
self.generate_response(messages, stream=True)
return ""
else:
# Generate response
response = self.generate_response(messages, stream=False)
print(f"Assistant: {response}")
return response
except Exception as e:
print(f"❌ Error in chat: {e}")
raise
def batch_inference(
self,
conversations: List[List[Dict[str, str]]],
max_new_tokens: Optional[int] = None
) -> List[str]:
"""Perform batch inference on multiple conversations"""
responses = []
for i, messages in enumerate(conversations):
print(f"Processing conversation {i+1}/{len(conversations)}")
response = self.generate_response(messages, max_new_tokens)
responses.append(response)
return responses
def interactive_chat(self):
"""Start an interactive chat session"""
print("🤖 Starting interactive chat session...")
print("Type 'quit', 'exit', or 'bye' to end the conversation.")
print("Type 'clear' to clear conversation history.")
print("Type 'stream on' or 'stream off' to toggle streaming.")
print("-" * 50)
conversation_history = []
streaming = False
while True:
try:
user_input = input("\n👤 You: ").strip()
if user_input.lower() in ['quit', 'exit', 'bye']:
print("👋 Goodbye!")
break
elif user_input.lower() == 'clear':
conversation_history = []
print("🗑️ Conversation history cleared.")
continue
elif user_input.lower() == 'stream on':
streaming = True
print("🔄 Streaming enabled.")
continue
elif user_input.lower() == 'stream off':
streaming = False
print("⏸️ Streaming disabled.")
continue
elif not user_input:
continue
# Generate response
if streaming:
print("🤖 Assistant: ", end="", flush=True)
self.chat(user_input, conversation_history, stream=True)
# Add to history (we don't have the actual response text for streaming)
conversation_history.extend([
{"role": "user", "content": user_input},
{"role": "assistant", "content": "[Streamed response]"}
])
else:
response = self.chat(user_input, conversation_history, stream=False)
# Add to history
conversation_history.extend([
{"role": "user", "content": user_input},
{"role": "assistant", "content": response}
])
except KeyboardInterrupt:
print("\n👋 Goodbye!")
break
except Exception as e:
print(f"❌ Error: {e}")
continue
def load_inference_config(config_path: str) -> Dict[str, Any]:
"""Load inference configuration from YAML file"""
try:
with open(config_path, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
# Extract inference configuration
inference_config = {}
# Model configuration
if 'model' in config:
model_data = config['model']
inference_config.update({
'base_model_name': model_data.get('training_model', 'unsloth/Qwen2.5-72B-Instruct'),
'max_seq_length': model_data.get('training_max_seq_length', 2048),
'dtype': model_data.get('training_dtype'),
'load_in_4bit': model_data.get('training_load_in_4bit', True),
'hf_token': model_data.get('training_token')
})
# Training configuration - to get model_output_dir
if 'training' in config:
training_data = config['training']
inference_config.update({
'model_output_dir': training_data.get('model_output_dir', './models/instruct')
})
# Inference configuration
if 'inference' in config:
inference_data = config['inference']
inference_config.update({
'batch_size': inference_data.get('batch_size', 1),
'max_new_tokens': inference_data.get('max_new_tokens', 128),
'temperature': inference_data.get('temperature', 1.5),
'min_p': inference_data.get('min_p', 0.1),
'use_cache': inference_data.get('use_cache', True)
})
# Chat template
inference_config.update({
'chat_template': 'qwen-2.5' # Use Qwen chat template by default
})
return inference_config
except Exception as e:
print(f"Error loading inference config: {e}")
raise
def main():
"""Main inference function"""
parser = argparse.ArgumentParser(description="Instruction Inference Pipeline")
# Configuration
parser.add_argument("--config", type=str, required=True, help="Path to YAML configuration file")
parser.add_argument("--interactive", action="store_true", help="Start interactive chat session")
parser.add_argument("--message", type=str, help="Single message to send to the model")
parser.add_argument("--max-tokens", type=int, help="Maximum new tokens to generate")
parser.add_argument("--stream", action="store_true", help="Enable streaming generation")
parser.add_argument("--temperature", type=float, help="Sampling temperature")
args = parser.parse_args()
try:
# Load configuration
print(f"Loading configuration from: {args.config}")
inference_config = load_inference_config(args.config)
# Override with CLI arguments
if args.max_tokens:
inference_config['max_new_tokens'] = args.max_tokens
if args.temperature:
inference_config['temperature'] = args.temperature
print("Inference configuration:")
for key, value in inference_config.items():
print(f" {key}: {value}")
# Initialize inference
inference = InstructInference(inference_config)
# Load model and tokenizer
inference.load_model_and_tokenizer()
# Setup chat template
inference.setup_chat_template()
# Run inference based on mode
if args.interactive:
# Interactive chat mode
inference.interactive_chat()
elif args.message:
# Single message mode
print("Running single message inference...")
messages = [{"role": "user", "content": args.message}]
if args.stream:
print("User:", args.message)
print("Assistant: ", end="", flush=True)
inference.generate_response(messages, stream=True)
else:
response = inference.generate_response(messages, stream=False)
print(f"User: {args.message}")
print(f"Assistant: {response}")
else:
# Default to interactive mode if no specific mode is chosen
print("No specific mode chosen. Starting interactive chat...")
inference.interactive_chat()
except Exception as e:
print(f"Inference failed: {e}")
import traceback
traceback.print_exc()
sys.exit(1)
if __name__ == "__main__":
main()
@@ -0,0 +1,518 @@
#!/usr/bin/env python3
"""
Instruct Training Pipeline using Unsloth and SFTTrainer
Supports instruction fine-tuning with conversational data and LoRA fine-tuning
"""
import os
import sys
import json
import argparse
from pathlib import Path
from typing import Dict, Any, Optional, List
import yaml
# Add the project root to the path
sys.path.append(str(Path(__file__).parent.parent.parent))
from utils.config.config_manager import ConfigManager
# Training imports
import torch
from datasets import load_from_disk, Dataset
from unsloth import FastLanguageModel #is_bfloat16_supported
from unsloth.chat_templates import get_chat_template, standardize_sharegpt, train_on_responses_only
from trl import SFTTrainer, SFTConfig
from transformers import DataCollatorForSeq2Seq
class InstructTrainer:
"""Instruction fine-tuning trainer using Unsloth and SFTTrainer"""
def __init__(self, config: Dict[str, Any]):
self.config = config
self.model = None
self.tokenizer = None
self.trainer = None
# Set device
self.device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {self.device}")
# Model parameters
self.model_name = config.get('model_name', 'unsloth/Qwen2.5-72B-Instruct')
self.max_seq_length = config.get('max_seq_length', 2048)
self.dtype = config.get('dtype', None)
self.load_in_4bit = config.get('load_in_4bit', True)
self.hf_token = config.get('hf_token', None)
# LoRA parameters
self.lora_r = config.get('lora_r', 32)
self.lora_alpha = config.get('lora_alpha', 16)
self.lora_dropout = config.get('lora_dropout', 0)
self.target_modules = config.get('target_modules', [
"q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj"
])
# Training arguments
self.batch_size = config.get('batch_size', 1)
self.gradient_accumulation_steps = config.get('gradient_accumulation_steps', 4)
self.learning_rate = config.get('learning_rate', 2e-4)
self.num_epochs = config.get('num_epochs', 1)
self.max_steps = config.get('max_steps', 30)
self.warmup_steps = config.get('warmup_steps', 5)
self.weight_decay = config.get('weight_decay', 0.01)
self.seed = config.get('seed', 3407)
# Output paths
self.output_dir = config.get('output_dir', './outputs')
self.model_output_dir = config.get('model_output_dir', './models/instruct')
# Chat template
self.chat_template = config.get('chat_template', 'qwen-2.5')
def load_model_and_tokenizer(self):
"""Load the pre-trained model and tokenizer"""
print("Loading model and tokenizer...")
try:
self.model, self.tokenizer = FastLanguageModel.from_pretrained(
model_name=self.model_name,
max_seq_length=self.max_seq_length,
dtype=self.dtype,
load_in_4bit=self.load_in_4bit,
token=self.hf_token
)
print(f"✅ Model loaded: {self.model_name}")
print(f"✅ Tokenizer loaded with vocab size: {self.tokenizer.vocab_size}")
except Exception as e:
print(f"❌ Error loading model: {e}")
raise
def setup_lora(self):
"""Setup LoRA for efficient fine-tuning"""
print("Setting up LoRA configuration...")
try:
self.model = FastLanguageModel.get_peft_model(
self.model,
r=self.lora_r,
target_modules=self.target_modules,
lora_alpha=self.lora_alpha,
lora_dropout=self.lora_dropout,
bias="none",
use_gradient_checkpointing="unsloth",
random_state=self.seed,
use_rslora=False,
loftq_config=None
)
print(f"✅ LoRA configured with r={self.lora_r}, alpha={self.lora_alpha}")
except Exception as e:
print(f"❌ Error setting up LoRA: {e}")
raise
def setup_chat_template(self):
"""Setup chat template for conversation formatting"""
print("Setting up chat template...")
try:
self.tokenizer = get_chat_template(
self.tokenizer,
chat_template=self.chat_template,
)
print(f"✅ Chat template configured: {self.chat_template}")
except Exception as e:
print(f"❌ Error setting up chat template: {e}")
raise
def load_dataset(self, dataset_path: str) -> Dataset:
"""Load the conversation training dataset"""
print(f"Loading conversation dataset from: {dataset_path}")
try:
if Path(dataset_path).exists():
# Check if it's a HuggingFace dataset directory
if (Path(dataset_path) / "dataset_info.json").exists():
# Load from HuggingFace dataset directory
dataset = load_from_disk(dataset_path)
print(f"Loaded HuggingFace dataset from disk: {len(dataset)} samples")
else:
# Load from processed conversation data files (JSONL format)
print("Loading from processed conversation data files...")
from datasets import Dataset
import json
all_data = []
data_dir = Path(dataset_path)
# Look for train.jsonl, validation.jsonl, test.jsonl
for split_file in ["train.jsonl", "validation.jsonl", "test.jsonl"]:
file_path = data_dir / split_file
if file_path.exists():
print(f"Loading {split_file}...")
with open(file_path, 'r', encoding='utf-8') as f:
for line in f:
if line.strip():
data = json.loads(line)
all_data.append(data)
if not all_data:
raise ValueError(f"No conversation data found in {dataset_path}")
# Create HuggingFace dataset
dataset = Dataset.from_list(all_data)
print(f"Created HuggingFace dataset from {len(all_data)} conversation samples")
else:
# Try loading from HuggingFace Hub
print(f"Attempting to load from HuggingFace Hub: {dataset_path}")
dataset = Dataset.load_dataset(dataset_path, split="train")
print(f"Loaded from HuggingFace Hub: {len(dataset)} samples")
print(f"Dataset loaded: {len(dataset)} samples")
print(f"Dataset features: {dataset.features}")
# Verify required fields exist for conversation data
required_fields = ["conversation"]
missing_fields = [field for field in required_fields if field not in dataset.features]
if missing_fields:
raise ValueError(f"Missing required fields in conversation dataset: {missing_fields}")
return dataset
except Exception as e:
print(f"Error loading conversation dataset: {e}")
raise
def format_dataset_for_training(self, dataset: Dataset) -> Dataset:
"""Format conversation dataset for training using standardize_sharegpt and apply_chat_template"""
print("Formatting conversation dataset for training...")
try:
# Standardize the ShareGPT format
print("Standardizing ShareGPT format...")
dataset = standardize_sharegpt(dataset)
# Define the formatting function for chat templates
def formatting_prompts_func(examples):
convos = examples["conversation"]
texts = [
self.tokenizer.apply_chat_template(
convo,
tokenize=False,
add_generation_prompt=False
) for convo in convos
]
return {"text": texts}
# Apply the formatting function
print("Applying chat template formatting...")
dataset = dataset.map(formatting_prompts_func, batched=True)
print(f"✅ Dataset formatted for training with {len(dataset)} samples")
print(f"Sample formatted text: {dataset[0]['text'][:200]}...")
return dataset
except Exception as e:
print(f"❌ Error formatting dataset: {e}")
raise
def setup_trainer(self, train_dataset: Dataset):
"""Setup the SFTTrainer for instruction fine-tuning"""
print("Setting up SFTTrainer for instruction fine-tuning...")
try:
# SFT Configuration
sft_config = SFTConfig(
per_device_train_batch_size=self.batch_size,
gradient_accumulation_steps=self.gradient_accumulation_steps,
warmup_steps=self.warmup_steps,
max_steps=self.max_steps,
learning_rate=self.learning_rate,
logging_steps=1,
optim="paged_adamw_8bit",
weight_decay=self.weight_decay,
lr_scheduler_type="linear",
seed=self.seed,
output_dir=self.output_dir,
report_to="none", # Disable wandb for now
)
print("SFT Configuration:")
print(f" batch_size: {self.batch_size}")
print(f" gradient_accumulation_steps: {self.gradient_accumulation_steps}")
print(f" warmup_steps: {self.warmup_steps}")
print(f" max_steps: {self.max_steps}")
print(f" learning_rate: {self.learning_rate}")
# Create SFTTrainer
self.trainer = SFTTrainer(
model=self.model,
tokenizer=self.tokenizer,
train_dataset=train_dataset,
dataset_text_field="text",
max_seq_length=self.max_seq_length,
data_collator=DataCollatorForSeq2Seq(tokenizer=self.tokenizer),
packing=False, # Disable packing for conversation data
args=sft_config,
)
print("✅ SFTTrainer configured successfully")
except Exception as e:
print(f"❌ Error setting up trainer: {e}")
import traceback
print("Full error traceback:")
traceback.print_exc()
raise
def setup_response_only_training(self):
"""Setup training to only learn from assistant responses"""
print("Setting up response-only training...")
try:
# Configure trainer to only train on responses
self.trainer = train_on_responses_only(
self.trainer,
instruction_part="<|im_start|>user\n",
response_part="<|im_start|>assistant\n",
)
print("✅ Response-only training configured")
except Exception as e:
print(f"❌ Error setting up response-only training: {e}")
raise
def train(self, dataset_path: str):
"""Run the instruction fine-tuning process"""
print("🚀 Starting instruction fine-tuning process...")
try:
# Load model and tokenizer
print("Step 1: Loading model and tokenizer...")
self.load_model_and_tokenizer()
# Setup LoRA
print("Step 2: Setting up LoRA...")
self.setup_lora()
# Setup chat template
print("Step 3: Setting up chat template...")
self.setup_chat_template()
# Load dataset
print(f"Step 4: Loading conversation dataset from: {dataset_path}")
train_dataset = self.load_dataset(dataset_path)
# Format dataset for training
print("Step 5: Formatting dataset for training...")
formatted_dataset = self.format_dataset_for_training(train_dataset)
# Setup trainer
print("Step 6: Setting up trainer...")
self.setup_trainer(formatted_dataset)
# Setup response-only training (optional but recommended for chat models)
print("Step 7: Setting up response-only training...")
self.setup_response_only_training()
# Start training
print("Step 8: Starting training...")
trainer_stats = self.trainer.train()
print("✅ Instruction fine-tuning completed successfully!")
print(f"Training stats: {trainer_stats}")
# Save the model
self.save_model()
return trainer_stats
except Exception as e:
print(f"❌ Instruction fine-tuning failed: {e}")
import traceback
print("Full error traceback:")
traceback.print_exc()
raise
def save_model(self):
"""Save the trained instruction model"""
print("Saving trained instruction model...")
try:
# Create output directory
Path(self.model_output_dir).mkdir(parents=True, exist_ok=True)
# Save model and tokenizer
self.model.save_pretrained(self.model_output_dir)
self.tokenizer.save_pretrained(self.model_output_dir)
# Save training config
config_path = Path(self.model_output_dir) / "training_config.json"
with open(config_path, 'w') as f:
json.dump(self.config, f, indent=2)
print(f"✅ Instruction model saved to: {self.model_output_dir}")
print(f"✅ You can now use this model for inference")
except Exception as e:
print(f"❌ Error saving model: {e}")
raise
def prepare_for_inference(self):
"""Prepare model for inference"""
print("Preparing model for inference...")
try:
FastLanguageModel.for_inference(self.model)
print("✅ Model prepared for inference")
except Exception as e:
print(f"❌ Error preparing for inference: {e}")
raise
def load_training_config(yaml_path: str) -> Dict[str, Any]:
"""Load training configuration from YAML file"""
try:
with open(yaml_path, 'r') as f:
config = yaml.safe_load(f)
training_config = {}
# Model configuration - extract from model section
if 'model' in config:
model_config = config['model']
training_config.update({
'model_name': model_config.get('name', 'unsloth/Qwen2.5-72B-Instruct'),
'max_seq_length': int(model_config.get('max_seq_length', 2048)),
'dtype': model_config.get('dtype', None),
'load_in_4bit': model_config.get('load_in_4bit', True),
'hf_token': model_config.get('token', None)
})
# Training configuration - extract from training section
if 'training' in config:
training_data = config['training']
print("Training data from YAML:")
print(f" num_epochs: {training_data.get('num_epochs')} (type: {type(training_data.get('num_epochs'))})")
print(f" batch_size: {training_data.get('batch_size')} (type: {type(training_data.get('batch_size'))})")
print(f" learning_rate: {training_data.get('learning_rate')} (type: {type(training_data.get('learning_rate'))})")
print(f" weight_decay: {training_data.get('weight_decay')} (type: {type(training_data.get('weight_decay'))})")
print(f" warmup_steps: {training_data.get('warmup_steps')} (type: {type(training_data.get('warmup_steps'))})")
print(f" max_steps: {training_data.get('max_steps')} (type: {type(training_data.get('max_steps'))})")
print(f" gradient_accumulation_steps: {training_data.get('gradient_accumulation_steps')} (type: {type(training_data.get('gradient_accumulation_steps'))})")
print(f" seed: {training_data.get('seed')} (type: {type(training_data.get('seed'))})")
print(f" model_output_dir: {training_data.get('model_output_dir')} (type: {type(training_data.get('model_output_dir'))})")
training_config.update({
'num_epochs': int(training_data.get('num_epochs', 1)),
'batch_size': int(training_data.get('batch_size', 1)),
'learning_rate': float(training_data.get('learning_rate', 2e-4)),
'weight_decay': float(training_data.get('weight_decay', 0.01)),
'warmup_steps': int(training_data.get('warmup_steps', 5)),
'max_steps': int(training_data.get('max_steps', 30)),
'gradient_accumulation_steps': int(training_data.get('gradient_accumulation_steps', 4)),
'lr_scheduler_type': training_data.get('lr_scheduler_type', 'linear'),
'seed': int(training_data.get('seed', 3407)),
'model_output_dir': training_data.get('model_output_dir', './models/instruct'),
# LoRA configuration
'lora_r': int(training_data.get('lora_r', 32)),
'lora_alpha': int(training_data.get('lora_alpha', 16)),
'lora_dropout': float(training_data.get('lora_dropout', 0)),
'target_modules': training_data.get('target_modules', [
"q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj"
])
})
# Data configuration - use output_dir from data section
if 'data' in config:
data_config = config['data']
output_dir = data_config.get('output_dir', './data/processed/instruct')
training_config.update({
'data_output_dir': output_dir,
'dataset_path': output_dir, # Default dataset path is the output_dir
})
# Output configuration
training_config.update({
'output_dir': './outputs',
'chat_template': 'qwen-2.5' # Use Qwen chat template by default
})
print("Final training_config:")
for key, value in training_config.items():
print(f" {key}: {value} (type: {type(value)})")
return training_config
except Exception as e:
print(f"Error loading training config: {e}")
raise
def main():
"""Main training function"""
parser = argparse.ArgumentParser(description="Instruction Fine-tuning Training Pipeline")
# Configuration
parser.add_argument("--config", type=str, required=True, help="Path to YAML configuration file")
parser.add_argument("--dataset", type=str, help="Path to training dataset (conversation data path)")
parser.add_argument("--output-dir", type=str, help="Output directory for model")
parser.add_argument("--epochs", type=int, help="Number of training epochs")
parser.add_argument("--batch-size", type=int, help="Training batch size")
parser.add_argument("--learning-rate", type=float, help="Learning rate")
parser.add_argument("--max-steps", type=int, help="Maximum training steps")
args = parser.parse_args()
# Setup logging replaced with print statements
try:
# Load configuration
print(f"Loading configuration from: {args.config}")
training_config = load_training_config(args.config)
# Override with CLI arguments
if args.output_dir:
training_config['model_output_dir'] = args.output_dir
if args.epochs:
training_config['num_epochs'] = int(args.epochs)
if args.batch_size:
training_config['batch_size'] = int(args.batch_size)
if args.learning_rate:
training_config['learning_rate'] = float(args.learning_rate)
if args.max_steps:
training_config['max_steps'] = int(args.max_steps)
# Determine dataset path: CLI argument takes precedence, then YAML config
dataset_path = args.dataset or training_config.get('dataset_path')
if not dataset_path:
print("No dataset path provided. Use --dataset or ensure output_dir is set in YAML config.")
sys.exit(1)
print("Training configuration:")
for key, value in training_config.items():
print(f" {key}: {value}")
print(f" Dataset path: {dataset_path}")
# Initialize trainer
trainer = InstructTrainer(training_config)
# Start training
trainer.train(dataset_path)
print("Instruction fine-tuning completed successfully!")
except Exception as e:
print(f"Instruction fine-tuning failed: {e}")
sys.exit(1)
if __name__ == "__main__":
main()
+62 -66
View File
@@ -7,16 +7,12 @@ from datasets import Dataset, load_dataset
import os
from dataclasses import dataclass
from abc import ABC, abstractmethod
import logging
from sklearn.model_selection import train_test_split
import re
import argparse
import sys
import yaml
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
@dataclass
class InstructConfig:
"""Configuration for instruction fine-tuning tasks"""
@@ -75,7 +71,7 @@ class ConversationValidator:
return False, errors
total_samples = sum(len(split_data) for split_data in data.values())
logger.info(f"Validating {total_samples} total samples across all splits...")
print(f"Validating {total_samples} total samples across all splits...")
# Determine field names based on whether data is processed or not
conversation_field = "conversation" if not is_processed else "conversation"
@@ -83,10 +79,10 @@ class ConversationValidator:
# Validate each split
for split_name, split_data in data.items():
if not split_data:
logger.info(f"Skipping validation for empty {split_name} split")
print(f"Skipping validation for empty {split_name} split")
continue
logger.info(f"Validating {split_name} split with {len(split_data)} samples...")
print(f"Validating {split_name} split with {len(split_data)} samples...")
# Check required fields
missing_conversation_count = 0
@@ -117,19 +113,19 @@ class ConversationValidator:
if "role" in turn and turn["role"] not in ["user", "assistant", "system"]:
errors.append(f"Invalid role '{turn['role']}' in conversation turn {j}, {split_name} split, item {i}. Must be 'user', 'assistant', or 'system'")
logger.info(f"{split_name} - Items missing conversation field: {missing_conversation_count}")
print(f"{split_name} - Items missing conversation field: {missing_conversation_count}")
# Show sample of processed data for debugging
if split_data:
logger.info(f"Sample conversation from {split_name}:")
print(f"Sample conversation from {split_name}:")
for i in range(min(2, len(split_data))):
item = split_data[i]
conversation = item.get(conversation_field, [])
logger.info(f" Item {i} conversation length: {len(conversation)} turns")
print(f" Item {i} conversation length: {len(conversation)} turns")
for j, turn in enumerate(conversation[:3]): # Show first 3 turns
role = turn.get("role", "unknown")
content = turn.get("content", "")[:100] + "..." if len(turn.get("content", "")) > 100 else turn.get("content", "")
logger.info(f" Turn {j}: {role} -> '{content}'")
print(f" Turn {j}: {role} -> '{content}'")
return len(errors) == 0, errors
@@ -242,7 +238,7 @@ class HuggingFaceInstructDataLoader(BaseInstructDataLoader):
if not config.dataset_name:
raise ValueError("Dataset name is required for Hugging Face datasets")
logger.info(f"Loading Hugging Face conversation dataset: {config.dataset_name}")
print(f"Loading Hugging Face conversation dataset: {config.dataset_name}")
try:
dataset = load_dataset(
@@ -251,7 +247,7 @@ class HuggingFaceInstructDataLoader(BaseInstructDataLoader):
)
available_splits = list(dataset.keys())
logger.info(f"Available splits in dataset: {available_splits}")
print(f"Available splits in dataset: {available_splits}")
splits_data = {
"train": [],
@@ -262,10 +258,10 @@ class HuggingFaceInstructDataLoader(BaseInstructDataLoader):
# Handle train split
if "train" in available_splits:
train_dataset = dataset["train"]
logger.info(f"Using 'train' split with {len(train_dataset)} samples")
print(f"Using 'train' split with {len(train_dataset)} samples")
splits_data["train"] = list(train_dataset)
else:
logger.error("No 'train' split found in dataset!")
print("No 'train' split found in dataset!")
raise ValueError(f"Dataset {config.dataset_name} does not have a 'train' split")
# Handle validation and test splits (similar logic to styling pipeline)
@@ -277,23 +273,23 @@ class HuggingFaceInstructDataLoader(BaseInstructDataLoader):
if splits_data[split_name]:
original_size = len(splits_data[split_name])
splits_data[split_name] = splits_data[split_name][:config.max_samples]
logger.info(f"Limited {split_name} split from {original_size} to {len(splits_data[split_name])} samples")
print(f"Limited {split_name} split from {original_size} to {len(splits_data[split_name])} samples")
logger.info(f"Successfully loaded dataset {config.dataset_name}")
print(f"Successfully loaded dataset {config.dataset_name}")
return splits_data
except Exception as e:
logger.error(f"Error loading dataset {config.dataset_name}: {e}")
print(f"Error loading dataset {config.dataset_name}: {e}")
raise
def preprocess(self, data: Dict[str, List[Dict]], config: InstructConfig) -> Dict[str, List[Dict]]:
"""Apply preprocessing steps to all splits separately"""
processed_splits = {}
logger.info(f"=== PREPROCESSING CONVERSATION DATA ===")
print(f"=== PREPROCESSING CONVERSATION DATA ===")
for split_name, split_data in data.items():
logger.info(f"Processing {split_name} split with {len(split_data)} items...")
print(f"Processing {split_name} split with {len(split_data)} items...")
processed_data = []
processed_count = 0
@@ -308,7 +304,7 @@ class HuggingFaceInstructDataLoader(BaseInstructDataLoader):
skipped_count += 1
processed_splits[split_name] = processed_data
logger.info(f"{split_name} - Preprocessed {processed_count} samples, skipped {skipped_count} samples")
print(f"{split_name} - Preprocessed {processed_count} samples, skipped {skipped_count} samples")
return processed_splits
@@ -368,7 +364,7 @@ class CustomInstructDataLoader(BaseInstructDataLoader):
if not file_path.exists():
raise FileNotFoundError(f"Data file not found: {file_path}")
logger.info(f"Loading custom conversation dataset: {file_path}")
print(f"Loading custom conversation dataset: {file_path}")
if config.data_format == "jsonl":
raw_data = self._load_jsonl(file_path, config)
@@ -380,7 +376,7 @@ class CustomInstructDataLoader(BaseInstructDataLoader):
if config.max_samples:
raw_data = raw_data[:config.max_samples]
logger.info(f"Loaded {len(raw_data)} conversation samples from {file_path}")
print(f"Loaded {len(raw_data)} conversation samples from {file_path}")
# Create splits from the raw data
splits_data = self._create_splits(raw_data, config)
@@ -389,11 +385,11 @@ class CustomInstructDataLoader(BaseInstructDataLoader):
def _create_splits(self, data: List[Dict], config: InstructConfig) -> Dict[str, List[Dict]]:
"""Create train/validation/test splits from raw data"""
logger.info(f"Creating splits from {len(data)} conversation samples...")
print(f"Creating splits from {len(data)} conversation samples...")
# Handle very small datasets
if len(data) < 3:
logger.warning(f"Dataset has only {len(data)} samples. Using all data for training.")
print(f"Dataset has only {len(data)} samples. Using all data for training.")
return {
"train": data,
"validation": [],
@@ -408,7 +404,7 @@ class CustomInstructDataLoader(BaseInstructDataLoader):
config.train_split = 0.6
config.validation_split = 0.2
config.test_split = 0.2
logger.info(f"Small dataset detected. Adjusted split ratios to: train={config.train_split}, val={config.validation_split}, test={config.test_split}")
print(f"Small dataset detected. Adjusted split ratios to: train={config.train_split}, val={config.validation_split}, test={config.test_split}")
val_size = max(1, int(total_samples * config.validation_split))
test_size = max(1, int(total_samples * config.test_split))
@@ -423,7 +419,7 @@ class CustomInstructDataLoader(BaseInstructDataLoader):
test_size -= 1
train_size += 1
logger.info(f"Split sizes: train={train_size}, validation={val_size}, test={test_size}")
print(f"Split sizes: train={train_size}, validation={val_size}, test={test_size}")
# Create splits
if val_size == 0 and test_size == 0:
@@ -466,10 +462,10 @@ class CustomInstructDataLoader(BaseInstructDataLoader):
"test": test_data
}
logger.info(f"Created conversation splits:")
logger.info(f" Train: {len(splits_data['train'])} samples")
logger.info(f" Validation: {len(splits_data['validation'])} samples")
logger.info(f" Test: {len(splits_data['test'])} samples")
print(f"Created conversation splits:")
print(f" Train: {len(splits_data['train'])} samples")
print(f" Validation: {len(splits_data['validation'])} samples")
print(f" Test: {len(splits_data['test'])} samples")
return splits_data
@@ -482,7 +478,7 @@ class CustomInstructDataLoader(BaseInstructDataLoader):
try:
data.append(json.loads(line))
except json.JSONDecodeError as e:
logger.warning(f"Invalid JSON at line {line_num}: {e}")
print(f"Invalid JSON at line {line_num}: {e}")
return data
def _load_json(self, file_path: Path, config: InstructConfig) -> List[Dict]:
@@ -501,10 +497,10 @@ class CustomInstructDataLoader(BaseInstructDataLoader):
"""Apply preprocessing steps to all splits separately"""
processed_splits = {}
logger.info(f"=== PREPROCESSING CUSTOM CONVERSATION DATA ===")
print(f"=== PREPROCESSING CUSTOM CONVERSATION DATA ===")
for split_name, split_data in data.items():
logger.info(f"Processing {split_name} split with {len(split_data)} items...")
print(f"Processing {split_name} split with {len(split_data)} items...")
processed_data = []
processed_count = 0
@@ -518,10 +514,10 @@ class CustomInstructDataLoader(BaseInstructDataLoader):
else:
skipped_count += 1
if skipped_count <= 3: # Log first few skipped items
logger.info(f"Skipped item {i} from {split_name}: {item}")
print(f"Skipped item {i} from {split_name}: {item}")
processed_splits[split_name] = processed_data
logger.info(f"{split_name} - Preprocessed {processed_count} samples, skipped {skipped_count} samples")
print(f"{split_name} - Preprocessed {processed_count} samples, skipped {skipped_count} samples")
return processed_splits
@@ -621,59 +617,59 @@ class InstructDataPipeline:
encoding=config_dict.get('encoding', 'utf-8')
)
logger.info(f"Configuration loaded from YAML: {yaml_path}")
logger.info(f"Output directory: {config.output_dir}")
logger.info(f"Conversation field: {config.conversation_field}")
print(f"Configuration loaded from YAML: {yaml_path}")
print(f"Output directory: {config.output_dir}")
print(f"Conversation field: {config.conversation_field}")
return config
except Exception as e:
logger.error(f"Error loading configuration from YAML {yaml_path}: {e}")
print(f"Error loading configuration from YAML {yaml_path}: {e}")
raise
def load_and_preprocess(self, config: InstructConfig) -> Tuple[Dict[str, List[Dict]], Dict[str, Any]]:
"""Load and preprocess conversation data"""
logger.info(f"Starting conversation data loading and preprocessing...")
logger.info(f"Data source: {config.data_source}")
print(f"Starting conversation data loading and preprocessing...")
print(f"Data source: {config.data_source}")
try:
# Load data
if config.data_source == "huggingface":
logger.info("Loading HuggingFace conversation dataset...")
print("Loading HuggingFace conversation dataset...")
raw_splits = self.hf_loader.load(config)
logger.info("Preprocessing HuggingFace conversation dataset...")
print("Preprocessing HuggingFace conversation dataset...")
processed_splits = self.hf_loader.preprocess(raw_splits, config)
elif config.data_source == "custom":
logger.info("Loading custom conversation dataset...")
print("Loading custom conversation dataset...")
raw_splits = self.custom_loader.load(config)
logger.info("Preprocessing custom conversation dataset...")
print("Preprocessing custom conversation dataset...")
processed_splits = self.custom_loader.preprocess(raw_splits, config)
else:
raise ValueError(f"Unsupported data source: {config.data_source}")
logger.info(f"Conversation data loading and preprocessing completed successfully")
print(f"Conversation data loading and preprocessing completed successfully")
# Validate processed data
logger.info("Validating processed conversation data...")
print("Validating processed conversation data...")
is_valid, errors = self.validator.validate_conversation_data(processed_splits, config, is_processed=True)
if not is_valid:
logger.error("Conversation data validation failed:")
print("Conversation data validation failed:")
for error in errors:
logger.error(f" - {error}")
print(f" - {error}")
raise ValueError("Conversation data validation failed")
logger.info("Conversation data validation passed")
print("Conversation data validation passed")
# Analyze dataset
logger.info("Analyzing conversation dataset...")
print("Analyzing conversation dataset...")
analysis = self.validator.analyze_conversation_dataset(processed_splits, config, is_processed=True)
logger.info("Conversation dataset analysis completed")
print("Conversation dataset analysis completed")
return processed_splits, analysis
except Exception as e:
logger.error(f"Error in load_and_preprocess: {e}")
print(f"Error in load_and_preprocess: {e}")
raise
def save_data(self, data: Dict[str, List[Dict]], output_dir: str, format: str = "jsonl"):
@@ -692,7 +688,7 @@ class InstructDataPipeline:
with open(output_file, 'w', encoding='utf-8') as f:
json.dump(split_data, f, ensure_ascii=False, indent=2)
logger.info(f"Saved {len(split_data)} conversation samples to {output_file}")
print(f"Saved {len(split_data)} conversation samples to {output_file}")
def run_pipeline(
self,
@@ -701,7 +697,7 @@ class InstructDataPipeline:
) -> Dict[str, Any]:
"""Run complete instruction data pipeline"""
logger.info("Starting instruction data pipeline...")
print("Starting instruction data pipeline...")
# Load and preprocess data
processed_splits, analysis = self.load_and_preprocess(config)
@@ -723,7 +719,7 @@ class InstructDataPipeline:
"data": processed_splits, # Include the actual processed data
}
logger.info("Instruction data pipeline completed successfully!")
print("Instruction data pipeline completed successfully!")
return result
def load_yaml_config(config_path: str) -> Dict[str, Any]:
@@ -764,13 +760,13 @@ def load_yaml_config(config_path: str) -> Dict[str, Any]:
'encoding': data_config.get('encoding')
})
logger.info(f"Successfully parsed YAML configuration from: {config_path}")
logger.info(f"Extracted {len(config_dict)} configuration parameters")
print(f"Successfully parsed YAML configuration from: {config_path}")
print(f"Extracted {len(config_dict)} configuration parameters")
return config_dict
except Exception as e:
logger.error(f"Error loading YAML config from {config_path}: {e}")
print(f"Error loading YAML config from {config_path}: {e}")
raise
def main():
@@ -805,10 +801,10 @@ def main():
args = parser.parse_args()
# Set up logging
logging.basicConfig(
level=getattr(logging, args.log_level),
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
# logging.basicConfig(
# level=getattr(logging, args.log_level),
# format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
# )
# Load configuration
config_dict = {}
@@ -818,7 +814,7 @@ def main():
try:
config_dict = load_yaml_config(args.config)
except Exception as e:
logger.error(f"Error loading YAML config: {e}")
print(f"Error loading YAML config: {e}")
sys.exit(1)
# Override YAML config with CLI arguments (similar to styling pipeline)
@@ -847,7 +843,7 @@ def main():
# Merge configurations
for key, value in cli_overrides.items():
if key in config_dict:
logger.info(f"Overriding YAML config '{key}' with CLI value: {value}")
print(f"Overriding YAML config '{key}' with CLI value: {value}")
config_dict[key] = value
# Validate required arguments
+82 -89
View File
@@ -7,7 +7,6 @@ Supports instruction fine-tuning with conversational data and LoRA fine-tuning
import os
import sys
import json
import logging
import argparse
from pathlib import Path
from typing import Dict, Any, Optional, List
@@ -21,13 +20,11 @@ from utils.config.config_manager import ConfigManager
# Training imports
import torch
from datasets import load_from_disk, Dataset
from unsloth import FastLanguageModel, is_bfloat16_supported
from unsloth import FastLanguageModel #is_bfloat16_supported
from unsloth.chat_templates import get_chat_template, standardize_sharegpt, train_on_responses_only
from trl import SFTTrainer, SFTConfig
from transformers import DataCollatorForSeq2Seq
logger = logging.getLogger(__name__)
class InstructTrainer:
"""Instruction fine-tuning trainer using Unsloth and SFTTrainer"""
@@ -39,7 +36,7 @@ class InstructTrainer:
# Set device
self.device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Using device: {self.device}")
print(f"Using device: {self.device}")
# Model parameters
self.model_name = config.get('model_name', 'unsloth/Qwen2.5-72B-Instruct')
@@ -76,7 +73,7 @@ class InstructTrainer:
def load_model_and_tokenizer(self):
"""Load the pre-trained model and tokenizer"""
logger.info("Loading model and tokenizer...")
print("Loading model and tokenizer...")
try:
self.model, self.tokenizer = FastLanguageModel.from_pretrained(
@@ -87,16 +84,16 @@ class InstructTrainer:
token=self.hf_token
)
logger.info(f"✅ Model loaded: {self.model_name}")
logger.info(f"✅ Tokenizer loaded with vocab size: {self.tokenizer.vocab_size}")
print(f"✅ Model loaded: {self.model_name}")
print(f"✅ Tokenizer loaded with vocab size: {self.tokenizer.vocab_size}")
except Exception as e:
logger.error(f"❌ Error loading model: {e}")
print(f"❌ Error loading model: {e}")
raise
def setup_lora(self):
"""Setup LoRA for efficient fine-tuning"""
logger.info("Setting up LoRA configuration...")
print("Setting up LoRA configuration...")
try:
self.model = FastLanguageModel.get_peft_model(
@@ -112,15 +109,15 @@ class InstructTrainer:
loftq_config=None
)
logger.info(f"✅ LoRA configured with r={self.lora_r}, alpha={self.lora_alpha}")
print(f"✅ LoRA configured with r={self.lora_r}, alpha={self.lora_alpha}")
except Exception as e:
logger.error(f"❌ Error setting up LoRA: {e}")
print(f"❌ Error setting up LoRA: {e}")
raise
def setup_chat_template(self):
"""Setup chat template for conversation formatting"""
logger.info("Setting up chat template...")
print("Setting up chat template...")
try:
self.tokenizer = get_chat_template(
@@ -128,15 +125,15 @@ class InstructTrainer:
chat_template=self.chat_template,
)
logger.info(f"✅ Chat template configured: {self.chat_template}")
print(f"✅ Chat template configured: {self.chat_template}")
except Exception as e:
logger.error(f"❌ Error setting up chat template: {e}")
print(f"❌ Error setting up chat template: {e}")
raise
def load_dataset(self, dataset_path: str) -> Dataset:
"""Load the conversation training dataset"""
logger.info(f"Loading conversation dataset from: {dataset_path}")
print(f"Loading conversation dataset from: {dataset_path}")
try:
if Path(dataset_path).exists():
@@ -144,10 +141,10 @@ class InstructTrainer:
if (Path(dataset_path) / "dataset_info.json").exists():
# Load from HuggingFace dataset directory
dataset = load_from_disk(dataset_path)
logger.info(f"Loaded HuggingFace dataset from disk: {len(dataset)} samples")
print(f"Loaded HuggingFace dataset from disk: {len(dataset)} samples")
else:
# Load from processed conversation data files (JSONL format)
logger.info("Loading from processed conversation data files...")
print("Loading from processed conversation data files...")
from datasets import Dataset
import json
@@ -158,7 +155,7 @@ class InstructTrainer:
for split_file in ["train.jsonl", "validation.jsonl", "test.jsonl"]:
file_path = data_dir / split_file
if file_path.exists():
logger.info(f"Loading {split_file}...")
print(f"Loading {split_file}...")
with open(file_path, 'r', encoding='utf-8') as f:
for line in f:
if line.strip():
@@ -170,15 +167,15 @@ class InstructTrainer:
# Create HuggingFace dataset
dataset = Dataset.from_list(all_data)
logger.info(f"Created HuggingFace dataset from {len(all_data)} conversation samples")
print(f"Created HuggingFace dataset from {len(all_data)} conversation samples")
else:
# Try loading from HuggingFace Hub
logger.info(f"Attempting to load from HuggingFace Hub: {dataset_path}")
print(f"Attempting to load from HuggingFace Hub: {dataset_path}")
dataset = Dataset.load_dataset(dataset_path, split="train")
logger.info(f"Loaded from HuggingFace Hub: {len(dataset)} samples")
print(f"Loaded from HuggingFace Hub: {len(dataset)} samples")
logger.info(f"Dataset loaded: {len(dataset)} samples")
logger.info(f"Dataset features: {dataset.features}")
print(f"Dataset loaded: {len(dataset)} samples")
print(f"Dataset features: {dataset.features}")
# Verify required fields exist for conversation data
required_fields = ["conversation"]
@@ -189,16 +186,16 @@ class InstructTrainer:
return dataset
except Exception as e:
logger.error(f"Error loading conversation dataset: {e}")
print(f"Error loading conversation dataset: {e}")
raise
def format_dataset_for_training(self, dataset: Dataset) -> Dataset:
"""Format conversation dataset for training using standardize_sharegpt and apply_chat_template"""
logger.info("Formatting conversation dataset for training...")
print("Formatting conversation dataset for training...")
try:
# Standardize the ShareGPT format
logger.info("Standardizing ShareGPT format...")
print("Standardizing ShareGPT format...")
dataset = standardize_sharegpt(dataset)
# Define the formatting function for chat templates
@@ -214,21 +211,21 @@ class InstructTrainer:
return {"text": texts}
# Apply the formatting function
logger.info("Applying chat template formatting...")
print("Applying chat template formatting...")
dataset = dataset.map(formatting_prompts_func, batched=True)
logger.info(f"✅ Dataset formatted for training with {len(dataset)} samples")
logger.info(f"Sample formatted text: {dataset[0]['text'][:200]}...")
print(f"✅ Dataset formatted for training with {len(dataset)} samples")
print(f"Sample formatted text: {dataset[0]['text'][:200]}...")
return dataset
except Exception as e:
logger.error(f"❌ Error formatting dataset: {e}")
print(f"❌ Error formatting dataset: {e}")
raise
def setup_trainer(self, train_dataset: Dataset):
"""Setup the SFTTrainer for instruction fine-tuning"""
logger.info("Setting up SFTTrainer for instruction fine-tuning...")
print("Setting up SFTTrainer for instruction fine-tuning...")
try:
# SFT Configuration
@@ -247,12 +244,12 @@ class InstructTrainer:
report_to="none", # Disable wandb for now
)
logger.info("SFT Configuration:")
logger.info(f" batch_size: {self.batch_size}")
logger.info(f" gradient_accumulation_steps: {self.gradient_accumulation_steps}")
logger.info(f" warmup_steps: {self.warmup_steps}")
logger.info(f" max_steps: {self.max_steps}")
logger.info(f" learning_rate: {self.learning_rate}")
print("SFT Configuration:")
print(f" batch_size: {self.batch_size}")
print(f" gradient_accumulation_steps: {self.gradient_accumulation_steps}")
print(f" warmup_steps: {self.warmup_steps}")
print(f" max_steps: {self.max_steps}")
print(f" learning_rate: {self.learning_rate}")
# Create SFTTrainer
self.trainer = SFTTrainer(
@@ -266,18 +263,18 @@ class InstructTrainer:
args=sft_config,
)
logger.info("✅ SFTTrainer configured successfully")
print("✅ SFTTrainer configured successfully")
except Exception as e:
logger.error(f"❌ Error setting up trainer: {e}")
print(f"❌ Error setting up trainer: {e}")
import traceback
logger.error("Full error traceback:")
print("Full error traceback:")
traceback.print_exc()
raise
def setup_response_only_training(self):
"""Setup training to only learn from assistant responses"""
logger.info("Setting up response-only training...")
print("Setting up response-only training...")
try:
# Configure trainer to only train on responses
@@ -287,51 +284,51 @@ class InstructTrainer:
response_part="<|im_start|>assistant\n",
)
logger.info("✅ Response-only training configured")
print("✅ Response-only training configured")
except Exception as e:
logger.error(f"❌ Error setting up response-only training: {e}")
print(f"❌ Error setting up response-only training: {e}")
raise
def train(self, dataset_path: str):
"""Run the instruction fine-tuning process"""
logger.info("🚀 Starting instruction fine-tuning process...")
print("🚀 Starting instruction fine-tuning process...")
try:
# Load model and tokenizer
logger.info("Step 1: Loading model and tokenizer...")
print("Step 1: Loading model and tokenizer...")
self.load_model_and_tokenizer()
# Setup LoRA
logger.info("Step 2: Setting up LoRA...")
print("Step 2: Setting up LoRA...")
self.setup_lora()
# Setup chat template
logger.info("Step 3: Setting up chat template...")
print("Step 3: Setting up chat template...")
self.setup_chat_template()
# Load dataset
logger.info(f"Step 4: Loading conversation dataset from: {dataset_path}")
print(f"Step 4: Loading conversation dataset from: {dataset_path}")
train_dataset = self.load_dataset(dataset_path)
# Format dataset for training
logger.info("Step 5: Formatting dataset for training...")
print("Step 5: Formatting dataset for training...")
formatted_dataset = self.format_dataset_for_training(train_dataset)
# Setup trainer
logger.info("Step 6: Setting up trainer...")
print("Step 6: Setting up trainer...")
self.setup_trainer(formatted_dataset)
# Setup response-only training (optional but recommended for chat models)
logger.info("Step 7: Setting up response-only training...")
print("Step 7: Setting up response-only training...")
self.setup_response_only_training()
# Start training
logger.info("Step 8: Starting training...")
print("Step 8: Starting training...")
trainer_stats = self.trainer.train()
logger.info("✅ Instruction fine-tuning completed successfully!")
logger.info(f"Training stats: {trainer_stats}")
print("✅ Instruction fine-tuning completed successfully!")
print(f"Training stats: {trainer_stats}")
# Save the model
self.save_model()
@@ -339,15 +336,15 @@ class InstructTrainer:
return trainer_stats
except Exception as e:
logger.error(f"❌ Instruction fine-tuning failed: {e}")
print(f"❌ Instruction fine-tuning failed: {e}")
import traceback
logger.error("Full error traceback:")
print("Full error traceback:")
traceback.print_exc()
raise
def save_model(self):
"""Save the trained instruction model"""
logger.info("Saving trained instruction model...")
print("Saving trained instruction model...")
try:
# Create output directory
@@ -362,23 +359,23 @@ class InstructTrainer:
with open(config_path, 'w') as f:
json.dump(self.config, f, indent=2)
logger.info(f"✅ Instruction model saved to: {self.model_output_dir}")
logger.info(f"✅ You can now use this model for inference")
print(f"✅ Instruction model saved to: {self.model_output_dir}")
print(f"✅ You can now use this model for inference")
except Exception as e:
logger.error(f"❌ Error saving model: {e}")
print(f"❌ Error saving model: {e}")
raise
def prepare_for_inference(self):
"""Prepare model for inference"""
logger.info("Preparing model for inference...")
print("Preparing model for inference...")
try:
FastLanguageModel.for_inference(self.model)
logger.info("✅ Model prepared for inference")
print("✅ Model prepared for inference")
except Exception as e:
logger.error(f"❌ Error preparing for inference: {e}")
print(f"❌ Error preparing for inference: {e}")
raise
def load_training_config(yaml_path: str) -> Dict[str, Any]:
@@ -403,16 +400,16 @@ def load_training_config(yaml_path: str) -> Dict[str, Any]:
# Training configuration - extract from training section
if 'training' in config:
training_data = config['training']
logger.info("Training data from YAML:")
logger.info(f" num_epochs: {training_data.get('num_epochs')} (type: {type(training_data.get('num_epochs'))})")
logger.info(f" batch_size: {training_data.get('batch_size')} (type: {type(training_data.get('batch_size'))})")
logger.info(f" learning_rate: {training_data.get('learning_rate')} (type: {type(training_data.get('learning_rate'))})")
logger.info(f" weight_decay: {training_data.get('weight_decay')} (type: {type(training_data.get('weight_decay'))})")
logger.info(f" warmup_steps: {training_data.get('warmup_steps')} (type: {type(training_data.get('warmup_steps'))})")
logger.info(f" max_steps: {training_data.get('max_steps')} (type: {type(training_data.get('max_steps'))})")
logger.info(f" gradient_accumulation_steps: {training_data.get('gradient_accumulation_steps')} (type: {type(training_data.get('gradient_accumulation_steps'))})")
logger.info(f" seed: {training_data.get('seed')} (type: {type(training_data.get('seed'))})")
logger.info(f" model_output_dir: {training_data.get('model_output_dir')} (type: {type(training_data.get('model_output_dir'))})")
print("Training data from YAML:")
print(f" num_epochs: {training_data.get('num_epochs')} (type: {type(training_data.get('num_epochs'))})")
print(f" batch_size: {training_data.get('batch_size')} (type: {type(training_data.get('batch_size'))})")
print(f" learning_rate: {training_data.get('learning_rate')} (type: {type(training_data.get('learning_rate'))})")
print(f" weight_decay: {training_data.get('weight_decay')} (type: {type(training_data.get('weight_decay'))})")
print(f" warmup_steps: {training_data.get('warmup_steps')} (type: {type(training_data.get('warmup_steps'))})")
print(f" max_steps: {training_data.get('max_steps')} (type: {type(training_data.get('max_steps'))})")
print(f" gradient_accumulation_steps: {training_data.get('gradient_accumulation_steps')} (type: {type(training_data.get('gradient_accumulation_steps'))})")
print(f" seed: {training_data.get('seed')} (type: {type(training_data.get('seed'))})")
print(f" model_output_dir: {training_data.get('model_output_dir')} (type: {type(training_data.get('model_output_dir'))})")
training_config.update({
'num_epochs': int(training_data.get('num_epochs', 1)),
@@ -450,14 +447,14 @@ def load_training_config(yaml_path: str) -> Dict[str, Any]:
'chat_template': 'qwen-2.5' # Use Qwen chat template by default
})
logger.info("Final training_config:")
print("Final training_config:")
for key, value in training_config.items():
logger.info(f" {key}: {value} (type: {type(value)})")
print(f" {key}: {value} (type: {type(value)})")
return training_config
except Exception as e:
logger.error(f"Error loading training config: {e}")
print(f"Error loading training config: {e}")
raise
def main():
@@ -475,15 +472,11 @@ def main():
args = parser.parse_args()
# Setup logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
# Setup logging replaced with print statements
try:
# Load configuration
logger.info(f"Loading configuration from: {args.config}")
print(f"Loading configuration from: {args.config}")
training_config = load_training_config(args.config)
# Override with CLI arguments
@@ -501,13 +494,13 @@ def main():
# Determine dataset path: CLI argument takes precedence, then YAML config
dataset_path = args.dataset or training_config.get('dataset_path')
if not dataset_path:
logger.error("No dataset path provided. Use --dataset or ensure output_dir is set in YAML config.")
print("No dataset path provided. Use --dataset or ensure output_dir is set in YAML config.")
sys.exit(1)
logger.info("Training configuration:")
print("Training configuration:")
for key, value in training_config.items():
logger.info(f" {key}: {value}")
logger.info(f" Dataset path: {dataset_path}")
print(f" {key}: {value}")
print(f" Dataset path: {dataset_path}")
# Initialize trainer
trainer = InstructTrainer(training_config)
@@ -515,10 +508,10 @@ def main():
# Start training
trainer.train(dataset_path)
logger.info("Instruction fine-tuning completed successfully!")
print("Instruction fine-tuning completed successfully!")
except Exception as e:
logger.error(f"Instruction fine-tuning failed: {e}")
print(f"Instruction fine-tuning failed: {e}")
sys.exit(1)
if __name__ == "__main__":
+1
View File
@@ -9,3 +9,4 @@ numpy>=1.24.0
scikit-learn>=1.3.0
pyyaml>=6.0
huggingface-hub>=0.15.0
unsloth
+46
View File
@@ -0,0 +1,46 @@
#!/usr/bin/env bash
set -e # exit on error
# Step 1: Download Miniconda installer
echo ">>> Downloading Miniconda..."
wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda.sh
# Step 2: Run the installer
echo ">>> Installing Miniconda..."
bash ~/miniconda.sh -b -p $HOME/miniconda
# Step 3: Initialize conda for bash
echo ">>> Initializing Conda..."
eval "$($HOME/miniconda/bin/conda shell.bash hook)"
# Step 4: Add conda to PATH permanently
if ! grep -q "miniconda/bin" ~/.bashrc; then
echo "export PATH=\$HOME/miniconda/bin:\$PATH" >> ~/.bashrc
fi
source ~/.bashrc
# Step 5: Verify installation
echo ">>> Conda version:"
conda --version
# Step 6: Create environment
echo ">>> Creating environment: llm-env (python=3.10)..."
conda create -n llm-env python=3.10 -y
# Step 7: Activate environment
echo ">>> Activating environment..."
conda activate llm-env
# Step 8: Install dependencies
echo ">>> Installing scikit-learn..."
conda install scikit-learn -y
if [ -f requirements.txt ]; then
echo ">>> Installing Python requirements..."
pip install -r requirements.txt
else
echo ">>> No requirements.txt found, skipping pip install."
fi
echo ">>> Setup complete. To activate your environment run:"
echo "conda activate llm-env"
Binary file not shown.
Binary file not shown.