添加qat量化支持
This commit is contained in:
@@ -4,7 +4,7 @@ import numpy as np
|
||||
import torchvision
|
||||
from pathlib import Path
|
||||
|
||||
from yolo11_standalone import YOLO11E, YOLOPostProcessor
|
||||
from yolo11_standalone import YOLO11E, YOLOPostProcessor, YOLOPostProcessorNumpy
|
||||
from mobile_clip_standalone import MobileCLIP
|
||||
|
||||
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
@@ -58,6 +58,56 @@ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.7, max_det=300)
|
||||
output[xi] = x[i]
|
||||
return output
|
||||
|
||||
def non_max_suppression_numpy(prediction, conf_thres=0.25, iou_thres=0.45, max_det=300):
|
||||
bs = prediction.shape[0]
|
||||
output = [np.zeros((0, 6), dtype=np.float32)] * bs
|
||||
|
||||
for xi, x in enumerate(prediction):
|
||||
bbox_xywh = x[:, :4]
|
||||
class_probs = x[:, 4:]
|
||||
|
||||
class_ids = np.argmax(class_probs, axis=1)
|
||||
confidences = np.max(class_probs, axis=1)
|
||||
|
||||
mask = confidences > conf_thres
|
||||
bbox_xywh = bbox_xywh[mask]
|
||||
confidences = confidences[mask]
|
||||
class_ids = class_ids[mask]
|
||||
|
||||
if len(confidences) == 0:
|
||||
continue
|
||||
|
||||
bbox_tlwh = np.copy(bbox_xywh)
|
||||
bbox_tlwh[:, 0] = bbox_xywh[:, 0] - bbox_xywh[:, 2] / 2
|
||||
bbox_tlwh[:, 1] = bbox_xywh[:, 1] - bbox_xywh[:, 3] / 2
|
||||
|
||||
indices = cv2.dnn.NMSBoxes(
|
||||
bboxes=bbox_tlwh.tolist(),
|
||||
scores=confidences.tolist(),
|
||||
score_threshold=conf_thres,
|
||||
nms_threshold=iou_thres
|
||||
)
|
||||
|
||||
if len(indices) > 0:
|
||||
indices = indices.flatten()
|
||||
if len(indices) > max_det:
|
||||
indices = indices[:max_det]
|
||||
|
||||
final_boxes_xywh = bbox_xywh[indices]
|
||||
final_boxes_xyxy = xywh2xyxy(final_boxes_xywh)
|
||||
final_scores = confidences[indices]
|
||||
final_classes = class_ids[indices]
|
||||
|
||||
out_tensor = np.concatenate([
|
||||
final_boxes_xyxy,
|
||||
final_scores[:, None],
|
||||
final_classes[:, None]
|
||||
], axis=1)
|
||||
|
||||
output[xi] = out_tensor
|
||||
|
||||
return output
|
||||
|
||||
def main():
|
||||
print(f"Using device: {DEVICE}")
|
||||
|
||||
@@ -94,17 +144,30 @@ def main():
|
||||
print("Running inference...")
|
||||
with torch.no_grad():
|
||||
raw_outputs = yolo_model(img_tensor)
|
||||
decoded_box, mc, p = post_processor(raw_outputs)
|
||||
# decoded_box, mc, p = post_processor(raw_outputs)
|
||||
# pred = non_max_suppression(decoded_box, conf_thres=0.25, iou_thres=0.7)
|
||||
# det = pred[0]
|
||||
feat_maps, mc, p = raw_outputs
|
||||
|
||||
feat_maps_numpy = [f.detach().cpu().numpy() for f in feat_maps]
|
||||
mc_numpy = mc.detach().cpu().numpy()
|
||||
p_numpy = p.detach().cpu().numpy()
|
||||
raw_outputs_numpy = (feat_maps_numpy, mc_numpy, p_numpy)
|
||||
post_processor_numpy = YOLOPostProcessorNumpy(strides=[8, 16, 32], reg_max=16, use_segmentation=True)
|
||||
decoded_box_numpy, mc_numpy_out, p_numpy_out = post_processor_numpy(raw_outputs_numpy)
|
||||
pred_results = non_max_suppression_numpy(decoded_box_numpy, conf_thres=0.25, iou_thres=0.7)
|
||||
det = pred_results[0]
|
||||
|
||||
pred = non_max_suppression(decoded_box, conf_thres=0.25, iou_thres=0.7)
|
||||
|
||||
det = pred[0]
|
||||
if len(det):
|
||||
det[:, [0, 2]] -= dw
|
||||
det[:, [1, 3]] -= dh
|
||||
det[:, :4] /= ratio
|
||||
det[:, [0, 2]].clamp_(0, img0.shape[1])
|
||||
det[:, [1, 3]].clamp_(0, img0.shape[0])
|
||||
# det[:, [0, 2]].clamp_(0, img0.shape[1])
|
||||
# det[:, [1, 3]].clamp_(0, img0.shape[0])
|
||||
det[:, 0] = np.clip(det[:, 0], 0, img0.shape[1])
|
||||
det[:, 1] = np.clip(det[:, 1], 0, img0.shape[0])
|
||||
det[:, 2] = np.clip(det[:, 2], 0, img0.shape[1])
|
||||
det[:, 3] = np.clip(det[:, 3], 0, img0.shape[0])
|
||||
|
||||
print(f"Detected {len(det)} objects:")
|
||||
for *xyxy, conf, cls in det:
|
||||
@@ -119,7 +182,7 @@ def main():
|
||||
else:
|
||||
print("No objects detected.")
|
||||
|
||||
cv2.imwrite("result_separate.jpg", img0)
|
||||
cv2.imwrite("result_yoloe.jpg", img0)
|
||||
print("Result saved.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user