diff --git a/nn_fingerprint.py b/nn_fingerprint.py index e7f8bb7..44e8dc9 100644 --- a/nn_fingerprint.py +++ b/nn_fingerprint.py @@ -69,7 +69,7 @@ class Model: def __call__(self, *args, **kwargs): new_args = [] new_kwargs = {} - for x in new_args: + for x in args: new_args.append(x.to(self.device)) for k, v in kwargs.items(): new_kwargs[k] = v.to(self.device)