diff --git a/README.md b/README.md
index b76441f..8a43106 100644
--- a/README.md
+++ b/README.md
@@ -15,10 +15,9 @@ The Rerank operator is used to reorder the list of relevant documents for a quer
```Python
from towhee import ops
-op = ops.rerank()
+op = ops.rerank(threshold=0)
res = op('What is Towhee?',
- ['Towhee is Towhee is a cutting-edge framework to deal with unstructure data.', 'I do not know about towhee', 'Towhee has many powerful operators.', 'The weather is good' ],
- 0)
+ ['Towhee is Towhee is a cutting-edge framework to deal with unstructure data.', 'I do not know about towhee', 'Towhee has many powerful operators.', 'The weather is good' ])
```
- Run a pipeline
@@ -26,15 +25,14 @@ res = op('What is Towhee?',
```python
from towhee import ops, pipe, DataCollection
-p = (pipe.input('query', 'doc', 'threshold')
- .map(('query', 'doc', 'threshold'), ('doc', 'score'), ops.rerank())
+p = (pipe.input('query', 'doc')
+ .map(('query', 'doc'), ('doc', 'score'), ops.rerank(threshold=0))
.flat_map(('doc', 'score'), ('doc', 'score'), lambda x, y: [(i, j) for i, j in zip(x, y)])
.output('query', 'doc', 'score')
)
DataCollection(p('What is Towhee?',
- ['Towhee is Towhee is a cutting-edge framework to deal with unstructure data.', 'I do not know about towhee', 'Towhee has many powerful operators.', 'The weather is good' ],
- 0)
+ ['Towhee is Towhee is a cutting-edge framework to deal with unstructure data.', 'I do not know about towhee', 'Towhee has many powerful operators.', 'The weather is good' ])
).show()
```
@@ -56,6 +54,10 @@ Create the operator via the following factory method
The model name of CrossEncoder, you can set it according to the [Model List](https://www.sbert.net/docs/pretrained-models/ce-msmarco.html#models-performance).
+ ***threshold***: float
+
+ The threshold for filtering with score, defaults to none, i.e., no filtering.
+
@@ -74,10 +76,6 @@ This operator is used to sort the documents of the query content and return the
A list of sentences to check the correlation with the query content.
- ***threshold***: float
-
- The threshold for filtering with score, defaults to none, i.e., no filtering.
-
diff --git a/rerank.py b/rerank.py
index 07ea589..2842e80 100644
--- a/rerank.py
+++ b/rerank.py
@@ -5,20 +5,21 @@ from towhee.operator import NNOperator
class ReRank(NNOperator):
- def __init__(self, model_name: str = 'cross-encoder/ms-marco-MiniLM-L-6-v2'):
+ def __init__(self, model_name: str = 'cross-encoder/ms-marco-MiniLM-L-6-v2', threshold: float = None):
super().__init__()
self._model_name = model_name
self._model = CrossEncoder(self._model_name, max_length=1000)
+ self._threshold = threshold
- def __call__(self, query: str, docs: List, threshold: float = None):
+ def __call__(self, query: str, docs: List):
if len(docs) == 0:
return [], []
scores = self._model.predict([(query, doc) for doc in docs])
re_ids = sorted(range(len(scores)), key=lambda k: scores[k], reverse=True)
- if threshold is None:
+ if self._threshold is None:
re_docs = [docs[i] for i in re_ids]
re_scores = [scores[i] for i in re_ids]
else:
- re_docs = [docs[i] for i in re_ids if scores[i] >= threshold]
- re_scores = [scores[i] for i in re_ids if scores[i] >= threshold]
+ re_docs = [docs[i] for i in re_ids if scores[i] >= self._threshold]
+ re_scores = [scores[i] for i in re_ids if scores[i] >= self._threshold]
return re_docs, re_scores