更新yolo11训练流程来适应静态detect头和新的后处理模块

This commit is contained in:
lhr
2025-12-30 17:29:36 +08:00
parent 553a63f521
commit f4b1f341fc
2 changed files with 18 additions and 55 deletions

View File

@@ -1,16 +1,11 @@
import torch
from torch.utils.data import DataLoader
# --- 引入你的模块 ---
from dataset import YOLODataset
from yolo11_standalone import YOLO11
from yolo11_standalone import YOLO11, YOLOPostProcessor
from loss import YOLOv8DetectionLoss
# --- 引入刚刚写的 Trainer ---
from trainer import YOLO11Trainer
def run_training():
# --- 1.全局配置 ---
img_dir_train = "E:\\Datasets\\coco\\images\\train2017"
label_dir_train = "E:\\Datasets\\coco\\labels\\train2017"
img_dir_val = "E:\\Datasets\\coco\\images\\val2017"
@@ -21,7 +16,6 @@ def run_training():
img_size = 640
device = "cuda" if torch.cuda.is_available() else "cpu"
# --- 2. 准备数据 ---
print("Loading Data...")
train_dataset = YOLODataset(img_dir_train, label_dir_train, img_size=img_size, is_train=True)
val_dataset = YOLODataset(img_dir_val, label_dir_val, img_size=img_size, is_train=False)
@@ -31,7 +25,6 @@ def run_training():
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False,
collate_fn=YOLODataset.collate_fn, num_workers=8, pin_memory=True)
# --- 3. 初始化模型 ---
print("Initializing Model...")
model = YOLO11(nc=80, scale='s')
# model.load_weights("yolo11s.pth")
@@ -45,9 +38,9 @@ def run_training():
}
loss_fn = YOLOv8DetectionLoss(nc=80, reg_max=16, stride=strides, hyp=hyp)
# --- 5. 初始化 Trainer 并开始训练 ---
trainer = YOLO11Trainer(
model=model,
post_std=YOLOPostProcessor(model.model[-1], use_segmentation=False),
train_loader=train_loader,
val_loader=val_loader,
loss_fn=loss_fn,
@@ -60,5 +53,4 @@ def run_training():
trainer.train()
if __name__ == "__main__":
# Windows下多进程dataloader需要这个保护
run_training()