import torch import utils class BeamSearch(object): def __init__(self, model, max_len: int, eos_idx: int, beam_size: int): self.model = model self.max_len = max_len self.eos_idx = eos_idx self.beam_size = beam_size self.b_s = None self.device = None self.seq_mask = None self.seq_logprob = None self.outputs = None self.log_probs = None self.selected_words = None self.all_logits = None def _expand_state(self, selected_beam, cur_beam_size): def fn(s): shape = [int(sh) for sh in s.shape] beam = selected_beam for _ in shape[1:]: beam = beam.unsqueeze(-1) s = torch.gather(s.view(*([self.b_s, cur_beam_size] + shape[1:])), 1, beam.expand(*([self.b_s, self.beam_size] + shape[1:]))) s = s.view(*([-1, ] + shape[1:])) return s return fn def _expand_visual(self, visual: utils.TensorOrSequence, cur_beam_size: int, selected_beam: torch.Tensor): if isinstance(visual, torch.Tensor): visual_shape = visual.shape visual_exp_shape = (self.b_s, cur_beam_size) + visual_shape[1:] visual_red_shape = (self.b_s * self.beam_size,) + visual_shape[1:] selected_beam_red_size = (self.b_s, self.beam_size) + tuple(1 for _ in range(len(visual_exp_shape) - 2)) selected_beam_exp_size = (self.b_s, self.beam_size) + visual_exp_shape[2:] visual_exp = visual.view(visual_exp_shape) selected_beam_exp = selected_beam.view(selected_beam_red_size).expand(selected_beam_exp_size) visual = torch.gather(visual_exp, 1, selected_beam_exp).view(visual_red_shape) else: new_visual = [] for im in visual: visual_shape = im.shape visual_exp_shape = (self.b_s, cur_beam_size) + visual_shape[1:] visual_red_shape = (self.b_s * self.beam_size,) + visual_shape[1:] selected_beam_red_size = (self.b_s, self.beam_size) + tuple(1 for _ in range(len(visual_exp_shape) - 2)) selected_beam_exp_size = (self.b_s, self.beam_size) + visual_exp_shape[2:] visual_exp = im.view(visual_exp_shape) selected_beam_exp = selected_beam.view(selected_beam_red_size).expand(selected_beam_exp_size) new_im = torch.gather(visual_exp, 1, selected_beam_exp).view(visual_red_shape) new_visual.append(new_im) visual = tuple(new_visual) return visual def apply(self, visual: utils.TensorOrSequence, out_size=1, return_logits=False, **kwargs): self.b_s = utils.get_batch_size(visual) self.device = utils.get_device(visual) self.seq_mask = torch.ones((self.b_s, self.beam_size, 1), device=self.device) self.seq_logprob = torch.zeros((self.b_s, 1, 1), device=self.device) self.log_probs = [] self.selected_words = None if return_logits: self.all_logits = [] outputs = [] with self.model.statefulness(self.b_s): for t in range(self.max_len): visual, outputs = self.iter(t, visual, outputs, return_logits, **kwargs) # Sort result seq_logprob, sort_idxs = torch.sort(self.seq_logprob, 1, descending=True) outputs = torch.cat(outputs, -1) outputs = torch.gather(outputs, 1, sort_idxs.expand(self.b_s, self.beam_size, self.max_len)) log_probs = torch.cat(self.log_probs, -1) log_probs = torch.gather(log_probs, 1, sort_idxs.expand(self.b_s, self.beam_size, self.max_len)) outputs = outputs.contiguous()[:, :out_size] log_probs = log_probs.contiguous()[:, :out_size] if return_logits: all_logits = torch.cat(self.all_logits, 2) all_logits = torch.gather(all_logits, 1, sort_idxs.unsqueeze(-1).expand(self.b_s, self.beam_size, self.max_len, all_logits.shape[-1])) all_logits = all_logits.contiguous()[:, :out_size] if out_size == 1: outputs = outputs.squeeze(1) log_probs = log_probs.squeeze(1) if return_logits: all_logits = all_logits.squeeze(1) if return_logits: return outputs, log_probs, all_logits else: return outputs, log_probs def select(self, t, candidate_logprob, **kwargs): selected_logprob, selected_idx = torch.sort(candidate_logprob.view(self.b_s, -1), -1, descending=True) selected_logprob, selected_idx = selected_logprob[:, :self.beam_size], selected_idx[:, :self.beam_size] return selected_idx, selected_logprob def iter(self, t: int, visual: utils.TensorOrSequence, outputs, return_logits, **kwargs): cur_beam_size = 1 if t == 0 else self.beam_size word_logits = self.model.step(t, self.selected_words, visual, **kwargs) word_logits = word_logits.view(self.b_s, cur_beam_size, -1) word_logprob = torch.log_softmax(word_logits, dim=-1) candidate_logprob = self.seq_logprob + word_logprob # Mask sequence if it reaches EOS if t > 0: mask = (self.selected_words.view(self.b_s, cur_beam_size) != self.eos_idx).type(visual.dtype).unsqueeze(-1) self.seq_mask = self.seq_mask * mask word_logprob = word_logprob * self.seq_mask.expand_as(word_logprob) old_seq_logprob = self.seq_logprob.expand_as(candidate_logprob).contiguous() old_seq_logprob[:, :, 1:] = -999 candidate_logprob = self.seq_mask * candidate_logprob + old_seq_logprob * (1 - self.seq_mask) selected_idx, selected_logprob = self.select(t, candidate_logprob, **kwargs) selected_beam = torch.floor_divide(selected_idx, candidate_logprob.shape[-1]) selected_words = selected_idx - selected_beam * candidate_logprob.shape[-1] self.model.apply_to_states(self._expand_state(selected_beam, cur_beam_size)) visual = self._expand_visual(visual, cur_beam_size, selected_beam) self.seq_logprob = selected_logprob.unsqueeze(-1) self.seq_mask = torch.gather(self.seq_mask, 1, selected_beam.unsqueeze(-1)) outputs = list(torch.gather(o, 1, selected_beam.unsqueeze(-1)) for o in outputs) outputs.append(selected_words.unsqueeze(-1)) if return_logits: if t == 0: self.all_logits.append(word_logits.expand((self.b_s, self.beam_size, -1)).unsqueeze(2)) else: self.all_logits.append(word_logits.unsqueeze(2)) this_word_logprob = torch.gather(word_logprob, 1, selected_beam.unsqueeze(-1).expand(self.b_s, self.beam_size, word_logprob.shape[-1])) this_word_logprob = torch.gather(this_word_logprob, 2, selected_words.unsqueeze(-1)) self.log_probs = list( torch.gather(o, 1, selected_beam.unsqueeze(-1).expand(self.b_s, self.beam_size, 1)) for o in self.log_probs) self.log_probs.append(this_word_logprob) self.selected_words = selected_words.view(-1, 1) return visual, outputs