initial setupt
This commit is contained in:
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,481 @@
|
||||
import torch
|
||||
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
||||
from pathlib import Path
|
||||
import json
|
||||
import numpy as np
|
||||
from typing import List, Dict, Union, Optional
|
||||
from dataclasses import dataclass
|
||||
import logging
|
||||
import argparse
|
||||
import sys
|
||||
import yaml
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@dataclass
|
||||
class InferenceConfig:
|
||||
"""Simple configuration for inference"""
|
||||
model_path: str # Path to saved model
|
||||
device: str = "auto" # "auto", "cuda", "cpu"
|
||||
batch_size: int = 32
|
||||
max_length: int = 512
|
||||
return_probabilities: bool = True
|
||||
return_top_k: int = 1 # Return top K predictions
|
||||
|
||||
|
||||
class ModelInference:
|
||||
"""Simple inference class for text classification"""
|
||||
|
||||
def __init__(self, config: InferenceConfig):
|
||||
self.config = config
|
||||
self.model = None
|
||||
self.tokenizer = None
|
||||
self.label_info = {}
|
||||
self.device = self._setup_device()
|
||||
|
||||
# Load model and tokenizer
|
||||
self.load_model()
|
||||
|
||||
def _setup_device(self) -> torch.device:
|
||||
"""Setup device for inference"""
|
||||
if self.config.device == "auto":
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
logger.info("Using CUDA device")
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
logger.info("Using CPU device")
|
||||
else:
|
||||
device = torch.device(self.config.device)
|
||||
logger.info(f"Using specified device: {device}")
|
||||
|
||||
return device
|
||||
|
||||
def load_model(self):
|
||||
"""Load model, tokenizer, and label mappings"""
|
||||
model_path = Path(self.config.model_path)
|
||||
|
||||
if not model_path.exists():
|
||||
raise FileNotFoundError(f"Model path not found: {model_path}")
|
||||
|
||||
# Load tokenizer
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||||
logger.info(f"Loaded tokenizer from {model_path}")
|
||||
|
||||
# Load model
|
||||
self.model = AutoModelForSequenceClassification.from_pretrained(model_path)
|
||||
self.model.to(self.device)
|
||||
self.model.eval()
|
||||
logger.info(f"Loaded model from {model_path}")
|
||||
|
||||
# Load label information
|
||||
label_info_path = model_path / "label_info.json"
|
||||
if label_info_path.exists():
|
||||
with open(label_info_path, 'r') as f:
|
||||
self.label_info = json.load(f)
|
||||
logger.info(f"Loaded label mappings: {self.label_info.get('id_to_label', {})}")
|
||||
else:
|
||||
logger.warning("No label_info.json found. Using default numeric labels.")
|
||||
# Create default mappings
|
||||
num_labels = self.model.config.num_labels
|
||||
self.label_info = {
|
||||
"id_to_label": {str(i): f"LABEL_{i}" for i in range(num_labels)},
|
||||
"label_to_id": {f"LABEL_{i}": i for i in range(num_labels)},
|
||||
"num_labels": num_labels
|
||||
}
|
||||
|
||||
def predict_single(self, text: str) -> Dict:
|
||||
"""Predict single text sample"""
|
||||
return self.predict_batch([text])[0]
|
||||
|
||||
def predict_batch(self, texts: List[str]) -> List[Dict]:
|
||||
"""Predict batch of texts"""
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
# Tokenize inputs
|
||||
inputs = self.tokenizer(
|
||||
texts,
|
||||
truncation=True,
|
||||
padding=True,
|
||||
max_length=self.config.max_length,
|
||||
return_tensors="pt"
|
||||
)
|
||||
|
||||
# Move to device
|
||||
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
||||
|
||||
# Get predictions
|
||||
with torch.no_grad():
|
||||
outputs = self.model(**inputs)
|
||||
logits = outputs.logits
|
||||
|
||||
# Convert to probabilities
|
||||
probabilities = torch.softmax(logits, dim=-1)
|
||||
|
||||
# Process results
|
||||
results = []
|
||||
for i, text in enumerate(texts):
|
||||
text_probs = probabilities[i].cpu().numpy()
|
||||
|
||||
# Get top K predictions
|
||||
top_indices = np.argsort(text_probs)[-self.config.return_top_k:][::-1]
|
||||
|
||||
predictions = []
|
||||
for idx in top_indices:
|
||||
label = self.label_info["id_to_label"].get(str(idx), f"LABEL_{idx}")
|
||||
prob = float(text_probs[idx])
|
||||
|
||||
pred_dict = {
|
||||
"label": label,
|
||||
"label_id": int(idx),
|
||||
"score": prob
|
||||
}
|
||||
predictions.append(pred_dict)
|
||||
|
||||
result = {
|
||||
"text": text,
|
||||
"predictions": predictions,
|
||||
"predicted_label": predictions[0]["label"],
|
||||
"confidence": predictions[0]["score"]
|
||||
}
|
||||
|
||||
if self.config.return_probabilities:
|
||||
# Add all class probabilities
|
||||
all_probs = {}
|
||||
for label_id, label_name in self.label_info["id_to_label"].items():
|
||||
all_probs[label_name] = float(text_probs[int(label_id)])
|
||||
result["all_probabilities"] = all_probs
|
||||
|
||||
results.append(result)
|
||||
|
||||
return results
|
||||
|
||||
def predict_file(self, input_file: str, output_file: str = None) -> List[Dict]:
|
||||
"""Predict on texts from file"""
|
||||
input_path = Path(input_file)
|
||||
|
||||
# Read texts from file
|
||||
texts = []
|
||||
if input_path.suffix == '.txt':
|
||||
# Plain text file (one text per line)
|
||||
with open(input_path, 'r', encoding='utf-8') as f:
|
||||
texts = [line.strip() for line in f if line.strip()]
|
||||
|
||||
elif input_path.suffix == '.jsonl':
|
||||
# JSONL file with "text" field
|
||||
with open(input_path, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
if line.strip():
|
||||
data = json.loads(line)
|
||||
text = data.get("text", data.get("input", ""))
|
||||
if text:
|
||||
texts.append(text)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported file format: {input_path.suffix}")
|
||||
|
||||
logger.info(f"Loaded {len(texts)} texts from {input_file}")
|
||||
|
||||
# Process in batches
|
||||
all_results = []
|
||||
for i in range(0, len(texts), self.config.batch_size):
|
||||
batch_texts = texts[i:i + self.config.batch_size]
|
||||
batch_results = self.predict_batch(batch_texts)
|
||||
all_results.extend(batch_results)
|
||||
|
||||
if i % (self.config.batch_size * 10) == 0:
|
||||
logger.info(f"Processed {i + len(batch_texts)}/{len(texts)} texts")
|
||||
|
||||
# Save results if output file specified
|
||||
if output_file:
|
||||
output_path = Path(output_file)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
for result in all_results:
|
||||
f.write(json.dumps(result) + '\n')
|
||||
|
||||
logger.info(f"Results saved to {output_file}")
|
||||
|
||||
return all_results
|
||||
|
||||
|
||||
class BatchInference:
|
||||
"""Optimized batch inference for large datasets"""
|
||||
|
||||
def __init__(self, config: InferenceConfig):
|
||||
self.inference = ModelInference(config)
|
||||
self.config = config
|
||||
|
||||
def predict_large_file(self, input_file: str, output_file: str,
|
||||
chunk_size: int = 1000) -> None:
|
||||
"""Process large files in chunks to manage memory"""
|
||||
input_path = Path(input_file)
|
||||
output_path = Path(output_file)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Count total lines first
|
||||
total_lines = 0
|
||||
with open(input_path, 'r', encoding='utf-8') as f:
|
||||
for _ in f:
|
||||
total_lines += 1
|
||||
|
||||
logger.info(f"Processing {total_lines} lines in chunks of {chunk_size}")
|
||||
|
||||
processed = 0
|
||||
with open(input_path, 'r', encoding='utf-8') as infile, \
|
||||
open(output_path, 'w', encoding='utf-8') as outfile:
|
||||
|
||||
chunk_texts = []
|
||||
|
||||
for line in infile:
|
||||
if line.strip():
|
||||
if input_path.suffix == '.jsonl':
|
||||
data = json.loads(line)
|
||||
text = data.get("text", data.get("input", ""))
|
||||
else:
|
||||
text = line.strip()
|
||||
|
||||
chunk_texts.append(text)
|
||||
|
||||
if len(chunk_texts) >= chunk_size:
|
||||
# Process chunk
|
||||
results = self.inference.predict_batch(chunk_texts)
|
||||
|
||||
# Write results
|
||||
for result in results:
|
||||
outfile.write(json.dumps(result) + '\n')
|
||||
|
||||
processed += len(chunk_texts)
|
||||
logger.info(f"Processed {processed}/{total_lines} texts")
|
||||
|
||||
chunk_texts = []
|
||||
|
||||
# Process remaining texts
|
||||
if chunk_texts:
|
||||
results = self.inference.predict_batch(chunk_texts)
|
||||
for result in results:
|
||||
outfile.write(json.dumps(result) + '\n')
|
||||
processed += len(chunk_texts)
|
||||
|
||||
logger.info(f"Completed processing {processed} texts")
|
||||
|
||||
|
||||
def create_inference_config(
|
||||
model_path: str,
|
||||
device: str = "auto",
|
||||
batch_size: int = 32,
|
||||
**kwargs
|
||||
) -> InferenceConfig:
|
||||
"""Helper function to create inference configuration"""
|
||||
return InferenceConfig(
|
||||
model_path=model_path,
|
||||
device=device,
|
||||
batch_size=batch_size,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
"""Main function with YAML configuration support"""
|
||||
|
||||
parser = argparse.ArgumentParser(description="Model Inference Pipeline")
|
||||
|
||||
# YAML configuration
|
||||
parser.add_argument("--config", type=str, help="Path to YAML configuration file")
|
||||
|
||||
# Model settings
|
||||
parser.add_argument("--model-path", type=str, help="Path to saved model directory")
|
||||
parser.add_argument("--device", choices=["auto", "cuda", "cpu"], help="Device to run inference on")
|
||||
parser.add_argument("--batch-size", type=int, help="Batch size for inference")
|
||||
parser.add_argument("--max-length", type=int, help="Maximum sequence length for tokenization")
|
||||
|
||||
# Inference settings
|
||||
parser.add_argument("--return-probabilities", action="store_true", help="Return all class probabilities")
|
||||
parser.add_argument("--return-top-k", type=int, help="Return top K predictions")
|
||||
|
||||
# Input/Output settings
|
||||
parser.add_argument("--input-text", type=str, help="Single text for prediction")
|
||||
parser.add_argument("--input-file", type=str, help="Input file path (txt or jsonl)")
|
||||
parser.add_argument("--output-file", type=str, help="Output file path for results")
|
||||
parser.add_argument("--chunk-size", type=int, help="Chunk size for large file processing")
|
||||
|
||||
# Logging
|
||||
parser.add_argument("--log-level", choices=["DEBUG", "INFO", "WARNING", "ERROR"], default="INFO", help="Logging level")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Set up logging
|
||||
logging.basicConfig(
|
||||
level=getattr(logging, args.log_level),
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
|
||||
# Load configuration
|
||||
config_dict = {}
|
||||
|
||||
# Load YAML config if provided
|
||||
if args.config:
|
||||
try:
|
||||
with open(args.config, 'r', encoding='utf-8') as f:
|
||||
config_dict = yaml.safe_load(f)
|
||||
logger.info(f"Loaded YAML configuration from: {args.config}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading YAML config: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
# Override YAML config with CLI arguments
|
||||
cli_overrides = {}
|
||||
if args.model_path:
|
||||
cli_overrides['model_path'] = args.model_path
|
||||
if args.device:
|
||||
cli_overrides['device'] = args.device
|
||||
if args.batch_size:
|
||||
cli_overrides['batch_size'] = args.batch_size
|
||||
if args.max_length:
|
||||
cli_overrides['max_length'] = args.max_length
|
||||
if args.return_probabilities:
|
||||
cli_overrides['return_probabilities'] = True
|
||||
if args.return_top_k:
|
||||
cli_overrides['return_top_k'] = args.return_top_k
|
||||
|
||||
# Merge configurations
|
||||
for key, value in cli_overrides.items():
|
||||
if key in config_dict:
|
||||
logger.info(f"Overriding YAML config '{key}' with CLI value: {value}")
|
||||
config_dict[key] = value
|
||||
|
||||
# Validate model path
|
||||
model_path = config_dict.get('model_path')
|
||||
if not model_path:
|
||||
parser.error("--model-path is required (either in YAML config or CLI)")
|
||||
|
||||
model_path = Path(model_path)
|
||||
if not model_path.exists():
|
||||
print(f"❌ Model path not found: {model_path}")
|
||||
print("Please ensure the model directory exists and contains the trained model files")
|
||||
sys.exit(1)
|
||||
|
||||
# Create configuration object
|
||||
config = InferenceConfig(
|
||||
model_path=str(model_path),
|
||||
device=config_dict.get('device', 'auto'),
|
||||
batch_size=config_dict.get('batch_size', 32),
|
||||
max_length=config_dict.get('max_length', 512),
|
||||
return_probabilities=config_dict.get('return_probabilities', True),
|
||||
return_top_k=config_dict.get('return_top_k', 1)
|
||||
)
|
||||
|
||||
try:
|
||||
print(f"Loading model from: {config.model_path}")
|
||||
print(f"Device: {config.device}")
|
||||
print(f"Batch size: {config.batch_size}")
|
||||
if args.config:
|
||||
print(f"Using YAML configuration: {args.config}")
|
||||
print()
|
||||
|
||||
# Initialize inference
|
||||
inference = ModelInference(config)
|
||||
|
||||
# Handle different input types
|
||||
if args.input_text:
|
||||
# Single text prediction
|
||||
print(f"=== Single Text Prediction ===")
|
||||
print(f"Input text: {args.input_text}")
|
||||
print()
|
||||
|
||||
result = inference.predict_single(args.input_text)
|
||||
print(f"Predicted label: {result['predicted_label']}")
|
||||
print(f"Confidence: {result['confidence']:.4f}")
|
||||
print(f"Top {config.return_top_k} predictions:")
|
||||
for pred in result['predictions']:
|
||||
print(f" - {pred['label']}: {pred['score']:.4f}")
|
||||
|
||||
if config.return_probabilities and 'all_probabilities' in result:
|
||||
print(f"\nAll class probabilities:")
|
||||
for label, prob in result['all_probabilities'].items():
|
||||
print(f" - {label}: {prob:.4f}")
|
||||
|
||||
elif args.input_file:
|
||||
# File prediction
|
||||
input_path = Path(args.input_file)
|
||||
if not input_path.exists():
|
||||
print(f"❌ Input file not found: {input_path}")
|
||||
sys.exit(1)
|
||||
|
||||
print(f"=== File Prediction ===")
|
||||
print(f"Input file: {args.input_file}")
|
||||
print(f"Output file: {args.output_file}")
|
||||
print()
|
||||
|
||||
if args.output_file:
|
||||
# Use batch inference for large files
|
||||
if input_path.stat().st_size > 10 * 1024 * 1024: # > 10MB
|
||||
print("Large file detected, using chunked processing...")
|
||||
batch_inference = BatchInference(config)
|
||||
chunk_size = args.chunk_size or 1000
|
||||
batch_inference.predict_large_file(
|
||||
args.input_file,
|
||||
args.output_file,
|
||||
chunk_size=chunk_size
|
||||
)
|
||||
else:
|
||||
# Regular file processing
|
||||
results = inference.predict_file(args.input_file, args.output_file)
|
||||
print(f"Processed {len(results)} texts")
|
||||
else:
|
||||
# Just predict without saving
|
||||
results = inference.predict_file(args.input_file)
|
||||
print(f"Processed {len(results)} texts")
|
||||
print("\nSample results:")
|
||||
for i, result in enumerate(results[:3]): # Show first 3
|
||||
print(f" {i+1}. '{result['text'][:50]}...' -> {result['predicted_label']} ({result['confidence']:.4f})")
|
||||
|
||||
else:
|
||||
# Interactive mode - example predictions
|
||||
print("=== Interactive Mode ===")
|
||||
print("No input specified. Running example predictions...")
|
||||
print()
|
||||
|
||||
# Example texts
|
||||
example_texts = [
|
||||
"I love this product! It's amazing.",
|
||||
"This is terrible, I hate it.",
|
||||
"The weather is okay today.",
|
||||
"Best purchase ever made!"
|
||||
]
|
||||
|
||||
print("Example predictions:")
|
||||
for text in example_texts:
|
||||
result = inference.predict_single(text)
|
||||
print(f" '{text}' -> {result['predicted_label']} ({result['confidence']:.4f})")
|
||||
|
||||
print(f"\n✅ Inference completed successfully!")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Inference failed: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
"""
|
||||
# Quick usage examples:
|
||||
|
||||
# 1. Single prediction
|
||||
config = InferenceConfig(model_path="./results/my_model")
|
||||
inference = ModelInference(config)
|
||||
result = inference.predict_single("Your text here")
|
||||
|
||||
# # 2. Batch prediction
|
||||
# results = inference.predict_batch(["text1", "text2", "text3"])
|
||||
|
||||
# # 3. File prediction
|
||||
# inference.predict_file("input.txt", "predictions.jsonl")
|
||||
|
||||
# # 4. Large file processing
|
||||
# batch_inference = BatchInference(config)
|
||||
# batch_inference.predict_large_file("large_input.jsonl", "large_output.jsonl", chunk_size=1000)
|
||||
"""
|
||||
@@ -0,0 +1,467 @@
|
||||
import torch
|
||||
from torch.optim import AdamW
|
||||
from torch.utils.data import DataLoader
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
AutoModelForSequenceClassification,
|
||||
DataCollatorWithPadding,
|
||||
get_scheduler
|
||||
)
|
||||
from accelerate import Accelerator
|
||||
from datasets import Dataset
|
||||
from tqdm.auto import tqdm
|
||||
import evaluate
|
||||
import json
|
||||
from pathlib import Path
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional
|
||||
import logging
|
||||
import argparse
|
||||
import sys
|
||||
import yaml
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@dataclass
|
||||
class SimpleConfig:
|
||||
"""Simple configuration for accelerate training"""
|
||||
# Model settings
|
||||
model_name: str = "bert-base-uncased"
|
||||
max_length: int = 512
|
||||
|
||||
# Training settings
|
||||
num_epochs: int = 3
|
||||
batch_size: int = 16
|
||||
learning_rate: float = 2e-5
|
||||
weight_decay: float = 0.01
|
||||
|
||||
# Scheduler settings
|
||||
lr_scheduler_type: str = "linear"
|
||||
warmup_ratio: float = 0.1
|
||||
|
||||
# Paths
|
||||
data_dir: str = "./data/classification"
|
||||
output_dir: str = "./results"
|
||||
|
||||
|
||||
class AccelerateTrainer:
|
||||
"""Simple trainer using Accelerate for distributed training"""
|
||||
|
||||
def __init__(self, config: SimpleConfig):
|
||||
self.config = config
|
||||
|
||||
# Initialize accelerator
|
||||
self.accelerator = Accelerator()
|
||||
|
||||
# Setup logging only on main process
|
||||
if self.accelerator.is_main_process:
|
||||
logging.basicConfig(
|
||||
format='%(asctime)s - %(levelname)s - %(message)s',
|
||||
level=logging.INFO
|
||||
)
|
||||
|
||||
self.tokenizer = None
|
||||
self.model = None
|
||||
self.label_to_id = {}
|
||||
self.id_to_label = {}
|
||||
self.num_labels = 0
|
||||
|
||||
def load_data(self) -> Dict[str, List[Dict]]:
|
||||
"""Load data from JSONL files"""
|
||||
data_path = Path(self.config.data_dir)
|
||||
splits = {}
|
||||
|
||||
for split_name in ["train", "validation", "test"]:
|
||||
split_file = data_path / f"{split_name}.jsonl"
|
||||
if split_file.exists():
|
||||
split_data = []
|
||||
with open(split_file, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
if line.strip():
|
||||
split_data.append(json.loads(line))
|
||||
splits[split_name] = split_data
|
||||
|
||||
if self.accelerator.is_main_process:
|
||||
logger.info(f"Loaded {len(split_data)} samples from {split_name}")
|
||||
|
||||
return splits
|
||||
|
||||
def setup_labels(self, train_data: List[Dict]):
|
||||
"""Setup label mappings"""
|
||||
labels = set()
|
||||
for item in train_data:
|
||||
labels.add(str(item["label"]))
|
||||
|
||||
sorted_labels = sorted(list(labels))
|
||||
self.label_to_id = {label: idx for idx, label in enumerate(sorted_labels)}
|
||||
self.id_to_label = {idx: label for label, idx in self.label_to_id.items()}
|
||||
self.num_labels = len(sorted_labels)
|
||||
|
||||
if self.accelerator.is_main_process:
|
||||
logger.info(f"Found {self.num_labels} labels: {sorted_labels}")
|
||||
|
||||
def create_dataset(self, data: List[Dict]) -> Dataset:
|
||||
"""Create tokenized dataset"""
|
||||
texts = []
|
||||
labels = []
|
||||
|
||||
for item in data:
|
||||
text = item["text"] if "text" in item else item["input"]
|
||||
label = item["label"]
|
||||
|
||||
texts.append(str(text))
|
||||
# Convert label to ID
|
||||
if isinstance(label, str):
|
||||
label_id = self.label_to_id.get(label, 0)
|
||||
else:
|
||||
label_id = int(label)
|
||||
labels.append(label_id)
|
||||
|
||||
# Create dataset
|
||||
dataset = Dataset.from_dict({
|
||||
"text": texts,
|
||||
"labels": labels
|
||||
})
|
||||
|
||||
# Tokenize
|
||||
def tokenize_function(examples):
|
||||
return self.tokenizer(
|
||||
examples["text"],
|
||||
truncation=True,
|
||||
padding="max_length",
|
||||
max_length=self.config.max_length
|
||||
)
|
||||
|
||||
tokenized_dataset = dataset.map(
|
||||
tokenize_function,
|
||||
batched=True,
|
||||
remove_columns=["text"]
|
||||
)
|
||||
|
||||
return tokenized_dataset
|
||||
|
||||
def compute_metrics(self, predictions, labels):
|
||||
"""Compute accuracy and F1"""
|
||||
predictions = predictions.argmax(axis=-1)
|
||||
|
||||
# Gather predictions from all processes
|
||||
all_predictions = self.accelerator.gather_for_metrics(predictions)
|
||||
all_labels = self.accelerator.gather_for_metrics(labels)
|
||||
|
||||
if self.accelerator.is_main_process:
|
||||
# Only compute metrics on main process
|
||||
metric = evaluate.load("glue", "mrpc") # Using MRPC as example
|
||||
results = metric.compute(
|
||||
predictions=all_predictions.cpu().numpy(),
|
||||
references=all_labels.cpu().numpy()
|
||||
)
|
||||
return results
|
||||
return {}
|
||||
|
||||
def train(self):
|
||||
"""Main training function"""
|
||||
if self.accelerator.is_main_process:
|
||||
logger.info("=== Starting Accelerate Training ===")
|
||||
|
||||
# Load data
|
||||
splits_data = self.load_data()
|
||||
if "train" not in splits_data:
|
||||
raise ValueError("No training data found!")
|
||||
|
||||
# Setup labels
|
||||
self.setup_labels(splits_data["train"])
|
||||
|
||||
# Initialize tokenizer and model
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(self.config.model_name)
|
||||
if self.tokenizer.pad_token is None:
|
||||
self.tokenizer.pad_token = self.tokenizer.eos_token
|
||||
|
||||
self.model = AutoModelForSequenceClassification.from_pretrained(
|
||||
self.config.model_name,
|
||||
num_labels=self.num_labels,
|
||||
id2label=self.id_to_label,
|
||||
label2id=self.label_to_id
|
||||
)
|
||||
|
||||
# Create datasets
|
||||
train_dataset = self.create_dataset(splits_data["train"])
|
||||
eval_dataset = None
|
||||
if "validation" in splits_data:
|
||||
eval_dataset = self.create_dataset(splits_data["validation"])
|
||||
|
||||
# Create data loaders
|
||||
data_collator = DataCollatorWithPadding(tokenizer=self.tokenizer)
|
||||
|
||||
train_dataloader = DataLoader(
|
||||
train_dataset,
|
||||
shuffle=True,
|
||||
batch_size=self.config.batch_size,
|
||||
collate_fn=data_collator
|
||||
)
|
||||
|
||||
eval_dataloader = None
|
||||
if eval_dataset:
|
||||
eval_dataloader = DataLoader(
|
||||
eval_dataset,
|
||||
batch_size=self.config.batch_size,
|
||||
collate_fn=data_collator
|
||||
)
|
||||
|
||||
# Setup optimizer and scheduler
|
||||
optimizer = AdamW(
|
||||
self.model.parameters(),
|
||||
lr=self.config.learning_rate,
|
||||
weight_decay=self.config.weight_decay
|
||||
)
|
||||
|
||||
num_training_steps = self.config.num_epochs * len(train_dataloader)
|
||||
num_warmup_steps = int(num_training_steps * self.config.warmup_ratio)
|
||||
|
||||
lr_scheduler = get_scheduler(
|
||||
self.config.lr_scheduler_type,
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=num_warmup_steps,
|
||||
num_training_steps=num_training_steps
|
||||
)
|
||||
|
||||
# Prepare everything with accelerator
|
||||
self.model, optimizer, train_dataloader, lr_scheduler = self.accelerator.prepare(
|
||||
self.model, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
|
||||
if eval_dataloader:
|
||||
eval_dataloader = self.accelerator.prepare(eval_dataloader)
|
||||
|
||||
# Training loop
|
||||
if self.accelerator.is_main_process:
|
||||
progress_bar = tqdm(range(num_training_steps))
|
||||
logger.info(f"Training steps: {num_training_steps}")
|
||||
|
||||
self.model.train()
|
||||
|
||||
for epoch in range(self.config.num_epochs):
|
||||
if self.accelerator.is_main_process:
|
||||
logger.info(f"Epoch {epoch + 1}/{self.config.num_epochs}")
|
||||
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
outputs = self.model(**batch)
|
||||
loss = outputs.loss
|
||||
|
||||
# Backward pass
|
||||
self.accelerator.backward(loss)
|
||||
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
if self.accelerator.is_main_process:
|
||||
progress_bar.update(1)
|
||||
|
||||
# Log every 100 steps
|
||||
if step % 100 == 0:
|
||||
logger.info(f"Step {step}, Loss: {loss.item():.4f}")
|
||||
|
||||
# Evaluation at end of each epoch
|
||||
if eval_dataloader:
|
||||
self.evaluate(eval_dataloader, epoch)
|
||||
|
||||
# Save model
|
||||
self.save_model()
|
||||
|
||||
if self.accelerator.is_main_process:
|
||||
logger.info("=== Training Completed ===")
|
||||
|
||||
def evaluate(self, eval_dataloader, epoch):
|
||||
"""Evaluation function"""
|
||||
self.model.eval()
|
||||
|
||||
all_predictions = []
|
||||
all_labels = []
|
||||
|
||||
for batch in eval_dataloader:
|
||||
with torch.no_grad():
|
||||
outputs = self.model(**batch)
|
||||
|
||||
predictions = outputs.logits
|
||||
labels = batch["labels"]
|
||||
|
||||
# Gather from all processes
|
||||
predictions = self.accelerator.gather_for_metrics(predictions)
|
||||
labels = self.accelerator.gather_for_metrics(labels)
|
||||
|
||||
all_predictions.append(predictions.cpu())
|
||||
all_labels.append(labels.cpu())
|
||||
|
||||
if self.accelerator.is_main_process and all_predictions:
|
||||
all_predictions = torch.cat(all_predictions)
|
||||
all_labels = torch.cat(all_labels)
|
||||
|
||||
predictions_np = all_predictions.argmax(dim=-1).numpy()
|
||||
labels_np = all_labels.numpy()
|
||||
|
||||
# Simple accuracy calculation
|
||||
accuracy = (predictions_np == labels_np).mean()
|
||||
logger.info(f"Epoch {epoch + 1} - Validation Accuracy: {accuracy:.4f}")
|
||||
|
||||
self.model.train()
|
||||
|
||||
def save_model(self):
|
||||
"""Save model and tokenizer"""
|
||||
if self.accelerator.is_main_process:
|
||||
output_path = Path(self.config.output_dir)
|
||||
output_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Save model using accelerator
|
||||
unwrapped_model = self.accelerator.unwrap_model(self.model)
|
||||
unwrapped_model.save_pretrained(output_path)
|
||||
self.tokenizer.save_pretrained(output_path)
|
||||
|
||||
# Save label info
|
||||
label_info = {
|
||||
"label_to_id": self.label_to_id,
|
||||
"id_to_label": self.id_to_label,
|
||||
"num_labels": self.num_labels
|
||||
}
|
||||
|
||||
with open(output_path / "label_info.json", 'w') as f:
|
||||
json.dump(label_info, f, indent=2)
|
||||
|
||||
logger.info(f"Model saved to {output_path}")
|
||||
|
||||
|
||||
def main():
|
||||
"""Main function with YAML configuration support"""
|
||||
|
||||
parser = argparse.ArgumentParser(description="Accelerate Training Pipeline")
|
||||
|
||||
# YAML configuration
|
||||
parser.add_argument("--config", type=str, help="Path to YAML configuration file")
|
||||
|
||||
# Model settings
|
||||
parser.add_argument("--model-name", type=str, help="Model name from HuggingFace Hub")
|
||||
parser.add_argument("--max-length", type=int, help="Maximum sequence length for tokenization")
|
||||
|
||||
# Training settings
|
||||
parser.add_argument("--num-epochs", type=int, help="Number of training epochs")
|
||||
parser.add_argument("--batch-size", type=int, help="Training batch size")
|
||||
parser.add_argument("--learning-rate", type=float, help="Learning rate")
|
||||
parser.add_argument("--weight-decay", type=float, help="Weight decay for optimizer")
|
||||
|
||||
# Scheduler settings
|
||||
parser.add_argument("--lr-scheduler-type", choices=["linear", "cosine", "polynomial"], help="Learning rate scheduler type")
|
||||
parser.add_argument("--warmup-ratio", type=float, help="Warmup ratio for scheduler")
|
||||
|
||||
# Paths
|
||||
parser.add_argument("--data-dir", type=str, help="Directory containing train/validation/test JSONL files")
|
||||
parser.add_argument("--output-dir", type=str, help="Output directory for saved model")
|
||||
|
||||
# Logging
|
||||
parser.add_argument("--log-level", choices=["DEBUG", "INFO", "WARNING", "ERROR"], default="INFO", help="Logging level")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Set up logging
|
||||
logging.basicConfig(
|
||||
level=getattr(logging, args.log_level),
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
|
||||
# Load configuration
|
||||
config_dict = {}
|
||||
|
||||
# Load YAML config if provided
|
||||
if args.config:
|
||||
try:
|
||||
with open(args.config, 'r', encoding='utf-8') as f:
|
||||
config_dict = yaml.safe_load(f)
|
||||
logger.info(f"Loaded YAML configuration from: {args.config}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading YAML config: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
# Override YAML config with CLI arguments
|
||||
cli_overrides = {}
|
||||
if args.model_name:
|
||||
cli_overrides['model_name'] = args.model_name
|
||||
if args.max_length:
|
||||
cli_overrides['max_length'] = args.max_length
|
||||
if args.num_epochs:
|
||||
cli_overrides['num_epochs'] = args.num_epochs
|
||||
if args.batch_size:
|
||||
cli_overrides['batch_size'] = args.batch_size
|
||||
if args.learning_rate:
|
||||
cli_overrides['learning_rate'] = args.learning_rate
|
||||
if args.weight_decay:
|
||||
cli_overrides['weight_decay'] = args.weight_decay
|
||||
if args.lr_scheduler_type:
|
||||
cli_overrides['lr_scheduler_type'] = args.lr_scheduler_type
|
||||
if args.warmup_ratio:
|
||||
cli_overrides['warmup_ratio'] = args.warmup_ratio
|
||||
if args.data_dir:
|
||||
cli_overrides['data_dir'] = args.data_dir
|
||||
if args.output_dir:
|
||||
cli_overrides['output_dir'] = args.output_dir
|
||||
|
||||
# Merge configurations
|
||||
for key, value in cli_overrides.items():
|
||||
if key in config_dict:
|
||||
logger.info(f"Overriding YAML config '{key}' with CLI value: {value}")
|
||||
config_dict[key] = value
|
||||
|
||||
# Create configuration object
|
||||
config = SimpleConfig(
|
||||
model_name=config_dict.get('model_name', 'bert-base-uncased'),
|
||||
max_length=config_dict.get('max_length', 512),
|
||||
num_epochs=config_dict.get('num_epochs', 3),
|
||||
batch_size=config_dict.get('batch_size', 16),
|
||||
learning_rate=config_dict.get('learning_rate', 2e-5),
|
||||
weight_decay=config_dict.get('weight_decay', 0.01),
|
||||
lr_scheduler_type=config_dict.get('lr_scheduler_type', 'linear'),
|
||||
warmup_ratio=config_dict.get('warmup_ratio', 0.1),
|
||||
data_dir=config_dict.get('data_dir', './data/classification'),
|
||||
output_dir=config_dict.get('output_dir', './results')
|
||||
)
|
||||
|
||||
# Validate data directory
|
||||
data_path = Path(config.data_dir)
|
||||
if not data_path.exists():
|
||||
print(f"❌ Data directory not found: {data_path}")
|
||||
print("Please ensure the data directory exists and contains train.jsonl, validation.jsonl, and test.jsonl files")
|
||||
sys.exit(1)
|
||||
|
||||
# Check for required data files
|
||||
required_files = ["train.jsonl"]
|
||||
missing_files = []
|
||||
for file_name in required_files:
|
||||
if not (data_path / file_name).exists():
|
||||
missing_files.append(file_name)
|
||||
|
||||
if missing_files:
|
||||
print(f"❌ Missing required data files: {missing_files}")
|
||||
print(f"Please ensure these files exist in: {data_path}")
|
||||
sys.exit(1)
|
||||
|
||||
# Initialize and run training
|
||||
try:
|
||||
print(f"Starting training with model: {config.model_name}")
|
||||
print(f"Data directory: {config.data_dir}")
|
||||
print(f"Output directory: {config.output_dir}")
|
||||
print(f"Training for {config.num_epochs} epochs with batch size {config.batch_size}")
|
||||
if args.config:
|
||||
print(f"Using YAML configuration: {args.config}")
|
||||
print()
|
||||
|
||||
trainer = AccelerateTrainer(config)
|
||||
trainer.train()
|
||||
|
||||
print(f"✅ Training completed successfully!")
|
||||
print(f"Model saved to: {config.output_dir}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error during training: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
Reference in New Issue
Block a user