updated style mimciking fine tuning
This commit is contained in:
@@ -0,0 +1,194 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Styling Inference Script
|
||||
Provides a command-line interface to run the styling inference pipeline
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import subprocess
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
def run_inference_with_config(config_path: str, instruction: str, input_text: str = "", max_tokens: int = 128, stream: bool = False):
|
||||
"""Run inference using a YAML configuration file"""
|
||||
print(f"Running styling inference with config: {config_path}")
|
||||
print(f"Instruction: {instruction}")
|
||||
print(f"Input text: {input_text}")
|
||||
print(f"Max tokens: {max_tokens}")
|
||||
print(f"Streaming: {stream}")
|
||||
|
||||
cmd = [
|
||||
"python", "pipelines/styling/inference.py",
|
||||
"--config", config_path,
|
||||
"--instruction", instruction,
|
||||
"--input-text", input_text,
|
||||
"--max-tokens", str(max_tokens)
|
||||
]
|
||||
|
||||
if stream:
|
||||
cmd.append("--stream")
|
||||
|
||||
print(f"Running: {' '.join(cmd)}")
|
||||
|
||||
try:
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, check=True)
|
||||
print("✅ Inference completed successfully!")
|
||||
print("Output:")
|
||||
print(result.stdout)
|
||||
return result.stdout
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"❌ Inference failed: {e}")
|
||||
print("Error output:")
|
||||
print(e.stderr)
|
||||
return None
|
||||
|
||||
def run_batch_inference_example(config_path: str, input_file: str, output_file: str, instruction: str, max_tokens: int = 128):
|
||||
"""Run batch inference example"""
|
||||
print(f"=== Batch Inference Example ===")
|
||||
print(f"Config: {config_path}")
|
||||
print(f"Input file: {input_file}")
|
||||
print(f"Output file: {output_file}")
|
||||
print(f"Instruction: {instruction}")
|
||||
print(f"Max tokens: {max_tokens}")
|
||||
|
||||
# Check if input file exists
|
||||
if not Path(input_file).exists():
|
||||
print(f"❌ Input file not found: {input_file}")
|
||||
return False
|
||||
|
||||
# Read input texts
|
||||
with open(input_file, 'r', encoding='utf-8') as f:
|
||||
input_texts = [line.strip() for line in f if line.strip()]
|
||||
|
||||
print(f"Found {len(input_texts)} texts to process")
|
||||
|
||||
# Process each text
|
||||
results = []
|
||||
for i, text in enumerate(input_texts):
|
||||
print(f"\nProcessing text {i+1}/{len(input_texts)}: {text[:50]}...")
|
||||
result = run_inference_with_config(config_path, instruction, text, max_tokens)
|
||||
if result:
|
||||
results.append(result)
|
||||
else:
|
||||
print(f"❌ Failed to process text {i+1}")
|
||||
|
||||
# Save results
|
||||
with open(output_file, 'w', encoding='utf-8') as f:
|
||||
for i, (input_text, result) in enumerate(zip(input_texts, results)):
|
||||
f.write(f"Input {i+1}: {input_text}\n")
|
||||
f.write(f"Output {i+1}: {result}\n")
|
||||
f.write("-" * 50 + "\n")
|
||||
|
||||
print(f"✅ Batch inference completed! Results saved to: {output_file}")
|
||||
return True
|
||||
|
||||
def show_inference_features():
|
||||
"""Show the features of the styling inference pipeline"""
|
||||
print("=== Styling Inference Pipeline Features ===")
|
||||
print()
|
||||
print("1. **Model Support**:")
|
||||
print(" - Trained LoRA models from training pipeline")
|
||||
print(" - Automatic model loading from config")
|
||||
print(" - Native Unsloth inference optimization")
|
||||
print()
|
||||
print("2. **Inference Modes**:")
|
||||
print(" - Single text inference with instruction")
|
||||
print(" - Batch file processing")
|
||||
print(" - Streaming generation")
|
||||
print()
|
||||
print("3. **Generation Control**:")
|
||||
print(" - Configurable max tokens")
|
||||
print(" - Same alpaca prompt format as training")
|
||||
print(" - Automatic response extraction")
|
||||
print()
|
||||
print("4. **Usage Examples**:")
|
||||
print(" - Single inference: --instruction 'style instruction' --input-text 'your text'")
|
||||
print(" - Streaming: add --stream flag")
|
||||
print(" - Batch: use batch subcommand with input/output files")
|
||||
|
||||
def create_inference_example():
|
||||
"""Create an inference example using the formal style configuration"""
|
||||
print("=== Inference Example: Formal Style Transfer ===")
|
||||
print()
|
||||
|
||||
# Check if we have the required files
|
||||
config_path = "configs/styling/formal.yaml"
|
||||
|
||||
if not Path(config_path).exists():
|
||||
print(f"❌ Configuration file not found: {config_path}")
|
||||
print(" Please run the data processor first to create the configuration")
|
||||
return False
|
||||
|
||||
print("✅ Found configuration file!")
|
||||
print(f" Config: {config_path}")
|
||||
print()
|
||||
|
||||
# Example text
|
||||
example_text = "Hey, what's up? I'm gonna go grab some food later."
|
||||
|
||||
print(f"Example text: {example_text}")
|
||||
print()
|
||||
|
||||
# Run inference
|
||||
success = run_inference_with_config(
|
||||
config_path=config_path,
|
||||
instruction="Rewrite the following text in a formal style",
|
||||
input_text=example_text
|
||||
)
|
||||
|
||||
if success:
|
||||
print("✅ Example inference completed successfully!")
|
||||
return True
|
||||
else:
|
||||
print("❌ Example inference failed!")
|
||||
return False
|
||||
|
||||
def main():
|
||||
"""Main inference function"""
|
||||
parser = argparse.ArgumentParser(description="Styling Inference Pipeline")
|
||||
subparsers = parser.add_subparsers(dest="command", help="Available commands")
|
||||
|
||||
# Inference command
|
||||
infer_parser = subparsers.add_parser("infer", help="Run single inference")
|
||||
infer_parser.add_argument("--config", type=str, required=True, help="Path to YAML configuration file")
|
||||
infer_parser.add_argument("--instruction", type=str, required=True, help="Style instruction")
|
||||
infer_parser.add_argument("--input-text", type=str, default="", help="Input text to style")
|
||||
infer_parser.add_argument("--max-tokens", type=int, default=128, help="Maximum new tokens to generate")
|
||||
infer_parser.add_argument("--stream", action="store_true", help="Enable streaming generation")
|
||||
|
||||
# Batch inference command
|
||||
batch_parser = subparsers.add_parser("batch", help="Run batch inference")
|
||||
batch_parser.add_argument("--config", type=str, required=True, help="Path to YAML configuration file")
|
||||
batch_parser.add_argument("--input-file", type=str, required=True, help="Input file with texts to style")
|
||||
batch_parser.add_argument("--output-file", type=str, required=True, help="Output file for styled texts")
|
||||
batch_parser.add_argument("--instruction", type=str, required=True, help="Style instruction")
|
||||
batch_parser.add_argument("--max-tokens", type=int, default=128, help="Maximum new tokens to generate")
|
||||
|
||||
# Features command
|
||||
subparsers.add_parser("features", help="Show available features")
|
||||
|
||||
# Example command
|
||||
subparsers.add_parser("example", help="Run example inference")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.command == "infer":
|
||||
run_inference_with_config(
|
||||
args.config,
|
||||
args.instruction,
|
||||
args.input_text,
|
||||
args.max_tokens,
|
||||
args.stream
|
||||
)
|
||||
elif args.command == "batch":
|
||||
run_batch_inference_example(args.config, args.input_file, args.output_file, args.instruction, args.max_tokens)
|
||||
elif args.command == "features":
|
||||
show_inference_features()
|
||||
elif args.command == "example":
|
||||
create_inference_example()
|
||||
else:
|
||||
parser.print_help()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user