Browse Source
Update model config with new method
Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
1 changed files with
9 additions and
6 deletions
-
timm_image.py
|
|
@ -31,9 +31,8 @@ from torch import nn |
|
|
|
from PIL import Image as PILImage |
|
|
|
|
|
|
|
import timm |
|
|
|
from timm.data.transforms_factory import create_transform |
|
|
|
from timm.data import resolve_data_config |
|
|
|
from timm.models.factory import create_model |
|
|
|
from timm.data import create_transform |
|
|
|
from timm.models import create_model, get_pretrained_cfg |
|
|
|
|
|
|
|
import warnings |
|
|
|
|
|
|
@ -89,7 +88,12 @@ class TimmImage(NNOperator): |
|
|
|
device=self.device, |
|
|
|
num_classes=num_classes |
|
|
|
) |
|
|
|
self.tfms = create_transform(**self.config) |
|
|
|
self.tfms = create_transform( |
|
|
|
input_size=self.config['input_size'], |
|
|
|
interpolation=self.config['interpolation'], |
|
|
|
mean=self.config['mean'], |
|
|
|
std=self.config['std'] |
|
|
|
) |
|
|
|
self.skip_tfms = skip_preprocess |
|
|
|
else: |
|
|
|
log.warning('The operator is initialized without specified model.') |
|
|
@ -128,8 +132,7 @@ class TimmImage(NNOperator): |
|
|
|
|
|
|
|
@property |
|
|
|
def config(self): |
|
|
|
m = create_model(self.model_name, pretrained=False) |
|
|
|
config = resolve_data_config({}, model=m) |
|
|
|
config = get_pretrained_cfg(self.model_name) |
|
|
|
return config |
|
|
|
|
|
|
|
@arg(1, to_image_color('RGB')) |
|
|
|