logo
Browse Source

Update Op

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 2 years ago
parent
commit
232277e4ea
  1. 11
      nn_fingerprint.py

11
nn_fingerprint.py

@ -20,7 +20,6 @@ from pathlib import Path
from typing import List, Union
import torch
from torch import nn
import torchaudio
import numpy
@ -36,9 +35,8 @@ warnings.filterwarnings('ignore')
log = logging.getLogger('nnfp_op')
class Model(nn.Module):
class Model:
def __init__(self, model_name, device='cpu', model_path=None):
super().__init__()
self.device = device
if model_name == 'nnfp_default':
self.params = default_params
@ -98,15 +96,16 @@ class NNFingerprint(NNOperator):
device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.device = device
self.model_name = model_name
self.model = Model(model_name=model_name, device=self.device, model_path=model_path)
self.params = self.model.params
self.accelerate_model = Model(model_name=model_name, device=self.device, model_path=model_path)
self.model = self.accelerate_model.model
self.params = self.accelerate_model.params
def __call__(self, data: Union[str, List[AudioFrame]]) -> numpy.ndarray:
audio_tensors = self.preprocess(data)
if audio_tensors.device != self.device:
audio_tensors = audio_tensors.to(self.device)
# print(audio_tensors.shape)
features = self.model(audio_tensors)
features = self.accelerate_model(audio_tensors)
outs = features.detach().cpu().numpy()
return outs

Loading…
Cancel
Save