logo
Browse Source

init the repo.

Signed-off-by: wxywb <xy.wang@zilliz.com>
main
wxywb 2 years ago
parent
commit
0cd4e5ac80
  1. 18
      __init__.py
  2. BIN
      captioning/.DS_Store
  3. 0
      captioning/__init__.py
  4. 464
      captioning/models/AttModel.py
  5. 412
      captioning/models/CaptionModel.py
  6. 0
      captioning/models/__init__.py
  7. 92
      captioning/models/model_utils.py
  8. 1
      clip/__init__.py
  9. BIN
      clip/bpe_simple_vocab_16e6.txt.gz
  10. 193
      clip/clip.py
  11. 437
      clip/model.py
  12. 132
      clip/simple_tokenizer.py
  13. 105
      clip_caption_reward.py
  14. BIN
      configs/.DS_Store
  15. 60
      configs/phase1/FineCapEval_clipRN50_mle.yml
  16. 52
      configs/phase1/clipRN50_mle.yml
  17. 41
      configs/phase1/transformer.yml
  18. 61
      configs/phase2/FineCapEval_clipRN50_cider.yml
  19. 65
      configs/phase2/FineCapEval_clipRN50_cider_clips.yml
  20. 64
      configs/phase2/FineCapEval_clipRN50_clips.yml
  21. 64
      configs/phase2/FineCapEval_clipRN50_clips_grammar.yml
  22. 58
      configs/phase2/clipRN50_cider.yml
  23. 61
      configs/phase2/clipRN50_cider_clips.yml
  24. 58
      configs/phase2/clipRN50_clips.yml
  25. 64
      configs/phase2/clipRN50_clips_grammar.yml
  26. 41
      configs/phase2/transformer.yml
  27. 1
      data/cocotalk.json
  28. 379
      transformer_model.py
  29. 153
      utils/config.py
  30. 412
      utils/opts.py

18
__init__.py

@ -0,0 +1,18 @@
# Copyright 2021 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from clip_caption_reward import ClipCaptionReward
def clip_caption_reward(model_name: str):
return ClipCaptionReward(model_name)

BIN
captioning/.DS_Store

Binary file not shown.

0
captioning/__init__.py

464
captioning/models/AttModel.py

@ -0,0 +1,464 @@
# This file contains Att2in2, AdaAtt, AdaAttMO, UpDown model
# AdaAtt is from Knowing When to Look: Adaptive Attention via A Visual Sentinel for Image Captioning
# https://arxiv.org/abs/1612.01887
# AdaAttMO is a modified version with maxout lstm
# Att2in is from Self-critical Sequence Training for Image Captioning
# https://arxiv.org/abs/1612.00563
# In this file we only have Att2in2, which is a slightly different version of att2in,
# in which the img feature embedding and word embedding is the same as what in adaatt.
# UpDown is from Bottom-Up and Top-Down Attention for Image Captioning and VQA
# https://arxiv.org/abs/1707.07998
# However, it may not be identical to the author's architecture.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
#from . import utils
#utils.repeat_tensors
from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence, pad_packed_sequence
from .CaptionModel import CaptionModel
bad_endings = ['a','an','the','in','for','at','of','with','before','after','on','upon','near','to','is','are','am']
bad_endings += ['the']
def repeat_tensors(n, x):
"""
For a tensor of size Bx..., we repeat it n times, and make it Bnx...
For collections, do nested repeat
"""
if torch.is_tensor(x):
x = x.unsqueeze(1) # Bx1x...
x = x.expand(-1, n, *([-1]*len(x.shape[2:]))) # Bxnx...
x = x.reshape(x.shape[0]*n, *x.shape[2:]) # Bnx...
elif type(x) is list or type(x) is tuple:
x = [repeat_tensors(n, _) for _ in x]
return x
def sort_pack_padded_sequence(input, lengths):
sorted_lengths, indices = torch.sort(lengths, descending=True)
# tmp = pack_padded_sequence(input[indices], sorted_lengths, batch_first=True)
tmp = pack_padded_sequence(input[indices], sorted_lengths.cpu(), batch_first=True)
inv_ix = indices.clone()
inv_ix[indices] = torch.arange(0,len(indices)).type_as(inv_ix)
return tmp, inv_ix
def pad_unsort_packed_sequence(input, inv_ix):
tmp, _ = pad_packed_sequence(input, batch_first=True)
tmp = tmp[inv_ix]
return tmp
def pack_wrapper(module, att_feats, att_masks):
if att_masks is not None:
packed, inv_ix = sort_pack_padded_sequence(att_feats, att_masks.data.long().sum(1))
return pad_unsort_packed_sequence(PackedSequence(module(packed[0]), packed[1]), inv_ix)
else:
return module(att_feats)
class AttModel(CaptionModel):
def __init__(self, opt):
super(AttModel, self).__init__()
self.vocab_size = opt.vocab_size
self.input_encoding_size = opt.input_encoding_size
#self.rnn_type = opt.rnn_type
self.rnn_size = opt.rnn_size
self.num_layers = opt.num_layers
self.drop_prob_lm = opt.drop_prob_lm
self.seq_length = getattr(opt, 'max_length', 20) or opt.seq_length # maximum sample length
self.fc_feat_size = opt.fc_feat_size
self.att_feat_size = opt.att_feat_size
self.att_hid_size = opt.att_hid_size
self.bos_idx = getattr(opt, 'bos_idx', 0)
self.eos_idx = getattr(opt, 'eos_idx', 0)
self.pad_idx = getattr(opt, 'pad_idx', 0)
self.use_bn = getattr(opt, 'use_bn', 0)
self.ss_prob = 0.0 # Schedule sampling probability
self.embed = nn.Sequential(nn.Embedding(self.vocab_size + 1, self.input_encoding_size),
nn.ReLU(),
nn.Dropout(self.drop_prob_lm))
self.fc_embed = nn.Sequential(nn.Linear(self.fc_feat_size, self.rnn_size),
nn.ReLU(),
nn.Dropout(self.drop_prob_lm))
self.att_embed = nn.Sequential(*(
((nn.BatchNorm1d(self.att_feat_size),) if self.use_bn else ())+
(nn.Linear(self.att_feat_size, self.rnn_size),
nn.ReLU(),
nn.Dropout(self.drop_prob_lm))+
((nn.BatchNorm1d(self.rnn_size),) if self.use_bn==2 else ())))
self.logit_layers = getattr(opt, 'logit_layers', 1)
if self.logit_layers == 1:
self.logit = nn.Linear(self.rnn_size, self.vocab_size + 1)
else:
self.logit = [[nn.Linear(self.rnn_size, self.rnn_size), nn.ReLU(), nn.Dropout(0.5)] for _ in range(opt.logit_layers - 1)]
self.logit = nn.Sequential(*(reduce(lambda x,y:x+y, self.logit) + [nn.Linear(self.rnn_size, self.vocab_size + 1)]))
self.ctx2att = nn.Linear(self.rnn_size, self.att_hid_size)
# For remove bad endding
self.vocab = opt.vocab
self.bad_endings_ix = [int(k) for k,v in self.vocab.items() if v in bad_endings]
def init_hidden(self, bsz):
weight = self.logit.weight \
if hasattr(self.logit, "weight") \
else self.logit[0].weight
return (weight.new_zeros(self.num_layers, bsz, self.rnn_size),
weight.new_zeros(self.num_layers, bsz, self.rnn_size))
def clip_att(self, att_feats, att_masks):
# Clip the length of att_masks and att_feats to the maximum length
if att_masks is not None:
max_len = att_masks.data.long().sum(1).max()
att_feats = att_feats[:, :max_len].contiguous()
att_masks = att_masks[:, :max_len].contiguous()
return att_feats, att_masks
def _prepare_feature(self, fc_feats, att_feats, att_masks):
att_feats, att_masks = self.clip_att(att_feats, att_masks)
# embed fc and att feats
fc_feats = self.fc_embed(fc_feats)
att_feats = pack_wrapper(self.att_embed, att_feats, att_masks)
# Project the attention feats first to reduce memory and computation comsumptions.
p_att_feats = self.ctx2att(att_feats)
return fc_feats, att_feats, p_att_feats, att_masks
def _forward(self, fc_feats, att_feats, seq, att_masks=None):
batch_size = fc_feats.size(0)
if seq.ndim == 3: # B * seq_per_img * seq_len
seq = seq.reshape(-1, seq.shape[2])
seq_per_img = seq.shape[0] // batch_size
state = self.init_hidden(batch_size*seq_per_img)
outputs = fc_feats.new_zeros(batch_size*seq_per_img, seq.size(1), self.vocab_size+1)
# Prepare the features
p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = self._prepare_feature(fc_feats, att_feats, att_masks)
# pp_att_feats is used for attention, we cache it in advance to reduce computation cost
if seq_per_img > 1:
p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = repeat_tensors(seq_per_img,
[p_fc_feats, p_att_feats, pp_att_feats, p_att_masks]
)
for i in range(seq.size(1)):
if self.training and i >= 1 and self.ss_prob > 0.0: # otherwiste no need to sample
sample_prob = fc_feats.new(batch_size*seq_per_img).uniform_(0, 1)
sample_mask = sample_prob < self.ss_prob
if sample_mask.sum() == 0:
it = seq[:, i].clone()
else:
sample_ind = sample_mask.nonzero().view(-1)
it = seq[:, i].data.clone()
prob_prev = torch.exp(outputs[:, i-1].detach()) # fetch prev distribution: shape Nx(M+1)
it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1).index_select(0, sample_ind))
else:
it = seq[:, i].clone()
# break if all the sequences end
if i >= 1 and seq[:, i].sum() == 0:
break
output, state = self.get_logprobs_state(it, p_fc_feats, p_att_feats, pp_att_feats, p_att_masks, state)
outputs[:, i] = output
return outputs
def get_logprobs_state(self, it, fc_feats, att_feats, p_att_feats, att_masks, state, output_logsoftmax=1):
# 'it' contains a word index
xt = self.embed(it)
output, state = self.core(xt, fc_feats, att_feats, p_att_feats, state, att_masks)
if output_logsoftmax:
logprobs = F.log_softmax(self.logit(output), dim=1)
else:
logprobs = self.logit(output)
return logprobs, state
def _old_sample_beam(self, fc_feats, att_feats, att_masks=None, opt={}):
beam_size = opt.get('beam_size', 10)
group_size = opt.get('group_size', 1)
sample_n = opt.get('sample_n', 10)
# when sample_n == beam_size then each beam is a sample.
assert sample_n == 1 or sample_n == beam_size // group_size, 'when beam search, sample_n == 1 or beam search'
batch_size = fc_feats.size(0)
p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = self._prepare_feature(fc_feats, att_feats, att_masks)
assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed'
seq = fc_feats.new_full((batch_size*sample_n, self.seq_length), self.pad_idx, dtype=torch.long)
seqLogprobs = fc_feats.new_zeros(batch_size*sample_n, self.seq_length, self.vocab_size + 1)
# lets process every image independently for now, for simplicity
self.done_beams = [[] for _ in range(batch_size)]
for k in range(batch_size):
state = self.init_hidden(beam_size)
tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, tmp_att_masks = repeat_tensors(beam_size,
[p_fc_feats[k:k+1], p_att_feats[k:k+1], pp_att_feats[k:k+1], p_att_masks[k:k+1] if att_masks is not None else None]
)
for t in range(1):
if t == 0: # input <bos>
it = fc_feats.new_full([beam_size], self.bos_idx, dtype=torch.long)
logprobs, state = self.get_logprobs_state(it, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, tmp_att_masks, state)
self.done_beams[k] = self.old_beam_search(state, logprobs, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, tmp_att_masks, opt=opt)
if sample_n == beam_size:
for _n in range(sample_n):
seq[k*sample_n+_n, :] = self.done_beams[k][_n]['seq']
seqLogprobs[k*sample_n+_n, :] = self.done_beams[k][_n]['logps']
else:
seq[k, :] = self.done_beams[k][0]['seq'] # the first beam has highest cumulative score
seqLogprobs[k, :] = self.done_beams[k][0]['logps']
# return the samples and their log likelihoods
return seq, seqLogprobs
def _sample_beam(self, fc_feats, att_feats, att_masks=None, opt={}):
beam_size = opt.get('beam_size', 10)
group_size = opt.get('group_size', 1)
sample_n = opt.get('sample_n', 10)
# when sample_n == beam_size then each beam is a sample.
assert sample_n == 1 or sample_n == beam_size // group_size, 'when beam search, sample_n == 1 or beam search'
batch_size = fc_feats.size(0)
p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = self._prepare_feature(fc_feats, att_feats, att_masks)
assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed'
seq = fc_feats.new_full((batch_size*sample_n, self.seq_length), self.pad_idx, dtype=torch.long)
seqLogprobs = fc_feats.new_zeros(batch_size*sample_n, self.seq_length, self.vocab_size + 1)
# lets process every image independently for now, for simplicity
self.done_beams = [[] for _ in range(batch_size)]
state = self.init_hidden(batch_size)
# first step, feed bos
it = fc_feats.new_full([batch_size], self.bos_idx, dtype=torch.long)
logprobs, state = self.get_logprobs_state(it, p_fc_feats, p_att_feats, pp_att_feats, p_att_masks, state)
p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = repeat_tensors(beam_size,
[p_fc_feats, p_att_feats, pp_att_feats, p_att_masks]
)
self.done_beams = self.beam_search(state, logprobs, p_fc_feats, p_att_feats, pp_att_feats, p_att_masks, opt=opt)
for k in range(batch_size):
if sample_n == beam_size:
for _n in range(sample_n):
seq_len = self.done_beams[k][_n]['seq'].shape[0]
seq[k*sample_n+_n, :seq_len] = self.done_beams[k][_n]['seq']
seqLogprobs[k*sample_n+_n, :seq_len] = self.done_beams[k][_n]['logps']
else:
seq_len = self.done_beams[k][0]['seq'].shape[0]
seq[k, :seq_len] = self.done_beams[k][0]['seq'] # the first beam has highest cumulative score
seqLogprobs[k, :seq_len] = self.done_beams[k][0]['logps']
# return the samples and their log likelihoods
return seq, seqLogprobs
def _sample(self, fc_feats, att_feats, att_masks=None, opt={}):
sample_method = opt.get('sample_method', 'greedy')
beam_size = opt.get('beam_size', 1)
temperature = opt.get('temperature', 1.0)
sample_n = int(opt.get('sample_n', 1))
group_size = opt.get('group_size', 1)
output_logsoftmax = opt.get('output_logsoftmax', 1)
decoding_constraint = opt.get('decoding_constraint', 0)
block_trigrams = opt.get('block_trigrams', 0)
remove_bad_endings = opt.get('remove_bad_endings', 0)
if beam_size > 1 and sample_method in ['greedy', 'beam_search']:
return self._sample_beam(fc_feats, att_feats, att_masks, opt)
if group_size > 1:
return self._diverse_sample(fc_feats, att_feats, att_masks, opt)
batch_size = fc_feats.size(0)
state = self.init_hidden(batch_size*sample_n)
p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = self._prepare_feature(fc_feats, att_feats, att_masks)
if sample_n > 1:
p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = repeat_tensors(sample_n,
[p_fc_feats, p_att_feats, pp_att_feats, p_att_masks]
)
trigrams = [] # will be a list of batch_size dictionaries
seq = fc_feats.new_full((batch_size*sample_n, self.seq_length), self.pad_idx, dtype=torch.long)
seqLogprobs = fc_feats.new_zeros(batch_size*sample_n, self.seq_length, self.vocab_size + 1)
for t in range(self.seq_length + 1):
if t == 0: # input <bos>
it = fc_feats.new_full([batch_size*sample_n], self.bos_idx, dtype=torch.long)
logprobs, state = self.get_logprobs_state(it, p_fc_feats, p_att_feats, pp_att_feats, p_att_masks, state, output_logsoftmax=output_logsoftmax)
if decoding_constraint and t > 0:
tmp = logprobs.new_zeros(logprobs.size())
tmp.scatter_(1, seq[:,t-1].data.unsqueeze(1), float('-inf'))
logprobs = logprobs + tmp
if remove_bad_endings and t > 0:
tmp = logprobs.new_zeros(logprobs.size())
prev_bad = np.isin(seq[:,t-1].data.cpu().numpy(), self.bad_endings_ix)
# Make it impossible to generate bad_endings
tmp[torch.from_numpy(prev_bad.astype('uint8')), 0] = float('-inf')
logprobs = logprobs + tmp
# Mess with trigrams
# Copy from https://github.com/lukemelas/image-paragraph-captioning
if block_trigrams and t >= 3:
# Store trigram generated at last step
prev_two_batch = seq[:,t-3:t-1]
for i in range(batch_size): # = seq.size(0)
prev_two = (prev_two_batch[i][0].item(), prev_two_batch[i][1].item())
current = seq[i][t-1]
if t == 3: # initialize
trigrams.append({prev_two: [current]}) # {LongTensor: list containing 1 int}
elif t > 3:
if prev_two in trigrams[i]: # add to list
trigrams[i][prev_two].append(current)
else: # create list
trigrams[i][prev_two] = [current]
# Block used trigrams at next step
prev_two_batch = seq[:,t-2:t]
mask = torch.zeros(logprobs.size(), requires_grad=False).to(logprobs.device) # batch_size x vocab_size
for i in range(batch_size):
prev_two = (prev_two_batch[i][0].item(), prev_two_batch[i][1].item())
if prev_two in trigrams[i]:
for j in trigrams[i][prev_two]:
mask[i,j] += 1
# Apply mask to log probs
#logprobs = logprobs - (mask * 1e9)
alpha = 2.0 # = 4
logprobs = logprobs + (mask * -0.693 * alpha) # ln(1/2) * alpha (alpha -> infty works best)
# sample the next word
if t == self.seq_length: # skip if we achieve maximum length
break
it, sampleLogprobs = self.sample_next_word(logprobs, sample_method, temperature)
# stop when all finished
if t == 0:
unfinished = it != self.eos_idx
else:
it[~unfinished] = self.pad_idx # This allows eos_idx not being overwritten to 0
logprobs = logprobs * unfinished.unsqueeze(1).to(logprobs)
unfinished = unfinished & (it != self.eos_idx)
seq[:,t] = it
seqLogprobs[:,t] = logprobs
# quit loop if all sequences have finished
if unfinished.sum() == 0:
break
return seq, seqLogprobs
def _diverse_sample(self, fc_feats, att_feats, att_masks=None, opt={}):
sample_method = opt.get('sample_method', 'greedy')
beam_size = opt.get('beam_size', 1)
temperature = opt.get('temperature', 1.0)
group_size = opt.get('group_size', 1)
diversity_lambda = opt.get('diversity_lambda', 0.5)
decoding_constraint = opt.get('decoding_constraint', 0)
block_trigrams = opt.get('block_trigrams', 0)
remove_bad_endings = opt.get('remove_bad_endings', 0)
batch_size = fc_feats.size(0)
state = self.init_hidden(batch_size)
p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = self._prepare_feature(fc_feats, att_feats, att_masks)
trigrams_table = [[] for _ in range(group_size)] # will be a list of batch_size dictionaries
seq_table = [fc_feats.new_full((batch_size, self.seq_length), self.pad_idx, dtype=torch.long) for _ in range(group_size)]
seqLogprobs_table = [fc_feats.new_zeros(batch_size, self.seq_length) for _ in range(group_size)]
state_table = [self.init_hidden(batch_size) for _ in range(group_size)]
for tt in range(self.seq_length + group_size):
for divm in range(group_size):
t = tt - divm
seq = seq_table[divm]
seqLogprobs = seqLogprobs_table[divm]
trigrams = trigrams_table[divm]
if t >= 0 and t <= self.seq_length-1:
if t == 0: # input <bos>
it = fc_feats.new_full([batch_size], self.bos_idx, dtype=torch.long)
else:
it = seq[:, t-1] # changed
logprobs, state_table[divm] = self.get_logprobs_state(it, p_fc_feats, p_att_feats, pp_att_feats, p_att_masks, state_table[divm]) # changed
logprobs = F.log_softmax(logprobs / temperature, dim=-1)
# Add diversity
if divm > 0:
unaug_logprobs = logprobs.clone()
for prev_choice in range(divm):
prev_decisions = seq_table[prev_choice][:, t]
logprobs[:, prev_decisions] = logprobs[:, prev_decisions] - diversity_lambda
if decoding_constraint and t > 0:
tmp = logprobs.new_zeros(logprobs.size())
tmp.scatter_(1, seq[:,t-1].data.unsqueeze(1), float('-inf'))
logprobs = logprobs + tmp
if remove_bad_endings and t > 0:
tmp = logprobs.new_zeros(logprobs.size())
prev_bad = np.isin(seq[:,t-1].data.cpu().numpy(), self.bad_endings_ix)
# Impossible to generate remove_bad_endings
tmp[torch.from_numpy(prev_bad.astype('uint8')), 0] = float('-inf')
logprobs = logprobs + tmp
# Mess with trigrams
if block_trigrams and t >= 3:
# Store trigram generated at last step
prev_two_batch = seq[:,t-3:t-1]
for i in range(batch_size): # = seq.size(0)
prev_two = (prev_two_batch[i][0].item(), prev_two_batch[i][1].item())
current = seq[i][t-1]
if t == 3: # initialize
trigrams.append({prev_two: [current]}) # {LongTensor: list containing 1 int}
elif t > 3:
if prev_two in trigrams[i]: # add to list
trigrams[i][prev_two].append(current)
else: # create list
trigrams[i][prev_two] = [current]
# Block used trigrams at next step
prev_two_batch = seq[:,t-2:t]
mask = torch.zeros(logprobs.size(), requires_grad=False).cuda() # batch_size x vocab_size
for i in range(batch_size):
prev_two = (prev_two_batch[i][0].item(), prev_two_batch[i][1].item())
if prev_two in trigrams[i]:
for j in trigrams[i][prev_two]:
mask[i,j] += 1
# Apply mask to log probs
#logprobs = logprobs - (mask * 1e9)
alpha = 2.0 # = 4
logprobs = logprobs + (mask * -0.693 * alpha) # ln(1/2) * alpha (alpha -> infty works best)
it, sampleLogprobs = self.sample_next_word(logprobs, sample_method, 1)
# stop when all finished
if t == 0:
unfinished = it != self.eos_idx
else:
unfinished = (seq[:,t-1] != self.pad_idx) & (seq[:,t-1] != self.eos_idx)
it[~unfinished] = self.pad_idx
unfinished = unfinished & (it != self.eos_idx) # changed
seq[:,t] = it
seqLogprobs[:,t] = sampleLogprobs.view(-1)
return torch.stack(seq_table, 1).reshape(batch_size * group_size, -1), torch.stack(seqLogprobs_table, 1).reshape(batch_size * group_size, -1)

412
captioning/models/CaptionModel.py

@ -0,0 +1,412 @@
# This file contains ShowAttendTell and AllImg model
# ShowAttendTell is from Show, Attend and Tell: Neural Image Caption Generation with Visual Attention
# https://arxiv.org/abs/1502.03044
# AllImg is a model where
# img feature is concatenated with word embedding at every time step as the input of lstm
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import *
#import .model_utils as mutils
from . import model_utils
#from ..utils import misc as utils
#from ... import utils as model_utils
#model_utils split_tensors
#utils penalty_builder decode_sequence
class CaptionModel(nn.Module):
def __init__(self):
super(CaptionModel, self).__init__()
# implements beam search
# calls beam_step and returns the final set of beams
# augments log-probabilities with diversity terms when number of groups > 1
def forward(self, *args, **kwargs):
mode = kwargs.get('mode', 'forward')
if 'mode' in kwargs:
del kwargs['mode']
return getattr(self, '_'+mode)(*args, **kwargs)
def beam_search(self, init_state, init_logprobs, *args, **kwargs):
# function computes the similarity score to be augmented
def add_diversity(beam_seq_table, logprobs, t, divm, diversity_lambda, bdash):
local_time = t - divm
unaug_logprobs = logprobs.clone()
batch_size = beam_seq_table[0].shape[0]
if divm > 0:
change = logprobs.new_zeros(batch_size, logprobs.shape[-1])
for prev_choice in range(divm):
prev_decisions = beam_seq_table[prev_choice][:, :, local_time] # Nxb
for prev_labels in range(bdash):
change.scatter_add_(1, prev_decisions[:, prev_labels].unsqueeze(-1), change.new_ones(batch_size, 1))
if local_time == 0:
logprobs = logprobs - change * diversity_lambda
else:
logprobs = logprobs - self.repeat_tensor(bdash, change) * diversity_lambda
return logprobs, unaug_logprobs
# does one step of classical beam search
def beam_step(logprobs, unaug_logprobs, beam_size, t, beam_seq, beam_seq_logprobs, beam_logprobs_sum, state):
#INPUTS:
#logprobs: probabilities augmented after diversity N*bxV
#beam_size: obvious
#t : time instant
#beam_seq : tensor contanining the beams
#beam_seq_logprobs: tensor contanining the beam logprobs
#beam_logprobs_sum: tensor contanining joint logprobs
#OUPUTS:
#beam_seq : tensor containing the word indices of the decoded captions Nxbxl
#beam_seq_logprobs : log-probability of each decision made, NxbxlxV
#beam_logprobs_sum : joint log-probability of each beam Nxb
batch_size = beam_logprobs_sum.shape[0]
vocab_size = logprobs.shape[-1]
logprobs = logprobs.reshape(batch_size, -1, vocab_size) # NxbxV
if t == 0:
assert logprobs.shape[1] == 1
beam_logprobs_sum = beam_logprobs_sum[:, :1]
candidate_logprobs = beam_logprobs_sum.unsqueeze(-1) + logprobs # beam_logprobs_sum Nxb logprobs is NxbxV
ys, ix = torch.sort(candidate_logprobs.reshape(candidate_logprobs.shape[0], -1), -1, True)
ys, ix = ys[:,:beam_size], ix[:,:beam_size]
beam_ix = ix // vocab_size # Nxb which beam
selected_ix = ix % vocab_size # Nxb # which world
state_ix = (beam_ix + torch.arange(batch_size).type_as(beam_ix).unsqueeze(-1) * logprobs.shape[1]).reshape(-1) # N*b which in Nxb beams
if t > 0:
# gather according to beam_ix
assert (beam_seq.gather(1, beam_ix.unsqueeze(-1).expand_as(beam_seq)) == beam_seq.reshape(-1, beam_seq.shape[-1])[state_ix].view_as(beam_seq)).all()
beam_seq = beam_seq.gather(1, beam_ix.unsqueeze(-1).expand_as(beam_seq))
beam_seq_logprobs = beam_seq_logprobs.gather(1, beam_ix.unsqueeze(-1).unsqueeze(-1).expand_as(beam_seq_logprobs))
beam_seq = torch.cat([beam_seq, selected_ix.unsqueeze(-1)], -1) # beam_seq Nxbxl
beam_logprobs_sum = beam_logprobs_sum.gather(1, beam_ix) + \
logprobs.reshape(batch_size, -1).gather(1, ix)
assert (beam_logprobs_sum == ys).all()
_tmp_beam_logprobs = unaug_logprobs[state_ix].reshape(batch_size, -1, vocab_size)
beam_logprobs = unaug_logprobs.reshape(batch_size, -1, vocab_size).gather(1, beam_ix.unsqueeze(-1).expand(-1, -1, vocab_size)) # NxbxV
assert (_tmp_beam_logprobs == beam_logprobs).all()
beam_seq_logprobs = torch.cat([
beam_seq_logprobs,
beam_logprobs.reshape(batch_size, -1, 1, vocab_size)], 2)
new_state = [None for _ in state]
for _ix in range(len(new_state)):
# copy over state in previous beam q to new beam at vix
new_state[_ix] = state[_ix][:, state_ix]
state = new_state
return beam_seq,beam_seq_logprobs,beam_logprobs_sum,state
# Start diverse_beam_search
opt = kwargs['opt']
temperature = opt.get('temperature', 1) # This should not affect beam search, but will affect dbs
beam_size = opt.get('beam_size', 10)
group_size = opt.get('group_size', 1)
diversity_lambda = opt.get('diversity_lambda', 0.5)
decoding_constraint = opt.get('decoding_constraint', 0)
remove_bad_endings = opt.get('remove_bad_endings', 0)
suppress_UNK = opt.get('suppress_UNK', 0)
length_penalty = utils.penalty_builder(opt.get('length_penalty', ''))
bdash = beam_size // group_size # beam per group
batch_size = init_logprobs.shape[0]
device = init_logprobs.device
# INITIALIZATIONS
beam_seq_table = [torch.LongTensor(batch_size, bdash, 0).to(device) for _ in range(group_size)]
beam_seq_logprobs_table = [torch.FloatTensor(batch_size, bdash, 0, self.vocab_size + 1).to(device) for _ in range(group_size)]
beam_logprobs_sum_table = [torch.zeros(batch_size, bdash).to(device) for _ in range(group_size)]
# logprobs # logprobs predicted in last time step, shape (beam_size, vocab_size+1)
done_beams_table = [[[] for __ in range(group_size)] for _ in range(batch_size)]
# state_table = [list(torch.unbind(_)) for _ in torch.stack(init_state).chunk(group_size, 2)]
# state_table = list(zip(*[_.reshape(-1, batch_size * bdash, group_size, *_.shape[2:]).chunk(group_size, 2) for _ in init_state]))
state_table = [[_.clone() for _ in init_state] for _ in range(group_size)]
# logprobs_table = list(init_logprobs.reshape(batch_size * bdash, group_size, -1).chunk(group_size, 0))
logprobs_table = [init_logprobs.clone() for _ in range(group_size)]
# END INIT
# Chunk elements in the args
args = list(args)
args = model_utils.split_tensors(group_size, args) # For each arg, turn (Bbg)x... to (Bb)x(g)x...
if self.__class__.__name__ == 'AttEnsemble':
args = [[[args[j][i][k] for i in range(len(self.models))] for j in range(len(args))] for k in range(group_size)] # group_name, arg_name, model_name
else:
args = [[args[i][j] for i in range(len(args))] for j in range(group_size)]
for t in range(self.seq_length + group_size - 1):
for divm in range(group_size):
if t >= divm and t <= self.seq_length + divm - 1:
# add diversity
logprobs = logprobs_table[divm]
# suppress previous word
if decoding_constraint and t-divm > 0:
logprobs.scatter_(1, beam_seq_table[divm][:, :, t-divm-1].reshape(-1, 1).to(device), float('-inf'))
if remove_bad_endings and t-divm > 0:
logprobs[torch.from_numpy(np.isin(beam_seq_table[divm][:, :, t-divm-1].cpu().numpy(), self.bad_endings_ix)).reshape(-1), 0] = float('-inf')
# suppress UNK tokens in the decoding
if suppress_UNK and hasattr(self, 'vocab') and self.vocab[str(logprobs.size(1)-1)] == 'UNK':
logprobs[:,logprobs.size(1)-1] = logprobs[:, logprobs.size(1)-1] - 1000
# diversity is added here
# the function directly modifies the logprobs values and hence, we need to return
# the unaugmented ones for sorting the candidates in the end. # for historical
# reasons :-)
logprobs, unaug_logprobs = add_diversity(beam_seq_table,logprobs,t,divm,diversity_lambda,bdash)
# infer new beams
beam_seq_table[divm],\
beam_seq_logprobs_table[divm],\
beam_logprobs_sum_table[divm],\
state_table[divm] = beam_step(logprobs,
unaug_logprobs,
bdash,
t-divm,
beam_seq_table[divm],
beam_seq_logprobs_table[divm],
beam_logprobs_sum_table[divm],
state_table[divm])
# if time's up... or if end token is reached then copy beams
for b in range(batch_size):
is_end = beam_seq_table[divm][b, :, t-divm] == self.eos_idx
assert beam_seq_table[divm].shape[-1] == t-divm+1
if t == self.seq_length + divm - 1:
is_end.fill_(1)
for vix in range(bdash):
if is_end[vix]:
final_beam = {
'seq': beam_seq_table[divm][b, vix].clone(),
'logps': beam_seq_logprobs_table[divm][b, vix].clone(),
'unaug_p': beam_seq_logprobs_table[divm][b, vix].sum().item(),
'p': beam_logprobs_sum_table[divm][b, vix].item()
}
final_beam['p'] = length_penalty(t-divm+1, final_beam['p'])
done_beams_table[b][divm].append(final_beam)
beam_logprobs_sum_table[divm][b, is_end] -= 1000
# move the current group one step forward in time
it = beam_seq_table[divm][:, :, t-divm].reshape(-1).to(logprobs.device)
logprobs_table[divm], state_table[divm] = self.get_logprobs_state(it, *(args[divm] + [state_table[divm]]))
logprobs_table[divm] = F.log_softmax(logprobs_table[divm] / temperature, dim=-1)
# all beams are sorted by their log-probabilities
done_beams_table = [[sorted(done_beams_table[b][i], key=lambda x: -x['p'])[:bdash] for i in range(group_size)] for b in range(batch_size)]
done_beams = [sum(_, []) for _ in done_beams_table]
return done_beams
def old_beam_search(self, init_state, init_logprobs, *args, **kwargs):
# function computes the similarity score to be augmented
def add_diversity(beam_seq_table, logprobsf, t, divm, diversity_lambda, bdash):
local_time = t - divm
unaug_logprobsf = logprobsf.clone()
for prev_choice in range(divm):
prev_decisions = beam_seq_table[prev_choice][local_time]
for sub_beam in range(bdash):
for prev_labels in range(bdash):
logprobsf[sub_beam][prev_decisions[prev_labels]] = logprobsf[sub_beam][prev_decisions[prev_labels]] - diversity_lambda
return unaug_logprobsf
# does one step of classical beam search
def beam_step(logprobsf, unaug_logprobsf, beam_size, t, beam_seq, beam_seq_logprobs, beam_logprobs_sum, state):
#INPUTS:
#logprobsf: probabilities augmented after diversity
#beam_size: obvious
#t : time instant
#beam_seq : tensor contanining the beams
#beam_seq_logprobs: tensor contanining the beam logprobs
#beam_logprobs_sum: tensor contanining joint logprobs
#OUPUTS:
#beam_seq : tensor containing the word indices of the decoded captions
#beam_seq_logprobs : log-probability of each decision made, same size as beam_seq
#beam_logprobs_sum : joint log-probability of each beam
ys,ix = torch.sort(logprobsf,1,True)
candidates = []
cols = min(beam_size, ys.size(1))
rows = beam_size
if t == 0:
rows = 1
for c in range(cols): # for each column (word, essentially)
for q in range(rows): # for each beam expansion
#compute logprob of expanding beam q with word in (sorted) position c
local_logprob = ys[q,c].item()
candidate_logprob = beam_logprobs_sum[q] + local_logprob
# local_unaug_logprob = unaug_logprobsf[q,ix[q,c]]
candidates.append({'c':ix[q,c], 'q':q, 'p':candidate_logprob, 'r':unaug_logprobsf[q]})
candidates = sorted(candidates, key=lambda x: -x['p'])
new_state = [_.clone() for _ in state]
#beam_seq_prev, beam_seq_logprobs_prev
if t >= 1:
#we''ll need these as reference when we fork beams around
beam_seq_prev = beam_seq[:t].clone()
beam_seq_logprobs_prev = beam_seq_logprobs[:t].clone()
for vix in range(beam_size):
v = candidates[vix]
#fork beam index q into index vix
if t >= 1:
beam_seq[:t, vix] = beam_seq_prev[:, v['q']]
beam_seq_logprobs[:t, vix] = beam_seq_logprobs_prev[:, v['q']]
#rearrange recurrent states
for state_ix in range(len(new_state)):
# copy over state in previous beam q to new beam at vix
new_state[state_ix][:, vix] = state[state_ix][:, v['q']] # dimension one is time step
#append new end terminal at the end of this beam
beam_seq[t, vix] = v['c'] # c'th word is the continuation
beam_seq_logprobs[t, vix] = v['r'] # the raw logprob here
beam_logprobs_sum[vix] = v['p'] # the new (sum) logprob along this beam
state = new_state
return beam_seq,beam_seq_logprobs,beam_logprobs_sum,state,candidates
# Start diverse_beam_search
opt = kwargs['opt']
temperature = opt.get('temperature', 1) # This should not affect beam search, but will affect dbs
beam_size = opt.get('beam_size', 10)
group_size = opt.get('group_size', 1)
diversity_lambda = opt.get('diversity_lambda', 0.5)
decoding_constraint = opt.get('decoding_constraint', 0)
remove_bad_endings = opt.get('remove_bad_endings', 0)
suppress_UNK = opt.get('suppress_UNK', 0)
length_penalty = utils.penalty_builder(opt.get('length_penalty', ''))
bdash = beam_size // group_size # beam per group
# INITIALIZATIONS
beam_seq_table = [torch.LongTensor(self.seq_length, bdash).zero_() for _ in range(group_size)]
beam_seq_logprobs_table = [torch.FloatTensor(self.seq_length, bdash, self.vocab_size + 1).zero_() for _ in range(group_size)]
beam_logprobs_sum_table = [torch.zeros(bdash) for _ in range(group_size)]
# logprobs # logprobs predicted in last time step, shape (beam_size, vocab_size+1)
done_beams_table = [[] for _ in range(group_size)]
# state_table = [list(torch.unbind(_)) for _ in torch.stack(init_state).chunk(group_size, 2)]
state_table = list(zip(*[_.chunk(group_size, 1) for _ in init_state]))
logprobs_table = list(init_logprobs.chunk(group_size, 0))
# END INIT
# Chunk elements in the args
args = list(args)
if self.__class__.__name__ == 'AttEnsemble':
args = [[_.chunk(group_size) if _ is not None else [None]*group_size for _ in args_] for args_ in args] # arg_name, model_name, group_name
args = [[[args[j][i][k] for i in range(len(self.models))] for j in range(len(args))] for k in range(group_size)] # group_name, arg_name, model_name
else:
args = [_.chunk(group_size) if _ is not None else [None]*group_size for _ in args]
args = [[args[i][j] for i in range(len(args))] for j in range(group_size)]
for t in range(self.seq_length + group_size - 1):
for divm in range(group_size):
if t >= divm and t <= self.seq_length + divm - 1:
# add diversity
logprobsf = logprobs_table[divm]
# suppress previous word
if decoding_constraint and t-divm > 0:
logprobsf.scatter_(1, beam_seq_table[divm][t-divm-1].unsqueeze(1).to(logprobsf.device), float('-inf'))
if remove_bad_endings and t-divm > 0:
logprobsf[torch.from_numpy(np.isin(beam_seq_table[divm][t-divm-1].cpu().numpy(), self.bad_endings_ix)), 0] = float('-inf')
# suppress UNK tokens in the decoding
if suppress_UNK and hasattr(self, 'vocab') and self.vocab[str(logprobsf.size(1)-1)] == 'UNK':
logprobsf[:,logprobsf.size(1)-1] = logprobsf[:, logprobsf.size(1)-1] - 1000
# diversity is added here
# the function directly modifies the logprobsf values and hence, we need to return
# the unaugmented ones for sorting the candidates in the end. # for historical
# reasons :-)
unaug_logprobsf = add_diversity(beam_seq_table,logprobsf,t,divm,diversity_lambda,bdash)
# infer new beams
beam_seq_table[divm],\
beam_seq_logprobs_table[divm],\
beam_logprobs_sum_table[divm],\
state_table[divm],\
candidates_divm = beam_step(logprobsf,
unaug_logprobsf,
bdash,
t-divm,
beam_seq_table[divm],
beam_seq_logprobs_table[divm],
beam_logprobs_sum_table[divm],
state_table[divm])
# if time's up... or if end token is reached then copy beams
for vix in range(bdash):
if beam_seq_table[divm][t-divm,vix] == self.eos_idx or t == self.seq_length + divm - 1:
final_beam = {
'seq': beam_seq_table[divm][:, vix].clone(),
'logps': beam_seq_logprobs_table[divm][:, vix].clone(),
'unaug_p': beam_seq_logprobs_table[divm][:, vix].sum().item(),
'p': beam_logprobs_sum_table[divm][vix].item()
}
final_beam['p'] = length_penalty(t-divm+1, final_beam['p'])
done_beams_table[divm].append(final_beam)
# don't continue beams from finished sequences
beam_logprobs_sum_table[divm][vix] = -1000
# move the current group one step forward in time
it = beam_seq_table[divm][t-divm].to(logprobsf.device)
logprobs_table[divm], state_table[divm] = self.get_logprobs_state(it, *(args[divm] + [state_table[divm]]))
logprobs_table[divm] = F.log_softmax(logprobs_table[divm] / temperature, dim=-1)
# all beams are sorted by their log-probabilities
done_beams_table = [sorted(done_beams_table[i], key=lambda x: -x['p'])[:bdash] for i in range(group_size)]
done_beams = sum(done_beams_table, [])
return done_beams
def sample_next_word(self, logprobs, sample_method, temperature):
if sample_method == 'greedy':
sampleLogprobs, it = torch.max(logprobs.data, 1)
it = it.view(-1).long()
elif sample_method == 'gumbel': # gumbel softmax
# ref: https://gist.github.com/yzh119/fd2146d2aeb329d067568a493b20172f
def sample_gumbel(shape, eps=1e-20):
U = torch.rand(shape).to(logprobs.device)
return -torch.log(-torch.log(U + eps) + eps)
def gumbel_softmax_sample(logits, temperature):
y = logits + sample_gumbel(logits.size())
return F.log_softmax(y / temperature, dim=-1)
_logprobs = gumbel_softmax_sample(logprobs, temperature)
_, it = torch.max(_logprobs.data, 1)
sampleLogprobs = logprobs.gather(1, it.unsqueeze(1)) # gather the logprobs at sampled positions
else:
logprobs = logprobs / temperature
if sample_method.startswith('top'): # topk sampling
top_num = float(sample_method[3:])
if 0 < top_num < 1:
# nucleus sampling from # The Curious Case of Neural Text Degeneration
probs = F.softmax(logprobs, dim=1)
sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=1)
_cumsum = sorted_probs.cumsum(1)
mask = _cumsum < top_num
mask = torch.cat([torch.ones_like(mask[:,:1]), mask[:,:-1]], 1)
sorted_probs = sorted_probs * mask.to(sorted_probs)
sorted_probs = sorted_probs / sorted_probs.sum(1, keepdim=True)
logprobs.scatter_(1, sorted_indices, sorted_probs.log())
else:
the_k = int(top_num)
tmp = torch.empty_like(logprobs).fill_(float('-inf'))
topk, indices = torch.topk(logprobs, the_k, dim=1)
tmp = tmp.scatter(1, indices, topk)
logprobs = tmp
it = torch.distributions.Categorical(logits=logprobs.detach()).sample()
sampleLogprobs = logprobs.gather(1, it.unsqueeze(1)) # gather the logprobs at sampled positions
return it, sampleLogprobs
def decode_sequence(self, seq):
return utils.decode_sequence(self.vocab, seq)

0
captioning/models/__init__.py

92
captioning/models/model_utils.py

@ -0,0 +1,92 @@
import torch
def split_tensors(n, x):
if torch.is_tensor(x):
assert x.shape[0] % n == 0
x = x.reshape(x.shape[0] // n, n, *x.shape[1:]).unbind(1)
elif type(x) is list or type(x) is tuple:
x = [split_tensors(n, _) for _ in x]
elif x is None:
x = [None] * n
return x
# Input: seq, N*D numpy array, with element 0 .. vocab_size. 0 is END token.
#def decode_sequence(ix_to_word, seq):
# # N, D = seq.size()
# N, D = seq.shape
# out = []
# for i in range(N):
# txt = ''
# for j in range(D):
# ix = seq[i,j]
# if ix > 0 :
# if j >= 1:
# txt = txt + ' '
# txt = txt + ix_to_word[str(ix.item())]
# else:
# break
# if int(os.getenv('REMOVE_BAD_ENDINGS', '0')):
# flag = 0
# words = txt.split(' ')
# for j in range(len(words)):
# if words[-j-1] not in bad_endings:
# flag = -j
# break
# txt = ' '.join(words[0:len(words)+flag])
# out.append(txt.replace('@@ ', ''))
# return out
def decode_sequence(ix_to_word, seq, remove_bad_endings = True):
# N, D = seq.size()
N, D = seq.shape
bad_endings = ['a','an','the','in','for','at','of','with','before','after','on','upon','near','to','is','are','am']
bad_endings += ['the']
out = []
for i in range(N):
txt = ''
for j in range(D):
ix = seq[i,j]
if ix > 0 :
if j >= 1:
txt = txt + ' '
txt = txt + ix_to_word[str(ix.item())]
else:
break
if remove_bad_endings is True:
flag = 0
words = txt.split(' ')
for j in range(len(words)):
if words[-j-1] not in bad_endings:
flag = -j
break
txt = ' '.join(words[0:len(words)+flag])
out.append(txt.replace('@@ ', ''))
return out
def penalty_builder(penalty_config):
if penalty_config == '':
return lambda x,y: y
pen_type, alpha = penalty_config.split('_')
alpha = float(alpha)
if pen_type == 'wu':
return lambda x,y: length_wu(x,y,alpha)
if pen_type == 'avg':
return lambda x,y: length_average(x,y,alpha)
def length_wu(length, logprobs, alpha=0.):
"""
NMT length re-ranking score from
"Google's Neural Machine Translation System" :cite:`wu2016google`.
"""
modifier = (((5 + length) ** alpha) /
((5 + 1) ** alpha))
return (logprobs / modifier)
def length_average(length, logprobs, alpha=0.):
"""
Returns the average probability of tokens in a sequence.
"""
return logprobs / length

1
clip/__init__.py

@ -0,0 +1 @@
from .clip import *

BIN
clip/bpe_simple_vocab_16e6.txt.gz (Stored with Git LFS)

Binary file not shown.

193
clip/clip.py

@ -0,0 +1,193 @@
import hashlib
import os
import urllib
import warnings
from typing import Union, List
import torch
from PIL import Image
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
from tqdm import tqdm
from .model import build_model
from .simple_tokenizer import SimpleTokenizer as _Tokenizer
__all__ = ["available_models", "load", "tokenize"]
_tokenizer = _Tokenizer()
_MODELS = {
"RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
"RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
"RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
"ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
}
def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")):
os.makedirs(root, exist_ok=True)
filename = os.path.basename(url)
expected_sha256 = url.split("/")[-2]
download_target = os.path.join(root, filename)
if os.path.exists(download_target) and not os.path.isfile(download_target):
raise RuntimeError(f"{download_target} exists and is not a regular file")
if os.path.isfile(download_target):
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
return download_target
else:
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop:
while True:
buffer = source.read(8192)
if not buffer:
break
output.write(buffer)
loop.update(len(buffer))
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")
return download_target
def _transform(n_px):
return Compose([
Resize(n_px, interpolation=Image.BICUBIC),
CenterCrop(n_px),
lambda image: image.convert("RGB"),
ToTensor(),
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
])
def available_models() -> List[str]:
"""Returns the names of available CLIP models"""
return list(_MODELS.keys())
def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=True):
"""Load a CLIP model
Parameters
----------
name : str
A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
device : Union[str, torch.device]
The device to put the loaded model
jit : bool
Whether to load the optimized JIT model (default) or more hackable non-JIT model.
Returns
-------
model : torch.nn.Module
The CLIP model
preprocess : Callable[[PIL.Image], torch.Tensor]
A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
"""
if name in _MODELS:
model_path = _download(_MODELS[name])
elif os.path.isfile(name):
model_path = name
else:
raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
try:
# loading JIT archive
model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
state_dict = None
except RuntimeError:
# loading saved state dict
if jit:
warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
jit = False
state_dict = torch.load(model_path, map_location="cpu")
if not jit:
model = build_model(state_dict or model.state_dict()).to(device)
if str(device) == "cpu":
model.float()
return model, _transform(model.visual.input_resolution)
# patch the device names
device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
def patch_device(module):
graphs = [module.graph] if hasattr(module, "graph") else []
if hasattr(module, "forward1"):
graphs.append(module.forward1.graph)
for graph in graphs:
for node in graph.findAllNodes("prim::Constant"):
if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
node.copyAttributes(device_node)
model.apply(patch_device)
patch_device(model.encode_image)
patch_device(model.encode_text)
# patch dtype to float32 on CPU
if str(device) == "cpu":
float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
float_node = float_input.node()
def patch_float(module):
graphs = [module.graph] if hasattr(module, "graph") else []
if hasattr(module, "forward1"):
graphs.append(module.forward1.graph)
for graph in graphs:
for node in graph.findAllNodes("aten::to"):
inputs = list(node.inputs())
for i in [1, 2]: # dtype can be the second or third argument to aten::to()
if inputs[i].node()["value"] == 5:
inputs[i].node().copyAttributes(float_node)
model.apply(patch_float)
patch_float(model.encode_image)
patch_float(model.encode_text)
model.float()
return model, _transform(model.input_resolution.item())
def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor:
"""
Returns the tokenized representation of given input string(s)
Parameters
----------
texts : Union[str, List[str]]
An input string or a list of input strings to tokenize
context_length : int
The context length to use; all CLIP models use 77 as the context length
Returns
-------
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
"""
if isinstance(texts, str):
texts = [texts]
sot_token = _tokenizer.encoder["<|startoftext|>"]
eot_token = _tokenizer.encoder["<|endoftext|>"]
all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
for i, tokens in enumerate(all_tokens):
if len(tokens) > context_length:
raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
result[i, :len(tokens)] = torch.tensor(tokens)
return result

437
clip/model.py

@ -0,0 +1,437 @@
from collections import OrderedDict
from typing import Tuple, Union
import torch
import torch.nn.functional as F
from torch import nn
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1):
super().__init__()
# all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = None
self.stride = stride
if stride > 1 or inplanes != planes * Bottleneck.expansion:
# downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
self.downsample = nn.Sequential(OrderedDict([
("-1", nn.AvgPool2d(stride)),
("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
("1", nn.BatchNorm2d(planes * self.expansion))
]))
def forward(self, x: torch.Tensor):
identity = x
out = self.relu(self.bn1(self.conv1(x)))
out = self.relu(self.bn2(self.conv2(out)))
out = self.avgpool(out)
out = self.bn3(self.conv3(out))
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class AttentionPool2d(nn.Module):
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
super().__init__()
self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
self.k_proj = nn.Linear(embed_dim, embed_dim)
self.q_proj = nn.Linear(embed_dim, embed_dim)
self.v_proj = nn.Linear(embed_dim, embed_dim)
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
self.num_heads = num_heads
def forward(self, x):
x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
# print(x.shape, self.positional_embedding.shape)
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
x = x + self.positional_embedding[0, :, None, :].to(x.dtype) # (HW+1)NC
x, _ = F.multi_head_attention_forward(
query=x, key=x, value=x,
embed_dim_to_check=x.shape[-1],
num_heads=self.num_heads,
q_proj_weight=self.q_proj.weight,
k_proj_weight=self.k_proj.weight,
v_proj_weight=self.v_proj.weight,
in_proj_weight=None,
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
bias_k=None,
bias_v=None,
add_zero_attn=False,
dropout_p=0,
out_proj_weight=torch.ones_like(self.q_proj.weight),
out_proj_bias=torch.zeros_like(self.q_proj.bias),
# out_proj_weight=self.c_proj.weight,
# out_proj_bias=self.c_proj.bias,
use_separate_proj_weight=True,
training=self.training,
need_weights=False
)
return x[0]
class ModifiedResNet(nn.Module):
"""
A ResNet class that is similar to torchvision's but contains the following changes:
- There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
- Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
- The final pooling layer is a QKV attention instead of an average pool
"""
def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
super().__init__()
self.output_dim = output_dim
self.input_resolution = input_resolution
# the 3-layer stem
self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(width // 2)
self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(width // 2)
self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
self.bn3 = nn.BatchNorm2d(width)
self.avgpool = nn.AvgPool2d(2)
self.relu = nn.ReLU(inplace=True)
# residual layers
self._inplanes = width # this is a *mutable* variable used during construction
self.layer1 = self._make_layer(width, layers[0])
self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
embed_dim = width * 32 # the ResNet feature dimension
self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
def _make_layer(self, planes, blocks, stride=1):
layers = [Bottleneck(self._inplanes, planes, stride)]
self._inplanes = planes * Bottleneck.expansion
for _ in range(1, blocks):
layers.append(Bottleneck(self._inplanes, planes))
return nn.Sequential(*layers)
def forward(self, x):
def stem(x):
for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]:
x = self.relu(bn(conv(x)))
x = self.avgpool(x)
return x
x = x.type(self.conv1.weight.dtype)
x = stem(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
# print(x.shape)
# x = self.attnpool(x)
attnpool = self.attnpool(x)
return (x, attnpool)
class LayerNorm(nn.LayerNorm):
"""Subclass torch's LayerNorm to handle fp16."""
def forward(self, x: torch.Tensor):
orig_type = x.dtype
ret = super().forward(x.type(torch.float32))
return ret.type(orig_type)
class QuickGELU(nn.Module):
def forward(self, x: torch.Tensor):
return x * torch.sigmoid(1.702 * x)
class ResidualAttentionBlock(nn.Module):
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
super().__init__()
self.attn = nn.MultiheadAttention(d_model, n_head)
self.ln_1 = LayerNorm(d_model)
self.mlp = nn.Sequential(OrderedDict([
("c_fc", nn.Linear(d_model, d_model * 4)),
("gelu", QuickGELU()),
("c_proj", nn.Linear(d_model * 4, d_model))
]))
self.ln_2 = LayerNorm(d_model)
self.attn_mask = attn_mask
def attention(self, x: torch.Tensor):
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
def forward(self, x: torch.Tensor):
x = x + self.attention(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))
return x
class Transformer(nn.Module):
def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
super().__init__()
self.width = width
self.layers = layers
self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
def forward(self, x: torch.Tensor):
return self.resblocks(x)
class VisualTransformer(nn.Module):
def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
super().__init__()
self.input_resolution = input_resolution
self.output_dim = output_dim
self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
scale = width ** -0.5
self.class_embedding = nn.Parameter(scale * torch.randn(width))
self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
self.ln_pre = LayerNorm(width)
self.transformer = Transformer(width, layers, heads)
self.ln_post = LayerNorm(width)
self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
def forward(self, x: torch.Tensor):
x = self.conv1(x) # shape = [*, width, grid, grid]
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
x = x + self.positional_embedding.to(x.dtype)
x = self.ln_pre(x)
x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x)
x = x.permute(1, 0, 2) # LND -> NLD
# x = self.ln_post(x[:, 0, :])
x = self.ln_post(x)
# if self.proj is not None:
# x = x @ self.proj
return x
class CLIP(nn.Module):
def __init__(self,
embed_dim: int,
# vision
image_resolution: int,
vision_layers: Union[Tuple[int, int, int, int], int],
vision_width: int,
vision_patch_size: int,
# text
context_length: int,
vocab_size: int,
transformer_width: int,
transformer_heads: int,
transformer_layers: int
):
super().__init__()
self.context_length = context_length
if isinstance(vision_layers, (tuple, list)):
vision_heads = vision_width * 32 // 64
self.visual = ModifiedResNet(
layers=vision_layers,
output_dim=embed_dim,
heads=vision_heads,
input_resolution=image_resolution,
width=vision_width
)
else:
vision_heads = vision_width // 64
self.visual = VisualTransformer(
input_resolution=image_resolution,
patch_size=vision_patch_size,
width=vision_width,
layers=vision_layers,
heads=vision_heads,
output_dim=embed_dim
)
self.transformer = Transformer(
width=transformer_width,
layers=transformer_layers,
heads=transformer_heads,
attn_mask=self.build_attention_mask()
)
self.vocab_size = vocab_size
self.token_embedding = nn.Embedding(vocab_size, transformer_width)
self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
self.ln_final = LayerNorm(transformer_width)
self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
self.logit_scale = nn.Parameter(torch.ones([]))
self.initialize_parameters()
def initialize_parameters(self):
nn.init.normal_(self.token_embedding.weight, std=0.02)
nn.init.normal_(self.positional_embedding, std=0.01)
if isinstance(self.visual, ModifiedResNet):
if self.visual.attnpool is not None:
std = self.visual.attnpool.c_proj.in_features ** -0.5
nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
for name, param in resnet_block.named_parameters():
if name.endswith("bn3.weight"):
nn.init.zeros_(param)
proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
attn_std = self.transformer.width ** -0.5
fc_std = (2 * self.transformer.width) ** -0.5
for block in self.transformer.resblocks:
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
if self.text_projection is not None:
nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
def build_attention_mask(self):
# lazily create causal attention mask, with full attention between the vision tokens
# pytorch uses additive attention mask; fill with -inf
mask = torch.empty(self.context_length, self.context_length)
mask.fill_(float("-inf"))
mask.triu_(1) # zero out the lower diagonal
return mask
@property
def dtype(self):
return self.visual.conv1.weight.dtype
def encode_image(self, image):
return self.visual(image.type(self.dtype))
def encode_text(self, text):
x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
x = x + self.positional_embedding.type(self.dtype)
x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.ln_final(x).type(self.dtype)
# x.shape = [batch_size, n_ctx, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
return x
def forward(self, image, text):
image_features = self.encode_image(image)
text_features = self.encode_text(text)
# normalized features
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
# cosine similarity as logits
logit_scale = self.logit_scale.exp()
logits_per_image = logit_scale * image_features @ text_features.t()
logits_per_text = logit_scale * text_features @ image_features.t()
# shape = [global_batch_size, global_batch_size]
return logits_per_image, logits_per_text
def convert_weights(model: nn.Module):
"""Convert applicable model parameters to fp16"""
def _convert_weights_to_fp16(l):
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
l.weight.data = l.weight.data.half()
if l.bias is not None:
l.bias.data = l.bias.data.half()
if isinstance(l, nn.MultiheadAttention):
for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
tensor = getattr(l, attr)
if tensor is not None:
tensor.data = tensor.data.half()
for name in ["text_projection", "proj"]:
if hasattr(l, name):
attr = getattr(l, name)
if attr is not None:
attr.data = attr.data.half()
model.apply(_convert_weights_to_fp16)
def build_model(state_dict: dict):
vit = "visual.proj" in state_dict
if vit:
vision_width = state_dict["visual.conv1.weight"].shape[0]
vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
image_resolution = vision_patch_size * grid_size
else:
counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
vision_layers = tuple(counts)
vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
vision_patch_size = None
assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
image_resolution = output_width * 32
embed_dim = state_dict["text_projection"].shape[1]
context_length = state_dict["positional_embedding"].shape[0]
vocab_size = state_dict["token_embedding.weight"].shape[0]
transformer_width = state_dict["ln_final.weight"].shape[0]
transformer_heads = transformer_width // 64
transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
model = CLIP(
embed_dim,
image_resolution, vision_layers, vision_width, vision_patch_size,
context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
)
for key in ["input_resolution", "context_length", "vocab_size"]:
if key in state_dict:
del state_dict[key]
convert_weights(model)
model.load_state_dict(state_dict)
return model.eval()

132
clip/simple_tokenizer.py

@ -0,0 +1,132 @@
import gzip
import html
import os
from functools import lru_cache
import ftfy
import regex as re
@lru_cache()
def default_bpe():
return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
@lru_cache()
def bytes_to_unicode():
"""
Returns list of utf-8 byte and a corresponding list of unicode strings.
The reversible bpe codes work on unicode strings.
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
This is a signficant percentage of your normal, say, 32K bpe vocab.
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
And avoids mapping to whitespace/control characters the bpe code barfs on.
"""
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
cs = bs[:]
n = 0
for b in range(2**8):
if b not in bs:
bs.append(b)
cs.append(2**8+n)
n += 1
cs = [chr(n) for n in cs]
return dict(zip(bs, cs))
def get_pairs(word):
"""Return set of symbol pairs in a word.
Word is represented as tuple of symbols (symbols being variable-length strings).
"""
pairs = set()
prev_char = word[0]
for char in word[1:]:
pairs.add((prev_char, char))
prev_char = char
return pairs
def basic_clean(text):
text = ftfy.fix_text(text)
text = html.unescape(html.unescape(text))
return text.strip()
def whitespace_clean(text):
text = re.sub(r'\s+', ' ', text)
text = text.strip()
return text
class SimpleTokenizer(object):
def __init__(self, bpe_path: str = default_bpe()):
self.byte_encoder = bytes_to_unicode()
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
merges = merges[1:49152-256-2+1]
merges = [tuple(merge.split()) for merge in merges]
vocab = list(bytes_to_unicode().values())
vocab = vocab + [v+'</w>' for v in vocab]
for merge in merges:
vocab.append(''.join(merge))
vocab.extend(['<|startoftext|>', '<|endoftext|>'])
self.encoder = dict(zip(vocab, range(len(vocab))))
self.decoder = {v: k for k, v in self.encoder.items()}
self.bpe_ranks = dict(zip(merges, range(len(merges))))
self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
def bpe(self, token):
if token in self.cache:
return self.cache[token]
word = tuple(token[:-1]) + ( token[-1] + '</w>',)
pairs = get_pairs(word)
if not pairs:
return token+'</w>'
while True:
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
if bigram not in self.bpe_ranks:
break
first, second = bigram
new_word = []
i = 0
while i < len(word):
try:
j = word.index(first, i)
new_word.extend(word[i:j])
i = j
except:
new_word.extend(word[i:])
break
if word[i] == first and i < len(word)-1 and word[i+1] == second:
new_word.append(first+second)
i += 2
else:
new_word.append(word[i])
i += 1
new_word = tuple(new_word)
word = new_word
if len(word) == 1:
break
else:
pairs = get_pairs(word)
word = ' '.join(word)
self.cache[token] = word
return word
def encode(self, text):
bpe_tokens = []
text = whitespace_clean(basic_clean(text)).lower()
for token in re.findall(self.pat, text):
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
return bpe_tokens
def decode(self, tokens):
text = ''.join([self.decoder[token] for token in tokens])
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
return text

105
clip_caption_reward.py

@ -0,0 +1,105 @@
# Copyright 2021 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pathlib
from torch import nn
from timm.models.vision_transformer import resize_pos_embed
from towhee.types.image_utils import to_pil
class ClipCaptionReward(NNOperator):
"""
BLIP multi-modal embedding operator
"""
def __init__(self, model_name: str):
super().__init__()
sys.path.append(str(Path(__file__).parent))
from utils import opts
import clip
opt = opts.parse_opt(parse=False, cfg=cfg)
path = pathlib.Path(__file__).parent
dict_json = json.load(open("{}/data/cocotalk.json".format(path)))
ix_to_word = dict_json["ix_to_word"]
self.device = "cuda" if torch.cuda.is_available() else "cpu"
clip_model, clip_transform = clip.load("RN50", jit=False, device=self.device)
self.clip_model = clip_model
self.clip_transform = clip_transform
vocab_size = len(ix_to_word)
seq_length = 1
opt.vocab_size = vocab_size
opt.seq_length = seq_length
opt.batch_size = 1
opt.vocab = ix_to_word
num_patches = 196 # 600 * 1000 // 32 // 32
pos_embed = nn.Parameter(
torch.zeros(
1,
num_patches + 1,
clip_model.visual.attnpool.positional_embedding.shape[-1],
device=self.device,
),
)
pos_embed.weight = resize_pos_embed(
clip_model.visual.attnpool.positional_embedding.unsqueeze(0), pos_embed
)
self.clip_model.visual.attnpool.positional_embedding = pos_embed
self.model = TransformerModel(opt)
self.image_mean = (
torch.Tensor([0.48145466, 0.4578275, 0.40821073])
.to(self.device)
.reshape(3, 1, 1)
)
self.image_std = (
torch.Tensor([0.26862954, 0.26130258, 0.27577711])
.to(self.device)
.reshape(3, 1, 1)
)
@arg(1, to_image_color('RGB'))
def inference_single_data(self, data):
text = self._inference_from_image(data)
return text
@arg(1, to_image_color('RGB'))
def _inference_from_image(self, img):
img = to_pil(img)
img = self._preprocess(img)
self._inference_from_image(img)
img -= self.image_mean
img /= self.image_std
tmp_att, tmp_fc = self.clip_model.encode_image(img)
tmp_att = tmp_att[0].permute(1, 2, 0)
att_feat = tmp_att
return att_feat
def __call__(self, data):
if not isinstance(data, list):
data = [data]
else:
data = data
for single_data in data:
result = self.inference_single_data(single_data)
results.append(result)
if len(data) == 1:
return results[0]
else:
return results

BIN
configs/.DS_Store

Binary file not shown.

60
configs/phase1/FineCapEval_clipRN50_mle.yml

@ -0,0 +1,60 @@
caption_model: transformer
noamopt: true
noamopt_warmup: 20000
label_smoothing: 0.0
input_json: data/FineCapEval.json
input_label_h5: none
input_fc_dir: data/FineCapEval_clip_RN50_fc
input_att_dir: data/FineCapEval_clip_RN50_att
input_clipscore_vis_dir: data/FineCapEval_clipscore_vis
seq_per_img: 5
batch_size: 200
learning_rate: 0.0005
checkpoint_path: ./save/clipRN50_mle/clipRN50_mle
# clip_load_path: '/scratch-space/retrieval/save/clip_negative_text/clip_negative_text-epoch=10.ckpt'
# Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer:
# N=num_layers
# d_model=input_encoding_size
# d_ff=rnn_size
# will be ignored
num_layers: 6
input_encoding_size: 512
rnn_size: 2048
# Transformer config
N_enc: 6
N_dec: 6
d_model: 512
d_ff: 2048
num_att_heads: 8
dropout: 0.1
learning_rate_decay_start: 0
scheduled_sampling_start: -1
save_checkpoint_every: 3000
language_eval: 1
val_images_use: 5000
max_epochs: 15
train_sample_n: 5
REFORWARD: false
# _BASE_: transformer.yml
reduce_on_plateau: false
noamopt: false
learning_rate: 0.000005
learning_rate_decay_start: -1
self_critical_after: 15
max_epochs: 50
verbose: false
precision: 32
use_clipscore: false

52
configs/phase1/clipRN50_mle.yml

@ -0,0 +1,52 @@
caption_model: transformer
noamopt: true
# noamopt: false
noamopt_warmup: 20000
label_smoothing: 0.0
input_json: data/cocotalk.json
input_label_h5: data/cocotalk_label.h5
input_fc_dir: data/cocotalk_clip_RN50_fc
input_att_dir: data/cocotalk_clip_RN50_att
input_clipscore_vis_dir: data/cocotalk_clipscore_vis
seq_per_img: 5
# batch_size: 600
batch_size: 200
learning_rate: 0.0005
# checkpoint_path: ./save/trans_clip_rn50_sc_pl
checkpoint_path: save/clipRN50_mle/clipRN50_mle
# Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer:
# N=num_layers
# d_model=input_encoding_size
# d_ff=rnn_size
# will be ignored
num_layers: 6
input_encoding_size: 512
rnn_size: 2048
# Transformer config
N_enc: 6
N_dec: 6
d_model: 512
d_ff: 2048
num_att_heads: 8
dropout: 0.1
learning_rate_decay_start: 0
scheduled_sampling_start: -1
save_checkpoint_every: 3000
language_eval: 1
val_images_use: 5000
# max_epochs: 15
max_epochs: 25
train_sample_n: 5
REFORWARD: false
verbose: false
precision: 16

41
configs/phase1/transformer.yml

@ -0,0 +1,41 @@
caption_model: transformer
noamopt: true
noamopt_warmup: 20000
label_smoothing: 0.0
input_json: data/cocotalk.json
input_label_h5: data/cocotalk_label.h5
input_att_dir: data/cocotalk_att
seq_per_img: 5
batch_size: 10
learning_rate: 0.0005
checkpoint_path: ./save/trans_rn50_sc
# Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer:
# N=num_layers
# d_model=input_encoding_size
# d_ff=rnn_size
# will be ignored
num_layers: 6
input_encoding_size: 512
rnn_size: 2048
# Transformer config
N_enc: 6
N_dec: 6
d_model: 512
d_ff: 2048
num_att_heads: 8
dropout: 0.1
learning_rate_decay_start: 0
scheduled_sampling_start: -1
save_checkpoint_every: 3000
language_eval: 1
val_images_use: 5000
max_epochs: 15
train_sample_n: 5
REFORWARD: false

61
configs/phase2/FineCapEval_clipRN50_cider.yml

@ -0,0 +1,61 @@
caption_model: transformer
noamopt: true
noamopt_warmup: 20000
label_smoothing: 0.0
input_json: data/FineCapEval.json
input_label_h5: none
input_fc_dir: data/FineCapEval_clip_RN50_fc
input_att_dir: data/FineCapEval_clip_RN50_att
input_clipscore_vis_dir: data/FineCapEval_clipscore_vis
seq_per_img: 5
batch_size: 200
learning_rate: 0.0005
checkpoint_path: ./save/clipRN50_cider/clipRN50_cider
# clip_load_path: '/scratch-space/retrieval/save/clip_negative_text/clip_negative_text-epoch=10.ckpt'
# Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer:
# N=num_layers
# d_model=input_encoding_size
# d_ff=rnn_size
# will be ignored
num_layers: 6
input_encoding_size: 512
rnn_size: 2048
# Transformer config
N_enc: 6
N_dec: 6
d_model: 512
d_ff: 2048
num_att_heads: 8
dropout: 0.1
learning_rate_decay_start: 0
scheduled_sampling_start: -1
save_checkpoint_every: 3000
language_eval: 1
val_images_use: 5000
max_epochs: 15
train_sample_n: 5
REFORWARD: false
# _BASE_: transformer.yml
reduce_on_plateau: false
noamopt: false
learning_rate: 0.000005
learning_rate_decay_start: -1
self_critical_after: 15
max_epochs: 50
verbose: false
precision: 32
# use_clipscore: true
use_clipscore: false

65
configs/phase2/FineCapEval_clipRN50_cider_clips.yml

@ -0,0 +1,65 @@
caption_model: transformer
noamopt: true
noamopt_warmup: 20000
label_smoothing: 0.0
input_json: data/FineCapEval.json
input_label_h5: none
input_fc_dir: data/FineCapEval_clip_RN50_fc
input_att_dir: data/FineCapEval_clip_RN50_att
input_clipscore_vis_dir: data/FineCapEval_clipscore_vis
seq_per_img: 5
batch_size: 200
learning_rate: 0.0005
checkpoint_path: ./save/clipRN50_cider_clips/clipRN50_cider_clips
# clip_load_path: '/scratch-space/retrieval/save/clip_negative_text/clip_negative_text-epoch=10.ckpt'
# Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer:
# N=num_layers
# d_model=input_encoding_size
# d_ff=rnn_size
# will be ignored
num_layers: 6
input_encoding_size: 512
rnn_size: 2048
# Transformer config
N_enc: 6
N_dec: 6
d_model: 512
d_ff: 2048
num_att_heads: 8
dropout: 0.1
learning_rate_decay_start: 0
scheduled_sampling_start: -1
save_checkpoint_every: 3000
language_eval: 1
val_images_use: 5000
max_epochs: 15
train_sample_n: 5
REFORWARD: false
# _BASE_: transformer.yml
reduce_on_plateau: false
noamopt: false
learning_rate: 0.000005
learning_rate_decay_start: -1
self_critical_after: 15
max_epochs: 50
verbose: false
precision: 32
# use_clipscore: true
use_clipscore: false
clipscore_reward_weight: 2.0
clipscore_mode: clip_s
use_multi_rewards: true

64
configs/phase2/FineCapEval_clipRN50_clips.yml

@ -0,0 +1,64 @@
caption_model: transformer
noamopt: true
noamopt_warmup: 20000
label_smoothing: 0.0
input_json: data/FineCapEval.json
input_label_h5: none
input_fc_dir: data/FineCapEval_clip_RN50_fc
input_att_dir: data/FineCapEval_clip_RN50_att
input_clipscore_vis_dir: data/FineCapEval_clipscore_vis
seq_per_img: 5
batch_size: 160
learning_rate: 0.0005
checkpoint_path: ./save/clipRN50_clips/clipRN50_clips
use_multi_rewards: false
use_grammar: false
use_grammar_baseline: false
# clip_load_path: '/scratch-space/retrieval/save/clip_negative_text/clip_negative_text-epoch=10.ckpt'
# Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer:
# N=num_layers
# d_model=input_encoding_size
# d_ff=rnn_size
# will be ignored
num_layers: 6
input_encoding_size: 512
rnn_size: 2048
# Transformer config
N_enc: 6
N_dec: 6
d_model: 512
d_ff: 2048
num_att_heads: 8
dropout: 0.1
learning_rate_decay_start: 0
scheduled_sampling_start: -1
save_checkpoint_every: 3000
language_eval: 0
val_images_use: 5000
max_epochs: 15
train_sample_n: 5
REFORWARD: false
# _BASE_: transformer.yml
reduce_on_plateau: false
noamopt: false
learning_rate: 0.000005
learning_rate_decay_start: -1
self_critical_after: 15
max_epochs: 50
verbose: false
precision: 32
# use_clipscore: true
use_clipscore: false
clipscore_reward_weight: 2.0

64
configs/phase2/FineCapEval_clipRN50_clips_grammar.yml

@ -0,0 +1,64 @@
caption_model: transformer
noamopt: true
noamopt_warmup: 20000
label_smoothing: 0.0
input_json: data/FineCapEval.json
input_label_h5: none
input_fc_dir: data/FineCapEval_clip_RN50_fc
input_att_dir: data/FineCapEval_clip_RN50_att
input_clipscore_vis_dir: data/FineCapEval_clipscore_vis
seq_per_img: 5
batch_size: 160
learning_rate: 0.0005
checkpoint_path: ./save/clipRN50_clips_grammar/clipRN50_clips_grammar
use_multi_rewards: true
use_grammar: true
use_grammar_baseline: true
# clip_load_path: '/scratch-space/retrieval/save/clip_negative_text/clip_negative_text-epoch=10.ckpt'
# Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer:
# N=num_layers
# d_model=input_encoding_size
# d_ff=rnn_size
# will be ignored
num_layers: 6
input_encoding_size: 512
rnn_size: 2048
# Transformer config
N_enc: 6
N_dec: 6
d_model: 512
d_ff: 2048
num_att_heads: 8
dropout: 0.1
learning_rate_decay_start: 0
scheduled_sampling_start: -1
save_checkpoint_every: 3000
language_eval: 0
val_images_use: 5000
max_epochs: 15
train_sample_n: 5
REFORWARD: false
# _BASE_: transformer.yml
reduce_on_plateau: false
noamopt: false
learning_rate: 0.000005
learning_rate_decay_start: -1
self_critical_after: 15
max_epochs: 50
verbose: false
precision: 32
# use_clipscore: true
use_clipscore: false
clipscore_reward_weight: 2.0

58
configs/phase2/clipRN50_cider.yml

@ -0,0 +1,58 @@
caption_model: transformer
noamopt: true
noamopt_warmup: 20000
label_smoothing: 0.0
input_json: data/cocotalk.json
input_label_h5: data/cocotalk_label.h5
input_fc_dir: data/cocotalk_clip_RN50_fc
input_att_dir: data/cocotalk_clip_RN50_att
# used only for evaluation
input_clipscore_vis_dir: data/cocotalk_clipscore_vis
seq_per_img: 5
batch_size: 200
learning_rate: 0.0005
# checkpoint_path: ./save/trans_clip_rn50_sc_pl_scst_cider
checkpoint_path: save/clipRN50_cider/clipRN50_cider
# Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer:
# N=num_layers
# d_model=input_encoding_size
# d_ff=rnn_size
# will be ignored
num_layers: 6
input_encoding_size: 512
rnn_size: 2048
# Transformer config
N_enc: 6
N_dec: 6
d_model: 512
d_ff: 2048
num_att_heads: 8
dropout: 0.1
learning_rate_decay_start: 0
scheduled_sampling_start: -1
save_checkpoint_every: 3000
language_eval: 1
val_images_use: 5000
max_epochs: 15
train_sample_n: 5
REFORWARD: false
# _BASE_: transformer.yml
reduce_on_plateau: false
noamopt: false
learning_rate: 0.000005
learning_rate_decay_start: -1
self_critical_after: 15
max_epochs: 40
verbose: false
precision: 32

61
configs/phase2/clipRN50_cider_clips.yml

@ -0,0 +1,61 @@
caption_model: transformer
noamopt: true
noamopt_warmup: 20000
label_smoothing: 0.0
input_json: data/cocotalk.json
input_label_h5: data/cocotalk_label.h5
input_fc_dir: data/cocotalk_clip_RN50_fc
input_att_dir: data/cocotalk_clip_RN50_att
input_clipscore_vis_dir: data/cocotalk_clipscore_vis
seq_per_img: 5
batch_size: 160
learning_rate: 0.0005
checkpoint_path: save/clipRN50_cider_clips/clipRN50_cider_clips
# Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer:
# N=num_layers
# d_model=input_encoding_size
# d_ff=rnn_size
# will be ignored
num_layers: 6
input_encoding_size: 512
rnn_size: 2048
# Transformer config
N_enc: 6
N_dec: 6
d_model: 512
d_ff: 2048
num_att_heads: 8
dropout: 0.1
learning_rate_decay_start: 0
scheduled_sampling_start: -1
save_checkpoint_every: 3000
language_eval: 1
val_images_use: 5000
max_epochs: 15
train_sample_n: 5
REFORWARD: false
# _BASE_: transformer.yml
reduce_on_plateau: false
noamopt: false
learning_rate: 0.000005
learning_rate_decay_start: -1
self_critical_after: 15
max_epochs: 40
verbose: false
precision: 32
use_clipscore: true
clipscore_reward_weight: 2.0
clipscore_mode: clip_s
use_multi_rewards: true

58
configs/phase2/clipRN50_clips.yml

@ -0,0 +1,58 @@
caption_model: transformer
noamopt: true
noamopt_warmup: 20000
label_smoothing: 0.0
input_json: data/cocotalk.json
input_label_h5: data/cocotalk_label.h5
input_fc_dir: data/cocotalk_clip_RN50_fc
input_att_dir: data/cocotalk_clip_RN50_att
input_clipscore_vis_dir: data/cocotalk_clipscore_vis
seq_per_img: 5
batch_size: 160
learning_rate: 0.0005
checkpoint_path: save/clipRN50_clips/clipRN50_clips
# Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer:
# N=num_layers
# d_model=input_encoding_size
# d_ff=rnn_size
# will be ignored
num_layers: 6
input_encoding_size: 512
rnn_size: 2048
# Transformer config
N_enc: 6
N_dec: 6
d_model: 512
d_ff: 2048
num_att_heads: 8
dropout: 0.1
learning_rate_decay_start: 0
scheduled_sampling_start: -1
save_checkpoint_every: 3000
language_eval: 1
val_images_use: 5000
max_epochs: 15
train_sample_n: 5
REFORWARD: false
# _BASE_: transformer.yml
reduce_on_plateau: false
noamopt: false
learning_rate: 0.000005
learning_rate_decay_start: -1
self_critical_after: 15
max_epochs: 40
verbose: false
precision: 32
use_clipscore: true
clipscore_reward_weight: 2.0

64
configs/phase2/clipRN50_clips_grammar.yml

@ -0,0 +1,64 @@
caption_model: transformer
noamopt: true
noamopt_warmup: 20000
label_smoothing: 0.0
input_json: data/cocotalk.json
input_label_h5: data/cocotalk_label.h5
input_fc_dir: data/cocotalk_clip_RN50_fc
input_att_dir: data/cocotalk_clip_RN50_att
input_clipscore_vis_dir: data/cocotalk_clipscore_vis
seq_per_img: 5
batch_size: 160
learning_rate: 0.0005
checkpoint_path: save/clipRN50_clips_grammar/clipRN50_clips_grammar
use_multi_rewards: true
use_grammar: true
use_grammar_baseline: true
# clip_load_path: '/scratch-space/retrieval/save/clip_negative_text/clip_negative_text-epoch=10.ckpt'
clip_load_path: 'retrieval/save/clip_negative_text/clip_negative_text-epoch=12.ckpt'
# Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer:
# N=num_layers
# d_model=input_encoding_size
# d_ff=rnn_size
# will be ignored
num_layers: 6
input_encoding_size: 512
rnn_size: 2048
# Transformer config
N_enc: 6
N_dec: 6
d_model: 512
d_ff: 2048
num_att_heads: 8
dropout: 0.1
learning_rate_decay_start: 0
scheduled_sampling_start: -1
save_checkpoint_every: 3000
language_eval: 1
val_images_use: 5000
max_epochs: 15
train_sample_n: 5
REFORWARD: false
# _BASE_: transformer.yml
reduce_on_plateau: false
noamopt: false
learning_rate: 0.000005
learning_rate_decay_start: -1
self_critical_after: 15
max_epochs: 40
verbose: false
precision: 32
use_clipscore: true
clipscore_reward_weight: 2.0

41
configs/phase2/transformer.yml

@ -0,0 +1,41 @@
caption_model: transformer
noamopt: true
noamopt_warmup: 20000
label_smoothing: 0.0
input_json: data/cocotalk.json
input_label_h5: data/cocotalk_label.h5
input_att_dir: data/cocotalk_att
seq_per_img: 5
batch_size: 10
learning_rate: 0.0005
checkpoint_path: ./save/trans_rn50_sc
# Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer:
# N=num_layers
# d_model=input_encoding_size
# d_ff=rnn_size
# will be ignored
num_layers: 6
input_encoding_size: 512
rnn_size: 2048
# Transformer config
N_enc: 6
N_dec: 6
d_model: 512
d_ff: 2048
num_att_heads: 8
dropout: 0.1
learning_rate_decay_start: 0
scheduled_sampling_start: -1
save_checkpoint_every: 3000
language_eval: 1
val_images_use: 5000
max_epochs: 15
train_sample_n: 5
REFORWARD: false

1
data/cocotalk.json

File diff suppressed because one or more lines are too long

379
transformer_model.py

@ -0,0 +1,379 @@
# This file contains Transformer network
# Most of the code is copied from http://nlp.seas.harvard.edu/2018/04/03/attention.html
# The cfg name correspondance:
# N=num_layers
# d_model=input_encoding_size
# d_ff=rnn_size
# h is always 8
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
#from . import utils
import copy
import math
import numpy as np
#from .CaptionModel import CaptionModel
#from .AttModel import sort_pack_padded_sequence, pad_unsort_packed_sequence, pack_wrapper, AttModel
from captioning.models.CaptionModel import CaptionModel
from captioning.models.AttModel import sort_pack_padded_sequence, pad_unsort_packed_sequence, pack_wrapper, AttModel
def repeat_tensors(n, x):
"""
For a tensor of size Bx..., we repeat it n times, and make it Bnx...
For collections, do nested repeat
"""
if torch.is_tensor(x):
x = x.unsqueeze(1) # Bx1x...
x = x.expand(-1, n, *([-1]*len(x.shape[2:]))) # Bxnx...
x = x.reshape(x.shape[0]*n, *x.shape[2:]) # Bnx...
elif type(x) is list or type(x) is tuple:
x = [repeat_tensors(n, _) for _ in x]
return x
class EncoderDecoder(nn.Module):
"""
A standard Encoder-Decoder architecture. Base for this and many
other models.
"""
def __init__(self, encoder, decoder, src_embed, tgt_embed, generator):
super(EncoderDecoder, self).__init__()
self.encoder = encoder
self.decoder = decoder
self.src_embed = src_embed
self.tgt_embed = tgt_embed
self.generator = generator
def forward(self, src, tgt, src_mask, tgt_mask):
"Take in and process masked src and target sequences."
return self.decode(self.encode(src, src_mask), src_mask,
tgt, tgt_mask)
def encode(self, src, src_mask):
return self.encoder(self.src_embed(src), src_mask)
def decode(self, memory, src_mask, tgt, tgt_mask):
return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask)
class Generator(nn.Module):
"Define standard linear + softmax generation step."
def __init__(self, d_model, vocab):
super(Generator, self).__init__()
self.proj = nn.Linear(d_model, vocab)
def forward(self, x):
return F.log_softmax(self.proj(x), dim=-1)
def clones(module, N):
"Produce N identical layers."
return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
class Encoder(nn.Module):
"Core encoder is a stack of N layers"
def __init__(self, layer, N):
super(Encoder, self).__init__()
self.layers = clones(layer, N)
self.norm = LayerNorm(layer.size)
def forward(self, x, mask):
"Pass the input (and mask) through each layer in turn."
for layer in self.layers:
x = layer(x, mask)
return self.norm(x)
class LayerNorm(nn.Module):
"Construct a layernorm module (See citation for details)."
def __init__(self, features, eps=1e-6):
super(LayerNorm, self).__init__()
self.a_2 = nn.Parameter(torch.ones(features))
self.b_2 = nn.Parameter(torch.zeros(features))
self.eps = eps
def forward(self, x):
mean = x.mean(-1, keepdim=True)
std = x.std(-1, keepdim=True)
return self.a_2 * (x - mean) / (std + self.eps) + self.b_2
class SublayerConnection(nn.Module):
"""
A residual connection followed by a layer norm.
Note for code simplicity the norm is first as opposed to last.
"""
def __init__(self, size, dropout):
super(SublayerConnection, self).__init__()
self.norm = LayerNorm(size)
self.dropout = nn.Dropout(dropout)
def forward(self, x, sublayer):
"Apply residual connection to any sublayer with the same size."
return x + self.dropout(sublayer(self.norm(x)))
class EncoderLayer(nn.Module):
"Encoder is made up of self-attn and feed forward (defined below)"
def __init__(self, size, self_attn, feed_forward, dropout):
super(EncoderLayer, self).__init__()
self.self_attn = self_attn
self.feed_forward = feed_forward
self.sublayer = clones(SublayerConnection(size, dropout), 2)
self.size = size
def forward(self, x, mask):
"Follow Figure 1 (left) for connections."
x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask))
return self.sublayer[1](x, self.feed_forward)
class Decoder(nn.Module):
"Generic N layer decoder with masking."
def __init__(self, layer, N):
super(Decoder, self).__init__()
self.layers = clones(layer, N)
self.norm = LayerNorm(layer.size)
def forward(self, x, memory, src_mask, tgt_mask):
for layer in self.layers:
x = layer(x, memory, src_mask, tgt_mask)
return self.norm(x)
class DecoderLayer(nn.Module):
"Decoder is made of self-attn, src-attn, and feed forward (defined below)"
def __init__(self, size, self_attn, src_attn, feed_forward, dropout):
super(DecoderLayer, self).__init__()
self.size = size
self.self_attn = self_attn
self.src_attn = src_attn
self.feed_forward = feed_forward
self.sublayer = clones(SublayerConnection(size, dropout), 3)
def forward(self, x, memory, src_mask, tgt_mask):
"Follow Figure 1 (right) for connections."
m = memory
x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask))
x = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask))
return self.sublayer[2](x, self.feed_forward)
def subsequent_mask(size):
"Mask out subsequent positions."
attn_shape = (1, size, size)
subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8')
return torch.from_numpy(subsequent_mask) == 0
def attention(query, key, value, mask=None, dropout=None):
"Compute 'Scaled Dot Product Attention'"
d_k = query.size(-1)
scores = torch.matmul(query, key.transpose(-2, -1)) \
/ math.sqrt(d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
p_attn = F.softmax(scores, dim = -1)
if dropout is not None:
p_attn = dropout(p_attn)
return torch.matmul(p_attn, value), p_attn
class MultiHeadedAttention(nn.Module):
def __init__(self, h, d_model, dropout=0.1):
"Take in model size and number of heads."
super(MultiHeadedAttention, self).__init__()
assert d_model % h == 0
# We assume d_v always equals d_k
self.d_k = d_model // h
self.h = h
self.linears = clones(nn.Linear(d_model, d_model), 4)
self.attn = None
self.dropout = nn.Dropout(p=dropout)
def forward(self, query, key, value, mask=None):
"Implements Figure 2"
if mask is not None:
# Same mask applied to all h heads.
mask = mask.unsqueeze(1)
nbatches = query.size(0)
# 1) Do all the linear projections in batch from d_model => h x d_k
query, key, value = \
[l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
for l, x in zip(self.linears, (query, key, value))]
# 2) Apply attention on all the projected vectors in batch.
x, self.attn = attention(query, key, value, mask=mask,
dropout=self.dropout)
# 3) "Concat" using a view and apply a final linear.
x = x.transpose(1, 2).contiguous() \
.view(nbatches, -1, self.h * self.d_k)
return self.linears[-1](x)
class PositionwiseFeedForward(nn.Module):
"Implements FFN equation."
def __init__(self, d_model, d_ff, dropout=0.1):
super(PositionwiseFeedForward, self).__init__()
self.w_1 = nn.Linear(d_model, d_ff)
self.w_2 = nn.Linear(d_ff, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
return self.w_2(self.dropout(F.relu(self.w_1(x))))
class Embeddings(nn.Module):
def __init__(self, d_model, vocab):
super(Embeddings, self).__init__()
self.lut = nn.Embedding(vocab, d_model)
self.d_model = d_model
def forward(self, x):
return self.lut(x) * math.sqrt(self.d_model)
class PositionalEncoding(nn.Module):
"Implement the PE function."
def __init__(self, d_model, dropout, max_len=5000):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout)
# Compute the positional encodings once in log space.
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len).unsqueeze(1).float()
div_term = torch.exp(torch.arange(0, d_model, 2).float() *
-(math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer('pe', pe)
def forward(self, x):
x = x + self.pe[:, :x.size(1)]
return self.dropout(x)
class TransformerModel(AttModel):
def make_model(self, src_vocab, tgt_vocab, N_enc=6, N_dec=6,
d_model=512, d_ff=2048, h=8, dropout=0.1):
"Helper: Construct a model from hyperparameters."
c = copy.deepcopy
attn = MultiHeadedAttention(h, d_model, dropout)
ff = PositionwiseFeedForward(d_model, d_ff, dropout)
position = PositionalEncoding(d_model, dropout)
model = EncoderDecoder(
Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N_enc),
Decoder(DecoderLayer(d_model, c(attn), c(attn),
c(ff), dropout), N_dec),
lambda x:x, # nn.Sequential(Embeddings(d_model, src_vocab), c(position)),
nn.Sequential(Embeddings(d_model, tgt_vocab), c(position)),
Generator(d_model, tgt_vocab))
# This was important from their code.
# Initialize parameters with Glorot / fan_avg.
for p in model.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
return model
def __init__(self, opt):
super(TransformerModel, self).__init__(opt)
self.opt = opt
# self.config = yaml.load(open(opt.config_file))
self.N_enc = getattr(opt, 'N_enc', opt.num_layers)
self.N_dec = getattr(opt, 'N_dec', opt.num_layers)
self.d_model = getattr(opt, 'd_model', opt.input_encoding_size)
self.d_ff = getattr(opt, 'd_ff', opt.rnn_size)
self.h = getattr(opt, 'num_att_heads', 8)
self.dropout = getattr(opt, 'dropout', 0.1)
delattr(self, 'att_embed')
self.att_embed = nn.Sequential(*(
((nn.BatchNorm1d(self.att_feat_size),) if self.use_bn else ())+
(nn.Linear(self.att_feat_size, self.d_model),
nn.ReLU(),
nn.Dropout(self.drop_prob_lm))+
((nn.BatchNorm1d(self.d_model),) if self.use_bn==2 else ())))
delattr(self, 'embed')
self.embed = lambda x : x
delattr(self, 'fc_embed')
self.fc_embed = lambda x : x
delattr(self, 'logit')
del self.ctx2att
tgt_vocab = self.vocab_size + 1
self.model = self.make_model(0, tgt_vocab,
N_enc=self.N_enc,
N_dec=self.N_dec,
d_model=self.d_model,
d_ff=self.d_ff,
h=self.h,
dropout=self.dropout)
def logit(self, x): # unsafe way
return self.model.generator.proj(x)
def init_hidden(self, bsz):
return []
def _prepare_feature(self, fc_feats, att_feats, att_masks):
att_feats, seq, att_masks, seq_mask = self._prepare_feature_forward(att_feats, att_masks)
memory = self.model.encode(att_feats, att_masks)
return fc_feats[...,:0], att_feats[...,:0], memory, att_masks
def _prepare_feature_forward(self, att_feats, att_masks=None, seq=None):
att_feats, att_masks = self.clip_att(att_feats, att_masks)
att_feats = pack_wrapper(self.att_embed, att_feats, att_masks)
if att_masks is None:
att_masks = att_feats.new_ones(att_feats.shape[:2], dtype=torch.long)
att_masks = att_masks.unsqueeze(-2)
if seq is not None:
# crop the last one
# seq = seq[:,:-1]
seq_mask = (seq.data != self.eos_idx) & (seq.data != self.pad_idx)
seq_mask[:,0] = 1 # bos
seq_mask = seq_mask.unsqueeze(-2)
seq_mask = seq_mask & subsequent_mask(seq.size(-1)).to(seq_mask)
seq_per_img = seq.shape[0] // att_feats.shape[0]
if seq_per_img > 1:
att_feats, att_masks = utils.repeat_tensors(seq_per_img,
[att_feats, att_masks]
)
else:
seq_mask = None
return att_feats, seq, att_masks, seq_mask
def _forward(self, fc_feats, att_feats, seq, att_masks=None):
if seq.ndim == 3: # B * seq_per_img * seq_len
seq = seq.reshape(-1, seq.shape[2])
att_feats, seq, att_masks, seq_mask = self._prepare_feature_forward(att_feats, att_masks, seq)
out = self.model(att_feats, seq, att_masks, seq_mask)
outputs = self.model.generator(out)
return outputs
# return torch.cat([_.unsqueeze(1) for _ in outputs], 1)
def core(self, it, fc_feats_ph, att_feats_ph, memory, state, mask):
"""
state = [ys.unsqueeze(0)]
"""
if len(state) == 0:
ys = it.unsqueeze(1)
else:
ys = torch.cat([state[0][0], it.unsqueeze(1)], dim=1)
out = self.model.decode(memory, mask,
ys,
subsequent_mask(ys.size(1))
.to(memory.device))
return out[:, -1], [ys.unsqueeze(0)]

153
utils/config.py

@ -0,0 +1,153 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
# Copy from fvcore
import logging
import os
from typing import Any
import yaml
from yacs.config import CfgNode as _CfgNode
import io as PathManager
BASE_KEY = "_BASE_"
class CfgNode(_CfgNode):
"""
Our own extended version of :class:`yacs.config.CfgNode`.
It contains the following extra features:
1. The :meth:`merge_from_file` method supports the "_BASE_" key,
which allows the new CfgNode to inherit all the attributes from the
base configuration file.
2. Keys that start with "COMPUTED_" are treated as insertion-only
"computed" attributes. They can be inserted regardless of whether
the CfgNode is frozen or not.
3. With "allow_unsafe=True", it supports pyyaml tags that evaluate
expressions in config. See examples in
https://pyyaml.org/wiki/PyYAMLDocumentation#yaml-tags-and-python-types
Note that this may lead to arbitrary code execution: you must not
load a config file from untrusted sources before manually inspecting
the content of the file.
"""
@staticmethod
def load_yaml_with_base(filename, allow_unsafe = False):
"""
Just like `yaml.load(open(filename))`, but inherit attributes from its
`_BASE_`.
Args:
filename (str): the file name of the current config. Will be used to
find the base config file.
allow_unsafe (bool): whether to allow loading the config file with
`yaml.unsafe_load`.
Returns:
(dict): the loaded yaml
"""
with PathManager.open(filename, "r") as f:
try:
cfg = yaml.safe_load(f)
except yaml.constructor.ConstructorError:
if not allow_unsafe:
raise
logger = logging.getLogger(__name__)
logger.warning(
"Loading config {} with yaml.unsafe_load. Your machine may "
"be at risk if the file contains malicious content.".format(
filename
)
)
f.close()
with open(filename, "r") as f:
cfg = yaml.unsafe_load(f)
def merge_a_into_b(a, b):
# merge dict a into dict b. values in a will overwrite b.
for k, v in a.items():
if isinstance(v, dict) and k in b:
assert isinstance(
b[k], dict
), "Cannot inherit key '{}' from base!".format(k)
merge_a_into_b(v, b[k])
else:
b[k] = v
if BASE_KEY in cfg:
base_cfg_file = cfg[BASE_KEY]
if base_cfg_file.startswith("~"):
base_cfg_file = os.path.expanduser(base_cfg_file)
if not any(
map(base_cfg_file.startswith, ["/", "https://", "http://"])
):
# the path to base cfg is relative to the config file itself.
base_cfg_file = os.path.join(
os.path.dirname(filename), base_cfg_file
)
base_cfg = CfgNode.load_yaml_with_base(
base_cfg_file, allow_unsafe=allow_unsafe
)
del cfg[BASE_KEY]
merge_a_into_b(cfg, base_cfg)
return base_cfg
return cfg
def merge_from_file(self, cfg_filename, allow_unsafe = False):
"""
Merge configs from a given yaml file.
Args:
cfg_filename: the file name of the yaml config.
allow_unsafe: whether to allow loading the config file with
`yaml.unsafe_load`.
"""
loaded_cfg = CfgNode.load_yaml_with_base(
cfg_filename, allow_unsafe=allow_unsafe
)
loaded_cfg = type(self)(loaded_cfg)
self.merge_from_other_cfg(loaded_cfg)
# Forward the following calls to base, but with a check on the BASE_KEY.
def merge_from_other_cfg(self, cfg_other):
"""
Args:
cfg_other (CfgNode): configs to merge from.
"""
assert (
BASE_KEY not in cfg_other
), "The reserved key '{}' can only be used in files!".format(BASE_KEY)
return super().merge_from_other_cfg(cfg_other)
def merge_from_list(self, cfg_list):
"""
Args:
cfg_list (list): list of configs to merge from.
"""
keys = set(cfg_list[0::2])
assert (
BASE_KEY not in keys
), "The reserved key '{}' can only be used in files!".format(BASE_KEY)
return super().merge_from_list(cfg_list)
def __setattr__(self, name, val):
if name.startswith("COMPUTED_"):
if name in self:
old_val = self[name]
if old_val == val:
return
raise KeyError(
"Computed attributed '{}' already exists "
"with a different value! old={}, new={}.".format(
name, old_val, val
)
)
self[name] = val
else:
super().__setattr__(name, val)
if __name__ == '__main__':
cfg = CfgNode.load_yaml_with_base('configs/updown_long.yml')
print(cfg)

412
utils/opts.py

@ -0,0 +1,412 @@
from __future__ import print_function
import argparse
def if_use_feat(caption_model):
# Decide if load attention feature according to caption model
if caption_model in ['show_tell', 'all_img', 'fc', 'newfc']:
use_att, use_fc = False, True
elif caption_model == 'language_model':
use_att, use_fc = False, False
elif caption_model in ['updown', 'topdown']:
use_fc, use_att = True, True
else:
use_att, use_fc = True, False
return use_fc, use_att
import pprint
class Config(object):
def __init__(self, **kwargs):
"""Configuration Class: set kwargs as class attributes with setattr"""
for k, v in kwargs.items():
setattr(self, k, v)
@property
def config_str(self):
return pprint.pformat(self.__dict__)
def __repr__(self):
"""Pretty-print configurations in alphabetical order"""
config_str = 'Configurations\n'
config_str += self.config_str
return config_str
def parse_opt(parse=True, **optional_kwargs):
parser = argparse.ArgumentParser()
# Data input settings
parser.add_argument('--input_json', type=str, default='data/coco.json',
help='path to the json file containing additional info and vocab')
parser.add_argument('--input_fc_dir', type=str, default='data/cocotalk_fc',
help='path to the directory containing the preprocessed fc feats')
parser.add_argument('--input_att_dir', type=str, default='data/cocotalk_att',
help='path to the directory containing the preprocessed att feats')
parser.add_argument('--input_box_dir', type=str, default='data/cocotalk_box',
help='path to the directory containing the boxes of att feats')
parser.add_argument('--input_label_h5', type=str, default='data/coco_label.h5',
help='path to the h5file containing the preprocessed dataset')
parser.add_argument('--data_in_memory', action='store_true',
help='True if we want to save the features in memory')
parser.add_argument('--start_from', type=str, default=None,
help="""continue training from saved model at this path. Path must contain files saved by previous training process:
'infos.pkl' : configuration;
'model.pth' : weights
""")
parser.add_argument('--cached_tokens', type=str, default='coco-train-idxs',
help='Cached token file for calculating cider score during self critical training.')
# Model settings
parser.add_argument('--caption_model', type=str, default="show_tell",
help='show_tell, show_attend_tell, all_img, fc, att2in, att2in2, att2all2, adaatt, adaattmo, updown, stackatt, denseatt, transformer')
parser.add_argument('--rnn_size', type=int, default=512,
help='size of the rnn in number of hidden nodes in each layer')
parser.add_argument('--num_layers', type=int, default=1,
help='number of layers in the RNN')
parser.add_argument('--rnn_type', type=str, default='lstm',
help='rnn, gru, or lstm')
parser.add_argument('--input_encoding_size', type=int, default=512,
help='the encoding size of each token in the vocabulary, and the image.')
parser.add_argument('--att_hid_size', type=int, default=512,
help='the hidden size of the attention MLP; only useful in show_attend_tell; 0 if not using hidden layer')
parser.add_argument('--fc_feat_size', type=int, default=2048,
help='2048 for resnet, 4096 for vgg')
parser.add_argument('--att_feat_size', type=int, default=2048,
help='2048 for resnet, 512 for vgg')
parser.add_argument('--logit_layers', type=int, default=1,
help='number of layers in the RNN')
parser.add_argument('--use_bn', type=int, default=0,
help='If 1, then do batch_normalization first in att_embed, if 2 then do bn both in the beginning and the end of att_embed')
# feature manipulation
parser.add_argument('--norm_att_feat', type=int, default=0,
help='If normalize attention features')
parser.add_argument('--use_box', type=int, default=0,
help='If use box features')
parser.add_argument('--norm_box_feat', type=int, default=0,
help='If use box, do we normalize box feature')
# Optimization: General
parser.add_argument('--max_epochs', type=int, default=-1,
help='number of epochs')
parser.add_argument('--batch_size', type=int, default=16,
help='minibatch size')
parser.add_argument('--grad_clip_mode', type=str, default='value',
help='value or norm')
parser.add_argument('--grad_clip_value', type=float, default=0.1,
help='clip gradients at this value/max_norm, 0 means no clipping')
parser.add_argument('--drop_prob_lm', type=float, default=0.5,
help='strength of dropout in the Language Model RNN')
parser.add_argument('--self_critical_after', type=int, default=-1,
help='After what epoch do we start finetuning the CNN? (-1 = disable; never finetune, 0 = finetune from start)')
parser.add_argument('--seq_per_img', type=int, default=5,
help='number of captions to sample for each image during training. Done for efficiency since CNN forward pass is expensive. E.g. coco has 5 sents/image')
parser.add_argument('--verbose', type=int, default=0)
# Sample related
add_eval_sample_opts(parser)
#Optimization: for the Language Model
parser.add_argument('--optim', type=str, default='adam',
help='what update to use? rmsprop|sgd|sgdmom|adagrad|adam|adamw')
parser.add_argument('--learning_rate', type=float, default=4e-4,
help='learning rate')
parser.add_argument('--learning_rate_decay_start', type=int, default=-1,
help='at what iteration to start decaying learning rate? (-1 = dont) (in epoch)')
parser.add_argument('--learning_rate_decay_every', type=int, default=3,
help='every how many iterations thereafter to drop LR?(in epoch)')
parser.add_argument('--learning_rate_decay_rate', type=float, default=0.8,
help='every how many iterations thereafter to drop LR?(in epoch)')
parser.add_argument('--optim_alpha', type=float, default=0.9,
help='alpha for adam')
parser.add_argument('--optim_beta', type=float, default=0.999,
help='beta used for adam')
parser.add_argument('--optim_epsilon', type=float, default=1e-8,
help='epsilon that goes into denominator for smoothing')
parser.add_argument('--weight_decay', type=float, default=0,
help='weight_decay')
# Transformer
parser.add_argument('--label_smoothing', type=float, default=0,
help='')
parser.add_argument('--noamopt', action='store_true',
help='')
parser.add_argument('--noamopt_warmup', type=int, default=2000,
help='')
parser.add_argument('--noamopt_factor', type=float, default=1,
help='')
parser.add_argument('--reduce_on_plateau', action='store_true',
help='')
parser.add_argument('--reduce_on_plateau_factor', type=float, default=0.5,
help='')
parser.add_argument('--reduce_on_plateau_patience', type=int, default=3,
help='')
parser.add_argument('--cached_transformer', action='store_true',
help='')
parser.add_argument('--use_warmup', action='store_true',
help='warm up the learing rate?')
parser.add_argument('--scheduled_sampling_start', type=int, default=-1,
help='at what iteration to start decay gt probability')
parser.add_argument('--scheduled_sampling_increase_every', type=int, default=5,
help='every how many iterations thereafter to gt probability')
parser.add_argument('--scheduled_sampling_increase_prob', type=float, default=0.05,
help='How much to update the prob')
parser.add_argument('--scheduled_sampling_max_prob', type=float, default=0.25,
help='Maximum scheduled sampling prob.')
# Evaluation/Checkpointing
parser.add_argument('--val_images_use', type=int, default=3200,
help='how many images to use when periodically evaluating the validation loss? (-1 = all)')
parser.add_argument('--save_checkpoint_every', type=int, default=2500,
help='how often to save a model checkpoint (in iterations)?')
parser.add_argument('--save_every_epoch', action='store_true',
help='Save checkpoint every epoch, will overwrite save_checkpoint_every')
parser.add_argument('--save_history_ckpt', type=int, default=0,
help='If save checkpoints at every save point')
parser.add_argument('--checkpoint_path', type=str, default=None,
help='directory to store checkpointed models')
parser.add_argument('--language_eval', type=int, default=0,
help='Evaluate language as well (1 = yes, 0 = no)? BLEU/CIDEr/METEOR/ROUGE_L? requires coco-caption code from Github.')
parser.add_argument('--losses_log_every', type=int, default=25,
help='How often do we snapshot losses, for inclusion in the progress dump? (0 = disable)')
parser.add_argument('--load_best_score', type=int, default=1,
help='Do we load previous best score when resuming training.')
# misc
parser.add_argument('--id', type=str, default='',
help='an id identifying this run/job. used in cross-val and appended when writing progress files')
parser.add_argument('--train_only', type=int, default=0,
help='if true then use 80k, else use 110k')
# Reward
parser.add_argument('--cider_reward_weight', type=float, default=1,
help='The reward weight from cider')
parser.add_argument('--bleu_reward_weight', type=float, default=0,
help='The reward weight from bleu4')
# Reward
parser.add_argument('--clipscore_reward_weight', type=float, default=1,
help='The reward weight from clipscore')
parser.add_argument('--use_clipscore', type=float, default=0,
help='Use CLIPScore')
parser.add_argument('--clipscore_mode', type=str, default='clip_s',
help='Which CLIPScore to use: clip_s|refclip_s')
# Structure_loss
parser.add_argument('--structure_loss_weight', type=float, default=1,
help='')
parser.add_argument('--structure_after', type=int, default=-1,
help='T')
parser.add_argument('--structure_loss_type', type=str, default='seqnll',
help='')
parser.add_argument('--struc_use_logsoftmax', action='store_true', help='')
parser.add_argument('--entropy_reward_weight', type=float, default=0,
help='Entropy reward, seems very interesting')
parser.add_argument('--self_cider_reward_weight', type=float, default=0,
help='self cider reward')
# Used for self critical or structure. Used when sampling is need during training
parser.add_argument('--train_sample_n', type=int, default=16,
help='The reward weight from cider')
parser.add_argument('--train_sample_method', type=str, default='sample',
help='')
parser.add_argument('--train_beam_size', type=int, default=1,
help='')
# Used for self critical
parser.add_argument('--sc_sample_method', type=str, default='greedy',
help='')
parser.add_argument('--sc_beam_size', type=int, default=1,
help='')
# For diversity evaluation during training
add_diversity_opts(parser)
# config
parser.add_argument('--cfg', type=str, default=None,
help='configuration; similar to what is used in detectron')
parser.add_argument(
'--set_cfgs', dest='set_cfgs',
help='Set config keys. Key value sequence seperate by whitespace.'
'e.g. [key] [value] [key] [value]\n This has higher priority'
'than cfg file but lower than other args. (You can only overwrite'
'arguments that have alerady been defined in config file.)',
default=[], nargs='+')
# How will config be used
# 1) read cfg argument, and load the cfg file if it's not None
# 2) Overwrite cfg argument with set_cfgs
# 3) parse config argument to args.
# 4) in the end, parse command line argument and overwrite args
# step 1: read cfg_fn
# args = parser.parse_args()
# Parse the arguments.
if parse:
args = parser.parse_args()
# For interative engironmnet (ex. jupyter)
else:
args = parser.parse_known_args()[0]
# print(args)
# Namespace => Dictionary
kwargs = vars(args)
# for k, v in optional_kwargs.items():
# setattr(args, k, v)
kwargs.update(optional_kwargs)
args = Config(**kwargs)
if args.cfg is not None or args.set_cfgs is not None:
from .config import CfgNode
if args.cfg is not None:
# print('Read Cfg')
cn = CfgNode(CfgNode.load_yaml_with_base(args.cfg))
# print(cn)
else:
cn = CfgNode()
if args.set_cfgs is not None:
cn.merge_from_list(args.set_cfgs)
for k,v in cn.items():
if not hasattr(args, k):
import os
if 'LOCAL_RANK' in os.environ and os.environ['LOCAL_RANK'] != '0':
pass
else:
print('Warning: key %s not in args' % k)
setattr(args, k, v)
if parse:
args = parser.parse_args(namespace=args)
else:
args = parser.parse_known_args(namespace=args)[0]
# Check if args are valid
assert args.rnn_size > 0, "rnn_size should be greater than 0"
assert args.num_layers > 0, "num_layers should be greater than 0"
assert args.input_encoding_size > 0, "input_encoding_size should be greater than 0"
assert args.batch_size > 0, "batch_size should be greater than 0"
assert args.drop_prob_lm >= 0 and args.drop_prob_lm < 1, "drop_prob_lm should be between 0 and 1"
assert args.seq_per_img > 0, "seq_per_img should be greater than 0"
assert args.beam_size > 0, "beam_size should be greater than 0"
assert args.save_checkpoint_every > 0, "save_checkpoint_every should be greater than 0"
assert args.losses_log_every > 0, "losses_log_every should be greater than 0"
assert args.language_eval == 0 or args.language_eval == 1, "language_eval should be 0 or 1"
assert args.load_best_score == 0 or args.load_best_score == 1, "language_eval should be 0 or 1"
assert args.train_only == 0 or args.train_only == 1, "language_eval should be 0 or 1"
# default value for start_from and checkpoint_path
args.checkpoint_path = args.checkpoint_path or './log_%s' %args.id
args.start_from = args.start_from or args.checkpoint_path
# Deal with feature things before anything
args.use_fc, args.use_att = if_use_feat(args.caption_model)
if args.use_box: args.att_feat_size = args.att_feat_size + 5
return args
def add_eval_options(parser):
# Basic options
parser.add_argument('--batch_size', type=int, default=0,
help='if > 0 then overrule, otherwise load from checkpoint.')
parser.add_argument('--num_images', type=int, default=-1,
help='how many images to use when periodically evaluating the loss? (-1 = all)')
parser.add_argument('--language_eval', type=int, default=0,
help='Evaluate language as well (1 = yes, 0 = no)? BLEU/CIDEr/METEOR/ROUGE_L? requires coco-caption code from Github.')
parser.add_argument('--dump_images', type=int, default=1,
help='Dump images into vis/imgs folder for vis? (1=yes,0=no)')
parser.add_argument('--dump_json', type=int, default=1,
help='Dump json with predictions into vis folder? (1=yes,0=no)')
parser.add_argument('--dump_path', type=int, default=0,
help='Write image paths along with predictions into vis json? (1=yes,0=no)')
# Sampling options
add_eval_sample_opts(parser)
# For evaluation on a folder of images:
parser.add_argument('--image_folder', type=str, default='',
help='If this is nonempty then will predict on the images in this folder path')
parser.add_argument('--image_root', type=str, default='',
help='In case the image paths have to be preprended with a root path to an image folder')
# For evaluation on MSCOCO images from some split:
parser.add_argument('--input_fc_dir', type=str, default='',
help='path to the h5file containing the preprocessed dataset')
parser.add_argument('--input_att_dir', type=str, default='',
help='path to the h5file containing the preprocessed dataset')
parser.add_argument('--input_box_dir', type=str, default='',
help='path to the h5file containing the preprocessed dataset')
parser.add_argument('--input_label_h5', type=str, default='',
help='path to the h5file containing the preprocessed dataset')
parser.add_argument('--input_json', type=str, default='',
help='path to the json file containing additional info and vocab. empty = fetch from model checkpoint.')
parser.add_argument('--split', type=str, default='test',
help='if running on MSCOCO images, which split to use: val|test|train')
parser.add_argument('--coco_json', type=str, default='',
help='if nonempty then use this file in DataLoaderRaw (see docs there). Used only in MSCOCO test evaluation, where we have a specific json file of only test set images.')
# misc
parser.add_argument('--id', type=str, default='',
help='an id identifying this run/job. used only if language_eval = 1 for appending to intermediate files')
parser.add_argument('--verbose_beam', type=int, default=1,
help='if we need to print out all beam search beams.')
parser.add_argument('--verbose_loss', type=int, default=0,
help='If calculate loss using ground truth during evaluation')
def add_diversity_opts(parser):
parser.add_argument('--sample_n', type=int, default=1,
help='Diverse sampling')
parser.add_argument('--sample_n_method', type=str, default='sample',
help='sample, bs, dbs, gumbel, topk, dgreedy, dsample, dtopk, dtopp')
parser.add_argument('--eval_oracle', type=int, default=1,
help='if we need to calculate loss.')
# Sampling related options
def add_eval_sample_opts(parser):
parser.add_argument('--sample_method', type=str, default='greedy',
help='greedy; sample; gumbel; top<int>, top<0-1>')
parser.add_argument('--beam_size', type=int, default=1,
help='used when sample_method = greedy, indicates number of beams in beam search. Usually 2 or 3 works well. More is not better. Set this to 1 for faster runtime but a bit worse performance.')
parser.add_argument('--max_length', type=int, default=20,
help='Maximum length during sampling')
parser.add_argument('--length_penalty', type=str, default='',
help='wu_X or avg_X, X is the alpha')
parser.add_argument('--group_size', type=int, default=1,
help='used for diverse beam search. if group_size is 1, then it\'s normal beam search')
parser.add_argument('--diversity_lambda', type=float, default=0.5,
help='used for diverse beam search. Usually from 0.2 to 0.8. Higher value of lambda produces a more diverse list')
parser.add_argument('--temperature', type=float, default=1.0,
help='temperature when sampling from distributions (i.e. when sample_method = sample). Lower = "safer" predictions.')
parser.add_argument('--decoding_constraint', type=int, default=0,
help='If 1, not allowing same word in a row')
parser.add_argument('--block_trigrams', type=int, default=0,
help='block repeated trigram.')
parser.add_argument('--remove_bad_endings', type=int, default=0,
help='Remove bad endings')
parser.add_argument('--suppress_UNK', type=int, default=1,
help='Not predicting UNK')
if __name__ == '__main__':
import sys
sys.argv = [sys.argv[0]]
args = parse_opt()
print(args)
print()
sys.argv = [sys.argv[0], '--cfg', 'configs/updown_long.yml']
args1 = parse_opt()
print(dict(set(vars(args1).items()) - set(vars(args).items())))
print()
sys.argv = [sys.argv[0], '--cfg', 'configs/updown_long.yml', '--caption_model', 'att2in2']
args2 = parse_opt()
print(dict(set(vars(args2).items()) - set(vars(args1).items())))
Loading…
Cancel
Save