481 lines
18 KiB
Python
481 lines
18 KiB
Python
|
|
import torch
|
||
|
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
||
|
|
from pathlib import Path
|
||
|
|
import json
|
||
|
|
import numpy as np
|
||
|
|
from typing import List, Dict, Union, Optional
|
||
|
|
from dataclasses import dataclass
|
||
|
|
import logging
|
||
|
|
import argparse
|
||
|
|
import sys
|
||
|
|
import yaml
|
||
|
|
|
||
|
|
logger = logging.getLogger(__name__)
|
||
|
|
|
||
|
|
@dataclass
|
||
|
|
class InferenceConfig:
|
||
|
|
"""Simple configuration for inference"""
|
||
|
|
model_path: str # Path to saved model
|
||
|
|
device: str = "auto" # "auto", "cuda", "cpu"
|
||
|
|
batch_size: int = 32
|
||
|
|
max_length: int = 512
|
||
|
|
return_probabilities: bool = True
|
||
|
|
return_top_k: int = 1 # Return top K predictions
|
||
|
|
|
||
|
|
|
||
|
|
class ModelInference:
|
||
|
|
"""Simple inference class for text classification"""
|
||
|
|
|
||
|
|
def __init__(self, config: InferenceConfig):
|
||
|
|
self.config = config
|
||
|
|
self.model = None
|
||
|
|
self.tokenizer = None
|
||
|
|
self.label_info = {}
|
||
|
|
self.device = self._setup_device()
|
||
|
|
|
||
|
|
# Load model and tokenizer
|
||
|
|
self.load_model()
|
||
|
|
|
||
|
|
def _setup_device(self) -> torch.device:
|
||
|
|
"""Setup device for inference"""
|
||
|
|
if self.config.device == "auto":
|
||
|
|
if torch.cuda.is_available():
|
||
|
|
device = torch.device("cuda")
|
||
|
|
logger.info("Using CUDA device")
|
||
|
|
else:
|
||
|
|
device = torch.device("cpu")
|
||
|
|
logger.info("Using CPU device")
|
||
|
|
else:
|
||
|
|
device = torch.device(self.config.device)
|
||
|
|
logger.info(f"Using specified device: {device}")
|
||
|
|
|
||
|
|
return device
|
||
|
|
|
||
|
|
def load_model(self):
|
||
|
|
"""Load model, tokenizer, and label mappings"""
|
||
|
|
model_path = Path(self.config.model_path)
|
||
|
|
|
||
|
|
if not model_path.exists():
|
||
|
|
raise FileNotFoundError(f"Model path not found: {model_path}")
|
||
|
|
|
||
|
|
# Load tokenizer
|
||
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||
|
|
logger.info(f"Loaded tokenizer from {model_path}")
|
||
|
|
|
||
|
|
# Load model
|
||
|
|
self.model = AutoModelForSequenceClassification.from_pretrained(model_path)
|
||
|
|
self.model.to(self.device)
|
||
|
|
self.model.eval()
|
||
|
|
logger.info(f"Loaded model from {model_path}")
|
||
|
|
|
||
|
|
# Load label information
|
||
|
|
label_info_path = model_path / "label_info.json"
|
||
|
|
if label_info_path.exists():
|
||
|
|
with open(label_info_path, 'r') as f:
|
||
|
|
self.label_info = json.load(f)
|
||
|
|
logger.info(f"Loaded label mappings: {self.label_info.get('id_to_label', {})}")
|
||
|
|
else:
|
||
|
|
logger.warning("No label_info.json found. Using default numeric labels.")
|
||
|
|
# Create default mappings
|
||
|
|
num_labels = self.model.config.num_labels
|
||
|
|
self.label_info = {
|
||
|
|
"id_to_label": {str(i): f"LABEL_{i}" for i in range(num_labels)},
|
||
|
|
"label_to_id": {f"LABEL_{i}": i for i in range(num_labels)},
|
||
|
|
"num_labels": num_labels
|
||
|
|
}
|
||
|
|
|
||
|
|
def predict_single(self, text: str) -> Dict:
|
||
|
|
"""Predict single text sample"""
|
||
|
|
return self.predict_batch([text])[0]
|
||
|
|
|
||
|
|
def predict_batch(self, texts: List[str]) -> List[Dict]:
|
||
|
|
"""Predict batch of texts"""
|
||
|
|
if not texts:
|
||
|
|
return []
|
||
|
|
|
||
|
|
# Tokenize inputs
|
||
|
|
inputs = self.tokenizer(
|
||
|
|
texts,
|
||
|
|
truncation=True,
|
||
|
|
padding=True,
|
||
|
|
max_length=self.config.max_length,
|
||
|
|
return_tensors="pt"
|
||
|
|
)
|
||
|
|
|
||
|
|
# Move to device
|
||
|
|
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
||
|
|
|
||
|
|
# Get predictions
|
||
|
|
with torch.no_grad():
|
||
|
|
outputs = self.model(**inputs)
|
||
|
|
logits = outputs.logits
|
||
|
|
|
||
|
|
# Convert to probabilities
|
||
|
|
probabilities = torch.softmax(logits, dim=-1)
|
||
|
|
|
||
|
|
# Process results
|
||
|
|
results = []
|
||
|
|
for i, text in enumerate(texts):
|
||
|
|
text_probs = probabilities[i].cpu().numpy()
|
||
|
|
|
||
|
|
# Get top K predictions
|
||
|
|
top_indices = np.argsort(text_probs)[-self.config.return_top_k:][::-1]
|
||
|
|
|
||
|
|
predictions = []
|
||
|
|
for idx in top_indices:
|
||
|
|
label = self.label_info["id_to_label"].get(str(idx), f"LABEL_{idx}")
|
||
|
|
prob = float(text_probs[idx])
|
||
|
|
|
||
|
|
pred_dict = {
|
||
|
|
"label": label,
|
||
|
|
"label_id": int(idx),
|
||
|
|
"score": prob
|
||
|
|
}
|
||
|
|
predictions.append(pred_dict)
|
||
|
|
|
||
|
|
result = {
|
||
|
|
"text": text,
|
||
|
|
"predictions": predictions,
|
||
|
|
"predicted_label": predictions[0]["label"],
|
||
|
|
"confidence": predictions[0]["score"]
|
||
|
|
}
|
||
|
|
|
||
|
|
if self.config.return_probabilities:
|
||
|
|
# Add all class probabilities
|
||
|
|
all_probs = {}
|
||
|
|
for label_id, label_name in self.label_info["id_to_label"].items():
|
||
|
|
all_probs[label_name] = float(text_probs[int(label_id)])
|
||
|
|
result["all_probabilities"] = all_probs
|
||
|
|
|
||
|
|
results.append(result)
|
||
|
|
|
||
|
|
return results
|
||
|
|
|
||
|
|
def predict_file(self, input_file: str, output_file: str = None) -> List[Dict]:
|
||
|
|
"""Predict on texts from file"""
|
||
|
|
input_path = Path(input_file)
|
||
|
|
|
||
|
|
# Read texts from file
|
||
|
|
texts = []
|
||
|
|
if input_path.suffix == '.txt':
|
||
|
|
# Plain text file (one text per line)
|
||
|
|
with open(input_path, 'r', encoding='utf-8') as f:
|
||
|
|
texts = [line.strip() for line in f if line.strip()]
|
||
|
|
|
||
|
|
elif input_path.suffix == '.jsonl':
|
||
|
|
# JSONL file with "text" field
|
||
|
|
with open(input_path, 'r', encoding='utf-8') as f:
|
||
|
|
for line in f:
|
||
|
|
if line.strip():
|
||
|
|
data = json.loads(line)
|
||
|
|
text = data.get("text", data.get("input", ""))
|
||
|
|
if text:
|
||
|
|
texts.append(text)
|
||
|
|
|
||
|
|
else:
|
||
|
|
raise ValueError(f"Unsupported file format: {input_path.suffix}")
|
||
|
|
|
||
|
|
logger.info(f"Loaded {len(texts)} texts from {input_file}")
|
||
|
|
|
||
|
|
# Process in batches
|
||
|
|
all_results = []
|
||
|
|
for i in range(0, len(texts), self.config.batch_size):
|
||
|
|
batch_texts = texts[i:i + self.config.batch_size]
|
||
|
|
batch_results = self.predict_batch(batch_texts)
|
||
|
|
all_results.extend(batch_results)
|
||
|
|
|
||
|
|
if i % (self.config.batch_size * 10) == 0:
|
||
|
|
logger.info(f"Processed {i + len(batch_texts)}/{len(texts)} texts")
|
||
|
|
|
||
|
|
# Save results if output file specified
|
||
|
|
if output_file:
|
||
|
|
output_path = Path(output_file)
|
||
|
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||
|
|
|
||
|
|
with open(output_path, 'w', encoding='utf-8') as f:
|
||
|
|
for result in all_results:
|
||
|
|
f.write(json.dumps(result) + '\n')
|
||
|
|
|
||
|
|
logger.info(f"Results saved to {output_file}")
|
||
|
|
|
||
|
|
return all_results
|
||
|
|
|
||
|
|
|
||
|
|
class BatchInference:
|
||
|
|
"""Optimized batch inference for large datasets"""
|
||
|
|
|
||
|
|
def __init__(self, config: InferenceConfig):
|
||
|
|
self.inference = ModelInference(config)
|
||
|
|
self.config = config
|
||
|
|
|
||
|
|
def predict_large_file(self, input_file: str, output_file: str,
|
||
|
|
chunk_size: int = 1000) -> None:
|
||
|
|
"""Process large files in chunks to manage memory"""
|
||
|
|
input_path = Path(input_file)
|
||
|
|
output_path = Path(output_file)
|
||
|
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||
|
|
|
||
|
|
# Count total lines first
|
||
|
|
total_lines = 0
|
||
|
|
with open(input_path, 'r', encoding='utf-8') as f:
|
||
|
|
for _ in f:
|
||
|
|
total_lines += 1
|
||
|
|
|
||
|
|
logger.info(f"Processing {total_lines} lines in chunks of {chunk_size}")
|
||
|
|
|
||
|
|
processed = 0
|
||
|
|
with open(input_path, 'r', encoding='utf-8') as infile, \
|
||
|
|
open(output_path, 'w', encoding='utf-8') as outfile:
|
||
|
|
|
||
|
|
chunk_texts = []
|
||
|
|
|
||
|
|
for line in infile:
|
||
|
|
if line.strip():
|
||
|
|
if input_path.suffix == '.jsonl':
|
||
|
|
data = json.loads(line)
|
||
|
|
text = data.get("text", data.get("input", ""))
|
||
|
|
else:
|
||
|
|
text = line.strip()
|
||
|
|
|
||
|
|
chunk_texts.append(text)
|
||
|
|
|
||
|
|
if len(chunk_texts) >= chunk_size:
|
||
|
|
# Process chunk
|
||
|
|
results = self.inference.predict_batch(chunk_texts)
|
||
|
|
|
||
|
|
# Write results
|
||
|
|
for result in results:
|
||
|
|
outfile.write(json.dumps(result) + '\n')
|
||
|
|
|
||
|
|
processed += len(chunk_texts)
|
||
|
|
logger.info(f"Processed {processed}/{total_lines} texts")
|
||
|
|
|
||
|
|
chunk_texts = []
|
||
|
|
|
||
|
|
# Process remaining texts
|
||
|
|
if chunk_texts:
|
||
|
|
results = self.inference.predict_batch(chunk_texts)
|
||
|
|
for result in results:
|
||
|
|
outfile.write(json.dumps(result) + '\n')
|
||
|
|
processed += len(chunk_texts)
|
||
|
|
|
||
|
|
logger.info(f"Completed processing {processed} texts")
|
||
|
|
|
||
|
|
|
||
|
|
def create_inference_config(
|
||
|
|
model_path: str,
|
||
|
|
device: str = "auto",
|
||
|
|
batch_size: int = 32,
|
||
|
|
**kwargs
|
||
|
|
) -> InferenceConfig:
|
||
|
|
"""Helper function to create inference configuration"""
|
||
|
|
return InferenceConfig(
|
||
|
|
model_path=model_path,
|
||
|
|
device=device,
|
||
|
|
batch_size=batch_size,
|
||
|
|
**kwargs
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
def main():
|
||
|
|
"""Main function with YAML configuration support"""
|
||
|
|
|
||
|
|
parser = argparse.ArgumentParser(description="Model Inference Pipeline")
|
||
|
|
|
||
|
|
# YAML configuration
|
||
|
|
parser.add_argument("--config", type=str, help="Path to YAML configuration file")
|
||
|
|
|
||
|
|
# Model settings
|
||
|
|
parser.add_argument("--model-path", type=str, help="Path to saved model directory")
|
||
|
|
parser.add_argument("--device", choices=["auto", "cuda", "cpu"], help="Device to run inference on")
|
||
|
|
parser.add_argument("--batch-size", type=int, help="Batch size for inference")
|
||
|
|
parser.add_argument("--max-length", type=int, help="Maximum sequence length for tokenization")
|
||
|
|
|
||
|
|
# Inference settings
|
||
|
|
parser.add_argument("--return-probabilities", action="store_true", help="Return all class probabilities")
|
||
|
|
parser.add_argument("--return-top-k", type=int, help="Return top K predictions")
|
||
|
|
|
||
|
|
# Input/Output settings
|
||
|
|
parser.add_argument("--input-text", type=str, help="Single text for prediction")
|
||
|
|
parser.add_argument("--input-file", type=str, help="Input file path (txt or jsonl)")
|
||
|
|
parser.add_argument("--output-file", type=str, help="Output file path for results")
|
||
|
|
parser.add_argument("--chunk-size", type=int, help="Chunk size for large file processing")
|
||
|
|
|
||
|
|
# Logging
|
||
|
|
parser.add_argument("--log-level", choices=["DEBUG", "INFO", "WARNING", "ERROR"], default="INFO", help="Logging level")
|
||
|
|
|
||
|
|
args = parser.parse_args()
|
||
|
|
|
||
|
|
# Set up logging
|
||
|
|
logging.basicConfig(
|
||
|
|
level=getattr(logging, args.log_level),
|
||
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||
|
|
)
|
||
|
|
|
||
|
|
# Load configuration
|
||
|
|
config_dict = {}
|
||
|
|
|
||
|
|
# Load YAML config if provided
|
||
|
|
if args.config:
|
||
|
|
try:
|
||
|
|
with open(args.config, 'r', encoding='utf-8') as f:
|
||
|
|
config_dict = yaml.safe_load(f)
|
||
|
|
logger.info(f"Loaded YAML configuration from: {args.config}")
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(f"Error loading YAML config: {e}")
|
||
|
|
sys.exit(1)
|
||
|
|
|
||
|
|
# Override YAML config with CLI arguments
|
||
|
|
cli_overrides = {}
|
||
|
|
if args.model_path:
|
||
|
|
cli_overrides['model_path'] = args.model_path
|
||
|
|
if args.device:
|
||
|
|
cli_overrides['device'] = args.device
|
||
|
|
if args.batch_size:
|
||
|
|
cli_overrides['batch_size'] = args.batch_size
|
||
|
|
if args.max_length:
|
||
|
|
cli_overrides['max_length'] = args.max_length
|
||
|
|
if args.return_probabilities:
|
||
|
|
cli_overrides['return_probabilities'] = True
|
||
|
|
if args.return_top_k:
|
||
|
|
cli_overrides['return_top_k'] = args.return_top_k
|
||
|
|
|
||
|
|
# Merge configurations
|
||
|
|
for key, value in cli_overrides.items():
|
||
|
|
if key in config_dict:
|
||
|
|
logger.info(f"Overriding YAML config '{key}' with CLI value: {value}")
|
||
|
|
config_dict[key] = value
|
||
|
|
|
||
|
|
# Validate model path
|
||
|
|
model_path = config_dict.get('model_path')
|
||
|
|
if not model_path:
|
||
|
|
parser.error("--model-path is required (either in YAML config or CLI)")
|
||
|
|
|
||
|
|
model_path = Path(model_path)
|
||
|
|
if not model_path.exists():
|
||
|
|
print(f"❌ Model path not found: {model_path}")
|
||
|
|
print("Please ensure the model directory exists and contains the trained model files")
|
||
|
|
sys.exit(1)
|
||
|
|
|
||
|
|
# Create configuration object
|
||
|
|
config = InferenceConfig(
|
||
|
|
model_path=str(model_path),
|
||
|
|
device=config_dict.get('device', 'auto'),
|
||
|
|
batch_size=config_dict.get('batch_size', 32),
|
||
|
|
max_length=config_dict.get('max_length', 512),
|
||
|
|
return_probabilities=config_dict.get('return_probabilities', True),
|
||
|
|
return_top_k=config_dict.get('return_top_k', 1)
|
||
|
|
)
|
||
|
|
|
||
|
|
try:
|
||
|
|
print(f"Loading model from: {config.model_path}")
|
||
|
|
print(f"Device: {config.device}")
|
||
|
|
print(f"Batch size: {config.batch_size}")
|
||
|
|
if args.config:
|
||
|
|
print(f"Using YAML configuration: {args.config}")
|
||
|
|
print()
|
||
|
|
|
||
|
|
# Initialize inference
|
||
|
|
inference = ModelInference(config)
|
||
|
|
|
||
|
|
# Handle different input types
|
||
|
|
if args.input_text:
|
||
|
|
# Single text prediction
|
||
|
|
print(f"=== Single Text Prediction ===")
|
||
|
|
print(f"Input text: {args.input_text}")
|
||
|
|
print()
|
||
|
|
|
||
|
|
result = inference.predict_single(args.input_text)
|
||
|
|
print(f"Predicted label: {result['predicted_label']}")
|
||
|
|
print(f"Confidence: {result['confidence']:.4f}")
|
||
|
|
print(f"Top {config.return_top_k} predictions:")
|
||
|
|
for pred in result['predictions']:
|
||
|
|
print(f" - {pred['label']}: {pred['score']:.4f}")
|
||
|
|
|
||
|
|
if config.return_probabilities and 'all_probabilities' in result:
|
||
|
|
print(f"\nAll class probabilities:")
|
||
|
|
for label, prob in result['all_probabilities'].items():
|
||
|
|
print(f" - {label}: {prob:.4f}")
|
||
|
|
|
||
|
|
elif args.input_file:
|
||
|
|
# File prediction
|
||
|
|
input_path = Path(args.input_file)
|
||
|
|
if not input_path.exists():
|
||
|
|
print(f"❌ Input file not found: {input_path}")
|
||
|
|
sys.exit(1)
|
||
|
|
|
||
|
|
print(f"=== File Prediction ===")
|
||
|
|
print(f"Input file: {args.input_file}")
|
||
|
|
print(f"Output file: {args.output_file}")
|
||
|
|
print()
|
||
|
|
|
||
|
|
if args.output_file:
|
||
|
|
# Use batch inference for large files
|
||
|
|
if input_path.stat().st_size > 10 * 1024 * 1024: # > 10MB
|
||
|
|
print("Large file detected, using chunked processing...")
|
||
|
|
batch_inference = BatchInference(config)
|
||
|
|
chunk_size = args.chunk_size or 1000
|
||
|
|
batch_inference.predict_large_file(
|
||
|
|
args.input_file,
|
||
|
|
args.output_file,
|
||
|
|
chunk_size=chunk_size
|
||
|
|
)
|
||
|
|
else:
|
||
|
|
# Regular file processing
|
||
|
|
results = inference.predict_file(args.input_file, args.output_file)
|
||
|
|
print(f"Processed {len(results)} texts")
|
||
|
|
else:
|
||
|
|
# Just predict without saving
|
||
|
|
results = inference.predict_file(args.input_file)
|
||
|
|
print(f"Processed {len(results)} texts")
|
||
|
|
print("\nSample results:")
|
||
|
|
for i, result in enumerate(results[:3]): # Show first 3
|
||
|
|
print(f" {i+1}. '{result['text'][:50]}...' -> {result['predicted_label']} ({result['confidence']:.4f})")
|
||
|
|
|
||
|
|
else:
|
||
|
|
# Interactive mode - example predictions
|
||
|
|
print("=== Interactive Mode ===")
|
||
|
|
print("No input specified. Running example predictions...")
|
||
|
|
print()
|
||
|
|
|
||
|
|
# Example texts
|
||
|
|
example_texts = [
|
||
|
|
"I love this product! It's amazing.",
|
||
|
|
"This is terrible, I hate it.",
|
||
|
|
"The weather is okay today.",
|
||
|
|
"Best purchase ever made!"
|
||
|
|
]
|
||
|
|
|
||
|
|
print("Example predictions:")
|
||
|
|
for text in example_texts:
|
||
|
|
result = inference.predict_single(text)
|
||
|
|
print(f" '{text}' -> {result['predicted_label']} ({result['confidence']:.4f})")
|
||
|
|
|
||
|
|
print(f"\n✅ Inference completed successfully!")
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
print(f"❌ Inference failed: {e}")
|
||
|
|
sys.exit(1)
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
main()
|
||
|
|
|
||
|
|
"""
|
||
|
|
# Quick usage examples:
|
||
|
|
|
||
|
|
# 1. Single prediction
|
||
|
|
config = InferenceConfig(model_path="./results/my_model")
|
||
|
|
inference = ModelInference(config)
|
||
|
|
result = inference.predict_single("Your text here")
|
||
|
|
|
||
|
|
# # 2. Batch prediction
|
||
|
|
# results = inference.predict_batch(["text1", "text2", "text3"])
|
||
|
|
|
||
|
|
# # 3. File prediction
|
||
|
|
# inference.predict_file("input.txt", "predictions.jsonl")
|
||
|
|
|
||
|
|
# # 4. Large file processing
|
||
|
|
# batch_inference = BatchInference(config)
|
||
|
|
# batch_inference.predict_large_file("large_input.jsonl", "large_output.jsonl", chunk_size=1000)
|
||
|
|
"""
|