204 lines
7.9 KiB
Python
204 lines
7.9 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Classification Trainer Script
|
|
Uses YAML configurations for flexible and maintainable model training.
|
|
"""
|
|
|
|
import sys
|
|
import os
|
|
import subprocess
|
|
import argparse
|
|
from pathlib import Path
|
|
|
|
def run_with_yaml_config(config_path: str, **cli_overrides):
|
|
"""Run trainer with YAML configuration"""
|
|
print(f"=== Running Classification Trainer ===")
|
|
print(f"Config: {config_path}")
|
|
|
|
cmd = [
|
|
"python", "pipelines/classification/train.py",
|
|
"--config", config_path
|
|
]
|
|
|
|
# Add CLI overrides
|
|
for key, value in cli_overrides.items():
|
|
if value is not None:
|
|
cmd.extend([f"--{key.replace('_', '-')}", str(value)])
|
|
|
|
print(f"Command: {' '.join(cmd)}")
|
|
print()
|
|
|
|
try:
|
|
result = subprocess.run(cmd, check=True, capture_output=True, text=True)
|
|
print("✅ Training completed successfully!")
|
|
print(result.stdout)
|
|
return True
|
|
except subprocess.CalledProcessError as e:
|
|
print(f"❌ Error running trainer: {e}")
|
|
print(f"Error output: {e.stderr}")
|
|
return False
|
|
|
|
def run_emotion_training():
|
|
"""Run emotion classification training"""
|
|
print("=== Emotion Classification Training ===")
|
|
|
|
success = run_with_yaml_config(
|
|
"configs/classification/emotion.yaml",
|
|
num_epochs=2, # Override YAML value
|
|
batch_size=8, # Smaller batch for testing
|
|
output_dir="./results/emotion_model"
|
|
)
|
|
|
|
if success:
|
|
print("✅ Emotion classification training completed!")
|
|
else:
|
|
print("❌ Emotion classification training failed!")
|
|
|
|
def run_custom_training():
|
|
"""Run custom dataset training"""
|
|
print("\n=== Custom Dataset Training ===")
|
|
|
|
if os.path.exists("data/custom_processed/train.jsonl"):
|
|
success = run_with_yaml_config(
|
|
"configs/classification/custom.yaml",
|
|
data_dir="data/custom_processed",
|
|
output_dir="./results/custom_model"
|
|
)
|
|
if success:
|
|
print("✅ Custom dataset training completed!")
|
|
else:
|
|
print("❌ Custom dataset training failed!")
|
|
else:
|
|
print("⚠️ Custom dataset not found, skipping...")
|
|
|
|
def create_training_config():
|
|
"""Create a training configuration file"""
|
|
training_config = """model_name: "bert-base-uncased"
|
|
max_length: 512
|
|
num_epochs: 3
|
|
batch_size: 16
|
|
learning_rate: 2e-5
|
|
weight_decay: 0.01
|
|
lr_scheduler_type: "linear"
|
|
warmup_ratio: 0.1
|
|
data_dir: "./data/classification"
|
|
output_dir: "./results/classification_model"
|
|
"""
|
|
|
|
config_path = "configs/classification/training.yaml"
|
|
with open(config_path, 'w') as f:
|
|
f.write(training_config)
|
|
|
|
print(f"✅ Created training config: {config_path}")
|
|
|
|
def show_usage():
|
|
"""Show usage examples"""
|
|
print("=== Classification Trainer Usage ===")
|
|
print()
|
|
print("1. Use YAML config only:")
|
|
print(" python scripts/classification/trainer.py --config configs/classification/emotion.yaml")
|
|
print()
|
|
print("2. Override YAML values:")
|
|
print(" python scripts/classification/trainer.py --config configs/classification/emotion.yaml --num-epochs 5")
|
|
print()
|
|
print("3. Use CLI only (backward compatibility):")
|
|
print(" python scripts/classification/trainer.py --model-name bert-base-uncased --num-epochs 3")
|
|
print()
|
|
print("4. Run examples:")
|
|
print(" python scripts/classification/trainer.py examples")
|
|
print()
|
|
print("5. Create training config:")
|
|
print(" python scripts/classification/trainer.py create-config")
|
|
|
|
def handle_direct_args():
|
|
"""Handle direct command-line arguments by passing them to the pipeline"""
|
|
parser = argparse.ArgumentParser(description="Classification Trainer")
|
|
|
|
# Add all the same arguments as the pipeline
|
|
parser.add_argument("--config", type=str, help="Path to YAML configuration file")
|
|
parser.add_argument("--model-name", type=str, help="Model name from HuggingFace Hub")
|
|
parser.add_argument("--max-length", type=int, help="Maximum sequence length for tokenization")
|
|
parser.add_argument("--num-epochs", type=int, help="Number of training epochs")
|
|
parser.add_argument("--batch-size", type=int, help="Training batch size")
|
|
parser.add_argument("--learning-rate", type=float, help="Learning rate")
|
|
parser.add_argument("--weight-decay", type=float, help="Weight decay for optimizer")
|
|
parser.add_argument("--lr-scheduler-type", choices=["linear", "cosine", "polynomial"], help="Learning rate scheduler type")
|
|
parser.add_argument("--warmup-ratio", type=float, help="Warmup ratio for scheduler")
|
|
parser.add_argument("--data-dir", type=str, help="Directory containing train/validation/test JSONL files")
|
|
parser.add_argument("--output-dir", type=str, help="Output directory for saved model")
|
|
parser.add_argument("--log-level", choices=["DEBUG", "INFO", "WARNING", "ERROR"], default="INFO", help="Logging level")
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Build command to call the pipeline
|
|
cmd = ["python", "pipelines/classification/train.py"]
|
|
|
|
# Add all arguments that were provided
|
|
for arg_name, arg_value in vars(args).items():
|
|
if arg_value is not None:
|
|
if isinstance(arg_value, bool):
|
|
if arg_value: # Only add flag if True
|
|
cmd.append(f"--{arg_name.replace('_', '-')}")
|
|
else:
|
|
cmd.extend([f"--{arg_name.replace('_', '-')}", str(arg_value)])
|
|
|
|
print(f"Running: {' '.join(cmd)}")
|
|
print()
|
|
|
|
try:
|
|
result = subprocess.run(cmd, check=True, capture_output=True, text=True)
|
|
print("✅ Training completed successfully!")
|
|
print(result.stdout)
|
|
return True
|
|
except subprocess.CalledProcessError as e:
|
|
print(f"❌ Error running trainer: {e}")
|
|
print(f"Error output: {e.stderr}")
|
|
return False
|
|
|
|
def main():
|
|
"""Main function"""
|
|
# Check if any command-line arguments were provided
|
|
if len(sys.argv) > 1:
|
|
# Check if it's a subcommand
|
|
if sys.argv[1] in ["examples", "emotion", "custom", "create-config", "help"]:
|
|
# Handle subcommands
|
|
if sys.argv[1] == "examples":
|
|
run_emotion_training()
|
|
run_custom_training()
|
|
elif sys.argv[1] == "emotion":
|
|
run_emotion_training()
|
|
elif sys.argv[1] == "custom":
|
|
run_custom_training()
|
|
elif sys.argv[1] == "create-config":
|
|
create_training_config()
|
|
elif sys.argv[1] == "help":
|
|
show_usage()
|
|
else:
|
|
# Handle direct arguments (pass through to pipeline)
|
|
handle_direct_args()
|
|
else:
|
|
print("Classification Trainer")
|
|
print("====================")
|
|
print()
|
|
print("This script trains classification models using YAML configurations.")
|
|
print()
|
|
print("Usage:")
|
|
print(" python scripts/classification/trainer.py examples # Run examples")
|
|
print(" python scripts/classification/trainer.py emotion # Run emotion training")
|
|
print(" python scripts/classification/trainer.py custom # Run custom training")
|
|
print(" python scripts/classification/trainer.py create-config # Create training config")
|
|
print(" python scripts/classification/trainer.py help # Show usage")
|
|
print()
|
|
print("Direct pipeline usage:")
|
|
print(" python scripts/classification/trainer.py --config configs/classification/emotion.yaml")
|
|
print(" python scripts/classification/trainer.py --model-name bert-base-uncased --num-epochs 3")
|
|
print()
|
|
print("Benefits of YAML configurations:")
|
|
print(" ✅ Easier to manage complex configurations")
|
|
print(" ✅ Version control friendly")
|
|
print(" ✅ Self-documenting")
|
|
print(" ✅ Can still override with CLI args")
|
|
print(" ✅ Better for team collaboration")
|
|
|
|
if __name__ == "__main__":
|
|
main() |