Files
DS-LLM-TEMPLATE-FINETUNING/pipelines/styling/inference.py
T

347 lines
13 KiB
Python
Raw Normal View History

2025-08-13 21:17:01 +01:00
#!/usr/bin/env python3
"""
Styling Inference Pipeline using Trained Models
Supports style transfer inference with streaming and batch processing
"""
import os
import sys
import json
import logging
import argparse
from pathlib import Path
from typing import Dict, Any, Optional, List, Union
import yaml
# Add the project root to the path
sys.path.append(str(Path(__file__).parent.parent.parent))
from utils.config.config_manager import ConfigManager
from utils.logging.logging import setup_logging
# Inference imports
import torch
from datasets import load_from_disk, Dataset
from unsloth import FastLanguageModel
from transformers import TextStreamer
logger = logging.getLogger(__name__)
class StylingInference:
"""Styling task inference using trained models"""
def __init__(self, config: Dict[str, Any]):
self.config = config
self.model = None
self.tokenizer = None
# Set device
self.device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Using device: {self.device}")
# Model parameters
self.model_path = config.get('model_path')
self.max_seq_length = config.get('max_seq_length', 2048)
self.dtype = config.get('dtype', None)
self.load_in_4bit = config.get('load_in_4bit', True)
self.hf_token = config.get('hf_token', None)
# Inference parameters
self.batch_size = config.get('batch_size', 1)
self.max_new_tokens = config.get('max_new_tokens', 128)
self.temperature = config.get('temperature', 0.8)
self.top_p = config.get('top_p', 0.9)
self.do_sample = config.get('do_sample', True)
# Alpaca prompt template
self.alpaca_prompt = config.get('alpaca_prompt', """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that follows the instruction
### Instruction:
{}
### Input:
{}
### Response:
{}""")
# Style instruction
self.style_instruction = config.get('style_instruction', 'Rewrite the following text in a formal style')
def load_model_and_tokenizer(self):
"""Load the trained model and tokenizer"""
logger.info("Loading model and tokenizer...")
try:
if self.model_path and Path(self.model_path).exists():
# Load local trained model
logger.info(f"Loading local model from: {self.model_path}")
self.model, self.tokenizer = FastLanguageModel.from_pretrained(
model_name=self.model_path,
max_seq_length=self.max_seq_length,
dtype=self.dtype,
load_in_4bit=self.load_in_4bit,
token=self.hf_token
)
else:
# Load base model from HuggingFace Hub
logger.info(f"Loading base model: {self.config.get('base_model_name', 'unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit')}")
self.model, self.tokenizer = FastLanguageModel.from_pretrained(
model_name=self.config.get('base_model_name', 'unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit'),
max_seq_length=self.max_seq_length,
dtype=self.dtype,
load_in_4bit=self.load_in_4bit,
token=self.hf_token
)
# Prepare for inference
FastLanguageModel.for_inference(self.model)
logger.info(f"✅ Model loaded successfully")
logger.info(f"✅ Tokenizer loaded with vocab size: {self.tokenizer.vocab_size}")
except Exception as e:
logger.error(f"❌ Error loading model: {e}")
raise
def format_prompt(self, instruction: str, input_text: str, output: str = "") -> str:
"""Format the prompt using Alpaca template"""
return self.alpaca_prompt.format(instruction, input_text, output)
def generate_text(self, prompt: str, max_new_tokens: Optional[int] = None) -> str:
"""Generate text from a single prompt"""
try:
# Tokenize input
inputs = self.tokenizer([prompt], return_tensors="pt").to(self.device)
# Set generation parameters
gen_kwargs = {
"max_new_tokens": max_new_tokens or self.max_new_tokens,
"temperature": self.temperature,
"top_p": self.top_p,
"do_sample": self.do_sample,
"use_cache": True,
"pad_token_id": self.tokenizer.eos_token_id
}
# Generate
with torch.no_grad():
outputs = self.model.generate(**inputs, **gen_kwargs)
# Decode
generated_text = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
# Extract only the generated part (remove input prompt)
if prompt in generated_text:
generated_text = generated_text[len(prompt):].strip()
return generated_text
except Exception as e:
logger.error(f"❌ Error generating text: {e}")
return ""
def style_transfer(self, input_text: str, instruction: Optional[str] = None, streaming: bool = False) -> str:
"""Perform style transfer on input text"""
if instruction is None:
instruction = self.style_instruction
# Format prompt
prompt = self.format_prompt(instruction, input_text, "")
logger.info(f"Style transfer prompt: {prompt}")
if streaming:
logger.info("Generating with streaming...")
self.generate_text_streaming(prompt)
return ""
else:
logger.info("Generating text...")
result = self.generate_text(prompt)
logger.info(f"Generated result: {result}")
return result
def generate_text_streaming(self, prompt: str, max_new_tokens: Optional[int] = None):
"""Generate text with streaming output"""
try:
# Tokenize input
inputs = self.tokenizer([prompt], return_tensors="pt").to(self.device)
# Setup text streamer
text_streamer = TextStreamer(self.tokenizer)
# Set generation parameters
gen_kwargs = {
"max_new_tokens": max_new_tokens or self.max_new_tokens,
"temperature": self.temperature,
"top_p": self.top_p,
"do_sample": self.do_sample,
"use_cache": True,
"pad_token_id": self.tokenizer.eos_token_id
}
# Generate with streaming
with torch.no_grad():
_ = self.model.generate(**inputs, streamer=text_streamer, **gen_kwargs)
except Exception as e:
logger.error(f"❌ Error in streaming generation: {e}")
def batch_style_transfer(self, input_texts: List[str], instruction: Optional[str] = None) -> List[str]:
"""Perform style transfer on multiple input texts"""
results = []
for i, input_text in enumerate(input_texts):
logger.info(f"Processing text {i+1}/{len(input_texts)}")
result = self.style_transfer(input_text, instruction)
results.append(result)
return results
def load_inference_config(config_path: str) -> Dict[str, Any]:
"""Load inference configuration from YAML file"""
try:
with open(config_path, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
# Extract inference configuration
inference_config = {}
# Model configuration
if 'model' in config:
model_data = config['model']
inference_config.update({
'base_model_name': model_data.get('training_model', 'unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit'),
'max_seq_length': model_data.get('training_max_seq_length', 2048),
'dtype': model_data.get('training_dtype'),
'load_in_4bit': model_data.get('training_load_in_4bit', True),
'hf_token': model_data.get('training_token')
})
# Inference configuration
if 'inference' in config:
inference_data = config['inference']
inference_config.update({
'batch_size': inference_data.get('batch_size', 1),
'max_new_tokens': inference_data.get('max_new_tokens', 128),
'temperature': inference_data.get('temperature', 0.8)
})
# Style configuration
if 'data' in config:
data_config = config['data']
inference_config.update({
'style_instruction': data_config.get('instruction', 'Rewrite the following text in a formal style')
})
return inference_config
except Exception as e:
logger.error(f"Error loading inference config: {e}")
raise
def main():
"""Main inference function"""
parser = argparse.ArgumentParser(description="Styling Inference Pipeline")
# Configuration
parser.add_argument("--config", type=str, required=True, help="Path to YAML configuration file")
parser.add_argument("--model-path", type=str, help="Path to trained model (optional, uses base model if not provided)")
# Inference modes
parser.add_argument("--text", type=str, help="Single text to style transfer")
parser.add_argument("--input-file", type=str, help="File containing texts to process (one per line)")
# Generation parameters
parser.add_argument("--max-tokens", type=int, help="Maximum new tokens to generate")
parser.add_argument("--temperature", type=float, help="Sampling temperature")
parser.add_argument("--streaming", action="store_true", help="Enable streaming generation")
parser.add_argument("--instruction", type=str, help="Custom style instruction")
# Output
parser.add_argument("--output-file", type=str, help="Output file for results")
args = parser.parse_args()
# Setup logging
setup_logging()
try:
# Load configuration
logger.info(f"Loading configuration from: {args.config}")
inference_config = load_inference_config(args.config)
# Override with CLI arguments
if args.model_path:
inference_config['model_path'] = args.model_path
if args.max_tokens:
inference_config['max_new_tokens'] = args.max_tokens
if args.temperature:
inference_config['temperature'] = args.temperature
if args.instruction:
inference_config['style_instruction'] = args.instruction
logger.info("Inference configuration:")
for key, value in inference_config.items():
logger.info(f" {key}: {value}")
# Initialize inference
inferencer = StylingInference(inference_config)
# Load model
inferencer.load_model_and_tokenizer()
# Run inference based on mode
if args.text:
# Single text inference
logger.info("Running single text inference...")
result = inferencer.style_transfer(args.text, args.instruction, args.streaming)
if not args.streaming:
print(f"\nGenerated text: {result}")
elif args.input_file:
# Batch file inference
logger.info("Running batch file inference...")
with open(args.input_file, 'r', encoding='utf-8') as f:
input_texts = [line.strip() for line in f if line.strip()]
results = inferencer.batch_style_transfer(input_texts, args.instruction)
# Save results
output_file = args.output_file or f"{Path(args.input_file).stem}_styled.txt"
with open(output_file, 'w', encoding='utf-8') as f:
for input_text, result in zip(input_texts, results):
f.write(f"Input: {input_text}\n")
f.write(f"Output: {result}\n")
f.write("-" * 50 + "\n")
logger.info(f"✅ Results saved to: {output_file}")
else:
# Interactive mode
logger.info("Entering interactive mode. Type 'quit' to exit.")
while True:
try:
user_input = input("\nEnter text to style (or 'quit'): ").strip()
if user_input.lower() == 'quit':
break
if user_input:
result = inferencer.style_transfer(user_input, args.instruction, args.streaming)
if not args.streaming:
print(f"\nStyled text: {result}")
except KeyboardInterrupt:
break
except Exception as e:
logger.error(f"Error processing input: {e}")
logger.info("🎉 Inference completed successfully!")
except Exception as e:
logger.error(f"❌ Inference failed: {e}")
sys.exit(1)
if __name__ == "__main__":
main()