Browse Source
update the operator.
Signed-off-by: wxywb <xy.wang@zilliz.com>
main
wxywb
3 years ago
2 changed files with
15 additions and
6 deletions
-
README.md
-
data2vec_text.py
|
|
@ -38,10 +38,19 @@ towhee.dc(["Hello, world."]) \ |
|
|
|
|
|
|
|
Create the operator via the following factory method |
|
|
|
|
|
|
|
***data2vec()*** |
|
|
|
***data2vec(model_name='facebook/data2vec-text-base')*** |
|
|
|
|
|
|
|
<br /> |
|
|
|
**Parameters:** |
|
|
|
|
|
|
|
***model_name***: *str* |
|
|
|
|
|
|
|
The model name in string. |
|
|
|
The default value is "facebook/data2vec-text-base". |
|
|
|
|
|
|
|
Supported model name: |
|
|
|
- facebook/data2vec-text-base |
|
|
|
|
|
|
|
<br /> |
|
|
|
|
|
|
|
|
|
|
|
## Interface |
|
|
|
|
|
@ -18,10 +18,10 @@ from towhee.operator.base import NNOperator |
|
|
|
|
|
|
|
class Data2VecText(NNOperator): |
|
|
|
def __init__(self, model_name): |
|
|
|
self.model = model_name |
|
|
|
self.tokenizer = RobertaTokenizer.from_pretrained("facebook/data2vec-text-base") |
|
|
|
self.model = Data2VecTextModel.from_pretrained(model_name) |
|
|
|
self.tokenizer = RobertaTokenizer.from_pretrained(model_name) |
|
|
|
|
|
|
|
def __call__(self, text: str) -> numpy.ndarray: |
|
|
|
inputs = self.tokenizer(data, return_tensors="pt") |
|
|
|
inputs = self.tokenizer(text, return_tensors="pt") |
|
|
|
outputs = self.model(**inputs) |
|
|
|
return outputs.pooler_output.detach().cpu().numpy() |
|
|
|
return outputs.pooler_output.detach().cpu().numpy().flatten() |
|
|
|