258 lines
8.7 KiB
Python
258 lines
8.7 KiB
Python
|
|
#!/usr/bin/env python3
|
||
|
|
"""
|
||
|
|
Test script for Memory Module Detection API
|
||
|
|
This script tests all API endpoints and provides usage examples.
|
||
|
|
"""
|
||
|
|
|
||
|
|
import requests
|
||
|
|
import json
|
||
|
|
import base64
|
||
|
|
import os
|
||
|
|
from PIL import Image
|
||
|
|
import io
|
||
|
|
|
||
|
|
# API base URL
|
||
|
|
BASE_URL = "http://localhost:5001"
|
||
|
|
|
||
|
|
def test_api_info():
|
||
|
|
"""Test the API info endpoint."""
|
||
|
|
print("🔍 Testing API Info...")
|
||
|
|
try:
|
||
|
|
response = requests.get(f"{BASE_URL}/")
|
||
|
|
if response.status_code == 200:
|
||
|
|
data = response.json()
|
||
|
|
print(f"✅ API Info: {data['message']}")
|
||
|
|
print(f" Model loaded: {data['model_loaded']}")
|
||
|
|
print(f" Supported formats: {data['supported_formats']}")
|
||
|
|
return True
|
||
|
|
else:
|
||
|
|
print(f"❌ API Info failed: {response.status_code}")
|
||
|
|
return False
|
||
|
|
except Exception as e:
|
||
|
|
print(f"❌ API Info error: {e}")
|
||
|
|
return False
|
||
|
|
|
||
|
|
def test_health_check():
|
||
|
|
"""Test the health check endpoint."""
|
||
|
|
print("\n🏥 Testing Health Check...")
|
||
|
|
try:
|
||
|
|
response = requests.get(f"{BASE_URL}/health")
|
||
|
|
if response.status_code == 200:
|
||
|
|
data = response.json()
|
||
|
|
print(f"✅ Health: {data['status']}")
|
||
|
|
print(f" Model loaded: {data['model_loaded']}")
|
||
|
|
return True
|
||
|
|
else:
|
||
|
|
print(f"❌ Health check failed: {response.status_code}")
|
||
|
|
return False
|
||
|
|
except Exception as e:
|
||
|
|
print(f"❌ Health check error: {e}")
|
||
|
|
return False
|
||
|
|
|
||
|
|
def test_hardcoded_detection():
|
||
|
|
"""Test detection with hardcoded image."""
|
||
|
|
print("\n🖼️ Testing Hardcoded Image Detection...")
|
||
|
|
try:
|
||
|
|
response = requests.get(f"{BASE_URL}/detect/hardcoded?confidence=0.5")
|
||
|
|
if response.status_code == 200:
|
||
|
|
data = response.json()
|
||
|
|
if data['success']:
|
||
|
|
print(f"✅ Hardcoded detection successful!")
|
||
|
|
print(f" Found {data['num_detections']} memory modules")
|
||
|
|
for i, detection in enumerate(data['detections']):
|
||
|
|
print(f" Detection {i+1}: {detection['class_name']} "
|
||
|
|
f"(confidence: {detection['confidence']:.3f})")
|
||
|
|
|
||
|
|
# Save annotated image
|
||
|
|
if 'annotated_image' in data:
|
||
|
|
save_base64_image(data['annotated_image'], 'test_hardcoded_result.png')
|
||
|
|
print(" Annotated image saved as: test_hardcoded_result.png")
|
||
|
|
|
||
|
|
return True
|
||
|
|
else:
|
||
|
|
print(f"❌ Hardcoded detection failed: {data.get('error', 'Unknown error')}")
|
||
|
|
return False
|
||
|
|
else:
|
||
|
|
print(f"❌ Hardcoded detection failed: {response.status_code}")
|
||
|
|
if response.text:
|
||
|
|
print(f" Response: {response.text}")
|
||
|
|
return False
|
||
|
|
except Exception as e:
|
||
|
|
print(f"❌ Hardcoded detection error: {e}")
|
||
|
|
return False
|
||
|
|
|
||
|
|
def test_file_upload():
|
||
|
|
"""Test detection with file upload."""
|
||
|
|
print("\n📤 Testing File Upload Detection...")
|
||
|
|
|
||
|
|
# Find a test image
|
||
|
|
test_image_path = None
|
||
|
|
possible_paths = [
|
||
|
|
'training/memory/out1.png',
|
||
|
|
'training/memory/out2.png',
|
||
|
|
'training/val/images/memory_out8.png'
|
||
|
|
]
|
||
|
|
|
||
|
|
for path in possible_paths:
|
||
|
|
if os.path.exists(path):
|
||
|
|
test_image_path = path
|
||
|
|
break
|
||
|
|
|
||
|
|
if not test_image_path:
|
||
|
|
print("❌ No test image found. Skipping file upload test.")
|
||
|
|
return False
|
||
|
|
|
||
|
|
try:
|
||
|
|
with open(test_image_path, 'rb') as f:
|
||
|
|
files = {'image': f}
|
||
|
|
data = {'confidence': '0.5'}
|
||
|
|
response = requests.post(f"{BASE_URL}/detect", files=files, data=data)
|
||
|
|
|
||
|
|
if response.status_code == 200:
|
||
|
|
result = response.json()
|
||
|
|
if result['success']:
|
||
|
|
print(f"✅ File upload detection successful!")
|
||
|
|
print(f" Test image: {test_image_path}")
|
||
|
|
print(f" Found {result['num_detections']} memory modules")
|
||
|
|
for i, detection in enumerate(result['detections']):
|
||
|
|
print(f" Detection {i+1}: {detection['class_name']} "
|
||
|
|
f"(confidence: {detection['confidence']:.3f})")
|
||
|
|
|
||
|
|
# Save annotated image
|
||
|
|
if 'annotated_image' in result:
|
||
|
|
save_base64_image(result['annotated_image'], 'test_upload_result.png')
|
||
|
|
print(" Annotated image saved as: test_upload_result.png")
|
||
|
|
|
||
|
|
return True
|
||
|
|
else:
|
||
|
|
print(f"❌ File upload detection failed: {result.get('error', 'Unknown error')}")
|
||
|
|
return False
|
||
|
|
else:
|
||
|
|
print(f"❌ File upload detection failed: {response.status_code}")
|
||
|
|
if response.text:
|
||
|
|
print(f" Response: {response.text}")
|
||
|
|
return False
|
||
|
|
except Exception as e:
|
||
|
|
print(f"❌ File upload detection error: {e}")
|
||
|
|
return False
|
||
|
|
|
||
|
|
def test_base64_detection():
|
||
|
|
"""Test detection with base64 encoded image."""
|
||
|
|
print("\n🔢 Testing Base64 Detection...")
|
||
|
|
|
||
|
|
# Find a test image
|
||
|
|
test_image_path = None
|
||
|
|
possible_paths = [
|
||
|
|
'training/memory/out1.png',
|
||
|
|
'training/memory/out2.png'
|
||
|
|
]
|
||
|
|
|
||
|
|
for path in possible_paths:
|
||
|
|
if os.path.exists(path):
|
||
|
|
test_image_path = path
|
||
|
|
break
|
||
|
|
|
||
|
|
if not test_image_path:
|
||
|
|
print("❌ No test image found. Skipping base64 test.")
|
||
|
|
return False
|
||
|
|
|
||
|
|
try:
|
||
|
|
# Convert image to base64
|
||
|
|
with open(test_image_path, 'rb') as f:
|
||
|
|
image_data = f.read()
|
||
|
|
base64_string = base64.b64encode(image_data).decode('utf-8')
|
||
|
|
|
||
|
|
# Send request
|
||
|
|
payload = {
|
||
|
|
'image': base64_string,
|
||
|
|
'confidence': 0.5
|
||
|
|
}
|
||
|
|
|
||
|
|
response = requests.post(
|
||
|
|
f"{BASE_URL}/detect/base64",
|
||
|
|
json=payload,
|
||
|
|
headers={'Content-Type': 'application/json'}
|
||
|
|
)
|
||
|
|
|
||
|
|
if response.status_code == 200:
|
||
|
|
result = response.json()
|
||
|
|
if result['success']:
|
||
|
|
print(f"✅ Base64 detection successful!")
|
||
|
|
print(f" Test image: {test_image_path}")
|
||
|
|
print(f" Found {result['num_detections']} memory modules")
|
||
|
|
for i, detection in enumerate(result['detections']):
|
||
|
|
print(f" Detection {i+1}: {detection['class_name']} "
|
||
|
|
f"(confidence: {detection['confidence']:.3f})")
|
||
|
|
|
||
|
|
# Save annotated image
|
||
|
|
if 'annotated_image' in result:
|
||
|
|
save_base64_image(result['annotated_image'], 'test_base64_result.png')
|
||
|
|
print(" Annotated image saved as: test_base64_result.png")
|
||
|
|
|
||
|
|
return True
|
||
|
|
else:
|
||
|
|
print(f"❌ Base64 detection failed: {result.get('error', 'Unknown error')}")
|
||
|
|
return False
|
||
|
|
else:
|
||
|
|
print(f"❌ Base64 detection failed: {response.status_code}")
|
||
|
|
if response.text:
|
||
|
|
print(f" Response: {response.text}")
|
||
|
|
return False
|
||
|
|
except Exception as e:
|
||
|
|
print(f"❌ Base64 detection error: {e}")
|
||
|
|
return False
|
||
|
|
|
||
|
|
def save_base64_image(base64_string, filename):
|
||
|
|
"""Save base64 encoded image to file."""
|
||
|
|
try:
|
||
|
|
image_data = base64.b64decode(base64_string)
|
||
|
|
image = Image.open(io.BytesIO(image_data))
|
||
|
|
image.save(filename)
|
||
|
|
except Exception as e:
|
||
|
|
print(f" Warning: Could not save image {filename}: {e}")
|
||
|
|
|
||
|
|
def main():
|
||
|
|
"""Run all API tests."""
|
||
|
|
print("🧪 Memory Module Detection API Test Suite")
|
||
|
|
print("=" * 50)
|
||
|
|
|
||
|
|
# Check if API is running
|
||
|
|
try:
|
||
|
|
response = requests.get(f"{BASE_URL}/health", timeout=5)
|
||
|
|
except requests.exceptions.ConnectionError:
|
||
|
|
print("❌ API is not running!")
|
||
|
|
print(" Please start the API first: python3 main.py")
|
||
|
|
return
|
||
|
|
except Exception as e:
|
||
|
|
print(f"❌ Cannot connect to API: {e}")
|
||
|
|
return
|
||
|
|
|
||
|
|
# Run tests
|
||
|
|
tests = [
|
||
|
|
test_api_info,
|
||
|
|
test_health_check,
|
||
|
|
test_hardcoded_detection,
|
||
|
|
test_file_upload,
|
||
|
|
test_base64_detection
|
||
|
|
]
|
||
|
|
|
||
|
|
passed = 0
|
||
|
|
total = len(tests)
|
||
|
|
|
||
|
|
for test in tests:
|
||
|
|
if test():
|
||
|
|
passed += 1
|
||
|
|
|
||
|
|
print("\n" + "=" * 50)
|
||
|
|
print(f"🏁 Test Results: {passed}/{total} tests passed")
|
||
|
|
|
||
|
|
if passed == total:
|
||
|
|
print("🎉 All tests passed! The API is working correctly.")
|
||
|
|
else:
|
||
|
|
print("⚠️ Some tests failed. Check the output above for details.")
|
||
|
|
if passed == 0:
|
||
|
|
print(" Make sure the model is trained: python3 train.py")
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
main()
|