logo
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions

111 lines
2.6 KiB

from towhee import ops
from timm_image import TimmImage
import torch
f = open('torchscript.csv', 'a+')
f.write('model_name,run_op,save_torchscript,check_result\n')
# models = TimmImage.supported_model_names()[:1]
models = [
# 'vgg11',
'resnet18',
# 'resnetv2_50',
# 'seresnet33ts',
# 'skresnet18',
# 'resnext26ts',
# 'seresnext26d_32x4d',
# 'skresnext50_32x4d',
# 'convit_base',
# 'inception_v4',
# 'efficientnet_b0',
'tf_efficientnet_b0',
# 'swin_base_patch4_window7_224',
# 'vit_base_patch8_224',
# 'beit_base_patch16_224',
# 'convnext_base',
# 'crossvit_9_240',
# 'convmixer_768_32',
# 'coat_lite_mini',
# 'inception_v3',
# 'cait_m36_384',
# 'cspdarknet53',
# 'deit_base_distilled_patch16_224',
# 'densenet121',
# 'dla34',
# 'dm_nfnet_f0',
# 'nf_regnet_b1',
# 'nf_resnet50',
# 'dpn68',
# 'ese_vovnet19b_dw',
# 'fbnetc_100',
# 'fbnetv3_b',
# 'halonet26t',
# 'eca_halonext26ts',
# 'sehalonet33ts',
# 'hardcorenas_a',
# 'hrnet_w18',
# 'jx_nest_base',
# 'lcnet_050',
# 'levit_128',
# 'mixer_b16_224',
# 'mixnet_s',
# 'mnasnet_100',
# 'mobilenetv2_050',
# 'mobilenetv3_large_100',
# 'nasnetalarge',
# 'pit_b_224',
# 'pnasnet5large',
# 'regnetx_002',
# 'repvgg_a2',
# 'res2net50_14w_8s',
# 'res2next50',
# 'resmlp_12_224',
# 'resnest14d',
# 'rexnet_100',
# 'selecsls42b',
# 'semnasnet_075',
# 'tinynet_a',
# 'tnt_s_patch16_224',
# 'tresnet_l',
# 'twins_pcpvt_base',
# 'visformer_small',
# 'xception',
# 'xcit_large_24_p8_224',
# 'ghostnet_100',
# 'gmlp_s16_224',
# 'lambda_resnet26rpt_256',
# 'spnasnet_100',
]
decoder = ops.image_decode()
data = decoder('./towhee.jpeg')
for name in models:
f.write(f'{name},')
try:
op = TimmImage(model_name=name)
out1 = op(data)
f.write('success,')
except Exception as e:
f.write('fail')
print(f'Fail to load op for {name}: {e}')
continue
try:
op.save_model(format='torchscript')
f.write('success,')
except Exception as e:
f.write('fail')
print(f'Fail to save onnx for {name}: {e}')
continue
try:
saved_name = name.replace('/', '-')
op.model = torch.jit.load(f'saved/torchscript/{saved_name}.pt')
out2 = op(data)
assert (out1 == out2).all()
f.write('success')
except Exception as e:
f.write('fail')
print(f'Fail to check onnx for {name}: {e}')
continue
f.write('\n')
print('Finished.')