import torch from torch.utils.data import DataLoader from dataset import YOLODataset from yolo11_standalone import YOLO11, YOLOPostProcessor from loss import YOLOv8DetectionLoss from trainer_qat import YOLO11QATTrainer import qat_utils from qat_utils import QATConfig def run_qat_training(): # ========================== # 1. Configuration # ========================== img_dir_train = "E:\\Datasets\\coco\\images\\train2017" label_dir_train = "E:\\Datasets\\coco\\labels\\train2017" img_dir_val = "E:\\Datasets\\coco\\images\\val2017" label_dir_val = "E:\\Datasets\\coco\\labels\\val2017" # img_dir_train = "E:\\Datasets\\coco8\\images\\train" # label_dir_train = "E:\\Datasets\\coco8\\labels\\train" # img_dir_val = "E:\\Datasets\\coco8\\images\\val" # label_dir_val = "E:\\Datasets\\coco8\\labels\\val" # Validation/Calibration Set (reused for PTQ calibration) # Using val set for calibration is common, or a subset of train calibration_img_dir = img_dir_val weights = "yolo11s.pth" scale = 's' nc = 80 # QAT Training Hyperparameters epochs = 10 batch_size = 16 # Suggest increasing to 32 or 64 if GPU memory allows for better BN stats lr = 0.0001 # Lower LR (1e-4) is safer for QAT fine-tuning to preserve pretrained knowledge # Unified QAT Configuration config = QATConfig( img_size=640, save_dir='./my_yolo_result_qat', calibration_image_dir=calibration_img_dir, max_calibration_samples=500, # Increased for full COCO to ensure robust PTQ quantization parameters ignore_layers=[] # Add layer names here to skip quantization if needed ) device = "cuda" if torch.cuda.is_available() else "cpu" # ========================== # 2. Data Preparation # ========================== print("Loading Data...") train_dataset = YOLODataset(img_dir_train, label_dir_train, img_size=config.img_size, is_train=True) val_dataset = YOLODataset(img_dir_val, label_dir_val, img_size=config.img_size, is_train=False) train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=YOLODataset.collate_fn, num_workers=4, pin_memory=True) val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, collate_fn=YOLODataset.collate_fn, num_workers=4, pin_memory=True) # ========================== # 3. Model Initialization # ========================== print("Initializing Model...") model = YOLO11(nc=nc, scale=scale) if weights: model.load_weights(weights) post_std = YOLOPostProcessor(model.model[-1], use_segmentation=False) loss_fn = YOLOv8DetectionLoss(nc=nc, reg_max=16, stride=[8, 16, 32]) # ========================== # 4. QAT Training # ========================== print("Creating QAT Trainer...") trainer = YOLO11QATTrainer( model=model, post_processor=post_std, train_loader=train_loader, val_loader=val_loader, loss_fn=loss_fn, config=config, epochs=epochs, lr=lr, device=device ) trainer.train() # ========================== # 5. Export Pipeline # (Load Best -> FP32 -> ONNX -> QDQ PTQ) # ========================== print("\nStarting Export Pipeline...") best_weights = config.save_dir / 'best.pt' if best_weights.exists(): qat_utils.export_knn_compatible_onnx( config=config, model_class=YOLO11, best_weights_path=str(best_weights), nc=nc, scale=scale ) else: print("Error: Best weights not found. Export failed.") if __name__ == "__main__": run_qat_training()