添加qat量化支持
This commit is contained in:
109
train_qat.py
Normal file
109
train_qat.py
Normal file
@@ -0,0 +1,109 @@
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from dataset import YOLODataset
|
||||
from yolo11_standalone import YOLO11, YOLOPostProcessor
|
||||
from loss import YOLOv8DetectionLoss
|
||||
from trainer_qat import YOLO11QATTrainer
|
||||
import qat_utils
|
||||
from qat_utils import QATConfig
|
||||
|
||||
def run_qat_training():
|
||||
# ==========================
|
||||
# 1. Configuration
|
||||
# ==========================
|
||||
img_dir_train = "E:\\Datasets\\coco\\images\\train2017"
|
||||
label_dir_train = "E:\\Datasets\\coco\\labels\\train2017"
|
||||
img_dir_val = "E:\\Datasets\\coco\\images\\val2017"
|
||||
label_dir_val = "E:\\Datasets\\coco\\labels\\val2017"
|
||||
# img_dir_train = "E:\\Datasets\\coco8\\images\\train"
|
||||
# label_dir_train = "E:\\Datasets\\coco8\\labels\\train"
|
||||
# img_dir_val = "E:\\Datasets\\coco8\\images\\val"
|
||||
# label_dir_val = "E:\\Datasets\\coco8\\labels\\val"
|
||||
|
||||
# Validation/Calibration Set (reused for PTQ calibration)
|
||||
# Using val set for calibration is common, or a subset of train
|
||||
calibration_img_dir = img_dir_val
|
||||
|
||||
weights = "yolo11s.pth"
|
||||
scale = 's'
|
||||
nc = 80
|
||||
|
||||
# QAT Training Hyperparameters
|
||||
epochs = 10
|
||||
batch_size = 16 # Suggest increasing to 32 or 64 if GPU memory allows for better BN stats
|
||||
lr = 0.0001 # Lower LR (1e-4) is safer for QAT fine-tuning to preserve pretrained knowledge
|
||||
|
||||
# Unified QAT Configuration
|
||||
config = QATConfig(
|
||||
img_size=640,
|
||||
save_dir='./my_yolo_result_qat',
|
||||
calibration_image_dir=calibration_img_dir,
|
||||
max_calibration_samples=500, # Increased for full COCO to ensure robust PTQ quantization parameters
|
||||
ignore_layers=[] # Add layer names here to skip quantization if needed
|
||||
)
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
# ==========================
|
||||
# 2. Data Preparation
|
||||
# ==========================
|
||||
print("Loading Data...")
|
||||
train_dataset = YOLODataset(img_dir_train, label_dir_train, img_size=config.img_size, is_train=True)
|
||||
val_dataset = YOLODataset(img_dir_val, label_dir_val, img_size=config.img_size, is_train=False)
|
||||
|
||||
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,
|
||||
collate_fn=YOLODataset.collate_fn, num_workers=4, pin_memory=True)
|
||||
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False,
|
||||
collate_fn=YOLODataset.collate_fn, num_workers=4, pin_memory=True)
|
||||
|
||||
# ==========================
|
||||
# 3. Model Initialization
|
||||
# ==========================
|
||||
print("Initializing Model...")
|
||||
model = YOLO11(nc=nc, scale=scale)
|
||||
if weights:
|
||||
model.load_weights(weights)
|
||||
|
||||
post_std = YOLOPostProcessor(model.model[-1], use_segmentation=False)
|
||||
loss_fn = YOLOv8DetectionLoss(nc=nc, reg_max=16, stride=[8, 16, 32])
|
||||
|
||||
# ==========================
|
||||
# 4. QAT Training
|
||||
# ==========================
|
||||
print("Creating QAT Trainer...")
|
||||
trainer = YOLO11QATTrainer(
|
||||
model=model,
|
||||
post_processor=post_std,
|
||||
train_loader=train_loader,
|
||||
val_loader=val_loader,
|
||||
loss_fn=loss_fn,
|
||||
config=config,
|
||||
epochs=epochs,
|
||||
lr=lr,
|
||||
device=device
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
|
||||
# ==========================
|
||||
# 5. Export Pipeline
|
||||
# (Load Best -> FP32 -> ONNX -> QDQ PTQ)
|
||||
# ==========================
|
||||
print("\nStarting Export Pipeline...")
|
||||
best_weights = config.save_dir / 'best.pt'
|
||||
|
||||
if best_weights.exists():
|
||||
qat_utils.export_knn_compatible_onnx(
|
||||
config=config,
|
||||
model_class=YOLO11,
|
||||
best_weights_path=str(best_weights),
|
||||
nc=nc,
|
||||
scale=scale
|
||||
)
|
||||
else:
|
||||
print("Error: Best weights not found. Export failed.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_qat_training()
|
||||
Reference in New Issue
Block a user