timm
copied
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions
184 lines
7.7 KiB
184 lines
7.7 KiB
import os
|
|
import torch
|
|
from torch import nn
|
|
import numpy
|
|
import onnxruntime
|
|
|
|
import towhee
|
|
from towhee import ops
|
|
from pymilvus import connections, DataType, FieldSchema, Collection, CollectionSchema, utility
|
|
from datasets import load_dataset
|
|
|
|
from statistics import mode
|
|
import argparse
|
|
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('--model', required=True, type=str)
|
|
parser.add_argument('--dataset', type=str, default='mnist')
|
|
parser.add_argument('--insert_size', type=int, default=1000)
|
|
parser.add_argument('--query_size', type=int, default=100)
|
|
parser.add_argument('--topk', type=int, default=10)
|
|
parser.add_argument('--collection_name', type=str, default=None)
|
|
parser.add_argument('--format', type=str, required=True)
|
|
parser.add_argument('--onnx_dir', type=str, default='./saved/onnx')
|
|
|
|
args = parser.parse_args()
|
|
model_name = args.model
|
|
onnx_path = os.path.join(args.onnx_dir, model_name.replace('/', '-') + '.onnx')
|
|
dataset_name = args.dataset
|
|
insert_size = args.insert_size
|
|
query_size = args.query_size
|
|
topk = args.topk
|
|
collection_name = args.collection_name if args.collection_name else model_name.replace('-', '_').replace('/', '_')
|
|
onnx_path = os.path.join(args.onnx_dir, model_name.replace('/', '_') + '.onnx')
|
|
|
|
device = 'cpu'
|
|
host = 'localhost'
|
|
port = '19530'
|
|
index_type = 'FLAT'
|
|
metric_type = 'L2'
|
|
|
|
data = load_dataset(dataset_name).shuffle(seed=32)
|
|
assert insert_size <= len(data['train']), 'There is no enough data. Please decrease insert size.'
|
|
assert insert_size <= len(data['test']), 'There is no enough data. Please decrease query size.'
|
|
|
|
insert_data = data['train']
|
|
insert_data = insert_data[:insert_size] if insert_size else insert_data
|
|
query_data = data['test']
|
|
query_data = query_data[:query_size] if query_size else query_data
|
|
|
|
# Warm up
|
|
print('Warming up...')
|
|
op = ops.image_embedding.timm(model_name=model_name, device=device).get_op()
|
|
dim = towhee.dc(['https://raw.githubusercontent.com/towhee-io/towhee/main/towhee_logo.png']) \
|
|
.image_decode() \
|
|
.image_embedding.timm(model_name=model_name, device=device)[0].shape[0]
|
|
print(f'output dim: {dim}')
|
|
|
|
# Prepare Milvus
|
|
print('Connecting milvus ...')
|
|
connections.connect(host=host, port=port)
|
|
|
|
|
|
def create_milvus(collection_name):
|
|
print('Creating collection ...')
|
|
fields = [
|
|
FieldSchema(name='id', dtype=DataType.INT64, description='embedding id', is_primary=True, auto_id=True),
|
|
FieldSchema(name='embedding', dtype=DataType.FLOAT_VECTOR, description='image embedding', dim=dim),
|
|
FieldSchema(name='label', dtype=DataType.VARCHAR, description='label', max_length=500)
|
|
]
|
|
schema = CollectionSchema(fields=fields, description=f'image embeddings for {model_name} on {dataset_name}')
|
|
if utility.has_collection(collection_name):
|
|
print(f'drop old collection: {collection_name}')
|
|
collection = Collection(collection_name)
|
|
collection.drop()
|
|
collection = Collection(name=collection_name, schema=schema)
|
|
print(f'A new collection is created: {collection_name}.')
|
|
return collection
|
|
|
|
|
|
if args.format == 'pytorch':
|
|
collection_name = collection_name + '_pytorch'
|
|
|
|
def insert(model_name, collection_name):
|
|
(
|
|
towhee.dc['image', 'label'](zip(insert_data['image'], insert_data['label'])).stream()
|
|
.image_embedding.timm['image', 'emb'](model_name=model_name, device=device)
|
|
.runas_op['label', 'label'](lambda y: str(y))
|
|
.ann_insert.milvus[('emb', 'label'), 'miluvs_insert'](
|
|
uri=f'tcp://{host}:{port}/{collection_name}'
|
|
)
|
|
.show(3)
|
|
)
|
|
collection = Collection(collection_name)
|
|
return collection.num_entities
|
|
|
|
def query(model_name, collection_name):
|
|
benchmark = (
|
|
towhee.dc['image', 'gt'](zip(query_data['image'], query_data['label'])).stream()
|
|
.image_embedding.timm['image', 'emb'](model_name=model_name, device=device)
|
|
.runas_op['gt', 'gt'](lambda y: str(y))
|
|
.ann_search.milvus_client['emb', 'milvus_res'](
|
|
host=host,
|
|
port=port,
|
|
collection_name=collection_name,
|
|
metric_type=metric_type,
|
|
limit=topk,
|
|
output_fields=['label']
|
|
)
|
|
.runas_op['milvus_res', 'preds'](lambda x: [y[-1] for y in x]).unstream()
|
|
.runas_op['preds', 'pred1'](lambda x: mode(x[:1]))
|
|
.runas_op['preds', 'pred5'](lambda x: mode(x[:5]))
|
|
.runas_op['preds', 'pred10'](lambda x: mode(x[:10]))
|
|
.with_metrics(['accuracy'])
|
|
.evaluate['gt', 'pred1']('pred1')
|
|
.evaluate['gt', 'pred5']('pred5')
|
|
.evaluate['gt', 'pred10']('pred10')
|
|
.report()
|
|
)
|
|
return benchmark
|
|
elif args.format == 'onnx':
|
|
collection_name = collection_name + '_onnx'
|
|
if not os.path.exists(onnx_path):
|
|
onnx_path = str(op.save_model(format='onnx'))
|
|
sess = onnxruntime.InferenceSession(onnx_path,
|
|
providers=onnxruntime.get_available_providers())
|
|
|
|
|
|
@towhee.register
|
|
def run_onnx(img):
|
|
img = img.convert('RGB')
|
|
img = op.tfms(img).unsqueeze(0).cpu().detach().numpy()
|
|
features = sess.run(output_names=['output_0'], input_feed={'input_0': img})[0]
|
|
if len(features.shape) == 4:
|
|
features = torch.from_numpy(features)
|
|
global_pool = nn.AdaptiveAvgPool2d(1)
|
|
features = global_pool(features)
|
|
outs = features.squeeze(0)
|
|
assert outs.shape[0] == dim, 'Output dimensions are not consistent.'
|
|
return outs.cpu().detach().numpy()
|
|
|
|
def insert(model_name, collection_name):
|
|
(
|
|
towhee.dc['image', 'label'](zip(insert_data['image'], insert_data['label'])).stream()
|
|
.run_onnx['image', 'emb']()
|
|
.runas_op['label', 'label'](lambda y: str(y))
|
|
.ann_insert.milvus[('emb', 'label'), 'miluvs_insert'](
|
|
uri=f'tcp://{host}:{port}/{collection_name}'
|
|
)
|
|
.show(3)
|
|
)
|
|
collection = Collection(collection_name)
|
|
return collection.num_entities
|
|
|
|
def query(model_name, collection_name):
|
|
benchmark = (
|
|
towhee.dc['image', 'gt'](zip(query_data['image'], query_data['label'])).stream()
|
|
.run_onnx['image', 'emb']()
|
|
.runas_op['gt', 'gt'](lambda y: str(y))
|
|
.ann_search.milvus_client['emb', 'milvus_res'](
|
|
host=host,
|
|
port=port,
|
|
collection_name=collection_name,
|
|
metric_type=metric_type,
|
|
limit=topk,
|
|
output_fields=['label']
|
|
)
|
|
.runas_op['milvus_res', 'preds'](lambda x: [y[-1] for y in x]).unstream()
|
|
.runas_op['preds', 'pred1'](lambda x: mode(x[:1]))
|
|
.runas_op['preds', 'pred5'](lambda x: mode(x[:5]))
|
|
.runas_op['preds', 'pred10'](lambda x: mode(x[:10]))
|
|
.with_metrics(['accuracy'])
|
|
.evaluate['gt', 'pred1']('pred1')
|
|
.evaluate['gt', 'pred5']('pred5')
|
|
.evaluate['gt', 'pred10']('pred10')
|
|
.report()
|
|
)
|
|
return benchmark
|
|
else:
|
|
raise AttributeError('Only support "pytorch" and "onnx" as format.')
|
|
|
|
collection = create_milvus(collection_name)
|
|
insert_count = insert(model_name, collection_name)
|
|
print('Total data inserted:', insert_count)
|
|
benchmark = query(model_name, collection_name)
|