diff --git a/README.md b/README.md index ac94256..57a9adb 100644 --- a/README.md +++ b/README.md @@ -1,67 +1,110 @@ -# Copyright 2021 Zilliz. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -@register(output_schema=['vec']) -class Clip(NNOperator): - """ - CLIP multi-modal embedding operator - """ - def __init__(self, model_name: str, modality: str): - self.modality = modality - self.device = "cuda" if torch.cuda.is_available() else "cpu" - self.model = clip.create_model(model_name=model_name, pretrained=True, jit=True) - self.tokenize = clip.tokenize - self.tfms = transforms.Compose([ - transforms.Resize(224, interpolation=transforms.InterpolationMode.BICUBIC), - transforms.CenterCrop(224), - transforms.ToTensor(), - transforms.Normalize( - (0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) - ]) - - def inference_single_data(self, data): - if self.modality == 'image': - vec = self._inference_from_image(data) - elif self.modality == 'text': - vec = self._inference_from_text(data) - else: - raise ValueError("modality[{}] not implemented.".format(self._modality)) - return vec.detach().cpu().numpy().flatten() - - def __call__(self, data): - if not isinstance(data, list): - data = [data] - else: - data = data - results = [] - for single_data in data: - result = self.inference_single_data(single_data) - results.append(result) - if len(data) == 1: - return results[0] - else: - return results - - def _inference_from_text(self, text): - text = self.tokenize(text).to(self.device) - text_features = self.model.encode_text(text) - return text_features - - @arg(1, to_image_color('RGB')) - def _inference_from_image(self, img): - img = to_pil(img) - image = self.tfms(img).unsqueeze(0).to(self.device) - image_features = self.model.encode_image(image) - return image_features +# Russian Image-Text Retrieval Embdding with CLIP + +*author: David Wang* + + +
+ + + +## Description + +This operator extracts features for image or text with [CLIP](https://arxiv.org/abs/2103.00020) which can generate embeddings for text and image by jointly training an image encoder and text encoder to maximize the cosine similarity. This is a Russian version of CLIP adopted from [ai-forever/ru-clip](https://github.com/ai-forever/ru-clip). + + + +
+ + +## Code Example + +Load an image from path './teddy.jpg' to generate an image embedding. + +Read the text 'Плюшевый мишка на скейтборде на Таймс-сквер.' to generate an text embedding. + + *Write the pipeline in simplified style*: + +```python +import towhee + +towhee.glob('./teddy.jpg') \ + .image_decode() \ + .image_text_embedding.ru_clip(model_name='ruclip-vit-base-patch32-224', modality='image') \ + .show() + +towhee.dc(["'Плюшевый мишка на скейтборде на Таймс-сквер."]) \ + .image_text_embedding.ru_clip(model_name='ruclip-vit-base-patch32-224', modality='text') \ + .show() +``` +result1 +result2 + +*Write a same pipeline with explicit inputs/outputs name specifications:* + +```python +import towhee + +towhee.glob['path']('./teddy.jpg') \ + .image_decode['path', 'img']() \ + .image_text_embedding.ru_clip['img', 'vec'](model_name='ruclip-vit-base-patch32-224', modality='image') \ + .select['img', 'vec']() \ + .show() + +towhee.dc['text'](["Плюшевый мишка на скейтборде на Таймс-сквер."]) \ + .image_text_embedding.ru_clip['text','vec'](model_name='ruclip-vit-base-patch32-224', modality='text') \ + .select['text', 'vec']() \ + .show() +``` +result1 +result2 + + +
+ + + +## Factory Constructor + +Create the operator via the following factory method + +***ru_clip(model_name, modality)*** + +**Parameters:** + +​ ***model_name:*** *str* + +​ The model name of CLIP. Supported model names: +- ruclip-vit-base-patch32-224 +- ruclip-vit-base-patch16-224 +- ruclip-vit-large-patch14-224 +- ruclip-vit-large-patch14-336 +- ruclip-vit-base-patch32-384 +- ruclip-vit-base-patch16-384 + + +​ ***modality:*** *str* + +​ Which modality(*image* or *text*) is used to generate the embedding. + +
+ + + +## Interface + +An image-text embedding operator takes a [towhee image](link/to/towhee/image/api/doc) or string as input and generate an embedding in ndarray. + + +**Parameters:** + +​ ***data:*** *towhee.types.Image (a sub-class of numpy.ndarray)* or *str* + +​ The data (image or text based on specified modality) to generate embedding. + + + +**Returns:** *numpy.ndarray* + +​ The data embedding extracted by model. + diff --git a/__init__.py b/__init__.py index e69de29..fd7985e 100644 --- a/__init__.py +++ b/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2021 Zilliz. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .ru_clip import RuClip + + +def ru_clip(model_name: str, modality: str): + return RuClip(model_name, modality) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..bf48659 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,6 @@ +numpy +torch +torchvision +youtokentome + + diff --git a/ru_clip.py b/ru_clip.py index 941c807..37059aa 100644 --- a/ru_clip.py +++ b/ru_clip.py @@ -31,9 +31,6 @@ class RuClip(NNOperator): def __init__(self, model_name: str, modality: str): self.modality = modality self.device = "cuda" if torch.cuda.is_available() else "cpu" - self.model = clip.create_model(model_name=model_name, pretrained=True, jit=True) - self.tokenize = clip.tokenize - self.device = "cuda" if torch.cuda.is_available() else "cpu" path = str(Path(__file__).parent) sys.path.append(path) @@ -41,7 +38,7 @@ class RuClip(NNOperator): sys.path.pop() clip, processor = ruclip.load('ruclip-vit-base-patch32-384', device=self.device) templates = ['{}', 'это {}', 'на картинке {}', 'это {}, домашнее животное'] - self.predictor = ruclip.Predictor(clip, processor, device, bs=1, templates=templates) + self.predictor = ruclip.Predictor(clip, processor, self.device, bs=1, templates=templates) def inference_single_data(self, data): if self.modality == 'image': @@ -49,7 +46,7 @@ class RuClip(NNOperator): elif self.modality == 'text': vec = self._inference_from_text(data) else: - raise ValueError("modality[{}] not implemented.".format(self._modality)) + raise ValueError("modality[{}] not implemented.".format(self.modality)) return vec.detach().cpu().numpy().flatten() def __call__(self, data): diff --git a/ruclip/predictor.py b/ruclip/predictor.py index 15fd370..8d94ffb 100644 --- a/ruclip/predictor.py +++ b/ruclip/predictor.py @@ -1,17 +1,15 @@ # -*- coding: utf-8 -*- import torch import more_itertools -from tqdm import tqdm class Predictor: - def __init__(self, clip_model, clip_processor, device, templates=None, bs=8, quiet=False): + def __init__(self, clip_model, clip_processor, device, templates=None, bs=8): self.device = device self.clip_model = clip_model.to(self.device) self.clip_model.eval() self.clip_processor = clip_processor self.bs = bs - self.quiet = quiet self.templates = templates or [ '{}', 'фото, на котором изображено {}', @@ -35,8 +33,6 @@ class Predictor: return text_latents def run(self, images, text_latents): - if not self.quiet: - pbar = tqdm() labels = [] logit_scale = self.clip_model.logit_scale.exp() for pil_images in more_itertools.chunked(images, self.bs): @@ -45,21 +41,14 @@ class Predictor: image_latents = image_latents / image_latents.norm(dim=-1, keepdim=True) logits_per_text = torch.matmul(text_latents.to(self.device), image_latents.t()) * logit_scale _labels = logits_per_text.argmax(0).cpu().numpy().tolist() - if not self.quiet: - pbar.update(len(_labels)) labels.extend(_labels) - pbar.close() return labels def get_image_latents(self, images): - if not self.quiet: - pbar = tqdm() image_latents = [] for pil_images in more_itertools.chunked(images, self.bs): inputs = self.clip_processor(text='', images=list(pil_images), return_tensors='pt', padding=True) image_latents.append(self.clip_model.encode_image(inputs['pixel_values'].to(self.device))) - if not self.quiet: - pbar.update(len(pil_images)) image_latents = torch.cat(image_latents) image_latents = image_latents / image_latents.norm(dim=-1, keepdim=True) return image_latents diff --git a/tabular1.png b/tabular1.png new file mode 100644 index 0000000..58a96b7 Binary files /dev/null and b/tabular1.png differ diff --git a/tabular2.png b/tabular2.png new file mode 100644 index 0000000..2f9e186 Binary files /dev/null and b/tabular2.png differ diff --git a/vec1.png b/vec1.png new file mode 100644 index 0000000..8a80429 Binary files /dev/null and b/vec1.png differ diff --git a/vec2.png b/vec2.png new file mode 100644 index 0000000..48a7890 Binary files /dev/null and b/vec2.png differ