logo
Browse Source

update the operator.

Signed-off-by: wxywb <xy.wang@zilliz.com>
main
wxywb 3 years ago
parent
commit
fbd79d484e
  1. 13
      README.md
  2. 8
      data2vec_text.py

13
README.md

@ -38,10 +38,19 @@ towhee.dc(["Hello, world."]) \
Create the operator via the following factory method
***data2vec()***
***data2vec(model_name='facebook/data2vec-text-base')***
<br />
**Parameters:**
***model_name***: *str*
The model name in string.
The default value is "facebook/data2vec-text-base".
Supported model name:
- facebook/data2vec-text-base
<br />
## Interface

8
data2vec_text.py

@ -18,10 +18,10 @@ from towhee.operator.base import NNOperator
class Data2VecText(NNOperator):
def __init__(self, model_name):
self.model = model_name
self.tokenizer = RobertaTokenizer.from_pretrained("facebook/data2vec-text-base")
self.model = Data2VecTextModel.from_pretrained(model_name)
self.tokenizer = RobertaTokenizer.from_pretrained(model_name)
def __call__(self, text: str) -> numpy.ndarray:
inputs = self.tokenizer(data, return_tensors="pt")
inputs = self.tokenizer(text, return_tensors="pt")
outputs = self.model(**inputs)
return outputs.pooler_output.detach().cpu().numpy()
return outputs.pooler_output.detach().cpu().numpy().flatten()

Loading…
Cancel
Save