diff --git a/isc.py b/isc.py index 8a5ee07..52dcfc5 100644 --- a/isc.py +++ b/isc.py @@ -130,7 +130,7 @@ class Isc(NNOperator): path = path + '.onnx' else: raise ValueError(f'Invalid format {format}.') - dummy_input = torch.rand(1, 3, 224, 224) + dummy_input = torch.rand(1, 3, 224, 224).to(self.device) if format == 'pytorch': torch.save(self._model, path) elif format == 'torchscript':