Files
Yolo-standalone/trainer.py

246 lines
9.4 KiB
Python

import math
import copy
import time
import logging
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
from metrics import ap_per_class, box_iou, non_max_suppression, xywh2xyxy
logging.basicConfig(format="%(message)s", level=logging.INFO)
LOGGER = logging.getLogger("YOLO_Trainer")
class ModelEMA:
def __init__(self, model, decay=0.9999, tau=2000, updates=0):
self.ema = copy.deepcopy(model).eval()
self.updates = updates
self.decay = lambda x: decay * (1 - math.exp(-x / tau))
for p in self.ema.parameters():
p.requires_grad_(False)
def update(self, model):
self.updates += 1
d = self.decay(self.updates)
msd = model.state_dict()
for k, v in self.ema.state_dict().items():
if k in msd:
tmp = msd[k].to(v.device)
if v.dtype.is_floating_point:
v *= d
v += (1 - d) * tmp
class YOLO11Trainer:
def __init__(self,
model,
post_std,
train_loader,
val_loader,
loss_fn,
epochs=100,
lr0=0.01,
lrf=0.01,
device='cuda',
save_dir='./runs/train',
warmup_epochs=3.0):
self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
self.model = model.to(self.device)
self.post_std = post_std
self.train_loader = train_loader
self.val_loader = val_loader
self.loss_fn = loss_fn
self.epochs = epochs
self.save_dir = Path(save_dir)
self.save_dir.mkdir(parents=True, exist_ok=True)
self.warmup_epochs = warmup_epochs
self.start_epoch = 0
g = [], [], []
bn = tuple(v for k, v in nn.__dict__.items() if 'Norm' in k)
for v in self.model.modules():
for p_name, p in v.named_parameters(recurse=False):
if p_name == 'bias':
g[2].append(p)
elif isinstance(v, bn):
g[1].append(p)
else:
g[0].append(p)
self.optimizer = optim.SGD(g[2], lr=lr0, momentum=0.937, nesterov=True)
self.optimizer.add_param_group({'params': g[0], 'weight_decay': 0.0005})
self.optimizer.add_param_group({'params': g[1], 'weight_decay': 0.0})
LOGGER.info(f"Optimizer: weights={len(g[0])}, bn={len(g[1])}, biases={len(g[2])}")
self.lf = lambda x: ((1 - math.cos(x * math.pi / self.epochs)) / 2) * (lrf - 1) + 1
self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=self.lf)
self.scaler = torch.amp.GradScaler('cuda', enabled=True)
self.ema = ModelEMA(self.model)
def train_one_epoch(self, epoch):
self.model.train()
pbar = tqdm(enumerate(self.train_loader), total=len(self.train_loader),
desc=f'Epoch {epoch+1}/{self.epochs}', leave=True)
mloss = torch.zeros(3, device=self.device)
self.optimizer.zero_grad()
nb = len(self.train_loader)
for i, batch in pbar:
ni = i + nb * epoch
imgs, targets, paths = batch
imgs = imgs.to(self.device, non_blocking=True)
targets = targets.to(self.device)
if ni <= nb * self.warmup_epochs:
xp = [0, nb * self.warmup_epochs]
for j, x in enumerate(self.optimizer.param_groups):
lr_target = x['initial_lr'] * self.lf(epoch)
x['lr'] = np.interp(ni, xp, [0.1 if j == 0 else 0.0, lr_target])
if 'momentum' in x:
x['momentum'] = np.interp(ni, xp, [0.8, 0.937])
with torch.amp.autocast('cuda', enabled=True):
preds = self.model(imgs)
target_batch = {
"batch_idx": targets[:, 0],
"cls": targets[:, 1],
"bboxes": targets[:, 2:],
}
loss, loss_items = self.loss_fn(preds, target_batch)
self.scaler.scale(loss).backward()
self.scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=10.0)
self.scaler.step(self.optimizer)
self.scaler.update()
self.optimizer.zero_grad()
self.ema.update(self.model)
loss_items = loss_items.detach()
mloss = (mloss * i + loss_items) / (i + 1)
mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G'
pbar.set_postfix({
'mem': mem,
'box': f"{mloss[0]:.4f}",
'cls': f"{mloss[1]:.4f}",
'dfl': f"{mloss[2]:.4f}",
'lr': f"{self.optimizer.param_groups[0]['lr']:.5f}"
})
def validate(self):
model = self.ema.ema
device = self.device
conf_thres = 0.001
iou_thres = 0.7
iouv = torch.linspace(0.5, 0.95, 10, device=device)
loss_sum = torch.zeros(3, device=device)
stats = []
LOGGER.info("\nValidating...")
model.eval()
pbar = tqdm(self.val_loader, desc="Calc Metrics")
with torch.no_grad():
for batch in pbar:
imgs, targets, _ = batch
imgs = imgs.to(device, non_blocking=True)
targets = targets.to(device)
_, _, height, width = imgs.shape
preds = model(imgs)
preds = self.post_std(preds)
preds = non_max_suppression(preds, conf_thres=conf_thres, iou_thres=iou_thres)
for si, pred in enumerate(preds):
labels = targets[targets[:, 0] == si]
nl = len(labels)
tcls = labels[:, 1].tolist() if nl else []
if len(pred) == 0:
if nl:
stats.append((torch.zeros(0, iouv.numel(), dtype=torch.bool),
torch.Tensor(), torch.Tensor(), torch.tensor(tcls)))
continue
predn = pred.clone()
if nl:
tbox = xywh2xyxy(labels[:, 2:6])
tbox[:, [0, 2]] *= width
tbox[:, [1, 3]] *= height
labelsn = torch.cat((labels[:, 1:2], tbox), 1)
correct = self._process_batch(predn, labelsn, iouv)
else:
correct = torch.zeros(pred.shape[0], iouv.numel(), dtype=torch.bool)
stats.append((correct.cpu(), pred[:, 4].cpu(), pred[:, 5].cpu(), torch.tensor(tcls)))
mp, mr, map50, map5095 = 0.0, 0.0, 0.0, 0.0
if len(stats) and stats[0][0].any():
stats = [torch.cat(x, 0).cpu().numpy() for x in zip(*stats)]
tp, fp, p, r, f1, ap, ap_class = ap_per_class(*stats)
ap50, ap = ap[:, 0], ap.mean(1)
mp, mr, map50, map5095 = p.mean(), r.mean(), ap50.mean(), ap.mean()
LOGGER.info(f"Val Results: Prec={mp:.3f}, Recall={mr:.3f} mAP50={map50:.3f} mAP50-95={map5095:.3f}")
return map50
def _process_batch(self, detections, labels, iouv):
correct = torch.zeros(detections.shape[0], iouv.shape[0], dtype=torch.bool, device=iouv.device)
iou = box_iou(labels[:, 1:], detections[:, :4])
x = torch.where((iou >= iouv[0]) & (labels[:, 0:1] == detections[:, 5]))
if x[0].shape[0]:
matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1).cpu().numpy()
if x[0].shape[0] > 1:
matches = matches[matches[:, 2].argsort()[::-1]]
matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
matches = matches[matches[:, 2].argsort()[::-1]]
matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
matches = torch.from_numpy(matches).to(iouv.device)
correct[matches[:, 1].long()] = matches[:, 2:3] >= iouv
return correct
def train(self):
LOGGER.info(f"Starting training on {self.device} for {self.epochs} epochs...")
start_time = time.time()
best_fitness = 0.0
for epoch in range(self.start_epoch, self.epochs):
self.train_one_epoch(epoch)
self.scheduler.step()
map50 = self.validate()
ckpt = {
'epoch': epoch,
'model': self.model.state_dict(),
'ema': self.ema.ema.state_dict(),
'optimizer': self.optimizer.state_dict(),
}
torch.save(ckpt, self.save_dir / 'last.pt')
if map50 > best_fitness:
best_fitness = map50
torch.save(ckpt, self.save_dir / 'best.pt')
LOGGER.info(f"--> Saved best model with Recall/mAP: {best_fitness:.4f}")
LOGGER.info(f"\nTraining completed in {(time.time() - start_time) / 3600:.3f} hours.")