logo
Browse Source

Update

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 2 years ago
parent
commit
39c24e6e61
  1. 25
      isc.py

25
isc.py

@ -28,9 +28,9 @@ import sys
import torch
from torch import nn
from torchvision import transforms
from PIL import Image as PILImage
import timm
from timm.data import create_transform
from timm.models import create_model, get_pretrained_cfg
warnings.filterwarnings('ignore')
log = logging.getLogger('isc_op')
@ -41,7 +41,7 @@ _ = sys.modules[__name__]
class Model:
def __init__(self, timm_backbone, checkpoint_path, device):
self.device = device
self.backbone = timm.create_model(timm_backbone, features_only=True, pretrained=False)
self.backbone = create_model(timm_backbone, features_only=True, pretrained=False)
self.model = isc.create_model(pretrained=True, checkpoint_path=checkpoint_path, device=self.device,
backbone=self.backbone, p=1.0, eval_p=1.0)
self.model.eval()
@ -79,12 +79,13 @@ class Isc(NNOperator):
self.model = Model(self.timm_backbone, checkpoint_path, self.device)
self.tfms = transforms.Compose([
transforms.Resize((img_size, img_size)),
transforms.ToTensor(),
transforms.Normalize(mean=self.backbone.default_cfg['mean'],
std=self.backbone.default_cfg['std'])
])
self.tfms = create_transform(
input_size=img_size,
interpolation=self.config['interpolation'],
mean=self.config['mean'],
std=self.config['std'],
crop_pct=self.config['crop_pct']
)
def __call__(self, data: Union[List['towhee.types.Image'], 'towhee.types.Image']):
if not isinstance(data, list):
@ -113,9 +114,9 @@ class Isc(NNOperator):
return self.model.model
@property
def backbone(self):
backbone = timm.create_model(self.timm_backbone, features_only=True, pretrained=False)
return backbone
def config(self):
config = get_pretrained_cfg(self.timm_backbone)
return config
def save_model(self, format: str = 'pytorch', path: str = 'default'):
if path == 'default':

Loading…
Cancel
Save