lightningdot
copied
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions
67 lines
1.4 KiB
67 lines
1.4 KiB
"""
|
|
Misc utilities
|
|
"""
|
|
import json
|
|
import random
|
|
import sys
|
|
|
|
import torch
|
|
import numpy as np
|
|
|
|
from uniter_model.utils.logger import LOGGER
|
|
|
|
|
|
class NoOp(object):
|
|
""" useful for distributed training No-Ops """
|
|
def __getattr__(self, name):
|
|
return self.noop
|
|
|
|
def noop(self, *args, **kwargs):
|
|
return
|
|
|
|
|
|
def parse_with_config(parser):
|
|
args = parser.parse_args()
|
|
if args.config is not None:
|
|
config_args = json.load(open(args.config))
|
|
override_keys = {arg[2:].split('=')[0] for arg in sys.argv[1:]
|
|
if arg.startswith('--')}
|
|
for k, v in config_args.items():
|
|
if k not in override_keys:
|
|
setattr(args, k, v)
|
|
del args.config
|
|
return args
|
|
|
|
|
|
VE_ENT2IDX = {
|
|
'contradiction': 0,
|
|
'entailment': 1,
|
|
'neutral': 2
|
|
}
|
|
|
|
VE_IDX2ENT = {
|
|
0: 'contradiction',
|
|
1: 'entailment',
|
|
2: 'neutral'
|
|
}
|
|
|
|
|
|
class Struct(object):
|
|
def __init__(self, dict_):
|
|
self.__dict__.update(dict_)
|
|
|
|
|
|
def set_dropout(model, drop_p):
|
|
for name, module in model.named_modules():
|
|
# we might want to tune dropout for smaller dataset
|
|
if isinstance(module, torch.nn.Dropout):
|
|
if module.p != drop_p:
|
|
module.p = drop_p
|
|
LOGGER.info(f'{name} set to {drop_p}')
|
|
|
|
|
|
def set_random_seed(seed):
|
|
random.seed(seed)
|
|
np.random.seed(seed)
|
|
torch.manual_seed(seed)
|
|
torch.cuda.manual_seed_all(seed)
|