|
@ -20,6 +20,7 @@ from functools import partial |
|
|
|
|
|
|
|
|
import torch |
|
|
import torch |
|
|
from sentence_transformers import SentenceTransformer |
|
|
from sentence_transformers import SentenceTransformer |
|
|
|
|
|
from transformers import AutoTokenizer, AutoConfig |
|
|
|
|
|
|
|
|
from towhee.operator import NNOperator |
|
|
from towhee.operator import NNOperator |
|
|
try: |
|
|
try: |
|
@ -105,21 +106,17 @@ class STransformers(NNOperator): |
|
|
return ['onnx'] |
|
|
return ['onnx'] |
|
|
|
|
|
|
|
|
def get_tokenizer(self): |
|
|
def get_tokenizer(self): |
|
|
if hasattr(self._model, "tokenize"): |
|
|
|
|
|
return self._model.tokenize |
|
|
|
|
|
else: |
|
|
|
|
|
from transformers import AutoTokenizer, AutoConfig |
|
|
|
|
|
try: |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/' + self.model_name) |
|
|
|
|
|
conf = AutoConfig.from_pretrained('sentence-transformers/' + self.model_name) |
|
|
|
|
|
except Exception: |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(self.model_name) |
|
|
|
|
|
conf = AutoConfig.from_pretrained(self.model_name) |
|
|
|
|
|
return partial(tokenizer, |
|
|
|
|
|
padding=True, |
|
|
|
|
|
truncation='longest_first', |
|
|
|
|
|
max_length=conf.max_position_embeddings, |
|
|
|
|
|
return_tensors='pt') |
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/' + self.model_name) |
|
|
|
|
|
conf = AutoConfig.from_pretrained('sentence-transformers/' + self.model_name) |
|
|
|
|
|
except Exception: |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(self.model_name) |
|
|
|
|
|
conf = AutoConfig.from_pretrained(self.model_name) |
|
|
|
|
|
return partial(tokenizer, |
|
|
|
|
|
padding=True, |
|
|
|
|
|
truncation='longest_first', |
|
|
|
|
|
max_length=conf.max_position_embeddings, |
|
|
|
|
|
return_tensors='pt') |
|
|
# @property |
|
|
# @property |
|
|
# def max_seq_length(self): |
|
|
# def max_seq_length(self): |
|
|
# import json |
|
|
# import json |
|
|