Files
recycling-project-solutions/train.py
T

159 lines
5.8 KiB
Python
Raw Normal View History

2025-07-11 20:07:36 +01:00
#!/usr/bin/env python3
"""
YOLOv8 Training Script for Memory Module Detection
This script trains a YOLOv8 nano model to detect memory modules in motherboard images.
"""
import os
import sys
from pathlib import Path
import yaml
from ultralytics import YOLO
import torch
def check_dataset_structure():
"""Verify that the dataset structure is correct."""
required_paths = [
'training/train/images',
'training/train/labels',
'training/val/images',
'training/val/labels',
'dataset.yaml'
]
for path in required_paths:
if not os.path.exists(path):
raise FileNotFoundError(f"Required path not found: {path}")
# Check if we have images and labels
train_images = len([f for f in os.listdir('training/train/images') if f.endswith('.png')])
train_labels = len([f for f in os.listdir('training/train/labels') if f.endswith('.txt')])
val_images = len([f for f in os.listdir('training/val/images') if f.endswith('.png')])
val_labels = len([f for f in os.listdir('training/val/labels') if f.endswith('.txt')])
print(f"Dataset structure verified:")
print(f" Training: {train_images} images, {train_labels} labels")
print(f" Validation: {val_images} images, {val_labels} labels")
return True
def train_model(epochs=100, imgsz=640, batch_size=16, device='auto'):
"""
Train YOLOv8 nano model on memory module dataset.
Args:
epochs (int): Number of training epochs
imgsz (int): Image size for training
batch_size (int): Batch size for training
device (str): Device to use ('auto', 'cpu', 'cuda', or specific GPU id)
"""
# Check dataset structure
check_dataset_structure()
# Initialize YOLOv8 nano model
print("Initializing YOLOv8 nano model...")
model = YOLO('yolov8n.pt') # Load pretrained YOLOv8 nano model
# Check available device
if device == 'auto':
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
print(f"CUDA device: {torch.cuda.get_device_name()}")
# Training configuration
train_config = {
'data': 'dataset.yaml',
'epochs': epochs,
'imgsz': imgsz,
'batch': batch_size,
'device': device,
'project': 'runs/detect',
'name': 'memory_module_detection',
'save': True,
'save_period': 10, # Save checkpoint every 10 epochs
'cache': False, # Don't cache images (saves RAM)
'workers': 4,
'patience': 50, # Early stopping patience
'optimizer': 'AdamW',
'lr0': 0.01, # Initial learning rate
'lrf': 0.01, # Final learning rate factor
'momentum': 0.937,
'weight_decay': 0.0005,
'warmup_epochs': 3,
'warmup_momentum': 0.8,
'warmup_bias_lr': 0.1,
'box': 7.5, # Box loss gain
'cls': 0.5, # Class loss gain
'dfl': 1.5, # DFL loss gain
'pose': 12.0, # Pose loss gain
'kobj': 1.0, # Keypoint obj loss gain
'label_smoothing': 0.0,
'nbs': 64, # Nominal batch size
'hsv_h': 0.015, # Image HSV-Hue augmentation
'hsv_s': 0.7, # Image HSV-Saturation augmentation
'hsv_v': 0.4, # Image HSV-Value augmentation
'degrees': 0.0, # Image rotation (+/- deg)
'translate': 0.1, # Image translation (+/- fraction)
'scale': 0.5, # Image scale (+/- gain)
'shear': 0.0, # Image shear (+/- deg)
'perspective': 0.0, # Image perspective (+/- fraction)
'flipud': 0.0, # Image flip up-down (probability)
'fliplr': 0.5, # Image flip left-right (probability)
'mosaic': 1.0, # Image mosaic (probability)
'mixup': 0.0, # Image mixup (probability)
'copy_paste': 0.0, # Segment copy-paste (probability)
}
print("Starting training...")
print(f"Configuration: {train_config}")
# Train the model
results = model.train(**train_config)
# Print training results
print("\nTraining completed!")
print(f"Best model saved at: runs/detect/memory_module_detection/weights/best.pt")
print(f"Last model saved at: runs/detect/memory_module_detection/weights/last.pt")
return results
def validate_model(model_path='runs/detect/memory_module_detection/weights/best.pt'):
"""Validate the trained model."""
if not os.path.exists(model_path):
print(f"Model not found at {model_path}")
return None
print(f"Validating model: {model_path}")
model = YOLO(model_path)
# Run validation
results = model.val(data='dataset.yaml')
print("Validation completed!")
return results
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description='Train YOLOv8 for memory module detection')
parser.add_argument('--epochs', type=int, default=100, help='Number of training epochs')
parser.add_argument('--imgsz', type=int, default=640, help='Image size for training')
parser.add_argument('--batch', type=int, default=16, help='Batch size')
parser.add_argument('--device', type=str, default='auto', help='Device to use (auto, cpu, cuda)')
parser.add_argument('--validate', action='store_true', help='Only run validation')
parser.add_argument('--model', type=str, default='runs/detect/memory_module_detection/weights/best.pt',
help='Model path for validation')
args = parser.parse_args()
if args.validate:
validate_model(args.model)
else:
train_model(epochs=args.epochs, imgsz=args.imgsz, batch_size=args.batch, device=args.device)
# Also run validation after training
validate_model()