timm
copied
Jael Gu
2 years ago
2 changed files with 105 additions and 99 deletions
@ -1,110 +1,97 @@ |
|||||
from towhee import ops |
from towhee import ops |
||||
from timm_image import TimmImage |
from timm_image import TimmImage |
||||
|
import torch |
||||
|
import numpy |
||||
import onnx |
import onnx |
||||
|
import onnxruntime |
||||
|
|
||||
f = open('onnx.csv', 'a+') |
|
||||
f.write('model_name, run_op, save_onnx, check_onnx\n') |
|
||||
|
import os |
||||
|
from pathlib import Path |
||||
|
import logging |
||||
|
import platform |
||||
|
import psutil |
||||
|
|
||||
# models = TimmImage.supported_model_names()[:1] |
# models = TimmImage.supported_model_names()[:1] |
||||
models = [ |
|
||||
'vgg11', |
|
||||
'resnet18', |
|
||||
'resnetv2_50', |
|
||||
'seresnet33ts', |
|
||||
'skresnet18', |
|
||||
'resnext26ts', |
|
||||
'seresnext26d_32x4d', |
|
||||
'skresnext50_32x4d', |
|
||||
'convit_base', |
|
||||
'inception_v4', |
|
||||
'efficientnet_b0', |
|
||||
'tf_efficientnet_b0', |
|
||||
'swin_base_patch4_window7_224', |
|
||||
'vit_base_patch8_224', |
|
||||
'beit_base_patch16_224', |
|
||||
'convnext_base', |
|
||||
'crossvit_9_240', |
|
||||
'convmixer_768_32', |
|
||||
'coat_lite_mini', |
|
||||
'inception_v3', |
|
||||
'cait_m36_384', |
|
||||
'cspdarknet53', |
|
||||
'deit_base_distilled_patch16_224', |
|
||||
'densenet121', |
|
||||
'dla34', |
|
||||
'dm_nfnet_f0', |
|
||||
'nf_regnet_b1', |
|
||||
'nf_resnet50', |
|
||||
'dpn68', |
|
||||
'ese_vovnet19b_dw', |
|
||||
'fbnetc_100', |
|
||||
'fbnetv3_b', |
|
||||
'halonet26t', |
|
||||
'eca_halonext26ts', |
|
||||
'sehalonet33ts', |
|
||||
'hardcorenas_a', |
|
||||
'hrnet_w18', |
|
||||
'jx_nest_base', |
|
||||
'lcnet_050', |
|
||||
'levit_128', |
|
||||
'mixer_b16_224', |
|
||||
'mixnet_s', |
|
||||
'mnasnet_100', |
|
||||
'mobilenetv2_050', |
|
||||
'mobilenetv3_large_100', |
|
||||
'nasnetalarge', |
|
||||
'pit_b_224', |
|
||||
'pnasnet5large', |
|
||||
'regnetx_002', |
|
||||
'repvgg_a2', |
|
||||
'res2net50_14w_8s', |
|
||||
'res2next50', |
|
||||
'resmlp_12_224', |
|
||||
'resnest14d', |
|
||||
'rexnet_100', |
|
||||
'selecsls42b', |
|
||||
'semnasnet_075', |
|
||||
'tinynet_a', |
|
||||
'tnt_s_patch16_224', |
|
||||
'tresnet_l', |
|
||||
'twins_pcpvt_base', |
|
||||
'visformer_small', |
|
||||
'xception', |
|
||||
'xcit_large_24_p8_224', |
|
||||
'ghostnet_100', |
|
||||
'gmlp_s16_224', |
|
||||
'lambda_resnet26rpt_256', |
|
||||
'spnasnet_100', |
|
||||
] |
|
||||
|
models = ['resnet50'] |
||||
|
|
||||
decoder = ops.image_decode() |
|
||||
data = decoder('./towhee.jpeg') |
|
||||
|
atol = 1e-3 |
||||
|
log_path = 'timm_onnx.log' |
||||
|
f = open('onnx.csv', 'w+') |
||||
|
f.write('model,load_op,save_onnx,check_onnx,run_onnx,accuracy\n') |
||||
|
|
||||
|
logger = logging.getLogger('timm_onnx') |
||||
|
logger.setLevel(logging.DEBUG) |
||||
|
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') |
||||
|
fh = logging.FileHandler(log_path) |
||||
|
fh.setLevel(logging.DEBUG) |
||||
|
fh.setFormatter(formatter) |
||||
|
logger.addHandler(fh) |
||||
|
ch = logging.StreamHandler() |
||||
|
ch.setLevel(logging.ERROR) |
||||
|
ch.setFormatter(formatter) |
||||
|
logger.addHandler(ch) |
||||
|
|
||||
|
logger.debug(f'machine: {platform.platform()}-{platform.processor()}') |
||||
|
logger.debug(f'free/available/total mem: {round(psutil.virtual_memory().free / (1024.0 ** 3))}' |
||||
|
f'/{round(psutil.virtual_memory().available / (1024.0 ** 3))}' |
||||
|
f'/{round(psutil.virtual_memory().total / (1024.0 ** 3))} GB') |
||||
|
logger.debug(f'cpu: {psutil.cpu_count()}') |
||||
|
|
||||
|
|
||||
|
status = None |
||||
for name in models: |
for name in models: |
||||
f.write(f'{name},') |
|
||||
|
logger.info(f'***{name}***') |
||||
|
saved_name = name.replace('/', '-') |
||||
|
onnx_path = f'saved/onnx/{saved_name}.onnx' |
||||
|
if status: |
||||
|
f.write(','.join(status) + '\n') |
||||
|
status = [name] + ['fail'] * 5 |
||||
|
|
||||
|
op = TimmImage(model_name=name, device='cpu') |
||||
|
data = torch.rand((1,) + op.config['input_size']) |
||||
try: |
try: |
||||
op = TimmImage(model_name=name) |
|
||||
out1 = op(data) |
|
||||
f.write('success,') |
|
||||
|
out1 = op.model(data).detach().numpy() |
||||
|
logger.info('OP LOADED.') |
||||
|
status[1] = 'success' |
||||
except Exception as e: |
except Exception as e: |
||||
f.write('fail\n') |
|
||||
print(f'Fail to load op for {name}: {e}') |
|
||||
|
logger.error(f'FAIL TO LOAD OP: {e}') |
||||
continue |
continue |
||||
try: |
try: |
||||
op.save_model(format='onnx') |
op.save_model(format='onnx') |
||||
f.write('success,') |
|
||||
|
logger.info('ONNX SAVED.') |
||||
|
status[2] = 'success' |
||||
except Exception as e: |
except Exception as e: |
||||
f.write('fail\n') |
|
||||
print(f'Fail to save onnx for {name}: {e}') |
|
||||
|
logger.error(f'FAIL TO SAVE ONNX: {e}') |
||||
continue |
continue |
||||
try: |
try: |
||||
saved_name = name.replace('/', '-') |
|
||||
onnx_model = onnx.load(f'saved/onnx/{saved_name}.onnx') |
|
||||
|
try: |
||||
|
onnx_model = onnx.load(onnx_path) |
||||
onnx.checker.check_model(onnx_model) |
onnx.checker.check_model(onnx_model) |
||||
f.write('success') |
|
||||
|
except Exception: |
||||
|
saved_onnx = onnx.load(onnx_path, load_external_data=False) |
||||
|
onnx.checker.check_model(saved_onnx) |
||||
|
logger.info('ONNX CHECKED.') |
||||
|
status[3] = 'success' |
||||
except Exception as e: |
except Exception as e: |
||||
f.write('fail\n') |
|
||||
print(f'Fail to check onnx for {name}: {e}') |
|
||||
|
logger.error(f'FAIL TO CHECK ONNX: {e}') |
||||
continue |
continue |
||||
f.write('\n') |
|
||||
|
try: |
||||
|
sess = onnxruntime.InferenceSession(onnx_path, |
||||
|
providers=onnxruntime.get_available_providers()) |
||||
|
|
||||
|
out2 = sess.run(None, input_feed={'input_0': data.detach().numpy()}) |
||||
|
logger.info('ONNX WORKED.') |
||||
|
status[4] = 'success' |
||||
|
if numpy.allclose(out1, out2, atol=atol): |
||||
|
logger.info('Check accuracy: OK') |
||||
|
status[5] = 'success' |
||||
|
else: |
||||
|
logger.info(f'Check accuracy: atol is larger than {atol}.') |
||||
|
except Exception as e: |
||||
|
logger.error(f'FAIL TO RUN ONNX: {e}') |
||||
|
continue |
||||
|
|
||||
|
if status: |
||||
|
f.write(','.join(status) + '\n') |
||||
|
|
||||
print('Finished.') |
print('Finished.') |
Loading…
Reference in new issue