import math import copy import time import logging from pathlib import Path import numpy as np import torch import torch.nn as nn import torch.optim as optim from tqdm import tqdm from metrics import ap_per_class, box_iou, non_max_suppression, xywh2xyxy # 配置日志 logging.basicConfig(format="%(message)s", level=logging.INFO) LOGGER = logging.getLogger("YOLO_Trainer") # ============================================================================== # Helper Class: Model EMA (Exponential Moving Average) # ============================================================================== class ModelEMA: """ Updated Exponential Moving Average (EMA) from Ultralytics """ def __init__(self, model, decay=0.9999, tau=2000, updates=0): self.ema = copy.deepcopy(model).eval() # FP32 EMA self.updates = updates # decay exponential ramp (to help early epochs) self.decay = lambda x: decay * (1 - math.exp(-x / tau)) for p in self.ema.parameters(): p.requires_grad_(False) def update(self, model): self.updates += 1 d = self.decay(self.updates) msd = model.state_dict() for k, v in self.ema.state_dict().items(): if k in msd: tmp = msd[k].to(v.device) if v.dtype.is_floating_point: v *= d v += (1 - d) * tmp # ============================================================================== # Main Trainer Class # ============================================================================== class YOLO11Trainer: def __init__(self, model, train_loader, val_loader, loss_fn, epochs=100, lr0=0.01, lrf=0.01, device='cuda', save_dir='./runs/train', warmup_epochs=3.0): self.device = torch.device(device if torch.cuda.is_available() else 'cpu') self.model = model.to(self.device) self.train_loader = train_loader self.val_loader = val_loader self.loss_fn = loss_fn self.epochs = epochs self.save_dir = Path(save_dir) self.save_dir.mkdir(parents=True, exist_ok=True) self.warmup_epochs = warmup_epochs self.start_epoch = 0 # --- Optimizer Building --- g = [], [], [] # optimizer parameter groups bn = tuple(v for k, v in nn.__dict__.items() if 'Norm' in k) for v in self.model.modules(): for p_name, p in v.named_parameters(recurse=False): if p_name == 'bias': g[2].append(p) # biases elif isinstance(v, bn): g[1].append(p) # bn weights (no decay) else: g[0].append(p) # weights (decay) self.optimizer = optim.SGD(g[2], lr=lr0, momentum=0.937, nesterov=True) self.optimizer.add_param_group({'params': g[0], 'weight_decay': 0.0005}) self.optimizer.add_param_group({'params': g[1], 'weight_decay': 0.0}) LOGGER.info(f"Optimizer: weights={len(g[0])}, bn={len(g[1])}, biases={len(g[2])}") # --- Scheduler --- self.lf = lambda x: ((1 - math.cos(x * math.pi / self.epochs)) / 2) * (lrf - 1) + 1 self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=self.lf) # --- AMP & EMA --- self.scaler = torch.amp.GradScaler('cuda', enabled=True) self.ema = ModelEMA(self.model) def train_one_epoch(self, epoch): self.model.train() pbar = tqdm(enumerate(self.train_loader), total=len(self.train_loader), desc=f'Epoch {epoch+1}/{self.epochs}', leave=True) mloss = torch.zeros(3, device=self.device) self.optimizer.zero_grad() nb = len(self.train_loader) for i, batch in pbar: ni = i + nb * epoch imgs, targets, paths = batch imgs = imgs.to(self.device, non_blocking=True) targets = targets.to(self.device) # --- Warmup --- if ni <= nb * self.warmup_epochs: xp = [0, nb * self.warmup_epochs] for j, x in enumerate(self.optimizer.param_groups): lr_target = x['initial_lr'] * self.lf(epoch) x['lr'] = np.interp(ni, xp, [0.1 if j == 0 else 0.0, lr_target]) if 'momentum' in x: x['momentum'] = np.interp(ni, xp, [0.8, 0.937]) # --- Forward --- with torch.amp.autocast('cuda', enabled=True): preds = self.model(imgs) target_batch = { "batch_idx": targets[:, 0], "cls": targets[:, 1], "bboxes": targets[:, 2:], } loss, loss_items = self.loss_fn(preds, target_batch) # --- Backward --- self.scaler.scale(loss).backward() self.scaler.unscale_(self.optimizer) torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=10.0) self.scaler.step(self.optimizer) self.scaler.update() self.optimizer.zero_grad() # --- EMA --- self.ema.update(self.model) # --- Logging --- loss_items = loss_items.detach() mloss = (mloss * i + loss_items) / (i + 1) mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G' pbar.set_postfix({ 'mem': mem, 'box': f"{mloss[0]:.4f}", 'cls': f"{mloss[1]:.4f}", 'dfl': f"{mloss[2]:.4f}", 'lr': f"{self.optimizer.param_groups[0]['lr']:.5f}" }) def validate(self): model = self.ema.ema device = self.device # --- Metrics Config --- conf_thres = 0.001 # Low threshold for mAP calculation iou_thres = 0.7 # NMS IoU threshold iouv = torch.linspace(0.5, 0.95, 10, device=device) # IoU vector for mAP@0.5:0.95 loss_sum = torch.zeros(3, device=device) stats = [] # [(correct, conf, pred_cls, target_cls)] LOGGER.info("\nValidating...") model.eval() pbar = tqdm(self.val_loader, desc="Calc Metrics") with torch.no_grad(): for batch in pbar: imgs, targets, _ = batch imgs = imgs.to(device, non_blocking=True) targets = targets.to(device) _, _, height, width = imgs.shape # Inference preds = model(imgs) # NMS preds = non_max_suppression(preds, conf_thres=conf_thres, iou_thres=iou_thres) # Metrics Processing for si, pred in enumerate(preds): labels = targets[targets[:, 0] == si] nl = len(labels) tcls = labels[:, 1].tolist() if nl else [] if len(pred) == 0: if nl: stats.append((torch.zeros(0, iouv.numel(), dtype=torch.bool), torch.Tensor(), torch.Tensor(), torch.tensor(tcls))) continue # Predictions predn = pred.clone() # Ground Truth if nl: tbox = xywh2xyxy(labels[:, 2:6]) tbox[:, [0, 2]] *= width tbox[:, [1, 3]] *= height labelsn = torch.cat((labels[:, 1:2], tbox), 1) # [cls, x1, y1, x2, y2] # Match predictions to GT correct = self._process_batch(predn, labelsn, iouv) else: correct = torch.zeros(pred.shape[0], iouv.numel(), dtype=torch.bool) stats.append((correct.cpu(), pred[:, 4].cpu(), pred[:, 5].cpu(), torch.tensor(tcls))) mp, mr, map50, map5095 = 0.0, 0.0, 0.0, 0.0 if len(stats) and stats[0][0].any(): stats = [torch.cat(x, 0).cpu().numpy() for x in zip(*stats)] # to numpy tp, fp, p, r, f1, ap, ap_class = ap_per_class(*stats) ap50, ap = ap[:, 0], ap.mean(1) # AP@0.5, AP@0.5:0.95 mp, mr, map50, map5095 = p.mean(), r.mean(), ap50.mean(), ap.mean() LOGGER.info(f"Val Results: Prec={mp:.3f}, Recall={mr:.3f} mAP50={map50:.3f} mAP50-95={map5095:.3f}") return map50 def _process_batch(self, detections, labels, iouv): """ Return correct prediction matrix detections: [N, 6] (x1, y1, x2, y2, conf, cls) labels: [M, 5] (cls, x1, y1, x2, y2) """ correct = torch.zeros(detections.shape[0], iouv.shape[0], dtype=torch.bool, device=iouv.device) iou = box_iou(labels[:, 1:], detections[:, :4]) x = torch.where((iou >= iouv[0]) & (labels[:, 0:1] == detections[:, 5])) if x[0].shape[0]: matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1).cpu().numpy() if x[0].shape[0] > 1: matches = matches[matches[:, 2].argsort()[::-1]] matches = matches[np.unique(matches[:, 1], return_index=True)[1]] matches = matches[matches[:, 2].argsort()[::-1]] matches = matches[np.unique(matches[:, 0], return_index=True)[1]] matches = torch.from_numpy(matches).to(iouv.device) correct[matches[:, 1].long()] = matches[:, 2:3] >= iouv return correct def train(self): LOGGER.info(f"Starting training on {self.device} for {self.epochs} epochs...") start_time = time.time() best_fitness = 0.0 for epoch in range(self.start_epoch, self.epochs): self.train_one_epoch(epoch) self.scheduler.step() map50 = self.validate() ckpt = { 'epoch': epoch, 'model': self.model.state_dict(), 'ema': self.ema.ema.state_dict(), 'optimizer': self.optimizer.state_dict(), } torch.save(ckpt, self.save_dir / 'last.pt') if map50 > best_fitness: best_fitness = map50 torch.save(ckpt, self.save_dir / 'best.pt') LOGGER.info(f"--> Saved best model with Recall/mAP: {best_fitness:.4f}") LOGGER.info(f"\nTraining completed in {(time.time() - start_time) / 3600:.3f} hours.")