logo
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions

184 lines
7.7 KiB

import os
import torch
from torch import nn
import numpy
import onnxruntime
import towhee
from towhee import ops
from pymilvus import connections, DataType, FieldSchema, Collection, CollectionSchema, utility
from datasets import load_dataset
from statistics import mode
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--model', required=True, type=str)
parser.add_argument('--dataset', type=str, default='mnist')
parser.add_argument('--insert_size', type=int, default=1000)
parser.add_argument('--query_size', type=int, default=100)
parser.add_argument('--topk', type=int, default=10)
parser.add_argument('--collection_name', type=str, default=None)
parser.add_argument('--format', type=str, required=True)
parser.add_argument('--onnx_dir', type=str, default='./saved/onnx')
args = parser.parse_args()
model_name = args.model
onnx_path = os.path.join(args.onnx_dir, model_name.replace('/', '-') + '.onnx')
dataset_name = args.dataset
insert_size = args.insert_size
query_size = args.query_size
topk = args.topk
collection_name = args.collection_name if args.collection_name else model_name.replace('-', '_').replace('/', '_')
onnx_path = os.path.join(args.onnx_dir, model_name.replace('/', '_') + '.onnx')
device = 'cpu'
host = 'localhost'
port = '19530'
index_type = 'FLAT'
metric_type = 'L2'
data = load_dataset(dataset_name).shuffle(seed=32)
assert insert_size <= len(data['train']), 'There is no enough data. Please decrease insert size.'
assert insert_size <= len(data['test']), 'There is no enough data. Please decrease query size.'
insert_data = data['train']
insert_data = insert_data[:insert_size] if insert_size else insert_data
query_data = data['test']
query_data = query_data[:query_size] if query_size else query_data
# Warm up
print('Warming up...')
op = ops.image_embedding.timm(model_name=model_name, device=device).get_op()
dim = towhee.dc(['https://raw.githubusercontent.com/towhee-io/towhee/main/towhee_logo.png']) \
.image_decode() \
.image_embedding.timm(model_name=model_name, device=device)[0].shape[0]
print(f'output dim: {dim}')
# Prepare Milvus
print('Connecting milvus ...')
connections.connect(host=host, port=port)
def create_milvus(collection_name):
print('Creating collection ...')
fields = [
FieldSchema(name='id', dtype=DataType.INT64, description='embedding id', is_primary=True, auto_id=True),
FieldSchema(name='embedding', dtype=DataType.FLOAT_VECTOR, description='image embedding', dim=dim),
FieldSchema(name='label', dtype=DataType.VARCHAR, description='label', max_length=500)
]
schema = CollectionSchema(fields=fields, description=f'image embeddings for {model_name} on {dataset_name}')
if utility.has_collection(collection_name):
print(f'drop old collection: {collection_name}')
collection = Collection(collection_name)
collection.drop()
collection = Collection(name=collection_name, schema=schema)
print(f'A new collection is created: {collection_name}.')
return collection
if args.format == 'pytorch':
collection_name = collection_name + '_pytorch'
def insert(model_name, collection_name):
(
towhee.dc['image', 'label'](zip(insert_data['image'], insert_data['label'])).stream()
.image_embedding.timm['image', 'emb'](model_name=model_name, device=device)
.runas_op['label', 'label'](lambda y: str(y))
.ann_insert.milvus[('emb', 'label'), 'miluvs_insert'](
uri=f'tcp://{host}:{port}/{collection_name}'
)
.show(3)
)
collection = Collection(collection_name)
return collection.num_entities
def query(model_name, collection_name):
benchmark = (
towhee.dc['image', 'gt'](zip(query_data['image'], query_data['label'])).stream()
.image_embedding.timm['image', 'emb'](model_name=model_name, device=device)
.runas_op['gt', 'gt'](lambda y: str(y))
.ann_search.milvus_client['emb', 'milvus_res'](
host=host,
port=port,
collection_name=collection_name,
metric_type=metric_type,
limit=topk,
output_fields=['label']
)
.runas_op['milvus_res', 'preds'](lambda x: [y[-1] for y in x]).unstream()
.runas_op['preds', 'pred1'](lambda x: mode(x[:1]))
.runas_op['preds', 'pred5'](lambda x: mode(x[:5]))
.runas_op['preds', 'pred10'](lambda x: mode(x[:10]))
.with_metrics(['accuracy'])
.evaluate['gt', 'pred1']('pred1')
.evaluate['gt', 'pred5']('pred5')
.evaluate['gt', 'pred10']('pred10')
.report()
)
return benchmark
elif args.format == 'onnx':
collection_name = collection_name + '_onnx'
if not os.path.exists(onnx_path):
onnx_path = str(op.save_model(format='onnx'))
sess = onnxruntime.InferenceSession(onnx_path,
providers=onnxruntime.get_available_providers())
@towhee.register
def run_onnx(img):
img = img.convert('RGB')
img = op.tfms(img).unsqueeze(0).cpu().detach().numpy()
features = sess.run(output_names=['output_0'], input_feed={'input_0': img})[0]
if len(features.shape) == 4:
features = torch.from_numpy(features)
global_pool = nn.AdaptiveAvgPool2d(1)
features = global_pool(features)
outs = features.squeeze(0)
assert outs.shape[0] == dim, 'Output dimensions are not consistent.'
return outs.cpu().detach().numpy()
def insert(model_name, collection_name):
(
towhee.dc['image', 'label'](zip(insert_data['image'], insert_data['label'])).stream()
.run_onnx['image', 'emb']()
.runas_op['label', 'label'](lambda y: str(y))
.ann_insert.milvus[('emb', 'label'), 'miluvs_insert'](
uri=f'tcp://{host}:{port}/{collection_name}'
)
.show(3)
)
collection = Collection(collection_name)
return collection.num_entities
def query(model_name, collection_name):
benchmark = (
towhee.dc['image', 'gt'](zip(query_data['image'], query_data['label'])).stream()
.run_onnx['image', 'emb']()
.runas_op['gt', 'gt'](lambda y: str(y))
.ann_search.milvus_client['emb', 'milvus_res'](
host=host,
port=port,
collection_name=collection_name,
metric_type=metric_type,
limit=topk,
output_fields=['label']
)
.runas_op['milvus_res', 'preds'](lambda x: [y[-1] for y in x]).unstream()
.runas_op['preds', 'pred1'](lambda x: mode(x[:1]))
.runas_op['preds', 'pred5'](lambda x: mode(x[:5]))
.runas_op['preds', 'pred10'](lambda x: mode(x[:10]))
.with_metrics(['accuracy'])
.evaluate['gt', 'pred1']('pred1')
.evaluate['gt', 'pred5']('pred5')
.evaluate['gt', 'pred10']('pred10')
.report()
)
return benchmark
else:
raise AttributeError('Only support "pytorch" and "onnx" as format.')
collection = create_milvus(collection_name)
insert_count = insert(model_name, collection_name)
print('Total data inserted:', insert_count)
benchmark = query(model_name, collection_name)