diff --git a/nn_fingerprint.py b/nn_fingerprint.py index e673f30..0245b58 100644 --- a/nn_fingerprint.py +++ b/nn_fingerprint.py @@ -20,7 +20,6 @@ from pathlib import Path from typing import List, Union import torch -from torch import nn import torchaudio import numpy @@ -36,9 +35,8 @@ warnings.filterwarnings('ignore') log = logging.getLogger('nnfp_op') -class Model(nn.Module): +class Model: def __init__(self, model_name, device='cpu', model_path=None): - super().__init__() self.device = device if model_name == 'nnfp_default': self.params = default_params @@ -98,15 +96,16 @@ class NNFingerprint(NNOperator): device = 'cuda' if torch.cuda.is_available() else 'cpu' self.device = device self.model_name = model_name - self.model = Model(model_name=model_name, device=self.device, model_path=model_path) - self.params = self.model.params + self.accelerate_model = Model(model_name=model_name, device=self.device, model_path=model_path) + self.model = self.accelerate_model.model + self.params = self.accelerate_model.params def __call__(self, data: Union[str, List[AudioFrame]]) -> numpy.ndarray: audio_tensors = self.preprocess(data) if audio_tensors.device != self.device: audio_tensors = audio_tensors.to(self.device) # print(audio_tensors.shape) - features = self.model(audio_tensors) + features = self.accelerate_model(audio_tensors) outs = features.detach().cpu().numpy() return outs