logo
Browse Source

Fix token

Signed-off-by: junjie.jiang <junjie.jiang@zilliz.com>
main
junjie.jiang 10 months ago
parent
commit
48e9c77a32
  1. 27
      s_bert.py

27
s_bert.py

@ -20,6 +20,7 @@ from functools import partial
import torch import torch
from sentence_transformers import SentenceTransformer from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoConfig
from towhee.operator import NNOperator from towhee.operator import NNOperator
try: try:
@ -105,21 +106,17 @@ class STransformers(NNOperator):
return ['onnx'] return ['onnx']
def get_tokenizer(self): def get_tokenizer(self):
if hasattr(self._model, "tokenize"):
return self._model.tokenize
else:
from transformers import AutoTokenizer, AutoConfig
try:
tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/' + self.model_name)
conf = AutoConfig.from_pretrained('sentence-transformers/' + self.model_name)
except Exception:
tokenizer = AutoTokenizer.from_pretrained(self.model_name)
conf = AutoConfig.from_pretrained(self.model_name)
return partial(tokenizer,
padding=True,
truncation='longest_first',
max_length=conf.max_position_embeddings,
return_tensors='pt')
try:
tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/' + self.model_name)
conf = AutoConfig.from_pretrained('sentence-transformers/' + self.model_name)
except Exception:
tokenizer = AutoTokenizer.from_pretrained(self.model_name)
conf = AutoConfig.from_pretrained(self.model_name)
return partial(tokenizer,
padding=True,
truncation='longest_first',
max_length=conf.max_position_embeddings,
return_tensors='pt')
# @property # @property
# def max_seq_length(self): # def max_seq_length(self):
# import json # import json

Loading…
Cancel
Save