diff --git a/s_bert.py b/s_bert.py index dfa71af..09efa5f 100644 --- a/s_bert.py +++ b/s_bert.py @@ -16,6 +16,7 @@ import logging import numpy from typing import Union, List from pathlib import Path +from functools import partial import torch from sentence_transformers import SentenceTransformer @@ -62,8 +63,11 @@ class Model: self.model = SentenceTransformer(model_name_or_path=model_name, device=self.device) self.model.eval() - def __call__(self, **features): - outs = self.model(features) + def __call__(self, *_, **kwargs): + new_kwargs = {} + for k, v in kwargs.items(): + new_kwargs[k] = v.to(self.device) + outs = self.model(new_kwargs) return outs['sentence_embedding'] @@ -81,16 +85,14 @@ class STransformers(NNOperator): self.model = Model(model_name=self.model_name, device=self.device) else: log.warning('The operator is initialized without specified model.') - pass + self._tokenize = self.get_tokenizer() def __call__(self, txt: Union[List[str], str]): if isinstance(txt, str): sentences = [txt] else: sentences = txt - inputs = self.tokenize(sentences) -# for k, v in inputs.items(): -# inputs[k] = v.to(self.device) + inputs = self._tokenize(sentences) embs = self.model(**inputs).cpu().detach().numpy() if isinstance(txt, str): embs = embs.squeeze(0) @@ -102,41 +104,41 @@ class STransformers(NNOperator): def supported_formats(self): return ['onnx'] - def tokenize(self, x): - try: - outs = self._model.tokenize(x) - except Exception: - from transformers import AutoTokenizer + def get_tokenizer(self): + if hasattr(self._model, "tokenize"): + return self._model.tokenize + else: + from transformers import AutoTokenizer, AutoConfig try: tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/' + self.model_name) + conf = AutoConfig.from_pretrained('sentence-transformers/' + self.model_name) except Exception: tokenizer = AutoTokenizer.from_pretrained(self.model_name) - outs = tokenizer( - x, - padding=True, truncation='longest_first', max_length=self.max_seq_length, - return_tensors='pt', - ) - return outs - - @property - def max_seq_length(self): - import json - from torch.hub import _get_torch_home - torch_cache = _get_torch_home() - sbert_cache = os.path.join(torch_cache, 'sentence_transformers') - cfg_path = os.path.join(sbert_cache, 'sentence-transformers_' + self.model_name, 'sentence_bert_config.json') - if not os.path.exists(cfg_path): - cfg_path = os.path.join(sbert_cache, self.model_name, 'config.json') - k = 'max_position_embeddings' - else: - k = 'max_seq_length' - with open(cfg_path) as f: - cfg = json.load(f) - if k in cfg: - max_seq_len = cfg[k] - else: - max_seq_len = None - return max_seq_len + conf = AutoConfig.from_pretrained(self.model_name) + return partial(tokenizer, + padding=True, + truncation='longest_first', + max_length=conf.max_position_embeddings, + return_tensors='pt') + # @property + # def max_seq_length(self): + # import json + # from torch.hub import _get_torch_home + # torch_cache = _get_torch_home() + # sbert_cache = os.path.join(torch_cache, 'sentence_transformers') + # cfg_path = os.path.join(sbert_cache, 'sentence-transformers_' + self.model_name, 'sentence_bert_config.json') + # if not os.path.exists(cfg_path): + # cfg_path = os.path.join(sbert_cache, self.model_name, 'config.json') + # k = 'max_position_embeddings' + # else: + # k = 'max_seq_length' + # with open(cfg_path) as f: + # cfg = json.load(f) + # if k in cfg: + # max_seq_len = cfg[k] + # else: + # max_seq_len = None + # return max_seq_len @property def _model(self): @@ -156,7 +158,7 @@ class STransformers(NNOperator): else: raise AttributeError(f'Invalid format {format}.') dummy_text = ['[CLS]'] - dummy_input = self.tokenize(dummy_text) + dummy_input = self._tokenize(dummy_text) if format == 'pytorch': torch.save(self._model, path) elif format == 'torchscript': @@ -180,7 +182,7 @@ class STransformers(NNOperator): dynamic_axes[i_n] = {0: 'batch_size', 1: 'sequence_length'} dynamic_axes['output_0'] = {0: 'batch_size', 1: 'emb_dim'} try: - torch.onnx.export(new_model, + torch.onnx.export(new_model.to('cpu'), tuple(dummy_input.values()), path, input_names=input_names, @@ -200,6 +202,7 @@ class STransformers(NNOperator): @staticmethod def supported_model_names(format: str = None): full_list = [ + 'clip-ViT-B-32-multilingual-v1', 'sentence-t5-xxl', 'sentence-t5-xl', 'sentence-t5-large',