logo
Browse Source

Refactor

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 3 years ago
parent
commit
ba01395f82
  1. 60
      README.md
  2. 4
      __init__.py
  3. 20
      dpr.py

60
README.md

@ -23,11 +23,20 @@ where embeddings are learned from a small number of questions and passages by a
[2].https://arxiv.org/abs/2004.04906 [2].https://arxiv.org/abs/2004.04906
## Code Example
Use the pretrained model "facebook/dpr-ctx_encoder-single-nq-base"
to generate a text embedding for the sentence "Hello, world.".
*Write the pipeline*:
```python ```python
from towhee import ops
from towhee import dc
text_encoder = ops.text_embedding.dpr(model_name="allenai/longformer-base-4096")
text_embedding = text_encoder("Hello, world.")
dc.stream(["Hello, world."])
.text_embedding.dpr("facebook/dpr-ctx_encoder-single-nq-base")
.show()
``` ```
## Factory Constructor ## Factory Constructor
@ -38,48 +47,37 @@ Create the operator via the following factory method
## Interface
## Factory Constructor
A text embedding operator takes a sentence, paragraph, or document in string as an input
and output an embedding vector in ndarray which captures the input's core semantic elements.
Create the operator via the following factory method
***text_embedding.dpr(model_name="facebook/dpr-ctx_encoder-single-nq-base")***
**Parameters:** **Parameters:**
***text***: *str*
​ The text in string.
***model_name***: *str*
​ The model name in string.
The default value is "facebook/dpr-ctx_encoder-single-nq-base".
You can get the list of supported model names by calling `get_model_list` from [auto_transformers.py](https://towhee.io/text-embedding/transformers/src/branch/main/auto_transformers.py).
**Returns**: *numpy.ndarray*
​ The text embedding extracted by model.
## Interface
The operator takes a text in string as input.
It loads tokenizer and pre-trained model using model name.
and then return text embedding in ndarray.
**Parameters:**
## Code Example
***text***: *str*
Use the pretrained model ('allenai/longformer-base-4096')
to generate a text embedding for the sentence "Hello, world.".
​ The text in string.
*Write the pipeline in simplified style*:
```python
import towhee.DataCollection as dc
dc.glob("Hello, world.")
.text_embedding.dpr('longformer-base-4096')
.show()
```
*Write a same pipeline with explicit inputs/outputs name specifications:*
**Returns**:
```python
from towhee import DataCollection as dc
​ *numpy.ndarray*
dc.glob['text']('Hello, world.')
.text_embedding.dpr['text', 'vec']('longformer-base-4096')
.select('vec')
.show()
```
​ The text embedding extracted by model.

4
__init__.py

@ -15,5 +15,5 @@
from .dpr import Dpr from .dpr import Dpr
def dpr(model_name: str):
return Dpr(model_name)
def dpr(**kwargs):
return Dpr(**kwargs)

20
dpr.py

@ -1,12 +1,15 @@
import numpy import numpy
import logging
from transformers import DPRContextEncoder, DPRContextEncoderTokenizer from transformers import DPRContextEncoder, DPRContextEncoderTokenizer
from towhee import register from towhee import register
from towhee.operator import NNOperator from towhee.operator import NNOperator
import warnings import warnings
import logging
warnings.filterwarnings('ignore') warnings.filterwarnings('ignore')
logging.getLogger("transformers").setLevel(logging.ERROR)
log = logging.getLogger() log = logging.getLogger()
@ -23,7 +26,7 @@ class Dpr(NNOperator):
model_name (`str`): model_name (`str`):
Which model to use for the embeddings. Which model to use for the embeddings.
""" """
def __init__(self, model_name: str) -> None:
def __init__(self, model_name: str = "facebook/dpr-ctx_encoder-single-nq-base") -> None:
self.model_name = model_name self.model_name = model_name
try: try:
self.tokenizer = DPRContextEncoderTokenizer.from_pretrained(model_name) self.tokenizer = DPRContextEncoderTokenizer.from_pretrained(model_name)
@ -47,5 +50,14 @@ class Dpr(NNOperator):
except Exception as e: except Exception as e:
log.error(f'Invalid input for the model: {self.model_name}') log.error(f'Invalid input for the model: {self.model_name}')
raise e raise e
feature_vector = embeddings.detach().numpy()
return feature_vector
vec = embeddings.detach().numpy()
return vec
def get_model_list():
full_list = [
"facebook/dpr-ctx_encoder-single-nq-base",
"facebook/dpr-ctx_encoder-multiset-base",
]
full_list.sort()
return full_list

Loading…
Cancel
Save