logo
Browse Source

Speed up

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 2 years ago
parent
commit
72d6eab617
  1. 39
      auto_transformers.py

39
auto_transformers.py

@ -20,7 +20,7 @@ from pathlib import Path
from typing import Union
from collections import OrderedDict
from transformers import AutoModel
from transformers import AutoTokenizer, AutoConfig, AutoModel
from towhee.operator import NNOperator
from towhee import register
@ -67,9 +67,17 @@ class AutoTransformers(NNOperator):
norm: bool = False
):
super().__init__()
self._device = device
if device:
self.device = device
else:
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.model_name = model_name
self.user_tokenizer = tokenizer
if tokenizer:
self.tokenizer = tokenizer
else:
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
if not self.tokenizer.pad_token:
self.tokenizer.pad_token = '[PAD]'
self.norm = norm
self.checkpoint_path = checkpoint_path
@ -120,18 +128,8 @@ class AutoTransformers(NNOperator):
model.eval()
return model
@property
def device(self):
if self._device is None:
if self._device_id < 0:
self._device = torch.device('cpu')
else:
self._device = torch.device(self._device_id)
return self._device
@property
def model_config(self):
from transformers import AutoConfig
configs = AutoConfig.from_pretrained(self.model_name)
return configs
@ -147,21 +145,6 @@ class AutoTransformers(NNOperator):
}
return onnx_config
@property
def tokenizer(self):
from transformers import AutoTokenizer
try:
if self.user_tokenizer:
t = tokenizer
else:
t = AutoTokenizer.from_pretrained(self.model_name)
if not t.pad_token:
t.pad_token = '[PAD]'
except Exception as e:
log.error(f'Fail to load tokenizer.')
raise e
return t
def post_proc(self, token_embeddings, inputs):
token_embeddings = token_embeddings.to(self.device)
attention_mask = inputs['attention_mask'].to(self.device)

Loading…
Cancel
Save