From c7e446550ba9a449788d14ba24aaef475d5419bb Mon Sep 17 00:00:00 2001 From: Jael Gu Date: Tue, 13 Dec 2022 16:52:27 +0800 Subject: [PATCH] Update model list Signed-off-by: Jael Gu --- README.md | 39 ++++++++++++++++++++++++ timm_image.py | 82 ++++++++++++++++++++++++++++++++++++++++++++------- 2 files changed, 110 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index 975b15a..61ac558 100644 --- a/README.md +++ b/README.md @@ -88,6 +88,45 @@ The decoded image data in towhee Image (a subset of numpy.ndarray). An image embedding generated by model, in shape of (feature_dim,). +
+ +***save_model(format='pytorch', path='default')*** + +Save model to local with specified format. + +**Parameters:** + +***format***: *str* + +​ The format of saved model, defaults to 'pytorch'. + +***path***: *str* + +​ The path where model is saved to. By default, it will save model to the operator directory. + +
+ +***supported_model_names(format=None)*** + +Get a list of all supported model names or supported model names for specified model format. + +**Parameters:** + +***format***: *str* + +​ The model format such as 'pytorch', 'torchscript'. + +```python +from towhee import ops + + +op = ops.image_embedding.timm().get_op() +full_list = op.supported_model_names() +onnx_list = op.supported_model_names(format='onnx') +print(f'Onnx-support/Total Models: {len(onnx_list)}/{len(full_list)}') +``` + 2022-12-13 16:50:57,012 - 140704500614336 - timm_image.py-timm_image:76 - WARNING: The operator is initialized without specified model. + Onnx-support/Total Models: 736/770
diff --git a/timm_image.py b/timm_image.py index 6c75833..b56f3b2 100644 --- a/timm_image.py +++ b/timm_image.py @@ -54,10 +54,9 @@ class TimmImage(NNOperator): """ def __init__(self, - model_name: str = 'resnet50', + model_name: str = None, num_classes: int = 1000, skip_preprocess: bool = False, - pretrained: bool = True, device: str = None ) -> None: super().__init__() @@ -65,13 +64,17 @@ class TimmImage(NNOperator): device = 'cuda' if torch.cuda.is_available() else 'cpu' self.device = device self.model_name = model_name - self.model = create_model(self.model_name, pretrained=pretrained, num_classes=num_classes) - self.model.eval() - self.model.to(self.device) - - self.config = resolve_data_config({}, model=self.model) - self.tfms = create_transform(**self.config) - self.skip_tfms = skip_preprocess + if self.model_name: + self.model = create_model(self.model_name, pretrained=pretrained, num_classes=num_classes) + self.model.eval() + self.model.to(self.device) + + self.config = resolve_data_config({}, model=self.model) + self.tfms = create_transform(**self.config) + self.skip_tfms = skip_preprocess + else: + log.warning('The operator is initialized without specified model.') + pass def __call__(self, data: Union[List[towhee._types.Image], towhee._types.Image]): if not isinstance(data, list): @@ -148,16 +151,73 @@ class TimmImage(NNOperator): @staticmethod def supported_model_names(format: str = None): + assert timm.__version__ == '0.6.12', 'The model lists are tested with timm==0.6.12.' full_list = timm.list_models(pretrained=True) full_list.sort() if format is None: model_list = full_list elif format == 'pytorch': - to_remove = [] + to_remove = [ + 'coat_mini', + 'coat_tiny', + 'crossvit_9_240', + 'crossvit_9_dagger_240', + 'crossvit_15_240', + 'crossvit_15_dagger_240', + 'crossvit_15_dagger_408', + 'crossvit_18_240', + 'crossvit_18_dagger_240', + 'crossvit_18_dagger_408', + 'crossvit_base_240', + 'crossvit_small_240', + 'crossvit_tiny_240', + 'tresnet_l', + 'tresnet_l_448', + 'tresnet_m', + 'tresnet_m_448', + 'tresnet_m_miil_in21k', + 'tresnet_v2_l', + 'tresnet_xl', + 'tresnet_xl_448'] assert set(to_remove).issubset(set(full_list)) model_list = list(set(full_list) - set(to_remove)) elif format == 'onnx': - to_remove = [] + to_remove = [ + 'bat_resnext26ts', + 'convmixer_768_32', + 'convmixer_1024_20_ks9_p14', + 'convmixer_1536_20', + 'eca_halonext26ts', + 'efficientformer_l1', + 'efficientformer_l3', + 'efficientformer_l7', + 'halo2botnet50ts_256', + 'halonet26t', + 'halonet50ts', + 'haloregnetz_b', + 'lamhalobotnet50ts_256', + 'levit_128', + 'levit_128s', + 'levit_192', + 'levit_256', + 'levit_384', + 'pvt_v2_b2_li', + 'sehalonet33ts', + 'tf_efficientnet_cc_b0_4e', + 'tf_efficientnet_cc_b0_8e', + 'tf_efficientnet_cc_b1_8e', + 'volo_d1_224', + 'volo_d1_384', + 'volo_d2_224', + 'volo_d2_384', + 'volo_d3_224', + 'volo_d3_448', + 'volo_d4_224', + 'volo_d4_448', + 'volo_d5_224', + 'volo_d5_448', + 'volo_d5_512' + ] assert set(to_remove).issubset(set(full_list)) model_list = list(set(full_list) - set(to_remove)) # todo: elif format == 'torchscript':