diff --git a/benchmark/run.py b/benchmark/run.py index 9e7d783..b3c92c6 100644 --- a/benchmark/run.py +++ b/benchmark/run.py @@ -104,13 +104,15 @@ if args.format == 'pytorch': .text_embedding.transformers['text', 'emb'](model_name=model_name, device=device) .runas_op['emb', 'emb'](lambda x: x[0]) .runas_op['gt', 'gt'](lambda y: str(y)) - .ann_search.milvus['emb', 'milvus_res']( - uri=f'tcp://{host}:{port}/{collection_name}', + .ann_search.milvus_client['emb', 'milvus_res']( + host=host, + port=port, + collection_name=collection_name, metric_type=metric_type, limit=topk, output_fields=['label'] ) - .runas_op['milvus_res', 'preds'](lambda x: [y.label for y in x]).unstream() + .runas_op['milvus_res', 'preds'](lambda x: [y[-1] for y in x]).unstream() .runas_op['preds', 'pred1'](lambda x: mode(x[:1])) .runas_op['preds', 'pred5'](lambda x: mode(x[:5])) .runas_op['preds', 'pred10'](lambda x: mode(x[:10])) @@ -178,13 +180,15 @@ elif args.format == 'onnx': .run_onnx['text', 'emb']() .runas_op['emb', 'emb'](lambda x: x[0]) .runas_op['gt', 'gt'](lambda y: str(y)) - .ann_search.milvus['emb', 'milvus_res']( - uri=f'tcp://{host}:{port}/{collection_name}', + .ann_search.milvus_client['emb', 'milvus_res']( + host=host, + port=port, + collection_name=collection_name, metric_type=metric_type, limit=topk, output_fields=['label'] ) - .runas_op['milvus_res', 'preds'](lambda x: [y.label for y in x]).unstream() + .runas_op['milvus_res', 'preds'](lambda x: [y[-1] for y in x]).unstream() .runas_op['preds', 'pred1'](lambda x: mode(x[:1])) .runas_op['preds', 'pred5'](lambda x: mode(x[:5])) .runas_op['preds', 'pred10'](lambda x: mode(x[:10])) @@ -198,7 +202,8 @@ elif args.format == 'onnx': else: raise AttributeError('Only support "pytorch" and "onnx" as format.') -collection = create_milvus(collection_name) -insert_count = insert(model_name, collection_name) -print('Total data inserted:', insert_count) +# collection = create_milvus(collection_name) +# insert_count = insert(model_name, collection_name) +# print('Total data inserted:', insert_count) +collection = Collection(collection_name) benchmark = query(model_name, collection_name)