transformers
copied
Jael Gu
2 years ago
3 changed files with 212 additions and 0 deletions
@ -0,0 +1,21 @@ |
|||||
|
# Evaluate with Similarity Search |
||||
|
|
||||
|
## Introduction |
||||
|
|
||||
|
Build a classification system based on similarity search across embeddings. |
||||
|
The core ideas in `evaluate.py`: |
||||
|
1. create a new Milvus collection each time |
||||
|
2. extract embeddings using a pretrained model with model name specified by `--model` |
||||
|
3. specify inference method with `--format` in value of `pytorch` or `onnx` |
||||
|
4. insert & search embeddings with Milvus collection without index |
||||
|
5. measure performance with accuracy at top 1, 5, 10 |
||||
|
1. vote for the prediction from topk search results (most frequent one) |
||||
|
2. compare final prediction with ground truth |
||||
|
3. calculate percent of correct predictions over all queries |
||||
|
|
||||
|
## Example Usage |
||||
|
|
||||
|
```bash |
||||
|
python evaluate.py --model MODEL_NAME --format pytorch |
||||
|
python evaluate.py --model MODEL_NAME --format onnx |
||||
|
``` |
@ -0,0 +1,184 @@ |
|||||
|
import os |
||||
|
import onnxruntime |
||||
|
|
||||
|
import towhee |
||||
|
from towhee import ops |
||||
|
from pymilvus import connections, DataType, FieldSchema, Collection, CollectionSchema, utility |
||||
|
from datasets import load_dataset |
||||
|
|
||||
|
from statistics import mode |
||||
|
import argparse |
||||
|
|
||||
|
import transformers |
||||
|
|
||||
|
transformers.logging.set_verbosity_error() |
||||
|
|
||||
|
parser = argparse.ArgumentParser() |
||||
|
parser.add_argument('--model', required=True, type=str) |
||||
|
parser.add_argument('--dataset', type=str, default='imdb') |
||||
|
parser.add_argument('--insert_size', type=int, default=1000) |
||||
|
parser.add_argument('--query_size', type=int, default=100) |
||||
|
parser.add_argument('--topk', type=int, default=10) |
||||
|
parser.add_argument('--collection_name', type=str, default=None) |
||||
|
parser.add_argument('--format', type=str, required=True) |
||||
|
|
||||
|
args = parser.parse_args() |
||||
|
model_name = args.model |
||||
|
dataset_name = args.dataset |
||||
|
insert_size = args.insert_size |
||||
|
query_size = args.query_size |
||||
|
topk = args.topk |
||||
|
collection_name = args.collection_name if args.collection_name else model_name.replace('-', '_').replace('/', '_') |
||||
|
|
||||
|
device = 'cpu' |
||||
|
host = 'localhost' |
||||
|
port = '19530' |
||||
|
index_type = 'FLAT' |
||||
|
metric_type = 'L2' |
||||
|
|
||||
|
data = load_dataset(dataset_name).shuffle(seed=32) |
||||
|
assert insert_size <= len(data['train']), 'There is no enough data. Please decrease insert size.' |
||||
|
assert insert_size <= len(data['test']), 'There is no enough data. Please decrease query size.' |
||||
|
|
||||
|
insert_data = data['train'] |
||||
|
insert_data = insert_data[:insert_size] if insert_size else insert_data |
||||
|
query_data = data['test'] |
||||
|
query_data = query_data[:query_size] if query_size else query_data |
||||
|
|
||||
|
# Warm up |
||||
|
print('Warming up...') |
||||
|
op = ops.text_embedding.transformers(model_name=model_name, device=device).get_op() |
||||
|
dim = op('This is test.').shape[-1] |
||||
|
print(f'output dim: {dim}') |
||||
|
|
||||
|
# Prepare Milvus |
||||
|
print('Connecting milvus ...') |
||||
|
connections.connect(host=host, port=port) |
||||
|
|
||||
|
|
||||
|
def create_milvus(collection_name): |
||||
|
print('Creating collection ...') |
||||
|
fields = [ |
||||
|
FieldSchema(name='id', dtype=DataType.INT64, description='embedding id', is_primary=True, auto_id=True), |
||||
|
FieldSchema(name='embedding', dtype=DataType.FLOAT_VECTOR, description='text embedding', dim=dim), |
||||
|
FieldSchema(name='label', dtype=DataType.VARCHAR, description='label', max_length=500) |
||||
|
] |
||||
|
schema = CollectionSchema(fields=fields, description=f'text embeddings for {model_name} on {dataset_name}') |
||||
|
if utility.has_collection(collection_name): |
||||
|
print(f'drop old collection: {collection_name}') |
||||
|
collection = Collection(collection_name) |
||||
|
collection.drop() |
||||
|
collection = Collection(name=collection_name, schema=schema) |
||||
|
print(f'A new collection is created: {collection_name}.') |
||||
|
return collection |
||||
|
|
||||
|
|
||||
|
if args.format == 'pytorch': |
||||
|
collection_name = collection_name + '_pytorch' |
||||
|
|
||||
|
|
||||
|
def insert(model_name, collection_name): |
||||
|
( |
||||
|
towhee.dc['text', 'label'](zip(insert_data['text'], insert_data['label'])).stream() |
||||
|
.runas_op['text', 'text'](lambda s: s[:1024]) |
||||
|
.text_embedding.transformers['text', 'emb'](model_name=model_name, device=device) |
||||
|
.runas_op['emb', 'emb'](lambda x: x[0]) |
||||
|
.runas_op['label', 'label'](lambda y: str(y)) |
||||
|
.ann_insert.milvus[('emb', 'label'), 'miluvs_insert']( |
||||
|
uri=f'tcp://{host}:{port}/{collection_name}' |
||||
|
) |
||||
|
.show(3) |
||||
|
) |
||||
|
collection = Collection(collection_name) |
||||
|
return collection.num_entities |
||||
|
|
||||
|
|
||||
|
def query(model_name, collection_name): |
||||
|
benchmark = ( |
||||
|
towhee.dc['text', 'gt'](zip(query_data['text'], query_data['label'])).stream() |
||||
|
.runas_op['text', 'text'](lambda s: s[:1024]) |
||||
|
.text_embedding.transformers['text', 'emb'](model_name=model_name, device=device) |
||||
|
.runas_op['emb', 'emb'](lambda x: x[0]) |
||||
|
.runas_op['gt', 'gt'](lambda y: str(y)) |
||||
|
.ann_search.milvus['emb', 'milvus_res']( |
||||
|
uri=f'tcp://{host}:{port}/{collection_name}', |
||||
|
metric_type=metric_type, |
||||
|
limit=topk, |
||||
|
output_fields=['label'] |
||||
|
) |
||||
|
.runas_op['milvus_res', 'preds'](lambda x: [y.label for y in x]).unstream() |
||||
|
.runas_op['preds', 'pred1'](lambda x: mode(x[:1])) |
||||
|
.runas_op['preds', 'pred5'](lambda x: mode(x[:5])) |
||||
|
.runas_op['preds', 'pred10'](lambda x: mode(x[:10])) |
||||
|
.with_metrics(['accuracy']) |
||||
|
.evaluate['gt', 'pred1']('pred1') |
||||
|
.evaluate['gt', 'pred5']('pred5') |
||||
|
.evaluate['gt', 'pred10']('pred10') |
||||
|
.report() |
||||
|
) |
||||
|
return benchmark |
||||
|
elif args.format == 'onnx': |
||||
|
collection_name = collection_name + '_onnx' |
||||
|
saved_name = model_name.replace('/', '-') |
||||
|
onnx_path = f'saved/onnx/{saved_name}.onnx' |
||||
|
if not os.path.exists(onnx_path): |
||||
|
op.save_model(format='onnx') |
||||
|
sess = onnxruntime.InferenceSession(onnx_path, |
||||
|
providers=onnxruntime.get_available_providers()) |
||||
|
|
||||
|
@towhee.register |
||||
|
def run_onnx(txt): |
||||
|
inputs = op.tokenizer(txt, return_tensors='np') |
||||
|
onnx_inputs = [x.name for x in sess.get_inputs()] |
||||
|
new_inputs = {} |
||||
|
for k in onnx_inputs: |
||||
|
new_inputs[k] = inputs[k] |
||||
|
outs = sess.run(output_names=['last_hidden_state'], input_feed=dict(new_inputs)) |
||||
|
return outs[0].squeeze(0) |
||||
|
|
||||
|
def insert(model_name, collection_name): |
||||
|
( |
||||
|
towhee.dc['text', 'label'](zip(insert_data['text'], insert_data['label'])).stream() |
||||
|
.runas_op['text', 'text'](lambda s: s[:1024]) |
||||
|
.run_onnx['text', 'emb']() |
||||
|
.runas_op['emb', 'emb'](lambda x: x[0]) |
||||
|
.runas_op['label', 'label'](lambda y: str(y)) |
||||
|
.ann_insert.milvus[('emb', 'label'), 'miluvs_insert']( |
||||
|
uri=f'tcp://{host}:{port}/{collection_name}' |
||||
|
) |
||||
|
.show(3) |
||||
|
) |
||||
|
collection = Collection(collection_name) |
||||
|
return collection.num_entities |
||||
|
|
||||
|
def query(model_name, collection_name): |
||||
|
benchmark = ( |
||||
|
towhee.dc['text', 'gt'](zip(query_data['text'], query_data['label'])).stream() |
||||
|
.runas_op['text', 'text'](lambda s: s[:1024]) |
||||
|
.run_onnx['text', 'emb']() |
||||
|
.runas_op['emb', 'emb'](lambda x: x[0]) |
||||
|
.runas_op['gt', 'gt'](lambda y: str(y)) |
||||
|
.ann_search.milvus['emb', 'milvus_res']( |
||||
|
uri=f'tcp://{host}:{port}/{collection_name}', |
||||
|
metric_type=metric_type, |
||||
|
limit=topk, |
||||
|
output_fields=['label'] |
||||
|
) |
||||
|
.runas_op['milvus_res', 'preds'](lambda x: [y.label for y in x]).unstream() |
||||
|
.runas_op['preds', 'pred1'](lambda x: mode(x[:1])) |
||||
|
.runas_op['preds', 'pred5'](lambda x: mode(x[:5])) |
||||
|
.runas_op['preds', 'pred10'](lambda x: mode(x[:10])) |
||||
|
.with_metrics(['accuracy']) |
||||
|
.evaluate['gt', 'pred1']('pred1') |
||||
|
.evaluate['gt', 'pred5']('pred5') |
||||
|
.evaluate['gt', 'pred10']('pred10') |
||||
|
.report() |
||||
|
) |
||||
|
return benchmark |
||||
|
else: |
||||
|
raise AttributeError('Only support "pytorch" and "onnx" as format.') |
||||
|
|
||||
|
collection = create_milvus(collection_name) |
||||
|
insert_count = insert(model_name, collection_name) |
||||
|
print('Total data inserted:', insert_count) |
||||
|
benchmark = query(model_name, collection_name) |
@ -0,0 +1,7 @@ |
|||||
|
#!/bin/bash |
||||
|
|
||||
|
for name in allenai/led-base-16384 flaubert/flaubert_base_cased flaubert/flaubert_base_uncased flaubert/flaubert_large_cased flaubert/flaubert_small_cased funnel-transformer/intermediate-base funnel-transformer/large-base funnel-transformer/medium-base funnel-transformer/small-base funnel-transformer/xlarge-base google/mobilebert-uncased tau/splinter-base tau/splinter-base-qass tau/splinter-large |
||||
|
do |
||||
|
python evaluate.py --model ${name} --format pytorch |
||||
|
python evaluate.py --model ${name} --format onnx |
||||
|
done |
Loading…
Reference in new issue