Files
recycling-project-solutions/test_api.py
T

258 lines
8.7 KiB
Python
Raw Normal View History

2025-07-11 20:07:36 +01:00
#!/usr/bin/env python3
"""
Test script for Memory Module Detection API
This script tests all API endpoints and provides usage examples.
"""
import requests
import json
import base64
import os
from PIL import Image
import io
# API base URL
BASE_URL = "http://localhost:5002"
2025-07-11 20:07:36 +01:00
def test_api_info():
"""Test the API info endpoint."""
print("🔍 Testing API Info...")
try:
response = requests.get(f"{BASE_URL}/")
if response.status_code == 200:
data = response.json()
print(f"✅ API Info: {data['message']}")
print(f" Model loaded: {data['model_loaded']}")
print(f" Supported formats: {data['supported_formats']}")
return True
else:
print(f"❌ API Info failed: {response.status_code}")
return False
except Exception as e:
print(f"❌ API Info error: {e}")
return False
def test_health_check():
"""Test the health check endpoint."""
print("\n🏥 Testing Health Check...")
try:
response = requests.get(f"{BASE_URL}/health")
if response.status_code == 200:
data = response.json()
print(f"✅ Health: {data['status']}")
print(f" Model loaded: {data['model_loaded']}")
return True
else:
print(f"❌ Health check failed: {response.status_code}")
return False
except Exception as e:
print(f"❌ Health check error: {e}")
return False
def test_hardcoded_detection():
"""Test detection with hardcoded image."""
print("\n🖼️ Testing Hardcoded Image Detection...")
try:
response = requests.get(f"{BASE_URL}/detect/hardcoded?confidence=0.8")
2025-07-11 20:07:36 +01:00
if response.status_code == 200:
data = response.json()
if data['success']:
print(f"✅ Hardcoded detection successful!")
print(f" Found {data['num_detections']} memory modules")
for i, detection in enumerate(data['detections']):
print(f" Detection {i+1}: {detection['class_name']} "
f"(confidence: {detection['confidence']:.3f})")
# Save annotated image
if 'annotated_image' in data:
save_base64_image(data['annotated_image'], 'test_hardcoded_result.png')
print(" Annotated image saved as: test_hardcoded_result.png")
return True
else:
print(f"❌ Hardcoded detection failed: {data.get('error', 'Unknown error')}")
return False
else:
print(f"❌ Hardcoded detection failed: {response.status_code}")
if response.text:
print(f" Response: {response.text}")
return False
except Exception as e:
print(f"❌ Hardcoded detection error: {e}")
return False
def test_file_upload():
"""Test detection with file upload."""
print("\n📤 Testing File Upload Detection...")
# Find a test image
test_image_path = None
possible_paths = [
'training/memory/out1.png',
'training/memory/out2.png',
'training/val/images/memory_out8.png'
]
for path in possible_paths:
if os.path.exists(path):
test_image_path = path
break
if not test_image_path:
print("❌ No test image found. Skipping file upload test.")
return False
try:
with open(test_image_path, 'rb') as f:
files = {'image': f}
data = {'confidence': '0.8'}
2025-07-11 20:07:36 +01:00
response = requests.post(f"{BASE_URL}/detect", files=files, data=data)
if response.status_code == 200:
result = response.json()
if result['success']:
print(f"✅ File upload detection successful!")
print(f" Test image: {test_image_path}")
print(f" Found {result['num_detections']} memory modules")
for i, detection in enumerate(result['detections']):
print(f" Detection {i+1}: {detection['class_name']} "
f"(confidence: {detection['confidence']:.3f})")
# Save annotated image
if 'annotated_image' in result:
save_base64_image(result['annotated_image'], 'test_upload_result.png')
print(" Annotated image saved as: test_upload_result.png")
return True
else:
print(f"❌ File upload detection failed: {result.get('error', 'Unknown error')}")
return False
else:
print(f"❌ File upload detection failed: {response.status_code}")
if response.text:
print(f" Response: {response.text}")
return False
except Exception as e:
print(f"❌ File upload detection error: {e}")
return False
def test_base64_detection():
"""Test detection with base64 encoded image."""
print("\n🔢 Testing Base64 Detection...")
# Find a test image
test_image_path = None
possible_paths = [
'training/memory/out1.png',
'training/memory/out2.png'
]
for path in possible_paths:
if os.path.exists(path):
test_image_path = path
break
if not test_image_path:
print("❌ No test image found. Skipping base64 test.")
return False
try:
# Convert image to base64
with open(test_image_path, 'rb') as f:
image_data = f.read()
base64_string = base64.b64encode(image_data).decode('utf-8')
# Send request
payload = {
'image': base64_string,
'confidence': 0.8
2025-07-11 20:07:36 +01:00
}
response = requests.post(
f"{BASE_URL}/detect/base64",
json=payload,
headers={'Content-Type': 'application/json'}
)
if response.status_code == 200:
result = response.json()
if result['success']:
print(f"✅ Base64 detection successful!")
print(f" Test image: {test_image_path}")
print(f" Found {result['num_detections']} memory modules")
for i, detection in enumerate(result['detections']):
print(f" Detection {i+1}: {detection['class_name']} "
f"(confidence: {detection['confidence']:.3f})")
# Save annotated image
if 'annotated_image' in result:
save_base64_image(result['annotated_image'], 'test_base64_result.png')
print(" Annotated image saved as: test_base64_result.png")
return True
else:
print(f"❌ Base64 detection failed: {result.get('error', 'Unknown error')}")
return False
else:
print(f"❌ Base64 detection failed: {response.status_code}")
if response.text:
print(f" Response: {response.text}")
return False
except Exception as e:
print(f"❌ Base64 detection error: {e}")
return False
def save_base64_image(base64_string, filename):
"""Save base64 encoded image to file."""
try:
image_data = base64.b64decode(base64_string)
image = Image.open(io.BytesIO(image_data))
image.save(filename)
except Exception as e:
print(f" Warning: Could not save image {filename}: {e}")
def main():
"""Run all API tests."""
print("🧪 Memory Module Detection API Test Suite")
print("=" * 50)
# Check if API is running
try:
response = requests.get(f"{BASE_URL}/health", timeout=5)
except requests.exceptions.ConnectionError:
print("❌ API is not running!")
print(" Please start the API first: python3 main.py")
return
except Exception as e:
print(f"❌ Cannot connect to API: {e}")
return
# Run tests
tests = [
test_api_info,
test_health_check,
test_hardcoded_detection,
test_file_upload,
test_base64_detection
]
passed = 0
total = len(tests)
for test in tests:
if test():
passed += 1
print("\n" + "=" * 50)
print(f"🏁 Test Results: {passed}/{total} tests passed")
if passed == total:
print("🎉 All tests passed! The API is working correctly.")
else:
print("⚠️ Some tests failed. Check the output above for details.")
if passed == 0:
print(" Make sure the model is trained: python3 train.py")
if __name__ == "__main__":
main()