logo
Browse Source

slip operator update.

Signed-off-by: wxywb <xy.wang@zilliz.com>
main
wxywb 2 years ago
parent
commit
5b1a1a02e9
  1. 6
      __init__.py
  2. BIN
      bpe_simple_vocab_16e6.txt.gz
  3. 36
      models.py
  4. 53
      slip.py
  5. 8
      utils.py

6
__init__.py

@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .blip import Blip
from .slip import Slip
def blip(model_name: str, modality: str):
return Blip(model_name, modality)
def slip(model_name: str, modality: str):
return Slip(model_name, modality)

BIN
bpe_simple_vocab_16e6.txt.gz (Stored with Git LFS)

Binary file not shown.

36
models.py

@ -12,7 +12,7 @@ import timm
import torch import torch
from torch import nn from torch import nn
import losses
#import losses
class LayerNorm(nn.LayerNorm): class LayerNorm(nn.LayerNorm):
@ -235,23 +235,23 @@ class SLIP(CLIP):
'aug2_embed': aug2_embed} 'aug2_embed': aug2_embed}
def get_loss(model, ssl_temp, ssl_scale):
if model.startswith('SLIP'):
ssl_loss = losses.SIMCLRLoss(temperature=ssl_temp)
return losses.SLIPLoss(ssl_loss, ssl_scale)
if model.startswith('CLIP'):
return losses.CLIPLoss()
if model.startswith('SIMCLR'):
return losses.SIMCLRLoss(temperature=ssl_temp)
def get_metric_names(model):
if model.startswith('SLIP'):
return ['loss', 'clip_loss', 'ssl_loss', 'clip_acc', 'ssl_acc']
elif model.startswith('CLIP'):
return ['loss', 'clip_loss', 'clip_acc']
else:
return ['loss', 'ssl_loss', 'ssl_acc']
#def get_loss(model, ssl_temp, ssl_scale):
# if model.startswith('SLIP'):
# ssl_loss = losses.SIMCLRLoss(temperature=ssl_temp)
# return losses.SLIPLoss(ssl_loss, ssl_scale)
# if model.startswith('CLIP'):
# return losses.CLIPLoss()
# if model.startswith('SIMCLR'):
# return losses.SIMCLRLoss(temperature=ssl_temp)
#
#
#def get_metric_names(model):
# if model.startswith('SLIP'):
# return ['loss', 'clip_loss', 'ssl_loss', 'clip_acc', 'ssl_acc']
# elif model.startswith('CLIP'):
# return ['loss', 'clip_loss', 'clip_acc']
# else:
# return ['loss', 'ssl_loss', 'ssl_acc']
@timm.models.registry.register_model @timm.models.registry.register_model

53
slip.py

@ -13,18 +13,20 @@
# limitations under the License. # limitations under the License.
import sys import sys
import os
from pathlib import Path from pathlib import Path
from urllib.parse import urlparse
from collections import OrderedDict
import torch import torch
from torchvision import transforms from torchvision import transforms
from timm.models.hub import download_cached_file
from towhee import register from towhee import register
from towhee.operator.base import NNOperator, OperatorFlag from towhee.operator.base import NNOperator, OperatorFlag
from towhee.types.arg import arg, to_image_color from towhee.types.arg import arg, to_image_color
from towhee.types.image_utils import from_pil, to_pil from towhee.types.image_utils import from_pil, to_pil
from tokenizer import SimpleTokenizer
def get_model(model): def get_model(model):
if isinstance(model, torch.nn.DataParallel) \ if isinstance(model, torch.nn.DataParallel) \
or isinstance(model, torch.nn.parallel.DistributedDataParallel): or isinstance(model, torch.nn.parallel.DistributedDataParallel):
@ -32,16 +34,52 @@ def get_model(model):
else: else:
return model return model
def is_url(url_or_filename):
parsed = urlparse(url_or_filename)
return parsed.scheme in ("http", "https")
def load_checkpoint(url_or_filename, models, device):
if is_url(url_or_filename):
cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True)
checkpoint = torch.load(cached_file, map_location='cpu')
elif os.path.isfile(url_or_filename):
checkpoint = torch.load(url_or_filename, map_location='cpu')
else:
raise RuntimeError('checkpoint url or path is invalid')
if is_url(url_or_filename):
cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True)
checkpoint = torch.load(cached_file, map_location='cpu')
elif os.path.isfile(url_or_filename):
checkpoint = torch.load(url_or_filename, map_location='cpu')
else:
raise RuntimeError('checkpoint url or path is invalid')
state_dict = OrderedDict()
for k, v in checkpoint['state_dict'].items():
state_dict[k.replace('module.', '')] = v
old_args = checkpoint['args']
model = getattr(models, old_args.model)(rand_embed=False,
ssl_mlp_dim=old_args.ssl_mlp_dim, ssl_emb_dim=old_args.ssl_emb_dim)
model.to(device)
model.load_state_dict(state_dict, strict=True)
return model
@register(output_schema=['vec']) @register(output_schema=['vec'])
class Slip(NNOperator)
class Slip(NNOperator):
""" """
SLIP multi-modal embedding operator SLIP multi-modal embedding operator
""" """
def __init__(self, model_name: str, modality: str): def __init__(self, model_name: str, modality: str):
super().__init__() super().__init__()
sys.path.append(str(Path(__file__).parent)) sys.path.append(str(Path(__file__).parent))
import models
from tokenizer import SimpleTokenizer
self.tokenizer = SimpleTokenizer() self.tokenizer = SimpleTokenizer()
self.device = "cuda" if torch.cuda.is_available() else "cpu" self.device = "cuda" if torch.cuda.is_available() else "cpu"
self._modality = modality
self.model = load_checkpoint(self._configs()[model_name]['weights'], models, self.device)
self.model.to(self.device) self.model.to(self.device)
self.model.eval() self.model.eval()
@ -61,13 +99,14 @@ class Slip(NNOperator)
vec = self._inference_from_text(data) vec = self._inference_from_text(data)
else: else:
raise ValueError("modality[{}] not implemented.".format(self._modality)) raise ValueError("modality[{}] not implemented.".format(self._modality))
vec = vec / vec.norm(dim=-1, keepdim=True)
return vec.detach().cpu().numpy().flatten() return vec.detach().cpu().numpy().flatten()
def _inference_from_text(self, text): def _inference_from_text(self, text):
texts = tokenizer(texts).cuda(non_blocking=True)
texts = texts.view(-1, 77).contiguous()
embedding = get_model(self.model).encode_text(texts)
embedding = embedding / embedding.norm(dim=-1, keepdim=True)
text = self.tokenizer(text).to(self.device)
text = text.view(-1, 77).contiguous()
embedding = get_model(self.model).encode_text(text)
return embedding
@arg(1, to_image_color('RGB')) @arg(1, to_image_color('RGB'))
def _inference_from_image(self, img): def _inference_from_image(self, img):

8
utils.py

@ -1,8 +0,0 @@
import torch
def get_model(model):
if isinstance(model, torch.nn.DataParallel) \
or isinstance(model, torch.nn.parallel.DistributedDataParallel):
return model.module
else:
return model
Loading…
Cancel
Save