diff --git a/test_onnx.py b/test_onnx.py new file mode 100644 index 0000000..a1a5684 --- /dev/null +++ b/test_onnx.py @@ -0,0 +1,110 @@ +from towhee import ops +from timm_image import TimmImage +import onnx + +f = open('onnx.csv', 'a+') +f.write('model_name, run_op, save_onnx, check_onnx\n') + +# models = TimmImage.supported_model_names()[:1] +models = [ + # 'vgg11', + 'resnet18', + # 'resnetv2_50', + # 'seresnet33ts', + # 'skresnet18', + # 'resnext26ts', + # 'seresnext26d_32x4d', + # 'skresnext50_32x4d', + # 'convit_base', + # 'inception_v4', + # 'efficientnet_b0', + 'tf_efficientnet_b0', + # 'swin_base_patch4_window7_224', + # 'vit_base_patch8_224', + # 'beit_base_patch16_224', + # 'convnext_base', + # 'crossvit_9_240', + # 'convmixer_768_32', + # 'coat_lite_mini', + # 'inception_v3', + # 'cait_m36_384', + # 'cspdarknet53', + # 'deit_base_distilled_patch16_224', + # 'densenet121', + # 'dla34', + # 'dm_nfnet_f0', + # 'nf_regnet_b1', + # 'nf_resnet50', + # 'dpn68', + # 'ese_vovnet19b_dw', + # 'fbnetc_100', + # 'fbnetv3_b', + # 'halonet26t', + # 'eca_halonext26ts', + # 'sehalonet33ts', + # 'hardcorenas_a', + # 'hrnet_w18', + # 'jx_nest_base', + # 'lcnet_050', + # 'levit_128', + # 'mixer_b16_224', + # 'mixnet_s', + # 'mnasnet_100', + # 'mobilenetv2_050', + # 'mobilenetv3_large_100', + # 'nasnetalarge', + # 'pit_b_224', + # 'pnasnet5large', + # 'regnetx_002', + # 'repvgg_a2', + # 'res2net50_14w_8s', + # 'res2next50', + # 'resmlp_12_224', + # 'resnest14d', + # 'rexnet_100', + # 'selecsls42b', + # 'semnasnet_075', + # 'tinynet_a', + # 'tnt_s_patch16_224', + # 'tresnet_l', + # 'twins_pcpvt_base', + # 'visformer_small', + # 'xception', + # 'xcit_large_24_p8_224', + # 'ghostnet_100', + # 'gmlp_s16_224', + # 'lambda_resnet26rpt_256', + # 'spnasnet_100', +] + +decoder = ops.image_decode() +data = decoder('./towhee.jpeg') + +for name in models: + f.write(f'{name},') + try: + op = TimmImage(model_name=name) + out1 = op(data) + f.write('success,') + except Exception as e: + f.write('fail') + print(f'Fail to load op for {name}: {e}') + continue + try: + op.save_model(format='onnx') + f.write('success,') + except Exception as e: + f.write('fail') + print(f'Fail to save onnx for {name}: {e}') + continue + try: + saved_name = name.replace('/', '-') + onnx_model = onnx.load(f'saved/onnx/{saved_name}.onnx') + onnx.checker.check_model(onnx_model) + f.write('success') + except Exception as e: + f.write('fail') + print(f'Fail to check onnx for {name}: {e}') + continue + f.write('\n') +print('Finished.') diff --git a/test_torchscript.py b/test_torchscript.py new file mode 100644 index 0000000..2461a4f --- /dev/null +++ b/test_torchscript.py @@ -0,0 +1,111 @@ +from towhee import ops +from timm_image import TimmImage +import torch + +f = open('torchscript.csv', 'a+') +f.write('model_name,run_op,save_torchscript,check_result\n') + +# models = TimmImage.supported_model_names()[:1] +models = [ + # 'vgg11', + 'resnet18', + # 'resnetv2_50', + # 'seresnet33ts', + # 'skresnet18', + # 'resnext26ts', + # 'seresnext26d_32x4d', + # 'skresnext50_32x4d', + # 'convit_base', + # 'inception_v4', + # 'efficientnet_b0', + 'tf_efficientnet_b0', + # 'swin_base_patch4_window7_224', + # 'vit_base_patch8_224', + # 'beit_base_patch16_224', + # 'convnext_base', + # 'crossvit_9_240', + # 'convmixer_768_32', + # 'coat_lite_mini', + # 'inception_v3', + # 'cait_m36_384', + # 'cspdarknet53', + # 'deit_base_distilled_patch16_224', + # 'densenet121', + # 'dla34', + # 'dm_nfnet_f0', + # 'nf_regnet_b1', + # 'nf_resnet50', + # 'dpn68', + # 'ese_vovnet19b_dw', + # 'fbnetc_100', + # 'fbnetv3_b', + # 'halonet26t', + # 'eca_halonext26ts', + # 'sehalonet33ts', + # 'hardcorenas_a', + # 'hrnet_w18', + # 'jx_nest_base', + # 'lcnet_050', + # 'levit_128', + # 'mixer_b16_224', + # 'mixnet_s', + # 'mnasnet_100', + # 'mobilenetv2_050', + # 'mobilenetv3_large_100', + # 'nasnetalarge', + # 'pit_b_224', + # 'pnasnet5large', + # 'regnetx_002', + # 'repvgg_a2', + # 'res2net50_14w_8s', + # 'res2next50', + # 'resmlp_12_224', + # 'resnest14d', + # 'rexnet_100', + # 'selecsls42b', + # 'semnasnet_075', + # 'tinynet_a', + # 'tnt_s_patch16_224', + # 'tresnet_l', + # 'twins_pcpvt_base', + # 'visformer_small', + # 'xception', + # 'xcit_large_24_p8_224', + # 'ghostnet_100', + # 'gmlp_s16_224', + # 'lambda_resnet26rpt_256', + # 'spnasnet_100', +] + +decoder = ops.image_decode() +data = decoder('./towhee.jpeg') + +for name in models: + f.write(f'{name},') + try: + op = TimmImage(model_name=name) + out1 = op(data) + f.write('success,') + except Exception as e: + f.write('fail') + print(f'Fail to load op for {name}: {e}') + continue + try: + op.save_model(format='torchscript') + f.write('success,') + except Exception as e: + f.write('fail') + print(f'Fail to save onnx for {name}: {e}') + continue + try: + saved_name = name.replace('/', '-') + op.model = torch.jit.load(f'saved/torchscript/{saved_name}.pt') + out2 = op(data) + assert (out1 == out2).all() + f.write('success') + except Exception as e: + f.write('fail') + print(f'Fail to check onnx for {name}: {e}') + continue + f.write('\n') +print('Finished.') diff --git a/timm_image.py b/timm_image.py index 1b80e55..c7f3843 100644 --- a/timm_image.py +++ b/timm_image.py @@ -27,6 +27,7 @@ from torch import nn from PIL import Image as PILImage +import timm from timm.data.transforms_factory import create_transform from timm.data import resolve_data_config from timm.models.factory import create_model @@ -57,8 +58,8 @@ class TimmImage(NNOperator): self.model = create_model(self.model_name, pretrained=True, num_classes=num_classes) self.model.to(self.device) self.model.eval() - config = resolve_data_config({}, model=self.model) - self.tfms = create_transform(**config) + self.config = resolve_data_config({}, model=self.model) + self.tfms = create_transform(**self.config) self.skip_tfms = skip_preprocess @arg(1, to_image_color('RGB')) @@ -76,31 +77,56 @@ class TimmImage(NNOperator): vec = features.flatten().detach().numpy() return vec - def save_model(self, jit: bool = True, destination: str = 'default'): - if destination == 'default': + def save_model(self, format: str = 'pytorch', path: str = 'default'): + if path == 'default': path = str(Path(__file__).parent) - destination = os.path.join(path, self.model_name + '.pt') - if jit: + path = os.path.join(path, 'saved', format) + os.makedirs(path, exist_ok=True) + name = self.model_name.replace('/', '-') + path = os.path.join(path, name) + dummy_input = torch.rand((1,) + self.config['input_size']) + if format == 'pytorch': + path = path + '.pt' + torch.save(self.model, path) + elif format == 'torchscript': + path = path + '.pt' try: - traced_model = torch.jit.script(self.model) - torch.jit.save(traced_model, destination) + try: + jit_model = torch.jit.script(self.model) + except Exception: + jit_model = torch.jit.trace(self.model, dummy_input, strict=False) + torch.jit.save(jit_model, path) except Exception as e: + log.error(f'Fail to save as torchscript: {e}.') raise RuntimeError(f'Fail to save as torchscript: {e}.') + elif format == 'onnx': + path = path + '.onnx' + try: + torch.onnx.export(self.model, + dummy_input, + path, + input_names=["input"], + output_names=["output"], + opset_version=12) + except Exception as e: + log.error(f'Fail to save as onnx: {e}.') + raise RuntimeError(f'Fail to save as onnx: {e}.') + # todo: elif format == 'tensorrt': else: - torch.save(self.model, destination) + log.error(f'Unsupported format "{format}".') -# if __name__ == '__main__': -# from towhee import ops -# -# path = '/image/path/or/link' -# -# decoder = ops.image_decode.cv2() -# img = decoder(path) -# -# op = TimmImage('resnet50') -# out = op(img) -# print(out) -# -# op.model = torch.jit.load('resnet50.pt') -# out2 = op(img) -# print(out2) + @staticmethod + def supported_model_names(format: str = None): + full_list = timm.list_models(pretrained=True) + full_list.sort() + if format is None: + model_list = full_list + elif format == 'pytorch': + to_remove = [] + assert set(to_remove).issubset(set(full_list)) + model_list = list(set(full_list) - set(to_remove)) + # todo: elif format == 'torchscript': + # todo: elif format == 'tensorrt' + else: + log.error(f'Invalid format "{format}". Currently supported formats: "pytorch".') + return model_list diff --git a/towhee.jpeg b/towhee.jpeg new file mode 100644 index 0000000..caf63b3 Binary files /dev/null and b/towhee.jpeg differ