|
@ -16,6 +16,7 @@ import logging |
|
|
import numpy |
|
|
import numpy |
|
|
from typing import Union, List |
|
|
from typing import Union, List |
|
|
from pathlib import Path |
|
|
from pathlib import Path |
|
|
|
|
|
from functools import partial |
|
|
|
|
|
|
|
|
import torch |
|
|
import torch |
|
|
from sentence_transformers import SentenceTransformer |
|
|
from sentence_transformers import SentenceTransformer |
|
@ -62,8 +63,11 @@ class Model: |
|
|
self.model = SentenceTransformer(model_name_or_path=model_name, device=self.device) |
|
|
self.model = SentenceTransformer(model_name_or_path=model_name, device=self.device) |
|
|
self.model.eval() |
|
|
self.model.eval() |
|
|
|
|
|
|
|
|
def __call__(self, **features): |
|
|
|
|
|
outs = self.model(features) |
|
|
|
|
|
|
|
|
def __call__(self, *_, **kwargs): |
|
|
|
|
|
new_kwargs = {} |
|
|
|
|
|
for k, v in kwargs.items(): |
|
|
|
|
|
new_kwargs[k] = v.to(self.device) |
|
|
|
|
|
outs = self.model(new_kwargs) |
|
|
return outs['sentence_embedding'] |
|
|
return outs['sentence_embedding'] |
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -81,16 +85,14 @@ class STransformers(NNOperator): |
|
|
self.model = Model(model_name=self.model_name, device=self.device) |
|
|
self.model = Model(model_name=self.model_name, device=self.device) |
|
|
else: |
|
|
else: |
|
|
log.warning('The operator is initialized without specified model.') |
|
|
log.warning('The operator is initialized without specified model.') |
|
|
pass |
|
|
|
|
|
|
|
|
self._tokenize = self.get_tokenizer() |
|
|
|
|
|
|
|
|
def __call__(self, txt: Union[List[str], str]): |
|
|
def __call__(self, txt: Union[List[str], str]): |
|
|
if isinstance(txt, str): |
|
|
if isinstance(txt, str): |
|
|
sentences = [txt] |
|
|
sentences = [txt] |
|
|
else: |
|
|
else: |
|
|
sentences = txt |
|
|
sentences = txt |
|
|
inputs = self.tokenize(sentences) |
|
|
|
|
|
# for k, v in inputs.items(): |
|
|
|
|
|
# inputs[k] = v.to(self.device) |
|
|
|
|
|
|
|
|
inputs = self._tokenize(sentences) |
|
|
embs = self.model(**inputs).cpu().detach().numpy() |
|
|
embs = self.model(**inputs).cpu().detach().numpy() |
|
|
if isinstance(txt, str): |
|
|
if isinstance(txt, str): |
|
|
embs = embs.squeeze(0) |
|
|
embs = embs.squeeze(0) |
|
@ -102,41 +104,41 @@ class STransformers(NNOperator): |
|
|
def supported_formats(self): |
|
|
def supported_formats(self): |
|
|
return ['onnx'] |
|
|
return ['onnx'] |
|
|
|
|
|
|
|
|
def tokenize(self, x): |
|
|
|
|
|
try: |
|
|
|
|
|
outs = self._model.tokenize(x) |
|
|
|
|
|
except Exception: |
|
|
|
|
|
from transformers import AutoTokenizer |
|
|
|
|
|
|
|
|
def get_tokenizer(self): |
|
|
|
|
|
if hasattr(self._model, "tokenize"): |
|
|
|
|
|
return self._model.tokenize |
|
|
|
|
|
else: |
|
|
|
|
|
from transformers import AutoTokenizer, AutoConfig |
|
|
try: |
|
|
try: |
|
|
tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/' + self.model_name) |
|
|
tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/' + self.model_name) |
|
|
|
|
|
conf = AutoConfig.from_pretrained('sentence-transformers/' + self.model_name) |
|
|
except Exception: |
|
|
except Exception: |
|
|
tokenizer = AutoTokenizer.from_pretrained(self.model_name) |
|
|
tokenizer = AutoTokenizer.from_pretrained(self.model_name) |
|
|
outs = tokenizer( |
|
|
|
|
|
x, |
|
|
|
|
|
padding=True, truncation='longest_first', max_length=self.max_seq_length, |
|
|
|
|
|
return_tensors='pt', |
|
|
|
|
|
) |
|
|
|
|
|
return outs |
|
|
|
|
|
|
|
|
|
|
|
@property |
|
|
|
|
|
def max_seq_length(self): |
|
|
|
|
|
import json |
|
|
|
|
|
from torch.hub import _get_torch_home |
|
|
|
|
|
torch_cache = _get_torch_home() |
|
|
|
|
|
sbert_cache = os.path.join(torch_cache, 'sentence_transformers') |
|
|
|
|
|
cfg_path = os.path.join(sbert_cache, 'sentence-transformers_' + self.model_name, 'sentence_bert_config.json') |
|
|
|
|
|
if not os.path.exists(cfg_path): |
|
|
|
|
|
cfg_path = os.path.join(sbert_cache, self.model_name, 'config.json') |
|
|
|
|
|
k = 'max_position_embeddings' |
|
|
|
|
|
else: |
|
|
|
|
|
k = 'max_seq_length' |
|
|
|
|
|
with open(cfg_path) as f: |
|
|
|
|
|
cfg = json.load(f) |
|
|
|
|
|
if k in cfg: |
|
|
|
|
|
max_seq_len = cfg[k] |
|
|
|
|
|
else: |
|
|
|
|
|
max_seq_len = None |
|
|
|
|
|
return max_seq_len |
|
|
|
|
|
|
|
|
conf = AutoConfig.from_pretrained(self.model_name) |
|
|
|
|
|
return partial(tokenizer, |
|
|
|
|
|
padding=True, |
|
|
|
|
|
truncation='longest_first', |
|
|
|
|
|
max_length=conf.max_position_embeddings, |
|
|
|
|
|
return_tensors='pt') |
|
|
|
|
|
# @property |
|
|
|
|
|
# def max_seq_length(self): |
|
|
|
|
|
# import json |
|
|
|
|
|
# from torch.hub import _get_torch_home |
|
|
|
|
|
# torch_cache = _get_torch_home() |
|
|
|
|
|
# sbert_cache = os.path.join(torch_cache, 'sentence_transformers') |
|
|
|
|
|
# cfg_path = os.path.join(sbert_cache, 'sentence-transformers_' + self.model_name, 'sentence_bert_config.json') |
|
|
|
|
|
# if not os.path.exists(cfg_path): |
|
|
|
|
|
# cfg_path = os.path.join(sbert_cache, self.model_name, 'config.json') |
|
|
|
|
|
# k = 'max_position_embeddings' |
|
|
|
|
|
# else: |
|
|
|
|
|
# k = 'max_seq_length' |
|
|
|
|
|
# with open(cfg_path) as f: |
|
|
|
|
|
# cfg = json.load(f) |
|
|
|
|
|
# if k in cfg: |
|
|
|
|
|
# max_seq_len = cfg[k] |
|
|
|
|
|
# else: |
|
|
|
|
|
# max_seq_len = None |
|
|
|
|
|
# return max_seq_len |
|
|
|
|
|
|
|
|
@property |
|
|
@property |
|
|
def _model(self): |
|
|
def _model(self): |
|
@ -156,7 +158,7 @@ class STransformers(NNOperator): |
|
|
else: |
|
|
else: |
|
|
raise AttributeError(f'Invalid format {format}.') |
|
|
raise AttributeError(f'Invalid format {format}.') |
|
|
dummy_text = ['[CLS]'] |
|
|
dummy_text = ['[CLS]'] |
|
|
dummy_input = self.tokenize(dummy_text) |
|
|
|
|
|
|
|
|
dummy_input = self._tokenize(dummy_text) |
|
|
if format == 'pytorch': |
|
|
if format == 'pytorch': |
|
|
torch.save(self._model, path) |
|
|
torch.save(self._model, path) |
|
|
elif format == 'torchscript': |
|
|
elif format == 'torchscript': |
|
@ -180,7 +182,7 @@ class STransformers(NNOperator): |
|
|
dynamic_axes[i_n] = {0: 'batch_size', 1: 'sequence_length'} |
|
|
dynamic_axes[i_n] = {0: 'batch_size', 1: 'sequence_length'} |
|
|
dynamic_axes['output_0'] = {0: 'batch_size', 1: 'emb_dim'} |
|
|
dynamic_axes['output_0'] = {0: 'batch_size', 1: 'emb_dim'} |
|
|
try: |
|
|
try: |
|
|
torch.onnx.export(new_model, |
|
|
|
|
|
|
|
|
torch.onnx.export(new_model.to('cpu'), |
|
|
tuple(dummy_input.values()), |
|
|
tuple(dummy_input.values()), |
|
|
path, |
|
|
path, |
|
|
input_names=input_names, |
|
|
input_names=input_names, |
|
@ -200,6 +202,7 @@ class STransformers(NNOperator): |
|
|
@staticmethod |
|
|
@staticmethod |
|
|
def supported_model_names(format: str = None): |
|
|
def supported_model_names(format: str = None): |
|
|
full_list = [ |
|
|
full_list = [ |
|
|
|
|
|
'clip-ViT-B-32-multilingual-v1', |
|
|
'sentence-t5-xxl', |
|
|
'sentence-t5-xxl', |
|
|
'sentence-t5-xl', |
|
|
'sentence-t5-xl', |
|
|
'sentence-t5-large', |
|
|
'sentence-t5-large', |
|
|