Files
Yolo-standalone/loss.py
2025-12-27 02:14:11 +08:00

448 lines
20 KiB
Python

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np
# ==============================================================================
# 1. 基础工具函数 (Utils)
# ==============================================================================
def xywh2xyxy(x):
"""Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right."""
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
y[..., 0] = x[..., 0] - x[..., 2] / 2 # top left x
y[..., 1] = x[..., 1] - x[..., 3] / 2 # top left y
y[..., 2] = x[..., 0] + x[..., 2] / 2 # bottom right x
y[..., 3] = x[..., 1] + x[..., 3] / 2 # bottom right y
return y
def bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7):
"""
Calculate Intersection over Union (IoU) of box1(1, 4) to box2(n, 4).
Args:
box1 (torch.Tensor): predicted, shape (N, 4)
box2 (torch.Tensor): target, shape (N, 4)
xywh (bool): If True, input boxes are (x, y, w, h). If False, (x1, y1, x2, y2).
CIoU (bool): If True, calculate Complete IoU.
"""
# Get the coordinates of bounding boxes
if xywh: # transform from xywh to xyxy
(x1, y1, w1, h1), (x2, y2, w2, h2) = box1.chunk(4, -1), box2.chunk(4, -1)
w1_, h1_, w2_, h2_ = w1 / 2, h1 / 2, w2 / 2, h2 / 2
b1_x1, b1_x2, b1_y1, b1_y2 = x1 - w1_, x1 + w1_, y1 - h1_, y1 + h1_
b2_x1, b2_x2, b2_y1, b2_y2 = x2 - w2_, x2 + w2_, y2 - h2_, y2 + h2_
else: # x1, y1, x2, y2
b1_x1, b1_y1, b1_x2, b1_y2 = box1.chunk(4, -1)
b2_x1, b2_y1, b2_x2, b2_y2 = box2.chunk(4, -1)
w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps
w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps
# Intersection area
inter = (torch.min(b1_x2, b2_x2) - torch.max(b1_x1, b2_x1)).clamp(0) * \
(torch.min(b1_y2, b2_y2) - torch.max(b1_y1, b2_y1)).clamp(0)
# Union Area
union = w1 * h1 + w2 * h2 - inter + eps
# IoU
iou = inter / union
if CIoU or DIoU or GIoU:
cw = torch.max(b1_x2, b2_x2) - torch.min(b1_x1, b2_x1) # convex (smallest enclosing box) width
ch = torch.max(b1_y2, b2_y2) - torch.min(b1_y1, b2_y1) # convex height
if CIoU or DIoU: # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1
c2 = cw ** 2 + ch ** 2 + eps # convex diagonal squared
rho2 = ((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 + (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4 # center dist ** 2
if CIoU: # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47
v = (4 / math.pi ** 2) * torch.pow(torch.atan(w2 / h2) - torch.atan(w1 / h1), 2)
with torch.no_grad():
alpha = v / (v - iou + (1 + eps))
return iou - (rho2 / c2 + v * alpha) # CIoU
return iou - rho2 / c2 # DIoU
c_area = cw * ch + eps # convex area
return iou - (c_area - union) / c_area # GIoU
return iou
def make_anchors(feats, strides, grid_cell_offset=0.5):
"""Generate anchors from features."""
anchor_points, stride_tensor = [], []
assert feats is not None
dtype, device = feats[0].dtype, feats[0].device
for i, stride in enumerate(strides):
h, w = feats[i].shape[2:]
sx = torch.arange(end=w, device=device, dtype=dtype) + grid_cell_offset
sy = torch.arange(end=h, device=device, dtype=dtype) + grid_cell_offset
sy, sx = torch.meshgrid(sy, sx, indexing="ij")
anchor_points.append(torch.stack((sx, sy), -1).view(-1, 2))
stride_tensor.append(torch.full((h * w, 1), stride, dtype=dtype, device=device))
return torch.cat(anchor_points), torch.cat(stride_tensor)
def dist2bbox(distance, anchor_points, xywh=True, dim=-1):
"""Transform distance(ltrb) to box(xywh or xyxy)."""
lt, rb = distance.chunk(2, dim)
x1y1 = anchor_points - lt
x2y2 = anchor_points + rb
if xywh:
c_xy = (x1y1 + x2y2) / 2
wh = x2y2 - x1y1
return torch.cat([c_xy, wh], dim)
return torch.cat((x1y1, x2y2), dim)
def bbox2dist(anchor_points, bbox, reg_max):
"""Transform bbox(xyxy) to dist(ltrb)."""
x1y1, x2y2 = bbox.chunk(2, -1)
return torch.cat((anchor_points - x1y1, x2y2 - anchor_points), -1).clamp_(0, reg_max - 0.01)
class TaskAlignedAssigner(nn.Module):
def __init__(self, topk=13, num_classes=80, alpha=1.0, beta=6.0, eps=1e-9):
super().__init__()
self.topk = topk
self.num_classes = num_classes
self.alpha = alpha
self.beta = beta
self.eps = eps
@torch.no_grad()
def forward(self, pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt):
self.bs = pd_scores.shape[0]
self.n_max_boxes = gt_bboxes.shape[1]
if self.n_max_boxes == 0:
return (
torch.full_like(pd_scores[..., 0], self.num_classes),
torch.zeros_like(pd_bboxes),
torch.zeros_like(pd_scores),
torch.zeros_like(pd_scores[..., 0]),
torch.zeros_like(pd_scores[..., 0]),
)
mask_pos, align_metric, overlaps = self.get_pos_mask(
pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points, mask_gt
)
target_gt_idx, fg_mask, mask_pos = self.select_highest_overlaps(mask_pos, overlaps, self.n_max_boxes)
target_labels, target_bboxes, target_scores = self.get_targets(gt_labels, gt_bboxes, target_gt_idx, fg_mask)
# Normalize
align_metric *= mask_pos
pos_align_metrics = align_metric.amax(dim=-1, keepdim=True)
pos_overlaps = (overlaps * mask_pos).amax(dim=-1, keepdim=True)
norm_align_metric = (align_metric * pos_overlaps / (pos_align_metrics + self.eps)).amax(-2).unsqueeze(-1)
target_scores = target_scores * norm_align_metric
return target_labels, target_bboxes, target_scores, fg_mask.bool(), target_gt_idx
def get_pos_mask(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points, mask_gt):
mask_in_gts = self.select_candidates_in_gts(anc_points, gt_bboxes)
align_metric, overlaps = self.get_box_metrics(pd_scores, pd_bboxes, gt_labels, gt_bboxes, mask_in_gts * mask_gt)
mask_topk = self.select_topk_candidates(align_metric, topk_mask=mask_gt.expand(-1, -1, self.topk).bool())
mask_pos = mask_topk * mask_in_gts * mask_gt
return mask_pos, align_metric, overlaps
def get_box_metrics(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, mask_gt):
na = pd_bboxes.shape[-2]
mask_gt = mask_gt.bool()
overlaps = torch.zeros([self.bs, self.n_max_boxes, na], dtype=pd_bboxes.dtype, device=pd_bboxes.device)
bbox_scores = torch.zeros([self.bs, self.n_max_boxes, na], dtype=pd_scores.dtype, device=pd_scores.device)
ind = torch.zeros([2, self.bs, self.n_max_boxes], dtype=torch.long)
ind[0] = torch.arange(end=self.bs).view(-1, 1).expand(-1, self.n_max_boxes)
ind[1] = gt_labels.squeeze(-1)
# Clamp labels to handle case where no GT exists or background index issues
ind[1] = ind[1].clamp(max=self.num_classes - 1)
bbox_scores[mask_gt] = pd_scores[ind[0], :, ind[1]][mask_gt]
pd_boxes = pd_bboxes.unsqueeze(1).expand(-1, self.n_max_boxes, -1, -1)[mask_gt]
gt_boxes = gt_bboxes.unsqueeze(2).expand(-1, -1, na, -1)[mask_gt]
overlaps[mask_gt] = bbox_iou(gt_boxes, pd_boxes, xywh=False, CIoU=True).squeeze(-1).clamp_(0)
align_metric = bbox_scores.pow(self.alpha) * overlaps.pow(self.beta)
return align_metric, overlaps
def select_topk_candidates(self, metrics, topk_mask=None):
topk_metrics, topk_idxs = torch.topk(metrics, self.topk, dim=-1, largest=True)
if topk_mask is None:
topk_mask = (topk_metrics.max(-1, keepdim=True)[0] > self.eps).expand_as(topk_idxs)
topk_idxs.masked_fill_(~topk_mask, 0)
count_tensor = torch.zeros(metrics.shape, dtype=torch.int8, device=topk_idxs.device)
ones = torch.ones_like(topk_idxs[:, :, :1], dtype=torch.int8, device=topk_idxs.device)
for k in range(self.topk):
count_tensor.scatter_add_(-1, topk_idxs[:, :, k : k + 1], ones)
count_tensor.masked_fill_(count_tensor > 1, 0)
return count_tensor.to(metrics.dtype)
def get_targets(self, gt_labels, gt_bboxes, target_gt_idx, fg_mask):
batch_ind = torch.arange(end=self.bs, dtype=torch.int64, device=gt_labels.device)[..., None]
target_gt_idx = target_gt_idx + batch_ind * self.n_max_boxes
target_labels = gt_labels.long().flatten()[target_gt_idx]
target_bboxes = gt_bboxes.view(-1, gt_bboxes.shape[-1])[target_gt_idx]
target_labels.clamp_(0)
target_scores = torch.zeros(
(target_labels.shape[0], target_labels.shape[1], self.num_classes),
dtype=torch.int64,
device=target_labels.device,
)
target_scores.scatter_(2, target_labels.unsqueeze(-1), 1)
fg_scores_mask = fg_mask[:, :, None].repeat(1, 1, self.num_classes)
target_scores = torch.where(fg_scores_mask > 0, target_scores, 0)
return target_labels, target_bboxes, target_scores
@staticmethod
def select_candidates_in_gts(xy_centers, gt_bboxes, eps=1e-9):
n_anchors = xy_centers.shape[0]
bs, n_boxes, _ = gt_bboxes.shape
lt, rb = gt_bboxes.view(-1, 1, 4).chunk(2, 2)
bbox_deltas = torch.cat((xy_centers[None] - lt, rb - xy_centers[None]), dim=2).view(bs, n_boxes, n_anchors, -1)
return bbox_deltas.amin(3).gt_(eps)
@staticmethod
def select_highest_overlaps(mask_pos, overlaps, n_max_boxes):
fg_mask = mask_pos.sum(-2)
if fg_mask.max() > 1:
mask_multi_gts = (fg_mask.unsqueeze(1) > 1).expand(-1, n_max_boxes, -1)
max_overlaps_idx = overlaps.argmax(1)
is_max_overlaps = torch.zeros(mask_pos.shape, dtype=mask_pos.dtype, device=mask_pos.device)
is_max_overlaps.scatter_(1, max_overlaps_idx.unsqueeze(1), 1)
mask_pos = torch.where(mask_multi_gts, is_max_overlaps, mask_pos).float()
fg_mask = mask_pos.sum(-2)
target_gt_idx = mask_pos.argmax(-2)
return target_gt_idx, fg_mask, mask_pos
# ==============================================================================
# 3. Loss 模块 (Loss Modules - from loss.py)
# ==============================================================================
class DFLoss(nn.Module):
"""Distribution Focal Loss (DFL)."""
def __init__(self, reg_max: int = 16) -> None:
super().__init__()
self.reg_max = reg_max
def __call__(self, pred_dist: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
target = target.clamp_(0, self.reg_max - 1 - 0.01)
tl = target.long() # target left
tr = tl + 1 # target right
wl = tr - target # weight left
wr = 1 - wl # weight right
return (
F.cross_entropy(pred_dist, tl.view(-1), reduction="none").view(tl.shape) * wl
+ F.cross_entropy(pred_dist, tr.view(-1), reduction="none").view(tl.shape) * wr
).mean(-1, keepdim=True)
class BboxLoss(nn.Module):
"""Criterion for computing training losses for bounding boxes (IoU + DFL)."""
def __init__(self, reg_max: int = 16):
super().__init__()
self.dfl_loss = DFLoss(reg_max) if reg_max > 1 else None
def forward(self, pred_dist, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask):
weight = target_scores.sum(-1)[fg_mask].unsqueeze(-1)
iou = bbox_iou(pred_bboxes[fg_mask], target_bboxes[fg_mask], xywh=False, CIoU=True)
loss_iou = ((1.0 - iou) * weight).sum() / target_scores_sum
if self.dfl_loss:
target_ltrb = bbox2dist(anchor_points, target_bboxes, self.dfl_loss.reg_max - 1)
loss_dfl = self.dfl_loss(pred_dist[fg_mask].view(-1, self.dfl_loss.reg_max), target_ltrb[fg_mask]) * weight
loss_dfl = loss_dfl.sum() / target_scores_sum
else:
loss_dfl = torch.tensor(0.0).to(pred_dist.device)
return loss_iou, loss_dfl
class YOLOv8DetectionLoss:
"""
独立版的 YOLOv8/v11 Detection Loss。
Refactored to assume specific inputs and remove reliance on full Model object.
"""
def __init__(self, nc=80, reg_max=16, stride=None, hyp=None):
"""
Args:
nc (int): Number of classes.
reg_max (int): DFL channels (default 16).
stride (list/torch.Tensor): Model strides (e.g., [8, 16, 32]).
hyp (dict): Hyperparameters (box, cls, dfl gains).
"""
if stride is None:
stride = [8, 16, 32] # Default strides
if hyp is None:
# Default YOLOv8 hyperparameters
hyp = {'box': 7.5, 'cls': 0.5, 'dfl': 1.5}
self.nc = nc
self.reg_max = reg_max
self.stride = stride
self.hyp = hyp
self.no = nc + reg_max * 4 # Output channels per anchor
self.use_dfl = reg_max > 1
# Components
self.assigner = TaskAlignedAssigner(topk=10, num_classes=self.nc, alpha=0.5, beta=6.0)
self.bbox_loss = BboxLoss(reg_max)
self.bce = nn.BCEWithLogitsLoss(reduction="none")
self.proj = torch.arange(reg_max, dtype=torch.float) # Will move to device in method
def preprocess(self, targets, batch_size, scale_tensor):
"""Preprocesses targets: converts [idx, cls, xywh] to internal format."""
if targets.shape[0] == 0:
out = torch.zeros(batch_size, 0, 5, device=targets.device)
else:
i = targets[:, 0] # image index
_, counts = i.unique(return_counts=True)
counts = counts.to(dtype=torch.int32)
out = torch.zeros(batch_size, counts.max(), 5, device=targets.device)
for j in range(batch_size):
matches = i == j
n = matches.sum()
if n:
out[j, :n] = targets[matches, 1:] # cls, x, y, w, h
out[..., 1:5] = xywh2xyxy(out[..., 1:5].mul_(scale_tensor))
return out
def bbox_decode(self, anchor_points, pred_dist):
"""Decode predicted object bounding box coordinates from anchor points and distribution."""
if self.use_dfl:
b, a, c = pred_dist.shape # batch, anchors, channels
# Project distributions to scalars using self.proj (0, 1, 2, ..., reg_max-1)
pred_dist = pred_dist.view(b, a, 4, c // 4).softmax(3).matmul(self.proj.to(pred_dist.device).type(pred_dist.dtype))
return dist2bbox(pred_dist, anchor_points, xywh=False)
def __call__(self, preds, batch):
"""
Args:
preds: List of prediction tensors [B, C, H, W] for each stride level.
C should be reg_max * 4 + nc.
batch: Dict containing:
'batch_idx': Tensor [N], image index for each target
'cls': Tensor [N], class index for each target
'bboxes': Tensor [N, 4], normalized xywh format
Returns:
total_loss: scalar tensor (sum of box, cls, dfl), multiplied by batch_size as per original logic if needed,
but here we return mean or sum based on reduction.
WARNING: Ultralytics returns (loss * bs) usually.
loss_items: Tensor used for logging [box, cls, dfl]
"""
# Ensure proj is on correct device
self.device = preds[0].device
loss = torch.zeros(3, device=self.device) # box, cls, dfl
feats = preds
# Concat features from different strides: (B, no, Total_Anchors)
pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
(self.reg_max * 4, self.nc), 1
)
pred_scores = pred_scores.permute(0, 2, 1).contiguous() # (B, Anchors, nc)
pred_distri = pred_distri.permute(0, 2, 1).contiguous() # (B, Anchors, reg_max * 4)
dtype = pred_scores.dtype
batch_size = pred_scores.shape[0]
# Calculate image size from features and first stride (assuming square output implies inputs)
# Ultralytics uses provided imgsz, but we can infer roughly or pass it.
# Here assuming feature map H*stride[0] = img H
imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0]
# Generate anchors
anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
# Process Targets
if "batch_idx" not in batch:
# Fallback if simplified batch input passed (targets assumed [N, 6] -> idx, cls, x, y, w, h)
# This handles cases where user passes raw target tensor instead of dict
raise ValueError("Batch dict must contain 'batch_idx', 'cls', and 'bboxes'")
targets = torch.cat((batch["batch_idx"].view(-1, 1), batch["cls"].view(-1, 1), batch["bboxes"]), 1)
# Check if Normalized (xywh)
targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0)
# Decode Boxes
pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4)
# Task Aligned Assignment
_, target_bboxes, target_scores, fg_mask, _ = self.assigner(
pred_scores.detach().sigmoid(),
(pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
anchor_points * stride_tensor,
gt_labels,
gt_bboxes,
mask_gt,
)
target_scores_sum = max(target_scores.sum(), 1)
# CLS Loss
loss[1] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum
# BBox Loss (IoU + DFL)
if fg_mask.sum():
target_bboxes /= stride_tensor
loss[0], loss[2] = self.bbox_loss(
pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask
)
# Apply Gains
loss[0] *= self.hyp['box']
loss[1] *= self.hyp['cls']
loss[2] *= self.hyp['dfl']
return loss.sum() * batch_size, loss.detach() # scaling by batch_size to match Ultralytics behavior
# ==============================================================================
# Usage Example
# ==============================================================================
if __name__ == "__main__":
# 配置
num_classes = 80
reg_max = 16
strides = [8, 16, 32]
device = "cuda" if torch.cuda.is_available() else "cpu"
# 初始化 Loss
criterion = YOLOv8DetectionLoss(nc=num_classes, reg_max=reg_max, stride=strides)
# 模拟输入 (Batch Size = 2)
bs = 2
# 模拟模型输出: 3 个层级的特征图
# [BS, Channels, H, W], Channels = reg_max * 4 + num_classes
channels = reg_max * 4 + num_classes
preds = [
torch.randn(bs, channels, 80, 80, device=device, requires_grad=True), # stride 8
torch.randn(bs, channels, 40, 40, device=device, requires_grad=True), # stride 16
torch.randn(bs, channels, 20, 20, device=device, requires_grad=True), # stride 32
]
# 模拟 Ground Truth
# batch: idx, cls, x, y, w, h (normalized 0-1)
target_batch = {
"batch_idx": torch.tensor([0, 0, 1], device=device), # Image indices
"cls": torch.tensor([1, 5, 10], device=device), # Class indices
"bboxes": torch.tensor([ # Normalized xywh
[0.5, 0.5, 0.2, 0.3],
[0.1, 0.1, 0.1, 0.1],
[0.8, 0.8, 0.2, 0.2]
], device=device)
}
# 计算 Loss
total_loss, loss_items = criterion(preds, target_batch)
print(f"Total Loss: {total_loss.item()}")
print(f"Loss Items (Box, Cls, DFL): {loss_items}")
# 反向传播测试
total_loss.backward()
print("Backward pass successful.")