""" saving utilities """ import json import os from os.path import abspath, dirname, exists, join import subprocess import torch from uniter_model.utils.logger import LOGGER def save_training_meta(args): if args.rank > 0: return os.makedirs(join(args.output_dir, 'log'), exist_ok=True) os.makedirs(join(args.output_dir, 'ckpt'), exist_ok=True) with open(join(args.output_dir, 'log', 'hps.json'), 'w') as writer: json.dump(vars(args), writer, indent=4) if False: model_config = json.load(open(args.model_config)) with open(join(args.output_dir, 'log', 'model.json'), 'w') as writer: json.dump(model_config, writer, indent=4) # git info try: LOGGER.info("Waiting on git info....") c = subprocess.run(["git", "rev-parse", "--abbrev-ref", "HEAD"], timeout=10, stdout=subprocess.PIPE) git_branch_name = c.stdout.decode().strip() LOGGER.info("Git branch: %s", git_branch_name) c = subprocess.run(["git", "rev-parse", "HEAD"], timeout=10, stdout=subprocess.PIPE) git_sha = c.stdout.decode().strip() LOGGER.info("Git SHA: %s", git_sha) git_dir = abspath(dirname(__file__)) git_status = subprocess.check_output( ['git', 'status', '--short'], cwd=git_dir, universal_newlines=True).strip() with open(join(args.output_dir, 'log', 'git_info.json'), 'w') as writer: json.dump({'branch': git_branch_name, 'is_dirty': bool(git_status), 'status': git_status, 'sha': git_sha}, writer, indent=4) except subprocess.TimeoutExpired as e: LOGGER.exception(e) LOGGER.warn("Git info not found. Moving right along...") class ModelSaver(object): def __init__(self, output_dir, prefix='model_step', suffix='pt'): self.output_dir = output_dir self.prefix = prefix self.suffix = suffix def save(self, model, step, optimizer=None): output_model_file = join(self.output_dir, f"{self.prefix}_{step}.{self.suffix}") state_dict = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in model.state_dict().items()} if hasattr(model, 'vocab_pad') and model.vocab_pad: # store vocab embeddings before padding emb_w = state_dict['bert.embeddings.word_embeddings.weight'] emb_w = emb_w[:-model.vocab_pad, :] state_dict['bert.embeddings.word_embeddings.weight'] = emb_w state_dict['cls.predictions.decoder.weight'] = emb_w torch.save(state_dict, output_model_file) if optimizer is not None: dump = {'step': step, 'optimizer': optimizer.state_dict()} if hasattr(optimizer, '_amp_stash'): pass # TODO fp16 optimizer torch.save(dump, f'{self.output_dir}/train_state_{step}.pt')