logo
Browse Source

enhance robustness of timm tfms

main
ChengZi 1 year ago
parent
commit
dbf64995d0
  1. 19
      timm_image.py

19
timm_image.py

@ -31,7 +31,7 @@ from torch import nn
from PIL import Image as PILImage
import timm
from timm.data import create_transform
from timm.data import create_transform, resolve_data_config
from timm.models import create_model, get_pretrained_cfg
import warnings
@ -90,13 +90,16 @@ class TimmImage(NNOperator):
device=self.device,
num_classes=num_classes
)
self.tfms = create_transform(
input_size=self.config['input_size'],
interpolation=self.config['interpolation'],
mean=self.config['mean'],
std=self.config['std'],
crop_pct=self.config['crop_pct']
)
try:
self.tfms = create_transform(
input_size=self.config['input_size'],
interpolation=self.config['interpolation'],
mean=self.config['mean'],
std=self.config['std'],
crop_pct=self.config['crop_pct']
)
except:
self.tfms = create_transform(**resolve_data_config({}, model=self.model))
self.skip_tfms = skip_preprocess
else:
log.warning('The operator is initialized without specified model.')

Loading…
Cancel
Save