lightningdot
copied
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions
176 lines
8.3 KiB
176 lines
8.3 KiB
import argparse
|
|
import json
|
|
import sys
|
|
import os
|
|
import logging
|
|
import torch
|
|
import random
|
|
import socket
|
|
import numpy as np
|
|
|
|
|
|
logger = logging.getLogger()
|
|
|
|
|
|
def default_params(parser: argparse.ArgumentParser):
|
|
parser.add_argument('--txt_model_type', default='bert-base', type=str, help="")
|
|
parser.add_argument('--txt_model_config', default='bert-base', type=str, help="")
|
|
parser.add_argument('--txt_checkpoint', default=None, type=str, help="")
|
|
parser.add_argument('--img_model_type', default='uniter-base', type=str, help="")
|
|
parser.add_argument('--img_model_config', default='./config/img_base.json', type=str, help="")
|
|
parser.add_argument('--img_checkpoint', default=None, type=str, help="")
|
|
parser.add_argument('--biencoder_checkpoint', default=None, type=str, help="")
|
|
parser.add_argument('--seperate_caption_encoder', action='store_true', help="")
|
|
|
|
parser.add_argument('--train_batch_size', default=80, type=int, help="")
|
|
parser.add_argument('--valid_batch_size', default=80, type=int, help="")
|
|
parser.add_argument('--gradient_accumulation_steps', default=1, type=int, help="")
|
|
parser.add_argument('--learning_rate', default=1e-5, type=float, help="")
|
|
parser.add_argument('--max_grad_norm', default=2.0, type=float, help="")
|
|
parser.add_argument('--warmup_steps', default=500, type=int, help="")
|
|
parser.add_argument('--valid_steps', default=500, type=int, help="")
|
|
parser.add_argument('--num_train_steps', default=5000, type=int, help="")
|
|
parser.add_argument('--num_train_epochs', default=0, type=int, help="")
|
|
|
|
parser.add_argument('--fp16', action='store_true', help="")
|
|
parser.add_argument('--seed', default=42, type=int, help="")
|
|
parser.add_argument('--output_dir', default='./', type=str, help="")
|
|
parser.add_argument('--max_txt_len', default=64, type=int, help="")
|
|
parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus")
|
|
parser.add_argument('--config', default=None, type=str, help="")
|
|
parser.add_argument('--itm_global_file', default=None, type=str, help="")
|
|
parser.add_argument("--no_cuda", action='store_true', help="Whether not to use CUDA when available")
|
|
parser.add_argument('--n_workers', type=int, default=2, help="number of data workers")
|
|
parser.add_argument('--pin_mem', action='store_true', help="pin memory") # ???
|
|
parser.add_argument('--hnsw_index', action='store_true', help="")
|
|
parser.add_argument('--fp16_opt_level', type=str, default='O1', help="")
|
|
parser.add_argument('--img_meta', type=str, default=None, help="")
|
|
|
|
|
|
def add_itm_params(parser: argparse.ArgumentParser):
|
|
parser.add_argument('--conf_th', default=0.2, type=float, help="")
|
|
parser.add_argument('--caption_score_weight', default=0.0, type=float, help="")
|
|
parser.add_argument('--negative_size', default=10, type=int, help="")
|
|
parser.add_argument('--num_hard_negatives', default=0, type=int, help="")
|
|
parser.add_argument('--sample_init_hard_negatives', action='store_true', help="")
|
|
parser.add_argument('--hard_negatives_sampling', default='none', type=str,
|
|
choices=['none', 'random', 'top', 'top-random', '10-20', '20-30'], help="")
|
|
parser.add_argument('--max_bb', default=100, type=int, help="")
|
|
parser.add_argument('--min_bb', default=10, type=int, help="")
|
|
parser.add_argument('--num_bb', default=36, type=int, help="")
|
|
parser.add_argument('--train_txt_dbs', default=None, type=str, help="")
|
|
parser.add_argument('--train_img_dbs', default=None, type=str, help="")
|
|
|
|
parser.add_argument('--txt_db_mapping', default=None, type=str, help="")
|
|
parser.add_argument('--img_db_mapping', default=None, type=str, help="")
|
|
parser.add_argument('--pretrain_mapping', default=None, type=str, help="")
|
|
|
|
parser.add_argument('--val_txt_db', default=None, type=str, help="")
|
|
parser.add_argument('--val_img_db', default=None, type=str, help="")
|
|
parser.add_argument('--test_txt_db', default=None, type=str, help="")
|
|
parser.add_argument('--test_img_db', default=None, type=str, help="")
|
|
parser.add_argument('--steps_per_hard_neg', default=-1, type=int, help="")
|
|
parser.add_argument('--inf_minibatch_size', default=400, type=int, help="")
|
|
parser.add_argument('--project_dim', default=0, type=int, help='')
|
|
parser.add_argument('--cls_concat', default="", type=str, help='')
|
|
parser.add_argument('--fix_txt_encoder', action='store_true', help='')
|
|
parser.add_argument('--fix_img_encoder', action='store_true', help='')
|
|
parser.add_argument('--compressed_db', action='store_true', help='use compressed LMDB')
|
|
parser.add_argument('--retrieval_mode', default="both",
|
|
choices=['img_only', 'txt_only', 'both'], type=str, help="")
|
|
|
|
|
|
def add_logging_params(parser: argparse.ArgumentParser):
|
|
parser.add_argument('--log_result_step', default=4, type=int, help="")
|
|
parser.add_argument('--project_name', default='itm', type=str, help="")
|
|
parser.add_argument('--expr_name_prefix', default='', type=str, help="")
|
|
parser.add_argument('--save_all_epochs', action='store_true', help="")
|
|
|
|
|
|
def add_kd_params(parser: argparse.ArgumentParser):
|
|
parser.add_argument('--teacher_checkpoint', default=None, type=str, help="")
|
|
parser.add_argument('--T', default=1.0, type=float, help="")
|
|
parser.add_argument('--kd_loss_weight', default=1.0, type=float, help="")
|
|
|
|
|
|
def parse_with_config(parser, cmds=None):
|
|
if cmds is None:
|
|
args = parser.parse_args()
|
|
else:
|
|
args = parser.parse_args(cmds)
|
|
|
|
if args.config is not None:
|
|
config_args = json.load(open(args.config))
|
|
override_keys = {arg[2:].split('=')[0] for arg in sys.argv[1:]
|
|
if arg.startswith('--')}
|
|
for k, v in config_args.items():
|
|
if k not in override_keys:
|
|
setattr(args, k, v)
|
|
return args
|
|
|
|
|
|
def map_db_dirs(args):
|
|
# map img db
|
|
for k in args.__dict__:
|
|
if not isinstance(args.__dict__[k], str):
|
|
continue
|
|
if args.__dict__[k].startswith('/pretrain') and args.pretrain_mapping:
|
|
print('pretrain', k, args.__dict__[k])
|
|
args.__dict__[k] = args.__dict__[k].replace('/pretrain', args.pretrain_mapping)
|
|
if args.__dict__[k].startswith('/db') and args.txt_db_mapping:
|
|
print('db', k, args.__dict__[k])
|
|
args.__dict__[k] = args.__dict__[k].replace('/db', args.txt_db_mapping)
|
|
if args.__dict__[k].startswith('/img') and args.img_db_mapping:
|
|
print('img', k, args.__dict__[k])
|
|
args.__dict__[k] = args.__dict__[k].replace('/img', args.img_db_mapping)
|
|
|
|
if args.img_db_mapping:
|
|
for i in range(len(args.train_img_dbs)):
|
|
args.train_img_dbs[i] = args.train_img_dbs[i].replace('/img', args.img_db_mapping)
|
|
if args.txt_db_mapping:
|
|
for i in range(len(args.train_txt_dbs)):
|
|
args.train_txt_dbs[i] = args.train_txt_dbs[i].replace('/db', args.txt_db_mapping)
|
|
|
|
|
|
|
|
|
|
def print_args(args):
|
|
logger.info(" **************** CONFIGURATION **************** ")
|
|
for key, val in sorted(vars(args).items()):
|
|
keystr = "{}".format(key) + (" " * (30 - len(key)))
|
|
logger.info("%s --> %s", keystr, val)
|
|
logger.info(" **************** END CONFIGURATION **************** ")
|
|
|
|
|
|
def set_seed(args):
|
|
seed = args.seed
|
|
random.seed(seed)
|
|
np.random.seed(seed)
|
|
torch.manual_seed(seed)
|
|
if args.n_gpu > 0:
|
|
torch.cuda.manual_seed_all(seed)
|
|
|
|
|
|
def setup_args_gpu(args):
|
|
"""
|
|
Setup arguments CUDA, GPU & distributed training
|
|
"""
|
|
if args.local_rank == -1 or args.no_cuda: # single-node multi-gpu (or cpu) mode
|
|
device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
|
|
args.n_gpu = torch.cuda.device_count()
|
|
else: # distributed mode
|
|
torch.cuda.set_device(args.local_rank)
|
|
device = torch.device("cuda", args.local_rank)
|
|
torch.distributed.init_process_group(backend="nccl")
|
|
args.n_gpu = 1
|
|
args.device = device
|
|
ws = os.environ.get('WORLD_SIZE')
|
|
|
|
args.distributed_world_size = int(ws) if ws else 1
|
|
|
|
logger.info(
|
|
'Initialized host %s as d.rank %d on device=%s, n_gpu=%d, world size=%d', socket.gethostname(),
|
|
args.local_rank, device,
|
|
args.n_gpu,
|
|
args.distributed_world_size)
|
|
logger.info("16-bits training: %s ", args.fp16)
|