# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license # For licensing see accompanying LICENSE file. # Copyright (C) 2024 Apple Inc. All Rights Reserved. from __future__ import annotations import math from collections.abc import Sequence import torch from torch import Tensor, nn from mobileclip import logger from mobileclip.modules.common.transformer import ( PositionalEmbedding, TransformerEncoder, get_normalization_layer, ) from mobileclip.modules.text.repmixer import RepMixerBlock class TextTransformer(nn.Module): def __init__(self, cfg: dict, projection_dim: int, *args, **kwargs) -> None: super().__init__() model_dim = cfg["dim"] no_scale_embedding = cfg.get("no_scale_embedding", False) no_pos_embedding = cfg.get("no_pos_embedding", False) embed_dropout = cfg.get("embed_dropout", 0.0) norm_layer = cfg["norm_layer"] variant = cfg["model_name"] self.vocab_size = cfg["vocab_size"] self.projection_dim = projection_dim # Token embedding layer self.embedding_layer = nn.Embedding(embedding_dim=model_dim, num_embeddings=self.vocab_size) self.embed_scale = 1.0 if no_scale_embedding else model_dim**-0.5 # Context length context_length = cfg["context_length"] assert context_length is not None, "Context length can't be None. Please set value accordingly." self.positional_embedding = ( None if no_pos_embedding else PositionalEmbedding(num_embeddings=context_length, embedding_dim=model_dim) ) self.embedding_dropout = nn.Dropout(p=embed_dropout) # Transformer layer n_transformer_layers = cfg["n_transformer_layers"] # FFN multipliers for transformer layer ffn_multipliers = cfg["ffn_multiplier_per_layer"] if isinstance(ffn_multipliers, (float, int)): ffn_multipliers = [ffn_multipliers] * n_transformer_layers if not isinstance(ffn_multipliers, Sequence): logger.error( f"{self.__class__.__name__} expects FFN multipliers as a list, whose length is the same as" f" number of transformer layers. Got: {type(ffn_multipliers)}" ) elif isinstance(ffn_multipliers, Sequence) and len(ffn_multipliers) != n_transformer_layers: logger.error( f"We need FFN multiplier for each transformer layer. Got {len(ffn_multipliers)} ffn" f" multipliers while number of transformer layers = {n_transformer_layers}" ) ffn_dims = [int(math.ceil(model_dim * ffn_mult / 16.0) * 16.0) for ffn_mult in ffn_multipliers] # Heads for transformer layers mha_heads = cfg["n_heads_per_layer"] if isinstance(mha_heads, int): mha_heads = [mha_heads] * n_transformer_layers if not isinstance(mha_heads, Sequence): logger.error( f"{self.__class__.__name__} expects MHA heads as a list, whose length is the same as number of " f"transformer layers. Got: {type(mha_heads)}" ) elif isinstance(mha_heads, Sequence) and len(mha_heads) != n_transformer_layers: logger.error( f"{self.__class__.__name__} needs MHA heads for each transformer layer. Got {len(mha_heads)} mha heads while" f" number of transformer layers = {n_transformer_layers}" ) if variant == "base": self.transformer = nn.ModuleList( [ TransformerEncoder( embed_dim=model_dim, num_heads=mha_heads[layer_idx], ffn_latent_dim=ffn_dims[layer_idx], transformer_norm_layer=norm_layer, ) for layer_idx in range(n_transformer_layers) ] ) elif variant == "mct": self.transformer = nn.ModuleList([RepMixerBlock(dim=model_dim)]) self.transformer.extend( [ TransformerEncoder( embed_dim=model_dim, num_heads=mha_heads[layer_idx], ffn_latent_dim=ffn_dims[layer_idx], transformer_norm_layer=norm_layer, ) for layer_idx in range(n_transformer_layers) ] ) self.transformer.extend([RepMixerBlock(dim=model_dim)]) else: raise ValueError(f"Unrecognized text encoder variant {variant}") self.final_layer_norm = get_normalization_layer(num_features=model_dim, norm_type=norm_layer) self.projection_layer = nn.Parameter(torch.empty(model_dim, self.projection_dim)) self.model_dim = model_dim self.causal_masking = cfg["causal_masking"] def forward_embedding(self, text_tokens: Tensor) -> Tensor: """Return text embedding for all tokens. Args: text_tokens: a tensor of token indices. Shape: [batch_size, context_length] Returns: A tensor of [batch_size, context_length, hidden_dim]. """ # [batch_size, context_length] --> [batch_size, context_length, hidden_dim] token_emb = self.embedding_layer(text_tokens) seq_len = token_emb.shape[1] if self.positional_embedding is not None: token_emb = token_emb + self.positional_embedding(seq_len).to(token_emb.dtype) token_emb = self.embedding_dropout(token_emb) return token_emb @staticmethod @torch.jit.script # use scripting to avoid device constant def build_attention_mask(text_tokens: torch.Tensor) -> Tensor: """Build causal attention mask [batch_size, context_length, context_length].""" # Build mask with full attention between the tokens # pytorch uses additive attention mask; fill with -inf batch_size, context_length = text_tokens.shape mask = torch.empty(context_length, context_length, device=text_tokens.device, dtype=torch.float32) mask.fill_(float("-inf")) mask.triu_(1) # zero out the lower diagonal mask = mask.unsqueeze(0) # add dummy batch dimension mask = mask.expand(batch_size, -1, -1) return mask def encode_text( self, text_tokens: Tensor, key_padding_mask: Tensor | None = None, return_all_tokens: bool = False, *args, **kwargs, ) -> Tensor: """Return text token embeddings. Args: text_tokens: a tensor of token indices. Shape: [batch_size, context_length] key_padding_mask: a tensor of boolean values as the padding mask of shape [batch_size, context_length] return_all_tokens: a boolean flag to return all tokens, defaults to False to return only EOT token embedding. Returns: A tensor of [batch_size, context_length, hidden_dim] if return_all_tokens is True, otherwise a tensor of [batch_size, hidden_dim]. """ # Discrete tokens to continuous embeddings # [batch_size, context_length] --> [batch_size, context_length, hidden_dim] token_emb = self.forward_embedding(text_tokens) # [1, context_length, context_length] attn_mask = None if self.causal_masking: attn_mask = self.build_attention_mask(text_tokens=text_tokens) key_padding_mask = None for layer in self.transformer: token_emb = layer( token_emb, key_padding_mask=key_padding_mask, attn_mask=attn_mask, ) # Apply layer norm token_emb = self.final_layer_norm(token_emb) if return_all_tokens: return token_emb # Take features from the eot embedding (eot_token is the highest number in each sequence) token_emb = token_emb[torch.arange(text_tokens.shape[0]), text_tokens.argmax(dim=-1)] token_emb = token_emb @ self.projection_layer return token_emb def forward( self, text_tokens: Tensor, key_padding_mask: Tensor | None = None, return_all_tokens: bool = False, *args, **kwargs, ) -> Tensor: # Image-text pair data with single caption # [B, CL] --> [B, d] text_tokens = self.encode_text( text_tokens=text_tokens, key_padding_mask=key_padding_mask, return_all_tokens=return_all_tokens, *args, **kwargs, ) return text_tokens