From fa4a0e42f07712e0a6b29bca5c89d194b766b82b Mon Sep 17 00:00:00 2001 From: wxywb Date: Wed, 19 Oct 2022 20:02:42 +0800 Subject: [PATCH] update the ru-clip. Signed-off-by: wxywb --- README.md | 67 ++++++++++++++++++++++++++++++++++++++++++++++- ru_clip.py | 77 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 143 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 39fd745..ac94256 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,67 @@ -# ru-clip +# Copyright 2021 Zilliz. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +@register(output_schema=['vec']) +class Clip(NNOperator): + """ + CLIP multi-modal embedding operator + """ + def __init__(self, model_name: str, modality: str): + self.modality = modality + self.device = "cuda" if torch.cuda.is_available() else "cpu" + self.model = clip.create_model(model_name=model_name, pretrained=True, jit=True) + self.tokenize = clip.tokenize + self.tfms = transforms.Compose([ + transforms.Resize(224, interpolation=transforms.InterpolationMode.BICUBIC), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize( + (0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) + ]) + + def inference_single_data(self, data): + if self.modality == 'image': + vec = self._inference_from_image(data) + elif self.modality == 'text': + vec = self._inference_from_text(data) + else: + raise ValueError("modality[{}] not implemented.".format(self._modality)) + return vec.detach().cpu().numpy().flatten() + + def __call__(self, data): + if not isinstance(data, list): + data = [data] + else: + data = data + results = [] + for single_data in data: + result = self.inference_single_data(single_data) + results.append(result) + if len(data) == 1: + return results[0] + else: + return results + + def _inference_from_text(self, text): + text = self.tokenize(text).to(self.device) + text_features = self.model.encode_text(text) + return text_features + + @arg(1, to_image_color('RGB')) + def _inference_from_image(self, img): + img = to_pil(img) + image = self.tfms(img).unsqueeze(0).to(self.device) + image_features = self.model.encode_image(image) + return image_features diff --git a/ru_clip.py b/ru_clip.py index e69de29..941c807 100644 --- a/ru_clip.py +++ b/ru_clip.py @@ -0,0 +1,77 @@ +# Copyright 2021 Zilliz. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +from pathlib import Path +import torch +from torchvision import transforms + +from towhee.types.image_utils import to_pil +from towhee.operator.base import NNOperator, OperatorFlag +from towhee.types.arg import arg, to_image_color +from towhee import register + + +@register(output_schema=['vec']) +class RuClip(NNOperator): + """ + Russian CLIP multi-modal embedding operator + """ + def __init__(self, model_name: str, modality: str): + self.modality = modality + self.device = "cuda" if torch.cuda.is_available() else "cpu" + self.model = clip.create_model(model_name=model_name, pretrained=True, jit=True) + self.tokenize = clip.tokenize + self.device = "cuda" if torch.cuda.is_available() else "cpu" + + path = str(Path(__file__).parent) + sys.path.append(path) + import ruclip + sys.path.pop() + clip, processor = ruclip.load('ruclip-vit-base-patch32-384', device=self.device) + templates = ['{}', 'это {}', 'на картинке {}', 'это {}, домашнее животное'] + self.predictor = ruclip.Predictor(clip, processor, device, bs=1, templates=templates) + + def inference_single_data(self, data): + if self.modality == 'image': + vec = self._inference_from_image(data) + elif self.modality == 'text': + vec = self._inference_from_text(data) + else: + raise ValueError("modality[{}] not implemented.".format(self._modality)) + return vec.detach().cpu().numpy().flatten() + + def __call__(self, data): + if not isinstance(data, list): + data = [data] + else: + data = data + results = [] + for single_data in data: + result = self.inference_single_data(single_data) + results.append(result) + if len(data) == 1: + return results[0] + else: + return results + + def _inference_from_text(self, text): + text_features = self.predictor.get_text_latents([text]) + return text_features + + @arg(1, to_image_color('RGB')) + def _inference_from_image(self, img): + img = to_pil(img) + image_features = self.predictor.get_image_latents([img]) + return image_features