diff --git a/timm_image.py b/timm_image.py index 45dca20..f59c7bd 100644 --- a/timm_image.py +++ b/timm_image.py @@ -166,7 +166,7 @@ class TimmImage(NNOperator): path = path + '.onnx' else: raise AttributeError(f'Invalid format {format}.') - dummy_input = torch.rand((1,) + self.config['input_size']) + dummy_input = torch.rand((1,) + self.config['input_size']).to(self.device) if format == 'pytorch': torch.save(self._model, path) elif format == 'torchscript':