移除detect头中动态运算图,动态部分分离出来,并更新相关的推理代码示例

This commit is contained in:
lhr
2025-12-30 17:10:01 +08:00
parent 9df330875d
commit 553a63f521
3 changed files with 203 additions and 621 deletions

View File

@@ -2,9 +2,8 @@ import torch
import cv2 import cv2
import numpy as np import numpy as np
import torchvision import torchvision
from yolo11_standalone import YOLO11 from yolo11_standalone import YOLO11, YOLOPostProcessor
# COCO 80类 类别名称
CLASSES = [ CLASSES = [
"person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light", "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light",
"fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow", "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow",
@@ -17,37 +16,29 @@ CLASSES = [
"hair drier", "toothbrush" "hair drier", "toothbrush"
] ]
# 生成随机颜色用于绘图
COLORS = np.random.uniform(0, 255, size=(len(CLASSES), 3)) COLORS = np.random.uniform(0, 255, size=(len(CLASSES), 3))
def letterbox(im, new_shape=(640, 640), color=(114, 114, 114)): def letterbox(im, new_shape=(640, 640), color=(114, 114, 114)):
""" shape = im.shape[:2]
将图像缩放并填充到指定大小 (保持纵横比)
"""
shape = im.shape[:2] # current shape [height, width]
if isinstance(new_shape, int): if isinstance(new_shape, int):
new_shape = (new_shape, new_shape) new_shape = (new_shape, new_shape)
# 计算缩放比例
r = min(new_shape[0] / shape[0], new_shape[1] / shape[1]) r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
# 计算padding
new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r)) new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
dw, dh = dw / 2, dh / 2 # divide padding into 2 sides dw, dh = dw / 2, dh / 2
if shape[::-1] != new_unpad: # resize if shape[::-1] != new_unpad:
im = cv2.resize(im, new_unpad, interpolation=cv2.INTER_LINEAR) im = cv2.resize(im, new_unpad, interpolation=cv2.INTER_LINEAR)
top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1)) top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
left, right = int(round(dw - 0.1)), int(round(dw + 0.1)) left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
# 添加边框
im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color)
return im, r, (dw, dh) return im, r, (dw, dh)
def xywh2xyxy(x): def xywh2xyxy(x):
"""Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2]"""
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
y[..., 0] = x[..., 0] - x[..., 2] / 2 # top left x y[..., 0] = x[..., 0] - x[..., 2] / 2 # top left x
y[..., 1] = x[..., 1] - x[..., 3] / 2 # top left y y[..., 1] = x[..., 1] - x[..., 3] / 2 # top left y
@@ -56,43 +47,27 @@ def xywh2xyxy(x):
return y return y
def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, max_det=300): def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, max_det=300):
"""
非极大值抑制 (NMS)
prediction: [Batch, 84, Anchors]
"""
# 1. 转置: [Batch, 84, Anchors] -> [Batch, Anchors, 84]
prediction = prediction.transpose(1, 2) prediction = prediction.transpose(1, 2)
bs = prediction.shape[0] # batch size bs = prediction.shape[0]
nc = prediction.shape[2] - 4 # number of classes nc = prediction.shape[2] - 4
# 修复: 使用 max(-1) 在最后一个维度(类别)上寻找最大置信度 xc = prediction[..., 4:].max(-1)[0] > conf_thres
# 之前的 max(1) 错误地在 Anchors 维度上操作了
xc = prediction[..., 4:].max(-1)[0] > conf_thres # candidates
output = [torch.zeros((0, 6), device=prediction.device)] * bs output = [torch.zeros((0, 6), device=prediction.device)] * bs
for xi, x in enumerate(prediction): # image index, image inference for xi, x in enumerate(prediction):
x = x[xc[xi]] # confidence filtering x = x[xc[xi]]
if not x.shape[0]: if not x.shape[0]:
continue continue
# Box (center x, center y, width, height) to (x1, y1, x2, y2)
box = xywh2xyxy(x[:, :4]) box = xywh2xyxy(x[:, :4])
# Confidence and Class
conf, j = x[:, 4:].max(1, keepdim=True) conf, j = x[:, 4:].max(1, keepdim=True)
x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres] x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]
# Check shape
n = x.shape[0] n = x.shape[0]
if not n: if not n:
continue continue
elif n > max_det: elif n > max_det:
x = x[x[:, 4].argsort(descending=True)[:max_det]] x = x[x[:, 4].argsort(descending=True)[:max_det]]
# Batched NMS
c = x[:, 5:6] * 7680 # classes c = x[:, 5:6] * 7680 # classes
boxes, scores = x[:, :4] + c, x[:, 4] boxes, scores = x[:, :4] + c, x[:, 4]
i = torchvision.ops.nms(boxes, scores, iou_thres) i = torchvision.ops.nms(boxes, scores, iou_thres)
@@ -101,28 +76,21 @@ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, max_det=300
return output return output
def main(): def main():
# 1. 初始化模型
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}") print(f"Using device: {device}")
model = YOLO11(nc=80, scale='s') model = YOLO11(nc=80, scale='s')
# 加载你之前转换好的纯净权重
model.load_weights("yolo11s.pth") model.load_weights("yolo11s.pth")
model.to(device) model.to(device)
model.eval() model.eval()
# model.train() post_std = YOLOPostProcessor(model.model[-1], use_segmentation=False)
# 2. 读取图片 img_path = "1.jpg"
img_path = "1.jpg" # 请替换为你本地的图片路径
img0 = cv2.imread(img_path) img0 = cv2.imread(img_path)
assert img0 is not None, f"Image Not Found {img_path}" assert img0 is not None, f"Image Not Found {img_path}"
# 3. 预处理
# Letterbox resize
img, ratio, (dw, dh) = letterbox(img0, new_shape=(640, 640)) img, ratio, (dw, dh) = letterbox(img0, new_shape=(640, 640))
# BGR to RGB, HWC to CHW
img = img[:, :, ::-1].transpose(2, 0, 1) img = img[:, :, ::-1].transpose(2, 0, 1)
img = np.ascontiguousarray(img) img = np.ascontiguousarray(img)
@@ -132,25 +100,20 @@ def main():
if img_tensor.ndim == 3: if img_tensor.ndim == 3:
img_tensor = img_tensor.unsqueeze(0) img_tensor = img_tensor.unsqueeze(0)
# 4. 推理
print("开始推理...") print("开始推理...")
with torch.no_grad(): with torch.no_grad():
pred = model(img_tensor) pred = model(img_tensor)
# 5. 后处理 (NMS) pred = post_std(pred)
pred = non_max_suppression(pred, conf_thres=0.25, iou_thres=0.45) pred = non_max_suppression(pred, conf_thres=0.25, iou_thres=0.45)
# 6. 绘制结果 det = pred[0]
det = pred[0] # 仅处理第一张图片
if len(det): if len(det):
# 将坐标映射回原图尺寸
# det[:, :4] 是 x1, y1, x2, y2
det[:, [0, 2]] -= dw # x padding det[:, [0, 2]] -= dw # x padding
det[:, [1, 3]] -= dh # y padding det[:, [1, 3]] -= dh # y padding
det[:, :4] /= ratio det[:, :4] /= ratio
# 裁剪坐标防止越界
det[:, 0].clamp_(0, img0.shape[1]) det[:, 0].clamp_(0, img0.shape[1])
det[:, 1].clamp_(0, img0.shape[0]) det[:, 1].clamp_(0, img0.shape[0])
det[:, 2].clamp_(0, img0.shape[1]) det[:, 2].clamp_(0, img0.shape[1])
@@ -163,21 +126,17 @@ def main():
label = f'{CLASSES[c]} {conf:.2f}' label = f'{CLASSES[c]} {conf:.2f}'
p1, p2 = (int(xyxy[0]), int(xyxy[1])), (int(xyxy[2]), int(xyxy[3])) p1, p2 = (int(xyxy[0]), int(xyxy[1])), (int(xyxy[2]), int(xyxy[3]))
# 画框
color = COLORS[c] color = COLORS[c]
cv2.rectangle(img0, p1, p2, color, 2, lineType=cv2.LINE_AA) cv2.rectangle(img0, p1, p2, color, 2, lineType=cv2.LINE_AA)
# 画标签背景
t_size = cv2.getTextSize(label, 0, fontScale=0.5, thickness=1)[0] t_size = cv2.getTextSize(label, 0, fontScale=0.5, thickness=1)[0]
p2_label = p1[0] + t_size[0], p1[1] - t_size[1] - 3 p2_label = p1[0] + t_size[0], p1[1] - t_size[1] - 3
cv2.rectangle(img0, p1, p2_label, color, -1, cv2.LINE_AA) cv2.rectangle(img0, p1, p2_label, color, -1, cv2.LINE_AA)
# 画文字
cv2.putText(img0, label, (p1[0], p1[1] - 2), 0, 0.5, [255, 255, 255], thickness=1, lineType=cv2.LINE_AA) cv2.putText(img0, label, (p1[0], p1[1] - 2), 0, 0.5, [255, 255, 255], thickness=1, lineType=cv2.LINE_AA)
print(f" - {label} at {p1}-{p2}") print(f" - {label} at {p1}-{p2}")
# 7. 显示/保存结果
cv2.imwrite("result.jpg", img0) cv2.imwrite("result.jpg", img0)
print("结果已保存至 result.jpg") print("结果已保存至 result.jpg")

View File

@@ -4,25 +4,18 @@ import numpy as np
import torchvision import torchvision
from pathlib import Path from pathlib import Path
# 导入你的模块 from yolo11_standalone import YOLO11E, YOLOPostProcessor
from yolo11_standalone import YOLO11E
from mobile_clip_standalone import MobileCLIP from mobile_clip_standalone import MobileCLIP
# --- 配置 ---
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
YOLO_WEIGHTS = "yoloe-11l-seg.pth" # 替换为你的 YOLO 权重路径 YOLO_WEIGHTS = "yoloe-11l-seg.pth"
CLIP_WEIGHTS = "mobileclip_blt.ts" # 替换为你的 MobileCLIP 权重路径 CLIP_WEIGHTS = "mobileclip_blt.ts"
CLIP_SIZE = "blt" # 对应 MobileCLIP 的 size CLIP_SIZE = "blt"
IMAGE_PATH = "1.jpg" # 待检测图片 IMAGE_PATH = "1.jpg"
# 自定义检测类别 (Open Vocabulary)
CUSTOM_CLASSES = ["girl", "red balloon"] CUSTOM_CLASSES = ["girl", "red balloon"]
# 绘图颜色
COLORS = np.random.uniform(0, 255, size=(len(CUSTOM_CLASSES), 3)) COLORS = np.random.uniform(0, 255, size=(len(CUSTOM_CLASSES), 3))
# --- 辅助函数 (Letterbox, NMS 等) ---
def letterbox(im, new_shape=(640, 640), color=(114, 114, 114)): def letterbox(im, new_shape=(640, 640), color=(114, 114, 114)):
shape = im.shape[:2] shape = im.shape[:2]
if isinstance(new_shape, int): new_shape = (new_shape, new_shape) if isinstance(new_shape, int): new_shape = (new_shape, new_shape)
@@ -68,99 +61,66 @@ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.7, max_det=300)
def main(): def main():
print(f"Using device: {DEVICE}") print(f"Using device: {DEVICE}")
# 1. 加载 MobileCLIP 文本编码器
print(f"Loading MobileCLIP from {CLIP_WEIGHTS}...") print(f"Loading MobileCLIP from {CLIP_WEIGHTS}...")
if not Path(CLIP_WEIGHTS).exists(): if not Path(CLIP_WEIGHTS).exists(): raise FileNotFoundError(CLIP_WEIGHTS)
raise FileNotFoundError(f"MobileCLIP weights not found: {CLIP_WEIGHTS}")
clip_model = MobileCLIP(checkpoint=CLIP_WEIGHTS, size=CLIP_SIZE, device=DEVICE) clip_model = MobileCLIP(checkpoint=CLIP_WEIGHTS, size=CLIP_SIZE, device=DEVICE)
# 2. 生成文本 Embeddings
print(f"Encoding classes: {CUSTOM_CLASSES}") print(f"Encoding classes: {CUSTOM_CLASSES}")
prompts = [f"{c}" for c in CUSTOM_CLASSES] tokens = clip_model.tokenize([f"{c}" for c in CUSTOM_CLASSES])
text_embeddings = clip_model.encode_text(tokens).unsqueeze(0)
tokens = clip_model.tokenize(prompts)
text_embeddings = clip_model.encode_text(tokens) # Shape: (N, 512)
# 调整维度为 (1, N, 512) 以匹配 YOLO11E 输入
text_embeddings = text_embeddings.unsqueeze(0)
# 3. 加载 YOLO11E 检测模型
print(f"Loading YOLO11E from {YOLO_WEIGHTS}...") print(f"Loading YOLO11E from {YOLO_WEIGHTS}...")
if not Path(YOLO_WEIGHTS).exists(): if not Path(YOLO_WEIGHTS).exists(): raise FileNotFoundError(YOLO_WEIGHTS)
raise FileNotFoundError(f"YOLO weights not found: {YOLO_WEIGHTS}")
yolo_model = YOLO11E(nc=80, scale='l')
# 注意scale='l' 必须与你的权重文件匹配 (s, m, l, x)
yolo_model = YOLO11E(nc=80, scale='l')
yolo_model.load_weights(YOLO_WEIGHTS) yolo_model.load_weights(YOLO_WEIGHTS)
yolo_model.to(DEVICE) # 使用半精度to(DEVICE) yolo_model.to(DEVICE).eval()
yolo_model.eval()
head = yolo_model.model[-1] head = yolo_model.model[-1]
post_processor = YOLOPostProcessor(head, use_segmentation=True)
post_processor.to(DEVICE).eval()
with torch.no_grad(): with torch.no_grad():
text_pe = head.get_tpe(text_embeddings) # type: ignore text_pe = head.get_tpe(text_embeddings)
yolo_model.set_classes(CUSTOM_CLASSES, text_pe) yolo_model.set_classes(CUSTOM_CLASSES, text_pe)
# 5. 图像预处理
img0 = cv2.imread(IMAGE_PATH) img0 = cv2.imread(IMAGE_PATH)
assert img0 is not None, f"Image Not Found {IMAGE_PATH}" assert img0 is not None, f"Image Not Found {IMAGE_PATH}"
img, ratio, (dw, dh) = letterbox(img0, new_shape=(640, 640)) img, ratio, (dw, dh) = letterbox(img0, new_shape=(640, 640))
img = img[:, :, ::-1].transpose(2, 0, 1) img = np.ascontiguousarray(img[:, :, ::-1].transpose(2, 0, 1))
img = np.ascontiguousarray(img) img_tensor = torch.from_numpy(img).to(DEVICE).float() / 255.0
img_tensor = torch.from_numpy(img).to(DEVICE) if img_tensor.ndim == 3: img_tensor = img_tensor.unsqueeze(0)
img_tensor = img_tensor.float()
img_tensor /= 255.0
if img_tensor.ndim == 3:
img_tensor = img_tensor.unsqueeze(0)
# 6. 推理
print("Running inference...") print("Running inference...")
with torch.no_grad(): with torch.no_grad():
pred = yolo_model(img_tensor) raw_outputs = yolo_model(img_tensor)
if isinstance(pred, tuple): decoded_box, mc, p = post_processor(raw_outputs)
pred = pred[0]
nc = len(CUSTOM_CLASSES)
pred = pred[:, :4+nc, :]
# 7. 后处理 (NMS) pred = non_max_suppression(decoded_box, conf_thres=0.25, iou_thres=0.7)
pred = non_max_suppression(pred, conf_thres=0.25, iou_thres=0.7)
print(pred)
# 8. 可视化
det = pred[0] det = pred[0]
if len(det): if len(det):
det[:, [0, 2]] -= dw det[:, [0, 2]] -= dw
det[:, [1, 3]] -= dh det[:, [1, 3]] -= dh
det[:, :4] /= ratio det[:, :4] /= ratio
det[:, 0].clamp_(0, img0.shape[1]) det[:, [0, 2]].clamp_(0, img0.shape[1])
det[:, 1].clamp_(0, img0.shape[0]) det[:, [1, 3]].clamp_(0, img0.shape[0])
det[:, 2].clamp_(0, img0.shape[1])
det[:, 3].clamp_(0, img0.shape[0])
print(f"Detected {len(det)} objects:") print(f"Detected {len(det)} objects:")
for *xyxy, conf, cls in det: for *xyxy, conf, cls in det:
c = int(cls) c = int(cls)
class_name = CUSTOM_CLASSES[c] if c < len(CUSTOM_CLASSES) else str(c) label = f'{CUSTOM_CLASSES[c]} {conf:.2f}'
label = f'{class_name} {conf:.2f}'
p1, p2 = (int(xyxy[0]), int(xyxy[1])), (int(xyxy[2]), int(xyxy[3])) p1, p2 = (int(xyxy[0]), int(xyxy[1])), (int(xyxy[2]), int(xyxy[3]))
color = COLORS[c] color = COLORS[c % len(COLORS)]
cv2.rectangle(img0, p1, p2, color, 2, cv2.LINE_AA) cv2.rectangle(img0, p1, p2, color, 2, cv2.LINE_AA)
t_size = cv2.getTextSize(label, 0, 0.5, 1)[0] cv2.putText(img0, label, (p1[0], p1[1] - 5), 0, 0.5, color, 1, cv2.LINE_AA)
p2_label = p1[0] + t_size[0], p1[1] - t_size[1] - 3
cv2.rectangle(img0, p1, p2_label, color, -1, cv2.LINE_AA)
cv2.putText(img0, label, (p1[0], p1[1] - 2), 0, 0.5, [255, 255, 255], 1, cv2.LINE_AA)
print(f" - {label}") print(f" - {label}")
else: else:
print("No objects detected.") print("No objects detected.")
output_path = "result_full.jpg" cv2.imwrite("result_separate.jpg", img0)
cv2.imwrite(output_path, img0) print("Result saved.")
print(f"Result saved to {output_path}")
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@@ -1,11 +1,13 @@
import math import math
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
# ==============================================================================
# [Part 1] Utils & Basic Modules
# ==============================================================================
def make_divisible(x, divisor): def make_divisible(x, divisor):
if isinstance(x, torch.Tensor): if isinstance(x, torch.Tensor):
return x return x
@@ -19,6 +21,7 @@ def autopad(k, p=None, d=1):
return p return p
def make_anchors(feats, strides, grid_cell_offset=0.5): def make_anchors(feats, strides, grid_cell_offset=0.5):
"""生成 Anchor Points用于后处理阶段"""
anchor_points, stride_tensor = [], [] anchor_points, stride_tensor = [], []
assert feats is not None assert feats is not None
dtype, device = feats[0].dtype, feats[0].device dtype, device = feats[0].dtype, feats[0].device
@@ -32,6 +35,7 @@ def make_anchors(feats, strides, grid_cell_offset=0.5):
return torch.cat(anchor_points), torch.cat(stride_tensor) return torch.cat(anchor_points), torch.cat(stride_tensor)
def dist2bbox(distance, anchor_points, xywh=True, dim=-1): def dist2bbox(distance, anchor_points, xywh=True, dim=-1):
"""将预测的距离转换为 BBox用于后处理阶段"""
lt, rb = distance.chunk(2, dim) lt, rb = distance.chunk(2, dim)
x1y1 = anchor_points - lt x1y1 = anchor_points - lt
x2y2 = anchor_points + rb x2y2 = anchor_points + rb
@@ -45,26 +49,19 @@ class Concat(nn.Module):
def __init__(self, dimension=1): def __init__(self, dimension=1):
super().__init__() super().__init__()
self.d = dimension self.d = dimension
def forward(self, x: List[torch.Tensor]): def forward(self, x: List[torch.Tensor]):
return torch.cat(x, self.d) return torch.cat(x, self.d)
class Conv(nn.Module): class Conv(nn.Module):
default_act = nn.SiLU(inplace=True) default_act = nn.SiLU(inplace=True)
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True): def __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True):
super().__init__() super().__init__()
self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p, d), groups=g, dilation=d, bias=False) # type: ignore self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p, d), groups=g, dilation=d, bias=False)
self.bn = nn.BatchNorm2d(c2, eps=0.001, momentum=0.03) self.bn = nn.BatchNorm2d(c2, eps=0.001, momentum=0.03)
self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity() self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()
def forward(self, x): def forward(self, x):
return self.act(self.bn(self.conv(x))) return self.act(self.bn(self.conv(x)))
def forward_fuse(self, x):
return self.act(self.conv(x))
class DWConv(Conv): class DWConv(Conv):
def __init__(self, c1, c2, k=1, s=1, d=1, act=True): def __init__(self, c1, c2, k=1, s=1, d=1, act=True):
super().__init__(c1, c2, k, s, g=math.gcd(c1, c2), d=d, act=act) super().__init__(c1, c2, k, s, g=math.gcd(c1, c2), d=d, act=act)
@@ -76,21 +73,17 @@ class DFL(nn.Module):
x = torch.arange(c1, dtype=torch.float) x = torch.arange(c1, dtype=torch.float)
self.conv.weight.data[:] = nn.Parameter(x.view(1, c1, 1, 1)) self.conv.weight.data[:] = nn.Parameter(x.view(1, c1, 1, 1))
self.c1 = c1 self.c1 = c1
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
b, _, a = x.shape b, _, a = x.shape
return self.conv(x.view(b, 4, self.c1, a).transpose(2, 1).softmax(1)).view(b, 4, a) return self.conv(x.view(b, 4, self.c1, a).transpose(2, 1).softmax(1)).view(b, 4, a)
class Bottleneck(nn.Module): class Bottleneck(nn.Module):
def __init__( def __init__(self, c1: int, c2: int, shortcut: bool = True, g: int = 1, k: Tuple[int, int] = (3, 3), e: float = 0.5):
self, c1: int, c2: int, shortcut: bool = True, g: int = 1, k: Tuple[int, int] = (3, 3), e: float = 0.5
):
super().__init__() super().__init__()
c_ = int(c2 * e) c_ = int(c2 * e)
self.cv1 = Conv(c1, c_, k[0], 1) self.cv1 = Conv(c1, c_, k[0], 1)
self.cv2 = Conv(c_, c2, k[1], 1, g=g) self.cv2 = Conv(c_, c2, k[1], 1, g=g)
self.add = shortcut and c1 == c2 self.add = shortcut and c1 == c2
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x)) return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
@@ -100,22 +93,14 @@ class C2f(nn.Module):
self.c = int(c2 * e) self.c = int(c2 * e)
self.cv1 = Conv(c1, 2 * self.c, 1, 1) self.cv1 = Conv(c1, 2 * self.c, 1, 1)
self.cv2 = Conv((2 + n) * self.c, c2, 1) self.cv2 = Conv((2 + n) * self.c, c2, 1)
self.m = nn.ModuleList(Bottleneck(self.c, self.c, shortcut, g, k=((3, 3), (3, 3)), e=1.0) for _ in range(n)) # type: ignore self.m = nn.ModuleList(Bottleneck(self.c, self.c, shortcut, g, k=((3, 3), (3, 3)), e=1.0) for _ in range(n))
def forward(self, x): def forward(self, x):
chunk_result = self.cv1(x).chunk(2, 1) chunk_result = self.cv1(x).chunk(2, 1)
y = [chunk_result[0], chunk_result[1]] y = [chunk_result[0], chunk_result[1]]
for m_module in self.m: for m_module in self.m:
y.append(m_module(y[-1])) y.append(m_module(y[-1]))
return self.cv2(torch.cat(y, 1)) return self.cv2(torch.cat(y, 1))
def forward_split(self, x: torch.Tensor) -> torch.Tensor:
y = self.cv1(x).split((self.c, self.c), 1)
y = [y[0], y[1]]
y.extend(m(y[-1]) for m in self.m)
return self.cv2(torch.cat(y, 1))
class C3(nn.Module): class C3(nn.Module):
def __init__(self, c1: int, c2: int, n: int = 1, shortcut: bool = True, g: int = 1, e: float = 0.5): def __init__(self, c1: int, c2: int, n: int = 1, shortcut: bool = True, g: int = 1, e: float = 0.5):
super().__init__() super().__init__()
@@ -123,8 +108,7 @@ class C3(nn.Module):
self.cv1 = Conv(c1, c_, 1, 1) self.cv1 = Conv(c1, c_, 1, 1)
self.cv2 = Conv(c1, c_, 1, 1) self.cv2 = Conv(c1, c_, 1, 1)
self.cv3 = Conv(2 * c_, c2, 1) self.cv3 = Conv(2 * c_, c2, 1)
self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, k=((1, 1), (3, 3)), e=1.0) for _ in range(n))) # type: ignore self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, k=((1, 1), (3, 3)), e=1.0) for _ in range(n)))
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), 1)) return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), 1))
@@ -135,13 +119,9 @@ class C3k(C3):
self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, k=(k, k), e=1.0) for _ in range(n))) self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, k=(k, k), e=1.0) for _ in range(n)))
class C3k2(C2f): class C3k2(C2f):
def __init__( def __init__(self, c1: int, c2: int, n: int = 1, c3k: bool = False, e: float = 0.5, g: int = 1, shortcut: bool = True):
self, c1: int, c2: int, n: int = 1, c3k: bool = False, e: float = 0.5, g: int = 1, shortcut: bool = True
):
super().__init__(c1, c2, n, shortcut, g, e) super().__init__(c1, c2, n, shortcut, g, e)
self.m = nn.ModuleList( self.m = nn.ModuleList(C3k(self.c, self.c, 2, shortcut, g) if c3k else Bottleneck(self.c, self.c, shortcut, g) for _ in range(n))
C3k(self.c, self.c, 2, shortcut, g) if c3k else Bottleneck(self.c, self.c, shortcut, g) for _ in range(n)
)
class SPPF(nn.Module): class SPPF(nn.Module):
def __init__(self, c1: int, c2: int, k: int = 5): def __init__(self, c1: int, c2: int, k: int = 5):
@@ -150,12 +130,16 @@ class SPPF(nn.Module):
self.cv1 = Conv(c1, c_, 1, 1) self.cv1 = Conv(c1, c_, 1, 1)
self.cv2 = Conv(c_ * 4, c2, 1, 1) self.cv2 = Conv(c_ * 4, c2, 1, 1)
self.m = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2) self.m = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
y = [self.cv1(x)] y = [self.cv1(x)]
y.extend(self.m(y[-1]) for _ in range(3)) y.extend(self.m(y[-1]) for _ in range(3))
return self.cv2(torch.cat(y, 1)) return self.cv2(torch.cat(y, 1))
# ==============================================================================
# [Part 2] Advanced Modules & Pure Heads
# ==============================================================================
class Attention(nn.Module): class Attention(nn.Module):
def __init__(self, dim: int, num_heads: int = 8, attn_ratio: float = 0.5): def __init__(self, dim: int, num_heads: int = 8, attn_ratio: float = 0.5):
super().__init__() super().__init__()
@@ -168,7 +152,6 @@ class Attention(nn.Module):
self.qkv = Conv(dim, h, 1, act=False) self.qkv = Conv(dim, h, 1, act=False)
self.proj = Conv(dim, dim, 1, act=False) self.proj = Conv(dim, dim, 1, act=False)
self.pe = Conv(dim, dim, 3, 1, g=dim, act=False) self.pe = Conv(dim, dim, 3, 1, g=dim, act=False)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
B, C, H, W = x.shape B, C, H, W = x.shape
N = H * W N = H * W
@@ -176,7 +159,6 @@ class Attention(nn.Module):
q, k, v = qkv.view(B, self.num_heads, self.key_dim * 2 + self.head_dim, N).split( q, k, v = qkv.view(B, self.num_heads, self.key_dim * 2 + self.head_dim, N).split(
[self.key_dim, self.key_dim, self.head_dim], dim=2 [self.key_dim, self.key_dim, self.head_dim], dim=2
) )
attn = (q.transpose(-2, -1) @ k) * self.scale attn = (q.transpose(-2, -1) @ k) * self.scale
attn = attn.softmax(dim=-1) attn = attn.softmax(dim=-1)
x = (v @ attn.transpose(-2, -1)).view(B, C, H, W) + self.pe(v.reshape(B, C, H, W)) x = (v @ attn.transpose(-2, -1)).view(B, C, H, W) + self.pe(v.reshape(B, C, H, W))
@@ -189,7 +171,6 @@ class PSABlock(nn.Module):
self.attn = Attention(c, attn_ratio=attn_ratio, num_heads=num_heads) self.attn = Attention(c, attn_ratio=attn_ratio, num_heads=num_heads)
self.ffn = nn.Sequential(Conv(c, c * 2, 1), Conv(c * 2, c, 1, act=False)) self.ffn = nn.Sequential(Conv(c, c * 2, 1), Conv(c * 2, c, 1, act=False))
self.add = shortcut self.add = shortcut
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x + self.attn(x) if self.add else self.attn(x) x = x + self.attn(x) if self.add else self.attn(x)
x = x + self.ffn(x) if self.add else self.ffn(x) x = x + self.ffn(x) if self.add else self.ffn(x)
@@ -202,9 +183,7 @@ class C2PSA(nn.Module):
self.c = int(c1 * e) self.c = int(c1 * e)
self.cv1 = Conv(c1, 2 * self.c, 1, 1) self.cv1 = Conv(c1, 2 * self.c, 1, 1)
self.cv2 = Conv(2 * self.c, c1, 1) self.cv2 = Conv(2 * self.c, c1, 1)
self.m = nn.Sequential(*(PSABlock(self.c, attn_ratio=0.5, num_heads=self.c // 64) for _ in range(n))) self.m = nn.Sequential(*(PSABlock(self.c, attn_ratio=0.5, num_heads=self.c // 64) for _ in range(n)))
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
a, b = self.cv1(x).split((self.c, self.c), dim=1) a, b = self.cv1(x).split((self.c, self.c), dim=1)
b = self.m(b) b = self.m(b)
@@ -217,7 +196,6 @@ class Proto(nn.Module):
self.upsample = nn.ConvTranspose2d(c_, c_, 2, 2, 0, bias=True) self.upsample = nn.ConvTranspose2d(c_, c_, 2, 2, 0, bias=True)
self.cv2 = Conv(c_, c_, k=3) self.cv2 = Conv(c_, c_, k=3)
self.cv3 = Conv(c_, c2) self.cv3 = Conv(c_, c2)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.cv3(self.cv2(self.upsample(self.cv1(x)))) return self.cv3(self.cv2(self.upsample(self.cv1(x))))
@@ -227,20 +205,16 @@ class BNContrastiveHead(nn.Module):
self.norm = nn.BatchNorm2d(embed_dims) self.norm = nn.BatchNorm2d(embed_dims)
self.bias = nn.Parameter(torch.tensor([-10.0])) self.bias = nn.Parameter(torch.tensor([-10.0]))
self.logit_scale = nn.Parameter(-1.0 * torch.ones([])) self.logit_scale = nn.Parameter(-1.0 * torch.ones([]))
def fuse(self): def fuse(self):
del self.norm del self.norm
del self.bias del self.bias
del self.logit_scale del self.logit_scale
self.forward = self.forward_fuse self.forward = self.forward_fuse
def forward_fuse(self, x: torch.Tensor, w: torch.Tensor) -> torch.Tensor: def forward_fuse(self, x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
return x return x
def forward(self, x: torch.Tensor, w: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
x = self.norm(x) x = self.norm(x)
w = F.normalize(w, dim=-1, p=2) w = F.normalize(w, dim=-1, p=2)
x = torch.einsum("bchw,bkc->bkhw", x, w) x = torch.einsum("bchw,bkc->bkhw", x, w)
return x * self.logit_scale.exp() + self.bias return x * self.logit_scale.exp() + self.bias
@@ -249,7 +223,6 @@ class SwiGLUFFN(nn.Module):
super().__init__() super().__init__()
self.w12 = nn.Linear(gc, e * ec) self.w12 = nn.Linear(gc, e * ec)
self.w3 = nn.Linear(e * ec // 2, ec) self.w3 = nn.Linear(e * ec // 2, ec)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
x12 = self.w12(x) x12 = self.w12(x)
x1, x2 = x12.chunk(2, dim=-1) x1, x2 = x12.chunk(2, dim=-1)
@@ -261,10 +234,7 @@ class Residual(nn.Module):
super().__init__() super().__init__()
self.m = m self.m = m
nn.init.zeros_(self.m.w3.bias) nn.init.zeros_(self.m.w3.bias)
# For models with l scale, please change the initialization to
# nn.init.constant_(self.m.w3.weight, 1e-6)
nn.init.zeros_(self.m.w3.weight) nn.init.zeros_(self.m.w3.weight)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
return x + self.m(x) return x + self.m(x)
@@ -277,12 +247,10 @@ class SAVPE(nn.Module):
) )
for i, x in enumerate(ch) for i, x in enumerate(ch)
) )
self.cv2 = nn.ModuleList( self.cv2 = nn.ModuleList(
nn.Sequential(Conv(x, c3, 1), nn.Upsample(scale_factor=i * 2) if i in {1, 2} else nn.Identity()) nn.Sequential(Conv(x, c3, 1), nn.Upsample(scale_factor=i * 2) if i in {1, 2} else nn.Identity())
for i, x in enumerate(ch) for i, x in enumerate(ch)
) )
self.c = 16 self.c = 16
self.cv3 = nn.Conv2d(3 * c3, embed, 1) self.cv3 = nn.Conv2d(3 * c3, embed, 1)
self.cv4 = nn.Conv2d(3 * c3, self.c, 3, padding=1) self.cv4 = nn.Conv2d(3 * c3, self.c, 3, padding=1)
@@ -292,30 +260,21 @@ class SAVPE(nn.Module):
def forward(self, x: List[torch.Tensor], vp: torch.Tensor) -> torch.Tensor: def forward(self, x: List[torch.Tensor], vp: torch.Tensor) -> torch.Tensor:
y = [self.cv2[i](xi) for i, xi in enumerate(x)] y = [self.cv2[i](xi) for i, xi in enumerate(x)]
y = self.cv4(torch.cat(y, dim=1)) y = self.cv4(torch.cat(y, dim=1))
x = [self.cv1[i](xi) for i, xi in enumerate(x)] x = [self.cv1[i](xi) for i, xi in enumerate(x)]
x = self.cv3(torch.cat(x, dim=1)) x = self.cv3(torch.cat(x, dim=1))
B, C, H, W = x.shape
B, C, H, W = x.shape # type: ignore
Q = vp.shape[1] Q = vp.shape[1]
x = x.view(B, C, -1)
x = x.view(B, C, -1) # type: ignore
y = y.reshape(B, 1, self.c, H, W).expand(-1, Q, -1, -1, -1).reshape(B * Q, self.c, H, W) y = y.reshape(B, 1, self.c, H, W).expand(-1, Q, -1, -1, -1).reshape(B * Q, self.c, H, W)
vp = vp.reshape(B, Q, 1, H, W).reshape(B * Q, 1, H, W) vp = vp.reshape(B, Q, 1, H, W).reshape(B * Q, 1, H, W)
y = self.cv6(torch.cat((y, self.cv5(vp)), dim=1)) y = self.cv6(torch.cat((y, self.cv5(vp)), dim=1))
y = y.reshape(B, Q, self.c, -1) y = y.reshape(B, Q, self.c, -1)
vp = vp.reshape(B, Q, 1, -1) vp = vp.reshape(B, Q, 1, -1)
score = y * vp + torch.logical_not(vp) * torch.finfo(y.dtype).min score = y * vp + torch.logical_not(vp) * torch.finfo(y.dtype).min
score = F.softmax(score, dim=-1).to(y.dtype) score = F.softmax(score, dim=-1).to(y.dtype)
aggregated = score.transpose(-2, -3) @ x.reshape(B, self.c, C // self.c, -1).transpose(-1, -2) aggregated = score.transpose(-2, -3) @ x.reshape(B, self.c, C // self.c, -1).transpose(-1, -2)
return F.normalize(aggregated.transpose(-2, -3).reshape(B, Q, -1), dim=-1, p=2) return F.normalize(aggregated.transpose(-2, -3).reshape(B, Q, -1), dim=-1, p=2)
class Detect(nn.Module): class Detect(nn.Module):
dynamic = False dynamic = False
export = False export = False
@@ -347,45 +306,17 @@ class Detect(nn.Module):
self.dfl = DFL(self.reg_max) if self.reg_max > 1 else nn.Identity() self.dfl = DFL(self.reg_max) if self.reg_max > 1 else nn.Identity()
def forward(self, x): def forward(self, x):
outs = []
for i in range(self.nl): for i in range(self.nl):
x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1) outs.append(torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1))
return outs
if self.training:
return x
# Inference path
shape = x[0].shape
x_cat = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2)
if self.dynamic or self.shape != shape:
self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
self.shape = shape
box, cls = x_cat.split((self.reg_max * 4, self.nc), 1)
dbox = dist2bbox(self.dfl(box), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
return torch.cat((dbox, cls.sigmoid()), 1)
def bias_init(self): def bias_init(self):
m = self m = self
for a, b, s in zip(m.cv2, m.cv3, m.stride): for a, b, s in zip(m.cv2, m.cv3, m.stride):
a[-1].bias.data[:] = 1.0 # type: ignore a[-1].bias.data[:] = 1.0
b[-1].bias.data[: m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # type: ignore b[-1].bias.data[: m.nc] = math.log(5 / m.nc / (640 / s) ** 2)
def decode_bboxes(self, bboxes: torch.Tensor, anchors: torch.Tensor, xywh: bool = True) -> torch.Tensor:
return dist2bbox(bboxes, anchors, xywh=xywh and not (self.end2end or self.xyxy), dim=1)
@staticmethod
def postprocess(preds: torch.Tensor, max_det: int, nc: int = 80) -> torch.Tensor:
batch_size, anchors, _ = preds.shape
boxes, scores = preds.split([4, nc], dim=-1)
index = scores.amax(dim=-1).topk(min(max_det, anchors))[1].unsqueeze(-1)
boxes = boxes.gather(dim=1, index=index.repeat(1, 1, 4))
scores = scores.gather(dim=1, index=index.repeat(1, 1, nc))
scores, index = scores.flatten(1).topk(min(max_det, anchors))
i = torch.arange(batch_size)[..., None]
return torch.cat([boxes[i, index // nc], scores[..., None], (index % nc)[..., None].float()], dim=-1)
class YOLOEDetect(Detect): class YOLOEDetect(Detect):
def __init__(self, nc: int = 80, embed: int = 512, ch: Tuple = ()): def __init__(self, nc: int = 80, embed: int = 512, ch: Tuple = ()):
super().__init__(nc, ch) super().__init__(nc, ch)
@@ -401,49 +332,32 @@ class YOLOEDetect(Detect):
for x in ch for x in ch
) )
) )
self.cv4 = nn.ModuleList(BNContrastiveHead(embed) for _ in ch) self.cv4 = nn.ModuleList(BNContrastiveHead(embed) for _ in ch)
self.reprta = Residual(SwiGLUFFN(embed, embed)) self.reprta = Residual(SwiGLUFFN(embed, embed))
self.savpe = SAVPE(ch, c3, embed) # type: ignore self.savpe = SAVPE(ch, c3, embed)
self.embed = embed self.embed = embed
def get_tpe(self, tpe: Optional[torch.Tensor]) -> Optional[torch.Tensor]: def get_tpe(self, tpe: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
return None if tpe is None else F.normalize(self.reprta(tpe), dim=-1, p=2) return None if tpe is None else F.normalize(self.reprta(tpe), dim=-1, p=2)
def get_vpe(self, x: List[torch.Tensor], vpe: torch.Tensor) -> torch.Tensor: def get_vpe(self, x: List[torch.Tensor], vpe: torch.Tensor) -> torch.Tensor:
if vpe.shape[1] == 0: # no visual prompt embeddings if vpe.shape[1] == 0:
return torch.zeros(x[0].shape[0], 0, self.embed, device=x[0].device) return torch.zeros(x[0].shape[0], 0, self.embed, device=x[0].device)
if vpe.ndim == 4: # (B, N, H, W) if vpe.ndim == 4:
vpe = self.savpe(x, vpe) vpe = self.savpe(x, vpe)
assert vpe.ndim == 3 # (B, N, D) assert vpe.ndim == 3
return vpe return vpe
def forward( # type: ignore def forward(self, x: List[torch.Tensor], cls_pe: torch.Tensor) -> List[torch.Tensor]:
self, x: List[torch.Tensor], cls_pe: torch.Tensor outs = []
) -> Union[torch.Tensor, Tuple]:
for i in range(self.nl): for i in range(self.nl):
x[i] = torch.cat((self.cv2[i](x[i]), self.cv4[i](self.cv3[i](x[i]), cls_pe)), 1) outs.append(torch.cat((self.cv2[i](x[i]), self.cv4[i](self.cv3[i](x[i]), cls_pe)), 1))
if self.training: return outs
return x # type: ignore
self.no = self.nc + self.reg_max * 4 # self.nc could be changed when inference with different texts
# Inference path
shape = x[0].shape
x_cat = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2)
if self.dynamic or self.shape != shape:
self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
self.shape = shape
box, cls = x_cat.split((self.reg_max * 4, self.nc), 1)
dbox = dist2bbox(self.dfl(box), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
return torch.cat((dbox, cls.sigmoid()), 1)
def bias_init(self): def bias_init(self):
m = self m = self
for a, b, c, s in zip(m.cv2, m.cv3, m.cv4, m.stride): for a, b, c, s in zip(m.cv2, m.cv3, m.cv4, m.stride):
a[-1].bias.data[:] = 1.0 # box a[-1].bias.data[:] = 1.0
b[-1].bias.data[:] = 0.0 b[-1].bias.data[:] = 0.0
c.bias.data[:] = math.log(5 / m.nc / (640 / s) ** 2) c.bias.data[:] = math.log(5 / m.nc / (640 / s) ** 2)
@@ -459,465 +373,214 @@ class YOLOESegment(YOLOEDetect):
c5 = max(ch[0] // 4, self.nm) c5 = max(ch[0] // 4, self.nm)
self.cv5 = nn.ModuleList(nn.Sequential(Conv(x, c5, 3), Conv(c5, c5, 3), nn.Conv2d(c5, self.nm, 1)) for x in ch) self.cv5 = nn.ModuleList(nn.Sequential(Conv(x, c5, 3), Conv(c5, c5, 3), nn.Conv2d(c5, self.nm, 1)) for x in ch)
def forward(self, x: List[torch.Tensor], text: torch.Tensor) -> Union[Tuple, torch.Tensor]: def forward(self, x: List[torch.Tensor], text: torch.Tensor) -> Tuple[List[torch.Tensor], torch.Tensor, torch.Tensor]:
p = self.proto(x[0]) # mask protos p = self.proto(x[0])
bs = p.shape[0] # batch size bs = p.shape[0]
mc = torch.cat([self.cv5[i](x[i]).view(bs, self.nm, -1) for i in range(self.nl)], 2)
mc = torch.cat([self.cv5[i](x[i]).view(bs, self.nm, -1) for i in range(self.nl)], 2) # mask coefficients x_out = YOLOEDetect.forward(self, x, text)
x = YOLOEDetect.forward(self, x, text) return x_out, mc, p
if self.training:
return x, mc, p
return (torch.cat([x, mc], 1), p)
# ==============================================================================
# [Part 3] PostProcessor & Top Level Models
# ==============================================================================
class YOLOPostProcessor(nn.Module):
def __init__(self, detect_head, use_segmentation=False):
super().__init__()
self.reg_max = detect_head.reg_max
self.stride = detect_head.stride
if hasattr(detect_head, 'dfl'):
self.dfl = detect_head.dfl
else:
self.dfl = nn.Identity()
self.use_segmentation = use_segmentation
self.register_buffer('anchors', torch.empty(0))
self.register_buffer('strides', torch.empty(0))
self.shape = None
def forward(self, outputs):
"""
outputs:
- Detect: List[Tensor]
- Segment: (List[Tensor], Tensor, Tensor)
"""
if self.use_segmentation:
x, mc, p = outputs
else:
x = outputs
current_no = x[0].shape[1]
current_nc = current_no - self.reg_max * 4
shape = x[0].shape
x_cat = torch.cat([xi.view(shape[0], current_no, -1) for xi in x], 2)
if self.anchors.device != x[0].device or self.shape != shape:
self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
self.shape = shape
box, cls = x_cat.split((self.reg_max * 4, current_nc), 1)
dbox = dist2bbox(self.dfl(box), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
final_box = torch.cat((dbox, cls.sigmoid()), 1)
if self.use_segmentation:
return final_box, mc, p
return final_box
class YOLO11(nn.Module): class YOLO11(nn.Module):
def __init__(self, nc=80, scale='n'): def __init__(self, nc=80, scale='n'):
super().__init__() super().__init__()
self.nc = nc self.nc = nc
# Scales: [depth, width, max_channels] # Scales: [depth, width, max_channels]
# 对应 yolo11.yaml 中的 scales 参数 scales = {'n': [0.50, 0.25, 1024], 's': [0.50, 0.50, 1024], 'm': [0.50, 1.00, 512], 'l': [1.00, 1.00, 512], 'x': [1.00, 1.50, 512]}
scales = { if scale not in scales: raise ValueError(f"Invalid scale")
'n': [0.50, 0.25, 1024],
's': [0.50, 0.50, 1024],
'm': [0.50, 1.00, 512],
'l': [1.00, 1.00, 512],
'x': [1.00, 1.50, 512],
}
if scale not in scales:
raise ValueError(f"Invalid scale '{scale}'. Available scales: {list(scales.keys())}")
depth, width, max_channels = scales[scale] depth, width, max_channels = scales[scale]
c3k_override = True if scale not in ['n', 's'] else False
if scale in ['n', 's']: def gw(channels): return make_divisible(min(channels, max_channels) * width, 8)
c3k_override = False def gd(n): return max(round(n * depth), 1) if n > 1 else n
else:
c3k_override = True
# 辅助函数:计算通道数 (Width Scaling)
def gw(channels):
return make_divisible(min(channels, max_channels) * width, 8)
# 辅助函数:计算层重复次数 (Depth Scaling)
def gd(n):
return max(round(n * depth), 1) if n > 1 else n
self.model = nn.ModuleList() self.model = nn.ModuleList()
# Backbone
# --- Backbone ---
# 0: Conv [64, 3, 2]
self.model.append(Conv(3, gw(64), 3, 2)) self.model.append(Conv(3, gw(64), 3, 2))
# 1: Conv [128, 3, 2]
self.model.append(Conv(gw(64), gw(128), 3, 2)) self.model.append(Conv(gw(64), gw(128), 3, 2))
# 2: C3k2 [256, False, 0.25] -> n=2
self.model.append(C3k2(gw(128), gw(256), n=gd(2), c3k=False or c3k_override, e=0.25)) self.model.append(C3k2(gw(128), gw(256), n=gd(2), c3k=False or c3k_override, e=0.25))
# 3: Conv [256, 3, 2]
self.model.append(Conv(gw(256), gw(256), 3, 2)) self.model.append(Conv(gw(256), gw(256), 3, 2))
# 4: C3k2 [512, False, 0.25] -> n=2
self.model.append(C3k2(gw(256), gw(512), n=gd(2), c3k=False or c3k_override, e=0.25)) self.model.append(C3k2(gw(256), gw(512), n=gd(2), c3k=False or c3k_override, e=0.25))
# 5: Conv [512, 3, 2]
self.model.append(Conv(gw(512), gw(512), 3, 2)) self.model.append(Conv(gw(512), gw(512), 3, 2))
# 6: C3k2 [512, True] -> n=2
self.model.append(C3k2(gw(512), gw(512), n=gd(2), c3k=True)) self.model.append(C3k2(gw(512), gw(512), n=gd(2), c3k=True))
# 7: Conv [1024, 3, 2]
self.model.append(Conv(gw(512), gw(1024), 3, 2)) self.model.append(Conv(gw(512), gw(1024), 3, 2))
# 8: C3k2 [1024, True] -> n=2
self.model.append(C3k2(gw(1024), gw(1024), n=gd(2), c3k=True)) self.model.append(C3k2(gw(1024), gw(1024), n=gd(2), c3k=True))
# 9: SPPF [1024, 5]
self.model.append(SPPF(gw(1024), gw(1024), 5)) self.model.append(SPPF(gw(1024), gw(1024), 5))
# 10: C2PSA [1024] -> n=2 (YAML args=[1024], repeats=2)
self.model.append(C2PSA(gw(1024), gw(1024), n=gd(2))) self.model.append(C2PSA(gw(1024), gw(1024), n=gd(2)))
# Neck
# --- Head ---
# 11: Upsample
self.model.append(nn.Upsample(scale_factor=2, mode='nearest')) self.model.append(nn.Upsample(scale_factor=2, mode='nearest'))
# 12: Concat [-1, 6] (P4)
self.model.append(Concat(dimension=1)) self.model.append(Concat(dimension=1))
# 13: C3k2 [512, False] -> n=2. Input: P5_up(gw(1024)) + P4(gw(512))
self.model.append(C3k2(gw(1024) + gw(512), gw(512), n=gd(2), c3k=False or c3k_override)) self.model.append(C3k2(gw(1024) + gw(512), gw(512), n=gd(2), c3k=False or c3k_override))
# 14: Upsample
self.model.append(nn.Upsample(scale_factor=2, mode='nearest')) self.model.append(nn.Upsample(scale_factor=2, mode='nearest'))
# 15: Concat [-1, 4] (P3)
self.model.append(Concat(dimension=1)) self.model.append(Concat(dimension=1))
# 16: C3k2 [256, False] -> n=2. Input: P4_up(gw(512)) + P3(gw(512))
# 注意Layer 4 输出是 gw(512)Layer 13 输出是 gw(512)
self.model.append(C3k2(gw(512) + gw(512), gw(256), n=gd(2), c3k=False or c3k_override)) self.model.append(C3k2(gw(512) + gw(512), gw(256), n=gd(2), c3k=False or c3k_override))
# 17: Conv [256, 3, 2]
self.model.append(Conv(gw(256), gw(256), 3, 2)) self.model.append(Conv(gw(256), gw(256), 3, 2))
# 18: Concat [-1, 13] (Head P4)
self.model.append(Concat(dimension=1)) self.model.append(Concat(dimension=1))
# 19: C3k2 [512, False] -> n=2. Input: P3_down(gw(256)) + Head_P4(gw(512))
self.model.append(C3k2(gw(256) + gw(512), gw(512), n=gd(2), c3k=False or c3k_override)) self.model.append(C3k2(gw(256) + gw(512), gw(512), n=gd(2), c3k=False or c3k_override))
# 20: Conv [512, 3, 2]
self.model.append(Conv(gw(512), gw(512), 3, 2)) self.model.append(Conv(gw(512), gw(512), 3, 2))
# 21: Concat [-1, 10] (P5)
self.model.append(Concat(dimension=1)) self.model.append(Concat(dimension=1))
# 22: C3k2 [1024, True] -> n=2. Input: P4_down(gw(512)) + P5(gw(1024))
self.model.append(C3k2(gw(512) + gw(1024), gw(1024), n=gd(2), c3k=True)) self.model.append(C3k2(gw(512) + gw(1024), gw(1024), n=gd(2), c3k=True))
# 23: Detect [nc] # 23: Standard Detect Head
self.model.append(Detect(nc, ch=[gw(256), gw(512), gw(1024)])) self.model.append(Detect(nc, ch=[gw(256), gw(512), gw(1024)]))
# --- 初始化权重 ---
self.initialize_weights() self.initialize_weights()
def initialize_weights(self): def initialize_weights(self):
"""初始化模型权重,特别是 Detect 头的 Bias"""
for m in self.modules(): for m in self.modules():
if isinstance(m, (nn.Conv2d, nn.Linear)): if isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
# 使用 Kaiming 初始化或其他合适的初始化
pass
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1) nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0) nn.init.constant_(m.bias, 0)
if isinstance(self.model[-1], Detect):
detect_layer = self.model[-1] self.model[-1].bias_init()
if isinstance(detect_layer, Detect):
detect_layer.bias_init()
def forward(self, x): def forward(self, x):
# Backbone x = self.model[0](x); x = self.model[1](x); x = self.model[2](x); x = self.model[3](x)
x = self.model[0](x) p3 = self.model[4](x)
x = self.model[1](x) x = self.model[5](p3); p4 = self.model[6](x)
x = self.model[2](x) x = self.model[7](p4); x = self.model[8](x); x = self.model[9](x); p5 = self.model[10](x)
x = self.model[3](x) x = self.model[11](p5); x = self.model[12]([x, p4]); h1 = self.model[13](x)
p3 = self.model[4](x) # 保存 P3 (layer 4) x = self.model[14](h1); x = self.model[15]([x, p3]); h2 = self.model[16](x)
x = self.model[5](p3) x = self.model[17](h2); x = self.model[18]([x, h1]); h3 = self.model[19](x)
p4 = self.model[6](x) # 保存 P4 (layer 6) x = self.model[20](h3); x = self.model[21]([x, p5]); h4 = self.model[22](x)
x = self.model[7](p4) return self.model[23]([h2, h3, h4])
x = self.model[8](x)
x = self.model[9](x)
p5 = self.model[10](x) # 保存 P5 (layer 10)
# Head
x = self.model[11](p5) # Upsample
x = self.model[12]([x, p4]) # Concat P4
h1 = self.model[13](x) # Head P4 (layer 13)
x = self.model[14](h1) # Upsample
x = self.model[15]([x, p3]) # Concat P3
h2 = self.model[16](x) # Output P3 (layer 16)
x = self.model[17](h2) # Conv
x = self.model[18]([x, h1]) # Concat Head P4
h3 = self.model[19](x) # Output P4 (layer 19)
x = self.model[20](h3) # Conv
x = self.model[21]([x, p5]) # Concat P5
h4 = self.model[22](x) # Output P5 (layer 22)
return self.model[23]([h2, h3, h4]) # Detect
def load_weights(self, pth_file): def load_weights(self, pth_file):
state_dict = torch.load(pth_file, map_location='cpu', weights_only=False) state_dict = torch.load(pth_file, map_location='cpu', weights_only=False)
# 移除可能存在的 'model.' 前缀 (如果权重来自 ultralytics 官方)
# 官方权重通常是 model.model.0.conv... 这种格式,或者直接是 model.0.conv...
# 这里做一个简单的兼容性处理
new_state_dict = {} new_state_dict = {}
for k, v in state_dict.items(): for k, v in state_dict.items():
# 处理 ultralytics 权重字典中的 'model' 键
if k == 'model': if k == 'model':
# 如果是完整的 checkpoint权重在 'model' 键下 if hasattr(v, 'state_dict'): v = v.state_dict()
# 且通常是 model.state_dict() elif isinstance(v, dict): pass
if hasattr(v, 'state_dict'): else:
v = v.state_dict() try: v = v.float().state_dict()
elif isinstance(v, dict): except: continue
pass # v 就是 state_dict for sub_k, sub_v in v.items(): new_state_dict[sub_k] = sub_v
else:
# 可能是 model 对象本身
try:
v = v.float().state_dict()
except:
continue
for sub_k, sub_v in v.items():
new_state_dict[sub_k] = sub_v
break break
else: else: new_state_dict[k] = v
new_state_dict[k] = v if not new_state_dict: new_state_dict = state_dict
if not new_state_dict:
new_state_dict = state_dict
# 尝试加载
try: try:
self.load_state_dict(new_state_dict, strict=True) self.load_state_dict(new_state_dict, strict=True)
print(f"Successfully loaded weights from {pth_file}") print(f"Successfully loaded weights from {pth_file}")
except Exception as e: except Exception as e:
print(f"Error loading weights: {e}") print(f"Error loading weights: {e}")
print("Trying to load with strict=False...")
self.load_state_dict(new_state_dict, strict=False) self.load_state_dict(new_state_dict, strict=False)
class YOLO11E(nn.Module): class YOLO11E(YOLO11):
def __init__(self, nc=80, scale='n'): def __init__(self, nc=80, scale='n'):
super().__init__() super().__init__(nc, scale)
self.nc = nc scales = {'n': [0.50, 0.25, 1024], 's': [0.50, 0.50, 1024], 'm': [0.50, 1.00, 512], 'l': [1.00, 1.00, 512], 'x': [1.00, 1.50, 512]}
self.pe = None
# Scales: [depth, width, max_channels]
# 对应 yolo11.yaml 中的 scales 参数
scales = {
'n': [0.50, 0.25, 1024],
's': [0.50, 0.50, 1024],
'm': [0.50, 1.00, 512],
'l': [1.00, 1.00, 512],
'x': [1.00, 1.50, 512],
}
if scale not in scales:
raise ValueError(f"Invalid scale '{scale}'. Available scales: {list(scales.keys())}")
depth, width, max_channels = scales[scale] depth, width, max_channels = scales[scale]
def gw(channels): return make_divisible(min(channels, max_channels) * width, 8)
if scale in ['n', 's']:
c3k_override = False
else:
c3k_override = True
# 辅助函数:计算通道数 (Width Scaling) self.nc = nc
def gw(channels): self.pe = None
return make_divisible(min(channels, max_channels) * width, 8) self.model[-1] = YOLOESegment(nc, ch=[gw(256), gw(512), gw(1024)])
# 辅助函数:计算层重复次数 (Depth Scaling)
def gd(n):
return max(round(n * depth), 1) if n > 1 else n
self.model = nn.ModuleList()
# --- Backbone ---
# 0: Conv [64, 3, 2]
self.model.append(Conv(3, gw(64), 3, 2))
# 1: Conv [128, 3, 2]
self.model.append(Conv(gw(64), gw(128), 3, 2))
# 2: C3k2 [256, False, 0.25] -> n=2
self.model.append(C3k2(gw(128), gw(256), n=gd(2), c3k=False or c3k_override, e=0.25))
# 3: Conv [256, 3, 2]
self.model.append(Conv(gw(256), gw(256), 3, 2))
# 4: C3k2 [512, False, 0.25] -> n=2
self.model.append(C3k2(gw(256), gw(512), n=gd(2), c3k=False or c3k_override, e=0.25))
# 5: Conv [512, 3, 2]
self.model.append(Conv(gw(512), gw(512), 3, 2))
# 6: C3k2 [512, True] -> n=2
self.model.append(C3k2(gw(512), gw(512), n=gd(2), c3k=True))
# 7: Conv [1024, 3, 2]
self.model.append(Conv(gw(512), gw(1024), 3, 2))
# 8: C3k2 [1024, True] -> n=2
self.model.append(C3k2(gw(1024), gw(1024), n=gd(2), c3k=True))
# 9: SPPF [1024, 5]
self.model.append(SPPF(gw(1024), gw(1024), 5))
# 10: C2PSA [1024] -> n=2 (YAML args=[1024], repeats=2)
self.model.append(C2PSA(gw(1024), gw(1024), n=gd(2)))
# --- Head ---
# 11: Upsample
self.model.append(nn.Upsample(scale_factor=2, mode='nearest'))
# 12: Concat [-1, 6] (P4)
self.model.append(Concat(dimension=1))
# 13: C3k2 [512, False] -> n=2. Input: P5_up(gw(1024)) + P4(gw(512))
self.model.append(C3k2(gw(1024) + gw(512), gw(512), n=gd(2), c3k=False or c3k_override))
# 14: Upsample
self.model.append(nn.Upsample(scale_factor=2, mode='nearest'))
# 15: Concat [-1, 4] (P3)
self.model.append(Concat(dimension=1))
# 16: C3k2 [256, False] -> n=2. Input: P4_up(gw(512)) + P3(gw(512))
# 注意Layer 4 输出是 gw(512)Layer 13 输出是 gw(512)
self.model.append(C3k2(gw(512) + gw(512), gw(256), n=gd(2), c3k=False or c3k_override))
# 17: Conv [256, 3, 2]
self.model.append(Conv(gw(256), gw(256), 3, 2))
# 18: Concat [-1, 13] (Head P4)
self.model.append(Concat(dimension=1))
# 19: C3k2 [512, False] -> n=2. Input: P3_down(gw(256)) + Head_P4(gw(512))
self.model.append(C3k2(gw(256) + gw(512), gw(512), n=gd(2), c3k=False or c3k_override))
# 20: Conv [512, 3, 2]
self.model.append(Conv(gw(512), gw(512), 3, 2))
# 21: Concat [-1, 10] (P5)
self.model.append(Concat(dimension=1))
# 22: C3k2 [1024, True] -> n=2. Input: P4_down(gw(512)) + P5(gw(1024))
self.model.append(C3k2(gw(512) + gw(1024), gw(1024), n=gd(2), c3k=True))
# 23: Detect [nc]
self.model.append(YOLOESegment(nc, ch=[gw(256), gw(512), gw(1024)]))
# --- 初始化权重 ---
self.initialize_weights() self.initialize_weights()
def initialize_weights(self):
"""初始化模型权重,特别是 Detect 头的 Bias"""
for m in self.modules():
if isinstance(m, (nn.Conv2d, nn.Linear)):
# 使用 Kaiming 初始化或其他合适的初始化
pass
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
detect_layer = self.model[-1]
if isinstance(detect_layer, Detect):
detect_layer.bias_init()
def set_classes(self, names: List[str], embeddings: torch.Tensor): def set_classes(self, names: List[str], embeddings: torch.Tensor):
assert embeddings.ndim == 3, "Embeddings must be (1, N, D)" assert embeddings.ndim == 3
self.pe = embeddings self.pe = embeddings
self.model[-1].nc = len(names) # type: ignore self.model[-1].nc = len(names)
self.nc = len(names) self.nc = len(names)
def forward(self, x, tpe=None, vpe=None): def forward(self, x, tpe=None, vpe=None):
# Backbone x = self.model[0](x); x = self.model[1](x); x = self.model[2](x); x = self.model[3](x)
x = self.model[0](x) p3 = self.model[4](x)
x = self.model[1](x) x = self.model[5](p3); p4 = self.model[6](x)
x = self.model[2](x) x = self.model[7](p4); x = self.model[8](x); x = self.model[9](x); p5 = self.model[10](x)
x = self.model[3](x) x = self.model[11](p5); x = self.model[12]([x, p4]); h1 = self.model[13](x)
p3 = self.model[4](x) # 保存 P3 (layer 4) x = self.model[14](h1); x = self.model[15]([x, p3]); h2 = self.model[16](x)
x = self.model[5](p3) x = self.model[17](h2); x = self.model[18]([x, h1]); h3 = self.model[19](x)
p4 = self.model[6](x) # 保存 P4 (layer 6) x = self.model[20](h3); x = self.model[21]([x, p5]); h4 = self.model[22](x)
x = self.model[7](p4)
x = self.model[8](x)
x = self.model[9](x)
p5 = self.model[10](x) # 保存 P5 (layer 10)
# Head
x = self.model[11](p5) # Upsample
x = self.model[12]([x, p4]) # Concat P4
h1 = self.model[13](x) # Head P4 (layer 13)
x = self.model[14](h1) # Upsample
x = self.model[15]([x, p3]) # Concat P3
h2 = self.model[16](x) # Output P3 (layer 16)
x = self.model[17](h2) # Conv
x = self.model[18]([x, h1]) # Concat Head P4
h3 = self.model[19](x) # Output P4 (layer 19)
x = self.model[20](h3) # Conv
x = self.model[21]([x, p5]) # Concat P5
h4 = self.model[22](x) # Output P5 (layer 22)
head = self.model[23] head = self.model[23]
feats = [h2, h3, h4] feats = [h2, h3, h4]
processed_tpe = head.get_tpe(tpe) # type: ignore processed_tpe = head.get_tpe(tpe)
processed_vpe = head.get_vpe(feats, vpe) if vpe is not None else None
processed_vpe = head.get_vpe(feats, vpe) if vpe is not None else None # type: ignore
all_pe = [] all_pe = []
if processed_tpe is not None: if processed_tpe is not None: all_pe.append(processed_tpe)
all_pe.append(processed_tpe) if processed_vpe is not None: all_pe.append(processed_vpe)
if processed_vpe is not None:
all_pe.append(processed_vpe)
if not all_pe: if not all_pe:
if self.pe is not None: if self.pe is not None: all_pe.append(self.pe.to(device=x.device, dtype=x.dtype))
all_pe.append(self.pe.to(device=x.device, dtype=x.dtype)) else: all_pe.append(torch.zeros(1, self.nc, head.embed, device=x.device, dtype=x.dtype))
else:
all_pe.append(torch.zeros(1, self.nc, head.embed, device=x.device, dtype=x.dtype))
cls_pe = torch.cat(all_pe, dim=1) cls_pe = torch.cat(all_pe, dim=1)
b = x.shape[0] b = x.shape[0]
if cls_pe.shape[0] != b: if cls_pe.shape[0] != b: cls_pe = cls_pe.expand(b, -1, -1)
cls_pe = cls_pe.expand(b, -1, -1)
return head(feats, cls_pe) return head(feats, cls_pe)
def load_weights(self, pth_file):
state_dict = torch.load(pth_file, map_location='cpu', weights_only=False)
# 移除可能存在的 'model.' 前缀 (如果权重来自 ultralytics 官方)
# 官方权重通常是 model.model.0.conv... 这种格式,或者直接是 model.0.conv...
# 这里做一个简单的兼容性处理
new_state_dict = {}
for k, v in state_dict.items():
# 处理 ultralytics 权重字典中的 'model' 键
if k == 'model':
# 如果是完整的 checkpoint权重在 'model' 键下
# 且通常是 model.state_dict()
if hasattr(v, 'state_dict'):
v = v.state_dict()
elif isinstance(v, dict):
pass # v 就是 state_dict
else:
# 可能是 model 对象本身
try:
v = v.float().state_dict()
except:
continue
for sub_k, sub_v in v.items():
new_state_dict[sub_k] = sub_v
break
else:
new_state_dict[k] = v
if not new_state_dict:
new_state_dict = state_dict
# 尝试加载
try:
self.load_state_dict(new_state_dict, strict=True)
print(f"Successfully loaded weights from {pth_file}")
except Exception as e:
print(f"Error loading weights: {e}")
print("Trying to load with strict=False...")
self.load_state_dict(new_state_dict, strict=False)
if __name__ == "__main__": if __name__ == "__main__":
model = YOLO11E(nc=80, scale='l') print("Testing Standard YOLO11...")
model.load_weights("yoloe-11l-seg.pth") model_std = YOLO11(nc=80, scale='n')
model_std.eval()
# 模拟 set_classes post_std = YOLOPostProcessor(model_std.model[-1], use_segmentation=False)
# 假设我们有2个类embedding维度是512
fake_embeddings = torch.randn(1, 2, 512)
model.set_classes(["class1", "class2"], fake_embeddings)
# 推理 input_std = torch.randn(1, 3, 640, 640)
dummy_input = torch.randn(1, 3, 640, 640) out_std_raw = model_std(input_std) # Raw list
model.eval() out_std_dec = post_std(out_std_raw) # Decoded
output = model(dummy_input) print(f"Standard Output: {out_std_dec.shape}") # (1, 84, 8400)
print("Output shape:", output[0].shape) # 应该是 (1, 4+mask_coeffs+num_classes, anchors)
print("\nTesting YOLO11E (Segment)...")
model_seg = YOLO11E(nc=80, scale='n')
model_seg.eval()
post_seg = YOLOPostProcessor(model_seg.model[-1], use_segmentation=True)
model_seg.set_classes(["a", "b"], torch.randn(1, 2, 512))
input_seg = torch.randn(1, 3, 640, 640)
out_seg_raw = model_seg(input_seg) # (feats, mc, p)
out_seg_dec, mc, p = post_seg(out_seg_raw) # Decoded
print(f"Segment Output: {out_seg_dec.shape}") # (1, 4+2, 8400)
print(f"Mask Coeffs: {mc.shape}, Protos: {p.shape}")