logo
Browse Source

update the operator.

Signed-off-by: wxywb <xy.wang@zilliz.com>
main
wxywb 2 years ago
parent
commit
2e241bd580
  1. 175
      README.md
  2. 19
      __init__.py
  3. 6
      requirements.txt
  4. 7
      ru_clip.py
  5. 13
      ruclip/predictor.py
  6. BIN
      tabular1.png
  7. BIN
      tabular2.png
  8. BIN
      vec1.png
  9. BIN
      vec2.png

175
README.md

@ -1,67 +1,110 @@
# Copyright 2021 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
@register(output_schema=['vec'])
class Clip(NNOperator):
"""
CLIP multi-modal embedding operator
"""
def __init__(self, model_name: str, modality: str):
self.modality = modality
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model = clip.create_model(model_name=model_name, pretrained=True, jit=True)
self.tokenize = clip.tokenize
self.tfms = transforms.Compose([
transforms.Resize(224, interpolation=transforms.InterpolationMode.BICUBIC),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
(0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
])
def inference_single_data(self, data):
if self.modality == 'image':
vec = self._inference_from_image(data)
elif self.modality == 'text':
vec = self._inference_from_text(data)
else:
raise ValueError("modality[{}] not implemented.".format(self._modality))
return vec.detach().cpu().numpy().flatten()
def __call__(self, data):
if not isinstance(data, list):
data = [data]
else:
data = data
results = []
for single_data in data:
result = self.inference_single_data(single_data)
results.append(result)
if len(data) == 1:
return results[0]
else:
return results
def _inference_from_text(self, text):
text = self.tokenize(text).to(self.device)
text_features = self.model.encode_text(text)
return text_features
@arg(1, to_image_color('RGB'))
def _inference_from_image(self, img):
img = to_pil(img)
image = self.tfms(img).unsqueeze(0).to(self.device)
image_features = self.model.encode_image(image)
return image_features
# Russian Image-Text Retrieval Embdding with CLIP
*author: David Wang*
<br />
## Description
This operator extracts features for image or text with [CLIP](https://arxiv.org/abs/2103.00020) which can generate embeddings for text and image by jointly training an image encoder and text encoder to maximize the cosine similarity. This is a Russian version of CLIP adopted from [ai-forever/ru-clip](https://github.com/ai-forever/ru-clip).
<br />
## Code Example
Load an image from path './teddy.jpg' to generate an image embedding.
Read the text 'Плюшевый мишка на скейтборде на Таймс-сквер.' to generate an text embedding.
*Write the pipeline in simplified style*:
```python
import towhee
towhee.glob('./teddy.jpg') \
.image_decode() \
.image_text_embedding.ru_clip(model_name='ruclip-vit-base-patch32-224', modality='image') \
.show()
towhee.dc(["'Плюшевый мишка на скейтборде на Таймс-сквер."]) \
.image_text_embedding.ru_clip(model_name='ruclip-vit-base-patch32-224', modality='text') \
.show()
```
<img src="./vec1.png" alt="result1" style="height:20px;"/>
<img src="./vec2.png" alt="result2" style="height:20px;"/>
*Write a same pipeline with explicit inputs/outputs name specifications:*
```python
import towhee
towhee.glob['path']('./teddy.jpg') \
.image_decode['path', 'img']() \
.image_text_embedding.ru_clip['img', 'vec'](model_name='ruclip-vit-base-patch32-224', modality='image') \
.select['img', 'vec']() \
.show()
towhee.dc['text'](["Плюшевый мишка на скейтборде на Таймс-сквер."]) \
.image_text_embedding.ru_clip['text','vec'](model_name='ruclip-vit-base-patch32-224', modality='text') \
.select['text', 'vec']() \
.show()
```
<img src="./tabular1.png" alt="result1" style="height:60px;"/>
<img src="./tabular2.png" alt="result2" style="height:60px;"/>
<br />
## Factory Constructor
Create the operator via the following factory method
***ru_clip(model_name, modality)***
**Parameters:**
***model_name:*** *str*
​ The model name of CLIP. Supported model names:
- ruclip-vit-base-patch32-224
- ruclip-vit-base-patch16-224
- ruclip-vit-large-patch14-224
- ruclip-vit-large-patch14-336
- ruclip-vit-base-patch32-384
- ruclip-vit-base-patch16-384
***modality:*** *str*
​ Which modality(*image* or *text*) is used to generate the embedding.
<br />
## Interface
An image-text embedding operator takes a [towhee image](link/to/towhee/image/api/doc) or string as input and generate an embedding in ndarray.
**Parameters:**
***data:*** *towhee.types.Image (a sub-class of numpy.ndarray)* or *str*
​ The data (image or text based on specified modality) to generate embedding.
**Returns:** *numpy.ndarray*
​ The data embedding extracted by model.

19
__init__.py

@ -0,0 +1,19 @@
# Copyright 2021 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .ru_clip import RuClip
def ru_clip(model_name: str, modality: str):
return RuClip(model_name, modality)

6
requirements.txt

@ -0,0 +1,6 @@
numpy
torch
torchvision
youtokentome

7
ru_clip.py

@ -31,9 +31,6 @@ class RuClip(NNOperator):
def __init__(self, model_name: str, modality: str):
self.modality = modality
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model = clip.create_model(model_name=model_name, pretrained=True, jit=True)
self.tokenize = clip.tokenize
self.device = "cuda" if torch.cuda.is_available() else "cpu"
path = str(Path(__file__).parent)
sys.path.append(path)
@ -41,7 +38,7 @@ class RuClip(NNOperator):
sys.path.pop()
clip, processor = ruclip.load('ruclip-vit-base-patch32-384', device=self.device)
templates = ['{}', 'это {}', 'на картинке {}', 'это {}, домашнее животное']
self.predictor = ruclip.Predictor(clip, processor, device, bs=1, templates=templates)
self.predictor = ruclip.Predictor(clip, processor, self.device, bs=1, templates=templates)
def inference_single_data(self, data):
if self.modality == 'image':
@ -49,7 +46,7 @@ class RuClip(NNOperator):
elif self.modality == 'text':
vec = self._inference_from_text(data)
else:
raise ValueError("modality[{}] not implemented.".format(self._modality))
raise ValueError("modality[{}] not implemented.".format(self.modality))
return vec.detach().cpu().numpy().flatten()
def __call__(self, data):

13
ruclip/predictor.py

@ -1,17 +1,15 @@
# -*- coding: utf-8 -*-
import torch
import more_itertools
from tqdm import tqdm
class Predictor:
def __init__(self, clip_model, clip_processor, device, templates=None, bs=8, quiet=False):
def __init__(self, clip_model, clip_processor, device, templates=None, bs=8):
self.device = device
self.clip_model = clip_model.to(self.device)
self.clip_model.eval()
self.clip_processor = clip_processor
self.bs = bs
self.quiet = quiet
self.templates = templates or [
'{}',
'фото, на котором изображено {}',
@ -35,8 +33,6 @@ class Predictor:
return text_latents
def run(self, images, text_latents):
if not self.quiet:
pbar = tqdm()
labels = []
logit_scale = self.clip_model.logit_scale.exp()
for pil_images in more_itertools.chunked(images, self.bs):
@ -45,21 +41,14 @@ class Predictor:
image_latents = image_latents / image_latents.norm(dim=-1, keepdim=True)
logits_per_text = torch.matmul(text_latents.to(self.device), image_latents.t()) * logit_scale
_labels = logits_per_text.argmax(0).cpu().numpy().tolist()
if not self.quiet:
pbar.update(len(_labels))
labels.extend(_labels)
pbar.close()
return labels
def get_image_latents(self, images):
if not self.quiet:
pbar = tqdm()
image_latents = []
for pil_images in more_itertools.chunked(images, self.bs):
inputs = self.clip_processor(text='', images=list(pil_images), return_tensors='pt', padding=True)
image_latents.append(self.clip_model.encode_image(inputs['pixel_values'].to(self.device)))
if not self.quiet:
pbar.update(len(pil_images))
image_latents = torch.cat(image_latents)
image_latents = image_latents / image_latents.norm(dim=-1, keepdim=True)
return image_latents

BIN
tabular1.png

Binary file not shown.

After

Width:  |  Height:  |  Size: 179 KiB

BIN
tabular2.png

Binary file not shown.

After

Width:  |  Height:  |  Size: 24 KiB

BIN
vec1.png

Binary file not shown.

After

Width:  |  Height:  |  Size: 14 KiB

BIN
vec2.png

Binary file not shown.

After

Width:  |  Height:  |  Size: 13 KiB

Loading…
Cancel
Save