import torch import cv2 import numpy as np import torchvision from pathlib import Path # 导入你的模块 from yolo11_standalone import YOLO11E from mobile_clip_standalone import MobileCLIP from ultralytics import YOLOE # --- 配置 --- DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") YOLO_WEIGHTS = "yoloe-11l-seg.pth" # 替换为你的 YOLO 权重路径 CLIP_WEIGHTS = "mobileclip_blt.ts" # 替换为你的 MobileCLIP 权重路径 CLIP_SIZE = "blt" # 对应 MobileCLIP 的 size IMAGE_PATH = "1.jpg" # 待检测图片 # 自定义检测类别 (Open Vocabulary) CUSTOM_CLASSES = ["girl", "red balloon"] # 绘图颜色 COLORS = np.random.uniform(0, 255, size=(len(CUSTOM_CLASSES), 3)) # --- 辅助函数 (Letterbox, NMS 等) --- def letterbox(im, new_shape=(640, 640), color=(114, 114, 114)): shape = im.shape[:2] if isinstance(new_shape, int): new_shape = (new_shape, new_shape) r = min(new_shape[0] / shape[0], new_shape[1] / shape[1]) new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r)) dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] dw, dh = dw / 2, dh / 2 if shape[::-1] != new_unpad: im = cv2.resize(im, new_unpad, interpolation=cv2.INTER_LINEAR) top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1)) left, right = int(round(dw - 0.1)), int(round(dw + 0.1)) im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) return im, r, (dw, dh) def xywh2xyxy(x): y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) y[..., 0] = x[..., 0] - x[..., 2] / 2 y[..., 1] = x[..., 1] - x[..., 3] / 2 y[..., 2] = x[..., 0] + x[..., 2] / 2 y[..., 3] = x[..., 1] + x[..., 3] / 2 return y def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.7, max_det=300): prediction = prediction.transpose(1, 2) bs = prediction.shape[0] xc = prediction[..., 4:].max(-1)[0] > conf_thres output = [torch.zeros((0, 6), device=prediction.device)] * bs for xi, x in enumerate(prediction): x = x[xc[xi]] if not x.shape[0]: continue box = xywh2xyxy(x[:, :4]) conf, j = x[:, 4:].max(1, keepdim=True) x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres] n = x.shape[0] if not n: continue elif n > max_det: x = x[x[:, 4].argsort(descending=True)[:max_det]] c = x[:, 5:6] * 7680 boxes, scores = x[:, :4] + c, x[:, 4] i = torchvision.ops.nms(boxes, scores, iou_thres) output[xi] = x[i] return output def main(): print(f"Using device: {DEVICE}") # 1. 加载 MobileCLIP 文本编码器 print(f"Loading MobileCLIP from {CLIP_WEIGHTS}...") if not Path(CLIP_WEIGHTS).exists(): raise FileNotFoundError(f"MobileCLIP weights not found: {CLIP_WEIGHTS}") clip_model = MobileCLIP(checkpoint=CLIP_WEIGHTS, size=CLIP_SIZE, device=DEVICE) # 2. 生成文本 Embeddings print(f"Encoding classes: {CUSTOM_CLASSES}") prompts = [f"{c}" for c in CUSTOM_CLASSES] tokens = clip_model.tokenize(prompts) text_embeddings = clip_model.encode_text(tokens) # Shape: (N, 512) # 调整维度为 (1, N, 512) 以匹配 YOLO11E 输入 text_embeddings = text_embeddings.unsqueeze(0) # 3. 加载 YOLO11E 检测模型 print(f"Loading YOLO11E from {YOLO_WEIGHTS}...") if not Path(YOLO_WEIGHTS).exists(): raise FileNotFoundError(f"YOLO weights not found: {YOLO_WEIGHTS}") # 注意:scale='l' 必须与你的权重文件匹配 (s, m, l, x) yolo_model = YOLO11E(nc=80, scale='l') yolo_model.load_weights(YOLO_WEIGHTS) yolo_model.to(DEVICE) # 使用半精度to(DEVICE) yolo_model.eval() head = yolo_model.model[-1] with torch.no_grad(): text_pe = head.get_tpe(text_embeddings) # type: ignore yolo_model.set_classes(CUSTOM_CLASSES, text_pe) # 5. 图像预处理 img0 = cv2.imread(IMAGE_PATH) assert img0 is not None, f"Image Not Found {IMAGE_PATH}" img, ratio, (dw, dh) = letterbox(img0, new_shape=(640, 640)) img = img[:, :, ::-1].transpose(2, 0, 1) img = np.ascontiguousarray(img) img_tensor = torch.from_numpy(img).to(DEVICE) img_tensor = img_tensor.float() img_tensor /= 255.0 if img_tensor.ndim == 3: img_tensor = img_tensor.unsqueeze(0) # 6. 推理 print("Running inference...") with torch.no_grad(): pred = yolo_model(img_tensor) if isinstance(pred, tuple): pred = pred[0] nc = len(CUSTOM_CLASSES) pred = pred[:, :4+nc, :] # 7. 后处理 (NMS) pred = non_max_suppression(pred, conf_thres=0.25, iou_thres=0.7) print(pred) # 8. 可视化 det = pred[0] if len(det): det[:, [0, 2]] -= dw det[:, [1, 3]] -= dh det[:, :4] /= ratio det[:, 0].clamp_(0, img0.shape[1]) det[:, 1].clamp_(0, img0.shape[0]) det[:, 2].clamp_(0, img0.shape[1]) det[:, 3].clamp_(0, img0.shape[0]) print(f"Detected {len(det)} objects:") for *xyxy, conf, cls in det: c = int(cls) class_name = CUSTOM_CLASSES[c] if c < len(CUSTOM_CLASSES) else str(c) label = f'{class_name} {conf:.2f}' p1, p2 = (int(xyxy[0]), int(xyxy[1])), (int(xyxy[2]), int(xyxy[3])) color = COLORS[c] cv2.rectangle(img0, p1, p2, color, 2, cv2.LINE_AA) t_size = cv2.getTextSize(label, 0, 0.5, 1)[0] p2_label = p1[0] + t_size[0], p1[1] - t_size[1] - 3 cv2.rectangle(img0, p1, p2_label, color, -1, cv2.LINE_AA) cv2.putText(img0, label, (p1[0], p1[1] - 2), 0, 0.5, [255, 255, 255], 1, cv2.LINE_AA) print(f" - {label}") else: print("No objects detected.") output_path = "result_full.jpg" cv2.imwrite(output_path, img0) print(f"Result saved to {output_path}") if __name__ == "__main__": main()