logo
Browse Source

Update

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 2 years ago
parent
commit
39c24e6e61
  1. 25
      isc.py

25
isc.py

@ -28,9 +28,9 @@ import sys
import torch import torch
from torch import nn from torch import nn
from torchvision import transforms
from PIL import Image as PILImage from PIL import Image as PILImage
import timm
from timm.data import create_transform
from timm.models import create_model, get_pretrained_cfg
warnings.filterwarnings('ignore') warnings.filterwarnings('ignore')
log = logging.getLogger('isc_op') log = logging.getLogger('isc_op')
@ -41,7 +41,7 @@ _ = sys.modules[__name__]
class Model: class Model:
def __init__(self, timm_backbone, checkpoint_path, device): def __init__(self, timm_backbone, checkpoint_path, device):
self.device = device self.device = device
self.backbone = timm.create_model(timm_backbone, features_only=True, pretrained=False)
self.backbone = create_model(timm_backbone, features_only=True, pretrained=False)
self.model = isc.create_model(pretrained=True, checkpoint_path=checkpoint_path, device=self.device, self.model = isc.create_model(pretrained=True, checkpoint_path=checkpoint_path, device=self.device,
backbone=self.backbone, p=1.0, eval_p=1.0) backbone=self.backbone, p=1.0, eval_p=1.0)
self.model.eval() self.model.eval()
@ -79,12 +79,13 @@ class Isc(NNOperator):
self.model = Model(self.timm_backbone, checkpoint_path, self.device) self.model = Model(self.timm_backbone, checkpoint_path, self.device)
self.tfms = transforms.Compose([
transforms.Resize((img_size, img_size)),
transforms.ToTensor(),
transforms.Normalize(mean=self.backbone.default_cfg['mean'],
std=self.backbone.default_cfg['std'])
])
self.tfms = create_transform(
input_size=img_size,
interpolation=self.config['interpolation'],
mean=self.config['mean'],
std=self.config['std'],
crop_pct=self.config['crop_pct']
)
def __call__(self, data: Union[List['towhee.types.Image'], 'towhee.types.Image']): def __call__(self, data: Union[List['towhee.types.Image'], 'towhee.types.Image']):
if not isinstance(data, list): if not isinstance(data, list):
@ -113,9 +114,9 @@ class Isc(NNOperator):
return self.model.model return self.model.model
@property @property
def backbone(self):
backbone = timm.create_model(self.timm_backbone, features_only=True, pretrained=False)
return backbone
def config(self):
config = get_pretrained_cfg(self.timm_backbone)
return config
def save_model(self, format: str = 'pytorch', path: str = 'default'): def save_model(self, format: str = 'pytorch', path: str = 'default'):
if path == 'default': if path == 'default':

Loading…
Cancel
Save