logo
Browse Source

Add evaluation

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 1 year ago
parent
commit
98d84e87f9
  1. 21
      evaluate/README.md
  2. 184
      evaluate/evaluate.py
  3. 7
      evaluate/evaluate.sh

21
evaluate/README.md

@ -0,0 +1,21 @@
# Evaluate with Similarity Search
## Introduction
Build a classification system based on similarity search across embeddings.
The core ideas in `evaluate.py`:
1. create a new Milvus collection each time
2. extract embeddings using a pretrained model with model name specified by `--model`
3. specify inference method with `--format` in value of `pytorch` or `onnx`
4. insert & search embeddings with Milvus collection without index
5. measure performance with accuracy at top 1, 5, 10
1. vote for the prediction from topk search results (most frequent one)
2. compare final prediction with ground truth
3. calculate percent of correct predictions over all queries
## Example Usage
```bash
python evaluate.py --model MODEL_NAME --format pytorch
python evaluate.py --model MODEL_NAME --format onnx
```

184
evaluate/evaluate.py

@ -0,0 +1,184 @@
import os
import onnxruntime
import towhee
from towhee import ops
from pymilvus import connections, DataType, FieldSchema, Collection, CollectionSchema, utility
from datasets import load_dataset
from statistics import mode
import argparse
import transformers
transformers.logging.set_verbosity_error()
parser = argparse.ArgumentParser()
parser.add_argument('--model', required=True, type=str)
parser.add_argument('--dataset', type=str, default='imdb')
parser.add_argument('--insert_size', type=int, default=1000)
parser.add_argument('--query_size', type=int, default=100)
parser.add_argument('--topk', type=int, default=10)
parser.add_argument('--collection_name', type=str, default=None)
parser.add_argument('--format', type=str, required=True)
args = parser.parse_args()
model_name = args.model
dataset_name = args.dataset
insert_size = args.insert_size
query_size = args.query_size
topk = args.topk
collection_name = args.collection_name if args.collection_name else model_name.replace('-', '_').replace('/', '_')
device = 'cpu'
host = 'localhost'
port = '19530'
index_type = 'FLAT'
metric_type = 'L2'
data = load_dataset(dataset_name).shuffle(seed=32)
assert insert_size <= len(data['train']), 'There is no enough data. Please decrease insert size.'
assert insert_size <= len(data['test']), 'There is no enough data. Please decrease query size.'
insert_data = data['train']
insert_data = insert_data[:insert_size] if insert_size else insert_data
query_data = data['test']
query_data = query_data[:query_size] if query_size else query_data
# Warm up
print('Warming up...')
op = ops.text_embedding.transformers(model_name=model_name, device=device).get_op()
dim = op('This is test.').shape[-1]
print(f'output dim: {dim}')
# Prepare Milvus
print('Connecting milvus ...')
connections.connect(host=host, port=port)
def create_milvus(collection_name):
print('Creating collection ...')
fields = [
FieldSchema(name='id', dtype=DataType.INT64, description='embedding id', is_primary=True, auto_id=True),
FieldSchema(name='embedding', dtype=DataType.FLOAT_VECTOR, description='text embedding', dim=dim),
FieldSchema(name='label', dtype=DataType.VARCHAR, description='label', max_length=500)
]
schema = CollectionSchema(fields=fields, description=f'text embeddings for {model_name} on {dataset_name}')
if utility.has_collection(collection_name):
print(f'drop old collection: {collection_name}')
collection = Collection(collection_name)
collection.drop()
collection = Collection(name=collection_name, schema=schema)
print(f'A new collection is created: {collection_name}.')
return collection
if args.format == 'pytorch':
collection_name = collection_name + '_pytorch'
def insert(model_name, collection_name):
(
towhee.dc['text', 'label'](zip(insert_data['text'], insert_data['label'])).stream()
.runas_op['text', 'text'](lambda s: s[:1024])
.text_embedding.transformers['text', 'emb'](model_name=model_name, device=device)
.runas_op['emb', 'emb'](lambda x: x[0])
.runas_op['label', 'label'](lambda y: str(y))
.ann_insert.milvus[('emb', 'label'), 'miluvs_insert'](
uri=f'tcp://{host}:{port}/{collection_name}'
)
.show(3)
)
collection = Collection(collection_name)
return collection.num_entities
def query(model_name, collection_name):
benchmark = (
towhee.dc['text', 'gt'](zip(query_data['text'], query_data['label'])).stream()
.runas_op['text', 'text'](lambda s: s[:1024])
.text_embedding.transformers['text', 'emb'](model_name=model_name, device=device)
.runas_op['emb', 'emb'](lambda x: x[0])
.runas_op['gt', 'gt'](lambda y: str(y))
.ann_search.milvus['emb', 'milvus_res'](
uri=f'tcp://{host}:{port}/{collection_name}',
metric_type=metric_type,
limit=topk,
output_fields=['label']
)
.runas_op['milvus_res', 'preds'](lambda x: [y.label for y in x]).unstream()
.runas_op['preds', 'pred1'](lambda x: mode(x[:1]))
.runas_op['preds', 'pred5'](lambda x: mode(x[:5]))
.runas_op['preds', 'pred10'](lambda x: mode(x[:10]))
.with_metrics(['accuracy'])
.evaluate['gt', 'pred1']('pred1')
.evaluate['gt', 'pred5']('pred5')
.evaluate['gt', 'pred10']('pred10')
.report()
)
return benchmark
elif args.format == 'onnx':
collection_name = collection_name + '_onnx'
saved_name = model_name.replace('/', '-')
onnx_path = f'saved/onnx/{saved_name}.onnx'
if not os.path.exists(onnx_path):
op.save_model(format='onnx')
sess = onnxruntime.InferenceSession(onnx_path,
providers=onnxruntime.get_available_providers())
@towhee.register
def run_onnx(txt):
inputs = op.tokenizer(txt, return_tensors='np')
onnx_inputs = [x.name for x in sess.get_inputs()]
new_inputs = {}
for k in onnx_inputs:
new_inputs[k] = inputs[k]
outs = sess.run(output_names=['last_hidden_state'], input_feed=dict(new_inputs))
return outs[0].squeeze(0)
def insert(model_name, collection_name):
(
towhee.dc['text', 'label'](zip(insert_data['text'], insert_data['label'])).stream()
.runas_op['text', 'text'](lambda s: s[:1024])
.run_onnx['text', 'emb']()
.runas_op['emb', 'emb'](lambda x: x[0])
.runas_op['label', 'label'](lambda y: str(y))
.ann_insert.milvus[('emb', 'label'), 'miluvs_insert'](
uri=f'tcp://{host}:{port}/{collection_name}'
)
.show(3)
)
collection = Collection(collection_name)
return collection.num_entities
def query(model_name, collection_name):
benchmark = (
towhee.dc['text', 'gt'](zip(query_data['text'], query_data['label'])).stream()
.runas_op['text', 'text'](lambda s: s[:1024])
.run_onnx['text', 'emb']()
.runas_op['emb', 'emb'](lambda x: x[0])
.runas_op['gt', 'gt'](lambda y: str(y))
.ann_search.milvus['emb', 'milvus_res'](
uri=f'tcp://{host}:{port}/{collection_name}',
metric_type=metric_type,
limit=topk,
output_fields=['label']
)
.runas_op['milvus_res', 'preds'](lambda x: [y.label for y in x]).unstream()
.runas_op['preds', 'pred1'](lambda x: mode(x[:1]))
.runas_op['preds', 'pred5'](lambda x: mode(x[:5]))
.runas_op['preds', 'pred10'](lambda x: mode(x[:10]))
.with_metrics(['accuracy'])
.evaluate['gt', 'pred1']('pred1')
.evaluate['gt', 'pred5']('pred5')
.evaluate['gt', 'pred10']('pred10')
.report()
)
return benchmark
else:
raise AttributeError('Only support "pytorch" and "onnx" as format.')
collection = create_milvus(collection_name)
insert_count = insert(model_name, collection_name)
print('Total data inserted:', insert_count)
benchmark = query(model_name, collection_name)

7
evaluate/evaluate.sh

@ -0,0 +1,7 @@
#!/bin/bash
for name in allenai/led-base-16384 flaubert/flaubert_base_cased flaubert/flaubert_base_uncased flaubert/flaubert_large_cased flaubert/flaubert_small_cased funnel-transformer/intermediate-base funnel-transformer/large-base funnel-transformer/medium-base funnel-transformer/small-base funnel-transformer/xlarge-base google/mobilebert-uncased tau/splinter-base tau/splinter-base-qass tau/splinter-large
do
python evaluate.py --model ${name} --format pytorch
python evaluate.py --model ${name} --format onnx
done
Loading…
Cancel
Save