transformers
copied
Jael Gu
2 years ago
3 changed files with 212 additions and 0 deletions
@ -0,0 +1,21 @@ |
|||
# Evaluate with Similarity Search |
|||
|
|||
## Introduction |
|||
|
|||
Build a classification system based on similarity search across embeddings. |
|||
The core ideas in `evaluate.py`: |
|||
1. create a new Milvus collection each time |
|||
2. extract embeddings using a pretrained model with model name specified by `--model` |
|||
3. specify inference method with `--format` in value of `pytorch` or `onnx` |
|||
4. insert & search embeddings with Milvus collection without index |
|||
5. measure performance with accuracy at top 1, 5, 10 |
|||
1. vote for the prediction from topk search results (most frequent one) |
|||
2. compare final prediction with ground truth |
|||
3. calculate percent of correct predictions over all queries |
|||
|
|||
## Example Usage |
|||
|
|||
```bash |
|||
python evaluate.py --model MODEL_NAME --format pytorch |
|||
python evaluate.py --model MODEL_NAME --format onnx |
|||
``` |
@ -0,0 +1,184 @@ |
|||
import os |
|||
import onnxruntime |
|||
|
|||
import towhee |
|||
from towhee import ops |
|||
from pymilvus import connections, DataType, FieldSchema, Collection, CollectionSchema, utility |
|||
from datasets import load_dataset |
|||
|
|||
from statistics import mode |
|||
import argparse |
|||
|
|||
import transformers |
|||
|
|||
transformers.logging.set_verbosity_error() |
|||
|
|||
parser = argparse.ArgumentParser() |
|||
parser.add_argument('--model', required=True, type=str) |
|||
parser.add_argument('--dataset', type=str, default='imdb') |
|||
parser.add_argument('--insert_size', type=int, default=1000) |
|||
parser.add_argument('--query_size', type=int, default=100) |
|||
parser.add_argument('--topk', type=int, default=10) |
|||
parser.add_argument('--collection_name', type=str, default=None) |
|||
parser.add_argument('--format', type=str, required=True) |
|||
|
|||
args = parser.parse_args() |
|||
model_name = args.model |
|||
dataset_name = args.dataset |
|||
insert_size = args.insert_size |
|||
query_size = args.query_size |
|||
topk = args.topk |
|||
collection_name = args.collection_name if args.collection_name else model_name.replace('-', '_').replace('/', '_') |
|||
|
|||
device = 'cpu' |
|||
host = 'localhost' |
|||
port = '19530' |
|||
index_type = 'FLAT' |
|||
metric_type = 'L2' |
|||
|
|||
data = load_dataset(dataset_name).shuffle(seed=32) |
|||
assert insert_size <= len(data['train']), 'There is no enough data. Please decrease insert size.' |
|||
assert insert_size <= len(data['test']), 'There is no enough data. Please decrease query size.' |
|||
|
|||
insert_data = data['train'] |
|||
insert_data = insert_data[:insert_size] if insert_size else insert_data |
|||
query_data = data['test'] |
|||
query_data = query_data[:query_size] if query_size else query_data |
|||
|
|||
# Warm up |
|||
print('Warming up...') |
|||
op = ops.text_embedding.transformers(model_name=model_name, device=device).get_op() |
|||
dim = op('This is test.').shape[-1] |
|||
print(f'output dim: {dim}') |
|||
|
|||
# Prepare Milvus |
|||
print('Connecting milvus ...') |
|||
connections.connect(host=host, port=port) |
|||
|
|||
|
|||
def create_milvus(collection_name): |
|||
print('Creating collection ...') |
|||
fields = [ |
|||
FieldSchema(name='id', dtype=DataType.INT64, description='embedding id', is_primary=True, auto_id=True), |
|||
FieldSchema(name='embedding', dtype=DataType.FLOAT_VECTOR, description='text embedding', dim=dim), |
|||
FieldSchema(name='label', dtype=DataType.VARCHAR, description='label', max_length=500) |
|||
] |
|||
schema = CollectionSchema(fields=fields, description=f'text embeddings for {model_name} on {dataset_name}') |
|||
if utility.has_collection(collection_name): |
|||
print(f'drop old collection: {collection_name}') |
|||
collection = Collection(collection_name) |
|||
collection.drop() |
|||
collection = Collection(name=collection_name, schema=schema) |
|||
print(f'A new collection is created: {collection_name}.') |
|||
return collection |
|||
|
|||
|
|||
if args.format == 'pytorch': |
|||
collection_name = collection_name + '_pytorch' |
|||
|
|||
|
|||
def insert(model_name, collection_name): |
|||
( |
|||
towhee.dc['text', 'label'](zip(insert_data['text'], insert_data['label'])).stream() |
|||
.runas_op['text', 'text'](lambda s: s[:1024]) |
|||
.text_embedding.transformers['text', 'emb'](model_name=model_name, device=device) |
|||
.runas_op['emb', 'emb'](lambda x: x[0]) |
|||
.runas_op['label', 'label'](lambda y: str(y)) |
|||
.ann_insert.milvus[('emb', 'label'), 'miluvs_insert']( |
|||
uri=f'tcp://{host}:{port}/{collection_name}' |
|||
) |
|||
.show(3) |
|||
) |
|||
collection = Collection(collection_name) |
|||
return collection.num_entities |
|||
|
|||
|
|||
def query(model_name, collection_name): |
|||
benchmark = ( |
|||
towhee.dc['text', 'gt'](zip(query_data['text'], query_data['label'])).stream() |
|||
.runas_op['text', 'text'](lambda s: s[:1024]) |
|||
.text_embedding.transformers['text', 'emb'](model_name=model_name, device=device) |
|||
.runas_op['emb', 'emb'](lambda x: x[0]) |
|||
.runas_op['gt', 'gt'](lambda y: str(y)) |
|||
.ann_search.milvus['emb', 'milvus_res']( |
|||
uri=f'tcp://{host}:{port}/{collection_name}', |
|||
metric_type=metric_type, |
|||
limit=topk, |
|||
output_fields=['label'] |
|||
) |
|||
.runas_op['milvus_res', 'preds'](lambda x: [y.label for y in x]).unstream() |
|||
.runas_op['preds', 'pred1'](lambda x: mode(x[:1])) |
|||
.runas_op['preds', 'pred5'](lambda x: mode(x[:5])) |
|||
.runas_op['preds', 'pred10'](lambda x: mode(x[:10])) |
|||
.with_metrics(['accuracy']) |
|||
.evaluate['gt', 'pred1']('pred1') |
|||
.evaluate['gt', 'pred5']('pred5') |
|||
.evaluate['gt', 'pred10']('pred10') |
|||
.report() |
|||
) |
|||
return benchmark |
|||
elif args.format == 'onnx': |
|||
collection_name = collection_name + '_onnx' |
|||
saved_name = model_name.replace('/', '-') |
|||
onnx_path = f'saved/onnx/{saved_name}.onnx' |
|||
if not os.path.exists(onnx_path): |
|||
op.save_model(format='onnx') |
|||
sess = onnxruntime.InferenceSession(onnx_path, |
|||
providers=onnxruntime.get_available_providers()) |
|||
|
|||
@towhee.register |
|||
def run_onnx(txt): |
|||
inputs = op.tokenizer(txt, return_tensors='np') |
|||
onnx_inputs = [x.name for x in sess.get_inputs()] |
|||
new_inputs = {} |
|||
for k in onnx_inputs: |
|||
new_inputs[k] = inputs[k] |
|||
outs = sess.run(output_names=['last_hidden_state'], input_feed=dict(new_inputs)) |
|||
return outs[0].squeeze(0) |
|||
|
|||
def insert(model_name, collection_name): |
|||
( |
|||
towhee.dc['text', 'label'](zip(insert_data['text'], insert_data['label'])).stream() |
|||
.runas_op['text', 'text'](lambda s: s[:1024]) |
|||
.run_onnx['text', 'emb']() |
|||
.runas_op['emb', 'emb'](lambda x: x[0]) |
|||
.runas_op['label', 'label'](lambda y: str(y)) |
|||
.ann_insert.milvus[('emb', 'label'), 'miluvs_insert']( |
|||
uri=f'tcp://{host}:{port}/{collection_name}' |
|||
) |
|||
.show(3) |
|||
) |
|||
collection = Collection(collection_name) |
|||
return collection.num_entities |
|||
|
|||
def query(model_name, collection_name): |
|||
benchmark = ( |
|||
towhee.dc['text', 'gt'](zip(query_data['text'], query_data['label'])).stream() |
|||
.runas_op['text', 'text'](lambda s: s[:1024]) |
|||
.run_onnx['text', 'emb']() |
|||
.runas_op['emb', 'emb'](lambda x: x[0]) |
|||
.runas_op['gt', 'gt'](lambda y: str(y)) |
|||
.ann_search.milvus['emb', 'milvus_res']( |
|||
uri=f'tcp://{host}:{port}/{collection_name}', |
|||
metric_type=metric_type, |
|||
limit=topk, |
|||
output_fields=['label'] |
|||
) |
|||
.runas_op['milvus_res', 'preds'](lambda x: [y.label for y in x]).unstream() |
|||
.runas_op['preds', 'pred1'](lambda x: mode(x[:1])) |
|||
.runas_op['preds', 'pred5'](lambda x: mode(x[:5])) |
|||
.runas_op['preds', 'pred10'](lambda x: mode(x[:10])) |
|||
.with_metrics(['accuracy']) |
|||
.evaluate['gt', 'pred1']('pred1') |
|||
.evaluate['gt', 'pred5']('pred5') |
|||
.evaluate['gt', 'pred10']('pred10') |
|||
.report() |
|||
) |
|||
return benchmark |
|||
else: |
|||
raise AttributeError('Only support "pytorch" and "onnx" as format.') |
|||
|
|||
collection = create_milvus(collection_name) |
|||
insert_count = insert(model_name, collection_name) |
|||
print('Total data inserted:', insert_count) |
|||
benchmark = query(model_name, collection_name) |
@ -0,0 +1,7 @@ |
|||
#!/bin/bash |
|||
|
|||
for name in allenai/led-base-16384 flaubert/flaubert_base_cased flaubert/flaubert_base_uncased flaubert/flaubert_large_cased flaubert/flaubert_small_cased funnel-transformer/intermediate-base funnel-transformer/large-base funnel-transformer/medium-base funnel-transformer/small-base funnel-transformer/xlarge-base google/mobilebert-uncased tau/splinter-base tau/splinter-base-qass tau/splinter-large |
|||
do |
|||
python evaluate.py --model ${name} --format pytorch |
|||
python evaluate.py --model ${name} --format onnx |
|||
done |
Loading…
Reference in new issue