|
|
@ -104,13 +104,15 @@ if args.format == 'pytorch': |
|
|
|
.text_embedding.transformers['text', 'emb'](model_name=model_name, device=device) |
|
|
|
.runas_op['emb', 'emb'](lambda x: x[0]) |
|
|
|
.runas_op['gt', 'gt'](lambda y: str(y)) |
|
|
|
.ann_search.milvus['emb', 'milvus_res']( |
|
|
|
uri=f'tcp://{host}:{port}/{collection_name}', |
|
|
|
.ann_search.milvus_client['emb', 'milvus_res']( |
|
|
|
host=host, |
|
|
|
port=port, |
|
|
|
collection_name=collection_name, |
|
|
|
metric_type=metric_type, |
|
|
|
limit=topk, |
|
|
|
output_fields=['label'] |
|
|
|
) |
|
|
|
.runas_op['milvus_res', 'preds'](lambda x: [y.label for y in x]).unstream() |
|
|
|
.runas_op['milvus_res', 'preds'](lambda x: [y[-1] for y in x]).unstream() |
|
|
|
.runas_op['preds', 'pred1'](lambda x: mode(x[:1])) |
|
|
|
.runas_op['preds', 'pred5'](lambda x: mode(x[:5])) |
|
|
|
.runas_op['preds', 'pred10'](lambda x: mode(x[:10])) |
|
|
@ -178,13 +180,15 @@ elif args.format == 'onnx': |
|
|
|
.run_onnx['text', 'emb']() |
|
|
|
.runas_op['emb', 'emb'](lambda x: x[0]) |
|
|
|
.runas_op['gt', 'gt'](lambda y: str(y)) |
|
|
|
.ann_search.milvus['emb', 'milvus_res']( |
|
|
|
uri=f'tcp://{host}:{port}/{collection_name}', |
|
|
|
.ann_search.milvus_client['emb', 'milvus_res']( |
|
|
|
host=host, |
|
|
|
port=port, |
|
|
|
collection_name=collection_name, |
|
|
|
metric_type=metric_type, |
|
|
|
limit=topk, |
|
|
|
output_fields=['label'] |
|
|
|
) |
|
|
|
.runas_op['milvus_res', 'preds'](lambda x: [y.label for y in x]).unstream() |
|
|
|
.runas_op['milvus_res', 'preds'](lambda x: [y[-1] for y in x]).unstream() |
|
|
|
.runas_op['preds', 'pred1'](lambda x: mode(x[:1])) |
|
|
|
.runas_op['preds', 'pred5'](lambda x: mode(x[:5])) |
|
|
|
.runas_op['preds', 'pred10'](lambda x: mode(x[:10])) |
|
|
@ -198,7 +202,8 @@ elif args.format == 'onnx': |
|
|
|
else: |
|
|
|
raise AttributeError('Only support "pytorch" and "onnx" as format.') |
|
|
|
|
|
|
|
collection = create_milvus(collection_name) |
|
|
|
insert_count = insert(model_name, collection_name) |
|
|
|
print('Total data inserted:', insert_count) |
|
|
|
# collection = create_milvus(collection_name) |
|
|
|
# insert_count = insert(model_name, collection_name) |
|
|
|
# print('Total data inserted:', insert_count) |
|
|
|
collection = Collection(collection_name) |
|
|
|
benchmark = query(model_name, collection_name) |
|
|
|