Files
Yolo-standalone/mobileclip/__init__.py
2025-12-27 02:14:11 +08:00

99 lines
2.8 KiB
Python

# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
# For licensing see accompanying LICENSE file.
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
from __future__ import annotations
import json
import os
from typing import Any, Optional, Tuple, Union
import torch
import torch.nn as nn
from torchvision.transforms import (
CenterCrop,
Compose,
InterpolationMode,
Resize,
ToTensor,
)
from mobileclip.clip import CLIP
from mobileclip.modules.common.mobileone import reparameterize_model
from mobileclip.modules.text.tokenizer import (
ClipTokenizer,
)
def create_model_and_transforms(
model_name: str,
pretrained: str | None = None,
reparameterize: bool | None = True,
device: str | torch.device = "cpu",
) -> tuple[nn.Module, Any, Any]:
"""Method to instantiate model and pre-processing transforms necessary for inference.
Args:
model_name: Model name. Choose from ['mobileclip_s0', 'mobileclip_s1', 'mobileclip_s2', 'mobileclip_b']
pretrained: Location of pretrained checkpoint.
reparameterize: When set to True, re-parameterizable branches get folded for faster inference.
device: Device identifier for model placement.
Returns:
Tuple of instantiated model, and preprocessing transforms for inference.
"""
# Config files
root_dir = os.path.dirname(os.path.abspath(__file__))
configs_dir = os.path.join(root_dir, "configs")
model_cfg_file = os.path.join(configs_dir, model_name + ".json")
# Get config from yaml file
if not os.path.exists(model_cfg_file):
raise ValueError(f"Unsupported model name: {model_name}")
model_cfg = json.load(open(model_cfg_file))
# Build preprocessing transforms for inference
resolution = model_cfg["image_cfg"]["image_size"]
resize_size = resolution
centercrop_size = resolution
aug_list = [
Resize(
resize_size,
interpolation=InterpolationMode.BILINEAR,
),
CenterCrop(centercrop_size),
ToTensor(),
]
preprocess = Compose(aug_list)
# Build model
model = CLIP(cfg=model_cfg)
model.to(device)
model.eval()
# Load checkpoint
if pretrained is not None:
chkpt = torch.load(pretrained)
model.load_state_dict(chkpt)
# Reparameterize model for inference (if specified)
if reparameterize:
model = reparameterize_model(model)
return model, None, preprocess
def get_tokenizer(model_name: str) -> nn.Module:
# Config files
root_dir = os.path.dirname(os.path.abspath(__file__))
configs_dir = os.path.join(root_dir, "configs")
model_cfg_file = os.path.join(configs_dir, model_name + ".json")
# Get config from yaml file
model_cfg = json.load(open(model_cfg_file))
# Build tokenizer
text_tokenizer = ClipTokenizer(model_cfg)
return text_tokenizer