logo
Browse Source

Update model list

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 2 years ago
parent
commit
c7e446550b
  1. 39
      README.md
  2. 82
      timm_image.py

39
README.md

@ -88,6 +88,45 @@ The decoded image data in towhee Image (a subset of numpy.ndarray).
An image embedding generated by model, in shape of (feature_dim,).
<br />
***save_model(format='pytorch', path='default')***
Save model to local with specified format.
**Parameters:**
***format***: *str*
​ The format of saved model, defaults to 'pytorch'.
***path***: *str*
​ The path where model is saved to. By default, it will save model to the operator directory.
<br />
***supported_model_names(format=None)***
Get a list of all supported model names or supported model names for specified model format.
**Parameters:**
***format***: *str*
​ The model format such as 'pytorch', 'torchscript'.
```python
from towhee import ops
op = ops.image_embedding.timm().get_op()
full_list = op.supported_model_names()
onnx_list = op.supported_model_names(format='onnx')
print(f'Onnx-support/Total Models: {len(onnx_list)}/{len(full_list)}')
```
2022-12-13 16:50:57,012 - 140704500614336 - timm_image.py-timm_image:76 - WARNING: The operator is initialized without specified model.
Onnx-support/Total Models: 736/770
<br />

82
timm_image.py

@ -54,10 +54,9 @@ class TimmImage(NNOperator):
"""
def __init__(self,
model_name: str = 'resnet50',
model_name: str = None,
num_classes: int = 1000,
skip_preprocess: bool = False,
pretrained: bool = True,
device: str = None
) -> None:
super().__init__()
@ -65,13 +64,17 @@ class TimmImage(NNOperator):
device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.device = device
self.model_name = model_name
self.model = create_model(self.model_name, pretrained=pretrained, num_classes=num_classes)
self.model.eval()
self.model.to(self.device)
self.config = resolve_data_config({}, model=self.model)
self.tfms = create_transform(**self.config)
self.skip_tfms = skip_preprocess
if self.model_name:
self.model = create_model(self.model_name, pretrained=pretrained, num_classes=num_classes)
self.model.eval()
self.model.to(self.device)
self.config = resolve_data_config({}, model=self.model)
self.tfms = create_transform(**self.config)
self.skip_tfms = skip_preprocess
else:
log.warning('The operator is initialized without specified model.')
pass
def __call__(self, data: Union[List[towhee._types.Image], towhee._types.Image]):
if not isinstance(data, list):
@ -148,16 +151,73 @@ class TimmImage(NNOperator):
@staticmethod
def supported_model_names(format: str = None):
assert timm.__version__ == '0.6.12', 'The model lists are tested with timm==0.6.12.'
full_list = timm.list_models(pretrained=True)
full_list.sort()
if format is None:
model_list = full_list
elif format == 'pytorch':
to_remove = []
to_remove = [
'coat_mini',
'coat_tiny',
'crossvit_9_240',
'crossvit_9_dagger_240',
'crossvit_15_240',
'crossvit_15_dagger_240',
'crossvit_15_dagger_408',
'crossvit_18_240',
'crossvit_18_dagger_240',
'crossvit_18_dagger_408',
'crossvit_base_240',
'crossvit_small_240',
'crossvit_tiny_240',
'tresnet_l',
'tresnet_l_448',
'tresnet_m',
'tresnet_m_448',
'tresnet_m_miil_in21k',
'tresnet_v2_l',
'tresnet_xl',
'tresnet_xl_448']
assert set(to_remove).issubset(set(full_list))
model_list = list(set(full_list) - set(to_remove))
elif format == 'onnx':
to_remove = []
to_remove = [
'bat_resnext26ts',
'convmixer_768_32',
'convmixer_1024_20_ks9_p14',
'convmixer_1536_20',
'eca_halonext26ts',
'efficientformer_l1',
'efficientformer_l3',
'efficientformer_l7',
'halo2botnet50ts_256',
'halonet26t',
'halonet50ts',
'haloregnetz_b',
'lamhalobotnet50ts_256',
'levit_128',
'levit_128s',
'levit_192',
'levit_256',
'levit_384',
'pvt_v2_b2_li',
'sehalonet33ts',
'tf_efficientnet_cc_b0_4e',
'tf_efficientnet_cc_b0_8e',
'tf_efficientnet_cc_b1_8e',
'volo_d1_224',
'volo_d1_384',
'volo_d2_224',
'volo_d2_384',
'volo_d3_224',
'volo_d3_448',
'volo_d4_224',
'volo_d4_448',
'volo_d5_224',
'volo_d5_448',
'volo_d5_512'
]
assert set(to_remove).issubset(set(full_list))
model_list = list(set(full_list) - set(to_remove))
# todo: elif format == 'torchscript':

Loading…
Cancel
Save