logo
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions

111 lines
2.5 KiB

from towhee import ops
from timm_image import TimmImage
import torch
f = open('torchscript.csv', 'a+')
f.write('model_name,run_op,save_torchscript,check_result\n')
# models = TimmImage.supported_model_names()[:1]
models = [
'vgg11',
'resnet18',
'resnetv2_50',
'seresnet33ts',
'skresnet18',
'resnext26ts',
'seresnext26d_32x4d',
'skresnext50_32x4d',
'convit_base',
'inception_v4',
'efficientnet_b0',
'tf_efficientnet_b0',
'swin_base_patch4_window7_224',
'vit_base_patch8_224',
'beit_base_patch16_224',
'convnext_base',
'crossvit_9_240',
'convmixer_768_32',
'coat_lite_mini',
'inception_v3',
'cait_m36_384',
'cspdarknet53',
'deit_base_distilled_patch16_224',
'densenet121',
'dla34',
'dm_nfnet_f0',
'nf_regnet_b1',
'nf_resnet50',
'dpn68',
'ese_vovnet19b_dw',
'fbnetc_100',
'fbnetv3_b',
'halonet26t',
'eca_halonext26ts',
'sehalonet33ts',
'hardcorenas_a',
'hrnet_w18',
'jx_nest_base',
'lcnet_050',
'levit_128',
'mixer_b16_224',
'mixnet_s',
'mnasnet_100',
'mobilenetv2_050',
'mobilenetv3_large_100',
'nasnetalarge',
'pit_b_224',
'pnasnet5large',
'regnetx_002',
'repvgg_a2',
'res2net50_14w_8s',
'res2next50',
'resmlp_12_224',
'resnest14d',
'rexnet_100',
'selecsls42b',
'semnasnet_075',
'tinynet_a',
'tnt_s_patch16_224',
'tresnet_l',
'twins_pcpvt_base',
'visformer_small',
'xception',
'xcit_large_24_p8_224',
'ghostnet_100',
'gmlp_s16_224',
'lambda_resnet26rpt_256',
'spnasnet_100',
]
decoder = ops.image_decode()
data = decoder('./towhee.jpeg')
for name in models:
f.write(f'{name},')
try:
op = TimmImage(model_name=name)
out1 = op(data)
f.write('success,')
except Exception as e:
f.write('fail\n')
print(f'Fail to load op for {name}: {e}')
continue
try:
op.save_model(format='torchscript')
f.write('success,')
except Exception as e:
f.write('fail\n')
print(f'Fail to save onnx for {name}: {e}')
continue
try:
saved_name = name.replace('/', '-')
op.model = torch.jit.load(f'saved/torchscript/{saved_name}.pt')
out2 = op(data)
assert (out1 == out2).all()
f.write('success')
except Exception as e:
f.write('fail\n')
print(f'Fail to check onnx for {name}: {e}')
continue
f.write('\n')
print('Finished.')