update
This commit is contained in:
@@ -0,0 +1,182 @@
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import matplotlib.pyplot as plt
|
||||
from app.utils.detector import MemoryDetector
|
||||
import os
|
||||
import json
|
||||
from typing import List, Dict, Tuple
|
||||
import logging
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class TestMemoryDetector:
|
||||
@pytest.fixture(scope="class")
|
||||
def results_dir(self):
|
||||
"""Create and return results directory"""
|
||||
dir_path = Path("test_results")
|
||||
dir_path.mkdir(exist_ok=True)
|
||||
logger.info(f"Created results directory: {dir_path}")
|
||||
return dir_path
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def detector(self):
|
||||
"""Initialize detector once for all tests"""
|
||||
logger.info("Initializing MemoryDetector...")
|
||||
return MemoryDetector()
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def test_images(self):
|
||||
"""Load test images from validation directory"""
|
||||
val_dir = Path('training/val/images')
|
||||
assert val_dir.exists(), f"Validation directory not found: {val_dir}"
|
||||
|
||||
logger.info(f"Loading test images from {val_dir}")
|
||||
images = []
|
||||
for img_path in val_dir.glob('memory_*.png'):
|
||||
images.append({
|
||||
'path': str(img_path),
|
||||
'image': Image.open(img_path)
|
||||
})
|
||||
logger.info(f"Loaded {len(images)} test images")
|
||||
assert len(images) > 0, "No test images found"
|
||||
return images
|
||||
|
||||
def test_detector_initialization(self, detector):
|
||||
"""Test detector initialization and default parameters"""
|
||||
logger.info("Testing detector initialization...")
|
||||
assert detector.conf_threshold == 0.25
|
||||
assert detector.iou_threshold == 0.45
|
||||
assert detector.model is not None
|
||||
logger.info("Detector initialization test passed")
|
||||
|
||||
def test_single_image_detection(self, detector, test_images, results_dir):
|
||||
"""Test detection on a single image"""
|
||||
logger.info("Testing single image detection...")
|
||||
test_case = test_images[0]
|
||||
result_img, detections = detector.detect(test_case['image'])
|
||||
|
||||
# Save the result
|
||||
output_path = results_dir / "single_detection_test.png"
|
||||
result_img.save(output_path)
|
||||
logger.info(f"Saved detection result to {output_path}")
|
||||
|
||||
# Verify result type and content
|
||||
assert isinstance(result_img, Image.Image)
|
||||
assert isinstance(detections, list)
|
||||
assert all(isinstance(d, dict) for d in detections)
|
||||
|
||||
# Log detection results
|
||||
logger.info(f"Number of detections: {len(detections)}")
|
||||
if len(detections) > 0:
|
||||
for i, det in enumerate(detections):
|
||||
logger.info(f"Detection {i+1}: confidence={det['confidence']:.3f}")
|
||||
|
||||
def test_batch_detection(self, detector, test_images, results_dir):
|
||||
"""Test detection on multiple images"""
|
||||
logger.info("Testing batch detection...")
|
||||
results = []
|
||||
for i, test_case in enumerate(test_images):
|
||||
logger.info(f"Processing image {i+1}/{len(test_images)}")
|
||||
result_img, detections = detector.detect(test_case['image'])
|
||||
|
||||
# Save each result
|
||||
output_path = results_dir / f"batch_detection_{i}.png"
|
||||
result_img.save(output_path)
|
||||
|
||||
results.append({
|
||||
'path': test_case['path'],
|
||||
'detections': len(detections),
|
||||
'confidences': [d['confidence'] for d in detections]
|
||||
})
|
||||
|
||||
# Save detailed results
|
||||
results_path = results_dir / "batch_results.json"
|
||||
with open(results_path, 'w') as f:
|
||||
json.dump(results, f, indent=2)
|
||||
logger.info(f"Saved batch results to {results_path}")
|
||||
|
||||
# Log statistics
|
||||
total_detections = sum(r['detections'] for r in results)
|
||||
avg_confidence = np.mean([conf for r in results for conf in r['confidences']]) if total_detections > 0 else 0
|
||||
|
||||
logger.info("\nBatch Detection Statistics:")
|
||||
logger.info(f"Total images processed: {len(results)}")
|
||||
logger.info(f"Total detections: {total_detections}")
|
||||
logger.info(f"Average confidence: {avg_confidence:.3f}")
|
||||
|
||||
assert total_detections > 0, "No detections found in any test image"
|
||||
|
||||
def test_threshold_optimization(self, detector, test_images):
|
||||
"""Test threshold optimization functionality"""
|
||||
images = [tc['image'] for tc in test_images]
|
||||
best_conf, best_iou = detector.optimize_thresholds(images)
|
||||
|
||||
# Verify threshold bounds
|
||||
assert 0 <= best_conf <= 1, f"Invalid confidence threshold: {best_conf}"
|
||||
assert 0 <= best_iou <= 1, f"Invalid IoU threshold: {best_iou}"
|
||||
|
||||
# Test detection with optimized thresholds
|
||||
test_case = test_images[0]
|
||||
result_img, detections = detector.detect(
|
||||
test_case['image'],
|
||||
conf_threshold=best_conf,
|
||||
iou_threshold=best_iou
|
||||
)
|
||||
|
||||
print(f"\nOptimized Thresholds:")
|
||||
print(f"Confidence: {best_conf:.3f}")
|
||||
print(f"IoU: {best_iou:.3f}")
|
||||
|
||||
@pytest.mark.parametrize("conf_threshold,iou_threshold", [
|
||||
(0.1, 0.1),
|
||||
(0.5, 0.5),
|
||||
(0.9, 0.9)
|
||||
])
|
||||
def test_different_thresholds(self, detector, test_images, conf_threshold, iou_threshold):
|
||||
"""Test detection with different threshold combinations"""
|
||||
test_case = test_images[0]
|
||||
result_img, detections = detector.detect(
|
||||
test_case['image'],
|
||||
conf_threshold=conf_threshold,
|
||||
iou_threshold=iou_threshold
|
||||
)
|
||||
|
||||
print(f"\nThreshold Test (conf={conf_threshold}, iou={iou_threshold}):")
|
||||
print(f"Detections found: {len(detections)}")
|
||||
|
||||
def test_visualization(self, detector, test_images, results_dir):
|
||||
"""Test detection visualization and save results"""
|
||||
logger.info("Testing visualization...")
|
||||
|
||||
# Process and visualize a batch of images
|
||||
fig, axes = plt.subplots(2, 2, figsize=(12, 12))
|
||||
axes = axes.ravel()
|
||||
|
||||
for idx, test_case in enumerate(test_images[:4]):
|
||||
logger.info(f"Processing image {idx+1}/4 for visualization")
|
||||
result_img, detections = detector.detect(test_case['image'])
|
||||
|
||||
# Save individual result
|
||||
result_path = results_dir / f"visualization_{idx}.png"
|
||||
result_img.save(result_path)
|
||||
logger.info(f"Saved individual result to {result_path}")
|
||||
|
||||
# Plot result
|
||||
axes[idx].imshow(result_img)
|
||||
axes[idx].set_title(f"Detections: {len(detections)}")
|
||||
axes[idx].axis('off')
|
||||
|
||||
# Save summary plot
|
||||
summary_path = results_dir / "summary.png"
|
||||
plt.tight_layout()
|
||||
plt.savefig(summary_path)
|
||||
plt.close()
|
||||
logger.info(f"Saved summary visualization to {summary_path}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run with output capture disabled
|
||||
pytest.main([__file__, "-v", "-s"])
|
||||
Reference in New Issue
Block a user