import math from typing import List, Optional, Tuple, Union import numpy as np import torch import torch.nn as nn import torch.nn.functional as F def make_divisible(x, divisor): if isinstance(x, torch.Tensor): return x return math.ceil(x / divisor) * divisor def autopad(k, p=None, d=1): if d > 1: k = d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k] if p is None: p = k // 2 if isinstance(k, int) else [x // 2 for x in k] return p def make_anchors(feats, strides, grid_cell_offset=0.5): anchor_points, stride_tensor = [], [] assert feats is not None dtype, device = feats[0].dtype, feats[0].device for i, stride in enumerate(strides): _, _, h, w = feats[i].shape sx = torch.arange(end=w, device=device, dtype=dtype) + grid_cell_offset sy = torch.arange(end=h, device=device, dtype=dtype) + grid_cell_offset sy, sx = torch.meshgrid(sy, sx, indexing="ij") anchor_points.append(torch.stack((sx, sy), -1).view(-1, 2)) stride_tensor.append(torch.full((h * w, 1), stride, dtype=dtype, device=device)) return torch.cat(anchor_points), torch.cat(stride_tensor) def dist2bbox(distance, anchor_points, xywh=True, dim=-1): lt, rb = distance.chunk(2, dim) x1y1 = anchor_points - lt x2y2 = anchor_points + rb if xywh: c_xy = (x1y1 + x2y2) / 2 wh = x2y2 - x1y1 return torch.cat((c_xy, wh), dim) return torch.cat((x1y1, x2y2), dim) class Concat(nn.Module): def __init__(self, dimension=1): super().__init__() self.d = dimension def forward(self, x: List[torch.Tensor]): return torch.cat(x, self.d) class Conv(nn.Module): default_act = nn.SiLU(inplace=True) def __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True): super().__init__() self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p, d), groups=g, dilation=d, bias=False) # type: ignore self.bn = nn.BatchNorm2d(c2, eps=0.001, momentum=0.03) self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity() def forward(self, x): return self.act(self.bn(self.conv(x))) def forward_fuse(self, x): return self.act(self.conv(x)) class DWConv(Conv): def __init__(self, c1, c2, k=1, s=1, d=1, act=True): super().__init__(c1, c2, k, s, g=math.gcd(c1, c2), d=d, act=act) class DFL(nn.Module): def __init__(self, c1: int = 16): super().__init__() self.conv = nn.Conv2d(c1, 1, 1, bias=False).requires_grad_(False) x = torch.arange(c1, dtype=torch.float) self.conv.weight.data[:] = nn.Parameter(x.view(1, c1, 1, 1)) self.c1 = c1 def forward(self, x: torch.Tensor) -> torch.Tensor: b, _, a = x.shape return self.conv(x.view(b, 4, self.c1, a).transpose(2, 1).softmax(1)).view(b, 4, a) class Bottleneck(nn.Module): def __init__( self, c1: int, c2: int, shortcut: bool = True, g: int = 1, k: Tuple[int, int] = (3, 3), e: float = 0.5 ): super().__init__() c_ = int(c2 * e) self.cv1 = Conv(c1, c_, k[0], 1) self.cv2 = Conv(c_, c2, k[1], 1, g=g) self.add = shortcut and c1 == c2 def forward(self, x: torch.Tensor) -> torch.Tensor: return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x)) class C2f(nn.Module): def __init__(self, c1: int, c2: int, n: int = 1, shortcut: bool = False, g: int = 1, e: float = 0.5): super().__init__() self.c = int(c2 * e) self.cv1 = Conv(c1, 2 * self.c, 1, 1) self.cv2 = Conv((2 + n) * self.c, c2, 1) self.m = nn.ModuleList(Bottleneck(self.c, self.c, shortcut, g, k=((3, 3), (3, 3)), e=1.0) for _ in range(n)) # type: ignore def forward(self, x): chunk_result = self.cv1(x).chunk(2, 1) y = [chunk_result[0], chunk_result[1]] for m_module in self.m: y.append(m_module(y[-1])) return self.cv2(torch.cat(y, 1)) def forward_split(self, x: torch.Tensor) -> torch.Tensor: y = self.cv1(x).split((self.c, self.c), 1) y = [y[0], y[1]] y.extend(m(y[-1]) for m in self.m) return self.cv2(torch.cat(y, 1)) class C3(nn.Module): def __init__(self, c1: int, c2: int, n: int = 1, shortcut: bool = True, g: int = 1, e: float = 0.5): super().__init__() c_ = int(c2 * e) self.cv1 = Conv(c1, c_, 1, 1) self.cv2 = Conv(c1, c_, 1, 1) self.cv3 = Conv(2 * c_, c2, 1) self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, k=((1, 1), (3, 3)), e=1.0) for _ in range(n))) # type: ignore def forward(self, x: torch.Tensor) -> torch.Tensor: return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), 1)) class C3k(C3): def __init__(self, c1: int, c2: int, n: int = 1, shortcut: bool = True, g: int = 1, e: float = 0.5, k: int = 3): super().__init__(c1, c2, n, shortcut, g, e) c_ = int(c2 * e) self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, k=(k, k), e=1.0) for _ in range(n))) class C3k2(C2f): def __init__( self, c1: int, c2: int, n: int = 1, c3k: bool = False, e: float = 0.5, g: int = 1, shortcut: bool = True ): super().__init__(c1, c2, n, shortcut, g, e) self.m = nn.ModuleList( C3k(self.c, self.c, 2, shortcut, g) if c3k else Bottleneck(self.c, self.c, shortcut, g) for _ in range(n) ) class SPPF(nn.Module): def __init__(self, c1: int, c2: int, k: int = 5): super().__init__() c_ = c1 // 2 self.cv1 = Conv(c1, c_, 1, 1) self.cv2 = Conv(c_ * 4, c2, 1, 1) self.m = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2) def forward(self, x: torch.Tensor) -> torch.Tensor: y = [self.cv1(x)] y.extend(self.m(y[-1]) for _ in range(3)) return self.cv2(torch.cat(y, 1)) class Attention(nn.Module): def __init__(self, dim: int, num_heads: int = 8, attn_ratio: float = 0.5): super().__init__() self.num_heads = num_heads self.head_dim = dim // num_heads self.key_dim = int(self.head_dim * attn_ratio) self.scale = self.key_dim**-0.5 nh_kd = self.key_dim * num_heads h = dim + nh_kd * 2 self.qkv = Conv(dim, h, 1, act=False) self.proj = Conv(dim, dim, 1, act=False) self.pe = Conv(dim, dim, 3, 1, g=dim, act=False) def forward(self, x: torch.Tensor) -> torch.Tensor: B, C, H, W = x.shape N = H * W qkv = self.qkv(x) q, k, v = qkv.view(B, self.num_heads, self.key_dim * 2 + self.head_dim, N).split( [self.key_dim, self.key_dim, self.head_dim], dim=2 ) attn = (q.transpose(-2, -1) @ k) * self.scale attn = attn.softmax(dim=-1) x = (v @ attn.transpose(-2, -1)).view(B, C, H, W) + self.pe(v.reshape(B, C, H, W)) x = self.proj(x) return x class PSABlock(nn.Module): def __init__(self, c: int, attn_ratio: float = 0.5, num_heads: int = 4, shortcut: bool = True) -> None: super().__init__() self.attn = Attention(c, attn_ratio=attn_ratio, num_heads=num_heads) self.ffn = nn.Sequential(Conv(c, c * 2, 1), Conv(c * 2, c, 1, act=False)) self.add = shortcut def forward(self, x: torch.Tensor) -> torch.Tensor: x = x + self.attn(x) if self.add else self.attn(x) x = x + self.ffn(x) if self.add else self.ffn(x) return x class C2PSA(nn.Module): def __init__(self, c1: int, c2: int, n: int = 1, e: float = 0.5): super().__init__() assert c1 == c2 self.c = int(c1 * e) self.cv1 = Conv(c1, 2 * self.c, 1, 1) self.cv2 = Conv(2 * self.c, c1, 1) self.m = nn.Sequential(*(PSABlock(self.c, attn_ratio=0.5, num_heads=self.c // 64) for _ in range(n))) def forward(self, x: torch.Tensor) -> torch.Tensor: a, b = self.cv1(x).split((self.c, self.c), dim=1) b = self.m(b) return self.cv2(torch.cat((a, b), 1)) class Proto(nn.Module): def __init__(self, c1: int, c_: int = 256, c2: int = 32): super().__init__() self.cv1 = Conv(c1, c_, k=3) self.upsample = nn.ConvTranspose2d(c_, c_, 2, 2, 0, bias=True) self.cv2 = Conv(c_, c_, k=3) self.cv3 = Conv(c_, c2) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.cv3(self.cv2(self.upsample(self.cv1(x)))) class BNContrastiveHead(nn.Module): def __init__(self, embed_dims: int): super().__init__() self.norm = nn.BatchNorm2d(embed_dims) self.bias = nn.Parameter(torch.tensor([-10.0])) self.logit_scale = nn.Parameter(-1.0 * torch.ones([])) def fuse(self): del self.norm del self.bias del self.logit_scale self.forward = self.forward_fuse def forward_fuse(self, x: torch.Tensor, w: torch.Tensor) -> torch.Tensor: return x def forward(self, x: torch.Tensor, w: torch.Tensor) -> torch.Tensor: x = self.norm(x) w = F.normalize(w, dim=-1, p=2) x = torch.einsum("bchw,bkc->bkhw", x, w) return x * self.logit_scale.exp() + self.bias class SwiGLUFFN(nn.Module): def __init__(self, gc: int, ec: int, e: int = 4) -> None: super().__init__() self.w12 = nn.Linear(gc, e * ec) self.w3 = nn.Linear(e * ec // 2, ec) def forward(self, x: torch.Tensor) -> torch.Tensor: x12 = self.w12(x) x1, x2 = x12.chunk(2, dim=-1) hidden = F.silu(x1) * x2 return self.w3(hidden) class Residual(nn.Module): def __init__(self, m: nn.Module) -> None: super().__init__() self.m = m nn.init.zeros_(self.m.w3.bias) # For models with l scale, please change the initialization to # nn.init.constant_(self.m.w3.weight, 1e-6) nn.init.zeros_(self.m.w3.weight) def forward(self, x: torch.Tensor) -> torch.Tensor: return x + self.m(x) class SAVPE(nn.Module): def __init__(self, ch: List[int], c3: int, embed: int): super().__init__() self.cv1 = nn.ModuleList( nn.Sequential( Conv(x, c3, 3), Conv(c3, c3, 3), nn.Upsample(scale_factor=i * 2) if i in {1, 2} else nn.Identity() ) for i, x in enumerate(ch) ) self.cv2 = nn.ModuleList( nn.Sequential(Conv(x, c3, 1), nn.Upsample(scale_factor=i * 2) if i in {1, 2} else nn.Identity()) for i, x in enumerate(ch) ) self.c = 16 self.cv3 = nn.Conv2d(3 * c3, embed, 1) self.cv4 = nn.Conv2d(3 * c3, self.c, 3, padding=1) self.cv5 = nn.Conv2d(1, self.c, 3, padding=1) self.cv6 = nn.Sequential(Conv(2 * self.c, self.c, 3), nn.Conv2d(self.c, self.c, 3, padding=1)) def forward(self, x: List[torch.Tensor], vp: torch.Tensor) -> torch.Tensor: y = [self.cv2[i](xi) for i, xi in enumerate(x)] y = self.cv4(torch.cat(y, dim=1)) x = [self.cv1[i](xi) for i, xi in enumerate(x)] x = self.cv3(torch.cat(x, dim=1)) B, C, H, W = x.shape # type: ignore Q = vp.shape[1] x = x.view(B, C, -1) # type: ignore y = y.reshape(B, 1, self.c, H, W).expand(-1, Q, -1, -1, -1).reshape(B * Q, self.c, H, W) vp = vp.reshape(B, Q, 1, H, W).reshape(B * Q, 1, H, W) y = self.cv6(torch.cat((y, self.cv5(vp)), dim=1)) y = y.reshape(B, Q, self.c, -1) vp = vp.reshape(B, Q, 1, -1) score = y * vp + torch.logical_not(vp) * torch.finfo(y.dtype).min score = F.softmax(score, dim=-1).to(y.dtype) aggregated = score.transpose(-2, -3) @ x.reshape(B, self.c, C // self.c, -1).transpose(-1, -2) return F.normalize(aggregated.transpose(-2, -3).reshape(B, Q, -1), dim=-1, p=2) class Detect(nn.Module): dynamic = False export = False shape = None anchors = torch.empty(0) strides = torch.empty(0) def __init__(self, nc=80, ch=()): super().__init__() self.nc = nc self.nl = len(ch) self.reg_max = 16 self.no = nc + self.reg_max * 4 self.stride = torch.tensor([8., 16., 32.]) c2, c3 = max((16, ch[0] // 4, self.reg_max * 4)), max(ch[0], min(self.nc, 100)) self.cv2 = nn.ModuleList( nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3), nn.Conv2d(c2, 4 * self.reg_max, 1)) for x in ch ) self.cv3 = nn.ModuleList( nn.Sequential( nn.Sequential(DWConv(x, x, 3), Conv(x, c3, 1)), nn.Sequential(DWConv(c3, c3, 3), Conv(c3, c3, 1)), nn.Conv2d(c3, self.nc, 1), ) for x in ch ) self.dfl = DFL(self.reg_max) if self.reg_max > 1 else nn.Identity() def forward(self, x): for i in range(self.nl): x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1) if self.training: return x # Inference path shape = x[0].shape x_cat = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2) if self.dynamic or self.shape != shape: self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5)) self.shape = shape box, cls = x_cat.split((self.reg_max * 4, self.nc), 1) dbox = dist2bbox(self.dfl(box), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides return torch.cat((dbox, cls.sigmoid()), 1) def bias_init(self): m = self for a, b, s in zip(m.cv2, m.cv3, m.stride): a[-1].bias.data[:] = 1.0 # type: ignore b[-1].bias.data[: m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # type: ignore def decode_bboxes(self, bboxes: torch.Tensor, anchors: torch.Tensor, xywh: bool = True) -> torch.Tensor: return dist2bbox(bboxes, anchors, xywh=xywh and not (self.end2end or self.xyxy), dim=1) @staticmethod def postprocess(preds: torch.Tensor, max_det: int, nc: int = 80) -> torch.Tensor: batch_size, anchors, _ = preds.shape boxes, scores = preds.split([4, nc], dim=-1) index = scores.amax(dim=-1).topk(min(max_det, anchors))[1].unsqueeze(-1) boxes = boxes.gather(dim=1, index=index.repeat(1, 1, 4)) scores = scores.gather(dim=1, index=index.repeat(1, 1, nc)) scores, index = scores.flatten(1).topk(min(max_det, anchors)) i = torch.arange(batch_size)[..., None] return torch.cat([boxes[i, index // nc], scores[..., None], (index % nc)[..., None].float()], dim=-1) class YOLOEDetect(Detect): def __init__(self, nc: int = 80, embed: int = 512, ch: Tuple = ()): super().__init__(nc, ch) c3 = max(ch[0], min(self.nc, 100)) assert c3 <= embed self.cv3 = ( nn.ModuleList( nn.Sequential( nn.Sequential(DWConv(x, x, 3), Conv(x, c3, 1)), nn.Sequential(DWConv(c3, c3, 3), Conv(c3, c3, 1)), nn.Conv2d(c3, embed, 1), ) for x in ch ) ) self.cv4 = nn.ModuleList(BNContrastiveHead(embed) for _ in ch) self.reprta = Residual(SwiGLUFFN(embed, embed)) self.savpe = SAVPE(ch, c3, embed) # type: ignore self.embed = embed def get_tpe(self, tpe: Optional[torch.Tensor]) -> Optional[torch.Tensor]: return None if tpe is None else F.normalize(self.reprta(tpe), dim=-1, p=2) def get_vpe(self, x: List[torch.Tensor], vpe: torch.Tensor) -> torch.Tensor: if vpe.shape[1] == 0: # no visual prompt embeddings return torch.zeros(x[0].shape[0], 0, self.embed, device=x[0].device) if vpe.ndim == 4: # (B, N, H, W) vpe = self.savpe(x, vpe) assert vpe.ndim == 3 # (B, N, D) return vpe def forward( # type: ignore self, x: List[torch.Tensor], cls_pe: torch.Tensor ) -> Union[torch.Tensor, Tuple]: for i in range(self.nl): x[i] = torch.cat((self.cv2[i](x[i]), self.cv4[i](self.cv3[i](x[i]), cls_pe)), 1) if self.training: return x # type: ignore self.no = self.nc + self.reg_max * 4 # self.nc could be changed when inference with different texts # Inference path shape = x[0].shape x_cat = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2) if self.dynamic or self.shape != shape: self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5)) self.shape = shape box, cls = x_cat.split((self.reg_max * 4, self.nc), 1) dbox = dist2bbox(self.dfl(box), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides return torch.cat((dbox, cls.sigmoid()), 1) def bias_init(self): m = self for a, b, c, s in zip(m.cv2, m.cv3, m.cv4, m.stride): a[-1].bias.data[:] = 1.0 # box b[-1].bias.data[:] = 0.0 c.bias.data[:] = math.log(5 / m.nc / (640 / s) ** 2) class YOLOESegment(YOLOEDetect): def __init__( self, nc: int = 80, nm: int = 32, npr: int = 256, embed: int = 512, ch: Tuple = () ): super().__init__(nc, embed, ch) self.nm = nm self.npr = npr self.proto = Proto(ch[0], self.npr, self.nm) c5 = max(ch[0] // 4, self.nm) self.cv5 = nn.ModuleList(nn.Sequential(Conv(x, c5, 3), Conv(c5, c5, 3), nn.Conv2d(c5, self.nm, 1)) for x in ch) def forward(self, x: List[torch.Tensor], text: torch.Tensor) -> Union[Tuple, torch.Tensor]: p = self.proto(x[0]) # mask protos bs = p.shape[0] # batch size mc = torch.cat([self.cv5[i](x[i]).view(bs, self.nm, -1) for i in range(self.nl)], 2) # mask coefficients x = YOLOEDetect.forward(self, x, text) if self.training: return x, mc, p return (torch.cat([x, mc], 1), p) class YOLO11(nn.Module): def __init__(self, nc=80, scale='n'): super().__init__() self.nc = nc # Scales: [depth, width, max_channels] # 对应 yolo11.yaml 中的 scales 参数 scales = { 'n': [0.50, 0.25, 1024], 's': [0.50, 0.50, 1024], 'm': [0.50, 1.00, 512], 'l': [1.00, 1.00, 512], 'x': [1.00, 1.50, 512], } if scale not in scales: raise ValueError(f"Invalid scale '{scale}'. Available scales: {list(scales.keys())}") depth, width, max_channels = scales[scale] if scale in ['n', 's']: c3k_override = False else: c3k_override = True # 辅助函数:计算通道数 (Width Scaling) def gw(channels): return make_divisible(min(channels, max_channels) * width, 8) # 辅助函数:计算层重复次数 (Depth Scaling) def gd(n): return max(round(n * depth), 1) if n > 1 else n self.model = nn.ModuleList() # --- Backbone --- # 0: Conv [64, 3, 2] self.model.append(Conv(3, gw(64), 3, 2)) # 1: Conv [128, 3, 2] self.model.append(Conv(gw(64), gw(128), 3, 2)) # 2: C3k2 [256, False, 0.25] -> n=2 self.model.append(C3k2(gw(128), gw(256), n=gd(2), c3k=False or c3k_override, e=0.25)) # 3: Conv [256, 3, 2] self.model.append(Conv(gw(256), gw(256), 3, 2)) # 4: C3k2 [512, False, 0.25] -> n=2 self.model.append(C3k2(gw(256), gw(512), n=gd(2), c3k=False or c3k_override, e=0.25)) # 5: Conv [512, 3, 2] self.model.append(Conv(gw(512), gw(512), 3, 2)) # 6: C3k2 [512, True] -> n=2 self.model.append(C3k2(gw(512), gw(512), n=gd(2), c3k=True)) # 7: Conv [1024, 3, 2] self.model.append(Conv(gw(512), gw(1024), 3, 2)) # 8: C3k2 [1024, True] -> n=2 self.model.append(C3k2(gw(1024), gw(1024), n=gd(2), c3k=True)) # 9: SPPF [1024, 5] self.model.append(SPPF(gw(1024), gw(1024), 5)) # 10: C2PSA [1024] -> n=2 (YAML args=[1024], repeats=2) self.model.append(C2PSA(gw(1024), gw(1024), n=gd(2))) # --- Head --- # 11: Upsample self.model.append(nn.Upsample(scale_factor=2, mode='nearest')) # 12: Concat [-1, 6] (P4) self.model.append(Concat(dimension=1)) # 13: C3k2 [512, False] -> n=2. Input: P5_up(gw(1024)) + P4(gw(512)) self.model.append(C3k2(gw(1024) + gw(512), gw(512), n=gd(2), c3k=False or c3k_override)) # 14: Upsample self.model.append(nn.Upsample(scale_factor=2, mode='nearest')) # 15: Concat [-1, 4] (P3) self.model.append(Concat(dimension=1)) # 16: C3k2 [256, False] -> n=2. Input: P4_up(gw(512)) + P3(gw(512)) # 注意:Layer 4 输出是 gw(512),Layer 13 输出是 gw(512) self.model.append(C3k2(gw(512) + gw(512), gw(256), n=gd(2), c3k=False or c3k_override)) # 17: Conv [256, 3, 2] self.model.append(Conv(gw(256), gw(256), 3, 2)) # 18: Concat [-1, 13] (Head P4) self.model.append(Concat(dimension=1)) # 19: C3k2 [512, False] -> n=2. Input: P3_down(gw(256)) + Head_P4(gw(512)) self.model.append(C3k2(gw(256) + gw(512), gw(512), n=gd(2), c3k=False or c3k_override)) # 20: Conv [512, 3, 2] self.model.append(Conv(gw(512), gw(512), 3, 2)) # 21: Concat [-1, 10] (P5) self.model.append(Concat(dimension=1)) # 22: C3k2 [1024, True] -> n=2. Input: P4_down(gw(512)) + P5(gw(1024)) self.model.append(C3k2(gw(512) + gw(1024), gw(1024), n=gd(2), c3k=True)) # 23: Detect [nc] self.model.append(Detect(nc, ch=[gw(256), gw(512), gw(1024)])) # --- 初始化权重 --- self.initialize_weights() def initialize_weights(self): """初始化模型权重,特别是 Detect 头的 Bias""" for m in self.modules(): if isinstance(m, (nn.Conv2d, nn.Linear)): # 使用 Kaiming 初始化或其他合适的初始化 pass elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) detect_layer = self.model[-1] if isinstance(detect_layer, Detect): detect_layer.bias_init() def forward(self, x): # Backbone x = self.model[0](x) x = self.model[1](x) x = self.model[2](x) x = self.model[3](x) p3 = self.model[4](x) # 保存 P3 (layer 4) x = self.model[5](p3) p4 = self.model[6](x) # 保存 P4 (layer 6) x = self.model[7](p4) x = self.model[8](x) x = self.model[9](x) p5 = self.model[10](x) # 保存 P5 (layer 10) # Head x = self.model[11](p5) # Upsample x = self.model[12]([x, p4]) # Concat P4 h1 = self.model[13](x) # Head P4 (layer 13) x = self.model[14](h1) # Upsample x = self.model[15]([x, p3]) # Concat P3 h2 = self.model[16](x) # Output P3 (layer 16) x = self.model[17](h2) # Conv x = self.model[18]([x, h1]) # Concat Head P4 h3 = self.model[19](x) # Output P4 (layer 19) x = self.model[20](h3) # Conv x = self.model[21]([x, p5]) # Concat P5 h4 = self.model[22](x) # Output P5 (layer 22) return self.model[23]([h2, h3, h4]) # Detect def load_weights(self, pth_file): state_dict = torch.load(pth_file, map_location='cpu', weights_only=False) # 移除可能存在的 'model.' 前缀 (如果权重来自 ultralytics 官方) # 官方权重通常是 model.model.0.conv... 这种格式,或者直接是 model.0.conv... # 这里做一个简单的兼容性处理 new_state_dict = {} for k, v in state_dict.items(): # 处理 ultralytics 权重字典中的 'model' 键 if k == 'model': # 如果是完整的 checkpoint,权重在 'model' 键下 # 且通常是 model.state_dict() if hasattr(v, 'state_dict'): v = v.state_dict() elif isinstance(v, dict): pass # v 就是 state_dict else: # 可能是 model 对象本身 try: v = v.float().state_dict() except: continue for sub_k, sub_v in v.items(): new_state_dict[sub_k] = sub_v break else: new_state_dict[k] = v if not new_state_dict: new_state_dict = state_dict # 尝试加载 try: self.load_state_dict(new_state_dict, strict=True) print(f"Successfully loaded weights from {pth_file}") except Exception as e: print(f"Error loading weights: {e}") print("Trying to load with strict=False...") self.load_state_dict(new_state_dict, strict=False) class YOLO11E(nn.Module): def __init__(self, nc=80, scale='n'): super().__init__() self.nc = nc self.pe = None # Scales: [depth, width, max_channels] # 对应 yolo11.yaml 中的 scales 参数 scales = { 'n': [0.50, 0.25, 1024], 's': [0.50, 0.50, 1024], 'm': [0.50, 1.00, 512], 'l': [1.00, 1.00, 512], 'x': [1.00, 1.50, 512], } if scale not in scales: raise ValueError(f"Invalid scale '{scale}'. Available scales: {list(scales.keys())}") depth, width, max_channels = scales[scale] if scale in ['n', 's']: c3k_override = False else: c3k_override = True # 辅助函数:计算通道数 (Width Scaling) def gw(channels): return make_divisible(min(channels, max_channels) * width, 8) # 辅助函数:计算层重复次数 (Depth Scaling) def gd(n): return max(round(n * depth), 1) if n > 1 else n self.model = nn.ModuleList() # --- Backbone --- # 0: Conv [64, 3, 2] self.model.append(Conv(3, gw(64), 3, 2)) # 1: Conv [128, 3, 2] self.model.append(Conv(gw(64), gw(128), 3, 2)) # 2: C3k2 [256, False, 0.25] -> n=2 self.model.append(C3k2(gw(128), gw(256), n=gd(2), c3k=False or c3k_override, e=0.25)) # 3: Conv [256, 3, 2] self.model.append(Conv(gw(256), gw(256), 3, 2)) # 4: C3k2 [512, False, 0.25] -> n=2 self.model.append(C3k2(gw(256), gw(512), n=gd(2), c3k=False or c3k_override, e=0.25)) # 5: Conv [512, 3, 2] self.model.append(Conv(gw(512), gw(512), 3, 2)) # 6: C3k2 [512, True] -> n=2 self.model.append(C3k2(gw(512), gw(512), n=gd(2), c3k=True)) # 7: Conv [1024, 3, 2] self.model.append(Conv(gw(512), gw(1024), 3, 2)) # 8: C3k2 [1024, True] -> n=2 self.model.append(C3k2(gw(1024), gw(1024), n=gd(2), c3k=True)) # 9: SPPF [1024, 5] self.model.append(SPPF(gw(1024), gw(1024), 5)) # 10: C2PSA [1024] -> n=2 (YAML args=[1024], repeats=2) self.model.append(C2PSA(gw(1024), gw(1024), n=gd(2))) # --- Head --- # 11: Upsample self.model.append(nn.Upsample(scale_factor=2, mode='nearest')) # 12: Concat [-1, 6] (P4) self.model.append(Concat(dimension=1)) # 13: C3k2 [512, False] -> n=2. Input: P5_up(gw(1024)) + P4(gw(512)) self.model.append(C3k2(gw(1024) + gw(512), gw(512), n=gd(2), c3k=False or c3k_override)) # 14: Upsample self.model.append(nn.Upsample(scale_factor=2, mode='nearest')) # 15: Concat [-1, 4] (P3) self.model.append(Concat(dimension=1)) # 16: C3k2 [256, False] -> n=2. Input: P4_up(gw(512)) + P3(gw(512)) # 注意:Layer 4 输出是 gw(512),Layer 13 输出是 gw(512) self.model.append(C3k2(gw(512) + gw(512), gw(256), n=gd(2), c3k=False or c3k_override)) # 17: Conv [256, 3, 2] self.model.append(Conv(gw(256), gw(256), 3, 2)) # 18: Concat [-1, 13] (Head P4) self.model.append(Concat(dimension=1)) # 19: C3k2 [512, False] -> n=2. Input: P3_down(gw(256)) + Head_P4(gw(512)) self.model.append(C3k2(gw(256) + gw(512), gw(512), n=gd(2), c3k=False or c3k_override)) # 20: Conv [512, 3, 2] self.model.append(Conv(gw(512), gw(512), 3, 2)) # 21: Concat [-1, 10] (P5) self.model.append(Concat(dimension=1)) # 22: C3k2 [1024, True] -> n=2. Input: P4_down(gw(512)) + P5(gw(1024)) self.model.append(C3k2(gw(512) + gw(1024), gw(1024), n=gd(2), c3k=True)) # 23: Detect [nc] self.model.append(YOLOESegment(nc, ch=[gw(256), gw(512), gw(1024)])) # --- 初始化权重 --- self.initialize_weights() def initialize_weights(self): """初始化模型权重,特别是 Detect 头的 Bias""" for m in self.modules(): if isinstance(m, (nn.Conv2d, nn.Linear)): # 使用 Kaiming 初始化或其他合适的初始化 pass elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) detect_layer = self.model[-1] if isinstance(detect_layer, Detect): detect_layer.bias_init() def set_classes(self, names: List[str], embeddings: torch.Tensor): assert embeddings.ndim == 3, "Embeddings must be (1, N, D)" self.pe = embeddings self.model[-1].nc = len(names) # type: ignore self.nc = len(names) def forward(self, x, tpe=None, vpe=None): # Backbone x = self.model[0](x) x = self.model[1](x) x = self.model[2](x) x = self.model[3](x) p3 = self.model[4](x) # 保存 P3 (layer 4) x = self.model[5](p3) p4 = self.model[6](x) # 保存 P4 (layer 6) x = self.model[7](p4) x = self.model[8](x) x = self.model[9](x) p5 = self.model[10](x) # 保存 P5 (layer 10) # Head x = self.model[11](p5) # Upsample x = self.model[12]([x, p4]) # Concat P4 h1 = self.model[13](x) # Head P4 (layer 13) x = self.model[14](h1) # Upsample x = self.model[15]([x, p3]) # Concat P3 h2 = self.model[16](x) # Output P3 (layer 16) x = self.model[17](h2) # Conv x = self.model[18]([x, h1]) # Concat Head P4 h3 = self.model[19](x) # Output P4 (layer 19) x = self.model[20](h3) # Conv x = self.model[21]([x, p5]) # Concat P5 h4 = self.model[22](x) # Output P5 (layer 22) head = self.model[23] feats = [h2, h3, h4] processed_tpe = head.get_tpe(tpe) # type: ignore processed_vpe = head.get_vpe(feats, vpe) if vpe is not None else None # type: ignore all_pe = [] if processed_tpe is not None: all_pe.append(processed_tpe) if processed_vpe is not None: all_pe.append(processed_vpe) if not all_pe: if self.pe is not None: all_pe.append(self.pe.to(device=x.device, dtype=x.dtype)) else: all_pe.append(torch.zeros(1, self.nc, head.embed, device=x.device, dtype=x.dtype)) cls_pe = torch.cat(all_pe, dim=1) b = x.shape[0] if cls_pe.shape[0] != b: cls_pe = cls_pe.expand(b, -1, -1) return head(feats, cls_pe) def load_weights(self, pth_file): state_dict = torch.load(pth_file, map_location='cpu', weights_only=False) # 移除可能存在的 'model.' 前缀 (如果权重来自 ultralytics 官方) # 官方权重通常是 model.model.0.conv... 这种格式,或者直接是 model.0.conv... # 这里做一个简单的兼容性处理 new_state_dict = {} for k, v in state_dict.items(): # 处理 ultralytics 权重字典中的 'model' 键 if k == 'model': # 如果是完整的 checkpoint,权重在 'model' 键下 # 且通常是 model.state_dict() if hasattr(v, 'state_dict'): v = v.state_dict() elif isinstance(v, dict): pass # v 就是 state_dict else: # 可能是 model 对象本身 try: v = v.float().state_dict() except: continue for sub_k, sub_v in v.items(): new_state_dict[sub_k] = sub_v break else: new_state_dict[k] = v if not new_state_dict: new_state_dict = state_dict # 尝试加载 try: self.load_state_dict(new_state_dict, strict=True) print(f"Successfully loaded weights from {pth_file}") except Exception as e: print(f"Error loading weights: {e}") print("Trying to load with strict=False...") self.load_state_dict(new_state_dict, strict=False) if __name__ == "__main__": model = YOLO11E(nc=80, scale='l') model.load_weights("yoloe-11l-seg.pth") # 模拟 set_classes # 假设我们有2个类,embedding维度是512 fake_embeddings = torch.randn(1, 2, 512) model.set_classes(["class1", "class2"], fake_embeddings) # 推理 dummy_input = torch.randn(1, 3, 640, 640) model.eval() output = model(dummy_input) print("Output shape:", output[0].shape) # 应该是 (1, 4+mask_coeffs+num_classes, anchors)