import torch import torch.nn as nn from pathlib import Path from typing import List, Union import mobileclip class TextModel(nn.Module): """TextModel 基类,定义接口""" def __init__(self): super().__init__() def tokenize(self, texts): raise NotImplementedError def encode_text(self, texts, dtype): raise NotImplementedError class MobileCLIP(TextModel): """ MobileCLIP 文本编码器。 """ config_size_map = {"s0": "s0", "s1": "s1", "s2": "s2", "b": "b", "blt": "b"} def __init__(self, checkpoint: str, size: str = "s0", device: Union[str, torch.device] = "cpu") -> None: """ 初始化 MobileCLIP 文本编码器。 Args: checkpoint (str): 模型权重文件路径 (.pt 或 .ts). size (str): 模型大小标识符 ('s0', 's1', 's2', 'b', 'blt'). device (torch.device): 加载模型的设备. """ super().__init__() if isinstance(device, str): device = torch.device(device) if not Path(checkpoint).exists(): raise FileNotFoundError(f"找不到权重文件: {checkpoint}") if size not in self.config_size_map: raise ValueError(f"不支持的大小: {size}. 可选: {list(self.config_size_map.keys())}") config = self.config_size_map[size] # 1. 加载 Tokenizer self.tokenizer = mobileclip.get_tokenizer(f"mobileclip_{config}") # 2. 加载模型 if str(checkpoint).endswith('.ts'): # TorchScript 格式 (例如 mobileclip_blt.ts) print(f"Loading TorchScript model from {checkpoint}...") self.model = torch.jit.load(checkpoint, map_location=device) self.is_torchscript = True else: # PyTorch 格式 (.pt) print(f"Loading PyTorch model from {checkpoint}...") self.model = mobileclip.create_model_and_transforms( f"mobileclip_{config}", pretrained=checkpoint, device=device )[0] self.is_torchscript = False self.to(device) self.device = device self.eval() def tokenize(self, texts: List[str]) -> torch.Tensor: """ 将文本转换为 token。 """ return self.tokenizer(texts).to(self.device) def encode_text(self, texts: torch.Tensor, dtype: torch.dtype = torch.float32) -> torch.Tensor: """ 编码 tokenized 文本并进行归一化。 """ with torch.no_grad(): if self.is_torchscript: return self.model(texts).to(dtype) text_features = self.model.encode_text(texts).to(dtype) # type: ignore text_features /= text_features.norm(p=2, dim=-1, keepdim=True) return text_features # --- 使用示例 --- if __name__ == "__main__": # 1. 设置设备 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") # 指定本地模型路径 checkpoint_path = "mobileclip_blt.ts" try: if Path(checkpoint_path).exists(): # 2. 初始化模型 (指定本地路径和对应大小) # 注意:blt 对应 size="blt" model = MobileCLIP(checkpoint=checkpoint_path, size="blt", device=device) # 3. 准备文本 input_texts = ["a photo of a cat", "a photo of a dog"] # 4. Tokenize tokens = model.tokenize(input_texts) print(f"Tokens shape: {tokens.shape}") # 5. Encode features = model.encode_text(tokens) print(f"Features shape: {features.shape}") print("运行成功!") else: print(f"权重文件不存在: {checkpoint_path}") except Exception as e: print(f"发生错误: {e}")