logo
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions

149 lines
7.3 KiB

import torch
import utils
class BeamSearch(object):
def __init__(self, model, max_len: int, eos_idx: int, beam_size: int):
self.model = model
self.max_len = max_len
self.eos_idx = eos_idx
self.beam_size = beam_size
self.b_s = None
self.device = None
self.seq_mask = None
self.seq_logprob = None
self.outputs = None
self.log_probs = None
self.selected_words = None
self.all_logits = None
def _expand_state(self, selected_beam, cur_beam_size):
def fn(s):
shape = [int(sh) for sh in s.shape]
beam = selected_beam
for _ in shape[1:]:
beam = beam.unsqueeze(-1)
s = torch.gather(s.view(*([self.b_s, cur_beam_size] + shape[1:])), 1,
beam.expand(*([self.b_s, self.beam_size] + shape[1:])))
s = s.view(*([-1, ] + shape[1:]))
return s
return fn
def _expand_visual(self, visual: utils.TensorOrSequence, cur_beam_size: int, selected_beam: torch.Tensor):
if isinstance(visual, torch.Tensor):
visual_shape = visual.shape
visual_exp_shape = (self.b_s, cur_beam_size) + visual_shape[1:]
visual_red_shape = (self.b_s * self.beam_size,) + visual_shape[1:]
selected_beam_red_size = (self.b_s, self.beam_size) + tuple(1 for _ in range(len(visual_exp_shape) - 2))
selected_beam_exp_size = (self.b_s, self.beam_size) + visual_exp_shape[2:]
visual_exp = visual.view(visual_exp_shape)
selected_beam_exp = selected_beam.view(selected_beam_red_size).expand(selected_beam_exp_size)
visual = torch.gather(visual_exp, 1, selected_beam_exp).view(visual_red_shape)
else:
new_visual = []
for im in visual:
visual_shape = im.shape
visual_exp_shape = (self.b_s, cur_beam_size) + visual_shape[1:]
visual_red_shape = (self.b_s * self.beam_size,) + visual_shape[1:]
selected_beam_red_size = (self.b_s, self.beam_size) + tuple(1 for _ in range(len(visual_exp_shape) - 2))
selected_beam_exp_size = (self.b_s, self.beam_size) + visual_exp_shape[2:]
visual_exp = im.view(visual_exp_shape)
selected_beam_exp = selected_beam.view(selected_beam_red_size).expand(selected_beam_exp_size)
new_im = torch.gather(visual_exp, 1, selected_beam_exp).view(visual_red_shape)
new_visual.append(new_im)
visual = tuple(new_visual)
return visual
def apply(self, visual: utils.TensorOrSequence, out_size=1, return_logits=False, **kwargs):
self.b_s = utils.get_batch_size(visual)
self.device = utils.get_device(visual)
self.seq_mask = torch.ones((self.b_s, self.beam_size, 1), device=self.device)
self.seq_logprob = torch.zeros((self.b_s, 1, 1), device=self.device)
self.log_probs = []
self.selected_words = None
if return_logits:
self.all_logits = []
outputs = []
with self.model.statefulness(self.b_s):
for t in range(self.max_len):
visual, outputs = self.iter(t, visual, outputs, return_logits, **kwargs)
# Sort result
seq_logprob, sort_idxs = torch.sort(self.seq_logprob, 1, descending=True)
outputs = torch.cat(outputs, -1)
outputs = torch.gather(outputs, 1, sort_idxs.expand(self.b_s, self.beam_size, self.max_len))
log_probs = torch.cat(self.log_probs, -1)
log_probs = torch.gather(log_probs, 1, sort_idxs.expand(self.b_s, self.beam_size, self.max_len))
outputs = outputs.contiguous()[:, :out_size]
log_probs = log_probs.contiguous()[:, :out_size]
if return_logits:
all_logits = torch.cat(self.all_logits, 2)
all_logits = torch.gather(all_logits, 1, sort_idxs.unsqueeze(-1).expand(self.b_s, self.beam_size,
self.max_len,
all_logits.shape[-1]))
all_logits = all_logits.contiguous()[:, :out_size]
if out_size == 1:
outputs = outputs.squeeze(1)
log_probs = log_probs.squeeze(1)
if return_logits:
all_logits = all_logits.squeeze(1)
if return_logits:
return outputs, log_probs, all_logits
else:
return outputs, log_probs
def select(self, t, candidate_logprob, **kwargs):
selected_logprob, selected_idx = torch.sort(candidate_logprob.view(self.b_s, -1), -1, descending=True)
selected_logprob, selected_idx = selected_logprob[:, :self.beam_size], selected_idx[:, :self.beam_size]
return selected_idx, selected_logprob
def iter(self, t: int, visual: utils.TensorOrSequence, outputs, return_logits, **kwargs):
cur_beam_size = 1 if t == 0 else self.beam_size
word_logits = self.model.step(t, self.selected_words, visual, **kwargs)
word_logits = word_logits.view(self.b_s, cur_beam_size, -1)
word_logprob = torch.log_softmax(word_logits, dim=-1)
candidate_logprob = self.seq_logprob + word_logprob
# Mask sequence if it reaches EOS
if t > 0:
mask = (self.selected_words.view(self.b_s, cur_beam_size) != self.eos_idx).type(visual.dtype).unsqueeze(-1)
self.seq_mask = self.seq_mask * mask
word_logprob = word_logprob * self.seq_mask.expand_as(word_logprob)
old_seq_logprob = self.seq_logprob.expand_as(candidate_logprob).contiguous()
old_seq_logprob[:, :, 1:] = -999
candidate_logprob = self.seq_mask * candidate_logprob + old_seq_logprob * (1 - self.seq_mask)
selected_idx, selected_logprob = self.select(t, candidate_logprob, **kwargs)
selected_beam = torch.floor_divide(selected_idx, candidate_logprob.shape[-1])
selected_words = selected_idx - selected_beam * candidate_logprob.shape[-1]
self.model.apply_to_states(self._expand_state(selected_beam, cur_beam_size))
visual = self._expand_visual(visual, cur_beam_size, selected_beam)
self.seq_logprob = selected_logprob.unsqueeze(-1)
self.seq_mask = torch.gather(self.seq_mask, 1, selected_beam.unsqueeze(-1))
outputs = list(torch.gather(o, 1, selected_beam.unsqueeze(-1)) for o in outputs)
outputs.append(selected_words.unsqueeze(-1))
if return_logits:
if t == 0:
self.all_logits.append(word_logits.expand((self.b_s, self.beam_size, -1)).unsqueeze(2))
else:
self.all_logits.append(word_logits.unsqueeze(2))
this_word_logprob = torch.gather(word_logprob, 1,
selected_beam.unsqueeze(-1).expand(self.b_s, self.beam_size,
word_logprob.shape[-1]))
this_word_logprob = torch.gather(this_word_logprob, 2, selected_words.unsqueeze(-1))
self.log_probs = list(
torch.gather(o, 1, selected_beam.unsqueeze(-1).expand(self.b_s, self.beam_size, 1)) for o in self.log_probs)
self.log_probs.append(this_word_logprob)
self.selected_words = selected_words.view(-1, 1)
return visual, outputs