#!/usr/bin/env python3 """ Instruct Inference Pipeline using Trained Models Supports conversational inference with streaming and batch processing """ import os import sys import json import argparse from pathlib import Path from typing import Dict, Any, Optional, List, Union import yaml # Add the project root to the path sys.path.append(str(Path(__file__).parent.parent.parent)) # Inference imports import torch from datasets import load_from_disk, Dataset from unsloth import FastLanguageModel from unsloth.chat_templates import get_chat_template from transformers import TextStreamer class InstructInference: """Instruction fine-tuning inference using trained models""" def __init__(self, config: Dict[str, Any]): self.config = config self.model = None self.tokenizer = None # Set device self.device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {self.device}") # Model parameters self.model_output_dir = config.get('model_output_dir', './models/instruct') self.base_model_name = config.get('base_model_name', 'unsloth/Qwen2.5-72B-Instruct') self.max_seq_length = config.get('max_seq_length', 2048) self.dtype = config.get('dtype', None) self.load_in_4bit = config.get('load_in_4bit', True) self.hf_token = config.get('hf_token', None) # Inference parameters self.batch_size = config.get('batch_size', 1) self.max_new_tokens = config.get('max_new_tokens', 128) self.temperature = config.get('temperature', 1.5) self.min_p = config.get('min_p', 0.1) self.use_cache = config.get('use_cache', True) # Chat template self.chat_template = config.get('chat_template', 'qwen-2.5') def load_model_and_tokenizer(self): """Load the trained model and tokenizer""" print("Loading trained instruction model and tokenizer...") try: # Load the saved LoRA model model_path = self.model_output_dir print(f"Loading model from: {model_path}") # Check if the model directory exists if not Path(model_path).exists(): raise FileNotFoundError(f"Model directory not found: {model_path}") # Load the model directly from the saved path self.model, self.tokenizer = FastLanguageModel.from_pretrained( model_name=model_path, max_seq_length=self.max_seq_length, dtype=self.dtype, load_in_4bit=self.load_in_4bit, ) # Enable native 2x faster inference FastLanguageModel.for_inference(self.model) print(f"✅ Model loaded from: {model_path}") print(f"✅ Tokenizer loaded with vocab size: {self.tokenizer.vocab_size}") except Exception as e: print(f"❌ Error loading model: {e}") raise def setup_chat_template(self): """Setup chat template for conversation formatting""" print("Setting up chat template...") try: self.tokenizer = get_chat_template( self.tokenizer, chat_template=self.chat_template, ) print(f"✅ Chat template configured: {self.chat_template}") except Exception as e: print(f"❌ Error setting up chat template: {e}") raise def format_messages(self, messages: List[Dict[str, str]]) -> str: """Format messages using chat template""" try: # Apply chat template to format the conversation formatted_prompt = self.tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, # Add generation prompt for inference ) return formatted_prompt except Exception as e: print(f"❌ Error formatting messages: {e}") raise def generate_response( self, messages: List[Dict[str, str]], max_new_tokens: Optional[int] = None, temperature: Optional[float] = None, stream: bool = False ) -> str: """Generate response using the trained instruction model""" try: # Use default values if not provided max_tokens = max_new_tokens or self.max_new_tokens temp = temperature or self.temperature # Format the messages formatted_prompt = self.format_messages(messages) print(f"Formatted prompt: {formatted_prompt[:200]}...") # Tokenize the input inputs = self.tokenizer( [formatted_prompt], return_tensors="pt" ).to(self.device) if stream: # Streaming generation text_streamer = TextStreamer(self.tokenizer, skip_prompt=True) print("Generating with streaming...") _ = self.model.generate( input_ids=inputs.input_ids, streamer=text_streamer, max_new_tokens=max_tokens, use_cache=self.use_cache, temperature=temp, min_p=self.min_p ) return "" # Streaming output is handled by streamer else: # Non-streaming generation print("Generating response...") outputs = self.model.generate( input_ids=inputs.input_ids, max_new_tokens=max_tokens, use_cache=self.use_cache, temperature=temp, min_p=self.min_p, pad_token_id=self.tokenizer.eos_token_id ) # Decode the generated text full_response = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)[0] # Extract only the new generated response (remove the input prompt) prompt_length = len(formatted_prompt) response = full_response[prompt_length:].strip() return response except Exception as e: print(f"❌ Error generating response: {e}") raise def chat(self, user_input: str, conversation_history: Optional[List[Dict[str, str]]] = None, stream: bool = False) -> str: """Have a chat conversation with the model""" try: # Initialize conversation history if not provided if conversation_history is None: conversation_history = [] # Add user input to conversation messages = conversation_history + [{"role": "user", "content": user_input}] print(f"User: {user_input}") if stream: print("Assistant: ", end="", flush=True) self.generate_response(messages, stream=True) return "" else: # Generate response response = self.generate_response(messages, stream=False) print(f"Assistant: {response}") return response except Exception as e: print(f"❌ Error in chat: {e}") raise def batch_inference( self, conversations: List[List[Dict[str, str]]], max_new_tokens: Optional[int] = None ) -> List[str]: """Perform batch inference on multiple conversations""" responses = [] for i, messages in enumerate(conversations): print(f"Processing conversation {i+1}/{len(conversations)}") response = self.generate_response(messages, max_new_tokens) responses.append(response) return responses def interactive_chat(self): """Start an interactive chat session""" print("🤖 Starting interactive chat session...") print("Type 'quit', 'exit', or 'bye' to end the conversation.") print("Type 'clear' to clear conversation history.") print("Type 'stream on' or 'stream off' to toggle streaming.") print("-" * 50) conversation_history = [] streaming = False while True: try: user_input = input("\n👤 You: ").strip() if user_input.lower() in ['quit', 'exit', 'bye']: print("👋 Goodbye!") break elif user_input.lower() == 'clear': conversation_history = [] print("🗑️ Conversation history cleared.") continue elif user_input.lower() == 'stream on': streaming = True print("🔄 Streaming enabled.") continue elif user_input.lower() == 'stream off': streaming = False print("⏸️ Streaming disabled.") continue elif not user_input: continue # Generate response if streaming: print("🤖 Assistant: ", end="", flush=True) self.chat(user_input, conversation_history, stream=True) # Add to history (we don't have the actual response text for streaming) conversation_history.extend([ {"role": "user", "content": user_input}, {"role": "assistant", "content": "[Streamed response]"} ]) else: response = self.chat(user_input, conversation_history, stream=False) # Add to history conversation_history.extend([ {"role": "user", "content": user_input}, {"role": "assistant", "content": response} ]) except KeyboardInterrupt: print("\n👋 Goodbye!") break except Exception as e: print(f"❌ Error: {e}") continue def load_inference_config(config_path: str) -> Dict[str, Any]: """Load inference configuration from YAML file""" try: with open(config_path, 'r', encoding='utf-8') as f: config = yaml.safe_load(f) # Extract inference configuration inference_config = {} # Model configuration if 'model' in config: model_data = config['model'] inference_config.update({ 'base_model_name': model_data.get('training_model', 'unsloth/Qwen2.5-72B-Instruct'), 'max_seq_length': model_data.get('training_max_seq_length', 2048), 'dtype': model_data.get('training_dtype'), 'load_in_4bit': model_data.get('training_load_in_4bit', True), 'hf_token': model_data.get('training_token') }) # Training configuration - to get model_output_dir if 'training' in config: training_data = config['training'] inference_config.update({ 'model_output_dir': training_data.get('model_output_dir', './models/instruct') }) # Inference configuration if 'inference' in config: inference_data = config['inference'] inference_config.update({ 'batch_size': inference_data.get('batch_size', 1), 'max_new_tokens': inference_data.get('max_new_tokens', 128), 'temperature': inference_data.get('temperature', 1.5), 'min_p': inference_data.get('min_p', 0.1), 'use_cache': inference_data.get('use_cache', True) }) # Chat template inference_config.update({ 'chat_template': 'qwen-2.5' # Use Qwen chat template by default }) return inference_config except Exception as e: print(f"Error loading inference config: {e}") raise def main(): """Main inference function""" parser = argparse.ArgumentParser(description="Instruction Inference Pipeline") # Configuration parser.add_argument("--config", type=str, required=True, help="Path to YAML configuration file") parser.add_argument("--interactive", action="store_true", help="Start interactive chat session") parser.add_argument("--message", type=str, help="Single message to send to the model") parser.add_argument("--max-tokens", type=int, help="Maximum new tokens to generate") parser.add_argument("--stream", action="store_true", help="Enable streaming generation") parser.add_argument("--temperature", type=float, help="Sampling temperature") args = parser.parse_args() try: # Load configuration print(f"Loading configuration from: {args.config}") inference_config = load_inference_config(args.config) # Override with CLI arguments if args.max_tokens: inference_config['max_new_tokens'] = args.max_tokens if args.temperature: inference_config['temperature'] = args.temperature print("Inference configuration:") for key, value in inference_config.items(): print(f" {key}: {value}") # Initialize inference inference = InstructInference(inference_config) # Load model and tokenizer inference.load_model_and_tokenizer() # Setup chat template inference.setup_chat_template() # Run inference based on mode if args.interactive: # Interactive chat mode inference.interactive_chat() elif args.message: # Single message mode print("Running single message inference...") messages = [{"role": "user", "content": args.message}] if args.stream: print("User:", args.message) print("Assistant: ", end="", flush=True) inference.generate_response(messages, stream=True) else: response = inference.generate_response(messages, stream=False) print(f"User: {args.message}") print(f"Assistant: {response}") else: # Default to interactive mode if no specific mode is chosen print("No specific mode chosen. Starting interactive chat...") inference.interactive_chat() except Exception as e: print(f"Inference failed: {e}") import traceback traceback.print_exc() sys.exit(1) if __name__ == "__main__": main()