logo
Browse Source

change data2vec_text to data2vec.

Signed-off-by: wxywb <xy.wang@zilliz.com>
main
wxywb 3 years ago
parent
commit
2faa93290e
  1. 4
      README.md
  2. 4
      __init__.py
  3. 4
      data2vec_text.py
  4. 1
      requirements.txt

4
README.md

@ -24,7 +24,7 @@ Use the pre-trained model to generate a text embedding for the sentence "Hello,
import towhee
towhee.dc(["Hello, world."]) \
.text_embedding.data2vec_text() \
.text_embedding.data2vec() \
.show()
```
@ -38,7 +38,7 @@ towhee.dc(["Hello, world."]) \
Create the operator via the following factory method
***data2vec_text()***
***data2vec()***
<br />

4
__init__.py

@ -15,5 +15,5 @@
from .data2vec_text import Data2VecText
def data2vec_text(model_name='facebook/data2vec-vision-base'):
return Data2Text(model_name)
def data2vec(model_name='facebook/data2vec-text-base'):
return Data2VecText(model_name)

4
data2vec_text.py

@ -17,8 +17,8 @@ from transformers import RobertaTokenizer, Data2VecTextModel
from towhee.operator.base import NNOperator
class Data2VecText(NNOperator):
def __init__(self):
self.model = Data2VecTextModel.from_pretrained("facebook/data2vec-text-base")
def __init__(self, model_name):
self.model = model_name
self.tokenizer = RobertaTokenizer.from_pretrained("facebook/data2vec-text-base")
def __call__(self, text: str) -> numpy.ndarray:

1
requirements.txt

@ -1,5 +1,4 @@
numpy
transformers>4.19.0
torch
transformers
towhee

Loading…
Cancel
Save