logo
rerank
repo-copy-icon

copied

Browse Source

op return float

Signed-off-by: junjie.jiang <junjie.jiang@zilliz.com>
main
junjie.jiang 2 years ago
parent
commit
9d597ec7a5
  1. 7
      rerank.py

7
rerank.py

@ -81,10 +81,10 @@ class ReRank(NNOperator):
def post_proc(self, logits):
scores = self.activation_fct(logits).detach().cpu().numpy()
if self.config.num_labels == 1:
scores = [score[0] for score in scores]
scores = [float(score[0]) for score in scores]
else:
scores = scores[:, 1]
scores = [score for score in scores]
scores = [float(score) for score in scores]
return scores
@ -92,6 +92,9 @@ class ReRank(NNOperator):
def _model(self):
return self.model.model
@property
def supported_formats(self):
return ['onnx']
def save_model(self, format: str = 'pytorch', path: str = 'default'):
if path == 'default':

Loading…
Cancel
Save