64 lines
2.0 KiB
Python
64 lines
2.0 KiB
Python
import torch
|
|
from torch.utils.data import DataLoader
|
|
|
|
# --- 引入你的模块 ---
|
|
from dataset import YOLODataset
|
|
from yolo11_standalone import YOLO11
|
|
from loss import YOLOv8DetectionLoss
|
|
|
|
# --- 引入刚刚写的 Trainer ---
|
|
from trainer import YOLO11Trainer
|
|
|
|
def run_training():
|
|
# --- 1.全局配置 ---
|
|
img_dir_train = "E:\\Datasets\\coco\\images\\train2017"
|
|
label_dir_train = "E:\\Datasets\\coco\\labels\\train2017"
|
|
img_dir_val = "E:\\Datasets\\coco\\images\\val2017"
|
|
label_dir_val = "E:\\Datasets\\coco\\labels\\val2017"
|
|
|
|
epochs = 50
|
|
batch_size = 36
|
|
img_size = 640
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
# --- 2. 准备数据 ---
|
|
print("Loading Data...")
|
|
train_dataset = YOLODataset(img_dir_train, label_dir_train, img_size=img_size, is_train=True)
|
|
val_dataset = YOLODataset(img_dir_val, label_dir_val, img_size=img_size, is_train=False)
|
|
|
|
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,
|
|
collate_fn=YOLODataset.collate_fn, num_workers=8, pin_memory=True)
|
|
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False,
|
|
collate_fn=YOLODataset.collate_fn, num_workers=8, pin_memory=True)
|
|
|
|
# --- 3. 初始化模型 ---
|
|
print("Initializing Model...")
|
|
model = YOLO11(nc=80, scale='s')
|
|
# model.load_weights("yolo11s.pth")
|
|
# model.to(device)
|
|
|
|
strides = [8, 16, 32]
|
|
hyp = {
|
|
'box': 7.5,
|
|
'cls': 0.5,
|
|
'dfl': 1.5
|
|
}
|
|
loss_fn = YOLOv8DetectionLoss(nc=80, reg_max=16, stride=strides, hyp=hyp)
|
|
|
|
# --- 5. 初始化 Trainer 并开始训练 ---
|
|
trainer = YOLO11Trainer(
|
|
model=model,
|
|
train_loader=train_loader,
|
|
val_loader=val_loader,
|
|
loss_fn=loss_fn,
|
|
epochs=epochs,
|
|
lr0=0.01,
|
|
device=device,
|
|
save_dir='./my_yolo_result'
|
|
)
|
|
|
|
trainer.train()
|
|
|
|
if __name__ == "__main__":
|
|
# Windows下多进程dataloader需要这个保护
|
|
run_training() |