import numpy as np from typing import List from sentence_transformers import CrossEncoder from towhee.operator import NNOperator class ReRank(NNOperator): def __init__(self, model_name: str = 'cross-encoder/ms-marco-MiniLM-L-12-v2'): super().__init__() self._model_name = model_name self._model = CrossEncoder(self._model_name, max_length=1000) def __call__(self, query: str, docs: List): scores = self._model.predict([(query, doc) for doc in docs]) re_ids = sorted(range(len(scores)), key=lambda k: scores[k], reverse=True) re_docs = [docs[i] for i in re_ids] scores.sort(reverse=True) return re_docs, scores