第一次提交Yolo项目
This commit is contained in:
98
mobileclip/__init__.py
Normal file
98
mobileclip/__init__.py
Normal file
@@ -0,0 +1,98 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
# For licensing see accompanying LICENSE file.
|
||||
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torchvision.transforms import (
|
||||
CenterCrop,
|
||||
Compose,
|
||||
InterpolationMode,
|
||||
Resize,
|
||||
ToTensor,
|
||||
)
|
||||
|
||||
from mobileclip.clip import CLIP
|
||||
from mobileclip.modules.common.mobileone import reparameterize_model
|
||||
from mobileclip.modules.text.tokenizer import (
|
||||
ClipTokenizer,
|
||||
)
|
||||
|
||||
|
||||
def create_model_and_transforms(
|
||||
model_name: str,
|
||||
pretrained: str | None = None,
|
||||
reparameterize: bool | None = True,
|
||||
device: str | torch.device = "cpu",
|
||||
) -> tuple[nn.Module, Any, Any]:
|
||||
"""Method to instantiate model and pre-processing transforms necessary for inference.
|
||||
|
||||
Args:
|
||||
model_name: Model name. Choose from ['mobileclip_s0', 'mobileclip_s1', 'mobileclip_s2', 'mobileclip_b']
|
||||
pretrained: Location of pretrained checkpoint.
|
||||
reparameterize: When set to True, re-parameterizable branches get folded for faster inference.
|
||||
device: Device identifier for model placement.
|
||||
|
||||
Returns:
|
||||
Tuple of instantiated model, and preprocessing transforms for inference.
|
||||
"""
|
||||
# Config files
|
||||
root_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
configs_dir = os.path.join(root_dir, "configs")
|
||||
model_cfg_file = os.path.join(configs_dir, model_name + ".json")
|
||||
|
||||
# Get config from yaml file
|
||||
if not os.path.exists(model_cfg_file):
|
||||
raise ValueError(f"Unsupported model name: {model_name}")
|
||||
model_cfg = json.load(open(model_cfg_file))
|
||||
|
||||
# Build preprocessing transforms for inference
|
||||
resolution = model_cfg["image_cfg"]["image_size"]
|
||||
resize_size = resolution
|
||||
centercrop_size = resolution
|
||||
aug_list = [
|
||||
Resize(
|
||||
resize_size,
|
||||
interpolation=InterpolationMode.BILINEAR,
|
||||
),
|
||||
CenterCrop(centercrop_size),
|
||||
ToTensor(),
|
||||
]
|
||||
preprocess = Compose(aug_list)
|
||||
|
||||
# Build model
|
||||
model = CLIP(cfg=model_cfg)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
# Load checkpoint
|
||||
if pretrained is not None:
|
||||
chkpt = torch.load(pretrained)
|
||||
model.load_state_dict(chkpt)
|
||||
|
||||
# Reparameterize model for inference (if specified)
|
||||
if reparameterize:
|
||||
model = reparameterize_model(model)
|
||||
|
||||
return model, None, preprocess
|
||||
|
||||
|
||||
def get_tokenizer(model_name: str) -> nn.Module:
|
||||
# Config files
|
||||
root_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
configs_dir = os.path.join(root_dir, "configs")
|
||||
model_cfg_file = os.path.join(configs_dir, model_name + ".json")
|
||||
|
||||
# Get config from yaml file
|
||||
model_cfg = json.load(open(model_cfg_file))
|
||||
|
||||
# Build tokenizer
|
||||
text_tokenizer = ClipTokenizer(model_cfg)
|
||||
return text_tokenizer
|
||||
Reference in New Issue
Block a user