updated instruct
This commit is contained in:
@@ -0,0 +1,393 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Instruct Inference Pipeline using Trained Models
|
||||
Supports conversational inference with streaming and batch processing
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Optional, List, Union
|
||||
import yaml
|
||||
|
||||
# Add the project root to the path
|
||||
sys.path.append(str(Path(__file__).parent.parent.parent))
|
||||
|
||||
# Inference imports
|
||||
import torch
|
||||
from datasets import load_from_disk, Dataset
|
||||
from unsloth import FastLanguageModel
|
||||
from unsloth.chat_templates import get_chat_template
|
||||
from transformers import TextStreamer
|
||||
|
||||
class InstructInference:
|
||||
"""Instruction fine-tuning inference using trained models"""
|
||||
|
||||
def __init__(self, config: Dict[str, Any]):
|
||||
self.config = config
|
||||
self.model = None
|
||||
self.tokenizer = None
|
||||
|
||||
# Set device
|
||||
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
print(f"Using device: {self.device}")
|
||||
|
||||
# Model parameters
|
||||
self.model_output_dir = config.get('model_output_dir', './models/instruct')
|
||||
self.base_model_name = config.get('base_model_name', 'unsloth/Qwen2.5-72B-Instruct')
|
||||
self.max_seq_length = config.get('max_seq_length', 2048)
|
||||
self.dtype = config.get('dtype', None)
|
||||
self.load_in_4bit = config.get('load_in_4bit', True)
|
||||
self.hf_token = config.get('hf_token', None)
|
||||
|
||||
# Inference parameters
|
||||
self.batch_size = config.get('batch_size', 1)
|
||||
self.max_new_tokens = config.get('max_new_tokens', 128)
|
||||
self.temperature = config.get('temperature', 1.5)
|
||||
self.min_p = config.get('min_p', 0.1)
|
||||
self.use_cache = config.get('use_cache', True)
|
||||
|
||||
# Chat template
|
||||
self.chat_template = config.get('chat_template', 'qwen-2.5')
|
||||
|
||||
def load_model_and_tokenizer(self):
|
||||
"""Load the trained model and tokenizer"""
|
||||
print("Loading trained instruction model and tokenizer...")
|
||||
|
||||
try:
|
||||
# Load the saved LoRA model
|
||||
model_path = self.model_output_dir
|
||||
print(f"Loading model from: {model_path}")
|
||||
|
||||
# Check if the model directory exists
|
||||
if not Path(model_path).exists():
|
||||
raise FileNotFoundError(f"Model directory not found: {model_path}")
|
||||
|
||||
# Load the model directly from the saved path
|
||||
self.model, self.tokenizer = FastLanguageModel.from_pretrained(
|
||||
model_name=model_path,
|
||||
max_seq_length=self.max_seq_length,
|
||||
dtype=self.dtype,
|
||||
load_in_4bit=self.load_in_4bit,
|
||||
)
|
||||
|
||||
# Enable native 2x faster inference
|
||||
FastLanguageModel.for_inference(self.model)
|
||||
|
||||
print(f"✅ Model loaded from: {model_path}")
|
||||
print(f"✅ Tokenizer loaded with vocab size: {self.tokenizer.vocab_size}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error loading model: {e}")
|
||||
raise
|
||||
|
||||
def setup_chat_template(self):
|
||||
"""Setup chat template for conversation formatting"""
|
||||
print("Setting up chat template...")
|
||||
|
||||
try:
|
||||
self.tokenizer = get_chat_template(
|
||||
self.tokenizer,
|
||||
chat_template=self.chat_template,
|
||||
)
|
||||
|
||||
print(f"✅ Chat template configured: {self.chat_template}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error setting up chat template: {e}")
|
||||
raise
|
||||
|
||||
def format_messages(self, messages: List[Dict[str, str]]) -> str:
|
||||
"""Format messages using chat template"""
|
||||
try:
|
||||
# Apply chat template to format the conversation
|
||||
formatted_prompt = self.tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True, # Add generation prompt for inference
|
||||
)
|
||||
|
||||
return formatted_prompt
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error formatting messages: {e}")
|
||||
raise
|
||||
|
||||
def generate_response(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
max_new_tokens: Optional[int] = None,
|
||||
temperature: Optional[float] = None,
|
||||
stream: bool = False
|
||||
) -> str:
|
||||
"""Generate response using the trained instruction model"""
|
||||
try:
|
||||
# Use default values if not provided
|
||||
max_tokens = max_new_tokens or self.max_new_tokens
|
||||
temp = temperature or self.temperature
|
||||
|
||||
# Format the messages
|
||||
formatted_prompt = self.format_messages(messages)
|
||||
print(f"Formatted prompt: {formatted_prompt[:200]}...")
|
||||
|
||||
# Tokenize the input
|
||||
inputs = self.tokenizer(
|
||||
[formatted_prompt],
|
||||
return_tensors="pt"
|
||||
).to(self.device)
|
||||
|
||||
if stream:
|
||||
# Streaming generation
|
||||
text_streamer = TextStreamer(self.tokenizer, skip_prompt=True)
|
||||
print("Generating with streaming...")
|
||||
_ = self.model.generate(
|
||||
input_ids=inputs.input_ids,
|
||||
streamer=text_streamer,
|
||||
max_new_tokens=max_tokens,
|
||||
use_cache=self.use_cache,
|
||||
temperature=temp,
|
||||
min_p=self.min_p
|
||||
)
|
||||
return "" # Streaming output is handled by streamer
|
||||
else:
|
||||
# Non-streaming generation
|
||||
print("Generating response...")
|
||||
outputs = self.model.generate(
|
||||
input_ids=inputs.input_ids,
|
||||
max_new_tokens=max_tokens,
|
||||
use_cache=self.use_cache,
|
||||
temperature=temp,
|
||||
min_p=self.min_p,
|
||||
pad_token_id=self.tokenizer.eos_token_id
|
||||
)
|
||||
|
||||
# Decode the generated text
|
||||
full_response = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
|
||||
|
||||
# Extract only the new generated response (remove the input prompt)
|
||||
prompt_length = len(formatted_prompt)
|
||||
response = full_response[prompt_length:].strip()
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error generating response: {e}")
|
||||
raise
|
||||
|
||||
def chat(self, user_input: str, conversation_history: Optional[List[Dict[str, str]]] = None, stream: bool = False) -> str:
|
||||
"""Have a chat conversation with the model"""
|
||||
try:
|
||||
# Initialize conversation history if not provided
|
||||
if conversation_history is None:
|
||||
conversation_history = []
|
||||
|
||||
# Add user input to conversation
|
||||
messages = conversation_history + [{"role": "user", "content": user_input}]
|
||||
|
||||
print(f"User: {user_input}")
|
||||
|
||||
if stream:
|
||||
print("Assistant: ", end="", flush=True)
|
||||
self.generate_response(messages, stream=True)
|
||||
return ""
|
||||
else:
|
||||
# Generate response
|
||||
response = self.generate_response(messages, stream=False)
|
||||
print(f"Assistant: {response}")
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error in chat: {e}")
|
||||
raise
|
||||
|
||||
def batch_inference(
|
||||
self,
|
||||
conversations: List[List[Dict[str, str]]],
|
||||
max_new_tokens: Optional[int] = None
|
||||
) -> List[str]:
|
||||
"""Perform batch inference on multiple conversations"""
|
||||
responses = []
|
||||
|
||||
for i, messages in enumerate(conversations):
|
||||
print(f"Processing conversation {i+1}/{len(conversations)}")
|
||||
response = self.generate_response(messages, max_new_tokens)
|
||||
responses.append(response)
|
||||
|
||||
return responses
|
||||
|
||||
def interactive_chat(self):
|
||||
"""Start an interactive chat session"""
|
||||
print("🤖 Starting interactive chat session...")
|
||||
print("Type 'quit', 'exit', or 'bye' to end the conversation.")
|
||||
print("Type 'clear' to clear conversation history.")
|
||||
print("Type 'stream on' or 'stream off' to toggle streaming.")
|
||||
print("-" * 50)
|
||||
|
||||
conversation_history = []
|
||||
streaming = False
|
||||
|
||||
while True:
|
||||
try:
|
||||
user_input = input("\n👤 You: ").strip()
|
||||
|
||||
if user_input.lower() in ['quit', 'exit', 'bye']:
|
||||
print("👋 Goodbye!")
|
||||
break
|
||||
elif user_input.lower() == 'clear':
|
||||
conversation_history = []
|
||||
print("🗑️ Conversation history cleared.")
|
||||
continue
|
||||
elif user_input.lower() == 'stream on':
|
||||
streaming = True
|
||||
print("🔄 Streaming enabled.")
|
||||
continue
|
||||
elif user_input.lower() == 'stream off':
|
||||
streaming = False
|
||||
print("⏸️ Streaming disabled.")
|
||||
continue
|
||||
elif not user_input:
|
||||
continue
|
||||
|
||||
# Generate response
|
||||
if streaming:
|
||||
print("🤖 Assistant: ", end="", flush=True)
|
||||
self.chat(user_input, conversation_history, stream=True)
|
||||
# Add to history (we don't have the actual response text for streaming)
|
||||
conversation_history.extend([
|
||||
{"role": "user", "content": user_input},
|
||||
{"role": "assistant", "content": "[Streamed response]"}
|
||||
])
|
||||
else:
|
||||
response = self.chat(user_input, conversation_history, stream=False)
|
||||
# Add to history
|
||||
conversation_history.extend([
|
||||
{"role": "user", "content": user_input},
|
||||
{"role": "assistant", "content": response}
|
||||
])
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n👋 Goodbye!")
|
||||
break
|
||||
except Exception as e:
|
||||
print(f"❌ Error: {e}")
|
||||
continue
|
||||
|
||||
def load_inference_config(config_path: str) -> Dict[str, Any]:
|
||||
"""Load inference configuration from YAML file"""
|
||||
try:
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
config = yaml.safe_load(f)
|
||||
|
||||
# Extract inference configuration
|
||||
inference_config = {}
|
||||
|
||||
# Model configuration
|
||||
if 'model' in config:
|
||||
model_data = config['model']
|
||||
inference_config.update({
|
||||
'base_model_name': model_data.get('training_model', 'unsloth/Qwen2.5-72B-Instruct'),
|
||||
'max_seq_length': model_data.get('training_max_seq_length', 2048),
|
||||
'dtype': model_data.get('training_dtype'),
|
||||
'load_in_4bit': model_data.get('training_load_in_4bit', True),
|
||||
'hf_token': model_data.get('training_token')
|
||||
})
|
||||
|
||||
# Training configuration - to get model_output_dir
|
||||
if 'training' in config:
|
||||
training_data = config['training']
|
||||
inference_config.update({
|
||||
'model_output_dir': training_data.get('model_output_dir', './models/instruct')
|
||||
})
|
||||
|
||||
# Inference configuration
|
||||
if 'inference' in config:
|
||||
inference_data = config['inference']
|
||||
inference_config.update({
|
||||
'batch_size': inference_data.get('batch_size', 1),
|
||||
'max_new_tokens': inference_data.get('max_new_tokens', 128),
|
||||
'temperature': inference_data.get('temperature', 1.5),
|
||||
'min_p': inference_data.get('min_p', 0.1),
|
||||
'use_cache': inference_data.get('use_cache', True)
|
||||
})
|
||||
|
||||
# Chat template
|
||||
inference_config.update({
|
||||
'chat_template': 'qwen-2.5' # Use Qwen chat template by default
|
||||
})
|
||||
|
||||
return inference_config
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error loading inference config: {e}")
|
||||
raise
|
||||
|
||||
def main():
|
||||
"""Main inference function"""
|
||||
parser = argparse.ArgumentParser(description="Instruction Inference Pipeline")
|
||||
|
||||
# Configuration
|
||||
parser.add_argument("--config", type=str, required=True, help="Path to YAML configuration file")
|
||||
parser.add_argument("--interactive", action="store_true", help="Start interactive chat session")
|
||||
parser.add_argument("--message", type=str, help="Single message to send to the model")
|
||||
parser.add_argument("--max-tokens", type=int, help="Maximum new tokens to generate")
|
||||
parser.add_argument("--stream", action="store_true", help="Enable streaming generation")
|
||||
parser.add_argument("--temperature", type=float, help="Sampling temperature")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
# Load configuration
|
||||
print(f"Loading configuration from: {args.config}")
|
||||
inference_config = load_inference_config(args.config)
|
||||
|
||||
# Override with CLI arguments
|
||||
if args.max_tokens:
|
||||
inference_config['max_new_tokens'] = args.max_tokens
|
||||
if args.temperature:
|
||||
inference_config['temperature'] = args.temperature
|
||||
|
||||
print("Inference configuration:")
|
||||
for key, value in inference_config.items():
|
||||
print(f" {key}: {value}")
|
||||
|
||||
# Initialize inference
|
||||
inference = InstructInference(inference_config)
|
||||
|
||||
# Load model and tokenizer
|
||||
inference.load_model_and_tokenizer()
|
||||
|
||||
# Setup chat template
|
||||
inference.setup_chat_template()
|
||||
|
||||
# Run inference based on mode
|
||||
if args.interactive:
|
||||
# Interactive chat mode
|
||||
inference.interactive_chat()
|
||||
elif args.message:
|
||||
# Single message mode
|
||||
print("Running single message inference...")
|
||||
messages = [{"role": "user", "content": args.message}]
|
||||
|
||||
if args.stream:
|
||||
print("User:", args.message)
|
||||
print("Assistant: ", end="", flush=True)
|
||||
inference.generate_response(messages, stream=True)
|
||||
else:
|
||||
response = inference.generate_response(messages, stream=False)
|
||||
print(f"User: {args.message}")
|
||||
print(f"Assistant: {response}")
|
||||
else:
|
||||
# Default to interactive mode if no specific mode is chosen
|
||||
print("No specific mode chosen. Starting interactive chat...")
|
||||
inference.interactive_chat()
|
||||
|
||||
except Exception as e:
|
||||
print(f"Inference failed: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user