timm
copied
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions
111 lines
2.5 KiB
111 lines
2.5 KiB
from towhee import ops
|
|
from timm_image import TimmImage
|
|
import torch
|
|
|
|
f = open('torchscript.csv', 'a+')
|
|
f.write('model_name,run_op,save_torchscript,check_result\n')
|
|
|
|
# models = TimmImage.supported_model_names()[:1]
|
|
models = [
|
|
'vgg11',
|
|
'resnet18',
|
|
'resnetv2_50',
|
|
'seresnet33ts',
|
|
'skresnet18',
|
|
'resnext26ts',
|
|
'seresnext26d_32x4d',
|
|
'skresnext50_32x4d',
|
|
'convit_base',
|
|
'inception_v4',
|
|
'efficientnet_b0',
|
|
'tf_efficientnet_b0',
|
|
'swin_base_patch4_window7_224',
|
|
'vit_base_patch8_224',
|
|
'beit_base_patch16_224',
|
|
'convnext_base',
|
|
'crossvit_9_240',
|
|
'convmixer_768_32',
|
|
'coat_lite_mini',
|
|
'inception_v3',
|
|
'cait_m36_384',
|
|
'cspdarknet53',
|
|
'deit_base_distilled_patch16_224',
|
|
'densenet121',
|
|
'dla34',
|
|
'dm_nfnet_f0',
|
|
'nf_regnet_b1',
|
|
'nf_resnet50',
|
|
'dpn68',
|
|
'ese_vovnet19b_dw',
|
|
'fbnetc_100',
|
|
'fbnetv3_b',
|
|
'halonet26t',
|
|
'eca_halonext26ts',
|
|
'sehalonet33ts',
|
|
'hardcorenas_a',
|
|
'hrnet_w18',
|
|
'jx_nest_base',
|
|
'lcnet_050',
|
|
'levit_128',
|
|
'mixer_b16_224',
|
|
'mixnet_s',
|
|
'mnasnet_100',
|
|
'mobilenetv2_050',
|
|
'mobilenetv3_large_100',
|
|
'nasnetalarge',
|
|
'pit_b_224',
|
|
'pnasnet5large',
|
|
'regnetx_002',
|
|
'repvgg_a2',
|
|
'res2net50_14w_8s',
|
|
'res2next50',
|
|
'resmlp_12_224',
|
|
'resnest14d',
|
|
'rexnet_100',
|
|
'selecsls42b',
|
|
'semnasnet_075',
|
|
'tinynet_a',
|
|
'tnt_s_patch16_224',
|
|
'tresnet_l',
|
|
'twins_pcpvt_base',
|
|
'visformer_small',
|
|
'xception',
|
|
'xcit_large_24_p8_224',
|
|
'ghostnet_100',
|
|
'gmlp_s16_224',
|
|
'lambda_resnet26rpt_256',
|
|
'spnasnet_100',
|
|
]
|
|
|
|
decoder = ops.image_decode()
|
|
data = decoder('./towhee.jpeg')
|
|
|
|
for name in models:
|
|
f.write(f'{name},')
|
|
try:
|
|
op = TimmImage(model_name=name)
|
|
out1 = op(data)
|
|
f.write('success,')
|
|
except Exception as e:
|
|
f.write('fail\n')
|
|
print(f'Fail to load op for {name}: {e}')
|
|
continue
|
|
try:
|
|
op.save_model(format='torchscript')
|
|
f.write('success,')
|
|
except Exception as e:
|
|
f.write('fail\n')
|
|
print(f'Fail to save onnx for {name}: {e}')
|
|
continue
|
|
try:
|
|
saved_name = name.replace('/', '-')
|
|
op.model = torch.jit.load(f'saved/torchscript/{saved_name}.pt')
|
|
out2 = op(data)
|
|
assert (out1 == out2).all()
|
|
f.write('success')
|
|
except Exception as e:
|
|
f.write('fail\n')
|
|
print(f'Fail to check onnx for {name}: {e}')
|
|
continue
|
|
f.write('\n')
|
|
print('Finished.')
|