From f65c1a0683380d6888f1338f0b681f4a945d0686 Mon Sep 17 00:00:00 2001 From: Jael Gu Date: Thu, 15 Dec 2022 16:49:13 +0800 Subject: [PATCH] Update run.py Signed-off-by: Jael Gu --- benchmark/README.md | 22 ++++++++++++++++++++++ benchmark/run.py | 22 +++++++++++++--------- 2 files changed, 35 insertions(+), 9 deletions(-) create mode 100644 benchmark/README.md diff --git a/benchmark/README.md b/benchmark/README.md new file mode 100644 index 0000000..ac4a3d1 --- /dev/null +++ b/benchmark/README.md @@ -0,0 +1,22 @@ +# Evaluate with Similarity Search + +## Introduction + +Build a image classification system based on similarity search across embeddings. + +The core ideas in `run.py`: +1. create a new Milvus collection each time +2. extract embeddings using a pretrained model with model name specified by `--model` +3. specify inference method with `--format` in value of `pytorch` or `onnx` +4. insert & search embeddings with Milvus collection without index +5. measure performance with accuracy at top 1, 5, 10 + 1. vote for the prediction from topk search results (most frequent one) + 2. compare final prediction with ground truth + 3. calculate percent of correct predictions over all queries + +## Example Usage + +```bash +python evaluate.py --model MODEL_NAME --format pytorch +python evaluate.py --model MODEL_NAME --format onnx +``` \ No newline at end of file diff --git a/benchmark/run.py b/benchmark/run.py index 4b26e99..7052ad3 100644 --- a/benchmark/run.py +++ b/benchmark/run.py @@ -98,13 +98,15 @@ if args.format == 'pytorch': towhee.dc['image', 'gt'](zip(query_data['image'], query_data['label'])).stream() .image_embedding.timm['image', 'emb'](model_name=model_name, device=device) .runas_op['gt', 'gt'](lambda y: str(y)) - .ann_search.milvus['emb', 'milvus_res']( - uri=f'tcp://{host}:{port}/{collection_name}', + .ann_search.milvus_client['emb', 'milvus_res']( + host=host, + port=port, + collection_name=collection_name, metric_type=metric_type, limit=topk, output_fields=['label'] ) - .runas_op['milvus_res', 'preds'](lambda x: [y.label for y in x]).unstream() + .runas_op['milvus_res', 'preds'](lambda x: [y[-1] for y in x]).unstream() .runas_op['preds', 'pred1'](lambda x: mode(x[:1])) .runas_op['preds', 'pred5'](lambda x: mode(x[:5])) .runas_op['preds', 'pred10'](lambda x: mode(x[:10])) @@ -154,13 +156,15 @@ elif args.format == 'onnx': towhee.dc['image', 'gt'](zip(query_data['image'], query_data['label'])).stream() .run_onnx['image', 'emb']() .runas_op['gt', 'gt'](lambda y: str(y)) - .ann_search.milvus['emb', 'milvus_res']( - uri=f'tcp://{host}:{port}/{collection_name}', - metric_type=metric_type, - limit=topk, - output_fields=['label'] + .ann_search.milvus_client['emb', 'milvus_res']( + host=host, + port=port, + collection_name=collection_name, + metric_type=metric_type, + limit=topk, + output_fields=['label'] ) - .runas_op['milvus_res', 'preds'](lambda x: [y.label for y in x]).unstream() + .runas_op['milvus_res', 'preds'](lambda x: [y[-1] for y in x]).unstream() .runas_op['preds', 'pred1'](lambda x: mode(x[:1])) .runas_op['preds', 'pred5'](lambda x: mode(x[:5])) .runas_op['preds', 'pred10'](lambda x: mode(x[:10]))