|
|
@ -30,7 +30,7 @@ from towhee.types.audio_frame import AudioFrame |
|
|
|
from towhee.models.nnfp import NNFp |
|
|
|
from towhee.models.utils.audio_preprocess import preprocess_wav, MelSpec |
|
|
|
|
|
|
|
from .configs import default_params |
|
|
|
from .configs import default_params, hop25_params, distill_params |
|
|
|
|
|
|
|
warnings.filterwarnings('ignore') |
|
|
|
log = logging.getLogger() |
|
|
@ -43,7 +43,7 @@ class NNFingerprint(NNOperator): |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, |
|
|
|
params: dict = None, |
|
|
|
model_name: str = 'nnfp_default', |
|
|
|
model_path: str = None, |
|
|
|
framework: str = 'pytorch', |
|
|
|
device: str = None |
|
|
@ -52,10 +52,14 @@ class NNFingerprint(NNOperator): |
|
|
|
if device is None: |
|
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
|
self.device = device |
|
|
|
if params is None: |
|
|
|
if model_name == 'nnfp_default': |
|
|
|
self.params = default_params |
|
|
|
elif model_name == 'nnfp_hop25': |
|
|
|
self.params = hop25_params |
|
|
|
elif model_name == 'nnfp_distill': |
|
|
|
self.params == distill_params |
|
|
|
else: |
|
|
|
self.params = params |
|
|
|
raise ValueError('Invalid model name. Accept value from ["nnfp_default", "nnfp_hop25", "nnfp_distill"]') |
|
|
|
|
|
|
|
log.info('Loading model...') |
|
|
|
if model_path is None: |
|
|
@ -127,6 +131,9 @@ class NNFingerprint(NNOperator): |
|
|
|
if sr != self.params['sample_rate']: |
|
|
|
resampler = torchaudio.transforms.Resample(sr, self.params['sample_rate'], dtype=audio.dtype) |
|
|
|
audio = resampler(audio) |
|
|
|
# import resampy |
|
|
|
# audio = audio.detach().cpu().numpy() |
|
|
|
# audio = resampy.resample(audio, sr, self.params['sample_rate']) |
|
|
|
|
|
|
|
wav = preprocess_wav(audio, |
|
|
|
segment_size=int(self.params['sample_rate'] * self.params['segment_size']), |
|
|
|