logo
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions

176 lines
8.3 KiB

import argparse
import json
import sys
import os
import logging
import torch
import random
import socket
import numpy as np
logger = logging.getLogger()
def default_params(parser: argparse.ArgumentParser):
parser.add_argument('--txt_model_type', default='bert-base', type=str, help="")
parser.add_argument('--txt_model_config', default='bert-base', type=str, help="")
parser.add_argument('--txt_checkpoint', default=None, type=str, help="")
parser.add_argument('--img_model_type', default='uniter-base', type=str, help="")
parser.add_argument('--img_model_config', default='./config/img_base.json', type=str, help="")
parser.add_argument('--img_checkpoint', default=None, type=str, help="")
parser.add_argument('--biencoder_checkpoint', default=None, type=str, help="")
parser.add_argument('--seperate_caption_encoder', action='store_true', help="")
parser.add_argument('--train_batch_size', default=80, type=int, help="")
parser.add_argument('--valid_batch_size', default=80, type=int, help="")
parser.add_argument('--gradient_accumulation_steps', default=1, type=int, help="")
parser.add_argument('--learning_rate', default=1e-5, type=float, help="")
parser.add_argument('--max_grad_norm', default=2.0, type=float, help="")
parser.add_argument('--warmup_steps', default=500, type=int, help="")
parser.add_argument('--valid_steps', default=500, type=int, help="")
parser.add_argument('--num_train_steps', default=5000, type=int, help="")
parser.add_argument('--num_train_epochs', default=0, type=int, help="")
parser.add_argument('--fp16', action='store_true', help="")
parser.add_argument('--seed', default=42, type=int, help="")
parser.add_argument('--output_dir', default='./', type=str, help="")
parser.add_argument('--max_txt_len', default=64, type=int, help="")
parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus")
parser.add_argument('--config', default=None, type=str, help="")
parser.add_argument('--itm_global_file', default=None, type=str, help="")
parser.add_argument("--no_cuda", action='store_true', help="Whether not to use CUDA when available")
parser.add_argument('--n_workers', type=int, default=2, help="number of data workers")
parser.add_argument('--pin_mem', action='store_true', help="pin memory") # ???
parser.add_argument('--hnsw_index', action='store_true', help="")
parser.add_argument('--fp16_opt_level', type=str, default='O1', help="")
parser.add_argument('--img_meta', type=str, default=None, help="")
def add_itm_params(parser: argparse.ArgumentParser):
parser.add_argument('--conf_th', default=0.2, type=float, help="")
parser.add_argument('--caption_score_weight', default=0.0, type=float, help="")
parser.add_argument('--negative_size', default=10, type=int, help="")
parser.add_argument('--num_hard_negatives', default=0, type=int, help="")
parser.add_argument('--sample_init_hard_negatives', action='store_true', help="")
parser.add_argument('--hard_negatives_sampling', default='none', type=str,
choices=['none', 'random', 'top', 'top-random', '10-20', '20-30'], help="")
parser.add_argument('--max_bb', default=100, type=int, help="")
parser.add_argument('--min_bb', default=10, type=int, help="")
parser.add_argument('--num_bb', default=36, type=int, help="")
parser.add_argument('--train_txt_dbs', default=None, type=str, help="")
parser.add_argument('--train_img_dbs', default=None, type=str, help="")
parser.add_argument('--txt_db_mapping', default=None, type=str, help="")
parser.add_argument('--img_db_mapping', default=None, type=str, help="")
parser.add_argument('--pretrain_mapping', default=None, type=str, help="")
parser.add_argument('--val_txt_db', default=None, type=str, help="")
parser.add_argument('--val_img_db', default=None, type=str, help="")
parser.add_argument('--test_txt_db', default=None, type=str, help="")
parser.add_argument('--test_img_db', default=None, type=str, help="")
parser.add_argument('--steps_per_hard_neg', default=-1, type=int, help="")
parser.add_argument('--inf_minibatch_size', default=400, type=int, help="")
parser.add_argument('--project_dim', default=0, type=int, help='')
parser.add_argument('--cls_concat', default="", type=str, help='')
parser.add_argument('--fix_txt_encoder', action='store_true', help='')
parser.add_argument('--fix_img_encoder', action='store_true', help='')
parser.add_argument('--compressed_db', action='store_true', help='use compressed LMDB')
parser.add_argument('--retrieval_mode', default="both",
choices=['img_only', 'txt_only', 'both'], type=str, help="")
def add_logging_params(parser: argparse.ArgumentParser):
parser.add_argument('--log_result_step', default=4, type=int, help="")
parser.add_argument('--project_name', default='itm', type=str, help="")
parser.add_argument('--expr_name_prefix', default='', type=str, help="")
parser.add_argument('--save_all_epochs', action='store_true', help="")
def add_kd_params(parser: argparse.ArgumentParser):
parser.add_argument('--teacher_checkpoint', default=None, type=str, help="")
parser.add_argument('--T', default=1.0, type=float, help="")
parser.add_argument('--kd_loss_weight', default=1.0, type=float, help="")
def parse_with_config(parser, cmds=None):
if cmds is None:
args = parser.parse_args()
else:
args = parser.parse_args(cmds)
if args.config is not None:
config_args = json.load(open(args.config))
override_keys = {arg[2:].split('=')[0] for arg in sys.argv[1:]
if arg.startswith('--')}
for k, v in config_args.items():
if k not in override_keys:
setattr(args, k, v)
return args
def map_db_dirs(args):
# map img db
for k in args.__dict__:
if not isinstance(args.__dict__[k], str):
continue
if args.__dict__[k].startswith('/pretrain') and args.pretrain_mapping:
print('pretrain', k, args.__dict__[k])
args.__dict__[k] = args.__dict__[k].replace('/pretrain', args.pretrain_mapping)
if args.__dict__[k].startswith('/db') and args.txt_db_mapping:
print('db', k, args.__dict__[k])
args.__dict__[k] = args.__dict__[k].replace('/db', args.txt_db_mapping)
if args.__dict__[k].startswith('/img') and args.img_db_mapping:
print('img', k, args.__dict__[k])
args.__dict__[k] = args.__dict__[k].replace('/img', args.img_db_mapping)
if args.img_db_mapping:
for i in range(len(args.train_img_dbs)):
args.train_img_dbs[i] = args.train_img_dbs[i].replace('/img', args.img_db_mapping)
if args.txt_db_mapping:
for i in range(len(args.train_txt_dbs)):
args.train_txt_dbs[i] = args.train_txt_dbs[i].replace('/db', args.txt_db_mapping)
def print_args(args):
logger.info(" **************** CONFIGURATION **************** ")
for key, val in sorted(vars(args).items()):
keystr = "{}".format(key) + (" " * (30 - len(key)))
logger.info("%s --> %s", keystr, val)
logger.info(" **************** END CONFIGURATION **************** ")
def set_seed(args):
seed = args.seed
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if args.n_gpu > 0:
torch.cuda.manual_seed_all(seed)
def setup_args_gpu(args):
"""
Setup arguments CUDA, GPU & distributed training
"""
if args.local_rank == -1 or args.no_cuda: # single-node multi-gpu (or cpu) mode
device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
args.n_gpu = torch.cuda.device_count()
else: # distributed mode
torch.cuda.set_device(args.local_rank)
device = torch.device("cuda", args.local_rank)
torch.distributed.init_process_group(backend="nccl")
args.n_gpu = 1
args.device = device
ws = os.environ.get('WORLD_SIZE')
args.distributed_world_size = int(ws) if ws else 1
logger.info(
'Initialized host %s as d.rank %d on device=%s, n_gpu=%d, world size=%d', socket.gethostname(),
args.local_rank, device,
args.n_gpu,
args.distributed_world_size)
logger.info("16-bits training: %s ", args.fp16)