logo
Browse Source

Fix op for vit models

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 1 year ago
parent
commit
0fcdf6d962
  1. 4
      README.md
  2. 73
      timm_image.py

4
README.md

@ -125,8 +125,8 @@ full_list = op.supported_model_names()
onnx_list = op.supported_model_names(format='onnx')
print(f'Onnx-support/Total Models: {len(onnx_list)}/{len(full_list)}')
```
2022-12-13 16:50:57,012 - 140704500614336 - timm_image.py-timm_image:76 - WARNING: The operator is initialized without specified model.
Onnx-support/Total Models: 736/770
2022-12-19 16:32:37,933 - 140704422594752 - timm_image.py-timm_image:88 - WARNING: The operator is initialized without specified model.
Onnx-support/Total Models: 715/759
<br />

73
timm_image.py

@ -101,14 +101,17 @@ class TimmImage(NNOperator):
inputs = torch.stack(img_list)
inputs = inputs.to(self.device)
features = self.model(inputs)
if features.dim() == 4:
global_pool = nn.AdaptiveAvgPool2d(1).to(self.device)
features = global_pool(features)
features = features.to('cpu').flatten(1)
if isinstance(features, list):
features = [self.post_proc(x) for x in features]
else:
features = self.post_proc(features)
if isinstance(data, list):
vecs = list(features.detach().numpy())
vecs = [list(x.detach().numpy()) for x in features] if isinstance(features, list) \
else list(features.detach().numpy())
else:
vecs = features.squeeze(0).detach().numpy()
vecs = [x.squeeze(0).detach().numpy()] if instance(features, list) \
else features.squeeze(0).detach().numpy()
return vecs
@property
@ -124,6 +127,16 @@ class TimmImage(NNOperator):
img = PILImage.fromarray(img.astype('uint8'), 'RGB')
return img
def post_proc(self, features):
if 'vit' in self.model_name and features.dim() == 3:
features = features[:, 0]
if features.dim() == 4:
global_pool = nn.AdaptiveAvgPool2d(1).to(self.device)
features = global_pool(features)
assert features.dim() == 2, f'Invalid output dim {features.dim()}'
features = features.to('cpu')
return features
def save_model(self, format: str = 'pytorch', path: str = 'default'):
if path == 'default':
path = str(Path(__file__).parent)
@ -176,34 +189,32 @@ class TimmImage(NNOperator):
@staticmethod
def supported_model_names(format: str = None):
assert timm.__version__ == '0.6.12', 'The model lists are tested with timm==0.6.12.'
full_list = timm.list_models(pretrained=True)
full_list = list(set(timm.list_models(pretrained=True)) - set([
'coat_mini',
'coat_tiny',
'crossvit_9_240',
'crossvit_9_dagger_240',
'crossvit_15_240',
'crossvit_15_dagger_240',
'crossvit_15_dagger_408',
'crossvit_18_240',
'crossvit_18_dagger_240',
'crossvit_18_dagger_408',
'crossvit_base_240',
'crossvit_small_240',
'crossvit_tiny_240',
]))
full_list.sort()
if format is None:
if format in [None, 'pytorch']:
model_list = full_list
elif format == 'pytorch':
to_remove = [
'coat_mini',
'coat_tiny',
'crossvit_9_240',
'crossvit_9_dagger_240',
'crossvit_15_240',
'crossvit_15_dagger_240',
'crossvit_15_dagger_408',
'crossvit_18_240',
'crossvit_18_dagger_240',
'crossvit_18_dagger_408',
'crossvit_base_240',
'crossvit_small_240',
'crossvit_tiny_240',
]
assert set(to_remove).issubset(set(full_list))
model_list = list(set(full_list) - set(to_remove))
elif format == 'onnx':
to_remove = [
'bat_resnext26ts',
'convmixer_768_32',
'coat_mini',
'coat_tiny',
'convmixer_1024_20_ks9_p14',
'convmixer_1536_20',
'convmixer_768_32',
'eca_halonext26ts',
'efficientformer_l1',
'efficientformer_l3',
@ -223,6 +234,14 @@ class TimmImage(NNOperator):
'tf_efficientnet_cc_b0_4e',
'tf_efficientnet_cc_b0_8e',
'tf_efficientnet_cc_b1_8e',
'tresnet_l',
'tresnet_l_448',
'tresnet_m',
'tresnet_m_448',
'tresnet_m_miil_in21k',
'tresnet_v2_l',
'tresnet_xl',
'tresnet_xl_448',
'volo_d1_224',
'volo_d1_384',
'volo_d2_224',

Loading…
Cancel
Save