64 lines
2.4 KiB
Python
64 lines
2.4 KiB
Python
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
|
|
|
#
|
|
# For licensing see accompanying LICENSE file.
|
|
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
|
|
#
|
|
from typing import Any
|
|
|
|
import torch.nn as nn
|
|
from timm.models import create_model
|
|
|
|
from mobileclip import models # noqa added to register models
|
|
from mobileclip.modules.image.image_projection import GlobalPool2D
|
|
|
|
|
|
class MCi(nn.Module):
|
|
"""This class implements `MCi Models <https://arxiv.org/pdf/2311.17049.pdf>`_."""
|
|
|
|
def __init__(self, model_name: str, *args, **kwargs) -> None:
|
|
super().__init__()
|
|
self.projection_dim = None
|
|
if "projection_dim" in kwargs:
|
|
self.projection_dim = kwargs.get("projection_dim")
|
|
|
|
# Create model
|
|
self.model = create_model(model_name, projection_dim=self.projection_dim)
|
|
|
|
# Build out projection head.
|
|
if self.projection_dim is not None:
|
|
if hasattr(self.model, "head"):
|
|
self.model.head = MCi._update_image_classifier(
|
|
image_classifier=self.model.head, projection_dim=self.projection_dim
|
|
)
|
|
|
|
def forward(self, x: Any, *args, **kwargs) -> Any:
|
|
"""A forward function of the model."""
|
|
x = self.model(x)
|
|
return x
|
|
|
|
@staticmethod
|
|
def _get_in_feature_dimension(image_classifier: nn.Module) -> int:
|
|
"""Return the input feature dimension to the image classification head."""
|
|
in_features = None
|
|
if isinstance(image_classifier, nn.Sequential):
|
|
# Classifier that uses nn.Sequential usually has global pooling and
|
|
# multiple linear layers. Find the first linear layer and get its
|
|
# in_features
|
|
for layer in image_classifier:
|
|
if isinstance(layer, nn.Linear):
|
|
in_features = layer.in_features
|
|
break
|
|
elif isinstance(image_classifier, nn.Linear):
|
|
in_features = image_classifier.in_features
|
|
|
|
if in_features is None:
|
|
raise NotImplementedError(f"Cannot get input feature dimension of {image_classifier}.")
|
|
return in_features
|
|
|
|
@staticmethod
|
|
def _update_image_classifier(image_classifier: nn.Module, projection_dim: int, *args, **kwargs) -> nn.Module:
|
|
in_features = MCi._get_in_feature_dimension(image_classifier)
|
|
new_img_classifier = GlobalPool2D(in_dim=in_features, out_dim=projection_dim)
|
|
return new_img_classifier
|