expansionnet-v2
copied
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions
110 lines
4.6 KiB
110 lines
4.6 KiB
2 years ago
|
|
||
|
import os
|
||
|
import torch
|
||
|
from datetime import datetime
|
||
|
|
||
|
from torch.nn.parameter import Parameter
|
||
|
|
||
|
def load_most_recent_checkpoint(model,
|
||
|
optimizer=None,
|
||
|
scheduler=None,
|
||
|
data_loader=None,
|
||
|
rank=0,
|
||
|
save_model_path='./', datetime_format='%Y-%m-%d-%H:%M:%S',
|
||
|
verbose=True):
|
||
|
ls_files = os.listdir(save_model_path)
|
||
|
most_recent_checkpoint_datetime = None
|
||
|
most_recent_checkpoint_filename = None
|
||
|
most_recent_checkpoint_info = 'no_additional_info'
|
||
|
for file_name in ls_files:
|
||
|
if file_name.startswith('checkpoint_'):
|
||
|
_, datetime_str, _, info, _ = file_name.split('_')
|
||
|
file_datetime = datetime.strptime(datetime_str, datetime_format)
|
||
|
if (most_recent_checkpoint_datetime is None) or \
|
||
|
(most_recent_checkpoint_datetime is not None and
|
||
|
file_datetime > most_recent_checkpoint_datetime):
|
||
|
most_recent_checkpoint_datetime = file_datetime
|
||
|
most_recent_checkpoint_filename = file_name
|
||
|
most_recent_checkpoint_info = info
|
||
|
|
||
|
if most_recent_checkpoint_filename is not None:
|
||
|
if verbose:
|
||
|
print("Loading: " + str(save_model_path + most_recent_checkpoint_filename))
|
||
|
map_location = {'cuda:%d' % 0: 'cuda:%d' % rank}
|
||
|
checkpoint = torch.load(save_model_path + most_recent_checkpoint_filename,
|
||
|
map_location=map_location)
|
||
|
model.load_state_dict(checkpoint['model_state_dict'])
|
||
|
if optimizer is not None:
|
||
|
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
||
|
if scheduler is not None:
|
||
|
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
||
|
if data_loader is not None:
|
||
|
data_loader.load_state(checkpoint['data_loader_state_dict'])
|
||
|
return True, most_recent_checkpoint_info
|
||
|
else:
|
||
|
if verbose:
|
||
|
print("Loading: no checkpoint found in " + str(save_model_path))
|
||
|
return False, most_recent_checkpoint_info
|
||
|
|
||
|
|
||
|
def save_last_checkpoint(model,
|
||
|
optimizer,
|
||
|
scheduler,
|
||
|
data_loader,
|
||
|
save_model_path='./',
|
||
|
num_max_checkpoints=3, datetime_format='%Y-%m-%d-%H:%M:%S',
|
||
|
additional_info='noinfo',
|
||
|
verbose=True):
|
||
|
|
||
|
checkpoint = {
|
||
|
'model_state_dict': model.state_dict(),
|
||
|
'optimizer_state_dict': optimizer.state_dict(),
|
||
|
'scheduler_state_dict': scheduler.state_dict(),
|
||
|
'data_loader_state_dict': data_loader.save_state(),
|
||
|
}
|
||
|
|
||
|
ls_files = os.listdir(save_model_path)
|
||
|
oldest_checkpoint_datetime = None
|
||
|
oldest_checkpoint_filename = None
|
||
|
num_check_points = 0
|
||
|
for file_name in ls_files:
|
||
|
if file_name.startswith('checkpoint_'):
|
||
|
num_check_points += 1
|
||
|
_, datetime_str, _, _, _ = file_name.split('_')
|
||
|
file_datetime = datetime.strptime(datetime_str, datetime_format)
|
||
|
if (oldest_checkpoint_datetime is None) or \
|
||
|
(oldest_checkpoint_datetime is not None and file_datetime < oldest_checkpoint_datetime):
|
||
|
oldest_checkpoint_datetime = file_datetime
|
||
|
oldest_checkpoint_filename = file_name
|
||
|
|
||
|
if oldest_checkpoint_filename is not None and num_check_points == num_max_checkpoints:
|
||
|
os.remove(save_model_path + oldest_checkpoint_filename)
|
||
|
|
||
|
new_checkpoint_filename = 'checkpoint_' + datetime.now().strftime(datetime_format) + \
|
||
|
'_epoch' + str(data_loader.get_epoch_it()) + \
|
||
|
'it' + str(data_loader.get_batch_it()) + \
|
||
|
'bs' + str(data_loader.get_batch_size()) + \
|
||
|
'_' + str(additional_info) + '_.pth'
|
||
|
if verbose:
|
||
|
print("Saved to " + str(new_checkpoint_filename))
|
||
|
torch.save(checkpoint, save_model_path + new_checkpoint_filename)
|
||
|
|
||
|
|
||
|
def partially_load_state_dict(model, state_dict, verbose=False):
|
||
|
own_state = model.state_dict()
|
||
|
num_print = 5
|
||
|
count_print = 0
|
||
|
for name, param in state_dict.items():
|
||
|
if name not in own_state:
|
||
|
if verbose:
|
||
|
print("Not found: " + str(name))
|
||
|
continue
|
||
|
if isinstance(param, Parameter):
|
||
|
param = param.data
|
||
|
own_state[name].copy_(param)
|
||
|
if verbose:
|
||
|
if count_print < num_print:
|
||
|
print("Found: " + str(name))
|
||
|
count_print += 1
|
||
|
|