添加qat量化支持
This commit is contained in:
199
trainer_qat.py
Normal file
199
trainer_qat.py
Normal file
@@ -0,0 +1,199 @@
|
||||
import math
|
||||
import logging
|
||||
import time
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from typing import Optional, List
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from tqdm import tqdm
|
||||
|
||||
from metrics import ap_per_class, box_iou, non_max_suppression, xywh2xyxy
|
||||
from qat_utils import prepare_model_for_qat, QATConfig
|
||||
|
||||
logging.basicConfig(format="%(message)s", level=logging.INFO)
|
||||
LOGGER = logging.getLogger("QAT_Trainer")
|
||||
|
||||
class YOLO11QATTrainer:
|
||||
def __init__(self,
|
||||
model,
|
||||
post_processor,
|
||||
train_loader,
|
||||
val_loader,
|
||||
loss_fn,
|
||||
config: QATConfig,
|
||||
epochs: int = 5,
|
||||
lr: float = 0.001,
|
||||
device: str = 'cuda'):
|
||||
|
||||
self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
self.config = config
|
||||
self.save_dir = config.save_dir
|
||||
|
||||
self.original_model = model.to(self.device)
|
||||
self.post_processor = post_processor
|
||||
self.train_loader = train_loader
|
||||
self.val_loader = val_loader
|
||||
self.loss_fn = loss_fn
|
||||
self.epochs = epochs
|
||||
self.lr = lr
|
||||
|
||||
self.qat_model = None
|
||||
self.optimizer = None
|
||||
self.scheduler = None
|
||||
|
||||
def setup(self):
|
||||
LOGGER.info(">>> Setting up QAT...")
|
||||
|
||||
example_input = torch.randn(1, 3, self.config.img_size, self.config.img_size).to(self.device)
|
||||
self.qat_model = prepare_model_for_qat(
|
||||
self.original_model,
|
||||
example_input,
|
||||
config=self.config
|
||||
)
|
||||
self.qat_model.to(self.device)
|
||||
|
||||
self.optimizer = optim.Adam(self.qat_model.parameters(), lr=self.lr, weight_decay=1e-5)
|
||||
|
||||
lf = lambda x: ((1 - math.cos(x * math.pi / self.epochs)) / 2) * (0.01 - 1) + 1
|
||||
self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lf)
|
||||
|
||||
def train_epoch(self, epoch):
|
||||
self.qat_model.train()
|
||||
pbar = tqdm(self.train_loader, desc=f'Epoch {epoch+1}/{self.epochs}')
|
||||
|
||||
mean_loss = torch.zeros(3, device=self.device)
|
||||
|
||||
for i, (imgs, targets, _) in enumerate(pbar):
|
||||
imgs = imgs.to(self.device).float()
|
||||
targets = targets.to(self.device)
|
||||
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
preds = self.qat_model(imgs)
|
||||
|
||||
target_batch = {
|
||||
"batch_idx": targets[:, 0],
|
||||
"cls": targets[:, 1],
|
||||
"bboxes": targets[:, 2:],
|
||||
}
|
||||
loss, loss_items = self.loss_fn(preds, target_batch)
|
||||
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(self.qat_model.parameters(), max_norm=10.0)
|
||||
self.optimizer.step()
|
||||
|
||||
mean_loss = (mean_loss * i + loss_items.detach()) / (i + 1)
|
||||
pbar.set_postfix({
|
||||
'box': f"{mean_loss[0]:.4f}",
|
||||
'cls': f"{mean_loss[1]:.4f}",
|
||||
'dfl': f"{mean_loss[2]:.4f}",
|
||||
'lr': f"{self.optimizer.param_groups[0]['lr']:.6f}"
|
||||
})
|
||||
|
||||
def validate(self):
|
||||
conf_thres = 0.001
|
||||
iou_thres = 0.7
|
||||
iouv = torch.linspace(0.5, 0.95, 10, device=self.device)
|
||||
|
||||
stats = []
|
||||
|
||||
LOGGER.info("\nValidating...")
|
||||
|
||||
self.qat_model.eval()
|
||||
pbar = tqdm(self.val_loader, desc="Calc Metrics")
|
||||
with torch.no_grad():
|
||||
for batch in pbar:
|
||||
imgs, targets, _ = batch
|
||||
imgs = imgs.to(self.device, non_blocking=True)
|
||||
targets = targets.to(self.device)
|
||||
_, _, height, width = imgs.shape
|
||||
|
||||
preds = self.qat_model(imgs)
|
||||
preds = self.post_processor(preds)
|
||||
preds = non_max_suppression(preds, conf_thres=conf_thres, iou_thres=iou_thres)
|
||||
|
||||
for si, pred in enumerate(preds):
|
||||
labels = targets[targets[:, 0] == si]
|
||||
nl = len(labels)
|
||||
tcls = labels[:, 1].tolist() if nl else []
|
||||
|
||||
if len(pred) == 0:
|
||||
if nl:
|
||||
stats.append((torch.zeros(0, iouv.numel(), dtype=torch.bool),
|
||||
torch.Tensor(), torch.Tensor(), torch.tensor(tcls)))
|
||||
continue
|
||||
|
||||
predn = pred.clone()
|
||||
|
||||
if nl:
|
||||
tbox = xywh2xyxy(labels[:, 2:6])
|
||||
tbox[:, [0, 2]] *= width
|
||||
tbox[:, [1, 3]] *= height
|
||||
labelsn = torch.cat((labels[:, 1:2], tbox), 1)
|
||||
correct = self._process_batch(predn, labelsn, iouv)
|
||||
else:
|
||||
correct = torch.zeros(pred.shape[0], iouv.numel(), dtype=torch.bool)
|
||||
|
||||
stats.append((correct.cpu(), pred[:, 4].cpu(), pred[:, 5].cpu(), torch.tensor(tcls)))
|
||||
|
||||
mp, mr, map50, map5095 = 0.0, 0.0, 0.0, 0.0
|
||||
if len(stats) and stats[0][0].any():
|
||||
stats = [torch.cat(x, 0).cpu().numpy() for x in zip(*stats)]
|
||||
tp, fp, p, r, f1, ap, ap_class = ap_per_class(*stats)
|
||||
ap50, ap = ap[:, 0], ap.mean(1)
|
||||
mp, mr, map50, map5095 = p.mean(), r.mean(), ap50.mean(), ap.mean()
|
||||
|
||||
LOGGER.info(f"Val Results: Prec={mp:.3f}, Recall={mr:.3f} mAP50={map50:.3f} mAP50-95={map5095:.3f}")
|
||||
|
||||
return map50
|
||||
|
||||
def _process_batch(self, detections, labels, iouv):
|
||||
correct = torch.zeros(detections.shape[0], iouv.shape[0], dtype=torch.bool, device=iouv.device)
|
||||
iou = box_iou(labels[:, 1:], detections[:, :4])
|
||||
|
||||
x = torch.where((iou >= iouv[0]) & (labels[:, 0:1] == detections[:, 5]))
|
||||
|
||||
if x[0].shape[0]:
|
||||
matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1).cpu().numpy()
|
||||
if x[0].shape[0] > 1:
|
||||
matches = matches[matches[:, 2].argsort()[::-1]]
|
||||
matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
|
||||
matches = matches[matches[:, 2].argsort()[::-1]]
|
||||
matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
|
||||
|
||||
matches = torch.from_numpy(matches).to(iouv.device)
|
||||
correct[matches[:, 1].long()] = matches[:, 2:3] >= iouv
|
||||
|
||||
return correct
|
||||
|
||||
def train(self):
|
||||
self.setup()
|
||||
|
||||
LOGGER.info(f"Starting training on {self.device} for {self.epochs} epochs...")
|
||||
start_time = time.time()
|
||||
best_fitness = -1
|
||||
|
||||
for epoch in range(self.epochs):
|
||||
self.train_epoch(epoch)
|
||||
self.scheduler.step()
|
||||
map50 = self.validate()
|
||||
|
||||
ckpt = {
|
||||
'epoch': epoch,
|
||||
'model': self.qat_model.state_dict(),
|
||||
'optimizer': self.optimizer.state_dict(),
|
||||
}
|
||||
|
||||
torch.save(ckpt, self.save_dir / 'last_qat.pt')
|
||||
|
||||
if map50 > best_fitness:
|
||||
best_fitness = map50
|
||||
torch.save(ckpt, self.save_dir / 'best.pt')
|
||||
LOGGER.info(f"--> Saved best model with Recall/mAP: {best_fitness:.4f}")
|
||||
|
||||
LOGGER.info(">>> Training Finished.")
|
||||
|
||||
Reference in New Issue
Block a user