Browse Source
Support timm>=0.9
Signed-off-by: Kaiyuan Hu <kaiyuan.hu@zilliz.com>
main
1 changed files with
2 additions and
1 deletions
-
isc.py
|
@ -32,6 +32,7 @@ except: |
|
|
return func |
|
|
return func |
|
|
|
|
|
|
|
|
import torch |
|
|
import torch |
|
|
|
|
|
import timm |
|
|
from torch import nn |
|
|
from torch import nn |
|
|
from PIL import Image as PILImage |
|
|
from PIL import Image as PILImage |
|
|
from timm.data import create_transform |
|
|
from timm.data import create_transform |
|
@ -83,7 +84,7 @@ class Isc(NNOperator): |
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
self.device = device if isinstance(device, str) else 'cpu' if device < 0 else torch.device(device) |
|
|
self.device = device if isinstance(device, str) else 'cpu' if device < 0 else torch.device(device) |
|
|
self.skip_tfms = skip_preprocess |
|
|
self.skip_tfms = skip_preprocess |
|
|
self.timm_backbone = timm_backbone |
|
|
|
|
|
|
|
|
self.timm_backbone = timm_backbone if timm.__version__ < '0.9.0' else 'tf_efficientnetv2_m.in21k_ft_in1k' |
|
|
|
|
|
|
|
|
if checkpoint_path is None: |
|
|
if checkpoint_path is None: |
|
|
checkpoint_path = os.path.join(str(Path(__file__).parent), 'checkpoints', timm_backbone + '.pth') |
|
|
checkpoint_path = os.path.join(str(Path(__file__).parent), 'checkpoints', timm_backbone + '.pth') |
|
|