logo
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions

96 lines
3.0 KiB

# coding=utf-8
# Copyright 2022 rinna Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Union
import json
import torch
from torchvision import transforms as T
from huggingface_hub import hf_hub_url, cached_download
import os
from .clip import CLIPModel
from .cloob import CLOOBModel
# TODO: Fill in repo_ids
MODELS = {
'rinna/japanese-clip-vit-b-16': {
'repo_id': 'rinna/japanese-clip-vit-b-16',
'model_class': CLIPModel,
},
'rinna/japanese-cloob-vit-b-16': {
'repo_id': 'rinna/japanese-cloob-vit-b-16',
'model_class': CLOOBModel,
}
}
MODEL_CLASSES = {
"cloob": CLOOBModel,
"clip": CLIPModel,
}
MODEL_FILE = "pytorch_model.bin"
CONFIG_FILE = "config.json"
def available_models():
return list(MODELS.keys())
def _convert_to_rgb(image):
return image.convert('RGB')
def _transform(image_size):
return T.Compose([
T.Resize(image_size, interpolation=T.InterpolationMode.BILINEAR),
T.CenterCrop(image_size),
_convert_to_rgb,
T.ToTensor(),
T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711),)
])
def _download(repo_id: str, cache_dir: str):
config_file_url = hf_hub_url(repo_id=repo_id, filename=CONFIG_FILE)
cached_download(config_file_url, cache_dir=cache_dir)
model_file_url = hf_hub_url(repo_id=repo_id, filename=MODEL_FILE)
cached_download(model_file_url, cache_dir=cache_dir)
def load(
model_name: str,
device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu",
**kwargs
):
"""
Args:
model_name: model unique name or path to pre-downloaded model
device: device to put the loaded model
kwargs: kwargs for huggingface pretrained model class
Return:
(torch.nn.Module, A torchvision transform)
"""
if model_name in MODELS.keys():
ModelClass = CLIPModel if 'clip' in model_name else CLOOBModel
elif os.path.exists(model_name):
assert os.path.exists(os.path.join(model_name, CONFIG_FILE))
with open(os.path.join(model_name, CONFIG_FILE), "r", encoding="utf-8") as f:
j = json.load(f)
ModelClass = MODEL_CLASSES[j["model_type"]]
else:
RuntimeError(f"Model {model_name} not found; available models = {available_models()}")
model = ModelClass.from_pretrained(model_name, **kwargs)
model = model.eval().requires_grad_(False).to(device)
return model, _transform(model.config.vision_config.image_size)