#!/usr/bin/env python3 """ Styling Inference Pipeline using Trained Models Supports style transfer inference with streaming and batch processing """ import os import sys import json import argparse from pathlib import Path from typing import Dict, Any, Optional, List, Union import yaml # Add the project root to the path sys.path.append(str(Path(__file__).parent.parent.parent)) # Inference imports import torch from datasets import load_from_disk, Dataset from unsloth import FastLanguageModel from transformers import TextStreamer class StylingInference: """Styling task inference using trained models""" def __init__(self, config: Dict[str, Any]): self.config = config self.model = None self.tokenizer = None # Set device self.device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {self.device}") # Model parameters self.model_output_dir = config.get('model_output_dir', './models/styling') self.max_seq_length = config.get('max_seq_length', 2048) self.dtype = config.get('dtype', None) self.load_in_4bit = config.get('load_in_4bit', True) self.hf_token = config.get('hf_token', None) # Inference parameters self.batch_size = config.get('batch_size', 1) self.max_new_tokens = config.get('max_new_tokens', 128) self.temperature = config.get('temperature', 0.8) self.top_p = config.get('top_p', 0.9) self.do_sample = config.get('do_sample', True) # Alpaca prompt template self.alpaca_prompt = config.get('alpaca_prompt', """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that follows the instruction ### Instruction: {} ### Input: {} ### Response: {}""") # Style instruction self.style_instruction = config.get('style_instruction', 'Rewrite the following text in a formal style') def load_model_and_tokenizer(self): """Load the trained model and tokenizer""" print("Loading trained model and tokenizer...") try: # Load the saved LoRA model model_path = self.config.get('model_output_dir', './models/styling') print(f"Loading model from: {model_path}") self.model, self.tokenizer = FastLanguageModel.from_pretrained( model_name=model_path, max_seq_length=self.max_seq_length, dtype=self.dtype, load_in_4bit=self.load_in_4bit, ) # Enable native 2x faster inference FastLanguageModel.for_inference(self.model) print(f"✅ Model loaded from: {model_path}") print(f"✅ Tokenizer loaded with vocab size: {self.tokenizer.vocab_size}") except Exception as e: print(f"❌ Error loading model: {e}") raise def format_prompt(self, instruction: str, input_text: str = "") -> str: """Format prompt using the same alpaca format as training""" # Use the exact same alpaca prompt as training alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that follows the instruction ### Instruction: {} ### Input: {} ### Response: {}""" # For inference, output is empty (will be generated) return alpaca_prompt.format(instruction, input_text, "") def generate_text(self, instruction: str, input_text: str = "", max_new_tokens: int = 128, stream: bool = False): """Generate text using the trained model""" try: # Format the prompt prompt = self.format_prompt(instruction, input_text) print(f"Formatted prompt: {prompt}") # Tokenize the input inputs = self.tokenizer([prompt], return_tensors="pt").to(self.device) if stream: # Streaming generation from transformers import TextStreamer text_streamer = TextStreamer(self.tokenizer) print("Generating with streaming...") _ = self.model.generate( **inputs, streamer=text_streamer, max_new_tokens=max_new_tokens ) return None # Streaming output is handled by streamer else: # Non-streaming generation print("Generating...") outputs = self.model.generate( **inputs, max_new_tokens=max_new_tokens, do_sample=True, temperature=0.7, pad_token_id=self.tokenizer.eos_token_id ) # Decode the generated text generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True) # Extract only the generated response (remove the input prompt) response_start = generated_text.find("### Response:") if response_start != -1: response = generated_text[response_start + len("### Response:"):].strip() else: response = generated_text return response except Exception as e: print(f"❌ Error generating text: {e}") raise def style_transfer(self, input_text: str, instruction: Optional[str] = None, streaming: bool = False) -> str: """Perform style transfer on a single input text""" try: # Use default instruction if none provided if instruction is None: instruction = self.style_instruction print(f"Style transfer prompt: {instruction}") print(f"Input text: {input_text}") # Format prompt prompt = self.format_prompt(instruction, input_text) print(f"Style transfer prompt: {prompt}") if streaming: print("Generating with streaming...") self.generate_text_streaming(prompt) return "" else: print("Generating text...") result = self.generate_text(instruction, input_text, self.max_new_tokens) print(f"Generated result: {result}") return result except Exception as e: print(f"❌ Error in style transfer: {e}") raise def generate_text_streaming(self, prompt: str, max_new_tokens: Optional[int] = None): """Generate text with streaming output""" try: # Tokenize input inputs = self.tokenizer([prompt], return_tensors="pt").to(self.device) # Setup text streamer text_streamer = TextStreamer(self.tokenizer) # Set generation parameters gen_kwargs = { "max_new_tokens": max_new_tokens or self.max_new_tokens, "temperature": self.temperature, "top_p": self.top_p, "do_sample": self.do_sample, "use_cache": True, "pad_token_id": self.tokenizer.eos_token_id } # Generate with streaming with torch.no_grad(): _ = self.model.generate(**inputs, streamer=text_streamer, **gen_kwargs) except Exception as e: print(f"❌ Error in streaming generation: {e}") def batch_style_transfer(self, input_texts: List[str], instruction: Optional[str] = None) -> List[str]: """Perform style transfer on multiple input texts""" results = [] for i, input_text in enumerate(input_texts): print(f"Processing text {i+1}/{len(input_texts)}") result = self.style_transfer(input_text, instruction) results.append(result) return results def load_inference_config(config_path: str) -> Dict[str, Any]: """Load inference configuration from YAML file""" try: with open(config_path, 'r', encoding='utf-8') as f: config = yaml.safe_load(f) # Extract inference configuration inference_config = {} # Model configuration if 'model' in config: model_data = config['model'] inference_config.update({ 'base_model_name': model_data.get('training_model', 'unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit'), 'max_seq_length': model_data.get('training_max_seq_length', 2048), 'dtype': model_data.get('training_dtype'), 'load_in_4bit': model_data.get('training_load_in_4bit', True), 'hf_token': model_data.get('training_token') }) # Training configuration - to get model_output_dir if 'training' in config: training_data = config['training'] inference_config.update({ 'model_output_dir': training_data.get('model_output_dir', './models/styling') }) # Inference configuration if 'inference' in config: inference_data = config['inference'] inference_config.update({ 'batch_size': inference_data.get('batch_size', 1), 'max_new_tokens': inference_data.get('max_new_tokens', 128), 'temperature': inference_data.get('temperature', 0.8) }) # Style configuration if 'data' in config: data_config = config['data'] inference_config.update({ 'style_instruction': data_config.get('instruction', 'Rewrite the following text in a formal style') }) return inference_config except Exception as e: print(f"Error loading inference config: {e}") raise def main(): """Main inference function""" parser = argparse.ArgumentParser(description="Styling Inference Pipeline") # Configuration parser.add_argument("--config", type=str, required=True, help="Path to YAML configuration file") parser.add_argument("--instruction", type=str, required=True, help="Style instruction") parser.add_argument("--input-text", type=str, default="", help="Input text to style") parser.add_argument("--max-tokens", type=int, default=128, help="Maximum new tokens to generate") parser.add_argument("--stream", action="store_true", help="Enable streaming generation") args = parser.parse_args() try: # Load configuration print(f"Loading configuration from: {args.config}") inference_config = load_inference_config(args.config) # Override with CLI arguments if args.max_tokens: inference_config['max_new_tokens'] = args.max_tokens print("Inference configuration:") for key, value in inference_config.items(): print(f" {key}: {value}") # Initialize inference inference = StylingInference(inference_config) # Load model and tokenizer inference.load_model_and_tokenizer() # Run inference if args.stream: print("Running streaming inference...") inference.generate_text(args.instruction, args.input_text, args.max_tokens, stream=True) else: print("Running inference...") result = inference.generate_text(args.instruction, args.input_text, args.max_tokens) print(f"✅ Generated text: {result}") except Exception as e: print(f"Inference failed: {e}") import traceback traceback.print_exc() sys.exit(1) if __name__ == "__main__": main()