195 lines
7.1 KiB
Python
195 lines
7.1 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Styling Inference Script
|
|
Provides a command-line interface to run the styling inference pipeline
|
|
"""
|
|
|
|
import sys
|
|
import os
|
|
import subprocess
|
|
import argparse
|
|
from pathlib import Path
|
|
|
|
def run_inference_with_config(config_path: str, instruction: str, input_text: str = "", max_tokens: int = 128, stream: bool = False):
|
|
"""Run inference using a YAML configuration file"""
|
|
print(f"Running styling inference with config: {config_path}")
|
|
print(f"Instruction: {instruction}")
|
|
print(f"Input text: {input_text}")
|
|
print(f"Max tokens: {max_tokens}")
|
|
print(f"Streaming: {stream}")
|
|
|
|
cmd = [
|
|
"python", "pipelines/styling/inference.py",
|
|
"--config", config_path,
|
|
"--instruction", instruction,
|
|
"--input-text", input_text,
|
|
"--max-tokens", str(max_tokens)
|
|
]
|
|
|
|
if stream:
|
|
cmd.append("--stream")
|
|
|
|
print(f"Running: {' '.join(cmd)}")
|
|
|
|
try:
|
|
result = subprocess.run(cmd, capture_output=True, text=True, check=True)
|
|
print("✅ Inference completed successfully!")
|
|
print("Output:")
|
|
print(result.stdout)
|
|
return result.stdout
|
|
except subprocess.CalledProcessError as e:
|
|
print(f"❌ Inference failed: {e}")
|
|
print("Error output:")
|
|
print(e.stderr)
|
|
return None
|
|
|
|
def run_batch_inference_example(config_path: str, input_file: str, output_file: str, instruction: str, max_tokens: int = 128):
|
|
"""Run batch inference example"""
|
|
print(f"=== Batch Inference Example ===")
|
|
print(f"Config: {config_path}")
|
|
print(f"Input file: {input_file}")
|
|
print(f"Output file: {output_file}")
|
|
print(f"Instruction: {instruction}")
|
|
print(f"Max tokens: {max_tokens}")
|
|
|
|
# Check if input file exists
|
|
if not Path(input_file).exists():
|
|
print(f"❌ Input file not found: {input_file}")
|
|
return False
|
|
|
|
# Read input texts
|
|
with open(input_file, 'r', encoding='utf-8') as f:
|
|
input_texts = [line.strip() for line in f if line.strip()]
|
|
|
|
print(f"Found {len(input_texts)} texts to process")
|
|
|
|
# Process each text
|
|
results = []
|
|
for i, text in enumerate(input_texts):
|
|
print(f"\nProcessing text {i+1}/{len(input_texts)}: {text[:50]}...")
|
|
result = run_inference_with_config(config_path, instruction, text, max_tokens)
|
|
if result:
|
|
results.append(result)
|
|
else:
|
|
print(f"❌ Failed to process text {i+1}")
|
|
|
|
# Save results
|
|
with open(output_file, 'w', encoding='utf-8') as f:
|
|
for i, (input_text, result) in enumerate(zip(input_texts, results)):
|
|
f.write(f"Input {i+1}: {input_text}\n")
|
|
f.write(f"Output {i+1}: {result}\n")
|
|
f.write("-" * 50 + "\n")
|
|
|
|
print(f"✅ Batch inference completed! Results saved to: {output_file}")
|
|
return True
|
|
|
|
def show_inference_features():
|
|
"""Show the features of the styling inference pipeline"""
|
|
print("=== Styling Inference Pipeline Features ===")
|
|
print()
|
|
print("1. **Model Support**:")
|
|
print(" - Trained LoRA models from training pipeline")
|
|
print(" - Automatic model loading from config")
|
|
print(" - Native Unsloth inference optimization")
|
|
print()
|
|
print("2. **Inference Modes**:")
|
|
print(" - Single text inference with instruction")
|
|
print(" - Batch file processing")
|
|
print(" - Streaming generation")
|
|
print()
|
|
print("3. **Generation Control**:")
|
|
print(" - Configurable max tokens")
|
|
print(" - Same alpaca prompt format as training")
|
|
print(" - Automatic response extraction")
|
|
print()
|
|
print("4. **Usage Examples**:")
|
|
print(" - Single inference: --instruction 'style instruction' --input-text 'your text'")
|
|
print(" - Streaming: add --stream flag")
|
|
print(" - Batch: use batch subcommand with input/output files")
|
|
|
|
def create_inference_example():
|
|
"""Create an inference example using the formal style configuration"""
|
|
print("=== Inference Example: Formal Style Transfer ===")
|
|
print()
|
|
|
|
# Check if we have the required files
|
|
config_path = "configs/styling/formal.yaml"
|
|
|
|
if not Path(config_path).exists():
|
|
print(f"❌ Configuration file not found: {config_path}")
|
|
print(" Please run the data processor first to create the configuration")
|
|
return False
|
|
|
|
print("✅ Found configuration file!")
|
|
print(f" Config: {config_path}")
|
|
print()
|
|
|
|
# Example text
|
|
example_text = "Hey, what's up? I'm gonna go grab some food later."
|
|
|
|
print(f"Example text: {example_text}")
|
|
print()
|
|
|
|
# Run inference
|
|
success = run_inference_with_config(
|
|
config_path=config_path,
|
|
instruction="Rewrite the following text in a formal style",
|
|
input_text=example_text
|
|
)
|
|
|
|
if success:
|
|
print("✅ Example inference completed successfully!")
|
|
return True
|
|
else:
|
|
print("❌ Example inference failed!")
|
|
return False
|
|
|
|
def main():
|
|
"""Main inference function"""
|
|
parser = argparse.ArgumentParser(description="Styling Inference Pipeline")
|
|
subparsers = parser.add_subparsers(dest="command", help="Available commands")
|
|
|
|
# Inference command
|
|
infer_parser = subparsers.add_parser("infer", help="Run single inference")
|
|
infer_parser.add_argument("--config", type=str, required=True, help="Path to YAML configuration file")
|
|
infer_parser.add_argument("--instruction", type=str, required=True, help="Style instruction")
|
|
infer_parser.add_argument("--input-text", type=str, default="", help="Input text to style")
|
|
infer_parser.add_argument("--max-tokens", type=int, default=128, help="Maximum new tokens to generate")
|
|
infer_parser.add_argument("--stream", action="store_true", help="Enable streaming generation")
|
|
|
|
# Batch inference command
|
|
batch_parser = subparsers.add_parser("batch", help="Run batch inference")
|
|
batch_parser.add_argument("--config", type=str, required=True, help="Path to YAML configuration file")
|
|
batch_parser.add_argument("--input-file", type=str, required=True, help="Input file with texts to style")
|
|
batch_parser.add_argument("--output-file", type=str, required=True, help="Output file for styled texts")
|
|
batch_parser.add_argument("--instruction", type=str, required=True, help="Style instruction")
|
|
batch_parser.add_argument("--max-tokens", type=int, default=128, help="Maximum new tokens to generate")
|
|
|
|
# Features command
|
|
subparsers.add_parser("features", help="Show available features")
|
|
|
|
# Example command
|
|
subparsers.add_parser("example", help="Run example inference")
|
|
|
|
args = parser.parse_args()
|
|
|
|
if args.command == "infer":
|
|
run_inference_with_config(
|
|
args.config,
|
|
args.instruction,
|
|
args.input_text,
|
|
args.max_tokens,
|
|
args.stream
|
|
)
|
|
elif args.command == "batch":
|
|
run_batch_inference_example(args.config, args.input_file, args.output_file, args.instruction, args.max_tokens)
|
|
elif args.command == "features":
|
|
show_inference_features()
|
|
elif args.command == "example":
|
|
create_inference_example()
|
|
else:
|
|
parser.print_help()
|
|
|
|
if __name__ == "__main__":
|
|
main()
|