diff --git a/nn_fingerprint.py b/nn_fingerprint.py index 61b9506..e673f30 100644 --- a/nn_fingerprint.py +++ b/nn_fingerprint.py @@ -20,6 +20,7 @@ from pathlib import Path from typing import List, Union import torch +from torch import nn import torchaudio import numpy @@ -35,8 +36,9 @@ warnings.filterwarnings('ignore') log = logging.getLogger('nnfp_op') -class Model: +class Model(nn.Module): def __init__(self, model_name, device='cpu', model_path=None): + super().__init__() self.device = device if model_name == 'nnfp_default': self.params = default_params @@ -76,8 +78,8 @@ class Model: self.model.eval() log.info('Model is loaded.') - def __call__(self, data: 'Tensor'): - return self.model(data) + def __call__(self, *args, **kwargs): + return self.model(*args, **kwargs) @register(output_schema=['vecs']) @@ -205,7 +207,7 @@ class NNFingerprint(NNOperator): elif format == 'onnx': path = path + '.onnx' try: - torch.onnx.export(self.model.model, + torch.onnx.export(self.model, dummy_input, path, input_names=['input'],