Files
DS-LLM-TEMPLATE-FINETUNING/scripts/styling/inference.py
T
2025-08-13 21:17:01 +01:00

224 lines
7.4 KiB
Python

#!/usr/bin/env python3
"""
Styling Inference Script
Provides a command-line interface to run the styling inference pipeline
"""
import sys
import os
import subprocess
import argparse
from pathlib import Path
def run_inference_with_config(config_path: str, **cli_overrides):
"""Run the styling inference pipeline with YAML configuration"""
print(f"🚀 Starting styling inference with config: {config_path}")
print()
# Build command
cmd = ["python", "pipelines/styling/inference.py", "--config", config_path]
# Add CLI overrides
for key, value in cli_overrides.items():
if value is not None:
if key == "model_path":
cmd.extend(["--model-path", str(value)])
elif key == "text":
cmd.extend(["--text", str(value)])
elif key == "input_file":
cmd.extend(["--input-file", str(value)])
elif key == "max_tokens":
cmd.extend(["--max-tokens", str(value)])
elif key == "temperature":
cmd.extend(["--temperature", str(value)])
elif key == "instruction":
cmd.extend(["--instruction", str(value)])
elif key == "output_file":
cmd.extend(["--output-file", str(value)])
elif key == "streaming":
cmd.append("--streaming")
print(f"Running: {' '.join(cmd)}")
print()
try:
result = subprocess.run(cmd, check=True, capture_output=True, text=True)
print("✅ Inference completed successfully!")
print(result.stdout)
return True
except subprocess.CalledProcessError as e:
print(f"❌ Inference failed: {e}")
print(f"Error output: {e.stderr}")
return False
def show_inference_features():
"""Show the features of the styling inference pipeline"""
print("=== Styling Inference Pipeline Features ===")
print()
print("1. **Model Support**:")
print(" - Trained LoRA models")
print(" - Base models from HuggingFace Hub")
print(" - Automatic model loading and preparation")
print()
print("2. **Inference Modes**:")
print(" - Single text inference")
print(" - Batch file processing")
print(" - Interactive mode")
print(" - Streaming generation")
print()
print("3. **Generation Control**:")
print(" - Configurable temperature and top-p")
print(" - Adjustable max tokens")
print(" - Custom style instructions")
print()
print("4. **Output Options**:")
print(" - Console output")
print(" - File output")
print(" - Streaming real-time generation")
def create_inference_example():
"""Create an inference example using the formal style configuration"""
print("=== Inference Example: Formal Style Transfer ===")
print()
# Check if we have the required files
config_path = "configs/styling/formal.yaml"
if not Path(config_path).exists():
print(f"❌ Configuration file not found: {config_path}")
print(" Please run the data processor first to create the configuration")
return False
print("✅ Found configuration file!")
print(f" Config: {config_path}")
print()
# Example text
example_text = "Hey, what's up? I'm gonna go grab some food later."
print(f"📝 Example text: {example_text}")
print()
# Run inference
success = run_inference_with_config(
config_path=config_path,
text=example_text,
instruction="Rewrite the following text in a formal style"
)
if success:
print("🎉 Inference example completed!")
return success
def create_test_file():
"""Create a test file with sample texts for batch inference"""
test_file = "test_texts.txt"
test_texts = [
"Hey, what's up? How are you doing today?",
"I'm gonna go to the store later to get some stuff.",
"This is pretty cool, right?",
"Can you help me out with this?",
"Thanks a lot for your help!"
]
with open(test_file, 'w', encoding='utf-8') as f:
for text in test_texts:
f.write(text + '\n')
print(f"✅ Created test file: {test_file}")
print(f" Contains {len(test_texts)} sample texts")
return test_file
def run_batch_inference_example():
"""Run a batch inference example"""
print("=== Batch Inference Example ===")
print()
# Create test file
test_file = create_test_file()
# Check configuration
config_path = "configs/styling/formal.yaml"
if not Path(config_path).exists():
print(f"❌ Configuration file not found: {config_path}")
return False
print("✅ Running batch inference...")
print()
# Run batch inference
success = run_inference_with_config(
config_path=config_path,
input_file=test_file,
output_file="styled_results.txt",
instruction="Rewrite the following text in a formal style"
)
if success:
print("🎉 Batch inference completed!")
print(" Results saved to: styled_results.txt")
return success
def main():
"""Main function"""
parser = argparse.ArgumentParser(description="Styling Inference Script")
# Subcommands
parser.add_argument("command", choices=["infer", "example", "batch", "features"],
help="Command to run")
# Inference arguments
parser.add_argument("--config", type=str, help="Path to YAML configuration file")
parser.add_argument("--model-path", type=str, help="Path to trained model")
parser.add_argument("--text", type=str, help="Single text to style transfer")
parser.add_argument("--input-file", type=str, help="File containing texts to process")
parser.add_argument("--max-tokens", type=int, help="Maximum new tokens to generate")
parser.add_argument("--temperature", type=float, help="Sampling temperature")
parser.add_argument("--instruction", type=str, help="Custom style instruction")
parser.add_argument("--output-file", type=str, help="Output file for results")
parser.add_argument("--streaming", action="store_true", help="Enable streaming generation")
args = parser.parse_args()
if args.command == "features":
show_inference_features()
elif args.command == "example":
create_inference_example()
elif args.command == "batch":
run_batch_inference_example()
elif args.command == "infer":
if not args.config:
print("❌ --config is required for inference")
print("Usage: python scripts/styling/inference.py infer --config config.yaml [options]")
sys.exit(1)
# Check if we have input
if not args.text and not args.input_file:
print("❌ Either --text or --input-file is required")
print("Usage: python scripts/styling/inference.py infer --config config.yaml --text 'your text'")
sys.exit(1)
success = run_inference_with_config(
config_path=args.config,
model_path=args.model_path,
text=args.text,
input_file=args.input_file,
max_tokens=args.max_tokens,
temperature=args.temperature,
instruction=args.instruction,
output_file=args.output_file,
streaming=args.streaming
)
if not success:
sys.exit(1)
if __name__ == "__main__":
main()