320 lines
12 KiB
Python
320 lines
12 KiB
Python
|
|
#!/usr/bin/env python3
|
||
|
|
"""
|
||
|
|
Styling Inference Pipeline using Trained Models
|
||
|
|
Supports style transfer inference with streaming and batch processing
|
||
|
|
"""
|
||
|
|
|
||
|
|
import os
|
||
|
|
import sys
|
||
|
|
import json
|
||
|
|
import argparse
|
||
|
|
from pathlib import Path
|
||
|
|
from typing import Dict, Any, Optional, List, Union
|
||
|
|
import yaml
|
||
|
|
|
||
|
|
# Add the project root to the path
|
||
|
|
sys.path.append(str(Path(__file__).parent.parent.parent))
|
||
|
|
|
||
|
|
# Inference imports
|
||
|
|
import torch
|
||
|
|
from datasets import load_from_disk, Dataset
|
||
|
|
from unsloth import FastLanguageModel
|
||
|
|
from transformers import TextStreamer
|
||
|
|
|
||
|
|
class StylingInference:
|
||
|
|
"""Styling task inference using trained models"""
|
||
|
|
|
||
|
|
def __init__(self, config: Dict[str, Any]):
|
||
|
|
self.config = config
|
||
|
|
self.model = None
|
||
|
|
self.tokenizer = None
|
||
|
|
|
||
|
|
# Set device
|
||
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||
|
|
print(f"Using device: {self.device}")
|
||
|
|
|
||
|
|
# Model parameters
|
||
|
|
self.model_output_dir = config.get('model_output_dir', './models/styling')
|
||
|
|
self.max_seq_length = config.get('max_seq_length', 2048)
|
||
|
|
self.dtype = config.get('dtype', None)
|
||
|
|
self.load_in_4bit = config.get('load_in_4bit', True)
|
||
|
|
self.hf_token = config.get('hf_token', None)
|
||
|
|
|
||
|
|
# Inference parameters
|
||
|
|
self.batch_size = config.get('batch_size', 1)
|
||
|
|
self.max_new_tokens = config.get('max_new_tokens', 128)
|
||
|
|
self.temperature = config.get('temperature', 0.8)
|
||
|
|
self.top_p = config.get('top_p', 0.9)
|
||
|
|
self.do_sample = config.get('do_sample', True)
|
||
|
|
|
||
|
|
# Alpaca prompt template
|
||
|
|
self.alpaca_prompt = config.get('alpaca_prompt', """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that follows the instruction
|
||
|
|
|
||
|
|
### Instruction:
|
||
|
|
{}
|
||
|
|
|
||
|
|
### Input:
|
||
|
|
{}
|
||
|
|
|
||
|
|
### Response:
|
||
|
|
{}""")
|
||
|
|
|
||
|
|
# Style instruction
|
||
|
|
self.style_instruction = config.get('style_instruction', 'Rewrite the following text in a formal style')
|
||
|
|
|
||
|
|
def load_model_and_tokenizer(self):
|
||
|
|
"""Load the trained model and tokenizer"""
|
||
|
|
print("Loading trained model and tokenizer...")
|
||
|
|
|
||
|
|
try:
|
||
|
|
# Load the saved LoRA model
|
||
|
|
model_path = self.config.get('model_output_dir', './models/styling')
|
||
|
|
print(f"Loading model from: {model_path}")
|
||
|
|
|
||
|
|
self.model, self.tokenizer = FastLanguageModel.from_pretrained(
|
||
|
|
model_name=model_path,
|
||
|
|
max_seq_length=self.max_seq_length,
|
||
|
|
dtype=self.dtype,
|
||
|
|
load_in_4bit=self.load_in_4bit,
|
||
|
|
)
|
||
|
|
|
||
|
|
# Enable native 2x faster inference
|
||
|
|
FastLanguageModel.for_inference(self.model)
|
||
|
|
|
||
|
|
print(f"✅ Model loaded from: {model_path}")
|
||
|
|
print(f"✅ Tokenizer loaded with vocab size: {self.tokenizer.vocab_size}")
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
print(f"❌ Error loading model: {e}")
|
||
|
|
raise
|
||
|
|
|
||
|
|
def format_prompt(self, instruction: str, input_text: str = "") -> str:
|
||
|
|
"""Format prompt using the same alpaca format as training"""
|
||
|
|
# Use the exact same alpaca prompt as training
|
||
|
|
alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that follows the instruction
|
||
|
|
|
||
|
|
### Instruction:
|
||
|
|
{}
|
||
|
|
|
||
|
|
### Input:
|
||
|
|
{}
|
||
|
|
|
||
|
|
### Response:
|
||
|
|
{}"""
|
||
|
|
|
||
|
|
# For inference, output is empty (will be generated)
|
||
|
|
return alpaca_prompt.format(instruction, input_text, "")
|
||
|
|
|
||
|
|
def generate_text(self, instruction: str, input_text: str = "", max_new_tokens: int = 128, stream: bool = False):
|
||
|
|
"""Generate text using the trained model"""
|
||
|
|
try:
|
||
|
|
# Format the prompt
|
||
|
|
prompt = self.format_prompt(instruction, input_text)
|
||
|
|
print(f"Formatted prompt: {prompt}")
|
||
|
|
|
||
|
|
# Tokenize the input
|
||
|
|
inputs = self.tokenizer([prompt], return_tensors="pt").to(self.device)
|
||
|
|
|
||
|
|
if stream:
|
||
|
|
# Streaming generation
|
||
|
|
from transformers import TextStreamer
|
||
|
|
text_streamer = TextStreamer(self.tokenizer)
|
||
|
|
print("Generating with streaming...")
|
||
|
|
_ = self.model.generate(
|
||
|
|
**inputs,
|
||
|
|
streamer=text_streamer,
|
||
|
|
max_new_tokens=max_new_tokens
|
||
|
|
)
|
||
|
|
return None # Streaming output is handled by streamer
|
||
|
|
else:
|
||
|
|
# Non-streaming generation
|
||
|
|
print("Generating...")
|
||
|
|
outputs = self.model.generate(
|
||
|
|
**inputs,
|
||
|
|
max_new_tokens=max_new_tokens,
|
||
|
|
do_sample=True,
|
||
|
|
temperature=0.7,
|
||
|
|
pad_token_id=self.tokenizer.eos_token_id
|
||
|
|
)
|
||
|
|
|
||
|
|
# Decode the generated text
|
||
|
|
generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
||
|
|
|
||
|
|
# Extract only the generated response (remove the input prompt)
|
||
|
|
response_start = generated_text.find("### Response:")
|
||
|
|
if response_start != -1:
|
||
|
|
response = generated_text[response_start + len("### Response:"):].strip()
|
||
|
|
else:
|
||
|
|
response = generated_text
|
||
|
|
|
||
|
|
return response
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
print(f"❌ Error generating text: {e}")
|
||
|
|
raise
|
||
|
|
|
||
|
|
def style_transfer(self, input_text: str, instruction: Optional[str] = None, streaming: bool = False) -> str:
|
||
|
|
"""Perform style transfer on a single input text"""
|
||
|
|
try:
|
||
|
|
# Use default instruction if none provided
|
||
|
|
if instruction is None:
|
||
|
|
instruction = self.style_instruction
|
||
|
|
|
||
|
|
print(f"Style transfer prompt: {instruction}")
|
||
|
|
print(f"Input text: {input_text}")
|
||
|
|
|
||
|
|
# Format prompt
|
||
|
|
prompt = self.format_prompt(instruction, input_text)
|
||
|
|
|
||
|
|
print(f"Style transfer prompt: {prompt}")
|
||
|
|
|
||
|
|
if streaming:
|
||
|
|
print("Generating with streaming...")
|
||
|
|
self.generate_text_streaming(prompt)
|
||
|
|
return ""
|
||
|
|
else:
|
||
|
|
print("Generating text...")
|
||
|
|
result = self.generate_text(instruction, input_text, self.max_new_tokens)
|
||
|
|
print(f"Generated result: {result}")
|
||
|
|
return result
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
print(f"❌ Error in style transfer: {e}")
|
||
|
|
raise
|
||
|
|
|
||
|
|
def generate_text_streaming(self, prompt: str, max_new_tokens: Optional[int] = None):
|
||
|
|
"""Generate text with streaming output"""
|
||
|
|
try:
|
||
|
|
# Tokenize input
|
||
|
|
inputs = self.tokenizer([prompt], return_tensors="pt").to(self.device)
|
||
|
|
|
||
|
|
# Setup text streamer
|
||
|
|
text_streamer = TextStreamer(self.tokenizer)
|
||
|
|
|
||
|
|
# Set generation parameters
|
||
|
|
gen_kwargs = {
|
||
|
|
"max_new_tokens": max_new_tokens or self.max_new_tokens,
|
||
|
|
"temperature": self.temperature,
|
||
|
|
"top_p": self.top_p,
|
||
|
|
"do_sample": self.do_sample,
|
||
|
|
"use_cache": True,
|
||
|
|
"pad_token_id": self.tokenizer.eos_token_id
|
||
|
|
}
|
||
|
|
|
||
|
|
# Generate with streaming
|
||
|
|
with torch.no_grad():
|
||
|
|
_ = self.model.generate(**inputs, streamer=text_streamer, **gen_kwargs)
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
print(f"❌ Error in streaming generation: {e}")
|
||
|
|
|
||
|
|
def batch_style_transfer(self, input_texts: List[str], instruction: Optional[str] = None) -> List[str]:
|
||
|
|
"""Perform style transfer on multiple input texts"""
|
||
|
|
results = []
|
||
|
|
|
||
|
|
for i, input_text in enumerate(input_texts):
|
||
|
|
print(f"Processing text {i+1}/{len(input_texts)}")
|
||
|
|
result = self.style_transfer(input_text, instruction)
|
||
|
|
results.append(result)
|
||
|
|
|
||
|
|
return results
|
||
|
|
|
||
|
|
def load_inference_config(config_path: str) -> Dict[str, Any]:
|
||
|
|
"""Load inference configuration from YAML file"""
|
||
|
|
try:
|
||
|
|
with open(config_path, 'r', encoding='utf-8') as f:
|
||
|
|
config = yaml.safe_load(f)
|
||
|
|
|
||
|
|
# Extract inference configuration
|
||
|
|
inference_config = {}
|
||
|
|
|
||
|
|
# Model configuration
|
||
|
|
if 'model' in config:
|
||
|
|
model_data = config['model']
|
||
|
|
inference_config.update({
|
||
|
|
'base_model_name': model_data.get('training_model', 'unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit'),
|
||
|
|
'max_seq_length': model_data.get('training_max_seq_length', 2048),
|
||
|
|
'dtype': model_data.get('training_dtype'),
|
||
|
|
'load_in_4bit': model_data.get('training_load_in_4bit', True),
|
||
|
|
'hf_token': model_data.get('training_token')
|
||
|
|
})
|
||
|
|
|
||
|
|
# Training configuration - to get model_output_dir
|
||
|
|
if 'training' in config:
|
||
|
|
training_data = config['training']
|
||
|
|
inference_config.update({
|
||
|
|
'model_output_dir': training_data.get('model_output_dir', './models/styling')
|
||
|
|
})
|
||
|
|
|
||
|
|
# Inference configuration
|
||
|
|
if 'inference' in config:
|
||
|
|
inference_data = config['inference']
|
||
|
|
inference_config.update({
|
||
|
|
'batch_size': inference_data.get('batch_size', 1),
|
||
|
|
'max_new_tokens': inference_data.get('max_new_tokens', 128),
|
||
|
|
'temperature': inference_data.get('temperature', 0.8)
|
||
|
|
})
|
||
|
|
|
||
|
|
# Style configuration
|
||
|
|
if 'data' in config:
|
||
|
|
data_config = config['data']
|
||
|
|
inference_config.update({
|
||
|
|
'style_instruction': data_config.get('instruction', 'Rewrite the following text in a formal style')
|
||
|
|
})
|
||
|
|
|
||
|
|
return inference_config
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
print(f"Error loading inference config: {e}")
|
||
|
|
raise
|
||
|
|
|
||
|
|
def main():
|
||
|
|
"""Main inference function"""
|
||
|
|
parser = argparse.ArgumentParser(description="Styling Inference Pipeline")
|
||
|
|
|
||
|
|
# Configuration
|
||
|
|
parser.add_argument("--config", type=str, required=True, help="Path to YAML configuration file")
|
||
|
|
parser.add_argument("--instruction", type=str, required=True, help="Style instruction")
|
||
|
|
parser.add_argument("--input-text", type=str, default="", help="Input text to style")
|
||
|
|
parser.add_argument("--max-tokens", type=int, default=128, help="Maximum new tokens to generate")
|
||
|
|
parser.add_argument("--stream", action="store_true", help="Enable streaming generation")
|
||
|
|
|
||
|
|
args = parser.parse_args()
|
||
|
|
|
||
|
|
try:
|
||
|
|
# Load configuration
|
||
|
|
print(f"Loading configuration from: {args.config}")
|
||
|
|
inference_config = load_inference_config(args.config)
|
||
|
|
|
||
|
|
# Override with CLI arguments
|
||
|
|
if args.max_tokens:
|
||
|
|
inference_config['max_new_tokens'] = args.max_tokens
|
||
|
|
|
||
|
|
print("Inference configuration:")
|
||
|
|
for key, value in inference_config.items():
|
||
|
|
print(f" {key}: {value}")
|
||
|
|
|
||
|
|
# Initialize inference
|
||
|
|
inference = StylingInference(inference_config)
|
||
|
|
|
||
|
|
# Load model and tokenizer
|
||
|
|
inference.load_model_and_tokenizer()
|
||
|
|
|
||
|
|
# Run inference
|
||
|
|
if args.stream:
|
||
|
|
print("Running streaming inference...")
|
||
|
|
inference.generate_text(args.instruction, args.input_text, args.max_tokens, stream=True)
|
||
|
|
else:
|
||
|
|
print("Running inference...")
|
||
|
|
result = inference.generate_text(args.instruction, args.input_text, args.max_tokens)
|
||
|
|
print(f"✅ Generated text: {result}")
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
print(f"Inference failed: {e}")
|
||
|
|
import traceback
|
||
|
|
traceback.print_exc()
|
||
|
|
sys.exit(1)
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
main()
|