logo
Browse Source

init the operator.

Signed-off-by: wxywb <xy.wang@zilliz.com>
main
wxywb 3 years ago
parent
commit
ab320230d8
  1. 63
      README.md
  2. 19
      __init__.py
  3. 27
      data2vec_text.py
  4. 0
      main.py
  5. 2
      requirements.txt

63
README.md

@ -1,2 +1,63 @@
# data2vec-text
# Text Embdding with data2vec
*author: David Wang*
<br />
## Description
This operator extracts features for text with [data2vec](https://arxiv.org/abs/2202.03555). The core idea is to predict latent representations of the full input data based on a masked view of the input in a self-distillation setup using a standard Transformer architecture.
<br />
## Code Example
Use the pre-trained model to generate a text embedding for the sentence "Hello, world.".
*Write the pipeline in simplified style*:
```python
import towhee
towhee.dc(["Hello, world."]) \
.text_embedding.data2vec_text() \
.show()
```
<br />
## Factory Constructor
Create the operator via the following factory method
***data2vec_text()***
<br />
## Interface
**Parameters:**
***text:*** *str*
​ The text in string.
**Returns:** *numpy.ndarray*
​ The text embedding extracted by model.

19
__init__.py

@ -0,0 +1,19 @@
# Copyright 2021 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .data2vec_text import Data2VecText
def data2vec_text(model_name='facebook/data2vec-vision-base'):
return Data2Text(model_name)

27
data2vec_text.py

@ -0,0 +1,27 @@
# Copyright 2021 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy
import torch
from transformers import RobertaTokenizer, Data2VecTextModel
from towhee.operator.base import NNOperator
class Data2VecText(NNOperator):
def __init__(self):
self.model = Data2VecTextModel.from_pretrained("facebook/data2vec-text-base")
self.tokenizer = RobertaTokenizer.from_pretrained("facebook/data2vec-text-base")
def __call__(self, text: str) -> numpy.ndarray:
inputs = self.tokenizer(data, return_tensors="pt")
outputs = self.model(**inputs)
return outputs.pooler_output.detach().cpu().numpy()

0
main.py

2
requirements.txt

@ -0,0 +1,2 @@
numpy
transformers>4.19.0
Loading…
Cancel
Save