diff --git a/README.md b/README.md
index c1e7afc..cfcbc6a 100644
--- a/README.md
+++ b/README.md
@@ -1,2 +1,85 @@
-# clipcap
+# Image Captioning with BLIP
+*author: David Wang*
+
+
+
+
+
+
+## Description
+
+This operator generates the caption with [BLIP](https://arxiv.org/abs/2201.12086) which describes the content of the given image. This is an adaptation from [salesforce/BLIP](https://github.com/salesforce/BLIP).
+
+
+
+
+
+## Code Example
+
+Load an image from path './animals.jpg' to generate the caption.
+
+ *Write the pipeline in simplified style*:
+
+```python
+import towhee
+
+towhee.glob('./animals.jpg') \
+ .image_decode() \
+ .image_captioning.blip(model_name='blip_base') \
+ .select() \
+ .show()
+```
+
+
+*Write a same pipeline with explicit inputs/outputs name specifications:*
+
+```python
+import towhee
+
+towhee.glob['path']('./animals.jpg') \
+ .image_decode['path', 'img']() \
+ .image_captioning.blip['img', 'text'](model_name='blip_base') \
+ .select['img', 'text']() \
+ .show()
+```
+
+
+
+
+
+
+
+## Factory Constructor
+
+Create the operator via the following factory method
+
+***blip(model_name)***
+
+**Parameters:**
+
+ ***model_name:*** *str*
+
+ The model name of BLIP. Supported model names:
+- blip_base
+
+
+
+
+
+## Interface
+
+An image-text embedding operator takes a [towhee image](link/to/towhee/image/api/doc) as input and generate the correspoing caption.
+
+
+**Parameters:**
+
+ ***data:*** *towhee.types.Image (a sub-class of numpy.ndarray)* or *str*
+
+ The image to generate embedding.
+
+
+
+**Returns:** *str*
+
+ The caption generated by model.
diff --git a/__init__.py b/__init__.py
new file mode 100644
index 0000000..e6759c2
--- /dev/null
+++ b/__init__.py
@@ -0,0 +1,18 @@
+# Copyright 2021 Zilliz. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .clipcap import ClipCap
+
+def clipcap(model_name: str):
+ return ClipCap(model_name)
diff --git a/clipcap.py b/clipcap.py
new file mode 100644
index 0000000..f99c8f8
--- /dev/null
+++ b/clipcap.py
@@ -0,0 +1,79 @@
+# Copyright 2021 Zilliz. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+import torch
+from pathlib import Path
+from torchvision import transforms
+
+from towhee.types.image_utils import to_pil
+from towhee.operator.base import NNOperator, OperatorFlag
+from towhee import register
+from towhee.models import clip
+
+class ClipCap(NNOperator):
+ """
+ ClipCap image captioning operator
+ """
+ def __init__(self, model_name: str):
+ super().__init__():
+ sys.path.append(str(Path(__file__).parent))
+ from models.clipcap import ClipCaptionModel
+ config = self._configs()[model_name]
+
+ self.clip_tfms = self.tfms = transforms.Compose([
+ transforms.Resize(224, interpolation=transforms.InterpolationMode.BICUBIC),
+ transforms.CenterCrop(224),
+ transforms.ToTensor(),
+ transforms.Normalize(
+ (0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
+ ])
+
+ clip_model_type = 'clip_vit_b32'
+ self.clip_model = clip.create_model(model_name=clip_model_type, pretrained=True, jit=True)
+
+ self.model = ClipCaptionModel(prefix = 10)
+ model_path = os.path.dirname(__file__) + '/weights/' + config['weights']
+ self.model.load_state_dict(torch.load(model_path, map_location=CPU))
+ self.model = model.eval()
+
+
+ @arg(1, to_image_color('RGB'))
+ def __call__(self, data:):
+ vec = self._inference_from_image(data)
+ return vec
+
+ def _preprocess(self, img):
+ img = to_pil(img)
+ processed_img = self.self.clip_tfms(img).unsqueeze(0).to(self.device)
+ return processed_img
+
+ @arg(1, to_image_color('RGB'))
+ def _inference_from_image(self, img):
+ img = self._preprocess(img)
+ clip_feat = self.clip_model.encode_image(image)
+
+ prefix_length = 10
+ prefix_embed = self.model.clip_project(clip_feat).reshape(1, prefix_length, -1)
+
+ generated_text_prefix = generate_beam(model, tokenizer, embed=prefix_embed)[0]
+ return generated_text_prefix
+
+ def _configs(self):
+ config = {}
+ config['clipcap_coco'] = {}
+ config['clipcap_coco']['weights'] = 'weights/coco_weights.pt'
+ config['clipcap_conceptual'] = {}
+ config['clipcap_conceptual']['weights'] = 'weights/conceptual_weights.pt'
+ return config
+
diff --git a/main.py b/main.py
new file mode 100644
index 0000000..adcb976
--- /dev/null
+++ b/main.py
@@ -0,0 +1,166 @@
+import clip
+import torch
+import skimage.io as io
+import PIL.Image
+import numpy as np
+import torch.nn.functional as nnf
+from transformers import GPT2Tokenizer, GPT2LMHeadModel, AdamW, get_linear_schedule_with_warmup
+from tqdm import tqdm, trange
+from clipcap_model import MLP, ClipCaptionModel, ClipCaptionPrefix
+
+is_gpu = False
+device = CUDA(0) if is_gpu else "cpu"
+clip_model, preprocess = clip.load("ViT-B/32", device=device, jit=False)
+tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
+CPU = torch.device('cpu')
+
+
+def generate2(
+ model,
+ tokenizer,
+ tokens=None,
+ prompt=None,
+ embed=None,
+ entry_count=1,
+ entry_length=67, # maximum number of words
+ top_p=0.8,
+ temperature=1.,
+ stop_token: str = '.',
+):
+ model.eval()
+ generated_num = 0
+ generated_list = []
+ stop_token_index = tokenizer.encode(stop_token)[0]
+ filter_value = -float("Inf")
+ device = next(model.parameters()).device
+
+ with torch.no_grad():
+
+ for entry_idx in trange(entry_count):
+ if embed is not None:
+ generated = embed
+ else:
+ if tokens is None:
+ tokens = torch.tensor(tokenizer.encode(prompt))
+ tokens = tokens.unsqueeze(0).to(device)
+
+ generated = model.gpt.transformer.wte(tokens)
+
+ for i in range(entry_length):
+
+ outputs = model.gpt(inputs_embeds=generated)
+ logits = outputs.logits
+ logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
+ cumulative_probs = torch.cumsum(nnf.softmax(sorted_logits, dim=-1), dim=-1)
+ sorted_indices_to_remove = cumulative_probs > top_p
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[ ..., :-1].clone()
+ sorted_indices_to_remove[..., 0] = 0
+
+ indices_to_remove = sorted_indices[sorted_indices_to_remove]
+ logits[:, indices_to_remove] = filter_value
+ next_token = torch.argmax(logits, -1).unsqueeze(0)
+ next_token_embed = model.gpt.transformer.wte(next_token)
+ if tokens is None:
+ tokens = next_token
+ else:
+ tokens = torch.cat((tokens, next_token), dim=1)
+ generated = torch.cat((generated, next_token_embed), dim=1)
+ if stop_token_index == next_token.item():
+ break
+
+ output_list = list(tokens.squeeze().cpu().numpy())
+ output_text = tokenizer.decode(output_list)
+ generated_list.append(output_text)
+
+ return generated_list[0]
+
+def generate_beam(model, tokenizer, beam_size: int = 5, prompt=None, embed=None,
+ entry_length=67, temperature=1., stop_token: str = '.'):
+
+ model.eval()
+ stop_token_index = tokenizer.encode(stop_token)[0]
+ tokens = None
+ scores = None
+ device = next(model.parameters()).device
+ seq_lengths = torch.ones(beam_size, device=device)
+ is_stopped = torch.zeros(beam_size, device=device, dtype=torch.bool)
+ with torch.no_grad():
+ if embed is not None:
+ generated = embed
+ else:
+ if tokens is None:
+ tokens = torch.tensor(tokenizer.encode(prompt))
+ tokens = tokens.unsqueeze(0).to(device)
+ generated = model.gpt.transformer.wte(tokens)
+ for i in range(entry_length):
+ outputs = model.gpt(inputs_embeds=generated)
+ logits = outputs.logits
+ logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
+ logits = logits.softmax(-1).log()
+ if scores is None:
+ scores, next_tokens = logits.topk(beam_size, -1)
+ generated = generated.expand(beam_size, *generated.shape[1:])
+ next_tokens, scores = next_tokens.permute(1, 0), scores.squeeze(0)
+ if tokens is None:
+ tokens = next_tokens
+ else:
+ tokens = tokens.expand(beam_size, *tokens.shape[1:])
+ tokens = torch.cat((tokens, next_tokens), dim=1)
+ else:
+ logits[is_stopped] = -float(np.inf)
+ logits[is_stopped, 0] = 0
+ scores_sum = scores[:, None] + logits
+ seq_lengths[~is_stopped] += 1
+ scores_sum_average = scores_sum / seq_lengths[:, None]
+ scores_sum_average, next_tokens = scores_sum_average.view(-1).topk(beam_size, -1)
+ next_tokens_source = next_tokens // scores_sum.shape[1]
+ seq_lengths = seq_lengths[next_tokens_source]
+ next_tokens = next_tokens % scores_sum.shape[1]
+ next_tokens = next_tokens.unsqueeze(1)
+ tokens = tokens[next_tokens_source]
+ tokens = torch.cat((tokens, next_tokens), dim=1)
+ generated = generated[next_tokens_source]
+ scores = scores_sum_average * seq_lengths
+ is_stopped = is_stopped[next_tokens_source]
+ next_token_embed = model.gpt.transformer.wte(next_tokens.squeeze()).view(generated.shape[0], 1, -1)
+ generated = torch.cat((generated, next_token_embed), dim=1)
+ is_stopped = is_stopped + next_tokens.eq(stop_token_index).squeeze()
+ if is_stopped.all():
+ break
+ scores = scores / seq_lengths
+ output_list = tokens.cpu().numpy()
+ output_texts = [tokenizer.decode(output[:int(length)]) for output, length in zip(output_list, seq_lengths)]
+ order = scores.argsort(descending=True)
+ output_texts = [output_texts[i] for i in order]
+ return output_texts
+
+prefix_length = 10
+
+model = ClipCaptionModel(prefix_length)
+model_path = '/Users/zilliz/git/image_captioning/git/clipcap/weights/coco_weights.pt'
+model.load_state_dict(torch.load(model_path, map_location=CPU))
+model = model.eval()
+
+use_beam_search = False #@param {type:"boolean"}
+use_beam_search = True #@param {type:"boolean"}
+
+UPLOADED_FILE = 'einstein.jpg'
+image = io.imread(UPLOADED_FILE)
+pil_image = PIL.Image.fromarray(image)
+
+image = preprocess(pil_image).unsqueeze(0).to(device)
+with torch.no_grad():
+ # if type(model) is ClipCaptionE2E:
+ # prefix_embed = model.forward_image(image)
+ # else:
+ prefix = clip_model.encode_image(image).to(device, dtype=torch.float32)
+ prefix_embed = model.clip_project(prefix).reshape(1, prefix_length, -1)
+if use_beam_search:
+ generated_text_prefix = generate_beam(model, tokenizer, embed=prefix_embed)[0]
+else:
+ generated_text_prefix = generate2(model, tokenizer, embed=prefix_embed)
+
+print(generated_text_prefix)
+
+
diff --git a/models/.utils.py.swp b/models/.utils.py.swp
new file mode 100644
index 0000000..ac298e3
Binary files /dev/null and b/models/.utils.py.swp differ
diff --git a/models/clipcap.py b/models/clipcap.py
new file mode 100644
index 0000000..ef15bb3
--- /dev/null
+++ b/models/clipcap.py
@@ -0,0 +1,136 @@
+import torch
+import torch.nn.functional as nnf
+#@title Imports
+
+from transformers import GPT2Tokenizer, GPT2LMHeadModel, AdamW, get_linear_schedule_with_warmup
+import clip
+import os
+from typing import Tuple, List, Union, Optional
+from torch import nn
+import numpy as np
+import torch
+import torch.nn.functional as nnf
+import sys
+
+T = torch.Tensor
+D = torch.device
+is_gpu = False
+
+def get_device(device_id: int) -> D:
+ if not torch.cuda.is_available():
+ return CPU
+ device_id = min(torch.cuda.device_count() - 1, device_id)
+ return torch.device(f'cuda:{device_id}')
+
+class MLP(nn.Module):
+
+ def forward(self, x: T) -> T:
+ return self.model(x)
+
+ def __init__(self, sizes: Tuple[int, ...], bias=True, act=nn.Tanh):
+ super(MLP, self).__init__()
+ layers = []
+ for i in range(len(sizes) -1):
+ layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=bias))
+ if i < len(sizes) - 2:
+ layers.append(act())
+ self.model = nn.Sequential(*layers)
+
+class ClipCaptionModel(nn.Module):
+
+ #@functools.lru_cache #FIXME
+ def get_dummy_token(self, batch_size: int, device: D) -> T:
+ return torch.zeros(batch_size, self.prefix_length, dtype=torch.int64, device=device)
+
+ def forward(self, tokens: T, prefix: T, mask: Optional[T] = None, labels: Optional[T] = None):
+ embedding_text = self.gpt.transformer.wte(tokens)
+ prefix_projections = self.clip_project(prefix).view(-1, self.prefix_length, self.gpt_embedding_size)
+ #print(embedding_text.size()) #torch.Size([5, 67, 768])
+ #print(prefix_projections.size()) #torch.Size([5, 1, 768])
+ embedding_cat = torch.cat((prefix_projections, embedding_text), dim=1)
+ if labels is not None:
+ dummy_token = self.get_dummy_token(tokens.shape[0], tokens.device)
+ labels = torch.cat((dummy_token, tokens), dim=1)
+ out = self.gpt(inputs_embeds=embedding_cat, labels=labels, attention_mask=mask)
+ return out
+
+ def __init__(self, prefix_length: int, prefix_size: int = 512):
+ super(ClipCaptionModel, self).__init__()
+ self.prefix_length = prefix_length
+ self.gpt = GPT2LMHeadModel.from_pretrained('gpt2')
+ self.gpt_embedding_size = self.gpt.transformer.wte.weight.shape[1]
+ if prefix_length > 10: # not enough memory
+ self.clip_project = nn.Linear(prefix_size, self.gpt_embedding_size * prefix_length)
+ else:
+ self.clip_project = MLP((prefix_size, (self.gpt_embedding_size * prefix_length) // 2, self.gpt_embedding_size * prefix_length))
+
+class ClipCaptionPrefix(ClipCaptionModel):
+
+ def parameters(self, recurse: bool = True):
+ return self.clip_project.parameters()
+
+ def train(self, mode: bool = True):
+ super(ClipCaptionPrefix, self).train(mode)
+ self.gpt.eval()
+ return self
+
+def generate_beam(model, tokenizer, beam_size: int = 5, prompt=None, embed=None,
+ entry_length=67, temperature=1., stop_token: str = '.'):
+
+ model.eval()
+ stop_token_index = tokenizer.encode(stop_token)[0]
+ tokens = None
+ scores = None
+ device = next(model.parameters()).device
+ seq_lengths = torch.ones(beam_size, device=device)
+ is_stopped = torch.zeros(beam_size, device=device, dtype=torch.bool)
+ with torch.no_grad():
+ if embed is not None:
+ generated = embed
+ else:
+ if tokens is None:
+ tokens = torch.tensor(tokenizer.encode(prompt))
+ tokens = tokens.unsqueeze(0).to(device)
+ generated = model.gpt.transformer.wte(tokens)
+ for i in range(entry_length):
+ outputs = model.gpt(inputs_embeds=generated)
+ logits = outputs.logits
+ logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
+ logits = logits.softmax(-1).log()
+ if scores is None:
+ scores, next_tokens = logits.topk(beam_size, -1)
+ generated = generated.expand(beam_size, *generated.shape[1:])
+ next_tokens, scores = next_tokens.permute(1, 0), scores.squeeze(0)
+ if tokens is None:
+ tokens = next_tokens
+ else:
+ tokens = tokens.expand(beam_size, *tokens.shape[1:])
+ tokens = torch.cat((tokens, next_tokens), dim=1)
+ else:
+ logits[is_stopped] = -float(np.inf)
+ logits[is_stopped, 0] = 0
+ scores_sum = scores[:, None] + logits
+ seq_lengths[~is_stopped] += 1
+ scores_sum_average = scores_sum / seq_lengths[:, None]
+ scores_sum_average, next_tokens = scores_sum_average.view(-1).topk(beam_size, -1)
+ next_tokens_source = next_tokens // scores_sum.shape[1]
+ seq_lengths = seq_lengths[next_tokens_source]
+ next_tokens = next_tokens % scores_sum.shape[1]
+ next_tokens = next_tokens.unsqueeze(1)
+ tokens = tokens[next_tokens_source]
+ tokens = torch.cat((tokens, next_tokens), dim=1)
+ generated = generated[next_tokens_source]
+ scores = scores_sum_average * seq_lengths
+ is_stopped = is_stopped[next_tokens_source]
+ next_token_embed = model.gpt.transformer.wte(next_tokens.squeeze()).view(generated.shape[0], 1, -1)
+ generated = torch.cat((generated, next_token_embed), dim=1)
+ is_stopped = is_stopped + next_tokens.eq(stop_token_index).squeeze()
+ if is_stopped.all():
+ break
+ scores = scores / seq_lengths
+ output_list = tokens.cpu().numpy()
+ output_texts = [tokenizer.decode(output[:int(length)]) for output, length in zip(output_list, seq_lengths)]
+ order = scores.argsort(descending=True)
+ output_texts = [output_texts[i] for i in order]
+ return output_texts
+
diff --git a/weights/coco_weights.pt b/weights/coco_weights.pt
new file mode 100644
index 0000000..41fbcc7
--- /dev/null
+++ b/weights/coco_weights.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0f179e3da4662f132d181f5aef4989d72c7e3b61c2fe04691fa72c45047c6b2f
+size 636286431
diff --git a/weights/conceptual_weights.pt b/weights/conceptual_weights.pt
new file mode 100644
index 0000000..bfe7876
--- /dev/null
+++ b/weights/conceptual_weights.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f09faf2b3b390c201ec3b80b223ea0baa2b303074d43dc3dec5663a9ecd34607
+size 636286431