|
|
@ -101,14 +101,17 @@ class TimmImage(NNOperator): |
|
|
|
inputs = torch.stack(img_list) |
|
|
|
inputs = inputs.to(self.device) |
|
|
|
features = self.model(inputs) |
|
|
|
if features.dim() == 4: |
|
|
|
global_pool = nn.AdaptiveAvgPool2d(1).to(self.device) |
|
|
|
features = global_pool(features) |
|
|
|
features = features.to('cpu').flatten(1) |
|
|
|
if isinstance(features, list): |
|
|
|
features = [self.post_proc(x) for x in features] |
|
|
|
else: |
|
|
|
features = self.post_proc(features) |
|
|
|
|
|
|
|
if isinstance(data, list): |
|
|
|
vecs = list(features.detach().numpy()) |
|
|
|
vecs = [list(x.detach().numpy()) for x in features] if isinstance(features, list) \ |
|
|
|
else list(features.detach().numpy()) |
|
|
|
else: |
|
|
|
vecs = features.squeeze(0).detach().numpy() |
|
|
|
vecs = [x.squeeze(0).detach().numpy()] if instance(features, list) \ |
|
|
|
else features.squeeze(0).detach().numpy() |
|
|
|
return vecs |
|
|
|
|
|
|
|
@property |
|
|
@ -124,6 +127,16 @@ class TimmImage(NNOperator): |
|
|
|
img = PILImage.fromarray(img.astype('uint8'), 'RGB') |
|
|
|
return img |
|
|
|
|
|
|
|
def post_proc(self, features): |
|
|
|
if 'vit' in self.model_name and features.dim() == 3: |
|
|
|
features = features[:, 0] |
|
|
|
if features.dim() == 4: |
|
|
|
global_pool = nn.AdaptiveAvgPool2d(1).to(self.device) |
|
|
|
features = global_pool(features) |
|
|
|
assert features.dim() == 2, f'Invalid output dim {features.dim()}' |
|
|
|
features = features.to('cpu') |
|
|
|
return features |
|
|
|
|
|
|
|
def save_model(self, format: str = 'pytorch', path: str = 'default'): |
|
|
|
if path == 'default': |
|
|
|
path = str(Path(__file__).parent) |
|
|
@ -176,34 +189,32 @@ class TimmImage(NNOperator): |
|
|
|
@staticmethod |
|
|
|
def supported_model_names(format: str = None): |
|
|
|
assert timm.__version__ == '0.6.12', 'The model lists are tested with timm==0.6.12.' |
|
|
|
full_list = timm.list_models(pretrained=True) |
|
|
|
full_list = list(set(timm.list_models(pretrained=True)) - set([ |
|
|
|
'coat_mini', |
|
|
|
'coat_tiny', |
|
|
|
'crossvit_9_240', |
|
|
|
'crossvit_9_dagger_240', |
|
|
|
'crossvit_15_240', |
|
|
|
'crossvit_15_dagger_240', |
|
|
|
'crossvit_15_dagger_408', |
|
|
|
'crossvit_18_240', |
|
|
|
'crossvit_18_dagger_240', |
|
|
|
'crossvit_18_dagger_408', |
|
|
|
'crossvit_base_240', |
|
|
|
'crossvit_small_240', |
|
|
|
'crossvit_tiny_240', |
|
|
|
])) |
|
|
|
full_list.sort() |
|
|
|
if format is None: |
|
|
|
if format in [None, 'pytorch']: |
|
|
|
model_list = full_list |
|
|
|
elif format == 'pytorch': |
|
|
|
to_remove = [ |
|
|
|
'coat_mini', |
|
|
|
'coat_tiny', |
|
|
|
'crossvit_9_240', |
|
|
|
'crossvit_9_dagger_240', |
|
|
|
'crossvit_15_240', |
|
|
|
'crossvit_15_dagger_240', |
|
|
|
'crossvit_15_dagger_408', |
|
|
|
'crossvit_18_240', |
|
|
|
'crossvit_18_dagger_240', |
|
|
|
'crossvit_18_dagger_408', |
|
|
|
'crossvit_base_240', |
|
|
|
'crossvit_small_240', |
|
|
|
'crossvit_tiny_240', |
|
|
|
] |
|
|
|
assert set(to_remove).issubset(set(full_list)) |
|
|
|
model_list = list(set(full_list) - set(to_remove)) |
|
|
|
elif format == 'onnx': |
|
|
|
to_remove = [ |
|
|
|
'bat_resnext26ts', |
|
|
|
'convmixer_768_32', |
|
|
|
'coat_mini', |
|
|
|
'coat_tiny', |
|
|
|
'convmixer_1024_20_ks9_p14', |
|
|
|
'convmixer_1536_20', |
|
|
|
'convmixer_768_32', |
|
|
|
'eca_halonext26ts', |
|
|
|
'efficientformer_l1', |
|
|
|
'efficientformer_l3', |
|
|
@ -223,6 +234,14 @@ class TimmImage(NNOperator): |
|
|
|
'tf_efficientnet_cc_b0_4e', |
|
|
|
'tf_efficientnet_cc_b0_8e', |
|
|
|
'tf_efficientnet_cc_b1_8e', |
|
|
|
'tresnet_l', |
|
|
|
'tresnet_l_448', |
|
|
|
'tresnet_m', |
|
|
|
'tresnet_m_448', |
|
|
|
'tresnet_m_miil_in21k', |
|
|
|
'tresnet_v2_l', |
|
|
|
'tresnet_xl', |
|
|
|
'tresnet_xl_448', |
|
|
|
'volo_d1_224', |
|
|
|
'volo_d1_384', |
|
|
|
'volo_d2_224', |
|
|
|