200 lines
7.3 KiB
Python
200 lines
7.3 KiB
Python
import math
|
|
import logging
|
|
import time
|
|
import numpy as np
|
|
from pathlib import Path
|
|
from typing import Optional, List
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
from tqdm import tqdm
|
|
|
|
from metrics import ap_per_class, box_iou, non_max_suppression, xywh2xyxy
|
|
from qat_utils import prepare_model_for_qat, QATConfig
|
|
|
|
logging.basicConfig(format="%(message)s", level=logging.INFO)
|
|
LOGGER = logging.getLogger("QAT_Trainer")
|
|
|
|
class YOLO11QATTrainer:
|
|
def __init__(self,
|
|
model,
|
|
post_processor,
|
|
train_loader,
|
|
val_loader,
|
|
loss_fn,
|
|
config: QATConfig,
|
|
epochs: int = 5,
|
|
lr: float = 0.001,
|
|
device: str = 'cuda'):
|
|
|
|
self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
|
|
|
|
self.config = config
|
|
self.save_dir = config.save_dir
|
|
|
|
self.original_model = model.to(self.device)
|
|
self.post_processor = post_processor
|
|
self.train_loader = train_loader
|
|
self.val_loader = val_loader
|
|
self.loss_fn = loss_fn
|
|
self.epochs = epochs
|
|
self.lr = lr
|
|
|
|
self.qat_model = None
|
|
self.optimizer = None
|
|
self.scheduler = None
|
|
|
|
def setup(self):
|
|
LOGGER.info(">>> Setting up QAT...")
|
|
|
|
example_input = torch.randn(1, 3, self.config.img_size, self.config.img_size).to(self.device)
|
|
self.qat_model = prepare_model_for_qat(
|
|
self.original_model,
|
|
example_input,
|
|
config=self.config
|
|
)
|
|
self.qat_model.to(self.device)
|
|
|
|
self.optimizer = optim.Adam(self.qat_model.parameters(), lr=self.lr, weight_decay=1e-5)
|
|
|
|
lf = lambda x: ((1 - math.cos(x * math.pi / self.epochs)) / 2) * (0.01 - 1) + 1
|
|
self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lf)
|
|
|
|
def train_epoch(self, epoch):
|
|
self.qat_model.train()
|
|
pbar = tqdm(self.train_loader, desc=f'Epoch {epoch+1}/{self.epochs}')
|
|
|
|
mean_loss = torch.zeros(3, device=self.device)
|
|
|
|
for i, (imgs, targets, _) in enumerate(pbar):
|
|
imgs = imgs.to(self.device).float()
|
|
targets = targets.to(self.device)
|
|
|
|
self.optimizer.zero_grad()
|
|
|
|
preds = self.qat_model(imgs)
|
|
|
|
target_batch = {
|
|
"batch_idx": targets[:, 0],
|
|
"cls": targets[:, 1],
|
|
"bboxes": targets[:, 2:],
|
|
}
|
|
loss, loss_items = self.loss_fn(preds, target_batch)
|
|
|
|
loss.backward()
|
|
torch.nn.utils.clip_grad_norm_(self.qat_model.parameters(), max_norm=10.0)
|
|
self.optimizer.step()
|
|
|
|
mean_loss = (mean_loss * i + loss_items.detach()) / (i + 1)
|
|
pbar.set_postfix({
|
|
'box': f"{mean_loss[0]:.4f}",
|
|
'cls': f"{mean_loss[1]:.4f}",
|
|
'dfl': f"{mean_loss[2]:.4f}",
|
|
'lr': f"{self.optimizer.param_groups[0]['lr']:.6f}"
|
|
})
|
|
|
|
def validate(self):
|
|
conf_thres = 0.001
|
|
iou_thres = 0.7
|
|
iouv = torch.linspace(0.5, 0.95, 10, device=self.device)
|
|
|
|
stats = []
|
|
|
|
LOGGER.info("\nValidating...")
|
|
|
|
self.qat_model.eval()
|
|
pbar = tqdm(self.val_loader, desc="Calc Metrics")
|
|
with torch.no_grad():
|
|
for batch in pbar:
|
|
imgs, targets, _ = batch
|
|
imgs = imgs.to(self.device, non_blocking=True)
|
|
targets = targets.to(self.device)
|
|
_, _, height, width = imgs.shape
|
|
|
|
preds = self.qat_model(imgs)
|
|
preds = self.post_processor(preds)
|
|
preds = non_max_suppression(preds, conf_thres=conf_thres, iou_thres=iou_thres)
|
|
|
|
for si, pred in enumerate(preds):
|
|
labels = targets[targets[:, 0] == si]
|
|
nl = len(labels)
|
|
tcls = labels[:, 1].tolist() if nl else []
|
|
|
|
if len(pred) == 0:
|
|
if nl:
|
|
stats.append((torch.zeros(0, iouv.numel(), dtype=torch.bool),
|
|
torch.Tensor(), torch.Tensor(), torch.tensor(tcls)))
|
|
continue
|
|
|
|
predn = pred.clone()
|
|
|
|
if nl:
|
|
tbox = xywh2xyxy(labels[:, 2:6])
|
|
tbox[:, [0, 2]] *= width
|
|
tbox[:, [1, 3]] *= height
|
|
labelsn = torch.cat((labels[:, 1:2], tbox), 1)
|
|
correct = self._process_batch(predn, labelsn, iouv)
|
|
else:
|
|
correct = torch.zeros(pred.shape[0], iouv.numel(), dtype=torch.bool)
|
|
|
|
stats.append((correct.cpu(), pred[:, 4].cpu(), pred[:, 5].cpu(), torch.tensor(tcls)))
|
|
|
|
mp, mr, map50, map5095 = 0.0, 0.0, 0.0, 0.0
|
|
if len(stats) and stats[0][0].any():
|
|
stats = [torch.cat(x, 0).cpu().numpy() for x in zip(*stats)]
|
|
tp, fp, p, r, f1, ap, ap_class = ap_per_class(*stats)
|
|
ap50, ap = ap[:, 0], ap.mean(1)
|
|
mp, mr, map50, map5095 = p.mean(), r.mean(), ap50.mean(), ap.mean()
|
|
|
|
LOGGER.info(f"Val Results: Prec={mp:.3f}, Recall={mr:.3f} mAP50={map50:.3f} mAP50-95={map5095:.3f}")
|
|
|
|
return map50
|
|
|
|
def _process_batch(self, detections, labels, iouv):
|
|
correct = torch.zeros(detections.shape[0], iouv.shape[0], dtype=torch.bool, device=iouv.device)
|
|
iou = box_iou(labels[:, 1:], detections[:, :4])
|
|
|
|
x = torch.where((iou >= iouv[0]) & (labels[:, 0:1] == detections[:, 5]))
|
|
|
|
if x[0].shape[0]:
|
|
matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1).cpu().numpy()
|
|
if x[0].shape[0] > 1:
|
|
matches = matches[matches[:, 2].argsort()[::-1]]
|
|
matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
|
|
matches = matches[matches[:, 2].argsort()[::-1]]
|
|
matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
|
|
|
|
matches = torch.from_numpy(matches).to(iouv.device)
|
|
correct[matches[:, 1].long()] = matches[:, 2:3] >= iouv
|
|
|
|
return correct
|
|
|
|
def train(self):
|
|
self.setup()
|
|
|
|
LOGGER.info(f"Starting training on {self.device} for {self.epochs} epochs...")
|
|
start_time = time.time()
|
|
best_fitness = -1
|
|
|
|
for epoch in range(self.epochs):
|
|
self.train_epoch(epoch)
|
|
self.scheduler.step()
|
|
map50 = self.validate()
|
|
|
|
ckpt = {
|
|
'epoch': epoch,
|
|
'model': self.qat_model.state_dict(),
|
|
'optimizer': self.optimizer.state_dict(),
|
|
}
|
|
|
|
torch.save(ckpt, self.save_dir / 'last_qat.pt')
|
|
|
|
if map50 > best_fitness:
|
|
best_fitness = map50
|
|
torch.save(ckpt, self.save_dir / 'best.pt')
|
|
LOGGER.info(f"--> Saved best model with Recall/mAP: {best_fitness:.4f}")
|
|
|
|
LOGGER.info(">>> Training Finished.")
|
|
|