126 lines
4.6 KiB
Python
126 lines
4.6 KiB
Python
import torch
|
|
import cv2
|
|
import numpy as np
|
|
import torchvision
|
|
from pathlib import Path
|
|
|
|
from yolo11_standalone import YOLO11E, YOLOPostProcessor
|
|
from mobile_clip_standalone import MobileCLIP
|
|
|
|
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
YOLO_WEIGHTS = "yoloe-11l-seg.pth"
|
|
CLIP_WEIGHTS = "mobileclip_blt.ts"
|
|
CLIP_SIZE = "blt"
|
|
IMAGE_PATH = "1.jpg"
|
|
|
|
CUSTOM_CLASSES = ["girl", "red balloon"]
|
|
COLORS = np.random.uniform(0, 255, size=(len(CUSTOM_CLASSES), 3))
|
|
|
|
def letterbox(im, new_shape=(640, 640), color=(114, 114, 114)):
|
|
shape = im.shape[:2]
|
|
if isinstance(new_shape, int): new_shape = (new_shape, new_shape)
|
|
r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
|
|
new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
|
|
dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1]
|
|
dw, dh = dw / 2, dh / 2
|
|
if shape[::-1] != new_unpad:
|
|
im = cv2.resize(im, new_unpad, interpolation=cv2.INTER_LINEAR)
|
|
top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
|
|
left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
|
|
im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color)
|
|
return im, r, (dw, dh)
|
|
|
|
def xywh2xyxy(x):
|
|
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
|
y[..., 0] = x[..., 0] - x[..., 2] / 2
|
|
y[..., 1] = x[..., 1] - x[..., 3] / 2
|
|
y[..., 2] = x[..., 0] + x[..., 2] / 2
|
|
y[..., 3] = x[..., 1] + x[..., 3] / 2
|
|
return y
|
|
|
|
def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.7, max_det=300):
|
|
prediction = prediction.transpose(1, 2)
|
|
bs = prediction.shape[0]
|
|
xc = prediction[..., 4:].max(-1)[0] > conf_thres
|
|
output = [torch.zeros((0, 6), device=prediction.device)] * bs
|
|
for xi, x in enumerate(prediction):
|
|
x = x[xc[xi]]
|
|
if not x.shape[0]: continue
|
|
box = xywh2xyxy(x[:, :4])
|
|
conf, j = x[:, 4:].max(1, keepdim=True)
|
|
x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]
|
|
n = x.shape[0]
|
|
if not n: continue
|
|
elif n > max_det: x = x[x[:, 4].argsort(descending=True)[:max_det]]
|
|
c = x[:, 5:6] * 7680
|
|
boxes, scores = x[:, :4] + c, x[:, 4]
|
|
i = torchvision.ops.nms(boxes, scores, iou_thres)
|
|
output[xi] = x[i]
|
|
return output
|
|
|
|
def main():
|
|
print(f"Using device: {DEVICE}")
|
|
|
|
print(f"Loading MobileCLIP from {CLIP_WEIGHTS}...")
|
|
if not Path(CLIP_WEIGHTS).exists(): raise FileNotFoundError(CLIP_WEIGHTS)
|
|
clip_model = MobileCLIP(checkpoint=CLIP_WEIGHTS, size=CLIP_SIZE, device=DEVICE)
|
|
|
|
print(f"Encoding classes: {CUSTOM_CLASSES}")
|
|
tokens = clip_model.tokenize([f"{c}" for c in CUSTOM_CLASSES])
|
|
text_embeddings = clip_model.encode_text(tokens).unsqueeze(0)
|
|
|
|
print(f"Loading YOLO11E from {YOLO_WEIGHTS}...")
|
|
if not Path(YOLO_WEIGHTS).exists(): raise FileNotFoundError(YOLO_WEIGHTS)
|
|
|
|
yolo_model = YOLO11E(nc=80, scale='l')
|
|
yolo_model.load_weights(YOLO_WEIGHTS)
|
|
yolo_model.to(DEVICE).eval()
|
|
|
|
head = yolo_model.model[-1]
|
|
post_processor = YOLOPostProcessor(head, use_segmentation=True)
|
|
post_processor.to(DEVICE).eval()
|
|
|
|
with torch.no_grad():
|
|
text_pe = head.get_tpe(text_embeddings)
|
|
yolo_model.set_classes(CUSTOM_CLASSES, text_pe)
|
|
|
|
img0 = cv2.imread(IMAGE_PATH)
|
|
assert img0 is not None, f"Image Not Found {IMAGE_PATH}"
|
|
img, ratio, (dw, dh) = letterbox(img0, new_shape=(640, 640))
|
|
img = np.ascontiguousarray(img[:, :, ::-1].transpose(2, 0, 1))
|
|
img_tensor = torch.from_numpy(img).to(DEVICE).float() / 255.0
|
|
if img_tensor.ndim == 3: img_tensor = img_tensor.unsqueeze(0)
|
|
|
|
print("Running inference...")
|
|
with torch.no_grad():
|
|
raw_outputs = yolo_model(img_tensor)
|
|
decoded_box, mc, p = post_processor(raw_outputs)
|
|
|
|
pred = non_max_suppression(decoded_box, conf_thres=0.25, iou_thres=0.7)
|
|
|
|
det = pred[0]
|
|
if len(det):
|
|
det[:, [0, 2]] -= dw
|
|
det[:, [1, 3]] -= dh
|
|
det[:, :4] /= ratio
|
|
det[:, [0, 2]].clamp_(0, img0.shape[1])
|
|
det[:, [1, 3]].clamp_(0, img0.shape[0])
|
|
|
|
print(f"Detected {len(det)} objects:")
|
|
for *xyxy, conf, cls in det:
|
|
c = int(cls)
|
|
label = f'{CUSTOM_CLASSES[c]} {conf:.2f}'
|
|
p1, p2 = (int(xyxy[0]), int(xyxy[1])), (int(xyxy[2]), int(xyxy[3]))
|
|
color = COLORS[c % len(COLORS)]
|
|
|
|
cv2.rectangle(img0, p1, p2, color, 2, cv2.LINE_AA)
|
|
cv2.putText(img0, label, (p1[0], p1[1] - 5), 0, 0.5, color, 1, cv2.LINE_AA)
|
|
print(f" - {label}")
|
|
else:
|
|
print("No objects detected.")
|
|
|
|
cv2.imwrite("result_separate.jpg", img0)
|
|
print("Result saved.")
|
|
|
|
if __name__ == "__main__":
|
|
main() |