import torch from torch.utils.data import DataLoader # --- 引入你的模块 --- from dataset import YOLODataset from yolo11_standalone import YOLO11 from loss import YOLOv8DetectionLoss # --- 引入刚刚写的 Trainer --- from trainer import YOLO11Trainer def run_training(): # --- 1.全局配置 --- img_dir_train = "E:\\Datasets\\coco\\images\\train2017" label_dir_train = "E:\\Datasets\\coco\\labels\\train2017" img_dir_val = "E:\\Datasets\\coco\\images\\val2017" label_dir_val = "E:\\Datasets\\coco\\labels\\val2017" epochs = 50 batch_size = 36 img_size = 640 device = "cuda" if torch.cuda.is_available() else "cpu" # --- 2. 准备数据 --- print("Loading Data...") train_dataset = YOLODataset(img_dir_train, label_dir_train, img_size=img_size, is_train=True) val_dataset = YOLODataset(img_dir_val, label_dir_val, img_size=img_size, is_train=False) train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=YOLODataset.collate_fn, num_workers=8, pin_memory=True) val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, collate_fn=YOLODataset.collate_fn, num_workers=8, pin_memory=True) # --- 3. 初始化模型 --- print("Initializing Model...") model = YOLO11(nc=80, scale='s') # model.load_weights("yolo11s.pth") # model.to(device) strides = [8, 16, 32] hyp = { 'box': 7.5, 'cls': 0.5, 'dfl': 1.5 } loss_fn = YOLOv8DetectionLoss(nc=80, reg_max=16, stride=strides, hyp=hyp) # --- 5. 初始化 Trainer 并开始训练 --- trainer = YOLO11Trainer( model=model, train_loader=train_loader, val_loader=val_loader, loss_fn=loss_fn, epochs=epochs, lr0=0.01, device=device, save_dir='./my_yolo_result' ) trainer.train() if __name__ == "__main__": # Windows下多进程dataloader需要这个保护 run_training()