logo
Browse Source

Support timm>=0.9

Signed-off-by: Kaiyuan Hu <kaiyuan.hu@zilliz.com>
main
Kaiyuan Hu 1 year ago
parent
commit
57a04cd00e
  1. 3
      isc.py

3
isc.py

@ -32,6 +32,7 @@ except:
return func
import torch
import timm
from torch import nn
from PIL import Image as PILImage
from timm.data import create_transform
@ -83,7 +84,7 @@ class Isc(NNOperator):
device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.device = device if isinstance(device, str) else 'cpu' if device < 0 else torch.device(device)
self.skip_tfms = skip_preprocess
self.timm_backbone = timm_backbone
self.timm_backbone = timm_backbone if timm.__version__ < '0.9.0' else 'tf_efficientnetv2_m.in21k_ft_in1k'
if checkpoint_path is None:
checkpoint_path = os.path.join(str(Path(__file__).parent), 'checkpoints', timm_backbone + '.pth')

Loading…
Cancel
Save