logo
rerank
repo-copy-icon

copied

Browse Source

add device and sigmoid

Signed-off-by: junjie.jiang <junjie.jiang@zilliz.com>
main
junjie.jiang 2 years ago
parent
commit
f4e0714d14
  1. 5
      README.md
  2. 6
      rerank.py

5
README.md

@ -36,7 +36,6 @@ DataCollection(p('What is Towhee?',
).show()
```
<img src="./result.png" height="100px"/>
<br />
@ -56,8 +55,8 @@ Create the operator via the following factory method
***threshold***: float
​ The threshold for filtering with score, defaults to none, i.e., no filtering.
​ The threshold for filtering with score
***device***: str
<br />

6
rerank.py

@ -1,14 +1,16 @@
from typing import List
from torch import nn
from sentence_transformers import CrossEncoder
from towhee.operator import NNOperator
class ReRank(NNOperator):
def __init__(self, model_name: str = 'cross-encoder/ms-marco-MiniLM-L-6-v2', threshold: float = None):
def __init__(self, model_name: str = 'cross-encoder/ms-marco-MiniLM-L-6-v2', threshold: float = 0.6, device: str = None):
super().__init__()
self._model_name = model_name
self._model = CrossEncoder(self._model_name, max_length=1000)
self._model = CrossEncoder(self._model_name, device=device, default_activation_function=nn.Sigmoid())
self._threshold = threshold
def __call__(self, query: str, docs: List):

Loading…
Cancel
Save