From 43a21ce0b00fc9454d788e1794f5f9a5dbafec87 Mon Sep 17 00:00:00 2001 From: Jael Gu Date: Wed, 28 Dec 2022 10:46:56 +0800 Subject: [PATCH] Update Readme for save_model Signed-off-by: Jael Gu --- README.md | 8 ++++++++ auto_transformers.py | 8 ++++---- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 236fe0d..41ce102 100644 --- a/README.md +++ b/README.md @@ -362,6 +362,14 @@ Save model to local with specified format. ​ The path where model is saved to. By default, it will save model to the operator directory. +```python +from towhee import ops + +op = ops.text_embedding.transformers(model_name='distilbert-base-cased').get_op() +op.save_model('onnx', 'test.onnx') +``` +PosixPath('/Home/.towhee/operators/text-embedding/transformers/main/test.onnx') +
***supported_model_names(format=None)*** diff --git a/auto_transformers.py b/auto_transformers.py index 4250a40..a495c36 100644 --- a/auto_transformers.py +++ b/auto_transformers.py @@ -140,14 +140,14 @@ class AutoTransformers(NNOperator): dummy_input = '[CLS]' inputs = self.tokenizer(dummy_input, return_tensors='pt') # a dictionary if model_type == 'pytorch': - torch.save(self.model, output_file) + torch.save(self._model, output_file) elif model_type == 'torchscript': inputs = list(inputs.values()) try: try: - jit_model = torch.jit.script(self.model) + jit_model = torch.jit.script(self._model) except Exception: - jit_model = torch.jit.trace(self.model, inputs, strict=False) + jit_model = torch.jit.trace(self._model, inputs, strict=False) torch.jit.save(jit_model, output_file) except Exception as e: log.error(f'Fail to save as torchscript: {e}.') @@ -171,7 +171,7 @@ class AutoTransformers(NNOperator): # todo: elif format == 'tensorrt': else: log.error(f'Unsupported format "{format}".') - return True + return Path(output_file).resolve() @property def supported_formats(self):