logo
Browse Source

Fix Model with nn.Module

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 2 years ago
parent
commit
26ae59a5d9
  1. 10
      nn_fingerprint.py

10
nn_fingerprint.py

@ -20,6 +20,7 @@ from pathlib import Path
from typing import List, Union from typing import List, Union
import torch import torch
from torch import nn
import torchaudio import torchaudio
import numpy import numpy
@ -35,8 +36,9 @@ warnings.filterwarnings('ignore')
log = logging.getLogger('nnfp_op') log = logging.getLogger('nnfp_op')
class Model:
class Model(nn.Module):
def __init__(self, model_name, device='cpu', model_path=None): def __init__(self, model_name, device='cpu', model_path=None):
super().__init__()
self.device = device self.device = device
if model_name == 'nnfp_default': if model_name == 'nnfp_default':
self.params = default_params self.params = default_params
@ -76,8 +78,8 @@ class Model:
self.model.eval() self.model.eval()
log.info('Model is loaded.') log.info('Model is loaded.')
def __call__(self, data: 'Tensor'):
return self.model(data)
def __call__(self, *args, **kwargs):
return self.model(*args, **kwargs)
@register(output_schema=['vecs']) @register(output_schema=['vecs'])
@ -205,7 +207,7 @@ class NNFingerprint(NNOperator):
elif format == 'onnx': elif format == 'onnx':
path = path + '.onnx' path = path + '.onnx'
try: try:
torch.onnx.export(self.model.model,
torch.onnx.export(self.model,
dummy_input, dummy_input,
path, path,
input_names=['input'], input_names=['input'],

Loading…
Cancel
Save