diff --git a/README.md b/README.md index aa69c20..455c064 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,95 @@ -# nnfp +# Audio Embedding with Neural Network Fingerprint +*Author: [Jael Gu](https://github.com/jaelgu)* + +
+ +## Description + +The audio embedding operator converts an input audio into a dense vector which can be used to represent the audio clip's semantics. +Each vector represents for an audio clip with a fixed length of around 1s. +This operator generates audio embeddings with fingerprinting method introduced by [Neural Audio Fingerprint](https://arxiv.org/abs/2010.11910). +The model is implemented in Pytorch. +We've also trained the nnfp model with [FMA dataset](https://github.com/mdeff/fma) (& some noise audio) and shared weights in this operator. +The nnfp operator is suitable to generate audio fingerprints. + +
+ +## Code Example + +Generate embeddings for the audio "test.wav". + +*Write the pipeline in simplified style*: + +```python +import towhee + +( + towhee.glob('test.wav') + .audio_decode.ffmpeg() + .runas_op(func=lambda x:[y[0] for y in x]) + .audio_embedding.nnfp() # use default model + .show() +) +``` + + +*Write a same pipeline with explicit inputs/outputs name specifications:* + +```python +import towhee + +( + towhee.glob['path']('test.wav') + .audio_decode.ffmpeg['path', 'frames']() + .runas_op['frames', 'frames'](func=lambda x:[y[0] for y in x]) + .audio_embedding.nnfp['frames', 'vecs']() + .select['path', 'vecs']() + .show() +) +``` + + +
+ +## Factory Constructor + +Create the operator via the following factory method + +***audio_embedding.nnfp(params=None, checkpoint_path=None, framework='pytorch')*** + +**Parameters:** + +*params: dict* + +A dictionary of model parameters. If None, it will use default parameters to create model. + +*checkpoint_path: str* + +The path to model weights. If None, it will load default model weights. + +*framework: str* + +The framework of model implementation. +Default value is "pytorch" since the model is implemented in Pytorch. + +
+ +## Interface + +An audio embedding operator generates vectors in numpy.ndarray given towhee audio frames. + +**Parameters:** + +*data: List[towhee.types.audio_frame.AudioFrame]* + +Input audio data is a list of towhee audio frames. +The input data should represent for an audio longer than 1s. + + +**Returns**: + +*numpy.ndarray* + +Audio embeddings in shape (num_clips, 128). +Each embedding stands for features of an audio clip with length of 1s. diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..e7dc4ec --- /dev/null +++ b/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2021 Zilliz. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .nn_fingerprint import NNFingerprint + + +def nnfp(): + return NNFingerprint() diff --git a/configs.py b/configs.py new file mode 100644 index 0000000..2388120 --- /dev/null +++ b/configs.py @@ -0,0 +1,36 @@ +# Parameter configs for nnfp, inspired by https://github.com/stdio2016/pfann +# +# Copyright 2021 Zilliz. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +default_params = { + "dim": 128, + "h": 1024, + "u": 32, + "fuller": True, + "activation": "relu", + "sample_rate": 8000, + "window_length": 1024, + "hop_length": 256, + "n_mels": 256, + "f_min": 300, + "f_max": 4000, + "segment_size": 1, + "hop_size": 1, + "frame_shift_mul": 1, + "naf_mode": False, + "mel_log": "log", + "spec_norm": "l2" +} diff --git a/nn_fingerprint.py b/nn_fingerprint.py new file mode 100644 index 0000000..fd6e651 --- /dev/null +++ b/nn_fingerprint.py @@ -0,0 +1,135 @@ +# Copyright 2021 Zilliz. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import warnings + +import os +from pathlib import Path +from typing import List + +import torch +import numpy +import resampy + +from towhee.operator.base import NNOperator +from towhee import register +from towhee.types.audio_frame import AudioFrame +from towhee.models.nnfp import NNFp +from towhee.models.utils.audio_preprocess import preprocess_wav, MelSpec + +from .configs import default_params + +warnings.filterwarnings('ignore') +log = logging.getLogger() + + +@register(output_schema=['vecs']) +class NNFingerprint(NNOperator): + """ + Audio embedding operator using Neural Network Fingerprint + """ + + def __init__(self, + params: dict = None, + checkpoint_path: str = None, + framework: str = 'pytorch'): + super().__init__(framework=framework) + self.device = "cuda" if torch.cuda.is_available() else "cpu" + if params is None: + self.params = default_params + else: + self.params = params + + dim = self.params['dim'] + h = self.params['h'] + u = self.params['u'] + f_bin = self.params['n_mels'] + n_seg = int(self.params['segment_size'] * self.params['sample_rate']) + t = (n_seg + self.params['hop_length'] - 1) // self.params['hop_length'] + + log.info('Creating model...') + self.model = NNFp( + dim=dim, h=h, u=u, + in_f=f_bin, in_t=t, + fuller=self.params['fuller'], + activation=self.params['activation'] + ).to(self.device) + + log.info('Loading weights...') + if checkpoint_path is None: + path = str(Path(__file__).parent) + checkpoint_path = os.path.join(path, './checkpoints/pfann_fma_m.pt') + state_dict = torch.load(checkpoint_path, map_location=self.device) + self.model.load_state_dict(state_dict) + self.model.eval() + log.info('Model is loaded.') + + def __call__(self, data: List[AudioFrame]) -> numpy.ndarray: + audio_tensors = self.preprocess(data).to(self.device) + features = self.model(audio_tensors) + return features.detach().cpu().numpy() + + def preprocess(self, frames: List[AudioFrame]): + sr = frames[0].sample_rate + layout = frames[0].layout + if layout == 'stereo': + frames = [frame.reshape(-1, 2) for frame in frames] + audio = numpy.vstack(frames).transpose() + else: + audio = numpy.hstack(frames) + audio = audio[None, :] + + audio = self.int2float(audio) + + if sr != self.params['sample_rate']: + audio = resampy.resample(audio, sr, self.params['sample_rate']) + + wav = preprocess_wav(audio, + segment_size=int(self.params['sample_rate'] * self.params['segment_size']), + hop_size=int(self.params['sample_rate'] * self.params['hop_size']), + frame_shift_mul=self.params['frame_shift_mul']).to(self.device) + wav = wav.to(torch.float32) + mel = MelSpec(sample_rate=self.params['sample_rate'], + window_length=self.params['window_length'], + hop_length=self.params['hop_length'], + f_min=self.params['f_min'], + f_max=self.params['f_max'], + n_mels=self.params['n_mels'], + naf_mode=self.params['naf_mode'], + mel_log=self.params['mel_log'], + spec_norm=self.params['spec_norm']).to(self.device) + wav = mel(wav) + return wav + + @staticmethod + def int2float(wav: numpy.ndarray, dtype: str = 'float64'): + """ + Convert audio imgs from int to float. + The input dtype must be integers. + The output dtype is controlled by the parameter `dtype`, defaults to 'float64'. + + The code is inspired by https://github.com/mgeier/python-audio/blob/master/audio-files/utility.py + """ + dtype = numpy.dtype(dtype) + assert dtype.kind == 'f' + + if wav.dtype.kind in 'iu': + ii = numpy.iinfo(wav.dtype) + abs_max = 2 ** (ii.bits - 1) + offset = ii.min + abs_max + return (wav.astype(dtype) - offset) / abs_max + else: + log.warning('Converting float dtype from %s to %s.', wav.dtype, dtype) + return wav.astype(dtype) diff --git a/result1.png b/result1.png new file mode 100644 index 0000000..5bbbf1a Binary files /dev/null and b/result1.png differ diff --git a/result2.png b/result2.png new file mode 100644 index 0000000..d64cc0f Binary files /dev/null and b/result2.png differ diff --git a/saved_model/pfann_fma_m.pt b/saved_model/pfann_fma_m.pt new file mode 100644 index 0000000..b990f16 --- /dev/null +++ b/saved_model/pfann_fma_m.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:62641652af2ed46b6146403db310c0c4725319f4525aa9c2060cd751329d7309 +size 67677593 diff --git a/saved_model/pfann_fma_s.pt b/saved_model/pfann_fma_s.pt new file mode 100644 index 0000000..b9ab647 --- /dev/null +++ b/saved_model/pfann_fma_s.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:470214d98a67b2d9d54d3f73ea389ae922dd41a10e2f1578e0d3014974996ec0 +size 67677593