lightningdot
copied
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions
77 lines
3.1 KiB
77 lines
3.1 KiB
2 years ago
|
"""
|
||
|
saving utilities
|
||
|
"""
|
||
|
import json
|
||
|
import os
|
||
|
from os.path import abspath, dirname, exists, join
|
||
|
import subprocess
|
||
|
|
||
|
import torch
|
||
|
|
||
|
from uniter_model.utils.logger import LOGGER
|
||
|
|
||
|
|
||
|
def save_training_meta(args):
|
||
|
if args.rank > 0:
|
||
|
return
|
||
|
|
||
|
os.makedirs(join(args.output_dir, 'log'), exist_ok=True)
|
||
|
os.makedirs(join(args.output_dir, 'ckpt'), exist_ok=True)
|
||
|
|
||
|
with open(join(args.output_dir, 'log', 'hps.json'), 'w') as writer:
|
||
|
json.dump(vars(args), writer, indent=4)
|
||
|
if False:
|
||
|
model_config = json.load(open(args.model_config))
|
||
|
with open(join(args.output_dir, 'log', 'model.json'), 'w') as writer:
|
||
|
json.dump(model_config, writer, indent=4)
|
||
|
# git info
|
||
|
try:
|
||
|
LOGGER.info("Waiting on git info....")
|
||
|
c = subprocess.run(["git", "rev-parse", "--abbrev-ref", "HEAD"],
|
||
|
timeout=10, stdout=subprocess.PIPE)
|
||
|
git_branch_name = c.stdout.decode().strip()
|
||
|
LOGGER.info("Git branch: %s", git_branch_name)
|
||
|
c = subprocess.run(["git", "rev-parse", "HEAD"],
|
||
|
timeout=10, stdout=subprocess.PIPE)
|
||
|
git_sha = c.stdout.decode().strip()
|
||
|
LOGGER.info("Git SHA: %s", git_sha)
|
||
|
git_dir = abspath(dirname(__file__))
|
||
|
git_status = subprocess.check_output(
|
||
|
['git', 'status', '--short'],
|
||
|
cwd=git_dir, universal_newlines=True).strip()
|
||
|
with open(join(args.output_dir, 'log', 'git_info.json'),
|
||
|
'w') as writer:
|
||
|
json.dump({'branch': git_branch_name,
|
||
|
'is_dirty': bool(git_status),
|
||
|
'status': git_status,
|
||
|
'sha': git_sha},
|
||
|
writer, indent=4)
|
||
|
except subprocess.TimeoutExpired as e:
|
||
|
LOGGER.exception(e)
|
||
|
LOGGER.warn("Git info not found. Moving right along...")
|
||
|
|
||
|
|
||
|
class ModelSaver(object):
|
||
|
def __init__(self, output_dir, prefix='model_step', suffix='pt'):
|
||
|
self.output_dir = output_dir
|
||
|
self.prefix = prefix
|
||
|
self.suffix = suffix
|
||
|
|
||
|
def save(self, model, step, optimizer=None):
|
||
|
output_model_file = join(self.output_dir,
|
||
|
f"{self.prefix}_{step}.{self.suffix}")
|
||
|
state_dict = {k: v.cpu() if isinstance(v, torch.Tensor) else v
|
||
|
for k, v in model.state_dict().items()}
|
||
|
if hasattr(model, 'vocab_pad') and model.vocab_pad:
|
||
|
# store vocab embeddings before padding
|
||
|
emb_w = state_dict['bert.embeddings.word_embeddings.weight']
|
||
|
emb_w = emb_w[:-model.vocab_pad, :]
|
||
|
state_dict['bert.embeddings.word_embeddings.weight'] = emb_w
|
||
|
state_dict['cls.predictions.decoder.weight'] = emb_w
|
||
|
torch.save(state_dict, output_model_file)
|
||
|
if optimizer is not None:
|
||
|
dump = {'step': step, 'optimizer': optimizer.state_dict()}
|
||
|
if hasattr(optimizer, '_amp_stash'):
|
||
|
pass # TODO fp16 optimizer
|
||
|
torch.save(dump, f'{self.output_dir}/train_state_{step}.pt')
|