From 8aa80249c09d39aaa568d38c0554d7c8bd399c2c Mon Sep 17 00:00:00 2001 From: ChengZi Date: Tue, 18 Jul 2023 15:28:43 +0800 Subject: [PATCH] adapt when model output label num is 2 Signed-off-by: ChengZi --- rerank.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/rerank.py b/rerank.py index b5039ef..c00ee3f 100644 --- a/rerank.py +++ b/rerank.py @@ -10,13 +10,18 @@ class ReRank(NNOperator): def __init__(self, model_name: str = 'cross-encoder/ms-marco-MiniLM-L-6-v2', threshold: float = 0.6, device: str = None): super().__init__() self._model_name = model_name - self._model = CrossEncoder(self._model_name, device=device, default_activation_function=nn.Sigmoid()) + self._model = CrossEncoder(self._model_name, device=device) + if self._model.config.num_labels == 1: + self._model.default_activation_function = nn.Sigmoid() self._threshold = threshold def __call__(self, query: str, docs: List): if len(docs) == 0: return [], [] - scores = self._model.predict([(query, doc) for doc in docs]) + if self._model.config.num_labels > 1: + scores = self._model.predict([(query, doc) for doc in docs], apply_softmax=True)[:, 1] + else: + scores = self._model.predict([(query, doc) for doc in docs]) re_ids = sorted(range(len(scores)), key=lambda k: scores[k], reverse=True) if self._threshold is None: re_docs = [docs[i] for i in re_ids]