timm
copied
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions
112 lines
2.6 KiB
112 lines
2.6 KiB
2 years ago
|
from towhee import ops
|
||
|
from timm_image import TimmImage
|
||
|
import torch
|
||
|
|
||
|
f = open('torchscript.csv', 'a+')
|
||
|
f.write('model_name,run_op,save_torchscript,check_result\n')
|
||
|
|
||
|
# models = TimmImage.supported_model_names()[:1]
|
||
|
models = [
|
||
|
# 'vgg11',
|
||
|
'resnet18',
|
||
|
# 'resnetv2_50',
|
||
|
# 'seresnet33ts',
|
||
|
# 'skresnet18',
|
||
|
# 'resnext26ts',
|
||
|
# 'seresnext26d_32x4d',
|
||
|
# 'skresnext50_32x4d',
|
||
|
# 'convit_base',
|
||
|
# 'inception_v4',
|
||
|
# 'efficientnet_b0',
|
||
|
'tf_efficientnet_b0',
|
||
|
# 'swin_base_patch4_window7_224',
|
||
|
# 'vit_base_patch8_224',
|
||
|
# 'beit_base_patch16_224',
|
||
|
# 'convnext_base',
|
||
|
# 'crossvit_9_240',
|
||
|
# 'convmixer_768_32',
|
||
|
# 'coat_lite_mini',
|
||
|
# 'inception_v3',
|
||
|
# 'cait_m36_384',
|
||
|
# 'cspdarknet53',
|
||
|
# 'deit_base_distilled_patch16_224',
|
||
|
# 'densenet121',
|
||
|
# 'dla34',
|
||
|
# 'dm_nfnet_f0',
|
||
|
# 'nf_regnet_b1',
|
||
|
# 'nf_resnet50',
|
||
|
# 'dpn68',
|
||
|
# 'ese_vovnet19b_dw',
|
||
|
# 'fbnetc_100',
|
||
|
# 'fbnetv3_b',
|
||
|
# 'halonet26t',
|
||
|
# 'eca_halonext26ts',
|
||
|
# 'sehalonet33ts',
|
||
|
# 'hardcorenas_a',
|
||
|
# 'hrnet_w18',
|
||
|
# 'jx_nest_base',
|
||
|
# 'lcnet_050',
|
||
|
# 'levit_128',
|
||
|
# 'mixer_b16_224',
|
||
|
# 'mixnet_s',
|
||
|
# 'mnasnet_100',
|
||
|
# 'mobilenetv2_050',
|
||
|
# 'mobilenetv3_large_100',
|
||
|
# 'nasnetalarge',
|
||
|
# 'pit_b_224',
|
||
|
# 'pnasnet5large',
|
||
|
# 'regnetx_002',
|
||
|
# 'repvgg_a2',
|
||
|
# 'res2net50_14w_8s',
|
||
|
# 'res2next50',
|
||
|
# 'resmlp_12_224',
|
||
|
# 'resnest14d',
|
||
|
# 'rexnet_100',
|
||
|
# 'selecsls42b',
|
||
|
# 'semnasnet_075',
|
||
|
# 'tinynet_a',
|
||
|
# 'tnt_s_patch16_224',
|
||
|
# 'tresnet_l',
|
||
|
# 'twins_pcpvt_base',
|
||
|
# 'visformer_small',
|
||
|
# 'xception',
|
||
|
# 'xcit_large_24_p8_224',
|
||
|
# 'ghostnet_100',
|
||
|
# 'gmlp_s16_224',
|
||
|
# 'lambda_resnet26rpt_256',
|
||
|
# 'spnasnet_100',
|
||
|
]
|
||
|
|
||
|
decoder = ops.image_decode()
|
||
|
data = decoder('./towhee.jpeg')
|
||
|
|
||
|
for name in models:
|
||
|
f.write(f'{name},')
|
||
|
try:
|
||
|
op = TimmImage(model_name=name)
|
||
|
out1 = op(data)
|
||
|
f.write('success,')
|
||
|
except Exception as e:
|
||
|
f.write('fail')
|
||
|
print(f'Fail to load op for {name}: {e}')
|
||
|
continue
|
||
|
try:
|
||
|
op.save_model(format='torchscript')
|
||
|
f.write('success,')
|
||
|
except Exception as e:
|
||
|
f.write('fail')
|
||
|
print(f'Fail to save onnx for {name}: {e}')
|
||
|
continue
|
||
|
try:
|
||
|
saved_name = name.replace('/', '-')
|
||
|
op.model = torch.jit.load(f'saved/torchscript/{saved_name}.pt')
|
||
|
out2 = op(data)
|
||
|
assert (out1 == out2).all()
|
||
|
f.write('success')
|
||
|
except Exception as e:
|
||
|
f.write('fail')
|
||
|
print(f'Fail to check onnx for {name}: {e}')
|
||
|
continue
|
||
|
f.write('\n')
|
||
|
print('Finished.')
|