logo
Browse Source

Support save model with torchscript

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 3 years ago
parent
commit
fd966f51bd
  1. 22
      auto_transformers.py

22
auto_transformers.py

@ -13,6 +13,9 @@
# limitations under the License.
import numpy
import os
import torch
from pathlib import Path
from transformers import AutoTokenizer, AutoModel
from towhee.operator import NNOperator
@ -40,6 +43,7 @@ class AutoTransformers(NNOperator):
self.model_name = model_name
try:
self.model = AutoModel.from_pretrained(model_name)
self.model.eval()
except Exception as e:
model_list = get_model_list()
if model_name not in model_list:
@ -65,13 +69,29 @@ class AutoTransformers(NNOperator):
log.error(f'Invalid input for the model: {self.model_name}')
raise e
try:
features = outs.last_hidden_state.squeeze(0)
features = outs['last_hidden_state'].squeeze(0)
except Exception as e:
log.error(f'Fail to extract features by model: {self.model_name}')
raise e
vec = features.detach().numpy()
return vec
def save_model(self, jit: bool = True, destination: str = 'default'):
if destination == 'default':
path = str(Path(__file__).parent)
destination = os.path.join(path, self.model_name + '.pt')
inputs = self.tokenizer('[CLS]', return_tensors='pt')
inputs = list(inputs.values())
if jit:
try:
traced_model = torch.jit.trace(self.model, inputs, strict=False)
torch.jit.save(traced_model, destination)
except Exception as e:
log.error(f'Fail to save as torchscript: {e}.')
raise RuntimeError(f'Fail to save as torchscript: {e}.')
else:
torch.save(self.model, destination)
def get_model_list():
full_list = [

Loading…
Cancel
Save