updated style mimciking fine tuning

This commit is contained in:
Your Name
2025-08-13 23:50:20 +00:00
parent 8847035d12
commit 1b46270afa
83 changed files with 2537260 additions and 378 deletions
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,319 @@
#!/usr/bin/env python3
"""
Styling Inference Pipeline using Trained Models
Supports style transfer inference with streaming and batch processing
"""
import os
import sys
import json
import argparse
from pathlib import Path
from typing import Dict, Any, Optional, List, Union
import yaml
# Add the project root to the path
sys.path.append(str(Path(__file__).parent.parent.parent))
# Inference imports
import torch
from datasets import load_from_disk, Dataset
from unsloth import FastLanguageModel
from transformers import TextStreamer
class StylingInference:
"""Styling task inference using trained models"""
def __init__(self, config: Dict[str, Any]):
self.config = config
self.model = None
self.tokenizer = None
# Set device
self.device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {self.device}")
# Model parameters
self.model_output_dir = config.get('model_output_dir', './models/styling')
self.max_seq_length = config.get('max_seq_length', 2048)
self.dtype = config.get('dtype', None)
self.load_in_4bit = config.get('load_in_4bit', True)
self.hf_token = config.get('hf_token', None)
# Inference parameters
self.batch_size = config.get('batch_size', 1)
self.max_new_tokens = config.get('max_new_tokens', 128)
self.temperature = config.get('temperature', 0.8)
self.top_p = config.get('top_p', 0.9)
self.do_sample = config.get('do_sample', True)
# Alpaca prompt template
self.alpaca_prompt = config.get('alpaca_prompt', """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that follows the instruction
### Instruction:
{}
### Input:
{}
### Response:
{}""")
# Style instruction
self.style_instruction = config.get('style_instruction', 'Rewrite the following text in a formal style')
def load_model_and_tokenizer(self):
"""Load the trained model and tokenizer"""
print("Loading trained model and tokenizer...")
try:
# Load the saved LoRA model
model_path = self.config.get('model_output_dir', './models/styling')
print(f"Loading model from: {model_path}")
self.model, self.tokenizer = FastLanguageModel.from_pretrained(
model_name=model_path,
max_seq_length=self.max_seq_length,
dtype=self.dtype,
load_in_4bit=self.load_in_4bit,
)
# Enable native 2x faster inference
FastLanguageModel.for_inference(self.model)
print(f"✅ Model loaded from: {model_path}")
print(f"✅ Tokenizer loaded with vocab size: {self.tokenizer.vocab_size}")
except Exception as e:
print(f"❌ Error loading model: {e}")
raise
def format_prompt(self, instruction: str, input_text: str = "") -> str:
"""Format prompt using the same alpaca format as training"""
# Use the exact same alpaca prompt as training
alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that follows the instruction
### Instruction:
{}
### Input:
{}
### Response:
{}"""
# For inference, output is empty (will be generated)
return alpaca_prompt.format(instruction, input_text, "")
def generate_text(self, instruction: str, input_text: str = "", max_new_tokens: int = 128, stream: bool = False):
"""Generate text using the trained model"""
try:
# Format the prompt
prompt = self.format_prompt(instruction, input_text)
print(f"Formatted prompt: {prompt}")
# Tokenize the input
inputs = self.tokenizer([prompt], return_tensors="pt").to(self.device)
if stream:
# Streaming generation
from transformers import TextStreamer
text_streamer = TextStreamer(self.tokenizer)
print("Generating with streaming...")
_ = self.model.generate(
**inputs,
streamer=text_streamer,
max_new_tokens=max_new_tokens
)
return None # Streaming output is handled by streamer
else:
# Non-streaming generation
print("Generating...")
outputs = self.model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=0.7,
pad_token_id=self.tokenizer.eos_token_id
)
# Decode the generated text
generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract only the generated response (remove the input prompt)
response_start = generated_text.find("### Response:")
if response_start != -1:
response = generated_text[response_start + len("### Response:"):].strip()
else:
response = generated_text
return response
except Exception as e:
print(f"❌ Error generating text: {e}")
raise
def style_transfer(self, input_text: str, instruction: Optional[str] = None, streaming: bool = False) -> str:
"""Perform style transfer on a single input text"""
try:
# Use default instruction if none provided
if instruction is None:
instruction = self.style_instruction
print(f"Style transfer prompt: {instruction}")
print(f"Input text: {input_text}")
# Format prompt
prompt = self.format_prompt(instruction, input_text)
print(f"Style transfer prompt: {prompt}")
if streaming:
print("Generating with streaming...")
self.generate_text_streaming(prompt)
return ""
else:
print("Generating text...")
result = self.generate_text(instruction, input_text, self.max_new_tokens)
print(f"Generated result: {result}")
return result
except Exception as e:
print(f"❌ Error in style transfer: {e}")
raise
def generate_text_streaming(self, prompt: str, max_new_tokens: Optional[int] = None):
"""Generate text with streaming output"""
try:
# Tokenize input
inputs = self.tokenizer([prompt], return_tensors="pt").to(self.device)
# Setup text streamer
text_streamer = TextStreamer(self.tokenizer)
# Set generation parameters
gen_kwargs = {
"max_new_tokens": max_new_tokens or self.max_new_tokens,
"temperature": self.temperature,
"top_p": self.top_p,
"do_sample": self.do_sample,
"use_cache": True,
"pad_token_id": self.tokenizer.eos_token_id
}
# Generate with streaming
with torch.no_grad():
_ = self.model.generate(**inputs, streamer=text_streamer, **gen_kwargs)
except Exception as e:
print(f"❌ Error in streaming generation: {e}")
def batch_style_transfer(self, input_texts: List[str], instruction: Optional[str] = None) -> List[str]:
"""Perform style transfer on multiple input texts"""
results = []
for i, input_text in enumerate(input_texts):
print(f"Processing text {i+1}/{len(input_texts)}")
result = self.style_transfer(input_text, instruction)
results.append(result)
return results
def load_inference_config(config_path: str) -> Dict[str, Any]:
"""Load inference configuration from YAML file"""
try:
with open(config_path, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
# Extract inference configuration
inference_config = {}
# Model configuration
if 'model' in config:
model_data = config['model']
inference_config.update({
'base_model_name': model_data.get('training_model', 'unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit'),
'max_seq_length': model_data.get('training_max_seq_length', 2048),
'dtype': model_data.get('training_dtype'),
'load_in_4bit': model_data.get('training_load_in_4bit', True),
'hf_token': model_data.get('training_token')
})
# Training configuration - to get model_output_dir
if 'training' in config:
training_data = config['training']
inference_config.update({
'model_output_dir': training_data.get('model_output_dir', './models/styling')
})
# Inference configuration
if 'inference' in config:
inference_data = config['inference']
inference_config.update({
'batch_size': inference_data.get('batch_size', 1),
'max_new_tokens': inference_data.get('max_new_tokens', 128),
'temperature': inference_data.get('temperature', 0.8)
})
# Style configuration
if 'data' in config:
data_config = config['data']
inference_config.update({
'style_instruction': data_config.get('instruction', 'Rewrite the following text in a formal style')
})
return inference_config
except Exception as e:
print(f"Error loading inference config: {e}")
raise
def main():
"""Main inference function"""
parser = argparse.ArgumentParser(description="Styling Inference Pipeline")
# Configuration
parser.add_argument("--config", type=str, required=True, help="Path to YAML configuration file")
parser.add_argument("--instruction", type=str, required=True, help="Style instruction")
parser.add_argument("--input-text", type=str, default="", help="Input text to style")
parser.add_argument("--max-tokens", type=int, default=128, help="Maximum new tokens to generate")
parser.add_argument("--stream", action="store_true", help="Enable streaming generation")
args = parser.parse_args()
try:
# Load configuration
print(f"Loading configuration from: {args.config}")
inference_config = load_inference_config(args.config)
# Override with CLI arguments
if args.max_tokens:
inference_config['max_new_tokens'] = args.max_tokens
print("Inference configuration:")
for key, value in inference_config.items():
print(f" {key}: {value}")
# Initialize inference
inference = StylingInference(inference_config)
# Load model and tokenizer
inference.load_model_and_tokenizer()
# Run inference
if args.stream:
print("Running streaming inference...")
inference.generate_text(args.instruction, args.input_text, args.max_tokens, stream=True)
else:
print("Running inference...")
result = inference.generate_text(args.instruction, args.input_text, args.max_tokens)
print(f"✅ Generated text: {result}")
except Exception as e:
print(f"Inference failed: {e}")
import traceback
traceback.print_exc()
sys.exit(1)
if __name__ == "__main__":
main()
@@ -0,0 +1,493 @@
#!/usr/bin/env python3
"""
Styling Training Pipeline using Unsloth and SFTTrainer
Supports style transfer tasks with LoRA fine-tuning
"""
import os
import sys
import json
import logging
import argparse
from pathlib import Path
from typing import Dict, Any, Optional
import yaml
# Add the project root to the path
sys.path.append(str(Path(__file__).parent.parent.parent))
from utils.config.config_manager import ConfigManager
#from utils.logging.logging import setup_logging
# Training imports
import torch
from datasets import load_from_disk, Dataset
from unsloth import FastLanguageModel, is_bfloat16_supported
from trl import SFTTrainer
from transformers import TrainingArguments
logger = logging.getLogger(__name__)
class StylingTrainer:
"""Styling task trainer using Unsloth and SFTTrainer"""
def __init__(self, config: Dict[str, Any]):
self.config = config
self.model = None
self.tokenizer = None
self.trainer = None
# Set device
self.device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Using device: {self.device}")
# Training parameters
self.max_seq_length = config.get('max_seq_length', 2048)
self.dtype = config.get('dtype', None)
self.load_in_4bit = config.get('load_in_4bit', True)
self.hf_token = config.get('hf_token', None)
# LoRA parameters
self.lora_r = config.get('lora_r', 16)
self.lora_alpha = config.get('lora_alpha', 16)
self.lora_dropout = config.get('lora_dropout', 0)
self.target_modules = config.get('target_modules', [
"q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj"
])
# Training arguments
self.batch_size = config.get('batch_size', 2)
self.gradient_accumulation_steps = config.get('gradient_accumulation_steps', 4)
self.learning_rate = config.get('learning_rate', 2e-4)
self.num_epochs = config.get('num_epochs', 1)
self.max_steps = config.get('max_steps', None)
self.warmup_ratio = config.get('warmup_ratio', 0.1)
# Set a default warmup_steps value
self.warmup_steps = config.get('warmup_steps', 10)
self.weight_decay = config.get('weight_decay', 0.01)
self.seed = config.get('seed', 3407)
# Output paths
self.output_dir = config.get('output_dir', './outputs')
self.model_output_dir = config.get('model_output_dir', './models/styling')
def load_model_and_tokenizer(self):
"""Load the pre-trained model and tokenizer"""
logger.info("Loading model and tokenizer...")
try:
self.model, self.tokenizer = FastLanguageModel.from_pretrained(
model_name=self.config['model_name'],
max_seq_length=self.max_seq_length,
dtype=self.dtype,
load_in_4bit=self.load_in_4bit,
token=self.hf_token
)
logger.info(f"✅ Model loaded: {self.config['model_name']}")
logger.info(f"✅ Tokenizer loaded with vocab size: {self.tokenizer.vocab_size}")
except Exception as e:
logger.error(f"❌ Error loading model: {e}")
raise
def setup_lora(self):
"""Setup LoRA for efficient fine-tuning"""
logger.info("Setting up LoRA configuration...")
try:
self.model = FastLanguageModel.get_peft_model(
self.model,
r=self.lora_r,
target_modules=self.target_modules,
lora_alpha=self.lora_alpha,
lora_dropout=self.lora_dropout,
bias="none",
use_gradient_checkpointing="unsloth",
random_state=self.seed,
use_rslora=False,
loftq_config=None
)
logger.info(f"✅ LoRA configured with r={self.lora_r}, alpha={self.lora_alpha}")
except Exception as e:
logger.error(f"❌ Error setting up LoRA: {e}")
raise
def load_dataset(self, dataset_path: str) -> Dataset:
"""Load the training dataset"""
logger.info(f"Loading dataset from: {dataset_path}")
try:
if Path(dataset_path).exists():
# Check if it's a HuggingFace dataset directory
if (Path(dataset_path) / "dataset_info.json").exists():
# Load from HuggingFace dataset directory
dataset = load_from_disk(dataset_path)
logger.info(f"Loaded HuggingFace dataset from disk: {len(dataset)} samples")
else:
# Load from processed data files (JSONL format)
logger.info("Loading from processed data files...")
from datasets import Dataset
import json
all_data = []
data_dir = Path(dataset_path)
# Look for train.jsonl, validation.jsonl, test.jsonl
for split_file in ["train.jsonl", "validation.jsonl", "test.jsonl"]:
file_path = data_dir / split_file
if file_path.exists():
logger.info(f"Loading {split_file}...")
with open(file_path, 'r', encoding='utf-8') as f:
for line in f:
if line.strip():
data = json.loads(line)
all_data.append(data)
if not all_data:
raise ValueError(f"No data found in {dataset_path}")
# Create HuggingFace dataset
dataset = Dataset.from_list(all_data)
logger.info(f"Created HuggingFace dataset from {len(all_data)} samples")
else:
# Try loading from HuggingFace Hub
logger.info(f"Attempting to load from HuggingFace Hub: {dataset_path}")
dataset = Dataset.load_dataset(dataset_path, split="train")
logger.info(f"Loaded from HuggingFace Hub: {len(dataset)} samples")
logger.info(f"Dataset loaded: {len(dataset)} samples")
logger.info(f"Dataset features: {dataset.features}")
# Verify required fields exist
required_fields = ["instruction", "input", "output"]
missing_fields = [field for field in required_fields if field not in dataset.features]
if missing_fields:
raise ValueError(f"Missing required fields in dataset: {missing_fields}")
return dataset
except Exception as e:
logger.error(f"Error loading dataset: {e}")
raise
def setup_trainer(self, train_dataset: Dataset):
"""Setup the SFTTrainer"""
print("Setting up SFTTrainer...")
try:
# First, map the dataset to create the text field with EOS token
def formatting_prompts_func(examples):
instructions = examples["instruction"]
inputs = examples["input"]
outputs = examples["output"]
texts = []
for instruction, input_text, output in zip(instructions, inputs, outputs):
# Must add EOS_TOKEN, otherwise your generation will go on forever!
alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that follows the instruction
### Instruction:
{}
### Input:
{}
### Response:
{}"""
text = alpaca_prompt.format(instruction, input_text, output) + self.tokenizer.eos_token
texts.append(text)
return {"text": texts}
# Apply the formatting function to create the text field
print("Mapping dataset to create text field with EOS token...")
formatted_dataset = train_dataset.map(formatting_prompts_func, batched=True, remove_columns=train_dataset.column_names)
print(f"Dataset mapped successfully. New features: {formatted_dataset.features}")
print(f"Sample text field: {formatted_dataset[0]['text'][:100]}...")
# Debug logging to identify parameter issues
print("Training parameters for TrainingArguments:")
print(f" batch_size: {self.batch_size} (type: {type(self.batch_size)})")
print(f" gradient_accumulation_steps: {self.gradient_accumulation_steps} (type: {type(self.gradient_accumulation_steps)})")
print(f" warmup_steps: {self.warmup_steps} (type: {type(self.warmup_steps)})")
print(f" num_epochs: {self.num_epochs} (type: {type(self.num_epochs)})")
print(f" max_steps: {self.max_steps} (type: {type(self.max_steps)})")
print(f" learning_rate: {self.learning_rate} (type: {type(self.learning_rate)})")
print(f" weight_decay: {self.weight_decay} (type: {type(self.weight_decay)})")
print(f" seed: {self.seed} (type: {type(self.seed)})")
print("Creating TrainingArguments...")
# Training arguments - using the exact working configuration
training_args = TrainingArguments(
per_device_train_batch_size=self.batch_size,
gradient_accumulation_steps=self.gradient_accumulation_steps,
warmup_steps=self.warmup_steps,
num_train_epochs=self.num_epochs,
max_steps=self.max_steps if self.max_steps is not None else 60, # Use default if None
learning_rate=self.learning_rate,
fp16=not is_bfloat16_supported(),
bf16=is_bfloat16_supported(),
logging_steps=1,
optim="adamw_8bit",
weight_decay=self.weight_decay,
lr_scheduler_type="linear",
seed=self.seed,
output_dir=self.output_dir,
report_to="none", # Disable wandb for now
)
print("TrainingArguments created successfully!")
print("SFTTrainer parameters:")
print(f" model: {type(self.model)}")
print(f" tokenizer: {type(self.tokenizer)}")
print(f" train_dataset: {type(formatted_dataset)} with {len(formatted_dataset)} samples")
print(f" dataset_text_field: text")
print(f" max_seq_length: {self.max_seq_length} (type: {type(self.max_seq_length)})")
print(f" dataset_num_proc: 2")
print(f" packing: False")
print(f" args: {type(training_args)}")
print("Creating SFTTrainer...")
# Create trainer with the formatted dataset
self.trainer = SFTTrainer(
model=self.model,
tokenizer=self.tokenizer,
train_dataset=formatted_dataset, # Use the formatted dataset
dataset_text_field="text", # The field we just created
max_seq_length=int(self.max_seq_length) if self.max_seq_length is not None else 2048,
dataset_num_proc=2,
packing=False, # Can make training 5x faster for short sequences
args=training_args
)
print("SFTTrainer configured successfully")
except Exception as e:
print(f"Error setting up trainer: {e}")
import traceback
print("Full error traceback:")
traceback.print_exc()
raise
def train(self, dataset_path: str):
"""Run the training process"""
print("🚀 Starting training process...")
try:
# Load model and tokenizer
print("Loading model and tokenizer...")
self.load_model_and_tokenizer()
# Setup LoRA
print("Setting up LoRA...")
self.setup_lora()
# Load dataset
print(f"Loading dataset from: {dataset_path}")
train_dataset = self.load_dataset(dataset_path)
# Setup trainer
print("Setting up trainer...")
self.setup_trainer(train_dataset)
# Start training
print("Starting training...")
trainer_stats = self.trainer.train()
print("✅ Training completed successfully!")
print(f"Training stats: {trainer_stats}")
# Save the model
self.save_model()
return trainer_stats
except Exception as e:
print(f"❌ Training failed: {e}")
import traceback
print("Full error traceback:")
traceback.print_exc()
raise
def save_model(self):
"""Save the trained model"""
print("Saving trained model...")
try:
# Create output directory
Path(self.model_output_dir).mkdir(parents=True, exist_ok=True)
# Save model and tokenizer
self.model.save_pretrained(self.model_output_dir)
self.tokenizer.save_pretrained(self.model_output_dir)
# Save training config
config_path = Path(self.model_output_dir) / "training_config.json"
with open(config_path, 'w') as f:
json.dump(self.config, f, indent=2)
print(f"✅ Model saved to: {self.model_output_dir}")
print(f"✅ You can now use this model for inference with: --config {self.model_output_dir}")
except Exception as e:
print(f"❌ Error saving model: {e}")
raise
def prepare_for_inference(self):
"""Prepare model for inference"""
logger.info("Preparing model for inference...")
try:
FastLanguageModel.for_inference(self.model)
logger.info("✅ Model prepared for inference")
except Exception as e:
logger.error(f"❌ Error preparing for inference: {e}")
raise
def load_training_config(yaml_path: str) -> Dict[str, Any]:
"""Load training configuration from YAML file"""
try:
with open(yaml_path, 'r') as f:
config = yaml.safe_load(f)
training_config = {}
# Model configuration - extract from model section
if 'model' in config:
model_config = config['model']
training_config.update({
'model_name': model_config.get('name', 'unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit'),
'max_seq_length': int(model_config.get('max_seq_length', 2048)),
'dtype': model_config.get('dtype', None),
'load_in_4bit': model_config.get('load_in_4bit', True),
'hf_token': model_config.get('token', None)
})
# Training configuration - extract from training section
if 'training' in config:
training_data = config['training']
print("DEBUG: Training data from YAML:")
print(f" num_epochs: {training_data.get('num_epochs')} (type: {type(training_data.get('num_epochs'))})")
print(f" batch_size: {training_data.get('batch_size')} (type: {type(training_data.get('batch_size'))})")
print(f" learning_rate: {training_data.get('learning_rate')} (type: {type(training_data.get('learning_rate'))})")
print(f" weight_decay: {training_data.get('weight_decay')} (type: {type(training_data.get('weight_decay'))})")
print(f" warmup_steps: {training_data.get('warmup_steps')} (type: {type(training_data.get('warmup_steps'))})")
print(f" max_steps: {training_data.get('max_steps')} (type: {type(training_data.get('max_steps'))})")
print(f" gradient_accumulation_steps: {training_data.get('gradient_accumulation_steps')} (type: {type(training_data.get('gradient_accumulation_steps'))})")
print(f" seed: {training_data.get('seed')} (type: {type(training_data.get('seed'))})")
print(f" model_output_dir: {training_data.get('model_output_dir')} (type: {type(training_data.get('model_output_dir'))})")
training_config.update({
'num_epochs': int(training_data.get('num_epochs', 1)),
'batch_size': int(training_data.get('batch_size', 2)),
'learning_rate': float(training_data.get('learning_rate', 2e-4)),
'weight_decay': float(training_data.get('weight_decay', 0.01)),
'warmup_steps': int(training_data.get('warmup_steps', 5)),
'max_steps': int(training_data.get('max_steps', 60)),
'gradient_accumulation_steps': int(training_data.get('gradient_accumulation_steps', 4)),
'lr_scheduler_type': training_data.get('lr_scheduler_type', 'linear'),
'seed': int(training_data.get('seed', 3407)),
'model_output_dir': training_data.get('model_output_dir', './models/styling')
})
# Data configuration - use output_dir from data section
if 'data' in config:
data_config = config['data']
output_dir = data_config.get('output_dir', './data/processed/styling')
training_config.update({
'data_output_dir': output_dir,
'dataset_path': output_dir, # Default dataset path is the output_dir
'style_instruction': data_config.get('instruction', 'Rewrite the following text in a formal style')
})
# LoRA configuration
training_config.update({
'lora_r': 16,
'lora_alpha': 16,
'lora_dropout': 0,
'target_modules': [
"q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj"
],
'output_dir': './outputs',
'model_output_dir': './models/styling'
})
print("DEBUG: Final training_config:")
for key, value in training_config.items():
print(f" {key}: {value} (type: {type(value)})")
return training_config
except Exception as e:
logger.error(f"Error loading training config: {e}")
raise
def main():
"""Main training function"""
parser = argparse.ArgumentParser(description="Styling Training Pipeline")
# Configuration
parser.add_argument("--config", type=str, required=True, help="Path to YAML configuration file")
parser.add_argument("--dataset", type=str, help="Path to training dataset (HF dataset path or local path)")
parser.add_argument("--output-dir", type=str, help="Output directory for model")
parser.add_argument("--epochs", type=int, help="Number of training epochs")
parser.add_argument("--batch-size", type=int, help="Training batch size")
parser.add_argument("--learning-rate", type=float, help="Learning rate")
parser.add_argument("--max-steps", type=int, help="Maximum training steps")
args = parser.parse_args()
# Setup logging
# setup_logging() # Commented out as per user's change
try:
# Load configuration
logger.info(f"Loading configuration from: {args.config}")
training_config = load_training_config(args.config)
# Override with CLI arguments
if args.output_dir:
training_config['model_output_dir'] = args.output_dir
if args.epochs:
training_config['num_epochs'] = int(args.epochs)
if args.batch_size:
training_config['batch_size'] = int(args.batch_size)
if args.learning_rate:
training_config['learning_rate'] = float(args.learning_rate)
if args.max_steps:
training_config['max_steps'] = int(args.max_steps)
# Determine dataset path: CLI argument takes precedence, then YAML config
dataset_path = args.dataset or training_config.get('dataset_path')
if not dataset_path:
logger.error("No dataset path provided. Use --dataset or ensure output_dir is set in YAML config.")
sys.exit(1)
logger.info("Training configuration:")
for key, value in training_config.items():
logger.info(f" {key}: {value}")
logger.info(f" Dataset path: {dataset_path}")
# Initialize trainer
trainer = StylingTrainer(training_config)
# Start training
trainer.train(dataset_path)
logger.info("Training completed successfully!")
except Exception as e:
logger.error(f"Training failed: {e}")
sys.exit(1)
if __name__ == "__main__":
main()
+45 -21
View File
@@ -910,28 +910,49 @@ class StylingDataPipeline:
def load_and_preprocess(self, config: StylingConfig) -> Tuple[Dict[str, List[Dict]], Dict[str, Any]]:
"""Load and preprocess data"""
# Load data
if config.data_source == "huggingface":
raw_splits = self.hf_loader.load(config)
processed_splits = self.hf_loader.preprocess(raw_splits, config)
elif config.data_source == "custom":
raw_splits = self.custom_loader.load(config)
processed_splits = self.custom_loader.preprocess(raw_splits, config)
else:
raise ValueError(f"Unsupported data source: {config.data_source}")
logger.info(f"Starting data loading and preprocessing...")
logger.info(f"Data source: {config.data_source}")
# Validate processed data
is_valid, errors = self.validator.validate_styling_data(processed_splits, config, is_processed=True)
if not is_valid:
logger.error("Data validation failed:")
for error in errors:
logger.error(f" - {error}")
raise ValueError("Data validation failed")
# Analyze dataset
analysis = self.validator.analyze_dataset(processed_splits, config, is_processed=True)
return processed_splits, analysis
try:
# Load data
if config.data_source == "huggingface":
logger.info("Loading HuggingFace dataset...")
raw_splits = self.hf_loader.load(config)
logger.info("Preprocessing HuggingFace dataset...")
processed_splits = self.hf_loader.preprocess(raw_splits, config)
elif config.data_source == "custom":
logger.info("Loading custom dataset...")
raw_splits = self.custom_loader.load(config)
logger.info("Preprocessing custom dataset...")
processed_splits = self.custom_loader.preprocess(raw_splits, config)
else:
raise ValueError(f"Unsupported data source: {config.data_source}")
logger.info(f"Data loading and preprocessing completed successfully")
logger.info(f"Raw splits: {list(raw_splits.keys())}")
logger.info(f"Processed splits: {list(processed_splits.keys())}")
# Validate processed data
logger.info("Validating processed data...")
is_valid, errors = self.validator.validate_styling_data(processed_splits, config, is_processed=True)
if not is_valid:
logger.error("Data validation failed:")
for error in errors:
logger.error(f" - {error}")
raise ValueError("Data validation failed")
logger.info("Data validation passed")
# Analyze dataset
logger.info("Analyzing dataset...")
analysis = self.validator.analyze_dataset(processed_splits, config, is_processed=True)
logger.info("Dataset analysis completed")
return processed_splits, analysis
except Exception as e:
logger.error(f"Error in load_and_preprocess: {e}")
raise
def convert_to_alpaca_format(self, data: Dict[str, List[Dict]], config: StylingConfig) -> Dict[str, List[Dict]]:
"""Convert styling data to Alpaca format with instruction"""
@@ -1481,6 +1502,9 @@ def main():
except Exception as e:
print(f"❌ Error running pipeline: {e}")
import traceback
print("Full error traceback:")
traceback.print_exc()
sys.exit(1)
+134 -161
View File
@@ -7,7 +7,6 @@ Supports style transfer inference with streaming and batch processing
import os
import sys
import json
import logging
import argparse
from pathlib import Path
from typing import Dict, Any, Optional, List, Union
@@ -16,17 +15,12 @@ import yaml
# Add the project root to the path
sys.path.append(str(Path(__file__).parent.parent.parent))
from utils.config.config_manager import ConfigManager
from utils.logging.logging import setup_logging
# Inference imports
import torch
from datasets import load_from_disk, Dataset
from unsloth import FastLanguageModel
from transformers import TextStreamer
logger = logging.getLogger(__name__)
class StylingInference:
"""Styling task inference using trained models"""
@@ -37,10 +31,10 @@ class StylingInference:
# Set device
self.device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Using device: {self.device}")
print(f"Using device: {self.device}")
# Model parameters
self.model_path = config.get('model_path')
self.model_output_dir = config.get('model_output_dir', './models/styling')
self.max_seq_length = config.get('max_seq_length', 2048)
self.dtype = config.get('dtype', None)
self.load_in_4bit = config.get('load_in_4bit', True)
@@ -70,96 +64,123 @@ class StylingInference:
def load_model_and_tokenizer(self):
"""Load the trained model and tokenizer"""
logger.info("Loading model and tokenizer...")
print("Loading trained model and tokenizer...")
try:
if self.model_path and Path(self.model_path).exists():
# Load local trained model
logger.info(f"Loading local model from: {self.model_path}")
self.model, self.tokenizer = FastLanguageModel.from_pretrained(
model_name=self.model_path,
max_seq_length=self.max_seq_length,
dtype=self.dtype,
load_in_4bit=self.load_in_4bit,
token=self.hf_token
)
else:
# Load base model from HuggingFace Hub
logger.info(f"Loading base model: {self.config.get('base_model_name', 'unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit')}")
self.model, self.tokenizer = FastLanguageModel.from_pretrained(
model_name=self.config.get('base_model_name', 'unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit'),
max_seq_length=self.max_seq_length,
dtype=self.dtype,
load_in_4bit=self.load_in_4bit,
token=self.hf_token
)
# Load the saved LoRA model
model_path = self.config.get('model_output_dir', './models/styling')
print(f"Loading model from: {model_path}")
# Prepare for inference
self.model, self.tokenizer = FastLanguageModel.from_pretrained(
model_name=model_path,
max_seq_length=self.max_seq_length,
dtype=self.dtype,
load_in_4bit=self.load_in_4bit,
)
# Enable native 2x faster inference
FastLanguageModel.for_inference(self.model)
logger.info(f"✅ Model loaded successfully")
logger.info(f"✅ Tokenizer loaded with vocab size: {self.tokenizer.vocab_size}")
print(f"✅ Model loaded from: {model_path}")
print(f"✅ Tokenizer loaded with vocab size: {self.tokenizer.vocab_size}")
except Exception as e:
logger.error(f"❌ Error loading model: {e}")
print(f"❌ Error loading model: {e}")
raise
def format_prompt(self, instruction: str, input_text: str, output: str = "") -> str:
"""Format the prompt using Alpaca template"""
return self.alpaca_prompt.format(instruction, input_text, output)
def format_prompt(self, instruction: str, input_text: str = "") -> str:
"""Format prompt using the same alpaca format as training"""
# Use the exact same alpaca prompt as training
alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that follows the instruction
### Instruction:
{}
### Input:
{}
### Response:
{}"""
# For inference, output is empty (will be generated)
return alpaca_prompt.format(instruction, input_text, "")
def generate_text(self, prompt: str, max_new_tokens: Optional[int] = None) -> str:
"""Generate text from a single prompt"""
def generate_text(self, instruction: str, input_text: str = "", max_new_tokens: int = 128, stream: bool = False):
"""Generate text using the trained model"""
try:
# Tokenize input
# Format the prompt
prompt = self.format_prompt(instruction, input_text)
print(f"Formatted prompt: {prompt}")
# Tokenize the input
inputs = self.tokenizer([prompt], return_tensors="pt").to(self.device)
# Set generation parameters
gen_kwargs = {
"max_new_tokens": max_new_tokens or self.max_new_tokens,
"temperature": self.temperature,
"top_p": self.top_p,
"do_sample": self.do_sample,
"use_cache": True,
"pad_token_id": self.tokenizer.eos_token_id
}
# Generate
with torch.no_grad():
outputs = self.model.generate(**inputs, **gen_kwargs)
# Decode
generated_text = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
# Extract only the generated part (remove input prompt)
if prompt in generated_text:
generated_text = generated_text[len(prompt):].strip()
return generated_text
if stream:
# Streaming generation
from transformers import TextStreamer
text_streamer = TextStreamer(self.tokenizer)
print("Generating with streaming...")
_ = self.model.generate(
**inputs,
streamer=text_streamer,
max_new_tokens=max_new_tokens
)
return None # Streaming output is handled by streamer
else:
# Non-streaming generation
print("Generating...")
outputs = self.model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=0.7,
pad_token_id=self.tokenizer.eos_token_id
)
# Decode the generated text
generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract only the generated response (remove the input prompt)
response_start = generated_text.find("### Response:")
if response_start != -1:
response = generated_text[response_start + len("### Response:"):].strip()
else:
response = generated_text
return response
except Exception as e:
logger.error(f"❌ Error generating text: {e}")
return ""
print(f"❌ Error generating text: {e}")
raise
def style_transfer(self, input_text: str, instruction: Optional[str] = None, streaming: bool = False) -> str:
"""Perform style transfer on input text"""
if instruction is None:
instruction = self.style_instruction
# Format prompt
prompt = self.format_prompt(instruction, input_text, "")
logger.info(f"Style transfer prompt: {prompt}")
if streaming:
logger.info("Generating with streaming...")
self.generate_text_streaming(prompt)
return ""
else:
logger.info("Generating text...")
result = self.generate_text(prompt)
logger.info(f"Generated result: {result}")
return result
"""Perform style transfer on a single input text"""
try:
# Use default instruction if none provided
if instruction is None:
instruction = self.style_instruction
print(f"Style transfer prompt: {instruction}")
print(f"Input text: {input_text}")
# Format prompt
prompt = self.format_prompt(instruction, input_text)
print(f"Style transfer prompt: {prompt}")
if streaming:
print("Generating with streaming...")
self.generate_text_streaming(prompt)
return ""
else:
print("Generating text...")
result = self.generate_text(instruction, input_text, self.max_new_tokens)
print(f"Generated result: {result}")
return result
except Exception as e:
print(f"❌ Error in style transfer: {e}")
raise
def generate_text_streaming(self, prompt: str, max_new_tokens: Optional[int] = None):
"""Generate text with streaming output"""
@@ -185,14 +206,14 @@ class StylingInference:
_ = self.model.generate(**inputs, streamer=text_streamer, **gen_kwargs)
except Exception as e:
logger.error(f"❌ Error in streaming generation: {e}")
print(f"❌ Error in streaming generation: {e}")
def batch_style_transfer(self, input_texts: List[str], instruction: Optional[str] = None) -> List[str]:
"""Perform style transfer on multiple input texts"""
results = []
for i, input_text in enumerate(input_texts):
logger.info(f"Processing text {i+1}/{len(input_texts)}")
print(f"Processing text {i+1}/{len(input_texts)}")
result = self.style_transfer(input_text, instruction)
results.append(result)
@@ -218,6 +239,13 @@ def load_inference_config(config_path: str) -> Dict[str, Any]:
'hf_token': model_data.get('training_token')
})
# Training configuration - to get model_output_dir
if 'training' in config:
training_data = config['training']
inference_config.update({
'model_output_dir': training_data.get('model_output_dir', './models/styling')
})
# Inference configuration
if 'inference' in config:
inference_data = config['inference']
@@ -237,7 +265,7 @@ def load_inference_config(config_path: str) -> Dict[str, Any]:
return inference_config
except Exception as e:
logger.error(f"Error loading inference config: {e}")
print(f"Error loading inference config: {e}")
raise
def main():
@@ -246,100 +274,45 @@ def main():
# Configuration
parser.add_argument("--config", type=str, required=True, help="Path to YAML configuration file")
parser.add_argument("--model-path", type=str, help="Path to trained model (optional, uses base model if not provided)")
# Inference modes
parser.add_argument("--text", type=str, help="Single text to style transfer")
parser.add_argument("--input-file", type=str, help="File containing texts to process (one per line)")
# Generation parameters
parser.add_argument("--max-tokens", type=int, help="Maximum new tokens to generate")
parser.add_argument("--temperature", type=float, help="Sampling temperature")
parser.add_argument("--streaming", action="store_true", help="Enable streaming generation")
parser.add_argument("--instruction", type=str, help="Custom style instruction")
# Output
parser.add_argument("--output-file", type=str, help="Output file for results")
parser.add_argument("--instruction", type=str, required=True, help="Style instruction")
parser.add_argument("--input-text", type=str, default="", help="Input text to style")
parser.add_argument("--max-tokens", type=int, default=128, help="Maximum new tokens to generate")
parser.add_argument("--stream", action="store_true", help="Enable streaming generation")
args = parser.parse_args()
# Setup logging
setup_logging()
try:
# Load configuration
logger.info(f"Loading configuration from: {args.config}")
print(f"Loading configuration from: {args.config}")
inference_config = load_inference_config(args.config)
# Override with CLI arguments
if args.model_path:
inference_config['model_path'] = args.model_path
if args.max_tokens:
inference_config['max_new_tokens'] = args.max_tokens
if args.temperature:
inference_config['temperature'] = args.temperature
if args.instruction:
inference_config['style_instruction'] = args.instruction
logger.info("Inference configuration:")
print("Inference configuration:")
for key, value in inference_config.items():
logger.info(f" {key}: {value}")
print(f" {key}: {value}")
# Initialize inference
inferencer = StylingInference(inference_config)
# Load model
inferencer.load_model_and_tokenizer()
# Run inference based on mode
if args.text:
# Single text inference
logger.info("Running single text inference...")
result = inferencer.style_transfer(args.text, args.instruction, args.streaming)
if not args.streaming:
print(f"\nGenerated text: {result}")
elif args.input_file:
# Batch file inference
logger.info("Running batch file inference...")
with open(args.input_file, 'r', encoding='utf-8') as f:
input_texts = [line.strip() for line in f if line.strip()]
results = inferencer.batch_style_transfer(input_texts, args.instruction)
# Save results
output_file = args.output_file or f"{Path(args.input_file).stem}_styled.txt"
with open(output_file, 'w', encoding='utf-8') as f:
for input_text, result in zip(input_texts, results):
f.write(f"Input: {input_text}\n")
f.write(f"Output: {result}\n")
f.write("-" * 50 + "\n")
logger.info(f"✅ Results saved to: {output_file}")
inference = StylingInference(inference_config)
# Load model and tokenizer
inference.load_model_and_tokenizer()
# Run inference
if args.stream:
print("Running streaming inference...")
inference.generate_text(args.instruction, args.input_text, args.max_tokens, stream=True)
else:
# Interactive mode
logger.info("Entering interactive mode. Type 'quit' to exit.")
while True:
try:
user_input = input("\nEnter text to style (or 'quit'): ").strip()
if user_input.lower() == 'quit':
break
if user_input:
result = inferencer.style_transfer(user_input, args.instruction, args.streaming)
if not args.streaming:
print(f"\nStyled text: {result}")
except KeyboardInterrupt:
break
except Exception as e:
logger.error(f"Error processing input: {e}")
logger.info("🎉 Inference completed successfully!")
print("Running inference...")
result = inference.generate_text(args.instruction, args.input_text, args.max_tokens)
print(f"✅ Generated text: {result}")
except Exception as e:
logger.error(f"Inference failed: {e}")
print(f"Inference failed: {e}")
import traceback
traceback.print_exc()
sys.exit(1)
if __name__ == "__main__":
+96 -49
View File
@@ -62,7 +62,11 @@ class StylingTrainer:
self.learning_rate = config.get('learning_rate', 2e-4)
self.num_epochs = config.get('num_epochs', 1)
self.max_steps = config.get('max_steps', None)
self.warmup_steps = config.get('warmup_steps', 5)
self.warmup_ratio = config.get('warmup_ratio', 0.1)
# Set a default warmup_steps value
self.warmup_steps = config.get('warmup_steps', 10)
self.weight_decay = config.get('weight_decay', 0.01)
self.seed = config.get('seed', 3407)
@@ -174,7 +178,7 @@ class StylingTrainer:
def setup_trainer(self, train_dataset: Dataset):
"""Setup the SFTTrainer"""
logger.info("Setting up SFTTrainer...")
print("Setting up SFTTrainer...")
try:
# First, map the dataset to create the text field with EOS token
@@ -202,19 +206,31 @@ class StylingTrainer:
return {"text": texts}
# Apply the formatting function to create the text field
logger.info("Mapping dataset to create text field with EOS token...")
print("Mapping dataset to create text field with EOS token...")
formatted_dataset = train_dataset.map(formatting_prompts_func, batched=True, remove_columns=train_dataset.column_names)
logger.info(f"Dataset mapped successfully. New features: {formatted_dataset.features}")
logger.info(f"Sample text field: {formatted_dataset[0]['text'][:100]}...")
print(f"Dataset mapped successfully. New features: {formatted_dataset.features}")
print(f"Sample text field: {formatted_dataset[0]['text'][:100]}...")
# Training arguments
# Debug logging to identify parameter issues
print("Training parameters for TrainingArguments:")
print(f" batch_size: {self.batch_size} (type: {type(self.batch_size)})")
print(f" gradient_accumulation_steps: {self.gradient_accumulation_steps} (type: {type(self.gradient_accumulation_steps)})")
print(f" warmup_steps: {self.warmup_steps} (type: {type(self.warmup_steps)})")
print(f" num_epochs: {self.num_epochs} (type: {type(self.num_epochs)})")
print(f" max_steps: {self.max_steps} (type: {type(self.max_steps)})")
print(f" learning_rate: {self.learning_rate} (type: {type(self.learning_rate)})")
print(f" weight_decay: {self.weight_decay} (type: {type(self.weight_decay)})")
print(f" seed: {self.seed} (type: {type(self.seed)})")
print("Creating TrainingArguments...")
# Training arguments - using the exact working configuration
training_args = TrainingArguments(
per_device_train_batch_size=self.batch_size,
gradient_accumulation_steps=self.gradient_accumulation_steps,
warmup_steps=self.warmup_steps,
num_train_epochs=self.num_epochs,
max_steps=self.max_steps,
max_steps=self.max_steps if self.max_steps is not None else 60, # Use default if None
learning_rate=self.learning_rate,
fp16=not is_bfloat16_supported(),
bf16=is_bfloat16_supported(),
@@ -225,55 +241,68 @@ class StylingTrainer:
seed=self.seed,
output_dir=self.output_dir,
report_to="none", # Disable wandb for now
save_strategy="epoch",
save_total_limit=2,
evaluation_strategy="no", # No validation for now
load_best_model_at_end=False,
remove_unused_columns=False,
dataloader_pin_memory=False,
)
print("TrainingArguments created successfully!")
print("SFTTrainer parameters:")
print(f" model: {type(self.model)}")
print(f" tokenizer: {type(self.tokenizer)}")
print(f" train_dataset: {type(formatted_dataset)} with {len(formatted_dataset)} samples")
print(f" dataset_text_field: text")
print(f" max_seq_length: {self.max_seq_length} (type: {type(self.max_seq_length)})")
print(f" dataset_num_proc: 2")
print(f" packing: False")
print(f" args: {type(training_args)}")
print("Creating SFTTrainer...")
# Create trainer with the formatted dataset
self.trainer = SFTTrainer(
model=self.model,
tokenizer=self.tokenizer,
train_dataset=formatted_dataset, # Use the formatted dataset
dataset_text_field="text", # The field we just created
max_seq_length=self.max_seq_length,
max_seq_length=int(self.max_seq_length) if self.max_seq_length is not None else 2048,
dataset_num_proc=2,
packing=False, # Can make training 5x faster for short sequences
args=training_args
)
logger.info("SFTTrainer configured successfully")
print("SFTTrainer configured successfully")
except Exception as e:
logger.error(f"Error setting up trainer: {e}")
print(f"Error setting up trainer: {e}")
import traceback
print("Full error traceback:")
traceback.print_exc()
raise
def train(self, dataset_path: str):
"""Run the training process"""
logger.info("🚀 Starting training process...")
print("🚀 Starting training process...")
try:
# Load model and tokenizer
print("Loading model and tokenizer...")
self.load_model_and_tokenizer()
# Setup LoRA
print("Setting up LoRA...")
self.setup_lora()
# Load dataset
print(f"Loading dataset from: {dataset_path}")
train_dataset = self.load_dataset(dataset_path)
# Setup trainer
print("Setting up trainer...")
self.setup_trainer(train_dataset)
# Start training
logger.info("Starting training...")
print("Starting training...")
trainer_stats = self.trainer.train()
logger.info("✅ Training completed successfully!")
logger.info(f"Training stats: {trainer_stats}")
print("✅ Training completed successfully!")
print(f"Training stats: {trainer_stats}")
# Save the model
self.save_model()
@@ -281,12 +310,15 @@ class StylingTrainer:
return trainer_stats
except Exception as e:
logger.error(f"❌ Training failed: {e}")
print(f"❌ Training failed: {e}")
import traceback
print("Full error traceback:")
traceback.print_exc()
raise
def save_model(self):
"""Save the trained model"""
logger.info("Saving trained model...")
print("Saving trained model...")
try:
# Create output directory
@@ -301,10 +333,11 @@ class StylingTrainer:
with open(config_path, 'w') as f:
json.dump(self.config, f, indent=2)
logger.info(f"✅ Model saved to: {self.model_output_dir}")
print(f"✅ Model saved to: {self.model_output_dir}")
print(f"✅ You can now use this model for inference with: --config {self.model_output_dir}")
except Exception as e:
logger.error(f"❌ Error saving model: {e}")
print(f"❌ Error saving model: {e}")
raise
def prepare_for_inference(self):
@@ -319,36 +352,50 @@ class StylingTrainer:
logger.error(f"❌ Error preparing for inference: {e}")
raise
def load_training_config(config_path: str) -> Dict[str, Any]:
def load_training_config(yaml_path: str) -> Dict[str, Any]:
"""Load training configuration from YAML file"""
try:
with open(config_path, 'r', encoding='utf-8') as f:
with open(yaml_path, 'r') as f:
config = yaml.safe_load(f)
# Extract training configuration
training_config = {}
# Model configuration
# Model configuration - extract from model section
if 'model' in config:
model_data = config['model']
model_config = config['model']
training_config.update({
'model_name': model_data.get('training_model', 'unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit'),
'max_seq_length': model_data.get('training_max_seq_length', 2048),
'dtype': model_data.get('training_dtype'),
'load_in_4bit': model_data.get('training_load_in_4bit', True),
'hf_token': model_data.get('training_token')
'model_name': model_config.get('name', 'unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit'),
'max_seq_length': int(model_config.get('max_seq_length', 2048)),
'dtype': model_config.get('dtype', None),
'load_in_4bit': model_config.get('load_in_4bit', True),
'hf_token': model_config.get('token', None)
})
# Training configuration
# Training configuration - extract from training section
if 'training' in config:
training_data = config['training']
print("DEBUG: Training data from YAML:")
print(f" num_epochs: {training_data.get('num_epochs')} (type: {type(training_data.get('num_epochs'))})")
print(f" batch_size: {training_data.get('batch_size')} (type: {type(training_data.get('batch_size'))})")
print(f" learning_rate: {training_data.get('learning_rate')} (type: {type(training_data.get('learning_rate'))})")
print(f" weight_decay: {training_data.get('weight_decay')} (type: {type(training_data.get('weight_decay'))})")
print(f" warmup_steps: {training_data.get('warmup_steps')} (type: {type(training_data.get('warmup_steps'))})")
print(f" max_steps: {training_data.get('max_steps')} (type: {type(training_data.get('max_steps'))})")
print(f" gradient_accumulation_steps: {training_data.get('gradient_accumulation_steps')} (type: {type(training_data.get('gradient_accumulation_steps'))})")
print(f" seed: {training_data.get('seed')} (type: {type(training_data.get('seed'))})")
print(f" model_output_dir: {training_data.get('model_output_dir')} (type: {type(training_data.get('model_output_dir'))})")
training_config.update({
'num_epochs': training_data.get('num_epochs', 3),
'batch_size': training_data.get('batch_size', 2),
'learning_rate': training_data.get('learning_rate', 2e-4),
'weight_decay': training_data.get('weight_decay', 0.01),
'warmup_ratio': training_data.get('warmup_ratio', 0.1),
'lr_scheduler_type': training_data.get('lr_scheduler_type', 'linear')
'num_epochs': int(training_data.get('num_epochs', 1)),
'batch_size': int(training_data.get('batch_size', 2)),
'learning_rate': float(training_data.get('learning_rate', 2e-4)),
'weight_decay': float(training_data.get('weight_decay', 0.01)),
'warmup_steps': int(training_data.get('warmup_steps', 5)),
'max_steps': int(training_data.get('max_steps', 60)),
'gradient_accumulation_steps': int(training_data.get('gradient_accumulation_steps', 4)),
'lr_scheduler_type': training_data.get('lr_scheduler_type', 'linear'),
'seed': int(training_data.get('seed', 3407)),
'model_output_dir': training_data.get('model_output_dir', './models/styling')
})
# Data configuration - use output_dir from data section
@@ -370,14 +417,14 @@ def load_training_config(config_path: str) -> Dict[str, Any]:
"q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj"
],
'gradient_accumulation_steps': 4,
'max_steps': None,
'warmup_steps': 5,
'seed': 3407,
'output_dir': './outputs',
'model_output_dir': './models/styling'
})
print("DEBUG: Final training_config:")
for key, value in training_config.items():
print(f" {key}: {value} (type: {type(value)})")
return training_config
except Exception as e:
@@ -411,13 +458,13 @@ def main():
if args.output_dir:
training_config['model_output_dir'] = args.output_dir
if args.epochs:
training_config['num_epochs'] = args.epochs
training_config['num_epochs'] = int(args.epochs)
if args.batch_size:
training_config['batch_size'] = args.batch_size
training_config['batch_size'] = int(args.batch_size)
if args.learning_rate:
training_config['learning_rate'] = args.learning_rate
training_config['learning_rate'] = float(args.learning_rate)
if args.max_steps:
training_config['max_steps'] = args.max_steps
training_config['max_steps'] = int(args.max_steps)
# Determine dataset path: CLI argument takes precedence, then YAML config
dataset_path = args.dataset or training_config.get('dataset_path')