Browse Source
Update
Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
1 changed files with
3 additions and
1 deletions
-
auto_transformers.py
|
@ -101,7 +101,6 @@ class AutoTransformers(NNOperator): |
|
|
raise RuntimeError(f'Fail to save as torchscript: {e}.') |
|
|
raise RuntimeError(f'Fail to save as torchscript: {e}.') |
|
|
elif format == 'onnx': |
|
|
elif format == 'onnx': |
|
|
path = path + '.onnx' |
|
|
path = path + '.onnx' |
|
|
|
|
|
|
|
|
try: |
|
|
try: |
|
|
torch.onnx.export(self.model, |
|
|
torch.onnx.export(self.model, |
|
|
tuple(inputs.values()), |
|
|
tuple(inputs.values()), |
|
@ -129,6 +128,9 @@ class AutoTransformers(NNOperator): |
|
|
"last_hidden_state": {0: "batch_size"}, |
|
|
"last_hidden_state": {0: "batch_size"}, |
|
|
"pooler_outputs": {0: "batch_size"} |
|
|
"pooler_outputs": {0: "batch_size"} |
|
|
}) |
|
|
}) |
|
|
|
|
|
elif format == 'tensorrt': |
|
|
|
|
|
# os.system('pip install "git+https://github.com/grimoire/torch2trt_dynamic.git"') |
|
|
|
|
|
pass |
|
|
else: |
|
|
else: |
|
|
log.error(f'Unsupported format "{format}".') |
|
|
log.error(f'Unsupported format "{format}".') |
|
|
|
|
|
|
|
|