Files
DS-LLM-TEMPLATE-FINETUNING/pipelines/classification/train.py
T
OwusuBlessing fef3f5ae35 initial setupt
2025-08-06 22:45:37 +01:00

468 lines
16 KiB
Python

import torch
from torch.optim import AdamW
from torch.utils.data import DataLoader
from transformers import (
AutoTokenizer,
AutoModelForSequenceClassification,
DataCollatorWithPadding,
get_scheduler
)
from accelerate import Accelerator
from datasets import Dataset
from tqdm.auto import tqdm
import evaluate
import json
from pathlib import Path
from dataclasses import dataclass
from typing import Dict, List, Optional
import logging
import argparse
import sys
import yaml
logger = logging.getLogger(__name__)
@dataclass
class SimpleConfig:
"""Simple configuration for accelerate training"""
# Model settings
model_name: str = "bert-base-uncased"
max_length: int = 512
# Training settings
num_epochs: int = 3
batch_size: int = 16
learning_rate: float = 2e-5
weight_decay: float = 0.01
# Scheduler settings
lr_scheduler_type: str = "linear"
warmup_ratio: float = 0.1
# Paths
data_dir: str = "./data/classification"
output_dir: str = "./results"
class AccelerateTrainer:
"""Simple trainer using Accelerate for distributed training"""
def __init__(self, config: SimpleConfig):
self.config = config
# Initialize accelerator
self.accelerator = Accelerator()
# Setup logging only on main process
if self.accelerator.is_main_process:
logging.basicConfig(
format='%(asctime)s - %(levelname)s - %(message)s',
level=logging.INFO
)
self.tokenizer = None
self.model = None
self.label_to_id = {}
self.id_to_label = {}
self.num_labels = 0
def load_data(self) -> Dict[str, List[Dict]]:
"""Load data from JSONL files"""
data_path = Path(self.config.data_dir)
splits = {}
for split_name in ["train", "validation", "test"]:
split_file = data_path / f"{split_name}.jsonl"
if split_file.exists():
split_data = []
with open(split_file, 'r', encoding='utf-8') as f:
for line in f:
if line.strip():
split_data.append(json.loads(line))
splits[split_name] = split_data
if self.accelerator.is_main_process:
logger.info(f"Loaded {len(split_data)} samples from {split_name}")
return splits
def setup_labels(self, train_data: List[Dict]):
"""Setup label mappings"""
labels = set()
for item in train_data:
labels.add(str(item["label"]))
sorted_labels = sorted(list(labels))
self.label_to_id = {label: idx for idx, label in enumerate(sorted_labels)}
self.id_to_label = {idx: label for label, idx in self.label_to_id.items()}
self.num_labels = len(sorted_labels)
if self.accelerator.is_main_process:
logger.info(f"Found {self.num_labels} labels: {sorted_labels}")
def create_dataset(self, data: List[Dict]) -> Dataset:
"""Create tokenized dataset"""
texts = []
labels = []
for item in data:
text = item["text"] if "text" in item else item["input"]
label = item["label"]
texts.append(str(text))
# Convert label to ID
if isinstance(label, str):
label_id = self.label_to_id.get(label, 0)
else:
label_id = int(label)
labels.append(label_id)
# Create dataset
dataset = Dataset.from_dict({
"text": texts,
"labels": labels
})
# Tokenize
def tokenize_function(examples):
return self.tokenizer(
examples["text"],
truncation=True,
padding="max_length",
max_length=self.config.max_length
)
tokenized_dataset = dataset.map(
tokenize_function,
batched=True,
remove_columns=["text"]
)
return tokenized_dataset
def compute_metrics(self, predictions, labels):
"""Compute accuracy and F1"""
predictions = predictions.argmax(axis=-1)
# Gather predictions from all processes
all_predictions = self.accelerator.gather_for_metrics(predictions)
all_labels = self.accelerator.gather_for_metrics(labels)
if self.accelerator.is_main_process:
# Only compute metrics on main process
metric = evaluate.load("glue", "mrpc") # Using MRPC as example
results = metric.compute(
predictions=all_predictions.cpu().numpy(),
references=all_labels.cpu().numpy()
)
return results
return {}
def train(self):
"""Main training function"""
if self.accelerator.is_main_process:
logger.info("=== Starting Accelerate Training ===")
# Load data
splits_data = self.load_data()
if "train" not in splits_data:
raise ValueError("No training data found!")
# Setup labels
self.setup_labels(splits_data["train"])
# Initialize tokenizer and model
self.tokenizer = AutoTokenizer.from_pretrained(self.config.model_name)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.model = AutoModelForSequenceClassification.from_pretrained(
self.config.model_name,
num_labels=self.num_labels,
id2label=self.id_to_label,
label2id=self.label_to_id
)
# Create datasets
train_dataset = self.create_dataset(splits_data["train"])
eval_dataset = None
if "validation" in splits_data:
eval_dataset = self.create_dataset(splits_data["validation"])
# Create data loaders
data_collator = DataCollatorWithPadding(tokenizer=self.tokenizer)
train_dataloader = DataLoader(
train_dataset,
shuffle=True,
batch_size=self.config.batch_size,
collate_fn=data_collator
)
eval_dataloader = None
if eval_dataset:
eval_dataloader = DataLoader(
eval_dataset,
batch_size=self.config.batch_size,
collate_fn=data_collator
)
# Setup optimizer and scheduler
optimizer = AdamW(
self.model.parameters(),
lr=self.config.learning_rate,
weight_decay=self.config.weight_decay
)
num_training_steps = self.config.num_epochs * len(train_dataloader)
num_warmup_steps = int(num_training_steps * self.config.warmup_ratio)
lr_scheduler = get_scheduler(
self.config.lr_scheduler_type,
optimizer=optimizer,
num_warmup_steps=num_warmup_steps,
num_training_steps=num_training_steps
)
# Prepare everything with accelerator
self.model, optimizer, train_dataloader, lr_scheduler = self.accelerator.prepare(
self.model, optimizer, train_dataloader, lr_scheduler
)
if eval_dataloader:
eval_dataloader = self.accelerator.prepare(eval_dataloader)
# Training loop
if self.accelerator.is_main_process:
progress_bar = tqdm(range(num_training_steps))
logger.info(f"Training steps: {num_training_steps}")
self.model.train()
for epoch in range(self.config.num_epochs):
if self.accelerator.is_main_process:
logger.info(f"Epoch {epoch + 1}/{self.config.num_epochs}")
for step, batch in enumerate(train_dataloader):
outputs = self.model(**batch)
loss = outputs.loss
# Backward pass
self.accelerator.backward(loss)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
if self.accelerator.is_main_process:
progress_bar.update(1)
# Log every 100 steps
if step % 100 == 0:
logger.info(f"Step {step}, Loss: {loss.item():.4f}")
# Evaluation at end of each epoch
if eval_dataloader:
self.evaluate(eval_dataloader, epoch)
# Save model
self.save_model()
if self.accelerator.is_main_process:
logger.info("=== Training Completed ===")
def evaluate(self, eval_dataloader, epoch):
"""Evaluation function"""
self.model.eval()
all_predictions = []
all_labels = []
for batch in eval_dataloader:
with torch.no_grad():
outputs = self.model(**batch)
predictions = outputs.logits
labels = batch["labels"]
# Gather from all processes
predictions = self.accelerator.gather_for_metrics(predictions)
labels = self.accelerator.gather_for_metrics(labels)
all_predictions.append(predictions.cpu())
all_labels.append(labels.cpu())
if self.accelerator.is_main_process and all_predictions:
all_predictions = torch.cat(all_predictions)
all_labels = torch.cat(all_labels)
predictions_np = all_predictions.argmax(dim=-1).numpy()
labels_np = all_labels.numpy()
# Simple accuracy calculation
accuracy = (predictions_np == labels_np).mean()
logger.info(f"Epoch {epoch + 1} - Validation Accuracy: {accuracy:.4f}")
self.model.train()
def save_model(self):
"""Save model and tokenizer"""
if self.accelerator.is_main_process:
output_path = Path(self.config.output_dir)
output_path.mkdir(parents=True, exist_ok=True)
# Save model using accelerator
unwrapped_model = self.accelerator.unwrap_model(self.model)
unwrapped_model.save_pretrained(output_path)
self.tokenizer.save_pretrained(output_path)
# Save label info
label_info = {
"label_to_id": self.label_to_id,
"id_to_label": self.id_to_label,
"num_labels": self.num_labels
}
with open(output_path / "label_info.json", 'w') as f:
json.dump(label_info, f, indent=2)
logger.info(f"Model saved to {output_path}")
def main():
"""Main function with YAML configuration support"""
parser = argparse.ArgumentParser(description="Accelerate Training Pipeline")
# YAML configuration
parser.add_argument("--config", type=str, help="Path to YAML configuration file")
# Model settings
parser.add_argument("--model-name", type=str, help="Model name from HuggingFace Hub")
parser.add_argument("--max-length", type=int, help="Maximum sequence length for tokenization")
# Training settings
parser.add_argument("--num-epochs", type=int, help="Number of training epochs")
parser.add_argument("--batch-size", type=int, help="Training batch size")
parser.add_argument("--learning-rate", type=float, help="Learning rate")
parser.add_argument("--weight-decay", type=float, help="Weight decay for optimizer")
# Scheduler settings
parser.add_argument("--lr-scheduler-type", choices=["linear", "cosine", "polynomial"], help="Learning rate scheduler type")
parser.add_argument("--warmup-ratio", type=float, help="Warmup ratio for scheduler")
# Paths
parser.add_argument("--data-dir", type=str, help="Directory containing train/validation/test JSONL files")
parser.add_argument("--output-dir", type=str, help="Output directory for saved model")
# Logging
parser.add_argument("--log-level", choices=["DEBUG", "INFO", "WARNING", "ERROR"], default="INFO", help="Logging level")
args = parser.parse_args()
# Set up logging
logging.basicConfig(
level=getattr(logging, args.log_level),
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
# Load configuration
config_dict = {}
# Load YAML config if provided
if args.config:
try:
with open(args.config, 'r', encoding='utf-8') as f:
config_dict = yaml.safe_load(f)
logger.info(f"Loaded YAML configuration from: {args.config}")
except Exception as e:
logger.error(f"Error loading YAML config: {e}")
sys.exit(1)
# Override YAML config with CLI arguments
cli_overrides = {}
if args.model_name:
cli_overrides['model_name'] = args.model_name
if args.max_length:
cli_overrides['max_length'] = args.max_length
if args.num_epochs:
cli_overrides['num_epochs'] = args.num_epochs
if args.batch_size:
cli_overrides['batch_size'] = args.batch_size
if args.learning_rate:
cli_overrides['learning_rate'] = args.learning_rate
if args.weight_decay:
cli_overrides['weight_decay'] = args.weight_decay
if args.lr_scheduler_type:
cli_overrides['lr_scheduler_type'] = args.lr_scheduler_type
if args.warmup_ratio:
cli_overrides['warmup_ratio'] = args.warmup_ratio
if args.data_dir:
cli_overrides['data_dir'] = args.data_dir
if args.output_dir:
cli_overrides['output_dir'] = args.output_dir
# Merge configurations
for key, value in cli_overrides.items():
if key in config_dict:
logger.info(f"Overriding YAML config '{key}' with CLI value: {value}")
config_dict[key] = value
# Create configuration object
config = SimpleConfig(
model_name=config_dict.get('model_name', 'bert-base-uncased'),
max_length=config_dict.get('max_length', 512),
num_epochs=config_dict.get('num_epochs', 3),
batch_size=config_dict.get('batch_size', 16),
learning_rate=config_dict.get('learning_rate', 2e-5),
weight_decay=config_dict.get('weight_decay', 0.01),
lr_scheduler_type=config_dict.get('lr_scheduler_type', 'linear'),
warmup_ratio=config_dict.get('warmup_ratio', 0.1),
data_dir=config_dict.get('data_dir', './data/classification'),
output_dir=config_dict.get('output_dir', './results')
)
# Validate data directory
data_path = Path(config.data_dir)
if not data_path.exists():
print(f"❌ Data directory not found: {data_path}")
print("Please ensure the data directory exists and contains train.jsonl, validation.jsonl, and test.jsonl files")
sys.exit(1)
# Check for required data files
required_files = ["train.jsonl"]
missing_files = []
for file_name in required_files:
if not (data_path / file_name).exists():
missing_files.append(file_name)
if missing_files:
print(f"❌ Missing required data files: {missing_files}")
print(f"Please ensure these files exist in: {data_path}")
sys.exit(1)
# Initialize and run training
try:
print(f"Starting training with model: {config.model_name}")
print(f"Data directory: {config.data_dir}")
print(f"Output directory: {config.output_dir}")
print(f"Training for {config.num_epochs} epochs with batch size {config.batch_size}")
if args.config:
print(f"Using YAML configuration: {args.config}")
print()
trainer = AccelerateTrainer(config)
trainer.train()
print(f"✅ Training completed successfully!")
print(f"Model saved to: {config.output_dir}")
except Exception as e:
print(f"❌ Error during training: {e}")
sys.exit(1)
if __name__ == "__main__":
main()