update
This commit is contained in:
@@ -1,158 +1,43 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
YOLOv8 Training Script for Memory Module Detection
|
||||
This script trains a YOLOv8 nano model to detect memory modules in motherboard images.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
import yaml
|
||||
from ultralytics import YOLO
|
||||
import torch
|
||||
|
||||
def check_dataset_structure():
|
||||
"""Verify that the dataset structure is correct."""
|
||||
required_paths = [
|
||||
'training/train/images',
|
||||
'training/train/labels',
|
||||
'training/val/images',
|
||||
'training/val/labels',
|
||||
'dataset.yaml'
|
||||
]
|
||||
def train_model():
|
||||
# Load YOLOv8n (nano) for faster training with decent accuracy
|
||||
model = YOLO('yolov8n.pt')
|
||||
|
||||
for path in required_paths:
|
||||
if not os.path.exists(path):
|
||||
raise FileNotFoundError(f"Required path not found: {path}")
|
||||
# Train with optimized parameters for speed and quality
|
||||
results = model.train(
|
||||
data='dataset.yaml',
|
||||
epochs=50, # Reduced number of epochs
|
||||
imgsz=640, # Standard image size for faster processing
|
||||
batch=8, # Smaller batch size for less memory usage
|
||||
name='memory_detector_fast',
|
||||
save=True,
|
||||
device='cpu',
|
||||
patience=15, # Shorter patience for earlier stopping
|
||||
save_period=5, # Save every 5 epochs
|
||||
verbose=True,
|
||||
|
||||
# Effective but lightweight augmentation
|
||||
degrees=5.0, # Less rotation for speed
|
||||
scale=0.5,
|
||||
translate=0.1,
|
||||
fliplr=0.5,
|
||||
mosaic=1.0, # Keep mosaic as it's very effective
|
||||
|
||||
# Speed-optimized optimization parameters
|
||||
lr0=0.01,
|
||||
lrf=0.01,
|
||||
momentum=0.937,
|
||||
weight_decay=0.0005,
|
||||
warmup_epochs=1.0, # Shorter warmup
|
||||
|
||||
# Performance parameters
|
||||
workers=0, # Fewer workers for CPU training
|
||||
cache='disk', # Changed to disk caching for deterministic results
|
||||
)
|
||||
|
||||
# Check if we have images and labels
|
||||
train_images = len([f for f in os.listdir('training/train/images') if f.endswith('.png')])
|
||||
train_labels = len([f for f in os.listdir('training/train/labels') if f.endswith('.txt')])
|
||||
val_images = len([f for f in os.listdir('training/val/images') if f.endswith('.png')])
|
||||
val_labels = len([f for f in os.listdir('training/val/labels') if f.endswith('.txt')])
|
||||
|
||||
print(f"Dataset structure verified:")
|
||||
print(f" Training: {train_images} images, {train_labels} labels")
|
||||
print(f" Validation: {val_images} images, {val_labels} labels")
|
||||
|
||||
return True
|
||||
# Save the trained model
|
||||
model.save('model/weights/best.pt')
|
||||
|
||||
def train_model(epochs=100, imgsz=640, batch_size=16, device='auto'):
|
||||
"""
|
||||
Train YOLOv8 nano model on memory module dataset.
|
||||
|
||||
Args:
|
||||
epochs (int): Number of training epochs
|
||||
imgsz (int): Image size for training
|
||||
batch_size (int): Batch size for training
|
||||
device (str): Device to use ('auto', 'cpu', 'cuda', or specific GPU id)
|
||||
"""
|
||||
|
||||
# Check dataset structure
|
||||
check_dataset_structure()
|
||||
|
||||
# Initialize YOLOv8 nano model
|
||||
print("Initializing YOLOv8 nano model...")
|
||||
model = YOLO('yolov8n.pt') # Load pretrained YOLOv8 nano model
|
||||
|
||||
# Check available device
|
||||
if device == 'auto':
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
|
||||
print(f"Using device: {device}")
|
||||
print(f"CUDA available: {torch.cuda.is_available()}")
|
||||
if torch.cuda.is_available():
|
||||
print(f"CUDA device: {torch.cuda.get_device_name()}")
|
||||
|
||||
# Training configuration
|
||||
train_config = {
|
||||
'data': 'dataset.yaml',
|
||||
'epochs': epochs,
|
||||
'imgsz': imgsz,
|
||||
'batch': batch_size,
|
||||
'device': device,
|
||||
'project': 'runs/detect',
|
||||
'name': 'memory_module_detection',
|
||||
'save': True,
|
||||
'save_period': 10, # Save checkpoint every 10 epochs
|
||||
'cache': False, # Don't cache images (saves RAM)
|
||||
'workers': 4,
|
||||
'patience': 50, # Early stopping patience
|
||||
'optimizer': 'AdamW',
|
||||
'lr0': 0.01, # Initial learning rate
|
||||
'lrf': 0.01, # Final learning rate factor
|
||||
'momentum': 0.937,
|
||||
'weight_decay': 0.0005,
|
||||
'warmup_epochs': 3,
|
||||
'warmup_momentum': 0.8,
|
||||
'warmup_bias_lr': 0.1,
|
||||
'box': 7.5, # Box loss gain
|
||||
'cls': 0.5, # Class loss gain
|
||||
'dfl': 1.5, # DFL loss gain
|
||||
'pose': 12.0, # Pose loss gain
|
||||
'kobj': 1.0, # Keypoint obj loss gain
|
||||
'label_smoothing': 0.0,
|
||||
'nbs': 64, # Nominal batch size
|
||||
'hsv_h': 0.015, # Image HSV-Hue augmentation
|
||||
'hsv_s': 0.7, # Image HSV-Saturation augmentation
|
||||
'hsv_v': 0.4, # Image HSV-Value augmentation
|
||||
'degrees': 0.0, # Image rotation (+/- deg)
|
||||
'translate': 0.1, # Image translation (+/- fraction)
|
||||
'scale': 0.5, # Image scale (+/- gain)
|
||||
'shear': 0.0, # Image shear (+/- deg)
|
||||
'perspective': 0.0, # Image perspective (+/- fraction)
|
||||
'flipud': 0.0, # Image flip up-down (probability)
|
||||
'fliplr': 0.5, # Image flip left-right (probability)
|
||||
'mosaic': 1.0, # Image mosaic (probability)
|
||||
'mixup': 0.0, # Image mixup (probability)
|
||||
'copy_paste': 0.0, # Segment copy-paste (probability)
|
||||
}
|
||||
|
||||
print("Starting training...")
|
||||
print(f"Configuration: {train_config}")
|
||||
|
||||
# Train the model
|
||||
results = model.train(**train_config)
|
||||
|
||||
# Print training results
|
||||
print("\nTraining completed!")
|
||||
print(f"Best model saved at: runs/detect/memory_module_detection/weights/best.pt")
|
||||
print(f"Last model saved at: runs/detect/memory_module_detection/weights/last.pt")
|
||||
|
||||
return results
|
||||
|
||||
def validate_model(model_path='runs/detect/memory_module_detection/weights/best.pt'):
|
||||
"""Validate the trained model."""
|
||||
if not os.path.exists(model_path):
|
||||
print(f"Model not found at {model_path}")
|
||||
return None
|
||||
|
||||
print(f"Validating model: {model_path}")
|
||||
model = YOLO(model_path)
|
||||
|
||||
# Run validation
|
||||
results = model.val(data='dataset.yaml')
|
||||
|
||||
print("Validation completed!")
|
||||
return results
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description='Train YOLOv8 for memory module detection')
|
||||
parser.add_argument('--epochs', type=int, default=100, help='Number of training epochs')
|
||||
parser.add_argument('--imgsz', type=int, default=640, help='Image size for training')
|
||||
parser.add_argument('--batch', type=int, default=16, help='Batch size')
|
||||
parser.add_argument('--device', type=str, default='auto', help='Device to use (auto, cpu, cuda)')
|
||||
parser.add_argument('--validate', action='store_true', help='Only run validation')
|
||||
parser.add_argument('--model', type=str, default='runs/detect/memory_module_detection/weights/best.pt',
|
||||
help='Model path for validation')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.validate:
|
||||
validate_model(args.model)
|
||||
else:
|
||||
train_model(epochs=args.epochs, imgsz=args.imgsz, batch_size=args.batch, device=args.device)
|
||||
# Also run validation after training
|
||||
validate_model()
|
||||
if __name__ == '__main__':
|
||||
train_model()
|
||||
Reference in New Issue
Block a user