# Copyright 2021 Zilliz. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch from typing import List, Union from torch import nn from towhee.types.arg import arg, to_image_color from timm.data import resolve_data_config, create_transform from towhee.models import mpvit from towhee.operator.base import NNOperator from towhee import register import towhee from PIL import Image as PILImage @register(output_schema=['vec']) class MPViT(NNOperator): """ MPViT embedding operator """ def __init__(self, model_name, num_classes: int = 1000, weights_path: str = None, device: str = None, skip_preprocess: bool = False): super().__init__() if device is None: device = 'cuda' if torch.cuda.is_available() else 'cpu' self.model = mpvit.create_model(model_name=model_name, num_classes=num_classes, pretrained=True, weights_path=weights_path, device=device) self.model.eval() self.config = resolve_data_config({}, model=self.model) self.tfms = create_transform(**self.config) self.skip_tfms = skip_preprocess self.device = device def __call__(self, data: Union[List['towhee.types.Image'], 'towhee.types.Image']): if not isinstance(data, list): imgs = [data] else: imgs = data img_list = [] for img in imgs: img = self.convert_img(img) img = img if self.skip_tfms else self.tfms(img) img_list.append(img) inputs = torch.stack(img_list) inputs = inputs.to(self.device) features = self.model.forward_features(inputs) global_pool = nn.AdaptiveAvgPool2d(1) features = global_pool(features) features = features.to('cpu').flatten(1) if isinstance(data, list): vecs = list(features.detach().numpy()) else: vecs = features.squeeze(0).detach().numpy() return vecs @arg(1, to_image_color('RGB')) def convert_img(self, img: 'towhee.types.Image'): img = PILImage.fromarray(img.astype('uint8'), 'RGB') return img