logo
Browse Source

Refactor

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 2 years ago
parent
commit
ba01395f82
  1. 60
      README.md
  2. 4
      __init__.py
  3. 20
      dpr.py

60
README.md

@ -23,11 +23,20 @@ where embeddings are learned from a small number of questions and passages by a
[2].https://arxiv.org/abs/2004.04906
## Code Example
Use the pretrained model "facebook/dpr-ctx_encoder-single-nq-base"
to generate a text embedding for the sentence "Hello, world.".
*Write the pipeline*:
```python
from towhee import ops
from towhee import dc
text_encoder = ops.text_embedding.dpr(model_name="allenai/longformer-base-4096")
text_embedding = text_encoder("Hello, world.")
dc.stream(["Hello, world."])
.text_embedding.dpr("facebook/dpr-ctx_encoder-single-nq-base")
.show()
```
## Factory Constructor
@ -38,48 +47,37 @@ Create the operator via the following factory method
## Interface
## Factory Constructor
A text embedding operator takes a sentence, paragraph, or document in string as an input
and output an embedding vector in ndarray which captures the input's core semantic elements.
Create the operator via the following factory method
***text_embedding.dpr(model_name="facebook/dpr-ctx_encoder-single-nq-base")***
**Parameters:**
***text***: *str*
​ The text in string.
***model_name***: *str*
​ The model name in string.
The default value is "facebook/dpr-ctx_encoder-single-nq-base".
You can get the list of supported model names by calling `get_model_list` from [auto_transformers.py](https://towhee.io/text-embedding/transformers/src/branch/main/auto_transformers.py).
**Returns**: *numpy.ndarray*
​ The text embedding extracted by model.
## Interface
The operator takes a text in string as input.
It loads tokenizer and pre-trained model using model name.
and then return text embedding in ndarray.
**Parameters:**
## Code Example
***text***: *str*
Use the pretrained model ('allenai/longformer-base-4096')
to generate a text embedding for the sentence "Hello, world.".
​ The text in string.
*Write the pipeline in simplified style*:
```python
import towhee.DataCollection as dc
dc.glob("Hello, world.")
.text_embedding.dpr('longformer-base-4096')
.show()
```
*Write a same pipeline with explicit inputs/outputs name specifications:*
**Returns**:
```python
from towhee import DataCollection as dc
​ *numpy.ndarray*
dc.glob['text']('Hello, world.')
.text_embedding.dpr['text', 'vec']('longformer-base-4096')
.select('vec')
.show()
```
​ The text embedding extracted by model.

4
__init__.py

@ -15,5 +15,5 @@
from .dpr import Dpr
def dpr(model_name: str):
return Dpr(model_name)
def dpr(**kwargs):
return Dpr(**kwargs)

20
dpr.py

@ -1,12 +1,15 @@
import numpy
import logging
from transformers import DPRContextEncoder, DPRContextEncoderTokenizer
from towhee import register
from towhee.operator import NNOperator
import warnings
import logging
warnings.filterwarnings('ignore')
logging.getLogger("transformers").setLevel(logging.ERROR)
log = logging.getLogger()
@ -23,7 +26,7 @@ class Dpr(NNOperator):
model_name (`str`):
Which model to use for the embeddings.
"""
def __init__(self, model_name: str) -> None:
def __init__(self, model_name: str = "facebook/dpr-ctx_encoder-single-nq-base") -> None:
self.model_name = model_name
try:
self.tokenizer = DPRContextEncoderTokenizer.from_pretrained(model_name)
@ -47,5 +50,14 @@ class Dpr(NNOperator):
except Exception as e:
log.error(f'Invalid input for the model: {self.model_name}')
raise e
feature_vector = embeddings.detach().numpy()
return feature_vector
vec = embeddings.detach().numpy()
return vec
def get_model_list():
full_list = [
"facebook/dpr-ctx_encoder-single-nq-base",
"facebook/dpr-ctx_encoder-multiset-base",
]
full_list.sort()
return full_list

Loading…
Cancel
Save