logo
rerank
repo-copy-icon

copied

You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions

18 lines
670 B

import numpy as np
from typing import List
from sentence_transformers import CrossEncoder
from towhee.operator import NNOperator
class ReRank(NNOperator):
def __init__(self, model_name: str = 'cross-encoder/ms-marco-MiniLM-L-12-v2'):
super().__init__()
self._model_name = model_name
self._model = CrossEncoder(self._model_name, max_length=1000)
def __call__(self, query: str, docs: List):
scores = self._model.predict([(query, doc) for doc in docs])
re_ids = sorted(range(len(scores)), key=lambda k: scores[k], reverse=True)
print(re_ids, docs)
re_docs = [docs[i] for i in re_ids]
return re_docs