from towhee import ops from timm_image import TimmImage import torch f = open('torchscript.csv', 'a+') f.write('model_name,run_op,save_torchscript,check_result\n') # models = TimmImage.supported_model_names()[:1] models = [ 'vgg11', 'resnet18', 'resnetv2_50', 'seresnet33ts', 'skresnet18', 'resnext26ts', 'seresnext26d_32x4d', 'skresnext50_32x4d', 'convit_base', 'inception_v4', 'efficientnet_b0', 'tf_efficientnet_b0', 'swin_base_patch4_window7_224', 'vit_base_patch8_224', 'beit_base_patch16_224', 'convnext_base', 'crossvit_9_240', 'convmixer_768_32', 'coat_lite_mini', 'inception_v3', 'cait_m36_384', 'cspdarknet53', 'deit_base_distilled_patch16_224', 'densenet121', 'dla34', 'dm_nfnet_f0', 'nf_regnet_b1', 'nf_resnet50', 'dpn68', 'ese_vovnet19b_dw', 'fbnetc_100', 'fbnetv3_b', 'halonet26t', 'eca_halonext26ts', 'sehalonet33ts', 'hardcorenas_a', 'hrnet_w18', 'jx_nest_base', 'lcnet_050', 'levit_128', 'mixer_b16_224', 'mixnet_s', 'mnasnet_100', 'mobilenetv2_050', 'mobilenetv3_large_100', 'nasnetalarge', 'pit_b_224', 'pnasnet5large', 'regnetx_002', 'repvgg_a2', 'res2net50_14w_8s', 'res2next50', 'resmlp_12_224', 'resnest14d', 'rexnet_100', 'selecsls42b', 'semnasnet_075', 'tinynet_a', 'tnt_s_patch16_224', 'tresnet_l', 'twins_pcpvt_base', 'visformer_small', 'xception', 'xcit_large_24_p8_224', 'ghostnet_100', 'gmlp_s16_224', 'lambda_resnet26rpt_256', 'spnasnet_100', ] decoder = ops.image_decode() data = decoder('./towhee.jpeg') for name in models: f.write(f'{name},') try: op = TimmImage(model_name=name) out1 = op(data) f.write('success,') except Exception as e: f.write('fail\n') print(f'Fail to load op for {name}: {e}') continue try: op.save_model(format='torchscript') f.write('success,') except Exception as e: f.write('fail\n') print(f'Fail to save onnx for {name}: {e}') continue try: saved_name = name.replace('/', '-') op.model = torch.jit.load(f'saved/torchscript/{saved_name}.pt') out2 = op(data) assert (out1 == out2).all() f.write('success') except Exception as e: f.write('fail\n') print(f'Fail to check onnx for {name}: {e}') continue f.write('\n') print('Finished.')