diff --git a/isc.py b/isc.py index 0ce6ded..a8d89a5 100644 --- a/isc.py +++ b/isc.py @@ -32,6 +32,7 @@ except: return func import torch +import timm from torch import nn from PIL import Image as PILImage from timm.data import create_transform @@ -83,7 +84,7 @@ class Isc(NNOperator): device = 'cuda' if torch.cuda.is_available() else 'cpu' self.device = device if isinstance(device, str) else 'cpu' if device < 0 else torch.device(device) self.skip_tfms = skip_preprocess - self.timm_backbone = timm_backbone + self.timm_backbone = timm_backbone if timm.__version__ < '0.9.0' else 'tf_efficientnetv2_m.in21k_ft_in1k' if checkpoint_path is None: checkpoint_path = os.path.join(str(Path(__file__).parent), 'checkpoints', timm_backbone + '.pth')