diff --git a/nn_fingerprint.py b/nn_fingerprint.py index 1cf86b3..25586eb 100644 --- a/nn_fingerprint.py +++ b/nn_fingerprint.py @@ -42,7 +42,7 @@ log = logging.getLogger('nnfp_op') log.setLevel(logging.ERROR) -# @accelerate +@accelerate class Model: def __init__(self, params, device='cpu', model_path=None): self.device = device