diff --git a/isc.py b/isc.py index 8cb869e..3948aa4 100644 --- a/isc.py +++ b/isc.py @@ -30,7 +30,14 @@ import torch from torch import nn from PIL import Image as PILImage from timm.data import create_transform -from timm.models import create_model, get_pretrained_cfg +from timm import create_model + +try: + from timm.models import get_pretrained_cfg +except ImportError: + from timm.models.registry import _model_default_cfgs + def get_pretrained_cfg(model_name): + return _model_default_cfgs[model_name] warnings.filterwarnings('ignore') log = logging.getLogger('isc_op') @@ -80,12 +87,12 @@ class Isc(NNOperator): self.model = Model(self.timm_backbone, checkpoint_path, self.device) self.tfms = create_transform( - input_size=img_size, - interpolation=self.config['interpolation'], - mean=self.config['mean'], - std=self.config['std'], - crop_pct=self.config['crop_pct'] - ) + input_size=img_size, + interpolation=self.config['interpolation'], + mean=self.config['mean'], + std=self.config['std'], + crop_pct=self.config['crop_pct'] + ) def __call__(self, data: Union[List['towhee.types.Image'], 'towhee.types.Image']): if not isinstance(data, list):