logo
Browse Source

Add save_model & suppported_model_names

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 2 years ago
parent
commit
107529e868
  1. 110
      test_onnx.py
  2. 111
      test_torchscript.py
  3. 74
      timm_image.py
  4. BIN
      towhee.jpeg

110
test_onnx.py

@ -0,0 +1,110 @@
from towhee import ops
from timm_image import TimmImage
import onnx
f = open('onnx.csv', 'a+')
f.write('model_name, run_op, save_onnx, check_onnx\n')
# models = TimmImage.supported_model_names()[:1]
models = [
# 'vgg11',
'resnet18',
# 'resnetv2_50',
# 'seresnet33ts',
# 'skresnet18',
# 'resnext26ts',
# 'seresnext26d_32x4d',
# 'skresnext50_32x4d',
# 'convit_base',
# 'inception_v4',
# 'efficientnet_b0',
'tf_efficientnet_b0',
# 'swin_base_patch4_window7_224',
# 'vit_base_patch8_224',
# 'beit_base_patch16_224',
# 'convnext_base',
# 'crossvit_9_240',
# 'convmixer_768_32',
# 'coat_lite_mini',
# 'inception_v3',
# 'cait_m36_384',
# 'cspdarknet53',
# 'deit_base_distilled_patch16_224',
# 'densenet121',
# 'dla34',
# 'dm_nfnet_f0',
# 'nf_regnet_b1',
# 'nf_resnet50',
# 'dpn68',
# 'ese_vovnet19b_dw',
# 'fbnetc_100',
# 'fbnetv3_b',
# 'halonet26t',
# 'eca_halonext26ts',
# 'sehalonet33ts',
# 'hardcorenas_a',
# 'hrnet_w18',
# 'jx_nest_base',
# 'lcnet_050',
# 'levit_128',
# 'mixer_b16_224',
# 'mixnet_s',
# 'mnasnet_100',
# 'mobilenetv2_050',
# 'mobilenetv3_large_100',
# 'nasnetalarge',
# 'pit_b_224',
# 'pnasnet5large',
# 'regnetx_002',
# 'repvgg_a2',
# 'res2net50_14w_8s',
# 'res2next50',
# 'resmlp_12_224',
# 'resnest14d',
# 'rexnet_100',
# 'selecsls42b',
# 'semnasnet_075',
# 'tinynet_a',
# 'tnt_s_patch16_224',
# 'tresnet_l',
# 'twins_pcpvt_base',
# 'visformer_small',
# 'xception',
# 'xcit_large_24_p8_224',
# 'ghostnet_100',
# 'gmlp_s16_224',
# 'lambda_resnet26rpt_256',
# 'spnasnet_100',
]
decoder = ops.image_decode()
data = decoder('./towhee.jpeg')
for name in models:
f.write(f'{name},')
try:
op = TimmImage(model_name=name)
out1 = op(data)
f.write('success,')
except Exception as e:
f.write('fail')
print(f'Fail to load op for {name}: {e}')
continue
try:
op.save_model(format='onnx')
f.write('success,')
except Exception as e:
f.write('fail')
print(f'Fail to save onnx for {name}: {e}')
continue
try:
saved_name = name.replace('/', '-')
onnx_model = onnx.load(f'saved/onnx/{saved_name}.onnx')
onnx.checker.check_model(onnx_model)
f.write('success')
except Exception as e:
f.write('fail')
print(f'Fail to check onnx for {name}: {e}')
continue
f.write('\n')
print('Finished.')

111
test_torchscript.py

@ -0,0 +1,111 @@
from towhee import ops
from timm_image import TimmImage
import torch
f = open('torchscript.csv', 'a+')
f.write('model_name,run_op,save_torchscript,check_result\n')
# models = TimmImage.supported_model_names()[:1]
models = [
# 'vgg11',
'resnet18',
# 'resnetv2_50',
# 'seresnet33ts',
# 'skresnet18',
# 'resnext26ts',
# 'seresnext26d_32x4d',
# 'skresnext50_32x4d',
# 'convit_base',
# 'inception_v4',
# 'efficientnet_b0',
'tf_efficientnet_b0',
# 'swin_base_patch4_window7_224',
# 'vit_base_patch8_224',
# 'beit_base_patch16_224',
# 'convnext_base',
# 'crossvit_9_240',
# 'convmixer_768_32',
# 'coat_lite_mini',
# 'inception_v3',
# 'cait_m36_384',
# 'cspdarknet53',
# 'deit_base_distilled_patch16_224',
# 'densenet121',
# 'dla34',
# 'dm_nfnet_f0',
# 'nf_regnet_b1',
# 'nf_resnet50',
# 'dpn68',
# 'ese_vovnet19b_dw',
# 'fbnetc_100',
# 'fbnetv3_b',
# 'halonet26t',
# 'eca_halonext26ts',
# 'sehalonet33ts',
# 'hardcorenas_a',
# 'hrnet_w18',
# 'jx_nest_base',
# 'lcnet_050',
# 'levit_128',
# 'mixer_b16_224',
# 'mixnet_s',
# 'mnasnet_100',
# 'mobilenetv2_050',
# 'mobilenetv3_large_100',
# 'nasnetalarge',
# 'pit_b_224',
# 'pnasnet5large',
# 'regnetx_002',
# 'repvgg_a2',
# 'res2net50_14w_8s',
# 'res2next50',
# 'resmlp_12_224',
# 'resnest14d',
# 'rexnet_100',
# 'selecsls42b',
# 'semnasnet_075',
# 'tinynet_a',
# 'tnt_s_patch16_224',
# 'tresnet_l',
# 'twins_pcpvt_base',
# 'visformer_small',
# 'xception',
# 'xcit_large_24_p8_224',
# 'ghostnet_100',
# 'gmlp_s16_224',
# 'lambda_resnet26rpt_256',
# 'spnasnet_100',
]
decoder = ops.image_decode()
data = decoder('./towhee.jpeg')
for name in models:
f.write(f'{name},')
try:
op = TimmImage(model_name=name)
out1 = op(data)
f.write('success,')
except Exception as e:
f.write('fail')
print(f'Fail to load op for {name}: {e}')
continue
try:
op.save_model(format='torchscript')
f.write('success,')
except Exception as e:
f.write('fail')
print(f'Fail to save onnx for {name}: {e}')
continue
try:
saved_name = name.replace('/', '-')
op.model = torch.jit.load(f'saved/torchscript/{saved_name}.pt')
out2 = op(data)
assert (out1 == out2).all()
f.write('success')
except Exception as e:
f.write('fail')
print(f'Fail to check onnx for {name}: {e}')
continue
f.write('\n')
print('Finished.')

74
timm_image.py

@ -27,6 +27,7 @@ from torch import nn
from PIL import Image as PILImage from PIL import Image as PILImage
import timm
from timm.data.transforms_factory import create_transform from timm.data.transforms_factory import create_transform
from timm.data import resolve_data_config from timm.data import resolve_data_config
from timm.models.factory import create_model from timm.models.factory import create_model
@ -57,8 +58,8 @@ class TimmImage(NNOperator):
self.model = create_model(self.model_name, pretrained=True, num_classes=num_classes) self.model = create_model(self.model_name, pretrained=True, num_classes=num_classes)
self.model.to(self.device) self.model.to(self.device)
self.model.eval() self.model.eval()
config = resolve_data_config({}, model=self.model)
self.tfms = create_transform(**config)
self.config = resolve_data_config({}, model=self.model)
self.tfms = create_transform(**self.config)
self.skip_tfms = skip_preprocess self.skip_tfms = skip_preprocess
@arg(1, to_image_color('RGB')) @arg(1, to_image_color('RGB'))
@ -76,31 +77,56 @@ class TimmImage(NNOperator):
vec = features.flatten().detach().numpy() vec = features.flatten().detach().numpy()
return vec return vec
def save_model(self, jit: bool = True, destination: str = 'default'):
if destination == 'default':
def save_model(self, format: str = 'pytorch', path: str = 'default'):
if path == 'default':
path = str(Path(__file__).parent) path = str(Path(__file__).parent)
destination = os.path.join(path, self.model_name + '.pt')
if jit:
path = os.path.join(path, 'saved', format)
os.makedirs(path, exist_ok=True)
name = self.model_name.replace('/', '-')
path = os.path.join(path, name)
dummy_input = torch.rand((1,) + self.config['input_size'])
if format == 'pytorch':
path = path + '.pt'
torch.save(self.model, path)
elif format == 'torchscript':
path = path + '.pt'
try: try:
traced_model = torch.jit.script(self.model)
torch.jit.save(traced_model, destination)
try:
jit_model = torch.jit.script(self.model)
except Exception:
jit_model = torch.jit.trace(self.model, dummy_input, strict=False)
torch.jit.save(jit_model, path)
except Exception as e: except Exception as e:
log.error(f'Fail to save as torchscript: {e}.')
raise RuntimeError(f'Fail to save as torchscript: {e}.') raise RuntimeError(f'Fail to save as torchscript: {e}.')
elif format == 'onnx':
path = path + '.onnx'
try:
torch.onnx.export(self.model,
dummy_input,
path,
input_names=["input"],
output_names=["output"],
opset_version=12)
except Exception as e:
log.error(f'Fail to save as onnx: {e}.')
raise RuntimeError(f'Fail to save as onnx: {e}.')
# todo: elif format == 'tensorrt':
else: else:
torch.save(self.model, destination)
log.error(f'Unsupported format "{format}".')
# if __name__ == '__main__':
# from towhee import ops
#
# path = '/image/path/or/link'
#
# decoder = ops.image_decode.cv2()
# img = decoder(path)
#
# op = TimmImage('resnet50')
# out = op(img)
# print(out)
#
# op.model = torch.jit.load('resnet50.pt')
# out2 = op(img)
# print(out2)
@staticmethod
def supported_model_names(format: str = None):
full_list = timm.list_models(pretrained=True)
full_list.sort()
if format is None:
model_list = full_list
elif format == 'pytorch':
to_remove = []
assert set(to_remove).issubset(set(full_list))
model_list = list(set(full_list) - set(to_remove))
# todo: elif format == 'torchscript':
# todo: elif format == 'tensorrt'
else:
log.error(f'Invalid format "{format}". Currently supported formats: "pytorch".')
return model_list

BIN
towhee.jpeg

Binary file not shown.

After

Width:  |  Height:  |  Size: 49 KiB

Loading…
Cancel
Save