Files
DS-LLM-TEMPLATE-FINETUNING/pipelines/styling/inference.py
T
2025-08-13 23:50:20 +00:00

320 lines
12 KiB
Python

#!/usr/bin/env python3
"""
Styling Inference Pipeline using Trained Models
Supports style transfer inference with streaming and batch processing
"""
import os
import sys
import json
import argparse
from pathlib import Path
from typing import Dict, Any, Optional, List, Union
import yaml
# Add the project root to the path
sys.path.append(str(Path(__file__).parent.parent.parent))
# Inference imports
import torch
from datasets import load_from_disk, Dataset
from unsloth import FastLanguageModel
from transformers import TextStreamer
class StylingInference:
"""Styling task inference using trained models"""
def __init__(self, config: Dict[str, Any]):
self.config = config
self.model = None
self.tokenizer = None
# Set device
self.device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {self.device}")
# Model parameters
self.model_output_dir = config.get('model_output_dir', './models/styling')
self.max_seq_length = config.get('max_seq_length', 2048)
self.dtype = config.get('dtype', None)
self.load_in_4bit = config.get('load_in_4bit', True)
self.hf_token = config.get('hf_token', None)
# Inference parameters
self.batch_size = config.get('batch_size', 1)
self.max_new_tokens = config.get('max_new_tokens', 128)
self.temperature = config.get('temperature', 0.8)
self.top_p = config.get('top_p', 0.9)
self.do_sample = config.get('do_sample', True)
# Alpaca prompt template
self.alpaca_prompt = config.get('alpaca_prompt', """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that follows the instruction
### Instruction:
{}
### Input:
{}
### Response:
{}""")
# Style instruction
self.style_instruction = config.get('style_instruction', 'Rewrite the following text in a formal style')
def load_model_and_tokenizer(self):
"""Load the trained model and tokenizer"""
print("Loading trained model and tokenizer...")
try:
# Load the saved LoRA model
model_path = self.config.get('model_output_dir', './models/styling')
print(f"Loading model from: {model_path}")
self.model, self.tokenizer = FastLanguageModel.from_pretrained(
model_name=model_path,
max_seq_length=self.max_seq_length,
dtype=self.dtype,
load_in_4bit=self.load_in_4bit,
)
# Enable native 2x faster inference
FastLanguageModel.for_inference(self.model)
print(f"✅ Model loaded from: {model_path}")
print(f"✅ Tokenizer loaded with vocab size: {self.tokenizer.vocab_size}")
except Exception as e:
print(f"❌ Error loading model: {e}")
raise
def format_prompt(self, instruction: str, input_text: str = "") -> str:
"""Format prompt using the same alpaca format as training"""
# Use the exact same alpaca prompt as training
alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that follows the instruction
### Instruction:
{}
### Input:
{}
### Response:
{}"""
# For inference, output is empty (will be generated)
return alpaca_prompt.format(instruction, input_text, "")
def generate_text(self, instruction: str, input_text: str = "", max_new_tokens: int = 128, stream: bool = False):
"""Generate text using the trained model"""
try:
# Format the prompt
prompt = self.format_prompt(instruction, input_text)
print(f"Formatted prompt: {prompt}")
# Tokenize the input
inputs = self.tokenizer([prompt], return_tensors="pt").to(self.device)
if stream:
# Streaming generation
from transformers import TextStreamer
text_streamer = TextStreamer(self.tokenizer)
print("Generating with streaming...")
_ = self.model.generate(
**inputs,
streamer=text_streamer,
max_new_tokens=max_new_tokens
)
return None # Streaming output is handled by streamer
else:
# Non-streaming generation
print("Generating...")
outputs = self.model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=0.7,
pad_token_id=self.tokenizer.eos_token_id
)
# Decode the generated text
generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract only the generated response (remove the input prompt)
response_start = generated_text.find("### Response:")
if response_start != -1:
response = generated_text[response_start + len("### Response:"):].strip()
else:
response = generated_text
return response
except Exception as e:
print(f"❌ Error generating text: {e}")
raise
def style_transfer(self, input_text: str, instruction: Optional[str] = None, streaming: bool = False) -> str:
"""Perform style transfer on a single input text"""
try:
# Use default instruction if none provided
if instruction is None:
instruction = self.style_instruction
print(f"Style transfer prompt: {instruction}")
print(f"Input text: {input_text}")
# Format prompt
prompt = self.format_prompt(instruction, input_text)
print(f"Style transfer prompt: {prompt}")
if streaming:
print("Generating with streaming...")
self.generate_text_streaming(prompt)
return ""
else:
print("Generating text...")
result = self.generate_text(instruction, input_text, self.max_new_tokens)
print(f"Generated result: {result}")
return result
except Exception as e:
print(f"❌ Error in style transfer: {e}")
raise
def generate_text_streaming(self, prompt: str, max_new_tokens: Optional[int] = None):
"""Generate text with streaming output"""
try:
# Tokenize input
inputs = self.tokenizer([prompt], return_tensors="pt").to(self.device)
# Setup text streamer
text_streamer = TextStreamer(self.tokenizer)
# Set generation parameters
gen_kwargs = {
"max_new_tokens": max_new_tokens or self.max_new_tokens,
"temperature": self.temperature,
"top_p": self.top_p,
"do_sample": self.do_sample,
"use_cache": True,
"pad_token_id": self.tokenizer.eos_token_id
}
# Generate with streaming
with torch.no_grad():
_ = self.model.generate(**inputs, streamer=text_streamer, **gen_kwargs)
except Exception as e:
print(f"❌ Error in streaming generation: {e}")
def batch_style_transfer(self, input_texts: List[str], instruction: Optional[str] = None) -> List[str]:
"""Perform style transfer on multiple input texts"""
results = []
for i, input_text in enumerate(input_texts):
print(f"Processing text {i+1}/{len(input_texts)}")
result = self.style_transfer(input_text, instruction)
results.append(result)
return results
def load_inference_config(config_path: str) -> Dict[str, Any]:
"""Load inference configuration from YAML file"""
try:
with open(config_path, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
# Extract inference configuration
inference_config = {}
# Model configuration
if 'model' in config:
model_data = config['model']
inference_config.update({
'base_model_name': model_data.get('training_model', 'unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit'),
'max_seq_length': model_data.get('training_max_seq_length', 2048),
'dtype': model_data.get('training_dtype'),
'load_in_4bit': model_data.get('training_load_in_4bit', True),
'hf_token': model_data.get('training_token')
})
# Training configuration - to get model_output_dir
if 'training' in config:
training_data = config['training']
inference_config.update({
'model_output_dir': training_data.get('model_output_dir', './models/styling')
})
# Inference configuration
if 'inference' in config:
inference_data = config['inference']
inference_config.update({
'batch_size': inference_data.get('batch_size', 1),
'max_new_tokens': inference_data.get('max_new_tokens', 128),
'temperature': inference_data.get('temperature', 0.8)
})
# Style configuration
if 'data' in config:
data_config = config['data']
inference_config.update({
'style_instruction': data_config.get('instruction', 'Rewrite the following text in a formal style')
})
return inference_config
except Exception as e:
print(f"Error loading inference config: {e}")
raise
def main():
"""Main inference function"""
parser = argparse.ArgumentParser(description="Styling Inference Pipeline")
# Configuration
parser.add_argument("--config", type=str, required=True, help="Path to YAML configuration file")
parser.add_argument("--instruction", type=str, required=True, help="Style instruction")
parser.add_argument("--input-text", type=str, default="", help="Input text to style")
parser.add_argument("--max-tokens", type=int, default=128, help="Maximum new tokens to generate")
parser.add_argument("--stream", action="store_true", help="Enable streaming generation")
args = parser.parse_args()
try:
# Load configuration
print(f"Loading configuration from: {args.config}")
inference_config = load_inference_config(args.config)
# Override with CLI arguments
if args.max_tokens:
inference_config['max_new_tokens'] = args.max_tokens
print("Inference configuration:")
for key, value in inference_config.items():
print(f" {key}: {value}")
# Initialize inference
inference = StylingInference(inference_config)
# Load model and tokenizer
inference.load_model_and_tokenizer()
# Run inference
if args.stream:
print("Running streaming inference...")
inference.generate_text(args.instruction, args.input_text, args.max_tokens, stream=True)
else:
print("Running inference...")
result = inference.generate_text(args.instruction, args.input_text, args.max_tokens)
print(f"✅ Generated text: {result}")
except Exception as e:
print(f"Inference failed: {e}")
import traceback
traceback.print_exc()
sys.exit(1)
if __name__ == "__main__":
main()