logo
Browse Source

update the operator.

Signed-off-by: wxywb <xy.wang@zilliz.com>
main
wxywb 2 years ago
parent
commit
e3792cc483
  1. 84
      README.md
  2. 45
      capdec.py
  3. 365
      modules.py

84
README.md

@ -1,2 +1,84 @@
# capdec
# Image Captioning with CapDec
*author: David Wang*
<br />
## Description
This operator generates the caption with [CapDec](https://arxiv.org/abs/2211.00575) which describes the content of the given image. ExpansionNet v2 introduces the Block Static Expansion which distributes and processes the input over a heterogeneous and arbitrarily big collection of sequences characterized by a different length compared to the input one. This is an adaptation from [DavidHuji/CapDec](https://github.com/DavidHuji/CapDec).
<br />
## Code Example
Load an image from path './image.jpg' to generate the caption.
*Write the pipeline in simplified style*:
```python
import towhee
towhee.glob('./image.jpg') \
.image_decode() \
.image_captioning.capdec(model_name='capdec_noise_0') \
.show()
```
<img src="./cap.png" alt="result1" style="height:20px;"/>
*Write a same pipeline with explicit inputs/outputs name specifications:*
```python
import towhee
towhee.glob['path']('./image.jpg') \
.image_decode['path', 'img']() \
.image_captioning.capdec['img', 'text'](model_name='capdec_noise_0') \
.select['img', 'text']() \
.show()
```
<img src="./tabular.png" alt="result2" style="height:60px;"/>
<br />
## Factory Constructor
Create the operator via the following factory method
***capdec(model_name)***
**Parameters:**
***model_name:*** *str*
​ The model name of CapDec. Supported model names:
- capdec_noise_0
- capdec_noise_01
- capdec_noise_001
- capdec_noise_0001
<br />
## Interface
An image captioning operator takes a [towhee image](link/to/towhee/image/api/doc) as input and generate the correspoing caption.
**Parameters:**
***data:*** *towhee.types.Image (a sub-class of numpy.ndarray)*
​ The image to generate caption.
**Returns:** *str*
​ The caption generated by model.

45
capdec.py

@ -25,6 +25,7 @@ from towhee.types.image_utils import to_pil
from towhee.operator.base import NNOperator, OperatorFlag from towhee.operator.base import NNOperator, OperatorFlag
from towhee import register from towhee import register
from towhee.models import clip from towhee.models import clip
from towhee.command.s3 import S3Bucket
from transformers import GPT2Tokenizer, GPT2LMHeadModel, AdamW, get_linear_schedule_with_warmup from transformers import GPT2Tokenizer, GPT2LMHeadModel, AdamW, get_linear_schedule_with_warmup
@ -36,8 +37,27 @@ class Capdec(NNOperator):
def __init__(self, model_name: str): def __init__(self, model_name: str):
super().__init__() super().__init__()
sys.path.append(str(Path(__file__).parent)) sys.path.append(str(Path(__file__).parent))
from modules import ClipCaptionModel, generate_beam, generate2
path = str(Path(__file__).parent)
config = self._configs()[model_name]
s3_bucket = S3Bucket()
s3_bucket.download_file(config['weights'], path + '/weights/')
model_path = path + '/weights/' + os.path.basename(config['weights'])
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.clip_caption_model = ClipCaptionModel()
self.clip_caption_model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
self.clip_caption_model.to(self.device)
self.clip_caption_model.eval()
self.clip_model = clip.create_model(model_name='clip_resnet_r50x4', pretrained=True, jit=True)
self.clip_model.to(self.device)
self.clip_tfms = clip.get_transforms(model_name='clip_resnet_r50x4')
self.tokenizer = GPT2Tokenizer.from_pretrained("gpt2").to(self.device)
self.generate_beam = generate_beam
self.generate2 = generate2
@arg(1, to_image_color('RGB')) @arg(1, to_image_color('RGB'))
def inference_single_data(self, data): def inference_single_data(self, data):
@ -66,18 +86,25 @@ class Capdec(NNOperator):
@arg(1, to_image_color('RGB')) @arg(1, to_image_color('RGB'))
def _inference_from_image(self, img): def _inference_from_image(self, img):
img = self._preprocess(img) img = self._preprocess(img)
clip_feat = self.clip_model.encode_image(img)
self.prefix_length = 10
prefix_embed = self.model.clip_project(clip_feat).reshape(1, self.prefix_length, -1)
use_beam_search = True
with torch.no_grad():
prefix = self.clip_model.encode_image(img)[0].to(self.device, dtype=torch.float32).unsqueeze(0)
prefix_embed = self.clip_caption_model.clip_project(prefix).reshape(1, 40, -1)
if use_beam_search:
generated_text_prefix = self.generate_beam(self.clip_caption_model, self.tokenizer, embed=prefix_embed)[0]
else:
generated_text_prefix = self.generate2(self.clip_caption_model, self.tokenizer, embed=prefix_embed)
generated_text_prefix = self.generate_beam(self.model, self.tokenizer, embed=prefix_embed)[0]
return generated_text_prefix return generated_text_prefix
def _configs(self): def _configs(self):
config = {} config = {}
config['clipcap_coco'] = {}
config['clipcap_coco']['weights'] = 'coco_weights.pt'
config['clipcap_conceptual'] = {}
config['clipcap_conceptual']['weights'] = 'conceptual_weights.pt'
config['capdec_noise_0'] = {}
config['capdec_noise_0']['weights'] = 'image-captioning/capdec/0.pt'
config['capdec_noise_01'] = {}
config['capdec_noise_01']['weights'] = 'image-captioning/capdec/01.pt'
config['capdec_noise_001'] = {}
config['capdec_noise_001']['weights'] = 'image-captioning/capdec/001.pt'
config['capdec_noise_0001'] = {}
config['capdec_noise_0001']['weights'] = 'image-captioning/capdec/0001.pt'
return config return config

365
modules.py

@ -0,0 +1,365 @@
import os
from torch import nn
import numpy as np
import torch
import torch.nn.functional as nnf
from typing import Tuple, List, Union, Optional
from transformers import GPT2Tokenizer, GPT2LMHeadModel, AdamW, get_linear_schedule_with_warmup
from tqdm import tqdm, trange
class ClipCaptionModel(nn.Module):
def get_dummy_token(self, batch_size: int, device: torch.device) -> torch.Tensor:
return torch.zeros(batch_size, self.prefix_length, dtype=torch.int64, device=device)
def forward(self, tokens: torch.Tensor, prefix: torch.Tensor, mask: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None):
embedding_text = self.gpt.transformer.wte(tokens)
prefix_projections = self.clip_project(prefix).view(-1, self.prefix_length, self.gpt_embedding_size)
embedding_cat = torch.cat((prefix_projections, embedding_text), dim=1)
if labels is not None:
dummy_token = self.get_dummy_token(tokens.shape[0], tokens.device)
labels = torch.cat((dummy_token, tokens), dim=1)
out = self.gpt(inputs_embeds=embedding_cat, labels=labels, attention_mask=mask)
return out
def __init__(self):
super(ClipCaptionModel, self).__init__()
self.prefix_length = 40
self.gpt = GPT2LMHeadModel.from_pretrained('gpt2')
self.gpt_embedding_size = self.gpt.transformer.wte.weight.shape[1]
self.clip_project = TransformerMapper(640, self.gpt_embedding_size, 40, 40, 8)
class MLP(nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.model(x)
def __init__(self, sizes: Tuple[int, ...], bias=True, act=nn.Tanh):
super(MLP, self).__init__()
layers = []
for i in range(len(sizes) -1):
layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=bias))
if i < len(sizes) - 2:
layers.append(act())
self.model = nn.Sequential(*layers)
class MlpTransformer(nn.Module):
def __init__(self, in_dim, h_dim, out_d: Optional[int] = None, act=nnf.relu, dropout=0.):
super().__init__()
out_d = out_d if out_d is not None else in_dim
self.fc1 = nn.Linear(in_dim, h_dim)
self.act = act
self.fc2 = nn.Linear(h_dim, out_d)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.dropout(x)
x = self.fc2(x)
x = self.dropout(x)
return x
class MultiHeadAttention(nn.Module):
def __init__(self, dim_self, dim_ref, num_heads, bias=True, dropout=0.):
super().__init__()
self.num_heads = num_heads
head_dim = dim_self // num_heads
self.scale = head_dim ** -0.5
self.to_queries = nn.Linear(dim_self, dim_self, bias=bias)
self.to_keys_values = nn.Linear(dim_ref, dim_self * 2, bias=bias)
self.project = nn.Linear(dim_self, dim_self)
self.dropout = nn.Dropout(dropout)
def forward(self, x, y=None, mask=None):
y = y if y is not None else x
b, n, c = x.shape
_, m, d = y.shape
# b n h dh
queries = self.to_queries(x).reshape(b, n, self.num_heads, c // self.num_heads)
# b m 2 h dh
keys_values = self.to_keys_values(y).reshape(b, m, 2, self.num_heads, c // self.num_heads)
keys, values = keys_values[:, :, 0], keys_values[:, :, 1]
attention = torch.einsum('bnhd,bmhd->bnmh', queries, keys) * self.scale
if mask is not None:
if mask.dim() == 2:
mask = mask.unsqueeze(1)
attention = attention.masked_fill(mask.unsqueeze(3), float("-inf"))
attention = attention.softmax(dim=2)
out = torch.einsum('bnmh,bmhd->bnhd', attention, values).reshape(b, n, c)
out = self.project(out)
return out, attention
class TransformerLayer(nn.Module):
def forward_with_attention(self, x, y=None, mask=None):
x_, attention = self.attn(self.norm1(x), y, mask)
x = x + x_
x = x + self.mlp(self.norm2(x))
return x, attention
def forward(self, x, y=None, mask=None):
x = x + self.attn(self.norm1(x), y, mask)[0]
x = x + self.mlp(self.norm2(x))
return x
def __init__(self, dim_self, dim_ref, num_heads, mlp_ratio=4., bias=False, dropout=0., act=nnf.relu,
norm_layer: nn.Module = nn.LayerNorm):
super().__init__()
self.norm1 = norm_layer(dim_self)
self.attn = MultiHeadAttention(dim_self, dim_ref, num_heads, bias=bias, dropout=dropout)
self.norm2 = norm_layer(dim_self)
self.mlp = MlpTransformer(dim_self, int(dim_self * mlp_ratio), act=act, dropout=dropout)
class Transformer(nn.Module):
def forward_with_attention(self, x, y=None, mask=None):
attentions = []
for layer in self.layers:
x, att = layer.forward_with_attention(x, y, mask)
attentions.append(att)
return x, attentions
def forward(self, x, y=None, mask=None):
for i, layer in enumerate(self.layers):
if i % 2 == 0 and self.enc_dec: # cross
x = layer(x, y)
elif self.enc_dec: # self
x = layer(x, x, mask)
else: # self or cross
x = layer(x, y, mask)
return x
def __init__(self, dim_self: int, num_heads: int, num_layers: int, dim_ref: Optional[int] = None,
mlp_ratio: float = 2., act=nnf.relu, norm_layer: nn.Module = nn.LayerNorm, enc_dec: bool = False):
super(Transformer, self).__init__()
dim_ref = dim_ref if dim_ref is not None else dim_self
self.enc_dec = enc_dec
if enc_dec:
num_layers = num_layers * 2
layers = []
for i in range(num_layers):
if i % 2 == 0 and enc_dec: # cross
layers.append(TransformerLayer(dim_self, dim_ref, num_heads, mlp_ratio, act=act, norm_layer=norm_layer))
elif enc_dec: # self
layers.append(TransformerLayer(dim_self, dim_self, num_heads, mlp_ratio, act=act, norm_layer=norm_layer))
else: # self or cross
layers.append(TransformerLayer(dim_self, dim_ref, num_heads, mlp_ratio, act=act, norm_layer=norm_layer))
self.layers = nn.ModuleList(layers)
class TransformerMapper(nn.Module):
def forward(self, x):
x = self.linear(x).view(x.shape[0], self.clip_length, -1)
prefix = self.prefix_const.unsqueeze(0).expand(x.shape[0], *self.prefix_const.shape)
prefix = torch.cat((x, prefix), dim=1)
out = self.transformer(prefix)[:, self.clip_length:]
return out
def __init__(self, dim_clip: int, dim_embedding: int, prefix_length: int, clip_length: int, num_layers: int = 8):
super(TransformerMapper, self).__init__()
self.clip_length = clip_length
self.transformer = Transformer(dim_embedding, 8, num_layers)
self.linear = nn.Linear(dim_clip, clip_length * dim_embedding)
self.prefix_const = nn.Parameter(torch.randn(prefix_length, dim_embedding), requires_grad=True)
class ClipCaptionPrefix(ClipCaptionModel):
def parameters(self, recurse: bool = True):
return self.clip_project.parameters()
def train(self, mode: bool = True):
super(ClipCaptionPrefix, self).train(mode)
self.gpt.eval()
return self
def generate_beam(model, tokenizer, beam_size: int = 5, prompt=None, embed=None,
entry_length=67, temperature=1., stop_token: str = '.'):
model.eval()
stop_token_index = tokenizer.encode(stop_token)[0]
tokens = None
scores = None
device = next(model.parameters()).device
seq_lengths = torch.ones(beam_size, device=device)
is_stopped = torch.zeros(beam_size, device=device, dtype=torch.bool)
with torch.no_grad():
if embed is not None:
generated = embed
else:
if tokens is None:
tokens = torch.tensor(tokenizer.encode(prompt))
tokens = tokens.unsqueeze(0).to(device)
generated = model.gpt.transformer.wte(tokens)
for i in range(entry_length):
outputs = model.gpt(inputs_embeds=generated)
logits = outputs.logits
logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
logits = logits.softmax(-1).log()
if scores is None:
scores, next_tokens = logits.topk(beam_size, -1)
generated = generated.expand(beam_size, *generated.shape[1:])
next_tokens, scores = next_tokens.permute(1, 0), scores.squeeze(0)
if tokens is None:
tokens = next_tokens
else:
tokens = tokens.expand(beam_size, *tokens.shape[1:])
tokens = torch.cat((tokens, next_tokens), dim=1)
else:
logits[is_stopped] = -float(np.inf)
logits[is_stopped, 0] = 0
scores_sum = scores[:, None] + logits
seq_lengths[~is_stopped] += 1
scores_sum_average = scores_sum / seq_lengths[:, None]
scores_sum_average, next_tokens = scores_sum_average.view(-1).topk(beam_size, -1)
next_tokens_source = next_tokens // scores_sum.shape[1]
seq_lengths = seq_lengths[next_tokens_source]
next_tokens = next_tokens % scores_sum.shape[1]
next_tokens = next_tokens.unsqueeze(1)
tokens = tokens[next_tokens_source]
tokens = torch.cat((tokens, next_tokens), dim=1)
generated = generated[next_tokens_source]
scores = scores_sum_average * seq_lengths
is_stopped = is_stopped[next_tokens_source]
next_token_embed = model.gpt.transformer.wte(next_tokens.squeeze()).view(generated.shape[0], 1, -1)
generated = torch.cat((generated, next_token_embed), dim=1)
is_stopped = is_stopped + next_tokens.eq(stop_token_index).squeeze()
if is_stopped.all():
break
scores = scores / seq_lengths
output_list = tokens.cpu().numpy()
output_texts = [tokenizer.decode(output[:int(length)]) for output, length in zip(output_list, seq_lengths)]
order = scores.argsort(descending=True)
output_texts = [output_texts[i] for i in order]
return output_texts
def generate2(
model,
tokenizer,
tokens=None,
prompt=None,
embed=None,
entry_count=1,
entry_length=67, # maximum number of words
top_p=0.8,
temperature=1.,
stop_token: str = '.',
):
model.eval()
generated_num = 0
generated_list = []
stop_token_index = tokenizer.encode(stop_token)[0]
filter_value = -float("Inf")
device = next(model.parameters()).device
with torch.no_grad():
for entry_idx in trange(entry_count):
if embed is not None:
generated = embed
else:
if tokens is None:
tokens = torch.tensor(tokenizer.encode(prompt))
tokens = tokens.unsqueeze(0).to(device)
generated = model.gpt.transformer.wte(tokens)
for i in range(entry_length):
outputs = model.gpt(inputs_embeds=generated)
logits = outputs.logits
logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(nnf.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
..., :-1
].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices[sorted_indices_to_remove]
logits[:, indices_to_remove] = filter_value
next_token = torch.argmax(logits, -1).unsqueeze(0)
next_token_embed = model.gpt.transformer.wte(next_token)
if tokens is None:
tokens = next_token
else:
tokens = torch.cat((tokens, next_token), dim=1)
generated = torch.cat((generated, next_token_embed), dim=1)
if stop_token_index == next_token.item():
break
output_list = list(tokens.squeeze().cpu().numpy())
output_text = tokenizer.decode(output_list)
generated_list.append(output_text)
return generated_list[0]
def generate_beam(model, tokenizer, beam_size: int = 5, prompt=None, embed=None,
entry_length=67, temperature=1., stop_token: str = '.'):
model.eval()
stop_token_index = tokenizer.encode(stop_token)[0]
tokens = None
scores = None
device = next(model.parameters()).device
seq_lengths = torch.ones(beam_size, device=device)
is_stopped = torch.zeros(beam_size, device=device, dtype=torch.bool)
with torch.no_grad():
if embed is not None:
generated = embed
else:
if tokens is None:
tokens = torch.tensor(tokenizer.encode(prompt))
tokens = tokens.unsqueeze(0).to(device)
generated = model.gpt.transformer.wte(tokens)
for i in range(entry_length):
outputs = model.gpt(inputs_embeds=generated)
logits = outputs.logits
logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
logits = logits.softmax(-1).log()
if scores is None:
scores, next_tokens = logits.topk(beam_size, -1)
generated = generated.expand(beam_size, *generated.shape[1:])
next_tokens, scores = next_tokens.permute(1, 0), scores.squeeze(0)
if tokens is None:
tokens = next_tokens
else:
tokens = tokens.expand(beam_size, *tokens.shape[1:])
tokens = torch.cat((tokens, next_tokens), dim=1)
else:
logits[is_stopped] = -float(np.inf)
logits[is_stopped, 0] = 0
scores_sum = scores[:, None] + logits
seq_lengths[~is_stopped] += 1
scores_sum_average = scores_sum / seq_lengths[:, None]
scores_sum_average, next_tokens = scores_sum_average.view(-1).topk(beam_size, -1)
next_tokens_source = next_tokens // scores_sum.shape[1]
seq_lengths = seq_lengths[next_tokens_source]
next_tokens = next_tokens % scores_sum.shape[1]
next_tokens = next_tokens.unsqueeze(1)
tokens = tokens[next_tokens_source]
tokens = torch.cat((tokens, next_tokens), dim=1)
generated = generated[next_tokens_source]
scores = scores_sum_average * seq_lengths
is_stopped = is_stopped[next_tokens_source]
next_token_embed = model.gpt.transformer.wte(next_tokens.squeeze()).view(generated.shape[0], 1, -1)
generated = torch.cat((generated, next_token_embed), dim=1)
is_stopped = is_stopped + next_tokens.eq(stop_token_index).squeeze()
if is_stopped.all():
break
scores = scores / seq_lengths
output_list = tokens.cpu().numpy()
output_texts = [tokenizer.decode(output[:int(length)]) for output, length in zip(output_list, seq_lengths)]
order = scores.argsort(descending=True)
output_texts = [output_texts[i] for i in order]
return output_texts
pretrained_model_variance = "0.1" #@param ["0.0", "0.0001", "0.001", "0.015", "0.1", "2.5"]
Loading…
Cancel
Save