Compare commits
5 Commits
2b8f25f318
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d9f2e8b496 | ||
|
|
11e874a3e5 | ||
|
|
01802e8beb | ||
|
|
58b24abbf0 | ||
|
|
604951f9c2 |
3
.gitattributes
vendored
Normal file
3
.gitattributes
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
*.pt filter=lfs diff=lfs merge=lfs -text
|
||||
*.pth filter=lfs diff=lfs merge=lfs -text
|
||||
*.ts filter=lfs diff=lfs merge=lfs -text
|
||||
288
dataset.py
Normal file
288
dataset.py
Normal file
@@ -0,0 +1,288 @@
|
||||
import os
|
||||
import glob
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import random
|
||||
from torch.utils.data import Dataset
|
||||
import albumentations as A
|
||||
from albumentations.pytorch import ToTensorV2
|
||||
|
||||
class YOLODataset(Dataset):
|
||||
def __init__(self, img_dir, label_dir, img_size=640, is_train=True):
|
||||
self.img_dir = img_dir
|
||||
self.label_dir = label_dir
|
||||
self.img_size = img_size
|
||||
self.is_train = is_train
|
||||
self.use_mosaic = is_train # 新增标志位,默认训练时开启
|
||||
|
||||
# 支持多种图片格式
|
||||
self.img_files = sorted(
|
||||
glob.glob(os.path.join(img_dir, "*.jpg")) +
|
||||
glob.glob(os.path.join(img_dir, "*.png")) +
|
||||
glob.glob(os.path.join(img_dir, "*.jpeg"))
|
||||
)
|
||||
|
||||
# --- 1. 对齐 Ultralytics 的 Albumentations 配置 ---
|
||||
# default.yaml: hsv_h: 0.015, hsv_s: 0.7, hsv_v: 0.4
|
||||
# OpenCV Hue range is [0, 179], Sat/Val is [0, 255]
|
||||
h_limit = int(0.015 * 179) # ~2
|
||||
s_limit = int(0.7 * 255) # ~178
|
||||
v_limit = int(0.4 * 255) # ~102
|
||||
|
||||
# default.yaml: translate: 0.1, scale: 0.5 (0.5~1.5), degrees: 0.0
|
||||
|
||||
if is_train:
|
||||
self.transform = A.Compose([
|
||||
# 几何增强 (Mosaic 之后再做一次微调,或者处理非 Mosaic 的情况)
|
||||
# 注意:Mosaic 输出已经是大图,这里主要负责最后的随机扰动
|
||||
A.Affine(
|
||||
scale=(0.5, 1.5), # scale: 0.5
|
||||
translate_percent=(0.1, 0.1), # translate: 0.1
|
||||
rotate=(-0, 0), # degrees: 0.0 (COCO default)
|
||||
shear=(-0, 0), # shear: 0.0
|
||||
p=0.5
|
||||
),
|
||||
|
||||
# 色彩增强 (严格对齐 default.yaml)
|
||||
A.HueSaturationValue(
|
||||
hue_shift_limit=h_limit,
|
||||
sat_shift_limit=s_limit,
|
||||
val_shift_limit=v_limit,
|
||||
p=0.5
|
||||
),
|
||||
|
||||
A.Blur(p=0.01),
|
||||
A.MedianBlur(p=0.01),
|
||||
A.ToGray(p=0.01),
|
||||
A.CLAHE(p=0.01),
|
||||
|
||||
# 翻转
|
||||
A.HorizontalFlip(p=0.5), # fliplr: 0.5
|
||||
|
||||
# 最终处理
|
||||
A.Resize(img_size, img_size), # 确保最后尺寸一致
|
||||
A.Normalize(mean=(0, 0, 0), std=(1, 1, 1), max_pixel_value=255.0),
|
||||
ToTensorV2()
|
||||
], bbox_params=A.BboxParams(format='yolo', min_visibility=0.0, label_fields=['class_labels']))
|
||||
else:
|
||||
# 验证集:Letterbox (保持比例填充)
|
||||
self.transform = A.Compose([
|
||||
A.LongestMaxSize(max_size=img_size),
|
||||
A.PadIfNeeded(min_height=img_size, min_width=img_size, border_mode=cv2.BORDER_CONSTANT),
|
||||
A.Normalize(mean=(0, 0, 0), std=(1, 1, 1), max_pixel_value=255.0),
|
||||
ToTensorV2()
|
||||
], bbox_params=A.BboxParams(format='yolo', min_visibility=0.1, label_fields=['class_labels']))
|
||||
|
||||
def close_mosaic(self):
|
||||
"""关闭 Mosaic 增强"""
|
||||
self.use_mosaic = False
|
||||
print("Mosaic augmentation disabled.")
|
||||
|
||||
def __len__(self):
|
||||
return len(self.img_files)
|
||||
|
||||
def load_image(self, index):
|
||||
"""加载单张图片并调整长边到 img_size"""
|
||||
img_path = self.img_files[index]
|
||||
img = cv2.imread(img_path)
|
||||
if img is None:
|
||||
raise FileNotFoundError(f"Image not found: {img_path}")
|
||||
|
||||
h, w = img.shape[:2]
|
||||
r = self.img_size / max(h, w)
|
||||
if r != 1:
|
||||
img = cv2.resize(img, (int(w * r), int(h * r)), interpolation=cv2.INTER_LINEAR)
|
||||
|
||||
return img, (h, w), img.shape[:2] # img, original_hw, resized_hw
|
||||
|
||||
def load_label(self, index, img_shape):
|
||||
"""加载标签并归一化"""
|
||||
img_path = self.img_files[index]
|
||||
label_path = self._get_label_path(img_path)
|
||||
h, w = img_shape
|
||||
|
||||
labels = []
|
||||
if os.path.exists(label_path):
|
||||
with open(label_path, 'r') as f:
|
||||
for line in f:
|
||||
parts = line.strip().split()
|
||||
if len(parts) >= 5:
|
||||
cls = int(parts[0])
|
||||
bx, by, bw, bh = map(float, parts[1:5])
|
||||
labels.append([cls, bx, by, bw, bh])
|
||||
|
||||
return np.array(labels, dtype=np.float32) if labels else np.zeros((0, 5), dtype=np.float32)
|
||||
|
||||
def load_mosaic(self, index):
|
||||
"""
|
||||
实现 YOLO 的 Mosaic 增强 (4张图拼成一张)
|
||||
"""
|
||||
labels4 = []
|
||||
s = self.img_size
|
||||
# 修复: 列表推导式需要遍历两次以生成 yc 和 xc
|
||||
yc, xc = [int(random.uniform(-x, 2 * s + x)) for x in [-s // 2, -s // 2]]
|
||||
|
||||
# 随机选3个额外索引
|
||||
indices = [index] + [random.randint(0, len(self.img_files) - 1) for _ in range(3)]
|
||||
random.shuffle(indices)
|
||||
|
||||
# 初始化大图 (2x size)
|
||||
img4 = np.full((s * 2, s * 2, 3), 114, dtype=np.uint8)
|
||||
|
||||
for i, idx in enumerate(indices):
|
||||
# 加载图片
|
||||
img, _, (h, w) = self.load_image(idx)
|
||||
|
||||
# 放置位置: top-left, top-right, bottom-left, bottom-right
|
||||
if i == 0: # top left
|
||||
x1a, y1a, x2a, y2a = max(xc - w, 0), max(yc - h, 0), xc, yc
|
||||
x1b, y1b, x2b, y2b = w - (x2a - x1a), h - (y2a - y1a), w, h
|
||||
elif i == 1: # top right
|
||||
x1a, y1a, x2a, y2a = xc, max(yc - h, 0), min(xc + w, s * 2), yc
|
||||
x1b, y1b, x2b, y2b = 0, h - (y2a - y1a), min(w, x2a - x1a), h
|
||||
elif i == 2: # bottom left
|
||||
x1a, y1a, x2a, y2a = max(xc - w, 0), yc, xc, min(yc + h, s * 2)
|
||||
x1b, y1b, x2b, y2b = w - (x2a - x1a), 0, w, min(y2a - y1a, h)
|
||||
elif i == 3: # bottom right
|
||||
x1a, y1a, x2a, y2a = xc, yc, min(xc + w, s * 2), min(yc + h, s * 2)
|
||||
x1b, y1b, x2b, y2b = 0, 0, min(w, x2a - x1a), min(y2a - y1a, h)
|
||||
|
||||
# 贴图
|
||||
img4[y1a:y2a, x1a:x2a] = img[y1b:y2b, x1b:x2b] # type: ignore
|
||||
padw = x1a - x1b # type: ignore
|
||||
padh = y1a - y1b # type: ignore
|
||||
|
||||
# 处理标签
|
||||
labels = self.load_label(idx, (h, w))
|
||||
if labels.size > 0:
|
||||
# Normalized xywh -> Pixel xywh
|
||||
labels[:, 1] = labels[:, 1] * w
|
||||
labels[:, 2] = labels[:, 2] * h
|
||||
labels[:, 3] = labels[:, 3] * w
|
||||
labels[:, 4] = labels[:, 4] * h
|
||||
|
||||
# xywh -> xyxy (Pixel)
|
||||
xyxy = np.copy(labels)
|
||||
xyxy[:, 1] = labels[:, 1] - labels[:, 3] / 2 + padw
|
||||
xyxy[:, 2] = labels[:, 2] - labels[:, 4] / 2 + padh
|
||||
xyxy[:, 3] = labels[:, 1] + labels[:, 3] / 2 + padw
|
||||
xyxy[:, 4] = labels[:, 2] + labels[:, 4] / 2 + padh
|
||||
|
||||
labels4.append(xyxy)
|
||||
|
||||
# Concat labels
|
||||
if len(labels4):
|
||||
labels4 = np.concatenate(labels4, 0)
|
||||
# Clip to mosaic image border
|
||||
np.clip(labels4[:, 1:], 0, 2 * s, out=labels4[:, 1:])
|
||||
|
||||
# 转换回 Normalized xywh (相对于 2*s 的大图)
|
||||
# Albumentations 需要 normalized xywh
|
||||
new_labels = np.copy(labels4)
|
||||
w_mosaic, h_mosaic = s * 2, s * 2
|
||||
|
||||
# xyxy -> xywh
|
||||
new_labels[:, 1] = (labels4[:, 1] + labels4[:, 3]) / 2 / w_mosaic
|
||||
new_labels[:, 2] = (labels4[:, 2] + labels4[:, 4]) / 2 / h_mosaic
|
||||
new_labels[:, 3] = (labels4[:, 3] - labels4[:, 1]) / w_mosaic
|
||||
new_labels[:, 4] = (labels4[:, 4] - labels4[:, 2]) / h_mosaic
|
||||
|
||||
return img4, new_labels
|
||||
else:
|
||||
return img4, np.zeros((0, 5))
|
||||
|
||||
def __getitem__(self, index):
|
||||
try:
|
||||
if self.is_train:
|
||||
# 修改判断逻辑:同时检查 is_train 和 use_mosaic
|
||||
if self.use_mosaic and random.random() < 1.0:
|
||||
img, labels = self.load_mosaic(index)
|
||||
else:
|
||||
img, _, _ = self.load_image(index)
|
||||
h, w = img.shape[:2]
|
||||
labels = self.load_label(index, (h, w))
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # type: ignore
|
||||
else:
|
||||
img = cv2.imread(self.img_files[index])
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # type: ignore
|
||||
h, w = img.shape[:2]
|
||||
labels = self.load_label(index, (h, w))
|
||||
|
||||
# --- 修复开始: 清洗边界框 ---
|
||||
# labels 格式: [cls, x, y, w, h] (Normalized)
|
||||
|
||||
valid_bboxes = []
|
||||
valid_labels = []
|
||||
|
||||
h_img, w_img = img.shape[:2]
|
||||
|
||||
for i in range(len(labels)):
|
||||
cls = labels[i, 0]
|
||||
x, y, w, h = labels[i, 1:]
|
||||
|
||||
# 1. 限制在 [0, 1] 范围内 (处理 Mosaic 裁剪产生的越界)
|
||||
x1 = np.clip(x - w / 2, 0, 1)
|
||||
y1 = np.clip(y - h / 2, 0, 1)
|
||||
x2 = np.clip(x + w / 2, 0, 1)
|
||||
y2 = np.clip(y + h / 2, 0, 1)
|
||||
|
||||
# 2. 重新计算宽高
|
||||
w_new = x2 - x1
|
||||
h_new = y2 - y1
|
||||
|
||||
# 3. 重新计算中心点
|
||||
x_new = x1 + w_new / 2
|
||||
y_new = y1 + h_new / 2
|
||||
|
||||
# 4. 过滤掉极小的框 (例如小于 2 个像素)
|
||||
if w_new * w_img > 2 and h_new * h_img > 2:
|
||||
valid_bboxes.append([x_new, y_new, w_new, h_new])
|
||||
valid_labels.append(cls)
|
||||
|
||||
if len(valid_bboxes) == 0:
|
||||
# 如果这张图的所有框都被过滤掉了,尝试下一张
|
||||
return self.__getitem__((index + 1) % len(self))
|
||||
|
||||
bboxes = valid_bboxes
|
||||
class_labels = valid_labels
|
||||
# --- 修复结束 ---
|
||||
|
||||
# 应用增强
|
||||
transformed = self.transform(image=img, bboxes=bboxes, class_labels=class_labels)
|
||||
image = transformed['image']
|
||||
bboxes = transformed['bboxes']
|
||||
class_labels = transformed['class_labels']
|
||||
|
||||
# 构建 Target Tensor
|
||||
n = len(bboxes)
|
||||
targets = torch.zeros((n, 6))
|
||||
if n > 0:
|
||||
targets[:, 1] = torch.tensor(class_labels)
|
||||
targets[:, 2:] = torch.tensor(bboxes)
|
||||
|
||||
return image, targets, self.img_files[index]
|
||||
|
||||
except Exception as e:
|
||||
# 打印更详细的错误信息以便调试,但不要中断训练
|
||||
# print(f"Error loading data {index}: {e}")
|
||||
return self.__getitem__((index + 1) % len(self))
|
||||
|
||||
def _get_label_path(self, img_path):
|
||||
filename = os.path.basename(img_path).rsplit('.', 1)[0] + ".txt"
|
||||
return os.path.join(self.label_dir, filename)
|
||||
|
||||
@staticmethod
|
||||
def collate_fn(batch):
|
||||
imgs, targets, paths = zip(*batch)
|
||||
imgs = torch.stack(imgs, 0)
|
||||
new_targets = []
|
||||
for i, t in enumerate(targets):
|
||||
if t.shape[0] > 0:
|
||||
t[:, 0] = i
|
||||
new_targets.append(t)
|
||||
if new_targets:
|
||||
targets = torch.cat(new_targets, 0)
|
||||
else:
|
||||
targets = torch.zeros((0, 6))
|
||||
return imgs, targets, paths
|
||||
189
inference.py
Normal file
189
inference.py
Normal file
@@ -0,0 +1,189 @@
|
||||
import torch
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torchvision
|
||||
from yolo11_standalone import YOLO11
|
||||
|
||||
# COCO 80类 类别名称
|
||||
CLASSES = [
|
||||
"person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light",
|
||||
"fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow",
|
||||
"elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee",
|
||||
"skis", "snowboard", "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard",
|
||||
"tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple",
|
||||
"sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch",
|
||||
"potted plant", "bed", "dining table", "toilet", "tv", "laptop", "mouse", "remote", "keyboard", "cell phone",
|
||||
"microwave", "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors", "teddy bear",
|
||||
"hair drier", "toothbrush"
|
||||
]
|
||||
|
||||
# 生成随机颜色用于绘图
|
||||
COLORS = np.random.uniform(0, 255, size=(len(CLASSES), 3))
|
||||
|
||||
def letterbox(im, new_shape=(640, 640), color=(114, 114, 114)):
|
||||
"""
|
||||
将图像缩放并填充到指定大小 (保持纵横比)
|
||||
"""
|
||||
shape = im.shape[:2] # current shape [height, width]
|
||||
if isinstance(new_shape, int):
|
||||
new_shape = (new_shape, new_shape)
|
||||
|
||||
# 计算缩放比例
|
||||
r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
|
||||
|
||||
# 计算padding
|
||||
new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
|
||||
dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
|
||||
dw, dh = dw / 2, dh / 2 # divide padding into 2 sides
|
||||
|
||||
if shape[::-1] != new_unpad: # resize
|
||||
im = cv2.resize(im, new_unpad, interpolation=cv2.INTER_LINEAR)
|
||||
|
||||
top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
|
||||
left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
|
||||
|
||||
# 添加边框
|
||||
im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color)
|
||||
return im, r, (dw, dh)
|
||||
|
||||
def xywh2xyxy(x):
|
||||
"""Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2]"""
|
||||
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
||||
y[..., 0] = x[..., 0] - x[..., 2] / 2 # top left x
|
||||
y[..., 1] = x[..., 1] - x[..., 3] / 2 # top left y
|
||||
y[..., 2] = x[..., 0] + x[..., 2] / 2 # bottom right x
|
||||
y[..., 3] = x[..., 1] + x[..., 3] / 2 # bottom right y
|
||||
return y
|
||||
|
||||
def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, max_det=300):
|
||||
"""
|
||||
非极大值抑制 (NMS)
|
||||
prediction: [Batch, 84, Anchors]
|
||||
"""
|
||||
# 1. 转置: [Batch, 84, Anchors] -> [Batch, Anchors, 84]
|
||||
prediction = prediction.transpose(1, 2)
|
||||
|
||||
bs = prediction.shape[0] # batch size
|
||||
nc = prediction.shape[2] - 4 # number of classes
|
||||
|
||||
# 修复: 使用 max(-1) 在最后一个维度(类别)上寻找最大置信度
|
||||
# 之前的 max(1) 错误地在 Anchors 维度上操作了
|
||||
xc = prediction[..., 4:].max(-1)[0] > conf_thres # candidates
|
||||
|
||||
output = [torch.zeros((0, 6), device=prediction.device)] * bs
|
||||
|
||||
for xi, x in enumerate(prediction): # image index, image inference
|
||||
x = x[xc[xi]] # confidence filtering
|
||||
|
||||
if not x.shape[0]:
|
||||
continue
|
||||
|
||||
# Box (center x, center y, width, height) to (x1, y1, x2, y2)
|
||||
box = xywh2xyxy(x[:, :4])
|
||||
|
||||
# Confidence and Class
|
||||
conf, j = x[:, 4:].max(1, keepdim=True)
|
||||
x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]
|
||||
|
||||
# Check shape
|
||||
n = x.shape[0]
|
||||
if not n:
|
||||
continue
|
||||
elif n > max_det:
|
||||
x = x[x[:, 4].argsort(descending=True)[:max_det]]
|
||||
|
||||
# Batched NMS
|
||||
c = x[:, 5:6] * 7680 # classes
|
||||
boxes, scores = x[:, :4] + c, x[:, 4]
|
||||
i = torchvision.ops.nms(boxes, scores, iou_thres)
|
||||
output[xi] = x[i]
|
||||
|
||||
return output
|
||||
|
||||
def main():
|
||||
# 1. 初始化模型
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
print(f"Using device: {device}")
|
||||
|
||||
model = YOLO11(nc=80, scale='s')
|
||||
# 加载你之前转换好的纯净权重
|
||||
model.load_weights("yolo11s.pth")
|
||||
model.to(device)
|
||||
model.eval()
|
||||
# model.train()
|
||||
|
||||
# 2. 读取图片
|
||||
img_path = "1.jpg" # 请替换为你本地的图片路径
|
||||
|
||||
img0 = cv2.imread(img_path)
|
||||
assert img0 is not None, f"Image Not Found {img_path}"
|
||||
|
||||
# 3. 预处理
|
||||
# Letterbox resize
|
||||
img, ratio, (dw, dh) = letterbox(img0, new_shape=(640, 640))
|
||||
|
||||
# BGR to RGB, HWC to CHW
|
||||
img = img[:, :, ::-1].transpose(2, 0, 1)
|
||||
img = np.ascontiguousarray(img)
|
||||
|
||||
img_tensor = torch.from_numpy(img).to(device)
|
||||
img_tensor = img_tensor.float()
|
||||
img_tensor /= 255.0 # 0 - 255 to 0.0 - 1.0
|
||||
if img_tensor.ndim == 3:
|
||||
img_tensor = img_tensor.unsqueeze(0)
|
||||
|
||||
# 4. 推理
|
||||
print("开始推理...")
|
||||
with torch.no_grad():
|
||||
pred = model(img_tensor)
|
||||
|
||||
# 5. 后处理 (NMS)
|
||||
pred = non_max_suppression(pred, conf_thres=0.25, iou_thres=0.45)
|
||||
|
||||
# 6. 绘制结果
|
||||
det = pred[0] # 仅处理第一张图片
|
||||
|
||||
if len(det):
|
||||
# 将坐标映射回原图尺寸
|
||||
# det[:, :4] 是 x1, y1, x2, y2
|
||||
det[:, [0, 2]] -= dw # x padding
|
||||
det[:, [1, 3]] -= dh # y padding
|
||||
det[:, :4] /= ratio
|
||||
|
||||
# 裁剪坐标防止越界
|
||||
det[:, 0].clamp_(0, img0.shape[1])
|
||||
det[:, 1].clamp_(0, img0.shape[0])
|
||||
det[:, 2].clamp_(0, img0.shape[1])
|
||||
det[:, 3].clamp_(0, img0.shape[0])
|
||||
|
||||
print(f"检测到 {len(det)} 个目标")
|
||||
|
||||
for *xyxy, conf, cls in det:
|
||||
c = int(cls)
|
||||
label = f'{CLASSES[c]} {conf:.2f}'
|
||||
p1, p2 = (int(xyxy[0]), int(xyxy[1])), (int(xyxy[2]), int(xyxy[3]))
|
||||
|
||||
# 画框
|
||||
color = COLORS[c]
|
||||
cv2.rectangle(img0, p1, p2, color, 2, lineType=cv2.LINE_AA)
|
||||
|
||||
# 画标签背景
|
||||
t_size = cv2.getTextSize(label, 0, fontScale=0.5, thickness=1)[0]
|
||||
p2_label = p1[0] + t_size[0], p1[1] - t_size[1] - 3
|
||||
cv2.rectangle(img0, p1, p2_label, color, -1, cv2.LINE_AA)
|
||||
|
||||
# 画文字
|
||||
cv2.putText(img0, label, (p1[0], p1[1] - 2), 0, 0.5, [255, 255, 255], thickness=1, lineType=cv2.LINE_AA)
|
||||
|
||||
print(f" - {label} at {p1}-{p2}")
|
||||
|
||||
# 7. 显示/保存结果
|
||||
cv2.imwrite("result.jpg", img0)
|
||||
print("结果已保存至 result.jpg")
|
||||
|
||||
def import_os_exists(path):
|
||||
import os
|
||||
return os.path.exists(path)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
168
inference_yoloe.py
Normal file
168
inference_yoloe.py
Normal file
@@ -0,0 +1,168 @@
|
||||
import torch
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torchvision
|
||||
from pathlib import Path
|
||||
|
||||
# 导入你的模块
|
||||
from yolo11_standalone import YOLO11E
|
||||
from mobile_clip_standalone import MobileCLIP
|
||||
|
||||
from ultralytics import YOLOE
|
||||
|
||||
|
||||
# --- 配置 ---
|
||||
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
YOLO_WEIGHTS = "yoloe-11l-seg.pth" # 替换为你的 YOLO 权重路径
|
||||
CLIP_WEIGHTS = "mobileclip_blt.ts" # 替换为你的 MobileCLIP 权重路径
|
||||
CLIP_SIZE = "blt" # 对应 MobileCLIP 的 size
|
||||
IMAGE_PATH = "1.jpg" # 待检测图片
|
||||
|
||||
# 自定义检测类别 (Open Vocabulary)
|
||||
CUSTOM_CLASSES = ["girl", "red balloon"]
|
||||
|
||||
# 绘图颜色
|
||||
COLORS = np.random.uniform(0, 255, size=(len(CUSTOM_CLASSES), 3))
|
||||
|
||||
# --- 辅助函数 (Letterbox, NMS 等) ---
|
||||
def letterbox(im, new_shape=(640, 640), color=(114, 114, 114)):
|
||||
shape = im.shape[:2]
|
||||
if isinstance(new_shape, int): new_shape = (new_shape, new_shape)
|
||||
r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
|
||||
new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
|
||||
dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1]
|
||||
dw, dh = dw / 2, dh / 2
|
||||
if shape[::-1] != new_unpad:
|
||||
im = cv2.resize(im, new_unpad, interpolation=cv2.INTER_LINEAR)
|
||||
top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
|
||||
left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
|
||||
im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color)
|
||||
return im, r, (dw, dh)
|
||||
|
||||
def xywh2xyxy(x):
|
||||
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
||||
y[..., 0] = x[..., 0] - x[..., 2] / 2
|
||||
y[..., 1] = x[..., 1] - x[..., 3] / 2
|
||||
y[..., 2] = x[..., 0] + x[..., 2] / 2
|
||||
y[..., 3] = x[..., 1] + x[..., 3] / 2
|
||||
return y
|
||||
|
||||
def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.7, max_det=300):
|
||||
prediction = prediction.transpose(1, 2)
|
||||
bs = prediction.shape[0]
|
||||
xc = prediction[..., 4:].max(-1)[0] > conf_thres
|
||||
output = [torch.zeros((0, 6), device=prediction.device)] * bs
|
||||
for xi, x in enumerate(prediction):
|
||||
x = x[xc[xi]]
|
||||
if not x.shape[0]: continue
|
||||
box = xywh2xyxy(x[:, :4])
|
||||
conf, j = x[:, 4:].max(1, keepdim=True)
|
||||
x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]
|
||||
n = x.shape[0]
|
||||
if not n: continue
|
||||
elif n > max_det: x = x[x[:, 4].argsort(descending=True)[:max_det]]
|
||||
c = x[:, 5:6] * 7680
|
||||
boxes, scores = x[:, :4] + c, x[:, 4]
|
||||
i = torchvision.ops.nms(boxes, scores, iou_thres)
|
||||
output[xi] = x[i]
|
||||
return output
|
||||
|
||||
def main():
|
||||
print(f"Using device: {DEVICE}")
|
||||
|
||||
# 1. 加载 MobileCLIP 文本编码器
|
||||
print(f"Loading MobileCLIP from {CLIP_WEIGHTS}...")
|
||||
if not Path(CLIP_WEIGHTS).exists():
|
||||
raise FileNotFoundError(f"MobileCLIP weights not found: {CLIP_WEIGHTS}")
|
||||
|
||||
clip_model = MobileCLIP(checkpoint=CLIP_WEIGHTS, size=CLIP_SIZE, device=DEVICE)
|
||||
|
||||
# 2. 生成文本 Embeddings
|
||||
print(f"Encoding classes: {CUSTOM_CLASSES}")
|
||||
prompts = [f"{c}" for c in CUSTOM_CLASSES]
|
||||
|
||||
tokens = clip_model.tokenize(prompts)
|
||||
text_embeddings = clip_model.encode_text(tokens) # Shape: (N, 512)
|
||||
|
||||
# 调整维度为 (1, N, 512) 以匹配 YOLO11E 输入
|
||||
text_embeddings = text_embeddings.unsqueeze(0)
|
||||
|
||||
# 3. 加载 YOLO11E 检测模型
|
||||
print(f"Loading YOLO11E from {YOLO_WEIGHTS}...")
|
||||
if not Path(YOLO_WEIGHTS).exists():
|
||||
raise FileNotFoundError(f"YOLO weights not found: {YOLO_WEIGHTS}")
|
||||
|
||||
# 注意:scale='l' 必须与你的权重文件匹配 (s, m, l, x)
|
||||
yolo_model = YOLO11E(nc=80, scale='l')
|
||||
yolo_model.load_weights(YOLO_WEIGHTS)
|
||||
yolo_model.to(DEVICE) # 使用半精度to(DEVICE)
|
||||
yolo_model.eval()
|
||||
|
||||
head = yolo_model.model[-1]
|
||||
|
||||
with torch.no_grad():
|
||||
text_pe = head.get_tpe(text_embeddings) # type: ignore
|
||||
|
||||
yolo_model.set_classes(CUSTOM_CLASSES, text_pe)
|
||||
|
||||
# 5. 图像预处理
|
||||
img0 = cv2.imread(IMAGE_PATH)
|
||||
assert img0 is not None, f"Image Not Found {IMAGE_PATH}"
|
||||
img, ratio, (dw, dh) = letterbox(img0, new_shape=(640, 640))
|
||||
img = img[:, :, ::-1].transpose(2, 0, 1)
|
||||
img = np.ascontiguousarray(img)
|
||||
img_tensor = torch.from_numpy(img).to(DEVICE)
|
||||
img_tensor = img_tensor.float()
|
||||
img_tensor /= 255.0
|
||||
if img_tensor.ndim == 3:
|
||||
img_tensor = img_tensor.unsqueeze(0)
|
||||
|
||||
# 6. 推理
|
||||
print("Running inference...")
|
||||
with torch.no_grad():
|
||||
pred = yolo_model(img_tensor)
|
||||
if isinstance(pred, tuple):
|
||||
pred = pred[0]
|
||||
nc = len(CUSTOM_CLASSES)
|
||||
pred = pred[:, :4+nc, :]
|
||||
|
||||
# 7. 后处理 (NMS)
|
||||
pred = non_max_suppression(pred, conf_thres=0.25, iou_thres=0.7)
|
||||
|
||||
print(pred)
|
||||
|
||||
# 8. 可视化
|
||||
det = pred[0]
|
||||
if len(det):
|
||||
det[:, [0, 2]] -= dw
|
||||
det[:, [1, 3]] -= dh
|
||||
det[:, :4] /= ratio
|
||||
det[:, 0].clamp_(0, img0.shape[1])
|
||||
det[:, 1].clamp_(0, img0.shape[0])
|
||||
det[:, 2].clamp_(0, img0.shape[1])
|
||||
det[:, 3].clamp_(0, img0.shape[0])
|
||||
|
||||
print(f"Detected {len(det)} objects:")
|
||||
for *xyxy, conf, cls in det:
|
||||
c = int(cls)
|
||||
class_name = CUSTOM_CLASSES[c] if c < len(CUSTOM_CLASSES) else str(c)
|
||||
label = f'{class_name} {conf:.2f}'
|
||||
|
||||
p1, p2 = (int(xyxy[0]), int(xyxy[1])), (int(xyxy[2]), int(xyxy[3]))
|
||||
color = COLORS[c]
|
||||
|
||||
cv2.rectangle(img0, p1, p2, color, 2, cv2.LINE_AA)
|
||||
t_size = cv2.getTextSize(label, 0, 0.5, 1)[0]
|
||||
p2_label = p1[0] + t_size[0], p1[1] - t_size[1] - 3
|
||||
cv2.rectangle(img0, p1, p2_label, color, -1, cv2.LINE_AA)
|
||||
cv2.putText(img0, label, (p1[0], p1[1] - 2), 0, 0.5, [255, 255, 255], 1, cv2.LINE_AA)
|
||||
print(f" - {label}")
|
||||
else:
|
||||
print("No objects detected.")
|
||||
|
||||
output_path = "result_full.jpg"
|
||||
cv2.imwrite(output_path, img0)
|
||||
print(f"Result saved to {output_path}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
448
loss.py
Normal file
448
loss.py
Normal file
@@ -0,0 +1,448 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
# ==============================================================================
|
||||
# 1. 基础工具函数 (Utils)
|
||||
# ==============================================================================
|
||||
|
||||
def xywh2xyxy(x):
|
||||
"""Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right."""
|
||||
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
||||
y[..., 0] = x[..., 0] - x[..., 2] / 2 # top left x
|
||||
y[..., 1] = x[..., 1] - x[..., 3] / 2 # top left y
|
||||
y[..., 2] = x[..., 0] + x[..., 2] / 2 # bottom right x
|
||||
y[..., 3] = x[..., 1] + x[..., 3] / 2 # bottom right y
|
||||
return y
|
||||
|
||||
def bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7):
|
||||
"""
|
||||
Calculate Intersection over Union (IoU) of box1(1, 4) to box2(n, 4).
|
||||
Args:
|
||||
box1 (torch.Tensor): predicted, shape (N, 4)
|
||||
box2 (torch.Tensor): target, shape (N, 4)
|
||||
xywh (bool): If True, input boxes are (x, y, w, h). If False, (x1, y1, x2, y2).
|
||||
CIoU (bool): If True, calculate Complete IoU.
|
||||
"""
|
||||
# Get the coordinates of bounding boxes
|
||||
if xywh: # transform from xywh to xyxy
|
||||
(x1, y1, w1, h1), (x2, y2, w2, h2) = box1.chunk(4, -1), box2.chunk(4, -1)
|
||||
w1_, h1_, w2_, h2_ = w1 / 2, h1 / 2, w2 / 2, h2 / 2
|
||||
b1_x1, b1_x2, b1_y1, b1_y2 = x1 - w1_, x1 + w1_, y1 - h1_, y1 + h1_
|
||||
b2_x1, b2_x2, b2_y1, b2_y2 = x2 - w2_, x2 + w2_, y2 - h2_, y2 + h2_
|
||||
else: # x1, y1, x2, y2
|
||||
b1_x1, b1_y1, b1_x2, b1_y2 = box1.chunk(4, -1)
|
||||
b2_x1, b2_y1, b2_x2, b2_y2 = box2.chunk(4, -1)
|
||||
w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps
|
||||
w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps
|
||||
|
||||
# Intersection area
|
||||
inter = (torch.min(b1_x2, b2_x2) - torch.max(b1_x1, b2_x1)).clamp(0) * \
|
||||
(torch.min(b1_y2, b2_y2) - torch.max(b1_y1, b2_y1)).clamp(0)
|
||||
|
||||
# Union Area
|
||||
union = w1 * h1 + w2 * h2 - inter + eps
|
||||
|
||||
# IoU
|
||||
iou = inter / union
|
||||
if CIoU or DIoU or GIoU:
|
||||
cw = torch.max(b1_x2, b2_x2) - torch.min(b1_x1, b2_x1) # convex (smallest enclosing box) width
|
||||
ch = torch.max(b1_y2, b2_y2) - torch.min(b1_y1, b2_y1) # convex height
|
||||
if CIoU or DIoU: # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1
|
||||
c2 = cw ** 2 + ch ** 2 + eps # convex diagonal squared
|
||||
rho2 = ((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 + (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4 # center dist ** 2
|
||||
if CIoU: # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47
|
||||
v = (4 / math.pi ** 2) * torch.pow(torch.atan(w2 / h2) - torch.atan(w1 / h1), 2)
|
||||
with torch.no_grad():
|
||||
alpha = v / (v - iou + (1 + eps))
|
||||
return iou - (rho2 / c2 + v * alpha) # CIoU
|
||||
return iou - rho2 / c2 # DIoU
|
||||
c_area = cw * ch + eps # convex area
|
||||
return iou - (c_area - union) / c_area # GIoU
|
||||
return iou
|
||||
|
||||
def make_anchors(feats, strides, grid_cell_offset=0.5):
|
||||
"""Generate anchors from features."""
|
||||
anchor_points, stride_tensor = [], []
|
||||
assert feats is not None
|
||||
dtype, device = feats[0].dtype, feats[0].device
|
||||
for i, stride in enumerate(strides):
|
||||
h, w = feats[i].shape[2:]
|
||||
sx = torch.arange(end=w, device=device, dtype=dtype) + grid_cell_offset
|
||||
sy = torch.arange(end=h, device=device, dtype=dtype) + grid_cell_offset
|
||||
sy, sx = torch.meshgrid(sy, sx, indexing="ij")
|
||||
anchor_points.append(torch.stack((sx, sy), -1).view(-1, 2))
|
||||
stride_tensor.append(torch.full((h * w, 1), stride, dtype=dtype, device=device))
|
||||
return torch.cat(anchor_points), torch.cat(stride_tensor)
|
||||
|
||||
def dist2bbox(distance, anchor_points, xywh=True, dim=-1):
|
||||
"""Transform distance(ltrb) to box(xywh or xyxy)."""
|
||||
lt, rb = distance.chunk(2, dim)
|
||||
x1y1 = anchor_points - lt
|
||||
x2y2 = anchor_points + rb
|
||||
if xywh:
|
||||
c_xy = (x1y1 + x2y2) / 2
|
||||
wh = x2y2 - x1y1
|
||||
return torch.cat([c_xy, wh], dim)
|
||||
return torch.cat((x1y1, x2y2), dim)
|
||||
|
||||
def bbox2dist(anchor_points, bbox, reg_max):
|
||||
"""Transform bbox(xyxy) to dist(ltrb)."""
|
||||
x1y1, x2y2 = bbox.chunk(2, -1)
|
||||
return torch.cat((anchor_points - x1y1, x2y2 - anchor_points), -1).clamp_(0, reg_max - 0.01)
|
||||
|
||||
|
||||
class TaskAlignedAssigner(nn.Module):
|
||||
def __init__(self, topk=13, num_classes=80, alpha=1.0, beta=6.0, eps=1e-9):
|
||||
super().__init__()
|
||||
self.topk = topk
|
||||
self.num_classes = num_classes
|
||||
self.alpha = alpha
|
||||
self.beta = beta
|
||||
self.eps = eps
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt):
|
||||
self.bs = pd_scores.shape[0]
|
||||
self.n_max_boxes = gt_bboxes.shape[1]
|
||||
|
||||
if self.n_max_boxes == 0:
|
||||
return (
|
||||
torch.full_like(pd_scores[..., 0], self.num_classes),
|
||||
torch.zeros_like(pd_bboxes),
|
||||
torch.zeros_like(pd_scores),
|
||||
torch.zeros_like(pd_scores[..., 0]),
|
||||
torch.zeros_like(pd_scores[..., 0]),
|
||||
)
|
||||
|
||||
mask_pos, align_metric, overlaps = self.get_pos_mask(
|
||||
pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points, mask_gt
|
||||
)
|
||||
|
||||
target_gt_idx, fg_mask, mask_pos = self.select_highest_overlaps(mask_pos, overlaps, self.n_max_boxes)
|
||||
|
||||
target_labels, target_bboxes, target_scores = self.get_targets(gt_labels, gt_bboxes, target_gt_idx, fg_mask)
|
||||
|
||||
# Normalize
|
||||
align_metric *= mask_pos
|
||||
pos_align_metrics = align_metric.amax(dim=-1, keepdim=True)
|
||||
pos_overlaps = (overlaps * mask_pos).amax(dim=-1, keepdim=True)
|
||||
norm_align_metric = (align_metric * pos_overlaps / (pos_align_metrics + self.eps)).amax(-2).unsqueeze(-1)
|
||||
target_scores = target_scores * norm_align_metric
|
||||
|
||||
return target_labels, target_bboxes, target_scores, fg_mask.bool(), target_gt_idx
|
||||
|
||||
def get_pos_mask(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points, mask_gt):
|
||||
mask_in_gts = self.select_candidates_in_gts(anc_points, gt_bboxes)
|
||||
align_metric, overlaps = self.get_box_metrics(pd_scores, pd_bboxes, gt_labels, gt_bboxes, mask_in_gts * mask_gt)
|
||||
mask_topk = self.select_topk_candidates(align_metric, topk_mask=mask_gt.expand(-1, -1, self.topk).bool())
|
||||
mask_pos = mask_topk * mask_in_gts * mask_gt
|
||||
return mask_pos, align_metric, overlaps
|
||||
|
||||
def get_box_metrics(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, mask_gt):
|
||||
na = pd_bboxes.shape[-2]
|
||||
mask_gt = mask_gt.bool()
|
||||
overlaps = torch.zeros([self.bs, self.n_max_boxes, na], dtype=pd_bboxes.dtype, device=pd_bboxes.device)
|
||||
bbox_scores = torch.zeros([self.bs, self.n_max_boxes, na], dtype=pd_scores.dtype, device=pd_scores.device)
|
||||
|
||||
ind = torch.zeros([2, self.bs, self.n_max_boxes], dtype=torch.long)
|
||||
ind[0] = torch.arange(end=self.bs).view(-1, 1).expand(-1, self.n_max_boxes)
|
||||
ind[1] = gt_labels.squeeze(-1)
|
||||
|
||||
# Clamp labels to handle case where no GT exists or background index issues
|
||||
ind[1] = ind[1].clamp(max=self.num_classes - 1)
|
||||
|
||||
bbox_scores[mask_gt] = pd_scores[ind[0], :, ind[1]][mask_gt]
|
||||
|
||||
pd_boxes = pd_bboxes.unsqueeze(1).expand(-1, self.n_max_boxes, -1, -1)[mask_gt]
|
||||
gt_boxes = gt_bboxes.unsqueeze(2).expand(-1, -1, na, -1)[mask_gt]
|
||||
overlaps[mask_gt] = bbox_iou(gt_boxes, pd_boxes, xywh=False, CIoU=True).squeeze(-1).clamp_(0)
|
||||
|
||||
align_metric = bbox_scores.pow(self.alpha) * overlaps.pow(self.beta)
|
||||
return align_metric, overlaps
|
||||
|
||||
def select_topk_candidates(self, metrics, topk_mask=None):
|
||||
topk_metrics, topk_idxs = torch.topk(metrics, self.topk, dim=-1, largest=True)
|
||||
if topk_mask is None:
|
||||
topk_mask = (topk_metrics.max(-1, keepdim=True)[0] > self.eps).expand_as(topk_idxs)
|
||||
topk_idxs.masked_fill_(~topk_mask, 0)
|
||||
|
||||
count_tensor = torch.zeros(metrics.shape, dtype=torch.int8, device=topk_idxs.device)
|
||||
ones = torch.ones_like(topk_idxs[:, :, :1], dtype=torch.int8, device=topk_idxs.device)
|
||||
for k in range(self.topk):
|
||||
count_tensor.scatter_add_(-1, topk_idxs[:, :, k : k + 1], ones)
|
||||
count_tensor.masked_fill_(count_tensor > 1, 0)
|
||||
return count_tensor.to(metrics.dtype)
|
||||
|
||||
def get_targets(self, gt_labels, gt_bboxes, target_gt_idx, fg_mask):
|
||||
batch_ind = torch.arange(end=self.bs, dtype=torch.int64, device=gt_labels.device)[..., None]
|
||||
target_gt_idx = target_gt_idx + batch_ind * self.n_max_boxes
|
||||
target_labels = gt_labels.long().flatten()[target_gt_idx]
|
||||
|
||||
target_bboxes = gt_bboxes.view(-1, gt_bboxes.shape[-1])[target_gt_idx]
|
||||
target_labels.clamp_(0)
|
||||
|
||||
target_scores = torch.zeros(
|
||||
(target_labels.shape[0], target_labels.shape[1], self.num_classes),
|
||||
dtype=torch.int64,
|
||||
device=target_labels.device,
|
||||
)
|
||||
target_scores.scatter_(2, target_labels.unsqueeze(-1), 1)
|
||||
|
||||
fg_scores_mask = fg_mask[:, :, None].repeat(1, 1, self.num_classes)
|
||||
target_scores = torch.where(fg_scores_mask > 0, target_scores, 0)
|
||||
|
||||
return target_labels, target_bboxes, target_scores
|
||||
|
||||
@staticmethod
|
||||
def select_candidates_in_gts(xy_centers, gt_bboxes, eps=1e-9):
|
||||
n_anchors = xy_centers.shape[0]
|
||||
bs, n_boxes, _ = gt_bboxes.shape
|
||||
lt, rb = gt_bboxes.view(-1, 1, 4).chunk(2, 2)
|
||||
bbox_deltas = torch.cat((xy_centers[None] - lt, rb - xy_centers[None]), dim=2).view(bs, n_boxes, n_anchors, -1)
|
||||
return bbox_deltas.amin(3).gt_(eps)
|
||||
|
||||
@staticmethod
|
||||
def select_highest_overlaps(mask_pos, overlaps, n_max_boxes):
|
||||
fg_mask = mask_pos.sum(-2)
|
||||
if fg_mask.max() > 1:
|
||||
mask_multi_gts = (fg_mask.unsqueeze(1) > 1).expand(-1, n_max_boxes, -1)
|
||||
max_overlaps_idx = overlaps.argmax(1)
|
||||
|
||||
is_max_overlaps = torch.zeros(mask_pos.shape, dtype=mask_pos.dtype, device=mask_pos.device)
|
||||
is_max_overlaps.scatter_(1, max_overlaps_idx.unsqueeze(1), 1)
|
||||
|
||||
mask_pos = torch.where(mask_multi_gts, is_max_overlaps, mask_pos).float()
|
||||
fg_mask = mask_pos.sum(-2)
|
||||
target_gt_idx = mask_pos.argmax(-2)
|
||||
return target_gt_idx, fg_mask, mask_pos
|
||||
|
||||
# ==============================================================================
|
||||
# 3. Loss 模块 (Loss Modules - from loss.py)
|
||||
# ==============================================================================
|
||||
|
||||
class DFLoss(nn.Module):
|
||||
"""Distribution Focal Loss (DFL)."""
|
||||
def __init__(self, reg_max: int = 16) -> None:
|
||||
super().__init__()
|
||||
self.reg_max = reg_max
|
||||
|
||||
def __call__(self, pred_dist: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
||||
target = target.clamp_(0, self.reg_max - 1 - 0.01)
|
||||
tl = target.long() # target left
|
||||
tr = tl + 1 # target right
|
||||
wl = tr - target # weight left
|
||||
wr = 1 - wl # weight right
|
||||
return (
|
||||
F.cross_entropy(pred_dist, tl.view(-1), reduction="none").view(tl.shape) * wl
|
||||
+ F.cross_entropy(pred_dist, tr.view(-1), reduction="none").view(tl.shape) * wr
|
||||
).mean(-1, keepdim=True)
|
||||
|
||||
class BboxLoss(nn.Module):
|
||||
"""Criterion for computing training losses for bounding boxes (IoU + DFL)."""
|
||||
def __init__(self, reg_max: int = 16):
|
||||
super().__init__()
|
||||
self.dfl_loss = DFLoss(reg_max) if reg_max > 1 else None
|
||||
|
||||
def forward(self, pred_dist, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask):
|
||||
weight = target_scores.sum(-1)[fg_mask].unsqueeze(-1)
|
||||
iou = bbox_iou(pred_bboxes[fg_mask], target_bboxes[fg_mask], xywh=False, CIoU=True)
|
||||
loss_iou = ((1.0 - iou) * weight).sum() / target_scores_sum
|
||||
|
||||
if self.dfl_loss:
|
||||
target_ltrb = bbox2dist(anchor_points, target_bboxes, self.dfl_loss.reg_max - 1)
|
||||
loss_dfl = self.dfl_loss(pred_dist[fg_mask].view(-1, self.dfl_loss.reg_max), target_ltrb[fg_mask]) * weight
|
||||
loss_dfl = loss_dfl.sum() / target_scores_sum
|
||||
else:
|
||||
loss_dfl = torch.tensor(0.0).to(pred_dist.device)
|
||||
|
||||
return loss_iou, loss_dfl
|
||||
|
||||
class YOLOv8DetectionLoss:
|
||||
"""
|
||||
独立版的 YOLOv8/v11 Detection Loss。
|
||||
Refactored to assume specific inputs and remove reliance on full Model object.
|
||||
"""
|
||||
def __init__(self, nc=80, reg_max=16, stride=None, hyp=None):
|
||||
"""
|
||||
Args:
|
||||
nc (int): Number of classes.
|
||||
reg_max (int): DFL channels (default 16).
|
||||
stride (list/torch.Tensor): Model strides (e.g., [8, 16, 32]).
|
||||
hyp (dict): Hyperparameters (box, cls, dfl gains).
|
||||
"""
|
||||
if stride is None:
|
||||
stride = [8, 16, 32] # Default strides
|
||||
if hyp is None:
|
||||
# Default YOLOv8 hyperparameters
|
||||
hyp = {'box': 7.5, 'cls': 0.5, 'dfl': 1.5}
|
||||
|
||||
self.nc = nc
|
||||
self.reg_max = reg_max
|
||||
self.stride = stride
|
||||
self.hyp = hyp
|
||||
self.no = nc + reg_max * 4 # Output channels per anchor
|
||||
self.use_dfl = reg_max > 1
|
||||
|
||||
# Components
|
||||
self.assigner = TaskAlignedAssigner(topk=10, num_classes=self.nc, alpha=0.5, beta=6.0)
|
||||
self.bbox_loss = BboxLoss(reg_max)
|
||||
self.bce = nn.BCEWithLogitsLoss(reduction="none")
|
||||
self.proj = torch.arange(reg_max, dtype=torch.float) # Will move to device in method
|
||||
|
||||
def preprocess(self, targets, batch_size, scale_tensor):
|
||||
"""Preprocesses targets: converts [idx, cls, xywh] to internal format."""
|
||||
if targets.shape[0] == 0:
|
||||
out = torch.zeros(batch_size, 0, 5, device=targets.device)
|
||||
else:
|
||||
i = targets[:, 0] # image index
|
||||
_, counts = i.unique(return_counts=True)
|
||||
counts = counts.to(dtype=torch.int32)
|
||||
out = torch.zeros(batch_size, counts.max(), 5, device=targets.device)
|
||||
for j in range(batch_size):
|
||||
matches = i == j
|
||||
n = matches.sum()
|
||||
if n:
|
||||
out[j, :n] = targets[matches, 1:] # cls, x, y, w, h
|
||||
out[..., 1:5] = xywh2xyxy(out[..., 1:5].mul_(scale_tensor))
|
||||
return out
|
||||
|
||||
def bbox_decode(self, anchor_points, pred_dist):
|
||||
"""Decode predicted object bounding box coordinates from anchor points and distribution."""
|
||||
if self.use_dfl:
|
||||
b, a, c = pred_dist.shape # batch, anchors, channels
|
||||
# Project distributions to scalars using self.proj (0, 1, 2, ..., reg_max-1)
|
||||
pred_dist = pred_dist.view(b, a, 4, c // 4).softmax(3).matmul(self.proj.to(pred_dist.device).type(pred_dist.dtype))
|
||||
return dist2bbox(pred_dist, anchor_points, xywh=False)
|
||||
|
||||
def __call__(self, preds, batch):
|
||||
"""
|
||||
Args:
|
||||
preds: List of prediction tensors [B, C, H, W] for each stride level.
|
||||
C should be reg_max * 4 + nc.
|
||||
batch: Dict containing:
|
||||
'batch_idx': Tensor [N], image index for each target
|
||||
'cls': Tensor [N], class index for each target
|
||||
'bboxes': Tensor [N, 4], normalized xywh format
|
||||
Returns:
|
||||
total_loss: scalar tensor (sum of box, cls, dfl), multiplied by batch_size as per original logic if needed,
|
||||
but here we return mean or sum based on reduction.
|
||||
WARNING: Ultralytics returns (loss * bs) usually.
|
||||
loss_items: Tensor used for logging [box, cls, dfl]
|
||||
"""
|
||||
# Ensure proj is on correct device
|
||||
self.device = preds[0].device
|
||||
loss = torch.zeros(3, device=self.device) # box, cls, dfl
|
||||
|
||||
feats = preds
|
||||
|
||||
# Concat features from different strides: (B, no, Total_Anchors)
|
||||
pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
|
||||
(self.reg_max * 4, self.nc), 1
|
||||
)
|
||||
|
||||
pred_scores = pred_scores.permute(0, 2, 1).contiguous() # (B, Anchors, nc)
|
||||
pred_distri = pred_distri.permute(0, 2, 1).contiguous() # (B, Anchors, reg_max * 4)
|
||||
|
||||
dtype = pred_scores.dtype
|
||||
batch_size = pred_scores.shape[0]
|
||||
|
||||
# Calculate image size from features and first stride (assuming square output implies inputs)
|
||||
# Ultralytics uses provided imgsz, but we can infer roughly or pass it.
|
||||
# Here assuming feature map H*stride[0] = img H
|
||||
imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0]
|
||||
|
||||
# Generate anchors
|
||||
anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
|
||||
|
||||
# Process Targets
|
||||
if "batch_idx" not in batch:
|
||||
# Fallback if simplified batch input passed (targets assumed [N, 6] -> idx, cls, x, y, w, h)
|
||||
# This handles cases where user passes raw target tensor instead of dict
|
||||
raise ValueError("Batch dict must contain 'batch_idx', 'cls', and 'bboxes'")
|
||||
|
||||
targets = torch.cat((batch["batch_idx"].view(-1, 1), batch["cls"].view(-1, 1), batch["bboxes"]), 1)
|
||||
# Check if Normalized (xywh)
|
||||
targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
|
||||
gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
|
||||
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0)
|
||||
|
||||
# Decode Boxes
|
||||
pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4)
|
||||
|
||||
# Task Aligned Assignment
|
||||
_, target_bboxes, target_scores, fg_mask, _ = self.assigner(
|
||||
pred_scores.detach().sigmoid(),
|
||||
(pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
|
||||
anchor_points * stride_tensor,
|
||||
gt_labels,
|
||||
gt_bboxes,
|
||||
mask_gt,
|
||||
)
|
||||
|
||||
target_scores_sum = max(target_scores.sum(), 1)
|
||||
|
||||
# CLS Loss
|
||||
loss[1] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum
|
||||
|
||||
# BBox Loss (IoU + DFL)
|
||||
if fg_mask.sum():
|
||||
target_bboxes /= stride_tensor
|
||||
loss[0], loss[2] = self.bbox_loss(
|
||||
pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask
|
||||
)
|
||||
|
||||
# Apply Gains
|
||||
loss[0] *= self.hyp['box']
|
||||
loss[1] *= self.hyp['cls']
|
||||
loss[2] *= self.hyp['dfl']
|
||||
|
||||
return loss.sum() * batch_size, loss.detach() # scaling by batch_size to match Ultralytics behavior
|
||||
|
||||
# ==============================================================================
|
||||
# Usage Example
|
||||
# ==============================================================================
|
||||
if __name__ == "__main__":
|
||||
# 配置
|
||||
num_classes = 80
|
||||
reg_max = 16
|
||||
strides = [8, 16, 32]
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
# 初始化 Loss
|
||||
criterion = YOLOv8DetectionLoss(nc=num_classes, reg_max=reg_max, stride=strides)
|
||||
|
||||
# 模拟输入 (Batch Size = 2)
|
||||
bs = 2
|
||||
# 模拟模型输出: 3 个层级的特征图
|
||||
# [BS, Channels, H, W], Channels = reg_max * 4 + num_classes
|
||||
channels = reg_max * 4 + num_classes
|
||||
preds = [
|
||||
torch.randn(bs, channels, 80, 80, device=device, requires_grad=True), # stride 8
|
||||
torch.randn(bs, channels, 40, 40, device=device, requires_grad=True), # stride 16
|
||||
torch.randn(bs, channels, 20, 20, device=device, requires_grad=True), # stride 32
|
||||
]
|
||||
|
||||
# 模拟 Ground Truth
|
||||
# batch: idx, cls, x, y, w, h (normalized 0-1)
|
||||
target_batch = {
|
||||
"batch_idx": torch.tensor([0, 0, 1], device=device), # Image indices
|
||||
"cls": torch.tensor([1, 5, 10], device=device), # Class indices
|
||||
"bboxes": torch.tensor([ # Normalized xywh
|
||||
[0.5, 0.5, 0.2, 0.3],
|
||||
[0.1, 0.1, 0.1, 0.1],
|
||||
[0.8, 0.8, 0.2, 0.2]
|
||||
], device=device)
|
||||
}
|
||||
|
||||
# 计算 Loss
|
||||
total_loss, loss_items = criterion(preds, target_batch)
|
||||
|
||||
print(f"Total Loss: {total_loss.item()}")
|
||||
print(f"Loss Items (Box, Cls, DFL): {loss_items}")
|
||||
|
||||
# 反向传播测试
|
||||
total_loss.backward()
|
||||
print("Backward pass successful.")
|
||||
148
metrics.py
Normal file
148
metrics.py
Normal file
@@ -0,0 +1,148 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
import torchvision
|
||||
|
||||
def xywh2xyxy(x):
|
||||
"""Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right"""
|
||||
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
||||
y[..., 0] = x[..., 0] - x[..., 2] / 2 # top left x
|
||||
y[..., 1] = x[..., 1] - x[..., 3] / 2 # top left y
|
||||
y[..., 2] = x[..., 0] + x[..., 2] / 2 # bottom right x
|
||||
y[..., 3] = x[..., 1] + x[..., 3] / 2 # bottom right y
|
||||
return y
|
||||
|
||||
def box_iou(box1, box2):
|
||||
"""
|
||||
Return intersection-over-union (Jaccard index) of boxes.
|
||||
Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
|
||||
box1: [N, 4]
|
||||
box2: [M, 4]
|
||||
Returns: [N, M]
|
||||
"""
|
||||
def box_area(box):
|
||||
return (box[:, 2] - box[:, 0]) * (box[:, 3] - box[:, 1])
|
||||
|
||||
area1 = box_area(box1)
|
||||
area2 = box_area(box2)
|
||||
|
||||
lt = torch.max(box1[:, None, :2], box2[:, :2]) # [N,M,2]
|
||||
rb = torch.min(box1[:, None, 2:], box2[:, 2:]) # [N,M,2]
|
||||
|
||||
wh = (rb - lt).clamp(min=0) # [N,M,2]
|
||||
inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
|
||||
|
||||
union = area1[:, None] + area2 - inter
|
||||
return inter / (union + 1e-6)
|
||||
|
||||
def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, max_det=300):
|
||||
"""
|
||||
Non-Maximum Suppression (NMS) on inference results
|
||||
prediction: [Batch, 84, 8400] (for YOLOv8/11)
|
||||
"""
|
||||
# [Batch, 84, Anchors] -> [Batch, Anchors, 84]
|
||||
prediction = prediction.transpose(1, 2)
|
||||
|
||||
bs = prediction.shape[0]
|
||||
nc = prediction.shape[2] - 4
|
||||
xc = prediction[..., 4:].max(-1)[0] > conf_thres
|
||||
|
||||
output = [torch.zeros((0, 6), device=prediction.device)] * bs
|
||||
|
||||
for xi, x in enumerate(prediction):
|
||||
x = x[xc[xi]]
|
||||
if not x.shape[0]:
|
||||
continue
|
||||
|
||||
# Box decoding
|
||||
box = xywh2xyxy(x[:, :4])
|
||||
|
||||
# Confidence and Class
|
||||
conf, j = x[:, 4:].max(1, keepdim=True)
|
||||
x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]
|
||||
|
||||
n = x.shape[0]
|
||||
if not n:
|
||||
continue
|
||||
elif n > max_det:
|
||||
x = x[x[:, 4].argsort(descending=True)[:max_det]]
|
||||
|
||||
# Batched NMS
|
||||
c = x[:, 5:6] * 7680
|
||||
boxes, scores = x[:, :4] + c, x[:, 4]
|
||||
i = torchvision.ops.nms(boxes, scores, iou_thres)
|
||||
output[xi] = x[i]
|
||||
|
||||
return output
|
||||
|
||||
def compute_ap(recall, precision):
|
||||
""" Compute the average precision, given the recall and precision curves """
|
||||
# Append sentinel values to beginning and end
|
||||
mrec = np.concatenate(([0.0], recall, [1.0]))
|
||||
mpre = np.concatenate(([1.0], precision, [0.0]))
|
||||
|
||||
# Compute the precision envelope
|
||||
mpre = np.flip(np.maximum.accumulate(np.flip(mpre)))
|
||||
|
||||
# Integrate area under curve
|
||||
method = 'interp' # methods: 'continuous', 'interp'
|
||||
if method == 'interp':
|
||||
x = np.linspace(0, 1, 101) # 101-point interp (COCO)
|
||||
ap = np.trapz(np.interp(x, mrec, mpre), x) # integrate
|
||||
else: # 'continuous'
|
||||
i = np.where(mrec[1:] != mrec[:-1])[0] # points where x axis (recall) changes
|
||||
ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) # area under curve
|
||||
|
||||
return ap, mpre, mrec
|
||||
|
||||
def smooth(y, f=0.05):
|
||||
"""Box filter of fraction f"""
|
||||
nf = round(len(y) * f * 2) // 2 + 1 # number of filter elements (must be odd)
|
||||
p = np.ones(nf // 2) # ones padding
|
||||
yp = np.concatenate((p * y[0], y, p * y[-1]), 0) # y padded
|
||||
return np.convolve(yp, np.ones(nf) / nf, mode='valid') # y-smoothed
|
||||
|
||||
def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir='.', names=(), eps=1e-16):
|
||||
""" Compute the average precision, given the recall and precision curves. """
|
||||
# Sort by objectness
|
||||
i = np.argsort(-conf)
|
||||
tp, conf, pred_cls = tp[i], conf[i], pred_cls[i]
|
||||
|
||||
# Find unique classes
|
||||
unique_classes, nt = np.unique(target_cls, return_counts=True)
|
||||
nc = unique_classes.shape[0] # number of classes, number of detections
|
||||
|
||||
# Create Precision-Recall curve and compute AP for each class
|
||||
px, py = np.linspace(0, 1, 1000), [] # for plotting
|
||||
ap, p, r = np.zeros((nc, tp.shape[1])), np.zeros((nc, 1000)), np.zeros((nc, 1000))
|
||||
for ci, c in enumerate(unique_classes):
|
||||
i = pred_cls == c
|
||||
n_l = (target_cls == c).sum() # number of labels
|
||||
n_p = i.sum() # number of predictions
|
||||
|
||||
if n_p == 0 or n_l == 0:
|
||||
continue
|
||||
|
||||
# Accumulate FPs and TPs
|
||||
fpc = (1 - tp[i]).cumsum(0)
|
||||
tpc = tp[i].cumsum(0)
|
||||
|
||||
# Recall
|
||||
recall = tpc / (n_l + eps) # recall curve
|
||||
r[ci] = np.interp(-px, -conf[i], recall[:, 0], left=0) # negative x, xp because xp decreases
|
||||
|
||||
# Precision
|
||||
precision = tpc / (tpc + fpc) # precision curve
|
||||
p[ci] = np.interp(-px, -conf[i], precision[:, 0], left=1) # p at pr_score
|
||||
|
||||
# AP from recall-precision curve
|
||||
for j in range(tp.shape[1]):
|
||||
ap[ci, j], mpre, mrec = compute_ap(recall[:, j], precision[:, j])
|
||||
|
||||
# Compute F1 (harmonic mean of precision and recall)
|
||||
f1 = 2 * p * r / (p + r + eps)
|
||||
i = smooth(f1.mean(0), 0.1).argmax() # max F1 index
|
||||
p, r, f1 = p[:, i], r[:, i], f1[:, i]
|
||||
tp = (r * nt).round().astype(int)
|
||||
fp = (tp / (p + eps) - tp).astype(int)
|
||||
|
||||
return tp, fp, p, r, f1, ap, unique_classes.astype(int)
|
||||
117
mobile_clip_standalone.py
Normal file
117
mobile_clip_standalone.py
Normal file
@@ -0,0 +1,117 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from pathlib import Path
|
||||
from typing import List, Union
|
||||
import mobileclip
|
||||
|
||||
class TextModel(nn.Module):
|
||||
"""TextModel 基类,定义接口"""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def tokenize(self, texts):
|
||||
raise NotImplementedError
|
||||
|
||||
def encode_text(self, texts, dtype):
|
||||
raise NotImplementedError
|
||||
|
||||
class MobileCLIP(TextModel):
|
||||
"""
|
||||
MobileCLIP 文本编码器。
|
||||
"""
|
||||
config_size_map = {"s0": "s0", "s1": "s1", "s2": "s2", "b": "b", "blt": "b"}
|
||||
|
||||
def __init__(self, checkpoint: str, size: str = "s0", device: Union[str, torch.device] = "cpu") -> None:
|
||||
"""
|
||||
初始化 MobileCLIP 文本编码器。
|
||||
|
||||
Args:
|
||||
checkpoint (str): 模型权重文件路径 (.pt 或 .ts).
|
||||
size (str): 模型大小标识符 ('s0', 's1', 's2', 'b', 'blt').
|
||||
device (torch.device): 加载模型的设备.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
if isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
|
||||
if not Path(checkpoint).exists():
|
||||
raise FileNotFoundError(f"找不到权重文件: {checkpoint}")
|
||||
|
||||
if size not in self.config_size_map:
|
||||
raise ValueError(f"不支持的大小: {size}. 可选: {list(self.config_size_map.keys())}")
|
||||
|
||||
config = self.config_size_map[size]
|
||||
|
||||
# 1. 加载 Tokenizer
|
||||
self.tokenizer = mobileclip.get_tokenizer(f"mobileclip_{config}")
|
||||
|
||||
# 2. 加载模型
|
||||
if str(checkpoint).endswith('.ts'):
|
||||
# TorchScript 格式 (例如 mobileclip_blt.ts)
|
||||
print(f"Loading TorchScript model from {checkpoint}...")
|
||||
self.model = torch.jit.load(checkpoint, map_location=device)
|
||||
self.is_torchscript = True
|
||||
else:
|
||||
# PyTorch 格式 (.pt)
|
||||
print(f"Loading PyTorch model from {checkpoint}...")
|
||||
self.model = mobileclip.create_model_and_transforms(
|
||||
f"mobileclip_{config}",
|
||||
pretrained=checkpoint,
|
||||
device=device
|
||||
)[0]
|
||||
self.is_torchscript = False
|
||||
|
||||
self.to(device)
|
||||
self.device = device
|
||||
self.eval()
|
||||
|
||||
def tokenize(self, texts: List[str]) -> torch.Tensor:
|
||||
"""
|
||||
将文本转换为 token。
|
||||
"""
|
||||
return self.tokenizer(texts).to(self.device)
|
||||
|
||||
def encode_text(self, texts: torch.Tensor, dtype: torch.dtype = torch.float32) -> torch.Tensor:
|
||||
"""
|
||||
编码 tokenized 文本并进行归一化。
|
||||
"""
|
||||
with torch.no_grad():
|
||||
if self.is_torchscript:
|
||||
return self.model(texts).to(dtype)
|
||||
|
||||
text_features = self.model.encode_text(texts).to(dtype) # type: ignore
|
||||
text_features /= text_features.norm(p=2, dim=-1, keepdim=True)
|
||||
return text_features
|
||||
|
||||
# --- 使用示例 ---
|
||||
if __name__ == "__main__":
|
||||
# 1. 设置设备
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
print(f"Using device: {device}")
|
||||
|
||||
# 指定本地模型路径
|
||||
checkpoint_path = "mobileclip_blt.ts"
|
||||
|
||||
try:
|
||||
if Path(checkpoint_path).exists():
|
||||
# 2. 初始化模型 (指定本地路径和对应大小)
|
||||
# 注意:blt 对应 size="blt"
|
||||
model = MobileCLIP(checkpoint=checkpoint_path, size="blt", device=device)
|
||||
|
||||
# 3. 准备文本
|
||||
input_texts = ["a photo of a cat", "a photo of a dog"]
|
||||
|
||||
# 4. Tokenize
|
||||
tokens = model.tokenize(input_texts)
|
||||
print(f"Tokens shape: {tokens.shape}")
|
||||
|
||||
# 5. Encode
|
||||
features = model.encode_text(tokens)
|
||||
print(f"Features shape: {features.shape}")
|
||||
print("运行成功!")
|
||||
else:
|
||||
print(f"权重文件不存在: {checkpoint_path}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"发生错误: {e}")
|
||||
98
mobileclip/__init__.py
Normal file
98
mobileclip/__init__.py
Normal file
@@ -0,0 +1,98 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
# For licensing see accompanying LICENSE file.
|
||||
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torchvision.transforms import (
|
||||
CenterCrop,
|
||||
Compose,
|
||||
InterpolationMode,
|
||||
Resize,
|
||||
ToTensor,
|
||||
)
|
||||
|
||||
from mobileclip.clip import CLIP
|
||||
from mobileclip.modules.common.mobileone import reparameterize_model
|
||||
from mobileclip.modules.text.tokenizer import (
|
||||
ClipTokenizer,
|
||||
)
|
||||
|
||||
|
||||
def create_model_and_transforms(
|
||||
model_name: str,
|
||||
pretrained: str | None = None,
|
||||
reparameterize: bool | None = True,
|
||||
device: str | torch.device = "cpu",
|
||||
) -> tuple[nn.Module, Any, Any]:
|
||||
"""Method to instantiate model and pre-processing transforms necessary for inference.
|
||||
|
||||
Args:
|
||||
model_name: Model name. Choose from ['mobileclip_s0', 'mobileclip_s1', 'mobileclip_s2', 'mobileclip_b']
|
||||
pretrained: Location of pretrained checkpoint.
|
||||
reparameterize: When set to True, re-parameterizable branches get folded for faster inference.
|
||||
device: Device identifier for model placement.
|
||||
|
||||
Returns:
|
||||
Tuple of instantiated model, and preprocessing transforms for inference.
|
||||
"""
|
||||
# Config files
|
||||
root_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
configs_dir = os.path.join(root_dir, "configs")
|
||||
model_cfg_file = os.path.join(configs_dir, model_name + ".json")
|
||||
|
||||
# Get config from yaml file
|
||||
if not os.path.exists(model_cfg_file):
|
||||
raise ValueError(f"Unsupported model name: {model_name}")
|
||||
model_cfg = json.load(open(model_cfg_file))
|
||||
|
||||
# Build preprocessing transforms for inference
|
||||
resolution = model_cfg["image_cfg"]["image_size"]
|
||||
resize_size = resolution
|
||||
centercrop_size = resolution
|
||||
aug_list = [
|
||||
Resize(
|
||||
resize_size,
|
||||
interpolation=InterpolationMode.BILINEAR,
|
||||
),
|
||||
CenterCrop(centercrop_size),
|
||||
ToTensor(),
|
||||
]
|
||||
preprocess = Compose(aug_list)
|
||||
|
||||
# Build model
|
||||
model = CLIP(cfg=model_cfg)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
# Load checkpoint
|
||||
if pretrained is not None:
|
||||
chkpt = torch.load(pretrained)
|
||||
model.load_state_dict(chkpt)
|
||||
|
||||
# Reparameterize model for inference (if specified)
|
||||
if reparameterize:
|
||||
model = reparameterize_model(model)
|
||||
|
||||
return model, None, preprocess
|
||||
|
||||
|
||||
def get_tokenizer(model_name: str) -> nn.Module:
|
||||
# Config files
|
||||
root_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
configs_dir = os.path.join(root_dir, "configs")
|
||||
model_cfg_file = os.path.join(configs_dir, model_name + ".json")
|
||||
|
||||
# Get config from yaml file
|
||||
model_cfg = json.load(open(model_cfg_file))
|
||||
|
||||
# Build tokenizer
|
||||
text_tokenizer = ClipTokenizer(model_cfg)
|
||||
return text_tokenizer
|
||||
69
mobileclip/clip.py
Normal file
69
mobileclip/clip.py
Normal file
@@ -0,0 +1,69 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
#
|
||||
# For licensing see accompanying LICENSE file.
|
||||
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
|
||||
#
|
||||
"""Model schema in open_clip format for inference only."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from mobileclip.text_encoder import (
|
||||
TextTransformer,
|
||||
)
|
||||
|
||||
from .image_encoder import MCi
|
||||
|
||||
|
||||
class CLIP(nn.Module):
|
||||
"""Base class for multi-modal image-text data."""
|
||||
|
||||
def __init__(self, cfg: dict, output_dict: bool = False, *args, **kwargs) -> None:
|
||||
super().__init__()
|
||||
self.output_dict = output_dict
|
||||
self.projection_dim = cfg["embed_dim"]
|
||||
if self.projection_dim is None:
|
||||
raise ValueError("Please specify `embed_dim` in model config.")
|
||||
|
||||
self.image_encoder = MCi(
|
||||
model_name=cfg["image_cfg"]["model_name"],
|
||||
projection_dim=self.projection_dim,
|
||||
)
|
||||
self.text_encoder = TextTransformer(cfg=cfg["text_cfg"], projection_dim=self.projection_dim)
|
||||
self.logit_scale = nn.Parameter(torch.ones([]) * math.log(1.0 / 0.07))
|
||||
|
||||
def _exponentiate_and_clip_logits(self, max_scale: float = 100.0):
|
||||
scale = self.logit_scale.exp()
|
||||
scale = torch.clamp(scale, 0, max_scale)
|
||||
return scale
|
||||
|
||||
def encode_image(self, image: torch.Tensor, normalize: bool = False):
|
||||
image_encoder_out = self.image_encoder(image)
|
||||
if isinstance(image_encoder_out, dict):
|
||||
features = image_encoder_out["logits"]
|
||||
else:
|
||||
features = image_encoder_out
|
||||
return F.normalize(features, dim=-1) if normalize else features
|
||||
|
||||
def encode_text(self, text: torch.Tensor, normalize: bool = False):
|
||||
text_features = self.text_encoder(text_tokens=text, key_padding_mask=None)
|
||||
return F.normalize(text_features, dim=-1) if normalize else text_features
|
||||
|
||||
def forward(self, image: torch.Tensor | None = None, text: torch.Tensor | None = None, *args, **kwargs) -> Any:
|
||||
image_embeddings = self.encode_image(image, normalize=True) if image is not None else None
|
||||
text_embeddings = self.encode_text(text, normalize=True) if text is not None else None
|
||||
|
||||
if self.output_dict:
|
||||
return {
|
||||
"image_features": image_embeddings,
|
||||
"text_features": text_embeddings,
|
||||
"logit_scale": self._exponentiate_and_clip_logits(),
|
||||
}
|
||||
return image_embeddings, text_embeddings, self._exponentiate_and_clip_logits()
|
||||
18
mobileclip/configs/mobileclip_b.json
Normal file
18
mobileclip/configs/mobileclip_b.json
Normal file
@@ -0,0 +1,18 @@
|
||||
{
|
||||
"embed_dim": 512,
|
||||
"image_cfg": {
|
||||
"image_size": 224,
|
||||
"model_name": "vit_b16"
|
||||
},
|
||||
"text_cfg": {
|
||||
"context_length": 77,
|
||||
"vocab_size": 49408,
|
||||
"dim": 512,
|
||||
"ffn_multiplier_per_layer": 4.0,
|
||||
"n_heads_per_layer": 8,
|
||||
"n_transformer_layers": 12,
|
||||
"norm_layer": "layer_norm_fp32",
|
||||
"causal_masking": true,
|
||||
"model_name": "base"
|
||||
}
|
||||
}
|
||||
18
mobileclip/configs/mobileclip_s0.json
Normal file
18
mobileclip/configs/mobileclip_s0.json
Normal file
@@ -0,0 +1,18 @@
|
||||
{
|
||||
"embed_dim": 512,
|
||||
"image_cfg": {
|
||||
"image_size": 256,
|
||||
"model_name": "mci0"
|
||||
},
|
||||
"text_cfg": {
|
||||
"context_length": 77,
|
||||
"vocab_size": 49408,
|
||||
"dim": 512,
|
||||
"ffn_multiplier_per_layer": 4.0,
|
||||
"n_heads_per_layer": 8,
|
||||
"n_transformer_layers": 4,
|
||||
"norm_layer": "layer_norm_fp32",
|
||||
"causal_masking": false,
|
||||
"model_name": "mct"
|
||||
}
|
||||
}
|
||||
18
mobileclip/configs/mobileclip_s1.json
Normal file
18
mobileclip/configs/mobileclip_s1.json
Normal file
@@ -0,0 +1,18 @@
|
||||
{
|
||||
"embed_dim": 512,
|
||||
"image_cfg": {
|
||||
"image_size": 256,
|
||||
"model_name": "mci1"
|
||||
},
|
||||
"text_cfg": {
|
||||
"context_length": 77,
|
||||
"vocab_size": 49408,
|
||||
"dim": 512,
|
||||
"ffn_multiplier_per_layer": 4.0,
|
||||
"n_heads_per_layer": 8,
|
||||
"n_transformer_layers": 12,
|
||||
"norm_layer": "layer_norm_fp32",
|
||||
"causal_masking": false,
|
||||
"model_name": "base"
|
||||
}
|
||||
}
|
||||
18
mobileclip/configs/mobileclip_s2.json
Normal file
18
mobileclip/configs/mobileclip_s2.json
Normal file
@@ -0,0 +1,18 @@
|
||||
{
|
||||
"embed_dim": 512,
|
||||
"image_cfg": {
|
||||
"image_size": 256,
|
||||
"model_name": "mci2"
|
||||
},
|
||||
"text_cfg": {
|
||||
"context_length": 77,
|
||||
"vocab_size": 49408,
|
||||
"dim": 512,
|
||||
"ffn_multiplier_per_layer": 4.0,
|
||||
"n_heads_per_layer": 8,
|
||||
"n_transformer_layers": 12,
|
||||
"norm_layer": "layer_norm_fp32",
|
||||
"causal_masking": false,
|
||||
"model_name": "base"
|
||||
}
|
||||
}
|
||||
63
mobileclip/image_encoder.py
Normal file
63
mobileclip/image_encoder.py
Normal file
@@ -0,0 +1,63 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
#
|
||||
# For licensing see accompanying LICENSE file.
|
||||
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
|
||||
#
|
||||
from typing import Any
|
||||
|
||||
import torch.nn as nn
|
||||
from timm.models import create_model
|
||||
|
||||
from mobileclip import models # noqa added to register models
|
||||
from mobileclip.modules.image.image_projection import GlobalPool2D
|
||||
|
||||
|
||||
class MCi(nn.Module):
|
||||
"""This class implements `MCi Models <https://arxiv.org/pdf/2311.17049.pdf>`_."""
|
||||
|
||||
def __init__(self, model_name: str, *args, **kwargs) -> None:
|
||||
super().__init__()
|
||||
self.projection_dim = None
|
||||
if "projection_dim" in kwargs:
|
||||
self.projection_dim = kwargs.get("projection_dim")
|
||||
|
||||
# Create model
|
||||
self.model = create_model(model_name, projection_dim=self.projection_dim)
|
||||
|
||||
# Build out projection head.
|
||||
if self.projection_dim is not None:
|
||||
if hasattr(self.model, "head"):
|
||||
self.model.head = MCi._update_image_classifier(
|
||||
image_classifier=self.model.head, projection_dim=self.projection_dim
|
||||
)
|
||||
|
||||
def forward(self, x: Any, *args, **kwargs) -> Any:
|
||||
"""A forward function of the model."""
|
||||
x = self.model(x)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def _get_in_feature_dimension(image_classifier: nn.Module) -> int:
|
||||
"""Return the input feature dimension to the image classification head."""
|
||||
in_features = None
|
||||
if isinstance(image_classifier, nn.Sequential):
|
||||
# Classifier that uses nn.Sequential usually has global pooling and
|
||||
# multiple linear layers. Find the first linear layer and get its
|
||||
# in_features
|
||||
for layer in image_classifier:
|
||||
if isinstance(layer, nn.Linear):
|
||||
in_features = layer.in_features
|
||||
break
|
||||
elif isinstance(image_classifier, nn.Linear):
|
||||
in_features = image_classifier.in_features
|
||||
|
||||
if in_features is None:
|
||||
raise NotImplementedError(f"Cannot get input feature dimension of {image_classifier}.")
|
||||
return in_features
|
||||
|
||||
@staticmethod
|
||||
def _update_image_classifier(image_classifier: nn.Module, projection_dim: int, *args, **kwargs) -> nn.Module:
|
||||
in_features = MCi._get_in_feature_dimension(image_classifier)
|
||||
new_img_classifier = GlobalPool2D(in_dim=in_features, out_dim=projection_dim)
|
||||
return new_img_classifier
|
||||
120
mobileclip/logger.py
Normal file
120
mobileclip/logger.py
Normal file
@@ -0,0 +1,120 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
# For licensing see accompanying LICENSE file.
|
||||
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
|
||||
text_colors = {
|
||||
"logs": "\033[34m", # 033 is the escape code and 34 is the color code
|
||||
"info": "\033[32m",
|
||||
"warning": "\033[33m",
|
||||
"debug": "\033[93m",
|
||||
"error": "\033[31m",
|
||||
"bold": "\033[1m",
|
||||
"end_color": "\033[0m",
|
||||
"light_red": "\033[36m",
|
||||
}
|
||||
|
||||
|
||||
def get_curr_time_stamp() -> str:
|
||||
return time.strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
|
||||
def error(message: str) -> None:
|
||||
time_stamp = get_curr_time_stamp()
|
||||
error_str = text_colors["error"] + text_colors["bold"] + "ERROR " + text_colors["end_color"]
|
||||
|
||||
# exiting with code -1 does not tell any information about the error (e.g., NaN encountered in the loss).
|
||||
# For more descriptive error messages, we replace exit(-1) with sys.exit(ERROR_MESSAGE).
|
||||
# This allows us to handle specific exceptions in the tests.
|
||||
|
||||
# print("{} - {} - {}".format(time_stamp, error_str, message), flush=True)
|
||||
# print("{} - {} - {}".format(time_stamp, error_str, "Exiting!!!"), flush=True)
|
||||
# exit(-1)
|
||||
|
||||
if sys.exc_info()[0] is None:
|
||||
traceback.print_stack()
|
||||
else:
|
||||
traceback.print_exc()
|
||||
sys.exit(f"{time_stamp} - {error_str} - {message}. Exiting!!!")
|
||||
|
||||
|
||||
def color_text(in_text: str) -> str:
|
||||
return text_colors["light_red"] + in_text + text_colors["end_color"]
|
||||
|
||||
|
||||
def log(message: str, end="\n") -> None:
|
||||
time_stamp = get_curr_time_stamp()
|
||||
log_str = text_colors["logs"] + text_colors["bold"] + "LOGS " + text_colors["end_color"]
|
||||
print(f"{time_stamp} - {log_str} - {message}", end=end)
|
||||
|
||||
|
||||
def warning(message: str | Warning) -> None:
|
||||
if isinstance(message, Warning):
|
||||
message = f"{type(message).__name__}({','.join(map(repr, message.args))}"
|
||||
|
||||
time_stamp = get_curr_time_stamp()
|
||||
warn_str = text_colors["warning"] + text_colors["bold"] + "WARNING" + text_colors["end_color"]
|
||||
print(f"{time_stamp} - {warn_str} - {message}")
|
||||
|
||||
|
||||
def ignore_exception_with_warning(message: str) -> None:
|
||||
"""After catching a tolerable exception E1 (e.g. when Model.forward() fails during profiling with try-catch, it'll
|
||||
be helpful to log the exception for future investigation. But printing the error stack trace, as is, could be
|
||||
confusing when an uncaught (non-tolerable) exception "E2" raises down the road. Then, the log will contain two
|
||||
stack traces for E1, E2. When looking for errors in logs, users should look for E2, but they may find E1.
|
||||
|
||||
This function appends "(WARNING)" at the end of all lines of the E1 traceback, so that the user can distinguish E1
|
||||
from uncaught exception E2.
|
||||
|
||||
Args:
|
||||
message: Extra explanation and context for debugging. (Note: the exception obj will be automatically fetched
|
||||
from python. No need to pass it as an argument or as message)
|
||||
"""
|
||||
warning(f"{message}:\n{traceback.format_exc()}".replace("\n", "\n(WARNING)"))
|
||||
|
||||
|
||||
def info(message: str, print_line: bool | None = False) -> None:
|
||||
time_stamp = get_curr_time_stamp()
|
||||
info_str = text_colors["info"] + text_colors["bold"] + "INFO " + text_colors["end_color"]
|
||||
print(f"{time_stamp} - {info_str} - {message}")
|
||||
if print_line:
|
||||
double_dash_line(dashes=150)
|
||||
|
||||
|
||||
def debug(message: str) -> None:
|
||||
time_stamp = get_curr_time_stamp()
|
||||
log_str = text_colors["debug"] + text_colors["bold"] + "DEBUG " + text_colors["end_color"]
|
||||
print(f"{time_stamp} - {log_str} - {message}")
|
||||
|
||||
|
||||
def double_dash_line(dashes: int | None = 75) -> None:
|
||||
print(text_colors["error"] + "=" * dashes + text_colors["end_color"])
|
||||
|
||||
|
||||
def singe_dash_line(dashes: int | None = 67) -> None:
|
||||
print("-" * dashes)
|
||||
|
||||
|
||||
def print_header(header: str) -> None:
|
||||
double_dash_line()
|
||||
print(text_colors["info"] + text_colors["bold"] + "=" * 50 + str(header) + text_colors["end_color"])
|
||||
double_dash_line()
|
||||
|
||||
|
||||
def print_header_minor(header: str) -> None:
|
||||
print(text_colors["warning"] + text_colors["bold"] + "=" * 25 + str(header) + text_colors["end_color"])
|
||||
|
||||
|
||||
def disable_printing():
|
||||
sys.stdout = open(os.devnull, "w")
|
||||
|
||||
|
||||
def enable_printing():
|
||||
sys.stdout = sys.__stdout__
|
||||
12
mobileclip/models/__init__.py
Normal file
12
mobileclip/models/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
#
|
||||
# For licensing see accompanying LICENSE file.
|
||||
# Copyright (C) 2024 Apple Inc. All rights reserved.
|
||||
#
|
||||
from .mci import (
|
||||
mci0,
|
||||
mci1,
|
||||
mci2,
|
||||
)
|
||||
from .vit import vit_b16
|
||||
888
mobileclip/models/mci.py
Normal file
888
mobileclip/models/mci.py
Normal file
@@ -0,0 +1,888 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
# For licensing see accompanying LICENSE file.
|
||||
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.models import register_model
|
||||
from timm.models.layers import DropPath, trunc_normal_
|
||||
|
||||
from mobileclip.modules.common.mobileone import MobileOneBlock
|
||||
from mobileclip.modules.image.replknet import ReparamLargeKernelConv
|
||||
|
||||
|
||||
def _cfg(url="", **kwargs):
|
||||
return {
|
||||
"url": url,
|
||||
"num_classes": 1000,
|
||||
"input_size": (3, 256, 256),
|
||||
"pool_size": None,
|
||||
"crop_pct": 0.95,
|
||||
"interpolation": "bicubic",
|
||||
"mean": IMAGENET_DEFAULT_MEAN,
|
||||
"std": IMAGENET_DEFAULT_STD,
|
||||
"classifier": "head",
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = {
|
||||
"fastvit_t": _cfg(crop_pct=0.9),
|
||||
"fastvit_s": _cfg(crop_pct=0.9),
|
||||
"fastvit_m": _cfg(crop_pct=0.95),
|
||||
}
|
||||
|
||||
|
||||
def convolutional_stem(in_channels: int, out_channels: int, inference_mode: bool = False) -> nn.Sequential:
|
||||
"""Build convolutional stem with MobileOne blocks.
|
||||
|
||||
Args:
|
||||
in_channels: Number of input channels.
|
||||
out_channels: Number of output channels.
|
||||
inference_mode: Flag to instantiate model in inference mode. Default: ``False``
|
||||
|
||||
Returns:
|
||||
nn.Sequential object with stem elements.
|
||||
"""
|
||||
return nn.Sequential(
|
||||
MobileOneBlock(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
groups=1,
|
||||
inference_mode=inference_mode,
|
||||
use_se=False,
|
||||
num_conv_branches=1,
|
||||
),
|
||||
MobileOneBlock(
|
||||
in_channels=out_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
groups=out_channels,
|
||||
inference_mode=inference_mode,
|
||||
use_se=False,
|
||||
num_conv_branches=1,
|
||||
),
|
||||
MobileOneBlock(
|
||||
in_channels=out_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
groups=1,
|
||||
inference_mode=inference_mode,
|
||||
use_se=False,
|
||||
num_conv_branches=1,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class MHSA(nn.Module):
|
||||
"""Multi-headed Self Attention module.
|
||||
|
||||
Source modified from:
|
||||
https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
head_dim: int = 32,
|
||||
qkv_bias: bool = False,
|
||||
attn_drop: float = 0.0,
|
||||
proj_drop: float = 0.0,
|
||||
) -> None:
|
||||
"""Build MHSA module that can handle 3D or 4D input tensors.
|
||||
|
||||
Args:
|
||||
dim: Number of embedding dimensions.
|
||||
head_dim: Number of hidden dimensions per head. Default: ``32``
|
||||
qkv_bias: Use bias or not. Default: ``False``
|
||||
attn_drop: Dropout rate for attention tensor.
|
||||
proj_drop: Dropout rate for projection tensor.
|
||||
"""
|
||||
super().__init__()
|
||||
assert dim % head_dim == 0, "dim should be divisible by head_dim"
|
||||
self.head_dim = head_dim
|
||||
self.num_heads = dim // head_dim
|
||||
self.scale = head_dim**-0.5
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
shape = x.shape
|
||||
B, C, H, W = shape
|
||||
N = H * W
|
||||
if len(shape) == 4:
|
||||
x = torch.flatten(x, start_dim=2).transpose(-2, -1) # (B, N, C)
|
||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
|
||||
|
||||
# trick here to make q@k.t more stable
|
||||
attn = (q * self.scale) @ k.transpose(-2, -1)
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
|
||||
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
if len(shape) == 4:
|
||||
x = x.transpose(-2, -1).reshape(B, C, H, W)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
"""Convolutional patch embedding layer."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: int,
|
||||
stride: int,
|
||||
in_channels: int,
|
||||
embed_dim: int,
|
||||
inference_mode: bool = False,
|
||||
use_se: bool = False,
|
||||
) -> None:
|
||||
"""Build patch embedding layer.
|
||||
|
||||
Args:
|
||||
patch_size: Patch size for embedding computation.
|
||||
stride: Stride for convolutional embedding layer.
|
||||
in_channels: Number of channels of input tensor.
|
||||
embed_dim: Number of embedding dimensions.
|
||||
inference_mode: Flag to instantiate model in inference mode. Default: ``False``
|
||||
use_se: If ``True`` SE block will be used.
|
||||
"""
|
||||
super().__init__()
|
||||
block = list()
|
||||
block.append(
|
||||
ReparamLargeKernelConv(
|
||||
in_channels=in_channels,
|
||||
out_channels=embed_dim,
|
||||
kernel_size=patch_size,
|
||||
stride=stride,
|
||||
groups=in_channels,
|
||||
small_kernel=3,
|
||||
inference_mode=inference_mode,
|
||||
use_se=use_se,
|
||||
)
|
||||
)
|
||||
block.append(
|
||||
MobileOneBlock(
|
||||
in_channels=embed_dim,
|
||||
out_channels=embed_dim,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
groups=1,
|
||||
inference_mode=inference_mode,
|
||||
use_se=False,
|
||||
num_conv_branches=1,
|
||||
)
|
||||
)
|
||||
self.proj = nn.Sequential(*block)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class RepMixer(nn.Module):
|
||||
"""Reparameterizable token mixer.
|
||||
|
||||
For more details, please refer to our paper: `FastViT: A Fast Hybrid Vision Transformer using Structural
|
||||
Reparameterization <https://arxiv.org/pdf/2303.14189.pdf>`_
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
kernel_size=3,
|
||||
use_layer_scale=True,
|
||||
layer_scale_init_value=1e-5,
|
||||
inference_mode: bool = False,
|
||||
):
|
||||
"""Build RepMixer Module.
|
||||
|
||||
Args:
|
||||
dim: Input feature map dimension. :math:`C_{in}` from an expected input of size :math:`(B, C_{in}, H, W)`.
|
||||
kernel_size: Kernel size for spatial mixing. Default: 3
|
||||
use_layer_scale: If True, learnable layer scale is used. Default: ``True``
|
||||
layer_scale_init_value: Initial value for layer scale. Default: 1e-5
|
||||
inference_mode: If True, instantiates model in inference mode. Default: ``False``
|
||||
"""
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.kernel_size = kernel_size
|
||||
self.inference_mode = inference_mode
|
||||
|
||||
if inference_mode:
|
||||
self.reparam_conv = nn.Conv2d(
|
||||
in_channels=self.dim,
|
||||
out_channels=self.dim,
|
||||
kernel_size=self.kernel_size,
|
||||
stride=1,
|
||||
padding=self.kernel_size // 2,
|
||||
groups=self.dim,
|
||||
bias=True,
|
||||
)
|
||||
else:
|
||||
self.norm = MobileOneBlock(
|
||||
dim,
|
||||
dim,
|
||||
kernel_size,
|
||||
padding=kernel_size // 2,
|
||||
groups=dim,
|
||||
use_act=False,
|
||||
use_scale_branch=False,
|
||||
num_conv_branches=0,
|
||||
)
|
||||
self.mixer = MobileOneBlock(
|
||||
dim,
|
||||
dim,
|
||||
kernel_size,
|
||||
padding=kernel_size // 2,
|
||||
groups=dim,
|
||||
use_act=False,
|
||||
)
|
||||
self.use_layer_scale = use_layer_scale
|
||||
if use_layer_scale:
|
||||
self.layer_scale = nn.Parameter(layer_scale_init_value * torch.ones((dim, 1, 1)), requires_grad=True)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if hasattr(self, "reparam_conv"):
|
||||
x = self.reparam_conv(x)
|
||||
return x
|
||||
else:
|
||||
if self.use_layer_scale:
|
||||
x = x + self.layer_scale * (self.mixer(x) - self.norm(x))
|
||||
else:
|
||||
x = x + self.mixer(x) - self.norm(x)
|
||||
return x
|
||||
|
||||
def reparameterize(self) -> None:
|
||||
"""Reparameterize mixer and norm into a single convolutional layer for efficient inference."""
|
||||
if self.inference_mode:
|
||||
return
|
||||
|
||||
self.mixer.reparameterize()
|
||||
self.norm.reparameterize()
|
||||
|
||||
if self.use_layer_scale:
|
||||
w = self.mixer.id_tensor + self.layer_scale.unsqueeze(-1) * (
|
||||
self.mixer.reparam_conv.weight - self.norm.reparam_conv.weight
|
||||
)
|
||||
b = torch.squeeze(self.layer_scale) * (self.mixer.reparam_conv.bias - self.norm.reparam_conv.bias)
|
||||
else:
|
||||
w = self.mixer.id_tensor + self.mixer.reparam_conv.weight - self.norm.reparam_conv.weight
|
||||
b = self.mixer.reparam_conv.bias - self.norm.reparam_conv.bias
|
||||
|
||||
self.reparam_conv = nn.Conv2d(
|
||||
in_channels=self.dim,
|
||||
out_channels=self.dim,
|
||||
kernel_size=self.kernel_size,
|
||||
stride=1,
|
||||
padding=self.kernel_size // 2,
|
||||
groups=self.dim,
|
||||
bias=True,
|
||||
)
|
||||
self.reparam_conv.weight.data = w
|
||||
self.reparam_conv.bias.data = b
|
||||
|
||||
for para in self.parameters():
|
||||
para.detach_()
|
||||
self.__delattr__("mixer")
|
||||
self.__delattr__("norm")
|
||||
if self.use_layer_scale:
|
||||
self.__delattr__("layer_scale")
|
||||
|
||||
|
||||
class ConvFFN(nn.Module):
|
||||
"""Convolutional FFN Module."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
hidden_channels: int | None = None,
|
||||
out_channels: int | None = None,
|
||||
act_layer: nn.Module = nn.GELU,
|
||||
drop: float = 0.0,
|
||||
) -> None:
|
||||
"""Build convolutional FFN module.
|
||||
|
||||
Args:
|
||||
in_channels: Number of input channels.
|
||||
hidden_channels: Number of channels after expansion. Default: None
|
||||
out_channels: Number of output channels. Default: None
|
||||
act_layer: Activation layer. Default: ``GELU``
|
||||
drop: Dropout rate. Default: ``0.0``.
|
||||
"""
|
||||
super().__init__()
|
||||
out_channels = out_channels or in_channels
|
||||
hidden_channels = hidden_channels or in_channels
|
||||
self.conv = nn.Sequential()
|
||||
self.conv.add_module(
|
||||
"conv",
|
||||
nn.Conv2d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=7,
|
||||
padding=3,
|
||||
groups=in_channels,
|
||||
bias=False,
|
||||
),
|
||||
)
|
||||
self.conv.add_module("bn", nn.BatchNorm2d(num_features=out_channels))
|
||||
self.fc1 = nn.Conv2d(in_channels, hidden_channels, kernel_size=1)
|
||||
self.act = act_layer()
|
||||
self.fc2 = nn.Conv2d(hidden_channels, out_channels, kernel_size=1)
|
||||
self.drop = nn.Dropout(drop)
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, m: nn.Module) -> None:
|
||||
if isinstance(m, nn.Conv2d):
|
||||
trunc_normal_(m.weight, std=0.02)
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.conv(x)
|
||||
x = self.fc1(x)
|
||||
x = self.act(x)
|
||||
x = self.drop(x)
|
||||
x = self.fc2(x)
|
||||
x = self.drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class RepCPE(nn.Module):
|
||||
"""Implementation of conditional positional encoding.
|
||||
|
||||
For more details refer to paper: `Conditional Positional Encodings for Vision Transformers
|
||||
<https://arxiv.org/pdf/2102.10882.pdf>`_
|
||||
|
||||
In our implementation, we can reparameterize this module to eliminate a skip connection.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
embed_dim: int = 768,
|
||||
spatial_shape: int | tuple[int, int] = (7, 7),
|
||||
inference_mode=False,
|
||||
) -> None:
|
||||
"""Build reparameterizable conditional positional encoding.
|
||||
|
||||
Args:
|
||||
in_channels: Number of input channels.
|
||||
embed_dim: Number of embedding dimensions. Default: 768
|
||||
spatial_shape: Spatial shape of kernel for positional encoding. Default: (7, 7)
|
||||
inference_mode: Flag to instantiate block in inference mode. Default: ``False``
|
||||
"""
|
||||
super().__init__()
|
||||
if isinstance(spatial_shape, int):
|
||||
spatial_shape = tuple([spatial_shape] * 2)
|
||||
assert isinstance(spatial_shape, tuple), (
|
||||
f'"spatial_shape" must by a sequence or int, get {type(spatial_shape)} instead.'
|
||||
)
|
||||
assert len(spatial_shape) == 2, f'Length of "spatial_shape" should be 2, got {len(spatial_shape)} instead.'
|
||||
|
||||
self.spatial_shape = spatial_shape
|
||||
self.embed_dim = embed_dim
|
||||
self.in_channels = in_channels
|
||||
self.groups = embed_dim
|
||||
|
||||
if inference_mode:
|
||||
self.reparam_conv = nn.Conv2d(
|
||||
in_channels=self.in_channels,
|
||||
out_channels=self.embed_dim,
|
||||
kernel_size=self.spatial_shape,
|
||||
stride=1,
|
||||
padding=int(self.spatial_shape[0] // 2),
|
||||
groups=self.embed_dim,
|
||||
bias=True,
|
||||
)
|
||||
else:
|
||||
self.pe = nn.Conv2d(
|
||||
in_channels,
|
||||
embed_dim,
|
||||
spatial_shape,
|
||||
1,
|
||||
int(spatial_shape[0] // 2),
|
||||
bias=True,
|
||||
groups=embed_dim,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if hasattr(self, "reparam_conv"):
|
||||
x = self.reparam_conv(x)
|
||||
return x
|
||||
else:
|
||||
x = self.pe(x) + x
|
||||
return x
|
||||
|
||||
def reparameterize(self) -> None:
|
||||
# Build equivalent Id tensor
|
||||
input_dim = self.in_channels // self.groups
|
||||
kernel_value = torch.zeros(
|
||||
(
|
||||
self.in_channels,
|
||||
input_dim,
|
||||
self.spatial_shape[0],
|
||||
self.spatial_shape[1],
|
||||
),
|
||||
dtype=self.pe.weight.dtype,
|
||||
device=self.pe.weight.device,
|
||||
)
|
||||
for i in range(self.in_channels):
|
||||
kernel_value[
|
||||
i,
|
||||
i % input_dim,
|
||||
self.spatial_shape[0] // 2,
|
||||
self.spatial_shape[1] // 2,
|
||||
] = 1
|
||||
id_tensor = kernel_value
|
||||
|
||||
# Reparameterize Id tensor and conv
|
||||
w_final = id_tensor + self.pe.weight
|
||||
b_final = self.pe.bias
|
||||
|
||||
# Introduce reparam conv
|
||||
self.reparam_conv = nn.Conv2d(
|
||||
in_channels=self.in_channels,
|
||||
out_channels=self.embed_dim,
|
||||
kernel_size=self.spatial_shape,
|
||||
stride=1,
|
||||
padding=int(self.spatial_shape[0] // 2),
|
||||
groups=self.embed_dim,
|
||||
bias=True,
|
||||
)
|
||||
self.reparam_conv.weight.data = w_final
|
||||
self.reparam_conv.bias.data = b_final
|
||||
|
||||
for para in self.parameters():
|
||||
para.detach_()
|
||||
self.__delattr__("pe")
|
||||
|
||||
|
||||
class RepMixerBlock(nn.Module):
|
||||
"""Implementation of Metaformer block with RepMixer as token mixer.
|
||||
|
||||
For more details on Metaformer structure, please refer to: `MetaFormer Is Actually What You Need for Vision
|
||||
<https://arxiv.org/pdf/2111.11418.pdf>`_
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
kernel_size: int = 3,
|
||||
mlp_ratio: float = 4.0,
|
||||
act_layer: nn.Module = nn.GELU,
|
||||
drop: float = 0.0,
|
||||
drop_path: float = 0.0,
|
||||
use_layer_scale: bool = True,
|
||||
layer_scale_init_value: float = 1e-5,
|
||||
inference_mode: bool = False,
|
||||
):
|
||||
"""Build RepMixer Block.
|
||||
|
||||
Args:
|
||||
dim: Number of embedding dimensions.
|
||||
kernel_size: Kernel size for repmixer. Default: 3
|
||||
mlp_ratio: MLP expansion ratio. Default: 4.0
|
||||
act_layer: Activation layer. Default: ``nn.GELU``
|
||||
drop: Dropout rate. Default: 0.0
|
||||
drop_path: Drop path rate. Default: 0.0
|
||||
use_layer_scale: Flag to turn on layer scale. Default: ``True``
|
||||
layer_scale_init_value: Layer scale value at initialization. Default: 1e-5
|
||||
inference_mode: Flag to instantiate block in inference mode. Default: ``False``
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.token_mixer = RepMixer(
|
||||
dim,
|
||||
kernel_size=kernel_size,
|
||||
use_layer_scale=use_layer_scale,
|
||||
layer_scale_init_value=layer_scale_init_value,
|
||||
inference_mode=inference_mode,
|
||||
)
|
||||
|
||||
assert mlp_ratio > 0, f"MLP ratio should be greater than 0, found: {mlp_ratio}"
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
self.convffn = ConvFFN(
|
||||
in_channels=dim,
|
||||
hidden_channels=mlp_hidden_dim,
|
||||
act_layer=act_layer,
|
||||
drop=drop,
|
||||
)
|
||||
|
||||
# Drop Path
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
||||
|
||||
# Layer Scale
|
||||
self.use_layer_scale = use_layer_scale
|
||||
if use_layer_scale:
|
||||
self.layer_scale = nn.Parameter(layer_scale_init_value * torch.ones((dim, 1, 1)), requires_grad=True)
|
||||
|
||||
def forward(self, x):
|
||||
if self.use_layer_scale:
|
||||
x = self.token_mixer(x)
|
||||
x = x + self.drop_path(self.layer_scale * self.convffn(x))
|
||||
else:
|
||||
x = self.token_mixer(x)
|
||||
x = x + self.drop_path(self.convffn(x))
|
||||
return x
|
||||
|
||||
|
||||
class AttentionBlock(nn.Module):
|
||||
"""Implementation of metaformer block with MHSA as token mixer.
|
||||
|
||||
For more details on Metaformer structure, please refer to: `MetaFormer Is Actually What You Need for Vision
|
||||
<https://arxiv.org/pdf/2111.11418.pdf>`_
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
act_layer: nn.Module = nn.GELU,
|
||||
norm_layer: nn.Module = nn.BatchNorm2d,
|
||||
drop: float = 0.0,
|
||||
drop_path: float = 0.0,
|
||||
use_layer_scale: bool = True,
|
||||
layer_scale_init_value: float = 1e-5,
|
||||
):
|
||||
"""Build Attention Block.
|
||||
|
||||
Args:
|
||||
dim: Number of embedding dimensions.
|
||||
mlp_ratio: MLP expansion ratio. Default: 4.0
|
||||
act_layer: Activation layer. Default: ``nn.GELU``
|
||||
norm_layer: Normalization layer. Default: ``nn.BatchNorm2d``
|
||||
drop: Dropout rate. Default: 0.0
|
||||
drop_path: Drop path rate. Default: 0.0
|
||||
use_layer_scale: Flag to turn on layer scale. Default: ``True``
|
||||
layer_scale_init_value: Layer scale value at initialization. Default: 1e-5
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.norm = norm_layer(dim)
|
||||
self.token_mixer = MHSA(dim=dim)
|
||||
|
||||
assert mlp_ratio > 0, f"MLP ratio should be greater than 0, found: {mlp_ratio}"
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
self.convffn = ConvFFN(
|
||||
in_channels=dim,
|
||||
hidden_channels=mlp_hidden_dim,
|
||||
act_layer=act_layer,
|
||||
drop=drop,
|
||||
)
|
||||
|
||||
# Drop path
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
||||
|
||||
# Layer Scale
|
||||
self.use_layer_scale = use_layer_scale
|
||||
if use_layer_scale:
|
||||
self.layer_scale_1 = nn.Parameter(layer_scale_init_value * torch.ones((dim, 1, 1)), requires_grad=True)
|
||||
self.layer_scale_2 = nn.Parameter(layer_scale_init_value * torch.ones((dim, 1, 1)), requires_grad=True)
|
||||
|
||||
def forward(self, x):
|
||||
if self.use_layer_scale:
|
||||
x = x + self.drop_path(self.layer_scale_1 * self.token_mixer(self.norm(x)))
|
||||
x = x + self.drop_path(self.layer_scale_2 * self.convffn(x))
|
||||
else:
|
||||
x = x + self.drop_path(self.token_mixer(self.norm(x)))
|
||||
x = x + self.drop_path(self.convffn(x))
|
||||
return x
|
||||
|
||||
|
||||
def basic_blocks(
|
||||
dim: int,
|
||||
block_index: int,
|
||||
num_blocks: list[int],
|
||||
token_mixer_type: str,
|
||||
kernel_size: int = 3,
|
||||
mlp_ratio: float = 4.0,
|
||||
act_layer: nn.Module = nn.GELU,
|
||||
norm_layer: nn.Module = nn.BatchNorm2d,
|
||||
drop_rate: float = 0.0,
|
||||
drop_path_rate: float = 0.0,
|
||||
use_layer_scale: bool = True,
|
||||
layer_scale_init_value: float = 1e-5,
|
||||
inference_mode=False,
|
||||
) -> nn.Sequential:
|
||||
"""Build FastViT blocks within a stage.
|
||||
|
||||
Args:
|
||||
dim: Number of embedding dimensions.
|
||||
block_index: block index.
|
||||
num_blocks: List containing number of blocks per stage.
|
||||
token_mixer_type: Token mixer type.
|
||||
kernel_size: Kernel size for repmixer.
|
||||
mlp_ratio: MLP expansion ratio.
|
||||
act_layer: Activation layer.
|
||||
norm_layer: Normalization layer.
|
||||
drop_rate: Dropout rate.
|
||||
drop_path_rate: Drop path rate.
|
||||
use_layer_scale: Flag to turn on layer scale regularization.
|
||||
layer_scale_init_value: Layer scale value at initialization.
|
||||
inference_mode: Flag to instantiate block in inference mode.
|
||||
|
||||
Returns:
|
||||
nn.Sequential object of all the blocks within the stage.
|
||||
"""
|
||||
blocks = []
|
||||
for block_idx in range(num_blocks[block_index]):
|
||||
block_dpr = drop_path_rate * (block_idx + sum(num_blocks[:block_index])) / (sum(num_blocks) - 1)
|
||||
if token_mixer_type == "repmixer":
|
||||
blocks.append(
|
||||
RepMixerBlock(
|
||||
dim,
|
||||
kernel_size=kernel_size,
|
||||
mlp_ratio=mlp_ratio,
|
||||
act_layer=act_layer,
|
||||
drop=drop_rate,
|
||||
drop_path=block_dpr,
|
||||
use_layer_scale=use_layer_scale,
|
||||
layer_scale_init_value=layer_scale_init_value,
|
||||
inference_mode=inference_mode,
|
||||
)
|
||||
)
|
||||
elif token_mixer_type == "attention":
|
||||
blocks.append(
|
||||
AttentionBlock(
|
||||
dim,
|
||||
mlp_ratio=mlp_ratio,
|
||||
act_layer=act_layer,
|
||||
norm_layer=norm_layer,
|
||||
drop=drop_rate,
|
||||
drop_path=block_dpr,
|
||||
use_layer_scale=use_layer_scale,
|
||||
layer_scale_init_value=layer_scale_init_value,
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Token mixer type: {token_mixer_type} not supported")
|
||||
blocks = nn.Sequential(*blocks)
|
||||
|
||||
return blocks
|
||||
|
||||
|
||||
class FastViT(nn.Module):
|
||||
"""This class implements `FastViT architecture <https://arxiv.org/pdf/2303.14189.pdf>`_."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layers,
|
||||
token_mixers: tuple[str, ...],
|
||||
embed_dims=None,
|
||||
mlp_ratios=None,
|
||||
downsamples=None,
|
||||
se_downsamples=None,
|
||||
repmixer_kernel_size=3,
|
||||
norm_layer: nn.Module = nn.BatchNorm2d,
|
||||
act_layer: nn.Module = nn.GELU,
|
||||
num_classes=1000,
|
||||
pos_embs=None,
|
||||
down_patch_size=7,
|
||||
down_stride=2,
|
||||
drop_rate=0.0,
|
||||
drop_path_rate=0.0,
|
||||
use_layer_scale=True,
|
||||
layer_scale_init_value=1e-5,
|
||||
init_cfg=None,
|
||||
pretrained=None,
|
||||
cls_ratio=2.0,
|
||||
inference_mode=False,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.num_classes = num_classes
|
||||
if pos_embs is None:
|
||||
pos_embs = [None] * len(layers)
|
||||
|
||||
if se_downsamples is None:
|
||||
se_downsamples = [False] * len(layers)
|
||||
|
||||
# Convolutional stem
|
||||
self.patch_embed = convolutional_stem(3, embed_dims[0], inference_mode)
|
||||
|
||||
# Build the main stages of the network architecture
|
||||
network = []
|
||||
for i in range(len(layers)):
|
||||
# Add position embeddings if requested
|
||||
if pos_embs[i] is not None:
|
||||
network.append(pos_embs[i](embed_dims[i], embed_dims[i], inference_mode=inference_mode))
|
||||
stage = basic_blocks(
|
||||
embed_dims[i],
|
||||
i,
|
||||
layers,
|
||||
token_mixer_type=token_mixers[i],
|
||||
kernel_size=repmixer_kernel_size,
|
||||
mlp_ratio=mlp_ratios[i],
|
||||
act_layer=act_layer,
|
||||
norm_layer=norm_layer,
|
||||
drop_rate=drop_rate,
|
||||
drop_path_rate=drop_path_rate,
|
||||
use_layer_scale=use_layer_scale,
|
||||
layer_scale_init_value=layer_scale_init_value,
|
||||
inference_mode=inference_mode,
|
||||
)
|
||||
network.append(stage)
|
||||
if i >= len(layers) - 1:
|
||||
break
|
||||
|
||||
# Patch merging/downsampling between stages.
|
||||
if downsamples[i] or embed_dims[i] != embed_dims[i + 1]:
|
||||
network.append(
|
||||
PatchEmbed(
|
||||
patch_size=down_patch_size,
|
||||
stride=down_stride,
|
||||
in_channels=embed_dims[i],
|
||||
embed_dim=embed_dims[i + 1],
|
||||
inference_mode=inference_mode,
|
||||
use_se=se_downsamples[i + 1],
|
||||
)
|
||||
)
|
||||
self.network = nn.ModuleList(network)
|
||||
|
||||
# Classifier head
|
||||
self.conv_exp = MobileOneBlock(
|
||||
in_channels=embed_dims[-1],
|
||||
out_channels=int(embed_dims[-1] * cls_ratio),
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
groups=embed_dims[-1],
|
||||
inference_mode=inference_mode,
|
||||
use_se=True,
|
||||
num_conv_branches=1,
|
||||
)
|
||||
self.head = nn.Linear(int(embed_dims[-1] * cls_ratio), num_classes) if num_classes > 0 else nn.Identity()
|
||||
self.apply(self.cls_init_weights)
|
||||
self.init_cfg = copy.deepcopy(init_cfg)
|
||||
|
||||
def cls_init_weights(self, m: nn.Module) -> None:
|
||||
"""Init.
|
||||
|
||||
for classification.
|
||||
"""
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=0.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def forward_embeddings(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.patch_embed(x)
|
||||
return x
|
||||
|
||||
def forward_tokens(self, x: torch.Tensor) -> torch.Tensor:
|
||||
for idx, block in enumerate(self.network):
|
||||
x = block(x)
|
||||
# output only the features of last layer for image classification
|
||||
return x
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# input embedding
|
||||
x = self.forward_embeddings(x)
|
||||
# through backbone
|
||||
x = self.forward_tokens(x)
|
||||
# for image classification
|
||||
x = self.conv_exp(x)
|
||||
cls_out = self.head(x)
|
||||
return cls_out
|
||||
|
||||
|
||||
@register_model
|
||||
def mci0(pretrained=False, **kwargs):
|
||||
"""Instantiate MCi0 model variant."""
|
||||
layers = [2, 6, 10, 2]
|
||||
embed_dims = [64, 128, 256, 512]
|
||||
mlp_ratios = [3, 3, 3, 3]
|
||||
downsamples = [True, True, True, True]
|
||||
se_downsamples = [False, False, True, True]
|
||||
pos_embs = [None, None, None, partial(RepCPE, spatial_shape=(7, 7))]
|
||||
token_mixers = ("repmixer", "repmixer", "repmixer", "attention")
|
||||
model = FastViT(
|
||||
layers,
|
||||
token_mixers=token_mixers,
|
||||
embed_dims=embed_dims,
|
||||
pos_embs=pos_embs,
|
||||
mlp_ratios=mlp_ratios,
|
||||
downsamples=downsamples,
|
||||
se_downsamples=se_downsamples,
|
||||
**kwargs,
|
||||
)
|
||||
model.default_cfg = default_cfgs["fastvit_s"]
|
||||
if pretrained:
|
||||
raise ValueError("Functionality not implemented.")
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def mci1(pretrained=False, **kwargs):
|
||||
"""Instantiate MCi1 model variant."""
|
||||
layers = [4, 12, 20, 4]
|
||||
embed_dims = [64, 128, 256, 512]
|
||||
mlp_ratios = [3, 3, 3, 3]
|
||||
downsamples = [True, True, True, True]
|
||||
se_downsamples = [False, False, True, True]
|
||||
pos_embs = [None, None, None, partial(RepCPE, spatial_shape=(7, 7))]
|
||||
token_mixers = ("repmixer", "repmixer", "repmixer", "attention")
|
||||
model = FastViT(
|
||||
layers,
|
||||
token_mixers=token_mixers,
|
||||
embed_dims=embed_dims,
|
||||
pos_embs=pos_embs,
|
||||
mlp_ratios=mlp_ratios,
|
||||
downsamples=downsamples,
|
||||
se_downsamples=se_downsamples,
|
||||
**kwargs,
|
||||
)
|
||||
model.default_cfg = default_cfgs["fastvit_s"]
|
||||
if pretrained:
|
||||
raise ValueError("Functionality not implemented.")
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def mci2(pretrained=False, **kwargs):
|
||||
"""Instantiate MCi2 model variant."""
|
||||
layers = [4, 12, 24, 4]
|
||||
embed_dims = [80, 160, 320, 640]
|
||||
mlp_ratios = [3, 3, 3, 3]
|
||||
downsamples = [True, True, True, True]
|
||||
se_downsamples = [False, False, True, True]
|
||||
pos_embs = [None, None, None, partial(RepCPE, spatial_shape=(7, 7))]
|
||||
token_mixers = ("repmixer", "repmixer", "repmixer", "attention")
|
||||
model = FastViT(
|
||||
layers,
|
||||
token_mixers=token_mixers,
|
||||
embed_dims=embed_dims,
|
||||
pos_embs=pos_embs,
|
||||
mlp_ratios=mlp_ratios,
|
||||
downsamples=downsamples,
|
||||
se_downsamples=se_downsamples,
|
||||
**kwargs,
|
||||
)
|
||||
model.default_cfg = default_cfgs["fastvit_m"]
|
||||
if pretrained:
|
||||
raise ValueError("Functionality not implemented.")
|
||||
return model
|
||||
389
mobileclip/models/vit.py
Normal file
389
mobileclip/models/vit.py
Normal file
@@ -0,0 +1,389 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
#
|
||||
# For licensing see accompanying LICENSE file.
|
||||
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
|
||||
#
|
||||
"""
|
||||
Implementation of the following modules is borrowed from ml-cvnets repo:
|
||||
https://github.com/apple/ml-cvnets/blob/main/cvnets/models/classification/vit.py.
|
||||
|
||||
Please see ACKNOWLEDGMENTS for license details.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from timm.models import register_model
|
||||
from torch import Tensor, nn
|
||||
|
||||
from mobileclip import logger
|
||||
from mobileclip.modules.common.transformer import (
|
||||
PositionalEmbedding,
|
||||
TransformerEncoder,
|
||||
get_normalization_layer,
|
||||
)
|
||||
from mobileclip.modules.image.image_projection import SimpleImageProjectionHead
|
||||
|
||||
|
||||
class ConvNormAct(nn.Module):
|
||||
"""Applies an N-dimensional convolution over an input.
|
||||
|
||||
Args:
|
||||
cfg: Model configuration.
|
||||
in_channels: :math:`C_{out}` from an expected output of size :math:`(bs, C_{in}, X_{1}, ..., X_{N})`.
|
||||
out_channels: :math:`C_{out}` from an expected output of size :math:`(bs, C_{out}, Y_{1}, ..., Y_{N})`.
|
||||
kernel_size: Kernel size for convolution. An integer, or tuple of length ``N``.
|
||||
stride: Stride for convolution. An integer, or tuple of length ``N``. Default: 1.
|
||||
dilation: Dilation rate for convolution. An integer, or tuple of length ``N``. Default: ``1``.
|
||||
padding: Padding for convolution. An integer, or tuple of length ``N``. If not specified, padding is
|
||||
automatically computed based on kernel size and dilation range. Default : ``None`` (equivalent to ``[
|
||||
int((kernel_size[i] - 1) / 2) * dilation[i] for i in range(N)]``).
|
||||
groups: Number of groups in convolution. Default: ``1``.
|
||||
bias: Use bias. Default: ``False``.
|
||||
padding_mode: Padding mode ('zeros', 'reflect', 'replicate' or 'circular'). Default: ``zeros``.
|
||||
use_norm: Use normalization layer after convolution. Default: ``True``.
|
||||
use_act: Use activation layer after convolution (or convolution and normalization). Default: ``True``.
|
||||
norm_layer: If not None, the provided normalization layer object will be used. Otherwise, a normalization object
|
||||
will be created based on config ``model.normalization.*`` opts.
|
||||
act_layer: If not None, the provided activation function will be used. Otherwise, an activation function will be
|
||||
created based on config ``model.activation.*`` opts.
|
||||
|
||||
Notes:
|
||||
- Input: :math:`(bs, C_{in}, X_{1}, ..., X_{N})`.
|
||||
- Output: :math:`(bs, C_{out}, Y_{1}, ..., Y_{N})`.
|
||||
- For depth-wise convolution, `groups=C_{in}=C_{out}`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cfg: dict,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: int | tuple[int, ...],
|
||||
stride: int | tuple[int, ...] = 1,
|
||||
dilation: int | tuple[int, ...] = 1,
|
||||
padding: int | tuple[int, ...] | None = None,
|
||||
groups: int = 1,
|
||||
bias: bool = False,
|
||||
padding_mode: str = "zeros",
|
||||
use_norm: bool = True,
|
||||
use_act: bool = True,
|
||||
norm_layer: nn.Module | None = None,
|
||||
act_layer: nn.Module | None = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.ndim = 2
|
||||
|
||||
if norm_layer is None and use_norm:
|
||||
norm_type = cfg.get("normalization", "batch_norm")
|
||||
if norm_type == "batch_norm":
|
||||
norm_layer = nn.BatchNorm2d(
|
||||
num_features=out_channels,
|
||||
momentum=cfg.get("momentum", 0.1),
|
||||
)
|
||||
else:
|
||||
norm_layer = get_normalization_layer(num_features=out_channels, norm_type=norm_type)
|
||||
elif norm_layer is not None and use_norm:
|
||||
logger.error(f"When use_norm is False, norm_layer should be None, but norm_layer={norm_layer} is provided.")
|
||||
|
||||
if act_layer is None and use_act:
|
||||
act_layer = nn.GELU() # Default to GELU
|
||||
elif act_layer is not None and use_act:
|
||||
logger.error(f"When use_act is False, act_layer should be None, but act_layer={act_layer} is provided.")
|
||||
|
||||
if use_norm and any(param[0] == "bias" for param in norm_layer.named_parameters()) and bias:
|
||||
assert not bias, "Do not use bias when using normalization layers with bias."
|
||||
|
||||
if isinstance(kernel_size, int):
|
||||
kernel_size = (kernel_size,) * self.ndim
|
||||
|
||||
if isinstance(stride, int):
|
||||
stride = (stride,) * self.ndim
|
||||
|
||||
if isinstance(dilation, int):
|
||||
dilation = (dilation,) * self.ndim
|
||||
|
||||
assert isinstance(kernel_size, tuple)
|
||||
assert isinstance(stride, tuple)
|
||||
assert isinstance(dilation, tuple)
|
||||
|
||||
if padding is None:
|
||||
padding = (int((kernel_size[i] - 1) / 2) * dilation[i] for i in range(self.ndim))
|
||||
|
||||
if in_channels % groups != 0:
|
||||
logger.error(f"Input channels are not divisible by groups. {in_channels}%{groups} != 0 ")
|
||||
if out_channels % groups != 0:
|
||||
logger.error(f"Output channels are not divisible by groups. {out_channels}%{groups} != 0 ")
|
||||
|
||||
block = nn.Sequential()
|
||||
|
||||
conv_layer = nn.Conv2d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size, # type: ignore
|
||||
stride=stride, # type: ignore
|
||||
padding=padding,
|
||||
dilation=dilation, # type: ignore
|
||||
groups=groups,
|
||||
bias=bias,
|
||||
padding_mode=padding_mode,
|
||||
)
|
||||
|
||||
block.add_module(name="conv", module=conv_layer)
|
||||
|
||||
self.norm_name = None
|
||||
if use_norm:
|
||||
block.add_module(name="norm", module=norm_layer)
|
||||
self.norm_name = norm_layer.__class__.__name__
|
||||
|
||||
self.act_name = None
|
||||
if use_act:
|
||||
block.add_module(name="act", module=act_layer)
|
||||
self.act_name = act_layer.__class__.__name__
|
||||
|
||||
self.block = block
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.stride = stride
|
||||
self.groups = groups
|
||||
self.kernel_size = conv_layer.kernel_size
|
||||
self.bias = bias
|
||||
self.dilation = dilation
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return self.block(x)
|
||||
|
||||
|
||||
class VisionTransformer(nn.Module):
|
||||
"""This class defines the `Vision Transformer architecture <https://arxiv.org/abs/2010.11929>`_. Our model
|
||||
implementation is inspired from `Early Convolutions Help Transformers See
|
||||
Better <https://arxiv.org/abs/2106.14881>`_.
|
||||
|
||||
.. note::
|
||||
Our implementation is different from the original implementation in two ways:
|
||||
1. Kernel size is odd.
|
||||
2. Our positional encoding implementation allows us to use ViT with any multiple input scales
|
||||
3. We do not use StochasticDepth
|
||||
4. We do not add positional encoding to class token (if enabled), as suggested in `DeiT-3 paper <https://arxiv.org/abs/2204.07118>`_
|
||||
"""
|
||||
|
||||
def __init__(self, cfg, *args, **kwargs) -> None:
|
||||
super().__init__()
|
||||
image_channels = 3
|
||||
num_classes = cfg.get("n_classes", 1000)
|
||||
|
||||
self.projection_dim = None
|
||||
if "projection_dim" in kwargs:
|
||||
self.projection_dim = kwargs.get("projection_dim")
|
||||
|
||||
kernel_sizes_conv_stem = [4, 2, 2]
|
||||
strides_conv_stem = [4, 2, 2]
|
||||
|
||||
# Typically, in the ImageNet dataset, we use 224x224 as a resolution.
|
||||
# For out ViT implementation, patch size is 16 (16 = 4 * 2 * 2)
|
||||
# Therefore, total number of embeddings along width and height are (224 / 16)^2
|
||||
num_embeddings = (224 // 16) ** 2
|
||||
|
||||
embed_dim = cfg["embed_dim"]
|
||||
ffn_dim = cfg["embed_dim"] * 4
|
||||
pos_emb_drop_p = cfg.get("pos_emb_drop_p", 0.0)
|
||||
n_transformer_layers = cfg["n_transformer_layers"]
|
||||
num_heads = cfg["n_attn_heads"]
|
||||
attn_dropout = cfg.get("attn_dropout", 0.0)
|
||||
dropout = cfg.get("dropout", 0.0)
|
||||
ffn_dropout = cfg.get("ffn_dropout", 0.0)
|
||||
norm_layer = cfg.get("norm_layer", "layer_norm")
|
||||
|
||||
conv_stem_proj_dim = max(32, embed_dim // 4)
|
||||
patch_emb = [
|
||||
ConvNormAct(
|
||||
cfg=cfg,
|
||||
in_channels=image_channels,
|
||||
out_channels=conv_stem_proj_dim,
|
||||
kernel_size=kernel_sizes_conv_stem[0],
|
||||
stride=strides_conv_stem[0],
|
||||
bias=False,
|
||||
use_norm=True,
|
||||
use_act=True,
|
||||
),
|
||||
ConvNormAct(
|
||||
cfg=cfg,
|
||||
in_channels=conv_stem_proj_dim,
|
||||
out_channels=conv_stem_proj_dim,
|
||||
kernel_size=kernel_sizes_conv_stem[1],
|
||||
stride=strides_conv_stem[1],
|
||||
bias=False,
|
||||
use_norm=True,
|
||||
use_act=True,
|
||||
),
|
||||
ConvNormAct(
|
||||
cfg=cfg,
|
||||
in_channels=conv_stem_proj_dim,
|
||||
out_channels=embed_dim,
|
||||
kernel_size=kernel_sizes_conv_stem[2],
|
||||
stride=strides_conv_stem[2],
|
||||
bias=True,
|
||||
use_norm=False,
|
||||
use_act=False,
|
||||
),
|
||||
]
|
||||
|
||||
self.patch_emb = nn.Sequential(*patch_emb)
|
||||
|
||||
use_cls_token = not cfg.get("no_cls_token", False)
|
||||
stochastic_dropout = cfg.get("stochastic_dropout", 0.0)
|
||||
per_layer_stochastic_drop_rate = [round(x, 3) for x in np.linspace(0, stochastic_dropout, n_transformer_layers)]
|
||||
transformer_blocks = [
|
||||
TransformerEncoder(
|
||||
embed_dim=embed_dim,
|
||||
ffn_latent_dim=ffn_dim,
|
||||
num_heads=num_heads,
|
||||
attn_dropout=attn_dropout,
|
||||
dropout=dropout,
|
||||
ffn_dropout=ffn_dropout,
|
||||
transformer_norm_layer=norm_layer,
|
||||
stochastic_dropout=per_layer_stochastic_drop_rate[layer_idx],
|
||||
)
|
||||
for layer_idx in range(n_transformer_layers)
|
||||
]
|
||||
|
||||
self.post_transformer_norm = get_normalization_layer(num_features=embed_dim, norm_type=norm_layer)
|
||||
|
||||
self.transformer = nn.Sequential(*transformer_blocks)
|
||||
|
||||
if self.projection_dim is None:
|
||||
self.classifier = nn.Linear(embed_dim, num_classes)
|
||||
else:
|
||||
self.classifier = SimpleImageProjectionHead(embed_dim, self.projection_dim)
|
||||
|
||||
if use_cls_token:
|
||||
self.cls_token = nn.Parameter(torch.zeros(size=(1, 1, embed_dim)))
|
||||
torch.nn.init.trunc_normal_(self.cls_token, std=0.02)
|
||||
else:
|
||||
self.cls_token = None
|
||||
|
||||
self.pos_embed = PositionalEmbedding(
|
||||
num_embeddings=num_embeddings,
|
||||
embedding_dim=embed_dim,
|
||||
padding_idx=None,
|
||||
interpolation_mode="bilinear",
|
||||
)
|
||||
self.emb_dropout = nn.Dropout(p=pos_emb_drop_p)
|
||||
|
||||
def extract_patch_embeddings(self, x: Tensor) -> tuple[Tensor, tuple[int, int]]:
|
||||
# input is of shape [Batch, in_channels, height, width]. in_channels is mostly 3 (for RGB images)
|
||||
batch_size = x.shape[0]
|
||||
|
||||
# [Batch, in_channels, height, width] --> [Batch, emb_dim, num_patches_height, num_patches_width]
|
||||
patch_emb = self.patch_emb(x)
|
||||
n_h, n_w = patch_emb.shape[-2:]
|
||||
|
||||
# [Batch, emb_dim, num_patches_height, num_patches_width] --> [Batch, emb_dim, num_patches]
|
||||
patch_emb = patch_emb.flatten(2)
|
||||
# [Batch, emb_dim, num_patches] --> [Batch, num_patches, emb_dim]
|
||||
patch_emb = patch_emb.transpose(1, 2).contiguous()
|
||||
|
||||
n_patches = patch_emb.shape[1]
|
||||
# we resize the positional encodings dynamically.
|
||||
pos_emb = self.pos_embed(n_patches).to(patch_emb.dtype)
|
||||
|
||||
# add positional encodings
|
||||
patch_emb = pos_emb + patch_emb
|
||||
|
||||
# add classification token
|
||||
if self.cls_token is not None:
|
||||
# [1, 1, emb_dim] --> [Batch, 1, emb_dim]
|
||||
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
|
||||
# Concat([Batch, 1, emb_dim], [Batch, num_patches, emb_dim]) --> [Batch, num_patches + 1, emb_dim]
|
||||
patch_emb = torch.cat((cls_tokens, patch_emb), dim=1)
|
||||
|
||||
# dropout
|
||||
patch_emb = self.emb_dropout(patch_emb)
|
||||
return patch_emb, (n_h, n_w)
|
||||
|
||||
def _features_from_transformer(self, x: Tensor, *args, **kwargs) -> tuple[Tensor, tuple[int, int]]:
|
||||
# this function extract patch embeddings and then apply transformer module to learn
|
||||
# inter-patch representations
|
||||
|
||||
# [B, N, C] --> [N, B, embed_dim], where B is batch size, N is number of tokens,
|
||||
# and embed_dim is feature dim
|
||||
x, (n_h, n_w) = self.extract_patch_embeddings(x)
|
||||
|
||||
for layer in self.transformer:
|
||||
x = layer(x)
|
||||
x = self.post_transformer_norm(x)
|
||||
|
||||
return x, (n_h, n_w)
|
||||
|
||||
def extract_features(self, x: Tensor, *args, **kwargs) -> tuple[Tensor, Tensor | None]:
|
||||
# The extract_features function for ViT returns two outputs: (1) embedding corresponding to CLS token
|
||||
# and (2) image embeddings of the shape [B, C, h//o, w//o], where the value of o is typically 16.
|
||||
return_image_embeddings = kwargs.get("return_image_embeddings", False)
|
||||
|
||||
# [B, C, H, W] --> [B, N + 1, embed_dim] or [B, N, embed_dim]
|
||||
# here, B is batch size, C is input channels
|
||||
# H and W are input height and width
|
||||
# N is the number of pixels (or tokens) after processing input with conv stem and reshaping
|
||||
# We add +1 for cls token (if applicable)
|
||||
# embed_dim --> embedding dimension
|
||||
x, (n_h, n_w) = self._features_from_transformer(x, *args, **kwargs)
|
||||
|
||||
if self.cls_token is not None:
|
||||
# [B, N + 1, embed_dim] --> [B, embed_dim], [B, N, embed_dim]
|
||||
cls_embedding, image_embedding = torch.split(x, split_size_or_sections=[1, x.shape[1] - 1], dim=1)
|
||||
cls_embedding = cls_embedding.squeeze(1)
|
||||
else:
|
||||
# [B, N, embed_dim] -> [B, embed_dim]
|
||||
cls_embedding = torch.mean(x, dim=1)
|
||||
# [B, N, embed_dim]
|
||||
image_embedding = x
|
||||
|
||||
if return_image_embeddings:
|
||||
# reshape image embedding to 4-D tensor
|
||||
# [B, N, C] --> [B, C, N]
|
||||
image_embedding = image_embedding.transpose(1, 2).contiguous()
|
||||
image_embedding = image_embedding.reshape(image_embedding.shape[0], -1, n_h, n_w)
|
||||
|
||||
return cls_embedding, image_embedding
|
||||
else:
|
||||
return cls_embedding, None
|
||||
|
||||
def forward_classifier(self, x: Tensor, *args, **kwargs) -> tuple[Tensor, Tensor]:
|
||||
cls_embedding, image_embedding = self.extract_features(x, *args, **kwargs)
|
||||
# classify based on CLS token
|
||||
cls_embedding = self.classifier(cls_embedding)
|
||||
return cls_embedding, image_embedding
|
||||
|
||||
def forward(self, x: Tensor, *args, **kwargs) -> Tensor | dict[str, Tensor]:
|
||||
# In ViT model, we can return either classifier embeddings (logits) or image embeddings or both.
|
||||
# To return the image embeddings, we need to set keyword argument (return_image_embeddings) as True.
|
||||
if kwargs.get("return_image_embeddings", False):
|
||||
out_dict = dict()
|
||||
prediction, image_embedding = self.forward_classifier(x, *args, **kwargs)
|
||||
out_dict.update({"logits": prediction})
|
||||
if image_embedding is not None:
|
||||
out_dict.update({"image_embeddings": image_embedding})
|
||||
return out_dict
|
||||
else:
|
||||
prediction, _ = self.forward_classifier(x, *args, **kwargs)
|
||||
return prediction
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_b16(pretrained=False, **kwargs):
|
||||
# Vision transformer config
|
||||
cfg = {
|
||||
"norm_layer": "layer_norm_fp32",
|
||||
"act_layer": "gelu",
|
||||
"embed_dim": 768,
|
||||
"n_transformer_layers": 12,
|
||||
"n_attn_heads": 12,
|
||||
}
|
||||
model = VisionTransformer(cfg=cfg, **kwargs)
|
||||
if pretrained:
|
||||
raise ValueError("Functionality not implemented.")
|
||||
return model
|
||||
6
mobileclip/modules/__init__.py
Normal file
6
mobileclip/modules/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
#
|
||||
# For licensing see accompanying LICENSE file.
|
||||
# Copyright (C) 2024 Apple Inc. All rights reserved.
|
||||
#
|
||||
6
mobileclip/modules/common/__init__.py
Normal file
6
mobileclip/modules/common/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
#
|
||||
# For licensing see accompanying LICENSE file.
|
||||
# Copyright (C) 2024 Apple Inc. All rights reserved.
|
||||
#
|
||||
330
mobileclip/modules/common/mobileone.py
Normal file
330
mobileclip/modules/common/mobileone.py
Normal file
@@ -0,0 +1,330 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
# For licensing see accompanying LICENSE file.
|
||||
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
__all__ = ["MobileOneBlock", "reparameterize_model"]
|
||||
|
||||
|
||||
class SEBlock(nn.Module):
|
||||
"""Squeeze and Excite module.
|
||||
|
||||
Pytorch implementation of `Squeeze-and-Excitation Networks` - https://arxiv.org/pdf/1709.01507.pdf
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels: int, rd_ratio: float = 0.0625) -> None:
|
||||
"""Construct a Squeeze and Excite Module.
|
||||
|
||||
Args:
|
||||
in_channels: Number of input channels.
|
||||
rd_ratio: Input channel reduction ratio.
|
||||
"""
|
||||
super().__init__()
|
||||
self.reduce = nn.Conv2d(
|
||||
in_channels=in_channels,
|
||||
out_channels=int(in_channels * rd_ratio),
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
bias=True,
|
||||
)
|
||||
self.expand = nn.Conv2d(
|
||||
in_channels=int(in_channels * rd_ratio),
|
||||
out_channels=in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
bias=True,
|
||||
)
|
||||
|
||||
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||
"""Apply forward pass."""
|
||||
_b, c, h, w = inputs.size()
|
||||
x = F.avg_pool2d(inputs, kernel_size=[h, w])
|
||||
x = self.reduce(x)
|
||||
x = F.relu(x)
|
||||
x = self.expand(x)
|
||||
x = torch.sigmoid(x)
|
||||
x = x.view(-1, c, 1, 1)
|
||||
return inputs * x
|
||||
|
||||
|
||||
class MobileOneBlock(nn.Module):
|
||||
"""MobileOne building block.
|
||||
|
||||
This block has a multi-branched architecture at train-time and plain-CNN style architecture at inference time For
|
||||
more details, please refer to our paper: `An Improved One millisecond Mobile Backbone` -
|
||||
https://arxiv.org/pdf/2206.04040.pdf
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: int,
|
||||
stride: int = 1,
|
||||
padding: int = 0,
|
||||
dilation: int = 1,
|
||||
groups: int = 1,
|
||||
inference_mode: bool = False,
|
||||
use_se: bool = False,
|
||||
use_act: bool = True,
|
||||
use_scale_branch: bool = True,
|
||||
num_conv_branches: int = 1,
|
||||
activation: nn.Module = nn.GELU(),
|
||||
) -> None:
|
||||
"""Construct a MobileOneBlock module.
|
||||
|
||||
Args:
|
||||
in_channels: Number of channels in the input.
|
||||
out_channels: Number of channels produced by the block.
|
||||
kernel_size: Size of the convolution kernel.
|
||||
stride: Stride size.
|
||||
padding: Zero-padding size.
|
||||
dilation: Kernel dilation factor.
|
||||
groups: Group number.
|
||||
inference_mode: If True, instantiates model in inference mode.
|
||||
use_se: Whether to use SE-ReLU activations.
|
||||
use_act: Whether to use activation. Default: ``True``
|
||||
use_scale_branch: Whether to use scale branch. Default: ``True``
|
||||
num_conv_branches: Number of linear conv branches.
|
||||
"""
|
||||
super().__init__()
|
||||
self.inference_mode = inference_mode
|
||||
self.groups = groups
|
||||
self.stride = stride
|
||||
self.padding = padding
|
||||
self.dilation = dilation
|
||||
self.kernel_size = kernel_size
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.num_conv_branches = num_conv_branches
|
||||
|
||||
# Check if SE-ReLU is requested
|
||||
if use_se:
|
||||
self.se = SEBlock(out_channels)
|
||||
else:
|
||||
self.se = nn.Identity()
|
||||
|
||||
if use_act:
|
||||
self.activation = activation
|
||||
else:
|
||||
self.activation = nn.Identity()
|
||||
|
||||
if inference_mode:
|
||||
self.reparam_conv = nn.Conv2d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
bias=True,
|
||||
)
|
||||
else:
|
||||
# Re-parameterizable skip connection
|
||||
self.rbr_skip = (
|
||||
nn.BatchNorm2d(num_features=in_channels) if out_channels == in_channels and stride == 1 else None
|
||||
)
|
||||
|
||||
# Re-parameterizable conv branches
|
||||
if num_conv_branches > 0:
|
||||
rbr_conv = list()
|
||||
for _ in range(self.num_conv_branches):
|
||||
rbr_conv.append(self._conv_bn(kernel_size=kernel_size, padding=padding))
|
||||
self.rbr_conv = nn.ModuleList(rbr_conv)
|
||||
else:
|
||||
self.rbr_conv = None
|
||||
|
||||
# Re-parameterizable scale branch
|
||||
self.rbr_scale = None
|
||||
if not isinstance(kernel_size, int):
|
||||
kernel_size = kernel_size[0]
|
||||
if (kernel_size > 1) and use_scale_branch:
|
||||
self.rbr_scale = self._conv_bn(kernel_size=1, padding=0)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Apply forward pass."""
|
||||
# Inference mode forward pass.
|
||||
if self.inference_mode:
|
||||
return self.activation(self.se(self.reparam_conv(x)))
|
||||
|
||||
# Multi-branched train-time forward pass.
|
||||
# Skip branch output
|
||||
identity_out = 0
|
||||
if self.rbr_skip is not None:
|
||||
identity_out = self.rbr_skip(x)
|
||||
|
||||
# Scale branch output
|
||||
scale_out = 0
|
||||
if self.rbr_scale is not None:
|
||||
scale_out = self.rbr_scale(x)
|
||||
|
||||
# Other branches
|
||||
out = scale_out + identity_out
|
||||
if self.rbr_conv is not None:
|
||||
for ix in range(self.num_conv_branches):
|
||||
out += self.rbr_conv[ix](x)
|
||||
|
||||
return self.activation(self.se(out))
|
||||
|
||||
def reparameterize(self):
|
||||
"""Following works like `RepVGG: Making VGG-style ConvNets Great Again` - https://arxiv.org/pdf/2101.03697.pdf.
|
||||
We re-parameterize multi-branched architecture used at training time to obtain a plain CNN-like
|
||||
structure for inference.
|
||||
"""
|
||||
if self.inference_mode:
|
||||
return
|
||||
kernel, bias = self._get_kernel_bias()
|
||||
self.reparam_conv = nn.Conv2d(
|
||||
in_channels=self.in_channels,
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=self.kernel_size,
|
||||
stride=self.stride,
|
||||
padding=self.padding,
|
||||
dilation=self.dilation,
|
||||
groups=self.groups,
|
||||
bias=True,
|
||||
)
|
||||
self.reparam_conv.weight.data = kernel
|
||||
self.reparam_conv.bias.data = bias
|
||||
|
||||
# Delete un-used branches
|
||||
for para in self.parameters():
|
||||
para.detach_()
|
||||
self.__delattr__("rbr_conv")
|
||||
self.__delattr__("rbr_scale")
|
||||
if hasattr(self, "rbr_skip"):
|
||||
self.__delattr__("rbr_skip")
|
||||
|
||||
self.inference_mode = True
|
||||
|
||||
def _get_kernel_bias(self) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Method to obtain re-parameterized kernel and bias. Reference:
|
||||
https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L83.
|
||||
|
||||
Returns:
|
||||
Tuple of (kernel, bias) after fusing branches.
|
||||
"""
|
||||
# get weights and bias of scale branch
|
||||
kernel_scale = 0
|
||||
bias_scale = 0
|
||||
if self.rbr_scale is not None:
|
||||
kernel_scale, bias_scale = self._fuse_bn_tensor(self.rbr_scale)
|
||||
# Pad scale branch kernel to match conv branch kernel size.
|
||||
pad = self.kernel_size // 2
|
||||
kernel_scale = torch.nn.functional.pad(kernel_scale, [pad, pad, pad, pad])
|
||||
|
||||
# get weights and bias of skip branch
|
||||
kernel_identity = 0
|
||||
bias_identity = 0
|
||||
if self.rbr_skip is not None:
|
||||
kernel_identity, bias_identity = self._fuse_bn_tensor(self.rbr_skip)
|
||||
|
||||
# get weights and bias of conv branches
|
||||
kernel_conv = 0
|
||||
bias_conv = 0
|
||||
if self.rbr_conv is not None:
|
||||
for ix in range(self.num_conv_branches):
|
||||
_kernel, _bias = self._fuse_bn_tensor(self.rbr_conv[ix])
|
||||
kernel_conv += _kernel
|
||||
bias_conv += _bias
|
||||
|
||||
kernel_final = kernel_conv + kernel_scale + kernel_identity
|
||||
bias_final = bias_conv + bias_scale + bias_identity
|
||||
return kernel_final, bias_final
|
||||
|
||||
def _fuse_bn_tensor(self, branch: nn.Sequential | nn.BatchNorm2d) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Method to fuse batchnorm layer with preceeding conv layer. Reference:
|
||||
https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L95.
|
||||
|
||||
Args:
|
||||
branch: Sequence of ops to be fused.
|
||||
|
||||
Returns:
|
||||
Tuple of (kernel, bias) after fusing batchnorm.
|
||||
"""
|
||||
if isinstance(branch, nn.Sequential):
|
||||
kernel = branch.conv.weight
|
||||
running_mean = branch.bn.running_mean
|
||||
running_var = branch.bn.running_var
|
||||
gamma = branch.bn.weight
|
||||
beta = branch.bn.bias
|
||||
eps = branch.bn.eps
|
||||
else:
|
||||
assert isinstance(branch, nn.BatchNorm2d)
|
||||
if not hasattr(self, "id_tensor"):
|
||||
input_dim = self.in_channels // self.groups
|
||||
|
||||
kernel_size = self.kernel_size
|
||||
if isinstance(self.kernel_size, int):
|
||||
kernel_size = (self.kernel_size, self.kernel_size)
|
||||
|
||||
kernel_value = torch.zeros(
|
||||
(self.in_channels, input_dim, kernel_size[0], kernel_size[1]),
|
||||
dtype=branch.weight.dtype,
|
||||
device=branch.weight.device,
|
||||
)
|
||||
for i in range(self.in_channels):
|
||||
kernel_value[i, i % input_dim, kernel_size[0] // 2, kernel_size[1] // 2] = 1
|
||||
self.id_tensor = kernel_value
|
||||
kernel = self.id_tensor
|
||||
running_mean = branch.running_mean
|
||||
running_var = branch.running_var
|
||||
gamma = branch.weight
|
||||
beta = branch.bias
|
||||
eps = branch.eps
|
||||
std = (running_var + eps).sqrt()
|
||||
t = (gamma / std).reshape(-1, 1, 1, 1)
|
||||
return kernel * t, beta - running_mean * gamma / std
|
||||
|
||||
def _conv_bn(self, kernel_size: int, padding: int) -> nn.Sequential:
|
||||
"""Helper method to construct conv-batchnorm layers.
|
||||
|
||||
Args:
|
||||
kernel_size: Size of the convolution kernel.
|
||||
padding: Zero-padding size.
|
||||
|
||||
Returns:
|
||||
Conv-BN module.
|
||||
"""
|
||||
mod_list = nn.Sequential()
|
||||
mod_list.add_module(
|
||||
"conv",
|
||||
nn.Conv2d(
|
||||
in_channels=self.in_channels,
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=self.stride,
|
||||
padding=padding,
|
||||
groups=self.groups,
|
||||
bias=False,
|
||||
),
|
||||
)
|
||||
mod_list.add_module("bn", nn.BatchNorm2d(num_features=self.out_channels))
|
||||
return mod_list
|
||||
|
||||
|
||||
def reparameterize_model(model: torch.nn.Module) -> nn.Module:
|
||||
"""Method returns a model where a multi-branched structure used in training is re-parameterized into a single branch
|
||||
for inference.
|
||||
|
||||
Args:
|
||||
model: MobileOne model in train mode.
|
||||
|
||||
Returns:
|
||||
MobileOne model in inference mode.
|
||||
"""
|
||||
# Avoid editing original graph
|
||||
model = copy.deepcopy(model)
|
||||
for module in model.modules():
|
||||
if hasattr(module, "reparameterize"):
|
||||
module.reparameterize()
|
||||
return model
|
||||
410
mobileclip/modules/common/transformer.py
Normal file
410
mobileclip/modules/common/transformer.py
Normal file
@@ -0,0 +1,410 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
#
|
||||
# For licensing see accompanying LICENSE file.
|
||||
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
|
||||
#
|
||||
"""
|
||||
Implementation of the following modules is borrowed from ml-cvnets repo:
|
||||
https://github.com/apple/ml-cvnets/blob/main/cvnets/layers/multi_head_attention.py
|
||||
https://github.com/apple/ml-cvnets/blob/main/cvnets/text_encoders/transformer.py.
|
||||
|
||||
Please see ACKNOWLEDGMENTS for license details.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
from torch import Size, Tensor, nn
|
||||
from torch.nn import functional as F
|
||||
from torchvision.ops import StochasticDepth
|
||||
|
||||
from mobileclip import logger
|
||||
|
||||
|
||||
class LayerNormFP32(nn.LayerNorm):
|
||||
"""Applies `Layer Normalization <https://arxiv.org/abs/1607.06450>`_ over a input tensor with FP32 precision."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
normalized_shape: int | list[int] | Size,
|
||||
eps: float | None = 1e-5,
|
||||
elementwise_affine: bool | None = True,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
normalized_shape=normalized_shape,
|
||||
eps=eps,
|
||||
elementwise_affine=elementwise_affine,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
# Convert input from dtype X to FP32 and perform normalization operation.
|
||||
# This may help with underflow/overflow issues that we typically see with normalization layers
|
||||
inp_dtype = x.dtype
|
||||
return super().forward(x.to(torch.float32)).to(inp_dtype)
|
||||
|
||||
|
||||
def get_normalization_layer(norm_type, num_features):
|
||||
if norm_type == "layer_norm":
|
||||
return nn.LayerNorm(num_features)
|
||||
elif norm_type == "layer_norm_fp32":
|
||||
return LayerNormFP32(num_features)
|
||||
else:
|
||||
raise NotImplementedError(f"Option: {norm_type} not supported.")
|
||||
|
||||
|
||||
class PositionalEmbedding(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_embeddings: int,
|
||||
embedding_dim: int,
|
||||
padding_idx: int | None = None,
|
||||
is_learnable: bool | None = False,
|
||||
interpolation_mode: str | None = "bilinear",
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
# Add other pos embedding here and logic to choose between them
|
||||
module = LearnablePositionalEmbedding
|
||||
|
||||
self.pos_embed = module(
|
||||
num_embeddings=num_embeddings,
|
||||
embedding_dim=embedding_dim,
|
||||
padding_idx=padding_idx,
|
||||
interpolation_mode=interpolation_mode,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def forward(self, seq_len: int, *args, **kwargs) -> Tensor:
|
||||
return self.pos_embed(seq_len, *args, **kwargs)
|
||||
|
||||
def __repr__(self):
|
||||
return self.pos_embed.__repr__()
|
||||
|
||||
|
||||
class LearnablePositionalEmbedding(nn.Module):
|
||||
"""Learnable Positional embedding."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_embeddings: int,
|
||||
embedding_dim: int,
|
||||
padding_idx: int | None = None,
|
||||
interpolation_mode: str | None = "bilinear",
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.pos_embed = nn.Parameter(torch.empty(1, 1, num_embeddings, embedding_dim))
|
||||
self.embedding_dim = embedding_dim
|
||||
self.num_embeddings = num_embeddings
|
||||
self.padding_idx = padding_idx
|
||||
self.interpolation_mode = interpolation_mode
|
||||
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self) -> None:
|
||||
nn.init.trunc_normal_(self.pos_embed, mean=0, std=self.embedding_dim**-0.5)
|
||||
if self.padding_idx is not None:
|
||||
with torch.no_grad():
|
||||
self.pos_embed[:, :, self.padding_idx, ...] = 0.0
|
||||
|
||||
def forward(self, seq_len: int, *args, **kwargs) -> Tensor:
|
||||
# scale pos embedding
|
||||
pos_embed = self.pos_embed
|
||||
if self.padding_idx is not None:
|
||||
with torch.no_grad():
|
||||
pos_embed[:, :, self.padding_idx, ...] = 0.0
|
||||
|
||||
if seq_len != self.num_embeddings:
|
||||
pos_embed = F.interpolate(
|
||||
pos_embed,
|
||||
size=(seq_len, self.embedding_dim),
|
||||
mode=self.interpolation_mode,
|
||||
)
|
||||
|
||||
# Input is of the form [Batch, Seq_len, Embedding_dim]
|
||||
return pos_embed.reshape(1, seq_len, self.embedding_dim)
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.__class__.__name__}(num_embeddings={self.num_embeddings}, embedding_dim={self.embedding_dim}, padding_idx={self.padding_idx})"
|
||||
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
"""This layer applies a multi-head self- or cross-attention as described in `Attention is all you need
|
||||
<https://arxiv.org/abs/1706.03762>`_ paper.
|
||||
|
||||
Args:
|
||||
embed_dim (int): :math:`C_{in}` from an expected input of size :math:`(N, S, C_{in})`
|
||||
num_heads (int): Number of heads in multi-head attention
|
||||
attn_dropout (Optional[float]): Attention dropout. Default: 0.0
|
||||
bias (Optional[bool]): Use bias or not. Default: ``True``
|
||||
|
||||
Notes:
|
||||
- Input:
|
||||
- Query tensor (x_q) :math:`(N, S, C_{in})` where :math:`N` is batch size, :math:`S` is number of source tokens,
|
||||
and: math:`C_{in}` is input embedding dim
|
||||
- Optional Key-Value tensor (x_kv) :math:`(N, T, C_{in})` where :math:`T` is number of target tokens
|
||||
- Output: same shape as the input
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim: int,
|
||||
num_heads: int,
|
||||
attn_dropout: float | None = 0.0,
|
||||
bias: bool | None = True,
|
||||
output_dim: int | None = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
if output_dim is None:
|
||||
output_dim = embed_dim
|
||||
super().__init__()
|
||||
if embed_dim % num_heads != 0:
|
||||
logger.error(
|
||||
f"Embedding dim must be divisible by number of heads in {self.__class__.__name__}. Got: embed_dim={embed_dim} and num_heads={num_heads}"
|
||||
)
|
||||
|
||||
self.qkv_proj = nn.Linear(in_features=embed_dim, out_features=3 * embed_dim, bias=bias)
|
||||
|
||||
self.attn_dropout = nn.Dropout(p=attn_dropout)
|
||||
self.out_proj = nn.Linear(in_features=embed_dim, out_features=output_dim, bias=bias)
|
||||
|
||||
self.head_dim = embed_dim // num_heads
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.softmax = nn.Softmax(dim=-1)
|
||||
self.num_heads = num_heads
|
||||
self.embed_dim = embed_dim
|
||||
self.use_separate_proj_weight = embed_dim != output_dim
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.__class__.__name__}(head_dim={self.head_dim}, num_heads={self.num_heads}, attn_dropout={self.attn_dropout.p})"
|
||||
|
||||
def _forward_impl(
|
||||
self,
|
||||
x_q: Tensor,
|
||||
x_kv: Tensor | None = None,
|
||||
key_padding_mask: Tensor | None = None,
|
||||
attn_mask: Tensor | None = None,
|
||||
) -> Tensor:
|
||||
# [N, S, C]
|
||||
b_sz, S_len, _in_channels = x_q.shape
|
||||
|
||||
if x_kv is None:
|
||||
# self-attention
|
||||
# [N, S, C] --> [N, S, 3C] --> [N, S, 3, h, c] where C = hc
|
||||
qkv = self.qkv_proj(x_q).reshape(b_sz, S_len, 3, self.num_heads, -1)
|
||||
# [N, S, 3, h, c] --> [N, h, 3, S, C]
|
||||
qkv = qkv.transpose(1, 3).contiguous()
|
||||
|
||||
# [N, h, 3, S, C] --> [N, h, S, C] x 3
|
||||
query, key, value = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2]
|
||||
else:
|
||||
T_len = x_kv.shape[1]
|
||||
|
||||
# cross-attention
|
||||
# [N, S, C]
|
||||
query = F.linear(
|
||||
x_q,
|
||||
weight=self.qkv_proj.weight[: self.embed_dim, ...],
|
||||
bias=self.qkv_proj.bias[: self.embed_dim] if self.qkv_proj.bias is not None else None,
|
||||
)
|
||||
# [N, S, C] --> [N, S, h, c] --> [N, h, S, c]
|
||||
query = query.reshape(b_sz, S_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
||||
|
||||
# [N, T, C] --> [N, T, 2C]
|
||||
kv = F.linear(
|
||||
x_kv,
|
||||
weight=self.qkv_proj.weight[self.embed_dim :, ...],
|
||||
bias=self.qkv_proj.bias[self.embed_dim :] if self.qkv_proj.bias is not None else None,
|
||||
)
|
||||
# [N, T, 2C] --> [N, T, 2, h, c]
|
||||
kv = kv.reshape(b_sz, T_len, 2, self.num_heads, self.head_dim)
|
||||
# [N, T, 2, h, c] --> [N, h, 2, T, c]
|
||||
kv = kv.transpose(1, 3).contiguous()
|
||||
key, value = kv[:, :, 0], kv[:, :, 1]
|
||||
|
||||
query = query * self.scaling
|
||||
|
||||
# [N h, T, c] --> [N, h, c, T]
|
||||
key = key.transpose(-1, -2)
|
||||
|
||||
# QK^T
|
||||
# [N, h, S, c] x [N, h, c, T] --> [N, h, S, T]
|
||||
attn = torch.matmul(query, key)
|
||||
|
||||
batch_size, _num_heads, num_src_tokens, num_tgt_tokens = attn.shape
|
||||
if attn_mask is not None:
|
||||
# attn_mask shape should be the same as attn
|
||||
assert list(attn_mask.shape) == [
|
||||
batch_size,
|
||||
num_src_tokens,
|
||||
num_tgt_tokens,
|
||||
], (
|
||||
f"Shape of attention mask should be [{batch_size}, {num_src_tokens}, {num_tgt_tokens}]. Got: {attn_mask.shape}"
|
||||
)
|
||||
# [N, S, T] --> [N, 1, S, T]
|
||||
attn_mask = attn_mask.unsqueeze(1)
|
||||
attn = attn + attn_mask
|
||||
|
||||
if key_padding_mask is not None:
|
||||
# Do not attend to padding positions
|
||||
# key padding mask size is [N, T]
|
||||
assert key_padding_mask.dim() == 2 and list(key_padding_mask.shape) == [
|
||||
batch_size,
|
||||
num_tgt_tokens,
|
||||
], (
|
||||
f"Key_padding_mask should be 2-dimension with shape [{batch_size}, {num_tgt_tokens}]. Got: {key_padding_mask.shape}"
|
||||
)
|
||||
attn = attn.masked_fill(
|
||||
key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), # [N, T] --> [N, 1, 1, T]
|
||||
float("-inf"),
|
||||
)
|
||||
|
||||
attn_dtype = attn.dtype
|
||||
attn_as_float = self.softmax(attn.float())
|
||||
attn = attn_as_float.to(attn_dtype)
|
||||
attn = self.attn_dropout(attn)
|
||||
|
||||
# weighted sum
|
||||
# [N, h, S, T] x [N, h, T, c] --> [N, h, S, c]
|
||||
out = torch.matmul(attn, value)
|
||||
|
||||
# [N, h, S, c] --> [N, S, h, c] --> [N, S, C]
|
||||
out = out.transpose(1, 2).reshape(b_sz, S_len, -1)
|
||||
out = self.out_proj(out)
|
||||
|
||||
return out
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x_q: Tensor,
|
||||
x_kv: Tensor | None = None,
|
||||
key_padding_mask: Tensor | None = None,
|
||||
attn_mask: Tensor | None = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> Tensor:
|
||||
# [Batch , Sequence, Hidden_dim]
|
||||
return self._forward_impl(
|
||||
x_q=x_q,
|
||||
x_kv=x_kv,
|
||||
key_padding_mask=key_padding_mask,
|
||||
attn_mask=attn_mask,
|
||||
)
|
||||
|
||||
|
||||
class TransformerEncoder(nn.Module):
|
||||
"""This class defines the pre-norm `Transformer encoder <https://arxiv.org/abs/1706.03762>`_.
|
||||
|
||||
Args:
|
||||
embed_dim: :math:`C_{in}` from an expected input of size :math:`(N, P, C_{in})`.
|
||||
ffn_latent_dim: Inner dimension of the FFN.
|
||||
num_heads: Number of heads in multi-head attention. Default: 8.
|
||||
attn_dropout: Dropout rate for attention in multi-head attention. Default: 0.0
|
||||
dropout: Dropout rate. Default: 0.0.
|
||||
ffn_dropout: Dropout between FFN layers. Default: 0.0.
|
||||
transformer_norm_layer: Normalization layer. Default: layer_norm.
|
||||
stochastic_dropout: Stochastic dropout setting. Default: 0.0.
|
||||
|
||||
Notes:
|
||||
- Input: :math:`(N, P, C_{in})` where :math:`N` is batch size, :math:`P` is number of patches,
|
||||
and: math:`C_{in}` is input embedding dim
|
||||
- Output: same shape as the input
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim: int,
|
||||
ffn_latent_dim: int,
|
||||
num_heads: int | None = 8,
|
||||
attn_dropout: float | None = 0.0,
|
||||
dropout: float | None = 0.0,
|
||||
ffn_dropout: float | None = 0.0,
|
||||
transformer_norm_layer: str | None = "layer_norm",
|
||||
stochastic_dropout: float | None = 0.0,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
# Build attention layer
|
||||
attn_unit = MultiHeadAttention(
|
||||
embed_dim,
|
||||
num_heads,
|
||||
attn_dropout=attn_dropout,
|
||||
bias=True,
|
||||
)
|
||||
|
||||
self.pre_norm_mha = nn.Sequential(
|
||||
get_normalization_layer(norm_type=transformer_norm_layer, num_features=embed_dim),
|
||||
attn_unit,
|
||||
nn.Dropout(p=dropout),
|
||||
)
|
||||
|
||||
act_name = nn.GELU()
|
||||
self.pre_norm_ffn = nn.Sequential(
|
||||
get_normalization_layer(norm_type=transformer_norm_layer, num_features=embed_dim),
|
||||
nn.Linear(in_features=embed_dim, out_features=ffn_latent_dim, bias=True),
|
||||
act_name,
|
||||
nn.Dropout(p=ffn_dropout),
|
||||
nn.Linear(in_features=ffn_latent_dim, out_features=embed_dim, bias=True),
|
||||
nn.Dropout(p=dropout),
|
||||
)
|
||||
|
||||
self.drop_path = nn.Identity()
|
||||
if stochastic_dropout > 0.0:
|
||||
if dropout > 0.0:
|
||||
logger.error(
|
||||
"Stochastic dropout and dropout are mutually exclusive. "
|
||||
"Use either of them, but not both."
|
||||
f"Got: {stochastic_dropout} and {dropout}"
|
||||
)
|
||||
self.drop_path = StochasticDepth(p=stochastic_dropout, mode="row")
|
||||
|
||||
self.embed_dim = embed_dim
|
||||
self.ffn_dim = ffn_latent_dim
|
||||
self.ffn_dropout = ffn_dropout
|
||||
self.stochastic_dropout = stochastic_dropout
|
||||
self.std_dropout = dropout
|
||||
self.attn_fn_name = attn_unit.__class__.__name__
|
||||
self.act_fn_name = act_name.__class__.__name__
|
||||
self.norm_type = transformer_norm_layer
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}(embed_dim={self.embed_dim}, ffn_dim={self.ffn_dim}, dropout={self.std_dropout}, ffn_dropout={self.ffn_dropout}, stochastic_dropout={self.stochastic_dropout}, attn_fn={self.attn_fn_name}, act_fn={self.act_fn_name}, norm_fn={self.norm_type})"
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: Tensor,
|
||||
x_prev: Tensor | None = None,
|
||||
key_padding_mask: Tensor | None = None,
|
||||
attn_mask: Tensor | None = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> Tensor:
|
||||
# Multi-head attention
|
||||
res = x
|
||||
x = self.pre_norm_mha[0](x) # norm
|
||||
x = self.pre_norm_mha[1](
|
||||
x_q=x,
|
||||
x_kv=x_prev,
|
||||
key_padding_mask=key_padding_mask,
|
||||
attn_mask=attn_mask,
|
||||
*args,
|
||||
**kwargs,
|
||||
) # mha
|
||||
|
||||
x = self.drop_path(self.pre_norm_mha[2](x)) # applying stochastic depth
|
||||
x = x + res
|
||||
|
||||
# Feed forward network
|
||||
x = x + self.drop_path(self.pre_norm_ffn(x))
|
||||
return x
|
||||
6
mobileclip/modules/image/__init__.py
Normal file
6
mobileclip/modules/image/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
#
|
||||
# For licensing see accompanying LICENSE file.
|
||||
# Copyright (C) 2024 Apple Inc. All rights reserved.
|
||||
#
|
||||
97
mobileclip/modules/image/image_projection.py
Normal file
97
mobileclip/modules/image/image_projection.py
Normal file
@@ -0,0 +1,97 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
# For licensing see accompanying LICENSE file.
|
||||
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
|
||||
from mobileclip import logger
|
||||
|
||||
|
||||
class GlobalPool(nn.Module):
|
||||
"""This layers applies global pooling over a 4D or 5D input tensor.
|
||||
|
||||
Args:
|
||||
pool_type (Optional[str]): Pooling type. It can be mean, rms, or abs. Default: `mean`
|
||||
keep_dim (Optional[bool]): Do not squeeze the dimensions of a tensor. Default: `False`
|
||||
|
||||
Notes:
|
||||
- Input: :math:`(N, C, H, W)` or :math:`(N, C, D, H, W)`
|
||||
- Output: :math:`(N, C, 1, 1)` or :math:`(N, C, 1, 1, 1)` if keep_dim else :math:`(N, C)`
|
||||
"""
|
||||
|
||||
pool_types = ["mean", "rms", "abs"]
|
||||
|
||||
def __init__(self, pool_type: str | None = "mean", keep_dim: bool | None = False, *args, **kwargs) -> None:
|
||||
super().__init__()
|
||||
if pool_type not in self.pool_types:
|
||||
logger.error(f"Supported pool types are: {self.pool_types}. Got {pool_type}")
|
||||
self.pool_type = pool_type
|
||||
self.keep_dim = keep_dim
|
||||
|
||||
def _global_pool(self, x: Tensor, dims: list):
|
||||
if self.pool_type == "rms": # root mean square
|
||||
x = x**2
|
||||
x = torch.mean(x, dim=dims, keepdim=self.keep_dim)
|
||||
x = x**-0.5
|
||||
elif self.pool_type == "abs": # absolute
|
||||
x = torch.mean(torch.abs(x), dim=dims, keepdim=self.keep_dim)
|
||||
else:
|
||||
# default is mean
|
||||
# same as AdaptiveAvgPool
|
||||
x = torch.mean(x, dim=dims, keepdim=self.keep_dim)
|
||||
return x
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
if x.dim() == 4:
|
||||
dims = [-2, -1]
|
||||
elif x.dim() == 5:
|
||||
dims = [-3, -2, -1]
|
||||
else:
|
||||
raise NotImplementedError("Currently 2D and 3D global pooling supported")
|
||||
return self._global_pool(x, dims=dims)
|
||||
|
||||
|
||||
class GlobalPool2D(nn.Module):
|
||||
"""This class implements global pooling with linear projection."""
|
||||
|
||||
def __init__(self, in_dim: int, out_dim: int, *args, **kwargs) -> None:
|
||||
super().__init__()
|
||||
scale = in_dim**-0.5
|
||||
self.pool = GlobalPool(pool_type="mean", keep_dim=False)
|
||||
self.proj = nn.Parameter(scale * torch.randn(size=(in_dim, out_dim)))
|
||||
self.in_dim = in_dim
|
||||
self.out_dim = out_dim
|
||||
|
||||
def forward(self, x: Tensor, *args, **kwargs) -> Tensor:
|
||||
# x is of shape [batch, in_dim]
|
||||
assert x.dim() == 4, f"Input should be 4-dimensional (Batch x in_dim x in_height x in_width). Got: {x.shape}"
|
||||
|
||||
# [batch, in_dim, in_height, in_width] --> [batch, in_dim]
|
||||
x = self.pool(x)
|
||||
# [batch, in_dim] x [in_dim, out_dim] --> [batch, out_dim]
|
||||
x = x @ self.proj
|
||||
return x
|
||||
|
||||
|
||||
class SimpleImageProjectionHead(nn.Module):
|
||||
"""This class implements linear projection head."""
|
||||
|
||||
def __init__(self, in_dim: int, out_dim: int) -> None:
|
||||
super().__init__()
|
||||
scale = in_dim**-0.5
|
||||
self.proj = nn.Parameter(scale * torch.randn(size=(in_dim, out_dim)))
|
||||
self.in_dim = in_dim
|
||||
self.out_dim = out_dim
|
||||
|
||||
def forward(self, x: Tensor, *args, **kwargs) -> Tensor:
|
||||
# x is of shape [batch, in_dim]
|
||||
assert x.dim() == 2, f"Input should be 2-dimensional (Batch x in_dim). Got: {x.shape}"
|
||||
|
||||
# [batch, in_dim] x [in_dim, out_dim] --> [batch, out_dim]
|
||||
x = x @ self.proj
|
||||
return x
|
||||
177
mobileclip/modules/image/replknet.py
Normal file
177
mobileclip/modules/image/replknet.py
Normal file
@@ -0,0 +1,177 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
#
|
||||
# For acknowledgment see accompanying ACKNOWLEDGMENTS file.
|
||||
# Copyright (C) 2024 Apple Inc. All rights reserved.
|
||||
#
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from timm.models.layers import SqueezeExcite
|
||||
|
||||
__all__ = ["ReparamLargeKernelConv"]
|
||||
|
||||
|
||||
class ReparamLargeKernelConv(nn.Module):
|
||||
"""Building Block of RepLKNet.
|
||||
|
||||
This class defines overparameterized large kernel conv block introduced in `RepLKNet
|
||||
<https://arxiv.org/abs/2203.06717>`_
|
||||
|
||||
Reference: https://github.com/DingXiaoH/RepLKNet-pytorch
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: int,
|
||||
stride: int,
|
||||
groups: int,
|
||||
small_kernel: int,
|
||||
inference_mode: bool = False,
|
||||
use_se: bool = False,
|
||||
activation: nn.Module = nn.GELU(),
|
||||
) -> None:
|
||||
"""Construct a ReparamLargeKernelConv module.
|
||||
|
||||
Args:
|
||||
in_channels: Number of input channels.
|
||||
out_channels: Number of output channels.
|
||||
kernel_size: Kernel size of the large kernel conv branch.
|
||||
stride: Stride size. Default: 1
|
||||
groups: Group number. Default: 1
|
||||
small_kernel: Kernel size of small kernel conv branch.
|
||||
inference_mode: If True, instantiates model in inference mode. Default: ``False``
|
||||
activation: Activation module. Default: ``nn.GELU``
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.stride = stride
|
||||
self.groups = groups
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.activation = activation
|
||||
|
||||
self.kernel_size = kernel_size
|
||||
self.small_kernel = small_kernel
|
||||
self.padding = kernel_size // 2
|
||||
|
||||
# Check if SE is requested
|
||||
if use_se:
|
||||
self.se = SqueezeExcite(out_channels, rd_ratio=0.25)
|
||||
else:
|
||||
self.se = nn.Identity()
|
||||
|
||||
if inference_mode:
|
||||
self.lkb_reparam = nn.Conv2d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=self.padding,
|
||||
dilation=1,
|
||||
groups=groups,
|
||||
bias=True,
|
||||
)
|
||||
else:
|
||||
self.lkb_origin = self._conv_bn(kernel_size=kernel_size, padding=self.padding)
|
||||
if small_kernel is not None:
|
||||
assert small_kernel <= kernel_size, (
|
||||
"The kernel size for re-param cannot be larger than the large kernel!"
|
||||
)
|
||||
self.small_conv = self._conv_bn(kernel_size=small_kernel, padding=small_kernel // 2)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Apply forward pass."""
|
||||
if hasattr(self, "lkb_reparam"):
|
||||
out = self.lkb_reparam(x)
|
||||
else:
|
||||
out = self.lkb_origin(x)
|
||||
if hasattr(self, "small_conv"):
|
||||
out += self.small_conv(x)
|
||||
|
||||
return self.activation(self.se(out))
|
||||
|
||||
def get_kernel_bias(self) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Method to obtain re-parameterized kernel and bias. Reference: https://github.com/DingXiaoH/RepLKNet-pytorch.
|
||||
|
||||
Returns:
|
||||
Tuple of (kernel, bias) after fusing branches.
|
||||
"""
|
||||
eq_k, eq_b = self._fuse_bn(self.lkb_origin.conv, self.lkb_origin.bn)
|
||||
if hasattr(self, "small_conv"):
|
||||
small_k, small_b = self._fuse_bn(self.small_conv.conv, self.small_conv.bn)
|
||||
eq_b += small_b
|
||||
eq_k += nn.functional.pad(small_k, [(self.kernel_size - self.small_kernel) // 2] * 4)
|
||||
return eq_k, eq_b
|
||||
|
||||
def reparameterize(self) -> None:
|
||||
"""Following works like `RepVGG: Making VGG-style ConvNets Great Again` - https://arxiv.org/pdf/2101.03697.pdf.
|
||||
We re-parameterize multi-branched architecture used at training time to obtain a plain CNN-like
|
||||
structure for inference.
|
||||
"""
|
||||
eq_k, eq_b = self.get_kernel_bias()
|
||||
self.lkb_reparam = nn.Conv2d(
|
||||
in_channels=self.in_channels,
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=self.kernel_size,
|
||||
stride=self.stride,
|
||||
padding=self.padding,
|
||||
dilation=self.lkb_origin.conv.dilation,
|
||||
groups=self.groups,
|
||||
bias=True,
|
||||
)
|
||||
|
||||
self.lkb_reparam.weight.data = eq_k
|
||||
self.lkb_reparam.bias.data = eq_b
|
||||
self.__delattr__("lkb_origin")
|
||||
if hasattr(self, "small_conv"):
|
||||
self.__delattr__("small_conv")
|
||||
|
||||
@staticmethod
|
||||
def _fuse_bn(conv: torch.Tensor, bn: nn.BatchNorm2d) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Method to fuse batchnorm layer with conv layer.
|
||||
|
||||
Args:
|
||||
conv: Convolutional kernel weights.
|
||||
bn: Batchnorm 2d layer.
|
||||
|
||||
Returns:
|
||||
Tuple of (kernel, bias) after fusing batchnorm.
|
||||
"""
|
||||
kernel = conv.weight
|
||||
running_mean = bn.running_mean
|
||||
running_var = bn.running_var
|
||||
gamma = bn.weight
|
||||
beta = bn.bias
|
||||
eps = bn.eps
|
||||
std = (running_var + eps).sqrt()
|
||||
t = (gamma / std).reshape(-1, 1, 1, 1)
|
||||
return kernel * t, beta - running_mean * gamma / std
|
||||
|
||||
def _conv_bn(self, kernel_size: int, padding: int = 0) -> nn.Sequential:
|
||||
"""Helper method to construct conv-batchnorm layers.
|
||||
|
||||
Args:
|
||||
kernel_size: Size of the convolution kernel.
|
||||
padding: Zero-padding size.
|
||||
|
||||
Returns:
|
||||
A nn.Sequential Conv-BN module.
|
||||
"""
|
||||
mod_list = nn.Sequential()
|
||||
mod_list.add_module(
|
||||
"conv",
|
||||
nn.Conv2d(
|
||||
in_channels=self.in_channels,
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=self.stride,
|
||||
padding=padding,
|
||||
groups=self.groups,
|
||||
bias=False,
|
||||
),
|
||||
)
|
||||
mod_list.add_module("bn", nn.BatchNorm2d(num_features=self.out_channels))
|
||||
return mod_list
|
||||
6
mobileclip/modules/text/__init__.py
Normal file
6
mobileclip/modules/text/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
#
|
||||
# For licensing see accompanying LICENSE file.
|
||||
# Copyright (C) 2024 Apple Inc. All rights reserved.
|
||||
#
|
||||
265
mobileclip/modules/text/repmixer.py
Normal file
265
mobileclip/modules/text/repmixer.py
Normal file
@@ -0,0 +1,265 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
# For licensing see accompanying LICENSE file.
|
||||
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from timm.models.layers import DropPath, trunc_normal_
|
||||
|
||||
from mobileclip.modules.common.mobileone import MobileOneBlock
|
||||
|
||||
|
||||
class ConvFFN(nn.Module):
|
||||
"""Convolutional FFN Module."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
context_size: int,
|
||||
hidden_channels: int | None = None,
|
||||
out_channels: int | None = None,
|
||||
act_layer: nn.Module = nn.GELU,
|
||||
drop: float = 0.0,
|
||||
) -> None:
|
||||
"""Build convolutional FFN module.
|
||||
|
||||
Args:
|
||||
in_channels: Number of input channels.
|
||||
context_size: Context size for 1D signals.
|
||||
hidden_channels: Number of channels after expansion. Default: None
|
||||
out_channels: Number of output channels. Default: None
|
||||
act_layer: Activation layer. Default: ``GELU``
|
||||
drop: Dropout rate. Default: ``0.0``.
|
||||
"""
|
||||
super().__init__()
|
||||
out_channels = out_channels or in_channels
|
||||
hidden_channels = hidden_channels or in_channels
|
||||
self.conv = nn.Sequential()
|
||||
self.conv.add_module(
|
||||
"conv",
|
||||
nn.Conv2d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=(1, int(context_size)),
|
||||
padding=(0, int(context_size // 2)),
|
||||
groups=in_channels,
|
||||
bias=False,
|
||||
),
|
||||
)
|
||||
self.conv.add_module("bn", nn.BatchNorm2d(num_features=out_channels))
|
||||
self.fc1 = nn.Conv2d(in_channels, hidden_channels, kernel_size=1)
|
||||
self.act = act_layer()
|
||||
self.fc2 = nn.Conv2d(hidden_channels, out_channels, kernel_size=1)
|
||||
self.drop = nn.Dropout(drop)
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, m: nn.Module) -> None:
|
||||
if isinstance(m, nn.Conv2d):
|
||||
trunc_normal_(m.weight, std=0.02)
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.conv(x)
|
||||
x = self.fc1(x)
|
||||
x = self.act(x)
|
||||
x = self.drop(x)
|
||||
x = self.fc2(x)
|
||||
x = self.drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class RepMixer(nn.Module):
|
||||
"""Reparameterizable token mixer.
|
||||
|
||||
For more details, please refer to our paper: `FastViT: A Fast Hybrid Vision Transformer using Structural
|
||||
Reparameterization <https://arxiv.org/pdf/2303.14189.pdf>`_
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
kernel_size=3,
|
||||
use_layer_scale=True,
|
||||
layer_scale_init_value=1e-5,
|
||||
inference_mode: bool = False,
|
||||
):
|
||||
"""Build RepMixer Module.
|
||||
|
||||
Args:
|
||||
dim: Input feature map dimension. :math:`C_{in}` from an expected input of size :math:`(B, C_{in}, H, W)`.
|
||||
kernel_size: Kernel size for spatial mixing. Default: 3
|
||||
use_layer_scale: If True, learnable layer scale is used. Default: ``True``
|
||||
layer_scale_init_value: Initial value for layer scale. Default: 1e-5
|
||||
inference_mode: If True, instantiates model in inference mode. Default: ``False``
|
||||
"""
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.kernel_size = kernel_size
|
||||
self.inference_mode = inference_mode
|
||||
|
||||
if inference_mode:
|
||||
self.reparam_conv = nn.Conv2d(
|
||||
in_channels=self.dim,
|
||||
out_channels=self.dim,
|
||||
kernel_size=(1, self.kernel_size),
|
||||
stride=1,
|
||||
padding=(0, self.kernel_size // 2),
|
||||
groups=self.dim,
|
||||
bias=True,
|
||||
)
|
||||
else:
|
||||
self.norm = MobileOneBlock(
|
||||
dim,
|
||||
dim,
|
||||
(1, kernel_size),
|
||||
padding=(0, kernel_size // 2),
|
||||
groups=dim,
|
||||
use_act=False,
|
||||
use_scale_branch=False,
|
||||
num_conv_branches=0,
|
||||
)
|
||||
self.mixer = MobileOneBlock(
|
||||
dim,
|
||||
dim,
|
||||
(1, kernel_size),
|
||||
padding=(0, kernel_size // 2),
|
||||
groups=dim,
|
||||
use_act=False,
|
||||
)
|
||||
self.use_layer_scale = use_layer_scale
|
||||
if use_layer_scale:
|
||||
self.layer_scale = nn.Parameter(layer_scale_init_value * torch.ones((dim, 1, 1)), requires_grad=True)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if hasattr(self, "reparam_conv"):
|
||||
x = self.reparam_conv(x)
|
||||
return x
|
||||
else:
|
||||
if self.use_layer_scale:
|
||||
x = x + self.layer_scale * (self.mixer(x) - self.norm(x))
|
||||
else:
|
||||
x = x + self.mixer(x) - self.norm(x)
|
||||
return x
|
||||
|
||||
def reparameterize(self) -> None:
|
||||
"""Reparameterize mixer and norm into a single convolutional layer for efficient inference."""
|
||||
if self.inference_mode:
|
||||
return
|
||||
|
||||
self.mixer.reparameterize()
|
||||
self.norm.reparameterize()
|
||||
|
||||
if self.use_layer_scale:
|
||||
w = self.mixer.id_tensor + self.layer_scale.unsqueeze(-1) * (
|
||||
self.mixer.reparam_conv.weight - self.norm.reparam_conv.weight
|
||||
)
|
||||
b = torch.squeeze(self.layer_scale) * (self.mixer.reparam_conv.bias - self.norm.reparam_conv.bias)
|
||||
else:
|
||||
w = self.mixer.id_tensor + self.mixer.reparam_conv.weight - self.norm.reparam_conv.weight
|
||||
b = self.mixer.reparam_conv.bias - self.norm.reparam_conv.bias
|
||||
|
||||
self.reparam_conv = nn.Conv2d(
|
||||
in_channels=self.dim,
|
||||
out_channels=self.dim,
|
||||
kernel_size=(1, self.kernel_size),
|
||||
stride=1,
|
||||
padding=(0, self.kernel_size // 2),
|
||||
groups=self.dim,
|
||||
bias=True,
|
||||
)
|
||||
self.reparam_conv.weight.data = w
|
||||
self.reparam_conv.bias.data = b
|
||||
|
||||
for para in self.parameters():
|
||||
para.detach_()
|
||||
self.__delattr__("mixer")
|
||||
self.__delattr__("norm")
|
||||
if self.use_layer_scale:
|
||||
self.__delattr__("layer_scale")
|
||||
|
||||
|
||||
class RepMixerBlock(nn.Module):
|
||||
"""Implementation of Metaformer block with RepMixer as token mixer.
|
||||
|
||||
For more details on Metaformer structure, please refer to: `MetaFormer Is Actually What You Need for Vision
|
||||
<https://arxiv.org/pdf/2111.11418.pdf>`_
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
kernel_size: int = 11,
|
||||
mlp_ratio: float = 4.0,
|
||||
act_layer: nn.Module = nn.GELU,
|
||||
drop: float = 0.0,
|
||||
drop_path: float = 0.0,
|
||||
use_layer_scale: bool = True,
|
||||
layer_scale_init_value: float = 1e-5,
|
||||
inference_mode: bool = False,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
"""Build RepMixer Block.
|
||||
|
||||
Args:
|
||||
dim: Number of embedding dimensions.
|
||||
kernel_size: Kernel size for repmixer. Default: 3
|
||||
mlp_ratio: MLP expansion ratio. Default: 4.0
|
||||
act_layer: Activation layer. Default: ``nn.GELU``
|
||||
drop: Dropout rate. Default: 0.0
|
||||
drop_path: Drop path rate. Default: 0.0
|
||||
use_layer_scale: Flag to turn on layer scale. Default: ``True``
|
||||
layer_scale_init_value: Layer scale value at initialization. Default: 1e-5
|
||||
inference_mode: Flag to instantiate block in inference mode. Default: ``False``
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.token_mixer = RepMixer(
|
||||
dim,
|
||||
kernel_size=kernel_size,
|
||||
use_layer_scale=use_layer_scale,
|
||||
layer_scale_init_value=layer_scale_init_value,
|
||||
inference_mode=inference_mode,
|
||||
)
|
||||
|
||||
assert mlp_ratio > 0, f"MLP ratio should be greater than 0, found: {mlp_ratio}"
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
self.convffn = ConvFFN(
|
||||
in_channels=dim,
|
||||
context_size=kernel_size,
|
||||
hidden_channels=mlp_hidden_dim,
|
||||
act_layer=act_layer,
|
||||
drop=drop,
|
||||
)
|
||||
|
||||
# Drop Path
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
||||
|
||||
# Layer Scale
|
||||
self.use_layer_scale = use_layer_scale
|
||||
if use_layer_scale:
|
||||
self.layer_scale = nn.Parameter(layer_scale_init_value * torch.ones((dim, 1, 1)), requires_grad=True)
|
||||
|
||||
def forward(self, x, *args, **kwargs):
|
||||
if x.dim() == 3:
|
||||
# B, C, D --- where C is the context length
|
||||
# Convert to B, D, C --- to match RepMixer impl.
|
||||
x = x.permute(0, 2, 1)
|
||||
x = torch.unsqueeze(x, dim=2)
|
||||
else:
|
||||
raise ValueError(f"Expected tensor of dim=3, obtained tensor of dim={x.dim()}")
|
||||
|
||||
if self.use_layer_scale:
|
||||
x = self.token_mixer(x)
|
||||
x = x + self.drop_path(self.layer_scale * self.convffn(x))
|
||||
else:
|
||||
x = self.token_mixer(x)
|
||||
x = x + self.drop_path(self.convffn(x))
|
||||
|
||||
# Convert tensors back
|
||||
x = x.squeeze(dim=2).permute(0, 2, 1)
|
||||
return x
|
||||
39
mobileclip/modules/text/tokenizer.py
Normal file
39
mobileclip/modules/text/tokenizer.py
Normal file
@@ -0,0 +1,39 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
#
|
||||
# For licensing see accompanying LICENSE file.
|
||||
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
|
||||
#
|
||||
|
||||
import open_clip
|
||||
from torch import Tensor, nn
|
||||
|
||||
|
||||
class ClipTokenizer(nn.Module):
|
||||
def __init__(self, cfg, *args, **kwargs):
|
||||
super().__init__()
|
||||
self.context_length = cfg["text_cfg"]["context_length"]
|
||||
model_name = getattr(cfg["text_cfg"], "open_clip_tokenizer", "ViT-B-16")
|
||||
self.tokenizer = open_clip.get_tokenizer(model_name)
|
||||
|
||||
def get_vocab_size(self) -> int:
|
||||
return len(self.tokenizer.encoder)
|
||||
|
||||
def get_encodings(self) -> dict[str, int]:
|
||||
return self.tokenizer.encoder
|
||||
|
||||
def get_eot_token(self) -> int:
|
||||
# Tokenizing an empty string returns a list [sot_id, eot_id]
|
||||
return self.tokenizer("")[1]
|
||||
|
||||
def get_sot_token(self) -> int:
|
||||
# Tokenizing an empty string returns a list [sot_id, eot_id]
|
||||
return self.tokenizer("")[0]
|
||||
|
||||
def forward(self, input_sentence: str, *args, **kwargs) -> Tensor:
|
||||
# tokenizer returns indices as a string
|
||||
tokenized_sentence = self.tokenizer(input_sentence, self.context_length)
|
||||
assert tokenized_sentence.shape[-1] == self.context_length, (
|
||||
"Tokenized tensor should be exactly `context_length` long."
|
||||
)
|
||||
return tokenized_sentence
|
||||
218
mobileclip/text_encoder.py
Normal file
218
mobileclip/text_encoder.py
Normal file
@@ -0,0 +1,218 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
# For licensing see accompanying LICENSE file.
|
||||
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from collections.abc import Sequence
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from mobileclip import logger
|
||||
from mobileclip.modules.common.transformer import (
|
||||
PositionalEmbedding,
|
||||
TransformerEncoder,
|
||||
get_normalization_layer,
|
||||
)
|
||||
from mobileclip.modules.text.repmixer import RepMixerBlock
|
||||
|
||||
|
||||
class TextTransformer(nn.Module):
|
||||
def __init__(self, cfg: dict, projection_dim: int, *args, **kwargs) -> None:
|
||||
super().__init__()
|
||||
|
||||
model_dim = cfg["dim"]
|
||||
no_scale_embedding = cfg.get("no_scale_embedding", False)
|
||||
no_pos_embedding = cfg.get("no_pos_embedding", False)
|
||||
embed_dropout = cfg.get("embed_dropout", 0.0)
|
||||
norm_layer = cfg["norm_layer"]
|
||||
variant = cfg["model_name"]
|
||||
self.vocab_size = cfg["vocab_size"]
|
||||
self.projection_dim = projection_dim
|
||||
|
||||
# Token embedding layer
|
||||
self.embedding_layer = nn.Embedding(embedding_dim=model_dim, num_embeddings=self.vocab_size)
|
||||
self.embed_scale = 1.0 if no_scale_embedding else model_dim**-0.5
|
||||
|
||||
# Context length
|
||||
context_length = cfg["context_length"]
|
||||
assert context_length is not None, "Context length can't be None. Please set value accordingly."
|
||||
|
||||
self.positional_embedding = (
|
||||
None if no_pos_embedding else PositionalEmbedding(num_embeddings=context_length, embedding_dim=model_dim)
|
||||
)
|
||||
|
||||
self.embedding_dropout = nn.Dropout(p=embed_dropout)
|
||||
|
||||
# Transformer layer
|
||||
n_transformer_layers = cfg["n_transformer_layers"]
|
||||
|
||||
# FFN multipliers for transformer layer
|
||||
ffn_multipliers = cfg["ffn_multiplier_per_layer"]
|
||||
if isinstance(ffn_multipliers, (float, int)):
|
||||
ffn_multipliers = [ffn_multipliers] * n_transformer_layers
|
||||
|
||||
if not isinstance(ffn_multipliers, Sequence):
|
||||
logger.error(
|
||||
f"{self.__class__.__name__} expects FFN multipliers as a list, whose length is the same as"
|
||||
f" number of transformer layers. Got: {type(ffn_multipliers)}"
|
||||
)
|
||||
elif isinstance(ffn_multipliers, Sequence) and len(ffn_multipliers) != n_transformer_layers:
|
||||
logger.error(
|
||||
f"We need FFN multiplier for each transformer layer. Got {len(ffn_multipliers)} ffn"
|
||||
f" multipliers while number of transformer layers = {n_transformer_layers}"
|
||||
)
|
||||
ffn_dims = [int(math.ceil(model_dim * ffn_mult / 16.0) * 16.0) for ffn_mult in ffn_multipliers]
|
||||
|
||||
# Heads for transformer layers
|
||||
mha_heads = cfg["n_heads_per_layer"]
|
||||
if isinstance(mha_heads, int):
|
||||
mha_heads = [mha_heads] * n_transformer_layers
|
||||
|
||||
if not isinstance(mha_heads, Sequence):
|
||||
logger.error(
|
||||
f"{self.__class__.__name__} expects MHA heads as a list, whose length is the same as number of "
|
||||
f"transformer layers. Got: {type(mha_heads)}"
|
||||
)
|
||||
elif isinstance(mha_heads, Sequence) and len(mha_heads) != n_transformer_layers:
|
||||
logger.error(
|
||||
f"{self.__class__.__name__} needs MHA heads for each transformer layer. Got {len(mha_heads)} mha heads while"
|
||||
f" number of transformer layers = {n_transformer_layers}"
|
||||
)
|
||||
|
||||
if variant == "base":
|
||||
self.transformer = nn.ModuleList(
|
||||
[
|
||||
TransformerEncoder(
|
||||
embed_dim=model_dim,
|
||||
num_heads=mha_heads[layer_idx],
|
||||
ffn_latent_dim=ffn_dims[layer_idx],
|
||||
transformer_norm_layer=norm_layer,
|
||||
)
|
||||
for layer_idx in range(n_transformer_layers)
|
||||
]
|
||||
)
|
||||
elif variant == "mct":
|
||||
self.transformer = nn.ModuleList([RepMixerBlock(dim=model_dim)])
|
||||
self.transformer.extend(
|
||||
[
|
||||
TransformerEncoder(
|
||||
embed_dim=model_dim,
|
||||
num_heads=mha_heads[layer_idx],
|
||||
ffn_latent_dim=ffn_dims[layer_idx],
|
||||
transformer_norm_layer=norm_layer,
|
||||
)
|
||||
for layer_idx in range(n_transformer_layers)
|
||||
]
|
||||
)
|
||||
self.transformer.extend([RepMixerBlock(dim=model_dim)])
|
||||
else:
|
||||
raise ValueError(f"Unrecognized text encoder variant {variant}")
|
||||
|
||||
self.final_layer_norm = get_normalization_layer(num_features=model_dim, norm_type=norm_layer)
|
||||
|
||||
self.projection_layer = nn.Parameter(torch.empty(model_dim, self.projection_dim))
|
||||
self.model_dim = model_dim
|
||||
self.causal_masking = cfg["causal_masking"]
|
||||
|
||||
def forward_embedding(self, text_tokens: Tensor) -> Tensor:
|
||||
"""Return text embedding for all tokens.
|
||||
|
||||
Args:
|
||||
text_tokens: a tensor of token indices. Shape: [batch_size, context_length]
|
||||
|
||||
Returns:
|
||||
A tensor of [batch_size, context_length, hidden_dim].
|
||||
"""
|
||||
# [batch_size, context_length] --> [batch_size, context_length, hidden_dim]
|
||||
token_emb = self.embedding_layer(text_tokens)
|
||||
seq_len = token_emb.shape[1]
|
||||
if self.positional_embedding is not None:
|
||||
token_emb = token_emb + self.positional_embedding(seq_len).to(token_emb.dtype)
|
||||
token_emb = self.embedding_dropout(token_emb)
|
||||
return token_emb
|
||||
|
||||
@staticmethod
|
||||
@torch.jit.script # use scripting to avoid device constant
|
||||
def build_attention_mask(text_tokens: torch.Tensor) -> Tensor:
|
||||
"""Build causal attention mask [batch_size, context_length, context_length]."""
|
||||
# Build mask with full attention between the tokens
|
||||
# pytorch uses additive attention mask; fill with -inf
|
||||
batch_size, context_length = text_tokens.shape
|
||||
mask = torch.empty(context_length, context_length, device=text_tokens.device, dtype=torch.float32)
|
||||
mask.fill_(float("-inf"))
|
||||
mask.triu_(1) # zero out the lower diagonal
|
||||
mask = mask.unsqueeze(0) # add dummy batch dimension
|
||||
mask = mask.expand(batch_size, -1, -1)
|
||||
return mask
|
||||
|
||||
def encode_text(
|
||||
self,
|
||||
text_tokens: Tensor,
|
||||
key_padding_mask: Tensor | None = None,
|
||||
return_all_tokens: bool = False,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> Tensor:
|
||||
"""Return text token embeddings.
|
||||
|
||||
Args:
|
||||
text_tokens: a tensor of token indices. Shape: [batch_size, context_length]
|
||||
key_padding_mask: a tensor of boolean values as the padding mask of shape [batch_size, context_length]
|
||||
return_all_tokens: a boolean flag to return all tokens, defaults to False to return only EOT token
|
||||
embedding.
|
||||
|
||||
Returns:
|
||||
A tensor of [batch_size, context_length, hidden_dim] if return_all_tokens is
|
||||
True, otherwise a tensor of [batch_size, hidden_dim].
|
||||
"""
|
||||
# Discrete tokens to continuous embeddings
|
||||
# [batch_size, context_length] --> [batch_size, context_length, hidden_dim]
|
||||
token_emb = self.forward_embedding(text_tokens)
|
||||
|
||||
# [1, context_length, context_length]
|
||||
attn_mask = None
|
||||
if self.causal_masking:
|
||||
attn_mask = self.build_attention_mask(text_tokens=text_tokens)
|
||||
key_padding_mask = None
|
||||
|
||||
for layer in self.transformer:
|
||||
token_emb = layer(
|
||||
token_emb,
|
||||
key_padding_mask=key_padding_mask,
|
||||
attn_mask=attn_mask,
|
||||
)
|
||||
|
||||
# Apply layer norm
|
||||
token_emb = self.final_layer_norm(token_emb)
|
||||
|
||||
if return_all_tokens:
|
||||
return token_emb
|
||||
|
||||
# Take features from the eot embedding (eot_token is the highest number in each sequence)
|
||||
token_emb = token_emb[torch.arange(text_tokens.shape[0]), text_tokens.argmax(dim=-1)]
|
||||
|
||||
token_emb = token_emb @ self.projection_layer
|
||||
return token_emb
|
||||
|
||||
def forward(
|
||||
self,
|
||||
text_tokens: Tensor,
|
||||
key_padding_mask: Tensor | None = None,
|
||||
return_all_tokens: bool = False,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> Tensor:
|
||||
# Image-text pair data with single caption
|
||||
# [B, CL] --> [B, d]
|
||||
text_tokens = self.encode_text(
|
||||
text_tokens=text_tokens,
|
||||
key_padding_mask=key_padding_mask,
|
||||
return_all_tokens=return_all_tokens,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
return text_tokens
|
||||
BIN
mobileclip_blt.ts
(Stored with Git LFS)
Normal file
BIN
mobileclip_blt.ts
(Stored with Git LFS)
Normal file
Binary file not shown.
64
train.py
Normal file
64
train.py
Normal file
@@ -0,0 +1,64 @@
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
# --- 引入你的模块 ---
|
||||
from dataset import YOLODataset
|
||||
from yolo11_standalone import YOLO11
|
||||
from loss import YOLOv8DetectionLoss
|
||||
|
||||
# --- 引入刚刚写的 Trainer ---
|
||||
from trainer import YOLO11Trainer
|
||||
|
||||
def run_training():
|
||||
# --- 1.全局配置 ---
|
||||
img_dir_train = "E:\\Datasets\\coco\\images\\train2017"
|
||||
label_dir_train = "E:\\Datasets\\coco\\labels\\train2017"
|
||||
img_dir_val = "E:\\Datasets\\coco\\images\\val2017"
|
||||
label_dir_val = "E:\\Datasets\\coco\\labels\\val2017"
|
||||
|
||||
epochs = 50
|
||||
batch_size = 36
|
||||
img_size = 640
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
# --- 2. 准备数据 ---
|
||||
print("Loading Data...")
|
||||
train_dataset = YOLODataset(img_dir_train, label_dir_train, img_size=img_size, is_train=True)
|
||||
val_dataset = YOLODataset(img_dir_val, label_dir_val, img_size=img_size, is_train=False)
|
||||
|
||||
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,
|
||||
collate_fn=YOLODataset.collate_fn, num_workers=8, pin_memory=True)
|
||||
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False,
|
||||
collate_fn=YOLODataset.collate_fn, num_workers=8, pin_memory=True)
|
||||
|
||||
# --- 3. 初始化模型 ---
|
||||
print("Initializing Model...")
|
||||
model = YOLO11(nc=80, scale='s')
|
||||
# model.load_weights("yolo11s.pth")
|
||||
# model.to(device)
|
||||
|
||||
strides = [8, 16, 32]
|
||||
hyp = {
|
||||
'box': 7.5,
|
||||
'cls': 0.5,
|
||||
'dfl': 1.5
|
||||
}
|
||||
loss_fn = YOLOv8DetectionLoss(nc=80, reg_max=16, stride=strides, hyp=hyp)
|
||||
|
||||
# --- 5. 初始化 Trainer 并开始训练 ---
|
||||
trainer = YOLO11Trainer(
|
||||
model=model,
|
||||
train_loader=train_loader,
|
||||
val_loader=val_loader,
|
||||
loss_fn=loss_fn,
|
||||
epochs=epochs,
|
||||
lr0=0.01,
|
||||
device=device,
|
||||
save_dir='./my_yolo_result'
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Windows下多进程dataloader需要这个保护
|
||||
run_training()
|
||||
275
trainer.py
Normal file
275
trainer.py
Normal file
@@ -0,0 +1,275 @@
|
||||
import math
|
||||
import copy
|
||||
import time
|
||||
import logging
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from tqdm import tqdm
|
||||
from metrics import ap_per_class, box_iou, non_max_suppression, xywh2xyxy
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(format="%(message)s", level=logging.INFO)
|
||||
LOGGER = logging.getLogger("YOLO_Trainer")
|
||||
|
||||
# ==============================================================================
|
||||
# Helper Class: Model EMA (Exponential Moving Average)
|
||||
# ==============================================================================
|
||||
class ModelEMA:
|
||||
""" Updated Exponential Moving Average (EMA) from Ultralytics """
|
||||
def __init__(self, model, decay=0.9999, tau=2000, updates=0):
|
||||
self.ema = copy.deepcopy(model).eval() # FP32 EMA
|
||||
self.updates = updates
|
||||
# decay exponential ramp (to help early epochs)
|
||||
self.decay = lambda x: decay * (1 - math.exp(-x / tau))
|
||||
for p in self.ema.parameters():
|
||||
p.requires_grad_(False)
|
||||
|
||||
def update(self, model):
|
||||
self.updates += 1
|
||||
d = self.decay(self.updates)
|
||||
|
||||
msd = model.state_dict()
|
||||
for k, v in self.ema.state_dict().items():
|
||||
if k in msd:
|
||||
tmp = msd[k].to(v.device)
|
||||
if v.dtype.is_floating_point:
|
||||
v *= d
|
||||
v += (1 - d) * tmp
|
||||
|
||||
# ==============================================================================
|
||||
# Main Trainer Class
|
||||
# ==============================================================================
|
||||
class YOLO11Trainer:
|
||||
def __init__(self,
|
||||
model,
|
||||
train_loader,
|
||||
val_loader,
|
||||
loss_fn,
|
||||
epochs=100,
|
||||
lr0=0.01,
|
||||
lrf=0.01,
|
||||
device='cuda',
|
||||
save_dir='./runs/train',
|
||||
warmup_epochs=3.0):
|
||||
|
||||
self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
|
||||
self.model = model.to(self.device)
|
||||
self.train_loader = train_loader
|
||||
self.val_loader = val_loader
|
||||
self.loss_fn = loss_fn
|
||||
self.epochs = epochs
|
||||
self.save_dir = Path(save_dir)
|
||||
self.save_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.warmup_epochs = warmup_epochs
|
||||
self.start_epoch = 0
|
||||
|
||||
# --- Optimizer Building ---
|
||||
g = [], [], [] # optimizer parameter groups
|
||||
bn = tuple(v for k, v in nn.__dict__.items() if 'Norm' in k)
|
||||
|
||||
for v in self.model.modules():
|
||||
for p_name, p in v.named_parameters(recurse=False):
|
||||
if p_name == 'bias':
|
||||
g[2].append(p) # biases
|
||||
elif isinstance(v, bn):
|
||||
g[1].append(p) # bn weights (no decay)
|
||||
else:
|
||||
g[0].append(p) # weights (decay)
|
||||
|
||||
self.optimizer = optim.SGD(g[2], lr=lr0, momentum=0.937, nesterov=True)
|
||||
self.optimizer.add_param_group({'params': g[0], 'weight_decay': 0.0005})
|
||||
self.optimizer.add_param_group({'params': g[1], 'weight_decay': 0.0})
|
||||
|
||||
LOGGER.info(f"Optimizer: weights={len(g[0])}, bn={len(g[1])}, biases={len(g[2])}")
|
||||
|
||||
# --- Scheduler ---
|
||||
self.lf = lambda x: ((1 - math.cos(x * math.pi / self.epochs)) / 2) * (lrf - 1) + 1
|
||||
self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=self.lf)
|
||||
|
||||
# --- AMP & EMA ---
|
||||
self.scaler = torch.amp.GradScaler('cuda', enabled=True)
|
||||
self.ema = ModelEMA(self.model)
|
||||
|
||||
def train_one_epoch(self, epoch):
|
||||
self.model.train()
|
||||
pbar = tqdm(enumerate(self.train_loader), total=len(self.train_loader),
|
||||
desc=f'Epoch {epoch+1}/{self.epochs}', leave=True)
|
||||
|
||||
mloss = torch.zeros(3, device=self.device)
|
||||
self.optimizer.zero_grad()
|
||||
nb = len(self.train_loader)
|
||||
|
||||
for i, batch in pbar:
|
||||
ni = i + nb * epoch
|
||||
|
||||
imgs, targets, paths = batch
|
||||
imgs = imgs.to(self.device, non_blocking=True)
|
||||
targets = targets.to(self.device)
|
||||
|
||||
# --- Warmup ---
|
||||
if ni <= nb * self.warmup_epochs:
|
||||
xp = [0, nb * self.warmup_epochs]
|
||||
for j, x in enumerate(self.optimizer.param_groups):
|
||||
lr_target = x['initial_lr'] * self.lf(epoch)
|
||||
x['lr'] = np.interp(ni, xp, [0.1 if j == 0 else 0.0, lr_target])
|
||||
if 'momentum' in x:
|
||||
x['momentum'] = np.interp(ni, xp, [0.8, 0.937])
|
||||
|
||||
# --- Forward ---
|
||||
with torch.amp.autocast('cuda', enabled=True):
|
||||
preds = self.model(imgs)
|
||||
|
||||
target_batch = {
|
||||
"batch_idx": targets[:, 0],
|
||||
"cls": targets[:, 1],
|
||||
"bboxes": targets[:, 2:],
|
||||
}
|
||||
|
||||
loss, loss_items = self.loss_fn(preds, target_batch)
|
||||
|
||||
# --- Backward ---
|
||||
self.scaler.scale(loss).backward()
|
||||
self.scaler.unscale_(self.optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=10.0)
|
||||
self.scaler.step(self.optimizer)
|
||||
self.scaler.update()
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
# --- EMA ---
|
||||
self.ema.update(self.model)
|
||||
|
||||
# --- Logging ---
|
||||
loss_items = loss_items.detach()
|
||||
mloss = (mloss * i + loss_items) / (i + 1)
|
||||
mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G'
|
||||
|
||||
pbar.set_postfix({
|
||||
'mem': mem,
|
||||
'box': f"{mloss[0]:.4f}",
|
||||
'cls': f"{mloss[1]:.4f}",
|
||||
'dfl': f"{mloss[2]:.4f}",
|
||||
'lr': f"{self.optimizer.param_groups[0]['lr']:.5f}"
|
||||
})
|
||||
|
||||
def validate(self):
|
||||
model = self.ema.ema
|
||||
device = self.device
|
||||
|
||||
# --- Metrics Config ---
|
||||
conf_thres = 0.001 # Low threshold for mAP calculation
|
||||
iou_thres = 0.7 # NMS IoU threshold
|
||||
iouv = torch.linspace(0.5, 0.95, 10, device=device) # IoU vector for mAP@0.5:0.95
|
||||
|
||||
loss_sum = torch.zeros(3, device=device)
|
||||
stats = [] # [(correct, conf, pred_cls, target_cls)]
|
||||
|
||||
LOGGER.info("\nValidating...")
|
||||
|
||||
model.eval()
|
||||
|
||||
pbar = tqdm(self.val_loader, desc="Calc Metrics")
|
||||
with torch.no_grad():
|
||||
for batch in pbar:
|
||||
imgs, targets, _ = batch
|
||||
imgs = imgs.to(device, non_blocking=True)
|
||||
targets = targets.to(device)
|
||||
_, _, height, width = imgs.shape
|
||||
|
||||
# Inference
|
||||
preds = model(imgs)
|
||||
|
||||
# NMS
|
||||
preds = non_max_suppression(preds, conf_thres=conf_thres, iou_thres=iou_thres)
|
||||
|
||||
# Metrics Processing
|
||||
for si, pred in enumerate(preds):
|
||||
labels = targets[targets[:, 0] == si]
|
||||
nl = len(labels)
|
||||
tcls = labels[:, 1].tolist() if nl else []
|
||||
|
||||
if len(pred) == 0:
|
||||
if nl:
|
||||
stats.append((torch.zeros(0, iouv.numel(), dtype=torch.bool),
|
||||
torch.Tensor(), torch.Tensor(), torch.tensor(tcls)))
|
||||
continue
|
||||
|
||||
# Predictions
|
||||
predn = pred.clone()
|
||||
|
||||
# Ground Truth
|
||||
if nl:
|
||||
tbox = xywh2xyxy(labels[:, 2:6])
|
||||
tbox[:, [0, 2]] *= width
|
||||
tbox[:, [1, 3]] *= height
|
||||
labelsn = torch.cat((labels[:, 1:2], tbox), 1) # [cls, x1, y1, x2, y2]
|
||||
|
||||
# Match predictions to GT
|
||||
correct = self._process_batch(predn, labelsn, iouv)
|
||||
else:
|
||||
correct = torch.zeros(pred.shape[0], iouv.numel(), dtype=torch.bool)
|
||||
|
||||
stats.append((correct.cpu(), pred[:, 4].cpu(), pred[:, 5].cpu(), torch.tensor(tcls)))
|
||||
|
||||
mp, mr, map50, map5095 = 0.0, 0.0, 0.0, 0.0
|
||||
if len(stats) and stats[0][0].any():
|
||||
stats = [torch.cat(x, 0).cpu().numpy() for x in zip(*stats)] # to numpy
|
||||
tp, fp, p, r, f1, ap, ap_class = ap_per_class(*stats)
|
||||
ap50, ap = ap[:, 0], ap.mean(1) # AP@0.5, AP@0.5:0.95
|
||||
mp, mr, map50, map5095 = p.mean(), r.mean(), ap50.mean(), ap.mean()
|
||||
|
||||
LOGGER.info(f"Val Results: Prec={mp:.3f}, Recall={mr:.3f} mAP50={map50:.3f} mAP50-95={map5095:.3f}")
|
||||
|
||||
return map50
|
||||
|
||||
def _process_batch(self, detections, labels, iouv):
|
||||
"""
|
||||
Return correct prediction matrix
|
||||
detections: [N, 6] (x1, y1, x2, y2, conf, cls)
|
||||
labels: [M, 5] (cls, x1, y1, x2, y2)
|
||||
"""
|
||||
correct = torch.zeros(detections.shape[0], iouv.shape[0], dtype=torch.bool, device=iouv.device)
|
||||
iou = box_iou(labels[:, 1:], detections[:, :4])
|
||||
|
||||
x = torch.where((iou >= iouv[0]) & (labels[:, 0:1] == detections[:, 5]))
|
||||
|
||||
if x[0].shape[0]:
|
||||
matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1).cpu().numpy()
|
||||
if x[0].shape[0] > 1:
|
||||
matches = matches[matches[:, 2].argsort()[::-1]]
|
||||
matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
|
||||
matches = matches[matches[:, 2].argsort()[::-1]]
|
||||
matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
|
||||
|
||||
matches = torch.from_numpy(matches).to(iouv.device)
|
||||
correct[matches[:, 1].long()] = matches[:, 2:3] >= iouv
|
||||
|
||||
return correct
|
||||
|
||||
def train(self):
|
||||
LOGGER.info(f"Starting training on {self.device} for {self.epochs} epochs...")
|
||||
start_time = time.time()
|
||||
best_fitness = 0.0
|
||||
|
||||
for epoch in range(self.start_epoch, self.epochs):
|
||||
self.train_one_epoch(epoch)
|
||||
self.scheduler.step()
|
||||
map50 = self.validate()
|
||||
|
||||
ckpt = {
|
||||
'epoch': epoch,
|
||||
'model': self.model.state_dict(),
|
||||
'ema': self.ema.ema.state_dict(),
|
||||
'optimizer': self.optimizer.state_dict(),
|
||||
}
|
||||
|
||||
torch.save(ckpt, self.save_dir / 'last.pt')
|
||||
|
||||
if map50 > best_fitness:
|
||||
best_fitness = map50
|
||||
torch.save(ckpt, self.save_dir / 'best.pt')
|
||||
LOGGER.info(f"--> Saved best model with Recall/mAP: {best_fitness:.4f}")
|
||||
|
||||
LOGGER.info(f"\nTraining completed in {(time.time() - start_time) / 3600:.3f} hours.")
|
||||
923
yolo11_standalone.py
Normal file
923
yolo11_standalone.py
Normal file
@@ -0,0 +1,923 @@
|
||||
import math
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
def make_divisible(x, divisor):
|
||||
if isinstance(x, torch.Tensor):
|
||||
return x
|
||||
return math.ceil(x / divisor) * divisor
|
||||
|
||||
def autopad(k, p=None, d=1):
|
||||
if d > 1:
|
||||
k = d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k]
|
||||
if p is None:
|
||||
p = k // 2 if isinstance(k, int) else [x // 2 for x in k]
|
||||
return p
|
||||
|
||||
def make_anchors(feats, strides, grid_cell_offset=0.5):
|
||||
anchor_points, stride_tensor = [], []
|
||||
assert feats is not None
|
||||
dtype, device = feats[0].dtype, feats[0].device
|
||||
for i, stride in enumerate(strides):
|
||||
_, _, h, w = feats[i].shape
|
||||
sx = torch.arange(end=w, device=device, dtype=dtype) + grid_cell_offset
|
||||
sy = torch.arange(end=h, device=device, dtype=dtype) + grid_cell_offset
|
||||
sy, sx = torch.meshgrid(sy, sx, indexing="ij")
|
||||
anchor_points.append(torch.stack((sx, sy), -1).view(-1, 2))
|
||||
stride_tensor.append(torch.full((h * w, 1), stride, dtype=dtype, device=device))
|
||||
return torch.cat(anchor_points), torch.cat(stride_tensor)
|
||||
|
||||
def dist2bbox(distance, anchor_points, xywh=True, dim=-1):
|
||||
lt, rb = distance.chunk(2, dim)
|
||||
x1y1 = anchor_points - lt
|
||||
x2y2 = anchor_points + rb
|
||||
if xywh:
|
||||
c_xy = (x1y1 + x2y2) / 2
|
||||
wh = x2y2 - x1y1
|
||||
return torch.cat((c_xy, wh), dim)
|
||||
return torch.cat((x1y1, x2y2), dim)
|
||||
|
||||
class Concat(nn.Module):
|
||||
def __init__(self, dimension=1):
|
||||
super().__init__()
|
||||
self.d = dimension
|
||||
|
||||
def forward(self, x: List[torch.Tensor]):
|
||||
return torch.cat(x, self.d)
|
||||
|
||||
|
||||
class Conv(nn.Module):
|
||||
default_act = nn.SiLU(inplace=True)
|
||||
|
||||
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p, d), groups=g, dilation=d, bias=False) # type: ignore
|
||||
self.bn = nn.BatchNorm2d(c2, eps=0.001, momentum=0.03)
|
||||
self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
return self.act(self.bn(self.conv(x)))
|
||||
|
||||
def forward_fuse(self, x):
|
||||
return self.act(self.conv(x))
|
||||
|
||||
class DWConv(Conv):
|
||||
def __init__(self, c1, c2, k=1, s=1, d=1, act=True):
|
||||
super().__init__(c1, c2, k, s, g=math.gcd(c1, c2), d=d, act=act)
|
||||
|
||||
class DFL(nn.Module):
|
||||
def __init__(self, c1: int = 16):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(c1, 1, 1, bias=False).requires_grad_(False)
|
||||
x = torch.arange(c1, dtype=torch.float)
|
||||
self.conv.weight.data[:] = nn.Parameter(x.view(1, c1, 1, 1))
|
||||
self.c1 = c1
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
b, _, a = x.shape
|
||||
return self.conv(x.view(b, 4, self.c1, a).transpose(2, 1).softmax(1)).view(b, 4, a)
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
def __init__(
|
||||
self, c1: int, c2: int, shortcut: bool = True, g: int = 1, k: Tuple[int, int] = (3, 3), e: float = 0.5
|
||||
):
|
||||
super().__init__()
|
||||
c_ = int(c2 * e)
|
||||
self.cv1 = Conv(c1, c_, k[0], 1)
|
||||
self.cv2 = Conv(c_, c2, k[1], 1, g=g)
|
||||
self.add = shortcut and c1 == c2
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
|
||||
|
||||
class C2f(nn.Module):
|
||||
def __init__(self, c1: int, c2: int, n: int = 1, shortcut: bool = False, g: int = 1, e: float = 0.5):
|
||||
super().__init__()
|
||||
self.c = int(c2 * e)
|
||||
self.cv1 = Conv(c1, 2 * self.c, 1, 1)
|
||||
self.cv2 = Conv((2 + n) * self.c, c2, 1)
|
||||
self.m = nn.ModuleList(Bottleneck(self.c, self.c, shortcut, g, k=((3, 3), (3, 3)), e=1.0) for _ in range(n)) # type: ignore
|
||||
|
||||
def forward(self, x):
|
||||
chunk_result = self.cv1(x).chunk(2, 1)
|
||||
y = [chunk_result[0], chunk_result[1]]
|
||||
|
||||
for m_module in self.m:
|
||||
y.append(m_module(y[-1]))
|
||||
return self.cv2(torch.cat(y, 1))
|
||||
|
||||
def forward_split(self, x: torch.Tensor) -> torch.Tensor:
|
||||
y = self.cv1(x).split((self.c, self.c), 1)
|
||||
y = [y[0], y[1]]
|
||||
y.extend(m(y[-1]) for m in self.m)
|
||||
return self.cv2(torch.cat(y, 1))
|
||||
|
||||
class C3(nn.Module):
|
||||
def __init__(self, c1: int, c2: int, n: int = 1, shortcut: bool = True, g: int = 1, e: float = 0.5):
|
||||
super().__init__()
|
||||
c_ = int(c2 * e)
|
||||
self.cv1 = Conv(c1, c_, 1, 1)
|
||||
self.cv2 = Conv(c1, c_, 1, 1)
|
||||
self.cv3 = Conv(2 * c_, c2, 1)
|
||||
self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, k=((1, 1), (3, 3)), e=1.0) for _ in range(n))) # type: ignore
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), 1))
|
||||
|
||||
class C3k(C3):
|
||||
def __init__(self, c1: int, c2: int, n: int = 1, shortcut: bool = True, g: int = 1, e: float = 0.5, k: int = 3):
|
||||
super().__init__(c1, c2, n, shortcut, g, e)
|
||||
c_ = int(c2 * e)
|
||||
self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, k=(k, k), e=1.0) for _ in range(n)))
|
||||
|
||||
class C3k2(C2f):
|
||||
def __init__(
|
||||
self, c1: int, c2: int, n: int = 1, c3k: bool = False, e: float = 0.5, g: int = 1, shortcut: bool = True
|
||||
):
|
||||
super().__init__(c1, c2, n, shortcut, g, e)
|
||||
self.m = nn.ModuleList(
|
||||
C3k(self.c, self.c, 2, shortcut, g) if c3k else Bottleneck(self.c, self.c, shortcut, g) for _ in range(n)
|
||||
)
|
||||
|
||||
class SPPF(nn.Module):
|
||||
def __init__(self, c1: int, c2: int, k: int = 5):
|
||||
super().__init__()
|
||||
c_ = c1 // 2
|
||||
self.cv1 = Conv(c1, c_, 1, 1)
|
||||
self.cv2 = Conv(c_ * 4, c2, 1, 1)
|
||||
self.m = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
y = [self.cv1(x)]
|
||||
y.extend(self.m(y[-1]) for _ in range(3))
|
||||
return self.cv2(torch.cat(y, 1))
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim: int, num_heads: int = 8, attn_ratio: float = 0.5):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.key_dim = int(self.head_dim * attn_ratio)
|
||||
self.scale = self.key_dim**-0.5
|
||||
nh_kd = self.key_dim * num_heads
|
||||
h = dim + nh_kd * 2
|
||||
self.qkv = Conv(dim, h, 1, act=False)
|
||||
self.proj = Conv(dim, dim, 1, act=False)
|
||||
self.pe = Conv(dim, dim, 3, 1, g=dim, act=False)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
B, C, H, W = x.shape
|
||||
N = H * W
|
||||
qkv = self.qkv(x)
|
||||
q, k, v = qkv.view(B, self.num_heads, self.key_dim * 2 + self.head_dim, N).split(
|
||||
[self.key_dim, self.key_dim, self.head_dim], dim=2
|
||||
)
|
||||
|
||||
attn = (q.transpose(-2, -1) @ k) * self.scale
|
||||
attn = attn.softmax(dim=-1)
|
||||
x = (v @ attn.transpose(-2, -1)).view(B, C, H, W) + self.pe(v.reshape(B, C, H, W))
|
||||
x = self.proj(x)
|
||||
return x
|
||||
|
||||
class PSABlock(nn.Module):
|
||||
def __init__(self, c: int, attn_ratio: float = 0.5, num_heads: int = 4, shortcut: bool = True) -> None:
|
||||
super().__init__()
|
||||
self.attn = Attention(c, attn_ratio=attn_ratio, num_heads=num_heads)
|
||||
self.ffn = nn.Sequential(Conv(c, c * 2, 1), Conv(c * 2, c, 1, act=False))
|
||||
self.add = shortcut
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = x + self.attn(x) if self.add else self.attn(x)
|
||||
x = x + self.ffn(x) if self.add else self.ffn(x)
|
||||
return x
|
||||
|
||||
class C2PSA(nn.Module):
|
||||
def __init__(self, c1: int, c2: int, n: int = 1, e: float = 0.5):
|
||||
super().__init__()
|
||||
assert c1 == c2
|
||||
self.c = int(c1 * e)
|
||||
self.cv1 = Conv(c1, 2 * self.c, 1, 1)
|
||||
self.cv2 = Conv(2 * self.c, c1, 1)
|
||||
|
||||
self.m = nn.Sequential(*(PSABlock(self.c, attn_ratio=0.5, num_heads=self.c // 64) for _ in range(n)))
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
a, b = self.cv1(x).split((self.c, self.c), dim=1)
|
||||
b = self.m(b)
|
||||
return self.cv2(torch.cat((a, b), 1))
|
||||
|
||||
class Proto(nn.Module):
|
||||
def __init__(self, c1: int, c_: int = 256, c2: int = 32):
|
||||
super().__init__()
|
||||
self.cv1 = Conv(c1, c_, k=3)
|
||||
self.upsample = nn.ConvTranspose2d(c_, c_, 2, 2, 0, bias=True)
|
||||
self.cv2 = Conv(c_, c_, k=3)
|
||||
self.cv3 = Conv(c_, c2)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.cv3(self.cv2(self.upsample(self.cv1(x))))
|
||||
|
||||
class BNContrastiveHead(nn.Module):
|
||||
def __init__(self, embed_dims: int):
|
||||
super().__init__()
|
||||
self.norm = nn.BatchNorm2d(embed_dims)
|
||||
self.bias = nn.Parameter(torch.tensor([-10.0]))
|
||||
self.logit_scale = nn.Parameter(-1.0 * torch.ones([]))
|
||||
|
||||
def fuse(self):
|
||||
del self.norm
|
||||
del self.bias
|
||||
del self.logit_scale
|
||||
self.forward = self.forward_fuse
|
||||
|
||||
def forward_fuse(self, x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
|
||||
return x
|
||||
|
||||
def forward(self, x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
|
||||
x = self.norm(x)
|
||||
w = F.normalize(w, dim=-1, p=2)
|
||||
|
||||
x = torch.einsum("bchw,bkc->bkhw", x, w)
|
||||
return x * self.logit_scale.exp() + self.bias
|
||||
|
||||
class SwiGLUFFN(nn.Module):
|
||||
def __init__(self, gc: int, ec: int, e: int = 4) -> None:
|
||||
super().__init__()
|
||||
self.w12 = nn.Linear(gc, e * ec)
|
||||
self.w3 = nn.Linear(e * ec // 2, ec)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x12 = self.w12(x)
|
||||
x1, x2 = x12.chunk(2, dim=-1)
|
||||
hidden = F.silu(x1) * x2
|
||||
return self.w3(hidden)
|
||||
|
||||
class Residual(nn.Module):
|
||||
def __init__(self, m: nn.Module) -> None:
|
||||
super().__init__()
|
||||
self.m = m
|
||||
nn.init.zeros_(self.m.w3.bias)
|
||||
# For models with l scale, please change the initialization to
|
||||
# nn.init.constant_(self.m.w3.weight, 1e-6)
|
||||
nn.init.zeros_(self.m.w3.weight)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return x + self.m(x)
|
||||
|
||||
class SAVPE(nn.Module):
|
||||
def __init__(self, ch: List[int], c3: int, embed: int):
|
||||
super().__init__()
|
||||
self.cv1 = nn.ModuleList(
|
||||
nn.Sequential(
|
||||
Conv(x, c3, 3), Conv(c3, c3, 3), nn.Upsample(scale_factor=i * 2) if i in {1, 2} else nn.Identity()
|
||||
)
|
||||
for i, x in enumerate(ch)
|
||||
)
|
||||
|
||||
self.cv2 = nn.ModuleList(
|
||||
nn.Sequential(Conv(x, c3, 1), nn.Upsample(scale_factor=i * 2) if i in {1, 2} else nn.Identity())
|
||||
for i, x in enumerate(ch)
|
||||
)
|
||||
|
||||
self.c = 16
|
||||
self.cv3 = nn.Conv2d(3 * c3, embed, 1)
|
||||
self.cv4 = nn.Conv2d(3 * c3, self.c, 3, padding=1)
|
||||
self.cv5 = nn.Conv2d(1, self.c, 3, padding=1)
|
||||
self.cv6 = nn.Sequential(Conv(2 * self.c, self.c, 3), nn.Conv2d(self.c, self.c, 3, padding=1))
|
||||
|
||||
def forward(self, x: List[torch.Tensor], vp: torch.Tensor) -> torch.Tensor:
|
||||
y = [self.cv2[i](xi) for i, xi in enumerate(x)]
|
||||
y = self.cv4(torch.cat(y, dim=1))
|
||||
|
||||
x = [self.cv1[i](xi) for i, xi in enumerate(x)]
|
||||
x = self.cv3(torch.cat(x, dim=1))
|
||||
|
||||
B, C, H, W = x.shape # type: ignore
|
||||
|
||||
Q = vp.shape[1]
|
||||
|
||||
x = x.view(B, C, -1) # type: ignore
|
||||
|
||||
y = y.reshape(B, 1, self.c, H, W).expand(-1, Q, -1, -1, -1).reshape(B * Q, self.c, H, W)
|
||||
vp = vp.reshape(B, Q, 1, H, W).reshape(B * Q, 1, H, W)
|
||||
|
||||
y = self.cv6(torch.cat((y, self.cv5(vp)), dim=1))
|
||||
|
||||
y = y.reshape(B, Q, self.c, -1)
|
||||
vp = vp.reshape(B, Q, 1, -1)
|
||||
|
||||
score = y * vp + torch.logical_not(vp) * torch.finfo(y.dtype).min
|
||||
score = F.softmax(score, dim=-1).to(y.dtype)
|
||||
aggregated = score.transpose(-2, -3) @ x.reshape(B, self.c, C // self.c, -1).transpose(-1, -2)
|
||||
|
||||
return F.normalize(aggregated.transpose(-2, -3).reshape(B, Q, -1), dim=-1, p=2)
|
||||
|
||||
class Detect(nn.Module):
|
||||
dynamic = False
|
||||
export = False
|
||||
shape = None
|
||||
anchors = torch.empty(0)
|
||||
strides = torch.empty(0)
|
||||
|
||||
def __init__(self, nc=80, ch=()):
|
||||
super().__init__()
|
||||
self.nc = nc
|
||||
self.nl = len(ch)
|
||||
self.reg_max = 16
|
||||
self.no = nc + self.reg_max * 4
|
||||
self.stride = torch.tensor([8., 16., 32.])
|
||||
|
||||
c2, c3 = max((16, ch[0] // 4, self.reg_max * 4)), max(ch[0], min(self.nc, 100))
|
||||
|
||||
self.cv2 = nn.ModuleList(
|
||||
nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3), nn.Conv2d(c2, 4 * self.reg_max, 1)) for x in ch
|
||||
)
|
||||
self.cv3 = nn.ModuleList(
|
||||
nn.Sequential(
|
||||
nn.Sequential(DWConv(x, x, 3), Conv(x, c3, 1)),
|
||||
nn.Sequential(DWConv(c3, c3, 3), Conv(c3, c3, 1)),
|
||||
nn.Conv2d(c3, self.nc, 1),
|
||||
)
|
||||
for x in ch
|
||||
)
|
||||
self.dfl = DFL(self.reg_max) if self.reg_max > 1 else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
for i in range(self.nl):
|
||||
x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1)
|
||||
|
||||
if self.training:
|
||||
return x
|
||||
|
||||
# Inference path
|
||||
shape = x[0].shape
|
||||
x_cat = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2)
|
||||
|
||||
if self.dynamic or self.shape != shape:
|
||||
self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
|
||||
self.shape = shape
|
||||
|
||||
box, cls = x_cat.split((self.reg_max * 4, self.nc), 1)
|
||||
dbox = dist2bbox(self.dfl(box), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
|
||||
|
||||
return torch.cat((dbox, cls.sigmoid()), 1)
|
||||
|
||||
def bias_init(self):
|
||||
m = self
|
||||
for a, b, s in zip(m.cv2, m.cv3, m.stride):
|
||||
a[-1].bias.data[:] = 1.0 # type: ignore
|
||||
b[-1].bias.data[: m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # type: ignore
|
||||
|
||||
def decode_bboxes(self, bboxes: torch.Tensor, anchors: torch.Tensor, xywh: bool = True) -> torch.Tensor:
|
||||
return dist2bbox(bboxes, anchors, xywh=xywh and not (self.end2end or self.xyxy), dim=1)
|
||||
|
||||
@staticmethod
|
||||
def postprocess(preds: torch.Tensor, max_det: int, nc: int = 80) -> torch.Tensor:
|
||||
batch_size, anchors, _ = preds.shape
|
||||
boxes, scores = preds.split([4, nc], dim=-1)
|
||||
index = scores.amax(dim=-1).topk(min(max_det, anchors))[1].unsqueeze(-1)
|
||||
boxes = boxes.gather(dim=1, index=index.repeat(1, 1, 4))
|
||||
scores = scores.gather(dim=1, index=index.repeat(1, 1, nc))
|
||||
scores, index = scores.flatten(1).topk(min(max_det, anchors))
|
||||
i = torch.arange(batch_size)[..., None]
|
||||
return torch.cat([boxes[i, index // nc], scores[..., None], (index % nc)[..., None].float()], dim=-1)
|
||||
|
||||
class YOLOEDetect(Detect):
|
||||
def __init__(self, nc: int = 80, embed: int = 512, ch: Tuple = ()):
|
||||
super().__init__(nc, ch)
|
||||
c3 = max(ch[0], min(self.nc, 100))
|
||||
assert c3 <= embed
|
||||
self.cv3 = (
|
||||
nn.ModuleList(
|
||||
nn.Sequential(
|
||||
nn.Sequential(DWConv(x, x, 3), Conv(x, c3, 1)),
|
||||
nn.Sequential(DWConv(c3, c3, 3), Conv(c3, c3, 1)),
|
||||
nn.Conv2d(c3, embed, 1),
|
||||
)
|
||||
for x in ch
|
||||
)
|
||||
)
|
||||
|
||||
self.cv4 = nn.ModuleList(BNContrastiveHead(embed) for _ in ch)
|
||||
|
||||
self.reprta = Residual(SwiGLUFFN(embed, embed))
|
||||
self.savpe = SAVPE(ch, c3, embed) # type: ignore
|
||||
self.embed = embed
|
||||
|
||||
def get_tpe(self, tpe: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
|
||||
return None if tpe is None else F.normalize(self.reprta(tpe), dim=-1, p=2)
|
||||
|
||||
def get_vpe(self, x: List[torch.Tensor], vpe: torch.Tensor) -> torch.Tensor:
|
||||
if vpe.shape[1] == 0: # no visual prompt embeddings
|
||||
return torch.zeros(x[0].shape[0], 0, self.embed, device=x[0].device)
|
||||
if vpe.ndim == 4: # (B, N, H, W)
|
||||
vpe = self.savpe(x, vpe)
|
||||
assert vpe.ndim == 3 # (B, N, D)
|
||||
return vpe
|
||||
|
||||
def forward( # type: ignore
|
||||
self, x: List[torch.Tensor], cls_pe: torch.Tensor
|
||||
) -> Union[torch.Tensor, Tuple]:
|
||||
for i in range(self.nl):
|
||||
x[i] = torch.cat((self.cv2[i](x[i]), self.cv4[i](self.cv3[i](x[i]), cls_pe)), 1)
|
||||
if self.training:
|
||||
return x # type: ignore
|
||||
self.no = self.nc + self.reg_max * 4 # self.nc could be changed when inference with different texts
|
||||
# Inference path
|
||||
shape = x[0].shape
|
||||
x_cat = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2)
|
||||
|
||||
if self.dynamic or self.shape != shape:
|
||||
self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
|
||||
self.shape = shape
|
||||
|
||||
box, cls = x_cat.split((self.reg_max * 4, self.nc), 1)
|
||||
dbox = dist2bbox(self.dfl(box), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
|
||||
|
||||
return torch.cat((dbox, cls.sigmoid()), 1)
|
||||
|
||||
def bias_init(self):
|
||||
m = self
|
||||
for a, b, c, s in zip(m.cv2, m.cv3, m.cv4, m.stride):
|
||||
a[-1].bias.data[:] = 1.0 # box
|
||||
b[-1].bias.data[:] = 0.0
|
||||
c.bias.data[:] = math.log(5 / m.nc / (640 / s) ** 2)
|
||||
|
||||
class YOLOESegment(YOLOEDetect):
|
||||
def __init__(
|
||||
self, nc: int = 80, nm: int = 32, npr: int = 256, embed: int = 512, ch: Tuple = ()
|
||||
):
|
||||
super().__init__(nc, embed, ch)
|
||||
self.nm = nm
|
||||
self.npr = npr
|
||||
self.proto = Proto(ch[0], self.npr, self.nm)
|
||||
|
||||
c5 = max(ch[0] // 4, self.nm)
|
||||
self.cv5 = nn.ModuleList(nn.Sequential(Conv(x, c5, 3), Conv(c5, c5, 3), nn.Conv2d(c5, self.nm, 1)) for x in ch)
|
||||
|
||||
def forward(self, x: List[torch.Tensor], text: torch.Tensor) -> Union[Tuple, torch.Tensor]:
|
||||
p = self.proto(x[0]) # mask protos
|
||||
bs = p.shape[0] # batch size
|
||||
|
||||
mc = torch.cat([self.cv5[i](x[i]).view(bs, self.nm, -1) for i in range(self.nl)], 2) # mask coefficients
|
||||
x = YOLOEDetect.forward(self, x, text)
|
||||
|
||||
if self.training:
|
||||
return x, mc, p
|
||||
|
||||
return (torch.cat([x, mc], 1), p)
|
||||
|
||||
|
||||
|
||||
class YOLO11(nn.Module):
|
||||
def __init__(self, nc=80, scale='n'):
|
||||
super().__init__()
|
||||
self.nc = nc
|
||||
|
||||
# Scales: [depth, width, max_channels]
|
||||
# 对应 yolo11.yaml 中的 scales 参数
|
||||
scales = {
|
||||
'n': [0.50, 0.25, 1024],
|
||||
's': [0.50, 0.50, 1024],
|
||||
'm': [0.50, 1.00, 512],
|
||||
'l': [1.00, 1.00, 512],
|
||||
'x': [1.00, 1.50, 512],
|
||||
}
|
||||
|
||||
if scale not in scales:
|
||||
raise ValueError(f"Invalid scale '{scale}'. Available scales: {list(scales.keys())}")
|
||||
|
||||
depth, width, max_channels = scales[scale]
|
||||
|
||||
if scale in ['n', 's']:
|
||||
c3k_override = False
|
||||
else:
|
||||
c3k_override = True
|
||||
|
||||
# 辅助函数:计算通道数 (Width Scaling)
|
||||
def gw(channels):
|
||||
return make_divisible(min(channels, max_channels) * width, 8)
|
||||
|
||||
# 辅助函数:计算层重复次数 (Depth Scaling)
|
||||
def gd(n):
|
||||
return max(round(n * depth), 1) if n > 1 else n
|
||||
|
||||
self.model = nn.ModuleList()
|
||||
|
||||
# --- Backbone ---
|
||||
# 0: Conv [64, 3, 2]
|
||||
self.model.append(Conv(3, gw(64), 3, 2))
|
||||
|
||||
# 1: Conv [128, 3, 2]
|
||||
self.model.append(Conv(gw(64), gw(128), 3, 2))
|
||||
|
||||
# 2: C3k2 [256, False, 0.25] -> n=2
|
||||
self.model.append(C3k2(gw(128), gw(256), n=gd(2), c3k=False or c3k_override, e=0.25))
|
||||
|
||||
# 3: Conv [256, 3, 2]
|
||||
self.model.append(Conv(gw(256), gw(256), 3, 2))
|
||||
|
||||
# 4: C3k2 [512, False, 0.25] -> n=2
|
||||
self.model.append(C3k2(gw(256), gw(512), n=gd(2), c3k=False or c3k_override, e=0.25))
|
||||
|
||||
# 5: Conv [512, 3, 2]
|
||||
self.model.append(Conv(gw(512), gw(512), 3, 2))
|
||||
|
||||
# 6: C3k2 [512, True] -> n=2
|
||||
self.model.append(C3k2(gw(512), gw(512), n=gd(2), c3k=True))
|
||||
|
||||
# 7: Conv [1024, 3, 2]
|
||||
self.model.append(Conv(gw(512), gw(1024), 3, 2))
|
||||
|
||||
# 8: C3k2 [1024, True] -> n=2
|
||||
self.model.append(C3k2(gw(1024), gw(1024), n=gd(2), c3k=True))
|
||||
|
||||
# 9: SPPF [1024, 5]
|
||||
self.model.append(SPPF(gw(1024), gw(1024), 5))
|
||||
|
||||
# 10: C2PSA [1024] -> n=2 (YAML args=[1024], repeats=2)
|
||||
self.model.append(C2PSA(gw(1024), gw(1024), n=gd(2)))
|
||||
|
||||
# --- Head ---
|
||||
|
||||
# 11: Upsample
|
||||
self.model.append(nn.Upsample(scale_factor=2, mode='nearest'))
|
||||
|
||||
# 12: Concat [-1, 6] (P4)
|
||||
self.model.append(Concat(dimension=1))
|
||||
|
||||
# 13: C3k2 [512, False] -> n=2. Input: P5_up(gw(1024)) + P4(gw(512))
|
||||
self.model.append(C3k2(gw(1024) + gw(512), gw(512), n=gd(2), c3k=False or c3k_override))
|
||||
|
||||
# 14: Upsample
|
||||
self.model.append(nn.Upsample(scale_factor=2, mode='nearest'))
|
||||
|
||||
# 15: Concat [-1, 4] (P3)
|
||||
self.model.append(Concat(dimension=1))
|
||||
|
||||
# 16: C3k2 [256, False] -> n=2. Input: P4_up(gw(512)) + P3(gw(512))
|
||||
# 注意:Layer 4 输出是 gw(512),Layer 13 输出是 gw(512)
|
||||
self.model.append(C3k2(gw(512) + gw(512), gw(256), n=gd(2), c3k=False or c3k_override))
|
||||
|
||||
# 17: Conv [256, 3, 2]
|
||||
self.model.append(Conv(gw(256), gw(256), 3, 2))
|
||||
|
||||
# 18: Concat [-1, 13] (Head P4)
|
||||
self.model.append(Concat(dimension=1))
|
||||
|
||||
# 19: C3k2 [512, False] -> n=2. Input: P3_down(gw(256)) + Head_P4(gw(512))
|
||||
self.model.append(C3k2(gw(256) + gw(512), gw(512), n=gd(2), c3k=False or c3k_override))
|
||||
|
||||
# 20: Conv [512, 3, 2]
|
||||
self.model.append(Conv(gw(512), gw(512), 3, 2))
|
||||
|
||||
# 21: Concat [-1, 10] (P5)
|
||||
self.model.append(Concat(dimension=1))
|
||||
|
||||
# 22: C3k2 [1024, True] -> n=2. Input: P4_down(gw(512)) + P5(gw(1024))
|
||||
self.model.append(C3k2(gw(512) + gw(1024), gw(1024), n=gd(2), c3k=True))
|
||||
|
||||
# 23: Detect [nc]
|
||||
self.model.append(Detect(nc, ch=[gw(256), gw(512), gw(1024)]))
|
||||
|
||||
# --- 初始化权重 ---
|
||||
self.initialize_weights()
|
||||
|
||||
def initialize_weights(self):
|
||||
"""初始化模型权重,特别是 Detect 头的 Bias"""
|
||||
for m in self.modules():
|
||||
if isinstance(m, (nn.Conv2d, nn.Linear)):
|
||||
# 使用 Kaiming 初始化或其他合适的初始化
|
||||
pass
|
||||
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
detect_layer = self.model[-1]
|
||||
if isinstance(detect_layer, Detect):
|
||||
detect_layer.bias_init()
|
||||
|
||||
def forward(self, x):
|
||||
# Backbone
|
||||
x = self.model[0](x)
|
||||
x = self.model[1](x)
|
||||
x = self.model[2](x)
|
||||
x = self.model[3](x)
|
||||
p3 = self.model[4](x) # 保存 P3 (layer 4)
|
||||
x = self.model[5](p3)
|
||||
p4 = self.model[6](x) # 保存 P4 (layer 6)
|
||||
x = self.model[7](p4)
|
||||
x = self.model[8](x)
|
||||
x = self.model[9](x)
|
||||
p5 = self.model[10](x) # 保存 P5 (layer 10)
|
||||
|
||||
# Head
|
||||
x = self.model[11](p5) # Upsample
|
||||
x = self.model[12]([x, p4]) # Concat P4
|
||||
h1 = self.model[13](x) # Head P4 (layer 13)
|
||||
|
||||
x = self.model[14](h1) # Upsample
|
||||
x = self.model[15]([x, p3]) # Concat P3
|
||||
h2 = self.model[16](x) # Output P3 (layer 16)
|
||||
|
||||
x = self.model[17](h2) # Conv
|
||||
x = self.model[18]([x, h1]) # Concat Head P4
|
||||
h3 = self.model[19](x) # Output P4 (layer 19)
|
||||
|
||||
x = self.model[20](h3) # Conv
|
||||
x = self.model[21]([x, p5]) # Concat P5
|
||||
h4 = self.model[22](x) # Output P5 (layer 22)
|
||||
|
||||
return self.model[23]([h2, h3, h4]) # Detect
|
||||
|
||||
def load_weights(self, pth_file):
|
||||
state_dict = torch.load(pth_file, map_location='cpu', weights_only=False)
|
||||
# 移除可能存在的 'model.' 前缀 (如果权重来自 ultralytics 官方)
|
||||
# 官方权重通常是 model.model.0.conv... 这种格式,或者直接是 model.0.conv...
|
||||
# 这里做一个简单的兼容性处理
|
||||
new_state_dict = {}
|
||||
for k, v in state_dict.items():
|
||||
# 处理 ultralytics 权重字典中的 'model' 键
|
||||
if k == 'model':
|
||||
# 如果是完整的 checkpoint,权重在 'model' 键下
|
||||
# 且通常是 model.state_dict()
|
||||
if hasattr(v, 'state_dict'):
|
||||
v = v.state_dict()
|
||||
elif isinstance(v, dict):
|
||||
pass # v 就是 state_dict
|
||||
else:
|
||||
# 可能是 model 对象本身
|
||||
try:
|
||||
v = v.float().state_dict()
|
||||
except:
|
||||
continue
|
||||
|
||||
for sub_k, sub_v in v.items():
|
||||
new_state_dict[sub_k] = sub_v
|
||||
break
|
||||
else:
|
||||
new_state_dict[k] = v
|
||||
|
||||
if not new_state_dict:
|
||||
new_state_dict = state_dict
|
||||
|
||||
# 尝试加载
|
||||
try:
|
||||
self.load_state_dict(new_state_dict, strict=True)
|
||||
print(f"Successfully loaded weights from {pth_file}")
|
||||
except Exception as e:
|
||||
print(f"Error loading weights: {e}")
|
||||
print("Trying to load with strict=False...")
|
||||
self.load_state_dict(new_state_dict, strict=False)
|
||||
|
||||
class YOLO11E(nn.Module):
|
||||
def __init__(self, nc=80, scale='n'):
|
||||
super().__init__()
|
||||
self.nc = nc
|
||||
self.pe = None
|
||||
|
||||
# Scales: [depth, width, max_channels]
|
||||
# 对应 yolo11.yaml 中的 scales 参数
|
||||
scales = {
|
||||
'n': [0.50, 0.25, 1024],
|
||||
's': [0.50, 0.50, 1024],
|
||||
'm': [0.50, 1.00, 512],
|
||||
'l': [1.00, 1.00, 512],
|
||||
'x': [1.00, 1.50, 512],
|
||||
}
|
||||
|
||||
if scale not in scales:
|
||||
raise ValueError(f"Invalid scale '{scale}'. Available scales: {list(scales.keys())}")
|
||||
|
||||
depth, width, max_channels = scales[scale]
|
||||
|
||||
if scale in ['n', 's']:
|
||||
c3k_override = False
|
||||
else:
|
||||
c3k_override = True
|
||||
|
||||
# 辅助函数:计算通道数 (Width Scaling)
|
||||
def gw(channels):
|
||||
return make_divisible(min(channels, max_channels) * width, 8)
|
||||
|
||||
# 辅助函数:计算层重复次数 (Depth Scaling)
|
||||
def gd(n):
|
||||
return max(round(n * depth), 1) if n > 1 else n
|
||||
|
||||
self.model = nn.ModuleList()
|
||||
|
||||
# --- Backbone ---
|
||||
# 0: Conv [64, 3, 2]
|
||||
self.model.append(Conv(3, gw(64), 3, 2))
|
||||
|
||||
# 1: Conv [128, 3, 2]
|
||||
self.model.append(Conv(gw(64), gw(128), 3, 2))
|
||||
|
||||
# 2: C3k2 [256, False, 0.25] -> n=2
|
||||
self.model.append(C3k2(gw(128), gw(256), n=gd(2), c3k=False or c3k_override, e=0.25))
|
||||
|
||||
# 3: Conv [256, 3, 2]
|
||||
self.model.append(Conv(gw(256), gw(256), 3, 2))
|
||||
|
||||
# 4: C3k2 [512, False, 0.25] -> n=2
|
||||
self.model.append(C3k2(gw(256), gw(512), n=gd(2), c3k=False or c3k_override, e=0.25))
|
||||
|
||||
# 5: Conv [512, 3, 2]
|
||||
self.model.append(Conv(gw(512), gw(512), 3, 2))
|
||||
|
||||
# 6: C3k2 [512, True] -> n=2
|
||||
self.model.append(C3k2(gw(512), gw(512), n=gd(2), c3k=True))
|
||||
|
||||
# 7: Conv [1024, 3, 2]
|
||||
self.model.append(Conv(gw(512), gw(1024), 3, 2))
|
||||
|
||||
# 8: C3k2 [1024, True] -> n=2
|
||||
self.model.append(C3k2(gw(1024), gw(1024), n=gd(2), c3k=True))
|
||||
|
||||
# 9: SPPF [1024, 5]
|
||||
self.model.append(SPPF(gw(1024), gw(1024), 5))
|
||||
|
||||
# 10: C2PSA [1024] -> n=2 (YAML args=[1024], repeats=2)
|
||||
self.model.append(C2PSA(gw(1024), gw(1024), n=gd(2)))
|
||||
|
||||
# --- Head ---
|
||||
|
||||
# 11: Upsample
|
||||
self.model.append(nn.Upsample(scale_factor=2, mode='nearest'))
|
||||
|
||||
# 12: Concat [-1, 6] (P4)
|
||||
self.model.append(Concat(dimension=1))
|
||||
|
||||
# 13: C3k2 [512, False] -> n=2. Input: P5_up(gw(1024)) + P4(gw(512))
|
||||
self.model.append(C3k2(gw(1024) + gw(512), gw(512), n=gd(2), c3k=False or c3k_override))
|
||||
|
||||
# 14: Upsample
|
||||
self.model.append(nn.Upsample(scale_factor=2, mode='nearest'))
|
||||
|
||||
# 15: Concat [-1, 4] (P3)
|
||||
self.model.append(Concat(dimension=1))
|
||||
|
||||
# 16: C3k2 [256, False] -> n=2. Input: P4_up(gw(512)) + P3(gw(512))
|
||||
# 注意:Layer 4 输出是 gw(512),Layer 13 输出是 gw(512)
|
||||
self.model.append(C3k2(gw(512) + gw(512), gw(256), n=gd(2), c3k=False or c3k_override))
|
||||
|
||||
# 17: Conv [256, 3, 2]
|
||||
self.model.append(Conv(gw(256), gw(256), 3, 2))
|
||||
|
||||
# 18: Concat [-1, 13] (Head P4)
|
||||
self.model.append(Concat(dimension=1))
|
||||
|
||||
# 19: C3k2 [512, False] -> n=2. Input: P3_down(gw(256)) + Head_P4(gw(512))
|
||||
self.model.append(C3k2(gw(256) + gw(512), gw(512), n=gd(2), c3k=False or c3k_override))
|
||||
|
||||
# 20: Conv [512, 3, 2]
|
||||
self.model.append(Conv(gw(512), gw(512), 3, 2))
|
||||
|
||||
# 21: Concat [-1, 10] (P5)
|
||||
self.model.append(Concat(dimension=1))
|
||||
|
||||
# 22: C3k2 [1024, True] -> n=2. Input: P4_down(gw(512)) + P5(gw(1024))
|
||||
self.model.append(C3k2(gw(512) + gw(1024), gw(1024), n=gd(2), c3k=True))
|
||||
|
||||
# 23: Detect [nc]
|
||||
self.model.append(YOLOESegment(nc, ch=[gw(256), gw(512), gw(1024)]))
|
||||
|
||||
# --- 初始化权重 ---
|
||||
self.initialize_weights()
|
||||
|
||||
def initialize_weights(self):
|
||||
"""初始化模型权重,特别是 Detect 头的 Bias"""
|
||||
for m in self.modules():
|
||||
if isinstance(m, (nn.Conv2d, nn.Linear)):
|
||||
# 使用 Kaiming 初始化或其他合适的初始化
|
||||
pass
|
||||
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
detect_layer = self.model[-1]
|
||||
if isinstance(detect_layer, Detect):
|
||||
detect_layer.bias_init()
|
||||
|
||||
def set_classes(self, names: List[str], embeddings: torch.Tensor):
|
||||
assert embeddings.ndim == 3, "Embeddings must be (1, N, D)"
|
||||
self.pe = embeddings
|
||||
self.model[-1].nc = len(names) # type: ignore
|
||||
self.nc = len(names)
|
||||
|
||||
def forward(self, x, tpe=None, vpe=None):
|
||||
# Backbone
|
||||
x = self.model[0](x)
|
||||
x = self.model[1](x)
|
||||
x = self.model[2](x)
|
||||
x = self.model[3](x)
|
||||
p3 = self.model[4](x) # 保存 P3 (layer 4)
|
||||
x = self.model[5](p3)
|
||||
p4 = self.model[6](x) # 保存 P4 (layer 6)
|
||||
x = self.model[7](p4)
|
||||
x = self.model[8](x)
|
||||
x = self.model[9](x)
|
||||
p5 = self.model[10](x) # 保存 P5 (layer 10)
|
||||
|
||||
# Head
|
||||
x = self.model[11](p5) # Upsample
|
||||
x = self.model[12]([x, p4]) # Concat P4
|
||||
h1 = self.model[13](x) # Head P4 (layer 13)
|
||||
|
||||
x = self.model[14](h1) # Upsample
|
||||
x = self.model[15]([x, p3]) # Concat P3
|
||||
h2 = self.model[16](x) # Output P3 (layer 16)
|
||||
|
||||
x = self.model[17](h2) # Conv
|
||||
x = self.model[18]([x, h1]) # Concat Head P4
|
||||
h3 = self.model[19](x) # Output P4 (layer 19)
|
||||
|
||||
x = self.model[20](h3) # Conv
|
||||
x = self.model[21]([x, p5]) # Concat P5
|
||||
h4 = self.model[22](x) # Output P5 (layer 22)
|
||||
|
||||
head = self.model[23]
|
||||
feats = [h2, h3, h4]
|
||||
|
||||
processed_tpe = head.get_tpe(tpe) # type: ignore
|
||||
|
||||
processed_vpe = head.get_vpe(feats, vpe) if vpe is not None else None # type: ignore
|
||||
|
||||
all_pe = []
|
||||
if processed_tpe is not None:
|
||||
all_pe.append(processed_tpe)
|
||||
if processed_vpe is not None:
|
||||
all_pe.append(processed_vpe)
|
||||
|
||||
if not all_pe:
|
||||
if self.pe is not None:
|
||||
all_pe.append(self.pe.to(device=x.device, dtype=x.dtype))
|
||||
else:
|
||||
all_pe.append(torch.zeros(1, self.nc, head.embed, device=x.device, dtype=x.dtype))
|
||||
|
||||
cls_pe = torch.cat(all_pe, dim=1)
|
||||
|
||||
b = x.shape[0]
|
||||
if cls_pe.shape[0] != b:
|
||||
cls_pe = cls_pe.expand(b, -1, -1)
|
||||
|
||||
return head(feats, cls_pe)
|
||||
|
||||
def load_weights(self, pth_file):
|
||||
state_dict = torch.load(pth_file, map_location='cpu', weights_only=False)
|
||||
# 移除可能存在的 'model.' 前缀 (如果权重来自 ultralytics 官方)
|
||||
# 官方权重通常是 model.model.0.conv... 这种格式,或者直接是 model.0.conv...
|
||||
# 这里做一个简单的兼容性处理
|
||||
new_state_dict = {}
|
||||
for k, v in state_dict.items():
|
||||
# 处理 ultralytics 权重字典中的 'model' 键
|
||||
if k == 'model':
|
||||
# 如果是完整的 checkpoint,权重在 'model' 键下
|
||||
# 且通常是 model.state_dict()
|
||||
if hasattr(v, 'state_dict'):
|
||||
v = v.state_dict()
|
||||
elif isinstance(v, dict):
|
||||
pass # v 就是 state_dict
|
||||
else:
|
||||
# 可能是 model 对象本身
|
||||
try:
|
||||
v = v.float().state_dict()
|
||||
except:
|
||||
continue
|
||||
|
||||
for sub_k, sub_v in v.items():
|
||||
new_state_dict[sub_k] = sub_v
|
||||
break
|
||||
else:
|
||||
new_state_dict[k] = v
|
||||
|
||||
if not new_state_dict:
|
||||
new_state_dict = state_dict
|
||||
|
||||
# 尝试加载
|
||||
try:
|
||||
self.load_state_dict(new_state_dict, strict=True)
|
||||
print(f"Successfully loaded weights from {pth_file}")
|
||||
except Exception as e:
|
||||
print(f"Error loading weights: {e}")
|
||||
print("Trying to load with strict=False...")
|
||||
self.load_state_dict(new_state_dict, strict=False)
|
||||
|
||||
if __name__ == "__main__":
|
||||
model = YOLO11E(nc=80, scale='l')
|
||||
model.load_weights("yoloe-11l-seg.pth")
|
||||
|
||||
# 模拟 set_classes
|
||||
# 假设我们有2个类,embedding维度是512
|
||||
fake_embeddings = torch.randn(1, 2, 512)
|
||||
model.set_classes(["class1", "class2"], fake_embeddings)
|
||||
|
||||
# 推理
|
||||
dummy_input = torch.randn(1, 3, 640, 640)
|
||||
model.eval()
|
||||
output = model(dummy_input)
|
||||
print("Output shape:", output[0].shape) # 应该是 (1, 4+mask_coeffs+num_classes, anchors)
|
||||
BIN
yolo11n.pth
(Stored with Git LFS)
Normal file
BIN
yolo11n.pth
(Stored with Git LFS)
Normal file
Binary file not shown.
BIN
yolo11s.pth
(Stored with Git LFS)
Normal file
BIN
yolo11s.pth
(Stored with Git LFS)
Normal file
Binary file not shown.
Reference in New Issue
Block a user