import os import torch import onnxruntime from transformers.onnx.features import FeaturesManager from transformers.onnx import validate_model_outputs import towhee from towhee import ops from pymilvus import connections, DataType, FieldSchema, Collection, CollectionSchema, utility from datasets import load_dataset from statistics import mode import argparse import transformers transformers.logging.set_verbosity_error() parser = argparse.ArgumentParser() parser.add_argument('--model', required=True, type=str) parser.add_argument('--dataset', type=str, default='imdb') parser.add_argument('--insert_size', type=int, default=1000) parser.add_argument('--query_size', type=int, default=100) parser.add_argument('--topk', type=int, default=10) parser.add_argument('--collection_name', type=str, default=None) parser.add_argument('--format', type=str, required=True) parser.add_argument('--onnx_dir', type=str, default='../saved/onnx') args = parser.parse_args() model_name = args.model onnx_path = os.path.join(args.onnx_dir, model_name.replace('/', '-') + '.onnx') dataset_name = args.dataset insert_size = args.insert_size query_size = args.query_size topk = args.topk collection_name = args.collection_name if args.collection_name else model_name.replace('-', '_').replace('/', '_') device = 'cpu' host = 'localhost' port = '19530' index_type = 'FLAT' metric_type = 'L2' data = load_dataset(dataset_name).shuffle(seed=32) assert insert_size <= len(data['train']), 'There is no enough data. Please decrease insert size.' assert insert_size <= len(data['test']), 'There is no enough data. Please decrease query size.' insert_data = data['train'] insert_data = insert_data[:insert_size] if insert_size else insert_data query_data = data['test'] query_data = query_data[:query_size] if query_size else query_data # Warm up print('Warming up...') op = ops.text_embedding.transformers(model_name=model_name, device=device).get_op() dim = op('This is test.').shape[-1] print(f'output dim: {dim}') # Prepare Milvus print('Connecting milvus ...') connections.connect(host=host, port=port) def create_milvus(collection_name): print('Creating collection ...') fields = [ FieldSchema(name='id', dtype=DataType.INT64, description='embedding id', is_primary=True, auto_id=True), FieldSchema(name='embedding', dtype=DataType.FLOAT_VECTOR, description='text embedding', dim=dim), FieldSchema(name='label', dtype=DataType.VARCHAR, description='label', max_length=500) ] schema = CollectionSchema(fields=fields, description=f'text embeddings for {model_name} on {dataset_name}') if utility.has_collection(collection_name): print(f'drop old collection: {collection_name}') collection = Collection(collection_name) collection.drop() collection = Collection(name=collection_name, schema=schema) print(f'A new collection is created: {collection_name}.') return collection if args.format == 'pytorch': collection_name = collection_name + '_pytorch' def insert(model_name, collection_name): ( towhee.dc['text', 'label'](zip(insert_data['text'], insert_data['label'])).stream() .runas_op['text', 'text'](lambda s: s[:1024]) .text_embedding.transformers['text', 'emb'](model_name=model_name, device=device) .runas_op['emb', 'emb'](lambda x: x[0]) .runas_op['label', 'label'](lambda y: str(y)) .ann_insert.milvus[('emb', 'label'), 'miluvs_insert']( uri=f'tcp://{host}:{port}/{collection_name}' ) .show(3) ) collection = Collection(collection_name) return collection.num_entities def query(model_name, collection_name): benchmark = ( towhee.dc['text', 'gt'](zip(query_data['text'], query_data['label'])).stream() .runas_op['text', 'text'](lambda s: s[:1024]) .text_embedding.transformers['text', 'emb'](model_name=model_name, device=device) .runas_op['emb', 'emb'](lambda x: x[0]) .runas_op['gt', 'gt'](lambda y: str(y)) .ann_search.milvus_client['emb', 'milvus_res']( host=host, port=port, collection_name=collection_name, metric_type=metric_type, limit=topk, output_fields=['label'] ) .runas_op['milvus_res', 'preds'](lambda x: [y[-1] for y in x]).unstream() .runas_op['preds', 'pred1'](lambda x: mode(x[:1])) .runas_op['preds', 'pred5'](lambda x: mode(x[:5])) .runas_op['preds', 'pred10'](lambda x: mode(x[:10])) .with_metrics(['accuracy']) .evaluate['gt', 'pred1']('pred1') .evaluate['gt', 'pred5']('pred5') .evaluate['gt', 'pred10']('pred10') .report() ) return benchmark elif args.format == 'onnx': collection_name = collection_name + '_onnx' saved_name = model_name.replace('/', '-') if not os.path.exists(onnx_path): inputs = op.tokenizer('This is test.', return_tensors='pt') input_names = list(inputs.keys()) dynamic_axes = {} for i_n in input_names: dynamic_axes[i_n] = {0: 'batch_size', 1: 'sequence_length'} output_names = ['last_hidden_state'] for o_n in output_names: dynamic_axes[o_n] = {0: 'batch_size', 1: 'sequence_length'} torch.onnx.export( op.model.model, tuple(inputs.values()), onnx_path, input_names=input_names, output_names=output_names, dynamic_axes=dynamic_axes, opset_version=14, do_constant_folding=True, ) sess = onnxruntime.InferenceSession(onnx_path, providers=onnxruntime.get_available_providers()) @towhee.register def run_onnx(txt): inputs = op.tokenizer(txt, return_tensors='np') onnx_inputs = [x.name for x in sess.get_inputs()] new_inputs = {} for k in onnx_inputs: new_inputs[k] = inputs[k] outs = sess.run(output_names=['last_hidden_state'], input_feed=dict(new_inputs)) return outs[0].squeeze(0) def insert(model_name, collection_name): ( towhee.dc['text', 'label'](zip(insert_data['text'], insert_data['label'])).stream() .runas_op['text', 'text'](lambda s: s[:1024]) .run_onnx['text', 'emb']() .runas_op['emb', 'emb'](lambda x: x[0]) .runas_op['label', 'label'](lambda y: str(y)) .ann_insert.milvus[('emb', 'label'), 'miluvs_insert']( uri=f'tcp://{host}:{port}/{collection_name}' ) .show(3) ) collection = Collection(collection_name) return collection.num_entities def query(model_name, collection_name): benchmark = ( towhee.dc['text', 'gt'](zip(query_data['text'], query_data['label'])).stream() .runas_op['text', 'text'](lambda s: s[:1024]) .run_onnx['text', 'emb']() .runas_op['emb', 'emb'](lambda x: x[0]) .runas_op['gt', 'gt'](lambda y: str(y)) .ann_search.milvus_client['emb', 'milvus_res']( host=host, port=port, collection_name=collection_name, metric_type=metric_type, limit=topk, output_fields=['label'] ) .runas_op['milvus_res', 'preds'](lambda x: [y[-1] for y in x]).unstream() .runas_op['preds', 'pred1'](lambda x: mode(x[:1])) .runas_op['preds', 'pred5'](lambda x: mode(x[:5])) .runas_op['preds', 'pred10'](lambda x: mode(x[:10])) .with_metrics(['accuracy']) .evaluate['gt', 'pred1']('pred1') .evaluate['gt', 'pred5']('pred5') .evaluate['gt', 'pred10']('pred10') .report() ) return benchmark else: raise AttributeError('Only support "pytorch" and "onnx" as format.') # collection = create_milvus(collection_name) # insert_count = insert(model_name, collection_name) # print('Total data inserted:', insert_count) collection = Collection(collection_name) benchmark = query(model_name, collection_name)