logo
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions

82 lines
2.6 KiB

# -*- coding: utf-8 -*-
import os
from huggingface_hub import hf_hub_url, cached_download
from . import model, processor, predictor
from .model import CLIP
from .processor import RuCLIPProcessor
from .predictor import Predictor
MODELS = {
'ruclip-vit-base-patch32-224': dict(
repo_id='sberbank-ai/ruclip-vit-base-patch32-224',
filenames=[
'bpe.model', 'config.json', 'pytorch_model.bin'
]
),
'ruclip-vit-base-patch16-224': dict(
repo_id='sberbank-ai/ruclip-vit-base-patch16-224',
filenames=[
'bpe.model', 'config.json', 'pytorch_model.bin'
]
),
'ruclip-vit-large-patch14-224': dict(
repo_id='sberbank-ai/ruclip-vit-large-patch14-224',
filenames=[
'bpe.model', 'config.json', 'pytorch_model.bin'
]
),
'ruclip-vit-large-patch14-336': dict(
repo_id='sberbank-ai/ruclip-vit-large-patch14-336',
filenames=[
'bpe.model', 'config.json', 'pytorch_model.bin'
]
),
'ruclip-vit-base-patch32-384': dict(
repo_id='sberbank-ai/ruclip-vit-base-patch32-384',
filenames=[
'bpe.model', 'config.json', 'pytorch_model.bin'
]
),
'ruclip-vit-base-patch16-384': dict(
repo_id='sberbank-ai/ruclip-vit-base-patch16-384',
filenames=[
'bpe.model', 'config.json', 'pytorch_model.bin'
]
),
}
def load(name, device='cpu', cache_dir='/tmp/ruclip', use_auth_token=None):
"""Load a ruCLIP model
Parameters
----------
name : str
A model name listed in ruclip.MODELS.keys()
device : Union[str, torch.device]
The device to put the loaded model
cache_dir: str
path to download the model files; by default, it uses "/tmp/ruclip"
Returns
-------
clip : torch.nn.Module
The ruCLIP model
clip_processor : ruclip.processor.RuCLIPProcessor
A ruCLIP processor which performs tokenization and image preprocessing
"""
assert name in MODELS, f'All models: {MODELS.keys()}'
config = MODELS[name]
repo_id = config['repo_id']
cache_dir = os.path.join(cache_dir, name)
for filename in config['filenames']:
config_file_url = hf_hub_url(repo_id=repo_id, filename=f'{filename}')
cached_download(config_file_url, cache_dir=cache_dir, force_filename=filename, use_auth_token=use_auth_token)
clip = CLIP.from_pretrained(cache_dir).eval().to(device)
clip_processor = RuCLIPProcessor.from_pretrained(cache_dir)
return clip, clip_processor
__all__ = ['processor', 'model', 'predictor', 'CLIP', 'RuCLIPProcessor', 'Predictor', 'MODELS', 'load']
__version__ = '0.0.2'