diff --git a/rerank.py b/rerank.py index 37a2374..3074313 100644 --- a/rerank.py +++ b/rerank.py @@ -81,10 +81,10 @@ class ReRank(NNOperator): def post_proc(self, logits): scores = self.activation_fct(logits).detach().cpu().numpy() if self.config.num_labels == 1: - scores = [score[0] for score in scores] + scores = [float(score[0]) for score in scores] else: scores = scores[:, 1] - scores = [score for score in scores] + scores = [float(score) for score in scores] return scores @@ -92,6 +92,9 @@ class ReRank(NNOperator): def _model(self): return self.model.model + @property + def supported_formats(self): + return ['onnx'] def save_model(self, format: str = 'pytorch', path: str = 'default'): if path == 'default': @@ -150,4 +153,4 @@ if __name__ == '__main__': op = ReRank(model_name, threshold=0) res = op('abc', ['123', 'ABC', 'ABCabc']) print(res) - op.save_model('onnx') \ No newline at end of file + op.save_model('onnx')