logo
Browse Source

Add param model_name

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 2 years ago
parent
commit
444c808ac0
  1. 6
      configs.py
  2. 15
      nn_fingerprint.py

6
configs.py

@ -34,3 +34,9 @@ default_params = {
"mel_log": "log", "mel_log": "log",
"spec_norm": "l2" "spec_norm": "l2"
} }
hop25_params = default_params.copy()
hop25_params.update(hop_size=0.25)
distill_params = default_params.copy()
distill_params.update(h=1024)

15
nn_fingerprint.py

@ -30,7 +30,7 @@ from towhee.types.audio_frame import AudioFrame
from towhee.models.nnfp import NNFp from towhee.models.nnfp import NNFp
from towhee.models.utils.audio_preprocess import preprocess_wav, MelSpec from towhee.models.utils.audio_preprocess import preprocess_wav, MelSpec
from .configs import default_params
from .configs import default_params, hop25_params, distill_params
warnings.filterwarnings('ignore') warnings.filterwarnings('ignore')
log = logging.getLogger() log = logging.getLogger()
@ -43,7 +43,7 @@ class NNFingerprint(NNOperator):
""" """
def __init__(self, def __init__(self,
params: dict = None,
model_name: str = 'nnfp_default',
model_path: str = None, model_path: str = None,
framework: str = 'pytorch', framework: str = 'pytorch',
device: str = None device: str = None
@ -52,10 +52,14 @@ class NNFingerprint(NNOperator):
if device is None: if device is None:
device = 'cuda' if torch.cuda.is_available() else 'cpu' device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.device = device self.device = device
if params is None:
if model_name == 'nnfp_default':
self.params = default_params self.params = default_params
elif model_name == 'nnfp_hop25':
self.params = hop25_params
elif model_name == 'nnfp_distill':
self.params == distill_params
else: else:
self.params = params
raise ValueError('Invalid model name. Accept value from ["nnfp_default", "nnfp_hop25", "nnfp_distill"]')
log.info('Loading model...') log.info('Loading model...')
if model_path is None: if model_path is None:
@ -127,6 +131,9 @@ class NNFingerprint(NNOperator):
if sr != self.params['sample_rate']: if sr != self.params['sample_rate']:
resampler = torchaudio.transforms.Resample(sr, self.params['sample_rate'], dtype=audio.dtype) resampler = torchaudio.transforms.Resample(sr, self.params['sample_rate'], dtype=audio.dtype)
audio = resampler(audio) audio = resampler(audio)
# import resampy
# audio = audio.detach().cpu().numpy()
# audio = resampy.resample(audio, sr, self.params['sample_rate'])
wav = preprocess_wav(audio, wav = preprocess_wav(audio,
segment_size=int(self.params['sample_rate'] * self.params['segment_size']), segment_size=int(self.params['sample_rate'] * self.params['segment_size']),

Loading…
Cancel
Save