diff --git a/api_docs.py b/api_docs.py deleted file mode 100644 index f1ab065..0000000 --- a/api_docs.py +++ /dev/null @@ -1,245 +0,0 @@ -#!/usr/bin/env python3 -""" -Swagger UI API Documentation for Memory Module Detection -This creates a separate API documentation interface using Flask-RESTX -""" - -from flask import Flask, request, jsonify -from flask_restx import Api, Resource, fields, reqparse -from flask_cors import CORS -from werkzeug.datastructures import FileStorage -import os -from inference_utils import MemoryModuleDetector - -# Initialize Flask app with Swagger -app = Flask(__name__) -CORS(app) - -# Configure Swagger UI -api = Api( - app, - version='1.0.0', - title='Memory Module Detection API', - description='AI-powered memory module detection in motherboard images using YOLOv8', - doc='/docs/', # Swagger UI will be available at /docs/ - prefix='/api/v1' -) - -# Create namespaces -ns_health = api.namespace('health', description='System health and status') -ns_detect = api.namespace('detect', description='Memory module detection operations') - -# Initialize detector -MODEL_PATH = 'runs/detect/memory_module_detection/weights/best.pt' -detector = MemoryModuleDetector(MODEL_PATH) - -# Define models for Swagger documentation -detection_model = api.model('Detection', { - 'bbox': fields.List(fields.Float, description='Bounding box coordinates [x1, y1, x2, y2]'), - 'confidence': fields.Float(description='Detection confidence score (0.0-1.0)'), - 'class': fields.Integer(description='Class ID (0 for memory_module)'), - 'class_name': fields.String(description='Class name (memory_module)') -}) - -detection_response = api.model('DetectionResponse', { - 'success': fields.Boolean(description='Whether detection was successful'), - 'detections': fields.List(fields.Nested(detection_model), description='List of detected memory modules'), - 'num_detections': fields.Integer(description='Number of memory modules detected'), - 'annotated_image': fields.String(description='Base64 encoded annotated image'), - 'confidence_threshold': fields.Float(description='Confidence threshold used'), - 'original_filename': fields.String(description='Original filename (for uploads)') -}) - -health_response = api.model('HealthResponse', { - 'status': fields.String(description='System health status'), - 'model_loaded': fields.Boolean(description='Whether the AI model is loaded'), - 'model_path': fields.String(description='Path to the AI model file') -}) - -error_response = api.model('ErrorResponse', { - 'success': fields.Boolean(description='Always false for errors'), - 'error': fields.String(description='Error message') -}) - -# Health endpoint -@ns_health.route('/') -class Health(Resource): - @ns_health.doc('health_check') - @ns_health.marshal_with(health_response) - def get(self): - """Check system health and model status""" - return { - 'status': 'healthy', - 'model_loaded': detector.model is not None, - 'model_path': MODEL_PATH - } - -# File upload parser -upload_parser = reqparse.RequestParser() -upload_parser.add_argument('image', location='files', type=FileStorage, required=True, - help='Motherboard image file (PNG, JPG, JPEG, GIF, BMP)') -upload_parser.add_argument('confidence', type=float, default=0.8, - help='Confidence threshold (0.1-1.0, default: 0.8)') - -@ns_detect.route('/upload') -class DetectUpload(Resource): - @ns_detect.doc('detect_upload') - @ns_detect.expect(upload_parser) - @ns_detect.marshal_with(detection_response, code=200) - @ns_detect.marshal_with(error_response, code=400) - @ns_detect.marshal_with(error_response, code=500) - def post(self): - """Upload and analyze motherboard image for memory modules""" - try: - if detector.model is None: - return {'success': False, 'error': 'Model not loaded'}, 500 - - args = upload_parser.parse_args() - file = args['image'] - confidence = args['confidence'] - - if not file: - return {'success': False, 'error': 'No image file provided'}, 400 - - # Save file temporarily - temp_path = f"temp_{file.filename}" - file.save(temp_path) - - try: - # Run detection - detections, annotated_image = detector.detect(temp_path, conf_threshold=confidence) - - # Convert annotated image to base64 - import io - import base64 - buffer = io.BytesIO() - annotated_image.save(buffer, format='PNG') - annotated_base64 = base64.b64encode(buffer.getvalue()).decode() - - return { - 'success': True, - 'detections': detections, - 'num_detections': len(detections), - 'annotated_image': annotated_base64, - 'confidence_threshold': confidence, - 'original_filename': file.filename - } - - finally: - # Clean up - if os.path.exists(temp_path): - os.remove(temp_path) - - except Exception as e: - return {'success': False, 'error': str(e)}, 500 - -# Hardcoded image parser -hardcoded_parser = reqparse.RequestParser() -hardcoded_parser.add_argument('confidence', type=float, default=0.8, location='args', - help='Confidence threshold (0.1-1.0, default: 0.8)') - -@ns_detect.route('/hardcoded') -class DetectHardcoded(Resource): - @ns_detect.doc('detect_hardcoded') - @ns_detect.expect(hardcoded_parser) - @ns_detect.marshal_with(detection_response, code=200) - @ns_detect.marshal_with(error_response, code=404) - @ns_detect.marshal_with(error_response, code=500) - def get(self): - """Analyze predefined test image for memory modules""" - try: - if detector.model is None: - return {'success': False, 'error': 'Model not loaded'}, 500 - - args = hardcoded_parser.parse_args() - confidence = args['confidence'] - - test_image_path = 'training/memory/out1.png' - if not os.path.exists(test_image_path): - return {'success': False, 'error': f'Test image not found at {test_image_path}'}, 404 - - # Run detection - detections, annotated_image = detector.detect(test_image_path, conf_threshold=confidence) - - # Convert annotated image to base64 - import io - import base64 - buffer = io.BytesIO() - annotated_image.save(buffer, format='PNG') - annotated_base64 = base64.b64encode(buffer.getvalue()).decode() - - return { - 'success': True, - 'detections': detections, - 'num_detections': len(detections), - 'annotated_image': annotated_base64, - 'confidence_threshold': confidence, - 'test_image_path': test_image_path - } - - except Exception as e: - return {'success': False, 'error': str(e)}, 500 - -# Base64 image model -base64_model = api.model('Base64Request', { - 'image': fields.String(required=True, description='Base64 encoded image data'), - 'confidence': fields.Float(default=0.8, description='Confidence threshold (0.1-1.0)') -}) - -@ns_detect.route('/base64') -class DetectBase64(Resource): - @ns_detect.doc('detect_base64') - @ns_detect.expect(base64_model) - @ns_detect.marshal_with(detection_response, code=200) - @ns_detect.marshal_with(error_response, code=400) - @ns_detect.marshal_with(error_response, code=500) - def post(self): - """Analyze base64 encoded image for memory modules""" - try: - if detector.model is None: - return {'success': False, 'error': 'Model not loaded'}, 500 - - data = request.get_json() - if not data or 'image' not in data: - return {'success': False, 'error': 'No base64 image data provided'}, 400 - - confidence = data.get('confidence', 0.8) - - # Decode base64 image - import base64 - import io - from PIL import Image - import numpy as np - - try: - img_data = base64.b64decode(data['image']) - image = Image.open(io.BytesIO(img_data)) - except Exception as e: - return {'success': False, 'error': f'Invalid base64 image data: {str(e)}'}, 400 - - # Run detection - detections, annotated_image = detector.detect_from_array(np.array(image), conf_threshold=confidence) - - # Convert annotated image to base64 - buffer = io.BytesIO() - annotated_image.save(buffer, format='PNG') - annotated_base64 = base64.b64encode(buffer.getvalue()).decode() - - return { - 'success': True, - 'detections': detections, - 'num_detections': len(detections), - 'annotated_image': annotated_base64, - 'confidence_threshold': confidence - } - - except Exception as e: - return {'success': False, 'error': str(e)}, 500 - -if __name__ == '__main__': - print("Starting Memory Module Detection API with Swagger UI...") - print(f"Model path: {MODEL_PATH}") - print(f"Model loaded: {detector.model is not None}") - print("Swagger UI available at: http://localhost:5003/docs/") - - app.run(host='0.0.0.0', port=5003, debug=True) diff --git a/main.py b/main.py index 2879b10..022dc23 100644 --- a/main.py +++ b/main.py @@ -88,63 +88,16 @@ def api_info(): 'endpoints': { '/': 'GET - Frontend interface or API information', '/api': 'GET - API information (JSON)', - '/docs': 'GET - API documentation (Swagger UI)', '/detect': 'POST - Upload image for memory module detection', '/detect/hardcoded': 'GET - Process hardcoded test image', '/detect/base64': 'POST - Process base64 encoded image', '/health': 'GET - Health check' }, 'model_loaded': detector.model is not None, - 'supported_formats': list(ALLOWED_EXTENSIONS), - 'swagger_ui': 'http://localhost:5003/docs/ (run: python3 api_docs.py)' + 'supported_formats': list(ALLOWED_EXTENSIONS) }) -@app.route('/docs') -def api_docs(): - """Redirect to API documentation.""" - return """ - - -
-Interactive Swagger UI documentation for all API endpoints
-python3 api_docs.pyhttp://localhost:5003/docs/
-