diff --git a/README.md b/README.md index b76441f..8a43106 100644 --- a/README.md +++ b/README.md @@ -15,10 +15,9 @@ The Rerank operator is used to reorder the list of relevant documents for a quer ```Python from towhee import ops -op = ops.rerank() +op = ops.rerank(threshold=0) res = op('What is Towhee?', - ['Towhee is Towhee is a cutting-edge framework to deal with unstructure data.', 'I do not know about towhee', 'Towhee has many powerful operators.', 'The weather is good' ], - 0) + ['Towhee is Towhee is a cutting-edge framework to deal with unstructure data.', 'I do not know about towhee', 'Towhee has many powerful operators.', 'The weather is good' ]) ``` - Run a pipeline @@ -26,15 +25,14 @@ res = op('What is Towhee?', ```python from towhee import ops, pipe, DataCollection -p = (pipe.input('query', 'doc', 'threshold') - .map(('query', 'doc', 'threshold'), ('doc', 'score'), ops.rerank()) +p = (pipe.input('query', 'doc') + .map(('query', 'doc'), ('doc', 'score'), ops.rerank(threshold=0)) .flat_map(('doc', 'score'), ('doc', 'score'), lambda x, y: [(i, j) for i, j in zip(x, y)]) .output('query', 'doc', 'score') ) DataCollection(p('What is Towhee?', - ['Towhee is Towhee is a cutting-edge framework to deal with unstructure data.', 'I do not know about towhee', 'Towhee has many powerful operators.', 'The weather is good' ], - 0) + ['Towhee is Towhee is a cutting-edge framework to deal with unstructure data.', 'I do not know about towhee', 'Towhee has many powerful operators.', 'The weather is good' ]) ).show() ``` @@ -56,6 +54,10 @@ Create the operator via the following factory method ​ The model name of CrossEncoder, you can set it according to the [Model List](https://www.sbert.net/docs/pretrained-models/ce-msmarco.html#models-performance). +​ ***threshold***: float + +​ The threshold for filtering with score, defaults to none, i.e., no filtering. +
@@ -74,10 +76,6 @@ This operator is used to sort the documents of the query content and return the A list of sentences to check the correlation with the query content. -​ ***threshold***: float - -​ The threshold for filtering with score, defaults to none, i.e., no filtering. -
diff --git a/rerank.py b/rerank.py index 07ea589..2842e80 100644 --- a/rerank.py +++ b/rerank.py @@ -5,20 +5,21 @@ from towhee.operator import NNOperator class ReRank(NNOperator): - def __init__(self, model_name: str = 'cross-encoder/ms-marco-MiniLM-L-6-v2'): + def __init__(self, model_name: str = 'cross-encoder/ms-marco-MiniLM-L-6-v2', threshold: float = None): super().__init__() self._model_name = model_name self._model = CrossEncoder(self._model_name, max_length=1000) + self._threshold = threshold - def __call__(self, query: str, docs: List, threshold: float = None): + def __call__(self, query: str, docs: List): if len(docs) == 0: return [], [] scores = self._model.predict([(query, doc) for doc in docs]) re_ids = sorted(range(len(scores)), key=lambda k: scores[k], reverse=True) - if threshold is None: + if self._threshold is None: re_docs = [docs[i] for i in re_ids] re_scores = [scores[i] for i in re_ids] else: - re_docs = [docs[i] for i in re_ids if scores[i] >= threshold] - re_scores = [scores[i] for i in re_ids if scores[i] >= threshold] + re_docs = [docs[i] for i in re_ids if scores[i] >= self._threshold] + re_scores = [scores[i] for i in re_ids if scores[i] >= self._threshold] return re_docs, re_scores