From 2faa93290e95da7aae242df921232303b6d28024 Mon Sep 17 00:00:00 2001 From: wxywb Date: Fri, 24 Jun 2022 18:06:43 +0800 Subject: [PATCH] change data2vec_text to data2vec. Signed-off-by: wxywb --- README.md | 4 ++-- __init__.py | 4 ++-- data2vec_text.py | 4 ++-- requirements.txt | 1 - 4 files changed, 6 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index f6eef38..32f4559 100644 --- a/README.md +++ b/README.md @@ -24,7 +24,7 @@ Use the pre-trained model to generate a text embedding for the sentence "Hello, import towhee towhee.dc(["Hello, world."]) \ - .text_embedding.data2vec_text() \ + .text_embedding.data2vec() \ .show() ``` @@ -38,7 +38,7 @@ towhee.dc(["Hello, world."]) \ Create the operator via the following factory method -***data2vec_text()*** +***data2vec()***
diff --git a/__init__.py b/__init__.py index 4e4f668..c79ee5d 100644 --- a/__init__.py +++ b/__init__.py @@ -15,5 +15,5 @@ from .data2vec_text import Data2VecText -def data2vec_text(model_name='facebook/data2vec-vision-base'): - return Data2Text(model_name) +def data2vec(model_name='facebook/data2vec-text-base'): + return Data2VecText(model_name) diff --git a/data2vec_text.py b/data2vec_text.py index 360cef1..8a14c51 100644 --- a/data2vec_text.py +++ b/data2vec_text.py @@ -17,8 +17,8 @@ from transformers import RobertaTokenizer, Data2VecTextModel from towhee.operator.base import NNOperator class Data2VecText(NNOperator): - def __init__(self): - self.model = Data2VecTextModel.from_pretrained("facebook/data2vec-text-base") + def __init__(self, model_name): + self.model = model_name self.tokenizer = RobertaTokenizer.from_pretrained("facebook/data2vec-text-base") def __call__(self, text: str) -> numpy.ndarray: diff --git a/requirements.txt b/requirements.txt index 0805d19..7a53fc5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,4 @@ numpy transformers>4.19.0 torch -transformers towhee