import os import copy import logging import torch import torch.nn as nn import numpy as np import onnx from pathlib import Path from dataclasses import dataclass, field from typing import List, Optional from torch.ao.quantization import QConfig from torch.ao.quantization.fake_quantize import FakeQuantize from torch.ao.quantization.observer import ( MovingAverageMinMaxObserver, MovingAveragePerChannelMinMaxObserver, ) from torch.ao.quantization.quantize_fx import prepare_qat_fx, convert_fx from torch.ao.quantization.qconfig_mapping import QConfigMapping from torch.onnx import register_custom_op_symbolic # ONNX Runtime Quantization from onnxruntime.quantization import ( quantize_static, CalibrationDataReader, QuantType, QuantFormat, ) LOGGER = logging.getLogger("QAT_Utils") @dataclass class QATConfig: """Unified Configuration for QAT and Export""" # Training Config ignore_layers: List[str] = field(default_factory=list) # Export Config img_size: int = 640 calibration_image_dir: Optional[str] = None max_calibration_samples: int = 100 # Paths save_dir: Path = Path('./runs/qat_unified') def __post_init__(self): self.save_dir = Path(self.save_dir) self.save_dir.mkdir(parents=True, exist_ok=True) class YOLOCalibrationDataReader(CalibrationDataReader): """YOLO model calibration data reader for ONNX Runtime""" def __init__(self, calibration_image_folder: str, input_name: str = 'images', img_size: int = 640, max_samples: int = 100): self.input_name = input_name self.img_size = img_size self.max_samples = max_samples if not calibration_image_folder or not os.path.exists(calibration_image_folder): raise ValueError(f"Calibration folder not found: {calibration_image_folder}") # Collect image paths self.image_paths = [] for ext in ['*.jpg', '*.jpeg', '*.png', '*.bmp']: self.image_paths.extend(Path(calibration_image_folder).glob(ext)) if not self.image_paths: LOGGER.warning(f"No images found in {calibration_image_folder}") self.image_paths = self.image_paths[:max_samples] LOGGER.info(f"Found {len(self.image_paths)} calibration images in {calibration_image_folder}") self.current_idx = 0 def preprocess(self, img_path): import cv2 img = cv2.imread(str(img_path)) if img is None: return None # Resize with letterbox h, w = img.shape[:2] scale = min(self.img_size / h, self.img_size / w) new_h, new_w = int(h * scale), int(w * scale) img = cv2.resize(img, (new_w, new_h)) # Pad to target size canvas = np.full((self.img_size, self.img_size, 3), 114, dtype=np.uint8) pad_h = (self.img_size - new_h) // 2 pad_w = (self.img_size - new_w) // 2 canvas[pad_h:pad_h+new_h, pad_w:pad_w+new_w] = img # Convert to model input format img = canvas[:, :, ::-1] # BGR to RGB img = img.transpose(2, 0, 1) # HWC to CHW img = img.astype(np.float32) / 255.0 img = np.expand_dims(img, axis=0) # Add batch dimension return img def get_next(self): if self.current_idx >= len(self.image_paths): return None img_path = self.image_paths[self.current_idx] self.current_idx += 1 img = self.preprocess(img_path) if img is None: return self.get_next() return {self.input_name: img} def rewind(self): self.current_idx = 0 def get_rknn_qconfig() -> QConfig: return QConfig( activation=FakeQuantize.with_args( observer=MovingAverageMinMaxObserver, quant_min=0, quant_max=255, reduce_range=False, dtype=torch.quint8, ), weight=FakeQuantize.with_args( observer=MovingAveragePerChannelMinMaxObserver, quant_min=-128, quant_max=127, dtype=torch.qint8, qscheme=torch.per_channel_affine, reduce_range=False, ) ) def prepare_model_for_qat( model: nn.Module, example_inputs: torch.Tensor, config: QATConfig ) -> nn.Module: qconfig = get_rknn_qconfig() qconfig_mapping = QConfigMapping().set_global(qconfig) if config.ignore_layers: LOGGER.info(f"Mixed Precision: Ignoring layers: {config.ignore_layers}") for layer_name in config.ignore_layers: qconfig_mapping.set_module_name(layer_name, None) model_to_quantize = copy.deepcopy(model) model_to_quantize.eval() try: prepared_model = prepare_qat_fx( model_to_quantize, qconfig_mapping=qconfig_mapping, example_inputs=(example_inputs,) ) return prepared_model except Exception as e: LOGGER.error(f"FX Preparation failed. Ensure model is traceable. Error: {e}") raise e def load_qat_weights_to_fp32(fp32_model: nn.Module, qat_state_dict: dict) -> nn.Module: """ Loads weights from a QAT model (with fake quant layers) into a standard FP32 model. FX Graph mode wraps Conv+BN modules, so keys need to be mapped: - QAT: model.X.conv.bn.weight -> FP32: model.X.bn.weight - QAT: model.X.conv.weight -> FP32: model.X.conv.weight (direct match) """ import re LOGGER.info("Loading QAT weights into FP32 model...") fp32_model_dict = fp32_model.state_dict() # Quantization-related keys to skip quant_keywords = ['scale', 'zero_point', 'activation_post_process', 'fake_quant', '_observer', 'eps', 'min_val', 'max_val', 'quant_min', 'quant_max', 'fake_quant_enabled', 'observer_enabled'] new_state_dict = {} keys_loaded = [] missing_fp32_keys = set(fp32_model_dict.keys()) for qat_key, qat_value in qat_state_dict.items(): # Skip quantization buffers if any(kw in qat_key for kw in quant_keywords): continue # Map QAT key to FP32 key: replace '.conv.bn.' with '.bn.' fp32_key = re.sub(r'\.conv\.bn\.', '.bn.', qat_key) if fp32_key in fp32_model_dict: if qat_value.shape == fp32_model_dict[fp32_key].shape: new_state_dict[fp32_key] = qat_value keys_loaded.append(fp32_key) missing_fp32_keys.discard(fp32_key) else: LOGGER.warning(f"Shape mismatch for {fp32_key}: QAT {qat_value.shape} vs FP32 {fp32_model_dict[fp32_key].shape}") elif qat_key in fp32_model_dict: # Fallback: try direct match if qat_value.shape == fp32_model_dict[qat_key].shape: new_state_dict[qat_key] = qat_value keys_loaded.append(qat_key) missing_fp32_keys.discard(qat_key) load_result = fp32_model.load_state_dict(new_state_dict, strict=False) LOGGER.info(f"Loaded {len(keys_loaded)} params from QAT checkpoint.") if missing_fp32_keys: LOGGER.warning(f"Missing keys: {len(missing_fp32_keys)}") return fp32_model def export_knn_compatible_onnx( config: QATConfig, model_class, best_weights_path: str, nc: int, scale: str ): """ Orchestrates the full export pipeline: 1. Load pure FP32 model. 2. Load QAT-trained weights (best.pt) into FP32 model. 3. Export FP32 ONNX. 4. Run ORT Quantization (PTQ) to produce QDQ ONNX. """ fp32_onnx_path = config.save_dir / 'model_fp32.onnx' qdq_onnx_path = config.save_dir / 'model_int8_qdq.onnx' LOGGER.info(">>> STEP 1: Exporting Clean FP32 ONNX with QAT weights...") # 1. Instantiate clean FP32 model fp32_model = model_class(nc=nc, scale=scale) # 2. Load QAT weights ckpt = torch.load(best_weights_path, map_location='cpu') if 'model' in ckpt: state_dict = ckpt['model'] else: state_dict = ckpt fp32_model = load_qat_weights_to_fp32(fp32_model, state_dict) fp32_model.eval() # 3. Export FP32 ONNX dummy_input = torch.randn(1, 3, config.img_size, config.img_size) try: torch.onnx.export( fp32_model, dummy_input, str(fp32_onnx_path), opset_version=13, input_names=['images'], output_names=['output'], do_constant_folding=True, dynamo=False ) LOGGER.info(f"FP32 ONNX exported to {fp32_onnx_path}") # 3.5 Pre-process (Shape Inference & Optimization) to suppress warnings and improve quantization from onnxruntime.quantization import quant_pre_process LOGGER.info(">>> STEP 1.5: Running ONNX Runtime Pre-processing...") preprocessed_path = config.save_dir / 'model_fp32_preprocessed.onnx' quant_pre_process( input_model_path=str(fp32_onnx_path), output_model_path=str(preprocessed_path), skip_symbolic_shape=False ) LOGGER.info(f"Pre-processed model saved to {preprocessed_path}") # Update input path for quantization to use the preprocessed model model_input_path = str(preprocessed_path) except Exception as e: LOGGER.error(f"Failed to export or pre-process FP32 ONNX: {e}") return LOGGER.info(">>> STEP 2: Running ONNX Runtime Quantization (PTQ)...") if not config.calibration_image_dir: LOGGER.error("No calibration image directory provided. Skipping PTQ.") return calibration_reader = YOLOCalibrationDataReader( calibration_image_folder=config.calibration_image_dir, max_samples=config.max_calibration_samples, img_size=config.img_size ) try: quantize_static( model_input=model_input_path, # Use the pre-processed model model_output=str(qdq_onnx_path), calibration_data_reader=calibration_reader, quant_format=QuantFormat.QDQ, activation_type=QuantType.QUInt8, weight_type=QuantType.QInt8, per_channel=True, reduce_range=False, extra_options={ 'ActivationSymmetric': False, 'WeightSymmetric': True } ) LOGGER.info(f"QDQ ONNX exported to {qdq_onnx_path}") LOGGER.info(">>> Export Workflow Complete.") except Exception as e: LOGGER.error(f"PTQ Failed: {e}") def _aten_copy_symbolic(g, self, src, non_blocking): return g.op("Identity", src) register_custom_op_symbolic("aten::copy", _aten_copy_symbolic, 13)