logo
Browse Source

Support timm>=0.9

Signed-off-by: Kaiyuan Hu <kaiyuan.hu@zilliz.com>
main
Kaiyuan Hu 1 year ago
parent
commit
57a04cd00e
  1. 3
      isc.py

3
isc.py

@ -32,6 +32,7 @@ except:
return func return func
import torch import torch
import timm
from torch import nn from torch import nn
from PIL import Image as PILImage from PIL import Image as PILImage
from timm.data import create_transform from timm.data import create_transform
@ -83,7 +84,7 @@ class Isc(NNOperator):
device = 'cuda' if torch.cuda.is_available() else 'cpu' device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.device = device if isinstance(device, str) else 'cpu' if device < 0 else torch.device(device) self.device = device if isinstance(device, str) else 'cpu' if device < 0 else torch.device(device)
self.skip_tfms = skip_preprocess self.skip_tfms = skip_preprocess
self.timm_backbone = timm_backbone
self.timm_backbone = timm_backbone if timm.__version__ < '0.9.0' else 'tf_efficientnetv2_m.in21k_ft_in1k'
if checkpoint_path is None: if checkpoint_path is None:
checkpoint_path = os.path.join(str(Path(__file__).parent), 'checkpoints', timm_backbone + '.pth') checkpoint_path = os.path.join(str(Path(__file__).parent), 'checkpoints', timm_backbone + '.pth')

Loading…
Cancel
Save