Files
Yolo-standalone/trainer.py
2025-12-27 02:14:11 +08:00

275 lines
11 KiB
Python

import math
import copy
import time
import logging
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
from metrics import ap_per_class, box_iou, non_max_suppression, xywh2xyxy
# 配置日志
logging.basicConfig(format="%(message)s", level=logging.INFO)
LOGGER = logging.getLogger("YOLO_Trainer")
# ==============================================================================
# Helper Class: Model EMA (Exponential Moving Average)
# ==============================================================================
class ModelEMA:
""" Updated Exponential Moving Average (EMA) from Ultralytics """
def __init__(self, model, decay=0.9999, tau=2000, updates=0):
self.ema = copy.deepcopy(model).eval() # FP32 EMA
self.updates = updates
# decay exponential ramp (to help early epochs)
self.decay = lambda x: decay * (1 - math.exp(-x / tau))
for p in self.ema.parameters():
p.requires_grad_(False)
def update(self, model):
self.updates += 1
d = self.decay(self.updates)
msd = model.state_dict()
for k, v in self.ema.state_dict().items():
if k in msd:
tmp = msd[k].to(v.device)
if v.dtype.is_floating_point:
v *= d
v += (1 - d) * tmp
# ==============================================================================
# Main Trainer Class
# ==============================================================================
class YOLO11Trainer:
def __init__(self,
model,
train_loader,
val_loader,
loss_fn,
epochs=100,
lr0=0.01,
lrf=0.01,
device='cuda',
save_dir='./runs/train',
warmup_epochs=3.0):
self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
self.model = model.to(self.device)
self.train_loader = train_loader
self.val_loader = val_loader
self.loss_fn = loss_fn
self.epochs = epochs
self.save_dir = Path(save_dir)
self.save_dir.mkdir(parents=True, exist_ok=True)
self.warmup_epochs = warmup_epochs
self.start_epoch = 0
# --- Optimizer Building ---
g = [], [], [] # optimizer parameter groups
bn = tuple(v for k, v in nn.__dict__.items() if 'Norm' in k)
for v in self.model.modules():
for p_name, p in v.named_parameters(recurse=False):
if p_name == 'bias':
g[2].append(p) # biases
elif isinstance(v, bn):
g[1].append(p) # bn weights (no decay)
else:
g[0].append(p) # weights (decay)
self.optimizer = optim.SGD(g[2], lr=lr0, momentum=0.937, nesterov=True)
self.optimizer.add_param_group({'params': g[0], 'weight_decay': 0.0005})
self.optimizer.add_param_group({'params': g[1], 'weight_decay': 0.0})
LOGGER.info(f"Optimizer: weights={len(g[0])}, bn={len(g[1])}, biases={len(g[2])}")
# --- Scheduler ---
self.lf = lambda x: ((1 - math.cos(x * math.pi / self.epochs)) / 2) * (lrf - 1) + 1
self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=self.lf)
# --- AMP & EMA ---
self.scaler = torch.amp.GradScaler('cuda', enabled=True)
self.ema = ModelEMA(self.model)
def train_one_epoch(self, epoch):
self.model.train()
pbar = tqdm(enumerate(self.train_loader), total=len(self.train_loader),
desc=f'Epoch {epoch+1}/{self.epochs}', leave=True)
mloss = torch.zeros(3, device=self.device)
self.optimizer.zero_grad()
nb = len(self.train_loader)
for i, batch in pbar:
ni = i + nb * epoch
imgs, targets, paths = batch
imgs = imgs.to(self.device, non_blocking=True)
targets = targets.to(self.device)
# --- Warmup ---
if ni <= nb * self.warmup_epochs:
xp = [0, nb * self.warmup_epochs]
for j, x in enumerate(self.optimizer.param_groups):
lr_target = x['initial_lr'] * self.lf(epoch)
x['lr'] = np.interp(ni, xp, [0.1 if j == 0 else 0.0, lr_target])
if 'momentum' in x:
x['momentum'] = np.interp(ni, xp, [0.8, 0.937])
# --- Forward ---
with torch.amp.autocast('cuda', enabled=True):
preds = self.model(imgs)
target_batch = {
"batch_idx": targets[:, 0],
"cls": targets[:, 1],
"bboxes": targets[:, 2:],
}
loss, loss_items = self.loss_fn(preds, target_batch)
# --- Backward ---
self.scaler.scale(loss).backward()
self.scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=10.0)
self.scaler.step(self.optimizer)
self.scaler.update()
self.optimizer.zero_grad()
# --- EMA ---
self.ema.update(self.model)
# --- Logging ---
loss_items = loss_items.detach()
mloss = (mloss * i + loss_items) / (i + 1)
mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G'
pbar.set_postfix({
'mem': mem,
'box': f"{mloss[0]:.4f}",
'cls': f"{mloss[1]:.4f}",
'dfl': f"{mloss[2]:.4f}",
'lr': f"{self.optimizer.param_groups[0]['lr']:.5f}"
})
def validate(self):
model = self.ema.ema
device = self.device
# --- Metrics Config ---
conf_thres = 0.001 # Low threshold for mAP calculation
iou_thres = 0.7 # NMS IoU threshold
iouv = torch.linspace(0.5, 0.95, 10, device=device) # IoU vector for mAP@0.5:0.95
loss_sum = torch.zeros(3, device=device)
stats = [] # [(correct, conf, pred_cls, target_cls)]
LOGGER.info("\nValidating...")
model.eval()
pbar = tqdm(self.val_loader, desc="Calc Metrics")
with torch.no_grad():
for batch in pbar:
imgs, targets, _ = batch
imgs = imgs.to(device, non_blocking=True)
targets = targets.to(device)
_, _, height, width = imgs.shape
# Inference
preds = model(imgs)
# NMS
preds = non_max_suppression(preds, conf_thres=conf_thres, iou_thres=iou_thres)
# Metrics Processing
for si, pred in enumerate(preds):
labels = targets[targets[:, 0] == si]
nl = len(labels)
tcls = labels[:, 1].tolist() if nl else []
if len(pred) == 0:
if nl:
stats.append((torch.zeros(0, iouv.numel(), dtype=torch.bool),
torch.Tensor(), torch.Tensor(), torch.tensor(tcls)))
continue
# Predictions
predn = pred.clone()
# Ground Truth
if nl:
tbox = xywh2xyxy(labels[:, 2:6])
tbox[:, [0, 2]] *= width
tbox[:, [1, 3]] *= height
labelsn = torch.cat((labels[:, 1:2], tbox), 1) # [cls, x1, y1, x2, y2]
# Match predictions to GT
correct = self._process_batch(predn, labelsn, iouv)
else:
correct = torch.zeros(pred.shape[0], iouv.numel(), dtype=torch.bool)
stats.append((correct.cpu(), pred[:, 4].cpu(), pred[:, 5].cpu(), torch.tensor(tcls)))
mp, mr, map50, map5095 = 0.0, 0.0, 0.0, 0.0
if len(stats) and stats[0][0].any():
stats = [torch.cat(x, 0).cpu().numpy() for x in zip(*stats)] # to numpy
tp, fp, p, r, f1, ap, ap_class = ap_per_class(*stats)
ap50, ap = ap[:, 0], ap.mean(1) # AP@0.5, AP@0.5:0.95
mp, mr, map50, map5095 = p.mean(), r.mean(), ap50.mean(), ap.mean()
LOGGER.info(f"Val Results: Prec={mp:.3f}, Recall={mr:.3f} mAP50={map50:.3f} mAP50-95={map5095:.3f}")
return map50
def _process_batch(self, detections, labels, iouv):
"""
Return correct prediction matrix
detections: [N, 6] (x1, y1, x2, y2, conf, cls)
labels: [M, 5] (cls, x1, y1, x2, y2)
"""
correct = torch.zeros(detections.shape[0], iouv.shape[0], dtype=torch.bool, device=iouv.device)
iou = box_iou(labels[:, 1:], detections[:, :4])
x = torch.where((iou >= iouv[0]) & (labels[:, 0:1] == detections[:, 5]))
if x[0].shape[0]:
matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1).cpu().numpy()
if x[0].shape[0] > 1:
matches = matches[matches[:, 2].argsort()[::-1]]
matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
matches = matches[matches[:, 2].argsort()[::-1]]
matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
matches = torch.from_numpy(matches).to(iouv.device)
correct[matches[:, 1].long()] = matches[:, 2:3] >= iouv
return correct
def train(self):
LOGGER.info(f"Starting training on {self.device} for {self.epochs} epochs...")
start_time = time.time()
best_fitness = 0.0
for epoch in range(self.start_epoch, self.epochs):
self.train_one_epoch(epoch)
self.scheduler.step()
map50 = self.validate()
ckpt = {
'epoch': epoch,
'model': self.model.state_dict(),
'ema': self.ema.ema.state_dict(),
'optimizer': self.optimizer.state_dict(),
}
torch.save(ckpt, self.save_dir / 'last.pt')
if map50 > best_fitness:
best_fitness = map50
torch.save(ckpt, self.save_dir / 'best.pt')
LOGGER.info(f"--> Saved best model with Recall/mAP: {best_fitness:.4f}")
LOGGER.info(f"\nTraining completed in {(time.time() - start_time) / 3600:.3f} hours.")