From ba5de36512509d2991fb0df8d9faae2ac761592a Mon Sep 17 00:00:00 2001 From: Junxen Date: Tue, 29 Mar 2022 17:39:39 +0800 Subject: [PATCH] refactor dpr --- README.md | 85 +++++++++++++++++++++++++++++++++++++++++++++++- nlp_dpr.py | 51 +++++++++++++++++++++++++++++ requirements.txt | 4 +++ 3 files changed, 139 insertions(+), 1 deletion(-) create mode 100644 nlp_dpr.py create mode 100644 requirements.txt diff --git a/README.md b/README.md index 3cdc849..f524411 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,85 @@ -# dpr +# Text Embedding with dpr +*author: Kyle He* + + + +## Desription + +This operator uses Dense Passage Retrieval (DPR) to convert long text to embeddings. + +Dense Passage Retrieval (DPR) is a set of tools and models for state-of-the-art open-domain Q&A research. +It was introduced in Dense Passage Retrieval for Open-Domain Question Answering by Vladimir Karpukhin, +Barlas Oğuz, Sewon Min, Patrick Lewis, Ledell Wu, Sergey Edunov, Danqi Chen, Wen-tau Yih[1]. + +**DPR** models were proposed in “[Dense Passage Retrieval for Open-Domain Question Answering][2]. + +In this work, we show that retrieval can be practically implemented using dense representations alone, +where embeddings are learned from a small number of questions and passages by a simple dual-encoder framework[2]. + +## Reference + +[1].https://huggingface.co/docs/transformers/v4.16.2/en/model_doc/longformer#transformers.LongformerConfig + +[2].https://arxiv.org/abs/2004.04906 + +```python +from towhee import ops + +text_encoder = ops.text_embedding.dpr(model_name="allenai/longformer-base-4096") +text_embedding = text_encoder("Hello, world.") +``` + +## Factory Constructor + +Create the operator via the following factory method + +***ops.text_embedding.dpr(model_name)*** + + + +## Interface + +A text embedding operator takes a sentence, paragraph, or document in string as an input +and output an embedding vector in ndarray which captures the input's core semantic elements. + + +**Parameters:** + +​ ***text***: *str* + +​ The text in string. + + + +**Returns**: *numpy.ndarray* + +​ The text embedding extracted by model. + + + +## Code Example + +Use the pretrained model ('allenai/longformer-base-4096') +to generate a text embedding for the sentence "Hello, world.". + + *Write the pipeline in simplified style*: + +```python +import towhee.DataCollection as dc + +dc.glob("Hello, world.") + .text_embedding.dpr('longformer-base-4096') + .show() +``` + +*Write a same pipeline with explicit inputs/outputs name specifications:* + +```python +from towhee import DataCollection as dc + +dc.glob['text']('Hello, world.') + .text_embedding.dpr['text', 'vec']('longformer-base-4096') + .select('vec') + .show() +``` diff --git a/nlp_dpr.py b/nlp_dpr.py new file mode 100644 index 0000000..faf39c4 --- /dev/null +++ b/nlp_dpr.py @@ -0,0 +1,51 @@ +import numpy +import logging +from transformers import DPRContextEncoder, DPRContextEncoderTokenizer + +from towhee import register +from towhee.operator import NNOperator + +import warnings +warnings.filterwarnings('ignore') +log = logging.getLogger() + + +@register(output_schema=['vec']) +class NlpDpr(NNOperator): + """ + This class uses Dense Passage Retrieval to generate embedding. + Dense Passage Retrieval (DPR) is a set of tools and models for state-of-the-art open-domain Q&A research. + It was introduced in Dense Passage Retrieval for Open-Domain Question Answering by Vladimir Karpukhin, + Barlas Oğuz, Sewon Min, Patrick Lewis, Ledell Wu, Sergey Edunov, Danqi Chen, Wen-tau Yih. + Ref: https://huggingface.co/docs/transformers/v4.16.2/en/model_doc/dpr + + Args: + model_name (`str`): + Which model to use for the embeddings. + """ + def __init__(self, model_name: str) -> None: + self.model_name = model_name + try: + self.tokenizer = DPRContextEncoderTokenizer.from_pretrained(model_name) + except Exception as e: + log.error(f'Fail to load tokenizer by name: {model_name}') + raise e + try: + self.model = DPRContextEncoder.from_pretrained(model_name) + except Exception as e: + log.error(f'Fail to load model by name: {model_name}') + raise e + + def __call__(self, txt: str) -> numpy.ndarray: + try: + input_ids = self.tokenizer(txt, return_tensors="pt")["input_ids"] + except Exception as e: + log.error(f'Invalid input for the tokenizer: {self.model_name}') + raise e + try: + embeddings = self.model(input_ids).pooler_output + except Exception as e: + log.error(f'Invalid input for the model: {self.model_name}') + raise e + feature_vector = embeddings.detach().numpy() + return feature_vector diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..7bd17fa --- /dev/null +++ b/requirements.txt @@ -0,0 +1,4 @@ +numpy +transformers +sentencepiece +protobuf