logo
Browse Source

Fix onnx export

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 2 years ago
parent
commit
e1a1abd518
  1. 2
      auto_transformers.py

2
auto_transformers.py

@ -188,7 +188,7 @@ class AutoTransformers(NNOperator):
raise AttributeError('Unsupported model_type.') raise AttributeError('Unsupported model_type.')
dummy_input = 'test sentence' dummy_input = 'test sentence'
inputs = self.tokenizer(dummy_input, padding=True, truncation=True, return_tensors='pt') # a dictionary
inputs = self.tokenizer(dummy_input, padding=True, truncation=True, return_tensors='pt').to(self.device)
if model_type == 'pytorch': if model_type == 'pytorch':
torch.save(self._model, output_file) torch.save(self._model, output_file)
elif model_type == 'torchscript': elif model_type == 'torchscript':

Loading…
Cancel
Save