diff --git a/inference.py b/inference.py index 74bab2b..85cfd15 100644 --- a/inference.py +++ b/inference.py @@ -2,7 +2,7 @@ import torch import cv2 import numpy as np import torchvision -from yolo11_standalone import YOLO11, YOLOPostProcessor +from yolo11_standalone import YOLO11, YOLOPostProcessor, YOLOPostProcessorNumpy CLASSES = [ "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light", @@ -26,7 +26,7 @@ def letterbox(im, new_shape=(640, 640), color=(114, 114, 114)): r = min(new_shape[0] / shape[0], new_shape[1] / shape[1]) new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r)) - dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding + dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] dw, dh = dw / 2, dh / 2 if shape[::-1] != new_unpad: @@ -40,13 +40,13 @@ def letterbox(im, new_shape=(640, 640), color=(114, 114, 114)): def xywh2xyxy(x): y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) - y[..., 0] = x[..., 0] - x[..., 2] / 2 # top left x - y[..., 1] = x[..., 1] - x[..., 3] / 2 # top left y - y[..., 2] = x[..., 0] + x[..., 2] / 2 # bottom right x - y[..., 3] = x[..., 1] + x[..., 3] / 2 # bottom right y + y[..., 0] = x[..., 0] - x[..., 2] / 2 + y[..., 1] = x[..., 1] - x[..., 3] / 2 + y[..., 2] = x[..., 0] + x[..., 2] / 2 + y[..., 3] = x[..., 1] + x[..., 3] / 2 return y -def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, max_det=300): +def non_max_suppression(prediction, conf_thres=0.01, iou_thres=0.45, max_det=300): prediction = prediction.transpose(1, 2) bs = prediction.shape[0] @@ -75,12 +75,63 @@ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, max_det=300 return output +def non_max_suppression_numpy(prediction, conf_thres=0.25, iou_thres=0.45, max_det=300): + bs = prediction.shape[0] + output = [np.zeros((0, 6), dtype=np.float32)] * bs + + for xi, x in enumerate(prediction): + bbox_xywh = x[:, :4] + class_probs = x[:, 4:] + + class_ids = np.argmax(class_probs, axis=1) + confidences = np.max(class_probs, axis=1) + + mask = confidences > conf_thres + bbox_xywh = bbox_xywh[mask] + confidences = confidences[mask] + class_ids = class_ids[mask] + + if len(confidences) == 0: + continue + + bbox_tlwh = np.copy(bbox_xywh) + bbox_tlwh[:, 0] = bbox_xywh[:, 0] - bbox_xywh[:, 2] / 2 + bbox_tlwh[:, 1] = bbox_xywh[:, 1] - bbox_xywh[:, 3] / 2 + + indices = cv2.dnn.NMSBoxes( + bboxes=bbox_tlwh.tolist(), + scores=confidences.tolist(), + score_threshold=conf_thres, + nms_threshold=iou_thres + ) + + if len(indices) > 0: + indices = indices.flatten() + if len(indices) > max_det: + indices = indices[:max_det] + + final_boxes_xywh = bbox_xywh[indices] + final_boxes_xyxy = xywh2xyxy(final_boxes_xywh) + + final_scores = confidences[indices] + final_classes = class_ids[indices] + + out_tensor = np.concatenate([ + final_boxes_xyxy, + final_scores[:, None], + final_classes[:, None] + ], axis=1) + + output[xi] = out_tensor + + return output + def main(): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") model = YOLO11(nc=80, scale='s') - model.load_weights("yolo11s.pth") + model.load_weights("my_yolo_result_qat/best_fp32_converted.pth") model.to(device) model.eval() post_std = YOLOPostProcessor(model.model[-1], use_segmentation=False) @@ -104,20 +155,28 @@ def main(): with torch.no_grad(): pred = model(img_tensor) - pred = post_std(pred) - pred = non_max_suppression(pred, conf_thres=0.25, iou_thres=0.45) - - det = pred[0] + # pred = post_std(pred) + # pred = non_max_suppression(pred, conf_thres=0.25, iou_thres=0.45) + # det = pred[0] + preds_raw_numpy = [p.cpu().numpy() for p in pred] + post_numpy = YOLOPostProcessorNumpy(strides=[8, 16, 32], reg_max=16, use_segmentation=False) + pred_numpy_decoded = post_numpy(preds_raw_numpy) + pred_results = non_max_suppression_numpy(pred_numpy_decoded, conf_thres=0.25, iou_thres=0.45) + det = pred_results[0] if len(det): det[:, [0, 2]] -= dw # x padding det[:, [1, 3]] -= dh # y padding det[:, :4] /= ratio - det[:, 0].clamp_(0, img0.shape[1]) - det[:, 1].clamp_(0, img0.shape[0]) - det[:, 2].clamp_(0, img0.shape[1]) - det[:, 3].clamp_(0, img0.shape[0]) + # det[:, 0].clamp_(0, img0.shape[1]) + # det[:, 1].clamp_(0, img0.shape[0]) + # det[:, 2].clamp_(0, img0.shape[1]) + # det[:, 3].clamp_(0, img0.shape[0]) + det[:, 0] = np.clip(det[:, 0], 0, img0.shape[1]) + det[:, 1] = np.clip(det[:, 1], 0, img0.shape[0]) + det[:, 2] = np.clip(det[:, 2], 0, img0.shape[1]) + det[:, 3] = np.clip(det[:, 3], 0, img0.shape[0]) print(f"检测到 {len(det)} 个目标") diff --git a/inference_yoloe.py b/inference_yoloe.py index 1315dff..178ff0b 100644 --- a/inference_yoloe.py +++ b/inference_yoloe.py @@ -4,7 +4,7 @@ import numpy as np import torchvision from pathlib import Path -from yolo11_standalone import YOLO11E, YOLOPostProcessor +from yolo11_standalone import YOLO11E, YOLOPostProcessor, YOLOPostProcessorNumpy from mobile_clip_standalone import MobileCLIP DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -58,6 +58,56 @@ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.7, max_det=300) output[xi] = x[i] return output +def non_max_suppression_numpy(prediction, conf_thres=0.25, iou_thres=0.45, max_det=300): + bs = prediction.shape[0] + output = [np.zeros((0, 6), dtype=np.float32)] * bs + + for xi, x in enumerate(prediction): + bbox_xywh = x[:, :4] + class_probs = x[:, 4:] + + class_ids = np.argmax(class_probs, axis=1) + confidences = np.max(class_probs, axis=1) + + mask = confidences > conf_thres + bbox_xywh = bbox_xywh[mask] + confidences = confidences[mask] + class_ids = class_ids[mask] + + if len(confidences) == 0: + continue + + bbox_tlwh = np.copy(bbox_xywh) + bbox_tlwh[:, 0] = bbox_xywh[:, 0] - bbox_xywh[:, 2] / 2 + bbox_tlwh[:, 1] = bbox_xywh[:, 1] - bbox_xywh[:, 3] / 2 + + indices = cv2.dnn.NMSBoxes( + bboxes=bbox_tlwh.tolist(), + scores=confidences.tolist(), + score_threshold=conf_thres, + nms_threshold=iou_thres + ) + + if len(indices) > 0: + indices = indices.flatten() + if len(indices) > max_det: + indices = indices[:max_det] + + final_boxes_xywh = bbox_xywh[indices] + final_boxes_xyxy = xywh2xyxy(final_boxes_xywh) + final_scores = confidences[indices] + final_classes = class_ids[indices] + + out_tensor = np.concatenate([ + final_boxes_xyxy, + final_scores[:, None], + final_classes[:, None] + ], axis=1) + + output[xi] = out_tensor + + return output + def main(): print(f"Using device: {DEVICE}") @@ -94,17 +144,30 @@ def main(): print("Running inference...") with torch.no_grad(): raw_outputs = yolo_model(img_tensor) - decoded_box, mc, p = post_processor(raw_outputs) + # decoded_box, mc, p = post_processor(raw_outputs) + # pred = non_max_suppression(decoded_box, conf_thres=0.25, iou_thres=0.7) + # det = pred[0] + feat_maps, mc, p = raw_outputs + + feat_maps_numpy = [f.detach().cpu().numpy() for f in feat_maps] + mc_numpy = mc.detach().cpu().numpy() + p_numpy = p.detach().cpu().numpy() + raw_outputs_numpy = (feat_maps_numpy, mc_numpy, p_numpy) + post_processor_numpy = YOLOPostProcessorNumpy(strides=[8, 16, 32], reg_max=16, use_segmentation=True) + decoded_box_numpy, mc_numpy_out, p_numpy_out = post_processor_numpy(raw_outputs_numpy) + pred_results = non_max_suppression_numpy(decoded_box_numpy, conf_thres=0.25, iou_thres=0.7) + det = pred_results[0] - pred = non_max_suppression(decoded_box, conf_thres=0.25, iou_thres=0.7) - - det = pred[0] if len(det): det[:, [0, 2]] -= dw det[:, [1, 3]] -= dh det[:, :4] /= ratio - det[:, [0, 2]].clamp_(0, img0.shape[1]) - det[:, [1, 3]].clamp_(0, img0.shape[0]) + # det[:, [0, 2]].clamp_(0, img0.shape[1]) + # det[:, [1, 3]].clamp_(0, img0.shape[0]) + det[:, 0] = np.clip(det[:, 0], 0, img0.shape[1]) + det[:, 1] = np.clip(det[:, 1], 0, img0.shape[0]) + det[:, 2] = np.clip(det[:, 2], 0, img0.shape[1]) + det[:, 3] = np.clip(det[:, 3], 0, img0.shape[0]) print(f"Detected {len(det)} objects:") for *xyxy, conf, cls in det: @@ -119,7 +182,7 @@ def main(): else: print("No objects detected.") - cv2.imwrite("result_separate.jpg", img0) + cv2.imwrite("result_yoloe.jpg", img0) print("Result saved.") if __name__ == "__main__": diff --git a/qat_utils.py b/qat_utils.py new file mode 100644 index 0000000..b975d9e --- /dev/null +++ b/qat_utils.py @@ -0,0 +1,324 @@ +import os +import copy +import logging +import torch +import torch.nn as nn +import numpy as np +import onnx +from pathlib import Path +from dataclasses import dataclass, field +from typing import List, Optional + +from torch.ao.quantization import QConfig +from torch.ao.quantization.fake_quantize import FakeQuantize +from torch.ao.quantization.observer import ( + MovingAverageMinMaxObserver, + MovingAveragePerChannelMinMaxObserver, +) +from torch.ao.quantization.quantize_fx import prepare_qat_fx, convert_fx +from torch.ao.quantization.qconfig_mapping import QConfigMapping +from torch.onnx import register_custom_op_symbolic + +# ONNX Runtime Quantization +from onnxruntime.quantization import ( + quantize_static, + CalibrationDataReader, + QuantType, + QuantFormat, +) + +LOGGER = logging.getLogger("QAT_Utils") + +@dataclass +class QATConfig: + """Unified Configuration for QAT and Export""" + # Training Config + ignore_layers: List[str] = field(default_factory=list) + + # Export Config + img_size: int = 640 + calibration_image_dir: Optional[str] = None + max_calibration_samples: int = 100 + + # Paths + save_dir: Path = Path('./runs/qat_unified') + + def __post_init__(self): + self.save_dir = Path(self.save_dir) + self.save_dir.mkdir(parents=True, exist_ok=True) + + +class YOLOCalibrationDataReader(CalibrationDataReader): + """YOLO model calibration data reader for ONNX Runtime""" + + def __init__(self, calibration_image_folder: str, input_name: str = 'images', + img_size: int = 640, max_samples: int = 100): + self.input_name = input_name + self.img_size = img_size + self.max_samples = max_samples + + if not calibration_image_folder or not os.path.exists(calibration_image_folder): + raise ValueError(f"Calibration folder not found: {calibration_image_folder}") + + # Collect image paths + self.image_paths = [] + for ext in ['*.jpg', '*.jpeg', '*.png', '*.bmp']: + self.image_paths.extend(Path(calibration_image_folder).glob(ext)) + + if not self.image_paths: + LOGGER.warning(f"No images found in {calibration_image_folder}") + + self.image_paths = self.image_paths[:max_samples] + LOGGER.info(f"Found {len(self.image_paths)} calibration images in {calibration_image_folder}") + + self.current_idx = 0 + + def preprocess(self, img_path): + import cv2 + + img = cv2.imread(str(img_path)) + if img is None: + return None + + # Resize with letterbox + h, w = img.shape[:2] + scale = min(self.img_size / h, self.img_size / w) + new_h, new_w = int(h * scale), int(w * scale) + + img = cv2.resize(img, (new_w, new_h)) + + # Pad to target size + canvas = np.full((self.img_size, self.img_size, 3), 114, dtype=np.uint8) + pad_h = (self.img_size - new_h) // 2 + pad_w = (self.img_size - new_w) // 2 + canvas[pad_h:pad_h+new_h, pad_w:pad_w+new_w] = img + + # Convert to model input format + img = canvas[:, :, ::-1] # BGR to RGB + img = img.transpose(2, 0, 1) # HWC to CHW + img = img.astype(np.float32) / 255.0 + img = np.expand_dims(img, axis=0) # Add batch dimension + + return img + + def get_next(self): + if self.current_idx >= len(self.image_paths): + return None + + img_path = self.image_paths[self.current_idx] + self.current_idx += 1 + + img = self.preprocess(img_path) + if img is None: + return self.get_next() + + return {self.input_name: img} + + def rewind(self): + self.current_idx = 0 + + +def get_rknn_qconfig() -> QConfig: + return QConfig( + activation=FakeQuantize.with_args( + observer=MovingAverageMinMaxObserver, + quant_min=0, + quant_max=255, + reduce_range=False, + dtype=torch.quint8, + ), + weight=FakeQuantize.with_args( + observer=MovingAveragePerChannelMinMaxObserver, + quant_min=-128, + quant_max=127, + dtype=torch.qint8, + qscheme=torch.per_channel_affine, + reduce_range=False, + ) + ) + +def prepare_model_for_qat( + model: nn.Module, + example_inputs: torch.Tensor, + config: QATConfig +) -> nn.Module: + qconfig = get_rknn_qconfig() + qconfig_mapping = QConfigMapping().set_global(qconfig) + + if config.ignore_layers: + LOGGER.info(f"Mixed Precision: Ignoring layers: {config.ignore_layers}") + for layer_name in config.ignore_layers: + qconfig_mapping.set_module_name(layer_name, None) + + model_to_quantize = copy.deepcopy(model) + model_to_quantize.eval() + + try: + prepared_model = prepare_qat_fx( + model_to_quantize, + qconfig_mapping=qconfig_mapping, + example_inputs=(example_inputs,) + ) + return prepared_model + except Exception as e: + LOGGER.error(f"FX Preparation failed. Ensure model is traceable. Error: {e}") + raise e + +def load_qat_weights_to_fp32(fp32_model: nn.Module, qat_state_dict: dict) -> nn.Module: + """ + Loads weights from a QAT model (with fake quant layers) into a standard FP32 model. + + FX Graph mode wraps Conv+BN modules, so keys need to be mapped: + - QAT: model.X.conv.bn.weight -> FP32: model.X.bn.weight + - QAT: model.X.conv.weight -> FP32: model.X.conv.weight (direct match) + """ + import re + + LOGGER.info("Loading QAT weights into FP32 model...") + fp32_model_dict = fp32_model.state_dict() + + # Quantization-related keys to skip + quant_keywords = ['scale', 'zero_point', 'activation_post_process', 'fake_quant', + '_observer', 'eps', 'min_val', 'max_val', 'quant_min', 'quant_max', + 'fake_quant_enabled', 'observer_enabled'] + + new_state_dict = {} + keys_loaded = [] + missing_fp32_keys = set(fp32_model_dict.keys()) + + for qat_key, qat_value in qat_state_dict.items(): + # Skip quantization buffers + if any(kw in qat_key for kw in quant_keywords): + continue + + # Map QAT key to FP32 key: replace '.conv.bn.' with '.bn.' + fp32_key = re.sub(r'\.conv\.bn\.', '.bn.', qat_key) + + if fp32_key in fp32_model_dict: + if qat_value.shape == fp32_model_dict[fp32_key].shape: + new_state_dict[fp32_key] = qat_value + keys_loaded.append(fp32_key) + missing_fp32_keys.discard(fp32_key) + else: + LOGGER.warning(f"Shape mismatch for {fp32_key}: QAT {qat_value.shape} vs FP32 {fp32_model_dict[fp32_key].shape}") + elif qat_key in fp32_model_dict: + # Fallback: try direct match + if qat_value.shape == fp32_model_dict[qat_key].shape: + new_state_dict[qat_key] = qat_value + keys_loaded.append(qat_key) + missing_fp32_keys.discard(qat_key) + + load_result = fp32_model.load_state_dict(new_state_dict, strict=False) + + LOGGER.info(f"Loaded {len(keys_loaded)} params from QAT checkpoint.") + if missing_fp32_keys: + LOGGER.warning(f"Missing keys: {len(missing_fp32_keys)}") + + return fp32_model + + +def export_knn_compatible_onnx( + config: QATConfig, + model_class, + best_weights_path: str, + nc: int, + scale: str +): + """ + Orchestrates the full export pipeline: + 1. Load pure FP32 model. + 2. Load QAT-trained weights (best.pt) into FP32 model. + 3. Export FP32 ONNX. + 4. Run ORT Quantization (PTQ) to produce QDQ ONNX. + """ + + fp32_onnx_path = config.save_dir / 'model_fp32.onnx' + qdq_onnx_path = config.save_dir / 'model_int8_qdq.onnx' + + LOGGER.info(">>> STEP 1: Exporting Clean FP32 ONNX with QAT weights...") + + # 1. Instantiate clean FP32 model + fp32_model = model_class(nc=nc, scale=scale) + + # 2. Load QAT weights + ckpt = torch.load(best_weights_path, map_location='cpu') + if 'model' in ckpt: + state_dict = ckpt['model'] + else: + state_dict = ckpt + + fp32_model = load_qat_weights_to_fp32(fp32_model, state_dict) + fp32_model.eval() + + # 3. Export FP32 ONNX + dummy_input = torch.randn(1, 3, config.img_size, config.img_size) + + try: + torch.onnx.export( + fp32_model, + dummy_input, + str(fp32_onnx_path), + opset_version=13, + input_names=['images'], + output_names=['output'], + do_constant_folding=True, + dynamo=False + ) + LOGGER.info(f"FP32 ONNX exported to {fp32_onnx_path}") + + # 3.5 Pre-process (Shape Inference & Optimization) to suppress warnings and improve quantization + from onnxruntime.quantization import quant_pre_process + + LOGGER.info(">>> STEP 1.5: Running ONNX Runtime Pre-processing...") + preprocessed_path = config.save_dir / 'model_fp32_preprocessed.onnx' + quant_pre_process( + input_model_path=str(fp32_onnx_path), + output_model_path=str(preprocessed_path), + skip_symbolic_shape=False + ) + LOGGER.info(f"Pre-processed model saved to {preprocessed_path}") + + # Update input path for quantization to use the preprocessed model + model_input_path = str(preprocessed_path) + + except Exception as e: + LOGGER.error(f"Failed to export or pre-process FP32 ONNX: {e}") + return + + LOGGER.info(">>> STEP 2: Running ONNX Runtime Quantization (PTQ)...") + + if not config.calibration_image_dir: + LOGGER.error("No calibration image directory provided. Skipping PTQ.") + return + + calibration_reader = YOLOCalibrationDataReader( + calibration_image_folder=config.calibration_image_dir, + max_samples=config.max_calibration_samples, + img_size=config.img_size + ) + + try: + quantize_static( + model_input=model_input_path, # Use the pre-processed model + model_output=str(qdq_onnx_path), + calibration_data_reader=calibration_reader, + quant_format=QuantFormat.QDQ, + activation_type=QuantType.QUInt8, + weight_type=QuantType.QInt8, + per_channel=True, + reduce_range=False, + extra_options={ + 'ActivationSymmetric': False, + 'WeightSymmetric': True + } + ) + LOGGER.info(f"QDQ ONNX exported to {qdq_onnx_path}") + LOGGER.info(">>> Export Workflow Complete.") + except Exception as e: + LOGGER.error(f"PTQ Failed: {e}") + +def _aten_copy_symbolic(g, self, src, non_blocking): + return g.op("Identity", src) + +register_custom_op_symbolic("aten::copy", _aten_copy_symbolic, 13) + diff --git a/train_qat.py b/train_qat.py new file mode 100644 index 0000000..99003c1 --- /dev/null +++ b/train_qat.py @@ -0,0 +1,109 @@ +import torch +from torch.utils.data import DataLoader + +from dataset import YOLODataset +from yolo11_standalone import YOLO11, YOLOPostProcessor +from loss import YOLOv8DetectionLoss +from trainer_qat import YOLO11QATTrainer +import qat_utils +from qat_utils import QATConfig + +def run_qat_training(): + # ========================== + # 1. Configuration + # ========================== + img_dir_train = "E:\\Datasets\\coco\\images\\train2017" + label_dir_train = "E:\\Datasets\\coco\\labels\\train2017" + img_dir_val = "E:\\Datasets\\coco\\images\\val2017" + label_dir_val = "E:\\Datasets\\coco\\labels\\val2017" + # img_dir_train = "E:\\Datasets\\coco8\\images\\train" + # label_dir_train = "E:\\Datasets\\coco8\\labels\\train" + # img_dir_val = "E:\\Datasets\\coco8\\images\\val" + # label_dir_val = "E:\\Datasets\\coco8\\labels\\val" + + # Validation/Calibration Set (reused for PTQ calibration) + # Using val set for calibration is common, or a subset of train + calibration_img_dir = img_dir_val + + weights = "yolo11s.pth" + scale = 's' + nc = 80 + + # QAT Training Hyperparameters + epochs = 10 + batch_size = 16 # Suggest increasing to 32 or 64 if GPU memory allows for better BN stats + lr = 0.0001 # Lower LR (1e-4) is safer for QAT fine-tuning to preserve pretrained knowledge + + # Unified QAT Configuration + config = QATConfig( + img_size=640, + save_dir='./my_yolo_result_qat', + calibration_image_dir=calibration_img_dir, + max_calibration_samples=500, # Increased for full COCO to ensure robust PTQ quantization parameters + ignore_layers=[] # Add layer names here to skip quantization if needed + ) + + device = "cuda" if torch.cuda.is_available() else "cpu" + + # ========================== + # 2. Data Preparation + # ========================== + print("Loading Data...") + train_dataset = YOLODataset(img_dir_train, label_dir_train, img_size=config.img_size, is_train=True) + val_dataset = YOLODataset(img_dir_val, label_dir_val, img_size=config.img_size, is_train=False) + + train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, + collate_fn=YOLODataset.collate_fn, num_workers=4, pin_memory=True) + val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, + collate_fn=YOLODataset.collate_fn, num_workers=4, pin_memory=True) + + # ========================== + # 3. Model Initialization + # ========================== + print("Initializing Model...") + model = YOLO11(nc=nc, scale=scale) + if weights: + model.load_weights(weights) + + post_std = YOLOPostProcessor(model.model[-1], use_segmentation=False) + loss_fn = YOLOv8DetectionLoss(nc=nc, reg_max=16, stride=[8, 16, 32]) + + # ========================== + # 4. QAT Training + # ========================== + print("Creating QAT Trainer...") + trainer = YOLO11QATTrainer( + model=model, + post_processor=post_std, + train_loader=train_loader, + val_loader=val_loader, + loss_fn=loss_fn, + config=config, + epochs=epochs, + lr=lr, + device=device + ) + + trainer.train() + + # ========================== + # 5. Export Pipeline + # (Load Best -> FP32 -> ONNX -> QDQ PTQ) + # ========================== + print("\nStarting Export Pipeline...") + best_weights = config.save_dir / 'best.pt' + + if best_weights.exists(): + qat_utils.export_knn_compatible_onnx( + config=config, + model_class=YOLO11, + best_weights_path=str(best_weights), + nc=nc, + scale=scale + ) + else: + print("Error: Best weights not found. Export failed.") + + +if __name__ == "__main__": + run_qat_training() diff --git a/trainer.py b/trainer.py index d06cb3d..1435055 100644 --- a/trainer.py +++ b/trainer.py @@ -146,7 +146,6 @@ class YOLO11Trainer: iou_thres = 0.7 iouv = torch.linspace(0.5, 0.95, 10, device=device) - loss_sum = torch.zeros(3, device=device) stats = [] LOGGER.info("\nValidating...") diff --git a/trainer_qat.py b/trainer_qat.py new file mode 100644 index 0000000..92f3222 --- /dev/null +++ b/trainer_qat.py @@ -0,0 +1,199 @@ +import math +import logging +import time +import numpy as np +from pathlib import Path +from typing import Optional, List + +import torch +import torch.nn as nn +import torch.optim as optim +from tqdm import tqdm + +from metrics import ap_per_class, box_iou, non_max_suppression, xywh2xyxy +from qat_utils import prepare_model_for_qat, QATConfig + +logging.basicConfig(format="%(message)s", level=logging.INFO) +LOGGER = logging.getLogger("QAT_Trainer") + +class YOLO11QATTrainer: + def __init__(self, + model, + post_processor, + train_loader, + val_loader, + loss_fn, + config: QATConfig, + epochs: int = 5, + lr: float = 0.001, + device: str = 'cuda'): + + self.device = torch.device(device if torch.cuda.is_available() else 'cpu') + + self.config = config + self.save_dir = config.save_dir + + self.original_model = model.to(self.device) + self.post_processor = post_processor + self.train_loader = train_loader + self.val_loader = val_loader + self.loss_fn = loss_fn + self.epochs = epochs + self.lr = lr + + self.qat_model = None + self.optimizer = None + self.scheduler = None + + def setup(self): + LOGGER.info(">>> Setting up QAT...") + + example_input = torch.randn(1, 3, self.config.img_size, self.config.img_size).to(self.device) + self.qat_model = prepare_model_for_qat( + self.original_model, + example_input, + config=self.config + ) + self.qat_model.to(self.device) + + self.optimizer = optim.Adam(self.qat_model.parameters(), lr=self.lr, weight_decay=1e-5) + + lf = lambda x: ((1 - math.cos(x * math.pi / self.epochs)) / 2) * (0.01 - 1) + 1 + self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lf) + + def train_epoch(self, epoch): + self.qat_model.train() + pbar = tqdm(self.train_loader, desc=f'Epoch {epoch+1}/{self.epochs}') + + mean_loss = torch.zeros(3, device=self.device) + + for i, (imgs, targets, _) in enumerate(pbar): + imgs = imgs.to(self.device).float() + targets = targets.to(self.device) + + self.optimizer.zero_grad() + + preds = self.qat_model(imgs) + + target_batch = { + "batch_idx": targets[:, 0], + "cls": targets[:, 1], + "bboxes": targets[:, 2:], + } + loss, loss_items = self.loss_fn(preds, target_batch) + + loss.backward() + torch.nn.utils.clip_grad_norm_(self.qat_model.parameters(), max_norm=10.0) + self.optimizer.step() + + mean_loss = (mean_loss * i + loss_items.detach()) / (i + 1) + pbar.set_postfix({ + 'box': f"{mean_loss[0]:.4f}", + 'cls': f"{mean_loss[1]:.4f}", + 'dfl': f"{mean_loss[2]:.4f}", + 'lr': f"{self.optimizer.param_groups[0]['lr']:.6f}" + }) + + def validate(self): + conf_thres = 0.001 + iou_thres = 0.7 + iouv = torch.linspace(0.5, 0.95, 10, device=self.device) + + stats = [] + + LOGGER.info("\nValidating...") + + self.qat_model.eval() + pbar = tqdm(self.val_loader, desc="Calc Metrics") + with torch.no_grad(): + for batch in pbar: + imgs, targets, _ = batch + imgs = imgs.to(self.device, non_blocking=True) + targets = targets.to(self.device) + _, _, height, width = imgs.shape + + preds = self.qat_model(imgs) + preds = self.post_processor(preds) + preds = non_max_suppression(preds, conf_thres=conf_thres, iou_thres=iou_thres) + + for si, pred in enumerate(preds): + labels = targets[targets[:, 0] == si] + nl = len(labels) + tcls = labels[:, 1].tolist() if nl else [] + + if len(pred) == 0: + if nl: + stats.append((torch.zeros(0, iouv.numel(), dtype=torch.bool), + torch.Tensor(), torch.Tensor(), torch.tensor(tcls))) + continue + + predn = pred.clone() + + if nl: + tbox = xywh2xyxy(labels[:, 2:6]) + tbox[:, [0, 2]] *= width + tbox[:, [1, 3]] *= height + labelsn = torch.cat((labels[:, 1:2], tbox), 1) + correct = self._process_batch(predn, labelsn, iouv) + else: + correct = torch.zeros(pred.shape[0], iouv.numel(), dtype=torch.bool) + + stats.append((correct.cpu(), pred[:, 4].cpu(), pred[:, 5].cpu(), torch.tensor(tcls))) + + mp, mr, map50, map5095 = 0.0, 0.0, 0.0, 0.0 + if len(stats) and stats[0][0].any(): + stats = [torch.cat(x, 0).cpu().numpy() for x in zip(*stats)] + tp, fp, p, r, f1, ap, ap_class = ap_per_class(*stats) + ap50, ap = ap[:, 0], ap.mean(1) + mp, mr, map50, map5095 = p.mean(), r.mean(), ap50.mean(), ap.mean() + + LOGGER.info(f"Val Results: Prec={mp:.3f}, Recall={mr:.3f} mAP50={map50:.3f} mAP50-95={map5095:.3f}") + + return map50 + + def _process_batch(self, detections, labels, iouv): + correct = torch.zeros(detections.shape[0], iouv.shape[0], dtype=torch.bool, device=iouv.device) + iou = box_iou(labels[:, 1:], detections[:, :4]) + + x = torch.where((iou >= iouv[0]) & (labels[:, 0:1] == detections[:, 5])) + + if x[0].shape[0]: + matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1).cpu().numpy() + if x[0].shape[0] > 1: + matches = matches[matches[:, 2].argsort()[::-1]] + matches = matches[np.unique(matches[:, 1], return_index=True)[1]] + matches = matches[matches[:, 2].argsort()[::-1]] + matches = matches[np.unique(matches[:, 0], return_index=True)[1]] + + matches = torch.from_numpy(matches).to(iouv.device) + correct[matches[:, 1].long()] = matches[:, 2:3] >= iouv + + return correct + + def train(self): + self.setup() + + LOGGER.info(f"Starting training on {self.device} for {self.epochs} epochs...") + start_time = time.time() + best_fitness = -1 + + for epoch in range(self.epochs): + self.train_epoch(epoch) + self.scheduler.step() + map50 = self.validate() + + ckpt = { + 'epoch': epoch, + 'model': self.qat_model.state_dict(), + 'optimizer': self.optimizer.state_dict(), + } + + torch.save(ckpt, self.save_dir / 'last_qat.pt') + + if map50 > best_fitness: + best_fitness = map50 + torch.save(ckpt, self.save_dir / 'best.pt') + LOGGER.info(f"--> Saved best model with Recall/mAP: {best_fitness:.4f}") + + LOGGER.info(">>> Training Finished.") + diff --git a/yolo11_standalone.py b/yolo11_standalone.py index aad9906..d5db89b 100644 --- a/yolo11_standalone.py +++ b/yolo11_standalone.py @@ -562,6 +562,90 @@ class YOLO11E(YOLO11): return head(feats, cls_pe) + +# ============================================================================== +# [Part 4] PostProcessorNumpy +# ============================================================================== + +import numpy as np + +class YOLOPostProcessorNumpy: + def __init__(self, strides=[8, 16, 32], reg_max=16, use_segmentation=False): + self.strides = np.array(strides, dtype=np.float32) + self.reg_max = reg_max + self.use_segmentation = use_segmentation + self.anchors = None + self.strides_array = None + self.shape = None + self.dfl_weights = np.arange(reg_max, dtype=np.float32).reshape(1, 1, reg_max, 1) + + def sigmoid(self, x): + return 1 / (1 + np.exp(-x)) + + def softmax(self, x, axis=-1): + x_max = np.max(x, axis=axis, keepdims=True) + e_x = np.exp(x - x_max) + return e_x / np.sum(e_x, axis=axis, keepdims=True) + + def make_anchors(self, feats, strides, grid_cell_offset=0.5): + anchor_points, stride_list = [], [] + for i, stride in enumerate(strides): + _, _, h, w = feats[i].shape + sx = np.arange(w, dtype=np.float32) + grid_cell_offset + sy = np.arange(h, dtype=np.float32) + grid_cell_offset + sy, sx = np.meshgrid(sy, sx, indexing='ij') + + anchor_points.append(np.stack((sx, sy), -1).reshape(-1, 2)) + stride_list.append(np.full((h * w, 1), stride, dtype=np.float32)) + + return np.concatenate(anchor_points), np.concatenate(stride_list) + + def dist2bbox(self, distance, anchor_points, xywh=True, dim=-1): + lt, rb = np.split(distance, 2, axis=dim) + x1y1 = anchor_points - lt + x2y2 = anchor_points + rb + if xywh: + c_xy = (x1y1 + x2y2) / 2 + wh = x2y2 - x1y1 + return np.concatenate((c_xy, wh), axis=dim) + return np.concatenate((x1y1, x2y2), axis=dim) + + def dfl_decode(self, x): + B, C, A = x.shape + x = x.reshape(B, 4, self.reg_max, A) + x = self.softmax(x, axis=2) + return np.sum(x * self.dfl_weights, axis=2) + + def __call__(self, outputs): + if self.use_segmentation: + x, mc, p = outputs + else: + x = outputs + + current_no = x[0].shape[1] + current_nc = current_no - self.reg_max * 4 + shape = x[0].shape + + x_cat = np.concatenate([xi.reshape(shape[0], current_no, -1) for xi in x], axis=2) + + if self.anchors is None or self.shape != shape: + self.anchors, self.strides_array = self.make_anchors(x, self.strides, 0.5) + self.shape = shape + + box, cls = np.split(x_cat, [self.reg_max * 4], axis=1) + dist = self.dfl_decode(box) + dist = dist.transpose(0, 2, 1) + dbox = self.dist2bbox(dist, self.anchors, xywh=True, dim=2) * self.strides_array + cls = cls.transpose(0, 2, 1) + sigmoid_cls = self.sigmoid(cls) + final_box = np.concatenate((dbox, sigmoid_cls), axis=2) + + if self.use_segmentation: + return final_box, mc, p + + return final_box + + if __name__ == "__main__": print("Testing Standard YOLO11...") model_std = YOLO11(nc=80, scale='n') diff --git a/yolo11s_int_qdq.onnx b/yolo11s_int_qdq.onnx new file mode 100644 index 0000000..90a456a Binary files /dev/null and b/yolo11s_int_qdq.onnx differ