117 lines
3.8 KiB
Python
117 lines
3.8 KiB
Python
import torch
|
||
import torch.nn as nn
|
||
from pathlib import Path
|
||
from typing import List, Union
|
||
import mobileclip
|
||
|
||
class TextModel(nn.Module):
|
||
"""TextModel 基类,定义接口"""
|
||
def __init__(self):
|
||
super().__init__()
|
||
|
||
def tokenize(self, texts):
|
||
raise NotImplementedError
|
||
|
||
def encode_text(self, texts, dtype):
|
||
raise NotImplementedError
|
||
|
||
class MobileCLIP(TextModel):
|
||
"""
|
||
MobileCLIP 文本编码器。
|
||
"""
|
||
config_size_map = {"s0": "s0", "s1": "s1", "s2": "s2", "b": "b", "blt": "b"}
|
||
|
||
def __init__(self, checkpoint: str, size: str = "s0", device: Union[str, torch.device] = "cpu") -> None:
|
||
"""
|
||
初始化 MobileCLIP 文本编码器。
|
||
|
||
Args:
|
||
checkpoint (str): 模型权重文件路径 (.pt 或 .ts).
|
||
size (str): 模型大小标识符 ('s0', 's1', 's2', 'b', 'blt').
|
||
device (torch.device): 加载模型的设备.
|
||
"""
|
||
super().__init__()
|
||
|
||
if isinstance(device, str):
|
||
device = torch.device(device)
|
||
|
||
if not Path(checkpoint).exists():
|
||
raise FileNotFoundError(f"找不到权重文件: {checkpoint}")
|
||
|
||
if size not in self.config_size_map:
|
||
raise ValueError(f"不支持的大小: {size}. 可选: {list(self.config_size_map.keys())}")
|
||
|
||
config = self.config_size_map[size]
|
||
|
||
# 1. 加载 Tokenizer
|
||
self.tokenizer = mobileclip.get_tokenizer(f"mobileclip_{config}")
|
||
|
||
# 2. 加载模型
|
||
if str(checkpoint).endswith('.ts'):
|
||
# TorchScript 格式 (例如 mobileclip_blt.ts)
|
||
print(f"Loading TorchScript model from {checkpoint}...")
|
||
self.model = torch.jit.load(checkpoint, map_location=device)
|
||
self.is_torchscript = True
|
||
else:
|
||
# PyTorch 格式 (.pt)
|
||
print(f"Loading PyTorch model from {checkpoint}...")
|
||
self.model = mobileclip.create_model_and_transforms(
|
||
f"mobileclip_{config}",
|
||
pretrained=checkpoint,
|
||
device=device
|
||
)[0]
|
||
self.is_torchscript = False
|
||
|
||
self.to(device)
|
||
self.device = device
|
||
self.eval()
|
||
|
||
def tokenize(self, texts: List[str]) -> torch.Tensor:
|
||
"""
|
||
将文本转换为 token。
|
||
"""
|
||
return self.tokenizer(texts).to(self.device)
|
||
|
||
def encode_text(self, texts: torch.Tensor, dtype: torch.dtype = torch.float32) -> torch.Tensor:
|
||
"""
|
||
编码 tokenized 文本并进行归一化。
|
||
"""
|
||
with torch.no_grad():
|
||
if self.is_torchscript:
|
||
return self.model(texts).to(dtype)
|
||
|
||
text_features = self.model.encode_text(texts).to(dtype) # type: ignore
|
||
text_features /= text_features.norm(p=2, dim=-1, keepdim=True)
|
||
return text_features
|
||
|
||
# --- 使用示例 ---
|
||
if __name__ == "__main__":
|
||
# 1. 设置设备
|
||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||
print(f"Using device: {device}")
|
||
|
||
# 指定本地模型路径
|
||
checkpoint_path = "mobileclip_blt.ts"
|
||
|
||
try:
|
||
if Path(checkpoint_path).exists():
|
||
# 2. 初始化模型 (指定本地路径和对应大小)
|
||
# 注意:blt 对应 size="blt"
|
||
model = MobileCLIP(checkpoint=checkpoint_path, size="blt", device=device)
|
||
|
||
# 3. 准备文本
|
||
input_texts = ["a photo of a cat", "a photo of a dog"]
|
||
|
||
# 4. Tokenize
|
||
tokens = model.tokenize(input_texts)
|
||
print(f"Tokens shape: {tokens.shape}")
|
||
|
||
# 5. Encode
|
||
features = model.encode_text(tokens)
|
||
print(f"Features shape: {features.shape}")
|
||
print("运行成功!")
|
||
else:
|
||
print(f"权重文件不存在: {checkpoint_path}")
|
||
|
||
except Exception as e:
|
||
print(f"发生错误: {e}") |