diff --git a/README.md b/README.md
index 975b15a..61ac558 100644
--- a/README.md
+++ b/README.md
@@ -88,6 +88,45 @@ The decoded image data in towhee Image (a subset of numpy.ndarray).
An image embedding generated by model, in shape of (feature_dim,).
+
+
+***save_model(format='pytorch', path='default')***
+
+Save model to local with specified format.
+
+**Parameters:**
+
+***format***: *str*
+
+ The format of saved model, defaults to 'pytorch'.
+
+***path***: *str*
+
+ The path where model is saved to. By default, it will save model to the operator directory.
+
+
+
+***supported_model_names(format=None)***
+
+Get a list of all supported model names or supported model names for specified model format.
+
+**Parameters:**
+
+***format***: *str*
+
+ The model format such as 'pytorch', 'torchscript'.
+
+```python
+from towhee import ops
+
+
+op = ops.image_embedding.timm().get_op()
+full_list = op.supported_model_names()
+onnx_list = op.supported_model_names(format='onnx')
+print(f'Onnx-support/Total Models: {len(onnx_list)}/{len(full_list)}')
+```
+ 2022-12-13 16:50:57,012 - 140704500614336 - timm_image.py-timm_image:76 - WARNING: The operator is initialized without specified model.
+ Onnx-support/Total Models: 736/770
diff --git a/timm_image.py b/timm_image.py
index 6c75833..b56f3b2 100644
--- a/timm_image.py
+++ b/timm_image.py
@@ -54,10 +54,9 @@ class TimmImage(NNOperator):
"""
def __init__(self,
- model_name: str = 'resnet50',
+ model_name: str = None,
num_classes: int = 1000,
skip_preprocess: bool = False,
- pretrained: bool = True,
device: str = None
) -> None:
super().__init__()
@@ -65,13 +64,17 @@ class TimmImage(NNOperator):
device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.device = device
self.model_name = model_name
- self.model = create_model(self.model_name, pretrained=pretrained, num_classes=num_classes)
- self.model.eval()
- self.model.to(self.device)
-
- self.config = resolve_data_config({}, model=self.model)
- self.tfms = create_transform(**self.config)
- self.skip_tfms = skip_preprocess
+ if self.model_name:
+ self.model = create_model(self.model_name, pretrained=pretrained, num_classes=num_classes)
+ self.model.eval()
+ self.model.to(self.device)
+
+ self.config = resolve_data_config({}, model=self.model)
+ self.tfms = create_transform(**self.config)
+ self.skip_tfms = skip_preprocess
+ else:
+ log.warning('The operator is initialized without specified model.')
+ pass
def __call__(self, data: Union[List[towhee._types.Image], towhee._types.Image]):
if not isinstance(data, list):
@@ -148,16 +151,73 @@ class TimmImage(NNOperator):
@staticmethod
def supported_model_names(format: str = None):
+ assert timm.__version__ == '0.6.12', 'The model lists are tested with timm==0.6.12.'
full_list = timm.list_models(pretrained=True)
full_list.sort()
if format is None:
model_list = full_list
elif format == 'pytorch':
- to_remove = []
+ to_remove = [
+ 'coat_mini',
+ 'coat_tiny',
+ 'crossvit_9_240',
+ 'crossvit_9_dagger_240',
+ 'crossvit_15_240',
+ 'crossvit_15_dagger_240',
+ 'crossvit_15_dagger_408',
+ 'crossvit_18_240',
+ 'crossvit_18_dagger_240',
+ 'crossvit_18_dagger_408',
+ 'crossvit_base_240',
+ 'crossvit_small_240',
+ 'crossvit_tiny_240',
+ 'tresnet_l',
+ 'tresnet_l_448',
+ 'tresnet_m',
+ 'tresnet_m_448',
+ 'tresnet_m_miil_in21k',
+ 'tresnet_v2_l',
+ 'tresnet_xl',
+ 'tresnet_xl_448']
assert set(to_remove).issubset(set(full_list))
model_list = list(set(full_list) - set(to_remove))
elif format == 'onnx':
- to_remove = []
+ to_remove = [
+ 'bat_resnext26ts',
+ 'convmixer_768_32',
+ 'convmixer_1024_20_ks9_p14',
+ 'convmixer_1536_20',
+ 'eca_halonext26ts',
+ 'efficientformer_l1',
+ 'efficientformer_l3',
+ 'efficientformer_l7',
+ 'halo2botnet50ts_256',
+ 'halonet26t',
+ 'halonet50ts',
+ 'haloregnetz_b',
+ 'lamhalobotnet50ts_256',
+ 'levit_128',
+ 'levit_128s',
+ 'levit_192',
+ 'levit_256',
+ 'levit_384',
+ 'pvt_v2_b2_li',
+ 'sehalonet33ts',
+ 'tf_efficientnet_cc_b0_4e',
+ 'tf_efficientnet_cc_b0_8e',
+ 'tf_efficientnet_cc_b1_8e',
+ 'volo_d1_224',
+ 'volo_d1_384',
+ 'volo_d2_224',
+ 'volo_d2_384',
+ 'volo_d3_224',
+ 'volo_d3_448',
+ 'volo_d4_224',
+ 'volo_d4_448',
+ 'volo_d5_224',
+ 'volo_d5_448',
+ 'volo_d5_512'
+ ]
assert set(to_remove).issubset(set(full_list))
model_list = list(set(full_list) - set(to_remove))
# todo: elif format == 'torchscript':