第一次提交Yolo项目
This commit is contained in:
63
mobileclip/image_encoder.py
Normal file
63
mobileclip/image_encoder.py
Normal file
@@ -0,0 +1,63 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
#
|
||||
# For licensing see accompanying LICENSE file.
|
||||
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
|
||||
#
|
||||
from typing import Any
|
||||
|
||||
import torch.nn as nn
|
||||
from timm.models import create_model
|
||||
|
||||
from mobileclip import models # noqa added to register models
|
||||
from mobileclip.modules.image.image_projection import GlobalPool2D
|
||||
|
||||
|
||||
class MCi(nn.Module):
|
||||
"""This class implements `MCi Models <https://arxiv.org/pdf/2311.17049.pdf>`_."""
|
||||
|
||||
def __init__(self, model_name: str, *args, **kwargs) -> None:
|
||||
super().__init__()
|
||||
self.projection_dim = None
|
||||
if "projection_dim" in kwargs:
|
||||
self.projection_dim = kwargs.get("projection_dim")
|
||||
|
||||
# Create model
|
||||
self.model = create_model(model_name, projection_dim=self.projection_dim)
|
||||
|
||||
# Build out projection head.
|
||||
if self.projection_dim is not None:
|
||||
if hasattr(self.model, "head"):
|
||||
self.model.head = MCi._update_image_classifier(
|
||||
image_classifier=self.model.head, projection_dim=self.projection_dim
|
||||
)
|
||||
|
||||
def forward(self, x: Any, *args, **kwargs) -> Any:
|
||||
"""A forward function of the model."""
|
||||
x = self.model(x)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def _get_in_feature_dimension(image_classifier: nn.Module) -> int:
|
||||
"""Return the input feature dimension to the image classification head."""
|
||||
in_features = None
|
||||
if isinstance(image_classifier, nn.Sequential):
|
||||
# Classifier that uses nn.Sequential usually has global pooling and
|
||||
# multiple linear layers. Find the first linear layer and get its
|
||||
# in_features
|
||||
for layer in image_classifier:
|
||||
if isinstance(layer, nn.Linear):
|
||||
in_features = layer.in_features
|
||||
break
|
||||
elif isinstance(image_classifier, nn.Linear):
|
||||
in_features = image_classifier.in_features
|
||||
|
||||
if in_features is None:
|
||||
raise NotImplementedError(f"Cannot get input feature dimension of {image_classifier}.")
|
||||
return in_features
|
||||
|
||||
@staticmethod
|
||||
def _update_image_classifier(image_classifier: nn.Module, projection_dim: int, *args, **kwargs) -> nn.Module:
|
||||
in_features = MCi._get_in_feature_dimension(image_classifier)
|
||||
new_img_classifier = GlobalPool2D(in_dim=in_features, out_dim=projection_dim)
|
||||
return new_img_classifier
|
||||
Reference in New Issue
Block a user