commit 604951f9c2e8af1913c79b42cee20b04226e9123 Author: lhr Date: Sat Dec 27 02:14:11 2025 +0800 第一次提交Yolo项目 diff --git a/dataset.py b/dataset.py new file mode 100644 index 0000000..ac3e745 --- /dev/null +++ b/dataset.py @@ -0,0 +1,288 @@ +import os +import glob +import cv2 +import numpy as np +import torch +import random +from torch.utils.data import Dataset +import albumentations as A +from albumentations.pytorch import ToTensorV2 + +class YOLODataset(Dataset): + def __init__(self, img_dir, label_dir, img_size=640, is_train=True): + self.img_dir = img_dir + self.label_dir = label_dir + self.img_size = img_size + self.is_train = is_train + self.use_mosaic = is_train # 新增标志位,默认训练时开启 + + # 支持多种图片格式 + self.img_files = sorted( + glob.glob(os.path.join(img_dir, "*.jpg")) + + glob.glob(os.path.join(img_dir, "*.png")) + + glob.glob(os.path.join(img_dir, "*.jpeg")) + ) + + # --- 1. 对齐 Ultralytics 的 Albumentations 配置 --- + # default.yaml: hsv_h: 0.015, hsv_s: 0.7, hsv_v: 0.4 + # OpenCV Hue range is [0, 179], Sat/Val is [0, 255] + h_limit = int(0.015 * 179) # ~2 + s_limit = int(0.7 * 255) # ~178 + v_limit = int(0.4 * 255) # ~102 + + # default.yaml: translate: 0.1, scale: 0.5 (0.5~1.5), degrees: 0.0 + + if is_train: + self.transform = A.Compose([ + # 几何增强 (Mosaic 之后再做一次微调,或者处理非 Mosaic 的情况) + # 注意:Mosaic 输出已经是大图,这里主要负责最后的随机扰动 + A.Affine( + scale=(0.5, 1.5), # scale: 0.5 + translate_percent=(0.1, 0.1), # translate: 0.1 + rotate=(-0, 0), # degrees: 0.0 (COCO default) + shear=(-0, 0), # shear: 0.0 + p=0.5 + ), + + # 色彩增强 (严格对齐 default.yaml) + A.HueSaturationValue( + hue_shift_limit=h_limit, + sat_shift_limit=s_limit, + val_shift_limit=v_limit, + p=0.5 + ), + + A.Blur(p=0.01), + A.MedianBlur(p=0.01), + A.ToGray(p=0.01), + A.CLAHE(p=0.01), + + # 翻转 + A.HorizontalFlip(p=0.5), # fliplr: 0.5 + + # 最终处理 + A.Resize(img_size, img_size), # 确保最后尺寸一致 + A.Normalize(mean=(0, 0, 0), std=(1, 1, 1), max_pixel_value=255.0), + ToTensorV2() + ], bbox_params=A.BboxParams(format='yolo', min_visibility=0.0, label_fields=['class_labels'])) + else: + # 验证集:Letterbox (保持比例填充) + self.transform = A.Compose([ + A.LongestMaxSize(max_size=img_size), + A.PadIfNeeded(min_height=img_size, min_width=img_size, border_mode=cv2.BORDER_CONSTANT), + A.Normalize(mean=(0, 0, 0), std=(1, 1, 1), max_pixel_value=255.0), + ToTensorV2() + ], bbox_params=A.BboxParams(format='yolo', min_visibility=0.1, label_fields=['class_labels'])) + + def close_mosaic(self): + """关闭 Mosaic 增强""" + self.use_mosaic = False + print("Mosaic augmentation disabled.") + + def __len__(self): + return len(self.img_files) + + def load_image(self, index): + """加载单张图片并调整长边到 img_size""" + img_path = self.img_files[index] + img = cv2.imread(img_path) + if img is None: + raise FileNotFoundError(f"Image not found: {img_path}") + + h, w = img.shape[:2] + r = self.img_size / max(h, w) + if r != 1: + img = cv2.resize(img, (int(w * r), int(h * r)), interpolation=cv2.INTER_LINEAR) + + return img, (h, w), img.shape[:2] # img, original_hw, resized_hw + + def load_label(self, index, img_shape): + """加载标签并归一化""" + img_path = self.img_files[index] + label_path = self._get_label_path(img_path) + h, w = img_shape + + labels = [] + if os.path.exists(label_path): + with open(label_path, 'r') as f: + for line in f: + parts = line.strip().split() + if len(parts) >= 5: + cls = int(parts[0]) + bx, by, bw, bh = map(float, parts[1:5]) + labels.append([cls, bx, by, bw, bh]) + + return np.array(labels, dtype=np.float32) if labels else np.zeros((0, 5), dtype=np.float32) + + def load_mosaic(self, index): + """ + 实现 YOLO 的 Mosaic 增强 (4张图拼成一张) + """ + labels4 = [] + s = self.img_size + # 修复: 列表推导式需要遍历两次以生成 yc 和 xc + yc, xc = [int(random.uniform(-x, 2 * s + x)) for x in [-s // 2, -s // 2]] + + # 随机选3个额外索引 + indices = [index] + [random.randint(0, len(self.img_files) - 1) for _ in range(3)] + random.shuffle(indices) + + # 初始化大图 (2x size) + img4 = np.full((s * 2, s * 2, 3), 114, dtype=np.uint8) + + for i, idx in enumerate(indices): + # 加载图片 + img, _, (h, w) = self.load_image(idx) + + # 放置位置: top-left, top-right, bottom-left, bottom-right + if i == 0: # top left + x1a, y1a, x2a, y2a = max(xc - w, 0), max(yc - h, 0), xc, yc + x1b, y1b, x2b, y2b = w - (x2a - x1a), h - (y2a - y1a), w, h + elif i == 1: # top right + x1a, y1a, x2a, y2a = xc, max(yc - h, 0), min(xc + w, s * 2), yc + x1b, y1b, x2b, y2b = 0, h - (y2a - y1a), min(w, x2a - x1a), h + elif i == 2: # bottom left + x1a, y1a, x2a, y2a = max(xc - w, 0), yc, xc, min(yc + h, s * 2) + x1b, y1b, x2b, y2b = w - (x2a - x1a), 0, w, min(y2a - y1a, h) + elif i == 3: # bottom right + x1a, y1a, x2a, y2a = xc, yc, min(xc + w, s * 2), min(yc + h, s * 2) + x1b, y1b, x2b, y2b = 0, 0, min(w, x2a - x1a), min(y2a - y1a, h) + + # 贴图 + img4[y1a:y2a, x1a:x2a] = img[y1b:y2b, x1b:x2b] # type: ignore + padw = x1a - x1b # type: ignore + padh = y1a - y1b # type: ignore + + # 处理标签 + labels = self.load_label(idx, (h, w)) + if labels.size > 0: + # Normalized xywh -> Pixel xywh + labels[:, 1] = labels[:, 1] * w + labels[:, 2] = labels[:, 2] * h + labels[:, 3] = labels[:, 3] * w + labels[:, 4] = labels[:, 4] * h + + # xywh -> xyxy (Pixel) + xyxy = np.copy(labels) + xyxy[:, 1] = labels[:, 1] - labels[:, 3] / 2 + padw + xyxy[:, 2] = labels[:, 2] - labels[:, 4] / 2 + padh + xyxy[:, 3] = labels[:, 1] + labels[:, 3] / 2 + padw + xyxy[:, 4] = labels[:, 2] + labels[:, 4] / 2 + padh + + labels4.append(xyxy) + + # Concat labels + if len(labels4): + labels4 = np.concatenate(labels4, 0) + # Clip to mosaic image border + np.clip(labels4[:, 1:], 0, 2 * s, out=labels4[:, 1:]) + + # 转换回 Normalized xywh (相对于 2*s 的大图) + # Albumentations 需要 normalized xywh + new_labels = np.copy(labels4) + w_mosaic, h_mosaic = s * 2, s * 2 + + # xyxy -> xywh + new_labels[:, 1] = (labels4[:, 1] + labels4[:, 3]) / 2 / w_mosaic + new_labels[:, 2] = (labels4[:, 2] + labels4[:, 4]) / 2 / h_mosaic + new_labels[:, 3] = (labels4[:, 3] - labels4[:, 1]) / w_mosaic + new_labels[:, 4] = (labels4[:, 4] - labels4[:, 2]) / h_mosaic + + return img4, new_labels + else: + return img4, np.zeros((0, 5)) + + def __getitem__(self, index): + try: + if self.is_train: + # 修改判断逻辑:同时检查 is_train 和 use_mosaic + if self.use_mosaic and random.random() < 1.0: + img, labels = self.load_mosaic(index) + else: + img, _, _ = self.load_image(index) + h, w = img.shape[:2] + labels = self.load_label(index, (h, w)) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # type: ignore + else: + img = cv2.imread(self.img_files[index]) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # type: ignore + h, w = img.shape[:2] + labels = self.load_label(index, (h, w)) + + # --- 修复开始: 清洗边界框 --- + # labels 格式: [cls, x, y, w, h] (Normalized) + + valid_bboxes = [] + valid_labels = [] + + h_img, w_img = img.shape[:2] + + for i in range(len(labels)): + cls = labels[i, 0] + x, y, w, h = labels[i, 1:] + + # 1. 限制在 [0, 1] 范围内 (处理 Mosaic 裁剪产生的越界) + x1 = np.clip(x - w / 2, 0, 1) + y1 = np.clip(y - h / 2, 0, 1) + x2 = np.clip(x + w / 2, 0, 1) + y2 = np.clip(y + h / 2, 0, 1) + + # 2. 重新计算宽高 + w_new = x2 - x1 + h_new = y2 - y1 + + # 3. 重新计算中心点 + x_new = x1 + w_new / 2 + y_new = y1 + h_new / 2 + + # 4. 过滤掉极小的框 (例如小于 2 个像素) + if w_new * w_img > 2 and h_new * h_img > 2: + valid_bboxes.append([x_new, y_new, w_new, h_new]) + valid_labels.append(cls) + + if len(valid_bboxes) == 0: + # 如果这张图的所有框都被过滤掉了,尝试下一张 + return self.__getitem__((index + 1) % len(self)) + + bboxes = valid_bboxes + class_labels = valid_labels + # --- 修复结束 --- + + # 应用增强 + transformed = self.transform(image=img, bboxes=bboxes, class_labels=class_labels) + image = transformed['image'] + bboxes = transformed['bboxes'] + class_labels = transformed['class_labels'] + + # 构建 Target Tensor + n = len(bboxes) + targets = torch.zeros((n, 6)) + if n > 0: + targets[:, 1] = torch.tensor(class_labels) + targets[:, 2:] = torch.tensor(bboxes) + + return image, targets, self.img_files[index] + + except Exception as e: + # 打印更详细的错误信息以便调试,但不要中断训练 + # print(f"Error loading data {index}: {e}") + return self.__getitem__((index + 1) % len(self)) + + def _get_label_path(self, img_path): + filename = os.path.basename(img_path).rsplit('.', 1)[0] + ".txt" + return os.path.join(self.label_dir, filename) + + @staticmethod + def collate_fn(batch): + imgs, targets, paths = zip(*batch) + imgs = torch.stack(imgs, 0) + new_targets = [] + for i, t in enumerate(targets): + if t.shape[0] > 0: + t[:, 0] = i + new_targets.append(t) + if new_targets: + targets = torch.cat(new_targets, 0) + else: + targets = torch.zeros((0, 6)) + return imgs, targets, paths \ No newline at end of file diff --git a/inference.py b/inference.py new file mode 100644 index 0000000..993db1e --- /dev/null +++ b/inference.py @@ -0,0 +1,189 @@ +import torch +import cv2 +import numpy as np +import torchvision +from yolo11_standalone import YOLO11 + +# COCO 80类 类别名称 +CLASSES = [ + "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light", + "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow", + "elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee", + "skis", "snowboard", "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard", + "tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple", + "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch", + "potted plant", "bed", "dining table", "toilet", "tv", "laptop", "mouse", "remote", "keyboard", "cell phone", + "microwave", "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors", "teddy bear", + "hair drier", "toothbrush" +] + +# 生成随机颜色用于绘图 +COLORS = np.random.uniform(0, 255, size=(len(CLASSES), 3)) + +def letterbox(im, new_shape=(640, 640), color=(114, 114, 114)): + """ + 将图像缩放并填充到指定大小 (保持纵横比) + """ + shape = im.shape[:2] # current shape [height, width] + if isinstance(new_shape, int): + new_shape = (new_shape, new_shape) + + # 计算缩放比例 + r = min(new_shape[0] / shape[0], new_shape[1] / shape[1]) + + # 计算padding + new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r)) + dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding + dw, dh = dw / 2, dh / 2 # divide padding into 2 sides + + if shape[::-1] != new_unpad: # resize + im = cv2.resize(im, new_unpad, interpolation=cv2.INTER_LINEAR) + + top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1)) + left, right = int(round(dw - 0.1)), int(round(dw + 0.1)) + + # 添加边框 + im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) + return im, r, (dw, dh) + +def xywh2xyxy(x): + """Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2]""" + y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) + y[..., 0] = x[..., 0] - x[..., 2] / 2 # top left x + y[..., 1] = x[..., 1] - x[..., 3] / 2 # top left y + y[..., 2] = x[..., 0] + x[..., 2] / 2 # bottom right x + y[..., 3] = x[..., 1] + x[..., 3] / 2 # bottom right y + return y + +def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, max_det=300): + """ + 非极大值抑制 (NMS) + prediction: [Batch, 84, Anchors] + """ + # 1. 转置: [Batch, 84, Anchors] -> [Batch, Anchors, 84] + prediction = prediction.transpose(1, 2) + + bs = prediction.shape[0] # batch size + nc = prediction.shape[2] - 4 # number of classes + + # 修复: 使用 max(-1) 在最后一个维度(类别)上寻找最大置信度 + # 之前的 max(1) 错误地在 Anchors 维度上操作了 + xc = prediction[..., 4:].max(-1)[0] > conf_thres # candidates + + output = [torch.zeros((0, 6), device=prediction.device)] * bs + + for xi, x in enumerate(prediction): # image index, image inference + x = x[xc[xi]] # confidence filtering + + if not x.shape[0]: + continue + + # Box (center x, center y, width, height) to (x1, y1, x2, y2) + box = xywh2xyxy(x[:, :4]) + + # Confidence and Class + conf, j = x[:, 4:].max(1, keepdim=True) + x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres] + + # Check shape + n = x.shape[0] + if not n: + continue + elif n > max_det: + x = x[x[:, 4].argsort(descending=True)[:max_det]] + + # Batched NMS + c = x[:, 5:6] * 7680 # classes + boxes, scores = x[:, :4] + c, x[:, 4] + i = torchvision.ops.nms(boxes, scores, iou_thres) + output[xi] = x[i] + + return output + +def main(): + # 1. 初始化模型 + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print(f"Using device: {device}") + + model = YOLO11(nc=80, scale='s') + # 加载你之前转换好的纯净权重 + model.load_weights("yolo11s.pth") + model.to(device) + model.eval() + # model.train() + + # 2. 读取图片 + img_path = "1.jpg" # 请替换为你本地的图片路径 + + img0 = cv2.imread(img_path) + assert img0 is not None, f"Image Not Found {img_path}" + + # 3. 预处理 + # Letterbox resize + img, ratio, (dw, dh) = letterbox(img0, new_shape=(640, 640)) + + # BGR to RGB, HWC to CHW + img = img[:, :, ::-1].transpose(2, 0, 1) + img = np.ascontiguousarray(img) + + img_tensor = torch.from_numpy(img).to(device) + img_tensor = img_tensor.float() + img_tensor /= 255.0 # 0 - 255 to 0.0 - 1.0 + if img_tensor.ndim == 3: + img_tensor = img_tensor.unsqueeze(0) + + # 4. 推理 + print("开始推理...") + with torch.no_grad(): + pred = model(img_tensor) + + # 5. 后处理 (NMS) + pred = non_max_suppression(pred, conf_thres=0.25, iou_thres=0.45) + + # 6. 绘制结果 + det = pred[0] # 仅处理第一张图片 + + if len(det): + # 将坐标映射回原图尺寸 + # det[:, :4] 是 x1, y1, x2, y2 + det[:, [0, 2]] -= dw # x padding + det[:, [1, 3]] -= dh # y padding + det[:, :4] /= ratio + + # 裁剪坐标防止越界 + det[:, 0].clamp_(0, img0.shape[1]) + det[:, 1].clamp_(0, img0.shape[0]) + det[:, 2].clamp_(0, img0.shape[1]) + det[:, 3].clamp_(0, img0.shape[0]) + + print(f"检测到 {len(det)} 个目标") + + for *xyxy, conf, cls in det: + c = int(cls) + label = f'{CLASSES[c]} {conf:.2f}' + p1, p2 = (int(xyxy[0]), int(xyxy[1])), (int(xyxy[2]), int(xyxy[3])) + + # 画框 + color = COLORS[c] + cv2.rectangle(img0, p1, p2, color, 2, lineType=cv2.LINE_AA) + + # 画标签背景 + t_size = cv2.getTextSize(label, 0, fontScale=0.5, thickness=1)[0] + p2_label = p1[0] + t_size[0], p1[1] - t_size[1] - 3 + cv2.rectangle(img0, p1, p2_label, color, -1, cv2.LINE_AA) + + # 画文字 + cv2.putText(img0, label, (p1[0], p1[1] - 2), 0, 0.5, [255, 255, 255], thickness=1, lineType=cv2.LINE_AA) + + print(f" - {label} at {p1}-{p2}") + + # 7. 显示/保存结果 + cv2.imwrite("result.jpg", img0) + print("结果已保存至 result.jpg") + +def import_os_exists(path): + import os + return os.path.exists(path) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/inference_yoloe.py b/inference_yoloe.py new file mode 100644 index 0000000..49411db --- /dev/null +++ b/inference_yoloe.py @@ -0,0 +1,168 @@ +import torch +import cv2 +import numpy as np +import torchvision +from pathlib import Path + +# 导入你的模块 +from yolo11_standalone import YOLO11E +from mobile_clip_standalone import MobileCLIP + +from ultralytics import YOLOE + + +# --- 配置 --- +DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") +YOLO_WEIGHTS = "yoloe-11l-seg.pth" # 替换为你的 YOLO 权重路径 +CLIP_WEIGHTS = "mobileclip_blt.ts" # 替换为你的 MobileCLIP 权重路径 +CLIP_SIZE = "blt" # 对应 MobileCLIP 的 size +IMAGE_PATH = "1.jpg" # 待检测图片 + +# 自定义检测类别 (Open Vocabulary) +CUSTOM_CLASSES = ["girl", "red balloon"] + +# 绘图颜色 +COLORS = np.random.uniform(0, 255, size=(len(CUSTOM_CLASSES), 3)) + +# --- 辅助函数 (Letterbox, NMS 等) --- +def letterbox(im, new_shape=(640, 640), color=(114, 114, 114)): + shape = im.shape[:2] + if isinstance(new_shape, int): new_shape = (new_shape, new_shape) + r = min(new_shape[0] / shape[0], new_shape[1] / shape[1]) + new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r)) + dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] + dw, dh = dw / 2, dh / 2 + if shape[::-1] != new_unpad: + im = cv2.resize(im, new_unpad, interpolation=cv2.INTER_LINEAR) + top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1)) + left, right = int(round(dw - 0.1)), int(round(dw + 0.1)) + im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) + return im, r, (dw, dh) + +def xywh2xyxy(x): + y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) + y[..., 0] = x[..., 0] - x[..., 2] / 2 + y[..., 1] = x[..., 1] - x[..., 3] / 2 + y[..., 2] = x[..., 0] + x[..., 2] / 2 + y[..., 3] = x[..., 1] + x[..., 3] / 2 + return y + +def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.7, max_det=300): + prediction = prediction.transpose(1, 2) + bs = prediction.shape[0] + xc = prediction[..., 4:].max(-1)[0] > conf_thres + output = [torch.zeros((0, 6), device=prediction.device)] * bs + for xi, x in enumerate(prediction): + x = x[xc[xi]] + if not x.shape[0]: continue + box = xywh2xyxy(x[:, :4]) + conf, j = x[:, 4:].max(1, keepdim=True) + x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres] + n = x.shape[0] + if not n: continue + elif n > max_det: x = x[x[:, 4].argsort(descending=True)[:max_det]] + c = x[:, 5:6] * 7680 + boxes, scores = x[:, :4] + c, x[:, 4] + i = torchvision.ops.nms(boxes, scores, iou_thres) + output[xi] = x[i] + return output + +def main(): + print(f"Using device: {DEVICE}") + + # 1. 加载 MobileCLIP 文本编码器 + print(f"Loading MobileCLIP from {CLIP_WEIGHTS}...") + if not Path(CLIP_WEIGHTS).exists(): + raise FileNotFoundError(f"MobileCLIP weights not found: {CLIP_WEIGHTS}") + + clip_model = MobileCLIP(checkpoint=CLIP_WEIGHTS, size=CLIP_SIZE, device=DEVICE) + + # 2. 生成文本 Embeddings + print(f"Encoding classes: {CUSTOM_CLASSES}") + prompts = [f"{c}" for c in CUSTOM_CLASSES] + + tokens = clip_model.tokenize(prompts) + text_embeddings = clip_model.encode_text(tokens) # Shape: (N, 512) + + # 调整维度为 (1, N, 512) 以匹配 YOLO11E 输入 + text_embeddings = text_embeddings.unsqueeze(0) + + # 3. 加载 YOLO11E 检测模型 + print(f"Loading YOLO11E from {YOLO_WEIGHTS}...") + if not Path(YOLO_WEIGHTS).exists(): + raise FileNotFoundError(f"YOLO weights not found: {YOLO_WEIGHTS}") + + # 注意:scale='l' 必须与你的权重文件匹配 (s, m, l, x) + yolo_model = YOLO11E(nc=80, scale='l') + yolo_model.load_weights(YOLO_WEIGHTS) + yolo_model.to(DEVICE) # 使用半精度to(DEVICE) + yolo_model.eval() + + head = yolo_model.model[-1] + + with torch.no_grad(): + text_pe = head.get_tpe(text_embeddings) # type: ignore + + yolo_model.set_classes(CUSTOM_CLASSES, text_pe) + + # 5. 图像预处理 + img0 = cv2.imread(IMAGE_PATH) + assert img0 is not None, f"Image Not Found {IMAGE_PATH}" + img, ratio, (dw, dh) = letterbox(img0, new_shape=(640, 640)) + img = img[:, :, ::-1].transpose(2, 0, 1) + img = np.ascontiguousarray(img) + img_tensor = torch.from_numpy(img).to(DEVICE) + img_tensor = img_tensor.float() + img_tensor /= 255.0 + if img_tensor.ndim == 3: + img_tensor = img_tensor.unsqueeze(0) + + # 6. 推理 + print("Running inference...") + with torch.no_grad(): + pred = yolo_model(img_tensor) + if isinstance(pred, tuple): + pred = pred[0] + nc = len(CUSTOM_CLASSES) + pred = pred[:, :4+nc, :] + + # 7. 后处理 (NMS) + pred = non_max_suppression(pred, conf_thres=0.25, iou_thres=0.7) + + print(pred) + + # 8. 可视化 + det = pred[0] + if len(det): + det[:, [0, 2]] -= dw + det[:, [1, 3]] -= dh + det[:, :4] /= ratio + det[:, 0].clamp_(0, img0.shape[1]) + det[:, 1].clamp_(0, img0.shape[0]) + det[:, 2].clamp_(0, img0.shape[1]) + det[:, 3].clamp_(0, img0.shape[0]) + + print(f"Detected {len(det)} objects:") + for *xyxy, conf, cls in det: + c = int(cls) + class_name = CUSTOM_CLASSES[c] if c < len(CUSTOM_CLASSES) else str(c) + label = f'{class_name} {conf:.2f}' + + p1, p2 = (int(xyxy[0]), int(xyxy[1])), (int(xyxy[2]), int(xyxy[3])) + color = COLORS[c] + + cv2.rectangle(img0, p1, p2, color, 2, cv2.LINE_AA) + t_size = cv2.getTextSize(label, 0, 0.5, 1)[0] + p2_label = p1[0] + t_size[0], p1[1] - t_size[1] - 3 + cv2.rectangle(img0, p1, p2_label, color, -1, cv2.LINE_AA) + cv2.putText(img0, label, (p1[0], p1[1] - 2), 0, 0.5, [255, 255, 255], 1, cv2.LINE_AA) + print(f" - {label}") + else: + print("No objects detected.") + + output_path = "result_full.jpg" + cv2.imwrite(output_path, img0) + print(f"Result saved to {output_path}") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/loss.py b/loss.py new file mode 100644 index 0000000..22d9973 --- /dev/null +++ b/loss.py @@ -0,0 +1,448 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import math +import numpy as np + +# ============================================================================== +# 1. 基础工具函数 (Utils) +# ============================================================================== + +def xywh2xyxy(x): + """Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right.""" + y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) + y[..., 0] = x[..., 0] - x[..., 2] / 2 # top left x + y[..., 1] = x[..., 1] - x[..., 3] / 2 # top left y + y[..., 2] = x[..., 0] + x[..., 2] / 2 # bottom right x + y[..., 3] = x[..., 1] + x[..., 3] / 2 # bottom right y + return y + +def bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7): + """ + Calculate Intersection over Union (IoU) of box1(1, 4) to box2(n, 4). + Args: + box1 (torch.Tensor): predicted, shape (N, 4) + box2 (torch.Tensor): target, shape (N, 4) + xywh (bool): If True, input boxes are (x, y, w, h). If False, (x1, y1, x2, y2). + CIoU (bool): If True, calculate Complete IoU. + """ + # Get the coordinates of bounding boxes + if xywh: # transform from xywh to xyxy + (x1, y1, w1, h1), (x2, y2, w2, h2) = box1.chunk(4, -1), box2.chunk(4, -1) + w1_, h1_, w2_, h2_ = w1 / 2, h1 / 2, w2 / 2, h2 / 2 + b1_x1, b1_x2, b1_y1, b1_y2 = x1 - w1_, x1 + w1_, y1 - h1_, y1 + h1_ + b2_x1, b2_x2, b2_y1, b2_y2 = x2 - w2_, x2 + w2_, y2 - h2_, y2 + h2_ + else: # x1, y1, x2, y2 + b1_x1, b1_y1, b1_x2, b1_y2 = box1.chunk(4, -1) + b2_x1, b2_y1, b2_x2, b2_y2 = box2.chunk(4, -1) + w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps + w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps + + # Intersection area + inter = (torch.min(b1_x2, b2_x2) - torch.max(b1_x1, b2_x1)).clamp(0) * \ + (torch.min(b1_y2, b2_y2) - torch.max(b1_y1, b2_y1)).clamp(0) + + # Union Area + union = w1 * h1 + w2 * h2 - inter + eps + + # IoU + iou = inter / union + if CIoU or DIoU or GIoU: + cw = torch.max(b1_x2, b2_x2) - torch.min(b1_x1, b2_x1) # convex (smallest enclosing box) width + ch = torch.max(b1_y2, b2_y2) - torch.min(b1_y1, b2_y1) # convex height + if CIoU or DIoU: # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1 + c2 = cw ** 2 + ch ** 2 + eps # convex diagonal squared + rho2 = ((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 + (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4 # center dist ** 2 + if CIoU: # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47 + v = (4 / math.pi ** 2) * torch.pow(torch.atan(w2 / h2) - torch.atan(w1 / h1), 2) + with torch.no_grad(): + alpha = v / (v - iou + (1 + eps)) + return iou - (rho2 / c2 + v * alpha) # CIoU + return iou - rho2 / c2 # DIoU + c_area = cw * ch + eps # convex area + return iou - (c_area - union) / c_area # GIoU + return iou + +def make_anchors(feats, strides, grid_cell_offset=0.5): + """Generate anchors from features.""" + anchor_points, stride_tensor = [], [] + assert feats is not None + dtype, device = feats[0].dtype, feats[0].device + for i, stride in enumerate(strides): + h, w = feats[i].shape[2:] + sx = torch.arange(end=w, device=device, dtype=dtype) + grid_cell_offset + sy = torch.arange(end=h, device=device, dtype=dtype) + grid_cell_offset + sy, sx = torch.meshgrid(sy, sx, indexing="ij") + anchor_points.append(torch.stack((sx, sy), -1).view(-1, 2)) + stride_tensor.append(torch.full((h * w, 1), stride, dtype=dtype, device=device)) + return torch.cat(anchor_points), torch.cat(stride_tensor) + +def dist2bbox(distance, anchor_points, xywh=True, dim=-1): + """Transform distance(ltrb) to box(xywh or xyxy).""" + lt, rb = distance.chunk(2, dim) + x1y1 = anchor_points - lt + x2y2 = anchor_points + rb + if xywh: + c_xy = (x1y1 + x2y2) / 2 + wh = x2y2 - x1y1 + return torch.cat([c_xy, wh], dim) + return torch.cat((x1y1, x2y2), dim) + +def bbox2dist(anchor_points, bbox, reg_max): + """Transform bbox(xyxy) to dist(ltrb).""" + x1y1, x2y2 = bbox.chunk(2, -1) + return torch.cat((anchor_points - x1y1, x2y2 - anchor_points), -1).clamp_(0, reg_max - 0.01) + + +class TaskAlignedAssigner(nn.Module): + def __init__(self, topk=13, num_classes=80, alpha=1.0, beta=6.0, eps=1e-9): + super().__init__() + self.topk = topk + self.num_classes = num_classes + self.alpha = alpha + self.beta = beta + self.eps = eps + + @torch.no_grad() + def forward(self, pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt): + self.bs = pd_scores.shape[0] + self.n_max_boxes = gt_bboxes.shape[1] + + if self.n_max_boxes == 0: + return ( + torch.full_like(pd_scores[..., 0], self.num_classes), + torch.zeros_like(pd_bboxes), + torch.zeros_like(pd_scores), + torch.zeros_like(pd_scores[..., 0]), + torch.zeros_like(pd_scores[..., 0]), + ) + + mask_pos, align_metric, overlaps = self.get_pos_mask( + pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points, mask_gt + ) + + target_gt_idx, fg_mask, mask_pos = self.select_highest_overlaps(mask_pos, overlaps, self.n_max_boxes) + + target_labels, target_bboxes, target_scores = self.get_targets(gt_labels, gt_bboxes, target_gt_idx, fg_mask) + + # Normalize + align_metric *= mask_pos + pos_align_metrics = align_metric.amax(dim=-1, keepdim=True) + pos_overlaps = (overlaps * mask_pos).amax(dim=-1, keepdim=True) + norm_align_metric = (align_metric * pos_overlaps / (pos_align_metrics + self.eps)).amax(-2).unsqueeze(-1) + target_scores = target_scores * norm_align_metric + + return target_labels, target_bboxes, target_scores, fg_mask.bool(), target_gt_idx + + def get_pos_mask(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points, mask_gt): + mask_in_gts = self.select_candidates_in_gts(anc_points, gt_bboxes) + align_metric, overlaps = self.get_box_metrics(pd_scores, pd_bboxes, gt_labels, gt_bboxes, mask_in_gts * mask_gt) + mask_topk = self.select_topk_candidates(align_metric, topk_mask=mask_gt.expand(-1, -1, self.topk).bool()) + mask_pos = mask_topk * mask_in_gts * mask_gt + return mask_pos, align_metric, overlaps + + def get_box_metrics(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, mask_gt): + na = pd_bboxes.shape[-2] + mask_gt = mask_gt.bool() + overlaps = torch.zeros([self.bs, self.n_max_boxes, na], dtype=pd_bboxes.dtype, device=pd_bboxes.device) + bbox_scores = torch.zeros([self.bs, self.n_max_boxes, na], dtype=pd_scores.dtype, device=pd_scores.device) + + ind = torch.zeros([2, self.bs, self.n_max_boxes], dtype=torch.long) + ind[0] = torch.arange(end=self.bs).view(-1, 1).expand(-1, self.n_max_boxes) + ind[1] = gt_labels.squeeze(-1) + + # Clamp labels to handle case where no GT exists or background index issues + ind[1] = ind[1].clamp(max=self.num_classes - 1) + + bbox_scores[mask_gt] = pd_scores[ind[0], :, ind[1]][mask_gt] + + pd_boxes = pd_bboxes.unsqueeze(1).expand(-1, self.n_max_boxes, -1, -1)[mask_gt] + gt_boxes = gt_bboxes.unsqueeze(2).expand(-1, -1, na, -1)[mask_gt] + overlaps[mask_gt] = bbox_iou(gt_boxes, pd_boxes, xywh=False, CIoU=True).squeeze(-1).clamp_(0) + + align_metric = bbox_scores.pow(self.alpha) * overlaps.pow(self.beta) + return align_metric, overlaps + + def select_topk_candidates(self, metrics, topk_mask=None): + topk_metrics, topk_idxs = torch.topk(metrics, self.topk, dim=-1, largest=True) + if topk_mask is None: + topk_mask = (topk_metrics.max(-1, keepdim=True)[0] > self.eps).expand_as(topk_idxs) + topk_idxs.masked_fill_(~topk_mask, 0) + + count_tensor = torch.zeros(metrics.shape, dtype=torch.int8, device=topk_idxs.device) + ones = torch.ones_like(topk_idxs[:, :, :1], dtype=torch.int8, device=topk_idxs.device) + for k in range(self.topk): + count_tensor.scatter_add_(-1, topk_idxs[:, :, k : k + 1], ones) + count_tensor.masked_fill_(count_tensor > 1, 0) + return count_tensor.to(metrics.dtype) + + def get_targets(self, gt_labels, gt_bboxes, target_gt_idx, fg_mask): + batch_ind = torch.arange(end=self.bs, dtype=torch.int64, device=gt_labels.device)[..., None] + target_gt_idx = target_gt_idx + batch_ind * self.n_max_boxes + target_labels = gt_labels.long().flatten()[target_gt_idx] + + target_bboxes = gt_bboxes.view(-1, gt_bboxes.shape[-1])[target_gt_idx] + target_labels.clamp_(0) + + target_scores = torch.zeros( + (target_labels.shape[0], target_labels.shape[1], self.num_classes), + dtype=torch.int64, + device=target_labels.device, + ) + target_scores.scatter_(2, target_labels.unsqueeze(-1), 1) + + fg_scores_mask = fg_mask[:, :, None].repeat(1, 1, self.num_classes) + target_scores = torch.where(fg_scores_mask > 0, target_scores, 0) + + return target_labels, target_bboxes, target_scores + + @staticmethod + def select_candidates_in_gts(xy_centers, gt_bboxes, eps=1e-9): + n_anchors = xy_centers.shape[0] + bs, n_boxes, _ = gt_bboxes.shape + lt, rb = gt_bboxes.view(-1, 1, 4).chunk(2, 2) + bbox_deltas = torch.cat((xy_centers[None] - lt, rb - xy_centers[None]), dim=2).view(bs, n_boxes, n_anchors, -1) + return bbox_deltas.amin(3).gt_(eps) + + @staticmethod + def select_highest_overlaps(mask_pos, overlaps, n_max_boxes): + fg_mask = mask_pos.sum(-2) + if fg_mask.max() > 1: + mask_multi_gts = (fg_mask.unsqueeze(1) > 1).expand(-1, n_max_boxes, -1) + max_overlaps_idx = overlaps.argmax(1) + + is_max_overlaps = torch.zeros(mask_pos.shape, dtype=mask_pos.dtype, device=mask_pos.device) + is_max_overlaps.scatter_(1, max_overlaps_idx.unsqueeze(1), 1) + + mask_pos = torch.where(mask_multi_gts, is_max_overlaps, mask_pos).float() + fg_mask = mask_pos.sum(-2) + target_gt_idx = mask_pos.argmax(-2) + return target_gt_idx, fg_mask, mask_pos + +# ============================================================================== +# 3. Loss 模块 (Loss Modules - from loss.py) +# ============================================================================== + +class DFLoss(nn.Module): + """Distribution Focal Loss (DFL).""" + def __init__(self, reg_max: int = 16) -> None: + super().__init__() + self.reg_max = reg_max + + def __call__(self, pred_dist: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + target = target.clamp_(0, self.reg_max - 1 - 0.01) + tl = target.long() # target left + tr = tl + 1 # target right + wl = tr - target # weight left + wr = 1 - wl # weight right + return ( + F.cross_entropy(pred_dist, tl.view(-1), reduction="none").view(tl.shape) * wl + + F.cross_entropy(pred_dist, tr.view(-1), reduction="none").view(tl.shape) * wr + ).mean(-1, keepdim=True) + +class BboxLoss(nn.Module): + """Criterion for computing training losses for bounding boxes (IoU + DFL).""" + def __init__(self, reg_max: int = 16): + super().__init__() + self.dfl_loss = DFLoss(reg_max) if reg_max > 1 else None + + def forward(self, pred_dist, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask): + weight = target_scores.sum(-1)[fg_mask].unsqueeze(-1) + iou = bbox_iou(pred_bboxes[fg_mask], target_bboxes[fg_mask], xywh=False, CIoU=True) + loss_iou = ((1.0 - iou) * weight).sum() / target_scores_sum + + if self.dfl_loss: + target_ltrb = bbox2dist(anchor_points, target_bboxes, self.dfl_loss.reg_max - 1) + loss_dfl = self.dfl_loss(pred_dist[fg_mask].view(-1, self.dfl_loss.reg_max), target_ltrb[fg_mask]) * weight + loss_dfl = loss_dfl.sum() / target_scores_sum + else: + loss_dfl = torch.tensor(0.0).to(pred_dist.device) + + return loss_iou, loss_dfl + +class YOLOv8DetectionLoss: + """ + 独立版的 YOLOv8/v11 Detection Loss。 + Refactored to assume specific inputs and remove reliance on full Model object. + """ + def __init__(self, nc=80, reg_max=16, stride=None, hyp=None): + """ + Args: + nc (int): Number of classes. + reg_max (int): DFL channels (default 16). + stride (list/torch.Tensor): Model strides (e.g., [8, 16, 32]). + hyp (dict): Hyperparameters (box, cls, dfl gains). + """ + if stride is None: + stride = [8, 16, 32] # Default strides + if hyp is None: + # Default YOLOv8 hyperparameters + hyp = {'box': 7.5, 'cls': 0.5, 'dfl': 1.5} + + self.nc = nc + self.reg_max = reg_max + self.stride = stride + self.hyp = hyp + self.no = nc + reg_max * 4 # Output channels per anchor + self.use_dfl = reg_max > 1 + + # Components + self.assigner = TaskAlignedAssigner(topk=10, num_classes=self.nc, alpha=0.5, beta=6.0) + self.bbox_loss = BboxLoss(reg_max) + self.bce = nn.BCEWithLogitsLoss(reduction="none") + self.proj = torch.arange(reg_max, dtype=torch.float) # Will move to device in method + + def preprocess(self, targets, batch_size, scale_tensor): + """Preprocesses targets: converts [idx, cls, xywh] to internal format.""" + if targets.shape[0] == 0: + out = torch.zeros(batch_size, 0, 5, device=targets.device) + else: + i = targets[:, 0] # image index + _, counts = i.unique(return_counts=True) + counts = counts.to(dtype=torch.int32) + out = torch.zeros(batch_size, counts.max(), 5, device=targets.device) + for j in range(batch_size): + matches = i == j + n = matches.sum() + if n: + out[j, :n] = targets[matches, 1:] # cls, x, y, w, h + out[..., 1:5] = xywh2xyxy(out[..., 1:5].mul_(scale_tensor)) + return out + + def bbox_decode(self, anchor_points, pred_dist): + """Decode predicted object bounding box coordinates from anchor points and distribution.""" + if self.use_dfl: + b, a, c = pred_dist.shape # batch, anchors, channels + # Project distributions to scalars using self.proj (0, 1, 2, ..., reg_max-1) + pred_dist = pred_dist.view(b, a, 4, c // 4).softmax(3).matmul(self.proj.to(pred_dist.device).type(pred_dist.dtype)) + return dist2bbox(pred_dist, anchor_points, xywh=False) + + def __call__(self, preds, batch): + """ + Args: + preds: List of prediction tensors [B, C, H, W] for each stride level. + C should be reg_max * 4 + nc. + batch: Dict containing: + 'batch_idx': Tensor [N], image index for each target + 'cls': Tensor [N], class index for each target + 'bboxes': Tensor [N, 4], normalized xywh format + Returns: + total_loss: scalar tensor (sum of box, cls, dfl), multiplied by batch_size as per original logic if needed, + but here we return mean or sum based on reduction. + WARNING: Ultralytics returns (loss * bs) usually. + loss_items: Tensor used for logging [box, cls, dfl] + """ + # Ensure proj is on correct device + self.device = preds[0].device + loss = torch.zeros(3, device=self.device) # box, cls, dfl + + feats = preds + + # Concat features from different strides: (B, no, Total_Anchors) + pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split( + (self.reg_max * 4, self.nc), 1 + ) + + pred_scores = pred_scores.permute(0, 2, 1).contiguous() # (B, Anchors, nc) + pred_distri = pred_distri.permute(0, 2, 1).contiguous() # (B, Anchors, reg_max * 4) + + dtype = pred_scores.dtype + batch_size = pred_scores.shape[0] + + # Calculate image size from features and first stride (assuming square output implies inputs) + # Ultralytics uses provided imgsz, but we can infer roughly or pass it. + # Here assuming feature map H*stride[0] = img H + imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] + + # Generate anchors + anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5) + + # Process Targets + if "batch_idx" not in batch: + # Fallback if simplified batch input passed (targets assumed [N, 6] -> idx, cls, x, y, w, h) + # This handles cases where user passes raw target tensor instead of dict + raise ValueError("Batch dict must contain 'batch_idx', 'cls', and 'bboxes'") + + targets = torch.cat((batch["batch_idx"].view(-1, 1), batch["cls"].view(-1, 1), batch["bboxes"]), 1) + # Check if Normalized (xywh) + targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]]) + gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy + mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0) + + # Decode Boxes + pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4) + + # Task Aligned Assignment + _, target_bboxes, target_scores, fg_mask, _ = self.assigner( + pred_scores.detach().sigmoid(), + (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype), + anchor_points * stride_tensor, + gt_labels, + gt_bboxes, + mask_gt, + ) + + target_scores_sum = max(target_scores.sum(), 1) + + # CLS Loss + loss[1] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum + + # BBox Loss (IoU + DFL) + if fg_mask.sum(): + target_bboxes /= stride_tensor + loss[0], loss[2] = self.bbox_loss( + pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask + ) + + # Apply Gains + loss[0] *= self.hyp['box'] + loss[1] *= self.hyp['cls'] + loss[2] *= self.hyp['dfl'] + + return loss.sum() * batch_size, loss.detach() # scaling by batch_size to match Ultralytics behavior + +# ============================================================================== +# Usage Example +# ============================================================================== +if __name__ == "__main__": + # 配置 + num_classes = 80 + reg_max = 16 + strides = [8, 16, 32] + device = "cuda" if torch.cuda.is_available() else "cpu" + + # 初始化 Loss + criterion = YOLOv8DetectionLoss(nc=num_classes, reg_max=reg_max, stride=strides) + + # 模拟输入 (Batch Size = 2) + bs = 2 + # 模拟模型输出: 3 个层级的特征图 + # [BS, Channels, H, W], Channels = reg_max * 4 + num_classes + channels = reg_max * 4 + num_classes + preds = [ + torch.randn(bs, channels, 80, 80, device=device, requires_grad=True), # stride 8 + torch.randn(bs, channels, 40, 40, device=device, requires_grad=True), # stride 16 + torch.randn(bs, channels, 20, 20, device=device, requires_grad=True), # stride 32 + ] + + # 模拟 Ground Truth + # batch: idx, cls, x, y, w, h (normalized 0-1) + target_batch = { + "batch_idx": torch.tensor([0, 0, 1], device=device), # Image indices + "cls": torch.tensor([1, 5, 10], device=device), # Class indices + "bboxes": torch.tensor([ # Normalized xywh + [0.5, 0.5, 0.2, 0.3], + [0.1, 0.1, 0.1, 0.1], + [0.8, 0.8, 0.2, 0.2] + ], device=device) + } + + # 计算 Loss + total_loss, loss_items = criterion(preds, target_batch) + + print(f"Total Loss: {total_loss.item()}") + print(f"Loss Items (Box, Cls, DFL): {loss_items}") + + # 反向传播测试 + total_loss.backward() + print("Backward pass successful.") \ No newline at end of file diff --git a/metrics.py b/metrics.py new file mode 100644 index 0000000..00231ef --- /dev/null +++ b/metrics.py @@ -0,0 +1,148 @@ +import torch +import numpy as np +import torchvision + +def xywh2xyxy(x): + """Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right""" + y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) + y[..., 0] = x[..., 0] - x[..., 2] / 2 # top left x + y[..., 1] = x[..., 1] - x[..., 3] / 2 # top left y + y[..., 2] = x[..., 0] + x[..., 2] / 2 # bottom right x + y[..., 3] = x[..., 1] + x[..., 3] / 2 # bottom right y + return y + +def box_iou(box1, box2): + """ + Return intersection-over-union (Jaccard index) of boxes. + Both sets of boxes are expected to be in (x1, y1, x2, y2) format. + box1: [N, 4] + box2: [M, 4] + Returns: [N, M] + """ + def box_area(box): + return (box[:, 2] - box[:, 0]) * (box[:, 3] - box[:, 1]) + + area1 = box_area(box1) + area2 = box_area(box2) + + lt = torch.max(box1[:, None, :2], box2[:, :2]) # [N,M,2] + rb = torch.min(box1[:, None, 2:], box2[:, 2:]) # [N,M,2] + + wh = (rb - lt).clamp(min=0) # [N,M,2] + inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] + + union = area1[:, None] + area2 - inter + return inter / (union + 1e-6) + +def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, max_det=300): + """ + Non-Maximum Suppression (NMS) on inference results + prediction: [Batch, 84, 8400] (for YOLOv8/11) + """ + # [Batch, 84, Anchors] -> [Batch, Anchors, 84] + prediction = prediction.transpose(1, 2) + + bs = prediction.shape[0] + nc = prediction.shape[2] - 4 + xc = prediction[..., 4:].max(-1)[0] > conf_thres + + output = [torch.zeros((0, 6), device=prediction.device)] * bs + + for xi, x in enumerate(prediction): + x = x[xc[xi]] + if not x.shape[0]: + continue + + # Box decoding + box = xywh2xyxy(x[:, :4]) + + # Confidence and Class + conf, j = x[:, 4:].max(1, keepdim=True) + x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres] + + n = x.shape[0] + if not n: + continue + elif n > max_det: + x = x[x[:, 4].argsort(descending=True)[:max_det]] + + # Batched NMS + c = x[:, 5:6] * 7680 + boxes, scores = x[:, :4] + c, x[:, 4] + i = torchvision.ops.nms(boxes, scores, iou_thres) + output[xi] = x[i] + + return output + +def compute_ap(recall, precision): + """ Compute the average precision, given the recall and precision curves """ + # Append sentinel values to beginning and end + mrec = np.concatenate(([0.0], recall, [1.0])) + mpre = np.concatenate(([1.0], precision, [0.0])) + + # Compute the precision envelope + mpre = np.flip(np.maximum.accumulate(np.flip(mpre))) + + # Integrate area under curve + method = 'interp' # methods: 'continuous', 'interp' + if method == 'interp': + x = np.linspace(0, 1, 101) # 101-point interp (COCO) + ap = np.trapz(np.interp(x, mrec, mpre), x) # integrate + else: # 'continuous' + i = np.where(mrec[1:] != mrec[:-1])[0] # points where x axis (recall) changes + ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) # area under curve + + return ap, mpre, mrec + +def smooth(y, f=0.05): + """Box filter of fraction f""" + nf = round(len(y) * f * 2) // 2 + 1 # number of filter elements (must be odd) + p = np.ones(nf // 2) # ones padding + yp = np.concatenate((p * y[0], y, p * y[-1]), 0) # y padded + return np.convolve(yp, np.ones(nf) / nf, mode='valid') # y-smoothed + +def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir='.', names=(), eps=1e-16): + """ Compute the average precision, given the recall and precision curves. """ + # Sort by objectness + i = np.argsort(-conf) + tp, conf, pred_cls = tp[i], conf[i], pred_cls[i] + + # Find unique classes + unique_classes, nt = np.unique(target_cls, return_counts=True) + nc = unique_classes.shape[0] # number of classes, number of detections + + # Create Precision-Recall curve and compute AP for each class + px, py = np.linspace(0, 1, 1000), [] # for plotting + ap, p, r = np.zeros((nc, tp.shape[1])), np.zeros((nc, 1000)), np.zeros((nc, 1000)) + for ci, c in enumerate(unique_classes): + i = pred_cls == c + n_l = (target_cls == c).sum() # number of labels + n_p = i.sum() # number of predictions + + if n_p == 0 or n_l == 0: + continue + + # Accumulate FPs and TPs + fpc = (1 - tp[i]).cumsum(0) + tpc = tp[i].cumsum(0) + + # Recall + recall = tpc / (n_l + eps) # recall curve + r[ci] = np.interp(-px, -conf[i], recall[:, 0], left=0) # negative x, xp because xp decreases + + # Precision + precision = tpc / (tpc + fpc) # precision curve + p[ci] = np.interp(-px, -conf[i], precision[:, 0], left=1) # p at pr_score + + # AP from recall-precision curve + for j in range(tp.shape[1]): + ap[ci, j], mpre, mrec = compute_ap(recall[:, j], precision[:, j]) + + # Compute F1 (harmonic mean of precision and recall) + f1 = 2 * p * r / (p + r + eps) + i = smooth(f1.mean(0), 0.1).argmax() # max F1 index + p, r, f1 = p[:, i], r[:, i], f1[:, i] + tp = (r * nt).round().astype(int) + fp = (tp / (p + eps) - tp).astype(int) + + return tp, fp, p, r, f1, ap, unique_classes.astype(int) \ No newline at end of file diff --git a/mobile_clip_standalone.py b/mobile_clip_standalone.py new file mode 100644 index 0000000..84945b9 --- /dev/null +++ b/mobile_clip_standalone.py @@ -0,0 +1,117 @@ +import torch +import torch.nn as nn +from pathlib import Path +from typing import List, Union +import mobileclip + +class TextModel(nn.Module): + """TextModel 基类,定义接口""" + def __init__(self): + super().__init__() + + def tokenize(self, texts): + raise NotImplementedError + + def encode_text(self, texts, dtype): + raise NotImplementedError + +class MobileCLIP(TextModel): + """ + MobileCLIP 文本编码器。 + """ + config_size_map = {"s0": "s0", "s1": "s1", "s2": "s2", "b": "b", "blt": "b"} + + def __init__(self, checkpoint: str, size: str = "s0", device: Union[str, torch.device] = "cpu") -> None: + """ + 初始化 MobileCLIP 文本编码器。 + + Args: + checkpoint (str): 模型权重文件路径 (.pt 或 .ts). + size (str): 模型大小标识符 ('s0', 's1', 's2', 'b', 'blt'). + device (torch.device): 加载模型的设备. + """ + super().__init__() + + if isinstance(device, str): + device = torch.device(device) + + if not Path(checkpoint).exists(): + raise FileNotFoundError(f"找不到权重文件: {checkpoint}") + + if size not in self.config_size_map: + raise ValueError(f"不支持的大小: {size}. 可选: {list(self.config_size_map.keys())}") + + config = self.config_size_map[size] + + # 1. 加载 Tokenizer + self.tokenizer = mobileclip.get_tokenizer(f"mobileclip_{config}") + + # 2. 加载模型 + if str(checkpoint).endswith('.ts'): + # TorchScript 格式 (例如 mobileclip_blt.ts) + print(f"Loading TorchScript model from {checkpoint}...") + self.model = torch.jit.load(checkpoint, map_location=device) + self.is_torchscript = True + else: + # PyTorch 格式 (.pt) + print(f"Loading PyTorch model from {checkpoint}...") + self.model = mobileclip.create_model_and_transforms( + f"mobileclip_{config}", + pretrained=checkpoint, + device=device + )[0] + self.is_torchscript = False + + self.to(device) + self.device = device + self.eval() + + def tokenize(self, texts: List[str]) -> torch.Tensor: + """ + 将文本转换为 token。 + """ + return self.tokenizer(texts).to(self.device) + + def encode_text(self, texts: torch.Tensor, dtype: torch.dtype = torch.float32) -> torch.Tensor: + """ + 编码 tokenized 文本并进行归一化。 + """ + with torch.no_grad(): + if self.is_torchscript: + return self.model(texts).to(dtype) + + text_features = self.model.encode_text(texts).to(dtype) # type: ignore + text_features /= text_features.norm(p=2, dim=-1, keepdim=True) + return text_features + +# --- 使用示例 --- +if __name__ == "__main__": + # 1. 设置设备 + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print(f"Using device: {device}") + + # 指定本地模型路径 + checkpoint_path = "mobileclip_blt.ts" + + try: + if Path(checkpoint_path).exists(): + # 2. 初始化模型 (指定本地路径和对应大小) + # 注意:blt 对应 size="blt" + model = MobileCLIP(checkpoint=checkpoint_path, size="blt", device=device) + + # 3. 准备文本 + input_texts = ["a photo of a cat", "a photo of a dog"] + + # 4. Tokenize + tokens = model.tokenize(input_texts) + print(f"Tokens shape: {tokens.shape}") + + # 5. Encode + features = model.encode_text(tokens) + print(f"Features shape: {features.shape}") + print("运行成功!") + else: + print(f"权重文件不存在: {checkpoint_path}") + + except Exception as e: + print(f"发生错误: {e}") \ No newline at end of file diff --git a/mobileclip/__init__.py b/mobileclip/__init__.py new file mode 100644 index 0000000..956749d --- /dev/null +++ b/mobileclip/__init__.py @@ -0,0 +1,98 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# For licensing see accompanying LICENSE file. +# Copyright (C) 2024 Apple Inc. All Rights Reserved. + +from __future__ import annotations + +import json +import os +from typing import Any, Optional, Tuple, Union + +import torch +import torch.nn as nn +from torchvision.transforms import ( + CenterCrop, + Compose, + InterpolationMode, + Resize, + ToTensor, +) + +from mobileclip.clip import CLIP +from mobileclip.modules.common.mobileone import reparameterize_model +from mobileclip.modules.text.tokenizer import ( + ClipTokenizer, +) + + +def create_model_and_transforms( + model_name: str, + pretrained: str | None = None, + reparameterize: bool | None = True, + device: str | torch.device = "cpu", +) -> tuple[nn.Module, Any, Any]: + """Method to instantiate model and pre-processing transforms necessary for inference. + + Args: + model_name: Model name. Choose from ['mobileclip_s0', 'mobileclip_s1', 'mobileclip_s2', 'mobileclip_b'] + pretrained: Location of pretrained checkpoint. + reparameterize: When set to True, re-parameterizable branches get folded for faster inference. + device: Device identifier for model placement. + + Returns: + Tuple of instantiated model, and preprocessing transforms for inference. + """ + # Config files + root_dir = os.path.dirname(os.path.abspath(__file__)) + configs_dir = os.path.join(root_dir, "configs") + model_cfg_file = os.path.join(configs_dir, model_name + ".json") + + # Get config from yaml file + if not os.path.exists(model_cfg_file): + raise ValueError(f"Unsupported model name: {model_name}") + model_cfg = json.load(open(model_cfg_file)) + + # Build preprocessing transforms for inference + resolution = model_cfg["image_cfg"]["image_size"] + resize_size = resolution + centercrop_size = resolution + aug_list = [ + Resize( + resize_size, + interpolation=InterpolationMode.BILINEAR, + ), + CenterCrop(centercrop_size), + ToTensor(), + ] + preprocess = Compose(aug_list) + + # Build model + model = CLIP(cfg=model_cfg) + model.to(device) + model.eval() + + # Load checkpoint + if pretrained is not None: + chkpt = torch.load(pretrained) + model.load_state_dict(chkpt) + + # Reparameterize model for inference (if specified) + if reparameterize: + model = reparameterize_model(model) + + return model, None, preprocess + + +def get_tokenizer(model_name: str) -> nn.Module: + # Config files + root_dir = os.path.dirname(os.path.abspath(__file__)) + configs_dir = os.path.join(root_dir, "configs") + model_cfg_file = os.path.join(configs_dir, model_name + ".json") + + # Get config from yaml file + model_cfg = json.load(open(model_cfg_file)) + + # Build tokenizer + text_tokenizer = ClipTokenizer(model_cfg) + return text_tokenizer diff --git a/mobileclip/clip.py b/mobileclip/clip.py new file mode 100644 index 0000000..518ed38 --- /dev/null +++ b/mobileclip/clip.py @@ -0,0 +1,69 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# +# For licensing see accompanying LICENSE file. +# Copyright (C) 2024 Apple Inc. All Rights Reserved. +# +"""Model schema in open_clip format for inference only.""" + +from __future__ import annotations + +import math +from typing import Any + +import torch +import torch.nn.functional as F +from torch import nn + +from mobileclip.text_encoder import ( + TextTransformer, +) + +from .image_encoder import MCi + + +class CLIP(nn.Module): + """Base class for multi-modal image-text data.""" + + def __init__(self, cfg: dict, output_dict: bool = False, *args, **kwargs) -> None: + super().__init__() + self.output_dict = output_dict + self.projection_dim = cfg["embed_dim"] + if self.projection_dim is None: + raise ValueError("Please specify `embed_dim` in model config.") + + self.image_encoder = MCi( + model_name=cfg["image_cfg"]["model_name"], + projection_dim=self.projection_dim, + ) + self.text_encoder = TextTransformer(cfg=cfg["text_cfg"], projection_dim=self.projection_dim) + self.logit_scale = nn.Parameter(torch.ones([]) * math.log(1.0 / 0.07)) + + def _exponentiate_and_clip_logits(self, max_scale: float = 100.0): + scale = self.logit_scale.exp() + scale = torch.clamp(scale, 0, max_scale) + return scale + + def encode_image(self, image: torch.Tensor, normalize: bool = False): + image_encoder_out = self.image_encoder(image) + if isinstance(image_encoder_out, dict): + features = image_encoder_out["logits"] + else: + features = image_encoder_out + return F.normalize(features, dim=-1) if normalize else features + + def encode_text(self, text: torch.Tensor, normalize: bool = False): + text_features = self.text_encoder(text_tokens=text, key_padding_mask=None) + return F.normalize(text_features, dim=-1) if normalize else text_features + + def forward(self, image: torch.Tensor | None = None, text: torch.Tensor | None = None, *args, **kwargs) -> Any: + image_embeddings = self.encode_image(image, normalize=True) if image is not None else None + text_embeddings = self.encode_text(text, normalize=True) if text is not None else None + + if self.output_dict: + return { + "image_features": image_embeddings, + "text_features": text_embeddings, + "logit_scale": self._exponentiate_and_clip_logits(), + } + return image_embeddings, text_embeddings, self._exponentiate_and_clip_logits() diff --git a/mobileclip/configs/mobileclip_b.json b/mobileclip/configs/mobileclip_b.json new file mode 100644 index 0000000..1d7abd9 --- /dev/null +++ b/mobileclip/configs/mobileclip_b.json @@ -0,0 +1,18 @@ +{ + "embed_dim": 512, + "image_cfg": { + "image_size": 224, + "model_name": "vit_b16" + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "dim": 512, + "ffn_multiplier_per_layer": 4.0, + "n_heads_per_layer": 8, + "n_transformer_layers": 12, + "norm_layer": "layer_norm_fp32", + "causal_masking": true, + "model_name": "base" + } +} diff --git a/mobileclip/configs/mobileclip_s0.json b/mobileclip/configs/mobileclip_s0.json new file mode 100644 index 0000000..ca289a4 --- /dev/null +++ b/mobileclip/configs/mobileclip_s0.json @@ -0,0 +1,18 @@ +{ + "embed_dim": 512, + "image_cfg": { + "image_size": 256, + "model_name": "mci0" + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "dim": 512, + "ffn_multiplier_per_layer": 4.0, + "n_heads_per_layer": 8, + "n_transformer_layers": 4, + "norm_layer": "layer_norm_fp32", + "causal_masking": false, + "model_name": "mct" + } +} diff --git a/mobileclip/configs/mobileclip_s1.json b/mobileclip/configs/mobileclip_s1.json new file mode 100644 index 0000000..8ea7809 --- /dev/null +++ b/mobileclip/configs/mobileclip_s1.json @@ -0,0 +1,18 @@ +{ + "embed_dim": 512, + "image_cfg": { + "image_size": 256, + "model_name": "mci1" + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "dim": 512, + "ffn_multiplier_per_layer": 4.0, + "n_heads_per_layer": 8, + "n_transformer_layers": 12, + "norm_layer": "layer_norm_fp32", + "causal_masking": false, + "model_name": "base" + } +} diff --git a/mobileclip/configs/mobileclip_s2.json b/mobileclip/configs/mobileclip_s2.json new file mode 100644 index 0000000..7ad5e7d --- /dev/null +++ b/mobileclip/configs/mobileclip_s2.json @@ -0,0 +1,18 @@ +{ + "embed_dim": 512, + "image_cfg": { + "image_size": 256, + "model_name": "mci2" + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "dim": 512, + "ffn_multiplier_per_layer": 4.0, + "n_heads_per_layer": 8, + "n_transformer_layers": 12, + "norm_layer": "layer_norm_fp32", + "causal_masking": false, + "model_name": "base" + } +} diff --git a/mobileclip/image_encoder.py b/mobileclip/image_encoder.py new file mode 100644 index 0000000..c957b65 --- /dev/null +++ b/mobileclip/image_encoder.py @@ -0,0 +1,63 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# +# For licensing see accompanying LICENSE file. +# Copyright (C) 2024 Apple Inc. All Rights Reserved. +# +from typing import Any + +import torch.nn as nn +from timm.models import create_model + +from mobileclip import models # noqa added to register models +from mobileclip.modules.image.image_projection import GlobalPool2D + + +class MCi(nn.Module): + """This class implements `MCi Models `_.""" + + def __init__(self, model_name: str, *args, **kwargs) -> None: + super().__init__() + self.projection_dim = None + if "projection_dim" in kwargs: + self.projection_dim = kwargs.get("projection_dim") + + # Create model + self.model = create_model(model_name, projection_dim=self.projection_dim) + + # Build out projection head. + if self.projection_dim is not None: + if hasattr(self.model, "head"): + self.model.head = MCi._update_image_classifier( + image_classifier=self.model.head, projection_dim=self.projection_dim + ) + + def forward(self, x: Any, *args, **kwargs) -> Any: + """A forward function of the model.""" + x = self.model(x) + return x + + @staticmethod + def _get_in_feature_dimension(image_classifier: nn.Module) -> int: + """Return the input feature dimension to the image classification head.""" + in_features = None + if isinstance(image_classifier, nn.Sequential): + # Classifier that uses nn.Sequential usually has global pooling and + # multiple linear layers. Find the first linear layer and get its + # in_features + for layer in image_classifier: + if isinstance(layer, nn.Linear): + in_features = layer.in_features + break + elif isinstance(image_classifier, nn.Linear): + in_features = image_classifier.in_features + + if in_features is None: + raise NotImplementedError(f"Cannot get input feature dimension of {image_classifier}.") + return in_features + + @staticmethod + def _update_image_classifier(image_classifier: nn.Module, projection_dim: int, *args, **kwargs) -> nn.Module: + in_features = MCi._get_in_feature_dimension(image_classifier) + new_img_classifier = GlobalPool2D(in_dim=in_features, out_dim=projection_dim) + return new_img_classifier diff --git a/mobileclip/logger.py b/mobileclip/logger.py new file mode 100644 index 0000000..ff33562 --- /dev/null +++ b/mobileclip/logger.py @@ -0,0 +1,120 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# For licensing see accompanying LICENSE file. +# Copyright (C) 2024 Apple Inc. All Rights Reserved. + +from __future__ import annotations + +import os +import sys +import time +import traceback + +text_colors = { + "logs": "\033[34m", # 033 is the escape code and 34 is the color code + "info": "\033[32m", + "warning": "\033[33m", + "debug": "\033[93m", + "error": "\033[31m", + "bold": "\033[1m", + "end_color": "\033[0m", + "light_red": "\033[36m", +} + + +def get_curr_time_stamp() -> str: + return time.strftime("%Y-%m-%d %H:%M:%S") + + +def error(message: str) -> None: + time_stamp = get_curr_time_stamp() + error_str = text_colors["error"] + text_colors["bold"] + "ERROR " + text_colors["end_color"] + + # exiting with code -1 does not tell any information about the error (e.g., NaN encountered in the loss). + # For more descriptive error messages, we replace exit(-1) with sys.exit(ERROR_MESSAGE). + # This allows us to handle specific exceptions in the tests. + + # print("{} - {} - {}".format(time_stamp, error_str, message), flush=True) + # print("{} - {} - {}".format(time_stamp, error_str, "Exiting!!!"), flush=True) + # exit(-1) + + if sys.exc_info()[0] is None: + traceback.print_stack() + else: + traceback.print_exc() + sys.exit(f"{time_stamp} - {error_str} - {message}. Exiting!!!") + + +def color_text(in_text: str) -> str: + return text_colors["light_red"] + in_text + text_colors["end_color"] + + +def log(message: str, end="\n") -> None: + time_stamp = get_curr_time_stamp() + log_str = text_colors["logs"] + text_colors["bold"] + "LOGS " + text_colors["end_color"] + print(f"{time_stamp} - {log_str} - {message}", end=end) + + +def warning(message: str | Warning) -> None: + if isinstance(message, Warning): + message = f"{type(message).__name__}({','.join(map(repr, message.args))}" + + time_stamp = get_curr_time_stamp() + warn_str = text_colors["warning"] + text_colors["bold"] + "WARNING" + text_colors["end_color"] + print(f"{time_stamp} - {warn_str} - {message}") + + +def ignore_exception_with_warning(message: str) -> None: + """After catching a tolerable exception E1 (e.g. when Model.forward() fails during profiling with try-catch, it'll + be helpful to log the exception for future investigation. But printing the error stack trace, as is, could be + confusing when an uncaught (non-tolerable) exception "E2" raises down the road. Then, the log will contain two + stack traces for E1, E2. When looking for errors in logs, users should look for E2, but they may find E1. + + This function appends "(WARNING)" at the end of all lines of the E1 traceback, so that the user can distinguish E1 + from uncaught exception E2. + + Args: + message: Extra explanation and context for debugging. (Note: the exception obj will be automatically fetched + from python. No need to pass it as an argument or as message) + """ + warning(f"{message}:\n{traceback.format_exc()}".replace("\n", "\n(WARNING)")) + + +def info(message: str, print_line: bool | None = False) -> None: + time_stamp = get_curr_time_stamp() + info_str = text_colors["info"] + text_colors["bold"] + "INFO " + text_colors["end_color"] + print(f"{time_stamp} - {info_str} - {message}") + if print_line: + double_dash_line(dashes=150) + + +def debug(message: str) -> None: + time_stamp = get_curr_time_stamp() + log_str = text_colors["debug"] + text_colors["bold"] + "DEBUG " + text_colors["end_color"] + print(f"{time_stamp} - {log_str} - {message}") + + +def double_dash_line(dashes: int | None = 75) -> None: + print(text_colors["error"] + "=" * dashes + text_colors["end_color"]) + + +def singe_dash_line(dashes: int | None = 67) -> None: + print("-" * dashes) + + +def print_header(header: str) -> None: + double_dash_line() + print(text_colors["info"] + text_colors["bold"] + "=" * 50 + str(header) + text_colors["end_color"]) + double_dash_line() + + +def print_header_minor(header: str) -> None: + print(text_colors["warning"] + text_colors["bold"] + "=" * 25 + str(header) + text_colors["end_color"]) + + +def disable_printing(): + sys.stdout = open(os.devnull, "w") + + +def enable_printing(): + sys.stdout = sys.__stdout__ diff --git a/mobileclip/models/__init__.py b/mobileclip/models/__init__.py new file mode 100644 index 0000000..47d3f15 --- /dev/null +++ b/mobileclip/models/__init__.py @@ -0,0 +1,12 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# +# For licensing see accompanying LICENSE file. +# Copyright (C) 2024 Apple Inc. All rights reserved. +# +from .mci import ( + mci0, + mci1, + mci2, +) +from .vit import vit_b16 diff --git a/mobileclip/models/mci.py b/mobileclip/models/mci.py new file mode 100644 index 0000000..516789e --- /dev/null +++ b/mobileclip/models/mci.py @@ -0,0 +1,888 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# For licensing see accompanying LICENSE file. +# Copyright (C) 2024 Apple Inc. All Rights Reserved. + +from __future__ import annotations + +import copy +from functools import partial + +import torch +import torch.nn as nn +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.models import register_model +from timm.models.layers import DropPath, trunc_normal_ + +from mobileclip.modules.common.mobileone import MobileOneBlock +from mobileclip.modules.image.replknet import ReparamLargeKernelConv + + +def _cfg(url="", **kwargs): + return { + "url": url, + "num_classes": 1000, + "input_size": (3, 256, 256), + "pool_size": None, + "crop_pct": 0.95, + "interpolation": "bicubic", + "mean": IMAGENET_DEFAULT_MEAN, + "std": IMAGENET_DEFAULT_STD, + "classifier": "head", + **kwargs, + } + + +default_cfgs = { + "fastvit_t": _cfg(crop_pct=0.9), + "fastvit_s": _cfg(crop_pct=0.9), + "fastvit_m": _cfg(crop_pct=0.95), +} + + +def convolutional_stem(in_channels: int, out_channels: int, inference_mode: bool = False) -> nn.Sequential: + """Build convolutional stem with MobileOne blocks. + + Args: + in_channels: Number of input channels. + out_channels: Number of output channels. + inference_mode: Flag to instantiate model in inference mode. Default: ``False`` + + Returns: + nn.Sequential object with stem elements. + """ + return nn.Sequential( + MobileOneBlock( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + stride=2, + padding=1, + groups=1, + inference_mode=inference_mode, + use_se=False, + num_conv_branches=1, + ), + MobileOneBlock( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=3, + stride=2, + padding=1, + groups=out_channels, + inference_mode=inference_mode, + use_se=False, + num_conv_branches=1, + ), + MobileOneBlock( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=1, + stride=1, + padding=0, + groups=1, + inference_mode=inference_mode, + use_se=False, + num_conv_branches=1, + ), + ) + + +class MHSA(nn.Module): + """Multi-headed Self Attention module. + + Source modified from: + https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py + """ + + def __init__( + self, + dim: int, + head_dim: int = 32, + qkv_bias: bool = False, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + ) -> None: + """Build MHSA module that can handle 3D or 4D input tensors. + + Args: + dim: Number of embedding dimensions. + head_dim: Number of hidden dimensions per head. Default: ``32`` + qkv_bias: Use bias or not. Default: ``False`` + attn_drop: Dropout rate for attention tensor. + proj_drop: Dropout rate for projection tensor. + """ + super().__init__() + assert dim % head_dim == 0, "dim should be divisible by head_dim" + self.head_dim = head_dim + self.num_heads = dim // head_dim + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + shape = x.shape + B, C, H, W = shape + N = H * W + if len(shape) == 4: + x = torch.flatten(x, start_dim=2).transpose(-2, -1) # (B, N, C) + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) + + # trick here to make q@k.t more stable + attn = (q * self.scale) @ k.transpose(-2, -1) + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + if len(shape) == 4: + x = x.transpose(-2, -1).reshape(B, C, H, W) + + return x + + +class PatchEmbed(nn.Module): + """Convolutional patch embedding layer.""" + + def __init__( + self, + patch_size: int, + stride: int, + in_channels: int, + embed_dim: int, + inference_mode: bool = False, + use_se: bool = False, + ) -> None: + """Build patch embedding layer. + + Args: + patch_size: Patch size for embedding computation. + stride: Stride for convolutional embedding layer. + in_channels: Number of channels of input tensor. + embed_dim: Number of embedding dimensions. + inference_mode: Flag to instantiate model in inference mode. Default: ``False`` + use_se: If ``True`` SE block will be used. + """ + super().__init__() + block = list() + block.append( + ReparamLargeKernelConv( + in_channels=in_channels, + out_channels=embed_dim, + kernel_size=patch_size, + stride=stride, + groups=in_channels, + small_kernel=3, + inference_mode=inference_mode, + use_se=use_se, + ) + ) + block.append( + MobileOneBlock( + in_channels=embed_dim, + out_channels=embed_dim, + kernel_size=1, + stride=1, + padding=0, + groups=1, + inference_mode=inference_mode, + use_se=False, + num_conv_branches=1, + ) + ) + self.proj = nn.Sequential(*block) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.proj(x) + return x + + +class RepMixer(nn.Module): + """Reparameterizable token mixer. + + For more details, please refer to our paper: `FastViT: A Fast Hybrid Vision Transformer using Structural + Reparameterization `_ + """ + + def __init__( + self, + dim, + kernel_size=3, + use_layer_scale=True, + layer_scale_init_value=1e-5, + inference_mode: bool = False, + ): + """Build RepMixer Module. + + Args: + dim: Input feature map dimension. :math:`C_{in}` from an expected input of size :math:`(B, C_{in}, H, W)`. + kernel_size: Kernel size for spatial mixing. Default: 3 + use_layer_scale: If True, learnable layer scale is used. Default: ``True`` + layer_scale_init_value: Initial value for layer scale. Default: 1e-5 + inference_mode: If True, instantiates model in inference mode. Default: ``False`` + """ + super().__init__() + self.dim = dim + self.kernel_size = kernel_size + self.inference_mode = inference_mode + + if inference_mode: + self.reparam_conv = nn.Conv2d( + in_channels=self.dim, + out_channels=self.dim, + kernel_size=self.kernel_size, + stride=1, + padding=self.kernel_size // 2, + groups=self.dim, + bias=True, + ) + else: + self.norm = MobileOneBlock( + dim, + dim, + kernel_size, + padding=kernel_size // 2, + groups=dim, + use_act=False, + use_scale_branch=False, + num_conv_branches=0, + ) + self.mixer = MobileOneBlock( + dim, + dim, + kernel_size, + padding=kernel_size // 2, + groups=dim, + use_act=False, + ) + self.use_layer_scale = use_layer_scale + if use_layer_scale: + self.layer_scale = nn.Parameter(layer_scale_init_value * torch.ones((dim, 1, 1)), requires_grad=True) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if hasattr(self, "reparam_conv"): + x = self.reparam_conv(x) + return x + else: + if self.use_layer_scale: + x = x + self.layer_scale * (self.mixer(x) - self.norm(x)) + else: + x = x + self.mixer(x) - self.norm(x) + return x + + def reparameterize(self) -> None: + """Reparameterize mixer and norm into a single convolutional layer for efficient inference.""" + if self.inference_mode: + return + + self.mixer.reparameterize() + self.norm.reparameterize() + + if self.use_layer_scale: + w = self.mixer.id_tensor + self.layer_scale.unsqueeze(-1) * ( + self.mixer.reparam_conv.weight - self.norm.reparam_conv.weight + ) + b = torch.squeeze(self.layer_scale) * (self.mixer.reparam_conv.bias - self.norm.reparam_conv.bias) + else: + w = self.mixer.id_tensor + self.mixer.reparam_conv.weight - self.norm.reparam_conv.weight + b = self.mixer.reparam_conv.bias - self.norm.reparam_conv.bias + + self.reparam_conv = nn.Conv2d( + in_channels=self.dim, + out_channels=self.dim, + kernel_size=self.kernel_size, + stride=1, + padding=self.kernel_size // 2, + groups=self.dim, + bias=True, + ) + self.reparam_conv.weight.data = w + self.reparam_conv.bias.data = b + + for para in self.parameters(): + para.detach_() + self.__delattr__("mixer") + self.__delattr__("norm") + if self.use_layer_scale: + self.__delattr__("layer_scale") + + +class ConvFFN(nn.Module): + """Convolutional FFN Module.""" + + def __init__( + self, + in_channels: int, + hidden_channels: int | None = None, + out_channels: int | None = None, + act_layer: nn.Module = nn.GELU, + drop: float = 0.0, + ) -> None: + """Build convolutional FFN module. + + Args: + in_channels: Number of input channels. + hidden_channels: Number of channels after expansion. Default: None + out_channels: Number of output channels. Default: None + act_layer: Activation layer. Default: ``GELU`` + drop: Dropout rate. Default: ``0.0``. + """ + super().__init__() + out_channels = out_channels or in_channels + hidden_channels = hidden_channels or in_channels + self.conv = nn.Sequential() + self.conv.add_module( + "conv", + nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=7, + padding=3, + groups=in_channels, + bias=False, + ), + ) + self.conv.add_module("bn", nn.BatchNorm2d(num_features=out_channels)) + self.fc1 = nn.Conv2d(in_channels, hidden_channels, kernel_size=1) + self.act = act_layer() + self.fc2 = nn.Conv2d(hidden_channels, out_channels, kernel_size=1) + self.drop = nn.Dropout(drop) + self.apply(self._init_weights) + + def _init_weights(self, m: nn.Module) -> None: + if isinstance(m, nn.Conv2d): + trunc_normal_(m.weight, std=0.02) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.conv(x) + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class RepCPE(nn.Module): + """Implementation of conditional positional encoding. + + For more details refer to paper: `Conditional Positional Encodings for Vision Transformers + `_ + + In our implementation, we can reparameterize this module to eliminate a skip connection. + """ + + def __init__( + self, + in_channels: int, + embed_dim: int = 768, + spatial_shape: int | tuple[int, int] = (7, 7), + inference_mode=False, + ) -> None: + """Build reparameterizable conditional positional encoding. + + Args: + in_channels: Number of input channels. + embed_dim: Number of embedding dimensions. Default: 768 + spatial_shape: Spatial shape of kernel for positional encoding. Default: (7, 7) + inference_mode: Flag to instantiate block in inference mode. Default: ``False`` + """ + super().__init__() + if isinstance(spatial_shape, int): + spatial_shape = tuple([spatial_shape] * 2) + assert isinstance(spatial_shape, tuple), ( + f'"spatial_shape" must by a sequence or int, get {type(spatial_shape)} instead.' + ) + assert len(spatial_shape) == 2, f'Length of "spatial_shape" should be 2, got {len(spatial_shape)} instead.' + + self.spatial_shape = spatial_shape + self.embed_dim = embed_dim + self.in_channels = in_channels + self.groups = embed_dim + + if inference_mode: + self.reparam_conv = nn.Conv2d( + in_channels=self.in_channels, + out_channels=self.embed_dim, + kernel_size=self.spatial_shape, + stride=1, + padding=int(self.spatial_shape[0] // 2), + groups=self.embed_dim, + bias=True, + ) + else: + self.pe = nn.Conv2d( + in_channels, + embed_dim, + spatial_shape, + 1, + int(spatial_shape[0] // 2), + bias=True, + groups=embed_dim, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if hasattr(self, "reparam_conv"): + x = self.reparam_conv(x) + return x + else: + x = self.pe(x) + x + return x + + def reparameterize(self) -> None: + # Build equivalent Id tensor + input_dim = self.in_channels // self.groups + kernel_value = torch.zeros( + ( + self.in_channels, + input_dim, + self.spatial_shape[0], + self.spatial_shape[1], + ), + dtype=self.pe.weight.dtype, + device=self.pe.weight.device, + ) + for i in range(self.in_channels): + kernel_value[ + i, + i % input_dim, + self.spatial_shape[0] // 2, + self.spatial_shape[1] // 2, + ] = 1 + id_tensor = kernel_value + + # Reparameterize Id tensor and conv + w_final = id_tensor + self.pe.weight + b_final = self.pe.bias + + # Introduce reparam conv + self.reparam_conv = nn.Conv2d( + in_channels=self.in_channels, + out_channels=self.embed_dim, + kernel_size=self.spatial_shape, + stride=1, + padding=int(self.spatial_shape[0] // 2), + groups=self.embed_dim, + bias=True, + ) + self.reparam_conv.weight.data = w_final + self.reparam_conv.bias.data = b_final + + for para in self.parameters(): + para.detach_() + self.__delattr__("pe") + + +class RepMixerBlock(nn.Module): + """Implementation of Metaformer block with RepMixer as token mixer. + + For more details on Metaformer structure, please refer to: `MetaFormer Is Actually What You Need for Vision + `_ + """ + + def __init__( + self, + dim: int, + kernel_size: int = 3, + mlp_ratio: float = 4.0, + act_layer: nn.Module = nn.GELU, + drop: float = 0.0, + drop_path: float = 0.0, + use_layer_scale: bool = True, + layer_scale_init_value: float = 1e-5, + inference_mode: bool = False, + ): + """Build RepMixer Block. + + Args: + dim: Number of embedding dimensions. + kernel_size: Kernel size for repmixer. Default: 3 + mlp_ratio: MLP expansion ratio. Default: 4.0 + act_layer: Activation layer. Default: ``nn.GELU`` + drop: Dropout rate. Default: 0.0 + drop_path: Drop path rate. Default: 0.0 + use_layer_scale: Flag to turn on layer scale. Default: ``True`` + layer_scale_init_value: Layer scale value at initialization. Default: 1e-5 + inference_mode: Flag to instantiate block in inference mode. Default: ``False`` + """ + super().__init__() + + self.token_mixer = RepMixer( + dim, + kernel_size=kernel_size, + use_layer_scale=use_layer_scale, + layer_scale_init_value=layer_scale_init_value, + inference_mode=inference_mode, + ) + + assert mlp_ratio > 0, f"MLP ratio should be greater than 0, found: {mlp_ratio}" + mlp_hidden_dim = int(dim * mlp_ratio) + self.convffn = ConvFFN( + in_channels=dim, + hidden_channels=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + ) + + # Drop Path + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + # Layer Scale + self.use_layer_scale = use_layer_scale + if use_layer_scale: + self.layer_scale = nn.Parameter(layer_scale_init_value * torch.ones((dim, 1, 1)), requires_grad=True) + + def forward(self, x): + if self.use_layer_scale: + x = self.token_mixer(x) + x = x + self.drop_path(self.layer_scale * self.convffn(x)) + else: + x = self.token_mixer(x) + x = x + self.drop_path(self.convffn(x)) + return x + + +class AttentionBlock(nn.Module): + """Implementation of metaformer block with MHSA as token mixer. + + For more details on Metaformer structure, please refer to: `MetaFormer Is Actually What You Need for Vision + `_ + """ + + def __init__( + self, + dim: int, + mlp_ratio: float = 4.0, + act_layer: nn.Module = nn.GELU, + norm_layer: nn.Module = nn.BatchNorm2d, + drop: float = 0.0, + drop_path: float = 0.0, + use_layer_scale: bool = True, + layer_scale_init_value: float = 1e-5, + ): + """Build Attention Block. + + Args: + dim: Number of embedding dimensions. + mlp_ratio: MLP expansion ratio. Default: 4.0 + act_layer: Activation layer. Default: ``nn.GELU`` + norm_layer: Normalization layer. Default: ``nn.BatchNorm2d`` + drop: Dropout rate. Default: 0.0 + drop_path: Drop path rate. Default: 0.0 + use_layer_scale: Flag to turn on layer scale. Default: ``True`` + layer_scale_init_value: Layer scale value at initialization. Default: 1e-5 + """ + super().__init__() + + self.norm = norm_layer(dim) + self.token_mixer = MHSA(dim=dim) + + assert mlp_ratio > 0, f"MLP ratio should be greater than 0, found: {mlp_ratio}" + mlp_hidden_dim = int(dim * mlp_ratio) + self.convffn = ConvFFN( + in_channels=dim, + hidden_channels=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + ) + + # Drop path + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + # Layer Scale + self.use_layer_scale = use_layer_scale + if use_layer_scale: + self.layer_scale_1 = nn.Parameter(layer_scale_init_value * torch.ones((dim, 1, 1)), requires_grad=True) + self.layer_scale_2 = nn.Parameter(layer_scale_init_value * torch.ones((dim, 1, 1)), requires_grad=True) + + def forward(self, x): + if self.use_layer_scale: + x = x + self.drop_path(self.layer_scale_1 * self.token_mixer(self.norm(x))) + x = x + self.drop_path(self.layer_scale_2 * self.convffn(x)) + else: + x = x + self.drop_path(self.token_mixer(self.norm(x))) + x = x + self.drop_path(self.convffn(x)) + return x + + +def basic_blocks( + dim: int, + block_index: int, + num_blocks: list[int], + token_mixer_type: str, + kernel_size: int = 3, + mlp_ratio: float = 4.0, + act_layer: nn.Module = nn.GELU, + norm_layer: nn.Module = nn.BatchNorm2d, + drop_rate: float = 0.0, + drop_path_rate: float = 0.0, + use_layer_scale: bool = True, + layer_scale_init_value: float = 1e-5, + inference_mode=False, +) -> nn.Sequential: + """Build FastViT blocks within a stage. + + Args: + dim: Number of embedding dimensions. + block_index: block index. + num_blocks: List containing number of blocks per stage. + token_mixer_type: Token mixer type. + kernel_size: Kernel size for repmixer. + mlp_ratio: MLP expansion ratio. + act_layer: Activation layer. + norm_layer: Normalization layer. + drop_rate: Dropout rate. + drop_path_rate: Drop path rate. + use_layer_scale: Flag to turn on layer scale regularization. + layer_scale_init_value: Layer scale value at initialization. + inference_mode: Flag to instantiate block in inference mode. + + Returns: + nn.Sequential object of all the blocks within the stage. + """ + blocks = [] + for block_idx in range(num_blocks[block_index]): + block_dpr = drop_path_rate * (block_idx + sum(num_blocks[:block_index])) / (sum(num_blocks) - 1) + if token_mixer_type == "repmixer": + blocks.append( + RepMixerBlock( + dim, + kernel_size=kernel_size, + mlp_ratio=mlp_ratio, + act_layer=act_layer, + drop=drop_rate, + drop_path=block_dpr, + use_layer_scale=use_layer_scale, + layer_scale_init_value=layer_scale_init_value, + inference_mode=inference_mode, + ) + ) + elif token_mixer_type == "attention": + blocks.append( + AttentionBlock( + dim, + mlp_ratio=mlp_ratio, + act_layer=act_layer, + norm_layer=norm_layer, + drop=drop_rate, + drop_path=block_dpr, + use_layer_scale=use_layer_scale, + layer_scale_init_value=layer_scale_init_value, + ) + ) + else: + raise ValueError(f"Token mixer type: {token_mixer_type} not supported") + blocks = nn.Sequential(*blocks) + + return blocks + + +class FastViT(nn.Module): + """This class implements `FastViT architecture `_.""" + + def __init__( + self, + layers, + token_mixers: tuple[str, ...], + embed_dims=None, + mlp_ratios=None, + downsamples=None, + se_downsamples=None, + repmixer_kernel_size=3, + norm_layer: nn.Module = nn.BatchNorm2d, + act_layer: nn.Module = nn.GELU, + num_classes=1000, + pos_embs=None, + down_patch_size=7, + down_stride=2, + drop_rate=0.0, + drop_path_rate=0.0, + use_layer_scale=True, + layer_scale_init_value=1e-5, + init_cfg=None, + pretrained=None, + cls_ratio=2.0, + inference_mode=False, + **kwargs, + ) -> None: + super().__init__() + + self.num_classes = num_classes + if pos_embs is None: + pos_embs = [None] * len(layers) + + if se_downsamples is None: + se_downsamples = [False] * len(layers) + + # Convolutional stem + self.patch_embed = convolutional_stem(3, embed_dims[0], inference_mode) + + # Build the main stages of the network architecture + network = [] + for i in range(len(layers)): + # Add position embeddings if requested + if pos_embs[i] is not None: + network.append(pos_embs[i](embed_dims[i], embed_dims[i], inference_mode=inference_mode)) + stage = basic_blocks( + embed_dims[i], + i, + layers, + token_mixer_type=token_mixers[i], + kernel_size=repmixer_kernel_size, + mlp_ratio=mlp_ratios[i], + act_layer=act_layer, + norm_layer=norm_layer, + drop_rate=drop_rate, + drop_path_rate=drop_path_rate, + use_layer_scale=use_layer_scale, + layer_scale_init_value=layer_scale_init_value, + inference_mode=inference_mode, + ) + network.append(stage) + if i >= len(layers) - 1: + break + + # Patch merging/downsampling between stages. + if downsamples[i] or embed_dims[i] != embed_dims[i + 1]: + network.append( + PatchEmbed( + patch_size=down_patch_size, + stride=down_stride, + in_channels=embed_dims[i], + embed_dim=embed_dims[i + 1], + inference_mode=inference_mode, + use_se=se_downsamples[i + 1], + ) + ) + self.network = nn.ModuleList(network) + + # Classifier head + self.conv_exp = MobileOneBlock( + in_channels=embed_dims[-1], + out_channels=int(embed_dims[-1] * cls_ratio), + kernel_size=3, + stride=1, + padding=1, + groups=embed_dims[-1], + inference_mode=inference_mode, + use_se=True, + num_conv_branches=1, + ) + self.head = nn.Linear(int(embed_dims[-1] * cls_ratio), num_classes) if num_classes > 0 else nn.Identity() + self.apply(self.cls_init_weights) + self.init_cfg = copy.deepcopy(init_cfg) + + def cls_init_weights(self, m: nn.Module) -> None: + """Init. + + for classification. + """ + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward_embeddings(self, x: torch.Tensor) -> torch.Tensor: + x = self.patch_embed(x) + return x + + def forward_tokens(self, x: torch.Tensor) -> torch.Tensor: + for idx, block in enumerate(self.network): + x = block(x) + # output only the features of last layer for image classification + return x + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # input embedding + x = self.forward_embeddings(x) + # through backbone + x = self.forward_tokens(x) + # for image classification + x = self.conv_exp(x) + cls_out = self.head(x) + return cls_out + + +@register_model +def mci0(pretrained=False, **kwargs): + """Instantiate MCi0 model variant.""" + layers = [2, 6, 10, 2] + embed_dims = [64, 128, 256, 512] + mlp_ratios = [3, 3, 3, 3] + downsamples = [True, True, True, True] + se_downsamples = [False, False, True, True] + pos_embs = [None, None, None, partial(RepCPE, spatial_shape=(7, 7))] + token_mixers = ("repmixer", "repmixer", "repmixer", "attention") + model = FastViT( + layers, + token_mixers=token_mixers, + embed_dims=embed_dims, + pos_embs=pos_embs, + mlp_ratios=mlp_ratios, + downsamples=downsamples, + se_downsamples=se_downsamples, + **kwargs, + ) + model.default_cfg = default_cfgs["fastvit_s"] + if pretrained: + raise ValueError("Functionality not implemented.") + return model + + +@register_model +def mci1(pretrained=False, **kwargs): + """Instantiate MCi1 model variant.""" + layers = [4, 12, 20, 4] + embed_dims = [64, 128, 256, 512] + mlp_ratios = [3, 3, 3, 3] + downsamples = [True, True, True, True] + se_downsamples = [False, False, True, True] + pos_embs = [None, None, None, partial(RepCPE, spatial_shape=(7, 7))] + token_mixers = ("repmixer", "repmixer", "repmixer", "attention") + model = FastViT( + layers, + token_mixers=token_mixers, + embed_dims=embed_dims, + pos_embs=pos_embs, + mlp_ratios=mlp_ratios, + downsamples=downsamples, + se_downsamples=se_downsamples, + **kwargs, + ) + model.default_cfg = default_cfgs["fastvit_s"] + if pretrained: + raise ValueError("Functionality not implemented.") + return model + + +@register_model +def mci2(pretrained=False, **kwargs): + """Instantiate MCi2 model variant.""" + layers = [4, 12, 24, 4] + embed_dims = [80, 160, 320, 640] + mlp_ratios = [3, 3, 3, 3] + downsamples = [True, True, True, True] + se_downsamples = [False, False, True, True] + pos_embs = [None, None, None, partial(RepCPE, spatial_shape=(7, 7))] + token_mixers = ("repmixer", "repmixer", "repmixer", "attention") + model = FastViT( + layers, + token_mixers=token_mixers, + embed_dims=embed_dims, + pos_embs=pos_embs, + mlp_ratios=mlp_ratios, + downsamples=downsamples, + se_downsamples=se_downsamples, + **kwargs, + ) + model.default_cfg = default_cfgs["fastvit_m"] + if pretrained: + raise ValueError("Functionality not implemented.") + return model diff --git a/mobileclip/models/vit.py b/mobileclip/models/vit.py new file mode 100644 index 0000000..4646d79 --- /dev/null +++ b/mobileclip/models/vit.py @@ -0,0 +1,389 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# +# For licensing see accompanying LICENSE file. +# Copyright (C) 2024 Apple Inc. All Rights Reserved. +# +""" +Implementation of the following modules is borrowed from ml-cvnets repo: +https://github.com/apple/ml-cvnets/blob/main/cvnets/models/classification/vit.py. + +Please see ACKNOWLEDGMENTS for license details. +""" + +from __future__ import annotations + +import numpy as np +import torch +from timm.models import register_model +from torch import Tensor, nn + +from mobileclip import logger +from mobileclip.modules.common.transformer import ( + PositionalEmbedding, + TransformerEncoder, + get_normalization_layer, +) +from mobileclip.modules.image.image_projection import SimpleImageProjectionHead + + +class ConvNormAct(nn.Module): + """Applies an N-dimensional convolution over an input. + + Args: + cfg: Model configuration. + in_channels: :math:`C_{out}` from an expected output of size :math:`(bs, C_{in}, X_{1}, ..., X_{N})`. + out_channels: :math:`C_{out}` from an expected output of size :math:`(bs, C_{out}, Y_{1}, ..., Y_{N})`. + kernel_size: Kernel size for convolution. An integer, or tuple of length ``N``. + stride: Stride for convolution. An integer, or tuple of length ``N``. Default: 1. + dilation: Dilation rate for convolution. An integer, or tuple of length ``N``. Default: ``1``. + padding: Padding for convolution. An integer, or tuple of length ``N``. If not specified, padding is + automatically computed based on kernel size and dilation range. Default : ``None`` (equivalent to ``[ + int((kernel_size[i] - 1) / 2) * dilation[i] for i in range(N)]``). + groups: Number of groups in convolution. Default: ``1``. + bias: Use bias. Default: ``False``. + padding_mode: Padding mode ('zeros', 'reflect', 'replicate' or 'circular'). Default: ``zeros``. + use_norm: Use normalization layer after convolution. Default: ``True``. + use_act: Use activation layer after convolution (or convolution and normalization). Default: ``True``. + norm_layer: If not None, the provided normalization layer object will be used. Otherwise, a normalization object + will be created based on config ``model.normalization.*`` opts. + act_layer: If not None, the provided activation function will be used. Otherwise, an activation function will be + created based on config ``model.activation.*`` opts. + + Notes: + - Input: :math:`(bs, C_{in}, X_{1}, ..., X_{N})`. + - Output: :math:`(bs, C_{out}, Y_{1}, ..., Y_{N})`. + - For depth-wise convolution, `groups=C_{in}=C_{out}`. + """ + + def __init__( + self, + cfg: dict, + in_channels: int, + out_channels: int, + kernel_size: int | tuple[int, ...], + stride: int | tuple[int, ...] = 1, + dilation: int | tuple[int, ...] = 1, + padding: int | tuple[int, ...] | None = None, + groups: int = 1, + bias: bool = False, + padding_mode: str = "zeros", + use_norm: bool = True, + use_act: bool = True, + norm_layer: nn.Module | None = None, + act_layer: nn.Module | None = None, + *args, + **kwargs, + ) -> None: + super().__init__() + self.ndim = 2 + + if norm_layer is None and use_norm: + norm_type = cfg.get("normalization", "batch_norm") + if norm_type == "batch_norm": + norm_layer = nn.BatchNorm2d( + num_features=out_channels, + momentum=cfg.get("momentum", 0.1), + ) + else: + norm_layer = get_normalization_layer(num_features=out_channels, norm_type=norm_type) + elif norm_layer is not None and use_norm: + logger.error(f"When use_norm is False, norm_layer should be None, but norm_layer={norm_layer} is provided.") + + if act_layer is None and use_act: + act_layer = nn.GELU() # Default to GELU + elif act_layer is not None and use_act: + logger.error(f"When use_act is False, act_layer should be None, but act_layer={act_layer} is provided.") + + if use_norm and any(param[0] == "bias" for param in norm_layer.named_parameters()) and bias: + assert not bias, "Do not use bias when using normalization layers with bias." + + if isinstance(kernel_size, int): + kernel_size = (kernel_size,) * self.ndim + + if isinstance(stride, int): + stride = (stride,) * self.ndim + + if isinstance(dilation, int): + dilation = (dilation,) * self.ndim + + assert isinstance(kernel_size, tuple) + assert isinstance(stride, tuple) + assert isinstance(dilation, tuple) + + if padding is None: + padding = (int((kernel_size[i] - 1) / 2) * dilation[i] for i in range(self.ndim)) + + if in_channels % groups != 0: + logger.error(f"Input channels are not divisible by groups. {in_channels}%{groups} != 0 ") + if out_channels % groups != 0: + logger.error(f"Output channels are not divisible by groups. {out_channels}%{groups} != 0 ") + + block = nn.Sequential() + + conv_layer = nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, # type: ignore + stride=stride, # type: ignore + padding=padding, + dilation=dilation, # type: ignore + groups=groups, + bias=bias, + padding_mode=padding_mode, + ) + + block.add_module(name="conv", module=conv_layer) + + self.norm_name = None + if use_norm: + block.add_module(name="norm", module=norm_layer) + self.norm_name = norm_layer.__class__.__name__ + + self.act_name = None + if use_act: + block.add_module(name="act", module=act_layer) + self.act_name = act_layer.__class__.__name__ + + self.block = block + self.in_channels = in_channels + self.out_channels = out_channels + self.stride = stride + self.groups = groups + self.kernel_size = conv_layer.kernel_size + self.bias = bias + self.dilation = dilation + + def forward(self, x: Tensor) -> Tensor: + return self.block(x) + + +class VisionTransformer(nn.Module): + """This class defines the `Vision Transformer architecture `_. Our model + implementation is inspired from `Early Convolutions Help Transformers See + Better `_. + + .. note:: + Our implementation is different from the original implementation in two ways: + 1. Kernel size is odd. + 2. Our positional encoding implementation allows us to use ViT with any multiple input scales + 3. We do not use StochasticDepth + 4. We do not add positional encoding to class token (if enabled), as suggested in `DeiT-3 paper `_ + """ + + def __init__(self, cfg, *args, **kwargs) -> None: + super().__init__() + image_channels = 3 + num_classes = cfg.get("n_classes", 1000) + + self.projection_dim = None + if "projection_dim" in kwargs: + self.projection_dim = kwargs.get("projection_dim") + + kernel_sizes_conv_stem = [4, 2, 2] + strides_conv_stem = [4, 2, 2] + + # Typically, in the ImageNet dataset, we use 224x224 as a resolution. + # For out ViT implementation, patch size is 16 (16 = 4 * 2 * 2) + # Therefore, total number of embeddings along width and height are (224 / 16)^2 + num_embeddings = (224 // 16) ** 2 + + embed_dim = cfg["embed_dim"] + ffn_dim = cfg["embed_dim"] * 4 + pos_emb_drop_p = cfg.get("pos_emb_drop_p", 0.0) + n_transformer_layers = cfg["n_transformer_layers"] + num_heads = cfg["n_attn_heads"] + attn_dropout = cfg.get("attn_dropout", 0.0) + dropout = cfg.get("dropout", 0.0) + ffn_dropout = cfg.get("ffn_dropout", 0.0) + norm_layer = cfg.get("norm_layer", "layer_norm") + + conv_stem_proj_dim = max(32, embed_dim // 4) + patch_emb = [ + ConvNormAct( + cfg=cfg, + in_channels=image_channels, + out_channels=conv_stem_proj_dim, + kernel_size=kernel_sizes_conv_stem[0], + stride=strides_conv_stem[0], + bias=False, + use_norm=True, + use_act=True, + ), + ConvNormAct( + cfg=cfg, + in_channels=conv_stem_proj_dim, + out_channels=conv_stem_proj_dim, + kernel_size=kernel_sizes_conv_stem[1], + stride=strides_conv_stem[1], + bias=False, + use_norm=True, + use_act=True, + ), + ConvNormAct( + cfg=cfg, + in_channels=conv_stem_proj_dim, + out_channels=embed_dim, + kernel_size=kernel_sizes_conv_stem[2], + stride=strides_conv_stem[2], + bias=True, + use_norm=False, + use_act=False, + ), + ] + + self.patch_emb = nn.Sequential(*patch_emb) + + use_cls_token = not cfg.get("no_cls_token", False) + stochastic_dropout = cfg.get("stochastic_dropout", 0.0) + per_layer_stochastic_drop_rate = [round(x, 3) for x in np.linspace(0, stochastic_dropout, n_transformer_layers)] + transformer_blocks = [ + TransformerEncoder( + embed_dim=embed_dim, + ffn_latent_dim=ffn_dim, + num_heads=num_heads, + attn_dropout=attn_dropout, + dropout=dropout, + ffn_dropout=ffn_dropout, + transformer_norm_layer=norm_layer, + stochastic_dropout=per_layer_stochastic_drop_rate[layer_idx], + ) + for layer_idx in range(n_transformer_layers) + ] + + self.post_transformer_norm = get_normalization_layer(num_features=embed_dim, norm_type=norm_layer) + + self.transformer = nn.Sequential(*transformer_blocks) + + if self.projection_dim is None: + self.classifier = nn.Linear(embed_dim, num_classes) + else: + self.classifier = SimpleImageProjectionHead(embed_dim, self.projection_dim) + + if use_cls_token: + self.cls_token = nn.Parameter(torch.zeros(size=(1, 1, embed_dim))) + torch.nn.init.trunc_normal_(self.cls_token, std=0.02) + else: + self.cls_token = None + + self.pos_embed = PositionalEmbedding( + num_embeddings=num_embeddings, + embedding_dim=embed_dim, + padding_idx=None, + interpolation_mode="bilinear", + ) + self.emb_dropout = nn.Dropout(p=pos_emb_drop_p) + + def extract_patch_embeddings(self, x: Tensor) -> tuple[Tensor, tuple[int, int]]: + # input is of shape [Batch, in_channels, height, width]. in_channels is mostly 3 (for RGB images) + batch_size = x.shape[0] + + # [Batch, in_channels, height, width] --> [Batch, emb_dim, num_patches_height, num_patches_width] + patch_emb = self.patch_emb(x) + n_h, n_w = patch_emb.shape[-2:] + + # [Batch, emb_dim, num_patches_height, num_patches_width] --> [Batch, emb_dim, num_patches] + patch_emb = patch_emb.flatten(2) + # [Batch, emb_dim, num_patches] --> [Batch, num_patches, emb_dim] + patch_emb = patch_emb.transpose(1, 2).contiguous() + + n_patches = patch_emb.shape[1] + # we resize the positional encodings dynamically. + pos_emb = self.pos_embed(n_patches).to(patch_emb.dtype) + + # add positional encodings + patch_emb = pos_emb + patch_emb + + # add classification token + if self.cls_token is not None: + # [1, 1, emb_dim] --> [Batch, 1, emb_dim] + cls_tokens = self.cls_token.expand(batch_size, -1, -1) + # Concat([Batch, 1, emb_dim], [Batch, num_patches, emb_dim]) --> [Batch, num_patches + 1, emb_dim] + patch_emb = torch.cat((cls_tokens, patch_emb), dim=1) + + # dropout + patch_emb = self.emb_dropout(patch_emb) + return patch_emb, (n_h, n_w) + + def _features_from_transformer(self, x: Tensor, *args, **kwargs) -> tuple[Tensor, tuple[int, int]]: + # this function extract patch embeddings and then apply transformer module to learn + # inter-patch representations + + # [B, N, C] --> [N, B, embed_dim], where B is batch size, N is number of tokens, + # and embed_dim is feature dim + x, (n_h, n_w) = self.extract_patch_embeddings(x) + + for layer in self.transformer: + x = layer(x) + x = self.post_transformer_norm(x) + + return x, (n_h, n_w) + + def extract_features(self, x: Tensor, *args, **kwargs) -> tuple[Tensor, Tensor | None]: + # The extract_features function for ViT returns two outputs: (1) embedding corresponding to CLS token + # and (2) image embeddings of the shape [B, C, h//o, w//o], where the value of o is typically 16. + return_image_embeddings = kwargs.get("return_image_embeddings", False) + + # [B, C, H, W] --> [B, N + 1, embed_dim] or [B, N, embed_dim] + # here, B is batch size, C is input channels + # H and W are input height and width + # N is the number of pixels (or tokens) after processing input with conv stem and reshaping + # We add +1 for cls token (if applicable) + # embed_dim --> embedding dimension + x, (n_h, n_w) = self._features_from_transformer(x, *args, **kwargs) + + if self.cls_token is not None: + # [B, N + 1, embed_dim] --> [B, embed_dim], [B, N, embed_dim] + cls_embedding, image_embedding = torch.split(x, split_size_or_sections=[1, x.shape[1] - 1], dim=1) + cls_embedding = cls_embedding.squeeze(1) + else: + # [B, N, embed_dim] -> [B, embed_dim] + cls_embedding = torch.mean(x, dim=1) + # [B, N, embed_dim] + image_embedding = x + + if return_image_embeddings: + # reshape image embedding to 4-D tensor + # [B, N, C] --> [B, C, N] + image_embedding = image_embedding.transpose(1, 2).contiguous() + image_embedding = image_embedding.reshape(image_embedding.shape[0], -1, n_h, n_w) + + return cls_embedding, image_embedding + else: + return cls_embedding, None + + def forward_classifier(self, x: Tensor, *args, **kwargs) -> tuple[Tensor, Tensor]: + cls_embedding, image_embedding = self.extract_features(x, *args, **kwargs) + # classify based on CLS token + cls_embedding = self.classifier(cls_embedding) + return cls_embedding, image_embedding + + def forward(self, x: Tensor, *args, **kwargs) -> Tensor | dict[str, Tensor]: + # In ViT model, we can return either classifier embeddings (logits) or image embeddings or both. + # To return the image embeddings, we need to set keyword argument (return_image_embeddings) as True. + if kwargs.get("return_image_embeddings", False): + out_dict = dict() + prediction, image_embedding = self.forward_classifier(x, *args, **kwargs) + out_dict.update({"logits": prediction}) + if image_embedding is not None: + out_dict.update({"image_embeddings": image_embedding}) + return out_dict + else: + prediction, _ = self.forward_classifier(x, *args, **kwargs) + return prediction + + +@register_model +def vit_b16(pretrained=False, **kwargs): + # Vision transformer config + cfg = { + "norm_layer": "layer_norm_fp32", + "act_layer": "gelu", + "embed_dim": 768, + "n_transformer_layers": 12, + "n_attn_heads": 12, + } + model = VisionTransformer(cfg=cfg, **kwargs) + if pretrained: + raise ValueError("Functionality not implemented.") + return model diff --git a/mobileclip/modules/__init__.py b/mobileclip/modules/__init__.py new file mode 100644 index 0000000..9a1fd1c --- /dev/null +++ b/mobileclip/modules/__init__.py @@ -0,0 +1,6 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# +# For licensing see accompanying LICENSE file. +# Copyright (C) 2024 Apple Inc. All rights reserved. +# diff --git a/mobileclip/modules/common/__init__.py b/mobileclip/modules/common/__init__.py new file mode 100644 index 0000000..9a1fd1c --- /dev/null +++ b/mobileclip/modules/common/__init__.py @@ -0,0 +1,6 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# +# For licensing see accompanying LICENSE file. +# Copyright (C) 2024 Apple Inc. All rights reserved. +# diff --git a/mobileclip/modules/common/mobileone.py b/mobileclip/modules/common/mobileone.py new file mode 100644 index 0000000..770f593 --- /dev/null +++ b/mobileclip/modules/common/mobileone.py @@ -0,0 +1,330 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# For licensing see accompanying LICENSE file. +# Copyright (C) 2024 Apple Inc. All Rights Reserved. + +from __future__ import annotations + +import copy + +import torch +import torch.nn as nn +import torch.nn.functional as F + +__all__ = ["MobileOneBlock", "reparameterize_model"] + + +class SEBlock(nn.Module): + """Squeeze and Excite module. + + Pytorch implementation of `Squeeze-and-Excitation Networks` - https://arxiv.org/pdf/1709.01507.pdf + """ + + def __init__(self, in_channels: int, rd_ratio: float = 0.0625) -> None: + """Construct a Squeeze and Excite Module. + + Args: + in_channels: Number of input channels. + rd_ratio: Input channel reduction ratio. + """ + super().__init__() + self.reduce = nn.Conv2d( + in_channels=in_channels, + out_channels=int(in_channels * rd_ratio), + kernel_size=1, + stride=1, + bias=True, + ) + self.expand = nn.Conv2d( + in_channels=int(in_channels * rd_ratio), + out_channels=in_channels, + kernel_size=1, + stride=1, + bias=True, + ) + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + """Apply forward pass.""" + _b, c, h, w = inputs.size() + x = F.avg_pool2d(inputs, kernel_size=[h, w]) + x = self.reduce(x) + x = F.relu(x) + x = self.expand(x) + x = torch.sigmoid(x) + x = x.view(-1, c, 1, 1) + return inputs * x + + +class MobileOneBlock(nn.Module): + """MobileOne building block. + + This block has a multi-branched architecture at train-time and plain-CNN style architecture at inference time For + more details, please refer to our paper: `An Improved One millisecond Mobile Backbone` - + https://arxiv.org/pdf/2206.04040.pdf + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + padding: int = 0, + dilation: int = 1, + groups: int = 1, + inference_mode: bool = False, + use_se: bool = False, + use_act: bool = True, + use_scale_branch: bool = True, + num_conv_branches: int = 1, + activation: nn.Module = nn.GELU(), + ) -> None: + """Construct a MobileOneBlock module. + + Args: + in_channels: Number of channels in the input. + out_channels: Number of channels produced by the block. + kernel_size: Size of the convolution kernel. + stride: Stride size. + padding: Zero-padding size. + dilation: Kernel dilation factor. + groups: Group number. + inference_mode: If True, instantiates model in inference mode. + use_se: Whether to use SE-ReLU activations. + use_act: Whether to use activation. Default: ``True`` + use_scale_branch: Whether to use scale branch. Default: ``True`` + num_conv_branches: Number of linear conv branches. + """ + super().__init__() + self.inference_mode = inference_mode + self.groups = groups + self.stride = stride + self.padding = padding + self.dilation = dilation + self.kernel_size = kernel_size + self.in_channels = in_channels + self.out_channels = out_channels + self.num_conv_branches = num_conv_branches + + # Check if SE-ReLU is requested + if use_se: + self.se = SEBlock(out_channels) + else: + self.se = nn.Identity() + + if use_act: + self.activation = activation + else: + self.activation = nn.Identity() + + if inference_mode: + self.reparam_conv = nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=True, + ) + else: + # Re-parameterizable skip connection + self.rbr_skip = ( + nn.BatchNorm2d(num_features=in_channels) if out_channels == in_channels and stride == 1 else None + ) + + # Re-parameterizable conv branches + if num_conv_branches > 0: + rbr_conv = list() + for _ in range(self.num_conv_branches): + rbr_conv.append(self._conv_bn(kernel_size=kernel_size, padding=padding)) + self.rbr_conv = nn.ModuleList(rbr_conv) + else: + self.rbr_conv = None + + # Re-parameterizable scale branch + self.rbr_scale = None + if not isinstance(kernel_size, int): + kernel_size = kernel_size[0] + if (kernel_size > 1) and use_scale_branch: + self.rbr_scale = self._conv_bn(kernel_size=1, padding=0) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Apply forward pass.""" + # Inference mode forward pass. + if self.inference_mode: + return self.activation(self.se(self.reparam_conv(x))) + + # Multi-branched train-time forward pass. + # Skip branch output + identity_out = 0 + if self.rbr_skip is not None: + identity_out = self.rbr_skip(x) + + # Scale branch output + scale_out = 0 + if self.rbr_scale is not None: + scale_out = self.rbr_scale(x) + + # Other branches + out = scale_out + identity_out + if self.rbr_conv is not None: + for ix in range(self.num_conv_branches): + out += self.rbr_conv[ix](x) + + return self.activation(self.se(out)) + + def reparameterize(self): + """Following works like `RepVGG: Making VGG-style ConvNets Great Again` - https://arxiv.org/pdf/2101.03697.pdf. + We re-parameterize multi-branched architecture used at training time to obtain a plain CNN-like + structure for inference. + """ + if self.inference_mode: + return + kernel, bias = self._get_kernel_bias() + self.reparam_conv = nn.Conv2d( + in_channels=self.in_channels, + out_channels=self.out_channels, + kernel_size=self.kernel_size, + stride=self.stride, + padding=self.padding, + dilation=self.dilation, + groups=self.groups, + bias=True, + ) + self.reparam_conv.weight.data = kernel + self.reparam_conv.bias.data = bias + + # Delete un-used branches + for para in self.parameters(): + para.detach_() + self.__delattr__("rbr_conv") + self.__delattr__("rbr_scale") + if hasattr(self, "rbr_skip"): + self.__delattr__("rbr_skip") + + self.inference_mode = True + + def _get_kernel_bias(self) -> tuple[torch.Tensor, torch.Tensor]: + """Method to obtain re-parameterized kernel and bias. Reference: + https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L83. + + Returns: + Tuple of (kernel, bias) after fusing branches. + """ + # get weights and bias of scale branch + kernel_scale = 0 + bias_scale = 0 + if self.rbr_scale is not None: + kernel_scale, bias_scale = self._fuse_bn_tensor(self.rbr_scale) + # Pad scale branch kernel to match conv branch kernel size. + pad = self.kernel_size // 2 + kernel_scale = torch.nn.functional.pad(kernel_scale, [pad, pad, pad, pad]) + + # get weights and bias of skip branch + kernel_identity = 0 + bias_identity = 0 + if self.rbr_skip is not None: + kernel_identity, bias_identity = self._fuse_bn_tensor(self.rbr_skip) + + # get weights and bias of conv branches + kernel_conv = 0 + bias_conv = 0 + if self.rbr_conv is not None: + for ix in range(self.num_conv_branches): + _kernel, _bias = self._fuse_bn_tensor(self.rbr_conv[ix]) + kernel_conv += _kernel + bias_conv += _bias + + kernel_final = kernel_conv + kernel_scale + kernel_identity + bias_final = bias_conv + bias_scale + bias_identity + return kernel_final, bias_final + + def _fuse_bn_tensor(self, branch: nn.Sequential | nn.BatchNorm2d) -> tuple[torch.Tensor, torch.Tensor]: + """Method to fuse batchnorm layer with preceeding conv layer. Reference: + https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L95. + + Args: + branch: Sequence of ops to be fused. + + Returns: + Tuple of (kernel, bias) after fusing batchnorm. + """ + if isinstance(branch, nn.Sequential): + kernel = branch.conv.weight + running_mean = branch.bn.running_mean + running_var = branch.bn.running_var + gamma = branch.bn.weight + beta = branch.bn.bias + eps = branch.bn.eps + else: + assert isinstance(branch, nn.BatchNorm2d) + if not hasattr(self, "id_tensor"): + input_dim = self.in_channels // self.groups + + kernel_size = self.kernel_size + if isinstance(self.kernel_size, int): + kernel_size = (self.kernel_size, self.kernel_size) + + kernel_value = torch.zeros( + (self.in_channels, input_dim, kernel_size[0], kernel_size[1]), + dtype=branch.weight.dtype, + device=branch.weight.device, + ) + for i in range(self.in_channels): + kernel_value[i, i % input_dim, kernel_size[0] // 2, kernel_size[1] // 2] = 1 + self.id_tensor = kernel_value + kernel = self.id_tensor + running_mean = branch.running_mean + running_var = branch.running_var + gamma = branch.weight + beta = branch.bias + eps = branch.eps + std = (running_var + eps).sqrt() + t = (gamma / std).reshape(-1, 1, 1, 1) + return kernel * t, beta - running_mean * gamma / std + + def _conv_bn(self, kernel_size: int, padding: int) -> nn.Sequential: + """Helper method to construct conv-batchnorm layers. + + Args: + kernel_size: Size of the convolution kernel. + padding: Zero-padding size. + + Returns: + Conv-BN module. + """ + mod_list = nn.Sequential() + mod_list.add_module( + "conv", + nn.Conv2d( + in_channels=self.in_channels, + out_channels=self.out_channels, + kernel_size=kernel_size, + stride=self.stride, + padding=padding, + groups=self.groups, + bias=False, + ), + ) + mod_list.add_module("bn", nn.BatchNorm2d(num_features=self.out_channels)) + return mod_list + + +def reparameterize_model(model: torch.nn.Module) -> nn.Module: + """Method returns a model where a multi-branched structure used in training is re-parameterized into a single branch + for inference. + + Args: + model: MobileOne model in train mode. + + Returns: + MobileOne model in inference mode. + """ + # Avoid editing original graph + model = copy.deepcopy(model) + for module in model.modules(): + if hasattr(module, "reparameterize"): + module.reparameterize() + return model diff --git a/mobileclip/modules/common/transformer.py b/mobileclip/modules/common/transformer.py new file mode 100644 index 0000000..4e316d5 --- /dev/null +++ b/mobileclip/modules/common/transformer.py @@ -0,0 +1,410 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# +# For licensing see accompanying LICENSE file. +# Copyright (C) 2024 Apple Inc. All Rights Reserved. +# +""" +Implementation of the following modules is borrowed from ml-cvnets repo: +https://github.com/apple/ml-cvnets/blob/main/cvnets/layers/multi_head_attention.py +https://github.com/apple/ml-cvnets/blob/main/cvnets/text_encoders/transformer.py. + +Please see ACKNOWLEDGMENTS for license details. +""" + +from __future__ import annotations + +import torch +from torch import Size, Tensor, nn +from torch.nn import functional as F +from torchvision.ops import StochasticDepth + +from mobileclip import logger + + +class LayerNormFP32(nn.LayerNorm): + """Applies `Layer Normalization `_ over a input tensor with FP32 precision.""" + + def __init__( + self, + normalized_shape: int | list[int] | Size, + eps: float | None = 1e-5, + elementwise_affine: bool | None = True, + *args, + **kwargs, + ): + super().__init__( + normalized_shape=normalized_shape, + eps=eps, + elementwise_affine=elementwise_affine, + *args, + **kwargs, + ) + + def forward(self, x: Tensor) -> Tensor: + # Convert input from dtype X to FP32 and perform normalization operation. + # This may help with underflow/overflow issues that we typically see with normalization layers + inp_dtype = x.dtype + return super().forward(x.to(torch.float32)).to(inp_dtype) + + +def get_normalization_layer(norm_type, num_features): + if norm_type == "layer_norm": + return nn.LayerNorm(num_features) + elif norm_type == "layer_norm_fp32": + return LayerNormFP32(num_features) + else: + raise NotImplementedError(f"Option: {norm_type} not supported.") + + +class PositionalEmbedding(nn.Module): + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + padding_idx: int | None = None, + is_learnable: bool | None = False, + interpolation_mode: str | None = "bilinear", + *args, + **kwargs, + ): + super().__init__() + # Add other pos embedding here and logic to choose between them + module = LearnablePositionalEmbedding + + self.pos_embed = module( + num_embeddings=num_embeddings, + embedding_dim=embedding_dim, + padding_idx=padding_idx, + interpolation_mode=interpolation_mode, + *args, + **kwargs, + ) + + def forward(self, seq_len: int, *args, **kwargs) -> Tensor: + return self.pos_embed(seq_len, *args, **kwargs) + + def __repr__(self): + return self.pos_embed.__repr__() + + +class LearnablePositionalEmbedding(nn.Module): + """Learnable Positional embedding.""" + + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + padding_idx: int | None = None, + interpolation_mode: str | None = "bilinear", + *args, + **kwargs, + ): + super().__init__() + self.pos_embed = nn.Parameter(torch.empty(1, 1, num_embeddings, embedding_dim)) + self.embedding_dim = embedding_dim + self.num_embeddings = num_embeddings + self.padding_idx = padding_idx + self.interpolation_mode = interpolation_mode + + self.reset_parameters() + + def reset_parameters(self) -> None: + nn.init.trunc_normal_(self.pos_embed, mean=0, std=self.embedding_dim**-0.5) + if self.padding_idx is not None: + with torch.no_grad(): + self.pos_embed[:, :, self.padding_idx, ...] = 0.0 + + def forward(self, seq_len: int, *args, **kwargs) -> Tensor: + # scale pos embedding + pos_embed = self.pos_embed + if self.padding_idx is not None: + with torch.no_grad(): + pos_embed[:, :, self.padding_idx, ...] = 0.0 + + if seq_len != self.num_embeddings: + pos_embed = F.interpolate( + pos_embed, + size=(seq_len, self.embedding_dim), + mode=self.interpolation_mode, + ) + + # Input is of the form [Batch, Seq_len, Embedding_dim] + return pos_embed.reshape(1, seq_len, self.embedding_dim) + + def __repr__(self): + return f"{self.__class__.__name__}(num_embeddings={self.num_embeddings}, embedding_dim={self.embedding_dim}, padding_idx={self.padding_idx})" + + +class MultiHeadAttention(nn.Module): + """This layer applies a multi-head self- or cross-attention as described in `Attention is all you need + `_ paper. + + Args: + embed_dim (int): :math:`C_{in}` from an expected input of size :math:`(N, S, C_{in})` + num_heads (int): Number of heads in multi-head attention + attn_dropout (Optional[float]): Attention dropout. Default: 0.0 + bias (Optional[bool]): Use bias or not. Default: ``True`` + + Notes: + - Input: + - Query tensor (x_q) :math:`(N, S, C_{in})` where :math:`N` is batch size, :math:`S` is number of source tokens, + and: math:`C_{in}` is input embedding dim + - Optional Key-Value tensor (x_kv) :math:`(N, T, C_{in})` where :math:`T` is number of target tokens + - Output: same shape as the input + """ + + def __init__( + self, + embed_dim: int, + num_heads: int, + attn_dropout: float | None = 0.0, + bias: bool | None = True, + output_dim: int | None = None, + *args, + **kwargs, + ) -> None: + if output_dim is None: + output_dim = embed_dim + super().__init__() + if embed_dim % num_heads != 0: + logger.error( + f"Embedding dim must be divisible by number of heads in {self.__class__.__name__}. Got: embed_dim={embed_dim} and num_heads={num_heads}" + ) + + self.qkv_proj = nn.Linear(in_features=embed_dim, out_features=3 * embed_dim, bias=bias) + + self.attn_dropout = nn.Dropout(p=attn_dropout) + self.out_proj = nn.Linear(in_features=embed_dim, out_features=output_dim, bias=bias) + + self.head_dim = embed_dim // num_heads + self.scaling = self.head_dim**-0.5 + self.softmax = nn.Softmax(dim=-1) + self.num_heads = num_heads + self.embed_dim = embed_dim + self.use_separate_proj_weight = embed_dim != output_dim + + def __repr__(self): + return f"{self.__class__.__name__}(head_dim={self.head_dim}, num_heads={self.num_heads}, attn_dropout={self.attn_dropout.p})" + + def _forward_impl( + self, + x_q: Tensor, + x_kv: Tensor | None = None, + key_padding_mask: Tensor | None = None, + attn_mask: Tensor | None = None, + ) -> Tensor: + # [N, S, C] + b_sz, S_len, _in_channels = x_q.shape + + if x_kv is None: + # self-attention + # [N, S, C] --> [N, S, 3C] --> [N, S, 3, h, c] where C = hc + qkv = self.qkv_proj(x_q).reshape(b_sz, S_len, 3, self.num_heads, -1) + # [N, S, 3, h, c] --> [N, h, 3, S, C] + qkv = qkv.transpose(1, 3).contiguous() + + # [N, h, 3, S, C] --> [N, h, S, C] x 3 + query, key, value = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2] + else: + T_len = x_kv.shape[1] + + # cross-attention + # [N, S, C] + query = F.linear( + x_q, + weight=self.qkv_proj.weight[: self.embed_dim, ...], + bias=self.qkv_proj.bias[: self.embed_dim] if self.qkv_proj.bias is not None else None, + ) + # [N, S, C] --> [N, S, h, c] --> [N, h, S, c] + query = query.reshape(b_sz, S_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + # [N, T, C] --> [N, T, 2C] + kv = F.linear( + x_kv, + weight=self.qkv_proj.weight[self.embed_dim :, ...], + bias=self.qkv_proj.bias[self.embed_dim :] if self.qkv_proj.bias is not None else None, + ) + # [N, T, 2C] --> [N, T, 2, h, c] + kv = kv.reshape(b_sz, T_len, 2, self.num_heads, self.head_dim) + # [N, T, 2, h, c] --> [N, h, 2, T, c] + kv = kv.transpose(1, 3).contiguous() + key, value = kv[:, :, 0], kv[:, :, 1] + + query = query * self.scaling + + # [N h, T, c] --> [N, h, c, T] + key = key.transpose(-1, -2) + + # QK^T + # [N, h, S, c] x [N, h, c, T] --> [N, h, S, T] + attn = torch.matmul(query, key) + + batch_size, _num_heads, num_src_tokens, num_tgt_tokens = attn.shape + if attn_mask is not None: + # attn_mask shape should be the same as attn + assert list(attn_mask.shape) == [ + batch_size, + num_src_tokens, + num_tgt_tokens, + ], ( + f"Shape of attention mask should be [{batch_size}, {num_src_tokens}, {num_tgt_tokens}]. Got: {attn_mask.shape}" + ) + # [N, S, T] --> [N, 1, S, T] + attn_mask = attn_mask.unsqueeze(1) + attn = attn + attn_mask + + if key_padding_mask is not None: + # Do not attend to padding positions + # key padding mask size is [N, T] + assert key_padding_mask.dim() == 2 and list(key_padding_mask.shape) == [ + batch_size, + num_tgt_tokens, + ], ( + f"Key_padding_mask should be 2-dimension with shape [{batch_size}, {num_tgt_tokens}]. Got: {key_padding_mask.shape}" + ) + attn = attn.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), # [N, T] --> [N, 1, 1, T] + float("-inf"), + ) + + attn_dtype = attn.dtype + attn_as_float = self.softmax(attn.float()) + attn = attn_as_float.to(attn_dtype) + attn = self.attn_dropout(attn) + + # weighted sum + # [N, h, S, T] x [N, h, T, c] --> [N, h, S, c] + out = torch.matmul(attn, value) + + # [N, h, S, c] --> [N, S, h, c] --> [N, S, C] + out = out.transpose(1, 2).reshape(b_sz, S_len, -1) + out = self.out_proj(out) + + return out + + def forward( + self, + x_q: Tensor, + x_kv: Tensor | None = None, + key_padding_mask: Tensor | None = None, + attn_mask: Tensor | None = None, + *args, + **kwargs, + ) -> Tensor: + # [Batch , Sequence, Hidden_dim] + return self._forward_impl( + x_q=x_q, + x_kv=x_kv, + key_padding_mask=key_padding_mask, + attn_mask=attn_mask, + ) + + +class TransformerEncoder(nn.Module): + """This class defines the pre-norm `Transformer encoder `_. + + Args: + embed_dim: :math:`C_{in}` from an expected input of size :math:`(N, P, C_{in})`. + ffn_latent_dim: Inner dimension of the FFN. + num_heads: Number of heads in multi-head attention. Default: 8. + attn_dropout: Dropout rate for attention in multi-head attention. Default: 0.0 + dropout: Dropout rate. Default: 0.0. + ffn_dropout: Dropout between FFN layers. Default: 0.0. + transformer_norm_layer: Normalization layer. Default: layer_norm. + stochastic_dropout: Stochastic dropout setting. Default: 0.0. + + Notes: + - Input: :math:`(N, P, C_{in})` where :math:`N` is batch size, :math:`P` is number of patches, + and: math:`C_{in}` is input embedding dim + - Output: same shape as the input + """ + + def __init__( + self, + embed_dim: int, + ffn_latent_dim: int, + num_heads: int | None = 8, + attn_dropout: float | None = 0.0, + dropout: float | None = 0.0, + ffn_dropout: float | None = 0.0, + transformer_norm_layer: str | None = "layer_norm", + stochastic_dropout: float | None = 0.0, + *args, + **kwargs, + ) -> None: + super().__init__() + + # Build attention layer + attn_unit = MultiHeadAttention( + embed_dim, + num_heads, + attn_dropout=attn_dropout, + bias=True, + ) + + self.pre_norm_mha = nn.Sequential( + get_normalization_layer(norm_type=transformer_norm_layer, num_features=embed_dim), + attn_unit, + nn.Dropout(p=dropout), + ) + + act_name = nn.GELU() + self.pre_norm_ffn = nn.Sequential( + get_normalization_layer(norm_type=transformer_norm_layer, num_features=embed_dim), + nn.Linear(in_features=embed_dim, out_features=ffn_latent_dim, bias=True), + act_name, + nn.Dropout(p=ffn_dropout), + nn.Linear(in_features=ffn_latent_dim, out_features=embed_dim, bias=True), + nn.Dropout(p=dropout), + ) + + self.drop_path = nn.Identity() + if stochastic_dropout > 0.0: + if dropout > 0.0: + logger.error( + "Stochastic dropout and dropout are mutually exclusive. " + "Use either of them, but not both." + f"Got: {stochastic_dropout} and {dropout}" + ) + self.drop_path = StochasticDepth(p=stochastic_dropout, mode="row") + + self.embed_dim = embed_dim + self.ffn_dim = ffn_latent_dim + self.ffn_dropout = ffn_dropout + self.stochastic_dropout = stochastic_dropout + self.std_dropout = dropout + self.attn_fn_name = attn_unit.__class__.__name__ + self.act_fn_name = act_name.__class__.__name__ + self.norm_type = transformer_norm_layer + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(embed_dim={self.embed_dim}, ffn_dim={self.ffn_dim}, dropout={self.std_dropout}, ffn_dropout={self.ffn_dropout}, stochastic_dropout={self.stochastic_dropout}, attn_fn={self.attn_fn_name}, act_fn={self.act_fn_name}, norm_fn={self.norm_type})" + + def forward( + self, + x: Tensor, + x_prev: Tensor | None = None, + key_padding_mask: Tensor | None = None, + attn_mask: Tensor | None = None, + *args, + **kwargs, + ) -> Tensor: + # Multi-head attention + res = x + x = self.pre_norm_mha[0](x) # norm + x = self.pre_norm_mha[1]( + x_q=x, + x_kv=x_prev, + key_padding_mask=key_padding_mask, + attn_mask=attn_mask, + *args, + **kwargs, + ) # mha + + x = self.drop_path(self.pre_norm_mha[2](x)) # applying stochastic depth + x = x + res + + # Feed forward network + x = x + self.drop_path(self.pre_norm_ffn(x)) + return x diff --git a/mobileclip/modules/image/__init__.py b/mobileclip/modules/image/__init__.py new file mode 100644 index 0000000..9a1fd1c --- /dev/null +++ b/mobileclip/modules/image/__init__.py @@ -0,0 +1,6 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# +# For licensing see accompanying LICENSE file. +# Copyright (C) 2024 Apple Inc. All rights reserved. +# diff --git a/mobileclip/modules/image/image_projection.py b/mobileclip/modules/image/image_projection.py new file mode 100644 index 0000000..dbf0f1a --- /dev/null +++ b/mobileclip/modules/image/image_projection.py @@ -0,0 +1,97 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# For licensing see accompanying LICENSE file. +# Copyright (C) 2024 Apple Inc. All Rights Reserved. + +from __future__ import annotations + +import torch +import torch.nn as nn +from torch import Tensor + +from mobileclip import logger + + +class GlobalPool(nn.Module): + """This layers applies global pooling over a 4D or 5D input tensor. + + Args: + pool_type (Optional[str]): Pooling type. It can be mean, rms, or abs. Default: `mean` + keep_dim (Optional[bool]): Do not squeeze the dimensions of a tensor. Default: `False` + + Notes: + - Input: :math:`(N, C, H, W)` or :math:`(N, C, D, H, W)` + - Output: :math:`(N, C, 1, 1)` or :math:`(N, C, 1, 1, 1)` if keep_dim else :math:`(N, C)` + """ + + pool_types = ["mean", "rms", "abs"] + + def __init__(self, pool_type: str | None = "mean", keep_dim: bool | None = False, *args, **kwargs) -> None: + super().__init__() + if pool_type not in self.pool_types: + logger.error(f"Supported pool types are: {self.pool_types}. Got {pool_type}") + self.pool_type = pool_type + self.keep_dim = keep_dim + + def _global_pool(self, x: Tensor, dims: list): + if self.pool_type == "rms": # root mean square + x = x**2 + x = torch.mean(x, dim=dims, keepdim=self.keep_dim) + x = x**-0.5 + elif self.pool_type == "abs": # absolute + x = torch.mean(torch.abs(x), dim=dims, keepdim=self.keep_dim) + else: + # default is mean + # same as AdaptiveAvgPool + x = torch.mean(x, dim=dims, keepdim=self.keep_dim) + return x + + def forward(self, x: Tensor) -> Tensor: + if x.dim() == 4: + dims = [-2, -1] + elif x.dim() == 5: + dims = [-3, -2, -1] + else: + raise NotImplementedError("Currently 2D and 3D global pooling supported") + return self._global_pool(x, dims=dims) + + +class GlobalPool2D(nn.Module): + """This class implements global pooling with linear projection.""" + + def __init__(self, in_dim: int, out_dim: int, *args, **kwargs) -> None: + super().__init__() + scale = in_dim**-0.5 + self.pool = GlobalPool(pool_type="mean", keep_dim=False) + self.proj = nn.Parameter(scale * torch.randn(size=(in_dim, out_dim))) + self.in_dim = in_dim + self.out_dim = out_dim + + def forward(self, x: Tensor, *args, **kwargs) -> Tensor: + # x is of shape [batch, in_dim] + assert x.dim() == 4, f"Input should be 4-dimensional (Batch x in_dim x in_height x in_width). Got: {x.shape}" + + # [batch, in_dim, in_height, in_width] --> [batch, in_dim] + x = self.pool(x) + # [batch, in_dim] x [in_dim, out_dim] --> [batch, out_dim] + x = x @ self.proj + return x + + +class SimpleImageProjectionHead(nn.Module): + """This class implements linear projection head.""" + + def __init__(self, in_dim: int, out_dim: int) -> None: + super().__init__() + scale = in_dim**-0.5 + self.proj = nn.Parameter(scale * torch.randn(size=(in_dim, out_dim))) + self.in_dim = in_dim + self.out_dim = out_dim + + def forward(self, x: Tensor, *args, **kwargs) -> Tensor: + # x is of shape [batch, in_dim] + assert x.dim() == 2, f"Input should be 2-dimensional (Batch x in_dim). Got: {x.shape}" + + # [batch, in_dim] x [in_dim, out_dim] --> [batch, out_dim] + x = x @ self.proj + return x diff --git a/mobileclip/modules/image/replknet.py b/mobileclip/modules/image/replknet.py new file mode 100644 index 0000000..8b148e2 --- /dev/null +++ b/mobileclip/modules/image/replknet.py @@ -0,0 +1,177 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# +# For acknowledgment see accompanying ACKNOWLEDGMENTS file. +# Copyright (C) 2024 Apple Inc. All rights reserved. +# + +import torch +import torch.nn as nn +from timm.models.layers import SqueezeExcite + +__all__ = ["ReparamLargeKernelConv"] + + +class ReparamLargeKernelConv(nn.Module): + """Building Block of RepLKNet. + + This class defines overparameterized large kernel conv block introduced in `RepLKNet + `_ + + Reference: https://github.com/DingXiaoH/RepLKNet-pytorch + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int, + groups: int, + small_kernel: int, + inference_mode: bool = False, + use_se: bool = False, + activation: nn.Module = nn.GELU(), + ) -> None: + """Construct a ReparamLargeKernelConv module. + + Args: + in_channels: Number of input channels. + out_channels: Number of output channels. + kernel_size: Kernel size of the large kernel conv branch. + stride: Stride size. Default: 1 + groups: Group number. Default: 1 + small_kernel: Kernel size of small kernel conv branch. + inference_mode: If True, instantiates model in inference mode. Default: ``False`` + activation: Activation module. Default: ``nn.GELU`` + """ + super().__init__() + + self.stride = stride + self.groups = groups + self.in_channels = in_channels + self.out_channels = out_channels + self.activation = activation + + self.kernel_size = kernel_size + self.small_kernel = small_kernel + self.padding = kernel_size // 2 + + # Check if SE is requested + if use_se: + self.se = SqueezeExcite(out_channels, rd_ratio=0.25) + else: + self.se = nn.Identity() + + if inference_mode: + self.lkb_reparam = nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=self.padding, + dilation=1, + groups=groups, + bias=True, + ) + else: + self.lkb_origin = self._conv_bn(kernel_size=kernel_size, padding=self.padding) + if small_kernel is not None: + assert small_kernel <= kernel_size, ( + "The kernel size for re-param cannot be larger than the large kernel!" + ) + self.small_conv = self._conv_bn(kernel_size=small_kernel, padding=small_kernel // 2) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Apply forward pass.""" + if hasattr(self, "lkb_reparam"): + out = self.lkb_reparam(x) + else: + out = self.lkb_origin(x) + if hasattr(self, "small_conv"): + out += self.small_conv(x) + + return self.activation(self.se(out)) + + def get_kernel_bias(self) -> tuple[torch.Tensor, torch.Tensor]: + """Method to obtain re-parameterized kernel and bias. Reference: https://github.com/DingXiaoH/RepLKNet-pytorch. + + Returns: + Tuple of (kernel, bias) after fusing branches. + """ + eq_k, eq_b = self._fuse_bn(self.lkb_origin.conv, self.lkb_origin.bn) + if hasattr(self, "small_conv"): + small_k, small_b = self._fuse_bn(self.small_conv.conv, self.small_conv.bn) + eq_b += small_b + eq_k += nn.functional.pad(small_k, [(self.kernel_size - self.small_kernel) // 2] * 4) + return eq_k, eq_b + + def reparameterize(self) -> None: + """Following works like `RepVGG: Making VGG-style ConvNets Great Again` - https://arxiv.org/pdf/2101.03697.pdf. + We re-parameterize multi-branched architecture used at training time to obtain a plain CNN-like + structure for inference. + """ + eq_k, eq_b = self.get_kernel_bias() + self.lkb_reparam = nn.Conv2d( + in_channels=self.in_channels, + out_channels=self.out_channels, + kernel_size=self.kernel_size, + stride=self.stride, + padding=self.padding, + dilation=self.lkb_origin.conv.dilation, + groups=self.groups, + bias=True, + ) + + self.lkb_reparam.weight.data = eq_k + self.lkb_reparam.bias.data = eq_b + self.__delattr__("lkb_origin") + if hasattr(self, "small_conv"): + self.__delattr__("small_conv") + + @staticmethod + def _fuse_bn(conv: torch.Tensor, bn: nn.BatchNorm2d) -> tuple[torch.Tensor, torch.Tensor]: + """Method to fuse batchnorm layer with conv layer. + + Args: + conv: Convolutional kernel weights. + bn: Batchnorm 2d layer. + + Returns: + Tuple of (kernel, bias) after fusing batchnorm. + """ + kernel = conv.weight + running_mean = bn.running_mean + running_var = bn.running_var + gamma = bn.weight + beta = bn.bias + eps = bn.eps + std = (running_var + eps).sqrt() + t = (gamma / std).reshape(-1, 1, 1, 1) + return kernel * t, beta - running_mean * gamma / std + + def _conv_bn(self, kernel_size: int, padding: int = 0) -> nn.Sequential: + """Helper method to construct conv-batchnorm layers. + + Args: + kernel_size: Size of the convolution kernel. + padding: Zero-padding size. + + Returns: + A nn.Sequential Conv-BN module. + """ + mod_list = nn.Sequential() + mod_list.add_module( + "conv", + nn.Conv2d( + in_channels=self.in_channels, + out_channels=self.out_channels, + kernel_size=kernel_size, + stride=self.stride, + padding=padding, + groups=self.groups, + bias=False, + ), + ) + mod_list.add_module("bn", nn.BatchNorm2d(num_features=self.out_channels)) + return mod_list diff --git a/mobileclip/modules/text/__init__.py b/mobileclip/modules/text/__init__.py new file mode 100644 index 0000000..9a1fd1c --- /dev/null +++ b/mobileclip/modules/text/__init__.py @@ -0,0 +1,6 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# +# For licensing see accompanying LICENSE file. +# Copyright (C) 2024 Apple Inc. All rights reserved. +# diff --git a/mobileclip/modules/text/repmixer.py b/mobileclip/modules/text/repmixer.py new file mode 100644 index 0000000..7b53e60 --- /dev/null +++ b/mobileclip/modules/text/repmixer.py @@ -0,0 +1,265 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# For licensing see accompanying LICENSE file. +# Copyright (C) 2024 Apple Inc. All Rights Reserved. + +from __future__ import annotations + +import torch +import torch.nn as nn +from timm.models.layers import DropPath, trunc_normal_ + +from mobileclip.modules.common.mobileone import MobileOneBlock + + +class ConvFFN(nn.Module): + """Convolutional FFN Module.""" + + def __init__( + self, + in_channels: int, + context_size: int, + hidden_channels: int | None = None, + out_channels: int | None = None, + act_layer: nn.Module = nn.GELU, + drop: float = 0.0, + ) -> None: + """Build convolutional FFN module. + + Args: + in_channels: Number of input channels. + context_size: Context size for 1D signals. + hidden_channels: Number of channels after expansion. Default: None + out_channels: Number of output channels. Default: None + act_layer: Activation layer. Default: ``GELU`` + drop: Dropout rate. Default: ``0.0``. + """ + super().__init__() + out_channels = out_channels or in_channels + hidden_channels = hidden_channels or in_channels + self.conv = nn.Sequential() + self.conv.add_module( + "conv", + nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=(1, int(context_size)), + padding=(0, int(context_size // 2)), + groups=in_channels, + bias=False, + ), + ) + self.conv.add_module("bn", nn.BatchNorm2d(num_features=out_channels)) + self.fc1 = nn.Conv2d(in_channels, hidden_channels, kernel_size=1) + self.act = act_layer() + self.fc2 = nn.Conv2d(hidden_channels, out_channels, kernel_size=1) + self.drop = nn.Dropout(drop) + self.apply(self._init_weights) + + def _init_weights(self, m: nn.Module) -> None: + if isinstance(m, nn.Conv2d): + trunc_normal_(m.weight, std=0.02) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.conv(x) + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class RepMixer(nn.Module): + """Reparameterizable token mixer. + + For more details, please refer to our paper: `FastViT: A Fast Hybrid Vision Transformer using Structural + Reparameterization `_ + """ + + def __init__( + self, + dim, + kernel_size=3, + use_layer_scale=True, + layer_scale_init_value=1e-5, + inference_mode: bool = False, + ): + """Build RepMixer Module. + + Args: + dim: Input feature map dimension. :math:`C_{in}` from an expected input of size :math:`(B, C_{in}, H, W)`. + kernel_size: Kernel size for spatial mixing. Default: 3 + use_layer_scale: If True, learnable layer scale is used. Default: ``True`` + layer_scale_init_value: Initial value for layer scale. Default: 1e-5 + inference_mode: If True, instantiates model in inference mode. Default: ``False`` + """ + super().__init__() + self.dim = dim + self.kernel_size = kernel_size + self.inference_mode = inference_mode + + if inference_mode: + self.reparam_conv = nn.Conv2d( + in_channels=self.dim, + out_channels=self.dim, + kernel_size=(1, self.kernel_size), + stride=1, + padding=(0, self.kernel_size // 2), + groups=self.dim, + bias=True, + ) + else: + self.norm = MobileOneBlock( + dim, + dim, + (1, kernel_size), + padding=(0, kernel_size // 2), + groups=dim, + use_act=False, + use_scale_branch=False, + num_conv_branches=0, + ) + self.mixer = MobileOneBlock( + dim, + dim, + (1, kernel_size), + padding=(0, kernel_size // 2), + groups=dim, + use_act=False, + ) + self.use_layer_scale = use_layer_scale + if use_layer_scale: + self.layer_scale = nn.Parameter(layer_scale_init_value * torch.ones((dim, 1, 1)), requires_grad=True) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if hasattr(self, "reparam_conv"): + x = self.reparam_conv(x) + return x + else: + if self.use_layer_scale: + x = x + self.layer_scale * (self.mixer(x) - self.norm(x)) + else: + x = x + self.mixer(x) - self.norm(x) + return x + + def reparameterize(self) -> None: + """Reparameterize mixer and norm into a single convolutional layer for efficient inference.""" + if self.inference_mode: + return + + self.mixer.reparameterize() + self.norm.reparameterize() + + if self.use_layer_scale: + w = self.mixer.id_tensor + self.layer_scale.unsqueeze(-1) * ( + self.mixer.reparam_conv.weight - self.norm.reparam_conv.weight + ) + b = torch.squeeze(self.layer_scale) * (self.mixer.reparam_conv.bias - self.norm.reparam_conv.bias) + else: + w = self.mixer.id_tensor + self.mixer.reparam_conv.weight - self.norm.reparam_conv.weight + b = self.mixer.reparam_conv.bias - self.norm.reparam_conv.bias + + self.reparam_conv = nn.Conv2d( + in_channels=self.dim, + out_channels=self.dim, + kernel_size=(1, self.kernel_size), + stride=1, + padding=(0, self.kernel_size // 2), + groups=self.dim, + bias=True, + ) + self.reparam_conv.weight.data = w + self.reparam_conv.bias.data = b + + for para in self.parameters(): + para.detach_() + self.__delattr__("mixer") + self.__delattr__("norm") + if self.use_layer_scale: + self.__delattr__("layer_scale") + + +class RepMixerBlock(nn.Module): + """Implementation of Metaformer block with RepMixer as token mixer. + + For more details on Metaformer structure, please refer to: `MetaFormer Is Actually What You Need for Vision + `_ + """ + + def __init__( + self, + dim: int, + kernel_size: int = 11, + mlp_ratio: float = 4.0, + act_layer: nn.Module = nn.GELU, + drop: float = 0.0, + drop_path: float = 0.0, + use_layer_scale: bool = True, + layer_scale_init_value: float = 1e-5, + inference_mode: bool = False, + *args, + **kwargs, + ): + """Build RepMixer Block. + + Args: + dim: Number of embedding dimensions. + kernel_size: Kernel size for repmixer. Default: 3 + mlp_ratio: MLP expansion ratio. Default: 4.0 + act_layer: Activation layer. Default: ``nn.GELU`` + drop: Dropout rate. Default: 0.0 + drop_path: Drop path rate. Default: 0.0 + use_layer_scale: Flag to turn on layer scale. Default: ``True`` + layer_scale_init_value: Layer scale value at initialization. Default: 1e-5 + inference_mode: Flag to instantiate block in inference mode. Default: ``False`` + """ + super().__init__() + + self.token_mixer = RepMixer( + dim, + kernel_size=kernel_size, + use_layer_scale=use_layer_scale, + layer_scale_init_value=layer_scale_init_value, + inference_mode=inference_mode, + ) + + assert mlp_ratio > 0, f"MLP ratio should be greater than 0, found: {mlp_ratio}" + mlp_hidden_dim = int(dim * mlp_ratio) + self.convffn = ConvFFN( + in_channels=dim, + context_size=kernel_size, + hidden_channels=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + ) + + # Drop Path + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + # Layer Scale + self.use_layer_scale = use_layer_scale + if use_layer_scale: + self.layer_scale = nn.Parameter(layer_scale_init_value * torch.ones((dim, 1, 1)), requires_grad=True) + + def forward(self, x, *args, **kwargs): + if x.dim() == 3: + # B, C, D --- where C is the context length + # Convert to B, D, C --- to match RepMixer impl. + x = x.permute(0, 2, 1) + x = torch.unsqueeze(x, dim=2) + else: + raise ValueError(f"Expected tensor of dim=3, obtained tensor of dim={x.dim()}") + + if self.use_layer_scale: + x = self.token_mixer(x) + x = x + self.drop_path(self.layer_scale * self.convffn(x)) + else: + x = self.token_mixer(x) + x = x + self.drop_path(self.convffn(x)) + + # Convert tensors back + x = x.squeeze(dim=2).permute(0, 2, 1) + return x diff --git a/mobileclip/modules/text/tokenizer.py b/mobileclip/modules/text/tokenizer.py new file mode 100644 index 0000000..154face --- /dev/null +++ b/mobileclip/modules/text/tokenizer.py @@ -0,0 +1,39 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# +# For licensing see accompanying LICENSE file. +# Copyright (C) 2024 Apple Inc. All Rights Reserved. +# + +import open_clip +from torch import Tensor, nn + + +class ClipTokenizer(nn.Module): + def __init__(self, cfg, *args, **kwargs): + super().__init__() + self.context_length = cfg["text_cfg"]["context_length"] + model_name = getattr(cfg["text_cfg"], "open_clip_tokenizer", "ViT-B-16") + self.tokenizer = open_clip.get_tokenizer(model_name) + + def get_vocab_size(self) -> int: + return len(self.tokenizer.encoder) + + def get_encodings(self) -> dict[str, int]: + return self.tokenizer.encoder + + def get_eot_token(self) -> int: + # Tokenizing an empty string returns a list [sot_id, eot_id] + return self.tokenizer("")[1] + + def get_sot_token(self) -> int: + # Tokenizing an empty string returns a list [sot_id, eot_id] + return self.tokenizer("")[0] + + def forward(self, input_sentence: str, *args, **kwargs) -> Tensor: + # tokenizer returns indices as a string + tokenized_sentence = self.tokenizer(input_sentence, self.context_length) + assert tokenized_sentence.shape[-1] == self.context_length, ( + "Tokenized tensor should be exactly `context_length` long." + ) + return tokenized_sentence diff --git a/mobileclip/text_encoder.py b/mobileclip/text_encoder.py new file mode 100644 index 0000000..62fc3e3 --- /dev/null +++ b/mobileclip/text_encoder.py @@ -0,0 +1,218 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# For licensing see accompanying LICENSE file. +# Copyright (C) 2024 Apple Inc. All Rights Reserved. + +from __future__ import annotations + +import math +from collections.abc import Sequence + +import torch +from torch import Tensor, nn + +from mobileclip import logger +from mobileclip.modules.common.transformer import ( + PositionalEmbedding, + TransformerEncoder, + get_normalization_layer, +) +from mobileclip.modules.text.repmixer import RepMixerBlock + + +class TextTransformer(nn.Module): + def __init__(self, cfg: dict, projection_dim: int, *args, **kwargs) -> None: + super().__init__() + + model_dim = cfg["dim"] + no_scale_embedding = cfg.get("no_scale_embedding", False) + no_pos_embedding = cfg.get("no_pos_embedding", False) + embed_dropout = cfg.get("embed_dropout", 0.0) + norm_layer = cfg["norm_layer"] + variant = cfg["model_name"] + self.vocab_size = cfg["vocab_size"] + self.projection_dim = projection_dim + + # Token embedding layer + self.embedding_layer = nn.Embedding(embedding_dim=model_dim, num_embeddings=self.vocab_size) + self.embed_scale = 1.0 if no_scale_embedding else model_dim**-0.5 + + # Context length + context_length = cfg["context_length"] + assert context_length is not None, "Context length can't be None. Please set value accordingly." + + self.positional_embedding = ( + None if no_pos_embedding else PositionalEmbedding(num_embeddings=context_length, embedding_dim=model_dim) + ) + + self.embedding_dropout = nn.Dropout(p=embed_dropout) + + # Transformer layer + n_transformer_layers = cfg["n_transformer_layers"] + + # FFN multipliers for transformer layer + ffn_multipliers = cfg["ffn_multiplier_per_layer"] + if isinstance(ffn_multipliers, (float, int)): + ffn_multipliers = [ffn_multipliers] * n_transformer_layers + + if not isinstance(ffn_multipliers, Sequence): + logger.error( + f"{self.__class__.__name__} expects FFN multipliers as a list, whose length is the same as" + f" number of transformer layers. Got: {type(ffn_multipliers)}" + ) + elif isinstance(ffn_multipliers, Sequence) and len(ffn_multipliers) != n_transformer_layers: + logger.error( + f"We need FFN multiplier for each transformer layer. Got {len(ffn_multipliers)} ffn" + f" multipliers while number of transformer layers = {n_transformer_layers}" + ) + ffn_dims = [int(math.ceil(model_dim * ffn_mult / 16.0) * 16.0) for ffn_mult in ffn_multipliers] + + # Heads for transformer layers + mha_heads = cfg["n_heads_per_layer"] + if isinstance(mha_heads, int): + mha_heads = [mha_heads] * n_transformer_layers + + if not isinstance(mha_heads, Sequence): + logger.error( + f"{self.__class__.__name__} expects MHA heads as a list, whose length is the same as number of " + f"transformer layers. Got: {type(mha_heads)}" + ) + elif isinstance(mha_heads, Sequence) and len(mha_heads) != n_transformer_layers: + logger.error( + f"{self.__class__.__name__} needs MHA heads for each transformer layer. Got {len(mha_heads)} mha heads while" + f" number of transformer layers = {n_transformer_layers}" + ) + + if variant == "base": + self.transformer = nn.ModuleList( + [ + TransformerEncoder( + embed_dim=model_dim, + num_heads=mha_heads[layer_idx], + ffn_latent_dim=ffn_dims[layer_idx], + transformer_norm_layer=norm_layer, + ) + for layer_idx in range(n_transformer_layers) + ] + ) + elif variant == "mct": + self.transformer = nn.ModuleList([RepMixerBlock(dim=model_dim)]) + self.transformer.extend( + [ + TransformerEncoder( + embed_dim=model_dim, + num_heads=mha_heads[layer_idx], + ffn_latent_dim=ffn_dims[layer_idx], + transformer_norm_layer=norm_layer, + ) + for layer_idx in range(n_transformer_layers) + ] + ) + self.transformer.extend([RepMixerBlock(dim=model_dim)]) + else: + raise ValueError(f"Unrecognized text encoder variant {variant}") + + self.final_layer_norm = get_normalization_layer(num_features=model_dim, norm_type=norm_layer) + + self.projection_layer = nn.Parameter(torch.empty(model_dim, self.projection_dim)) + self.model_dim = model_dim + self.causal_masking = cfg["causal_masking"] + + def forward_embedding(self, text_tokens: Tensor) -> Tensor: + """Return text embedding for all tokens. + + Args: + text_tokens: a tensor of token indices. Shape: [batch_size, context_length] + + Returns: + A tensor of [batch_size, context_length, hidden_dim]. + """ + # [batch_size, context_length] --> [batch_size, context_length, hidden_dim] + token_emb = self.embedding_layer(text_tokens) + seq_len = token_emb.shape[1] + if self.positional_embedding is not None: + token_emb = token_emb + self.positional_embedding(seq_len).to(token_emb.dtype) + token_emb = self.embedding_dropout(token_emb) + return token_emb + + @staticmethod + @torch.jit.script # use scripting to avoid device constant + def build_attention_mask(text_tokens: torch.Tensor) -> Tensor: + """Build causal attention mask [batch_size, context_length, context_length].""" + # Build mask with full attention between the tokens + # pytorch uses additive attention mask; fill with -inf + batch_size, context_length = text_tokens.shape + mask = torch.empty(context_length, context_length, device=text_tokens.device, dtype=torch.float32) + mask.fill_(float("-inf")) + mask.triu_(1) # zero out the lower diagonal + mask = mask.unsqueeze(0) # add dummy batch dimension + mask = mask.expand(batch_size, -1, -1) + return mask + + def encode_text( + self, + text_tokens: Tensor, + key_padding_mask: Tensor | None = None, + return_all_tokens: bool = False, + *args, + **kwargs, + ) -> Tensor: + """Return text token embeddings. + + Args: + text_tokens: a tensor of token indices. Shape: [batch_size, context_length] + key_padding_mask: a tensor of boolean values as the padding mask of shape [batch_size, context_length] + return_all_tokens: a boolean flag to return all tokens, defaults to False to return only EOT token + embedding. + + Returns: + A tensor of [batch_size, context_length, hidden_dim] if return_all_tokens is + True, otherwise a tensor of [batch_size, hidden_dim]. + """ + # Discrete tokens to continuous embeddings + # [batch_size, context_length] --> [batch_size, context_length, hidden_dim] + token_emb = self.forward_embedding(text_tokens) + + # [1, context_length, context_length] + attn_mask = None + if self.causal_masking: + attn_mask = self.build_attention_mask(text_tokens=text_tokens) + key_padding_mask = None + + for layer in self.transformer: + token_emb = layer( + token_emb, + key_padding_mask=key_padding_mask, + attn_mask=attn_mask, + ) + + # Apply layer norm + token_emb = self.final_layer_norm(token_emb) + + if return_all_tokens: + return token_emb + + # Take features from the eot embedding (eot_token is the highest number in each sequence) + token_emb = token_emb[torch.arange(text_tokens.shape[0]), text_tokens.argmax(dim=-1)] + + token_emb = token_emb @ self.projection_layer + return token_emb + + def forward( + self, + text_tokens: Tensor, + key_padding_mask: Tensor | None = None, + return_all_tokens: bool = False, + *args, + **kwargs, + ) -> Tensor: + # Image-text pair data with single caption + # [B, CL] --> [B, d] + text_tokens = self.encode_text( + text_tokens=text_tokens, + key_padding_mask=key_padding_mask, + return_all_tokens=return_all_tokens, + *args, + **kwargs, + ) + return text_tokens diff --git a/train.py b/train.py new file mode 100644 index 0000000..d2e05da --- /dev/null +++ b/train.py @@ -0,0 +1,64 @@ +import torch +from torch.utils.data import DataLoader + +# --- 引入你的模块 --- +from dataset import YOLODataset +from yolo11_standalone import YOLO11 +from loss import YOLOv8DetectionLoss + +# --- 引入刚刚写的 Trainer --- +from trainer import YOLO11Trainer + +def run_training(): + # --- 1.全局配置 --- + img_dir_train = "E:\\Datasets\\coco\\images\\train2017" + label_dir_train = "E:\\Datasets\\coco\\labels\\train2017" + img_dir_val = "E:\\Datasets\\coco\\images\\val2017" + label_dir_val = "E:\\Datasets\\coco\\labels\\val2017" + + epochs = 50 + batch_size = 36 + img_size = 640 + device = "cuda" if torch.cuda.is_available() else "cpu" + + # --- 2. 准备数据 --- + print("Loading Data...") + train_dataset = YOLODataset(img_dir_train, label_dir_train, img_size=img_size, is_train=True) + val_dataset = YOLODataset(img_dir_val, label_dir_val, img_size=img_size, is_train=False) + + train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, + collate_fn=YOLODataset.collate_fn, num_workers=8, pin_memory=True) + val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, + collate_fn=YOLODataset.collate_fn, num_workers=8, pin_memory=True) + + # --- 3. 初始化模型 --- + print("Initializing Model...") + model = YOLO11(nc=80, scale='s') + # model.load_weights("yolo11s.pth") + # model.to(device) + + strides = [8, 16, 32] + hyp = { + 'box': 7.5, + 'cls': 0.5, + 'dfl': 1.5 + } + loss_fn = YOLOv8DetectionLoss(nc=80, reg_max=16, stride=strides, hyp=hyp) + + # --- 5. 初始化 Trainer 并开始训练 --- + trainer = YOLO11Trainer( + model=model, + train_loader=train_loader, + val_loader=val_loader, + loss_fn=loss_fn, + epochs=epochs, + lr0=0.01, + device=device, + save_dir='./my_yolo_result' + ) + + trainer.train() + +if __name__ == "__main__": + # Windows下多进程dataloader需要这个保护 + run_training() \ No newline at end of file diff --git a/trainer.py b/trainer.py new file mode 100644 index 0000000..c08bbbb --- /dev/null +++ b/trainer.py @@ -0,0 +1,275 @@ +import math +import copy +import time +import logging +from pathlib import Path +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim +from tqdm import tqdm +from metrics import ap_per_class, box_iou, non_max_suppression, xywh2xyxy + +# 配置日志 +logging.basicConfig(format="%(message)s", level=logging.INFO) +LOGGER = logging.getLogger("YOLO_Trainer") + +# ============================================================================== +# Helper Class: Model EMA (Exponential Moving Average) +# ============================================================================== +class ModelEMA: + """ Updated Exponential Moving Average (EMA) from Ultralytics """ + def __init__(self, model, decay=0.9999, tau=2000, updates=0): + self.ema = copy.deepcopy(model).eval() # FP32 EMA + self.updates = updates + # decay exponential ramp (to help early epochs) + self.decay = lambda x: decay * (1 - math.exp(-x / tau)) + for p in self.ema.parameters(): + p.requires_grad_(False) + + def update(self, model): + self.updates += 1 + d = self.decay(self.updates) + + msd = model.state_dict() + for k, v in self.ema.state_dict().items(): + if k in msd: + tmp = msd[k].to(v.device) + if v.dtype.is_floating_point: + v *= d + v += (1 - d) * tmp + +# ============================================================================== +# Main Trainer Class +# ============================================================================== +class YOLO11Trainer: + def __init__(self, + model, + train_loader, + val_loader, + loss_fn, + epochs=100, + lr0=0.01, + lrf=0.01, + device='cuda', + save_dir='./runs/train', + warmup_epochs=3.0): + + self.device = torch.device(device if torch.cuda.is_available() else 'cpu') + self.model = model.to(self.device) + self.train_loader = train_loader + self.val_loader = val_loader + self.loss_fn = loss_fn + self.epochs = epochs + self.save_dir = Path(save_dir) + self.save_dir.mkdir(parents=True, exist_ok=True) + self.warmup_epochs = warmup_epochs + self.start_epoch = 0 + + # --- Optimizer Building --- + g = [], [], [] # optimizer parameter groups + bn = tuple(v for k, v in nn.__dict__.items() if 'Norm' in k) + + for v in self.model.modules(): + for p_name, p in v.named_parameters(recurse=False): + if p_name == 'bias': + g[2].append(p) # biases + elif isinstance(v, bn): + g[1].append(p) # bn weights (no decay) + else: + g[0].append(p) # weights (decay) + + self.optimizer = optim.SGD(g[2], lr=lr0, momentum=0.937, nesterov=True) + self.optimizer.add_param_group({'params': g[0], 'weight_decay': 0.0005}) + self.optimizer.add_param_group({'params': g[1], 'weight_decay': 0.0}) + + LOGGER.info(f"Optimizer: weights={len(g[0])}, bn={len(g[1])}, biases={len(g[2])}") + + # --- Scheduler --- + self.lf = lambda x: ((1 - math.cos(x * math.pi / self.epochs)) / 2) * (lrf - 1) + 1 + self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=self.lf) + + # --- AMP & EMA --- + self.scaler = torch.amp.GradScaler('cuda', enabled=True) + self.ema = ModelEMA(self.model) + + def train_one_epoch(self, epoch): + self.model.train() + pbar = tqdm(enumerate(self.train_loader), total=len(self.train_loader), + desc=f'Epoch {epoch+1}/{self.epochs}', leave=True) + + mloss = torch.zeros(3, device=self.device) + self.optimizer.zero_grad() + nb = len(self.train_loader) + + for i, batch in pbar: + ni = i + nb * epoch + + imgs, targets, paths = batch + imgs = imgs.to(self.device, non_blocking=True) + targets = targets.to(self.device) + + # --- Warmup --- + if ni <= nb * self.warmup_epochs: + xp = [0, nb * self.warmup_epochs] + for j, x in enumerate(self.optimizer.param_groups): + lr_target = x['initial_lr'] * self.lf(epoch) + x['lr'] = np.interp(ni, xp, [0.1 if j == 0 else 0.0, lr_target]) + if 'momentum' in x: + x['momentum'] = np.interp(ni, xp, [0.8, 0.937]) + + # --- Forward --- + with torch.amp.autocast('cuda', enabled=True): + preds = self.model(imgs) + + target_batch = { + "batch_idx": targets[:, 0], + "cls": targets[:, 1], + "bboxes": targets[:, 2:], + } + + loss, loss_items = self.loss_fn(preds, target_batch) + + # --- Backward --- + self.scaler.scale(loss).backward() + self.scaler.unscale_(self.optimizer) + torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=10.0) + self.scaler.step(self.optimizer) + self.scaler.update() + self.optimizer.zero_grad() + + # --- EMA --- + self.ema.update(self.model) + + # --- Logging --- + loss_items = loss_items.detach() + mloss = (mloss * i + loss_items) / (i + 1) + mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G' + + pbar.set_postfix({ + 'mem': mem, + 'box': f"{mloss[0]:.4f}", + 'cls': f"{mloss[1]:.4f}", + 'dfl': f"{mloss[2]:.4f}", + 'lr': f"{self.optimizer.param_groups[0]['lr']:.5f}" + }) + + def validate(self): + model = self.ema.ema + device = self.device + + # --- Metrics Config --- + conf_thres = 0.001 # Low threshold for mAP calculation + iou_thres = 0.7 # NMS IoU threshold + iouv = torch.linspace(0.5, 0.95, 10, device=device) # IoU vector for mAP@0.5:0.95 + + loss_sum = torch.zeros(3, device=device) + stats = [] # [(correct, conf, pred_cls, target_cls)] + + LOGGER.info("\nValidating...") + + model.eval() + + pbar = tqdm(self.val_loader, desc="Calc Metrics") + with torch.no_grad(): + for batch in pbar: + imgs, targets, _ = batch + imgs = imgs.to(device, non_blocking=True) + targets = targets.to(device) + _, _, height, width = imgs.shape + + # Inference + preds = model(imgs) + + # NMS + preds = non_max_suppression(preds, conf_thres=conf_thres, iou_thres=iou_thres) + + # Metrics Processing + for si, pred in enumerate(preds): + labels = targets[targets[:, 0] == si] + nl = len(labels) + tcls = labels[:, 1].tolist() if nl else [] + + if len(pred) == 0: + if nl: + stats.append((torch.zeros(0, iouv.numel(), dtype=torch.bool), + torch.Tensor(), torch.Tensor(), torch.tensor(tcls))) + continue + + # Predictions + predn = pred.clone() + + # Ground Truth + if nl: + tbox = xywh2xyxy(labels[:, 2:6]) + tbox[:, [0, 2]] *= width + tbox[:, [1, 3]] *= height + labelsn = torch.cat((labels[:, 1:2], tbox), 1) # [cls, x1, y1, x2, y2] + + # Match predictions to GT + correct = self._process_batch(predn, labelsn, iouv) + else: + correct = torch.zeros(pred.shape[0], iouv.numel(), dtype=torch.bool) + + stats.append((correct.cpu(), pred[:, 4].cpu(), pred[:, 5].cpu(), torch.tensor(tcls))) + + mp, mr, map50, map5095 = 0.0, 0.0, 0.0, 0.0 + if len(stats) and stats[0][0].any(): + stats = [torch.cat(x, 0).cpu().numpy() for x in zip(*stats)] # to numpy + tp, fp, p, r, f1, ap, ap_class = ap_per_class(*stats) + ap50, ap = ap[:, 0], ap.mean(1) # AP@0.5, AP@0.5:0.95 + mp, mr, map50, map5095 = p.mean(), r.mean(), ap50.mean(), ap.mean() + + LOGGER.info(f"Val Results: Prec={mp:.3f}, Recall={mr:.3f} mAP50={map50:.3f} mAP50-95={map5095:.3f}") + + return map50 + + def _process_batch(self, detections, labels, iouv): + """ + Return correct prediction matrix + detections: [N, 6] (x1, y1, x2, y2, conf, cls) + labels: [M, 5] (cls, x1, y1, x2, y2) + """ + correct = torch.zeros(detections.shape[0], iouv.shape[0], dtype=torch.bool, device=iouv.device) + iou = box_iou(labels[:, 1:], detections[:, :4]) + + x = torch.where((iou >= iouv[0]) & (labels[:, 0:1] == detections[:, 5])) + + if x[0].shape[0]: + matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1).cpu().numpy() + if x[0].shape[0] > 1: + matches = matches[matches[:, 2].argsort()[::-1]] + matches = matches[np.unique(matches[:, 1], return_index=True)[1]] + matches = matches[matches[:, 2].argsort()[::-1]] + matches = matches[np.unique(matches[:, 0], return_index=True)[1]] + + matches = torch.from_numpy(matches).to(iouv.device) + correct[matches[:, 1].long()] = matches[:, 2:3] >= iouv + + return correct + + def train(self): + LOGGER.info(f"Starting training on {self.device} for {self.epochs} epochs...") + start_time = time.time() + best_fitness = 0.0 + + for epoch in range(self.start_epoch, self.epochs): + self.train_one_epoch(epoch) + self.scheduler.step() + map50 = self.validate() + + ckpt = { + 'epoch': epoch, + 'model': self.model.state_dict(), + 'ema': self.ema.ema.state_dict(), + 'optimizer': self.optimizer.state_dict(), + } + + torch.save(ckpt, self.save_dir / 'last.pt') + + if map50 > best_fitness: + best_fitness = map50 + torch.save(ckpt, self.save_dir / 'best.pt') + LOGGER.info(f"--> Saved best model with Recall/mAP: {best_fitness:.4f}") + + LOGGER.info(f"\nTraining completed in {(time.time() - start_time) / 3600:.3f} hours.") \ No newline at end of file diff --git a/yolo11_standalone.py b/yolo11_standalone.py new file mode 100644 index 0000000..0b59075 --- /dev/null +++ b/yolo11_standalone.py @@ -0,0 +1,923 @@ +import math +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +def make_divisible(x, divisor): + if isinstance(x, torch.Tensor): + return x + return math.ceil(x / divisor) * divisor + +def autopad(k, p=None, d=1): + if d > 1: + k = d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k] + if p is None: + p = k // 2 if isinstance(k, int) else [x // 2 for x in k] + return p + +def make_anchors(feats, strides, grid_cell_offset=0.5): + anchor_points, stride_tensor = [], [] + assert feats is not None + dtype, device = feats[0].dtype, feats[0].device + for i, stride in enumerate(strides): + _, _, h, w = feats[i].shape + sx = torch.arange(end=w, device=device, dtype=dtype) + grid_cell_offset + sy = torch.arange(end=h, device=device, dtype=dtype) + grid_cell_offset + sy, sx = torch.meshgrid(sy, sx, indexing="ij") + anchor_points.append(torch.stack((sx, sy), -1).view(-1, 2)) + stride_tensor.append(torch.full((h * w, 1), stride, dtype=dtype, device=device)) + return torch.cat(anchor_points), torch.cat(stride_tensor) + +def dist2bbox(distance, anchor_points, xywh=True, dim=-1): + lt, rb = distance.chunk(2, dim) + x1y1 = anchor_points - lt + x2y2 = anchor_points + rb + if xywh: + c_xy = (x1y1 + x2y2) / 2 + wh = x2y2 - x1y1 + return torch.cat((c_xy, wh), dim) + return torch.cat((x1y1, x2y2), dim) + +class Concat(nn.Module): + def __init__(self, dimension=1): + super().__init__() + self.d = dimension + + def forward(self, x: List[torch.Tensor]): + return torch.cat(x, self.d) + + +class Conv(nn.Module): + default_act = nn.SiLU(inplace=True) + + def __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True): + super().__init__() + self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p, d), groups=g, dilation=d, bias=False) # type: ignore + self.bn = nn.BatchNorm2d(c2, eps=0.001, momentum=0.03) + self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity() + + def forward(self, x): + return self.act(self.bn(self.conv(x))) + + def forward_fuse(self, x): + return self.act(self.conv(x)) + +class DWConv(Conv): + def __init__(self, c1, c2, k=1, s=1, d=1, act=True): + super().__init__(c1, c2, k, s, g=math.gcd(c1, c2), d=d, act=act) + +class DFL(nn.Module): + def __init__(self, c1: int = 16): + super().__init__() + self.conv = nn.Conv2d(c1, 1, 1, bias=False).requires_grad_(False) + x = torch.arange(c1, dtype=torch.float) + self.conv.weight.data[:] = nn.Parameter(x.view(1, c1, 1, 1)) + self.c1 = c1 + + def forward(self, x: torch.Tensor) -> torch.Tensor: + b, _, a = x.shape + return self.conv(x.view(b, 4, self.c1, a).transpose(2, 1).softmax(1)).view(b, 4, a) + +class Bottleneck(nn.Module): + def __init__( + self, c1: int, c2: int, shortcut: bool = True, g: int = 1, k: Tuple[int, int] = (3, 3), e: float = 0.5 + ): + super().__init__() + c_ = int(c2 * e) + self.cv1 = Conv(c1, c_, k[0], 1) + self.cv2 = Conv(c_, c2, k[1], 1, g=g) + self.add = shortcut and c1 == c2 + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x)) + +class C2f(nn.Module): + def __init__(self, c1: int, c2: int, n: int = 1, shortcut: bool = False, g: int = 1, e: float = 0.5): + super().__init__() + self.c = int(c2 * e) + self.cv1 = Conv(c1, 2 * self.c, 1, 1) + self.cv2 = Conv((2 + n) * self.c, c2, 1) + self.m = nn.ModuleList(Bottleneck(self.c, self.c, shortcut, g, k=((3, 3), (3, 3)), e=1.0) for _ in range(n)) # type: ignore + + def forward(self, x): + chunk_result = self.cv1(x).chunk(2, 1) + y = [chunk_result[0], chunk_result[1]] + + for m_module in self.m: + y.append(m_module(y[-1])) + return self.cv2(torch.cat(y, 1)) + + def forward_split(self, x: torch.Tensor) -> torch.Tensor: + y = self.cv1(x).split((self.c, self.c), 1) + y = [y[0], y[1]] + y.extend(m(y[-1]) for m in self.m) + return self.cv2(torch.cat(y, 1)) + +class C3(nn.Module): + def __init__(self, c1: int, c2: int, n: int = 1, shortcut: bool = True, g: int = 1, e: float = 0.5): + super().__init__() + c_ = int(c2 * e) + self.cv1 = Conv(c1, c_, 1, 1) + self.cv2 = Conv(c1, c_, 1, 1) + self.cv3 = Conv(2 * c_, c2, 1) + self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, k=((1, 1), (3, 3)), e=1.0) for _ in range(n))) # type: ignore + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), 1)) + +class C3k(C3): + def __init__(self, c1: int, c2: int, n: int = 1, shortcut: bool = True, g: int = 1, e: float = 0.5, k: int = 3): + super().__init__(c1, c2, n, shortcut, g, e) + c_ = int(c2 * e) + self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, k=(k, k), e=1.0) for _ in range(n))) + +class C3k2(C2f): + def __init__( + self, c1: int, c2: int, n: int = 1, c3k: bool = False, e: float = 0.5, g: int = 1, shortcut: bool = True + ): + super().__init__(c1, c2, n, shortcut, g, e) + self.m = nn.ModuleList( + C3k(self.c, self.c, 2, shortcut, g) if c3k else Bottleneck(self.c, self.c, shortcut, g) for _ in range(n) + ) + +class SPPF(nn.Module): + def __init__(self, c1: int, c2: int, k: int = 5): + super().__init__() + c_ = c1 // 2 + self.cv1 = Conv(c1, c_, 1, 1) + self.cv2 = Conv(c_ * 4, c2, 1, 1) + self.m = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + y = [self.cv1(x)] + y.extend(self.m(y[-1]) for _ in range(3)) + return self.cv2(torch.cat(y, 1)) + +class Attention(nn.Module): + def __init__(self, dim: int, num_heads: int = 8, attn_ratio: float = 0.5): + super().__init__() + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.key_dim = int(self.head_dim * attn_ratio) + self.scale = self.key_dim**-0.5 + nh_kd = self.key_dim * num_heads + h = dim + nh_kd * 2 + self.qkv = Conv(dim, h, 1, act=False) + self.proj = Conv(dim, dim, 1, act=False) + self.pe = Conv(dim, dim, 3, 1, g=dim, act=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, C, H, W = x.shape + N = H * W + qkv = self.qkv(x) + q, k, v = qkv.view(B, self.num_heads, self.key_dim * 2 + self.head_dim, N).split( + [self.key_dim, self.key_dim, self.head_dim], dim=2 + ) + + attn = (q.transpose(-2, -1) @ k) * self.scale + attn = attn.softmax(dim=-1) + x = (v @ attn.transpose(-2, -1)).view(B, C, H, W) + self.pe(v.reshape(B, C, H, W)) + x = self.proj(x) + return x + +class PSABlock(nn.Module): + def __init__(self, c: int, attn_ratio: float = 0.5, num_heads: int = 4, shortcut: bool = True) -> None: + super().__init__() + self.attn = Attention(c, attn_ratio=attn_ratio, num_heads=num_heads) + self.ffn = nn.Sequential(Conv(c, c * 2, 1), Conv(c * 2, c, 1, act=False)) + self.add = shortcut + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x + self.attn(x) if self.add else self.attn(x) + x = x + self.ffn(x) if self.add else self.ffn(x) + return x + +class C2PSA(nn.Module): + def __init__(self, c1: int, c2: int, n: int = 1, e: float = 0.5): + super().__init__() + assert c1 == c2 + self.c = int(c1 * e) + self.cv1 = Conv(c1, 2 * self.c, 1, 1) + self.cv2 = Conv(2 * self.c, c1, 1) + + self.m = nn.Sequential(*(PSABlock(self.c, attn_ratio=0.5, num_heads=self.c // 64) for _ in range(n))) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + a, b = self.cv1(x).split((self.c, self.c), dim=1) + b = self.m(b) + return self.cv2(torch.cat((a, b), 1)) + +class Proto(nn.Module): + def __init__(self, c1: int, c_: int = 256, c2: int = 32): + super().__init__() + self.cv1 = Conv(c1, c_, k=3) + self.upsample = nn.ConvTranspose2d(c_, c_, 2, 2, 0, bias=True) + self.cv2 = Conv(c_, c_, k=3) + self.cv3 = Conv(c_, c2) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.cv3(self.cv2(self.upsample(self.cv1(x)))) + +class BNContrastiveHead(nn.Module): + def __init__(self, embed_dims: int): + super().__init__() + self.norm = nn.BatchNorm2d(embed_dims) + self.bias = nn.Parameter(torch.tensor([-10.0])) + self.logit_scale = nn.Parameter(-1.0 * torch.ones([])) + + def fuse(self): + del self.norm + del self.bias + del self.logit_scale + self.forward = self.forward_fuse + + def forward_fuse(self, x: torch.Tensor, w: torch.Tensor) -> torch.Tensor: + return x + + def forward(self, x: torch.Tensor, w: torch.Tensor) -> torch.Tensor: + x = self.norm(x) + w = F.normalize(w, dim=-1, p=2) + + x = torch.einsum("bchw,bkc->bkhw", x, w) + return x * self.logit_scale.exp() + self.bias + +class SwiGLUFFN(nn.Module): + def __init__(self, gc: int, ec: int, e: int = 4) -> None: + super().__init__() + self.w12 = nn.Linear(gc, e * ec) + self.w3 = nn.Linear(e * ec // 2, ec) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x12 = self.w12(x) + x1, x2 = x12.chunk(2, dim=-1) + hidden = F.silu(x1) * x2 + return self.w3(hidden) + +class Residual(nn.Module): + def __init__(self, m: nn.Module) -> None: + super().__init__() + self.m = m + nn.init.zeros_(self.m.w3.bias) + # For models with l scale, please change the initialization to + # nn.init.constant_(self.m.w3.weight, 1e-6) + nn.init.zeros_(self.m.w3.weight) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x + self.m(x) + +class SAVPE(nn.Module): + def __init__(self, ch: List[int], c3: int, embed: int): + super().__init__() + self.cv1 = nn.ModuleList( + nn.Sequential( + Conv(x, c3, 3), Conv(c3, c3, 3), nn.Upsample(scale_factor=i * 2) if i in {1, 2} else nn.Identity() + ) + for i, x in enumerate(ch) + ) + + self.cv2 = nn.ModuleList( + nn.Sequential(Conv(x, c3, 1), nn.Upsample(scale_factor=i * 2) if i in {1, 2} else nn.Identity()) + for i, x in enumerate(ch) + ) + + self.c = 16 + self.cv3 = nn.Conv2d(3 * c3, embed, 1) + self.cv4 = nn.Conv2d(3 * c3, self.c, 3, padding=1) + self.cv5 = nn.Conv2d(1, self.c, 3, padding=1) + self.cv6 = nn.Sequential(Conv(2 * self.c, self.c, 3), nn.Conv2d(self.c, self.c, 3, padding=1)) + + def forward(self, x: List[torch.Tensor], vp: torch.Tensor) -> torch.Tensor: + y = [self.cv2[i](xi) for i, xi in enumerate(x)] + y = self.cv4(torch.cat(y, dim=1)) + + x = [self.cv1[i](xi) for i, xi in enumerate(x)] + x = self.cv3(torch.cat(x, dim=1)) + + B, C, H, W = x.shape # type: ignore + + Q = vp.shape[1] + + x = x.view(B, C, -1) # type: ignore + + y = y.reshape(B, 1, self.c, H, W).expand(-1, Q, -1, -1, -1).reshape(B * Q, self.c, H, W) + vp = vp.reshape(B, Q, 1, H, W).reshape(B * Q, 1, H, W) + + y = self.cv6(torch.cat((y, self.cv5(vp)), dim=1)) + + y = y.reshape(B, Q, self.c, -1) + vp = vp.reshape(B, Q, 1, -1) + + score = y * vp + torch.logical_not(vp) * torch.finfo(y.dtype).min + score = F.softmax(score, dim=-1).to(y.dtype) + aggregated = score.transpose(-2, -3) @ x.reshape(B, self.c, C // self.c, -1).transpose(-1, -2) + + return F.normalize(aggregated.transpose(-2, -3).reshape(B, Q, -1), dim=-1, p=2) + +class Detect(nn.Module): + dynamic = False + export = False + shape = None + anchors = torch.empty(0) + strides = torch.empty(0) + + def __init__(self, nc=80, ch=()): + super().__init__() + self.nc = nc + self.nl = len(ch) + self.reg_max = 16 + self.no = nc + self.reg_max * 4 + self.stride = torch.tensor([8., 16., 32.]) + + c2, c3 = max((16, ch[0] // 4, self.reg_max * 4)), max(ch[0], min(self.nc, 100)) + + self.cv2 = nn.ModuleList( + nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3), nn.Conv2d(c2, 4 * self.reg_max, 1)) for x in ch + ) + self.cv3 = nn.ModuleList( + nn.Sequential( + nn.Sequential(DWConv(x, x, 3), Conv(x, c3, 1)), + nn.Sequential(DWConv(c3, c3, 3), Conv(c3, c3, 1)), + nn.Conv2d(c3, self.nc, 1), + ) + for x in ch + ) + self.dfl = DFL(self.reg_max) if self.reg_max > 1 else nn.Identity() + + def forward(self, x): + for i in range(self.nl): + x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1) + + if self.training: + return x + + # Inference path + shape = x[0].shape + x_cat = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2) + + if self.dynamic or self.shape != shape: + self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5)) + self.shape = shape + + box, cls = x_cat.split((self.reg_max * 4, self.nc), 1) + dbox = dist2bbox(self.dfl(box), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides + + return torch.cat((dbox, cls.sigmoid()), 1) + + def bias_init(self): + m = self + for a, b, s in zip(m.cv2, m.cv3, m.stride): + a[-1].bias.data[:] = 1.0 # type: ignore + b[-1].bias.data[: m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # type: ignore + + def decode_bboxes(self, bboxes: torch.Tensor, anchors: torch.Tensor, xywh: bool = True) -> torch.Tensor: + return dist2bbox(bboxes, anchors, xywh=xywh and not (self.end2end or self.xyxy), dim=1) + + @staticmethod + def postprocess(preds: torch.Tensor, max_det: int, nc: int = 80) -> torch.Tensor: + batch_size, anchors, _ = preds.shape + boxes, scores = preds.split([4, nc], dim=-1) + index = scores.amax(dim=-1).topk(min(max_det, anchors))[1].unsqueeze(-1) + boxes = boxes.gather(dim=1, index=index.repeat(1, 1, 4)) + scores = scores.gather(dim=1, index=index.repeat(1, 1, nc)) + scores, index = scores.flatten(1).topk(min(max_det, anchors)) + i = torch.arange(batch_size)[..., None] + return torch.cat([boxes[i, index // nc], scores[..., None], (index % nc)[..., None].float()], dim=-1) + +class YOLOEDetect(Detect): + def __init__(self, nc: int = 80, embed: int = 512, ch: Tuple = ()): + super().__init__(nc, ch) + c3 = max(ch[0], min(self.nc, 100)) + assert c3 <= embed + self.cv3 = ( + nn.ModuleList( + nn.Sequential( + nn.Sequential(DWConv(x, x, 3), Conv(x, c3, 1)), + nn.Sequential(DWConv(c3, c3, 3), Conv(c3, c3, 1)), + nn.Conv2d(c3, embed, 1), + ) + for x in ch + ) + ) + + self.cv4 = nn.ModuleList(BNContrastiveHead(embed) for _ in ch) + + self.reprta = Residual(SwiGLUFFN(embed, embed)) + self.savpe = SAVPE(ch, c3, embed) # type: ignore + self.embed = embed + + def get_tpe(self, tpe: Optional[torch.Tensor]) -> Optional[torch.Tensor]: + return None if tpe is None else F.normalize(self.reprta(tpe), dim=-1, p=2) + + def get_vpe(self, x: List[torch.Tensor], vpe: torch.Tensor) -> torch.Tensor: + if vpe.shape[1] == 0: # no visual prompt embeddings + return torch.zeros(x[0].shape[0], 0, self.embed, device=x[0].device) + if vpe.ndim == 4: # (B, N, H, W) + vpe = self.savpe(x, vpe) + assert vpe.ndim == 3 # (B, N, D) + return vpe + + def forward( # type: ignore + self, x: List[torch.Tensor], cls_pe: torch.Tensor + ) -> Union[torch.Tensor, Tuple]: + for i in range(self.nl): + x[i] = torch.cat((self.cv2[i](x[i]), self.cv4[i](self.cv3[i](x[i]), cls_pe)), 1) + if self.training: + return x # type: ignore + self.no = self.nc + self.reg_max * 4 # self.nc could be changed when inference with different texts + # Inference path + shape = x[0].shape + x_cat = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2) + + if self.dynamic or self.shape != shape: + self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5)) + self.shape = shape + + box, cls = x_cat.split((self.reg_max * 4, self.nc), 1) + dbox = dist2bbox(self.dfl(box), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides + + return torch.cat((dbox, cls.sigmoid()), 1) + + def bias_init(self): + m = self + for a, b, c, s in zip(m.cv2, m.cv3, m.cv4, m.stride): + a[-1].bias.data[:] = 1.0 # box + b[-1].bias.data[:] = 0.0 + c.bias.data[:] = math.log(5 / m.nc / (640 / s) ** 2) + +class YOLOESegment(YOLOEDetect): + def __init__( + self, nc: int = 80, nm: int = 32, npr: int = 256, embed: int = 512, ch: Tuple = () + ): + super().__init__(nc, embed, ch) + self.nm = nm + self.npr = npr + self.proto = Proto(ch[0], self.npr, self.nm) + + c5 = max(ch[0] // 4, self.nm) + self.cv5 = nn.ModuleList(nn.Sequential(Conv(x, c5, 3), Conv(c5, c5, 3), nn.Conv2d(c5, self.nm, 1)) for x in ch) + + def forward(self, x: List[torch.Tensor], text: torch.Tensor) -> Union[Tuple, torch.Tensor]: + p = self.proto(x[0]) # mask protos + bs = p.shape[0] # batch size + + mc = torch.cat([self.cv5[i](x[i]).view(bs, self.nm, -1) for i in range(self.nl)], 2) # mask coefficients + x = YOLOEDetect.forward(self, x, text) + + if self.training: + return x, mc, p + + return (torch.cat([x, mc], 1), p) + + + +class YOLO11(nn.Module): + def __init__(self, nc=80, scale='n'): + super().__init__() + self.nc = nc + + # Scales: [depth, width, max_channels] + # 对应 yolo11.yaml 中的 scales 参数 + scales = { + 'n': [0.50, 0.25, 1024], + 's': [0.50, 0.50, 1024], + 'm': [0.50, 1.00, 512], + 'l': [1.00, 1.00, 512], + 'x': [1.00, 1.50, 512], + } + + if scale not in scales: + raise ValueError(f"Invalid scale '{scale}'. Available scales: {list(scales.keys())}") + + depth, width, max_channels = scales[scale] + + if scale in ['n', 's']: + c3k_override = False + else: + c3k_override = True + + # 辅助函数:计算通道数 (Width Scaling) + def gw(channels): + return make_divisible(min(channels, max_channels) * width, 8) + + # 辅助函数:计算层重复次数 (Depth Scaling) + def gd(n): + return max(round(n * depth), 1) if n > 1 else n + + self.model = nn.ModuleList() + + # --- Backbone --- + # 0: Conv [64, 3, 2] + self.model.append(Conv(3, gw(64), 3, 2)) + + # 1: Conv [128, 3, 2] + self.model.append(Conv(gw(64), gw(128), 3, 2)) + + # 2: C3k2 [256, False, 0.25] -> n=2 + self.model.append(C3k2(gw(128), gw(256), n=gd(2), c3k=False or c3k_override, e=0.25)) + + # 3: Conv [256, 3, 2] + self.model.append(Conv(gw(256), gw(256), 3, 2)) + + # 4: C3k2 [512, False, 0.25] -> n=2 + self.model.append(C3k2(gw(256), gw(512), n=gd(2), c3k=False or c3k_override, e=0.25)) + + # 5: Conv [512, 3, 2] + self.model.append(Conv(gw(512), gw(512), 3, 2)) + + # 6: C3k2 [512, True] -> n=2 + self.model.append(C3k2(gw(512), gw(512), n=gd(2), c3k=True)) + + # 7: Conv [1024, 3, 2] + self.model.append(Conv(gw(512), gw(1024), 3, 2)) + + # 8: C3k2 [1024, True] -> n=2 + self.model.append(C3k2(gw(1024), gw(1024), n=gd(2), c3k=True)) + + # 9: SPPF [1024, 5] + self.model.append(SPPF(gw(1024), gw(1024), 5)) + + # 10: C2PSA [1024] -> n=2 (YAML args=[1024], repeats=2) + self.model.append(C2PSA(gw(1024), gw(1024), n=gd(2))) + + # --- Head --- + + # 11: Upsample + self.model.append(nn.Upsample(scale_factor=2, mode='nearest')) + + # 12: Concat [-1, 6] (P4) + self.model.append(Concat(dimension=1)) + + # 13: C3k2 [512, False] -> n=2. Input: P5_up(gw(1024)) + P4(gw(512)) + self.model.append(C3k2(gw(1024) + gw(512), gw(512), n=gd(2), c3k=False or c3k_override)) + + # 14: Upsample + self.model.append(nn.Upsample(scale_factor=2, mode='nearest')) + + # 15: Concat [-1, 4] (P3) + self.model.append(Concat(dimension=1)) + + # 16: C3k2 [256, False] -> n=2. Input: P4_up(gw(512)) + P3(gw(512)) + # 注意:Layer 4 输出是 gw(512),Layer 13 输出是 gw(512) + self.model.append(C3k2(gw(512) + gw(512), gw(256), n=gd(2), c3k=False or c3k_override)) + + # 17: Conv [256, 3, 2] + self.model.append(Conv(gw(256), gw(256), 3, 2)) + + # 18: Concat [-1, 13] (Head P4) + self.model.append(Concat(dimension=1)) + + # 19: C3k2 [512, False] -> n=2. Input: P3_down(gw(256)) + Head_P4(gw(512)) + self.model.append(C3k2(gw(256) + gw(512), gw(512), n=gd(2), c3k=False or c3k_override)) + + # 20: Conv [512, 3, 2] + self.model.append(Conv(gw(512), gw(512), 3, 2)) + + # 21: Concat [-1, 10] (P5) + self.model.append(Concat(dimension=1)) + + # 22: C3k2 [1024, True] -> n=2. Input: P4_down(gw(512)) + P5(gw(1024)) + self.model.append(C3k2(gw(512) + gw(1024), gw(1024), n=gd(2), c3k=True)) + + # 23: Detect [nc] + self.model.append(Detect(nc, ch=[gw(256), gw(512), gw(1024)])) + + # --- 初始化权重 --- + self.initialize_weights() + + def initialize_weights(self): + """初始化模型权重,特别是 Detect 头的 Bias""" + for m in self.modules(): + if isinstance(m, (nn.Conv2d, nn.Linear)): + # 使用 Kaiming 初始化或其他合适的初始化 + pass + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + detect_layer = self.model[-1] + if isinstance(detect_layer, Detect): + detect_layer.bias_init() + + def forward(self, x): + # Backbone + x = self.model[0](x) + x = self.model[1](x) + x = self.model[2](x) + x = self.model[3](x) + p3 = self.model[4](x) # 保存 P3 (layer 4) + x = self.model[5](p3) + p4 = self.model[6](x) # 保存 P4 (layer 6) + x = self.model[7](p4) + x = self.model[8](x) + x = self.model[9](x) + p5 = self.model[10](x) # 保存 P5 (layer 10) + + # Head + x = self.model[11](p5) # Upsample + x = self.model[12]([x, p4]) # Concat P4 + h1 = self.model[13](x) # Head P4 (layer 13) + + x = self.model[14](h1) # Upsample + x = self.model[15]([x, p3]) # Concat P3 + h2 = self.model[16](x) # Output P3 (layer 16) + + x = self.model[17](h2) # Conv + x = self.model[18]([x, h1]) # Concat Head P4 + h3 = self.model[19](x) # Output P4 (layer 19) + + x = self.model[20](h3) # Conv + x = self.model[21]([x, p5]) # Concat P5 + h4 = self.model[22](x) # Output P5 (layer 22) + + return self.model[23]([h2, h3, h4]) # Detect + + def load_weights(self, pth_file): + state_dict = torch.load(pth_file, map_location='cpu', weights_only=False) + # 移除可能存在的 'model.' 前缀 (如果权重来自 ultralytics 官方) + # 官方权重通常是 model.model.0.conv... 这种格式,或者直接是 model.0.conv... + # 这里做一个简单的兼容性处理 + new_state_dict = {} + for k, v in state_dict.items(): + # 处理 ultralytics 权重字典中的 'model' 键 + if k == 'model': + # 如果是完整的 checkpoint,权重在 'model' 键下 + # 且通常是 model.state_dict() + if hasattr(v, 'state_dict'): + v = v.state_dict() + elif isinstance(v, dict): + pass # v 就是 state_dict + else: + # 可能是 model 对象本身 + try: + v = v.float().state_dict() + except: + continue + + for sub_k, sub_v in v.items(): + new_state_dict[sub_k] = sub_v + break + else: + new_state_dict[k] = v + + if not new_state_dict: + new_state_dict = state_dict + + # 尝试加载 + try: + self.load_state_dict(new_state_dict, strict=True) + print(f"Successfully loaded weights from {pth_file}") + except Exception as e: + print(f"Error loading weights: {e}") + print("Trying to load with strict=False...") + self.load_state_dict(new_state_dict, strict=False) + +class YOLO11E(nn.Module): + def __init__(self, nc=80, scale='n'): + super().__init__() + self.nc = nc + self.pe = None + + # Scales: [depth, width, max_channels] + # 对应 yolo11.yaml 中的 scales 参数 + scales = { + 'n': [0.50, 0.25, 1024], + 's': [0.50, 0.50, 1024], + 'm': [0.50, 1.00, 512], + 'l': [1.00, 1.00, 512], + 'x': [1.00, 1.50, 512], + } + + if scale not in scales: + raise ValueError(f"Invalid scale '{scale}'. Available scales: {list(scales.keys())}") + + depth, width, max_channels = scales[scale] + + if scale in ['n', 's']: + c3k_override = False + else: + c3k_override = True + + # 辅助函数:计算通道数 (Width Scaling) + def gw(channels): + return make_divisible(min(channels, max_channels) * width, 8) + + # 辅助函数:计算层重复次数 (Depth Scaling) + def gd(n): + return max(round(n * depth), 1) if n > 1 else n + + self.model = nn.ModuleList() + + # --- Backbone --- + # 0: Conv [64, 3, 2] + self.model.append(Conv(3, gw(64), 3, 2)) + + # 1: Conv [128, 3, 2] + self.model.append(Conv(gw(64), gw(128), 3, 2)) + + # 2: C3k2 [256, False, 0.25] -> n=2 + self.model.append(C3k2(gw(128), gw(256), n=gd(2), c3k=False or c3k_override, e=0.25)) + + # 3: Conv [256, 3, 2] + self.model.append(Conv(gw(256), gw(256), 3, 2)) + + # 4: C3k2 [512, False, 0.25] -> n=2 + self.model.append(C3k2(gw(256), gw(512), n=gd(2), c3k=False or c3k_override, e=0.25)) + + # 5: Conv [512, 3, 2] + self.model.append(Conv(gw(512), gw(512), 3, 2)) + + # 6: C3k2 [512, True] -> n=2 + self.model.append(C3k2(gw(512), gw(512), n=gd(2), c3k=True)) + + # 7: Conv [1024, 3, 2] + self.model.append(Conv(gw(512), gw(1024), 3, 2)) + + # 8: C3k2 [1024, True] -> n=2 + self.model.append(C3k2(gw(1024), gw(1024), n=gd(2), c3k=True)) + + # 9: SPPF [1024, 5] + self.model.append(SPPF(gw(1024), gw(1024), 5)) + + # 10: C2PSA [1024] -> n=2 (YAML args=[1024], repeats=2) + self.model.append(C2PSA(gw(1024), gw(1024), n=gd(2))) + + # --- Head --- + + # 11: Upsample + self.model.append(nn.Upsample(scale_factor=2, mode='nearest')) + + # 12: Concat [-1, 6] (P4) + self.model.append(Concat(dimension=1)) + + # 13: C3k2 [512, False] -> n=2. Input: P5_up(gw(1024)) + P4(gw(512)) + self.model.append(C3k2(gw(1024) + gw(512), gw(512), n=gd(2), c3k=False or c3k_override)) + + # 14: Upsample + self.model.append(nn.Upsample(scale_factor=2, mode='nearest')) + + # 15: Concat [-1, 4] (P3) + self.model.append(Concat(dimension=1)) + + # 16: C3k2 [256, False] -> n=2. Input: P4_up(gw(512)) + P3(gw(512)) + # 注意:Layer 4 输出是 gw(512),Layer 13 输出是 gw(512) + self.model.append(C3k2(gw(512) + gw(512), gw(256), n=gd(2), c3k=False or c3k_override)) + + # 17: Conv [256, 3, 2] + self.model.append(Conv(gw(256), gw(256), 3, 2)) + + # 18: Concat [-1, 13] (Head P4) + self.model.append(Concat(dimension=1)) + + # 19: C3k2 [512, False] -> n=2. Input: P3_down(gw(256)) + Head_P4(gw(512)) + self.model.append(C3k2(gw(256) + gw(512), gw(512), n=gd(2), c3k=False or c3k_override)) + + # 20: Conv [512, 3, 2] + self.model.append(Conv(gw(512), gw(512), 3, 2)) + + # 21: Concat [-1, 10] (P5) + self.model.append(Concat(dimension=1)) + + # 22: C3k2 [1024, True] -> n=2. Input: P4_down(gw(512)) + P5(gw(1024)) + self.model.append(C3k2(gw(512) + gw(1024), gw(1024), n=gd(2), c3k=True)) + + # 23: Detect [nc] + self.model.append(YOLOESegment(nc, ch=[gw(256), gw(512), gw(1024)])) + + # --- 初始化权重 --- + self.initialize_weights() + + def initialize_weights(self): + """初始化模型权重,特别是 Detect 头的 Bias""" + for m in self.modules(): + if isinstance(m, (nn.Conv2d, nn.Linear)): + # 使用 Kaiming 初始化或其他合适的初始化 + pass + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + detect_layer = self.model[-1] + if isinstance(detect_layer, Detect): + detect_layer.bias_init() + + def set_classes(self, names: List[str], embeddings: torch.Tensor): + assert embeddings.ndim == 3, "Embeddings must be (1, N, D)" + self.pe = embeddings + self.model[-1].nc = len(names) # type: ignore + self.nc = len(names) + + def forward(self, x, tpe=None, vpe=None): + # Backbone + x = self.model[0](x) + x = self.model[1](x) + x = self.model[2](x) + x = self.model[3](x) + p3 = self.model[4](x) # 保存 P3 (layer 4) + x = self.model[5](p3) + p4 = self.model[6](x) # 保存 P4 (layer 6) + x = self.model[7](p4) + x = self.model[8](x) + x = self.model[9](x) + p5 = self.model[10](x) # 保存 P5 (layer 10) + + # Head + x = self.model[11](p5) # Upsample + x = self.model[12]([x, p4]) # Concat P4 + h1 = self.model[13](x) # Head P4 (layer 13) + + x = self.model[14](h1) # Upsample + x = self.model[15]([x, p3]) # Concat P3 + h2 = self.model[16](x) # Output P3 (layer 16) + + x = self.model[17](h2) # Conv + x = self.model[18]([x, h1]) # Concat Head P4 + h3 = self.model[19](x) # Output P4 (layer 19) + + x = self.model[20](h3) # Conv + x = self.model[21]([x, p5]) # Concat P5 + h4 = self.model[22](x) # Output P5 (layer 22) + + head = self.model[23] + feats = [h2, h3, h4] + + processed_tpe = head.get_tpe(tpe) # type: ignore + + processed_vpe = head.get_vpe(feats, vpe) if vpe is not None else None # type: ignore + + all_pe = [] + if processed_tpe is not None: + all_pe.append(processed_tpe) + if processed_vpe is not None: + all_pe.append(processed_vpe) + + if not all_pe: + if self.pe is not None: + all_pe.append(self.pe.to(device=x.device, dtype=x.dtype)) + else: + all_pe.append(torch.zeros(1, self.nc, head.embed, device=x.device, dtype=x.dtype)) + + cls_pe = torch.cat(all_pe, dim=1) + + b = x.shape[0] + if cls_pe.shape[0] != b: + cls_pe = cls_pe.expand(b, -1, -1) + + return head(feats, cls_pe) + + def load_weights(self, pth_file): + state_dict = torch.load(pth_file, map_location='cpu', weights_only=False) + # 移除可能存在的 'model.' 前缀 (如果权重来自 ultralytics 官方) + # 官方权重通常是 model.model.0.conv... 这种格式,或者直接是 model.0.conv... + # 这里做一个简单的兼容性处理 + new_state_dict = {} + for k, v in state_dict.items(): + # 处理 ultralytics 权重字典中的 'model' 键 + if k == 'model': + # 如果是完整的 checkpoint,权重在 'model' 键下 + # 且通常是 model.state_dict() + if hasattr(v, 'state_dict'): + v = v.state_dict() + elif isinstance(v, dict): + pass # v 就是 state_dict + else: + # 可能是 model 对象本身 + try: + v = v.float().state_dict() + except: + continue + + for sub_k, sub_v in v.items(): + new_state_dict[sub_k] = sub_v + break + else: + new_state_dict[k] = v + + if not new_state_dict: + new_state_dict = state_dict + + # 尝试加载 + try: + self.load_state_dict(new_state_dict, strict=True) + print(f"Successfully loaded weights from {pth_file}") + except Exception as e: + print(f"Error loading weights: {e}") + print("Trying to load with strict=False...") + self.load_state_dict(new_state_dict, strict=False) + +if __name__ == "__main__": + model = YOLO11E(nc=80, scale='l') + model.load_weights("yoloe-11l-seg.pth") + + # 模拟 set_classes + # 假设我们有2个类,embedding维度是512 + fake_embeddings = torch.randn(1, 2, 512) + model.set_classes(["class1", "class2"], fake_embeddings) + + # 推理 + dummy_input = torch.randn(1, 3, 640, 640) + model.eval() + output = model(dummy_input) + print("Output shape:", output[0].shape) # 应该是 (1, 4+mask_coeffs+num_classes, anchors) \ No newline at end of file diff --git a/yolo11n.pt b/yolo11n.pt new file mode 100644 index 0000000..45b273b Binary files /dev/null and b/yolo11n.pt differ diff --git a/yolo11s.pth b/yolo11s.pth new file mode 100644 index 0000000..539ef44 Binary files /dev/null and b/yolo11s.pth differ