logo
Browse Source

update the operator.

Signed-off-by: wxywb <xy.wang@zilliz.com>
main
wxywb 2 years ago
parent
commit
c30e175ccf
  1. 83
      expansionnet_v2.py
  2. 1
      models/End_ExpansionNet_v2.py
  3. 37
      utils/args_utils.py
  4. 58
      utils/language_utils.py
  5. 22
      utils/masking.py
  6. 109
      utils/saving_utils.py

83
expansionnet_v2.py

@ -14,9 +14,12 @@
import sys
import os
from pathlib import Path
import pathlib
import pickle
from argparse import Namespace
import torch
import torchvision
from torchvision import transforms
from transformers import GPT2Tokenizer
@ -32,11 +35,36 @@ class ExpansionNetV2(NNOperator):
"""
def __init__(self, model_name: str):
super().__init__()
sys.path.append(str(Path(__file__).parent))
sys.path.append(str(pathlib.Path(__file__).parent))
from models.End_ExpansionNet_v2 import End_ExpansionNet_v2
from utils.language_utils import convert_vector_idx2word
self.convert_vector_idx2word = convert_vector_idx2word
sys.path.pop()
with open('demo_coco_tokens.pickle') as fw:
path = pathlib.Path(__file__).parent
with open('{}/demo_coco_tokens.pickle'.format(path), 'rb') as f:
coco_tokens = pickle.load(f)
self.coco_tokens = coco_tokens
img_size = 384
self.device = "cuda" if torch.cuda.is_available() else "cpu"
drop_args = Namespace(enc=0.0,
dec=0.0,
enc_input=0.0,
dec_input=0.0,
other=0.0)
drop_args = Namespace(enc=0.0,
dec=0.0,
enc_input=0.0,
dec_input=0.0,
other=0.0)
model_args = Namespace(model_dim=512,
N_enc=3,
N_dec=3,
dropout=0.0,
drop_args=drop_args)
max_seq_len = 74
beam_size = 5
self.model = End_ExpansionNet_v2(swin_img_size=img_size, swin_patch_size=4, swin_in_chans=3,
swin_embed_dim=192, swin_depths=[2, 2, 18, 2], swin_num_heads=[6, 12, 24, 48],
swin_window_size=12, swin_mlp_ratio=4., swin_qkv_bias=True, swin_qk_scale=None,
@ -51,5 +79,52 @@ class ExpansionNetV2(NNOperator):
num_exp_dec=16,
output_word2idx=coco_tokens['word2idx_dict'],
output_idx2word=coco_tokens['idx2word_list'],
max_seq_len=args.max_seq_len, drop_args=model_args.drop_args,
max_seq_len=max_seq_len, drop_args=model_args.drop_args,
rank='cpu')
self.transf_1 = torchvision.transforms.Compose([torchvision.transforms.Resize((img_size, img_size)), torchvision.transforms.ToTensor()])
self.transf_2 = torchvision.transforms.Compose([torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
self.beam_search_kwargs = {'beam_size': beam_size,
'beam_max_seq_len': max_seq_len,
'sample_or_max': 'max',
'how_many_outputs': 1,
'sos_idx': coco_tokens['word2idx_dict'][coco_tokens['sos_str']],
'eos_idx': coco_tokens['word2idx_dict'][coco_tokens['eos_str']]}
def _preprocess(self, img):
img = to_pil(img)
processed_img = self.transf_1(img)
processed_img = self.transf_2(processed_img)
processed_img = processed_img.to(self.device)
return processed_img
@arg(1, to_image_color('RGB'))
def inference_single_data(self, data):
text = self._inference_from_image(data)
return text
def __call__(self, data):
if not isinstance(data, list):
data = [data]
else:
data = data
results = []
for single_data in data:
result = self.inference_single_data(single_data)
results.append(result)
if len(data) == 1:
return results[0]
else:
return results
@arg(1, to_image_color('RGB'))
def _inference_from_image(self, img):
img = self._preprocess(img).unsqueeze(0)
with torch.no_grad():
pred, _ = self.model(enc_x=img,
enc_x_num_pads=[0],
mode='beam_search', **self.beam_search_kwargs)
pred = self.convert_vector_idx2word(pred[0][0], self.coco_tokens['idx2word_list'])[1:-1]
pred[-1] = pred[-1] + '.'
pred = ' '.join(pred).capitalize()
return pred

1
models/End_ExpansionNet_v2.py

@ -74,7 +74,6 @@ class End_ExpansionNet_v2(CaptioningModel):
self.check_required_attributes()
def forward_enc(self, enc_input, enc_input_num_pads):
assert (enc_input_num_pads is None or enc_input_num_pads == ([0] * enc_input.size(0))), "End to End case have no padding"
x = self.swin_transf(enc_input)
# --------------- Normale parte di Captioning ---------------------------------

37
utils/args_utils.py

@ -0,0 +1,37 @@
import argparse
# thanks Maxim from: https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse
def str2bool(v):
if isinstance(v, bool):
return v
if v.lower() in ('yes', 'true', 't', 'y', '1'):
return True
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
return False
else:
raise argparse.ArgumentTypeError('Boolean value expected.')
def str2list(v):
if '[' in v and ']' in v:
return list(map(int, v.strip('[]').split(',')))
else:
raise argparse.ArgumentTypeError('Input expected in the form [b1,b2,b3,...]')
def scheduler_type_choice(v):
if v == 'annealing' or v == 'custom_warmup_anneal':
return v
else:
raise argparse.ArgumentTypeError('Argument must be either '
'\'annealing\', '
'\'custom_warmup_anneal\'')
def optim_type_choice(v):
if v == 'adam' or v == 'radam':
return v
else:
raise argparse.ArgumentTypeError('Argument must be either \'adam\', '
'\'radam\'.')

58
utils/language_utils.py

@ -0,0 +1,58 @@
import re
def compute_num_pads(list_bboxes):
max_len = -1
for bboxes in list_bboxes:
num_bboxes = len(bboxes)
if num_bboxes > max_len:
max_len = num_bboxes
num_pad_vector = []
for bboxes in list_bboxes:
num_pad_vector.append(max_len - len(bboxes))
return num_pad_vector
def remove_punctuations(sentences):
punctuations = ["''", "'", "``", "`", ".", "?", "!", ",", ":", "-", "--", "...", ";"]
res_sentences_list = []
for i in range(len(sentences)):
res_sentence = []
for word in sentences[i].split(' '):
if word not in punctuations:
res_sentence.append(word)
res_sentences_list.append(' '.join(res_sentence))
return res_sentences_list
def lowercase_and_clean_trailing_spaces(sentences):
return [(sentences[i].lower()).rstrip() for i in range(len(sentences))]
def add_space_between_non_alphanumeric_symbols(sentences):
return [re.sub(r'([^\w0-9])', r" \1 ", sentences[i]) for i in range(len(sentences))]
def tokenize(list_sentences):
res_sentences_list = []
for i in range(len(list_sentences)):
sentence = list_sentences[i].split(' ')
while '' in sentence:
sentence.remove('')
res_sentences_list.append(sentence)
return res_sentences_list
def convert_vector_word2idx(sentence, word2idx_dict):
return [word2idx_dict[word] for word in sentence]
def convert_allsentences_word2idx(sentences, word2idx_dict):
return [convert_vector_word2idx(sentences[i], word2idx_dict) for i in range(len(sentences))]
def convert_vector_idx2word(sentence, idx2word_list):
return [idx2word_list[idx] for idx in sentence]
def convert_allsentences_idx2word(sentences, idx2word_list):
return [convert_vector_idx2word(sentences[i], idx2word_list) for i in range(len(sentences))]

22
utils/masking.py

@ -0,0 +1,22 @@
import torch
def create_pad_mask(mask_size, pad_along_row_input, pad_along_column_input, rank):
batch_size, output_seq_len, input_seq_len = mask_size
mask = torch.ones(size=(batch_size, output_seq_len, input_seq_len), dtype=torch.int8).to(rank)
for batch_idx in range(batch_size):
mask[batch_idx, :, (input_seq_len - pad_along_column_input[batch_idx]):] = 0
mask[batch_idx, (output_seq_len - pad_along_row_input[batch_idx]):, :] = 0
return mask
def create_no_peak_and_pad_mask(mask_size, num_pads, rank):
batch_size, seq_len, seq_len = mask_size
mask = torch.tril(torch.ones(size=(seq_len, seq_len), dtype=torch.int8),
diagonal=0).unsqueeze(0).repeat(batch_size, 1, 1).to(rank)
for batch_idx in range(batch_size):
mask[batch_idx, :, seq_len - num_pads[batch_idx]:] = 0
mask[batch_idx, (seq_len - num_pads[batch_idx]):, :] = 0
return mask

109
utils/saving_utils.py

@ -0,0 +1,109 @@
import os
import torch
from datetime import datetime
from torch.nn.parameter import Parameter
def load_most_recent_checkpoint(model,
optimizer=None,
scheduler=None,
data_loader=None,
rank=0,
save_model_path='./', datetime_format='%Y-%m-%d-%H:%M:%S',
verbose=True):
ls_files = os.listdir(save_model_path)
most_recent_checkpoint_datetime = None
most_recent_checkpoint_filename = None
most_recent_checkpoint_info = 'no_additional_info'
for file_name in ls_files:
if file_name.startswith('checkpoint_'):
_, datetime_str, _, info, _ = file_name.split('_')
file_datetime = datetime.strptime(datetime_str, datetime_format)
if (most_recent_checkpoint_datetime is None) or \
(most_recent_checkpoint_datetime is not None and
file_datetime > most_recent_checkpoint_datetime):
most_recent_checkpoint_datetime = file_datetime
most_recent_checkpoint_filename = file_name
most_recent_checkpoint_info = info
if most_recent_checkpoint_filename is not None:
if verbose:
print("Loading: " + str(save_model_path + most_recent_checkpoint_filename))
map_location = {'cuda:%d' % 0: 'cuda:%d' % rank}
checkpoint = torch.load(save_model_path + most_recent_checkpoint_filename,
map_location=map_location)
model.load_state_dict(checkpoint['model_state_dict'])
if optimizer is not None:
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
if scheduler is not None:
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
if data_loader is not None:
data_loader.load_state(checkpoint['data_loader_state_dict'])
return True, most_recent_checkpoint_info
else:
if verbose:
print("Loading: no checkpoint found in " + str(save_model_path))
return False, most_recent_checkpoint_info
def save_last_checkpoint(model,
optimizer,
scheduler,
data_loader,
save_model_path='./',
num_max_checkpoints=3, datetime_format='%Y-%m-%d-%H:%M:%S',
additional_info='noinfo',
verbose=True):
checkpoint = {
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
'data_loader_state_dict': data_loader.save_state(),
}
ls_files = os.listdir(save_model_path)
oldest_checkpoint_datetime = None
oldest_checkpoint_filename = None
num_check_points = 0
for file_name in ls_files:
if file_name.startswith('checkpoint_'):
num_check_points += 1
_, datetime_str, _, _, _ = file_name.split('_')
file_datetime = datetime.strptime(datetime_str, datetime_format)
if (oldest_checkpoint_datetime is None) or \
(oldest_checkpoint_datetime is not None and file_datetime < oldest_checkpoint_datetime):
oldest_checkpoint_datetime = file_datetime
oldest_checkpoint_filename = file_name
if oldest_checkpoint_filename is not None and num_check_points == num_max_checkpoints:
os.remove(save_model_path + oldest_checkpoint_filename)
new_checkpoint_filename = 'checkpoint_' + datetime.now().strftime(datetime_format) + \
'_epoch' + str(data_loader.get_epoch_it()) + \
'it' + str(data_loader.get_batch_it()) + \
'bs' + str(data_loader.get_batch_size()) + \
'_' + str(additional_info) + '_.pth'
if verbose:
print("Saved to " + str(new_checkpoint_filename))
torch.save(checkpoint, save_model_path + new_checkpoint_filename)
def partially_load_state_dict(model, state_dict, verbose=False):
own_state = model.state_dict()
num_print = 5
count_print = 0
for name, param in state_dict.items():
if name not in own_state:
if verbose:
print("Not found: " + str(name))
continue
if isinstance(param, Parameter):
param = param.data
own_state[name].copy_(param)
if verbose:
if count_print < num_print:
print("Found: " + str(name))
count_print += 1
Loading…
Cancel
Save