logo
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions

91 lines
2.4 KiB

"""
helper for logging
NOTE: loggers are global objects use with caution
"""
import logging
import math
import tensorboardX
_LOG_FMT = '%(asctime)s - %(levelname)s - %(name)s - %(message)s'
_DATE_FMT = '%m/%d/%Y %H:%M:%S'
logging.basicConfig(format=_LOG_FMT, datefmt=_DATE_FMT, level=logging.INFO)
LOGGER = logging.getLogger('__main__') # this is the global logger
def add_log_to_file(log_path):
fh = logging.FileHandler(log_path)
formatter = logging.Formatter(_LOG_FMT, datefmt=_DATE_FMT)
fh.setFormatter(formatter)
LOGGER.addHandler(fh)
class TensorboardLogger(object):
def __init__(self):
self._logger = None
self._global_step = 0
def create(self, path):
self._logger = tensorboardX.SummaryWriter(path)
def noop(self, *args, **kwargs):
return
def step(self):
self._global_step += 1
@property
def global_step(self):
return self._global_step
def log_scaler_dict(self, log_dict, prefix=''):
""" log a dictionary of scalar values"""
if self._logger is None:
return
if prefix:
prefix = f'{prefix}_'
for name, value in log_dict.items():
if isinstance(value, dict):
self.log_scaler_dict(value, self._global_step,
prefix=f'{prefix}{name}')
else:
self._logger.add_scalar(f'{prefix}{name}', value,
self._global_step)
def __getattr__(self, name):
if self._logger is None:
return self.noop
return self._logger.__getattribute__(name)
TB_LOGGER = TensorboardLogger()
class RunningMeter(object):
""" running meteor of a scalar value
(useful for monitoring training loss)
"""
def __init__(self, name, val=None, smooth=0.99):
self._name = name
self._sm = smooth
self._val = val
def __call__(self, value):
val = (value if self._val is None
else value*(1-self._sm) + self._val*self._sm)
if not math.isnan(val) and not math.isinf(val):
self._val = val
else:
print(f'Inf/Nan in {self._name}')
def __str__(self):
return f'{self._name}: {self._val:.4f}'
@property
def val(self):
return self._val
@property
def name(self):
return self._name