diff --git a/README.md b/README.md index 345f7ce..c5f4637 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,63 @@ -# data2vec-text +# Text Embdding with data2vec + +*author: David Wang* + + +
+ + + +## Description + +This operator extracts features for text with [data2vec](https://arxiv.org/abs/2202.03555). The core idea is to predict latent representations of the full input data based on a masked view of the input in a self-distillation setup using a standard Transformer architecture. + +
+ + +## Code Example + +Use the pre-trained model to generate a text embedding for the sentence "Hello, world.". + + *Write the pipeline in simplified style*: + +```python +import towhee + +towhee.dc(["Hello, world."]) \ + .text_embedding.data2vec_text() \ + .show() + +``` + + +
+ + + +## Factory Constructor + +Create the operator via the following factory method + +***data2vec_text()*** + +
+ + + +## Interface + + +**Parameters:** + +​ ***text:*** *str* + +​ The text in string. + + + +**Returns:** *numpy.ndarray* + +​ The text embedding extracted by model. + + diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..4e4f668 --- /dev/null +++ b/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2021 Zilliz. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .data2vec_text import Data2VecText + + +def data2vec_text(model_name='facebook/data2vec-vision-base'): + return Data2Text(model_name) diff --git a/data2vec_text.py b/data2vec_text.py new file mode 100644 index 0000000..360cef1 --- /dev/null +++ b/data2vec_text.py @@ -0,0 +1,27 @@ +# Copyright 2021 Zilliz. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy +import torch +from transformers import RobertaTokenizer, Data2VecTextModel +from towhee.operator.base import NNOperator + +class Data2VecText(NNOperator): + def __init__(self): + self.model = Data2VecTextModel.from_pretrained("facebook/data2vec-text-base") + self.tokenizer = RobertaTokenizer.from_pretrained("facebook/data2vec-text-base") + + def __call__(self, text: str) -> numpy.ndarray: + inputs = self.tokenizer(data, return_tensors="pt") + outputs = self.model(**inputs) + return outputs.pooler_output.detach().cpu().numpy() diff --git a/main.py b/main.py new file mode 100644 index 0000000..e69de29 diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..0c49ac4 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,2 @@ +numpy +transformers>4.19.0