更新yolo11训练流程来适应静态detect头和新的后处理模块
This commit is contained in:
12
train.py
12
train.py
@@ -1,16 +1,11 @@
|
|||||||
import torch
|
import torch
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
# --- 引入你的模块 ---
|
|
||||||
from dataset import YOLODataset
|
from dataset import YOLODataset
|
||||||
from yolo11_standalone import YOLO11
|
from yolo11_standalone import YOLO11, YOLOPostProcessor
|
||||||
from loss import YOLOv8DetectionLoss
|
from loss import YOLOv8DetectionLoss
|
||||||
|
|
||||||
# --- 引入刚刚写的 Trainer ---
|
|
||||||
from trainer import YOLO11Trainer
|
from trainer import YOLO11Trainer
|
||||||
|
|
||||||
def run_training():
|
def run_training():
|
||||||
# --- 1.全局配置 ---
|
|
||||||
img_dir_train = "E:\\Datasets\\coco\\images\\train2017"
|
img_dir_train = "E:\\Datasets\\coco\\images\\train2017"
|
||||||
label_dir_train = "E:\\Datasets\\coco\\labels\\train2017"
|
label_dir_train = "E:\\Datasets\\coco\\labels\\train2017"
|
||||||
img_dir_val = "E:\\Datasets\\coco\\images\\val2017"
|
img_dir_val = "E:\\Datasets\\coco\\images\\val2017"
|
||||||
@@ -21,7 +16,6 @@ def run_training():
|
|||||||
img_size = 640
|
img_size = 640
|
||||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
|
||||||
# --- 2. 准备数据 ---
|
|
||||||
print("Loading Data...")
|
print("Loading Data...")
|
||||||
train_dataset = YOLODataset(img_dir_train, label_dir_train, img_size=img_size, is_train=True)
|
train_dataset = YOLODataset(img_dir_train, label_dir_train, img_size=img_size, is_train=True)
|
||||||
val_dataset = YOLODataset(img_dir_val, label_dir_val, img_size=img_size, is_train=False)
|
val_dataset = YOLODataset(img_dir_val, label_dir_val, img_size=img_size, is_train=False)
|
||||||
@@ -31,7 +25,6 @@ def run_training():
|
|||||||
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False,
|
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False,
|
||||||
collate_fn=YOLODataset.collate_fn, num_workers=8, pin_memory=True)
|
collate_fn=YOLODataset.collate_fn, num_workers=8, pin_memory=True)
|
||||||
|
|
||||||
# --- 3. 初始化模型 ---
|
|
||||||
print("Initializing Model...")
|
print("Initializing Model...")
|
||||||
model = YOLO11(nc=80, scale='s')
|
model = YOLO11(nc=80, scale='s')
|
||||||
# model.load_weights("yolo11s.pth")
|
# model.load_weights("yolo11s.pth")
|
||||||
@@ -45,9 +38,9 @@ def run_training():
|
|||||||
}
|
}
|
||||||
loss_fn = YOLOv8DetectionLoss(nc=80, reg_max=16, stride=strides, hyp=hyp)
|
loss_fn = YOLOv8DetectionLoss(nc=80, reg_max=16, stride=strides, hyp=hyp)
|
||||||
|
|
||||||
# --- 5. 初始化 Trainer 并开始训练 ---
|
|
||||||
trainer = YOLO11Trainer(
|
trainer = YOLO11Trainer(
|
||||||
model=model,
|
model=model,
|
||||||
|
post_std=YOLOPostProcessor(model.model[-1], use_segmentation=False),
|
||||||
train_loader=train_loader,
|
train_loader=train_loader,
|
||||||
val_loader=val_loader,
|
val_loader=val_loader,
|
||||||
loss_fn=loss_fn,
|
loss_fn=loss_fn,
|
||||||
@@ -60,5 +53,4 @@ def run_training():
|
|||||||
trainer.train()
|
trainer.train()
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Windows下多进程dataloader需要这个保护
|
|
||||||
run_training()
|
run_training()
|
||||||
59
trainer.py
59
trainer.py
@@ -10,19 +10,13 @@ import torch.optim as optim
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from metrics import ap_per_class, box_iou, non_max_suppression, xywh2xyxy
|
from metrics import ap_per_class, box_iou, non_max_suppression, xywh2xyxy
|
||||||
|
|
||||||
# 配置日志
|
|
||||||
logging.basicConfig(format="%(message)s", level=logging.INFO)
|
logging.basicConfig(format="%(message)s", level=logging.INFO)
|
||||||
LOGGER = logging.getLogger("YOLO_Trainer")
|
LOGGER = logging.getLogger("YOLO_Trainer")
|
||||||
|
|
||||||
# ==============================================================================
|
|
||||||
# Helper Class: Model EMA (Exponential Moving Average)
|
|
||||||
# ==============================================================================
|
|
||||||
class ModelEMA:
|
class ModelEMA:
|
||||||
""" Updated Exponential Moving Average (EMA) from Ultralytics """
|
|
||||||
def __init__(self, model, decay=0.9999, tau=2000, updates=0):
|
def __init__(self, model, decay=0.9999, tau=2000, updates=0):
|
||||||
self.ema = copy.deepcopy(model).eval() # FP32 EMA
|
self.ema = copy.deepcopy(model).eval()
|
||||||
self.updates = updates
|
self.updates = updates
|
||||||
# decay exponential ramp (to help early epochs)
|
|
||||||
self.decay = lambda x: decay * (1 - math.exp(-x / tau))
|
self.decay = lambda x: decay * (1 - math.exp(-x / tau))
|
||||||
for p in self.ema.parameters():
|
for p in self.ema.parameters():
|
||||||
p.requires_grad_(False)
|
p.requires_grad_(False)
|
||||||
@@ -39,12 +33,10 @@ class ModelEMA:
|
|||||||
v *= d
|
v *= d
|
||||||
v += (1 - d) * tmp
|
v += (1 - d) * tmp
|
||||||
|
|
||||||
# ==============================================================================
|
|
||||||
# Main Trainer Class
|
|
||||||
# ==============================================================================
|
|
||||||
class YOLO11Trainer:
|
class YOLO11Trainer:
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
model,
|
model,
|
||||||
|
post_std,
|
||||||
train_loader,
|
train_loader,
|
||||||
val_loader,
|
val_loader,
|
||||||
loss_fn,
|
loss_fn,
|
||||||
@@ -57,6 +49,7 @@ class YOLO11Trainer:
|
|||||||
|
|
||||||
self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
|
self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
|
||||||
self.model = model.to(self.device)
|
self.model = model.to(self.device)
|
||||||
|
self.post_std = post_std
|
||||||
self.train_loader = train_loader
|
self.train_loader = train_loader
|
||||||
self.val_loader = val_loader
|
self.val_loader = val_loader
|
||||||
self.loss_fn = loss_fn
|
self.loss_fn = loss_fn
|
||||||
@@ -66,18 +59,17 @@ class YOLO11Trainer:
|
|||||||
self.warmup_epochs = warmup_epochs
|
self.warmup_epochs = warmup_epochs
|
||||||
self.start_epoch = 0
|
self.start_epoch = 0
|
||||||
|
|
||||||
# --- Optimizer Building ---
|
g = [], [], []
|
||||||
g = [], [], [] # optimizer parameter groups
|
|
||||||
bn = tuple(v for k, v in nn.__dict__.items() if 'Norm' in k)
|
bn = tuple(v for k, v in nn.__dict__.items() if 'Norm' in k)
|
||||||
|
|
||||||
for v in self.model.modules():
|
for v in self.model.modules():
|
||||||
for p_name, p in v.named_parameters(recurse=False):
|
for p_name, p in v.named_parameters(recurse=False):
|
||||||
if p_name == 'bias':
|
if p_name == 'bias':
|
||||||
g[2].append(p) # biases
|
g[2].append(p)
|
||||||
elif isinstance(v, bn):
|
elif isinstance(v, bn):
|
||||||
g[1].append(p) # bn weights (no decay)
|
g[1].append(p)
|
||||||
else:
|
else:
|
||||||
g[0].append(p) # weights (decay)
|
g[0].append(p)
|
||||||
|
|
||||||
self.optimizer = optim.SGD(g[2], lr=lr0, momentum=0.937, nesterov=True)
|
self.optimizer = optim.SGD(g[2], lr=lr0, momentum=0.937, nesterov=True)
|
||||||
self.optimizer.add_param_group({'params': g[0], 'weight_decay': 0.0005})
|
self.optimizer.add_param_group({'params': g[0], 'weight_decay': 0.0005})
|
||||||
@@ -85,11 +77,8 @@ class YOLO11Trainer:
|
|||||||
|
|
||||||
LOGGER.info(f"Optimizer: weights={len(g[0])}, bn={len(g[1])}, biases={len(g[2])}")
|
LOGGER.info(f"Optimizer: weights={len(g[0])}, bn={len(g[1])}, biases={len(g[2])}")
|
||||||
|
|
||||||
# --- Scheduler ---
|
|
||||||
self.lf = lambda x: ((1 - math.cos(x * math.pi / self.epochs)) / 2) * (lrf - 1) + 1
|
self.lf = lambda x: ((1 - math.cos(x * math.pi / self.epochs)) / 2) * (lrf - 1) + 1
|
||||||
self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=self.lf)
|
self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=self.lf)
|
||||||
|
|
||||||
# --- AMP & EMA ---
|
|
||||||
self.scaler = torch.amp.GradScaler('cuda', enabled=True)
|
self.scaler = torch.amp.GradScaler('cuda', enabled=True)
|
||||||
self.ema = ModelEMA(self.model)
|
self.ema = ModelEMA(self.model)
|
||||||
|
|
||||||
@@ -109,7 +98,6 @@ class YOLO11Trainer:
|
|||||||
imgs = imgs.to(self.device, non_blocking=True)
|
imgs = imgs.to(self.device, non_blocking=True)
|
||||||
targets = targets.to(self.device)
|
targets = targets.to(self.device)
|
||||||
|
|
||||||
# --- Warmup ---
|
|
||||||
if ni <= nb * self.warmup_epochs:
|
if ni <= nb * self.warmup_epochs:
|
||||||
xp = [0, nb * self.warmup_epochs]
|
xp = [0, nb * self.warmup_epochs]
|
||||||
for j, x in enumerate(self.optimizer.param_groups):
|
for j, x in enumerate(self.optimizer.param_groups):
|
||||||
@@ -118,7 +106,6 @@ class YOLO11Trainer:
|
|||||||
if 'momentum' in x:
|
if 'momentum' in x:
|
||||||
x['momentum'] = np.interp(ni, xp, [0.8, 0.937])
|
x['momentum'] = np.interp(ni, xp, [0.8, 0.937])
|
||||||
|
|
||||||
# --- Forward ---
|
|
||||||
with torch.amp.autocast('cuda', enabled=True):
|
with torch.amp.autocast('cuda', enabled=True):
|
||||||
preds = self.model(imgs)
|
preds = self.model(imgs)
|
||||||
|
|
||||||
@@ -130,7 +117,6 @@ class YOLO11Trainer:
|
|||||||
|
|
||||||
loss, loss_items = self.loss_fn(preds, target_batch)
|
loss, loss_items = self.loss_fn(preds, target_batch)
|
||||||
|
|
||||||
# --- Backward ---
|
|
||||||
self.scaler.scale(loss).backward()
|
self.scaler.scale(loss).backward()
|
||||||
self.scaler.unscale_(self.optimizer)
|
self.scaler.unscale_(self.optimizer)
|
||||||
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=10.0)
|
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=10.0)
|
||||||
@@ -138,10 +124,8 @@ class YOLO11Trainer:
|
|||||||
self.scaler.update()
|
self.scaler.update()
|
||||||
self.optimizer.zero_grad()
|
self.optimizer.zero_grad()
|
||||||
|
|
||||||
# --- EMA ---
|
|
||||||
self.ema.update(self.model)
|
self.ema.update(self.model)
|
||||||
|
|
||||||
# --- Logging ---
|
|
||||||
loss_items = loss_items.detach()
|
loss_items = loss_items.detach()
|
||||||
mloss = (mloss * i + loss_items) / (i + 1)
|
mloss = (mloss * i + loss_items) / (i + 1)
|
||||||
mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G'
|
mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G'
|
||||||
@@ -158,13 +142,12 @@ class YOLO11Trainer:
|
|||||||
model = self.ema.ema
|
model = self.ema.ema
|
||||||
device = self.device
|
device = self.device
|
||||||
|
|
||||||
# --- Metrics Config ---
|
conf_thres = 0.001
|
||||||
conf_thres = 0.001 # Low threshold for mAP calculation
|
iou_thres = 0.7
|
||||||
iou_thres = 0.7 # NMS IoU threshold
|
iouv = torch.linspace(0.5, 0.95, 10, device=device)
|
||||||
iouv = torch.linspace(0.5, 0.95, 10, device=device) # IoU vector for mAP@0.5:0.95
|
|
||||||
|
|
||||||
loss_sum = torch.zeros(3, device=device)
|
loss_sum = torch.zeros(3, device=device)
|
||||||
stats = [] # [(correct, conf, pred_cls, target_cls)]
|
stats = []
|
||||||
|
|
||||||
LOGGER.info("\nValidating...")
|
LOGGER.info("\nValidating...")
|
||||||
|
|
||||||
@@ -178,13 +161,10 @@ class YOLO11Trainer:
|
|||||||
targets = targets.to(device)
|
targets = targets.to(device)
|
||||||
_, _, height, width = imgs.shape
|
_, _, height, width = imgs.shape
|
||||||
|
|
||||||
# Inference
|
|
||||||
preds = model(imgs)
|
preds = model(imgs)
|
||||||
|
preds = self.post_std(preds)
|
||||||
# NMS
|
|
||||||
preds = non_max_suppression(preds, conf_thres=conf_thres, iou_thres=iou_thres)
|
preds = non_max_suppression(preds, conf_thres=conf_thres, iou_thres=iou_thres)
|
||||||
|
|
||||||
# Metrics Processing
|
|
||||||
for si, pred in enumerate(preds):
|
for si, pred in enumerate(preds):
|
||||||
labels = targets[targets[:, 0] == si]
|
labels = targets[targets[:, 0] == si]
|
||||||
nl = len(labels)
|
nl = len(labels)
|
||||||
@@ -196,17 +176,13 @@ class YOLO11Trainer:
|
|||||||
torch.Tensor(), torch.Tensor(), torch.tensor(tcls)))
|
torch.Tensor(), torch.Tensor(), torch.tensor(tcls)))
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Predictions
|
|
||||||
predn = pred.clone()
|
predn = pred.clone()
|
||||||
|
|
||||||
# Ground Truth
|
|
||||||
if nl:
|
if nl:
|
||||||
tbox = xywh2xyxy(labels[:, 2:6])
|
tbox = xywh2xyxy(labels[:, 2:6])
|
||||||
tbox[:, [0, 2]] *= width
|
tbox[:, [0, 2]] *= width
|
||||||
tbox[:, [1, 3]] *= height
|
tbox[:, [1, 3]] *= height
|
||||||
labelsn = torch.cat((labels[:, 1:2], tbox), 1) # [cls, x1, y1, x2, y2]
|
labelsn = torch.cat((labels[:, 1:2], tbox), 1)
|
||||||
|
|
||||||
# Match predictions to GT
|
|
||||||
correct = self._process_batch(predn, labelsn, iouv)
|
correct = self._process_batch(predn, labelsn, iouv)
|
||||||
else:
|
else:
|
||||||
correct = torch.zeros(pred.shape[0], iouv.numel(), dtype=torch.bool)
|
correct = torch.zeros(pred.shape[0], iouv.numel(), dtype=torch.bool)
|
||||||
@@ -215,9 +191,9 @@ class YOLO11Trainer:
|
|||||||
|
|
||||||
mp, mr, map50, map5095 = 0.0, 0.0, 0.0, 0.0
|
mp, mr, map50, map5095 = 0.0, 0.0, 0.0, 0.0
|
||||||
if len(stats) and stats[0][0].any():
|
if len(stats) and stats[0][0].any():
|
||||||
stats = [torch.cat(x, 0).cpu().numpy() for x in zip(*stats)] # to numpy
|
stats = [torch.cat(x, 0).cpu().numpy() for x in zip(*stats)]
|
||||||
tp, fp, p, r, f1, ap, ap_class = ap_per_class(*stats)
|
tp, fp, p, r, f1, ap, ap_class = ap_per_class(*stats)
|
||||||
ap50, ap = ap[:, 0], ap.mean(1) # AP@0.5, AP@0.5:0.95
|
ap50, ap = ap[:, 0], ap.mean(1)
|
||||||
mp, mr, map50, map5095 = p.mean(), r.mean(), ap50.mean(), ap.mean()
|
mp, mr, map50, map5095 = p.mean(), r.mean(), ap50.mean(), ap.mean()
|
||||||
|
|
||||||
LOGGER.info(f"Val Results: Prec={mp:.3f}, Recall={mr:.3f} mAP50={map50:.3f} mAP50-95={map5095:.3f}")
|
LOGGER.info(f"Val Results: Prec={mp:.3f}, Recall={mr:.3f} mAP50={map50:.3f} mAP50-95={map5095:.3f}")
|
||||||
@@ -225,11 +201,6 @@ class YOLO11Trainer:
|
|||||||
return map50
|
return map50
|
||||||
|
|
||||||
def _process_batch(self, detections, labels, iouv):
|
def _process_batch(self, detections, labels, iouv):
|
||||||
"""
|
|
||||||
Return correct prediction matrix
|
|
||||||
detections: [N, 6] (x1, y1, x2, y2, conf, cls)
|
|
||||||
labels: [M, 5] (cls, x1, y1, x2, y2)
|
|
||||||
"""
|
|
||||||
correct = torch.zeros(detections.shape[0], iouv.shape[0], dtype=torch.bool, device=iouv.device)
|
correct = torch.zeros(detections.shape[0], iouv.shape[0], dtype=torch.bool, device=iouv.device)
|
||||||
iou = box_iou(labels[:, 1:], detections[:, :4])
|
iou = box_iou(labels[:, 1:], detections[:, :4])
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user