diff --git a/timm_image.py b/timm_image.py index 80ed6bb..a5690a3 100644 --- a/timm_image.py +++ b/timm_image.py @@ -31,7 +31,7 @@ from torch import nn from PIL import Image as PILImage import timm -from timm.data import create_transform +from timm.data import create_transform, resolve_data_config from timm.models import create_model, get_pretrained_cfg import warnings @@ -90,13 +90,16 @@ class TimmImage(NNOperator): device=self.device, num_classes=num_classes ) - self.tfms = create_transform( - input_size=self.config['input_size'], - interpolation=self.config['interpolation'], - mean=self.config['mean'], - std=self.config['std'], - crop_pct=self.config['crop_pct'] - ) + try: + self.tfms = create_transform( + input_size=self.config['input_size'], + interpolation=self.config['interpolation'], + mean=self.config['mean'], + std=self.config['std'], + crop_pct=self.config['crop_pct'] + ) + except: + self.tfms = create_transform(**resolve_data_config({}, model=self.model)) self.skip_tfms = skip_preprocess else: log.warning('The operator is initialized without specified model.')