移除detect头中动态运算图,动态部分分离出来,并更新相关的推理代码示例

This commit is contained in:
lhr
2025-12-30 17:10:01 +08:00
parent 9df330875d
commit 553a63f521
3 changed files with 203 additions and 621 deletions

View File

@@ -4,25 +4,18 @@ import numpy as np
import torchvision
from pathlib import Path
# 导入你的模块
from yolo11_standalone import YOLO11E
from yolo11_standalone import YOLO11E, YOLOPostProcessor
from mobile_clip_standalone import MobileCLIP
# --- 配置 ---
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
YOLO_WEIGHTS = "yoloe-11l-seg.pth" # 替换为你的 YOLO 权重路径
CLIP_WEIGHTS = "mobileclip_blt.ts" # 替换为你的 MobileCLIP 权重路径
CLIP_SIZE = "blt" # 对应 MobileCLIP 的 size
IMAGE_PATH = "1.jpg" # 待检测图片
YOLO_WEIGHTS = "yoloe-11l-seg.pth"
CLIP_WEIGHTS = "mobileclip_blt.ts"
CLIP_SIZE = "blt"
IMAGE_PATH = "1.jpg"
# 自定义检测类别 (Open Vocabulary)
CUSTOM_CLASSES = ["girl", "red balloon"]
# 绘图颜色
COLORS = np.random.uniform(0, 255, size=(len(CUSTOM_CLASSES), 3))
# --- 辅助函数 (Letterbox, NMS 等) ---
def letterbox(im, new_shape=(640, 640), color=(114, 114, 114)):
shape = im.shape[:2]
if isinstance(new_shape, int): new_shape = (new_shape, new_shape)
@@ -68,99 +61,66 @@ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.7, max_det=300)
def main():
print(f"Using device: {DEVICE}")
# 1. 加载 MobileCLIP 文本编码器
print(f"Loading MobileCLIP from {CLIP_WEIGHTS}...")
if not Path(CLIP_WEIGHTS).exists():
raise FileNotFoundError(f"MobileCLIP weights not found: {CLIP_WEIGHTS}")
if not Path(CLIP_WEIGHTS).exists(): raise FileNotFoundError(CLIP_WEIGHTS)
clip_model = MobileCLIP(checkpoint=CLIP_WEIGHTS, size=CLIP_SIZE, device=DEVICE)
# 2. 生成文本 Embeddings
print(f"Encoding classes: {CUSTOM_CLASSES}")
prompts = [f"{c}" for c in CUSTOM_CLASSES]
tokens = clip_model.tokenize(prompts)
text_embeddings = clip_model.encode_text(tokens) # Shape: (N, 512)
# 调整维度为 (1, N, 512) 以匹配 YOLO11E 输入
text_embeddings = text_embeddings.unsqueeze(0)
tokens = clip_model.tokenize([f"{c}" for c in CUSTOM_CLASSES])
text_embeddings = clip_model.encode_text(tokens).unsqueeze(0)
# 3. 加载 YOLO11E 检测模型
print(f"Loading YOLO11E from {YOLO_WEIGHTS}...")
if not Path(YOLO_WEIGHTS).exists():
raise FileNotFoundError(f"YOLO weights not found: {YOLO_WEIGHTS}")
# 注意scale='l' 必须与你的权重文件匹配 (s, m, l, x)
yolo_model = YOLO11E(nc=80, scale='l')
if not Path(YOLO_WEIGHTS).exists(): raise FileNotFoundError(YOLO_WEIGHTS)
yolo_model = YOLO11E(nc=80, scale='l')
yolo_model.load_weights(YOLO_WEIGHTS)
yolo_model.to(DEVICE) # 使用半精度to(DEVICE)
yolo_model.eval()
yolo_model.to(DEVICE).eval()
head = yolo_model.model[-1]
post_processor = YOLOPostProcessor(head, use_segmentation=True)
post_processor.to(DEVICE).eval()
with torch.no_grad():
text_pe = head.get_tpe(text_embeddings) # type: ignore
text_pe = head.get_tpe(text_embeddings)
yolo_model.set_classes(CUSTOM_CLASSES, text_pe)
# 5. 图像预处理
img0 = cv2.imread(IMAGE_PATH)
assert img0 is not None, f"Image Not Found {IMAGE_PATH}"
img, ratio, (dw, dh) = letterbox(img0, new_shape=(640, 640))
img = img[:, :, ::-1].transpose(2, 0, 1)
img = np.ascontiguousarray(img)
img_tensor = torch.from_numpy(img).to(DEVICE)
img_tensor = img_tensor.float()
img_tensor /= 255.0
if img_tensor.ndim == 3:
img_tensor = img_tensor.unsqueeze(0)
img = np.ascontiguousarray(img[:, :, ::-1].transpose(2, 0, 1))
img_tensor = torch.from_numpy(img).to(DEVICE).float() / 255.0
if img_tensor.ndim == 3: img_tensor = img_tensor.unsqueeze(0)
# 6. 推理
print("Running inference...")
with torch.no_grad():
pred = yolo_model(img_tensor)
if isinstance(pred, tuple):
pred = pred[0]
nc = len(CUSTOM_CLASSES)
pred = pred[:, :4+nc, :]
raw_outputs = yolo_model(img_tensor)
decoded_box, mc, p = post_processor(raw_outputs)
# 7. 后处理 (NMS)
pred = non_max_suppression(pred, conf_thres=0.25, iou_thres=0.7)
pred = non_max_suppression(decoded_box, conf_thres=0.25, iou_thres=0.7)
print(pred)
# 8. 可视化
det = pred[0]
if len(det):
det[:, [0, 2]] -= dw
det[:, [1, 3]] -= dh
det[:, :4] /= ratio
det[:, 0].clamp_(0, img0.shape[1])
det[:, 1].clamp_(0, img0.shape[0])
det[:, 2].clamp_(0, img0.shape[1])
det[:, 3].clamp_(0, img0.shape[0])
det[:, [0, 2]].clamp_(0, img0.shape[1])
det[:, [1, 3]].clamp_(0, img0.shape[0])
print(f"Detected {len(det)} objects:")
for *xyxy, conf, cls in det:
c = int(cls)
class_name = CUSTOM_CLASSES[c] if c < len(CUSTOM_CLASSES) else str(c)
label = f'{class_name} {conf:.2f}'
label = f'{CUSTOM_CLASSES[c]} {conf:.2f}'
p1, p2 = (int(xyxy[0]), int(xyxy[1])), (int(xyxy[2]), int(xyxy[3]))
color = COLORS[c]
color = COLORS[c % len(COLORS)]
cv2.rectangle(img0, p1, p2, color, 2, cv2.LINE_AA)
t_size = cv2.getTextSize(label, 0, 0.5, 1)[0]
p2_label = p1[0] + t_size[0], p1[1] - t_size[1] - 3
cv2.rectangle(img0, p1, p2_label, color, -1, cv2.LINE_AA)
cv2.putText(img0, label, (p1[0], p1[1] - 2), 0, 0.5, [255, 255, 255], 1, cv2.LINE_AA)
cv2.putText(img0, label, (p1[0], p1[1] - 5), 0, 0.5, color, 1, cv2.LINE_AA)
print(f" - {label}")
else:
print("No objects detected.")
output_path = "result_full.jpg"
cv2.imwrite(output_path, img0)
print(f"Result saved to {output_path}")
cv2.imwrite("result_separate.jpg", img0)
print("Result saved.")
if __name__ == "__main__":
main()