更新yolo11训练流程来适应静态detect头和新的后处理模块

This commit is contained in:
lhr
2025-12-30 17:29:36 +08:00
parent 553a63f521
commit f4b1f341fc
2 changed files with 18 additions and 55 deletions

View File

@@ -1,16 +1,11 @@
import torch import torch
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
# --- 引入你的模块 ---
from dataset import YOLODataset from dataset import YOLODataset
from yolo11_standalone import YOLO11 from yolo11_standalone import YOLO11, YOLOPostProcessor
from loss import YOLOv8DetectionLoss from loss import YOLOv8DetectionLoss
# --- 引入刚刚写的 Trainer ---
from trainer import YOLO11Trainer from trainer import YOLO11Trainer
def run_training(): def run_training():
# --- 1.全局配置 ---
img_dir_train = "E:\\Datasets\\coco\\images\\train2017" img_dir_train = "E:\\Datasets\\coco\\images\\train2017"
label_dir_train = "E:\\Datasets\\coco\\labels\\train2017" label_dir_train = "E:\\Datasets\\coco\\labels\\train2017"
img_dir_val = "E:\\Datasets\\coco\\images\\val2017" img_dir_val = "E:\\Datasets\\coco\\images\\val2017"
@@ -21,7 +16,6 @@ def run_training():
img_size = 640 img_size = 640
device = "cuda" if torch.cuda.is_available() else "cpu" device = "cuda" if torch.cuda.is_available() else "cpu"
# --- 2. 准备数据 ---
print("Loading Data...") print("Loading Data...")
train_dataset = YOLODataset(img_dir_train, label_dir_train, img_size=img_size, is_train=True) train_dataset = YOLODataset(img_dir_train, label_dir_train, img_size=img_size, is_train=True)
val_dataset = YOLODataset(img_dir_val, label_dir_val, img_size=img_size, is_train=False) val_dataset = YOLODataset(img_dir_val, label_dir_val, img_size=img_size, is_train=False)
@@ -31,7 +25,6 @@ def run_training():
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False,
collate_fn=YOLODataset.collate_fn, num_workers=8, pin_memory=True) collate_fn=YOLODataset.collate_fn, num_workers=8, pin_memory=True)
# --- 3. 初始化模型 ---
print("Initializing Model...") print("Initializing Model...")
model = YOLO11(nc=80, scale='s') model = YOLO11(nc=80, scale='s')
# model.load_weights("yolo11s.pth") # model.load_weights("yolo11s.pth")
@@ -45,9 +38,9 @@ def run_training():
} }
loss_fn = YOLOv8DetectionLoss(nc=80, reg_max=16, stride=strides, hyp=hyp) loss_fn = YOLOv8DetectionLoss(nc=80, reg_max=16, stride=strides, hyp=hyp)
# --- 5. 初始化 Trainer 并开始训练 ---
trainer = YOLO11Trainer( trainer = YOLO11Trainer(
model=model, model=model,
post_std=YOLOPostProcessor(model.model[-1], use_segmentation=False),
train_loader=train_loader, train_loader=train_loader,
val_loader=val_loader, val_loader=val_loader,
loss_fn=loss_fn, loss_fn=loss_fn,
@@ -60,5 +53,4 @@ def run_training():
trainer.train() trainer.train()
if __name__ == "__main__": if __name__ == "__main__":
# Windows下多进程dataloader需要这个保护
run_training() run_training()

View File

@@ -10,19 +10,13 @@ import torch.optim as optim
from tqdm import tqdm from tqdm import tqdm
from metrics import ap_per_class, box_iou, non_max_suppression, xywh2xyxy from metrics import ap_per_class, box_iou, non_max_suppression, xywh2xyxy
# 配置日志
logging.basicConfig(format="%(message)s", level=logging.INFO) logging.basicConfig(format="%(message)s", level=logging.INFO)
LOGGER = logging.getLogger("YOLO_Trainer") LOGGER = logging.getLogger("YOLO_Trainer")
# ==============================================================================
# Helper Class: Model EMA (Exponential Moving Average)
# ==============================================================================
class ModelEMA: class ModelEMA:
""" Updated Exponential Moving Average (EMA) from Ultralytics """
def __init__(self, model, decay=0.9999, tau=2000, updates=0): def __init__(self, model, decay=0.9999, tau=2000, updates=0):
self.ema = copy.deepcopy(model).eval() # FP32 EMA self.ema = copy.deepcopy(model).eval()
self.updates = updates self.updates = updates
# decay exponential ramp (to help early epochs)
self.decay = lambda x: decay * (1 - math.exp(-x / tau)) self.decay = lambda x: decay * (1 - math.exp(-x / tau))
for p in self.ema.parameters(): for p in self.ema.parameters():
p.requires_grad_(False) p.requires_grad_(False)
@@ -39,12 +33,10 @@ class ModelEMA:
v *= d v *= d
v += (1 - d) * tmp v += (1 - d) * tmp
# ==============================================================================
# Main Trainer Class
# ==============================================================================
class YOLO11Trainer: class YOLO11Trainer:
def __init__(self, def __init__(self,
model, model,
post_std,
train_loader, train_loader,
val_loader, val_loader,
loss_fn, loss_fn,
@@ -57,6 +49,7 @@ class YOLO11Trainer:
self.device = torch.device(device if torch.cuda.is_available() else 'cpu') self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
self.model = model.to(self.device) self.model = model.to(self.device)
self.post_std = post_std
self.train_loader = train_loader self.train_loader = train_loader
self.val_loader = val_loader self.val_loader = val_loader
self.loss_fn = loss_fn self.loss_fn = loss_fn
@@ -66,18 +59,17 @@ class YOLO11Trainer:
self.warmup_epochs = warmup_epochs self.warmup_epochs = warmup_epochs
self.start_epoch = 0 self.start_epoch = 0
# --- Optimizer Building --- g = [], [], []
g = [], [], [] # optimizer parameter groups
bn = tuple(v for k, v in nn.__dict__.items() if 'Norm' in k) bn = tuple(v for k, v in nn.__dict__.items() if 'Norm' in k)
for v in self.model.modules(): for v in self.model.modules():
for p_name, p in v.named_parameters(recurse=False): for p_name, p in v.named_parameters(recurse=False):
if p_name == 'bias': if p_name == 'bias':
g[2].append(p) # biases g[2].append(p)
elif isinstance(v, bn): elif isinstance(v, bn):
g[1].append(p) # bn weights (no decay) g[1].append(p)
else: else:
g[0].append(p) # weights (decay) g[0].append(p)
self.optimizer = optim.SGD(g[2], lr=lr0, momentum=0.937, nesterov=True) self.optimizer = optim.SGD(g[2], lr=lr0, momentum=0.937, nesterov=True)
self.optimizer.add_param_group({'params': g[0], 'weight_decay': 0.0005}) self.optimizer.add_param_group({'params': g[0], 'weight_decay': 0.0005})
@@ -85,11 +77,8 @@ class YOLO11Trainer:
LOGGER.info(f"Optimizer: weights={len(g[0])}, bn={len(g[1])}, biases={len(g[2])}") LOGGER.info(f"Optimizer: weights={len(g[0])}, bn={len(g[1])}, biases={len(g[2])}")
# --- Scheduler ---
self.lf = lambda x: ((1 - math.cos(x * math.pi / self.epochs)) / 2) * (lrf - 1) + 1 self.lf = lambda x: ((1 - math.cos(x * math.pi / self.epochs)) / 2) * (lrf - 1) + 1
self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=self.lf) self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=self.lf)
# --- AMP & EMA ---
self.scaler = torch.amp.GradScaler('cuda', enabled=True) self.scaler = torch.amp.GradScaler('cuda', enabled=True)
self.ema = ModelEMA(self.model) self.ema = ModelEMA(self.model)
@@ -109,7 +98,6 @@ class YOLO11Trainer:
imgs = imgs.to(self.device, non_blocking=True) imgs = imgs.to(self.device, non_blocking=True)
targets = targets.to(self.device) targets = targets.to(self.device)
# --- Warmup ---
if ni <= nb * self.warmup_epochs: if ni <= nb * self.warmup_epochs:
xp = [0, nb * self.warmup_epochs] xp = [0, nb * self.warmup_epochs]
for j, x in enumerate(self.optimizer.param_groups): for j, x in enumerate(self.optimizer.param_groups):
@@ -118,7 +106,6 @@ class YOLO11Trainer:
if 'momentum' in x: if 'momentum' in x:
x['momentum'] = np.interp(ni, xp, [0.8, 0.937]) x['momentum'] = np.interp(ni, xp, [0.8, 0.937])
# --- Forward ---
with torch.amp.autocast('cuda', enabled=True): with torch.amp.autocast('cuda', enabled=True):
preds = self.model(imgs) preds = self.model(imgs)
@@ -130,7 +117,6 @@ class YOLO11Trainer:
loss, loss_items = self.loss_fn(preds, target_batch) loss, loss_items = self.loss_fn(preds, target_batch)
# --- Backward ---
self.scaler.scale(loss).backward() self.scaler.scale(loss).backward()
self.scaler.unscale_(self.optimizer) self.scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=10.0) torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=10.0)
@@ -138,10 +124,8 @@ class YOLO11Trainer:
self.scaler.update() self.scaler.update()
self.optimizer.zero_grad() self.optimizer.zero_grad()
# --- EMA ---
self.ema.update(self.model) self.ema.update(self.model)
# --- Logging ---
loss_items = loss_items.detach() loss_items = loss_items.detach()
mloss = (mloss * i + loss_items) / (i + 1) mloss = (mloss * i + loss_items) / (i + 1)
mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G' mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G'
@@ -158,13 +142,12 @@ class YOLO11Trainer:
model = self.ema.ema model = self.ema.ema
device = self.device device = self.device
# --- Metrics Config --- conf_thres = 0.001
conf_thres = 0.001 # Low threshold for mAP calculation iou_thres = 0.7
iou_thres = 0.7 # NMS IoU threshold iouv = torch.linspace(0.5, 0.95, 10, device=device)
iouv = torch.linspace(0.5, 0.95, 10, device=device) # IoU vector for mAP@0.5:0.95
loss_sum = torch.zeros(3, device=device) loss_sum = torch.zeros(3, device=device)
stats = [] # [(correct, conf, pred_cls, target_cls)] stats = []
LOGGER.info("\nValidating...") LOGGER.info("\nValidating...")
@@ -178,13 +161,10 @@ class YOLO11Trainer:
targets = targets.to(device) targets = targets.to(device)
_, _, height, width = imgs.shape _, _, height, width = imgs.shape
# Inference
preds = model(imgs) preds = model(imgs)
preds = self.post_std(preds)
# NMS
preds = non_max_suppression(preds, conf_thres=conf_thres, iou_thres=iou_thres) preds = non_max_suppression(preds, conf_thres=conf_thres, iou_thres=iou_thres)
# Metrics Processing
for si, pred in enumerate(preds): for si, pred in enumerate(preds):
labels = targets[targets[:, 0] == si] labels = targets[targets[:, 0] == si]
nl = len(labels) nl = len(labels)
@@ -196,17 +176,13 @@ class YOLO11Trainer:
torch.Tensor(), torch.Tensor(), torch.tensor(tcls))) torch.Tensor(), torch.Tensor(), torch.tensor(tcls)))
continue continue
# Predictions
predn = pred.clone() predn = pred.clone()
# Ground Truth
if nl: if nl:
tbox = xywh2xyxy(labels[:, 2:6]) tbox = xywh2xyxy(labels[:, 2:6])
tbox[:, [0, 2]] *= width tbox[:, [0, 2]] *= width
tbox[:, [1, 3]] *= height tbox[:, [1, 3]] *= height
labelsn = torch.cat((labels[:, 1:2], tbox), 1) # [cls, x1, y1, x2, y2] labelsn = torch.cat((labels[:, 1:2], tbox), 1)
# Match predictions to GT
correct = self._process_batch(predn, labelsn, iouv) correct = self._process_batch(predn, labelsn, iouv)
else: else:
correct = torch.zeros(pred.shape[0], iouv.numel(), dtype=torch.bool) correct = torch.zeros(pred.shape[0], iouv.numel(), dtype=torch.bool)
@@ -215,9 +191,9 @@ class YOLO11Trainer:
mp, mr, map50, map5095 = 0.0, 0.0, 0.0, 0.0 mp, mr, map50, map5095 = 0.0, 0.0, 0.0, 0.0
if len(stats) and stats[0][0].any(): if len(stats) and stats[0][0].any():
stats = [torch.cat(x, 0).cpu().numpy() for x in zip(*stats)] # to numpy stats = [torch.cat(x, 0).cpu().numpy() for x in zip(*stats)]
tp, fp, p, r, f1, ap, ap_class = ap_per_class(*stats) tp, fp, p, r, f1, ap, ap_class = ap_per_class(*stats)
ap50, ap = ap[:, 0], ap.mean(1) # AP@0.5, AP@0.5:0.95 ap50, ap = ap[:, 0], ap.mean(1)
mp, mr, map50, map5095 = p.mean(), r.mean(), ap50.mean(), ap.mean() mp, mr, map50, map5095 = p.mean(), r.mean(), ap50.mean(), ap.mean()
LOGGER.info(f"Val Results: Prec={mp:.3f}, Recall={mr:.3f} mAP50={map50:.3f} mAP50-95={map5095:.3f}") LOGGER.info(f"Val Results: Prec={mp:.3f}, Recall={mr:.3f} mAP50={map50:.3f} mAP50-95={map5095:.3f}")
@@ -225,11 +201,6 @@ class YOLO11Trainer:
return map50 return map50
def _process_batch(self, detections, labels, iouv): def _process_batch(self, detections, labels, iouv):
"""
Return correct prediction matrix
detections: [N, 6] (x1, y1, x2, y2, conf, cls)
labels: [M, 5] (cls, x1, y1, x2, y2)
"""
correct = torch.zeros(detections.shape[0], iouv.shape[0], dtype=torch.bool, device=iouv.device) correct = torch.zeros(detections.shape[0], iouv.shape[0], dtype=torch.bool, device=iouv.device)
iou = box_iou(labels[:, 1:], detections[:, :4]) iou = box_iou(labels[:, 1:], detections[:, :4])