添加qat量化支持
This commit is contained in:
324
qat_utils.py
Normal file
324
qat_utils.py
Normal file
@@ -0,0 +1,324 @@
|
||||
import os
|
||||
import copy
|
||||
import logging
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
import onnx
|
||||
from pathlib import Path
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional
|
||||
|
||||
from torch.ao.quantization import QConfig
|
||||
from torch.ao.quantization.fake_quantize import FakeQuantize
|
||||
from torch.ao.quantization.observer import (
|
||||
MovingAverageMinMaxObserver,
|
||||
MovingAveragePerChannelMinMaxObserver,
|
||||
)
|
||||
from torch.ao.quantization.quantize_fx import prepare_qat_fx, convert_fx
|
||||
from torch.ao.quantization.qconfig_mapping import QConfigMapping
|
||||
from torch.onnx import register_custom_op_symbolic
|
||||
|
||||
# ONNX Runtime Quantization
|
||||
from onnxruntime.quantization import (
|
||||
quantize_static,
|
||||
CalibrationDataReader,
|
||||
QuantType,
|
||||
QuantFormat,
|
||||
)
|
||||
|
||||
LOGGER = logging.getLogger("QAT_Utils")
|
||||
|
||||
@dataclass
|
||||
class QATConfig:
|
||||
"""Unified Configuration for QAT and Export"""
|
||||
# Training Config
|
||||
ignore_layers: List[str] = field(default_factory=list)
|
||||
|
||||
# Export Config
|
||||
img_size: int = 640
|
||||
calibration_image_dir: Optional[str] = None
|
||||
max_calibration_samples: int = 100
|
||||
|
||||
# Paths
|
||||
save_dir: Path = Path('./runs/qat_unified')
|
||||
|
||||
def __post_init__(self):
|
||||
self.save_dir = Path(self.save_dir)
|
||||
self.save_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
class YOLOCalibrationDataReader(CalibrationDataReader):
|
||||
"""YOLO model calibration data reader for ONNX Runtime"""
|
||||
|
||||
def __init__(self, calibration_image_folder: str, input_name: str = 'images',
|
||||
img_size: int = 640, max_samples: int = 100):
|
||||
self.input_name = input_name
|
||||
self.img_size = img_size
|
||||
self.max_samples = max_samples
|
||||
|
||||
if not calibration_image_folder or not os.path.exists(calibration_image_folder):
|
||||
raise ValueError(f"Calibration folder not found: {calibration_image_folder}")
|
||||
|
||||
# Collect image paths
|
||||
self.image_paths = []
|
||||
for ext in ['*.jpg', '*.jpeg', '*.png', '*.bmp']:
|
||||
self.image_paths.extend(Path(calibration_image_folder).glob(ext))
|
||||
|
||||
if not self.image_paths:
|
||||
LOGGER.warning(f"No images found in {calibration_image_folder}")
|
||||
|
||||
self.image_paths = self.image_paths[:max_samples]
|
||||
LOGGER.info(f"Found {len(self.image_paths)} calibration images in {calibration_image_folder}")
|
||||
|
||||
self.current_idx = 0
|
||||
|
||||
def preprocess(self, img_path):
|
||||
import cv2
|
||||
|
||||
img = cv2.imread(str(img_path))
|
||||
if img is None:
|
||||
return None
|
||||
|
||||
# Resize with letterbox
|
||||
h, w = img.shape[:2]
|
||||
scale = min(self.img_size / h, self.img_size / w)
|
||||
new_h, new_w = int(h * scale), int(w * scale)
|
||||
|
||||
img = cv2.resize(img, (new_w, new_h))
|
||||
|
||||
# Pad to target size
|
||||
canvas = np.full((self.img_size, self.img_size, 3), 114, dtype=np.uint8)
|
||||
pad_h = (self.img_size - new_h) // 2
|
||||
pad_w = (self.img_size - new_w) // 2
|
||||
canvas[pad_h:pad_h+new_h, pad_w:pad_w+new_w] = img
|
||||
|
||||
# Convert to model input format
|
||||
img = canvas[:, :, ::-1] # BGR to RGB
|
||||
img = img.transpose(2, 0, 1) # HWC to CHW
|
||||
img = img.astype(np.float32) / 255.0
|
||||
img = np.expand_dims(img, axis=0) # Add batch dimension
|
||||
|
||||
return img
|
||||
|
||||
def get_next(self):
|
||||
if self.current_idx >= len(self.image_paths):
|
||||
return None
|
||||
|
||||
img_path = self.image_paths[self.current_idx]
|
||||
self.current_idx += 1
|
||||
|
||||
img = self.preprocess(img_path)
|
||||
if img is None:
|
||||
return self.get_next()
|
||||
|
||||
return {self.input_name: img}
|
||||
|
||||
def rewind(self):
|
||||
self.current_idx = 0
|
||||
|
||||
|
||||
def get_rknn_qconfig() -> QConfig:
|
||||
return QConfig(
|
||||
activation=FakeQuantize.with_args(
|
||||
observer=MovingAverageMinMaxObserver,
|
||||
quant_min=0,
|
||||
quant_max=255,
|
||||
reduce_range=False,
|
||||
dtype=torch.quint8,
|
||||
),
|
||||
weight=FakeQuantize.with_args(
|
||||
observer=MovingAveragePerChannelMinMaxObserver,
|
||||
quant_min=-128,
|
||||
quant_max=127,
|
||||
dtype=torch.qint8,
|
||||
qscheme=torch.per_channel_affine,
|
||||
reduce_range=False,
|
||||
)
|
||||
)
|
||||
|
||||
def prepare_model_for_qat(
|
||||
model: nn.Module,
|
||||
example_inputs: torch.Tensor,
|
||||
config: QATConfig
|
||||
) -> nn.Module:
|
||||
qconfig = get_rknn_qconfig()
|
||||
qconfig_mapping = QConfigMapping().set_global(qconfig)
|
||||
|
||||
if config.ignore_layers:
|
||||
LOGGER.info(f"Mixed Precision: Ignoring layers: {config.ignore_layers}")
|
||||
for layer_name in config.ignore_layers:
|
||||
qconfig_mapping.set_module_name(layer_name, None)
|
||||
|
||||
model_to_quantize = copy.deepcopy(model)
|
||||
model_to_quantize.eval()
|
||||
|
||||
try:
|
||||
prepared_model = prepare_qat_fx(
|
||||
model_to_quantize,
|
||||
qconfig_mapping=qconfig_mapping,
|
||||
example_inputs=(example_inputs,)
|
||||
)
|
||||
return prepared_model
|
||||
except Exception as e:
|
||||
LOGGER.error(f"FX Preparation failed. Ensure model is traceable. Error: {e}")
|
||||
raise e
|
||||
|
||||
def load_qat_weights_to_fp32(fp32_model: nn.Module, qat_state_dict: dict) -> nn.Module:
|
||||
"""
|
||||
Loads weights from a QAT model (with fake quant layers) into a standard FP32 model.
|
||||
|
||||
FX Graph mode wraps Conv+BN modules, so keys need to be mapped:
|
||||
- QAT: model.X.conv.bn.weight -> FP32: model.X.bn.weight
|
||||
- QAT: model.X.conv.weight -> FP32: model.X.conv.weight (direct match)
|
||||
"""
|
||||
import re
|
||||
|
||||
LOGGER.info("Loading QAT weights into FP32 model...")
|
||||
fp32_model_dict = fp32_model.state_dict()
|
||||
|
||||
# Quantization-related keys to skip
|
||||
quant_keywords = ['scale', 'zero_point', 'activation_post_process', 'fake_quant',
|
||||
'_observer', 'eps', 'min_val', 'max_val', 'quant_min', 'quant_max',
|
||||
'fake_quant_enabled', 'observer_enabled']
|
||||
|
||||
new_state_dict = {}
|
||||
keys_loaded = []
|
||||
missing_fp32_keys = set(fp32_model_dict.keys())
|
||||
|
||||
for qat_key, qat_value in qat_state_dict.items():
|
||||
# Skip quantization buffers
|
||||
if any(kw in qat_key for kw in quant_keywords):
|
||||
continue
|
||||
|
||||
# Map QAT key to FP32 key: replace '.conv.bn.' with '.bn.'
|
||||
fp32_key = re.sub(r'\.conv\.bn\.', '.bn.', qat_key)
|
||||
|
||||
if fp32_key in fp32_model_dict:
|
||||
if qat_value.shape == fp32_model_dict[fp32_key].shape:
|
||||
new_state_dict[fp32_key] = qat_value
|
||||
keys_loaded.append(fp32_key)
|
||||
missing_fp32_keys.discard(fp32_key)
|
||||
else:
|
||||
LOGGER.warning(f"Shape mismatch for {fp32_key}: QAT {qat_value.shape} vs FP32 {fp32_model_dict[fp32_key].shape}")
|
||||
elif qat_key in fp32_model_dict:
|
||||
# Fallback: try direct match
|
||||
if qat_value.shape == fp32_model_dict[qat_key].shape:
|
||||
new_state_dict[qat_key] = qat_value
|
||||
keys_loaded.append(qat_key)
|
||||
missing_fp32_keys.discard(qat_key)
|
||||
|
||||
load_result = fp32_model.load_state_dict(new_state_dict, strict=False)
|
||||
|
||||
LOGGER.info(f"Loaded {len(keys_loaded)} params from QAT checkpoint.")
|
||||
if missing_fp32_keys:
|
||||
LOGGER.warning(f"Missing keys: {len(missing_fp32_keys)}")
|
||||
|
||||
return fp32_model
|
||||
|
||||
|
||||
def export_knn_compatible_onnx(
|
||||
config: QATConfig,
|
||||
model_class,
|
||||
best_weights_path: str,
|
||||
nc: int,
|
||||
scale: str
|
||||
):
|
||||
"""
|
||||
Orchestrates the full export pipeline:
|
||||
1. Load pure FP32 model.
|
||||
2. Load QAT-trained weights (best.pt) into FP32 model.
|
||||
3. Export FP32 ONNX.
|
||||
4. Run ORT Quantization (PTQ) to produce QDQ ONNX.
|
||||
"""
|
||||
|
||||
fp32_onnx_path = config.save_dir / 'model_fp32.onnx'
|
||||
qdq_onnx_path = config.save_dir / 'model_int8_qdq.onnx'
|
||||
|
||||
LOGGER.info(">>> STEP 1: Exporting Clean FP32 ONNX with QAT weights...")
|
||||
|
||||
# 1. Instantiate clean FP32 model
|
||||
fp32_model = model_class(nc=nc, scale=scale)
|
||||
|
||||
# 2. Load QAT weights
|
||||
ckpt = torch.load(best_weights_path, map_location='cpu')
|
||||
if 'model' in ckpt:
|
||||
state_dict = ckpt['model']
|
||||
else:
|
||||
state_dict = ckpt
|
||||
|
||||
fp32_model = load_qat_weights_to_fp32(fp32_model, state_dict)
|
||||
fp32_model.eval()
|
||||
|
||||
# 3. Export FP32 ONNX
|
||||
dummy_input = torch.randn(1, 3, config.img_size, config.img_size)
|
||||
|
||||
try:
|
||||
torch.onnx.export(
|
||||
fp32_model,
|
||||
dummy_input,
|
||||
str(fp32_onnx_path),
|
||||
opset_version=13,
|
||||
input_names=['images'],
|
||||
output_names=['output'],
|
||||
do_constant_folding=True,
|
||||
dynamo=False
|
||||
)
|
||||
LOGGER.info(f"FP32 ONNX exported to {fp32_onnx_path}")
|
||||
|
||||
# 3.5 Pre-process (Shape Inference & Optimization) to suppress warnings and improve quantization
|
||||
from onnxruntime.quantization import quant_pre_process
|
||||
|
||||
LOGGER.info(">>> STEP 1.5: Running ONNX Runtime Pre-processing...")
|
||||
preprocessed_path = config.save_dir / 'model_fp32_preprocessed.onnx'
|
||||
quant_pre_process(
|
||||
input_model_path=str(fp32_onnx_path),
|
||||
output_model_path=str(preprocessed_path),
|
||||
skip_symbolic_shape=False
|
||||
)
|
||||
LOGGER.info(f"Pre-processed model saved to {preprocessed_path}")
|
||||
|
||||
# Update input path for quantization to use the preprocessed model
|
||||
model_input_path = str(preprocessed_path)
|
||||
|
||||
except Exception as e:
|
||||
LOGGER.error(f"Failed to export or pre-process FP32 ONNX: {e}")
|
||||
return
|
||||
|
||||
LOGGER.info(">>> STEP 2: Running ONNX Runtime Quantization (PTQ)...")
|
||||
|
||||
if not config.calibration_image_dir:
|
||||
LOGGER.error("No calibration image directory provided. Skipping PTQ.")
|
||||
return
|
||||
|
||||
calibration_reader = YOLOCalibrationDataReader(
|
||||
calibration_image_folder=config.calibration_image_dir,
|
||||
max_samples=config.max_calibration_samples,
|
||||
img_size=config.img_size
|
||||
)
|
||||
|
||||
try:
|
||||
quantize_static(
|
||||
model_input=model_input_path, # Use the pre-processed model
|
||||
model_output=str(qdq_onnx_path),
|
||||
calibration_data_reader=calibration_reader,
|
||||
quant_format=QuantFormat.QDQ,
|
||||
activation_type=QuantType.QUInt8,
|
||||
weight_type=QuantType.QInt8,
|
||||
per_channel=True,
|
||||
reduce_range=False,
|
||||
extra_options={
|
||||
'ActivationSymmetric': False,
|
||||
'WeightSymmetric': True
|
||||
}
|
||||
)
|
||||
LOGGER.info(f"QDQ ONNX exported to {qdq_onnx_path}")
|
||||
LOGGER.info(">>> Export Workflow Complete.")
|
||||
except Exception as e:
|
||||
LOGGER.error(f"PTQ Failed: {e}")
|
||||
|
||||
def _aten_copy_symbolic(g, self, src, non_blocking):
|
||||
return g.op("Identity", src)
|
||||
|
||||
register_custom_op_symbolic("aten::copy", _aten_copy_symbolic, 13)
|
||||
|
||||
Reference in New Issue
Block a user