添加qat量化支持
This commit is contained in:
91
inference.py
91
inference.py
@@ -2,7 +2,7 @@ import torch
|
|||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torchvision
|
import torchvision
|
||||||
from yolo11_standalone import YOLO11, YOLOPostProcessor
|
from yolo11_standalone import YOLO11, YOLOPostProcessor, YOLOPostProcessorNumpy
|
||||||
|
|
||||||
CLASSES = [
|
CLASSES = [
|
||||||
"person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light",
|
"person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light",
|
||||||
@@ -26,7 +26,7 @@ def letterbox(im, new_shape=(640, 640), color=(114, 114, 114)):
|
|||||||
r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
|
r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
|
||||||
|
|
||||||
new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
|
new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
|
||||||
dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
|
dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1]
|
||||||
dw, dh = dw / 2, dh / 2
|
dw, dh = dw / 2, dh / 2
|
||||||
|
|
||||||
if shape[::-1] != new_unpad:
|
if shape[::-1] != new_unpad:
|
||||||
@@ -40,13 +40,13 @@ def letterbox(im, new_shape=(640, 640), color=(114, 114, 114)):
|
|||||||
|
|
||||||
def xywh2xyxy(x):
|
def xywh2xyxy(x):
|
||||||
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
||||||
y[..., 0] = x[..., 0] - x[..., 2] / 2 # top left x
|
y[..., 0] = x[..., 0] - x[..., 2] / 2
|
||||||
y[..., 1] = x[..., 1] - x[..., 3] / 2 # top left y
|
y[..., 1] = x[..., 1] - x[..., 3] / 2
|
||||||
y[..., 2] = x[..., 0] + x[..., 2] / 2 # bottom right x
|
y[..., 2] = x[..., 0] + x[..., 2] / 2
|
||||||
y[..., 3] = x[..., 1] + x[..., 3] / 2 # bottom right y
|
y[..., 3] = x[..., 1] + x[..., 3] / 2
|
||||||
return y
|
return y
|
||||||
|
|
||||||
def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, max_det=300):
|
def non_max_suppression(prediction, conf_thres=0.01, iou_thres=0.45, max_det=300):
|
||||||
prediction = prediction.transpose(1, 2)
|
prediction = prediction.transpose(1, 2)
|
||||||
|
|
||||||
bs = prediction.shape[0]
|
bs = prediction.shape[0]
|
||||||
@@ -75,12 +75,63 @@ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, max_det=300
|
|||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
def non_max_suppression_numpy(prediction, conf_thres=0.25, iou_thres=0.45, max_det=300):
|
||||||
|
bs = prediction.shape[0]
|
||||||
|
output = [np.zeros((0, 6), dtype=np.float32)] * bs
|
||||||
|
|
||||||
|
for xi, x in enumerate(prediction):
|
||||||
|
bbox_xywh = x[:, :4]
|
||||||
|
class_probs = x[:, 4:]
|
||||||
|
|
||||||
|
class_ids = np.argmax(class_probs, axis=1)
|
||||||
|
confidences = np.max(class_probs, axis=1)
|
||||||
|
|
||||||
|
mask = confidences > conf_thres
|
||||||
|
bbox_xywh = bbox_xywh[mask]
|
||||||
|
confidences = confidences[mask]
|
||||||
|
class_ids = class_ids[mask]
|
||||||
|
|
||||||
|
if len(confidences) == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
bbox_tlwh = np.copy(bbox_xywh)
|
||||||
|
bbox_tlwh[:, 0] = bbox_xywh[:, 0] - bbox_xywh[:, 2] / 2
|
||||||
|
bbox_tlwh[:, 1] = bbox_xywh[:, 1] - bbox_xywh[:, 3] / 2
|
||||||
|
|
||||||
|
indices = cv2.dnn.NMSBoxes(
|
||||||
|
bboxes=bbox_tlwh.tolist(),
|
||||||
|
scores=confidences.tolist(),
|
||||||
|
score_threshold=conf_thres,
|
||||||
|
nms_threshold=iou_thres
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(indices) > 0:
|
||||||
|
indices = indices.flatten()
|
||||||
|
if len(indices) > max_det:
|
||||||
|
indices = indices[:max_det]
|
||||||
|
|
||||||
|
final_boxes_xywh = bbox_xywh[indices]
|
||||||
|
final_boxes_xyxy = xywh2xyxy(final_boxes_xywh)
|
||||||
|
|
||||||
|
final_scores = confidences[indices]
|
||||||
|
final_classes = class_ids[indices]
|
||||||
|
|
||||||
|
out_tensor = np.concatenate([
|
||||||
|
final_boxes_xyxy,
|
||||||
|
final_scores[:, None],
|
||||||
|
final_classes[:, None]
|
||||||
|
], axis=1)
|
||||||
|
|
||||||
|
output[xi] = out_tensor
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
print(f"Using device: {device}")
|
print(f"Using device: {device}")
|
||||||
|
|
||||||
model = YOLO11(nc=80, scale='s')
|
model = YOLO11(nc=80, scale='s')
|
||||||
model.load_weights("yolo11s.pth")
|
model.load_weights("my_yolo_result_qat/best_fp32_converted.pth")
|
||||||
model.to(device)
|
model.to(device)
|
||||||
model.eval()
|
model.eval()
|
||||||
post_std = YOLOPostProcessor(model.model[-1], use_segmentation=False)
|
post_std = YOLOPostProcessor(model.model[-1], use_segmentation=False)
|
||||||
@@ -104,20 +155,28 @@ def main():
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
pred = model(img_tensor)
|
pred = model(img_tensor)
|
||||||
|
|
||||||
pred = post_std(pred)
|
# pred = post_std(pred)
|
||||||
pred = non_max_suppression(pred, conf_thres=0.25, iou_thres=0.45)
|
# pred = non_max_suppression(pred, conf_thres=0.25, iou_thres=0.45)
|
||||||
|
# det = pred[0]
|
||||||
det = pred[0]
|
preds_raw_numpy = [p.cpu().numpy() for p in pred]
|
||||||
|
post_numpy = YOLOPostProcessorNumpy(strides=[8, 16, 32], reg_max=16, use_segmentation=False)
|
||||||
|
pred_numpy_decoded = post_numpy(preds_raw_numpy)
|
||||||
|
pred_results = non_max_suppression_numpy(pred_numpy_decoded, conf_thres=0.25, iou_thres=0.45)
|
||||||
|
det = pred_results[0]
|
||||||
|
|
||||||
if len(det):
|
if len(det):
|
||||||
det[:, [0, 2]] -= dw # x padding
|
det[:, [0, 2]] -= dw # x padding
|
||||||
det[:, [1, 3]] -= dh # y padding
|
det[:, [1, 3]] -= dh # y padding
|
||||||
det[:, :4] /= ratio
|
det[:, :4] /= ratio
|
||||||
|
|
||||||
det[:, 0].clamp_(0, img0.shape[1])
|
# det[:, 0].clamp_(0, img0.shape[1])
|
||||||
det[:, 1].clamp_(0, img0.shape[0])
|
# det[:, 1].clamp_(0, img0.shape[0])
|
||||||
det[:, 2].clamp_(0, img0.shape[1])
|
# det[:, 2].clamp_(0, img0.shape[1])
|
||||||
det[:, 3].clamp_(0, img0.shape[0])
|
# det[:, 3].clamp_(0, img0.shape[0])
|
||||||
|
det[:, 0] = np.clip(det[:, 0], 0, img0.shape[1])
|
||||||
|
det[:, 1] = np.clip(det[:, 1], 0, img0.shape[0])
|
||||||
|
det[:, 2] = np.clip(det[:, 2], 0, img0.shape[1])
|
||||||
|
det[:, 3] = np.clip(det[:, 3], 0, img0.shape[0])
|
||||||
|
|
||||||
print(f"检测到 {len(det)} 个目标")
|
print(f"检测到 {len(det)} 个目标")
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import numpy as np
|
|||||||
import torchvision
|
import torchvision
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from yolo11_standalone import YOLO11E, YOLOPostProcessor
|
from yolo11_standalone import YOLO11E, YOLOPostProcessor, YOLOPostProcessorNumpy
|
||||||
from mobile_clip_standalone import MobileCLIP
|
from mobile_clip_standalone import MobileCLIP
|
||||||
|
|
||||||
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
@@ -58,6 +58,56 @@ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.7, max_det=300)
|
|||||||
output[xi] = x[i]
|
output[xi] = x[i]
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
def non_max_suppression_numpy(prediction, conf_thres=0.25, iou_thres=0.45, max_det=300):
|
||||||
|
bs = prediction.shape[0]
|
||||||
|
output = [np.zeros((0, 6), dtype=np.float32)] * bs
|
||||||
|
|
||||||
|
for xi, x in enumerate(prediction):
|
||||||
|
bbox_xywh = x[:, :4]
|
||||||
|
class_probs = x[:, 4:]
|
||||||
|
|
||||||
|
class_ids = np.argmax(class_probs, axis=1)
|
||||||
|
confidences = np.max(class_probs, axis=1)
|
||||||
|
|
||||||
|
mask = confidences > conf_thres
|
||||||
|
bbox_xywh = bbox_xywh[mask]
|
||||||
|
confidences = confidences[mask]
|
||||||
|
class_ids = class_ids[mask]
|
||||||
|
|
||||||
|
if len(confidences) == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
bbox_tlwh = np.copy(bbox_xywh)
|
||||||
|
bbox_tlwh[:, 0] = bbox_xywh[:, 0] - bbox_xywh[:, 2] / 2
|
||||||
|
bbox_tlwh[:, 1] = bbox_xywh[:, 1] - bbox_xywh[:, 3] / 2
|
||||||
|
|
||||||
|
indices = cv2.dnn.NMSBoxes(
|
||||||
|
bboxes=bbox_tlwh.tolist(),
|
||||||
|
scores=confidences.tolist(),
|
||||||
|
score_threshold=conf_thres,
|
||||||
|
nms_threshold=iou_thres
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(indices) > 0:
|
||||||
|
indices = indices.flatten()
|
||||||
|
if len(indices) > max_det:
|
||||||
|
indices = indices[:max_det]
|
||||||
|
|
||||||
|
final_boxes_xywh = bbox_xywh[indices]
|
||||||
|
final_boxes_xyxy = xywh2xyxy(final_boxes_xywh)
|
||||||
|
final_scores = confidences[indices]
|
||||||
|
final_classes = class_ids[indices]
|
||||||
|
|
||||||
|
out_tensor = np.concatenate([
|
||||||
|
final_boxes_xyxy,
|
||||||
|
final_scores[:, None],
|
||||||
|
final_classes[:, None]
|
||||||
|
], axis=1)
|
||||||
|
|
||||||
|
output[xi] = out_tensor
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
print(f"Using device: {DEVICE}")
|
print(f"Using device: {DEVICE}")
|
||||||
|
|
||||||
@@ -94,17 +144,30 @@ def main():
|
|||||||
print("Running inference...")
|
print("Running inference...")
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
raw_outputs = yolo_model(img_tensor)
|
raw_outputs = yolo_model(img_tensor)
|
||||||
decoded_box, mc, p = post_processor(raw_outputs)
|
# decoded_box, mc, p = post_processor(raw_outputs)
|
||||||
|
# pred = non_max_suppression(decoded_box, conf_thres=0.25, iou_thres=0.7)
|
||||||
|
# det = pred[0]
|
||||||
|
feat_maps, mc, p = raw_outputs
|
||||||
|
|
||||||
pred = non_max_suppression(decoded_box, conf_thres=0.25, iou_thres=0.7)
|
feat_maps_numpy = [f.detach().cpu().numpy() for f in feat_maps]
|
||||||
|
mc_numpy = mc.detach().cpu().numpy()
|
||||||
|
p_numpy = p.detach().cpu().numpy()
|
||||||
|
raw_outputs_numpy = (feat_maps_numpy, mc_numpy, p_numpy)
|
||||||
|
post_processor_numpy = YOLOPostProcessorNumpy(strides=[8, 16, 32], reg_max=16, use_segmentation=True)
|
||||||
|
decoded_box_numpy, mc_numpy_out, p_numpy_out = post_processor_numpy(raw_outputs_numpy)
|
||||||
|
pred_results = non_max_suppression_numpy(decoded_box_numpy, conf_thres=0.25, iou_thres=0.7)
|
||||||
|
det = pred_results[0]
|
||||||
|
|
||||||
det = pred[0]
|
|
||||||
if len(det):
|
if len(det):
|
||||||
det[:, [0, 2]] -= dw
|
det[:, [0, 2]] -= dw
|
||||||
det[:, [1, 3]] -= dh
|
det[:, [1, 3]] -= dh
|
||||||
det[:, :4] /= ratio
|
det[:, :4] /= ratio
|
||||||
det[:, [0, 2]].clamp_(0, img0.shape[1])
|
# det[:, [0, 2]].clamp_(0, img0.shape[1])
|
||||||
det[:, [1, 3]].clamp_(0, img0.shape[0])
|
# det[:, [1, 3]].clamp_(0, img0.shape[0])
|
||||||
|
det[:, 0] = np.clip(det[:, 0], 0, img0.shape[1])
|
||||||
|
det[:, 1] = np.clip(det[:, 1], 0, img0.shape[0])
|
||||||
|
det[:, 2] = np.clip(det[:, 2], 0, img0.shape[1])
|
||||||
|
det[:, 3] = np.clip(det[:, 3], 0, img0.shape[0])
|
||||||
|
|
||||||
print(f"Detected {len(det)} objects:")
|
print(f"Detected {len(det)} objects:")
|
||||||
for *xyxy, conf, cls in det:
|
for *xyxy, conf, cls in det:
|
||||||
@@ -119,7 +182,7 @@ def main():
|
|||||||
else:
|
else:
|
||||||
print("No objects detected.")
|
print("No objects detected.")
|
||||||
|
|
||||||
cv2.imwrite("result_separate.jpg", img0)
|
cv2.imwrite("result_yoloe.jpg", img0)
|
||||||
print("Result saved.")
|
print("Result saved.")
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
324
qat_utils.py
Normal file
324
qat_utils.py
Normal file
@@ -0,0 +1,324 @@
|
|||||||
|
import os
|
||||||
|
import copy
|
||||||
|
import logging
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import numpy as np
|
||||||
|
import onnx
|
||||||
|
from pathlib import Path
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from torch.ao.quantization import QConfig
|
||||||
|
from torch.ao.quantization.fake_quantize import FakeQuantize
|
||||||
|
from torch.ao.quantization.observer import (
|
||||||
|
MovingAverageMinMaxObserver,
|
||||||
|
MovingAveragePerChannelMinMaxObserver,
|
||||||
|
)
|
||||||
|
from torch.ao.quantization.quantize_fx import prepare_qat_fx, convert_fx
|
||||||
|
from torch.ao.quantization.qconfig_mapping import QConfigMapping
|
||||||
|
from torch.onnx import register_custom_op_symbolic
|
||||||
|
|
||||||
|
# ONNX Runtime Quantization
|
||||||
|
from onnxruntime.quantization import (
|
||||||
|
quantize_static,
|
||||||
|
CalibrationDataReader,
|
||||||
|
QuantType,
|
||||||
|
QuantFormat,
|
||||||
|
)
|
||||||
|
|
||||||
|
LOGGER = logging.getLogger("QAT_Utils")
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class QATConfig:
|
||||||
|
"""Unified Configuration for QAT and Export"""
|
||||||
|
# Training Config
|
||||||
|
ignore_layers: List[str] = field(default_factory=list)
|
||||||
|
|
||||||
|
# Export Config
|
||||||
|
img_size: int = 640
|
||||||
|
calibration_image_dir: Optional[str] = None
|
||||||
|
max_calibration_samples: int = 100
|
||||||
|
|
||||||
|
# Paths
|
||||||
|
save_dir: Path = Path('./runs/qat_unified')
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
self.save_dir = Path(self.save_dir)
|
||||||
|
self.save_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
|
class YOLOCalibrationDataReader(CalibrationDataReader):
|
||||||
|
"""YOLO model calibration data reader for ONNX Runtime"""
|
||||||
|
|
||||||
|
def __init__(self, calibration_image_folder: str, input_name: str = 'images',
|
||||||
|
img_size: int = 640, max_samples: int = 100):
|
||||||
|
self.input_name = input_name
|
||||||
|
self.img_size = img_size
|
||||||
|
self.max_samples = max_samples
|
||||||
|
|
||||||
|
if not calibration_image_folder or not os.path.exists(calibration_image_folder):
|
||||||
|
raise ValueError(f"Calibration folder not found: {calibration_image_folder}")
|
||||||
|
|
||||||
|
# Collect image paths
|
||||||
|
self.image_paths = []
|
||||||
|
for ext in ['*.jpg', '*.jpeg', '*.png', '*.bmp']:
|
||||||
|
self.image_paths.extend(Path(calibration_image_folder).glob(ext))
|
||||||
|
|
||||||
|
if not self.image_paths:
|
||||||
|
LOGGER.warning(f"No images found in {calibration_image_folder}")
|
||||||
|
|
||||||
|
self.image_paths = self.image_paths[:max_samples]
|
||||||
|
LOGGER.info(f"Found {len(self.image_paths)} calibration images in {calibration_image_folder}")
|
||||||
|
|
||||||
|
self.current_idx = 0
|
||||||
|
|
||||||
|
def preprocess(self, img_path):
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
img = cv2.imread(str(img_path))
|
||||||
|
if img is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Resize with letterbox
|
||||||
|
h, w = img.shape[:2]
|
||||||
|
scale = min(self.img_size / h, self.img_size / w)
|
||||||
|
new_h, new_w = int(h * scale), int(w * scale)
|
||||||
|
|
||||||
|
img = cv2.resize(img, (new_w, new_h))
|
||||||
|
|
||||||
|
# Pad to target size
|
||||||
|
canvas = np.full((self.img_size, self.img_size, 3), 114, dtype=np.uint8)
|
||||||
|
pad_h = (self.img_size - new_h) // 2
|
||||||
|
pad_w = (self.img_size - new_w) // 2
|
||||||
|
canvas[pad_h:pad_h+new_h, pad_w:pad_w+new_w] = img
|
||||||
|
|
||||||
|
# Convert to model input format
|
||||||
|
img = canvas[:, :, ::-1] # BGR to RGB
|
||||||
|
img = img.transpose(2, 0, 1) # HWC to CHW
|
||||||
|
img = img.astype(np.float32) / 255.0
|
||||||
|
img = np.expand_dims(img, axis=0) # Add batch dimension
|
||||||
|
|
||||||
|
return img
|
||||||
|
|
||||||
|
def get_next(self):
|
||||||
|
if self.current_idx >= len(self.image_paths):
|
||||||
|
return None
|
||||||
|
|
||||||
|
img_path = self.image_paths[self.current_idx]
|
||||||
|
self.current_idx += 1
|
||||||
|
|
||||||
|
img = self.preprocess(img_path)
|
||||||
|
if img is None:
|
||||||
|
return self.get_next()
|
||||||
|
|
||||||
|
return {self.input_name: img}
|
||||||
|
|
||||||
|
def rewind(self):
|
||||||
|
self.current_idx = 0
|
||||||
|
|
||||||
|
|
||||||
|
def get_rknn_qconfig() -> QConfig:
|
||||||
|
return QConfig(
|
||||||
|
activation=FakeQuantize.with_args(
|
||||||
|
observer=MovingAverageMinMaxObserver,
|
||||||
|
quant_min=0,
|
||||||
|
quant_max=255,
|
||||||
|
reduce_range=False,
|
||||||
|
dtype=torch.quint8,
|
||||||
|
),
|
||||||
|
weight=FakeQuantize.with_args(
|
||||||
|
observer=MovingAveragePerChannelMinMaxObserver,
|
||||||
|
quant_min=-128,
|
||||||
|
quant_max=127,
|
||||||
|
dtype=torch.qint8,
|
||||||
|
qscheme=torch.per_channel_affine,
|
||||||
|
reduce_range=False,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def prepare_model_for_qat(
|
||||||
|
model: nn.Module,
|
||||||
|
example_inputs: torch.Tensor,
|
||||||
|
config: QATConfig
|
||||||
|
) -> nn.Module:
|
||||||
|
qconfig = get_rknn_qconfig()
|
||||||
|
qconfig_mapping = QConfigMapping().set_global(qconfig)
|
||||||
|
|
||||||
|
if config.ignore_layers:
|
||||||
|
LOGGER.info(f"Mixed Precision: Ignoring layers: {config.ignore_layers}")
|
||||||
|
for layer_name in config.ignore_layers:
|
||||||
|
qconfig_mapping.set_module_name(layer_name, None)
|
||||||
|
|
||||||
|
model_to_quantize = copy.deepcopy(model)
|
||||||
|
model_to_quantize.eval()
|
||||||
|
|
||||||
|
try:
|
||||||
|
prepared_model = prepare_qat_fx(
|
||||||
|
model_to_quantize,
|
||||||
|
qconfig_mapping=qconfig_mapping,
|
||||||
|
example_inputs=(example_inputs,)
|
||||||
|
)
|
||||||
|
return prepared_model
|
||||||
|
except Exception as e:
|
||||||
|
LOGGER.error(f"FX Preparation failed. Ensure model is traceable. Error: {e}")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def load_qat_weights_to_fp32(fp32_model: nn.Module, qat_state_dict: dict) -> nn.Module:
|
||||||
|
"""
|
||||||
|
Loads weights from a QAT model (with fake quant layers) into a standard FP32 model.
|
||||||
|
|
||||||
|
FX Graph mode wraps Conv+BN modules, so keys need to be mapped:
|
||||||
|
- QAT: model.X.conv.bn.weight -> FP32: model.X.bn.weight
|
||||||
|
- QAT: model.X.conv.weight -> FP32: model.X.conv.weight (direct match)
|
||||||
|
"""
|
||||||
|
import re
|
||||||
|
|
||||||
|
LOGGER.info("Loading QAT weights into FP32 model...")
|
||||||
|
fp32_model_dict = fp32_model.state_dict()
|
||||||
|
|
||||||
|
# Quantization-related keys to skip
|
||||||
|
quant_keywords = ['scale', 'zero_point', 'activation_post_process', 'fake_quant',
|
||||||
|
'_observer', 'eps', 'min_val', 'max_val', 'quant_min', 'quant_max',
|
||||||
|
'fake_quant_enabled', 'observer_enabled']
|
||||||
|
|
||||||
|
new_state_dict = {}
|
||||||
|
keys_loaded = []
|
||||||
|
missing_fp32_keys = set(fp32_model_dict.keys())
|
||||||
|
|
||||||
|
for qat_key, qat_value in qat_state_dict.items():
|
||||||
|
# Skip quantization buffers
|
||||||
|
if any(kw in qat_key for kw in quant_keywords):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Map QAT key to FP32 key: replace '.conv.bn.' with '.bn.'
|
||||||
|
fp32_key = re.sub(r'\.conv\.bn\.', '.bn.', qat_key)
|
||||||
|
|
||||||
|
if fp32_key in fp32_model_dict:
|
||||||
|
if qat_value.shape == fp32_model_dict[fp32_key].shape:
|
||||||
|
new_state_dict[fp32_key] = qat_value
|
||||||
|
keys_loaded.append(fp32_key)
|
||||||
|
missing_fp32_keys.discard(fp32_key)
|
||||||
|
else:
|
||||||
|
LOGGER.warning(f"Shape mismatch for {fp32_key}: QAT {qat_value.shape} vs FP32 {fp32_model_dict[fp32_key].shape}")
|
||||||
|
elif qat_key in fp32_model_dict:
|
||||||
|
# Fallback: try direct match
|
||||||
|
if qat_value.shape == fp32_model_dict[qat_key].shape:
|
||||||
|
new_state_dict[qat_key] = qat_value
|
||||||
|
keys_loaded.append(qat_key)
|
||||||
|
missing_fp32_keys.discard(qat_key)
|
||||||
|
|
||||||
|
load_result = fp32_model.load_state_dict(new_state_dict, strict=False)
|
||||||
|
|
||||||
|
LOGGER.info(f"Loaded {len(keys_loaded)} params from QAT checkpoint.")
|
||||||
|
if missing_fp32_keys:
|
||||||
|
LOGGER.warning(f"Missing keys: {len(missing_fp32_keys)}")
|
||||||
|
|
||||||
|
return fp32_model
|
||||||
|
|
||||||
|
|
||||||
|
def export_knn_compatible_onnx(
|
||||||
|
config: QATConfig,
|
||||||
|
model_class,
|
||||||
|
best_weights_path: str,
|
||||||
|
nc: int,
|
||||||
|
scale: str
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Orchestrates the full export pipeline:
|
||||||
|
1. Load pure FP32 model.
|
||||||
|
2. Load QAT-trained weights (best.pt) into FP32 model.
|
||||||
|
3. Export FP32 ONNX.
|
||||||
|
4. Run ORT Quantization (PTQ) to produce QDQ ONNX.
|
||||||
|
"""
|
||||||
|
|
||||||
|
fp32_onnx_path = config.save_dir / 'model_fp32.onnx'
|
||||||
|
qdq_onnx_path = config.save_dir / 'model_int8_qdq.onnx'
|
||||||
|
|
||||||
|
LOGGER.info(">>> STEP 1: Exporting Clean FP32 ONNX with QAT weights...")
|
||||||
|
|
||||||
|
# 1. Instantiate clean FP32 model
|
||||||
|
fp32_model = model_class(nc=nc, scale=scale)
|
||||||
|
|
||||||
|
# 2. Load QAT weights
|
||||||
|
ckpt = torch.load(best_weights_path, map_location='cpu')
|
||||||
|
if 'model' in ckpt:
|
||||||
|
state_dict = ckpt['model']
|
||||||
|
else:
|
||||||
|
state_dict = ckpt
|
||||||
|
|
||||||
|
fp32_model = load_qat_weights_to_fp32(fp32_model, state_dict)
|
||||||
|
fp32_model.eval()
|
||||||
|
|
||||||
|
# 3. Export FP32 ONNX
|
||||||
|
dummy_input = torch.randn(1, 3, config.img_size, config.img_size)
|
||||||
|
|
||||||
|
try:
|
||||||
|
torch.onnx.export(
|
||||||
|
fp32_model,
|
||||||
|
dummy_input,
|
||||||
|
str(fp32_onnx_path),
|
||||||
|
opset_version=13,
|
||||||
|
input_names=['images'],
|
||||||
|
output_names=['output'],
|
||||||
|
do_constant_folding=True,
|
||||||
|
dynamo=False
|
||||||
|
)
|
||||||
|
LOGGER.info(f"FP32 ONNX exported to {fp32_onnx_path}")
|
||||||
|
|
||||||
|
# 3.5 Pre-process (Shape Inference & Optimization) to suppress warnings and improve quantization
|
||||||
|
from onnxruntime.quantization import quant_pre_process
|
||||||
|
|
||||||
|
LOGGER.info(">>> STEP 1.5: Running ONNX Runtime Pre-processing...")
|
||||||
|
preprocessed_path = config.save_dir / 'model_fp32_preprocessed.onnx'
|
||||||
|
quant_pre_process(
|
||||||
|
input_model_path=str(fp32_onnx_path),
|
||||||
|
output_model_path=str(preprocessed_path),
|
||||||
|
skip_symbolic_shape=False
|
||||||
|
)
|
||||||
|
LOGGER.info(f"Pre-processed model saved to {preprocessed_path}")
|
||||||
|
|
||||||
|
# Update input path for quantization to use the preprocessed model
|
||||||
|
model_input_path = str(preprocessed_path)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
LOGGER.error(f"Failed to export or pre-process FP32 ONNX: {e}")
|
||||||
|
return
|
||||||
|
|
||||||
|
LOGGER.info(">>> STEP 2: Running ONNX Runtime Quantization (PTQ)...")
|
||||||
|
|
||||||
|
if not config.calibration_image_dir:
|
||||||
|
LOGGER.error("No calibration image directory provided. Skipping PTQ.")
|
||||||
|
return
|
||||||
|
|
||||||
|
calibration_reader = YOLOCalibrationDataReader(
|
||||||
|
calibration_image_folder=config.calibration_image_dir,
|
||||||
|
max_samples=config.max_calibration_samples,
|
||||||
|
img_size=config.img_size
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
quantize_static(
|
||||||
|
model_input=model_input_path, # Use the pre-processed model
|
||||||
|
model_output=str(qdq_onnx_path),
|
||||||
|
calibration_data_reader=calibration_reader,
|
||||||
|
quant_format=QuantFormat.QDQ,
|
||||||
|
activation_type=QuantType.QUInt8,
|
||||||
|
weight_type=QuantType.QInt8,
|
||||||
|
per_channel=True,
|
||||||
|
reduce_range=False,
|
||||||
|
extra_options={
|
||||||
|
'ActivationSymmetric': False,
|
||||||
|
'WeightSymmetric': True
|
||||||
|
}
|
||||||
|
)
|
||||||
|
LOGGER.info(f"QDQ ONNX exported to {qdq_onnx_path}")
|
||||||
|
LOGGER.info(">>> Export Workflow Complete.")
|
||||||
|
except Exception as e:
|
||||||
|
LOGGER.error(f"PTQ Failed: {e}")
|
||||||
|
|
||||||
|
def _aten_copy_symbolic(g, self, src, non_blocking):
|
||||||
|
return g.op("Identity", src)
|
||||||
|
|
||||||
|
register_custom_op_symbolic("aten::copy", _aten_copy_symbolic, 13)
|
||||||
|
|
||||||
109
train_qat.py
Normal file
109
train_qat.py
Normal file
@@ -0,0 +1,109 @@
|
|||||||
|
import torch
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
from dataset import YOLODataset
|
||||||
|
from yolo11_standalone import YOLO11, YOLOPostProcessor
|
||||||
|
from loss import YOLOv8DetectionLoss
|
||||||
|
from trainer_qat import YOLO11QATTrainer
|
||||||
|
import qat_utils
|
||||||
|
from qat_utils import QATConfig
|
||||||
|
|
||||||
|
def run_qat_training():
|
||||||
|
# ==========================
|
||||||
|
# 1. Configuration
|
||||||
|
# ==========================
|
||||||
|
img_dir_train = "E:\\Datasets\\coco\\images\\train2017"
|
||||||
|
label_dir_train = "E:\\Datasets\\coco\\labels\\train2017"
|
||||||
|
img_dir_val = "E:\\Datasets\\coco\\images\\val2017"
|
||||||
|
label_dir_val = "E:\\Datasets\\coco\\labels\\val2017"
|
||||||
|
# img_dir_train = "E:\\Datasets\\coco8\\images\\train"
|
||||||
|
# label_dir_train = "E:\\Datasets\\coco8\\labels\\train"
|
||||||
|
# img_dir_val = "E:\\Datasets\\coco8\\images\\val"
|
||||||
|
# label_dir_val = "E:\\Datasets\\coco8\\labels\\val"
|
||||||
|
|
||||||
|
# Validation/Calibration Set (reused for PTQ calibration)
|
||||||
|
# Using val set for calibration is common, or a subset of train
|
||||||
|
calibration_img_dir = img_dir_val
|
||||||
|
|
||||||
|
weights = "yolo11s.pth"
|
||||||
|
scale = 's'
|
||||||
|
nc = 80
|
||||||
|
|
||||||
|
# QAT Training Hyperparameters
|
||||||
|
epochs = 10
|
||||||
|
batch_size = 16 # Suggest increasing to 32 or 64 if GPU memory allows for better BN stats
|
||||||
|
lr = 0.0001 # Lower LR (1e-4) is safer for QAT fine-tuning to preserve pretrained knowledge
|
||||||
|
|
||||||
|
# Unified QAT Configuration
|
||||||
|
config = QATConfig(
|
||||||
|
img_size=640,
|
||||||
|
save_dir='./my_yolo_result_qat',
|
||||||
|
calibration_image_dir=calibration_img_dir,
|
||||||
|
max_calibration_samples=500, # Increased for full COCO to ensure robust PTQ quantization parameters
|
||||||
|
ignore_layers=[] # Add layer names here to skip quantization if needed
|
||||||
|
)
|
||||||
|
|
||||||
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
|
||||||
|
# ==========================
|
||||||
|
# 2. Data Preparation
|
||||||
|
# ==========================
|
||||||
|
print("Loading Data...")
|
||||||
|
train_dataset = YOLODataset(img_dir_train, label_dir_train, img_size=config.img_size, is_train=True)
|
||||||
|
val_dataset = YOLODataset(img_dir_val, label_dir_val, img_size=config.img_size, is_train=False)
|
||||||
|
|
||||||
|
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,
|
||||||
|
collate_fn=YOLODataset.collate_fn, num_workers=4, pin_memory=True)
|
||||||
|
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False,
|
||||||
|
collate_fn=YOLODataset.collate_fn, num_workers=4, pin_memory=True)
|
||||||
|
|
||||||
|
# ==========================
|
||||||
|
# 3. Model Initialization
|
||||||
|
# ==========================
|
||||||
|
print("Initializing Model...")
|
||||||
|
model = YOLO11(nc=nc, scale=scale)
|
||||||
|
if weights:
|
||||||
|
model.load_weights(weights)
|
||||||
|
|
||||||
|
post_std = YOLOPostProcessor(model.model[-1], use_segmentation=False)
|
||||||
|
loss_fn = YOLOv8DetectionLoss(nc=nc, reg_max=16, stride=[8, 16, 32])
|
||||||
|
|
||||||
|
# ==========================
|
||||||
|
# 4. QAT Training
|
||||||
|
# ==========================
|
||||||
|
print("Creating QAT Trainer...")
|
||||||
|
trainer = YOLO11QATTrainer(
|
||||||
|
model=model,
|
||||||
|
post_processor=post_std,
|
||||||
|
train_loader=train_loader,
|
||||||
|
val_loader=val_loader,
|
||||||
|
loss_fn=loss_fn,
|
||||||
|
config=config,
|
||||||
|
epochs=epochs,
|
||||||
|
lr=lr,
|
||||||
|
device=device
|
||||||
|
)
|
||||||
|
|
||||||
|
trainer.train()
|
||||||
|
|
||||||
|
# ==========================
|
||||||
|
# 5. Export Pipeline
|
||||||
|
# (Load Best -> FP32 -> ONNX -> QDQ PTQ)
|
||||||
|
# ==========================
|
||||||
|
print("\nStarting Export Pipeline...")
|
||||||
|
best_weights = config.save_dir / 'best.pt'
|
||||||
|
|
||||||
|
if best_weights.exists():
|
||||||
|
qat_utils.export_knn_compatible_onnx(
|
||||||
|
config=config,
|
||||||
|
model_class=YOLO11,
|
||||||
|
best_weights_path=str(best_weights),
|
||||||
|
nc=nc,
|
||||||
|
scale=scale
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print("Error: Best weights not found. Export failed.")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
run_qat_training()
|
||||||
@@ -146,7 +146,6 @@ class YOLO11Trainer:
|
|||||||
iou_thres = 0.7
|
iou_thres = 0.7
|
||||||
iouv = torch.linspace(0.5, 0.95, 10, device=device)
|
iouv = torch.linspace(0.5, 0.95, 10, device=device)
|
||||||
|
|
||||||
loss_sum = torch.zeros(3, device=device)
|
|
||||||
stats = []
|
stats = []
|
||||||
|
|
||||||
LOGGER.info("\nValidating...")
|
LOGGER.info("\nValidating...")
|
||||||
|
|||||||
199
trainer_qat.py
Normal file
199
trainer_qat.py
Normal file
@@ -0,0 +1,199 @@
|
|||||||
|
import math
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
import numpy as np
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional, List
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.optim as optim
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from metrics import ap_per_class, box_iou, non_max_suppression, xywh2xyxy
|
||||||
|
from qat_utils import prepare_model_for_qat, QATConfig
|
||||||
|
|
||||||
|
logging.basicConfig(format="%(message)s", level=logging.INFO)
|
||||||
|
LOGGER = logging.getLogger("QAT_Trainer")
|
||||||
|
|
||||||
|
class YOLO11QATTrainer:
|
||||||
|
def __init__(self,
|
||||||
|
model,
|
||||||
|
post_processor,
|
||||||
|
train_loader,
|
||||||
|
val_loader,
|
||||||
|
loss_fn,
|
||||||
|
config: QATConfig,
|
||||||
|
epochs: int = 5,
|
||||||
|
lr: float = 0.001,
|
||||||
|
device: str = 'cuda'):
|
||||||
|
|
||||||
|
self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
|
||||||
|
|
||||||
|
self.config = config
|
||||||
|
self.save_dir = config.save_dir
|
||||||
|
|
||||||
|
self.original_model = model.to(self.device)
|
||||||
|
self.post_processor = post_processor
|
||||||
|
self.train_loader = train_loader
|
||||||
|
self.val_loader = val_loader
|
||||||
|
self.loss_fn = loss_fn
|
||||||
|
self.epochs = epochs
|
||||||
|
self.lr = lr
|
||||||
|
|
||||||
|
self.qat_model = None
|
||||||
|
self.optimizer = None
|
||||||
|
self.scheduler = None
|
||||||
|
|
||||||
|
def setup(self):
|
||||||
|
LOGGER.info(">>> Setting up QAT...")
|
||||||
|
|
||||||
|
example_input = torch.randn(1, 3, self.config.img_size, self.config.img_size).to(self.device)
|
||||||
|
self.qat_model = prepare_model_for_qat(
|
||||||
|
self.original_model,
|
||||||
|
example_input,
|
||||||
|
config=self.config
|
||||||
|
)
|
||||||
|
self.qat_model.to(self.device)
|
||||||
|
|
||||||
|
self.optimizer = optim.Adam(self.qat_model.parameters(), lr=self.lr, weight_decay=1e-5)
|
||||||
|
|
||||||
|
lf = lambda x: ((1 - math.cos(x * math.pi / self.epochs)) / 2) * (0.01 - 1) + 1
|
||||||
|
self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lf)
|
||||||
|
|
||||||
|
def train_epoch(self, epoch):
|
||||||
|
self.qat_model.train()
|
||||||
|
pbar = tqdm(self.train_loader, desc=f'Epoch {epoch+1}/{self.epochs}')
|
||||||
|
|
||||||
|
mean_loss = torch.zeros(3, device=self.device)
|
||||||
|
|
||||||
|
for i, (imgs, targets, _) in enumerate(pbar):
|
||||||
|
imgs = imgs.to(self.device).float()
|
||||||
|
targets = targets.to(self.device)
|
||||||
|
|
||||||
|
self.optimizer.zero_grad()
|
||||||
|
|
||||||
|
preds = self.qat_model(imgs)
|
||||||
|
|
||||||
|
target_batch = {
|
||||||
|
"batch_idx": targets[:, 0],
|
||||||
|
"cls": targets[:, 1],
|
||||||
|
"bboxes": targets[:, 2:],
|
||||||
|
}
|
||||||
|
loss, loss_items = self.loss_fn(preds, target_batch)
|
||||||
|
|
||||||
|
loss.backward()
|
||||||
|
torch.nn.utils.clip_grad_norm_(self.qat_model.parameters(), max_norm=10.0)
|
||||||
|
self.optimizer.step()
|
||||||
|
|
||||||
|
mean_loss = (mean_loss * i + loss_items.detach()) / (i + 1)
|
||||||
|
pbar.set_postfix({
|
||||||
|
'box': f"{mean_loss[0]:.4f}",
|
||||||
|
'cls': f"{mean_loss[1]:.4f}",
|
||||||
|
'dfl': f"{mean_loss[2]:.4f}",
|
||||||
|
'lr': f"{self.optimizer.param_groups[0]['lr']:.6f}"
|
||||||
|
})
|
||||||
|
|
||||||
|
def validate(self):
|
||||||
|
conf_thres = 0.001
|
||||||
|
iou_thres = 0.7
|
||||||
|
iouv = torch.linspace(0.5, 0.95, 10, device=self.device)
|
||||||
|
|
||||||
|
stats = []
|
||||||
|
|
||||||
|
LOGGER.info("\nValidating...")
|
||||||
|
|
||||||
|
self.qat_model.eval()
|
||||||
|
pbar = tqdm(self.val_loader, desc="Calc Metrics")
|
||||||
|
with torch.no_grad():
|
||||||
|
for batch in pbar:
|
||||||
|
imgs, targets, _ = batch
|
||||||
|
imgs = imgs.to(self.device, non_blocking=True)
|
||||||
|
targets = targets.to(self.device)
|
||||||
|
_, _, height, width = imgs.shape
|
||||||
|
|
||||||
|
preds = self.qat_model(imgs)
|
||||||
|
preds = self.post_processor(preds)
|
||||||
|
preds = non_max_suppression(preds, conf_thres=conf_thres, iou_thres=iou_thres)
|
||||||
|
|
||||||
|
for si, pred in enumerate(preds):
|
||||||
|
labels = targets[targets[:, 0] == si]
|
||||||
|
nl = len(labels)
|
||||||
|
tcls = labels[:, 1].tolist() if nl else []
|
||||||
|
|
||||||
|
if len(pred) == 0:
|
||||||
|
if nl:
|
||||||
|
stats.append((torch.zeros(0, iouv.numel(), dtype=torch.bool),
|
||||||
|
torch.Tensor(), torch.Tensor(), torch.tensor(tcls)))
|
||||||
|
continue
|
||||||
|
|
||||||
|
predn = pred.clone()
|
||||||
|
|
||||||
|
if nl:
|
||||||
|
tbox = xywh2xyxy(labels[:, 2:6])
|
||||||
|
tbox[:, [0, 2]] *= width
|
||||||
|
tbox[:, [1, 3]] *= height
|
||||||
|
labelsn = torch.cat((labels[:, 1:2], tbox), 1)
|
||||||
|
correct = self._process_batch(predn, labelsn, iouv)
|
||||||
|
else:
|
||||||
|
correct = torch.zeros(pred.shape[0], iouv.numel(), dtype=torch.bool)
|
||||||
|
|
||||||
|
stats.append((correct.cpu(), pred[:, 4].cpu(), pred[:, 5].cpu(), torch.tensor(tcls)))
|
||||||
|
|
||||||
|
mp, mr, map50, map5095 = 0.0, 0.0, 0.0, 0.0
|
||||||
|
if len(stats) and stats[0][0].any():
|
||||||
|
stats = [torch.cat(x, 0).cpu().numpy() for x in zip(*stats)]
|
||||||
|
tp, fp, p, r, f1, ap, ap_class = ap_per_class(*stats)
|
||||||
|
ap50, ap = ap[:, 0], ap.mean(1)
|
||||||
|
mp, mr, map50, map5095 = p.mean(), r.mean(), ap50.mean(), ap.mean()
|
||||||
|
|
||||||
|
LOGGER.info(f"Val Results: Prec={mp:.3f}, Recall={mr:.3f} mAP50={map50:.3f} mAP50-95={map5095:.3f}")
|
||||||
|
|
||||||
|
return map50
|
||||||
|
|
||||||
|
def _process_batch(self, detections, labels, iouv):
|
||||||
|
correct = torch.zeros(detections.shape[0], iouv.shape[0], dtype=torch.bool, device=iouv.device)
|
||||||
|
iou = box_iou(labels[:, 1:], detections[:, :4])
|
||||||
|
|
||||||
|
x = torch.where((iou >= iouv[0]) & (labels[:, 0:1] == detections[:, 5]))
|
||||||
|
|
||||||
|
if x[0].shape[0]:
|
||||||
|
matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1).cpu().numpy()
|
||||||
|
if x[0].shape[0] > 1:
|
||||||
|
matches = matches[matches[:, 2].argsort()[::-1]]
|
||||||
|
matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
|
||||||
|
matches = matches[matches[:, 2].argsort()[::-1]]
|
||||||
|
matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
|
||||||
|
|
||||||
|
matches = torch.from_numpy(matches).to(iouv.device)
|
||||||
|
correct[matches[:, 1].long()] = matches[:, 2:3] >= iouv
|
||||||
|
|
||||||
|
return correct
|
||||||
|
|
||||||
|
def train(self):
|
||||||
|
self.setup()
|
||||||
|
|
||||||
|
LOGGER.info(f"Starting training on {self.device} for {self.epochs} epochs...")
|
||||||
|
start_time = time.time()
|
||||||
|
best_fitness = -1
|
||||||
|
|
||||||
|
for epoch in range(self.epochs):
|
||||||
|
self.train_epoch(epoch)
|
||||||
|
self.scheduler.step()
|
||||||
|
map50 = self.validate()
|
||||||
|
|
||||||
|
ckpt = {
|
||||||
|
'epoch': epoch,
|
||||||
|
'model': self.qat_model.state_dict(),
|
||||||
|
'optimizer': self.optimizer.state_dict(),
|
||||||
|
}
|
||||||
|
|
||||||
|
torch.save(ckpt, self.save_dir / 'last_qat.pt')
|
||||||
|
|
||||||
|
if map50 > best_fitness:
|
||||||
|
best_fitness = map50
|
||||||
|
torch.save(ckpt, self.save_dir / 'best.pt')
|
||||||
|
LOGGER.info(f"--> Saved best model with Recall/mAP: {best_fitness:.4f}")
|
||||||
|
|
||||||
|
LOGGER.info(">>> Training Finished.")
|
||||||
|
|
||||||
@@ -562,6 +562,90 @@ class YOLO11E(YOLO11):
|
|||||||
|
|
||||||
return head(feats, cls_pe)
|
return head(feats, cls_pe)
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
# [Part 4] PostProcessorNumpy
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
class YOLOPostProcessorNumpy:
|
||||||
|
def __init__(self, strides=[8, 16, 32], reg_max=16, use_segmentation=False):
|
||||||
|
self.strides = np.array(strides, dtype=np.float32)
|
||||||
|
self.reg_max = reg_max
|
||||||
|
self.use_segmentation = use_segmentation
|
||||||
|
self.anchors = None
|
||||||
|
self.strides_array = None
|
||||||
|
self.shape = None
|
||||||
|
self.dfl_weights = np.arange(reg_max, dtype=np.float32).reshape(1, 1, reg_max, 1)
|
||||||
|
|
||||||
|
def sigmoid(self, x):
|
||||||
|
return 1 / (1 + np.exp(-x))
|
||||||
|
|
||||||
|
def softmax(self, x, axis=-1):
|
||||||
|
x_max = np.max(x, axis=axis, keepdims=True)
|
||||||
|
e_x = np.exp(x - x_max)
|
||||||
|
return e_x / np.sum(e_x, axis=axis, keepdims=True)
|
||||||
|
|
||||||
|
def make_anchors(self, feats, strides, grid_cell_offset=0.5):
|
||||||
|
anchor_points, stride_list = [], []
|
||||||
|
for i, stride in enumerate(strides):
|
||||||
|
_, _, h, w = feats[i].shape
|
||||||
|
sx = np.arange(w, dtype=np.float32) + grid_cell_offset
|
||||||
|
sy = np.arange(h, dtype=np.float32) + grid_cell_offset
|
||||||
|
sy, sx = np.meshgrid(sy, sx, indexing='ij')
|
||||||
|
|
||||||
|
anchor_points.append(np.stack((sx, sy), -1).reshape(-1, 2))
|
||||||
|
stride_list.append(np.full((h * w, 1), stride, dtype=np.float32))
|
||||||
|
|
||||||
|
return np.concatenate(anchor_points), np.concatenate(stride_list)
|
||||||
|
|
||||||
|
def dist2bbox(self, distance, anchor_points, xywh=True, dim=-1):
|
||||||
|
lt, rb = np.split(distance, 2, axis=dim)
|
||||||
|
x1y1 = anchor_points - lt
|
||||||
|
x2y2 = anchor_points + rb
|
||||||
|
if xywh:
|
||||||
|
c_xy = (x1y1 + x2y2) / 2
|
||||||
|
wh = x2y2 - x1y1
|
||||||
|
return np.concatenate((c_xy, wh), axis=dim)
|
||||||
|
return np.concatenate((x1y1, x2y2), axis=dim)
|
||||||
|
|
||||||
|
def dfl_decode(self, x):
|
||||||
|
B, C, A = x.shape
|
||||||
|
x = x.reshape(B, 4, self.reg_max, A)
|
||||||
|
x = self.softmax(x, axis=2)
|
||||||
|
return np.sum(x * self.dfl_weights, axis=2)
|
||||||
|
|
||||||
|
def __call__(self, outputs):
|
||||||
|
if self.use_segmentation:
|
||||||
|
x, mc, p = outputs
|
||||||
|
else:
|
||||||
|
x = outputs
|
||||||
|
|
||||||
|
current_no = x[0].shape[1]
|
||||||
|
current_nc = current_no - self.reg_max * 4
|
||||||
|
shape = x[0].shape
|
||||||
|
|
||||||
|
x_cat = np.concatenate([xi.reshape(shape[0], current_no, -1) for xi in x], axis=2)
|
||||||
|
|
||||||
|
if self.anchors is None or self.shape != shape:
|
||||||
|
self.anchors, self.strides_array = self.make_anchors(x, self.strides, 0.5)
|
||||||
|
self.shape = shape
|
||||||
|
|
||||||
|
box, cls = np.split(x_cat, [self.reg_max * 4], axis=1)
|
||||||
|
dist = self.dfl_decode(box)
|
||||||
|
dist = dist.transpose(0, 2, 1)
|
||||||
|
dbox = self.dist2bbox(dist, self.anchors, xywh=True, dim=2) * self.strides_array
|
||||||
|
cls = cls.transpose(0, 2, 1)
|
||||||
|
sigmoid_cls = self.sigmoid(cls)
|
||||||
|
final_box = np.concatenate((dbox, sigmoid_cls), axis=2)
|
||||||
|
|
||||||
|
if self.use_segmentation:
|
||||||
|
return final_box, mc, p
|
||||||
|
|
||||||
|
return final_box
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
print("Testing Standard YOLO11...")
|
print("Testing Standard YOLO11...")
|
||||||
model_std = YOLO11(nc=80, scale='n')
|
model_std = YOLO11(nc=80, scale='n')
|
||||||
|
|||||||
BIN
yolo11s_int_qdq.onnx
Normal file
BIN
yolo11s_int_qdq.onnx
Normal file
Binary file not shown.
Reference in New Issue
Block a user