import os import glob import cv2 import numpy as np import torch import random from torch.utils.data import Dataset import albumentations as A from albumentations.pytorch import ToTensorV2 class YOLODataset(Dataset): def __init__(self, img_dir, label_dir, img_size=640, is_train=True): self.img_dir = img_dir self.label_dir = label_dir self.img_size = img_size self.is_train = is_train self.use_mosaic = is_train # 新增标志位,默认训练时开启 # 支持多种图片格式 self.img_files = sorted( glob.glob(os.path.join(img_dir, "*.jpg")) + glob.glob(os.path.join(img_dir, "*.png")) + glob.glob(os.path.join(img_dir, "*.jpeg")) ) # --- 1. 对齐 Ultralytics 的 Albumentations 配置 --- # default.yaml: hsv_h: 0.015, hsv_s: 0.7, hsv_v: 0.4 # OpenCV Hue range is [0, 179], Sat/Val is [0, 255] h_limit = int(0.015 * 179) # ~2 s_limit = int(0.7 * 255) # ~178 v_limit = int(0.4 * 255) # ~102 # default.yaml: translate: 0.1, scale: 0.5 (0.5~1.5), degrees: 0.0 if is_train: self.transform = A.Compose([ # 几何增强 (Mosaic 之后再做一次微调,或者处理非 Mosaic 的情况) # 注意:Mosaic 输出已经是大图,这里主要负责最后的随机扰动 A.Affine( scale=(0.5, 1.5), # scale: 0.5 translate_percent=(0.1, 0.1), # translate: 0.1 rotate=(-0, 0), # degrees: 0.0 (COCO default) shear=(-0, 0), # shear: 0.0 p=0.5 ), # 色彩增强 (严格对齐 default.yaml) A.HueSaturationValue( hue_shift_limit=h_limit, sat_shift_limit=s_limit, val_shift_limit=v_limit, p=0.5 ), A.Blur(p=0.01), A.MedianBlur(p=0.01), A.ToGray(p=0.01), A.CLAHE(p=0.01), # 翻转 A.HorizontalFlip(p=0.5), # fliplr: 0.5 # 最终处理 A.Resize(img_size, img_size), # 确保最后尺寸一致 A.Normalize(mean=(0, 0, 0), std=(1, 1, 1), max_pixel_value=255.0), ToTensorV2() ], bbox_params=A.BboxParams(format='yolo', min_visibility=0.0, label_fields=['class_labels'])) else: # 验证集:Letterbox (保持比例填充) self.transform = A.Compose([ A.LongestMaxSize(max_size=img_size), A.PadIfNeeded(min_height=img_size, min_width=img_size, border_mode=cv2.BORDER_CONSTANT), A.Normalize(mean=(0, 0, 0), std=(1, 1, 1), max_pixel_value=255.0), ToTensorV2() ], bbox_params=A.BboxParams(format='yolo', min_visibility=0.1, label_fields=['class_labels'])) def close_mosaic(self): """关闭 Mosaic 增强""" self.use_mosaic = False print("Mosaic augmentation disabled.") def __len__(self): return len(self.img_files) def load_image(self, index): """加载单张图片并调整长边到 img_size""" img_path = self.img_files[index] img = cv2.imread(img_path) if img is None: raise FileNotFoundError(f"Image not found: {img_path}") h, w = img.shape[:2] r = self.img_size / max(h, w) if r != 1: img = cv2.resize(img, (int(w * r), int(h * r)), interpolation=cv2.INTER_LINEAR) return img, (h, w), img.shape[:2] # img, original_hw, resized_hw def load_label(self, index, img_shape): """加载标签并归一化""" img_path = self.img_files[index] label_path = self._get_label_path(img_path) h, w = img_shape labels = [] if os.path.exists(label_path): with open(label_path, 'r') as f: for line in f: parts = line.strip().split() if len(parts) >= 5: cls = int(parts[0]) bx, by, bw, bh = map(float, parts[1:5]) labels.append([cls, bx, by, bw, bh]) return np.array(labels, dtype=np.float32) if labels else np.zeros((0, 5), dtype=np.float32) def load_mosaic(self, index): """ 实现 YOLO 的 Mosaic 增强 (4张图拼成一张) """ labels4 = [] s = self.img_size # 修复: 列表推导式需要遍历两次以生成 yc 和 xc yc, xc = [int(random.uniform(-x, 2 * s + x)) for x in [-s // 2, -s // 2]] # 随机选3个额外索引 indices = [index] + [random.randint(0, len(self.img_files) - 1) for _ in range(3)] random.shuffle(indices) # 初始化大图 (2x size) img4 = np.full((s * 2, s * 2, 3), 114, dtype=np.uint8) for i, idx in enumerate(indices): # 加载图片 img, _, (h, w) = self.load_image(idx) # 放置位置: top-left, top-right, bottom-left, bottom-right if i == 0: # top left x1a, y1a, x2a, y2a = max(xc - w, 0), max(yc - h, 0), xc, yc x1b, y1b, x2b, y2b = w - (x2a - x1a), h - (y2a - y1a), w, h elif i == 1: # top right x1a, y1a, x2a, y2a = xc, max(yc - h, 0), min(xc + w, s * 2), yc x1b, y1b, x2b, y2b = 0, h - (y2a - y1a), min(w, x2a - x1a), h elif i == 2: # bottom left x1a, y1a, x2a, y2a = max(xc - w, 0), yc, xc, min(yc + h, s * 2) x1b, y1b, x2b, y2b = w - (x2a - x1a), 0, w, min(y2a - y1a, h) elif i == 3: # bottom right x1a, y1a, x2a, y2a = xc, yc, min(xc + w, s * 2), min(yc + h, s * 2) x1b, y1b, x2b, y2b = 0, 0, min(w, x2a - x1a), min(y2a - y1a, h) # 贴图 img4[y1a:y2a, x1a:x2a] = img[y1b:y2b, x1b:x2b] # type: ignore padw = x1a - x1b # type: ignore padh = y1a - y1b # type: ignore # 处理标签 labels = self.load_label(idx, (h, w)) if labels.size > 0: # Normalized xywh -> Pixel xywh labels[:, 1] = labels[:, 1] * w labels[:, 2] = labels[:, 2] * h labels[:, 3] = labels[:, 3] * w labels[:, 4] = labels[:, 4] * h # xywh -> xyxy (Pixel) xyxy = np.copy(labels) xyxy[:, 1] = labels[:, 1] - labels[:, 3] / 2 + padw xyxy[:, 2] = labels[:, 2] - labels[:, 4] / 2 + padh xyxy[:, 3] = labels[:, 1] + labels[:, 3] / 2 + padw xyxy[:, 4] = labels[:, 2] + labels[:, 4] / 2 + padh labels4.append(xyxy) # Concat labels if len(labels4): labels4 = np.concatenate(labels4, 0) # Clip to mosaic image border np.clip(labels4[:, 1:], 0, 2 * s, out=labels4[:, 1:]) # 转换回 Normalized xywh (相对于 2*s 的大图) # Albumentations 需要 normalized xywh new_labels = np.copy(labels4) w_mosaic, h_mosaic = s * 2, s * 2 # xyxy -> xywh new_labels[:, 1] = (labels4[:, 1] + labels4[:, 3]) / 2 / w_mosaic new_labels[:, 2] = (labels4[:, 2] + labels4[:, 4]) / 2 / h_mosaic new_labels[:, 3] = (labels4[:, 3] - labels4[:, 1]) / w_mosaic new_labels[:, 4] = (labels4[:, 4] - labels4[:, 2]) / h_mosaic return img4, new_labels else: return img4, np.zeros((0, 5)) def __getitem__(self, index): try: if self.is_train: # 修改判断逻辑:同时检查 is_train 和 use_mosaic if self.use_mosaic and random.random() < 1.0: img, labels = self.load_mosaic(index) else: img, _, _ = self.load_image(index) h, w = img.shape[:2] labels = self.load_label(index, (h, w)) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # type: ignore else: img = cv2.imread(self.img_files[index]) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # type: ignore h, w = img.shape[:2] labels = self.load_label(index, (h, w)) # --- 修复开始: 清洗边界框 --- # labels 格式: [cls, x, y, w, h] (Normalized) valid_bboxes = [] valid_labels = [] h_img, w_img = img.shape[:2] for i in range(len(labels)): cls = labels[i, 0] x, y, w, h = labels[i, 1:] # 1. 限制在 [0, 1] 范围内 (处理 Mosaic 裁剪产生的越界) x1 = np.clip(x - w / 2, 0, 1) y1 = np.clip(y - h / 2, 0, 1) x2 = np.clip(x + w / 2, 0, 1) y2 = np.clip(y + h / 2, 0, 1) # 2. 重新计算宽高 w_new = x2 - x1 h_new = y2 - y1 # 3. 重新计算中心点 x_new = x1 + w_new / 2 y_new = y1 + h_new / 2 # 4. 过滤掉极小的框 (例如小于 2 个像素) if w_new * w_img > 2 and h_new * h_img > 2: valid_bboxes.append([x_new, y_new, w_new, h_new]) valid_labels.append(cls) if len(valid_bboxes) == 0: # 如果这张图的所有框都被过滤掉了,尝试下一张 return self.__getitem__((index + 1) % len(self)) bboxes = valid_bboxes class_labels = valid_labels # --- 修复结束 --- # 应用增强 transformed = self.transform(image=img, bboxes=bboxes, class_labels=class_labels) image = transformed['image'] bboxes = transformed['bboxes'] class_labels = transformed['class_labels'] # 构建 Target Tensor n = len(bboxes) targets = torch.zeros((n, 6)) if n > 0: targets[:, 1] = torch.tensor(class_labels) targets[:, 2:] = torch.tensor(bboxes) return image, targets, self.img_files[index] except Exception as e: # 打印更详细的错误信息以便调试,但不要中断训练 # print(f"Error loading data {index}: {e}") return self.__getitem__((index + 1) % len(self)) def _get_label_path(self, img_path): filename = os.path.basename(img_path).rsplit('.', 1)[0] + ".txt" return os.path.join(self.label_dir, filename) @staticmethod def collate_fn(batch): imgs, targets, paths = zip(*batch) imgs = torch.stack(imgs, 0) new_targets = [] for i, t in enumerate(targets): if t.shape[0] > 0: t[:, 0] = i new_targets.append(t) if new_targets: targets = torch.cat(new_targets, 0) else: targets = torch.zeros((0, 6)) return imgs, targets, paths