Compare commits
2 Commits
2b8f25f318
...
58b24abbf0
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
58b24abbf0 | ||
|
|
604951f9c2 |
288
dataset.py
Normal file
288
dataset.py
Normal file
@@ -0,0 +1,288 @@
|
||||
import os
|
||||
import glob
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import random
|
||||
from torch.utils.data import Dataset
|
||||
import albumentations as A
|
||||
from albumentations.pytorch import ToTensorV2
|
||||
|
||||
class YOLODataset(Dataset):
|
||||
def __init__(self, img_dir, label_dir, img_size=640, is_train=True):
|
||||
self.img_dir = img_dir
|
||||
self.label_dir = label_dir
|
||||
self.img_size = img_size
|
||||
self.is_train = is_train
|
||||
self.use_mosaic = is_train # 新增标志位,默认训练时开启
|
||||
|
||||
# 支持多种图片格式
|
||||
self.img_files = sorted(
|
||||
glob.glob(os.path.join(img_dir, "*.jpg")) +
|
||||
glob.glob(os.path.join(img_dir, "*.png")) +
|
||||
glob.glob(os.path.join(img_dir, "*.jpeg"))
|
||||
)
|
||||
|
||||
# --- 1. 对齐 Ultralytics 的 Albumentations 配置 ---
|
||||
# default.yaml: hsv_h: 0.015, hsv_s: 0.7, hsv_v: 0.4
|
||||
# OpenCV Hue range is [0, 179], Sat/Val is [0, 255]
|
||||
h_limit = int(0.015 * 179) # ~2
|
||||
s_limit = int(0.7 * 255) # ~178
|
||||
v_limit = int(0.4 * 255) # ~102
|
||||
|
||||
# default.yaml: translate: 0.1, scale: 0.5 (0.5~1.5), degrees: 0.0
|
||||
|
||||
if is_train:
|
||||
self.transform = A.Compose([
|
||||
# 几何增强 (Mosaic 之后再做一次微调,或者处理非 Mosaic 的情况)
|
||||
# 注意:Mosaic 输出已经是大图,这里主要负责最后的随机扰动
|
||||
A.Affine(
|
||||
scale=(0.5, 1.5), # scale: 0.5
|
||||
translate_percent=(0.1, 0.1), # translate: 0.1
|
||||
rotate=(-0, 0), # degrees: 0.0 (COCO default)
|
||||
shear=(-0, 0), # shear: 0.0
|
||||
p=0.5
|
||||
),
|
||||
|
||||
# 色彩增强 (严格对齐 default.yaml)
|
||||
A.HueSaturationValue(
|
||||
hue_shift_limit=h_limit,
|
||||
sat_shift_limit=s_limit,
|
||||
val_shift_limit=v_limit,
|
||||
p=0.5
|
||||
),
|
||||
|
||||
A.Blur(p=0.01),
|
||||
A.MedianBlur(p=0.01),
|
||||
A.ToGray(p=0.01),
|
||||
A.CLAHE(p=0.01),
|
||||
|
||||
# 翻转
|
||||
A.HorizontalFlip(p=0.5), # fliplr: 0.5
|
||||
|
||||
# 最终处理
|
||||
A.Resize(img_size, img_size), # 确保最后尺寸一致
|
||||
A.Normalize(mean=(0, 0, 0), std=(1, 1, 1), max_pixel_value=255.0),
|
||||
ToTensorV2()
|
||||
], bbox_params=A.BboxParams(format='yolo', min_visibility=0.0, label_fields=['class_labels']))
|
||||
else:
|
||||
# 验证集:Letterbox (保持比例填充)
|
||||
self.transform = A.Compose([
|
||||
A.LongestMaxSize(max_size=img_size),
|
||||
A.PadIfNeeded(min_height=img_size, min_width=img_size, border_mode=cv2.BORDER_CONSTANT),
|
||||
A.Normalize(mean=(0, 0, 0), std=(1, 1, 1), max_pixel_value=255.0),
|
||||
ToTensorV2()
|
||||
], bbox_params=A.BboxParams(format='yolo', min_visibility=0.1, label_fields=['class_labels']))
|
||||
|
||||
def close_mosaic(self):
|
||||
"""关闭 Mosaic 增强"""
|
||||
self.use_mosaic = False
|
||||
print("Mosaic augmentation disabled.")
|
||||
|
||||
def __len__(self):
|
||||
return len(self.img_files)
|
||||
|
||||
def load_image(self, index):
|
||||
"""加载单张图片并调整长边到 img_size"""
|
||||
img_path = self.img_files[index]
|
||||
img = cv2.imread(img_path)
|
||||
if img is None:
|
||||
raise FileNotFoundError(f"Image not found: {img_path}")
|
||||
|
||||
h, w = img.shape[:2]
|
||||
r = self.img_size / max(h, w)
|
||||
if r != 1:
|
||||
img = cv2.resize(img, (int(w * r), int(h * r)), interpolation=cv2.INTER_LINEAR)
|
||||
|
||||
return img, (h, w), img.shape[:2] # img, original_hw, resized_hw
|
||||
|
||||
def load_label(self, index, img_shape):
|
||||
"""加载标签并归一化"""
|
||||
img_path = self.img_files[index]
|
||||
label_path = self._get_label_path(img_path)
|
||||
h, w = img_shape
|
||||
|
||||
labels = []
|
||||
if os.path.exists(label_path):
|
||||
with open(label_path, 'r') as f:
|
||||
for line in f:
|
||||
parts = line.strip().split()
|
||||
if len(parts) >= 5:
|
||||
cls = int(parts[0])
|
||||
bx, by, bw, bh = map(float, parts[1:5])
|
||||
labels.append([cls, bx, by, bw, bh])
|
||||
|
||||
return np.array(labels, dtype=np.float32) if labels else np.zeros((0, 5), dtype=np.float32)
|
||||
|
||||
def load_mosaic(self, index):
|
||||
"""
|
||||
实现 YOLO 的 Mosaic 增强 (4张图拼成一张)
|
||||
"""
|
||||
labels4 = []
|
||||
s = self.img_size
|
||||
# 修复: 列表推导式需要遍历两次以生成 yc 和 xc
|
||||
yc, xc = [int(random.uniform(-x, 2 * s + x)) for x in [-s // 2, -s // 2]]
|
||||
|
||||
# 随机选3个额外索引
|
||||
indices = [index] + [random.randint(0, len(self.img_files) - 1) for _ in range(3)]
|
||||
random.shuffle(indices)
|
||||
|
||||
# 初始化大图 (2x size)
|
||||
img4 = np.full((s * 2, s * 2, 3), 114, dtype=np.uint8)
|
||||
|
||||
for i, idx in enumerate(indices):
|
||||
# 加载图片
|
||||
img, _, (h, w) = self.load_image(idx)
|
||||
|
||||
# 放置位置: top-left, top-right, bottom-left, bottom-right
|
||||
if i == 0: # top left
|
||||
x1a, y1a, x2a, y2a = max(xc - w, 0), max(yc - h, 0), xc, yc
|
||||
x1b, y1b, x2b, y2b = w - (x2a - x1a), h - (y2a - y1a), w, h
|
||||
elif i == 1: # top right
|
||||
x1a, y1a, x2a, y2a = xc, max(yc - h, 0), min(xc + w, s * 2), yc
|
||||
x1b, y1b, x2b, y2b = 0, h - (y2a - y1a), min(w, x2a - x1a), h
|
||||
elif i == 2: # bottom left
|
||||
x1a, y1a, x2a, y2a = max(xc - w, 0), yc, xc, min(yc + h, s * 2)
|
||||
x1b, y1b, x2b, y2b = w - (x2a - x1a), 0, w, min(y2a - y1a, h)
|
||||
elif i == 3: # bottom right
|
||||
x1a, y1a, x2a, y2a = xc, yc, min(xc + w, s * 2), min(yc + h, s * 2)
|
||||
x1b, y1b, x2b, y2b = 0, 0, min(w, x2a - x1a), min(y2a - y1a, h)
|
||||
|
||||
# 贴图
|
||||
img4[y1a:y2a, x1a:x2a] = img[y1b:y2b, x1b:x2b] # type: ignore
|
||||
padw = x1a - x1b # type: ignore
|
||||
padh = y1a - y1b # type: ignore
|
||||
|
||||
# 处理标签
|
||||
labels = self.load_label(idx, (h, w))
|
||||
if labels.size > 0:
|
||||
# Normalized xywh -> Pixel xywh
|
||||
labels[:, 1] = labels[:, 1] * w
|
||||
labels[:, 2] = labels[:, 2] * h
|
||||
labels[:, 3] = labels[:, 3] * w
|
||||
labels[:, 4] = labels[:, 4] * h
|
||||
|
||||
# xywh -> xyxy (Pixel)
|
||||
xyxy = np.copy(labels)
|
||||
xyxy[:, 1] = labels[:, 1] - labels[:, 3] / 2 + padw
|
||||
xyxy[:, 2] = labels[:, 2] - labels[:, 4] / 2 + padh
|
||||
xyxy[:, 3] = labels[:, 1] + labels[:, 3] / 2 + padw
|
||||
xyxy[:, 4] = labels[:, 2] + labels[:, 4] / 2 + padh
|
||||
|
||||
labels4.append(xyxy)
|
||||
|
||||
# Concat labels
|
||||
if len(labels4):
|
||||
labels4 = np.concatenate(labels4, 0)
|
||||
# Clip to mosaic image border
|
||||
np.clip(labels4[:, 1:], 0, 2 * s, out=labels4[:, 1:])
|
||||
|
||||
# 转换回 Normalized xywh (相对于 2*s 的大图)
|
||||
# Albumentations 需要 normalized xywh
|
||||
new_labels = np.copy(labels4)
|
||||
w_mosaic, h_mosaic = s * 2, s * 2
|
||||
|
||||
# xyxy -> xywh
|
||||
new_labels[:, 1] = (labels4[:, 1] + labels4[:, 3]) / 2 / w_mosaic
|
||||
new_labels[:, 2] = (labels4[:, 2] + labels4[:, 4]) / 2 / h_mosaic
|
||||
new_labels[:, 3] = (labels4[:, 3] - labels4[:, 1]) / w_mosaic
|
||||
new_labels[:, 4] = (labels4[:, 4] - labels4[:, 2]) / h_mosaic
|
||||
|
||||
return img4, new_labels
|
||||
else:
|
||||
return img4, np.zeros((0, 5))
|
||||
|
||||
def __getitem__(self, index):
|
||||
try:
|
||||
if self.is_train:
|
||||
# 修改判断逻辑:同时检查 is_train 和 use_mosaic
|
||||
if self.use_mosaic and random.random() < 1.0:
|
||||
img, labels = self.load_mosaic(index)
|
||||
else:
|
||||
img, _, _ = self.load_image(index)
|
||||
h, w = img.shape[:2]
|
||||
labels = self.load_label(index, (h, w))
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # type: ignore
|
||||
else:
|
||||
img = cv2.imread(self.img_files[index])
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # type: ignore
|
||||
h, w = img.shape[:2]
|
||||
labels = self.load_label(index, (h, w))
|
||||
|
||||
# --- 修复开始: 清洗边界框 ---
|
||||
# labels 格式: [cls, x, y, w, h] (Normalized)
|
||||
|
||||
valid_bboxes = []
|
||||
valid_labels = []
|
||||
|
||||
h_img, w_img = img.shape[:2]
|
||||
|
||||
for i in range(len(labels)):
|
||||
cls = labels[i, 0]
|
||||
x, y, w, h = labels[i, 1:]
|
||||
|
||||
# 1. 限制在 [0, 1] 范围内 (处理 Mosaic 裁剪产生的越界)
|
||||
x1 = np.clip(x - w / 2, 0, 1)
|
||||
y1 = np.clip(y - h / 2, 0, 1)
|
||||
x2 = np.clip(x + w / 2, 0, 1)
|
||||
y2 = np.clip(y + h / 2, 0, 1)
|
||||
|
||||
# 2. 重新计算宽高
|
||||
w_new = x2 - x1
|
||||
h_new = y2 - y1
|
||||
|
||||
# 3. 重新计算中心点
|
||||
x_new = x1 + w_new / 2
|
||||
y_new = y1 + h_new / 2
|
||||
|
||||
# 4. 过滤掉极小的框 (例如小于 2 个像素)
|
||||
if w_new * w_img > 2 and h_new * h_img > 2:
|
||||
valid_bboxes.append([x_new, y_new, w_new, h_new])
|
||||
valid_labels.append(cls)
|
||||
|
||||
if len(valid_bboxes) == 0:
|
||||
# 如果这张图的所有框都被过滤掉了,尝试下一张
|
||||
return self.__getitem__((index + 1) % len(self))
|
||||
|
||||
bboxes = valid_bboxes
|
||||
class_labels = valid_labels
|
||||
# --- 修复结束 ---
|
||||
|
||||
# 应用增强
|
||||
transformed = self.transform(image=img, bboxes=bboxes, class_labels=class_labels)
|
||||
image = transformed['image']
|
||||
bboxes = transformed['bboxes']
|
||||
class_labels = transformed['class_labels']
|
||||
|
||||
# 构建 Target Tensor
|
||||
n = len(bboxes)
|
||||
targets = torch.zeros((n, 6))
|
||||
if n > 0:
|
||||
targets[:, 1] = torch.tensor(class_labels)
|
||||
targets[:, 2:] = torch.tensor(bboxes)
|
||||
|
||||
return image, targets, self.img_files[index]
|
||||
|
||||
except Exception as e:
|
||||
# 打印更详细的错误信息以便调试,但不要中断训练
|
||||
# print(f"Error loading data {index}: {e}")
|
||||
return self.__getitem__((index + 1) % len(self))
|
||||
|
||||
def _get_label_path(self, img_path):
|
||||
filename = os.path.basename(img_path).rsplit('.', 1)[0] + ".txt"
|
||||
return os.path.join(self.label_dir, filename)
|
||||
|
||||
@staticmethod
|
||||
def collate_fn(batch):
|
||||
imgs, targets, paths = zip(*batch)
|
||||
imgs = torch.stack(imgs, 0)
|
||||
new_targets = []
|
||||
for i, t in enumerate(targets):
|
||||
if t.shape[0] > 0:
|
||||
t[:, 0] = i
|
||||
new_targets.append(t)
|
||||
if new_targets:
|
||||
targets = torch.cat(new_targets, 0)
|
||||
else:
|
||||
targets = torch.zeros((0, 6))
|
||||
return imgs, targets, paths
|
||||
189
inference.py
Normal file
189
inference.py
Normal file
@@ -0,0 +1,189 @@
|
||||
import torch
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torchvision
|
||||
from yolo11_standalone import YOLO11
|
||||
|
||||
# COCO 80类 类别名称
|
||||
CLASSES = [
|
||||
"person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light",
|
||||
"fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow",
|
||||
"elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee",
|
||||
"skis", "snowboard", "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard",
|
||||
"tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple",
|
||||
"sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch",
|
||||
"potted plant", "bed", "dining table", "toilet", "tv", "laptop", "mouse", "remote", "keyboard", "cell phone",
|
||||
"microwave", "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors", "teddy bear",
|
||||
"hair drier", "toothbrush"
|
||||
]
|
||||
|
||||
# 生成随机颜色用于绘图
|
||||
COLORS = np.random.uniform(0, 255, size=(len(CLASSES), 3))
|
||||
|
||||
def letterbox(im, new_shape=(640, 640), color=(114, 114, 114)):
|
||||
"""
|
||||
将图像缩放并填充到指定大小 (保持纵横比)
|
||||
"""
|
||||
shape = im.shape[:2] # current shape [height, width]
|
||||
if isinstance(new_shape, int):
|
||||
new_shape = (new_shape, new_shape)
|
||||
|
||||
# 计算缩放比例
|
||||
r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
|
||||
|
||||
# 计算padding
|
||||
new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
|
||||
dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
|
||||
dw, dh = dw / 2, dh / 2 # divide padding into 2 sides
|
||||
|
||||
if shape[::-1] != new_unpad: # resize
|
||||
im = cv2.resize(im, new_unpad, interpolation=cv2.INTER_LINEAR)
|
||||
|
||||
top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
|
||||
left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
|
||||
|
||||
# 添加边框
|
||||
im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color)
|
||||
return im, r, (dw, dh)
|
||||
|
||||
def xywh2xyxy(x):
|
||||
"""Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2]"""
|
||||
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
||||
y[..., 0] = x[..., 0] - x[..., 2] / 2 # top left x
|
||||
y[..., 1] = x[..., 1] - x[..., 3] / 2 # top left y
|
||||
y[..., 2] = x[..., 0] + x[..., 2] / 2 # bottom right x
|
||||
y[..., 3] = x[..., 1] + x[..., 3] / 2 # bottom right y
|
||||
return y
|
||||
|
||||
def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, max_det=300):
|
||||
"""
|
||||
非极大值抑制 (NMS)
|
||||
prediction: [Batch, 84, Anchors]
|
||||
"""
|
||||
# 1. 转置: [Batch, 84, Anchors] -> [Batch, Anchors, 84]
|
||||
prediction = prediction.transpose(1, 2)
|
||||
|
||||
bs = prediction.shape[0] # batch size
|
||||
nc = prediction.shape[2] - 4 # number of classes
|
||||
|
||||
# 修复: 使用 max(-1) 在最后一个维度(类别)上寻找最大置信度
|
||||
# 之前的 max(1) 错误地在 Anchors 维度上操作了
|
||||
xc = prediction[..., 4:].max(-1)[0] > conf_thres # candidates
|
||||
|
||||
output = [torch.zeros((0, 6), device=prediction.device)] * bs
|
||||
|
||||
for xi, x in enumerate(prediction): # image index, image inference
|
||||
x = x[xc[xi]] # confidence filtering
|
||||
|
||||
if not x.shape[0]:
|
||||
continue
|
||||
|
||||
# Box (center x, center y, width, height) to (x1, y1, x2, y2)
|
||||
box = xywh2xyxy(x[:, :4])
|
||||
|
||||
# Confidence and Class
|
||||
conf, j = x[:, 4:].max(1, keepdim=True)
|
||||
x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]
|
||||
|
||||
# Check shape
|
||||
n = x.shape[0]
|
||||
if not n:
|
||||
continue
|
||||
elif n > max_det:
|
||||
x = x[x[:, 4].argsort(descending=True)[:max_det]]
|
||||
|
||||
# Batched NMS
|
||||
c = x[:, 5:6] * 7680 # classes
|
||||
boxes, scores = x[:, :4] + c, x[:, 4]
|
||||
i = torchvision.ops.nms(boxes, scores, iou_thres)
|
||||
output[xi] = x[i]
|
||||
|
||||
return output
|
||||
|
||||
def main():
|
||||
# 1. 初始化模型
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
print(f"Using device: {device}")
|
||||
|
||||
model = YOLO11(nc=80, scale='s')
|
||||
# 加载你之前转换好的纯净权重
|
||||
model.load_weights("yolo11s.pth")
|
||||
model.to(device)
|
||||
model.eval()
|
||||
# model.train()
|
||||
|
||||
# 2. 读取图片
|
||||
img_path = "1.jpg" # 请替换为你本地的图片路径
|
||||
|
||||
img0 = cv2.imread(img_path)
|
||||
assert img0 is not None, f"Image Not Found {img_path}"
|
||||
|
||||
# 3. 预处理
|
||||
# Letterbox resize
|
||||
img, ratio, (dw, dh) = letterbox(img0, new_shape=(640, 640))
|
||||
|
||||
# BGR to RGB, HWC to CHW
|
||||
img = img[:, :, ::-1].transpose(2, 0, 1)
|
||||
img = np.ascontiguousarray(img)
|
||||
|
||||
img_tensor = torch.from_numpy(img).to(device)
|
||||
img_tensor = img_tensor.float()
|
||||
img_tensor /= 255.0 # 0 - 255 to 0.0 - 1.0
|
||||
if img_tensor.ndim == 3:
|
||||
img_tensor = img_tensor.unsqueeze(0)
|
||||
|
||||
# 4. 推理
|
||||
print("开始推理...")
|
||||
with torch.no_grad():
|
||||
pred = model(img_tensor)
|
||||
|
||||
# 5. 后处理 (NMS)
|
||||
pred = non_max_suppression(pred, conf_thres=0.25, iou_thres=0.45)
|
||||
|
||||
# 6. 绘制结果
|
||||
det = pred[0] # 仅处理第一张图片
|
||||
|
||||
if len(det):
|
||||
# 将坐标映射回原图尺寸
|
||||
# det[:, :4] 是 x1, y1, x2, y2
|
||||
det[:, [0, 2]] -= dw # x padding
|
||||
det[:, [1, 3]] -= dh # y padding
|
||||
det[:, :4] /= ratio
|
||||
|
||||
# 裁剪坐标防止越界
|
||||
det[:, 0].clamp_(0, img0.shape[1])
|
||||
det[:, 1].clamp_(0, img0.shape[0])
|
||||
det[:, 2].clamp_(0, img0.shape[1])
|
||||
det[:, 3].clamp_(0, img0.shape[0])
|
||||
|
||||
print(f"检测到 {len(det)} 个目标")
|
||||
|
||||
for *xyxy, conf, cls in det:
|
||||
c = int(cls)
|
||||
label = f'{CLASSES[c]} {conf:.2f}'
|
||||
p1, p2 = (int(xyxy[0]), int(xyxy[1])), (int(xyxy[2]), int(xyxy[3]))
|
||||
|
||||
# 画框
|
||||
color = COLORS[c]
|
||||
cv2.rectangle(img0, p1, p2, color, 2, lineType=cv2.LINE_AA)
|
||||
|
||||
# 画标签背景
|
||||
t_size = cv2.getTextSize(label, 0, fontScale=0.5, thickness=1)[0]
|
||||
p2_label = p1[0] + t_size[0], p1[1] - t_size[1] - 3
|
||||
cv2.rectangle(img0, p1, p2_label, color, -1, cv2.LINE_AA)
|
||||
|
||||
# 画文字
|
||||
cv2.putText(img0, label, (p1[0], p1[1] - 2), 0, 0.5, [255, 255, 255], thickness=1, lineType=cv2.LINE_AA)
|
||||
|
||||
print(f" - {label} at {p1}-{p2}")
|
||||
|
||||
# 7. 显示/保存结果
|
||||
cv2.imwrite("result.jpg", img0)
|
||||
print("结果已保存至 result.jpg")
|
||||
|
||||
def import_os_exists(path):
|
||||
import os
|
||||
return os.path.exists(path)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
168
inference_yoloe.py
Normal file
168
inference_yoloe.py
Normal file
@@ -0,0 +1,168 @@
|
||||
import torch
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torchvision
|
||||
from pathlib import Path
|
||||
|
||||
# 导入你的模块
|
||||
from yolo11_standalone import YOLO11E
|
||||
from mobile_clip_standalone import MobileCLIP
|
||||
|
||||
from ultralytics import YOLOE
|
||||
|
||||
|
||||
# --- 配置 ---
|
||||
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
YOLO_WEIGHTS = "yoloe-11l-seg.pth" # 替换为你的 YOLO 权重路径
|
||||
CLIP_WEIGHTS = "mobileclip_blt.ts" # 替换为你的 MobileCLIP 权重路径
|
||||
CLIP_SIZE = "blt" # 对应 MobileCLIP 的 size
|
||||
IMAGE_PATH = "1.jpg" # 待检测图片
|
||||
|
||||
# 自定义检测类别 (Open Vocabulary)
|
||||
CUSTOM_CLASSES = ["girl", "red balloon"]
|
||||
|
||||
# 绘图颜色
|
||||
COLORS = np.random.uniform(0, 255, size=(len(CUSTOM_CLASSES), 3))
|
||||
|
||||
# --- 辅助函数 (Letterbox, NMS 等) ---
|
||||
def letterbox(im, new_shape=(640, 640), color=(114, 114, 114)):
|
||||
shape = im.shape[:2]
|
||||
if isinstance(new_shape, int): new_shape = (new_shape, new_shape)
|
||||
r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
|
||||
new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
|
||||
dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1]
|
||||
dw, dh = dw / 2, dh / 2
|
||||
if shape[::-1] != new_unpad:
|
||||
im = cv2.resize(im, new_unpad, interpolation=cv2.INTER_LINEAR)
|
||||
top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
|
||||
left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
|
||||
im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color)
|
||||
return im, r, (dw, dh)
|
||||
|
||||
def xywh2xyxy(x):
|
||||
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
||||
y[..., 0] = x[..., 0] - x[..., 2] / 2
|
||||
y[..., 1] = x[..., 1] - x[..., 3] / 2
|
||||
y[..., 2] = x[..., 0] + x[..., 2] / 2
|
||||
y[..., 3] = x[..., 1] + x[..., 3] / 2
|
||||
return y
|
||||
|
||||
def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.7, max_det=300):
|
||||
prediction = prediction.transpose(1, 2)
|
||||
bs = prediction.shape[0]
|
||||
xc = prediction[..., 4:].max(-1)[0] > conf_thres
|
||||
output = [torch.zeros((0, 6), device=prediction.device)] * bs
|
||||
for xi, x in enumerate(prediction):
|
||||
x = x[xc[xi]]
|
||||
if not x.shape[0]: continue
|
||||
box = xywh2xyxy(x[:, :4])
|
||||
conf, j = x[:, 4:].max(1, keepdim=True)
|
||||
x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]
|
||||
n = x.shape[0]
|
||||
if not n: continue
|
||||
elif n > max_det: x = x[x[:, 4].argsort(descending=True)[:max_det]]
|
||||
c = x[:, 5:6] * 7680
|
||||
boxes, scores = x[:, :4] + c, x[:, 4]
|
||||
i = torchvision.ops.nms(boxes, scores, iou_thres)
|
||||
output[xi] = x[i]
|
||||
return output
|
||||
|
||||
def main():
|
||||
print(f"Using device: {DEVICE}")
|
||||
|
||||
# 1. 加载 MobileCLIP 文本编码器
|
||||
print(f"Loading MobileCLIP from {CLIP_WEIGHTS}...")
|
||||
if not Path(CLIP_WEIGHTS).exists():
|
||||
raise FileNotFoundError(f"MobileCLIP weights not found: {CLIP_WEIGHTS}")
|
||||
|
||||
clip_model = MobileCLIP(checkpoint=CLIP_WEIGHTS, size=CLIP_SIZE, device=DEVICE)
|
||||
|
||||
# 2. 生成文本 Embeddings
|
||||
print(f"Encoding classes: {CUSTOM_CLASSES}")
|
||||
prompts = [f"{c}" for c in CUSTOM_CLASSES]
|
||||
|
||||
tokens = clip_model.tokenize(prompts)
|
||||
text_embeddings = clip_model.encode_text(tokens) # Shape: (N, 512)
|
||||
|
||||
# 调整维度为 (1, N, 512) 以匹配 YOLO11E 输入
|
||||
text_embeddings = text_embeddings.unsqueeze(0)
|
||||
|
||||
# 3. 加载 YOLO11E 检测模型
|
||||
print(f"Loading YOLO11E from {YOLO_WEIGHTS}...")
|
||||
if not Path(YOLO_WEIGHTS).exists():
|
||||
raise FileNotFoundError(f"YOLO weights not found: {YOLO_WEIGHTS}")
|
||||
|
||||
# 注意:scale='l' 必须与你的权重文件匹配 (s, m, l, x)
|
||||
yolo_model = YOLO11E(nc=80, scale='l')
|
||||
yolo_model.load_weights(YOLO_WEIGHTS)
|
||||
yolo_model.to(DEVICE) # 使用半精度to(DEVICE)
|
||||
yolo_model.eval()
|
||||
|
||||
head = yolo_model.model[-1]
|
||||
|
||||
with torch.no_grad():
|
||||
text_pe = head.get_tpe(text_embeddings) # type: ignore
|
||||
|
||||
yolo_model.set_classes(CUSTOM_CLASSES, text_pe)
|
||||
|
||||
# 5. 图像预处理
|
||||
img0 = cv2.imread(IMAGE_PATH)
|
||||
assert img0 is not None, f"Image Not Found {IMAGE_PATH}"
|
||||
img, ratio, (dw, dh) = letterbox(img0, new_shape=(640, 640))
|
||||
img = img[:, :, ::-1].transpose(2, 0, 1)
|
||||
img = np.ascontiguousarray(img)
|
||||
img_tensor = torch.from_numpy(img).to(DEVICE)
|
||||
img_tensor = img_tensor.float()
|
||||
img_tensor /= 255.0
|
||||
if img_tensor.ndim == 3:
|
||||
img_tensor = img_tensor.unsqueeze(0)
|
||||
|
||||
# 6. 推理
|
||||
print("Running inference...")
|
||||
with torch.no_grad():
|
||||
pred = yolo_model(img_tensor)
|
||||
if isinstance(pred, tuple):
|
||||
pred = pred[0]
|
||||
nc = len(CUSTOM_CLASSES)
|
||||
pred = pred[:, :4+nc, :]
|
||||
|
||||
# 7. 后处理 (NMS)
|
||||
pred = non_max_suppression(pred, conf_thres=0.25, iou_thres=0.7)
|
||||
|
||||
print(pred)
|
||||
|
||||
# 8. 可视化
|
||||
det = pred[0]
|
||||
if len(det):
|
||||
det[:, [0, 2]] -= dw
|
||||
det[:, [1, 3]] -= dh
|
||||
det[:, :4] /= ratio
|
||||
det[:, 0].clamp_(0, img0.shape[1])
|
||||
det[:, 1].clamp_(0, img0.shape[0])
|
||||
det[:, 2].clamp_(0, img0.shape[1])
|
||||
det[:, 3].clamp_(0, img0.shape[0])
|
||||
|
||||
print(f"Detected {len(det)} objects:")
|
||||
for *xyxy, conf, cls in det:
|
||||
c = int(cls)
|
||||
class_name = CUSTOM_CLASSES[c] if c < len(CUSTOM_CLASSES) else str(c)
|
||||
label = f'{class_name} {conf:.2f}'
|
||||
|
||||
p1, p2 = (int(xyxy[0]), int(xyxy[1])), (int(xyxy[2]), int(xyxy[3]))
|
||||
color = COLORS[c]
|
||||
|
||||
cv2.rectangle(img0, p1, p2, color, 2, cv2.LINE_AA)
|
||||
t_size = cv2.getTextSize(label, 0, 0.5, 1)[0]
|
||||
p2_label = p1[0] + t_size[0], p1[1] - t_size[1] - 3
|
||||
cv2.rectangle(img0, p1, p2_label, color, -1, cv2.LINE_AA)
|
||||
cv2.putText(img0, label, (p1[0], p1[1] - 2), 0, 0.5, [255, 255, 255], 1, cv2.LINE_AA)
|
||||
print(f" - {label}")
|
||||
else:
|
||||
print("No objects detected.")
|
||||
|
||||
output_path = "result_full.jpg"
|
||||
cv2.imwrite(output_path, img0)
|
||||
print(f"Result saved to {output_path}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
448
loss.py
Normal file
448
loss.py
Normal file
@@ -0,0 +1,448 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
# ==============================================================================
|
||||
# 1. 基础工具函数 (Utils)
|
||||
# ==============================================================================
|
||||
|
||||
def xywh2xyxy(x):
|
||||
"""Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right."""
|
||||
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
||||
y[..., 0] = x[..., 0] - x[..., 2] / 2 # top left x
|
||||
y[..., 1] = x[..., 1] - x[..., 3] / 2 # top left y
|
||||
y[..., 2] = x[..., 0] + x[..., 2] / 2 # bottom right x
|
||||
y[..., 3] = x[..., 1] + x[..., 3] / 2 # bottom right y
|
||||
return y
|
||||
|
||||
def bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7):
|
||||
"""
|
||||
Calculate Intersection over Union (IoU) of box1(1, 4) to box2(n, 4).
|
||||
Args:
|
||||
box1 (torch.Tensor): predicted, shape (N, 4)
|
||||
box2 (torch.Tensor): target, shape (N, 4)
|
||||
xywh (bool): If True, input boxes are (x, y, w, h). If False, (x1, y1, x2, y2).
|
||||
CIoU (bool): If True, calculate Complete IoU.
|
||||
"""
|
||||
# Get the coordinates of bounding boxes
|
||||
if xywh: # transform from xywh to xyxy
|
||||
(x1, y1, w1, h1), (x2, y2, w2, h2) = box1.chunk(4, -1), box2.chunk(4, -1)
|
||||
w1_, h1_, w2_, h2_ = w1 / 2, h1 / 2, w2 / 2, h2 / 2
|
||||
b1_x1, b1_x2, b1_y1, b1_y2 = x1 - w1_, x1 + w1_, y1 - h1_, y1 + h1_
|
||||
b2_x1, b2_x2, b2_y1, b2_y2 = x2 - w2_, x2 + w2_, y2 - h2_, y2 + h2_
|
||||
else: # x1, y1, x2, y2
|
||||
b1_x1, b1_y1, b1_x2, b1_y2 = box1.chunk(4, -1)
|
||||
b2_x1, b2_y1, b2_x2, b2_y2 = box2.chunk(4, -1)
|
||||
w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps
|
||||
w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps
|
||||
|
||||
# Intersection area
|
||||
inter = (torch.min(b1_x2, b2_x2) - torch.max(b1_x1, b2_x1)).clamp(0) * \
|
||||
(torch.min(b1_y2, b2_y2) - torch.max(b1_y1, b2_y1)).clamp(0)
|
||||
|
||||
# Union Area
|
||||
union = w1 * h1 + w2 * h2 - inter + eps
|
||||
|
||||
# IoU
|
||||
iou = inter / union
|
||||
if CIoU or DIoU or GIoU:
|
||||
cw = torch.max(b1_x2, b2_x2) - torch.min(b1_x1, b2_x1) # convex (smallest enclosing box) width
|
||||
ch = torch.max(b1_y2, b2_y2) - torch.min(b1_y1, b2_y1) # convex height
|
||||
if CIoU or DIoU: # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1
|
||||
c2 = cw ** 2 + ch ** 2 + eps # convex diagonal squared
|
||||
rho2 = ((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 + (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4 # center dist ** 2
|
||||
if CIoU: # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47
|
||||
v = (4 / math.pi ** 2) * torch.pow(torch.atan(w2 / h2) - torch.atan(w1 / h1), 2)
|
||||
with torch.no_grad():
|
||||
alpha = v / (v - iou + (1 + eps))
|
||||
return iou - (rho2 / c2 + v * alpha) # CIoU
|
||||
return iou - rho2 / c2 # DIoU
|
||||
c_area = cw * ch + eps # convex area
|
||||
return iou - (c_area - union) / c_area # GIoU
|
||||
return iou
|
||||
|
||||
def make_anchors(feats, strides, grid_cell_offset=0.5):
|
||||
"""Generate anchors from features."""
|
||||
anchor_points, stride_tensor = [], []
|
||||
assert feats is not None
|
||||
dtype, device = feats[0].dtype, feats[0].device
|
||||
for i, stride in enumerate(strides):
|
||||
h, w = feats[i].shape[2:]
|
||||
sx = torch.arange(end=w, device=device, dtype=dtype) + grid_cell_offset
|
||||
sy = torch.arange(end=h, device=device, dtype=dtype) + grid_cell_offset
|
||||
sy, sx = torch.meshgrid(sy, sx, indexing="ij")
|
||||
anchor_points.append(torch.stack((sx, sy), -1).view(-1, 2))
|
||||
stride_tensor.append(torch.full((h * w, 1), stride, dtype=dtype, device=device))
|
||||
return torch.cat(anchor_points), torch.cat(stride_tensor)
|
||||
|
||||
def dist2bbox(distance, anchor_points, xywh=True, dim=-1):
|
||||
"""Transform distance(ltrb) to box(xywh or xyxy)."""
|
||||
lt, rb = distance.chunk(2, dim)
|
||||
x1y1 = anchor_points - lt
|
||||
x2y2 = anchor_points + rb
|
||||
if xywh:
|
||||
c_xy = (x1y1 + x2y2) / 2
|
||||
wh = x2y2 - x1y1
|
||||
return torch.cat([c_xy, wh], dim)
|
||||
return torch.cat((x1y1, x2y2), dim)
|
||||
|
||||
def bbox2dist(anchor_points, bbox, reg_max):
|
||||
"""Transform bbox(xyxy) to dist(ltrb)."""
|
||||
x1y1, x2y2 = bbox.chunk(2, -1)
|
||||
return torch.cat((anchor_points - x1y1, x2y2 - anchor_points), -1).clamp_(0, reg_max - 0.01)
|
||||
|
||||
|
||||
class TaskAlignedAssigner(nn.Module):
|
||||
def __init__(self, topk=13, num_classes=80, alpha=1.0, beta=6.0, eps=1e-9):
|
||||
super().__init__()
|
||||
self.topk = topk
|
||||
self.num_classes = num_classes
|
||||
self.alpha = alpha
|
||||
self.beta = beta
|
||||
self.eps = eps
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt):
|
||||
self.bs = pd_scores.shape[0]
|
||||
self.n_max_boxes = gt_bboxes.shape[1]
|
||||
|
||||
if self.n_max_boxes == 0:
|
||||
return (
|
||||
torch.full_like(pd_scores[..., 0], self.num_classes),
|
||||
torch.zeros_like(pd_bboxes),
|
||||
torch.zeros_like(pd_scores),
|
||||
torch.zeros_like(pd_scores[..., 0]),
|
||||
torch.zeros_like(pd_scores[..., 0]),
|
||||
)
|
||||
|
||||
mask_pos, align_metric, overlaps = self.get_pos_mask(
|
||||
pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points, mask_gt
|
||||
)
|
||||
|
||||
target_gt_idx, fg_mask, mask_pos = self.select_highest_overlaps(mask_pos, overlaps, self.n_max_boxes)
|
||||
|
||||
target_labels, target_bboxes, target_scores = self.get_targets(gt_labels, gt_bboxes, target_gt_idx, fg_mask)
|
||||
|
||||
# Normalize
|
||||
align_metric *= mask_pos
|
||||
pos_align_metrics = align_metric.amax(dim=-1, keepdim=True)
|
||||
pos_overlaps = (overlaps * mask_pos).amax(dim=-1, keepdim=True)
|
||||
norm_align_metric = (align_metric * pos_overlaps / (pos_align_metrics + self.eps)).amax(-2).unsqueeze(-1)
|
||||
target_scores = target_scores * norm_align_metric
|
||||
|
||||
return target_labels, target_bboxes, target_scores, fg_mask.bool(), target_gt_idx
|
||||
|
||||
def get_pos_mask(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points, mask_gt):
|
||||
mask_in_gts = self.select_candidates_in_gts(anc_points, gt_bboxes)
|
||||
align_metric, overlaps = self.get_box_metrics(pd_scores, pd_bboxes, gt_labels, gt_bboxes, mask_in_gts * mask_gt)
|
||||
mask_topk = self.select_topk_candidates(align_metric, topk_mask=mask_gt.expand(-1, -1, self.topk).bool())
|
||||
mask_pos = mask_topk * mask_in_gts * mask_gt
|
||||
return mask_pos, align_metric, overlaps
|
||||
|
||||
def get_box_metrics(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, mask_gt):
|
||||
na = pd_bboxes.shape[-2]
|
||||
mask_gt = mask_gt.bool()
|
||||
overlaps = torch.zeros([self.bs, self.n_max_boxes, na], dtype=pd_bboxes.dtype, device=pd_bboxes.device)
|
||||
bbox_scores = torch.zeros([self.bs, self.n_max_boxes, na], dtype=pd_scores.dtype, device=pd_scores.device)
|
||||
|
||||
ind = torch.zeros([2, self.bs, self.n_max_boxes], dtype=torch.long)
|
||||
ind[0] = torch.arange(end=self.bs).view(-1, 1).expand(-1, self.n_max_boxes)
|
||||
ind[1] = gt_labels.squeeze(-1)
|
||||
|
||||
# Clamp labels to handle case where no GT exists or background index issues
|
||||
ind[1] = ind[1].clamp(max=self.num_classes - 1)
|
||||
|
||||
bbox_scores[mask_gt] = pd_scores[ind[0], :, ind[1]][mask_gt]
|
||||
|
||||
pd_boxes = pd_bboxes.unsqueeze(1).expand(-1, self.n_max_boxes, -1, -1)[mask_gt]
|
||||
gt_boxes = gt_bboxes.unsqueeze(2).expand(-1, -1, na, -1)[mask_gt]
|
||||
overlaps[mask_gt] = bbox_iou(gt_boxes, pd_boxes, xywh=False, CIoU=True).squeeze(-1).clamp_(0)
|
||||
|
||||
align_metric = bbox_scores.pow(self.alpha) * overlaps.pow(self.beta)
|
||||
return align_metric, overlaps
|
||||
|
||||
def select_topk_candidates(self, metrics, topk_mask=None):
|
||||
topk_metrics, topk_idxs = torch.topk(metrics, self.topk, dim=-1, largest=True)
|
||||
if topk_mask is None:
|
||||
topk_mask = (topk_metrics.max(-1, keepdim=True)[0] > self.eps).expand_as(topk_idxs)
|
||||
topk_idxs.masked_fill_(~topk_mask, 0)
|
||||
|
||||
count_tensor = torch.zeros(metrics.shape, dtype=torch.int8, device=topk_idxs.device)
|
||||
ones = torch.ones_like(topk_idxs[:, :, :1], dtype=torch.int8, device=topk_idxs.device)
|
||||
for k in range(self.topk):
|
||||
count_tensor.scatter_add_(-1, topk_idxs[:, :, k : k + 1], ones)
|
||||
count_tensor.masked_fill_(count_tensor > 1, 0)
|
||||
return count_tensor.to(metrics.dtype)
|
||||
|
||||
def get_targets(self, gt_labels, gt_bboxes, target_gt_idx, fg_mask):
|
||||
batch_ind = torch.arange(end=self.bs, dtype=torch.int64, device=gt_labels.device)[..., None]
|
||||
target_gt_idx = target_gt_idx + batch_ind * self.n_max_boxes
|
||||
target_labels = gt_labels.long().flatten()[target_gt_idx]
|
||||
|
||||
target_bboxes = gt_bboxes.view(-1, gt_bboxes.shape[-1])[target_gt_idx]
|
||||
target_labels.clamp_(0)
|
||||
|
||||
target_scores = torch.zeros(
|
||||
(target_labels.shape[0], target_labels.shape[1], self.num_classes),
|
||||
dtype=torch.int64,
|
||||
device=target_labels.device,
|
||||
)
|
||||
target_scores.scatter_(2, target_labels.unsqueeze(-1), 1)
|
||||
|
||||
fg_scores_mask = fg_mask[:, :, None].repeat(1, 1, self.num_classes)
|
||||
target_scores = torch.where(fg_scores_mask > 0, target_scores, 0)
|
||||
|
||||
return target_labels, target_bboxes, target_scores
|
||||
|
||||
@staticmethod
|
||||
def select_candidates_in_gts(xy_centers, gt_bboxes, eps=1e-9):
|
||||
n_anchors = xy_centers.shape[0]
|
||||
bs, n_boxes, _ = gt_bboxes.shape
|
||||
lt, rb = gt_bboxes.view(-1, 1, 4).chunk(2, 2)
|
||||
bbox_deltas = torch.cat((xy_centers[None] - lt, rb - xy_centers[None]), dim=2).view(bs, n_boxes, n_anchors, -1)
|
||||
return bbox_deltas.amin(3).gt_(eps)
|
||||
|
||||
@staticmethod
|
||||
def select_highest_overlaps(mask_pos, overlaps, n_max_boxes):
|
||||
fg_mask = mask_pos.sum(-2)
|
||||
if fg_mask.max() > 1:
|
||||
mask_multi_gts = (fg_mask.unsqueeze(1) > 1).expand(-1, n_max_boxes, -1)
|
||||
max_overlaps_idx = overlaps.argmax(1)
|
||||
|
||||
is_max_overlaps = torch.zeros(mask_pos.shape, dtype=mask_pos.dtype, device=mask_pos.device)
|
||||
is_max_overlaps.scatter_(1, max_overlaps_idx.unsqueeze(1), 1)
|
||||
|
||||
mask_pos = torch.where(mask_multi_gts, is_max_overlaps, mask_pos).float()
|
||||
fg_mask = mask_pos.sum(-2)
|
||||
target_gt_idx = mask_pos.argmax(-2)
|
||||
return target_gt_idx, fg_mask, mask_pos
|
||||
|
||||
# ==============================================================================
|
||||
# 3. Loss 模块 (Loss Modules - from loss.py)
|
||||
# ==============================================================================
|
||||
|
||||
class DFLoss(nn.Module):
|
||||
"""Distribution Focal Loss (DFL)."""
|
||||
def __init__(self, reg_max: int = 16) -> None:
|
||||
super().__init__()
|
||||
self.reg_max = reg_max
|
||||
|
||||
def __call__(self, pred_dist: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
||||
target = target.clamp_(0, self.reg_max - 1 - 0.01)
|
||||
tl = target.long() # target left
|
||||
tr = tl + 1 # target right
|
||||
wl = tr - target # weight left
|
||||
wr = 1 - wl # weight right
|
||||
return (
|
||||
F.cross_entropy(pred_dist, tl.view(-1), reduction="none").view(tl.shape) * wl
|
||||
+ F.cross_entropy(pred_dist, tr.view(-1), reduction="none").view(tl.shape) * wr
|
||||
).mean(-1, keepdim=True)
|
||||
|
||||
class BboxLoss(nn.Module):
|
||||
"""Criterion for computing training losses for bounding boxes (IoU + DFL)."""
|
||||
def __init__(self, reg_max: int = 16):
|
||||
super().__init__()
|
||||
self.dfl_loss = DFLoss(reg_max) if reg_max > 1 else None
|
||||
|
||||
def forward(self, pred_dist, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask):
|
||||
weight = target_scores.sum(-1)[fg_mask].unsqueeze(-1)
|
||||
iou = bbox_iou(pred_bboxes[fg_mask], target_bboxes[fg_mask], xywh=False, CIoU=True)
|
||||
loss_iou = ((1.0 - iou) * weight).sum() / target_scores_sum
|
||||
|
||||
if self.dfl_loss:
|
||||
target_ltrb = bbox2dist(anchor_points, target_bboxes, self.dfl_loss.reg_max - 1)
|
||||
loss_dfl = self.dfl_loss(pred_dist[fg_mask].view(-1, self.dfl_loss.reg_max), target_ltrb[fg_mask]) * weight
|
||||
loss_dfl = loss_dfl.sum() / target_scores_sum
|
||||
else:
|
||||
loss_dfl = torch.tensor(0.0).to(pred_dist.device)
|
||||
|
||||
return loss_iou, loss_dfl
|
||||
|
||||
class YOLOv8DetectionLoss:
|
||||
"""
|
||||
独立版的 YOLOv8/v11 Detection Loss。
|
||||
Refactored to assume specific inputs and remove reliance on full Model object.
|
||||
"""
|
||||
def __init__(self, nc=80, reg_max=16, stride=None, hyp=None):
|
||||
"""
|
||||
Args:
|
||||
nc (int): Number of classes.
|
||||
reg_max (int): DFL channels (default 16).
|
||||
stride (list/torch.Tensor): Model strides (e.g., [8, 16, 32]).
|
||||
hyp (dict): Hyperparameters (box, cls, dfl gains).
|
||||
"""
|
||||
if stride is None:
|
||||
stride = [8, 16, 32] # Default strides
|
||||
if hyp is None:
|
||||
# Default YOLOv8 hyperparameters
|
||||
hyp = {'box': 7.5, 'cls': 0.5, 'dfl': 1.5}
|
||||
|
||||
self.nc = nc
|
||||
self.reg_max = reg_max
|
||||
self.stride = stride
|
||||
self.hyp = hyp
|
||||
self.no = nc + reg_max * 4 # Output channels per anchor
|
||||
self.use_dfl = reg_max > 1
|
||||
|
||||
# Components
|
||||
self.assigner = TaskAlignedAssigner(topk=10, num_classes=self.nc, alpha=0.5, beta=6.0)
|
||||
self.bbox_loss = BboxLoss(reg_max)
|
||||
self.bce = nn.BCEWithLogitsLoss(reduction="none")
|
||||
self.proj = torch.arange(reg_max, dtype=torch.float) # Will move to device in method
|
||||
|
||||
def preprocess(self, targets, batch_size, scale_tensor):
|
||||
"""Preprocesses targets: converts [idx, cls, xywh] to internal format."""
|
||||
if targets.shape[0] == 0:
|
||||
out = torch.zeros(batch_size, 0, 5, device=targets.device)
|
||||
else:
|
||||
i = targets[:, 0] # image index
|
||||
_, counts = i.unique(return_counts=True)
|
||||
counts = counts.to(dtype=torch.int32)
|
||||
out = torch.zeros(batch_size, counts.max(), 5, device=targets.device)
|
||||
for j in range(batch_size):
|
||||
matches = i == j
|
||||
n = matches.sum()
|
||||
if n:
|
||||
out[j, :n] = targets[matches, 1:] # cls, x, y, w, h
|
||||
out[..., 1:5] = xywh2xyxy(out[..., 1:5].mul_(scale_tensor))
|
||||
return out
|
||||
|
||||
def bbox_decode(self, anchor_points, pred_dist):
|
||||
"""Decode predicted object bounding box coordinates from anchor points and distribution."""
|
||||
if self.use_dfl:
|
||||
b, a, c = pred_dist.shape # batch, anchors, channels
|
||||
# Project distributions to scalars using self.proj (0, 1, 2, ..., reg_max-1)
|
||||
pred_dist = pred_dist.view(b, a, 4, c // 4).softmax(3).matmul(self.proj.to(pred_dist.device).type(pred_dist.dtype))
|
||||
return dist2bbox(pred_dist, anchor_points, xywh=False)
|
||||
|
||||
def __call__(self, preds, batch):
|
||||
"""
|
||||
Args:
|
||||
preds: List of prediction tensors [B, C, H, W] for each stride level.
|
||||
C should be reg_max * 4 + nc.
|
||||
batch: Dict containing:
|
||||
'batch_idx': Tensor [N], image index for each target
|
||||
'cls': Tensor [N], class index for each target
|
||||
'bboxes': Tensor [N, 4], normalized xywh format
|
||||
Returns:
|
||||
total_loss: scalar tensor (sum of box, cls, dfl), multiplied by batch_size as per original logic if needed,
|
||||
but here we return mean or sum based on reduction.
|
||||
WARNING: Ultralytics returns (loss * bs) usually.
|
||||
loss_items: Tensor used for logging [box, cls, dfl]
|
||||
"""
|
||||
# Ensure proj is on correct device
|
||||
self.device = preds[0].device
|
||||
loss = torch.zeros(3, device=self.device) # box, cls, dfl
|
||||
|
||||
feats = preds
|
||||
|
||||
# Concat features from different strides: (B, no, Total_Anchors)
|
||||
pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
|
||||
(self.reg_max * 4, self.nc), 1
|
||||
)
|
||||
|
||||
pred_scores = pred_scores.permute(0, 2, 1).contiguous() # (B, Anchors, nc)
|
||||
pred_distri = pred_distri.permute(0, 2, 1).contiguous() # (B, Anchors, reg_max * 4)
|
||||
|
||||
dtype = pred_scores.dtype
|
||||
batch_size = pred_scores.shape[0]
|
||||
|
||||
# Calculate image size from features and first stride (assuming square output implies inputs)
|
||||
# Ultralytics uses provided imgsz, but we can infer roughly or pass it.
|
||||
# Here assuming feature map H*stride[0] = img H
|
||||
imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0]
|
||||
|
||||
# Generate anchors
|
||||
anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
|
||||
|
||||
# Process Targets
|
||||
if "batch_idx" not in batch:
|
||||
# Fallback if simplified batch input passed (targets assumed [N, 6] -> idx, cls, x, y, w, h)
|
||||
# This handles cases where user passes raw target tensor instead of dict
|
||||
raise ValueError("Batch dict must contain 'batch_idx', 'cls', and 'bboxes'")
|
||||
|
||||
targets = torch.cat((batch["batch_idx"].view(-1, 1), batch["cls"].view(-1, 1), batch["bboxes"]), 1)
|
||||
# Check if Normalized (xywh)
|
||||
targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
|
||||
gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
|
||||
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0)
|
||||
|
||||
# Decode Boxes
|
||||
pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4)
|
||||
|
||||
# Task Aligned Assignment
|
||||
_, target_bboxes, target_scores, fg_mask, _ = self.assigner(
|
||||
pred_scores.detach().sigmoid(),
|
||||
(pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
|
||||
anchor_points * stride_tensor,
|
||||
gt_labels,
|
||||
gt_bboxes,
|
||||
mask_gt,
|
||||
)
|
||||
|
||||
target_scores_sum = max(target_scores.sum(), 1)
|
||||
|
||||
# CLS Loss
|
||||
loss[1] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum
|
||||
|
||||
# BBox Loss (IoU + DFL)
|
||||
if fg_mask.sum():
|
||||
target_bboxes /= stride_tensor
|
||||
loss[0], loss[2] = self.bbox_loss(
|
||||
pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask
|
||||
)
|
||||
|
||||
# Apply Gains
|
||||
loss[0] *= self.hyp['box']
|
||||
loss[1] *= self.hyp['cls']
|
||||
loss[2] *= self.hyp['dfl']
|
||||
|
||||
return loss.sum() * batch_size, loss.detach() # scaling by batch_size to match Ultralytics behavior
|
||||
|
||||
# ==============================================================================
|
||||
# Usage Example
|
||||
# ==============================================================================
|
||||
if __name__ == "__main__":
|
||||
# 配置
|
||||
num_classes = 80
|
||||
reg_max = 16
|
||||
strides = [8, 16, 32]
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
# 初始化 Loss
|
||||
criterion = YOLOv8DetectionLoss(nc=num_classes, reg_max=reg_max, stride=strides)
|
||||
|
||||
# 模拟输入 (Batch Size = 2)
|
||||
bs = 2
|
||||
# 模拟模型输出: 3 个层级的特征图
|
||||
# [BS, Channels, H, W], Channels = reg_max * 4 + num_classes
|
||||
channels = reg_max * 4 + num_classes
|
||||
preds = [
|
||||
torch.randn(bs, channels, 80, 80, device=device, requires_grad=True), # stride 8
|
||||
torch.randn(bs, channels, 40, 40, device=device, requires_grad=True), # stride 16
|
||||
torch.randn(bs, channels, 20, 20, device=device, requires_grad=True), # stride 32
|
||||
]
|
||||
|
||||
# 模拟 Ground Truth
|
||||
# batch: idx, cls, x, y, w, h (normalized 0-1)
|
||||
target_batch = {
|
||||
"batch_idx": torch.tensor([0, 0, 1], device=device), # Image indices
|
||||
"cls": torch.tensor([1, 5, 10], device=device), # Class indices
|
||||
"bboxes": torch.tensor([ # Normalized xywh
|
||||
[0.5, 0.5, 0.2, 0.3],
|
||||
[0.1, 0.1, 0.1, 0.1],
|
||||
[0.8, 0.8, 0.2, 0.2]
|
||||
], device=device)
|
||||
}
|
||||
|
||||
# 计算 Loss
|
||||
total_loss, loss_items = criterion(preds, target_batch)
|
||||
|
||||
print(f"Total Loss: {total_loss.item()}")
|
||||
print(f"Loss Items (Box, Cls, DFL): {loss_items}")
|
||||
|
||||
# 反向传播测试
|
||||
total_loss.backward()
|
||||
print("Backward pass successful.")
|
||||
148
metrics.py
Normal file
148
metrics.py
Normal file
@@ -0,0 +1,148 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
import torchvision
|
||||
|
||||
def xywh2xyxy(x):
|
||||
"""Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right"""
|
||||
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
||||
y[..., 0] = x[..., 0] - x[..., 2] / 2 # top left x
|
||||
y[..., 1] = x[..., 1] - x[..., 3] / 2 # top left y
|
||||
y[..., 2] = x[..., 0] + x[..., 2] / 2 # bottom right x
|
||||
y[..., 3] = x[..., 1] + x[..., 3] / 2 # bottom right y
|
||||
return y
|
||||
|
||||
def box_iou(box1, box2):
|
||||
"""
|
||||
Return intersection-over-union (Jaccard index) of boxes.
|
||||
Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
|
||||
box1: [N, 4]
|
||||
box2: [M, 4]
|
||||
Returns: [N, M]
|
||||
"""
|
||||
def box_area(box):
|
||||
return (box[:, 2] - box[:, 0]) * (box[:, 3] - box[:, 1])
|
||||
|
||||
area1 = box_area(box1)
|
||||
area2 = box_area(box2)
|
||||
|
||||
lt = torch.max(box1[:, None, :2], box2[:, :2]) # [N,M,2]
|
||||
rb = torch.min(box1[:, None, 2:], box2[:, 2:]) # [N,M,2]
|
||||
|
||||
wh = (rb - lt).clamp(min=0) # [N,M,2]
|
||||
inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
|
||||
|
||||
union = area1[:, None] + area2 - inter
|
||||
return inter / (union + 1e-6)
|
||||
|
||||
def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, max_det=300):
|
||||
"""
|
||||
Non-Maximum Suppression (NMS) on inference results
|
||||
prediction: [Batch, 84, 8400] (for YOLOv8/11)
|
||||
"""
|
||||
# [Batch, 84, Anchors] -> [Batch, Anchors, 84]
|
||||
prediction = prediction.transpose(1, 2)
|
||||
|
||||
bs = prediction.shape[0]
|
||||
nc = prediction.shape[2] - 4
|
||||
xc = prediction[..., 4:].max(-1)[0] > conf_thres
|
||||
|
||||
output = [torch.zeros((0, 6), device=prediction.device)] * bs
|
||||
|
||||
for xi, x in enumerate(prediction):
|
||||
x = x[xc[xi]]
|
||||
if not x.shape[0]:
|
||||
continue
|
||||
|
||||
# Box decoding
|
||||
box = xywh2xyxy(x[:, :4])
|
||||
|
||||
# Confidence and Class
|
||||
conf, j = x[:, 4:].max(1, keepdim=True)
|
||||
x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]
|
||||
|
||||
n = x.shape[0]
|
||||
if not n:
|
||||
continue
|
||||
elif n > max_det:
|
||||
x = x[x[:, 4].argsort(descending=True)[:max_det]]
|
||||
|
||||
# Batched NMS
|
||||
c = x[:, 5:6] * 7680
|
||||
boxes, scores = x[:, :4] + c, x[:, 4]
|
||||
i = torchvision.ops.nms(boxes, scores, iou_thres)
|
||||
output[xi] = x[i]
|
||||
|
||||
return output
|
||||
|
||||
def compute_ap(recall, precision):
|
||||
""" Compute the average precision, given the recall and precision curves """
|
||||
# Append sentinel values to beginning and end
|
||||
mrec = np.concatenate(([0.0], recall, [1.0]))
|
||||
mpre = np.concatenate(([1.0], precision, [0.0]))
|
||||
|
||||
# Compute the precision envelope
|
||||
mpre = np.flip(np.maximum.accumulate(np.flip(mpre)))
|
||||
|
||||
# Integrate area under curve
|
||||
method = 'interp' # methods: 'continuous', 'interp'
|
||||
if method == 'interp':
|
||||
x = np.linspace(0, 1, 101) # 101-point interp (COCO)
|
||||
ap = np.trapz(np.interp(x, mrec, mpre), x) # integrate
|
||||
else: # 'continuous'
|
||||
i = np.where(mrec[1:] != mrec[:-1])[0] # points where x axis (recall) changes
|
||||
ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) # area under curve
|
||||
|
||||
return ap, mpre, mrec
|
||||
|
||||
def smooth(y, f=0.05):
|
||||
"""Box filter of fraction f"""
|
||||
nf = round(len(y) * f * 2) // 2 + 1 # number of filter elements (must be odd)
|
||||
p = np.ones(nf // 2) # ones padding
|
||||
yp = np.concatenate((p * y[0], y, p * y[-1]), 0) # y padded
|
||||
return np.convolve(yp, np.ones(nf) / nf, mode='valid') # y-smoothed
|
||||
|
||||
def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir='.', names=(), eps=1e-16):
|
||||
""" Compute the average precision, given the recall and precision curves. """
|
||||
# Sort by objectness
|
||||
i = np.argsort(-conf)
|
||||
tp, conf, pred_cls = tp[i], conf[i], pred_cls[i]
|
||||
|
||||
# Find unique classes
|
||||
unique_classes, nt = np.unique(target_cls, return_counts=True)
|
||||
nc = unique_classes.shape[0] # number of classes, number of detections
|
||||
|
||||
# Create Precision-Recall curve and compute AP for each class
|
||||
px, py = np.linspace(0, 1, 1000), [] # for plotting
|
||||
ap, p, r = np.zeros((nc, tp.shape[1])), np.zeros((nc, 1000)), np.zeros((nc, 1000))
|
||||
for ci, c in enumerate(unique_classes):
|
||||
i = pred_cls == c
|
||||
n_l = (target_cls == c).sum() # number of labels
|
||||
n_p = i.sum() # number of predictions
|
||||
|
||||
if n_p == 0 or n_l == 0:
|
||||
continue
|
||||
|
||||
# Accumulate FPs and TPs
|
||||
fpc = (1 - tp[i]).cumsum(0)
|
||||
tpc = tp[i].cumsum(0)
|
||||
|
||||
# Recall
|
||||
recall = tpc / (n_l + eps) # recall curve
|
||||
r[ci] = np.interp(-px, -conf[i], recall[:, 0], left=0) # negative x, xp because xp decreases
|
||||
|
||||
# Precision
|
||||
precision = tpc / (tpc + fpc) # precision curve
|
||||
p[ci] = np.interp(-px, -conf[i], precision[:, 0], left=1) # p at pr_score
|
||||
|
||||
# AP from recall-precision curve
|
||||
for j in range(tp.shape[1]):
|
||||
ap[ci, j], mpre, mrec = compute_ap(recall[:, j], precision[:, j])
|
||||
|
||||
# Compute F1 (harmonic mean of precision and recall)
|
||||
f1 = 2 * p * r / (p + r + eps)
|
||||
i = smooth(f1.mean(0), 0.1).argmax() # max F1 index
|
||||
p, r, f1 = p[:, i], r[:, i], f1[:, i]
|
||||
tp = (r * nt).round().astype(int)
|
||||
fp = (tp / (p + eps) - tp).astype(int)
|
||||
|
||||
return tp, fp, p, r, f1, ap, unique_classes.astype(int)
|
||||
117
mobile_clip_standalone.py
Normal file
117
mobile_clip_standalone.py
Normal file
@@ -0,0 +1,117 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from pathlib import Path
|
||||
from typing import List, Union
|
||||
import mobileclip
|
||||
|
||||
class TextModel(nn.Module):
|
||||
"""TextModel 基类,定义接口"""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def tokenize(self, texts):
|
||||
raise NotImplementedError
|
||||
|
||||
def encode_text(self, texts, dtype):
|
||||
raise NotImplementedError
|
||||
|
||||
class MobileCLIP(TextModel):
|
||||
"""
|
||||
MobileCLIP 文本编码器。
|
||||
"""
|
||||
config_size_map = {"s0": "s0", "s1": "s1", "s2": "s2", "b": "b", "blt": "b"}
|
||||
|
||||
def __init__(self, checkpoint: str, size: str = "s0", device: Union[str, torch.device] = "cpu") -> None:
|
||||
"""
|
||||
初始化 MobileCLIP 文本编码器。
|
||||
|
||||
Args:
|
||||
checkpoint (str): 模型权重文件路径 (.pt 或 .ts).
|
||||
size (str): 模型大小标识符 ('s0', 's1', 's2', 'b', 'blt').
|
||||
device (torch.device): 加载模型的设备.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
if isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
|
||||
if not Path(checkpoint).exists():
|
||||
raise FileNotFoundError(f"找不到权重文件: {checkpoint}")
|
||||
|
||||
if size not in self.config_size_map:
|
||||
raise ValueError(f"不支持的大小: {size}. 可选: {list(self.config_size_map.keys())}")
|
||||
|
||||
config = self.config_size_map[size]
|
||||
|
||||
# 1. 加载 Tokenizer
|
||||
self.tokenizer = mobileclip.get_tokenizer(f"mobileclip_{config}")
|
||||
|
||||
# 2. 加载模型
|
||||
if str(checkpoint).endswith('.ts'):
|
||||
# TorchScript 格式 (例如 mobileclip_blt.ts)
|
||||
print(f"Loading TorchScript model from {checkpoint}...")
|
||||
self.model = torch.jit.load(checkpoint, map_location=device)
|
||||
self.is_torchscript = True
|
||||
else:
|
||||
# PyTorch 格式 (.pt)
|
||||
print(f"Loading PyTorch model from {checkpoint}...")
|
||||
self.model = mobileclip.create_model_and_transforms(
|
||||
f"mobileclip_{config}",
|
||||
pretrained=checkpoint,
|
||||
device=device
|
||||
)[0]
|
||||
self.is_torchscript = False
|
||||
|
||||
self.to(device)
|
||||
self.device = device
|
||||
self.eval()
|
||||
|
||||
def tokenize(self, texts: List[str]) -> torch.Tensor:
|
||||
"""
|
||||
将文本转换为 token。
|
||||
"""
|
||||
return self.tokenizer(texts).to(self.device)
|
||||
|
||||
def encode_text(self, texts: torch.Tensor, dtype: torch.dtype = torch.float32) -> torch.Tensor:
|
||||
"""
|
||||
编码 tokenized 文本并进行归一化。
|
||||
"""
|
||||
with torch.no_grad():
|
||||
if self.is_torchscript:
|
||||
return self.model(texts).to(dtype)
|
||||
|
||||
text_features = self.model.encode_text(texts).to(dtype) # type: ignore
|
||||
text_features /= text_features.norm(p=2, dim=-1, keepdim=True)
|
||||
return text_features
|
||||
|
||||
# --- 使用示例 ---
|
||||
if __name__ == "__main__":
|
||||
# 1. 设置设备
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
print(f"Using device: {device}")
|
||||
|
||||
# 指定本地模型路径
|
||||
checkpoint_path = "mobileclip_blt.ts"
|
||||
|
||||
try:
|
||||
if Path(checkpoint_path).exists():
|
||||
# 2. 初始化模型 (指定本地路径和对应大小)
|
||||
# 注意:blt 对应 size="blt"
|
||||
model = MobileCLIP(checkpoint=checkpoint_path, size="blt", device=device)
|
||||
|
||||
# 3. 准备文本
|
||||
input_texts = ["a photo of a cat", "a photo of a dog"]
|
||||
|
||||
# 4. Tokenize
|
||||
tokens = model.tokenize(input_texts)
|
||||
print(f"Tokens shape: {tokens.shape}")
|
||||
|
||||
# 5. Encode
|
||||
features = model.encode_text(tokens)
|
||||
print(f"Features shape: {features.shape}")
|
||||
print("运行成功!")
|
||||
else:
|
||||
print(f"权重文件不存在: {checkpoint_path}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"发生错误: {e}")
|
||||
98
mobileclip/__init__.py
Normal file
98
mobileclip/__init__.py
Normal file
@@ -0,0 +1,98 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
# For licensing see accompanying LICENSE file.
|
||||
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torchvision.transforms import (
|
||||
CenterCrop,
|
||||
Compose,
|
||||
InterpolationMode,
|
||||
Resize,
|
||||
ToTensor,
|
||||
)
|
||||
|
||||
from mobileclip.clip import CLIP
|
||||
from mobileclip.modules.common.mobileone import reparameterize_model
|
||||
from mobileclip.modules.text.tokenizer import (
|
||||
ClipTokenizer,
|
||||
)
|
||||
|
||||
|
||||
def create_model_and_transforms(
|
||||
model_name: str,
|
||||
pretrained: str | None = None,
|
||||
reparameterize: bool | None = True,
|
||||
device: str | torch.device = "cpu",
|
||||
) -> tuple[nn.Module, Any, Any]:
|
||||
"""Method to instantiate model and pre-processing transforms necessary for inference.
|
||||
|
||||
Args:
|
||||
model_name: Model name. Choose from ['mobileclip_s0', 'mobileclip_s1', 'mobileclip_s2', 'mobileclip_b']
|
||||
pretrained: Location of pretrained checkpoint.
|
||||
reparameterize: When set to True, re-parameterizable branches get folded for faster inference.
|
||||
device: Device identifier for model placement.
|
||||
|
||||
Returns:
|
||||
Tuple of instantiated model, and preprocessing transforms for inference.
|
||||
"""
|
||||
# Config files
|
||||
root_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
configs_dir = os.path.join(root_dir, "configs")
|
||||
model_cfg_file = os.path.join(configs_dir, model_name + ".json")
|
||||
|
||||
# Get config from yaml file
|
||||
if not os.path.exists(model_cfg_file):
|
||||
raise ValueError(f"Unsupported model name: {model_name}")
|
||||
model_cfg = json.load(open(model_cfg_file))
|
||||
|
||||
# Build preprocessing transforms for inference
|
||||
resolution = model_cfg["image_cfg"]["image_size"]
|
||||
resize_size = resolution
|
||||
centercrop_size = resolution
|
||||
aug_list = [
|
||||
Resize(
|
||||
resize_size,
|
||||
interpolation=InterpolationMode.BILINEAR,
|
||||
),
|
||||
CenterCrop(centercrop_size),
|
||||
ToTensor(),
|
||||
]
|
||||
preprocess = Compose(aug_list)
|
||||
|
||||
# Build model
|
||||
model = CLIP(cfg=model_cfg)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
# Load checkpoint
|
||||
if pretrained is not None:
|
||||
chkpt = torch.load(pretrained)
|
||||
model.load_state_dict(chkpt)
|
||||
|
||||
# Reparameterize model for inference (if specified)
|
||||
if reparameterize:
|
||||
model = reparameterize_model(model)
|
||||
|
||||
return model, None, preprocess
|
||||
|
||||
|
||||
def get_tokenizer(model_name: str) -> nn.Module:
|
||||
# Config files
|
||||
root_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
configs_dir = os.path.join(root_dir, "configs")
|
||||
model_cfg_file = os.path.join(configs_dir, model_name + ".json")
|
||||
|
||||
# Get config from yaml file
|
||||
model_cfg = json.load(open(model_cfg_file))
|
||||
|
||||
# Build tokenizer
|
||||
text_tokenizer = ClipTokenizer(model_cfg)
|
||||
return text_tokenizer
|
||||
69
mobileclip/clip.py
Normal file
69
mobileclip/clip.py
Normal file
@@ -0,0 +1,69 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
#
|
||||
# For licensing see accompanying LICENSE file.
|
||||
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
|
||||
#
|
||||
"""Model schema in open_clip format for inference only."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from mobileclip.text_encoder import (
|
||||
TextTransformer,
|
||||
)
|
||||
|
||||
from .image_encoder import MCi
|
||||
|
||||
|
||||
class CLIP(nn.Module):
|
||||
"""Base class for multi-modal image-text data."""
|
||||
|
||||
def __init__(self, cfg: dict, output_dict: bool = False, *args, **kwargs) -> None:
|
||||
super().__init__()
|
||||
self.output_dict = output_dict
|
||||
self.projection_dim = cfg["embed_dim"]
|
||||
if self.projection_dim is None:
|
||||
raise ValueError("Please specify `embed_dim` in model config.")
|
||||
|
||||
self.image_encoder = MCi(
|
||||
model_name=cfg["image_cfg"]["model_name"],
|
||||
projection_dim=self.projection_dim,
|
||||
)
|
||||
self.text_encoder = TextTransformer(cfg=cfg["text_cfg"], projection_dim=self.projection_dim)
|
||||
self.logit_scale = nn.Parameter(torch.ones([]) * math.log(1.0 / 0.07))
|
||||
|
||||
def _exponentiate_and_clip_logits(self, max_scale: float = 100.0):
|
||||
scale = self.logit_scale.exp()
|
||||
scale = torch.clamp(scale, 0, max_scale)
|
||||
return scale
|
||||
|
||||
def encode_image(self, image: torch.Tensor, normalize: bool = False):
|
||||
image_encoder_out = self.image_encoder(image)
|
||||
if isinstance(image_encoder_out, dict):
|
||||
features = image_encoder_out["logits"]
|
||||
else:
|
||||
features = image_encoder_out
|
||||
return F.normalize(features, dim=-1) if normalize else features
|
||||
|
||||
def encode_text(self, text: torch.Tensor, normalize: bool = False):
|
||||
text_features = self.text_encoder(text_tokens=text, key_padding_mask=None)
|
||||
return F.normalize(text_features, dim=-1) if normalize else text_features
|
||||
|
||||
def forward(self, image: torch.Tensor | None = None, text: torch.Tensor | None = None, *args, **kwargs) -> Any:
|
||||
image_embeddings = self.encode_image(image, normalize=True) if image is not None else None
|
||||
text_embeddings = self.encode_text(text, normalize=True) if text is not None else None
|
||||
|
||||
if self.output_dict:
|
||||
return {
|
||||
"image_features": image_embeddings,
|
||||
"text_features": text_embeddings,
|
||||
"logit_scale": self._exponentiate_and_clip_logits(),
|
||||
}
|
||||
return image_embeddings, text_embeddings, self._exponentiate_and_clip_logits()
|
||||
18
mobileclip/configs/mobileclip_b.json
Normal file
18
mobileclip/configs/mobileclip_b.json
Normal file
@@ -0,0 +1,18 @@
|
||||
{
|
||||
"embed_dim": 512,
|
||||
"image_cfg": {
|
||||
"image_size": 224,
|
||||
"model_name": "vit_b16"
|
||||
},
|
||||
"text_cfg": {
|
||||
"context_length": 77,
|
||||
"vocab_size": 49408,
|
||||
"dim": 512,
|
||||
"ffn_multiplier_per_layer": 4.0,
|
||||
"n_heads_per_layer": 8,
|
||||
"n_transformer_layers": 12,
|
||||
"norm_layer": "layer_norm_fp32",
|
||||
"causal_masking": true,
|
||||
"model_name": "base"
|
||||
}
|
||||
}
|
||||
18
mobileclip/configs/mobileclip_s0.json
Normal file
18
mobileclip/configs/mobileclip_s0.json
Normal file
@@ -0,0 +1,18 @@
|
||||
{
|
||||
"embed_dim": 512,
|
||||
"image_cfg": {
|
||||
"image_size": 256,
|
||||
"model_name": "mci0"
|
||||
},
|
||||
"text_cfg": {
|
||||
"context_length": 77,
|
||||
"vocab_size": 49408,
|
||||
"dim": 512,
|
||||
"ffn_multiplier_per_layer": 4.0,
|
||||
"n_heads_per_layer": 8,
|
||||
"n_transformer_layers": 4,
|
||||
"norm_layer": "layer_norm_fp32",
|
||||
"causal_masking": false,
|
||||
"model_name": "mct"
|
||||
}
|
||||
}
|
||||
18
mobileclip/configs/mobileclip_s1.json
Normal file
18
mobileclip/configs/mobileclip_s1.json
Normal file
@@ -0,0 +1,18 @@
|
||||
{
|
||||
"embed_dim": 512,
|
||||
"image_cfg": {
|
||||
"image_size": 256,
|
||||
"model_name": "mci1"
|
||||
},
|
||||
"text_cfg": {
|
||||
"context_length": 77,
|
||||
"vocab_size": 49408,
|
||||
"dim": 512,
|
||||
"ffn_multiplier_per_layer": 4.0,
|
||||
"n_heads_per_layer": 8,
|
||||
"n_transformer_layers": 12,
|
||||
"norm_layer": "layer_norm_fp32",
|
||||
"causal_masking": false,
|
||||
"model_name": "base"
|
||||
}
|
||||
}
|
||||
18
mobileclip/configs/mobileclip_s2.json
Normal file
18
mobileclip/configs/mobileclip_s2.json
Normal file
@@ -0,0 +1,18 @@
|
||||
{
|
||||
"embed_dim": 512,
|
||||
"image_cfg": {
|
||||
"image_size": 256,
|
||||
"model_name": "mci2"
|
||||
},
|
||||
"text_cfg": {
|
||||
"context_length": 77,
|
||||
"vocab_size": 49408,
|
||||
"dim": 512,
|
||||
"ffn_multiplier_per_layer": 4.0,
|
||||
"n_heads_per_layer": 8,
|
||||
"n_transformer_layers": 12,
|
||||
"norm_layer": "layer_norm_fp32",
|
||||
"causal_masking": false,
|
||||
"model_name": "base"
|
||||
}
|
||||
}
|
||||
63
mobileclip/image_encoder.py
Normal file
63
mobileclip/image_encoder.py
Normal file
@@ -0,0 +1,63 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
#
|
||||
# For licensing see accompanying LICENSE file.
|
||||
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
|
||||
#
|
||||
from typing import Any
|
||||
|
||||
import torch.nn as nn
|
||||
from timm.models import create_model
|
||||
|
||||
from mobileclip import models # noqa added to register models
|
||||
from mobileclip.modules.image.image_projection import GlobalPool2D
|
||||
|
||||
|
||||
class MCi(nn.Module):
|
||||
"""This class implements `MCi Models <https://arxiv.org/pdf/2311.17049.pdf>`_."""
|
||||
|
||||
def __init__(self, model_name: str, *args, **kwargs) -> None:
|
||||
super().__init__()
|
||||
self.projection_dim = None
|
||||
if "projection_dim" in kwargs:
|
||||
self.projection_dim = kwargs.get("projection_dim")
|
||||
|
||||
# Create model
|
||||
self.model = create_model(model_name, projection_dim=self.projection_dim)
|
||||
|
||||
# Build out projection head.
|
||||
if self.projection_dim is not None:
|
||||
if hasattr(self.model, "head"):
|
||||
self.model.head = MCi._update_image_classifier(
|
||||
image_classifier=self.model.head, projection_dim=self.projection_dim
|
||||
)
|
||||
|
||||
def forward(self, x: Any, *args, **kwargs) -> Any:
|
||||
"""A forward function of the model."""
|
||||
x = self.model(x)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def _get_in_feature_dimension(image_classifier: nn.Module) -> int:
|
||||
"""Return the input feature dimension to the image classification head."""
|
||||
in_features = None
|
||||
if isinstance(image_classifier, nn.Sequential):
|
||||
# Classifier that uses nn.Sequential usually has global pooling and
|
||||
# multiple linear layers. Find the first linear layer and get its
|
||||
# in_features
|
||||
for layer in image_classifier:
|
||||
if isinstance(layer, nn.Linear):
|
||||
in_features = layer.in_features
|
||||
break
|
||||
elif isinstance(image_classifier, nn.Linear):
|
||||
in_features = image_classifier.in_features
|
||||
|
||||
if in_features is None:
|
||||
raise NotImplementedError(f"Cannot get input feature dimension of {image_classifier}.")
|
||||
return in_features
|
||||
|
||||
@staticmethod
|
||||
def _update_image_classifier(image_classifier: nn.Module, projection_dim: int, *args, **kwargs) -> nn.Module:
|
||||
in_features = MCi._get_in_feature_dimension(image_classifier)
|
||||
new_img_classifier = GlobalPool2D(in_dim=in_features, out_dim=projection_dim)
|
||||
return new_img_classifier
|
||||
120
mobileclip/logger.py
Normal file
120
mobileclip/logger.py
Normal file
@@ -0,0 +1,120 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
# For licensing see accompanying LICENSE file.
|
||||
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
|
||||
text_colors = {
|
||||
"logs": "\033[34m", # 033 is the escape code and 34 is the color code
|
||||
"info": "\033[32m",
|
||||
"warning": "\033[33m",
|
||||
"debug": "\033[93m",
|
||||
"error": "\033[31m",
|
||||
"bold": "\033[1m",
|
||||
"end_color": "\033[0m",
|
||||
"light_red": "\033[36m",
|
||||
}
|
||||
|
||||
|
||||
def get_curr_time_stamp() -> str:
|
||||
return time.strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
|
||||
def error(message: str) -> None:
|
||||
time_stamp = get_curr_time_stamp()
|
||||
error_str = text_colors["error"] + text_colors["bold"] + "ERROR " + text_colors["end_color"]
|
||||
|
||||
# exiting with code -1 does not tell any information about the error (e.g., NaN encountered in the loss).
|
||||
# For more descriptive error messages, we replace exit(-1) with sys.exit(ERROR_MESSAGE).
|
||||
# This allows us to handle specific exceptions in the tests.
|
||||
|
||||
# print("{} - {} - {}".format(time_stamp, error_str, message), flush=True)
|
||||
# print("{} - {} - {}".format(time_stamp, error_str, "Exiting!!!"), flush=True)
|
||||
# exit(-1)
|
||||
|
||||
if sys.exc_info()[0] is None:
|
||||
traceback.print_stack()
|
||||
else:
|
||||
traceback.print_exc()
|
||||
sys.exit(f"{time_stamp} - {error_str} - {message}. Exiting!!!")
|
||||
|
||||
|
||||
def color_text(in_text: str) -> str:
|
||||
return text_colors["light_red"] + in_text + text_colors["end_color"]
|
||||
|
||||
|
||||
def log(message: str, end="\n") -> None:
|
||||
time_stamp = get_curr_time_stamp()
|
||||
log_str = text_colors["logs"] + text_colors["bold"] + "LOGS " + text_colors["end_color"]
|
||||
print(f"{time_stamp} - {log_str} - {message}", end=end)
|
||||
|
||||
|
||||
def warning(message: str | Warning) -> None:
|
||||
if isinstance(message, Warning):
|
||||
message = f"{type(message).__name__}({','.join(map(repr, message.args))}"
|
||||
|
||||
time_stamp = get_curr_time_stamp()
|
||||
warn_str = text_colors["warning"] + text_colors["bold"] + "WARNING" + text_colors["end_color"]
|
||||
print(f"{time_stamp} - {warn_str} - {message}")
|
||||
|
||||
|
||||
def ignore_exception_with_warning(message: str) -> None:
|
||||
"""After catching a tolerable exception E1 (e.g. when Model.forward() fails during profiling with try-catch, it'll
|
||||
be helpful to log the exception for future investigation. But printing the error stack trace, as is, could be
|
||||
confusing when an uncaught (non-tolerable) exception "E2" raises down the road. Then, the log will contain two
|
||||
stack traces for E1, E2. When looking for errors in logs, users should look for E2, but they may find E1.
|
||||
|
||||
This function appends "(WARNING)" at the end of all lines of the E1 traceback, so that the user can distinguish E1
|
||||
from uncaught exception E2.
|
||||
|
||||
Args:
|
||||
message: Extra explanation and context for debugging. (Note: the exception obj will be automatically fetched
|
||||
from python. No need to pass it as an argument or as message)
|
||||
"""
|
||||
warning(f"{message}:\n{traceback.format_exc()}".replace("\n", "\n(WARNING)"))
|
||||
|
||||
|
||||
def info(message: str, print_line: bool | None = False) -> None:
|
||||
time_stamp = get_curr_time_stamp()
|
||||
info_str = text_colors["info"] + text_colors["bold"] + "INFO " + text_colors["end_color"]
|
||||
print(f"{time_stamp} - {info_str} - {message}")
|
||||
if print_line:
|
||||
double_dash_line(dashes=150)
|
||||
|
||||
|
||||
def debug(message: str) -> None:
|
||||
time_stamp = get_curr_time_stamp()
|
||||
log_str = text_colors["debug"] + text_colors["bold"] + "DEBUG " + text_colors["end_color"]
|
||||
print(f"{time_stamp} - {log_str} - {message}")
|
||||
|
||||
|
||||
def double_dash_line(dashes: int | None = 75) -> None:
|
||||
print(text_colors["error"] + "=" * dashes + text_colors["end_color"])
|
||||
|
||||
|
||||
def singe_dash_line(dashes: int | None = 67) -> None:
|
||||
print("-" * dashes)
|
||||
|
||||
|
||||
def print_header(header: str) -> None:
|
||||
double_dash_line()
|
||||
print(text_colors["info"] + text_colors["bold"] + "=" * 50 + str(header) + text_colors["end_color"])
|
||||
double_dash_line()
|
||||
|
||||
|
||||
def print_header_minor(header: str) -> None:
|
||||
print(text_colors["warning"] + text_colors["bold"] + "=" * 25 + str(header) + text_colors["end_color"])
|
||||
|
||||
|
||||
def disable_printing():
|
||||
sys.stdout = open(os.devnull, "w")
|
||||
|
||||
|
||||
def enable_printing():
|
||||
sys.stdout = sys.__stdout__
|
||||
12
mobileclip/models/__init__.py
Normal file
12
mobileclip/models/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
#
|
||||
# For licensing see accompanying LICENSE file.
|
||||
# Copyright (C) 2024 Apple Inc. All rights reserved.
|
||||
#
|
||||
from .mci import (
|
||||
mci0,
|
||||
mci1,
|
||||
mci2,
|
||||
)
|
||||
from .vit import vit_b16
|
||||
888
mobileclip/models/mci.py
Normal file
888
mobileclip/models/mci.py
Normal file
@@ -0,0 +1,888 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
# For licensing see accompanying LICENSE file.
|
||||
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.models import register_model
|
||||
from timm.models.layers import DropPath, trunc_normal_
|
||||
|
||||
from mobileclip.modules.common.mobileone import MobileOneBlock
|
||||
from mobileclip.modules.image.replknet import ReparamLargeKernelConv
|
||||
|
||||
|
||||
def _cfg(url="", **kwargs):
|
||||
return {
|
||||
"url": url,
|
||||
"num_classes": 1000,
|
||||
"input_size": (3, 256, 256),
|
||||
"pool_size": None,
|
||||
"crop_pct": 0.95,
|
||||
"interpolation": "bicubic",
|
||||
"mean": IMAGENET_DEFAULT_MEAN,
|
||||
"std": IMAGENET_DEFAULT_STD,
|
||||
"classifier": "head",
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = {
|
||||
"fastvit_t": _cfg(crop_pct=0.9),
|
||||
"fastvit_s": _cfg(crop_pct=0.9),
|
||||
"fastvit_m": _cfg(crop_pct=0.95),
|
||||
}
|
||||
|
||||
|
||||
def convolutional_stem(in_channels: int, out_channels: int, inference_mode: bool = False) -> nn.Sequential:
|
||||
"""Build convolutional stem with MobileOne blocks.
|
||||
|
||||
Args:
|
||||
in_channels: Number of input channels.
|
||||
out_channels: Number of output channels.
|
||||
inference_mode: Flag to instantiate model in inference mode. Default: ``False``
|
||||
|
||||
Returns:
|
||||
nn.Sequential object with stem elements.
|
||||
"""
|
||||
return nn.Sequential(
|
||||
MobileOneBlock(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
groups=1,
|
||||
inference_mode=inference_mode,
|
||||
use_se=False,
|
||||
num_conv_branches=1,
|
||||
),
|
||||
MobileOneBlock(
|
||||
in_channels=out_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
groups=out_channels,
|
||||
inference_mode=inference_mode,
|
||||
use_se=False,
|
||||
num_conv_branches=1,
|
||||
),
|
||||
MobileOneBlock(
|
||||
in_channels=out_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
groups=1,
|
||||
inference_mode=inference_mode,
|
||||
use_se=False,
|
||||
num_conv_branches=1,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class MHSA(nn.Module):
|
||||
"""Multi-headed Self Attention module.
|
||||
|
||||
Source modified from:
|
||||
https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
head_dim: int = 32,
|
||||
qkv_bias: bool = False,
|
||||
attn_drop: float = 0.0,
|
||||
proj_drop: float = 0.0,
|
||||
) -> None:
|
||||
"""Build MHSA module that can handle 3D or 4D input tensors.
|
||||
|
||||
Args:
|
||||
dim: Number of embedding dimensions.
|
||||
head_dim: Number of hidden dimensions per head. Default: ``32``
|
||||
qkv_bias: Use bias or not. Default: ``False``
|
||||
attn_drop: Dropout rate for attention tensor.
|
||||
proj_drop: Dropout rate for projection tensor.
|
||||
"""
|
||||
super().__init__()
|
||||
assert dim % head_dim == 0, "dim should be divisible by head_dim"
|
||||
self.head_dim = head_dim
|
||||
self.num_heads = dim // head_dim
|
||||
self.scale = head_dim**-0.5
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
shape = x.shape
|
||||
B, C, H, W = shape
|
||||
N = H * W
|
||||
if len(shape) == 4:
|
||||
x = torch.flatten(x, start_dim=2).transpose(-2, -1) # (B, N, C)
|
||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
|
||||
|
||||
# trick here to make q@k.t more stable
|
||||
attn = (q * self.scale) @ k.transpose(-2, -1)
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
|
||||
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
if len(shape) == 4:
|
||||
x = x.transpose(-2, -1).reshape(B, C, H, W)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
"""Convolutional patch embedding layer."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: int,
|
||||
stride: int,
|
||||
in_channels: int,
|
||||
embed_dim: int,
|
||||
inference_mode: bool = False,
|
||||
use_se: bool = False,
|
||||
) -> None:
|
||||
"""Build patch embedding layer.
|
||||
|
||||
Args:
|
||||
patch_size: Patch size for embedding computation.
|
||||
stride: Stride for convolutional embedding layer.
|
||||
in_channels: Number of channels of input tensor.
|
||||
embed_dim: Number of embedding dimensions.
|
||||
inference_mode: Flag to instantiate model in inference mode. Default: ``False``
|
||||
use_se: If ``True`` SE block will be used.
|
||||
"""
|
||||
super().__init__()
|
||||
block = list()
|
||||
block.append(
|
||||
ReparamLargeKernelConv(
|
||||
in_channels=in_channels,
|
||||
out_channels=embed_dim,
|
||||
kernel_size=patch_size,
|
||||
stride=stride,
|
||||
groups=in_channels,
|
||||
small_kernel=3,
|
||||
inference_mode=inference_mode,
|
||||
use_se=use_se,
|
||||
)
|
||||
)
|
||||
block.append(
|
||||
MobileOneBlock(
|
||||
in_channels=embed_dim,
|
||||
out_channels=embed_dim,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
groups=1,
|
||||
inference_mode=inference_mode,
|
||||
use_se=False,
|
||||
num_conv_branches=1,
|
||||
)
|
||||
)
|
||||
self.proj = nn.Sequential(*block)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class RepMixer(nn.Module):
|
||||
"""Reparameterizable token mixer.
|
||||
|
||||
For more details, please refer to our paper: `FastViT: A Fast Hybrid Vision Transformer using Structural
|
||||
Reparameterization <https://arxiv.org/pdf/2303.14189.pdf>`_
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
kernel_size=3,
|
||||
use_layer_scale=True,
|
||||
layer_scale_init_value=1e-5,
|
||||
inference_mode: bool = False,
|
||||
):
|
||||
"""Build RepMixer Module.
|
||||
|
||||
Args:
|
||||
dim: Input feature map dimension. :math:`C_{in}` from an expected input of size :math:`(B, C_{in}, H, W)`.
|
||||
kernel_size: Kernel size for spatial mixing. Default: 3
|
||||
use_layer_scale: If True, learnable layer scale is used. Default: ``True``
|
||||
layer_scale_init_value: Initial value for layer scale. Default: 1e-5
|
||||
inference_mode: If True, instantiates model in inference mode. Default: ``False``
|
||||
"""
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.kernel_size = kernel_size
|
||||
self.inference_mode = inference_mode
|
||||
|
||||
if inference_mode:
|
||||
self.reparam_conv = nn.Conv2d(
|
||||
in_channels=self.dim,
|
||||
out_channels=self.dim,
|
||||
kernel_size=self.kernel_size,
|
||||
stride=1,
|
||||
padding=self.kernel_size // 2,
|
||||
groups=self.dim,
|
||||
bias=True,
|
||||
)
|
||||
else:
|
||||
self.norm = MobileOneBlock(
|
||||
dim,
|
||||
dim,
|
||||
kernel_size,
|
||||
padding=kernel_size // 2,
|
||||
groups=dim,
|
||||
use_act=False,
|
||||
use_scale_branch=False,
|
||||
num_conv_branches=0,
|
||||
)
|
||||
self.mixer = MobileOneBlock(
|
||||
dim,
|
||||
dim,
|
||||
kernel_size,
|
||||
padding=kernel_size // 2,
|
||||
groups=dim,
|
||||
use_act=False,
|
||||
)
|
||||
self.use_layer_scale = use_layer_scale
|
||||
if use_layer_scale:
|
||||
self.layer_scale = nn.Parameter(layer_scale_init_value * torch.ones((dim, 1, 1)), requires_grad=True)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if hasattr(self, "reparam_conv"):
|
||||
x = self.reparam_conv(x)
|
||||
return x
|
||||
else:
|
||||
if self.use_layer_scale:
|
||||
x = x + self.layer_scale * (self.mixer(x) - self.norm(x))
|
||||
else:
|
||||
x = x + self.mixer(x) - self.norm(x)
|
||||
return x
|
||||
|
||||
def reparameterize(self) -> None:
|
||||
"""Reparameterize mixer and norm into a single convolutional layer for efficient inference."""
|
||||
if self.inference_mode:
|
||||
return
|
||||
|
||||
self.mixer.reparameterize()
|
||||
self.norm.reparameterize()
|
||||
|
||||
if self.use_layer_scale:
|
||||
w = self.mixer.id_tensor + self.layer_scale.unsqueeze(-1) * (
|
||||
self.mixer.reparam_conv.weight - self.norm.reparam_conv.weight
|
||||
)
|
||||
b = torch.squeeze(self.layer_scale) * (self.mixer.reparam_conv.bias - self.norm.reparam_conv.bias)
|
||||
else:
|
||||
w = self.mixer.id_tensor + self.mixer.reparam_conv.weight - self.norm.reparam_conv.weight
|
||||
b = self.mixer.reparam_conv.bias - self.norm.reparam_conv.bias
|
||||
|
||||
self.reparam_conv = nn.Conv2d(
|
||||
in_channels=self.dim,
|
||||
out_channels=self.dim,
|
||||
kernel_size=self.kernel_size,
|
||||
stride=1,
|
||||
padding=self.kernel_size // 2,
|
||||
groups=self.dim,
|
||||
bias=True,
|
||||
)
|
||||
self.reparam_conv.weight.data = w
|
||||
self.reparam_conv.bias.data = b
|
||||
|
||||
for para in self.parameters():
|
||||
para.detach_()
|
||||
self.__delattr__("mixer")
|
||||
self.__delattr__("norm")
|
||||
if self.use_layer_scale:
|
||||
self.__delattr__("layer_scale")
|
||||
|
||||
|
||||
class ConvFFN(nn.Module):
|
||||
"""Convolutional FFN Module."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
hidden_channels: int | None = None,
|
||||
out_channels: int | None = None,
|
||||
act_layer: nn.Module = nn.GELU,
|
||||
drop: float = 0.0,
|
||||
) -> None:
|
||||
"""Build convolutional FFN module.
|
||||
|
||||
Args:
|
||||
in_channels: Number of input channels.
|
||||
hidden_channels: Number of channels after expansion. Default: None
|
||||
out_channels: Number of output channels. Default: None
|
||||
act_layer: Activation layer. Default: ``GELU``
|
||||
drop: Dropout rate. Default: ``0.0``.
|
||||
"""
|
||||
super().__init__()
|
||||
out_channels = out_channels or in_channels
|
||||
hidden_channels = hidden_channels or in_channels
|
||||
self.conv = nn.Sequential()
|
||||
self.conv.add_module(
|
||||
"conv",
|
||||
nn.Conv2d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=7,
|
||||
padding=3,
|
||||
groups=in_channels,
|
||||
bias=False,
|
||||
),
|
||||
)
|
||||
self.conv.add_module("bn", nn.BatchNorm2d(num_features=out_channels))
|
||||
self.fc1 = nn.Conv2d(in_channels, hidden_channels, kernel_size=1)
|
||||
self.act = act_layer()
|
||||
self.fc2 = nn.Conv2d(hidden_channels, out_channels, kernel_size=1)
|
||||
self.drop = nn.Dropout(drop)
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, m: nn.Module) -> None:
|
||||
if isinstance(m, nn.Conv2d):
|
||||
trunc_normal_(m.weight, std=0.02)
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.conv(x)
|
||||
x = self.fc1(x)
|
||||
x = self.act(x)
|
||||
x = self.drop(x)
|
||||
x = self.fc2(x)
|
||||
x = self.drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class RepCPE(nn.Module):
|
||||
"""Implementation of conditional positional encoding.
|
||||
|
||||
For more details refer to paper: `Conditional Positional Encodings for Vision Transformers
|
||||
<https://arxiv.org/pdf/2102.10882.pdf>`_
|
||||
|
||||
In our implementation, we can reparameterize this module to eliminate a skip connection.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
embed_dim: int = 768,
|
||||
spatial_shape: int | tuple[int, int] = (7, 7),
|
||||
inference_mode=False,
|
||||
) -> None:
|
||||
"""Build reparameterizable conditional positional encoding.
|
||||
|
||||
Args:
|
||||
in_channels: Number of input channels.
|
||||
embed_dim: Number of embedding dimensions. Default: 768
|
||||
spatial_shape: Spatial shape of kernel for positional encoding. Default: (7, 7)
|
||||
inference_mode: Flag to instantiate block in inference mode. Default: ``False``
|
||||
"""
|
||||
super().__init__()
|
||||
if isinstance(spatial_shape, int):
|
||||
spatial_shape = tuple([spatial_shape] * 2)
|
||||
assert isinstance(spatial_shape, tuple), (
|
||||
f'"spatial_shape" must by a sequence or int, get {type(spatial_shape)} instead.'
|
||||
)
|
||||
assert len(spatial_shape) == 2, f'Length of "spatial_shape" should be 2, got {len(spatial_shape)} instead.'
|
||||
|
||||
self.spatial_shape = spatial_shape
|
||||
self.embed_dim = embed_dim
|
||||
self.in_channels = in_channels
|
||||
self.groups = embed_dim
|
||||
|
||||
if inference_mode:
|
||||
self.reparam_conv = nn.Conv2d(
|
||||
in_channels=self.in_channels,
|
||||
out_channels=self.embed_dim,
|
||||
kernel_size=self.spatial_shape,
|
||||
stride=1,
|
||||
padding=int(self.spatial_shape[0] // 2),
|
||||
groups=self.embed_dim,
|
||||
bias=True,
|
||||
)
|
||||
else:
|
||||
self.pe = nn.Conv2d(
|
||||
in_channels,
|
||||
embed_dim,
|
||||
spatial_shape,
|
||||
1,
|
||||
int(spatial_shape[0] // 2),
|
||||
bias=True,
|
||||
groups=embed_dim,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if hasattr(self, "reparam_conv"):
|
||||
x = self.reparam_conv(x)
|
||||
return x
|
||||
else:
|
||||
x = self.pe(x) + x
|
||||
return x
|
||||
|
||||
def reparameterize(self) -> None:
|
||||
# Build equivalent Id tensor
|
||||
input_dim = self.in_channels // self.groups
|
||||
kernel_value = torch.zeros(
|
||||
(
|
||||
self.in_channels,
|
||||
input_dim,
|
||||
self.spatial_shape[0],
|
||||
self.spatial_shape[1],
|
||||
),
|
||||
dtype=self.pe.weight.dtype,
|
||||
device=self.pe.weight.device,
|
||||
)
|
||||
for i in range(self.in_channels):
|
||||
kernel_value[
|
||||
i,
|
||||
i % input_dim,
|
||||
self.spatial_shape[0] // 2,
|
||||
self.spatial_shape[1] // 2,
|
||||
] = 1
|
||||
id_tensor = kernel_value
|
||||
|
||||
# Reparameterize Id tensor and conv
|
||||
w_final = id_tensor + self.pe.weight
|
||||
b_final = self.pe.bias
|
||||
|
||||
# Introduce reparam conv
|
||||
self.reparam_conv = nn.Conv2d(
|
||||
in_channels=self.in_channels,
|
||||
out_channels=self.embed_dim,
|
||||
kernel_size=self.spatial_shape,
|
||||
stride=1,
|
||||
padding=int(self.spatial_shape[0] // 2),
|
||||
groups=self.embed_dim,
|
||||
bias=True,
|
||||
)
|
||||
self.reparam_conv.weight.data = w_final
|
||||
self.reparam_conv.bias.data = b_final
|
||||
|
||||
for para in self.parameters():
|
||||
para.detach_()
|
||||
self.__delattr__("pe")
|
||||
|
||||
|
||||
class RepMixerBlock(nn.Module):
|
||||
"""Implementation of Metaformer block with RepMixer as token mixer.
|
||||
|
||||
For more details on Metaformer structure, please refer to: `MetaFormer Is Actually What You Need for Vision
|
||||
<https://arxiv.org/pdf/2111.11418.pdf>`_
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
kernel_size: int = 3,
|
||||
mlp_ratio: float = 4.0,
|
||||
act_layer: nn.Module = nn.GELU,
|
||||
drop: float = 0.0,
|
||||
drop_path: float = 0.0,
|
||||
use_layer_scale: bool = True,
|
||||
layer_scale_init_value: float = 1e-5,
|
||||
inference_mode: bool = False,
|
||||
):
|
||||
"""Build RepMixer Block.
|
||||
|
||||
Args:
|
||||
dim: Number of embedding dimensions.
|
||||
kernel_size: Kernel size for repmixer. Default: 3
|
||||
mlp_ratio: MLP expansion ratio. Default: 4.0
|
||||
act_layer: Activation layer. Default: ``nn.GELU``
|
||||
drop: Dropout rate. Default: 0.0
|
||||
drop_path: Drop path rate. Default: 0.0
|
||||
use_layer_scale: Flag to turn on layer scale. Default: ``True``
|
||||
layer_scale_init_value: Layer scale value at initialization. Default: 1e-5
|
||||
inference_mode: Flag to instantiate block in inference mode. Default: ``False``
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.token_mixer = RepMixer(
|
||||
dim,
|
||||
kernel_size=kernel_size,
|
||||
use_layer_scale=use_layer_scale,
|
||||
layer_scale_init_value=layer_scale_init_value,
|
||||
inference_mode=inference_mode,
|
||||
)
|
||||
|
||||
assert mlp_ratio > 0, f"MLP ratio should be greater than 0, found: {mlp_ratio}"
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
self.convffn = ConvFFN(
|
||||
in_channels=dim,
|
||||
hidden_channels=mlp_hidden_dim,
|
||||
act_layer=act_layer,
|
||||
drop=drop,
|
||||
)
|
||||
|
||||
# Drop Path
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
||||
|
||||
# Layer Scale
|
||||
self.use_layer_scale = use_layer_scale
|
||||
if use_layer_scale:
|
||||
self.layer_scale = nn.Parameter(layer_scale_init_value * torch.ones((dim, 1, 1)), requires_grad=True)
|
||||
|
||||
def forward(self, x):
|
||||
if self.use_layer_scale:
|
||||
x = self.token_mixer(x)
|
||||
x = x + self.drop_path(self.layer_scale * self.convffn(x))
|
||||
else:
|
||||
x = self.token_mixer(x)
|
||||
x = x + self.drop_path(self.convffn(x))
|
||||
return x
|
||||
|
||||
|
||||
class AttentionBlock(nn.Module):
|
||||
"""Implementation of metaformer block with MHSA as token mixer.
|
||||
|
||||
For more details on Metaformer structure, please refer to: `MetaFormer Is Actually What You Need for Vision
|
||||
<https://arxiv.org/pdf/2111.11418.pdf>`_
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
act_layer: nn.Module = nn.GELU,
|
||||
norm_layer: nn.Module = nn.BatchNorm2d,
|
||||
drop: float = 0.0,
|
||||
drop_path: float = 0.0,
|
||||
use_layer_scale: bool = True,
|
||||
layer_scale_init_value: float = 1e-5,
|
||||
):
|
||||
"""Build Attention Block.
|
||||
|
||||
Args:
|
||||
dim: Number of embedding dimensions.
|
||||
mlp_ratio: MLP expansion ratio. Default: 4.0
|
||||
act_layer: Activation layer. Default: ``nn.GELU``
|
||||
norm_layer: Normalization layer. Default: ``nn.BatchNorm2d``
|
||||
drop: Dropout rate. Default: 0.0
|
||||
drop_path: Drop path rate. Default: 0.0
|
||||
use_layer_scale: Flag to turn on layer scale. Default: ``True``
|
||||
layer_scale_init_value: Layer scale value at initialization. Default: 1e-5
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.norm = norm_layer(dim)
|
||||
self.token_mixer = MHSA(dim=dim)
|
||||
|
||||
assert mlp_ratio > 0, f"MLP ratio should be greater than 0, found: {mlp_ratio}"
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
self.convffn = ConvFFN(
|
||||
in_channels=dim,
|
||||
hidden_channels=mlp_hidden_dim,
|
||||
act_layer=act_layer,
|
||||
drop=drop,
|
||||
)
|
||||
|
||||
# Drop path
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
||||
|
||||
# Layer Scale
|
||||
self.use_layer_scale = use_layer_scale
|
||||
if use_layer_scale:
|
||||
self.layer_scale_1 = nn.Parameter(layer_scale_init_value * torch.ones((dim, 1, 1)), requires_grad=True)
|
||||
self.layer_scale_2 = nn.Parameter(layer_scale_init_value * torch.ones((dim, 1, 1)), requires_grad=True)
|
||||
|
||||
def forward(self, x):
|
||||
if self.use_layer_scale:
|
||||
x = x + self.drop_path(self.layer_scale_1 * self.token_mixer(self.norm(x)))
|
||||
x = x + self.drop_path(self.layer_scale_2 * self.convffn(x))
|
||||
else:
|
||||
x = x + self.drop_path(self.token_mixer(self.norm(x)))
|
||||
x = x + self.drop_path(self.convffn(x))
|
||||
return x
|
||||
|
||||
|
||||
def basic_blocks(
|
||||
dim: int,
|
||||
block_index: int,
|
||||
num_blocks: list[int],
|
||||
token_mixer_type: str,
|
||||
kernel_size: int = 3,
|
||||
mlp_ratio: float = 4.0,
|
||||
act_layer: nn.Module = nn.GELU,
|
||||
norm_layer: nn.Module = nn.BatchNorm2d,
|
||||
drop_rate: float = 0.0,
|
||||
drop_path_rate: float = 0.0,
|
||||
use_layer_scale: bool = True,
|
||||
layer_scale_init_value: float = 1e-5,
|
||||
inference_mode=False,
|
||||
) -> nn.Sequential:
|
||||
"""Build FastViT blocks within a stage.
|
||||
|
||||
Args:
|
||||
dim: Number of embedding dimensions.
|
||||
block_index: block index.
|
||||
num_blocks: List containing number of blocks per stage.
|
||||
token_mixer_type: Token mixer type.
|
||||
kernel_size: Kernel size for repmixer.
|
||||
mlp_ratio: MLP expansion ratio.
|
||||
act_layer: Activation layer.
|
||||
norm_layer: Normalization layer.
|
||||
drop_rate: Dropout rate.
|
||||
drop_path_rate: Drop path rate.
|
||||
use_layer_scale: Flag to turn on layer scale regularization.
|
||||
layer_scale_init_value: Layer scale value at initialization.
|
||||
inference_mode: Flag to instantiate block in inference mode.
|
||||
|
||||
Returns:
|
||||
nn.Sequential object of all the blocks within the stage.
|
||||
"""
|
||||
blocks = []
|
||||
for block_idx in range(num_blocks[block_index]):
|
||||
block_dpr = drop_path_rate * (block_idx + sum(num_blocks[:block_index])) / (sum(num_blocks) - 1)
|
||||
if token_mixer_type == "repmixer":
|
||||
blocks.append(
|
||||
RepMixerBlock(
|
||||
dim,
|
||||
kernel_size=kernel_size,
|
||||
mlp_ratio=mlp_ratio,
|
||||
act_layer=act_layer,
|
||||
drop=drop_rate,
|
||||
drop_path=block_dpr,
|
||||
use_layer_scale=use_layer_scale,
|
||||
layer_scale_init_value=layer_scale_init_value,
|
||||
inference_mode=inference_mode,
|
||||
)
|
||||
)
|
||||
elif token_mixer_type == "attention":
|
||||
blocks.append(
|
||||
AttentionBlock(
|
||||
dim,
|
||||
mlp_ratio=mlp_ratio,
|
||||
act_layer=act_layer,
|
||||
norm_layer=norm_layer,
|
||||
drop=drop_rate,
|
||||
drop_path=block_dpr,
|
||||
use_layer_scale=use_layer_scale,
|
||||
layer_scale_init_value=layer_scale_init_value,
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Token mixer type: {token_mixer_type} not supported")
|
||||
blocks = nn.Sequential(*blocks)
|
||||
|
||||
return blocks
|
||||
|
||||
|
||||
class FastViT(nn.Module):
|
||||
"""This class implements `FastViT architecture <https://arxiv.org/pdf/2303.14189.pdf>`_."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layers,
|
||||
token_mixers: tuple[str, ...],
|
||||
embed_dims=None,
|
||||
mlp_ratios=None,
|
||||
downsamples=None,
|
||||
se_downsamples=None,
|
||||
repmixer_kernel_size=3,
|
||||
norm_layer: nn.Module = nn.BatchNorm2d,
|
||||
act_layer: nn.Module = nn.GELU,
|
||||
num_classes=1000,
|
||||
pos_embs=None,
|
||||
down_patch_size=7,
|
||||
down_stride=2,
|
||||
drop_rate=0.0,
|
||||
drop_path_rate=0.0,
|
||||
use_layer_scale=True,
|
||||
layer_scale_init_value=1e-5,
|
||||
init_cfg=None,
|
||||
pretrained=None,
|
||||
cls_ratio=2.0,
|
||||
inference_mode=False,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.num_classes = num_classes
|
||||
if pos_embs is None:
|
||||
pos_embs = [None] * len(layers)
|
||||
|
||||
if se_downsamples is None:
|
||||
se_downsamples = [False] * len(layers)
|
||||
|
||||
# Convolutional stem
|
||||
self.patch_embed = convolutional_stem(3, embed_dims[0], inference_mode)
|
||||
|
||||
# Build the main stages of the network architecture
|
||||
network = []
|
||||
for i in range(len(layers)):
|
||||
# Add position embeddings if requested
|
||||
if pos_embs[i] is not None:
|
||||
network.append(pos_embs[i](embed_dims[i], embed_dims[i], inference_mode=inference_mode))
|
||||
stage = basic_blocks(
|
||||
embed_dims[i],
|
||||
i,
|
||||
layers,
|
||||
token_mixer_type=token_mixers[i],
|
||||
kernel_size=repmixer_kernel_size,
|
||||
mlp_ratio=mlp_ratios[i],
|
||||
act_layer=act_layer,
|
||||
norm_layer=norm_layer,
|
||||
drop_rate=drop_rate,
|
||||
drop_path_rate=drop_path_rate,
|
||||
use_layer_scale=use_layer_scale,
|
||||
layer_scale_init_value=layer_scale_init_value,
|
||||
inference_mode=inference_mode,
|
||||
)
|
||||
network.append(stage)
|
||||
if i >= len(layers) - 1:
|
||||
break
|
||||
|
||||
# Patch merging/downsampling between stages.
|
||||
if downsamples[i] or embed_dims[i] != embed_dims[i + 1]:
|
||||
network.append(
|
||||
PatchEmbed(
|
||||
patch_size=down_patch_size,
|
||||
stride=down_stride,
|
||||
in_channels=embed_dims[i],
|
||||
embed_dim=embed_dims[i + 1],
|
||||
inference_mode=inference_mode,
|
||||
use_se=se_downsamples[i + 1],
|
||||
)
|
||||
)
|
||||
self.network = nn.ModuleList(network)
|
||||
|
||||
# Classifier head
|
||||
self.conv_exp = MobileOneBlock(
|
||||
in_channels=embed_dims[-1],
|
||||
out_channels=int(embed_dims[-1] * cls_ratio),
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
groups=embed_dims[-1],
|
||||
inference_mode=inference_mode,
|
||||
use_se=True,
|
||||
num_conv_branches=1,
|
||||
)
|
||||
self.head = nn.Linear(int(embed_dims[-1] * cls_ratio), num_classes) if num_classes > 0 else nn.Identity()
|
||||
self.apply(self.cls_init_weights)
|
||||
self.init_cfg = copy.deepcopy(init_cfg)
|
||||
|
||||
def cls_init_weights(self, m: nn.Module) -> None:
|
||||
"""Init.
|
||||
|
||||
for classification.
|
||||
"""
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=0.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def forward_embeddings(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.patch_embed(x)
|
||||
return x
|
||||
|
||||
def forward_tokens(self, x: torch.Tensor) -> torch.Tensor:
|
||||
for idx, block in enumerate(self.network):
|
||||
x = block(x)
|
||||
# output only the features of last layer for image classification
|
||||
return x
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# input embedding
|
||||
x = self.forward_embeddings(x)
|
||||
# through backbone
|
||||
x = self.forward_tokens(x)
|
||||
# for image classification
|
||||
x = self.conv_exp(x)
|
||||
cls_out = self.head(x)
|
||||
return cls_out
|
||||
|
||||
|
||||
@register_model
|
||||
def mci0(pretrained=False, **kwargs):
|
||||
"""Instantiate MCi0 model variant."""
|
||||
layers = [2, 6, 10, 2]
|
||||
embed_dims = [64, 128, 256, 512]
|
||||
mlp_ratios = [3, 3, 3, 3]
|
||||
downsamples = [True, True, True, True]
|
||||
se_downsamples = [False, False, True, True]
|
||||
pos_embs = [None, None, None, partial(RepCPE, spatial_shape=(7, 7))]
|
||||
token_mixers = ("repmixer", "repmixer", "repmixer", "attention")
|
||||
model = FastViT(
|
||||
layers,
|
||||
token_mixers=token_mixers,
|
||||
embed_dims=embed_dims,
|
||||
pos_embs=pos_embs,
|
||||
mlp_ratios=mlp_ratios,
|
||||
downsamples=downsamples,
|
||||
se_downsamples=se_downsamples,
|
||||
**kwargs,
|
||||
)
|
||||
model.default_cfg = default_cfgs["fastvit_s"]
|
||||
if pretrained:
|
||||
raise ValueError("Functionality not implemented.")
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def mci1(pretrained=False, **kwargs):
|
||||
"""Instantiate MCi1 model variant."""
|
||||
layers = [4, 12, 20, 4]
|
||||
embed_dims = [64, 128, 256, 512]
|
||||
mlp_ratios = [3, 3, 3, 3]
|
||||
downsamples = [True, True, True, True]
|
||||
se_downsamples = [False, False, True, True]
|
||||
pos_embs = [None, None, None, partial(RepCPE, spatial_shape=(7, 7))]
|
||||
token_mixers = ("repmixer", "repmixer", "repmixer", "attention")
|
||||
model = FastViT(
|
||||
layers,
|
||||
token_mixers=token_mixers,
|
||||
embed_dims=embed_dims,
|
||||
pos_embs=pos_embs,
|
||||
mlp_ratios=mlp_ratios,
|
||||
downsamples=downsamples,
|
||||
se_downsamples=se_downsamples,
|
||||
**kwargs,
|
||||
)
|
||||
model.default_cfg = default_cfgs["fastvit_s"]
|
||||
if pretrained:
|
||||
raise ValueError("Functionality not implemented.")
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def mci2(pretrained=False, **kwargs):
|
||||
"""Instantiate MCi2 model variant."""
|
||||
layers = [4, 12, 24, 4]
|
||||
embed_dims = [80, 160, 320, 640]
|
||||
mlp_ratios = [3, 3, 3, 3]
|
||||
downsamples = [True, True, True, True]
|
||||
se_downsamples = [False, False, True, True]
|
||||
pos_embs = [None, None, None, partial(RepCPE, spatial_shape=(7, 7))]
|
||||
token_mixers = ("repmixer", "repmixer", "repmixer", "attention")
|
||||
model = FastViT(
|
||||
layers,
|
||||
token_mixers=token_mixers,
|
||||
embed_dims=embed_dims,
|
||||
pos_embs=pos_embs,
|
||||
mlp_ratios=mlp_ratios,
|
||||
downsamples=downsamples,
|
||||
se_downsamples=se_downsamples,
|
||||
**kwargs,
|
||||
)
|
||||
model.default_cfg = default_cfgs["fastvit_m"]
|
||||
if pretrained:
|
||||
raise ValueError("Functionality not implemented.")
|
||||
return model
|
||||
389
mobileclip/models/vit.py
Normal file
389
mobileclip/models/vit.py
Normal file
@@ -0,0 +1,389 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
#
|
||||
# For licensing see accompanying LICENSE file.
|
||||
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
|
||||
#
|
||||
"""
|
||||
Implementation of the following modules is borrowed from ml-cvnets repo:
|
||||
https://github.com/apple/ml-cvnets/blob/main/cvnets/models/classification/vit.py.
|
||||
|
||||
Please see ACKNOWLEDGMENTS for license details.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from timm.models import register_model
|
||||
from torch import Tensor, nn
|
||||
|
||||
from mobileclip import logger
|
||||
from mobileclip.modules.common.transformer import (
|
||||
PositionalEmbedding,
|
||||
TransformerEncoder,
|
||||
get_normalization_layer,
|
||||
)
|
||||
from mobileclip.modules.image.image_projection import SimpleImageProjectionHead
|
||||
|
||||
|
||||
class ConvNormAct(nn.Module):
|
||||
"""Applies an N-dimensional convolution over an input.
|
||||
|
||||
Args:
|
||||
cfg: Model configuration.
|
||||
in_channels: :math:`C_{out}` from an expected output of size :math:`(bs, C_{in}, X_{1}, ..., X_{N})`.
|
||||
out_channels: :math:`C_{out}` from an expected output of size :math:`(bs, C_{out}, Y_{1}, ..., Y_{N})`.
|
||||
kernel_size: Kernel size for convolution. An integer, or tuple of length ``N``.
|
||||
stride: Stride for convolution. An integer, or tuple of length ``N``. Default: 1.
|
||||
dilation: Dilation rate for convolution. An integer, or tuple of length ``N``. Default: ``1``.
|
||||
padding: Padding for convolution. An integer, or tuple of length ``N``. If not specified, padding is
|
||||
automatically computed based on kernel size and dilation range. Default : ``None`` (equivalent to ``[
|
||||
int((kernel_size[i] - 1) / 2) * dilation[i] for i in range(N)]``).
|
||||
groups: Number of groups in convolution. Default: ``1``.
|
||||
bias: Use bias. Default: ``False``.
|
||||
padding_mode: Padding mode ('zeros', 'reflect', 'replicate' or 'circular'). Default: ``zeros``.
|
||||
use_norm: Use normalization layer after convolution. Default: ``True``.
|
||||
use_act: Use activation layer after convolution (or convolution and normalization). Default: ``True``.
|
||||
norm_layer: If not None, the provided normalization layer object will be used. Otherwise, a normalization object
|
||||
will be created based on config ``model.normalization.*`` opts.
|
||||
act_layer: If not None, the provided activation function will be used. Otherwise, an activation function will be
|
||||
created based on config ``model.activation.*`` opts.
|
||||
|
||||
Notes:
|
||||
- Input: :math:`(bs, C_{in}, X_{1}, ..., X_{N})`.
|
||||
- Output: :math:`(bs, C_{out}, Y_{1}, ..., Y_{N})`.
|
||||
- For depth-wise convolution, `groups=C_{in}=C_{out}`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cfg: dict,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: int | tuple[int, ...],
|
||||
stride: int | tuple[int, ...] = 1,
|
||||
dilation: int | tuple[int, ...] = 1,
|
||||
padding: int | tuple[int, ...] | None = None,
|
||||
groups: int = 1,
|
||||
bias: bool = False,
|
||||
padding_mode: str = "zeros",
|
||||
use_norm: bool = True,
|
||||
use_act: bool = True,
|
||||
norm_layer: nn.Module | None = None,
|
||||
act_layer: nn.Module | None = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.ndim = 2
|
||||
|
||||
if norm_layer is None and use_norm:
|
||||
norm_type = cfg.get("normalization", "batch_norm")
|
||||
if norm_type == "batch_norm":
|
||||
norm_layer = nn.BatchNorm2d(
|
||||
num_features=out_channels,
|
||||
momentum=cfg.get("momentum", 0.1),
|
||||
)
|
||||
else:
|
||||
norm_layer = get_normalization_layer(num_features=out_channels, norm_type=norm_type)
|
||||
elif norm_layer is not None and use_norm:
|
||||
logger.error(f"When use_norm is False, norm_layer should be None, but norm_layer={norm_layer} is provided.")
|
||||
|
||||
if act_layer is None and use_act:
|
||||
act_layer = nn.GELU() # Default to GELU
|
||||
elif act_layer is not None and use_act:
|
||||
logger.error(f"When use_act is False, act_layer should be None, but act_layer={act_layer} is provided.")
|
||||
|
||||
if use_norm and any(param[0] == "bias" for param in norm_layer.named_parameters()) and bias:
|
||||
assert not bias, "Do not use bias when using normalization layers with bias."
|
||||
|
||||
if isinstance(kernel_size, int):
|
||||
kernel_size = (kernel_size,) * self.ndim
|
||||
|
||||
if isinstance(stride, int):
|
||||
stride = (stride,) * self.ndim
|
||||
|
||||
if isinstance(dilation, int):
|
||||
dilation = (dilation,) * self.ndim
|
||||
|
||||
assert isinstance(kernel_size, tuple)
|
||||
assert isinstance(stride, tuple)
|
||||
assert isinstance(dilation, tuple)
|
||||
|
||||
if padding is None:
|
||||
padding = (int((kernel_size[i] - 1) / 2) * dilation[i] for i in range(self.ndim))
|
||||
|
||||
if in_channels % groups != 0:
|
||||
logger.error(f"Input channels are not divisible by groups. {in_channels}%{groups} != 0 ")
|
||||
if out_channels % groups != 0:
|
||||
logger.error(f"Output channels are not divisible by groups. {out_channels}%{groups} != 0 ")
|
||||
|
||||
block = nn.Sequential()
|
||||
|
||||
conv_layer = nn.Conv2d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size, # type: ignore
|
||||
stride=stride, # type: ignore
|
||||
padding=padding,
|
||||
dilation=dilation, # type: ignore
|
||||
groups=groups,
|
||||
bias=bias,
|
||||
padding_mode=padding_mode,
|
||||
)
|
||||
|
||||
block.add_module(name="conv", module=conv_layer)
|
||||
|
||||
self.norm_name = None
|
||||
if use_norm:
|
||||
block.add_module(name="norm", module=norm_layer)
|
||||
self.norm_name = norm_layer.__class__.__name__
|
||||
|
||||
self.act_name = None
|
||||
if use_act:
|
||||
block.add_module(name="act", module=act_layer)
|
||||
self.act_name = act_layer.__class__.__name__
|
||||
|
||||
self.block = block
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.stride = stride
|
||||
self.groups = groups
|
||||
self.kernel_size = conv_layer.kernel_size
|
||||
self.bias = bias
|
||||
self.dilation = dilation
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return self.block(x)
|
||||
|
||||
|
||||
class VisionTransformer(nn.Module):
|
||||
"""This class defines the `Vision Transformer architecture <https://arxiv.org/abs/2010.11929>`_. Our model
|
||||
implementation is inspired from `Early Convolutions Help Transformers See
|
||||
Better <https://arxiv.org/abs/2106.14881>`_.
|
||||
|
||||
.. note::
|
||||
Our implementation is different from the original implementation in two ways:
|
||||
1. Kernel size is odd.
|
||||
2. Our positional encoding implementation allows us to use ViT with any multiple input scales
|
||||
3. We do not use StochasticDepth
|
||||
4. We do not add positional encoding to class token (if enabled), as suggested in `DeiT-3 paper <https://arxiv.org/abs/2204.07118>`_
|
||||
"""
|
||||
|
||||
def __init__(self, cfg, *args, **kwargs) -> None:
|
||||
super().__init__()
|
||||
image_channels = 3
|
||||
num_classes = cfg.get("n_classes", 1000)
|
||||
|
||||
self.projection_dim = None
|
||||
if "projection_dim" in kwargs:
|
||||
self.projection_dim = kwargs.get("projection_dim")
|
||||
|
||||
kernel_sizes_conv_stem = [4, 2, 2]
|
||||
strides_conv_stem = [4, 2, 2]
|
||||
|
||||
# Typically, in the ImageNet dataset, we use 224x224 as a resolution.
|
||||
# For out ViT implementation, patch size is 16 (16 = 4 * 2 * 2)
|
||||
# Therefore, total number of embeddings along width and height are (224 / 16)^2
|
||||
num_embeddings = (224 // 16) ** 2
|
||||
|
||||
embed_dim = cfg["embed_dim"]
|
||||
ffn_dim = cfg["embed_dim"] * 4
|
||||
pos_emb_drop_p = cfg.get("pos_emb_drop_p", 0.0)
|
||||
n_transformer_layers = cfg["n_transformer_layers"]
|
||||
num_heads = cfg["n_attn_heads"]
|
||||
attn_dropout = cfg.get("attn_dropout", 0.0)
|
||||
dropout = cfg.get("dropout", 0.0)
|
||||
ffn_dropout = cfg.get("ffn_dropout", 0.0)
|
||||
norm_layer = cfg.get("norm_layer", "layer_norm")
|
||||
|
||||
conv_stem_proj_dim = max(32, embed_dim // 4)
|
||||
patch_emb = [
|
||||
ConvNormAct(
|
||||
cfg=cfg,
|
||||
in_channels=image_channels,
|
||||
out_channels=conv_stem_proj_dim,
|
||||
kernel_size=kernel_sizes_conv_stem[0],
|
||||
stride=strides_conv_stem[0],
|
||||
bias=False,
|
||||
use_norm=True,
|
||||
use_act=True,
|
||||
),
|
||||
ConvNormAct(
|
||||
cfg=cfg,
|
||||
in_channels=conv_stem_proj_dim,
|
||||
out_channels=conv_stem_proj_dim,
|
||||
kernel_size=kernel_sizes_conv_stem[1],
|
||||
stride=strides_conv_stem[1],
|
||||
bias=False,
|
||||
use_norm=True,
|
||||
use_act=True,
|
||||
),
|
||||
ConvNormAct(
|
||||
cfg=cfg,
|
||||
in_channels=conv_stem_proj_dim,
|
||||
out_channels=embed_dim,
|
||||
kernel_size=kernel_sizes_conv_stem[2],
|
||||
stride=strides_conv_stem[2],
|
||||
bias=True,
|
||||
use_norm=False,
|
||||
use_act=False,
|
||||
),
|
||||
]
|
||||
|
||||
self.patch_emb = nn.Sequential(*patch_emb)
|
||||
|
||||
use_cls_token = not cfg.get("no_cls_token", False)
|
||||
stochastic_dropout = cfg.get("stochastic_dropout", 0.0)
|
||||
per_layer_stochastic_drop_rate = [round(x, 3) for x in np.linspace(0, stochastic_dropout, n_transformer_layers)]
|
||||
transformer_blocks = [
|
||||
TransformerEncoder(
|
||||
embed_dim=embed_dim,
|
||||
ffn_latent_dim=ffn_dim,
|
||||
num_heads=num_heads,
|
||||
attn_dropout=attn_dropout,
|
||||
dropout=dropout,
|
||||
ffn_dropout=ffn_dropout,
|
||||
transformer_norm_layer=norm_layer,
|
||||
stochastic_dropout=per_layer_stochastic_drop_rate[layer_idx],
|
||||
)
|
||||
for layer_idx in range(n_transformer_layers)
|
||||
]
|
||||
|
||||
self.post_transformer_norm = get_normalization_layer(num_features=embed_dim, norm_type=norm_layer)
|
||||
|
||||
self.transformer = nn.Sequential(*transformer_blocks)
|
||||
|
||||
if self.projection_dim is None:
|
||||
self.classifier = nn.Linear(embed_dim, num_classes)
|
||||
else:
|
||||
self.classifier = SimpleImageProjectionHead(embed_dim, self.projection_dim)
|
||||
|
||||
if use_cls_token:
|
||||
self.cls_token = nn.Parameter(torch.zeros(size=(1, 1, embed_dim)))
|
||||
torch.nn.init.trunc_normal_(self.cls_token, std=0.02)
|
||||
else:
|
||||
self.cls_token = None
|
||||
|
||||
self.pos_embed = PositionalEmbedding(
|
||||
num_embeddings=num_embeddings,
|
||||
embedding_dim=embed_dim,
|
||||
padding_idx=None,
|
||||
interpolation_mode="bilinear",
|
||||
)
|
||||
self.emb_dropout = nn.Dropout(p=pos_emb_drop_p)
|
||||
|
||||
def extract_patch_embeddings(self, x: Tensor) -> tuple[Tensor, tuple[int, int]]:
|
||||
# input is of shape [Batch, in_channels, height, width]. in_channels is mostly 3 (for RGB images)
|
||||
batch_size = x.shape[0]
|
||||
|
||||
# [Batch, in_channels, height, width] --> [Batch, emb_dim, num_patches_height, num_patches_width]
|
||||
patch_emb = self.patch_emb(x)
|
||||
n_h, n_w = patch_emb.shape[-2:]
|
||||
|
||||
# [Batch, emb_dim, num_patches_height, num_patches_width] --> [Batch, emb_dim, num_patches]
|
||||
patch_emb = patch_emb.flatten(2)
|
||||
# [Batch, emb_dim, num_patches] --> [Batch, num_patches, emb_dim]
|
||||
patch_emb = patch_emb.transpose(1, 2).contiguous()
|
||||
|
||||
n_patches = patch_emb.shape[1]
|
||||
# we resize the positional encodings dynamically.
|
||||
pos_emb = self.pos_embed(n_patches).to(patch_emb.dtype)
|
||||
|
||||
# add positional encodings
|
||||
patch_emb = pos_emb + patch_emb
|
||||
|
||||
# add classification token
|
||||
if self.cls_token is not None:
|
||||
# [1, 1, emb_dim] --> [Batch, 1, emb_dim]
|
||||
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
|
||||
# Concat([Batch, 1, emb_dim], [Batch, num_patches, emb_dim]) --> [Batch, num_patches + 1, emb_dim]
|
||||
patch_emb = torch.cat((cls_tokens, patch_emb), dim=1)
|
||||
|
||||
# dropout
|
||||
patch_emb = self.emb_dropout(patch_emb)
|
||||
return patch_emb, (n_h, n_w)
|
||||
|
||||
def _features_from_transformer(self, x: Tensor, *args, **kwargs) -> tuple[Tensor, tuple[int, int]]:
|
||||
# this function extract patch embeddings and then apply transformer module to learn
|
||||
# inter-patch representations
|
||||
|
||||
# [B, N, C] --> [N, B, embed_dim], where B is batch size, N is number of tokens,
|
||||
# and embed_dim is feature dim
|
||||
x, (n_h, n_w) = self.extract_patch_embeddings(x)
|
||||
|
||||
for layer in self.transformer:
|
||||
x = layer(x)
|
||||
x = self.post_transformer_norm(x)
|
||||
|
||||
return x, (n_h, n_w)
|
||||
|
||||
def extract_features(self, x: Tensor, *args, **kwargs) -> tuple[Tensor, Tensor | None]:
|
||||
# The extract_features function for ViT returns two outputs: (1) embedding corresponding to CLS token
|
||||
# and (2) image embeddings of the shape [B, C, h//o, w//o], where the value of o is typically 16.
|
||||
return_image_embeddings = kwargs.get("return_image_embeddings", False)
|
||||
|
||||
# [B, C, H, W] --> [B, N + 1, embed_dim] or [B, N, embed_dim]
|
||||
# here, B is batch size, C is input channels
|
||||
# H and W are input height and width
|
||||
# N is the number of pixels (or tokens) after processing input with conv stem and reshaping
|
||||
# We add +1 for cls token (if applicable)
|
||||
# embed_dim --> embedding dimension
|
||||
x, (n_h, n_w) = self._features_from_transformer(x, *args, **kwargs)
|
||||
|
||||
if self.cls_token is not None:
|
||||
# [B, N + 1, embed_dim] --> [B, embed_dim], [B, N, embed_dim]
|
||||
cls_embedding, image_embedding = torch.split(x, split_size_or_sections=[1, x.shape[1] - 1], dim=1)
|
||||
cls_embedding = cls_embedding.squeeze(1)
|
||||
else:
|
||||
# [B, N, embed_dim] -> [B, embed_dim]
|
||||
cls_embedding = torch.mean(x, dim=1)
|
||||
# [B, N, embed_dim]
|
||||
image_embedding = x
|
||||
|
||||
if return_image_embeddings:
|
||||
# reshape image embedding to 4-D tensor
|
||||
# [B, N, C] --> [B, C, N]
|
||||
image_embedding = image_embedding.transpose(1, 2).contiguous()
|
||||
image_embedding = image_embedding.reshape(image_embedding.shape[0], -1, n_h, n_w)
|
||||
|
||||
return cls_embedding, image_embedding
|
||||
else:
|
||||
return cls_embedding, None
|
||||
|
||||
def forward_classifier(self, x: Tensor, *args, **kwargs) -> tuple[Tensor, Tensor]:
|
||||
cls_embedding, image_embedding = self.extract_features(x, *args, **kwargs)
|
||||
# classify based on CLS token
|
||||
cls_embedding = self.classifier(cls_embedding)
|
||||
return cls_embedding, image_embedding
|
||||
|
||||
def forward(self, x: Tensor, *args, **kwargs) -> Tensor | dict[str, Tensor]:
|
||||
# In ViT model, we can return either classifier embeddings (logits) or image embeddings or both.
|
||||
# To return the image embeddings, we need to set keyword argument (return_image_embeddings) as True.
|
||||
if kwargs.get("return_image_embeddings", False):
|
||||
out_dict = dict()
|
||||
prediction, image_embedding = self.forward_classifier(x, *args, **kwargs)
|
||||
out_dict.update({"logits": prediction})
|
||||
if image_embedding is not None:
|
||||
out_dict.update({"image_embeddings": image_embedding})
|
||||
return out_dict
|
||||
else:
|
||||
prediction, _ = self.forward_classifier(x, *args, **kwargs)
|
||||
return prediction
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_b16(pretrained=False, **kwargs):
|
||||
# Vision transformer config
|
||||
cfg = {
|
||||
"norm_layer": "layer_norm_fp32",
|
||||
"act_layer": "gelu",
|
||||
"embed_dim": 768,
|
||||
"n_transformer_layers": 12,
|
||||
"n_attn_heads": 12,
|
||||
}
|
||||
model = VisionTransformer(cfg=cfg, **kwargs)
|
||||
if pretrained:
|
||||
raise ValueError("Functionality not implemented.")
|
||||
return model
|
||||
6
mobileclip/modules/__init__.py
Normal file
6
mobileclip/modules/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
#
|
||||
# For licensing see accompanying LICENSE file.
|
||||
# Copyright (C) 2024 Apple Inc. All rights reserved.
|
||||
#
|
||||
6
mobileclip/modules/common/__init__.py
Normal file
6
mobileclip/modules/common/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
#
|
||||
# For licensing see accompanying LICENSE file.
|
||||
# Copyright (C) 2024 Apple Inc. All rights reserved.
|
||||
#
|
||||
330
mobileclip/modules/common/mobileone.py
Normal file
330
mobileclip/modules/common/mobileone.py
Normal file
@@ -0,0 +1,330 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
# For licensing see accompanying LICENSE file.
|
||||
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
__all__ = ["MobileOneBlock", "reparameterize_model"]
|
||||
|
||||
|
||||
class SEBlock(nn.Module):
|
||||
"""Squeeze and Excite module.
|
||||
|
||||
Pytorch implementation of `Squeeze-and-Excitation Networks` - https://arxiv.org/pdf/1709.01507.pdf
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels: int, rd_ratio: float = 0.0625) -> None:
|
||||
"""Construct a Squeeze and Excite Module.
|
||||
|
||||
Args:
|
||||
in_channels: Number of input channels.
|
||||
rd_ratio: Input channel reduction ratio.
|
||||
"""
|
||||
super().__init__()
|
||||
self.reduce = nn.Conv2d(
|
||||
in_channels=in_channels,
|
||||
out_channels=int(in_channels * rd_ratio),
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
bias=True,
|
||||
)
|
||||
self.expand = nn.Conv2d(
|
||||
in_channels=int(in_channels * rd_ratio),
|
||||
out_channels=in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
bias=True,
|
||||
)
|
||||
|
||||
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||
"""Apply forward pass."""
|
||||
_b, c, h, w = inputs.size()
|
||||
x = F.avg_pool2d(inputs, kernel_size=[h, w])
|
||||
x = self.reduce(x)
|
||||
x = F.relu(x)
|
||||
x = self.expand(x)
|
||||
x = torch.sigmoid(x)
|
||||
x = x.view(-1, c, 1, 1)
|
||||
return inputs * x
|
||||
|
||||
|
||||
class MobileOneBlock(nn.Module):
|
||||
"""MobileOne building block.
|
||||
|
||||
This block has a multi-branched architecture at train-time and plain-CNN style architecture at inference time For
|
||||
more details, please refer to our paper: `An Improved One millisecond Mobile Backbone` -
|
||||
https://arxiv.org/pdf/2206.04040.pdf
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: int,
|
||||
stride: int = 1,
|
||||
padding: int = 0,
|
||||
dilation: int = 1,
|
||||
groups: int = 1,
|
||||
inference_mode: bool = False,
|
||||
use_se: bool = False,
|
||||
use_act: bool = True,
|
||||
use_scale_branch: bool = True,
|
||||
num_conv_branches: int = 1,
|
||||
activation: nn.Module = nn.GELU(),
|
||||
) -> None:
|
||||
"""Construct a MobileOneBlock module.
|
||||
|
||||
Args:
|
||||
in_channels: Number of channels in the input.
|
||||
out_channels: Number of channels produced by the block.
|
||||
kernel_size: Size of the convolution kernel.
|
||||
stride: Stride size.
|
||||
padding: Zero-padding size.
|
||||
dilation: Kernel dilation factor.
|
||||
groups: Group number.
|
||||
inference_mode: If True, instantiates model in inference mode.
|
||||
use_se: Whether to use SE-ReLU activations.
|
||||
use_act: Whether to use activation. Default: ``True``
|
||||
use_scale_branch: Whether to use scale branch. Default: ``True``
|
||||
num_conv_branches: Number of linear conv branches.
|
||||
"""
|
||||
super().__init__()
|
||||
self.inference_mode = inference_mode
|
||||
self.groups = groups
|
||||
self.stride = stride
|
||||
self.padding = padding
|
||||
self.dilation = dilation
|
||||
self.kernel_size = kernel_size
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.num_conv_branches = num_conv_branches
|
||||
|
||||
# Check if SE-ReLU is requested
|
||||
if use_se:
|
||||
self.se = SEBlock(out_channels)
|
||||
else:
|
||||
self.se = nn.Identity()
|
||||
|
||||
if use_act:
|
||||
self.activation = activation
|
||||
else:
|
||||
self.activation = nn.Identity()
|
||||
|
||||
if inference_mode:
|
||||
self.reparam_conv = nn.Conv2d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
bias=True,
|
||||
)
|
||||
else:
|
||||
# Re-parameterizable skip connection
|
||||
self.rbr_skip = (
|
||||
nn.BatchNorm2d(num_features=in_channels) if out_channels == in_channels and stride == 1 else None
|
||||
)
|
||||
|
||||
# Re-parameterizable conv branches
|
||||
if num_conv_branches > 0:
|
||||
rbr_conv = list()
|
||||
for _ in range(self.num_conv_branches):
|
||||
rbr_conv.append(self._conv_bn(kernel_size=kernel_size, padding=padding))
|
||||
self.rbr_conv = nn.ModuleList(rbr_conv)
|
||||
else:
|
||||
self.rbr_conv = None
|
||||
|
||||
# Re-parameterizable scale branch
|
||||
self.rbr_scale = None
|
||||
if not isinstance(kernel_size, int):
|
||||
kernel_size = kernel_size[0]
|
||||
if (kernel_size > 1) and use_scale_branch:
|
||||
self.rbr_scale = self._conv_bn(kernel_size=1, padding=0)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Apply forward pass."""
|
||||
# Inference mode forward pass.
|
||||
if self.inference_mode:
|
||||
return self.activation(self.se(self.reparam_conv(x)))
|
||||
|
||||
# Multi-branched train-time forward pass.
|
||||
# Skip branch output
|
||||
identity_out = 0
|
||||
if self.rbr_skip is not None:
|
||||
identity_out = self.rbr_skip(x)
|
||||
|
||||
# Scale branch output
|
||||
scale_out = 0
|
||||
if self.rbr_scale is not None:
|
||||
scale_out = self.rbr_scale(x)
|
||||
|
||||
# Other branches
|
||||
out = scale_out + identity_out
|
||||
if self.rbr_conv is not None:
|
||||
for ix in range(self.num_conv_branches):
|
||||
out += self.rbr_conv[ix](x)
|
||||
|
||||
return self.activation(self.se(out))
|
||||
|
||||
def reparameterize(self):
|
||||
"""Following works like `RepVGG: Making VGG-style ConvNets Great Again` - https://arxiv.org/pdf/2101.03697.pdf.
|
||||
We re-parameterize multi-branched architecture used at training time to obtain a plain CNN-like
|
||||
structure for inference.
|
||||
"""
|
||||
if self.inference_mode:
|
||||
return
|
||||
kernel, bias = self._get_kernel_bias()
|
||||
self.reparam_conv = nn.Conv2d(
|
||||
in_channels=self.in_channels,
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=self.kernel_size,
|
||||
stride=self.stride,
|
||||
padding=self.padding,
|
||||
dilation=self.dilation,
|
||||
groups=self.groups,
|
||||
bias=True,
|
||||
)
|
||||
self.reparam_conv.weight.data = kernel
|
||||
self.reparam_conv.bias.data = bias
|
||||
|
||||
# Delete un-used branches
|
||||
for para in self.parameters():
|
||||
para.detach_()
|
||||
self.__delattr__("rbr_conv")
|
||||
self.__delattr__("rbr_scale")
|
||||
if hasattr(self, "rbr_skip"):
|
||||
self.__delattr__("rbr_skip")
|
||||
|
||||
self.inference_mode = True
|
||||
|
||||
def _get_kernel_bias(self) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Method to obtain re-parameterized kernel and bias. Reference:
|
||||
https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L83.
|
||||
|
||||
Returns:
|
||||
Tuple of (kernel, bias) after fusing branches.
|
||||
"""
|
||||
# get weights and bias of scale branch
|
||||
kernel_scale = 0
|
||||
bias_scale = 0
|
||||
if self.rbr_scale is not None:
|
||||
kernel_scale, bias_scale = self._fuse_bn_tensor(self.rbr_scale)
|
||||
# Pad scale branch kernel to match conv branch kernel size.
|
||||
pad = self.kernel_size // 2
|
||||
kernel_scale = torch.nn.functional.pad(kernel_scale, [pad, pad, pad, pad])
|
||||
|
||||
# get weights and bias of skip branch
|
||||
kernel_identity = 0
|
||||
bias_identity = 0
|
||||
if self.rbr_skip is not None:
|
||||
kernel_identity, bias_identity = self._fuse_bn_tensor(self.rbr_skip)
|
||||
|
||||
# get weights and bias of conv branches
|
||||
kernel_conv = 0
|
||||
bias_conv = 0
|
||||
if self.rbr_conv is not None:
|
||||
for ix in range(self.num_conv_branches):
|
||||
_kernel, _bias = self._fuse_bn_tensor(self.rbr_conv[ix])
|
||||
kernel_conv += _kernel
|
||||
bias_conv += _bias
|
||||
|
||||
kernel_final = kernel_conv + kernel_scale + kernel_identity
|
||||
bias_final = bias_conv + bias_scale + bias_identity
|
||||
return kernel_final, bias_final
|
||||
|
||||
def _fuse_bn_tensor(self, branch: nn.Sequential | nn.BatchNorm2d) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Method to fuse batchnorm layer with preceeding conv layer. Reference:
|
||||
https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L95.
|
||||
|
||||
Args:
|
||||
branch: Sequence of ops to be fused.
|
||||
|
||||
Returns:
|
||||
Tuple of (kernel, bias) after fusing batchnorm.
|
||||
"""
|
||||
if isinstance(branch, nn.Sequential):
|
||||
kernel = branch.conv.weight
|
||||
running_mean = branch.bn.running_mean
|
||||
running_var = branch.bn.running_var
|
||||
gamma = branch.bn.weight
|
||||
beta = branch.bn.bias
|
||||
eps = branch.bn.eps
|
||||
else:
|
||||
assert isinstance(branch, nn.BatchNorm2d)
|
||||
if not hasattr(self, "id_tensor"):
|
||||
input_dim = self.in_channels // self.groups
|
||||
|
||||
kernel_size = self.kernel_size
|
||||
if isinstance(self.kernel_size, int):
|
||||
kernel_size = (self.kernel_size, self.kernel_size)
|
||||
|
||||
kernel_value = torch.zeros(
|
||||
(self.in_channels, input_dim, kernel_size[0], kernel_size[1]),
|
||||
dtype=branch.weight.dtype,
|
||||
device=branch.weight.device,
|
||||
)
|
||||
for i in range(self.in_channels):
|
||||
kernel_value[i, i % input_dim, kernel_size[0] // 2, kernel_size[1] // 2] = 1
|
||||
self.id_tensor = kernel_value
|
||||
kernel = self.id_tensor
|
||||
running_mean = branch.running_mean
|
||||
running_var = branch.running_var
|
||||
gamma = branch.weight
|
||||
beta = branch.bias
|
||||
eps = branch.eps
|
||||
std = (running_var + eps).sqrt()
|
||||
t = (gamma / std).reshape(-1, 1, 1, 1)
|
||||
return kernel * t, beta - running_mean * gamma / std
|
||||
|
||||
def _conv_bn(self, kernel_size: int, padding: int) -> nn.Sequential:
|
||||
"""Helper method to construct conv-batchnorm layers.
|
||||
|
||||
Args:
|
||||
kernel_size: Size of the convolution kernel.
|
||||
padding: Zero-padding size.
|
||||
|
||||
Returns:
|
||||
Conv-BN module.
|
||||
"""
|
||||
mod_list = nn.Sequential()
|
||||
mod_list.add_module(
|
||||
"conv",
|
||||
nn.Conv2d(
|
||||
in_channels=self.in_channels,
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=self.stride,
|
||||
padding=padding,
|
||||
groups=self.groups,
|
||||
bias=False,
|
||||
),
|
||||
)
|
||||
mod_list.add_module("bn", nn.BatchNorm2d(num_features=self.out_channels))
|
||||
return mod_list
|
||||
|
||||
|
||||
def reparameterize_model(model: torch.nn.Module) -> nn.Module:
|
||||
"""Method returns a model where a multi-branched structure used in training is re-parameterized into a single branch
|
||||
for inference.
|
||||
|
||||
Args:
|
||||
model: MobileOne model in train mode.
|
||||
|
||||
Returns:
|
||||
MobileOne model in inference mode.
|
||||
"""
|
||||
# Avoid editing original graph
|
||||
model = copy.deepcopy(model)
|
||||
for module in model.modules():
|
||||
if hasattr(module, "reparameterize"):
|
||||
module.reparameterize()
|
||||
return model
|
||||
410
mobileclip/modules/common/transformer.py
Normal file
410
mobileclip/modules/common/transformer.py
Normal file
@@ -0,0 +1,410 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
#
|
||||
# For licensing see accompanying LICENSE file.
|
||||
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
|
||||
#
|
||||
"""
|
||||
Implementation of the following modules is borrowed from ml-cvnets repo:
|
||||
https://github.com/apple/ml-cvnets/blob/main/cvnets/layers/multi_head_attention.py
|
||||
https://github.com/apple/ml-cvnets/blob/main/cvnets/text_encoders/transformer.py.
|
||||
|
||||
Please see ACKNOWLEDGMENTS for license details.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
from torch import Size, Tensor, nn
|
||||
from torch.nn import functional as F
|
||||
from torchvision.ops import StochasticDepth
|
||||
|
||||
from mobileclip import logger
|
||||
|
||||
|
||||
class LayerNormFP32(nn.LayerNorm):
|
||||
"""Applies `Layer Normalization <https://arxiv.org/abs/1607.06450>`_ over a input tensor with FP32 precision."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
normalized_shape: int | list[int] | Size,
|
||||
eps: float | None = 1e-5,
|
||||
elementwise_affine: bool | None = True,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
normalized_shape=normalized_shape,
|
||||
eps=eps,
|
||||
elementwise_affine=elementwise_affine,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
# Convert input from dtype X to FP32 and perform normalization operation.
|
||||
# This may help with underflow/overflow issues that we typically see with normalization layers
|
||||
inp_dtype = x.dtype
|
||||
return super().forward(x.to(torch.float32)).to(inp_dtype)
|
||||
|
||||
|
||||
def get_normalization_layer(norm_type, num_features):
|
||||
if norm_type == "layer_norm":
|
||||
return nn.LayerNorm(num_features)
|
||||
elif norm_type == "layer_norm_fp32":
|
||||
return LayerNormFP32(num_features)
|
||||
else:
|
||||
raise NotImplementedError(f"Option: {norm_type} not supported.")
|
||||
|
||||
|
||||
class PositionalEmbedding(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_embeddings: int,
|
||||
embedding_dim: int,
|
||||
padding_idx: int | None = None,
|
||||
is_learnable: bool | None = False,
|
||||
interpolation_mode: str | None = "bilinear",
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
# Add other pos embedding here and logic to choose between them
|
||||
module = LearnablePositionalEmbedding
|
||||
|
||||
self.pos_embed = module(
|
||||
num_embeddings=num_embeddings,
|
||||
embedding_dim=embedding_dim,
|
||||
padding_idx=padding_idx,
|
||||
interpolation_mode=interpolation_mode,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def forward(self, seq_len: int, *args, **kwargs) -> Tensor:
|
||||
return self.pos_embed(seq_len, *args, **kwargs)
|
||||
|
||||
def __repr__(self):
|
||||
return self.pos_embed.__repr__()
|
||||
|
||||
|
||||
class LearnablePositionalEmbedding(nn.Module):
|
||||
"""Learnable Positional embedding."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_embeddings: int,
|
||||
embedding_dim: int,
|
||||
padding_idx: int | None = None,
|
||||
interpolation_mode: str | None = "bilinear",
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.pos_embed = nn.Parameter(torch.empty(1, 1, num_embeddings, embedding_dim))
|
||||
self.embedding_dim = embedding_dim
|
||||
self.num_embeddings = num_embeddings
|
||||
self.padding_idx = padding_idx
|
||||
self.interpolation_mode = interpolation_mode
|
||||
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self) -> None:
|
||||
nn.init.trunc_normal_(self.pos_embed, mean=0, std=self.embedding_dim**-0.5)
|
||||
if self.padding_idx is not None:
|
||||
with torch.no_grad():
|
||||
self.pos_embed[:, :, self.padding_idx, ...] = 0.0
|
||||
|
||||
def forward(self, seq_len: int, *args, **kwargs) -> Tensor:
|
||||
# scale pos embedding
|
||||
pos_embed = self.pos_embed
|
||||
if self.padding_idx is not None:
|
||||
with torch.no_grad():
|
||||
pos_embed[:, :, self.padding_idx, ...] = 0.0
|
||||
|
||||
if seq_len != self.num_embeddings:
|
||||
pos_embed = F.interpolate(
|
||||
pos_embed,
|
||||
size=(seq_len, self.embedding_dim),
|
||||
mode=self.interpolation_mode,
|
||||
)
|
||||
|
||||
# Input is of the form [Batch, Seq_len, Embedding_dim]
|
||||
return pos_embed.reshape(1, seq_len, self.embedding_dim)
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.__class__.__name__}(num_embeddings={self.num_embeddings}, embedding_dim={self.embedding_dim}, padding_idx={self.padding_idx})"
|
||||
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
"""This layer applies a multi-head self- or cross-attention as described in `Attention is all you need
|
||||
<https://arxiv.org/abs/1706.03762>`_ paper.
|
||||
|
||||
Args:
|
||||
embed_dim (int): :math:`C_{in}` from an expected input of size :math:`(N, S, C_{in})`
|
||||
num_heads (int): Number of heads in multi-head attention
|
||||
attn_dropout (Optional[float]): Attention dropout. Default: 0.0
|
||||
bias (Optional[bool]): Use bias or not. Default: ``True``
|
||||
|
||||
Notes:
|
||||
- Input:
|
||||
- Query tensor (x_q) :math:`(N, S, C_{in})` where :math:`N` is batch size, :math:`S` is number of source tokens,
|
||||
and: math:`C_{in}` is input embedding dim
|
||||
- Optional Key-Value tensor (x_kv) :math:`(N, T, C_{in})` where :math:`T` is number of target tokens
|
||||
- Output: same shape as the input
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim: int,
|
||||
num_heads: int,
|
||||
attn_dropout: float | None = 0.0,
|
||||
bias: bool | None = True,
|
||||
output_dim: int | None = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
if output_dim is None:
|
||||
output_dim = embed_dim
|
||||
super().__init__()
|
||||
if embed_dim % num_heads != 0:
|
||||
logger.error(
|
||||
f"Embedding dim must be divisible by number of heads in {self.__class__.__name__}. Got: embed_dim={embed_dim} and num_heads={num_heads}"
|
||||
)
|
||||
|
||||
self.qkv_proj = nn.Linear(in_features=embed_dim, out_features=3 * embed_dim, bias=bias)
|
||||
|
||||
self.attn_dropout = nn.Dropout(p=attn_dropout)
|
||||
self.out_proj = nn.Linear(in_features=embed_dim, out_features=output_dim, bias=bias)
|
||||
|
||||
self.head_dim = embed_dim // num_heads
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.softmax = nn.Softmax(dim=-1)
|
||||
self.num_heads = num_heads
|
||||
self.embed_dim = embed_dim
|
||||
self.use_separate_proj_weight = embed_dim != output_dim
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.__class__.__name__}(head_dim={self.head_dim}, num_heads={self.num_heads}, attn_dropout={self.attn_dropout.p})"
|
||||
|
||||
def _forward_impl(
|
||||
self,
|
||||
x_q: Tensor,
|
||||
x_kv: Tensor | None = None,
|
||||
key_padding_mask: Tensor | None = None,
|
||||
attn_mask: Tensor | None = None,
|
||||
) -> Tensor:
|
||||
# [N, S, C]
|
||||
b_sz, S_len, _in_channels = x_q.shape
|
||||
|
||||
if x_kv is None:
|
||||
# self-attention
|
||||
# [N, S, C] --> [N, S, 3C] --> [N, S, 3, h, c] where C = hc
|
||||
qkv = self.qkv_proj(x_q).reshape(b_sz, S_len, 3, self.num_heads, -1)
|
||||
# [N, S, 3, h, c] --> [N, h, 3, S, C]
|
||||
qkv = qkv.transpose(1, 3).contiguous()
|
||||
|
||||
# [N, h, 3, S, C] --> [N, h, S, C] x 3
|
||||
query, key, value = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2]
|
||||
else:
|
||||
T_len = x_kv.shape[1]
|
||||
|
||||
# cross-attention
|
||||
# [N, S, C]
|
||||
query = F.linear(
|
||||
x_q,
|
||||
weight=self.qkv_proj.weight[: self.embed_dim, ...],
|
||||
bias=self.qkv_proj.bias[: self.embed_dim] if self.qkv_proj.bias is not None else None,
|
||||
)
|
||||
# [N, S, C] --> [N, S, h, c] --> [N, h, S, c]
|
||||
query = query.reshape(b_sz, S_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
||||
|
||||
# [N, T, C] --> [N, T, 2C]
|
||||
kv = F.linear(
|
||||
x_kv,
|
||||
weight=self.qkv_proj.weight[self.embed_dim :, ...],
|
||||
bias=self.qkv_proj.bias[self.embed_dim :] if self.qkv_proj.bias is not None else None,
|
||||
)
|
||||
# [N, T, 2C] --> [N, T, 2, h, c]
|
||||
kv = kv.reshape(b_sz, T_len, 2, self.num_heads, self.head_dim)
|
||||
# [N, T, 2, h, c] --> [N, h, 2, T, c]
|
||||
kv = kv.transpose(1, 3).contiguous()
|
||||
key, value = kv[:, :, 0], kv[:, :, 1]
|
||||
|
||||
query = query * self.scaling
|
||||
|
||||
# [N h, T, c] --> [N, h, c, T]
|
||||
key = key.transpose(-1, -2)
|
||||
|
||||
# QK^T
|
||||
# [N, h, S, c] x [N, h, c, T] --> [N, h, S, T]
|
||||
attn = torch.matmul(query, key)
|
||||
|
||||
batch_size, _num_heads, num_src_tokens, num_tgt_tokens = attn.shape
|
||||
if attn_mask is not None:
|
||||
# attn_mask shape should be the same as attn
|
||||
assert list(attn_mask.shape) == [
|
||||
batch_size,
|
||||
num_src_tokens,
|
||||
num_tgt_tokens,
|
||||
], (
|
||||
f"Shape of attention mask should be [{batch_size}, {num_src_tokens}, {num_tgt_tokens}]. Got: {attn_mask.shape}"
|
||||
)
|
||||
# [N, S, T] --> [N, 1, S, T]
|
||||
attn_mask = attn_mask.unsqueeze(1)
|
||||
attn = attn + attn_mask
|
||||
|
||||
if key_padding_mask is not None:
|
||||
# Do not attend to padding positions
|
||||
# key padding mask size is [N, T]
|
||||
assert key_padding_mask.dim() == 2 and list(key_padding_mask.shape) == [
|
||||
batch_size,
|
||||
num_tgt_tokens,
|
||||
], (
|
||||
f"Key_padding_mask should be 2-dimension with shape [{batch_size}, {num_tgt_tokens}]. Got: {key_padding_mask.shape}"
|
||||
)
|
||||
attn = attn.masked_fill(
|
||||
key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), # [N, T] --> [N, 1, 1, T]
|
||||
float("-inf"),
|
||||
)
|
||||
|
||||
attn_dtype = attn.dtype
|
||||
attn_as_float = self.softmax(attn.float())
|
||||
attn = attn_as_float.to(attn_dtype)
|
||||
attn = self.attn_dropout(attn)
|
||||
|
||||
# weighted sum
|
||||
# [N, h, S, T] x [N, h, T, c] --> [N, h, S, c]
|
||||
out = torch.matmul(attn, value)
|
||||
|
||||
# [N, h, S, c] --> [N, S, h, c] --> [N, S, C]
|
||||
out = out.transpose(1, 2).reshape(b_sz, S_len, -1)
|
||||
out = self.out_proj(out)
|
||||
|
||||
return out
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x_q: Tensor,
|
||||
x_kv: Tensor | None = None,
|
||||
key_padding_mask: Tensor | None = None,
|
||||
attn_mask: Tensor | None = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> Tensor:
|
||||
# [Batch , Sequence, Hidden_dim]
|
||||
return self._forward_impl(
|
||||
x_q=x_q,
|
||||
x_kv=x_kv,
|
||||
key_padding_mask=key_padding_mask,
|
||||
attn_mask=attn_mask,
|
||||
)
|
||||
|
||||
|
||||
class TransformerEncoder(nn.Module):
|
||||
"""This class defines the pre-norm `Transformer encoder <https://arxiv.org/abs/1706.03762>`_.
|
||||
|
||||
Args:
|
||||
embed_dim: :math:`C_{in}` from an expected input of size :math:`(N, P, C_{in})`.
|
||||
ffn_latent_dim: Inner dimension of the FFN.
|
||||
num_heads: Number of heads in multi-head attention. Default: 8.
|
||||
attn_dropout: Dropout rate for attention in multi-head attention. Default: 0.0
|
||||
dropout: Dropout rate. Default: 0.0.
|
||||
ffn_dropout: Dropout between FFN layers. Default: 0.0.
|
||||
transformer_norm_layer: Normalization layer. Default: layer_norm.
|
||||
stochastic_dropout: Stochastic dropout setting. Default: 0.0.
|
||||
|
||||
Notes:
|
||||
- Input: :math:`(N, P, C_{in})` where :math:`N` is batch size, :math:`P` is number of patches,
|
||||
and: math:`C_{in}` is input embedding dim
|
||||
- Output: same shape as the input
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim: int,
|
||||
ffn_latent_dim: int,
|
||||
num_heads: int | None = 8,
|
||||
attn_dropout: float | None = 0.0,
|
||||
dropout: float | None = 0.0,
|
||||
ffn_dropout: float | None = 0.0,
|
||||
transformer_norm_layer: str | None = "layer_norm",
|
||||
stochastic_dropout: float | None = 0.0,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
# Build attention layer
|
||||
attn_unit = MultiHeadAttention(
|
||||
embed_dim,
|
||||
num_heads,
|
||||
attn_dropout=attn_dropout,
|
||||
bias=True,
|
||||
)
|
||||
|
||||
self.pre_norm_mha = nn.Sequential(
|
||||
get_normalization_layer(norm_type=transformer_norm_layer, num_features=embed_dim),
|
||||
attn_unit,
|
||||
nn.Dropout(p=dropout),
|
||||
)
|
||||
|
||||
act_name = nn.GELU()
|
||||
self.pre_norm_ffn = nn.Sequential(
|
||||
get_normalization_layer(norm_type=transformer_norm_layer, num_features=embed_dim),
|
||||
nn.Linear(in_features=embed_dim, out_features=ffn_latent_dim, bias=True),
|
||||
act_name,
|
||||
nn.Dropout(p=ffn_dropout),
|
||||
nn.Linear(in_features=ffn_latent_dim, out_features=embed_dim, bias=True),
|
||||
nn.Dropout(p=dropout),
|
||||
)
|
||||
|
||||
self.drop_path = nn.Identity()
|
||||
if stochastic_dropout > 0.0:
|
||||
if dropout > 0.0:
|
||||
logger.error(
|
||||
"Stochastic dropout and dropout are mutually exclusive. "
|
||||
"Use either of them, but not both."
|
||||
f"Got: {stochastic_dropout} and {dropout}"
|
||||
)
|
||||
self.drop_path = StochasticDepth(p=stochastic_dropout, mode="row")
|
||||
|
||||
self.embed_dim = embed_dim
|
||||
self.ffn_dim = ffn_latent_dim
|
||||
self.ffn_dropout = ffn_dropout
|
||||
self.stochastic_dropout = stochastic_dropout
|
||||
self.std_dropout = dropout
|
||||
self.attn_fn_name = attn_unit.__class__.__name__
|
||||
self.act_fn_name = act_name.__class__.__name__
|
||||
self.norm_type = transformer_norm_layer
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}(embed_dim={self.embed_dim}, ffn_dim={self.ffn_dim}, dropout={self.std_dropout}, ffn_dropout={self.ffn_dropout}, stochastic_dropout={self.stochastic_dropout}, attn_fn={self.attn_fn_name}, act_fn={self.act_fn_name}, norm_fn={self.norm_type})"
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: Tensor,
|
||||
x_prev: Tensor | None = None,
|
||||
key_padding_mask: Tensor | None = None,
|
||||
attn_mask: Tensor | None = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> Tensor:
|
||||
# Multi-head attention
|
||||
res = x
|
||||
x = self.pre_norm_mha[0](x) # norm
|
||||
x = self.pre_norm_mha[1](
|
||||
x_q=x,
|
||||
x_kv=x_prev,
|
||||
key_padding_mask=key_padding_mask,
|
||||
attn_mask=attn_mask,
|
||||
*args,
|
||||
**kwargs,
|
||||
) # mha
|
||||
|
||||
x = self.drop_path(self.pre_norm_mha[2](x)) # applying stochastic depth
|
||||
x = x + res
|
||||
|
||||
# Feed forward network
|
||||
x = x + self.drop_path(self.pre_norm_ffn(x))
|
||||
return x
|
||||
6
mobileclip/modules/image/__init__.py
Normal file
6
mobileclip/modules/image/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
#
|
||||
# For licensing see accompanying LICENSE file.
|
||||
# Copyright (C) 2024 Apple Inc. All rights reserved.
|
||||
#
|
||||
97
mobileclip/modules/image/image_projection.py
Normal file
97
mobileclip/modules/image/image_projection.py
Normal file
@@ -0,0 +1,97 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
# For licensing see accompanying LICENSE file.
|
||||
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
|
||||
from mobileclip import logger
|
||||
|
||||
|
||||
class GlobalPool(nn.Module):
|
||||
"""This layers applies global pooling over a 4D or 5D input tensor.
|
||||
|
||||
Args:
|
||||
pool_type (Optional[str]): Pooling type. It can be mean, rms, or abs. Default: `mean`
|
||||
keep_dim (Optional[bool]): Do not squeeze the dimensions of a tensor. Default: `False`
|
||||
|
||||
Notes:
|
||||
- Input: :math:`(N, C, H, W)` or :math:`(N, C, D, H, W)`
|
||||
- Output: :math:`(N, C, 1, 1)` or :math:`(N, C, 1, 1, 1)` if keep_dim else :math:`(N, C)`
|
||||
"""
|
||||
|
||||
pool_types = ["mean", "rms", "abs"]
|
||||
|
||||
def __init__(self, pool_type: str | None = "mean", keep_dim: bool | None = False, *args, **kwargs) -> None:
|
||||
super().__init__()
|
||||
if pool_type not in self.pool_types:
|
||||
logger.error(f"Supported pool types are: {self.pool_types}. Got {pool_type}")
|
||||
self.pool_type = pool_type
|
||||
self.keep_dim = keep_dim
|
||||
|
||||
def _global_pool(self, x: Tensor, dims: list):
|
||||
if self.pool_type == "rms": # root mean square
|
||||
x = x**2
|
||||
x = torch.mean(x, dim=dims, keepdim=self.keep_dim)
|
||||
x = x**-0.5
|
||||
elif self.pool_type == "abs": # absolute
|
||||
x = torch.mean(torch.abs(x), dim=dims, keepdim=self.keep_dim)
|
||||
else:
|
||||
# default is mean
|
||||
# same as AdaptiveAvgPool
|
||||
x = torch.mean(x, dim=dims, keepdim=self.keep_dim)
|
||||
return x
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
if x.dim() == 4:
|
||||
dims = [-2, -1]
|
||||
elif x.dim() == 5:
|
||||
dims = [-3, -2, -1]
|
||||
else:
|
||||
raise NotImplementedError("Currently 2D and 3D global pooling supported")
|
||||
return self._global_pool(x, dims=dims)
|
||||
|
||||
|
||||
class GlobalPool2D(nn.Module):
|
||||
"""This class implements global pooling with linear projection."""
|
||||
|
||||
def __init__(self, in_dim: int, out_dim: int, *args, **kwargs) -> None:
|
||||
super().__init__()
|
||||
scale = in_dim**-0.5
|
||||
self.pool = GlobalPool(pool_type="mean", keep_dim=False)
|
||||
self.proj = nn.Parameter(scale * torch.randn(size=(in_dim, out_dim)))
|
||||
self.in_dim = in_dim
|
||||
self.out_dim = out_dim
|
||||
|
||||
def forward(self, x: Tensor, *args, **kwargs) -> Tensor:
|
||||
# x is of shape [batch, in_dim]
|
||||
assert x.dim() == 4, f"Input should be 4-dimensional (Batch x in_dim x in_height x in_width). Got: {x.shape}"
|
||||
|
||||
# [batch, in_dim, in_height, in_width] --> [batch, in_dim]
|
||||
x = self.pool(x)
|
||||
# [batch, in_dim] x [in_dim, out_dim] --> [batch, out_dim]
|
||||
x = x @ self.proj
|
||||
return x
|
||||
|
||||
|
||||
class SimpleImageProjectionHead(nn.Module):
|
||||
"""This class implements linear projection head."""
|
||||
|
||||
def __init__(self, in_dim: int, out_dim: int) -> None:
|
||||
super().__init__()
|
||||
scale = in_dim**-0.5
|
||||
self.proj = nn.Parameter(scale * torch.randn(size=(in_dim, out_dim)))
|
||||
self.in_dim = in_dim
|
||||
self.out_dim = out_dim
|
||||
|
||||
def forward(self, x: Tensor, *args, **kwargs) -> Tensor:
|
||||
# x is of shape [batch, in_dim]
|
||||
assert x.dim() == 2, f"Input should be 2-dimensional (Batch x in_dim). Got: {x.shape}"
|
||||
|
||||
# [batch, in_dim] x [in_dim, out_dim] --> [batch, out_dim]
|
||||
x = x @ self.proj
|
||||
return x
|
||||
177
mobileclip/modules/image/replknet.py
Normal file
177
mobileclip/modules/image/replknet.py
Normal file
@@ -0,0 +1,177 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
#
|
||||
# For acknowledgment see accompanying ACKNOWLEDGMENTS file.
|
||||
# Copyright (C) 2024 Apple Inc. All rights reserved.
|
||||
#
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from timm.models.layers import SqueezeExcite
|
||||
|
||||
__all__ = ["ReparamLargeKernelConv"]
|
||||
|
||||
|
||||
class ReparamLargeKernelConv(nn.Module):
|
||||
"""Building Block of RepLKNet.
|
||||
|
||||
This class defines overparameterized large kernel conv block introduced in `RepLKNet
|
||||
<https://arxiv.org/abs/2203.06717>`_
|
||||
|
||||
Reference: https://github.com/DingXiaoH/RepLKNet-pytorch
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: int,
|
||||
stride: int,
|
||||
groups: int,
|
||||
small_kernel: int,
|
||||
inference_mode: bool = False,
|
||||
use_se: bool = False,
|
||||
activation: nn.Module = nn.GELU(),
|
||||
) -> None:
|
||||
"""Construct a ReparamLargeKernelConv module.
|
||||
|
||||
Args:
|
||||
in_channels: Number of input channels.
|
||||
out_channels: Number of output channels.
|
||||
kernel_size: Kernel size of the large kernel conv branch.
|
||||
stride: Stride size. Default: 1
|
||||
groups: Group number. Default: 1
|
||||
small_kernel: Kernel size of small kernel conv branch.
|
||||
inference_mode: If True, instantiates model in inference mode. Default: ``False``
|
||||
activation: Activation module. Default: ``nn.GELU``
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.stride = stride
|
||||
self.groups = groups
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.activation = activation
|
||||
|
||||
self.kernel_size = kernel_size
|
||||
self.small_kernel = small_kernel
|
||||
self.padding = kernel_size // 2
|
||||
|
||||
# Check if SE is requested
|
||||
if use_se:
|
||||
self.se = SqueezeExcite(out_channels, rd_ratio=0.25)
|
||||
else:
|
||||
self.se = nn.Identity()
|
||||
|
||||
if inference_mode:
|
||||
self.lkb_reparam = nn.Conv2d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=self.padding,
|
||||
dilation=1,
|
||||
groups=groups,
|
||||
bias=True,
|
||||
)
|
||||
else:
|
||||
self.lkb_origin = self._conv_bn(kernel_size=kernel_size, padding=self.padding)
|
||||
if small_kernel is not None:
|
||||
assert small_kernel <= kernel_size, (
|
||||
"The kernel size for re-param cannot be larger than the large kernel!"
|
||||
)
|
||||
self.small_conv = self._conv_bn(kernel_size=small_kernel, padding=small_kernel // 2)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Apply forward pass."""
|
||||
if hasattr(self, "lkb_reparam"):
|
||||
out = self.lkb_reparam(x)
|
||||
else:
|
||||
out = self.lkb_origin(x)
|
||||
if hasattr(self, "small_conv"):
|
||||
out += self.small_conv(x)
|
||||
|
||||
return self.activation(self.se(out))
|
||||
|
||||
def get_kernel_bias(self) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Method to obtain re-parameterized kernel and bias. Reference: https://github.com/DingXiaoH/RepLKNet-pytorch.
|
||||
|
||||
Returns:
|
||||
Tuple of (kernel, bias) after fusing branches.
|
||||
"""
|
||||
eq_k, eq_b = self._fuse_bn(self.lkb_origin.conv, self.lkb_origin.bn)
|
||||
if hasattr(self, "small_conv"):
|
||||
small_k, small_b = self._fuse_bn(self.small_conv.conv, self.small_conv.bn)
|
||||
eq_b += small_b
|
||||
eq_k += nn.functional.pad(small_k, [(self.kernel_size - self.small_kernel) // 2] * 4)
|
||||
return eq_k, eq_b
|
||||
|
||||
def reparameterize(self) -> None:
|
||||
"""Following works like `RepVGG: Making VGG-style ConvNets Great Again` - https://arxiv.org/pdf/2101.03697.pdf.
|
||||
We re-parameterize multi-branched architecture used at training time to obtain a plain CNN-like
|
||||
structure for inference.
|
||||
"""
|
||||
eq_k, eq_b = self.get_kernel_bias()
|
||||
self.lkb_reparam = nn.Conv2d(
|
||||
in_channels=self.in_channels,
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=self.kernel_size,
|
||||
stride=self.stride,
|
||||
padding=self.padding,
|
||||
dilation=self.lkb_origin.conv.dilation,
|
||||
groups=self.groups,
|
||||
bias=True,
|
||||
)
|
||||
|
||||
self.lkb_reparam.weight.data = eq_k
|
||||
self.lkb_reparam.bias.data = eq_b
|
||||
self.__delattr__("lkb_origin")
|
||||
if hasattr(self, "small_conv"):
|
||||
self.__delattr__("small_conv")
|
||||
|
||||
@staticmethod
|
||||
def _fuse_bn(conv: torch.Tensor, bn: nn.BatchNorm2d) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Method to fuse batchnorm layer with conv layer.
|
||||
|
||||
Args:
|
||||
conv: Convolutional kernel weights.
|
||||
bn: Batchnorm 2d layer.
|
||||
|
||||
Returns:
|
||||
Tuple of (kernel, bias) after fusing batchnorm.
|
||||
"""
|
||||
kernel = conv.weight
|
||||
running_mean = bn.running_mean
|
||||
running_var = bn.running_var
|
||||
gamma = bn.weight
|
||||
beta = bn.bias
|
||||
eps = bn.eps
|
||||
std = (running_var + eps).sqrt()
|
||||
t = (gamma / std).reshape(-1, 1, 1, 1)
|
||||
return kernel * t, beta - running_mean * gamma / std
|
||||
|
||||
def _conv_bn(self, kernel_size: int, padding: int = 0) -> nn.Sequential:
|
||||
"""Helper method to construct conv-batchnorm layers.
|
||||
|
||||
Args:
|
||||
kernel_size: Size of the convolution kernel.
|
||||
padding: Zero-padding size.
|
||||
|
||||
Returns:
|
||||
A nn.Sequential Conv-BN module.
|
||||
"""
|
||||
mod_list = nn.Sequential()
|
||||
mod_list.add_module(
|
||||
"conv",
|
||||
nn.Conv2d(
|
||||
in_channels=self.in_channels,
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=self.stride,
|
||||
padding=padding,
|
||||
groups=self.groups,
|
||||
bias=False,
|
||||
),
|
||||
)
|
||||
mod_list.add_module("bn", nn.BatchNorm2d(num_features=self.out_channels))
|
||||
return mod_list
|
||||
6
mobileclip/modules/text/__init__.py
Normal file
6
mobileclip/modules/text/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
#
|
||||
# For licensing see accompanying LICENSE file.
|
||||
# Copyright (C) 2024 Apple Inc. All rights reserved.
|
||||
#
|
||||
265
mobileclip/modules/text/repmixer.py
Normal file
265
mobileclip/modules/text/repmixer.py
Normal file
@@ -0,0 +1,265 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
# For licensing see accompanying LICENSE file.
|
||||
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from timm.models.layers import DropPath, trunc_normal_
|
||||
|
||||
from mobileclip.modules.common.mobileone import MobileOneBlock
|
||||
|
||||
|
||||
class ConvFFN(nn.Module):
|
||||
"""Convolutional FFN Module."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
context_size: int,
|
||||
hidden_channels: int | None = None,
|
||||
out_channels: int | None = None,
|
||||
act_layer: nn.Module = nn.GELU,
|
||||
drop: float = 0.0,
|
||||
) -> None:
|
||||
"""Build convolutional FFN module.
|
||||
|
||||
Args:
|
||||
in_channels: Number of input channels.
|
||||
context_size: Context size for 1D signals.
|
||||
hidden_channels: Number of channels after expansion. Default: None
|
||||
out_channels: Number of output channels. Default: None
|
||||
act_layer: Activation layer. Default: ``GELU``
|
||||
drop: Dropout rate. Default: ``0.0``.
|
||||
"""
|
||||
super().__init__()
|
||||
out_channels = out_channels or in_channels
|
||||
hidden_channels = hidden_channels or in_channels
|
||||
self.conv = nn.Sequential()
|
||||
self.conv.add_module(
|
||||
"conv",
|
||||
nn.Conv2d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=(1, int(context_size)),
|
||||
padding=(0, int(context_size // 2)),
|
||||
groups=in_channels,
|
||||
bias=False,
|
||||
),
|
||||
)
|
||||
self.conv.add_module("bn", nn.BatchNorm2d(num_features=out_channels))
|
||||
self.fc1 = nn.Conv2d(in_channels, hidden_channels, kernel_size=1)
|
||||
self.act = act_layer()
|
||||
self.fc2 = nn.Conv2d(hidden_channels, out_channels, kernel_size=1)
|
||||
self.drop = nn.Dropout(drop)
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, m: nn.Module) -> None:
|
||||
if isinstance(m, nn.Conv2d):
|
||||
trunc_normal_(m.weight, std=0.02)
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.conv(x)
|
||||
x = self.fc1(x)
|
||||
x = self.act(x)
|
||||
x = self.drop(x)
|
||||
x = self.fc2(x)
|
||||
x = self.drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class RepMixer(nn.Module):
|
||||
"""Reparameterizable token mixer.
|
||||
|
||||
For more details, please refer to our paper: `FastViT: A Fast Hybrid Vision Transformer using Structural
|
||||
Reparameterization <https://arxiv.org/pdf/2303.14189.pdf>`_
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
kernel_size=3,
|
||||
use_layer_scale=True,
|
||||
layer_scale_init_value=1e-5,
|
||||
inference_mode: bool = False,
|
||||
):
|
||||
"""Build RepMixer Module.
|
||||
|
||||
Args:
|
||||
dim: Input feature map dimension. :math:`C_{in}` from an expected input of size :math:`(B, C_{in}, H, W)`.
|
||||
kernel_size: Kernel size for spatial mixing. Default: 3
|
||||
use_layer_scale: If True, learnable layer scale is used. Default: ``True``
|
||||
layer_scale_init_value: Initial value for layer scale. Default: 1e-5
|
||||
inference_mode: If True, instantiates model in inference mode. Default: ``False``
|
||||
"""
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.kernel_size = kernel_size
|
||||
self.inference_mode = inference_mode
|
||||
|
||||
if inference_mode:
|
||||
self.reparam_conv = nn.Conv2d(
|
||||
in_channels=self.dim,
|
||||
out_channels=self.dim,
|
||||
kernel_size=(1, self.kernel_size),
|
||||
stride=1,
|
||||
padding=(0, self.kernel_size // 2),
|
||||
groups=self.dim,
|
||||
bias=True,
|
||||
)
|
||||
else:
|
||||
self.norm = MobileOneBlock(
|
||||
dim,
|
||||
dim,
|
||||
(1, kernel_size),
|
||||
padding=(0, kernel_size // 2),
|
||||
groups=dim,
|
||||
use_act=False,
|
||||
use_scale_branch=False,
|
||||
num_conv_branches=0,
|
||||
)
|
||||
self.mixer = MobileOneBlock(
|
||||
dim,
|
||||
dim,
|
||||
(1, kernel_size),
|
||||
padding=(0, kernel_size // 2),
|
||||
groups=dim,
|
||||
use_act=False,
|
||||
)
|
||||
self.use_layer_scale = use_layer_scale
|
||||
if use_layer_scale:
|
||||
self.layer_scale = nn.Parameter(layer_scale_init_value * torch.ones((dim, 1, 1)), requires_grad=True)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if hasattr(self, "reparam_conv"):
|
||||
x = self.reparam_conv(x)
|
||||
return x
|
||||
else:
|
||||
if self.use_layer_scale:
|
||||
x = x + self.layer_scale * (self.mixer(x) - self.norm(x))
|
||||
else:
|
||||
x = x + self.mixer(x) - self.norm(x)
|
||||
return x
|
||||
|
||||
def reparameterize(self) -> None:
|
||||
"""Reparameterize mixer and norm into a single convolutional layer for efficient inference."""
|
||||
if self.inference_mode:
|
||||
return
|
||||
|
||||
self.mixer.reparameterize()
|
||||
self.norm.reparameterize()
|
||||
|
||||
if self.use_layer_scale:
|
||||
w = self.mixer.id_tensor + self.layer_scale.unsqueeze(-1) * (
|
||||
self.mixer.reparam_conv.weight - self.norm.reparam_conv.weight
|
||||
)
|
||||
b = torch.squeeze(self.layer_scale) * (self.mixer.reparam_conv.bias - self.norm.reparam_conv.bias)
|
||||
else:
|
||||
w = self.mixer.id_tensor + self.mixer.reparam_conv.weight - self.norm.reparam_conv.weight
|
||||
b = self.mixer.reparam_conv.bias - self.norm.reparam_conv.bias
|
||||
|
||||
self.reparam_conv = nn.Conv2d(
|
||||
in_channels=self.dim,
|
||||
out_channels=self.dim,
|
||||
kernel_size=(1, self.kernel_size),
|
||||
stride=1,
|
||||
padding=(0, self.kernel_size // 2),
|
||||
groups=self.dim,
|
||||
bias=True,
|
||||
)
|
||||
self.reparam_conv.weight.data = w
|
||||
self.reparam_conv.bias.data = b
|
||||
|
||||
for para in self.parameters():
|
||||
para.detach_()
|
||||
self.__delattr__("mixer")
|
||||
self.__delattr__("norm")
|
||||
if self.use_layer_scale:
|
||||
self.__delattr__("layer_scale")
|
||||
|
||||
|
||||
class RepMixerBlock(nn.Module):
|
||||
"""Implementation of Metaformer block with RepMixer as token mixer.
|
||||
|
||||
For more details on Metaformer structure, please refer to: `MetaFormer Is Actually What You Need for Vision
|
||||
<https://arxiv.org/pdf/2111.11418.pdf>`_
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
kernel_size: int = 11,
|
||||
mlp_ratio: float = 4.0,
|
||||
act_layer: nn.Module = nn.GELU,
|
||||
drop: float = 0.0,
|
||||
drop_path: float = 0.0,
|
||||
use_layer_scale: bool = True,
|
||||
layer_scale_init_value: float = 1e-5,
|
||||
inference_mode: bool = False,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
"""Build RepMixer Block.
|
||||
|
||||
Args:
|
||||
dim: Number of embedding dimensions.
|
||||
kernel_size: Kernel size for repmixer. Default: 3
|
||||
mlp_ratio: MLP expansion ratio. Default: 4.0
|
||||
act_layer: Activation layer. Default: ``nn.GELU``
|
||||
drop: Dropout rate. Default: 0.0
|
||||
drop_path: Drop path rate. Default: 0.0
|
||||
use_layer_scale: Flag to turn on layer scale. Default: ``True``
|
||||
layer_scale_init_value: Layer scale value at initialization. Default: 1e-5
|
||||
inference_mode: Flag to instantiate block in inference mode. Default: ``False``
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.token_mixer = RepMixer(
|
||||
dim,
|
||||
kernel_size=kernel_size,
|
||||
use_layer_scale=use_layer_scale,
|
||||
layer_scale_init_value=layer_scale_init_value,
|
||||
inference_mode=inference_mode,
|
||||
)
|
||||
|
||||
assert mlp_ratio > 0, f"MLP ratio should be greater than 0, found: {mlp_ratio}"
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
self.convffn = ConvFFN(
|
||||
in_channels=dim,
|
||||
context_size=kernel_size,
|
||||
hidden_channels=mlp_hidden_dim,
|
||||
act_layer=act_layer,
|
||||
drop=drop,
|
||||
)
|
||||
|
||||
# Drop Path
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
||||
|
||||
# Layer Scale
|
||||
self.use_layer_scale = use_layer_scale
|
||||
if use_layer_scale:
|
||||
self.layer_scale = nn.Parameter(layer_scale_init_value * torch.ones((dim, 1, 1)), requires_grad=True)
|
||||
|
||||
def forward(self, x, *args, **kwargs):
|
||||
if x.dim() == 3:
|
||||
# B, C, D --- where C is the context length
|
||||
# Convert to B, D, C --- to match RepMixer impl.
|
||||
x = x.permute(0, 2, 1)
|
||||
x = torch.unsqueeze(x, dim=2)
|
||||
else:
|
||||
raise ValueError(f"Expected tensor of dim=3, obtained tensor of dim={x.dim()}")
|
||||
|
||||
if self.use_layer_scale:
|
||||
x = self.token_mixer(x)
|
||||
x = x + self.drop_path(self.layer_scale * self.convffn(x))
|
||||
else:
|
||||
x = self.token_mixer(x)
|
||||
x = x + self.drop_path(self.convffn(x))
|
||||
|
||||
# Convert tensors back
|
||||
x = x.squeeze(dim=2).permute(0, 2, 1)
|
||||
return x
|
||||
39
mobileclip/modules/text/tokenizer.py
Normal file
39
mobileclip/modules/text/tokenizer.py
Normal file
@@ -0,0 +1,39 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
#
|
||||
# For licensing see accompanying LICENSE file.
|
||||
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
|
||||
#
|
||||
|
||||
import open_clip
|
||||
from torch import Tensor, nn
|
||||
|
||||
|
||||
class ClipTokenizer(nn.Module):
|
||||
def __init__(self, cfg, *args, **kwargs):
|
||||
super().__init__()
|
||||
self.context_length = cfg["text_cfg"]["context_length"]
|
||||
model_name = getattr(cfg["text_cfg"], "open_clip_tokenizer", "ViT-B-16")
|
||||
self.tokenizer = open_clip.get_tokenizer(model_name)
|
||||
|
||||
def get_vocab_size(self) -> int:
|
||||
return len(self.tokenizer.encoder)
|
||||
|
||||
def get_encodings(self) -> dict[str, int]:
|
||||
return self.tokenizer.encoder
|
||||
|
||||
def get_eot_token(self) -> int:
|
||||
# Tokenizing an empty string returns a list [sot_id, eot_id]
|
||||
return self.tokenizer("")[1]
|
||||
|
||||
def get_sot_token(self) -> int:
|
||||
# Tokenizing an empty string returns a list [sot_id, eot_id]
|
||||
return self.tokenizer("")[0]
|
||||
|
||||
def forward(self, input_sentence: str, *args, **kwargs) -> Tensor:
|
||||
# tokenizer returns indices as a string
|
||||
tokenized_sentence = self.tokenizer(input_sentence, self.context_length)
|
||||
assert tokenized_sentence.shape[-1] == self.context_length, (
|
||||
"Tokenized tensor should be exactly `context_length` long."
|
||||
)
|
||||
return tokenized_sentence
|
||||
218
mobileclip/text_encoder.py
Normal file
218
mobileclip/text_encoder.py
Normal file
@@ -0,0 +1,218 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
# For licensing see accompanying LICENSE file.
|
||||
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from collections.abc import Sequence
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from mobileclip import logger
|
||||
from mobileclip.modules.common.transformer import (
|
||||
PositionalEmbedding,
|
||||
TransformerEncoder,
|
||||
get_normalization_layer,
|
||||
)
|
||||
from mobileclip.modules.text.repmixer import RepMixerBlock
|
||||
|
||||
|
||||
class TextTransformer(nn.Module):
|
||||
def __init__(self, cfg: dict, projection_dim: int, *args, **kwargs) -> None:
|
||||
super().__init__()
|
||||
|
||||
model_dim = cfg["dim"]
|
||||
no_scale_embedding = cfg.get("no_scale_embedding", False)
|
||||
no_pos_embedding = cfg.get("no_pos_embedding", False)
|
||||
embed_dropout = cfg.get("embed_dropout", 0.0)
|
||||
norm_layer = cfg["norm_layer"]
|
||||
variant = cfg["model_name"]
|
||||
self.vocab_size = cfg["vocab_size"]
|
||||
self.projection_dim = projection_dim
|
||||
|
||||
# Token embedding layer
|
||||
self.embedding_layer = nn.Embedding(embedding_dim=model_dim, num_embeddings=self.vocab_size)
|
||||
self.embed_scale = 1.0 if no_scale_embedding else model_dim**-0.5
|
||||
|
||||
# Context length
|
||||
context_length = cfg["context_length"]
|
||||
assert context_length is not None, "Context length can't be None. Please set value accordingly."
|
||||
|
||||
self.positional_embedding = (
|
||||
None if no_pos_embedding else PositionalEmbedding(num_embeddings=context_length, embedding_dim=model_dim)
|
||||
)
|
||||
|
||||
self.embedding_dropout = nn.Dropout(p=embed_dropout)
|
||||
|
||||
# Transformer layer
|
||||
n_transformer_layers = cfg["n_transformer_layers"]
|
||||
|
||||
# FFN multipliers for transformer layer
|
||||
ffn_multipliers = cfg["ffn_multiplier_per_layer"]
|
||||
if isinstance(ffn_multipliers, (float, int)):
|
||||
ffn_multipliers = [ffn_multipliers] * n_transformer_layers
|
||||
|
||||
if not isinstance(ffn_multipliers, Sequence):
|
||||
logger.error(
|
||||
f"{self.__class__.__name__} expects FFN multipliers as a list, whose length is the same as"
|
||||
f" number of transformer layers. Got: {type(ffn_multipliers)}"
|
||||
)
|
||||
elif isinstance(ffn_multipliers, Sequence) and len(ffn_multipliers) != n_transformer_layers:
|
||||
logger.error(
|
||||
f"We need FFN multiplier for each transformer layer. Got {len(ffn_multipliers)} ffn"
|
||||
f" multipliers while number of transformer layers = {n_transformer_layers}"
|
||||
)
|
||||
ffn_dims = [int(math.ceil(model_dim * ffn_mult / 16.0) * 16.0) for ffn_mult in ffn_multipliers]
|
||||
|
||||
# Heads for transformer layers
|
||||
mha_heads = cfg["n_heads_per_layer"]
|
||||
if isinstance(mha_heads, int):
|
||||
mha_heads = [mha_heads] * n_transformer_layers
|
||||
|
||||
if not isinstance(mha_heads, Sequence):
|
||||
logger.error(
|
||||
f"{self.__class__.__name__} expects MHA heads as a list, whose length is the same as number of "
|
||||
f"transformer layers. Got: {type(mha_heads)}"
|
||||
)
|
||||
elif isinstance(mha_heads, Sequence) and len(mha_heads) != n_transformer_layers:
|
||||
logger.error(
|
||||
f"{self.__class__.__name__} needs MHA heads for each transformer layer. Got {len(mha_heads)} mha heads while"
|
||||
f" number of transformer layers = {n_transformer_layers}"
|
||||
)
|
||||
|
||||
if variant == "base":
|
||||
self.transformer = nn.ModuleList(
|
||||
[
|
||||
TransformerEncoder(
|
||||
embed_dim=model_dim,
|
||||
num_heads=mha_heads[layer_idx],
|
||||
ffn_latent_dim=ffn_dims[layer_idx],
|
||||
transformer_norm_layer=norm_layer,
|
||||
)
|
||||
for layer_idx in range(n_transformer_layers)
|
||||
]
|
||||
)
|
||||
elif variant == "mct":
|
||||
self.transformer = nn.ModuleList([RepMixerBlock(dim=model_dim)])
|
||||
self.transformer.extend(
|
||||
[
|
||||
TransformerEncoder(
|
||||
embed_dim=model_dim,
|
||||
num_heads=mha_heads[layer_idx],
|
||||
ffn_latent_dim=ffn_dims[layer_idx],
|
||||
transformer_norm_layer=norm_layer,
|
||||
)
|
||||
for layer_idx in range(n_transformer_layers)
|
||||
]
|
||||
)
|
||||
self.transformer.extend([RepMixerBlock(dim=model_dim)])
|
||||
else:
|
||||
raise ValueError(f"Unrecognized text encoder variant {variant}")
|
||||
|
||||
self.final_layer_norm = get_normalization_layer(num_features=model_dim, norm_type=norm_layer)
|
||||
|
||||
self.projection_layer = nn.Parameter(torch.empty(model_dim, self.projection_dim))
|
||||
self.model_dim = model_dim
|
||||
self.causal_masking = cfg["causal_masking"]
|
||||
|
||||
def forward_embedding(self, text_tokens: Tensor) -> Tensor:
|
||||
"""Return text embedding for all tokens.
|
||||
|
||||
Args:
|
||||
text_tokens: a tensor of token indices. Shape: [batch_size, context_length]
|
||||
|
||||
Returns:
|
||||
A tensor of [batch_size, context_length, hidden_dim].
|
||||
"""
|
||||
# [batch_size, context_length] --> [batch_size, context_length, hidden_dim]
|
||||
token_emb = self.embedding_layer(text_tokens)
|
||||
seq_len = token_emb.shape[1]
|
||||
if self.positional_embedding is not None:
|
||||
token_emb = token_emb + self.positional_embedding(seq_len).to(token_emb.dtype)
|
||||
token_emb = self.embedding_dropout(token_emb)
|
||||
return token_emb
|
||||
|
||||
@staticmethod
|
||||
@torch.jit.script # use scripting to avoid device constant
|
||||
def build_attention_mask(text_tokens: torch.Tensor) -> Tensor:
|
||||
"""Build causal attention mask [batch_size, context_length, context_length]."""
|
||||
# Build mask with full attention between the tokens
|
||||
# pytorch uses additive attention mask; fill with -inf
|
||||
batch_size, context_length = text_tokens.shape
|
||||
mask = torch.empty(context_length, context_length, device=text_tokens.device, dtype=torch.float32)
|
||||
mask.fill_(float("-inf"))
|
||||
mask.triu_(1) # zero out the lower diagonal
|
||||
mask = mask.unsqueeze(0) # add dummy batch dimension
|
||||
mask = mask.expand(batch_size, -1, -1)
|
||||
return mask
|
||||
|
||||
def encode_text(
|
||||
self,
|
||||
text_tokens: Tensor,
|
||||
key_padding_mask: Tensor | None = None,
|
||||
return_all_tokens: bool = False,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> Tensor:
|
||||
"""Return text token embeddings.
|
||||
|
||||
Args:
|
||||
text_tokens: a tensor of token indices. Shape: [batch_size, context_length]
|
||||
key_padding_mask: a tensor of boolean values as the padding mask of shape [batch_size, context_length]
|
||||
return_all_tokens: a boolean flag to return all tokens, defaults to False to return only EOT token
|
||||
embedding.
|
||||
|
||||
Returns:
|
||||
A tensor of [batch_size, context_length, hidden_dim] if return_all_tokens is
|
||||
True, otherwise a tensor of [batch_size, hidden_dim].
|
||||
"""
|
||||
# Discrete tokens to continuous embeddings
|
||||
# [batch_size, context_length] --> [batch_size, context_length, hidden_dim]
|
||||
token_emb = self.forward_embedding(text_tokens)
|
||||
|
||||
# [1, context_length, context_length]
|
||||
attn_mask = None
|
||||
if self.causal_masking:
|
||||
attn_mask = self.build_attention_mask(text_tokens=text_tokens)
|
||||
key_padding_mask = None
|
||||
|
||||
for layer in self.transformer:
|
||||
token_emb = layer(
|
||||
token_emb,
|
||||
key_padding_mask=key_padding_mask,
|
||||
attn_mask=attn_mask,
|
||||
)
|
||||
|
||||
# Apply layer norm
|
||||
token_emb = self.final_layer_norm(token_emb)
|
||||
|
||||
if return_all_tokens:
|
||||
return token_emb
|
||||
|
||||
# Take features from the eot embedding (eot_token is the highest number in each sequence)
|
||||
token_emb = token_emb[torch.arange(text_tokens.shape[0]), text_tokens.argmax(dim=-1)]
|
||||
|
||||
token_emb = token_emb @ self.projection_layer
|
||||
return token_emb
|
||||
|
||||
def forward(
|
||||
self,
|
||||
text_tokens: Tensor,
|
||||
key_padding_mask: Tensor | None = None,
|
||||
return_all_tokens: bool = False,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> Tensor:
|
||||
# Image-text pair data with single caption
|
||||
# [B, CL] --> [B, d]
|
||||
text_tokens = self.encode_text(
|
||||
text_tokens=text_tokens,
|
||||
key_padding_mask=key_padding_mask,
|
||||
return_all_tokens=return_all_tokens,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
return text_tokens
|
||||
64
train.py
Normal file
64
train.py
Normal file
@@ -0,0 +1,64 @@
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
# --- 引入你的模块 ---
|
||||
from dataset import YOLODataset
|
||||
from yolo11_standalone import YOLO11
|
||||
from loss import YOLOv8DetectionLoss
|
||||
|
||||
# --- 引入刚刚写的 Trainer ---
|
||||
from trainer import YOLO11Trainer
|
||||
|
||||
def run_training():
|
||||
# --- 1.全局配置 ---
|
||||
img_dir_train = "E:\\Datasets\\coco\\images\\train2017"
|
||||
label_dir_train = "E:\\Datasets\\coco\\labels\\train2017"
|
||||
img_dir_val = "E:\\Datasets\\coco\\images\\val2017"
|
||||
label_dir_val = "E:\\Datasets\\coco\\labels\\val2017"
|
||||
|
||||
epochs = 50
|
||||
batch_size = 36
|
||||
img_size = 640
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
# --- 2. 准备数据 ---
|
||||
print("Loading Data...")
|
||||
train_dataset = YOLODataset(img_dir_train, label_dir_train, img_size=img_size, is_train=True)
|
||||
val_dataset = YOLODataset(img_dir_val, label_dir_val, img_size=img_size, is_train=False)
|
||||
|
||||
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,
|
||||
collate_fn=YOLODataset.collate_fn, num_workers=8, pin_memory=True)
|
||||
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False,
|
||||
collate_fn=YOLODataset.collate_fn, num_workers=8, pin_memory=True)
|
||||
|
||||
# --- 3. 初始化模型 ---
|
||||
print("Initializing Model...")
|
||||
model = YOLO11(nc=80, scale='s')
|
||||
# model.load_weights("yolo11s.pth")
|
||||
# model.to(device)
|
||||
|
||||
strides = [8, 16, 32]
|
||||
hyp = {
|
||||
'box': 7.5,
|
||||
'cls': 0.5,
|
||||
'dfl': 1.5
|
||||
}
|
||||
loss_fn = YOLOv8DetectionLoss(nc=80, reg_max=16, stride=strides, hyp=hyp)
|
||||
|
||||
# --- 5. 初始化 Trainer 并开始训练 ---
|
||||
trainer = YOLO11Trainer(
|
||||
model=model,
|
||||
train_loader=train_loader,
|
||||
val_loader=val_loader,
|
||||
loss_fn=loss_fn,
|
||||
epochs=epochs,
|
||||
lr0=0.01,
|
||||
device=device,
|
||||
save_dir='./my_yolo_result'
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Windows下多进程dataloader需要这个保护
|
||||
run_training()
|
||||
275
trainer.py
Normal file
275
trainer.py
Normal file
@@ -0,0 +1,275 @@
|
||||
import math
|
||||
import copy
|
||||
import time
|
||||
import logging
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from tqdm import tqdm
|
||||
from metrics import ap_per_class, box_iou, non_max_suppression, xywh2xyxy
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(format="%(message)s", level=logging.INFO)
|
||||
LOGGER = logging.getLogger("YOLO_Trainer")
|
||||
|
||||
# ==============================================================================
|
||||
# Helper Class: Model EMA (Exponential Moving Average)
|
||||
# ==============================================================================
|
||||
class ModelEMA:
|
||||
""" Updated Exponential Moving Average (EMA) from Ultralytics """
|
||||
def __init__(self, model, decay=0.9999, tau=2000, updates=0):
|
||||
self.ema = copy.deepcopy(model).eval() # FP32 EMA
|
||||
self.updates = updates
|
||||
# decay exponential ramp (to help early epochs)
|
||||
self.decay = lambda x: decay * (1 - math.exp(-x / tau))
|
||||
for p in self.ema.parameters():
|
||||
p.requires_grad_(False)
|
||||
|
||||
def update(self, model):
|
||||
self.updates += 1
|
||||
d = self.decay(self.updates)
|
||||
|
||||
msd = model.state_dict()
|
||||
for k, v in self.ema.state_dict().items():
|
||||
if k in msd:
|
||||
tmp = msd[k].to(v.device)
|
||||
if v.dtype.is_floating_point:
|
||||
v *= d
|
||||
v += (1 - d) * tmp
|
||||
|
||||
# ==============================================================================
|
||||
# Main Trainer Class
|
||||
# ==============================================================================
|
||||
class YOLO11Trainer:
|
||||
def __init__(self,
|
||||
model,
|
||||
train_loader,
|
||||
val_loader,
|
||||
loss_fn,
|
||||
epochs=100,
|
||||
lr0=0.01,
|
||||
lrf=0.01,
|
||||
device='cuda',
|
||||
save_dir='./runs/train',
|
||||
warmup_epochs=3.0):
|
||||
|
||||
self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
|
||||
self.model = model.to(self.device)
|
||||
self.train_loader = train_loader
|
||||
self.val_loader = val_loader
|
||||
self.loss_fn = loss_fn
|
||||
self.epochs = epochs
|
||||
self.save_dir = Path(save_dir)
|
||||
self.save_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.warmup_epochs = warmup_epochs
|
||||
self.start_epoch = 0
|
||||
|
||||
# --- Optimizer Building ---
|
||||
g = [], [], [] # optimizer parameter groups
|
||||
bn = tuple(v for k, v in nn.__dict__.items() if 'Norm' in k)
|
||||
|
||||
for v in self.model.modules():
|
||||
for p_name, p in v.named_parameters(recurse=False):
|
||||
if p_name == 'bias':
|
||||
g[2].append(p) # biases
|
||||
elif isinstance(v, bn):
|
||||
g[1].append(p) # bn weights (no decay)
|
||||
else:
|
||||
g[0].append(p) # weights (decay)
|
||||
|
||||
self.optimizer = optim.SGD(g[2], lr=lr0, momentum=0.937, nesterov=True)
|
||||
self.optimizer.add_param_group({'params': g[0], 'weight_decay': 0.0005})
|
||||
self.optimizer.add_param_group({'params': g[1], 'weight_decay': 0.0})
|
||||
|
||||
LOGGER.info(f"Optimizer: weights={len(g[0])}, bn={len(g[1])}, biases={len(g[2])}")
|
||||
|
||||
# --- Scheduler ---
|
||||
self.lf = lambda x: ((1 - math.cos(x * math.pi / self.epochs)) / 2) * (lrf - 1) + 1
|
||||
self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=self.lf)
|
||||
|
||||
# --- AMP & EMA ---
|
||||
self.scaler = torch.amp.GradScaler('cuda', enabled=True)
|
||||
self.ema = ModelEMA(self.model)
|
||||
|
||||
def train_one_epoch(self, epoch):
|
||||
self.model.train()
|
||||
pbar = tqdm(enumerate(self.train_loader), total=len(self.train_loader),
|
||||
desc=f'Epoch {epoch+1}/{self.epochs}', leave=True)
|
||||
|
||||
mloss = torch.zeros(3, device=self.device)
|
||||
self.optimizer.zero_grad()
|
||||
nb = len(self.train_loader)
|
||||
|
||||
for i, batch in pbar:
|
||||
ni = i + nb * epoch
|
||||
|
||||
imgs, targets, paths = batch
|
||||
imgs = imgs.to(self.device, non_blocking=True)
|
||||
targets = targets.to(self.device)
|
||||
|
||||
# --- Warmup ---
|
||||
if ni <= nb * self.warmup_epochs:
|
||||
xp = [0, nb * self.warmup_epochs]
|
||||
for j, x in enumerate(self.optimizer.param_groups):
|
||||
lr_target = x['initial_lr'] * self.lf(epoch)
|
||||
x['lr'] = np.interp(ni, xp, [0.1 if j == 0 else 0.0, lr_target])
|
||||
if 'momentum' in x:
|
||||
x['momentum'] = np.interp(ni, xp, [0.8, 0.937])
|
||||
|
||||
# --- Forward ---
|
||||
with torch.amp.autocast('cuda', enabled=True):
|
||||
preds = self.model(imgs)
|
||||
|
||||
target_batch = {
|
||||
"batch_idx": targets[:, 0],
|
||||
"cls": targets[:, 1],
|
||||
"bboxes": targets[:, 2:],
|
||||
}
|
||||
|
||||
loss, loss_items = self.loss_fn(preds, target_batch)
|
||||
|
||||
# --- Backward ---
|
||||
self.scaler.scale(loss).backward()
|
||||
self.scaler.unscale_(self.optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=10.0)
|
||||
self.scaler.step(self.optimizer)
|
||||
self.scaler.update()
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
# --- EMA ---
|
||||
self.ema.update(self.model)
|
||||
|
||||
# --- Logging ---
|
||||
loss_items = loss_items.detach()
|
||||
mloss = (mloss * i + loss_items) / (i + 1)
|
||||
mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G'
|
||||
|
||||
pbar.set_postfix({
|
||||
'mem': mem,
|
||||
'box': f"{mloss[0]:.4f}",
|
||||
'cls': f"{mloss[1]:.4f}",
|
||||
'dfl': f"{mloss[2]:.4f}",
|
||||
'lr': f"{self.optimizer.param_groups[0]['lr']:.5f}"
|
||||
})
|
||||
|
||||
def validate(self):
|
||||
model = self.ema.ema
|
||||
device = self.device
|
||||
|
||||
# --- Metrics Config ---
|
||||
conf_thres = 0.001 # Low threshold for mAP calculation
|
||||
iou_thres = 0.7 # NMS IoU threshold
|
||||
iouv = torch.linspace(0.5, 0.95, 10, device=device) # IoU vector for mAP@0.5:0.95
|
||||
|
||||
loss_sum = torch.zeros(3, device=device)
|
||||
stats = [] # [(correct, conf, pred_cls, target_cls)]
|
||||
|
||||
LOGGER.info("\nValidating...")
|
||||
|
||||
model.eval()
|
||||
|
||||
pbar = tqdm(self.val_loader, desc="Calc Metrics")
|
||||
with torch.no_grad():
|
||||
for batch in pbar:
|
||||
imgs, targets, _ = batch
|
||||
imgs = imgs.to(device, non_blocking=True)
|
||||
targets = targets.to(device)
|
||||
_, _, height, width = imgs.shape
|
||||
|
||||
# Inference
|
||||
preds = model(imgs)
|
||||
|
||||
# NMS
|
||||
preds = non_max_suppression(preds, conf_thres=conf_thres, iou_thres=iou_thres)
|
||||
|
||||
# Metrics Processing
|
||||
for si, pred in enumerate(preds):
|
||||
labels = targets[targets[:, 0] == si]
|
||||
nl = len(labels)
|
||||
tcls = labels[:, 1].tolist() if nl else []
|
||||
|
||||
if len(pred) == 0:
|
||||
if nl:
|
||||
stats.append((torch.zeros(0, iouv.numel(), dtype=torch.bool),
|
||||
torch.Tensor(), torch.Tensor(), torch.tensor(tcls)))
|
||||
continue
|
||||
|
||||
# Predictions
|
||||
predn = pred.clone()
|
||||
|
||||
# Ground Truth
|
||||
if nl:
|
||||
tbox = xywh2xyxy(labels[:, 2:6])
|
||||
tbox[:, [0, 2]] *= width
|
||||
tbox[:, [1, 3]] *= height
|
||||
labelsn = torch.cat((labels[:, 1:2], tbox), 1) # [cls, x1, y1, x2, y2]
|
||||
|
||||
# Match predictions to GT
|
||||
correct = self._process_batch(predn, labelsn, iouv)
|
||||
else:
|
||||
correct = torch.zeros(pred.shape[0], iouv.numel(), dtype=torch.bool)
|
||||
|
||||
stats.append((correct.cpu(), pred[:, 4].cpu(), pred[:, 5].cpu(), torch.tensor(tcls)))
|
||||
|
||||
mp, mr, map50, map5095 = 0.0, 0.0, 0.0, 0.0
|
||||
if len(stats) and stats[0][0].any():
|
||||
stats = [torch.cat(x, 0).cpu().numpy() for x in zip(*stats)] # to numpy
|
||||
tp, fp, p, r, f1, ap, ap_class = ap_per_class(*stats)
|
||||
ap50, ap = ap[:, 0], ap.mean(1) # AP@0.5, AP@0.5:0.95
|
||||
mp, mr, map50, map5095 = p.mean(), r.mean(), ap50.mean(), ap.mean()
|
||||
|
||||
LOGGER.info(f"Val Results: Prec={mp:.3f}, Recall={mr:.3f} mAP50={map50:.3f} mAP50-95={map5095:.3f}")
|
||||
|
||||
return map50
|
||||
|
||||
def _process_batch(self, detections, labels, iouv):
|
||||
"""
|
||||
Return correct prediction matrix
|
||||
detections: [N, 6] (x1, y1, x2, y2, conf, cls)
|
||||
labels: [M, 5] (cls, x1, y1, x2, y2)
|
||||
"""
|
||||
correct = torch.zeros(detections.shape[0], iouv.shape[0], dtype=torch.bool, device=iouv.device)
|
||||
iou = box_iou(labels[:, 1:], detections[:, :4])
|
||||
|
||||
x = torch.where((iou >= iouv[0]) & (labels[:, 0:1] == detections[:, 5]))
|
||||
|
||||
if x[0].shape[0]:
|
||||
matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1).cpu().numpy()
|
||||
if x[0].shape[0] > 1:
|
||||
matches = matches[matches[:, 2].argsort()[::-1]]
|
||||
matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
|
||||
matches = matches[matches[:, 2].argsort()[::-1]]
|
||||
matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
|
||||
|
||||
matches = torch.from_numpy(matches).to(iouv.device)
|
||||
correct[matches[:, 1].long()] = matches[:, 2:3] >= iouv
|
||||
|
||||
return correct
|
||||
|
||||
def train(self):
|
||||
LOGGER.info(f"Starting training on {self.device} for {self.epochs} epochs...")
|
||||
start_time = time.time()
|
||||
best_fitness = 0.0
|
||||
|
||||
for epoch in range(self.start_epoch, self.epochs):
|
||||
self.train_one_epoch(epoch)
|
||||
self.scheduler.step()
|
||||
map50 = self.validate()
|
||||
|
||||
ckpt = {
|
||||
'epoch': epoch,
|
||||
'model': self.model.state_dict(),
|
||||
'ema': self.ema.ema.state_dict(),
|
||||
'optimizer': self.optimizer.state_dict(),
|
||||
}
|
||||
|
||||
torch.save(ckpt, self.save_dir / 'last.pt')
|
||||
|
||||
if map50 > best_fitness:
|
||||
best_fitness = map50
|
||||
torch.save(ckpt, self.save_dir / 'best.pt')
|
||||
LOGGER.info(f"--> Saved best model with Recall/mAP: {best_fitness:.4f}")
|
||||
|
||||
LOGGER.info(f"\nTraining completed in {(time.time() - start_time) / 3600:.3f} hours.")
|
||||
923
yolo11_standalone.py
Normal file
923
yolo11_standalone.py
Normal file
@@ -0,0 +1,923 @@
|
||||
import math
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
def make_divisible(x, divisor):
|
||||
if isinstance(x, torch.Tensor):
|
||||
return x
|
||||
return math.ceil(x / divisor) * divisor
|
||||
|
||||
def autopad(k, p=None, d=1):
|
||||
if d > 1:
|
||||
k = d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k]
|
||||
if p is None:
|
||||
p = k // 2 if isinstance(k, int) else [x // 2 for x in k]
|
||||
return p
|
||||
|
||||
def make_anchors(feats, strides, grid_cell_offset=0.5):
|
||||
anchor_points, stride_tensor = [], []
|
||||
assert feats is not None
|
||||
dtype, device = feats[0].dtype, feats[0].device
|
||||
for i, stride in enumerate(strides):
|
||||
_, _, h, w = feats[i].shape
|
||||
sx = torch.arange(end=w, device=device, dtype=dtype) + grid_cell_offset
|
||||
sy = torch.arange(end=h, device=device, dtype=dtype) + grid_cell_offset
|
||||
sy, sx = torch.meshgrid(sy, sx, indexing="ij")
|
||||
anchor_points.append(torch.stack((sx, sy), -1).view(-1, 2))
|
||||
stride_tensor.append(torch.full((h * w, 1), stride, dtype=dtype, device=device))
|
||||
return torch.cat(anchor_points), torch.cat(stride_tensor)
|
||||
|
||||
def dist2bbox(distance, anchor_points, xywh=True, dim=-1):
|
||||
lt, rb = distance.chunk(2, dim)
|
||||
x1y1 = anchor_points - lt
|
||||
x2y2 = anchor_points + rb
|
||||
if xywh:
|
||||
c_xy = (x1y1 + x2y2) / 2
|
||||
wh = x2y2 - x1y1
|
||||
return torch.cat((c_xy, wh), dim)
|
||||
return torch.cat((x1y1, x2y2), dim)
|
||||
|
||||
class Concat(nn.Module):
|
||||
def __init__(self, dimension=1):
|
||||
super().__init__()
|
||||
self.d = dimension
|
||||
|
||||
def forward(self, x: List[torch.Tensor]):
|
||||
return torch.cat(x, self.d)
|
||||
|
||||
|
||||
class Conv(nn.Module):
|
||||
default_act = nn.SiLU(inplace=True)
|
||||
|
||||
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p, d), groups=g, dilation=d, bias=False) # type: ignore
|
||||
self.bn = nn.BatchNorm2d(c2, eps=0.001, momentum=0.03)
|
||||
self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
return self.act(self.bn(self.conv(x)))
|
||||
|
||||
def forward_fuse(self, x):
|
||||
return self.act(self.conv(x))
|
||||
|
||||
class DWConv(Conv):
|
||||
def __init__(self, c1, c2, k=1, s=1, d=1, act=True):
|
||||
super().__init__(c1, c2, k, s, g=math.gcd(c1, c2), d=d, act=act)
|
||||
|
||||
class DFL(nn.Module):
|
||||
def __init__(self, c1: int = 16):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(c1, 1, 1, bias=False).requires_grad_(False)
|
||||
x = torch.arange(c1, dtype=torch.float)
|
||||
self.conv.weight.data[:] = nn.Parameter(x.view(1, c1, 1, 1))
|
||||
self.c1 = c1
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
b, _, a = x.shape
|
||||
return self.conv(x.view(b, 4, self.c1, a).transpose(2, 1).softmax(1)).view(b, 4, a)
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
def __init__(
|
||||
self, c1: int, c2: int, shortcut: bool = True, g: int = 1, k: Tuple[int, int] = (3, 3), e: float = 0.5
|
||||
):
|
||||
super().__init__()
|
||||
c_ = int(c2 * e)
|
||||
self.cv1 = Conv(c1, c_, k[0], 1)
|
||||
self.cv2 = Conv(c_, c2, k[1], 1, g=g)
|
||||
self.add = shortcut and c1 == c2
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
|
||||
|
||||
class C2f(nn.Module):
|
||||
def __init__(self, c1: int, c2: int, n: int = 1, shortcut: bool = False, g: int = 1, e: float = 0.5):
|
||||
super().__init__()
|
||||
self.c = int(c2 * e)
|
||||
self.cv1 = Conv(c1, 2 * self.c, 1, 1)
|
||||
self.cv2 = Conv((2 + n) * self.c, c2, 1)
|
||||
self.m = nn.ModuleList(Bottleneck(self.c, self.c, shortcut, g, k=((3, 3), (3, 3)), e=1.0) for _ in range(n)) # type: ignore
|
||||
|
||||
def forward(self, x):
|
||||
chunk_result = self.cv1(x).chunk(2, 1)
|
||||
y = [chunk_result[0], chunk_result[1]]
|
||||
|
||||
for m_module in self.m:
|
||||
y.append(m_module(y[-1]))
|
||||
return self.cv2(torch.cat(y, 1))
|
||||
|
||||
def forward_split(self, x: torch.Tensor) -> torch.Tensor:
|
||||
y = self.cv1(x).split((self.c, self.c), 1)
|
||||
y = [y[0], y[1]]
|
||||
y.extend(m(y[-1]) for m in self.m)
|
||||
return self.cv2(torch.cat(y, 1))
|
||||
|
||||
class C3(nn.Module):
|
||||
def __init__(self, c1: int, c2: int, n: int = 1, shortcut: bool = True, g: int = 1, e: float = 0.5):
|
||||
super().__init__()
|
||||
c_ = int(c2 * e)
|
||||
self.cv1 = Conv(c1, c_, 1, 1)
|
||||
self.cv2 = Conv(c1, c_, 1, 1)
|
||||
self.cv3 = Conv(2 * c_, c2, 1)
|
||||
self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, k=((1, 1), (3, 3)), e=1.0) for _ in range(n))) # type: ignore
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), 1))
|
||||
|
||||
class C3k(C3):
|
||||
def __init__(self, c1: int, c2: int, n: int = 1, shortcut: bool = True, g: int = 1, e: float = 0.5, k: int = 3):
|
||||
super().__init__(c1, c2, n, shortcut, g, e)
|
||||
c_ = int(c2 * e)
|
||||
self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, k=(k, k), e=1.0) for _ in range(n)))
|
||||
|
||||
class C3k2(C2f):
|
||||
def __init__(
|
||||
self, c1: int, c2: int, n: int = 1, c3k: bool = False, e: float = 0.5, g: int = 1, shortcut: bool = True
|
||||
):
|
||||
super().__init__(c1, c2, n, shortcut, g, e)
|
||||
self.m = nn.ModuleList(
|
||||
C3k(self.c, self.c, 2, shortcut, g) if c3k else Bottleneck(self.c, self.c, shortcut, g) for _ in range(n)
|
||||
)
|
||||
|
||||
class SPPF(nn.Module):
|
||||
def __init__(self, c1: int, c2: int, k: int = 5):
|
||||
super().__init__()
|
||||
c_ = c1 // 2
|
||||
self.cv1 = Conv(c1, c_, 1, 1)
|
||||
self.cv2 = Conv(c_ * 4, c2, 1, 1)
|
||||
self.m = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
y = [self.cv1(x)]
|
||||
y.extend(self.m(y[-1]) for _ in range(3))
|
||||
return self.cv2(torch.cat(y, 1))
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim: int, num_heads: int = 8, attn_ratio: float = 0.5):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.key_dim = int(self.head_dim * attn_ratio)
|
||||
self.scale = self.key_dim**-0.5
|
||||
nh_kd = self.key_dim * num_heads
|
||||
h = dim + nh_kd * 2
|
||||
self.qkv = Conv(dim, h, 1, act=False)
|
||||
self.proj = Conv(dim, dim, 1, act=False)
|
||||
self.pe = Conv(dim, dim, 3, 1, g=dim, act=False)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
B, C, H, W = x.shape
|
||||
N = H * W
|
||||
qkv = self.qkv(x)
|
||||
q, k, v = qkv.view(B, self.num_heads, self.key_dim * 2 + self.head_dim, N).split(
|
||||
[self.key_dim, self.key_dim, self.head_dim], dim=2
|
||||
)
|
||||
|
||||
attn = (q.transpose(-2, -1) @ k) * self.scale
|
||||
attn = attn.softmax(dim=-1)
|
||||
x = (v @ attn.transpose(-2, -1)).view(B, C, H, W) + self.pe(v.reshape(B, C, H, W))
|
||||
x = self.proj(x)
|
||||
return x
|
||||
|
||||
class PSABlock(nn.Module):
|
||||
def __init__(self, c: int, attn_ratio: float = 0.5, num_heads: int = 4, shortcut: bool = True) -> None:
|
||||
super().__init__()
|
||||
self.attn = Attention(c, attn_ratio=attn_ratio, num_heads=num_heads)
|
||||
self.ffn = nn.Sequential(Conv(c, c * 2, 1), Conv(c * 2, c, 1, act=False))
|
||||
self.add = shortcut
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = x + self.attn(x) if self.add else self.attn(x)
|
||||
x = x + self.ffn(x) if self.add else self.ffn(x)
|
||||
return x
|
||||
|
||||
class C2PSA(nn.Module):
|
||||
def __init__(self, c1: int, c2: int, n: int = 1, e: float = 0.5):
|
||||
super().__init__()
|
||||
assert c1 == c2
|
||||
self.c = int(c1 * e)
|
||||
self.cv1 = Conv(c1, 2 * self.c, 1, 1)
|
||||
self.cv2 = Conv(2 * self.c, c1, 1)
|
||||
|
||||
self.m = nn.Sequential(*(PSABlock(self.c, attn_ratio=0.5, num_heads=self.c // 64) for _ in range(n)))
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
a, b = self.cv1(x).split((self.c, self.c), dim=1)
|
||||
b = self.m(b)
|
||||
return self.cv2(torch.cat((a, b), 1))
|
||||
|
||||
class Proto(nn.Module):
|
||||
def __init__(self, c1: int, c_: int = 256, c2: int = 32):
|
||||
super().__init__()
|
||||
self.cv1 = Conv(c1, c_, k=3)
|
||||
self.upsample = nn.ConvTranspose2d(c_, c_, 2, 2, 0, bias=True)
|
||||
self.cv2 = Conv(c_, c_, k=3)
|
||||
self.cv3 = Conv(c_, c2)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.cv3(self.cv2(self.upsample(self.cv1(x))))
|
||||
|
||||
class BNContrastiveHead(nn.Module):
|
||||
def __init__(self, embed_dims: int):
|
||||
super().__init__()
|
||||
self.norm = nn.BatchNorm2d(embed_dims)
|
||||
self.bias = nn.Parameter(torch.tensor([-10.0]))
|
||||
self.logit_scale = nn.Parameter(-1.0 * torch.ones([]))
|
||||
|
||||
def fuse(self):
|
||||
del self.norm
|
||||
del self.bias
|
||||
del self.logit_scale
|
||||
self.forward = self.forward_fuse
|
||||
|
||||
def forward_fuse(self, x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
|
||||
return x
|
||||
|
||||
def forward(self, x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
|
||||
x = self.norm(x)
|
||||
w = F.normalize(w, dim=-1, p=2)
|
||||
|
||||
x = torch.einsum("bchw,bkc->bkhw", x, w)
|
||||
return x * self.logit_scale.exp() + self.bias
|
||||
|
||||
class SwiGLUFFN(nn.Module):
|
||||
def __init__(self, gc: int, ec: int, e: int = 4) -> None:
|
||||
super().__init__()
|
||||
self.w12 = nn.Linear(gc, e * ec)
|
||||
self.w3 = nn.Linear(e * ec // 2, ec)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x12 = self.w12(x)
|
||||
x1, x2 = x12.chunk(2, dim=-1)
|
||||
hidden = F.silu(x1) * x2
|
||||
return self.w3(hidden)
|
||||
|
||||
class Residual(nn.Module):
|
||||
def __init__(self, m: nn.Module) -> None:
|
||||
super().__init__()
|
||||
self.m = m
|
||||
nn.init.zeros_(self.m.w3.bias)
|
||||
# For models with l scale, please change the initialization to
|
||||
# nn.init.constant_(self.m.w3.weight, 1e-6)
|
||||
nn.init.zeros_(self.m.w3.weight)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return x + self.m(x)
|
||||
|
||||
class SAVPE(nn.Module):
|
||||
def __init__(self, ch: List[int], c3: int, embed: int):
|
||||
super().__init__()
|
||||
self.cv1 = nn.ModuleList(
|
||||
nn.Sequential(
|
||||
Conv(x, c3, 3), Conv(c3, c3, 3), nn.Upsample(scale_factor=i * 2) if i in {1, 2} else nn.Identity()
|
||||
)
|
||||
for i, x in enumerate(ch)
|
||||
)
|
||||
|
||||
self.cv2 = nn.ModuleList(
|
||||
nn.Sequential(Conv(x, c3, 1), nn.Upsample(scale_factor=i * 2) if i in {1, 2} else nn.Identity())
|
||||
for i, x in enumerate(ch)
|
||||
)
|
||||
|
||||
self.c = 16
|
||||
self.cv3 = nn.Conv2d(3 * c3, embed, 1)
|
||||
self.cv4 = nn.Conv2d(3 * c3, self.c, 3, padding=1)
|
||||
self.cv5 = nn.Conv2d(1, self.c, 3, padding=1)
|
||||
self.cv6 = nn.Sequential(Conv(2 * self.c, self.c, 3), nn.Conv2d(self.c, self.c, 3, padding=1))
|
||||
|
||||
def forward(self, x: List[torch.Tensor], vp: torch.Tensor) -> torch.Tensor:
|
||||
y = [self.cv2[i](xi) for i, xi in enumerate(x)]
|
||||
y = self.cv4(torch.cat(y, dim=1))
|
||||
|
||||
x = [self.cv1[i](xi) for i, xi in enumerate(x)]
|
||||
x = self.cv3(torch.cat(x, dim=1))
|
||||
|
||||
B, C, H, W = x.shape # type: ignore
|
||||
|
||||
Q = vp.shape[1]
|
||||
|
||||
x = x.view(B, C, -1) # type: ignore
|
||||
|
||||
y = y.reshape(B, 1, self.c, H, W).expand(-1, Q, -1, -1, -1).reshape(B * Q, self.c, H, W)
|
||||
vp = vp.reshape(B, Q, 1, H, W).reshape(B * Q, 1, H, W)
|
||||
|
||||
y = self.cv6(torch.cat((y, self.cv5(vp)), dim=1))
|
||||
|
||||
y = y.reshape(B, Q, self.c, -1)
|
||||
vp = vp.reshape(B, Q, 1, -1)
|
||||
|
||||
score = y * vp + torch.logical_not(vp) * torch.finfo(y.dtype).min
|
||||
score = F.softmax(score, dim=-1).to(y.dtype)
|
||||
aggregated = score.transpose(-2, -3) @ x.reshape(B, self.c, C // self.c, -1).transpose(-1, -2)
|
||||
|
||||
return F.normalize(aggregated.transpose(-2, -3).reshape(B, Q, -1), dim=-1, p=2)
|
||||
|
||||
class Detect(nn.Module):
|
||||
dynamic = False
|
||||
export = False
|
||||
shape = None
|
||||
anchors = torch.empty(0)
|
||||
strides = torch.empty(0)
|
||||
|
||||
def __init__(self, nc=80, ch=()):
|
||||
super().__init__()
|
||||
self.nc = nc
|
||||
self.nl = len(ch)
|
||||
self.reg_max = 16
|
||||
self.no = nc + self.reg_max * 4
|
||||
self.stride = torch.tensor([8., 16., 32.])
|
||||
|
||||
c2, c3 = max((16, ch[0] // 4, self.reg_max * 4)), max(ch[0], min(self.nc, 100))
|
||||
|
||||
self.cv2 = nn.ModuleList(
|
||||
nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3), nn.Conv2d(c2, 4 * self.reg_max, 1)) for x in ch
|
||||
)
|
||||
self.cv3 = nn.ModuleList(
|
||||
nn.Sequential(
|
||||
nn.Sequential(DWConv(x, x, 3), Conv(x, c3, 1)),
|
||||
nn.Sequential(DWConv(c3, c3, 3), Conv(c3, c3, 1)),
|
||||
nn.Conv2d(c3, self.nc, 1),
|
||||
)
|
||||
for x in ch
|
||||
)
|
||||
self.dfl = DFL(self.reg_max) if self.reg_max > 1 else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
for i in range(self.nl):
|
||||
x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1)
|
||||
|
||||
if self.training:
|
||||
return x
|
||||
|
||||
# Inference path
|
||||
shape = x[0].shape
|
||||
x_cat = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2)
|
||||
|
||||
if self.dynamic or self.shape != shape:
|
||||
self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
|
||||
self.shape = shape
|
||||
|
||||
box, cls = x_cat.split((self.reg_max * 4, self.nc), 1)
|
||||
dbox = dist2bbox(self.dfl(box), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
|
||||
|
||||
return torch.cat((dbox, cls.sigmoid()), 1)
|
||||
|
||||
def bias_init(self):
|
||||
m = self
|
||||
for a, b, s in zip(m.cv2, m.cv3, m.stride):
|
||||
a[-1].bias.data[:] = 1.0 # type: ignore
|
||||
b[-1].bias.data[: m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # type: ignore
|
||||
|
||||
def decode_bboxes(self, bboxes: torch.Tensor, anchors: torch.Tensor, xywh: bool = True) -> torch.Tensor:
|
||||
return dist2bbox(bboxes, anchors, xywh=xywh and not (self.end2end or self.xyxy), dim=1)
|
||||
|
||||
@staticmethod
|
||||
def postprocess(preds: torch.Tensor, max_det: int, nc: int = 80) -> torch.Tensor:
|
||||
batch_size, anchors, _ = preds.shape
|
||||
boxes, scores = preds.split([4, nc], dim=-1)
|
||||
index = scores.amax(dim=-1).topk(min(max_det, anchors))[1].unsqueeze(-1)
|
||||
boxes = boxes.gather(dim=1, index=index.repeat(1, 1, 4))
|
||||
scores = scores.gather(dim=1, index=index.repeat(1, 1, nc))
|
||||
scores, index = scores.flatten(1).topk(min(max_det, anchors))
|
||||
i = torch.arange(batch_size)[..., None]
|
||||
return torch.cat([boxes[i, index // nc], scores[..., None], (index % nc)[..., None].float()], dim=-1)
|
||||
|
||||
class YOLOEDetect(Detect):
|
||||
def __init__(self, nc: int = 80, embed: int = 512, ch: Tuple = ()):
|
||||
super().__init__(nc, ch)
|
||||
c3 = max(ch[0], min(self.nc, 100))
|
||||
assert c3 <= embed
|
||||
self.cv3 = (
|
||||
nn.ModuleList(
|
||||
nn.Sequential(
|
||||
nn.Sequential(DWConv(x, x, 3), Conv(x, c3, 1)),
|
||||
nn.Sequential(DWConv(c3, c3, 3), Conv(c3, c3, 1)),
|
||||
nn.Conv2d(c3, embed, 1),
|
||||
)
|
||||
for x in ch
|
||||
)
|
||||
)
|
||||
|
||||
self.cv4 = nn.ModuleList(BNContrastiveHead(embed) for _ in ch)
|
||||
|
||||
self.reprta = Residual(SwiGLUFFN(embed, embed))
|
||||
self.savpe = SAVPE(ch, c3, embed) # type: ignore
|
||||
self.embed = embed
|
||||
|
||||
def get_tpe(self, tpe: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
|
||||
return None if tpe is None else F.normalize(self.reprta(tpe), dim=-1, p=2)
|
||||
|
||||
def get_vpe(self, x: List[torch.Tensor], vpe: torch.Tensor) -> torch.Tensor:
|
||||
if vpe.shape[1] == 0: # no visual prompt embeddings
|
||||
return torch.zeros(x[0].shape[0], 0, self.embed, device=x[0].device)
|
||||
if vpe.ndim == 4: # (B, N, H, W)
|
||||
vpe = self.savpe(x, vpe)
|
||||
assert vpe.ndim == 3 # (B, N, D)
|
||||
return vpe
|
||||
|
||||
def forward( # type: ignore
|
||||
self, x: List[torch.Tensor], cls_pe: torch.Tensor
|
||||
) -> Union[torch.Tensor, Tuple]:
|
||||
for i in range(self.nl):
|
||||
x[i] = torch.cat((self.cv2[i](x[i]), self.cv4[i](self.cv3[i](x[i]), cls_pe)), 1)
|
||||
if self.training:
|
||||
return x # type: ignore
|
||||
self.no = self.nc + self.reg_max * 4 # self.nc could be changed when inference with different texts
|
||||
# Inference path
|
||||
shape = x[0].shape
|
||||
x_cat = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2)
|
||||
|
||||
if self.dynamic or self.shape != shape:
|
||||
self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
|
||||
self.shape = shape
|
||||
|
||||
box, cls = x_cat.split((self.reg_max * 4, self.nc), 1)
|
||||
dbox = dist2bbox(self.dfl(box), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
|
||||
|
||||
return torch.cat((dbox, cls.sigmoid()), 1)
|
||||
|
||||
def bias_init(self):
|
||||
m = self
|
||||
for a, b, c, s in zip(m.cv2, m.cv3, m.cv4, m.stride):
|
||||
a[-1].bias.data[:] = 1.0 # box
|
||||
b[-1].bias.data[:] = 0.0
|
||||
c.bias.data[:] = math.log(5 / m.nc / (640 / s) ** 2)
|
||||
|
||||
class YOLOESegment(YOLOEDetect):
|
||||
def __init__(
|
||||
self, nc: int = 80, nm: int = 32, npr: int = 256, embed: int = 512, ch: Tuple = ()
|
||||
):
|
||||
super().__init__(nc, embed, ch)
|
||||
self.nm = nm
|
||||
self.npr = npr
|
||||
self.proto = Proto(ch[0], self.npr, self.nm)
|
||||
|
||||
c5 = max(ch[0] // 4, self.nm)
|
||||
self.cv5 = nn.ModuleList(nn.Sequential(Conv(x, c5, 3), Conv(c5, c5, 3), nn.Conv2d(c5, self.nm, 1)) for x in ch)
|
||||
|
||||
def forward(self, x: List[torch.Tensor], text: torch.Tensor) -> Union[Tuple, torch.Tensor]:
|
||||
p = self.proto(x[0]) # mask protos
|
||||
bs = p.shape[0] # batch size
|
||||
|
||||
mc = torch.cat([self.cv5[i](x[i]).view(bs, self.nm, -1) for i in range(self.nl)], 2) # mask coefficients
|
||||
x = YOLOEDetect.forward(self, x, text)
|
||||
|
||||
if self.training:
|
||||
return x, mc, p
|
||||
|
||||
return (torch.cat([x, mc], 1), p)
|
||||
|
||||
|
||||
|
||||
class YOLO11(nn.Module):
|
||||
def __init__(self, nc=80, scale='n'):
|
||||
super().__init__()
|
||||
self.nc = nc
|
||||
|
||||
# Scales: [depth, width, max_channels]
|
||||
# 对应 yolo11.yaml 中的 scales 参数
|
||||
scales = {
|
||||
'n': [0.50, 0.25, 1024],
|
||||
's': [0.50, 0.50, 1024],
|
||||
'm': [0.50, 1.00, 512],
|
||||
'l': [1.00, 1.00, 512],
|
||||
'x': [1.00, 1.50, 512],
|
||||
}
|
||||
|
||||
if scale not in scales:
|
||||
raise ValueError(f"Invalid scale '{scale}'. Available scales: {list(scales.keys())}")
|
||||
|
||||
depth, width, max_channels = scales[scale]
|
||||
|
||||
if scale in ['n', 's']:
|
||||
c3k_override = False
|
||||
else:
|
||||
c3k_override = True
|
||||
|
||||
# 辅助函数:计算通道数 (Width Scaling)
|
||||
def gw(channels):
|
||||
return make_divisible(min(channels, max_channels) * width, 8)
|
||||
|
||||
# 辅助函数:计算层重复次数 (Depth Scaling)
|
||||
def gd(n):
|
||||
return max(round(n * depth), 1) if n > 1 else n
|
||||
|
||||
self.model = nn.ModuleList()
|
||||
|
||||
# --- Backbone ---
|
||||
# 0: Conv [64, 3, 2]
|
||||
self.model.append(Conv(3, gw(64), 3, 2))
|
||||
|
||||
# 1: Conv [128, 3, 2]
|
||||
self.model.append(Conv(gw(64), gw(128), 3, 2))
|
||||
|
||||
# 2: C3k2 [256, False, 0.25] -> n=2
|
||||
self.model.append(C3k2(gw(128), gw(256), n=gd(2), c3k=False or c3k_override, e=0.25))
|
||||
|
||||
# 3: Conv [256, 3, 2]
|
||||
self.model.append(Conv(gw(256), gw(256), 3, 2))
|
||||
|
||||
# 4: C3k2 [512, False, 0.25] -> n=2
|
||||
self.model.append(C3k2(gw(256), gw(512), n=gd(2), c3k=False or c3k_override, e=0.25))
|
||||
|
||||
# 5: Conv [512, 3, 2]
|
||||
self.model.append(Conv(gw(512), gw(512), 3, 2))
|
||||
|
||||
# 6: C3k2 [512, True] -> n=2
|
||||
self.model.append(C3k2(gw(512), gw(512), n=gd(2), c3k=True))
|
||||
|
||||
# 7: Conv [1024, 3, 2]
|
||||
self.model.append(Conv(gw(512), gw(1024), 3, 2))
|
||||
|
||||
# 8: C3k2 [1024, True] -> n=2
|
||||
self.model.append(C3k2(gw(1024), gw(1024), n=gd(2), c3k=True))
|
||||
|
||||
# 9: SPPF [1024, 5]
|
||||
self.model.append(SPPF(gw(1024), gw(1024), 5))
|
||||
|
||||
# 10: C2PSA [1024] -> n=2 (YAML args=[1024], repeats=2)
|
||||
self.model.append(C2PSA(gw(1024), gw(1024), n=gd(2)))
|
||||
|
||||
# --- Head ---
|
||||
|
||||
# 11: Upsample
|
||||
self.model.append(nn.Upsample(scale_factor=2, mode='nearest'))
|
||||
|
||||
# 12: Concat [-1, 6] (P4)
|
||||
self.model.append(Concat(dimension=1))
|
||||
|
||||
# 13: C3k2 [512, False] -> n=2. Input: P5_up(gw(1024)) + P4(gw(512))
|
||||
self.model.append(C3k2(gw(1024) + gw(512), gw(512), n=gd(2), c3k=False or c3k_override))
|
||||
|
||||
# 14: Upsample
|
||||
self.model.append(nn.Upsample(scale_factor=2, mode='nearest'))
|
||||
|
||||
# 15: Concat [-1, 4] (P3)
|
||||
self.model.append(Concat(dimension=1))
|
||||
|
||||
# 16: C3k2 [256, False] -> n=2. Input: P4_up(gw(512)) + P3(gw(512))
|
||||
# 注意:Layer 4 输出是 gw(512),Layer 13 输出是 gw(512)
|
||||
self.model.append(C3k2(gw(512) + gw(512), gw(256), n=gd(2), c3k=False or c3k_override))
|
||||
|
||||
# 17: Conv [256, 3, 2]
|
||||
self.model.append(Conv(gw(256), gw(256), 3, 2))
|
||||
|
||||
# 18: Concat [-1, 13] (Head P4)
|
||||
self.model.append(Concat(dimension=1))
|
||||
|
||||
# 19: C3k2 [512, False] -> n=2. Input: P3_down(gw(256)) + Head_P4(gw(512))
|
||||
self.model.append(C3k2(gw(256) + gw(512), gw(512), n=gd(2), c3k=False or c3k_override))
|
||||
|
||||
# 20: Conv [512, 3, 2]
|
||||
self.model.append(Conv(gw(512), gw(512), 3, 2))
|
||||
|
||||
# 21: Concat [-1, 10] (P5)
|
||||
self.model.append(Concat(dimension=1))
|
||||
|
||||
# 22: C3k2 [1024, True] -> n=2. Input: P4_down(gw(512)) + P5(gw(1024))
|
||||
self.model.append(C3k2(gw(512) + gw(1024), gw(1024), n=gd(2), c3k=True))
|
||||
|
||||
# 23: Detect [nc]
|
||||
self.model.append(Detect(nc, ch=[gw(256), gw(512), gw(1024)]))
|
||||
|
||||
# --- 初始化权重 ---
|
||||
self.initialize_weights()
|
||||
|
||||
def initialize_weights(self):
|
||||
"""初始化模型权重,特别是 Detect 头的 Bias"""
|
||||
for m in self.modules():
|
||||
if isinstance(m, (nn.Conv2d, nn.Linear)):
|
||||
# 使用 Kaiming 初始化或其他合适的初始化
|
||||
pass
|
||||
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
detect_layer = self.model[-1]
|
||||
if isinstance(detect_layer, Detect):
|
||||
detect_layer.bias_init()
|
||||
|
||||
def forward(self, x):
|
||||
# Backbone
|
||||
x = self.model[0](x)
|
||||
x = self.model[1](x)
|
||||
x = self.model[2](x)
|
||||
x = self.model[3](x)
|
||||
p3 = self.model[4](x) # 保存 P3 (layer 4)
|
||||
x = self.model[5](p3)
|
||||
p4 = self.model[6](x) # 保存 P4 (layer 6)
|
||||
x = self.model[7](p4)
|
||||
x = self.model[8](x)
|
||||
x = self.model[9](x)
|
||||
p5 = self.model[10](x) # 保存 P5 (layer 10)
|
||||
|
||||
# Head
|
||||
x = self.model[11](p5) # Upsample
|
||||
x = self.model[12]([x, p4]) # Concat P4
|
||||
h1 = self.model[13](x) # Head P4 (layer 13)
|
||||
|
||||
x = self.model[14](h1) # Upsample
|
||||
x = self.model[15]([x, p3]) # Concat P3
|
||||
h2 = self.model[16](x) # Output P3 (layer 16)
|
||||
|
||||
x = self.model[17](h2) # Conv
|
||||
x = self.model[18]([x, h1]) # Concat Head P4
|
||||
h3 = self.model[19](x) # Output P4 (layer 19)
|
||||
|
||||
x = self.model[20](h3) # Conv
|
||||
x = self.model[21]([x, p5]) # Concat P5
|
||||
h4 = self.model[22](x) # Output P5 (layer 22)
|
||||
|
||||
return self.model[23]([h2, h3, h4]) # Detect
|
||||
|
||||
def load_weights(self, pth_file):
|
||||
state_dict = torch.load(pth_file, map_location='cpu', weights_only=False)
|
||||
# 移除可能存在的 'model.' 前缀 (如果权重来自 ultralytics 官方)
|
||||
# 官方权重通常是 model.model.0.conv... 这种格式,或者直接是 model.0.conv...
|
||||
# 这里做一个简单的兼容性处理
|
||||
new_state_dict = {}
|
||||
for k, v in state_dict.items():
|
||||
# 处理 ultralytics 权重字典中的 'model' 键
|
||||
if k == 'model':
|
||||
# 如果是完整的 checkpoint,权重在 'model' 键下
|
||||
# 且通常是 model.state_dict()
|
||||
if hasattr(v, 'state_dict'):
|
||||
v = v.state_dict()
|
||||
elif isinstance(v, dict):
|
||||
pass # v 就是 state_dict
|
||||
else:
|
||||
# 可能是 model 对象本身
|
||||
try:
|
||||
v = v.float().state_dict()
|
||||
except:
|
||||
continue
|
||||
|
||||
for sub_k, sub_v in v.items():
|
||||
new_state_dict[sub_k] = sub_v
|
||||
break
|
||||
else:
|
||||
new_state_dict[k] = v
|
||||
|
||||
if not new_state_dict:
|
||||
new_state_dict = state_dict
|
||||
|
||||
# 尝试加载
|
||||
try:
|
||||
self.load_state_dict(new_state_dict, strict=True)
|
||||
print(f"Successfully loaded weights from {pth_file}")
|
||||
except Exception as e:
|
||||
print(f"Error loading weights: {e}")
|
||||
print("Trying to load with strict=False...")
|
||||
self.load_state_dict(new_state_dict, strict=False)
|
||||
|
||||
class YOLO11E(nn.Module):
|
||||
def __init__(self, nc=80, scale='n'):
|
||||
super().__init__()
|
||||
self.nc = nc
|
||||
self.pe = None
|
||||
|
||||
# Scales: [depth, width, max_channels]
|
||||
# 对应 yolo11.yaml 中的 scales 参数
|
||||
scales = {
|
||||
'n': [0.50, 0.25, 1024],
|
||||
's': [0.50, 0.50, 1024],
|
||||
'm': [0.50, 1.00, 512],
|
||||
'l': [1.00, 1.00, 512],
|
||||
'x': [1.00, 1.50, 512],
|
||||
}
|
||||
|
||||
if scale not in scales:
|
||||
raise ValueError(f"Invalid scale '{scale}'. Available scales: {list(scales.keys())}")
|
||||
|
||||
depth, width, max_channels = scales[scale]
|
||||
|
||||
if scale in ['n', 's']:
|
||||
c3k_override = False
|
||||
else:
|
||||
c3k_override = True
|
||||
|
||||
# 辅助函数:计算通道数 (Width Scaling)
|
||||
def gw(channels):
|
||||
return make_divisible(min(channels, max_channels) * width, 8)
|
||||
|
||||
# 辅助函数:计算层重复次数 (Depth Scaling)
|
||||
def gd(n):
|
||||
return max(round(n * depth), 1) if n > 1 else n
|
||||
|
||||
self.model = nn.ModuleList()
|
||||
|
||||
# --- Backbone ---
|
||||
# 0: Conv [64, 3, 2]
|
||||
self.model.append(Conv(3, gw(64), 3, 2))
|
||||
|
||||
# 1: Conv [128, 3, 2]
|
||||
self.model.append(Conv(gw(64), gw(128), 3, 2))
|
||||
|
||||
# 2: C3k2 [256, False, 0.25] -> n=2
|
||||
self.model.append(C3k2(gw(128), gw(256), n=gd(2), c3k=False or c3k_override, e=0.25))
|
||||
|
||||
# 3: Conv [256, 3, 2]
|
||||
self.model.append(Conv(gw(256), gw(256), 3, 2))
|
||||
|
||||
# 4: C3k2 [512, False, 0.25] -> n=2
|
||||
self.model.append(C3k2(gw(256), gw(512), n=gd(2), c3k=False or c3k_override, e=0.25))
|
||||
|
||||
# 5: Conv [512, 3, 2]
|
||||
self.model.append(Conv(gw(512), gw(512), 3, 2))
|
||||
|
||||
# 6: C3k2 [512, True] -> n=2
|
||||
self.model.append(C3k2(gw(512), gw(512), n=gd(2), c3k=True))
|
||||
|
||||
# 7: Conv [1024, 3, 2]
|
||||
self.model.append(Conv(gw(512), gw(1024), 3, 2))
|
||||
|
||||
# 8: C3k2 [1024, True] -> n=2
|
||||
self.model.append(C3k2(gw(1024), gw(1024), n=gd(2), c3k=True))
|
||||
|
||||
# 9: SPPF [1024, 5]
|
||||
self.model.append(SPPF(gw(1024), gw(1024), 5))
|
||||
|
||||
# 10: C2PSA [1024] -> n=2 (YAML args=[1024], repeats=2)
|
||||
self.model.append(C2PSA(gw(1024), gw(1024), n=gd(2)))
|
||||
|
||||
# --- Head ---
|
||||
|
||||
# 11: Upsample
|
||||
self.model.append(nn.Upsample(scale_factor=2, mode='nearest'))
|
||||
|
||||
# 12: Concat [-1, 6] (P4)
|
||||
self.model.append(Concat(dimension=1))
|
||||
|
||||
# 13: C3k2 [512, False] -> n=2. Input: P5_up(gw(1024)) + P4(gw(512))
|
||||
self.model.append(C3k2(gw(1024) + gw(512), gw(512), n=gd(2), c3k=False or c3k_override))
|
||||
|
||||
# 14: Upsample
|
||||
self.model.append(nn.Upsample(scale_factor=2, mode='nearest'))
|
||||
|
||||
# 15: Concat [-1, 4] (P3)
|
||||
self.model.append(Concat(dimension=1))
|
||||
|
||||
# 16: C3k2 [256, False] -> n=2. Input: P4_up(gw(512)) + P3(gw(512))
|
||||
# 注意:Layer 4 输出是 gw(512),Layer 13 输出是 gw(512)
|
||||
self.model.append(C3k2(gw(512) + gw(512), gw(256), n=gd(2), c3k=False or c3k_override))
|
||||
|
||||
# 17: Conv [256, 3, 2]
|
||||
self.model.append(Conv(gw(256), gw(256), 3, 2))
|
||||
|
||||
# 18: Concat [-1, 13] (Head P4)
|
||||
self.model.append(Concat(dimension=1))
|
||||
|
||||
# 19: C3k2 [512, False] -> n=2. Input: P3_down(gw(256)) + Head_P4(gw(512))
|
||||
self.model.append(C3k2(gw(256) + gw(512), gw(512), n=gd(2), c3k=False or c3k_override))
|
||||
|
||||
# 20: Conv [512, 3, 2]
|
||||
self.model.append(Conv(gw(512), gw(512), 3, 2))
|
||||
|
||||
# 21: Concat [-1, 10] (P5)
|
||||
self.model.append(Concat(dimension=1))
|
||||
|
||||
# 22: C3k2 [1024, True] -> n=2. Input: P4_down(gw(512)) + P5(gw(1024))
|
||||
self.model.append(C3k2(gw(512) + gw(1024), gw(1024), n=gd(2), c3k=True))
|
||||
|
||||
# 23: Detect [nc]
|
||||
self.model.append(YOLOESegment(nc, ch=[gw(256), gw(512), gw(1024)]))
|
||||
|
||||
# --- 初始化权重 ---
|
||||
self.initialize_weights()
|
||||
|
||||
def initialize_weights(self):
|
||||
"""初始化模型权重,特别是 Detect 头的 Bias"""
|
||||
for m in self.modules():
|
||||
if isinstance(m, (nn.Conv2d, nn.Linear)):
|
||||
# 使用 Kaiming 初始化或其他合适的初始化
|
||||
pass
|
||||
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
detect_layer = self.model[-1]
|
||||
if isinstance(detect_layer, Detect):
|
||||
detect_layer.bias_init()
|
||||
|
||||
def set_classes(self, names: List[str], embeddings: torch.Tensor):
|
||||
assert embeddings.ndim == 3, "Embeddings must be (1, N, D)"
|
||||
self.pe = embeddings
|
||||
self.model[-1].nc = len(names) # type: ignore
|
||||
self.nc = len(names)
|
||||
|
||||
def forward(self, x, tpe=None, vpe=None):
|
||||
# Backbone
|
||||
x = self.model[0](x)
|
||||
x = self.model[1](x)
|
||||
x = self.model[2](x)
|
||||
x = self.model[3](x)
|
||||
p3 = self.model[4](x) # 保存 P3 (layer 4)
|
||||
x = self.model[5](p3)
|
||||
p4 = self.model[6](x) # 保存 P4 (layer 6)
|
||||
x = self.model[7](p4)
|
||||
x = self.model[8](x)
|
||||
x = self.model[9](x)
|
||||
p5 = self.model[10](x) # 保存 P5 (layer 10)
|
||||
|
||||
# Head
|
||||
x = self.model[11](p5) # Upsample
|
||||
x = self.model[12]([x, p4]) # Concat P4
|
||||
h1 = self.model[13](x) # Head P4 (layer 13)
|
||||
|
||||
x = self.model[14](h1) # Upsample
|
||||
x = self.model[15]([x, p3]) # Concat P3
|
||||
h2 = self.model[16](x) # Output P3 (layer 16)
|
||||
|
||||
x = self.model[17](h2) # Conv
|
||||
x = self.model[18]([x, h1]) # Concat Head P4
|
||||
h3 = self.model[19](x) # Output P4 (layer 19)
|
||||
|
||||
x = self.model[20](h3) # Conv
|
||||
x = self.model[21]([x, p5]) # Concat P5
|
||||
h4 = self.model[22](x) # Output P5 (layer 22)
|
||||
|
||||
head = self.model[23]
|
||||
feats = [h2, h3, h4]
|
||||
|
||||
processed_tpe = head.get_tpe(tpe) # type: ignore
|
||||
|
||||
processed_vpe = head.get_vpe(feats, vpe) if vpe is not None else None # type: ignore
|
||||
|
||||
all_pe = []
|
||||
if processed_tpe is not None:
|
||||
all_pe.append(processed_tpe)
|
||||
if processed_vpe is not None:
|
||||
all_pe.append(processed_vpe)
|
||||
|
||||
if not all_pe:
|
||||
if self.pe is not None:
|
||||
all_pe.append(self.pe.to(device=x.device, dtype=x.dtype))
|
||||
else:
|
||||
all_pe.append(torch.zeros(1, self.nc, head.embed, device=x.device, dtype=x.dtype))
|
||||
|
||||
cls_pe = torch.cat(all_pe, dim=1)
|
||||
|
||||
b = x.shape[0]
|
||||
if cls_pe.shape[0] != b:
|
||||
cls_pe = cls_pe.expand(b, -1, -1)
|
||||
|
||||
return head(feats, cls_pe)
|
||||
|
||||
def load_weights(self, pth_file):
|
||||
state_dict = torch.load(pth_file, map_location='cpu', weights_only=False)
|
||||
# 移除可能存在的 'model.' 前缀 (如果权重来自 ultralytics 官方)
|
||||
# 官方权重通常是 model.model.0.conv... 这种格式,或者直接是 model.0.conv...
|
||||
# 这里做一个简单的兼容性处理
|
||||
new_state_dict = {}
|
||||
for k, v in state_dict.items():
|
||||
# 处理 ultralytics 权重字典中的 'model' 键
|
||||
if k == 'model':
|
||||
# 如果是完整的 checkpoint,权重在 'model' 键下
|
||||
# 且通常是 model.state_dict()
|
||||
if hasattr(v, 'state_dict'):
|
||||
v = v.state_dict()
|
||||
elif isinstance(v, dict):
|
||||
pass # v 就是 state_dict
|
||||
else:
|
||||
# 可能是 model 对象本身
|
||||
try:
|
||||
v = v.float().state_dict()
|
||||
except:
|
||||
continue
|
||||
|
||||
for sub_k, sub_v in v.items():
|
||||
new_state_dict[sub_k] = sub_v
|
||||
break
|
||||
else:
|
||||
new_state_dict[k] = v
|
||||
|
||||
if not new_state_dict:
|
||||
new_state_dict = state_dict
|
||||
|
||||
# 尝试加载
|
||||
try:
|
||||
self.load_state_dict(new_state_dict, strict=True)
|
||||
print(f"Successfully loaded weights from {pth_file}")
|
||||
except Exception as e:
|
||||
print(f"Error loading weights: {e}")
|
||||
print("Trying to load with strict=False...")
|
||||
self.load_state_dict(new_state_dict, strict=False)
|
||||
|
||||
if __name__ == "__main__":
|
||||
model = YOLO11E(nc=80, scale='l')
|
||||
model.load_weights("yoloe-11l-seg.pth")
|
||||
|
||||
# 模拟 set_classes
|
||||
# 假设我们有2个类,embedding维度是512
|
||||
fake_embeddings = torch.randn(1, 2, 512)
|
||||
model.set_classes(["class1", "class2"], fake_embeddings)
|
||||
|
||||
# 推理
|
||||
dummy_input = torch.randn(1, 3, 640, 640)
|
||||
model.eval()
|
||||
output = model(dummy_input)
|
||||
print("Output shape:", output[0].shape) # 应该是 (1, 4+mask_coeffs+num_classes, anchors)
|
||||
BIN
yolo11n.pt
Normal file
BIN
yolo11n.pt
Normal file
Binary file not shown.
BIN
yolo11s.pth
Normal file
BIN
yolo11s.pth
Normal file
Binary file not shown.
Reference in New Issue
Block a user