initial setupt
This commit is contained in:
@@ -0,0 +1,481 @@
|
||||
import torch
|
||||
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
||||
from pathlib import Path
|
||||
import json
|
||||
import numpy as np
|
||||
from typing import List, Dict, Union, Optional
|
||||
from dataclasses import dataclass
|
||||
import logging
|
||||
import argparse
|
||||
import sys
|
||||
import yaml
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@dataclass
|
||||
class InferenceConfig:
|
||||
"""Simple configuration for inference"""
|
||||
model_path: str # Path to saved model
|
||||
device: str = "auto" # "auto", "cuda", "cpu"
|
||||
batch_size: int = 32
|
||||
max_length: int = 512
|
||||
return_probabilities: bool = True
|
||||
return_top_k: int = 1 # Return top K predictions
|
||||
|
||||
|
||||
class ModelInference:
|
||||
"""Simple inference class for text classification"""
|
||||
|
||||
def __init__(self, config: InferenceConfig):
|
||||
self.config = config
|
||||
self.model = None
|
||||
self.tokenizer = None
|
||||
self.label_info = {}
|
||||
self.device = self._setup_device()
|
||||
|
||||
# Load model and tokenizer
|
||||
self.load_model()
|
||||
|
||||
def _setup_device(self) -> torch.device:
|
||||
"""Setup device for inference"""
|
||||
if self.config.device == "auto":
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
logger.info("Using CUDA device")
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
logger.info("Using CPU device")
|
||||
else:
|
||||
device = torch.device(self.config.device)
|
||||
logger.info(f"Using specified device: {device}")
|
||||
|
||||
return device
|
||||
|
||||
def load_model(self):
|
||||
"""Load model, tokenizer, and label mappings"""
|
||||
model_path = Path(self.config.model_path)
|
||||
|
||||
if not model_path.exists():
|
||||
raise FileNotFoundError(f"Model path not found: {model_path}")
|
||||
|
||||
# Load tokenizer
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||||
logger.info(f"Loaded tokenizer from {model_path}")
|
||||
|
||||
# Load model
|
||||
self.model = AutoModelForSequenceClassification.from_pretrained(model_path)
|
||||
self.model.to(self.device)
|
||||
self.model.eval()
|
||||
logger.info(f"Loaded model from {model_path}")
|
||||
|
||||
# Load label information
|
||||
label_info_path = model_path / "label_info.json"
|
||||
if label_info_path.exists():
|
||||
with open(label_info_path, 'r') as f:
|
||||
self.label_info = json.load(f)
|
||||
logger.info(f"Loaded label mappings: {self.label_info.get('id_to_label', {})}")
|
||||
else:
|
||||
logger.warning("No label_info.json found. Using default numeric labels.")
|
||||
# Create default mappings
|
||||
num_labels = self.model.config.num_labels
|
||||
self.label_info = {
|
||||
"id_to_label": {str(i): f"LABEL_{i}" for i in range(num_labels)},
|
||||
"label_to_id": {f"LABEL_{i}": i for i in range(num_labels)},
|
||||
"num_labels": num_labels
|
||||
}
|
||||
|
||||
def predict_single(self, text: str) -> Dict:
|
||||
"""Predict single text sample"""
|
||||
return self.predict_batch([text])[0]
|
||||
|
||||
def predict_batch(self, texts: List[str]) -> List[Dict]:
|
||||
"""Predict batch of texts"""
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
# Tokenize inputs
|
||||
inputs = self.tokenizer(
|
||||
texts,
|
||||
truncation=True,
|
||||
padding=True,
|
||||
max_length=self.config.max_length,
|
||||
return_tensors="pt"
|
||||
)
|
||||
|
||||
# Move to device
|
||||
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
||||
|
||||
# Get predictions
|
||||
with torch.no_grad():
|
||||
outputs = self.model(**inputs)
|
||||
logits = outputs.logits
|
||||
|
||||
# Convert to probabilities
|
||||
probabilities = torch.softmax(logits, dim=-1)
|
||||
|
||||
# Process results
|
||||
results = []
|
||||
for i, text in enumerate(texts):
|
||||
text_probs = probabilities[i].cpu().numpy()
|
||||
|
||||
# Get top K predictions
|
||||
top_indices = np.argsort(text_probs)[-self.config.return_top_k:][::-1]
|
||||
|
||||
predictions = []
|
||||
for idx in top_indices:
|
||||
label = self.label_info["id_to_label"].get(str(idx), f"LABEL_{idx}")
|
||||
prob = float(text_probs[idx])
|
||||
|
||||
pred_dict = {
|
||||
"label": label,
|
||||
"label_id": int(idx),
|
||||
"score": prob
|
||||
}
|
||||
predictions.append(pred_dict)
|
||||
|
||||
result = {
|
||||
"text": text,
|
||||
"predictions": predictions,
|
||||
"predicted_label": predictions[0]["label"],
|
||||
"confidence": predictions[0]["score"]
|
||||
}
|
||||
|
||||
if self.config.return_probabilities:
|
||||
# Add all class probabilities
|
||||
all_probs = {}
|
||||
for label_id, label_name in self.label_info["id_to_label"].items():
|
||||
all_probs[label_name] = float(text_probs[int(label_id)])
|
||||
result["all_probabilities"] = all_probs
|
||||
|
||||
results.append(result)
|
||||
|
||||
return results
|
||||
|
||||
def predict_file(self, input_file: str, output_file: str = None) -> List[Dict]:
|
||||
"""Predict on texts from file"""
|
||||
input_path = Path(input_file)
|
||||
|
||||
# Read texts from file
|
||||
texts = []
|
||||
if input_path.suffix == '.txt':
|
||||
# Plain text file (one text per line)
|
||||
with open(input_path, 'r', encoding='utf-8') as f:
|
||||
texts = [line.strip() for line in f if line.strip()]
|
||||
|
||||
elif input_path.suffix == '.jsonl':
|
||||
# JSONL file with "text" field
|
||||
with open(input_path, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
if line.strip():
|
||||
data = json.loads(line)
|
||||
text = data.get("text", data.get("input", ""))
|
||||
if text:
|
||||
texts.append(text)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported file format: {input_path.suffix}")
|
||||
|
||||
logger.info(f"Loaded {len(texts)} texts from {input_file}")
|
||||
|
||||
# Process in batches
|
||||
all_results = []
|
||||
for i in range(0, len(texts), self.config.batch_size):
|
||||
batch_texts = texts[i:i + self.config.batch_size]
|
||||
batch_results = self.predict_batch(batch_texts)
|
||||
all_results.extend(batch_results)
|
||||
|
||||
if i % (self.config.batch_size * 10) == 0:
|
||||
logger.info(f"Processed {i + len(batch_texts)}/{len(texts)} texts")
|
||||
|
||||
# Save results if output file specified
|
||||
if output_file:
|
||||
output_path = Path(output_file)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
for result in all_results:
|
||||
f.write(json.dumps(result) + '\n')
|
||||
|
||||
logger.info(f"Results saved to {output_file}")
|
||||
|
||||
return all_results
|
||||
|
||||
|
||||
class BatchInference:
|
||||
"""Optimized batch inference for large datasets"""
|
||||
|
||||
def __init__(self, config: InferenceConfig):
|
||||
self.inference = ModelInference(config)
|
||||
self.config = config
|
||||
|
||||
def predict_large_file(self, input_file: str, output_file: str,
|
||||
chunk_size: int = 1000) -> None:
|
||||
"""Process large files in chunks to manage memory"""
|
||||
input_path = Path(input_file)
|
||||
output_path = Path(output_file)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Count total lines first
|
||||
total_lines = 0
|
||||
with open(input_path, 'r', encoding='utf-8') as f:
|
||||
for _ in f:
|
||||
total_lines += 1
|
||||
|
||||
logger.info(f"Processing {total_lines} lines in chunks of {chunk_size}")
|
||||
|
||||
processed = 0
|
||||
with open(input_path, 'r', encoding='utf-8') as infile, \
|
||||
open(output_path, 'w', encoding='utf-8') as outfile:
|
||||
|
||||
chunk_texts = []
|
||||
|
||||
for line in infile:
|
||||
if line.strip():
|
||||
if input_path.suffix == '.jsonl':
|
||||
data = json.loads(line)
|
||||
text = data.get("text", data.get("input", ""))
|
||||
else:
|
||||
text = line.strip()
|
||||
|
||||
chunk_texts.append(text)
|
||||
|
||||
if len(chunk_texts) >= chunk_size:
|
||||
# Process chunk
|
||||
results = self.inference.predict_batch(chunk_texts)
|
||||
|
||||
# Write results
|
||||
for result in results:
|
||||
outfile.write(json.dumps(result) + '\n')
|
||||
|
||||
processed += len(chunk_texts)
|
||||
logger.info(f"Processed {processed}/{total_lines} texts")
|
||||
|
||||
chunk_texts = []
|
||||
|
||||
# Process remaining texts
|
||||
if chunk_texts:
|
||||
results = self.inference.predict_batch(chunk_texts)
|
||||
for result in results:
|
||||
outfile.write(json.dumps(result) + '\n')
|
||||
processed += len(chunk_texts)
|
||||
|
||||
logger.info(f"Completed processing {processed} texts")
|
||||
|
||||
|
||||
def create_inference_config(
|
||||
model_path: str,
|
||||
device: str = "auto",
|
||||
batch_size: int = 32,
|
||||
**kwargs
|
||||
) -> InferenceConfig:
|
||||
"""Helper function to create inference configuration"""
|
||||
return InferenceConfig(
|
||||
model_path=model_path,
|
||||
device=device,
|
||||
batch_size=batch_size,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
"""Main function with YAML configuration support"""
|
||||
|
||||
parser = argparse.ArgumentParser(description="Model Inference Pipeline")
|
||||
|
||||
# YAML configuration
|
||||
parser.add_argument("--config", type=str, help="Path to YAML configuration file")
|
||||
|
||||
# Model settings
|
||||
parser.add_argument("--model-path", type=str, help="Path to saved model directory")
|
||||
parser.add_argument("--device", choices=["auto", "cuda", "cpu"], help="Device to run inference on")
|
||||
parser.add_argument("--batch-size", type=int, help="Batch size for inference")
|
||||
parser.add_argument("--max-length", type=int, help="Maximum sequence length for tokenization")
|
||||
|
||||
# Inference settings
|
||||
parser.add_argument("--return-probabilities", action="store_true", help="Return all class probabilities")
|
||||
parser.add_argument("--return-top-k", type=int, help="Return top K predictions")
|
||||
|
||||
# Input/Output settings
|
||||
parser.add_argument("--input-text", type=str, help="Single text for prediction")
|
||||
parser.add_argument("--input-file", type=str, help="Input file path (txt or jsonl)")
|
||||
parser.add_argument("--output-file", type=str, help="Output file path for results")
|
||||
parser.add_argument("--chunk-size", type=int, help="Chunk size for large file processing")
|
||||
|
||||
# Logging
|
||||
parser.add_argument("--log-level", choices=["DEBUG", "INFO", "WARNING", "ERROR"], default="INFO", help="Logging level")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Set up logging
|
||||
logging.basicConfig(
|
||||
level=getattr(logging, args.log_level),
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
|
||||
# Load configuration
|
||||
config_dict = {}
|
||||
|
||||
# Load YAML config if provided
|
||||
if args.config:
|
||||
try:
|
||||
with open(args.config, 'r', encoding='utf-8') as f:
|
||||
config_dict = yaml.safe_load(f)
|
||||
logger.info(f"Loaded YAML configuration from: {args.config}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading YAML config: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
# Override YAML config with CLI arguments
|
||||
cli_overrides = {}
|
||||
if args.model_path:
|
||||
cli_overrides['model_path'] = args.model_path
|
||||
if args.device:
|
||||
cli_overrides['device'] = args.device
|
||||
if args.batch_size:
|
||||
cli_overrides['batch_size'] = args.batch_size
|
||||
if args.max_length:
|
||||
cli_overrides['max_length'] = args.max_length
|
||||
if args.return_probabilities:
|
||||
cli_overrides['return_probabilities'] = True
|
||||
if args.return_top_k:
|
||||
cli_overrides['return_top_k'] = args.return_top_k
|
||||
|
||||
# Merge configurations
|
||||
for key, value in cli_overrides.items():
|
||||
if key in config_dict:
|
||||
logger.info(f"Overriding YAML config '{key}' with CLI value: {value}")
|
||||
config_dict[key] = value
|
||||
|
||||
# Validate model path
|
||||
model_path = config_dict.get('model_path')
|
||||
if not model_path:
|
||||
parser.error("--model-path is required (either in YAML config or CLI)")
|
||||
|
||||
model_path = Path(model_path)
|
||||
if not model_path.exists():
|
||||
print(f"❌ Model path not found: {model_path}")
|
||||
print("Please ensure the model directory exists and contains the trained model files")
|
||||
sys.exit(1)
|
||||
|
||||
# Create configuration object
|
||||
config = InferenceConfig(
|
||||
model_path=str(model_path),
|
||||
device=config_dict.get('device', 'auto'),
|
||||
batch_size=config_dict.get('batch_size', 32),
|
||||
max_length=config_dict.get('max_length', 512),
|
||||
return_probabilities=config_dict.get('return_probabilities', True),
|
||||
return_top_k=config_dict.get('return_top_k', 1)
|
||||
)
|
||||
|
||||
try:
|
||||
print(f"Loading model from: {config.model_path}")
|
||||
print(f"Device: {config.device}")
|
||||
print(f"Batch size: {config.batch_size}")
|
||||
if args.config:
|
||||
print(f"Using YAML configuration: {args.config}")
|
||||
print()
|
||||
|
||||
# Initialize inference
|
||||
inference = ModelInference(config)
|
||||
|
||||
# Handle different input types
|
||||
if args.input_text:
|
||||
# Single text prediction
|
||||
print(f"=== Single Text Prediction ===")
|
||||
print(f"Input text: {args.input_text}")
|
||||
print()
|
||||
|
||||
result = inference.predict_single(args.input_text)
|
||||
print(f"Predicted label: {result['predicted_label']}")
|
||||
print(f"Confidence: {result['confidence']:.4f}")
|
||||
print(f"Top {config.return_top_k} predictions:")
|
||||
for pred in result['predictions']:
|
||||
print(f" - {pred['label']}: {pred['score']:.4f}")
|
||||
|
||||
if config.return_probabilities and 'all_probabilities' in result:
|
||||
print(f"\nAll class probabilities:")
|
||||
for label, prob in result['all_probabilities'].items():
|
||||
print(f" - {label}: {prob:.4f}")
|
||||
|
||||
elif args.input_file:
|
||||
# File prediction
|
||||
input_path = Path(args.input_file)
|
||||
if not input_path.exists():
|
||||
print(f"❌ Input file not found: {input_path}")
|
||||
sys.exit(1)
|
||||
|
||||
print(f"=== File Prediction ===")
|
||||
print(f"Input file: {args.input_file}")
|
||||
print(f"Output file: {args.output_file}")
|
||||
print()
|
||||
|
||||
if args.output_file:
|
||||
# Use batch inference for large files
|
||||
if input_path.stat().st_size > 10 * 1024 * 1024: # > 10MB
|
||||
print("Large file detected, using chunked processing...")
|
||||
batch_inference = BatchInference(config)
|
||||
chunk_size = args.chunk_size or 1000
|
||||
batch_inference.predict_large_file(
|
||||
args.input_file,
|
||||
args.output_file,
|
||||
chunk_size=chunk_size
|
||||
)
|
||||
else:
|
||||
# Regular file processing
|
||||
results = inference.predict_file(args.input_file, args.output_file)
|
||||
print(f"Processed {len(results)} texts")
|
||||
else:
|
||||
# Just predict without saving
|
||||
results = inference.predict_file(args.input_file)
|
||||
print(f"Processed {len(results)} texts")
|
||||
print("\nSample results:")
|
||||
for i, result in enumerate(results[:3]): # Show first 3
|
||||
print(f" {i+1}. '{result['text'][:50]}...' -> {result['predicted_label']} ({result['confidence']:.4f})")
|
||||
|
||||
else:
|
||||
# Interactive mode - example predictions
|
||||
print("=== Interactive Mode ===")
|
||||
print("No input specified. Running example predictions...")
|
||||
print()
|
||||
|
||||
# Example texts
|
||||
example_texts = [
|
||||
"I love this product! It's amazing.",
|
||||
"This is terrible, I hate it.",
|
||||
"The weather is okay today.",
|
||||
"Best purchase ever made!"
|
||||
]
|
||||
|
||||
print("Example predictions:")
|
||||
for text in example_texts:
|
||||
result = inference.predict_single(text)
|
||||
print(f" '{text}' -> {result['predicted_label']} ({result['confidence']:.4f})")
|
||||
|
||||
print(f"\n✅ Inference completed successfully!")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Inference failed: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
"""
|
||||
# Quick usage examples:
|
||||
|
||||
# 1. Single prediction
|
||||
config = InferenceConfig(model_path="./results/my_model")
|
||||
inference = ModelInference(config)
|
||||
result = inference.predict_single("Your text here")
|
||||
|
||||
# # 2. Batch prediction
|
||||
# results = inference.predict_batch(["text1", "text2", "text3"])
|
||||
|
||||
# # 3. File prediction
|
||||
# inference.predict_file("input.txt", "predictions.jsonl")
|
||||
|
||||
# # 4. Large file processing
|
||||
# batch_inference = BatchInference(config)
|
||||
# batch_inference.predict_large_file("large_input.jsonl", "large_output.jsonl", chunk_size=1000)
|
||||
"""
|
||||
Reference in New Issue
Block a user