diff --git a/README.md b/README.md
index 32f4559..3288454 100644
--- a/README.md
+++ b/README.md
@@ -38,10 +38,19 @@ towhee.dc(["Hello, world."]) \
Create the operator via the following factory method
-***data2vec()***
+***data2vec(model_name='facebook/data2vec-text-base')***
-
+**Parameters:**
+
+ ***model_name***: *str*
+The model name in string.
+The default value is "facebook/data2vec-text-base".
+
+Supported model name:
+- facebook/data2vec-text-base
+
+
## Interface
diff --git a/data2vec_text.py b/data2vec_text.py
index 8a14c51..25e0b7d 100644
--- a/data2vec_text.py
+++ b/data2vec_text.py
@@ -18,10 +18,10 @@ from towhee.operator.base import NNOperator
class Data2VecText(NNOperator):
def __init__(self, model_name):
- self.model = model_name
- self.tokenizer = RobertaTokenizer.from_pretrained("facebook/data2vec-text-base")
+ self.model = Data2VecTextModel.from_pretrained(model_name)
+ self.tokenizer = RobertaTokenizer.from_pretrained(model_name)
def __call__(self, text: str) -> numpy.ndarray:
- inputs = self.tokenizer(data, return_tensors="pt")
+ inputs = self.tokenizer(text, return_tensors="pt")
outputs = self.model(**inputs)
- return outputs.pooler_output.detach().cpu().numpy()
+ return outputs.pooler_output.detach().cpu().numpy().flatten()