diff --git a/README.md b/README.md
index 8a43106..3cdd209 100644
--- a/README.md
+++ b/README.md
@@ -36,7 +36,6 @@ DataCollection(p('What is Towhee?',
).show()
```
-
@@ -56,8 +55,8 @@ Create the operator via the following factory method
***threshold***: float
- The threshold for filtering with score, defaults to none, i.e., no filtering.
-
+ The threshold for filtering with score
+ ***device***: str
diff --git a/rerank.py b/rerank.py
index 2842e80..b5039ef 100644
--- a/rerank.py
+++ b/rerank.py
@@ -1,14 +1,16 @@
from typing import List
+from torch import nn
from sentence_transformers import CrossEncoder
+
from towhee.operator import NNOperator
class ReRank(NNOperator):
- def __init__(self, model_name: str = 'cross-encoder/ms-marco-MiniLM-L-6-v2', threshold: float = None):
+ def __init__(self, model_name: str = 'cross-encoder/ms-marco-MiniLM-L-6-v2', threshold: float = 0.6, device: str = None):
super().__init__()
self._model_name = model_name
- self._model = CrossEncoder(self._model_name, max_length=1000)
+ self._model = CrossEncoder(self._model_name, device=device, default_activation_function=nn.Sigmoid())
self._threshold = threshold
def __call__(self, query: str, docs: List):