266 lines
9.0 KiB
Python
266 lines
9.0 KiB
Python
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
|
|
|
# For licensing see accompanying LICENSE file.
|
|
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
|
|
|
|
from __future__ import annotations
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from timm.models.layers import DropPath, trunc_normal_
|
|
|
|
from mobileclip.modules.common.mobileone import MobileOneBlock
|
|
|
|
|
|
class ConvFFN(nn.Module):
|
|
"""Convolutional FFN Module."""
|
|
|
|
def __init__(
|
|
self,
|
|
in_channels: int,
|
|
context_size: int,
|
|
hidden_channels: int | None = None,
|
|
out_channels: int | None = None,
|
|
act_layer: nn.Module = nn.GELU,
|
|
drop: float = 0.0,
|
|
) -> None:
|
|
"""Build convolutional FFN module.
|
|
|
|
Args:
|
|
in_channels: Number of input channels.
|
|
context_size: Context size for 1D signals.
|
|
hidden_channels: Number of channels after expansion. Default: None
|
|
out_channels: Number of output channels. Default: None
|
|
act_layer: Activation layer. Default: ``GELU``
|
|
drop: Dropout rate. Default: ``0.0``.
|
|
"""
|
|
super().__init__()
|
|
out_channels = out_channels or in_channels
|
|
hidden_channels = hidden_channels or in_channels
|
|
self.conv = nn.Sequential()
|
|
self.conv.add_module(
|
|
"conv",
|
|
nn.Conv2d(
|
|
in_channels=in_channels,
|
|
out_channels=out_channels,
|
|
kernel_size=(1, int(context_size)),
|
|
padding=(0, int(context_size // 2)),
|
|
groups=in_channels,
|
|
bias=False,
|
|
),
|
|
)
|
|
self.conv.add_module("bn", nn.BatchNorm2d(num_features=out_channels))
|
|
self.fc1 = nn.Conv2d(in_channels, hidden_channels, kernel_size=1)
|
|
self.act = act_layer()
|
|
self.fc2 = nn.Conv2d(hidden_channels, out_channels, kernel_size=1)
|
|
self.drop = nn.Dropout(drop)
|
|
self.apply(self._init_weights)
|
|
|
|
def _init_weights(self, m: nn.Module) -> None:
|
|
if isinstance(m, nn.Conv2d):
|
|
trunc_normal_(m.weight, std=0.02)
|
|
if m.bias is not None:
|
|
nn.init.constant_(m.bias, 0)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
x = self.conv(x)
|
|
x = self.fc1(x)
|
|
x = self.act(x)
|
|
x = self.drop(x)
|
|
x = self.fc2(x)
|
|
x = self.drop(x)
|
|
return x
|
|
|
|
|
|
class RepMixer(nn.Module):
|
|
"""Reparameterizable token mixer.
|
|
|
|
For more details, please refer to our paper: `FastViT: A Fast Hybrid Vision Transformer using Structural
|
|
Reparameterization <https://arxiv.org/pdf/2303.14189.pdf>`_
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
dim,
|
|
kernel_size=3,
|
|
use_layer_scale=True,
|
|
layer_scale_init_value=1e-5,
|
|
inference_mode: bool = False,
|
|
):
|
|
"""Build RepMixer Module.
|
|
|
|
Args:
|
|
dim: Input feature map dimension. :math:`C_{in}` from an expected input of size :math:`(B, C_{in}, H, W)`.
|
|
kernel_size: Kernel size for spatial mixing. Default: 3
|
|
use_layer_scale: If True, learnable layer scale is used. Default: ``True``
|
|
layer_scale_init_value: Initial value for layer scale. Default: 1e-5
|
|
inference_mode: If True, instantiates model in inference mode. Default: ``False``
|
|
"""
|
|
super().__init__()
|
|
self.dim = dim
|
|
self.kernel_size = kernel_size
|
|
self.inference_mode = inference_mode
|
|
|
|
if inference_mode:
|
|
self.reparam_conv = nn.Conv2d(
|
|
in_channels=self.dim,
|
|
out_channels=self.dim,
|
|
kernel_size=(1, self.kernel_size),
|
|
stride=1,
|
|
padding=(0, self.kernel_size // 2),
|
|
groups=self.dim,
|
|
bias=True,
|
|
)
|
|
else:
|
|
self.norm = MobileOneBlock(
|
|
dim,
|
|
dim,
|
|
(1, kernel_size),
|
|
padding=(0, kernel_size // 2),
|
|
groups=dim,
|
|
use_act=False,
|
|
use_scale_branch=False,
|
|
num_conv_branches=0,
|
|
)
|
|
self.mixer = MobileOneBlock(
|
|
dim,
|
|
dim,
|
|
(1, kernel_size),
|
|
padding=(0, kernel_size // 2),
|
|
groups=dim,
|
|
use_act=False,
|
|
)
|
|
self.use_layer_scale = use_layer_scale
|
|
if use_layer_scale:
|
|
self.layer_scale = nn.Parameter(layer_scale_init_value * torch.ones((dim, 1, 1)), requires_grad=True)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
if hasattr(self, "reparam_conv"):
|
|
x = self.reparam_conv(x)
|
|
return x
|
|
else:
|
|
if self.use_layer_scale:
|
|
x = x + self.layer_scale * (self.mixer(x) - self.norm(x))
|
|
else:
|
|
x = x + self.mixer(x) - self.norm(x)
|
|
return x
|
|
|
|
def reparameterize(self) -> None:
|
|
"""Reparameterize mixer and norm into a single convolutional layer for efficient inference."""
|
|
if self.inference_mode:
|
|
return
|
|
|
|
self.mixer.reparameterize()
|
|
self.norm.reparameterize()
|
|
|
|
if self.use_layer_scale:
|
|
w = self.mixer.id_tensor + self.layer_scale.unsqueeze(-1) * (
|
|
self.mixer.reparam_conv.weight - self.norm.reparam_conv.weight
|
|
)
|
|
b = torch.squeeze(self.layer_scale) * (self.mixer.reparam_conv.bias - self.norm.reparam_conv.bias)
|
|
else:
|
|
w = self.mixer.id_tensor + self.mixer.reparam_conv.weight - self.norm.reparam_conv.weight
|
|
b = self.mixer.reparam_conv.bias - self.norm.reparam_conv.bias
|
|
|
|
self.reparam_conv = nn.Conv2d(
|
|
in_channels=self.dim,
|
|
out_channels=self.dim,
|
|
kernel_size=(1, self.kernel_size),
|
|
stride=1,
|
|
padding=(0, self.kernel_size // 2),
|
|
groups=self.dim,
|
|
bias=True,
|
|
)
|
|
self.reparam_conv.weight.data = w
|
|
self.reparam_conv.bias.data = b
|
|
|
|
for para in self.parameters():
|
|
para.detach_()
|
|
self.__delattr__("mixer")
|
|
self.__delattr__("norm")
|
|
if self.use_layer_scale:
|
|
self.__delattr__("layer_scale")
|
|
|
|
|
|
class RepMixerBlock(nn.Module):
|
|
"""Implementation of Metaformer block with RepMixer as token mixer.
|
|
|
|
For more details on Metaformer structure, please refer to: `MetaFormer Is Actually What You Need for Vision
|
|
<https://arxiv.org/pdf/2111.11418.pdf>`_
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
dim: int,
|
|
kernel_size: int = 11,
|
|
mlp_ratio: float = 4.0,
|
|
act_layer: nn.Module = nn.GELU,
|
|
drop: float = 0.0,
|
|
drop_path: float = 0.0,
|
|
use_layer_scale: bool = True,
|
|
layer_scale_init_value: float = 1e-5,
|
|
inference_mode: bool = False,
|
|
*args,
|
|
**kwargs,
|
|
):
|
|
"""Build RepMixer Block.
|
|
|
|
Args:
|
|
dim: Number of embedding dimensions.
|
|
kernel_size: Kernel size for repmixer. Default: 3
|
|
mlp_ratio: MLP expansion ratio. Default: 4.0
|
|
act_layer: Activation layer. Default: ``nn.GELU``
|
|
drop: Dropout rate. Default: 0.0
|
|
drop_path: Drop path rate. Default: 0.0
|
|
use_layer_scale: Flag to turn on layer scale. Default: ``True``
|
|
layer_scale_init_value: Layer scale value at initialization. Default: 1e-5
|
|
inference_mode: Flag to instantiate block in inference mode. Default: ``False``
|
|
"""
|
|
super().__init__()
|
|
|
|
self.token_mixer = RepMixer(
|
|
dim,
|
|
kernel_size=kernel_size,
|
|
use_layer_scale=use_layer_scale,
|
|
layer_scale_init_value=layer_scale_init_value,
|
|
inference_mode=inference_mode,
|
|
)
|
|
|
|
assert mlp_ratio > 0, f"MLP ratio should be greater than 0, found: {mlp_ratio}"
|
|
mlp_hidden_dim = int(dim * mlp_ratio)
|
|
self.convffn = ConvFFN(
|
|
in_channels=dim,
|
|
context_size=kernel_size,
|
|
hidden_channels=mlp_hidden_dim,
|
|
act_layer=act_layer,
|
|
drop=drop,
|
|
)
|
|
|
|
# Drop Path
|
|
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
|
|
|
# Layer Scale
|
|
self.use_layer_scale = use_layer_scale
|
|
if use_layer_scale:
|
|
self.layer_scale = nn.Parameter(layer_scale_init_value * torch.ones((dim, 1, 1)), requires_grad=True)
|
|
|
|
def forward(self, x, *args, **kwargs):
|
|
if x.dim() == 3:
|
|
# B, C, D --- where C is the context length
|
|
# Convert to B, D, C --- to match RepMixer impl.
|
|
x = x.permute(0, 2, 1)
|
|
x = torch.unsqueeze(x, dim=2)
|
|
else:
|
|
raise ValueError(f"Expected tensor of dim=3, obtained tensor of dim={x.dim()}")
|
|
|
|
if self.use_layer_scale:
|
|
x = self.token_mixer(x)
|
|
x = x + self.drop_path(self.layer_scale * self.convffn(x))
|
|
else:
|
|
x = self.token_mixer(x)
|
|
x = x + self.drop_path(self.convffn(x))
|
|
|
|
# Convert tensors back
|
|
x = x.squeeze(dim=2).permute(0, 2, 1)
|
|
return x
|