clip-caption-reward
copied
wxywb
2 years ago
30 changed files with 3491 additions and 0 deletions
@ -0,0 +1,18 @@ |
|||||
|
# Copyright 2021 Zilliz. All rights reserved. |
||||
|
# |
||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
|
# you may not use this file except in compliance with the License. |
||||
|
# You may obtain a copy of the License at |
||||
|
# |
||||
|
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
|
# |
||||
|
# Unless required by applicable law or agreed to in writing, software |
||||
|
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
|
# See the License for the specific language governing permissions and |
||||
|
# limitations under the License. |
||||
|
|
||||
|
from clip_caption_reward import ClipCaptionReward |
||||
|
|
||||
|
def clip_caption_reward(model_name: str): |
||||
|
return ClipCaptionReward(model_name) |
Binary file not shown.
@ -0,0 +1,464 @@ |
|||||
|
# This file contains Att2in2, AdaAtt, AdaAttMO, UpDown model |
||||
|
|
||||
|
# AdaAtt is from Knowing When to Look: Adaptive Attention via A Visual Sentinel for Image Captioning |
||||
|
# https://arxiv.org/abs/1612.01887 |
||||
|
# AdaAttMO is a modified version with maxout lstm |
||||
|
|
||||
|
# Att2in is from Self-critical Sequence Training for Image Captioning |
||||
|
# https://arxiv.org/abs/1612.00563 |
||||
|
# In this file we only have Att2in2, which is a slightly different version of att2in, |
||||
|
# in which the img feature embedding and word embedding is the same as what in adaatt. |
||||
|
|
||||
|
# UpDown is from Bottom-Up and Top-Down Attention for Image Captioning and VQA |
||||
|
# https://arxiv.org/abs/1707.07998 |
||||
|
# However, it may not be identical to the author's architecture. |
||||
|
|
||||
|
from __future__ import absolute_import |
||||
|
from __future__ import division |
||||
|
from __future__ import print_function |
||||
|
|
||||
|
import numpy as np |
||||
|
import torch |
||||
|
import torch.nn as nn |
||||
|
import torch.nn.functional as F |
||||
|
#from . import utils |
||||
|
#utils.repeat_tensors |
||||
|
from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence, pad_packed_sequence |
||||
|
|
||||
|
from .CaptionModel import CaptionModel |
||||
|
|
||||
|
bad_endings = ['a','an','the','in','for','at','of','with','before','after','on','upon','near','to','is','are','am'] |
||||
|
bad_endings += ['the'] |
||||
|
|
||||
|
|
||||
|
def repeat_tensors(n, x): |
||||
|
""" |
||||
|
For a tensor of size Bx..., we repeat it n times, and make it Bnx... |
||||
|
For collections, do nested repeat |
||||
|
""" |
||||
|
if torch.is_tensor(x): |
||||
|
x = x.unsqueeze(1) # Bx1x... |
||||
|
x = x.expand(-1, n, *([-1]*len(x.shape[2:]))) # Bxnx... |
||||
|
x = x.reshape(x.shape[0]*n, *x.shape[2:]) # Bnx... |
||||
|
elif type(x) is list or type(x) is tuple: |
||||
|
x = [repeat_tensors(n, _) for _ in x] |
||||
|
return x |
||||
|
|
||||
|
def sort_pack_padded_sequence(input, lengths): |
||||
|
sorted_lengths, indices = torch.sort(lengths, descending=True) |
||||
|
# tmp = pack_padded_sequence(input[indices], sorted_lengths, batch_first=True) |
||||
|
tmp = pack_padded_sequence(input[indices], sorted_lengths.cpu(), batch_first=True) |
||||
|
inv_ix = indices.clone() |
||||
|
inv_ix[indices] = torch.arange(0,len(indices)).type_as(inv_ix) |
||||
|
return tmp, inv_ix |
||||
|
|
||||
|
def pad_unsort_packed_sequence(input, inv_ix): |
||||
|
tmp, _ = pad_packed_sequence(input, batch_first=True) |
||||
|
tmp = tmp[inv_ix] |
||||
|
return tmp |
||||
|
|
||||
|
def pack_wrapper(module, att_feats, att_masks): |
||||
|
if att_masks is not None: |
||||
|
packed, inv_ix = sort_pack_padded_sequence(att_feats, att_masks.data.long().sum(1)) |
||||
|
return pad_unsort_packed_sequence(PackedSequence(module(packed[0]), packed[1]), inv_ix) |
||||
|
else: |
||||
|
return module(att_feats) |
||||
|
|
||||
|
class AttModel(CaptionModel): |
||||
|
def __init__(self, opt): |
||||
|
super(AttModel, self).__init__() |
||||
|
self.vocab_size = opt.vocab_size |
||||
|
self.input_encoding_size = opt.input_encoding_size |
||||
|
#self.rnn_type = opt.rnn_type |
||||
|
self.rnn_size = opt.rnn_size |
||||
|
self.num_layers = opt.num_layers |
||||
|
self.drop_prob_lm = opt.drop_prob_lm |
||||
|
self.seq_length = getattr(opt, 'max_length', 20) or opt.seq_length # maximum sample length |
||||
|
self.fc_feat_size = opt.fc_feat_size |
||||
|
self.att_feat_size = opt.att_feat_size |
||||
|
self.att_hid_size = opt.att_hid_size |
||||
|
|
||||
|
self.bos_idx = getattr(opt, 'bos_idx', 0) |
||||
|
self.eos_idx = getattr(opt, 'eos_idx', 0) |
||||
|
self.pad_idx = getattr(opt, 'pad_idx', 0) |
||||
|
|
||||
|
self.use_bn = getattr(opt, 'use_bn', 0) |
||||
|
|
||||
|
self.ss_prob = 0.0 # Schedule sampling probability |
||||
|
|
||||
|
self.embed = nn.Sequential(nn.Embedding(self.vocab_size + 1, self.input_encoding_size), |
||||
|
nn.ReLU(), |
||||
|
nn.Dropout(self.drop_prob_lm)) |
||||
|
self.fc_embed = nn.Sequential(nn.Linear(self.fc_feat_size, self.rnn_size), |
||||
|
nn.ReLU(), |
||||
|
nn.Dropout(self.drop_prob_lm)) |
||||
|
self.att_embed = nn.Sequential(*( |
||||
|
((nn.BatchNorm1d(self.att_feat_size),) if self.use_bn else ())+ |
||||
|
(nn.Linear(self.att_feat_size, self.rnn_size), |
||||
|
nn.ReLU(), |
||||
|
nn.Dropout(self.drop_prob_lm))+ |
||||
|
((nn.BatchNorm1d(self.rnn_size),) if self.use_bn==2 else ()))) |
||||
|
|
||||
|
self.logit_layers = getattr(opt, 'logit_layers', 1) |
||||
|
if self.logit_layers == 1: |
||||
|
self.logit = nn.Linear(self.rnn_size, self.vocab_size + 1) |
||||
|
else: |
||||
|
self.logit = [[nn.Linear(self.rnn_size, self.rnn_size), nn.ReLU(), nn.Dropout(0.5)] for _ in range(opt.logit_layers - 1)] |
||||
|
self.logit = nn.Sequential(*(reduce(lambda x,y:x+y, self.logit) + [nn.Linear(self.rnn_size, self.vocab_size + 1)])) |
||||
|
self.ctx2att = nn.Linear(self.rnn_size, self.att_hid_size) |
||||
|
|
||||
|
# For remove bad endding |
||||
|
self.vocab = opt.vocab |
||||
|
self.bad_endings_ix = [int(k) for k,v in self.vocab.items() if v in bad_endings] |
||||
|
|
||||
|
def init_hidden(self, bsz): |
||||
|
weight = self.logit.weight \ |
||||
|
if hasattr(self.logit, "weight") \ |
||||
|
else self.logit[0].weight |
||||
|
return (weight.new_zeros(self.num_layers, bsz, self.rnn_size), |
||||
|
weight.new_zeros(self.num_layers, bsz, self.rnn_size)) |
||||
|
|
||||
|
def clip_att(self, att_feats, att_masks): |
||||
|
# Clip the length of att_masks and att_feats to the maximum length |
||||
|
if att_masks is not None: |
||||
|
max_len = att_masks.data.long().sum(1).max() |
||||
|
att_feats = att_feats[:, :max_len].contiguous() |
||||
|
att_masks = att_masks[:, :max_len].contiguous() |
||||
|
return att_feats, att_masks |
||||
|
|
||||
|
def _prepare_feature(self, fc_feats, att_feats, att_masks): |
||||
|
att_feats, att_masks = self.clip_att(att_feats, att_masks) |
||||
|
|
||||
|
# embed fc and att feats |
||||
|
fc_feats = self.fc_embed(fc_feats) |
||||
|
att_feats = pack_wrapper(self.att_embed, att_feats, att_masks) |
||||
|
|
||||
|
# Project the attention feats first to reduce memory and computation comsumptions. |
||||
|
p_att_feats = self.ctx2att(att_feats) |
||||
|
|
||||
|
return fc_feats, att_feats, p_att_feats, att_masks |
||||
|
|
||||
|
def _forward(self, fc_feats, att_feats, seq, att_masks=None): |
||||
|
batch_size = fc_feats.size(0) |
||||
|
if seq.ndim == 3: # B * seq_per_img * seq_len |
||||
|
seq = seq.reshape(-1, seq.shape[2]) |
||||
|
seq_per_img = seq.shape[0] // batch_size |
||||
|
state = self.init_hidden(batch_size*seq_per_img) |
||||
|
|
||||
|
outputs = fc_feats.new_zeros(batch_size*seq_per_img, seq.size(1), self.vocab_size+1) |
||||
|
|
||||
|
# Prepare the features |
||||
|
p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = self._prepare_feature(fc_feats, att_feats, att_masks) |
||||
|
# pp_att_feats is used for attention, we cache it in advance to reduce computation cost |
||||
|
|
||||
|
if seq_per_img > 1: |
||||
|
p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = repeat_tensors(seq_per_img, |
||||
|
[p_fc_feats, p_att_feats, pp_att_feats, p_att_masks] |
||||
|
) |
||||
|
|
||||
|
for i in range(seq.size(1)): |
||||
|
if self.training and i >= 1 and self.ss_prob > 0.0: # otherwiste no need to sample |
||||
|
sample_prob = fc_feats.new(batch_size*seq_per_img).uniform_(0, 1) |
||||
|
sample_mask = sample_prob < self.ss_prob |
||||
|
if sample_mask.sum() == 0: |
||||
|
it = seq[:, i].clone() |
||||
|
else: |
||||
|
sample_ind = sample_mask.nonzero().view(-1) |
||||
|
it = seq[:, i].data.clone() |
||||
|
prob_prev = torch.exp(outputs[:, i-1].detach()) # fetch prev distribution: shape Nx(M+1) |
||||
|
it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1).index_select(0, sample_ind)) |
||||
|
else: |
||||
|
it = seq[:, i].clone() |
||||
|
# break if all the sequences end |
||||
|
if i >= 1 and seq[:, i].sum() == 0: |
||||
|
break |
||||
|
|
||||
|
output, state = self.get_logprobs_state(it, p_fc_feats, p_att_feats, pp_att_feats, p_att_masks, state) |
||||
|
outputs[:, i] = output |
||||
|
|
||||
|
return outputs |
||||
|
|
||||
|
def get_logprobs_state(self, it, fc_feats, att_feats, p_att_feats, att_masks, state, output_logsoftmax=1): |
||||
|
# 'it' contains a word index |
||||
|
xt = self.embed(it) |
||||
|
|
||||
|
output, state = self.core(xt, fc_feats, att_feats, p_att_feats, state, att_masks) |
||||
|
if output_logsoftmax: |
||||
|
logprobs = F.log_softmax(self.logit(output), dim=1) |
||||
|
else: |
||||
|
logprobs = self.logit(output) |
||||
|
|
||||
|
return logprobs, state |
||||
|
|
||||
|
def _old_sample_beam(self, fc_feats, att_feats, att_masks=None, opt={}): |
||||
|
beam_size = opt.get('beam_size', 10) |
||||
|
group_size = opt.get('group_size', 1) |
||||
|
sample_n = opt.get('sample_n', 10) |
||||
|
# when sample_n == beam_size then each beam is a sample. |
||||
|
assert sample_n == 1 or sample_n == beam_size // group_size, 'when beam search, sample_n == 1 or beam search' |
||||
|
batch_size = fc_feats.size(0) |
||||
|
|
||||
|
p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = self._prepare_feature(fc_feats, att_feats, att_masks) |
||||
|
|
||||
|
assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed' |
||||
|
seq = fc_feats.new_full((batch_size*sample_n, self.seq_length), self.pad_idx, dtype=torch.long) |
||||
|
seqLogprobs = fc_feats.new_zeros(batch_size*sample_n, self.seq_length, self.vocab_size + 1) |
||||
|
# lets process every image independently for now, for simplicity |
||||
|
|
||||
|
self.done_beams = [[] for _ in range(batch_size)] |
||||
|
for k in range(batch_size): |
||||
|
state = self.init_hidden(beam_size) |
||||
|
tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, tmp_att_masks = repeat_tensors(beam_size, |
||||
|
[p_fc_feats[k:k+1], p_att_feats[k:k+1], pp_att_feats[k:k+1], p_att_masks[k:k+1] if att_masks is not None else None] |
||||
|
) |
||||
|
|
||||
|
for t in range(1): |
||||
|
if t == 0: # input <bos> |
||||
|
it = fc_feats.new_full([beam_size], self.bos_idx, dtype=torch.long) |
||||
|
|
||||
|
logprobs, state = self.get_logprobs_state(it, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, tmp_att_masks, state) |
||||
|
|
||||
|
self.done_beams[k] = self.old_beam_search(state, logprobs, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, tmp_att_masks, opt=opt) |
||||
|
if sample_n == beam_size: |
||||
|
for _n in range(sample_n): |
||||
|
seq[k*sample_n+_n, :] = self.done_beams[k][_n]['seq'] |
||||
|
seqLogprobs[k*sample_n+_n, :] = self.done_beams[k][_n]['logps'] |
||||
|
else: |
||||
|
seq[k, :] = self.done_beams[k][0]['seq'] # the first beam has highest cumulative score |
||||
|
seqLogprobs[k, :] = self.done_beams[k][0]['logps'] |
||||
|
# return the samples and their log likelihoods |
||||
|
return seq, seqLogprobs |
||||
|
|
||||
|
|
||||
|
def _sample_beam(self, fc_feats, att_feats, att_masks=None, opt={}): |
||||
|
beam_size = opt.get('beam_size', 10) |
||||
|
group_size = opt.get('group_size', 1) |
||||
|
sample_n = opt.get('sample_n', 10) |
||||
|
# when sample_n == beam_size then each beam is a sample. |
||||
|
assert sample_n == 1 or sample_n == beam_size // group_size, 'when beam search, sample_n == 1 or beam search' |
||||
|
batch_size = fc_feats.size(0) |
||||
|
|
||||
|
p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = self._prepare_feature(fc_feats, att_feats, att_masks) |
||||
|
|
||||
|
assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed' |
||||
|
seq = fc_feats.new_full((batch_size*sample_n, self.seq_length), self.pad_idx, dtype=torch.long) |
||||
|
seqLogprobs = fc_feats.new_zeros(batch_size*sample_n, self.seq_length, self.vocab_size + 1) |
||||
|
# lets process every image independently for now, for simplicity |
||||
|
|
||||
|
self.done_beams = [[] for _ in range(batch_size)] |
||||
|
|
||||
|
state = self.init_hidden(batch_size) |
||||
|
|
||||
|
# first step, feed bos |
||||
|
it = fc_feats.new_full([batch_size], self.bos_idx, dtype=torch.long) |
||||
|
logprobs, state = self.get_logprobs_state(it, p_fc_feats, p_att_feats, pp_att_feats, p_att_masks, state) |
||||
|
|
||||
|
p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = repeat_tensors(beam_size, |
||||
|
[p_fc_feats, p_att_feats, pp_att_feats, p_att_masks] |
||||
|
) |
||||
|
self.done_beams = self.beam_search(state, logprobs, p_fc_feats, p_att_feats, pp_att_feats, p_att_masks, opt=opt) |
||||
|
for k in range(batch_size): |
||||
|
if sample_n == beam_size: |
||||
|
for _n in range(sample_n): |
||||
|
seq_len = self.done_beams[k][_n]['seq'].shape[0] |
||||
|
seq[k*sample_n+_n, :seq_len] = self.done_beams[k][_n]['seq'] |
||||
|
seqLogprobs[k*sample_n+_n, :seq_len] = self.done_beams[k][_n]['logps'] |
||||
|
else: |
||||
|
seq_len = self.done_beams[k][0]['seq'].shape[0] |
||||
|
seq[k, :seq_len] = self.done_beams[k][0]['seq'] # the first beam has highest cumulative score |
||||
|
seqLogprobs[k, :seq_len] = self.done_beams[k][0]['logps'] |
||||
|
# return the samples and their log likelihoods |
||||
|
return seq, seqLogprobs |
||||
|
|
||||
|
def _sample(self, fc_feats, att_feats, att_masks=None, opt={}): |
||||
|
|
||||
|
sample_method = opt.get('sample_method', 'greedy') |
||||
|
beam_size = opt.get('beam_size', 1) |
||||
|
temperature = opt.get('temperature', 1.0) |
||||
|
sample_n = int(opt.get('sample_n', 1)) |
||||
|
group_size = opt.get('group_size', 1) |
||||
|
output_logsoftmax = opt.get('output_logsoftmax', 1) |
||||
|
decoding_constraint = opt.get('decoding_constraint', 0) |
||||
|
block_trigrams = opt.get('block_trigrams', 0) |
||||
|
remove_bad_endings = opt.get('remove_bad_endings', 0) |
||||
|
if beam_size > 1 and sample_method in ['greedy', 'beam_search']: |
||||
|
return self._sample_beam(fc_feats, att_feats, att_masks, opt) |
||||
|
if group_size > 1: |
||||
|
return self._diverse_sample(fc_feats, att_feats, att_masks, opt) |
||||
|
|
||||
|
batch_size = fc_feats.size(0) |
||||
|
state = self.init_hidden(batch_size*sample_n) |
||||
|
|
||||
|
p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = self._prepare_feature(fc_feats, att_feats, att_masks) |
||||
|
|
||||
|
if sample_n > 1: |
||||
|
p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = repeat_tensors(sample_n, |
||||
|
[p_fc_feats, p_att_feats, pp_att_feats, p_att_masks] |
||||
|
) |
||||
|
|
||||
|
trigrams = [] # will be a list of batch_size dictionaries |
||||
|
|
||||
|
seq = fc_feats.new_full((batch_size*sample_n, self.seq_length), self.pad_idx, dtype=torch.long) |
||||
|
seqLogprobs = fc_feats.new_zeros(batch_size*sample_n, self.seq_length, self.vocab_size + 1) |
||||
|
for t in range(self.seq_length + 1): |
||||
|
if t == 0: # input <bos> |
||||
|
it = fc_feats.new_full([batch_size*sample_n], self.bos_idx, dtype=torch.long) |
||||
|
|
||||
|
logprobs, state = self.get_logprobs_state(it, p_fc_feats, p_att_feats, pp_att_feats, p_att_masks, state, output_logsoftmax=output_logsoftmax) |
||||
|
|
||||
|
if decoding_constraint and t > 0: |
||||
|
tmp = logprobs.new_zeros(logprobs.size()) |
||||
|
tmp.scatter_(1, seq[:,t-1].data.unsqueeze(1), float('-inf')) |
||||
|
logprobs = logprobs + tmp |
||||
|
|
||||
|
if remove_bad_endings and t > 0: |
||||
|
tmp = logprobs.new_zeros(logprobs.size()) |
||||
|
prev_bad = np.isin(seq[:,t-1].data.cpu().numpy(), self.bad_endings_ix) |
||||
|
# Make it impossible to generate bad_endings |
||||
|
tmp[torch.from_numpy(prev_bad.astype('uint8')), 0] = float('-inf') |
||||
|
logprobs = logprobs + tmp |
||||
|
|
||||
|
# Mess with trigrams |
||||
|
# Copy from https://github.com/lukemelas/image-paragraph-captioning |
||||
|
if block_trigrams and t >= 3: |
||||
|
# Store trigram generated at last step |
||||
|
prev_two_batch = seq[:,t-3:t-1] |
||||
|
for i in range(batch_size): # = seq.size(0) |
||||
|
prev_two = (prev_two_batch[i][0].item(), prev_two_batch[i][1].item()) |
||||
|
current = seq[i][t-1] |
||||
|
if t == 3: # initialize |
||||
|
trigrams.append({prev_two: [current]}) # {LongTensor: list containing 1 int} |
||||
|
elif t > 3: |
||||
|
if prev_two in trigrams[i]: # add to list |
||||
|
trigrams[i][prev_two].append(current) |
||||
|
else: # create list |
||||
|
trigrams[i][prev_two] = [current] |
||||
|
# Block used trigrams at next step |
||||
|
prev_two_batch = seq[:,t-2:t] |
||||
|
mask = torch.zeros(logprobs.size(), requires_grad=False).to(logprobs.device) # batch_size x vocab_size |
||||
|
for i in range(batch_size): |
||||
|
prev_two = (prev_two_batch[i][0].item(), prev_two_batch[i][1].item()) |
||||
|
if prev_two in trigrams[i]: |
||||
|
for j in trigrams[i][prev_two]: |
||||
|
mask[i,j] += 1 |
||||
|
# Apply mask to log probs |
||||
|
#logprobs = logprobs - (mask * 1e9) |
||||
|
alpha = 2.0 # = 4 |
||||
|
logprobs = logprobs + (mask * -0.693 * alpha) # ln(1/2) * alpha (alpha -> infty works best) |
||||
|
|
||||
|
# sample the next word |
||||
|
if t == self.seq_length: # skip if we achieve maximum length |
||||
|
break |
||||
|
it, sampleLogprobs = self.sample_next_word(logprobs, sample_method, temperature) |
||||
|
|
||||
|
# stop when all finished |
||||
|
if t == 0: |
||||
|
unfinished = it != self.eos_idx |
||||
|
else: |
||||
|
it[~unfinished] = self.pad_idx # This allows eos_idx not being overwritten to 0 |
||||
|
logprobs = logprobs * unfinished.unsqueeze(1).to(logprobs) |
||||
|
unfinished = unfinished & (it != self.eos_idx) |
||||
|
seq[:,t] = it |
||||
|
seqLogprobs[:,t] = logprobs |
||||
|
# quit loop if all sequences have finished |
||||
|
if unfinished.sum() == 0: |
||||
|
break |
||||
|
|
||||
|
return seq, seqLogprobs |
||||
|
|
||||
|
def _diverse_sample(self, fc_feats, att_feats, att_masks=None, opt={}): |
||||
|
|
||||
|
sample_method = opt.get('sample_method', 'greedy') |
||||
|
beam_size = opt.get('beam_size', 1) |
||||
|
temperature = opt.get('temperature', 1.0) |
||||
|
group_size = opt.get('group_size', 1) |
||||
|
diversity_lambda = opt.get('diversity_lambda', 0.5) |
||||
|
decoding_constraint = opt.get('decoding_constraint', 0) |
||||
|
block_trigrams = opt.get('block_trigrams', 0) |
||||
|
remove_bad_endings = opt.get('remove_bad_endings', 0) |
||||
|
|
||||
|
batch_size = fc_feats.size(0) |
||||
|
state = self.init_hidden(batch_size) |
||||
|
|
||||
|
p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = self._prepare_feature(fc_feats, att_feats, att_masks) |
||||
|
|
||||
|
trigrams_table = [[] for _ in range(group_size)] # will be a list of batch_size dictionaries |
||||
|
|
||||
|
seq_table = [fc_feats.new_full((batch_size, self.seq_length), self.pad_idx, dtype=torch.long) for _ in range(group_size)] |
||||
|
seqLogprobs_table = [fc_feats.new_zeros(batch_size, self.seq_length) for _ in range(group_size)] |
||||
|
state_table = [self.init_hidden(batch_size) for _ in range(group_size)] |
||||
|
|
||||
|
for tt in range(self.seq_length + group_size): |
||||
|
for divm in range(group_size): |
||||
|
t = tt - divm |
||||
|
seq = seq_table[divm] |
||||
|
seqLogprobs = seqLogprobs_table[divm] |
||||
|
trigrams = trigrams_table[divm] |
||||
|
if t >= 0 and t <= self.seq_length-1: |
||||
|
if t == 0: # input <bos> |
||||
|
it = fc_feats.new_full([batch_size], self.bos_idx, dtype=torch.long) |
||||
|
else: |
||||
|
it = seq[:, t-1] # changed |
||||
|
|
||||
|
logprobs, state_table[divm] = self.get_logprobs_state(it, p_fc_feats, p_att_feats, pp_att_feats, p_att_masks, state_table[divm]) # changed |
||||
|
logprobs = F.log_softmax(logprobs / temperature, dim=-1) |
||||
|
|
||||
|
# Add diversity |
||||
|
if divm > 0: |
||||
|
unaug_logprobs = logprobs.clone() |
||||
|
for prev_choice in range(divm): |
||||
|
prev_decisions = seq_table[prev_choice][:, t] |
||||
|
logprobs[:, prev_decisions] = logprobs[:, prev_decisions] - diversity_lambda |
||||
|
|
||||
|
if decoding_constraint and t > 0: |
||||
|
tmp = logprobs.new_zeros(logprobs.size()) |
||||
|
tmp.scatter_(1, seq[:,t-1].data.unsqueeze(1), float('-inf')) |
||||
|
logprobs = logprobs + tmp |
||||
|
|
||||
|
if remove_bad_endings and t > 0: |
||||
|
tmp = logprobs.new_zeros(logprobs.size()) |
||||
|
prev_bad = np.isin(seq[:,t-1].data.cpu().numpy(), self.bad_endings_ix) |
||||
|
# Impossible to generate remove_bad_endings |
||||
|
tmp[torch.from_numpy(prev_bad.astype('uint8')), 0] = float('-inf') |
||||
|
logprobs = logprobs + tmp |
||||
|
|
||||
|
# Mess with trigrams |
||||
|
if block_trigrams and t >= 3: |
||||
|
# Store trigram generated at last step |
||||
|
prev_two_batch = seq[:,t-3:t-1] |
||||
|
for i in range(batch_size): # = seq.size(0) |
||||
|
prev_two = (prev_two_batch[i][0].item(), prev_two_batch[i][1].item()) |
||||
|
current = seq[i][t-1] |
||||
|
if t == 3: # initialize |
||||
|
trigrams.append({prev_two: [current]}) # {LongTensor: list containing 1 int} |
||||
|
elif t > 3: |
||||
|
if prev_two in trigrams[i]: # add to list |
||||
|
trigrams[i][prev_two].append(current) |
||||
|
else: # create list |
||||
|
trigrams[i][prev_two] = [current] |
||||
|
# Block used trigrams at next step |
||||
|
prev_two_batch = seq[:,t-2:t] |
||||
|
mask = torch.zeros(logprobs.size(), requires_grad=False).cuda() # batch_size x vocab_size |
||||
|
for i in range(batch_size): |
||||
|
prev_two = (prev_two_batch[i][0].item(), prev_two_batch[i][1].item()) |
||||
|
if prev_two in trigrams[i]: |
||||
|
for j in trigrams[i][prev_two]: |
||||
|
mask[i,j] += 1 |
||||
|
# Apply mask to log probs |
||||
|
#logprobs = logprobs - (mask * 1e9) |
||||
|
alpha = 2.0 # = 4 |
||||
|
logprobs = logprobs + (mask * -0.693 * alpha) # ln(1/2) * alpha (alpha -> infty works best) |
||||
|
|
||||
|
it, sampleLogprobs = self.sample_next_word(logprobs, sample_method, 1) |
||||
|
|
||||
|
# stop when all finished |
||||
|
if t == 0: |
||||
|
unfinished = it != self.eos_idx |
||||
|
else: |
||||
|
unfinished = (seq[:,t-1] != self.pad_idx) & (seq[:,t-1] != self.eos_idx) |
||||
|
it[~unfinished] = self.pad_idx |
||||
|
unfinished = unfinished & (it != self.eos_idx) # changed |
||||
|
seq[:,t] = it |
||||
|
seqLogprobs[:,t] = sampleLogprobs.view(-1) |
||||
|
|
||||
|
return torch.stack(seq_table, 1).reshape(batch_size * group_size, -1), torch.stack(seqLogprobs_table, 1).reshape(batch_size * group_size, -1) |
@ -0,0 +1,412 @@ |
|||||
|
# This file contains ShowAttendTell and AllImg model |
||||
|
|
||||
|
# ShowAttendTell is from Show, Attend and Tell: Neural Image Caption Generation with Visual Attention |
||||
|
# https://arxiv.org/abs/1502.03044 |
||||
|
|
||||
|
# AllImg is a model where |
||||
|
# img feature is concatenated with word embedding at every time step as the input of lstm |
||||
|
from __future__ import absolute_import |
||||
|
from __future__ import division |
||||
|
from __future__ import print_function |
||||
|
|
||||
|
import numpy as np |
||||
|
import torch |
||||
|
import torch.nn as nn |
||||
|
import torch.nn.functional as F |
||||
|
from torch.autograd import * |
||||
|
#import .model_utils as mutils |
||||
|
from . import model_utils |
||||
|
|
||||
|
#from ..utils import misc as utils |
||||
|
#from ... import utils as model_utils |
||||
|
|
||||
|
#model_utils split_tensors |
||||
|
#utils penalty_builder decode_sequence |
||||
|
|
||||
|
class CaptionModel(nn.Module): |
||||
|
def __init__(self): |
||||
|
super(CaptionModel, self).__init__() |
||||
|
|
||||
|
# implements beam search |
||||
|
# calls beam_step and returns the final set of beams |
||||
|
# augments log-probabilities with diversity terms when number of groups > 1 |
||||
|
|
||||
|
def forward(self, *args, **kwargs): |
||||
|
mode = kwargs.get('mode', 'forward') |
||||
|
if 'mode' in kwargs: |
||||
|
del kwargs['mode'] |
||||
|
return getattr(self, '_'+mode)(*args, **kwargs) |
||||
|
|
||||
|
def beam_search(self, init_state, init_logprobs, *args, **kwargs): |
||||
|
|
||||
|
# function computes the similarity score to be augmented |
||||
|
def add_diversity(beam_seq_table, logprobs, t, divm, diversity_lambda, bdash): |
||||
|
local_time = t - divm |
||||
|
unaug_logprobs = logprobs.clone() |
||||
|
batch_size = beam_seq_table[0].shape[0] |
||||
|
|
||||
|
if divm > 0: |
||||
|
change = logprobs.new_zeros(batch_size, logprobs.shape[-1]) |
||||
|
for prev_choice in range(divm): |
||||
|
prev_decisions = beam_seq_table[prev_choice][:, :, local_time] # Nxb |
||||
|
for prev_labels in range(bdash): |
||||
|
change.scatter_add_(1, prev_decisions[:, prev_labels].unsqueeze(-1), change.new_ones(batch_size, 1)) |
||||
|
|
||||
|
if local_time == 0: |
||||
|
logprobs = logprobs - change * diversity_lambda |
||||
|
else: |
||||
|
logprobs = logprobs - self.repeat_tensor(bdash, change) * diversity_lambda |
||||
|
|
||||
|
return logprobs, unaug_logprobs |
||||
|
|
||||
|
|
||||
|
# does one step of classical beam search |
||||
|
|
||||
|
def beam_step(logprobs, unaug_logprobs, beam_size, t, beam_seq, beam_seq_logprobs, beam_logprobs_sum, state): |
||||
|
#INPUTS: |
||||
|
#logprobs: probabilities augmented after diversity N*bxV |
||||
|
#beam_size: obvious |
||||
|
#t : time instant |
||||
|
#beam_seq : tensor contanining the beams |
||||
|
#beam_seq_logprobs: tensor contanining the beam logprobs |
||||
|
#beam_logprobs_sum: tensor contanining joint logprobs |
||||
|
#OUPUTS: |
||||
|
#beam_seq : tensor containing the word indices of the decoded captions Nxbxl |
||||
|
#beam_seq_logprobs : log-probability of each decision made, NxbxlxV |
||||
|
#beam_logprobs_sum : joint log-probability of each beam Nxb |
||||
|
|
||||
|
batch_size = beam_logprobs_sum.shape[0] |
||||
|
vocab_size = logprobs.shape[-1] |
||||
|
logprobs = logprobs.reshape(batch_size, -1, vocab_size) # NxbxV |
||||
|
if t == 0: |
||||
|
assert logprobs.shape[1] == 1 |
||||
|
beam_logprobs_sum = beam_logprobs_sum[:, :1] |
||||
|
candidate_logprobs = beam_logprobs_sum.unsqueeze(-1) + logprobs # beam_logprobs_sum Nxb logprobs is NxbxV |
||||
|
ys, ix = torch.sort(candidate_logprobs.reshape(candidate_logprobs.shape[0], -1), -1, True) |
||||
|
ys, ix = ys[:,:beam_size], ix[:,:beam_size] |
||||
|
beam_ix = ix // vocab_size # Nxb which beam |
||||
|
selected_ix = ix % vocab_size # Nxb # which world |
||||
|
state_ix = (beam_ix + torch.arange(batch_size).type_as(beam_ix).unsqueeze(-1) * logprobs.shape[1]).reshape(-1) # N*b which in Nxb beams |
||||
|
|
||||
|
|
||||
|
if t > 0: |
||||
|
# gather according to beam_ix |
||||
|
assert (beam_seq.gather(1, beam_ix.unsqueeze(-1).expand_as(beam_seq)) == beam_seq.reshape(-1, beam_seq.shape[-1])[state_ix].view_as(beam_seq)).all() |
||||
|
beam_seq = beam_seq.gather(1, beam_ix.unsqueeze(-1).expand_as(beam_seq)) |
||||
|
|
||||
|
beam_seq_logprobs = beam_seq_logprobs.gather(1, beam_ix.unsqueeze(-1).unsqueeze(-1).expand_as(beam_seq_logprobs)) |
||||
|
|
||||
|
beam_seq = torch.cat([beam_seq, selected_ix.unsqueeze(-1)], -1) # beam_seq Nxbxl |
||||
|
beam_logprobs_sum = beam_logprobs_sum.gather(1, beam_ix) + \ |
||||
|
logprobs.reshape(batch_size, -1).gather(1, ix) |
||||
|
assert (beam_logprobs_sum == ys).all() |
||||
|
_tmp_beam_logprobs = unaug_logprobs[state_ix].reshape(batch_size, -1, vocab_size) |
||||
|
beam_logprobs = unaug_logprobs.reshape(batch_size, -1, vocab_size).gather(1, beam_ix.unsqueeze(-1).expand(-1, -1, vocab_size)) # NxbxV |
||||
|
assert (_tmp_beam_logprobs == beam_logprobs).all() |
||||
|
beam_seq_logprobs = torch.cat([ |
||||
|
beam_seq_logprobs, |
||||
|
beam_logprobs.reshape(batch_size, -1, 1, vocab_size)], 2) |
||||
|
|
||||
|
new_state = [None for _ in state] |
||||
|
for _ix in range(len(new_state)): |
||||
|
# copy over state in previous beam q to new beam at vix |
||||
|
new_state[_ix] = state[_ix][:, state_ix] |
||||
|
state = new_state |
||||
|
return beam_seq,beam_seq_logprobs,beam_logprobs_sum,state |
||||
|
|
||||
|
# Start diverse_beam_search |
||||
|
opt = kwargs['opt'] |
||||
|
temperature = opt.get('temperature', 1) # This should not affect beam search, but will affect dbs |
||||
|
beam_size = opt.get('beam_size', 10) |
||||
|
group_size = opt.get('group_size', 1) |
||||
|
diversity_lambda = opt.get('diversity_lambda', 0.5) |
||||
|
decoding_constraint = opt.get('decoding_constraint', 0) |
||||
|
remove_bad_endings = opt.get('remove_bad_endings', 0) |
||||
|
suppress_UNK = opt.get('suppress_UNK', 0) |
||||
|
length_penalty = utils.penalty_builder(opt.get('length_penalty', '')) |
||||
|
bdash = beam_size // group_size # beam per group |
||||
|
|
||||
|
batch_size = init_logprobs.shape[0] |
||||
|
device = init_logprobs.device |
||||
|
# INITIALIZATIONS |
||||
|
beam_seq_table = [torch.LongTensor(batch_size, bdash, 0).to(device) for _ in range(group_size)] |
||||
|
beam_seq_logprobs_table = [torch.FloatTensor(batch_size, bdash, 0, self.vocab_size + 1).to(device) for _ in range(group_size)] |
||||
|
beam_logprobs_sum_table = [torch.zeros(batch_size, bdash).to(device) for _ in range(group_size)] |
||||
|
|
||||
|
# logprobs # logprobs predicted in last time step, shape (beam_size, vocab_size+1) |
||||
|
done_beams_table = [[[] for __ in range(group_size)] for _ in range(batch_size)] |
||||
|
# state_table = [list(torch.unbind(_)) for _ in torch.stack(init_state).chunk(group_size, 2)] |
||||
|
# state_table = list(zip(*[_.reshape(-1, batch_size * bdash, group_size, *_.shape[2:]).chunk(group_size, 2) for _ in init_state])) |
||||
|
state_table = [[_.clone() for _ in init_state] for _ in range(group_size)] |
||||
|
# logprobs_table = list(init_logprobs.reshape(batch_size * bdash, group_size, -1).chunk(group_size, 0)) |
||||
|
logprobs_table = [init_logprobs.clone() for _ in range(group_size)] |
||||
|
# END INIT |
||||
|
|
||||
|
# Chunk elements in the args |
||||
|
args = list(args) |
||||
|
args = model_utils.split_tensors(group_size, args) # For each arg, turn (Bbg)x... to (Bb)x(g)x... |
||||
|
if self.__class__.__name__ == 'AttEnsemble': |
||||
|
args = [[[args[j][i][k] for i in range(len(self.models))] for j in range(len(args))] for k in range(group_size)] # group_name, arg_name, model_name |
||||
|
else: |
||||
|
args = [[args[i][j] for i in range(len(args))] for j in range(group_size)] |
||||
|
|
||||
|
for t in range(self.seq_length + group_size - 1): |
||||
|
for divm in range(group_size): |
||||
|
if t >= divm and t <= self.seq_length + divm - 1: |
||||
|
# add diversity |
||||
|
logprobs = logprobs_table[divm] |
||||
|
# suppress previous word |
||||
|
if decoding_constraint and t-divm > 0: |
||||
|
logprobs.scatter_(1, beam_seq_table[divm][:, :, t-divm-1].reshape(-1, 1).to(device), float('-inf')) |
||||
|
if remove_bad_endings and t-divm > 0: |
||||
|
logprobs[torch.from_numpy(np.isin(beam_seq_table[divm][:, :, t-divm-1].cpu().numpy(), self.bad_endings_ix)).reshape(-1), 0] = float('-inf') |
||||
|
# suppress UNK tokens in the decoding |
||||
|
if suppress_UNK and hasattr(self, 'vocab') and self.vocab[str(logprobs.size(1)-1)] == 'UNK': |
||||
|
logprobs[:,logprobs.size(1)-1] = logprobs[:, logprobs.size(1)-1] - 1000 |
||||
|
# diversity is added here |
||||
|
# the function directly modifies the logprobs values and hence, we need to return |
||||
|
# the unaugmented ones for sorting the candidates in the end. # for historical |
||||
|
# reasons :-) |
||||
|
logprobs, unaug_logprobs = add_diversity(beam_seq_table,logprobs,t,divm,diversity_lambda,bdash) |
||||
|
|
||||
|
# infer new beams |
||||
|
beam_seq_table[divm],\ |
||||
|
beam_seq_logprobs_table[divm],\ |
||||
|
beam_logprobs_sum_table[divm],\ |
||||
|
state_table[divm] = beam_step(logprobs, |
||||
|
unaug_logprobs, |
||||
|
bdash, |
||||
|
t-divm, |
||||
|
beam_seq_table[divm], |
||||
|
beam_seq_logprobs_table[divm], |
||||
|
beam_logprobs_sum_table[divm], |
||||
|
state_table[divm]) |
||||
|
|
||||
|
# if time's up... or if end token is reached then copy beams |
||||
|
for b in range(batch_size): |
||||
|
is_end = beam_seq_table[divm][b, :, t-divm] == self.eos_idx |
||||
|
assert beam_seq_table[divm].shape[-1] == t-divm+1 |
||||
|
if t == self.seq_length + divm - 1: |
||||
|
is_end.fill_(1) |
||||
|
for vix in range(bdash): |
||||
|
if is_end[vix]: |
||||
|
final_beam = { |
||||
|
'seq': beam_seq_table[divm][b, vix].clone(), |
||||
|
'logps': beam_seq_logprobs_table[divm][b, vix].clone(), |
||||
|
'unaug_p': beam_seq_logprobs_table[divm][b, vix].sum().item(), |
||||
|
'p': beam_logprobs_sum_table[divm][b, vix].item() |
||||
|
} |
||||
|
final_beam['p'] = length_penalty(t-divm+1, final_beam['p']) |
||||
|
done_beams_table[b][divm].append(final_beam) |
||||
|
beam_logprobs_sum_table[divm][b, is_end] -= 1000 |
||||
|
|
||||
|
# move the current group one step forward in time |
||||
|
|
||||
|
it = beam_seq_table[divm][:, :, t-divm].reshape(-1).to(logprobs.device) |
||||
|
logprobs_table[divm], state_table[divm] = self.get_logprobs_state(it, *(args[divm] + [state_table[divm]])) |
||||
|
logprobs_table[divm] = F.log_softmax(logprobs_table[divm] / temperature, dim=-1) |
||||
|
|
||||
|
# all beams are sorted by their log-probabilities |
||||
|
done_beams_table = [[sorted(done_beams_table[b][i], key=lambda x: -x['p'])[:bdash] for i in range(group_size)] for b in range(batch_size)] |
||||
|
done_beams = [sum(_, []) for _ in done_beams_table] |
||||
|
return done_beams |
||||
|
|
||||
|
def old_beam_search(self, init_state, init_logprobs, *args, **kwargs): |
||||
|
|
||||
|
# function computes the similarity score to be augmented |
||||
|
def add_diversity(beam_seq_table, logprobsf, t, divm, diversity_lambda, bdash): |
||||
|
local_time = t - divm |
||||
|
unaug_logprobsf = logprobsf.clone() |
||||
|
for prev_choice in range(divm): |
||||
|
prev_decisions = beam_seq_table[prev_choice][local_time] |
||||
|
for sub_beam in range(bdash): |
||||
|
for prev_labels in range(bdash): |
||||
|
logprobsf[sub_beam][prev_decisions[prev_labels]] = logprobsf[sub_beam][prev_decisions[prev_labels]] - diversity_lambda |
||||
|
return unaug_logprobsf |
||||
|
|
||||
|
# does one step of classical beam search |
||||
|
|
||||
|
def beam_step(logprobsf, unaug_logprobsf, beam_size, t, beam_seq, beam_seq_logprobs, beam_logprobs_sum, state): |
||||
|
#INPUTS: |
||||
|
#logprobsf: probabilities augmented after diversity |
||||
|
#beam_size: obvious |
||||
|
#t : time instant |
||||
|
#beam_seq : tensor contanining the beams |
||||
|
#beam_seq_logprobs: tensor contanining the beam logprobs |
||||
|
#beam_logprobs_sum: tensor contanining joint logprobs |
||||
|
#OUPUTS: |
||||
|
#beam_seq : tensor containing the word indices of the decoded captions |
||||
|
#beam_seq_logprobs : log-probability of each decision made, same size as beam_seq |
||||
|
#beam_logprobs_sum : joint log-probability of each beam |
||||
|
|
||||
|
ys,ix = torch.sort(logprobsf,1,True) |
||||
|
candidates = [] |
||||
|
cols = min(beam_size, ys.size(1)) |
||||
|
rows = beam_size |
||||
|
if t == 0: |
||||
|
rows = 1 |
||||
|
for c in range(cols): # for each column (word, essentially) |
||||
|
for q in range(rows): # for each beam expansion |
||||
|
#compute logprob of expanding beam q with word in (sorted) position c |
||||
|
local_logprob = ys[q,c].item() |
||||
|
candidate_logprob = beam_logprobs_sum[q] + local_logprob |
||||
|
# local_unaug_logprob = unaug_logprobsf[q,ix[q,c]] |
||||
|
candidates.append({'c':ix[q,c], 'q':q, 'p':candidate_logprob, 'r':unaug_logprobsf[q]}) |
||||
|
candidates = sorted(candidates, key=lambda x: -x['p']) |
||||
|
|
||||
|
new_state = [_.clone() for _ in state] |
||||
|
#beam_seq_prev, beam_seq_logprobs_prev |
||||
|
if t >= 1: |
||||
|
#we''ll need these as reference when we fork beams around |
||||
|
beam_seq_prev = beam_seq[:t].clone() |
||||
|
beam_seq_logprobs_prev = beam_seq_logprobs[:t].clone() |
||||
|
for vix in range(beam_size): |
||||
|
v = candidates[vix] |
||||
|
#fork beam index q into index vix |
||||
|
if t >= 1: |
||||
|
beam_seq[:t, vix] = beam_seq_prev[:, v['q']] |
||||
|
beam_seq_logprobs[:t, vix] = beam_seq_logprobs_prev[:, v['q']] |
||||
|
#rearrange recurrent states |
||||
|
for state_ix in range(len(new_state)): |
||||
|
# copy over state in previous beam q to new beam at vix |
||||
|
new_state[state_ix][:, vix] = state[state_ix][:, v['q']] # dimension one is time step |
||||
|
#append new end terminal at the end of this beam |
||||
|
beam_seq[t, vix] = v['c'] # c'th word is the continuation |
||||
|
beam_seq_logprobs[t, vix] = v['r'] # the raw logprob here |
||||
|
beam_logprobs_sum[vix] = v['p'] # the new (sum) logprob along this beam |
||||
|
state = new_state |
||||
|
return beam_seq,beam_seq_logprobs,beam_logprobs_sum,state,candidates |
||||
|
|
||||
|
# Start diverse_beam_search |
||||
|
opt = kwargs['opt'] |
||||
|
temperature = opt.get('temperature', 1) # This should not affect beam search, but will affect dbs |
||||
|
beam_size = opt.get('beam_size', 10) |
||||
|
group_size = opt.get('group_size', 1) |
||||
|
diversity_lambda = opt.get('diversity_lambda', 0.5) |
||||
|
decoding_constraint = opt.get('decoding_constraint', 0) |
||||
|
remove_bad_endings = opt.get('remove_bad_endings', 0) |
||||
|
suppress_UNK = opt.get('suppress_UNK', 0) |
||||
|
length_penalty = utils.penalty_builder(opt.get('length_penalty', '')) |
||||
|
bdash = beam_size // group_size # beam per group |
||||
|
|
||||
|
# INITIALIZATIONS |
||||
|
beam_seq_table = [torch.LongTensor(self.seq_length, bdash).zero_() for _ in range(group_size)] |
||||
|
beam_seq_logprobs_table = [torch.FloatTensor(self.seq_length, bdash, self.vocab_size + 1).zero_() for _ in range(group_size)] |
||||
|
beam_logprobs_sum_table = [torch.zeros(bdash) for _ in range(group_size)] |
||||
|
|
||||
|
# logprobs # logprobs predicted in last time step, shape (beam_size, vocab_size+1) |
||||
|
done_beams_table = [[] for _ in range(group_size)] |
||||
|
# state_table = [list(torch.unbind(_)) for _ in torch.stack(init_state).chunk(group_size, 2)] |
||||
|
state_table = list(zip(*[_.chunk(group_size, 1) for _ in init_state])) |
||||
|
logprobs_table = list(init_logprobs.chunk(group_size, 0)) |
||||
|
# END INIT |
||||
|
|
||||
|
# Chunk elements in the args |
||||
|
args = list(args) |
||||
|
if self.__class__.__name__ == 'AttEnsemble': |
||||
|
args = [[_.chunk(group_size) if _ is not None else [None]*group_size for _ in args_] for args_ in args] # arg_name, model_name, group_name |
||||
|
args = [[[args[j][i][k] for i in range(len(self.models))] for j in range(len(args))] for k in range(group_size)] # group_name, arg_name, model_name |
||||
|
else: |
||||
|
args = [_.chunk(group_size) if _ is not None else [None]*group_size for _ in args] |
||||
|
args = [[args[i][j] for i in range(len(args))] for j in range(group_size)] |
||||
|
|
||||
|
for t in range(self.seq_length + group_size - 1): |
||||
|
for divm in range(group_size): |
||||
|
if t >= divm and t <= self.seq_length + divm - 1: |
||||
|
# add diversity |
||||
|
logprobsf = logprobs_table[divm] |
||||
|
# suppress previous word |
||||
|
if decoding_constraint and t-divm > 0: |
||||
|
logprobsf.scatter_(1, beam_seq_table[divm][t-divm-1].unsqueeze(1).to(logprobsf.device), float('-inf')) |
||||
|
if remove_bad_endings and t-divm > 0: |
||||
|
logprobsf[torch.from_numpy(np.isin(beam_seq_table[divm][t-divm-1].cpu().numpy(), self.bad_endings_ix)), 0] = float('-inf') |
||||
|
# suppress UNK tokens in the decoding |
||||
|
if suppress_UNK and hasattr(self, 'vocab') and self.vocab[str(logprobsf.size(1)-1)] == 'UNK': |
||||
|
logprobsf[:,logprobsf.size(1)-1] = logprobsf[:, logprobsf.size(1)-1] - 1000 |
||||
|
# diversity is added here |
||||
|
# the function directly modifies the logprobsf values and hence, we need to return |
||||
|
# the unaugmented ones for sorting the candidates in the end. # for historical |
||||
|
# reasons :-) |
||||
|
unaug_logprobsf = add_diversity(beam_seq_table,logprobsf,t,divm,diversity_lambda,bdash) |
||||
|
|
||||
|
# infer new beams |
||||
|
beam_seq_table[divm],\ |
||||
|
beam_seq_logprobs_table[divm],\ |
||||
|
beam_logprobs_sum_table[divm],\ |
||||
|
state_table[divm],\ |
||||
|
candidates_divm = beam_step(logprobsf, |
||||
|
unaug_logprobsf, |
||||
|
bdash, |
||||
|
t-divm, |
||||
|
beam_seq_table[divm], |
||||
|
beam_seq_logprobs_table[divm], |
||||
|
beam_logprobs_sum_table[divm], |
||||
|
state_table[divm]) |
||||
|
|
||||
|
# if time's up... or if end token is reached then copy beams |
||||
|
for vix in range(bdash): |
||||
|
if beam_seq_table[divm][t-divm,vix] == self.eos_idx or t == self.seq_length + divm - 1: |
||||
|
final_beam = { |
||||
|
'seq': beam_seq_table[divm][:, vix].clone(), |
||||
|
'logps': beam_seq_logprobs_table[divm][:, vix].clone(), |
||||
|
'unaug_p': beam_seq_logprobs_table[divm][:, vix].sum().item(), |
||||
|
'p': beam_logprobs_sum_table[divm][vix].item() |
||||
|
} |
||||
|
final_beam['p'] = length_penalty(t-divm+1, final_beam['p']) |
||||
|
done_beams_table[divm].append(final_beam) |
||||
|
# don't continue beams from finished sequences |
||||
|
beam_logprobs_sum_table[divm][vix] = -1000 |
||||
|
|
||||
|
# move the current group one step forward in time |
||||
|
|
||||
|
it = beam_seq_table[divm][t-divm].to(logprobsf.device) |
||||
|
logprobs_table[divm], state_table[divm] = self.get_logprobs_state(it, *(args[divm] + [state_table[divm]])) |
||||
|
logprobs_table[divm] = F.log_softmax(logprobs_table[divm] / temperature, dim=-1) |
||||
|
|
||||
|
# all beams are sorted by their log-probabilities |
||||
|
done_beams_table = [sorted(done_beams_table[i], key=lambda x: -x['p'])[:bdash] for i in range(group_size)] |
||||
|
done_beams = sum(done_beams_table, []) |
||||
|
return done_beams |
||||
|
|
||||
|
def sample_next_word(self, logprobs, sample_method, temperature): |
||||
|
if sample_method == 'greedy': |
||||
|
sampleLogprobs, it = torch.max(logprobs.data, 1) |
||||
|
it = it.view(-1).long() |
||||
|
elif sample_method == 'gumbel': # gumbel softmax |
||||
|
# ref: https://gist.github.com/yzh119/fd2146d2aeb329d067568a493b20172f |
||||
|
def sample_gumbel(shape, eps=1e-20): |
||||
|
U = torch.rand(shape).to(logprobs.device) |
||||
|
return -torch.log(-torch.log(U + eps) + eps) |
||||
|
def gumbel_softmax_sample(logits, temperature): |
||||
|
y = logits + sample_gumbel(logits.size()) |
||||
|
return F.log_softmax(y / temperature, dim=-1) |
||||
|
_logprobs = gumbel_softmax_sample(logprobs, temperature) |
||||
|
_, it = torch.max(_logprobs.data, 1) |
||||
|
sampleLogprobs = logprobs.gather(1, it.unsqueeze(1)) # gather the logprobs at sampled positions |
||||
|
else: |
||||
|
logprobs = logprobs / temperature |
||||
|
if sample_method.startswith('top'): # topk sampling |
||||
|
top_num = float(sample_method[3:]) |
||||
|
if 0 < top_num < 1: |
||||
|
# nucleus sampling from # The Curious Case of Neural Text Degeneration |
||||
|
probs = F.softmax(logprobs, dim=1) |
||||
|
sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=1) |
||||
|
_cumsum = sorted_probs.cumsum(1) |
||||
|
mask = _cumsum < top_num |
||||
|
mask = torch.cat([torch.ones_like(mask[:,:1]), mask[:,:-1]], 1) |
||||
|
sorted_probs = sorted_probs * mask.to(sorted_probs) |
||||
|
sorted_probs = sorted_probs / sorted_probs.sum(1, keepdim=True) |
||||
|
logprobs.scatter_(1, sorted_indices, sorted_probs.log()) |
||||
|
else: |
||||
|
the_k = int(top_num) |
||||
|
tmp = torch.empty_like(logprobs).fill_(float('-inf')) |
||||
|
topk, indices = torch.topk(logprobs, the_k, dim=1) |
||||
|
tmp = tmp.scatter(1, indices, topk) |
||||
|
logprobs = tmp |
||||
|
it = torch.distributions.Categorical(logits=logprobs.detach()).sample() |
||||
|
sampleLogprobs = logprobs.gather(1, it.unsqueeze(1)) # gather the logprobs at sampled positions |
||||
|
return it, sampleLogprobs |
||||
|
|
||||
|
|
||||
|
def decode_sequence(self, seq): |
||||
|
return utils.decode_sequence(self.vocab, seq) |
@ -0,0 +1,92 @@ |
|||||
|
import torch |
||||
|
|
||||
|
def split_tensors(n, x): |
||||
|
if torch.is_tensor(x): |
||||
|
assert x.shape[0] % n == 0 |
||||
|
x = x.reshape(x.shape[0] // n, n, *x.shape[1:]).unbind(1) |
||||
|
elif type(x) is list or type(x) is tuple: |
||||
|
x = [split_tensors(n, _) for _ in x] |
||||
|
elif x is None: |
||||
|
x = [None] * n |
||||
|
return x |
||||
|
|
||||
|
# Input: seq, N*D numpy array, with element 0 .. vocab_size. 0 is END token. |
||||
|
#def decode_sequence(ix_to_word, seq): |
||||
|
# # N, D = seq.size() |
||||
|
# N, D = seq.shape |
||||
|
# out = [] |
||||
|
# for i in range(N): |
||||
|
# txt = '' |
||||
|
# for j in range(D): |
||||
|
# ix = seq[i,j] |
||||
|
# if ix > 0 : |
||||
|
# if j >= 1: |
||||
|
# txt = txt + ' ' |
||||
|
# txt = txt + ix_to_word[str(ix.item())] |
||||
|
# else: |
||||
|
# break |
||||
|
# if int(os.getenv('REMOVE_BAD_ENDINGS', '0')): |
||||
|
# flag = 0 |
||||
|
# words = txt.split(' ') |
||||
|
# for j in range(len(words)): |
||||
|
# if words[-j-1] not in bad_endings: |
||||
|
# flag = -j |
||||
|
# break |
||||
|
# txt = ' '.join(words[0:len(words)+flag]) |
||||
|
# out.append(txt.replace('@@ ', '')) |
||||
|
# return out |
||||
|
|
||||
|
def decode_sequence(ix_to_word, seq, remove_bad_endings = True): |
||||
|
# N, D = seq.size() |
||||
|
N, D = seq.shape |
||||
|
bad_endings = ['a','an','the','in','for','at','of','with','before','after','on','upon','near','to','is','are','am'] |
||||
|
bad_endings += ['the'] |
||||
|
out = [] |
||||
|
for i in range(N): |
||||
|
txt = '' |
||||
|
for j in range(D): |
||||
|
ix = seq[i,j] |
||||
|
if ix > 0 : |
||||
|
if j >= 1: |
||||
|
txt = txt + ' ' |
||||
|
txt = txt + ix_to_word[str(ix.item())] |
||||
|
else: |
||||
|
break |
||||
|
if remove_bad_endings is True: |
||||
|
flag = 0 |
||||
|
words = txt.split(' ') |
||||
|
for j in range(len(words)): |
||||
|
if words[-j-1] not in bad_endings: |
||||
|
flag = -j |
||||
|
break |
||||
|
txt = ' '.join(words[0:len(words)+flag]) |
||||
|
out.append(txt.replace('@@ ', '')) |
||||
|
return out |
||||
|
|
||||
|
|
||||
|
|
||||
|
def penalty_builder(penalty_config): |
||||
|
if penalty_config == '': |
||||
|
return lambda x,y: y |
||||
|
pen_type, alpha = penalty_config.split('_') |
||||
|
alpha = float(alpha) |
||||
|
if pen_type == 'wu': |
||||
|
return lambda x,y: length_wu(x,y,alpha) |
||||
|
if pen_type == 'avg': |
||||
|
return lambda x,y: length_average(x,y,alpha) |
||||
|
|
||||
|
def length_wu(length, logprobs, alpha=0.): |
||||
|
""" |
||||
|
NMT length re-ranking score from |
||||
|
"Google's Neural Machine Translation System" :cite:`wu2016google`. |
||||
|
""" |
||||
|
|
||||
|
modifier = (((5 + length) ** alpha) / |
||||
|
((5 + 1) ** alpha)) |
||||
|
return (logprobs / modifier) |
||||
|
|
||||
|
def length_average(length, logprobs, alpha=0.): |
||||
|
""" |
||||
|
Returns the average probability of tokens in a sequence. |
||||
|
""" |
||||
|
return logprobs / length |
@ -0,0 +1 @@ |
|||||
|
from .clip import * |
Binary file not shown.
@ -0,0 +1,193 @@ |
|||||
|
import hashlib |
||||
|
import os |
||||
|
import urllib |
||||
|
import warnings |
||||
|
from typing import Union, List |
||||
|
|
||||
|
import torch |
||||
|
from PIL import Image |
||||
|
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize |
||||
|
from tqdm import tqdm |
||||
|
|
||||
|
from .model import build_model |
||||
|
from .simple_tokenizer import SimpleTokenizer as _Tokenizer |
||||
|
|
||||
|
__all__ = ["available_models", "load", "tokenize"] |
||||
|
_tokenizer = _Tokenizer() |
||||
|
|
||||
|
_MODELS = { |
||||
|
"RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", |
||||
|
"RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", |
||||
|
"RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", |
||||
|
"ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", |
||||
|
} |
||||
|
|
||||
|
|
||||
|
def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")): |
||||
|
os.makedirs(root, exist_ok=True) |
||||
|
filename = os.path.basename(url) |
||||
|
|
||||
|
expected_sha256 = url.split("/")[-2] |
||||
|
download_target = os.path.join(root, filename) |
||||
|
|
||||
|
if os.path.exists(download_target) and not os.path.isfile(download_target): |
||||
|
raise RuntimeError(f"{download_target} exists and is not a regular file") |
||||
|
|
||||
|
if os.path.isfile(download_target): |
||||
|
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: |
||||
|
return download_target |
||||
|
else: |
||||
|
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") |
||||
|
|
||||
|
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: |
||||
|
with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop: |
||||
|
while True: |
||||
|
buffer = source.read(8192) |
||||
|
if not buffer: |
||||
|
break |
||||
|
|
||||
|
output.write(buffer) |
||||
|
loop.update(len(buffer)) |
||||
|
|
||||
|
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: |
||||
|
raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") |
||||
|
|
||||
|
return download_target |
||||
|
|
||||
|
|
||||
|
def _transform(n_px): |
||||
|
return Compose([ |
||||
|
Resize(n_px, interpolation=Image.BICUBIC), |
||||
|
CenterCrop(n_px), |
||||
|
lambda image: image.convert("RGB"), |
||||
|
ToTensor(), |
||||
|
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), |
||||
|
]) |
||||
|
|
||||
|
|
||||
|
def available_models() -> List[str]: |
||||
|
"""Returns the names of available CLIP models""" |
||||
|
return list(_MODELS.keys()) |
||||
|
|
||||
|
|
||||
|
def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=True): |
||||
|
"""Load a CLIP model |
||||
|
|
||||
|
Parameters |
||||
|
---------- |
||||
|
name : str |
||||
|
A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict |
||||
|
|
||||
|
device : Union[str, torch.device] |
||||
|
The device to put the loaded model |
||||
|
|
||||
|
jit : bool |
||||
|
Whether to load the optimized JIT model (default) or more hackable non-JIT model. |
||||
|
|
||||
|
Returns |
||||
|
------- |
||||
|
model : torch.nn.Module |
||||
|
The CLIP model |
||||
|
|
||||
|
preprocess : Callable[[PIL.Image], torch.Tensor] |
||||
|
A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input |
||||
|
""" |
||||
|
if name in _MODELS: |
||||
|
model_path = _download(_MODELS[name]) |
||||
|
elif os.path.isfile(name): |
||||
|
model_path = name |
||||
|
else: |
||||
|
raise RuntimeError(f"Model {name} not found; available models = {available_models()}") |
||||
|
|
||||
|
try: |
||||
|
# loading JIT archive |
||||
|
model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() |
||||
|
state_dict = None |
||||
|
except RuntimeError: |
||||
|
# loading saved state dict |
||||
|
if jit: |
||||
|
warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") |
||||
|
jit = False |
||||
|
state_dict = torch.load(model_path, map_location="cpu") |
||||
|
|
||||
|
if not jit: |
||||
|
model = build_model(state_dict or model.state_dict()).to(device) |
||||
|
if str(device) == "cpu": |
||||
|
model.float() |
||||
|
return model, _transform(model.visual.input_resolution) |
||||
|
|
||||
|
# patch the device names |
||||
|
device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) |
||||
|
device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] |
||||
|
|
||||
|
def patch_device(module): |
||||
|
graphs = [module.graph] if hasattr(module, "graph") else [] |
||||
|
if hasattr(module, "forward1"): |
||||
|
graphs.append(module.forward1.graph) |
||||
|
|
||||
|
for graph in graphs: |
||||
|
for node in graph.findAllNodes("prim::Constant"): |
||||
|
if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): |
||||
|
node.copyAttributes(device_node) |
||||
|
|
||||
|
model.apply(patch_device) |
||||
|
patch_device(model.encode_image) |
||||
|
patch_device(model.encode_text) |
||||
|
|
||||
|
# patch dtype to float32 on CPU |
||||
|
if str(device) == "cpu": |
||||
|
float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) |
||||
|
float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] |
||||
|
float_node = float_input.node() |
||||
|
|
||||
|
def patch_float(module): |
||||
|
graphs = [module.graph] if hasattr(module, "graph") else [] |
||||
|
if hasattr(module, "forward1"): |
||||
|
graphs.append(module.forward1.graph) |
||||
|
|
||||
|
for graph in graphs: |
||||
|
for node in graph.findAllNodes("aten::to"): |
||||
|
inputs = list(node.inputs()) |
||||
|
for i in [1, 2]: # dtype can be the second or third argument to aten::to() |
||||
|
if inputs[i].node()["value"] == 5: |
||||
|
inputs[i].node().copyAttributes(float_node) |
||||
|
|
||||
|
model.apply(patch_float) |
||||
|
patch_float(model.encode_image) |
||||
|
patch_float(model.encode_text) |
||||
|
|
||||
|
model.float() |
||||
|
|
||||
|
return model, _transform(model.input_resolution.item()) |
||||
|
|
||||
|
|
||||
|
def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor: |
||||
|
""" |
||||
|
Returns the tokenized representation of given input string(s) |
||||
|
|
||||
|
Parameters |
||||
|
---------- |
||||
|
texts : Union[str, List[str]] |
||||
|
An input string or a list of input strings to tokenize |
||||
|
|
||||
|
context_length : int |
||||
|
The context length to use; all CLIP models use 77 as the context length |
||||
|
|
||||
|
Returns |
||||
|
------- |
||||
|
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] |
||||
|
""" |
||||
|
if isinstance(texts, str): |
||||
|
texts = [texts] |
||||
|
|
||||
|
sot_token = _tokenizer.encoder["<|startoftext|>"] |
||||
|
eot_token = _tokenizer.encoder["<|endoftext|>"] |
||||
|
all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] |
||||
|
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) |
||||
|
|
||||
|
for i, tokens in enumerate(all_tokens): |
||||
|
if len(tokens) > context_length: |
||||
|
raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") |
||||
|
result[i, :len(tokens)] = torch.tensor(tokens) |
||||
|
|
||||
|
return result |
@ -0,0 +1,437 @@ |
|||||
|
from collections import OrderedDict |
||||
|
from typing import Tuple, Union |
||||
|
|
||||
|
import torch |
||||
|
import torch.nn.functional as F |
||||
|
from torch import nn |
||||
|
|
||||
|
|
||||
|
class Bottleneck(nn.Module): |
||||
|
expansion = 4 |
||||
|
|
||||
|
def __init__(self, inplanes, planes, stride=1): |
||||
|
super().__init__() |
||||
|
|
||||
|
# all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 |
||||
|
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) |
||||
|
self.bn1 = nn.BatchNorm2d(planes) |
||||
|
|
||||
|
self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) |
||||
|
self.bn2 = nn.BatchNorm2d(planes) |
||||
|
|
||||
|
self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() |
||||
|
|
||||
|
self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) |
||||
|
self.bn3 = nn.BatchNorm2d(planes * self.expansion) |
||||
|
|
||||
|
self.relu = nn.ReLU(inplace=True) |
||||
|
self.downsample = None |
||||
|
self.stride = stride |
||||
|
|
||||
|
if stride > 1 or inplanes != planes * Bottleneck.expansion: |
||||
|
# downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 |
||||
|
self.downsample = nn.Sequential(OrderedDict([ |
||||
|
("-1", nn.AvgPool2d(stride)), |
||||
|
("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), |
||||
|
("1", nn.BatchNorm2d(planes * self.expansion)) |
||||
|
])) |
||||
|
|
||||
|
def forward(self, x: torch.Tensor): |
||||
|
identity = x |
||||
|
|
||||
|
out = self.relu(self.bn1(self.conv1(x))) |
||||
|
out = self.relu(self.bn2(self.conv2(out))) |
||||
|
out = self.avgpool(out) |
||||
|
out = self.bn3(self.conv3(out)) |
||||
|
|
||||
|
if self.downsample is not None: |
||||
|
identity = self.downsample(x) |
||||
|
|
||||
|
out += identity |
||||
|
out = self.relu(out) |
||||
|
return out |
||||
|
|
||||
|
|
||||
|
class AttentionPool2d(nn.Module): |
||||
|
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): |
||||
|
super().__init__() |
||||
|
self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) |
||||
|
self.k_proj = nn.Linear(embed_dim, embed_dim) |
||||
|
self.q_proj = nn.Linear(embed_dim, embed_dim) |
||||
|
self.v_proj = nn.Linear(embed_dim, embed_dim) |
||||
|
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) |
||||
|
self.num_heads = num_heads |
||||
|
|
||||
|
def forward(self, x): |
||||
|
x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC |
||||
|
# print(x.shape, self.positional_embedding.shape) |
||||
|
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC |
||||
|
x = x + self.positional_embedding[0, :, None, :].to(x.dtype) # (HW+1)NC |
||||
|
x, _ = F.multi_head_attention_forward( |
||||
|
query=x, key=x, value=x, |
||||
|
embed_dim_to_check=x.shape[-1], |
||||
|
num_heads=self.num_heads, |
||||
|
q_proj_weight=self.q_proj.weight, |
||||
|
k_proj_weight=self.k_proj.weight, |
||||
|
v_proj_weight=self.v_proj.weight, |
||||
|
in_proj_weight=None, |
||||
|
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), |
||||
|
bias_k=None, |
||||
|
bias_v=None, |
||||
|
add_zero_attn=False, |
||||
|
dropout_p=0, |
||||
|
out_proj_weight=torch.ones_like(self.q_proj.weight), |
||||
|
out_proj_bias=torch.zeros_like(self.q_proj.bias), |
||||
|
# out_proj_weight=self.c_proj.weight, |
||||
|
# out_proj_bias=self.c_proj.bias, |
||||
|
use_separate_proj_weight=True, |
||||
|
training=self.training, |
||||
|
need_weights=False |
||||
|
) |
||||
|
|
||||
|
return x[0] |
||||
|
|
||||
|
|
||||
|
class ModifiedResNet(nn.Module): |
||||
|
""" |
||||
|
A ResNet class that is similar to torchvision's but contains the following changes: |
||||
|
- There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. |
||||
|
- Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 |
||||
|
- The final pooling layer is a QKV attention instead of an average pool |
||||
|
""" |
||||
|
|
||||
|
def __init__(self, layers, output_dim, heads, input_resolution=224, width=64): |
||||
|
super().__init__() |
||||
|
self.output_dim = output_dim |
||||
|
self.input_resolution = input_resolution |
||||
|
|
||||
|
# the 3-layer stem |
||||
|
self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) |
||||
|
self.bn1 = nn.BatchNorm2d(width // 2) |
||||
|
self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) |
||||
|
self.bn2 = nn.BatchNorm2d(width // 2) |
||||
|
self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) |
||||
|
self.bn3 = nn.BatchNorm2d(width) |
||||
|
self.avgpool = nn.AvgPool2d(2) |
||||
|
self.relu = nn.ReLU(inplace=True) |
||||
|
|
||||
|
# residual layers |
||||
|
self._inplanes = width # this is a *mutable* variable used during construction |
||||
|
self.layer1 = self._make_layer(width, layers[0]) |
||||
|
self.layer2 = self._make_layer(width * 2, layers[1], stride=2) |
||||
|
self.layer3 = self._make_layer(width * 4, layers[2], stride=2) |
||||
|
self.layer4 = self._make_layer(width * 8, layers[3], stride=2) |
||||
|
|
||||
|
embed_dim = width * 32 # the ResNet feature dimension |
||||
|
self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim) |
||||
|
|
||||
|
def _make_layer(self, planes, blocks, stride=1): |
||||
|
layers = [Bottleneck(self._inplanes, planes, stride)] |
||||
|
|
||||
|
self._inplanes = planes * Bottleneck.expansion |
||||
|
for _ in range(1, blocks): |
||||
|
layers.append(Bottleneck(self._inplanes, planes)) |
||||
|
|
||||
|
return nn.Sequential(*layers) |
||||
|
|
||||
|
def forward(self, x): |
||||
|
def stem(x): |
||||
|
for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]: |
||||
|
x = self.relu(bn(conv(x))) |
||||
|
x = self.avgpool(x) |
||||
|
return x |
||||
|
|
||||
|
x = x.type(self.conv1.weight.dtype) |
||||
|
x = stem(x) |
||||
|
x = self.layer1(x) |
||||
|
x = self.layer2(x) |
||||
|
x = self.layer3(x) |
||||
|
x = self.layer4(x) |
||||
|
# print(x.shape) |
||||
|
# x = self.attnpool(x) |
||||
|
attnpool = self.attnpool(x) |
||||
|
|
||||
|
return (x, attnpool) |
||||
|
|
||||
|
|
||||
|
class LayerNorm(nn.LayerNorm): |
||||
|
"""Subclass torch's LayerNorm to handle fp16.""" |
||||
|
|
||||
|
def forward(self, x: torch.Tensor): |
||||
|
orig_type = x.dtype |
||||
|
ret = super().forward(x.type(torch.float32)) |
||||
|
return ret.type(orig_type) |
||||
|
|
||||
|
|
||||
|
class QuickGELU(nn.Module): |
||||
|
def forward(self, x: torch.Tensor): |
||||
|
return x * torch.sigmoid(1.702 * x) |
||||
|
|
||||
|
|
||||
|
class ResidualAttentionBlock(nn.Module): |
||||
|
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): |
||||
|
super().__init__() |
||||
|
|
||||
|
self.attn = nn.MultiheadAttention(d_model, n_head) |
||||
|
self.ln_1 = LayerNorm(d_model) |
||||
|
self.mlp = nn.Sequential(OrderedDict([ |
||||
|
("c_fc", nn.Linear(d_model, d_model * 4)), |
||||
|
("gelu", QuickGELU()), |
||||
|
("c_proj", nn.Linear(d_model * 4, d_model)) |
||||
|
])) |
||||
|
self.ln_2 = LayerNorm(d_model) |
||||
|
self.attn_mask = attn_mask |
||||
|
|
||||
|
def attention(self, x: torch.Tensor): |
||||
|
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None |
||||
|
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] |
||||
|
|
||||
|
def forward(self, x: torch.Tensor): |
||||
|
x = x + self.attention(self.ln_1(x)) |
||||
|
x = x + self.mlp(self.ln_2(x)) |
||||
|
return x |
||||
|
|
||||
|
|
||||
|
class Transformer(nn.Module): |
||||
|
def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): |
||||
|
super().__init__() |
||||
|
self.width = width |
||||
|
self.layers = layers |
||||
|
self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) |
||||
|
|
||||
|
def forward(self, x: torch.Tensor): |
||||
|
return self.resblocks(x) |
||||
|
|
||||
|
|
||||
|
class VisualTransformer(nn.Module): |
||||
|
def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int): |
||||
|
super().__init__() |
||||
|
self.input_resolution = input_resolution |
||||
|
self.output_dim = output_dim |
||||
|
self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) |
||||
|
|
||||
|
scale = width ** -0.5 |
||||
|
self.class_embedding = nn.Parameter(scale * torch.randn(width)) |
||||
|
self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)) |
||||
|
self.ln_pre = LayerNorm(width) |
||||
|
|
||||
|
self.transformer = Transformer(width, layers, heads) |
||||
|
|
||||
|
self.ln_post = LayerNorm(width) |
||||
|
self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) |
||||
|
|
||||
|
def forward(self, x: torch.Tensor): |
||||
|
x = self.conv1(x) # shape = [*, width, grid, grid] |
||||
|
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] |
||||
|
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] |
||||
|
x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] |
||||
|
x = x + self.positional_embedding.to(x.dtype) |
||||
|
x = self.ln_pre(x) |
||||
|
|
||||
|
x = x.permute(1, 0, 2) # NLD -> LND |
||||
|
x = self.transformer(x) |
||||
|
x = x.permute(1, 0, 2) # LND -> NLD |
||||
|
|
||||
|
# x = self.ln_post(x[:, 0, :]) |
||||
|
|
||||
|
x = self.ln_post(x) |
||||
|
# if self.proj is not None: |
||||
|
# x = x @ self.proj |
||||
|
|
||||
|
return x |
||||
|
|
||||
|
|
||||
|
class CLIP(nn.Module): |
||||
|
def __init__(self, |
||||
|
embed_dim: int, |
||||
|
# vision |
||||
|
image_resolution: int, |
||||
|
vision_layers: Union[Tuple[int, int, int, int], int], |
||||
|
vision_width: int, |
||||
|
vision_patch_size: int, |
||||
|
# text |
||||
|
context_length: int, |
||||
|
vocab_size: int, |
||||
|
transformer_width: int, |
||||
|
transformer_heads: int, |
||||
|
transformer_layers: int |
||||
|
): |
||||
|
super().__init__() |
||||
|
|
||||
|
self.context_length = context_length |
||||
|
|
||||
|
if isinstance(vision_layers, (tuple, list)): |
||||
|
vision_heads = vision_width * 32 // 64 |
||||
|
self.visual = ModifiedResNet( |
||||
|
layers=vision_layers, |
||||
|
output_dim=embed_dim, |
||||
|
heads=vision_heads, |
||||
|
input_resolution=image_resolution, |
||||
|
width=vision_width |
||||
|
) |
||||
|
else: |
||||
|
vision_heads = vision_width // 64 |
||||
|
self.visual = VisualTransformer( |
||||
|
input_resolution=image_resolution, |
||||
|
patch_size=vision_patch_size, |
||||
|
width=vision_width, |
||||
|
layers=vision_layers, |
||||
|
heads=vision_heads, |
||||
|
output_dim=embed_dim |
||||
|
) |
||||
|
|
||||
|
self.transformer = Transformer( |
||||
|
width=transformer_width, |
||||
|
layers=transformer_layers, |
||||
|
heads=transformer_heads, |
||||
|
attn_mask=self.build_attention_mask() |
||||
|
) |
||||
|
|
||||
|
self.vocab_size = vocab_size |
||||
|
self.token_embedding = nn.Embedding(vocab_size, transformer_width) |
||||
|
self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) |
||||
|
self.ln_final = LayerNorm(transformer_width) |
||||
|
|
||||
|
self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) |
||||
|
self.logit_scale = nn.Parameter(torch.ones([])) |
||||
|
|
||||
|
self.initialize_parameters() |
||||
|
|
||||
|
def initialize_parameters(self): |
||||
|
nn.init.normal_(self.token_embedding.weight, std=0.02) |
||||
|
nn.init.normal_(self.positional_embedding, std=0.01) |
||||
|
|
||||
|
if isinstance(self.visual, ModifiedResNet): |
||||
|
if self.visual.attnpool is not None: |
||||
|
std = self.visual.attnpool.c_proj.in_features ** -0.5 |
||||
|
nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std) |
||||
|
nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std) |
||||
|
nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std) |
||||
|
nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std) |
||||
|
|
||||
|
for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]: |
||||
|
for name, param in resnet_block.named_parameters(): |
||||
|
if name.endswith("bn3.weight"): |
||||
|
nn.init.zeros_(param) |
||||
|
|
||||
|
proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) |
||||
|
attn_std = self.transformer.width ** -0.5 |
||||
|
fc_std = (2 * self.transformer.width) ** -0.5 |
||||
|
for block in self.transformer.resblocks: |
||||
|
nn.init.normal_(block.attn.in_proj_weight, std=attn_std) |
||||
|
nn.init.normal_(block.attn.out_proj.weight, std=proj_std) |
||||
|
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) |
||||
|
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) |
||||
|
|
||||
|
if self.text_projection is not None: |
||||
|
nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) |
||||
|
|
||||
|
def build_attention_mask(self): |
||||
|
# lazily create causal attention mask, with full attention between the vision tokens |
||||
|
# pytorch uses additive attention mask; fill with -inf |
||||
|
mask = torch.empty(self.context_length, self.context_length) |
||||
|
mask.fill_(float("-inf")) |
||||
|
mask.triu_(1) # zero out the lower diagonal |
||||
|
return mask |
||||
|
|
||||
|
@property |
||||
|
def dtype(self): |
||||
|
return self.visual.conv1.weight.dtype |
||||
|
|
||||
|
def encode_image(self, image): |
||||
|
return self.visual(image.type(self.dtype)) |
||||
|
|
||||
|
def encode_text(self, text): |
||||
|
x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] |
||||
|
|
||||
|
x = x + self.positional_embedding.type(self.dtype) |
||||
|
x = x.permute(1, 0, 2) # NLD -> LND |
||||
|
x = self.transformer(x) |
||||
|
x = x.permute(1, 0, 2) # LND -> NLD |
||||
|
x = self.ln_final(x).type(self.dtype) |
||||
|
|
||||
|
# x.shape = [batch_size, n_ctx, transformer.width] |
||||
|
# take features from the eot embedding (eot_token is the highest number in each sequence) |
||||
|
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection |
||||
|
|
||||
|
return x |
||||
|
|
||||
|
def forward(self, image, text): |
||||
|
image_features = self.encode_image(image) |
||||
|
text_features = self.encode_text(text) |
||||
|
|
||||
|
# normalized features |
||||
|
image_features = image_features / image_features.norm(dim=-1, keepdim=True) |
||||
|
text_features = text_features / text_features.norm(dim=-1, keepdim=True) |
||||
|
|
||||
|
# cosine similarity as logits |
||||
|
logit_scale = self.logit_scale.exp() |
||||
|
logits_per_image = logit_scale * image_features @ text_features.t() |
||||
|
logits_per_text = logit_scale * text_features @ image_features.t() |
||||
|
|
||||
|
# shape = [global_batch_size, global_batch_size] |
||||
|
return logits_per_image, logits_per_text |
||||
|
|
||||
|
|
||||
|
def convert_weights(model: nn.Module): |
||||
|
"""Convert applicable model parameters to fp16""" |
||||
|
|
||||
|
def _convert_weights_to_fp16(l): |
||||
|
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): |
||||
|
l.weight.data = l.weight.data.half() |
||||
|
if l.bias is not None: |
||||
|
l.bias.data = l.bias.data.half() |
||||
|
|
||||
|
if isinstance(l, nn.MultiheadAttention): |
||||
|
for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: |
||||
|
tensor = getattr(l, attr) |
||||
|
if tensor is not None: |
||||
|
tensor.data = tensor.data.half() |
||||
|
|
||||
|
for name in ["text_projection", "proj"]: |
||||
|
if hasattr(l, name): |
||||
|
attr = getattr(l, name) |
||||
|
if attr is not None: |
||||
|
attr.data = attr.data.half() |
||||
|
|
||||
|
model.apply(_convert_weights_to_fp16) |
||||
|
|
||||
|
|
||||
|
def build_model(state_dict: dict): |
||||
|
vit = "visual.proj" in state_dict |
||||
|
|
||||
|
if vit: |
||||
|
vision_width = state_dict["visual.conv1.weight"].shape[0] |
||||
|
vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) |
||||
|
vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] |
||||
|
grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) |
||||
|
image_resolution = vision_patch_size * grid_size |
||||
|
else: |
||||
|
counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] |
||||
|
vision_layers = tuple(counts) |
||||
|
vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] |
||||
|
output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) |
||||
|
vision_patch_size = None |
||||
|
assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] |
||||
|
image_resolution = output_width * 32 |
||||
|
|
||||
|
embed_dim = state_dict["text_projection"].shape[1] |
||||
|
context_length = state_dict["positional_embedding"].shape[0] |
||||
|
vocab_size = state_dict["token_embedding.weight"].shape[0] |
||||
|
transformer_width = state_dict["ln_final.weight"].shape[0] |
||||
|
transformer_heads = transformer_width // 64 |
||||
|
transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks"))) |
||||
|
|
||||
|
model = CLIP( |
||||
|
embed_dim, |
||||
|
image_resolution, vision_layers, vision_width, vision_patch_size, |
||||
|
context_length, vocab_size, transformer_width, transformer_heads, transformer_layers |
||||
|
) |
||||
|
|
||||
|
for key in ["input_resolution", "context_length", "vocab_size"]: |
||||
|
if key in state_dict: |
||||
|
del state_dict[key] |
||||
|
|
||||
|
convert_weights(model) |
||||
|
model.load_state_dict(state_dict) |
||||
|
return model.eval() |
@ -0,0 +1,132 @@ |
|||||
|
import gzip |
||||
|
import html |
||||
|
import os |
||||
|
from functools import lru_cache |
||||
|
|
||||
|
import ftfy |
||||
|
import regex as re |
||||
|
|
||||
|
|
||||
|
@lru_cache() |
||||
|
def default_bpe(): |
||||
|
return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") |
||||
|
|
||||
|
|
||||
|
@lru_cache() |
||||
|
def bytes_to_unicode(): |
||||
|
""" |
||||
|
Returns list of utf-8 byte and a corresponding list of unicode strings. |
||||
|
The reversible bpe codes work on unicode strings. |
||||
|
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. |
||||
|
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. |
||||
|
This is a signficant percentage of your normal, say, 32K bpe vocab. |
||||
|
To avoid that, we want lookup tables between utf-8 bytes and unicode strings. |
||||
|
And avoids mapping to whitespace/control characters the bpe code barfs on. |
||||
|
""" |
||||
|
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) |
||||
|
cs = bs[:] |
||||
|
n = 0 |
||||
|
for b in range(2**8): |
||||
|
if b not in bs: |
||||
|
bs.append(b) |
||||
|
cs.append(2**8+n) |
||||
|
n += 1 |
||||
|
cs = [chr(n) for n in cs] |
||||
|
return dict(zip(bs, cs)) |
||||
|
|
||||
|
|
||||
|
def get_pairs(word): |
||||
|
"""Return set of symbol pairs in a word. |
||||
|
Word is represented as tuple of symbols (symbols being variable-length strings). |
||||
|
""" |
||||
|
pairs = set() |
||||
|
prev_char = word[0] |
||||
|
for char in word[1:]: |
||||
|
pairs.add((prev_char, char)) |
||||
|
prev_char = char |
||||
|
return pairs |
||||
|
|
||||
|
|
||||
|
def basic_clean(text): |
||||
|
text = ftfy.fix_text(text) |
||||
|
text = html.unescape(html.unescape(text)) |
||||
|
return text.strip() |
||||
|
|
||||
|
|
||||
|
def whitespace_clean(text): |
||||
|
text = re.sub(r'\s+', ' ', text) |
||||
|
text = text.strip() |
||||
|
return text |
||||
|
|
||||
|
|
||||
|
class SimpleTokenizer(object): |
||||
|
def __init__(self, bpe_path: str = default_bpe()): |
||||
|
self.byte_encoder = bytes_to_unicode() |
||||
|
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} |
||||
|
merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') |
||||
|
merges = merges[1:49152-256-2+1] |
||||
|
merges = [tuple(merge.split()) for merge in merges] |
||||
|
vocab = list(bytes_to_unicode().values()) |
||||
|
vocab = vocab + [v+'</w>' for v in vocab] |
||||
|
for merge in merges: |
||||
|
vocab.append(''.join(merge)) |
||||
|
vocab.extend(['<|startoftext|>', '<|endoftext|>']) |
||||
|
self.encoder = dict(zip(vocab, range(len(vocab)))) |
||||
|
self.decoder = {v: k for k, v in self.encoder.items()} |
||||
|
self.bpe_ranks = dict(zip(merges, range(len(merges)))) |
||||
|
self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} |
||||
|
self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) |
||||
|
|
||||
|
def bpe(self, token): |
||||
|
if token in self.cache: |
||||
|
return self.cache[token] |
||||
|
word = tuple(token[:-1]) + ( token[-1] + '</w>',) |
||||
|
pairs = get_pairs(word) |
||||
|
|
||||
|
if not pairs: |
||||
|
return token+'</w>' |
||||
|
|
||||
|
while True: |
||||
|
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) |
||||
|
if bigram not in self.bpe_ranks: |
||||
|
break |
||||
|
first, second = bigram |
||||
|
new_word = [] |
||||
|
i = 0 |
||||
|
while i < len(word): |
||||
|
try: |
||||
|
j = word.index(first, i) |
||||
|
new_word.extend(word[i:j]) |
||||
|
i = j |
||||
|
except: |
||||
|
new_word.extend(word[i:]) |
||||
|
break |
||||
|
|
||||
|
if word[i] == first and i < len(word)-1 and word[i+1] == second: |
||||
|
new_word.append(first+second) |
||||
|
i += 2 |
||||
|
else: |
||||
|
new_word.append(word[i]) |
||||
|
i += 1 |
||||
|
new_word = tuple(new_word) |
||||
|
word = new_word |
||||
|
if len(word) == 1: |
||||
|
break |
||||
|
else: |
||||
|
pairs = get_pairs(word) |
||||
|
word = ' '.join(word) |
||||
|
self.cache[token] = word |
||||
|
return word |
||||
|
|
||||
|
def encode(self, text): |
||||
|
bpe_tokens = [] |
||||
|
text = whitespace_clean(basic_clean(text)).lower() |
||||
|
for token in re.findall(self.pat, text): |
||||
|
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) |
||||
|
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) |
||||
|
return bpe_tokens |
||||
|
|
||||
|
def decode(self, tokens): |
||||
|
text = ''.join([self.decoder[token] for token in tokens]) |
||||
|
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ') |
||||
|
return text |
@ -0,0 +1,105 @@ |
|||||
|
# Copyright 2021 Zilliz. All rights reserved. |
||||
|
# |
||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
|
# you may not use this file except in compliance with the License. |
||||
|
# You may obtain a copy of the License at |
||||
|
# |
||||
|
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
|
# |
||||
|
# Unless required by applicable law or agreed to in writing, software |
||||
|
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
|
# See the License for the specific language governing permissions and |
||||
|
# limitations under the License. |
||||
|
import pathlib |
||||
|
|
||||
|
from torch import nn |
||||
|
from timm.models.vision_transformer import resize_pos_embed |
||||
|
from towhee.types.image_utils import to_pil |
||||
|
|
||||
|
|
||||
|
class ClipCaptionReward(NNOperator): |
||||
|
""" |
||||
|
BLIP multi-modal embedding operator |
||||
|
""" |
||||
|
def __init__(self, model_name: str): |
||||
|
super().__init__() |
||||
|
sys.path.append(str(Path(__file__).parent)) |
||||
|
from utils import opts |
||||
|
import clip |
||||
|
opt = opts.parse_opt(parse=False, cfg=cfg) |
||||
|
path = pathlib.Path(__file__).parent |
||||
|
dict_json = json.load(open("{}/data/cocotalk.json".format(path))) |
||||
|
ix_to_word = dict_json["ix_to_word"] |
||||
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
||||
|
|
||||
|
clip_model, clip_transform = clip.load("RN50", jit=False, device=self.device) |
||||
|
self.clip_model = clip_model |
||||
|
self.clip_transform = clip_transform |
||||
|
|
||||
|
vocab_size = len(ix_to_word) |
||||
|
seq_length = 1 |
||||
|
opt.vocab_size = vocab_size |
||||
|
opt.seq_length = seq_length |
||||
|
opt.batch_size = 1 |
||||
|
opt.vocab = ix_to_word |
||||
|
|
||||
|
num_patches = 196 # 600 * 1000 // 32 // 32 |
||||
|
|
||||
|
pos_embed = nn.Parameter( |
||||
|
torch.zeros( |
||||
|
1, |
||||
|
num_patches + 1, |
||||
|
clip_model.visual.attnpool.positional_embedding.shape[-1], |
||||
|
device=self.device, |
||||
|
), |
||||
|
) |
||||
|
pos_embed.weight = resize_pos_embed( |
||||
|
clip_model.visual.attnpool.positional_embedding.unsqueeze(0), pos_embed |
||||
|
) |
||||
|
self.clip_model.visual.attnpool.positional_embedding = pos_embed |
||||
|
|
||||
|
self.model = TransformerModel(opt) |
||||
|
self.image_mean = ( |
||||
|
torch.Tensor([0.48145466, 0.4578275, 0.40821073]) |
||||
|
.to(self.device) |
||||
|
.reshape(3, 1, 1) |
||||
|
) |
||||
|
self.image_std = ( |
||||
|
torch.Tensor([0.26862954, 0.26130258, 0.27577711]) |
||||
|
.to(self.device) |
||||
|
.reshape(3, 1, 1) |
||||
|
) |
||||
|
|
||||
|
@arg(1, to_image_color('RGB')) |
||||
|
def inference_single_data(self, data): |
||||
|
text = self._inference_from_image(data) |
||||
|
return text |
||||
|
|
||||
|
@arg(1, to_image_color('RGB')) |
||||
|
def _inference_from_image(self, img): |
||||
|
img = to_pil(img) |
||||
|
img = self._preprocess(img) |
||||
|
self._inference_from_image(img) |
||||
|
img -= self.image_mean |
||||
|
img /= self.image_std |
||||
|
tmp_att, tmp_fc = self.clip_model.encode_image(img) |
||||
|
tmp_att = tmp_att[0].permute(1, 2, 0) |
||||
|
|
||||
|
att_feat = tmp_att |
||||
|
|
||||
|
return att_feat |
||||
|
|
||||
|
def __call__(self, data): |
||||
|
if not isinstance(data, list): |
||||
|
data = [data] |
||||
|
else: |
||||
|
data = data |
||||
|
for single_data in data: |
||||
|
result = self.inference_single_data(single_data) |
||||
|
results.append(result) |
||||
|
if len(data) == 1: |
||||
|
return results[0] |
||||
|
else: |
||||
|
return results |
||||
|
|
Binary file not shown.
@ -0,0 +1,60 @@ |
|||||
|
caption_model: transformer |
||||
|
noamopt: true |
||||
|
noamopt_warmup: 20000 |
||||
|
label_smoothing: 0.0 |
||||
|
input_json: data/FineCapEval.json |
||||
|
input_label_h5: none |
||||
|
input_fc_dir: data/FineCapEval_clip_RN50_fc |
||||
|
input_att_dir: data/FineCapEval_clip_RN50_att |
||||
|
input_clipscore_vis_dir: data/FineCapEval_clipscore_vis |
||||
|
|
||||
|
seq_per_img: 5 |
||||
|
batch_size: 200 |
||||
|
learning_rate: 0.0005 |
||||
|
|
||||
|
checkpoint_path: ./save/clipRN50_mle/clipRN50_mle |
||||
|
|
||||
|
# clip_load_path: '/scratch-space/retrieval/save/clip_negative_text/clip_negative_text-epoch=10.ckpt' |
||||
|
|
||||
|
# Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer: |
||||
|
# N=num_layers |
||||
|
# d_model=input_encoding_size |
||||
|
# d_ff=rnn_size |
||||
|
|
||||
|
# will be ignored |
||||
|
num_layers: 6 |
||||
|
input_encoding_size: 512 |
||||
|
rnn_size: 2048 |
||||
|
|
||||
|
# Transformer config |
||||
|
N_enc: 6 |
||||
|
N_dec: 6 |
||||
|
d_model: 512 |
||||
|
d_ff: 2048 |
||||
|
num_att_heads: 8 |
||||
|
dropout: 0.1 |
||||
|
|
||||
|
|
||||
|
learning_rate_decay_start: 0 |
||||
|
scheduled_sampling_start: -1 |
||||
|
save_checkpoint_every: 3000 |
||||
|
language_eval: 1 |
||||
|
val_images_use: 5000 |
||||
|
max_epochs: 15 |
||||
|
train_sample_n: 5 |
||||
|
|
||||
|
REFORWARD: false |
||||
|
|
||||
|
# _BASE_: transformer.yml |
||||
|
reduce_on_plateau: false |
||||
|
noamopt: false |
||||
|
learning_rate: 0.000005 |
||||
|
learning_rate_decay_start: -1 |
||||
|
|
||||
|
self_critical_after: 15 |
||||
|
max_epochs: 50 |
||||
|
|
||||
|
verbose: false |
||||
|
precision: 32 |
||||
|
|
||||
|
use_clipscore: false |
@ -0,0 +1,52 @@ |
|||||
|
caption_model: transformer |
||||
|
noamopt: true |
||||
|
# noamopt: false |
||||
|
noamopt_warmup: 20000 |
||||
|
label_smoothing: 0.0 |
||||
|
input_json: data/cocotalk.json |
||||
|
input_label_h5: data/cocotalk_label.h5 |
||||
|
input_fc_dir: data/cocotalk_clip_RN50_fc |
||||
|
input_att_dir: data/cocotalk_clip_RN50_att |
||||
|
input_clipscore_vis_dir: data/cocotalk_clipscore_vis |
||||
|
seq_per_img: 5 |
||||
|
# batch_size: 600 |
||||
|
batch_size: 200 |
||||
|
|
||||
|
learning_rate: 0.0005 |
||||
|
|
||||
|
# checkpoint_path: ./save/trans_clip_rn50_sc_pl |
||||
|
checkpoint_path: save/clipRN50_mle/clipRN50_mle |
||||
|
|
||||
|
# Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer: |
||||
|
# N=num_layers |
||||
|
# d_model=input_encoding_size |
||||
|
# d_ff=rnn_size |
||||
|
|
||||
|
# will be ignored |
||||
|
num_layers: 6 |
||||
|
input_encoding_size: 512 |
||||
|
rnn_size: 2048 |
||||
|
|
||||
|
# Transformer config |
||||
|
N_enc: 6 |
||||
|
N_dec: 6 |
||||
|
d_model: 512 |
||||
|
d_ff: 2048 |
||||
|
num_att_heads: 8 |
||||
|
dropout: 0.1 |
||||
|
|
||||
|
|
||||
|
learning_rate_decay_start: 0 |
||||
|
scheduled_sampling_start: -1 |
||||
|
save_checkpoint_every: 3000 |
||||
|
language_eval: 1 |
||||
|
val_images_use: 5000 |
||||
|
# max_epochs: 15 |
||||
|
max_epochs: 25 |
||||
|
train_sample_n: 5 |
||||
|
|
||||
|
REFORWARD: false |
||||
|
|
||||
|
|
||||
|
verbose: false |
||||
|
precision: 16 |
@ -0,0 +1,41 @@ |
|||||
|
caption_model: transformer |
||||
|
noamopt: true |
||||
|
noamopt_warmup: 20000 |
||||
|
label_smoothing: 0.0 |
||||
|
input_json: data/cocotalk.json |
||||
|
input_label_h5: data/cocotalk_label.h5 |
||||
|
input_att_dir: data/cocotalk_att |
||||
|
seq_per_img: 5 |
||||
|
batch_size: 10 |
||||
|
learning_rate: 0.0005 |
||||
|
|
||||
|
checkpoint_path: ./save/trans_rn50_sc |
||||
|
|
||||
|
# Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer: |
||||
|
# N=num_layers |
||||
|
# d_model=input_encoding_size |
||||
|
# d_ff=rnn_size |
||||
|
|
||||
|
# will be ignored |
||||
|
num_layers: 6 |
||||
|
input_encoding_size: 512 |
||||
|
rnn_size: 2048 |
||||
|
|
||||
|
# Transformer config |
||||
|
N_enc: 6 |
||||
|
N_dec: 6 |
||||
|
d_model: 512 |
||||
|
d_ff: 2048 |
||||
|
num_att_heads: 8 |
||||
|
dropout: 0.1 |
||||
|
|
||||
|
|
||||
|
learning_rate_decay_start: 0 |
||||
|
scheduled_sampling_start: -1 |
||||
|
save_checkpoint_every: 3000 |
||||
|
language_eval: 1 |
||||
|
val_images_use: 5000 |
||||
|
max_epochs: 15 |
||||
|
train_sample_n: 5 |
||||
|
|
||||
|
REFORWARD: false |
@ -0,0 +1,61 @@ |
|||||
|
caption_model: transformer |
||||
|
noamopt: true |
||||
|
noamopt_warmup: 20000 |
||||
|
label_smoothing: 0.0 |
||||
|
input_json: data/FineCapEval.json |
||||
|
input_label_h5: none |
||||
|
input_fc_dir: data/FineCapEval_clip_RN50_fc |
||||
|
input_att_dir: data/FineCapEval_clip_RN50_att |
||||
|
input_clipscore_vis_dir: data/FineCapEval_clipscore_vis |
||||
|
|
||||
|
seq_per_img: 5 |
||||
|
batch_size: 200 |
||||
|
learning_rate: 0.0005 |
||||
|
|
||||
|
checkpoint_path: ./save/clipRN50_cider/clipRN50_cider |
||||
|
|
||||
|
# clip_load_path: '/scratch-space/retrieval/save/clip_negative_text/clip_negative_text-epoch=10.ckpt' |
||||
|
|
||||
|
# Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer: |
||||
|
# N=num_layers |
||||
|
# d_model=input_encoding_size |
||||
|
# d_ff=rnn_size |
||||
|
|
||||
|
# will be ignored |
||||
|
num_layers: 6 |
||||
|
input_encoding_size: 512 |
||||
|
rnn_size: 2048 |
||||
|
|
||||
|
# Transformer config |
||||
|
N_enc: 6 |
||||
|
N_dec: 6 |
||||
|
d_model: 512 |
||||
|
d_ff: 2048 |
||||
|
num_att_heads: 8 |
||||
|
dropout: 0.1 |
||||
|
|
||||
|
|
||||
|
learning_rate_decay_start: 0 |
||||
|
scheduled_sampling_start: -1 |
||||
|
save_checkpoint_every: 3000 |
||||
|
language_eval: 1 |
||||
|
val_images_use: 5000 |
||||
|
max_epochs: 15 |
||||
|
train_sample_n: 5 |
||||
|
|
||||
|
REFORWARD: false |
||||
|
|
||||
|
# _BASE_: transformer.yml |
||||
|
reduce_on_plateau: false |
||||
|
noamopt: false |
||||
|
learning_rate: 0.000005 |
||||
|
learning_rate_decay_start: -1 |
||||
|
|
||||
|
self_critical_after: 15 |
||||
|
max_epochs: 50 |
||||
|
|
||||
|
verbose: false |
||||
|
precision: 32 |
||||
|
|
||||
|
# use_clipscore: true |
||||
|
use_clipscore: false |
@ -0,0 +1,65 @@ |
|||||
|
caption_model: transformer |
||||
|
noamopt: true |
||||
|
noamopt_warmup: 20000 |
||||
|
label_smoothing: 0.0 |
||||
|
input_json: data/FineCapEval.json |
||||
|
input_label_h5: none |
||||
|
input_fc_dir: data/FineCapEval_clip_RN50_fc |
||||
|
input_att_dir: data/FineCapEval_clip_RN50_att |
||||
|
input_clipscore_vis_dir: data/FineCapEval_clipscore_vis |
||||
|
|
||||
|
seq_per_img: 5 |
||||
|
batch_size: 200 |
||||
|
learning_rate: 0.0005 |
||||
|
|
||||
|
checkpoint_path: ./save/clipRN50_cider_clips/clipRN50_cider_clips |
||||
|
|
||||
|
# clip_load_path: '/scratch-space/retrieval/save/clip_negative_text/clip_negative_text-epoch=10.ckpt' |
||||
|
|
||||
|
# Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer: |
||||
|
# N=num_layers |
||||
|
# d_model=input_encoding_size |
||||
|
# d_ff=rnn_size |
||||
|
|
||||
|
# will be ignored |
||||
|
num_layers: 6 |
||||
|
input_encoding_size: 512 |
||||
|
rnn_size: 2048 |
||||
|
|
||||
|
# Transformer config |
||||
|
N_enc: 6 |
||||
|
N_dec: 6 |
||||
|
d_model: 512 |
||||
|
d_ff: 2048 |
||||
|
num_att_heads: 8 |
||||
|
dropout: 0.1 |
||||
|
|
||||
|
|
||||
|
learning_rate_decay_start: 0 |
||||
|
scheduled_sampling_start: -1 |
||||
|
save_checkpoint_every: 3000 |
||||
|
language_eval: 1 |
||||
|
val_images_use: 5000 |
||||
|
max_epochs: 15 |
||||
|
train_sample_n: 5 |
||||
|
|
||||
|
REFORWARD: false |
||||
|
|
||||
|
# _BASE_: transformer.yml |
||||
|
reduce_on_plateau: false |
||||
|
noamopt: false |
||||
|
learning_rate: 0.000005 |
||||
|
learning_rate_decay_start: -1 |
||||
|
|
||||
|
self_critical_after: 15 |
||||
|
max_epochs: 50 |
||||
|
|
||||
|
verbose: false |
||||
|
precision: 32 |
||||
|
|
||||
|
# use_clipscore: true |
||||
|
use_clipscore: false |
||||
|
clipscore_reward_weight: 2.0 |
||||
|
clipscore_mode: clip_s |
||||
|
|
||||
|
use_multi_rewards: true |
@ -0,0 +1,64 @@ |
|||||
|
caption_model: transformer |
||||
|
noamopt: true |
||||
|
noamopt_warmup: 20000 |
||||
|
label_smoothing: 0.0 |
||||
|
input_json: data/FineCapEval.json |
||||
|
input_label_h5: none |
||||
|
input_fc_dir: data/FineCapEval_clip_RN50_fc |
||||
|
input_att_dir: data/FineCapEval_clip_RN50_att |
||||
|
input_clipscore_vis_dir: data/FineCapEval_clipscore_vis |
||||
|
seq_per_img: 5 |
||||
|
batch_size: 160 |
||||
|
learning_rate: 0.0005 |
||||
|
|
||||
|
checkpoint_path: ./save/clipRN50_clips/clipRN50_clips |
||||
|
|
||||
|
use_multi_rewards: false |
||||
|
use_grammar: false |
||||
|
use_grammar_baseline: false |
||||
|
# clip_load_path: '/scratch-space/retrieval/save/clip_negative_text/clip_negative_text-epoch=10.ckpt' |
||||
|
|
||||
|
# Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer: |
||||
|
# N=num_layers |
||||
|
# d_model=input_encoding_size |
||||
|
# d_ff=rnn_size |
||||
|
|
||||
|
# will be ignored |
||||
|
num_layers: 6 |
||||
|
input_encoding_size: 512 |
||||
|
rnn_size: 2048 |
||||
|
|
||||
|
# Transformer config |
||||
|
N_enc: 6 |
||||
|
N_dec: 6 |
||||
|
d_model: 512 |
||||
|
d_ff: 2048 |
||||
|
num_att_heads: 8 |
||||
|
dropout: 0.1 |
||||
|
|
||||
|
|
||||
|
learning_rate_decay_start: 0 |
||||
|
scheduled_sampling_start: -1 |
||||
|
save_checkpoint_every: 3000 |
||||
|
language_eval: 0 |
||||
|
val_images_use: 5000 |
||||
|
max_epochs: 15 |
||||
|
train_sample_n: 5 |
||||
|
|
||||
|
REFORWARD: false |
||||
|
|
||||
|
# _BASE_: transformer.yml |
||||
|
reduce_on_plateau: false |
||||
|
noamopt: false |
||||
|
learning_rate: 0.000005 |
||||
|
learning_rate_decay_start: -1 |
||||
|
|
||||
|
self_critical_after: 15 |
||||
|
max_epochs: 50 |
||||
|
|
||||
|
verbose: false |
||||
|
precision: 32 |
||||
|
|
||||
|
# use_clipscore: true |
||||
|
use_clipscore: false |
||||
|
clipscore_reward_weight: 2.0 |
@ -0,0 +1,64 @@ |
|||||
|
caption_model: transformer |
||||
|
noamopt: true |
||||
|
noamopt_warmup: 20000 |
||||
|
label_smoothing: 0.0 |
||||
|
input_json: data/FineCapEval.json |
||||
|
input_label_h5: none |
||||
|
input_fc_dir: data/FineCapEval_clip_RN50_fc |
||||
|
input_att_dir: data/FineCapEval_clip_RN50_att |
||||
|
input_clipscore_vis_dir: data/FineCapEval_clipscore_vis |
||||
|
seq_per_img: 5 |
||||
|
batch_size: 160 |
||||
|
learning_rate: 0.0005 |
||||
|
|
||||
|
checkpoint_path: ./save/clipRN50_clips_grammar/clipRN50_clips_grammar |
||||
|
|
||||
|
use_multi_rewards: true |
||||
|
use_grammar: true |
||||
|
use_grammar_baseline: true |
||||
|
# clip_load_path: '/scratch-space/retrieval/save/clip_negative_text/clip_negative_text-epoch=10.ckpt' |
||||
|
|
||||
|
# Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer: |
||||
|
# N=num_layers |
||||
|
# d_model=input_encoding_size |
||||
|
# d_ff=rnn_size |
||||
|
|
||||
|
# will be ignored |
||||
|
num_layers: 6 |
||||
|
input_encoding_size: 512 |
||||
|
rnn_size: 2048 |
||||
|
|
||||
|
# Transformer config |
||||
|
N_enc: 6 |
||||
|
N_dec: 6 |
||||
|
d_model: 512 |
||||
|
d_ff: 2048 |
||||
|
num_att_heads: 8 |
||||
|
dropout: 0.1 |
||||
|
|
||||
|
|
||||
|
learning_rate_decay_start: 0 |
||||
|
scheduled_sampling_start: -1 |
||||
|
save_checkpoint_every: 3000 |
||||
|
language_eval: 0 |
||||
|
val_images_use: 5000 |
||||
|
max_epochs: 15 |
||||
|
train_sample_n: 5 |
||||
|
|
||||
|
REFORWARD: false |
||||
|
|
||||
|
# _BASE_: transformer.yml |
||||
|
reduce_on_plateau: false |
||||
|
noamopt: false |
||||
|
learning_rate: 0.000005 |
||||
|
learning_rate_decay_start: -1 |
||||
|
|
||||
|
self_critical_after: 15 |
||||
|
max_epochs: 50 |
||||
|
|
||||
|
verbose: false |
||||
|
precision: 32 |
||||
|
|
||||
|
# use_clipscore: true |
||||
|
use_clipscore: false |
||||
|
clipscore_reward_weight: 2.0 |
@ -0,0 +1,58 @@ |
|||||
|
caption_model: transformer |
||||
|
noamopt: true |
||||
|
noamopt_warmup: 20000 |
||||
|
label_smoothing: 0.0 |
||||
|
input_json: data/cocotalk.json |
||||
|
input_label_h5: data/cocotalk_label.h5 |
||||
|
input_fc_dir: data/cocotalk_clip_RN50_fc |
||||
|
input_att_dir: data/cocotalk_clip_RN50_att |
||||
|
# used only for evaluation |
||||
|
input_clipscore_vis_dir: data/cocotalk_clipscore_vis |
||||
|
|
||||
|
seq_per_img: 5 |
||||
|
batch_size: 200 |
||||
|
learning_rate: 0.0005 |
||||
|
|
||||
|
# checkpoint_path: ./save/trans_clip_rn50_sc_pl_scst_cider |
||||
|
checkpoint_path: save/clipRN50_cider/clipRN50_cider |
||||
|
|
||||
|
# Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer: |
||||
|
# N=num_layers |
||||
|
# d_model=input_encoding_size |
||||
|
# d_ff=rnn_size |
||||
|
|
||||
|
# will be ignored |
||||
|
num_layers: 6 |
||||
|
input_encoding_size: 512 |
||||
|
rnn_size: 2048 |
||||
|
|
||||
|
# Transformer config |
||||
|
N_enc: 6 |
||||
|
N_dec: 6 |
||||
|
d_model: 512 |
||||
|
d_ff: 2048 |
||||
|
num_att_heads: 8 |
||||
|
dropout: 0.1 |
||||
|
|
||||
|
|
||||
|
learning_rate_decay_start: 0 |
||||
|
scheduled_sampling_start: -1 |
||||
|
save_checkpoint_every: 3000 |
||||
|
language_eval: 1 |
||||
|
val_images_use: 5000 |
||||
|
max_epochs: 15 |
||||
|
train_sample_n: 5 |
||||
|
|
||||
|
REFORWARD: false |
||||
|
|
||||
|
# _BASE_: transformer.yml |
||||
|
reduce_on_plateau: false |
||||
|
noamopt: false |
||||
|
learning_rate: 0.000005 |
||||
|
learning_rate_decay_start: -1 |
||||
|
|
||||
|
self_critical_after: 15 |
||||
|
max_epochs: 40 |
||||
|
|
||||
|
verbose: false |
||||
|
precision: 32 |
@ -0,0 +1,61 @@ |
|||||
|
caption_model: transformer |
||||
|
noamopt: true |
||||
|
noamopt_warmup: 20000 |
||||
|
label_smoothing: 0.0 |
||||
|
input_json: data/cocotalk.json |
||||
|
input_label_h5: data/cocotalk_label.h5 |
||||
|
input_fc_dir: data/cocotalk_clip_RN50_fc |
||||
|
input_att_dir: data/cocotalk_clip_RN50_att |
||||
|
input_clipscore_vis_dir: data/cocotalk_clipscore_vis |
||||
|
seq_per_img: 5 |
||||
|
batch_size: 160 |
||||
|
learning_rate: 0.0005 |
||||
|
|
||||
|
checkpoint_path: save/clipRN50_cider_clips/clipRN50_cider_clips |
||||
|
|
||||
|
# Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer: |
||||
|
# N=num_layers |
||||
|
# d_model=input_encoding_size |
||||
|
# d_ff=rnn_size |
||||
|
|
||||
|
# will be ignored |
||||
|
num_layers: 6 |
||||
|
input_encoding_size: 512 |
||||
|
rnn_size: 2048 |
||||
|
|
||||
|
# Transformer config |
||||
|
N_enc: 6 |
||||
|
N_dec: 6 |
||||
|
d_model: 512 |
||||
|
d_ff: 2048 |
||||
|
num_att_heads: 8 |
||||
|
dropout: 0.1 |
||||
|
|
||||
|
|
||||
|
learning_rate_decay_start: 0 |
||||
|
scheduled_sampling_start: -1 |
||||
|
save_checkpoint_every: 3000 |
||||
|
language_eval: 1 |
||||
|
val_images_use: 5000 |
||||
|
max_epochs: 15 |
||||
|
train_sample_n: 5 |
||||
|
|
||||
|
REFORWARD: false |
||||
|
|
||||
|
# _BASE_: transformer.yml |
||||
|
reduce_on_plateau: false |
||||
|
noamopt: false |
||||
|
learning_rate: 0.000005 |
||||
|
learning_rate_decay_start: -1 |
||||
|
|
||||
|
self_critical_after: 15 |
||||
|
max_epochs: 40 |
||||
|
|
||||
|
verbose: false |
||||
|
precision: 32 |
||||
|
|
||||
|
use_clipscore: true |
||||
|
clipscore_reward_weight: 2.0 |
||||
|
clipscore_mode: clip_s |
||||
|
|
||||
|
use_multi_rewards: true |
@ -0,0 +1,58 @@ |
|||||
|
caption_model: transformer |
||||
|
noamopt: true |
||||
|
noamopt_warmup: 20000 |
||||
|
label_smoothing: 0.0 |
||||
|
input_json: data/cocotalk.json |
||||
|
input_label_h5: data/cocotalk_label.h5 |
||||
|
input_fc_dir: data/cocotalk_clip_RN50_fc |
||||
|
input_att_dir: data/cocotalk_clip_RN50_att |
||||
|
input_clipscore_vis_dir: data/cocotalk_clipscore_vis |
||||
|
seq_per_img: 5 |
||||
|
batch_size: 160 |
||||
|
learning_rate: 0.0005 |
||||
|
|
||||
|
checkpoint_path: save/clipRN50_clips/clipRN50_clips |
||||
|
|
||||
|
# Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer: |
||||
|
# N=num_layers |
||||
|
# d_model=input_encoding_size |
||||
|
# d_ff=rnn_size |
||||
|
|
||||
|
# will be ignored |
||||
|
num_layers: 6 |
||||
|
input_encoding_size: 512 |
||||
|
rnn_size: 2048 |
||||
|
|
||||
|
# Transformer config |
||||
|
N_enc: 6 |
||||
|
N_dec: 6 |
||||
|
d_model: 512 |
||||
|
d_ff: 2048 |
||||
|
num_att_heads: 8 |
||||
|
dropout: 0.1 |
||||
|
|
||||
|
|
||||
|
learning_rate_decay_start: 0 |
||||
|
scheduled_sampling_start: -1 |
||||
|
save_checkpoint_every: 3000 |
||||
|
language_eval: 1 |
||||
|
val_images_use: 5000 |
||||
|
max_epochs: 15 |
||||
|
train_sample_n: 5 |
||||
|
|
||||
|
REFORWARD: false |
||||
|
|
||||
|
# _BASE_: transformer.yml |
||||
|
reduce_on_plateau: false |
||||
|
noamopt: false |
||||
|
learning_rate: 0.000005 |
||||
|
learning_rate_decay_start: -1 |
||||
|
|
||||
|
self_critical_after: 15 |
||||
|
max_epochs: 40 |
||||
|
|
||||
|
verbose: false |
||||
|
precision: 32 |
||||
|
|
||||
|
use_clipscore: true |
||||
|
clipscore_reward_weight: 2.0 |
@ -0,0 +1,64 @@ |
|||||
|
caption_model: transformer |
||||
|
noamopt: true |
||||
|
noamopt_warmup: 20000 |
||||
|
label_smoothing: 0.0 |
||||
|
input_json: data/cocotalk.json |
||||
|
input_label_h5: data/cocotalk_label.h5 |
||||
|
input_fc_dir: data/cocotalk_clip_RN50_fc |
||||
|
input_att_dir: data/cocotalk_clip_RN50_att |
||||
|
input_clipscore_vis_dir: data/cocotalk_clipscore_vis |
||||
|
seq_per_img: 5 |
||||
|
batch_size: 160 |
||||
|
learning_rate: 0.0005 |
||||
|
|
||||
|
checkpoint_path: save/clipRN50_clips_grammar/clipRN50_clips_grammar |
||||
|
|
||||
|
use_multi_rewards: true |
||||
|
use_grammar: true |
||||
|
use_grammar_baseline: true |
||||
|
# clip_load_path: '/scratch-space/retrieval/save/clip_negative_text/clip_negative_text-epoch=10.ckpt' |
||||
|
clip_load_path: 'retrieval/save/clip_negative_text/clip_negative_text-epoch=12.ckpt' |
||||
|
|
||||
|
# Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer: |
||||
|
# N=num_layers |
||||
|
# d_model=input_encoding_size |
||||
|
# d_ff=rnn_size |
||||
|
|
||||
|
# will be ignored |
||||
|
num_layers: 6 |
||||
|
input_encoding_size: 512 |
||||
|
rnn_size: 2048 |
||||
|
|
||||
|
# Transformer config |
||||
|
N_enc: 6 |
||||
|
N_dec: 6 |
||||
|
d_model: 512 |
||||
|
d_ff: 2048 |
||||
|
num_att_heads: 8 |
||||
|
dropout: 0.1 |
||||
|
|
||||
|
|
||||
|
learning_rate_decay_start: 0 |
||||
|
scheduled_sampling_start: -1 |
||||
|
save_checkpoint_every: 3000 |
||||
|
language_eval: 1 |
||||
|
val_images_use: 5000 |
||||
|
max_epochs: 15 |
||||
|
train_sample_n: 5 |
||||
|
|
||||
|
REFORWARD: false |
||||
|
|
||||
|
# _BASE_: transformer.yml |
||||
|
reduce_on_plateau: false |
||||
|
noamopt: false |
||||
|
learning_rate: 0.000005 |
||||
|
learning_rate_decay_start: -1 |
||||
|
|
||||
|
self_critical_after: 15 |
||||
|
max_epochs: 40 |
||||
|
|
||||
|
verbose: false |
||||
|
precision: 32 |
||||
|
|
||||
|
use_clipscore: true |
||||
|
clipscore_reward_weight: 2.0 |
@ -0,0 +1,41 @@ |
|||||
|
caption_model: transformer |
||||
|
noamopt: true |
||||
|
noamopt_warmup: 20000 |
||||
|
label_smoothing: 0.0 |
||||
|
input_json: data/cocotalk.json |
||||
|
input_label_h5: data/cocotalk_label.h5 |
||||
|
input_att_dir: data/cocotalk_att |
||||
|
seq_per_img: 5 |
||||
|
batch_size: 10 |
||||
|
learning_rate: 0.0005 |
||||
|
|
||||
|
checkpoint_path: ./save/trans_rn50_sc |
||||
|
|
||||
|
# Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer: |
||||
|
# N=num_layers |
||||
|
# d_model=input_encoding_size |
||||
|
# d_ff=rnn_size |
||||
|
|
||||
|
# will be ignored |
||||
|
num_layers: 6 |
||||
|
input_encoding_size: 512 |
||||
|
rnn_size: 2048 |
||||
|
|
||||
|
# Transformer config |
||||
|
N_enc: 6 |
||||
|
N_dec: 6 |
||||
|
d_model: 512 |
||||
|
d_ff: 2048 |
||||
|
num_att_heads: 8 |
||||
|
dropout: 0.1 |
||||
|
|
||||
|
|
||||
|
learning_rate_decay_start: 0 |
||||
|
scheduled_sampling_start: -1 |
||||
|
save_checkpoint_every: 3000 |
||||
|
language_eval: 1 |
||||
|
val_images_use: 5000 |
||||
|
max_epochs: 15 |
||||
|
train_sample_n: 5 |
||||
|
|
||||
|
REFORWARD: false |
File diff suppressed because one or more lines are too long
@ -0,0 +1,379 @@ |
|||||
|
# This file contains Transformer network |
||||
|
# Most of the code is copied from http://nlp.seas.harvard.edu/2018/04/03/attention.html |
||||
|
|
||||
|
# The cfg name correspondance: |
||||
|
# N=num_layers |
||||
|
# d_model=input_encoding_size |
||||
|
# d_ff=rnn_size |
||||
|
# h is always 8 |
||||
|
|
||||
|
from __future__ import absolute_import |
||||
|
from __future__ import division |
||||
|
from __future__ import print_function |
||||
|
|
||||
|
import torch |
||||
|
import torch.nn as nn |
||||
|
import torch.nn.functional as F |
||||
|
#from . import utils |
||||
|
|
||||
|
import copy |
||||
|
import math |
||||
|
import numpy as np |
||||
|
|
||||
|
#from .CaptionModel import CaptionModel |
||||
|
#from .AttModel import sort_pack_padded_sequence, pad_unsort_packed_sequence, pack_wrapper, AttModel |
||||
|
from captioning.models.CaptionModel import CaptionModel |
||||
|
from captioning.models.AttModel import sort_pack_padded_sequence, pad_unsort_packed_sequence, pack_wrapper, AttModel |
||||
|
|
||||
|
def repeat_tensors(n, x): |
||||
|
""" |
||||
|
For a tensor of size Bx..., we repeat it n times, and make it Bnx... |
||||
|
For collections, do nested repeat |
||||
|
""" |
||||
|
if torch.is_tensor(x): |
||||
|
x = x.unsqueeze(1) # Bx1x... |
||||
|
x = x.expand(-1, n, *([-1]*len(x.shape[2:]))) # Bxnx... |
||||
|
x = x.reshape(x.shape[0]*n, *x.shape[2:]) # Bnx... |
||||
|
elif type(x) is list or type(x) is tuple: |
||||
|
x = [repeat_tensors(n, _) for _ in x] |
||||
|
return x |
||||
|
|
||||
|
|
||||
|
class EncoderDecoder(nn.Module): |
||||
|
""" |
||||
|
A standard Encoder-Decoder architecture. Base for this and many |
||||
|
other models. |
||||
|
""" |
||||
|
def __init__(self, encoder, decoder, src_embed, tgt_embed, generator): |
||||
|
super(EncoderDecoder, self).__init__() |
||||
|
self.encoder = encoder |
||||
|
self.decoder = decoder |
||||
|
self.src_embed = src_embed |
||||
|
self.tgt_embed = tgt_embed |
||||
|
self.generator = generator |
||||
|
|
||||
|
def forward(self, src, tgt, src_mask, tgt_mask): |
||||
|
"Take in and process masked src and target sequences." |
||||
|
return self.decode(self.encode(src, src_mask), src_mask, |
||||
|
tgt, tgt_mask) |
||||
|
|
||||
|
def encode(self, src, src_mask): |
||||
|
return self.encoder(self.src_embed(src), src_mask) |
||||
|
|
||||
|
def decode(self, memory, src_mask, tgt, tgt_mask): |
||||
|
return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask) |
||||
|
|
||||
|
class Generator(nn.Module): |
||||
|
"Define standard linear + softmax generation step." |
||||
|
def __init__(self, d_model, vocab): |
||||
|
super(Generator, self).__init__() |
||||
|
self.proj = nn.Linear(d_model, vocab) |
||||
|
|
||||
|
def forward(self, x): |
||||
|
return F.log_softmax(self.proj(x), dim=-1) |
||||
|
|
||||
|
def clones(module, N): |
||||
|
"Produce N identical layers." |
||||
|
return nn.ModuleList([copy.deepcopy(module) for _ in range(N)]) |
||||
|
|
||||
|
class Encoder(nn.Module): |
||||
|
"Core encoder is a stack of N layers" |
||||
|
def __init__(self, layer, N): |
||||
|
super(Encoder, self).__init__() |
||||
|
self.layers = clones(layer, N) |
||||
|
self.norm = LayerNorm(layer.size) |
||||
|
|
||||
|
def forward(self, x, mask): |
||||
|
"Pass the input (and mask) through each layer in turn." |
||||
|
for layer in self.layers: |
||||
|
x = layer(x, mask) |
||||
|
return self.norm(x) |
||||
|
|
||||
|
class LayerNorm(nn.Module): |
||||
|
"Construct a layernorm module (See citation for details)." |
||||
|
def __init__(self, features, eps=1e-6): |
||||
|
super(LayerNorm, self).__init__() |
||||
|
self.a_2 = nn.Parameter(torch.ones(features)) |
||||
|
self.b_2 = nn.Parameter(torch.zeros(features)) |
||||
|
self.eps = eps |
||||
|
|
||||
|
def forward(self, x): |
||||
|
mean = x.mean(-1, keepdim=True) |
||||
|
std = x.std(-1, keepdim=True) |
||||
|
return self.a_2 * (x - mean) / (std + self.eps) + self.b_2 |
||||
|
|
||||
|
class SublayerConnection(nn.Module): |
||||
|
""" |
||||
|
A residual connection followed by a layer norm. |
||||
|
Note for code simplicity the norm is first as opposed to last. |
||||
|
""" |
||||
|
def __init__(self, size, dropout): |
||||
|
super(SublayerConnection, self).__init__() |
||||
|
self.norm = LayerNorm(size) |
||||
|
self.dropout = nn.Dropout(dropout) |
||||
|
|
||||
|
def forward(self, x, sublayer): |
||||
|
"Apply residual connection to any sublayer with the same size." |
||||
|
return x + self.dropout(sublayer(self.norm(x))) |
||||
|
|
||||
|
class EncoderLayer(nn.Module): |
||||
|
"Encoder is made up of self-attn and feed forward (defined below)" |
||||
|
def __init__(self, size, self_attn, feed_forward, dropout): |
||||
|
super(EncoderLayer, self).__init__() |
||||
|
self.self_attn = self_attn |
||||
|
self.feed_forward = feed_forward |
||||
|
self.sublayer = clones(SublayerConnection(size, dropout), 2) |
||||
|
self.size = size |
||||
|
|
||||
|
def forward(self, x, mask): |
||||
|
"Follow Figure 1 (left) for connections." |
||||
|
x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask)) |
||||
|
return self.sublayer[1](x, self.feed_forward) |
||||
|
|
||||
|
class Decoder(nn.Module): |
||||
|
"Generic N layer decoder with masking." |
||||
|
def __init__(self, layer, N): |
||||
|
super(Decoder, self).__init__() |
||||
|
self.layers = clones(layer, N) |
||||
|
self.norm = LayerNorm(layer.size) |
||||
|
|
||||
|
def forward(self, x, memory, src_mask, tgt_mask): |
||||
|
for layer in self.layers: |
||||
|
x = layer(x, memory, src_mask, tgt_mask) |
||||
|
return self.norm(x) |
||||
|
|
||||
|
class DecoderLayer(nn.Module): |
||||
|
"Decoder is made of self-attn, src-attn, and feed forward (defined below)" |
||||
|
def __init__(self, size, self_attn, src_attn, feed_forward, dropout): |
||||
|
super(DecoderLayer, self).__init__() |
||||
|
self.size = size |
||||
|
self.self_attn = self_attn |
||||
|
self.src_attn = src_attn |
||||
|
self.feed_forward = feed_forward |
||||
|
self.sublayer = clones(SublayerConnection(size, dropout), 3) |
||||
|
|
||||
|
def forward(self, x, memory, src_mask, tgt_mask): |
||||
|
"Follow Figure 1 (right) for connections." |
||||
|
m = memory |
||||
|
x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask)) |
||||
|
x = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask)) |
||||
|
return self.sublayer[2](x, self.feed_forward) |
||||
|
|
||||
|
def subsequent_mask(size): |
||||
|
"Mask out subsequent positions." |
||||
|
attn_shape = (1, size, size) |
||||
|
subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8') |
||||
|
return torch.from_numpy(subsequent_mask) == 0 |
||||
|
|
||||
|
def attention(query, key, value, mask=None, dropout=None): |
||||
|
"Compute 'Scaled Dot Product Attention'" |
||||
|
d_k = query.size(-1) |
||||
|
scores = torch.matmul(query, key.transpose(-2, -1)) \ |
||||
|
/ math.sqrt(d_k) |
||||
|
if mask is not None: |
||||
|
scores = scores.masked_fill(mask == 0, float('-inf')) |
||||
|
p_attn = F.softmax(scores, dim = -1) |
||||
|
if dropout is not None: |
||||
|
p_attn = dropout(p_attn) |
||||
|
return torch.matmul(p_attn, value), p_attn |
||||
|
|
||||
|
class MultiHeadedAttention(nn.Module): |
||||
|
def __init__(self, h, d_model, dropout=0.1): |
||||
|
"Take in model size and number of heads." |
||||
|
super(MultiHeadedAttention, self).__init__() |
||||
|
assert d_model % h == 0 |
||||
|
# We assume d_v always equals d_k |
||||
|
self.d_k = d_model // h |
||||
|
self.h = h |
||||
|
self.linears = clones(nn.Linear(d_model, d_model), 4) |
||||
|
self.attn = None |
||||
|
self.dropout = nn.Dropout(p=dropout) |
||||
|
|
||||
|
def forward(self, query, key, value, mask=None): |
||||
|
"Implements Figure 2" |
||||
|
if mask is not None: |
||||
|
# Same mask applied to all h heads. |
||||
|
mask = mask.unsqueeze(1) |
||||
|
nbatches = query.size(0) |
||||
|
|
||||
|
# 1) Do all the linear projections in batch from d_model => h x d_k |
||||
|
query, key, value = \ |
||||
|
[l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2) |
||||
|
for l, x in zip(self.linears, (query, key, value))] |
||||
|
|
||||
|
# 2) Apply attention on all the projected vectors in batch. |
||||
|
x, self.attn = attention(query, key, value, mask=mask, |
||||
|
dropout=self.dropout) |
||||
|
|
||||
|
# 3) "Concat" using a view and apply a final linear. |
||||
|
x = x.transpose(1, 2).contiguous() \ |
||||
|
.view(nbatches, -1, self.h * self.d_k) |
||||
|
return self.linears[-1](x) |
||||
|
|
||||
|
class PositionwiseFeedForward(nn.Module): |
||||
|
"Implements FFN equation." |
||||
|
def __init__(self, d_model, d_ff, dropout=0.1): |
||||
|
super(PositionwiseFeedForward, self).__init__() |
||||
|
self.w_1 = nn.Linear(d_model, d_ff) |
||||
|
self.w_2 = nn.Linear(d_ff, d_model) |
||||
|
self.dropout = nn.Dropout(dropout) |
||||
|
|
||||
|
def forward(self, x): |
||||
|
return self.w_2(self.dropout(F.relu(self.w_1(x)))) |
||||
|
|
||||
|
class Embeddings(nn.Module): |
||||
|
def __init__(self, d_model, vocab): |
||||
|
super(Embeddings, self).__init__() |
||||
|
self.lut = nn.Embedding(vocab, d_model) |
||||
|
self.d_model = d_model |
||||
|
|
||||
|
def forward(self, x): |
||||
|
return self.lut(x) * math.sqrt(self.d_model) |
||||
|
|
||||
|
class PositionalEncoding(nn.Module): |
||||
|
"Implement the PE function." |
||||
|
def __init__(self, d_model, dropout, max_len=5000): |
||||
|
super(PositionalEncoding, self).__init__() |
||||
|
self.dropout = nn.Dropout(p=dropout) |
||||
|
|
||||
|
# Compute the positional encodings once in log space. |
||||
|
pe = torch.zeros(max_len, d_model) |
||||
|
position = torch.arange(0, max_len).unsqueeze(1).float() |
||||
|
div_term = torch.exp(torch.arange(0, d_model, 2).float() * |
||||
|
-(math.log(10000.0) / d_model)) |
||||
|
pe[:, 0::2] = torch.sin(position * div_term) |
||||
|
pe[:, 1::2] = torch.cos(position * div_term) |
||||
|
pe = pe.unsqueeze(0) |
||||
|
self.register_buffer('pe', pe) |
||||
|
|
||||
|
def forward(self, x): |
||||
|
x = x + self.pe[:, :x.size(1)] |
||||
|
return self.dropout(x) |
||||
|
|
||||
|
class TransformerModel(AttModel): |
||||
|
|
||||
|
def make_model(self, src_vocab, tgt_vocab, N_enc=6, N_dec=6, |
||||
|
d_model=512, d_ff=2048, h=8, dropout=0.1): |
||||
|
"Helper: Construct a model from hyperparameters." |
||||
|
c = copy.deepcopy |
||||
|
attn = MultiHeadedAttention(h, d_model, dropout) |
||||
|
ff = PositionwiseFeedForward(d_model, d_ff, dropout) |
||||
|
position = PositionalEncoding(d_model, dropout) |
||||
|
model = EncoderDecoder( |
||||
|
Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N_enc), |
||||
|
Decoder(DecoderLayer(d_model, c(attn), c(attn), |
||||
|
c(ff), dropout), N_dec), |
||||
|
lambda x:x, # nn.Sequential(Embeddings(d_model, src_vocab), c(position)), |
||||
|
nn.Sequential(Embeddings(d_model, tgt_vocab), c(position)), |
||||
|
Generator(d_model, tgt_vocab)) |
||||
|
|
||||
|
# This was important from their code. |
||||
|
# Initialize parameters with Glorot / fan_avg. |
||||
|
for p in model.parameters(): |
||||
|
if p.dim() > 1: |
||||
|
nn.init.xavier_uniform_(p) |
||||
|
return model |
||||
|
|
||||
|
def __init__(self, opt): |
||||
|
super(TransformerModel, self).__init__(opt) |
||||
|
self.opt = opt |
||||
|
# self.config = yaml.load(open(opt.config_file)) |
||||
|
|
||||
|
self.N_enc = getattr(opt, 'N_enc', opt.num_layers) |
||||
|
self.N_dec = getattr(opt, 'N_dec', opt.num_layers) |
||||
|
self.d_model = getattr(opt, 'd_model', opt.input_encoding_size) |
||||
|
self.d_ff = getattr(opt, 'd_ff', opt.rnn_size) |
||||
|
self.h = getattr(opt, 'num_att_heads', 8) |
||||
|
self.dropout = getattr(opt, 'dropout', 0.1) |
||||
|
|
||||
|
delattr(self, 'att_embed') |
||||
|
self.att_embed = nn.Sequential(*( |
||||
|
((nn.BatchNorm1d(self.att_feat_size),) if self.use_bn else ())+ |
||||
|
(nn.Linear(self.att_feat_size, self.d_model), |
||||
|
nn.ReLU(), |
||||
|
nn.Dropout(self.drop_prob_lm))+ |
||||
|
((nn.BatchNorm1d(self.d_model),) if self.use_bn==2 else ()))) |
||||
|
|
||||
|
delattr(self, 'embed') |
||||
|
self.embed = lambda x : x |
||||
|
delattr(self, 'fc_embed') |
||||
|
self.fc_embed = lambda x : x |
||||
|
delattr(self, 'logit') |
||||
|
del self.ctx2att |
||||
|
|
||||
|
tgt_vocab = self.vocab_size + 1 |
||||
|
|
||||
|
|
||||
|
self.model = self.make_model(0, tgt_vocab, |
||||
|
N_enc=self.N_enc, |
||||
|
N_dec=self.N_dec, |
||||
|
d_model=self.d_model, |
||||
|
d_ff=self.d_ff, |
||||
|
h=self.h, |
||||
|
dropout=self.dropout) |
||||
|
|
||||
|
def logit(self, x): # unsafe way |
||||
|
return self.model.generator.proj(x) |
||||
|
|
||||
|
def init_hidden(self, bsz): |
||||
|
return [] |
||||
|
|
||||
|
def _prepare_feature(self, fc_feats, att_feats, att_masks): |
||||
|
|
||||
|
att_feats, seq, att_masks, seq_mask = self._prepare_feature_forward(att_feats, att_masks) |
||||
|
memory = self.model.encode(att_feats, att_masks) |
||||
|
|
||||
|
return fc_feats[...,:0], att_feats[...,:0], memory, att_masks |
||||
|
|
||||
|
def _prepare_feature_forward(self, att_feats, att_masks=None, seq=None): |
||||
|
att_feats, att_masks = self.clip_att(att_feats, att_masks) |
||||
|
|
||||
|
att_feats = pack_wrapper(self.att_embed, att_feats, att_masks) |
||||
|
|
||||
|
if att_masks is None: |
||||
|
att_masks = att_feats.new_ones(att_feats.shape[:2], dtype=torch.long) |
||||
|
att_masks = att_masks.unsqueeze(-2) |
||||
|
|
||||
|
if seq is not None: |
||||
|
# crop the last one |
||||
|
# seq = seq[:,:-1] |
||||
|
seq_mask = (seq.data != self.eos_idx) & (seq.data != self.pad_idx) |
||||
|
seq_mask[:,0] = 1 # bos |
||||
|
|
||||
|
seq_mask = seq_mask.unsqueeze(-2) |
||||
|
seq_mask = seq_mask & subsequent_mask(seq.size(-1)).to(seq_mask) |
||||
|
|
||||
|
seq_per_img = seq.shape[0] // att_feats.shape[0] |
||||
|
if seq_per_img > 1: |
||||
|
att_feats, att_masks = utils.repeat_tensors(seq_per_img, |
||||
|
[att_feats, att_masks] |
||||
|
) |
||||
|
else: |
||||
|
seq_mask = None |
||||
|
|
||||
|
return att_feats, seq, att_masks, seq_mask |
||||
|
|
||||
|
def _forward(self, fc_feats, att_feats, seq, att_masks=None): |
||||
|
if seq.ndim == 3: # B * seq_per_img * seq_len |
||||
|
seq = seq.reshape(-1, seq.shape[2]) |
||||
|
att_feats, seq, att_masks, seq_mask = self._prepare_feature_forward(att_feats, att_masks, seq) |
||||
|
|
||||
|
out = self.model(att_feats, seq, att_masks, seq_mask) |
||||
|
|
||||
|
outputs = self.model.generator(out) |
||||
|
return outputs |
||||
|
# return torch.cat([_.unsqueeze(1) for _ in outputs], 1) |
||||
|
|
||||
|
def core(self, it, fc_feats_ph, att_feats_ph, memory, state, mask): |
||||
|
""" |
||||
|
state = [ys.unsqueeze(0)] |
||||
|
""" |
||||
|
if len(state) == 0: |
||||
|
ys = it.unsqueeze(1) |
||||
|
else: |
||||
|
ys = torch.cat([state[0][0], it.unsqueeze(1)], dim=1) |
||||
|
out = self.model.decode(memory, mask, |
||||
|
ys, |
||||
|
subsequent_mask(ys.size(1)) |
||||
|
.to(memory.device)) |
||||
|
return out[:, -1], [ys.unsqueeze(0)] |
@ -0,0 +1,153 @@ |
|||||
|
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. |
||||
|
# Copy from fvcore |
||||
|
|
||||
|
import logging |
||||
|
import os |
||||
|
from typing import Any |
||||
|
import yaml |
||||
|
from yacs.config import CfgNode as _CfgNode |
||||
|
|
||||
|
import io as PathManager |
||||
|
|
||||
|
BASE_KEY = "_BASE_" |
||||
|
|
||||
|
|
||||
|
class CfgNode(_CfgNode): |
||||
|
""" |
||||
|
Our own extended version of :class:`yacs.config.CfgNode`. |
||||
|
It contains the following extra features: |
||||
|
|
||||
|
1. The :meth:`merge_from_file` method supports the "_BASE_" key, |
||||
|
which allows the new CfgNode to inherit all the attributes from the |
||||
|
base configuration file. |
||||
|
2. Keys that start with "COMPUTED_" are treated as insertion-only |
||||
|
"computed" attributes. They can be inserted regardless of whether |
||||
|
the CfgNode is frozen or not. |
||||
|
3. With "allow_unsafe=True", it supports pyyaml tags that evaluate |
||||
|
expressions in config. See examples in |
||||
|
https://pyyaml.org/wiki/PyYAMLDocumentation#yaml-tags-and-python-types |
||||
|
Note that this may lead to arbitrary code execution: you must not |
||||
|
load a config file from untrusted sources before manually inspecting |
||||
|
the content of the file. |
||||
|
""" |
||||
|
|
||||
|
@staticmethod |
||||
|
def load_yaml_with_base(filename, allow_unsafe = False): |
||||
|
""" |
||||
|
Just like `yaml.load(open(filename))`, but inherit attributes from its |
||||
|
`_BASE_`. |
||||
|
|
||||
|
Args: |
||||
|
filename (str): the file name of the current config. Will be used to |
||||
|
find the base config file. |
||||
|
allow_unsafe (bool): whether to allow loading the config file with |
||||
|
`yaml.unsafe_load`. |
||||
|
|
||||
|
Returns: |
||||
|
(dict): the loaded yaml |
||||
|
""" |
||||
|
with PathManager.open(filename, "r") as f: |
||||
|
try: |
||||
|
cfg = yaml.safe_load(f) |
||||
|
except yaml.constructor.ConstructorError: |
||||
|
if not allow_unsafe: |
||||
|
raise |
||||
|
logger = logging.getLogger(__name__) |
||||
|
logger.warning( |
||||
|
"Loading config {} with yaml.unsafe_load. Your machine may " |
||||
|
"be at risk if the file contains malicious content.".format( |
||||
|
filename |
||||
|
) |
||||
|
) |
||||
|
f.close() |
||||
|
with open(filename, "r") as f: |
||||
|
cfg = yaml.unsafe_load(f) |
||||
|
|
||||
|
def merge_a_into_b(a, b): |
||||
|
# merge dict a into dict b. values in a will overwrite b. |
||||
|
for k, v in a.items(): |
||||
|
if isinstance(v, dict) and k in b: |
||||
|
assert isinstance( |
||||
|
b[k], dict |
||||
|
), "Cannot inherit key '{}' from base!".format(k) |
||||
|
merge_a_into_b(v, b[k]) |
||||
|
else: |
||||
|
b[k] = v |
||||
|
|
||||
|
if BASE_KEY in cfg: |
||||
|
base_cfg_file = cfg[BASE_KEY] |
||||
|
if base_cfg_file.startswith("~"): |
||||
|
base_cfg_file = os.path.expanduser(base_cfg_file) |
||||
|
if not any( |
||||
|
map(base_cfg_file.startswith, ["/", "https://", "http://"]) |
||||
|
): |
||||
|
# the path to base cfg is relative to the config file itself. |
||||
|
base_cfg_file = os.path.join( |
||||
|
os.path.dirname(filename), base_cfg_file |
||||
|
) |
||||
|
base_cfg = CfgNode.load_yaml_with_base( |
||||
|
base_cfg_file, allow_unsafe=allow_unsafe |
||||
|
) |
||||
|
del cfg[BASE_KEY] |
||||
|
|
||||
|
merge_a_into_b(cfg, base_cfg) |
||||
|
return base_cfg |
||||
|
return cfg |
||||
|
|
||||
|
def merge_from_file(self, cfg_filename, allow_unsafe = False): |
||||
|
""" |
||||
|
Merge configs from a given yaml file. |
||||
|
|
||||
|
Args: |
||||
|
cfg_filename: the file name of the yaml config. |
||||
|
allow_unsafe: whether to allow loading the config file with |
||||
|
`yaml.unsafe_load`. |
||||
|
""" |
||||
|
loaded_cfg = CfgNode.load_yaml_with_base( |
||||
|
cfg_filename, allow_unsafe=allow_unsafe |
||||
|
) |
||||
|
loaded_cfg = type(self)(loaded_cfg) |
||||
|
self.merge_from_other_cfg(loaded_cfg) |
||||
|
|
||||
|
# Forward the following calls to base, but with a check on the BASE_KEY. |
||||
|
def merge_from_other_cfg(self, cfg_other): |
||||
|
""" |
||||
|
Args: |
||||
|
cfg_other (CfgNode): configs to merge from. |
||||
|
""" |
||||
|
assert ( |
||||
|
BASE_KEY not in cfg_other |
||||
|
), "The reserved key '{}' can only be used in files!".format(BASE_KEY) |
||||
|
return super().merge_from_other_cfg(cfg_other) |
||||
|
|
||||
|
def merge_from_list(self, cfg_list): |
||||
|
""" |
||||
|
Args: |
||||
|
cfg_list (list): list of configs to merge from. |
||||
|
""" |
||||
|
keys = set(cfg_list[0::2]) |
||||
|
assert ( |
||||
|
BASE_KEY not in keys |
||||
|
), "The reserved key '{}' can only be used in files!".format(BASE_KEY) |
||||
|
return super().merge_from_list(cfg_list) |
||||
|
|
||||
|
def __setattr__(self, name, val): |
||||
|
if name.startswith("COMPUTED_"): |
||||
|
if name in self: |
||||
|
old_val = self[name] |
||||
|
if old_val == val: |
||||
|
return |
||||
|
raise KeyError( |
||||
|
"Computed attributed '{}' already exists " |
||||
|
"with a different value! old={}, new={}.".format( |
||||
|
name, old_val, val |
||||
|
) |
||||
|
) |
||||
|
self[name] = val |
||||
|
else: |
||||
|
super().__setattr__(name, val) |
||||
|
|
||||
|
|
||||
|
if __name__ == '__main__': |
||||
|
cfg = CfgNode.load_yaml_with_base('configs/updown_long.yml') |
||||
|
print(cfg) |
@ -0,0 +1,412 @@ |
|||||
|
from __future__ import print_function |
||||
|
import argparse |
||||
|
|
||||
|
|
||||
|
def if_use_feat(caption_model): |
||||
|
# Decide if load attention feature according to caption model |
||||
|
if caption_model in ['show_tell', 'all_img', 'fc', 'newfc']: |
||||
|
use_att, use_fc = False, True |
||||
|
elif caption_model == 'language_model': |
||||
|
use_att, use_fc = False, False |
||||
|
elif caption_model in ['updown', 'topdown']: |
||||
|
use_fc, use_att = True, True |
||||
|
else: |
||||
|
use_att, use_fc = True, False |
||||
|
return use_fc, use_att |
||||
|
|
||||
|
import pprint |
||||
|
class Config(object): |
||||
|
def __init__(self, **kwargs): |
||||
|
"""Configuration Class: set kwargs as class attributes with setattr""" |
||||
|
for k, v in kwargs.items(): |
||||
|
setattr(self, k, v) |
||||
|
|
||||
|
@property |
||||
|
def config_str(self): |
||||
|
return pprint.pformat(self.__dict__) |
||||
|
|
||||
|
def __repr__(self): |
||||
|
"""Pretty-print configurations in alphabetical order""" |
||||
|
config_str = 'Configurations\n' |
||||
|
config_str += self.config_str |
||||
|
return config_str |
||||
|
|
||||
|
|
||||
|
def parse_opt(parse=True, **optional_kwargs): |
||||
|
parser = argparse.ArgumentParser() |
||||
|
# Data input settings |
||||
|
parser.add_argument('--input_json', type=str, default='data/coco.json', |
||||
|
help='path to the json file containing additional info and vocab') |
||||
|
parser.add_argument('--input_fc_dir', type=str, default='data/cocotalk_fc', |
||||
|
help='path to the directory containing the preprocessed fc feats') |
||||
|
parser.add_argument('--input_att_dir', type=str, default='data/cocotalk_att', |
||||
|
help='path to the directory containing the preprocessed att feats') |
||||
|
parser.add_argument('--input_box_dir', type=str, default='data/cocotalk_box', |
||||
|
help='path to the directory containing the boxes of att feats') |
||||
|
parser.add_argument('--input_label_h5', type=str, default='data/coco_label.h5', |
||||
|
help='path to the h5file containing the preprocessed dataset') |
||||
|
parser.add_argument('--data_in_memory', action='store_true', |
||||
|
help='True if we want to save the features in memory') |
||||
|
parser.add_argument('--start_from', type=str, default=None, |
||||
|
help="""continue training from saved model at this path. Path must contain files saved by previous training process: |
||||
|
'infos.pkl' : configuration; |
||||
|
'model.pth' : weights |
||||
|
""") |
||||
|
parser.add_argument('--cached_tokens', type=str, default='coco-train-idxs', |
||||
|
help='Cached token file for calculating cider score during self critical training.') |
||||
|
|
||||
|
# Model settings |
||||
|
parser.add_argument('--caption_model', type=str, default="show_tell", |
||||
|
help='show_tell, show_attend_tell, all_img, fc, att2in, att2in2, att2all2, adaatt, adaattmo, updown, stackatt, denseatt, transformer') |
||||
|
parser.add_argument('--rnn_size', type=int, default=512, |
||||
|
help='size of the rnn in number of hidden nodes in each layer') |
||||
|
parser.add_argument('--num_layers', type=int, default=1, |
||||
|
help='number of layers in the RNN') |
||||
|
parser.add_argument('--rnn_type', type=str, default='lstm', |
||||
|
help='rnn, gru, or lstm') |
||||
|
parser.add_argument('--input_encoding_size', type=int, default=512, |
||||
|
help='the encoding size of each token in the vocabulary, and the image.') |
||||
|
parser.add_argument('--att_hid_size', type=int, default=512, |
||||
|
help='the hidden size of the attention MLP; only useful in show_attend_tell; 0 if not using hidden layer') |
||||
|
parser.add_argument('--fc_feat_size', type=int, default=2048, |
||||
|
help='2048 for resnet, 4096 for vgg') |
||||
|
parser.add_argument('--att_feat_size', type=int, default=2048, |
||||
|
help='2048 for resnet, 512 for vgg') |
||||
|
parser.add_argument('--logit_layers', type=int, default=1, |
||||
|
help='number of layers in the RNN') |
||||
|
|
||||
|
|
||||
|
parser.add_argument('--use_bn', type=int, default=0, |
||||
|
help='If 1, then do batch_normalization first in att_embed, if 2 then do bn both in the beginning and the end of att_embed') |
||||
|
|
||||
|
# feature manipulation |
||||
|
parser.add_argument('--norm_att_feat', type=int, default=0, |
||||
|
help='If normalize attention features') |
||||
|
parser.add_argument('--use_box', type=int, default=0, |
||||
|
help='If use box features') |
||||
|
parser.add_argument('--norm_box_feat', type=int, default=0, |
||||
|
help='If use box, do we normalize box feature') |
||||
|
|
||||
|
# Optimization: General |
||||
|
parser.add_argument('--max_epochs', type=int, default=-1, |
||||
|
help='number of epochs') |
||||
|
parser.add_argument('--batch_size', type=int, default=16, |
||||
|
help='minibatch size') |
||||
|
parser.add_argument('--grad_clip_mode', type=str, default='value', |
||||
|
help='value or norm') |
||||
|
parser.add_argument('--grad_clip_value', type=float, default=0.1, |
||||
|
help='clip gradients at this value/max_norm, 0 means no clipping') |
||||
|
parser.add_argument('--drop_prob_lm', type=float, default=0.5, |
||||
|
help='strength of dropout in the Language Model RNN') |
||||
|
parser.add_argument('--self_critical_after', type=int, default=-1, |
||||
|
help='After what epoch do we start finetuning the CNN? (-1 = disable; never finetune, 0 = finetune from start)') |
||||
|
parser.add_argument('--seq_per_img', type=int, default=5, |
||||
|
help='number of captions to sample for each image during training. Done for efficiency since CNN forward pass is expensive. E.g. coco has 5 sents/image') |
||||
|
|
||||
|
parser.add_argument('--verbose', type=int, default=0) |
||||
|
|
||||
|
# Sample related |
||||
|
add_eval_sample_opts(parser) |
||||
|
|
||||
|
#Optimization: for the Language Model |
||||
|
parser.add_argument('--optim', type=str, default='adam', |
||||
|
help='what update to use? rmsprop|sgd|sgdmom|adagrad|adam|adamw') |
||||
|
parser.add_argument('--learning_rate', type=float, default=4e-4, |
||||
|
help='learning rate') |
||||
|
parser.add_argument('--learning_rate_decay_start', type=int, default=-1, |
||||
|
help='at what iteration to start decaying learning rate? (-1 = dont) (in epoch)') |
||||
|
parser.add_argument('--learning_rate_decay_every', type=int, default=3, |
||||
|
help='every how many iterations thereafter to drop LR?(in epoch)') |
||||
|
parser.add_argument('--learning_rate_decay_rate', type=float, default=0.8, |
||||
|
help='every how many iterations thereafter to drop LR?(in epoch)') |
||||
|
parser.add_argument('--optim_alpha', type=float, default=0.9, |
||||
|
help='alpha for adam') |
||||
|
parser.add_argument('--optim_beta', type=float, default=0.999, |
||||
|
help='beta used for adam') |
||||
|
parser.add_argument('--optim_epsilon', type=float, default=1e-8, |
||||
|
help='epsilon that goes into denominator for smoothing') |
||||
|
parser.add_argument('--weight_decay', type=float, default=0, |
||||
|
help='weight_decay') |
||||
|
# Transformer |
||||
|
parser.add_argument('--label_smoothing', type=float, default=0, |
||||
|
help='') |
||||
|
parser.add_argument('--noamopt', action='store_true', |
||||
|
help='') |
||||
|
parser.add_argument('--noamopt_warmup', type=int, default=2000, |
||||
|
help='') |
||||
|
parser.add_argument('--noamopt_factor', type=float, default=1, |
||||
|
help='') |
||||
|
parser.add_argument('--reduce_on_plateau', action='store_true', |
||||
|
help='') |
||||
|
parser.add_argument('--reduce_on_plateau_factor', type=float, default=0.5, |
||||
|
help='') |
||||
|
parser.add_argument('--reduce_on_plateau_patience', type=int, default=3, |
||||
|
help='') |
||||
|
parser.add_argument('--cached_transformer', action='store_true', |
||||
|
help='') |
||||
|
|
||||
|
|
||||
|
parser.add_argument('--use_warmup', action='store_true', |
||||
|
help='warm up the learing rate?') |
||||
|
|
||||
|
parser.add_argument('--scheduled_sampling_start', type=int, default=-1, |
||||
|
help='at what iteration to start decay gt probability') |
||||
|
parser.add_argument('--scheduled_sampling_increase_every', type=int, default=5, |
||||
|
help='every how many iterations thereafter to gt probability') |
||||
|
parser.add_argument('--scheduled_sampling_increase_prob', type=float, default=0.05, |
||||
|
help='How much to update the prob') |
||||
|
parser.add_argument('--scheduled_sampling_max_prob', type=float, default=0.25, |
||||
|
help='Maximum scheduled sampling prob.') |
||||
|
|
||||
|
|
||||
|
# Evaluation/Checkpointing |
||||
|
parser.add_argument('--val_images_use', type=int, default=3200, |
||||
|
help='how many images to use when periodically evaluating the validation loss? (-1 = all)') |
||||
|
parser.add_argument('--save_checkpoint_every', type=int, default=2500, |
||||
|
help='how often to save a model checkpoint (in iterations)?') |
||||
|
parser.add_argument('--save_every_epoch', action='store_true', |
||||
|
help='Save checkpoint every epoch, will overwrite save_checkpoint_every') |
||||
|
parser.add_argument('--save_history_ckpt', type=int, default=0, |
||||
|
help='If save checkpoints at every save point') |
||||
|
parser.add_argument('--checkpoint_path', type=str, default=None, |
||||
|
help='directory to store checkpointed models') |
||||
|
parser.add_argument('--language_eval', type=int, default=0, |
||||
|
help='Evaluate language as well (1 = yes, 0 = no)? BLEU/CIDEr/METEOR/ROUGE_L? requires coco-caption code from Github.') |
||||
|
parser.add_argument('--losses_log_every', type=int, default=25, |
||||
|
help='How often do we snapshot losses, for inclusion in the progress dump? (0 = disable)') |
||||
|
parser.add_argument('--load_best_score', type=int, default=1, |
||||
|
help='Do we load previous best score when resuming training.') |
||||
|
|
||||
|
# misc |
||||
|
parser.add_argument('--id', type=str, default='', |
||||
|
help='an id identifying this run/job. used in cross-val and appended when writing progress files') |
||||
|
parser.add_argument('--train_only', type=int, default=0, |
||||
|
help='if true then use 80k, else use 110k') |
||||
|
|
||||
|
|
||||
|
# Reward |
||||
|
parser.add_argument('--cider_reward_weight', type=float, default=1, |
||||
|
help='The reward weight from cider') |
||||
|
parser.add_argument('--bleu_reward_weight', type=float, default=0, |
||||
|
help='The reward weight from bleu4') |
||||
|
|
||||
|
# Reward |
||||
|
parser.add_argument('--clipscore_reward_weight', type=float, default=1, |
||||
|
help='The reward weight from clipscore') |
||||
|
parser.add_argument('--use_clipscore', type=float, default=0, |
||||
|
help='Use CLIPScore') |
||||
|
parser.add_argument('--clipscore_mode', type=str, default='clip_s', |
||||
|
help='Which CLIPScore to use: clip_s|refclip_s') |
||||
|
|
||||
|
|
||||
|
# Structure_loss |
||||
|
parser.add_argument('--structure_loss_weight', type=float, default=1, |
||||
|
help='') |
||||
|
parser.add_argument('--structure_after', type=int, default=-1, |
||||
|
help='T') |
||||
|
parser.add_argument('--structure_loss_type', type=str, default='seqnll', |
||||
|
help='') |
||||
|
parser.add_argument('--struc_use_logsoftmax', action='store_true', help='') |
||||
|
parser.add_argument('--entropy_reward_weight', type=float, default=0, |
||||
|
help='Entropy reward, seems very interesting') |
||||
|
parser.add_argument('--self_cider_reward_weight', type=float, default=0, |
||||
|
help='self cider reward') |
||||
|
|
||||
|
# Used for self critical or structure. Used when sampling is need during training |
||||
|
parser.add_argument('--train_sample_n', type=int, default=16, |
||||
|
help='The reward weight from cider') |
||||
|
parser.add_argument('--train_sample_method', type=str, default='sample', |
||||
|
help='') |
||||
|
parser.add_argument('--train_beam_size', type=int, default=1, |
||||
|
help='') |
||||
|
|
||||
|
# Used for self critical |
||||
|
parser.add_argument('--sc_sample_method', type=str, default='greedy', |
||||
|
help='') |
||||
|
parser.add_argument('--sc_beam_size', type=int, default=1, |
||||
|
help='') |
||||
|
|
||||
|
|
||||
|
# For diversity evaluation during training |
||||
|
add_diversity_opts(parser) |
||||
|
|
||||
|
|
||||
|
# config |
||||
|
parser.add_argument('--cfg', type=str, default=None, |
||||
|
help='configuration; similar to what is used in detectron') |
||||
|
parser.add_argument( |
||||
|
'--set_cfgs', dest='set_cfgs', |
||||
|
help='Set config keys. Key value sequence seperate by whitespace.' |
||||
|
'e.g. [key] [value] [key] [value]\n This has higher priority' |
||||
|
'than cfg file but lower than other args. (You can only overwrite' |
||||
|
'arguments that have alerady been defined in config file.)', |
||||
|
default=[], nargs='+') |
||||
|
# How will config be used |
||||
|
# 1) read cfg argument, and load the cfg file if it's not None |
||||
|
# 2) Overwrite cfg argument with set_cfgs |
||||
|
# 3) parse config argument to args. |
||||
|
# 4) in the end, parse command line argument and overwrite args |
||||
|
|
||||
|
# step 1: read cfg_fn |
||||
|
# args = parser.parse_args() |
||||
|
# Parse the arguments. |
||||
|
if parse: |
||||
|
args = parser.parse_args() |
||||
|
# For interative engironmnet (ex. jupyter) |
||||
|
else: |
||||
|
args = parser.parse_known_args()[0] |
||||
|
# print(args) |
||||
|
|
||||
|
# Namespace => Dictionary |
||||
|
kwargs = vars(args) |
||||
|
# for k, v in optional_kwargs.items(): |
||||
|
# setattr(args, k, v) |
||||
|
kwargs.update(optional_kwargs) |
||||
|
|
||||
|
args = Config(**kwargs) |
||||
|
|
||||
|
|
||||
|
if args.cfg is not None or args.set_cfgs is not None: |
||||
|
from .config import CfgNode |
||||
|
if args.cfg is not None: |
||||
|
# print('Read Cfg') |
||||
|
cn = CfgNode(CfgNode.load_yaml_with_base(args.cfg)) |
||||
|
# print(cn) |
||||
|
else: |
||||
|
cn = CfgNode() |
||||
|
if args.set_cfgs is not None: |
||||
|
cn.merge_from_list(args.set_cfgs) |
||||
|
for k,v in cn.items(): |
||||
|
if not hasattr(args, k): |
||||
|
import os |
||||
|
if 'LOCAL_RANK' in os.environ and os.environ['LOCAL_RANK'] != '0': |
||||
|
pass |
||||
|
else: |
||||
|
print('Warning: key %s not in args' % k) |
||||
|
|
||||
|
setattr(args, k, v) |
||||
|
|
||||
|
if parse: |
||||
|
args = parser.parse_args(namespace=args) |
||||
|
else: |
||||
|
args = parser.parse_known_args(namespace=args)[0] |
||||
|
|
||||
|
# Check if args are valid |
||||
|
assert args.rnn_size > 0, "rnn_size should be greater than 0" |
||||
|
assert args.num_layers > 0, "num_layers should be greater than 0" |
||||
|
assert args.input_encoding_size > 0, "input_encoding_size should be greater than 0" |
||||
|
assert args.batch_size > 0, "batch_size should be greater than 0" |
||||
|
assert args.drop_prob_lm >= 0 and args.drop_prob_lm < 1, "drop_prob_lm should be between 0 and 1" |
||||
|
assert args.seq_per_img > 0, "seq_per_img should be greater than 0" |
||||
|
assert args.beam_size > 0, "beam_size should be greater than 0" |
||||
|
assert args.save_checkpoint_every > 0, "save_checkpoint_every should be greater than 0" |
||||
|
assert args.losses_log_every > 0, "losses_log_every should be greater than 0" |
||||
|
assert args.language_eval == 0 or args.language_eval == 1, "language_eval should be 0 or 1" |
||||
|
assert args.load_best_score == 0 or args.load_best_score == 1, "language_eval should be 0 or 1" |
||||
|
assert args.train_only == 0 or args.train_only == 1, "language_eval should be 0 or 1" |
||||
|
|
||||
|
# default value for start_from and checkpoint_path |
||||
|
args.checkpoint_path = args.checkpoint_path or './log_%s' %args.id |
||||
|
args.start_from = args.start_from or args.checkpoint_path |
||||
|
|
||||
|
# Deal with feature things before anything |
||||
|
args.use_fc, args.use_att = if_use_feat(args.caption_model) |
||||
|
if args.use_box: args.att_feat_size = args.att_feat_size + 5 |
||||
|
|
||||
|
return args |
||||
|
|
||||
|
|
||||
|
def add_eval_options(parser): |
||||
|
# Basic options |
||||
|
parser.add_argument('--batch_size', type=int, default=0, |
||||
|
help='if > 0 then overrule, otherwise load from checkpoint.') |
||||
|
parser.add_argument('--num_images', type=int, default=-1, |
||||
|
help='how many images to use when periodically evaluating the loss? (-1 = all)') |
||||
|
parser.add_argument('--language_eval', type=int, default=0, |
||||
|
help='Evaluate language as well (1 = yes, 0 = no)? BLEU/CIDEr/METEOR/ROUGE_L? requires coco-caption code from Github.') |
||||
|
parser.add_argument('--dump_images', type=int, default=1, |
||||
|
help='Dump images into vis/imgs folder for vis? (1=yes,0=no)') |
||||
|
parser.add_argument('--dump_json', type=int, default=1, |
||||
|
help='Dump json with predictions into vis folder? (1=yes,0=no)') |
||||
|
parser.add_argument('--dump_path', type=int, default=0, |
||||
|
help='Write image paths along with predictions into vis json? (1=yes,0=no)') |
||||
|
|
||||
|
# Sampling options |
||||
|
add_eval_sample_opts(parser) |
||||
|
|
||||
|
# For evaluation on a folder of images: |
||||
|
parser.add_argument('--image_folder', type=str, default='', |
||||
|
help='If this is nonempty then will predict on the images in this folder path') |
||||
|
parser.add_argument('--image_root', type=str, default='', |
||||
|
help='In case the image paths have to be preprended with a root path to an image folder') |
||||
|
# For evaluation on MSCOCO images from some split: |
||||
|
parser.add_argument('--input_fc_dir', type=str, default='', |
||||
|
help='path to the h5file containing the preprocessed dataset') |
||||
|
parser.add_argument('--input_att_dir', type=str, default='', |
||||
|
help='path to the h5file containing the preprocessed dataset') |
||||
|
parser.add_argument('--input_box_dir', type=str, default='', |
||||
|
help='path to the h5file containing the preprocessed dataset') |
||||
|
parser.add_argument('--input_label_h5', type=str, default='', |
||||
|
help='path to the h5file containing the preprocessed dataset') |
||||
|
parser.add_argument('--input_json', type=str, default='', |
||||
|
help='path to the json file containing additional info and vocab. empty = fetch from model checkpoint.') |
||||
|
parser.add_argument('--split', type=str, default='test', |
||||
|
help='if running on MSCOCO images, which split to use: val|test|train') |
||||
|
parser.add_argument('--coco_json', type=str, default='', |
||||
|
help='if nonempty then use this file in DataLoaderRaw (see docs there). Used only in MSCOCO test evaluation, where we have a specific json file of only test set images.') |
||||
|
# misc |
||||
|
parser.add_argument('--id', type=str, default='', |
||||
|
help='an id identifying this run/job. used only if language_eval = 1 for appending to intermediate files') |
||||
|
parser.add_argument('--verbose_beam', type=int, default=1, |
||||
|
help='if we need to print out all beam search beams.') |
||||
|
parser.add_argument('--verbose_loss', type=int, default=0, |
||||
|
help='If calculate loss using ground truth during evaluation') |
||||
|
|
||||
|
def add_diversity_opts(parser): |
||||
|
parser.add_argument('--sample_n', type=int, default=1, |
||||
|
help='Diverse sampling') |
||||
|
parser.add_argument('--sample_n_method', type=str, default='sample', |
||||
|
help='sample, bs, dbs, gumbel, topk, dgreedy, dsample, dtopk, dtopp') |
||||
|
parser.add_argument('--eval_oracle', type=int, default=1, |
||||
|
help='if we need to calculate loss.') |
||||
|
|
||||
|
|
||||
|
# Sampling related options |
||||
|
def add_eval_sample_opts(parser): |
||||
|
parser.add_argument('--sample_method', type=str, default='greedy', |
||||
|
help='greedy; sample; gumbel; top<int>, top<0-1>') |
||||
|
parser.add_argument('--beam_size', type=int, default=1, |
||||
|
help='used when sample_method = greedy, indicates number of beams in beam search. Usually 2 or 3 works well. More is not better. Set this to 1 for faster runtime but a bit worse performance.') |
||||
|
parser.add_argument('--max_length', type=int, default=20, |
||||
|
help='Maximum length during sampling') |
||||
|
parser.add_argument('--length_penalty', type=str, default='', |
||||
|
help='wu_X or avg_X, X is the alpha') |
||||
|
parser.add_argument('--group_size', type=int, default=1, |
||||
|
help='used for diverse beam search. if group_size is 1, then it\'s normal beam search') |
||||
|
parser.add_argument('--diversity_lambda', type=float, default=0.5, |
||||
|
help='used for diverse beam search. Usually from 0.2 to 0.8. Higher value of lambda produces a more diverse list') |
||||
|
parser.add_argument('--temperature', type=float, default=1.0, |
||||
|
help='temperature when sampling from distributions (i.e. when sample_method = sample). Lower = "safer" predictions.') |
||||
|
parser.add_argument('--decoding_constraint', type=int, default=0, |
||||
|
help='If 1, not allowing same word in a row') |
||||
|
parser.add_argument('--block_trigrams', type=int, default=0, |
||||
|
help='block repeated trigram.') |
||||
|
parser.add_argument('--remove_bad_endings', type=int, default=0, |
||||
|
help='Remove bad endings') |
||||
|
parser.add_argument('--suppress_UNK', type=int, default=1, |
||||
|
help='Not predicting UNK') |
||||
|
|
||||
|
|
||||
|
if __name__ == '__main__': |
||||
|
import sys |
||||
|
sys.argv = [sys.argv[0]] |
||||
|
args = parse_opt() |
||||
|
print(args) |
||||
|
print() |
||||
|
sys.argv = [sys.argv[0], '--cfg', 'configs/updown_long.yml'] |
||||
|
args1 = parse_opt() |
||||
|
print(dict(set(vars(args1).items()) - set(vars(args).items()))) |
||||
|
print() |
||||
|
sys.argv = [sys.argv[0], '--cfg', 'configs/updown_long.yml', '--caption_model', 'att2in2'] |
||||
|
args2 = parse_opt() |
||||
|
print(dict(set(vars(args2).items()) - set(vars(args1).items()))) |
Loading…
Reference in new issue