logo
rerank
repo-copy-icon

copied

Browse Source

adapt when model output label num is 2

Signed-off-by: ChengZi <chen.zhang@zilliz.com>
main
ChengZi 2 years ago
parent
commit
8aa80249c0
  1. 9
      rerank.py

9
rerank.py

@ -10,13 +10,18 @@ class ReRank(NNOperator):
def __init__(self, model_name: str = 'cross-encoder/ms-marco-MiniLM-L-6-v2', threshold: float = 0.6, device: str = None):
super().__init__()
self._model_name = model_name
self._model = CrossEncoder(self._model_name, device=device, default_activation_function=nn.Sigmoid())
self._model = CrossEncoder(self._model_name, device=device)
if self._model.config.num_labels == 1:
self._model.default_activation_function = nn.Sigmoid()
self._threshold = threshold
def __call__(self, query: str, docs: List):
if len(docs) == 0:
return [], []
scores = self._model.predict([(query, doc) for doc in docs])
if self._model.config.num_labels > 1:
scores = self._model.predict([(query, doc) for doc in docs], apply_softmax=True)[:, 1]
else:
scores = self._model.predict([(query, doc) for doc in docs])
re_ids = sorted(range(len(scores)), key=lambda k: scores[k], reverse=True)
if self._threshold is None:
re_docs = [docs[i] for i in re_ids]

Loading…
Cancel
Save