added style mimicking piepelines
This commit is contained in:
Binary file not shown.
Binary file not shown.
Binary file not shown.
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,346 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Styling Inference Pipeline using Trained Models
|
||||
Supports style transfer inference with streaming and batch processing
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import logging
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Optional, List, Union
|
||||
import yaml
|
||||
|
||||
# Add the project root to the path
|
||||
sys.path.append(str(Path(__file__).parent.parent.parent))
|
||||
|
||||
from utils.config.config_manager import ConfigManager
|
||||
from utils.logging.logging import setup_logging
|
||||
|
||||
# Inference imports
|
||||
import torch
|
||||
from datasets import load_from_disk, Dataset
|
||||
from unsloth import FastLanguageModel
|
||||
from transformers import TextStreamer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class StylingInference:
|
||||
"""Styling task inference using trained models"""
|
||||
|
||||
def __init__(self, config: Dict[str, Any]):
|
||||
self.config = config
|
||||
self.model = None
|
||||
self.tokenizer = None
|
||||
|
||||
# Set device
|
||||
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
logger.info(f"Using device: {self.device}")
|
||||
|
||||
# Model parameters
|
||||
self.model_path = config.get('model_path')
|
||||
self.max_seq_length = config.get('max_seq_length', 2048)
|
||||
self.dtype = config.get('dtype', None)
|
||||
self.load_in_4bit = config.get('load_in_4bit', True)
|
||||
self.hf_token = config.get('hf_token', None)
|
||||
|
||||
# Inference parameters
|
||||
self.batch_size = config.get('batch_size', 1)
|
||||
self.max_new_tokens = config.get('max_new_tokens', 128)
|
||||
self.temperature = config.get('temperature', 0.8)
|
||||
self.top_p = config.get('top_p', 0.9)
|
||||
self.do_sample = config.get('do_sample', True)
|
||||
|
||||
# Alpaca prompt template
|
||||
self.alpaca_prompt = config.get('alpaca_prompt', """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that follows the instruction
|
||||
|
||||
### Instruction:
|
||||
{}
|
||||
|
||||
### Input:
|
||||
{}
|
||||
|
||||
### Response:
|
||||
{}""")
|
||||
|
||||
# Style instruction
|
||||
self.style_instruction = config.get('style_instruction', 'Rewrite the following text in a formal style')
|
||||
|
||||
def load_model_and_tokenizer(self):
|
||||
"""Load the trained model and tokenizer"""
|
||||
logger.info("Loading model and tokenizer...")
|
||||
|
||||
try:
|
||||
if self.model_path and Path(self.model_path).exists():
|
||||
# Load local trained model
|
||||
logger.info(f"Loading local model from: {self.model_path}")
|
||||
self.model, self.tokenizer = FastLanguageModel.from_pretrained(
|
||||
model_name=self.model_path,
|
||||
max_seq_length=self.max_seq_length,
|
||||
dtype=self.dtype,
|
||||
load_in_4bit=self.load_in_4bit,
|
||||
token=self.hf_token
|
||||
)
|
||||
else:
|
||||
# Load base model from HuggingFace Hub
|
||||
logger.info(f"Loading base model: {self.config.get('base_model_name', 'unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit')}")
|
||||
self.model, self.tokenizer = FastLanguageModel.from_pretrained(
|
||||
model_name=self.config.get('base_model_name', 'unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit'),
|
||||
max_seq_length=self.max_seq_length,
|
||||
dtype=self.dtype,
|
||||
load_in_4bit=self.load_in_4bit,
|
||||
token=self.hf_token
|
||||
)
|
||||
|
||||
# Prepare for inference
|
||||
FastLanguageModel.for_inference(self.model)
|
||||
|
||||
logger.info(f"✅ Model loaded successfully")
|
||||
logger.info(f"✅ Tokenizer loaded with vocab size: {self.tokenizer.vocab_size}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error loading model: {e}")
|
||||
raise
|
||||
|
||||
def format_prompt(self, instruction: str, input_text: str, output: str = "") -> str:
|
||||
"""Format the prompt using Alpaca template"""
|
||||
return self.alpaca_prompt.format(instruction, input_text, output)
|
||||
|
||||
def generate_text(self, prompt: str, max_new_tokens: Optional[int] = None) -> str:
|
||||
"""Generate text from a single prompt"""
|
||||
try:
|
||||
# Tokenize input
|
||||
inputs = self.tokenizer([prompt], return_tensors="pt").to(self.device)
|
||||
|
||||
# Set generation parameters
|
||||
gen_kwargs = {
|
||||
"max_new_tokens": max_new_tokens or self.max_new_tokens,
|
||||
"temperature": self.temperature,
|
||||
"top_p": self.top_p,
|
||||
"do_sample": self.do_sample,
|
||||
"use_cache": True,
|
||||
"pad_token_id": self.tokenizer.eos_token_id
|
||||
}
|
||||
|
||||
# Generate
|
||||
with torch.no_grad():
|
||||
outputs = self.model.generate(**inputs, **gen_kwargs)
|
||||
|
||||
# Decode
|
||||
generated_text = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
|
||||
|
||||
# Extract only the generated part (remove input prompt)
|
||||
if prompt in generated_text:
|
||||
generated_text = generated_text[len(prompt):].strip()
|
||||
|
||||
return generated_text
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error generating text: {e}")
|
||||
return ""
|
||||
|
||||
def style_transfer(self, input_text: str, instruction: Optional[str] = None, streaming: bool = False) -> str:
|
||||
"""Perform style transfer on input text"""
|
||||
if instruction is None:
|
||||
instruction = self.style_instruction
|
||||
|
||||
# Format prompt
|
||||
prompt = self.format_prompt(instruction, input_text, "")
|
||||
|
||||
logger.info(f"Style transfer prompt: {prompt}")
|
||||
|
||||
if streaming:
|
||||
logger.info("Generating with streaming...")
|
||||
self.generate_text_streaming(prompt)
|
||||
return ""
|
||||
else:
|
||||
logger.info("Generating text...")
|
||||
result = self.generate_text(prompt)
|
||||
logger.info(f"Generated result: {result}")
|
||||
return result
|
||||
|
||||
def generate_text_streaming(self, prompt: str, max_new_tokens: Optional[int] = None):
|
||||
"""Generate text with streaming output"""
|
||||
try:
|
||||
# Tokenize input
|
||||
inputs = self.tokenizer([prompt], return_tensors="pt").to(self.device)
|
||||
|
||||
# Setup text streamer
|
||||
text_streamer = TextStreamer(self.tokenizer)
|
||||
|
||||
# Set generation parameters
|
||||
gen_kwargs = {
|
||||
"max_new_tokens": max_new_tokens or self.max_new_tokens,
|
||||
"temperature": self.temperature,
|
||||
"top_p": self.top_p,
|
||||
"do_sample": self.do_sample,
|
||||
"use_cache": True,
|
||||
"pad_token_id": self.tokenizer.eos_token_id
|
||||
}
|
||||
|
||||
# Generate with streaming
|
||||
with torch.no_grad():
|
||||
_ = self.model.generate(**inputs, streamer=text_streamer, **gen_kwargs)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error in streaming generation: {e}")
|
||||
|
||||
def batch_style_transfer(self, input_texts: List[str], instruction: Optional[str] = None) -> List[str]:
|
||||
"""Perform style transfer on multiple input texts"""
|
||||
results = []
|
||||
|
||||
for i, input_text in enumerate(input_texts):
|
||||
logger.info(f"Processing text {i+1}/{len(input_texts)}")
|
||||
result = self.style_transfer(input_text, instruction)
|
||||
results.append(result)
|
||||
|
||||
return results
|
||||
|
||||
def load_inference_config(config_path: str) -> Dict[str, Any]:
|
||||
"""Load inference configuration from YAML file"""
|
||||
try:
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
config = yaml.safe_load(f)
|
||||
|
||||
# Extract inference configuration
|
||||
inference_config = {}
|
||||
|
||||
# Model configuration
|
||||
if 'model' in config:
|
||||
model_data = config['model']
|
||||
inference_config.update({
|
||||
'base_model_name': model_data.get('training_model', 'unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit'),
|
||||
'max_seq_length': model_data.get('training_max_seq_length', 2048),
|
||||
'dtype': model_data.get('training_dtype'),
|
||||
'load_in_4bit': model_data.get('training_load_in_4bit', True),
|
||||
'hf_token': model_data.get('training_token')
|
||||
})
|
||||
|
||||
# Inference configuration
|
||||
if 'inference' in config:
|
||||
inference_data = config['inference']
|
||||
inference_config.update({
|
||||
'batch_size': inference_data.get('batch_size', 1),
|
||||
'max_new_tokens': inference_data.get('max_new_tokens', 128),
|
||||
'temperature': inference_data.get('temperature', 0.8)
|
||||
})
|
||||
|
||||
# Style configuration
|
||||
if 'data' in config:
|
||||
data_config = config['data']
|
||||
inference_config.update({
|
||||
'style_instruction': data_config.get('instruction', 'Rewrite the following text in a formal style')
|
||||
})
|
||||
|
||||
return inference_config
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading inference config: {e}")
|
||||
raise
|
||||
|
||||
def main():
|
||||
"""Main inference function"""
|
||||
parser = argparse.ArgumentParser(description="Styling Inference Pipeline")
|
||||
|
||||
# Configuration
|
||||
parser.add_argument("--config", type=str, required=True, help="Path to YAML configuration file")
|
||||
parser.add_argument("--model-path", type=str, help="Path to trained model (optional, uses base model if not provided)")
|
||||
|
||||
# Inference modes
|
||||
parser.add_argument("--text", type=str, help="Single text to style transfer")
|
||||
parser.add_argument("--input-file", type=str, help="File containing texts to process (one per line)")
|
||||
|
||||
# Generation parameters
|
||||
parser.add_argument("--max-tokens", type=int, help="Maximum new tokens to generate")
|
||||
parser.add_argument("--temperature", type=float, help="Sampling temperature")
|
||||
parser.add_argument("--streaming", action="store_true", help="Enable streaming generation")
|
||||
parser.add_argument("--instruction", type=str, help="Custom style instruction")
|
||||
|
||||
# Output
|
||||
parser.add_argument("--output-file", type=str, help="Output file for results")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Setup logging
|
||||
setup_logging()
|
||||
|
||||
try:
|
||||
# Load configuration
|
||||
logger.info(f"Loading configuration from: {args.config}")
|
||||
inference_config = load_inference_config(args.config)
|
||||
|
||||
# Override with CLI arguments
|
||||
if args.model_path:
|
||||
inference_config['model_path'] = args.model_path
|
||||
if args.max_tokens:
|
||||
inference_config['max_new_tokens'] = args.max_tokens
|
||||
if args.temperature:
|
||||
inference_config['temperature'] = args.temperature
|
||||
if args.instruction:
|
||||
inference_config['style_instruction'] = args.instruction
|
||||
|
||||
logger.info("Inference configuration:")
|
||||
for key, value in inference_config.items():
|
||||
logger.info(f" {key}: {value}")
|
||||
|
||||
# Initialize inference
|
||||
inferencer = StylingInference(inference_config)
|
||||
|
||||
# Load model
|
||||
inferencer.load_model_and_tokenizer()
|
||||
|
||||
# Run inference based on mode
|
||||
if args.text:
|
||||
# Single text inference
|
||||
logger.info("Running single text inference...")
|
||||
result = inferencer.style_transfer(args.text, args.instruction, args.streaming)
|
||||
if not args.streaming:
|
||||
print(f"\nGenerated text: {result}")
|
||||
|
||||
elif args.input_file:
|
||||
# Batch file inference
|
||||
logger.info("Running batch file inference...")
|
||||
with open(args.input_file, 'r', encoding='utf-8') as f:
|
||||
input_texts = [line.strip() for line in f if line.strip()]
|
||||
|
||||
results = inferencer.batch_style_transfer(input_texts, args.instruction)
|
||||
|
||||
# Save results
|
||||
output_file = args.output_file or f"{Path(args.input_file).stem}_styled.txt"
|
||||
with open(output_file, 'w', encoding='utf-8') as f:
|
||||
for input_text, result in zip(input_texts, results):
|
||||
f.write(f"Input: {input_text}\n")
|
||||
f.write(f"Output: {result}\n")
|
||||
f.write("-" * 50 + "\n")
|
||||
|
||||
logger.info(f"✅ Results saved to: {output_file}")
|
||||
|
||||
else:
|
||||
# Interactive mode
|
||||
logger.info("Entering interactive mode. Type 'quit' to exit.")
|
||||
while True:
|
||||
try:
|
||||
user_input = input("\nEnter text to style (or 'quit'): ").strip()
|
||||
if user_input.lower() == 'quit':
|
||||
break
|
||||
|
||||
if user_input:
|
||||
result = inferencer.style_transfer(user_input, args.instruction, args.streaming)
|
||||
if not args.streaming:
|
||||
print(f"\nStyled text: {result}")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing input: {e}")
|
||||
|
||||
logger.info("🎉 Inference completed successfully!")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Inference failed: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,446 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Styling Training Pipeline using Unsloth and SFTTrainer
|
||||
Supports style transfer tasks with LoRA fine-tuning
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import logging
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Optional
|
||||
import yaml
|
||||
|
||||
# Add the project root to the path
|
||||
sys.path.append(str(Path(__file__).parent.parent.parent))
|
||||
|
||||
from utils.config.config_manager import ConfigManager
|
||||
#from utils.logging.logging import setup_logging
|
||||
|
||||
# Training imports
|
||||
import torch
|
||||
from datasets import load_from_disk, Dataset
|
||||
from unsloth import FastLanguageModel, is_bfloat16_supported
|
||||
from trl import SFTTrainer
|
||||
from transformers import TrainingArguments
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class StylingTrainer:
|
||||
"""Styling task trainer using Unsloth and SFTTrainer"""
|
||||
|
||||
def __init__(self, config: Dict[str, Any]):
|
||||
self.config = config
|
||||
self.model = None
|
||||
self.tokenizer = None
|
||||
self.trainer = None
|
||||
|
||||
# Set device
|
||||
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
logger.info(f"Using device: {self.device}")
|
||||
|
||||
# Training parameters
|
||||
self.max_seq_length = config.get('max_seq_length', 2048)
|
||||
self.dtype = config.get('dtype', None)
|
||||
self.load_in_4bit = config.get('load_in_4bit', True)
|
||||
self.hf_token = config.get('hf_token', None)
|
||||
|
||||
# LoRA parameters
|
||||
self.lora_r = config.get('lora_r', 16)
|
||||
self.lora_alpha = config.get('lora_alpha', 16)
|
||||
self.lora_dropout = config.get('lora_dropout', 0)
|
||||
self.target_modules = config.get('target_modules', [
|
||||
"q_proj", "k_proj", "v_proj", "o_proj",
|
||||
"gate_proj", "up_proj", "down_proj"
|
||||
])
|
||||
|
||||
# Training arguments
|
||||
self.batch_size = config.get('batch_size', 2)
|
||||
self.gradient_accumulation_steps = config.get('gradient_accumulation_steps', 4)
|
||||
self.learning_rate = config.get('learning_rate', 2e-4)
|
||||
self.num_epochs = config.get('num_epochs', 1)
|
||||
self.max_steps = config.get('max_steps', None)
|
||||
self.warmup_steps = config.get('warmup_steps', 5)
|
||||
self.weight_decay = config.get('weight_decay', 0.01)
|
||||
self.seed = config.get('seed', 3407)
|
||||
|
||||
# Output paths
|
||||
self.output_dir = config.get('output_dir', './outputs')
|
||||
self.model_output_dir = config.get('model_output_dir', './models/styling')
|
||||
|
||||
def load_model_and_tokenizer(self):
|
||||
"""Load the pre-trained model and tokenizer"""
|
||||
logger.info("Loading model and tokenizer...")
|
||||
|
||||
try:
|
||||
self.model, self.tokenizer = FastLanguageModel.from_pretrained(
|
||||
model_name=self.config['model_name'],
|
||||
max_seq_length=self.max_seq_length,
|
||||
dtype=self.dtype,
|
||||
load_in_4bit=self.load_in_4bit,
|
||||
token=self.hf_token
|
||||
)
|
||||
|
||||
logger.info(f"✅ Model loaded: {self.config['model_name']}")
|
||||
logger.info(f"✅ Tokenizer loaded with vocab size: {self.tokenizer.vocab_size}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error loading model: {e}")
|
||||
raise
|
||||
|
||||
def setup_lora(self):
|
||||
"""Setup LoRA for efficient fine-tuning"""
|
||||
logger.info("Setting up LoRA configuration...")
|
||||
|
||||
try:
|
||||
self.model = FastLanguageModel.get_peft_model(
|
||||
self.model,
|
||||
r=self.lora_r,
|
||||
target_modules=self.target_modules,
|
||||
lora_alpha=self.lora_alpha,
|
||||
lora_dropout=self.lora_dropout,
|
||||
bias="none",
|
||||
use_gradient_checkpointing="unsloth",
|
||||
random_state=self.seed,
|
||||
use_rslora=False,
|
||||
loftq_config=None
|
||||
)
|
||||
|
||||
logger.info(f"✅ LoRA configured with r={self.lora_r}, alpha={self.lora_alpha}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error setting up LoRA: {e}")
|
||||
raise
|
||||
|
||||
def load_dataset(self, dataset_path: str) -> Dataset:
|
||||
"""Load the training dataset"""
|
||||
logger.info(f"Loading dataset from: {dataset_path}")
|
||||
|
||||
try:
|
||||
if Path(dataset_path).exists():
|
||||
# Check if it's a HuggingFace dataset directory
|
||||
if (Path(dataset_path) / "dataset_info.json").exists():
|
||||
# Load from HuggingFace dataset directory
|
||||
dataset = load_from_disk(dataset_path)
|
||||
logger.info(f"Loaded HuggingFace dataset from disk: {len(dataset)} samples")
|
||||
else:
|
||||
# Load from processed data files (JSONL format)
|
||||
logger.info("Loading from processed data files...")
|
||||
from datasets import Dataset
|
||||
import json
|
||||
|
||||
all_data = []
|
||||
data_dir = Path(dataset_path)
|
||||
|
||||
# Look for train.jsonl, validation.jsonl, test.jsonl
|
||||
for split_file in ["train.jsonl", "validation.jsonl", "test.jsonl"]:
|
||||
file_path = data_dir / split_file
|
||||
if file_path.exists():
|
||||
logger.info(f"Loading {split_file}...")
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
if line.strip():
|
||||
data = json.loads(line)
|
||||
all_data.append(data)
|
||||
|
||||
if not all_data:
|
||||
raise ValueError(f"No data found in {dataset_path}")
|
||||
|
||||
# Create HuggingFace dataset
|
||||
dataset = Dataset.from_list(all_data)
|
||||
logger.info(f"Created HuggingFace dataset from {len(all_data)} samples")
|
||||
else:
|
||||
# Try loading from HuggingFace Hub
|
||||
logger.info(f"Attempting to load from HuggingFace Hub: {dataset_path}")
|
||||
dataset = Dataset.load_dataset(dataset_path, split="train")
|
||||
logger.info(f"Loaded from HuggingFace Hub: {len(dataset)} samples")
|
||||
|
||||
logger.info(f"Dataset loaded: {len(dataset)} samples")
|
||||
logger.info(f"Dataset features: {dataset.features}")
|
||||
|
||||
# Verify required fields exist
|
||||
required_fields = ["instruction", "input", "output"]
|
||||
missing_fields = [field for field in required_fields if field not in dataset.features]
|
||||
if missing_fields:
|
||||
raise ValueError(f"Missing required fields in dataset: {missing_fields}")
|
||||
|
||||
return dataset
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading dataset: {e}")
|
||||
raise
|
||||
|
||||
def setup_trainer(self, train_dataset: Dataset):
|
||||
"""Setup the SFTTrainer"""
|
||||
logger.info("Setting up SFTTrainer...")
|
||||
|
||||
try:
|
||||
# First, map the dataset to create the text field with EOS token
|
||||
def formatting_prompts_func(examples):
|
||||
instructions = examples["instruction"]
|
||||
inputs = examples["input"]
|
||||
outputs = examples["output"]
|
||||
texts = []
|
||||
|
||||
for instruction, input_text, output in zip(instructions, inputs, outputs):
|
||||
# Must add EOS_TOKEN, otherwise your generation will go on forever!
|
||||
alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that follows the instruction
|
||||
|
||||
### Instruction:
|
||||
{}
|
||||
|
||||
### Input:
|
||||
{}
|
||||
|
||||
### Response:
|
||||
{}"""
|
||||
text = alpaca_prompt.format(instruction, input_text, output) + self.tokenizer.eos_token
|
||||
texts.append(text)
|
||||
|
||||
return {"text": texts}
|
||||
|
||||
# Apply the formatting function to create the text field
|
||||
logger.info("Mapping dataset to create text field with EOS token...")
|
||||
formatted_dataset = train_dataset.map(formatting_prompts_func, batched=True, remove_columns=train_dataset.column_names)
|
||||
|
||||
logger.info(f"Dataset mapped successfully. New features: {formatted_dataset.features}")
|
||||
logger.info(f"Sample text field: {formatted_dataset[0]['text'][:100]}...")
|
||||
|
||||
# Training arguments
|
||||
training_args = TrainingArguments(
|
||||
per_device_train_batch_size=self.batch_size,
|
||||
gradient_accumulation_steps=self.gradient_accumulation_steps,
|
||||
warmup_steps=self.warmup_steps,
|
||||
num_train_epochs=self.num_epochs,
|
||||
max_steps=self.max_steps,
|
||||
learning_rate=self.learning_rate,
|
||||
fp16=not is_bfloat16_supported(),
|
||||
bf16=is_bfloat16_supported(),
|
||||
logging_steps=1,
|
||||
optim="adamw_8bit",
|
||||
weight_decay=self.weight_decay,
|
||||
lr_scheduler_type="linear",
|
||||
seed=self.seed,
|
||||
output_dir=self.output_dir,
|
||||
report_to="none", # Disable wandb for now
|
||||
save_strategy="epoch",
|
||||
save_total_limit=2,
|
||||
evaluation_strategy="no", # No validation for now
|
||||
load_best_model_at_end=False,
|
||||
remove_unused_columns=False,
|
||||
dataloader_pin_memory=False,
|
||||
)
|
||||
|
||||
# Create trainer with the formatted dataset
|
||||
self.trainer = SFTTrainer(
|
||||
model=self.model,
|
||||
tokenizer=self.tokenizer,
|
||||
train_dataset=formatted_dataset, # Use the formatted dataset
|
||||
dataset_text_field="text", # The field we just created
|
||||
max_seq_length=self.max_seq_length,
|
||||
dataset_num_proc=2,
|
||||
packing=False, # Can make training 5x faster for short sequences
|
||||
args=training_args
|
||||
)
|
||||
|
||||
logger.info("SFTTrainer configured successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error setting up trainer: {e}")
|
||||
raise
|
||||
|
||||
def train(self, dataset_path: str):
|
||||
"""Run the training process"""
|
||||
logger.info("🚀 Starting training process...")
|
||||
|
||||
try:
|
||||
# Load model and tokenizer
|
||||
self.load_model_and_tokenizer()
|
||||
|
||||
# Setup LoRA
|
||||
self.setup_lora()
|
||||
|
||||
# Load dataset
|
||||
train_dataset = self.load_dataset(dataset_path)
|
||||
|
||||
# Setup trainer
|
||||
self.setup_trainer(train_dataset)
|
||||
|
||||
# Start training
|
||||
logger.info("Starting training...")
|
||||
trainer_stats = self.trainer.train()
|
||||
|
||||
logger.info("✅ Training completed successfully!")
|
||||
logger.info(f"Training stats: {trainer_stats}")
|
||||
|
||||
# Save the model
|
||||
self.save_model()
|
||||
|
||||
return trainer_stats
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Training failed: {e}")
|
||||
raise
|
||||
|
||||
def save_model(self):
|
||||
"""Save the trained model"""
|
||||
logger.info("Saving trained model...")
|
||||
|
||||
try:
|
||||
# Create output directory
|
||||
Path(self.model_output_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Save model and tokenizer
|
||||
self.model.save_pretrained(self.model_output_dir)
|
||||
self.tokenizer.save_pretrained(self.model_output_dir)
|
||||
|
||||
# Save training config
|
||||
config_path = Path(self.model_output_dir) / "training_config.json"
|
||||
with open(config_path, 'w') as f:
|
||||
json.dump(self.config, f, indent=2)
|
||||
|
||||
logger.info(f"✅ Model saved to: {self.model_output_dir}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error saving model: {e}")
|
||||
raise
|
||||
|
||||
def prepare_for_inference(self):
|
||||
"""Prepare model for inference"""
|
||||
logger.info("Preparing model for inference...")
|
||||
|
||||
try:
|
||||
FastLanguageModel.for_inference(self.model)
|
||||
logger.info("✅ Model prepared for inference")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error preparing for inference: {e}")
|
||||
raise
|
||||
|
||||
def load_training_config(config_path: str) -> Dict[str, Any]:
|
||||
"""Load training configuration from YAML file"""
|
||||
try:
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
config = yaml.safe_load(f)
|
||||
|
||||
# Extract training configuration
|
||||
training_config = {}
|
||||
|
||||
# Model configuration
|
||||
if 'model' in config:
|
||||
model_data = config['model']
|
||||
training_config.update({
|
||||
'model_name': model_data.get('training_model', 'unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit'),
|
||||
'max_seq_length': model_data.get('training_max_seq_length', 2048),
|
||||
'dtype': model_data.get('training_dtype'),
|
||||
'load_in_4bit': model_data.get('training_load_in_4bit', True),
|
||||
'hf_token': model_data.get('training_token')
|
||||
})
|
||||
|
||||
# Training configuration
|
||||
if 'training' in config:
|
||||
training_data = config['training']
|
||||
training_config.update({
|
||||
'num_epochs': training_data.get('num_epochs', 3),
|
||||
'batch_size': training_data.get('batch_size', 2),
|
||||
'learning_rate': training_data.get('learning_rate', 2e-4),
|
||||
'weight_decay': training_data.get('weight_decay', 0.01),
|
||||
'warmup_ratio': training_data.get('warmup_ratio', 0.1),
|
||||
'lr_scheduler_type': training_data.get('lr_scheduler_type', 'linear')
|
||||
})
|
||||
|
||||
# Data configuration - use output_dir from data section
|
||||
if 'data' in config:
|
||||
data_config = config['data']
|
||||
output_dir = data_config.get('output_dir', './data/processed/styling')
|
||||
training_config.update({
|
||||
'data_output_dir': output_dir,
|
||||
'dataset_path': output_dir, # Default dataset path is the output_dir
|
||||
'style_instruction': data_config.get('instruction', 'Rewrite the following text in a formal style')
|
||||
})
|
||||
|
||||
# LoRA configuration
|
||||
training_config.update({
|
||||
'lora_r': 16,
|
||||
'lora_alpha': 16,
|
||||
'lora_dropout': 0,
|
||||
'target_modules': [
|
||||
"q_proj", "k_proj", "v_proj", "o_proj",
|
||||
"gate_proj", "up_proj", "down_proj"
|
||||
],
|
||||
'gradient_accumulation_steps': 4,
|
||||
'max_steps': None,
|
||||
'warmup_steps': 5,
|
||||
'seed': 3407,
|
||||
'output_dir': './outputs',
|
||||
'model_output_dir': './models/styling'
|
||||
})
|
||||
|
||||
return training_config
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading training config: {e}")
|
||||
raise
|
||||
|
||||
def main():
|
||||
"""Main training function"""
|
||||
parser = argparse.ArgumentParser(description="Styling Training Pipeline")
|
||||
|
||||
# Configuration
|
||||
parser.add_argument("--config", type=str, required=True, help="Path to YAML configuration file")
|
||||
parser.add_argument("--dataset", type=str, help="Path to training dataset (HF dataset path or local path)")
|
||||
parser.add_argument("--output-dir", type=str, help="Output directory for model")
|
||||
parser.add_argument("--epochs", type=int, help="Number of training epochs")
|
||||
parser.add_argument("--batch-size", type=int, help="Training batch size")
|
||||
parser.add_argument("--learning-rate", type=float, help="Learning rate")
|
||||
parser.add_argument("--max-steps", type=int, help="Maximum training steps")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Setup logging
|
||||
# setup_logging() # Commented out as per user's change
|
||||
|
||||
try:
|
||||
# Load configuration
|
||||
logger.info(f"Loading configuration from: {args.config}")
|
||||
training_config = load_training_config(args.config)
|
||||
|
||||
# Override with CLI arguments
|
||||
if args.output_dir:
|
||||
training_config['model_output_dir'] = args.output_dir
|
||||
if args.epochs:
|
||||
training_config['num_epochs'] = args.epochs
|
||||
if args.batch_size:
|
||||
training_config['batch_size'] = args.batch_size
|
||||
if args.learning_rate:
|
||||
training_config['learning_rate'] = args.learning_rate
|
||||
if args.max_steps:
|
||||
training_config['max_steps'] = args.max_steps
|
||||
|
||||
# Determine dataset path: CLI argument takes precedence, then YAML config
|
||||
dataset_path = args.dataset or training_config.get('dataset_path')
|
||||
if not dataset_path:
|
||||
logger.error("No dataset path provided. Use --dataset or ensure output_dir is set in YAML config.")
|
||||
sys.exit(1)
|
||||
|
||||
logger.info("Training configuration:")
|
||||
for key, value in training_config.items():
|
||||
logger.info(f" {key}: {value}")
|
||||
logger.info(f" Dataset path: {dataset_path}")
|
||||
|
||||
# Initialize trainer
|
||||
trainer = StylingTrainer(training_config)
|
||||
|
||||
# Start training
|
||||
trainer.train(dataset_path)
|
||||
|
||||
logger.info("Training completed successfully!")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Training failed: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user