448 lines
20 KiB
Python
448 lines
20 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import math
|
|
import numpy as np
|
|
|
|
# ==============================================================================
|
|
# 1. 基础工具函数 (Utils)
|
|
# ==============================================================================
|
|
|
|
def xywh2xyxy(x):
|
|
"""Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right."""
|
|
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
|
y[..., 0] = x[..., 0] - x[..., 2] / 2 # top left x
|
|
y[..., 1] = x[..., 1] - x[..., 3] / 2 # top left y
|
|
y[..., 2] = x[..., 0] + x[..., 2] / 2 # bottom right x
|
|
y[..., 3] = x[..., 1] + x[..., 3] / 2 # bottom right y
|
|
return y
|
|
|
|
def bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7):
|
|
"""
|
|
Calculate Intersection over Union (IoU) of box1(1, 4) to box2(n, 4).
|
|
Args:
|
|
box1 (torch.Tensor): predicted, shape (N, 4)
|
|
box2 (torch.Tensor): target, shape (N, 4)
|
|
xywh (bool): If True, input boxes are (x, y, w, h). If False, (x1, y1, x2, y2).
|
|
CIoU (bool): If True, calculate Complete IoU.
|
|
"""
|
|
# Get the coordinates of bounding boxes
|
|
if xywh: # transform from xywh to xyxy
|
|
(x1, y1, w1, h1), (x2, y2, w2, h2) = box1.chunk(4, -1), box2.chunk(4, -1)
|
|
w1_, h1_, w2_, h2_ = w1 / 2, h1 / 2, w2 / 2, h2 / 2
|
|
b1_x1, b1_x2, b1_y1, b1_y2 = x1 - w1_, x1 + w1_, y1 - h1_, y1 + h1_
|
|
b2_x1, b2_x2, b2_y1, b2_y2 = x2 - w2_, x2 + w2_, y2 - h2_, y2 + h2_
|
|
else: # x1, y1, x2, y2
|
|
b1_x1, b1_y1, b1_x2, b1_y2 = box1.chunk(4, -1)
|
|
b2_x1, b2_y1, b2_x2, b2_y2 = box2.chunk(4, -1)
|
|
w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps
|
|
w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps
|
|
|
|
# Intersection area
|
|
inter = (torch.min(b1_x2, b2_x2) - torch.max(b1_x1, b2_x1)).clamp(0) * \
|
|
(torch.min(b1_y2, b2_y2) - torch.max(b1_y1, b2_y1)).clamp(0)
|
|
|
|
# Union Area
|
|
union = w1 * h1 + w2 * h2 - inter + eps
|
|
|
|
# IoU
|
|
iou = inter / union
|
|
if CIoU or DIoU or GIoU:
|
|
cw = torch.max(b1_x2, b2_x2) - torch.min(b1_x1, b2_x1) # convex (smallest enclosing box) width
|
|
ch = torch.max(b1_y2, b2_y2) - torch.min(b1_y1, b2_y1) # convex height
|
|
if CIoU or DIoU: # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1
|
|
c2 = cw ** 2 + ch ** 2 + eps # convex diagonal squared
|
|
rho2 = ((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 + (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4 # center dist ** 2
|
|
if CIoU: # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47
|
|
v = (4 / math.pi ** 2) * torch.pow(torch.atan(w2 / h2) - torch.atan(w1 / h1), 2)
|
|
with torch.no_grad():
|
|
alpha = v / (v - iou + (1 + eps))
|
|
return iou - (rho2 / c2 + v * alpha) # CIoU
|
|
return iou - rho2 / c2 # DIoU
|
|
c_area = cw * ch + eps # convex area
|
|
return iou - (c_area - union) / c_area # GIoU
|
|
return iou
|
|
|
|
def make_anchors(feats, strides, grid_cell_offset=0.5):
|
|
"""Generate anchors from features."""
|
|
anchor_points, stride_tensor = [], []
|
|
assert feats is not None
|
|
dtype, device = feats[0].dtype, feats[0].device
|
|
for i, stride in enumerate(strides):
|
|
h, w = feats[i].shape[2:]
|
|
sx = torch.arange(end=w, device=device, dtype=dtype) + grid_cell_offset
|
|
sy = torch.arange(end=h, device=device, dtype=dtype) + grid_cell_offset
|
|
sy, sx = torch.meshgrid(sy, sx, indexing="ij")
|
|
anchor_points.append(torch.stack((sx, sy), -1).view(-1, 2))
|
|
stride_tensor.append(torch.full((h * w, 1), stride, dtype=dtype, device=device))
|
|
return torch.cat(anchor_points), torch.cat(stride_tensor)
|
|
|
|
def dist2bbox(distance, anchor_points, xywh=True, dim=-1):
|
|
"""Transform distance(ltrb) to box(xywh or xyxy)."""
|
|
lt, rb = distance.chunk(2, dim)
|
|
x1y1 = anchor_points - lt
|
|
x2y2 = anchor_points + rb
|
|
if xywh:
|
|
c_xy = (x1y1 + x2y2) / 2
|
|
wh = x2y2 - x1y1
|
|
return torch.cat([c_xy, wh], dim)
|
|
return torch.cat((x1y1, x2y2), dim)
|
|
|
|
def bbox2dist(anchor_points, bbox, reg_max):
|
|
"""Transform bbox(xyxy) to dist(ltrb)."""
|
|
x1y1, x2y2 = bbox.chunk(2, -1)
|
|
return torch.cat((anchor_points - x1y1, x2y2 - anchor_points), -1).clamp_(0, reg_max - 0.01)
|
|
|
|
|
|
class TaskAlignedAssigner(nn.Module):
|
|
def __init__(self, topk=13, num_classes=80, alpha=1.0, beta=6.0, eps=1e-9):
|
|
super().__init__()
|
|
self.topk = topk
|
|
self.num_classes = num_classes
|
|
self.alpha = alpha
|
|
self.beta = beta
|
|
self.eps = eps
|
|
|
|
@torch.no_grad()
|
|
def forward(self, pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt):
|
|
self.bs = pd_scores.shape[0]
|
|
self.n_max_boxes = gt_bboxes.shape[1]
|
|
|
|
if self.n_max_boxes == 0:
|
|
return (
|
|
torch.full_like(pd_scores[..., 0], self.num_classes),
|
|
torch.zeros_like(pd_bboxes),
|
|
torch.zeros_like(pd_scores),
|
|
torch.zeros_like(pd_scores[..., 0]),
|
|
torch.zeros_like(pd_scores[..., 0]),
|
|
)
|
|
|
|
mask_pos, align_metric, overlaps = self.get_pos_mask(
|
|
pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points, mask_gt
|
|
)
|
|
|
|
target_gt_idx, fg_mask, mask_pos = self.select_highest_overlaps(mask_pos, overlaps, self.n_max_boxes)
|
|
|
|
target_labels, target_bboxes, target_scores = self.get_targets(gt_labels, gt_bboxes, target_gt_idx, fg_mask)
|
|
|
|
# Normalize
|
|
align_metric *= mask_pos
|
|
pos_align_metrics = align_metric.amax(dim=-1, keepdim=True)
|
|
pos_overlaps = (overlaps * mask_pos).amax(dim=-1, keepdim=True)
|
|
norm_align_metric = (align_metric * pos_overlaps / (pos_align_metrics + self.eps)).amax(-2).unsqueeze(-1)
|
|
target_scores = target_scores * norm_align_metric
|
|
|
|
return target_labels, target_bboxes, target_scores, fg_mask.bool(), target_gt_idx
|
|
|
|
def get_pos_mask(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points, mask_gt):
|
|
mask_in_gts = self.select_candidates_in_gts(anc_points, gt_bboxes)
|
|
align_metric, overlaps = self.get_box_metrics(pd_scores, pd_bboxes, gt_labels, gt_bboxes, mask_in_gts * mask_gt)
|
|
mask_topk = self.select_topk_candidates(align_metric, topk_mask=mask_gt.expand(-1, -1, self.topk).bool())
|
|
mask_pos = mask_topk * mask_in_gts * mask_gt
|
|
return mask_pos, align_metric, overlaps
|
|
|
|
def get_box_metrics(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, mask_gt):
|
|
na = pd_bboxes.shape[-2]
|
|
mask_gt = mask_gt.bool()
|
|
overlaps = torch.zeros([self.bs, self.n_max_boxes, na], dtype=pd_bboxes.dtype, device=pd_bboxes.device)
|
|
bbox_scores = torch.zeros([self.bs, self.n_max_boxes, na], dtype=pd_scores.dtype, device=pd_scores.device)
|
|
|
|
ind = torch.zeros([2, self.bs, self.n_max_boxes], dtype=torch.long)
|
|
ind[0] = torch.arange(end=self.bs).view(-1, 1).expand(-1, self.n_max_boxes)
|
|
ind[1] = gt_labels.squeeze(-1)
|
|
|
|
# Clamp labels to handle case where no GT exists or background index issues
|
|
ind[1] = ind[1].clamp(max=self.num_classes - 1)
|
|
|
|
bbox_scores[mask_gt] = pd_scores[ind[0], :, ind[1]][mask_gt]
|
|
|
|
pd_boxes = pd_bboxes.unsqueeze(1).expand(-1, self.n_max_boxes, -1, -1)[mask_gt]
|
|
gt_boxes = gt_bboxes.unsqueeze(2).expand(-1, -1, na, -1)[mask_gt]
|
|
overlaps[mask_gt] = bbox_iou(gt_boxes, pd_boxes, xywh=False, CIoU=True).squeeze(-1).clamp_(0)
|
|
|
|
align_metric = bbox_scores.pow(self.alpha) * overlaps.pow(self.beta)
|
|
return align_metric, overlaps
|
|
|
|
def select_topk_candidates(self, metrics, topk_mask=None):
|
|
topk_metrics, topk_idxs = torch.topk(metrics, self.topk, dim=-1, largest=True)
|
|
if topk_mask is None:
|
|
topk_mask = (topk_metrics.max(-1, keepdim=True)[0] > self.eps).expand_as(topk_idxs)
|
|
topk_idxs.masked_fill_(~topk_mask, 0)
|
|
|
|
count_tensor = torch.zeros(metrics.shape, dtype=torch.int8, device=topk_idxs.device)
|
|
ones = torch.ones_like(topk_idxs[:, :, :1], dtype=torch.int8, device=topk_idxs.device)
|
|
for k in range(self.topk):
|
|
count_tensor.scatter_add_(-1, topk_idxs[:, :, k : k + 1], ones)
|
|
count_tensor.masked_fill_(count_tensor > 1, 0)
|
|
return count_tensor.to(metrics.dtype)
|
|
|
|
def get_targets(self, gt_labels, gt_bboxes, target_gt_idx, fg_mask):
|
|
batch_ind = torch.arange(end=self.bs, dtype=torch.int64, device=gt_labels.device)[..., None]
|
|
target_gt_idx = target_gt_idx + batch_ind * self.n_max_boxes
|
|
target_labels = gt_labels.long().flatten()[target_gt_idx]
|
|
|
|
target_bboxes = gt_bboxes.view(-1, gt_bboxes.shape[-1])[target_gt_idx]
|
|
target_labels.clamp_(0)
|
|
|
|
target_scores = torch.zeros(
|
|
(target_labels.shape[0], target_labels.shape[1], self.num_classes),
|
|
dtype=torch.int64,
|
|
device=target_labels.device,
|
|
)
|
|
target_scores.scatter_(2, target_labels.unsqueeze(-1), 1)
|
|
|
|
fg_scores_mask = fg_mask[:, :, None].repeat(1, 1, self.num_classes)
|
|
target_scores = torch.where(fg_scores_mask > 0, target_scores, 0)
|
|
|
|
return target_labels, target_bboxes, target_scores
|
|
|
|
@staticmethod
|
|
def select_candidates_in_gts(xy_centers, gt_bboxes, eps=1e-9):
|
|
n_anchors = xy_centers.shape[0]
|
|
bs, n_boxes, _ = gt_bboxes.shape
|
|
lt, rb = gt_bboxes.view(-1, 1, 4).chunk(2, 2)
|
|
bbox_deltas = torch.cat((xy_centers[None] - lt, rb - xy_centers[None]), dim=2).view(bs, n_boxes, n_anchors, -1)
|
|
return bbox_deltas.amin(3).gt_(eps)
|
|
|
|
@staticmethod
|
|
def select_highest_overlaps(mask_pos, overlaps, n_max_boxes):
|
|
fg_mask = mask_pos.sum(-2)
|
|
if fg_mask.max() > 1:
|
|
mask_multi_gts = (fg_mask.unsqueeze(1) > 1).expand(-1, n_max_boxes, -1)
|
|
max_overlaps_idx = overlaps.argmax(1)
|
|
|
|
is_max_overlaps = torch.zeros(mask_pos.shape, dtype=mask_pos.dtype, device=mask_pos.device)
|
|
is_max_overlaps.scatter_(1, max_overlaps_idx.unsqueeze(1), 1)
|
|
|
|
mask_pos = torch.where(mask_multi_gts, is_max_overlaps, mask_pos).float()
|
|
fg_mask = mask_pos.sum(-2)
|
|
target_gt_idx = mask_pos.argmax(-2)
|
|
return target_gt_idx, fg_mask, mask_pos
|
|
|
|
# ==============================================================================
|
|
# 3. Loss 模块 (Loss Modules - from loss.py)
|
|
# ==============================================================================
|
|
|
|
class DFLoss(nn.Module):
|
|
"""Distribution Focal Loss (DFL)."""
|
|
def __init__(self, reg_max: int = 16) -> None:
|
|
super().__init__()
|
|
self.reg_max = reg_max
|
|
|
|
def __call__(self, pred_dist: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
|
target = target.clamp_(0, self.reg_max - 1 - 0.01)
|
|
tl = target.long() # target left
|
|
tr = tl + 1 # target right
|
|
wl = tr - target # weight left
|
|
wr = 1 - wl # weight right
|
|
return (
|
|
F.cross_entropy(pred_dist, tl.view(-1), reduction="none").view(tl.shape) * wl
|
|
+ F.cross_entropy(pred_dist, tr.view(-1), reduction="none").view(tl.shape) * wr
|
|
).mean(-1, keepdim=True)
|
|
|
|
class BboxLoss(nn.Module):
|
|
"""Criterion for computing training losses for bounding boxes (IoU + DFL)."""
|
|
def __init__(self, reg_max: int = 16):
|
|
super().__init__()
|
|
self.dfl_loss = DFLoss(reg_max) if reg_max > 1 else None
|
|
|
|
def forward(self, pred_dist, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask):
|
|
weight = target_scores.sum(-1)[fg_mask].unsqueeze(-1)
|
|
iou = bbox_iou(pred_bboxes[fg_mask], target_bboxes[fg_mask], xywh=False, CIoU=True)
|
|
loss_iou = ((1.0 - iou) * weight).sum() / target_scores_sum
|
|
|
|
if self.dfl_loss:
|
|
target_ltrb = bbox2dist(anchor_points, target_bboxes, self.dfl_loss.reg_max - 1)
|
|
loss_dfl = self.dfl_loss(pred_dist[fg_mask].view(-1, self.dfl_loss.reg_max), target_ltrb[fg_mask]) * weight
|
|
loss_dfl = loss_dfl.sum() / target_scores_sum
|
|
else:
|
|
loss_dfl = torch.tensor(0.0).to(pred_dist.device)
|
|
|
|
return loss_iou, loss_dfl
|
|
|
|
class YOLOv8DetectionLoss:
|
|
"""
|
|
独立版的 YOLOv8/v11 Detection Loss。
|
|
Refactored to assume specific inputs and remove reliance on full Model object.
|
|
"""
|
|
def __init__(self, nc=80, reg_max=16, stride=None, hyp=None):
|
|
"""
|
|
Args:
|
|
nc (int): Number of classes.
|
|
reg_max (int): DFL channels (default 16).
|
|
stride (list/torch.Tensor): Model strides (e.g., [8, 16, 32]).
|
|
hyp (dict): Hyperparameters (box, cls, dfl gains).
|
|
"""
|
|
if stride is None:
|
|
stride = [8, 16, 32] # Default strides
|
|
if hyp is None:
|
|
# Default YOLOv8 hyperparameters
|
|
hyp = {'box': 7.5, 'cls': 0.5, 'dfl': 1.5}
|
|
|
|
self.nc = nc
|
|
self.reg_max = reg_max
|
|
self.stride = stride
|
|
self.hyp = hyp
|
|
self.no = nc + reg_max * 4 # Output channels per anchor
|
|
self.use_dfl = reg_max > 1
|
|
|
|
# Components
|
|
self.assigner = TaskAlignedAssigner(topk=10, num_classes=self.nc, alpha=0.5, beta=6.0)
|
|
self.bbox_loss = BboxLoss(reg_max)
|
|
self.bce = nn.BCEWithLogitsLoss(reduction="none")
|
|
self.proj = torch.arange(reg_max, dtype=torch.float) # Will move to device in method
|
|
|
|
def preprocess(self, targets, batch_size, scale_tensor):
|
|
"""Preprocesses targets: converts [idx, cls, xywh] to internal format."""
|
|
if targets.shape[0] == 0:
|
|
out = torch.zeros(batch_size, 0, 5, device=targets.device)
|
|
else:
|
|
i = targets[:, 0] # image index
|
|
_, counts = i.unique(return_counts=True)
|
|
counts = counts.to(dtype=torch.int32)
|
|
out = torch.zeros(batch_size, counts.max(), 5, device=targets.device)
|
|
for j in range(batch_size):
|
|
matches = i == j
|
|
n = matches.sum()
|
|
if n:
|
|
out[j, :n] = targets[matches, 1:] # cls, x, y, w, h
|
|
out[..., 1:5] = xywh2xyxy(out[..., 1:5].mul_(scale_tensor))
|
|
return out
|
|
|
|
def bbox_decode(self, anchor_points, pred_dist):
|
|
"""Decode predicted object bounding box coordinates from anchor points and distribution."""
|
|
if self.use_dfl:
|
|
b, a, c = pred_dist.shape # batch, anchors, channels
|
|
# Project distributions to scalars using self.proj (0, 1, 2, ..., reg_max-1)
|
|
pred_dist = pred_dist.view(b, a, 4, c // 4).softmax(3).matmul(self.proj.to(pred_dist.device).type(pred_dist.dtype))
|
|
return dist2bbox(pred_dist, anchor_points, xywh=False)
|
|
|
|
def __call__(self, preds, batch):
|
|
"""
|
|
Args:
|
|
preds: List of prediction tensors [B, C, H, W] for each stride level.
|
|
C should be reg_max * 4 + nc.
|
|
batch: Dict containing:
|
|
'batch_idx': Tensor [N], image index for each target
|
|
'cls': Tensor [N], class index for each target
|
|
'bboxes': Tensor [N, 4], normalized xywh format
|
|
Returns:
|
|
total_loss: scalar tensor (sum of box, cls, dfl), multiplied by batch_size as per original logic if needed,
|
|
but here we return mean or sum based on reduction.
|
|
WARNING: Ultralytics returns (loss * bs) usually.
|
|
loss_items: Tensor used for logging [box, cls, dfl]
|
|
"""
|
|
# Ensure proj is on correct device
|
|
self.device = preds[0].device
|
|
loss = torch.zeros(3, device=self.device) # box, cls, dfl
|
|
|
|
feats = preds
|
|
|
|
# Concat features from different strides: (B, no, Total_Anchors)
|
|
pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
|
|
(self.reg_max * 4, self.nc), 1
|
|
)
|
|
|
|
pred_scores = pred_scores.permute(0, 2, 1).contiguous() # (B, Anchors, nc)
|
|
pred_distri = pred_distri.permute(0, 2, 1).contiguous() # (B, Anchors, reg_max * 4)
|
|
|
|
dtype = pred_scores.dtype
|
|
batch_size = pred_scores.shape[0]
|
|
|
|
# Calculate image size from features and first stride (assuming square output implies inputs)
|
|
# Ultralytics uses provided imgsz, but we can infer roughly or pass it.
|
|
# Here assuming feature map H*stride[0] = img H
|
|
imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0]
|
|
|
|
# Generate anchors
|
|
anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
|
|
|
|
# Process Targets
|
|
if "batch_idx" not in batch:
|
|
# Fallback if simplified batch input passed (targets assumed [N, 6] -> idx, cls, x, y, w, h)
|
|
# This handles cases where user passes raw target tensor instead of dict
|
|
raise ValueError("Batch dict must contain 'batch_idx', 'cls', and 'bboxes'")
|
|
|
|
targets = torch.cat((batch["batch_idx"].view(-1, 1), batch["cls"].view(-1, 1), batch["bboxes"]), 1)
|
|
# Check if Normalized (xywh)
|
|
targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
|
|
gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
|
|
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0)
|
|
|
|
# Decode Boxes
|
|
pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4)
|
|
|
|
# Task Aligned Assignment
|
|
_, target_bboxes, target_scores, fg_mask, _ = self.assigner(
|
|
pred_scores.detach().sigmoid(),
|
|
(pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
|
|
anchor_points * stride_tensor,
|
|
gt_labels,
|
|
gt_bboxes,
|
|
mask_gt,
|
|
)
|
|
|
|
target_scores_sum = max(target_scores.sum(), 1)
|
|
|
|
# CLS Loss
|
|
loss[1] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum
|
|
|
|
# BBox Loss (IoU + DFL)
|
|
if fg_mask.sum():
|
|
target_bboxes /= stride_tensor
|
|
loss[0], loss[2] = self.bbox_loss(
|
|
pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask
|
|
)
|
|
|
|
# Apply Gains
|
|
loss[0] *= self.hyp['box']
|
|
loss[1] *= self.hyp['cls']
|
|
loss[2] *= self.hyp['dfl']
|
|
|
|
return loss.sum() * batch_size, loss.detach() # scaling by batch_size to match Ultralytics behavior
|
|
|
|
# ==============================================================================
|
|
# Usage Example
|
|
# ==============================================================================
|
|
if __name__ == "__main__":
|
|
# 配置
|
|
num_classes = 80
|
|
reg_max = 16
|
|
strides = [8, 16, 32]
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
# 初始化 Loss
|
|
criterion = YOLOv8DetectionLoss(nc=num_classes, reg_max=reg_max, stride=strides)
|
|
|
|
# 模拟输入 (Batch Size = 2)
|
|
bs = 2
|
|
# 模拟模型输出: 3 个层级的特征图
|
|
# [BS, Channels, H, W], Channels = reg_max * 4 + num_classes
|
|
channels = reg_max * 4 + num_classes
|
|
preds = [
|
|
torch.randn(bs, channels, 80, 80, device=device, requires_grad=True), # stride 8
|
|
torch.randn(bs, channels, 40, 40, device=device, requires_grad=True), # stride 16
|
|
torch.randn(bs, channels, 20, 20, device=device, requires_grad=True), # stride 32
|
|
]
|
|
|
|
# 模拟 Ground Truth
|
|
# batch: idx, cls, x, y, w, h (normalized 0-1)
|
|
target_batch = {
|
|
"batch_idx": torch.tensor([0, 0, 1], device=device), # Image indices
|
|
"cls": torch.tensor([1, 5, 10], device=device), # Class indices
|
|
"bboxes": torch.tensor([ # Normalized xywh
|
|
[0.5, 0.5, 0.2, 0.3],
|
|
[0.1, 0.1, 0.1, 0.1],
|
|
[0.8, 0.8, 0.2, 0.2]
|
|
], device=device)
|
|
}
|
|
|
|
# 计算 Loss
|
|
total_loss, loss_items = criterion(preds, target_batch)
|
|
|
|
print(f"Total Loss: {total_loss.item()}")
|
|
print(f"Loss Items (Box, Cls, DFL): {loss_items}")
|
|
|
|
# 反向传播测试
|
|
total_loss.backward()
|
|
print("Backward pass successful.") |