import torch from torch.optim import AdamW from torch.utils.data import DataLoader from transformers import ( AutoTokenizer, AutoModelForSequenceClassification, DataCollatorWithPadding, get_scheduler ) from accelerate import Accelerator from datasets import Dataset from tqdm.auto import tqdm import evaluate import json from pathlib import Path from dataclasses import dataclass from typing import Dict, List, Optional import logging import argparse import sys import yaml logger = logging.getLogger(__name__) @dataclass class SimpleConfig: """Simple configuration for accelerate training""" # Model settings model_name: str = "bert-base-uncased" max_length: int = 512 # Training settings num_epochs: int = 3 batch_size: int = 16 learning_rate: float = 2e-5 weight_decay: float = 0.01 # Scheduler settings lr_scheduler_type: str = "linear" warmup_ratio: float = 0.1 # Paths data_dir: str = "./data/classification" output_dir: str = "./results" class AccelerateTrainer: """Simple trainer using Accelerate for distributed training""" def __init__(self, config: SimpleConfig): self.config = config # Initialize accelerator self.accelerator = Accelerator() # Setup logging only on main process if self.accelerator.is_main_process: logging.basicConfig( format='%(asctime)s - %(levelname)s - %(message)s', level=logging.INFO ) self.tokenizer = None self.model = None self.label_to_id = {} self.id_to_label = {} self.num_labels = 0 def load_data(self) -> Dict[str, List[Dict]]: """Load data from JSONL files""" data_path = Path(self.config.data_dir) splits = {} for split_name in ["train", "validation", "test"]: split_file = data_path / f"{split_name}.jsonl" if split_file.exists(): split_data = [] with open(split_file, 'r', encoding='utf-8') as f: for line in f: if line.strip(): split_data.append(json.loads(line)) splits[split_name] = split_data if self.accelerator.is_main_process: logger.info(f"Loaded {len(split_data)} samples from {split_name}") return splits def setup_labels(self, train_data: List[Dict]): """Setup label mappings""" labels = set() for item in train_data: labels.add(str(item["label"])) sorted_labels = sorted(list(labels)) self.label_to_id = {label: idx for idx, label in enumerate(sorted_labels)} self.id_to_label = {idx: label for label, idx in self.label_to_id.items()} self.num_labels = len(sorted_labels) if self.accelerator.is_main_process: logger.info(f"Found {self.num_labels} labels: {sorted_labels}") def create_dataset(self, data: List[Dict]) -> Dataset: """Create tokenized dataset""" texts = [] labels = [] for item in data: text = item["text"] if "text" in item else item["input"] label = item["label"] texts.append(str(text)) # Convert label to ID if isinstance(label, str): label_id = self.label_to_id.get(label, 0) else: label_id = int(label) labels.append(label_id) # Create dataset dataset = Dataset.from_dict({ "text": texts, "labels": labels }) # Tokenize def tokenize_function(examples): return self.tokenizer( examples["text"], truncation=True, padding="max_length", max_length=self.config.max_length ) tokenized_dataset = dataset.map( tokenize_function, batched=True, remove_columns=["text"] ) return tokenized_dataset def compute_metrics(self, predictions, labels): """Compute accuracy and F1""" predictions = predictions.argmax(axis=-1) # Gather predictions from all processes all_predictions = self.accelerator.gather_for_metrics(predictions) all_labels = self.accelerator.gather_for_metrics(labels) if self.accelerator.is_main_process: # Only compute metrics on main process metric = evaluate.load("glue", "mrpc") # Using MRPC as example results = metric.compute( predictions=all_predictions.cpu().numpy(), references=all_labels.cpu().numpy() ) return results return {} def train(self): """Main training function""" if self.accelerator.is_main_process: logger.info("=== Starting Accelerate Training ===") # Load data splits_data = self.load_data() if "train" not in splits_data: raise ValueError("No training data found!") # Setup labels self.setup_labels(splits_data["train"]) # Initialize tokenizer and model self.tokenizer = AutoTokenizer.from_pretrained(self.config.model_name) if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token self.model = AutoModelForSequenceClassification.from_pretrained( self.config.model_name, num_labels=self.num_labels, id2label=self.id_to_label, label2id=self.label_to_id ) # Create datasets train_dataset = self.create_dataset(splits_data["train"]) eval_dataset = None if "validation" in splits_data: eval_dataset = self.create_dataset(splits_data["validation"]) # Create data loaders data_collator = DataCollatorWithPadding(tokenizer=self.tokenizer) train_dataloader = DataLoader( train_dataset, shuffle=True, batch_size=self.config.batch_size, collate_fn=data_collator ) eval_dataloader = None if eval_dataset: eval_dataloader = DataLoader( eval_dataset, batch_size=self.config.batch_size, collate_fn=data_collator ) # Setup optimizer and scheduler optimizer = AdamW( self.model.parameters(), lr=self.config.learning_rate, weight_decay=self.config.weight_decay ) num_training_steps = self.config.num_epochs * len(train_dataloader) num_warmup_steps = int(num_training_steps * self.config.warmup_ratio) lr_scheduler = get_scheduler( self.config.lr_scheduler_type, optimizer=optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps ) # Prepare everything with accelerator self.model, optimizer, train_dataloader, lr_scheduler = self.accelerator.prepare( self.model, optimizer, train_dataloader, lr_scheduler ) if eval_dataloader: eval_dataloader = self.accelerator.prepare(eval_dataloader) # Training loop if self.accelerator.is_main_process: progress_bar = tqdm(range(num_training_steps)) logger.info(f"Training steps: {num_training_steps}") self.model.train() for epoch in range(self.config.num_epochs): if self.accelerator.is_main_process: logger.info(f"Epoch {epoch + 1}/{self.config.num_epochs}") for step, batch in enumerate(train_dataloader): outputs = self.model(**batch) loss = outputs.loss # Backward pass self.accelerator.backward(loss) optimizer.step() lr_scheduler.step() optimizer.zero_grad() if self.accelerator.is_main_process: progress_bar.update(1) # Log every 100 steps if step % 100 == 0: logger.info(f"Step {step}, Loss: {loss.item():.4f}") # Evaluation at end of each epoch if eval_dataloader: self.evaluate(eval_dataloader, epoch) # Save model self.save_model() if self.accelerator.is_main_process: logger.info("=== Training Completed ===") def evaluate(self, eval_dataloader, epoch): """Evaluation function""" self.model.eval() all_predictions = [] all_labels = [] for batch in eval_dataloader: with torch.no_grad(): outputs = self.model(**batch) predictions = outputs.logits labels = batch["labels"] # Gather from all processes predictions = self.accelerator.gather_for_metrics(predictions) labels = self.accelerator.gather_for_metrics(labels) all_predictions.append(predictions.cpu()) all_labels.append(labels.cpu()) if self.accelerator.is_main_process and all_predictions: all_predictions = torch.cat(all_predictions) all_labels = torch.cat(all_labels) predictions_np = all_predictions.argmax(dim=-1).numpy() labels_np = all_labels.numpy() # Simple accuracy calculation accuracy = (predictions_np == labels_np).mean() logger.info(f"Epoch {epoch + 1} - Validation Accuracy: {accuracy:.4f}") self.model.train() def save_model(self): """Save model and tokenizer""" if self.accelerator.is_main_process: output_path = Path(self.config.output_dir) output_path.mkdir(parents=True, exist_ok=True) # Save model using accelerator unwrapped_model = self.accelerator.unwrap_model(self.model) unwrapped_model.save_pretrained(output_path) self.tokenizer.save_pretrained(output_path) # Save label info label_info = { "label_to_id": self.label_to_id, "id_to_label": self.id_to_label, "num_labels": self.num_labels } with open(output_path / "label_info.json", 'w') as f: json.dump(label_info, f, indent=2) logger.info(f"Model saved to {output_path}") def main(): """Main function with YAML configuration support""" parser = argparse.ArgumentParser(description="Accelerate Training Pipeline") # YAML configuration parser.add_argument("--config", type=str, help="Path to YAML configuration file") # Model settings parser.add_argument("--model-name", type=str, help="Model name from HuggingFace Hub") parser.add_argument("--max-length", type=int, help="Maximum sequence length for tokenization") # Training settings parser.add_argument("--num-epochs", type=int, help="Number of training epochs") parser.add_argument("--batch-size", type=int, help="Training batch size") parser.add_argument("--learning-rate", type=float, help="Learning rate") parser.add_argument("--weight-decay", type=float, help="Weight decay for optimizer") # Scheduler settings parser.add_argument("--lr-scheduler-type", choices=["linear", "cosine", "polynomial"], help="Learning rate scheduler type") parser.add_argument("--warmup-ratio", type=float, help="Warmup ratio for scheduler") # Paths parser.add_argument("--data-dir", type=str, help="Directory containing train/validation/test JSONL files") parser.add_argument("--output-dir", type=str, help="Output directory for saved model") # Logging parser.add_argument("--log-level", choices=["DEBUG", "INFO", "WARNING", "ERROR"], default="INFO", help="Logging level") args = parser.parse_args() # Set up logging logging.basicConfig( level=getattr(logging, args.log_level), format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) # Load configuration config_dict = {} # Load YAML config if provided if args.config: try: with open(args.config, 'r', encoding='utf-8') as f: config_dict = yaml.safe_load(f) logger.info(f"Loaded YAML configuration from: {args.config}") except Exception as e: logger.error(f"Error loading YAML config: {e}") sys.exit(1) # Override YAML config with CLI arguments cli_overrides = {} if args.model_name: cli_overrides['model_name'] = args.model_name if args.max_length: cli_overrides['max_length'] = args.max_length if args.num_epochs: cli_overrides['num_epochs'] = args.num_epochs if args.batch_size: cli_overrides['batch_size'] = args.batch_size if args.learning_rate: cli_overrides['learning_rate'] = args.learning_rate if args.weight_decay: cli_overrides['weight_decay'] = args.weight_decay if args.lr_scheduler_type: cli_overrides['lr_scheduler_type'] = args.lr_scheduler_type if args.warmup_ratio: cli_overrides['warmup_ratio'] = args.warmup_ratio if args.data_dir: cli_overrides['data_dir'] = args.data_dir if args.output_dir: cli_overrides['output_dir'] = args.output_dir # Merge configurations for key, value in cli_overrides.items(): if key in config_dict: logger.info(f"Overriding YAML config '{key}' with CLI value: {value}") config_dict[key] = value # Create configuration object config = SimpleConfig( model_name=config_dict.get('model_name', 'bert-base-uncased'), max_length=config_dict.get('max_length', 512), num_epochs=config_dict.get('num_epochs', 3), batch_size=config_dict.get('batch_size', 16), learning_rate=config_dict.get('learning_rate', 2e-5), weight_decay=config_dict.get('weight_decay', 0.01), lr_scheduler_type=config_dict.get('lr_scheduler_type', 'linear'), warmup_ratio=config_dict.get('warmup_ratio', 0.1), data_dir=config_dict.get('data_dir', './data/classification'), output_dir=config_dict.get('output_dir', './results') ) # Validate data directory data_path = Path(config.data_dir) if not data_path.exists(): print(f"❌ Data directory not found: {data_path}") print("Please ensure the data directory exists and contains train.jsonl, validation.jsonl, and test.jsonl files") sys.exit(1) # Check for required data files required_files = ["train.jsonl"] missing_files = [] for file_name in required_files: if not (data_path / file_name).exists(): missing_files.append(file_name) if missing_files: print(f"❌ Missing required data files: {missing_files}") print(f"Please ensure these files exist in: {data_path}") sys.exit(1) # Initialize and run training try: print(f"Starting training with model: {config.model_name}") print(f"Data directory: {config.data_dir}") print(f"Output directory: {config.output_dir}") print(f"Training for {config.num_epochs} epochs with batch size {config.batch_size}") if args.config: print(f"Using YAML configuration: {args.config}") print() trainer = AccelerateTrainer(config) trainer.train() print(f"✅ Training completed successfully!") print(f"Model saved to: {config.output_dir}") except Exception as e: print(f"❌ Error during training: {e}") sys.exit(1) if __name__ == "__main__": main()