logo
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions

184 lines
7.6 KiB

import os
import onnxruntime
import towhee
from towhee import ops
from pymilvus import connections, DataType, FieldSchema, Collection, CollectionSchema, utility
from datasets import load_dataset
from statistics import mode
import argparse
import transformers
transformers.logging.set_verbosity_error()
parser = argparse.ArgumentParser()
parser.add_argument('--model', required=True, type=str)
parser.add_argument('--dataset', type=str, default='imdb')
parser.add_argument('--insert_size', type=int, default=1000)
parser.add_argument('--query_size', type=int, default=100)
parser.add_argument('--topk', type=int, default=10)
parser.add_argument('--collection_name', type=str, default=None)
parser.add_argument('--format', type=str, required=True)
args = parser.parse_args()
model_name = args.model
dataset_name = args.dataset
insert_size = args.insert_size
query_size = args.query_size
topk = args.topk
collection_name = args.collection_name if args.collection_name else model_name.replace('-', '_').replace('/', '_')
device = 'cpu'
host = 'localhost'
port = '19530'
index_type = 'FLAT'
metric_type = 'L2'
data = load_dataset(dataset_name).shuffle(seed=32)
assert insert_size <= len(data['train']), 'There is no enough data. Please decrease insert size.'
assert insert_size <= len(data['test']), 'There is no enough data. Please decrease query size.'
insert_data = data['train']
insert_data = insert_data[:insert_size] if insert_size else insert_data
query_data = data['test']
query_data = query_data[:query_size] if query_size else query_data
# Warm up
print('Warming up...')
op = ops.text_embedding.transformers(model_name=model_name, device=device).get_op()
dim = op('This is test.').shape[-1]
print(f'output dim: {dim}')
# Prepare Milvus
print('Connecting milvus ...')
connections.connect(host=host, port=port)
def create_milvus(collection_name):
print('Creating collection ...')
fields = [
FieldSchema(name='id', dtype=DataType.INT64, description='embedding id', is_primary=True, auto_id=True),
FieldSchema(name='embedding', dtype=DataType.FLOAT_VECTOR, description='text embedding', dim=dim),
FieldSchema(name='label', dtype=DataType.VARCHAR, description='label', max_length=500)
]
schema = CollectionSchema(fields=fields, description=f'text embeddings for {model_name} on {dataset_name}')
if utility.has_collection(collection_name):
print(f'drop old collection: {collection_name}')
collection = Collection(collection_name)
collection.drop()
collection = Collection(name=collection_name, schema=schema)
print(f'A new collection is created: {collection_name}.')
return collection
if args.format == 'pytorch':
collection_name = collection_name + '_pytorch'
def insert(model_name, collection_name):
(
towhee.dc['text', 'label'](zip(insert_data['text'], insert_data['label'])).stream()
.runas_op['text', 'text'](lambda s: s[:1024])
.text_embedding.transformers['text', 'emb'](model_name=model_name, device=device)
.runas_op['emb', 'emb'](lambda x: x[0])
.runas_op['label', 'label'](lambda y: str(y))
.ann_insert.milvus[('emb', 'label'), 'miluvs_insert'](
uri=f'tcp://{host}:{port}/{collection_name}'
)
.show(3)
)
collection = Collection(collection_name)
return collection.num_entities
def query(model_name, collection_name):
benchmark = (
towhee.dc['text', 'gt'](zip(query_data['text'], query_data['label'])).stream()
.runas_op['text', 'text'](lambda s: s[:1024])
.text_embedding.transformers['text', 'emb'](model_name=model_name, device=device)
.runas_op['emb', 'emb'](lambda x: x[0])
.runas_op['gt', 'gt'](lambda y: str(y))
.ann_search.milvus['emb', 'milvus_res'](
uri=f'tcp://{host}:{port}/{collection_name}',
metric_type=metric_type,
limit=topk,
output_fields=['label']
)
.runas_op['milvus_res', 'preds'](lambda x: [y.label for y in x]).unstream()
.runas_op['preds', 'pred1'](lambda x: mode(x[:1]))
.runas_op['preds', 'pred5'](lambda x: mode(x[:5]))
.runas_op['preds', 'pred10'](lambda x: mode(x[:10]))
.with_metrics(['accuracy'])
.evaluate['gt', 'pred1']('pred1')
.evaluate['gt', 'pred5']('pred5')
.evaluate['gt', 'pred10']('pred10')
.report()
)
return benchmark
elif args.format == 'onnx':
collection_name = collection_name + '_onnx'
saved_name = model_name.replace('/', '-')
onnx_path = f'saved/onnx/{saved_name}.onnx'
if not os.path.exists(onnx_path):
op.save_model(format='onnx')
sess = onnxruntime.InferenceSession(onnx_path,
providers=onnxruntime.get_available_providers())
@towhee.register
def run_onnx(txt):
inputs = op.tokenizer(txt, return_tensors='np')
onnx_inputs = [x.name for x in sess.get_inputs()]
new_inputs = {}
for k in onnx_inputs:
new_inputs[k] = inputs[k]
outs = sess.run(output_names=['last_hidden_state'], input_feed=dict(new_inputs))
return outs[0].squeeze(0)
def insert(model_name, collection_name):
(
towhee.dc['text', 'label'](zip(insert_data['text'], insert_data['label'])).stream()
.runas_op['text', 'text'](lambda s: s[:1024])
.run_onnx['text', 'emb']()
.runas_op['emb', 'emb'](lambda x: x[0])
.runas_op['label', 'label'](lambda y: str(y))
.ann_insert.milvus[('emb', 'label'), 'miluvs_insert'](
uri=f'tcp://{host}:{port}/{collection_name}'
)
.show(3)
)
collection = Collection(collection_name)
return collection.num_entities
def query(model_name, collection_name):
benchmark = (
towhee.dc['text', 'gt'](zip(query_data['text'], query_data['label'])).stream()
.runas_op['text', 'text'](lambda s: s[:1024])
.run_onnx['text', 'emb']()
.runas_op['emb', 'emb'](lambda x: x[0])
.runas_op['gt', 'gt'](lambda y: str(y))
.ann_search.milvus['emb', 'milvus_res'](
uri=f'tcp://{host}:{port}/{collection_name}',
metric_type=metric_type,
limit=topk,
output_fields=['label']
)
.runas_op['milvus_res', 'preds'](lambda x: [y.label for y in x]).unstream()
.runas_op['preds', 'pred1'](lambda x: mode(x[:1]))
.runas_op['preds', 'pred5'](lambda x: mode(x[:5]))
.runas_op['preds', 'pred10'](lambda x: mode(x[:10]))
.with_metrics(['accuracy'])
.evaluate['gt', 'pred1']('pred1')
.evaluate['gt', 'pred5']('pred5')
.evaluate['gt', 'pred10']('pred10')
.report()
)
return benchmark
else:
raise AttributeError('Only support "pytorch" and "onnx" as format.')
collection = create_milvus(collection_name)
insert_count = insert(model_name, collection_name)
print('Total data inserted:', insert_count)
benchmark = query(model_name, collection_name)