diff --git a/evaluate/evaluate.py b/evaluate/evaluate.py index 2ccdd5b..2162d4f 100644 --- a/evaluate/evaluate.py +++ b/evaluate/evaluate.py @@ -21,9 +21,11 @@ parser.add_argument('--query_size', type=int, default=100) parser.add_argument('--topk', type=int, default=10) parser.add_argument('--collection_name', type=str, default=None) parser.add_argument('--format', type=str, required=True) +parser.add_argument('--onnx_dir', type=str, default='../saved/onnx') args = parser.parse_args() model_name = args.model +onnx_path = os.path.join(args.onnx_dir, model_name.replace('/', '-'), '.onnx') dataset_name = args.dataset insert_size = args.insert_size query_size = args.query_size @@ -76,7 +78,6 @@ def create_milvus(collection_name): if args.format == 'pytorch': collection_name = collection_name + '_pytorch' - def insert(model_name, collection_name): ( towhee.dc['text', 'label'](zip(insert_data['text'], insert_data['label'])).stream() @@ -120,9 +121,8 @@ if args.format == 'pytorch': elif args.format == 'onnx': collection_name = collection_name + '_onnx' saved_name = model_name.replace('/', '-') - onnx_path = f'saved/onnx/{saved_name}.onnx' if not os.path.exists(onnx_path): - op.save_model(format='onnx') + op.save_model(format='onnx', path=onnx_path[:-5]) sess = onnxruntime.InferenceSession(onnx_path, providers=onnxruntime.get_available_providers())