logo
Browse Source

Speed up

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 2 years ago
parent
commit
72d6eab617
  1. 39
      auto_transformers.py

39
auto_transformers.py

@ -20,7 +20,7 @@ from pathlib import Path
from typing import Union from typing import Union
from collections import OrderedDict from collections import OrderedDict
from transformers import AutoModel
from transformers import AutoTokenizer, AutoConfig, AutoModel
from towhee.operator import NNOperator from towhee.operator import NNOperator
from towhee import register from towhee import register
@ -67,9 +67,17 @@ class AutoTransformers(NNOperator):
norm: bool = False norm: bool = False
): ):
super().__init__() super().__init__()
self._device = device
if device:
self.device = device
else:
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.model_name = model_name self.model_name = model_name
self.user_tokenizer = tokenizer
if tokenizer:
self.tokenizer = tokenizer
else:
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
if not self.tokenizer.pad_token:
self.tokenizer.pad_token = '[PAD]'
self.norm = norm self.norm = norm
self.checkpoint_path = checkpoint_path self.checkpoint_path = checkpoint_path
@ -120,18 +128,8 @@ class AutoTransformers(NNOperator):
model.eval() model.eval()
return model return model
@property
def device(self):
if self._device is None:
if self._device_id < 0:
self._device = torch.device('cpu')
else:
self._device = torch.device(self._device_id)
return self._device
@property @property
def model_config(self): def model_config(self):
from transformers import AutoConfig
configs = AutoConfig.from_pretrained(self.model_name) configs = AutoConfig.from_pretrained(self.model_name)
return configs return configs
@ -147,21 +145,6 @@ class AutoTransformers(NNOperator):
} }
return onnx_config return onnx_config
@property
def tokenizer(self):
from transformers import AutoTokenizer
try:
if self.user_tokenizer:
t = tokenizer
else:
t = AutoTokenizer.from_pretrained(self.model_name)
if not t.pad_token:
t.pad_token = '[PAD]'
except Exception as e:
log.error(f'Fail to load tokenizer.')
raise e
return t
def post_proc(self, token_embeddings, inputs): def post_proc(self, token_embeddings, inputs):
token_embeddings = token_embeddings.to(self.device) token_embeddings = token_embeddings.to(self.device)
attention_mask = inputs['attention_mask'].to(self.device) attention_mask = inputs['attention_mask'].to(self.device)

Loading…
Cancel
Save