logo
rerank
repo-copy-icon

copied

Browse Source

op return float

Signed-off-by: junjie.jiang <junjie.jiang@zilliz.com>
main
junjie.jiang 2 years ago
parent
commit
9d597ec7a5
  1. 9
      rerank.py

9
rerank.py

@ -81,10 +81,10 @@ class ReRank(NNOperator):
def post_proc(self, logits):
scores = self.activation_fct(logits).detach().cpu().numpy()
if self.config.num_labels == 1:
scores = [score[0] for score in scores]
scores = [float(score[0]) for score in scores]
else:
scores = scores[:, 1]
scores = [score for score in scores]
scores = [float(score) for score in scores]
return scores
@ -92,6 +92,9 @@ class ReRank(NNOperator):
def _model(self):
return self.model.model
@property
def supported_formats(self):
return ['onnx']
def save_model(self, format: str = 'pytorch', path: str = 'default'):
if path == 'default':
@ -150,4 +153,4 @@ if __name__ == '__main__':
op = ReRank(model_name, threshold=0)
res = op('abc', ['123', 'ABC', 'ABCabc'])
print(res)
op.save_model('onnx')
op.save_model('onnx')

Loading…
Cancel
Save