From 72d6eab61768d3542d287e1671660cded2c22648 Mon Sep 17 00:00:00 2001 From: Jael Gu Date: Thu, 12 Jan 2023 16:22:01 +0800 Subject: [PATCH] Speed up Signed-off-by: Jael Gu --- auto_transformers.py | 39 +++++++++++---------------------------- 1 file changed, 11 insertions(+), 28 deletions(-) diff --git a/auto_transformers.py b/auto_transformers.py index 1ce8b7d..6104820 100644 --- a/auto_transformers.py +++ b/auto_transformers.py @@ -20,7 +20,7 @@ from pathlib import Path from typing import Union from collections import OrderedDict -from transformers import AutoModel +from transformers import AutoTokenizer, AutoConfig, AutoModel from towhee.operator import NNOperator from towhee import register @@ -67,9 +67,17 @@ class AutoTransformers(NNOperator): norm: bool = False ): super().__init__() - self._device = device + if device: + self.device = device + else: + self.device = 'cuda' if torch.cuda.is_available() else 'cpu' self.model_name = model_name - self.user_tokenizer = tokenizer + if tokenizer: + self.tokenizer = tokenizer + else: + self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) + if not self.tokenizer.pad_token: + self.tokenizer.pad_token = '[PAD]' self.norm = norm self.checkpoint_path = checkpoint_path @@ -120,18 +128,8 @@ class AutoTransformers(NNOperator): model.eval() return model - @property - def device(self): - if self._device is None: - if self._device_id < 0: - self._device = torch.device('cpu') - else: - self._device = torch.device(self._device_id) - return self._device - @property def model_config(self): - from transformers import AutoConfig configs = AutoConfig.from_pretrained(self.model_name) return configs @@ -147,21 +145,6 @@ class AutoTransformers(NNOperator): } return onnx_config - @property - def tokenizer(self): - from transformers import AutoTokenizer - try: - if self.user_tokenizer: - t = tokenizer - else: - t = AutoTokenizer.from_pretrained(self.model_name) - if not t.pad_token: - t.pad_token = '[PAD]' - except Exception as e: - log.error(f'Fail to load tokenizer.') - raise e - return t - def post_proc(self, token_embeddings, inputs): token_embeddings = token_embeddings.to(self.device) attention_mask = inputs['attention_mask'].to(self.device)