ru-clip
copied
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions
83 lines
2.6 KiB
83 lines
2.6 KiB
2 years ago
|
# -*- coding: utf-8 -*-
|
||
|
import os
|
||
|
|
||
|
from huggingface_hub import hf_hub_url, cached_download
|
||
|
|
||
|
from . import model, processor, predictor
|
||
|
from .model import CLIP
|
||
|
from .processor import RuCLIPProcessor
|
||
|
from .predictor import Predictor
|
||
|
|
||
|
MODELS = {
|
||
|
'ruclip-vit-base-patch32-224': dict(
|
||
|
repo_id='sberbank-ai/ruclip-vit-base-patch32-224',
|
||
|
filenames=[
|
||
|
'bpe.model', 'config.json', 'pytorch_model.bin'
|
||
|
]
|
||
|
),
|
||
|
'ruclip-vit-base-patch16-224': dict(
|
||
|
repo_id='sberbank-ai/ruclip-vit-base-patch16-224',
|
||
|
filenames=[
|
||
|
'bpe.model', 'config.json', 'pytorch_model.bin'
|
||
|
]
|
||
|
),
|
||
|
'ruclip-vit-large-patch14-224': dict(
|
||
|
repo_id='sberbank-ai/ruclip-vit-large-patch14-224',
|
||
|
filenames=[
|
||
|
'bpe.model', 'config.json', 'pytorch_model.bin'
|
||
|
]
|
||
|
),
|
||
|
'ruclip-vit-large-patch14-336': dict(
|
||
|
repo_id='sberbank-ai/ruclip-vit-large-patch14-336',
|
||
|
filenames=[
|
||
|
'bpe.model', 'config.json', 'pytorch_model.bin'
|
||
|
]
|
||
|
),
|
||
|
'ruclip-vit-base-patch32-384': dict(
|
||
|
repo_id='sberbank-ai/ruclip-vit-base-patch32-384',
|
||
|
filenames=[
|
||
|
'bpe.model', 'config.json', 'pytorch_model.bin'
|
||
|
]
|
||
|
),
|
||
|
'ruclip-vit-base-patch16-384': dict(
|
||
|
repo_id='sberbank-ai/ruclip-vit-base-patch16-384',
|
||
|
filenames=[
|
||
|
'bpe.model', 'config.json', 'pytorch_model.bin'
|
||
|
]
|
||
|
),
|
||
|
}
|
||
|
|
||
|
|
||
|
def load(name, device='cpu', cache_dir='/tmp/ruclip', use_auth_token=None):
|
||
|
"""Load a ruCLIP model
|
||
|
Parameters
|
||
|
----------
|
||
|
name : str
|
||
|
A model name listed in ruclip.MODELS.keys()
|
||
|
device : Union[str, torch.device]
|
||
|
The device to put the loaded model
|
||
|
cache_dir: str
|
||
|
path to download the model files; by default, it uses "/tmp/ruclip"
|
||
|
Returns
|
||
|
-------
|
||
|
clip : torch.nn.Module
|
||
|
The ruCLIP model
|
||
|
clip_processor : ruclip.processor.RuCLIPProcessor
|
||
|
A ruCLIP processor which performs tokenization and image preprocessing
|
||
|
"""
|
||
|
assert name in MODELS, f'All models: {MODELS.keys()}'
|
||
|
config = MODELS[name]
|
||
|
repo_id = config['repo_id']
|
||
|
cache_dir = os.path.join(cache_dir, name)
|
||
|
for filename in config['filenames']:
|
||
|
config_file_url = hf_hub_url(repo_id=repo_id, filename=f'{filename}')
|
||
|
cached_download(config_file_url, cache_dir=cache_dir, force_filename=filename, use_auth_token=use_auth_token)
|
||
|
|
||
|
clip = CLIP.from_pretrained(cache_dir).eval().to(device)
|
||
|
clip_processor = RuCLIPProcessor.from_pretrained(cache_dir)
|
||
|
return clip, clip_processor
|
||
|
|
||
|
|
||
|
__all__ = ['processor', 'model', 'predictor', 'CLIP', 'RuCLIPProcessor', 'Predictor', 'MODELS', 'load']
|
||
|
__version__ = '0.0.2'
|