468 lines
16 KiB
Python
468 lines
16 KiB
Python
import torch
|
|
from torch.optim import AdamW
|
|
from torch.utils.data import DataLoader
|
|
from transformers import (
|
|
AutoTokenizer,
|
|
AutoModelForSequenceClassification,
|
|
DataCollatorWithPadding,
|
|
get_scheduler
|
|
)
|
|
from accelerate import Accelerator
|
|
from datasets import Dataset
|
|
from tqdm.auto import tqdm
|
|
import evaluate
|
|
import json
|
|
from pathlib import Path
|
|
from dataclasses import dataclass
|
|
from typing import Dict, List, Optional
|
|
import logging
|
|
import argparse
|
|
import sys
|
|
import yaml
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
@dataclass
|
|
class SimpleConfig:
|
|
"""Simple configuration for accelerate training"""
|
|
# Model settings
|
|
model_name: str = "bert-base-uncased"
|
|
max_length: int = 512
|
|
|
|
# Training settings
|
|
num_epochs: int = 3
|
|
batch_size: int = 16
|
|
learning_rate: float = 2e-5
|
|
weight_decay: float = 0.01
|
|
|
|
# Scheduler settings
|
|
lr_scheduler_type: str = "linear"
|
|
warmup_ratio: float = 0.1
|
|
|
|
# Paths
|
|
data_dir: str = "./data/classification"
|
|
output_dir: str = "./results"
|
|
|
|
|
|
class AccelerateTrainer:
|
|
"""Simple trainer using Accelerate for distributed training"""
|
|
|
|
def __init__(self, config: SimpleConfig):
|
|
self.config = config
|
|
|
|
# Initialize accelerator
|
|
self.accelerator = Accelerator()
|
|
|
|
# Setup logging only on main process
|
|
if self.accelerator.is_main_process:
|
|
logging.basicConfig(
|
|
format='%(asctime)s - %(levelname)s - %(message)s',
|
|
level=logging.INFO
|
|
)
|
|
|
|
self.tokenizer = None
|
|
self.model = None
|
|
self.label_to_id = {}
|
|
self.id_to_label = {}
|
|
self.num_labels = 0
|
|
|
|
def load_data(self) -> Dict[str, List[Dict]]:
|
|
"""Load data from JSONL files"""
|
|
data_path = Path(self.config.data_dir)
|
|
splits = {}
|
|
|
|
for split_name in ["train", "validation", "test"]:
|
|
split_file = data_path / f"{split_name}.jsonl"
|
|
if split_file.exists():
|
|
split_data = []
|
|
with open(split_file, 'r', encoding='utf-8') as f:
|
|
for line in f:
|
|
if line.strip():
|
|
split_data.append(json.loads(line))
|
|
splits[split_name] = split_data
|
|
|
|
if self.accelerator.is_main_process:
|
|
logger.info(f"Loaded {len(split_data)} samples from {split_name}")
|
|
|
|
return splits
|
|
|
|
def setup_labels(self, train_data: List[Dict]):
|
|
"""Setup label mappings"""
|
|
labels = set()
|
|
for item in train_data:
|
|
labels.add(str(item["label"]))
|
|
|
|
sorted_labels = sorted(list(labels))
|
|
self.label_to_id = {label: idx for idx, label in enumerate(sorted_labels)}
|
|
self.id_to_label = {idx: label for label, idx in self.label_to_id.items()}
|
|
self.num_labels = len(sorted_labels)
|
|
|
|
if self.accelerator.is_main_process:
|
|
logger.info(f"Found {self.num_labels} labels: {sorted_labels}")
|
|
|
|
def create_dataset(self, data: List[Dict]) -> Dataset:
|
|
"""Create tokenized dataset"""
|
|
texts = []
|
|
labels = []
|
|
|
|
for item in data:
|
|
text = item["text"] if "text" in item else item["input"]
|
|
label = item["label"]
|
|
|
|
texts.append(str(text))
|
|
# Convert label to ID
|
|
if isinstance(label, str):
|
|
label_id = self.label_to_id.get(label, 0)
|
|
else:
|
|
label_id = int(label)
|
|
labels.append(label_id)
|
|
|
|
# Create dataset
|
|
dataset = Dataset.from_dict({
|
|
"text": texts,
|
|
"labels": labels
|
|
})
|
|
|
|
# Tokenize
|
|
def tokenize_function(examples):
|
|
return self.tokenizer(
|
|
examples["text"],
|
|
truncation=True,
|
|
padding="max_length",
|
|
max_length=self.config.max_length
|
|
)
|
|
|
|
tokenized_dataset = dataset.map(
|
|
tokenize_function,
|
|
batched=True,
|
|
remove_columns=["text"]
|
|
)
|
|
|
|
return tokenized_dataset
|
|
|
|
def compute_metrics(self, predictions, labels):
|
|
"""Compute accuracy and F1"""
|
|
predictions = predictions.argmax(axis=-1)
|
|
|
|
# Gather predictions from all processes
|
|
all_predictions = self.accelerator.gather_for_metrics(predictions)
|
|
all_labels = self.accelerator.gather_for_metrics(labels)
|
|
|
|
if self.accelerator.is_main_process:
|
|
# Only compute metrics on main process
|
|
metric = evaluate.load("glue", "mrpc") # Using MRPC as example
|
|
results = metric.compute(
|
|
predictions=all_predictions.cpu().numpy(),
|
|
references=all_labels.cpu().numpy()
|
|
)
|
|
return results
|
|
return {}
|
|
|
|
def train(self):
|
|
"""Main training function"""
|
|
if self.accelerator.is_main_process:
|
|
logger.info("=== Starting Accelerate Training ===")
|
|
|
|
# Load data
|
|
splits_data = self.load_data()
|
|
if "train" not in splits_data:
|
|
raise ValueError("No training data found!")
|
|
|
|
# Setup labels
|
|
self.setup_labels(splits_data["train"])
|
|
|
|
# Initialize tokenizer and model
|
|
self.tokenizer = AutoTokenizer.from_pretrained(self.config.model_name)
|
|
if self.tokenizer.pad_token is None:
|
|
self.tokenizer.pad_token = self.tokenizer.eos_token
|
|
|
|
self.model = AutoModelForSequenceClassification.from_pretrained(
|
|
self.config.model_name,
|
|
num_labels=self.num_labels,
|
|
id2label=self.id_to_label,
|
|
label2id=self.label_to_id
|
|
)
|
|
|
|
# Create datasets
|
|
train_dataset = self.create_dataset(splits_data["train"])
|
|
eval_dataset = None
|
|
if "validation" in splits_data:
|
|
eval_dataset = self.create_dataset(splits_data["validation"])
|
|
|
|
# Create data loaders
|
|
data_collator = DataCollatorWithPadding(tokenizer=self.tokenizer)
|
|
|
|
train_dataloader = DataLoader(
|
|
train_dataset,
|
|
shuffle=True,
|
|
batch_size=self.config.batch_size,
|
|
collate_fn=data_collator
|
|
)
|
|
|
|
eval_dataloader = None
|
|
if eval_dataset:
|
|
eval_dataloader = DataLoader(
|
|
eval_dataset,
|
|
batch_size=self.config.batch_size,
|
|
collate_fn=data_collator
|
|
)
|
|
|
|
# Setup optimizer and scheduler
|
|
optimizer = AdamW(
|
|
self.model.parameters(),
|
|
lr=self.config.learning_rate,
|
|
weight_decay=self.config.weight_decay
|
|
)
|
|
|
|
num_training_steps = self.config.num_epochs * len(train_dataloader)
|
|
num_warmup_steps = int(num_training_steps * self.config.warmup_ratio)
|
|
|
|
lr_scheduler = get_scheduler(
|
|
self.config.lr_scheduler_type,
|
|
optimizer=optimizer,
|
|
num_warmup_steps=num_warmup_steps,
|
|
num_training_steps=num_training_steps
|
|
)
|
|
|
|
# Prepare everything with accelerator
|
|
self.model, optimizer, train_dataloader, lr_scheduler = self.accelerator.prepare(
|
|
self.model, optimizer, train_dataloader, lr_scheduler
|
|
)
|
|
|
|
if eval_dataloader:
|
|
eval_dataloader = self.accelerator.prepare(eval_dataloader)
|
|
|
|
# Training loop
|
|
if self.accelerator.is_main_process:
|
|
progress_bar = tqdm(range(num_training_steps))
|
|
logger.info(f"Training steps: {num_training_steps}")
|
|
|
|
self.model.train()
|
|
|
|
for epoch in range(self.config.num_epochs):
|
|
if self.accelerator.is_main_process:
|
|
logger.info(f"Epoch {epoch + 1}/{self.config.num_epochs}")
|
|
|
|
for step, batch in enumerate(train_dataloader):
|
|
outputs = self.model(**batch)
|
|
loss = outputs.loss
|
|
|
|
# Backward pass
|
|
self.accelerator.backward(loss)
|
|
|
|
optimizer.step()
|
|
lr_scheduler.step()
|
|
optimizer.zero_grad()
|
|
|
|
if self.accelerator.is_main_process:
|
|
progress_bar.update(1)
|
|
|
|
# Log every 100 steps
|
|
if step % 100 == 0:
|
|
logger.info(f"Step {step}, Loss: {loss.item():.4f}")
|
|
|
|
# Evaluation at end of each epoch
|
|
if eval_dataloader:
|
|
self.evaluate(eval_dataloader, epoch)
|
|
|
|
# Save model
|
|
self.save_model()
|
|
|
|
if self.accelerator.is_main_process:
|
|
logger.info("=== Training Completed ===")
|
|
|
|
def evaluate(self, eval_dataloader, epoch):
|
|
"""Evaluation function"""
|
|
self.model.eval()
|
|
|
|
all_predictions = []
|
|
all_labels = []
|
|
|
|
for batch in eval_dataloader:
|
|
with torch.no_grad():
|
|
outputs = self.model(**batch)
|
|
|
|
predictions = outputs.logits
|
|
labels = batch["labels"]
|
|
|
|
# Gather from all processes
|
|
predictions = self.accelerator.gather_for_metrics(predictions)
|
|
labels = self.accelerator.gather_for_metrics(labels)
|
|
|
|
all_predictions.append(predictions.cpu())
|
|
all_labels.append(labels.cpu())
|
|
|
|
if self.accelerator.is_main_process and all_predictions:
|
|
all_predictions = torch.cat(all_predictions)
|
|
all_labels = torch.cat(all_labels)
|
|
|
|
predictions_np = all_predictions.argmax(dim=-1).numpy()
|
|
labels_np = all_labels.numpy()
|
|
|
|
# Simple accuracy calculation
|
|
accuracy = (predictions_np == labels_np).mean()
|
|
logger.info(f"Epoch {epoch + 1} - Validation Accuracy: {accuracy:.4f}")
|
|
|
|
self.model.train()
|
|
|
|
def save_model(self):
|
|
"""Save model and tokenizer"""
|
|
if self.accelerator.is_main_process:
|
|
output_path = Path(self.config.output_dir)
|
|
output_path.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Save model using accelerator
|
|
unwrapped_model = self.accelerator.unwrap_model(self.model)
|
|
unwrapped_model.save_pretrained(output_path)
|
|
self.tokenizer.save_pretrained(output_path)
|
|
|
|
# Save label info
|
|
label_info = {
|
|
"label_to_id": self.label_to_id,
|
|
"id_to_label": self.id_to_label,
|
|
"num_labels": self.num_labels
|
|
}
|
|
|
|
with open(output_path / "label_info.json", 'w') as f:
|
|
json.dump(label_info, f, indent=2)
|
|
|
|
logger.info(f"Model saved to {output_path}")
|
|
|
|
|
|
def main():
|
|
"""Main function with YAML configuration support"""
|
|
|
|
parser = argparse.ArgumentParser(description="Accelerate Training Pipeline")
|
|
|
|
# YAML configuration
|
|
parser.add_argument("--config", type=str, help="Path to YAML configuration file")
|
|
|
|
# Model settings
|
|
parser.add_argument("--model-name", type=str, help="Model name from HuggingFace Hub")
|
|
parser.add_argument("--max-length", type=int, help="Maximum sequence length for tokenization")
|
|
|
|
# Training settings
|
|
parser.add_argument("--num-epochs", type=int, help="Number of training epochs")
|
|
parser.add_argument("--batch-size", type=int, help="Training batch size")
|
|
parser.add_argument("--learning-rate", type=float, help="Learning rate")
|
|
parser.add_argument("--weight-decay", type=float, help="Weight decay for optimizer")
|
|
|
|
# Scheduler settings
|
|
parser.add_argument("--lr-scheduler-type", choices=["linear", "cosine", "polynomial"], help="Learning rate scheduler type")
|
|
parser.add_argument("--warmup-ratio", type=float, help="Warmup ratio for scheduler")
|
|
|
|
# Paths
|
|
parser.add_argument("--data-dir", type=str, help="Directory containing train/validation/test JSONL files")
|
|
parser.add_argument("--output-dir", type=str, help="Output directory for saved model")
|
|
|
|
# Logging
|
|
parser.add_argument("--log-level", choices=["DEBUG", "INFO", "WARNING", "ERROR"], default="INFO", help="Logging level")
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Set up logging
|
|
logging.basicConfig(
|
|
level=getattr(logging, args.log_level),
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
|
)
|
|
|
|
# Load configuration
|
|
config_dict = {}
|
|
|
|
# Load YAML config if provided
|
|
if args.config:
|
|
try:
|
|
with open(args.config, 'r', encoding='utf-8') as f:
|
|
config_dict = yaml.safe_load(f)
|
|
logger.info(f"Loaded YAML configuration from: {args.config}")
|
|
except Exception as e:
|
|
logger.error(f"Error loading YAML config: {e}")
|
|
sys.exit(1)
|
|
|
|
# Override YAML config with CLI arguments
|
|
cli_overrides = {}
|
|
if args.model_name:
|
|
cli_overrides['model_name'] = args.model_name
|
|
if args.max_length:
|
|
cli_overrides['max_length'] = args.max_length
|
|
if args.num_epochs:
|
|
cli_overrides['num_epochs'] = args.num_epochs
|
|
if args.batch_size:
|
|
cli_overrides['batch_size'] = args.batch_size
|
|
if args.learning_rate:
|
|
cli_overrides['learning_rate'] = args.learning_rate
|
|
if args.weight_decay:
|
|
cli_overrides['weight_decay'] = args.weight_decay
|
|
if args.lr_scheduler_type:
|
|
cli_overrides['lr_scheduler_type'] = args.lr_scheduler_type
|
|
if args.warmup_ratio:
|
|
cli_overrides['warmup_ratio'] = args.warmup_ratio
|
|
if args.data_dir:
|
|
cli_overrides['data_dir'] = args.data_dir
|
|
if args.output_dir:
|
|
cli_overrides['output_dir'] = args.output_dir
|
|
|
|
# Merge configurations
|
|
for key, value in cli_overrides.items():
|
|
if key in config_dict:
|
|
logger.info(f"Overriding YAML config '{key}' with CLI value: {value}")
|
|
config_dict[key] = value
|
|
|
|
# Create configuration object
|
|
config = SimpleConfig(
|
|
model_name=config_dict.get('model_name', 'bert-base-uncased'),
|
|
max_length=config_dict.get('max_length', 512),
|
|
num_epochs=config_dict.get('num_epochs', 3),
|
|
batch_size=config_dict.get('batch_size', 16),
|
|
learning_rate=config_dict.get('learning_rate', 2e-5),
|
|
weight_decay=config_dict.get('weight_decay', 0.01),
|
|
lr_scheduler_type=config_dict.get('lr_scheduler_type', 'linear'),
|
|
warmup_ratio=config_dict.get('warmup_ratio', 0.1),
|
|
data_dir=config_dict.get('data_dir', './data/classification'),
|
|
output_dir=config_dict.get('output_dir', './results')
|
|
)
|
|
|
|
# Validate data directory
|
|
data_path = Path(config.data_dir)
|
|
if not data_path.exists():
|
|
print(f"❌ Data directory not found: {data_path}")
|
|
print("Please ensure the data directory exists and contains train.jsonl, validation.jsonl, and test.jsonl files")
|
|
sys.exit(1)
|
|
|
|
# Check for required data files
|
|
required_files = ["train.jsonl"]
|
|
missing_files = []
|
|
for file_name in required_files:
|
|
if not (data_path / file_name).exists():
|
|
missing_files.append(file_name)
|
|
|
|
if missing_files:
|
|
print(f"❌ Missing required data files: {missing_files}")
|
|
print(f"Please ensure these files exist in: {data_path}")
|
|
sys.exit(1)
|
|
|
|
# Initialize and run training
|
|
try:
|
|
print(f"Starting training with model: {config.model_name}")
|
|
print(f"Data directory: {config.data_dir}")
|
|
print(f"Output directory: {config.output_dir}")
|
|
print(f"Training for {config.num_epochs} epochs with batch size {config.batch_size}")
|
|
if args.config:
|
|
print(f"Using YAML configuration: {args.config}")
|
|
print()
|
|
|
|
trainer = AccelerateTrainer(config)
|
|
trainer.train()
|
|
|
|
print(f"✅ Training completed successfully!")
|
|
print(f"Model saved to: {config.output_dir}")
|
|
|
|
except Exception as e:
|
|
print(f"❌ Error during training: {e}")
|
|
sys.exit(1)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|
|
|