From 444c808ac0f2c8a461d278fa906a93150df16007 Mon Sep 17 00:00:00 2001 From: Jael Gu Date: Mon, 10 Oct 2022 10:26:23 +0800 Subject: [PATCH] Add param model_name Signed-off-by: Jael Gu --- configs.py | 6 ++++++ nn_fingerprint.py | 15 +++++++++++---- 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/configs.py b/configs.py index 2388120..f915216 100644 --- a/configs.py +++ b/configs.py @@ -34,3 +34,9 @@ default_params = { "mel_log": "log", "spec_norm": "l2" } + +hop25_params = default_params.copy() +hop25_params.update(hop_size=0.25) + +distill_params = default_params.copy() +distill_params.update(h=1024) diff --git a/nn_fingerprint.py b/nn_fingerprint.py index 8c1c5c3..6eadb50 100644 --- a/nn_fingerprint.py +++ b/nn_fingerprint.py @@ -30,7 +30,7 @@ from towhee.types.audio_frame import AudioFrame from towhee.models.nnfp import NNFp from towhee.models.utils.audio_preprocess import preprocess_wav, MelSpec -from .configs import default_params +from .configs import default_params, hop25_params, distill_params warnings.filterwarnings('ignore') log = logging.getLogger() @@ -43,7 +43,7 @@ class NNFingerprint(NNOperator): """ def __init__(self, - params: dict = None, + model_name: str = 'nnfp_default', model_path: str = None, framework: str = 'pytorch', device: str = None @@ -52,10 +52,14 @@ class NNFingerprint(NNOperator): if device is None: device = 'cuda' if torch.cuda.is_available() else 'cpu' self.device = device - if params is None: + if model_name == 'nnfp_default': self.params = default_params + elif model_name == 'nnfp_hop25': + self.params = hop25_params + elif model_name == 'nnfp_distill': + self.params == distill_params else: - self.params = params + raise ValueError('Invalid model name. Accept value from ["nnfp_default", "nnfp_hop25", "nnfp_distill"]') log.info('Loading model...') if model_path is None: @@ -127,6 +131,9 @@ class NNFingerprint(NNOperator): if sr != self.params['sample_rate']: resampler = torchaudio.transforms.Resample(sr, self.params['sample_rate'], dtype=audio.dtype) audio = resampler(audio) + # import resampy + # audio = audio.detach().cpu().numpy() + # audio = resampy.resample(audio, sr, self.params['sample_rate']) wav = preprocess_wav(audio, segment_size=int(self.params['sample_rate'] * self.params['segment_size']),