logo
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions

82 lines
2.6 KiB

# Copyright 2021 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import numpy as np
from torchvision import transforms as T
2 years ago
from towhee.operator import NNOperator, SharedType
from towhee.dc2 import accelerate
from towhee.models import clip
@accelerate
class Model:
def __init__(self, model_name, device='cpu'):
self.model = clip.create_model(model_name=model_name, pretrained=True, device=device).visual
self.model.eval()
print('Create local model')
def __call__(self, data: 'Tensor'):
return self.model(data)
class ClipVision(NNOperator):
def __init__(self, model_name='clip_vit_b32'):
super().__init__()
2 years ago
self._device = None
self.tfms = torch.nn.Sequential(
T.Resize(224, interpolation=T.InterpolationMode.BICUBIC),
T.CenterCrop(224),
T.ConvertImageDtype(torch.float),
T.Normalize(
(0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
).to(self.device)
self.model = Model(model_name, self.device)
@property
def device(self):
2 years ago
if self._device is None:
if self._device_id < 0:
self._device = torch.device('cpu')
else:
self._device = torch.device(self._device_id)
2 years ago
return self._device
def __call__(self, image: 'Image'):
img = np.transpose(image, [2, 0, 1])
data = torch.from_numpy(img)
data = data.to(self.device)
2 years ago
image_tensor = self.tfms(data).unsqueeze(0)
features = self.model(image_tensor)
return features.detach().cpu().numpy().flatten()
def save_model(self, model_type, output_file, args=None):
if model_type != 'onnx':
return False
x = torch.randn((1, 3, 224, 224))
torch.onnx.export(self.model.model, x, output_file, input_names=['INPUT0'],
output_names=['OUTPUT0'], dynamic_axes={'INPUT0': [0]})
return True
@property
def supported_formats(self):
return ['onnx']
2 years ago
@property
def shared_type(self):
return SharedType.Shareable