Files
Yolo-standalone/yolo11_standalone.py
2025-12-27 02:14:11 +08:00

923 lines
34 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import math
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
def make_divisible(x, divisor):
if isinstance(x, torch.Tensor):
return x
return math.ceil(x / divisor) * divisor
def autopad(k, p=None, d=1):
if d > 1:
k = d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k]
if p is None:
p = k // 2 if isinstance(k, int) else [x // 2 for x in k]
return p
def make_anchors(feats, strides, grid_cell_offset=0.5):
anchor_points, stride_tensor = [], []
assert feats is not None
dtype, device = feats[0].dtype, feats[0].device
for i, stride in enumerate(strides):
_, _, h, w = feats[i].shape
sx = torch.arange(end=w, device=device, dtype=dtype) + grid_cell_offset
sy = torch.arange(end=h, device=device, dtype=dtype) + grid_cell_offset
sy, sx = torch.meshgrid(sy, sx, indexing="ij")
anchor_points.append(torch.stack((sx, sy), -1).view(-1, 2))
stride_tensor.append(torch.full((h * w, 1), stride, dtype=dtype, device=device))
return torch.cat(anchor_points), torch.cat(stride_tensor)
def dist2bbox(distance, anchor_points, xywh=True, dim=-1):
lt, rb = distance.chunk(2, dim)
x1y1 = anchor_points - lt
x2y2 = anchor_points + rb
if xywh:
c_xy = (x1y1 + x2y2) / 2
wh = x2y2 - x1y1
return torch.cat((c_xy, wh), dim)
return torch.cat((x1y1, x2y2), dim)
class Concat(nn.Module):
def __init__(self, dimension=1):
super().__init__()
self.d = dimension
def forward(self, x: List[torch.Tensor]):
return torch.cat(x, self.d)
class Conv(nn.Module):
default_act = nn.SiLU(inplace=True)
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True):
super().__init__()
self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p, d), groups=g, dilation=d, bias=False) # type: ignore
self.bn = nn.BatchNorm2d(c2, eps=0.001, momentum=0.03)
self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()
def forward(self, x):
return self.act(self.bn(self.conv(x)))
def forward_fuse(self, x):
return self.act(self.conv(x))
class DWConv(Conv):
def __init__(self, c1, c2, k=1, s=1, d=1, act=True):
super().__init__(c1, c2, k, s, g=math.gcd(c1, c2), d=d, act=act)
class DFL(nn.Module):
def __init__(self, c1: int = 16):
super().__init__()
self.conv = nn.Conv2d(c1, 1, 1, bias=False).requires_grad_(False)
x = torch.arange(c1, dtype=torch.float)
self.conv.weight.data[:] = nn.Parameter(x.view(1, c1, 1, 1))
self.c1 = c1
def forward(self, x: torch.Tensor) -> torch.Tensor:
b, _, a = x.shape
return self.conv(x.view(b, 4, self.c1, a).transpose(2, 1).softmax(1)).view(b, 4, a)
class Bottleneck(nn.Module):
def __init__(
self, c1: int, c2: int, shortcut: bool = True, g: int = 1, k: Tuple[int, int] = (3, 3), e: float = 0.5
):
super().__init__()
c_ = int(c2 * e)
self.cv1 = Conv(c1, c_, k[0], 1)
self.cv2 = Conv(c_, c2, k[1], 1, g=g)
self.add = shortcut and c1 == c2
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
class C2f(nn.Module):
def __init__(self, c1: int, c2: int, n: int = 1, shortcut: bool = False, g: int = 1, e: float = 0.5):
super().__init__()
self.c = int(c2 * e)
self.cv1 = Conv(c1, 2 * self.c, 1, 1)
self.cv2 = Conv((2 + n) * self.c, c2, 1)
self.m = nn.ModuleList(Bottleneck(self.c, self.c, shortcut, g, k=((3, 3), (3, 3)), e=1.0) for _ in range(n)) # type: ignore
def forward(self, x):
chunk_result = self.cv1(x).chunk(2, 1)
y = [chunk_result[0], chunk_result[1]]
for m_module in self.m:
y.append(m_module(y[-1]))
return self.cv2(torch.cat(y, 1))
def forward_split(self, x: torch.Tensor) -> torch.Tensor:
y = self.cv1(x).split((self.c, self.c), 1)
y = [y[0], y[1]]
y.extend(m(y[-1]) for m in self.m)
return self.cv2(torch.cat(y, 1))
class C3(nn.Module):
def __init__(self, c1: int, c2: int, n: int = 1, shortcut: bool = True, g: int = 1, e: float = 0.5):
super().__init__()
c_ = int(c2 * e)
self.cv1 = Conv(c1, c_, 1, 1)
self.cv2 = Conv(c1, c_, 1, 1)
self.cv3 = Conv(2 * c_, c2, 1)
self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, k=((1, 1), (3, 3)), e=1.0) for _ in range(n))) # type: ignore
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), 1))
class C3k(C3):
def __init__(self, c1: int, c2: int, n: int = 1, shortcut: bool = True, g: int = 1, e: float = 0.5, k: int = 3):
super().__init__(c1, c2, n, shortcut, g, e)
c_ = int(c2 * e)
self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, k=(k, k), e=1.0) for _ in range(n)))
class C3k2(C2f):
def __init__(
self, c1: int, c2: int, n: int = 1, c3k: bool = False, e: float = 0.5, g: int = 1, shortcut: bool = True
):
super().__init__(c1, c2, n, shortcut, g, e)
self.m = nn.ModuleList(
C3k(self.c, self.c, 2, shortcut, g) if c3k else Bottleneck(self.c, self.c, shortcut, g) for _ in range(n)
)
class SPPF(nn.Module):
def __init__(self, c1: int, c2: int, k: int = 5):
super().__init__()
c_ = c1 // 2
self.cv1 = Conv(c1, c_, 1, 1)
self.cv2 = Conv(c_ * 4, c2, 1, 1)
self.m = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2)
def forward(self, x: torch.Tensor) -> torch.Tensor:
y = [self.cv1(x)]
y.extend(self.m(y[-1]) for _ in range(3))
return self.cv2(torch.cat(y, 1))
class Attention(nn.Module):
def __init__(self, dim: int, num_heads: int = 8, attn_ratio: float = 0.5):
super().__init__()
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.key_dim = int(self.head_dim * attn_ratio)
self.scale = self.key_dim**-0.5
nh_kd = self.key_dim * num_heads
h = dim + nh_kd * 2
self.qkv = Conv(dim, h, 1, act=False)
self.proj = Conv(dim, dim, 1, act=False)
self.pe = Conv(dim, dim, 3, 1, g=dim, act=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, C, H, W = x.shape
N = H * W
qkv = self.qkv(x)
q, k, v = qkv.view(B, self.num_heads, self.key_dim * 2 + self.head_dim, N).split(
[self.key_dim, self.key_dim, self.head_dim], dim=2
)
attn = (q.transpose(-2, -1) @ k) * self.scale
attn = attn.softmax(dim=-1)
x = (v @ attn.transpose(-2, -1)).view(B, C, H, W) + self.pe(v.reshape(B, C, H, W))
x = self.proj(x)
return x
class PSABlock(nn.Module):
def __init__(self, c: int, attn_ratio: float = 0.5, num_heads: int = 4, shortcut: bool = True) -> None:
super().__init__()
self.attn = Attention(c, attn_ratio=attn_ratio, num_heads=num_heads)
self.ffn = nn.Sequential(Conv(c, c * 2, 1), Conv(c * 2, c, 1, act=False))
self.add = shortcut
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x + self.attn(x) if self.add else self.attn(x)
x = x + self.ffn(x) if self.add else self.ffn(x)
return x
class C2PSA(nn.Module):
def __init__(self, c1: int, c2: int, n: int = 1, e: float = 0.5):
super().__init__()
assert c1 == c2
self.c = int(c1 * e)
self.cv1 = Conv(c1, 2 * self.c, 1, 1)
self.cv2 = Conv(2 * self.c, c1, 1)
self.m = nn.Sequential(*(PSABlock(self.c, attn_ratio=0.5, num_heads=self.c // 64) for _ in range(n)))
def forward(self, x: torch.Tensor) -> torch.Tensor:
a, b = self.cv1(x).split((self.c, self.c), dim=1)
b = self.m(b)
return self.cv2(torch.cat((a, b), 1))
class Proto(nn.Module):
def __init__(self, c1: int, c_: int = 256, c2: int = 32):
super().__init__()
self.cv1 = Conv(c1, c_, k=3)
self.upsample = nn.ConvTranspose2d(c_, c_, 2, 2, 0, bias=True)
self.cv2 = Conv(c_, c_, k=3)
self.cv3 = Conv(c_, c2)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.cv3(self.cv2(self.upsample(self.cv1(x))))
class BNContrastiveHead(nn.Module):
def __init__(self, embed_dims: int):
super().__init__()
self.norm = nn.BatchNorm2d(embed_dims)
self.bias = nn.Parameter(torch.tensor([-10.0]))
self.logit_scale = nn.Parameter(-1.0 * torch.ones([]))
def fuse(self):
del self.norm
del self.bias
del self.logit_scale
self.forward = self.forward_fuse
def forward_fuse(self, x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
return x
def forward(self, x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
x = self.norm(x)
w = F.normalize(w, dim=-1, p=2)
x = torch.einsum("bchw,bkc->bkhw", x, w)
return x * self.logit_scale.exp() + self.bias
class SwiGLUFFN(nn.Module):
def __init__(self, gc: int, ec: int, e: int = 4) -> None:
super().__init__()
self.w12 = nn.Linear(gc, e * ec)
self.w3 = nn.Linear(e * ec // 2, ec)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x12 = self.w12(x)
x1, x2 = x12.chunk(2, dim=-1)
hidden = F.silu(x1) * x2
return self.w3(hidden)
class Residual(nn.Module):
def __init__(self, m: nn.Module) -> None:
super().__init__()
self.m = m
nn.init.zeros_(self.m.w3.bias)
# For models with l scale, please change the initialization to
# nn.init.constant_(self.m.w3.weight, 1e-6)
nn.init.zeros_(self.m.w3.weight)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x + self.m(x)
class SAVPE(nn.Module):
def __init__(self, ch: List[int], c3: int, embed: int):
super().__init__()
self.cv1 = nn.ModuleList(
nn.Sequential(
Conv(x, c3, 3), Conv(c3, c3, 3), nn.Upsample(scale_factor=i * 2) if i in {1, 2} else nn.Identity()
)
for i, x in enumerate(ch)
)
self.cv2 = nn.ModuleList(
nn.Sequential(Conv(x, c3, 1), nn.Upsample(scale_factor=i * 2) if i in {1, 2} else nn.Identity())
for i, x in enumerate(ch)
)
self.c = 16
self.cv3 = nn.Conv2d(3 * c3, embed, 1)
self.cv4 = nn.Conv2d(3 * c3, self.c, 3, padding=1)
self.cv5 = nn.Conv2d(1, self.c, 3, padding=1)
self.cv6 = nn.Sequential(Conv(2 * self.c, self.c, 3), nn.Conv2d(self.c, self.c, 3, padding=1))
def forward(self, x: List[torch.Tensor], vp: torch.Tensor) -> torch.Tensor:
y = [self.cv2[i](xi) for i, xi in enumerate(x)]
y = self.cv4(torch.cat(y, dim=1))
x = [self.cv1[i](xi) for i, xi in enumerate(x)]
x = self.cv3(torch.cat(x, dim=1))
B, C, H, W = x.shape # type: ignore
Q = vp.shape[1]
x = x.view(B, C, -1) # type: ignore
y = y.reshape(B, 1, self.c, H, W).expand(-1, Q, -1, -1, -1).reshape(B * Q, self.c, H, W)
vp = vp.reshape(B, Q, 1, H, W).reshape(B * Q, 1, H, W)
y = self.cv6(torch.cat((y, self.cv5(vp)), dim=1))
y = y.reshape(B, Q, self.c, -1)
vp = vp.reshape(B, Q, 1, -1)
score = y * vp + torch.logical_not(vp) * torch.finfo(y.dtype).min
score = F.softmax(score, dim=-1).to(y.dtype)
aggregated = score.transpose(-2, -3) @ x.reshape(B, self.c, C // self.c, -1).transpose(-1, -2)
return F.normalize(aggregated.transpose(-2, -3).reshape(B, Q, -1), dim=-1, p=2)
class Detect(nn.Module):
dynamic = False
export = False
shape = None
anchors = torch.empty(0)
strides = torch.empty(0)
def __init__(self, nc=80, ch=()):
super().__init__()
self.nc = nc
self.nl = len(ch)
self.reg_max = 16
self.no = nc + self.reg_max * 4
self.stride = torch.tensor([8., 16., 32.])
c2, c3 = max((16, ch[0] // 4, self.reg_max * 4)), max(ch[0], min(self.nc, 100))
self.cv2 = nn.ModuleList(
nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3), nn.Conv2d(c2, 4 * self.reg_max, 1)) for x in ch
)
self.cv3 = nn.ModuleList(
nn.Sequential(
nn.Sequential(DWConv(x, x, 3), Conv(x, c3, 1)),
nn.Sequential(DWConv(c3, c3, 3), Conv(c3, c3, 1)),
nn.Conv2d(c3, self.nc, 1),
)
for x in ch
)
self.dfl = DFL(self.reg_max) if self.reg_max > 1 else nn.Identity()
def forward(self, x):
for i in range(self.nl):
x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1)
if self.training:
return x
# Inference path
shape = x[0].shape
x_cat = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2)
if self.dynamic or self.shape != shape:
self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
self.shape = shape
box, cls = x_cat.split((self.reg_max * 4, self.nc), 1)
dbox = dist2bbox(self.dfl(box), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
return torch.cat((dbox, cls.sigmoid()), 1)
def bias_init(self):
m = self
for a, b, s in zip(m.cv2, m.cv3, m.stride):
a[-1].bias.data[:] = 1.0 # type: ignore
b[-1].bias.data[: m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # type: ignore
def decode_bboxes(self, bboxes: torch.Tensor, anchors: torch.Tensor, xywh: bool = True) -> torch.Tensor:
return dist2bbox(bboxes, anchors, xywh=xywh and not (self.end2end or self.xyxy), dim=1)
@staticmethod
def postprocess(preds: torch.Tensor, max_det: int, nc: int = 80) -> torch.Tensor:
batch_size, anchors, _ = preds.shape
boxes, scores = preds.split([4, nc], dim=-1)
index = scores.amax(dim=-1).topk(min(max_det, anchors))[1].unsqueeze(-1)
boxes = boxes.gather(dim=1, index=index.repeat(1, 1, 4))
scores = scores.gather(dim=1, index=index.repeat(1, 1, nc))
scores, index = scores.flatten(1).topk(min(max_det, anchors))
i = torch.arange(batch_size)[..., None]
return torch.cat([boxes[i, index // nc], scores[..., None], (index % nc)[..., None].float()], dim=-1)
class YOLOEDetect(Detect):
def __init__(self, nc: int = 80, embed: int = 512, ch: Tuple = ()):
super().__init__(nc, ch)
c3 = max(ch[0], min(self.nc, 100))
assert c3 <= embed
self.cv3 = (
nn.ModuleList(
nn.Sequential(
nn.Sequential(DWConv(x, x, 3), Conv(x, c3, 1)),
nn.Sequential(DWConv(c3, c3, 3), Conv(c3, c3, 1)),
nn.Conv2d(c3, embed, 1),
)
for x in ch
)
)
self.cv4 = nn.ModuleList(BNContrastiveHead(embed) for _ in ch)
self.reprta = Residual(SwiGLUFFN(embed, embed))
self.savpe = SAVPE(ch, c3, embed) # type: ignore
self.embed = embed
def get_tpe(self, tpe: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
return None if tpe is None else F.normalize(self.reprta(tpe), dim=-1, p=2)
def get_vpe(self, x: List[torch.Tensor], vpe: torch.Tensor) -> torch.Tensor:
if vpe.shape[1] == 0: # no visual prompt embeddings
return torch.zeros(x[0].shape[0], 0, self.embed, device=x[0].device)
if vpe.ndim == 4: # (B, N, H, W)
vpe = self.savpe(x, vpe)
assert vpe.ndim == 3 # (B, N, D)
return vpe
def forward( # type: ignore
self, x: List[torch.Tensor], cls_pe: torch.Tensor
) -> Union[torch.Tensor, Tuple]:
for i in range(self.nl):
x[i] = torch.cat((self.cv2[i](x[i]), self.cv4[i](self.cv3[i](x[i]), cls_pe)), 1)
if self.training:
return x # type: ignore
self.no = self.nc + self.reg_max * 4 # self.nc could be changed when inference with different texts
# Inference path
shape = x[0].shape
x_cat = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2)
if self.dynamic or self.shape != shape:
self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
self.shape = shape
box, cls = x_cat.split((self.reg_max * 4, self.nc), 1)
dbox = dist2bbox(self.dfl(box), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
return torch.cat((dbox, cls.sigmoid()), 1)
def bias_init(self):
m = self
for a, b, c, s in zip(m.cv2, m.cv3, m.cv4, m.stride):
a[-1].bias.data[:] = 1.0 # box
b[-1].bias.data[:] = 0.0
c.bias.data[:] = math.log(5 / m.nc / (640 / s) ** 2)
class YOLOESegment(YOLOEDetect):
def __init__(
self, nc: int = 80, nm: int = 32, npr: int = 256, embed: int = 512, ch: Tuple = ()
):
super().__init__(nc, embed, ch)
self.nm = nm
self.npr = npr
self.proto = Proto(ch[0], self.npr, self.nm)
c5 = max(ch[0] // 4, self.nm)
self.cv5 = nn.ModuleList(nn.Sequential(Conv(x, c5, 3), Conv(c5, c5, 3), nn.Conv2d(c5, self.nm, 1)) for x in ch)
def forward(self, x: List[torch.Tensor], text: torch.Tensor) -> Union[Tuple, torch.Tensor]:
p = self.proto(x[0]) # mask protos
bs = p.shape[0] # batch size
mc = torch.cat([self.cv5[i](x[i]).view(bs, self.nm, -1) for i in range(self.nl)], 2) # mask coefficients
x = YOLOEDetect.forward(self, x, text)
if self.training:
return x, mc, p
return (torch.cat([x, mc], 1), p)
class YOLO11(nn.Module):
def __init__(self, nc=80, scale='n'):
super().__init__()
self.nc = nc
# Scales: [depth, width, max_channels]
# 对应 yolo11.yaml 中的 scales 参数
scales = {
'n': [0.50, 0.25, 1024],
's': [0.50, 0.50, 1024],
'm': [0.50, 1.00, 512],
'l': [1.00, 1.00, 512],
'x': [1.00, 1.50, 512],
}
if scale not in scales:
raise ValueError(f"Invalid scale '{scale}'. Available scales: {list(scales.keys())}")
depth, width, max_channels = scales[scale]
if scale in ['n', 's']:
c3k_override = False
else:
c3k_override = True
# 辅助函数:计算通道数 (Width Scaling)
def gw(channels):
return make_divisible(min(channels, max_channels) * width, 8)
# 辅助函数:计算层重复次数 (Depth Scaling)
def gd(n):
return max(round(n * depth), 1) if n > 1 else n
self.model = nn.ModuleList()
# --- Backbone ---
# 0: Conv [64, 3, 2]
self.model.append(Conv(3, gw(64), 3, 2))
# 1: Conv [128, 3, 2]
self.model.append(Conv(gw(64), gw(128), 3, 2))
# 2: C3k2 [256, False, 0.25] -> n=2
self.model.append(C3k2(gw(128), gw(256), n=gd(2), c3k=False or c3k_override, e=0.25))
# 3: Conv [256, 3, 2]
self.model.append(Conv(gw(256), gw(256), 3, 2))
# 4: C3k2 [512, False, 0.25] -> n=2
self.model.append(C3k2(gw(256), gw(512), n=gd(2), c3k=False or c3k_override, e=0.25))
# 5: Conv [512, 3, 2]
self.model.append(Conv(gw(512), gw(512), 3, 2))
# 6: C3k2 [512, True] -> n=2
self.model.append(C3k2(gw(512), gw(512), n=gd(2), c3k=True))
# 7: Conv [1024, 3, 2]
self.model.append(Conv(gw(512), gw(1024), 3, 2))
# 8: C3k2 [1024, True] -> n=2
self.model.append(C3k2(gw(1024), gw(1024), n=gd(2), c3k=True))
# 9: SPPF [1024, 5]
self.model.append(SPPF(gw(1024), gw(1024), 5))
# 10: C2PSA [1024] -> n=2 (YAML args=[1024], repeats=2)
self.model.append(C2PSA(gw(1024), gw(1024), n=gd(2)))
# --- Head ---
# 11: Upsample
self.model.append(nn.Upsample(scale_factor=2, mode='nearest'))
# 12: Concat [-1, 6] (P4)
self.model.append(Concat(dimension=1))
# 13: C3k2 [512, False] -> n=2. Input: P5_up(gw(1024)) + P4(gw(512))
self.model.append(C3k2(gw(1024) + gw(512), gw(512), n=gd(2), c3k=False or c3k_override))
# 14: Upsample
self.model.append(nn.Upsample(scale_factor=2, mode='nearest'))
# 15: Concat [-1, 4] (P3)
self.model.append(Concat(dimension=1))
# 16: C3k2 [256, False] -> n=2. Input: P4_up(gw(512)) + P3(gw(512))
# 注意Layer 4 输出是 gw(512)Layer 13 输出是 gw(512)
self.model.append(C3k2(gw(512) + gw(512), gw(256), n=gd(2), c3k=False or c3k_override))
# 17: Conv [256, 3, 2]
self.model.append(Conv(gw(256), gw(256), 3, 2))
# 18: Concat [-1, 13] (Head P4)
self.model.append(Concat(dimension=1))
# 19: C3k2 [512, False] -> n=2. Input: P3_down(gw(256)) + Head_P4(gw(512))
self.model.append(C3k2(gw(256) + gw(512), gw(512), n=gd(2), c3k=False or c3k_override))
# 20: Conv [512, 3, 2]
self.model.append(Conv(gw(512), gw(512), 3, 2))
# 21: Concat [-1, 10] (P5)
self.model.append(Concat(dimension=1))
# 22: C3k2 [1024, True] -> n=2. Input: P4_down(gw(512)) + P5(gw(1024))
self.model.append(C3k2(gw(512) + gw(1024), gw(1024), n=gd(2), c3k=True))
# 23: Detect [nc]
self.model.append(Detect(nc, ch=[gw(256), gw(512), gw(1024)]))
# --- 初始化权重 ---
self.initialize_weights()
def initialize_weights(self):
"""初始化模型权重,特别是 Detect 头的 Bias"""
for m in self.modules():
if isinstance(m, (nn.Conv2d, nn.Linear)):
# 使用 Kaiming 初始化或其他合适的初始化
pass
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
detect_layer = self.model[-1]
if isinstance(detect_layer, Detect):
detect_layer.bias_init()
def forward(self, x):
# Backbone
x = self.model[0](x)
x = self.model[1](x)
x = self.model[2](x)
x = self.model[3](x)
p3 = self.model[4](x) # 保存 P3 (layer 4)
x = self.model[5](p3)
p4 = self.model[6](x) # 保存 P4 (layer 6)
x = self.model[7](p4)
x = self.model[8](x)
x = self.model[9](x)
p5 = self.model[10](x) # 保存 P5 (layer 10)
# Head
x = self.model[11](p5) # Upsample
x = self.model[12]([x, p4]) # Concat P4
h1 = self.model[13](x) # Head P4 (layer 13)
x = self.model[14](h1) # Upsample
x = self.model[15]([x, p3]) # Concat P3
h2 = self.model[16](x) # Output P3 (layer 16)
x = self.model[17](h2) # Conv
x = self.model[18]([x, h1]) # Concat Head P4
h3 = self.model[19](x) # Output P4 (layer 19)
x = self.model[20](h3) # Conv
x = self.model[21]([x, p5]) # Concat P5
h4 = self.model[22](x) # Output P5 (layer 22)
return self.model[23]([h2, h3, h4]) # Detect
def load_weights(self, pth_file):
state_dict = torch.load(pth_file, map_location='cpu', weights_only=False)
# 移除可能存在的 'model.' 前缀 (如果权重来自 ultralytics 官方)
# 官方权重通常是 model.model.0.conv... 这种格式,或者直接是 model.0.conv...
# 这里做一个简单的兼容性处理
new_state_dict = {}
for k, v in state_dict.items():
# 处理 ultralytics 权重字典中的 'model' 键
if k == 'model':
# 如果是完整的 checkpoint权重在 'model' 键下
# 且通常是 model.state_dict()
if hasattr(v, 'state_dict'):
v = v.state_dict()
elif isinstance(v, dict):
pass # v 就是 state_dict
else:
# 可能是 model 对象本身
try:
v = v.float().state_dict()
except:
continue
for sub_k, sub_v in v.items():
new_state_dict[sub_k] = sub_v
break
else:
new_state_dict[k] = v
if not new_state_dict:
new_state_dict = state_dict
# 尝试加载
try:
self.load_state_dict(new_state_dict, strict=True)
print(f"Successfully loaded weights from {pth_file}")
except Exception as e:
print(f"Error loading weights: {e}")
print("Trying to load with strict=False...")
self.load_state_dict(new_state_dict, strict=False)
class YOLO11E(nn.Module):
def __init__(self, nc=80, scale='n'):
super().__init__()
self.nc = nc
self.pe = None
# Scales: [depth, width, max_channels]
# 对应 yolo11.yaml 中的 scales 参数
scales = {
'n': [0.50, 0.25, 1024],
's': [0.50, 0.50, 1024],
'm': [0.50, 1.00, 512],
'l': [1.00, 1.00, 512],
'x': [1.00, 1.50, 512],
}
if scale not in scales:
raise ValueError(f"Invalid scale '{scale}'. Available scales: {list(scales.keys())}")
depth, width, max_channels = scales[scale]
if scale in ['n', 's']:
c3k_override = False
else:
c3k_override = True
# 辅助函数:计算通道数 (Width Scaling)
def gw(channels):
return make_divisible(min(channels, max_channels) * width, 8)
# 辅助函数:计算层重复次数 (Depth Scaling)
def gd(n):
return max(round(n * depth), 1) if n > 1 else n
self.model = nn.ModuleList()
# --- Backbone ---
# 0: Conv [64, 3, 2]
self.model.append(Conv(3, gw(64), 3, 2))
# 1: Conv [128, 3, 2]
self.model.append(Conv(gw(64), gw(128), 3, 2))
# 2: C3k2 [256, False, 0.25] -> n=2
self.model.append(C3k2(gw(128), gw(256), n=gd(2), c3k=False or c3k_override, e=0.25))
# 3: Conv [256, 3, 2]
self.model.append(Conv(gw(256), gw(256), 3, 2))
# 4: C3k2 [512, False, 0.25] -> n=2
self.model.append(C3k2(gw(256), gw(512), n=gd(2), c3k=False or c3k_override, e=0.25))
# 5: Conv [512, 3, 2]
self.model.append(Conv(gw(512), gw(512), 3, 2))
# 6: C3k2 [512, True] -> n=2
self.model.append(C3k2(gw(512), gw(512), n=gd(2), c3k=True))
# 7: Conv [1024, 3, 2]
self.model.append(Conv(gw(512), gw(1024), 3, 2))
# 8: C3k2 [1024, True] -> n=2
self.model.append(C3k2(gw(1024), gw(1024), n=gd(2), c3k=True))
# 9: SPPF [1024, 5]
self.model.append(SPPF(gw(1024), gw(1024), 5))
# 10: C2PSA [1024] -> n=2 (YAML args=[1024], repeats=2)
self.model.append(C2PSA(gw(1024), gw(1024), n=gd(2)))
# --- Head ---
# 11: Upsample
self.model.append(nn.Upsample(scale_factor=2, mode='nearest'))
# 12: Concat [-1, 6] (P4)
self.model.append(Concat(dimension=1))
# 13: C3k2 [512, False] -> n=2. Input: P5_up(gw(1024)) + P4(gw(512))
self.model.append(C3k2(gw(1024) + gw(512), gw(512), n=gd(2), c3k=False or c3k_override))
# 14: Upsample
self.model.append(nn.Upsample(scale_factor=2, mode='nearest'))
# 15: Concat [-1, 4] (P3)
self.model.append(Concat(dimension=1))
# 16: C3k2 [256, False] -> n=2. Input: P4_up(gw(512)) + P3(gw(512))
# 注意Layer 4 输出是 gw(512)Layer 13 输出是 gw(512)
self.model.append(C3k2(gw(512) + gw(512), gw(256), n=gd(2), c3k=False or c3k_override))
# 17: Conv [256, 3, 2]
self.model.append(Conv(gw(256), gw(256), 3, 2))
# 18: Concat [-1, 13] (Head P4)
self.model.append(Concat(dimension=1))
# 19: C3k2 [512, False] -> n=2. Input: P3_down(gw(256)) + Head_P4(gw(512))
self.model.append(C3k2(gw(256) + gw(512), gw(512), n=gd(2), c3k=False or c3k_override))
# 20: Conv [512, 3, 2]
self.model.append(Conv(gw(512), gw(512), 3, 2))
# 21: Concat [-1, 10] (P5)
self.model.append(Concat(dimension=1))
# 22: C3k2 [1024, True] -> n=2. Input: P4_down(gw(512)) + P5(gw(1024))
self.model.append(C3k2(gw(512) + gw(1024), gw(1024), n=gd(2), c3k=True))
# 23: Detect [nc]
self.model.append(YOLOESegment(nc, ch=[gw(256), gw(512), gw(1024)]))
# --- 初始化权重 ---
self.initialize_weights()
def initialize_weights(self):
"""初始化模型权重,特别是 Detect 头的 Bias"""
for m in self.modules():
if isinstance(m, (nn.Conv2d, nn.Linear)):
# 使用 Kaiming 初始化或其他合适的初始化
pass
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
detect_layer = self.model[-1]
if isinstance(detect_layer, Detect):
detect_layer.bias_init()
def set_classes(self, names: List[str], embeddings: torch.Tensor):
assert embeddings.ndim == 3, "Embeddings must be (1, N, D)"
self.pe = embeddings
self.model[-1].nc = len(names) # type: ignore
self.nc = len(names)
def forward(self, x, tpe=None, vpe=None):
# Backbone
x = self.model[0](x)
x = self.model[1](x)
x = self.model[2](x)
x = self.model[3](x)
p3 = self.model[4](x) # 保存 P3 (layer 4)
x = self.model[5](p3)
p4 = self.model[6](x) # 保存 P4 (layer 6)
x = self.model[7](p4)
x = self.model[8](x)
x = self.model[9](x)
p5 = self.model[10](x) # 保存 P5 (layer 10)
# Head
x = self.model[11](p5) # Upsample
x = self.model[12]([x, p4]) # Concat P4
h1 = self.model[13](x) # Head P4 (layer 13)
x = self.model[14](h1) # Upsample
x = self.model[15]([x, p3]) # Concat P3
h2 = self.model[16](x) # Output P3 (layer 16)
x = self.model[17](h2) # Conv
x = self.model[18]([x, h1]) # Concat Head P4
h3 = self.model[19](x) # Output P4 (layer 19)
x = self.model[20](h3) # Conv
x = self.model[21]([x, p5]) # Concat P5
h4 = self.model[22](x) # Output P5 (layer 22)
head = self.model[23]
feats = [h2, h3, h4]
processed_tpe = head.get_tpe(tpe) # type: ignore
processed_vpe = head.get_vpe(feats, vpe) if vpe is not None else None # type: ignore
all_pe = []
if processed_tpe is not None:
all_pe.append(processed_tpe)
if processed_vpe is not None:
all_pe.append(processed_vpe)
if not all_pe:
if self.pe is not None:
all_pe.append(self.pe.to(device=x.device, dtype=x.dtype))
else:
all_pe.append(torch.zeros(1, self.nc, head.embed, device=x.device, dtype=x.dtype))
cls_pe = torch.cat(all_pe, dim=1)
b = x.shape[0]
if cls_pe.shape[0] != b:
cls_pe = cls_pe.expand(b, -1, -1)
return head(feats, cls_pe)
def load_weights(self, pth_file):
state_dict = torch.load(pth_file, map_location='cpu', weights_only=False)
# 移除可能存在的 'model.' 前缀 (如果权重来自 ultralytics 官方)
# 官方权重通常是 model.model.0.conv... 这种格式,或者直接是 model.0.conv...
# 这里做一个简单的兼容性处理
new_state_dict = {}
for k, v in state_dict.items():
# 处理 ultralytics 权重字典中的 'model' 键
if k == 'model':
# 如果是完整的 checkpoint权重在 'model' 键下
# 且通常是 model.state_dict()
if hasattr(v, 'state_dict'):
v = v.state_dict()
elif isinstance(v, dict):
pass # v 就是 state_dict
else:
# 可能是 model 对象本身
try:
v = v.float().state_dict()
except:
continue
for sub_k, sub_v in v.items():
new_state_dict[sub_k] = sub_v
break
else:
new_state_dict[k] = v
if not new_state_dict:
new_state_dict = state_dict
# 尝试加载
try:
self.load_state_dict(new_state_dict, strict=True)
print(f"Successfully loaded weights from {pth_file}")
except Exception as e:
print(f"Error loading weights: {e}")
print("Trying to load with strict=False...")
self.load_state_dict(new_state_dict, strict=False)
if __name__ == "__main__":
model = YOLO11E(nc=80, scale='l')
model.load_weights("yoloe-11l-seg.pth")
# 模拟 set_classes
# 假设我们有2个类embedding维度是512
fake_embeddings = torch.randn(1, 2, 512)
model.set_classes(["class1", "class2"], fake_embeddings)
# 推理
dummy_input = torch.randn(1, 3, 640, 640)
model.eval()
output = model(dummy_input)
print("Output shape:", output[0].shape) # 应该是 (1, 4+mask_coeffs+num_classes, anchors)