# Copyright 2021 Zilliz. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging import os from typing import Union, List from pathlib import Path import towhee from towhee.operator.base import NNOperator, OperatorFlag from towhee.types.arg import arg, to_image_color from towhee import register from towhee.models import isc # from towhee.dc2 import accelerate import torch from torch import nn from torchvision import transforms from PIL import Image as PILImage import timm import warnings warnings.filterwarnings('ignore') log = logging.getLogger('isc_op') # @accelerate class Model: def __init__(self, timm_backbone, checkpoint_path, device): self.device = device self.backbone = timm.create_model(timm_backbone, features_only=True, pretrained=False) self.model = isc.create_model(pretrained=True, checkpoint_path=checkpoint_path, device=self.device, backbone=self.backbone, p=1.0, eval_p=1.0) self.model.eval() def __call__(self, x): x = x.to(self.device) return self.model(x) @register(output_schema=['vec']) class Isc(NNOperator): """ The operator uses pretrained ISC model to extract features for an image input. Args: skip_preprocess (`bool = False`): Whether skip image transforms. """ def __init__(self, timm_backbone: str = 'tf_efficientnetv2_m_in21ft1k', img_size: int = 512, checkpoint_path: str = None, skip_preprocess: bool = False, device: str = None) -> None: super().__init__() if device is None: device = 'cuda' if torch.cuda.is_available() else 'cpu' self.device = device self.skip_tfms = skip_preprocess self.timm_backbone = timm_backbone if checkpoint_path is None: checkpoint_path = os.path.join(str(Path(__file__).parent), 'checkpoints', timm_backbone + '.pth') self.model = Model(self.timm_backbone, checkpoint_path, self.device) self.tfms = transforms.Compose([ transforms.Resize((img_size, img_size)), transforms.ToTensor(), transforms.Normalize(mean=self.backbone.default_cfg['mean'], std=self.backbone.default_cfg['std']) ]) def __call__(self, data: Union[List[towhee._types.Image], towhee._types.Image]): if isinstance(data, towhee._types.Image): imgs = [data] else: imgs = data img_list = [] for img in imgs: img = self.convert_img(img) img = img if self.skip_tfms else self.tfms(img) img_list.append(img) inputs = torch.stack(img_list) inputs = inputs.to(self.device) features = self.model(inputs) features = features.to('cpu') if isinstance(data, list): vecs = list(features.detach().numpy()) else: vecs = features.squeeze(0).detach().numpy() return vecs @property def _model(self): return self.model.model @property def backbone(self): backbone = timm.create_model(self.timm_backbone, features_only=True, pretrained=False) return backbone def save_model(self, format: str = 'pytorch', path: str = 'default'): if path == 'default': path = str(Path(__file__).parent) path = os.path.join(path, 'saved', format) os.makedirs(path, exist_ok=True) name = self.timm_backbone.replace('/', '-') path = os.path.join(path, name) if format in ['pytorch', 'torchscript']: path = path + '.pt' elif format == 'onnx': path = path + '.onnx' else: raise ValueError(f'Invalid format {format}.') dummy_input = torch.rand(1, 3, 224, 224) if format == 'pytorch': torch.save(self._model, path) elif format == 'torchscript': try: try: jit_model = torch.jit.script(self._model) except Exception: jit_model = torch.jit.trace(self._model, dummy_input, strict=False) torch.jit.save(jit_model, path) except Exception as e: log.error(f'Fail to save as torchscript: {e}.') raise RuntimeError(f'Fail to save as torchscript: {e}.') elif format == 'onnx': try: torch.onnx.export(self._model, dummy_input, path, input_names=['input_0'], output_names=['output_0'], opset_version=14, dynamic_axes={ 'input_0': {0: 'batch_size', 2: 'height', 3: 'width'}, 'output_0': {0: 'batch_size', 1: 'dim'} }, do_constant_folding=True ) except Exception as e: log.error(f'Fail to save as onnx: {e}.') raise RuntimeError(f'Fail to save as onnx: {e}.') # todo: elif format == 'tensorrt': else: log.error(f'Unsupported format "{format}".') return path @arg(1, to_image_color('RGB')) def convert_img(self, img: towhee._types.Image): img = PILImage.fromarray(img.astype('uint8'), 'RGB') return img @property def supported_formats(self): return ['onnx'] def train(self, training_config=None, train_dataset=None, eval_dataset=None, resume_checkpoint_path=None, **kwargs): from .train_isc import train_isc training_args = kwargs.pop('training_args', None) train_isc(self._model, training_args) # if __name__ == '__main__': # from towhee import ops # # path = 'https://github.com/towhee-io/towhee/raw/main/towhee_logo.png' # # decoder = ops.image_decode.cv2() # img = decoder(path) # # op = Isc() # out = op(img) # assert out.shape == (256,)