diff --git a/benchmark/run.py b/benchmark/run.py new file mode 100644 index 0000000..b96ac1c --- /dev/null +++ b/benchmark/run.py @@ -0,0 +1,172 @@ +import os +import torch +import onnxruntime + +import towhee +from towhee import ops +from pymilvus import connections, DataType, FieldSchema, Collection, CollectionSchema, utility +from datasets import load_dataset + +from statistics import mode +import argparse + +parser = argparse.ArgumentParser() +parser.add_argument('--model', required=True, type=str) +parser.add_argument('--dataset', type=str, default='mnist') +parser.add_argument('--insert_size', type=int, default=1000) +parser.add_argument('--query_size', type=int, default=100) +parser.add_argument('--topk', type=int, default=10) +parser.add_argument('--collection_name', type=str, default=None) +parser.add_argument('--format', type=str, required=True) +parser.add_argument('--onnx_dir', type=str, default='./saved/onnx') + +args = parser.parse_args() +model_name = args.model +onnx_path = os.path.join(args.onnx_dir, model_name.replace('/', '-') + '.onnx') +dataset_name = args.dataset +insert_size = args.insert_size +query_size = args.query_size +topk = args.topk +collection_name = args.collection_name if args.collection_name else model_name.replace('-', '_').replace('/', '_') +onnx_path = os.path.join(args.onnx_dir, model_name.replace('/', '_') + '.onnx') + +device = 'cpu' +host = 'localhost' +port = '19530' +index_type = 'FLAT' +metric_type = 'L2' + +data = load_dataset(dataset_name).shuffle(seed=32) +assert insert_size <= len(data['train']), 'There is no enough data. Please decrease insert size.' +assert insert_size <= len(data['test']), 'There is no enough data. Please decrease query size.' + +insert_data = data['train'] +insert_data = insert_data[:insert_size] if insert_size else insert_data +query_data = data['test'] +query_data = query_data[:query_size] if query_size else query_data + +# Warm up +print('Warming up...') +op = ops.image_embedding.timm(model_name=model_name, device=device).get_op() +dummy_input = torch.rand((1,) + op.config['input_size']) +dim = op(dummy_input).shape[0] +print(f'output dim: {dim}') + +# Prepare Milvus +print('Connecting milvus ...') +connections.connect(host=host, port=port) + + +def create_milvus(collection_name): + print('Creating collection ...') + fields = [ + FieldSchema(name='id', dtype=DataType.INT64, description='embedding id', is_primary=True, auto_id=True), + FieldSchema(name='embedding', dtype=DataType.FLOAT_VECTOR, description='image embedding', dim=dim), + FieldSchema(name='label', dtype=DataType.VARCHAR, description='label', max_length=500) + ] + schema = CollectionSchema(fields=fields, description=f'image embeddings for {model_name} on {dataset_name}') + if utility.has_collection(collection_name): + print(f'drop old collection: {collection_name}') + collection = Collection(collection_name) + collection.drop() + collection = Collection(name=collection_name, schema=schema) + print(f'A new collection is created: {collection_name}.') + return collection + + +if args.format == 'pytorch': + collection_name = collection_name + '_pytorch' + + def insert(model_name, collection_name): + ( + towhee.dc['image', 'label'](zip(insert_data['image'], insert_data['label'])).stream() + .image_embedding.timm['image', 'emb'](model_name=model_name, device=device) + .runas_op['label', 'label'](lambda y: str(y)) + .ann_insert.milvus[('emb', 'label'), 'miluvs_insert']( + uri=f'tcp://{host}:{port}/{collection_name}' + ) + .show(3) + ) + collection = Collection(collection_name) + return collection.num_entities + + def query(model_name, collection_name): + benchmark = ( + towhee.dc['image', 'gt'](zip(query_data['image'], query_data['label'])).stream() + .image_embedding.timm['image', 'emb'](model_name=model_name, device=device) + .runas_op['gt', 'gt'](lambda y: str(y)) + .ann_search.milvus['emb', 'milvus_res']( + uri=f'tcp://{host}:{port}/{collection_name}', + metric_type=metric_type, + limit=topk, + output_fields=['label'] + ) + .runas_op['milvus_res', 'preds'](lambda x: [y.label for y in x]).unstream() + .runas_op['preds', 'pred1'](lambda x: mode(x[:1])) + .runas_op['preds', 'pred5'](lambda x: mode(x[:5])) + .runas_op['preds', 'pred10'](lambda x: mode(x[:10])) + .with_metrics(['accuracy']) + .evaluate['gt', 'pred1']('pred1') + .evaluate['gt', 'pred5']('pred5') + .evaluate['gt', 'pred10']('pred10') + .report() + ) + return benchmark +elif args.format == 'onnx': + collection_name = collection_name + '_onnx' + if not os.path.exists(onnx_path): + onnx_path = op.save_model(format='onnx') + + @towhee.register + def run_onnx(img): + img = op.tfms(img).cpu().detach().numpy() + features = sess.run(output_names=['output_0'], input_feed={'input_0': img})[0] + if len(features.shape) == 4: + global_pool = nn.AdaptiveAvgPool2d(1) + features = global_pool(features) + outs = features.squeeze(0) + assert outs.shape[0] == dim, 'Output dimensions are not consistent.' + return outs.cpu().detach().numpy() + + def insert(model_name, collection_name): + ( + towhee.dc['image', 'label'](zip(insert_data['image'], insert_data['label'])).stream() + .run_onnx['image', 'emb']() + .runas_op['label', 'label'](lambda y: str(y)) + .ann_insert.milvus[('emb', 'label'), 'miluvs_insert']( + uri=f'tcp://{host}:{port}/{collection_name}' + ) + .show(3) + ) + collection = Collection(collection_name) + return collection.num_entities + + def query(model_name, collection_name): + benchmark = ( + towhee.dc['image', 'gt'](zip(query_data['image'], query_data['label'])).stream() + .run_onnx['image', 'emb']() + .runas_op['gt', 'gt'](lambda y: str(y)) + .ann_search.milvus['emb', 'milvus_res']( + uri=f'tcp://{host}:{port}/{collection_name}', + metric_type=metric_type, + limit=topk, + output_fields=['label'] + ) + .runas_op['milvus_res', 'preds'](lambda x: [y.label for y in x]).unstream() + .runas_op['preds', 'pred1'](lambda x: mode(x[:1])) + .runas_op['preds', 'pred5'](lambda x: mode(x[:5])) + .runas_op['preds', 'pred10'](lambda x: mode(x[:10])) + .with_metrics(['accuracy']) + .evaluate['gt', 'pred1']('pred1') + .evaluate['gt', 'pred5']('pred5') + .evaluate['gt', 'pred10']('pred10') + .report() + ) + return benchmark +else: + raise AttributeError('Only support "pytorch" and "onnx" as format.') + +collection = create_milvus(collection_name) +insert_count = insert(model_name, collection_name) +print('Total data inserted:', insert_count) +benchmark = query(model_name, collection_name) diff --git a/timm_image.py b/timm_image.py index f21f784..b0188a7 100644 --- a/timm_image.py +++ b/timm_image.py @@ -83,14 +83,14 @@ class TimmImage(NNOperator): imgs = data img_list = [] for img in imgs: - img = self.convert_img(img) + img = self.convert_img(img) if isinstance(img, numpy.ndarray) else img img = img if self.skip_tfms else self.tfms(img) img_list.append(img) inputs = torch.stack(img_list) inputs = inputs.to(self.device) features = self.model.forward_features(inputs) if features.dim() == 4: - global_pool = nn.AdaptiveAvgPool2d(1) + global_pool = nn.AdaptiveAvgPool2d(1).to(self.device) features = global_pool(features) features = features.to('cpu').flatten(1) if isinstance(data, list): @@ -135,7 +135,7 @@ class TimmImage(NNOperator): path, input_names=['input_0'], output_names=['output_0'], - opset_version=14, + opset_version=12, dynamic_axes={ 'input_0': {0: 'batch_size'}, 'output_0': {0: 'batch_size'} @@ -148,6 +148,7 @@ class TimmImage(NNOperator): # todo: elif format == 'tensorrt': else: log.error(f'Unsupported format "{format}".') + return Path(path).resolve() @staticmethod def supported_model_names(format: str = None):