diff --git a/timm_image.py b/timm_image.py index 5da4d2f..f74eb9a 100644 --- a/timm_image.py +++ b/timm_image.py @@ -31,9 +31,8 @@ from torch import nn from PIL import Image as PILImage import timm -from timm.data.transforms_factory import create_transform -from timm.data import resolve_data_config -from timm.models.factory import create_model +from timm.data import create_transform +from timm.models import create_model, get_pretrained_cfg import warnings @@ -89,7 +88,12 @@ class TimmImage(NNOperator): device=self.device, num_classes=num_classes ) - self.tfms = create_transform(**self.config) + self.tfms = create_transform( + input_size=self.config['input_size'], + interpolation=self.config['interpolation'], + mean=self.config['mean'], + std=self.config['std'] + ) self.skip_tfms = skip_preprocess else: log.warning('The operator is initialized without specified model.') @@ -128,8 +132,7 @@ class TimmImage(NNOperator): @property def config(self): - m = create_model(self.model_name, pretrained=False) - config = resolve_data_config({}, model=m) + config = get_pretrained_cfg(self.model_name) return config @arg(1, to_image_color('RGB'))