timm
copied
Jael Gu
2 years ago
2 changed files with 176 additions and 3 deletions
@ -0,0 +1,172 @@ |
|||
import os |
|||
import torch |
|||
import onnxruntime |
|||
|
|||
import towhee |
|||
from towhee import ops |
|||
from pymilvus import connections, DataType, FieldSchema, Collection, CollectionSchema, utility |
|||
from datasets import load_dataset |
|||
|
|||
from statistics import mode |
|||
import argparse |
|||
|
|||
parser = argparse.ArgumentParser() |
|||
parser.add_argument('--model', required=True, type=str) |
|||
parser.add_argument('--dataset', type=str, default='mnist') |
|||
parser.add_argument('--insert_size', type=int, default=1000) |
|||
parser.add_argument('--query_size', type=int, default=100) |
|||
parser.add_argument('--topk', type=int, default=10) |
|||
parser.add_argument('--collection_name', type=str, default=None) |
|||
parser.add_argument('--format', type=str, required=True) |
|||
parser.add_argument('--onnx_dir', type=str, default='./saved/onnx') |
|||
|
|||
args = parser.parse_args() |
|||
model_name = args.model |
|||
onnx_path = os.path.join(args.onnx_dir, model_name.replace('/', '-') + '.onnx') |
|||
dataset_name = args.dataset |
|||
insert_size = args.insert_size |
|||
query_size = args.query_size |
|||
topk = args.topk |
|||
collection_name = args.collection_name if args.collection_name else model_name.replace('-', '_').replace('/', '_') |
|||
onnx_path = os.path.join(args.onnx_dir, model_name.replace('/', '_') + '.onnx') |
|||
|
|||
device = 'cpu' |
|||
host = 'localhost' |
|||
port = '19530' |
|||
index_type = 'FLAT' |
|||
metric_type = 'L2' |
|||
|
|||
data = load_dataset(dataset_name).shuffle(seed=32) |
|||
assert insert_size <= len(data['train']), 'There is no enough data. Please decrease insert size.' |
|||
assert insert_size <= len(data['test']), 'There is no enough data. Please decrease query size.' |
|||
|
|||
insert_data = data['train'] |
|||
insert_data = insert_data[:insert_size] if insert_size else insert_data |
|||
query_data = data['test'] |
|||
query_data = query_data[:query_size] if query_size else query_data |
|||
|
|||
# Warm up |
|||
print('Warming up...') |
|||
op = ops.image_embedding.timm(model_name=model_name, device=device).get_op() |
|||
dummy_input = torch.rand((1,) + op.config['input_size']) |
|||
dim = op(dummy_input).shape[0] |
|||
print(f'output dim: {dim}') |
|||
|
|||
# Prepare Milvus |
|||
print('Connecting milvus ...') |
|||
connections.connect(host=host, port=port) |
|||
|
|||
|
|||
def create_milvus(collection_name): |
|||
print('Creating collection ...') |
|||
fields = [ |
|||
FieldSchema(name='id', dtype=DataType.INT64, description='embedding id', is_primary=True, auto_id=True), |
|||
FieldSchema(name='embedding', dtype=DataType.FLOAT_VECTOR, description='image embedding', dim=dim), |
|||
FieldSchema(name='label', dtype=DataType.VARCHAR, description='label', max_length=500) |
|||
] |
|||
schema = CollectionSchema(fields=fields, description=f'image embeddings for {model_name} on {dataset_name}') |
|||
if utility.has_collection(collection_name): |
|||
print(f'drop old collection: {collection_name}') |
|||
collection = Collection(collection_name) |
|||
collection.drop() |
|||
collection = Collection(name=collection_name, schema=schema) |
|||
print(f'A new collection is created: {collection_name}.') |
|||
return collection |
|||
|
|||
|
|||
if args.format == 'pytorch': |
|||
collection_name = collection_name + '_pytorch' |
|||
|
|||
def insert(model_name, collection_name): |
|||
( |
|||
towhee.dc['image', 'label'](zip(insert_data['image'], insert_data['label'])).stream() |
|||
.image_embedding.timm['image', 'emb'](model_name=model_name, device=device) |
|||
.runas_op['label', 'label'](lambda y: str(y)) |
|||
.ann_insert.milvus[('emb', 'label'), 'miluvs_insert']( |
|||
uri=f'tcp://{host}:{port}/{collection_name}' |
|||
) |
|||
.show(3) |
|||
) |
|||
collection = Collection(collection_name) |
|||
return collection.num_entities |
|||
|
|||
def query(model_name, collection_name): |
|||
benchmark = ( |
|||
towhee.dc['image', 'gt'](zip(query_data['image'], query_data['label'])).stream() |
|||
.image_embedding.timm['image', 'emb'](model_name=model_name, device=device) |
|||
.runas_op['gt', 'gt'](lambda y: str(y)) |
|||
.ann_search.milvus['emb', 'milvus_res']( |
|||
uri=f'tcp://{host}:{port}/{collection_name}', |
|||
metric_type=metric_type, |
|||
limit=topk, |
|||
output_fields=['label'] |
|||
) |
|||
.runas_op['milvus_res', 'preds'](lambda x: [y.label for y in x]).unstream() |
|||
.runas_op['preds', 'pred1'](lambda x: mode(x[:1])) |
|||
.runas_op['preds', 'pred5'](lambda x: mode(x[:5])) |
|||
.runas_op['preds', 'pred10'](lambda x: mode(x[:10])) |
|||
.with_metrics(['accuracy']) |
|||
.evaluate['gt', 'pred1']('pred1') |
|||
.evaluate['gt', 'pred5']('pred5') |
|||
.evaluate['gt', 'pred10']('pred10') |
|||
.report() |
|||
) |
|||
return benchmark |
|||
elif args.format == 'onnx': |
|||
collection_name = collection_name + '_onnx' |
|||
if not os.path.exists(onnx_path): |
|||
onnx_path = op.save_model(format='onnx') |
|||
|
|||
@towhee.register |
|||
def run_onnx(img): |
|||
img = op.tfms(img).cpu().detach().numpy() |
|||
features = sess.run(output_names=['output_0'], input_feed={'input_0': img})[0] |
|||
if len(features.shape) == 4: |
|||
global_pool = nn.AdaptiveAvgPool2d(1) |
|||
features = global_pool(features) |
|||
outs = features.squeeze(0) |
|||
assert outs.shape[0] == dim, 'Output dimensions are not consistent.' |
|||
return outs.cpu().detach().numpy() |
|||
|
|||
def insert(model_name, collection_name): |
|||
( |
|||
towhee.dc['image', 'label'](zip(insert_data['image'], insert_data['label'])).stream() |
|||
.run_onnx['image', 'emb']() |
|||
.runas_op['label', 'label'](lambda y: str(y)) |
|||
.ann_insert.milvus[('emb', 'label'), 'miluvs_insert']( |
|||
uri=f'tcp://{host}:{port}/{collection_name}' |
|||
) |
|||
.show(3) |
|||
) |
|||
collection = Collection(collection_name) |
|||
return collection.num_entities |
|||
|
|||
def query(model_name, collection_name): |
|||
benchmark = ( |
|||
towhee.dc['image', 'gt'](zip(query_data['image'], query_data['label'])).stream() |
|||
.run_onnx['image', 'emb']() |
|||
.runas_op['gt', 'gt'](lambda y: str(y)) |
|||
.ann_search.milvus['emb', 'milvus_res']( |
|||
uri=f'tcp://{host}:{port}/{collection_name}', |
|||
metric_type=metric_type, |
|||
limit=topk, |
|||
output_fields=['label'] |
|||
) |
|||
.runas_op['milvus_res', 'preds'](lambda x: [y.label for y in x]).unstream() |
|||
.runas_op['preds', 'pred1'](lambda x: mode(x[:1])) |
|||
.runas_op['preds', 'pred5'](lambda x: mode(x[:5])) |
|||
.runas_op['preds', 'pred10'](lambda x: mode(x[:10])) |
|||
.with_metrics(['accuracy']) |
|||
.evaluate['gt', 'pred1']('pred1') |
|||
.evaluate['gt', 'pred5']('pred5') |
|||
.evaluate['gt', 'pred10']('pred10') |
|||
.report() |
|||
) |
|||
return benchmark |
|||
else: |
|||
raise AttributeError('Only support "pytorch" and "onnx" as format.') |
|||
|
|||
collection = create_milvus(collection_name) |
|||
insert_count = insert(model_name, collection_name) |
|||
print('Total data inserted:', insert_count) |
|||
benchmark = query(model_name, collection_name) |
Loading…
Reference in new issue