logo
rerank
repo-copy-icon

copied

Browse Source

Update rerank

Signed-off-by: shiyu22 <shiyu.chen@zilliz.com>
main
shiyu22 2 years ago
parent
commit
4be43e5b9e
  1. 20
      README.md
  2. 11
      rerank.py

20
README.md

@ -15,10 +15,9 @@ The Rerank operator is used to reorder the list of relevant documents for a quer
```Python ```Python
from towhee import ops from towhee import ops
op = ops.rerank()
op = ops.rerank(threshold=0)
res = op('What is Towhee?', res = op('What is Towhee?',
['Towhee is Towhee is a cutting-edge framework to deal with unstructure data.', 'I do not know about towhee', 'Towhee has many powerful operators.', 'The weather is good' ],
0)
['Towhee is Towhee is a cutting-edge framework to deal with unstructure data.', 'I do not know about towhee', 'Towhee has many powerful operators.', 'The weather is good' ])
``` ```
- Run a pipeline - Run a pipeline
@ -26,15 +25,14 @@ res = op('What is Towhee?',
```python ```python
from towhee import ops, pipe, DataCollection from towhee import ops, pipe, DataCollection
p = (pipe.input('query', 'doc', 'threshold')
.map(('query', 'doc', 'threshold'), ('doc', 'score'), ops.rerank())
p = (pipe.input('query', 'doc')
.map(('query', 'doc'), ('doc', 'score'), ops.rerank(threshold=0))
.flat_map(('doc', 'score'), ('doc', 'score'), lambda x, y: [(i, j) for i, j in zip(x, y)]) .flat_map(('doc', 'score'), ('doc', 'score'), lambda x, y: [(i, j) for i, j in zip(x, y)])
.output('query', 'doc', 'score') .output('query', 'doc', 'score')
) )
DataCollection(p('What is Towhee?', DataCollection(p('What is Towhee?',
['Towhee is Towhee is a cutting-edge framework to deal with unstructure data.', 'I do not know about towhee', 'Towhee has many powerful operators.', 'The weather is good' ],
0)
['Towhee is Towhee is a cutting-edge framework to deal with unstructure data.', 'I do not know about towhee', 'Towhee has many powerful operators.', 'The weather is good' ])
).show() ).show()
``` ```
@ -56,6 +54,10 @@ Create the operator via the following factory method
​ The model name of CrossEncoder, you can set it according to the [Model List](https://www.sbert.net/docs/pretrained-models/ce-msmarco.html#models-performance). ​ The model name of CrossEncoder, you can set it according to the [Model List](https://www.sbert.net/docs/pretrained-models/ce-msmarco.html#models-performance).
***threshold***: float
​ The threshold for filtering with score, defaults to none, i.e., no filtering.
<br /> <br />
@ -74,10 +76,6 @@ This operator is used to sort the documents of the query content and return the
A list of sentences to check the correlation with the query content. A list of sentences to check the correlation with the query content.
***threshold***: float
​ The threshold for filtering with score, defaults to none, i.e., no filtering.
<br /> <br />

11
rerank.py

@ -5,20 +5,21 @@ from towhee.operator import NNOperator
class ReRank(NNOperator): class ReRank(NNOperator):
def __init__(self, model_name: str = 'cross-encoder/ms-marco-MiniLM-L-6-v2'):
def __init__(self, model_name: str = 'cross-encoder/ms-marco-MiniLM-L-6-v2', threshold: float = None):
super().__init__() super().__init__()
self._model_name = model_name self._model_name = model_name
self._model = CrossEncoder(self._model_name, max_length=1000) self._model = CrossEncoder(self._model_name, max_length=1000)
self._threshold = threshold
def __call__(self, query: str, docs: List, threshold: float = None):
def __call__(self, query: str, docs: List):
if len(docs) == 0: if len(docs) == 0:
return [], [] return [], []
scores = self._model.predict([(query, doc) for doc in docs]) scores = self._model.predict([(query, doc) for doc in docs])
re_ids = sorted(range(len(scores)), key=lambda k: scores[k], reverse=True) re_ids = sorted(range(len(scores)), key=lambda k: scores[k], reverse=True)
if threshold is None:
if self._threshold is None:
re_docs = [docs[i] for i in re_ids] re_docs = [docs[i] for i in re_ids]
re_scores = [scores[i] for i in re_ids] re_scores = [scores[i] for i in re_ids]
else: else:
re_docs = [docs[i] for i in re_ids if scores[i] >= threshold]
re_scores = [scores[i] for i in re_ids if scores[i] >= threshold]
re_docs = [docs[i] for i in re_ids if scores[i] >= self._threshold]
re_scores = [scores[i] for i in re_ids if scores[i] >= self._threshold]
return re_docs, re_scores return re_docs, re_scores

Loading…
Cancel
Save