frozen-in-time
copied
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions
197 lines
7.0 KiB
197 lines
7.0 KiB
# Original pytorch implementation by:
|
|
# 'Frozen in Time: A Joint Image and Video Encoder for End-to-End Retrieval'
|
|
# - https://arxiv.org/abs/2104.00650
|
|
# Original code by / Copyright 2021, Max Bain.
|
|
# Modifications & additions by / Copyright 2021 Zilliz. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import inspect
|
|
import logging
|
|
import os
|
|
import time
|
|
from collections import OrderedDict
|
|
from datetime import datetime
|
|
from functools import reduce
|
|
from operator import getitem
|
|
from pathlib import Path
|
|
import json
|
|
|
|
|
|
def read_json(fname):
|
|
with fname.open('rt') as handle:
|
|
return json.load(handle, object_hook=OrderedDict)
|
|
|
|
|
|
def write_json(content, fname):
|
|
with fname.open('wt') as handle:
|
|
json.dump(content, handle, indent=4, sort_keys=False)
|
|
|
|
|
|
def setup_logging(save_dir, log_config='logger/logger_config.json', default_level=logging.INFO):
|
|
"""
|
|
Setup logging configuration
|
|
"""
|
|
log_config = Path(log_config)
|
|
if log_config.is_file():
|
|
config = read_json(log_config)
|
|
# modify logging paths based on run config
|
|
for _, handler in config['handlers'].items():
|
|
if 'filename' in handler:
|
|
handler['filename'] = str(save_dir / handler['filename'])
|
|
|
|
logging.config.dictConfig(config)
|
|
else:
|
|
print(f'Warning: logging configuration file is not found in {log_config}.')
|
|
logging.basicConfig(level=default_level)
|
|
|
|
|
|
class ConfigParser:
|
|
"""
|
|
args:
|
|
options:
|
|
timestamp:
|
|
test:
|
|
"""
|
|
def __init__(self, args=None, options='', timestamp=True, test=False):
|
|
# parse default and custom cli options
|
|
if args is None:
|
|
return
|
|
for opt in options:
|
|
args.add_argument(*opt.flags, default=None, type=opt.type)
|
|
args = args.parse_args()
|
|
|
|
if args.device:
|
|
os.environ['CUDA_VISIBLE_DEVICES'] = args.device
|
|
if args.resume is None:
|
|
msg_no_cfg = "Configuration file need to be specified. Add '-c config.json', for example."
|
|
assert args.config is not None, msg_no_cfg
|
|
self.cfg_fname = Path(args.config)
|
|
config = read_json(self.cfg_fname)
|
|
self.resume = None
|
|
else:
|
|
self.resume = Path(args.resume)
|
|
# resume_cfg_fname = self.resume.parent / 'config.json'
|
|
resume_cfg_fname = Path(args.config)
|
|
config = read_json(resume_cfg_fname)
|
|
if args.config is not None:
|
|
config.update(read_json(Path(args.config)))
|
|
|
|
# load config file and apply custom cli options
|
|
self._config = _update_config(config, options, args)
|
|
|
|
# set save_dir where trained model and log will be saved.
|
|
save_dir = Path(self.config['trainer']['save_dir'])
|
|
timestamp = datetime.now().strftime(r'%m%d_%H%M%S') if timestamp else ''
|
|
|
|
exper_name = self.config['name']
|
|
self._save_dir = save_dir / 'models' / exper_name / timestamp
|
|
self._web_log_dir = save_dir / 'web' / exper_name / timestamp
|
|
self._log_dir = save_dir / 'log' / exper_name / timestamp
|
|
|
|
if not test:
|
|
self.save_dir.mkdir(parents=True, exist_ok=True)
|
|
self.log_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
# if set, remove all previous experiments with the current config
|
|
if vars(args).get('purge_exp_dir', False):
|
|
for dirpath in (self._save_dir, self._log_dir, self._web_log_dir):
|
|
config_dir = dirpath.parent
|
|
existing = list(config_dir.glob('*'))
|
|
print(f'purging {len(existing)} directories from config_dir...')
|
|
tic = time.time()
|
|
os.system(f'rm -rf {config_dir}')
|
|
print(f'Finished purge in {time.time() - tic:.3f}s')
|
|
|
|
# save updated config file to the checkpoint dir
|
|
if not test:
|
|
write_json(self.config, self.save_dir / 'config.json')
|
|
|
|
# configure logging module
|
|
setup_logging(self.log_dir)
|
|
self.log_levels = {
|
|
0: logging.WARNING,
|
|
1: logging.INFO,
|
|
2: logging.DEBUG
|
|
}
|
|
|
|
def initialize(self, name, module, *args, index=None, **kwargs):
|
|
"""
|
|
finds a function handle with the name given as 'type' in config, and returns the
|
|
instance initialized with corresponding keyword args given as 'args'.
|
|
"""
|
|
if index is None:
|
|
module_name = self[name]['type']
|
|
module_args = dict(self[name]['args'])
|
|
assert all(k not in module_args for k in kwargs), 'Overwriting kwargs given in config file is not allowed'
|
|
module_args.update(kwargs)
|
|
else:
|
|
module_name = self[name][index]['type']
|
|
module_args = dict(self[name][index]['args'])
|
|
|
|
# if parameter not in config subdict, then check if it's in global config.
|
|
signature = inspect.signature(getattr(module, module_name).__init__)
|
|
print(module_name)
|
|
for param in signature.parameters.keys():
|
|
if param not in module_args and param in self.config:
|
|
module_args[param] = self[param]
|
|
|
|
return getattr(module, module_name)(*args, **module_args)
|
|
|
|
def __getitem__(self, name):
|
|
return self.config[name]
|
|
|
|
def get_logger(self, name, verbosity=2):
|
|
msg_verbosity = f'verbosity option {verbosity} is invalid. Valid options are {self.log_levels.keys()}.'
|
|
assert verbosity in self.log_levels, msg_verbosity
|
|
logger = logging.getLogger(name)
|
|
logger.setLevel(self.log_levels[verbosity])
|
|
return logger
|
|
|
|
# setting read-only attributes
|
|
@property
|
|
def config(self):
|
|
return self._config
|
|
|
|
@property
|
|
def save_dir(self):
|
|
return self._save_dir
|
|
|
|
@property
|
|
def log_dir(self):
|
|
return self._log_dir
|
|
|
|
|
|
# helper functions used to update config dict with custom cli options
|
|
def _update_config(config, options, args):
|
|
for opt in options:
|
|
value = getattr(args, _get_opt_name(opt.flags))
|
|
if value is not None:
|
|
_set_by_path(config, opt.target, value)
|
|
return config
|
|
|
|
|
|
def _get_opt_name(flags):
|
|
for flg in flags:
|
|
if flg.startswith('--'):
|
|
return flg.replace('--', '')
|
|
return flags[0].replace('--', '')
|
|
|
|
|
|
def _set_by_path(tree, keys, value):
|
|
"""Set a value in a nested object in tree by sequence of keys."""
|
|
_get_by_path(tree, keys[:-1])[keys[-1]] = value
|
|
|
|
|
|
def _get_by_path(tree, keys):
|
|
"""Access a nested object in tree by sequence of keys."""
|
|
return reduce(getitem, keys, tree)
|