diff --git a/README.md b/README.md
index ac94256..57a9adb 100644
--- a/README.md
+++ b/README.md
@@ -1,67 +1,110 @@
-# Copyright 2021 Zilliz. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-@register(output_schema=['vec'])
-class Clip(NNOperator):
- """
- CLIP multi-modal embedding operator
- """
- def __init__(self, model_name: str, modality: str):
- self.modality = modality
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
- self.model = clip.create_model(model_name=model_name, pretrained=True, jit=True)
- self.tokenize = clip.tokenize
- self.tfms = transforms.Compose([
- transforms.Resize(224, interpolation=transforms.InterpolationMode.BICUBIC),
- transforms.CenterCrop(224),
- transforms.ToTensor(),
- transforms.Normalize(
- (0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
- ])
-
- def inference_single_data(self, data):
- if self.modality == 'image':
- vec = self._inference_from_image(data)
- elif self.modality == 'text':
- vec = self._inference_from_text(data)
- else:
- raise ValueError("modality[{}] not implemented.".format(self._modality))
- return vec.detach().cpu().numpy().flatten()
-
- def __call__(self, data):
- if not isinstance(data, list):
- data = [data]
- else:
- data = data
- results = []
- for single_data in data:
- result = self.inference_single_data(single_data)
- results.append(result)
- if len(data) == 1:
- return results[0]
- else:
- return results
-
- def _inference_from_text(self, text):
- text = self.tokenize(text).to(self.device)
- text_features = self.model.encode_text(text)
- return text_features
-
- @arg(1, to_image_color('RGB'))
- def _inference_from_image(self, img):
- img = to_pil(img)
- image = self.tfms(img).unsqueeze(0).to(self.device)
- image_features = self.model.encode_image(image)
- return image_features
+# Russian Image-Text Retrieval Embdding with CLIP
+
+*author: David Wang*
+
+
+
+
+
+
+## Description
+
+This operator extracts features for image or text with [CLIP](https://arxiv.org/abs/2103.00020) which can generate embeddings for text and image by jointly training an image encoder and text encoder to maximize the cosine similarity. This is a Russian version of CLIP adopted from [ai-forever/ru-clip](https://github.com/ai-forever/ru-clip).
+
+
+
+
+
+
+## Code Example
+
+Load an image from path './teddy.jpg' to generate an image embedding.
+
+Read the text 'Плюшевый мишка на скейтборде на Таймс-сквер.' to generate an text embedding.
+
+ *Write the pipeline in simplified style*:
+
+```python
+import towhee
+
+towhee.glob('./teddy.jpg') \
+ .image_decode() \
+ .image_text_embedding.ru_clip(model_name='ruclip-vit-base-patch32-224', modality='image') \
+ .show()
+
+towhee.dc(["'Плюшевый мишка на скейтборде на Таймс-сквер."]) \
+ .image_text_embedding.ru_clip(model_name='ruclip-vit-base-patch32-224', modality='text') \
+ .show()
+```
+
+
+
+*Write a same pipeline with explicit inputs/outputs name specifications:*
+
+```python
+import towhee
+
+towhee.glob['path']('./teddy.jpg') \
+ .image_decode['path', 'img']() \
+ .image_text_embedding.ru_clip['img', 'vec'](model_name='ruclip-vit-base-patch32-224', modality='image') \
+ .select['img', 'vec']() \
+ .show()
+
+towhee.dc['text'](["Плюшевый мишка на скейтборде на Таймс-сквер."]) \
+ .image_text_embedding.ru_clip['text','vec'](model_name='ruclip-vit-base-patch32-224', modality='text') \
+ .select['text', 'vec']() \
+ .show()
+```
+
+
+
+
+
+
+
+
+## Factory Constructor
+
+Create the operator via the following factory method
+
+***ru_clip(model_name, modality)***
+
+**Parameters:**
+
+ ***model_name:*** *str*
+
+ The model name of CLIP. Supported model names:
+- ruclip-vit-base-patch32-224
+- ruclip-vit-base-patch16-224
+- ruclip-vit-large-patch14-224
+- ruclip-vit-large-patch14-336
+- ruclip-vit-base-patch32-384
+- ruclip-vit-base-patch16-384
+
+
+ ***modality:*** *str*
+
+ Which modality(*image* or *text*) is used to generate the embedding.
+
+
+
+
+
+## Interface
+
+An image-text embedding operator takes a [towhee image](link/to/towhee/image/api/doc) or string as input and generate an embedding in ndarray.
+
+
+**Parameters:**
+
+ ***data:*** *towhee.types.Image (a sub-class of numpy.ndarray)* or *str*
+
+ The data (image or text based on specified modality) to generate embedding.
+
+
+
+**Returns:** *numpy.ndarray*
+
+ The data embedding extracted by model.
+
diff --git a/__init__.py b/__init__.py
index e69de29..fd7985e 100644
--- a/__init__.py
+++ b/__init__.py
@@ -0,0 +1,19 @@
+# Copyright 2021 Zilliz. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .ru_clip import RuClip
+
+
+def ru_clip(model_name: str, modality: str):
+ return RuClip(model_name, modality)
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000..bf48659
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,6 @@
+numpy
+torch
+torchvision
+youtokentome
+
+
diff --git a/ru_clip.py b/ru_clip.py
index 941c807..37059aa 100644
--- a/ru_clip.py
+++ b/ru_clip.py
@@ -31,9 +31,6 @@ class RuClip(NNOperator):
def __init__(self, model_name: str, modality: str):
self.modality = modality
self.device = "cuda" if torch.cuda.is_available() else "cpu"
- self.model = clip.create_model(model_name=model_name, pretrained=True, jit=True)
- self.tokenize = clip.tokenize
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
path = str(Path(__file__).parent)
sys.path.append(path)
@@ -41,7 +38,7 @@ class RuClip(NNOperator):
sys.path.pop()
clip, processor = ruclip.load('ruclip-vit-base-patch32-384', device=self.device)
templates = ['{}', 'это {}', 'на картинке {}', 'это {}, домашнее животное']
- self.predictor = ruclip.Predictor(clip, processor, device, bs=1, templates=templates)
+ self.predictor = ruclip.Predictor(clip, processor, self.device, bs=1, templates=templates)
def inference_single_data(self, data):
if self.modality == 'image':
@@ -49,7 +46,7 @@ class RuClip(NNOperator):
elif self.modality == 'text':
vec = self._inference_from_text(data)
else:
- raise ValueError("modality[{}] not implemented.".format(self._modality))
+ raise ValueError("modality[{}] not implemented.".format(self.modality))
return vec.detach().cpu().numpy().flatten()
def __call__(self, data):
diff --git a/ruclip/predictor.py b/ruclip/predictor.py
index 15fd370..8d94ffb 100644
--- a/ruclip/predictor.py
+++ b/ruclip/predictor.py
@@ -1,17 +1,15 @@
# -*- coding: utf-8 -*-
import torch
import more_itertools
-from tqdm import tqdm
class Predictor:
- def __init__(self, clip_model, clip_processor, device, templates=None, bs=8, quiet=False):
+ def __init__(self, clip_model, clip_processor, device, templates=None, bs=8):
self.device = device
self.clip_model = clip_model.to(self.device)
self.clip_model.eval()
self.clip_processor = clip_processor
self.bs = bs
- self.quiet = quiet
self.templates = templates or [
'{}',
'фото, на котором изображено {}',
@@ -35,8 +33,6 @@ class Predictor:
return text_latents
def run(self, images, text_latents):
- if not self.quiet:
- pbar = tqdm()
labels = []
logit_scale = self.clip_model.logit_scale.exp()
for pil_images in more_itertools.chunked(images, self.bs):
@@ -45,21 +41,14 @@ class Predictor:
image_latents = image_latents / image_latents.norm(dim=-1, keepdim=True)
logits_per_text = torch.matmul(text_latents.to(self.device), image_latents.t()) * logit_scale
_labels = logits_per_text.argmax(0).cpu().numpy().tolist()
- if not self.quiet:
- pbar.update(len(_labels))
labels.extend(_labels)
- pbar.close()
return labels
def get_image_latents(self, images):
- if not self.quiet:
- pbar = tqdm()
image_latents = []
for pil_images in more_itertools.chunked(images, self.bs):
inputs = self.clip_processor(text='', images=list(pil_images), return_tensors='pt', padding=True)
image_latents.append(self.clip_model.encode_image(inputs['pixel_values'].to(self.device)))
- if not self.quiet:
- pbar.update(len(pil_images))
image_latents = torch.cat(image_latents)
image_latents = image_latents / image_latents.norm(dim=-1, keepdim=True)
return image_latents
diff --git a/tabular1.png b/tabular1.png
new file mode 100644
index 0000000..58a96b7
Binary files /dev/null and b/tabular1.png differ
diff --git a/tabular2.png b/tabular2.png
new file mode 100644
index 0000000..2f9e186
Binary files /dev/null and b/tabular2.png differ
diff --git a/vec1.png b/vec1.png
new file mode 100644
index 0000000..8a80429
Binary files /dev/null and b/vec1.png differ
diff --git a/vec2.png b/vec2.png
new file mode 100644
index 0000000..48a7890
Binary files /dev/null and b/vec2.png differ