添加qat量化支持

This commit is contained in:
lhr
2026-01-08 15:12:27 +08:00
parent f4b1f341fc
commit 546a510eb2
8 changed files with 862 additions and 25 deletions

View File

@@ -2,7 +2,7 @@ import torch
import cv2 import cv2
import numpy as np import numpy as np
import torchvision import torchvision
from yolo11_standalone import YOLO11, YOLOPostProcessor from yolo11_standalone import YOLO11, YOLOPostProcessor, YOLOPostProcessorNumpy
CLASSES = [ CLASSES = [
"person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light", "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light",
@@ -26,7 +26,7 @@ def letterbox(im, new_shape=(640, 640), color=(114, 114, 114)):
r = min(new_shape[0] / shape[0], new_shape[1] / shape[1]) r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r)) new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1]
dw, dh = dw / 2, dh / 2 dw, dh = dw / 2, dh / 2
if shape[::-1] != new_unpad: if shape[::-1] != new_unpad:
@@ -40,13 +40,13 @@ def letterbox(im, new_shape=(640, 640), color=(114, 114, 114)):
def xywh2xyxy(x): def xywh2xyxy(x):
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
y[..., 0] = x[..., 0] - x[..., 2] / 2 # top left x y[..., 0] = x[..., 0] - x[..., 2] / 2
y[..., 1] = x[..., 1] - x[..., 3] / 2 # top left y y[..., 1] = x[..., 1] - x[..., 3] / 2
y[..., 2] = x[..., 0] + x[..., 2] / 2 # bottom right x y[..., 2] = x[..., 0] + x[..., 2] / 2
y[..., 3] = x[..., 1] + x[..., 3] / 2 # bottom right y y[..., 3] = x[..., 1] + x[..., 3] / 2
return y return y
def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, max_det=300): def non_max_suppression(prediction, conf_thres=0.01, iou_thres=0.45, max_det=300):
prediction = prediction.transpose(1, 2) prediction = prediction.transpose(1, 2)
bs = prediction.shape[0] bs = prediction.shape[0]
@@ -75,12 +75,63 @@ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, max_det=300
return output return output
def non_max_suppression_numpy(prediction, conf_thres=0.25, iou_thres=0.45, max_det=300):
bs = prediction.shape[0]
output = [np.zeros((0, 6), dtype=np.float32)] * bs
for xi, x in enumerate(prediction):
bbox_xywh = x[:, :4]
class_probs = x[:, 4:]
class_ids = np.argmax(class_probs, axis=1)
confidences = np.max(class_probs, axis=1)
mask = confidences > conf_thres
bbox_xywh = bbox_xywh[mask]
confidences = confidences[mask]
class_ids = class_ids[mask]
if len(confidences) == 0:
continue
bbox_tlwh = np.copy(bbox_xywh)
bbox_tlwh[:, 0] = bbox_xywh[:, 0] - bbox_xywh[:, 2] / 2
bbox_tlwh[:, 1] = bbox_xywh[:, 1] - bbox_xywh[:, 3] / 2
indices = cv2.dnn.NMSBoxes(
bboxes=bbox_tlwh.tolist(),
scores=confidences.tolist(),
score_threshold=conf_thres,
nms_threshold=iou_thres
)
if len(indices) > 0:
indices = indices.flatten()
if len(indices) > max_det:
indices = indices[:max_det]
final_boxes_xywh = bbox_xywh[indices]
final_boxes_xyxy = xywh2xyxy(final_boxes_xywh)
final_scores = confidences[indices]
final_classes = class_ids[indices]
out_tensor = np.concatenate([
final_boxes_xyxy,
final_scores[:, None],
final_classes[:, None]
], axis=1)
output[xi] = out_tensor
return output
def main(): def main():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}") print(f"Using device: {device}")
model = YOLO11(nc=80, scale='s') model = YOLO11(nc=80, scale='s')
model.load_weights("yolo11s.pth") model.load_weights("my_yolo_result_qat/best_fp32_converted.pth")
model.to(device) model.to(device)
model.eval() model.eval()
post_std = YOLOPostProcessor(model.model[-1], use_segmentation=False) post_std = YOLOPostProcessor(model.model[-1], use_segmentation=False)
@@ -104,20 +155,28 @@ def main():
with torch.no_grad(): with torch.no_grad():
pred = model(img_tensor) pred = model(img_tensor)
pred = post_std(pred) # pred = post_std(pred)
pred = non_max_suppression(pred, conf_thres=0.25, iou_thres=0.45) # pred = non_max_suppression(pred, conf_thres=0.25, iou_thres=0.45)
# det = pred[0]
det = pred[0] preds_raw_numpy = [p.cpu().numpy() for p in pred]
post_numpy = YOLOPostProcessorNumpy(strides=[8, 16, 32], reg_max=16, use_segmentation=False)
pred_numpy_decoded = post_numpy(preds_raw_numpy)
pred_results = non_max_suppression_numpy(pred_numpy_decoded, conf_thres=0.25, iou_thres=0.45)
det = pred_results[0]
if len(det): if len(det):
det[:, [0, 2]] -= dw # x padding det[:, [0, 2]] -= dw # x padding
det[:, [1, 3]] -= dh # y padding det[:, [1, 3]] -= dh # y padding
det[:, :4] /= ratio det[:, :4] /= ratio
det[:, 0].clamp_(0, img0.shape[1]) # det[:, 0].clamp_(0, img0.shape[1])
det[:, 1].clamp_(0, img0.shape[0]) # det[:, 1].clamp_(0, img0.shape[0])
det[:, 2].clamp_(0, img0.shape[1]) # det[:, 2].clamp_(0, img0.shape[1])
det[:, 3].clamp_(0, img0.shape[0]) # det[:, 3].clamp_(0, img0.shape[0])
det[:, 0] = np.clip(det[:, 0], 0, img0.shape[1])
det[:, 1] = np.clip(det[:, 1], 0, img0.shape[0])
det[:, 2] = np.clip(det[:, 2], 0, img0.shape[1])
det[:, 3] = np.clip(det[:, 3], 0, img0.shape[0])
print(f"检测到 {len(det)} 个目标") print(f"检测到 {len(det)} 个目标")

View File

@@ -4,7 +4,7 @@ import numpy as np
import torchvision import torchvision
from pathlib import Path from pathlib import Path
from yolo11_standalone import YOLO11E, YOLOPostProcessor from yolo11_standalone import YOLO11E, YOLOPostProcessor, YOLOPostProcessorNumpy
from mobile_clip_standalone import MobileCLIP from mobile_clip_standalone import MobileCLIP
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -58,6 +58,56 @@ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.7, max_det=300)
output[xi] = x[i] output[xi] = x[i]
return output return output
def non_max_suppression_numpy(prediction, conf_thres=0.25, iou_thres=0.45, max_det=300):
bs = prediction.shape[0]
output = [np.zeros((0, 6), dtype=np.float32)] * bs
for xi, x in enumerate(prediction):
bbox_xywh = x[:, :4]
class_probs = x[:, 4:]
class_ids = np.argmax(class_probs, axis=1)
confidences = np.max(class_probs, axis=1)
mask = confidences > conf_thres
bbox_xywh = bbox_xywh[mask]
confidences = confidences[mask]
class_ids = class_ids[mask]
if len(confidences) == 0:
continue
bbox_tlwh = np.copy(bbox_xywh)
bbox_tlwh[:, 0] = bbox_xywh[:, 0] - bbox_xywh[:, 2] / 2
bbox_tlwh[:, 1] = bbox_xywh[:, 1] - bbox_xywh[:, 3] / 2
indices = cv2.dnn.NMSBoxes(
bboxes=bbox_tlwh.tolist(),
scores=confidences.tolist(),
score_threshold=conf_thres,
nms_threshold=iou_thres
)
if len(indices) > 0:
indices = indices.flatten()
if len(indices) > max_det:
indices = indices[:max_det]
final_boxes_xywh = bbox_xywh[indices]
final_boxes_xyxy = xywh2xyxy(final_boxes_xywh)
final_scores = confidences[indices]
final_classes = class_ids[indices]
out_tensor = np.concatenate([
final_boxes_xyxy,
final_scores[:, None],
final_classes[:, None]
], axis=1)
output[xi] = out_tensor
return output
def main(): def main():
print(f"Using device: {DEVICE}") print(f"Using device: {DEVICE}")
@@ -94,17 +144,30 @@ def main():
print("Running inference...") print("Running inference...")
with torch.no_grad(): with torch.no_grad():
raw_outputs = yolo_model(img_tensor) raw_outputs = yolo_model(img_tensor)
decoded_box, mc, p = post_processor(raw_outputs) # decoded_box, mc, p = post_processor(raw_outputs)
# pred = non_max_suppression(decoded_box, conf_thres=0.25, iou_thres=0.7)
# det = pred[0]
feat_maps, mc, p = raw_outputs
pred = non_max_suppression(decoded_box, conf_thres=0.25, iou_thres=0.7) feat_maps_numpy = [f.detach().cpu().numpy() for f in feat_maps]
mc_numpy = mc.detach().cpu().numpy()
p_numpy = p.detach().cpu().numpy()
raw_outputs_numpy = (feat_maps_numpy, mc_numpy, p_numpy)
post_processor_numpy = YOLOPostProcessorNumpy(strides=[8, 16, 32], reg_max=16, use_segmentation=True)
decoded_box_numpy, mc_numpy_out, p_numpy_out = post_processor_numpy(raw_outputs_numpy)
pred_results = non_max_suppression_numpy(decoded_box_numpy, conf_thres=0.25, iou_thres=0.7)
det = pred_results[0]
det = pred[0]
if len(det): if len(det):
det[:, [0, 2]] -= dw det[:, [0, 2]] -= dw
det[:, [1, 3]] -= dh det[:, [1, 3]] -= dh
det[:, :4] /= ratio det[:, :4] /= ratio
det[:, [0, 2]].clamp_(0, img0.shape[1]) # det[:, [0, 2]].clamp_(0, img0.shape[1])
det[:, [1, 3]].clamp_(0, img0.shape[0]) # det[:, [1, 3]].clamp_(0, img0.shape[0])
det[:, 0] = np.clip(det[:, 0], 0, img0.shape[1])
det[:, 1] = np.clip(det[:, 1], 0, img0.shape[0])
det[:, 2] = np.clip(det[:, 2], 0, img0.shape[1])
det[:, 3] = np.clip(det[:, 3], 0, img0.shape[0])
print(f"Detected {len(det)} objects:") print(f"Detected {len(det)} objects:")
for *xyxy, conf, cls in det: for *xyxy, conf, cls in det:
@@ -119,7 +182,7 @@ def main():
else: else:
print("No objects detected.") print("No objects detected.")
cv2.imwrite("result_separate.jpg", img0) cv2.imwrite("result_yoloe.jpg", img0)
print("Result saved.") print("Result saved.")
if __name__ == "__main__": if __name__ == "__main__":

324
qat_utils.py Normal file
View File

@@ -0,0 +1,324 @@
import os
import copy
import logging
import torch
import torch.nn as nn
import numpy as np
import onnx
from pathlib import Path
from dataclasses import dataclass, field
from typing import List, Optional
from torch.ao.quantization import QConfig
from torch.ao.quantization.fake_quantize import FakeQuantize
from torch.ao.quantization.observer import (
MovingAverageMinMaxObserver,
MovingAveragePerChannelMinMaxObserver,
)
from torch.ao.quantization.quantize_fx import prepare_qat_fx, convert_fx
from torch.ao.quantization.qconfig_mapping import QConfigMapping
from torch.onnx import register_custom_op_symbolic
# ONNX Runtime Quantization
from onnxruntime.quantization import (
quantize_static,
CalibrationDataReader,
QuantType,
QuantFormat,
)
LOGGER = logging.getLogger("QAT_Utils")
@dataclass
class QATConfig:
"""Unified Configuration for QAT and Export"""
# Training Config
ignore_layers: List[str] = field(default_factory=list)
# Export Config
img_size: int = 640
calibration_image_dir: Optional[str] = None
max_calibration_samples: int = 100
# Paths
save_dir: Path = Path('./runs/qat_unified')
def __post_init__(self):
self.save_dir = Path(self.save_dir)
self.save_dir.mkdir(parents=True, exist_ok=True)
class YOLOCalibrationDataReader(CalibrationDataReader):
"""YOLO model calibration data reader for ONNX Runtime"""
def __init__(self, calibration_image_folder: str, input_name: str = 'images',
img_size: int = 640, max_samples: int = 100):
self.input_name = input_name
self.img_size = img_size
self.max_samples = max_samples
if not calibration_image_folder or not os.path.exists(calibration_image_folder):
raise ValueError(f"Calibration folder not found: {calibration_image_folder}")
# Collect image paths
self.image_paths = []
for ext in ['*.jpg', '*.jpeg', '*.png', '*.bmp']:
self.image_paths.extend(Path(calibration_image_folder).glob(ext))
if not self.image_paths:
LOGGER.warning(f"No images found in {calibration_image_folder}")
self.image_paths = self.image_paths[:max_samples]
LOGGER.info(f"Found {len(self.image_paths)} calibration images in {calibration_image_folder}")
self.current_idx = 0
def preprocess(self, img_path):
import cv2
img = cv2.imread(str(img_path))
if img is None:
return None
# Resize with letterbox
h, w = img.shape[:2]
scale = min(self.img_size / h, self.img_size / w)
new_h, new_w = int(h * scale), int(w * scale)
img = cv2.resize(img, (new_w, new_h))
# Pad to target size
canvas = np.full((self.img_size, self.img_size, 3), 114, dtype=np.uint8)
pad_h = (self.img_size - new_h) // 2
pad_w = (self.img_size - new_w) // 2
canvas[pad_h:pad_h+new_h, pad_w:pad_w+new_w] = img
# Convert to model input format
img = canvas[:, :, ::-1] # BGR to RGB
img = img.transpose(2, 0, 1) # HWC to CHW
img = img.astype(np.float32) / 255.0
img = np.expand_dims(img, axis=0) # Add batch dimension
return img
def get_next(self):
if self.current_idx >= len(self.image_paths):
return None
img_path = self.image_paths[self.current_idx]
self.current_idx += 1
img = self.preprocess(img_path)
if img is None:
return self.get_next()
return {self.input_name: img}
def rewind(self):
self.current_idx = 0
def get_rknn_qconfig() -> QConfig:
return QConfig(
activation=FakeQuantize.with_args(
observer=MovingAverageMinMaxObserver,
quant_min=0,
quant_max=255,
reduce_range=False,
dtype=torch.quint8,
),
weight=FakeQuantize.with_args(
observer=MovingAveragePerChannelMinMaxObserver,
quant_min=-128,
quant_max=127,
dtype=torch.qint8,
qscheme=torch.per_channel_affine,
reduce_range=False,
)
)
def prepare_model_for_qat(
model: nn.Module,
example_inputs: torch.Tensor,
config: QATConfig
) -> nn.Module:
qconfig = get_rknn_qconfig()
qconfig_mapping = QConfigMapping().set_global(qconfig)
if config.ignore_layers:
LOGGER.info(f"Mixed Precision: Ignoring layers: {config.ignore_layers}")
for layer_name in config.ignore_layers:
qconfig_mapping.set_module_name(layer_name, None)
model_to_quantize = copy.deepcopy(model)
model_to_quantize.eval()
try:
prepared_model = prepare_qat_fx(
model_to_quantize,
qconfig_mapping=qconfig_mapping,
example_inputs=(example_inputs,)
)
return prepared_model
except Exception as e:
LOGGER.error(f"FX Preparation failed. Ensure model is traceable. Error: {e}")
raise e
def load_qat_weights_to_fp32(fp32_model: nn.Module, qat_state_dict: dict) -> nn.Module:
"""
Loads weights from a QAT model (with fake quant layers) into a standard FP32 model.
FX Graph mode wraps Conv+BN modules, so keys need to be mapped:
- QAT: model.X.conv.bn.weight -> FP32: model.X.bn.weight
- QAT: model.X.conv.weight -> FP32: model.X.conv.weight (direct match)
"""
import re
LOGGER.info("Loading QAT weights into FP32 model...")
fp32_model_dict = fp32_model.state_dict()
# Quantization-related keys to skip
quant_keywords = ['scale', 'zero_point', 'activation_post_process', 'fake_quant',
'_observer', 'eps', 'min_val', 'max_val', 'quant_min', 'quant_max',
'fake_quant_enabled', 'observer_enabled']
new_state_dict = {}
keys_loaded = []
missing_fp32_keys = set(fp32_model_dict.keys())
for qat_key, qat_value in qat_state_dict.items():
# Skip quantization buffers
if any(kw in qat_key for kw in quant_keywords):
continue
# Map QAT key to FP32 key: replace '.conv.bn.' with '.bn.'
fp32_key = re.sub(r'\.conv\.bn\.', '.bn.', qat_key)
if fp32_key in fp32_model_dict:
if qat_value.shape == fp32_model_dict[fp32_key].shape:
new_state_dict[fp32_key] = qat_value
keys_loaded.append(fp32_key)
missing_fp32_keys.discard(fp32_key)
else:
LOGGER.warning(f"Shape mismatch for {fp32_key}: QAT {qat_value.shape} vs FP32 {fp32_model_dict[fp32_key].shape}")
elif qat_key in fp32_model_dict:
# Fallback: try direct match
if qat_value.shape == fp32_model_dict[qat_key].shape:
new_state_dict[qat_key] = qat_value
keys_loaded.append(qat_key)
missing_fp32_keys.discard(qat_key)
load_result = fp32_model.load_state_dict(new_state_dict, strict=False)
LOGGER.info(f"Loaded {len(keys_loaded)} params from QAT checkpoint.")
if missing_fp32_keys:
LOGGER.warning(f"Missing keys: {len(missing_fp32_keys)}")
return fp32_model
def export_knn_compatible_onnx(
config: QATConfig,
model_class,
best_weights_path: str,
nc: int,
scale: str
):
"""
Orchestrates the full export pipeline:
1. Load pure FP32 model.
2. Load QAT-trained weights (best.pt) into FP32 model.
3. Export FP32 ONNX.
4. Run ORT Quantization (PTQ) to produce QDQ ONNX.
"""
fp32_onnx_path = config.save_dir / 'model_fp32.onnx'
qdq_onnx_path = config.save_dir / 'model_int8_qdq.onnx'
LOGGER.info(">>> STEP 1: Exporting Clean FP32 ONNX with QAT weights...")
# 1. Instantiate clean FP32 model
fp32_model = model_class(nc=nc, scale=scale)
# 2. Load QAT weights
ckpt = torch.load(best_weights_path, map_location='cpu')
if 'model' in ckpt:
state_dict = ckpt['model']
else:
state_dict = ckpt
fp32_model = load_qat_weights_to_fp32(fp32_model, state_dict)
fp32_model.eval()
# 3. Export FP32 ONNX
dummy_input = torch.randn(1, 3, config.img_size, config.img_size)
try:
torch.onnx.export(
fp32_model,
dummy_input,
str(fp32_onnx_path),
opset_version=13,
input_names=['images'],
output_names=['output'],
do_constant_folding=True,
dynamo=False
)
LOGGER.info(f"FP32 ONNX exported to {fp32_onnx_path}")
# 3.5 Pre-process (Shape Inference & Optimization) to suppress warnings and improve quantization
from onnxruntime.quantization import quant_pre_process
LOGGER.info(">>> STEP 1.5: Running ONNX Runtime Pre-processing...")
preprocessed_path = config.save_dir / 'model_fp32_preprocessed.onnx'
quant_pre_process(
input_model_path=str(fp32_onnx_path),
output_model_path=str(preprocessed_path),
skip_symbolic_shape=False
)
LOGGER.info(f"Pre-processed model saved to {preprocessed_path}")
# Update input path for quantization to use the preprocessed model
model_input_path = str(preprocessed_path)
except Exception as e:
LOGGER.error(f"Failed to export or pre-process FP32 ONNX: {e}")
return
LOGGER.info(">>> STEP 2: Running ONNX Runtime Quantization (PTQ)...")
if not config.calibration_image_dir:
LOGGER.error("No calibration image directory provided. Skipping PTQ.")
return
calibration_reader = YOLOCalibrationDataReader(
calibration_image_folder=config.calibration_image_dir,
max_samples=config.max_calibration_samples,
img_size=config.img_size
)
try:
quantize_static(
model_input=model_input_path, # Use the pre-processed model
model_output=str(qdq_onnx_path),
calibration_data_reader=calibration_reader,
quant_format=QuantFormat.QDQ,
activation_type=QuantType.QUInt8,
weight_type=QuantType.QInt8,
per_channel=True,
reduce_range=False,
extra_options={
'ActivationSymmetric': False,
'WeightSymmetric': True
}
)
LOGGER.info(f"QDQ ONNX exported to {qdq_onnx_path}")
LOGGER.info(">>> Export Workflow Complete.")
except Exception as e:
LOGGER.error(f"PTQ Failed: {e}")
def _aten_copy_symbolic(g, self, src, non_blocking):
return g.op("Identity", src)
register_custom_op_symbolic("aten::copy", _aten_copy_symbolic, 13)

109
train_qat.py Normal file
View File

@@ -0,0 +1,109 @@
import torch
from torch.utils.data import DataLoader
from dataset import YOLODataset
from yolo11_standalone import YOLO11, YOLOPostProcessor
from loss import YOLOv8DetectionLoss
from trainer_qat import YOLO11QATTrainer
import qat_utils
from qat_utils import QATConfig
def run_qat_training():
# ==========================
# 1. Configuration
# ==========================
img_dir_train = "E:\\Datasets\\coco\\images\\train2017"
label_dir_train = "E:\\Datasets\\coco\\labels\\train2017"
img_dir_val = "E:\\Datasets\\coco\\images\\val2017"
label_dir_val = "E:\\Datasets\\coco\\labels\\val2017"
# img_dir_train = "E:\\Datasets\\coco8\\images\\train"
# label_dir_train = "E:\\Datasets\\coco8\\labels\\train"
# img_dir_val = "E:\\Datasets\\coco8\\images\\val"
# label_dir_val = "E:\\Datasets\\coco8\\labels\\val"
# Validation/Calibration Set (reused for PTQ calibration)
# Using val set for calibration is common, or a subset of train
calibration_img_dir = img_dir_val
weights = "yolo11s.pth"
scale = 's'
nc = 80
# QAT Training Hyperparameters
epochs = 10
batch_size = 16 # Suggest increasing to 32 or 64 if GPU memory allows for better BN stats
lr = 0.0001 # Lower LR (1e-4) is safer for QAT fine-tuning to preserve pretrained knowledge
# Unified QAT Configuration
config = QATConfig(
img_size=640,
save_dir='./my_yolo_result_qat',
calibration_image_dir=calibration_img_dir,
max_calibration_samples=500, # Increased for full COCO to ensure robust PTQ quantization parameters
ignore_layers=[] # Add layer names here to skip quantization if needed
)
device = "cuda" if torch.cuda.is_available() else "cpu"
# ==========================
# 2. Data Preparation
# ==========================
print("Loading Data...")
train_dataset = YOLODataset(img_dir_train, label_dir_train, img_size=config.img_size, is_train=True)
val_dataset = YOLODataset(img_dir_val, label_dir_val, img_size=config.img_size, is_train=False)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,
collate_fn=YOLODataset.collate_fn, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False,
collate_fn=YOLODataset.collate_fn, num_workers=4, pin_memory=True)
# ==========================
# 3. Model Initialization
# ==========================
print("Initializing Model...")
model = YOLO11(nc=nc, scale=scale)
if weights:
model.load_weights(weights)
post_std = YOLOPostProcessor(model.model[-1], use_segmentation=False)
loss_fn = YOLOv8DetectionLoss(nc=nc, reg_max=16, stride=[8, 16, 32])
# ==========================
# 4. QAT Training
# ==========================
print("Creating QAT Trainer...")
trainer = YOLO11QATTrainer(
model=model,
post_processor=post_std,
train_loader=train_loader,
val_loader=val_loader,
loss_fn=loss_fn,
config=config,
epochs=epochs,
lr=lr,
device=device
)
trainer.train()
# ==========================
# 5. Export Pipeline
# (Load Best -> FP32 -> ONNX -> QDQ PTQ)
# ==========================
print("\nStarting Export Pipeline...")
best_weights = config.save_dir / 'best.pt'
if best_weights.exists():
qat_utils.export_knn_compatible_onnx(
config=config,
model_class=YOLO11,
best_weights_path=str(best_weights),
nc=nc,
scale=scale
)
else:
print("Error: Best weights not found. Export failed.")
if __name__ == "__main__":
run_qat_training()

View File

@@ -146,7 +146,6 @@ class YOLO11Trainer:
iou_thres = 0.7 iou_thres = 0.7
iouv = torch.linspace(0.5, 0.95, 10, device=device) iouv = torch.linspace(0.5, 0.95, 10, device=device)
loss_sum = torch.zeros(3, device=device)
stats = [] stats = []
LOGGER.info("\nValidating...") LOGGER.info("\nValidating...")

199
trainer_qat.py Normal file
View File

@@ -0,0 +1,199 @@
import math
import logging
import time
import numpy as np
from pathlib import Path
from typing import Optional, List
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
from metrics import ap_per_class, box_iou, non_max_suppression, xywh2xyxy
from qat_utils import prepare_model_for_qat, QATConfig
logging.basicConfig(format="%(message)s", level=logging.INFO)
LOGGER = logging.getLogger("QAT_Trainer")
class YOLO11QATTrainer:
def __init__(self,
model,
post_processor,
train_loader,
val_loader,
loss_fn,
config: QATConfig,
epochs: int = 5,
lr: float = 0.001,
device: str = 'cuda'):
self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
self.config = config
self.save_dir = config.save_dir
self.original_model = model.to(self.device)
self.post_processor = post_processor
self.train_loader = train_loader
self.val_loader = val_loader
self.loss_fn = loss_fn
self.epochs = epochs
self.lr = lr
self.qat_model = None
self.optimizer = None
self.scheduler = None
def setup(self):
LOGGER.info(">>> Setting up QAT...")
example_input = torch.randn(1, 3, self.config.img_size, self.config.img_size).to(self.device)
self.qat_model = prepare_model_for_qat(
self.original_model,
example_input,
config=self.config
)
self.qat_model.to(self.device)
self.optimizer = optim.Adam(self.qat_model.parameters(), lr=self.lr, weight_decay=1e-5)
lf = lambda x: ((1 - math.cos(x * math.pi / self.epochs)) / 2) * (0.01 - 1) + 1
self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lf)
def train_epoch(self, epoch):
self.qat_model.train()
pbar = tqdm(self.train_loader, desc=f'Epoch {epoch+1}/{self.epochs}')
mean_loss = torch.zeros(3, device=self.device)
for i, (imgs, targets, _) in enumerate(pbar):
imgs = imgs.to(self.device).float()
targets = targets.to(self.device)
self.optimizer.zero_grad()
preds = self.qat_model(imgs)
target_batch = {
"batch_idx": targets[:, 0],
"cls": targets[:, 1],
"bboxes": targets[:, 2:],
}
loss, loss_items = self.loss_fn(preds, target_batch)
loss.backward()
torch.nn.utils.clip_grad_norm_(self.qat_model.parameters(), max_norm=10.0)
self.optimizer.step()
mean_loss = (mean_loss * i + loss_items.detach()) / (i + 1)
pbar.set_postfix({
'box': f"{mean_loss[0]:.4f}",
'cls': f"{mean_loss[1]:.4f}",
'dfl': f"{mean_loss[2]:.4f}",
'lr': f"{self.optimizer.param_groups[0]['lr']:.6f}"
})
def validate(self):
conf_thres = 0.001
iou_thres = 0.7
iouv = torch.linspace(0.5, 0.95, 10, device=self.device)
stats = []
LOGGER.info("\nValidating...")
self.qat_model.eval()
pbar = tqdm(self.val_loader, desc="Calc Metrics")
with torch.no_grad():
for batch in pbar:
imgs, targets, _ = batch
imgs = imgs.to(self.device, non_blocking=True)
targets = targets.to(self.device)
_, _, height, width = imgs.shape
preds = self.qat_model(imgs)
preds = self.post_processor(preds)
preds = non_max_suppression(preds, conf_thres=conf_thres, iou_thres=iou_thres)
for si, pred in enumerate(preds):
labels = targets[targets[:, 0] == si]
nl = len(labels)
tcls = labels[:, 1].tolist() if nl else []
if len(pred) == 0:
if nl:
stats.append((torch.zeros(0, iouv.numel(), dtype=torch.bool),
torch.Tensor(), torch.Tensor(), torch.tensor(tcls)))
continue
predn = pred.clone()
if nl:
tbox = xywh2xyxy(labels[:, 2:6])
tbox[:, [0, 2]] *= width
tbox[:, [1, 3]] *= height
labelsn = torch.cat((labels[:, 1:2], tbox), 1)
correct = self._process_batch(predn, labelsn, iouv)
else:
correct = torch.zeros(pred.shape[0], iouv.numel(), dtype=torch.bool)
stats.append((correct.cpu(), pred[:, 4].cpu(), pred[:, 5].cpu(), torch.tensor(tcls)))
mp, mr, map50, map5095 = 0.0, 0.0, 0.0, 0.0
if len(stats) and stats[0][0].any():
stats = [torch.cat(x, 0).cpu().numpy() for x in zip(*stats)]
tp, fp, p, r, f1, ap, ap_class = ap_per_class(*stats)
ap50, ap = ap[:, 0], ap.mean(1)
mp, mr, map50, map5095 = p.mean(), r.mean(), ap50.mean(), ap.mean()
LOGGER.info(f"Val Results: Prec={mp:.3f}, Recall={mr:.3f} mAP50={map50:.3f} mAP50-95={map5095:.3f}")
return map50
def _process_batch(self, detections, labels, iouv):
correct = torch.zeros(detections.shape[0], iouv.shape[0], dtype=torch.bool, device=iouv.device)
iou = box_iou(labels[:, 1:], detections[:, :4])
x = torch.where((iou >= iouv[0]) & (labels[:, 0:1] == detections[:, 5]))
if x[0].shape[0]:
matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1).cpu().numpy()
if x[0].shape[0] > 1:
matches = matches[matches[:, 2].argsort()[::-1]]
matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
matches = matches[matches[:, 2].argsort()[::-1]]
matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
matches = torch.from_numpy(matches).to(iouv.device)
correct[matches[:, 1].long()] = matches[:, 2:3] >= iouv
return correct
def train(self):
self.setup()
LOGGER.info(f"Starting training on {self.device} for {self.epochs} epochs...")
start_time = time.time()
best_fitness = -1
for epoch in range(self.epochs):
self.train_epoch(epoch)
self.scheduler.step()
map50 = self.validate()
ckpt = {
'epoch': epoch,
'model': self.qat_model.state_dict(),
'optimizer': self.optimizer.state_dict(),
}
torch.save(ckpt, self.save_dir / 'last_qat.pt')
if map50 > best_fitness:
best_fitness = map50
torch.save(ckpt, self.save_dir / 'best.pt')
LOGGER.info(f"--> Saved best model with Recall/mAP: {best_fitness:.4f}")
LOGGER.info(">>> Training Finished.")

View File

@@ -562,6 +562,90 @@ class YOLO11E(YOLO11):
return head(feats, cls_pe) return head(feats, cls_pe)
# ==============================================================================
# [Part 4] PostProcessorNumpy
# ==============================================================================
import numpy as np
class YOLOPostProcessorNumpy:
def __init__(self, strides=[8, 16, 32], reg_max=16, use_segmentation=False):
self.strides = np.array(strides, dtype=np.float32)
self.reg_max = reg_max
self.use_segmentation = use_segmentation
self.anchors = None
self.strides_array = None
self.shape = None
self.dfl_weights = np.arange(reg_max, dtype=np.float32).reshape(1, 1, reg_max, 1)
def sigmoid(self, x):
return 1 / (1 + np.exp(-x))
def softmax(self, x, axis=-1):
x_max = np.max(x, axis=axis, keepdims=True)
e_x = np.exp(x - x_max)
return e_x / np.sum(e_x, axis=axis, keepdims=True)
def make_anchors(self, feats, strides, grid_cell_offset=0.5):
anchor_points, stride_list = [], []
for i, stride in enumerate(strides):
_, _, h, w = feats[i].shape
sx = np.arange(w, dtype=np.float32) + grid_cell_offset
sy = np.arange(h, dtype=np.float32) + grid_cell_offset
sy, sx = np.meshgrid(sy, sx, indexing='ij')
anchor_points.append(np.stack((sx, sy), -1).reshape(-1, 2))
stride_list.append(np.full((h * w, 1), stride, dtype=np.float32))
return np.concatenate(anchor_points), np.concatenate(stride_list)
def dist2bbox(self, distance, anchor_points, xywh=True, dim=-1):
lt, rb = np.split(distance, 2, axis=dim)
x1y1 = anchor_points - lt
x2y2 = anchor_points + rb
if xywh:
c_xy = (x1y1 + x2y2) / 2
wh = x2y2 - x1y1
return np.concatenate((c_xy, wh), axis=dim)
return np.concatenate((x1y1, x2y2), axis=dim)
def dfl_decode(self, x):
B, C, A = x.shape
x = x.reshape(B, 4, self.reg_max, A)
x = self.softmax(x, axis=2)
return np.sum(x * self.dfl_weights, axis=2)
def __call__(self, outputs):
if self.use_segmentation:
x, mc, p = outputs
else:
x = outputs
current_no = x[0].shape[1]
current_nc = current_no - self.reg_max * 4
shape = x[0].shape
x_cat = np.concatenate([xi.reshape(shape[0], current_no, -1) for xi in x], axis=2)
if self.anchors is None or self.shape != shape:
self.anchors, self.strides_array = self.make_anchors(x, self.strides, 0.5)
self.shape = shape
box, cls = np.split(x_cat, [self.reg_max * 4], axis=1)
dist = self.dfl_decode(box)
dist = dist.transpose(0, 2, 1)
dbox = self.dist2bbox(dist, self.anchors, xywh=True, dim=2) * self.strides_array
cls = cls.transpose(0, 2, 1)
sigmoid_cls = self.sigmoid(cls)
final_box = np.concatenate((dbox, sigmoid_cls), axis=2)
if self.use_segmentation:
return final_box, mc, p
return final_box
if __name__ == "__main__": if __name__ == "__main__":
print("Testing Standard YOLO11...") print("Testing Standard YOLO11...")
model_std = YOLO11(nc=80, scale='n') model_std = YOLO11(nc=80, scale='n')

BIN
yolo11s_int_qdq.onnx Normal file

Binary file not shown.