第一次提交Yolo项目
This commit is contained in:
275
trainer.py
Normal file
275
trainer.py
Normal file
@@ -0,0 +1,275 @@
|
||||
import math
|
||||
import copy
|
||||
import time
|
||||
import logging
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from tqdm import tqdm
|
||||
from metrics import ap_per_class, box_iou, non_max_suppression, xywh2xyxy
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(format="%(message)s", level=logging.INFO)
|
||||
LOGGER = logging.getLogger("YOLO_Trainer")
|
||||
|
||||
# ==============================================================================
|
||||
# Helper Class: Model EMA (Exponential Moving Average)
|
||||
# ==============================================================================
|
||||
class ModelEMA:
|
||||
""" Updated Exponential Moving Average (EMA) from Ultralytics """
|
||||
def __init__(self, model, decay=0.9999, tau=2000, updates=0):
|
||||
self.ema = copy.deepcopy(model).eval() # FP32 EMA
|
||||
self.updates = updates
|
||||
# decay exponential ramp (to help early epochs)
|
||||
self.decay = lambda x: decay * (1 - math.exp(-x / tau))
|
||||
for p in self.ema.parameters():
|
||||
p.requires_grad_(False)
|
||||
|
||||
def update(self, model):
|
||||
self.updates += 1
|
||||
d = self.decay(self.updates)
|
||||
|
||||
msd = model.state_dict()
|
||||
for k, v in self.ema.state_dict().items():
|
||||
if k in msd:
|
||||
tmp = msd[k].to(v.device)
|
||||
if v.dtype.is_floating_point:
|
||||
v *= d
|
||||
v += (1 - d) * tmp
|
||||
|
||||
# ==============================================================================
|
||||
# Main Trainer Class
|
||||
# ==============================================================================
|
||||
class YOLO11Trainer:
|
||||
def __init__(self,
|
||||
model,
|
||||
train_loader,
|
||||
val_loader,
|
||||
loss_fn,
|
||||
epochs=100,
|
||||
lr0=0.01,
|
||||
lrf=0.01,
|
||||
device='cuda',
|
||||
save_dir='./runs/train',
|
||||
warmup_epochs=3.0):
|
||||
|
||||
self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
|
||||
self.model = model.to(self.device)
|
||||
self.train_loader = train_loader
|
||||
self.val_loader = val_loader
|
||||
self.loss_fn = loss_fn
|
||||
self.epochs = epochs
|
||||
self.save_dir = Path(save_dir)
|
||||
self.save_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.warmup_epochs = warmup_epochs
|
||||
self.start_epoch = 0
|
||||
|
||||
# --- Optimizer Building ---
|
||||
g = [], [], [] # optimizer parameter groups
|
||||
bn = tuple(v for k, v in nn.__dict__.items() if 'Norm' in k)
|
||||
|
||||
for v in self.model.modules():
|
||||
for p_name, p in v.named_parameters(recurse=False):
|
||||
if p_name == 'bias':
|
||||
g[2].append(p) # biases
|
||||
elif isinstance(v, bn):
|
||||
g[1].append(p) # bn weights (no decay)
|
||||
else:
|
||||
g[0].append(p) # weights (decay)
|
||||
|
||||
self.optimizer = optim.SGD(g[2], lr=lr0, momentum=0.937, nesterov=True)
|
||||
self.optimizer.add_param_group({'params': g[0], 'weight_decay': 0.0005})
|
||||
self.optimizer.add_param_group({'params': g[1], 'weight_decay': 0.0})
|
||||
|
||||
LOGGER.info(f"Optimizer: weights={len(g[0])}, bn={len(g[1])}, biases={len(g[2])}")
|
||||
|
||||
# --- Scheduler ---
|
||||
self.lf = lambda x: ((1 - math.cos(x * math.pi / self.epochs)) / 2) * (lrf - 1) + 1
|
||||
self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=self.lf)
|
||||
|
||||
# --- AMP & EMA ---
|
||||
self.scaler = torch.amp.GradScaler('cuda', enabled=True)
|
||||
self.ema = ModelEMA(self.model)
|
||||
|
||||
def train_one_epoch(self, epoch):
|
||||
self.model.train()
|
||||
pbar = tqdm(enumerate(self.train_loader), total=len(self.train_loader),
|
||||
desc=f'Epoch {epoch+1}/{self.epochs}', leave=True)
|
||||
|
||||
mloss = torch.zeros(3, device=self.device)
|
||||
self.optimizer.zero_grad()
|
||||
nb = len(self.train_loader)
|
||||
|
||||
for i, batch in pbar:
|
||||
ni = i + nb * epoch
|
||||
|
||||
imgs, targets, paths = batch
|
||||
imgs = imgs.to(self.device, non_blocking=True)
|
||||
targets = targets.to(self.device)
|
||||
|
||||
# --- Warmup ---
|
||||
if ni <= nb * self.warmup_epochs:
|
||||
xp = [0, nb * self.warmup_epochs]
|
||||
for j, x in enumerate(self.optimizer.param_groups):
|
||||
lr_target = x['initial_lr'] * self.lf(epoch)
|
||||
x['lr'] = np.interp(ni, xp, [0.1 if j == 0 else 0.0, lr_target])
|
||||
if 'momentum' in x:
|
||||
x['momentum'] = np.interp(ni, xp, [0.8, 0.937])
|
||||
|
||||
# --- Forward ---
|
||||
with torch.amp.autocast('cuda', enabled=True):
|
||||
preds = self.model(imgs)
|
||||
|
||||
target_batch = {
|
||||
"batch_idx": targets[:, 0],
|
||||
"cls": targets[:, 1],
|
||||
"bboxes": targets[:, 2:],
|
||||
}
|
||||
|
||||
loss, loss_items = self.loss_fn(preds, target_batch)
|
||||
|
||||
# --- Backward ---
|
||||
self.scaler.scale(loss).backward()
|
||||
self.scaler.unscale_(self.optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=10.0)
|
||||
self.scaler.step(self.optimizer)
|
||||
self.scaler.update()
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
# --- EMA ---
|
||||
self.ema.update(self.model)
|
||||
|
||||
# --- Logging ---
|
||||
loss_items = loss_items.detach()
|
||||
mloss = (mloss * i + loss_items) / (i + 1)
|
||||
mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G'
|
||||
|
||||
pbar.set_postfix({
|
||||
'mem': mem,
|
||||
'box': f"{mloss[0]:.4f}",
|
||||
'cls': f"{mloss[1]:.4f}",
|
||||
'dfl': f"{mloss[2]:.4f}",
|
||||
'lr': f"{self.optimizer.param_groups[0]['lr']:.5f}"
|
||||
})
|
||||
|
||||
def validate(self):
|
||||
model = self.ema.ema
|
||||
device = self.device
|
||||
|
||||
# --- Metrics Config ---
|
||||
conf_thres = 0.001 # Low threshold for mAP calculation
|
||||
iou_thres = 0.7 # NMS IoU threshold
|
||||
iouv = torch.linspace(0.5, 0.95, 10, device=device) # IoU vector for mAP@0.5:0.95
|
||||
|
||||
loss_sum = torch.zeros(3, device=device)
|
||||
stats = [] # [(correct, conf, pred_cls, target_cls)]
|
||||
|
||||
LOGGER.info("\nValidating...")
|
||||
|
||||
model.eval()
|
||||
|
||||
pbar = tqdm(self.val_loader, desc="Calc Metrics")
|
||||
with torch.no_grad():
|
||||
for batch in pbar:
|
||||
imgs, targets, _ = batch
|
||||
imgs = imgs.to(device, non_blocking=True)
|
||||
targets = targets.to(device)
|
||||
_, _, height, width = imgs.shape
|
||||
|
||||
# Inference
|
||||
preds = model(imgs)
|
||||
|
||||
# NMS
|
||||
preds = non_max_suppression(preds, conf_thres=conf_thres, iou_thres=iou_thres)
|
||||
|
||||
# Metrics Processing
|
||||
for si, pred in enumerate(preds):
|
||||
labels = targets[targets[:, 0] == si]
|
||||
nl = len(labels)
|
||||
tcls = labels[:, 1].tolist() if nl else []
|
||||
|
||||
if len(pred) == 0:
|
||||
if nl:
|
||||
stats.append((torch.zeros(0, iouv.numel(), dtype=torch.bool),
|
||||
torch.Tensor(), torch.Tensor(), torch.tensor(tcls)))
|
||||
continue
|
||||
|
||||
# Predictions
|
||||
predn = pred.clone()
|
||||
|
||||
# Ground Truth
|
||||
if nl:
|
||||
tbox = xywh2xyxy(labels[:, 2:6])
|
||||
tbox[:, [0, 2]] *= width
|
||||
tbox[:, [1, 3]] *= height
|
||||
labelsn = torch.cat((labels[:, 1:2], tbox), 1) # [cls, x1, y1, x2, y2]
|
||||
|
||||
# Match predictions to GT
|
||||
correct = self._process_batch(predn, labelsn, iouv)
|
||||
else:
|
||||
correct = torch.zeros(pred.shape[0], iouv.numel(), dtype=torch.bool)
|
||||
|
||||
stats.append((correct.cpu(), pred[:, 4].cpu(), pred[:, 5].cpu(), torch.tensor(tcls)))
|
||||
|
||||
mp, mr, map50, map5095 = 0.0, 0.0, 0.0, 0.0
|
||||
if len(stats) and stats[0][0].any():
|
||||
stats = [torch.cat(x, 0).cpu().numpy() for x in zip(*stats)] # to numpy
|
||||
tp, fp, p, r, f1, ap, ap_class = ap_per_class(*stats)
|
||||
ap50, ap = ap[:, 0], ap.mean(1) # AP@0.5, AP@0.5:0.95
|
||||
mp, mr, map50, map5095 = p.mean(), r.mean(), ap50.mean(), ap.mean()
|
||||
|
||||
LOGGER.info(f"Val Results: Prec={mp:.3f}, Recall={mr:.3f} mAP50={map50:.3f} mAP50-95={map5095:.3f}")
|
||||
|
||||
return map50
|
||||
|
||||
def _process_batch(self, detections, labels, iouv):
|
||||
"""
|
||||
Return correct prediction matrix
|
||||
detections: [N, 6] (x1, y1, x2, y2, conf, cls)
|
||||
labels: [M, 5] (cls, x1, y1, x2, y2)
|
||||
"""
|
||||
correct = torch.zeros(detections.shape[0], iouv.shape[0], dtype=torch.bool, device=iouv.device)
|
||||
iou = box_iou(labels[:, 1:], detections[:, :4])
|
||||
|
||||
x = torch.where((iou >= iouv[0]) & (labels[:, 0:1] == detections[:, 5]))
|
||||
|
||||
if x[0].shape[0]:
|
||||
matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1).cpu().numpy()
|
||||
if x[0].shape[0] > 1:
|
||||
matches = matches[matches[:, 2].argsort()[::-1]]
|
||||
matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
|
||||
matches = matches[matches[:, 2].argsort()[::-1]]
|
||||
matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
|
||||
|
||||
matches = torch.from_numpy(matches).to(iouv.device)
|
||||
correct[matches[:, 1].long()] = matches[:, 2:3] >= iouv
|
||||
|
||||
return correct
|
||||
|
||||
def train(self):
|
||||
LOGGER.info(f"Starting training on {self.device} for {self.epochs} epochs...")
|
||||
start_time = time.time()
|
||||
best_fitness = 0.0
|
||||
|
||||
for epoch in range(self.start_epoch, self.epochs):
|
||||
self.train_one_epoch(epoch)
|
||||
self.scheduler.step()
|
||||
map50 = self.validate()
|
||||
|
||||
ckpt = {
|
||||
'epoch': epoch,
|
||||
'model': self.model.state_dict(),
|
||||
'ema': self.ema.ema.state_dict(),
|
||||
'optimizer': self.optimizer.state_dict(),
|
||||
}
|
||||
|
||||
torch.save(ckpt, self.save_dir / 'last.pt')
|
||||
|
||||
if map50 > best_fitness:
|
||||
best_fitness = map50
|
||||
torch.save(ckpt, self.save_dir / 'best.pt')
|
||||
LOGGER.info(f"--> Saved best model with Recall/mAP: {best_fitness:.4f}")
|
||||
|
||||
LOGGER.info(f"\nTraining completed in {(time.time() - start_time) / 3600:.3f} hours.")
|
||||
Reference in New Issue
Block a user