From f4e0714d140b443cfd6ef1a379d423bcea1c3391 Mon Sep 17 00:00:00 2001 From: "junjie.jiang" Date: Wed, 12 Jul 2023 15:34:28 +0800 Subject: [PATCH] add device and sigmoid Signed-off-by: junjie.jiang --- README.md | 5 ++--- rerank.py | 6 ++++-- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 8a43106..3cdd209 100644 --- a/README.md +++ b/README.md @@ -36,7 +36,6 @@ DataCollection(p('What is Towhee?', ).show() ``` -
@@ -56,8 +55,8 @@ Create the operator via the following factory method ​ ***threshold***: float -​ The threshold for filtering with score, defaults to none, i.e., no filtering. - +​ The threshold for filtering with score + ***device***: str
diff --git a/rerank.py b/rerank.py index 2842e80..b5039ef 100644 --- a/rerank.py +++ b/rerank.py @@ -1,14 +1,16 @@ from typing import List +from torch import nn from sentence_transformers import CrossEncoder + from towhee.operator import NNOperator class ReRank(NNOperator): - def __init__(self, model_name: str = 'cross-encoder/ms-marco-MiniLM-L-6-v2', threshold: float = None): + def __init__(self, model_name: str = 'cross-encoder/ms-marco-MiniLM-L-6-v2', threshold: float = 0.6, device: str = None): super().__init__() self._model_name = model_name - self._model = CrossEncoder(self._model_name, max_length=1000) + self._model = CrossEncoder(self._model_name, device=device, default_activation_function=nn.Sigmoid()) self._threshold = threshold def __call__(self, query: str, docs: List):