Files
DS-LLM-TEMPLATE-FINETUNING/scripts/styling/.ipynb_checkpoints/inference-checkpoint.py
T
2025-08-13 23:50:20 +00:00

195 lines
7.1 KiB
Python

#!/usr/bin/env python3
"""
Styling Inference Script
Provides a command-line interface to run the styling inference pipeline
"""
import sys
import os
import subprocess
import argparse
from pathlib import Path
def run_inference_with_config(config_path: str, instruction: str, input_text: str = "", max_tokens: int = 128, stream: bool = False):
"""Run inference using a YAML configuration file"""
print(f"Running styling inference with config: {config_path}")
print(f"Instruction: {instruction}")
print(f"Input text: {input_text}")
print(f"Max tokens: {max_tokens}")
print(f"Streaming: {stream}")
cmd = [
"python", "pipelines/styling/inference.py",
"--config", config_path,
"--instruction", instruction,
"--input-text", input_text,
"--max-tokens", str(max_tokens)
]
if stream:
cmd.append("--stream")
print(f"Running: {' '.join(cmd)}")
try:
result = subprocess.run(cmd, capture_output=True, text=True, check=True)
print("✅ Inference completed successfully!")
print("Output:")
print(result.stdout)
return result.stdout
except subprocess.CalledProcessError as e:
print(f"❌ Inference failed: {e}")
print("Error output:")
print(e.stderr)
return None
def run_batch_inference_example(config_path: str, input_file: str, output_file: str, instruction: str, max_tokens: int = 128):
"""Run batch inference example"""
print(f"=== Batch Inference Example ===")
print(f"Config: {config_path}")
print(f"Input file: {input_file}")
print(f"Output file: {output_file}")
print(f"Instruction: {instruction}")
print(f"Max tokens: {max_tokens}")
# Check if input file exists
if not Path(input_file).exists():
print(f"❌ Input file not found: {input_file}")
return False
# Read input texts
with open(input_file, 'r', encoding='utf-8') as f:
input_texts = [line.strip() for line in f if line.strip()]
print(f"Found {len(input_texts)} texts to process")
# Process each text
results = []
for i, text in enumerate(input_texts):
print(f"\nProcessing text {i+1}/{len(input_texts)}: {text[:50]}...")
result = run_inference_with_config(config_path, instruction, text, max_tokens)
if result:
results.append(result)
else:
print(f"❌ Failed to process text {i+1}")
# Save results
with open(output_file, 'w', encoding='utf-8') as f:
for i, (input_text, result) in enumerate(zip(input_texts, results)):
f.write(f"Input {i+1}: {input_text}\n")
f.write(f"Output {i+1}: {result}\n")
f.write("-" * 50 + "\n")
print(f"✅ Batch inference completed! Results saved to: {output_file}")
return True
def show_inference_features():
"""Show the features of the styling inference pipeline"""
print("=== Styling Inference Pipeline Features ===")
print()
print("1. **Model Support**:")
print(" - Trained LoRA models from training pipeline")
print(" - Automatic model loading from config")
print(" - Native Unsloth inference optimization")
print()
print("2. **Inference Modes**:")
print(" - Single text inference with instruction")
print(" - Batch file processing")
print(" - Streaming generation")
print()
print("3. **Generation Control**:")
print(" - Configurable max tokens")
print(" - Same alpaca prompt format as training")
print(" - Automatic response extraction")
print()
print("4. **Usage Examples**:")
print(" - Single inference: --instruction 'style instruction' --input-text 'your text'")
print(" - Streaming: add --stream flag")
print(" - Batch: use batch subcommand with input/output files")
def create_inference_example():
"""Create an inference example using the formal style configuration"""
print("=== Inference Example: Formal Style Transfer ===")
print()
# Check if we have the required files
config_path = "configs/styling/formal.yaml"
if not Path(config_path).exists():
print(f"❌ Configuration file not found: {config_path}")
print(" Please run the data processor first to create the configuration")
return False
print("✅ Found configuration file!")
print(f" Config: {config_path}")
print()
# Example text
example_text = "Hey, what's up? I'm gonna go grab some food later."
print(f"Example text: {example_text}")
print()
# Run inference
success = run_inference_with_config(
config_path=config_path,
instruction="Rewrite the following text in a formal style",
input_text=example_text
)
if success:
print("✅ Example inference completed successfully!")
return True
else:
print("❌ Example inference failed!")
return False
def main():
"""Main inference function"""
parser = argparse.ArgumentParser(description="Styling Inference Pipeline")
subparsers = parser.add_subparsers(dest="command", help="Available commands")
# Inference command
infer_parser = subparsers.add_parser("infer", help="Run single inference")
infer_parser.add_argument("--config", type=str, required=True, help="Path to YAML configuration file")
infer_parser.add_argument("--instruction", type=str, required=True, help="Style instruction")
infer_parser.add_argument("--input-text", type=str, default="", help="Input text to style")
infer_parser.add_argument("--max-tokens", type=int, default=128, help="Maximum new tokens to generate")
infer_parser.add_argument("--stream", action="store_true", help="Enable streaming generation")
# Batch inference command
batch_parser = subparsers.add_parser("batch", help="Run batch inference")
batch_parser.add_argument("--config", type=str, required=True, help="Path to YAML configuration file")
batch_parser.add_argument("--input-file", type=str, required=True, help="Input file with texts to style")
batch_parser.add_argument("--output-file", type=str, required=True, help="Output file for styled texts")
batch_parser.add_argument("--instruction", type=str, required=True, help="Style instruction")
batch_parser.add_argument("--max-tokens", type=int, default=128, help="Maximum new tokens to generate")
# Features command
subparsers.add_parser("features", help="Show available features")
# Example command
subparsers.add_parser("example", help="Run example inference")
args = parser.parse_args()
if args.command == "infer":
run_inference_with_config(
args.config,
args.instruction,
args.input_text,
args.max_tokens,
args.stream
)
elif args.command == "batch":
run_batch_inference_example(args.config, args.input_file, args.output_file, args.instruction, args.max_tokens)
elif args.command == "features":
show_inference_features()
elif args.command == "example":
create_inference_example()
else:
parser.print_help()
if __name__ == "__main__":
main()