Files
Yolo-standalone/qat_utils.py
2026-01-08 15:12:27 +08:00

325 lines
11 KiB
Python

import os
import copy
import logging
import torch
import torch.nn as nn
import numpy as np
import onnx
from pathlib import Path
from dataclasses import dataclass, field
from typing import List, Optional
from torch.ao.quantization import QConfig
from torch.ao.quantization.fake_quantize import FakeQuantize
from torch.ao.quantization.observer import (
MovingAverageMinMaxObserver,
MovingAveragePerChannelMinMaxObserver,
)
from torch.ao.quantization.quantize_fx import prepare_qat_fx, convert_fx
from torch.ao.quantization.qconfig_mapping import QConfigMapping
from torch.onnx import register_custom_op_symbolic
# ONNX Runtime Quantization
from onnxruntime.quantization import (
quantize_static,
CalibrationDataReader,
QuantType,
QuantFormat,
)
LOGGER = logging.getLogger("QAT_Utils")
@dataclass
class QATConfig:
"""Unified Configuration for QAT and Export"""
# Training Config
ignore_layers: List[str] = field(default_factory=list)
# Export Config
img_size: int = 640
calibration_image_dir: Optional[str] = None
max_calibration_samples: int = 100
# Paths
save_dir: Path = Path('./runs/qat_unified')
def __post_init__(self):
self.save_dir = Path(self.save_dir)
self.save_dir.mkdir(parents=True, exist_ok=True)
class YOLOCalibrationDataReader(CalibrationDataReader):
"""YOLO model calibration data reader for ONNX Runtime"""
def __init__(self, calibration_image_folder: str, input_name: str = 'images',
img_size: int = 640, max_samples: int = 100):
self.input_name = input_name
self.img_size = img_size
self.max_samples = max_samples
if not calibration_image_folder or not os.path.exists(calibration_image_folder):
raise ValueError(f"Calibration folder not found: {calibration_image_folder}")
# Collect image paths
self.image_paths = []
for ext in ['*.jpg', '*.jpeg', '*.png', '*.bmp']:
self.image_paths.extend(Path(calibration_image_folder).glob(ext))
if not self.image_paths:
LOGGER.warning(f"No images found in {calibration_image_folder}")
self.image_paths = self.image_paths[:max_samples]
LOGGER.info(f"Found {len(self.image_paths)} calibration images in {calibration_image_folder}")
self.current_idx = 0
def preprocess(self, img_path):
import cv2
img = cv2.imread(str(img_path))
if img is None:
return None
# Resize with letterbox
h, w = img.shape[:2]
scale = min(self.img_size / h, self.img_size / w)
new_h, new_w = int(h * scale), int(w * scale)
img = cv2.resize(img, (new_w, new_h))
# Pad to target size
canvas = np.full((self.img_size, self.img_size, 3), 114, dtype=np.uint8)
pad_h = (self.img_size - new_h) // 2
pad_w = (self.img_size - new_w) // 2
canvas[pad_h:pad_h+new_h, pad_w:pad_w+new_w] = img
# Convert to model input format
img = canvas[:, :, ::-1] # BGR to RGB
img = img.transpose(2, 0, 1) # HWC to CHW
img = img.astype(np.float32) / 255.0
img = np.expand_dims(img, axis=0) # Add batch dimension
return img
def get_next(self):
if self.current_idx >= len(self.image_paths):
return None
img_path = self.image_paths[self.current_idx]
self.current_idx += 1
img = self.preprocess(img_path)
if img is None:
return self.get_next()
return {self.input_name: img}
def rewind(self):
self.current_idx = 0
def get_rknn_qconfig() -> QConfig:
return QConfig(
activation=FakeQuantize.with_args(
observer=MovingAverageMinMaxObserver,
quant_min=0,
quant_max=255,
reduce_range=False,
dtype=torch.quint8,
),
weight=FakeQuantize.with_args(
observer=MovingAveragePerChannelMinMaxObserver,
quant_min=-128,
quant_max=127,
dtype=torch.qint8,
qscheme=torch.per_channel_affine,
reduce_range=False,
)
)
def prepare_model_for_qat(
model: nn.Module,
example_inputs: torch.Tensor,
config: QATConfig
) -> nn.Module:
qconfig = get_rknn_qconfig()
qconfig_mapping = QConfigMapping().set_global(qconfig)
if config.ignore_layers:
LOGGER.info(f"Mixed Precision: Ignoring layers: {config.ignore_layers}")
for layer_name in config.ignore_layers:
qconfig_mapping.set_module_name(layer_name, None)
model_to_quantize = copy.deepcopy(model)
model_to_quantize.eval()
try:
prepared_model = prepare_qat_fx(
model_to_quantize,
qconfig_mapping=qconfig_mapping,
example_inputs=(example_inputs,)
)
return prepared_model
except Exception as e:
LOGGER.error(f"FX Preparation failed. Ensure model is traceable. Error: {e}")
raise e
def load_qat_weights_to_fp32(fp32_model: nn.Module, qat_state_dict: dict) -> nn.Module:
"""
Loads weights from a QAT model (with fake quant layers) into a standard FP32 model.
FX Graph mode wraps Conv+BN modules, so keys need to be mapped:
- QAT: model.X.conv.bn.weight -> FP32: model.X.bn.weight
- QAT: model.X.conv.weight -> FP32: model.X.conv.weight (direct match)
"""
import re
LOGGER.info("Loading QAT weights into FP32 model...")
fp32_model_dict = fp32_model.state_dict()
# Quantization-related keys to skip
quant_keywords = ['scale', 'zero_point', 'activation_post_process', 'fake_quant',
'_observer', 'eps', 'min_val', 'max_val', 'quant_min', 'quant_max',
'fake_quant_enabled', 'observer_enabled']
new_state_dict = {}
keys_loaded = []
missing_fp32_keys = set(fp32_model_dict.keys())
for qat_key, qat_value in qat_state_dict.items():
# Skip quantization buffers
if any(kw in qat_key for kw in quant_keywords):
continue
# Map QAT key to FP32 key: replace '.conv.bn.' with '.bn.'
fp32_key = re.sub(r'\.conv\.bn\.', '.bn.', qat_key)
if fp32_key in fp32_model_dict:
if qat_value.shape == fp32_model_dict[fp32_key].shape:
new_state_dict[fp32_key] = qat_value
keys_loaded.append(fp32_key)
missing_fp32_keys.discard(fp32_key)
else:
LOGGER.warning(f"Shape mismatch for {fp32_key}: QAT {qat_value.shape} vs FP32 {fp32_model_dict[fp32_key].shape}")
elif qat_key in fp32_model_dict:
# Fallback: try direct match
if qat_value.shape == fp32_model_dict[qat_key].shape:
new_state_dict[qat_key] = qat_value
keys_loaded.append(qat_key)
missing_fp32_keys.discard(qat_key)
load_result = fp32_model.load_state_dict(new_state_dict, strict=False)
LOGGER.info(f"Loaded {len(keys_loaded)} params from QAT checkpoint.")
if missing_fp32_keys:
LOGGER.warning(f"Missing keys: {len(missing_fp32_keys)}")
return fp32_model
def export_knn_compatible_onnx(
config: QATConfig,
model_class,
best_weights_path: str,
nc: int,
scale: str
):
"""
Orchestrates the full export pipeline:
1. Load pure FP32 model.
2. Load QAT-trained weights (best.pt) into FP32 model.
3. Export FP32 ONNX.
4. Run ORT Quantization (PTQ) to produce QDQ ONNX.
"""
fp32_onnx_path = config.save_dir / 'model_fp32.onnx'
qdq_onnx_path = config.save_dir / 'model_int8_qdq.onnx'
LOGGER.info(">>> STEP 1: Exporting Clean FP32 ONNX with QAT weights...")
# 1. Instantiate clean FP32 model
fp32_model = model_class(nc=nc, scale=scale)
# 2. Load QAT weights
ckpt = torch.load(best_weights_path, map_location='cpu')
if 'model' in ckpt:
state_dict = ckpt['model']
else:
state_dict = ckpt
fp32_model = load_qat_weights_to_fp32(fp32_model, state_dict)
fp32_model.eval()
# 3. Export FP32 ONNX
dummy_input = torch.randn(1, 3, config.img_size, config.img_size)
try:
torch.onnx.export(
fp32_model,
dummy_input,
str(fp32_onnx_path),
opset_version=13,
input_names=['images'],
output_names=['output'],
do_constant_folding=True,
dynamo=False
)
LOGGER.info(f"FP32 ONNX exported to {fp32_onnx_path}")
# 3.5 Pre-process (Shape Inference & Optimization) to suppress warnings and improve quantization
from onnxruntime.quantization import quant_pre_process
LOGGER.info(">>> STEP 1.5: Running ONNX Runtime Pre-processing...")
preprocessed_path = config.save_dir / 'model_fp32_preprocessed.onnx'
quant_pre_process(
input_model_path=str(fp32_onnx_path),
output_model_path=str(preprocessed_path),
skip_symbolic_shape=False
)
LOGGER.info(f"Pre-processed model saved to {preprocessed_path}")
# Update input path for quantization to use the preprocessed model
model_input_path = str(preprocessed_path)
except Exception as e:
LOGGER.error(f"Failed to export or pre-process FP32 ONNX: {e}")
return
LOGGER.info(">>> STEP 2: Running ONNX Runtime Quantization (PTQ)...")
if not config.calibration_image_dir:
LOGGER.error("No calibration image directory provided. Skipping PTQ.")
return
calibration_reader = YOLOCalibrationDataReader(
calibration_image_folder=config.calibration_image_dir,
max_samples=config.max_calibration_samples,
img_size=config.img_size
)
try:
quantize_static(
model_input=model_input_path, # Use the pre-processed model
model_output=str(qdq_onnx_path),
calibration_data_reader=calibration_reader,
quant_format=QuantFormat.QDQ,
activation_type=QuantType.QUInt8,
weight_type=QuantType.QInt8,
per_channel=True,
reduce_range=False,
extra_options={
'ActivationSymmetric': False,
'WeightSymmetric': True
}
)
LOGGER.info(f"QDQ ONNX exported to {qdq_onnx_path}")
LOGGER.info(">>> Export Workflow Complete.")
except Exception as e:
LOGGER.error(f"PTQ Failed: {e}")
def _aten_copy_symbolic(g, self, src, non_blocking):
return g.op("Identity", src)
register_custom_op_symbolic("aten::copy", _aten_copy_symbolic, 13)