|
@ -13,18 +13,20 @@ |
|
|
# limitations under the License. |
|
|
# limitations under the License. |
|
|
|
|
|
|
|
|
import sys |
|
|
import sys |
|
|
|
|
|
import os |
|
|
from pathlib import Path |
|
|
from pathlib import Path |
|
|
|
|
|
from urllib.parse import urlparse |
|
|
|
|
|
from collections import OrderedDict |
|
|
|
|
|
|
|
|
import torch |
|
|
import torch |
|
|
from torchvision import transforms |
|
|
from torchvision import transforms |
|
|
|
|
|
from timm.models.hub import download_cached_file |
|
|
|
|
|
|
|
|
from towhee import register |
|
|
from towhee import register |
|
|
from towhee.operator.base import NNOperator, OperatorFlag |
|
|
from towhee.operator.base import NNOperator, OperatorFlag |
|
|
from towhee.types.arg import arg, to_image_color |
|
|
from towhee.types.arg import arg, to_image_color |
|
|
from towhee.types.image_utils import from_pil, to_pil |
|
|
from towhee.types.image_utils import from_pil, to_pil |
|
|
|
|
|
|
|
|
from tokenizer import SimpleTokenizer |
|
|
|
|
|
|
|
|
|
|
|
def get_model(model): |
|
|
def get_model(model): |
|
|
if isinstance(model, torch.nn.DataParallel) \ |
|
|
if isinstance(model, torch.nn.DataParallel) \ |
|
|
or isinstance(model, torch.nn.parallel.DistributedDataParallel): |
|
|
or isinstance(model, torch.nn.parallel.DistributedDataParallel): |
|
@ -32,16 +34,52 @@ def get_model(model): |
|
|
else: |
|
|
else: |
|
|
return model |
|
|
return model |
|
|
|
|
|
|
|
|
|
|
|
def is_url(url_or_filename): |
|
|
|
|
|
parsed = urlparse(url_or_filename) |
|
|
|
|
|
return parsed.scheme in ("http", "https") |
|
|
|
|
|
|
|
|
|
|
|
def load_checkpoint(url_or_filename, models, device): |
|
|
|
|
|
if is_url(url_or_filename): |
|
|
|
|
|
cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True) |
|
|
|
|
|
checkpoint = torch.load(cached_file, map_location='cpu') |
|
|
|
|
|
elif os.path.isfile(url_or_filename): |
|
|
|
|
|
checkpoint = torch.load(url_or_filename, map_location='cpu') |
|
|
|
|
|
else: |
|
|
|
|
|
raise RuntimeError('checkpoint url or path is invalid') |
|
|
|
|
|
|
|
|
|
|
|
if is_url(url_or_filename): |
|
|
|
|
|
cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True) |
|
|
|
|
|
checkpoint = torch.load(cached_file, map_location='cpu') |
|
|
|
|
|
elif os.path.isfile(url_or_filename): |
|
|
|
|
|
checkpoint = torch.load(url_or_filename, map_location='cpu') |
|
|
|
|
|
else: |
|
|
|
|
|
raise RuntimeError('checkpoint url or path is invalid') |
|
|
|
|
|
|
|
|
|
|
|
state_dict = OrderedDict() |
|
|
|
|
|
for k, v in checkpoint['state_dict'].items(): |
|
|
|
|
|
state_dict[k.replace('module.', '')] = v |
|
|
|
|
|
old_args = checkpoint['args'] |
|
|
|
|
|
|
|
|
|
|
|
model = getattr(models, old_args.model)(rand_embed=False, |
|
|
|
|
|
ssl_mlp_dim=old_args.ssl_mlp_dim, ssl_emb_dim=old_args.ssl_emb_dim) |
|
|
|
|
|
model.to(device) |
|
|
|
|
|
model.load_state_dict(state_dict, strict=True) |
|
|
|
|
|
return model |
|
|
|
|
|
|
|
|
@register(output_schema=['vec']) |
|
|
@register(output_schema=['vec']) |
|
|
class Slip(NNOperator) |
|
|
|
|
|
|
|
|
class Slip(NNOperator): |
|
|
""" |
|
|
""" |
|
|
SLIP multi-modal embedding operator |
|
|
SLIP multi-modal embedding operator |
|
|
""" |
|
|
""" |
|
|
def __init__(self, model_name: str, modality: str): |
|
|
def __init__(self, model_name: str, modality: str): |
|
|
super().__init__() |
|
|
super().__init__() |
|
|
sys.path.append(str(Path(__file__).parent)) |
|
|
sys.path.append(str(Path(__file__).parent)) |
|
|
|
|
|
import models |
|
|
|
|
|
from tokenizer import SimpleTokenizer |
|
|
self.tokenizer = SimpleTokenizer() |
|
|
self.tokenizer = SimpleTokenizer() |
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
self._modality = modality |
|
|
|
|
|
self.model = load_checkpoint(self._configs()[model_name]['weights'], models, self.device) |
|
|
self.model.to(self.device) |
|
|
self.model.to(self.device) |
|
|
self.model.eval() |
|
|
self.model.eval() |
|
|
|
|
|
|
|
@ -61,13 +99,14 @@ class Slip(NNOperator) |
|
|
vec = self._inference_from_text(data) |
|
|
vec = self._inference_from_text(data) |
|
|
else: |
|
|
else: |
|
|
raise ValueError("modality[{}] not implemented.".format(self._modality)) |
|
|
raise ValueError("modality[{}] not implemented.".format(self._modality)) |
|
|
|
|
|
vec = vec / vec.norm(dim=-1, keepdim=True) |
|
|
return vec.detach().cpu().numpy().flatten() |
|
|
return vec.detach().cpu().numpy().flatten() |
|
|
|
|
|
|
|
|
def _inference_from_text(self, text): |
|
|
def _inference_from_text(self, text): |
|
|
texts = tokenizer(texts).cuda(non_blocking=True) |
|
|
|
|
|
texts = texts.view(-1, 77).contiguous() |
|
|
|
|
|
embedding = get_model(self.model).encode_text(texts) |
|
|
|
|
|
embedding = embedding / embedding.norm(dim=-1, keepdim=True) |
|
|
|
|
|
|
|
|
text = self.tokenizer(text).to(self.device) |
|
|
|
|
|
text = text.view(-1, 77).contiguous() |
|
|
|
|
|
embedding = get_model(self.model).encode_text(text) |
|
|
|
|
|
return embedding |
|
|
|
|
|
|
|
|
@arg(1, to_image_color('RGB')) |
|
|
@arg(1, to_image_color('RGB')) |
|
|
def _inference_from_image(self, img): |
|
|
def _inference_from_image(self, img): |
|
|