clip-caption-reward
copied
wxywb
2 years ago
30 changed files with 3491 additions and 0 deletions
@ -0,0 +1,18 @@ |
|||
# Copyright 2021 Zilliz. All rights reserved. |
|||
# |
|||
# Licensed under the Apache License, Version 2.0 (the "License"); |
|||
# you may not use this file except in compliance with the License. |
|||
# You may obtain a copy of the License at |
|||
# |
|||
# http://www.apache.org/licenses/LICENSE-2.0 |
|||
# |
|||
# Unless required by applicable law or agreed to in writing, software |
|||
# distributed under the License is distributed on an "AS IS" BASIS, |
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|||
# See the License for the specific language governing permissions and |
|||
# limitations under the License. |
|||
|
|||
from clip_caption_reward import ClipCaptionReward |
|||
|
|||
def clip_caption_reward(model_name: str): |
|||
return ClipCaptionReward(model_name) |
Binary file not shown.
@ -0,0 +1,464 @@ |
|||
# This file contains Att2in2, AdaAtt, AdaAttMO, UpDown model |
|||
|
|||
# AdaAtt is from Knowing When to Look: Adaptive Attention via A Visual Sentinel for Image Captioning |
|||
# https://arxiv.org/abs/1612.01887 |
|||
# AdaAttMO is a modified version with maxout lstm |
|||
|
|||
# Att2in is from Self-critical Sequence Training for Image Captioning |
|||
# https://arxiv.org/abs/1612.00563 |
|||
# In this file we only have Att2in2, which is a slightly different version of att2in, |
|||
# in which the img feature embedding and word embedding is the same as what in adaatt. |
|||
|
|||
# UpDown is from Bottom-Up and Top-Down Attention for Image Captioning and VQA |
|||
# https://arxiv.org/abs/1707.07998 |
|||
# However, it may not be identical to the author's architecture. |
|||
|
|||
from __future__ import absolute_import |
|||
from __future__ import division |
|||
from __future__ import print_function |
|||
|
|||
import numpy as np |
|||
import torch |
|||
import torch.nn as nn |
|||
import torch.nn.functional as F |
|||
#from . import utils |
|||
#utils.repeat_tensors |
|||
from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence, pad_packed_sequence |
|||
|
|||
from .CaptionModel import CaptionModel |
|||
|
|||
bad_endings = ['a','an','the','in','for','at','of','with','before','after','on','upon','near','to','is','are','am'] |
|||
bad_endings += ['the'] |
|||
|
|||
|
|||
def repeat_tensors(n, x): |
|||
""" |
|||
For a tensor of size Bx..., we repeat it n times, and make it Bnx... |
|||
For collections, do nested repeat |
|||
""" |
|||
if torch.is_tensor(x): |
|||
x = x.unsqueeze(1) # Bx1x... |
|||
x = x.expand(-1, n, *([-1]*len(x.shape[2:]))) # Bxnx... |
|||
x = x.reshape(x.shape[0]*n, *x.shape[2:]) # Bnx... |
|||
elif type(x) is list or type(x) is tuple: |
|||
x = [repeat_tensors(n, _) for _ in x] |
|||
return x |
|||
|
|||
def sort_pack_padded_sequence(input, lengths): |
|||
sorted_lengths, indices = torch.sort(lengths, descending=True) |
|||
# tmp = pack_padded_sequence(input[indices], sorted_lengths, batch_first=True) |
|||
tmp = pack_padded_sequence(input[indices], sorted_lengths.cpu(), batch_first=True) |
|||
inv_ix = indices.clone() |
|||
inv_ix[indices] = torch.arange(0,len(indices)).type_as(inv_ix) |
|||
return tmp, inv_ix |
|||
|
|||
def pad_unsort_packed_sequence(input, inv_ix): |
|||
tmp, _ = pad_packed_sequence(input, batch_first=True) |
|||
tmp = tmp[inv_ix] |
|||
return tmp |
|||
|
|||
def pack_wrapper(module, att_feats, att_masks): |
|||
if att_masks is not None: |
|||
packed, inv_ix = sort_pack_padded_sequence(att_feats, att_masks.data.long().sum(1)) |
|||
return pad_unsort_packed_sequence(PackedSequence(module(packed[0]), packed[1]), inv_ix) |
|||
else: |
|||
return module(att_feats) |
|||
|
|||
class AttModel(CaptionModel): |
|||
def __init__(self, opt): |
|||
super(AttModel, self).__init__() |
|||
self.vocab_size = opt.vocab_size |
|||
self.input_encoding_size = opt.input_encoding_size |
|||
#self.rnn_type = opt.rnn_type |
|||
self.rnn_size = opt.rnn_size |
|||
self.num_layers = opt.num_layers |
|||
self.drop_prob_lm = opt.drop_prob_lm |
|||
self.seq_length = getattr(opt, 'max_length', 20) or opt.seq_length # maximum sample length |
|||
self.fc_feat_size = opt.fc_feat_size |
|||
self.att_feat_size = opt.att_feat_size |
|||
self.att_hid_size = opt.att_hid_size |
|||
|
|||
self.bos_idx = getattr(opt, 'bos_idx', 0) |
|||
self.eos_idx = getattr(opt, 'eos_idx', 0) |
|||
self.pad_idx = getattr(opt, 'pad_idx', 0) |
|||
|
|||
self.use_bn = getattr(opt, 'use_bn', 0) |
|||
|
|||
self.ss_prob = 0.0 # Schedule sampling probability |
|||
|
|||
self.embed = nn.Sequential(nn.Embedding(self.vocab_size + 1, self.input_encoding_size), |
|||
nn.ReLU(), |
|||
nn.Dropout(self.drop_prob_lm)) |
|||
self.fc_embed = nn.Sequential(nn.Linear(self.fc_feat_size, self.rnn_size), |
|||
nn.ReLU(), |
|||
nn.Dropout(self.drop_prob_lm)) |
|||
self.att_embed = nn.Sequential(*( |
|||
((nn.BatchNorm1d(self.att_feat_size),) if self.use_bn else ())+ |
|||
(nn.Linear(self.att_feat_size, self.rnn_size), |
|||
nn.ReLU(), |
|||
nn.Dropout(self.drop_prob_lm))+ |
|||
((nn.BatchNorm1d(self.rnn_size),) if self.use_bn==2 else ()))) |
|||
|
|||
self.logit_layers = getattr(opt, 'logit_layers', 1) |
|||
if self.logit_layers == 1: |
|||
self.logit = nn.Linear(self.rnn_size, self.vocab_size + 1) |
|||
else: |
|||
self.logit = [[nn.Linear(self.rnn_size, self.rnn_size), nn.ReLU(), nn.Dropout(0.5)] for _ in range(opt.logit_layers - 1)] |
|||
self.logit = nn.Sequential(*(reduce(lambda x,y:x+y, self.logit) + [nn.Linear(self.rnn_size, self.vocab_size + 1)])) |
|||
self.ctx2att = nn.Linear(self.rnn_size, self.att_hid_size) |
|||
|
|||
# For remove bad endding |
|||
self.vocab = opt.vocab |
|||
self.bad_endings_ix = [int(k) for k,v in self.vocab.items() if v in bad_endings] |
|||
|
|||
def init_hidden(self, bsz): |
|||
weight = self.logit.weight \ |
|||
if hasattr(self.logit, "weight") \ |
|||
else self.logit[0].weight |
|||
return (weight.new_zeros(self.num_layers, bsz, self.rnn_size), |
|||
weight.new_zeros(self.num_layers, bsz, self.rnn_size)) |
|||
|
|||
def clip_att(self, att_feats, att_masks): |
|||
# Clip the length of att_masks and att_feats to the maximum length |
|||
if att_masks is not None: |
|||
max_len = att_masks.data.long().sum(1).max() |
|||
att_feats = att_feats[:, :max_len].contiguous() |
|||
att_masks = att_masks[:, :max_len].contiguous() |
|||
return att_feats, att_masks |
|||
|
|||
def _prepare_feature(self, fc_feats, att_feats, att_masks): |
|||
att_feats, att_masks = self.clip_att(att_feats, att_masks) |
|||
|
|||
# embed fc and att feats |
|||
fc_feats = self.fc_embed(fc_feats) |
|||
att_feats = pack_wrapper(self.att_embed, att_feats, att_masks) |
|||
|
|||
# Project the attention feats first to reduce memory and computation comsumptions. |
|||
p_att_feats = self.ctx2att(att_feats) |
|||
|
|||
return fc_feats, att_feats, p_att_feats, att_masks |
|||
|
|||
def _forward(self, fc_feats, att_feats, seq, att_masks=None): |
|||
batch_size = fc_feats.size(0) |
|||
if seq.ndim == 3: # B * seq_per_img * seq_len |
|||
seq = seq.reshape(-1, seq.shape[2]) |
|||
seq_per_img = seq.shape[0] // batch_size |
|||
state = self.init_hidden(batch_size*seq_per_img) |
|||
|
|||
outputs = fc_feats.new_zeros(batch_size*seq_per_img, seq.size(1), self.vocab_size+1) |
|||
|
|||
# Prepare the features |
|||
p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = self._prepare_feature(fc_feats, att_feats, att_masks) |
|||
# pp_att_feats is used for attention, we cache it in advance to reduce computation cost |
|||
|
|||
if seq_per_img > 1: |
|||
p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = repeat_tensors(seq_per_img, |
|||
[p_fc_feats, p_att_feats, pp_att_feats, p_att_masks] |
|||
) |
|||
|
|||
for i in range(seq.size(1)): |
|||
if self.training and i >= 1 and self.ss_prob > 0.0: # otherwiste no need to sample |
|||
sample_prob = fc_feats.new(batch_size*seq_per_img).uniform_(0, 1) |
|||
sample_mask = sample_prob < self.ss_prob |
|||
if sample_mask.sum() == 0: |
|||
it = seq[:, i].clone() |
|||
else: |
|||
sample_ind = sample_mask.nonzero().view(-1) |
|||
it = seq[:, i].data.clone() |
|||
prob_prev = torch.exp(outputs[:, i-1].detach()) # fetch prev distribution: shape Nx(M+1) |
|||
it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1).index_select(0, sample_ind)) |
|||
else: |
|||
it = seq[:, i].clone() |
|||
# break if all the sequences end |
|||
if i >= 1 and seq[:, i].sum() == 0: |
|||
break |
|||
|
|||
output, state = self.get_logprobs_state(it, p_fc_feats, p_att_feats, pp_att_feats, p_att_masks, state) |
|||
outputs[:, i] = output |
|||
|
|||
return outputs |
|||
|
|||
def get_logprobs_state(self, it, fc_feats, att_feats, p_att_feats, att_masks, state, output_logsoftmax=1): |
|||
# 'it' contains a word index |
|||
xt = self.embed(it) |
|||
|
|||
output, state = self.core(xt, fc_feats, att_feats, p_att_feats, state, att_masks) |
|||
if output_logsoftmax: |
|||
logprobs = F.log_softmax(self.logit(output), dim=1) |
|||
else: |
|||
logprobs = self.logit(output) |
|||
|
|||
return logprobs, state |
|||
|
|||
def _old_sample_beam(self, fc_feats, att_feats, att_masks=None, opt={}): |
|||
beam_size = opt.get('beam_size', 10) |
|||
group_size = opt.get('group_size', 1) |
|||
sample_n = opt.get('sample_n', 10) |
|||
# when sample_n == beam_size then each beam is a sample. |
|||
assert sample_n == 1 or sample_n == beam_size // group_size, 'when beam search, sample_n == 1 or beam search' |
|||
batch_size = fc_feats.size(0) |
|||
|
|||
p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = self._prepare_feature(fc_feats, att_feats, att_masks) |
|||
|
|||
assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed' |
|||
seq = fc_feats.new_full((batch_size*sample_n, self.seq_length), self.pad_idx, dtype=torch.long) |
|||
seqLogprobs = fc_feats.new_zeros(batch_size*sample_n, self.seq_length, self.vocab_size + 1) |
|||
# lets process every image independently for now, for simplicity |
|||
|
|||
self.done_beams = [[] for _ in range(batch_size)] |
|||
for k in range(batch_size): |
|||
state = self.init_hidden(beam_size) |
|||
tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, tmp_att_masks = repeat_tensors(beam_size, |
|||
[p_fc_feats[k:k+1], p_att_feats[k:k+1], pp_att_feats[k:k+1], p_att_masks[k:k+1] if att_masks is not None else None] |
|||
) |
|||
|
|||
for t in range(1): |
|||
if t == 0: # input <bos> |
|||
it = fc_feats.new_full([beam_size], self.bos_idx, dtype=torch.long) |
|||
|
|||
logprobs, state = self.get_logprobs_state(it, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, tmp_att_masks, state) |
|||
|
|||
self.done_beams[k] = self.old_beam_search(state, logprobs, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, tmp_att_masks, opt=opt) |
|||
if sample_n == beam_size: |
|||
for _n in range(sample_n): |
|||
seq[k*sample_n+_n, :] = self.done_beams[k][_n]['seq'] |
|||
seqLogprobs[k*sample_n+_n, :] = self.done_beams[k][_n]['logps'] |
|||
else: |
|||
seq[k, :] = self.done_beams[k][0]['seq'] # the first beam has highest cumulative score |
|||
seqLogprobs[k, :] = self.done_beams[k][0]['logps'] |
|||
# return the samples and their log likelihoods |
|||
return seq, seqLogprobs |
|||
|
|||
|
|||
def _sample_beam(self, fc_feats, att_feats, att_masks=None, opt={}): |
|||
beam_size = opt.get('beam_size', 10) |
|||
group_size = opt.get('group_size', 1) |
|||
sample_n = opt.get('sample_n', 10) |
|||
# when sample_n == beam_size then each beam is a sample. |
|||
assert sample_n == 1 or sample_n == beam_size // group_size, 'when beam search, sample_n == 1 or beam search' |
|||
batch_size = fc_feats.size(0) |
|||
|
|||
p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = self._prepare_feature(fc_feats, att_feats, att_masks) |
|||
|
|||
assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed' |
|||
seq = fc_feats.new_full((batch_size*sample_n, self.seq_length), self.pad_idx, dtype=torch.long) |
|||
seqLogprobs = fc_feats.new_zeros(batch_size*sample_n, self.seq_length, self.vocab_size + 1) |
|||
# lets process every image independently for now, for simplicity |
|||
|
|||
self.done_beams = [[] for _ in range(batch_size)] |
|||
|
|||
state = self.init_hidden(batch_size) |
|||
|
|||
# first step, feed bos |
|||
it = fc_feats.new_full([batch_size], self.bos_idx, dtype=torch.long) |
|||
logprobs, state = self.get_logprobs_state(it, p_fc_feats, p_att_feats, pp_att_feats, p_att_masks, state) |
|||
|
|||
p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = repeat_tensors(beam_size, |
|||
[p_fc_feats, p_att_feats, pp_att_feats, p_att_masks] |
|||
) |
|||
self.done_beams = self.beam_search(state, logprobs, p_fc_feats, p_att_feats, pp_att_feats, p_att_masks, opt=opt) |
|||
for k in range(batch_size): |
|||
if sample_n == beam_size: |
|||
for _n in range(sample_n): |
|||
seq_len = self.done_beams[k][_n]['seq'].shape[0] |
|||
seq[k*sample_n+_n, :seq_len] = self.done_beams[k][_n]['seq'] |
|||
seqLogprobs[k*sample_n+_n, :seq_len] = self.done_beams[k][_n]['logps'] |
|||
else: |
|||
seq_len = self.done_beams[k][0]['seq'].shape[0] |
|||
seq[k, :seq_len] = self.done_beams[k][0]['seq'] # the first beam has highest cumulative score |
|||
seqLogprobs[k, :seq_len] = self.done_beams[k][0]['logps'] |
|||
# return the samples and their log likelihoods |
|||
return seq, seqLogprobs |
|||
|
|||
def _sample(self, fc_feats, att_feats, att_masks=None, opt={}): |
|||
|
|||
sample_method = opt.get('sample_method', 'greedy') |
|||
beam_size = opt.get('beam_size', 1) |
|||
temperature = opt.get('temperature', 1.0) |
|||
sample_n = int(opt.get('sample_n', 1)) |
|||
group_size = opt.get('group_size', 1) |
|||
output_logsoftmax = opt.get('output_logsoftmax', 1) |
|||
decoding_constraint = opt.get('decoding_constraint', 0) |
|||
block_trigrams = opt.get('block_trigrams', 0) |
|||
remove_bad_endings = opt.get('remove_bad_endings', 0) |
|||
if beam_size > 1 and sample_method in ['greedy', 'beam_search']: |
|||
return self._sample_beam(fc_feats, att_feats, att_masks, opt) |
|||
if group_size > 1: |
|||
return self._diverse_sample(fc_feats, att_feats, att_masks, opt) |
|||
|
|||
batch_size = fc_feats.size(0) |
|||
state = self.init_hidden(batch_size*sample_n) |
|||
|
|||
p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = self._prepare_feature(fc_feats, att_feats, att_masks) |
|||
|
|||
if sample_n > 1: |
|||
p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = repeat_tensors(sample_n, |
|||
[p_fc_feats, p_att_feats, pp_att_feats, p_att_masks] |
|||
) |
|||
|
|||
trigrams = [] # will be a list of batch_size dictionaries |
|||
|
|||
seq = fc_feats.new_full((batch_size*sample_n, self.seq_length), self.pad_idx, dtype=torch.long) |
|||
seqLogprobs = fc_feats.new_zeros(batch_size*sample_n, self.seq_length, self.vocab_size + 1) |
|||
for t in range(self.seq_length + 1): |
|||
if t == 0: # input <bos> |
|||
it = fc_feats.new_full([batch_size*sample_n], self.bos_idx, dtype=torch.long) |
|||
|
|||
logprobs, state = self.get_logprobs_state(it, p_fc_feats, p_att_feats, pp_att_feats, p_att_masks, state, output_logsoftmax=output_logsoftmax) |
|||
|
|||
if decoding_constraint and t > 0: |
|||
tmp = logprobs.new_zeros(logprobs.size()) |
|||
tmp.scatter_(1, seq[:,t-1].data.unsqueeze(1), float('-inf')) |
|||
logprobs = logprobs + tmp |
|||
|
|||
if remove_bad_endings and t > 0: |
|||
tmp = logprobs.new_zeros(logprobs.size()) |
|||
prev_bad = np.isin(seq[:,t-1].data.cpu().numpy(), self.bad_endings_ix) |
|||
# Make it impossible to generate bad_endings |
|||
tmp[torch.from_numpy(prev_bad.astype('uint8')), 0] = float('-inf') |
|||
logprobs = logprobs + tmp |
|||
|
|||
# Mess with trigrams |
|||
# Copy from https://github.com/lukemelas/image-paragraph-captioning |
|||
if block_trigrams and t >= 3: |
|||
# Store trigram generated at last step |
|||
prev_two_batch = seq[:,t-3:t-1] |
|||
for i in range(batch_size): # = seq.size(0) |
|||
prev_two = (prev_two_batch[i][0].item(), prev_two_batch[i][1].item()) |
|||
current = seq[i][t-1] |
|||
if t == 3: # initialize |
|||
trigrams.append({prev_two: [current]}) # {LongTensor: list containing 1 int} |
|||
elif t > 3: |
|||
if prev_two in trigrams[i]: # add to list |
|||
trigrams[i][prev_two].append(current) |
|||
else: # create list |
|||
trigrams[i][prev_two] = [current] |
|||
# Block used trigrams at next step |
|||
prev_two_batch = seq[:,t-2:t] |
|||
mask = torch.zeros(logprobs.size(), requires_grad=False).to(logprobs.device) # batch_size x vocab_size |
|||
for i in range(batch_size): |
|||
prev_two = (prev_two_batch[i][0].item(), prev_two_batch[i][1].item()) |
|||
if prev_two in trigrams[i]: |
|||
for j in trigrams[i][prev_two]: |
|||
mask[i,j] += 1 |
|||
# Apply mask to log probs |
|||
#logprobs = logprobs - (mask * 1e9) |
|||
alpha = 2.0 # = 4 |
|||
logprobs = logprobs + (mask * -0.693 * alpha) # ln(1/2) * alpha (alpha -> infty works best) |
|||
|
|||
# sample the next word |
|||
if t == self.seq_length: # skip if we achieve maximum length |
|||
break |
|||
it, sampleLogprobs = self.sample_next_word(logprobs, sample_method, temperature) |
|||
|
|||
# stop when all finished |
|||
if t == 0: |
|||
unfinished = it != self.eos_idx |
|||
else: |
|||
it[~unfinished] = self.pad_idx # This allows eos_idx not being overwritten to 0 |
|||
logprobs = logprobs * unfinished.unsqueeze(1).to(logprobs) |
|||
unfinished = unfinished & (it != self.eos_idx) |
|||
seq[:,t] = it |
|||
seqLogprobs[:,t] = logprobs |
|||
# quit loop if all sequences have finished |
|||
if unfinished.sum() == 0: |
|||
break |
|||
|
|||
return seq, seqLogprobs |
|||
|
|||
def _diverse_sample(self, fc_feats, att_feats, att_masks=None, opt={}): |
|||
|
|||
sample_method = opt.get('sample_method', 'greedy') |
|||
beam_size = opt.get('beam_size', 1) |
|||
temperature = opt.get('temperature', 1.0) |
|||
group_size = opt.get('group_size', 1) |
|||
diversity_lambda = opt.get('diversity_lambda', 0.5) |
|||
decoding_constraint = opt.get('decoding_constraint', 0) |
|||
block_trigrams = opt.get('block_trigrams', 0) |
|||
remove_bad_endings = opt.get('remove_bad_endings', 0) |
|||
|
|||
batch_size = fc_feats.size(0) |
|||
state = self.init_hidden(batch_size) |
|||
|
|||
p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = self._prepare_feature(fc_feats, att_feats, att_masks) |
|||
|
|||
trigrams_table = [[] for _ in range(group_size)] # will be a list of batch_size dictionaries |
|||
|
|||
seq_table = [fc_feats.new_full((batch_size, self.seq_length), self.pad_idx, dtype=torch.long) for _ in range(group_size)] |
|||
seqLogprobs_table = [fc_feats.new_zeros(batch_size, self.seq_length) for _ in range(group_size)] |
|||
state_table = [self.init_hidden(batch_size) for _ in range(group_size)] |
|||
|
|||
for tt in range(self.seq_length + group_size): |
|||
for divm in range(group_size): |
|||
t = tt - divm |
|||
seq = seq_table[divm] |
|||
seqLogprobs = seqLogprobs_table[divm] |
|||
trigrams = trigrams_table[divm] |
|||
if t >= 0 and t <= self.seq_length-1: |
|||
if t == 0: # input <bos> |
|||
it = fc_feats.new_full([batch_size], self.bos_idx, dtype=torch.long) |
|||
else: |
|||
it = seq[:, t-1] # changed |
|||
|
|||
logprobs, state_table[divm] = self.get_logprobs_state(it, p_fc_feats, p_att_feats, pp_att_feats, p_att_masks, state_table[divm]) # changed |
|||
logprobs = F.log_softmax(logprobs / temperature, dim=-1) |
|||
|
|||
# Add diversity |
|||
if divm > 0: |
|||
unaug_logprobs = logprobs.clone() |
|||
for prev_choice in range(divm): |
|||
prev_decisions = seq_table[prev_choice][:, t] |
|||
logprobs[:, prev_decisions] = logprobs[:, prev_decisions] - diversity_lambda |
|||
|
|||
if decoding_constraint and t > 0: |
|||
tmp = logprobs.new_zeros(logprobs.size()) |
|||
tmp.scatter_(1, seq[:,t-1].data.unsqueeze(1), float('-inf')) |
|||
logprobs = logprobs + tmp |
|||
|
|||
if remove_bad_endings and t > 0: |
|||
tmp = logprobs.new_zeros(logprobs.size()) |
|||
prev_bad = np.isin(seq[:,t-1].data.cpu().numpy(), self.bad_endings_ix) |
|||
# Impossible to generate remove_bad_endings |
|||
tmp[torch.from_numpy(prev_bad.astype('uint8')), 0] = float('-inf') |
|||
logprobs = logprobs + tmp |
|||
|
|||
# Mess with trigrams |
|||
if block_trigrams and t >= 3: |
|||
# Store trigram generated at last step |
|||
prev_two_batch = seq[:,t-3:t-1] |
|||
for i in range(batch_size): # = seq.size(0) |
|||
prev_two = (prev_two_batch[i][0].item(), prev_two_batch[i][1].item()) |
|||
current = seq[i][t-1] |
|||
if t == 3: # initialize |
|||
trigrams.append({prev_two: [current]}) # {LongTensor: list containing 1 int} |
|||
elif t > 3: |
|||
if prev_two in trigrams[i]: # add to list |
|||
trigrams[i][prev_two].append(current) |
|||
else: # create list |
|||
trigrams[i][prev_two] = [current] |
|||
# Block used trigrams at next step |
|||
prev_two_batch = seq[:,t-2:t] |
|||
mask = torch.zeros(logprobs.size(), requires_grad=False).cuda() # batch_size x vocab_size |
|||
for i in range(batch_size): |
|||
prev_two = (prev_two_batch[i][0].item(), prev_two_batch[i][1].item()) |
|||
if prev_two in trigrams[i]: |
|||
for j in trigrams[i][prev_two]: |
|||
mask[i,j] += 1 |
|||
# Apply mask to log probs |
|||
#logprobs = logprobs - (mask * 1e9) |
|||
alpha = 2.0 # = 4 |
|||
logprobs = logprobs + (mask * -0.693 * alpha) # ln(1/2) * alpha (alpha -> infty works best) |
|||
|
|||
it, sampleLogprobs = self.sample_next_word(logprobs, sample_method, 1) |
|||
|
|||
# stop when all finished |
|||
if t == 0: |
|||
unfinished = it != self.eos_idx |
|||
else: |
|||
unfinished = (seq[:,t-1] != self.pad_idx) & (seq[:,t-1] != self.eos_idx) |
|||
it[~unfinished] = self.pad_idx |
|||
unfinished = unfinished & (it != self.eos_idx) # changed |
|||
seq[:,t] = it |
|||
seqLogprobs[:,t] = sampleLogprobs.view(-1) |
|||
|
|||
return torch.stack(seq_table, 1).reshape(batch_size * group_size, -1), torch.stack(seqLogprobs_table, 1).reshape(batch_size * group_size, -1) |
@ -0,0 +1,412 @@ |
|||
# This file contains ShowAttendTell and AllImg model |
|||
|
|||
# ShowAttendTell is from Show, Attend and Tell: Neural Image Caption Generation with Visual Attention |
|||
# https://arxiv.org/abs/1502.03044 |
|||
|
|||
# AllImg is a model where |
|||
# img feature is concatenated with word embedding at every time step as the input of lstm |
|||
from __future__ import absolute_import |
|||
from __future__ import division |
|||
from __future__ import print_function |
|||
|
|||
import numpy as np |
|||
import torch |
|||
import torch.nn as nn |
|||
import torch.nn.functional as F |
|||
from torch.autograd import * |
|||
#import .model_utils as mutils |
|||
from . import model_utils |
|||
|
|||
#from ..utils import misc as utils |
|||
#from ... import utils as model_utils |
|||
|
|||
#model_utils split_tensors |
|||
#utils penalty_builder decode_sequence |
|||
|
|||
class CaptionModel(nn.Module): |
|||
def __init__(self): |
|||
super(CaptionModel, self).__init__() |
|||
|
|||
# implements beam search |
|||
# calls beam_step and returns the final set of beams |
|||
# augments log-probabilities with diversity terms when number of groups > 1 |
|||
|
|||
def forward(self, *args, **kwargs): |
|||
mode = kwargs.get('mode', 'forward') |
|||
if 'mode' in kwargs: |
|||
del kwargs['mode'] |
|||
return getattr(self, '_'+mode)(*args, **kwargs) |
|||
|
|||
def beam_search(self, init_state, init_logprobs, *args, **kwargs): |
|||
|
|||
# function computes the similarity score to be augmented |
|||
def add_diversity(beam_seq_table, logprobs, t, divm, diversity_lambda, bdash): |
|||
local_time = t - divm |
|||
unaug_logprobs = logprobs.clone() |
|||
batch_size = beam_seq_table[0].shape[0] |
|||
|
|||
if divm > 0: |
|||
change = logprobs.new_zeros(batch_size, logprobs.shape[-1]) |
|||
for prev_choice in range(divm): |
|||
prev_decisions = beam_seq_table[prev_choice][:, :, local_time] # Nxb |
|||
for prev_labels in range(bdash): |
|||
change.scatter_add_(1, prev_decisions[:, prev_labels].unsqueeze(-1), change.new_ones(batch_size, 1)) |
|||
|
|||
if local_time == 0: |
|||
logprobs = logprobs - change * diversity_lambda |
|||
else: |
|||
logprobs = logprobs - self.repeat_tensor(bdash, change) * diversity_lambda |
|||
|
|||
return logprobs, unaug_logprobs |
|||
|
|||
|
|||
# does one step of classical beam search |
|||
|
|||
def beam_step(logprobs, unaug_logprobs, beam_size, t, beam_seq, beam_seq_logprobs, beam_logprobs_sum, state): |
|||
#INPUTS: |
|||
#logprobs: probabilities augmented after diversity N*bxV |
|||
#beam_size: obvious |
|||
#t : time instant |
|||
#beam_seq : tensor contanining the beams |
|||
#beam_seq_logprobs: tensor contanining the beam logprobs |
|||
#beam_logprobs_sum: tensor contanining joint logprobs |
|||
#OUPUTS: |
|||
#beam_seq : tensor containing the word indices of the decoded captions Nxbxl |
|||
#beam_seq_logprobs : log-probability of each decision made, NxbxlxV |
|||
#beam_logprobs_sum : joint log-probability of each beam Nxb |
|||
|
|||
batch_size = beam_logprobs_sum.shape[0] |
|||
vocab_size = logprobs.shape[-1] |
|||
logprobs = logprobs.reshape(batch_size, -1, vocab_size) # NxbxV |
|||
if t == 0: |
|||
assert logprobs.shape[1] == 1 |
|||
beam_logprobs_sum = beam_logprobs_sum[:, :1] |
|||
candidate_logprobs = beam_logprobs_sum.unsqueeze(-1) + logprobs # beam_logprobs_sum Nxb logprobs is NxbxV |
|||
ys, ix = torch.sort(candidate_logprobs.reshape(candidate_logprobs.shape[0], -1), -1, True) |
|||
ys, ix = ys[:,:beam_size], ix[:,:beam_size] |
|||
beam_ix = ix // vocab_size # Nxb which beam |
|||
selected_ix = ix % vocab_size # Nxb # which world |
|||
state_ix = (beam_ix + torch.arange(batch_size).type_as(beam_ix).unsqueeze(-1) * logprobs.shape[1]).reshape(-1) # N*b which in Nxb beams |
|||
|
|||
|
|||
if t > 0: |
|||
# gather according to beam_ix |
|||
assert (beam_seq.gather(1, beam_ix.unsqueeze(-1).expand_as(beam_seq)) == beam_seq.reshape(-1, beam_seq.shape[-1])[state_ix].view_as(beam_seq)).all() |
|||
beam_seq = beam_seq.gather(1, beam_ix.unsqueeze(-1).expand_as(beam_seq)) |
|||
|
|||
beam_seq_logprobs = beam_seq_logprobs.gather(1, beam_ix.unsqueeze(-1).unsqueeze(-1).expand_as(beam_seq_logprobs)) |
|||
|
|||
beam_seq = torch.cat([beam_seq, selected_ix.unsqueeze(-1)], -1) # beam_seq Nxbxl |
|||
beam_logprobs_sum = beam_logprobs_sum.gather(1, beam_ix) + \ |
|||
logprobs.reshape(batch_size, -1).gather(1, ix) |
|||
assert (beam_logprobs_sum == ys).all() |
|||
_tmp_beam_logprobs = unaug_logprobs[state_ix].reshape(batch_size, -1, vocab_size) |
|||
beam_logprobs = unaug_logprobs.reshape(batch_size, -1, vocab_size).gather(1, beam_ix.unsqueeze(-1).expand(-1, -1, vocab_size)) # NxbxV |
|||
assert (_tmp_beam_logprobs == beam_logprobs).all() |
|||
beam_seq_logprobs = torch.cat([ |
|||
beam_seq_logprobs, |
|||
beam_logprobs.reshape(batch_size, -1, 1, vocab_size)], 2) |
|||
|
|||
new_state = [None for _ in state] |
|||
for _ix in range(len(new_state)): |
|||
# copy over state in previous beam q to new beam at vix |
|||
new_state[_ix] = state[_ix][:, state_ix] |
|||
state = new_state |
|||
return beam_seq,beam_seq_logprobs,beam_logprobs_sum,state |
|||
|
|||
# Start diverse_beam_search |
|||
opt = kwargs['opt'] |
|||
temperature = opt.get('temperature', 1) # This should not affect beam search, but will affect dbs |
|||
beam_size = opt.get('beam_size', 10) |
|||
group_size = opt.get('group_size', 1) |
|||
diversity_lambda = opt.get('diversity_lambda', 0.5) |
|||
decoding_constraint = opt.get('decoding_constraint', 0) |
|||
remove_bad_endings = opt.get('remove_bad_endings', 0) |
|||
suppress_UNK = opt.get('suppress_UNK', 0) |
|||
length_penalty = utils.penalty_builder(opt.get('length_penalty', '')) |
|||
bdash = beam_size // group_size # beam per group |
|||
|
|||
batch_size = init_logprobs.shape[0] |
|||
device = init_logprobs.device |
|||
# INITIALIZATIONS |
|||
beam_seq_table = [torch.LongTensor(batch_size, bdash, 0).to(device) for _ in range(group_size)] |
|||
beam_seq_logprobs_table = [torch.FloatTensor(batch_size, bdash, 0, self.vocab_size + 1).to(device) for _ in range(group_size)] |
|||
beam_logprobs_sum_table = [torch.zeros(batch_size, bdash).to(device) for _ in range(group_size)] |
|||
|
|||
# logprobs # logprobs predicted in last time step, shape (beam_size, vocab_size+1) |
|||
done_beams_table = [[[] for __ in range(group_size)] for _ in range(batch_size)] |
|||
# state_table = [list(torch.unbind(_)) for _ in torch.stack(init_state).chunk(group_size, 2)] |
|||
# state_table = list(zip(*[_.reshape(-1, batch_size * bdash, group_size, *_.shape[2:]).chunk(group_size, 2) for _ in init_state])) |
|||
state_table = [[_.clone() for _ in init_state] for _ in range(group_size)] |
|||
# logprobs_table = list(init_logprobs.reshape(batch_size * bdash, group_size, -1).chunk(group_size, 0)) |
|||
logprobs_table = [init_logprobs.clone() for _ in range(group_size)] |
|||
# END INIT |
|||
|
|||
# Chunk elements in the args |
|||
args = list(args) |
|||
args = model_utils.split_tensors(group_size, args) # For each arg, turn (Bbg)x... to (Bb)x(g)x... |
|||
if self.__class__.__name__ == 'AttEnsemble': |
|||
args = [[[args[j][i][k] for i in range(len(self.models))] for j in range(len(args))] for k in range(group_size)] # group_name, arg_name, model_name |
|||
else: |
|||
args = [[args[i][j] for i in range(len(args))] for j in range(group_size)] |
|||
|
|||
for t in range(self.seq_length + group_size - 1): |
|||
for divm in range(group_size): |
|||
if t >= divm and t <= self.seq_length + divm - 1: |
|||
# add diversity |
|||
logprobs = logprobs_table[divm] |
|||
# suppress previous word |
|||
if decoding_constraint and t-divm > 0: |
|||
logprobs.scatter_(1, beam_seq_table[divm][:, :, t-divm-1].reshape(-1, 1).to(device), float('-inf')) |
|||
if remove_bad_endings and t-divm > 0: |
|||
logprobs[torch.from_numpy(np.isin(beam_seq_table[divm][:, :, t-divm-1].cpu().numpy(), self.bad_endings_ix)).reshape(-1), 0] = float('-inf') |
|||
# suppress UNK tokens in the decoding |
|||
if suppress_UNK and hasattr(self, 'vocab') and self.vocab[str(logprobs.size(1)-1)] == 'UNK': |
|||
logprobs[:,logprobs.size(1)-1] = logprobs[:, logprobs.size(1)-1] - 1000 |
|||
# diversity is added here |
|||
# the function directly modifies the logprobs values and hence, we need to return |
|||
# the unaugmented ones for sorting the candidates in the end. # for historical |
|||
# reasons :-) |
|||
logprobs, unaug_logprobs = add_diversity(beam_seq_table,logprobs,t,divm,diversity_lambda,bdash) |
|||
|
|||
# infer new beams |
|||
beam_seq_table[divm],\ |
|||
beam_seq_logprobs_table[divm],\ |
|||
beam_logprobs_sum_table[divm],\ |
|||
state_table[divm] = beam_step(logprobs, |
|||
unaug_logprobs, |
|||
bdash, |
|||
t-divm, |
|||
beam_seq_table[divm], |
|||
beam_seq_logprobs_table[divm], |
|||
beam_logprobs_sum_table[divm], |
|||
state_table[divm]) |
|||
|
|||
# if time's up... or if end token is reached then copy beams |
|||
for b in range(batch_size): |
|||
is_end = beam_seq_table[divm][b, :, t-divm] == self.eos_idx |
|||
assert beam_seq_table[divm].shape[-1] == t-divm+1 |
|||
if t == self.seq_length + divm - 1: |
|||
is_end.fill_(1) |
|||
for vix in range(bdash): |
|||
if is_end[vix]: |
|||
final_beam = { |
|||
'seq': beam_seq_table[divm][b, vix].clone(), |
|||
'logps': beam_seq_logprobs_table[divm][b, vix].clone(), |
|||
'unaug_p': beam_seq_logprobs_table[divm][b, vix].sum().item(), |
|||
'p': beam_logprobs_sum_table[divm][b, vix].item() |
|||
} |
|||
final_beam['p'] = length_penalty(t-divm+1, final_beam['p']) |
|||
done_beams_table[b][divm].append(final_beam) |
|||
beam_logprobs_sum_table[divm][b, is_end] -= 1000 |
|||
|
|||
# move the current group one step forward in time |
|||
|
|||
it = beam_seq_table[divm][:, :, t-divm].reshape(-1).to(logprobs.device) |
|||
logprobs_table[divm], state_table[divm] = self.get_logprobs_state(it, *(args[divm] + [state_table[divm]])) |
|||
logprobs_table[divm] = F.log_softmax(logprobs_table[divm] / temperature, dim=-1) |
|||
|
|||
# all beams are sorted by their log-probabilities |
|||
done_beams_table = [[sorted(done_beams_table[b][i], key=lambda x: -x['p'])[:bdash] for i in range(group_size)] for b in range(batch_size)] |
|||
done_beams = [sum(_, []) for _ in done_beams_table] |
|||
return done_beams |
|||
|
|||
def old_beam_search(self, init_state, init_logprobs, *args, **kwargs): |
|||
|
|||
# function computes the similarity score to be augmented |
|||
def add_diversity(beam_seq_table, logprobsf, t, divm, diversity_lambda, bdash): |
|||
local_time = t - divm |
|||
unaug_logprobsf = logprobsf.clone() |
|||
for prev_choice in range(divm): |
|||
prev_decisions = beam_seq_table[prev_choice][local_time] |
|||
for sub_beam in range(bdash): |
|||
for prev_labels in range(bdash): |
|||
logprobsf[sub_beam][prev_decisions[prev_labels]] = logprobsf[sub_beam][prev_decisions[prev_labels]] - diversity_lambda |
|||
return unaug_logprobsf |
|||
|
|||
# does one step of classical beam search |
|||
|
|||
def beam_step(logprobsf, unaug_logprobsf, beam_size, t, beam_seq, beam_seq_logprobs, beam_logprobs_sum, state): |
|||
#INPUTS: |
|||
#logprobsf: probabilities augmented after diversity |
|||
#beam_size: obvious |
|||
#t : time instant |
|||
#beam_seq : tensor contanining the beams |
|||
#beam_seq_logprobs: tensor contanining the beam logprobs |
|||
#beam_logprobs_sum: tensor contanining joint logprobs |
|||
#OUPUTS: |
|||
#beam_seq : tensor containing the word indices of the decoded captions |
|||
#beam_seq_logprobs : log-probability of each decision made, same size as beam_seq |
|||
#beam_logprobs_sum : joint log-probability of each beam |
|||
|
|||
ys,ix = torch.sort(logprobsf,1,True) |
|||
candidates = [] |
|||
cols = min(beam_size, ys.size(1)) |
|||
rows = beam_size |
|||
if t == 0: |
|||
rows = 1 |
|||
for c in range(cols): # for each column (word, essentially) |
|||
for q in range(rows): # for each beam expansion |
|||
#compute logprob of expanding beam q with word in (sorted) position c |
|||
local_logprob = ys[q,c].item() |
|||
candidate_logprob = beam_logprobs_sum[q] + local_logprob |
|||
# local_unaug_logprob = unaug_logprobsf[q,ix[q,c]] |
|||
candidates.append({'c':ix[q,c], 'q':q, 'p':candidate_logprob, 'r':unaug_logprobsf[q]}) |
|||
candidates = sorted(candidates, key=lambda x: -x['p']) |
|||
|
|||
new_state = [_.clone() for _ in state] |
|||
#beam_seq_prev, beam_seq_logprobs_prev |
|||
if t >= 1: |
|||
#we''ll need these as reference when we fork beams around |
|||
beam_seq_prev = beam_seq[:t].clone() |
|||
beam_seq_logprobs_prev = beam_seq_logprobs[:t].clone() |
|||
for vix in range(beam_size): |
|||
v = candidates[vix] |
|||
#fork beam index q into index vix |
|||
if t >= 1: |
|||
beam_seq[:t, vix] = beam_seq_prev[:, v['q']] |
|||
beam_seq_logprobs[:t, vix] = beam_seq_logprobs_prev[:, v['q']] |
|||
#rearrange recurrent states |
|||
for state_ix in range(len(new_state)): |
|||
# copy over state in previous beam q to new beam at vix |
|||
new_state[state_ix][:, vix] = state[state_ix][:, v['q']] # dimension one is time step |
|||
#append new end terminal at the end of this beam |
|||
beam_seq[t, vix] = v['c'] # c'th word is the continuation |
|||
beam_seq_logprobs[t, vix] = v['r'] # the raw logprob here |
|||
beam_logprobs_sum[vix] = v['p'] # the new (sum) logprob along this beam |
|||
state = new_state |
|||
return beam_seq,beam_seq_logprobs,beam_logprobs_sum,state,candidates |
|||
|
|||
# Start diverse_beam_search |
|||
opt = kwargs['opt'] |
|||
temperature = opt.get('temperature', 1) # This should not affect beam search, but will affect dbs |
|||
beam_size = opt.get('beam_size', 10) |
|||
group_size = opt.get('group_size', 1) |
|||
diversity_lambda = opt.get('diversity_lambda', 0.5) |
|||
decoding_constraint = opt.get('decoding_constraint', 0) |
|||
remove_bad_endings = opt.get('remove_bad_endings', 0) |
|||
suppress_UNK = opt.get('suppress_UNK', 0) |
|||
length_penalty = utils.penalty_builder(opt.get('length_penalty', '')) |
|||
bdash = beam_size // group_size # beam per group |
|||
|
|||
# INITIALIZATIONS |
|||
beam_seq_table = [torch.LongTensor(self.seq_length, bdash).zero_() for _ in range(group_size)] |
|||
beam_seq_logprobs_table = [torch.FloatTensor(self.seq_length, bdash, self.vocab_size + 1).zero_() for _ in range(group_size)] |
|||
beam_logprobs_sum_table = [torch.zeros(bdash) for _ in range(group_size)] |
|||
|
|||
# logprobs # logprobs predicted in last time step, shape (beam_size, vocab_size+1) |
|||
done_beams_table = [[] for _ in range(group_size)] |
|||
# state_table = [list(torch.unbind(_)) for _ in torch.stack(init_state).chunk(group_size, 2)] |
|||
state_table = list(zip(*[_.chunk(group_size, 1) for _ in init_state])) |
|||
logprobs_table = list(init_logprobs.chunk(group_size, 0)) |
|||
# END INIT |
|||
|
|||
# Chunk elements in the args |
|||
args = list(args) |
|||
if self.__class__.__name__ == 'AttEnsemble': |
|||
args = [[_.chunk(group_size) if _ is not None else [None]*group_size for _ in args_] for args_ in args] # arg_name, model_name, group_name |
|||
args = [[[args[j][i][k] for i in range(len(self.models))] for j in range(len(args))] for k in range(group_size)] # group_name, arg_name, model_name |
|||
else: |
|||
args = [_.chunk(group_size) if _ is not None else [None]*group_size for _ in args] |
|||
args = [[args[i][j] for i in range(len(args))] for j in range(group_size)] |
|||
|
|||
for t in range(self.seq_length + group_size - 1): |
|||
for divm in range(group_size): |
|||
if t >= divm and t <= self.seq_length + divm - 1: |
|||
# add diversity |
|||
logprobsf = logprobs_table[divm] |
|||
# suppress previous word |
|||
if decoding_constraint and t-divm > 0: |
|||
logprobsf.scatter_(1, beam_seq_table[divm][t-divm-1].unsqueeze(1).to(logprobsf.device), float('-inf')) |
|||
if remove_bad_endings and t-divm > 0: |
|||
logprobsf[torch.from_numpy(np.isin(beam_seq_table[divm][t-divm-1].cpu().numpy(), self.bad_endings_ix)), 0] = float('-inf') |
|||
# suppress UNK tokens in the decoding |
|||
if suppress_UNK and hasattr(self, 'vocab') and self.vocab[str(logprobsf.size(1)-1)] == 'UNK': |
|||
logprobsf[:,logprobsf.size(1)-1] = logprobsf[:, logprobsf.size(1)-1] - 1000 |
|||
# diversity is added here |
|||
# the function directly modifies the logprobsf values and hence, we need to return |
|||
# the unaugmented ones for sorting the candidates in the end. # for historical |
|||
# reasons :-) |
|||
unaug_logprobsf = add_diversity(beam_seq_table,logprobsf,t,divm,diversity_lambda,bdash) |
|||
|
|||
# infer new beams |
|||
beam_seq_table[divm],\ |
|||
beam_seq_logprobs_table[divm],\ |
|||
beam_logprobs_sum_table[divm],\ |
|||
state_table[divm],\ |
|||
candidates_divm = beam_step(logprobsf, |
|||
unaug_logprobsf, |
|||
bdash, |
|||
t-divm, |
|||
beam_seq_table[divm], |
|||
beam_seq_logprobs_table[divm], |
|||
beam_logprobs_sum_table[divm], |
|||
state_table[divm]) |
|||
|
|||
# if time's up... or if end token is reached then copy beams |
|||
for vix in range(bdash): |
|||
if beam_seq_table[divm][t-divm,vix] == self.eos_idx or t == self.seq_length + divm - 1: |
|||
final_beam = { |
|||
'seq': beam_seq_table[divm][:, vix].clone(), |
|||
'logps': beam_seq_logprobs_table[divm][:, vix].clone(), |
|||
'unaug_p': beam_seq_logprobs_table[divm][:, vix].sum().item(), |
|||
'p': beam_logprobs_sum_table[divm][vix].item() |
|||
} |
|||
final_beam['p'] = length_penalty(t-divm+1, final_beam['p']) |
|||
done_beams_table[divm].append(final_beam) |
|||
# don't continue beams from finished sequences |
|||
beam_logprobs_sum_table[divm][vix] = -1000 |
|||
|
|||
# move the current group one step forward in time |
|||
|
|||
it = beam_seq_table[divm][t-divm].to(logprobsf.device) |
|||
logprobs_table[divm], state_table[divm] = self.get_logprobs_state(it, *(args[divm] + [state_table[divm]])) |
|||
logprobs_table[divm] = F.log_softmax(logprobs_table[divm] / temperature, dim=-1) |
|||
|
|||
# all beams are sorted by their log-probabilities |
|||
done_beams_table = [sorted(done_beams_table[i], key=lambda x: -x['p'])[:bdash] for i in range(group_size)] |
|||
done_beams = sum(done_beams_table, []) |
|||
return done_beams |
|||
|
|||
def sample_next_word(self, logprobs, sample_method, temperature): |
|||
if sample_method == 'greedy': |
|||
sampleLogprobs, it = torch.max(logprobs.data, 1) |
|||
it = it.view(-1).long() |
|||
elif sample_method == 'gumbel': # gumbel softmax |
|||
# ref: https://gist.github.com/yzh119/fd2146d2aeb329d067568a493b20172f |
|||
def sample_gumbel(shape, eps=1e-20): |
|||
U = torch.rand(shape).to(logprobs.device) |
|||
return -torch.log(-torch.log(U + eps) + eps) |
|||
def gumbel_softmax_sample(logits, temperature): |
|||
y = logits + sample_gumbel(logits.size()) |
|||
return F.log_softmax(y / temperature, dim=-1) |
|||
_logprobs = gumbel_softmax_sample(logprobs, temperature) |
|||
_, it = torch.max(_logprobs.data, 1) |
|||
sampleLogprobs = logprobs.gather(1, it.unsqueeze(1)) # gather the logprobs at sampled positions |
|||
else: |
|||
logprobs = logprobs / temperature |
|||
if sample_method.startswith('top'): # topk sampling |
|||
top_num = float(sample_method[3:]) |
|||
if 0 < top_num < 1: |
|||
# nucleus sampling from # The Curious Case of Neural Text Degeneration |
|||
probs = F.softmax(logprobs, dim=1) |
|||
sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=1) |
|||
_cumsum = sorted_probs.cumsum(1) |
|||
mask = _cumsum < top_num |
|||
mask = torch.cat([torch.ones_like(mask[:,:1]), mask[:,:-1]], 1) |
|||
sorted_probs = sorted_probs * mask.to(sorted_probs) |
|||
sorted_probs = sorted_probs / sorted_probs.sum(1, keepdim=True) |
|||
logprobs.scatter_(1, sorted_indices, sorted_probs.log()) |
|||
else: |
|||
the_k = int(top_num) |
|||
tmp = torch.empty_like(logprobs).fill_(float('-inf')) |
|||
topk, indices = torch.topk(logprobs, the_k, dim=1) |
|||
tmp = tmp.scatter(1, indices, topk) |
|||
logprobs = tmp |
|||
it = torch.distributions.Categorical(logits=logprobs.detach()).sample() |
|||
sampleLogprobs = logprobs.gather(1, it.unsqueeze(1)) # gather the logprobs at sampled positions |
|||
return it, sampleLogprobs |
|||
|
|||
|
|||
def decode_sequence(self, seq): |
|||
return utils.decode_sequence(self.vocab, seq) |
@ -0,0 +1,92 @@ |
|||
import torch |
|||
|
|||
def split_tensors(n, x): |
|||
if torch.is_tensor(x): |
|||
assert x.shape[0] % n == 0 |
|||
x = x.reshape(x.shape[0] // n, n, *x.shape[1:]).unbind(1) |
|||
elif type(x) is list or type(x) is tuple: |
|||
x = [split_tensors(n, _) for _ in x] |
|||
elif x is None: |
|||
x = [None] * n |
|||
return x |
|||
|
|||
# Input: seq, N*D numpy array, with element 0 .. vocab_size. 0 is END token. |
|||
#def decode_sequence(ix_to_word, seq): |
|||
# # N, D = seq.size() |
|||
# N, D = seq.shape |
|||
# out = [] |
|||
# for i in range(N): |
|||
# txt = '' |
|||
# for j in range(D): |
|||
# ix = seq[i,j] |
|||
# if ix > 0 : |
|||
# if j >= 1: |
|||
# txt = txt + ' ' |
|||
# txt = txt + ix_to_word[str(ix.item())] |
|||
# else: |
|||
# break |
|||
# if int(os.getenv('REMOVE_BAD_ENDINGS', '0')): |
|||
# flag = 0 |
|||
# words = txt.split(' ') |
|||
# for j in range(len(words)): |
|||
# if words[-j-1] not in bad_endings: |
|||
# flag = -j |
|||
# break |
|||
# txt = ' '.join(words[0:len(words)+flag]) |
|||
# out.append(txt.replace('@@ ', '')) |
|||
# return out |
|||
|
|||
def decode_sequence(ix_to_word, seq, remove_bad_endings = True): |
|||
# N, D = seq.size() |
|||
N, D = seq.shape |
|||
bad_endings = ['a','an','the','in','for','at','of','with','before','after','on','upon','near','to','is','are','am'] |
|||
bad_endings += ['the'] |
|||
out = [] |
|||
for i in range(N): |
|||
txt = '' |
|||
for j in range(D): |
|||
ix = seq[i,j] |
|||
if ix > 0 : |
|||
if j >= 1: |
|||
txt = txt + ' ' |
|||
txt = txt + ix_to_word[str(ix.item())] |
|||
else: |
|||
break |
|||
if remove_bad_endings is True: |
|||
flag = 0 |
|||
words = txt.split(' ') |
|||
for j in range(len(words)): |
|||
if words[-j-1] not in bad_endings: |
|||
flag = -j |
|||
break |
|||
txt = ' '.join(words[0:len(words)+flag]) |
|||
out.append(txt.replace('@@ ', '')) |
|||
return out |
|||
|
|||
|
|||
|
|||
def penalty_builder(penalty_config): |
|||
if penalty_config == '': |
|||
return lambda x,y: y |
|||
pen_type, alpha = penalty_config.split('_') |
|||
alpha = float(alpha) |
|||
if pen_type == 'wu': |
|||
return lambda x,y: length_wu(x,y,alpha) |
|||
if pen_type == 'avg': |
|||
return lambda x,y: length_average(x,y,alpha) |
|||
|
|||
def length_wu(length, logprobs, alpha=0.): |
|||
""" |
|||
NMT length re-ranking score from |
|||
"Google's Neural Machine Translation System" :cite:`wu2016google`. |
|||
""" |
|||
|
|||
modifier = (((5 + length) ** alpha) / |
|||
((5 + 1) ** alpha)) |
|||
return (logprobs / modifier) |
|||
|
|||
def length_average(length, logprobs, alpha=0.): |
|||
""" |
|||
Returns the average probability of tokens in a sequence. |
|||
""" |
|||
return logprobs / length |
@ -0,0 +1 @@ |
|||
from .clip import * |
Binary file not shown.
@ -0,0 +1,193 @@ |
|||
import hashlib |
|||
import os |
|||
import urllib |
|||
import warnings |
|||
from typing import Union, List |
|||
|
|||
import torch |
|||
from PIL import Image |
|||
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize |
|||
from tqdm import tqdm |
|||
|
|||
from .model import build_model |
|||
from .simple_tokenizer import SimpleTokenizer as _Tokenizer |
|||
|
|||
__all__ = ["available_models", "load", "tokenize"] |
|||
_tokenizer = _Tokenizer() |
|||
|
|||
_MODELS = { |
|||
"RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", |
|||
"RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", |
|||
"RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", |
|||
"ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", |
|||
} |
|||
|
|||
|
|||
def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")): |
|||
os.makedirs(root, exist_ok=True) |
|||
filename = os.path.basename(url) |
|||
|
|||
expected_sha256 = url.split("/")[-2] |
|||
download_target = os.path.join(root, filename) |
|||
|
|||
if os.path.exists(download_target) and not os.path.isfile(download_target): |
|||
raise RuntimeError(f"{download_target} exists and is not a regular file") |
|||
|
|||
if os.path.isfile(download_target): |
|||
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: |
|||
return download_target |
|||
else: |
|||
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") |
|||
|
|||
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: |
|||
with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop: |
|||
while True: |
|||
buffer = source.read(8192) |
|||
if not buffer: |
|||
break |
|||
|
|||
output.write(buffer) |
|||
loop.update(len(buffer)) |
|||
|
|||
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: |
|||
raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") |
|||
|
|||
return download_target |
|||
|
|||
|
|||
def _transform(n_px): |
|||
return Compose([ |
|||
Resize(n_px, interpolation=Image.BICUBIC), |
|||
CenterCrop(n_px), |
|||
lambda image: image.convert("RGB"), |
|||
ToTensor(), |
|||
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), |
|||
]) |
|||
|
|||
|
|||
def available_models() -> List[str]: |
|||
"""Returns the names of available CLIP models""" |
|||
return list(_MODELS.keys()) |
|||
|
|||
|
|||
def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=True): |
|||
"""Load a CLIP model |
|||
|
|||
Parameters |
|||
---------- |
|||
name : str |
|||
A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict |
|||
|
|||
device : Union[str, torch.device] |
|||
The device to put the loaded model |
|||
|
|||
jit : bool |
|||
Whether to load the optimized JIT model (default) or more hackable non-JIT model. |
|||
|
|||
Returns |
|||
------- |
|||
model : torch.nn.Module |
|||
The CLIP model |
|||
|
|||
preprocess : Callable[[PIL.Image], torch.Tensor] |
|||
A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input |
|||
""" |
|||
if name in _MODELS: |
|||
model_path = _download(_MODELS[name]) |
|||
elif os.path.isfile(name): |
|||
model_path = name |
|||
else: |
|||
raise RuntimeError(f"Model {name} not found; available models = {available_models()}") |
|||
|
|||
try: |
|||
# loading JIT archive |
|||
model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() |
|||
state_dict = None |
|||
except RuntimeError: |
|||
# loading saved state dict |
|||
if jit: |
|||
warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") |
|||
jit = False |
|||
state_dict = torch.load(model_path, map_location="cpu") |
|||
|
|||
if not jit: |
|||
model = build_model(state_dict or model.state_dict()).to(device) |
|||
if str(device) == "cpu": |
|||
model.float() |
|||
return model, _transform(model.visual.input_resolution) |
|||
|
|||
# patch the device names |
|||
device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) |
|||
device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] |
|||
|
|||
def patch_device(module): |
|||
graphs = [module.graph] if hasattr(module, "graph") else [] |
|||
if hasattr(module, "forward1"): |
|||
graphs.append(module.forward1.graph) |
|||
|
|||
for graph in graphs: |
|||
for node in graph.findAllNodes("prim::Constant"): |
|||
if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): |
|||
node.copyAttributes(device_node) |
|||
|
|||
model.apply(patch_device) |
|||
patch_device(model.encode_image) |
|||
patch_device(model.encode_text) |
|||
|
|||
# patch dtype to float32 on CPU |
|||
if str(device) == "cpu": |
|||
float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) |
|||
float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] |
|||
float_node = float_input.node() |
|||
|
|||
def patch_float(module): |
|||
graphs = [module.graph] if hasattr(module, "graph") else [] |
|||
if hasattr(module, "forward1"): |
|||
graphs.append(module.forward1.graph) |
|||
|
|||
for graph in graphs: |
|||
for node in graph.findAllNodes("aten::to"): |
|||
inputs = list(node.inputs()) |
|||
for i in [1, 2]: # dtype can be the second or third argument to aten::to() |
|||
if inputs[i].node()["value"] == 5: |
|||
inputs[i].node().copyAttributes(float_node) |
|||
|
|||
model.apply(patch_float) |
|||
patch_float(model.encode_image) |
|||
patch_float(model.encode_text) |
|||
|
|||
model.float() |
|||
|
|||
return model, _transform(model.input_resolution.item()) |
|||
|
|||
|
|||
def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor: |
|||
""" |
|||
Returns the tokenized representation of given input string(s) |
|||
|
|||
Parameters |
|||
---------- |
|||
texts : Union[str, List[str]] |
|||
An input string or a list of input strings to tokenize |
|||
|
|||
context_length : int |
|||
The context length to use; all CLIP models use 77 as the context length |
|||
|
|||
Returns |
|||
------- |
|||
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] |
|||
""" |
|||
if isinstance(texts, str): |
|||
texts = [texts] |
|||
|
|||
sot_token = _tokenizer.encoder["<|startoftext|>"] |
|||
eot_token = _tokenizer.encoder["<|endoftext|>"] |
|||
all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] |
|||
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) |
|||
|
|||
for i, tokens in enumerate(all_tokens): |
|||
if len(tokens) > context_length: |
|||
raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") |
|||
result[i, :len(tokens)] = torch.tensor(tokens) |
|||
|
|||
return result |
@ -0,0 +1,437 @@ |
|||
from collections import OrderedDict |
|||
from typing import Tuple, Union |
|||
|
|||
import torch |
|||
import torch.nn.functional as F |
|||
from torch import nn |
|||
|
|||
|
|||
class Bottleneck(nn.Module): |
|||
expansion = 4 |
|||
|
|||
def __init__(self, inplanes, planes, stride=1): |
|||
super().__init__() |
|||
|
|||
# all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 |
|||
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) |
|||
self.bn1 = nn.BatchNorm2d(planes) |
|||
|
|||
self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) |
|||
self.bn2 = nn.BatchNorm2d(planes) |
|||
|
|||
self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() |
|||
|
|||
self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) |
|||
self.bn3 = nn.BatchNorm2d(planes * self.expansion) |
|||
|
|||
self.relu = nn.ReLU(inplace=True) |
|||
self.downsample = None |
|||
self.stride = stride |
|||
|
|||
if stride > 1 or inplanes != planes * Bottleneck.expansion: |
|||
# downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 |
|||
self.downsample = nn.Sequential(OrderedDict([ |
|||
("-1", nn.AvgPool2d(stride)), |
|||
("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), |
|||
("1", nn.BatchNorm2d(planes * self.expansion)) |
|||
])) |
|||
|
|||
def forward(self, x: torch.Tensor): |
|||
identity = x |
|||
|
|||
out = self.relu(self.bn1(self.conv1(x))) |
|||
out = self.relu(self.bn2(self.conv2(out))) |
|||
out = self.avgpool(out) |
|||
out = self.bn3(self.conv3(out)) |
|||
|
|||
if self.downsample is not None: |
|||
identity = self.downsample(x) |
|||
|
|||
out += identity |
|||
out = self.relu(out) |
|||
return out |
|||
|
|||
|
|||
class AttentionPool2d(nn.Module): |
|||
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): |
|||
super().__init__() |
|||
self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) |
|||
self.k_proj = nn.Linear(embed_dim, embed_dim) |
|||
self.q_proj = nn.Linear(embed_dim, embed_dim) |
|||
self.v_proj = nn.Linear(embed_dim, embed_dim) |
|||
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) |
|||
self.num_heads = num_heads |
|||
|
|||
def forward(self, x): |
|||
x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC |
|||
# print(x.shape, self.positional_embedding.shape) |
|||
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC |
|||
x = x + self.positional_embedding[0, :, None, :].to(x.dtype) # (HW+1)NC |
|||
x, _ = F.multi_head_attention_forward( |
|||
query=x, key=x, value=x, |
|||
embed_dim_to_check=x.shape[-1], |
|||
num_heads=self.num_heads, |
|||
q_proj_weight=self.q_proj.weight, |
|||
k_proj_weight=self.k_proj.weight, |
|||
v_proj_weight=self.v_proj.weight, |
|||
in_proj_weight=None, |
|||
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), |
|||
bias_k=None, |
|||
bias_v=None, |
|||
add_zero_attn=False, |
|||
dropout_p=0, |
|||
out_proj_weight=torch.ones_like(self.q_proj.weight), |
|||
out_proj_bias=torch.zeros_like(self.q_proj.bias), |
|||
# out_proj_weight=self.c_proj.weight, |
|||
# out_proj_bias=self.c_proj.bias, |
|||
use_separate_proj_weight=True, |
|||
training=self.training, |
|||
need_weights=False |
|||
) |
|||
|
|||
return x[0] |
|||
|
|||
|
|||
class ModifiedResNet(nn.Module): |
|||
""" |
|||
A ResNet class that is similar to torchvision's but contains the following changes: |
|||
- There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. |
|||
- Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 |
|||
- The final pooling layer is a QKV attention instead of an average pool |
|||
""" |
|||
|
|||
def __init__(self, layers, output_dim, heads, input_resolution=224, width=64): |
|||
super().__init__() |
|||
self.output_dim = output_dim |
|||
self.input_resolution = input_resolution |
|||
|
|||
# the 3-layer stem |
|||
self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) |
|||
self.bn1 = nn.BatchNorm2d(width // 2) |
|||
self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) |
|||
self.bn2 = nn.BatchNorm2d(width // 2) |
|||
self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) |
|||
self.bn3 = nn.BatchNorm2d(width) |
|||
self.avgpool = nn.AvgPool2d(2) |
|||
self.relu = nn.ReLU(inplace=True) |
|||
|
|||
# residual layers |
|||
self._inplanes = width # this is a *mutable* variable used during construction |
|||
self.layer1 = self._make_layer(width, layers[0]) |
|||
self.layer2 = self._make_layer(width * 2, layers[1], stride=2) |
|||
self.layer3 = self._make_layer(width * 4, layers[2], stride=2) |
|||
self.layer4 = self._make_layer(width * 8, layers[3], stride=2) |
|||
|
|||
embed_dim = width * 32 # the ResNet feature dimension |
|||
self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim) |
|||
|
|||
def _make_layer(self, planes, blocks, stride=1): |
|||
layers = [Bottleneck(self._inplanes, planes, stride)] |
|||
|
|||
self._inplanes = planes * Bottleneck.expansion |
|||
for _ in range(1, blocks): |
|||
layers.append(Bottleneck(self._inplanes, planes)) |
|||
|
|||
return nn.Sequential(*layers) |
|||
|
|||
def forward(self, x): |
|||
def stem(x): |
|||
for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]: |
|||
x = self.relu(bn(conv(x))) |
|||
x = self.avgpool(x) |
|||
return x |
|||
|
|||
x = x.type(self.conv1.weight.dtype) |
|||
x = stem(x) |
|||
x = self.layer1(x) |
|||
x = self.layer2(x) |
|||
x = self.layer3(x) |
|||
x = self.layer4(x) |
|||
# print(x.shape) |
|||
# x = self.attnpool(x) |
|||
attnpool = self.attnpool(x) |
|||
|
|||
return (x, attnpool) |
|||
|
|||
|
|||
class LayerNorm(nn.LayerNorm): |
|||
"""Subclass torch's LayerNorm to handle fp16.""" |
|||
|
|||
def forward(self, x: torch.Tensor): |
|||
orig_type = x.dtype |
|||
ret = super().forward(x.type(torch.float32)) |
|||
return ret.type(orig_type) |
|||
|
|||
|
|||
class QuickGELU(nn.Module): |
|||
def forward(self, x: torch.Tensor): |
|||
return x * torch.sigmoid(1.702 * x) |
|||
|
|||
|
|||
class ResidualAttentionBlock(nn.Module): |
|||
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): |
|||
super().__init__() |
|||
|
|||
self.attn = nn.MultiheadAttention(d_model, n_head) |
|||
self.ln_1 = LayerNorm(d_model) |
|||
self.mlp = nn.Sequential(OrderedDict([ |
|||
("c_fc", nn.Linear(d_model, d_model * 4)), |
|||
("gelu", QuickGELU()), |
|||
("c_proj", nn.Linear(d_model * 4, d_model)) |
|||
])) |
|||
self.ln_2 = LayerNorm(d_model) |
|||
self.attn_mask = attn_mask |
|||
|
|||
def attention(self, x: torch.Tensor): |
|||
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None |
|||
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] |
|||
|
|||
def forward(self, x: torch.Tensor): |
|||
x = x + self.attention(self.ln_1(x)) |
|||
x = x + self.mlp(self.ln_2(x)) |
|||
return x |
|||
|
|||
|
|||
class Transformer(nn.Module): |
|||
def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): |
|||
super().__init__() |
|||
self.width = width |
|||
self.layers = layers |
|||
self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) |
|||
|
|||
def forward(self, x: torch.Tensor): |
|||
return self.resblocks(x) |
|||
|
|||
|
|||
class VisualTransformer(nn.Module): |
|||
def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int): |
|||
super().__init__() |
|||
self.input_resolution = input_resolution |
|||
self.output_dim = output_dim |
|||
self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) |
|||
|
|||
scale = width ** -0.5 |
|||
self.class_embedding = nn.Parameter(scale * torch.randn(width)) |
|||
self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)) |
|||
self.ln_pre = LayerNorm(width) |
|||
|
|||
self.transformer = Transformer(width, layers, heads) |
|||
|
|||
self.ln_post = LayerNorm(width) |
|||
self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) |
|||
|
|||
def forward(self, x: torch.Tensor): |
|||
x = self.conv1(x) # shape = [*, width, grid, grid] |
|||
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] |
|||
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] |
|||
x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] |
|||
x = x + self.positional_embedding.to(x.dtype) |
|||
x = self.ln_pre(x) |
|||
|
|||
x = x.permute(1, 0, 2) # NLD -> LND |
|||
x = self.transformer(x) |
|||
x = x.permute(1, 0, 2) # LND -> NLD |
|||
|
|||
# x = self.ln_post(x[:, 0, :]) |
|||
|
|||
x = self.ln_post(x) |
|||
# if self.proj is not None: |
|||
# x = x @ self.proj |
|||
|
|||
return x |
|||
|
|||
|
|||
class CLIP(nn.Module): |
|||
def __init__(self, |
|||
embed_dim: int, |
|||
# vision |
|||
image_resolution: int, |
|||
vision_layers: Union[Tuple[int, int, int, int], int], |
|||
vision_width: int, |
|||
vision_patch_size: int, |
|||
# text |
|||
context_length: int, |
|||
vocab_size: int, |
|||
transformer_width: int, |
|||
transformer_heads: int, |
|||
transformer_layers: int |
|||
): |
|||
super().__init__() |
|||
|
|||
self.context_length = context_length |
|||
|
|||
if isinstance(vision_layers, (tuple, list)): |
|||
vision_heads = vision_width * 32 // 64 |
|||
self.visual = ModifiedResNet( |
|||
layers=vision_layers, |
|||
output_dim=embed_dim, |
|||
heads=vision_heads, |
|||
input_resolution=image_resolution, |
|||
width=vision_width |
|||
) |
|||
else: |
|||
vision_heads = vision_width // 64 |
|||
self.visual = VisualTransformer( |
|||
input_resolution=image_resolution, |
|||
patch_size=vision_patch_size, |
|||
width=vision_width, |
|||
layers=vision_layers, |
|||
heads=vision_heads, |
|||
output_dim=embed_dim |
|||
) |
|||
|
|||
self.transformer = Transformer( |
|||
width=transformer_width, |
|||
layers=transformer_layers, |
|||
heads=transformer_heads, |
|||
attn_mask=self.build_attention_mask() |
|||
) |
|||
|
|||
self.vocab_size = vocab_size |
|||
self.token_embedding = nn.Embedding(vocab_size, transformer_width) |
|||
self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) |
|||
self.ln_final = LayerNorm(transformer_width) |
|||
|
|||
self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) |
|||
self.logit_scale = nn.Parameter(torch.ones([])) |
|||
|
|||
self.initialize_parameters() |
|||
|
|||
def initialize_parameters(self): |
|||
nn.init.normal_(self.token_embedding.weight, std=0.02) |
|||
nn.init.normal_(self.positional_embedding, std=0.01) |
|||
|
|||
if isinstance(self.visual, ModifiedResNet): |
|||
if self.visual.attnpool is not None: |
|||
std = self.visual.attnpool.c_proj.in_features ** -0.5 |
|||
nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std) |
|||
nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std) |
|||
nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std) |
|||
nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std) |
|||
|
|||
for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]: |
|||
for name, param in resnet_block.named_parameters(): |
|||
if name.endswith("bn3.weight"): |
|||
nn.init.zeros_(param) |
|||
|
|||
proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) |
|||
attn_std = self.transformer.width ** -0.5 |
|||
fc_std = (2 * self.transformer.width) ** -0.5 |
|||
for block in self.transformer.resblocks: |
|||
nn.init.normal_(block.attn.in_proj_weight, std=attn_std) |
|||
nn.init.normal_(block.attn.out_proj.weight, std=proj_std) |
|||
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) |
|||
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) |
|||
|
|||
if self.text_projection is not None: |
|||
nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) |
|||
|
|||
def build_attention_mask(self): |
|||
# lazily create causal attention mask, with full attention between the vision tokens |
|||
# pytorch uses additive attention mask; fill with -inf |
|||
mask = torch.empty(self.context_length, self.context_length) |
|||
mask.fill_(float("-inf")) |
|||
mask.triu_(1) # zero out the lower diagonal |
|||
return mask |
|||
|
|||
@property |
|||
def dtype(self): |
|||
return self.visual.conv1.weight.dtype |
|||
|
|||
def encode_image(self, image): |
|||
return self.visual(image.type(self.dtype)) |
|||
|
|||
def encode_text(self, text): |
|||
x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] |
|||
|
|||
x = x + self.positional_embedding.type(self.dtype) |
|||
x = x.permute(1, 0, 2) # NLD -> LND |
|||
x = self.transformer(x) |
|||
x = x.permute(1, 0, 2) # LND -> NLD |
|||
x = self.ln_final(x).type(self.dtype) |
|||
|
|||
# x.shape = [batch_size, n_ctx, transformer.width] |
|||
# take features from the eot embedding (eot_token is the highest number in each sequence) |
|||
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection |
|||
|
|||
return x |
|||
|
|||
def forward(self, image, text): |
|||
image_features = self.encode_image(image) |
|||
text_features = self.encode_text(text) |
|||
|
|||
# normalized features |
|||
image_features = image_features / image_features.norm(dim=-1, keepdim=True) |
|||
text_features = text_features / text_features.norm(dim=-1, keepdim=True) |
|||
|
|||
# cosine similarity as logits |
|||
logit_scale = self.logit_scale.exp() |
|||
logits_per_image = logit_scale * image_features @ text_features.t() |
|||
logits_per_text = logit_scale * text_features @ image_features.t() |
|||
|
|||
# shape = [global_batch_size, global_batch_size] |
|||
return logits_per_image, logits_per_text |
|||
|
|||
|
|||
def convert_weights(model: nn.Module): |
|||
"""Convert applicable model parameters to fp16""" |
|||
|
|||
def _convert_weights_to_fp16(l): |
|||
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): |
|||
l.weight.data = l.weight.data.half() |
|||
if l.bias is not None: |
|||
l.bias.data = l.bias.data.half() |
|||
|
|||
if isinstance(l, nn.MultiheadAttention): |
|||
for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: |
|||
tensor = getattr(l, attr) |
|||
if tensor is not None: |
|||
tensor.data = tensor.data.half() |
|||
|
|||
for name in ["text_projection", "proj"]: |
|||
if hasattr(l, name): |
|||
attr = getattr(l, name) |
|||
if attr is not None: |
|||
attr.data = attr.data.half() |
|||
|
|||
model.apply(_convert_weights_to_fp16) |
|||
|
|||
|
|||
def build_model(state_dict: dict): |
|||
vit = "visual.proj" in state_dict |
|||
|
|||
if vit: |
|||
vision_width = state_dict["visual.conv1.weight"].shape[0] |
|||
vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) |
|||
vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] |
|||
grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) |
|||
image_resolution = vision_patch_size * grid_size |
|||
else: |
|||
counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] |
|||
vision_layers = tuple(counts) |
|||
vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] |
|||
output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) |
|||
vision_patch_size = None |
|||
assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] |
|||
image_resolution = output_width * 32 |
|||
|
|||
embed_dim = state_dict["text_projection"].shape[1] |
|||
context_length = state_dict["positional_embedding"].shape[0] |
|||
vocab_size = state_dict["token_embedding.weight"].shape[0] |
|||
transformer_width = state_dict["ln_final.weight"].shape[0] |
|||
transformer_heads = transformer_width // 64 |
|||
transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks"))) |
|||
|
|||
model = CLIP( |
|||
embed_dim, |
|||
image_resolution, vision_layers, vision_width, vision_patch_size, |
|||
context_length, vocab_size, transformer_width, transformer_heads, transformer_layers |
|||
) |
|||
|
|||
for key in ["input_resolution", "context_length", "vocab_size"]: |
|||
if key in state_dict: |
|||
del state_dict[key] |
|||
|
|||
convert_weights(model) |
|||
model.load_state_dict(state_dict) |
|||
return model.eval() |
@ -0,0 +1,132 @@ |
|||
import gzip |
|||
import html |
|||
import os |
|||
from functools import lru_cache |
|||
|
|||
import ftfy |
|||
import regex as re |
|||
|
|||
|
|||
@lru_cache() |
|||
def default_bpe(): |
|||
return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") |
|||
|
|||
|
|||
@lru_cache() |
|||
def bytes_to_unicode(): |
|||
""" |
|||
Returns list of utf-8 byte and a corresponding list of unicode strings. |
|||
The reversible bpe codes work on unicode strings. |
|||
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. |
|||
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. |
|||
This is a signficant percentage of your normal, say, 32K bpe vocab. |
|||
To avoid that, we want lookup tables between utf-8 bytes and unicode strings. |
|||
And avoids mapping to whitespace/control characters the bpe code barfs on. |
|||
""" |
|||
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) |
|||
cs = bs[:] |
|||
n = 0 |
|||
for b in range(2**8): |
|||
if b not in bs: |
|||
bs.append(b) |
|||
cs.append(2**8+n) |
|||
n += 1 |
|||
cs = [chr(n) for n in cs] |
|||
return dict(zip(bs, cs)) |
|||
|
|||
|
|||
def get_pairs(word): |
|||
"""Return set of symbol pairs in a word. |
|||
Word is represented as tuple of symbols (symbols being variable-length strings). |
|||
""" |
|||
pairs = set() |
|||
prev_char = word[0] |
|||
for char in word[1:]: |
|||
pairs.add((prev_char, char)) |
|||
prev_char = char |
|||
return pairs |
|||
|
|||
|
|||
def basic_clean(text): |
|||
text = ftfy.fix_text(text) |
|||
text = html.unescape(html.unescape(text)) |
|||
return text.strip() |
|||
|
|||
|
|||
def whitespace_clean(text): |
|||
text = re.sub(r'\s+', ' ', text) |
|||
text = text.strip() |
|||
return text |
|||
|
|||
|
|||
class SimpleTokenizer(object): |
|||
def __init__(self, bpe_path: str = default_bpe()): |
|||
self.byte_encoder = bytes_to_unicode() |
|||
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} |
|||
merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') |
|||
merges = merges[1:49152-256-2+1] |
|||
merges = [tuple(merge.split()) for merge in merges] |
|||
vocab = list(bytes_to_unicode().values()) |
|||
vocab = vocab + [v+'</w>' for v in vocab] |
|||
for merge in merges: |
|||
vocab.append(''.join(merge)) |
|||
vocab.extend(['<|startoftext|>', '<|endoftext|>']) |
|||
self.encoder = dict(zip(vocab, range(len(vocab)))) |
|||
self.decoder = {v: k for k, v in self.encoder.items()} |
|||
self.bpe_ranks = dict(zip(merges, range(len(merges)))) |
|||
self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} |
|||
self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) |
|||
|
|||
def bpe(self, token): |
|||
if token in self.cache: |
|||
return self.cache[token] |
|||
word = tuple(token[:-1]) + ( token[-1] + '</w>',) |
|||
pairs = get_pairs(word) |
|||
|
|||
if not pairs: |
|||
return token+'</w>' |
|||
|
|||
while True: |
|||
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) |
|||
if bigram not in self.bpe_ranks: |
|||
break |
|||
first, second = bigram |
|||
new_word = [] |
|||
i = 0 |
|||
while i < len(word): |
|||
try: |
|||
j = word.index(first, i) |
|||
new_word.extend(word[i:j]) |
|||
i = j |
|||
except: |
|||
new_word.extend(word[i:]) |
|||
break |
|||
|
|||
if word[i] == first and i < len(word)-1 and word[i+1] == second: |
|||
new_word.append(first+second) |
|||
i += 2 |
|||
else: |
|||
new_word.append(word[i]) |
|||
i += 1 |
|||
new_word = tuple(new_word) |
|||
word = new_word |
|||
if len(word) == 1: |
|||
break |
|||
else: |
|||
pairs = get_pairs(word) |
|||
word = ' '.join(word) |
|||
self.cache[token] = word |
|||
return word |
|||
|
|||
def encode(self, text): |
|||
bpe_tokens = [] |
|||
text = whitespace_clean(basic_clean(text)).lower() |
|||
for token in re.findall(self.pat, text): |
|||
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) |
|||
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) |
|||
return bpe_tokens |
|||
|
|||
def decode(self, tokens): |
|||
text = ''.join([self.decoder[token] for token in tokens]) |
|||
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ') |
|||
return text |
@ -0,0 +1,105 @@ |
|||
# Copyright 2021 Zilliz. All rights reserved. |
|||
# |
|||
# Licensed under the Apache License, Version 2.0 (the "License"); |
|||
# you may not use this file except in compliance with the License. |
|||
# You may obtain a copy of the License at |
|||
# |
|||
# http://www.apache.org/licenses/LICENSE-2.0 |
|||
# |
|||
# Unless required by applicable law or agreed to in writing, software |
|||
# distributed under the License is distributed on an "AS IS" BASIS, |
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|||
# See the License for the specific language governing permissions and |
|||
# limitations under the License. |
|||
import pathlib |
|||
|
|||
from torch import nn |
|||
from timm.models.vision_transformer import resize_pos_embed |
|||
from towhee.types.image_utils import to_pil |
|||
|
|||
|
|||
class ClipCaptionReward(NNOperator): |
|||
""" |
|||
BLIP multi-modal embedding operator |
|||
""" |
|||
def __init__(self, model_name: str): |
|||
super().__init__() |
|||
sys.path.append(str(Path(__file__).parent)) |
|||
from utils import opts |
|||
import clip |
|||
opt = opts.parse_opt(parse=False, cfg=cfg) |
|||
path = pathlib.Path(__file__).parent |
|||
dict_json = json.load(open("{}/data/cocotalk.json".format(path))) |
|||
ix_to_word = dict_json["ix_to_word"] |
|||
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|||
|
|||
clip_model, clip_transform = clip.load("RN50", jit=False, device=self.device) |
|||
self.clip_model = clip_model |
|||
self.clip_transform = clip_transform |
|||
|
|||
vocab_size = len(ix_to_word) |
|||
seq_length = 1 |
|||
opt.vocab_size = vocab_size |
|||
opt.seq_length = seq_length |
|||
opt.batch_size = 1 |
|||
opt.vocab = ix_to_word |
|||
|
|||
num_patches = 196 # 600 * 1000 // 32 // 32 |
|||
|
|||
pos_embed = nn.Parameter( |
|||
torch.zeros( |
|||
1, |
|||
num_patches + 1, |
|||
clip_model.visual.attnpool.positional_embedding.shape[-1], |
|||
device=self.device, |
|||
), |
|||
) |
|||
pos_embed.weight = resize_pos_embed( |
|||
clip_model.visual.attnpool.positional_embedding.unsqueeze(0), pos_embed |
|||
) |
|||
self.clip_model.visual.attnpool.positional_embedding = pos_embed |
|||
|
|||
self.model = TransformerModel(opt) |
|||
self.image_mean = ( |
|||
torch.Tensor([0.48145466, 0.4578275, 0.40821073]) |
|||
.to(self.device) |
|||
.reshape(3, 1, 1) |
|||
) |
|||
self.image_std = ( |
|||
torch.Tensor([0.26862954, 0.26130258, 0.27577711]) |
|||
.to(self.device) |
|||
.reshape(3, 1, 1) |
|||
) |
|||
|
|||
@arg(1, to_image_color('RGB')) |
|||
def inference_single_data(self, data): |
|||
text = self._inference_from_image(data) |
|||
return text |
|||
|
|||
@arg(1, to_image_color('RGB')) |
|||
def _inference_from_image(self, img): |
|||
img = to_pil(img) |
|||
img = self._preprocess(img) |
|||
self._inference_from_image(img) |
|||
img -= self.image_mean |
|||
img /= self.image_std |
|||
tmp_att, tmp_fc = self.clip_model.encode_image(img) |
|||
tmp_att = tmp_att[0].permute(1, 2, 0) |
|||
|
|||
att_feat = tmp_att |
|||
|
|||
return att_feat |
|||
|
|||
def __call__(self, data): |
|||
if not isinstance(data, list): |
|||
data = [data] |
|||
else: |
|||
data = data |
|||
for single_data in data: |
|||
result = self.inference_single_data(single_data) |
|||
results.append(result) |
|||
if len(data) == 1: |
|||
return results[0] |
|||
else: |
|||
return results |
|||
|
Binary file not shown.
@ -0,0 +1,60 @@ |
|||
caption_model: transformer |
|||
noamopt: true |
|||
noamopt_warmup: 20000 |
|||
label_smoothing: 0.0 |
|||
input_json: data/FineCapEval.json |
|||
input_label_h5: none |
|||
input_fc_dir: data/FineCapEval_clip_RN50_fc |
|||
input_att_dir: data/FineCapEval_clip_RN50_att |
|||
input_clipscore_vis_dir: data/FineCapEval_clipscore_vis |
|||
|
|||
seq_per_img: 5 |
|||
batch_size: 200 |
|||
learning_rate: 0.0005 |
|||
|
|||
checkpoint_path: ./save/clipRN50_mle/clipRN50_mle |
|||
|
|||
# clip_load_path: '/scratch-space/retrieval/save/clip_negative_text/clip_negative_text-epoch=10.ckpt' |
|||
|
|||
# Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer: |
|||
# N=num_layers |
|||
# d_model=input_encoding_size |
|||
# d_ff=rnn_size |
|||
|
|||
# will be ignored |
|||
num_layers: 6 |
|||
input_encoding_size: 512 |
|||
rnn_size: 2048 |
|||
|
|||
# Transformer config |
|||
N_enc: 6 |
|||
N_dec: 6 |
|||
d_model: 512 |
|||
d_ff: 2048 |
|||
num_att_heads: 8 |
|||
dropout: 0.1 |
|||
|
|||
|
|||
learning_rate_decay_start: 0 |
|||
scheduled_sampling_start: -1 |
|||
save_checkpoint_every: 3000 |
|||
language_eval: 1 |
|||
val_images_use: 5000 |
|||
max_epochs: 15 |
|||
train_sample_n: 5 |
|||
|
|||
REFORWARD: false |
|||
|
|||
# _BASE_: transformer.yml |
|||
reduce_on_plateau: false |
|||
noamopt: false |
|||
learning_rate: 0.000005 |
|||
learning_rate_decay_start: -1 |
|||
|
|||
self_critical_after: 15 |
|||
max_epochs: 50 |
|||
|
|||
verbose: false |
|||
precision: 32 |
|||
|
|||
use_clipscore: false |
@ -0,0 +1,52 @@ |
|||
caption_model: transformer |
|||
noamopt: true |
|||
# noamopt: false |
|||
noamopt_warmup: 20000 |
|||
label_smoothing: 0.0 |
|||
input_json: data/cocotalk.json |
|||
input_label_h5: data/cocotalk_label.h5 |
|||
input_fc_dir: data/cocotalk_clip_RN50_fc |
|||
input_att_dir: data/cocotalk_clip_RN50_att |
|||
input_clipscore_vis_dir: data/cocotalk_clipscore_vis |
|||
seq_per_img: 5 |
|||
# batch_size: 600 |
|||
batch_size: 200 |
|||
|
|||
learning_rate: 0.0005 |
|||
|
|||
# checkpoint_path: ./save/trans_clip_rn50_sc_pl |
|||
checkpoint_path: save/clipRN50_mle/clipRN50_mle |
|||
|
|||
# Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer: |
|||
# N=num_layers |
|||
# d_model=input_encoding_size |
|||
# d_ff=rnn_size |
|||
|
|||
# will be ignored |
|||
num_layers: 6 |
|||
input_encoding_size: 512 |
|||
rnn_size: 2048 |
|||
|
|||
# Transformer config |
|||
N_enc: 6 |
|||
N_dec: 6 |
|||
d_model: 512 |
|||
d_ff: 2048 |
|||
num_att_heads: 8 |
|||
dropout: 0.1 |
|||
|
|||
|
|||
learning_rate_decay_start: 0 |
|||
scheduled_sampling_start: -1 |
|||
save_checkpoint_every: 3000 |
|||
language_eval: 1 |
|||
val_images_use: 5000 |
|||
# max_epochs: 15 |
|||
max_epochs: 25 |
|||
train_sample_n: 5 |
|||
|
|||
REFORWARD: false |
|||
|
|||
|
|||
verbose: false |
|||
precision: 16 |
@ -0,0 +1,41 @@ |
|||
caption_model: transformer |
|||
noamopt: true |
|||
noamopt_warmup: 20000 |
|||
label_smoothing: 0.0 |
|||
input_json: data/cocotalk.json |
|||
input_label_h5: data/cocotalk_label.h5 |
|||
input_att_dir: data/cocotalk_att |
|||
seq_per_img: 5 |
|||
batch_size: 10 |
|||
learning_rate: 0.0005 |
|||
|
|||
checkpoint_path: ./save/trans_rn50_sc |
|||
|
|||
# Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer: |
|||
# N=num_layers |
|||
# d_model=input_encoding_size |
|||
# d_ff=rnn_size |
|||
|
|||
# will be ignored |
|||
num_layers: 6 |
|||
input_encoding_size: 512 |
|||
rnn_size: 2048 |
|||
|
|||
# Transformer config |
|||
N_enc: 6 |
|||
N_dec: 6 |
|||
d_model: 512 |
|||
d_ff: 2048 |
|||
num_att_heads: 8 |
|||
dropout: 0.1 |
|||
|
|||
|
|||
learning_rate_decay_start: 0 |
|||
scheduled_sampling_start: -1 |
|||
save_checkpoint_every: 3000 |
|||
language_eval: 1 |
|||
val_images_use: 5000 |
|||
max_epochs: 15 |
|||
train_sample_n: 5 |
|||
|
|||
REFORWARD: false |
@ -0,0 +1,61 @@ |
|||
caption_model: transformer |
|||
noamopt: true |
|||
noamopt_warmup: 20000 |
|||
label_smoothing: 0.0 |
|||
input_json: data/FineCapEval.json |
|||
input_label_h5: none |
|||
input_fc_dir: data/FineCapEval_clip_RN50_fc |
|||
input_att_dir: data/FineCapEval_clip_RN50_att |
|||
input_clipscore_vis_dir: data/FineCapEval_clipscore_vis |
|||
|
|||
seq_per_img: 5 |
|||
batch_size: 200 |
|||
learning_rate: 0.0005 |
|||
|
|||
checkpoint_path: ./save/clipRN50_cider/clipRN50_cider |
|||
|
|||
# clip_load_path: '/scratch-space/retrieval/save/clip_negative_text/clip_negative_text-epoch=10.ckpt' |
|||
|
|||
# Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer: |
|||
# N=num_layers |
|||
# d_model=input_encoding_size |
|||
# d_ff=rnn_size |
|||
|
|||
# will be ignored |
|||
num_layers: 6 |
|||
input_encoding_size: 512 |
|||
rnn_size: 2048 |
|||
|
|||
# Transformer config |
|||
N_enc: 6 |
|||
N_dec: 6 |
|||
d_model: 512 |
|||
d_ff: 2048 |
|||
num_att_heads: 8 |
|||
dropout: 0.1 |
|||
|
|||
|
|||
learning_rate_decay_start: 0 |
|||
scheduled_sampling_start: -1 |
|||
save_checkpoint_every: 3000 |
|||
language_eval: 1 |
|||
val_images_use: 5000 |
|||
max_epochs: 15 |
|||
train_sample_n: 5 |
|||
|
|||
REFORWARD: false |
|||
|
|||
# _BASE_: transformer.yml |
|||
reduce_on_plateau: false |
|||
noamopt: false |
|||
learning_rate: 0.000005 |
|||
learning_rate_decay_start: -1 |
|||
|
|||
self_critical_after: 15 |
|||
max_epochs: 50 |
|||
|
|||
verbose: false |
|||
precision: 32 |
|||
|
|||
# use_clipscore: true |
|||
use_clipscore: false |
@ -0,0 +1,65 @@ |
|||
caption_model: transformer |
|||
noamopt: true |
|||
noamopt_warmup: 20000 |
|||
label_smoothing: 0.0 |
|||
input_json: data/FineCapEval.json |
|||
input_label_h5: none |
|||
input_fc_dir: data/FineCapEval_clip_RN50_fc |
|||
input_att_dir: data/FineCapEval_clip_RN50_att |
|||
input_clipscore_vis_dir: data/FineCapEval_clipscore_vis |
|||
|
|||
seq_per_img: 5 |
|||
batch_size: 200 |
|||
learning_rate: 0.0005 |
|||
|
|||
checkpoint_path: ./save/clipRN50_cider_clips/clipRN50_cider_clips |
|||
|
|||
# clip_load_path: '/scratch-space/retrieval/save/clip_negative_text/clip_negative_text-epoch=10.ckpt' |
|||
|
|||
# Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer: |
|||
# N=num_layers |
|||
# d_model=input_encoding_size |
|||
# d_ff=rnn_size |
|||
|
|||
# will be ignored |
|||
num_layers: 6 |
|||
input_encoding_size: 512 |
|||
rnn_size: 2048 |
|||
|
|||
# Transformer config |
|||
N_enc: 6 |
|||
N_dec: 6 |
|||
d_model: 512 |
|||
d_ff: 2048 |
|||
num_att_heads: 8 |
|||
dropout: 0.1 |
|||
|
|||
|
|||
learning_rate_decay_start: 0 |
|||
scheduled_sampling_start: -1 |
|||
save_checkpoint_every: 3000 |
|||
language_eval: 1 |
|||
val_images_use: 5000 |
|||
max_epochs: 15 |
|||
train_sample_n: 5 |
|||
|
|||
REFORWARD: false |
|||
|
|||
# _BASE_: transformer.yml |
|||
reduce_on_plateau: false |
|||
noamopt: false |
|||
learning_rate: 0.000005 |
|||
learning_rate_decay_start: -1 |
|||
|
|||
self_critical_after: 15 |
|||
max_epochs: 50 |
|||
|
|||
verbose: false |
|||
precision: 32 |
|||
|
|||
# use_clipscore: true |
|||
use_clipscore: false |
|||
clipscore_reward_weight: 2.0 |
|||
clipscore_mode: clip_s |
|||
|
|||
use_multi_rewards: true |
@ -0,0 +1,64 @@ |
|||
caption_model: transformer |
|||
noamopt: true |
|||
noamopt_warmup: 20000 |
|||
label_smoothing: 0.0 |
|||
input_json: data/FineCapEval.json |
|||
input_label_h5: none |
|||
input_fc_dir: data/FineCapEval_clip_RN50_fc |
|||
input_att_dir: data/FineCapEval_clip_RN50_att |
|||
input_clipscore_vis_dir: data/FineCapEval_clipscore_vis |
|||
seq_per_img: 5 |
|||
batch_size: 160 |
|||
learning_rate: 0.0005 |
|||
|
|||
checkpoint_path: ./save/clipRN50_clips/clipRN50_clips |
|||
|
|||
use_multi_rewards: false |
|||
use_grammar: false |
|||
use_grammar_baseline: false |
|||
# clip_load_path: '/scratch-space/retrieval/save/clip_negative_text/clip_negative_text-epoch=10.ckpt' |
|||
|
|||
# Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer: |
|||
# N=num_layers |
|||
# d_model=input_encoding_size |
|||
# d_ff=rnn_size |
|||
|
|||
# will be ignored |
|||
num_layers: 6 |
|||
input_encoding_size: 512 |
|||
rnn_size: 2048 |
|||
|
|||
# Transformer config |
|||
N_enc: 6 |
|||
N_dec: 6 |
|||
d_model: 512 |
|||
d_ff: 2048 |
|||
num_att_heads: 8 |
|||
dropout: 0.1 |
|||
|
|||
|
|||
learning_rate_decay_start: 0 |
|||
scheduled_sampling_start: -1 |
|||
save_checkpoint_every: 3000 |
|||
language_eval: 0 |
|||
val_images_use: 5000 |
|||
max_epochs: 15 |
|||
train_sample_n: 5 |
|||
|
|||
REFORWARD: false |
|||
|
|||
# _BASE_: transformer.yml |
|||
reduce_on_plateau: false |
|||
noamopt: false |
|||
learning_rate: 0.000005 |
|||
learning_rate_decay_start: -1 |
|||
|
|||
self_critical_after: 15 |
|||
max_epochs: 50 |
|||
|
|||
verbose: false |
|||
precision: 32 |
|||
|
|||
# use_clipscore: true |
|||
use_clipscore: false |
|||
clipscore_reward_weight: 2.0 |
@ -0,0 +1,64 @@ |
|||
caption_model: transformer |
|||
noamopt: true |
|||
noamopt_warmup: 20000 |
|||
label_smoothing: 0.0 |
|||
input_json: data/FineCapEval.json |
|||
input_label_h5: none |
|||
input_fc_dir: data/FineCapEval_clip_RN50_fc |
|||
input_att_dir: data/FineCapEval_clip_RN50_att |
|||
input_clipscore_vis_dir: data/FineCapEval_clipscore_vis |
|||
seq_per_img: 5 |
|||
batch_size: 160 |
|||
learning_rate: 0.0005 |
|||
|
|||
checkpoint_path: ./save/clipRN50_clips_grammar/clipRN50_clips_grammar |
|||
|
|||
use_multi_rewards: true |
|||
use_grammar: true |
|||
use_grammar_baseline: true |
|||
# clip_load_path: '/scratch-space/retrieval/save/clip_negative_text/clip_negative_text-epoch=10.ckpt' |
|||
|
|||
# Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer: |
|||
# N=num_layers |
|||
# d_model=input_encoding_size |
|||
# d_ff=rnn_size |
|||
|
|||
# will be ignored |
|||
num_layers: 6 |
|||
input_encoding_size: 512 |
|||
rnn_size: 2048 |
|||
|
|||
# Transformer config |
|||
N_enc: 6 |
|||
N_dec: 6 |
|||
d_model: 512 |
|||
d_ff: 2048 |
|||
num_att_heads: 8 |
|||
dropout: 0.1 |
|||
|
|||
|
|||
learning_rate_decay_start: 0 |
|||
scheduled_sampling_start: -1 |
|||
save_checkpoint_every: 3000 |
|||
language_eval: 0 |
|||
val_images_use: 5000 |
|||
max_epochs: 15 |
|||
train_sample_n: 5 |
|||
|
|||
REFORWARD: false |
|||
|
|||
# _BASE_: transformer.yml |
|||
reduce_on_plateau: false |
|||
noamopt: false |
|||
learning_rate: 0.000005 |
|||
learning_rate_decay_start: -1 |
|||
|
|||
self_critical_after: 15 |
|||
max_epochs: 50 |
|||
|
|||
verbose: false |
|||
precision: 32 |
|||
|
|||
# use_clipscore: true |
|||
use_clipscore: false |
|||
clipscore_reward_weight: 2.0 |
@ -0,0 +1,58 @@ |
|||
caption_model: transformer |
|||
noamopt: true |
|||
noamopt_warmup: 20000 |
|||
label_smoothing: 0.0 |
|||
input_json: data/cocotalk.json |
|||
input_label_h5: data/cocotalk_label.h5 |
|||
input_fc_dir: data/cocotalk_clip_RN50_fc |
|||
input_att_dir: data/cocotalk_clip_RN50_att |
|||
# used only for evaluation |
|||
input_clipscore_vis_dir: data/cocotalk_clipscore_vis |
|||
|
|||
seq_per_img: 5 |
|||
batch_size: 200 |
|||
learning_rate: 0.0005 |
|||
|
|||
# checkpoint_path: ./save/trans_clip_rn50_sc_pl_scst_cider |
|||
checkpoint_path: save/clipRN50_cider/clipRN50_cider |
|||
|
|||
# Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer: |
|||
# N=num_layers |
|||
# d_model=input_encoding_size |
|||
# d_ff=rnn_size |
|||
|
|||
# will be ignored |
|||
num_layers: 6 |
|||
input_encoding_size: 512 |
|||
rnn_size: 2048 |
|||
|
|||
# Transformer config |
|||
N_enc: 6 |
|||
N_dec: 6 |
|||
d_model: 512 |
|||
d_ff: 2048 |
|||
num_att_heads: 8 |
|||
dropout: 0.1 |
|||
|
|||
|
|||
learning_rate_decay_start: 0 |
|||
scheduled_sampling_start: -1 |
|||
save_checkpoint_every: 3000 |
|||
language_eval: 1 |
|||
val_images_use: 5000 |
|||
max_epochs: 15 |
|||
train_sample_n: 5 |
|||
|
|||
REFORWARD: false |
|||
|
|||
# _BASE_: transformer.yml |
|||
reduce_on_plateau: false |
|||
noamopt: false |
|||
learning_rate: 0.000005 |
|||
learning_rate_decay_start: -1 |
|||
|
|||
self_critical_after: 15 |
|||
max_epochs: 40 |
|||
|
|||
verbose: false |
|||
precision: 32 |
@ -0,0 +1,61 @@ |
|||
caption_model: transformer |
|||
noamopt: true |
|||
noamopt_warmup: 20000 |
|||
label_smoothing: 0.0 |
|||
input_json: data/cocotalk.json |
|||
input_label_h5: data/cocotalk_label.h5 |
|||
input_fc_dir: data/cocotalk_clip_RN50_fc |
|||
input_att_dir: data/cocotalk_clip_RN50_att |
|||
input_clipscore_vis_dir: data/cocotalk_clipscore_vis |
|||
seq_per_img: 5 |
|||
batch_size: 160 |
|||
learning_rate: 0.0005 |
|||
|
|||
checkpoint_path: save/clipRN50_cider_clips/clipRN50_cider_clips |
|||
|
|||
# Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer: |
|||
# N=num_layers |
|||
# d_model=input_encoding_size |
|||
# d_ff=rnn_size |
|||
|
|||
# will be ignored |
|||
num_layers: 6 |
|||
input_encoding_size: 512 |
|||
rnn_size: 2048 |
|||
|
|||
# Transformer config |
|||
N_enc: 6 |
|||
N_dec: 6 |
|||
d_model: 512 |
|||
d_ff: 2048 |
|||
num_att_heads: 8 |
|||
dropout: 0.1 |
|||
|
|||
|
|||
learning_rate_decay_start: 0 |
|||
scheduled_sampling_start: -1 |
|||
save_checkpoint_every: 3000 |
|||
language_eval: 1 |
|||
val_images_use: 5000 |
|||
max_epochs: 15 |
|||
train_sample_n: 5 |
|||
|
|||
REFORWARD: false |
|||
|
|||
# _BASE_: transformer.yml |
|||
reduce_on_plateau: false |
|||
noamopt: false |
|||
learning_rate: 0.000005 |
|||
learning_rate_decay_start: -1 |
|||
|
|||
self_critical_after: 15 |
|||
max_epochs: 40 |
|||
|
|||
verbose: false |
|||
precision: 32 |
|||
|
|||
use_clipscore: true |
|||
clipscore_reward_weight: 2.0 |
|||
clipscore_mode: clip_s |
|||
|
|||
use_multi_rewards: true |
@ -0,0 +1,58 @@ |
|||
caption_model: transformer |
|||
noamopt: true |
|||
noamopt_warmup: 20000 |
|||
label_smoothing: 0.0 |
|||
input_json: data/cocotalk.json |
|||
input_label_h5: data/cocotalk_label.h5 |
|||
input_fc_dir: data/cocotalk_clip_RN50_fc |
|||
input_att_dir: data/cocotalk_clip_RN50_att |
|||
input_clipscore_vis_dir: data/cocotalk_clipscore_vis |
|||
seq_per_img: 5 |
|||
batch_size: 160 |
|||
learning_rate: 0.0005 |
|||
|
|||
checkpoint_path: save/clipRN50_clips/clipRN50_clips |
|||
|
|||
# Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer: |
|||
# N=num_layers |
|||
# d_model=input_encoding_size |
|||
# d_ff=rnn_size |
|||
|
|||
# will be ignored |
|||
num_layers: 6 |
|||
input_encoding_size: 512 |
|||
rnn_size: 2048 |
|||
|
|||
# Transformer config |
|||
N_enc: 6 |
|||
N_dec: 6 |
|||
d_model: 512 |
|||
d_ff: 2048 |
|||
num_att_heads: 8 |
|||
dropout: 0.1 |
|||
|
|||
|
|||
learning_rate_decay_start: 0 |
|||
scheduled_sampling_start: -1 |
|||
save_checkpoint_every: 3000 |
|||
language_eval: 1 |
|||
val_images_use: 5000 |
|||
max_epochs: 15 |
|||
train_sample_n: 5 |
|||
|
|||
REFORWARD: false |
|||
|
|||
# _BASE_: transformer.yml |
|||
reduce_on_plateau: false |
|||
noamopt: false |
|||
learning_rate: 0.000005 |
|||
learning_rate_decay_start: -1 |
|||
|
|||
self_critical_after: 15 |
|||
max_epochs: 40 |
|||
|
|||
verbose: false |
|||
precision: 32 |
|||
|
|||
use_clipscore: true |
|||
clipscore_reward_weight: 2.0 |
@ -0,0 +1,64 @@ |
|||
caption_model: transformer |
|||
noamopt: true |
|||
noamopt_warmup: 20000 |
|||
label_smoothing: 0.0 |
|||
input_json: data/cocotalk.json |
|||
input_label_h5: data/cocotalk_label.h5 |
|||
input_fc_dir: data/cocotalk_clip_RN50_fc |
|||
input_att_dir: data/cocotalk_clip_RN50_att |
|||
input_clipscore_vis_dir: data/cocotalk_clipscore_vis |
|||
seq_per_img: 5 |
|||
batch_size: 160 |
|||
learning_rate: 0.0005 |
|||
|
|||
checkpoint_path: save/clipRN50_clips_grammar/clipRN50_clips_grammar |
|||
|
|||
use_multi_rewards: true |
|||
use_grammar: true |
|||
use_grammar_baseline: true |
|||
# clip_load_path: '/scratch-space/retrieval/save/clip_negative_text/clip_negative_text-epoch=10.ckpt' |
|||
clip_load_path: 'retrieval/save/clip_negative_text/clip_negative_text-epoch=12.ckpt' |
|||
|
|||
# Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer: |
|||
# N=num_layers |
|||
# d_model=input_encoding_size |
|||
# d_ff=rnn_size |
|||
|
|||
# will be ignored |
|||
num_layers: 6 |
|||
input_encoding_size: 512 |
|||
rnn_size: 2048 |
|||
|
|||
# Transformer config |
|||
N_enc: 6 |
|||
N_dec: 6 |
|||
d_model: 512 |
|||
d_ff: 2048 |
|||
num_att_heads: 8 |
|||
dropout: 0.1 |
|||
|
|||
|
|||
learning_rate_decay_start: 0 |
|||
scheduled_sampling_start: -1 |
|||
save_checkpoint_every: 3000 |
|||
language_eval: 1 |
|||
val_images_use: 5000 |
|||
max_epochs: 15 |
|||
train_sample_n: 5 |
|||
|
|||
REFORWARD: false |
|||
|
|||
# _BASE_: transformer.yml |
|||
reduce_on_plateau: false |
|||
noamopt: false |
|||
learning_rate: 0.000005 |
|||
learning_rate_decay_start: -1 |
|||
|
|||
self_critical_after: 15 |
|||
max_epochs: 40 |
|||
|
|||
verbose: false |
|||
precision: 32 |
|||
|
|||
use_clipscore: true |
|||
clipscore_reward_weight: 2.0 |
@ -0,0 +1,41 @@ |
|||
caption_model: transformer |
|||
noamopt: true |
|||
noamopt_warmup: 20000 |
|||
label_smoothing: 0.0 |
|||
input_json: data/cocotalk.json |
|||
input_label_h5: data/cocotalk_label.h5 |
|||
input_att_dir: data/cocotalk_att |
|||
seq_per_img: 5 |
|||
batch_size: 10 |
|||
learning_rate: 0.0005 |
|||
|
|||
checkpoint_path: ./save/trans_rn50_sc |
|||
|
|||
# Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer: |
|||
# N=num_layers |
|||
# d_model=input_encoding_size |
|||
# d_ff=rnn_size |
|||
|
|||
# will be ignored |
|||
num_layers: 6 |
|||
input_encoding_size: 512 |
|||
rnn_size: 2048 |
|||
|
|||
# Transformer config |
|||
N_enc: 6 |
|||
N_dec: 6 |
|||
d_model: 512 |
|||
d_ff: 2048 |
|||
num_att_heads: 8 |
|||
dropout: 0.1 |
|||
|
|||
|
|||
learning_rate_decay_start: 0 |
|||
scheduled_sampling_start: -1 |
|||
save_checkpoint_every: 3000 |
|||
language_eval: 1 |
|||
val_images_use: 5000 |
|||
max_epochs: 15 |
|||
train_sample_n: 5 |
|||
|
|||
REFORWARD: false |
File diff suppressed because one or more lines are too long
@ -0,0 +1,379 @@ |
|||
# This file contains Transformer network |
|||
# Most of the code is copied from http://nlp.seas.harvard.edu/2018/04/03/attention.html |
|||
|
|||
# The cfg name correspondance: |
|||
# N=num_layers |
|||
# d_model=input_encoding_size |
|||
# d_ff=rnn_size |
|||
# h is always 8 |
|||
|
|||
from __future__ import absolute_import |
|||
from __future__ import division |
|||
from __future__ import print_function |
|||
|
|||
import torch |
|||
import torch.nn as nn |
|||
import torch.nn.functional as F |
|||
#from . import utils |
|||
|
|||
import copy |
|||
import math |
|||
import numpy as np |
|||
|
|||
#from .CaptionModel import CaptionModel |
|||
#from .AttModel import sort_pack_padded_sequence, pad_unsort_packed_sequence, pack_wrapper, AttModel |
|||
from captioning.models.CaptionModel import CaptionModel |
|||
from captioning.models.AttModel import sort_pack_padded_sequence, pad_unsort_packed_sequence, pack_wrapper, AttModel |
|||
|
|||
def repeat_tensors(n, x): |
|||
""" |
|||
For a tensor of size Bx..., we repeat it n times, and make it Bnx... |
|||
For collections, do nested repeat |
|||
""" |
|||
if torch.is_tensor(x): |
|||
x = x.unsqueeze(1) # Bx1x... |
|||
x = x.expand(-1, n, *([-1]*len(x.shape[2:]))) # Bxnx... |
|||
x = x.reshape(x.shape[0]*n, *x.shape[2:]) # Bnx... |
|||
elif type(x) is list or type(x) is tuple: |
|||
x = [repeat_tensors(n, _) for _ in x] |
|||
return x |
|||
|
|||
|
|||
class EncoderDecoder(nn.Module): |
|||
""" |
|||
A standard Encoder-Decoder architecture. Base for this and many |
|||
other models. |
|||
""" |
|||
def __init__(self, encoder, decoder, src_embed, tgt_embed, generator): |
|||
super(EncoderDecoder, self).__init__() |
|||
self.encoder = encoder |
|||
self.decoder = decoder |
|||
self.src_embed = src_embed |
|||
self.tgt_embed = tgt_embed |
|||
self.generator = generator |
|||
|
|||
def forward(self, src, tgt, src_mask, tgt_mask): |
|||
"Take in and process masked src and target sequences." |
|||
return self.decode(self.encode(src, src_mask), src_mask, |
|||
tgt, tgt_mask) |
|||
|
|||
def encode(self, src, src_mask): |
|||
return self.encoder(self.src_embed(src), src_mask) |
|||
|
|||
def decode(self, memory, src_mask, tgt, tgt_mask): |
|||
return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask) |
|||
|
|||
class Generator(nn.Module): |
|||
"Define standard linear + softmax generation step." |
|||
def __init__(self, d_model, vocab): |
|||
super(Generator, self).__init__() |
|||
self.proj = nn.Linear(d_model, vocab) |
|||
|
|||
def forward(self, x): |
|||
return F.log_softmax(self.proj(x), dim=-1) |
|||
|
|||
def clones(module, N): |
|||
"Produce N identical layers." |
|||
return nn.ModuleList([copy.deepcopy(module) for _ in range(N)]) |
|||
|
|||
class Encoder(nn.Module): |
|||
"Core encoder is a stack of N layers" |
|||
def __init__(self, layer, N): |
|||
super(Encoder, self).__init__() |
|||
self.layers = clones(layer, N) |
|||
self.norm = LayerNorm(layer.size) |
|||
|
|||
def forward(self, x, mask): |
|||
"Pass the input (and mask) through each layer in turn." |
|||
for layer in self.layers: |
|||
x = layer(x, mask) |
|||
return self.norm(x) |
|||
|
|||
class LayerNorm(nn.Module): |
|||
"Construct a layernorm module (See citation for details)." |
|||
def __init__(self, features, eps=1e-6): |
|||
super(LayerNorm, self).__init__() |
|||
self.a_2 = nn.Parameter(torch.ones(features)) |
|||
self.b_2 = nn.Parameter(torch.zeros(features)) |
|||
self.eps = eps |
|||
|
|||
def forward(self, x): |
|||
mean = x.mean(-1, keepdim=True) |
|||
std = x.std(-1, keepdim=True) |
|||
return self.a_2 * (x - mean) / (std + self.eps) + self.b_2 |
|||
|
|||
class SublayerConnection(nn.Module): |
|||
""" |
|||
A residual connection followed by a layer norm. |
|||
Note for code simplicity the norm is first as opposed to last. |
|||
""" |
|||
def __init__(self, size, dropout): |
|||
super(SublayerConnection, self).__init__() |
|||
self.norm = LayerNorm(size) |
|||
self.dropout = nn.Dropout(dropout) |
|||
|
|||
def forward(self, x, sublayer): |
|||
"Apply residual connection to any sublayer with the same size." |
|||
return x + self.dropout(sublayer(self.norm(x))) |
|||
|
|||
class EncoderLayer(nn.Module): |
|||
"Encoder is made up of self-attn and feed forward (defined below)" |
|||
def __init__(self, size, self_attn, feed_forward, dropout): |
|||
super(EncoderLayer, self).__init__() |
|||
self.self_attn = self_attn |
|||
self.feed_forward = feed_forward |
|||
self.sublayer = clones(SublayerConnection(size, dropout), 2) |
|||
self.size = size |
|||
|
|||
def forward(self, x, mask): |
|||
"Follow Figure 1 (left) for connections." |
|||
x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask)) |
|||
return self.sublayer[1](x, self.feed_forward) |
|||
|
|||
class Decoder(nn.Module): |
|||
"Generic N layer decoder with masking." |
|||
def __init__(self, layer, N): |
|||
super(Decoder, self).__init__() |
|||
self.layers = clones(layer, N) |
|||
self.norm = LayerNorm(layer.size) |
|||
|
|||
def forward(self, x, memory, src_mask, tgt_mask): |
|||
for layer in self.layers: |
|||
x = layer(x, memory, src_mask, tgt_mask) |
|||
return self.norm(x) |
|||
|
|||
class DecoderLayer(nn.Module): |
|||
"Decoder is made of self-attn, src-attn, and feed forward (defined below)" |
|||
def __init__(self, size, self_attn, src_attn, feed_forward, dropout): |
|||
super(DecoderLayer, self).__init__() |
|||
self.size = size |
|||
self.self_attn = self_attn |
|||
self.src_attn = src_attn |
|||
self.feed_forward = feed_forward |
|||
self.sublayer = clones(SublayerConnection(size, dropout), 3) |
|||
|
|||
def forward(self, x, memory, src_mask, tgt_mask): |
|||
"Follow Figure 1 (right) for connections." |
|||
m = memory |
|||
x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask)) |
|||
x = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask)) |
|||
return self.sublayer[2](x, self.feed_forward) |
|||
|
|||
def subsequent_mask(size): |
|||
"Mask out subsequent positions." |
|||
attn_shape = (1, size, size) |
|||
subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8') |
|||
return torch.from_numpy(subsequent_mask) == 0 |
|||
|
|||
def attention(query, key, value, mask=None, dropout=None): |
|||
"Compute 'Scaled Dot Product Attention'" |
|||
d_k = query.size(-1) |
|||
scores = torch.matmul(query, key.transpose(-2, -1)) \ |
|||
/ math.sqrt(d_k) |
|||
if mask is not None: |
|||
scores = scores.masked_fill(mask == 0, float('-inf')) |
|||
p_attn = F.softmax(scores, dim = -1) |
|||
if dropout is not None: |
|||
p_attn = dropout(p_attn) |
|||
return torch.matmul(p_attn, value), p_attn |
|||
|
|||
class MultiHeadedAttention(nn.Module): |
|||
def __init__(self, h, d_model, dropout=0.1): |
|||
"Take in model size and number of heads." |
|||
super(MultiHeadedAttention, self).__init__() |
|||
assert d_model % h == 0 |
|||
# We assume d_v always equals d_k |
|||
self.d_k = d_model // h |
|||
self.h = h |
|||
self.linears = clones(nn.Linear(d_model, d_model), 4) |
|||
self.attn = None |
|||
self.dropout = nn.Dropout(p=dropout) |
|||
|
|||
def forward(self, query, key, value, mask=None): |
|||
"Implements Figure 2" |
|||
if mask is not None: |
|||
# Same mask applied to all h heads. |
|||
mask = mask.unsqueeze(1) |
|||
nbatches = query.size(0) |
|||
|
|||
# 1) Do all the linear projections in batch from d_model => h x d_k |
|||
query, key, value = \ |
|||
[l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2) |
|||
for l, x in zip(self.linears, (query, key, value))] |
|||
|
|||
# 2) Apply attention on all the projected vectors in batch. |
|||
x, self.attn = attention(query, key, value, mask=mask, |
|||
dropout=self.dropout) |
|||
|
|||
# 3) "Concat" using a view and apply a final linear. |
|||
x = x.transpose(1, 2).contiguous() \ |
|||
.view(nbatches, -1, self.h * self.d_k) |
|||
return self.linears[-1](x) |
|||
|
|||
class PositionwiseFeedForward(nn.Module): |
|||
"Implements FFN equation." |
|||
def __init__(self, d_model, d_ff, dropout=0.1): |
|||
super(PositionwiseFeedForward, self).__init__() |
|||
self.w_1 = nn.Linear(d_model, d_ff) |
|||
self.w_2 = nn.Linear(d_ff, d_model) |
|||
self.dropout = nn.Dropout(dropout) |
|||
|
|||
def forward(self, x): |
|||
return self.w_2(self.dropout(F.relu(self.w_1(x)))) |
|||
|
|||
class Embeddings(nn.Module): |
|||
def __init__(self, d_model, vocab): |
|||
super(Embeddings, self).__init__() |
|||
self.lut = nn.Embedding(vocab, d_model) |
|||
self.d_model = d_model |
|||
|
|||
def forward(self, x): |
|||
return self.lut(x) * math.sqrt(self.d_model) |
|||
|
|||
class PositionalEncoding(nn.Module): |
|||
"Implement the PE function." |
|||
def __init__(self, d_model, dropout, max_len=5000): |
|||
super(PositionalEncoding, self).__init__() |
|||
self.dropout = nn.Dropout(p=dropout) |
|||
|
|||
# Compute the positional encodings once in log space. |
|||
pe = torch.zeros(max_len, d_model) |
|||
position = torch.arange(0, max_len).unsqueeze(1).float() |
|||
div_term = torch.exp(torch.arange(0, d_model, 2).float() * |
|||
-(math.log(10000.0) / d_model)) |
|||
pe[:, 0::2] = torch.sin(position * div_term) |
|||
pe[:, 1::2] = torch.cos(position * div_term) |
|||
pe = pe.unsqueeze(0) |
|||
self.register_buffer('pe', pe) |
|||
|
|||
def forward(self, x): |
|||
x = x + self.pe[:, :x.size(1)] |
|||
return self.dropout(x) |
|||
|
|||
class TransformerModel(AttModel): |
|||
|
|||
def make_model(self, src_vocab, tgt_vocab, N_enc=6, N_dec=6, |
|||
d_model=512, d_ff=2048, h=8, dropout=0.1): |
|||
"Helper: Construct a model from hyperparameters." |
|||
c = copy.deepcopy |
|||
attn = MultiHeadedAttention(h, d_model, dropout) |
|||
ff = PositionwiseFeedForward(d_model, d_ff, dropout) |
|||
position = PositionalEncoding(d_model, dropout) |
|||
model = EncoderDecoder( |
|||
Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N_enc), |
|||
Decoder(DecoderLayer(d_model, c(attn), c(attn), |
|||
c(ff), dropout), N_dec), |
|||
lambda x:x, # nn.Sequential(Embeddings(d_model, src_vocab), c(position)), |
|||
nn.Sequential(Embeddings(d_model, tgt_vocab), c(position)), |
|||
Generator(d_model, tgt_vocab)) |
|||
|
|||
# This was important from their code. |
|||
# Initialize parameters with Glorot / fan_avg. |
|||
for p in model.parameters(): |
|||
if p.dim() > 1: |
|||
nn.init.xavier_uniform_(p) |
|||
return model |
|||
|
|||
def __init__(self, opt): |
|||
super(TransformerModel, self).__init__(opt) |
|||
self.opt = opt |
|||
# self.config = yaml.load(open(opt.config_file)) |
|||
|
|||
self.N_enc = getattr(opt, 'N_enc', opt.num_layers) |
|||
self.N_dec = getattr(opt, 'N_dec', opt.num_layers) |
|||
self.d_model = getattr(opt, 'd_model', opt.input_encoding_size) |
|||
self.d_ff = getattr(opt, 'd_ff', opt.rnn_size) |
|||
self.h = getattr(opt, 'num_att_heads', 8) |
|||
self.dropout = getattr(opt, 'dropout', 0.1) |
|||
|
|||
delattr(self, 'att_embed') |
|||
self.att_embed = nn.Sequential(*( |
|||
((nn.BatchNorm1d(self.att_feat_size),) if self.use_bn else ())+ |
|||
(nn.Linear(self.att_feat_size, self.d_model), |
|||
nn.ReLU(), |
|||
nn.Dropout(self.drop_prob_lm))+ |
|||
((nn.BatchNorm1d(self.d_model),) if self.use_bn==2 else ()))) |
|||
|
|||
delattr(self, 'embed') |
|||
self.embed = lambda x : x |
|||
delattr(self, 'fc_embed') |
|||
self.fc_embed = lambda x : x |
|||
delattr(self, 'logit') |
|||
del self.ctx2att |
|||
|
|||
tgt_vocab = self.vocab_size + 1 |
|||
|
|||
|
|||
self.model = self.make_model(0, tgt_vocab, |
|||
N_enc=self.N_enc, |
|||
N_dec=self.N_dec, |
|||
d_model=self.d_model, |
|||
d_ff=self.d_ff, |
|||
h=self.h, |
|||
dropout=self.dropout) |
|||
|
|||
def logit(self, x): # unsafe way |
|||
return self.model.generator.proj(x) |
|||
|
|||
def init_hidden(self, bsz): |
|||
return [] |
|||
|
|||
def _prepare_feature(self, fc_feats, att_feats, att_masks): |
|||
|
|||
att_feats, seq, att_masks, seq_mask = self._prepare_feature_forward(att_feats, att_masks) |
|||
memory = self.model.encode(att_feats, att_masks) |
|||
|
|||
return fc_feats[...,:0], att_feats[...,:0], memory, att_masks |
|||
|
|||
def _prepare_feature_forward(self, att_feats, att_masks=None, seq=None): |
|||
att_feats, att_masks = self.clip_att(att_feats, att_masks) |
|||
|
|||
att_feats = pack_wrapper(self.att_embed, att_feats, att_masks) |
|||
|
|||
if att_masks is None: |
|||
att_masks = att_feats.new_ones(att_feats.shape[:2], dtype=torch.long) |
|||
att_masks = att_masks.unsqueeze(-2) |
|||
|
|||
if seq is not None: |
|||
# crop the last one |
|||
# seq = seq[:,:-1] |
|||
seq_mask = (seq.data != self.eos_idx) & (seq.data != self.pad_idx) |
|||
seq_mask[:,0] = 1 # bos |
|||
|
|||
seq_mask = seq_mask.unsqueeze(-2) |
|||
seq_mask = seq_mask & subsequent_mask(seq.size(-1)).to(seq_mask) |
|||
|
|||
seq_per_img = seq.shape[0] // att_feats.shape[0] |
|||
if seq_per_img > 1: |
|||
att_feats, att_masks = utils.repeat_tensors(seq_per_img, |
|||
[att_feats, att_masks] |
|||
) |
|||
else: |
|||
seq_mask = None |
|||
|
|||
return att_feats, seq, att_masks, seq_mask |
|||
|
|||
def _forward(self, fc_feats, att_feats, seq, att_masks=None): |
|||
if seq.ndim == 3: # B * seq_per_img * seq_len |
|||
seq = seq.reshape(-1, seq.shape[2]) |
|||
att_feats, seq, att_masks, seq_mask = self._prepare_feature_forward(att_feats, att_masks, seq) |
|||
|
|||
out = self.model(att_feats, seq, att_masks, seq_mask) |
|||
|
|||
outputs = self.model.generator(out) |
|||
return outputs |
|||
# return torch.cat([_.unsqueeze(1) for _ in outputs], 1) |
|||
|
|||
def core(self, it, fc_feats_ph, att_feats_ph, memory, state, mask): |
|||
""" |
|||
state = [ys.unsqueeze(0)] |
|||
""" |
|||
if len(state) == 0: |
|||
ys = it.unsqueeze(1) |
|||
else: |
|||
ys = torch.cat([state[0][0], it.unsqueeze(1)], dim=1) |
|||
out = self.model.decode(memory, mask, |
|||
ys, |
|||
subsequent_mask(ys.size(1)) |
|||
.to(memory.device)) |
|||
return out[:, -1], [ys.unsqueeze(0)] |
@ -0,0 +1,153 @@ |
|||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. |
|||
# Copy from fvcore |
|||
|
|||
import logging |
|||
import os |
|||
from typing import Any |
|||
import yaml |
|||
from yacs.config import CfgNode as _CfgNode |
|||
|
|||
import io as PathManager |
|||
|
|||
BASE_KEY = "_BASE_" |
|||
|
|||
|
|||
class CfgNode(_CfgNode): |
|||
""" |
|||
Our own extended version of :class:`yacs.config.CfgNode`. |
|||
It contains the following extra features: |
|||
|
|||
1. The :meth:`merge_from_file` method supports the "_BASE_" key, |
|||
which allows the new CfgNode to inherit all the attributes from the |
|||
base configuration file. |
|||
2. Keys that start with "COMPUTED_" are treated as insertion-only |
|||
"computed" attributes. They can be inserted regardless of whether |
|||
the CfgNode is frozen or not. |
|||
3. With "allow_unsafe=True", it supports pyyaml tags that evaluate |
|||
expressions in config. See examples in |
|||
https://pyyaml.org/wiki/PyYAMLDocumentation#yaml-tags-and-python-types |
|||
Note that this may lead to arbitrary code execution: you must not |
|||
load a config file from untrusted sources before manually inspecting |
|||
the content of the file. |
|||
""" |
|||
|
|||
@staticmethod |
|||
def load_yaml_with_base(filename, allow_unsafe = False): |
|||
""" |
|||
Just like `yaml.load(open(filename))`, but inherit attributes from its |
|||
`_BASE_`. |
|||
|
|||
Args: |
|||
filename (str): the file name of the current config. Will be used to |
|||
find the base config file. |
|||
allow_unsafe (bool): whether to allow loading the config file with |
|||
`yaml.unsafe_load`. |
|||
|
|||
Returns: |
|||
(dict): the loaded yaml |
|||
""" |
|||
with PathManager.open(filename, "r") as f: |
|||
try: |
|||
cfg = yaml.safe_load(f) |
|||
except yaml.constructor.ConstructorError: |
|||
if not allow_unsafe: |
|||
raise |
|||
logger = logging.getLogger(__name__) |
|||
logger.warning( |
|||
"Loading config {} with yaml.unsafe_load. Your machine may " |
|||
"be at risk if the file contains malicious content.".format( |
|||
filename |
|||
) |
|||
) |
|||
f.close() |
|||
with open(filename, "r") as f: |
|||
cfg = yaml.unsafe_load(f) |
|||
|
|||
def merge_a_into_b(a, b): |
|||
# merge dict a into dict b. values in a will overwrite b. |
|||
for k, v in a.items(): |
|||
if isinstance(v, dict) and k in b: |
|||
assert isinstance( |
|||
b[k], dict |
|||
), "Cannot inherit key '{}' from base!".format(k) |
|||
merge_a_into_b(v, b[k]) |
|||
else: |
|||
b[k] = v |
|||
|
|||
if BASE_KEY in cfg: |
|||
base_cfg_file = cfg[BASE_KEY] |
|||
if base_cfg_file.startswith("~"): |
|||
base_cfg_file = os.path.expanduser(base_cfg_file) |
|||
if not any( |
|||
map(base_cfg_file.startswith, ["/", "https://", "http://"]) |
|||
): |
|||
# the path to base cfg is relative to the config file itself. |
|||
base_cfg_file = os.path.join( |
|||
os.path.dirname(filename), base_cfg_file |
|||
) |
|||
base_cfg = CfgNode.load_yaml_with_base( |
|||
base_cfg_file, allow_unsafe=allow_unsafe |
|||
) |
|||
del cfg[BASE_KEY] |
|||
|
|||
merge_a_into_b(cfg, base_cfg) |
|||
return base_cfg |
|||
return cfg |
|||
|
|||
def merge_from_file(self, cfg_filename, allow_unsafe = False): |
|||
""" |
|||
Merge configs from a given yaml file. |
|||
|
|||
Args: |
|||
cfg_filename: the file name of the yaml config. |
|||
allow_unsafe: whether to allow loading the config file with |
|||
`yaml.unsafe_load`. |
|||
""" |
|||
loaded_cfg = CfgNode.load_yaml_with_base( |
|||
cfg_filename, allow_unsafe=allow_unsafe |
|||
) |
|||
loaded_cfg = type(self)(loaded_cfg) |
|||
self.merge_from_other_cfg(loaded_cfg) |
|||
|
|||
# Forward the following calls to base, but with a check on the BASE_KEY. |
|||
def merge_from_other_cfg(self, cfg_other): |
|||
""" |
|||
Args: |
|||
cfg_other (CfgNode): configs to merge from. |
|||
""" |
|||
assert ( |
|||
BASE_KEY not in cfg_other |
|||
), "The reserved key '{}' can only be used in files!".format(BASE_KEY) |
|||
return super().merge_from_other_cfg(cfg_other) |
|||
|
|||
def merge_from_list(self, cfg_list): |
|||
""" |
|||
Args: |
|||
cfg_list (list): list of configs to merge from. |
|||
""" |
|||
keys = set(cfg_list[0::2]) |
|||
assert ( |
|||
BASE_KEY not in keys |
|||
), "The reserved key '{}' can only be used in files!".format(BASE_KEY) |
|||
return super().merge_from_list(cfg_list) |
|||
|
|||
def __setattr__(self, name, val): |
|||
if name.startswith("COMPUTED_"): |
|||
if name in self: |
|||
old_val = self[name] |
|||
if old_val == val: |
|||
return |
|||
raise KeyError( |
|||
"Computed attributed '{}' already exists " |
|||
"with a different value! old={}, new={}.".format( |
|||
name, old_val, val |
|||
) |
|||
) |
|||
self[name] = val |
|||
else: |
|||
super().__setattr__(name, val) |
|||
|
|||
|
|||
if __name__ == '__main__': |
|||
cfg = CfgNode.load_yaml_with_base('configs/updown_long.yml') |
|||
print(cfg) |
@ -0,0 +1,412 @@ |
|||
from __future__ import print_function |
|||
import argparse |
|||
|
|||
|
|||
def if_use_feat(caption_model): |
|||
# Decide if load attention feature according to caption model |
|||
if caption_model in ['show_tell', 'all_img', 'fc', 'newfc']: |
|||
use_att, use_fc = False, True |
|||
elif caption_model == 'language_model': |
|||
use_att, use_fc = False, False |
|||
elif caption_model in ['updown', 'topdown']: |
|||
use_fc, use_att = True, True |
|||
else: |
|||
use_att, use_fc = True, False |
|||
return use_fc, use_att |
|||
|
|||
import pprint |
|||
class Config(object): |
|||
def __init__(self, **kwargs): |
|||
"""Configuration Class: set kwargs as class attributes with setattr""" |
|||
for k, v in kwargs.items(): |
|||
setattr(self, k, v) |
|||
|
|||
@property |
|||
def config_str(self): |
|||
return pprint.pformat(self.__dict__) |
|||
|
|||
def __repr__(self): |
|||
"""Pretty-print configurations in alphabetical order""" |
|||
config_str = 'Configurations\n' |
|||
config_str += self.config_str |
|||
return config_str |
|||
|
|||
|
|||
def parse_opt(parse=True, **optional_kwargs): |
|||
parser = argparse.ArgumentParser() |
|||
# Data input settings |
|||
parser.add_argument('--input_json', type=str, default='data/coco.json', |
|||
help='path to the json file containing additional info and vocab') |
|||
parser.add_argument('--input_fc_dir', type=str, default='data/cocotalk_fc', |
|||
help='path to the directory containing the preprocessed fc feats') |
|||
parser.add_argument('--input_att_dir', type=str, default='data/cocotalk_att', |
|||
help='path to the directory containing the preprocessed att feats') |
|||
parser.add_argument('--input_box_dir', type=str, default='data/cocotalk_box', |
|||
help='path to the directory containing the boxes of att feats') |
|||
parser.add_argument('--input_label_h5', type=str, default='data/coco_label.h5', |
|||
help='path to the h5file containing the preprocessed dataset') |
|||
parser.add_argument('--data_in_memory', action='store_true', |
|||
help='True if we want to save the features in memory') |
|||
parser.add_argument('--start_from', type=str, default=None, |
|||
help="""continue training from saved model at this path. Path must contain files saved by previous training process: |
|||
'infos.pkl' : configuration; |
|||
'model.pth' : weights |
|||
""") |
|||
parser.add_argument('--cached_tokens', type=str, default='coco-train-idxs', |
|||
help='Cached token file for calculating cider score during self critical training.') |
|||
|
|||
# Model settings |
|||
parser.add_argument('--caption_model', type=str, default="show_tell", |
|||
help='show_tell, show_attend_tell, all_img, fc, att2in, att2in2, att2all2, adaatt, adaattmo, updown, stackatt, denseatt, transformer') |
|||
parser.add_argument('--rnn_size', type=int, default=512, |
|||
help='size of the rnn in number of hidden nodes in each layer') |
|||
parser.add_argument('--num_layers', type=int, default=1, |
|||
help='number of layers in the RNN') |
|||
parser.add_argument('--rnn_type', type=str, default='lstm', |
|||
help='rnn, gru, or lstm') |
|||
parser.add_argument('--input_encoding_size', type=int, default=512, |
|||
help='the encoding size of each token in the vocabulary, and the image.') |
|||
parser.add_argument('--att_hid_size', type=int, default=512, |
|||
help='the hidden size of the attention MLP; only useful in show_attend_tell; 0 if not using hidden layer') |
|||
parser.add_argument('--fc_feat_size', type=int, default=2048, |
|||
help='2048 for resnet, 4096 for vgg') |
|||
parser.add_argument('--att_feat_size', type=int, default=2048, |
|||
help='2048 for resnet, 512 for vgg') |
|||
parser.add_argument('--logit_layers', type=int, default=1, |
|||
help='number of layers in the RNN') |
|||
|
|||
|
|||
parser.add_argument('--use_bn', type=int, default=0, |
|||
help='If 1, then do batch_normalization first in att_embed, if 2 then do bn both in the beginning and the end of att_embed') |
|||
|
|||
# feature manipulation |
|||
parser.add_argument('--norm_att_feat', type=int, default=0, |
|||
help='If normalize attention features') |
|||
parser.add_argument('--use_box', type=int, default=0, |
|||
help='If use box features') |
|||
parser.add_argument('--norm_box_feat', type=int, default=0, |
|||
help='If use box, do we normalize box feature') |
|||
|
|||
# Optimization: General |
|||
parser.add_argument('--max_epochs', type=int, default=-1, |
|||
help='number of epochs') |
|||
parser.add_argument('--batch_size', type=int, default=16, |
|||
help='minibatch size') |
|||
parser.add_argument('--grad_clip_mode', type=str, default='value', |
|||
help='value or norm') |
|||
parser.add_argument('--grad_clip_value', type=float, default=0.1, |
|||
help='clip gradients at this value/max_norm, 0 means no clipping') |
|||
parser.add_argument('--drop_prob_lm', type=float, default=0.5, |
|||
help='strength of dropout in the Language Model RNN') |
|||
parser.add_argument('--self_critical_after', type=int, default=-1, |
|||
help='After what epoch do we start finetuning the CNN? (-1 = disable; never finetune, 0 = finetune from start)') |
|||
parser.add_argument('--seq_per_img', type=int, default=5, |
|||
help='number of captions to sample for each image during training. Done for efficiency since CNN forward pass is expensive. E.g. coco has 5 sents/image') |
|||
|
|||
parser.add_argument('--verbose', type=int, default=0) |
|||
|
|||
# Sample related |
|||
add_eval_sample_opts(parser) |
|||
|
|||
#Optimization: for the Language Model |
|||
parser.add_argument('--optim', type=str, default='adam', |
|||
help='what update to use? rmsprop|sgd|sgdmom|adagrad|adam|adamw') |
|||
parser.add_argument('--learning_rate', type=float, default=4e-4, |
|||
help='learning rate') |
|||
parser.add_argument('--learning_rate_decay_start', type=int, default=-1, |
|||
help='at what iteration to start decaying learning rate? (-1 = dont) (in epoch)') |
|||
parser.add_argument('--learning_rate_decay_every', type=int, default=3, |
|||
help='every how many iterations thereafter to drop LR?(in epoch)') |
|||
parser.add_argument('--learning_rate_decay_rate', type=float, default=0.8, |
|||
help='every how many iterations thereafter to drop LR?(in epoch)') |
|||
parser.add_argument('--optim_alpha', type=float, default=0.9, |
|||
help='alpha for adam') |
|||
parser.add_argument('--optim_beta', type=float, default=0.999, |
|||
help='beta used for adam') |
|||
parser.add_argument('--optim_epsilon', type=float, default=1e-8, |
|||
help='epsilon that goes into denominator for smoothing') |
|||
parser.add_argument('--weight_decay', type=float, default=0, |
|||
help='weight_decay') |
|||
# Transformer |
|||
parser.add_argument('--label_smoothing', type=float, default=0, |
|||
help='') |
|||
parser.add_argument('--noamopt', action='store_true', |
|||
help='') |
|||
parser.add_argument('--noamopt_warmup', type=int, default=2000, |
|||
help='') |
|||
parser.add_argument('--noamopt_factor', type=float, default=1, |
|||
help='') |
|||
parser.add_argument('--reduce_on_plateau', action='store_true', |
|||
help='') |
|||
parser.add_argument('--reduce_on_plateau_factor', type=float, default=0.5, |
|||
help='') |
|||
parser.add_argument('--reduce_on_plateau_patience', type=int, default=3, |
|||
help='') |
|||
parser.add_argument('--cached_transformer', action='store_true', |
|||
help='') |
|||
|
|||
|
|||
parser.add_argument('--use_warmup', action='store_true', |
|||
help='warm up the learing rate?') |
|||
|
|||
parser.add_argument('--scheduled_sampling_start', type=int, default=-1, |
|||
help='at what iteration to start decay gt probability') |
|||
parser.add_argument('--scheduled_sampling_increase_every', type=int, default=5, |
|||
help='every how many iterations thereafter to gt probability') |
|||
parser.add_argument('--scheduled_sampling_increase_prob', type=float, default=0.05, |
|||
help='How much to update the prob') |
|||
parser.add_argument('--scheduled_sampling_max_prob', type=float, default=0.25, |
|||
help='Maximum scheduled sampling prob.') |
|||
|
|||
|
|||
# Evaluation/Checkpointing |
|||
parser.add_argument('--val_images_use', type=int, default=3200, |
|||
help='how many images to use when periodically evaluating the validation loss? (-1 = all)') |
|||
parser.add_argument('--save_checkpoint_every', type=int, default=2500, |
|||
help='how often to save a model checkpoint (in iterations)?') |
|||
parser.add_argument('--save_every_epoch', action='store_true', |
|||
help='Save checkpoint every epoch, will overwrite save_checkpoint_every') |
|||
parser.add_argument('--save_history_ckpt', type=int, default=0, |
|||
help='If save checkpoints at every save point') |
|||
parser.add_argument('--checkpoint_path', type=str, default=None, |
|||
help='directory to store checkpointed models') |
|||
parser.add_argument('--language_eval', type=int, default=0, |
|||
help='Evaluate language as well (1 = yes, 0 = no)? BLEU/CIDEr/METEOR/ROUGE_L? requires coco-caption code from Github.') |
|||
parser.add_argument('--losses_log_every', type=int, default=25, |
|||
help='How often do we snapshot losses, for inclusion in the progress dump? (0 = disable)') |
|||
parser.add_argument('--load_best_score', type=int, default=1, |
|||
help='Do we load previous best score when resuming training.') |
|||
|
|||
# misc |
|||
parser.add_argument('--id', type=str, default='', |
|||
help='an id identifying this run/job. used in cross-val and appended when writing progress files') |
|||
parser.add_argument('--train_only', type=int, default=0, |
|||
help='if true then use 80k, else use 110k') |
|||
|
|||
|
|||
# Reward |
|||
parser.add_argument('--cider_reward_weight', type=float, default=1, |
|||
help='The reward weight from cider') |
|||
parser.add_argument('--bleu_reward_weight', type=float, default=0, |
|||
help='The reward weight from bleu4') |
|||
|
|||
# Reward |
|||
parser.add_argument('--clipscore_reward_weight', type=float, default=1, |
|||
help='The reward weight from clipscore') |
|||
parser.add_argument('--use_clipscore', type=float, default=0, |
|||
help='Use CLIPScore') |
|||
parser.add_argument('--clipscore_mode', type=str, default='clip_s', |
|||
help='Which CLIPScore to use: clip_s|refclip_s') |
|||
|
|||
|
|||
# Structure_loss |
|||
parser.add_argument('--structure_loss_weight', type=float, default=1, |
|||
help='') |
|||
parser.add_argument('--structure_after', type=int, default=-1, |
|||
help='T') |
|||
parser.add_argument('--structure_loss_type', type=str, default='seqnll', |
|||
help='') |
|||
parser.add_argument('--struc_use_logsoftmax', action='store_true', help='') |
|||
parser.add_argument('--entropy_reward_weight', type=float, default=0, |
|||
help='Entropy reward, seems very interesting') |
|||
parser.add_argument('--self_cider_reward_weight', type=float, default=0, |
|||
help='self cider reward') |
|||
|
|||
# Used for self critical or structure. Used when sampling is need during training |
|||
parser.add_argument('--train_sample_n', type=int, default=16, |
|||
help='The reward weight from cider') |
|||
parser.add_argument('--train_sample_method', type=str, default='sample', |
|||
help='') |
|||
parser.add_argument('--train_beam_size', type=int, default=1, |
|||
help='') |
|||
|
|||
# Used for self critical |
|||
parser.add_argument('--sc_sample_method', type=str, default='greedy', |
|||
help='') |
|||
parser.add_argument('--sc_beam_size', type=int, default=1, |
|||
help='') |
|||
|
|||
|
|||
# For diversity evaluation during training |
|||
add_diversity_opts(parser) |
|||
|
|||
|
|||
# config |
|||
parser.add_argument('--cfg', type=str, default=None, |
|||
help='configuration; similar to what is used in detectron') |
|||
parser.add_argument( |
|||
'--set_cfgs', dest='set_cfgs', |
|||
help='Set config keys. Key value sequence seperate by whitespace.' |
|||
'e.g. [key] [value] [key] [value]\n This has higher priority' |
|||
'than cfg file but lower than other args. (You can only overwrite' |
|||
'arguments that have alerady been defined in config file.)', |
|||
default=[], nargs='+') |
|||
# How will config be used |
|||
# 1) read cfg argument, and load the cfg file if it's not None |
|||
# 2) Overwrite cfg argument with set_cfgs |
|||
# 3) parse config argument to args. |
|||
# 4) in the end, parse command line argument and overwrite args |
|||
|
|||
# step 1: read cfg_fn |
|||
# args = parser.parse_args() |
|||
# Parse the arguments. |
|||
if parse: |
|||
args = parser.parse_args() |
|||
# For interative engironmnet (ex. jupyter) |
|||
else: |
|||
args = parser.parse_known_args()[0] |
|||
# print(args) |
|||
|
|||
# Namespace => Dictionary |
|||
kwargs = vars(args) |
|||
# for k, v in optional_kwargs.items(): |
|||
# setattr(args, k, v) |
|||
kwargs.update(optional_kwargs) |
|||
|
|||
args = Config(**kwargs) |
|||
|
|||
|
|||
if args.cfg is not None or args.set_cfgs is not None: |
|||
from .config import CfgNode |
|||
if args.cfg is not None: |
|||
# print('Read Cfg') |
|||
cn = CfgNode(CfgNode.load_yaml_with_base(args.cfg)) |
|||
# print(cn) |
|||
else: |
|||
cn = CfgNode() |
|||
if args.set_cfgs is not None: |
|||
cn.merge_from_list(args.set_cfgs) |
|||
for k,v in cn.items(): |
|||
if not hasattr(args, k): |
|||
import os |
|||
if 'LOCAL_RANK' in os.environ and os.environ['LOCAL_RANK'] != '0': |
|||
pass |
|||
else: |
|||
print('Warning: key %s not in args' % k) |
|||
|
|||
setattr(args, k, v) |
|||
|
|||
if parse: |
|||
args = parser.parse_args(namespace=args) |
|||
else: |
|||
args = parser.parse_known_args(namespace=args)[0] |
|||
|
|||
# Check if args are valid |
|||
assert args.rnn_size > 0, "rnn_size should be greater than 0" |
|||
assert args.num_layers > 0, "num_layers should be greater than 0" |
|||
assert args.input_encoding_size > 0, "input_encoding_size should be greater than 0" |
|||
assert args.batch_size > 0, "batch_size should be greater than 0" |
|||
assert args.drop_prob_lm >= 0 and args.drop_prob_lm < 1, "drop_prob_lm should be between 0 and 1" |
|||
assert args.seq_per_img > 0, "seq_per_img should be greater than 0" |
|||
assert args.beam_size > 0, "beam_size should be greater than 0" |
|||
assert args.save_checkpoint_every > 0, "save_checkpoint_every should be greater than 0" |
|||
assert args.losses_log_every > 0, "losses_log_every should be greater than 0" |
|||
assert args.language_eval == 0 or args.language_eval == 1, "language_eval should be 0 or 1" |
|||
assert args.load_best_score == 0 or args.load_best_score == 1, "language_eval should be 0 or 1" |
|||
assert args.train_only == 0 or args.train_only == 1, "language_eval should be 0 or 1" |
|||
|
|||
# default value for start_from and checkpoint_path |
|||
args.checkpoint_path = args.checkpoint_path or './log_%s' %args.id |
|||
args.start_from = args.start_from or args.checkpoint_path |
|||
|
|||
# Deal with feature things before anything |
|||
args.use_fc, args.use_att = if_use_feat(args.caption_model) |
|||
if args.use_box: args.att_feat_size = args.att_feat_size + 5 |
|||
|
|||
return args |
|||
|
|||
|
|||
def add_eval_options(parser): |
|||
# Basic options |
|||
parser.add_argument('--batch_size', type=int, default=0, |
|||
help='if > 0 then overrule, otherwise load from checkpoint.') |
|||
parser.add_argument('--num_images', type=int, default=-1, |
|||
help='how many images to use when periodically evaluating the loss? (-1 = all)') |
|||
parser.add_argument('--language_eval', type=int, default=0, |
|||
help='Evaluate language as well (1 = yes, 0 = no)? BLEU/CIDEr/METEOR/ROUGE_L? requires coco-caption code from Github.') |
|||
parser.add_argument('--dump_images', type=int, default=1, |
|||
help='Dump images into vis/imgs folder for vis? (1=yes,0=no)') |
|||
parser.add_argument('--dump_json', type=int, default=1, |
|||
help='Dump json with predictions into vis folder? (1=yes,0=no)') |
|||
parser.add_argument('--dump_path', type=int, default=0, |
|||
help='Write image paths along with predictions into vis json? (1=yes,0=no)') |
|||
|
|||
# Sampling options |
|||
add_eval_sample_opts(parser) |
|||
|
|||
# For evaluation on a folder of images: |
|||
parser.add_argument('--image_folder', type=str, default='', |
|||
help='If this is nonempty then will predict on the images in this folder path') |
|||
parser.add_argument('--image_root', type=str, default='', |
|||
help='In case the image paths have to be preprended with a root path to an image folder') |
|||
# For evaluation on MSCOCO images from some split: |
|||
parser.add_argument('--input_fc_dir', type=str, default='', |
|||
help='path to the h5file containing the preprocessed dataset') |
|||
parser.add_argument('--input_att_dir', type=str, default='', |
|||
help='path to the h5file containing the preprocessed dataset') |
|||
parser.add_argument('--input_box_dir', type=str, default='', |
|||
help='path to the h5file containing the preprocessed dataset') |
|||
parser.add_argument('--input_label_h5', type=str, default='', |
|||
help='path to the h5file containing the preprocessed dataset') |
|||
parser.add_argument('--input_json', type=str, default='', |
|||
help='path to the json file containing additional info and vocab. empty = fetch from model checkpoint.') |
|||
parser.add_argument('--split', type=str, default='test', |
|||
help='if running on MSCOCO images, which split to use: val|test|train') |
|||
parser.add_argument('--coco_json', type=str, default='', |
|||
help='if nonempty then use this file in DataLoaderRaw (see docs there). Used only in MSCOCO test evaluation, where we have a specific json file of only test set images.') |
|||
# misc |
|||
parser.add_argument('--id', type=str, default='', |
|||
help='an id identifying this run/job. used only if language_eval = 1 for appending to intermediate files') |
|||
parser.add_argument('--verbose_beam', type=int, default=1, |
|||
help='if we need to print out all beam search beams.') |
|||
parser.add_argument('--verbose_loss', type=int, default=0, |
|||
help='If calculate loss using ground truth during evaluation') |
|||
|
|||
def add_diversity_opts(parser): |
|||
parser.add_argument('--sample_n', type=int, default=1, |
|||
help='Diverse sampling') |
|||
parser.add_argument('--sample_n_method', type=str, default='sample', |
|||
help='sample, bs, dbs, gumbel, topk, dgreedy, dsample, dtopk, dtopp') |
|||
parser.add_argument('--eval_oracle', type=int, default=1, |
|||
help='if we need to calculate loss.') |
|||
|
|||
|
|||
# Sampling related options |
|||
def add_eval_sample_opts(parser): |
|||
parser.add_argument('--sample_method', type=str, default='greedy', |
|||
help='greedy; sample; gumbel; top<int>, top<0-1>') |
|||
parser.add_argument('--beam_size', type=int, default=1, |
|||
help='used when sample_method = greedy, indicates number of beams in beam search. Usually 2 or 3 works well. More is not better. Set this to 1 for faster runtime but a bit worse performance.') |
|||
parser.add_argument('--max_length', type=int, default=20, |
|||
help='Maximum length during sampling') |
|||
parser.add_argument('--length_penalty', type=str, default='', |
|||
help='wu_X or avg_X, X is the alpha') |
|||
parser.add_argument('--group_size', type=int, default=1, |
|||
help='used for diverse beam search. if group_size is 1, then it\'s normal beam search') |
|||
parser.add_argument('--diversity_lambda', type=float, default=0.5, |
|||
help='used for diverse beam search. Usually from 0.2 to 0.8. Higher value of lambda produces a more diverse list') |
|||
parser.add_argument('--temperature', type=float, default=1.0, |
|||
help='temperature when sampling from distributions (i.e. when sample_method = sample). Lower = "safer" predictions.') |
|||
parser.add_argument('--decoding_constraint', type=int, default=0, |
|||
help='If 1, not allowing same word in a row') |
|||
parser.add_argument('--block_trigrams', type=int, default=0, |
|||
help='block repeated trigram.') |
|||
parser.add_argument('--remove_bad_endings', type=int, default=0, |
|||
help='Remove bad endings') |
|||
parser.add_argument('--suppress_UNK', type=int, default=1, |
|||
help='Not predicting UNK') |
|||
|
|||
|
|||
if __name__ == '__main__': |
|||
import sys |
|||
sys.argv = [sys.argv[0]] |
|||
args = parse_opt() |
|||
print(args) |
|||
print() |
|||
sys.argv = [sys.argv[0], '--cfg', 'configs/updown_long.yml'] |
|||
args1 = parse_opt() |
|||
print(dict(set(vars(args1).items()) - set(vars(args).items()))) |
|||
print() |
|||
sys.argv = [sys.argv[0], '--cfg', 'configs/updown_long.yml', '--caption_model', 'att2in2'] |
|||
args2 = parse_opt() |
|||
print(dict(set(vars(args2).items()) - set(vars(args1).items()))) |
Loading…
Reference in new issue