diff --git a/rerank.py b/rerank.py index 4840319..cb663c2 100644 --- a/rerank.py +++ b/rerank.py @@ -122,7 +122,7 @@ class ReRank(NNOperator): onnx_config = model_onnx_config(self._model.config) onnx_inputs, onnx_outputs = export( self.tokenizer, - self._model, + self._model.to('cpu'), config=onnx_config, opset=13, output=Path(path)