Files
DS-LLM-TEMPLATE-FINETUNING/pipelines/instruct/.ipynb_checkpoints/inference-checkpoint.py
T
2025-08-28 16:46:24 +00:00

394 lines
15 KiB
Python

#!/usr/bin/env python3
"""
Instruct Inference Pipeline using Trained Models
Supports conversational inference with streaming and batch processing
"""
import os
import sys
import json
import argparse
from pathlib import Path
from typing import Dict, Any, Optional, List, Union
import yaml
# Add the project root to the path
sys.path.append(str(Path(__file__).parent.parent.parent))
# Inference imports
import torch
from datasets import load_from_disk, Dataset
from unsloth import FastLanguageModel
from unsloth.chat_templates import get_chat_template
from transformers import TextStreamer
class InstructInference:
"""Instruction fine-tuning inference using trained models"""
def __init__(self, config: Dict[str, Any]):
self.config = config
self.model = None
self.tokenizer = None
# Set device
self.device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {self.device}")
# Model parameters
self.model_output_dir = config.get('model_output_dir', './models/instruct')
self.base_model_name = config.get('base_model_name', 'unsloth/Qwen2.5-72B-Instruct')
self.max_seq_length = config.get('max_seq_length', 2048)
self.dtype = config.get('dtype', None)
self.load_in_4bit = config.get('load_in_4bit', True)
self.hf_token = config.get('hf_token', None)
# Inference parameters
self.batch_size = config.get('batch_size', 1)
self.max_new_tokens = config.get('max_new_tokens', 128)
self.temperature = config.get('temperature', 1.5)
self.min_p = config.get('min_p', 0.1)
self.use_cache = config.get('use_cache', True)
# Chat template
self.chat_template = config.get('chat_template', 'qwen-2.5')
def load_model_and_tokenizer(self):
"""Load the trained model and tokenizer"""
print("Loading trained instruction model and tokenizer...")
try:
# Load the saved LoRA model
model_path = self.model_output_dir
print(f"Loading model from: {model_path}")
# Check if the model directory exists
if not Path(model_path).exists():
raise FileNotFoundError(f"Model directory not found: {model_path}")
# Load the model directly from the saved path
self.model, self.tokenizer = FastLanguageModel.from_pretrained(
model_name=model_path,
max_seq_length=self.max_seq_length,
dtype=self.dtype,
load_in_4bit=self.load_in_4bit,
)
# Enable native 2x faster inference
FastLanguageModel.for_inference(self.model)
print(f"✅ Model loaded from: {model_path}")
print(f"✅ Tokenizer loaded with vocab size: {self.tokenizer.vocab_size}")
except Exception as e:
print(f"❌ Error loading model: {e}")
raise
def setup_chat_template(self):
"""Setup chat template for conversation formatting"""
print("Setting up chat template...")
try:
self.tokenizer = get_chat_template(
self.tokenizer,
chat_template=self.chat_template,
)
print(f"✅ Chat template configured: {self.chat_template}")
except Exception as e:
print(f"❌ Error setting up chat template: {e}")
raise
def format_messages(self, messages: List[Dict[str, str]]) -> str:
"""Format messages using chat template"""
try:
# Apply chat template to format the conversation
formatted_prompt = self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True, # Add generation prompt for inference
)
return formatted_prompt
except Exception as e:
print(f"❌ Error formatting messages: {e}")
raise
def generate_response(
self,
messages: List[Dict[str, str]],
max_new_tokens: Optional[int] = None,
temperature: Optional[float] = None,
stream: bool = False
) -> str:
"""Generate response using the trained instruction model"""
try:
# Use default values if not provided
max_tokens = max_new_tokens or self.max_new_tokens
temp = temperature or self.temperature
# Format the messages
formatted_prompt = self.format_messages(messages)
print(f"Formatted prompt: {formatted_prompt[:200]}...")
# Tokenize the input
inputs = self.tokenizer(
[formatted_prompt],
return_tensors="pt"
).to(self.device)
if stream:
# Streaming generation
text_streamer = TextStreamer(self.tokenizer, skip_prompt=True)
print("Generating with streaming...")
_ = self.model.generate(
input_ids=inputs.input_ids,
streamer=text_streamer,
max_new_tokens=max_tokens,
use_cache=self.use_cache,
temperature=temp,
min_p=self.min_p
)
return "" # Streaming output is handled by streamer
else:
# Non-streaming generation
print("Generating response...")
outputs = self.model.generate(
input_ids=inputs.input_ids,
max_new_tokens=max_tokens,
use_cache=self.use_cache,
temperature=temp,
min_p=self.min_p,
pad_token_id=self.tokenizer.eos_token_id
)
# Decode the generated text
full_response = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
# Extract only the new generated response (remove the input prompt)
prompt_length = len(formatted_prompt)
response = full_response[prompt_length:].strip()
return response
except Exception as e:
print(f"❌ Error generating response: {e}")
raise
def chat(self, user_input: str, conversation_history: Optional[List[Dict[str, str]]] = None, stream: bool = False) -> str:
"""Have a chat conversation with the model"""
try:
# Initialize conversation history if not provided
if conversation_history is None:
conversation_history = []
# Add user input to conversation
messages = conversation_history + [{"role": "user", "content": user_input}]
print(f"User: {user_input}")
if stream:
print("Assistant: ", end="", flush=True)
self.generate_response(messages, stream=True)
return ""
else:
# Generate response
response = self.generate_response(messages, stream=False)
print(f"Assistant: {response}")
return response
except Exception as e:
print(f"❌ Error in chat: {e}")
raise
def batch_inference(
self,
conversations: List[List[Dict[str, str]]],
max_new_tokens: Optional[int] = None
) -> List[str]:
"""Perform batch inference on multiple conversations"""
responses = []
for i, messages in enumerate(conversations):
print(f"Processing conversation {i+1}/{len(conversations)}")
response = self.generate_response(messages, max_new_tokens)
responses.append(response)
return responses
def interactive_chat(self):
"""Start an interactive chat session"""
print("🤖 Starting interactive chat session...")
print("Type 'quit', 'exit', or 'bye' to end the conversation.")
print("Type 'clear' to clear conversation history.")
print("Type 'stream on' or 'stream off' to toggle streaming.")
print("-" * 50)
conversation_history = []
streaming = False
while True:
try:
user_input = input("\n👤 You: ").strip()
if user_input.lower() in ['quit', 'exit', 'bye']:
print("👋 Goodbye!")
break
elif user_input.lower() == 'clear':
conversation_history = []
print("🗑️ Conversation history cleared.")
continue
elif user_input.lower() == 'stream on':
streaming = True
print("🔄 Streaming enabled.")
continue
elif user_input.lower() == 'stream off':
streaming = False
print("⏸️ Streaming disabled.")
continue
elif not user_input:
continue
# Generate response
if streaming:
print("🤖 Assistant: ", end="", flush=True)
self.chat(user_input, conversation_history, stream=True)
# Add to history (we don't have the actual response text for streaming)
conversation_history.extend([
{"role": "user", "content": user_input},
{"role": "assistant", "content": "[Streamed response]"}
])
else:
response = self.chat(user_input, conversation_history, stream=False)
# Add to history
conversation_history.extend([
{"role": "user", "content": user_input},
{"role": "assistant", "content": response}
])
except KeyboardInterrupt:
print("\n👋 Goodbye!")
break
except Exception as e:
print(f"❌ Error: {e}")
continue
def load_inference_config(config_path: str) -> Dict[str, Any]:
"""Load inference configuration from YAML file"""
try:
with open(config_path, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
# Extract inference configuration
inference_config = {}
# Model configuration
if 'model' in config:
model_data = config['model']
inference_config.update({
'base_model_name': model_data.get('training_model', 'unsloth/Qwen2.5-72B-Instruct'),
'max_seq_length': model_data.get('training_max_seq_length', 2048),
'dtype': model_data.get('training_dtype'),
'load_in_4bit': model_data.get('training_load_in_4bit', True),
'hf_token': model_data.get('training_token')
})
# Training configuration - to get model_output_dir
if 'training' in config:
training_data = config['training']
inference_config.update({
'model_output_dir': training_data.get('model_output_dir', './models/instruct')
})
# Inference configuration
if 'inference' in config:
inference_data = config['inference']
inference_config.update({
'batch_size': inference_data.get('batch_size', 1),
'max_new_tokens': inference_data.get('max_new_tokens', 128),
'temperature': inference_data.get('temperature', 1.5),
'min_p': inference_data.get('min_p', 0.1),
'use_cache': inference_data.get('use_cache', True)
})
# Chat template
inference_config.update({
'chat_template': 'qwen-2.5' # Use Qwen chat template by default
})
return inference_config
except Exception as e:
print(f"Error loading inference config: {e}")
raise
def main():
"""Main inference function"""
parser = argparse.ArgumentParser(description="Instruction Inference Pipeline")
# Configuration
parser.add_argument("--config", type=str, required=True, help="Path to YAML configuration file")
parser.add_argument("--interactive", action="store_true", help="Start interactive chat session")
parser.add_argument("--message", type=str, help="Single message to send to the model")
parser.add_argument("--max-tokens", type=int, help="Maximum new tokens to generate")
parser.add_argument("--stream", action="store_true", help="Enable streaming generation")
parser.add_argument("--temperature", type=float, help="Sampling temperature")
args = parser.parse_args()
try:
# Load configuration
print(f"Loading configuration from: {args.config}")
inference_config = load_inference_config(args.config)
# Override with CLI arguments
if args.max_tokens:
inference_config['max_new_tokens'] = args.max_tokens
if args.temperature:
inference_config['temperature'] = args.temperature
print("Inference configuration:")
for key, value in inference_config.items():
print(f" {key}: {value}")
# Initialize inference
inference = InstructInference(inference_config)
# Load model and tokenizer
inference.load_model_and_tokenizer()
# Setup chat template
inference.setup_chat_template()
# Run inference based on mode
if args.interactive:
# Interactive chat mode
inference.interactive_chat()
elif args.message:
# Single message mode
print("Running single message inference...")
messages = [{"role": "user", "content": args.message}]
if args.stream:
print("User:", args.message)
print("Assistant: ", end="", flush=True)
inference.generate_response(messages, stream=True)
else:
response = inference.generate_response(messages, stream=False)
print(f"User: {args.message}")
print(f"Assistant: {response}")
else:
# Default to interactive mode if no specific mode is chosen
print("No specific mode chosen. Starting interactive chat...")
inference.interactive_chat()
except Exception as e:
print(f"Inference failed: {e}")
import traceback
traceback.print_exc()
sys.exit(1)
if __name__ == "__main__":
main()