第一次提交Yolo项目
This commit is contained in:
64
train.py
Normal file
64
train.py
Normal file
@@ -0,0 +1,64 @@
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
# --- 引入你的模块 ---
|
||||
from dataset import YOLODataset
|
||||
from yolo11_standalone import YOLO11
|
||||
from loss import YOLOv8DetectionLoss
|
||||
|
||||
# --- 引入刚刚写的 Trainer ---
|
||||
from trainer import YOLO11Trainer
|
||||
|
||||
def run_training():
|
||||
# --- 1.全局配置 ---
|
||||
img_dir_train = "E:\\Datasets\\coco\\images\\train2017"
|
||||
label_dir_train = "E:\\Datasets\\coco\\labels\\train2017"
|
||||
img_dir_val = "E:\\Datasets\\coco\\images\\val2017"
|
||||
label_dir_val = "E:\\Datasets\\coco\\labels\\val2017"
|
||||
|
||||
epochs = 50
|
||||
batch_size = 36
|
||||
img_size = 640
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
# --- 2. 准备数据 ---
|
||||
print("Loading Data...")
|
||||
train_dataset = YOLODataset(img_dir_train, label_dir_train, img_size=img_size, is_train=True)
|
||||
val_dataset = YOLODataset(img_dir_val, label_dir_val, img_size=img_size, is_train=False)
|
||||
|
||||
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,
|
||||
collate_fn=YOLODataset.collate_fn, num_workers=8, pin_memory=True)
|
||||
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False,
|
||||
collate_fn=YOLODataset.collate_fn, num_workers=8, pin_memory=True)
|
||||
|
||||
# --- 3. 初始化模型 ---
|
||||
print("Initializing Model...")
|
||||
model = YOLO11(nc=80, scale='s')
|
||||
# model.load_weights("yolo11s.pth")
|
||||
# model.to(device)
|
||||
|
||||
strides = [8, 16, 32]
|
||||
hyp = {
|
||||
'box': 7.5,
|
||||
'cls': 0.5,
|
||||
'dfl': 1.5
|
||||
}
|
||||
loss_fn = YOLOv8DetectionLoss(nc=80, reg_max=16, stride=strides, hyp=hyp)
|
||||
|
||||
# --- 5. 初始化 Trainer 并开始训练 ---
|
||||
trainer = YOLO11Trainer(
|
||||
model=model,
|
||||
train_loader=train_loader,
|
||||
val_loader=val_loader,
|
||||
loss_fn=loss_fn,
|
||||
epochs=epochs,
|
||||
lr0=0.01,
|
||||
device=device,
|
||||
save_dir='./my_yolo_result'
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Windows下多进程dataloader需要这个保护
|
||||
run_training()
|
||||
Reference in New Issue
Block a user