logo
Browse Source

Adapt get_pretrained_cfg with towhee of lower version

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 2 years ago
parent
commit
42ea1659f5
  1. 21
      isc.py

21
isc.py

@ -30,7 +30,14 @@ import torch
from torch import nn
from PIL import Image as PILImage
from timm.data import create_transform
from timm.models import create_model, get_pretrained_cfg
from timm import create_model
try:
from timm.models import get_pretrained_cfg
except ImportError:
from timm.models.registry import _model_default_cfgs
def get_pretrained_cfg(model_name):
return _model_default_cfgs[model_name]
warnings.filterwarnings('ignore')
log = logging.getLogger('isc_op')
@ -80,12 +87,12 @@ class Isc(NNOperator):
self.model = Model(self.timm_backbone, checkpoint_path, self.device)
self.tfms = create_transform(
input_size=img_size,
interpolation=self.config['interpolation'],
mean=self.config['mean'],
std=self.config['std'],
crop_pct=self.config['crop_pct']
)
input_size=img_size,
interpolation=self.config['interpolation'],
mean=self.config['mean'],
std=self.config['std'],
crop_pct=self.config['crop_pct']
)
def __call__(self, data: Union[List['towhee.types.Image'], 'towhee.types.Image']):
if not isinstance(data, list):

Loading…
Cancel
Save