275 lines
11 KiB
Python
275 lines
11 KiB
Python
import math
|
|
import copy
|
|
import time
|
|
import logging
|
|
from pathlib import Path
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
from tqdm import tqdm
|
|
from metrics import ap_per_class, box_iou, non_max_suppression, xywh2xyxy
|
|
|
|
# 配置日志
|
|
logging.basicConfig(format="%(message)s", level=logging.INFO)
|
|
LOGGER = logging.getLogger("YOLO_Trainer")
|
|
|
|
# ==============================================================================
|
|
# Helper Class: Model EMA (Exponential Moving Average)
|
|
# ==============================================================================
|
|
class ModelEMA:
|
|
""" Updated Exponential Moving Average (EMA) from Ultralytics """
|
|
def __init__(self, model, decay=0.9999, tau=2000, updates=0):
|
|
self.ema = copy.deepcopy(model).eval() # FP32 EMA
|
|
self.updates = updates
|
|
# decay exponential ramp (to help early epochs)
|
|
self.decay = lambda x: decay * (1 - math.exp(-x / tau))
|
|
for p in self.ema.parameters():
|
|
p.requires_grad_(False)
|
|
|
|
def update(self, model):
|
|
self.updates += 1
|
|
d = self.decay(self.updates)
|
|
|
|
msd = model.state_dict()
|
|
for k, v in self.ema.state_dict().items():
|
|
if k in msd:
|
|
tmp = msd[k].to(v.device)
|
|
if v.dtype.is_floating_point:
|
|
v *= d
|
|
v += (1 - d) * tmp
|
|
|
|
# ==============================================================================
|
|
# Main Trainer Class
|
|
# ==============================================================================
|
|
class YOLO11Trainer:
|
|
def __init__(self,
|
|
model,
|
|
train_loader,
|
|
val_loader,
|
|
loss_fn,
|
|
epochs=100,
|
|
lr0=0.01,
|
|
lrf=0.01,
|
|
device='cuda',
|
|
save_dir='./runs/train',
|
|
warmup_epochs=3.0):
|
|
|
|
self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
|
|
self.model = model.to(self.device)
|
|
self.train_loader = train_loader
|
|
self.val_loader = val_loader
|
|
self.loss_fn = loss_fn
|
|
self.epochs = epochs
|
|
self.save_dir = Path(save_dir)
|
|
self.save_dir.mkdir(parents=True, exist_ok=True)
|
|
self.warmup_epochs = warmup_epochs
|
|
self.start_epoch = 0
|
|
|
|
# --- Optimizer Building ---
|
|
g = [], [], [] # optimizer parameter groups
|
|
bn = tuple(v for k, v in nn.__dict__.items() if 'Norm' in k)
|
|
|
|
for v in self.model.modules():
|
|
for p_name, p in v.named_parameters(recurse=False):
|
|
if p_name == 'bias':
|
|
g[2].append(p) # biases
|
|
elif isinstance(v, bn):
|
|
g[1].append(p) # bn weights (no decay)
|
|
else:
|
|
g[0].append(p) # weights (decay)
|
|
|
|
self.optimizer = optim.SGD(g[2], lr=lr0, momentum=0.937, nesterov=True)
|
|
self.optimizer.add_param_group({'params': g[0], 'weight_decay': 0.0005})
|
|
self.optimizer.add_param_group({'params': g[1], 'weight_decay': 0.0})
|
|
|
|
LOGGER.info(f"Optimizer: weights={len(g[0])}, bn={len(g[1])}, biases={len(g[2])}")
|
|
|
|
# --- Scheduler ---
|
|
self.lf = lambda x: ((1 - math.cos(x * math.pi / self.epochs)) / 2) * (lrf - 1) + 1
|
|
self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=self.lf)
|
|
|
|
# --- AMP & EMA ---
|
|
self.scaler = torch.amp.GradScaler('cuda', enabled=True)
|
|
self.ema = ModelEMA(self.model)
|
|
|
|
def train_one_epoch(self, epoch):
|
|
self.model.train()
|
|
pbar = tqdm(enumerate(self.train_loader), total=len(self.train_loader),
|
|
desc=f'Epoch {epoch+1}/{self.epochs}', leave=True)
|
|
|
|
mloss = torch.zeros(3, device=self.device)
|
|
self.optimizer.zero_grad()
|
|
nb = len(self.train_loader)
|
|
|
|
for i, batch in pbar:
|
|
ni = i + nb * epoch
|
|
|
|
imgs, targets, paths = batch
|
|
imgs = imgs.to(self.device, non_blocking=True)
|
|
targets = targets.to(self.device)
|
|
|
|
# --- Warmup ---
|
|
if ni <= nb * self.warmup_epochs:
|
|
xp = [0, nb * self.warmup_epochs]
|
|
for j, x in enumerate(self.optimizer.param_groups):
|
|
lr_target = x['initial_lr'] * self.lf(epoch)
|
|
x['lr'] = np.interp(ni, xp, [0.1 if j == 0 else 0.0, lr_target])
|
|
if 'momentum' in x:
|
|
x['momentum'] = np.interp(ni, xp, [0.8, 0.937])
|
|
|
|
# --- Forward ---
|
|
with torch.amp.autocast('cuda', enabled=True):
|
|
preds = self.model(imgs)
|
|
|
|
target_batch = {
|
|
"batch_idx": targets[:, 0],
|
|
"cls": targets[:, 1],
|
|
"bboxes": targets[:, 2:],
|
|
}
|
|
|
|
loss, loss_items = self.loss_fn(preds, target_batch)
|
|
|
|
# --- Backward ---
|
|
self.scaler.scale(loss).backward()
|
|
self.scaler.unscale_(self.optimizer)
|
|
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=10.0)
|
|
self.scaler.step(self.optimizer)
|
|
self.scaler.update()
|
|
self.optimizer.zero_grad()
|
|
|
|
# --- EMA ---
|
|
self.ema.update(self.model)
|
|
|
|
# --- Logging ---
|
|
loss_items = loss_items.detach()
|
|
mloss = (mloss * i + loss_items) / (i + 1)
|
|
mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G'
|
|
|
|
pbar.set_postfix({
|
|
'mem': mem,
|
|
'box': f"{mloss[0]:.4f}",
|
|
'cls': f"{mloss[1]:.4f}",
|
|
'dfl': f"{mloss[2]:.4f}",
|
|
'lr': f"{self.optimizer.param_groups[0]['lr']:.5f}"
|
|
})
|
|
|
|
def validate(self):
|
|
model = self.ema.ema
|
|
device = self.device
|
|
|
|
# --- Metrics Config ---
|
|
conf_thres = 0.001 # Low threshold for mAP calculation
|
|
iou_thres = 0.7 # NMS IoU threshold
|
|
iouv = torch.linspace(0.5, 0.95, 10, device=device) # IoU vector for mAP@0.5:0.95
|
|
|
|
loss_sum = torch.zeros(3, device=device)
|
|
stats = [] # [(correct, conf, pred_cls, target_cls)]
|
|
|
|
LOGGER.info("\nValidating...")
|
|
|
|
model.eval()
|
|
|
|
pbar = tqdm(self.val_loader, desc="Calc Metrics")
|
|
with torch.no_grad():
|
|
for batch in pbar:
|
|
imgs, targets, _ = batch
|
|
imgs = imgs.to(device, non_blocking=True)
|
|
targets = targets.to(device)
|
|
_, _, height, width = imgs.shape
|
|
|
|
# Inference
|
|
preds = model(imgs)
|
|
|
|
# NMS
|
|
preds = non_max_suppression(preds, conf_thres=conf_thres, iou_thres=iou_thres)
|
|
|
|
# Metrics Processing
|
|
for si, pred in enumerate(preds):
|
|
labels = targets[targets[:, 0] == si]
|
|
nl = len(labels)
|
|
tcls = labels[:, 1].tolist() if nl else []
|
|
|
|
if len(pred) == 0:
|
|
if nl:
|
|
stats.append((torch.zeros(0, iouv.numel(), dtype=torch.bool),
|
|
torch.Tensor(), torch.Tensor(), torch.tensor(tcls)))
|
|
continue
|
|
|
|
# Predictions
|
|
predn = pred.clone()
|
|
|
|
# Ground Truth
|
|
if nl:
|
|
tbox = xywh2xyxy(labels[:, 2:6])
|
|
tbox[:, [0, 2]] *= width
|
|
tbox[:, [1, 3]] *= height
|
|
labelsn = torch.cat((labels[:, 1:2], tbox), 1) # [cls, x1, y1, x2, y2]
|
|
|
|
# Match predictions to GT
|
|
correct = self._process_batch(predn, labelsn, iouv)
|
|
else:
|
|
correct = torch.zeros(pred.shape[0], iouv.numel(), dtype=torch.bool)
|
|
|
|
stats.append((correct.cpu(), pred[:, 4].cpu(), pred[:, 5].cpu(), torch.tensor(tcls)))
|
|
|
|
mp, mr, map50, map5095 = 0.0, 0.0, 0.0, 0.0
|
|
if len(stats) and stats[0][0].any():
|
|
stats = [torch.cat(x, 0).cpu().numpy() for x in zip(*stats)] # to numpy
|
|
tp, fp, p, r, f1, ap, ap_class = ap_per_class(*stats)
|
|
ap50, ap = ap[:, 0], ap.mean(1) # AP@0.5, AP@0.5:0.95
|
|
mp, mr, map50, map5095 = p.mean(), r.mean(), ap50.mean(), ap.mean()
|
|
|
|
LOGGER.info(f"Val Results: Prec={mp:.3f}, Recall={mr:.3f} mAP50={map50:.3f} mAP50-95={map5095:.3f}")
|
|
|
|
return map50
|
|
|
|
def _process_batch(self, detections, labels, iouv):
|
|
"""
|
|
Return correct prediction matrix
|
|
detections: [N, 6] (x1, y1, x2, y2, conf, cls)
|
|
labels: [M, 5] (cls, x1, y1, x2, y2)
|
|
"""
|
|
correct = torch.zeros(detections.shape[0], iouv.shape[0], dtype=torch.bool, device=iouv.device)
|
|
iou = box_iou(labels[:, 1:], detections[:, :4])
|
|
|
|
x = torch.where((iou >= iouv[0]) & (labels[:, 0:1] == detections[:, 5]))
|
|
|
|
if x[0].shape[0]:
|
|
matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1).cpu().numpy()
|
|
if x[0].shape[0] > 1:
|
|
matches = matches[matches[:, 2].argsort()[::-1]]
|
|
matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
|
|
matches = matches[matches[:, 2].argsort()[::-1]]
|
|
matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
|
|
|
|
matches = torch.from_numpy(matches).to(iouv.device)
|
|
correct[matches[:, 1].long()] = matches[:, 2:3] >= iouv
|
|
|
|
return correct
|
|
|
|
def train(self):
|
|
LOGGER.info(f"Starting training on {self.device} for {self.epochs} epochs...")
|
|
start_time = time.time()
|
|
best_fitness = 0.0
|
|
|
|
for epoch in range(self.start_epoch, self.epochs):
|
|
self.train_one_epoch(epoch)
|
|
self.scheduler.step()
|
|
map50 = self.validate()
|
|
|
|
ckpt = {
|
|
'epoch': epoch,
|
|
'model': self.model.state_dict(),
|
|
'ema': self.ema.ema.state_dict(),
|
|
'optimizer': self.optimizer.state_dict(),
|
|
}
|
|
|
|
torch.save(ckpt, self.save_dir / 'last.pt')
|
|
|
|
if map50 > best_fitness:
|
|
best_fitness = map50
|
|
torch.save(ckpt, self.save_dir / 'best.pt')
|
|
LOGGER.info(f"--> Saved best model with Recall/mAP: {best_fitness:.4f}")
|
|
|
|
LOGGER.info(f"\nTraining completed in {(time.time() - start_time) / 3600:.3f} hours.") |