diff --git a/isc.py b/isc.py index 52dcfc5..8cb869e 100644 --- a/isc.py +++ b/isc.py @@ -28,9 +28,9 @@ import sys import torch from torch import nn -from torchvision import transforms from PIL import Image as PILImage -import timm +from timm.data import create_transform +from timm.models import create_model, get_pretrained_cfg warnings.filterwarnings('ignore') log = logging.getLogger('isc_op') @@ -41,7 +41,7 @@ _ = sys.modules[__name__] class Model: def __init__(self, timm_backbone, checkpoint_path, device): self.device = device - self.backbone = timm.create_model(timm_backbone, features_only=True, pretrained=False) + self.backbone = create_model(timm_backbone, features_only=True, pretrained=False) self.model = isc.create_model(pretrained=True, checkpoint_path=checkpoint_path, device=self.device, backbone=self.backbone, p=1.0, eval_p=1.0) self.model.eval() @@ -79,12 +79,13 @@ class Isc(NNOperator): self.model = Model(self.timm_backbone, checkpoint_path, self.device) - self.tfms = transforms.Compose([ - transforms.Resize((img_size, img_size)), - transforms.ToTensor(), - transforms.Normalize(mean=self.backbone.default_cfg['mean'], - std=self.backbone.default_cfg['std']) - ]) + self.tfms = create_transform( + input_size=img_size, + interpolation=self.config['interpolation'], + mean=self.config['mean'], + std=self.config['std'], + crop_pct=self.config['crop_pct'] + ) def __call__(self, data: Union[List['towhee.types.Image'], 'towhee.types.Image']): if not isinstance(data, list): @@ -113,9 +114,9 @@ class Isc(NNOperator): return self.model.model @property - def backbone(self): - backbone = timm.create_model(self.timm_backbone, features_only=True, pretrained=False) - return backbone + def config(self): + config = get_pretrained_cfg(self.timm_backbone) + return config def save_model(self, format: str = 'pytorch', path: str = 'default'): if path == 'default':