Browse Source
add device and sigmoid
Signed-off-by: junjie.jiang <junjie.jiang@zilliz.com>
main
2 changed files with
6 additions and
5 deletions
-
README.md
-
rerank.py
|
|
@ -36,7 +36,6 @@ DataCollection(p('What is Towhee?', |
|
|
|
).show() |
|
|
|
``` |
|
|
|
|
|
|
|
<img src="./result.png" height="100px"/> |
|
|
|
|
|
|
|
<br /> |
|
|
|
|
|
|
@ -56,8 +55,8 @@ Create the operator via the following factory method |
|
|
|
|
|
|
|
***threshold***: float |
|
|
|
|
|
|
|
The threshold for filtering with score, defaults to none, i.e., no filtering. |
|
|
|
|
|
|
|
The threshold for filtering with score |
|
|
|
***device***: str |
|
|
|
<br /> |
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -1,14 +1,16 @@ |
|
|
|
from typing import List |
|
|
|
|
|
|
|
from torch import nn |
|
|
|
from sentence_transformers import CrossEncoder |
|
|
|
|
|
|
|
from towhee.operator import NNOperator |
|
|
|
|
|
|
|
|
|
|
|
class ReRank(NNOperator): |
|
|
|
def __init__(self, model_name: str = 'cross-encoder/ms-marco-MiniLM-L-6-v2', threshold: float = None): |
|
|
|
def __init__(self, model_name: str = 'cross-encoder/ms-marco-MiniLM-L-6-v2', threshold: float = 0.6, device: str = None): |
|
|
|
super().__init__() |
|
|
|
self._model_name = model_name |
|
|
|
self._model = CrossEncoder(self._model_name, max_length=1000) |
|
|
|
self._model = CrossEncoder(self._model_name, device=device, default_activation_function=nn.Sigmoid()) |
|
|
|
self._threshold = threshold |
|
|
|
|
|
|
|
def __call__(self, query: str, docs: List): |
|
|
|