diff --git a/README.md b/README.md index 32f4559..3288454 100644 --- a/README.md +++ b/README.md @@ -38,10 +38,19 @@ towhee.dc(["Hello, world."]) \ Create the operator via the following factory method -***data2vec()*** +***data2vec(model_name='facebook/data2vec-text-base')*** -
+**Parameters:** + +​ ***model_name***: *str* +The model name in string. +The default value is "facebook/data2vec-text-base". + +Supported model name: +- facebook/data2vec-text-base + +
## Interface diff --git a/data2vec_text.py b/data2vec_text.py index 8a14c51..25e0b7d 100644 --- a/data2vec_text.py +++ b/data2vec_text.py @@ -18,10 +18,10 @@ from towhee.operator.base import NNOperator class Data2VecText(NNOperator): def __init__(self, model_name): - self.model = model_name - self.tokenizer = RobertaTokenizer.from_pretrained("facebook/data2vec-text-base") + self.model = Data2VecTextModel.from_pretrained(model_name) + self.tokenizer = RobertaTokenizer.from_pretrained(model_name) def __call__(self, text: str) -> numpy.ndarray: - inputs = self.tokenizer(data, return_tensors="pt") + inputs = self.tokenizer(text, return_tensors="pt") outputs = self.model(**inputs) - return outputs.pooler_output.detach().cpu().numpy() + return outputs.pooler_output.detach().cpu().numpy().flatten()