#!/usr/bin/env python3 """ Inference utilities for memory module detection. Contains functions for model loading, inference, and visualization. """ import cv2 import numpy as np from PIL import Image, ImageDraw, ImageFont import os from ultralytics import YOLO import torch class MemoryModuleDetector: """Memory module detector using YOLOv8.""" def __init__(self, model_path='runs/detect/memory_module_detection/weights/best.pt'): """ Initialize the detector. Args: model_path (str): Path to the trained YOLOv8 model """ self.model_path = model_path self.model = None self.class_names = ['memory_module'] self.colors = [(0, 255, 0)] # Green for memory modules # Load model if it exists if os.path.exists(model_path): self.load_model() else: print(f"Warning: Model not found at {model_path}") print("Please train the model first using train.py") def load_model(self): """Load the trained YOLOv8 model.""" try: # Fix for PyTorch 2.6+ weights_only issue import torch # Use weights_only=False for compatibility with torch.serialization.safe_globals(['ultralytics.nn.tasks.DetectionModel']): self.model = YOLO(self.model_path) print(f"Model loaded successfully from {self.model_path}") except Exception as e: try: # Fallback: try loading with weights_only=False import torch original_load = torch.load torch.load = lambda *args, **kwargs: original_load(*args, **kwargs, weights_only=False) self.model = YOLO(self.model_path) torch.load = original_load print(f"Model loaded successfully from {self.model_path} (fallback method)") except Exception as e2: print(f"Error loading model: {e2}") self.model = None def detect(self, image_path, conf_threshold=0.5, iou_threshold=0.45): """ Detect memory modules in an image. Args: image_path (str): Path to the input image conf_threshold (float): Confidence threshold for detections iou_threshold (float): IoU threshold for NMS Returns: tuple: (detections, annotated_image) """ if self.model is None: raise ValueError("Model not loaded. Please check model path.") # Run inference results = self.model(image_path, conf=conf_threshold, iou=iou_threshold) # Extract detections detections = [] if len(results) > 0 and results[0].boxes is not None: boxes = results[0].boxes for i in range(len(boxes)): box = boxes.xyxy[i].cpu().numpy() # x1, y1, x2, y2 conf = boxes.conf[i].cpu().numpy() cls = int(boxes.cls[i].cpu().numpy()) detection = { 'bbox': box.tolist(), 'confidence': float(conf), 'class': int(cls), 'class_name': self.class_names[cls] if cls < len(self.class_names) else 'unknown' } detections.append(detection) # Create annotated image annotated_image = self.draw_detections(image_path, detections) return detections, annotated_image def draw_detections(self, image_path, detections): """ Draw bounding boxes on the image. Args: image_path (str): Path to the input image detections (list): List of detection dictionaries Returns: PIL.Image: Annotated image """ # Load image image = Image.open(image_path).convert('RGB') draw = ImageDraw.Draw(image) # Try to load a font try: font = ImageFont.truetype("arial.ttf", 16) except: font = ImageFont.load_default() # Draw each detection for detection in detections: bbox = detection['bbox'] confidence = detection['confidence'] class_name = detection['class_name'] # Extract coordinates x1, y1, x2, y2 = bbox # Draw bounding box color = self.colors[0] # Green for memory modules draw.rectangle([x1, y1, x2, y2], outline=color, width=3) # Draw label label = f"{class_name}: {confidence:.2f}" # Get text size for background bbox_text = draw.textbbox((0, 0), label, font=font) text_width = bbox_text[2] - bbox_text[0] text_height = bbox_text[3] - bbox_text[1] # Draw background for text draw.rectangle([x1, y1 - text_height - 4, x1 + text_width + 4, y1], fill=color, outline=color) # Draw text draw.text((x1 + 2, y1 - text_height - 2), label, fill=(255, 255, 255), font=font) return image def detect_from_array(self, image_array, conf_threshold=0.5, iou_threshold=0.45): """ Detect memory modules from a numpy array. Args: image_array (np.ndarray): Input image as numpy array conf_threshold (float): Confidence threshold for detections iou_threshold (float): IoU threshold for NMS Returns: tuple: (detections, annotated_image) """ if self.model is None: raise ValueError("Model not loaded. Please check model path.") # Convert numpy array to PIL Image if needed if isinstance(image_array, np.ndarray): if image_array.dtype != np.uint8: image_array = (image_array * 255).astype(np.uint8) image = Image.fromarray(image_array) else: image = image_array # Run inference results = self.model(image, conf=conf_threshold, iou=iou_threshold) # Extract detections detections = [] if len(results) > 0 and results[0].boxes is not None: boxes = results[0].boxes for i in range(len(boxes)): box = boxes.xyxy[i].cpu().numpy() # x1, y1, x2, y2 conf = boxes.conf[i].cpu().numpy() cls = int(boxes.cls[i].cpu().numpy()) detection = { 'bbox': box.tolist(), 'confidence': float(conf), 'class': int(cls), 'class_name': self.class_names[cls] if cls < len(self.class_names) else 'unknown' } detections.append(detection) # Create annotated image annotated_image = self.draw_detections_on_image(image, detections) return detections, annotated_image def draw_detections_on_image(self, image, detections): """ Draw bounding boxes on a PIL Image. Args: image (PIL.Image): Input image detections (list): List of detection dictionaries Returns: PIL.Image: Annotated image """ # Make a copy to avoid modifying the original annotated_image = image.copy() draw = ImageDraw.Draw(annotated_image) # Try to load a font try: font = ImageFont.truetype("arial.ttf", 16) except: font = ImageFont.load_default() # Draw each detection for detection in detections: bbox = detection['bbox'] confidence = detection['confidence'] class_name = detection['class_name'] # Extract coordinates x1, y1, x2, y2 = bbox # Draw bounding box color = self.colors[0] # Green for memory modules draw.rectangle([x1, y1, x2, y2], outline=color, width=3) # Draw label label = f"{class_name}: {confidence:.2f}" # Get text size for background bbox_text = draw.textbbox((0, 0), label, font=font) text_width = bbox_text[2] - bbox_text[0] text_height = bbox_text[3] - bbox_text[1] # Draw background for text draw.rectangle([x1, y1 - text_height - 4, x1 + text_width + 4, y1], fill=color, outline=color) # Draw text draw.text((x1 + 2, y1 - text_height - 2), label, fill=(255, 255, 255), font=font) return annotated_image def test_inference(image_path, model_path='runs/detect/memory_module_detection/weights/best.pt'): """ Test inference on a single image. Args: image_path (str): Path to test image model_path (str): Path to trained model """ detector = MemoryModuleDetector(model_path) if detector.model is None: print("Cannot run inference without a trained model.") return print(f"Running inference on: {image_path}") detections, annotated_image = detector.detect(image_path) print(f"Found {len(detections)} memory modules:") for i, detection in enumerate(detections): print(f" {i+1}. {detection['class_name']} (confidence: {detection['confidence']:.3f})") # Save annotated image output_path = f"annotated_{os.path.basename(image_path)}" annotated_image.save(output_path) print(f"Annotated image saved as: {output_path}") return detections, annotated_image if __name__ == "__main__": import argparse parser = argparse.ArgumentParser(description='Test memory module detection') parser.add_argument('--image', type=str, required=True, help='Path to test image') parser.add_argument('--model', type=str, default='runs/detect/memory_module_detection/weights/best.pt', help='Path to trained model') parser.add_argument('--conf', type=float, default=0.5, help='Confidence threshold') args = parser.parse_args() test_inference(args.image, args.model)