Files
Yolo-standalone/train.py
2025-12-27 02:14:11 +08:00

64 lines
2.0 KiB
Python

import torch
from torch.utils.data import DataLoader
# --- 引入你的模块 ---
from dataset import YOLODataset
from yolo11_standalone import YOLO11
from loss import YOLOv8DetectionLoss
# --- 引入刚刚写的 Trainer ---
from trainer import YOLO11Trainer
def run_training():
# --- 1.全局配置 ---
img_dir_train = "E:\\Datasets\\coco\\images\\train2017"
label_dir_train = "E:\\Datasets\\coco\\labels\\train2017"
img_dir_val = "E:\\Datasets\\coco\\images\\val2017"
label_dir_val = "E:\\Datasets\\coco\\labels\\val2017"
epochs = 50
batch_size = 36
img_size = 640
device = "cuda" if torch.cuda.is_available() else "cpu"
# --- 2. 准备数据 ---
print("Loading Data...")
train_dataset = YOLODataset(img_dir_train, label_dir_train, img_size=img_size, is_train=True)
val_dataset = YOLODataset(img_dir_val, label_dir_val, img_size=img_size, is_train=False)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,
collate_fn=YOLODataset.collate_fn, num_workers=8, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False,
collate_fn=YOLODataset.collate_fn, num_workers=8, pin_memory=True)
# --- 3. 初始化模型 ---
print("Initializing Model...")
model = YOLO11(nc=80, scale='s')
# model.load_weights("yolo11s.pth")
# model.to(device)
strides = [8, 16, 32]
hyp = {
'box': 7.5,
'cls': 0.5,
'dfl': 1.5
}
loss_fn = YOLOv8DetectionLoss(nc=80, reg_max=16, stride=strides, hyp=hyp)
# --- 5. 初始化 Trainer 并开始训练 ---
trainer = YOLO11Trainer(
model=model,
train_loader=train_loader,
val_loader=val_loader,
loss_fn=loss_fn,
epochs=epochs,
lr0=0.01,
device=device,
save_dir='./my_yolo_result'
)
trainer.train()
if __name__ == "__main__":
# Windows下多进程dataloader需要这个保护
run_training()