|
|
@ -20,7 +20,7 @@ from pathlib import Path |
|
|
|
from typing import Union |
|
|
|
from collections import OrderedDict |
|
|
|
|
|
|
|
from transformers import AutoModel |
|
|
|
from transformers import AutoTokenizer, AutoConfig, AutoModel |
|
|
|
|
|
|
|
from towhee.operator import NNOperator |
|
|
|
from towhee import register |
|
|
@ -67,9 +67,17 @@ class AutoTransformers(NNOperator): |
|
|
|
norm: bool = False |
|
|
|
): |
|
|
|
super().__init__() |
|
|
|
self._device = device |
|
|
|
if device: |
|
|
|
self.device = device |
|
|
|
else: |
|
|
|
self.device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
|
self.model_name = model_name |
|
|
|
self.user_tokenizer = tokenizer |
|
|
|
if tokenizer: |
|
|
|
self.tokenizer = tokenizer |
|
|
|
else: |
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) |
|
|
|
if not self.tokenizer.pad_token: |
|
|
|
self.tokenizer.pad_token = '[PAD]' |
|
|
|
self.norm = norm |
|
|
|
self.checkpoint_path = checkpoint_path |
|
|
|
|
|
|
@ -120,18 +128,8 @@ class AutoTransformers(NNOperator): |
|
|
|
model.eval() |
|
|
|
return model |
|
|
|
|
|
|
|
@property |
|
|
|
def device(self): |
|
|
|
if self._device is None: |
|
|
|
if self._device_id < 0: |
|
|
|
self._device = torch.device('cpu') |
|
|
|
else: |
|
|
|
self._device = torch.device(self._device_id) |
|
|
|
return self._device |
|
|
|
|
|
|
|
@property |
|
|
|
def model_config(self): |
|
|
|
from transformers import AutoConfig |
|
|
|
configs = AutoConfig.from_pretrained(self.model_name) |
|
|
|
return configs |
|
|
|
|
|
|
@ -147,21 +145,6 @@ class AutoTransformers(NNOperator): |
|
|
|
} |
|
|
|
return onnx_config |
|
|
|
|
|
|
|
@property |
|
|
|
def tokenizer(self): |
|
|
|
from transformers import AutoTokenizer |
|
|
|
try: |
|
|
|
if self.user_tokenizer: |
|
|
|
t = tokenizer |
|
|
|
else: |
|
|
|
t = AutoTokenizer.from_pretrained(self.model_name) |
|
|
|
if not t.pad_token: |
|
|
|
t.pad_token = '[PAD]' |
|
|
|
except Exception as e: |
|
|
|
log.error(f'Fail to load tokenizer.') |
|
|
|
raise e |
|
|
|
return t |
|
|
|
|
|
|
|
def post_proc(self, token_embeddings, inputs): |
|
|
|
token_embeddings = token_embeddings.to(self.device) |
|
|
|
attention_mask = inputs['attention_mask'].to(self.device) |
|
|
|