|
|
|
from typing import List
|
|
|
|
|
|
|
|
from sentence_transformers import CrossEncoder
|
|
|
|
from towhee.operator import NNOperator
|
|
|
|
|
|
|
|
|
|
|
|
class ReRank(NNOperator):
|
|
|
|
def __init__(self, model_name: str = 'cross-encoder/ms-marco-MiniLM-L-6-v2', threshold: float = None):
|
|
|
|
super().__init__()
|
|
|
|
self._model_name = model_name
|
|
|
|
self._model = CrossEncoder(self._model_name, max_length=1000)
|
|
|
|
self._threshold = threshold
|
|
|
|
|
|
|
|
def __call__(self, query: str, docs: List):
|
|
|
|
if len(docs) == 0:
|
|
|
|
return [], []
|
|
|
|
scores = self._model.predict([(query, doc) for doc in docs])
|
|
|
|
re_ids = sorted(range(len(scores)), key=lambda k: scores[k], reverse=True)
|
|
|
|
if self._threshold is None:
|
|
|
|
re_docs = [docs[i] for i in re_ids]
|
|
|
|
re_scores = [scores[i] for i in re_ids]
|
|
|
|
else:
|
|
|
|
re_docs = [docs[i] for i in re_ids if scores[i] >= self._threshold]
|
|
|
|
re_scores = [scores[i] for i in re_ids if scores[i] >= self._threshold]
|
|
|
|
return re_docs, re_scores
|