logo
Browse Source

update the operator.

Signed-off-by: wxywb <xy.wang@zilliz.com>
main
wxywb 2 years ago
parent
commit
c30e175ccf
  1. 83
      expansionnet_v2.py
  2. 1
      models/End_ExpansionNet_v2.py
  3. 37
      utils/args_utils.py
  4. 58
      utils/language_utils.py
  5. 22
      utils/masking.py
  6. 109
      utils/saving_utils.py

83
expansionnet_v2.py

@ -14,9 +14,12 @@
import sys import sys
import os import os
from pathlib import Path
import pathlib
import pickle
from argparse import Namespace
import torch import torch
import torchvision
from torchvision import transforms from torchvision import transforms
from transformers import GPT2Tokenizer from transformers import GPT2Tokenizer
@ -32,11 +35,36 @@ class ExpansionNetV2(NNOperator):
""" """
def __init__(self, model_name: str): def __init__(self, model_name: str):
super().__init__() super().__init__()
sys.path.append(str(Path(__file__).parent))
sys.path.append(str(pathlib.Path(__file__).parent))
from models.End_ExpansionNet_v2 import End_ExpansionNet_v2 from models.End_ExpansionNet_v2 import End_ExpansionNet_v2
from utils.language_utils import convert_vector_idx2word
self.convert_vector_idx2word = convert_vector_idx2word
sys.path.pop() sys.path.pop()
with open('demo_coco_tokens.pickle') as fw:
path = pathlib.Path(__file__).parent
with open('{}/demo_coco_tokens.pickle'.format(path), 'rb') as f:
coco_tokens = pickle.load(f)
self.coco_tokens = coco_tokens
img_size = 384
self.device = "cuda" if torch.cuda.is_available() else "cpu"
drop_args = Namespace(enc=0.0,
dec=0.0,
enc_input=0.0,
dec_input=0.0,
other=0.0)
drop_args = Namespace(enc=0.0,
dec=0.0,
enc_input=0.0,
dec_input=0.0,
other=0.0)
model_args = Namespace(model_dim=512,
N_enc=3,
N_dec=3,
dropout=0.0,
drop_args=drop_args)
max_seq_len = 74
beam_size = 5
self.model = End_ExpansionNet_v2(swin_img_size=img_size, swin_patch_size=4, swin_in_chans=3, self.model = End_ExpansionNet_v2(swin_img_size=img_size, swin_patch_size=4, swin_in_chans=3,
swin_embed_dim=192, swin_depths=[2, 2, 18, 2], swin_num_heads=[6, 12, 24, 48], swin_embed_dim=192, swin_depths=[2, 2, 18, 2], swin_num_heads=[6, 12, 24, 48],
swin_window_size=12, swin_mlp_ratio=4., swin_qkv_bias=True, swin_qk_scale=None, swin_window_size=12, swin_mlp_ratio=4., swin_qkv_bias=True, swin_qk_scale=None,
@ -51,5 +79,52 @@ class ExpansionNetV2(NNOperator):
num_exp_dec=16, num_exp_dec=16,
output_word2idx=coco_tokens['word2idx_dict'], output_word2idx=coco_tokens['word2idx_dict'],
output_idx2word=coco_tokens['idx2word_list'], output_idx2word=coco_tokens['idx2word_list'],
max_seq_len=args.max_seq_len, drop_args=model_args.drop_args,
max_seq_len=max_seq_len, drop_args=model_args.drop_args,
rank='cpu') rank='cpu')
self.transf_1 = torchvision.transforms.Compose([torchvision.transforms.Resize((img_size, img_size)), torchvision.transforms.ToTensor()])
self.transf_2 = torchvision.transforms.Compose([torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
self.beam_search_kwargs = {'beam_size': beam_size,
'beam_max_seq_len': max_seq_len,
'sample_or_max': 'max',
'how_many_outputs': 1,
'sos_idx': coco_tokens['word2idx_dict'][coco_tokens['sos_str']],
'eos_idx': coco_tokens['word2idx_dict'][coco_tokens['eos_str']]}
def _preprocess(self, img):
img = to_pil(img)
processed_img = self.transf_1(img)
processed_img = self.transf_2(processed_img)
processed_img = processed_img.to(self.device)
return processed_img
@arg(1, to_image_color('RGB'))
def inference_single_data(self, data):
text = self._inference_from_image(data)
return text
def __call__(self, data):
if not isinstance(data, list):
data = [data]
else:
data = data
results = []
for single_data in data:
result = self.inference_single_data(single_data)
results.append(result)
if len(data) == 1:
return results[0]
else:
return results
@arg(1, to_image_color('RGB'))
def _inference_from_image(self, img):
img = self._preprocess(img).unsqueeze(0)
with torch.no_grad():
pred, _ = self.model(enc_x=img,
enc_x_num_pads=[0],
mode='beam_search', **self.beam_search_kwargs)
pred = self.convert_vector_idx2word(pred[0][0], self.coco_tokens['idx2word_list'])[1:-1]
pred[-1] = pred[-1] + '.'
pred = ' '.join(pred).capitalize()
return pred

1
models/End_ExpansionNet_v2.py

@ -74,7 +74,6 @@ class End_ExpansionNet_v2(CaptioningModel):
self.check_required_attributes() self.check_required_attributes()
def forward_enc(self, enc_input, enc_input_num_pads): def forward_enc(self, enc_input, enc_input_num_pads):
assert (enc_input_num_pads is None or enc_input_num_pads == ([0] * enc_input.size(0))), "End to End case have no padding" assert (enc_input_num_pads is None or enc_input_num_pads == ([0] * enc_input.size(0))), "End to End case have no padding"
x = self.swin_transf(enc_input) x = self.swin_transf(enc_input)
# --------------- Normale parte di Captioning --------------------------------- # --------------- Normale parte di Captioning ---------------------------------

37
utils/args_utils.py

@ -0,0 +1,37 @@
import argparse
# thanks Maxim from: https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse
def str2bool(v):
if isinstance(v, bool):
return v
if v.lower() in ('yes', 'true', 't', 'y', '1'):
return True
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
return False
else:
raise argparse.ArgumentTypeError('Boolean value expected.')
def str2list(v):
if '[' in v and ']' in v:
return list(map(int, v.strip('[]').split(',')))
else:
raise argparse.ArgumentTypeError('Input expected in the form [b1,b2,b3,...]')
def scheduler_type_choice(v):
if v == 'annealing' or v == 'custom_warmup_anneal':
return v
else:
raise argparse.ArgumentTypeError('Argument must be either '
'\'annealing\', '
'\'custom_warmup_anneal\'')
def optim_type_choice(v):
if v == 'adam' or v == 'radam':
return v
else:
raise argparse.ArgumentTypeError('Argument must be either \'adam\', '
'\'radam\'.')

58
utils/language_utils.py

@ -0,0 +1,58 @@
import re
def compute_num_pads(list_bboxes):
max_len = -1
for bboxes in list_bboxes:
num_bboxes = len(bboxes)
if num_bboxes > max_len:
max_len = num_bboxes
num_pad_vector = []
for bboxes in list_bboxes:
num_pad_vector.append(max_len - len(bboxes))
return num_pad_vector
def remove_punctuations(sentences):
punctuations = ["''", "'", "``", "`", ".", "?", "!", ",", ":", "-", "--", "...", ";"]
res_sentences_list = []
for i in range(len(sentences)):
res_sentence = []
for word in sentences[i].split(' '):
if word not in punctuations:
res_sentence.append(word)
res_sentences_list.append(' '.join(res_sentence))
return res_sentences_list
def lowercase_and_clean_trailing_spaces(sentences):
return [(sentences[i].lower()).rstrip() for i in range(len(sentences))]
def add_space_between_non_alphanumeric_symbols(sentences):
return [re.sub(r'([^\w0-9])', r" \1 ", sentences[i]) for i in range(len(sentences))]
def tokenize(list_sentences):
res_sentences_list = []
for i in range(len(list_sentences)):
sentence = list_sentences[i].split(' ')
while '' in sentence:
sentence.remove('')
res_sentences_list.append(sentence)
return res_sentences_list
def convert_vector_word2idx(sentence, word2idx_dict):
return [word2idx_dict[word] for word in sentence]
def convert_allsentences_word2idx(sentences, word2idx_dict):
return [convert_vector_word2idx(sentences[i], word2idx_dict) for i in range(len(sentences))]
def convert_vector_idx2word(sentence, idx2word_list):
return [idx2word_list[idx] for idx in sentence]
def convert_allsentences_idx2word(sentences, idx2word_list):
return [convert_vector_idx2word(sentences[i], idx2word_list) for i in range(len(sentences))]

22
utils/masking.py

@ -0,0 +1,22 @@
import torch
def create_pad_mask(mask_size, pad_along_row_input, pad_along_column_input, rank):
batch_size, output_seq_len, input_seq_len = mask_size
mask = torch.ones(size=(batch_size, output_seq_len, input_seq_len), dtype=torch.int8).to(rank)
for batch_idx in range(batch_size):
mask[batch_idx, :, (input_seq_len - pad_along_column_input[batch_idx]):] = 0
mask[batch_idx, (output_seq_len - pad_along_row_input[batch_idx]):, :] = 0
return mask
def create_no_peak_and_pad_mask(mask_size, num_pads, rank):
batch_size, seq_len, seq_len = mask_size
mask = torch.tril(torch.ones(size=(seq_len, seq_len), dtype=torch.int8),
diagonal=0).unsqueeze(0).repeat(batch_size, 1, 1).to(rank)
for batch_idx in range(batch_size):
mask[batch_idx, :, seq_len - num_pads[batch_idx]:] = 0
mask[batch_idx, (seq_len - num_pads[batch_idx]):, :] = 0
return mask

109
utils/saving_utils.py

@ -0,0 +1,109 @@
import os
import torch
from datetime import datetime
from torch.nn.parameter import Parameter
def load_most_recent_checkpoint(model,
optimizer=None,
scheduler=None,
data_loader=None,
rank=0,
save_model_path='./', datetime_format='%Y-%m-%d-%H:%M:%S',
verbose=True):
ls_files = os.listdir(save_model_path)
most_recent_checkpoint_datetime = None
most_recent_checkpoint_filename = None
most_recent_checkpoint_info = 'no_additional_info'
for file_name in ls_files:
if file_name.startswith('checkpoint_'):
_, datetime_str, _, info, _ = file_name.split('_')
file_datetime = datetime.strptime(datetime_str, datetime_format)
if (most_recent_checkpoint_datetime is None) or \
(most_recent_checkpoint_datetime is not None and
file_datetime > most_recent_checkpoint_datetime):
most_recent_checkpoint_datetime = file_datetime
most_recent_checkpoint_filename = file_name
most_recent_checkpoint_info = info
if most_recent_checkpoint_filename is not None:
if verbose:
print("Loading: " + str(save_model_path + most_recent_checkpoint_filename))
map_location = {'cuda:%d' % 0: 'cuda:%d' % rank}
checkpoint = torch.load(save_model_path + most_recent_checkpoint_filename,
map_location=map_location)
model.load_state_dict(checkpoint['model_state_dict'])
if optimizer is not None:
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
if scheduler is not None:
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
if data_loader is not None:
data_loader.load_state(checkpoint['data_loader_state_dict'])
return True, most_recent_checkpoint_info
else:
if verbose:
print("Loading: no checkpoint found in " + str(save_model_path))
return False, most_recent_checkpoint_info
def save_last_checkpoint(model,
optimizer,
scheduler,
data_loader,
save_model_path='./',
num_max_checkpoints=3, datetime_format='%Y-%m-%d-%H:%M:%S',
additional_info='noinfo',
verbose=True):
checkpoint = {
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
'data_loader_state_dict': data_loader.save_state(),
}
ls_files = os.listdir(save_model_path)
oldest_checkpoint_datetime = None
oldest_checkpoint_filename = None
num_check_points = 0
for file_name in ls_files:
if file_name.startswith('checkpoint_'):
num_check_points += 1
_, datetime_str, _, _, _ = file_name.split('_')
file_datetime = datetime.strptime(datetime_str, datetime_format)
if (oldest_checkpoint_datetime is None) or \
(oldest_checkpoint_datetime is not None and file_datetime < oldest_checkpoint_datetime):
oldest_checkpoint_datetime = file_datetime
oldest_checkpoint_filename = file_name
if oldest_checkpoint_filename is not None and num_check_points == num_max_checkpoints:
os.remove(save_model_path + oldest_checkpoint_filename)
new_checkpoint_filename = 'checkpoint_' + datetime.now().strftime(datetime_format) + \
'_epoch' + str(data_loader.get_epoch_it()) + \
'it' + str(data_loader.get_batch_it()) + \
'bs' + str(data_loader.get_batch_size()) + \
'_' + str(additional_info) + '_.pth'
if verbose:
print("Saved to " + str(new_checkpoint_filename))
torch.save(checkpoint, save_model_path + new_checkpoint_filename)
def partially_load_state_dict(model, state_dict, verbose=False):
own_state = model.state_dict()
num_print = 5
count_print = 0
for name, param in state_dict.items():
if name not in own_state:
if verbose:
print("Not found: " + str(name))
continue
if isinstance(param, Parameter):
param = param.data
own_state[name].copy_(param)
if verbose:
if count_print < num_print:
print("Found: " + str(name))
count_print += 1
Loading…
Cancel
Save