diff --git a/README.md b/README.md index c1e7afc..cfcbc6a 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,85 @@ -# clipcap +# Image Captioning with BLIP +*author: David Wang* + + +
+ + + +## Description + +This operator generates the caption with [BLIP](https://arxiv.org/abs/2201.12086) which describes the content of the given image. This is an adaptation from [salesforce/BLIP](https://github.com/salesforce/BLIP). + + +
+ + +## Code Example + +Load an image from path './animals.jpg' to generate the caption. + + *Write the pipeline in simplified style*: + +```python +import towhee + +towhee.glob('./animals.jpg') \ + .image_decode() \ + .image_captioning.blip(model_name='blip_base') \ + .select() \ + .show() +``` +result1 + +*Write a same pipeline with explicit inputs/outputs name specifications:* + +```python +import towhee + +towhee.glob['path']('./animals.jpg') \ + .image_decode['path', 'img']() \ + .image_captioning.blip['img', 'text'](model_name='blip_base') \ + .select['img', 'text']() \ + .show() +``` +result2 + + +
+ + + +## Factory Constructor + +Create the operator via the following factory method + +***blip(model_name)*** + +**Parameters:** + +​ ***model_name:*** *str* + +​ The model name of BLIP. Supported model names: +- blip_base + +
+ + + +## Interface + +An image-text embedding operator takes a [towhee image](link/to/towhee/image/api/doc) as input and generate the correspoing caption. + + +**Parameters:** + +​ ***data:*** *towhee.types.Image (a sub-class of numpy.ndarray)* or *str* + +​ The image to generate embedding. + + + +**Returns:** *str* + +​ The caption generated by model. diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..e6759c2 --- /dev/null +++ b/__init__.py @@ -0,0 +1,18 @@ +# Copyright 2021 Zilliz. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .clipcap import ClipCap + +def clipcap(model_name: str): + return ClipCap(model_name) diff --git a/clipcap.py b/clipcap.py new file mode 100644 index 0000000..f99c8f8 --- /dev/null +++ b/clipcap.py @@ -0,0 +1,79 @@ +# Copyright 2021 Zilliz. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import torch +from pathlib import Path +from torchvision import transforms + +from towhee.types.image_utils import to_pil +from towhee.operator.base import NNOperator, OperatorFlag +from towhee import register +from towhee.models import clip + +class ClipCap(NNOperator): + """ + ClipCap image captioning operator + """ + def __init__(self, model_name: str): + super().__init__(): + sys.path.append(str(Path(__file__).parent)) + from models.clipcap import ClipCaptionModel + config = self._configs()[model_name] + + self.clip_tfms = self.tfms = transforms.Compose([ + transforms.Resize(224, interpolation=transforms.InterpolationMode.BICUBIC), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize( + (0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) + ]) + + clip_model_type = 'clip_vit_b32' + self.clip_model = clip.create_model(model_name=clip_model_type, pretrained=True, jit=True) + + self.model = ClipCaptionModel(prefix = 10) + model_path = os.path.dirname(__file__) + '/weights/' + config['weights'] + self.model.load_state_dict(torch.load(model_path, map_location=CPU)) + self.model = model.eval() + + + @arg(1, to_image_color('RGB')) + def __call__(self, data:): + vec = self._inference_from_image(data) + return vec + + def _preprocess(self, img): + img = to_pil(img) + processed_img = self.self.clip_tfms(img).unsqueeze(0).to(self.device) + return processed_img + + @arg(1, to_image_color('RGB')) + def _inference_from_image(self, img): + img = self._preprocess(img) + clip_feat = self.clip_model.encode_image(image) + + prefix_length = 10 + prefix_embed = self.model.clip_project(clip_feat).reshape(1, prefix_length, -1) + + generated_text_prefix = generate_beam(model, tokenizer, embed=prefix_embed)[0] + return generated_text_prefix + + def _configs(self): + config = {} + config['clipcap_coco'] = {} + config['clipcap_coco']['weights'] = 'weights/coco_weights.pt' + config['clipcap_conceptual'] = {} + config['clipcap_conceptual']['weights'] = 'weights/conceptual_weights.pt' + return config + diff --git a/main.py b/main.py new file mode 100644 index 0000000..adcb976 --- /dev/null +++ b/main.py @@ -0,0 +1,166 @@ +import clip +import torch +import skimage.io as io +import PIL.Image +import numpy as np +import torch.nn.functional as nnf +from transformers import GPT2Tokenizer, GPT2LMHeadModel, AdamW, get_linear_schedule_with_warmup +from tqdm import tqdm, trange +from clipcap_model import MLP, ClipCaptionModel, ClipCaptionPrefix + +is_gpu = False +device = CUDA(0) if is_gpu else "cpu" +clip_model, preprocess = clip.load("ViT-B/32", device=device, jit=False) +tokenizer = GPT2Tokenizer.from_pretrained("gpt2") +CPU = torch.device('cpu') + + +def generate2( + model, + tokenizer, + tokens=None, + prompt=None, + embed=None, + entry_count=1, + entry_length=67, # maximum number of words + top_p=0.8, + temperature=1., + stop_token: str = '.', +): + model.eval() + generated_num = 0 + generated_list = [] + stop_token_index = tokenizer.encode(stop_token)[0] + filter_value = -float("Inf") + device = next(model.parameters()).device + + with torch.no_grad(): + + for entry_idx in trange(entry_count): + if embed is not None: + generated = embed + else: + if tokens is None: + tokens = torch.tensor(tokenizer.encode(prompt)) + tokens = tokens.unsqueeze(0).to(device) + + generated = model.gpt.transformer.wte(tokens) + + for i in range(entry_length): + + outputs = model.gpt(inputs_embeds=generated) + logits = outputs.logits + logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0) + sorted_logits, sorted_indices = torch.sort(logits, descending=True) + cumulative_probs = torch.cumsum(nnf.softmax(sorted_logits, dim=-1), dim=-1) + sorted_indices_to_remove = cumulative_probs > top_p + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[ ..., :-1].clone() + sorted_indices_to_remove[..., 0] = 0 + + indices_to_remove = sorted_indices[sorted_indices_to_remove] + logits[:, indices_to_remove] = filter_value + next_token = torch.argmax(logits, -1).unsqueeze(0) + next_token_embed = model.gpt.transformer.wte(next_token) + if tokens is None: + tokens = next_token + else: + tokens = torch.cat((tokens, next_token), dim=1) + generated = torch.cat((generated, next_token_embed), dim=1) + if stop_token_index == next_token.item(): + break + + output_list = list(tokens.squeeze().cpu().numpy()) + output_text = tokenizer.decode(output_list) + generated_list.append(output_text) + + return generated_list[0] + +def generate_beam(model, tokenizer, beam_size: int = 5, prompt=None, embed=None, + entry_length=67, temperature=1., stop_token: str = '.'): + + model.eval() + stop_token_index = tokenizer.encode(stop_token)[0] + tokens = None + scores = None + device = next(model.parameters()).device + seq_lengths = torch.ones(beam_size, device=device) + is_stopped = torch.zeros(beam_size, device=device, dtype=torch.bool) + with torch.no_grad(): + if embed is not None: + generated = embed + else: + if tokens is None: + tokens = torch.tensor(tokenizer.encode(prompt)) + tokens = tokens.unsqueeze(0).to(device) + generated = model.gpt.transformer.wte(tokens) + for i in range(entry_length): + outputs = model.gpt(inputs_embeds=generated) + logits = outputs.logits + logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0) + logits = logits.softmax(-1).log() + if scores is None: + scores, next_tokens = logits.topk(beam_size, -1) + generated = generated.expand(beam_size, *generated.shape[1:]) + next_tokens, scores = next_tokens.permute(1, 0), scores.squeeze(0) + if tokens is None: + tokens = next_tokens + else: + tokens = tokens.expand(beam_size, *tokens.shape[1:]) + tokens = torch.cat((tokens, next_tokens), dim=1) + else: + logits[is_stopped] = -float(np.inf) + logits[is_stopped, 0] = 0 + scores_sum = scores[:, None] + logits + seq_lengths[~is_stopped] += 1 + scores_sum_average = scores_sum / seq_lengths[:, None] + scores_sum_average, next_tokens = scores_sum_average.view(-1).topk(beam_size, -1) + next_tokens_source = next_tokens // scores_sum.shape[1] + seq_lengths = seq_lengths[next_tokens_source] + next_tokens = next_tokens % scores_sum.shape[1] + next_tokens = next_tokens.unsqueeze(1) + tokens = tokens[next_tokens_source] + tokens = torch.cat((tokens, next_tokens), dim=1) + generated = generated[next_tokens_source] + scores = scores_sum_average * seq_lengths + is_stopped = is_stopped[next_tokens_source] + next_token_embed = model.gpt.transformer.wte(next_tokens.squeeze()).view(generated.shape[0], 1, -1) + generated = torch.cat((generated, next_token_embed), dim=1) + is_stopped = is_stopped + next_tokens.eq(stop_token_index).squeeze() + if is_stopped.all(): + break + scores = scores / seq_lengths + output_list = tokens.cpu().numpy() + output_texts = [tokenizer.decode(output[:int(length)]) for output, length in zip(output_list, seq_lengths)] + order = scores.argsort(descending=True) + output_texts = [output_texts[i] for i in order] + return output_texts + +prefix_length = 10 + +model = ClipCaptionModel(prefix_length) +model_path = '/Users/zilliz/git/image_captioning/git/clipcap/weights/coco_weights.pt' +model.load_state_dict(torch.load(model_path, map_location=CPU)) +model = model.eval() + +use_beam_search = False #@param {type:"boolean"} +use_beam_search = True #@param {type:"boolean"} + +UPLOADED_FILE = 'einstein.jpg' +image = io.imread(UPLOADED_FILE) +pil_image = PIL.Image.fromarray(image) + +image = preprocess(pil_image).unsqueeze(0).to(device) +with torch.no_grad(): + # if type(model) is ClipCaptionE2E: + # prefix_embed = model.forward_image(image) + # else: + prefix = clip_model.encode_image(image).to(device, dtype=torch.float32) + prefix_embed = model.clip_project(prefix).reshape(1, prefix_length, -1) +if use_beam_search: + generated_text_prefix = generate_beam(model, tokenizer, embed=prefix_embed)[0] +else: + generated_text_prefix = generate2(model, tokenizer, embed=prefix_embed) + +print(generated_text_prefix) + + diff --git a/models/.utils.py.swp b/models/.utils.py.swp new file mode 100644 index 0000000..ac298e3 Binary files /dev/null and b/models/.utils.py.swp differ diff --git a/models/clipcap.py b/models/clipcap.py new file mode 100644 index 0000000..ef15bb3 --- /dev/null +++ b/models/clipcap.py @@ -0,0 +1,136 @@ +import torch +import torch.nn.functional as nnf +#@title Imports + +from transformers import GPT2Tokenizer, GPT2LMHeadModel, AdamW, get_linear_schedule_with_warmup +import clip +import os +from typing import Tuple, List, Union, Optional +from torch import nn +import numpy as np +import torch +import torch.nn.functional as nnf +import sys + +T = torch.Tensor +D = torch.device +is_gpu = False + +def get_device(device_id: int) -> D: + if not torch.cuda.is_available(): + return CPU + device_id = min(torch.cuda.device_count() - 1, device_id) + return torch.device(f'cuda:{device_id}') + +class MLP(nn.Module): + + def forward(self, x: T) -> T: + return self.model(x) + + def __init__(self, sizes: Tuple[int, ...], bias=True, act=nn.Tanh): + super(MLP, self).__init__() + layers = [] + for i in range(len(sizes) -1): + layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=bias)) + if i < len(sizes) - 2: + layers.append(act()) + self.model = nn.Sequential(*layers) + +class ClipCaptionModel(nn.Module): + + #@functools.lru_cache #FIXME + def get_dummy_token(self, batch_size: int, device: D) -> T: + return torch.zeros(batch_size, self.prefix_length, dtype=torch.int64, device=device) + + def forward(self, tokens: T, prefix: T, mask: Optional[T] = None, labels: Optional[T] = None): + embedding_text = self.gpt.transformer.wte(tokens) + prefix_projections = self.clip_project(prefix).view(-1, self.prefix_length, self.gpt_embedding_size) + #print(embedding_text.size()) #torch.Size([5, 67, 768]) + #print(prefix_projections.size()) #torch.Size([5, 1, 768]) + embedding_cat = torch.cat((prefix_projections, embedding_text), dim=1) + if labels is not None: + dummy_token = self.get_dummy_token(tokens.shape[0], tokens.device) + labels = torch.cat((dummy_token, tokens), dim=1) + out = self.gpt(inputs_embeds=embedding_cat, labels=labels, attention_mask=mask) + return out + + def __init__(self, prefix_length: int, prefix_size: int = 512): + super(ClipCaptionModel, self).__init__() + self.prefix_length = prefix_length + self.gpt = GPT2LMHeadModel.from_pretrained('gpt2') + self.gpt_embedding_size = self.gpt.transformer.wte.weight.shape[1] + if prefix_length > 10: # not enough memory + self.clip_project = nn.Linear(prefix_size, self.gpt_embedding_size * prefix_length) + else: + self.clip_project = MLP((prefix_size, (self.gpt_embedding_size * prefix_length) // 2, self.gpt_embedding_size * prefix_length)) + +class ClipCaptionPrefix(ClipCaptionModel): + + def parameters(self, recurse: bool = True): + return self.clip_project.parameters() + + def train(self, mode: bool = True): + super(ClipCaptionPrefix, self).train(mode) + self.gpt.eval() + return self + +def generate_beam(model, tokenizer, beam_size: int = 5, prompt=None, embed=None, + entry_length=67, temperature=1., stop_token: str = '.'): + + model.eval() + stop_token_index = tokenizer.encode(stop_token)[0] + tokens = None + scores = None + device = next(model.parameters()).device + seq_lengths = torch.ones(beam_size, device=device) + is_stopped = torch.zeros(beam_size, device=device, dtype=torch.bool) + with torch.no_grad(): + if embed is not None: + generated = embed + else: + if tokens is None: + tokens = torch.tensor(tokenizer.encode(prompt)) + tokens = tokens.unsqueeze(0).to(device) + generated = model.gpt.transformer.wte(tokens) + for i in range(entry_length): + outputs = model.gpt(inputs_embeds=generated) + logits = outputs.logits + logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0) + logits = logits.softmax(-1).log() + if scores is None: + scores, next_tokens = logits.topk(beam_size, -1) + generated = generated.expand(beam_size, *generated.shape[1:]) + next_tokens, scores = next_tokens.permute(1, 0), scores.squeeze(0) + if tokens is None: + tokens = next_tokens + else: + tokens = tokens.expand(beam_size, *tokens.shape[1:]) + tokens = torch.cat((tokens, next_tokens), dim=1) + else: + logits[is_stopped] = -float(np.inf) + logits[is_stopped, 0] = 0 + scores_sum = scores[:, None] + logits + seq_lengths[~is_stopped] += 1 + scores_sum_average = scores_sum / seq_lengths[:, None] + scores_sum_average, next_tokens = scores_sum_average.view(-1).topk(beam_size, -1) + next_tokens_source = next_tokens // scores_sum.shape[1] + seq_lengths = seq_lengths[next_tokens_source] + next_tokens = next_tokens % scores_sum.shape[1] + next_tokens = next_tokens.unsqueeze(1) + tokens = tokens[next_tokens_source] + tokens = torch.cat((tokens, next_tokens), dim=1) + generated = generated[next_tokens_source] + scores = scores_sum_average * seq_lengths + is_stopped = is_stopped[next_tokens_source] + next_token_embed = model.gpt.transformer.wte(next_tokens.squeeze()).view(generated.shape[0], 1, -1) + generated = torch.cat((generated, next_token_embed), dim=1) + is_stopped = is_stopped + next_tokens.eq(stop_token_index).squeeze() + if is_stopped.all(): + break + scores = scores / seq_lengths + output_list = tokens.cpu().numpy() + output_texts = [tokenizer.decode(output[:int(length)]) for output, length in zip(output_list, seq_lengths)] + order = scores.argsort(descending=True) + output_texts = [output_texts[i] for i in order] + return output_texts + diff --git a/weights/coco_weights.pt b/weights/coco_weights.pt new file mode 100644 index 0000000..41fbcc7 --- /dev/null +++ b/weights/coco_weights.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0f179e3da4662f132d181f5aef4989d72c7e3b61c2fe04691fa72c45047c6b2f +size 636286431 diff --git a/weights/conceptual_weights.pt b/weights/conceptual_weights.pt new file mode 100644 index 0000000..bfe7876 --- /dev/null +++ b/weights/conceptual_weights.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f09faf2b3b390c201ec3b80b223ea0baa2b303074d43dc3dec5663a9ecd34607 +size 636286431