# Copyright 2021 Zilliz. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging import numpy import os from pathlib import Path import towhee from towhee.operator.base import NNOperator, OperatorFlag from towhee.types.arg import arg, to_image_color from towhee import register import torch from torch import nn from PIL import Image as PILImage from timm.data.transforms_factory import create_transform from timm.data import resolve_data_config from timm.models.factory import create_model import warnings warnings.filterwarnings('ignore') log = logging.getLogger() @register(output_schema=['vec']) class TimmImage(NNOperator): """ Pytorch image embedding operator that uses the Pytorch Image Model (timm) collection. Args: model_name (`str`): Which model to use for the embeddings. num_classes (`int = 1000`): Number of classes for classification. skip_preprocess (`bool = False`): Whether skip image transforms. """ def __init__(self, model_name: str, num_classes: int = 1000, skip_preprocess: bool = False) -> None: super().__init__() self.device = 'cuda' if torch.cuda.is_available() else 'cpu' self.model_name = model_name self.model = create_model(self.model_name, pretrained=True, num_classes=num_classes) self.model.to(self.device) self.model.eval() config = resolve_data_config({}, model=self.model) self.tfms = create_transform(**config) self.skip_tfms = skip_preprocess @arg(1, to_image_color('RGB')) def __call__(self, img: towhee._types.Image) -> numpy.ndarray: img = PILImage.fromarray(img.astype('uint8'), 'RGB') if not self.skip_tfms: img = self.tfms(img).unsqueeze(0) img = img.to(self.device) features = self.model.forward_features(img) if features.dim() == 4: global_pool = nn.AdaptiveAvgPool2d(1) features = global_pool(features) features = features.to('cpu') vec = features.flatten().detach().numpy() return vec def save_model(self, jit: bool = True, destination: str = 'default'): if destination == 'default': path = str(Path(__file__).parent) destination = os.path.join(path, self.model_name + '.pt') if jit: try: traced_model = torch.jit.script(self.model) torch.jit.save(traced_model, destination) except Exception as e: raise RuntimeError(f'Fail to save as torchscript: {e}.') else: torch.save(self.model, destination) # if __name__ == '__main__': # from towhee import ops # # path = '/image/path/or/link' # # decoder = ops.image_decode.cv2() # img = decoder(path) # # op = TimmImage('resnet50') # out = op(img) # print(out) # # op.model = torch.jit.load('resnet50.pt') # out2 = op(img) # print(out2)