diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..c56fe53 --- /dev/null +++ b/__init__.py @@ -0,0 +1,18 @@ +# Copyright 2021 Zilliz. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .data2vec_audio import Data2VecAudio + +def data2vec_audio(model_name="facebook/data2vec-audio-base-960h"): + return Data2VecAudio(model_name) diff --git a/data2vec_audio.py b/data2vec_audio.py new file mode 100644 index 0000000..fd1ca98 --- /dev/null +++ b/data2vec_audio.py @@ -0,0 +1,39 @@ +# Copyright 2021 Zilliz. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy +import torch +import towhee + +from PIL import Image as PILImage + +from transformers import BeitFeatureExtractor, Data2VecVisionForImageClassification +from towhee.operator.base import NNOperator + +class Data2VecAudio(NNOperator): + def __init__(self, model_name = "facebook/data2vec-audio-base-960h"): + self.model = Data2VecAudioModel.from_pretrained("facebook/data2vec-audio-base-960h") + self.processor = Wav2Vec2Processor.from_pretrained("facebook/data2vec-audio-base-960h") + + def __call__(self, data): + audio = np.hstack(data).reshape(1, -1) + audio = audio.astype(np.float32, order='C') / 32768.0 + inputs = processor(audio, sampling_rate=sampling_rate, return_tensors="pt") + with torch.no_grad(): + outputs = model(**inputs) + last_hidden_states = outputs.last_hidden_state + return last_hidden_states + + + +