288 lines
12 KiB
Python
288 lines
12 KiB
Python
import os
|
||
import glob
|
||
import cv2
|
||
import numpy as np
|
||
import torch
|
||
import random
|
||
from torch.utils.data import Dataset
|
||
import albumentations as A
|
||
from albumentations.pytorch import ToTensorV2
|
||
|
||
class YOLODataset(Dataset):
|
||
def __init__(self, img_dir, label_dir, img_size=640, is_train=True):
|
||
self.img_dir = img_dir
|
||
self.label_dir = label_dir
|
||
self.img_size = img_size
|
||
self.is_train = is_train
|
||
self.use_mosaic = is_train # 新增标志位,默认训练时开启
|
||
|
||
# 支持多种图片格式
|
||
self.img_files = sorted(
|
||
glob.glob(os.path.join(img_dir, "*.jpg")) +
|
||
glob.glob(os.path.join(img_dir, "*.png")) +
|
||
glob.glob(os.path.join(img_dir, "*.jpeg"))
|
||
)
|
||
|
||
# --- 1. 对齐 Ultralytics 的 Albumentations 配置 ---
|
||
# default.yaml: hsv_h: 0.015, hsv_s: 0.7, hsv_v: 0.4
|
||
# OpenCV Hue range is [0, 179], Sat/Val is [0, 255]
|
||
h_limit = int(0.015 * 179) # ~2
|
||
s_limit = int(0.7 * 255) # ~178
|
||
v_limit = int(0.4 * 255) # ~102
|
||
|
||
# default.yaml: translate: 0.1, scale: 0.5 (0.5~1.5), degrees: 0.0
|
||
|
||
if is_train:
|
||
self.transform = A.Compose([
|
||
# 几何增强 (Mosaic 之后再做一次微调,或者处理非 Mosaic 的情况)
|
||
# 注意:Mosaic 输出已经是大图,这里主要负责最后的随机扰动
|
||
A.Affine(
|
||
scale=(0.5, 1.5), # scale: 0.5
|
||
translate_percent=(0.1, 0.1), # translate: 0.1
|
||
rotate=(-0, 0), # degrees: 0.0 (COCO default)
|
||
shear=(-0, 0), # shear: 0.0
|
||
p=0.5
|
||
),
|
||
|
||
# 色彩增强 (严格对齐 default.yaml)
|
||
A.HueSaturationValue(
|
||
hue_shift_limit=h_limit,
|
||
sat_shift_limit=s_limit,
|
||
val_shift_limit=v_limit,
|
||
p=0.5
|
||
),
|
||
|
||
A.Blur(p=0.01),
|
||
A.MedianBlur(p=0.01),
|
||
A.ToGray(p=0.01),
|
||
A.CLAHE(p=0.01),
|
||
|
||
# 翻转
|
||
A.HorizontalFlip(p=0.5), # fliplr: 0.5
|
||
|
||
# 最终处理
|
||
A.Resize(img_size, img_size), # 确保最后尺寸一致
|
||
A.Normalize(mean=(0, 0, 0), std=(1, 1, 1), max_pixel_value=255.0),
|
||
ToTensorV2()
|
||
], bbox_params=A.BboxParams(format='yolo', min_visibility=0.0, label_fields=['class_labels']))
|
||
else:
|
||
# 验证集:Letterbox (保持比例填充)
|
||
self.transform = A.Compose([
|
||
A.LongestMaxSize(max_size=img_size),
|
||
A.PadIfNeeded(min_height=img_size, min_width=img_size, border_mode=cv2.BORDER_CONSTANT),
|
||
A.Normalize(mean=(0, 0, 0), std=(1, 1, 1), max_pixel_value=255.0),
|
||
ToTensorV2()
|
||
], bbox_params=A.BboxParams(format='yolo', min_visibility=0.1, label_fields=['class_labels']))
|
||
|
||
def close_mosaic(self):
|
||
"""关闭 Mosaic 增强"""
|
||
self.use_mosaic = False
|
||
print("Mosaic augmentation disabled.")
|
||
|
||
def __len__(self):
|
||
return len(self.img_files)
|
||
|
||
def load_image(self, index):
|
||
"""加载单张图片并调整长边到 img_size"""
|
||
img_path = self.img_files[index]
|
||
img = cv2.imread(img_path)
|
||
if img is None:
|
||
raise FileNotFoundError(f"Image not found: {img_path}")
|
||
|
||
h, w = img.shape[:2]
|
||
r = self.img_size / max(h, w)
|
||
if r != 1:
|
||
img = cv2.resize(img, (int(w * r), int(h * r)), interpolation=cv2.INTER_LINEAR)
|
||
|
||
return img, (h, w), img.shape[:2] # img, original_hw, resized_hw
|
||
|
||
def load_label(self, index, img_shape):
|
||
"""加载标签并归一化"""
|
||
img_path = self.img_files[index]
|
||
label_path = self._get_label_path(img_path)
|
||
h, w = img_shape
|
||
|
||
labels = []
|
||
if os.path.exists(label_path):
|
||
with open(label_path, 'r') as f:
|
||
for line in f:
|
||
parts = line.strip().split()
|
||
if len(parts) >= 5:
|
||
cls = int(parts[0])
|
||
bx, by, bw, bh = map(float, parts[1:5])
|
||
labels.append([cls, bx, by, bw, bh])
|
||
|
||
return np.array(labels, dtype=np.float32) if labels else np.zeros((0, 5), dtype=np.float32)
|
||
|
||
def load_mosaic(self, index):
|
||
"""
|
||
实现 YOLO 的 Mosaic 增强 (4张图拼成一张)
|
||
"""
|
||
labels4 = []
|
||
s = self.img_size
|
||
# 修复: 列表推导式需要遍历两次以生成 yc 和 xc
|
||
yc, xc = [int(random.uniform(-x, 2 * s + x)) for x in [-s // 2, -s // 2]]
|
||
|
||
# 随机选3个额外索引
|
||
indices = [index] + [random.randint(0, len(self.img_files) - 1) for _ in range(3)]
|
||
random.shuffle(indices)
|
||
|
||
# 初始化大图 (2x size)
|
||
img4 = np.full((s * 2, s * 2, 3), 114, dtype=np.uint8)
|
||
|
||
for i, idx in enumerate(indices):
|
||
# 加载图片
|
||
img, _, (h, w) = self.load_image(idx)
|
||
|
||
# 放置位置: top-left, top-right, bottom-left, bottom-right
|
||
if i == 0: # top left
|
||
x1a, y1a, x2a, y2a = max(xc - w, 0), max(yc - h, 0), xc, yc
|
||
x1b, y1b, x2b, y2b = w - (x2a - x1a), h - (y2a - y1a), w, h
|
||
elif i == 1: # top right
|
||
x1a, y1a, x2a, y2a = xc, max(yc - h, 0), min(xc + w, s * 2), yc
|
||
x1b, y1b, x2b, y2b = 0, h - (y2a - y1a), min(w, x2a - x1a), h
|
||
elif i == 2: # bottom left
|
||
x1a, y1a, x2a, y2a = max(xc - w, 0), yc, xc, min(yc + h, s * 2)
|
||
x1b, y1b, x2b, y2b = w - (x2a - x1a), 0, w, min(y2a - y1a, h)
|
||
elif i == 3: # bottom right
|
||
x1a, y1a, x2a, y2a = xc, yc, min(xc + w, s * 2), min(yc + h, s * 2)
|
||
x1b, y1b, x2b, y2b = 0, 0, min(w, x2a - x1a), min(y2a - y1a, h)
|
||
|
||
# 贴图
|
||
img4[y1a:y2a, x1a:x2a] = img[y1b:y2b, x1b:x2b] # type: ignore
|
||
padw = x1a - x1b # type: ignore
|
||
padh = y1a - y1b # type: ignore
|
||
|
||
# 处理标签
|
||
labels = self.load_label(idx, (h, w))
|
||
if labels.size > 0:
|
||
# Normalized xywh -> Pixel xywh
|
||
labels[:, 1] = labels[:, 1] * w
|
||
labels[:, 2] = labels[:, 2] * h
|
||
labels[:, 3] = labels[:, 3] * w
|
||
labels[:, 4] = labels[:, 4] * h
|
||
|
||
# xywh -> xyxy (Pixel)
|
||
xyxy = np.copy(labels)
|
||
xyxy[:, 1] = labels[:, 1] - labels[:, 3] / 2 + padw
|
||
xyxy[:, 2] = labels[:, 2] - labels[:, 4] / 2 + padh
|
||
xyxy[:, 3] = labels[:, 1] + labels[:, 3] / 2 + padw
|
||
xyxy[:, 4] = labels[:, 2] + labels[:, 4] / 2 + padh
|
||
|
||
labels4.append(xyxy)
|
||
|
||
# Concat labels
|
||
if len(labels4):
|
||
labels4 = np.concatenate(labels4, 0)
|
||
# Clip to mosaic image border
|
||
np.clip(labels4[:, 1:], 0, 2 * s, out=labels4[:, 1:])
|
||
|
||
# 转换回 Normalized xywh (相对于 2*s 的大图)
|
||
# Albumentations 需要 normalized xywh
|
||
new_labels = np.copy(labels4)
|
||
w_mosaic, h_mosaic = s * 2, s * 2
|
||
|
||
# xyxy -> xywh
|
||
new_labels[:, 1] = (labels4[:, 1] + labels4[:, 3]) / 2 / w_mosaic
|
||
new_labels[:, 2] = (labels4[:, 2] + labels4[:, 4]) / 2 / h_mosaic
|
||
new_labels[:, 3] = (labels4[:, 3] - labels4[:, 1]) / w_mosaic
|
||
new_labels[:, 4] = (labels4[:, 4] - labels4[:, 2]) / h_mosaic
|
||
|
||
return img4, new_labels
|
||
else:
|
||
return img4, np.zeros((0, 5))
|
||
|
||
def __getitem__(self, index):
|
||
try:
|
||
if self.is_train:
|
||
# 修改判断逻辑:同时检查 is_train 和 use_mosaic
|
||
if self.use_mosaic and random.random() < 1.0:
|
||
img, labels = self.load_mosaic(index)
|
||
else:
|
||
img, _, _ = self.load_image(index)
|
||
h, w = img.shape[:2]
|
||
labels = self.load_label(index, (h, w))
|
||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # type: ignore
|
||
else:
|
||
img = cv2.imread(self.img_files[index])
|
||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # type: ignore
|
||
h, w = img.shape[:2]
|
||
labels = self.load_label(index, (h, w))
|
||
|
||
# --- 修复开始: 清洗边界框 ---
|
||
# labels 格式: [cls, x, y, w, h] (Normalized)
|
||
|
||
valid_bboxes = []
|
||
valid_labels = []
|
||
|
||
h_img, w_img = img.shape[:2]
|
||
|
||
for i in range(len(labels)):
|
||
cls = labels[i, 0]
|
||
x, y, w, h = labels[i, 1:]
|
||
|
||
# 1. 限制在 [0, 1] 范围内 (处理 Mosaic 裁剪产生的越界)
|
||
x1 = np.clip(x - w / 2, 0, 1)
|
||
y1 = np.clip(y - h / 2, 0, 1)
|
||
x2 = np.clip(x + w / 2, 0, 1)
|
||
y2 = np.clip(y + h / 2, 0, 1)
|
||
|
||
# 2. 重新计算宽高
|
||
w_new = x2 - x1
|
||
h_new = y2 - y1
|
||
|
||
# 3. 重新计算中心点
|
||
x_new = x1 + w_new / 2
|
||
y_new = y1 + h_new / 2
|
||
|
||
# 4. 过滤掉极小的框 (例如小于 2 个像素)
|
||
if w_new * w_img > 2 and h_new * h_img > 2:
|
||
valid_bboxes.append([x_new, y_new, w_new, h_new])
|
||
valid_labels.append(cls)
|
||
|
||
if len(valid_bboxes) == 0:
|
||
# 如果这张图的所有框都被过滤掉了,尝试下一张
|
||
return self.__getitem__((index + 1) % len(self))
|
||
|
||
bboxes = valid_bboxes
|
||
class_labels = valid_labels
|
||
# --- 修复结束 ---
|
||
|
||
# 应用增强
|
||
transformed = self.transform(image=img, bboxes=bboxes, class_labels=class_labels)
|
||
image = transformed['image']
|
||
bboxes = transformed['bboxes']
|
||
class_labels = transformed['class_labels']
|
||
|
||
# 构建 Target Tensor
|
||
n = len(bboxes)
|
||
targets = torch.zeros((n, 6))
|
||
if n > 0:
|
||
targets[:, 1] = torch.tensor(class_labels)
|
||
targets[:, 2:] = torch.tensor(bboxes)
|
||
|
||
return image, targets, self.img_files[index]
|
||
|
||
except Exception as e:
|
||
# 打印更详细的错误信息以便调试,但不要中断训练
|
||
# print(f"Error loading data {index}: {e}")
|
||
return self.__getitem__((index + 1) % len(self))
|
||
|
||
def _get_label_path(self, img_path):
|
||
filename = os.path.basename(img_path).rsplit('.', 1)[0] + ".txt"
|
||
return os.path.join(self.label_dir, filename)
|
||
|
||
@staticmethod
|
||
def collate_fn(batch):
|
||
imgs, targets, paths = zip(*batch)
|
||
imgs = torch.stack(imgs, 0)
|
||
new_targets = []
|
||
for i, t in enumerate(targets):
|
||
if t.shape[0] > 0:
|
||
t[:, 0] = i
|
||
new_targets.append(t)
|
||
if new_targets:
|
||
targets = torch.cat(new_targets, 0)
|
||
else:
|
||
targets = torch.zeros((0, 6))
|
||
return imgs, targets, paths |