Browse Source
Update Readme for save_model
Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
2 changed files with
12 additions and
4 deletions
-
README.md
-
auto_transformers.py
|
|
@ -362,6 +362,14 @@ Save model to local with specified format. |
|
|
|
|
|
|
|
The path where model is saved to. By default, it will save model to the operator directory. |
|
|
|
|
|
|
|
```python |
|
|
|
from towhee import ops |
|
|
|
|
|
|
|
op = ops.text_embedding.transformers(model_name='distilbert-base-cased').get_op() |
|
|
|
op.save_model('onnx', 'test.onnx') |
|
|
|
``` |
|
|
|
PosixPath('/Home/.towhee/operators/text-embedding/transformers/main/test.onnx') |
|
|
|
|
|
|
|
<br /> |
|
|
|
|
|
|
|
***supported_model_names(format=None)*** |
|
|
|
|
|
@ -140,14 +140,14 @@ class AutoTransformers(NNOperator): |
|
|
|
dummy_input = '[CLS]' |
|
|
|
inputs = self.tokenizer(dummy_input, return_tensors='pt') # a dictionary |
|
|
|
if model_type == 'pytorch': |
|
|
|
torch.save(self.model, output_file) |
|
|
|
torch.save(self._model, output_file) |
|
|
|
elif model_type == 'torchscript': |
|
|
|
inputs = list(inputs.values()) |
|
|
|
try: |
|
|
|
try: |
|
|
|
jit_model = torch.jit.script(self.model) |
|
|
|
jit_model = torch.jit.script(self._model) |
|
|
|
except Exception: |
|
|
|
jit_model = torch.jit.trace(self.model, inputs, strict=False) |
|
|
|
jit_model = torch.jit.trace(self._model, inputs, strict=False) |
|
|
|
torch.jit.save(jit_model, output_file) |
|
|
|
except Exception as e: |
|
|
|
log.error(f'Fail to save as torchscript: {e}.') |
|
|
@ -171,7 +171,7 @@ class AutoTransformers(NNOperator): |
|
|
|
# todo: elif format == 'tensorrt': |
|
|
|
else: |
|
|
|
log.error(f'Unsupported format "{format}".') |
|
|
|
return True |
|
|
|
return Path(output_file).resolve() |
|
|
|
|
|
|
|
@property |
|
|
|
def supported_formats(self): |
|
|
|