logo
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions

63 lines
2.1 KiB

import numpy
from transformers import DPRContextEncoder, DPRContextEncoderTokenizer
from towhee import register
from towhee.operator import NNOperator
import warnings
import logging
warnings.filterwarnings('ignore')
logging.getLogger("transformers").setLevel(logging.ERROR)
log = logging.getLogger()
@register(output_schema=['vec'])
class Dpr(NNOperator):
"""
This class uses Dense Passage Retrieval to generate embedding.
Dense Passage Retrieval (DPR) is a set of tools and models for state-of-the-art open-domain Q&A research.
It was introduced in Dense Passage Retrieval for Open-Domain Question Answering by Vladimir Karpukhin,
Barlas Oğuz, Sewon Min, Patrick Lewis, Ledell Wu, Sergey Edunov, Danqi Chen, Wen-tau Yih.
Ref: https://huggingface.co/docs/transformers/v4.16.2/en/model_doc/dpr
Args:
model_name (`str`):
Which model to use for the embeddings.
"""
def __init__(self, model_name: str = "facebook/dpr-ctx_encoder-single-nq-base") -> None:
self.model_name = model_name
try:
self.tokenizer = DPRContextEncoderTokenizer.from_pretrained(model_name)
except Exception as e:
log.error(f'Fail to load tokenizer by name: {model_name}')
raise e
try:
self.model = DPRContextEncoder.from_pretrained(model_name)
except Exception as e:
log.error(f'Fail to load model by name: {model_name}')
raise e
def __call__(self, txt: str) -> numpy.ndarray:
try:
input_ids = self.tokenizer(txt, return_tensors="pt")["input_ids"]
except Exception as e:
log.error(f'Invalid input for the tokenizer: {self.model_name}')
raise e
try:
embeddings = self.model(input_ids).pooler_output
except Exception as e:
log.error(f'Invalid input for the model: {self.model_name}')
raise e
vec = embeddings.detach().numpy()
return vec
def get_model_list():
full_list = [
"facebook/dpr-ctx_encoder-single-nq-base",
"facebook/dpr-ctx_encoder-multiset-base",
]
full_list.sort()
return full_list