|
@ -31,7 +31,7 @@ from torch import nn |
|
|
from PIL import Image as PILImage |
|
|
from PIL import Image as PILImage |
|
|
|
|
|
|
|
|
import timm |
|
|
import timm |
|
|
from timm.data import create_transform |
|
|
|
|
|
|
|
|
from timm.data import create_transform, resolve_data_config |
|
|
from timm.models import create_model, get_pretrained_cfg |
|
|
from timm.models import create_model, get_pretrained_cfg |
|
|
|
|
|
|
|
|
import warnings |
|
|
import warnings |
|
@ -90,13 +90,16 @@ class TimmImage(NNOperator): |
|
|
device=self.device, |
|
|
device=self.device, |
|
|
num_classes=num_classes |
|
|
num_classes=num_classes |
|
|
) |
|
|
) |
|
|
self.tfms = create_transform( |
|
|
|
|
|
input_size=self.config['input_size'], |
|
|
|
|
|
interpolation=self.config['interpolation'], |
|
|
|
|
|
mean=self.config['mean'], |
|
|
|
|
|
std=self.config['std'], |
|
|
|
|
|
crop_pct=self.config['crop_pct'] |
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
self.tfms = create_transform( |
|
|
|
|
|
input_size=self.config['input_size'], |
|
|
|
|
|
interpolation=self.config['interpolation'], |
|
|
|
|
|
mean=self.config['mean'], |
|
|
|
|
|
std=self.config['std'], |
|
|
|
|
|
crop_pct=self.config['crop_pct'] |
|
|
|
|
|
) |
|
|
|
|
|
except: |
|
|
|
|
|
self.tfms = create_transform(**resolve_data_config({}, model=self.model)) |
|
|
self.skip_tfms = skip_preprocess |
|
|
self.skip_tfms = skip_preprocess |
|
|
else: |
|
|
else: |
|
|
log.warning('The operator is initialized without specified model.') |
|
|
log.warning('The operator is initialized without specified model.') |
|
|