224 lines
7.4 KiB
Python
224 lines
7.4 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Styling Inference Script
|
|
Provides a command-line interface to run the styling inference pipeline
|
|
"""
|
|
|
|
import sys
|
|
import os
|
|
import subprocess
|
|
import argparse
|
|
from pathlib import Path
|
|
|
|
def run_inference_with_config(config_path: str, **cli_overrides):
|
|
"""Run the styling inference pipeline with YAML configuration"""
|
|
print(f"🚀 Starting styling inference with config: {config_path}")
|
|
print()
|
|
|
|
# Build command
|
|
cmd = ["python", "pipelines/styling/inference.py", "--config", config_path]
|
|
|
|
# Add CLI overrides
|
|
for key, value in cli_overrides.items():
|
|
if value is not None:
|
|
if key == "model_path":
|
|
cmd.extend(["--model-path", str(value)])
|
|
elif key == "text":
|
|
cmd.extend(["--text", str(value)])
|
|
elif key == "input_file":
|
|
cmd.extend(["--input-file", str(value)])
|
|
elif key == "max_tokens":
|
|
cmd.extend(["--max-tokens", str(value)])
|
|
elif key == "temperature":
|
|
cmd.extend(["--temperature", str(value)])
|
|
elif key == "instruction":
|
|
cmd.extend(["--instruction", str(value)])
|
|
elif key == "output_file":
|
|
cmd.extend(["--output-file", str(value)])
|
|
elif key == "streaming":
|
|
cmd.append("--streaming")
|
|
|
|
print(f"Running: {' '.join(cmd)}")
|
|
print()
|
|
|
|
try:
|
|
result = subprocess.run(cmd, check=True, capture_output=True, text=True)
|
|
print("✅ Inference completed successfully!")
|
|
print(result.stdout)
|
|
return True
|
|
except subprocess.CalledProcessError as e:
|
|
print(f"❌ Inference failed: {e}")
|
|
print(f"Error output: {e.stderr}")
|
|
return False
|
|
|
|
def show_inference_features():
|
|
"""Show the features of the styling inference pipeline"""
|
|
print("=== Styling Inference Pipeline Features ===")
|
|
print()
|
|
print("1. **Model Support**:")
|
|
print(" - Trained LoRA models")
|
|
print(" - Base models from HuggingFace Hub")
|
|
print(" - Automatic model loading and preparation")
|
|
print()
|
|
print("2. **Inference Modes**:")
|
|
print(" - Single text inference")
|
|
print(" - Batch file processing")
|
|
print(" - Interactive mode")
|
|
print(" - Streaming generation")
|
|
print()
|
|
print("3. **Generation Control**:")
|
|
print(" - Configurable temperature and top-p")
|
|
print(" - Adjustable max tokens")
|
|
print(" - Custom style instructions")
|
|
print()
|
|
print("4. **Output Options**:")
|
|
print(" - Console output")
|
|
print(" - File output")
|
|
print(" - Streaming real-time generation")
|
|
|
|
def create_inference_example():
|
|
"""Create an inference example using the formal style configuration"""
|
|
print("=== Inference Example: Formal Style Transfer ===")
|
|
print()
|
|
|
|
# Check if we have the required files
|
|
config_path = "configs/styling/formal.yaml"
|
|
|
|
if not Path(config_path).exists():
|
|
print(f"❌ Configuration file not found: {config_path}")
|
|
print(" Please run the data processor first to create the configuration")
|
|
return False
|
|
|
|
print("✅ Found configuration file!")
|
|
print(f" Config: {config_path}")
|
|
print()
|
|
|
|
# Example text
|
|
example_text = "Hey, what's up? I'm gonna go grab some food later."
|
|
|
|
print(f"📝 Example text: {example_text}")
|
|
print()
|
|
|
|
# Run inference
|
|
success = run_inference_with_config(
|
|
config_path=config_path,
|
|
text=example_text,
|
|
instruction="Rewrite the following text in a formal style"
|
|
)
|
|
|
|
if success:
|
|
print("🎉 Inference example completed!")
|
|
|
|
return success
|
|
|
|
def create_test_file():
|
|
"""Create a test file with sample texts for batch inference"""
|
|
test_file = "test_texts.txt"
|
|
|
|
test_texts = [
|
|
"Hey, what's up? How are you doing today?",
|
|
"I'm gonna go to the store later to get some stuff.",
|
|
"This is pretty cool, right?",
|
|
"Can you help me out with this?",
|
|
"Thanks a lot for your help!"
|
|
]
|
|
|
|
with open(test_file, 'w', encoding='utf-8') as f:
|
|
for text in test_texts:
|
|
f.write(text + '\n')
|
|
|
|
print(f"✅ Created test file: {test_file}")
|
|
print(f" Contains {len(test_texts)} sample texts")
|
|
return test_file
|
|
|
|
def run_batch_inference_example():
|
|
"""Run a batch inference example"""
|
|
print("=== Batch Inference Example ===")
|
|
print()
|
|
|
|
# Create test file
|
|
test_file = create_test_file()
|
|
|
|
# Check configuration
|
|
config_path = "configs/styling/formal.yaml"
|
|
if not Path(config_path).exists():
|
|
print(f"❌ Configuration file not found: {config_path}")
|
|
return False
|
|
|
|
print("✅ Running batch inference...")
|
|
print()
|
|
|
|
# Run batch inference
|
|
success = run_inference_with_config(
|
|
config_path=config_path,
|
|
input_file=test_file,
|
|
output_file="styled_results.txt",
|
|
instruction="Rewrite the following text in a formal style"
|
|
)
|
|
|
|
if success:
|
|
print("🎉 Batch inference completed!")
|
|
print(" Results saved to: styled_results.txt")
|
|
|
|
return success
|
|
|
|
def main():
|
|
"""Main function"""
|
|
parser = argparse.ArgumentParser(description="Styling Inference Script")
|
|
|
|
# Subcommands
|
|
parser.add_argument("command", choices=["infer", "example", "batch", "features"],
|
|
help="Command to run")
|
|
|
|
# Inference arguments
|
|
parser.add_argument("--config", type=str, help="Path to YAML configuration file")
|
|
parser.add_argument("--model-path", type=str, help="Path to trained model")
|
|
parser.add_argument("--text", type=str, help="Single text to style transfer")
|
|
parser.add_argument("--input-file", type=str, help="File containing texts to process")
|
|
parser.add_argument("--max-tokens", type=int, help="Maximum new tokens to generate")
|
|
parser.add_argument("--temperature", type=float, help="Sampling temperature")
|
|
parser.add_argument("--instruction", type=str, help="Custom style instruction")
|
|
parser.add_argument("--output-file", type=str, help="Output file for results")
|
|
parser.add_argument("--streaming", action="store_true", help="Enable streaming generation")
|
|
|
|
args = parser.parse_args()
|
|
|
|
if args.command == "features":
|
|
show_inference_features()
|
|
|
|
elif args.command == "example":
|
|
create_inference_example()
|
|
|
|
elif args.command == "batch":
|
|
run_batch_inference_example()
|
|
|
|
elif args.command == "infer":
|
|
if not args.config:
|
|
print("❌ --config is required for inference")
|
|
print("Usage: python scripts/styling/inference.py infer --config config.yaml [options]")
|
|
sys.exit(1)
|
|
|
|
# Check if we have input
|
|
if not args.text and not args.input_file:
|
|
print("❌ Either --text or --input-file is required")
|
|
print("Usage: python scripts/styling/inference.py infer --config config.yaml --text 'your text'")
|
|
sys.exit(1)
|
|
|
|
success = run_inference_with_config(
|
|
config_path=args.config,
|
|
model_path=args.model_path,
|
|
text=args.text,
|
|
input_file=args.input_file,
|
|
max_tokens=args.max_tokens,
|
|
temperature=args.temperature,
|
|
instruction=args.instruction,
|
|
output_file=args.output_file,
|
|
streaming=args.streaming
|
|
)
|
|
|
|
if not success:
|
|
sys.exit(1)
|
|
|
|
if __name__ == "__main__":
|
|
main()
|