import math import copy import time import logging from pathlib import Path import numpy as np import torch import torch.nn as nn import torch.optim as optim from tqdm import tqdm from metrics import ap_per_class, box_iou, non_max_suppression, xywh2xyxy logging.basicConfig(format="%(message)s", level=logging.INFO) LOGGER = logging.getLogger("YOLO_Trainer") class ModelEMA: def __init__(self, model, decay=0.9999, tau=2000, updates=0): self.ema = copy.deepcopy(model).eval() self.updates = updates self.decay = lambda x: decay * (1 - math.exp(-x / tau)) for p in self.ema.parameters(): p.requires_grad_(False) def update(self, model): self.updates += 1 d = self.decay(self.updates) msd = model.state_dict() for k, v in self.ema.state_dict().items(): if k in msd: tmp = msd[k].to(v.device) if v.dtype.is_floating_point: v *= d v += (1 - d) * tmp class YOLO11Trainer: def __init__(self, model, post_std, train_loader, val_loader, loss_fn, epochs=100, lr0=0.01, lrf=0.01, device='cuda', save_dir='./runs/train', warmup_epochs=3.0): self.device = torch.device(device if torch.cuda.is_available() else 'cpu') self.model = model.to(self.device) self.post_std = post_std self.train_loader = train_loader self.val_loader = val_loader self.loss_fn = loss_fn self.epochs = epochs self.save_dir = Path(save_dir) self.save_dir.mkdir(parents=True, exist_ok=True) self.warmup_epochs = warmup_epochs self.start_epoch = 0 g = [], [], [] bn = tuple(v for k, v in nn.__dict__.items() if 'Norm' in k) for v in self.model.modules(): for p_name, p in v.named_parameters(recurse=False): if p_name == 'bias': g[2].append(p) elif isinstance(v, bn): g[1].append(p) else: g[0].append(p) self.optimizer = optim.SGD(g[2], lr=lr0, momentum=0.937, nesterov=True) self.optimizer.add_param_group({'params': g[0], 'weight_decay': 0.0005}) self.optimizer.add_param_group({'params': g[1], 'weight_decay': 0.0}) LOGGER.info(f"Optimizer: weights={len(g[0])}, bn={len(g[1])}, biases={len(g[2])}") self.lf = lambda x: ((1 - math.cos(x * math.pi / self.epochs)) / 2) * (lrf - 1) + 1 self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=self.lf) self.scaler = torch.amp.GradScaler('cuda', enabled=True) self.ema = ModelEMA(self.model) def train_one_epoch(self, epoch): self.model.train() pbar = tqdm(enumerate(self.train_loader), total=len(self.train_loader), desc=f'Epoch {epoch+1}/{self.epochs}', leave=True) mloss = torch.zeros(3, device=self.device) self.optimizer.zero_grad() nb = len(self.train_loader) for i, batch in pbar: ni = i + nb * epoch imgs, targets, paths = batch imgs = imgs.to(self.device, non_blocking=True) targets = targets.to(self.device) if ni <= nb * self.warmup_epochs: xp = [0, nb * self.warmup_epochs] for j, x in enumerate(self.optimizer.param_groups): lr_target = x['initial_lr'] * self.lf(epoch) x['lr'] = np.interp(ni, xp, [0.1 if j == 0 else 0.0, lr_target]) if 'momentum' in x: x['momentum'] = np.interp(ni, xp, [0.8, 0.937]) with torch.amp.autocast('cuda', enabled=True): preds = self.model(imgs) target_batch = { "batch_idx": targets[:, 0], "cls": targets[:, 1], "bboxes": targets[:, 2:], } loss, loss_items = self.loss_fn(preds, target_batch) self.scaler.scale(loss).backward() self.scaler.unscale_(self.optimizer) torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=10.0) self.scaler.step(self.optimizer) self.scaler.update() self.optimizer.zero_grad() self.ema.update(self.model) loss_items = loss_items.detach() mloss = (mloss * i + loss_items) / (i + 1) mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G' pbar.set_postfix({ 'mem': mem, 'box': f"{mloss[0]:.4f}", 'cls': f"{mloss[1]:.4f}", 'dfl': f"{mloss[2]:.4f}", 'lr': f"{self.optimizer.param_groups[0]['lr']:.5f}" }) def validate(self): model = self.ema.ema device = self.device conf_thres = 0.001 iou_thres = 0.7 iouv = torch.linspace(0.5, 0.95, 10, device=device) loss_sum = torch.zeros(3, device=device) stats = [] LOGGER.info("\nValidating...") model.eval() pbar = tqdm(self.val_loader, desc="Calc Metrics") with torch.no_grad(): for batch in pbar: imgs, targets, _ = batch imgs = imgs.to(device, non_blocking=True) targets = targets.to(device) _, _, height, width = imgs.shape preds = model(imgs) preds = self.post_std(preds) preds = non_max_suppression(preds, conf_thres=conf_thres, iou_thres=iou_thres) for si, pred in enumerate(preds): labels = targets[targets[:, 0] == si] nl = len(labels) tcls = labels[:, 1].tolist() if nl else [] if len(pred) == 0: if nl: stats.append((torch.zeros(0, iouv.numel(), dtype=torch.bool), torch.Tensor(), torch.Tensor(), torch.tensor(tcls))) continue predn = pred.clone() if nl: tbox = xywh2xyxy(labels[:, 2:6]) tbox[:, [0, 2]] *= width tbox[:, [1, 3]] *= height labelsn = torch.cat((labels[:, 1:2], tbox), 1) correct = self._process_batch(predn, labelsn, iouv) else: correct = torch.zeros(pred.shape[0], iouv.numel(), dtype=torch.bool) stats.append((correct.cpu(), pred[:, 4].cpu(), pred[:, 5].cpu(), torch.tensor(tcls))) mp, mr, map50, map5095 = 0.0, 0.0, 0.0, 0.0 if len(stats) and stats[0][0].any(): stats = [torch.cat(x, 0).cpu().numpy() for x in zip(*stats)] tp, fp, p, r, f1, ap, ap_class = ap_per_class(*stats) ap50, ap = ap[:, 0], ap.mean(1) mp, mr, map50, map5095 = p.mean(), r.mean(), ap50.mean(), ap.mean() LOGGER.info(f"Val Results: Prec={mp:.3f}, Recall={mr:.3f} mAP50={map50:.3f} mAP50-95={map5095:.3f}") return map50 def _process_batch(self, detections, labels, iouv): correct = torch.zeros(detections.shape[0], iouv.shape[0], dtype=torch.bool, device=iouv.device) iou = box_iou(labels[:, 1:], detections[:, :4]) x = torch.where((iou >= iouv[0]) & (labels[:, 0:1] == detections[:, 5])) if x[0].shape[0]: matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1).cpu().numpy() if x[0].shape[0] > 1: matches = matches[matches[:, 2].argsort()[::-1]] matches = matches[np.unique(matches[:, 1], return_index=True)[1]] matches = matches[matches[:, 2].argsort()[::-1]] matches = matches[np.unique(matches[:, 0], return_index=True)[1]] matches = torch.from_numpy(matches).to(iouv.device) correct[matches[:, 1].long()] = matches[:, 2:3] >= iouv return correct def train(self): LOGGER.info(f"Starting training on {self.device} for {self.epochs} epochs...") start_time = time.time() best_fitness = 0.0 for epoch in range(self.start_epoch, self.epochs): self.train_one_epoch(epoch) self.scheduler.step() map50 = self.validate() ckpt = { 'epoch': epoch, 'model': self.model.state_dict(), 'ema': self.ema.ema.state_dict(), 'optimizer': self.optimizer.state_dict(), } torch.save(ckpt, self.save_dir / 'last.pt') if map50 > best_fitness: best_fitness = map50 torch.save(ckpt, self.save_dir / 'best.pt') LOGGER.info(f"--> Saved best model with Recall/mAP: {best_fitness:.4f}") LOGGER.info(f"\nTraining completed in {(time.time() - start_time) / 3600:.3f} hours.")