diff --git a/README.md b/README.md index f6eef38..32f4559 100644 --- a/README.md +++ b/README.md @@ -24,7 +24,7 @@ Use the pre-trained model to generate a text embedding for the sentence "Hello, import towhee towhee.dc(["Hello, world."]) \ - .text_embedding.data2vec_text() \ + .text_embedding.data2vec() \ .show() ``` @@ -38,7 +38,7 @@ towhee.dc(["Hello, world."]) \ Create the operator via the following factory method -***data2vec_text()*** +***data2vec()***
diff --git a/__init__.py b/__init__.py index 4e4f668..c79ee5d 100644 --- a/__init__.py +++ b/__init__.py @@ -15,5 +15,5 @@ from .data2vec_text import Data2VecText -def data2vec_text(model_name='facebook/data2vec-vision-base'): - return Data2Text(model_name) +def data2vec(model_name='facebook/data2vec-text-base'): + return Data2VecText(model_name) diff --git a/data2vec_text.py b/data2vec_text.py index 360cef1..8a14c51 100644 --- a/data2vec_text.py +++ b/data2vec_text.py @@ -17,8 +17,8 @@ from transformers import RobertaTokenizer, Data2VecTextModel from towhee.operator.base import NNOperator class Data2VecText(NNOperator): - def __init__(self): - self.model = Data2VecTextModel.from_pretrained("facebook/data2vec-text-base") + def __init__(self, model_name): + self.model = model_name self.tokenizer = RobertaTokenizer.from_pretrained("facebook/data2vec-text-base") def __call__(self, text: str) -> numpy.ndarray: diff --git a/requirements.txt b/requirements.txt index 0805d19..7a53fc5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,4 @@ numpy transformers>4.19.0 torch -transformers towhee