logo
Browse Source

enhance robustness of timm tfms

main
ChengZi 2 years ago
parent
commit
dbf64995d0
  1. 5
      timm_image.py

5
timm_image.py

@ -31,7 +31,7 @@ from torch import nn
from PIL import Image as PILImage from PIL import Image as PILImage
import timm import timm
from timm.data import create_transform
from timm.data import create_transform, resolve_data_config
from timm.models import create_model, get_pretrained_cfg from timm.models import create_model, get_pretrained_cfg
import warnings import warnings
@ -90,6 +90,7 @@ class TimmImage(NNOperator):
device=self.device, device=self.device,
num_classes=num_classes num_classes=num_classes
) )
try:
self.tfms = create_transform( self.tfms = create_transform(
input_size=self.config['input_size'], input_size=self.config['input_size'],
interpolation=self.config['interpolation'], interpolation=self.config['interpolation'],
@ -97,6 +98,8 @@ class TimmImage(NNOperator):
std=self.config['std'], std=self.config['std'],
crop_pct=self.config['crop_pct'] crop_pct=self.config['crop_pct']
) )
except:
self.tfms = create_transform(**resolve_data_config({}, model=self.model))
self.skip_tfms = skip_preprocess self.skip_tfms = skip_preprocess
else: else:
log.warning('The operator is initialized without specified model.') log.warning('The operator is initialized without specified model.')

Loading…
Cancel
Save