logo
Browse Source

Update model config with new method

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 2 years ago
parent
commit
9ce147c9a9
  1. 15
      timm_image.py

15
timm_image.py

@ -31,9 +31,8 @@ from torch import nn
from PIL import Image as PILImage
import timm
from timm.data.transforms_factory import create_transform
from timm.data import resolve_data_config
from timm.models.factory import create_model
from timm.data import create_transform
from timm.models import create_model, get_pretrained_cfg
import warnings
@ -89,7 +88,12 @@ class TimmImage(NNOperator):
device=self.device,
num_classes=num_classes
)
self.tfms = create_transform(**self.config)
self.tfms = create_transform(
input_size=self.config['input_size'],
interpolation=self.config['interpolation'],
mean=self.config['mean'],
std=self.config['std']
)
self.skip_tfms = skip_preprocess
else:
log.warning('The operator is initialized without specified model.')
@ -128,8 +132,7 @@ class TimmImage(NNOperator):
@property
def config(self):
m = create_model(self.model_name, pretrained=False)
config = resolve_data_config({}, model=m)
config = get_pretrained_cfg(self.model_name)
return config
@arg(1, to_image_color('RGB'))

Loading…
Cancel
Save