第一次提交Yolo项目
This commit is contained in:
218
mobileclip/text_encoder.py
Normal file
218
mobileclip/text_encoder.py
Normal file
@@ -0,0 +1,218 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
# For licensing see accompanying LICENSE file.
|
||||
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from collections.abc import Sequence
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from mobileclip import logger
|
||||
from mobileclip.modules.common.transformer import (
|
||||
PositionalEmbedding,
|
||||
TransformerEncoder,
|
||||
get_normalization_layer,
|
||||
)
|
||||
from mobileclip.modules.text.repmixer import RepMixerBlock
|
||||
|
||||
|
||||
class TextTransformer(nn.Module):
|
||||
def __init__(self, cfg: dict, projection_dim: int, *args, **kwargs) -> None:
|
||||
super().__init__()
|
||||
|
||||
model_dim = cfg["dim"]
|
||||
no_scale_embedding = cfg.get("no_scale_embedding", False)
|
||||
no_pos_embedding = cfg.get("no_pos_embedding", False)
|
||||
embed_dropout = cfg.get("embed_dropout", 0.0)
|
||||
norm_layer = cfg["norm_layer"]
|
||||
variant = cfg["model_name"]
|
||||
self.vocab_size = cfg["vocab_size"]
|
||||
self.projection_dim = projection_dim
|
||||
|
||||
# Token embedding layer
|
||||
self.embedding_layer = nn.Embedding(embedding_dim=model_dim, num_embeddings=self.vocab_size)
|
||||
self.embed_scale = 1.0 if no_scale_embedding else model_dim**-0.5
|
||||
|
||||
# Context length
|
||||
context_length = cfg["context_length"]
|
||||
assert context_length is not None, "Context length can't be None. Please set value accordingly."
|
||||
|
||||
self.positional_embedding = (
|
||||
None if no_pos_embedding else PositionalEmbedding(num_embeddings=context_length, embedding_dim=model_dim)
|
||||
)
|
||||
|
||||
self.embedding_dropout = nn.Dropout(p=embed_dropout)
|
||||
|
||||
# Transformer layer
|
||||
n_transformer_layers = cfg["n_transformer_layers"]
|
||||
|
||||
# FFN multipliers for transformer layer
|
||||
ffn_multipliers = cfg["ffn_multiplier_per_layer"]
|
||||
if isinstance(ffn_multipliers, (float, int)):
|
||||
ffn_multipliers = [ffn_multipliers] * n_transformer_layers
|
||||
|
||||
if not isinstance(ffn_multipliers, Sequence):
|
||||
logger.error(
|
||||
f"{self.__class__.__name__} expects FFN multipliers as a list, whose length is the same as"
|
||||
f" number of transformer layers. Got: {type(ffn_multipliers)}"
|
||||
)
|
||||
elif isinstance(ffn_multipliers, Sequence) and len(ffn_multipliers) != n_transformer_layers:
|
||||
logger.error(
|
||||
f"We need FFN multiplier for each transformer layer. Got {len(ffn_multipliers)} ffn"
|
||||
f" multipliers while number of transformer layers = {n_transformer_layers}"
|
||||
)
|
||||
ffn_dims = [int(math.ceil(model_dim * ffn_mult / 16.0) * 16.0) for ffn_mult in ffn_multipliers]
|
||||
|
||||
# Heads for transformer layers
|
||||
mha_heads = cfg["n_heads_per_layer"]
|
||||
if isinstance(mha_heads, int):
|
||||
mha_heads = [mha_heads] * n_transformer_layers
|
||||
|
||||
if not isinstance(mha_heads, Sequence):
|
||||
logger.error(
|
||||
f"{self.__class__.__name__} expects MHA heads as a list, whose length is the same as number of "
|
||||
f"transformer layers. Got: {type(mha_heads)}"
|
||||
)
|
||||
elif isinstance(mha_heads, Sequence) and len(mha_heads) != n_transformer_layers:
|
||||
logger.error(
|
||||
f"{self.__class__.__name__} needs MHA heads for each transformer layer. Got {len(mha_heads)} mha heads while"
|
||||
f" number of transformer layers = {n_transformer_layers}"
|
||||
)
|
||||
|
||||
if variant == "base":
|
||||
self.transformer = nn.ModuleList(
|
||||
[
|
||||
TransformerEncoder(
|
||||
embed_dim=model_dim,
|
||||
num_heads=mha_heads[layer_idx],
|
||||
ffn_latent_dim=ffn_dims[layer_idx],
|
||||
transformer_norm_layer=norm_layer,
|
||||
)
|
||||
for layer_idx in range(n_transformer_layers)
|
||||
]
|
||||
)
|
||||
elif variant == "mct":
|
||||
self.transformer = nn.ModuleList([RepMixerBlock(dim=model_dim)])
|
||||
self.transformer.extend(
|
||||
[
|
||||
TransformerEncoder(
|
||||
embed_dim=model_dim,
|
||||
num_heads=mha_heads[layer_idx],
|
||||
ffn_latent_dim=ffn_dims[layer_idx],
|
||||
transformer_norm_layer=norm_layer,
|
||||
)
|
||||
for layer_idx in range(n_transformer_layers)
|
||||
]
|
||||
)
|
||||
self.transformer.extend([RepMixerBlock(dim=model_dim)])
|
||||
else:
|
||||
raise ValueError(f"Unrecognized text encoder variant {variant}")
|
||||
|
||||
self.final_layer_norm = get_normalization_layer(num_features=model_dim, norm_type=norm_layer)
|
||||
|
||||
self.projection_layer = nn.Parameter(torch.empty(model_dim, self.projection_dim))
|
||||
self.model_dim = model_dim
|
||||
self.causal_masking = cfg["causal_masking"]
|
||||
|
||||
def forward_embedding(self, text_tokens: Tensor) -> Tensor:
|
||||
"""Return text embedding for all tokens.
|
||||
|
||||
Args:
|
||||
text_tokens: a tensor of token indices. Shape: [batch_size, context_length]
|
||||
|
||||
Returns:
|
||||
A tensor of [batch_size, context_length, hidden_dim].
|
||||
"""
|
||||
# [batch_size, context_length] --> [batch_size, context_length, hidden_dim]
|
||||
token_emb = self.embedding_layer(text_tokens)
|
||||
seq_len = token_emb.shape[1]
|
||||
if self.positional_embedding is not None:
|
||||
token_emb = token_emb + self.positional_embedding(seq_len).to(token_emb.dtype)
|
||||
token_emb = self.embedding_dropout(token_emb)
|
||||
return token_emb
|
||||
|
||||
@staticmethod
|
||||
@torch.jit.script # use scripting to avoid device constant
|
||||
def build_attention_mask(text_tokens: torch.Tensor) -> Tensor:
|
||||
"""Build causal attention mask [batch_size, context_length, context_length]."""
|
||||
# Build mask with full attention between the tokens
|
||||
# pytorch uses additive attention mask; fill with -inf
|
||||
batch_size, context_length = text_tokens.shape
|
||||
mask = torch.empty(context_length, context_length, device=text_tokens.device, dtype=torch.float32)
|
||||
mask.fill_(float("-inf"))
|
||||
mask.triu_(1) # zero out the lower diagonal
|
||||
mask = mask.unsqueeze(0) # add dummy batch dimension
|
||||
mask = mask.expand(batch_size, -1, -1)
|
||||
return mask
|
||||
|
||||
def encode_text(
|
||||
self,
|
||||
text_tokens: Tensor,
|
||||
key_padding_mask: Tensor | None = None,
|
||||
return_all_tokens: bool = False,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> Tensor:
|
||||
"""Return text token embeddings.
|
||||
|
||||
Args:
|
||||
text_tokens: a tensor of token indices. Shape: [batch_size, context_length]
|
||||
key_padding_mask: a tensor of boolean values as the padding mask of shape [batch_size, context_length]
|
||||
return_all_tokens: a boolean flag to return all tokens, defaults to False to return only EOT token
|
||||
embedding.
|
||||
|
||||
Returns:
|
||||
A tensor of [batch_size, context_length, hidden_dim] if return_all_tokens is
|
||||
True, otherwise a tensor of [batch_size, hidden_dim].
|
||||
"""
|
||||
# Discrete tokens to continuous embeddings
|
||||
# [batch_size, context_length] --> [batch_size, context_length, hidden_dim]
|
||||
token_emb = self.forward_embedding(text_tokens)
|
||||
|
||||
# [1, context_length, context_length]
|
||||
attn_mask = None
|
||||
if self.causal_masking:
|
||||
attn_mask = self.build_attention_mask(text_tokens=text_tokens)
|
||||
key_padding_mask = None
|
||||
|
||||
for layer in self.transformer:
|
||||
token_emb = layer(
|
||||
token_emb,
|
||||
key_padding_mask=key_padding_mask,
|
||||
attn_mask=attn_mask,
|
||||
)
|
||||
|
||||
# Apply layer norm
|
||||
token_emb = self.final_layer_norm(token_emb)
|
||||
|
||||
if return_all_tokens:
|
||||
return token_emb
|
||||
|
||||
# Take features from the eot embedding (eot_token is the highest number in each sequence)
|
||||
token_emb = token_emb[torch.arange(text_tokens.shape[0]), text_tokens.argmax(dim=-1)]
|
||||
|
||||
token_emb = token_emb @ self.projection_layer
|
||||
return token_emb
|
||||
|
||||
def forward(
|
||||
self,
|
||||
text_tokens: Tensor,
|
||||
key_padding_mask: Tensor | None = None,
|
||||
return_all_tokens: bool = False,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> Tensor:
|
||||
# Image-text pair data with single caption
|
||||
# [B, CL] --> [B, d]
|
||||
text_tokens = self.encode_text(
|
||||
text_tokens=text_tokens,
|
||||
key_padding_mask=key_padding_mask,
|
||||
return_all_tokens=return_all_tokens,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
return text_tokens
|
||||
Reference in New Issue
Block a user