diff --git a/inference_utils.py b/inference_utils.py index 8e08052..4c9d2af 100644 --- a/inference_utils.py +++ b/inference_utils.py @@ -36,11 +36,24 @@ class MemoryModuleDetector: def load_model(self): """Load the trained YOLOv8 model.""" try: - self.model = YOLO(self.model_path) + # Fix for PyTorch 2.6+ weights_only issue + import torch + # Use weights_only=False for compatibility + with torch.serialization.safe_globals(['ultralytics.nn.tasks.DetectionModel']): + self.model = YOLO(self.model_path) print(f"Model loaded successfully from {self.model_path}") except Exception as e: - print(f"Error loading model: {e}") - self.model = None + try: + # Fallback: try loading with weights_only=False + import torch + original_load = torch.load + torch.load = lambda *args, **kwargs: original_load(*args, **kwargs, weights_only=False) + self.model = YOLO(self.model_path) + torch.load = original_load + print(f"Model loaded successfully from {self.model_path} (fallback method)") + except Exception as e2: + print(f"Error loading model: {e2}") + self.model = None def detect(self, image_path, conf_threshold=0.5, iou_threshold=0.45): """ diff --git a/main.py b/main.py index 279e010..0b0fcba 100644 --- a/main.py +++ b/main.py @@ -9,11 +9,14 @@ import io import base64 from flask import Flask, request, jsonify, send_file, render_template from flask_cors import CORS +from flask_restx import Api, Resource, fields, reqparse from PIL import Image import numpy as np from werkzeug.utils import secure_filename +from werkzeug.datastructures import FileStorage import tempfile import logging +from datetime import datetime from inference_utils import MemoryModuleDetector # Configure logging @@ -24,6 +27,50 @@ logger = logging.getLogger(__name__) app = Flask(__name__) CORS(app) +# Initialize Flask-RESTX API with custom configuration +api = Api( + app, + version='1.0', + title='Memory Module Detection API', + description='AI-powered memory module detection system for motherboard images using YOLOv8', + doc='/docs/', + prefix='/api/v1' +) + +# Create namespaces +ns_health = api.namespace('health', description='Health check operations') +ns_detection = api.namespace('detection', description='Memory module detection operations') +ns_info = api.namespace('info', description='API information') + +# Define API models for documentation +detection_result = api.model('DetectionResult', { + 'success': fields.Boolean(required=True, description='Whether detection was successful'), + 'detections': fields.List(fields.Raw, description='List of detected memory modules with coordinates'), + 'num_detections': fields.Integer(description='Number of memory modules detected'), + 'annotated_image': fields.String(description='Base64 encoded annotated image'), + 'confidence_threshold': fields.Float(description='Confidence threshold used for detection'), + 'test_image_path': fields.String(description='Path to the test image (for hardcoded tests)') +}) + +error_response = api.model('ErrorResponse', { + 'error': fields.String(required=True, description='Error message'), + 'success': fields.Boolean(required=True, description='Always false for errors') +}) + +health_response = api.model('HealthResponse', { + 'status': fields.String(required=True, description='Health status'), + 'model_loaded': fields.Boolean(required=True, description='Whether the ML model is loaded'), + 'timestamp': fields.String(required=True, description='Current timestamp') +}) + +api_info_response = api.model('ApiInfoResponse', { + 'name': fields.String(required=True, description='API name'), + 'version': fields.String(required=True, description='API version'), + 'description': fields.String(required=True, description='API description'), + 'model_info': fields.Raw(description='Information about the ML model'), + 'endpoints': fields.List(fields.String, description='Available endpoints') +}) + # Configuration app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024 # 16MB max file size UPLOAD_FOLDER = 'uploads' @@ -352,6 +399,187 @@ def internal_error(e): 'success': False }), 500 + +# ============================================================================ +# SWAGGER API RESOURCES +# ============================================================================ + +@ns_health.route('') +class HealthCheck(Resource): + @ns_health.doc('health_check') + @ns_health.marshal_with(health_response) + def get(self): + """Check API health status""" + return { + 'status': 'healthy', + 'model_loaded': detector.model is not None, + 'timestamp': datetime.now().isoformat() + } + +@ns_info.route('') +class ApiInfo(Resource): + @ns_info.doc('api_info') + @ns_info.marshal_with(api_info_response) + def get(self): + """Get API information and available endpoints""" + return { + 'name': 'Memory Module Detection API', + 'version': '1.0', + 'description': 'AI-powered memory module detection system for motherboard images using YOLOv8', + 'model_info': { + 'architecture': 'YOLOv8 Nano', + 'classes': ['memory_module'], + 'input_size': '640x640', + 'model_loaded': detector.model is not None + }, + 'endpoints': [ + '/api/v1/health', + '/api/v1/info', + '/api/v1/detection/upload', + '/api/v1/detection/hardcoded', + '/api/v1/detection/base64' + ] + } + +# File upload parser +upload_parser = reqparse.RequestParser() +upload_parser.add_argument('file', location='files', type=FileStorage, required=True, help='Image file to analyze') +upload_parser.add_argument('confidence', type=float, default=0.8, help='Confidence threshold (0.0-1.0)') + +@ns_detection.route('/upload') +class DetectionUpload(Resource): + @ns_detection.doc('upload_detection') + @ns_detection.expect(upload_parser) + @ns_detection.marshal_with(detection_result, code=200) + @ns_detection.marshal_with(error_response, code=400) + @ns_detection.marshal_with(error_response, code=500) + def post(self): + """Upload an image for memory module detection""" + try: + args = upload_parser.parse_args() + file = args['file'] + confidence = args.get('confidence', 0.8) + + if not file or file.filename == '': + return {'error': 'No file provided', 'success': False}, 400 + + if not allowed_file(file.filename): + return {'error': 'Invalid file type. Allowed: PNG, JPG, JPEG, GIF, BMP', 'success': False}, 400 + + # Save uploaded file temporarily + filename = secure_filename(file.filename) + temp_path = os.path.join(UPLOAD_FOLDER, filename) + file.save(temp_path) + + # Run detection + detections, annotated_image = detector.detect(temp_path, conf_threshold=confidence) + + # Convert annotated image to base64 + annotated_base64 = image_to_base64(annotated_image) + + return { + 'success': True, + 'detections': detections, + 'num_detections': len(detections), + 'annotated_image': annotated_base64, + 'confidence_threshold': confidence + } + + except Exception as e: + return {'error': f'Error processing image: {str(e)}', 'success': False}, 500 + +# Hardcoded test parser +hardcoded_parser = reqparse.RequestParser() +hardcoded_parser.add_argument('confidence', type=float, default=0.8, help='Confidence threshold (0.0-1.0)') + +@ns_detection.route('/hardcoded') +class DetectionHardcoded(Resource): + @ns_detection.doc('hardcoded_detection') + @ns_detection.expect(hardcoded_parser) + @ns_detection.marshal_with(detection_result, code=200) + @ns_detection.marshal_with(error_response, code=404) + @ns_detection.marshal_with(error_response, code=500) + def get(self): + """Process hardcoded test image for memory module detection""" + try: + args = hardcoded_parser.parse_args() + confidence = args.get('confidence', 0.8) + + if detector.model is None: + return {'error': 'Model not loaded. Please train the model first.', 'success': False}, 500 + + if not os.path.exists(HARDCODED_IMAGE_PATH): + return {'error': f'Hardcoded test image not found at {HARDCODED_IMAGE_PATH}', 'success': False}, 404 + + # Run detection + detections, annotated_image = detector.detect(HARDCODED_IMAGE_PATH, conf_threshold=confidence) + + # Convert annotated image to base64 + annotated_base64 = image_to_base64(annotated_image) + + return { + 'success': True, + 'detections': detections, + 'num_detections': len(detections), + 'annotated_image': annotated_base64, + 'confidence_threshold': confidence, + 'test_image_path': HARDCODED_IMAGE_PATH + } + + except Exception as e: + return {'error': f'Error processing hardcoded image: {str(e)}', 'success': False}, 500 + +# Base64 detection parser +base64_parser = reqparse.RequestParser() +base64_parser.add_argument('image_data', type=str, required=True, help='Base64 encoded image data') +base64_parser.add_argument('confidence', type=float, default=0.8, help='Confidence threshold (0.0-1.0)') + +@ns_detection.route('/base64') +class DetectionBase64(Resource): + @ns_detection.doc('base64_detection') + @ns_detection.expect(base64_parser) + @ns_detection.marshal_with(detection_result, code=200) + @ns_detection.marshal_with(error_response, code=400) + @ns_detection.marshal_with(error_response, code=500) + def post(self): + """Process base64 encoded image for memory module detection""" + try: + args = base64_parser.parse_args() + image_data = args['image_data'] + confidence = args.get('confidence', 0.8) + + if detector.model is None: + return {'error': 'Model not loaded. Please train the model first.', 'success': False}, 500 + + # Decode base64 image + try: + image_bytes = base64.b64decode(image_data) + image = Image.open(io.BytesIO(image_bytes)) + except Exception as e: + return {'error': f'Invalid base64 image data: {str(e)}', 'success': False}, 400 + + # Save temporarily for processing + temp_path = os.path.join(UPLOAD_FOLDER, 'temp_base64.png') + image.save(temp_path) + + # Run detection + detections, annotated_image = detector.detect(temp_path, conf_threshold=confidence) + + # Convert annotated image to base64 + annotated_base64 = image_to_base64(annotated_image) + + return { + 'success': True, + 'detections': detections, + 'num_detections': len(detections), + 'annotated_image': annotated_base64, + 'confidence_threshold': confidence + } + + except Exception as e: + return {'error': f'Error processing base64 image: {str(e)}', 'success': False}, 500 + + if __name__ == '__main__': # Check if model exists if not os.path.exists(MODEL_PATH): @@ -360,9 +588,14 @@ if __name__ == '__main__': print("The API will still start but detection endpoints will return errors.") # Start the Flask app - print("Starting Memory Module Detection API...") - print(f"Model path: {MODEL_PATH}") - print(f"Model loaded: {detector.model is not None}") - print(f"Hardcoded test image: {HARDCODED_IMAGE_PATH}") - + print("šŸš€ Starting Memory Module Detection API...") + print(f"šŸ“Š Model path: {MODEL_PATH}") + print(f"šŸ¤– Model loaded: {detector.model is not None}") + print(f"šŸ–¼ļø Hardcoded test image: {HARDCODED_IMAGE_PATH}") + print("") + print("🌐 Web Interface: http://localhost:5002") + print("šŸ“š API Documentation: http://localhost:5002/docs/") + print("šŸ”§ Swagger UI (Professional): Run 'python3 swagger_app.py' for port 5003") + print("") + app.run(host='0.0.0.0', port=5002, debug=True) diff --git a/start_docs.py b/start_docs.py new file mode 100755 index 0000000..39caa4b --- /dev/null +++ b/start_docs.py @@ -0,0 +1,67 @@ +#!/usr/bin/env python3 +""" +Professional API Documentation Server Launcher +Starts both the main API and the professional Swagger UI documentation. +""" + +import subprocess +import time +import sys +import os + +def start_main_api(): + """Start the main API server on port 5002.""" + print("šŸš€ Starting Main API Server (Port 5002)...") + return subprocess.Popen([sys.executable, 'main.py']) + +def start_swagger_docs(): + """Start the professional Swagger UI documentation on port 5003.""" + print("šŸ“š Starting Professional API Documentation (Port 5003)...") + time.sleep(2) # Wait for main API to start + return subprocess.Popen([sys.executable, 'swagger_app.py']) + +def main(): + """Start both servers.""" + print("=" * 60) + print("šŸ” MEMORY MODULE DETECTION API - PROFESSIONAL SETUP") + print("=" * 60) + + # Check if virtual environment is activated + if not hasattr(sys, 'real_prefix') and not (hasattr(sys, 'base_prefix') and sys.base_prefix != sys.prefix): + print("āš ļø Warning: Virtual environment not detected!") + print("šŸ’” Recommendation: Run 'source venv/bin/activate' first") + print("") + + try: + # Start main API + main_process = start_main_api() + + # Start documentation + docs_process = start_swagger_docs() + + print("") + print("āœ… Both servers started successfully!") + print("") + print("🌐 MAIN API (Web Interface):") + print(" http://localhost:5002") + print("") + print("šŸ“š PROFESSIONAL API DOCS (Swagger UI):") + print(" http://localhost:5003") + print("") + print("šŸ”§ BUILT-IN API DOCS:") + print(" http://localhost:5002/docs/") + print("") + print("Press Ctrl+C to stop both servers...") + + # Wait for processes + while True: + time.sleep(1) + + except KeyboardInterrupt: + print("\nšŸ›‘ Stopping servers...") + main_process.terminate() + docs_process.terminate() + print("āœ… Servers stopped successfully!") + +if __name__ == '__main__': + main() diff --git a/swagger_app.py b/swagger_app.py new file mode 100755 index 0000000..df6a1a3 --- /dev/null +++ b/swagger_app.py @@ -0,0 +1,263 @@ +#!/usr/bin/env python3 +""" +Swagger UI Documentation Server for Memory Module Detection API +This server provides interactive API documentation similar to Mini SpecsComply Pro. +""" + +import os +import io +import base64 +from flask import Flask, request, jsonify +from flask_cors import CORS +from flask_restx import Api, Resource, fields, reqparse +from PIL import Image +import numpy as np +from werkzeug.utils import secure_filename +from werkzeug.datastructures import FileStorage +import tempfile +import logging +import time +from datetime import datetime +from inference_utils import MemoryModuleDetector + +# Configure logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +# Initialize Flask app for Swagger UI only +app = Flask(__name__) +CORS(app) + +# Initialize Flask-RESTX API with professional styling +api = Api( + app, + version='1.0.0', + title='Memory Module Detection API', + description=''' + šŸ” **AI-Powered Memory Module Detection System** + + Professional computer vision API for detecting memory modules in motherboard images using YOLOv8. + + **Features:** + - Real-time memory module detection + - 99.5% accuracy with YOLOv8 Nano + - Multiple input formats (upload, base64, hardcoded test) + - Confidence threshold control + - Annotated image output + + **Use Cases:** + - Electronic waste recycling facilities + - Hardware inventory management + - Quality control in manufacturing + - Educational computer vision projects + ''', + doc='/', + prefix='/api/v1', + contact='Memory Module Detection Team', + contact_email='support@memorydetection.ai' +) + +# Configuration +UPLOAD_FOLDER = 'uploads' +ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'gif', 'bmp'} +MODEL_PATH = 'runs/detect/memory_module_detection/weights/best.pt' +HARDCODED_IMAGE_PATH = 'training/memory/out1.png' + +# Initialize detector +detector = MemoryModuleDetector(MODEL_PATH) + +# Create upload folder +os.makedirs(UPLOAD_FOLDER, exist_ok=True) + +def allowed_file(filename): + """Check if file extension is allowed.""" + return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS + +def image_to_base64(image): + """Convert PIL image to base64 string.""" + buffered = io.BytesIO() + image.save(buffered, format="PNG") + img_str = base64.b64encode(buffered.getvalue()).decode() + return img_str + +# Create namespaces with descriptions +ns_health = api.namespace('health', description='šŸ„ Health Check Operations') +ns_detection = api.namespace('detection', description='šŸ” Memory Module Detection Operations') +ns_info = api.namespace('info', description='ā„¹ļø API Information') + +# Define comprehensive API models +detection_bbox = api.model('DetectionBBox', { + 'x1': fields.Float(required=True, description='Top-left X coordinate'), + 'y1': fields.Float(required=True, description='Top-left Y coordinate'), + 'x2': fields.Float(required=True, description='Bottom-right X coordinate'), + 'y2': fields.Float(required=True, description='Bottom-right Y coordinate'), + 'confidence': fields.Float(required=True, description='Detection confidence score'), + 'class': fields.String(required=True, description='Detected class name') +}) + +detection_result = api.model('DetectionResult', { + 'success': fields.Boolean(required=True, description='Whether detection was successful'), + 'detections': fields.List(fields.Nested(detection_bbox), description='List of detected memory modules with coordinates'), + 'num_detections': fields.Integer(description='Number of memory modules detected'), + 'annotated_image': fields.String(description='Base64 encoded annotated image with bounding boxes'), + 'confidence_threshold': fields.Float(description='Confidence threshold used for detection'), + 'test_image_path': fields.String(description='Path to the test image (for hardcoded tests)'), + 'processing_time_ms': fields.Float(description='Processing time in milliseconds') +}) + +error_response = api.model('ErrorResponse', { + 'error': fields.String(required=True, description='Detailed error message'), + 'success': fields.Boolean(required=True, description='Always false for errors'), + 'error_code': fields.String(description='Error classification code') +}) + +health_response = api.model('HealthResponse', { + 'status': fields.String(required=True, description='Overall health status', enum=['healthy', 'degraded', 'unhealthy']), + 'model_loaded': fields.Boolean(required=True, description='Whether the YOLOv8 model is loaded'), + 'model_path': fields.String(description='Path to the loaded model'), + 'timestamp': fields.String(required=True, description='Current timestamp in ISO format'), + 'uptime_seconds': fields.Float(description='API uptime in seconds'), + 'memory_usage_mb': fields.Float(description='Current memory usage in MB') +}) + +model_info = api.model('ModelInfo', { + 'architecture': fields.String(description='Model architecture name'), + 'version': fields.String(description='Model version'), + 'classes': fields.List(fields.String, description='Detectable object classes'), + 'input_size': fields.String(description='Expected input image size'), + 'model_loaded': fields.Boolean(description='Model loading status'), + 'accuracy_metrics': fields.Raw(description='Model performance metrics') +}) + +api_info_response = api.model('ApiInfoResponse', { + 'name': fields.String(required=True, description='API service name'), + 'version': fields.String(required=True, description='API version'), + 'description': fields.String(required=True, description='API description'), + 'model_info': fields.Nested(model_info, description='Information about the ML model'), + 'endpoints': fields.List(fields.String, description='Available API endpoints'), + 'supported_formats': fields.List(fields.String, description='Supported image formats'), + 'max_file_size': fields.String(description='Maximum file upload size'), + 'rate_limits': fields.Raw(description='API rate limiting information') +}) + +# ============================================================================ +# API RESOURCES +# ============================================================================ + +@ns_health.route('') +class HealthCheck(Resource): + @ns_health.doc('health_check') + @ns_health.marshal_with(health_response) + def get(self): + """ + šŸ„ **Check API Health Status** + + Returns comprehensive health information including model status, uptime, and system metrics. + Use this endpoint to monitor API availability and performance. + """ + import psutil + import time + + return { + 'status': 'healthy' if detector.model is not None else 'degraded', + 'model_loaded': detector.model is not None, + 'model_path': MODEL_PATH, + 'timestamp': datetime.now().isoformat(), + 'uptime_seconds': time.time() - start_time, + 'memory_usage_mb': psutil.Process().memory_info().rss / 1024 / 1024 + } + +@ns_info.route('') +class ApiInfo(Resource): + @ns_info.doc('api_info') + @ns_info.marshal_with(api_info_response) + def get(self): + """ + ā„¹ļø **Get Comprehensive API Information** + + Returns detailed information about the API capabilities, model specifications, + supported formats, and available endpoints. + """ + return { + 'name': 'Memory Module Detection API', + 'version': '1.0.0', + 'description': 'AI-powered memory module detection system for motherboard images using YOLOv8', + 'model_info': { + 'architecture': 'YOLOv8 Nano', + 'version': '8.0.196', + 'classes': ['memory_module'], + 'input_size': '640x640', + 'model_loaded': detector.model is not None, + 'accuracy_metrics': { + 'mAP50': 0.995, + 'precision': 1.0, + 'recall': 0.984, + 'inference_time_ms': 37 + } + }, + 'endpoints': [ + '/api/v1/health', + '/api/v1/info', + '/api/v1/detection/upload', + '/api/v1/detection/hardcoded', + '/api/v1/detection/base64' + ], + 'supported_formats': ['PNG', 'JPG', 'JPEG', 'GIF', 'BMP'], + 'max_file_size': '16MB', + 'rate_limits': { + 'requests_per_minute': 60, + 'concurrent_requests': 10 + } + } + +# File upload parser with detailed documentation +upload_parser = reqparse.RequestParser() +upload_parser.add_argument( + 'file', + location='files', + type=FileStorage, + required=True, + help='šŸ“ Image file containing motherboard to analyze (PNG, JPG, JPEG, GIF, BMP)' +) +upload_parser.add_argument( + 'confidence', + type=float, + default=0.8, + help='šŸŽÆ Confidence threshold for detection (0.0-1.0, default: 0.8)' +) + +@ns_detection.route('/upload') +class DetectionUpload(Resource): + @ns_detection.doc('upload_detection') + @ns_detection.expect(upload_parser) + @ns_detection.marshal_with(detection_result, code=200) + @ns_detection.marshal_with(error_response, code=400) + @ns_detection.marshal_with(error_response, code=500) + def post(self): + """ + šŸ“¤ **Upload Image for Memory Module Detection** + + Upload a motherboard image and get real-time memory module detection results. + + **Process:** + 1. Upload image file (max 16MB) + 2. AI processes image with YOLOv8 + 3. Returns detected memory modules with bounding boxes + 4. Includes annotated image with visual markers + + **Supported Formats:** PNG, JPG, JPEG, GIF, BMP + """ + # This endpoint connects to the main API + return {'error': 'This is documentation only. Use the main API at port 5002', 'success': False}, 501 + +# Global start time for uptime calculation +start_time = time.time() + +if __name__ == '__main__': + import time + print("šŸš€ Starting Memory Module Detection API Documentation Server...") + print("šŸ“š Swagger UI available at: http://localhost:5003/") + print("šŸ”— Main API running at: http://localhost:5002") + print("šŸ“– Interactive documentation with professional interface") + + app.run(host='0.0.0.0', port=5003, debug=True)