expansionnet-v2
copied
wxywb
2 years ago
6 changed files with 305 additions and 5 deletions
@ -0,0 +1,37 @@ |
|||
import argparse |
|||
|
|||
|
|||
# thanks Maxim from: https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse |
|||
def str2bool(v): |
|||
if isinstance(v, bool): |
|||
return v |
|||
if v.lower() in ('yes', 'true', 't', 'y', '1'): |
|||
return True |
|||
elif v.lower() in ('no', 'false', 'f', 'n', '0'): |
|||
return False |
|||
else: |
|||
raise argparse.ArgumentTypeError('Boolean value expected.') |
|||
|
|||
|
|||
def str2list(v): |
|||
if '[' in v and ']' in v: |
|||
return list(map(int, v.strip('[]').split(','))) |
|||
else: |
|||
raise argparse.ArgumentTypeError('Input expected in the form [b1,b2,b3,...]') |
|||
|
|||
|
|||
def scheduler_type_choice(v): |
|||
if v == 'annealing' or v == 'custom_warmup_anneal': |
|||
return v |
|||
else: |
|||
raise argparse.ArgumentTypeError('Argument must be either ' |
|||
'\'annealing\', ' |
|||
'\'custom_warmup_anneal\'') |
|||
|
|||
|
|||
def optim_type_choice(v): |
|||
if v == 'adam' or v == 'radam': |
|||
return v |
|||
else: |
|||
raise argparse.ArgumentTypeError('Argument must be either \'adam\', ' |
|||
'\'radam\'.') |
@ -0,0 +1,58 @@ |
|||
import re |
|||
|
|||
|
|||
def compute_num_pads(list_bboxes): |
|||
max_len = -1 |
|||
for bboxes in list_bboxes: |
|||
num_bboxes = len(bboxes) |
|||
if num_bboxes > max_len: |
|||
max_len = num_bboxes |
|||
num_pad_vector = [] |
|||
for bboxes in list_bboxes: |
|||
num_pad_vector.append(max_len - len(bboxes)) |
|||
return num_pad_vector |
|||
|
|||
|
|||
def remove_punctuations(sentences): |
|||
punctuations = ["''", "'", "``", "`", ".", "?", "!", ",", ":", "-", "--", "...", ";"] |
|||
res_sentences_list = [] |
|||
for i in range(len(sentences)): |
|||
res_sentence = [] |
|||
for word in sentences[i].split(' '): |
|||
if word not in punctuations: |
|||
res_sentence.append(word) |
|||
res_sentences_list.append(' '.join(res_sentence)) |
|||
return res_sentences_list |
|||
|
|||
|
|||
def lowercase_and_clean_trailing_spaces(sentences): |
|||
return [(sentences[i].lower()).rstrip() for i in range(len(sentences))] |
|||
|
|||
|
|||
def add_space_between_non_alphanumeric_symbols(sentences): |
|||
return [re.sub(r'([^\w0-9])', r" \1 ", sentences[i]) for i in range(len(sentences))] |
|||
|
|||
|
|||
def tokenize(list_sentences): |
|||
res_sentences_list = [] |
|||
for i in range(len(list_sentences)): |
|||
sentence = list_sentences[i].split(' ') |
|||
while '' in sentence: |
|||
sentence.remove('') |
|||
res_sentences_list.append(sentence) |
|||
return res_sentences_list |
|||
|
|||
def convert_vector_word2idx(sentence, word2idx_dict): |
|||
return [word2idx_dict[word] for word in sentence] |
|||
|
|||
|
|||
def convert_allsentences_word2idx(sentences, word2idx_dict): |
|||
return [convert_vector_word2idx(sentences[i], word2idx_dict) for i in range(len(sentences))] |
|||
|
|||
|
|||
def convert_vector_idx2word(sentence, idx2word_list): |
|||
return [idx2word_list[idx] for idx in sentence] |
|||
|
|||
|
|||
def convert_allsentences_idx2word(sentences, idx2word_list): |
|||
return [convert_vector_idx2word(sentences[i], idx2word_list) for i in range(len(sentences))] |
@ -0,0 +1,22 @@ |
|||
|
|||
import torch |
|||
|
|||
|
|||
def create_pad_mask(mask_size, pad_along_row_input, pad_along_column_input, rank): |
|||
batch_size, output_seq_len, input_seq_len = mask_size |
|||
mask = torch.ones(size=(batch_size, output_seq_len, input_seq_len), dtype=torch.int8).to(rank) |
|||
|
|||
for batch_idx in range(batch_size): |
|||
mask[batch_idx, :, (input_seq_len - pad_along_column_input[batch_idx]):] = 0 |
|||
mask[batch_idx, (output_seq_len - pad_along_row_input[batch_idx]):, :] = 0 |
|||
return mask |
|||
|
|||
|
|||
def create_no_peak_and_pad_mask(mask_size, num_pads, rank): |
|||
batch_size, seq_len, seq_len = mask_size |
|||
mask = torch.tril(torch.ones(size=(seq_len, seq_len), dtype=torch.int8), |
|||
diagonal=0).unsqueeze(0).repeat(batch_size, 1, 1).to(rank) |
|||
for batch_idx in range(batch_size): |
|||
mask[batch_idx, :, seq_len - num_pads[batch_idx]:] = 0 |
|||
mask[batch_idx, (seq_len - num_pads[batch_idx]):, :] = 0 |
|||
return mask |
@ -0,0 +1,109 @@ |
|||
|
|||
import os |
|||
import torch |
|||
from datetime import datetime |
|||
|
|||
from torch.nn.parameter import Parameter |
|||
|
|||
def load_most_recent_checkpoint(model, |
|||
optimizer=None, |
|||
scheduler=None, |
|||
data_loader=None, |
|||
rank=0, |
|||
save_model_path='./', datetime_format='%Y-%m-%d-%H:%M:%S', |
|||
verbose=True): |
|||
ls_files = os.listdir(save_model_path) |
|||
most_recent_checkpoint_datetime = None |
|||
most_recent_checkpoint_filename = None |
|||
most_recent_checkpoint_info = 'no_additional_info' |
|||
for file_name in ls_files: |
|||
if file_name.startswith('checkpoint_'): |
|||
_, datetime_str, _, info, _ = file_name.split('_') |
|||
file_datetime = datetime.strptime(datetime_str, datetime_format) |
|||
if (most_recent_checkpoint_datetime is None) or \ |
|||
(most_recent_checkpoint_datetime is not None and |
|||
file_datetime > most_recent_checkpoint_datetime): |
|||
most_recent_checkpoint_datetime = file_datetime |
|||
most_recent_checkpoint_filename = file_name |
|||
most_recent_checkpoint_info = info |
|||
|
|||
if most_recent_checkpoint_filename is not None: |
|||
if verbose: |
|||
print("Loading: " + str(save_model_path + most_recent_checkpoint_filename)) |
|||
map_location = {'cuda:%d' % 0: 'cuda:%d' % rank} |
|||
checkpoint = torch.load(save_model_path + most_recent_checkpoint_filename, |
|||
map_location=map_location) |
|||
model.load_state_dict(checkpoint['model_state_dict']) |
|||
if optimizer is not None: |
|||
optimizer.load_state_dict(checkpoint['optimizer_state_dict']) |
|||
if scheduler is not None: |
|||
scheduler.load_state_dict(checkpoint['scheduler_state_dict']) |
|||
if data_loader is not None: |
|||
data_loader.load_state(checkpoint['data_loader_state_dict']) |
|||
return True, most_recent_checkpoint_info |
|||
else: |
|||
if verbose: |
|||
print("Loading: no checkpoint found in " + str(save_model_path)) |
|||
return False, most_recent_checkpoint_info |
|||
|
|||
|
|||
def save_last_checkpoint(model, |
|||
optimizer, |
|||
scheduler, |
|||
data_loader, |
|||
save_model_path='./', |
|||
num_max_checkpoints=3, datetime_format='%Y-%m-%d-%H:%M:%S', |
|||
additional_info='noinfo', |
|||
verbose=True): |
|||
|
|||
checkpoint = { |
|||
'model_state_dict': model.state_dict(), |
|||
'optimizer_state_dict': optimizer.state_dict(), |
|||
'scheduler_state_dict': scheduler.state_dict(), |
|||
'data_loader_state_dict': data_loader.save_state(), |
|||
} |
|||
|
|||
ls_files = os.listdir(save_model_path) |
|||
oldest_checkpoint_datetime = None |
|||
oldest_checkpoint_filename = None |
|||
num_check_points = 0 |
|||
for file_name in ls_files: |
|||
if file_name.startswith('checkpoint_'): |
|||
num_check_points += 1 |
|||
_, datetime_str, _, _, _ = file_name.split('_') |
|||
file_datetime = datetime.strptime(datetime_str, datetime_format) |
|||
if (oldest_checkpoint_datetime is None) or \ |
|||
(oldest_checkpoint_datetime is not None and file_datetime < oldest_checkpoint_datetime): |
|||
oldest_checkpoint_datetime = file_datetime |
|||
oldest_checkpoint_filename = file_name |
|||
|
|||
if oldest_checkpoint_filename is not None and num_check_points == num_max_checkpoints: |
|||
os.remove(save_model_path + oldest_checkpoint_filename) |
|||
|
|||
new_checkpoint_filename = 'checkpoint_' + datetime.now().strftime(datetime_format) + \ |
|||
'_epoch' + str(data_loader.get_epoch_it()) + \ |
|||
'it' + str(data_loader.get_batch_it()) + \ |
|||
'bs' + str(data_loader.get_batch_size()) + \ |
|||
'_' + str(additional_info) + '_.pth' |
|||
if verbose: |
|||
print("Saved to " + str(new_checkpoint_filename)) |
|||
torch.save(checkpoint, save_model_path + new_checkpoint_filename) |
|||
|
|||
|
|||
def partially_load_state_dict(model, state_dict, verbose=False): |
|||
own_state = model.state_dict() |
|||
num_print = 5 |
|||
count_print = 0 |
|||
for name, param in state_dict.items(): |
|||
if name not in own_state: |
|||
if verbose: |
|||
print("Not found: " + str(name)) |
|||
continue |
|||
if isinstance(param, Parameter): |
|||
param = param.data |
|||
own_state[name].copy_(param) |
|||
if verbose: |
|||
if count_print < num_print: |
|||
print("Found: " + str(name)) |
|||
count_print += 1 |
|||
|
Loading…
Reference in new issue