timm
copied
Jael Gu
2 years ago
2 changed files with 176 additions and 3 deletions
@ -0,0 +1,172 @@ |
|||||
|
import os |
||||
|
import torch |
||||
|
import onnxruntime |
||||
|
|
||||
|
import towhee |
||||
|
from towhee import ops |
||||
|
from pymilvus import connections, DataType, FieldSchema, Collection, CollectionSchema, utility |
||||
|
from datasets import load_dataset |
||||
|
|
||||
|
from statistics import mode |
||||
|
import argparse |
||||
|
|
||||
|
parser = argparse.ArgumentParser() |
||||
|
parser.add_argument('--model', required=True, type=str) |
||||
|
parser.add_argument('--dataset', type=str, default='mnist') |
||||
|
parser.add_argument('--insert_size', type=int, default=1000) |
||||
|
parser.add_argument('--query_size', type=int, default=100) |
||||
|
parser.add_argument('--topk', type=int, default=10) |
||||
|
parser.add_argument('--collection_name', type=str, default=None) |
||||
|
parser.add_argument('--format', type=str, required=True) |
||||
|
parser.add_argument('--onnx_dir', type=str, default='./saved/onnx') |
||||
|
|
||||
|
args = parser.parse_args() |
||||
|
model_name = args.model |
||||
|
onnx_path = os.path.join(args.onnx_dir, model_name.replace('/', '-') + '.onnx') |
||||
|
dataset_name = args.dataset |
||||
|
insert_size = args.insert_size |
||||
|
query_size = args.query_size |
||||
|
topk = args.topk |
||||
|
collection_name = args.collection_name if args.collection_name else model_name.replace('-', '_').replace('/', '_') |
||||
|
onnx_path = os.path.join(args.onnx_dir, model_name.replace('/', '_') + '.onnx') |
||||
|
|
||||
|
device = 'cpu' |
||||
|
host = 'localhost' |
||||
|
port = '19530' |
||||
|
index_type = 'FLAT' |
||||
|
metric_type = 'L2' |
||||
|
|
||||
|
data = load_dataset(dataset_name).shuffle(seed=32) |
||||
|
assert insert_size <= len(data['train']), 'There is no enough data. Please decrease insert size.' |
||||
|
assert insert_size <= len(data['test']), 'There is no enough data. Please decrease query size.' |
||||
|
|
||||
|
insert_data = data['train'] |
||||
|
insert_data = insert_data[:insert_size] if insert_size else insert_data |
||||
|
query_data = data['test'] |
||||
|
query_data = query_data[:query_size] if query_size else query_data |
||||
|
|
||||
|
# Warm up |
||||
|
print('Warming up...') |
||||
|
op = ops.image_embedding.timm(model_name=model_name, device=device).get_op() |
||||
|
dummy_input = torch.rand((1,) + op.config['input_size']) |
||||
|
dim = op(dummy_input).shape[0] |
||||
|
print(f'output dim: {dim}') |
||||
|
|
||||
|
# Prepare Milvus |
||||
|
print('Connecting milvus ...') |
||||
|
connections.connect(host=host, port=port) |
||||
|
|
||||
|
|
||||
|
def create_milvus(collection_name): |
||||
|
print('Creating collection ...') |
||||
|
fields = [ |
||||
|
FieldSchema(name='id', dtype=DataType.INT64, description='embedding id', is_primary=True, auto_id=True), |
||||
|
FieldSchema(name='embedding', dtype=DataType.FLOAT_VECTOR, description='image embedding', dim=dim), |
||||
|
FieldSchema(name='label', dtype=DataType.VARCHAR, description='label', max_length=500) |
||||
|
] |
||||
|
schema = CollectionSchema(fields=fields, description=f'image embeddings for {model_name} on {dataset_name}') |
||||
|
if utility.has_collection(collection_name): |
||||
|
print(f'drop old collection: {collection_name}') |
||||
|
collection = Collection(collection_name) |
||||
|
collection.drop() |
||||
|
collection = Collection(name=collection_name, schema=schema) |
||||
|
print(f'A new collection is created: {collection_name}.') |
||||
|
return collection |
||||
|
|
||||
|
|
||||
|
if args.format == 'pytorch': |
||||
|
collection_name = collection_name + '_pytorch' |
||||
|
|
||||
|
def insert(model_name, collection_name): |
||||
|
( |
||||
|
towhee.dc['image', 'label'](zip(insert_data['image'], insert_data['label'])).stream() |
||||
|
.image_embedding.timm['image', 'emb'](model_name=model_name, device=device) |
||||
|
.runas_op['label', 'label'](lambda y: str(y)) |
||||
|
.ann_insert.milvus[('emb', 'label'), 'miluvs_insert']( |
||||
|
uri=f'tcp://{host}:{port}/{collection_name}' |
||||
|
) |
||||
|
.show(3) |
||||
|
) |
||||
|
collection = Collection(collection_name) |
||||
|
return collection.num_entities |
||||
|
|
||||
|
def query(model_name, collection_name): |
||||
|
benchmark = ( |
||||
|
towhee.dc['image', 'gt'](zip(query_data['image'], query_data['label'])).stream() |
||||
|
.image_embedding.timm['image', 'emb'](model_name=model_name, device=device) |
||||
|
.runas_op['gt', 'gt'](lambda y: str(y)) |
||||
|
.ann_search.milvus['emb', 'milvus_res']( |
||||
|
uri=f'tcp://{host}:{port}/{collection_name}', |
||||
|
metric_type=metric_type, |
||||
|
limit=topk, |
||||
|
output_fields=['label'] |
||||
|
) |
||||
|
.runas_op['milvus_res', 'preds'](lambda x: [y.label for y in x]).unstream() |
||||
|
.runas_op['preds', 'pred1'](lambda x: mode(x[:1])) |
||||
|
.runas_op['preds', 'pred5'](lambda x: mode(x[:5])) |
||||
|
.runas_op['preds', 'pred10'](lambda x: mode(x[:10])) |
||||
|
.with_metrics(['accuracy']) |
||||
|
.evaluate['gt', 'pred1']('pred1') |
||||
|
.evaluate['gt', 'pred5']('pred5') |
||||
|
.evaluate['gt', 'pred10']('pred10') |
||||
|
.report() |
||||
|
) |
||||
|
return benchmark |
||||
|
elif args.format == 'onnx': |
||||
|
collection_name = collection_name + '_onnx' |
||||
|
if not os.path.exists(onnx_path): |
||||
|
onnx_path = op.save_model(format='onnx') |
||||
|
|
||||
|
@towhee.register |
||||
|
def run_onnx(img): |
||||
|
img = op.tfms(img).cpu().detach().numpy() |
||||
|
features = sess.run(output_names=['output_0'], input_feed={'input_0': img})[0] |
||||
|
if len(features.shape) == 4: |
||||
|
global_pool = nn.AdaptiveAvgPool2d(1) |
||||
|
features = global_pool(features) |
||||
|
outs = features.squeeze(0) |
||||
|
assert outs.shape[0] == dim, 'Output dimensions are not consistent.' |
||||
|
return outs.cpu().detach().numpy() |
||||
|
|
||||
|
def insert(model_name, collection_name): |
||||
|
( |
||||
|
towhee.dc['image', 'label'](zip(insert_data['image'], insert_data['label'])).stream() |
||||
|
.run_onnx['image', 'emb']() |
||||
|
.runas_op['label', 'label'](lambda y: str(y)) |
||||
|
.ann_insert.milvus[('emb', 'label'), 'miluvs_insert']( |
||||
|
uri=f'tcp://{host}:{port}/{collection_name}' |
||||
|
) |
||||
|
.show(3) |
||||
|
) |
||||
|
collection = Collection(collection_name) |
||||
|
return collection.num_entities |
||||
|
|
||||
|
def query(model_name, collection_name): |
||||
|
benchmark = ( |
||||
|
towhee.dc['image', 'gt'](zip(query_data['image'], query_data['label'])).stream() |
||||
|
.run_onnx['image', 'emb']() |
||||
|
.runas_op['gt', 'gt'](lambda y: str(y)) |
||||
|
.ann_search.milvus['emb', 'milvus_res']( |
||||
|
uri=f'tcp://{host}:{port}/{collection_name}', |
||||
|
metric_type=metric_type, |
||||
|
limit=topk, |
||||
|
output_fields=['label'] |
||||
|
) |
||||
|
.runas_op['milvus_res', 'preds'](lambda x: [y.label for y in x]).unstream() |
||||
|
.runas_op['preds', 'pred1'](lambda x: mode(x[:1])) |
||||
|
.runas_op['preds', 'pred5'](lambda x: mode(x[:5])) |
||||
|
.runas_op['preds', 'pred10'](lambda x: mode(x[:10])) |
||||
|
.with_metrics(['accuracy']) |
||||
|
.evaluate['gt', 'pred1']('pred1') |
||||
|
.evaluate['gt', 'pred5']('pred5') |
||||
|
.evaluate['gt', 'pred10']('pred10') |
||||
|
.report() |
||||
|
) |
||||
|
return benchmark |
||||
|
else: |
||||
|
raise AttributeError('Only support "pytorch" and "onnx" as format.') |
||||
|
|
||||
|
collection = create_milvus(collection_name) |
||||
|
insert_count = insert(model_name, collection_name) |
||||
|
print('Total data inserted:', insert_count) |
||||
|
benchmark = query(model_name, collection_name) |
Loading…
Reference in new issue