update project structure and improve scripts
This commit is contained in:
+116
-35
@@ -1,55 +1,136 @@
|
||||
from ultralytics import YOLO
|
||||
import cv2
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from ultralytics import YOLO
|
||||
from exceptions import ModelLoadError, DetectionError, ImageProcessingError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MemoryDetector:
|
||||
def __init__(self, model_path):
|
||||
def __init__(self, model_path, confidence_threshold=0.3, image_size=416):
|
||||
self.model_path = model_path
|
||||
self.confidence_threshold = confidence_threshold
|
||||
self.image_size = image_size
|
||||
self.model = None
|
||||
self._load_model()
|
||||
|
||||
def _load_model(self):
|
||||
"""Load YOLO model with error handling."""
|
||||
try:
|
||||
self.model = YOLO(model_path)
|
||||
logger.info(f"Loaded model from {model_path}")
|
||||
if not Path(self.model_path).exists():
|
||||
raise FileNotFoundError(f"Model file not found: {self.model_path}")
|
||||
|
||||
logger.info(f"Loading model from {self.model_path}")
|
||||
self.model = YOLO(self.model_path)
|
||||
logger.info("Model loaded successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Model loading failed: {str(e)}")
|
||||
raise
|
||||
|
||||
def detect(self, image_path):
|
||||
logger.error(f"Failed to load model: {e}")
|
||||
raise ModelLoadError(f"Model loading failed: {str(e)}")
|
||||
|
||||
def detect_from_bytes(self, image_bytes):
|
||||
"""Detect memory modules from image bytes."""
|
||||
try:
|
||||
# Decode image
|
||||
nparr = np.frombuffer(image_bytes, np.uint8)
|
||||
image = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
|
||||
|
||||
if image is None:
|
||||
raise ImageProcessingError("Could not decode image")
|
||||
|
||||
return self._perform_detection(image)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Detection from bytes failed: {e}")
|
||||
raise DetectionError(f"Detection failed: {str(e)}")
|
||||
|
||||
def detect_from_file(self, file_path):
|
||||
"""Detect memory modules from image file."""
|
||||
try:
|
||||
if not Path(file_path).exists():
|
||||
raise FileNotFoundError(f"Image file not found: {file_path}")
|
||||
|
||||
image = cv2.imread(str(file_path))
|
||||
if image is None:
|
||||
raise ImageProcessingError("Could not load image file")
|
||||
|
||||
return self._perform_detection(image)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Detection from file failed: {e}")
|
||||
raise DetectionError(f"Detection failed: {str(e)}")
|
||||
|
||||
def _perform_detection(self, image):
|
||||
"""Perform detection on image."""
|
||||
try:
|
||||
logger.info("Running detection")
|
||||
|
||||
# Run inference
|
||||
results = self.model.predict(image_path, imgsz=416, conf=0.5)
|
||||
results = self.model.predict(
|
||||
image,
|
||||
imgsz=self.image_size,
|
||||
conf=self.confidence_threshold,
|
||||
verbose=False
|
||||
)
|
||||
|
||||
# Extract results
|
||||
boxes = results[0].boxes.xyxy.cpu().numpy()
|
||||
confidences = results[0].boxes.conf.cpu().numpy()
|
||||
# Extract detections
|
||||
detections = self._extract_detections(results[0])
|
||||
|
||||
# Convert to list of [x1, y1, x2, y2, confidence]
|
||||
detections = []
|
||||
for box, conf in zip(boxes, confidences):
|
||||
detections.append({
|
||||
'box': [int(x) for x in box],
|
||||
'confidence': float(conf)
|
||||
})
|
||||
# Create annotated image
|
||||
annotated_image = self._draw_boxes(image, detections)
|
||||
|
||||
# Annotate image
|
||||
annotated_img = self._draw_boxes(image_path, detections)
|
||||
logger.info(f"Detection completed: {len(detections)} objects found")
|
||||
|
||||
return {
|
||||
'detections': detections,
|
||||
'annotated_image': annotated_img
|
||||
'annotated_image': annotated_image,
|
||||
'detection_count': len(detections)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Detection failed: {str(e)}")
|
||||
raise
|
||||
|
||||
def _draw_boxes(self, image_path, detections):
|
||||
img = cv2.imread(str(image_path))
|
||||
for det in detections:
|
||||
x1, y1, x2, y2 = det['box']
|
||||
cv2.rectangle(img, (x1, y1), (x2, y2), (0, 255, 0), 2)
|
||||
cv2.putText(img, f"{det['confidence']:.2f}",
|
||||
(x1, y1-10),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0,255,0), 1)
|
||||
return img
|
||||
logger.error(f"Detection processing failed: {e}")
|
||||
raise DetectionError(f"Detection processing failed: {str(e)}")
|
||||
|
||||
def _extract_detections(self, result):
|
||||
"""Extract detection results."""
|
||||
detections = []
|
||||
|
||||
if result.boxes is not None:
|
||||
boxes = result.boxes.xyxy.cpu().numpy()
|
||||
confidences = result.boxes.conf.cpu().numpy()
|
||||
classes = result.boxes.cls.cpu().numpy() if result.boxes.cls is not None else None
|
||||
|
||||
for i, (box, conf) in enumerate(zip(boxes, confidences)):
|
||||
detection = {
|
||||
'box': [float(coord) for coord in box], # [x1, y1, x2, y2]
|
||||
'confidence': float(conf),
|
||||
'class': int(classes[i]) if classes is not None else 0
|
||||
}
|
||||
detections.append(detection)
|
||||
|
||||
return detections
|
||||
|
||||
def _draw_boxes(self, image, detections):
|
||||
"""Draw bounding boxes on image."""
|
||||
annotated = image.copy()
|
||||
|
||||
for detection in detections:
|
||||
box = detection['box']
|
||||
confidence = detection['confidence']
|
||||
|
||||
# Extract coordinates
|
||||
x1, y1, x2, y2 = map(int, box)
|
||||
|
||||
# Draw bounding box
|
||||
cv2.rectangle(annotated, (x1, y1), (x2, y2), (0, 255, 0), 2)
|
||||
|
||||
# Draw confidence score
|
||||
label = f"Memory: {confidence:.2f}"
|
||||
cv2.putText(
|
||||
annotated, label, (x1, y1 - 10),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1
|
||||
)
|
||||
|
||||
return annotated
|
||||
|
||||
Reference in New Issue
Block a user