expansionnet-v2
copied
wxywb
2 years ago
6 changed files with 305 additions and 5 deletions
@ -0,0 +1,37 @@ |
|||||
|
import argparse |
||||
|
|
||||
|
|
||||
|
# thanks Maxim from: https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse |
||||
|
def str2bool(v): |
||||
|
if isinstance(v, bool): |
||||
|
return v |
||||
|
if v.lower() in ('yes', 'true', 't', 'y', '1'): |
||||
|
return True |
||||
|
elif v.lower() in ('no', 'false', 'f', 'n', '0'): |
||||
|
return False |
||||
|
else: |
||||
|
raise argparse.ArgumentTypeError('Boolean value expected.') |
||||
|
|
||||
|
|
||||
|
def str2list(v): |
||||
|
if '[' in v and ']' in v: |
||||
|
return list(map(int, v.strip('[]').split(','))) |
||||
|
else: |
||||
|
raise argparse.ArgumentTypeError('Input expected in the form [b1,b2,b3,...]') |
||||
|
|
||||
|
|
||||
|
def scheduler_type_choice(v): |
||||
|
if v == 'annealing' or v == 'custom_warmup_anneal': |
||||
|
return v |
||||
|
else: |
||||
|
raise argparse.ArgumentTypeError('Argument must be either ' |
||||
|
'\'annealing\', ' |
||||
|
'\'custom_warmup_anneal\'') |
||||
|
|
||||
|
|
||||
|
def optim_type_choice(v): |
||||
|
if v == 'adam' or v == 'radam': |
||||
|
return v |
||||
|
else: |
||||
|
raise argparse.ArgumentTypeError('Argument must be either \'adam\', ' |
||||
|
'\'radam\'.') |
@ -0,0 +1,58 @@ |
|||||
|
import re |
||||
|
|
||||
|
|
||||
|
def compute_num_pads(list_bboxes): |
||||
|
max_len = -1 |
||||
|
for bboxes in list_bboxes: |
||||
|
num_bboxes = len(bboxes) |
||||
|
if num_bboxes > max_len: |
||||
|
max_len = num_bboxes |
||||
|
num_pad_vector = [] |
||||
|
for bboxes in list_bboxes: |
||||
|
num_pad_vector.append(max_len - len(bboxes)) |
||||
|
return num_pad_vector |
||||
|
|
||||
|
|
||||
|
def remove_punctuations(sentences): |
||||
|
punctuations = ["''", "'", "``", "`", ".", "?", "!", ",", ":", "-", "--", "...", ";"] |
||||
|
res_sentences_list = [] |
||||
|
for i in range(len(sentences)): |
||||
|
res_sentence = [] |
||||
|
for word in sentences[i].split(' '): |
||||
|
if word not in punctuations: |
||||
|
res_sentence.append(word) |
||||
|
res_sentences_list.append(' '.join(res_sentence)) |
||||
|
return res_sentences_list |
||||
|
|
||||
|
|
||||
|
def lowercase_and_clean_trailing_spaces(sentences): |
||||
|
return [(sentences[i].lower()).rstrip() for i in range(len(sentences))] |
||||
|
|
||||
|
|
||||
|
def add_space_between_non_alphanumeric_symbols(sentences): |
||||
|
return [re.sub(r'([^\w0-9])', r" \1 ", sentences[i]) for i in range(len(sentences))] |
||||
|
|
||||
|
|
||||
|
def tokenize(list_sentences): |
||||
|
res_sentences_list = [] |
||||
|
for i in range(len(list_sentences)): |
||||
|
sentence = list_sentences[i].split(' ') |
||||
|
while '' in sentence: |
||||
|
sentence.remove('') |
||||
|
res_sentences_list.append(sentence) |
||||
|
return res_sentences_list |
||||
|
|
||||
|
def convert_vector_word2idx(sentence, word2idx_dict): |
||||
|
return [word2idx_dict[word] for word in sentence] |
||||
|
|
||||
|
|
||||
|
def convert_allsentences_word2idx(sentences, word2idx_dict): |
||||
|
return [convert_vector_word2idx(sentences[i], word2idx_dict) for i in range(len(sentences))] |
||||
|
|
||||
|
|
||||
|
def convert_vector_idx2word(sentence, idx2word_list): |
||||
|
return [idx2word_list[idx] for idx in sentence] |
||||
|
|
||||
|
|
||||
|
def convert_allsentences_idx2word(sentences, idx2word_list): |
||||
|
return [convert_vector_idx2word(sentences[i], idx2word_list) for i in range(len(sentences))] |
@ -0,0 +1,22 @@ |
|||||
|
|
||||
|
import torch |
||||
|
|
||||
|
|
||||
|
def create_pad_mask(mask_size, pad_along_row_input, pad_along_column_input, rank): |
||||
|
batch_size, output_seq_len, input_seq_len = mask_size |
||||
|
mask = torch.ones(size=(batch_size, output_seq_len, input_seq_len), dtype=torch.int8).to(rank) |
||||
|
|
||||
|
for batch_idx in range(batch_size): |
||||
|
mask[batch_idx, :, (input_seq_len - pad_along_column_input[batch_idx]):] = 0 |
||||
|
mask[batch_idx, (output_seq_len - pad_along_row_input[batch_idx]):, :] = 0 |
||||
|
return mask |
||||
|
|
||||
|
|
||||
|
def create_no_peak_and_pad_mask(mask_size, num_pads, rank): |
||||
|
batch_size, seq_len, seq_len = mask_size |
||||
|
mask = torch.tril(torch.ones(size=(seq_len, seq_len), dtype=torch.int8), |
||||
|
diagonal=0).unsqueeze(0).repeat(batch_size, 1, 1).to(rank) |
||||
|
for batch_idx in range(batch_size): |
||||
|
mask[batch_idx, :, seq_len - num_pads[batch_idx]:] = 0 |
||||
|
mask[batch_idx, (seq_len - num_pads[batch_idx]):, :] = 0 |
||||
|
return mask |
@ -0,0 +1,109 @@ |
|||||
|
|
||||
|
import os |
||||
|
import torch |
||||
|
from datetime import datetime |
||||
|
|
||||
|
from torch.nn.parameter import Parameter |
||||
|
|
||||
|
def load_most_recent_checkpoint(model, |
||||
|
optimizer=None, |
||||
|
scheduler=None, |
||||
|
data_loader=None, |
||||
|
rank=0, |
||||
|
save_model_path='./', datetime_format='%Y-%m-%d-%H:%M:%S', |
||||
|
verbose=True): |
||||
|
ls_files = os.listdir(save_model_path) |
||||
|
most_recent_checkpoint_datetime = None |
||||
|
most_recent_checkpoint_filename = None |
||||
|
most_recent_checkpoint_info = 'no_additional_info' |
||||
|
for file_name in ls_files: |
||||
|
if file_name.startswith('checkpoint_'): |
||||
|
_, datetime_str, _, info, _ = file_name.split('_') |
||||
|
file_datetime = datetime.strptime(datetime_str, datetime_format) |
||||
|
if (most_recent_checkpoint_datetime is None) or \ |
||||
|
(most_recent_checkpoint_datetime is not None and |
||||
|
file_datetime > most_recent_checkpoint_datetime): |
||||
|
most_recent_checkpoint_datetime = file_datetime |
||||
|
most_recent_checkpoint_filename = file_name |
||||
|
most_recent_checkpoint_info = info |
||||
|
|
||||
|
if most_recent_checkpoint_filename is not None: |
||||
|
if verbose: |
||||
|
print("Loading: " + str(save_model_path + most_recent_checkpoint_filename)) |
||||
|
map_location = {'cuda:%d' % 0: 'cuda:%d' % rank} |
||||
|
checkpoint = torch.load(save_model_path + most_recent_checkpoint_filename, |
||||
|
map_location=map_location) |
||||
|
model.load_state_dict(checkpoint['model_state_dict']) |
||||
|
if optimizer is not None: |
||||
|
optimizer.load_state_dict(checkpoint['optimizer_state_dict']) |
||||
|
if scheduler is not None: |
||||
|
scheduler.load_state_dict(checkpoint['scheduler_state_dict']) |
||||
|
if data_loader is not None: |
||||
|
data_loader.load_state(checkpoint['data_loader_state_dict']) |
||||
|
return True, most_recent_checkpoint_info |
||||
|
else: |
||||
|
if verbose: |
||||
|
print("Loading: no checkpoint found in " + str(save_model_path)) |
||||
|
return False, most_recent_checkpoint_info |
||||
|
|
||||
|
|
||||
|
def save_last_checkpoint(model, |
||||
|
optimizer, |
||||
|
scheduler, |
||||
|
data_loader, |
||||
|
save_model_path='./', |
||||
|
num_max_checkpoints=3, datetime_format='%Y-%m-%d-%H:%M:%S', |
||||
|
additional_info='noinfo', |
||||
|
verbose=True): |
||||
|
|
||||
|
checkpoint = { |
||||
|
'model_state_dict': model.state_dict(), |
||||
|
'optimizer_state_dict': optimizer.state_dict(), |
||||
|
'scheduler_state_dict': scheduler.state_dict(), |
||||
|
'data_loader_state_dict': data_loader.save_state(), |
||||
|
} |
||||
|
|
||||
|
ls_files = os.listdir(save_model_path) |
||||
|
oldest_checkpoint_datetime = None |
||||
|
oldest_checkpoint_filename = None |
||||
|
num_check_points = 0 |
||||
|
for file_name in ls_files: |
||||
|
if file_name.startswith('checkpoint_'): |
||||
|
num_check_points += 1 |
||||
|
_, datetime_str, _, _, _ = file_name.split('_') |
||||
|
file_datetime = datetime.strptime(datetime_str, datetime_format) |
||||
|
if (oldest_checkpoint_datetime is None) or \ |
||||
|
(oldest_checkpoint_datetime is not None and file_datetime < oldest_checkpoint_datetime): |
||||
|
oldest_checkpoint_datetime = file_datetime |
||||
|
oldest_checkpoint_filename = file_name |
||||
|
|
||||
|
if oldest_checkpoint_filename is not None and num_check_points == num_max_checkpoints: |
||||
|
os.remove(save_model_path + oldest_checkpoint_filename) |
||||
|
|
||||
|
new_checkpoint_filename = 'checkpoint_' + datetime.now().strftime(datetime_format) + \ |
||||
|
'_epoch' + str(data_loader.get_epoch_it()) + \ |
||||
|
'it' + str(data_loader.get_batch_it()) + \ |
||||
|
'bs' + str(data_loader.get_batch_size()) + \ |
||||
|
'_' + str(additional_info) + '_.pth' |
||||
|
if verbose: |
||||
|
print("Saved to " + str(new_checkpoint_filename)) |
||||
|
torch.save(checkpoint, save_model_path + new_checkpoint_filename) |
||||
|
|
||||
|
|
||||
|
def partially_load_state_dict(model, state_dict, verbose=False): |
||||
|
own_state = model.state_dict() |
||||
|
num_print = 5 |
||||
|
count_print = 0 |
||||
|
for name, param in state_dict.items(): |
||||
|
if name not in own_state: |
||||
|
if verbose: |
||||
|
print("Not found: " + str(name)) |
||||
|
continue |
||||
|
if isinstance(param, Parameter): |
||||
|
param = param.data |
||||
|
own_state[name].copy_(param) |
||||
|
if verbose: |
||||
|
if count_print < num_print: |
||||
|
print("Found: " + str(name)) |
||||
|
count_print += 1 |
||||
|
|
Loading…
Reference in new issue