timm
copied
Jael Gu
2 years ago
4 changed files with 271 additions and 24 deletions
@ -0,0 +1,110 @@ |
|||||
|
from towhee import ops |
||||
|
from timm_image import TimmImage |
||||
|
import onnx |
||||
|
|
||||
|
f = open('onnx.csv', 'a+') |
||||
|
f.write('model_name, run_op, save_onnx, check_onnx\n') |
||||
|
|
||||
|
# models = TimmImage.supported_model_names()[:1] |
||||
|
models = [ |
||||
|
# 'vgg11', |
||||
|
'resnet18', |
||||
|
# 'resnetv2_50', |
||||
|
# 'seresnet33ts', |
||||
|
# 'skresnet18', |
||||
|
# 'resnext26ts', |
||||
|
# 'seresnext26d_32x4d', |
||||
|
# 'skresnext50_32x4d', |
||||
|
# 'convit_base', |
||||
|
# 'inception_v4', |
||||
|
# 'efficientnet_b0', |
||||
|
'tf_efficientnet_b0', |
||||
|
# 'swin_base_patch4_window7_224', |
||||
|
# 'vit_base_patch8_224', |
||||
|
# 'beit_base_patch16_224', |
||||
|
# 'convnext_base', |
||||
|
# 'crossvit_9_240', |
||||
|
# 'convmixer_768_32', |
||||
|
# 'coat_lite_mini', |
||||
|
# 'inception_v3', |
||||
|
# 'cait_m36_384', |
||||
|
# 'cspdarknet53', |
||||
|
# 'deit_base_distilled_patch16_224', |
||||
|
# 'densenet121', |
||||
|
# 'dla34', |
||||
|
# 'dm_nfnet_f0', |
||||
|
# 'nf_regnet_b1', |
||||
|
# 'nf_resnet50', |
||||
|
# 'dpn68', |
||||
|
# 'ese_vovnet19b_dw', |
||||
|
# 'fbnetc_100', |
||||
|
# 'fbnetv3_b', |
||||
|
# 'halonet26t', |
||||
|
# 'eca_halonext26ts', |
||||
|
# 'sehalonet33ts', |
||||
|
# 'hardcorenas_a', |
||||
|
# 'hrnet_w18', |
||||
|
# 'jx_nest_base', |
||||
|
# 'lcnet_050', |
||||
|
# 'levit_128', |
||||
|
# 'mixer_b16_224', |
||||
|
# 'mixnet_s', |
||||
|
# 'mnasnet_100', |
||||
|
# 'mobilenetv2_050', |
||||
|
# 'mobilenetv3_large_100', |
||||
|
# 'nasnetalarge', |
||||
|
# 'pit_b_224', |
||||
|
# 'pnasnet5large', |
||||
|
# 'regnetx_002', |
||||
|
# 'repvgg_a2', |
||||
|
# 'res2net50_14w_8s', |
||||
|
# 'res2next50', |
||||
|
# 'resmlp_12_224', |
||||
|
# 'resnest14d', |
||||
|
# 'rexnet_100', |
||||
|
# 'selecsls42b', |
||||
|
# 'semnasnet_075', |
||||
|
# 'tinynet_a', |
||||
|
# 'tnt_s_patch16_224', |
||||
|
# 'tresnet_l', |
||||
|
# 'twins_pcpvt_base', |
||||
|
# 'visformer_small', |
||||
|
# 'xception', |
||||
|
# 'xcit_large_24_p8_224', |
||||
|
# 'ghostnet_100', |
||||
|
# 'gmlp_s16_224', |
||||
|
# 'lambda_resnet26rpt_256', |
||||
|
# 'spnasnet_100', |
||||
|
] |
||||
|
|
||||
|
decoder = ops.image_decode() |
||||
|
data = decoder('./towhee.jpeg') |
||||
|
|
||||
|
for name in models: |
||||
|
f.write(f'{name},') |
||||
|
try: |
||||
|
op = TimmImage(model_name=name) |
||||
|
out1 = op(data) |
||||
|
f.write('success,') |
||||
|
except Exception as e: |
||||
|
f.write('fail') |
||||
|
print(f'Fail to load op for {name}: {e}') |
||||
|
continue |
||||
|
try: |
||||
|
op.save_model(format='onnx') |
||||
|
f.write('success,') |
||||
|
except Exception as e: |
||||
|
f.write('fail') |
||||
|
print(f'Fail to save onnx for {name}: {e}') |
||||
|
continue |
||||
|
try: |
||||
|
saved_name = name.replace('/', '-') |
||||
|
onnx_model = onnx.load(f'saved/onnx/{saved_name}.onnx') |
||||
|
onnx.checker.check_model(onnx_model) |
||||
|
f.write('success') |
||||
|
except Exception as e: |
||||
|
f.write('fail') |
||||
|
print(f'Fail to check onnx for {name}: {e}') |
||||
|
continue |
||||
|
f.write('\n') |
||||
|
print('Finished.') |
@ -0,0 +1,111 @@ |
|||||
|
from towhee import ops |
||||
|
from timm_image import TimmImage |
||||
|
import torch |
||||
|
|
||||
|
f = open('torchscript.csv', 'a+') |
||||
|
f.write('model_name,run_op,save_torchscript,check_result\n') |
||||
|
|
||||
|
# models = TimmImage.supported_model_names()[:1] |
||||
|
models = [ |
||||
|
# 'vgg11', |
||||
|
'resnet18', |
||||
|
# 'resnetv2_50', |
||||
|
# 'seresnet33ts', |
||||
|
# 'skresnet18', |
||||
|
# 'resnext26ts', |
||||
|
# 'seresnext26d_32x4d', |
||||
|
# 'skresnext50_32x4d', |
||||
|
# 'convit_base', |
||||
|
# 'inception_v4', |
||||
|
# 'efficientnet_b0', |
||||
|
'tf_efficientnet_b0', |
||||
|
# 'swin_base_patch4_window7_224', |
||||
|
# 'vit_base_patch8_224', |
||||
|
# 'beit_base_patch16_224', |
||||
|
# 'convnext_base', |
||||
|
# 'crossvit_9_240', |
||||
|
# 'convmixer_768_32', |
||||
|
# 'coat_lite_mini', |
||||
|
# 'inception_v3', |
||||
|
# 'cait_m36_384', |
||||
|
# 'cspdarknet53', |
||||
|
# 'deit_base_distilled_patch16_224', |
||||
|
# 'densenet121', |
||||
|
# 'dla34', |
||||
|
# 'dm_nfnet_f0', |
||||
|
# 'nf_regnet_b1', |
||||
|
# 'nf_resnet50', |
||||
|
# 'dpn68', |
||||
|
# 'ese_vovnet19b_dw', |
||||
|
# 'fbnetc_100', |
||||
|
# 'fbnetv3_b', |
||||
|
# 'halonet26t', |
||||
|
# 'eca_halonext26ts', |
||||
|
# 'sehalonet33ts', |
||||
|
# 'hardcorenas_a', |
||||
|
# 'hrnet_w18', |
||||
|
# 'jx_nest_base', |
||||
|
# 'lcnet_050', |
||||
|
# 'levit_128', |
||||
|
# 'mixer_b16_224', |
||||
|
# 'mixnet_s', |
||||
|
# 'mnasnet_100', |
||||
|
# 'mobilenetv2_050', |
||||
|
# 'mobilenetv3_large_100', |
||||
|
# 'nasnetalarge', |
||||
|
# 'pit_b_224', |
||||
|
# 'pnasnet5large', |
||||
|
# 'regnetx_002', |
||||
|
# 'repvgg_a2', |
||||
|
# 'res2net50_14w_8s', |
||||
|
# 'res2next50', |
||||
|
# 'resmlp_12_224', |
||||
|
# 'resnest14d', |
||||
|
# 'rexnet_100', |
||||
|
# 'selecsls42b', |
||||
|
# 'semnasnet_075', |
||||
|
# 'tinynet_a', |
||||
|
# 'tnt_s_patch16_224', |
||||
|
# 'tresnet_l', |
||||
|
# 'twins_pcpvt_base', |
||||
|
# 'visformer_small', |
||||
|
# 'xception', |
||||
|
# 'xcit_large_24_p8_224', |
||||
|
# 'ghostnet_100', |
||||
|
# 'gmlp_s16_224', |
||||
|
# 'lambda_resnet26rpt_256', |
||||
|
# 'spnasnet_100', |
||||
|
] |
||||
|
|
||||
|
decoder = ops.image_decode() |
||||
|
data = decoder('./towhee.jpeg') |
||||
|
|
||||
|
for name in models: |
||||
|
f.write(f'{name},') |
||||
|
try: |
||||
|
op = TimmImage(model_name=name) |
||||
|
out1 = op(data) |
||||
|
f.write('success,') |
||||
|
except Exception as e: |
||||
|
f.write('fail') |
||||
|
print(f'Fail to load op for {name}: {e}') |
||||
|
continue |
||||
|
try: |
||||
|
op.save_model(format='torchscript') |
||||
|
f.write('success,') |
||||
|
except Exception as e: |
||||
|
f.write('fail') |
||||
|
print(f'Fail to save onnx for {name}: {e}') |
||||
|
continue |
||||
|
try: |
||||
|
saved_name = name.replace('/', '-') |
||||
|
op.model = torch.jit.load(f'saved/torchscript/{saved_name}.pt') |
||||
|
out2 = op(data) |
||||
|
assert (out1 == out2).all() |
||||
|
f.write('success') |
||||
|
except Exception as e: |
||||
|
f.write('fail') |
||||
|
print(f'Fail to check onnx for {name}: {e}') |
||||
|
continue |
||||
|
f.write('\n') |
||||
|
print('Finished.') |
After Width: | Height: | Size: 49 KiB |
Loading…
Reference in new issue