logo
rerank
repo-copy-icon

copied

Browse Source

Add threshold

Signed-off-by: shiyu22 <shiyu.chen@zilliz.com>
main
shiyu22 2 years ago
parent
commit
8713010957
  1. 15
      rerank.py

15
rerank.py

@ -1,18 +1,23 @@
import numpy as np
from typing import List from typing import List
from xml.dom.expatbuilder import theDOMImplementation
from sentence_transformers import CrossEncoder from sentence_transformers import CrossEncoder
from towhee.operator import NNOperator from towhee.operator import NNOperator
class ReRank(NNOperator): class ReRank(NNOperator):
def __init__(self, model_name: str = 'cross-encoder/ms-marco-MiniLM-L-12-v2'):
def __init__(self, model_name: str = 'cross-encoder/ms-marco-MiniLM-L-6-v2'):
super().__init__() super().__init__()
self._model_name = model_name self._model_name = model_name
self._model = CrossEncoder(self._model_name, max_length=1000) self._model = CrossEncoder(self._model_name, max_length=1000)
def __call__(self, query: str, docs: List):
def __call__(self, query: str, docs: List, threshold: float = None):
scores = self._model.predict([(query, doc) for doc in docs]) scores = self._model.predict([(query, doc) for doc in docs])
re_ids = sorted(range(len(scores)), key=lambda k: scores[k], reverse=True) re_ids = sorted(range(len(scores)), key=lambda k: scores[k], reverse=True)
re_docs = [docs[i] for i in re_ids]
re_scores = [scores[i] for i in re_ids]
if threshold is None:
re_docs = [docs[i] for i in re_ids]
re_scores = [scores[i] for i in re_ids]
else:
re_docs = [docs[i] for i in re_ids if scores[i] >= threshold]
re_scores = [scores[i] for i in re_ids if scores[i] >= threshold]
return re_docs, re_scores return re_docs, re_scores
Loading…
Cancel
Save