From 98d84e87f9e42c4e37ae15d89793052bcdda5f3e Mon Sep 17 00:00:00 2001 From: Jael Gu Date: Wed, 14 Dec 2022 15:03:06 +0800 Subject: [PATCH] Add evaluation Signed-off-by: Jael Gu --- evaluate/README.md | 21 +++++ evaluate/evaluate.py | 184 +++++++++++++++++++++++++++++++++++++++++++ evaluate/evaluate.sh | 7 ++ 3 files changed, 212 insertions(+) create mode 100644 evaluate/README.md create mode 100644 evaluate/evaluate.py create mode 100755 evaluate/evaluate.sh diff --git a/evaluate/README.md b/evaluate/README.md new file mode 100644 index 0000000..ec2bbda --- /dev/null +++ b/evaluate/README.md @@ -0,0 +1,21 @@ +# Evaluate with Similarity Search + +## Introduction + +Build a classification system based on similarity search across embeddings. +The core ideas in `evaluate.py`: +1. create a new Milvus collection each time +2. extract embeddings using a pretrained model with model name specified by `--model` +3. specify inference method with `--format` in value of `pytorch` or `onnx` +4. insert & search embeddings with Milvus collection without index +5. measure performance with accuracy at top 1, 5, 10 + 1. vote for the prediction from topk search results (most frequent one) + 2. compare final prediction with ground truth + 3. calculate percent of correct predictions over all queries + +## Example Usage + +```bash +python evaluate.py --model MODEL_NAME --format pytorch +python evaluate.py --model MODEL_NAME --format onnx +``` \ No newline at end of file diff --git a/evaluate/evaluate.py b/evaluate/evaluate.py new file mode 100644 index 0000000..2ccdd5b --- /dev/null +++ b/evaluate/evaluate.py @@ -0,0 +1,184 @@ +import os +import onnxruntime + +import towhee +from towhee import ops +from pymilvus import connections, DataType, FieldSchema, Collection, CollectionSchema, utility +from datasets import load_dataset + +from statistics import mode +import argparse + +import transformers + +transformers.logging.set_verbosity_error() + +parser = argparse.ArgumentParser() +parser.add_argument('--model', required=True, type=str) +parser.add_argument('--dataset', type=str, default='imdb') +parser.add_argument('--insert_size', type=int, default=1000) +parser.add_argument('--query_size', type=int, default=100) +parser.add_argument('--topk', type=int, default=10) +parser.add_argument('--collection_name', type=str, default=None) +parser.add_argument('--format', type=str, required=True) + +args = parser.parse_args() +model_name = args.model +dataset_name = args.dataset +insert_size = args.insert_size +query_size = args.query_size +topk = args.topk +collection_name = args.collection_name if args.collection_name else model_name.replace('-', '_').replace('/', '_') + +device = 'cpu' +host = 'localhost' +port = '19530' +index_type = 'FLAT' +metric_type = 'L2' + +data = load_dataset(dataset_name).shuffle(seed=32) +assert insert_size <= len(data['train']), 'There is no enough data. Please decrease insert size.' +assert insert_size <= len(data['test']), 'There is no enough data. Please decrease query size.' + +insert_data = data['train'] +insert_data = insert_data[:insert_size] if insert_size else insert_data +query_data = data['test'] +query_data = query_data[:query_size] if query_size else query_data + +# Warm up +print('Warming up...') +op = ops.text_embedding.transformers(model_name=model_name, device=device).get_op() +dim = op('This is test.').shape[-1] +print(f'output dim: {dim}') + +# Prepare Milvus +print('Connecting milvus ...') +connections.connect(host=host, port=port) + + +def create_milvus(collection_name): + print('Creating collection ...') + fields = [ + FieldSchema(name='id', dtype=DataType.INT64, description='embedding id', is_primary=True, auto_id=True), + FieldSchema(name='embedding', dtype=DataType.FLOAT_VECTOR, description='text embedding', dim=dim), + FieldSchema(name='label', dtype=DataType.VARCHAR, description='label', max_length=500) + ] + schema = CollectionSchema(fields=fields, description=f'text embeddings for {model_name} on {dataset_name}') + if utility.has_collection(collection_name): + print(f'drop old collection: {collection_name}') + collection = Collection(collection_name) + collection.drop() + collection = Collection(name=collection_name, schema=schema) + print(f'A new collection is created: {collection_name}.') + return collection + + +if args.format == 'pytorch': + collection_name = collection_name + '_pytorch' + + + def insert(model_name, collection_name): + ( + towhee.dc['text', 'label'](zip(insert_data['text'], insert_data['label'])).stream() + .runas_op['text', 'text'](lambda s: s[:1024]) + .text_embedding.transformers['text', 'emb'](model_name=model_name, device=device) + .runas_op['emb', 'emb'](lambda x: x[0]) + .runas_op['label', 'label'](lambda y: str(y)) + .ann_insert.milvus[('emb', 'label'), 'miluvs_insert']( + uri=f'tcp://{host}:{port}/{collection_name}' + ) + .show(3) + ) + collection = Collection(collection_name) + return collection.num_entities + + + def query(model_name, collection_name): + benchmark = ( + towhee.dc['text', 'gt'](zip(query_data['text'], query_data['label'])).stream() + .runas_op['text', 'text'](lambda s: s[:1024]) + .text_embedding.transformers['text', 'emb'](model_name=model_name, device=device) + .runas_op['emb', 'emb'](lambda x: x[0]) + .runas_op['gt', 'gt'](lambda y: str(y)) + .ann_search.milvus['emb', 'milvus_res']( + uri=f'tcp://{host}:{port}/{collection_name}', + metric_type=metric_type, + limit=topk, + output_fields=['label'] + ) + .runas_op['milvus_res', 'preds'](lambda x: [y.label for y in x]).unstream() + .runas_op['preds', 'pred1'](lambda x: mode(x[:1])) + .runas_op['preds', 'pred5'](lambda x: mode(x[:5])) + .runas_op['preds', 'pred10'](lambda x: mode(x[:10])) + .with_metrics(['accuracy']) + .evaluate['gt', 'pred1']('pred1') + .evaluate['gt', 'pred5']('pred5') + .evaluate['gt', 'pred10']('pred10') + .report() + ) + return benchmark +elif args.format == 'onnx': + collection_name = collection_name + '_onnx' + saved_name = model_name.replace('/', '-') + onnx_path = f'saved/onnx/{saved_name}.onnx' + if not os.path.exists(onnx_path): + op.save_model(format='onnx') + sess = onnxruntime.InferenceSession(onnx_path, + providers=onnxruntime.get_available_providers()) + + @towhee.register + def run_onnx(txt): + inputs = op.tokenizer(txt, return_tensors='np') + onnx_inputs = [x.name for x in sess.get_inputs()] + new_inputs = {} + for k in onnx_inputs: + new_inputs[k] = inputs[k] + outs = sess.run(output_names=['last_hidden_state'], input_feed=dict(new_inputs)) + return outs[0].squeeze(0) + + def insert(model_name, collection_name): + ( + towhee.dc['text', 'label'](zip(insert_data['text'], insert_data['label'])).stream() + .runas_op['text', 'text'](lambda s: s[:1024]) + .run_onnx['text', 'emb']() + .runas_op['emb', 'emb'](lambda x: x[0]) + .runas_op['label', 'label'](lambda y: str(y)) + .ann_insert.milvus[('emb', 'label'), 'miluvs_insert']( + uri=f'tcp://{host}:{port}/{collection_name}' + ) + .show(3) + ) + collection = Collection(collection_name) + return collection.num_entities + + def query(model_name, collection_name): + benchmark = ( + towhee.dc['text', 'gt'](zip(query_data['text'], query_data['label'])).stream() + .runas_op['text', 'text'](lambda s: s[:1024]) + .run_onnx['text', 'emb']() + .runas_op['emb', 'emb'](lambda x: x[0]) + .runas_op['gt', 'gt'](lambda y: str(y)) + .ann_search.milvus['emb', 'milvus_res']( + uri=f'tcp://{host}:{port}/{collection_name}', + metric_type=metric_type, + limit=topk, + output_fields=['label'] + ) + .runas_op['milvus_res', 'preds'](lambda x: [y.label for y in x]).unstream() + .runas_op['preds', 'pred1'](lambda x: mode(x[:1])) + .runas_op['preds', 'pred5'](lambda x: mode(x[:5])) + .runas_op['preds', 'pred10'](lambda x: mode(x[:10])) + .with_metrics(['accuracy']) + .evaluate['gt', 'pred1']('pred1') + .evaluate['gt', 'pred5']('pred5') + .evaluate['gt', 'pred10']('pred10') + .report() + ) + return benchmark +else: + raise AttributeError('Only support "pytorch" and "onnx" as format.') + +collection = create_milvus(collection_name) +insert_count = insert(model_name, collection_name) +print('Total data inserted:', insert_count) +benchmark = query(model_name, collection_name) diff --git a/evaluate/evaluate.sh b/evaluate/evaluate.sh new file mode 100755 index 0000000..e8a4c5c --- /dev/null +++ b/evaluate/evaluate.sh @@ -0,0 +1,7 @@ +#!/bin/bash + +for name in allenai/led-base-16384 flaubert/flaubert_base_cased flaubert/flaubert_base_uncased flaubert/flaubert_large_cased flaubert/flaubert_small_cased funnel-transformer/intermediate-base funnel-transformer/large-base funnel-transformer/medium-base funnel-transformer/small-base funnel-transformer/xlarge-base google/mobilebert-uncased tau/splinter-base tau/splinter-base-qass tau/splinter-large +do + python evaluate.py --model ${name} --format pytorch + python evaluate.py --model ${name} --format onnx +done