updated style mimciking fine tuning
This commit is contained in:
@@ -0,0 +1,194 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Styling Inference Script
|
||||
Provides a command-line interface to run the styling inference pipeline
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import subprocess
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
def run_inference_with_config(config_path: str, instruction: str, input_text: str = "", max_tokens: int = 128, stream: bool = False):
|
||||
"""Run inference using a YAML configuration file"""
|
||||
print(f"Running styling inference with config: {config_path}")
|
||||
print(f"Instruction: {instruction}")
|
||||
print(f"Input text: {input_text}")
|
||||
print(f"Max tokens: {max_tokens}")
|
||||
print(f"Streaming: {stream}")
|
||||
|
||||
cmd = [
|
||||
"python", "pipelines/styling/inference.py",
|
||||
"--config", config_path,
|
||||
"--instruction", instruction,
|
||||
"--input-text", input_text,
|
||||
"--max-tokens", str(max_tokens)
|
||||
]
|
||||
|
||||
if stream:
|
||||
cmd.append("--stream")
|
||||
|
||||
print(f"Running: {' '.join(cmd)}")
|
||||
|
||||
try:
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, check=True)
|
||||
print("✅ Inference completed successfully!")
|
||||
print("Output:")
|
||||
print(result.stdout)
|
||||
return result.stdout
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"❌ Inference failed: {e}")
|
||||
print("Error output:")
|
||||
print(e.stderr)
|
||||
return None
|
||||
|
||||
def run_batch_inference_example(config_path: str, input_file: str, output_file: str, instruction: str, max_tokens: int = 128):
|
||||
"""Run batch inference example"""
|
||||
print(f"=== Batch Inference Example ===")
|
||||
print(f"Config: {config_path}")
|
||||
print(f"Input file: {input_file}")
|
||||
print(f"Output file: {output_file}")
|
||||
print(f"Instruction: {instruction}")
|
||||
print(f"Max tokens: {max_tokens}")
|
||||
|
||||
# Check if input file exists
|
||||
if not Path(input_file).exists():
|
||||
print(f"❌ Input file not found: {input_file}")
|
||||
return False
|
||||
|
||||
# Read input texts
|
||||
with open(input_file, 'r', encoding='utf-8') as f:
|
||||
input_texts = [line.strip() for line in f if line.strip()]
|
||||
|
||||
print(f"Found {len(input_texts)} texts to process")
|
||||
|
||||
# Process each text
|
||||
results = []
|
||||
for i, text in enumerate(input_texts):
|
||||
print(f"\nProcessing text {i+1}/{len(input_texts)}: {text[:50]}...")
|
||||
result = run_inference_with_config(config_path, instruction, text, max_tokens)
|
||||
if result:
|
||||
results.append(result)
|
||||
else:
|
||||
print(f"❌ Failed to process text {i+1}")
|
||||
|
||||
# Save results
|
||||
with open(output_file, 'w', encoding='utf-8') as f:
|
||||
for i, (input_text, result) in enumerate(zip(input_texts, results)):
|
||||
f.write(f"Input {i+1}: {input_text}\n")
|
||||
f.write(f"Output {i+1}: {result}\n")
|
||||
f.write("-" * 50 + "\n")
|
||||
|
||||
print(f"✅ Batch inference completed! Results saved to: {output_file}")
|
||||
return True
|
||||
|
||||
def show_inference_features():
|
||||
"""Show the features of the styling inference pipeline"""
|
||||
print("=== Styling Inference Pipeline Features ===")
|
||||
print()
|
||||
print("1. **Model Support**:")
|
||||
print(" - Trained LoRA models from training pipeline")
|
||||
print(" - Automatic model loading from config")
|
||||
print(" - Native Unsloth inference optimization")
|
||||
print()
|
||||
print("2. **Inference Modes**:")
|
||||
print(" - Single text inference with instruction")
|
||||
print(" - Batch file processing")
|
||||
print(" - Streaming generation")
|
||||
print()
|
||||
print("3. **Generation Control**:")
|
||||
print(" - Configurable max tokens")
|
||||
print(" - Same alpaca prompt format as training")
|
||||
print(" - Automatic response extraction")
|
||||
print()
|
||||
print("4. **Usage Examples**:")
|
||||
print(" - Single inference: --instruction 'style instruction' --input-text 'your text'")
|
||||
print(" - Streaming: add --stream flag")
|
||||
print(" - Batch: use batch subcommand with input/output files")
|
||||
|
||||
def create_inference_example():
|
||||
"""Create an inference example using the formal style configuration"""
|
||||
print("=== Inference Example: Formal Style Transfer ===")
|
||||
print()
|
||||
|
||||
# Check if we have the required files
|
||||
config_path = "configs/styling/formal.yaml"
|
||||
|
||||
if not Path(config_path).exists():
|
||||
print(f"❌ Configuration file not found: {config_path}")
|
||||
print(" Please run the data processor first to create the configuration")
|
||||
return False
|
||||
|
||||
print("✅ Found configuration file!")
|
||||
print(f" Config: {config_path}")
|
||||
print()
|
||||
|
||||
# Example text
|
||||
example_text = "Hey, what's up? I'm gonna go grab some food later."
|
||||
|
||||
print(f"Example text: {example_text}")
|
||||
print()
|
||||
|
||||
# Run inference
|
||||
success = run_inference_with_config(
|
||||
config_path=config_path,
|
||||
instruction="Rewrite the following text in a formal style",
|
||||
input_text=example_text
|
||||
)
|
||||
|
||||
if success:
|
||||
print("✅ Example inference completed successfully!")
|
||||
return True
|
||||
else:
|
||||
print("❌ Example inference failed!")
|
||||
return False
|
||||
|
||||
def main():
|
||||
"""Main inference function"""
|
||||
parser = argparse.ArgumentParser(description="Styling Inference Pipeline")
|
||||
subparsers = parser.add_subparsers(dest="command", help="Available commands")
|
||||
|
||||
# Inference command
|
||||
infer_parser = subparsers.add_parser("infer", help="Run single inference")
|
||||
infer_parser.add_argument("--config", type=str, required=True, help="Path to YAML configuration file")
|
||||
infer_parser.add_argument("--instruction", type=str, required=True, help="Style instruction")
|
||||
infer_parser.add_argument("--input-text", type=str, default="", help="Input text to style")
|
||||
infer_parser.add_argument("--max-tokens", type=int, default=128, help="Maximum new tokens to generate")
|
||||
infer_parser.add_argument("--stream", action="store_true", help="Enable streaming generation")
|
||||
|
||||
# Batch inference command
|
||||
batch_parser = subparsers.add_parser("batch", help="Run batch inference")
|
||||
batch_parser.add_argument("--config", type=str, required=True, help="Path to YAML configuration file")
|
||||
batch_parser.add_argument("--input-file", type=str, required=True, help="Input file with texts to style")
|
||||
batch_parser.add_argument("--output-file", type=str, required=True, help="Output file for styled texts")
|
||||
batch_parser.add_argument("--instruction", type=str, required=True, help="Style instruction")
|
||||
batch_parser.add_argument("--max-tokens", type=int, default=128, help="Maximum new tokens to generate")
|
||||
|
||||
# Features command
|
||||
subparsers.add_parser("features", help="Show available features")
|
||||
|
||||
# Example command
|
||||
subparsers.add_parser("example", help="Run example inference")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.command == "infer":
|
||||
run_inference_with_config(
|
||||
args.config,
|
||||
args.instruction,
|
||||
args.input_text,
|
||||
args.max_tokens,
|
||||
args.stream
|
||||
)
|
||||
elif args.command == "batch":
|
||||
run_batch_inference_example(args.config, args.input_file, args.output_file, args.instruction, args.max_tokens)
|
||||
elif args.command == "features":
|
||||
show_inference_features()
|
||||
elif args.command == "example":
|
||||
create_inference_example()
|
||||
else:
|
||||
parser.print_help()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,168 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Styling Training Script
|
||||
Provides a command-line interface to run the styling training pipeline
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import subprocess
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
def run_training_with_config(config_path: str, dataset_path: str = None, **cli_overrides):
|
||||
"""Run the styling training pipeline with YAML configuration"""
|
||||
print(f"Starting styling training with config: {config_path}")
|
||||
if dataset_path:
|
||||
print(f"Training dataset: {dataset_path}")
|
||||
else:
|
||||
print("Training dataset: Will use output_dir from YAML config")
|
||||
print()
|
||||
|
||||
# Build command
|
||||
cmd = ["python", "pipelines/styling/train.py", "--config", config_path]
|
||||
|
||||
# Add dataset path if provided
|
||||
if dataset_path:
|
||||
cmd.extend(["--dataset", dataset_path])
|
||||
|
||||
# Add CLI overrides
|
||||
for key, value in cli_overrides.items():
|
||||
if value is not None:
|
||||
if key == "output_dir":
|
||||
cmd.extend(["--output-dir", str(value)])
|
||||
elif key == "epochs":
|
||||
cmd.extend(["--epochs", str(value)])
|
||||
elif key == "batch_size":
|
||||
cmd.extend(["--batch-size", str(value)])
|
||||
elif key == "learning_rate":
|
||||
cmd.extend(["--learning-rate", str(value)])
|
||||
elif key == "max_steps":
|
||||
cmd.extend(["--max-steps", str(value)])
|
||||
|
||||
print(f"Running: {' '.join(cmd)}")
|
||||
print()
|
||||
|
||||
try:
|
||||
result = subprocess.run(cmd, check=True, capture_output=True, text=True)
|
||||
print("Training completed successfully!")
|
||||
print(result.stdout)
|
||||
return True
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"Training failed: {e}")
|
||||
print(f"Error output: {e.stderr}")
|
||||
return False
|
||||
|
||||
def show_training_features():
|
||||
"""Show the features of the styling training pipeline"""
|
||||
print("=== Styling Training Pipeline Features ===")
|
||||
print()
|
||||
print("1. **Model Support**:")
|
||||
print(" - Unsloth optimized models (4x faster)")
|
||||
print(" - LoRA fine-tuning for efficiency")
|
||||
print(" - Support for Llama-3.1, Mistral, Phi-3, Gemma")
|
||||
print()
|
||||
print("2. **Training Features**:")
|
||||
print(" - SFTTrainer with instruction tuning")
|
||||
print(" - Automatic mixed precision (FP16/BF16)")
|
||||
print(" - Gradient checkpointing for memory efficiency")
|
||||
print(" - Configurable LoRA parameters")
|
||||
print()
|
||||
print("3. **Configuration**:")
|
||||
print(" - YAML configuration files")
|
||||
print(" - CLI argument overrides")
|
||||
print(" - Automatic device detection")
|
||||
print()
|
||||
print("4. **Output**:")
|
||||
print(" - Saved LoRA models")
|
||||
print(" - Training logs and checkpoints")
|
||||
print(" - Ready for inference")
|
||||
|
||||
def create_training_example():
|
||||
"""Create a training example using the formal style configuration"""
|
||||
print("=== Training Example: Formal Style Transfer ===")
|
||||
print()
|
||||
|
||||
# Check if we have the required files
|
||||
config_path = "configs/styling/formal.yaml"
|
||||
|
||||
if not Path(config_path).exists():
|
||||
print(f"Configuration file not found: {config_path}")
|
||||
print(" Please run the data processor first to create the configuration")
|
||||
return False
|
||||
|
||||
print("Found required files!")
|
||||
print(f" Config: {config_path}")
|
||||
print(" Dataset: Will use output_dir from YAML config")
|
||||
print(" The training pipeline will automatically:")
|
||||
print(" - Load data from the output_dir specified in YAML")
|
||||
print(" - Convert JSONL files to HuggingFace dataset format")
|
||||
print(" - Apply formatting with EOS tokens")
|
||||
print(" - Train the model using SFTTrainer")
|
||||
print()
|
||||
|
||||
# Run training without explicit dataset path - will use YAML config
|
||||
success = run_training_with_config(
|
||||
config_path=config_path,
|
||||
dataset_path=None, # Use output_dir from YAML config
|
||||
epochs=1,
|
||||
batch_size=2,
|
||||
learning_rate=2e-4
|
||||
)
|
||||
|
||||
if success:
|
||||
print("Training example completed!")
|
||||
print(" Model saved to: ./models/styling")
|
||||
print(" Ready for inference!")
|
||||
|
||||
return success
|
||||
|
||||
def main():
|
||||
"""Main function"""
|
||||
parser = argparse.ArgumentParser(description="Styling Training Script")
|
||||
|
||||
# Subcommands
|
||||
parser.add_argument("command", choices=["train", "example", "features"],
|
||||
help="Command to run")
|
||||
|
||||
# Training arguments
|
||||
parser.add_argument("--config", type=str, help="Path to YAML configuration file")
|
||||
parser.add_argument("--dataset", type=str, help="Path to training dataset")
|
||||
parser.add_argument("--output-dir", type=str, help="Output directory for model")
|
||||
parser.add_argument("--epochs", type=int, help="Number of training epochs")
|
||||
parser.add_argument("--batch-size", type=int, help="Training batch size")
|
||||
parser.add_argument("--learning-rate", type=float, help="Learning rate")
|
||||
parser.add_argument("--max-steps", type=int, help="Maximum training steps")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.command == "features":
|
||||
show_training_features()
|
||||
|
||||
elif args.command == "example":
|
||||
create_training_example()
|
||||
|
||||
elif args.command == "train":
|
||||
if not args.config:
|
||||
print("❌ --config is required for training")
|
||||
print("Usage: python scripts/styling/train.py train --config config.yaml")
|
||||
sys.exit(1)
|
||||
|
||||
# If dataset is not provided, try to use output_dir from config
|
||||
dataset_path = args.dataset if args.dataset else None
|
||||
|
||||
success = run_training_with_config(
|
||||
config_path=args.config,
|
||||
dataset_path=dataset_path,
|
||||
output_dir=args.output_dir,
|
||||
epochs=args.epochs,
|
||||
batch_size=args.batch_size,
|
||||
learning_rate=args.learning_rate,
|
||||
max_steps=args.max_steps
|
||||
)
|
||||
|
||||
if not success:
|
||||
sys.exit(1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
+115
-144
@@ -10,71 +10,102 @@ import subprocess
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
def run_inference_with_config(config_path: str, **cli_overrides):
|
||||
"""Run the styling inference pipeline with YAML configuration"""
|
||||
print(f"🚀 Starting styling inference with config: {config_path}")
|
||||
print()
|
||||
def run_inference_with_config(config_path: str, instruction: str, input_text: str = "", max_tokens: int = 128, stream: bool = False):
|
||||
"""Run inference using a YAML configuration file"""
|
||||
print(f"Running styling inference with config: {config_path}")
|
||||
print(f"Instruction: {instruction}")
|
||||
print(f"Input text: {input_text}")
|
||||
print(f"Max tokens: {max_tokens}")
|
||||
print(f"Streaming: {stream}")
|
||||
|
||||
# Build command
|
||||
cmd = ["python", "pipelines/styling/inference.py", "--config", config_path]
|
||||
cmd = [
|
||||
"python", "pipelines/styling/inference.py",
|
||||
"--config", config_path,
|
||||
"--instruction", instruction,
|
||||
"--input-text", input_text,
|
||||
"--max-tokens", str(max_tokens)
|
||||
]
|
||||
|
||||
# Add CLI overrides
|
||||
for key, value in cli_overrides.items():
|
||||
if value is not None:
|
||||
if key == "model_path":
|
||||
cmd.extend(["--model-path", str(value)])
|
||||
elif key == "text":
|
||||
cmd.extend(["--text", str(value)])
|
||||
elif key == "input_file":
|
||||
cmd.extend(["--input-file", str(value)])
|
||||
elif key == "max_tokens":
|
||||
cmd.extend(["--max-tokens", str(value)])
|
||||
elif key == "temperature":
|
||||
cmd.extend(["--temperature", str(value)])
|
||||
elif key == "instruction":
|
||||
cmd.extend(["--instruction", str(value)])
|
||||
elif key == "output_file":
|
||||
cmd.extend(["--output-file", str(value)])
|
||||
elif key == "streaming":
|
||||
cmd.append("--streaming")
|
||||
if stream:
|
||||
cmd.append("--stream")
|
||||
|
||||
print(f"Running: {' '.join(cmd)}")
|
||||
print()
|
||||
|
||||
try:
|
||||
result = subprocess.run(cmd, check=True, capture_output=True, text=True)
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, check=True)
|
||||
print("✅ Inference completed successfully!")
|
||||
print("Output:")
|
||||
print(result.stdout)
|
||||
return True
|
||||
return result.stdout
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"❌ Inference failed: {e}")
|
||||
print(f"Error output: {e.stderr}")
|
||||
print("Error output:")
|
||||
print(e.stderr)
|
||||
return None
|
||||
|
||||
def run_batch_inference_example(config_path: str, input_file: str, output_file: str, instruction: str, max_tokens: int = 128):
|
||||
"""Run batch inference example"""
|
||||
print(f"=== Batch Inference Example ===")
|
||||
print(f"Config: {config_path}")
|
||||
print(f"Input file: {input_file}")
|
||||
print(f"Output file: {output_file}")
|
||||
print(f"Instruction: {instruction}")
|
||||
print(f"Max tokens: {max_tokens}")
|
||||
|
||||
# Check if input file exists
|
||||
if not Path(input_file).exists():
|
||||
print(f"❌ Input file not found: {input_file}")
|
||||
return False
|
||||
|
||||
# Read input texts
|
||||
with open(input_file, 'r', encoding='utf-8') as f:
|
||||
input_texts = [line.strip() for line in f if line.strip()]
|
||||
|
||||
print(f"Found {len(input_texts)} texts to process")
|
||||
|
||||
# Process each text
|
||||
results = []
|
||||
for i, text in enumerate(input_texts):
|
||||
print(f"\nProcessing text {i+1}/{len(input_texts)}: {text[:50]}...")
|
||||
result = run_inference_with_config(config_path, instruction, text, max_tokens)
|
||||
if result:
|
||||
results.append(result)
|
||||
else:
|
||||
print(f"❌ Failed to process text {i+1}")
|
||||
|
||||
# Save results
|
||||
with open(output_file, 'w', encoding='utf-8') as f:
|
||||
for i, (input_text, result) in enumerate(zip(input_texts, results)):
|
||||
f.write(f"Input {i+1}: {input_text}\n")
|
||||
f.write(f"Output {i+1}: {result}\n")
|
||||
f.write("-" * 50 + "\n")
|
||||
|
||||
print(f"✅ Batch inference completed! Results saved to: {output_file}")
|
||||
return True
|
||||
|
||||
def show_inference_features():
|
||||
"""Show the features of the styling inference pipeline"""
|
||||
print("=== Styling Inference Pipeline Features ===")
|
||||
print()
|
||||
print("1. **Model Support**:")
|
||||
print(" - Trained LoRA models")
|
||||
print(" - Base models from HuggingFace Hub")
|
||||
print(" - Automatic model loading and preparation")
|
||||
print(" - Trained LoRA models from training pipeline")
|
||||
print(" - Automatic model loading from config")
|
||||
print(" - Native Unsloth inference optimization")
|
||||
print()
|
||||
print("2. **Inference Modes**:")
|
||||
print(" - Single text inference")
|
||||
print(" - Single text inference with instruction")
|
||||
print(" - Batch file processing")
|
||||
print(" - Interactive mode")
|
||||
print(" - Streaming generation")
|
||||
print()
|
||||
print("3. **Generation Control**:")
|
||||
print(" - Configurable temperature and top-p")
|
||||
print(" - Adjustable max tokens")
|
||||
print(" - Custom style instructions")
|
||||
print(" - Configurable max tokens")
|
||||
print(" - Same alpaca prompt format as training")
|
||||
print(" - Automatic response extraction")
|
||||
print()
|
||||
print("4. **Output Options**:")
|
||||
print(" - Console output")
|
||||
print(" - File output")
|
||||
print(" - Streaming real-time generation")
|
||||
print("4. **Usage Examples**:")
|
||||
print(" - Single inference: --instruction 'style instruction' --input-text 'your text'")
|
||||
print(" - Streaming: add --stream flag")
|
||||
print(" - Batch: use batch subcommand with input/output files")
|
||||
|
||||
def create_inference_example():
|
||||
"""Create an inference example using the formal style configuration"""
|
||||
@@ -96,128 +127,68 @@ def create_inference_example():
|
||||
# Example text
|
||||
example_text = "Hey, what's up? I'm gonna go grab some food later."
|
||||
|
||||
print(f"📝 Example text: {example_text}")
|
||||
print(f"Example text: {example_text}")
|
||||
print()
|
||||
|
||||
# Run inference
|
||||
success = run_inference_with_config(
|
||||
config_path=config_path,
|
||||
text=example_text,
|
||||
instruction="Rewrite the following text in a formal style"
|
||||
instruction="Rewrite the following text in a formal style",
|
||||
input_text=example_text
|
||||
)
|
||||
|
||||
if success:
|
||||
print("🎉 Inference example completed!")
|
||||
|
||||
return success
|
||||
|
||||
def create_test_file():
|
||||
"""Create a test file with sample texts for batch inference"""
|
||||
test_file = "test_texts.txt"
|
||||
|
||||
test_texts = [
|
||||
"Hey, what's up? How are you doing today?",
|
||||
"I'm gonna go to the store later to get some stuff.",
|
||||
"This is pretty cool, right?",
|
||||
"Can you help me out with this?",
|
||||
"Thanks a lot for your help!"
|
||||
]
|
||||
|
||||
with open(test_file, 'w', encoding='utf-8') as f:
|
||||
for text in test_texts:
|
||||
f.write(text + '\n')
|
||||
|
||||
print(f"✅ Created test file: {test_file}")
|
||||
print(f" Contains {len(test_texts)} sample texts")
|
||||
return test_file
|
||||
|
||||
def run_batch_inference_example():
|
||||
"""Run a batch inference example"""
|
||||
print("=== Batch Inference Example ===")
|
||||
print()
|
||||
|
||||
# Create test file
|
||||
test_file = create_test_file()
|
||||
|
||||
# Check configuration
|
||||
config_path = "configs/styling/formal.yaml"
|
||||
if not Path(config_path).exists():
|
||||
print(f"❌ Configuration file not found: {config_path}")
|
||||
print("✅ Example inference completed successfully!")
|
||||
return True
|
||||
else:
|
||||
print("❌ Example inference failed!")
|
||||
return False
|
||||
|
||||
print("✅ Running batch inference...")
|
||||
print()
|
||||
|
||||
# Run batch inference
|
||||
success = run_inference_with_config(
|
||||
config_path=config_path,
|
||||
input_file=test_file,
|
||||
output_file="styled_results.txt",
|
||||
instruction="Rewrite the following text in a formal style"
|
||||
)
|
||||
|
||||
if success:
|
||||
print("🎉 Batch inference completed!")
|
||||
print(" Results saved to: styled_results.txt")
|
||||
|
||||
return success
|
||||
|
||||
def main():
|
||||
"""Main function"""
|
||||
parser = argparse.ArgumentParser(description="Styling Inference Script")
|
||||
"""Main inference function"""
|
||||
parser = argparse.ArgumentParser(description="Styling Inference Pipeline")
|
||||
subparsers = parser.add_subparsers(dest="command", help="Available commands")
|
||||
|
||||
# Subcommands
|
||||
parser.add_argument("command", choices=["infer", "example", "batch", "features"],
|
||||
help="Command to run")
|
||||
# Inference command
|
||||
infer_parser = subparsers.add_parser("infer", help="Run single inference")
|
||||
infer_parser.add_argument("--config", type=str, required=True, help="Path to YAML configuration file")
|
||||
infer_parser.add_argument("--instruction", type=str, required=True, help="Style instruction")
|
||||
infer_parser.add_argument("--input-text", type=str, default="", help="Input text to style")
|
||||
infer_parser.add_argument("--max-tokens", type=int, default=128, help="Maximum new tokens to generate")
|
||||
infer_parser.add_argument("--stream", action="store_true", help="Enable streaming generation")
|
||||
|
||||
# Inference arguments
|
||||
parser.add_argument("--config", type=str, help="Path to YAML configuration file")
|
||||
parser.add_argument("--model-path", type=str, help="Path to trained model")
|
||||
parser.add_argument("--text", type=str, help="Single text to style transfer")
|
||||
parser.add_argument("--input-file", type=str, help="File containing texts to process")
|
||||
parser.add_argument("--max-tokens", type=int, help="Maximum new tokens to generate")
|
||||
parser.add_argument("--temperature", type=float, help="Sampling temperature")
|
||||
parser.add_argument("--instruction", type=str, help="Custom style instruction")
|
||||
parser.add_argument("--output-file", type=str, help="Output file for results")
|
||||
parser.add_argument("--streaming", action="store_true", help="Enable streaming generation")
|
||||
# Batch inference command
|
||||
batch_parser = subparsers.add_parser("batch", help="Run batch inference")
|
||||
batch_parser.add_argument("--config", type=str, required=True, help="Path to YAML configuration file")
|
||||
batch_parser.add_argument("--input-file", type=str, required=True, help="Input file with texts to style")
|
||||
batch_parser.add_argument("--output-file", type=str, required=True, help="Output file for styled texts")
|
||||
batch_parser.add_argument("--instruction", type=str, required=True, help="Style instruction")
|
||||
batch_parser.add_argument("--max-tokens", type=int, default=128, help="Maximum new tokens to generate")
|
||||
|
||||
# Features command
|
||||
subparsers.add_parser("features", help="Show available features")
|
||||
|
||||
# Example command
|
||||
subparsers.add_parser("example", help="Run example inference")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.command == "features":
|
||||
if args.command == "infer":
|
||||
run_inference_with_config(
|
||||
args.config,
|
||||
args.instruction,
|
||||
args.input_text,
|
||||
args.max_tokens,
|
||||
args.stream
|
||||
)
|
||||
elif args.command == "batch":
|
||||
run_batch_inference_example(args.config, args.input_file, args.output_file, args.instruction, args.max_tokens)
|
||||
elif args.command == "features":
|
||||
show_inference_features()
|
||||
|
||||
elif args.command == "example":
|
||||
create_inference_example()
|
||||
|
||||
elif args.command == "batch":
|
||||
run_batch_inference_example()
|
||||
|
||||
elif args.command == "infer":
|
||||
if not args.config:
|
||||
print("❌ --config is required for inference")
|
||||
print("Usage: python scripts/styling/inference.py infer --config config.yaml [options]")
|
||||
sys.exit(1)
|
||||
|
||||
# Check if we have input
|
||||
if not args.text and not args.input_file:
|
||||
print("❌ Either --text or --input-file is required")
|
||||
print("Usage: python scripts/styling/inference.py infer --config config.yaml --text 'your text'")
|
||||
sys.exit(1)
|
||||
|
||||
success = run_inference_with_config(
|
||||
config_path=args.config,
|
||||
model_path=args.model_path,
|
||||
text=args.text,
|
||||
input_file=args.input_file,
|
||||
max_tokens=args.max_tokens,
|
||||
temperature=args.temperature,
|
||||
instruction=args.instruction,
|
||||
output_file=args.output_file,
|
||||
streaming=args.streaming
|
||||
)
|
||||
|
||||
if not success:
|
||||
sys.exit(1)
|
||||
else:
|
||||
parser.print_help()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
Reference in New Issue
Block a user