|
|
@ -54,10 +54,9 @@ class TimmImage(NNOperator): |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, |
|
|
|
model_name: str = 'resnet50', |
|
|
|
model_name: str = None, |
|
|
|
num_classes: int = 1000, |
|
|
|
skip_preprocess: bool = False, |
|
|
|
pretrained: bool = True, |
|
|
|
device: str = None |
|
|
|
) -> None: |
|
|
|
super().__init__() |
|
|
@ -65,13 +64,17 @@ class TimmImage(NNOperator): |
|
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
|
self.device = device |
|
|
|
self.model_name = model_name |
|
|
|
self.model = create_model(self.model_name, pretrained=pretrained, num_classes=num_classes) |
|
|
|
self.model.eval() |
|
|
|
self.model.to(self.device) |
|
|
|
|
|
|
|
self.config = resolve_data_config({}, model=self.model) |
|
|
|
self.tfms = create_transform(**self.config) |
|
|
|
self.skip_tfms = skip_preprocess |
|
|
|
if self.model_name: |
|
|
|
self.model = create_model(self.model_name, pretrained=pretrained, num_classes=num_classes) |
|
|
|
self.model.eval() |
|
|
|
self.model.to(self.device) |
|
|
|
|
|
|
|
self.config = resolve_data_config({}, model=self.model) |
|
|
|
self.tfms = create_transform(**self.config) |
|
|
|
self.skip_tfms = skip_preprocess |
|
|
|
else: |
|
|
|
log.warning('The operator is initialized without specified model.') |
|
|
|
pass |
|
|
|
|
|
|
|
def __call__(self, data: Union[List[towhee._types.Image], towhee._types.Image]): |
|
|
|
if not isinstance(data, list): |
|
|
@ -148,16 +151,73 @@ class TimmImage(NNOperator): |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def supported_model_names(format: str = None): |
|
|
|
assert timm.__version__ == '0.6.12', 'The model lists are tested with timm==0.6.12.' |
|
|
|
full_list = timm.list_models(pretrained=True) |
|
|
|
full_list.sort() |
|
|
|
if format is None: |
|
|
|
model_list = full_list |
|
|
|
elif format == 'pytorch': |
|
|
|
to_remove = [] |
|
|
|
to_remove = [ |
|
|
|
'coat_mini', |
|
|
|
'coat_tiny', |
|
|
|
'crossvit_9_240', |
|
|
|
'crossvit_9_dagger_240', |
|
|
|
'crossvit_15_240', |
|
|
|
'crossvit_15_dagger_240', |
|
|
|
'crossvit_15_dagger_408', |
|
|
|
'crossvit_18_240', |
|
|
|
'crossvit_18_dagger_240', |
|
|
|
'crossvit_18_dagger_408', |
|
|
|
'crossvit_base_240', |
|
|
|
'crossvit_small_240', |
|
|
|
'crossvit_tiny_240', |
|
|
|
'tresnet_l', |
|
|
|
'tresnet_l_448', |
|
|
|
'tresnet_m', |
|
|
|
'tresnet_m_448', |
|
|
|
'tresnet_m_miil_in21k', |
|
|
|
'tresnet_v2_l', |
|
|
|
'tresnet_xl', |
|
|
|
'tresnet_xl_448'] |
|
|
|
assert set(to_remove).issubset(set(full_list)) |
|
|
|
model_list = list(set(full_list) - set(to_remove)) |
|
|
|
elif format == 'onnx': |
|
|
|
to_remove = [] |
|
|
|
to_remove = [ |
|
|
|
'bat_resnext26ts', |
|
|
|
'convmixer_768_32', |
|
|
|
'convmixer_1024_20_ks9_p14', |
|
|
|
'convmixer_1536_20', |
|
|
|
'eca_halonext26ts', |
|
|
|
'efficientformer_l1', |
|
|
|
'efficientformer_l3', |
|
|
|
'efficientformer_l7', |
|
|
|
'halo2botnet50ts_256', |
|
|
|
'halonet26t', |
|
|
|
'halonet50ts', |
|
|
|
'haloregnetz_b', |
|
|
|
'lamhalobotnet50ts_256', |
|
|
|
'levit_128', |
|
|
|
'levit_128s', |
|
|
|
'levit_192', |
|
|
|
'levit_256', |
|
|
|
'levit_384', |
|
|
|
'pvt_v2_b2_li', |
|
|
|
'sehalonet33ts', |
|
|
|
'tf_efficientnet_cc_b0_4e', |
|
|
|
'tf_efficientnet_cc_b0_8e', |
|
|
|
'tf_efficientnet_cc_b1_8e', |
|
|
|
'volo_d1_224', |
|
|
|
'volo_d1_384', |
|
|
|
'volo_d2_224', |
|
|
|
'volo_d2_384', |
|
|
|
'volo_d3_224', |
|
|
|
'volo_d3_448', |
|
|
|
'volo_d4_224', |
|
|
|
'volo_d4_448', |
|
|
|
'volo_d5_224', |
|
|
|
'volo_d5_448', |
|
|
|
'volo_d5_512' |
|
|
|
] |
|
|
|
assert set(to_remove).issubset(set(full_list)) |
|
|
|
model_list = list(set(full_list) - set(to_remove)) |
|
|
|
# todo: elif format == 'torchscript': |
|
|
|