|
|
@ -30,7 +30,14 @@ import torch |
|
|
|
from torch import nn |
|
|
|
from PIL import Image as PILImage |
|
|
|
from timm.data import create_transform |
|
|
|
from timm.models import create_model, get_pretrained_cfg |
|
|
|
from timm import create_model |
|
|
|
|
|
|
|
try: |
|
|
|
from timm.models import get_pretrained_cfg |
|
|
|
except ImportError: |
|
|
|
from timm.models.registry import _model_default_cfgs |
|
|
|
def get_pretrained_cfg(model_name): |
|
|
|
return _model_default_cfgs[model_name] |
|
|
|
|
|
|
|
warnings.filterwarnings('ignore') |
|
|
|
log = logging.getLogger('isc_op') |
|
|
@ -80,12 +87,12 @@ class Isc(NNOperator): |
|
|
|
self.model = Model(self.timm_backbone, checkpoint_path, self.device) |
|
|
|
|
|
|
|
self.tfms = create_transform( |
|
|
|
input_size=img_size, |
|
|
|
interpolation=self.config['interpolation'], |
|
|
|
mean=self.config['mean'], |
|
|
|
std=self.config['std'], |
|
|
|
crop_pct=self.config['crop_pct'] |
|
|
|
) |
|
|
|
input_size=img_size, |
|
|
|
interpolation=self.config['interpolation'], |
|
|
|
mean=self.config['mean'], |
|
|
|
std=self.config['std'], |
|
|
|
crop_pct=self.config['crop_pct'] |
|
|
|
) |
|
|
|
|
|
|
|
def __call__(self, data: Union[List['towhee.types.Image'], 'towhee.types.Image']): |
|
|
|
if not isinstance(data, list): |
|
|
|