logo
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions

71 lines
2.3 KiB

# Copyright 2021 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import numpy as np
from torchvision import transforms as T
from towhee.operator import NNOperator
from towhee.models import clip
class Model:
def __init__(self, model_name, device='cpu'):
self.model = clip.create_model(model_name=model_name, pretrained=True, device=device).visual
self.model.eval()
def __call__(self, data: 'Tensor'):
return self.model(data)
class ClipVision(NNOperator):
def __init__(self, model_name='clip_vit_b32'):
super().__init__()
self.tfms = torch.nn.Sequential(
T.Resize(224, interpolation=T.InterpolationMode.BICUBIC),
T.CenterCrop(224),
T.ConvertImageDtype(torch.float),
T.Normalize(
(0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
).to(self.device)
self.model = Model(model_name, self.device)
@property
def device(self):
if self._device_id < 0:
return 'cpu'
else:
return self._device_id
def __call__(self, image: 'Image'):
img = np.transpose(image, [2, 0, 1])
data = torch.from_numpy(img)
data = data.to(self.device)
image_tensor = self.tfms(img)
features = self.model(image_tensor)
return features.detach().cpu().numpy().flatten()
def save_model(self, model_type, output_file, args=None):
if model_type != 'onnx':
return False
x = torch.randn((1, 3, 224, 224))
torch.onnx.export(self.model, x, output_file, input_names=['INPUT0'],
output_names=['OUTPUT0'], dynamic_axes={'INPUT0': [0]})
return True
@property
def supported_formats(self):
return ['onnx']