#!/usr/bin/env python3 """ Test script for Memory Module Detection API This script tests all API endpoints and provides usage examples. """ import requests import json import base64 import os from PIL import Image import io # API base URL BASE_URL = "http://localhost:5002" def test_api_info(): """Test the API info endpoint.""" print("๐Ÿ” Testing API Info...") try: response = requests.get(f"{BASE_URL}/") if response.status_code == 200: data = response.json() print(f"โœ… API Info: {data['message']}") print(f" Model loaded: {data['model_loaded']}") print(f" Supported formats: {data['supported_formats']}") return True else: print(f"โŒ API Info failed: {response.status_code}") return False except Exception as e: print(f"โŒ API Info error: {e}") return False def test_health_check(): """Test the health check endpoint.""" print("\n๐Ÿฅ Testing Health Check...") try: response = requests.get(f"{BASE_URL}/health") if response.status_code == 200: data = response.json() print(f"โœ… Health: {data['status']}") print(f" Model loaded: {data['model_loaded']}") return True else: print(f"โŒ Health check failed: {response.status_code}") return False except Exception as e: print(f"โŒ Health check error: {e}") return False def test_hardcoded_detection(): """Test detection with hardcoded image.""" print("\n๐Ÿ–ผ๏ธ Testing Hardcoded Image Detection...") try: response = requests.get(f"{BASE_URL}/detect/hardcoded?confidence=0.8") if response.status_code == 200: data = response.json() if data['success']: print(f"โœ… Hardcoded detection successful!") print(f" Found {data['num_detections']} memory modules") for i, detection in enumerate(data['detections']): print(f" Detection {i+1}: {detection['class_name']} " f"(confidence: {detection['confidence']:.3f})") # Save annotated image if 'annotated_image' in data: save_base64_image(data['annotated_image'], 'test_hardcoded_result.png') print(" Annotated image saved as: test_hardcoded_result.png") return True else: print(f"โŒ Hardcoded detection failed: {data.get('error', 'Unknown error')}") return False else: print(f"โŒ Hardcoded detection failed: {response.status_code}") if response.text: print(f" Response: {response.text}") return False except Exception as e: print(f"โŒ Hardcoded detection error: {e}") return False def test_file_upload(): """Test detection with file upload.""" print("\n๐Ÿ“ค Testing File Upload Detection...") # Find a test image test_image_path = None possible_paths = [ 'training/memory/out1.png', 'training/memory/out2.png', 'training/val/images/memory_out8.png' ] for path in possible_paths: if os.path.exists(path): test_image_path = path break if not test_image_path: print("โŒ No test image found. Skipping file upload test.") return False try: with open(test_image_path, 'rb') as f: files = {'image': f} data = {'confidence': '0.8'} response = requests.post(f"{BASE_URL}/detect", files=files, data=data) if response.status_code == 200: result = response.json() if result['success']: print(f"โœ… File upload detection successful!") print(f" Test image: {test_image_path}") print(f" Found {result['num_detections']} memory modules") for i, detection in enumerate(result['detections']): print(f" Detection {i+1}: {detection['class_name']} " f"(confidence: {detection['confidence']:.3f})") # Save annotated image if 'annotated_image' in result: save_base64_image(result['annotated_image'], 'test_upload_result.png') print(" Annotated image saved as: test_upload_result.png") return True else: print(f"โŒ File upload detection failed: {result.get('error', 'Unknown error')}") return False else: print(f"โŒ File upload detection failed: {response.status_code}") if response.text: print(f" Response: {response.text}") return False except Exception as e: print(f"โŒ File upload detection error: {e}") return False def test_base64_detection(): """Test detection with base64 encoded image.""" print("\n๐Ÿ”ข Testing Base64 Detection...") # Find a test image test_image_path = None possible_paths = [ 'training/memory/out1.png', 'training/memory/out2.png' ] for path in possible_paths: if os.path.exists(path): test_image_path = path break if not test_image_path: print("โŒ No test image found. Skipping base64 test.") return False try: # Convert image to base64 with open(test_image_path, 'rb') as f: image_data = f.read() base64_string = base64.b64encode(image_data).decode('utf-8') # Send request payload = { 'image': base64_string, 'confidence': 0.8 } response = requests.post( f"{BASE_URL}/detect/base64", json=payload, headers={'Content-Type': 'application/json'} ) if response.status_code == 200: result = response.json() if result['success']: print(f"โœ… Base64 detection successful!") print(f" Test image: {test_image_path}") print(f" Found {result['num_detections']} memory modules") for i, detection in enumerate(result['detections']): print(f" Detection {i+1}: {detection['class_name']} " f"(confidence: {detection['confidence']:.3f})") # Save annotated image if 'annotated_image' in result: save_base64_image(result['annotated_image'], 'test_base64_result.png') print(" Annotated image saved as: test_base64_result.png") return True else: print(f"โŒ Base64 detection failed: {result.get('error', 'Unknown error')}") return False else: print(f"โŒ Base64 detection failed: {response.status_code}") if response.text: print(f" Response: {response.text}") return False except Exception as e: print(f"โŒ Base64 detection error: {e}") return False def save_base64_image(base64_string, filename): """Save base64 encoded image to file.""" try: image_data = base64.b64decode(base64_string) image = Image.open(io.BytesIO(image_data)) image.save(filename) except Exception as e: print(f" Warning: Could not save image {filename}: {e}") def main(): """Run all API tests.""" print("๐Ÿงช Memory Module Detection API Test Suite") print("=" * 50) # Check if API is running try: response = requests.get(f"{BASE_URL}/health", timeout=5) except requests.exceptions.ConnectionError: print("โŒ API is not running!") print(" Please start the API first: python3 main.py") return except Exception as e: print(f"โŒ Cannot connect to API: {e}") return # Run tests tests = [ test_api_info, test_health_check, test_hardcoded_detection, test_file_upload, test_base64_detection ] passed = 0 total = len(tests) for test in tests: if test(): passed += 1 print("\n" + "=" * 50) print(f"๐Ÿ Test Results: {passed}/{total} tests passed") if passed == total: print("๐ŸŽ‰ All tests passed! The API is working correctly.") else: print("โš ๏ธ Some tests failed. Check the output above for details.") if passed == 0: print(" Make sure the model is trained: python3 train.py") if __name__ == "__main__": main()