diff --git a/README.md b/README.md
index 587b355..c265688 100644
--- a/README.md
+++ b/README.md
@@ -1,2 +1,103 @@
-# data2vec-audio
+# Audio Embdding with data2vec
+
+*author: David Wang*
+
+
+
+
+
+
+## Description
+
+This operator extracts features for audio with [data2vec](https://arxiv.org/abs/2202.03555). The core idea is to predict latent representations of the full input data based on a masked view of the input in a self-distillation setup using a standard Transformer architecture.
+
+
+
+
+## Code Example
+
+Generate embeddings for the audio "test.wav".
+
+
+ *Write the pipeline in simplified style*:
+
+```python
+import towhee
+
+(
+ towhee.glob('test.wav')
+ .audio_decode.ffmpeg()
+ .runas_op(func=lambda x:[y[0] for y in x])
+ .towhee.data2vec_audio()
+ .show()
+)
+
+```
+
+
+*Write a same pipeline with explicit inputs/outputs name specifications:*
+
+```python
+import towhee
+
+(
+ towhee.glob['path']('test.wav')
+ .audio_decode.ffmpeg['path', 'frames']()
+ .runas_op['frames', 'frames'](func=lambda x:[y[0] for y in x])
+ .towhee.data2vec_audio['frames', 'vecs'](model_name="facebook/data2vec-audio-base-960h")
+ .show()
+)
+```
+
+
+
+
+
+
+
+## Factory Constructor
+
+Create the operator via the following factory method
+
+***data2vec_vision(model_name='facebook/data2vec-vision-base')***
+
+**Parameters:**
+
+
+ ***model_name***: *str*
+
+The model name in string.
+The default value is "facebook/data2vec-audio-base-960h".
+
+Supported model name:
+-
+- facebook/data2vec-audio-base-960h
+- facebook/data2vec-audio-large-960h
+- facebook/data2vec-audio-base
+- facebook/data2vec-audio-base-100h
+- facebook/data2vec-audio-base-10m
+- facebook/data2vec-audio-large
+- facebook/data2vec-audio-large-100h
+- facebook/data2vec-audio-large-10m
+
+
+
+
+
+## Interface
+
+An audio embedding operator generates vectors in numpy.ndarray given an audio file path or towhee audio frames.
+
+
+**Parameters:**
+
+ ***data:*** *List[towhee.types.audio_frame.AudioFrame]*
+
+ Input audio data is a list of towhee audio frames. The input data should represent for an audio longer than 0.9s.
+
+**Returns:** *numpy.ndarray*
+
+ The audio embedding extracted by model.
+
+
diff --git a/data2vec_audio.py b/data2vec_audio.py
index fd1ca98..9afd92d 100644
--- a/data2vec_audio.py
+++ b/data2vec_audio.py
@@ -11,13 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import numpy
+import numpy as np
import torch
import towhee
from PIL import Image as PILImage
-from transformers import BeitFeatureExtractor, Data2VecVisionForImageClassification
+from transformers import Data2VecAudioModel, Wav2Vec2Processor
from towhee.operator.base import NNOperator
class Data2VecAudio(NNOperator):
@@ -28,12 +28,10 @@ class Data2VecAudio(NNOperator):
def __call__(self, data):
audio = np.hstack(data).reshape(1, -1)
audio = audio.astype(np.float32, order='C') / 32768.0
- inputs = processor(audio, sampling_rate=sampling_rate, return_tensors="pt")
+ sampling_rate = data[0]._sample_rate
+ inputs = self.processor(audio.flatten(), sampling_rate=sampling_rate, return_tensors="pt")
with torch.no_grad():
- outputs = model(**inputs)
+ outputs = self.model(**inputs)
last_hidden_states = outputs.last_hidden_state
- return last_hidden_states
-
-
-
-
+ feat = last_hidden_states[:,-1,:].flatten().detach().cpu().numpy()
+ return feat