Files
Yolo-standalone/train_qat.py
2026-01-08 15:12:27 +08:00

110 lines
3.8 KiB
Python

import torch
from torch.utils.data import DataLoader
from dataset import YOLODataset
from yolo11_standalone import YOLO11, YOLOPostProcessor
from loss import YOLOv8DetectionLoss
from trainer_qat import YOLO11QATTrainer
import qat_utils
from qat_utils import QATConfig
def run_qat_training():
# ==========================
# 1. Configuration
# ==========================
img_dir_train = "E:\\Datasets\\coco\\images\\train2017"
label_dir_train = "E:\\Datasets\\coco\\labels\\train2017"
img_dir_val = "E:\\Datasets\\coco\\images\\val2017"
label_dir_val = "E:\\Datasets\\coco\\labels\\val2017"
# img_dir_train = "E:\\Datasets\\coco8\\images\\train"
# label_dir_train = "E:\\Datasets\\coco8\\labels\\train"
# img_dir_val = "E:\\Datasets\\coco8\\images\\val"
# label_dir_val = "E:\\Datasets\\coco8\\labels\\val"
# Validation/Calibration Set (reused for PTQ calibration)
# Using val set for calibration is common, or a subset of train
calibration_img_dir = img_dir_val
weights = "yolo11s.pth"
scale = 's'
nc = 80
# QAT Training Hyperparameters
epochs = 10
batch_size = 16 # Suggest increasing to 32 or 64 if GPU memory allows for better BN stats
lr = 0.0001 # Lower LR (1e-4) is safer for QAT fine-tuning to preserve pretrained knowledge
# Unified QAT Configuration
config = QATConfig(
img_size=640,
save_dir='./my_yolo_result_qat',
calibration_image_dir=calibration_img_dir,
max_calibration_samples=500, # Increased for full COCO to ensure robust PTQ quantization parameters
ignore_layers=[] # Add layer names here to skip quantization if needed
)
device = "cuda" if torch.cuda.is_available() else "cpu"
# ==========================
# 2. Data Preparation
# ==========================
print("Loading Data...")
train_dataset = YOLODataset(img_dir_train, label_dir_train, img_size=config.img_size, is_train=True)
val_dataset = YOLODataset(img_dir_val, label_dir_val, img_size=config.img_size, is_train=False)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,
collate_fn=YOLODataset.collate_fn, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False,
collate_fn=YOLODataset.collate_fn, num_workers=4, pin_memory=True)
# ==========================
# 3. Model Initialization
# ==========================
print("Initializing Model...")
model = YOLO11(nc=nc, scale=scale)
if weights:
model.load_weights(weights)
post_std = YOLOPostProcessor(model.model[-1], use_segmentation=False)
loss_fn = YOLOv8DetectionLoss(nc=nc, reg_max=16, stride=[8, 16, 32])
# ==========================
# 4. QAT Training
# ==========================
print("Creating QAT Trainer...")
trainer = YOLO11QATTrainer(
model=model,
post_processor=post_std,
train_loader=train_loader,
val_loader=val_loader,
loss_fn=loss_fn,
config=config,
epochs=epochs,
lr=lr,
device=device
)
trainer.train()
# ==========================
# 5. Export Pipeline
# (Load Best -> FP32 -> ONNX -> QDQ PTQ)
# ==========================
print("\nStarting Export Pipeline...")
best_weights = config.save_dir / 'best.pt'
if best_weights.exists():
qat_utils.export_knn_compatible_onnx(
config=config,
model_class=YOLO11,
best_weights_path=str(best_weights),
nc=nc,
scale=scale
)
else:
print("Error: Best weights not found. Export failed.")
if __name__ == "__main__":
run_qat_training()