Files
Yolo-standalone/mobile_clip_standalone.py
2025-12-27 02:14:11 +08:00

117 lines
3.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn as nn
from pathlib import Path
from typing import List, Union
import mobileclip
class TextModel(nn.Module):
"""TextModel 基类,定义接口"""
def __init__(self):
super().__init__()
def tokenize(self, texts):
raise NotImplementedError
def encode_text(self, texts, dtype):
raise NotImplementedError
class MobileCLIP(TextModel):
"""
MobileCLIP 文本编码器。
"""
config_size_map = {"s0": "s0", "s1": "s1", "s2": "s2", "b": "b", "blt": "b"}
def __init__(self, checkpoint: str, size: str = "s0", device: Union[str, torch.device] = "cpu") -> None:
"""
初始化 MobileCLIP 文本编码器。
Args:
checkpoint (str): 模型权重文件路径 (.pt 或 .ts).
size (str): 模型大小标识符 ('s0', 's1', 's2', 'b', 'blt').
device (torch.device): 加载模型的设备.
"""
super().__init__()
if isinstance(device, str):
device = torch.device(device)
if not Path(checkpoint).exists():
raise FileNotFoundError(f"找不到权重文件: {checkpoint}")
if size not in self.config_size_map:
raise ValueError(f"不支持的大小: {size}. 可选: {list(self.config_size_map.keys())}")
config = self.config_size_map[size]
# 1. 加载 Tokenizer
self.tokenizer = mobileclip.get_tokenizer(f"mobileclip_{config}")
# 2. 加载模型
if str(checkpoint).endswith('.ts'):
# TorchScript 格式 (例如 mobileclip_blt.ts)
print(f"Loading TorchScript model from {checkpoint}...")
self.model = torch.jit.load(checkpoint, map_location=device)
self.is_torchscript = True
else:
# PyTorch 格式 (.pt)
print(f"Loading PyTorch model from {checkpoint}...")
self.model = mobileclip.create_model_and_transforms(
f"mobileclip_{config}",
pretrained=checkpoint,
device=device
)[0]
self.is_torchscript = False
self.to(device)
self.device = device
self.eval()
def tokenize(self, texts: List[str]) -> torch.Tensor:
"""
将文本转换为 token。
"""
return self.tokenizer(texts).to(self.device)
def encode_text(self, texts: torch.Tensor, dtype: torch.dtype = torch.float32) -> torch.Tensor:
"""
编码 tokenized 文本并进行归一化。
"""
with torch.no_grad():
if self.is_torchscript:
return self.model(texts).to(dtype)
text_features = self.model.encode_text(texts).to(dtype) # type: ignore
text_features /= text_features.norm(p=2, dim=-1, keepdim=True)
return text_features
# --- 使用示例 ---
if __name__ == "__main__":
# 1. 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# 指定本地模型路径
checkpoint_path = "mobileclip_blt.ts"
try:
if Path(checkpoint_path).exists():
# 2. 初始化模型 (指定本地路径和对应大小)
# 注意blt 对应 size="blt"
model = MobileCLIP(checkpoint=checkpoint_path, size="blt", device=device)
# 3. 准备文本
input_texts = ["a photo of a cat", "a photo of a dog"]
# 4. Tokenize
tokens = model.tokenize(input_texts)
print(f"Tokens shape: {tokens.shape}")
# 5. Encode
features = model.encode_text(tokens)
print(f"Features shape: {features.shape}")
print("运行成功!")
else:
print(f"权重文件不存在: {checkpoint_path}")
except Exception as e:
print(f"发生错误: {e}")