923 lines
34 KiB
Python
923 lines
34 KiB
Python
import math
|
||
from typing import List, Optional, Tuple, Union
|
||
|
||
import numpy as np
|
||
import torch
|
||
import torch.nn as nn
|
||
import torch.nn.functional as F
|
||
|
||
def make_divisible(x, divisor):
|
||
if isinstance(x, torch.Tensor):
|
||
return x
|
||
return math.ceil(x / divisor) * divisor
|
||
|
||
def autopad(k, p=None, d=1):
|
||
if d > 1:
|
||
k = d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k]
|
||
if p is None:
|
||
p = k // 2 if isinstance(k, int) else [x // 2 for x in k]
|
||
return p
|
||
|
||
def make_anchors(feats, strides, grid_cell_offset=0.5):
|
||
anchor_points, stride_tensor = [], []
|
||
assert feats is not None
|
||
dtype, device = feats[0].dtype, feats[0].device
|
||
for i, stride in enumerate(strides):
|
||
_, _, h, w = feats[i].shape
|
||
sx = torch.arange(end=w, device=device, dtype=dtype) + grid_cell_offset
|
||
sy = torch.arange(end=h, device=device, dtype=dtype) + grid_cell_offset
|
||
sy, sx = torch.meshgrid(sy, sx, indexing="ij")
|
||
anchor_points.append(torch.stack((sx, sy), -1).view(-1, 2))
|
||
stride_tensor.append(torch.full((h * w, 1), stride, dtype=dtype, device=device))
|
||
return torch.cat(anchor_points), torch.cat(stride_tensor)
|
||
|
||
def dist2bbox(distance, anchor_points, xywh=True, dim=-1):
|
||
lt, rb = distance.chunk(2, dim)
|
||
x1y1 = anchor_points - lt
|
||
x2y2 = anchor_points + rb
|
||
if xywh:
|
||
c_xy = (x1y1 + x2y2) / 2
|
||
wh = x2y2 - x1y1
|
||
return torch.cat((c_xy, wh), dim)
|
||
return torch.cat((x1y1, x2y2), dim)
|
||
|
||
class Concat(nn.Module):
|
||
def __init__(self, dimension=1):
|
||
super().__init__()
|
||
self.d = dimension
|
||
|
||
def forward(self, x: List[torch.Tensor]):
|
||
return torch.cat(x, self.d)
|
||
|
||
|
||
class Conv(nn.Module):
|
||
default_act = nn.SiLU(inplace=True)
|
||
|
||
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True):
|
||
super().__init__()
|
||
self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p, d), groups=g, dilation=d, bias=False) # type: ignore
|
||
self.bn = nn.BatchNorm2d(c2, eps=0.001, momentum=0.03)
|
||
self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()
|
||
|
||
def forward(self, x):
|
||
return self.act(self.bn(self.conv(x)))
|
||
|
||
def forward_fuse(self, x):
|
||
return self.act(self.conv(x))
|
||
|
||
class DWConv(Conv):
|
||
def __init__(self, c1, c2, k=1, s=1, d=1, act=True):
|
||
super().__init__(c1, c2, k, s, g=math.gcd(c1, c2), d=d, act=act)
|
||
|
||
class DFL(nn.Module):
|
||
def __init__(self, c1: int = 16):
|
||
super().__init__()
|
||
self.conv = nn.Conv2d(c1, 1, 1, bias=False).requires_grad_(False)
|
||
x = torch.arange(c1, dtype=torch.float)
|
||
self.conv.weight.data[:] = nn.Parameter(x.view(1, c1, 1, 1))
|
||
self.c1 = c1
|
||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||
b, _, a = x.shape
|
||
return self.conv(x.view(b, 4, self.c1, a).transpose(2, 1).softmax(1)).view(b, 4, a)
|
||
|
||
class Bottleneck(nn.Module):
|
||
def __init__(
|
||
self, c1: int, c2: int, shortcut: bool = True, g: int = 1, k: Tuple[int, int] = (3, 3), e: float = 0.5
|
||
):
|
||
super().__init__()
|
||
c_ = int(c2 * e)
|
||
self.cv1 = Conv(c1, c_, k[0], 1)
|
||
self.cv2 = Conv(c_, c2, k[1], 1, g=g)
|
||
self.add = shortcut and c1 == c2
|
||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||
return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
|
||
|
||
class C2f(nn.Module):
|
||
def __init__(self, c1: int, c2: int, n: int = 1, shortcut: bool = False, g: int = 1, e: float = 0.5):
|
||
super().__init__()
|
||
self.c = int(c2 * e)
|
||
self.cv1 = Conv(c1, 2 * self.c, 1, 1)
|
||
self.cv2 = Conv((2 + n) * self.c, c2, 1)
|
||
self.m = nn.ModuleList(Bottleneck(self.c, self.c, shortcut, g, k=((3, 3), (3, 3)), e=1.0) for _ in range(n)) # type: ignore
|
||
|
||
def forward(self, x):
|
||
chunk_result = self.cv1(x).chunk(2, 1)
|
||
y = [chunk_result[0], chunk_result[1]]
|
||
|
||
for m_module in self.m:
|
||
y.append(m_module(y[-1]))
|
||
return self.cv2(torch.cat(y, 1))
|
||
|
||
def forward_split(self, x: torch.Tensor) -> torch.Tensor:
|
||
y = self.cv1(x).split((self.c, self.c), 1)
|
||
y = [y[0], y[1]]
|
||
y.extend(m(y[-1]) for m in self.m)
|
||
return self.cv2(torch.cat(y, 1))
|
||
|
||
class C3(nn.Module):
|
||
def __init__(self, c1: int, c2: int, n: int = 1, shortcut: bool = True, g: int = 1, e: float = 0.5):
|
||
super().__init__()
|
||
c_ = int(c2 * e)
|
||
self.cv1 = Conv(c1, c_, 1, 1)
|
||
self.cv2 = Conv(c1, c_, 1, 1)
|
||
self.cv3 = Conv(2 * c_, c2, 1)
|
||
self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, k=((1, 1), (3, 3)), e=1.0) for _ in range(n))) # type: ignore
|
||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||
return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), 1))
|
||
|
||
class C3k(C3):
|
||
def __init__(self, c1: int, c2: int, n: int = 1, shortcut: bool = True, g: int = 1, e: float = 0.5, k: int = 3):
|
||
super().__init__(c1, c2, n, shortcut, g, e)
|
||
c_ = int(c2 * e)
|
||
self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, k=(k, k), e=1.0) for _ in range(n)))
|
||
|
||
class C3k2(C2f):
|
||
def __init__(
|
||
self, c1: int, c2: int, n: int = 1, c3k: bool = False, e: float = 0.5, g: int = 1, shortcut: bool = True
|
||
):
|
||
super().__init__(c1, c2, n, shortcut, g, e)
|
||
self.m = nn.ModuleList(
|
||
C3k(self.c, self.c, 2, shortcut, g) if c3k else Bottleneck(self.c, self.c, shortcut, g) for _ in range(n)
|
||
)
|
||
|
||
class SPPF(nn.Module):
|
||
def __init__(self, c1: int, c2: int, k: int = 5):
|
||
super().__init__()
|
||
c_ = c1 // 2
|
||
self.cv1 = Conv(c1, c_, 1, 1)
|
||
self.cv2 = Conv(c_ * 4, c2, 1, 1)
|
||
self.m = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2)
|
||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||
y = [self.cv1(x)]
|
||
y.extend(self.m(y[-1]) for _ in range(3))
|
||
return self.cv2(torch.cat(y, 1))
|
||
|
||
class Attention(nn.Module):
|
||
def __init__(self, dim: int, num_heads: int = 8, attn_ratio: float = 0.5):
|
||
super().__init__()
|
||
self.num_heads = num_heads
|
||
self.head_dim = dim // num_heads
|
||
self.key_dim = int(self.head_dim * attn_ratio)
|
||
self.scale = self.key_dim**-0.5
|
||
nh_kd = self.key_dim * num_heads
|
||
h = dim + nh_kd * 2
|
||
self.qkv = Conv(dim, h, 1, act=False)
|
||
self.proj = Conv(dim, dim, 1, act=False)
|
||
self.pe = Conv(dim, dim, 3, 1, g=dim, act=False)
|
||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||
B, C, H, W = x.shape
|
||
N = H * W
|
||
qkv = self.qkv(x)
|
||
q, k, v = qkv.view(B, self.num_heads, self.key_dim * 2 + self.head_dim, N).split(
|
||
[self.key_dim, self.key_dim, self.head_dim], dim=2
|
||
)
|
||
|
||
attn = (q.transpose(-2, -1) @ k) * self.scale
|
||
attn = attn.softmax(dim=-1)
|
||
x = (v @ attn.transpose(-2, -1)).view(B, C, H, W) + self.pe(v.reshape(B, C, H, W))
|
||
x = self.proj(x)
|
||
return x
|
||
|
||
class PSABlock(nn.Module):
|
||
def __init__(self, c: int, attn_ratio: float = 0.5, num_heads: int = 4, shortcut: bool = True) -> None:
|
||
super().__init__()
|
||
self.attn = Attention(c, attn_ratio=attn_ratio, num_heads=num_heads)
|
||
self.ffn = nn.Sequential(Conv(c, c * 2, 1), Conv(c * 2, c, 1, act=False))
|
||
self.add = shortcut
|
||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||
x = x + self.attn(x) if self.add else self.attn(x)
|
||
x = x + self.ffn(x) if self.add else self.ffn(x)
|
||
return x
|
||
|
||
class C2PSA(nn.Module):
|
||
def __init__(self, c1: int, c2: int, n: int = 1, e: float = 0.5):
|
||
super().__init__()
|
||
assert c1 == c2
|
||
self.c = int(c1 * e)
|
||
self.cv1 = Conv(c1, 2 * self.c, 1, 1)
|
||
self.cv2 = Conv(2 * self.c, c1, 1)
|
||
|
||
self.m = nn.Sequential(*(PSABlock(self.c, attn_ratio=0.5, num_heads=self.c // 64) for _ in range(n)))
|
||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||
a, b = self.cv1(x).split((self.c, self.c), dim=1)
|
||
b = self.m(b)
|
||
return self.cv2(torch.cat((a, b), 1))
|
||
|
||
class Proto(nn.Module):
|
||
def __init__(self, c1: int, c_: int = 256, c2: int = 32):
|
||
super().__init__()
|
||
self.cv1 = Conv(c1, c_, k=3)
|
||
self.upsample = nn.ConvTranspose2d(c_, c_, 2, 2, 0, bias=True)
|
||
self.cv2 = Conv(c_, c_, k=3)
|
||
self.cv3 = Conv(c_, c2)
|
||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||
return self.cv3(self.cv2(self.upsample(self.cv1(x))))
|
||
|
||
class BNContrastiveHead(nn.Module):
|
||
def __init__(self, embed_dims: int):
|
||
super().__init__()
|
||
self.norm = nn.BatchNorm2d(embed_dims)
|
||
self.bias = nn.Parameter(torch.tensor([-10.0]))
|
||
self.logit_scale = nn.Parameter(-1.0 * torch.ones([]))
|
||
|
||
def fuse(self):
|
||
del self.norm
|
||
del self.bias
|
||
del self.logit_scale
|
||
self.forward = self.forward_fuse
|
||
|
||
def forward_fuse(self, x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
|
||
return x
|
||
|
||
def forward(self, x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
|
||
x = self.norm(x)
|
||
w = F.normalize(w, dim=-1, p=2)
|
||
|
||
x = torch.einsum("bchw,bkc->bkhw", x, w)
|
||
return x * self.logit_scale.exp() + self.bias
|
||
|
||
class SwiGLUFFN(nn.Module):
|
||
def __init__(self, gc: int, ec: int, e: int = 4) -> None:
|
||
super().__init__()
|
||
self.w12 = nn.Linear(gc, e * ec)
|
||
self.w3 = nn.Linear(e * ec // 2, ec)
|
||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||
x12 = self.w12(x)
|
||
x1, x2 = x12.chunk(2, dim=-1)
|
||
hidden = F.silu(x1) * x2
|
||
return self.w3(hidden)
|
||
|
||
class Residual(nn.Module):
|
||
def __init__(self, m: nn.Module) -> None:
|
||
super().__init__()
|
||
self.m = m
|
||
nn.init.zeros_(self.m.w3.bias)
|
||
# For models with l scale, please change the initialization to
|
||
# nn.init.constant_(self.m.w3.weight, 1e-6)
|
||
nn.init.zeros_(self.m.w3.weight)
|
||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||
return x + self.m(x)
|
||
|
||
class SAVPE(nn.Module):
|
||
def __init__(self, ch: List[int], c3: int, embed: int):
|
||
super().__init__()
|
||
self.cv1 = nn.ModuleList(
|
||
nn.Sequential(
|
||
Conv(x, c3, 3), Conv(c3, c3, 3), nn.Upsample(scale_factor=i * 2) if i in {1, 2} else nn.Identity()
|
||
)
|
||
for i, x in enumerate(ch)
|
||
)
|
||
|
||
self.cv2 = nn.ModuleList(
|
||
nn.Sequential(Conv(x, c3, 1), nn.Upsample(scale_factor=i * 2) if i in {1, 2} else nn.Identity())
|
||
for i, x in enumerate(ch)
|
||
)
|
||
|
||
self.c = 16
|
||
self.cv3 = nn.Conv2d(3 * c3, embed, 1)
|
||
self.cv4 = nn.Conv2d(3 * c3, self.c, 3, padding=1)
|
||
self.cv5 = nn.Conv2d(1, self.c, 3, padding=1)
|
||
self.cv6 = nn.Sequential(Conv(2 * self.c, self.c, 3), nn.Conv2d(self.c, self.c, 3, padding=1))
|
||
|
||
def forward(self, x: List[torch.Tensor], vp: torch.Tensor) -> torch.Tensor:
|
||
y = [self.cv2[i](xi) for i, xi in enumerate(x)]
|
||
y = self.cv4(torch.cat(y, dim=1))
|
||
|
||
x = [self.cv1[i](xi) for i, xi in enumerate(x)]
|
||
x = self.cv3(torch.cat(x, dim=1))
|
||
|
||
B, C, H, W = x.shape # type: ignore
|
||
|
||
Q = vp.shape[1]
|
||
|
||
x = x.view(B, C, -1) # type: ignore
|
||
|
||
y = y.reshape(B, 1, self.c, H, W).expand(-1, Q, -1, -1, -1).reshape(B * Q, self.c, H, W)
|
||
vp = vp.reshape(B, Q, 1, H, W).reshape(B * Q, 1, H, W)
|
||
|
||
y = self.cv6(torch.cat((y, self.cv5(vp)), dim=1))
|
||
|
||
y = y.reshape(B, Q, self.c, -1)
|
||
vp = vp.reshape(B, Q, 1, -1)
|
||
|
||
score = y * vp + torch.logical_not(vp) * torch.finfo(y.dtype).min
|
||
score = F.softmax(score, dim=-1).to(y.dtype)
|
||
aggregated = score.transpose(-2, -3) @ x.reshape(B, self.c, C // self.c, -1).transpose(-1, -2)
|
||
|
||
return F.normalize(aggregated.transpose(-2, -3).reshape(B, Q, -1), dim=-1, p=2)
|
||
|
||
class Detect(nn.Module):
|
||
dynamic = False
|
||
export = False
|
||
shape = None
|
||
anchors = torch.empty(0)
|
||
strides = torch.empty(0)
|
||
|
||
def __init__(self, nc=80, ch=()):
|
||
super().__init__()
|
||
self.nc = nc
|
||
self.nl = len(ch)
|
||
self.reg_max = 16
|
||
self.no = nc + self.reg_max * 4
|
||
self.stride = torch.tensor([8., 16., 32.])
|
||
|
||
c2, c3 = max((16, ch[0] // 4, self.reg_max * 4)), max(ch[0], min(self.nc, 100))
|
||
|
||
self.cv2 = nn.ModuleList(
|
||
nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3), nn.Conv2d(c2, 4 * self.reg_max, 1)) for x in ch
|
||
)
|
||
self.cv3 = nn.ModuleList(
|
||
nn.Sequential(
|
||
nn.Sequential(DWConv(x, x, 3), Conv(x, c3, 1)),
|
||
nn.Sequential(DWConv(c3, c3, 3), Conv(c3, c3, 1)),
|
||
nn.Conv2d(c3, self.nc, 1),
|
||
)
|
||
for x in ch
|
||
)
|
||
self.dfl = DFL(self.reg_max) if self.reg_max > 1 else nn.Identity()
|
||
|
||
def forward(self, x):
|
||
for i in range(self.nl):
|
||
x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1)
|
||
|
||
if self.training:
|
||
return x
|
||
|
||
# Inference path
|
||
shape = x[0].shape
|
||
x_cat = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2)
|
||
|
||
if self.dynamic or self.shape != shape:
|
||
self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
|
||
self.shape = shape
|
||
|
||
box, cls = x_cat.split((self.reg_max * 4, self.nc), 1)
|
||
dbox = dist2bbox(self.dfl(box), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
|
||
|
||
return torch.cat((dbox, cls.sigmoid()), 1)
|
||
|
||
def bias_init(self):
|
||
m = self
|
||
for a, b, s in zip(m.cv2, m.cv3, m.stride):
|
||
a[-1].bias.data[:] = 1.0 # type: ignore
|
||
b[-1].bias.data[: m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # type: ignore
|
||
|
||
def decode_bboxes(self, bboxes: torch.Tensor, anchors: torch.Tensor, xywh: bool = True) -> torch.Tensor:
|
||
return dist2bbox(bboxes, anchors, xywh=xywh and not (self.end2end or self.xyxy), dim=1)
|
||
|
||
@staticmethod
|
||
def postprocess(preds: torch.Tensor, max_det: int, nc: int = 80) -> torch.Tensor:
|
||
batch_size, anchors, _ = preds.shape
|
||
boxes, scores = preds.split([4, nc], dim=-1)
|
||
index = scores.amax(dim=-1).topk(min(max_det, anchors))[1].unsqueeze(-1)
|
||
boxes = boxes.gather(dim=1, index=index.repeat(1, 1, 4))
|
||
scores = scores.gather(dim=1, index=index.repeat(1, 1, nc))
|
||
scores, index = scores.flatten(1).topk(min(max_det, anchors))
|
||
i = torch.arange(batch_size)[..., None]
|
||
return torch.cat([boxes[i, index // nc], scores[..., None], (index % nc)[..., None].float()], dim=-1)
|
||
|
||
class YOLOEDetect(Detect):
|
||
def __init__(self, nc: int = 80, embed: int = 512, ch: Tuple = ()):
|
||
super().__init__(nc, ch)
|
||
c3 = max(ch[0], min(self.nc, 100))
|
||
assert c3 <= embed
|
||
self.cv3 = (
|
||
nn.ModuleList(
|
||
nn.Sequential(
|
||
nn.Sequential(DWConv(x, x, 3), Conv(x, c3, 1)),
|
||
nn.Sequential(DWConv(c3, c3, 3), Conv(c3, c3, 1)),
|
||
nn.Conv2d(c3, embed, 1),
|
||
)
|
||
for x in ch
|
||
)
|
||
)
|
||
|
||
self.cv4 = nn.ModuleList(BNContrastiveHead(embed) for _ in ch)
|
||
|
||
self.reprta = Residual(SwiGLUFFN(embed, embed))
|
||
self.savpe = SAVPE(ch, c3, embed) # type: ignore
|
||
self.embed = embed
|
||
|
||
def get_tpe(self, tpe: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
|
||
return None if tpe is None else F.normalize(self.reprta(tpe), dim=-1, p=2)
|
||
|
||
def get_vpe(self, x: List[torch.Tensor], vpe: torch.Tensor) -> torch.Tensor:
|
||
if vpe.shape[1] == 0: # no visual prompt embeddings
|
||
return torch.zeros(x[0].shape[0], 0, self.embed, device=x[0].device)
|
||
if vpe.ndim == 4: # (B, N, H, W)
|
||
vpe = self.savpe(x, vpe)
|
||
assert vpe.ndim == 3 # (B, N, D)
|
||
return vpe
|
||
|
||
def forward( # type: ignore
|
||
self, x: List[torch.Tensor], cls_pe: torch.Tensor
|
||
) -> Union[torch.Tensor, Tuple]:
|
||
for i in range(self.nl):
|
||
x[i] = torch.cat((self.cv2[i](x[i]), self.cv4[i](self.cv3[i](x[i]), cls_pe)), 1)
|
||
if self.training:
|
||
return x # type: ignore
|
||
self.no = self.nc + self.reg_max * 4 # self.nc could be changed when inference with different texts
|
||
# Inference path
|
||
shape = x[0].shape
|
||
x_cat = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2)
|
||
|
||
if self.dynamic or self.shape != shape:
|
||
self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
|
||
self.shape = shape
|
||
|
||
box, cls = x_cat.split((self.reg_max * 4, self.nc), 1)
|
||
dbox = dist2bbox(self.dfl(box), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
|
||
|
||
return torch.cat((dbox, cls.sigmoid()), 1)
|
||
|
||
def bias_init(self):
|
||
m = self
|
||
for a, b, c, s in zip(m.cv2, m.cv3, m.cv4, m.stride):
|
||
a[-1].bias.data[:] = 1.0 # box
|
||
b[-1].bias.data[:] = 0.0
|
||
c.bias.data[:] = math.log(5 / m.nc / (640 / s) ** 2)
|
||
|
||
class YOLOESegment(YOLOEDetect):
|
||
def __init__(
|
||
self, nc: int = 80, nm: int = 32, npr: int = 256, embed: int = 512, ch: Tuple = ()
|
||
):
|
||
super().__init__(nc, embed, ch)
|
||
self.nm = nm
|
||
self.npr = npr
|
||
self.proto = Proto(ch[0], self.npr, self.nm)
|
||
|
||
c5 = max(ch[0] // 4, self.nm)
|
||
self.cv5 = nn.ModuleList(nn.Sequential(Conv(x, c5, 3), Conv(c5, c5, 3), nn.Conv2d(c5, self.nm, 1)) for x in ch)
|
||
|
||
def forward(self, x: List[torch.Tensor], text: torch.Tensor) -> Union[Tuple, torch.Tensor]:
|
||
p = self.proto(x[0]) # mask protos
|
||
bs = p.shape[0] # batch size
|
||
|
||
mc = torch.cat([self.cv5[i](x[i]).view(bs, self.nm, -1) for i in range(self.nl)], 2) # mask coefficients
|
||
x = YOLOEDetect.forward(self, x, text)
|
||
|
||
if self.training:
|
||
return x, mc, p
|
||
|
||
return (torch.cat([x, mc], 1), p)
|
||
|
||
|
||
|
||
class YOLO11(nn.Module):
|
||
def __init__(self, nc=80, scale='n'):
|
||
super().__init__()
|
||
self.nc = nc
|
||
|
||
# Scales: [depth, width, max_channels]
|
||
# 对应 yolo11.yaml 中的 scales 参数
|
||
scales = {
|
||
'n': [0.50, 0.25, 1024],
|
||
's': [0.50, 0.50, 1024],
|
||
'm': [0.50, 1.00, 512],
|
||
'l': [1.00, 1.00, 512],
|
||
'x': [1.00, 1.50, 512],
|
||
}
|
||
|
||
if scale not in scales:
|
||
raise ValueError(f"Invalid scale '{scale}'. Available scales: {list(scales.keys())}")
|
||
|
||
depth, width, max_channels = scales[scale]
|
||
|
||
if scale in ['n', 's']:
|
||
c3k_override = False
|
||
else:
|
||
c3k_override = True
|
||
|
||
# 辅助函数:计算通道数 (Width Scaling)
|
||
def gw(channels):
|
||
return make_divisible(min(channels, max_channels) * width, 8)
|
||
|
||
# 辅助函数:计算层重复次数 (Depth Scaling)
|
||
def gd(n):
|
||
return max(round(n * depth), 1) if n > 1 else n
|
||
|
||
self.model = nn.ModuleList()
|
||
|
||
# --- Backbone ---
|
||
# 0: Conv [64, 3, 2]
|
||
self.model.append(Conv(3, gw(64), 3, 2))
|
||
|
||
# 1: Conv [128, 3, 2]
|
||
self.model.append(Conv(gw(64), gw(128), 3, 2))
|
||
|
||
# 2: C3k2 [256, False, 0.25] -> n=2
|
||
self.model.append(C3k2(gw(128), gw(256), n=gd(2), c3k=False or c3k_override, e=0.25))
|
||
|
||
# 3: Conv [256, 3, 2]
|
||
self.model.append(Conv(gw(256), gw(256), 3, 2))
|
||
|
||
# 4: C3k2 [512, False, 0.25] -> n=2
|
||
self.model.append(C3k2(gw(256), gw(512), n=gd(2), c3k=False or c3k_override, e=0.25))
|
||
|
||
# 5: Conv [512, 3, 2]
|
||
self.model.append(Conv(gw(512), gw(512), 3, 2))
|
||
|
||
# 6: C3k2 [512, True] -> n=2
|
||
self.model.append(C3k2(gw(512), gw(512), n=gd(2), c3k=True))
|
||
|
||
# 7: Conv [1024, 3, 2]
|
||
self.model.append(Conv(gw(512), gw(1024), 3, 2))
|
||
|
||
# 8: C3k2 [1024, True] -> n=2
|
||
self.model.append(C3k2(gw(1024), gw(1024), n=gd(2), c3k=True))
|
||
|
||
# 9: SPPF [1024, 5]
|
||
self.model.append(SPPF(gw(1024), gw(1024), 5))
|
||
|
||
# 10: C2PSA [1024] -> n=2 (YAML args=[1024], repeats=2)
|
||
self.model.append(C2PSA(gw(1024), gw(1024), n=gd(2)))
|
||
|
||
# --- Head ---
|
||
|
||
# 11: Upsample
|
||
self.model.append(nn.Upsample(scale_factor=2, mode='nearest'))
|
||
|
||
# 12: Concat [-1, 6] (P4)
|
||
self.model.append(Concat(dimension=1))
|
||
|
||
# 13: C3k2 [512, False] -> n=2. Input: P5_up(gw(1024)) + P4(gw(512))
|
||
self.model.append(C3k2(gw(1024) + gw(512), gw(512), n=gd(2), c3k=False or c3k_override))
|
||
|
||
# 14: Upsample
|
||
self.model.append(nn.Upsample(scale_factor=2, mode='nearest'))
|
||
|
||
# 15: Concat [-1, 4] (P3)
|
||
self.model.append(Concat(dimension=1))
|
||
|
||
# 16: C3k2 [256, False] -> n=2. Input: P4_up(gw(512)) + P3(gw(512))
|
||
# 注意:Layer 4 输出是 gw(512),Layer 13 输出是 gw(512)
|
||
self.model.append(C3k2(gw(512) + gw(512), gw(256), n=gd(2), c3k=False or c3k_override))
|
||
|
||
# 17: Conv [256, 3, 2]
|
||
self.model.append(Conv(gw(256), gw(256), 3, 2))
|
||
|
||
# 18: Concat [-1, 13] (Head P4)
|
||
self.model.append(Concat(dimension=1))
|
||
|
||
# 19: C3k2 [512, False] -> n=2. Input: P3_down(gw(256)) + Head_P4(gw(512))
|
||
self.model.append(C3k2(gw(256) + gw(512), gw(512), n=gd(2), c3k=False or c3k_override))
|
||
|
||
# 20: Conv [512, 3, 2]
|
||
self.model.append(Conv(gw(512), gw(512), 3, 2))
|
||
|
||
# 21: Concat [-1, 10] (P5)
|
||
self.model.append(Concat(dimension=1))
|
||
|
||
# 22: C3k2 [1024, True] -> n=2. Input: P4_down(gw(512)) + P5(gw(1024))
|
||
self.model.append(C3k2(gw(512) + gw(1024), gw(1024), n=gd(2), c3k=True))
|
||
|
||
# 23: Detect [nc]
|
||
self.model.append(Detect(nc, ch=[gw(256), gw(512), gw(1024)]))
|
||
|
||
# --- 初始化权重 ---
|
||
self.initialize_weights()
|
||
|
||
def initialize_weights(self):
|
||
"""初始化模型权重,特别是 Detect 头的 Bias"""
|
||
for m in self.modules():
|
||
if isinstance(m, (nn.Conv2d, nn.Linear)):
|
||
# 使用 Kaiming 初始化或其他合适的初始化
|
||
pass
|
||
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
||
nn.init.constant_(m.weight, 1)
|
||
nn.init.constant_(m.bias, 0)
|
||
|
||
detect_layer = self.model[-1]
|
||
if isinstance(detect_layer, Detect):
|
||
detect_layer.bias_init()
|
||
|
||
def forward(self, x):
|
||
# Backbone
|
||
x = self.model[0](x)
|
||
x = self.model[1](x)
|
||
x = self.model[2](x)
|
||
x = self.model[3](x)
|
||
p3 = self.model[4](x) # 保存 P3 (layer 4)
|
||
x = self.model[5](p3)
|
||
p4 = self.model[6](x) # 保存 P4 (layer 6)
|
||
x = self.model[7](p4)
|
||
x = self.model[8](x)
|
||
x = self.model[9](x)
|
||
p5 = self.model[10](x) # 保存 P5 (layer 10)
|
||
|
||
# Head
|
||
x = self.model[11](p5) # Upsample
|
||
x = self.model[12]([x, p4]) # Concat P4
|
||
h1 = self.model[13](x) # Head P4 (layer 13)
|
||
|
||
x = self.model[14](h1) # Upsample
|
||
x = self.model[15]([x, p3]) # Concat P3
|
||
h2 = self.model[16](x) # Output P3 (layer 16)
|
||
|
||
x = self.model[17](h2) # Conv
|
||
x = self.model[18]([x, h1]) # Concat Head P4
|
||
h3 = self.model[19](x) # Output P4 (layer 19)
|
||
|
||
x = self.model[20](h3) # Conv
|
||
x = self.model[21]([x, p5]) # Concat P5
|
||
h4 = self.model[22](x) # Output P5 (layer 22)
|
||
|
||
return self.model[23]([h2, h3, h4]) # Detect
|
||
|
||
def load_weights(self, pth_file):
|
||
state_dict = torch.load(pth_file, map_location='cpu', weights_only=False)
|
||
# 移除可能存在的 'model.' 前缀 (如果权重来自 ultralytics 官方)
|
||
# 官方权重通常是 model.model.0.conv... 这种格式,或者直接是 model.0.conv...
|
||
# 这里做一个简单的兼容性处理
|
||
new_state_dict = {}
|
||
for k, v in state_dict.items():
|
||
# 处理 ultralytics 权重字典中的 'model' 键
|
||
if k == 'model':
|
||
# 如果是完整的 checkpoint,权重在 'model' 键下
|
||
# 且通常是 model.state_dict()
|
||
if hasattr(v, 'state_dict'):
|
||
v = v.state_dict()
|
||
elif isinstance(v, dict):
|
||
pass # v 就是 state_dict
|
||
else:
|
||
# 可能是 model 对象本身
|
||
try:
|
||
v = v.float().state_dict()
|
||
except:
|
||
continue
|
||
|
||
for sub_k, sub_v in v.items():
|
||
new_state_dict[sub_k] = sub_v
|
||
break
|
||
else:
|
||
new_state_dict[k] = v
|
||
|
||
if not new_state_dict:
|
||
new_state_dict = state_dict
|
||
|
||
# 尝试加载
|
||
try:
|
||
self.load_state_dict(new_state_dict, strict=True)
|
||
print(f"Successfully loaded weights from {pth_file}")
|
||
except Exception as e:
|
||
print(f"Error loading weights: {e}")
|
||
print("Trying to load with strict=False...")
|
||
self.load_state_dict(new_state_dict, strict=False)
|
||
|
||
class YOLO11E(nn.Module):
|
||
def __init__(self, nc=80, scale='n'):
|
||
super().__init__()
|
||
self.nc = nc
|
||
self.pe = None
|
||
|
||
# Scales: [depth, width, max_channels]
|
||
# 对应 yolo11.yaml 中的 scales 参数
|
||
scales = {
|
||
'n': [0.50, 0.25, 1024],
|
||
's': [0.50, 0.50, 1024],
|
||
'm': [0.50, 1.00, 512],
|
||
'l': [1.00, 1.00, 512],
|
||
'x': [1.00, 1.50, 512],
|
||
}
|
||
|
||
if scale not in scales:
|
||
raise ValueError(f"Invalid scale '{scale}'. Available scales: {list(scales.keys())}")
|
||
|
||
depth, width, max_channels = scales[scale]
|
||
|
||
if scale in ['n', 's']:
|
||
c3k_override = False
|
||
else:
|
||
c3k_override = True
|
||
|
||
# 辅助函数:计算通道数 (Width Scaling)
|
||
def gw(channels):
|
||
return make_divisible(min(channels, max_channels) * width, 8)
|
||
|
||
# 辅助函数:计算层重复次数 (Depth Scaling)
|
||
def gd(n):
|
||
return max(round(n * depth), 1) if n > 1 else n
|
||
|
||
self.model = nn.ModuleList()
|
||
|
||
# --- Backbone ---
|
||
# 0: Conv [64, 3, 2]
|
||
self.model.append(Conv(3, gw(64), 3, 2))
|
||
|
||
# 1: Conv [128, 3, 2]
|
||
self.model.append(Conv(gw(64), gw(128), 3, 2))
|
||
|
||
# 2: C3k2 [256, False, 0.25] -> n=2
|
||
self.model.append(C3k2(gw(128), gw(256), n=gd(2), c3k=False or c3k_override, e=0.25))
|
||
|
||
# 3: Conv [256, 3, 2]
|
||
self.model.append(Conv(gw(256), gw(256), 3, 2))
|
||
|
||
# 4: C3k2 [512, False, 0.25] -> n=2
|
||
self.model.append(C3k2(gw(256), gw(512), n=gd(2), c3k=False or c3k_override, e=0.25))
|
||
|
||
# 5: Conv [512, 3, 2]
|
||
self.model.append(Conv(gw(512), gw(512), 3, 2))
|
||
|
||
# 6: C3k2 [512, True] -> n=2
|
||
self.model.append(C3k2(gw(512), gw(512), n=gd(2), c3k=True))
|
||
|
||
# 7: Conv [1024, 3, 2]
|
||
self.model.append(Conv(gw(512), gw(1024), 3, 2))
|
||
|
||
# 8: C3k2 [1024, True] -> n=2
|
||
self.model.append(C3k2(gw(1024), gw(1024), n=gd(2), c3k=True))
|
||
|
||
# 9: SPPF [1024, 5]
|
||
self.model.append(SPPF(gw(1024), gw(1024), 5))
|
||
|
||
# 10: C2PSA [1024] -> n=2 (YAML args=[1024], repeats=2)
|
||
self.model.append(C2PSA(gw(1024), gw(1024), n=gd(2)))
|
||
|
||
# --- Head ---
|
||
|
||
# 11: Upsample
|
||
self.model.append(nn.Upsample(scale_factor=2, mode='nearest'))
|
||
|
||
# 12: Concat [-1, 6] (P4)
|
||
self.model.append(Concat(dimension=1))
|
||
|
||
# 13: C3k2 [512, False] -> n=2. Input: P5_up(gw(1024)) + P4(gw(512))
|
||
self.model.append(C3k2(gw(1024) + gw(512), gw(512), n=gd(2), c3k=False or c3k_override))
|
||
|
||
# 14: Upsample
|
||
self.model.append(nn.Upsample(scale_factor=2, mode='nearest'))
|
||
|
||
# 15: Concat [-1, 4] (P3)
|
||
self.model.append(Concat(dimension=1))
|
||
|
||
# 16: C3k2 [256, False] -> n=2. Input: P4_up(gw(512)) + P3(gw(512))
|
||
# 注意:Layer 4 输出是 gw(512),Layer 13 输出是 gw(512)
|
||
self.model.append(C3k2(gw(512) + gw(512), gw(256), n=gd(2), c3k=False or c3k_override))
|
||
|
||
# 17: Conv [256, 3, 2]
|
||
self.model.append(Conv(gw(256), gw(256), 3, 2))
|
||
|
||
# 18: Concat [-1, 13] (Head P4)
|
||
self.model.append(Concat(dimension=1))
|
||
|
||
# 19: C3k2 [512, False] -> n=2. Input: P3_down(gw(256)) + Head_P4(gw(512))
|
||
self.model.append(C3k2(gw(256) + gw(512), gw(512), n=gd(2), c3k=False or c3k_override))
|
||
|
||
# 20: Conv [512, 3, 2]
|
||
self.model.append(Conv(gw(512), gw(512), 3, 2))
|
||
|
||
# 21: Concat [-1, 10] (P5)
|
||
self.model.append(Concat(dimension=1))
|
||
|
||
# 22: C3k2 [1024, True] -> n=2. Input: P4_down(gw(512)) + P5(gw(1024))
|
||
self.model.append(C3k2(gw(512) + gw(1024), gw(1024), n=gd(2), c3k=True))
|
||
|
||
# 23: Detect [nc]
|
||
self.model.append(YOLOESegment(nc, ch=[gw(256), gw(512), gw(1024)]))
|
||
|
||
# --- 初始化权重 ---
|
||
self.initialize_weights()
|
||
|
||
def initialize_weights(self):
|
||
"""初始化模型权重,特别是 Detect 头的 Bias"""
|
||
for m in self.modules():
|
||
if isinstance(m, (nn.Conv2d, nn.Linear)):
|
||
# 使用 Kaiming 初始化或其他合适的初始化
|
||
pass
|
||
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
||
nn.init.constant_(m.weight, 1)
|
||
nn.init.constant_(m.bias, 0)
|
||
|
||
detect_layer = self.model[-1]
|
||
if isinstance(detect_layer, Detect):
|
||
detect_layer.bias_init()
|
||
|
||
def set_classes(self, names: List[str], embeddings: torch.Tensor):
|
||
assert embeddings.ndim == 3, "Embeddings must be (1, N, D)"
|
||
self.pe = embeddings
|
||
self.model[-1].nc = len(names) # type: ignore
|
||
self.nc = len(names)
|
||
|
||
def forward(self, x, tpe=None, vpe=None):
|
||
# Backbone
|
||
x = self.model[0](x)
|
||
x = self.model[1](x)
|
||
x = self.model[2](x)
|
||
x = self.model[3](x)
|
||
p3 = self.model[4](x) # 保存 P3 (layer 4)
|
||
x = self.model[5](p3)
|
||
p4 = self.model[6](x) # 保存 P4 (layer 6)
|
||
x = self.model[7](p4)
|
||
x = self.model[8](x)
|
||
x = self.model[9](x)
|
||
p5 = self.model[10](x) # 保存 P5 (layer 10)
|
||
|
||
# Head
|
||
x = self.model[11](p5) # Upsample
|
||
x = self.model[12]([x, p4]) # Concat P4
|
||
h1 = self.model[13](x) # Head P4 (layer 13)
|
||
|
||
x = self.model[14](h1) # Upsample
|
||
x = self.model[15]([x, p3]) # Concat P3
|
||
h2 = self.model[16](x) # Output P3 (layer 16)
|
||
|
||
x = self.model[17](h2) # Conv
|
||
x = self.model[18]([x, h1]) # Concat Head P4
|
||
h3 = self.model[19](x) # Output P4 (layer 19)
|
||
|
||
x = self.model[20](h3) # Conv
|
||
x = self.model[21]([x, p5]) # Concat P5
|
||
h4 = self.model[22](x) # Output P5 (layer 22)
|
||
|
||
head = self.model[23]
|
||
feats = [h2, h3, h4]
|
||
|
||
processed_tpe = head.get_tpe(tpe) # type: ignore
|
||
|
||
processed_vpe = head.get_vpe(feats, vpe) if vpe is not None else None # type: ignore
|
||
|
||
all_pe = []
|
||
if processed_tpe is not None:
|
||
all_pe.append(processed_tpe)
|
||
if processed_vpe is not None:
|
||
all_pe.append(processed_vpe)
|
||
|
||
if not all_pe:
|
||
if self.pe is not None:
|
||
all_pe.append(self.pe.to(device=x.device, dtype=x.dtype))
|
||
else:
|
||
all_pe.append(torch.zeros(1, self.nc, head.embed, device=x.device, dtype=x.dtype))
|
||
|
||
cls_pe = torch.cat(all_pe, dim=1)
|
||
|
||
b = x.shape[0]
|
||
if cls_pe.shape[0] != b:
|
||
cls_pe = cls_pe.expand(b, -1, -1)
|
||
|
||
return head(feats, cls_pe)
|
||
|
||
def load_weights(self, pth_file):
|
||
state_dict = torch.load(pth_file, map_location='cpu', weights_only=False)
|
||
# 移除可能存在的 'model.' 前缀 (如果权重来自 ultralytics 官方)
|
||
# 官方权重通常是 model.model.0.conv... 这种格式,或者直接是 model.0.conv...
|
||
# 这里做一个简单的兼容性处理
|
||
new_state_dict = {}
|
||
for k, v in state_dict.items():
|
||
# 处理 ultralytics 权重字典中的 'model' 键
|
||
if k == 'model':
|
||
# 如果是完整的 checkpoint,权重在 'model' 键下
|
||
# 且通常是 model.state_dict()
|
||
if hasattr(v, 'state_dict'):
|
||
v = v.state_dict()
|
||
elif isinstance(v, dict):
|
||
pass # v 就是 state_dict
|
||
else:
|
||
# 可能是 model 对象本身
|
||
try:
|
||
v = v.float().state_dict()
|
||
except:
|
||
continue
|
||
|
||
for sub_k, sub_v in v.items():
|
||
new_state_dict[sub_k] = sub_v
|
||
break
|
||
else:
|
||
new_state_dict[k] = v
|
||
|
||
if not new_state_dict:
|
||
new_state_dict = state_dict
|
||
|
||
# 尝试加载
|
||
try:
|
||
self.load_state_dict(new_state_dict, strict=True)
|
||
print(f"Successfully loaded weights from {pth_file}")
|
||
except Exception as e:
|
||
print(f"Error loading weights: {e}")
|
||
print("Trying to load with strict=False...")
|
||
self.load_state_dict(new_state_dict, strict=False)
|
||
|
||
if __name__ == "__main__":
|
||
model = YOLO11E(nc=80, scale='l')
|
||
model.load_weights("yoloe-11l-seg.pth")
|
||
|
||
# 模拟 set_classes
|
||
# 假设我们有2个类,embedding维度是512
|
||
fake_embeddings = torch.randn(1, 2, 512)
|
||
model.set_classes(["class1", "class2"], fake_embeddings)
|
||
|
||
# 推理
|
||
dummy_input = torch.randn(1, 3, 640, 640)
|
||
model.eval()
|
||
output = model(dummy_input)
|
||
print("Output shape:", output[0].shape) # 应该是 (1, 4+mask_coeffs+num_classes, anchors) |