325 lines
11 KiB
Python
325 lines
11 KiB
Python
import os
|
|
import copy
|
|
import logging
|
|
import torch
|
|
import torch.nn as nn
|
|
import numpy as np
|
|
import onnx
|
|
from pathlib import Path
|
|
from dataclasses import dataclass, field
|
|
from typing import List, Optional
|
|
|
|
from torch.ao.quantization import QConfig
|
|
from torch.ao.quantization.fake_quantize import FakeQuantize
|
|
from torch.ao.quantization.observer import (
|
|
MovingAverageMinMaxObserver,
|
|
MovingAveragePerChannelMinMaxObserver,
|
|
)
|
|
from torch.ao.quantization.quantize_fx import prepare_qat_fx, convert_fx
|
|
from torch.ao.quantization.qconfig_mapping import QConfigMapping
|
|
from torch.onnx import register_custom_op_symbolic
|
|
|
|
# ONNX Runtime Quantization
|
|
from onnxruntime.quantization import (
|
|
quantize_static,
|
|
CalibrationDataReader,
|
|
QuantType,
|
|
QuantFormat,
|
|
)
|
|
|
|
LOGGER = logging.getLogger("QAT_Utils")
|
|
|
|
@dataclass
|
|
class QATConfig:
|
|
"""Unified Configuration for QAT and Export"""
|
|
# Training Config
|
|
ignore_layers: List[str] = field(default_factory=list)
|
|
|
|
# Export Config
|
|
img_size: int = 640
|
|
calibration_image_dir: Optional[str] = None
|
|
max_calibration_samples: int = 100
|
|
|
|
# Paths
|
|
save_dir: Path = Path('./runs/qat_unified')
|
|
|
|
def __post_init__(self):
|
|
self.save_dir = Path(self.save_dir)
|
|
self.save_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
class YOLOCalibrationDataReader(CalibrationDataReader):
|
|
"""YOLO model calibration data reader for ONNX Runtime"""
|
|
|
|
def __init__(self, calibration_image_folder: str, input_name: str = 'images',
|
|
img_size: int = 640, max_samples: int = 100):
|
|
self.input_name = input_name
|
|
self.img_size = img_size
|
|
self.max_samples = max_samples
|
|
|
|
if not calibration_image_folder or not os.path.exists(calibration_image_folder):
|
|
raise ValueError(f"Calibration folder not found: {calibration_image_folder}")
|
|
|
|
# Collect image paths
|
|
self.image_paths = []
|
|
for ext in ['*.jpg', '*.jpeg', '*.png', '*.bmp']:
|
|
self.image_paths.extend(Path(calibration_image_folder).glob(ext))
|
|
|
|
if not self.image_paths:
|
|
LOGGER.warning(f"No images found in {calibration_image_folder}")
|
|
|
|
self.image_paths = self.image_paths[:max_samples]
|
|
LOGGER.info(f"Found {len(self.image_paths)} calibration images in {calibration_image_folder}")
|
|
|
|
self.current_idx = 0
|
|
|
|
def preprocess(self, img_path):
|
|
import cv2
|
|
|
|
img = cv2.imread(str(img_path))
|
|
if img is None:
|
|
return None
|
|
|
|
# Resize with letterbox
|
|
h, w = img.shape[:2]
|
|
scale = min(self.img_size / h, self.img_size / w)
|
|
new_h, new_w = int(h * scale), int(w * scale)
|
|
|
|
img = cv2.resize(img, (new_w, new_h))
|
|
|
|
# Pad to target size
|
|
canvas = np.full((self.img_size, self.img_size, 3), 114, dtype=np.uint8)
|
|
pad_h = (self.img_size - new_h) // 2
|
|
pad_w = (self.img_size - new_w) // 2
|
|
canvas[pad_h:pad_h+new_h, pad_w:pad_w+new_w] = img
|
|
|
|
# Convert to model input format
|
|
img = canvas[:, :, ::-1] # BGR to RGB
|
|
img = img.transpose(2, 0, 1) # HWC to CHW
|
|
img = img.astype(np.float32) / 255.0
|
|
img = np.expand_dims(img, axis=0) # Add batch dimension
|
|
|
|
return img
|
|
|
|
def get_next(self):
|
|
if self.current_idx >= len(self.image_paths):
|
|
return None
|
|
|
|
img_path = self.image_paths[self.current_idx]
|
|
self.current_idx += 1
|
|
|
|
img = self.preprocess(img_path)
|
|
if img is None:
|
|
return self.get_next()
|
|
|
|
return {self.input_name: img}
|
|
|
|
def rewind(self):
|
|
self.current_idx = 0
|
|
|
|
|
|
def get_rknn_qconfig() -> QConfig:
|
|
return QConfig(
|
|
activation=FakeQuantize.with_args(
|
|
observer=MovingAverageMinMaxObserver,
|
|
quant_min=0,
|
|
quant_max=255,
|
|
reduce_range=False,
|
|
dtype=torch.quint8,
|
|
),
|
|
weight=FakeQuantize.with_args(
|
|
observer=MovingAveragePerChannelMinMaxObserver,
|
|
quant_min=-128,
|
|
quant_max=127,
|
|
dtype=torch.qint8,
|
|
qscheme=torch.per_channel_affine,
|
|
reduce_range=False,
|
|
)
|
|
)
|
|
|
|
def prepare_model_for_qat(
|
|
model: nn.Module,
|
|
example_inputs: torch.Tensor,
|
|
config: QATConfig
|
|
) -> nn.Module:
|
|
qconfig = get_rknn_qconfig()
|
|
qconfig_mapping = QConfigMapping().set_global(qconfig)
|
|
|
|
if config.ignore_layers:
|
|
LOGGER.info(f"Mixed Precision: Ignoring layers: {config.ignore_layers}")
|
|
for layer_name in config.ignore_layers:
|
|
qconfig_mapping.set_module_name(layer_name, None)
|
|
|
|
model_to_quantize = copy.deepcopy(model)
|
|
model_to_quantize.eval()
|
|
|
|
try:
|
|
prepared_model = prepare_qat_fx(
|
|
model_to_quantize,
|
|
qconfig_mapping=qconfig_mapping,
|
|
example_inputs=(example_inputs,)
|
|
)
|
|
return prepared_model
|
|
except Exception as e:
|
|
LOGGER.error(f"FX Preparation failed. Ensure model is traceable. Error: {e}")
|
|
raise e
|
|
|
|
def load_qat_weights_to_fp32(fp32_model: nn.Module, qat_state_dict: dict) -> nn.Module:
|
|
"""
|
|
Loads weights from a QAT model (with fake quant layers) into a standard FP32 model.
|
|
|
|
FX Graph mode wraps Conv+BN modules, so keys need to be mapped:
|
|
- QAT: model.X.conv.bn.weight -> FP32: model.X.bn.weight
|
|
- QAT: model.X.conv.weight -> FP32: model.X.conv.weight (direct match)
|
|
"""
|
|
import re
|
|
|
|
LOGGER.info("Loading QAT weights into FP32 model...")
|
|
fp32_model_dict = fp32_model.state_dict()
|
|
|
|
# Quantization-related keys to skip
|
|
quant_keywords = ['scale', 'zero_point', 'activation_post_process', 'fake_quant',
|
|
'_observer', 'eps', 'min_val', 'max_val', 'quant_min', 'quant_max',
|
|
'fake_quant_enabled', 'observer_enabled']
|
|
|
|
new_state_dict = {}
|
|
keys_loaded = []
|
|
missing_fp32_keys = set(fp32_model_dict.keys())
|
|
|
|
for qat_key, qat_value in qat_state_dict.items():
|
|
# Skip quantization buffers
|
|
if any(kw in qat_key for kw in quant_keywords):
|
|
continue
|
|
|
|
# Map QAT key to FP32 key: replace '.conv.bn.' with '.bn.'
|
|
fp32_key = re.sub(r'\.conv\.bn\.', '.bn.', qat_key)
|
|
|
|
if fp32_key in fp32_model_dict:
|
|
if qat_value.shape == fp32_model_dict[fp32_key].shape:
|
|
new_state_dict[fp32_key] = qat_value
|
|
keys_loaded.append(fp32_key)
|
|
missing_fp32_keys.discard(fp32_key)
|
|
else:
|
|
LOGGER.warning(f"Shape mismatch for {fp32_key}: QAT {qat_value.shape} vs FP32 {fp32_model_dict[fp32_key].shape}")
|
|
elif qat_key in fp32_model_dict:
|
|
# Fallback: try direct match
|
|
if qat_value.shape == fp32_model_dict[qat_key].shape:
|
|
new_state_dict[qat_key] = qat_value
|
|
keys_loaded.append(qat_key)
|
|
missing_fp32_keys.discard(qat_key)
|
|
|
|
load_result = fp32_model.load_state_dict(new_state_dict, strict=False)
|
|
|
|
LOGGER.info(f"Loaded {len(keys_loaded)} params from QAT checkpoint.")
|
|
if missing_fp32_keys:
|
|
LOGGER.warning(f"Missing keys: {len(missing_fp32_keys)}")
|
|
|
|
return fp32_model
|
|
|
|
|
|
def export_knn_compatible_onnx(
|
|
config: QATConfig,
|
|
model_class,
|
|
best_weights_path: str,
|
|
nc: int,
|
|
scale: str
|
|
):
|
|
"""
|
|
Orchestrates the full export pipeline:
|
|
1. Load pure FP32 model.
|
|
2. Load QAT-trained weights (best.pt) into FP32 model.
|
|
3. Export FP32 ONNX.
|
|
4. Run ORT Quantization (PTQ) to produce QDQ ONNX.
|
|
"""
|
|
|
|
fp32_onnx_path = config.save_dir / 'model_fp32.onnx'
|
|
qdq_onnx_path = config.save_dir / 'model_int8_qdq.onnx'
|
|
|
|
LOGGER.info(">>> STEP 1: Exporting Clean FP32 ONNX with QAT weights...")
|
|
|
|
# 1. Instantiate clean FP32 model
|
|
fp32_model = model_class(nc=nc, scale=scale)
|
|
|
|
# 2. Load QAT weights
|
|
ckpt = torch.load(best_weights_path, map_location='cpu')
|
|
if 'model' in ckpt:
|
|
state_dict = ckpt['model']
|
|
else:
|
|
state_dict = ckpt
|
|
|
|
fp32_model = load_qat_weights_to_fp32(fp32_model, state_dict)
|
|
fp32_model.eval()
|
|
|
|
# 3. Export FP32 ONNX
|
|
dummy_input = torch.randn(1, 3, config.img_size, config.img_size)
|
|
|
|
try:
|
|
torch.onnx.export(
|
|
fp32_model,
|
|
dummy_input,
|
|
str(fp32_onnx_path),
|
|
opset_version=13,
|
|
input_names=['images'],
|
|
output_names=['output'],
|
|
do_constant_folding=True,
|
|
dynamo=False
|
|
)
|
|
LOGGER.info(f"FP32 ONNX exported to {fp32_onnx_path}")
|
|
|
|
# 3.5 Pre-process (Shape Inference & Optimization) to suppress warnings and improve quantization
|
|
from onnxruntime.quantization import quant_pre_process
|
|
|
|
LOGGER.info(">>> STEP 1.5: Running ONNX Runtime Pre-processing...")
|
|
preprocessed_path = config.save_dir / 'model_fp32_preprocessed.onnx'
|
|
quant_pre_process(
|
|
input_model_path=str(fp32_onnx_path),
|
|
output_model_path=str(preprocessed_path),
|
|
skip_symbolic_shape=False
|
|
)
|
|
LOGGER.info(f"Pre-processed model saved to {preprocessed_path}")
|
|
|
|
# Update input path for quantization to use the preprocessed model
|
|
model_input_path = str(preprocessed_path)
|
|
|
|
except Exception as e:
|
|
LOGGER.error(f"Failed to export or pre-process FP32 ONNX: {e}")
|
|
return
|
|
|
|
LOGGER.info(">>> STEP 2: Running ONNX Runtime Quantization (PTQ)...")
|
|
|
|
if not config.calibration_image_dir:
|
|
LOGGER.error("No calibration image directory provided. Skipping PTQ.")
|
|
return
|
|
|
|
calibration_reader = YOLOCalibrationDataReader(
|
|
calibration_image_folder=config.calibration_image_dir,
|
|
max_samples=config.max_calibration_samples,
|
|
img_size=config.img_size
|
|
)
|
|
|
|
try:
|
|
quantize_static(
|
|
model_input=model_input_path, # Use the pre-processed model
|
|
model_output=str(qdq_onnx_path),
|
|
calibration_data_reader=calibration_reader,
|
|
quant_format=QuantFormat.QDQ,
|
|
activation_type=QuantType.QUInt8,
|
|
weight_type=QuantType.QInt8,
|
|
per_channel=True,
|
|
reduce_range=False,
|
|
extra_options={
|
|
'ActivationSymmetric': False,
|
|
'WeightSymmetric': True
|
|
}
|
|
)
|
|
LOGGER.info(f"QDQ ONNX exported to {qdq_onnx_path}")
|
|
LOGGER.info(">>> Export Workflow Complete.")
|
|
except Exception as e:
|
|
LOGGER.error(f"PTQ Failed: {e}")
|
|
|
|
def _aten_copy_symbolic(g, self, src, non_blocking):
|
|
return g.op("Identity", src)
|
|
|
|
register_custom_op_symbolic("aten::copy", _aten_copy_symbolic, 13)
|
|
|