from typing import List from torch import nn from sentence_transformers import CrossEncoder from towhee.operator import NNOperator class ReRank(NNOperator): def __init__(self, model_name: str = 'cross-encoder/ms-marco-MiniLM-L-6-v2', threshold: float = 0.6, device: str = None): super().__init__() self._model_name = model_name self._model = CrossEncoder(self._model_name, device=device) if self._model.config.num_labels == 1: self._model.default_activation_function = nn.Sigmoid() self._threshold = threshold def __call__(self, query: str, docs: List): if len(docs) == 0: return [], [] if self._model.config.num_labels > 1: scores = self._model.predict([(query, doc) for doc in docs], apply_softmax=True)[:, 1] else: scores = self._model.predict([(query, doc) for doc in docs]) re_ids = sorted(range(len(scores)), key=lambda k: scores[k], reverse=True) if self._threshold is None: re_docs = [docs[i] for i in re_ids] re_scores = [scores[i] for i in re_ids] else: re_docs = [docs[i] for i in re_ids if scores[i] >= self._threshold] re_scores = [scores[i] for i in re_ids if scores[i] >= self._threshold] return re_docs, re_scores