更新yolo11训练流程来适应静态detect头和新的后处理模块
This commit is contained in:
12
train.py
12
train.py
@@ -1,16 +1,11 @@
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
# --- 引入你的模块 ---
|
||||
from dataset import YOLODataset
|
||||
from yolo11_standalone import YOLO11
|
||||
from yolo11_standalone import YOLO11, YOLOPostProcessor
|
||||
from loss import YOLOv8DetectionLoss
|
||||
|
||||
# --- 引入刚刚写的 Trainer ---
|
||||
from trainer import YOLO11Trainer
|
||||
|
||||
def run_training():
|
||||
# --- 1.全局配置 ---
|
||||
img_dir_train = "E:\\Datasets\\coco\\images\\train2017"
|
||||
label_dir_train = "E:\\Datasets\\coco\\labels\\train2017"
|
||||
img_dir_val = "E:\\Datasets\\coco\\images\\val2017"
|
||||
@@ -21,7 +16,6 @@ def run_training():
|
||||
img_size = 640
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
# --- 2. 准备数据 ---
|
||||
print("Loading Data...")
|
||||
train_dataset = YOLODataset(img_dir_train, label_dir_train, img_size=img_size, is_train=True)
|
||||
val_dataset = YOLODataset(img_dir_val, label_dir_val, img_size=img_size, is_train=False)
|
||||
@@ -31,7 +25,6 @@ def run_training():
|
||||
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False,
|
||||
collate_fn=YOLODataset.collate_fn, num_workers=8, pin_memory=True)
|
||||
|
||||
# --- 3. 初始化模型 ---
|
||||
print("Initializing Model...")
|
||||
model = YOLO11(nc=80, scale='s')
|
||||
# model.load_weights("yolo11s.pth")
|
||||
@@ -45,9 +38,9 @@ def run_training():
|
||||
}
|
||||
loss_fn = YOLOv8DetectionLoss(nc=80, reg_max=16, stride=strides, hyp=hyp)
|
||||
|
||||
# --- 5. 初始化 Trainer 并开始训练 ---
|
||||
trainer = YOLO11Trainer(
|
||||
model=model,
|
||||
post_std=YOLOPostProcessor(model.model[-1], use_segmentation=False),
|
||||
train_loader=train_loader,
|
||||
val_loader=val_loader,
|
||||
loss_fn=loss_fn,
|
||||
@@ -60,5 +53,4 @@ def run_training():
|
||||
trainer.train()
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Windows下多进程dataloader需要这个保护
|
||||
run_training()
|
||||
Reference in New Issue
Block a user