#!/usr/bin/env python3 """ Classification Inference Script Uses YAML configurations for flexible and maintainable model inference. """ import sys import os import subprocess import argparse from pathlib import Path def run_with_yaml_config(config_path: str, **cli_overrides): """Run inference with YAML configuration""" print(f"=== Running Classification Inference ===") print(f"Config: {config_path}") cmd = [ "python", "pipelines/classification/inference.py", "--config", config_path ] # Add CLI overrides for key, value in cli_overrides.items(): if value is not None: cmd.extend([f"--{key.replace('_', '-')}", str(value)]) print(f"Command: {' '.join(cmd)}") print() try: result = subprocess.run(cmd, check=True, capture_output=True, text=True) print("✅ Inference completed successfully!") print(result.stdout) return True except subprocess.CalledProcessError as e: print(f"❌ Error running inference: {e}") print(f"Error output: {e.stderr}") return False def run_single_text_inference(): """Run single text inference""" print("=== Single Text Inference ===") # Check if model exists model_path = "./results/emotion_model" if not os.path.exists(model_path): print(f"⚠️ Model not found: {model_path}") print("Please train a model first using the trainer script.") return False success = run_with_yaml_config( "configs/classification/emotion.yaml", model_path=model_path, input_text="I love this product! It's amazing.", return_top_k=3 ) if success: print("✅ Single text inference completed!") else: print("❌ Single text inference failed!") return success def run_file_inference(): """Run file-based inference""" print("\n=== File-Based Inference ===") # Check if model exists model_path = "./results/emotion_model" if not os.path.exists(model_path): print(f"⚠️ Model not found: {model_path}") print("Please train a model first using the trainer script.") return False # Create sample input file sample_texts = [ "I love this product! It's amazing.", "This is terrible, I hate it.", "The weather is okay today.", "Best purchase ever made!" ] input_file = "sample_texts.txt" with open(input_file, 'w') as f: for text in sample_texts: f.write(text + '\n') success = run_with_yaml_config( "configs/classification/emotion.yaml", model_path=model_path, input_file=input_file, output_file="predictions.jsonl", batch_size=16 ) if success: print("✅ File-based inference completed!") print(f"Results saved to: predictions.jsonl") else: print("❌ File-based inference failed!") return success def run_interactive_inference(): """Run interactive inference""" print("\n=== Interactive Inference ===") # Check if model exists model_path = "./results/emotion_model" if not os.path.exists(model_path): print(f"⚠️ Model not found: {model_path}") print("Please train a model first using the trainer script.") return False success = run_with_yaml_config( "configs/classification/emotion.yaml", model_path=model_path, return_top_k=3 ) if success: print("✅ Interactive inference completed!") else: print("❌ Interactive inference failed!") return success def create_inference_config(): """Create an inference configuration file""" inference_config = """model_path: "./results/emotion_model" device: "auto" batch_size: 32 max_length: 512 return_probabilities: true return_top_k: 3 """ config_path = "configs/classification/inference.yaml" with open(config_path, 'w') as f: f.write(inference_config) print(f"✅ Created inference config: {config_path}") def show_usage(): """Show usage examples""" print("=== Classification Inference Usage ===") print() print("1. Use YAML config only:") print(" python scripts/classification/inference.py --config configs/classification/inference.yaml") print() print("2. Override YAML values:") print(" python scripts/classification/inference.py --config configs/classification/inference.yaml --input-text 'Your text here'") print() print("3. Use CLI only (backward compatibility):") print(" python scripts/classification/inference.py --model-path ./results/emotion_model --input-text 'Your text here'") print() print("4. Run examples:") print(" python scripts/classification/inference.py examples") print() print("5. Create inference config:") print(" python scripts/classification/inference.py create-config") def handle_direct_args(): """Handle direct command-line arguments by passing them to the pipeline""" parser = argparse.ArgumentParser(description="Classification Inference") # Add all the same arguments as the pipeline parser.add_argument("--config", type=str, help="Path to YAML configuration file") parser.add_argument("--model-path", type=str, help="Path to saved model directory") parser.add_argument("--device", choices=["auto", "cuda", "cpu"], help="Device to run inference on") parser.add_argument("--batch-size", type=int, help="Batch size for inference") parser.add_argument("--max-length", type=int, help="Maximum sequence length for tokenization") parser.add_argument("--return-probabilities", action="store_true", help="Return all class probabilities") parser.add_argument("--return-top-k", type=int, help="Return top K predictions") parser.add_argument("--input-text", type=str, help="Single text for prediction") parser.add_argument("--input-file", type=str, help="Input file path (txt or jsonl)") parser.add_argument("--output-file", type=str, help="Output file path for results") parser.add_argument("--chunk-size", type=int, help="Chunk size for large file processing") parser.add_argument("--log-level", choices=["DEBUG", "INFO", "WARNING", "ERROR"], default="INFO", help="Logging level") args = parser.parse_args() # Build command to call the pipeline cmd = ["python", "pipelines/classification/inference.py"] # Add all arguments that were provided for arg_name, arg_value in vars(args).items(): if arg_value is not None: if isinstance(arg_value, bool): if arg_value: # Only add flag if True cmd.append(f"--{arg_name.replace('_', '-')}") else: cmd.extend([f"--{arg_name.replace('_', '-')}", str(arg_value)]) print(f"Running: {' '.join(cmd)}") print() try: result = subprocess.run(cmd, check=True, capture_output=True, text=True) print("✅ Inference completed successfully!") print(result.stdout) return True except subprocess.CalledProcessError as e: print(f"❌ Error running inference: {e}") print(f"Error output: {e.stderr}") return False def main(): """Main function""" # Check if any command-line arguments were provided if len(sys.argv) > 1: # Check if it's a subcommand if sys.argv[1] in ["examples", "single", "file", "interactive", "create-config", "help"]: # Handle subcommands if sys.argv[1] == "examples": run_single_text_inference() run_file_inference() run_interactive_inference() elif sys.argv[1] == "single": run_single_text_inference() elif sys.argv[1] == "file": run_file_inference() elif sys.argv[1] == "interactive": run_interactive_inference() elif sys.argv[1] == "create-config": create_inference_config() elif sys.argv[1] == "help": show_usage() else: # Handle direct arguments (pass through to pipeline) handle_direct_args() else: print("Classification Inference") print("=======================") print() print("This script performs inference using trained classification models.") print() print("Usage:") print(" python scripts/classification/inference.py examples # Run examples") print(" python scripts/classification/inference.py single # Single text inference") print(" python scripts/classification/inference.py file # File-based inference") print(" python scripts/classification/inference.py interactive # Interactive inference") print(" python scripts/classification/inference.py create-config # Create inference config") print(" python scripts/classification/inference.py help # Show usage") print() print("Direct pipeline usage:") print(" python scripts/classification/inference.py --config configs/classification/inference.yaml") print(" python scripts/classification/inference.py --model-path ./results/emotion_model --input-text 'Your text here'") print() print("Benefits of YAML configurations:") print(" ✅ Easier to manage complex configurations") print(" ✅ Version control friendly") print(" ✅ Self-documenting") print(" ✅ Can still override with CLI args") print(" ✅ Better for team collaboration") if __name__ == "__main__": main()