logo
Browse Source

Add benchmark script

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 2 years ago
parent
commit
7ddc9b0f86
  1. 172
      benchmark/run.py
  2. 7
      timm_image.py

172
benchmark/run.py

@ -0,0 +1,172 @@
import os
import torch
import onnxruntime
import towhee
from towhee import ops
from pymilvus import connections, DataType, FieldSchema, Collection, CollectionSchema, utility
from datasets import load_dataset
from statistics import mode
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--model', required=True, type=str)
parser.add_argument('--dataset', type=str, default='mnist')
parser.add_argument('--insert_size', type=int, default=1000)
parser.add_argument('--query_size', type=int, default=100)
parser.add_argument('--topk', type=int, default=10)
parser.add_argument('--collection_name', type=str, default=None)
parser.add_argument('--format', type=str, required=True)
parser.add_argument('--onnx_dir', type=str, default='./saved/onnx')
args = parser.parse_args()
model_name = args.model
onnx_path = os.path.join(args.onnx_dir, model_name.replace('/', '-') + '.onnx')
dataset_name = args.dataset
insert_size = args.insert_size
query_size = args.query_size
topk = args.topk
collection_name = args.collection_name if args.collection_name else model_name.replace('-', '_').replace('/', '_')
onnx_path = os.path.join(args.onnx_dir, model_name.replace('/', '_') + '.onnx')
device = 'cpu'
host = 'localhost'
port = '19530'
index_type = 'FLAT'
metric_type = 'L2'
data = load_dataset(dataset_name).shuffle(seed=32)
assert insert_size <= len(data['train']), 'There is no enough data. Please decrease insert size.'
assert insert_size <= len(data['test']), 'There is no enough data. Please decrease query size.'
insert_data = data['train']
insert_data = insert_data[:insert_size] if insert_size else insert_data
query_data = data['test']
query_data = query_data[:query_size] if query_size else query_data
# Warm up
print('Warming up...')
op = ops.image_embedding.timm(model_name=model_name, device=device).get_op()
dummy_input = torch.rand((1,) + op.config['input_size'])
dim = op(dummy_input).shape[0]
print(f'output dim: {dim}')
# Prepare Milvus
print('Connecting milvus ...')
connections.connect(host=host, port=port)
def create_milvus(collection_name):
print('Creating collection ...')
fields = [
FieldSchema(name='id', dtype=DataType.INT64, description='embedding id', is_primary=True, auto_id=True),
FieldSchema(name='embedding', dtype=DataType.FLOAT_VECTOR, description='image embedding', dim=dim),
FieldSchema(name='label', dtype=DataType.VARCHAR, description='label', max_length=500)
]
schema = CollectionSchema(fields=fields, description=f'image embeddings for {model_name} on {dataset_name}')
if utility.has_collection(collection_name):
print(f'drop old collection: {collection_name}')
collection = Collection(collection_name)
collection.drop()
collection = Collection(name=collection_name, schema=schema)
print(f'A new collection is created: {collection_name}.')
return collection
if args.format == 'pytorch':
collection_name = collection_name + '_pytorch'
def insert(model_name, collection_name):
(
towhee.dc['image', 'label'](zip(insert_data['image'], insert_data['label'])).stream()
.image_embedding.timm['image', 'emb'](model_name=model_name, device=device)
.runas_op['label', 'label'](lambda y: str(y))
.ann_insert.milvus[('emb', 'label'), 'miluvs_insert'](
uri=f'tcp://{host}:{port}/{collection_name}'
)
.show(3)
)
collection = Collection(collection_name)
return collection.num_entities
def query(model_name, collection_name):
benchmark = (
towhee.dc['image', 'gt'](zip(query_data['image'], query_data['label'])).stream()
.image_embedding.timm['image', 'emb'](model_name=model_name, device=device)
.runas_op['gt', 'gt'](lambda y: str(y))
.ann_search.milvus['emb', 'milvus_res'](
uri=f'tcp://{host}:{port}/{collection_name}',
metric_type=metric_type,
limit=topk,
output_fields=['label']
)
.runas_op['milvus_res', 'preds'](lambda x: [y.label for y in x]).unstream()
.runas_op['preds', 'pred1'](lambda x: mode(x[:1]))
.runas_op['preds', 'pred5'](lambda x: mode(x[:5]))
.runas_op['preds', 'pred10'](lambda x: mode(x[:10]))
.with_metrics(['accuracy'])
.evaluate['gt', 'pred1']('pred1')
.evaluate['gt', 'pred5']('pred5')
.evaluate['gt', 'pred10']('pred10')
.report()
)
return benchmark
elif args.format == 'onnx':
collection_name = collection_name + '_onnx'
if not os.path.exists(onnx_path):
onnx_path = op.save_model(format='onnx')
@towhee.register
def run_onnx(img):
img = op.tfms(img).cpu().detach().numpy()
features = sess.run(output_names=['output_0'], input_feed={'input_0': img})[0]
if len(features.shape) == 4:
global_pool = nn.AdaptiveAvgPool2d(1)
features = global_pool(features)
outs = features.squeeze(0)
assert outs.shape[0] == dim, 'Output dimensions are not consistent.'
return outs.cpu().detach().numpy()
def insert(model_name, collection_name):
(
towhee.dc['image', 'label'](zip(insert_data['image'], insert_data['label'])).stream()
.run_onnx['image', 'emb']()
.runas_op['label', 'label'](lambda y: str(y))
.ann_insert.milvus[('emb', 'label'), 'miluvs_insert'](
uri=f'tcp://{host}:{port}/{collection_name}'
)
.show(3)
)
collection = Collection(collection_name)
return collection.num_entities
def query(model_name, collection_name):
benchmark = (
towhee.dc['image', 'gt'](zip(query_data['image'], query_data['label'])).stream()
.run_onnx['image', 'emb']()
.runas_op['gt', 'gt'](lambda y: str(y))
.ann_search.milvus['emb', 'milvus_res'](
uri=f'tcp://{host}:{port}/{collection_name}',
metric_type=metric_type,
limit=topk,
output_fields=['label']
)
.runas_op['milvus_res', 'preds'](lambda x: [y.label for y in x]).unstream()
.runas_op['preds', 'pred1'](lambda x: mode(x[:1]))
.runas_op['preds', 'pred5'](lambda x: mode(x[:5]))
.runas_op['preds', 'pred10'](lambda x: mode(x[:10]))
.with_metrics(['accuracy'])
.evaluate['gt', 'pred1']('pred1')
.evaluate['gt', 'pred5']('pred5')
.evaluate['gt', 'pred10']('pred10')
.report()
)
return benchmark
else:
raise AttributeError('Only support "pytorch" and "onnx" as format.')
collection = create_milvus(collection_name)
insert_count = insert(model_name, collection_name)
print('Total data inserted:', insert_count)
benchmark = query(model_name, collection_name)

7
timm_image.py

@ -83,14 +83,14 @@ class TimmImage(NNOperator):
imgs = data
img_list = []
for img in imgs:
img = self.convert_img(img)
img = self.convert_img(img) if isinstance(img, numpy.ndarray) else img
img = img if self.skip_tfms else self.tfms(img)
img_list.append(img)
inputs = torch.stack(img_list)
inputs = inputs.to(self.device)
features = self.model.forward_features(inputs)
if features.dim() == 4:
global_pool = nn.AdaptiveAvgPool2d(1)
global_pool = nn.AdaptiveAvgPool2d(1).to(self.device)
features = global_pool(features)
features = features.to('cpu').flatten(1)
if isinstance(data, list):
@ -135,7 +135,7 @@ class TimmImage(NNOperator):
path,
input_names=['input_0'],
output_names=['output_0'],
opset_version=14,
opset_version=12,
dynamic_axes={
'input_0': {0: 'batch_size'},
'output_0': {0: 'batch_size'}
@ -148,6 +148,7 @@ class TimmImage(NNOperator):
# todo: elif format == 'tensorrt':
else:
log.error(f'Unsupported format "{format}".')
return Path(path).resolve()
@staticmethod
def supported_model_names(format: str = None):

Loading…
Cancel
Save