logo
Browse Source

Update evaluate with onnx path

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 2 years ago
parent
commit
4f343f9441
  1. 6
      evaluate/evaluate.py

6
evaluate/evaluate.py

@ -21,9 +21,11 @@ parser.add_argument('--query_size', type=int, default=100)
parser.add_argument('--topk', type=int, default=10) parser.add_argument('--topk', type=int, default=10)
parser.add_argument('--collection_name', type=str, default=None) parser.add_argument('--collection_name', type=str, default=None)
parser.add_argument('--format', type=str, required=True) parser.add_argument('--format', type=str, required=True)
parser.add_argument('--onnx_dir', type=str, default='../saved/onnx')
args = parser.parse_args() args = parser.parse_args()
model_name = args.model model_name = args.model
onnx_path = os.path.join(args.onnx_dir, model_name.replace('/', '-'), '.onnx')
dataset_name = args.dataset dataset_name = args.dataset
insert_size = args.insert_size insert_size = args.insert_size
query_size = args.query_size query_size = args.query_size
@ -76,7 +78,6 @@ def create_milvus(collection_name):
if args.format == 'pytorch': if args.format == 'pytorch':
collection_name = collection_name + '_pytorch' collection_name = collection_name + '_pytorch'
def insert(model_name, collection_name): def insert(model_name, collection_name):
( (
towhee.dc['text', 'label'](zip(insert_data['text'], insert_data['label'])).stream() towhee.dc['text', 'label'](zip(insert_data['text'], insert_data['label'])).stream()
@ -120,9 +121,8 @@ if args.format == 'pytorch':
elif args.format == 'onnx': elif args.format == 'onnx':
collection_name = collection_name + '_onnx' collection_name = collection_name + '_onnx'
saved_name = model_name.replace('/', '-') saved_name = model_name.replace('/', '-')
onnx_path = f'saved/onnx/{saved_name}.onnx'
if not os.path.exists(onnx_path): if not os.path.exists(onnx_path):
op.save_model(format='onnx')
op.save_model(format='onnx', path=onnx_path[:-5])
sess = onnxruntime.InferenceSession(onnx_path, sess = onnxruntime.InferenceSession(onnx_path,
providers=onnxruntime.get_available_providers()) providers=onnxruntime.get_available_providers())

Loading…
Cancel
Save