Files
Yolo-standalone/inference_yoloe.py
2025-12-27 02:14:11 +08:00

168 lines
5.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import cv2
import numpy as np
import torchvision
from pathlib import Path
# 导入你的模块
from yolo11_standalone import YOLO11E
from mobile_clip_standalone import MobileCLIP
from ultralytics import YOLOE
# --- 配置 ---
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
YOLO_WEIGHTS = "yoloe-11l-seg.pth" # 替换为你的 YOLO 权重路径
CLIP_WEIGHTS = "mobileclip_blt.ts" # 替换为你的 MobileCLIP 权重路径
CLIP_SIZE = "blt" # 对应 MobileCLIP 的 size
IMAGE_PATH = "1.jpg" # 待检测图片
# 自定义检测类别 (Open Vocabulary)
CUSTOM_CLASSES = ["girl", "red balloon"]
# 绘图颜色
COLORS = np.random.uniform(0, 255, size=(len(CUSTOM_CLASSES), 3))
# --- 辅助函数 (Letterbox, NMS 等) ---
def letterbox(im, new_shape=(640, 640), color=(114, 114, 114)):
shape = im.shape[:2]
if isinstance(new_shape, int): new_shape = (new_shape, new_shape)
r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1]
dw, dh = dw / 2, dh / 2
if shape[::-1] != new_unpad:
im = cv2.resize(im, new_unpad, interpolation=cv2.INTER_LINEAR)
top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color)
return im, r, (dw, dh)
def xywh2xyxy(x):
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
y[..., 0] = x[..., 0] - x[..., 2] / 2
y[..., 1] = x[..., 1] - x[..., 3] / 2
y[..., 2] = x[..., 0] + x[..., 2] / 2
y[..., 3] = x[..., 1] + x[..., 3] / 2
return y
def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.7, max_det=300):
prediction = prediction.transpose(1, 2)
bs = prediction.shape[0]
xc = prediction[..., 4:].max(-1)[0] > conf_thres
output = [torch.zeros((0, 6), device=prediction.device)] * bs
for xi, x in enumerate(prediction):
x = x[xc[xi]]
if not x.shape[0]: continue
box = xywh2xyxy(x[:, :4])
conf, j = x[:, 4:].max(1, keepdim=True)
x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]
n = x.shape[0]
if not n: continue
elif n > max_det: x = x[x[:, 4].argsort(descending=True)[:max_det]]
c = x[:, 5:6] * 7680
boxes, scores = x[:, :4] + c, x[:, 4]
i = torchvision.ops.nms(boxes, scores, iou_thres)
output[xi] = x[i]
return output
def main():
print(f"Using device: {DEVICE}")
# 1. 加载 MobileCLIP 文本编码器
print(f"Loading MobileCLIP from {CLIP_WEIGHTS}...")
if not Path(CLIP_WEIGHTS).exists():
raise FileNotFoundError(f"MobileCLIP weights not found: {CLIP_WEIGHTS}")
clip_model = MobileCLIP(checkpoint=CLIP_WEIGHTS, size=CLIP_SIZE, device=DEVICE)
# 2. 生成文本 Embeddings
print(f"Encoding classes: {CUSTOM_CLASSES}")
prompts = [f"{c}" for c in CUSTOM_CLASSES]
tokens = clip_model.tokenize(prompts)
text_embeddings = clip_model.encode_text(tokens) # Shape: (N, 512)
# 调整维度为 (1, N, 512) 以匹配 YOLO11E 输入
text_embeddings = text_embeddings.unsqueeze(0)
# 3. 加载 YOLO11E 检测模型
print(f"Loading YOLO11E from {YOLO_WEIGHTS}...")
if not Path(YOLO_WEIGHTS).exists():
raise FileNotFoundError(f"YOLO weights not found: {YOLO_WEIGHTS}")
# 注意scale='l' 必须与你的权重文件匹配 (s, m, l, x)
yolo_model = YOLO11E(nc=80, scale='l')
yolo_model.load_weights(YOLO_WEIGHTS)
yolo_model.to(DEVICE) # 使用半精度to(DEVICE)
yolo_model.eval()
head = yolo_model.model[-1]
with torch.no_grad():
text_pe = head.get_tpe(text_embeddings) # type: ignore
yolo_model.set_classes(CUSTOM_CLASSES, text_pe)
# 5. 图像预处理
img0 = cv2.imread(IMAGE_PATH)
assert img0 is not None, f"Image Not Found {IMAGE_PATH}"
img, ratio, (dw, dh) = letterbox(img0, new_shape=(640, 640))
img = img[:, :, ::-1].transpose(2, 0, 1)
img = np.ascontiguousarray(img)
img_tensor = torch.from_numpy(img).to(DEVICE)
img_tensor = img_tensor.float()
img_tensor /= 255.0
if img_tensor.ndim == 3:
img_tensor = img_tensor.unsqueeze(0)
# 6. 推理
print("Running inference...")
with torch.no_grad():
pred = yolo_model(img_tensor)
if isinstance(pred, tuple):
pred = pred[0]
nc = len(CUSTOM_CLASSES)
pred = pred[:, :4+nc, :]
# 7. 后处理 (NMS)
pred = non_max_suppression(pred, conf_thres=0.25, iou_thres=0.7)
print(pred)
# 8. 可视化
det = pred[0]
if len(det):
det[:, [0, 2]] -= dw
det[:, [1, 3]] -= dh
det[:, :4] /= ratio
det[:, 0].clamp_(0, img0.shape[1])
det[:, 1].clamp_(0, img0.shape[0])
det[:, 2].clamp_(0, img0.shape[1])
det[:, 3].clamp_(0, img0.shape[0])
print(f"Detected {len(det)} objects:")
for *xyxy, conf, cls in det:
c = int(cls)
class_name = CUSTOM_CLASSES[c] if c < len(CUSTOM_CLASSES) else str(c)
label = f'{class_name} {conf:.2f}'
p1, p2 = (int(xyxy[0]), int(xyxy[1])), (int(xyxy[2]), int(xyxy[3]))
color = COLORS[c]
cv2.rectangle(img0, p1, p2, color, 2, cv2.LINE_AA)
t_size = cv2.getTextSize(label, 0, 0.5, 1)[0]
p2_label = p1[0] + t_size[0], p1[1] - t_size[1] - 3
cv2.rectangle(img0, p1, p2_label, color, -1, cv2.LINE_AA)
cv2.putText(img0, label, (p1[0], p1[1] - 2), 0, 0.5, [255, 255, 255], 1, cv2.LINE_AA)
print(f" - {label}")
else:
print("No objects detected.")
output_path = "result_full.jpg"
cv2.imwrite(output_path, img0)
print(f"Result saved to {output_path}")
if __name__ == "__main__":
main()