import torch from transformers import AutoTokenizer, AutoModelForSequenceClassification from pathlib import Path import json import numpy as np from typing import List, Dict, Union, Optional from dataclasses import dataclass import logging import argparse import sys import yaml logger = logging.getLogger(__name__) @dataclass class InferenceConfig: """Simple configuration for inference""" model_path: str # Path to saved model device: str = "auto" # "auto", "cuda", "cpu" batch_size: int = 32 max_length: int = 512 return_probabilities: bool = True return_top_k: int = 1 # Return top K predictions class ModelInference: """Simple inference class for text classification""" def __init__(self, config: InferenceConfig): self.config = config self.model = None self.tokenizer = None self.label_info = {} self.device = self._setup_device() # Load model and tokenizer self.load_model() def _setup_device(self) -> torch.device: """Setup device for inference""" if self.config.device == "auto": if torch.cuda.is_available(): device = torch.device("cuda") logger.info("Using CUDA device") else: device = torch.device("cpu") logger.info("Using CPU device") else: device = torch.device(self.config.device) logger.info(f"Using specified device: {device}") return device def load_model(self): """Load model, tokenizer, and label mappings""" model_path = Path(self.config.model_path) if not model_path.exists(): raise FileNotFoundError(f"Model path not found: {model_path}") # Load tokenizer self.tokenizer = AutoTokenizer.from_pretrained(model_path) logger.info(f"Loaded tokenizer from {model_path}") # Load model self.model = AutoModelForSequenceClassification.from_pretrained(model_path) self.model.to(self.device) self.model.eval() logger.info(f"Loaded model from {model_path}") # Load label information label_info_path = model_path / "label_info.json" if label_info_path.exists(): with open(label_info_path, 'r') as f: self.label_info = json.load(f) logger.info(f"Loaded label mappings: {self.label_info.get('id_to_label', {})}") else: logger.warning("No label_info.json found. Using default numeric labels.") # Create default mappings num_labels = self.model.config.num_labels self.label_info = { "id_to_label": {str(i): f"LABEL_{i}" for i in range(num_labels)}, "label_to_id": {f"LABEL_{i}": i for i in range(num_labels)}, "num_labels": num_labels } def predict_single(self, text: str) -> Dict: """Predict single text sample""" return self.predict_batch([text])[0] def predict_batch(self, texts: List[str]) -> List[Dict]: """Predict batch of texts""" if not texts: return [] # Tokenize inputs inputs = self.tokenizer( texts, truncation=True, padding=True, max_length=self.config.max_length, return_tensors="pt" ) # Move to device inputs = {k: v.to(self.device) for k, v in inputs.items()} # Get predictions with torch.no_grad(): outputs = self.model(**inputs) logits = outputs.logits # Convert to probabilities probabilities = torch.softmax(logits, dim=-1) # Process results results = [] for i, text in enumerate(texts): text_probs = probabilities[i].cpu().numpy() # Get top K predictions top_indices = np.argsort(text_probs)[-self.config.return_top_k:][::-1] predictions = [] for idx in top_indices: label = self.label_info["id_to_label"].get(str(idx), f"LABEL_{idx}") prob = float(text_probs[idx]) pred_dict = { "label": label, "label_id": int(idx), "score": prob } predictions.append(pred_dict) result = { "text": text, "predictions": predictions, "predicted_label": predictions[0]["label"], "confidence": predictions[0]["score"] } if self.config.return_probabilities: # Add all class probabilities all_probs = {} for label_id, label_name in self.label_info["id_to_label"].items(): all_probs[label_name] = float(text_probs[int(label_id)]) result["all_probabilities"] = all_probs results.append(result) return results def predict_file(self, input_file: str, output_file: str = None) -> List[Dict]: """Predict on texts from file""" input_path = Path(input_file) # Read texts from file texts = [] if input_path.suffix == '.txt': # Plain text file (one text per line) with open(input_path, 'r', encoding='utf-8') as f: texts = [line.strip() for line in f if line.strip()] elif input_path.suffix == '.jsonl': # JSONL file with "text" field with open(input_path, 'r', encoding='utf-8') as f: for line in f: if line.strip(): data = json.loads(line) text = data.get("text", data.get("input", "")) if text: texts.append(text) else: raise ValueError(f"Unsupported file format: {input_path.suffix}") logger.info(f"Loaded {len(texts)} texts from {input_file}") # Process in batches all_results = [] for i in range(0, len(texts), self.config.batch_size): batch_texts = texts[i:i + self.config.batch_size] batch_results = self.predict_batch(batch_texts) all_results.extend(batch_results) if i % (self.config.batch_size * 10) == 0: logger.info(f"Processed {i + len(batch_texts)}/{len(texts)} texts") # Save results if output file specified if output_file: output_path = Path(output_file) output_path.parent.mkdir(parents=True, exist_ok=True) with open(output_path, 'w', encoding='utf-8') as f: for result in all_results: f.write(json.dumps(result) + '\n') logger.info(f"Results saved to {output_file}") return all_results class BatchInference: """Optimized batch inference for large datasets""" def __init__(self, config: InferenceConfig): self.inference = ModelInference(config) self.config = config def predict_large_file(self, input_file: str, output_file: str, chunk_size: int = 1000) -> None: """Process large files in chunks to manage memory""" input_path = Path(input_file) output_path = Path(output_file) output_path.parent.mkdir(parents=True, exist_ok=True) # Count total lines first total_lines = 0 with open(input_path, 'r', encoding='utf-8') as f: for _ in f: total_lines += 1 logger.info(f"Processing {total_lines} lines in chunks of {chunk_size}") processed = 0 with open(input_path, 'r', encoding='utf-8') as infile, \ open(output_path, 'w', encoding='utf-8') as outfile: chunk_texts = [] for line in infile: if line.strip(): if input_path.suffix == '.jsonl': data = json.loads(line) text = data.get("text", data.get("input", "")) else: text = line.strip() chunk_texts.append(text) if len(chunk_texts) >= chunk_size: # Process chunk results = self.inference.predict_batch(chunk_texts) # Write results for result in results: outfile.write(json.dumps(result) + '\n') processed += len(chunk_texts) logger.info(f"Processed {processed}/{total_lines} texts") chunk_texts = [] # Process remaining texts if chunk_texts: results = self.inference.predict_batch(chunk_texts) for result in results: outfile.write(json.dumps(result) + '\n') processed += len(chunk_texts) logger.info(f"Completed processing {processed} texts") def create_inference_config( model_path: str, device: str = "auto", batch_size: int = 32, **kwargs ) -> InferenceConfig: """Helper function to create inference configuration""" return InferenceConfig( model_path=model_path, device=device, batch_size=batch_size, **kwargs ) def main(): """Main function with YAML configuration support""" parser = argparse.ArgumentParser(description="Model Inference Pipeline") # YAML configuration parser.add_argument("--config", type=str, help="Path to YAML configuration file") # Model settings parser.add_argument("--model-path", type=str, help="Path to saved model directory") parser.add_argument("--device", choices=["auto", "cuda", "cpu"], help="Device to run inference on") parser.add_argument("--batch-size", type=int, help="Batch size for inference") parser.add_argument("--max-length", type=int, help="Maximum sequence length for tokenization") # Inference settings parser.add_argument("--return-probabilities", action="store_true", help="Return all class probabilities") parser.add_argument("--return-top-k", type=int, help="Return top K predictions") # Input/Output settings parser.add_argument("--input-text", type=str, help="Single text for prediction") parser.add_argument("--input-file", type=str, help="Input file path (txt or jsonl)") parser.add_argument("--output-file", type=str, help="Output file path for results") parser.add_argument("--chunk-size", type=int, help="Chunk size for large file processing") # Logging parser.add_argument("--log-level", choices=["DEBUG", "INFO", "WARNING", "ERROR"], default="INFO", help="Logging level") args = parser.parse_args() # Set up logging logging.basicConfig( level=getattr(logging, args.log_level), format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) # Load configuration config_dict = {} # Load YAML config if provided if args.config: try: with open(args.config, 'r', encoding='utf-8') as f: config_dict = yaml.safe_load(f) logger.info(f"Loaded YAML configuration from: {args.config}") except Exception as e: logger.error(f"Error loading YAML config: {e}") sys.exit(1) # Override YAML config with CLI arguments cli_overrides = {} if args.model_path: cli_overrides['model_path'] = args.model_path if args.device: cli_overrides['device'] = args.device if args.batch_size: cli_overrides['batch_size'] = args.batch_size if args.max_length: cli_overrides['max_length'] = args.max_length if args.return_probabilities: cli_overrides['return_probabilities'] = True if args.return_top_k: cli_overrides['return_top_k'] = args.return_top_k # Merge configurations for key, value in cli_overrides.items(): if key in config_dict: logger.info(f"Overriding YAML config '{key}' with CLI value: {value}") config_dict[key] = value # Validate model path model_path = config_dict.get('model_path') if not model_path: parser.error("--model-path is required (either in YAML config or CLI)") model_path = Path(model_path) if not model_path.exists(): print(f"āŒ Model path not found: {model_path}") print("Please ensure the model directory exists and contains the trained model files") sys.exit(1) # Create configuration object config = InferenceConfig( model_path=str(model_path), device=config_dict.get('device', 'auto'), batch_size=config_dict.get('batch_size', 32), max_length=config_dict.get('max_length', 512), return_probabilities=config_dict.get('return_probabilities', True), return_top_k=config_dict.get('return_top_k', 1) ) try: print(f"Loading model from: {config.model_path}") print(f"Device: {config.device}") print(f"Batch size: {config.batch_size}") if args.config: print(f"Using YAML configuration: {args.config}") print() # Initialize inference inference = ModelInference(config) # Handle different input types if args.input_text: # Single text prediction print(f"=== Single Text Prediction ===") print(f"Input text: {args.input_text}") print() result = inference.predict_single(args.input_text) print(f"Predicted label: {result['predicted_label']}") print(f"Confidence: {result['confidence']:.4f}") print(f"Top {config.return_top_k} predictions:") for pred in result['predictions']: print(f" - {pred['label']}: {pred['score']:.4f}") if config.return_probabilities and 'all_probabilities' in result: print(f"\nAll class probabilities:") for label, prob in result['all_probabilities'].items(): print(f" - {label}: {prob:.4f}") elif args.input_file: # File prediction input_path = Path(args.input_file) if not input_path.exists(): print(f"āŒ Input file not found: {input_path}") sys.exit(1) print(f"=== File Prediction ===") print(f"Input file: {args.input_file}") print(f"Output file: {args.output_file}") print() if args.output_file: # Use batch inference for large files if input_path.stat().st_size > 10 * 1024 * 1024: # > 10MB print("Large file detected, using chunked processing...") batch_inference = BatchInference(config) chunk_size = args.chunk_size or 1000 batch_inference.predict_large_file( args.input_file, args.output_file, chunk_size=chunk_size ) else: # Regular file processing results = inference.predict_file(args.input_file, args.output_file) print(f"Processed {len(results)} texts") else: # Just predict without saving results = inference.predict_file(args.input_file) print(f"Processed {len(results)} texts") print("\nSample results:") for i, result in enumerate(results[:3]): # Show first 3 print(f" {i+1}. '{result['text'][:50]}...' -> {result['predicted_label']} ({result['confidence']:.4f})") else: # Interactive mode - example predictions print("=== Interactive Mode ===") print("No input specified. Running example predictions...") print() # Example texts example_texts = [ "I love this product! It's amazing.", "This is terrible, I hate it.", "The weather is okay today.", "Best purchase ever made!" ] print("Example predictions:") for text in example_texts: result = inference.predict_single(text) print(f" '{text}' -> {result['predicted_label']} ({result['confidence']:.4f})") print(f"\nāœ… Inference completed successfully!") except Exception as e: print(f"āŒ Inference failed: {e}") sys.exit(1) if __name__ == "__main__": main() """ # Quick usage examples: # 1. Single prediction config = InferenceConfig(model_path="./results/my_model") inference = ModelInference(config) result = inference.predict_single("Your text here") # # 2. Batch prediction # results = inference.predict_batch(["text1", "text2", "text3"]) # # 3. File prediction # inference.predict_file("input.txt", "predictions.jsonl") # # 4. Large file processing # batch_inference = BatchInference(config) # batch_inference.predict_large_file("large_input.jsonl", "large_output.jsonl", chunk_size=1000) """