# Copyright 2021 Zilliz. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging import numpy import os from pathlib import Path import towhee from towhee.operator.base import NNOperator, OperatorFlag from towhee.types.arg import arg, to_image_color from towhee import register import torch from torch import nn from torchvision import transforms from PIL import Image as PILImage import warnings warnings.filterwarnings('ignore') log = logging.getLogger() @register(output_schema=['vec']) class Swag(NNOperator): """ Pytorch image embedding operator that uses the Pytorch Image Model (timm) collection. Args: model_name (`str`): Which model to use for the embeddings. skip_preprocess (`bool = False`): Whether skip image transforms. """ def __init__(self, model_name: str, skip_preprocess: bool = False) -> None: super().__init__() self.device = 'cuda' if torch.cuda.is_available() else 'cpu' self.skip_tfms = skip_preprocess self.tfms = self.get_transforms(model_name) self.model_name = model_name self.model = torch.hub.load("facebookresearch/swag", model=model_name) self.model.to(self.device) self.model.head = None # To extract features without model head self.model.eval() @arg(1, to_image_color('RGB')) def __call__(self, img: 'towhee.types.Image') -> numpy.ndarray: img = PILImage.fromarray(img.astype('uint8'), 'RGB') if not self.skip_tfms: img = self.tfms(img).unsqueeze(0) img = img.to(self.device) features = self.model(img) if features.dim() == 4: global_pool = nn.AdaptiveAvgPool2d(1) features = global_pool(features) features = features.to('cpu') vec = features.flatten().detach().numpy() return vec def save_model(self, format: str = 'pytorch', path: str = 'default'): if path == 'default': path = str(Path(__file__).parent) name = self.model_name.replace('/', '-') path = os.path.join(path, name) inputs = torch.ones(1, 3, 224, 224) if format == 'pytorch': torch.save(self.model, path) elif format == 'torchscript': path = path + '.pt' try: try: jit_model = torch.jit.script(self.model) except Exception: jit_model = torch.jit.trace(self.model, inputs, strict=False) torch.jit.save(jit_model, path) except Exception as e: log.error(f'Fail to save as torchscript: {e}.') raise RuntimeError(f'Fail to save as torchscript: {e}.') elif format == 'onxx': pass # todo else: log.error(f'Save model: unsupported format "{format}".') @staticmethod def supported_model_names(format: str = None): full_list = [ 'vit_h14_in1k', 'vit_l16_in1k', 'vit_b16_in1k', 'regnety_16gf_in1k', 'regnety_32gf_in1k', 'regnety_128gf_in1k', ] full_list.sort() if format is None: model_list = full_list elif format == 'pytorch': to_remove = [] assert set(to_remove).issubset(set(full_list)) model_list = list(set(full_list) - set(to_remove)) else: # todo: format in {'torchscript', 'onnx', 'tensorrt'} log.error(f'Invalid format "{format}". Currently supported formats: "pytorch".') return model_list @staticmethod def get_transforms(model_name): model_resolution = { 'vit_h14_in1k': 518, 'vit_l16_in1k': 512, 'vit_b16_in1k': 384, 'regnety_16gf_in1k': 384, 'regnety_32gf_in1k': 384, 'regnety_128gf_in1k': 384 } if model_name not in model_resolution.keys(): log.warning('No transforms specified for model "%s", using resolution 384.', model_name) resolution = 384 else: resolution = model_resolution[model_name] transform = transforms.Compose([ transforms.Resize( resolution, interpolation=transforms.InterpolationMode.BICUBIC, ), transforms.CenterCrop(resolution), transforms.ToTensor(), transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ), ]) return transform if __name__ == '__main__': from towhee import ops path = '/Users/mengjiagu/Desktop/models/data/image/animals10/bird.jpg' decoder = ops.image_decode.cv2() img = decoder(path) # op = Swag('vit_b16_in1k') op = Swag('regnety_16gf_in1k') out = op(img) print(out.shape)