Browse Source
op return float
Signed-off-by: junjie.jiang <junjie.jiang@zilliz.com>
main
1 changed files with
6 additions and
3 deletions
-
rerank.py
|
|
@ -81,10 +81,10 @@ class ReRank(NNOperator): |
|
|
|
def post_proc(self, logits): |
|
|
|
scores = self.activation_fct(logits).detach().cpu().numpy() |
|
|
|
if self.config.num_labels == 1: |
|
|
|
scores = [score[0] for score in scores] |
|
|
|
scores = [float(score[0]) for score in scores] |
|
|
|
else: |
|
|
|
scores = scores[:, 1] |
|
|
|
scores = [score for score in scores] |
|
|
|
scores = [float(score) for score in scores] |
|
|
|
return scores |
|
|
|
|
|
|
|
|
|
|
@ -92,6 +92,9 @@ class ReRank(NNOperator): |
|
|
|
def _model(self): |
|
|
|
return self.model.model |
|
|
|
|
|
|
|
@property |
|
|
|
def supported_formats(self): |
|
|
|
return ['onnx'] |
|
|
|
|
|
|
|
def save_model(self, format: str = 'pytorch', path: str = 'default'): |
|
|
|
if path == 'default': |
|
|
|