diff --git a/README.md b/README.md index 61ac558..028175a 100644 --- a/README.md +++ b/README.md @@ -125,8 +125,8 @@ full_list = op.supported_model_names() onnx_list = op.supported_model_names(format='onnx') print(f'Onnx-support/Total Models: {len(onnx_list)}/{len(full_list)}') ``` - 2022-12-13 16:50:57,012 - 140704500614336 - timm_image.py-timm_image:76 - WARNING: The operator is initialized without specified model. - Onnx-support/Total Models: 736/770 + 2022-12-19 16:32:37,933 - 140704422594752 - timm_image.py-timm_image:88 - WARNING: The operator is initialized without specified model. + Onnx-support/Total Models: 715/759
diff --git a/timm_image.py b/timm_image.py index d54676a..b8eb495 100644 --- a/timm_image.py +++ b/timm_image.py @@ -101,14 +101,17 @@ class TimmImage(NNOperator): inputs = torch.stack(img_list) inputs = inputs.to(self.device) features = self.model(inputs) - if features.dim() == 4: - global_pool = nn.AdaptiveAvgPool2d(1).to(self.device) - features = global_pool(features) - features = features.to('cpu').flatten(1) + if isinstance(features, list): + features = [self.post_proc(x) for x in features] + else: + features = self.post_proc(features) + if isinstance(data, list): - vecs = list(features.detach().numpy()) + vecs = [list(x.detach().numpy()) for x in features] if isinstance(features, list) \ + else list(features.detach().numpy()) else: - vecs = features.squeeze(0).detach().numpy() + vecs = [x.squeeze(0).detach().numpy()] if instance(features, list) \ + else features.squeeze(0).detach().numpy() return vecs @property @@ -124,6 +127,16 @@ class TimmImage(NNOperator): img = PILImage.fromarray(img.astype('uint8'), 'RGB') return img + def post_proc(self, features): + if 'vit' in self.model_name and features.dim() == 3: + features = features[:, 0] + if features.dim() == 4: + global_pool = nn.AdaptiveAvgPool2d(1).to(self.device) + features = global_pool(features) + assert features.dim() == 2, f'Invalid output dim {features.dim()}' + features = features.to('cpu') + return features + def save_model(self, format: str = 'pytorch', path: str = 'default'): if path == 'default': path = str(Path(__file__).parent) @@ -176,34 +189,32 @@ class TimmImage(NNOperator): @staticmethod def supported_model_names(format: str = None): assert timm.__version__ == '0.6.12', 'The model lists are tested with timm==0.6.12.' - full_list = timm.list_models(pretrained=True) + full_list = list(set(timm.list_models(pretrained=True)) - set([ + 'coat_mini', + 'coat_tiny', + 'crossvit_9_240', + 'crossvit_9_dagger_240', + 'crossvit_15_240', + 'crossvit_15_dagger_240', + 'crossvit_15_dagger_408', + 'crossvit_18_240', + 'crossvit_18_dagger_240', + 'crossvit_18_dagger_408', + 'crossvit_base_240', + 'crossvit_small_240', + 'crossvit_tiny_240', + ])) full_list.sort() - if format is None: + if format in [None, 'pytorch']: model_list = full_list - elif format == 'pytorch': - to_remove = [ - 'coat_mini', - 'coat_tiny', - 'crossvit_9_240', - 'crossvit_9_dagger_240', - 'crossvit_15_240', - 'crossvit_15_dagger_240', - 'crossvit_15_dagger_408', - 'crossvit_18_240', - 'crossvit_18_dagger_240', - 'crossvit_18_dagger_408', - 'crossvit_base_240', - 'crossvit_small_240', - 'crossvit_tiny_240', - ] - assert set(to_remove).issubset(set(full_list)) - model_list = list(set(full_list) - set(to_remove)) elif format == 'onnx': to_remove = [ 'bat_resnext26ts', - 'convmixer_768_32', + 'coat_mini', + 'coat_tiny', 'convmixer_1024_20_ks9_p14', 'convmixer_1536_20', + 'convmixer_768_32', 'eca_halonext26ts', 'efficientformer_l1', 'efficientformer_l3', @@ -223,6 +234,14 @@ class TimmImage(NNOperator): 'tf_efficientnet_cc_b0_4e', 'tf_efficientnet_cc_b0_8e', 'tf_efficientnet_cc_b1_8e', + 'tresnet_l', + 'tresnet_l_448', + 'tresnet_m', + 'tresnet_m_448', + 'tresnet_m_miil_in21k', + 'tresnet_v2_l', + 'tresnet_xl', + 'tresnet_xl_448', 'volo_d1_224', 'volo_d1_384', 'volo_d2_224',