From fbd79d484e89872474829c1cfc35757252d9f0a8 Mon Sep 17 00:00:00 2001 From: wxywb <xy.wang@zilliz.com> Date: Thu, 21 Jul 2022 19:37:45 +0800 Subject: [PATCH] update the operator. Signed-off-by: wxywb <xy.wang@zilliz.com> --- README.md | 13 +++++++++++-- data2vec_text.py | 8 ++++---- 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 32f4559..3288454 100644 --- a/README.md +++ b/README.md @@ -38,10 +38,19 @@ towhee.dc(["Hello, world."]) \ Create the operator via the following factory method -***data2vec()*** +***data2vec(model_name='facebook/data2vec-text-base')*** -<br /> +**Parameters:** + + ***model_name***: *str* +The model name in string. +The default value is "facebook/data2vec-text-base". + +Supported model name: +- facebook/data2vec-text-base + +<br /> ## Interface diff --git a/data2vec_text.py b/data2vec_text.py index 8a14c51..25e0b7d 100644 --- a/data2vec_text.py +++ b/data2vec_text.py @@ -18,10 +18,10 @@ from towhee.operator.base import NNOperator class Data2VecText(NNOperator): def __init__(self, model_name): - self.model = model_name - self.tokenizer = RobertaTokenizer.from_pretrained("facebook/data2vec-text-base") + self.model = Data2VecTextModel.from_pretrained(model_name) + self.tokenizer = RobertaTokenizer.from_pretrained(model_name) def __call__(self, text: str) -> numpy.ndarray: - inputs = self.tokenizer(data, return_tensors="pt") + inputs = self.tokenizer(text, return_tensors="pt") outputs = self.model(**inputs) - return outputs.pooler_output.detach().cpu().numpy() + return outputs.pooler_output.detach().cpu().numpy().flatten()