diff --git a/README.md b/README.md index f428144..f430260 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,84 @@ -# capdec +# Image Captioning with CapDec + +*author: David Wang* + + +
+ + +## Description + +This operator generates the caption with [CapDec](https://arxiv.org/abs/2211.00575) which describes the content of the given image. ExpansionNet v2 introduces the Block Static Expansion which distributes and processes the input over a heterogeneous and arbitrarily big collection of sequences characterized by a different length compared to the input one. This is an adaptation from [DavidHuji/CapDec](https://github.com/DavidHuji/CapDec). + + +
+ + +## Code Example + +Load an image from path './image.jpg' to generate the caption. + + *Write the pipeline in simplified style*: + +```python +import towhee + +towhee.glob('./image.jpg') \ + .image_decode() \ + .image_captioning.capdec(model_name='capdec_noise_0') \ + .show() +``` +result1 + +*Write a same pipeline with explicit inputs/outputs name specifications:* + +```python +import towhee + +towhee.glob['path']('./image.jpg') \ + .image_decode['path', 'img']() \ + .image_captioning.capdec['img', 'text'](model_name='capdec_noise_0') \ + .select['img', 'text']() \ + .show() +``` +result2 + + +
+ + +## Factory Constructor + +Create the operator via the following factory method + +***capdec(model_name)*** + +**Parameters:** + +​ ***model_name:*** *str* + +​ The model name of CapDec. Supported model names: +- capdec_noise_0 +- capdec_noise_01 +- capdec_noise_001 +- capdec_noise_0001 + +
+ +## Interface + +An image captioning operator takes a [towhee image](link/to/towhee/image/api/doc) as input and generate the correspoing caption. + + +**Parameters:** + +​ ***data:*** *towhee.types.Image (a sub-class of numpy.ndarray)* + +​ The image to generate caption. + + + +**Returns:** *str* + +​ The caption generated by model. diff --git a/capdec.py b/capdec.py index c628d0c..ad2a844 100644 --- a/capdec.py +++ b/capdec.py @@ -25,6 +25,7 @@ from towhee.types.image_utils import to_pil from towhee.operator.base import NNOperator, OperatorFlag from towhee import register from towhee.models import clip +from towhee.command.s3 import S3Bucket from transformers import GPT2Tokenizer, GPT2LMHeadModel, AdamW, get_linear_schedule_with_warmup @@ -36,8 +37,27 @@ class Capdec(NNOperator): def __init__(self, model_name: str): super().__init__() sys.path.append(str(Path(__file__).parent)) + from modules import ClipCaptionModel, generate_beam, generate2 + + path = str(Path(__file__).parent) + config = self._configs()[model_name] + s3_bucket = S3Bucket() + s3_bucket.download_file(config['weights'], path + '/weights/') + model_path = path + '/weights/' + os.path.basename(config['weights']) + self.device = "cuda" if torch.cuda.is_available() else "cpu" + self.clip_caption_model = ClipCaptionModel() + self.clip_caption_model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) + self.clip_caption_model.to(self.device) + self.clip_caption_model.eval() + + self.clip_model = clip.create_model(model_name='clip_resnet_r50x4', pretrained=True, jit=True) + self.clip_model.to(self.device) + self.clip_tfms = clip.get_transforms(model_name='clip_resnet_r50x4') + self.tokenizer = GPT2Tokenizer.from_pretrained("gpt2").to(self.device) + self.generate_beam = generate_beam + self.generate2 = generate2 @arg(1, to_image_color('RGB')) def inference_single_data(self, data): @@ -66,18 +86,25 @@ class Capdec(NNOperator): @arg(1, to_image_color('RGB')) def _inference_from_image(self, img): img = self._preprocess(img) - clip_feat = self.clip_model.encode_image(img) - - self.prefix_length = 10 - prefix_embed = self.model.clip_project(clip_feat).reshape(1, self.prefix_length, -1) + use_beam_search = True + with torch.no_grad(): + prefix = self.clip_model.encode_image(img)[0].to(self.device, dtype=torch.float32).unsqueeze(0) + prefix_embed = self.clip_caption_model.clip_project(prefix).reshape(1, 40, -1) + if use_beam_search: + generated_text_prefix = self.generate_beam(self.clip_caption_model, self.tokenizer, embed=prefix_embed)[0] + else: + generated_text_prefix = self.generate2(self.clip_caption_model, self.tokenizer, embed=prefix_embed) - generated_text_prefix = self.generate_beam(self.model, self.tokenizer, embed=prefix_embed)[0] return generated_text_prefix def _configs(self): config = {} - config['clipcap_coco'] = {} - config['clipcap_coco']['weights'] = 'coco_weights.pt' - config['clipcap_conceptual'] = {} - config['clipcap_conceptual']['weights'] = 'conceptual_weights.pt' + config['capdec_noise_0'] = {} + config['capdec_noise_0']['weights'] = 'image-captioning/capdec/0.pt' + config['capdec_noise_01'] = {} + config['capdec_noise_01']['weights'] = 'image-captioning/capdec/01.pt' + config['capdec_noise_001'] = {} + config['capdec_noise_001']['weights'] = 'image-captioning/capdec/001.pt' + config['capdec_noise_0001'] = {} + config['capdec_noise_0001']['weights'] = 'image-captioning/capdec/0001.pt' return config diff --git a/modules.py b/modules.py new file mode 100644 index 0000000..341cee6 --- /dev/null +++ b/modules.py @@ -0,0 +1,365 @@ +import os +from torch import nn +import numpy as np +import torch +import torch.nn.functional as nnf +from typing import Tuple, List, Union, Optional +from transformers import GPT2Tokenizer, GPT2LMHeadModel, AdamW, get_linear_schedule_with_warmup +from tqdm import tqdm, trange + +class ClipCaptionModel(nn.Module): + + def get_dummy_token(self, batch_size: int, device: torch.device) -> torch.Tensor: + return torch.zeros(batch_size, self.prefix_length, dtype=torch.int64, device=device) + + def forward(self, tokens: torch.Tensor, prefix: torch.Tensor, mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None): + embedding_text = self.gpt.transformer.wte(tokens) + prefix_projections = self.clip_project(prefix).view(-1, self.prefix_length, self.gpt_embedding_size) + embedding_cat = torch.cat((prefix_projections, embedding_text), dim=1) + if labels is not None: + dummy_token = self.get_dummy_token(tokens.shape[0], tokens.device) + labels = torch.cat((dummy_token, tokens), dim=1) + out = self.gpt(inputs_embeds=embedding_cat, labels=labels, attention_mask=mask) + return out + + def __init__(self): + super(ClipCaptionModel, self).__init__() + self.prefix_length = 40 + self.gpt = GPT2LMHeadModel.from_pretrained('gpt2') + self.gpt_embedding_size = self.gpt.transformer.wte.weight.shape[1] + self.clip_project = TransformerMapper(640, self.gpt_embedding_size, 40, 40, 8) + +class MLP(nn.Module): + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.model(x) + + def __init__(self, sizes: Tuple[int, ...], bias=True, act=nn.Tanh): + super(MLP, self).__init__() + layers = [] + for i in range(len(sizes) -1): + layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=bias)) + if i < len(sizes) - 2: + layers.append(act()) + self.model = nn.Sequential(*layers) + +class MlpTransformer(nn.Module): + def __init__(self, in_dim, h_dim, out_d: Optional[int] = None, act=nnf.relu, dropout=0.): + super().__init__() + out_d = out_d if out_d is not None else in_dim + self.fc1 = nn.Linear(in_dim, h_dim) + self.act = act + self.fc2 = nn.Linear(h_dim, out_d) + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.dropout(x) + x = self.fc2(x) + x = self.dropout(x) + return x + +class MultiHeadAttention(nn.Module): + + def __init__(self, dim_self, dim_ref, num_heads, bias=True, dropout=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim_self // num_heads + self.scale = head_dim ** -0.5 + self.to_queries = nn.Linear(dim_self, dim_self, bias=bias) + self.to_keys_values = nn.Linear(dim_ref, dim_self * 2, bias=bias) + self.project = nn.Linear(dim_self, dim_self) + self.dropout = nn.Dropout(dropout) + + def forward(self, x, y=None, mask=None): + y = y if y is not None else x + b, n, c = x.shape + _, m, d = y.shape + # b n h dh + queries = self.to_queries(x).reshape(b, n, self.num_heads, c // self.num_heads) + # b m 2 h dh + keys_values = self.to_keys_values(y).reshape(b, m, 2, self.num_heads, c // self.num_heads) + keys, values = keys_values[:, :, 0], keys_values[:, :, 1] + attention = torch.einsum('bnhd,bmhd->bnmh', queries, keys) * self.scale + if mask is not None: + if mask.dim() == 2: + mask = mask.unsqueeze(1) + attention = attention.masked_fill(mask.unsqueeze(3), float("-inf")) + attention = attention.softmax(dim=2) + out = torch.einsum('bnmh,bmhd->bnhd', attention, values).reshape(b, n, c) + out = self.project(out) + return out, attention + + +class TransformerLayer(nn.Module): + + def forward_with_attention(self, x, y=None, mask=None): + x_, attention = self.attn(self.norm1(x), y, mask) + x = x + x_ + x = x + self.mlp(self.norm2(x)) + return x, attention + + def forward(self, x, y=None, mask=None): + x = x + self.attn(self.norm1(x), y, mask)[0] + x = x + self.mlp(self.norm2(x)) + return x + + def __init__(self, dim_self, dim_ref, num_heads, mlp_ratio=4., bias=False, dropout=0., act=nnf.relu, + norm_layer: nn.Module = nn.LayerNorm): + super().__init__() + self.norm1 = norm_layer(dim_self) + self.attn = MultiHeadAttention(dim_self, dim_ref, num_heads, bias=bias, dropout=dropout) + self.norm2 = norm_layer(dim_self) + self.mlp = MlpTransformer(dim_self, int(dim_self * mlp_ratio), act=act, dropout=dropout) + + +class Transformer(nn.Module): + + def forward_with_attention(self, x, y=None, mask=None): + attentions = [] + for layer in self.layers: + x, att = layer.forward_with_attention(x, y, mask) + attentions.append(att) + return x, attentions + + def forward(self, x, y=None, mask=None): + for i, layer in enumerate(self.layers): + if i % 2 == 0 and self.enc_dec: # cross + x = layer(x, y) + elif self.enc_dec: # self + x = layer(x, x, mask) + else: # self or cross + x = layer(x, y, mask) + return x + + def __init__(self, dim_self: int, num_heads: int, num_layers: int, dim_ref: Optional[int] = None, + mlp_ratio: float = 2., act=nnf.relu, norm_layer: nn.Module = nn.LayerNorm, enc_dec: bool = False): + super(Transformer, self).__init__() + dim_ref = dim_ref if dim_ref is not None else dim_self + self.enc_dec = enc_dec + if enc_dec: + num_layers = num_layers * 2 + layers = [] + for i in range(num_layers): + if i % 2 == 0 and enc_dec: # cross + layers.append(TransformerLayer(dim_self, dim_ref, num_heads, mlp_ratio, act=act, norm_layer=norm_layer)) + elif enc_dec: # self + layers.append(TransformerLayer(dim_self, dim_self, num_heads, mlp_ratio, act=act, norm_layer=norm_layer)) + else: # self or cross + layers.append(TransformerLayer(dim_self, dim_ref, num_heads, mlp_ratio, act=act, norm_layer=norm_layer)) + self.layers = nn.ModuleList(layers) + + +class TransformerMapper(nn.Module): + + def forward(self, x): + x = self.linear(x).view(x.shape[0], self.clip_length, -1) + prefix = self.prefix_const.unsqueeze(0).expand(x.shape[0], *self.prefix_const.shape) + prefix = torch.cat((x, prefix), dim=1) + out = self.transformer(prefix)[:, self.clip_length:] + return out + + def __init__(self, dim_clip: int, dim_embedding: int, prefix_length: int, clip_length: int, num_layers: int = 8): + super(TransformerMapper, self).__init__() + self.clip_length = clip_length + self.transformer = Transformer(dim_embedding, 8, num_layers) + self.linear = nn.Linear(dim_clip, clip_length * dim_embedding) + self.prefix_const = nn.Parameter(torch.randn(prefix_length, dim_embedding), requires_grad=True) + + +class ClipCaptionPrefix(ClipCaptionModel): + + def parameters(self, recurse: bool = True): + return self.clip_project.parameters() + + def train(self, mode: bool = True): + super(ClipCaptionPrefix, self).train(mode) + self.gpt.eval() + return self + +def generate_beam(model, tokenizer, beam_size: int = 5, prompt=None, embed=None, + entry_length=67, temperature=1., stop_token: str = '.'): + + model.eval() + stop_token_index = tokenizer.encode(stop_token)[0] + tokens = None + scores = None + device = next(model.parameters()).device + seq_lengths = torch.ones(beam_size, device=device) + is_stopped = torch.zeros(beam_size, device=device, dtype=torch.bool) + with torch.no_grad(): + if embed is not None: + generated = embed + else: + if tokens is None: + tokens = torch.tensor(tokenizer.encode(prompt)) + tokens = tokens.unsqueeze(0).to(device) + generated = model.gpt.transformer.wte(tokens) + for i in range(entry_length): + outputs = model.gpt(inputs_embeds=generated) + logits = outputs.logits + logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0) + logits = logits.softmax(-1).log() + if scores is None: + scores, next_tokens = logits.topk(beam_size, -1) + generated = generated.expand(beam_size, *generated.shape[1:]) + next_tokens, scores = next_tokens.permute(1, 0), scores.squeeze(0) + if tokens is None: + tokens = next_tokens + else: + tokens = tokens.expand(beam_size, *tokens.shape[1:]) + tokens = torch.cat((tokens, next_tokens), dim=1) + else: + logits[is_stopped] = -float(np.inf) + logits[is_stopped, 0] = 0 + scores_sum = scores[:, None] + logits + seq_lengths[~is_stopped] += 1 + scores_sum_average = scores_sum / seq_lengths[:, None] + scores_sum_average, next_tokens = scores_sum_average.view(-1).topk(beam_size, -1) + next_tokens_source = next_tokens // scores_sum.shape[1] + seq_lengths = seq_lengths[next_tokens_source] + next_tokens = next_tokens % scores_sum.shape[1] + next_tokens = next_tokens.unsqueeze(1) + tokens = tokens[next_tokens_source] + tokens = torch.cat((tokens, next_tokens), dim=1) + generated = generated[next_tokens_source] + scores = scores_sum_average * seq_lengths + is_stopped = is_stopped[next_tokens_source] + next_token_embed = model.gpt.transformer.wte(next_tokens.squeeze()).view(generated.shape[0], 1, -1) + generated = torch.cat((generated, next_token_embed), dim=1) + is_stopped = is_stopped + next_tokens.eq(stop_token_index).squeeze() + if is_stopped.all(): + break + scores = scores / seq_lengths + output_list = tokens.cpu().numpy() + output_texts = [tokenizer.decode(output[:int(length)]) for output, length in zip(output_list, seq_lengths)] + order = scores.argsort(descending=True) + output_texts = [output_texts[i] for i in order] + return output_texts + + +def generate2( + model, + tokenizer, + tokens=None, + prompt=None, + embed=None, + entry_count=1, + entry_length=67, # maximum number of words + top_p=0.8, + temperature=1., + stop_token: str = '.', +): + model.eval() + generated_num = 0 + generated_list = [] + stop_token_index = tokenizer.encode(stop_token)[0] + filter_value = -float("Inf") + device = next(model.parameters()).device + + with torch.no_grad(): + + for entry_idx in trange(entry_count): + if embed is not None: + generated = embed + else: + if tokens is None: + tokens = torch.tensor(tokenizer.encode(prompt)) + tokens = tokens.unsqueeze(0).to(device) + + generated = model.gpt.transformer.wte(tokens) + + for i in range(entry_length): + + outputs = model.gpt(inputs_embeds=generated) + logits = outputs.logits + logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0) + sorted_logits, sorted_indices = torch.sort(logits, descending=True) + cumulative_probs = torch.cumsum(nnf.softmax(sorted_logits, dim=-1), dim=-1) + sorted_indices_to_remove = cumulative_probs > top_p + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[ + ..., :-1 + ].clone() + sorted_indices_to_remove[..., 0] = 0 + + indices_to_remove = sorted_indices[sorted_indices_to_remove] + logits[:, indices_to_remove] = filter_value + next_token = torch.argmax(logits, -1).unsqueeze(0) + next_token_embed = model.gpt.transformer.wte(next_token) + if tokens is None: + tokens = next_token + else: + tokens = torch.cat((tokens, next_token), dim=1) + generated = torch.cat((generated, next_token_embed), dim=1) + if stop_token_index == next_token.item(): + break + + output_list = list(tokens.squeeze().cpu().numpy()) + output_text = tokenizer.decode(output_list) + generated_list.append(output_text) + + return generated_list[0] + +def generate_beam(model, tokenizer, beam_size: int = 5, prompt=None, embed=None, + entry_length=67, temperature=1., stop_token: str = '.'): + + model.eval() + stop_token_index = tokenizer.encode(stop_token)[0] + tokens = None + scores = None + device = next(model.parameters()).device + seq_lengths = torch.ones(beam_size, device=device) + is_stopped = torch.zeros(beam_size, device=device, dtype=torch.bool) + with torch.no_grad(): + if embed is not None: + generated = embed + else: + if tokens is None: + tokens = torch.tensor(tokenizer.encode(prompt)) + tokens = tokens.unsqueeze(0).to(device) + generated = model.gpt.transformer.wte(tokens) + for i in range(entry_length): + outputs = model.gpt(inputs_embeds=generated) + logits = outputs.logits + logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0) + logits = logits.softmax(-1).log() + if scores is None: + scores, next_tokens = logits.topk(beam_size, -1) + generated = generated.expand(beam_size, *generated.shape[1:]) + next_tokens, scores = next_tokens.permute(1, 0), scores.squeeze(0) + if tokens is None: + tokens = next_tokens + else: + tokens = tokens.expand(beam_size, *tokens.shape[1:]) + tokens = torch.cat((tokens, next_tokens), dim=1) + else: + logits[is_stopped] = -float(np.inf) + logits[is_stopped, 0] = 0 + scores_sum = scores[:, None] + logits + seq_lengths[~is_stopped] += 1 + scores_sum_average = scores_sum / seq_lengths[:, None] + scores_sum_average, next_tokens = scores_sum_average.view(-1).topk(beam_size, -1) + next_tokens_source = next_tokens // scores_sum.shape[1] + seq_lengths = seq_lengths[next_tokens_source] + next_tokens = next_tokens % scores_sum.shape[1] + next_tokens = next_tokens.unsqueeze(1) + tokens = tokens[next_tokens_source] + tokens = torch.cat((tokens, next_tokens), dim=1) + generated = generated[next_tokens_source] + scores = scores_sum_average * seq_lengths + is_stopped = is_stopped[next_tokens_source] + next_token_embed = model.gpt.transformer.wte(next_tokens.squeeze()).view(generated.shape[0], 1, -1) + generated = torch.cat((generated, next_token_embed), dim=1) + is_stopped = is_stopped + next_tokens.eq(stop_token_index).squeeze() + if is_stopped.all(): + break + scores = scores / seq_lengths + output_list = tokens.cpu().numpy() + output_texts = [tokenizer.decode(output[:int(length)]) for output, length in zip(output_list, seq_lengths)] + order = scores.argsort(descending=True) + output_texts = [output_texts[i] for i in order] + return output_texts + +pretrained_model_variance = "0.1" #@param ["0.0", "0.0001", "0.001", "0.015", "0.1", "2.5"] \ No newline at end of file