From 26ae59a5d98fc33316bd4d8ab8688932d536aa5a Mon Sep 17 00:00:00 2001 From: Jael Gu Date: Thu, 15 Dec 2022 14:23:27 +0800 Subject: [PATCH] Fix Model with nn.Module Signed-off-by: Jael Gu --- nn_fingerprint.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/nn_fingerprint.py b/nn_fingerprint.py index 61b9506..e673f30 100644 --- a/nn_fingerprint.py +++ b/nn_fingerprint.py @@ -20,6 +20,7 @@ from pathlib import Path from typing import List, Union import torch +from torch import nn import torchaudio import numpy @@ -35,8 +36,9 @@ warnings.filterwarnings('ignore') log = logging.getLogger('nnfp_op') -class Model: +class Model(nn.Module): def __init__(self, model_name, device='cpu', model_path=None): + super().__init__() self.device = device if model_name == 'nnfp_default': self.params = default_params @@ -76,8 +78,8 @@ class Model: self.model.eval() log.info('Model is loaded.') - def __call__(self, data: 'Tensor'): - return self.model(data) + def __call__(self, *args, **kwargs): + return self.model(*args, **kwargs) @register(output_schema=['vecs']) @@ -205,7 +207,7 @@ class NNFingerprint(NNOperator): elif format == 'onnx': path = path + '.onnx' try: - torch.onnx.export(self.model.model, + torch.onnx.export(self.model, dummy_input, path, input_names=['input'],