diff --git a/nn_fingerprint.py b/nn_fingerprint.py index fd6e651..c891256 100644 --- a/nn_fingerprint.py +++ b/nn_fingerprint.py @@ -70,7 +70,7 @@ class NNFingerprint(NNOperator): log.info('Loading weights...') if checkpoint_path is None: path = str(Path(__file__).parent) - checkpoint_path = os.path.join(path, './checkpoints/pfann_fma_m.pt') + checkpoint_path = os.path.join(path, 'saved_model', 'pfann_fma_m.pt') state_dict = torch.load(checkpoint_path, map_location=self.device) self.model.load_state_dict(state_dict) self.model.eval()