From f5b2df224b10bac2a1529ed14d314d6afd150c25 Mon Sep 17 00:00:00 2001 From: shiyu22 Date: Tue, 20 Jun 2023 16:16:07 +0800 Subject: [PATCH] Add rerank op Signed-off-by: shiyu22 --- __init__.py | 4 ++++ requirements.txt | 1 + rerank.py | 18 ++++++++++++++++++ 3 files changed, 23 insertions(+) create mode 100644 __init__.py create mode 100644 requirements.txt create mode 100644 rerank.py diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..da15a82 --- /dev/null +++ b/__init__.py @@ -0,0 +1,4 @@ +from .rerank import ReRank + +def rerank(*args, **kwargs): + return ReRank(*args, **kwargs) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..b344dc3 --- /dev/null +++ b/requirements.txt @@ -0,0 +1 @@ +sentence_transformers \ No newline at end of file diff --git a/rerank.py b/rerank.py new file mode 100644 index 0000000..3907113 --- /dev/null +++ b/rerank.py @@ -0,0 +1,18 @@ +import numpy as np +from typing import List + +from sentence_transformers import CrossEncoder +from towhee.operator import NNOperator + +class ReRank(NNOperator): + def __init__(self, model_name: str = 'cross-encoder/ms-marco-MiniLM-L-12-v2'): + super().__init__() + self._model_name = model_name + self._model = CrossEncoder(self._model_name, max_length=1000) + + def __call__(self, query: str, docs: List): + scores = self._model.predict([(query, doc) for doc in docs]) + re_ids = sorted(range(len(scores)), key=lambda k: scores[k], reverse=True) + print(re_ids, docs) + re_docs = [docs[i] for i in re_ids] + return re_docs \ No newline at end of file