From fd966f51bd11f4cac51bff13eb87602ef4e98b64 Mon Sep 17 00:00:00 2001 From: Jael Gu Date: Wed, 1 Jun 2022 14:01:43 +0800 Subject: [PATCH] Support save model with torchscript Signed-off-by: Jael Gu --- auto_transformers.py | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/auto_transformers.py b/auto_transformers.py index 955ba6f..55d76e1 100644 --- a/auto_transformers.py +++ b/auto_transformers.py @@ -13,6 +13,9 @@ # limitations under the License. import numpy +import os +import torch +from pathlib import Path from transformers import AutoTokenizer, AutoModel from towhee.operator import NNOperator @@ -40,6 +43,7 @@ class AutoTransformers(NNOperator): self.model_name = model_name try: self.model = AutoModel.from_pretrained(model_name) + self.model.eval() except Exception as e: model_list = get_model_list() if model_name not in model_list: @@ -65,13 +69,29 @@ class AutoTransformers(NNOperator): log.error(f'Invalid input for the model: {self.model_name}') raise e try: - features = outs.last_hidden_state.squeeze(0) + features = outs['last_hidden_state'].squeeze(0) except Exception as e: log.error(f'Fail to extract features by model: {self.model_name}') raise e vec = features.detach().numpy() return vec + def save_model(self, jit: bool = True, destination: str = 'default'): + if destination == 'default': + path = str(Path(__file__).parent) + destination = os.path.join(path, self.model_name + '.pt') + inputs = self.tokenizer('[CLS]', return_tensors='pt') + inputs = list(inputs.values()) + if jit: + try: + traced_model = torch.jit.trace(self.model, inputs, strict=False) + torch.jit.save(traced_model, destination) + except Exception as e: + log.error(f'Fail to save as torchscript: {e}.') + raise RuntimeError(f'Fail to save as torchscript: {e}.') + else: + torch.save(self.model, destination) + def get_model_list(): full_list = [