diff --git a/README.md b/README.md index 587b355..c265688 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,103 @@ -# data2vec-audio +# Audio Embdding with data2vec + +*author: David Wang* + + +
+ + + +## Description + +This operator extracts features for audio with [data2vec](https://arxiv.org/abs/2202.03555). The core idea is to predict latent representations of the full input data based on a masked view of the input in a self-distillation setup using a standard Transformer architecture. + +
+ + +## Code Example + +Generate embeddings for the audio "test.wav". + + + *Write the pipeline in simplified style*: + +```python +import towhee + +( + towhee.glob('test.wav') + .audio_decode.ffmpeg() + .runas_op(func=lambda x:[y[0] for y in x]) + .towhee.data2vec_audio() + .show() +) + +``` +result1 + +*Write a same pipeline with explicit inputs/outputs name specifications:* + +```python +import towhee + +( + towhee.glob['path']('test.wav') + .audio_decode.ffmpeg['path', 'frames']() + .runas_op['frames', 'frames'](func=lambda x:[y[0] for y in x]) + .towhee.data2vec_audio['frames', 'vecs'](model_name="facebook/data2vec-audio-base-960h") + .show() +) +``` +result2 + + +
+ + + +## Factory Constructor + +Create the operator via the following factory method + +***data2vec_vision(model_name='facebook/data2vec-vision-base')*** + +**Parameters:** + + +​ ***model_name***: *str* + +The model name in string. +The default value is "facebook/data2vec-audio-base-960h". + +Supported model name: +- +- facebook/data2vec-audio-base-960h +- facebook/data2vec-audio-large-960h +- facebook/data2vec-audio-base +- facebook/data2vec-audio-base-100h +- facebook/data2vec-audio-base-10m +- facebook/data2vec-audio-large +- facebook/data2vec-audio-large-100h +- facebook/data2vec-audio-large-10m + +
+ + + +## Interface + +An audio embedding operator generates vectors in numpy.ndarray given an audio file path or towhee audio frames. + + +**Parameters:** + +​ ***data:*** *List[towhee.types.audio_frame.AudioFrame]* + +​ Input audio data is a list of towhee audio frames. The input data should represent for an audio longer than 0.9s. + +**Returns:** *numpy.ndarray* + +​ The audio embedding extracted by model. + + diff --git a/data2vec_audio.py b/data2vec_audio.py index fd1ca98..9afd92d 100644 --- a/data2vec_audio.py +++ b/data2vec_audio.py @@ -11,13 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import numpy +import numpy as np import torch import towhee from PIL import Image as PILImage -from transformers import BeitFeatureExtractor, Data2VecVisionForImageClassification +from transformers import Data2VecAudioModel, Wav2Vec2Processor from towhee.operator.base import NNOperator class Data2VecAudio(NNOperator): @@ -28,12 +28,10 @@ class Data2VecAudio(NNOperator): def __call__(self, data): audio = np.hstack(data).reshape(1, -1) audio = audio.astype(np.float32, order='C') / 32768.0 - inputs = processor(audio, sampling_rate=sampling_rate, return_tensors="pt") + sampling_rate = data[0]._sample_rate + inputs = self.processor(audio.flatten(), sampling_rate=sampling_rate, return_tensors="pt") with torch.no_grad(): - outputs = model(**inputs) + outputs = self.model(**inputs) last_hidden_states = outputs.last_hidden_state - return last_hidden_states - - - - + feat = last_hidden_states[:,-1,:].flatten().detach().cpu().numpy() + return feat