import torch import torch.nn as nn import torch.nn.functional as F import math import numpy as np # ============================================================================== # 1. 基础工具函数 (Utils) # ============================================================================== def xywh2xyxy(x): """Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right.""" y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) y[..., 0] = x[..., 0] - x[..., 2] / 2 # top left x y[..., 1] = x[..., 1] - x[..., 3] / 2 # top left y y[..., 2] = x[..., 0] + x[..., 2] / 2 # bottom right x y[..., 3] = x[..., 1] + x[..., 3] / 2 # bottom right y return y def bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7): """ Calculate Intersection over Union (IoU) of box1(1, 4) to box2(n, 4). Args: box1 (torch.Tensor): predicted, shape (N, 4) box2 (torch.Tensor): target, shape (N, 4) xywh (bool): If True, input boxes are (x, y, w, h). If False, (x1, y1, x2, y2). CIoU (bool): If True, calculate Complete IoU. """ # Get the coordinates of bounding boxes if xywh: # transform from xywh to xyxy (x1, y1, w1, h1), (x2, y2, w2, h2) = box1.chunk(4, -1), box2.chunk(4, -1) w1_, h1_, w2_, h2_ = w1 / 2, h1 / 2, w2 / 2, h2 / 2 b1_x1, b1_x2, b1_y1, b1_y2 = x1 - w1_, x1 + w1_, y1 - h1_, y1 + h1_ b2_x1, b2_x2, b2_y1, b2_y2 = x2 - w2_, x2 + w2_, y2 - h2_, y2 + h2_ else: # x1, y1, x2, y2 b1_x1, b1_y1, b1_x2, b1_y2 = box1.chunk(4, -1) b2_x1, b2_y1, b2_x2, b2_y2 = box2.chunk(4, -1) w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps # Intersection area inter = (torch.min(b1_x2, b2_x2) - torch.max(b1_x1, b2_x1)).clamp(0) * \ (torch.min(b1_y2, b2_y2) - torch.max(b1_y1, b2_y1)).clamp(0) # Union Area union = w1 * h1 + w2 * h2 - inter + eps # IoU iou = inter / union if CIoU or DIoU or GIoU: cw = torch.max(b1_x2, b2_x2) - torch.min(b1_x1, b2_x1) # convex (smallest enclosing box) width ch = torch.max(b1_y2, b2_y2) - torch.min(b1_y1, b2_y1) # convex height if CIoU or DIoU: # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1 c2 = cw ** 2 + ch ** 2 + eps # convex diagonal squared rho2 = ((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 + (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4 # center dist ** 2 if CIoU: # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47 v = (4 / math.pi ** 2) * torch.pow(torch.atan(w2 / h2) - torch.atan(w1 / h1), 2) with torch.no_grad(): alpha = v / (v - iou + (1 + eps)) return iou - (rho2 / c2 + v * alpha) # CIoU return iou - rho2 / c2 # DIoU c_area = cw * ch + eps # convex area return iou - (c_area - union) / c_area # GIoU return iou def make_anchors(feats, strides, grid_cell_offset=0.5): """Generate anchors from features.""" anchor_points, stride_tensor = [], [] assert feats is not None dtype, device = feats[0].dtype, feats[0].device for i, stride in enumerate(strides): h, w = feats[i].shape[2:] sx = torch.arange(end=w, device=device, dtype=dtype) + grid_cell_offset sy = torch.arange(end=h, device=device, dtype=dtype) + grid_cell_offset sy, sx = torch.meshgrid(sy, sx, indexing="ij") anchor_points.append(torch.stack((sx, sy), -1).view(-1, 2)) stride_tensor.append(torch.full((h * w, 1), stride, dtype=dtype, device=device)) return torch.cat(anchor_points), torch.cat(stride_tensor) def dist2bbox(distance, anchor_points, xywh=True, dim=-1): """Transform distance(ltrb) to box(xywh or xyxy).""" lt, rb = distance.chunk(2, dim) x1y1 = anchor_points - lt x2y2 = anchor_points + rb if xywh: c_xy = (x1y1 + x2y2) / 2 wh = x2y2 - x1y1 return torch.cat([c_xy, wh], dim) return torch.cat((x1y1, x2y2), dim) def bbox2dist(anchor_points, bbox, reg_max): """Transform bbox(xyxy) to dist(ltrb).""" x1y1, x2y2 = bbox.chunk(2, -1) return torch.cat((anchor_points - x1y1, x2y2 - anchor_points), -1).clamp_(0, reg_max - 0.01) class TaskAlignedAssigner(nn.Module): def __init__(self, topk=13, num_classes=80, alpha=1.0, beta=6.0, eps=1e-9): super().__init__() self.topk = topk self.num_classes = num_classes self.alpha = alpha self.beta = beta self.eps = eps @torch.no_grad() def forward(self, pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt): self.bs = pd_scores.shape[0] self.n_max_boxes = gt_bboxes.shape[1] if self.n_max_boxes == 0: return ( torch.full_like(pd_scores[..., 0], self.num_classes), torch.zeros_like(pd_bboxes), torch.zeros_like(pd_scores), torch.zeros_like(pd_scores[..., 0]), torch.zeros_like(pd_scores[..., 0]), ) mask_pos, align_metric, overlaps = self.get_pos_mask( pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points, mask_gt ) target_gt_idx, fg_mask, mask_pos = self.select_highest_overlaps(mask_pos, overlaps, self.n_max_boxes) target_labels, target_bboxes, target_scores = self.get_targets(gt_labels, gt_bboxes, target_gt_idx, fg_mask) # Normalize align_metric *= mask_pos pos_align_metrics = align_metric.amax(dim=-1, keepdim=True) pos_overlaps = (overlaps * mask_pos).amax(dim=-1, keepdim=True) norm_align_metric = (align_metric * pos_overlaps / (pos_align_metrics + self.eps)).amax(-2).unsqueeze(-1) target_scores = target_scores * norm_align_metric return target_labels, target_bboxes, target_scores, fg_mask.bool(), target_gt_idx def get_pos_mask(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points, mask_gt): mask_in_gts = self.select_candidates_in_gts(anc_points, gt_bboxes) align_metric, overlaps = self.get_box_metrics(pd_scores, pd_bboxes, gt_labels, gt_bboxes, mask_in_gts * mask_gt) mask_topk = self.select_topk_candidates(align_metric, topk_mask=mask_gt.expand(-1, -1, self.topk).bool()) mask_pos = mask_topk * mask_in_gts * mask_gt return mask_pos, align_metric, overlaps def get_box_metrics(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, mask_gt): na = pd_bboxes.shape[-2] mask_gt = mask_gt.bool() overlaps = torch.zeros([self.bs, self.n_max_boxes, na], dtype=pd_bboxes.dtype, device=pd_bboxes.device) bbox_scores = torch.zeros([self.bs, self.n_max_boxes, na], dtype=pd_scores.dtype, device=pd_scores.device) ind = torch.zeros([2, self.bs, self.n_max_boxes], dtype=torch.long) ind[0] = torch.arange(end=self.bs).view(-1, 1).expand(-1, self.n_max_boxes) ind[1] = gt_labels.squeeze(-1) # Clamp labels to handle case where no GT exists or background index issues ind[1] = ind[1].clamp(max=self.num_classes - 1) bbox_scores[mask_gt] = pd_scores[ind[0], :, ind[1]][mask_gt] pd_boxes = pd_bboxes.unsqueeze(1).expand(-1, self.n_max_boxes, -1, -1)[mask_gt] gt_boxes = gt_bboxes.unsqueeze(2).expand(-1, -1, na, -1)[mask_gt] overlaps[mask_gt] = bbox_iou(gt_boxes, pd_boxes, xywh=False, CIoU=True).squeeze(-1).clamp_(0) align_metric = bbox_scores.pow(self.alpha) * overlaps.pow(self.beta) return align_metric, overlaps def select_topk_candidates(self, metrics, topk_mask=None): topk_metrics, topk_idxs = torch.topk(metrics, self.topk, dim=-1, largest=True) if topk_mask is None: topk_mask = (topk_metrics.max(-1, keepdim=True)[0] > self.eps).expand_as(topk_idxs) topk_idxs.masked_fill_(~topk_mask, 0) count_tensor = torch.zeros(metrics.shape, dtype=torch.int8, device=topk_idxs.device) ones = torch.ones_like(topk_idxs[:, :, :1], dtype=torch.int8, device=topk_idxs.device) for k in range(self.topk): count_tensor.scatter_add_(-1, topk_idxs[:, :, k : k + 1], ones) count_tensor.masked_fill_(count_tensor > 1, 0) return count_tensor.to(metrics.dtype) def get_targets(self, gt_labels, gt_bboxes, target_gt_idx, fg_mask): batch_ind = torch.arange(end=self.bs, dtype=torch.int64, device=gt_labels.device)[..., None] target_gt_idx = target_gt_idx + batch_ind * self.n_max_boxes target_labels = gt_labels.long().flatten()[target_gt_idx] target_bboxes = gt_bboxes.view(-1, gt_bboxes.shape[-1])[target_gt_idx] target_labels.clamp_(0) target_scores = torch.zeros( (target_labels.shape[0], target_labels.shape[1], self.num_classes), dtype=torch.int64, device=target_labels.device, ) target_scores.scatter_(2, target_labels.unsqueeze(-1), 1) fg_scores_mask = fg_mask[:, :, None].repeat(1, 1, self.num_classes) target_scores = torch.where(fg_scores_mask > 0, target_scores, 0) return target_labels, target_bboxes, target_scores @staticmethod def select_candidates_in_gts(xy_centers, gt_bboxes, eps=1e-9): n_anchors = xy_centers.shape[0] bs, n_boxes, _ = gt_bboxes.shape lt, rb = gt_bboxes.view(-1, 1, 4).chunk(2, 2) bbox_deltas = torch.cat((xy_centers[None] - lt, rb - xy_centers[None]), dim=2).view(bs, n_boxes, n_anchors, -1) return bbox_deltas.amin(3).gt_(eps) @staticmethod def select_highest_overlaps(mask_pos, overlaps, n_max_boxes): fg_mask = mask_pos.sum(-2) if fg_mask.max() > 1: mask_multi_gts = (fg_mask.unsqueeze(1) > 1).expand(-1, n_max_boxes, -1) max_overlaps_idx = overlaps.argmax(1) is_max_overlaps = torch.zeros(mask_pos.shape, dtype=mask_pos.dtype, device=mask_pos.device) is_max_overlaps.scatter_(1, max_overlaps_idx.unsqueeze(1), 1) mask_pos = torch.where(mask_multi_gts, is_max_overlaps, mask_pos).float() fg_mask = mask_pos.sum(-2) target_gt_idx = mask_pos.argmax(-2) return target_gt_idx, fg_mask, mask_pos # ============================================================================== # 3. Loss 模块 (Loss Modules - from loss.py) # ============================================================================== class DFLoss(nn.Module): """Distribution Focal Loss (DFL).""" def __init__(self, reg_max: int = 16) -> None: super().__init__() self.reg_max = reg_max def __call__(self, pred_dist: torch.Tensor, target: torch.Tensor) -> torch.Tensor: target = target.clamp_(0, self.reg_max - 1 - 0.01) tl = target.long() # target left tr = tl + 1 # target right wl = tr - target # weight left wr = 1 - wl # weight right return ( F.cross_entropy(pred_dist, tl.view(-1), reduction="none").view(tl.shape) * wl + F.cross_entropy(pred_dist, tr.view(-1), reduction="none").view(tl.shape) * wr ).mean(-1, keepdim=True) class BboxLoss(nn.Module): """Criterion for computing training losses for bounding boxes (IoU + DFL).""" def __init__(self, reg_max: int = 16): super().__init__() self.dfl_loss = DFLoss(reg_max) if reg_max > 1 else None def forward(self, pred_dist, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask): weight = target_scores.sum(-1)[fg_mask].unsqueeze(-1) iou = bbox_iou(pred_bboxes[fg_mask], target_bboxes[fg_mask], xywh=False, CIoU=True) loss_iou = ((1.0 - iou) * weight).sum() / target_scores_sum if self.dfl_loss: target_ltrb = bbox2dist(anchor_points, target_bboxes, self.dfl_loss.reg_max - 1) loss_dfl = self.dfl_loss(pred_dist[fg_mask].view(-1, self.dfl_loss.reg_max), target_ltrb[fg_mask]) * weight loss_dfl = loss_dfl.sum() / target_scores_sum else: loss_dfl = torch.tensor(0.0).to(pred_dist.device) return loss_iou, loss_dfl class YOLOv8DetectionLoss: """ 独立版的 YOLOv8/v11 Detection Loss。 Refactored to assume specific inputs and remove reliance on full Model object. """ def __init__(self, nc=80, reg_max=16, stride=None, hyp=None): """ Args: nc (int): Number of classes. reg_max (int): DFL channels (default 16). stride (list/torch.Tensor): Model strides (e.g., [8, 16, 32]). hyp (dict): Hyperparameters (box, cls, dfl gains). """ if stride is None: stride = [8, 16, 32] # Default strides if hyp is None: # Default YOLOv8 hyperparameters hyp = {'box': 7.5, 'cls': 0.5, 'dfl': 1.5} self.nc = nc self.reg_max = reg_max self.stride = stride self.hyp = hyp self.no = nc + reg_max * 4 # Output channels per anchor self.use_dfl = reg_max > 1 # Components self.assigner = TaskAlignedAssigner(topk=10, num_classes=self.nc, alpha=0.5, beta=6.0) self.bbox_loss = BboxLoss(reg_max) self.bce = nn.BCEWithLogitsLoss(reduction="none") self.proj = torch.arange(reg_max, dtype=torch.float) # Will move to device in method def preprocess(self, targets, batch_size, scale_tensor): """Preprocesses targets: converts [idx, cls, xywh] to internal format.""" if targets.shape[0] == 0: out = torch.zeros(batch_size, 0, 5, device=targets.device) else: i = targets[:, 0] # image index _, counts = i.unique(return_counts=True) counts = counts.to(dtype=torch.int32) out = torch.zeros(batch_size, counts.max(), 5, device=targets.device) for j in range(batch_size): matches = i == j n = matches.sum() if n: out[j, :n] = targets[matches, 1:] # cls, x, y, w, h out[..., 1:5] = xywh2xyxy(out[..., 1:5].mul_(scale_tensor)) return out def bbox_decode(self, anchor_points, pred_dist): """Decode predicted object bounding box coordinates from anchor points and distribution.""" if self.use_dfl: b, a, c = pred_dist.shape # batch, anchors, channels # Project distributions to scalars using self.proj (0, 1, 2, ..., reg_max-1) pred_dist = pred_dist.view(b, a, 4, c // 4).softmax(3).matmul(self.proj.to(pred_dist.device).type(pred_dist.dtype)) return dist2bbox(pred_dist, anchor_points, xywh=False) def __call__(self, preds, batch): """ Args: preds: List of prediction tensors [B, C, H, W] for each stride level. C should be reg_max * 4 + nc. batch: Dict containing: 'batch_idx': Tensor [N], image index for each target 'cls': Tensor [N], class index for each target 'bboxes': Tensor [N, 4], normalized xywh format Returns: total_loss: scalar tensor (sum of box, cls, dfl), multiplied by batch_size as per original logic if needed, but here we return mean or sum based on reduction. WARNING: Ultralytics returns (loss * bs) usually. loss_items: Tensor used for logging [box, cls, dfl] """ # Ensure proj is on correct device self.device = preds[0].device loss = torch.zeros(3, device=self.device) # box, cls, dfl feats = preds # Concat features from different strides: (B, no, Total_Anchors) pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split( (self.reg_max * 4, self.nc), 1 ) pred_scores = pred_scores.permute(0, 2, 1).contiguous() # (B, Anchors, nc) pred_distri = pred_distri.permute(0, 2, 1).contiguous() # (B, Anchors, reg_max * 4) dtype = pred_scores.dtype batch_size = pred_scores.shape[0] # Calculate image size from features and first stride (assuming square output implies inputs) # Ultralytics uses provided imgsz, but we can infer roughly or pass it. # Here assuming feature map H*stride[0] = img H imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # Generate anchors anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5) # Process Targets if "batch_idx" not in batch: # Fallback if simplified batch input passed (targets assumed [N, 6] -> idx, cls, x, y, w, h) # This handles cases where user passes raw target tensor instead of dict raise ValueError("Batch dict must contain 'batch_idx', 'cls', and 'bboxes'") targets = torch.cat((batch["batch_idx"].view(-1, 1), batch["cls"].view(-1, 1), batch["bboxes"]), 1) # Check if Normalized (xywh) targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]]) gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0) # Decode Boxes pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4) # Task Aligned Assignment _, target_bboxes, target_scores, fg_mask, _ = self.assigner( pred_scores.detach().sigmoid(), (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype), anchor_points * stride_tensor, gt_labels, gt_bboxes, mask_gt, ) target_scores_sum = max(target_scores.sum(), 1) # CLS Loss loss[1] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BBox Loss (IoU + DFL) if fg_mask.sum(): target_bboxes /= stride_tensor loss[0], loss[2] = self.bbox_loss( pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask ) # Apply Gains loss[0] *= self.hyp['box'] loss[1] *= self.hyp['cls'] loss[2] *= self.hyp['dfl'] return loss.sum() * batch_size, loss.detach() # scaling by batch_size to match Ultralytics behavior # ============================================================================== # Usage Example # ============================================================================== if __name__ == "__main__": # 配置 num_classes = 80 reg_max = 16 strides = [8, 16, 32] device = "cuda" if torch.cuda.is_available() else "cpu" # 初始化 Loss criterion = YOLOv8DetectionLoss(nc=num_classes, reg_max=reg_max, stride=strides) # 模拟输入 (Batch Size = 2) bs = 2 # 模拟模型输出: 3 个层级的特征图 # [BS, Channels, H, W], Channels = reg_max * 4 + num_classes channels = reg_max * 4 + num_classes preds = [ torch.randn(bs, channels, 80, 80, device=device, requires_grad=True), # stride 8 torch.randn(bs, channels, 40, 40, device=device, requires_grad=True), # stride 16 torch.randn(bs, channels, 20, 20, device=device, requires_grad=True), # stride 32 ] # 模拟 Ground Truth # batch: idx, cls, x, y, w, h (normalized 0-1) target_batch = { "batch_idx": torch.tensor([0, 0, 1], device=device), # Image indices "cls": torch.tensor([1, 5, 10], device=device), # Class indices "bboxes": torch.tensor([ # Normalized xywh [0.5, 0.5, 0.2, 0.3], [0.1, 0.1, 0.1, 0.1], [0.8, 0.8, 0.2, 0.2] ], device=device) } # 计算 Loss total_loss, loss_items = criterion(preds, target_batch) print(f"Total Loss: {total_loss.item()}") print(f"Loss Items (Box, Cls, DFL): {loss_items}") # 反向传播测试 total_loss.backward() print("Backward pass successful.")