initial setupt

This commit is contained in:
OwusuBlessing
2025-08-06 22:45:37 +01:00
commit fef3f5ae35
42 changed files with 7147 additions and 0 deletions
File diff suppressed because it is too large Load Diff
+481
View File
@@ -0,0 +1,481 @@
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from pathlib import Path
import json
import numpy as np
from typing import List, Dict, Union, Optional
from dataclasses import dataclass
import logging
import argparse
import sys
import yaml
logger = logging.getLogger(__name__)
@dataclass
class InferenceConfig:
"""Simple configuration for inference"""
model_path: str # Path to saved model
device: str = "auto" # "auto", "cuda", "cpu"
batch_size: int = 32
max_length: int = 512
return_probabilities: bool = True
return_top_k: int = 1 # Return top K predictions
class ModelInference:
"""Simple inference class for text classification"""
def __init__(self, config: InferenceConfig):
self.config = config
self.model = None
self.tokenizer = None
self.label_info = {}
self.device = self._setup_device()
# Load model and tokenizer
self.load_model()
def _setup_device(self) -> torch.device:
"""Setup device for inference"""
if self.config.device == "auto":
if torch.cuda.is_available():
device = torch.device("cuda")
logger.info("Using CUDA device")
else:
device = torch.device("cpu")
logger.info("Using CPU device")
else:
device = torch.device(self.config.device)
logger.info(f"Using specified device: {device}")
return device
def load_model(self):
"""Load model, tokenizer, and label mappings"""
model_path = Path(self.config.model_path)
if not model_path.exists():
raise FileNotFoundError(f"Model path not found: {model_path}")
# Load tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
logger.info(f"Loaded tokenizer from {model_path}")
# Load model
self.model = AutoModelForSequenceClassification.from_pretrained(model_path)
self.model.to(self.device)
self.model.eval()
logger.info(f"Loaded model from {model_path}")
# Load label information
label_info_path = model_path / "label_info.json"
if label_info_path.exists():
with open(label_info_path, 'r') as f:
self.label_info = json.load(f)
logger.info(f"Loaded label mappings: {self.label_info.get('id_to_label', {})}")
else:
logger.warning("No label_info.json found. Using default numeric labels.")
# Create default mappings
num_labels = self.model.config.num_labels
self.label_info = {
"id_to_label": {str(i): f"LABEL_{i}" for i in range(num_labels)},
"label_to_id": {f"LABEL_{i}": i for i in range(num_labels)},
"num_labels": num_labels
}
def predict_single(self, text: str) -> Dict:
"""Predict single text sample"""
return self.predict_batch([text])[0]
def predict_batch(self, texts: List[str]) -> List[Dict]:
"""Predict batch of texts"""
if not texts:
return []
# Tokenize inputs
inputs = self.tokenizer(
texts,
truncation=True,
padding=True,
max_length=self.config.max_length,
return_tensors="pt"
)
# Move to device
inputs = {k: v.to(self.device) for k, v in inputs.items()}
# Get predictions
with torch.no_grad():
outputs = self.model(**inputs)
logits = outputs.logits
# Convert to probabilities
probabilities = torch.softmax(logits, dim=-1)
# Process results
results = []
for i, text in enumerate(texts):
text_probs = probabilities[i].cpu().numpy()
# Get top K predictions
top_indices = np.argsort(text_probs)[-self.config.return_top_k:][::-1]
predictions = []
for idx in top_indices:
label = self.label_info["id_to_label"].get(str(idx), f"LABEL_{idx}")
prob = float(text_probs[idx])
pred_dict = {
"label": label,
"label_id": int(idx),
"score": prob
}
predictions.append(pred_dict)
result = {
"text": text,
"predictions": predictions,
"predicted_label": predictions[0]["label"],
"confidence": predictions[0]["score"]
}
if self.config.return_probabilities:
# Add all class probabilities
all_probs = {}
for label_id, label_name in self.label_info["id_to_label"].items():
all_probs[label_name] = float(text_probs[int(label_id)])
result["all_probabilities"] = all_probs
results.append(result)
return results
def predict_file(self, input_file: str, output_file: str = None) -> List[Dict]:
"""Predict on texts from file"""
input_path = Path(input_file)
# Read texts from file
texts = []
if input_path.suffix == '.txt':
# Plain text file (one text per line)
with open(input_path, 'r', encoding='utf-8') as f:
texts = [line.strip() for line in f if line.strip()]
elif input_path.suffix == '.jsonl':
# JSONL file with "text" field
with open(input_path, 'r', encoding='utf-8') as f:
for line in f:
if line.strip():
data = json.loads(line)
text = data.get("text", data.get("input", ""))
if text:
texts.append(text)
else:
raise ValueError(f"Unsupported file format: {input_path.suffix}")
logger.info(f"Loaded {len(texts)} texts from {input_file}")
# Process in batches
all_results = []
for i in range(0, len(texts), self.config.batch_size):
batch_texts = texts[i:i + self.config.batch_size]
batch_results = self.predict_batch(batch_texts)
all_results.extend(batch_results)
if i % (self.config.batch_size * 10) == 0:
logger.info(f"Processed {i + len(batch_texts)}/{len(texts)} texts")
# Save results if output file specified
if output_file:
output_path = Path(output_file)
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_path, 'w', encoding='utf-8') as f:
for result in all_results:
f.write(json.dumps(result) + '\n')
logger.info(f"Results saved to {output_file}")
return all_results
class BatchInference:
"""Optimized batch inference for large datasets"""
def __init__(self, config: InferenceConfig):
self.inference = ModelInference(config)
self.config = config
def predict_large_file(self, input_file: str, output_file: str,
chunk_size: int = 1000) -> None:
"""Process large files in chunks to manage memory"""
input_path = Path(input_file)
output_path = Path(output_file)
output_path.parent.mkdir(parents=True, exist_ok=True)
# Count total lines first
total_lines = 0
with open(input_path, 'r', encoding='utf-8') as f:
for _ in f:
total_lines += 1
logger.info(f"Processing {total_lines} lines in chunks of {chunk_size}")
processed = 0
with open(input_path, 'r', encoding='utf-8') as infile, \
open(output_path, 'w', encoding='utf-8') as outfile:
chunk_texts = []
for line in infile:
if line.strip():
if input_path.suffix == '.jsonl':
data = json.loads(line)
text = data.get("text", data.get("input", ""))
else:
text = line.strip()
chunk_texts.append(text)
if len(chunk_texts) >= chunk_size:
# Process chunk
results = self.inference.predict_batch(chunk_texts)
# Write results
for result in results:
outfile.write(json.dumps(result) + '\n')
processed += len(chunk_texts)
logger.info(f"Processed {processed}/{total_lines} texts")
chunk_texts = []
# Process remaining texts
if chunk_texts:
results = self.inference.predict_batch(chunk_texts)
for result in results:
outfile.write(json.dumps(result) + '\n')
processed += len(chunk_texts)
logger.info(f"Completed processing {processed} texts")
def create_inference_config(
model_path: str,
device: str = "auto",
batch_size: int = 32,
**kwargs
) -> InferenceConfig:
"""Helper function to create inference configuration"""
return InferenceConfig(
model_path=model_path,
device=device,
batch_size=batch_size,
**kwargs
)
def main():
"""Main function with YAML configuration support"""
parser = argparse.ArgumentParser(description="Model Inference Pipeline")
# YAML configuration
parser.add_argument("--config", type=str, help="Path to YAML configuration file")
# Model settings
parser.add_argument("--model-path", type=str, help="Path to saved model directory")
parser.add_argument("--device", choices=["auto", "cuda", "cpu"], help="Device to run inference on")
parser.add_argument("--batch-size", type=int, help="Batch size for inference")
parser.add_argument("--max-length", type=int, help="Maximum sequence length for tokenization")
# Inference settings
parser.add_argument("--return-probabilities", action="store_true", help="Return all class probabilities")
parser.add_argument("--return-top-k", type=int, help="Return top K predictions")
# Input/Output settings
parser.add_argument("--input-text", type=str, help="Single text for prediction")
parser.add_argument("--input-file", type=str, help="Input file path (txt or jsonl)")
parser.add_argument("--output-file", type=str, help="Output file path for results")
parser.add_argument("--chunk-size", type=int, help="Chunk size for large file processing")
# Logging
parser.add_argument("--log-level", choices=["DEBUG", "INFO", "WARNING", "ERROR"], default="INFO", help="Logging level")
args = parser.parse_args()
# Set up logging
logging.basicConfig(
level=getattr(logging, args.log_level),
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
# Load configuration
config_dict = {}
# Load YAML config if provided
if args.config:
try:
with open(args.config, 'r', encoding='utf-8') as f:
config_dict = yaml.safe_load(f)
logger.info(f"Loaded YAML configuration from: {args.config}")
except Exception as e:
logger.error(f"Error loading YAML config: {e}")
sys.exit(1)
# Override YAML config with CLI arguments
cli_overrides = {}
if args.model_path:
cli_overrides['model_path'] = args.model_path
if args.device:
cli_overrides['device'] = args.device
if args.batch_size:
cli_overrides['batch_size'] = args.batch_size
if args.max_length:
cli_overrides['max_length'] = args.max_length
if args.return_probabilities:
cli_overrides['return_probabilities'] = True
if args.return_top_k:
cli_overrides['return_top_k'] = args.return_top_k
# Merge configurations
for key, value in cli_overrides.items():
if key in config_dict:
logger.info(f"Overriding YAML config '{key}' with CLI value: {value}")
config_dict[key] = value
# Validate model path
model_path = config_dict.get('model_path')
if not model_path:
parser.error("--model-path is required (either in YAML config or CLI)")
model_path = Path(model_path)
if not model_path.exists():
print(f"❌ Model path not found: {model_path}")
print("Please ensure the model directory exists and contains the trained model files")
sys.exit(1)
# Create configuration object
config = InferenceConfig(
model_path=str(model_path),
device=config_dict.get('device', 'auto'),
batch_size=config_dict.get('batch_size', 32),
max_length=config_dict.get('max_length', 512),
return_probabilities=config_dict.get('return_probabilities', True),
return_top_k=config_dict.get('return_top_k', 1)
)
try:
print(f"Loading model from: {config.model_path}")
print(f"Device: {config.device}")
print(f"Batch size: {config.batch_size}")
if args.config:
print(f"Using YAML configuration: {args.config}")
print()
# Initialize inference
inference = ModelInference(config)
# Handle different input types
if args.input_text:
# Single text prediction
print(f"=== Single Text Prediction ===")
print(f"Input text: {args.input_text}")
print()
result = inference.predict_single(args.input_text)
print(f"Predicted label: {result['predicted_label']}")
print(f"Confidence: {result['confidence']:.4f}")
print(f"Top {config.return_top_k} predictions:")
for pred in result['predictions']:
print(f" - {pred['label']}: {pred['score']:.4f}")
if config.return_probabilities and 'all_probabilities' in result:
print(f"\nAll class probabilities:")
for label, prob in result['all_probabilities'].items():
print(f" - {label}: {prob:.4f}")
elif args.input_file:
# File prediction
input_path = Path(args.input_file)
if not input_path.exists():
print(f"❌ Input file not found: {input_path}")
sys.exit(1)
print(f"=== File Prediction ===")
print(f"Input file: {args.input_file}")
print(f"Output file: {args.output_file}")
print()
if args.output_file:
# Use batch inference for large files
if input_path.stat().st_size > 10 * 1024 * 1024: # > 10MB
print("Large file detected, using chunked processing...")
batch_inference = BatchInference(config)
chunk_size = args.chunk_size or 1000
batch_inference.predict_large_file(
args.input_file,
args.output_file,
chunk_size=chunk_size
)
else:
# Regular file processing
results = inference.predict_file(args.input_file, args.output_file)
print(f"Processed {len(results)} texts")
else:
# Just predict without saving
results = inference.predict_file(args.input_file)
print(f"Processed {len(results)} texts")
print("\nSample results:")
for i, result in enumerate(results[:3]): # Show first 3
print(f" {i+1}. '{result['text'][:50]}...' -> {result['predicted_label']} ({result['confidence']:.4f})")
else:
# Interactive mode - example predictions
print("=== Interactive Mode ===")
print("No input specified. Running example predictions...")
print()
# Example texts
example_texts = [
"I love this product! It's amazing.",
"This is terrible, I hate it.",
"The weather is okay today.",
"Best purchase ever made!"
]
print("Example predictions:")
for text in example_texts:
result = inference.predict_single(text)
print(f" '{text}' -> {result['predicted_label']} ({result['confidence']:.4f})")
print(f"\n✅ Inference completed successfully!")
except Exception as e:
print(f"❌ Inference failed: {e}")
sys.exit(1)
if __name__ == "__main__":
main()
"""
# Quick usage examples:
# 1. Single prediction
config = InferenceConfig(model_path="./results/my_model")
inference = ModelInference(config)
result = inference.predict_single("Your text here")
# # 2. Batch prediction
# results = inference.predict_batch(["text1", "text2", "text3"])
# # 3. File prediction
# inference.predict_file("input.txt", "predictions.jsonl")
# # 4. Large file processing
# batch_inference = BatchInference(config)
# batch_inference.predict_large_file("large_input.jsonl", "large_output.jsonl", chunk_size=1000)
"""
+467
View File
@@ -0,0 +1,467 @@
import torch
from torch.optim import AdamW
from torch.utils.data import DataLoader
from transformers import (
AutoTokenizer,
AutoModelForSequenceClassification,
DataCollatorWithPadding,
get_scheduler
)
from accelerate import Accelerator
from datasets import Dataset
from tqdm.auto import tqdm
import evaluate
import json
from pathlib import Path
from dataclasses import dataclass
from typing import Dict, List, Optional
import logging
import argparse
import sys
import yaml
logger = logging.getLogger(__name__)
@dataclass
class SimpleConfig:
"""Simple configuration for accelerate training"""
# Model settings
model_name: str = "bert-base-uncased"
max_length: int = 512
# Training settings
num_epochs: int = 3
batch_size: int = 16
learning_rate: float = 2e-5
weight_decay: float = 0.01
# Scheduler settings
lr_scheduler_type: str = "linear"
warmup_ratio: float = 0.1
# Paths
data_dir: str = "./data/classification"
output_dir: str = "./results"
class AccelerateTrainer:
"""Simple trainer using Accelerate for distributed training"""
def __init__(self, config: SimpleConfig):
self.config = config
# Initialize accelerator
self.accelerator = Accelerator()
# Setup logging only on main process
if self.accelerator.is_main_process:
logging.basicConfig(
format='%(asctime)s - %(levelname)s - %(message)s',
level=logging.INFO
)
self.tokenizer = None
self.model = None
self.label_to_id = {}
self.id_to_label = {}
self.num_labels = 0
def load_data(self) -> Dict[str, List[Dict]]:
"""Load data from JSONL files"""
data_path = Path(self.config.data_dir)
splits = {}
for split_name in ["train", "validation", "test"]:
split_file = data_path / f"{split_name}.jsonl"
if split_file.exists():
split_data = []
with open(split_file, 'r', encoding='utf-8') as f:
for line in f:
if line.strip():
split_data.append(json.loads(line))
splits[split_name] = split_data
if self.accelerator.is_main_process:
logger.info(f"Loaded {len(split_data)} samples from {split_name}")
return splits
def setup_labels(self, train_data: List[Dict]):
"""Setup label mappings"""
labels = set()
for item in train_data:
labels.add(str(item["label"]))
sorted_labels = sorted(list(labels))
self.label_to_id = {label: idx for idx, label in enumerate(sorted_labels)}
self.id_to_label = {idx: label for label, idx in self.label_to_id.items()}
self.num_labels = len(sorted_labels)
if self.accelerator.is_main_process:
logger.info(f"Found {self.num_labels} labels: {sorted_labels}")
def create_dataset(self, data: List[Dict]) -> Dataset:
"""Create tokenized dataset"""
texts = []
labels = []
for item in data:
text = item["text"] if "text" in item else item["input"]
label = item["label"]
texts.append(str(text))
# Convert label to ID
if isinstance(label, str):
label_id = self.label_to_id.get(label, 0)
else:
label_id = int(label)
labels.append(label_id)
# Create dataset
dataset = Dataset.from_dict({
"text": texts,
"labels": labels
})
# Tokenize
def tokenize_function(examples):
return self.tokenizer(
examples["text"],
truncation=True,
padding="max_length",
max_length=self.config.max_length
)
tokenized_dataset = dataset.map(
tokenize_function,
batched=True,
remove_columns=["text"]
)
return tokenized_dataset
def compute_metrics(self, predictions, labels):
"""Compute accuracy and F1"""
predictions = predictions.argmax(axis=-1)
# Gather predictions from all processes
all_predictions = self.accelerator.gather_for_metrics(predictions)
all_labels = self.accelerator.gather_for_metrics(labels)
if self.accelerator.is_main_process:
# Only compute metrics on main process
metric = evaluate.load("glue", "mrpc") # Using MRPC as example
results = metric.compute(
predictions=all_predictions.cpu().numpy(),
references=all_labels.cpu().numpy()
)
return results
return {}
def train(self):
"""Main training function"""
if self.accelerator.is_main_process:
logger.info("=== Starting Accelerate Training ===")
# Load data
splits_data = self.load_data()
if "train" not in splits_data:
raise ValueError("No training data found!")
# Setup labels
self.setup_labels(splits_data["train"])
# Initialize tokenizer and model
self.tokenizer = AutoTokenizer.from_pretrained(self.config.model_name)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.model = AutoModelForSequenceClassification.from_pretrained(
self.config.model_name,
num_labels=self.num_labels,
id2label=self.id_to_label,
label2id=self.label_to_id
)
# Create datasets
train_dataset = self.create_dataset(splits_data["train"])
eval_dataset = None
if "validation" in splits_data:
eval_dataset = self.create_dataset(splits_data["validation"])
# Create data loaders
data_collator = DataCollatorWithPadding(tokenizer=self.tokenizer)
train_dataloader = DataLoader(
train_dataset,
shuffle=True,
batch_size=self.config.batch_size,
collate_fn=data_collator
)
eval_dataloader = None
if eval_dataset:
eval_dataloader = DataLoader(
eval_dataset,
batch_size=self.config.batch_size,
collate_fn=data_collator
)
# Setup optimizer and scheduler
optimizer = AdamW(
self.model.parameters(),
lr=self.config.learning_rate,
weight_decay=self.config.weight_decay
)
num_training_steps = self.config.num_epochs * len(train_dataloader)
num_warmup_steps = int(num_training_steps * self.config.warmup_ratio)
lr_scheduler = get_scheduler(
self.config.lr_scheduler_type,
optimizer=optimizer,
num_warmup_steps=num_warmup_steps,
num_training_steps=num_training_steps
)
# Prepare everything with accelerator
self.model, optimizer, train_dataloader, lr_scheduler = self.accelerator.prepare(
self.model, optimizer, train_dataloader, lr_scheduler
)
if eval_dataloader:
eval_dataloader = self.accelerator.prepare(eval_dataloader)
# Training loop
if self.accelerator.is_main_process:
progress_bar = tqdm(range(num_training_steps))
logger.info(f"Training steps: {num_training_steps}")
self.model.train()
for epoch in range(self.config.num_epochs):
if self.accelerator.is_main_process:
logger.info(f"Epoch {epoch + 1}/{self.config.num_epochs}")
for step, batch in enumerate(train_dataloader):
outputs = self.model(**batch)
loss = outputs.loss
# Backward pass
self.accelerator.backward(loss)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
if self.accelerator.is_main_process:
progress_bar.update(1)
# Log every 100 steps
if step % 100 == 0:
logger.info(f"Step {step}, Loss: {loss.item():.4f}")
# Evaluation at end of each epoch
if eval_dataloader:
self.evaluate(eval_dataloader, epoch)
# Save model
self.save_model()
if self.accelerator.is_main_process:
logger.info("=== Training Completed ===")
def evaluate(self, eval_dataloader, epoch):
"""Evaluation function"""
self.model.eval()
all_predictions = []
all_labels = []
for batch in eval_dataloader:
with torch.no_grad():
outputs = self.model(**batch)
predictions = outputs.logits
labels = batch["labels"]
# Gather from all processes
predictions = self.accelerator.gather_for_metrics(predictions)
labels = self.accelerator.gather_for_metrics(labels)
all_predictions.append(predictions.cpu())
all_labels.append(labels.cpu())
if self.accelerator.is_main_process and all_predictions:
all_predictions = torch.cat(all_predictions)
all_labels = torch.cat(all_labels)
predictions_np = all_predictions.argmax(dim=-1).numpy()
labels_np = all_labels.numpy()
# Simple accuracy calculation
accuracy = (predictions_np == labels_np).mean()
logger.info(f"Epoch {epoch + 1} - Validation Accuracy: {accuracy:.4f}")
self.model.train()
def save_model(self):
"""Save model and tokenizer"""
if self.accelerator.is_main_process:
output_path = Path(self.config.output_dir)
output_path.mkdir(parents=True, exist_ok=True)
# Save model using accelerator
unwrapped_model = self.accelerator.unwrap_model(self.model)
unwrapped_model.save_pretrained(output_path)
self.tokenizer.save_pretrained(output_path)
# Save label info
label_info = {
"label_to_id": self.label_to_id,
"id_to_label": self.id_to_label,
"num_labels": self.num_labels
}
with open(output_path / "label_info.json", 'w') as f:
json.dump(label_info, f, indent=2)
logger.info(f"Model saved to {output_path}")
def main():
"""Main function with YAML configuration support"""
parser = argparse.ArgumentParser(description="Accelerate Training Pipeline")
# YAML configuration
parser.add_argument("--config", type=str, help="Path to YAML configuration file")
# Model settings
parser.add_argument("--model-name", type=str, help="Model name from HuggingFace Hub")
parser.add_argument("--max-length", type=int, help="Maximum sequence length for tokenization")
# Training settings
parser.add_argument("--num-epochs", type=int, help="Number of training epochs")
parser.add_argument("--batch-size", type=int, help="Training batch size")
parser.add_argument("--learning-rate", type=float, help="Learning rate")
parser.add_argument("--weight-decay", type=float, help="Weight decay for optimizer")
# Scheduler settings
parser.add_argument("--lr-scheduler-type", choices=["linear", "cosine", "polynomial"], help="Learning rate scheduler type")
parser.add_argument("--warmup-ratio", type=float, help="Warmup ratio for scheduler")
# Paths
parser.add_argument("--data-dir", type=str, help="Directory containing train/validation/test JSONL files")
parser.add_argument("--output-dir", type=str, help="Output directory for saved model")
# Logging
parser.add_argument("--log-level", choices=["DEBUG", "INFO", "WARNING", "ERROR"], default="INFO", help="Logging level")
args = parser.parse_args()
# Set up logging
logging.basicConfig(
level=getattr(logging, args.log_level),
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
# Load configuration
config_dict = {}
# Load YAML config if provided
if args.config:
try:
with open(args.config, 'r', encoding='utf-8') as f:
config_dict = yaml.safe_load(f)
logger.info(f"Loaded YAML configuration from: {args.config}")
except Exception as e:
logger.error(f"Error loading YAML config: {e}")
sys.exit(1)
# Override YAML config with CLI arguments
cli_overrides = {}
if args.model_name:
cli_overrides['model_name'] = args.model_name
if args.max_length:
cli_overrides['max_length'] = args.max_length
if args.num_epochs:
cli_overrides['num_epochs'] = args.num_epochs
if args.batch_size:
cli_overrides['batch_size'] = args.batch_size
if args.learning_rate:
cli_overrides['learning_rate'] = args.learning_rate
if args.weight_decay:
cli_overrides['weight_decay'] = args.weight_decay
if args.lr_scheduler_type:
cli_overrides['lr_scheduler_type'] = args.lr_scheduler_type
if args.warmup_ratio:
cli_overrides['warmup_ratio'] = args.warmup_ratio
if args.data_dir:
cli_overrides['data_dir'] = args.data_dir
if args.output_dir:
cli_overrides['output_dir'] = args.output_dir
# Merge configurations
for key, value in cli_overrides.items():
if key in config_dict:
logger.info(f"Overriding YAML config '{key}' with CLI value: {value}")
config_dict[key] = value
# Create configuration object
config = SimpleConfig(
model_name=config_dict.get('model_name', 'bert-base-uncased'),
max_length=config_dict.get('max_length', 512),
num_epochs=config_dict.get('num_epochs', 3),
batch_size=config_dict.get('batch_size', 16),
learning_rate=config_dict.get('learning_rate', 2e-5),
weight_decay=config_dict.get('weight_decay', 0.01),
lr_scheduler_type=config_dict.get('lr_scheduler_type', 'linear'),
warmup_ratio=config_dict.get('warmup_ratio', 0.1),
data_dir=config_dict.get('data_dir', './data/classification'),
output_dir=config_dict.get('output_dir', './results')
)
# Validate data directory
data_path = Path(config.data_dir)
if not data_path.exists():
print(f"❌ Data directory not found: {data_path}")
print("Please ensure the data directory exists and contains train.jsonl, validation.jsonl, and test.jsonl files")
sys.exit(1)
# Check for required data files
required_files = ["train.jsonl"]
missing_files = []
for file_name in required_files:
if not (data_path / file_name).exists():
missing_files.append(file_name)
if missing_files:
print(f"❌ Missing required data files: {missing_files}")
print(f"Please ensure these files exist in: {data_path}")
sys.exit(1)
# Initialize and run training
try:
print(f"Starting training with model: {config.model_name}")
print(f"Data directory: {config.data_dir}")
print(f"Output directory: {config.output_dir}")
print(f"Training for {config.num_epochs} epochs with batch size {config.batch_size}")
if args.config:
print(f"Using YAML configuration: {args.config}")
print()
trainer = AccelerateTrainer(config)
trainer.train()
print(f"✅ Training completed successfully!")
print(f"Model saved to: {config.output_dir}")
except Exception as e:
print(f"❌ Error during training: {e}")
sys.exit(1)
if __name__ == "__main__":
main()