from typing import List

from sentence_transformers import CrossEncoder
from towhee.operator import NNOperator


class ReRank(NNOperator):
    def __init__(self, model_name: str = 'cross-encoder/ms-marco-MiniLM-L-6-v2'):
        super().__init__()
        self._model_name = model_name
        self._model = CrossEncoder(self._model_name, max_length=1000)

    def __call__(self, query: str, docs: List, threshold: float = None):
        if len(docs) == 0:
            return [], []
        scores = self._model.predict([(query, doc) for doc in docs])
        re_ids = sorted(range(len(scores)), key=lambda k: scores[k], reverse=True)
        if threshold is None:
            re_docs = [docs[i] for i in re_ids]
            re_scores = [scores[i] for i in re_ids]
        else:
            re_docs = [docs[i] for i in re_ids if scores[i] >= threshold]
            re_scores = [scores[i] for i in re_ids if scores[i] >= threshold]
        return re_docs, re_scores