logo
Browse Source

Update Readme for save_model

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 1 year ago
parent
commit
43a21ce0b0
  1. 8
      README.md
  2. 8
      auto_transformers.py

8
README.md

@ -362,6 +362,14 @@ Save model to local with specified format.
​ The path where model is saved to. By default, it will save model to the operator directory.
```python
from towhee import ops
op = ops.text_embedding.transformers(model_name='distilbert-base-cased').get_op()
op.save_model('onnx', 'test.onnx')
```
PosixPath('/Home/.towhee/operators/text-embedding/transformers/main/test.onnx')
<br />
***supported_model_names(format=None)***

8
auto_transformers.py

@ -140,14 +140,14 @@ class AutoTransformers(NNOperator):
dummy_input = '[CLS]'
inputs = self.tokenizer(dummy_input, return_tensors='pt') # a dictionary
if model_type == 'pytorch':
torch.save(self.model, output_file)
torch.save(self._model, output_file)
elif model_type == 'torchscript':
inputs = list(inputs.values())
try:
try:
jit_model = torch.jit.script(self.model)
jit_model = torch.jit.script(self._model)
except Exception:
jit_model = torch.jit.trace(self.model, inputs, strict=False)
jit_model = torch.jit.trace(self._model, inputs, strict=False)
torch.jit.save(jit_model, output_file)
except Exception as e:
log.error(f'Fail to save as torchscript: {e}.')
@ -171,7 +171,7 @@ class AutoTransformers(NNOperator):
# todo: elif format == 'tensorrt':
else:
log.error(f'Unsupported format "{format}".')
return True
return Path(output_file).resolve()
@property
def supported_formats(self):

Loading…
Cancel
Save