#!/usr/bin/env python3 """ Styling Inference Script Provides a command-line interface to run the styling inference pipeline """ import sys import os import subprocess import argparse from pathlib import Path def run_inference_with_config(config_path: str, instruction: str, input_text: str = "", max_tokens: int = 128, stream: bool = False): """Run inference using a YAML configuration file""" print(f"Running styling inference with config: {config_path}") print(f"Instruction: {instruction}") print(f"Input text: {input_text}") print(f"Max tokens: {max_tokens}") print(f"Streaming: {stream}") cmd = [ "python", "pipelines/styling/inference.py", "--config", config_path, "--instruction", instruction, "--input-text", input_text, "--max-tokens", str(max_tokens) ] if stream: cmd.append("--stream") print(f"Running: {' '.join(cmd)}") try: result = subprocess.run(cmd, capture_output=True, text=True, check=True) print("✅ Inference completed successfully!") print("Output:") print(result.stdout) return result.stdout except subprocess.CalledProcessError as e: print(f"❌ Inference failed: {e}") print("Error output:") print(e.stderr) return None def run_batch_inference_example(config_path: str, input_file: str, output_file: str, instruction: str, max_tokens: int = 128): """Run batch inference example""" print(f"=== Batch Inference Example ===") print(f"Config: {config_path}") print(f"Input file: {input_file}") print(f"Output file: {output_file}") print(f"Instruction: {instruction}") print(f"Max tokens: {max_tokens}") # Check if input file exists if not Path(input_file).exists(): print(f"❌ Input file not found: {input_file}") return False # Read input texts with open(input_file, 'r', encoding='utf-8') as f: input_texts = [line.strip() for line in f if line.strip()] print(f"Found {len(input_texts)} texts to process") # Process each text results = [] for i, text in enumerate(input_texts): print(f"\nProcessing text {i+1}/{len(input_texts)}: {text[:50]}...") result = run_inference_with_config(config_path, instruction, text, max_tokens) if result: results.append(result) else: print(f"❌ Failed to process text {i+1}") # Save results with open(output_file, 'w', encoding='utf-8') as f: for i, (input_text, result) in enumerate(zip(input_texts, results)): f.write(f"Input {i+1}: {input_text}\n") f.write(f"Output {i+1}: {result}\n") f.write("-" * 50 + "\n") print(f"✅ Batch inference completed! Results saved to: {output_file}") return True def show_inference_features(): """Show the features of the styling inference pipeline""" print("=== Styling Inference Pipeline Features ===") print() print("1. **Model Support**:") print(" - Trained LoRA models from training pipeline") print(" - Automatic model loading from config") print(" - Native Unsloth inference optimization") print() print("2. **Inference Modes**:") print(" - Single text inference with instruction") print(" - Batch file processing") print(" - Streaming generation") print() print("3. **Generation Control**:") print(" - Configurable max tokens") print(" - Same alpaca prompt format as training") print(" - Automatic response extraction") print() print("4. **Usage Examples**:") print(" - Single inference: --instruction 'style instruction' --input-text 'your text'") print(" - Streaming: add --stream flag") print(" - Batch: use batch subcommand with input/output files") def create_inference_example(): """Create an inference example using the formal style configuration""" print("=== Inference Example: Formal Style Transfer ===") print() # Check if we have the required files config_path = "configs/styling/formal.yaml" if not Path(config_path).exists(): print(f"❌ Configuration file not found: {config_path}") print(" Please run the data processor first to create the configuration") return False print("✅ Found configuration file!") print(f" Config: {config_path}") print() # Example text example_text = "Hey, what's up? I'm gonna go grab some food later." print(f"Example text: {example_text}") print() # Run inference success = run_inference_with_config( config_path=config_path, instruction="Rewrite the following text in a formal style", input_text=example_text ) if success: print("✅ Example inference completed successfully!") return True else: print("❌ Example inference failed!") return False def main(): """Main inference function""" parser = argparse.ArgumentParser(description="Styling Inference Pipeline") subparsers = parser.add_subparsers(dest="command", help="Available commands") # Inference command infer_parser = subparsers.add_parser("infer", help="Run single inference") infer_parser.add_argument("--config", type=str, required=True, help="Path to YAML configuration file") infer_parser.add_argument("--instruction", type=str, required=True, help="Style instruction") infer_parser.add_argument("--input-text", type=str, default="", help="Input text to style") infer_parser.add_argument("--max-tokens", type=int, default=128, help="Maximum new tokens to generate") infer_parser.add_argument("--stream", action="store_true", help="Enable streaming generation") # Batch inference command batch_parser = subparsers.add_parser("batch", help="Run batch inference") batch_parser.add_argument("--config", type=str, required=True, help="Path to YAML configuration file") batch_parser.add_argument("--input-file", type=str, required=True, help="Input file with texts to style") batch_parser.add_argument("--output-file", type=str, required=True, help="Output file for styled texts") batch_parser.add_argument("--instruction", type=str, required=True, help="Style instruction") batch_parser.add_argument("--max-tokens", type=int, default=128, help="Maximum new tokens to generate") # Features command subparsers.add_parser("features", help="Show available features") # Example command subparsers.add_parser("example", help="Run example inference") args = parser.parse_args() if args.command == "infer": run_inference_with_config( args.config, args.instruction, args.input_text, args.max_tokens, args.stream ) elif args.command == "batch": run_batch_inference_example(args.config, args.input_file, args.output_file, args.instruction, args.max_tokens) elif args.command == "features": show_inference_features() elif args.command == "example": create_inference_example() else: parser.print_help() if __name__ == "__main__": main()