From 8d5520f5a7dc112b7be1d082c2adcbd6541109fe Mon Sep 17 00:00:00 2001 From: shiyu22 Date: Wed, 21 Jun 2023 12:02:11 +0800 Subject: [PATCH] Update rerank Signed-off-by: shiyu22 --- rerank.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/rerank.py b/rerank.py index ff8f233..07ea589 100644 --- a/rerank.py +++ b/rerank.py @@ -1,17 +1,18 @@ from typing import List -from xml.dom.expatbuilder import theDOMImplementation from sentence_transformers import CrossEncoder from towhee.operator import NNOperator class ReRank(NNOperator): - def __init__(self, model_name: str = 'cross-encoder/ms-marco-MiniLM-L-12-v2'): + def __init__(self, model_name: str = 'cross-encoder/ms-marco-MiniLM-L-6-v2'): super().__init__() self._model_name = model_name self._model = CrossEncoder(self._model_name, max_length=1000) def __call__(self, query: str, docs: List, threshold: float = None): + if len(docs) == 0: + return [], [] scores = self._model.predict([(query, doc) for doc in docs]) re_ids = sorted(range(len(scores)), key=lambda k: scores[k], reverse=True) if threshold is None: