diff --git a/README.md b/README.md
index f428144..f430260 100644
--- a/README.md
+++ b/README.md
@@ -1,2 +1,84 @@
-# capdec
+# Image Captioning with CapDec
+
+*author: David Wang*
+
+
+
+
+
+## Description
+
+This operator generates the caption with [CapDec](https://arxiv.org/abs/2211.00575) which describes the content of the given image. ExpansionNet v2 introduces the Block Static Expansion which distributes and processes the input over a heterogeneous and arbitrarily big collection of sequences characterized by a different length compared to the input one. This is an adaptation from [DavidHuji/CapDec](https://github.com/DavidHuji/CapDec).
+
+
+
+
+
+## Code Example
+
+Load an image from path './image.jpg' to generate the caption.
+
+ *Write the pipeline in simplified style*:
+
+```python
+import towhee
+
+towhee.glob('./image.jpg') \
+ .image_decode() \
+ .image_captioning.capdec(model_name='capdec_noise_0') \
+ .show()
+```
+
+
+*Write a same pipeline with explicit inputs/outputs name specifications:*
+
+```python
+import towhee
+
+towhee.glob['path']('./image.jpg') \
+ .image_decode['path', 'img']() \
+ .image_captioning.capdec['img', 'text'](model_name='capdec_noise_0') \
+ .select['img', 'text']() \
+ .show()
+```
+
+
+
+
+
+
+## Factory Constructor
+
+Create the operator via the following factory method
+
+***capdec(model_name)***
+
+**Parameters:**
+
+ ***model_name:*** *str*
+
+ The model name of CapDec. Supported model names:
+- capdec_noise_0
+- capdec_noise_01
+- capdec_noise_001
+- capdec_noise_0001
+
+
+
+## Interface
+
+An image captioning operator takes a [towhee image](link/to/towhee/image/api/doc) as input and generate the correspoing caption.
+
+
+**Parameters:**
+
+ ***data:*** *towhee.types.Image (a sub-class of numpy.ndarray)*
+
+ The image to generate caption.
+
+
+
+**Returns:** *str*
+
+ The caption generated by model.
diff --git a/capdec.py b/capdec.py
index c628d0c..ad2a844 100644
--- a/capdec.py
+++ b/capdec.py
@@ -25,6 +25,7 @@ from towhee.types.image_utils import to_pil
from towhee.operator.base import NNOperator, OperatorFlag
from towhee import register
from towhee.models import clip
+from towhee.command.s3 import S3Bucket
from transformers import GPT2Tokenizer, GPT2LMHeadModel, AdamW, get_linear_schedule_with_warmup
@@ -36,8 +37,27 @@ class Capdec(NNOperator):
def __init__(self, model_name: str):
super().__init__()
sys.path.append(str(Path(__file__).parent))
+ from modules import ClipCaptionModel, generate_beam, generate2
+
+ path = str(Path(__file__).parent)
+ config = self._configs()[model_name]
+ s3_bucket = S3Bucket()
+ s3_bucket.download_file(config['weights'], path + '/weights/')
+ model_path = path + '/weights/' + os.path.basename(config['weights'])
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
+ self.clip_caption_model = ClipCaptionModel()
+ self.clip_caption_model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
+ self.clip_caption_model.to(self.device)
+ self.clip_caption_model.eval()
+
+ self.clip_model = clip.create_model(model_name='clip_resnet_r50x4', pretrained=True, jit=True)
+ self.clip_model.to(self.device)
+ self.clip_tfms = clip.get_transforms(model_name='clip_resnet_r50x4')
+ self.tokenizer = GPT2Tokenizer.from_pretrained("gpt2").to(self.device)
+ self.generate_beam = generate_beam
+ self.generate2 = generate2
@arg(1, to_image_color('RGB'))
def inference_single_data(self, data):
@@ -66,18 +86,25 @@ class Capdec(NNOperator):
@arg(1, to_image_color('RGB'))
def _inference_from_image(self, img):
img = self._preprocess(img)
- clip_feat = self.clip_model.encode_image(img)
-
- self.prefix_length = 10
- prefix_embed = self.model.clip_project(clip_feat).reshape(1, self.prefix_length, -1)
+ use_beam_search = True
+ with torch.no_grad():
+ prefix = self.clip_model.encode_image(img)[0].to(self.device, dtype=torch.float32).unsqueeze(0)
+ prefix_embed = self.clip_caption_model.clip_project(prefix).reshape(1, 40, -1)
+ if use_beam_search:
+ generated_text_prefix = self.generate_beam(self.clip_caption_model, self.tokenizer, embed=prefix_embed)[0]
+ else:
+ generated_text_prefix = self.generate2(self.clip_caption_model, self.tokenizer, embed=prefix_embed)
- generated_text_prefix = self.generate_beam(self.model, self.tokenizer, embed=prefix_embed)[0]
return generated_text_prefix
def _configs(self):
config = {}
- config['clipcap_coco'] = {}
- config['clipcap_coco']['weights'] = 'coco_weights.pt'
- config['clipcap_conceptual'] = {}
- config['clipcap_conceptual']['weights'] = 'conceptual_weights.pt'
+ config['capdec_noise_0'] = {}
+ config['capdec_noise_0']['weights'] = 'image-captioning/capdec/0.pt'
+ config['capdec_noise_01'] = {}
+ config['capdec_noise_01']['weights'] = 'image-captioning/capdec/01.pt'
+ config['capdec_noise_001'] = {}
+ config['capdec_noise_001']['weights'] = 'image-captioning/capdec/001.pt'
+ config['capdec_noise_0001'] = {}
+ config['capdec_noise_0001']['weights'] = 'image-captioning/capdec/0001.pt'
return config
diff --git a/modules.py b/modules.py
new file mode 100644
index 0000000..341cee6
--- /dev/null
+++ b/modules.py
@@ -0,0 +1,365 @@
+import os
+from torch import nn
+import numpy as np
+import torch
+import torch.nn.functional as nnf
+from typing import Tuple, List, Union, Optional
+from transformers import GPT2Tokenizer, GPT2LMHeadModel, AdamW, get_linear_schedule_with_warmup
+from tqdm import tqdm, trange
+
+class ClipCaptionModel(nn.Module):
+
+ def get_dummy_token(self, batch_size: int, device: torch.device) -> torch.Tensor:
+ return torch.zeros(batch_size, self.prefix_length, dtype=torch.int64, device=device)
+
+ def forward(self, tokens: torch.Tensor, prefix: torch.Tensor, mask: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None):
+ embedding_text = self.gpt.transformer.wte(tokens)
+ prefix_projections = self.clip_project(prefix).view(-1, self.prefix_length, self.gpt_embedding_size)
+ embedding_cat = torch.cat((prefix_projections, embedding_text), dim=1)
+ if labels is not None:
+ dummy_token = self.get_dummy_token(tokens.shape[0], tokens.device)
+ labels = torch.cat((dummy_token, tokens), dim=1)
+ out = self.gpt(inputs_embeds=embedding_cat, labels=labels, attention_mask=mask)
+ return out
+
+ def __init__(self):
+ super(ClipCaptionModel, self).__init__()
+ self.prefix_length = 40
+ self.gpt = GPT2LMHeadModel.from_pretrained('gpt2')
+ self.gpt_embedding_size = self.gpt.transformer.wte.weight.shape[1]
+ self.clip_project = TransformerMapper(640, self.gpt_embedding_size, 40, 40, 8)
+
+class MLP(nn.Module):
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.model(x)
+
+ def __init__(self, sizes: Tuple[int, ...], bias=True, act=nn.Tanh):
+ super(MLP, self).__init__()
+ layers = []
+ for i in range(len(sizes) -1):
+ layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=bias))
+ if i < len(sizes) - 2:
+ layers.append(act())
+ self.model = nn.Sequential(*layers)
+
+class MlpTransformer(nn.Module):
+ def __init__(self, in_dim, h_dim, out_d: Optional[int] = None, act=nnf.relu, dropout=0.):
+ super().__init__()
+ out_d = out_d if out_d is not None else in_dim
+ self.fc1 = nn.Linear(in_dim, h_dim)
+ self.act = act
+ self.fc2 = nn.Linear(h_dim, out_d)
+ self.dropout = nn.Dropout(dropout)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.dropout(x)
+ x = self.fc2(x)
+ x = self.dropout(x)
+ return x
+
+class MultiHeadAttention(nn.Module):
+
+ def __init__(self, dim_self, dim_ref, num_heads, bias=True, dropout=0.):
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim_self // num_heads
+ self.scale = head_dim ** -0.5
+ self.to_queries = nn.Linear(dim_self, dim_self, bias=bias)
+ self.to_keys_values = nn.Linear(dim_ref, dim_self * 2, bias=bias)
+ self.project = nn.Linear(dim_self, dim_self)
+ self.dropout = nn.Dropout(dropout)
+
+ def forward(self, x, y=None, mask=None):
+ y = y if y is not None else x
+ b, n, c = x.shape
+ _, m, d = y.shape
+ # b n h dh
+ queries = self.to_queries(x).reshape(b, n, self.num_heads, c // self.num_heads)
+ # b m 2 h dh
+ keys_values = self.to_keys_values(y).reshape(b, m, 2, self.num_heads, c // self.num_heads)
+ keys, values = keys_values[:, :, 0], keys_values[:, :, 1]
+ attention = torch.einsum('bnhd,bmhd->bnmh', queries, keys) * self.scale
+ if mask is not None:
+ if mask.dim() == 2:
+ mask = mask.unsqueeze(1)
+ attention = attention.masked_fill(mask.unsqueeze(3), float("-inf"))
+ attention = attention.softmax(dim=2)
+ out = torch.einsum('bnmh,bmhd->bnhd', attention, values).reshape(b, n, c)
+ out = self.project(out)
+ return out, attention
+
+
+class TransformerLayer(nn.Module):
+
+ def forward_with_attention(self, x, y=None, mask=None):
+ x_, attention = self.attn(self.norm1(x), y, mask)
+ x = x + x_
+ x = x + self.mlp(self.norm2(x))
+ return x, attention
+
+ def forward(self, x, y=None, mask=None):
+ x = x + self.attn(self.norm1(x), y, mask)[0]
+ x = x + self.mlp(self.norm2(x))
+ return x
+
+ def __init__(self, dim_self, dim_ref, num_heads, mlp_ratio=4., bias=False, dropout=0., act=nnf.relu,
+ norm_layer: nn.Module = nn.LayerNorm):
+ super().__init__()
+ self.norm1 = norm_layer(dim_self)
+ self.attn = MultiHeadAttention(dim_self, dim_ref, num_heads, bias=bias, dropout=dropout)
+ self.norm2 = norm_layer(dim_self)
+ self.mlp = MlpTransformer(dim_self, int(dim_self * mlp_ratio), act=act, dropout=dropout)
+
+
+class Transformer(nn.Module):
+
+ def forward_with_attention(self, x, y=None, mask=None):
+ attentions = []
+ for layer in self.layers:
+ x, att = layer.forward_with_attention(x, y, mask)
+ attentions.append(att)
+ return x, attentions
+
+ def forward(self, x, y=None, mask=None):
+ for i, layer in enumerate(self.layers):
+ if i % 2 == 0 and self.enc_dec: # cross
+ x = layer(x, y)
+ elif self.enc_dec: # self
+ x = layer(x, x, mask)
+ else: # self or cross
+ x = layer(x, y, mask)
+ return x
+
+ def __init__(self, dim_self: int, num_heads: int, num_layers: int, dim_ref: Optional[int] = None,
+ mlp_ratio: float = 2., act=nnf.relu, norm_layer: nn.Module = nn.LayerNorm, enc_dec: bool = False):
+ super(Transformer, self).__init__()
+ dim_ref = dim_ref if dim_ref is not None else dim_self
+ self.enc_dec = enc_dec
+ if enc_dec:
+ num_layers = num_layers * 2
+ layers = []
+ for i in range(num_layers):
+ if i % 2 == 0 and enc_dec: # cross
+ layers.append(TransformerLayer(dim_self, dim_ref, num_heads, mlp_ratio, act=act, norm_layer=norm_layer))
+ elif enc_dec: # self
+ layers.append(TransformerLayer(dim_self, dim_self, num_heads, mlp_ratio, act=act, norm_layer=norm_layer))
+ else: # self or cross
+ layers.append(TransformerLayer(dim_self, dim_ref, num_heads, mlp_ratio, act=act, norm_layer=norm_layer))
+ self.layers = nn.ModuleList(layers)
+
+
+class TransformerMapper(nn.Module):
+
+ def forward(self, x):
+ x = self.linear(x).view(x.shape[0], self.clip_length, -1)
+ prefix = self.prefix_const.unsqueeze(0).expand(x.shape[0], *self.prefix_const.shape)
+ prefix = torch.cat((x, prefix), dim=1)
+ out = self.transformer(prefix)[:, self.clip_length:]
+ return out
+
+ def __init__(self, dim_clip: int, dim_embedding: int, prefix_length: int, clip_length: int, num_layers: int = 8):
+ super(TransformerMapper, self).__init__()
+ self.clip_length = clip_length
+ self.transformer = Transformer(dim_embedding, 8, num_layers)
+ self.linear = nn.Linear(dim_clip, clip_length * dim_embedding)
+ self.prefix_const = nn.Parameter(torch.randn(prefix_length, dim_embedding), requires_grad=True)
+
+
+class ClipCaptionPrefix(ClipCaptionModel):
+
+ def parameters(self, recurse: bool = True):
+ return self.clip_project.parameters()
+
+ def train(self, mode: bool = True):
+ super(ClipCaptionPrefix, self).train(mode)
+ self.gpt.eval()
+ return self
+
+def generate_beam(model, tokenizer, beam_size: int = 5, prompt=None, embed=None,
+ entry_length=67, temperature=1., stop_token: str = '.'):
+
+ model.eval()
+ stop_token_index = tokenizer.encode(stop_token)[0]
+ tokens = None
+ scores = None
+ device = next(model.parameters()).device
+ seq_lengths = torch.ones(beam_size, device=device)
+ is_stopped = torch.zeros(beam_size, device=device, dtype=torch.bool)
+ with torch.no_grad():
+ if embed is not None:
+ generated = embed
+ else:
+ if tokens is None:
+ tokens = torch.tensor(tokenizer.encode(prompt))
+ tokens = tokens.unsqueeze(0).to(device)
+ generated = model.gpt.transformer.wte(tokens)
+ for i in range(entry_length):
+ outputs = model.gpt(inputs_embeds=generated)
+ logits = outputs.logits
+ logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
+ logits = logits.softmax(-1).log()
+ if scores is None:
+ scores, next_tokens = logits.topk(beam_size, -1)
+ generated = generated.expand(beam_size, *generated.shape[1:])
+ next_tokens, scores = next_tokens.permute(1, 0), scores.squeeze(0)
+ if tokens is None:
+ tokens = next_tokens
+ else:
+ tokens = tokens.expand(beam_size, *tokens.shape[1:])
+ tokens = torch.cat((tokens, next_tokens), dim=1)
+ else:
+ logits[is_stopped] = -float(np.inf)
+ logits[is_stopped, 0] = 0
+ scores_sum = scores[:, None] + logits
+ seq_lengths[~is_stopped] += 1
+ scores_sum_average = scores_sum / seq_lengths[:, None]
+ scores_sum_average, next_tokens = scores_sum_average.view(-1).topk(beam_size, -1)
+ next_tokens_source = next_tokens // scores_sum.shape[1]
+ seq_lengths = seq_lengths[next_tokens_source]
+ next_tokens = next_tokens % scores_sum.shape[1]
+ next_tokens = next_tokens.unsqueeze(1)
+ tokens = tokens[next_tokens_source]
+ tokens = torch.cat((tokens, next_tokens), dim=1)
+ generated = generated[next_tokens_source]
+ scores = scores_sum_average * seq_lengths
+ is_stopped = is_stopped[next_tokens_source]
+ next_token_embed = model.gpt.transformer.wte(next_tokens.squeeze()).view(generated.shape[0], 1, -1)
+ generated = torch.cat((generated, next_token_embed), dim=1)
+ is_stopped = is_stopped + next_tokens.eq(stop_token_index).squeeze()
+ if is_stopped.all():
+ break
+ scores = scores / seq_lengths
+ output_list = tokens.cpu().numpy()
+ output_texts = [tokenizer.decode(output[:int(length)]) for output, length in zip(output_list, seq_lengths)]
+ order = scores.argsort(descending=True)
+ output_texts = [output_texts[i] for i in order]
+ return output_texts
+
+
+def generate2(
+ model,
+ tokenizer,
+ tokens=None,
+ prompt=None,
+ embed=None,
+ entry_count=1,
+ entry_length=67, # maximum number of words
+ top_p=0.8,
+ temperature=1.,
+ stop_token: str = '.',
+):
+ model.eval()
+ generated_num = 0
+ generated_list = []
+ stop_token_index = tokenizer.encode(stop_token)[0]
+ filter_value = -float("Inf")
+ device = next(model.parameters()).device
+
+ with torch.no_grad():
+
+ for entry_idx in trange(entry_count):
+ if embed is not None:
+ generated = embed
+ else:
+ if tokens is None:
+ tokens = torch.tensor(tokenizer.encode(prompt))
+ tokens = tokens.unsqueeze(0).to(device)
+
+ generated = model.gpt.transformer.wte(tokens)
+
+ for i in range(entry_length):
+
+ outputs = model.gpt(inputs_embeds=generated)
+ logits = outputs.logits
+ logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
+ cumulative_probs = torch.cumsum(nnf.softmax(sorted_logits, dim=-1), dim=-1)
+ sorted_indices_to_remove = cumulative_probs > top_p
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
+ ..., :-1
+ ].clone()
+ sorted_indices_to_remove[..., 0] = 0
+
+ indices_to_remove = sorted_indices[sorted_indices_to_remove]
+ logits[:, indices_to_remove] = filter_value
+ next_token = torch.argmax(logits, -1).unsqueeze(0)
+ next_token_embed = model.gpt.transformer.wte(next_token)
+ if tokens is None:
+ tokens = next_token
+ else:
+ tokens = torch.cat((tokens, next_token), dim=1)
+ generated = torch.cat((generated, next_token_embed), dim=1)
+ if stop_token_index == next_token.item():
+ break
+
+ output_list = list(tokens.squeeze().cpu().numpy())
+ output_text = tokenizer.decode(output_list)
+ generated_list.append(output_text)
+
+ return generated_list[0]
+
+def generate_beam(model, tokenizer, beam_size: int = 5, prompt=None, embed=None,
+ entry_length=67, temperature=1., stop_token: str = '.'):
+
+ model.eval()
+ stop_token_index = tokenizer.encode(stop_token)[0]
+ tokens = None
+ scores = None
+ device = next(model.parameters()).device
+ seq_lengths = torch.ones(beam_size, device=device)
+ is_stopped = torch.zeros(beam_size, device=device, dtype=torch.bool)
+ with torch.no_grad():
+ if embed is not None:
+ generated = embed
+ else:
+ if tokens is None:
+ tokens = torch.tensor(tokenizer.encode(prompt))
+ tokens = tokens.unsqueeze(0).to(device)
+ generated = model.gpt.transformer.wte(tokens)
+ for i in range(entry_length):
+ outputs = model.gpt(inputs_embeds=generated)
+ logits = outputs.logits
+ logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
+ logits = logits.softmax(-1).log()
+ if scores is None:
+ scores, next_tokens = logits.topk(beam_size, -1)
+ generated = generated.expand(beam_size, *generated.shape[1:])
+ next_tokens, scores = next_tokens.permute(1, 0), scores.squeeze(0)
+ if tokens is None:
+ tokens = next_tokens
+ else:
+ tokens = tokens.expand(beam_size, *tokens.shape[1:])
+ tokens = torch.cat((tokens, next_tokens), dim=1)
+ else:
+ logits[is_stopped] = -float(np.inf)
+ logits[is_stopped, 0] = 0
+ scores_sum = scores[:, None] + logits
+ seq_lengths[~is_stopped] += 1
+ scores_sum_average = scores_sum / seq_lengths[:, None]
+ scores_sum_average, next_tokens = scores_sum_average.view(-1).topk(beam_size, -1)
+ next_tokens_source = next_tokens // scores_sum.shape[1]
+ seq_lengths = seq_lengths[next_tokens_source]
+ next_tokens = next_tokens % scores_sum.shape[1]
+ next_tokens = next_tokens.unsqueeze(1)
+ tokens = tokens[next_tokens_source]
+ tokens = torch.cat((tokens, next_tokens), dim=1)
+ generated = generated[next_tokens_source]
+ scores = scores_sum_average * seq_lengths
+ is_stopped = is_stopped[next_tokens_source]
+ next_token_embed = model.gpt.transformer.wte(next_tokens.squeeze()).view(generated.shape[0], 1, -1)
+ generated = torch.cat((generated, next_token_embed), dim=1)
+ is_stopped = is_stopped + next_tokens.eq(stop_token_index).squeeze()
+ if is_stopped.all():
+ break
+ scores = scores / seq_lengths
+ output_list = tokens.cpu().numpy()
+ output_texts = [tokenizer.decode(output[:int(length)]) for output, length in zip(output_list, seq_lengths)]
+ order = scores.argsort(descending=True)
+ output_texts = [output_texts[i] for i in order]
+ return output_texts
+
+pretrained_model_variance = "0.1" #@param ["0.0", "0.0001", "0.001", "0.015", "0.1", "2.5"]
\ No newline at end of file