更新yolo11训练流程来适应静态detect头和新的后处理模块

This commit is contained in:
lhr
2025-12-30 17:29:36 +08:00
parent 553a63f521
commit f4b1f341fc
2 changed files with 18 additions and 55 deletions

View File

@@ -1,16 +1,11 @@
import torch
from torch.utils.data import DataLoader
# --- 引入你的模块 ---
from dataset import YOLODataset
from yolo11_standalone import YOLO11
from yolo11_standalone import YOLO11, YOLOPostProcessor
from loss import YOLOv8DetectionLoss
# --- 引入刚刚写的 Trainer ---
from trainer import YOLO11Trainer
def run_training():
# --- 1.全局配置 ---
img_dir_train = "E:\\Datasets\\coco\\images\\train2017"
label_dir_train = "E:\\Datasets\\coco\\labels\\train2017"
img_dir_val = "E:\\Datasets\\coco\\images\\val2017"
@@ -21,7 +16,6 @@ def run_training():
img_size = 640
device = "cuda" if torch.cuda.is_available() else "cpu"
# --- 2. 准备数据 ---
print("Loading Data...")
train_dataset = YOLODataset(img_dir_train, label_dir_train, img_size=img_size, is_train=True)
val_dataset = YOLODataset(img_dir_val, label_dir_val, img_size=img_size, is_train=False)
@@ -31,7 +25,6 @@ def run_training():
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False,
collate_fn=YOLODataset.collate_fn, num_workers=8, pin_memory=True)
# --- 3. 初始化模型 ---
print("Initializing Model...")
model = YOLO11(nc=80, scale='s')
# model.load_weights("yolo11s.pth")
@@ -45,9 +38,9 @@ def run_training():
}
loss_fn = YOLOv8DetectionLoss(nc=80, reg_max=16, stride=strides, hyp=hyp)
# --- 5. 初始化 Trainer 并开始训练 ---
trainer = YOLO11Trainer(
model=model,
post_std=YOLOPostProcessor(model.model[-1], use_segmentation=False),
train_loader=train_loader,
val_loader=val_loader,
loss_fn=loss_fn,
@@ -60,5 +53,4 @@ def run_training():
trainer.train()
if __name__ == "__main__":
# Windows下多进程dataloader需要这个保护
run_training()

View File

@@ -10,19 +10,13 @@ import torch.optim as optim
from tqdm import tqdm
from metrics import ap_per_class, box_iou, non_max_suppression, xywh2xyxy
# 配置日志
logging.basicConfig(format="%(message)s", level=logging.INFO)
LOGGER = logging.getLogger("YOLO_Trainer")
# ==============================================================================
# Helper Class: Model EMA (Exponential Moving Average)
# ==============================================================================
class ModelEMA:
""" Updated Exponential Moving Average (EMA) from Ultralytics """
def __init__(self, model, decay=0.9999, tau=2000, updates=0):
self.ema = copy.deepcopy(model).eval() # FP32 EMA
self.ema = copy.deepcopy(model).eval()
self.updates = updates
# decay exponential ramp (to help early epochs)
self.decay = lambda x: decay * (1 - math.exp(-x / tau))
for p in self.ema.parameters():
p.requires_grad_(False)
@@ -39,12 +33,10 @@ class ModelEMA:
v *= d
v += (1 - d) * tmp
# ==============================================================================
# Main Trainer Class
# ==============================================================================
class YOLO11Trainer:
def __init__(self,
model,
post_std,
train_loader,
val_loader,
loss_fn,
@@ -57,6 +49,7 @@ class YOLO11Trainer:
self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
self.model = model.to(self.device)
self.post_std = post_std
self.train_loader = train_loader
self.val_loader = val_loader
self.loss_fn = loss_fn
@@ -66,18 +59,17 @@ class YOLO11Trainer:
self.warmup_epochs = warmup_epochs
self.start_epoch = 0
# --- Optimizer Building ---
g = [], [], [] # optimizer parameter groups
g = [], [], []
bn = tuple(v for k, v in nn.__dict__.items() if 'Norm' in k)
for v in self.model.modules():
for p_name, p in v.named_parameters(recurse=False):
if p_name == 'bias':
g[2].append(p) # biases
g[2].append(p)
elif isinstance(v, bn):
g[1].append(p) # bn weights (no decay)
g[1].append(p)
else:
g[0].append(p) # weights (decay)
g[0].append(p)
self.optimizer = optim.SGD(g[2], lr=lr0, momentum=0.937, nesterov=True)
self.optimizer.add_param_group({'params': g[0], 'weight_decay': 0.0005})
@@ -85,11 +77,8 @@ class YOLO11Trainer:
LOGGER.info(f"Optimizer: weights={len(g[0])}, bn={len(g[1])}, biases={len(g[2])}")
# --- Scheduler ---
self.lf = lambda x: ((1 - math.cos(x * math.pi / self.epochs)) / 2) * (lrf - 1) + 1
self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=self.lf)
# --- AMP & EMA ---
self.scaler = torch.amp.GradScaler('cuda', enabled=True)
self.ema = ModelEMA(self.model)
@@ -109,7 +98,6 @@ class YOLO11Trainer:
imgs = imgs.to(self.device, non_blocking=True)
targets = targets.to(self.device)
# --- Warmup ---
if ni <= nb * self.warmup_epochs:
xp = [0, nb * self.warmup_epochs]
for j, x in enumerate(self.optimizer.param_groups):
@@ -118,7 +106,6 @@ class YOLO11Trainer:
if 'momentum' in x:
x['momentum'] = np.interp(ni, xp, [0.8, 0.937])
# --- Forward ---
with torch.amp.autocast('cuda', enabled=True):
preds = self.model(imgs)
@@ -130,7 +117,6 @@ class YOLO11Trainer:
loss, loss_items = self.loss_fn(preds, target_batch)
# --- Backward ---
self.scaler.scale(loss).backward()
self.scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=10.0)
@@ -138,10 +124,8 @@ class YOLO11Trainer:
self.scaler.update()
self.optimizer.zero_grad()
# --- EMA ---
self.ema.update(self.model)
# --- Logging ---
loss_items = loss_items.detach()
mloss = (mloss * i + loss_items) / (i + 1)
mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G'
@@ -158,13 +142,12 @@ class YOLO11Trainer:
model = self.ema.ema
device = self.device
# --- Metrics Config ---
conf_thres = 0.001 # Low threshold for mAP calculation
iou_thres = 0.7 # NMS IoU threshold
iouv = torch.linspace(0.5, 0.95, 10, device=device) # IoU vector for mAP@0.5:0.95
conf_thres = 0.001
iou_thres = 0.7
iouv = torch.linspace(0.5, 0.95, 10, device=device)
loss_sum = torch.zeros(3, device=device)
stats = [] # [(correct, conf, pred_cls, target_cls)]
stats = []
LOGGER.info("\nValidating...")
@@ -178,13 +161,10 @@ class YOLO11Trainer:
targets = targets.to(device)
_, _, height, width = imgs.shape
# Inference
preds = model(imgs)
# NMS
preds = self.post_std(preds)
preds = non_max_suppression(preds, conf_thres=conf_thres, iou_thres=iou_thres)
# Metrics Processing
for si, pred in enumerate(preds):
labels = targets[targets[:, 0] == si]
nl = len(labels)
@@ -196,17 +176,13 @@ class YOLO11Trainer:
torch.Tensor(), torch.Tensor(), torch.tensor(tcls)))
continue
# Predictions
predn = pred.clone()
# Ground Truth
if nl:
tbox = xywh2xyxy(labels[:, 2:6])
tbox[:, [0, 2]] *= width
tbox[:, [1, 3]] *= height
labelsn = torch.cat((labels[:, 1:2], tbox), 1) # [cls, x1, y1, x2, y2]
# Match predictions to GT
labelsn = torch.cat((labels[:, 1:2], tbox), 1)
correct = self._process_batch(predn, labelsn, iouv)
else:
correct = torch.zeros(pred.shape[0], iouv.numel(), dtype=torch.bool)
@@ -215,9 +191,9 @@ class YOLO11Trainer:
mp, mr, map50, map5095 = 0.0, 0.0, 0.0, 0.0
if len(stats) and stats[0][0].any():
stats = [torch.cat(x, 0).cpu().numpy() for x in zip(*stats)] # to numpy
stats = [torch.cat(x, 0).cpu().numpy() for x in zip(*stats)]
tp, fp, p, r, f1, ap, ap_class = ap_per_class(*stats)
ap50, ap = ap[:, 0], ap.mean(1) # AP@0.5, AP@0.5:0.95
ap50, ap = ap[:, 0], ap.mean(1)
mp, mr, map50, map5095 = p.mean(), r.mean(), ap50.mean(), ap.mean()
LOGGER.info(f"Val Results: Prec={mp:.3f}, Recall={mr:.3f} mAP50={map50:.3f} mAP50-95={map5095:.3f}")
@@ -225,11 +201,6 @@ class YOLO11Trainer:
return map50
def _process_batch(self, detections, labels, iouv):
"""
Return correct prediction matrix
detections: [N, 6] (x1, y1, x2, y2, conf, cls)
labels: [M, 5] (cls, x1, y1, x2, y2)
"""
correct = torch.zeros(detections.shape[0], iouv.shape[0], dtype=torch.bool, device=iouv.device)
iou = box_iou(labels[:, 1:], detections[:, :4])