|
|
@ -28,9 +28,9 @@ import sys |
|
|
|
|
|
|
|
import torch |
|
|
|
from torch import nn |
|
|
|
from torchvision import transforms |
|
|
|
from PIL import Image as PILImage |
|
|
|
import timm |
|
|
|
from timm.data import create_transform |
|
|
|
from timm.models import create_model, get_pretrained_cfg |
|
|
|
|
|
|
|
warnings.filterwarnings('ignore') |
|
|
|
log = logging.getLogger('isc_op') |
|
|
@ -41,7 +41,7 @@ _ = sys.modules[__name__] |
|
|
|
class Model: |
|
|
|
def __init__(self, timm_backbone, checkpoint_path, device): |
|
|
|
self.device = device |
|
|
|
self.backbone = timm.create_model(timm_backbone, features_only=True, pretrained=False) |
|
|
|
self.backbone = create_model(timm_backbone, features_only=True, pretrained=False) |
|
|
|
self.model = isc.create_model(pretrained=True, checkpoint_path=checkpoint_path, device=self.device, |
|
|
|
backbone=self.backbone, p=1.0, eval_p=1.0) |
|
|
|
self.model.eval() |
|
|
@ -79,12 +79,13 @@ class Isc(NNOperator): |
|
|
|
|
|
|
|
self.model = Model(self.timm_backbone, checkpoint_path, self.device) |
|
|
|
|
|
|
|
self.tfms = transforms.Compose([ |
|
|
|
transforms.Resize((img_size, img_size)), |
|
|
|
transforms.ToTensor(), |
|
|
|
transforms.Normalize(mean=self.backbone.default_cfg['mean'], |
|
|
|
std=self.backbone.default_cfg['std']) |
|
|
|
]) |
|
|
|
self.tfms = create_transform( |
|
|
|
input_size=img_size, |
|
|
|
interpolation=self.config['interpolation'], |
|
|
|
mean=self.config['mean'], |
|
|
|
std=self.config['std'], |
|
|
|
crop_pct=self.config['crop_pct'] |
|
|
|
) |
|
|
|
|
|
|
|
def __call__(self, data: Union[List['towhee.types.Image'], 'towhee.types.Image']): |
|
|
|
if not isinstance(data, list): |
|
|
@ -113,9 +114,9 @@ class Isc(NNOperator): |
|
|
|
return self.model.model |
|
|
|
|
|
|
|
@property |
|
|
|
def backbone(self): |
|
|
|
backbone = timm.create_model(self.timm_backbone, features_only=True, pretrained=False) |
|
|
|
return backbone |
|
|
|
def config(self): |
|
|
|
config = get_pretrained_cfg(self.timm_backbone) |
|
|
|
return config |
|
|
|
|
|
|
|
def save_model(self, format: str = 'pytorch', path: str = 'default'): |
|
|
|
if path == 'default': |
|
|
|