diff --git a/data2vec_audio.py b/data2vec_audio.py index 9afd92d..c0d467a 100644 --- a/data2vec_audio.py +++ b/data2vec_audio.py @@ -22,8 +22,8 @@ from towhee.operator.base import NNOperator class Data2VecAudio(NNOperator): def __init__(self, model_name = "facebook/data2vec-audio-base-960h"): - self.model = Data2VecAudioModel.from_pretrained("facebook/data2vec-audio-base-960h") - self.processor = Wav2Vec2Processor.from_pretrained("facebook/data2vec-audio-base-960h") + self.model = Data2VecAudioModel.from_pretrained(model_name) + self.processor = Wav2Vec2Processor.from_pretrained(model_name) def __call__(self, data): audio = np.hstack(data).reshape(1, -1) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..4a10baf --- /dev/null +++ b/requirements.txt @@ -0,0 +1 @@ +transformers>=4.18.0