diff --git a/README.md b/README.md index f524411..209ffb5 100644 --- a/README.md +++ b/README.md @@ -23,11 +23,20 @@ where embeddings are learned from a small number of questions and passages by a [2].https://arxiv.org/abs/2004.04906 +## Code Example + +Use the pretrained model "facebook/dpr-ctx_encoder-single-nq-base" +to generate a text embedding for the sentence "Hello, world.". + + *Write the pipeline*: + ```python -from towhee import ops +from towhee import dc + -text_encoder = ops.text_embedding.dpr(model_name="allenai/longformer-base-4096") -text_embedding = text_encoder("Hello, world.") +dc.stream(["Hello, world."]) + .text_embedding.dpr("facebook/dpr-ctx_encoder-single-nq-base") + .show() ``` ## Factory Constructor @@ -38,48 +47,37 @@ Create the operator via the following factory method -## Interface +## Factory Constructor -A text embedding operator takes a sentence, paragraph, or document in string as an input -and output an embedding vector in ndarray which captures the input's core semantic elements. +Create the operator via the following factory method +***text_embedding.dpr(model_name="facebook/dpr-ctx_encoder-single-nq-base")*** **Parameters:** -​ ***text***: *str* - -​ The text in string. - +​ ***model_name***: *str* +​ The model name in string. +The default value is "facebook/dpr-ctx_encoder-single-nq-base". +You can get the list of supported model names by calling `get_model_list` from [auto_transformers.py](https://towhee.io/text-embedding/transformers/src/branch/main/auto_transformers.py). -**Returns**: *numpy.ndarray* -​ The text embedding extracted by model. +## Interface +The operator takes a text in string as input. +It loads tokenizer and pre-trained model using model name. +and then return text embedding in ndarray. +**Parameters:** -## Code Example +​ ***text***: *str* -Use the pretrained model ('allenai/longformer-base-4096') -to generate a text embedding for the sentence "Hello, world.". +​ The text in string. - *Write the pipeline in simplified style*: -```python -import towhee.DataCollection as dc -dc.glob("Hello, world.") - .text_embedding.dpr('longformer-base-4096') - .show() -``` - -*Write a same pipeline with explicit inputs/outputs name specifications:* +**Returns**: -```python -from towhee import DataCollection as dc +​ *numpy.ndarray* -dc.glob['text']('Hello, world.') - .text_embedding.dpr['text', 'vec']('longformer-base-4096') - .select('vec') - .show() -``` +​ The text embedding extracted by model. \ No newline at end of file diff --git a/__init__.py b/__init__.py index e6b9968..28e2af7 100644 --- a/__init__.py +++ b/__init__.py @@ -15,5 +15,5 @@ from .dpr import Dpr -def dpr(model_name: str): - return Dpr(model_name) +def dpr(**kwargs): + return Dpr(**kwargs) diff --git a/dpr.py b/dpr.py index 74ee34a..004626a 100644 --- a/dpr.py +++ b/dpr.py @@ -1,12 +1,15 @@ import numpy -import logging from transformers import DPRContextEncoder, DPRContextEncoderTokenizer from towhee import register from towhee.operator import NNOperator import warnings +import logging + + warnings.filterwarnings('ignore') +logging.getLogger("transformers").setLevel(logging.ERROR) log = logging.getLogger() @@ -23,7 +26,7 @@ class Dpr(NNOperator): model_name (`str`): Which model to use for the embeddings. """ - def __init__(self, model_name: str) -> None: + def __init__(self, model_name: str = "facebook/dpr-ctx_encoder-single-nq-base") -> None: self.model_name = model_name try: self.tokenizer = DPRContextEncoderTokenizer.from_pretrained(model_name) @@ -47,5 +50,14 @@ class Dpr(NNOperator): except Exception as e: log.error(f'Invalid input for the model: {self.model_name}') raise e - feature_vector = embeddings.detach().numpy() - return feature_vector + vec = embeddings.detach().numpy() + return vec + + +def get_model_list(): + full_list = [ + "facebook/dpr-ctx_encoder-single-nq-base", + "facebook/dpr-ctx_encoder-multiset-base", + ] + full_list.sort() + return full_list