第一次提交Yolo项目
This commit is contained in:
117
mobile_clip_standalone.py
Normal file
117
mobile_clip_standalone.py
Normal file
@@ -0,0 +1,117 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from pathlib import Path
|
||||
from typing import List, Union
|
||||
import mobileclip
|
||||
|
||||
class TextModel(nn.Module):
|
||||
"""TextModel 基类,定义接口"""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def tokenize(self, texts):
|
||||
raise NotImplementedError
|
||||
|
||||
def encode_text(self, texts, dtype):
|
||||
raise NotImplementedError
|
||||
|
||||
class MobileCLIP(TextModel):
|
||||
"""
|
||||
MobileCLIP 文本编码器。
|
||||
"""
|
||||
config_size_map = {"s0": "s0", "s1": "s1", "s2": "s2", "b": "b", "blt": "b"}
|
||||
|
||||
def __init__(self, checkpoint: str, size: str = "s0", device: Union[str, torch.device] = "cpu") -> None:
|
||||
"""
|
||||
初始化 MobileCLIP 文本编码器。
|
||||
|
||||
Args:
|
||||
checkpoint (str): 模型权重文件路径 (.pt 或 .ts).
|
||||
size (str): 模型大小标识符 ('s0', 's1', 's2', 'b', 'blt').
|
||||
device (torch.device): 加载模型的设备.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
if isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
|
||||
if not Path(checkpoint).exists():
|
||||
raise FileNotFoundError(f"找不到权重文件: {checkpoint}")
|
||||
|
||||
if size not in self.config_size_map:
|
||||
raise ValueError(f"不支持的大小: {size}. 可选: {list(self.config_size_map.keys())}")
|
||||
|
||||
config = self.config_size_map[size]
|
||||
|
||||
# 1. 加载 Tokenizer
|
||||
self.tokenizer = mobileclip.get_tokenizer(f"mobileclip_{config}")
|
||||
|
||||
# 2. 加载模型
|
||||
if str(checkpoint).endswith('.ts'):
|
||||
# TorchScript 格式 (例如 mobileclip_blt.ts)
|
||||
print(f"Loading TorchScript model from {checkpoint}...")
|
||||
self.model = torch.jit.load(checkpoint, map_location=device)
|
||||
self.is_torchscript = True
|
||||
else:
|
||||
# PyTorch 格式 (.pt)
|
||||
print(f"Loading PyTorch model from {checkpoint}...")
|
||||
self.model = mobileclip.create_model_and_transforms(
|
||||
f"mobileclip_{config}",
|
||||
pretrained=checkpoint,
|
||||
device=device
|
||||
)[0]
|
||||
self.is_torchscript = False
|
||||
|
||||
self.to(device)
|
||||
self.device = device
|
||||
self.eval()
|
||||
|
||||
def tokenize(self, texts: List[str]) -> torch.Tensor:
|
||||
"""
|
||||
将文本转换为 token。
|
||||
"""
|
||||
return self.tokenizer(texts).to(self.device)
|
||||
|
||||
def encode_text(self, texts: torch.Tensor, dtype: torch.dtype = torch.float32) -> torch.Tensor:
|
||||
"""
|
||||
编码 tokenized 文本并进行归一化。
|
||||
"""
|
||||
with torch.no_grad():
|
||||
if self.is_torchscript:
|
||||
return self.model(texts).to(dtype)
|
||||
|
||||
text_features = self.model.encode_text(texts).to(dtype) # type: ignore
|
||||
text_features /= text_features.norm(p=2, dim=-1, keepdim=True)
|
||||
return text_features
|
||||
|
||||
# --- 使用示例 ---
|
||||
if __name__ == "__main__":
|
||||
# 1. 设置设备
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
print(f"Using device: {device}")
|
||||
|
||||
# 指定本地模型路径
|
||||
checkpoint_path = "mobileclip_blt.ts"
|
||||
|
||||
try:
|
||||
if Path(checkpoint_path).exists():
|
||||
# 2. 初始化模型 (指定本地路径和对应大小)
|
||||
# 注意:blt 对应 size="blt"
|
||||
model = MobileCLIP(checkpoint=checkpoint_path, size="blt", device=device)
|
||||
|
||||
# 3. 准备文本
|
||||
input_texts = ["a photo of a cat", "a photo of a dog"]
|
||||
|
||||
# 4. Tokenize
|
||||
tokens = model.tokenize(input_texts)
|
||||
print(f"Tokens shape: {tokens.shape}")
|
||||
|
||||
# 5. Encode
|
||||
features = model.encode_text(tokens)
|
||||
print(f"Features shape: {features.shape}")
|
||||
print("运行成功!")
|
||||
else:
|
||||
print(f"权重文件不存在: {checkpoint_path}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"发生错误: {e}")
|
||||
Reference in New Issue
Block a user