#!/usr/bin/env python3 """ Styling Inference Pipeline using Trained Models Supports style transfer inference with streaming and batch processing """ import os import sys import json import logging import argparse from pathlib import Path from typing import Dict, Any, Optional, List, Union import yaml # Add the project root to the path sys.path.append(str(Path(__file__).parent.parent.parent)) from utils.config.config_manager import ConfigManager from utils.logging.logging import setup_logging # Inference imports import torch from datasets import load_from_disk, Dataset from unsloth import FastLanguageModel from transformers import TextStreamer logger = logging.getLogger(__name__) class StylingInference: """Styling task inference using trained models""" def __init__(self, config: Dict[str, Any]): self.config = config self.model = None self.tokenizer = None # Set device self.device = "cuda" if torch.cuda.is_available() else "cpu" logger.info(f"Using device: {self.device}") # Model parameters self.model_path = config.get('model_path') self.max_seq_length = config.get('max_seq_length', 2048) self.dtype = config.get('dtype', None) self.load_in_4bit = config.get('load_in_4bit', True) self.hf_token = config.get('hf_token', None) # Inference parameters self.batch_size = config.get('batch_size', 1) self.max_new_tokens = config.get('max_new_tokens', 128) self.temperature = config.get('temperature', 0.8) self.top_p = config.get('top_p', 0.9) self.do_sample = config.get('do_sample', True) # Alpaca prompt template self.alpaca_prompt = config.get('alpaca_prompt', """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that follows the instruction ### Instruction: {} ### Input: {} ### Response: {}""") # Style instruction self.style_instruction = config.get('style_instruction', 'Rewrite the following text in a formal style') def load_model_and_tokenizer(self): """Load the trained model and tokenizer""" logger.info("Loading model and tokenizer...") try: if self.model_path and Path(self.model_path).exists(): # Load local trained model logger.info(f"Loading local model from: {self.model_path}") self.model, self.tokenizer = FastLanguageModel.from_pretrained( model_name=self.model_path, max_seq_length=self.max_seq_length, dtype=self.dtype, load_in_4bit=self.load_in_4bit, token=self.hf_token ) else: # Load base model from HuggingFace Hub logger.info(f"Loading base model: {self.config.get('base_model_name', 'unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit')}") self.model, self.tokenizer = FastLanguageModel.from_pretrained( model_name=self.config.get('base_model_name', 'unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit'), max_seq_length=self.max_seq_length, dtype=self.dtype, load_in_4bit=self.load_in_4bit, token=self.hf_token ) # Prepare for inference FastLanguageModel.for_inference(self.model) logger.info(f"✅ Model loaded successfully") logger.info(f"✅ Tokenizer loaded with vocab size: {self.tokenizer.vocab_size}") except Exception as e: logger.error(f"❌ Error loading model: {e}") raise def format_prompt(self, instruction: str, input_text: str, output: str = "") -> str: """Format the prompt using Alpaca template""" return self.alpaca_prompt.format(instruction, input_text, output) def generate_text(self, prompt: str, max_new_tokens: Optional[int] = None) -> str: """Generate text from a single prompt""" try: # Tokenize input inputs = self.tokenizer([prompt], return_tensors="pt").to(self.device) # Set generation parameters gen_kwargs = { "max_new_tokens": max_new_tokens or self.max_new_tokens, "temperature": self.temperature, "top_p": self.top_p, "do_sample": self.do_sample, "use_cache": True, "pad_token_id": self.tokenizer.eos_token_id } # Generate with torch.no_grad(): outputs = self.model.generate(**inputs, **gen_kwargs) # Decode generated_text = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)[0] # Extract only the generated part (remove input prompt) if prompt in generated_text: generated_text = generated_text[len(prompt):].strip() return generated_text except Exception as e: logger.error(f"❌ Error generating text: {e}") return "" def style_transfer(self, input_text: str, instruction: Optional[str] = None, streaming: bool = False) -> str: """Perform style transfer on input text""" if instruction is None: instruction = self.style_instruction # Format prompt prompt = self.format_prompt(instruction, input_text, "") logger.info(f"Style transfer prompt: {prompt}") if streaming: logger.info("Generating with streaming...") self.generate_text_streaming(prompt) return "" else: logger.info("Generating text...") result = self.generate_text(prompt) logger.info(f"Generated result: {result}") return result def generate_text_streaming(self, prompt: str, max_new_tokens: Optional[int] = None): """Generate text with streaming output""" try: # Tokenize input inputs = self.tokenizer([prompt], return_tensors="pt").to(self.device) # Setup text streamer text_streamer = TextStreamer(self.tokenizer) # Set generation parameters gen_kwargs = { "max_new_tokens": max_new_tokens or self.max_new_tokens, "temperature": self.temperature, "top_p": self.top_p, "do_sample": self.do_sample, "use_cache": True, "pad_token_id": self.tokenizer.eos_token_id } # Generate with streaming with torch.no_grad(): _ = self.model.generate(**inputs, streamer=text_streamer, **gen_kwargs) except Exception as e: logger.error(f"❌ Error in streaming generation: {e}") def batch_style_transfer(self, input_texts: List[str], instruction: Optional[str] = None) -> List[str]: """Perform style transfer on multiple input texts""" results = [] for i, input_text in enumerate(input_texts): logger.info(f"Processing text {i+1}/{len(input_texts)}") result = self.style_transfer(input_text, instruction) results.append(result) return results def load_inference_config(config_path: str) -> Dict[str, Any]: """Load inference configuration from YAML file""" try: with open(config_path, 'r', encoding='utf-8') as f: config = yaml.safe_load(f) # Extract inference configuration inference_config = {} # Model configuration if 'model' in config: model_data = config['model'] inference_config.update({ 'base_model_name': model_data.get('training_model', 'unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit'), 'max_seq_length': model_data.get('training_max_seq_length', 2048), 'dtype': model_data.get('training_dtype'), 'load_in_4bit': model_data.get('training_load_in_4bit', True), 'hf_token': model_data.get('training_token') }) # Inference configuration if 'inference' in config: inference_data = config['inference'] inference_config.update({ 'batch_size': inference_data.get('batch_size', 1), 'max_new_tokens': inference_data.get('max_new_tokens', 128), 'temperature': inference_data.get('temperature', 0.8) }) # Style configuration if 'data' in config: data_config = config['data'] inference_config.update({ 'style_instruction': data_config.get('instruction', 'Rewrite the following text in a formal style') }) return inference_config except Exception as e: logger.error(f"Error loading inference config: {e}") raise def main(): """Main inference function""" parser = argparse.ArgumentParser(description="Styling Inference Pipeline") # Configuration parser.add_argument("--config", type=str, required=True, help="Path to YAML configuration file") parser.add_argument("--model-path", type=str, help="Path to trained model (optional, uses base model if not provided)") # Inference modes parser.add_argument("--text", type=str, help="Single text to style transfer") parser.add_argument("--input-file", type=str, help="File containing texts to process (one per line)") # Generation parameters parser.add_argument("--max-tokens", type=int, help="Maximum new tokens to generate") parser.add_argument("--temperature", type=float, help="Sampling temperature") parser.add_argument("--streaming", action="store_true", help="Enable streaming generation") parser.add_argument("--instruction", type=str, help="Custom style instruction") # Output parser.add_argument("--output-file", type=str, help="Output file for results") args = parser.parse_args() # Setup logging setup_logging() try: # Load configuration logger.info(f"Loading configuration from: {args.config}") inference_config = load_inference_config(args.config) # Override with CLI arguments if args.model_path: inference_config['model_path'] = args.model_path if args.max_tokens: inference_config['max_new_tokens'] = args.max_tokens if args.temperature: inference_config['temperature'] = args.temperature if args.instruction: inference_config['style_instruction'] = args.instruction logger.info("Inference configuration:") for key, value in inference_config.items(): logger.info(f" {key}: {value}") # Initialize inference inferencer = StylingInference(inference_config) # Load model inferencer.load_model_and_tokenizer() # Run inference based on mode if args.text: # Single text inference logger.info("Running single text inference...") result = inferencer.style_transfer(args.text, args.instruction, args.streaming) if not args.streaming: print(f"\nGenerated text: {result}") elif args.input_file: # Batch file inference logger.info("Running batch file inference...") with open(args.input_file, 'r', encoding='utf-8') as f: input_texts = [line.strip() for line in f if line.strip()] results = inferencer.batch_style_transfer(input_texts, args.instruction) # Save results output_file = args.output_file or f"{Path(args.input_file).stem}_styled.txt" with open(output_file, 'w', encoding='utf-8') as f: for input_text, result in zip(input_texts, results): f.write(f"Input: {input_text}\n") f.write(f"Output: {result}\n") f.write("-" * 50 + "\n") logger.info(f"✅ Results saved to: {output_file}") else: # Interactive mode logger.info("Entering interactive mode. Type 'quit' to exit.") while True: try: user_input = input("\nEnter text to style (or 'quit'): ").strip() if user_input.lower() == 'quit': break if user_input: result = inferencer.style_transfer(user_input, args.instruction, args.streaming) if not args.streaming: print(f"\nStyled text: {result}") except KeyboardInterrupt: break except Exception as e: logger.error(f"Error processing input: {e}") logger.info("🎉 Inference completed successfully!") except Exception as e: logger.error(f"❌ Inference failed: {e}") sys.exit(1) if __name__ == "__main__": main()