134 lines
4.1 KiB
Python
134 lines
4.1 KiB
Python
|
|
#!/usr/bin/env python3
|
||
|
|
"""
|
||
|
|
Setup script for Memory Module Detection Project
|
||
|
|
This script helps users set up the project quickly.
|
||
|
|
"""
|
||
|
|
|
||
|
|
import os
|
||
|
|
import sys
|
||
|
|
import subprocess
|
||
|
|
import time
|
||
|
|
|
||
|
|
def run_command(command, description):
|
||
|
|
"""Run a command and handle errors."""
|
||
|
|
print(f"🔄 {description}...")
|
||
|
|
try:
|
||
|
|
result = subprocess.run(command, shell=True, check=True, capture_output=True, text=True)
|
||
|
|
print(f"✅ {description} completed successfully")
|
||
|
|
return True
|
||
|
|
except subprocess.CalledProcessError as e:
|
||
|
|
print(f"❌ {description} failed:")
|
||
|
|
print(f" Command: {command}")
|
||
|
|
print(f" Error: {e.stderr}")
|
||
|
|
return False
|
||
|
|
|
||
|
|
def check_python_version():
|
||
|
|
"""Check if Python version is compatible."""
|
||
|
|
print("🐍 Checking Python version...")
|
||
|
|
version = sys.version_info
|
||
|
|
if version.major == 3 and version.minor >= 8:
|
||
|
|
print(f"✅ Python {version.major}.{version.minor}.{version.micro} is compatible")
|
||
|
|
return True
|
||
|
|
else:
|
||
|
|
print(f"❌ Python {version.major}.{version.minor}.{version.micro} is not compatible")
|
||
|
|
print(" Please use Python 3.8 or higher")
|
||
|
|
return False
|
||
|
|
|
||
|
|
def check_files():
|
||
|
|
"""Check if required files exist."""
|
||
|
|
print("📁 Checking project files...")
|
||
|
|
required_files = [
|
||
|
|
'requirements.txt',
|
||
|
|
'main.py',
|
||
|
|
'train.py',
|
||
|
|
'inference_utils.py',
|
||
|
|
'prepare_dataset.py',
|
||
|
|
'dataset.yaml',
|
||
|
|
'training/memory',
|
||
|
|
'training/no_memory'
|
||
|
|
]
|
||
|
|
|
||
|
|
missing_files = []
|
||
|
|
for file in required_files:
|
||
|
|
if not os.path.exists(file):
|
||
|
|
missing_files.append(file)
|
||
|
|
|
||
|
|
if missing_files:
|
||
|
|
print(f"❌ Missing files: {missing_files}")
|
||
|
|
return False
|
||
|
|
else:
|
||
|
|
print("✅ All required files found")
|
||
|
|
return True
|
||
|
|
|
||
|
|
def install_dependencies():
|
||
|
|
"""Install Python dependencies."""
|
||
|
|
if not run_command("pip install -r requirements.txt", "Installing dependencies"):
|
||
|
|
print(" Try using: pip3 install -r requirements.txt")
|
||
|
|
return run_command("pip3 install -r requirements.txt", "Installing dependencies with pip3")
|
||
|
|
return True
|
||
|
|
|
||
|
|
def prepare_dataset():
|
||
|
|
"""Prepare the dataset structure."""
|
||
|
|
if os.path.exists('training/train/images') and os.path.exists('training/val/images'):
|
||
|
|
print("✅ Dataset already prepared")
|
||
|
|
return True
|
||
|
|
|
||
|
|
return run_command("python3 prepare_dataset.py", "Preparing dataset structure")
|
||
|
|
|
||
|
|
def train_model():
|
||
|
|
"""Train the YOLOv8 model."""
|
||
|
|
model_path = 'runs/detect/memory_module_detection/weights/best.pt'
|
||
|
|
if os.path.exists(model_path):
|
||
|
|
print("✅ Model already trained")
|
||
|
|
return True
|
||
|
|
|
||
|
|
print("🤖 Training YOLOv8 model...")
|
||
|
|
print(" This may take 5-60 minutes depending on your hardware...")
|
||
|
|
return run_command("python3 train.py --epochs 50 --batch 8", "Training YOLOv8 model")
|
||
|
|
|
||
|
|
def test_setup():
|
||
|
|
"""Test the setup by running a quick inference."""
|
||
|
|
print("🧪 Testing setup...")
|
||
|
|
return run_command("python3 test_api.py", "Running API tests")
|
||
|
|
|
||
|
|
def main():
|
||
|
|
"""Main setup function."""
|
||
|
|
print("🚀 Memory Module Detection Project Setup")
|
||
|
|
print("=" * 50)
|
||
|
|
|
||
|
|
# Check prerequisites
|
||
|
|
if not check_python_version():
|
||
|
|
return False
|
||
|
|
|
||
|
|
if not check_files():
|
||
|
|
return False
|
||
|
|
|
||
|
|
# Setup steps
|
||
|
|
steps = [
|
||
|
|
("Install Dependencies", install_dependencies),
|
||
|
|
("Prepare Dataset", prepare_dataset),
|
||
|
|
("Train Model", train_model)
|
||
|
|
]
|
||
|
|
|
||
|
|
for step_name, step_func in steps:
|
||
|
|
print(f"\n📋 Step: {step_name}")
|
||
|
|
if not step_func():
|
||
|
|
print(f"❌ Setup failed at step: {step_name}")
|
||
|
|
return False
|
||
|
|
|
||
|
|
print("\n" + "=" * 50)
|
||
|
|
print("🎉 Setup completed successfully!")
|
||
|
|
print("\n📖 Next steps:")
|
||
|
|
print("1. Start the API:")
|
||
|
|
print(" python3 main.py")
|
||
|
|
print("\n2. Test the API (in another terminal):")
|
||
|
|
print(" python3 test_api.py")
|
||
|
|
print("\n3. Or test manually:")
|
||
|
|
print(" curl http://localhost:5000/detect/hardcoded")
|
||
|
|
|
||
|
|
return True
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
success = main()
|
||
|
|
sys.exit(0 if success else 1)
|