110 lines
3.8 KiB
Python
110 lines
3.8 KiB
Python
import torch
|
|
from torch.utils.data import DataLoader
|
|
|
|
from dataset import YOLODataset
|
|
from yolo11_standalone import YOLO11, YOLOPostProcessor
|
|
from loss import YOLOv8DetectionLoss
|
|
from trainer_qat import YOLO11QATTrainer
|
|
import qat_utils
|
|
from qat_utils import QATConfig
|
|
|
|
def run_qat_training():
|
|
# ==========================
|
|
# 1. Configuration
|
|
# ==========================
|
|
img_dir_train = "E:\\Datasets\\coco\\images\\train2017"
|
|
label_dir_train = "E:\\Datasets\\coco\\labels\\train2017"
|
|
img_dir_val = "E:\\Datasets\\coco\\images\\val2017"
|
|
label_dir_val = "E:\\Datasets\\coco\\labels\\val2017"
|
|
# img_dir_train = "E:\\Datasets\\coco8\\images\\train"
|
|
# label_dir_train = "E:\\Datasets\\coco8\\labels\\train"
|
|
# img_dir_val = "E:\\Datasets\\coco8\\images\\val"
|
|
# label_dir_val = "E:\\Datasets\\coco8\\labels\\val"
|
|
|
|
# Validation/Calibration Set (reused for PTQ calibration)
|
|
# Using val set for calibration is common, or a subset of train
|
|
calibration_img_dir = img_dir_val
|
|
|
|
weights = "yolo11s.pth"
|
|
scale = 's'
|
|
nc = 80
|
|
|
|
# QAT Training Hyperparameters
|
|
epochs = 10
|
|
batch_size = 16 # Suggest increasing to 32 or 64 if GPU memory allows for better BN stats
|
|
lr = 0.0001 # Lower LR (1e-4) is safer for QAT fine-tuning to preserve pretrained knowledge
|
|
|
|
# Unified QAT Configuration
|
|
config = QATConfig(
|
|
img_size=640,
|
|
save_dir='./my_yolo_result_qat',
|
|
calibration_image_dir=calibration_img_dir,
|
|
max_calibration_samples=500, # Increased for full COCO to ensure robust PTQ quantization parameters
|
|
ignore_layers=[] # Add layer names here to skip quantization if needed
|
|
)
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
# ==========================
|
|
# 2. Data Preparation
|
|
# ==========================
|
|
print("Loading Data...")
|
|
train_dataset = YOLODataset(img_dir_train, label_dir_train, img_size=config.img_size, is_train=True)
|
|
val_dataset = YOLODataset(img_dir_val, label_dir_val, img_size=config.img_size, is_train=False)
|
|
|
|
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,
|
|
collate_fn=YOLODataset.collate_fn, num_workers=4, pin_memory=True)
|
|
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False,
|
|
collate_fn=YOLODataset.collate_fn, num_workers=4, pin_memory=True)
|
|
|
|
# ==========================
|
|
# 3. Model Initialization
|
|
# ==========================
|
|
print("Initializing Model...")
|
|
model = YOLO11(nc=nc, scale=scale)
|
|
if weights:
|
|
model.load_weights(weights)
|
|
|
|
post_std = YOLOPostProcessor(model.model[-1], use_segmentation=False)
|
|
loss_fn = YOLOv8DetectionLoss(nc=nc, reg_max=16, stride=[8, 16, 32])
|
|
|
|
# ==========================
|
|
# 4. QAT Training
|
|
# ==========================
|
|
print("Creating QAT Trainer...")
|
|
trainer = YOLO11QATTrainer(
|
|
model=model,
|
|
post_processor=post_std,
|
|
train_loader=train_loader,
|
|
val_loader=val_loader,
|
|
loss_fn=loss_fn,
|
|
config=config,
|
|
epochs=epochs,
|
|
lr=lr,
|
|
device=device
|
|
)
|
|
|
|
trainer.train()
|
|
|
|
# ==========================
|
|
# 5. Export Pipeline
|
|
# (Load Best -> FP32 -> ONNX -> QDQ PTQ)
|
|
# ==========================
|
|
print("\nStarting Export Pipeline...")
|
|
best_weights = config.save_dir / 'best.pt'
|
|
|
|
if best_weights.exists():
|
|
qat_utils.export_knn_compatible_onnx(
|
|
config=config,
|
|
model_class=YOLO11,
|
|
best_weights_path=str(best_weights),
|
|
nc=nc,
|
|
scale=scale
|
|
)
|
|
else:
|
|
print("Error: Best weights not found. Export failed.")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_qat_training()
|