第一次提交Yolo项目
This commit is contained in:
12
mobileclip/models/__init__.py
Normal file
12
mobileclip/models/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
#
|
||||
# For licensing see accompanying LICENSE file.
|
||||
# Copyright (C) 2024 Apple Inc. All rights reserved.
|
||||
#
|
||||
from .mci import (
|
||||
mci0,
|
||||
mci1,
|
||||
mci2,
|
||||
)
|
||||
from .vit import vit_b16
|
||||
888
mobileclip/models/mci.py
Normal file
888
mobileclip/models/mci.py
Normal file
@@ -0,0 +1,888 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
# For licensing see accompanying LICENSE file.
|
||||
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.models import register_model
|
||||
from timm.models.layers import DropPath, trunc_normal_
|
||||
|
||||
from mobileclip.modules.common.mobileone import MobileOneBlock
|
||||
from mobileclip.modules.image.replknet import ReparamLargeKernelConv
|
||||
|
||||
|
||||
def _cfg(url="", **kwargs):
|
||||
return {
|
||||
"url": url,
|
||||
"num_classes": 1000,
|
||||
"input_size": (3, 256, 256),
|
||||
"pool_size": None,
|
||||
"crop_pct": 0.95,
|
||||
"interpolation": "bicubic",
|
||||
"mean": IMAGENET_DEFAULT_MEAN,
|
||||
"std": IMAGENET_DEFAULT_STD,
|
||||
"classifier": "head",
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = {
|
||||
"fastvit_t": _cfg(crop_pct=0.9),
|
||||
"fastvit_s": _cfg(crop_pct=0.9),
|
||||
"fastvit_m": _cfg(crop_pct=0.95),
|
||||
}
|
||||
|
||||
|
||||
def convolutional_stem(in_channels: int, out_channels: int, inference_mode: bool = False) -> nn.Sequential:
|
||||
"""Build convolutional stem with MobileOne blocks.
|
||||
|
||||
Args:
|
||||
in_channels: Number of input channels.
|
||||
out_channels: Number of output channels.
|
||||
inference_mode: Flag to instantiate model in inference mode. Default: ``False``
|
||||
|
||||
Returns:
|
||||
nn.Sequential object with stem elements.
|
||||
"""
|
||||
return nn.Sequential(
|
||||
MobileOneBlock(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
groups=1,
|
||||
inference_mode=inference_mode,
|
||||
use_se=False,
|
||||
num_conv_branches=1,
|
||||
),
|
||||
MobileOneBlock(
|
||||
in_channels=out_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
groups=out_channels,
|
||||
inference_mode=inference_mode,
|
||||
use_se=False,
|
||||
num_conv_branches=1,
|
||||
),
|
||||
MobileOneBlock(
|
||||
in_channels=out_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
groups=1,
|
||||
inference_mode=inference_mode,
|
||||
use_se=False,
|
||||
num_conv_branches=1,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class MHSA(nn.Module):
|
||||
"""Multi-headed Self Attention module.
|
||||
|
||||
Source modified from:
|
||||
https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
head_dim: int = 32,
|
||||
qkv_bias: bool = False,
|
||||
attn_drop: float = 0.0,
|
||||
proj_drop: float = 0.0,
|
||||
) -> None:
|
||||
"""Build MHSA module that can handle 3D or 4D input tensors.
|
||||
|
||||
Args:
|
||||
dim: Number of embedding dimensions.
|
||||
head_dim: Number of hidden dimensions per head. Default: ``32``
|
||||
qkv_bias: Use bias or not. Default: ``False``
|
||||
attn_drop: Dropout rate for attention tensor.
|
||||
proj_drop: Dropout rate for projection tensor.
|
||||
"""
|
||||
super().__init__()
|
||||
assert dim % head_dim == 0, "dim should be divisible by head_dim"
|
||||
self.head_dim = head_dim
|
||||
self.num_heads = dim // head_dim
|
||||
self.scale = head_dim**-0.5
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
shape = x.shape
|
||||
B, C, H, W = shape
|
||||
N = H * W
|
||||
if len(shape) == 4:
|
||||
x = torch.flatten(x, start_dim=2).transpose(-2, -1) # (B, N, C)
|
||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
|
||||
|
||||
# trick here to make q@k.t more stable
|
||||
attn = (q * self.scale) @ k.transpose(-2, -1)
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
|
||||
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
if len(shape) == 4:
|
||||
x = x.transpose(-2, -1).reshape(B, C, H, W)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
"""Convolutional patch embedding layer."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: int,
|
||||
stride: int,
|
||||
in_channels: int,
|
||||
embed_dim: int,
|
||||
inference_mode: bool = False,
|
||||
use_se: bool = False,
|
||||
) -> None:
|
||||
"""Build patch embedding layer.
|
||||
|
||||
Args:
|
||||
patch_size: Patch size for embedding computation.
|
||||
stride: Stride for convolutional embedding layer.
|
||||
in_channels: Number of channels of input tensor.
|
||||
embed_dim: Number of embedding dimensions.
|
||||
inference_mode: Flag to instantiate model in inference mode. Default: ``False``
|
||||
use_se: If ``True`` SE block will be used.
|
||||
"""
|
||||
super().__init__()
|
||||
block = list()
|
||||
block.append(
|
||||
ReparamLargeKernelConv(
|
||||
in_channels=in_channels,
|
||||
out_channels=embed_dim,
|
||||
kernel_size=patch_size,
|
||||
stride=stride,
|
||||
groups=in_channels,
|
||||
small_kernel=3,
|
||||
inference_mode=inference_mode,
|
||||
use_se=use_se,
|
||||
)
|
||||
)
|
||||
block.append(
|
||||
MobileOneBlock(
|
||||
in_channels=embed_dim,
|
||||
out_channels=embed_dim,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
groups=1,
|
||||
inference_mode=inference_mode,
|
||||
use_se=False,
|
||||
num_conv_branches=1,
|
||||
)
|
||||
)
|
||||
self.proj = nn.Sequential(*block)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class RepMixer(nn.Module):
|
||||
"""Reparameterizable token mixer.
|
||||
|
||||
For more details, please refer to our paper: `FastViT: A Fast Hybrid Vision Transformer using Structural
|
||||
Reparameterization <https://arxiv.org/pdf/2303.14189.pdf>`_
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
kernel_size=3,
|
||||
use_layer_scale=True,
|
||||
layer_scale_init_value=1e-5,
|
||||
inference_mode: bool = False,
|
||||
):
|
||||
"""Build RepMixer Module.
|
||||
|
||||
Args:
|
||||
dim: Input feature map dimension. :math:`C_{in}` from an expected input of size :math:`(B, C_{in}, H, W)`.
|
||||
kernel_size: Kernel size for spatial mixing. Default: 3
|
||||
use_layer_scale: If True, learnable layer scale is used. Default: ``True``
|
||||
layer_scale_init_value: Initial value for layer scale. Default: 1e-5
|
||||
inference_mode: If True, instantiates model in inference mode. Default: ``False``
|
||||
"""
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.kernel_size = kernel_size
|
||||
self.inference_mode = inference_mode
|
||||
|
||||
if inference_mode:
|
||||
self.reparam_conv = nn.Conv2d(
|
||||
in_channels=self.dim,
|
||||
out_channels=self.dim,
|
||||
kernel_size=self.kernel_size,
|
||||
stride=1,
|
||||
padding=self.kernel_size // 2,
|
||||
groups=self.dim,
|
||||
bias=True,
|
||||
)
|
||||
else:
|
||||
self.norm = MobileOneBlock(
|
||||
dim,
|
||||
dim,
|
||||
kernel_size,
|
||||
padding=kernel_size // 2,
|
||||
groups=dim,
|
||||
use_act=False,
|
||||
use_scale_branch=False,
|
||||
num_conv_branches=0,
|
||||
)
|
||||
self.mixer = MobileOneBlock(
|
||||
dim,
|
||||
dim,
|
||||
kernel_size,
|
||||
padding=kernel_size // 2,
|
||||
groups=dim,
|
||||
use_act=False,
|
||||
)
|
||||
self.use_layer_scale = use_layer_scale
|
||||
if use_layer_scale:
|
||||
self.layer_scale = nn.Parameter(layer_scale_init_value * torch.ones((dim, 1, 1)), requires_grad=True)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if hasattr(self, "reparam_conv"):
|
||||
x = self.reparam_conv(x)
|
||||
return x
|
||||
else:
|
||||
if self.use_layer_scale:
|
||||
x = x + self.layer_scale * (self.mixer(x) - self.norm(x))
|
||||
else:
|
||||
x = x + self.mixer(x) - self.norm(x)
|
||||
return x
|
||||
|
||||
def reparameterize(self) -> None:
|
||||
"""Reparameterize mixer and norm into a single convolutional layer for efficient inference."""
|
||||
if self.inference_mode:
|
||||
return
|
||||
|
||||
self.mixer.reparameterize()
|
||||
self.norm.reparameterize()
|
||||
|
||||
if self.use_layer_scale:
|
||||
w = self.mixer.id_tensor + self.layer_scale.unsqueeze(-1) * (
|
||||
self.mixer.reparam_conv.weight - self.norm.reparam_conv.weight
|
||||
)
|
||||
b = torch.squeeze(self.layer_scale) * (self.mixer.reparam_conv.bias - self.norm.reparam_conv.bias)
|
||||
else:
|
||||
w = self.mixer.id_tensor + self.mixer.reparam_conv.weight - self.norm.reparam_conv.weight
|
||||
b = self.mixer.reparam_conv.bias - self.norm.reparam_conv.bias
|
||||
|
||||
self.reparam_conv = nn.Conv2d(
|
||||
in_channels=self.dim,
|
||||
out_channels=self.dim,
|
||||
kernel_size=self.kernel_size,
|
||||
stride=1,
|
||||
padding=self.kernel_size // 2,
|
||||
groups=self.dim,
|
||||
bias=True,
|
||||
)
|
||||
self.reparam_conv.weight.data = w
|
||||
self.reparam_conv.bias.data = b
|
||||
|
||||
for para in self.parameters():
|
||||
para.detach_()
|
||||
self.__delattr__("mixer")
|
||||
self.__delattr__("norm")
|
||||
if self.use_layer_scale:
|
||||
self.__delattr__("layer_scale")
|
||||
|
||||
|
||||
class ConvFFN(nn.Module):
|
||||
"""Convolutional FFN Module."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
hidden_channels: int | None = None,
|
||||
out_channels: int | None = None,
|
||||
act_layer: nn.Module = nn.GELU,
|
||||
drop: float = 0.0,
|
||||
) -> None:
|
||||
"""Build convolutional FFN module.
|
||||
|
||||
Args:
|
||||
in_channels: Number of input channels.
|
||||
hidden_channels: Number of channels after expansion. Default: None
|
||||
out_channels: Number of output channels. Default: None
|
||||
act_layer: Activation layer. Default: ``GELU``
|
||||
drop: Dropout rate. Default: ``0.0``.
|
||||
"""
|
||||
super().__init__()
|
||||
out_channels = out_channels or in_channels
|
||||
hidden_channels = hidden_channels or in_channels
|
||||
self.conv = nn.Sequential()
|
||||
self.conv.add_module(
|
||||
"conv",
|
||||
nn.Conv2d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=7,
|
||||
padding=3,
|
||||
groups=in_channels,
|
||||
bias=False,
|
||||
),
|
||||
)
|
||||
self.conv.add_module("bn", nn.BatchNorm2d(num_features=out_channels))
|
||||
self.fc1 = nn.Conv2d(in_channels, hidden_channels, kernel_size=1)
|
||||
self.act = act_layer()
|
||||
self.fc2 = nn.Conv2d(hidden_channels, out_channels, kernel_size=1)
|
||||
self.drop = nn.Dropout(drop)
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, m: nn.Module) -> None:
|
||||
if isinstance(m, nn.Conv2d):
|
||||
trunc_normal_(m.weight, std=0.02)
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.conv(x)
|
||||
x = self.fc1(x)
|
||||
x = self.act(x)
|
||||
x = self.drop(x)
|
||||
x = self.fc2(x)
|
||||
x = self.drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class RepCPE(nn.Module):
|
||||
"""Implementation of conditional positional encoding.
|
||||
|
||||
For more details refer to paper: `Conditional Positional Encodings for Vision Transformers
|
||||
<https://arxiv.org/pdf/2102.10882.pdf>`_
|
||||
|
||||
In our implementation, we can reparameterize this module to eliminate a skip connection.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
embed_dim: int = 768,
|
||||
spatial_shape: int | tuple[int, int] = (7, 7),
|
||||
inference_mode=False,
|
||||
) -> None:
|
||||
"""Build reparameterizable conditional positional encoding.
|
||||
|
||||
Args:
|
||||
in_channels: Number of input channels.
|
||||
embed_dim: Number of embedding dimensions. Default: 768
|
||||
spatial_shape: Spatial shape of kernel for positional encoding. Default: (7, 7)
|
||||
inference_mode: Flag to instantiate block in inference mode. Default: ``False``
|
||||
"""
|
||||
super().__init__()
|
||||
if isinstance(spatial_shape, int):
|
||||
spatial_shape = tuple([spatial_shape] * 2)
|
||||
assert isinstance(spatial_shape, tuple), (
|
||||
f'"spatial_shape" must by a sequence or int, get {type(spatial_shape)} instead.'
|
||||
)
|
||||
assert len(spatial_shape) == 2, f'Length of "spatial_shape" should be 2, got {len(spatial_shape)} instead.'
|
||||
|
||||
self.spatial_shape = spatial_shape
|
||||
self.embed_dim = embed_dim
|
||||
self.in_channels = in_channels
|
||||
self.groups = embed_dim
|
||||
|
||||
if inference_mode:
|
||||
self.reparam_conv = nn.Conv2d(
|
||||
in_channels=self.in_channels,
|
||||
out_channels=self.embed_dim,
|
||||
kernel_size=self.spatial_shape,
|
||||
stride=1,
|
||||
padding=int(self.spatial_shape[0] // 2),
|
||||
groups=self.embed_dim,
|
||||
bias=True,
|
||||
)
|
||||
else:
|
||||
self.pe = nn.Conv2d(
|
||||
in_channels,
|
||||
embed_dim,
|
||||
spatial_shape,
|
||||
1,
|
||||
int(spatial_shape[0] // 2),
|
||||
bias=True,
|
||||
groups=embed_dim,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if hasattr(self, "reparam_conv"):
|
||||
x = self.reparam_conv(x)
|
||||
return x
|
||||
else:
|
||||
x = self.pe(x) + x
|
||||
return x
|
||||
|
||||
def reparameterize(self) -> None:
|
||||
# Build equivalent Id tensor
|
||||
input_dim = self.in_channels // self.groups
|
||||
kernel_value = torch.zeros(
|
||||
(
|
||||
self.in_channels,
|
||||
input_dim,
|
||||
self.spatial_shape[0],
|
||||
self.spatial_shape[1],
|
||||
),
|
||||
dtype=self.pe.weight.dtype,
|
||||
device=self.pe.weight.device,
|
||||
)
|
||||
for i in range(self.in_channels):
|
||||
kernel_value[
|
||||
i,
|
||||
i % input_dim,
|
||||
self.spatial_shape[0] // 2,
|
||||
self.spatial_shape[1] // 2,
|
||||
] = 1
|
||||
id_tensor = kernel_value
|
||||
|
||||
# Reparameterize Id tensor and conv
|
||||
w_final = id_tensor + self.pe.weight
|
||||
b_final = self.pe.bias
|
||||
|
||||
# Introduce reparam conv
|
||||
self.reparam_conv = nn.Conv2d(
|
||||
in_channels=self.in_channels,
|
||||
out_channels=self.embed_dim,
|
||||
kernel_size=self.spatial_shape,
|
||||
stride=1,
|
||||
padding=int(self.spatial_shape[0] // 2),
|
||||
groups=self.embed_dim,
|
||||
bias=True,
|
||||
)
|
||||
self.reparam_conv.weight.data = w_final
|
||||
self.reparam_conv.bias.data = b_final
|
||||
|
||||
for para in self.parameters():
|
||||
para.detach_()
|
||||
self.__delattr__("pe")
|
||||
|
||||
|
||||
class RepMixerBlock(nn.Module):
|
||||
"""Implementation of Metaformer block with RepMixer as token mixer.
|
||||
|
||||
For more details on Metaformer structure, please refer to: `MetaFormer Is Actually What You Need for Vision
|
||||
<https://arxiv.org/pdf/2111.11418.pdf>`_
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
kernel_size: int = 3,
|
||||
mlp_ratio: float = 4.0,
|
||||
act_layer: nn.Module = nn.GELU,
|
||||
drop: float = 0.0,
|
||||
drop_path: float = 0.0,
|
||||
use_layer_scale: bool = True,
|
||||
layer_scale_init_value: float = 1e-5,
|
||||
inference_mode: bool = False,
|
||||
):
|
||||
"""Build RepMixer Block.
|
||||
|
||||
Args:
|
||||
dim: Number of embedding dimensions.
|
||||
kernel_size: Kernel size for repmixer. Default: 3
|
||||
mlp_ratio: MLP expansion ratio. Default: 4.0
|
||||
act_layer: Activation layer. Default: ``nn.GELU``
|
||||
drop: Dropout rate. Default: 0.0
|
||||
drop_path: Drop path rate. Default: 0.0
|
||||
use_layer_scale: Flag to turn on layer scale. Default: ``True``
|
||||
layer_scale_init_value: Layer scale value at initialization. Default: 1e-5
|
||||
inference_mode: Flag to instantiate block in inference mode. Default: ``False``
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.token_mixer = RepMixer(
|
||||
dim,
|
||||
kernel_size=kernel_size,
|
||||
use_layer_scale=use_layer_scale,
|
||||
layer_scale_init_value=layer_scale_init_value,
|
||||
inference_mode=inference_mode,
|
||||
)
|
||||
|
||||
assert mlp_ratio > 0, f"MLP ratio should be greater than 0, found: {mlp_ratio}"
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
self.convffn = ConvFFN(
|
||||
in_channels=dim,
|
||||
hidden_channels=mlp_hidden_dim,
|
||||
act_layer=act_layer,
|
||||
drop=drop,
|
||||
)
|
||||
|
||||
# Drop Path
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
||||
|
||||
# Layer Scale
|
||||
self.use_layer_scale = use_layer_scale
|
||||
if use_layer_scale:
|
||||
self.layer_scale = nn.Parameter(layer_scale_init_value * torch.ones((dim, 1, 1)), requires_grad=True)
|
||||
|
||||
def forward(self, x):
|
||||
if self.use_layer_scale:
|
||||
x = self.token_mixer(x)
|
||||
x = x + self.drop_path(self.layer_scale * self.convffn(x))
|
||||
else:
|
||||
x = self.token_mixer(x)
|
||||
x = x + self.drop_path(self.convffn(x))
|
||||
return x
|
||||
|
||||
|
||||
class AttentionBlock(nn.Module):
|
||||
"""Implementation of metaformer block with MHSA as token mixer.
|
||||
|
||||
For more details on Metaformer structure, please refer to: `MetaFormer Is Actually What You Need for Vision
|
||||
<https://arxiv.org/pdf/2111.11418.pdf>`_
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
act_layer: nn.Module = nn.GELU,
|
||||
norm_layer: nn.Module = nn.BatchNorm2d,
|
||||
drop: float = 0.0,
|
||||
drop_path: float = 0.0,
|
||||
use_layer_scale: bool = True,
|
||||
layer_scale_init_value: float = 1e-5,
|
||||
):
|
||||
"""Build Attention Block.
|
||||
|
||||
Args:
|
||||
dim: Number of embedding dimensions.
|
||||
mlp_ratio: MLP expansion ratio. Default: 4.0
|
||||
act_layer: Activation layer. Default: ``nn.GELU``
|
||||
norm_layer: Normalization layer. Default: ``nn.BatchNorm2d``
|
||||
drop: Dropout rate. Default: 0.0
|
||||
drop_path: Drop path rate. Default: 0.0
|
||||
use_layer_scale: Flag to turn on layer scale. Default: ``True``
|
||||
layer_scale_init_value: Layer scale value at initialization. Default: 1e-5
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.norm = norm_layer(dim)
|
||||
self.token_mixer = MHSA(dim=dim)
|
||||
|
||||
assert mlp_ratio > 0, f"MLP ratio should be greater than 0, found: {mlp_ratio}"
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
self.convffn = ConvFFN(
|
||||
in_channels=dim,
|
||||
hidden_channels=mlp_hidden_dim,
|
||||
act_layer=act_layer,
|
||||
drop=drop,
|
||||
)
|
||||
|
||||
# Drop path
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
||||
|
||||
# Layer Scale
|
||||
self.use_layer_scale = use_layer_scale
|
||||
if use_layer_scale:
|
||||
self.layer_scale_1 = nn.Parameter(layer_scale_init_value * torch.ones((dim, 1, 1)), requires_grad=True)
|
||||
self.layer_scale_2 = nn.Parameter(layer_scale_init_value * torch.ones((dim, 1, 1)), requires_grad=True)
|
||||
|
||||
def forward(self, x):
|
||||
if self.use_layer_scale:
|
||||
x = x + self.drop_path(self.layer_scale_1 * self.token_mixer(self.norm(x)))
|
||||
x = x + self.drop_path(self.layer_scale_2 * self.convffn(x))
|
||||
else:
|
||||
x = x + self.drop_path(self.token_mixer(self.norm(x)))
|
||||
x = x + self.drop_path(self.convffn(x))
|
||||
return x
|
||||
|
||||
|
||||
def basic_blocks(
|
||||
dim: int,
|
||||
block_index: int,
|
||||
num_blocks: list[int],
|
||||
token_mixer_type: str,
|
||||
kernel_size: int = 3,
|
||||
mlp_ratio: float = 4.0,
|
||||
act_layer: nn.Module = nn.GELU,
|
||||
norm_layer: nn.Module = nn.BatchNorm2d,
|
||||
drop_rate: float = 0.0,
|
||||
drop_path_rate: float = 0.0,
|
||||
use_layer_scale: bool = True,
|
||||
layer_scale_init_value: float = 1e-5,
|
||||
inference_mode=False,
|
||||
) -> nn.Sequential:
|
||||
"""Build FastViT blocks within a stage.
|
||||
|
||||
Args:
|
||||
dim: Number of embedding dimensions.
|
||||
block_index: block index.
|
||||
num_blocks: List containing number of blocks per stage.
|
||||
token_mixer_type: Token mixer type.
|
||||
kernel_size: Kernel size for repmixer.
|
||||
mlp_ratio: MLP expansion ratio.
|
||||
act_layer: Activation layer.
|
||||
norm_layer: Normalization layer.
|
||||
drop_rate: Dropout rate.
|
||||
drop_path_rate: Drop path rate.
|
||||
use_layer_scale: Flag to turn on layer scale regularization.
|
||||
layer_scale_init_value: Layer scale value at initialization.
|
||||
inference_mode: Flag to instantiate block in inference mode.
|
||||
|
||||
Returns:
|
||||
nn.Sequential object of all the blocks within the stage.
|
||||
"""
|
||||
blocks = []
|
||||
for block_idx in range(num_blocks[block_index]):
|
||||
block_dpr = drop_path_rate * (block_idx + sum(num_blocks[:block_index])) / (sum(num_blocks) - 1)
|
||||
if token_mixer_type == "repmixer":
|
||||
blocks.append(
|
||||
RepMixerBlock(
|
||||
dim,
|
||||
kernel_size=kernel_size,
|
||||
mlp_ratio=mlp_ratio,
|
||||
act_layer=act_layer,
|
||||
drop=drop_rate,
|
||||
drop_path=block_dpr,
|
||||
use_layer_scale=use_layer_scale,
|
||||
layer_scale_init_value=layer_scale_init_value,
|
||||
inference_mode=inference_mode,
|
||||
)
|
||||
)
|
||||
elif token_mixer_type == "attention":
|
||||
blocks.append(
|
||||
AttentionBlock(
|
||||
dim,
|
||||
mlp_ratio=mlp_ratio,
|
||||
act_layer=act_layer,
|
||||
norm_layer=norm_layer,
|
||||
drop=drop_rate,
|
||||
drop_path=block_dpr,
|
||||
use_layer_scale=use_layer_scale,
|
||||
layer_scale_init_value=layer_scale_init_value,
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Token mixer type: {token_mixer_type} not supported")
|
||||
blocks = nn.Sequential(*blocks)
|
||||
|
||||
return blocks
|
||||
|
||||
|
||||
class FastViT(nn.Module):
|
||||
"""This class implements `FastViT architecture <https://arxiv.org/pdf/2303.14189.pdf>`_."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layers,
|
||||
token_mixers: tuple[str, ...],
|
||||
embed_dims=None,
|
||||
mlp_ratios=None,
|
||||
downsamples=None,
|
||||
se_downsamples=None,
|
||||
repmixer_kernel_size=3,
|
||||
norm_layer: nn.Module = nn.BatchNorm2d,
|
||||
act_layer: nn.Module = nn.GELU,
|
||||
num_classes=1000,
|
||||
pos_embs=None,
|
||||
down_patch_size=7,
|
||||
down_stride=2,
|
||||
drop_rate=0.0,
|
||||
drop_path_rate=0.0,
|
||||
use_layer_scale=True,
|
||||
layer_scale_init_value=1e-5,
|
||||
init_cfg=None,
|
||||
pretrained=None,
|
||||
cls_ratio=2.0,
|
||||
inference_mode=False,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.num_classes = num_classes
|
||||
if pos_embs is None:
|
||||
pos_embs = [None] * len(layers)
|
||||
|
||||
if se_downsamples is None:
|
||||
se_downsamples = [False] * len(layers)
|
||||
|
||||
# Convolutional stem
|
||||
self.patch_embed = convolutional_stem(3, embed_dims[0], inference_mode)
|
||||
|
||||
# Build the main stages of the network architecture
|
||||
network = []
|
||||
for i in range(len(layers)):
|
||||
# Add position embeddings if requested
|
||||
if pos_embs[i] is not None:
|
||||
network.append(pos_embs[i](embed_dims[i], embed_dims[i], inference_mode=inference_mode))
|
||||
stage = basic_blocks(
|
||||
embed_dims[i],
|
||||
i,
|
||||
layers,
|
||||
token_mixer_type=token_mixers[i],
|
||||
kernel_size=repmixer_kernel_size,
|
||||
mlp_ratio=mlp_ratios[i],
|
||||
act_layer=act_layer,
|
||||
norm_layer=norm_layer,
|
||||
drop_rate=drop_rate,
|
||||
drop_path_rate=drop_path_rate,
|
||||
use_layer_scale=use_layer_scale,
|
||||
layer_scale_init_value=layer_scale_init_value,
|
||||
inference_mode=inference_mode,
|
||||
)
|
||||
network.append(stage)
|
||||
if i >= len(layers) - 1:
|
||||
break
|
||||
|
||||
# Patch merging/downsampling between stages.
|
||||
if downsamples[i] or embed_dims[i] != embed_dims[i + 1]:
|
||||
network.append(
|
||||
PatchEmbed(
|
||||
patch_size=down_patch_size,
|
||||
stride=down_stride,
|
||||
in_channels=embed_dims[i],
|
||||
embed_dim=embed_dims[i + 1],
|
||||
inference_mode=inference_mode,
|
||||
use_se=se_downsamples[i + 1],
|
||||
)
|
||||
)
|
||||
self.network = nn.ModuleList(network)
|
||||
|
||||
# Classifier head
|
||||
self.conv_exp = MobileOneBlock(
|
||||
in_channels=embed_dims[-1],
|
||||
out_channels=int(embed_dims[-1] * cls_ratio),
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
groups=embed_dims[-1],
|
||||
inference_mode=inference_mode,
|
||||
use_se=True,
|
||||
num_conv_branches=1,
|
||||
)
|
||||
self.head = nn.Linear(int(embed_dims[-1] * cls_ratio), num_classes) if num_classes > 0 else nn.Identity()
|
||||
self.apply(self.cls_init_weights)
|
||||
self.init_cfg = copy.deepcopy(init_cfg)
|
||||
|
||||
def cls_init_weights(self, m: nn.Module) -> None:
|
||||
"""Init.
|
||||
|
||||
for classification.
|
||||
"""
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=0.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def forward_embeddings(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.patch_embed(x)
|
||||
return x
|
||||
|
||||
def forward_tokens(self, x: torch.Tensor) -> torch.Tensor:
|
||||
for idx, block in enumerate(self.network):
|
||||
x = block(x)
|
||||
# output only the features of last layer for image classification
|
||||
return x
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# input embedding
|
||||
x = self.forward_embeddings(x)
|
||||
# through backbone
|
||||
x = self.forward_tokens(x)
|
||||
# for image classification
|
||||
x = self.conv_exp(x)
|
||||
cls_out = self.head(x)
|
||||
return cls_out
|
||||
|
||||
|
||||
@register_model
|
||||
def mci0(pretrained=False, **kwargs):
|
||||
"""Instantiate MCi0 model variant."""
|
||||
layers = [2, 6, 10, 2]
|
||||
embed_dims = [64, 128, 256, 512]
|
||||
mlp_ratios = [3, 3, 3, 3]
|
||||
downsamples = [True, True, True, True]
|
||||
se_downsamples = [False, False, True, True]
|
||||
pos_embs = [None, None, None, partial(RepCPE, spatial_shape=(7, 7))]
|
||||
token_mixers = ("repmixer", "repmixer", "repmixer", "attention")
|
||||
model = FastViT(
|
||||
layers,
|
||||
token_mixers=token_mixers,
|
||||
embed_dims=embed_dims,
|
||||
pos_embs=pos_embs,
|
||||
mlp_ratios=mlp_ratios,
|
||||
downsamples=downsamples,
|
||||
se_downsamples=se_downsamples,
|
||||
**kwargs,
|
||||
)
|
||||
model.default_cfg = default_cfgs["fastvit_s"]
|
||||
if pretrained:
|
||||
raise ValueError("Functionality not implemented.")
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def mci1(pretrained=False, **kwargs):
|
||||
"""Instantiate MCi1 model variant."""
|
||||
layers = [4, 12, 20, 4]
|
||||
embed_dims = [64, 128, 256, 512]
|
||||
mlp_ratios = [3, 3, 3, 3]
|
||||
downsamples = [True, True, True, True]
|
||||
se_downsamples = [False, False, True, True]
|
||||
pos_embs = [None, None, None, partial(RepCPE, spatial_shape=(7, 7))]
|
||||
token_mixers = ("repmixer", "repmixer", "repmixer", "attention")
|
||||
model = FastViT(
|
||||
layers,
|
||||
token_mixers=token_mixers,
|
||||
embed_dims=embed_dims,
|
||||
pos_embs=pos_embs,
|
||||
mlp_ratios=mlp_ratios,
|
||||
downsamples=downsamples,
|
||||
se_downsamples=se_downsamples,
|
||||
**kwargs,
|
||||
)
|
||||
model.default_cfg = default_cfgs["fastvit_s"]
|
||||
if pretrained:
|
||||
raise ValueError("Functionality not implemented.")
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def mci2(pretrained=False, **kwargs):
|
||||
"""Instantiate MCi2 model variant."""
|
||||
layers = [4, 12, 24, 4]
|
||||
embed_dims = [80, 160, 320, 640]
|
||||
mlp_ratios = [3, 3, 3, 3]
|
||||
downsamples = [True, True, True, True]
|
||||
se_downsamples = [False, False, True, True]
|
||||
pos_embs = [None, None, None, partial(RepCPE, spatial_shape=(7, 7))]
|
||||
token_mixers = ("repmixer", "repmixer", "repmixer", "attention")
|
||||
model = FastViT(
|
||||
layers,
|
||||
token_mixers=token_mixers,
|
||||
embed_dims=embed_dims,
|
||||
pos_embs=pos_embs,
|
||||
mlp_ratios=mlp_ratios,
|
||||
downsamples=downsamples,
|
||||
se_downsamples=se_downsamples,
|
||||
**kwargs,
|
||||
)
|
||||
model.default_cfg = default_cfgs["fastvit_m"]
|
||||
if pretrained:
|
||||
raise ValueError("Functionality not implemented.")
|
||||
return model
|
||||
389
mobileclip/models/vit.py
Normal file
389
mobileclip/models/vit.py
Normal file
@@ -0,0 +1,389 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
#
|
||||
# For licensing see accompanying LICENSE file.
|
||||
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
|
||||
#
|
||||
"""
|
||||
Implementation of the following modules is borrowed from ml-cvnets repo:
|
||||
https://github.com/apple/ml-cvnets/blob/main/cvnets/models/classification/vit.py.
|
||||
|
||||
Please see ACKNOWLEDGMENTS for license details.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from timm.models import register_model
|
||||
from torch import Tensor, nn
|
||||
|
||||
from mobileclip import logger
|
||||
from mobileclip.modules.common.transformer import (
|
||||
PositionalEmbedding,
|
||||
TransformerEncoder,
|
||||
get_normalization_layer,
|
||||
)
|
||||
from mobileclip.modules.image.image_projection import SimpleImageProjectionHead
|
||||
|
||||
|
||||
class ConvNormAct(nn.Module):
|
||||
"""Applies an N-dimensional convolution over an input.
|
||||
|
||||
Args:
|
||||
cfg: Model configuration.
|
||||
in_channels: :math:`C_{out}` from an expected output of size :math:`(bs, C_{in}, X_{1}, ..., X_{N})`.
|
||||
out_channels: :math:`C_{out}` from an expected output of size :math:`(bs, C_{out}, Y_{1}, ..., Y_{N})`.
|
||||
kernel_size: Kernel size for convolution. An integer, or tuple of length ``N``.
|
||||
stride: Stride for convolution. An integer, or tuple of length ``N``. Default: 1.
|
||||
dilation: Dilation rate for convolution. An integer, or tuple of length ``N``. Default: ``1``.
|
||||
padding: Padding for convolution. An integer, or tuple of length ``N``. If not specified, padding is
|
||||
automatically computed based on kernel size and dilation range. Default : ``None`` (equivalent to ``[
|
||||
int((kernel_size[i] - 1) / 2) * dilation[i] for i in range(N)]``).
|
||||
groups: Number of groups in convolution. Default: ``1``.
|
||||
bias: Use bias. Default: ``False``.
|
||||
padding_mode: Padding mode ('zeros', 'reflect', 'replicate' or 'circular'). Default: ``zeros``.
|
||||
use_norm: Use normalization layer after convolution. Default: ``True``.
|
||||
use_act: Use activation layer after convolution (or convolution and normalization). Default: ``True``.
|
||||
norm_layer: If not None, the provided normalization layer object will be used. Otherwise, a normalization object
|
||||
will be created based on config ``model.normalization.*`` opts.
|
||||
act_layer: If not None, the provided activation function will be used. Otherwise, an activation function will be
|
||||
created based on config ``model.activation.*`` opts.
|
||||
|
||||
Notes:
|
||||
- Input: :math:`(bs, C_{in}, X_{1}, ..., X_{N})`.
|
||||
- Output: :math:`(bs, C_{out}, Y_{1}, ..., Y_{N})`.
|
||||
- For depth-wise convolution, `groups=C_{in}=C_{out}`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cfg: dict,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: int | tuple[int, ...],
|
||||
stride: int | tuple[int, ...] = 1,
|
||||
dilation: int | tuple[int, ...] = 1,
|
||||
padding: int | tuple[int, ...] | None = None,
|
||||
groups: int = 1,
|
||||
bias: bool = False,
|
||||
padding_mode: str = "zeros",
|
||||
use_norm: bool = True,
|
||||
use_act: bool = True,
|
||||
norm_layer: nn.Module | None = None,
|
||||
act_layer: nn.Module | None = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.ndim = 2
|
||||
|
||||
if norm_layer is None and use_norm:
|
||||
norm_type = cfg.get("normalization", "batch_norm")
|
||||
if norm_type == "batch_norm":
|
||||
norm_layer = nn.BatchNorm2d(
|
||||
num_features=out_channels,
|
||||
momentum=cfg.get("momentum", 0.1),
|
||||
)
|
||||
else:
|
||||
norm_layer = get_normalization_layer(num_features=out_channels, norm_type=norm_type)
|
||||
elif norm_layer is not None and use_norm:
|
||||
logger.error(f"When use_norm is False, norm_layer should be None, but norm_layer={norm_layer} is provided.")
|
||||
|
||||
if act_layer is None and use_act:
|
||||
act_layer = nn.GELU() # Default to GELU
|
||||
elif act_layer is not None and use_act:
|
||||
logger.error(f"When use_act is False, act_layer should be None, but act_layer={act_layer} is provided.")
|
||||
|
||||
if use_norm and any(param[0] == "bias" for param in norm_layer.named_parameters()) and bias:
|
||||
assert not bias, "Do not use bias when using normalization layers with bias."
|
||||
|
||||
if isinstance(kernel_size, int):
|
||||
kernel_size = (kernel_size,) * self.ndim
|
||||
|
||||
if isinstance(stride, int):
|
||||
stride = (stride,) * self.ndim
|
||||
|
||||
if isinstance(dilation, int):
|
||||
dilation = (dilation,) * self.ndim
|
||||
|
||||
assert isinstance(kernel_size, tuple)
|
||||
assert isinstance(stride, tuple)
|
||||
assert isinstance(dilation, tuple)
|
||||
|
||||
if padding is None:
|
||||
padding = (int((kernel_size[i] - 1) / 2) * dilation[i] for i in range(self.ndim))
|
||||
|
||||
if in_channels % groups != 0:
|
||||
logger.error(f"Input channels are not divisible by groups. {in_channels}%{groups} != 0 ")
|
||||
if out_channels % groups != 0:
|
||||
logger.error(f"Output channels are not divisible by groups. {out_channels}%{groups} != 0 ")
|
||||
|
||||
block = nn.Sequential()
|
||||
|
||||
conv_layer = nn.Conv2d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size, # type: ignore
|
||||
stride=stride, # type: ignore
|
||||
padding=padding,
|
||||
dilation=dilation, # type: ignore
|
||||
groups=groups,
|
||||
bias=bias,
|
||||
padding_mode=padding_mode,
|
||||
)
|
||||
|
||||
block.add_module(name="conv", module=conv_layer)
|
||||
|
||||
self.norm_name = None
|
||||
if use_norm:
|
||||
block.add_module(name="norm", module=norm_layer)
|
||||
self.norm_name = norm_layer.__class__.__name__
|
||||
|
||||
self.act_name = None
|
||||
if use_act:
|
||||
block.add_module(name="act", module=act_layer)
|
||||
self.act_name = act_layer.__class__.__name__
|
||||
|
||||
self.block = block
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.stride = stride
|
||||
self.groups = groups
|
||||
self.kernel_size = conv_layer.kernel_size
|
||||
self.bias = bias
|
||||
self.dilation = dilation
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return self.block(x)
|
||||
|
||||
|
||||
class VisionTransformer(nn.Module):
|
||||
"""This class defines the `Vision Transformer architecture <https://arxiv.org/abs/2010.11929>`_. Our model
|
||||
implementation is inspired from `Early Convolutions Help Transformers See
|
||||
Better <https://arxiv.org/abs/2106.14881>`_.
|
||||
|
||||
.. note::
|
||||
Our implementation is different from the original implementation in two ways:
|
||||
1. Kernel size is odd.
|
||||
2. Our positional encoding implementation allows us to use ViT with any multiple input scales
|
||||
3. We do not use StochasticDepth
|
||||
4. We do not add positional encoding to class token (if enabled), as suggested in `DeiT-3 paper <https://arxiv.org/abs/2204.07118>`_
|
||||
"""
|
||||
|
||||
def __init__(self, cfg, *args, **kwargs) -> None:
|
||||
super().__init__()
|
||||
image_channels = 3
|
||||
num_classes = cfg.get("n_classes", 1000)
|
||||
|
||||
self.projection_dim = None
|
||||
if "projection_dim" in kwargs:
|
||||
self.projection_dim = kwargs.get("projection_dim")
|
||||
|
||||
kernel_sizes_conv_stem = [4, 2, 2]
|
||||
strides_conv_stem = [4, 2, 2]
|
||||
|
||||
# Typically, in the ImageNet dataset, we use 224x224 as a resolution.
|
||||
# For out ViT implementation, patch size is 16 (16 = 4 * 2 * 2)
|
||||
# Therefore, total number of embeddings along width and height are (224 / 16)^2
|
||||
num_embeddings = (224 // 16) ** 2
|
||||
|
||||
embed_dim = cfg["embed_dim"]
|
||||
ffn_dim = cfg["embed_dim"] * 4
|
||||
pos_emb_drop_p = cfg.get("pos_emb_drop_p", 0.0)
|
||||
n_transformer_layers = cfg["n_transformer_layers"]
|
||||
num_heads = cfg["n_attn_heads"]
|
||||
attn_dropout = cfg.get("attn_dropout", 0.0)
|
||||
dropout = cfg.get("dropout", 0.0)
|
||||
ffn_dropout = cfg.get("ffn_dropout", 0.0)
|
||||
norm_layer = cfg.get("norm_layer", "layer_norm")
|
||||
|
||||
conv_stem_proj_dim = max(32, embed_dim // 4)
|
||||
patch_emb = [
|
||||
ConvNormAct(
|
||||
cfg=cfg,
|
||||
in_channels=image_channels,
|
||||
out_channels=conv_stem_proj_dim,
|
||||
kernel_size=kernel_sizes_conv_stem[0],
|
||||
stride=strides_conv_stem[0],
|
||||
bias=False,
|
||||
use_norm=True,
|
||||
use_act=True,
|
||||
),
|
||||
ConvNormAct(
|
||||
cfg=cfg,
|
||||
in_channels=conv_stem_proj_dim,
|
||||
out_channels=conv_stem_proj_dim,
|
||||
kernel_size=kernel_sizes_conv_stem[1],
|
||||
stride=strides_conv_stem[1],
|
||||
bias=False,
|
||||
use_norm=True,
|
||||
use_act=True,
|
||||
),
|
||||
ConvNormAct(
|
||||
cfg=cfg,
|
||||
in_channels=conv_stem_proj_dim,
|
||||
out_channels=embed_dim,
|
||||
kernel_size=kernel_sizes_conv_stem[2],
|
||||
stride=strides_conv_stem[2],
|
||||
bias=True,
|
||||
use_norm=False,
|
||||
use_act=False,
|
||||
),
|
||||
]
|
||||
|
||||
self.patch_emb = nn.Sequential(*patch_emb)
|
||||
|
||||
use_cls_token = not cfg.get("no_cls_token", False)
|
||||
stochastic_dropout = cfg.get("stochastic_dropout", 0.0)
|
||||
per_layer_stochastic_drop_rate = [round(x, 3) for x in np.linspace(0, stochastic_dropout, n_transformer_layers)]
|
||||
transformer_blocks = [
|
||||
TransformerEncoder(
|
||||
embed_dim=embed_dim,
|
||||
ffn_latent_dim=ffn_dim,
|
||||
num_heads=num_heads,
|
||||
attn_dropout=attn_dropout,
|
||||
dropout=dropout,
|
||||
ffn_dropout=ffn_dropout,
|
||||
transformer_norm_layer=norm_layer,
|
||||
stochastic_dropout=per_layer_stochastic_drop_rate[layer_idx],
|
||||
)
|
||||
for layer_idx in range(n_transformer_layers)
|
||||
]
|
||||
|
||||
self.post_transformer_norm = get_normalization_layer(num_features=embed_dim, norm_type=norm_layer)
|
||||
|
||||
self.transformer = nn.Sequential(*transformer_blocks)
|
||||
|
||||
if self.projection_dim is None:
|
||||
self.classifier = nn.Linear(embed_dim, num_classes)
|
||||
else:
|
||||
self.classifier = SimpleImageProjectionHead(embed_dim, self.projection_dim)
|
||||
|
||||
if use_cls_token:
|
||||
self.cls_token = nn.Parameter(torch.zeros(size=(1, 1, embed_dim)))
|
||||
torch.nn.init.trunc_normal_(self.cls_token, std=0.02)
|
||||
else:
|
||||
self.cls_token = None
|
||||
|
||||
self.pos_embed = PositionalEmbedding(
|
||||
num_embeddings=num_embeddings,
|
||||
embedding_dim=embed_dim,
|
||||
padding_idx=None,
|
||||
interpolation_mode="bilinear",
|
||||
)
|
||||
self.emb_dropout = nn.Dropout(p=pos_emb_drop_p)
|
||||
|
||||
def extract_patch_embeddings(self, x: Tensor) -> tuple[Tensor, tuple[int, int]]:
|
||||
# input is of shape [Batch, in_channels, height, width]. in_channels is mostly 3 (for RGB images)
|
||||
batch_size = x.shape[0]
|
||||
|
||||
# [Batch, in_channels, height, width] --> [Batch, emb_dim, num_patches_height, num_patches_width]
|
||||
patch_emb = self.patch_emb(x)
|
||||
n_h, n_w = patch_emb.shape[-2:]
|
||||
|
||||
# [Batch, emb_dim, num_patches_height, num_patches_width] --> [Batch, emb_dim, num_patches]
|
||||
patch_emb = patch_emb.flatten(2)
|
||||
# [Batch, emb_dim, num_patches] --> [Batch, num_patches, emb_dim]
|
||||
patch_emb = patch_emb.transpose(1, 2).contiguous()
|
||||
|
||||
n_patches = patch_emb.shape[1]
|
||||
# we resize the positional encodings dynamically.
|
||||
pos_emb = self.pos_embed(n_patches).to(patch_emb.dtype)
|
||||
|
||||
# add positional encodings
|
||||
patch_emb = pos_emb + patch_emb
|
||||
|
||||
# add classification token
|
||||
if self.cls_token is not None:
|
||||
# [1, 1, emb_dim] --> [Batch, 1, emb_dim]
|
||||
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
|
||||
# Concat([Batch, 1, emb_dim], [Batch, num_patches, emb_dim]) --> [Batch, num_patches + 1, emb_dim]
|
||||
patch_emb = torch.cat((cls_tokens, patch_emb), dim=1)
|
||||
|
||||
# dropout
|
||||
patch_emb = self.emb_dropout(patch_emb)
|
||||
return patch_emb, (n_h, n_w)
|
||||
|
||||
def _features_from_transformer(self, x: Tensor, *args, **kwargs) -> tuple[Tensor, tuple[int, int]]:
|
||||
# this function extract patch embeddings and then apply transformer module to learn
|
||||
# inter-patch representations
|
||||
|
||||
# [B, N, C] --> [N, B, embed_dim], where B is batch size, N is number of tokens,
|
||||
# and embed_dim is feature dim
|
||||
x, (n_h, n_w) = self.extract_patch_embeddings(x)
|
||||
|
||||
for layer in self.transformer:
|
||||
x = layer(x)
|
||||
x = self.post_transformer_norm(x)
|
||||
|
||||
return x, (n_h, n_w)
|
||||
|
||||
def extract_features(self, x: Tensor, *args, **kwargs) -> tuple[Tensor, Tensor | None]:
|
||||
# The extract_features function for ViT returns two outputs: (1) embedding corresponding to CLS token
|
||||
# and (2) image embeddings of the shape [B, C, h//o, w//o], where the value of o is typically 16.
|
||||
return_image_embeddings = kwargs.get("return_image_embeddings", False)
|
||||
|
||||
# [B, C, H, W] --> [B, N + 1, embed_dim] or [B, N, embed_dim]
|
||||
# here, B is batch size, C is input channels
|
||||
# H and W are input height and width
|
||||
# N is the number of pixels (or tokens) after processing input with conv stem and reshaping
|
||||
# We add +1 for cls token (if applicable)
|
||||
# embed_dim --> embedding dimension
|
||||
x, (n_h, n_w) = self._features_from_transformer(x, *args, **kwargs)
|
||||
|
||||
if self.cls_token is not None:
|
||||
# [B, N + 1, embed_dim] --> [B, embed_dim], [B, N, embed_dim]
|
||||
cls_embedding, image_embedding = torch.split(x, split_size_or_sections=[1, x.shape[1] - 1], dim=1)
|
||||
cls_embedding = cls_embedding.squeeze(1)
|
||||
else:
|
||||
# [B, N, embed_dim] -> [B, embed_dim]
|
||||
cls_embedding = torch.mean(x, dim=1)
|
||||
# [B, N, embed_dim]
|
||||
image_embedding = x
|
||||
|
||||
if return_image_embeddings:
|
||||
# reshape image embedding to 4-D tensor
|
||||
# [B, N, C] --> [B, C, N]
|
||||
image_embedding = image_embedding.transpose(1, 2).contiguous()
|
||||
image_embedding = image_embedding.reshape(image_embedding.shape[0], -1, n_h, n_w)
|
||||
|
||||
return cls_embedding, image_embedding
|
||||
else:
|
||||
return cls_embedding, None
|
||||
|
||||
def forward_classifier(self, x: Tensor, *args, **kwargs) -> tuple[Tensor, Tensor]:
|
||||
cls_embedding, image_embedding = self.extract_features(x, *args, **kwargs)
|
||||
# classify based on CLS token
|
||||
cls_embedding = self.classifier(cls_embedding)
|
||||
return cls_embedding, image_embedding
|
||||
|
||||
def forward(self, x: Tensor, *args, **kwargs) -> Tensor | dict[str, Tensor]:
|
||||
# In ViT model, we can return either classifier embeddings (logits) or image embeddings or both.
|
||||
# To return the image embeddings, we need to set keyword argument (return_image_embeddings) as True.
|
||||
if kwargs.get("return_image_embeddings", False):
|
||||
out_dict = dict()
|
||||
prediction, image_embedding = self.forward_classifier(x, *args, **kwargs)
|
||||
out_dict.update({"logits": prediction})
|
||||
if image_embedding is not None:
|
||||
out_dict.update({"image_embeddings": image_embedding})
|
||||
return out_dict
|
||||
else:
|
||||
prediction, _ = self.forward_classifier(x, *args, **kwargs)
|
||||
return prediction
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_b16(pretrained=False, **kwargs):
|
||||
# Vision transformer config
|
||||
cfg = {
|
||||
"norm_layer": "layer_norm_fp32",
|
||||
"act_layer": "gelu",
|
||||
"embed_dim": 768,
|
||||
"n_transformer_layers": 12,
|
||||
"n_attn_heads": 12,
|
||||
}
|
||||
model = VisionTransformer(cfg=cfg, **kwargs)
|
||||
if pretrained:
|
||||
raise ValueError("Functionality not implemented.")
|
||||
return model
|
||||
Reference in New Issue
Block a user