logo
Browse Source

Support save model with torchscript

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 3 years ago
parent
commit
fd966f51bd
  1. 22
      auto_transformers.py

22
auto_transformers.py

@ -13,6 +13,9 @@
# limitations under the License. # limitations under the License.
import numpy import numpy
import os
import torch
from pathlib import Path
from transformers import AutoTokenizer, AutoModel from transformers import AutoTokenizer, AutoModel
from towhee.operator import NNOperator from towhee.operator import NNOperator
@ -40,6 +43,7 @@ class AutoTransformers(NNOperator):
self.model_name = model_name self.model_name = model_name
try: try:
self.model = AutoModel.from_pretrained(model_name) self.model = AutoModel.from_pretrained(model_name)
self.model.eval()
except Exception as e: except Exception as e:
model_list = get_model_list() model_list = get_model_list()
if model_name not in model_list: if model_name not in model_list:
@ -65,13 +69,29 @@ class AutoTransformers(NNOperator):
log.error(f'Invalid input for the model: {self.model_name}') log.error(f'Invalid input for the model: {self.model_name}')
raise e raise e
try: try:
features = outs.last_hidden_state.squeeze(0)
features = outs['last_hidden_state'].squeeze(0)
except Exception as e: except Exception as e:
log.error(f'Fail to extract features by model: {self.model_name}') log.error(f'Fail to extract features by model: {self.model_name}')
raise e raise e
vec = features.detach().numpy() vec = features.detach().numpy()
return vec return vec
def save_model(self, jit: bool = True, destination: str = 'default'):
if destination == 'default':
path = str(Path(__file__).parent)
destination = os.path.join(path, self.model_name + '.pt')
inputs = self.tokenizer('[CLS]', return_tensors='pt')
inputs = list(inputs.values())
if jit:
try:
traced_model = torch.jit.trace(self.model, inputs, strict=False)
torch.jit.save(traced_model, destination)
except Exception as e:
log.error(f'Fail to save as torchscript: {e}.')
raise RuntimeError(f'Fail to save as torchscript: {e}.')
else:
torch.save(self.model, destination)
def get_model_list(): def get_model_list():
full_list = [ full_list = [

Loading…
Cancel
Save