import os import torch from torch import nn import numpy import onnxruntime import towhee from towhee import ops from pymilvus import connections, DataType, FieldSchema, Collection, CollectionSchema, utility from datasets import load_dataset from statistics import mode import argparse parser = argparse.ArgumentParser() parser.add_argument('--model', required=True, type=str) parser.add_argument('--dataset', type=str, default='mnist') parser.add_argument('--insert_size', type=int, default=1000) parser.add_argument('--query_size', type=int, default=100) parser.add_argument('--topk', type=int, default=10) parser.add_argument('--collection_name', type=str, default=None) parser.add_argument('--format', type=str, required=True) parser.add_argument('--onnx_dir', type=str, default='./saved/onnx') args = parser.parse_args() model_name = args.model onnx_path = os.path.join(args.onnx_dir, model_name.replace('/', '-') + '.onnx') dataset_name = args.dataset insert_size = args.insert_size query_size = args.query_size topk = args.topk collection_name = args.collection_name if args.collection_name else model_name.replace('-', '_').replace('/', '_') onnx_path = os.path.join(args.onnx_dir, model_name.replace('/', '_') + '.onnx') device = 'cpu' host = 'localhost' port = '19530' index_type = 'FLAT' metric_type = 'L2' data = load_dataset(dataset_name).shuffle(seed=32) assert insert_size <= len(data['train']), 'There is no enough data. Please decrease insert size.' assert insert_size <= len(data['test']), 'There is no enough data. Please decrease query size.' insert_data = data['train'] insert_data = insert_data[:insert_size] if insert_size else insert_data query_data = data['test'] query_data = query_data[:query_size] if query_size else query_data # Warm up print('Warming up...') op = ops.image_embedding.timm(model_name=model_name, device=device).get_op() dim = towhee.dc(['https://raw.githubusercontent.com/towhee-io/towhee/main/towhee_logo.png']) \ .image_decode() \ .image_embedding.timm(model_name=model_name, device=device)[0].shape[0] print(f'output dim: {dim}') # Prepare Milvus print('Connecting milvus ...') connections.connect(host=host, port=port) def create_milvus(collection_name): print('Creating collection ...') fields = [ FieldSchema(name='id', dtype=DataType.INT64, description='embedding id', is_primary=True, auto_id=True), FieldSchema(name='embedding', dtype=DataType.FLOAT_VECTOR, description='image embedding', dim=dim), FieldSchema(name='label', dtype=DataType.VARCHAR, description='label', max_length=500) ] schema = CollectionSchema(fields=fields, description=f'image embeddings for {model_name} on {dataset_name}') if utility.has_collection(collection_name): print(f'drop old collection: {collection_name}') collection = Collection(collection_name) collection.drop() collection = Collection(name=collection_name, schema=schema) print(f'A new collection is created: {collection_name}.') return collection if args.format == 'pytorch': collection_name = collection_name + '_pytorch' def insert(model_name, collection_name): ( towhee.dc['image', 'label'](zip(insert_data['image'], insert_data['label'])).stream() .image_embedding.timm['image', 'emb'](model_name=model_name, device=device) .runas_op['label', 'label'](lambda y: str(y)) .ann_insert.milvus[('emb', 'label'), 'miluvs_insert']( uri=f'tcp://{host}:{port}/{collection_name}' ) .show(3) ) collection = Collection(collection_name) return collection.num_entities def query(model_name, collection_name): benchmark = ( towhee.dc['image', 'gt'](zip(query_data['image'], query_data['label'])).stream() .image_embedding.timm['image', 'emb'](model_name=model_name, device=device) .runas_op['gt', 'gt'](lambda y: str(y)) .ann_search.milvus_client['emb', 'milvus_res']( host=host, port=port, collection_name=collection_name, metric_type=metric_type, limit=topk, output_fields=['label'] ) .runas_op['milvus_res', 'preds'](lambda x: [y[-1] for y in x]).unstream() .runas_op['preds', 'pred1'](lambda x: mode(x[:1])) .runas_op['preds', 'pred5'](lambda x: mode(x[:5])) .runas_op['preds', 'pred10'](lambda x: mode(x[:10])) .with_metrics(['accuracy']) .evaluate['gt', 'pred1']('pred1') .evaluate['gt', 'pred5']('pred5') .evaluate['gt', 'pred10']('pred10') .report() ) return benchmark elif args.format == 'onnx': collection_name = collection_name + '_onnx' if not os.path.exists(onnx_path): onnx_path = str(op.save_model(format='onnx')) sess = onnxruntime.InferenceSession(onnx_path, providers=onnxruntime.get_available_providers()) @towhee.register def run_onnx(img): img = img.convert('RGB') img = op.tfms(img).unsqueeze(0).cpu().detach().numpy() features = sess.run(output_names=['output_0'], input_feed={'input_0': img})[0] if len(features.shape) == 4: features = torch.from_numpy(features) global_pool = nn.AdaptiveAvgPool2d(1) features = global_pool(features) outs = features.squeeze(0) assert outs.shape[0] == dim, 'Output dimensions are not consistent.' return outs.cpu().detach().numpy() def insert(model_name, collection_name): ( towhee.dc['image', 'label'](zip(insert_data['image'], insert_data['label'])).stream() .run_onnx['image', 'emb']() .runas_op['label', 'label'](lambda y: str(y)) .ann_insert.milvus[('emb', 'label'), 'miluvs_insert']( uri=f'tcp://{host}:{port}/{collection_name}' ) .show(3) ) collection = Collection(collection_name) return collection.num_entities def query(model_name, collection_name): benchmark = ( towhee.dc['image', 'gt'](zip(query_data['image'], query_data['label'])).stream() .run_onnx['image', 'emb']() .runas_op['gt', 'gt'](lambda y: str(y)) .ann_search.milvus_client['emb', 'milvus_res']( host=host, port=port, collection_name=collection_name, metric_type=metric_type, limit=topk, output_fields=['label'] ) .runas_op['milvus_res', 'preds'](lambda x: [y[-1] for y in x]).unstream() .runas_op['preds', 'pred1'](lambda x: mode(x[:1])) .runas_op['preds', 'pred5'](lambda x: mode(x[:5])) .runas_op['preds', 'pred10'](lambda x: mode(x[:10])) .with_metrics(['accuracy']) .evaluate['gt', 'pred1']('pred1') .evaluate['gt', 'pred5']('pred5') .evaluate['gt', 'pred10']('pred10') .report() ) return benchmark else: raise AttributeError('Only support "pytorch" and "onnx" as format.') collection = create_milvus(collection_name) insert_count = insert(model_name, collection_name) print('Total data inserted:', insert_count) benchmark = query(model_name, collection_name)