Browse Source
Fix Model with nn.Module
Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
1 changed files with
6 additions and
4 deletions
-
nn_fingerprint.py
|
|
@ -20,6 +20,7 @@ from pathlib import Path |
|
|
|
from typing import List, Union |
|
|
|
|
|
|
|
import torch |
|
|
|
from torch import nn |
|
|
|
import torchaudio |
|
|
|
import numpy |
|
|
|
|
|
|
@ -35,8 +36,9 @@ warnings.filterwarnings('ignore') |
|
|
|
log = logging.getLogger('nnfp_op') |
|
|
|
|
|
|
|
|
|
|
|
class Model: |
|
|
|
class Model(nn.Module): |
|
|
|
def __init__(self, model_name, device='cpu', model_path=None): |
|
|
|
super().__init__() |
|
|
|
self.device = device |
|
|
|
if model_name == 'nnfp_default': |
|
|
|
self.params = default_params |
|
|
@ -76,8 +78,8 @@ class Model: |
|
|
|
self.model.eval() |
|
|
|
log.info('Model is loaded.') |
|
|
|
|
|
|
|
def __call__(self, data: 'Tensor'): |
|
|
|
return self.model(data) |
|
|
|
def __call__(self, *args, **kwargs): |
|
|
|
return self.model(*args, **kwargs) |
|
|
|
|
|
|
|
|
|
|
|
@register(output_schema=['vecs']) |
|
|
@ -205,7 +207,7 @@ class NNFingerprint(NNOperator): |
|
|
|
elif format == 'onnx': |
|
|
|
path = path + '.onnx' |
|
|
|
try: |
|
|
|
torch.onnx.export(self.model.model, |
|
|
|
torch.onnx.export(self.model, |
|
|
|
dummy_input, |
|
|
|
path, |
|
|
|
input_names=['input'], |
|
|
|