import math import logging import time import numpy as np from pathlib import Path from typing import Optional, List import torch import torch.nn as nn import torch.optim as optim from tqdm import tqdm from metrics import ap_per_class, box_iou, non_max_suppression, xywh2xyxy from qat_utils import prepare_model_for_qat, QATConfig logging.basicConfig(format="%(message)s", level=logging.INFO) LOGGER = logging.getLogger("QAT_Trainer") class YOLO11QATTrainer: def __init__(self, model, post_processor, train_loader, val_loader, loss_fn, config: QATConfig, epochs: int = 5, lr: float = 0.001, device: str = 'cuda'): self.device = torch.device(device if torch.cuda.is_available() else 'cpu') self.config = config self.save_dir = config.save_dir self.original_model = model.to(self.device) self.post_processor = post_processor self.train_loader = train_loader self.val_loader = val_loader self.loss_fn = loss_fn self.epochs = epochs self.lr = lr self.qat_model = None self.optimizer = None self.scheduler = None def setup(self): LOGGER.info(">>> Setting up QAT...") example_input = torch.randn(1, 3, self.config.img_size, self.config.img_size).to(self.device) self.qat_model = prepare_model_for_qat( self.original_model, example_input, config=self.config ) self.qat_model.to(self.device) self.optimizer = optim.Adam(self.qat_model.parameters(), lr=self.lr, weight_decay=1e-5) lf = lambda x: ((1 - math.cos(x * math.pi / self.epochs)) / 2) * (0.01 - 1) + 1 self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lf) def train_epoch(self, epoch): self.qat_model.train() pbar = tqdm(self.train_loader, desc=f'Epoch {epoch+1}/{self.epochs}') mean_loss = torch.zeros(3, device=self.device) for i, (imgs, targets, _) in enumerate(pbar): imgs = imgs.to(self.device).float() targets = targets.to(self.device) self.optimizer.zero_grad() preds = self.qat_model(imgs) target_batch = { "batch_idx": targets[:, 0], "cls": targets[:, 1], "bboxes": targets[:, 2:], } loss, loss_items = self.loss_fn(preds, target_batch) loss.backward() torch.nn.utils.clip_grad_norm_(self.qat_model.parameters(), max_norm=10.0) self.optimizer.step() mean_loss = (mean_loss * i + loss_items.detach()) / (i + 1) pbar.set_postfix({ 'box': f"{mean_loss[0]:.4f}", 'cls': f"{mean_loss[1]:.4f}", 'dfl': f"{mean_loss[2]:.4f}", 'lr': f"{self.optimizer.param_groups[0]['lr']:.6f}" }) def validate(self): conf_thres = 0.001 iou_thres = 0.7 iouv = torch.linspace(0.5, 0.95, 10, device=self.device) stats = [] LOGGER.info("\nValidating...") self.qat_model.eval() pbar = tqdm(self.val_loader, desc="Calc Metrics") with torch.no_grad(): for batch in pbar: imgs, targets, _ = batch imgs = imgs.to(self.device, non_blocking=True) targets = targets.to(self.device) _, _, height, width = imgs.shape preds = self.qat_model(imgs) preds = self.post_processor(preds) preds = non_max_suppression(preds, conf_thres=conf_thres, iou_thres=iou_thres) for si, pred in enumerate(preds): labels = targets[targets[:, 0] == si] nl = len(labels) tcls = labels[:, 1].tolist() if nl else [] if len(pred) == 0: if nl: stats.append((torch.zeros(0, iouv.numel(), dtype=torch.bool), torch.Tensor(), torch.Tensor(), torch.tensor(tcls))) continue predn = pred.clone() if nl: tbox = xywh2xyxy(labels[:, 2:6]) tbox[:, [0, 2]] *= width tbox[:, [1, 3]] *= height labelsn = torch.cat((labels[:, 1:2], tbox), 1) correct = self._process_batch(predn, labelsn, iouv) else: correct = torch.zeros(pred.shape[0], iouv.numel(), dtype=torch.bool) stats.append((correct.cpu(), pred[:, 4].cpu(), pred[:, 5].cpu(), torch.tensor(tcls))) mp, mr, map50, map5095 = 0.0, 0.0, 0.0, 0.0 if len(stats) and stats[0][0].any(): stats = [torch.cat(x, 0).cpu().numpy() for x in zip(*stats)] tp, fp, p, r, f1, ap, ap_class = ap_per_class(*stats) ap50, ap = ap[:, 0], ap.mean(1) mp, mr, map50, map5095 = p.mean(), r.mean(), ap50.mean(), ap.mean() LOGGER.info(f"Val Results: Prec={mp:.3f}, Recall={mr:.3f} mAP50={map50:.3f} mAP50-95={map5095:.3f}") return map50 def _process_batch(self, detections, labels, iouv): correct = torch.zeros(detections.shape[0], iouv.shape[0], dtype=torch.bool, device=iouv.device) iou = box_iou(labels[:, 1:], detections[:, :4]) x = torch.where((iou >= iouv[0]) & (labels[:, 0:1] == detections[:, 5])) if x[0].shape[0]: matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1).cpu().numpy() if x[0].shape[0] > 1: matches = matches[matches[:, 2].argsort()[::-1]] matches = matches[np.unique(matches[:, 1], return_index=True)[1]] matches = matches[matches[:, 2].argsort()[::-1]] matches = matches[np.unique(matches[:, 0], return_index=True)[1]] matches = torch.from_numpy(matches).to(iouv.device) correct[matches[:, 1].long()] = matches[:, 2:3] >= iouv return correct def train(self): self.setup() LOGGER.info(f"Starting training on {self.device} for {self.epochs} epochs...") start_time = time.time() best_fitness = -1 for epoch in range(self.epochs): self.train_epoch(epoch) self.scheduler.step() map50 = self.validate() ckpt = { 'epoch': epoch, 'model': self.qat_model.state_dict(), 'optimizer': self.optimizer.state_dict(), } torch.save(ckpt, self.save_dir / 'last_qat.pt') if map50 > best_fitness: best_fitness = map50 torch.save(ckpt, self.save_dir / 'best.pt') LOGGER.info(f"--> Saved best model with Recall/mAP: {best_fitness:.4f}") LOGGER.info(">>> Training Finished.")