Files
OwusuBlessing fef3f5ae35 initial setupt
2025-08-06 22:45:37 +01:00

481 lines
18 KiB
Python

import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from pathlib import Path
import json
import numpy as np
from typing import List, Dict, Union, Optional
from dataclasses import dataclass
import logging
import argparse
import sys
import yaml
logger = logging.getLogger(__name__)
@dataclass
class InferenceConfig:
"""Simple configuration for inference"""
model_path: str # Path to saved model
device: str = "auto" # "auto", "cuda", "cpu"
batch_size: int = 32
max_length: int = 512
return_probabilities: bool = True
return_top_k: int = 1 # Return top K predictions
class ModelInference:
"""Simple inference class for text classification"""
def __init__(self, config: InferenceConfig):
self.config = config
self.model = None
self.tokenizer = None
self.label_info = {}
self.device = self._setup_device()
# Load model and tokenizer
self.load_model()
def _setup_device(self) -> torch.device:
"""Setup device for inference"""
if self.config.device == "auto":
if torch.cuda.is_available():
device = torch.device("cuda")
logger.info("Using CUDA device")
else:
device = torch.device("cpu")
logger.info("Using CPU device")
else:
device = torch.device(self.config.device)
logger.info(f"Using specified device: {device}")
return device
def load_model(self):
"""Load model, tokenizer, and label mappings"""
model_path = Path(self.config.model_path)
if not model_path.exists():
raise FileNotFoundError(f"Model path not found: {model_path}")
# Load tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
logger.info(f"Loaded tokenizer from {model_path}")
# Load model
self.model = AutoModelForSequenceClassification.from_pretrained(model_path)
self.model.to(self.device)
self.model.eval()
logger.info(f"Loaded model from {model_path}")
# Load label information
label_info_path = model_path / "label_info.json"
if label_info_path.exists():
with open(label_info_path, 'r') as f:
self.label_info = json.load(f)
logger.info(f"Loaded label mappings: {self.label_info.get('id_to_label', {})}")
else:
logger.warning("No label_info.json found. Using default numeric labels.")
# Create default mappings
num_labels = self.model.config.num_labels
self.label_info = {
"id_to_label": {str(i): f"LABEL_{i}" for i in range(num_labels)},
"label_to_id": {f"LABEL_{i}": i for i in range(num_labels)},
"num_labels": num_labels
}
def predict_single(self, text: str) -> Dict:
"""Predict single text sample"""
return self.predict_batch([text])[0]
def predict_batch(self, texts: List[str]) -> List[Dict]:
"""Predict batch of texts"""
if not texts:
return []
# Tokenize inputs
inputs = self.tokenizer(
texts,
truncation=True,
padding=True,
max_length=self.config.max_length,
return_tensors="pt"
)
# Move to device
inputs = {k: v.to(self.device) for k, v in inputs.items()}
# Get predictions
with torch.no_grad():
outputs = self.model(**inputs)
logits = outputs.logits
# Convert to probabilities
probabilities = torch.softmax(logits, dim=-1)
# Process results
results = []
for i, text in enumerate(texts):
text_probs = probabilities[i].cpu().numpy()
# Get top K predictions
top_indices = np.argsort(text_probs)[-self.config.return_top_k:][::-1]
predictions = []
for idx in top_indices:
label = self.label_info["id_to_label"].get(str(idx), f"LABEL_{idx}")
prob = float(text_probs[idx])
pred_dict = {
"label": label,
"label_id": int(idx),
"score": prob
}
predictions.append(pred_dict)
result = {
"text": text,
"predictions": predictions,
"predicted_label": predictions[0]["label"],
"confidence": predictions[0]["score"]
}
if self.config.return_probabilities:
# Add all class probabilities
all_probs = {}
for label_id, label_name in self.label_info["id_to_label"].items():
all_probs[label_name] = float(text_probs[int(label_id)])
result["all_probabilities"] = all_probs
results.append(result)
return results
def predict_file(self, input_file: str, output_file: str = None) -> List[Dict]:
"""Predict on texts from file"""
input_path = Path(input_file)
# Read texts from file
texts = []
if input_path.suffix == '.txt':
# Plain text file (one text per line)
with open(input_path, 'r', encoding='utf-8') as f:
texts = [line.strip() for line in f if line.strip()]
elif input_path.suffix == '.jsonl':
# JSONL file with "text" field
with open(input_path, 'r', encoding='utf-8') as f:
for line in f:
if line.strip():
data = json.loads(line)
text = data.get("text", data.get("input", ""))
if text:
texts.append(text)
else:
raise ValueError(f"Unsupported file format: {input_path.suffix}")
logger.info(f"Loaded {len(texts)} texts from {input_file}")
# Process in batches
all_results = []
for i in range(0, len(texts), self.config.batch_size):
batch_texts = texts[i:i + self.config.batch_size]
batch_results = self.predict_batch(batch_texts)
all_results.extend(batch_results)
if i % (self.config.batch_size * 10) == 0:
logger.info(f"Processed {i + len(batch_texts)}/{len(texts)} texts")
# Save results if output file specified
if output_file:
output_path = Path(output_file)
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_path, 'w', encoding='utf-8') as f:
for result in all_results:
f.write(json.dumps(result) + '\n')
logger.info(f"Results saved to {output_file}")
return all_results
class BatchInference:
"""Optimized batch inference for large datasets"""
def __init__(self, config: InferenceConfig):
self.inference = ModelInference(config)
self.config = config
def predict_large_file(self, input_file: str, output_file: str,
chunk_size: int = 1000) -> None:
"""Process large files in chunks to manage memory"""
input_path = Path(input_file)
output_path = Path(output_file)
output_path.parent.mkdir(parents=True, exist_ok=True)
# Count total lines first
total_lines = 0
with open(input_path, 'r', encoding='utf-8') as f:
for _ in f:
total_lines += 1
logger.info(f"Processing {total_lines} lines in chunks of {chunk_size}")
processed = 0
with open(input_path, 'r', encoding='utf-8') as infile, \
open(output_path, 'w', encoding='utf-8') as outfile:
chunk_texts = []
for line in infile:
if line.strip():
if input_path.suffix == '.jsonl':
data = json.loads(line)
text = data.get("text", data.get("input", ""))
else:
text = line.strip()
chunk_texts.append(text)
if len(chunk_texts) >= chunk_size:
# Process chunk
results = self.inference.predict_batch(chunk_texts)
# Write results
for result in results:
outfile.write(json.dumps(result) + '\n')
processed += len(chunk_texts)
logger.info(f"Processed {processed}/{total_lines} texts")
chunk_texts = []
# Process remaining texts
if chunk_texts:
results = self.inference.predict_batch(chunk_texts)
for result in results:
outfile.write(json.dumps(result) + '\n')
processed += len(chunk_texts)
logger.info(f"Completed processing {processed} texts")
def create_inference_config(
model_path: str,
device: str = "auto",
batch_size: int = 32,
**kwargs
) -> InferenceConfig:
"""Helper function to create inference configuration"""
return InferenceConfig(
model_path=model_path,
device=device,
batch_size=batch_size,
**kwargs
)
def main():
"""Main function with YAML configuration support"""
parser = argparse.ArgumentParser(description="Model Inference Pipeline")
# YAML configuration
parser.add_argument("--config", type=str, help="Path to YAML configuration file")
# Model settings
parser.add_argument("--model-path", type=str, help="Path to saved model directory")
parser.add_argument("--device", choices=["auto", "cuda", "cpu"], help="Device to run inference on")
parser.add_argument("--batch-size", type=int, help="Batch size for inference")
parser.add_argument("--max-length", type=int, help="Maximum sequence length for tokenization")
# Inference settings
parser.add_argument("--return-probabilities", action="store_true", help="Return all class probabilities")
parser.add_argument("--return-top-k", type=int, help="Return top K predictions")
# Input/Output settings
parser.add_argument("--input-text", type=str, help="Single text for prediction")
parser.add_argument("--input-file", type=str, help="Input file path (txt or jsonl)")
parser.add_argument("--output-file", type=str, help="Output file path for results")
parser.add_argument("--chunk-size", type=int, help="Chunk size for large file processing")
# Logging
parser.add_argument("--log-level", choices=["DEBUG", "INFO", "WARNING", "ERROR"], default="INFO", help="Logging level")
args = parser.parse_args()
# Set up logging
logging.basicConfig(
level=getattr(logging, args.log_level),
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
# Load configuration
config_dict = {}
# Load YAML config if provided
if args.config:
try:
with open(args.config, 'r', encoding='utf-8') as f:
config_dict = yaml.safe_load(f)
logger.info(f"Loaded YAML configuration from: {args.config}")
except Exception as e:
logger.error(f"Error loading YAML config: {e}")
sys.exit(1)
# Override YAML config with CLI arguments
cli_overrides = {}
if args.model_path:
cli_overrides['model_path'] = args.model_path
if args.device:
cli_overrides['device'] = args.device
if args.batch_size:
cli_overrides['batch_size'] = args.batch_size
if args.max_length:
cli_overrides['max_length'] = args.max_length
if args.return_probabilities:
cli_overrides['return_probabilities'] = True
if args.return_top_k:
cli_overrides['return_top_k'] = args.return_top_k
# Merge configurations
for key, value in cli_overrides.items():
if key in config_dict:
logger.info(f"Overriding YAML config '{key}' with CLI value: {value}")
config_dict[key] = value
# Validate model path
model_path = config_dict.get('model_path')
if not model_path:
parser.error("--model-path is required (either in YAML config or CLI)")
model_path = Path(model_path)
if not model_path.exists():
print(f"❌ Model path not found: {model_path}")
print("Please ensure the model directory exists and contains the trained model files")
sys.exit(1)
# Create configuration object
config = InferenceConfig(
model_path=str(model_path),
device=config_dict.get('device', 'auto'),
batch_size=config_dict.get('batch_size', 32),
max_length=config_dict.get('max_length', 512),
return_probabilities=config_dict.get('return_probabilities', True),
return_top_k=config_dict.get('return_top_k', 1)
)
try:
print(f"Loading model from: {config.model_path}")
print(f"Device: {config.device}")
print(f"Batch size: {config.batch_size}")
if args.config:
print(f"Using YAML configuration: {args.config}")
print()
# Initialize inference
inference = ModelInference(config)
# Handle different input types
if args.input_text:
# Single text prediction
print(f"=== Single Text Prediction ===")
print(f"Input text: {args.input_text}")
print()
result = inference.predict_single(args.input_text)
print(f"Predicted label: {result['predicted_label']}")
print(f"Confidence: {result['confidence']:.4f}")
print(f"Top {config.return_top_k} predictions:")
for pred in result['predictions']:
print(f" - {pred['label']}: {pred['score']:.4f}")
if config.return_probabilities and 'all_probabilities' in result:
print(f"\nAll class probabilities:")
for label, prob in result['all_probabilities'].items():
print(f" - {label}: {prob:.4f}")
elif args.input_file:
# File prediction
input_path = Path(args.input_file)
if not input_path.exists():
print(f"❌ Input file not found: {input_path}")
sys.exit(1)
print(f"=== File Prediction ===")
print(f"Input file: {args.input_file}")
print(f"Output file: {args.output_file}")
print()
if args.output_file:
# Use batch inference for large files
if input_path.stat().st_size > 10 * 1024 * 1024: # > 10MB
print("Large file detected, using chunked processing...")
batch_inference = BatchInference(config)
chunk_size = args.chunk_size or 1000
batch_inference.predict_large_file(
args.input_file,
args.output_file,
chunk_size=chunk_size
)
else:
# Regular file processing
results = inference.predict_file(args.input_file, args.output_file)
print(f"Processed {len(results)} texts")
else:
# Just predict without saving
results = inference.predict_file(args.input_file)
print(f"Processed {len(results)} texts")
print("\nSample results:")
for i, result in enumerate(results[:3]): # Show first 3
print(f" {i+1}. '{result['text'][:50]}...' -> {result['predicted_label']} ({result['confidence']:.4f})")
else:
# Interactive mode - example predictions
print("=== Interactive Mode ===")
print("No input specified. Running example predictions...")
print()
# Example texts
example_texts = [
"I love this product! It's amazing.",
"This is terrible, I hate it.",
"The weather is okay today.",
"Best purchase ever made!"
]
print("Example predictions:")
for text in example_texts:
result = inference.predict_single(text)
print(f" '{text}' -> {result['predicted_label']} ({result['confidence']:.4f})")
print(f"\n✅ Inference completed successfully!")
except Exception as e:
print(f"❌ Inference failed: {e}")
sys.exit(1)
if __name__ == "__main__":
main()
"""
# Quick usage examples:
# 1. Single prediction
config = InferenceConfig(model_path="./results/my_model")
inference = ModelInference(config)
result = inference.predict_single("Your text here")
# # 2. Batch prediction
# results = inference.predict_batch(["text1", "text2", "text3"])
# # 3. File prediction
# inference.predict_file("input.txt", "predictions.jsonl")
# # 4. Large file processing
# batch_inference = BatchInference(config)
# batch_inference.predict_large_file("large_input.jsonl", "large_output.jsonl", chunk_size=1000)
"""