diff --git a/nn_fingerprint.py b/nn_fingerprint.py index 46e6c7d..e7f8bb7 100644 --- a/nn_fingerprint.py +++ b/nn_fingerprint.py @@ -70,7 +70,7 @@ class Model: new_args = [] new_kwargs = {} for x in new_args: - x = x.to(self.device) + new_args.append(x.to(self.device)) for k, v in kwargs.items(): new_kwargs[k] = v.to(self.device) return self.model(*new_args, **new_kwargs)