670 lines
27 KiB
Python
670 lines
27 KiB
Python
import math
|
||
from typing import List, Optional, Tuple, Union
|
||
import torch
|
||
import torch.nn as nn
|
||
import torch.nn.functional as F
|
||
|
||
# ==============================================================================
|
||
# [Part 1] Utils & Basic Modules
|
||
# ==============================================================================
|
||
|
||
def make_divisible(x, divisor):
|
||
if isinstance(x, torch.Tensor):
|
||
return x
|
||
return math.ceil(x / divisor) * divisor
|
||
|
||
def autopad(k, p=None, d=1):
|
||
if d > 1:
|
||
k = d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k]
|
||
if p is None:
|
||
p = k // 2 if isinstance(k, int) else [x // 2 for x in k]
|
||
return p
|
||
|
||
def make_anchors(feats, strides, grid_cell_offset=0.5):
|
||
"""生成 Anchor Points,用于后处理阶段"""
|
||
anchor_points, stride_tensor = [], []
|
||
assert feats is not None
|
||
dtype, device = feats[0].dtype, feats[0].device
|
||
for i, stride in enumerate(strides):
|
||
_, _, h, w = feats[i].shape
|
||
sx = torch.arange(end=w, device=device, dtype=dtype) + grid_cell_offset
|
||
sy = torch.arange(end=h, device=device, dtype=dtype) + grid_cell_offset
|
||
sy, sx = torch.meshgrid(sy, sx, indexing="ij")
|
||
anchor_points.append(torch.stack((sx, sy), -1).view(-1, 2))
|
||
stride_tensor.append(torch.full((h * w, 1), stride, dtype=dtype, device=device))
|
||
return torch.cat(anchor_points), torch.cat(stride_tensor)
|
||
|
||
def dist2bbox(distance, anchor_points, xywh=True, dim=-1):
|
||
"""将预测的距离转换为 BBox,用于后处理阶段"""
|
||
lt, rb = distance.chunk(2, dim)
|
||
x1y1 = anchor_points - lt
|
||
x2y2 = anchor_points + rb
|
||
if xywh:
|
||
c_xy = (x1y1 + x2y2) / 2
|
||
wh = x2y2 - x1y1
|
||
return torch.cat((c_xy, wh), dim)
|
||
return torch.cat((x1y1, x2y2), dim)
|
||
|
||
class Concat(nn.Module):
|
||
def __init__(self, dimension=1):
|
||
super().__init__()
|
||
self.d = dimension
|
||
def forward(self, x: List[torch.Tensor]):
|
||
return torch.cat(x, self.d)
|
||
|
||
class Conv(nn.Module):
|
||
default_act = nn.SiLU(inplace=True)
|
||
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True):
|
||
super().__init__()
|
||
self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p, d), groups=g, dilation=d, bias=False)
|
||
self.bn = nn.BatchNorm2d(c2, eps=0.001, momentum=0.03)
|
||
self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()
|
||
def forward(self, x):
|
||
return self.act(self.bn(self.conv(x)))
|
||
|
||
class DWConv(Conv):
|
||
def __init__(self, c1, c2, k=1, s=1, d=1, act=True):
|
||
super().__init__(c1, c2, k, s, g=math.gcd(c1, c2), d=d, act=act)
|
||
|
||
class DFL(nn.Module):
|
||
def __init__(self, c1: int = 16):
|
||
super().__init__()
|
||
self.conv = nn.Conv2d(c1, 1, 1, bias=False).requires_grad_(False)
|
||
x = torch.arange(c1, dtype=torch.float)
|
||
self.conv.weight.data[:] = nn.Parameter(x.view(1, c1, 1, 1))
|
||
self.c1 = c1
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||
b, _, a = x.shape
|
||
return self.conv(x.view(b, 4, self.c1, a).transpose(2, 1).softmax(1)).view(b, 4, a)
|
||
|
||
class Bottleneck(nn.Module):
|
||
def __init__(self, c1: int, c2: int, shortcut: bool = True, g: int = 1, k: Tuple[int, int] = (3, 3), e: float = 0.5):
|
||
super().__init__()
|
||
c_ = int(c2 * e)
|
||
self.cv1 = Conv(c1, c_, k[0], 1)
|
||
self.cv2 = Conv(c_, c2, k[1], 1, g=g)
|
||
self.add = shortcut and c1 == c2
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||
return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
|
||
|
||
class C2f(nn.Module):
|
||
def __init__(self, c1: int, c2: int, n: int = 1, shortcut: bool = False, g: int = 1, e: float = 0.5):
|
||
super().__init__()
|
||
self.c = int(c2 * e)
|
||
self.cv1 = Conv(c1, 2 * self.c, 1, 1)
|
||
self.cv2 = Conv((2 + n) * self.c, c2, 1)
|
||
self.m = nn.ModuleList(Bottleneck(self.c, self.c, shortcut, g, k=((3, 3), (3, 3)), e=1.0) for _ in range(n))
|
||
def forward(self, x):
|
||
chunk_result = self.cv1(x).chunk(2, 1)
|
||
y = [chunk_result[0], chunk_result[1]]
|
||
for m_module in self.m:
|
||
y.append(m_module(y[-1]))
|
||
return self.cv2(torch.cat(y, 1))
|
||
|
||
class C3(nn.Module):
|
||
def __init__(self, c1: int, c2: int, n: int = 1, shortcut: bool = True, g: int = 1, e: float = 0.5):
|
||
super().__init__()
|
||
c_ = int(c2 * e)
|
||
self.cv1 = Conv(c1, c_, 1, 1)
|
||
self.cv2 = Conv(c1, c_, 1, 1)
|
||
self.cv3 = Conv(2 * c_, c2, 1)
|
||
self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, k=((1, 1), (3, 3)), e=1.0) for _ in range(n)))
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||
return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), 1))
|
||
|
||
class C3k(C3):
|
||
def __init__(self, c1: int, c2: int, n: int = 1, shortcut: bool = True, g: int = 1, e: float = 0.5, k: int = 3):
|
||
super().__init__(c1, c2, n, shortcut, g, e)
|
||
c_ = int(c2 * e)
|
||
self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, k=(k, k), e=1.0) for _ in range(n)))
|
||
|
||
class C3k2(C2f):
|
||
def __init__(self, c1: int, c2: int, n: int = 1, c3k: bool = False, e: float = 0.5, g: int = 1, shortcut: bool = True):
|
||
super().__init__(c1, c2, n, shortcut, g, e)
|
||
self.m = nn.ModuleList(C3k(self.c, self.c, 2, shortcut, g) if c3k else Bottleneck(self.c, self.c, shortcut, g) for _ in range(n))
|
||
|
||
class SPPF(nn.Module):
|
||
def __init__(self, c1: int, c2: int, k: int = 5):
|
||
super().__init__()
|
||
c_ = c1 // 2
|
||
self.cv1 = Conv(c1, c_, 1, 1)
|
||
self.cv2 = Conv(c_ * 4, c2, 1, 1)
|
||
self.m = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2)
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||
y = [self.cv1(x)]
|
||
y.extend(self.m(y[-1]) for _ in range(3))
|
||
return self.cv2(torch.cat(y, 1))
|
||
|
||
|
||
# ==============================================================================
|
||
# [Part 2] Advanced Modules & Pure Heads
|
||
# ==============================================================================
|
||
|
||
class Attention(nn.Module):
|
||
def __init__(self, dim: int, num_heads: int = 8, attn_ratio: float = 0.5):
|
||
super().__init__()
|
||
self.num_heads = num_heads
|
||
self.head_dim = dim // num_heads
|
||
self.key_dim = int(self.head_dim * attn_ratio)
|
||
self.scale = self.key_dim**-0.5
|
||
nh_kd = self.key_dim * num_heads
|
||
h = dim + nh_kd * 2
|
||
self.qkv = Conv(dim, h, 1, act=False)
|
||
self.proj = Conv(dim, dim, 1, act=False)
|
||
self.pe = Conv(dim, dim, 3, 1, g=dim, act=False)
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||
B, C, H, W = x.shape
|
||
N = H * W
|
||
qkv = self.qkv(x)
|
||
q, k, v = qkv.view(B, self.num_heads, self.key_dim * 2 + self.head_dim, N).split(
|
||
[self.key_dim, self.key_dim, self.head_dim], dim=2
|
||
)
|
||
attn = (q.transpose(-2, -1) @ k) * self.scale
|
||
attn = attn.softmax(dim=-1)
|
||
x = (v @ attn.transpose(-2, -1)).view(B, C, H, W) + self.pe(v.reshape(B, C, H, W))
|
||
x = self.proj(x)
|
||
return x
|
||
|
||
class PSABlock(nn.Module):
|
||
def __init__(self, c: int, attn_ratio: float = 0.5, num_heads: int = 4, shortcut: bool = True) -> None:
|
||
super().__init__()
|
||
self.attn = Attention(c, attn_ratio=attn_ratio, num_heads=num_heads)
|
||
self.ffn = nn.Sequential(Conv(c, c * 2, 1), Conv(c * 2, c, 1, act=False))
|
||
self.add = shortcut
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||
x = x + self.attn(x) if self.add else self.attn(x)
|
||
x = x + self.ffn(x) if self.add else self.ffn(x)
|
||
return x
|
||
|
||
class C2PSA(nn.Module):
|
||
def __init__(self, c1: int, c2: int, n: int = 1, e: float = 0.5):
|
||
super().__init__()
|
||
assert c1 == c2
|
||
self.c = int(c1 * e)
|
||
self.cv1 = Conv(c1, 2 * self.c, 1, 1)
|
||
self.cv2 = Conv(2 * self.c, c1, 1)
|
||
self.m = nn.Sequential(*(PSABlock(self.c, attn_ratio=0.5, num_heads=self.c // 64) for _ in range(n)))
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||
a, b = self.cv1(x).split((self.c, self.c), dim=1)
|
||
b = self.m(b)
|
||
return self.cv2(torch.cat((a, b), 1))
|
||
|
||
class Proto(nn.Module):
|
||
def __init__(self, c1: int, c_: int = 256, c2: int = 32):
|
||
super().__init__()
|
||
self.cv1 = Conv(c1, c_, k=3)
|
||
self.upsample = nn.ConvTranspose2d(c_, c_, 2, 2, 0, bias=True)
|
||
self.cv2 = Conv(c_, c_, k=3)
|
||
self.cv3 = Conv(c_, c2)
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||
return self.cv3(self.cv2(self.upsample(self.cv1(x))))
|
||
|
||
class BNContrastiveHead(nn.Module):
|
||
def __init__(self, embed_dims: int):
|
||
super().__init__()
|
||
self.norm = nn.BatchNorm2d(embed_dims)
|
||
self.bias = nn.Parameter(torch.tensor([-10.0]))
|
||
self.logit_scale = nn.Parameter(-1.0 * torch.ones([]))
|
||
def fuse(self):
|
||
del self.norm
|
||
del self.bias
|
||
del self.logit_scale
|
||
self.forward = self.forward_fuse
|
||
def forward_fuse(self, x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
|
||
return x
|
||
def forward(self, x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
|
||
x = self.norm(x)
|
||
w = F.normalize(w, dim=-1, p=2)
|
||
x = torch.einsum("bchw,bkc->bkhw", x, w)
|
||
return x * self.logit_scale.exp() + self.bias
|
||
|
||
class SwiGLUFFN(nn.Module):
|
||
def __init__(self, gc: int, ec: int, e: int = 4) -> None:
|
||
super().__init__()
|
||
self.w12 = nn.Linear(gc, e * ec)
|
||
self.w3 = nn.Linear(e * ec // 2, ec)
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||
x12 = self.w12(x)
|
||
x1, x2 = x12.chunk(2, dim=-1)
|
||
hidden = F.silu(x1) * x2
|
||
return self.w3(hidden)
|
||
|
||
class Residual(nn.Module):
|
||
def __init__(self, m: nn.Module) -> None:
|
||
super().__init__()
|
||
self.m = m
|
||
nn.init.zeros_(self.m.w3.bias)
|
||
nn.init.zeros_(self.m.w3.weight)
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||
return x + self.m(x)
|
||
|
||
class SAVPE(nn.Module):
|
||
def __init__(self, ch: List[int], c3: int, embed: int):
|
||
super().__init__()
|
||
self.cv1 = nn.ModuleList(
|
||
nn.Sequential(
|
||
Conv(x, c3, 3), Conv(c3, c3, 3), nn.Upsample(scale_factor=i * 2) if i in {1, 2} else nn.Identity()
|
||
)
|
||
for i, x in enumerate(ch)
|
||
)
|
||
self.cv2 = nn.ModuleList(
|
||
nn.Sequential(Conv(x, c3, 1), nn.Upsample(scale_factor=i * 2) if i in {1, 2} else nn.Identity())
|
||
for i, x in enumerate(ch)
|
||
)
|
||
self.c = 16
|
||
self.cv3 = nn.Conv2d(3 * c3, embed, 1)
|
||
self.cv4 = nn.Conv2d(3 * c3, self.c, 3, padding=1)
|
||
self.cv5 = nn.Conv2d(1, self.c, 3, padding=1)
|
||
self.cv6 = nn.Sequential(Conv(2 * self.c, self.c, 3), nn.Conv2d(self.c, self.c, 3, padding=1))
|
||
|
||
def forward(self, x: List[torch.Tensor], vp: torch.Tensor) -> torch.Tensor:
|
||
y = [self.cv2[i](xi) for i, xi in enumerate(x)]
|
||
y = self.cv4(torch.cat(y, dim=1))
|
||
x = [self.cv1[i](xi) for i, xi in enumerate(x)]
|
||
x = self.cv3(torch.cat(x, dim=1))
|
||
B, C, H, W = x.shape
|
||
Q = vp.shape[1]
|
||
x = x.view(B, C, -1)
|
||
y = y.reshape(B, 1, self.c, H, W).expand(-1, Q, -1, -1, -1).reshape(B * Q, self.c, H, W)
|
||
vp = vp.reshape(B, Q, 1, H, W).reshape(B * Q, 1, H, W)
|
||
y = self.cv6(torch.cat((y, self.cv5(vp)), dim=1))
|
||
y = y.reshape(B, Q, self.c, -1)
|
||
vp = vp.reshape(B, Q, 1, -1)
|
||
score = y * vp + torch.logical_not(vp) * torch.finfo(y.dtype).min
|
||
score = F.softmax(score, dim=-1).to(y.dtype)
|
||
aggregated = score.transpose(-2, -3) @ x.reshape(B, self.c, C // self.c, -1).transpose(-1, -2)
|
||
return F.normalize(aggregated.transpose(-2, -3).reshape(B, Q, -1), dim=-1, p=2)
|
||
|
||
class Detect(nn.Module):
|
||
dynamic = False
|
||
export = False
|
||
shape = None
|
||
anchors = torch.empty(0)
|
||
strides = torch.empty(0)
|
||
|
||
def __init__(self, nc=80, ch=()):
|
||
super().__init__()
|
||
self.nc = nc
|
||
self.nl = len(ch)
|
||
self.reg_max = 16
|
||
self.no = nc + self.reg_max * 4
|
||
self.stride = torch.tensor([8., 16., 32.])
|
||
|
||
c2, c3 = max((16, ch[0] // 4, self.reg_max * 4)), max(ch[0], min(self.nc, 100))
|
||
|
||
self.cv2 = nn.ModuleList(
|
||
nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3), nn.Conv2d(c2, 4 * self.reg_max, 1)) for x in ch
|
||
)
|
||
self.cv3 = nn.ModuleList(
|
||
nn.Sequential(
|
||
nn.Sequential(DWConv(x, x, 3), Conv(x, c3, 1)),
|
||
nn.Sequential(DWConv(c3, c3, 3), Conv(c3, c3, 1)),
|
||
nn.Conv2d(c3, self.nc, 1),
|
||
)
|
||
for x in ch
|
||
)
|
||
self.dfl = DFL(self.reg_max) if self.reg_max > 1 else nn.Identity()
|
||
|
||
def forward(self, x):
|
||
outs = []
|
||
for i in range(self.nl):
|
||
outs.append(torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1))
|
||
return outs
|
||
|
||
def bias_init(self):
|
||
m = self
|
||
for a, b, s in zip(m.cv2, m.cv3, m.stride):
|
||
a[-1].bias.data[:] = 1.0
|
||
b[-1].bias.data[: m.nc] = math.log(5 / m.nc / (640 / s) ** 2)
|
||
|
||
class YOLOEDetect(Detect):
|
||
def __init__(self, nc: int = 80, embed: int = 512, ch: Tuple = ()):
|
||
super().__init__(nc, ch)
|
||
c3 = max(ch[0], min(self.nc, 100))
|
||
assert c3 <= embed
|
||
self.cv3 = (
|
||
nn.ModuleList(
|
||
nn.Sequential(
|
||
nn.Sequential(DWConv(x, x, 3), Conv(x, c3, 1)),
|
||
nn.Sequential(DWConv(c3, c3, 3), Conv(c3, c3, 1)),
|
||
nn.Conv2d(c3, embed, 1),
|
||
)
|
||
for x in ch
|
||
)
|
||
)
|
||
self.cv4 = nn.ModuleList(BNContrastiveHead(embed) for _ in ch)
|
||
self.reprta = Residual(SwiGLUFFN(embed, embed))
|
||
self.savpe = SAVPE(ch, c3, embed)
|
||
self.embed = embed
|
||
|
||
def get_tpe(self, tpe: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
|
||
return None if tpe is None else F.normalize(self.reprta(tpe), dim=-1, p=2)
|
||
|
||
def get_vpe(self, x: List[torch.Tensor], vpe: torch.Tensor) -> torch.Tensor:
|
||
if vpe.shape[1] == 0:
|
||
return torch.zeros(x[0].shape[0], 0, self.embed, device=x[0].device)
|
||
if vpe.ndim == 4:
|
||
vpe = self.savpe(x, vpe)
|
||
assert vpe.ndim == 3
|
||
return vpe
|
||
|
||
def forward(self, x: List[torch.Tensor], cls_pe: torch.Tensor) -> List[torch.Tensor]:
|
||
outs = []
|
||
for i in range(self.nl):
|
||
outs.append(torch.cat((self.cv2[i](x[i]), self.cv4[i](self.cv3[i](x[i]), cls_pe)), 1))
|
||
return outs
|
||
|
||
def bias_init(self):
|
||
m = self
|
||
for a, b, c, s in zip(m.cv2, m.cv3, m.cv4, m.stride):
|
||
a[-1].bias.data[:] = 1.0
|
||
b[-1].bias.data[:] = 0.0
|
||
c.bias.data[:] = math.log(5 / m.nc / (640 / s) ** 2)
|
||
|
||
class YOLOESegment(YOLOEDetect):
|
||
def __init__(
|
||
self, nc: int = 80, nm: int = 32, npr: int = 256, embed: int = 512, ch: Tuple = ()
|
||
):
|
||
super().__init__(nc, embed, ch)
|
||
self.nm = nm
|
||
self.npr = npr
|
||
self.proto = Proto(ch[0], self.npr, self.nm)
|
||
|
||
c5 = max(ch[0] // 4, self.nm)
|
||
self.cv5 = nn.ModuleList(nn.Sequential(Conv(x, c5, 3), Conv(c5, c5, 3), nn.Conv2d(c5, self.nm, 1)) for x in ch)
|
||
|
||
def forward(self, x: List[torch.Tensor], text: torch.Tensor) -> Tuple[List[torch.Tensor], torch.Tensor, torch.Tensor]:
|
||
p = self.proto(x[0])
|
||
bs = p.shape[0]
|
||
mc = torch.cat([self.cv5[i](x[i]).view(bs, self.nm, -1) for i in range(self.nl)], 2)
|
||
x_out = YOLOEDetect.forward(self, x, text)
|
||
return x_out, mc, p
|
||
|
||
|
||
# ==============================================================================
|
||
# [Part 3] PostProcessor & Top Level Models
|
||
# ==============================================================================
|
||
|
||
class YOLOPostProcessor(nn.Module):
|
||
def __init__(self, detect_head, use_segmentation=False):
|
||
super().__init__()
|
||
|
||
self.reg_max = detect_head.reg_max
|
||
self.stride = detect_head.stride
|
||
|
||
if hasattr(detect_head, 'dfl'):
|
||
self.dfl = detect_head.dfl
|
||
else:
|
||
self.dfl = nn.Identity()
|
||
|
||
self.use_segmentation = use_segmentation
|
||
self.register_buffer('anchors', torch.empty(0))
|
||
self.register_buffer('strides', torch.empty(0))
|
||
self.shape = None
|
||
|
||
def forward(self, outputs):
|
||
"""
|
||
outputs:
|
||
- Detect: List[Tensor]
|
||
- Segment: (List[Tensor], Tensor, Tensor)
|
||
"""
|
||
if self.use_segmentation:
|
||
x, mc, p = outputs
|
||
else:
|
||
x = outputs
|
||
|
||
current_no = x[0].shape[1]
|
||
current_nc = current_no - self.reg_max * 4
|
||
shape = x[0].shape
|
||
x_cat = torch.cat([xi.view(shape[0], current_no, -1) for xi in x], 2)
|
||
|
||
if self.anchors.device != x[0].device or self.shape != shape:
|
||
self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
|
||
self.shape = shape
|
||
|
||
box, cls = x_cat.split((self.reg_max * 4, current_nc), 1)
|
||
dbox = dist2bbox(self.dfl(box), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
|
||
final_box = torch.cat((dbox, cls.sigmoid()), 1)
|
||
|
||
if self.use_segmentation:
|
||
return final_box, mc, p
|
||
|
||
return final_box
|
||
|
||
class YOLO11(nn.Module):
|
||
def __init__(self, nc=80, scale='n'):
|
||
super().__init__()
|
||
self.nc = nc
|
||
|
||
# Scales: [depth, width, max_channels]
|
||
scales = {'n': [0.50, 0.25, 1024], 's': [0.50, 0.50, 1024], 'm': [0.50, 1.00, 512], 'l': [1.00, 1.00, 512], 'x': [1.00, 1.50, 512]}
|
||
if scale not in scales: raise ValueError(f"Invalid scale")
|
||
depth, width, max_channels = scales[scale]
|
||
c3k_override = True if scale not in ['n', 's'] else False
|
||
def gw(channels): return make_divisible(min(channels, max_channels) * width, 8)
|
||
def gd(n): return max(round(n * depth), 1) if n > 1 else n
|
||
|
||
self.model = nn.ModuleList()
|
||
# Backbone
|
||
self.model.append(Conv(3, gw(64), 3, 2))
|
||
self.model.append(Conv(gw(64), gw(128), 3, 2))
|
||
self.model.append(C3k2(gw(128), gw(256), n=gd(2), c3k=False or c3k_override, e=0.25))
|
||
self.model.append(Conv(gw(256), gw(256), 3, 2))
|
||
self.model.append(C3k2(gw(256), gw(512), n=gd(2), c3k=False or c3k_override, e=0.25))
|
||
self.model.append(Conv(gw(512), gw(512), 3, 2))
|
||
self.model.append(C3k2(gw(512), gw(512), n=gd(2), c3k=True))
|
||
self.model.append(Conv(gw(512), gw(1024), 3, 2))
|
||
self.model.append(C3k2(gw(1024), gw(1024), n=gd(2), c3k=True))
|
||
self.model.append(SPPF(gw(1024), gw(1024), 5))
|
||
self.model.append(C2PSA(gw(1024), gw(1024), n=gd(2)))
|
||
# Neck
|
||
self.model.append(nn.Upsample(scale_factor=2, mode='nearest'))
|
||
self.model.append(Concat(dimension=1))
|
||
self.model.append(C3k2(gw(1024) + gw(512), gw(512), n=gd(2), c3k=False or c3k_override))
|
||
self.model.append(nn.Upsample(scale_factor=2, mode='nearest'))
|
||
self.model.append(Concat(dimension=1))
|
||
self.model.append(C3k2(gw(512) + gw(512), gw(256), n=gd(2), c3k=False or c3k_override))
|
||
self.model.append(Conv(gw(256), gw(256), 3, 2))
|
||
self.model.append(Concat(dimension=1))
|
||
self.model.append(C3k2(gw(256) + gw(512), gw(512), n=gd(2), c3k=False or c3k_override))
|
||
self.model.append(Conv(gw(512), gw(512), 3, 2))
|
||
self.model.append(Concat(dimension=1))
|
||
self.model.append(C3k2(gw(512) + gw(1024), gw(1024), n=gd(2), c3k=True))
|
||
|
||
# 23: Standard Detect Head
|
||
self.model.append(Detect(nc, ch=[gw(256), gw(512), gw(1024)]))
|
||
self.initialize_weights()
|
||
|
||
def initialize_weights(self):
|
||
for m in self.modules():
|
||
if isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
||
nn.init.constant_(m.weight, 1)
|
||
nn.init.constant_(m.bias, 0)
|
||
if isinstance(self.model[-1], Detect):
|
||
self.model[-1].bias_init()
|
||
|
||
def forward(self, x):
|
||
x = self.model[0](x); x = self.model[1](x); x = self.model[2](x); x = self.model[3](x)
|
||
p3 = self.model[4](x)
|
||
x = self.model[5](p3); p4 = self.model[6](x)
|
||
x = self.model[7](p4); x = self.model[8](x); x = self.model[9](x); p5 = self.model[10](x)
|
||
x = self.model[11](p5); x = self.model[12]([x, p4]); h1 = self.model[13](x)
|
||
x = self.model[14](h1); x = self.model[15]([x, p3]); h2 = self.model[16](x)
|
||
x = self.model[17](h2); x = self.model[18]([x, h1]); h3 = self.model[19](x)
|
||
x = self.model[20](h3); x = self.model[21]([x, p5]); h4 = self.model[22](x)
|
||
return self.model[23]([h2, h3, h4])
|
||
|
||
def load_weights(self, pth_file):
|
||
state_dict = torch.load(pth_file, map_location='cpu', weights_only=False)
|
||
new_state_dict = {}
|
||
for k, v in state_dict.items():
|
||
if k == 'model':
|
||
if hasattr(v, 'state_dict'): v = v.state_dict()
|
||
elif isinstance(v, dict): pass
|
||
else:
|
||
try: v = v.float().state_dict()
|
||
except: continue
|
||
for sub_k, sub_v in v.items(): new_state_dict[sub_k] = sub_v
|
||
break
|
||
else: new_state_dict[k] = v
|
||
if not new_state_dict: new_state_dict = state_dict
|
||
try:
|
||
self.load_state_dict(new_state_dict, strict=True)
|
||
print(f"Successfully loaded weights from {pth_file}")
|
||
except Exception as e:
|
||
print(f"Error loading weights: {e}")
|
||
self.load_state_dict(new_state_dict, strict=False)
|
||
|
||
class YOLO11E(YOLO11):
|
||
def __init__(self, nc=80, scale='n'):
|
||
super().__init__(nc, scale)
|
||
scales = {'n': [0.50, 0.25, 1024], 's': [0.50, 0.50, 1024], 'm': [0.50, 1.00, 512], 'l': [1.00, 1.00, 512], 'x': [1.00, 1.50, 512]}
|
||
depth, width, max_channels = scales[scale]
|
||
def gw(channels): return make_divisible(min(channels, max_channels) * width, 8)
|
||
|
||
self.nc = nc
|
||
self.pe = None
|
||
self.model[-1] = YOLOESegment(nc, ch=[gw(256), gw(512), gw(1024)])
|
||
self.initialize_weights()
|
||
|
||
def set_classes(self, names: List[str], embeddings: torch.Tensor):
|
||
assert embeddings.ndim == 3
|
||
self.pe = embeddings
|
||
self.model[-1].nc = len(names)
|
||
self.nc = len(names)
|
||
|
||
def forward(self, x, tpe=None, vpe=None):
|
||
x = self.model[0](x); x = self.model[1](x); x = self.model[2](x); x = self.model[3](x)
|
||
p3 = self.model[4](x)
|
||
x = self.model[5](p3); p4 = self.model[6](x)
|
||
x = self.model[7](p4); x = self.model[8](x); x = self.model[9](x); p5 = self.model[10](x)
|
||
x = self.model[11](p5); x = self.model[12]([x, p4]); h1 = self.model[13](x)
|
||
x = self.model[14](h1); x = self.model[15]([x, p3]); h2 = self.model[16](x)
|
||
x = self.model[17](h2); x = self.model[18]([x, h1]); h3 = self.model[19](x)
|
||
x = self.model[20](h3); x = self.model[21]([x, p5]); h4 = self.model[22](x)
|
||
|
||
head = self.model[23]
|
||
feats = [h2, h3, h4]
|
||
|
||
processed_tpe = head.get_tpe(tpe)
|
||
processed_vpe = head.get_vpe(feats, vpe) if vpe is not None else None
|
||
|
||
all_pe = []
|
||
if processed_tpe is not None: all_pe.append(processed_tpe)
|
||
if processed_vpe is not None: all_pe.append(processed_vpe)
|
||
if not all_pe:
|
||
if self.pe is not None: all_pe.append(self.pe.to(device=x.device, dtype=x.dtype))
|
||
else: all_pe.append(torch.zeros(1, self.nc, head.embed, device=x.device, dtype=x.dtype))
|
||
|
||
cls_pe = torch.cat(all_pe, dim=1)
|
||
b = x.shape[0]
|
||
if cls_pe.shape[0] != b: cls_pe = cls_pe.expand(b, -1, -1)
|
||
|
||
return head(feats, cls_pe)
|
||
|
||
|
||
# ==============================================================================
|
||
# [Part 4] PostProcessorNumpy
|
||
# ==============================================================================
|
||
|
||
import numpy as np
|
||
|
||
class YOLOPostProcessorNumpy:
|
||
def __init__(self, strides=[8, 16, 32], reg_max=16, use_segmentation=False):
|
||
self.strides = np.array(strides, dtype=np.float32)
|
||
self.reg_max = reg_max
|
||
self.use_segmentation = use_segmentation
|
||
self.anchors = None
|
||
self.strides_array = None
|
||
self.shape = None
|
||
self.dfl_weights = np.arange(reg_max, dtype=np.float32).reshape(1, 1, reg_max, 1)
|
||
|
||
def sigmoid(self, x):
|
||
return 1 / (1 + np.exp(-x))
|
||
|
||
def softmax(self, x, axis=-1):
|
||
x_max = np.max(x, axis=axis, keepdims=True)
|
||
e_x = np.exp(x - x_max)
|
||
return e_x / np.sum(e_x, axis=axis, keepdims=True)
|
||
|
||
def make_anchors(self, feats, strides, grid_cell_offset=0.5):
|
||
anchor_points, stride_list = [], []
|
||
for i, stride in enumerate(strides):
|
||
_, _, h, w = feats[i].shape
|
||
sx = np.arange(w, dtype=np.float32) + grid_cell_offset
|
||
sy = np.arange(h, dtype=np.float32) + grid_cell_offset
|
||
sy, sx = np.meshgrid(sy, sx, indexing='ij')
|
||
|
||
anchor_points.append(np.stack((sx, sy), -1).reshape(-1, 2))
|
||
stride_list.append(np.full((h * w, 1), stride, dtype=np.float32))
|
||
|
||
return np.concatenate(anchor_points), np.concatenate(stride_list)
|
||
|
||
def dist2bbox(self, distance, anchor_points, xywh=True, dim=-1):
|
||
lt, rb = np.split(distance, 2, axis=dim)
|
||
x1y1 = anchor_points - lt
|
||
x2y2 = anchor_points + rb
|
||
if xywh:
|
||
c_xy = (x1y1 + x2y2) / 2
|
||
wh = x2y2 - x1y1
|
||
return np.concatenate((c_xy, wh), axis=dim)
|
||
return np.concatenate((x1y1, x2y2), axis=dim)
|
||
|
||
def dfl_decode(self, x):
|
||
B, C, A = x.shape
|
||
x = x.reshape(B, 4, self.reg_max, A)
|
||
x = self.softmax(x, axis=2)
|
||
return np.sum(x * self.dfl_weights, axis=2)
|
||
|
||
def __call__(self, outputs):
|
||
if self.use_segmentation:
|
||
x, mc, p = outputs
|
||
else:
|
||
x = outputs
|
||
|
||
current_no = x[0].shape[1]
|
||
current_nc = current_no - self.reg_max * 4
|
||
shape = x[0].shape
|
||
|
||
x_cat = np.concatenate([xi.reshape(shape[0], current_no, -1) for xi in x], axis=2)
|
||
|
||
if self.anchors is None or self.shape != shape:
|
||
self.anchors, self.strides_array = self.make_anchors(x, self.strides, 0.5)
|
||
self.shape = shape
|
||
|
||
box, cls = np.split(x_cat, [self.reg_max * 4], axis=1)
|
||
dist = self.dfl_decode(box)
|
||
dist = dist.transpose(0, 2, 1)
|
||
dbox = self.dist2bbox(dist, self.anchors, xywh=True, dim=2) * self.strides_array
|
||
cls = cls.transpose(0, 2, 1)
|
||
sigmoid_cls = self.sigmoid(cls)
|
||
final_box = np.concatenate((dbox, sigmoid_cls), axis=2)
|
||
|
||
if self.use_segmentation:
|
||
return final_box, mc, p
|
||
|
||
return final_box
|
||
|
||
|
||
if __name__ == "__main__":
|
||
print("Testing Standard YOLO11...")
|
||
model_std = YOLO11(nc=80, scale='n')
|
||
model_std.eval()
|
||
post_std = YOLOPostProcessor(model_std.model[-1], use_segmentation=False)
|
||
|
||
input_std = torch.randn(1, 3, 640, 640)
|
||
out_std_raw = model_std(input_std) # Raw list
|
||
out_std_dec = post_std(out_std_raw) # Decoded
|
||
print(f"Standard Output: {out_std_dec.shape}") # (1, 84, 8400)
|
||
|
||
print("\nTesting YOLO11E (Segment)...")
|
||
model_seg = YOLO11E(nc=80, scale='n')
|
||
model_seg.eval()
|
||
post_seg = YOLOPostProcessor(model_seg.model[-1], use_segmentation=True)
|
||
|
||
model_seg.set_classes(["a", "b"], torch.randn(1, 2, 512))
|
||
input_seg = torch.randn(1, 3, 640, 640)
|
||
out_seg_raw = model_seg(input_seg) # (feats, mc, p)
|
||
out_seg_dec, mc, p = post_seg(out_seg_raw) # Decoded
|
||
print(f"Segment Output: {out_seg_dec.shape}") # (1, 4+2, 8400)
|
||
print(f"Mask Coeffs: {mc.shape}, Protos: {p.shape}") |