189 lines
6.5 KiB
Python
189 lines
6.5 KiB
Python
import torch
|
|
import cv2
|
|
import numpy as np
|
|
import torchvision
|
|
from yolo11_standalone import YOLO11
|
|
|
|
# COCO 80类 类别名称
|
|
CLASSES = [
|
|
"person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light",
|
|
"fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow",
|
|
"elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee",
|
|
"skis", "snowboard", "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard",
|
|
"tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple",
|
|
"sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch",
|
|
"potted plant", "bed", "dining table", "toilet", "tv", "laptop", "mouse", "remote", "keyboard", "cell phone",
|
|
"microwave", "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors", "teddy bear",
|
|
"hair drier", "toothbrush"
|
|
]
|
|
|
|
# 生成随机颜色用于绘图
|
|
COLORS = np.random.uniform(0, 255, size=(len(CLASSES), 3))
|
|
|
|
def letterbox(im, new_shape=(640, 640), color=(114, 114, 114)):
|
|
"""
|
|
将图像缩放并填充到指定大小 (保持纵横比)
|
|
"""
|
|
shape = im.shape[:2] # current shape [height, width]
|
|
if isinstance(new_shape, int):
|
|
new_shape = (new_shape, new_shape)
|
|
|
|
# 计算缩放比例
|
|
r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
|
|
|
|
# 计算padding
|
|
new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
|
|
dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
|
|
dw, dh = dw / 2, dh / 2 # divide padding into 2 sides
|
|
|
|
if shape[::-1] != new_unpad: # resize
|
|
im = cv2.resize(im, new_unpad, interpolation=cv2.INTER_LINEAR)
|
|
|
|
top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
|
|
left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
|
|
|
|
# 添加边框
|
|
im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color)
|
|
return im, r, (dw, dh)
|
|
|
|
def xywh2xyxy(x):
|
|
"""Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2]"""
|
|
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
|
y[..., 0] = x[..., 0] - x[..., 2] / 2 # top left x
|
|
y[..., 1] = x[..., 1] - x[..., 3] / 2 # top left y
|
|
y[..., 2] = x[..., 0] + x[..., 2] / 2 # bottom right x
|
|
y[..., 3] = x[..., 1] + x[..., 3] / 2 # bottom right y
|
|
return y
|
|
|
|
def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, max_det=300):
|
|
"""
|
|
非极大值抑制 (NMS)
|
|
prediction: [Batch, 84, Anchors]
|
|
"""
|
|
# 1. 转置: [Batch, 84, Anchors] -> [Batch, Anchors, 84]
|
|
prediction = prediction.transpose(1, 2)
|
|
|
|
bs = prediction.shape[0] # batch size
|
|
nc = prediction.shape[2] - 4 # number of classes
|
|
|
|
# 修复: 使用 max(-1) 在最后一个维度(类别)上寻找最大置信度
|
|
# 之前的 max(1) 错误地在 Anchors 维度上操作了
|
|
xc = prediction[..., 4:].max(-1)[0] > conf_thres # candidates
|
|
|
|
output = [torch.zeros((0, 6), device=prediction.device)] * bs
|
|
|
|
for xi, x in enumerate(prediction): # image index, image inference
|
|
x = x[xc[xi]] # confidence filtering
|
|
|
|
if not x.shape[0]:
|
|
continue
|
|
|
|
# Box (center x, center y, width, height) to (x1, y1, x2, y2)
|
|
box = xywh2xyxy(x[:, :4])
|
|
|
|
# Confidence and Class
|
|
conf, j = x[:, 4:].max(1, keepdim=True)
|
|
x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]
|
|
|
|
# Check shape
|
|
n = x.shape[0]
|
|
if not n:
|
|
continue
|
|
elif n > max_det:
|
|
x = x[x[:, 4].argsort(descending=True)[:max_det]]
|
|
|
|
# Batched NMS
|
|
c = x[:, 5:6] * 7680 # classes
|
|
boxes, scores = x[:, :4] + c, x[:, 4]
|
|
i = torchvision.ops.nms(boxes, scores, iou_thres)
|
|
output[xi] = x[i]
|
|
|
|
return output
|
|
|
|
def main():
|
|
# 1. 初始化模型
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
print(f"Using device: {device}")
|
|
|
|
model = YOLO11(nc=80, scale='s')
|
|
# 加载你之前转换好的纯净权重
|
|
model.load_weights("yolo11s.pth")
|
|
model.to(device)
|
|
model.eval()
|
|
# model.train()
|
|
|
|
# 2. 读取图片
|
|
img_path = "1.jpg" # 请替换为你本地的图片路径
|
|
|
|
img0 = cv2.imread(img_path)
|
|
assert img0 is not None, f"Image Not Found {img_path}"
|
|
|
|
# 3. 预处理
|
|
# Letterbox resize
|
|
img, ratio, (dw, dh) = letterbox(img0, new_shape=(640, 640))
|
|
|
|
# BGR to RGB, HWC to CHW
|
|
img = img[:, :, ::-1].transpose(2, 0, 1)
|
|
img = np.ascontiguousarray(img)
|
|
|
|
img_tensor = torch.from_numpy(img).to(device)
|
|
img_tensor = img_tensor.float()
|
|
img_tensor /= 255.0 # 0 - 255 to 0.0 - 1.0
|
|
if img_tensor.ndim == 3:
|
|
img_tensor = img_tensor.unsqueeze(0)
|
|
|
|
# 4. 推理
|
|
print("开始推理...")
|
|
with torch.no_grad():
|
|
pred = model(img_tensor)
|
|
|
|
# 5. 后处理 (NMS)
|
|
pred = non_max_suppression(pred, conf_thres=0.25, iou_thres=0.45)
|
|
|
|
# 6. 绘制结果
|
|
det = pred[0] # 仅处理第一张图片
|
|
|
|
if len(det):
|
|
# 将坐标映射回原图尺寸
|
|
# det[:, :4] 是 x1, y1, x2, y2
|
|
det[:, [0, 2]] -= dw # x padding
|
|
det[:, [1, 3]] -= dh # y padding
|
|
det[:, :4] /= ratio
|
|
|
|
# 裁剪坐标防止越界
|
|
det[:, 0].clamp_(0, img0.shape[1])
|
|
det[:, 1].clamp_(0, img0.shape[0])
|
|
det[:, 2].clamp_(0, img0.shape[1])
|
|
det[:, 3].clamp_(0, img0.shape[0])
|
|
|
|
print(f"检测到 {len(det)} 个目标")
|
|
|
|
for *xyxy, conf, cls in det:
|
|
c = int(cls)
|
|
label = f'{CLASSES[c]} {conf:.2f}'
|
|
p1, p2 = (int(xyxy[0]), int(xyxy[1])), (int(xyxy[2]), int(xyxy[3]))
|
|
|
|
# 画框
|
|
color = COLORS[c]
|
|
cv2.rectangle(img0, p1, p2, color, 2, lineType=cv2.LINE_AA)
|
|
|
|
# 画标签背景
|
|
t_size = cv2.getTextSize(label, 0, fontScale=0.5, thickness=1)[0]
|
|
p2_label = p1[0] + t_size[0], p1[1] - t_size[1] - 3
|
|
cv2.rectangle(img0, p1, p2_label, color, -1, cv2.LINE_AA)
|
|
|
|
# 画文字
|
|
cv2.putText(img0, label, (p1[0], p1[1] - 2), 0, 0.5, [255, 255, 255], thickness=1, lineType=cv2.LINE_AA)
|
|
|
|
print(f" - {label} at {p1}-{p2}")
|
|
|
|
# 7. 显示/保存结果
|
|
cv2.imwrite("result.jpg", img0)
|
|
print("结果已保存至 result.jpg")
|
|
|
|
def import_os_exists(path):
|
|
import os
|
|
return os.path.exists(path)
|
|
|
|
if __name__ == "__main__":
|
|
main() |