update
This commit is contained in:
@@ -0,0 +1,99 @@
|
||||
from ultralytics import YOLO
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
from typing import Tuple, List, Dict
|
||||
|
||||
class MemoryDetector:
|
||||
def __init__(self,
|
||||
model_path='model/weights/best.pt',
|
||||
conf_threshold=0.25,
|
||||
iou_threshold=0.45):
|
||||
"""
|
||||
Initialize the detector with the trained model.
|
||||
|
||||
Args:
|
||||
model_path (str): Path to the trained model weights
|
||||
conf_threshold (float): Confidence threshold for detections
|
||||
iou_threshold (float): IoU threshold for NMS
|
||||
"""
|
||||
self.model = YOLO(model_path)
|
||||
self.conf_threshold = conf_threshold
|
||||
self.iou_threshold = iou_threshold
|
||||
|
||||
def detect(self,
|
||||
image: Image.Image,
|
||||
conf_threshold: float = None,
|
||||
iou_threshold: float = None) -> Tuple[Image.Image, List[Dict]]:
|
||||
"""
|
||||
Detect memory modules in the given image.
|
||||
|
||||
Args:
|
||||
image (PIL.Image): Input image to process
|
||||
conf_threshold (float, optional): Override default confidence threshold
|
||||
iou_threshold (float, optional): Override default IoU threshold
|
||||
|
||||
Returns:
|
||||
Tuple[PIL.Image, List[Dict]]: Annotated image and list of detections
|
||||
"""
|
||||
# Use provided thresholds or defaults
|
||||
conf = conf_threshold if conf_threshold is not None else self.conf_threshold
|
||||
iou = iou_threshold if iou_threshold is not None else self.iou_threshold
|
||||
|
||||
# Run inference
|
||||
results = self.model.predict(
|
||||
source=image,
|
||||
conf=conf,
|
||||
iou=iou,
|
||||
max_det=10,
|
||||
verbose=False
|
||||
)
|
||||
|
||||
# Get the annotated image
|
||||
annotated_img = results[0].plot()
|
||||
|
||||
# Extract detection information
|
||||
detections = []
|
||||
for box in results[0].boxes:
|
||||
detection = {
|
||||
'xyxy': box.xyxy[0].tolist(), # Bounding box coordinates
|
||||
'confidence': float(box.conf[0]), # Detection confidence
|
||||
'class': int(box.cls[0]) # Class ID
|
||||
}
|
||||
detections.append(detection)
|
||||
|
||||
return Image.fromarray(annotated_img), detections
|
||||
|
||||
def optimize_thresholds(self, validation_images: List[Image.Image]) -> Tuple[float, float]:
|
||||
"""
|
||||
Find optimal confidence and IoU thresholds using validation images.
|
||||
|
||||
Args:
|
||||
validation_images (List[Image.Image]): List of validation images
|
||||
|
||||
Returns:
|
||||
Tuple[float, float]: Optimal confidence and IoU thresholds
|
||||
"""
|
||||
best_conf = 0.25
|
||||
best_iou = 0.45
|
||||
|
||||
# Grid search for best parameters
|
||||
conf_range = [0.15, 0.2, 0.25, 0.3, 0.35]
|
||||
iou_range = [0.35, 0.4, 0.45, 0.5, 0.55]
|
||||
|
||||
best_score = 0
|
||||
|
||||
for conf in conf_range:
|
||||
for iou in iou_range:
|
||||
total_score = 0
|
||||
for img in validation_images:
|
||||
_, detections = self.detect(img, conf, iou)
|
||||
# Score based on number of detections and confidence
|
||||
score = sum([d['confidence'] for d in detections])
|
||||
total_score += score
|
||||
|
||||
if total_score > best_score:
|
||||
best_score = total_score
|
||||
best_conf = conf
|
||||
best_iou = iou
|
||||
|
||||
return best_conf, best_iou
|
||||
Reference in New Issue
Block a user