# Copyright 2021 Zilliz. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch import numpy as np from torchvision import transforms as T from towhee.operator import NNOperator, SharedType from towhee.dc2 import accelerate from towhee.models import clip @accelerate class Model: def __init__(self, model_name, device='cpu'): self.model = clip.create_model(model_name=model_name, pretrained=True, device=device).visual self.model.eval() print('Create local model') def __call__(self, data: 'Tensor'): return self.model(data) class ClipVision(NNOperator): def __init__(self, model_name='clip_vit_b32'): super().__init__() self._device = None self.tfms = torch.nn.Sequential( T.Resize(224, interpolation=T.InterpolationMode.BICUBIC), T.CenterCrop(224), T.ConvertImageDtype(torch.float), T.Normalize( (0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) ).to(self.device) self.model = Model(model_name, self.device) @property def device(self): if self._device is None: if self._device_id < 0: self._device = torch.device('cpu') else: self._device = torch.device(self._device_id) return self._device def __call__(self, image: 'Image'): img = np.transpose(image, [2, 0, 1]) data = torch.from_numpy(img) data = data.to(self.device) image_tensor = self.tfms(data).unsqueeze(0) features = self.model(image_tensor) return features.detach().cpu().numpy().flatten() def save_model(self, model_type, output_file, args=None): if model_type != 'onnx': return False x = torch.randn((1, 3, 224, 224)) torch.onnx.export(self.model.model, x, output_file, input_names=['INPUT0'], output_names=['OUTPUT0'], dynamic_axes={'INPUT0': [0]}) return True @property def supported_formats(self): return ['onnx'] @property def shared_type(self): return SharedType.Shareable