Files
Yolo-standalone/trainer_qat.py
2026-01-08 15:12:27 +08:00

200 lines
7.3 KiB
Python

import math
import logging
import time
import numpy as np
from pathlib import Path
from typing import Optional, List
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
from metrics import ap_per_class, box_iou, non_max_suppression, xywh2xyxy
from qat_utils import prepare_model_for_qat, QATConfig
logging.basicConfig(format="%(message)s", level=logging.INFO)
LOGGER = logging.getLogger("QAT_Trainer")
class YOLO11QATTrainer:
def __init__(self,
model,
post_processor,
train_loader,
val_loader,
loss_fn,
config: QATConfig,
epochs: int = 5,
lr: float = 0.001,
device: str = 'cuda'):
self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
self.config = config
self.save_dir = config.save_dir
self.original_model = model.to(self.device)
self.post_processor = post_processor
self.train_loader = train_loader
self.val_loader = val_loader
self.loss_fn = loss_fn
self.epochs = epochs
self.lr = lr
self.qat_model = None
self.optimizer = None
self.scheduler = None
def setup(self):
LOGGER.info(">>> Setting up QAT...")
example_input = torch.randn(1, 3, self.config.img_size, self.config.img_size).to(self.device)
self.qat_model = prepare_model_for_qat(
self.original_model,
example_input,
config=self.config
)
self.qat_model.to(self.device)
self.optimizer = optim.Adam(self.qat_model.parameters(), lr=self.lr, weight_decay=1e-5)
lf = lambda x: ((1 - math.cos(x * math.pi / self.epochs)) / 2) * (0.01 - 1) + 1
self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lf)
def train_epoch(self, epoch):
self.qat_model.train()
pbar = tqdm(self.train_loader, desc=f'Epoch {epoch+1}/{self.epochs}')
mean_loss = torch.zeros(3, device=self.device)
for i, (imgs, targets, _) in enumerate(pbar):
imgs = imgs.to(self.device).float()
targets = targets.to(self.device)
self.optimizer.zero_grad()
preds = self.qat_model(imgs)
target_batch = {
"batch_idx": targets[:, 0],
"cls": targets[:, 1],
"bboxes": targets[:, 2:],
}
loss, loss_items = self.loss_fn(preds, target_batch)
loss.backward()
torch.nn.utils.clip_grad_norm_(self.qat_model.parameters(), max_norm=10.0)
self.optimizer.step()
mean_loss = (mean_loss * i + loss_items.detach()) / (i + 1)
pbar.set_postfix({
'box': f"{mean_loss[0]:.4f}",
'cls': f"{mean_loss[1]:.4f}",
'dfl': f"{mean_loss[2]:.4f}",
'lr': f"{self.optimizer.param_groups[0]['lr']:.6f}"
})
def validate(self):
conf_thres = 0.001
iou_thres = 0.7
iouv = torch.linspace(0.5, 0.95, 10, device=self.device)
stats = []
LOGGER.info("\nValidating...")
self.qat_model.eval()
pbar = tqdm(self.val_loader, desc="Calc Metrics")
with torch.no_grad():
for batch in pbar:
imgs, targets, _ = batch
imgs = imgs.to(self.device, non_blocking=True)
targets = targets.to(self.device)
_, _, height, width = imgs.shape
preds = self.qat_model(imgs)
preds = self.post_processor(preds)
preds = non_max_suppression(preds, conf_thres=conf_thres, iou_thres=iou_thres)
for si, pred in enumerate(preds):
labels = targets[targets[:, 0] == si]
nl = len(labels)
tcls = labels[:, 1].tolist() if nl else []
if len(pred) == 0:
if nl:
stats.append((torch.zeros(0, iouv.numel(), dtype=torch.bool),
torch.Tensor(), torch.Tensor(), torch.tensor(tcls)))
continue
predn = pred.clone()
if nl:
tbox = xywh2xyxy(labels[:, 2:6])
tbox[:, [0, 2]] *= width
tbox[:, [1, 3]] *= height
labelsn = torch.cat((labels[:, 1:2], tbox), 1)
correct = self._process_batch(predn, labelsn, iouv)
else:
correct = torch.zeros(pred.shape[0], iouv.numel(), dtype=torch.bool)
stats.append((correct.cpu(), pred[:, 4].cpu(), pred[:, 5].cpu(), torch.tensor(tcls)))
mp, mr, map50, map5095 = 0.0, 0.0, 0.0, 0.0
if len(stats) and stats[0][0].any():
stats = [torch.cat(x, 0).cpu().numpy() for x in zip(*stats)]
tp, fp, p, r, f1, ap, ap_class = ap_per_class(*stats)
ap50, ap = ap[:, 0], ap.mean(1)
mp, mr, map50, map5095 = p.mean(), r.mean(), ap50.mean(), ap.mean()
LOGGER.info(f"Val Results: Prec={mp:.3f}, Recall={mr:.3f} mAP50={map50:.3f} mAP50-95={map5095:.3f}")
return map50
def _process_batch(self, detections, labels, iouv):
correct = torch.zeros(detections.shape[0], iouv.shape[0], dtype=torch.bool, device=iouv.device)
iou = box_iou(labels[:, 1:], detections[:, :4])
x = torch.where((iou >= iouv[0]) & (labels[:, 0:1] == detections[:, 5]))
if x[0].shape[0]:
matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1).cpu().numpy()
if x[0].shape[0] > 1:
matches = matches[matches[:, 2].argsort()[::-1]]
matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
matches = matches[matches[:, 2].argsort()[::-1]]
matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
matches = torch.from_numpy(matches).to(iouv.device)
correct[matches[:, 1].long()] = matches[:, 2:3] >= iouv
return correct
def train(self):
self.setup()
LOGGER.info(f"Starting training on {self.device} for {self.epochs} epochs...")
start_time = time.time()
best_fitness = -1
for epoch in range(self.epochs):
self.train_epoch(epoch)
self.scheduler.step()
map50 = self.validate()
ckpt = {
'epoch': epoch,
'model': self.qat_model.state_dict(),
'optimizer': self.optimizer.state_dict(),
}
torch.save(ckpt, self.save_dir / 'last_qat.pt')
if map50 > best_fitness:
best_fitness = map50
torch.save(ckpt, self.save_dir / 'best.pt')
LOGGER.info(f"--> Saved best model with Recall/mAP: {best_fitness:.4f}")
LOGGER.info(">>> Training Finished.")