From 48e9c77a32b17498302ab129a2c6d857946ac008 Mon Sep 17 00:00:00 2001 From: "junjie.jiang" Date: Tue, 25 Jun 2024 13:22:14 +0800 Subject: [PATCH] Fix token Signed-off-by: junjie.jiang --- s_bert.py | 27 ++++++++++++--------------- 1 file changed, 12 insertions(+), 15 deletions(-) diff --git a/s_bert.py b/s_bert.py index 09efa5f..18688f3 100644 --- a/s_bert.py +++ b/s_bert.py @@ -20,6 +20,7 @@ from functools import partial import torch from sentence_transformers import SentenceTransformer +from transformers import AutoTokenizer, AutoConfig from towhee.operator import NNOperator try: @@ -105,21 +106,17 @@ class STransformers(NNOperator): return ['onnx'] def get_tokenizer(self): - if hasattr(self._model, "tokenize"): - return self._model.tokenize - else: - from transformers import AutoTokenizer, AutoConfig - try: - tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/' + self.model_name) - conf = AutoConfig.from_pretrained('sentence-transformers/' + self.model_name) - except Exception: - tokenizer = AutoTokenizer.from_pretrained(self.model_name) - conf = AutoConfig.from_pretrained(self.model_name) - return partial(tokenizer, - padding=True, - truncation='longest_first', - max_length=conf.max_position_embeddings, - return_tensors='pt') + try: + tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/' + self.model_name) + conf = AutoConfig.from_pretrained('sentence-transformers/' + self.model_name) + except Exception: + tokenizer = AutoTokenizer.from_pretrained(self.model_name) + conf = AutoConfig.from_pretrained(self.model_name) + return partial(tokenizer, + padding=True, + truncation='longest_first', + max_length=conf.max_position_embeddings, + return_tensors='pt') # @property # def max_seq_length(self): # import json