From c30e175ccf10c4512582dc051755dd89af3f4516 Mon Sep 17 00:00:00 2001 From: wxywb Date: Tue, 18 Oct 2022 19:58:43 +0800 Subject: [PATCH] update the operator. Signed-off-by: wxywb --- expansionnet_v2.py | 83 ++++++++++++++++++++++++-- models/End_ExpansionNet_v2.py | 1 - utils/args_utils.py | 37 ++++++++++++ utils/language_utils.py | 58 ++++++++++++++++++ utils/masking.py | 22 +++++++ utils/saving_utils.py | 109 ++++++++++++++++++++++++++++++++++ 6 files changed, 305 insertions(+), 5 deletions(-) create mode 100644 utils/args_utils.py create mode 100644 utils/language_utils.py create mode 100644 utils/masking.py create mode 100644 utils/saving_utils.py diff --git a/expansionnet_v2.py b/expansionnet_v2.py index cbaceb2..2020cf3 100644 --- a/expansionnet_v2.py +++ b/expansionnet_v2.py @@ -14,9 +14,12 @@ import sys import os -from pathlib import Path +import pathlib +import pickle +from argparse import Namespace import torch +import torchvision from torchvision import transforms from transformers import GPT2Tokenizer @@ -32,11 +35,36 @@ class ExpansionNetV2(NNOperator): """ def __init__(self, model_name: str): super().__init__() - sys.path.append(str(Path(__file__).parent)) + sys.path.append(str(pathlib.Path(__file__).parent)) from models.End_ExpansionNet_v2 import End_ExpansionNet_v2 + from utils.language_utils import convert_vector_idx2word + self.convert_vector_idx2word = convert_vector_idx2word sys.path.pop() - with open('demo_coco_tokens.pickle') as fw: + path = pathlib.Path(__file__).parent + with open('{}/demo_coco_tokens.pickle'.format(path), 'rb') as f: + coco_tokens = pickle.load(f) + self.coco_tokens = coco_tokens + img_size = 384 + self.device = "cuda" if torch.cuda.is_available() else "cpu" + drop_args = Namespace(enc=0.0, + dec=0.0, + enc_input=0.0, + dec_input=0.0, + other=0.0) + + drop_args = Namespace(enc=0.0, + dec=0.0, + enc_input=0.0, + dec_input=0.0, + other=0.0) + model_args = Namespace(model_dim=512, + N_enc=3, + N_dec=3, + dropout=0.0, + drop_args=drop_args) + max_seq_len = 74 + beam_size = 5 self.model = End_ExpansionNet_v2(swin_img_size=img_size, swin_patch_size=4, swin_in_chans=3, swin_embed_dim=192, swin_depths=[2, 2, 18, 2], swin_num_heads=[6, 12, 24, 48], swin_window_size=12, swin_mlp_ratio=4., swin_qkv_bias=True, swin_qk_scale=None, @@ -51,5 +79,52 @@ class ExpansionNetV2(NNOperator): num_exp_dec=16, output_word2idx=coco_tokens['word2idx_dict'], output_idx2word=coco_tokens['idx2word_list'], - max_seq_len=args.max_seq_len, drop_args=model_args.drop_args, + max_seq_len=max_seq_len, drop_args=model_args.drop_args, rank='cpu') + + self.transf_1 = torchvision.transforms.Compose([torchvision.transforms.Resize((img_size, img_size)), torchvision.transforms.ToTensor()]) + self.transf_2 = torchvision.transforms.Compose([torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) + self.beam_search_kwargs = {'beam_size': beam_size, + 'beam_max_seq_len': max_seq_len, + 'sample_or_max': 'max', + 'how_many_outputs': 1, + 'sos_idx': coco_tokens['word2idx_dict'][coco_tokens['sos_str']], + 'eos_idx': coco_tokens['word2idx_dict'][coco_tokens['eos_str']]} + def _preprocess(self, img): + img = to_pil(img) + processed_img = self.transf_1(img) + processed_img = self.transf_2(processed_img) + processed_img = processed_img.to(self.device) + return processed_img + + @arg(1, to_image_color('RGB')) + def inference_single_data(self, data): + text = self._inference_from_image(data) + return text + + def __call__(self, data): + if not isinstance(data, list): + data = [data] + else: + data = data + results = [] + for single_data in data: + result = self.inference_single_data(single_data) + results.append(result) + if len(data) == 1: + return results[0] + else: + return results + + @arg(1, to_image_color('RGB')) + def _inference_from_image(self, img): + img = self._preprocess(img).unsqueeze(0) + with torch.no_grad(): + pred, _ = self.model(enc_x=img, + enc_x_num_pads=[0], + mode='beam_search', **self.beam_search_kwargs) + pred = self.convert_vector_idx2word(pred[0][0], self.coco_tokens['idx2word_list'])[1:-1] + pred[-1] = pred[-1] + '.' + pred = ' '.join(pred).capitalize() + + return pred diff --git a/models/End_ExpansionNet_v2.py b/models/End_ExpansionNet_v2.py index 712f760..90b519b 100644 --- a/models/End_ExpansionNet_v2.py +++ b/models/End_ExpansionNet_v2.py @@ -74,7 +74,6 @@ class End_ExpansionNet_v2(CaptioningModel): self.check_required_attributes() def forward_enc(self, enc_input, enc_input_num_pads): - assert (enc_input_num_pads is None or enc_input_num_pads == ([0] * enc_input.size(0))), "End to End case have no padding" x = self.swin_transf(enc_input) # --------------- Normale parte di Captioning --------------------------------- diff --git a/utils/args_utils.py b/utils/args_utils.py new file mode 100644 index 0000000..cd7b174 --- /dev/null +++ b/utils/args_utils.py @@ -0,0 +1,37 @@ +import argparse + + +# thanks Maxim from: https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse +def str2bool(v): + if isinstance(v, bool): + return v + if v.lower() in ('yes', 'true', 't', 'y', '1'): + return True + elif v.lower() in ('no', 'false', 'f', 'n', '0'): + return False + else: + raise argparse.ArgumentTypeError('Boolean value expected.') + + +def str2list(v): + if '[' in v and ']' in v: + return list(map(int, v.strip('[]').split(','))) + else: + raise argparse.ArgumentTypeError('Input expected in the form [b1,b2,b3,...]') + + +def scheduler_type_choice(v): + if v == 'annealing' or v == 'custom_warmup_anneal': + return v + else: + raise argparse.ArgumentTypeError('Argument must be either ' + '\'annealing\', ' + '\'custom_warmup_anneal\'') + + +def optim_type_choice(v): + if v == 'adam' or v == 'radam': + return v + else: + raise argparse.ArgumentTypeError('Argument must be either \'adam\', ' + '\'radam\'.') \ No newline at end of file diff --git a/utils/language_utils.py b/utils/language_utils.py new file mode 100644 index 0000000..a4415e6 --- /dev/null +++ b/utils/language_utils.py @@ -0,0 +1,58 @@ +import re + + +def compute_num_pads(list_bboxes): + max_len = -1 + for bboxes in list_bboxes: + num_bboxes = len(bboxes) + if num_bboxes > max_len: + max_len = num_bboxes + num_pad_vector = [] + for bboxes in list_bboxes: + num_pad_vector.append(max_len - len(bboxes)) + return num_pad_vector + + +def remove_punctuations(sentences): + punctuations = ["''", "'", "``", "`", ".", "?", "!", ",", ":", "-", "--", "...", ";"] + res_sentences_list = [] + for i in range(len(sentences)): + res_sentence = [] + for word in sentences[i].split(' '): + if word not in punctuations: + res_sentence.append(word) + res_sentences_list.append(' '.join(res_sentence)) + return res_sentences_list + + +def lowercase_and_clean_trailing_spaces(sentences): + return [(sentences[i].lower()).rstrip() for i in range(len(sentences))] + + +def add_space_between_non_alphanumeric_symbols(sentences): + return [re.sub(r'([^\w0-9])', r" \1 ", sentences[i]) for i in range(len(sentences))] + + +def tokenize(list_sentences): + res_sentences_list = [] + for i in range(len(list_sentences)): + sentence = list_sentences[i].split(' ') + while '' in sentence: + sentence.remove('') + res_sentences_list.append(sentence) + return res_sentences_list + +def convert_vector_word2idx(sentence, word2idx_dict): + return [word2idx_dict[word] for word in sentence] + + +def convert_allsentences_word2idx(sentences, word2idx_dict): + return [convert_vector_word2idx(sentences[i], word2idx_dict) for i in range(len(sentences))] + + +def convert_vector_idx2word(sentence, idx2word_list): + return [idx2word_list[idx] for idx in sentence] + + +def convert_allsentences_idx2word(sentences, idx2word_list): + return [convert_vector_idx2word(sentences[i], idx2word_list) for i in range(len(sentences))] \ No newline at end of file diff --git a/utils/masking.py b/utils/masking.py new file mode 100644 index 0000000..2d9cf44 --- /dev/null +++ b/utils/masking.py @@ -0,0 +1,22 @@ + +import torch + + +def create_pad_mask(mask_size, pad_along_row_input, pad_along_column_input, rank): + batch_size, output_seq_len, input_seq_len = mask_size + mask = torch.ones(size=(batch_size, output_seq_len, input_seq_len), dtype=torch.int8).to(rank) + + for batch_idx in range(batch_size): + mask[batch_idx, :, (input_seq_len - pad_along_column_input[batch_idx]):] = 0 + mask[batch_idx, (output_seq_len - pad_along_row_input[batch_idx]):, :] = 0 + return mask + + +def create_no_peak_and_pad_mask(mask_size, num_pads, rank): + batch_size, seq_len, seq_len = mask_size + mask = torch.tril(torch.ones(size=(seq_len, seq_len), dtype=torch.int8), + diagonal=0).unsqueeze(0).repeat(batch_size, 1, 1).to(rank) + for batch_idx in range(batch_size): + mask[batch_idx, :, seq_len - num_pads[batch_idx]:] = 0 + mask[batch_idx, (seq_len - num_pads[batch_idx]):, :] = 0 + return mask diff --git a/utils/saving_utils.py b/utils/saving_utils.py new file mode 100644 index 0000000..5eafdc4 --- /dev/null +++ b/utils/saving_utils.py @@ -0,0 +1,109 @@ + +import os +import torch +from datetime import datetime + +from torch.nn.parameter import Parameter + +def load_most_recent_checkpoint(model, + optimizer=None, + scheduler=None, + data_loader=None, + rank=0, + save_model_path='./', datetime_format='%Y-%m-%d-%H:%M:%S', + verbose=True): + ls_files = os.listdir(save_model_path) + most_recent_checkpoint_datetime = None + most_recent_checkpoint_filename = None + most_recent_checkpoint_info = 'no_additional_info' + for file_name in ls_files: + if file_name.startswith('checkpoint_'): + _, datetime_str, _, info, _ = file_name.split('_') + file_datetime = datetime.strptime(datetime_str, datetime_format) + if (most_recent_checkpoint_datetime is None) or \ + (most_recent_checkpoint_datetime is not None and + file_datetime > most_recent_checkpoint_datetime): + most_recent_checkpoint_datetime = file_datetime + most_recent_checkpoint_filename = file_name + most_recent_checkpoint_info = info + + if most_recent_checkpoint_filename is not None: + if verbose: + print("Loading: " + str(save_model_path + most_recent_checkpoint_filename)) + map_location = {'cuda:%d' % 0: 'cuda:%d' % rank} + checkpoint = torch.load(save_model_path + most_recent_checkpoint_filename, + map_location=map_location) + model.load_state_dict(checkpoint['model_state_dict']) + if optimizer is not None: + optimizer.load_state_dict(checkpoint['optimizer_state_dict']) + if scheduler is not None: + scheduler.load_state_dict(checkpoint['scheduler_state_dict']) + if data_loader is not None: + data_loader.load_state(checkpoint['data_loader_state_dict']) + return True, most_recent_checkpoint_info + else: + if verbose: + print("Loading: no checkpoint found in " + str(save_model_path)) + return False, most_recent_checkpoint_info + + +def save_last_checkpoint(model, + optimizer, + scheduler, + data_loader, + save_model_path='./', + num_max_checkpoints=3, datetime_format='%Y-%m-%d-%H:%M:%S', + additional_info='noinfo', + verbose=True): + + checkpoint = { + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'scheduler_state_dict': scheduler.state_dict(), + 'data_loader_state_dict': data_loader.save_state(), + } + + ls_files = os.listdir(save_model_path) + oldest_checkpoint_datetime = None + oldest_checkpoint_filename = None + num_check_points = 0 + for file_name in ls_files: + if file_name.startswith('checkpoint_'): + num_check_points += 1 + _, datetime_str, _, _, _ = file_name.split('_') + file_datetime = datetime.strptime(datetime_str, datetime_format) + if (oldest_checkpoint_datetime is None) or \ + (oldest_checkpoint_datetime is not None and file_datetime < oldest_checkpoint_datetime): + oldest_checkpoint_datetime = file_datetime + oldest_checkpoint_filename = file_name + + if oldest_checkpoint_filename is not None and num_check_points == num_max_checkpoints: + os.remove(save_model_path + oldest_checkpoint_filename) + + new_checkpoint_filename = 'checkpoint_' + datetime.now().strftime(datetime_format) + \ + '_epoch' + str(data_loader.get_epoch_it()) + \ + 'it' + str(data_loader.get_batch_it()) + \ + 'bs' + str(data_loader.get_batch_size()) + \ + '_' + str(additional_info) + '_.pth' + if verbose: + print("Saved to " + str(new_checkpoint_filename)) + torch.save(checkpoint, save_model_path + new_checkpoint_filename) + + +def partially_load_state_dict(model, state_dict, verbose=False): + own_state = model.state_dict() + num_print = 5 + count_print = 0 + for name, param in state_dict.items(): + if name not in own_state: + if verbose: + print("Not found: " + str(name)) + continue + if isinstance(param, Parameter): + param = param.data + own_state[name].copy_(param) + if verbose: + if count_print < num_print: + print("Found: " + str(name)) + count_print += 1 +