diff --git a/auto_transformers.py b/auto_transformers.py index eb0a46b..797d8ed 100644 --- a/auto_transformers.py +++ b/auto_transformers.py @@ -101,7 +101,6 @@ class AutoTransformers(NNOperator): raise RuntimeError(f'Fail to save as torchscript: {e}.') elif format == 'onnx': path = path + '.onnx' - try: torch.onnx.export(self.model, tuple(inputs.values()), @@ -129,6 +128,9 @@ class AutoTransformers(NNOperator): "last_hidden_state": {0: "batch_size"}, "pooler_outputs": {0: "batch_size"} }) + elif format == 'tensorrt': + # os.system('pip install "git+https://github.com/grimoire/torch2trt_dynamic.git"') + pass else: log.error(f'Unsupported format "{format}".')