timm
copied
Jael Gu
2 years ago
2 changed files with 105 additions and 99 deletions
@ -1,110 +1,97 @@ |
|||
from towhee import ops |
|||
from timm_image import TimmImage |
|||
import torch |
|||
import numpy |
|||
import onnx |
|||
import onnxruntime |
|||
|
|||
f = open('onnx.csv', 'a+') |
|||
f.write('model_name, run_op, save_onnx, check_onnx\n') |
|||
import os |
|||
from pathlib import Path |
|||
import logging |
|||
import platform |
|||
import psutil |
|||
|
|||
# models = TimmImage.supported_model_names()[:1] |
|||
models = [ |
|||
'vgg11', |
|||
'resnet18', |
|||
'resnetv2_50', |
|||
'seresnet33ts', |
|||
'skresnet18', |
|||
'resnext26ts', |
|||
'seresnext26d_32x4d', |
|||
'skresnext50_32x4d', |
|||
'convit_base', |
|||
'inception_v4', |
|||
'efficientnet_b0', |
|||
'tf_efficientnet_b0', |
|||
'swin_base_patch4_window7_224', |
|||
'vit_base_patch8_224', |
|||
'beit_base_patch16_224', |
|||
'convnext_base', |
|||
'crossvit_9_240', |
|||
'convmixer_768_32', |
|||
'coat_lite_mini', |
|||
'inception_v3', |
|||
'cait_m36_384', |
|||
'cspdarknet53', |
|||
'deit_base_distilled_patch16_224', |
|||
'densenet121', |
|||
'dla34', |
|||
'dm_nfnet_f0', |
|||
'nf_regnet_b1', |
|||
'nf_resnet50', |
|||
'dpn68', |
|||
'ese_vovnet19b_dw', |
|||
'fbnetc_100', |
|||
'fbnetv3_b', |
|||
'halonet26t', |
|||
'eca_halonext26ts', |
|||
'sehalonet33ts', |
|||
'hardcorenas_a', |
|||
'hrnet_w18', |
|||
'jx_nest_base', |
|||
'lcnet_050', |
|||
'levit_128', |
|||
'mixer_b16_224', |
|||
'mixnet_s', |
|||
'mnasnet_100', |
|||
'mobilenetv2_050', |
|||
'mobilenetv3_large_100', |
|||
'nasnetalarge', |
|||
'pit_b_224', |
|||
'pnasnet5large', |
|||
'regnetx_002', |
|||
'repvgg_a2', |
|||
'res2net50_14w_8s', |
|||
'res2next50', |
|||
'resmlp_12_224', |
|||
'resnest14d', |
|||
'rexnet_100', |
|||
'selecsls42b', |
|||
'semnasnet_075', |
|||
'tinynet_a', |
|||
'tnt_s_patch16_224', |
|||
'tresnet_l', |
|||
'twins_pcpvt_base', |
|||
'visformer_small', |
|||
'xception', |
|||
'xcit_large_24_p8_224', |
|||
'ghostnet_100', |
|||
'gmlp_s16_224', |
|||
'lambda_resnet26rpt_256', |
|||
'spnasnet_100', |
|||
] |
|||
models = ['resnet50'] |
|||
|
|||
decoder = ops.image_decode() |
|||
data = decoder('./towhee.jpeg') |
|||
atol = 1e-3 |
|||
log_path = 'timm_onnx.log' |
|||
f = open('onnx.csv', 'w+') |
|||
f.write('model,load_op,save_onnx,check_onnx,run_onnx,accuracy\n') |
|||
|
|||
logger = logging.getLogger('timm_onnx') |
|||
logger.setLevel(logging.DEBUG) |
|||
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') |
|||
fh = logging.FileHandler(log_path) |
|||
fh.setLevel(logging.DEBUG) |
|||
fh.setFormatter(formatter) |
|||
logger.addHandler(fh) |
|||
ch = logging.StreamHandler() |
|||
ch.setLevel(logging.ERROR) |
|||
ch.setFormatter(formatter) |
|||
logger.addHandler(ch) |
|||
|
|||
logger.debug(f'machine: {platform.platform()}-{platform.processor()}') |
|||
logger.debug(f'free/available/total mem: {round(psutil.virtual_memory().free / (1024.0 ** 3))}' |
|||
f'/{round(psutil.virtual_memory().available / (1024.0 ** 3))}' |
|||
f'/{round(psutil.virtual_memory().total / (1024.0 ** 3))} GB') |
|||
logger.debug(f'cpu: {psutil.cpu_count()}') |
|||
|
|||
|
|||
status = None |
|||
for name in models: |
|||
f.write(f'{name},') |
|||
logger.info(f'***{name}***') |
|||
saved_name = name.replace('/', '-') |
|||
onnx_path = f'saved/onnx/{saved_name}.onnx' |
|||
if status: |
|||
f.write(','.join(status) + '\n') |
|||
status = [name] + ['fail'] * 5 |
|||
|
|||
op = TimmImage(model_name=name, device='cpu') |
|||
data = torch.rand((1,) + op.config['input_size']) |
|||
try: |
|||
op = TimmImage(model_name=name) |
|||
out1 = op(data) |
|||
f.write('success,') |
|||
out1 = op.model(data).detach().numpy() |
|||
logger.info('OP LOADED.') |
|||
status[1] = 'success' |
|||
except Exception as e: |
|||
f.write('fail\n') |
|||
print(f'Fail to load op for {name}: {e}') |
|||
logger.error(f'FAIL TO LOAD OP: {e}') |
|||
continue |
|||
try: |
|||
op.save_model(format='onnx') |
|||
f.write('success,') |
|||
logger.info('ONNX SAVED.') |
|||
status[2] = 'success' |
|||
except Exception as e: |
|||
logger.error(f'FAIL TO SAVE ONNX: {e}') |
|||
continue |
|||
try: |
|||
try: |
|||
onnx_model = onnx.load(onnx_path) |
|||
onnx.checker.check_model(onnx_model) |
|||
except Exception: |
|||
saved_onnx = onnx.load(onnx_path, load_external_data=False) |
|||
onnx.checker.check_model(saved_onnx) |
|||
logger.info('ONNX CHECKED.') |
|||
status[3] = 'success' |
|||
except Exception as e: |
|||
f.write('fail\n') |
|||
print(f'Fail to save onnx for {name}: {e}') |
|||
logger.error(f'FAIL TO CHECK ONNX: {e}') |
|||
continue |
|||
try: |
|||
saved_name = name.replace('/', '-') |
|||
onnx_model = onnx.load(f'saved/onnx/{saved_name}.onnx') |
|||
onnx.checker.check_model(onnx_model) |
|||
f.write('success') |
|||
sess = onnxruntime.InferenceSession(onnx_path, |
|||
providers=onnxruntime.get_available_providers()) |
|||
|
|||
out2 = sess.run(None, input_feed={'input_0': data.detach().numpy()}) |
|||
logger.info('ONNX WORKED.') |
|||
status[4] = 'success' |
|||
if numpy.allclose(out1, out2, atol=atol): |
|||
logger.info('Check accuracy: OK') |
|||
status[5] = 'success' |
|||
else: |
|||
logger.info(f'Check accuracy: atol is larger than {atol}.') |
|||
except Exception as e: |
|||
f.write('fail\n') |
|||
print(f'Fail to check onnx for {name}: {e}') |
|||
logger.error(f'FAIL TO RUN ONNX: {e}') |
|||
continue |
|||
f.write('\n') |
|||
print('Finished.') |
|||
|
|||
if status: |
|||
f.write(','.join(status) + '\n') |
|||
|
|||
print('Finished.') |
Loading…
Reference in new issue