diff --git a/nn_fingerprint.py b/nn_fingerprint.py index 3623a7d..cef8c08 100644 --- a/nn_fingerprint.py +++ b/nn_fingerprint.py @@ -157,7 +157,7 @@ class NNFingerprint(NNOperator): path = os.path.join(path, name) dummy_input = torch.rand( (1,) + (self.params['n_mels'], self.params['u']) - ) + ).to(self.device) if format == 'pytorch': path = path + '.pt' torch.save(self.model, path)