第一次提交Yolo项目
This commit is contained in:
288
dataset.py
Normal file
288
dataset.py
Normal file
@@ -0,0 +1,288 @@
|
||||
import os
|
||||
import glob
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import random
|
||||
from torch.utils.data import Dataset
|
||||
import albumentations as A
|
||||
from albumentations.pytorch import ToTensorV2
|
||||
|
||||
class YOLODataset(Dataset):
|
||||
def __init__(self, img_dir, label_dir, img_size=640, is_train=True):
|
||||
self.img_dir = img_dir
|
||||
self.label_dir = label_dir
|
||||
self.img_size = img_size
|
||||
self.is_train = is_train
|
||||
self.use_mosaic = is_train # 新增标志位,默认训练时开启
|
||||
|
||||
# 支持多种图片格式
|
||||
self.img_files = sorted(
|
||||
glob.glob(os.path.join(img_dir, "*.jpg")) +
|
||||
glob.glob(os.path.join(img_dir, "*.png")) +
|
||||
glob.glob(os.path.join(img_dir, "*.jpeg"))
|
||||
)
|
||||
|
||||
# --- 1. 对齐 Ultralytics 的 Albumentations 配置 ---
|
||||
# default.yaml: hsv_h: 0.015, hsv_s: 0.7, hsv_v: 0.4
|
||||
# OpenCV Hue range is [0, 179], Sat/Val is [0, 255]
|
||||
h_limit = int(0.015 * 179) # ~2
|
||||
s_limit = int(0.7 * 255) # ~178
|
||||
v_limit = int(0.4 * 255) # ~102
|
||||
|
||||
# default.yaml: translate: 0.1, scale: 0.5 (0.5~1.5), degrees: 0.0
|
||||
|
||||
if is_train:
|
||||
self.transform = A.Compose([
|
||||
# 几何增强 (Mosaic 之后再做一次微调,或者处理非 Mosaic 的情况)
|
||||
# 注意:Mosaic 输出已经是大图,这里主要负责最后的随机扰动
|
||||
A.Affine(
|
||||
scale=(0.5, 1.5), # scale: 0.5
|
||||
translate_percent=(0.1, 0.1), # translate: 0.1
|
||||
rotate=(-0, 0), # degrees: 0.0 (COCO default)
|
||||
shear=(-0, 0), # shear: 0.0
|
||||
p=0.5
|
||||
),
|
||||
|
||||
# 色彩增强 (严格对齐 default.yaml)
|
||||
A.HueSaturationValue(
|
||||
hue_shift_limit=h_limit,
|
||||
sat_shift_limit=s_limit,
|
||||
val_shift_limit=v_limit,
|
||||
p=0.5
|
||||
),
|
||||
|
||||
A.Blur(p=0.01),
|
||||
A.MedianBlur(p=0.01),
|
||||
A.ToGray(p=0.01),
|
||||
A.CLAHE(p=0.01),
|
||||
|
||||
# 翻转
|
||||
A.HorizontalFlip(p=0.5), # fliplr: 0.5
|
||||
|
||||
# 最终处理
|
||||
A.Resize(img_size, img_size), # 确保最后尺寸一致
|
||||
A.Normalize(mean=(0, 0, 0), std=(1, 1, 1), max_pixel_value=255.0),
|
||||
ToTensorV2()
|
||||
], bbox_params=A.BboxParams(format='yolo', min_visibility=0.0, label_fields=['class_labels']))
|
||||
else:
|
||||
# 验证集:Letterbox (保持比例填充)
|
||||
self.transform = A.Compose([
|
||||
A.LongestMaxSize(max_size=img_size),
|
||||
A.PadIfNeeded(min_height=img_size, min_width=img_size, border_mode=cv2.BORDER_CONSTANT),
|
||||
A.Normalize(mean=(0, 0, 0), std=(1, 1, 1), max_pixel_value=255.0),
|
||||
ToTensorV2()
|
||||
], bbox_params=A.BboxParams(format='yolo', min_visibility=0.1, label_fields=['class_labels']))
|
||||
|
||||
def close_mosaic(self):
|
||||
"""关闭 Mosaic 增强"""
|
||||
self.use_mosaic = False
|
||||
print("Mosaic augmentation disabled.")
|
||||
|
||||
def __len__(self):
|
||||
return len(self.img_files)
|
||||
|
||||
def load_image(self, index):
|
||||
"""加载单张图片并调整长边到 img_size"""
|
||||
img_path = self.img_files[index]
|
||||
img = cv2.imread(img_path)
|
||||
if img is None:
|
||||
raise FileNotFoundError(f"Image not found: {img_path}")
|
||||
|
||||
h, w = img.shape[:2]
|
||||
r = self.img_size / max(h, w)
|
||||
if r != 1:
|
||||
img = cv2.resize(img, (int(w * r), int(h * r)), interpolation=cv2.INTER_LINEAR)
|
||||
|
||||
return img, (h, w), img.shape[:2] # img, original_hw, resized_hw
|
||||
|
||||
def load_label(self, index, img_shape):
|
||||
"""加载标签并归一化"""
|
||||
img_path = self.img_files[index]
|
||||
label_path = self._get_label_path(img_path)
|
||||
h, w = img_shape
|
||||
|
||||
labels = []
|
||||
if os.path.exists(label_path):
|
||||
with open(label_path, 'r') as f:
|
||||
for line in f:
|
||||
parts = line.strip().split()
|
||||
if len(parts) >= 5:
|
||||
cls = int(parts[0])
|
||||
bx, by, bw, bh = map(float, parts[1:5])
|
||||
labels.append([cls, bx, by, bw, bh])
|
||||
|
||||
return np.array(labels, dtype=np.float32) if labels else np.zeros((0, 5), dtype=np.float32)
|
||||
|
||||
def load_mosaic(self, index):
|
||||
"""
|
||||
实现 YOLO 的 Mosaic 增强 (4张图拼成一张)
|
||||
"""
|
||||
labels4 = []
|
||||
s = self.img_size
|
||||
# 修复: 列表推导式需要遍历两次以生成 yc 和 xc
|
||||
yc, xc = [int(random.uniform(-x, 2 * s + x)) for x in [-s // 2, -s // 2]]
|
||||
|
||||
# 随机选3个额外索引
|
||||
indices = [index] + [random.randint(0, len(self.img_files) - 1) for _ in range(3)]
|
||||
random.shuffle(indices)
|
||||
|
||||
# 初始化大图 (2x size)
|
||||
img4 = np.full((s * 2, s * 2, 3), 114, dtype=np.uint8)
|
||||
|
||||
for i, idx in enumerate(indices):
|
||||
# 加载图片
|
||||
img, _, (h, w) = self.load_image(idx)
|
||||
|
||||
# 放置位置: top-left, top-right, bottom-left, bottom-right
|
||||
if i == 0: # top left
|
||||
x1a, y1a, x2a, y2a = max(xc - w, 0), max(yc - h, 0), xc, yc
|
||||
x1b, y1b, x2b, y2b = w - (x2a - x1a), h - (y2a - y1a), w, h
|
||||
elif i == 1: # top right
|
||||
x1a, y1a, x2a, y2a = xc, max(yc - h, 0), min(xc + w, s * 2), yc
|
||||
x1b, y1b, x2b, y2b = 0, h - (y2a - y1a), min(w, x2a - x1a), h
|
||||
elif i == 2: # bottom left
|
||||
x1a, y1a, x2a, y2a = max(xc - w, 0), yc, xc, min(yc + h, s * 2)
|
||||
x1b, y1b, x2b, y2b = w - (x2a - x1a), 0, w, min(y2a - y1a, h)
|
||||
elif i == 3: # bottom right
|
||||
x1a, y1a, x2a, y2a = xc, yc, min(xc + w, s * 2), min(yc + h, s * 2)
|
||||
x1b, y1b, x2b, y2b = 0, 0, min(w, x2a - x1a), min(y2a - y1a, h)
|
||||
|
||||
# 贴图
|
||||
img4[y1a:y2a, x1a:x2a] = img[y1b:y2b, x1b:x2b] # type: ignore
|
||||
padw = x1a - x1b # type: ignore
|
||||
padh = y1a - y1b # type: ignore
|
||||
|
||||
# 处理标签
|
||||
labels = self.load_label(idx, (h, w))
|
||||
if labels.size > 0:
|
||||
# Normalized xywh -> Pixel xywh
|
||||
labels[:, 1] = labels[:, 1] * w
|
||||
labels[:, 2] = labels[:, 2] * h
|
||||
labels[:, 3] = labels[:, 3] * w
|
||||
labels[:, 4] = labels[:, 4] * h
|
||||
|
||||
# xywh -> xyxy (Pixel)
|
||||
xyxy = np.copy(labels)
|
||||
xyxy[:, 1] = labels[:, 1] - labels[:, 3] / 2 + padw
|
||||
xyxy[:, 2] = labels[:, 2] - labels[:, 4] / 2 + padh
|
||||
xyxy[:, 3] = labels[:, 1] + labels[:, 3] / 2 + padw
|
||||
xyxy[:, 4] = labels[:, 2] + labels[:, 4] / 2 + padh
|
||||
|
||||
labels4.append(xyxy)
|
||||
|
||||
# Concat labels
|
||||
if len(labels4):
|
||||
labels4 = np.concatenate(labels4, 0)
|
||||
# Clip to mosaic image border
|
||||
np.clip(labels4[:, 1:], 0, 2 * s, out=labels4[:, 1:])
|
||||
|
||||
# 转换回 Normalized xywh (相对于 2*s 的大图)
|
||||
# Albumentations 需要 normalized xywh
|
||||
new_labels = np.copy(labels4)
|
||||
w_mosaic, h_mosaic = s * 2, s * 2
|
||||
|
||||
# xyxy -> xywh
|
||||
new_labels[:, 1] = (labels4[:, 1] + labels4[:, 3]) / 2 / w_mosaic
|
||||
new_labels[:, 2] = (labels4[:, 2] + labels4[:, 4]) / 2 / h_mosaic
|
||||
new_labels[:, 3] = (labels4[:, 3] - labels4[:, 1]) / w_mosaic
|
||||
new_labels[:, 4] = (labels4[:, 4] - labels4[:, 2]) / h_mosaic
|
||||
|
||||
return img4, new_labels
|
||||
else:
|
||||
return img4, np.zeros((0, 5))
|
||||
|
||||
def __getitem__(self, index):
|
||||
try:
|
||||
if self.is_train:
|
||||
# 修改判断逻辑:同时检查 is_train 和 use_mosaic
|
||||
if self.use_mosaic and random.random() < 1.0:
|
||||
img, labels = self.load_mosaic(index)
|
||||
else:
|
||||
img, _, _ = self.load_image(index)
|
||||
h, w = img.shape[:2]
|
||||
labels = self.load_label(index, (h, w))
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # type: ignore
|
||||
else:
|
||||
img = cv2.imread(self.img_files[index])
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # type: ignore
|
||||
h, w = img.shape[:2]
|
||||
labels = self.load_label(index, (h, w))
|
||||
|
||||
# --- 修复开始: 清洗边界框 ---
|
||||
# labels 格式: [cls, x, y, w, h] (Normalized)
|
||||
|
||||
valid_bboxes = []
|
||||
valid_labels = []
|
||||
|
||||
h_img, w_img = img.shape[:2]
|
||||
|
||||
for i in range(len(labels)):
|
||||
cls = labels[i, 0]
|
||||
x, y, w, h = labels[i, 1:]
|
||||
|
||||
# 1. 限制在 [0, 1] 范围内 (处理 Mosaic 裁剪产生的越界)
|
||||
x1 = np.clip(x - w / 2, 0, 1)
|
||||
y1 = np.clip(y - h / 2, 0, 1)
|
||||
x2 = np.clip(x + w / 2, 0, 1)
|
||||
y2 = np.clip(y + h / 2, 0, 1)
|
||||
|
||||
# 2. 重新计算宽高
|
||||
w_new = x2 - x1
|
||||
h_new = y2 - y1
|
||||
|
||||
# 3. 重新计算中心点
|
||||
x_new = x1 + w_new / 2
|
||||
y_new = y1 + h_new / 2
|
||||
|
||||
# 4. 过滤掉极小的框 (例如小于 2 个像素)
|
||||
if w_new * w_img > 2 and h_new * h_img > 2:
|
||||
valid_bboxes.append([x_new, y_new, w_new, h_new])
|
||||
valid_labels.append(cls)
|
||||
|
||||
if len(valid_bboxes) == 0:
|
||||
# 如果这张图的所有框都被过滤掉了,尝试下一张
|
||||
return self.__getitem__((index + 1) % len(self))
|
||||
|
||||
bboxes = valid_bboxes
|
||||
class_labels = valid_labels
|
||||
# --- 修复结束 ---
|
||||
|
||||
# 应用增强
|
||||
transformed = self.transform(image=img, bboxes=bboxes, class_labels=class_labels)
|
||||
image = transformed['image']
|
||||
bboxes = transformed['bboxes']
|
||||
class_labels = transformed['class_labels']
|
||||
|
||||
# 构建 Target Tensor
|
||||
n = len(bboxes)
|
||||
targets = torch.zeros((n, 6))
|
||||
if n > 0:
|
||||
targets[:, 1] = torch.tensor(class_labels)
|
||||
targets[:, 2:] = torch.tensor(bboxes)
|
||||
|
||||
return image, targets, self.img_files[index]
|
||||
|
||||
except Exception as e:
|
||||
# 打印更详细的错误信息以便调试,但不要中断训练
|
||||
# print(f"Error loading data {index}: {e}")
|
||||
return self.__getitem__((index + 1) % len(self))
|
||||
|
||||
def _get_label_path(self, img_path):
|
||||
filename = os.path.basename(img_path).rsplit('.', 1)[0] + ".txt"
|
||||
return os.path.join(self.label_dir, filename)
|
||||
|
||||
@staticmethod
|
||||
def collate_fn(batch):
|
||||
imgs, targets, paths = zip(*batch)
|
||||
imgs = torch.stack(imgs, 0)
|
||||
new_targets = []
|
||||
for i, t in enumerate(targets):
|
||||
if t.shape[0] > 0:
|
||||
t[:, 0] = i
|
||||
new_targets.append(t)
|
||||
if new_targets:
|
||||
targets = torch.cat(new_targets, 0)
|
||||
else:
|
||||
targets = torch.zeros((0, 6))
|
||||
return imgs, targets, paths
|
||||
Reference in New Issue
Block a user