logo
Browse Source

Update run.py

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 2 years ago
parent
commit
f65c1a0683
  1. 22
      benchmark/README.md
  2. 22
      benchmark/run.py

22
benchmark/README.md

@ -0,0 +1,22 @@
# Evaluate with Similarity Search
## Introduction
Build a image classification system based on similarity search across embeddings.
The core ideas in `run.py`:
1. create a new Milvus collection each time
2. extract embeddings using a pretrained model with model name specified by `--model`
3. specify inference method with `--format` in value of `pytorch` or `onnx`
4. insert & search embeddings with Milvus collection without index
5. measure performance with accuracy at top 1, 5, 10
1. vote for the prediction from topk search results (most frequent one)
2. compare final prediction with ground truth
3. calculate percent of correct predictions over all queries
## Example Usage
```bash
python evaluate.py --model MODEL_NAME --format pytorch
python evaluate.py --model MODEL_NAME --format onnx
```

22
benchmark/run.py

@ -98,13 +98,15 @@ if args.format == 'pytorch':
towhee.dc['image', 'gt'](zip(query_data['image'], query_data['label'])).stream()
.image_embedding.timm['image', 'emb'](model_name=model_name, device=device)
.runas_op['gt', 'gt'](lambda y: str(y))
.ann_search.milvus['emb', 'milvus_res'](
uri=f'tcp://{host}:{port}/{collection_name}',
.ann_search.milvus_client['emb', 'milvus_res'](
host=host,
port=port,
collection_name=collection_name,
metric_type=metric_type,
limit=topk,
output_fields=['label']
)
.runas_op['milvus_res', 'preds'](lambda x: [y.label for y in x]).unstream()
.runas_op['milvus_res', 'preds'](lambda x: [y[-1] for y in x]).unstream()
.runas_op['preds', 'pred1'](lambda x: mode(x[:1]))
.runas_op['preds', 'pred5'](lambda x: mode(x[:5]))
.runas_op['preds', 'pred10'](lambda x: mode(x[:10]))
@ -154,13 +156,15 @@ elif args.format == 'onnx':
towhee.dc['image', 'gt'](zip(query_data['image'], query_data['label'])).stream()
.run_onnx['image', 'emb']()
.runas_op['gt', 'gt'](lambda y: str(y))
.ann_search.milvus['emb', 'milvus_res'](
uri=f'tcp://{host}:{port}/{collection_name}',
metric_type=metric_type,
limit=topk,
output_fields=['label']
.ann_search.milvus_client['emb', 'milvus_res'](
host=host,
port=port,
collection_name=collection_name,
metric_type=metric_type,
limit=topk,
output_fields=['label']
)
.runas_op['milvus_res', 'preds'](lambda x: [y.label for y in x]).unstream()
.runas_op['milvus_res', 'preds'](lambda x: [y[-1] for y in x]).unstream()
.runas_op['preds', 'pred1'](lambda x: mode(x[:1]))
.runas_op['preds', 'pred5'](lambda x: mode(x[:5]))
.runas_op['preds', 'pred10'](lambda x: mode(x[:10]))

Loading…
Cancel
Save