From e582a9204cfa93c3237f1b300c16eee9afd0a99a Mon Sep 17 00:00:00 2001 From: "junjie.jiang" Date: Wed, 7 Dec 2022 11:09:54 +0800 Subject: [PATCH] update Signed-off-by: junjie.jiang --- README.md | 2 +- __init__.py | 19 +++++++++++++ clip_vision.py | 71 ++++++++++++++++++++++++++++++++++++++++++++++++ requirements.txt | 1 + 4 files changed, 92 insertions(+), 1 deletion(-) create mode 100644 __init__.py create mode 100644 clip_vision.py create mode 100644 requirements.txt diff --git a/README.md b/README.md index 8957d6b..b1e5d94 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,2 @@ -# clip-vision +# clip_vision diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..25f8e0d --- /dev/null +++ b/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2021 Zilliz. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from .clip_vision import ClipVision + +def clip_vision(**kwargs): + return ClipVision(**kwargs) diff --git a/clip_vision.py b/clip_vision.py new file mode 100644 index 0000000..54961fe --- /dev/null +++ b/clip_vision.py @@ -0,0 +1,71 @@ +# Copyright 2021 Zilliz. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import numpy as np +from torchvision import transforms as T + +from towhee.operator import NNOperator +from towhee.models import clip + + +class Model: + def __init__(self, model_name, device='cpu'): + self.model = clip.create_model(model_name=model_name, pretrained=True, device=device).visual + self.model.eval() + + def __call__(self, data: 'Tensor'): + return self.model(data) + + +class ClipVision(NNOperator): + def __init__(self, model_name='clip_vit_b32'): + super().__init__() + + self.tfms = torch.nn.Sequential( + T.Resize(224, interpolation=T.InterpolationMode.BICUBIC), + T.CenterCrop(224), + T.ConvertImageDtype(torch.float), + T.Normalize( + (0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) + ).to(self.device) + self.model = Model(model_name, self.device) + + @property + def device(self): + if self._device_id < 0: + return 'cpu' + else: + return self._device_id + + def __call__(self, image: 'Image'): + img = np.transpose(image, [2, 0, 1]) + data = torch.from_numpy(img) + data = data.to(self.device) + image_tensor = self.tfms(img) + features = self.model(image_tensor) + return features.detach().cpu().numpy().flatten() + + def save_model(self, model_type, output_file, args=None): + if model_type != 'onnx': + return False + x = torch.randn((1, 3, 224, 224)) + + torch.onnx.export(self.model, x, output_file, input_names=['INPUT0'], + output_names=['OUTPUT0'], dynamic_axes={'INPUT0': [0]}) + return True + + @property + def supported_formats(self): + return ['onnx'] diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..391b72e --- /dev/null +++ b/requirements.txt @@ -0,0 +1 @@ +onnxruntime