From 871301095768098a8d78aa6b484f09f18fe34db0 Mon Sep 17 00:00:00 2001 From: shiyu22 Date: Tue, 20 Jun 2023 17:28:07 +0800 Subject: [PATCH] Add threshold Signed-off-by: shiyu22 --- rerank.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/rerank.py b/rerank.py index e3ab5ec..59acaa6 100644 --- a/rerank.py +++ b/rerank.py @@ -1,18 +1,23 @@ -import numpy as np from typing import List +from xml.dom.expatbuilder import theDOMImplementation from sentence_transformers import CrossEncoder from towhee.operator import NNOperator + class ReRank(NNOperator): - def __init__(self, model_name: str = 'cross-encoder/ms-marco-MiniLM-L-12-v2'): + def __init__(self, model_name: str = 'cross-encoder/ms-marco-MiniLM-L-6-v2'): super().__init__() self._model_name = model_name self._model = CrossEncoder(self._model_name, max_length=1000) - def __call__(self, query: str, docs: List): + def __call__(self, query: str, docs: List, threshold: float = None): scores = self._model.predict([(query, doc) for doc in docs]) re_ids = sorted(range(len(scores)), key=lambda k: scores[k], reverse=True) - re_docs = [docs[i] for i in re_ids] - re_scores = [scores[i] for i in re_ids] + if threshold is None: + re_docs = [docs[i] for i in re_ids] + re_scores = [scores[i] for i in re_ids] + else: + re_docs = [docs[i] for i in re_ids if scores[i] >= threshold] + re_scores = [scores[i] for i in re_ids if scores[i] >= threshold] return re_docs, re_scores \ No newline at end of file