lightningdot
copied
wxywb
2 years ago
192 changed files with 21477 additions and 8 deletions
@ -0,0 +1,19 @@ |
|||||
|
{ |
||||
|
"architectures": [ |
||||
|
"BertForMaskedLM" |
||||
|
], |
||||
|
"attention_probs_dropout_prob": 0.1, |
||||
|
"hidden_act": "gelu", |
||||
|
"hidden_dropout_prob": 0.1, |
||||
|
"hidden_size": 768, |
||||
|
"initializer_range": 0.02, |
||||
|
"intermediate_size": 3072, |
||||
|
"layer_norm_eps": 1e-12, |
||||
|
"max_position_embeddings": 512, |
||||
|
"model_type": "bert", |
||||
|
"num_attention_heads": 12, |
||||
|
"num_hidden_layers": 12, |
||||
|
"pad_token_id": 0, |
||||
|
"type_vocab_size": 2, |
||||
|
"vocab_size": 30522 |
||||
|
} |
@ -0,0 +1,20 @@ |
|||||
|
{ |
||||
|
"img_model_type": "uniter-base", |
||||
|
"txt_model_type": "bert-base", |
||||
|
"txt_model_config": "bert-base-cased", |
||||
|
"img_model_config": "./config/img_base.json", |
||||
|
"itm_global_file":"./data/meta/coco_meta.json", |
||||
|
"seed": 42, |
||||
|
"output_dir": "/storage/debug-eval", |
||||
|
"max_txt_len": 60, |
||||
|
"conf_th": 0.2, |
||||
|
"max_bb": 100, |
||||
|
"min_bb": 10, |
||||
|
"num_bb": 36, |
||||
|
"project_dim": 768, |
||||
|
"test_txt_db": "./data/db/itm_coco_test_base-cased.db", |
||||
|
"test_img_db": "./data/img/coco_val2014/", |
||||
|
"project_name": "itm-debug", |
||||
|
"n_workers": 4, |
||||
|
"fp16": true |
||||
|
} |
@ -0,0 +1,43 @@ |
|||||
|
{ |
||||
|
"txt_model_type": "bert-base", |
||||
|
"txt_model_config": "bert-base-cased", |
||||
|
"img_model_type": "uniter-base", |
||||
|
"img_model_config": "./config/img_base.json", |
||||
|
"img_checkpoint": "./data/model/uniter-base.pt", |
||||
|
"itm_global_file":"./data/meta/coco_meta.json", |
||||
|
"train_batch_size": 64, |
||||
|
"val_batch_size": 256, |
||||
|
"gradient_accumulation_steps": 1, |
||||
|
"learning_rate": 2e-05, |
||||
|
"warmup_steps": 100, |
||||
|
"valid_steps": 500, |
||||
|
"num_train_epochs": 20, |
||||
|
"seed": 42, |
||||
|
"output_dir": "/storage/debug_coco", |
||||
|
"max_txt_len": 60, |
||||
|
"conf_th": 0.2, |
||||
|
"max_bb": 100, |
||||
|
"min_bb": 10, |
||||
|
"num_bb": 36, |
||||
|
"project_dim": 768, |
||||
|
"train_txt_dbs": [ |
||||
|
"./data/db/itm_coco_train_base-cased.db", |
||||
|
"./data/db/itm_coco_restval_base-cased.db" |
||||
|
], |
||||
|
"train_img_dbs": [ |
||||
|
"./data/img/coco_train2014/", |
||||
|
"./data/img/coco_val2014" |
||||
|
], |
||||
|
"val_txt_db": "./data/db/itm_coco_val_base-cased.db", |
||||
|
"val_img_db": "./data/img/coco_val2014/", |
||||
|
"test_txt_db": "./data/db/itm_coco_test_base-cased.db", |
||||
|
"test_img_db": "./data/img/coco_val2014/", |
||||
|
"project_name": "itm-debug", |
||||
|
"num_hard_negatives": 0, |
||||
|
"hard_negatives_sampling": "none", |
||||
|
"inf_minibatch_size": 0, |
||||
|
"n_workers": 0, |
||||
|
"fp16": true, |
||||
|
"compressed_db": false, |
||||
|
"pin_mem": true |
||||
|
} |
@ -0,0 +1,22 @@ |
|||||
|
{ |
||||
|
"img_model_type": "uniter-base", |
||||
|
"txt_model_type": "bert-base", |
||||
|
"txt_model_config": "bert-base-cased", |
||||
|
"img_model_config": "./config/img_base.json", |
||||
|
"itm_global_file":"./data/meta/flickr_meta.json", |
||||
|
"seed": 42, |
||||
|
"output_dir": "/storage/debug-eval", |
||||
|
"max_txt_len": 60, |
||||
|
"conf_th": 0.2, |
||||
|
"max_bb": 100, |
||||
|
"min_bb": 10, |
||||
|
"num_bb": 36, |
||||
|
"project_dim": 768, |
||||
|
"val_txt_db": "./data/db/itm_flickr30k_val_base-cased.db", |
||||
|
"val_img_db": "./data/img/flickr30k/", |
||||
|
"test_txt_db": "./data/db/itm_flickr30k_test_base-cased.db", |
||||
|
"test_img_db": "./data/img/flickr30k/", |
||||
|
"project_name": "itm-debug", |
||||
|
"n_workers": 4, |
||||
|
"fp16": true |
||||
|
} |
@ -0,0 +1,38 @@ |
|||||
|
{ |
||||
|
"txt_model_type": "bert-base", |
||||
|
"txt_model_config": "bert-base-cased", |
||||
|
"img_model_type": "uniter-base", |
||||
|
"img_model_config": "./config/img_base.json", |
||||
|
"img_checkpoint": "./data/model/uniter-base.pt", |
||||
|
"itm_global_file":"./data/meta/flickr_meta.json", |
||||
|
"train_batch_size": 64, |
||||
|
"gradient_accumulation_steps": 1, |
||||
|
"learning_rate": 2e-05, |
||||
|
"warmup_steps": 100, |
||||
|
"valid_steps": 500, |
||||
|
"num_train_epochs": 15, |
||||
|
"seed": 42, |
||||
|
"output_dir": "/storage/debug_flickr", |
||||
|
"max_txt_len": 60, |
||||
|
"conf_th": 0.2, |
||||
|
"max_bb": 100, |
||||
|
"min_bb": 10, |
||||
|
"num_bb": 36, |
||||
|
"project_dim": 768, |
||||
|
"train_txt_dbs": [ |
||||
|
"./data/db/itm_flickr30k_train_base-cased.db" |
||||
|
], |
||||
|
"train_img_dbs": [ |
||||
|
"./data/img/flickr30k/" |
||||
|
], |
||||
|
"val_txt_db": "./data/db/itm_flickr30k_val_base-cased.db", |
||||
|
"val_img_db": "./data/img/flickr30k/", |
||||
|
"test_txt_db": "./data/db/itm_flickr30k_test_base-cased.db", |
||||
|
"test_img_db": "./data/img/flickr30k/", |
||||
|
"project_name": "itm-debug", |
||||
|
"num_hard_negatives": 0, |
||||
|
"hard_negatives_sampling": "none", |
||||
|
"inf_minibatch_size": 0, |
||||
|
"n_workers": 0, |
||||
|
"fp16": true |
||||
|
} |
@ -0,0 +1,16 @@ |
|||||
|
{ |
||||
|
"attention_probs_dropout_prob": 0.1, |
||||
|
"hidden_act": "gelu", |
||||
|
"hidden_dropout_prob": 0.1, |
||||
|
"hidden_size": 768, |
||||
|
"initializer_range": 0.02, |
||||
|
"intermediate_size": 3072, |
||||
|
"layer_norm_eps": 1e-12, |
||||
|
"max_position_embeddings": 512, |
||||
|
"num_attention_heads": 12, |
||||
|
"num_hidden_layers": 12, |
||||
|
"pad_token_id": 0, |
||||
|
"type_vocab_size": 2, |
||||
|
"vocab_size": 28996, |
||||
|
"output_hidden_states": false |
||||
|
} |
@ -0,0 +1,191 @@ |
|||||
|
{ |
||||
|
"compressed_db": false, |
||||
|
"txt_model_type": "bert-base", |
||||
|
"txt_model_config": "bert-base-cased", |
||||
|
"img_model_type": "uniter-base", |
||||
|
"img_model_config": "./config/img_base.json", |
||||
|
"model_config": "./config/img_base.json", |
||||
|
"output_dir": "/storage/pretrain/alltask_ot_alldata_base", |
||||
|
"project_dim": 768, |
||||
|
"mrm_prob": 0.15, |
||||
|
"neg_size": 128, |
||||
|
"nce_temp": 1.0, |
||||
|
"itm_neg_prob": 0.0, |
||||
|
"itm_ot_lambda": 0.0, |
||||
|
"max_txt_len": 60, |
||||
|
"conf_th": 0.2, |
||||
|
"max_bb": 100, |
||||
|
"min_bb": 10, |
||||
|
"num_bb": 36, |
||||
|
"train_batch_size": 10240, |
||||
|
"val_batch_size": 10240, |
||||
|
"gradient_accumulation_steps": 6, |
||||
|
"learning_rate": 5e-05, |
||||
|
"valid_steps": 10000, |
||||
|
"num_train_steps": 300000, |
||||
|
"optim": "adamw", |
||||
|
"betas": [ |
||||
|
0.9, |
||||
|
0.98 |
||||
|
], |
||||
|
"decay": "linear", |
||||
|
"dropout": 0.1, |
||||
|
"weight_decay": 0.01, |
||||
|
"grad_norm": 5.0, |
||||
|
"warmup_steps": 10000, |
||||
|
"seed": 42, |
||||
|
"fp16": true, |
||||
|
"n_workers": 3, |
||||
|
"pin_mem": true, |
||||
|
"train_datasets": [ |
||||
|
{ |
||||
|
"name": "coco_cap", |
||||
|
"db": [ |
||||
|
"./data/db/pretrain_caption_coco_train_base-cased.db/", |
||||
|
"./data/db/pretrain_caption_coco_trainval_base-cased.db/" |
||||
|
], |
||||
|
"img": [ |
||||
|
"./data/img/coco_train2014/", |
||||
|
"./data/img/coco_val2014/" |
||||
|
], |
||||
|
"tasks": [ |
||||
|
"itm", |
||||
|
"mlm", |
||||
|
"mrfr", |
||||
|
"mrckl" |
||||
|
], |
||||
|
"mix_ratio": [ |
||||
|
16, |
||||
|
8, |
||||
|
4, |
||||
|
4 |
||||
|
] |
||||
|
}, |
||||
|
{ |
||||
|
"name": "vg_cap", |
||||
|
"db": [ |
||||
|
"./data/db/pretrain_caption_vg_train_base-cased.db/" |
||||
|
], |
||||
|
"img": [ |
||||
|
"./data/img/vg/" |
||||
|
], |
||||
|
"tasks": [ |
||||
|
"itm", |
||||
|
"mlm", |
||||
|
"mrfr", |
||||
|
"mrckl" |
||||
|
], |
||||
|
"mix_ratio": [ |
||||
|
16, |
||||
|
12, |
||||
|
6, |
||||
|
6 |
||||
|
] |
||||
|
}, |
||||
|
{ |
||||
|
"name": "cc", |
||||
|
"db": [ |
||||
|
"./data/db/conceptual_caption_train_base-cased.db/" |
||||
|
], |
||||
|
"img": [ |
||||
|
"./data/img/gcc_train/" |
||||
|
], |
||||
|
"tasks": [ |
||||
|
"itm", |
||||
|
"mlm", |
||||
|
"mrfr", |
||||
|
"mrckl" |
||||
|
], |
||||
|
"mix_ratio": [ |
||||
|
16, |
||||
|
12, |
||||
|
6, |
||||
|
6 |
||||
|
] |
||||
|
}, |
||||
|
{ |
||||
|
"name": "sbu", |
||||
|
"db": [ |
||||
|
"./data/db/sbu_caption_train_base-cased.db/" |
||||
|
], |
||||
|
"img": [ |
||||
|
"./data/img/sbu/" |
||||
|
], |
||||
|
"tasks": [ |
||||
|
"itm", |
||||
|
"mlm", |
||||
|
"mrfr", |
||||
|
"mrckl" |
||||
|
], |
||||
|
"mix_ratio": [ |
||||
|
16, |
||||
|
8, |
||||
|
4, |
||||
|
4 |
||||
|
] |
||||
|
} |
||||
|
], |
||||
|
"val_datasets": [ |
||||
|
{ |
||||
|
"name": "coco_cap", |
||||
|
"db": [ |
||||
|
"./data/db/pretrain_caption_coco_val_base-cased.db/" |
||||
|
], |
||||
|
"img": [ |
||||
|
"./data/img/coco_val2014/" |
||||
|
], |
||||
|
"tasks": [ |
||||
|
"itm", |
||||
|
"mlm", |
||||
|
"mrfr", |
||||
|
"mrckl" |
||||
|
] |
||||
|
}, |
||||
|
{ |
||||
|
"name": "vg_cap", |
||||
|
"db": [ |
||||
|
"./data/db/pretrain_caption_vg_val_base-cased.db/" |
||||
|
], |
||||
|
"img": [ |
||||
|
"./data/img/vg/" |
||||
|
], |
||||
|
"tasks": [ |
||||
|
"itm", |
||||
|
"mlm", |
||||
|
"mrfr", |
||||
|
"mrckl" |
||||
|
] |
||||
|
}, |
||||
|
{ |
||||
|
"name": "cc", |
||||
|
"db": [ |
||||
|
"./data/db/conceptual_caption_val_base-cased.db/" |
||||
|
], |
||||
|
"img": [ |
||||
|
"./data/img/gcc_val/" |
||||
|
], |
||||
|
"tasks": [ |
||||
|
"itm", |
||||
|
"mlm", |
||||
|
"mrfr", |
||||
|
"mrckl" |
||||
|
] |
||||
|
}, |
||||
|
{ |
||||
|
"name": "sbu", |
||||
|
"db": [ |
||||
|
"./data/db/sbu_caption_val_base-cased.db/" |
||||
|
], |
||||
|
"img": [ |
||||
|
"./data/img/sbu/" |
||||
|
], |
||||
|
"tasks": [ |
||||
|
"itm", |
||||
|
"mlm", |
||||
|
"mrfr", |
||||
|
"mrckl" |
||||
|
] |
||||
|
} |
||||
|
], |
||||
|
"rank": 0 |
||||
|
} |
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -0,0 +1,75 @@ |
|||||
|
# -------------------------------------------------------- |
||||
|
# Fast R-CNN |
||||
|
# Copyright (c) 2015 Microsoft |
||||
|
# Licensed under The MIT License [see LICENSE for details] |
||||
|
# Written by Ross Girshick |
||||
|
# -------------------------------------------------------- |
||||
|
|
||||
|
import numpy as np |
||||
|
|
||||
|
def bbox_transform(ex_rois, gt_rois): |
||||
|
ex_widths = ex_rois[:, 2] - ex_rois[:, 0] + 1.0 |
||||
|
ex_heights = ex_rois[:, 3] - ex_rois[:, 1] + 1.0 |
||||
|
ex_ctr_x = ex_rois[:, 0] + 0.5 * ex_widths |
||||
|
ex_ctr_y = ex_rois[:, 1] + 0.5 * ex_heights |
||||
|
|
||||
|
gt_widths = gt_rois[:, 2] - gt_rois[:, 0] + 1.0 |
||||
|
gt_heights = gt_rois[:, 3] - gt_rois[:, 1] + 1.0 |
||||
|
gt_ctr_x = gt_rois[:, 0] + 0.5 * gt_widths |
||||
|
gt_ctr_y = gt_rois[:, 1] + 0.5 * gt_heights |
||||
|
|
||||
|
targets_dx = (gt_ctr_x - ex_ctr_x) / ex_widths |
||||
|
targets_dy = (gt_ctr_y - ex_ctr_y) / ex_heights |
||||
|
targets_dw = np.log(gt_widths / ex_widths) |
||||
|
targets_dh = np.log(gt_heights / ex_heights) |
||||
|
|
||||
|
targets = np.vstack( |
||||
|
(targets_dx, targets_dy, targets_dw, targets_dh)).transpose() |
||||
|
return targets |
||||
|
|
||||
|
def bbox_transform_inv(boxes, deltas): |
||||
|
if boxes.shape[0] == 0: |
||||
|
return np.zeros((0, deltas.shape[1]), dtype=deltas.dtype) |
||||
|
boxes = boxes.astype(deltas.dtype, copy=False) |
||||
|
|
||||
|
widths = boxes[:, 2] - boxes[:, 0] + 1.0 |
||||
|
heights = boxes[:, 3] - boxes[:, 1] + 1.0 |
||||
|
ctr_x = boxes[:, 0] + 0.5 * widths |
||||
|
ctr_y = boxes[:, 1] + 0.5 * heights |
||||
|
|
||||
|
dx = deltas[:, 0::4] |
||||
|
dy = deltas[:, 1::4] |
||||
|
dw = deltas[:, 2::4] |
||||
|
dh = deltas[:, 3::4] |
||||
|
|
||||
|
pred_ctr_x = dx * widths[:, np.newaxis] + ctr_x[:, np.newaxis] |
||||
|
pred_ctr_y = dy * heights[:, np.newaxis] + ctr_y[:, np.newaxis] |
||||
|
pred_w = np.exp(dw) * widths[:, np.newaxis] |
||||
|
pred_h = np.exp(dh) * heights[:, np.newaxis] |
||||
|
|
||||
|
pred_boxes = np.zeros(deltas.shape, dtype=deltas.dtype) |
||||
|
# x1 |
||||
|
pred_boxes[:, 0::4] = pred_ctr_x - 0.5 * pred_w |
||||
|
# y1 |
||||
|
pred_boxes[:, 1::4] = pred_ctr_y - 0.5 * pred_h |
||||
|
# x2 |
||||
|
pred_boxes[:, 2::4] = pred_ctr_x + 0.5 * pred_w |
||||
|
# y2 |
||||
|
pred_boxes[:, 3::4] = pred_ctr_y + 0.5 * pred_h |
||||
|
|
||||
|
return pred_boxes |
||||
|
|
||||
|
def clip_boxes(boxes, im_shape): |
||||
|
""" |
||||
|
Clip boxes to image boundaries. |
||||
|
""" |
||||
|
|
||||
|
# x1 >= 0 |
||||
|
boxes[:, 0::4] = np.maximum(np.minimum(boxes[:, 0::4], im_shape[1] - 1), 0) |
||||
|
# y1 >= 0 |
||||
|
boxes[:, 1::4] = np.maximum(np.minimum(boxes[:, 1::4], im_shape[0] - 1), 0) |
||||
|
# x2 < im_shape[1] |
||||
|
boxes[:, 2::4] = np.maximum(np.minimum(boxes[:, 2::4], im_shape[1] - 1), 0) |
||||
|
# y2 < im_shape[0] |
||||
|
boxes[:, 3::4] = np.maximum(np.minimum(boxes[:, 3::4], im_shape[0] - 1), 0) |
||||
|
return boxes |
@ -0,0 +1,478 @@ |
|||||
|
import pickle |
||||
|
import ipdb |
||||
|
import torch |
||||
|
import numpy as np |
||||
|
import cv2 |
||||
|
import torchvision |
||||
|
from torch import nn |
||||
|
from .rpn import RegionProposalNetwork |
||||
|
|
||||
|
class ConvBlock(nn.Module): |
||||
|
def __init__(self,i,o,k,s,p,d,use_relu = True): |
||||
|
super(ConvBlock, self).__init__() |
||||
|
self.conv = nn.Conv2d(i, o, k, s, p, d) |
||||
|
self.bn = nn.BatchNorm2d(o) |
||||
|
self.use_relu = use_relu |
||||
|
if self.use_relu == True: |
||||
|
self.relu = nn.ReLU() |
||||
|
|
||||
|
def forward(self, x): |
||||
|
x = self.conv(x) |
||||
|
x = self.bn(x) |
||||
|
if self.use_relu == True: |
||||
|
x = self.relu(x) |
||||
|
return x |
||||
|
|
||||
|
def load_convblock(block, convname, bnname, scalename, weights): |
||||
|
block.conv.weight = nn.Parameter(torch.FloatTensor(weights[convname][0])) |
||||
|
block.conv.bias = nn.Parameter(torch.zeros_like(block.conv.bias)) |
||||
|
block.bn.running_mean = nn.Parameter(torch.FloatTensor(weights[bnname][0] / weights[bnname][2])) |
||||
|
block.bn.running_var = nn.Parameter(torch.FloatTensor(weights[bnname][1] / weights[bnname][2])) |
||||
|
block.bn.weight = nn.Parameter(torch.FloatTensor(weights[scalename][0])) |
||||
|
block.bn.bias = nn.Parameter(torch.FloatTensor(weights[scalename][1])) |
||||
|
|
||||
|
|
||||
|
class Net(nn.Module): |
||||
|
def __init__(self): |
||||
|
super(Net, self).__init__() |
||||
|
self.conv1 = ConvBlock(3,64,7,2,3,1,True) |
||||
|
|
||||
|
self.pool1 = nn.MaxPool2d(3,2,0,ceil_mode=True) |
||||
|
|
||||
|
self.res2a_branch1 = ConvBlock(64,256,1,1,0,1,False) |
||||
|
self.res2a_branch2a = ConvBlock(64,64,1,1,0,1,True) |
||||
|
self.res2a_branch2b = ConvBlock(64,64,3,1,1,1,True) |
||||
|
self.res2a_branch2c = ConvBlock(64,256,1,1,0,1,False) |
||||
|
|
||||
|
self.res2b_branch2a = ConvBlock(256,64,1,1,0,1,True) |
||||
|
self.res2b_branch2b = ConvBlock(64,64,3,1,1,1,True) |
||||
|
self.res2b_branch2c = ConvBlock(64,256,1,1,0,1,False) |
||||
|
|
||||
|
self.res2c_branch2a = ConvBlock(256,64,1,1,0,1,True) |
||||
|
self.res2c_branch2b = ConvBlock(64,64,3,1,1,1,True) |
||||
|
self.res2c_branch2c = ConvBlock(64,256,1,1,0,1,False) |
||||
|
|
||||
|
self.res3a_branch1 = ConvBlock(256,512,1,2,0,1,False) |
||||
|
self.res3a_branch2a = ConvBlock(256,128,1,2,0,1,True) |
||||
|
self.res3a_branch2b = ConvBlock(128,128,3,1,1,1,True) |
||||
|
self.res3a_branch2c = ConvBlock(128,512,1,1,0,1,False) |
||||
|
|
||||
|
self.res3b1_branch2a = ConvBlock(512,128,1,1,0,1,True) |
||||
|
self.res3b1_branch2b = ConvBlock(128,128,3,1,1,1,True) |
||||
|
self.res3b1_branch2c = ConvBlock(128,512,1,1,0,1,False) |
||||
|
|
||||
|
self.res3b2_branch2a = ConvBlock(512,128,1,1,0,1,True) |
||||
|
self.res3b2_branch2b = ConvBlock(128,128,3,1,1,1,True) |
||||
|
self.res3b2_branch2c = ConvBlock(128,512,1,1,0,1,False) |
||||
|
|
||||
|
self.res3b3_branch2a = ConvBlock(512,128,1,1,0,1,True) |
||||
|
self.res3b3_branch2b = ConvBlock(128,128,3,1,1,1,True) |
||||
|
self.res3b3_branch2c = ConvBlock(128,512,1,1,0,1,False) |
||||
|
|
||||
|
self.res4a_branch1 = ConvBlock(512,1024,1,2,0,1,False) |
||||
|
self.res4a_branch2a = ConvBlock(512,256,1,2,0,1,True) |
||||
|
self.res4a_branch2b = ConvBlock(256,256,3,1,1,1,True) |
||||
|
self.res4a_branch2c = ConvBlock(256,1024,1,1,0,1,False) |
||||
|
|
||||
|
self.res4b1_branch2a = ConvBlock(1024,256,1,1,0,1,True) |
||||
|
self.res4b1_branch2b = ConvBlock(256,256,3,1,1,1,True) |
||||
|
self.res4b1_branch2c = ConvBlock(256,1024,1,1,0,1,False) |
||||
|
|
||||
|
self.res4b2_branch2a = ConvBlock(1024,256,1,1,0,1,True) |
||||
|
self.res4b2_branch2b = ConvBlock(256,256,3,1,1,1,True) |
||||
|
self.res4b2_branch2c = ConvBlock(256,1024,1,1,0,1,False) |
||||
|
|
||||
|
self.res4b3_branch2a = ConvBlock(1024,256,1,1,0,1,True) |
||||
|
self.res4b3_branch2b = ConvBlock(256,256,3,1,1,1,True) |
||||
|
self.res4b3_branch2c = ConvBlock(256,1024,1,1,0,1,False) |
||||
|
|
||||
|
self.res4b4_branch2a = ConvBlock(1024,256,1,1,0,1,True) |
||||
|
self.res4b4_branch2b = ConvBlock(256,256,3,1,1,1,True) |
||||
|
self.res4b4_branch2c = ConvBlock(256,1024,1,1,0,1,False) |
||||
|
|
||||
|
self.res4b5_branch2a = ConvBlock(1024,256,1,1,0,1,True) |
||||
|
self.res4b5_branch2b = ConvBlock(256,256,3,1,1,1,True) |
||||
|
self.res4b5_branch2c = ConvBlock(256,1024,1,1,0,1,False) |
||||
|
|
||||
|
self.res4b6_branch2a = ConvBlock(1024,256,1,1,0,1,True) |
||||
|
self.res4b6_branch2b = ConvBlock(256,256,3,1,1,1,True) |
||||
|
self.res4b6_branch2c = ConvBlock(256,1024,1,1,0,1,False) |
||||
|
|
||||
|
self.res4b7_branch2a = ConvBlock(1024,256,1,1,0,1,True) |
||||
|
self.res4b7_branch2b = ConvBlock(256,256,3,1,1,1,True) |
||||
|
self.res4b7_branch2c = ConvBlock(256,1024,1,1,0,1,False) |
||||
|
|
||||
|
self.res4b8_branch2a = ConvBlock(1024,256,1,1,0,1,True) |
||||
|
self.res4b8_branch2b = ConvBlock(256,256,3,1,1,1,True) |
||||
|
self.res4b8_branch2c = ConvBlock(256,1024,1,1,0,1,False) |
||||
|
|
||||
|
self.res4b9_branch2a = ConvBlock(1024,256,1,1,0,1,True) |
||||
|
self.res4b9_branch2b = ConvBlock(256,256,3,1,1,1,True) |
||||
|
self.res4b9_branch2c = ConvBlock(256,1024,1,1,0,1,False) |
||||
|
|
||||
|
self.res4b10_branch2a = ConvBlock(1024,256,1,1,0,1,True) |
||||
|
self.res4b10_branch2b = ConvBlock(256,256,3,1,1,1,True) |
||||
|
self.res4b10_branch2c = ConvBlock(256,1024,1,1,0,1,False) |
||||
|
|
||||
|
self.res4b11_branch2a = ConvBlock(1024,256,1,1,0,1,True) |
||||
|
self.res4b11_branch2b = ConvBlock(256,256,3,1,1,1,True) |
||||
|
self.res4b11_branch2c = ConvBlock(256,1024,1,1,0,1,False) |
||||
|
|
||||
|
self.res4b12_branch2a = ConvBlock(1024,256,1,1,0,1,True) |
||||
|
self.res4b12_branch2b = ConvBlock(256,256,3,1,1,1,True) |
||||
|
self.res4b12_branch2c = ConvBlock(256,1024,1,1,0,1,False) |
||||
|
|
||||
|
self.res4b13_branch2a = ConvBlock(1024,256,1,1,0,1,True) |
||||
|
self.res4b13_branch2b = ConvBlock(256,256,3,1,1,1,True) |
||||
|
self.res4b13_branch2c = ConvBlock(256,1024,1,1,0,1,False) |
||||
|
|
||||
|
self.res4b14_branch2a = ConvBlock(1024,256,1,1,0,1,True) |
||||
|
self.res4b14_branch2b = ConvBlock(256,256,3,1,1,1,True) |
||||
|
self.res4b14_branch2c = ConvBlock(256,1024,1,1,0,1,False) |
||||
|
|
||||
|
self.res4b15_branch2a = ConvBlock(1024,256,1,1,0,1,True) |
||||
|
self.res4b15_branch2b = ConvBlock(256,256,3,1,1,1,True) |
||||
|
self.res4b15_branch2c = ConvBlock(256,1024,1,1,0,1,False) |
||||
|
|
||||
|
self.res4b16_branch2a = ConvBlock(1024,256,1,1,0,1,True) |
||||
|
self.res4b16_branch2b = ConvBlock(256,256,3,1,1,1,True) |
||||
|
self.res4b16_branch2c = ConvBlock(256,1024,1,1,0,1,False) |
||||
|
|
||||
|
self.res4b17_branch2a = ConvBlock(1024,256,1,1,0,1,True) |
||||
|
self.res4b17_branch2b = ConvBlock(256,256,3,1,1,1,True) |
||||
|
self.res4b17_branch2c = ConvBlock(256,1024,1,1,0,1,False) |
||||
|
|
||||
|
self.res4b18_branch2a = ConvBlock(1024,256,1,1,0,1,True) |
||||
|
self.res4b18_branch2b = ConvBlock(256,256,3,1,1,1,True) |
||||
|
self.res4b18_branch2c = ConvBlock(256,1024,1,1,0,1,False) |
||||
|
|
||||
|
self.res4b19_branch2a = ConvBlock(1024,256,1,1,0,1,True) |
||||
|
self.res4b19_branch2b = ConvBlock(256,256,3,1,1,1,True) |
||||
|
self.res4b19_branch2c = ConvBlock(256,1024,1,1,0,1,False) |
||||
|
|
||||
|
self.res4b20_branch2a = ConvBlock(1024,256,1,1,0,1,True) |
||||
|
self.res4b20_branch2b = ConvBlock(256,256,3,1,1,1,True) |
||||
|
self.res4b20_branch2c = ConvBlock(256,1024,1,1,0,1,False) |
||||
|
|
||||
|
self.res4b21_branch2a = ConvBlock(1024,256,1,1,0,1,True) |
||||
|
self.res4b21_branch2b = ConvBlock(256,256,3,1,1,1,True) |
||||
|
self.res4b21_branch2c = ConvBlock(256,1024,1,1,0,1,False) |
||||
|
|
||||
|
self.res4b22_branch2a = ConvBlock(1024,256,1,1,0,1,True) |
||||
|
self.res4b22_branch2b = ConvBlock(256,256,3,1,1,1,True) |
||||
|
self.res4b22_branch2c = ConvBlock(256,1024,1,1,0,1,False) |
||||
|
|
||||
|
self.res5a_branch1 = ConvBlock(1024,2048,1,1,0,1,False) |
||||
|
self.res5a_branch2a = ConvBlock(1024,512,1,1,0,1,True) |
||||
|
self.res5a_branch2b = ConvBlock(512,512,3,1,2,2,True) |
||||
|
self.res5a_branch2c = ConvBlock(512,2048,1,1,0,1,False) |
||||
|
|
||||
|
self.res5b_branch2a = ConvBlock(2048,512,1,1,0,1,True) |
||||
|
self.res5b_branch2b = ConvBlock(512,512,3,1,2,2,True) |
||||
|
self.res5b_branch2c = ConvBlock(512,2048,1,1,0,1,False) |
||||
|
|
||||
|
self.res5c_branch2a = ConvBlock(2048,512,1,1,0,1,True) |
||||
|
self.res5c_branch2b = ConvBlock(512,512,3,1,2,2,True) |
||||
|
self.res5c_branch2c = ConvBlock(512,2048,1,1,0,1,False) |
||||
|
|
||||
|
self.rpn_conv_3x3 = nn.Conv2d(1024,512,3,1,1,1) |
||||
|
self.rpn_cls_score = nn.Conv2d(512,24,1,1,0,1) |
||||
|
self.rpn_bbox_pred = nn.Conv2d(512,48,1,1,0,1) |
||||
|
|
||||
|
self.rpn = RegionProposalNetwork(pre_nms_topN = 6000, post_nms_topN = 300, nms_thresh = 0.7, min_size = 16, anchor_scales = (4, 8, 16, 32), feat_stride=16) |
||||
|
|
||||
|
#self.pool5 = nn.MaxPool2d(3,2,1,ceil_mode=True) |
||||
|
self.cls_score = nn.Linear(2048, 1601) |
||||
|
|
||||
|
def infer_resblock(self, l, r, x): |
||||
|
xl = x |
||||
|
xr = x |
||||
|
|
||||
|
for b in l: |
||||
|
xl = b(xl) |
||||
|
for b in r: |
||||
|
xr = b(xr) |
||||
|
return xl + xr |
||||
|
|
||||
|
def forward(self, x, im_size): |
||||
|
x = self.conv1(x) |
||||
|
x = self.pool1(x) |
||||
|
|
||||
|
x = nn.functional.relu(self.infer_resblock([self.res2a_branch1], [self.res2a_branch2a,self.res2a_branch2b,self.res2a_branch2c],x)) |
||||
|
x = nn.functional.relu(self.infer_resblock([lambda x: x], [self.res2b_branch2a,self.res2b_branch2b,self.res2b_branch2c],x)) |
||||
|
x = nn.functional.relu(self.infer_resblock([lambda x: x], [self.res2c_branch2a,self.res2c_branch2b,self.res2c_branch2c],x)) |
||||
|
x = nn.functional.relu(self.infer_resblock([self.res3a_branch1], [self.res3a_branch2a,self.res3a_branch2b,self.res3a_branch2c],x)) |
||||
|
x = nn.functional.relu(self.infer_resblock([lambda x: x], [self.res3b1_branch2a,self.res3b1_branch2b,self.res3b1_branch2c],x)) |
||||
|
x = nn.functional.relu(self.infer_resblock([lambda x: x], [self.res3b2_branch2a,self.res3b2_branch2b,self.res3b2_branch2c],x)) |
||||
|
x = nn.functional.relu(self.infer_resblock([lambda x: x], [self.res3b3_branch2a,self.res3b3_branch2b,self.res3b3_branch2c],x)) |
||||
|
x = nn.functional.relu(self.infer_resblock([self.res4a_branch1], [self.res4a_branch2a,self.res4a_branch2b,self.res4a_branch2c],x)) |
||||
|
x = nn.functional.relu(self.infer_resblock([lambda x: x], [self.res4b1_branch2a,self.res4b1_branch2b,self.res4b1_branch2c],x)) |
||||
|
x = nn.functional.relu(self.infer_resblock([lambda x: x], [self.res4b2_branch2a,self.res4b2_branch2b,self.res4b2_branch2c],x)) |
||||
|
x = nn.functional.relu(self.infer_resblock([lambda x: x], [self.res4b3_branch2a,self.res4b3_branch2b,self.res4b3_branch2c],x)) |
||||
|
x = nn.functional.relu(self.infer_resblock([lambda x: x], [self.res4b4_branch2a,self.res4b4_branch2b,self.res4b4_branch2c],x)) |
||||
|
x = nn.functional.relu(self.infer_resblock([lambda x: x], [self.res4b5_branch2a,self.res4b5_branch2b,self.res4b5_branch2c],x)) |
||||
|
x = nn.functional.relu(self.infer_resblock([lambda x: x], [self.res4b6_branch2a,self.res4b6_branch2b,self.res4b6_branch2c],x)) |
||||
|
x = nn.functional.relu(self.infer_resblock([lambda x: x], [self.res4b7_branch2a,self.res4b7_branch2b,self.res4b7_branch2c],x)) |
||||
|
x = nn.functional.relu(self.infer_resblock([lambda x: x], [self.res4b8_branch2a,self.res4b8_branch2b,self.res4b8_branch2c],x)) |
||||
|
x = nn.functional.relu(self.infer_resblock([lambda x: x], [self.res4b9_branch2a,self.res4b9_branch2b,self.res4b9_branch2c],x)) |
||||
|
x = nn.functional.relu(self.infer_resblock([lambda x: x], [self.res4b10_branch2a,self.res4b10_branch2b,self.res4b10_branch2c],x)) |
||||
|
x = nn.functional.relu(self.infer_resblock([lambda x: x], [self.res4b11_branch2a,self.res4b11_branch2b,self.res4b11_branch2c],x)) |
||||
|
x = nn.functional.relu(self.infer_resblock([lambda x: x], [self.res4b12_branch2a,self.res4b12_branch2b,self.res4b12_branch2c],x)) |
||||
|
x = nn.functional.relu(self.infer_resblock([lambda x: x], [self.res4b13_branch2a,self.res4b13_branch2b,self.res4b13_branch2c],x)) |
||||
|
x = nn.functional.relu(self.infer_resblock([lambda x: x], [self.res4b14_branch2a,self.res4b14_branch2b,self.res4b14_branch2c],x)) |
||||
|
x = nn.functional.relu(self.infer_resblock([lambda x: x], [self.res4b15_branch2a,self.res4b15_branch2b,self.res4b15_branch2c],x)) |
||||
|
x = nn.functional.relu(self.infer_resblock([lambda x: x], [self.res4b16_branch2a,self.res4b16_branch2b,self.res4b16_branch2c],x)) |
||||
|
x = nn.functional.relu(self.infer_resblock([lambda x: x], [self.res4b17_branch2a,self.res4b17_branch2b,self.res4b17_branch2c],x)) |
||||
|
x = nn.functional.relu(self.infer_resblock([lambda x: x], [self.res4b18_branch2a,self.res4b18_branch2b,self.res4b18_branch2c],x)) |
||||
|
x = nn.functional.relu(self.infer_resblock([lambda x: x], [self.res4b19_branch2a,self.res4b19_branch2b,self.res4b19_branch2c],x)) |
||||
|
x = nn.functional.relu(self.infer_resblock([lambda x: x], [self.res4b20_branch2a,self.res4b20_branch2b,self.res4b20_branch2c],x)) |
||||
|
x = nn.functional.relu(self.infer_resblock([lambda x: x], [self.res4b21_branch2a,self.res4b21_branch2b,self.res4b21_branch2c],x)) |
||||
|
#x = data_kv['res4b21'] |
||||
|
x = nn.functional.relu(self.infer_resblock([lambda x: x], [self.res4b22_branch2a,self.res4b22_branch2b,self.res4b22_branch2c],x)) |
||||
|
|
||||
|
x_rpn_output = nn.functional.relu(self.rpn_conv_3x3(x)) |
||||
|
x_rpn_cls_score = self.rpn_cls_score(x_rpn_output) |
||||
|
x_rpn_bbox_pred = self.rpn_bbox_pred(x_rpn_output) |
||||
|
|
||||
|
n, c, h, w = x_rpn_cls_score.shape |
||||
|
|
||||
|
x_rpn_cls_score = x_rpn_cls_score.reshape(n,2,-1,w) |
||||
|
x_rpn_cls_prob = nn.functional.softmax(x_rpn_cls_score, 1) |
||||
|
x_rpn_cls_prob_reshape = x_rpn_cls_prob.reshape(n,24,-1,w) |
||||
|
|
||||
|
#im_size = np.array([600. , 600. , 2.6785715]) |
||||
|
#im_size = np.array([5.6200000e+02, 1.0000000e+03, 8.9285713e-01]) |
||||
|
|
||||
|
rois = self.rpn.forward(x_rpn_cls_prob_reshape, x_rpn_bbox_pred, im_size) |
||||
|
feats = torchvision.ops.roi_pool(x, rois, output_size=[14,14], spatial_scale=0.0625) |
||||
|
|
||||
|
|
||||
|
x = nn.functional.relu(self.infer_resblock([self.res5a_branch1], [self.res5a_branch2a,self.res5a_branch2b,self.res5a_branch2c],feats)) |
||||
|
x = nn.functional.relu(self.infer_resblock([lambda x: x], [self.res5b_branch2a,self.res5b_branch2b,self.res5b_branch2c],x)) |
||||
|
x = nn.functional.relu(self.infer_resblock([lambda x: x], [self.res5c_branch2a,self.res5c_branch2b,self.res5c_branch2c],x)) |
||||
|
|
||||
|
x = torch.nn.functional.adaptive_avg_pool2d(x, (1,1)) |
||||
|
|
||||
|
pool5_flat = x.reshape((x.shape[0], -1)) |
||||
|
|
||||
|
x_cls_score = self.cls_score(pool5_flat) |
||||
|
x_cls_prob = torch.nn.functional.softmax(x_cls_score, -1) |
||||
|
x_cls_boxes = rois[:, 1:5] / im_size[2] |
||||
|
max_conf, keep_boxes = self.post_process(rois, x_cls_boxes, x_cls_prob, 0.2) |
||||
|
|
||||
|
MIN_BOXES = 10 |
||||
|
MAX_BOXES = 100 |
||||
|
|
||||
|
if len(keep_boxes) < MIN_BOXES: |
||||
|
keep_boxes = torch.argsort(max_conf, 0, True)[:MIN_BOXES] |
||||
|
elif len(keep_boxes) > MAX_BOXES: |
||||
|
keep_boxes = torch.argsort(max_conf, 0, True)[:MAX_BOXES] |
||||
|
|
||||
|
boxes = x_cls_boxes[keep_boxes] |
||||
|
features = pool5_flat[keep_boxes] |
||||
|
confidence = max_conf[keep_boxes] |
||||
|
|
||||
|
return boxes, features, confidence |
||||
|
|
||||
|
|
||||
|
def post_process(self, rois, cls_boxes, cls_prob, conf_thresh = 0.2): |
||||
|
max_conf = torch.zeros((rois.shape[0]), device = rois.device ) |
||||
|
for cls_ind in range(1, cls_prob.shape[1]): |
||||
|
#cls_scores = scores[:, cls_ind] |
||||
|
cls_scores = cls_prob[:, cls_ind] |
||||
|
dets = torch.hstack( |
||||
|
(cls_boxes, cls_scores[:, np.newaxis])) |
||||
|
keep = np.array(torchvision.ops.nms(dets[:,:4],dets[:,4], 0.3 )) |
||||
|
max_conf[keep] = torch.where(cls_scores[keep] > max_conf[keep], |
||||
|
cls_scores[keep], max_conf[keep]) |
||||
|
keep_boxes = torch.where(max_conf >= conf_thresh)[0] |
||||
|
return max_conf, keep_boxes |
||||
|
|
||||
|
def load_weights_from_pkl(self, weights): |
||||
|
with torch.no_grad(): |
||||
|
load_convblock(self.conv1, 'conv1', 'bn_conv1', 'scale_conv1', weights_kv) |
||||
|
load_convblock(self.res2a_branch1, 'res2a_branch1', 'bn2a_branch1', 'scale2a_branch1', weights_kv) |
||||
|
load_convblock(self.res2a_branch2a, 'res2a_branch2a', 'bn2a_branch2a', 'scale2a_branch2a', weights_kv) |
||||
|
load_convblock(self.res2a_branch2b, 'res2a_branch2b', 'bn2a_branch2b', 'scale2a_branch2b', weights_kv) |
||||
|
load_convblock(self.res2a_branch2c, 'res2a_branch2c', 'bn2a_branch2c', 'scale2a_branch2c', weights_kv) |
||||
|
|
||||
|
load_convblock(self.res2b_branch2a, 'res2b_branch2a', 'bn2b_branch2a', 'scale2b_branch2a', weights_kv) |
||||
|
load_convblock(self.res2b_branch2b, 'res2b_branch2b', 'bn2b_branch2b', 'scale2b_branch2b', weights_kv) |
||||
|
load_convblock(self.res2b_branch2c, 'res2b_branch2c', 'bn2b_branch2c', 'scale2b_branch2c', weights_kv) |
||||
|
|
||||
|
load_convblock(self.res2c_branch2a, 'res2c_branch2a', 'bn2c_branch2a', 'scale2c_branch2a', weights_kv) |
||||
|
load_convblock(self.res2c_branch2b, 'res2c_branch2b', 'bn2c_branch2b', 'scale2c_branch2b', weights_kv) |
||||
|
load_convblock(self.res2c_branch2c, 'res2c_branch2c', 'bn2c_branch2c', 'scale2c_branch2c', weights_kv) |
||||
|
|
||||
|
load_convblock(self.res3a_branch1, 'res3a_branch1', 'bn3a_branch1', 'scale3a_branch1', weights_kv) |
||||
|
load_convblock(self.res3a_branch2a, 'res3a_branch2a', 'bn3a_branch2a', 'scale3a_branch2a', weights_kv) |
||||
|
load_convblock(self.res3a_branch2b, 'res3a_branch2b', 'bn3a_branch2b', 'scale3a_branch2b', weights_kv) |
||||
|
load_convblock(self.res3a_branch2c, 'res3a_branch2c', 'bn3a_branch2c', 'scale3a_branch2c', weights_kv) |
||||
|
|
||||
|
load_convblock(self.res3b1_branch2a, 'res3b1_branch2a', 'bn3b1_branch2a', 'scale3b1_branch2a', weights_kv) |
||||
|
load_convblock(self.res3b1_branch2b, 'res3b1_branch2b', 'bn3b1_branch2b', 'scale3b1_branch2b', weights_kv) |
||||
|
load_convblock(self.res3b1_branch2c, 'res3b1_branch2c', 'bn3b1_branch2c', 'scale3b1_branch2c', weights_kv) |
||||
|
|
||||
|
load_convblock(self.res3b2_branch2a, 'res3b2_branch2a', 'bn3b2_branch2a', 'scale3b2_branch2a', weights_kv) |
||||
|
load_convblock(self.res3b2_branch2b, 'res3b2_branch2b', 'bn3b2_branch2b', 'scale3b2_branch2b', weights_kv) |
||||
|
load_convblock(self.res3b2_branch2c, 'res3b2_branch2c', 'bn3b2_branch2c', 'scale3b2_branch2c', weights_kv) |
||||
|
|
||||
|
load_convblock(self.res3b3_branch2a, 'res3b3_branch2a', 'bn3b3_branch2a', 'scale3b3_branch2a', weights_kv) |
||||
|
load_convblock(self.res3b3_branch2b, 'res3b3_branch2b', 'bn3b3_branch2b', 'scale3b3_branch2b', weights_kv) |
||||
|
load_convblock(self.res3b3_branch2c, 'res3b3_branch2c', 'bn3b3_branch2c', 'scale3b3_branch2c', weights_kv) |
||||
|
|
||||
|
load_convblock(self.res4a_branch1, 'res4a_branch1', 'bn4a_branch1', 'scale4a_branch1', weights_kv) |
||||
|
load_convblock(self.res4a_branch2a, 'res4a_branch2a', 'bn4a_branch2a', 'scale4a_branch2a', weights_kv) |
||||
|
load_convblock(self.res4a_branch2b, 'res4a_branch2b', 'bn4a_branch2b', 'scale4a_branch2b', weights_kv) |
||||
|
load_convblock(self.res4a_branch2c, 'res4a_branch2c', 'bn4a_branch2c', 'scale4a_branch2c', weights_kv) |
||||
|
|
||||
|
load_convblock(self.res4b1_branch2a, 'res4b1_branch2a', 'bn4b1_branch2a', 'scale4b1_branch2a', weights_kv) |
||||
|
load_convblock(self.res4b1_branch2b, 'res4b1_branch2b', 'bn4b1_branch2b', 'scale4b1_branch2b', weights_kv) |
||||
|
load_convblock(self.res4b1_branch2c, 'res4b1_branch2c', 'bn4b1_branch2c', 'scale4b1_branch2c', weights_kv) |
||||
|
|
||||
|
load_convblock(self.res4b2_branch2a, 'res4b2_branch2a', 'bn4b2_branch2a', 'scale4b2_branch2a', weights_kv) |
||||
|
load_convblock(self.res4b2_branch2b, 'res4b2_branch2b', 'bn4b2_branch2b', 'scale4b2_branch2b', weights_kv) |
||||
|
load_convblock(self.res4b2_branch2c, 'res4b2_branch2c', 'bn4b2_branch2c', 'scale4b2_branch2c', weights_kv) |
||||
|
|
||||
|
load_convblock(self.res4b3_branch2a, 'res4b3_branch2a', 'bn4b3_branch2a', 'scale4b3_branch2a', weights_kv) |
||||
|
load_convblock(self.res4b3_branch2b, 'res4b3_branch2b', 'bn4b3_branch2b', 'scale4b3_branch2b', weights_kv) |
||||
|
load_convblock(self.res4b3_branch2c, 'res4b3_branch2c', 'bn4b3_branch2c', 'scale4b3_branch2c', weights_kv) |
||||
|
|
||||
|
load_convblock(self.res4b4_branch2a, 'res4b4_branch2a', 'bn4b4_branch2a', 'scale4b4_branch2a', weights_kv) |
||||
|
load_convblock(self.res4b4_branch2b, 'res4b4_branch2b', 'bn4b4_branch2b', 'scale4b4_branch2b', weights_kv) |
||||
|
load_convblock(self.res4b4_branch2c, 'res4b4_branch2c', 'bn4b4_branch2c', 'scale4b4_branch2c', weights_kv) |
||||
|
|
||||
|
load_convblock(self.res4b5_branch2a, 'res4b5_branch2a', 'bn4b5_branch2a', 'scale4b5_branch2a', weights_kv) |
||||
|
load_convblock(self.res4b5_branch2b, 'res4b5_branch2b', 'bn4b5_branch2b', 'scale4b5_branch2b', weights_kv) |
||||
|
load_convblock(self.res4b5_branch2c, 'res4b5_branch2c', 'bn4b5_branch2c', 'scale4b5_branch2c', weights_kv) |
||||
|
|
||||
|
load_convblock(self.res4b6_branch2a, 'res4b6_branch2a', 'bn4b6_branch2a', 'scale4b6_branch2a', weights_kv) |
||||
|
load_convblock(self.res4b6_branch2b, 'res4b6_branch2b', 'bn4b6_branch2b', 'scale4b6_branch2b', weights_kv) |
||||
|
load_convblock(self.res4b6_branch2c, 'res4b6_branch2c', 'bn4b6_branch2c', 'scale4b6_branch2c', weights_kv) |
||||
|
|
||||
|
load_convblock(self.res4b7_branch2a, 'res4b7_branch2a', 'bn4b7_branch2a', 'scale4b7_branch2a', weights_kv) |
||||
|
load_convblock(self.res4b7_branch2b, 'res4b7_branch2b', 'bn4b7_branch2b', 'scale4b7_branch2b', weights_kv) |
||||
|
load_convblock(self.res4b7_branch2c, 'res4b7_branch2c', 'bn4b7_branch2c', 'scale4b7_branch2c', weights_kv) |
||||
|
|
||||
|
load_convblock(self.res4b8_branch2a, 'res4b8_branch2a', 'bn4b8_branch2a', 'scale4b8_branch2a', weights_kv) |
||||
|
load_convblock(self.res4b8_branch2b, 'res4b8_branch2b', 'bn4b8_branch2b', 'scale4b8_branch2b', weights_kv) |
||||
|
load_convblock(self.res4b8_branch2c, 'res4b8_branch2c', 'bn4b8_branch2c', 'scale4b8_branch2c', weights_kv) |
||||
|
|
||||
|
load_convblock(self.res4b9_branch2a, 'res4b9_branch2a', 'bn4b9_branch2a', 'scale4b9_branch2a', weights_kv) |
||||
|
load_convblock(self.res4b9_branch2b, 'res4b9_branch2b', 'bn4b9_branch2b', 'scale4b9_branch2b', weights_kv) |
||||
|
load_convblock(self.res4b9_branch2c, 'res4b9_branch2c', 'bn4b9_branch2c', 'scale4b9_branch2c', weights_kv) |
||||
|
|
||||
|
load_convblock(self.res4b10_branch2a, 'res4b10_branch2a', 'bn4b10_branch2a', 'scale4b10_branch2a', weights_kv) |
||||
|
load_convblock(self.res4b10_branch2b, 'res4b10_branch2b', 'bn4b10_branch2b', 'scale4b10_branch2b', weights_kv) |
||||
|
load_convblock(self.res4b10_branch2c, 'res4b10_branch2c', 'bn4b10_branch2c', 'scale4b10_branch2c', weights_kv) |
||||
|
|
||||
|
load_convblock(self.res4b11_branch2a, 'res4b11_branch2a', 'bn4b11_branch2a', 'scale4b11_branch2a', weights_kv) |
||||
|
load_convblock(self.res4b11_branch2b, 'res4b11_branch2b', 'bn4b11_branch2b', 'scale4b11_branch2b', weights_kv) |
||||
|
load_convblock(self.res4b11_branch2c, 'res4b11_branch2c', 'bn4b11_branch2c', 'scale4b11_branch2c', weights_kv) |
||||
|
|
||||
|
load_convblock(self.res4b12_branch2a, 'res4b12_branch2a', 'bn4b12_branch2a', 'scale4b12_branch2a', weights_kv) |
||||
|
load_convblock(self.res4b12_branch2b, 'res4b12_branch2b', 'bn4b12_branch2b', 'scale4b12_branch2b', weights_kv) |
||||
|
load_convblock(self.res4b12_branch2c, 'res4b12_branch2c', 'bn4b12_branch2c', 'scale4b12_branch2c', weights_kv) |
||||
|
|
||||
|
load_convblock(self.res4b13_branch2a, 'res4b13_branch2a', 'bn4b13_branch2a', 'scale4b13_branch2a', weights_kv) |
||||
|
load_convblock(self.res4b13_branch2b, 'res4b13_branch2b', 'bn4b13_branch2b', 'scale4b13_branch2b', weights_kv) |
||||
|
load_convblock(self.res4b13_branch2c, 'res4b13_branch2c', 'bn4b13_branch2c', 'scale4b13_branch2c', weights_kv) |
||||
|
|
||||
|
load_convblock(self.res4b14_branch2a, 'res4b14_branch2a', 'bn4b14_branch2a', 'scale4b14_branch2a', weights_kv) |
||||
|
load_convblock(self.res4b14_branch2b, 'res4b14_branch2b', 'bn4b14_branch2b', 'scale4b14_branch2b', weights_kv) |
||||
|
load_convblock(self.res4b14_branch2c, 'res4b14_branch2c', 'bn4b14_branch2c', 'scale4b14_branch2c', weights_kv) |
||||
|
|
||||
|
load_convblock(self.res4b15_branch2a, 'res4b15_branch2a', 'bn4b15_branch2a', 'scale4b15_branch2a', weights_kv) |
||||
|
load_convblock(self.res4b15_branch2b, 'res4b15_branch2b', 'bn4b15_branch2b', 'scale4b15_branch2b', weights_kv) |
||||
|
load_convblock(self.res4b15_branch2c, 'res4b15_branch2c', 'bn4b15_branch2c', 'scale4b15_branch2c', weights_kv) |
||||
|
|
||||
|
load_convblock(self.res4b16_branch2a, 'res4b16_branch2a', 'bn4b16_branch2a', 'scale4b16_branch2a', weights_kv) |
||||
|
load_convblock(self.res4b16_branch2b, 'res4b16_branch2b', 'bn4b16_branch2b', 'scale4b16_branch2b', weights_kv) |
||||
|
load_convblock(self.res4b16_branch2c, 'res4b16_branch2c', 'bn4b16_branch2c', 'scale4b16_branch2c', weights_kv) |
||||
|
|
||||
|
load_convblock(self.res4b17_branch2a, 'res4b17_branch2a', 'bn4b17_branch2a', 'scale4b17_branch2a', weights_kv) |
||||
|
load_convblock(self.res4b17_branch2b, 'res4b17_branch2b', 'bn4b17_branch2b', 'scale4b17_branch2b', weights_kv) |
||||
|
load_convblock(self.res4b17_branch2c, 'res4b17_branch2c', 'bn4b17_branch2c', 'scale4b17_branch2c', weights_kv) |
||||
|
|
||||
|
load_convblock(self.res4b18_branch2a, 'res4b18_branch2a', 'bn4b18_branch2a', 'scale4b18_branch2a', weights_kv) |
||||
|
load_convblock(self.res4b18_branch2b, 'res4b18_branch2b', 'bn4b18_branch2b', 'scale4b18_branch2b', weights_kv) |
||||
|
load_convblock(self.res4b18_branch2c, 'res4b18_branch2c', 'bn4b18_branch2c', 'scale4b18_branch2c', weights_kv) |
||||
|
|
||||
|
load_convblock(self.res4b19_branch2a, 'res4b19_branch2a', 'bn4b19_branch2a', 'scale4b19_branch2a', weights_kv) |
||||
|
load_convblock(self.res4b19_branch2b, 'res4b19_branch2b', 'bn4b19_branch2b', 'scale4b19_branch2b', weights_kv) |
||||
|
load_convblock(self.res4b19_branch2c, 'res4b19_branch2c', 'bn4b19_branch2c', 'scale4b19_branch2c', weights_kv) |
||||
|
|
||||
|
load_convblock(self.res4b20_branch2a, 'res4b20_branch2a', 'bn4b20_branch2a', 'scale4b20_branch2a', weights_kv) |
||||
|
load_convblock(self.res4b20_branch2b, 'res4b20_branch2b', 'bn4b20_branch2b', 'scale4b20_branch2b', weights_kv) |
||||
|
load_convblock(self.res4b20_branch2c, 'res4b20_branch2c', 'bn4b20_branch2c', 'scale4b20_branch2c', weights_kv) |
||||
|
|
||||
|
load_convblock(self.res4b21_branch2a, 'res4b21_branch2a', 'bn4b21_branch2a', 'scale4b21_branch2a', weights_kv) |
||||
|
load_convblock(self.res4b21_branch2b, 'res4b21_branch2b', 'bn4b21_branch2b', 'scale4b21_branch2b', weights_kv) |
||||
|
load_convblock(self.res4b21_branch2c, 'res4b21_branch2c', 'bn4b21_branch2c', 'scale4b21_branch2c', weights_kv) |
||||
|
|
||||
|
load_convblock(self.res4b22_branch2a, 'res4b22_branch2a', 'bn4b22_branch2a', 'scale4b22_branch2a', weights_kv) |
||||
|
load_convblock(self.res4b22_branch2b, 'res4b22_branch2b', 'bn4b22_branch2b', 'scale4b22_branch2b', weights_kv) |
||||
|
load_convblock(self.res4b22_branch2c, 'res4b22_branch2c', 'bn4b22_branch2c', 'scale4b22_branch2c', weights_kv) |
||||
|
|
||||
|
load_convblock(self.res5a_branch1, 'res5a_branch1', 'bn5a_branch1', 'scale5a_branch1', weights_kv) |
||||
|
load_convblock(self.res5a_branch2a, 'res5a_branch2a', 'bn5a_branch2a', 'scale5a_branch2a', weights_kv) |
||||
|
load_convblock(self.res5a_branch2b, 'res5a_branch2b', 'bn5a_branch2b', 'scale5a_branch2b', weights_kv) |
||||
|
load_convblock(self.res5a_branch2c, 'res5a_branch2c', 'bn5a_branch2c', 'scale5a_branch2c', weights_kv) |
||||
|
|
||||
|
load_convblock(self.res5b_branch2a, 'res5b_branch2a', 'bn5b_branch2a', 'scale5b_branch2a', weights_kv) |
||||
|
load_convblock(self.res5b_branch2b, 'res5b_branch2b', 'bn5b_branch2b', 'scale5b_branch2b', weights_kv) |
||||
|
load_convblock(self.res5b_branch2c, 'res5b_branch2c', 'bn5b_branch2c', 'scale5b_branch2c', weights_kv) |
||||
|
|
||||
|
load_convblock(self.res5c_branch2a, 'res5c_branch2a', 'bn5c_branch2a', 'scale5c_branch2a', weights_kv) |
||||
|
load_convblock(self.res5c_branch2b, 'res5c_branch2b', 'bn5c_branch2b', 'scale5c_branch2b', weights_kv) |
||||
|
load_convblock(self.res5c_branch2c, 'res5c_branch2c', 'bn5c_branch2c', 'scale5c_branch2c', weights_kv) |
||||
|
|
||||
|
self.rpn_conv_3x3.weight = nn.Parameter(torch.FloatTensor(weights_kv['rpn_conv/3x3'][0])) |
||||
|
self.rpn_conv_3x3.bias = nn.Parameter(torch.FloatTensor(weights_kv['rpn_conv/3x3'][1])) |
||||
|
|
||||
|
self.rpn_cls_score.weight = nn.Parameter(torch.FloatTensor(weights_kv['rpn_cls_score'][0])) |
||||
|
self.rpn_cls_score.bias = nn.Parameter(torch.FloatTensor(weights_kv['rpn_cls_score'][1])) |
||||
|
|
||||
|
self.rpn_bbox_pred.weight = nn.Parameter(torch.FloatTensor(weights_kv['rpn_bbox_pred'][0])) |
||||
|
self.rpn_bbox_pred.bias = nn.Parameter(torch.FloatTensor(weights_kv['rpn_bbox_pred'][1])) |
||||
|
|
||||
|
self.cls_score.weight = nn.Parameter(torch.FloatTensor(weights_kv['cls_score'][0])) |
||||
|
self.cls_score.bias = nn.Parameter(torch.FloatTensor(weights_kv['cls_score'][1])) |
||||
|
|
||||
|
|
||||
|
# self.conv1.weight = nn.Parameter(torch.FloatTensor(weights[0]['weights'][0])) |
||||
|
# self.conv1.bias = nn.Parameter(torch.zeros_like(self.conv1.bias)) |
||||
|
# self.bn_conv1.running_mean = nn.Parameter(torch.FloatTensor(weights[1]['weights'][0] / weights[1]['weights'][2])) |
||||
|
# self.bn_conv1.running_var = nn.Parameter(torch.FloatTensor(weights[1]['weights'][1] / weights[1]['weights'][2])) |
||||
|
# self.bn_conv1.weight = nn.Parameter(torch.FloatTensor(weights[2]['weights'][0])) |
||||
|
# self.bn_conv1.bias = nn.Parameter(torch.FloatTensor(weights[2]['weights'][1])) |
||||
|
# |
||||
|
|
||||
|
def process_img(img): |
||||
|
mean = np.array([[[102.9801, 115.9465, 122.7717]]]) |
||||
|
img = img - mean |
||||
|
|
||||
|
im_shape = img.shape |
||||
|
im_size_min = np.min(im_shape[0:2]) |
||||
|
im_size_max = np.max(im_shape[0:2]) |
||||
|
|
||||
|
target_size = 600 |
||||
|
max_size = 1000 |
||||
|
im_scale = float(target_size) / float(im_size_min) |
||||
|
if np.round(im_scale * im_size_max) > max_size: |
||||
|
im_scale = float(max_size) / float(im_size_max) |
||||
|
im = cv2.resize(img, None, None, fx=im_scale, fy=im_scale, interpolation=cv2.INTER_LINEAR) |
||||
|
return im, np.array([im.shape[0], im.shape[1], im_scale]) |
||||
|
|
||||
|
#img = cv2.imread('img2.jpg') |
||||
|
##print(img) |
||||
|
#ipdb.set_trace() |
||||
|
#shape = process_img(img) |
||||
|
# |
||||
|
#net = Net() |
||||
|
#net.load_weights_from_pkl(data2) |
||||
|
#net.eval() |
||||
|
#with torch.no_grad(): |
||||
|
# output = net(img) |
||||
|
#ipdb.set_trace() |
||||
|
#print(residual) |
@ -0,0 +1,105 @@ |
|||||
|
# -------------------------------------------------------- |
||||
|
# Faster R-CNN |
||||
|
# Copyright (c) 2015 Microsoft |
||||
|
# Licensed under The MIT License [see LICENSE for details] |
||||
|
# Written by Ross Girshick and Sean Bell |
||||
|
# -------------------------------------------------------- |
||||
|
|
||||
|
import numpy as np |
||||
|
|
||||
|
# Verify that we compute the same anchors as Shaoqing's matlab implementation: |
||||
|
# |
||||
|
# >> load output/rpn_cachedir/faster_rcnn_VOC2007_ZF_stage1_rpn/anchors.mat |
||||
|
# >> anchors |
||||
|
# |
||||
|
# anchors = |
||||
|
# |
||||
|
# -83 -39 100 56 |
||||
|
# -175 -87 192 104 |
||||
|
# -359 -183 376 200 |
||||
|
# -55 -55 72 72 |
||||
|
# -119 -119 136 136 |
||||
|
# -247 -247 264 264 |
||||
|
# -35 -79 52 96 |
||||
|
# -79 -167 96 184 |
||||
|
# -167 -343 184 360 |
||||
|
|
||||
|
#array([[ -83., -39., 100., 56.], |
||||
|
# [-175., -87., 192., 104.], |
||||
|
# [-359., -183., 376., 200.], |
||||
|
# [ -55., -55., 72., 72.], |
||||
|
# [-119., -119., 136., 136.], |
||||
|
# [-247., -247., 264., 264.], |
||||
|
# [ -35., -79., 52., 96.], |
||||
|
# [ -79., -167., 96., 184.], |
||||
|
# [-167., -343., 184., 360.]]) |
||||
|
|
||||
|
def generate_anchors(base_size=16, ratios=[0.5, 1, 2], |
||||
|
scales=2**np.arange(3, 6)): |
||||
|
""" |
||||
|
Generate anchor (reference) windows by enumerating aspect ratios X |
||||
|
scales wrt a reference (0, 0, 15, 15) window. |
||||
|
""" |
||||
|
|
||||
|
base_anchor = np.array([1, 1, base_size, base_size]) - 1 |
||||
|
ratio_anchors = _ratio_enum(base_anchor, ratios) |
||||
|
anchors = np.vstack([_scale_enum(ratio_anchors[i, :], scales) |
||||
|
for i in range(ratio_anchors.shape[0])]) |
||||
|
return anchors |
||||
|
|
||||
|
def _whctrs(anchor): |
||||
|
""" |
||||
|
Return width, height, x center, and y center for an anchor (window). |
||||
|
""" |
||||
|
|
||||
|
w = anchor[2] - anchor[0] + 1 |
||||
|
h = anchor[3] - anchor[1] + 1 |
||||
|
x_ctr = anchor[0] + 0.5 * (w - 1) |
||||
|
y_ctr = anchor[1] + 0.5 * (h - 1) |
||||
|
return w, h, x_ctr, y_ctr |
||||
|
|
||||
|
def _mkanchors(ws, hs, x_ctr, y_ctr): |
||||
|
""" |
||||
|
Given a vector of widths (ws) and heights (hs) around a center |
||||
|
(x_ctr, y_ctr), output a set of anchors (windows). |
||||
|
""" |
||||
|
|
||||
|
ws = ws[:, np.newaxis] |
||||
|
hs = hs[:, np.newaxis] |
||||
|
anchors = np.hstack((x_ctr - 0.5 * (ws - 1), |
||||
|
y_ctr - 0.5 * (hs - 1), |
||||
|
x_ctr + 0.5 * (ws - 1), |
||||
|
y_ctr + 0.5 * (hs - 1))) |
||||
|
return anchors |
||||
|
|
||||
|
def _ratio_enum(anchor, ratios): |
||||
|
""" |
||||
|
Enumerate a set of anchors for each aspect ratio wrt an anchor. |
||||
|
""" |
||||
|
|
||||
|
w, h, x_ctr, y_ctr = _whctrs(anchor) |
||||
|
size = w * h |
||||
|
size_ratios = size / ratios |
||||
|
ws = np.round(np.sqrt(size_ratios)) |
||||
|
hs = np.round(ws * ratios) |
||||
|
anchors = _mkanchors(ws, hs, x_ctr, y_ctr) |
||||
|
return anchors |
||||
|
|
||||
|
def _scale_enum(anchor, scales): |
||||
|
""" |
||||
|
Enumerate a set of anchors for each scale wrt an anchor. |
||||
|
""" |
||||
|
|
||||
|
w, h, x_ctr, y_ctr = _whctrs(anchor) |
||||
|
ws = w * scales |
||||
|
hs = h * scales |
||||
|
anchors = _mkanchors(ws, hs, x_ctr, y_ctr) |
||||
|
return anchors |
||||
|
|
||||
|
#if __name__ == '__main__': |
||||
|
# import time |
||||
|
# t = time.time() |
||||
|
# a = generate_anchors() |
||||
|
# print time.time() - t |
||||
|
# print a |
||||
|
# from IPython import embed; embed() |
@ -0,0 +1,136 @@ |
|||||
|
import pickle |
||||
|
|
||||
|
import numpy as np |
||||
|
import torch as t |
||||
|
import torch |
||||
|
from torch.nn import functional as F |
||||
|
from torch import nn |
||||
|
import torchvision |
||||
|
from torch import nn |
||||
|
|
||||
|
from .generate_anchors import generate_anchors |
||||
|
from .bbox_transform import bbox_transform_inv, clip_boxes |
||||
|
|
||||
|
|
||||
|
class RegionProposalNetwork(nn.Module): |
||||
|
def __init__(self, pre_nms_topN,post_nms_topN, nms_thresh, min_size, anchor_scales, feat_stride): |
||||
|
super(RegionProposalNetwork, self).__init__() |
||||
|
self._anchors = generate_anchors(scales=np.array(anchor_scales)) |
||||
|
self._num_anchors = self._anchors.shape[0] |
||||
|
self._feat_stride = feat_stride |
||||
|
self.pre_nms_topN = pre_nms_topN |
||||
|
self.post_nms_topN = post_nms_topN |
||||
|
self.nms_thresh = nms_thresh |
||||
|
self.min_size = min_size |
||||
|
self.anchor_scales = anchor_scales |
||||
|
self.feat_stride = feat_stride |
||||
|
|
||||
|
|
||||
|
def forward(self, rpn_cls_prob, rpn_bbox_pred, img_size): |
||||
|
scores = rpn_cls_prob[:,self._num_anchors:, :, :] |
||||
|
bbox_deltas = rpn_bbox_pred |
||||
|
min_size = self.min_size |
||||
|
|
||||
|
pre_nms_topN = self.pre_nms_topN |
||||
|
post_nms_topN = self.post_nms_topN |
||||
|
nms_thresh = self.nms_thresh |
||||
|
min_size = self.min_size |
||||
|
anchor_scales = self.anchor_scales |
||||
|
feat_stride = self.feat_stride |
||||
|
|
||||
|
n, _, hh, ww = scores.shape |
||||
|
|
||||
|
#n_anchor = anchor.shape[0] // (hh * ww) |
||||
|
|
||||
|
anchors = self._enumerate_shifted_anchor(self._anchors, self._feat_stride, hh, ww) |
||||
|
bbox_deltas = bbox_deltas.permute((0, 2, 3, 1)).reshape((-1, 4)) |
||||
|
|
||||
|
bbox_deltas = bbox_deltas.cpu().detach().numpy() |
||||
|
proposals = bbox_transform_inv(anchors, bbox_deltas) |
||||
|
scores = scores.permute((0, 2, 3, 1)).reshape((-1, 1)) |
||||
|
|
||||
|
proposals = bbox_transform_inv(anchors, bbox_deltas) |
||||
|
|
||||
|
proposals = clip_boxes(proposals, img_size[:2]) |
||||
|
|
||||
|
keep = _filter_boxes(proposals, min_size * img_size[2]) |
||||
|
proposals = proposals[keep, :] |
||||
|
scores = scores[keep] |
||||
|
|
||||
|
order = scores.ravel().argsort(descending=True) |
||||
|
proposals = t.FloatTensor(proposals, device = rpn_cls_prob.device) |
||||
|
|
||||
|
if pre_nms_topN > 0: |
||||
|
order = order[:pre_nms_topN] |
||||
|
|
||||
|
proposals = proposals[order, :] |
||||
|
scores = scores[order] |
||||
|
|
||||
|
keep = torchvision.ops.nms(proposals, scores.ravel(), nms_thresh) |
||||
|
|
||||
|
if post_nms_topN > 0: |
||||
|
keep = keep[:post_nms_topN] |
||||
|
proposals = proposals[keep, :] |
||||
|
scores = scores[keep] |
||||
|
|
||||
|
batch_inds = t.zeros((proposals.shape[0], 1), dtype = proposals.dtype, device = proposals.device) |
||||
|
rois = t.hstack([batch_inds, proposals]) |
||||
|
|
||||
|
return rois |
||||
|
|
||||
|
#keep = nms(np.hstack((proposals, scores)), nms_thresh) |
||||
|
#if post_nms_topN > 0: |
||||
|
# keep = keep[:post_nms_topN] |
||||
|
#proposals = proposals[keep, :] |
||||
|
#scores = scores[keep] |
||||
|
|
||||
|
def _enumerate_shifted_anchor(self, anchor_base, feat_stride, height, width): |
||||
|
# Enumerate all shifted anchors: |
||||
|
# |
||||
|
# add A anchors (1, A, 4) to |
||||
|
# cell K shifts (K, 1, 4) to get |
||||
|
# shift anchors (K, A, 4) |
||||
|
# reshape to (K*A, 4) shifted anchors |
||||
|
# return (K*A, 4) |
||||
|
|
||||
|
# !TODO: add support for torch.CudaTensor |
||||
|
# xp = cuda.get_array_module(anchor_base) |
||||
|
# it seems that it can't be boosed using GPU |
||||
|
import numpy as xp |
||||
|
shift_y = xp.arange(0, height * feat_stride, feat_stride) |
||||
|
shift_x = xp.arange(0, width * feat_stride, feat_stride) |
||||
|
shift_x, shift_y = xp.meshgrid(shift_x, shift_y) |
||||
|
shift = xp.stack((shift_x.ravel(), shift_y.ravel(), |
||||
|
shift_x.ravel(), shift_y.ravel()), axis=1) |
||||
|
|
||||
|
A = anchor_base.shape[0] |
||||
|
K = shift.shape[0] |
||||
|
anchor = anchor_base.reshape((1, A, 4)) + \ |
||||
|
shift.reshape((1, K, 4)).transpose((1, 0, 2)) |
||||
|
anchor = anchor.reshape((K * A, 4)).astype(np.float32) |
||||
|
|
||||
|
return anchor |
||||
|
|
||||
|
def _filter_boxes(boxes, min_size): |
||||
|
"""Remove all boxes with any side smaller than min_size.""" |
||||
|
ws = boxes[:, 2] - boxes[:, 0] + 1 |
||||
|
hs = boxes[:, 3] - boxes[:, 1] + 1 |
||||
|
keep = np.where((ws >= min_size) & (hs >= min_size))[0] |
||||
|
return keep |
||||
|
|
||||
|
#if __name__ == '__main__': |
||||
|
# rpn_cls_prob_reshape = data_kv['rpn_cls_prob_reshape'] |
||||
|
# rpn_bbox_pred = data_kv['rpn_bbox_pred'] |
||||
|
# |
||||
|
# rpn = RegionProposalNetwork(pre_nms_topN = 6000, post_nms_topN = 300, nms_thresh = 0.7, min_size = 16, anchor_scales = (4, 8, 16, 32), feat_stride=16) |
||||
|
# im_size = np.array([600. , 600. , 2.6785715]) |
||||
|
# ipdb.set_trace() |
||||
|
# rois = rpn.forward(rpn_cls_prob_reshape, rpn_bbox_pred, im_size) |
||||
|
# |
||||
|
# outputs = data_kv['res4b22'] |
||||
|
# feats = torchvision.ops.roi_pool(outputs,rois, output_size=[14,14], spatial_scale=0.0625) |
||||
|
# |
||||
|
# print(rpn_cls_prob_reshape.shape, rpn_bbox_pred.shape) |
||||
|
# print(rpn) |
||||
|
# print('main') |
||||
|
# |
Binary file not shown.
@ -0,0 +1,3 @@ |
|||||
|
IMG_DIM = 2048 |
||||
|
IMG_LABEL_DIM = 1601 |
||||
|
BUCKET_SIZE = 8192 |
@ -0,0 +1,366 @@ |
|||||
|
import torch |
||||
|
import numpy as np |
||||
|
import itertools |
||||
|
|
||||
|
from torch.nn.utils.rnn import pad_sequence |
||||
|
from uniter_model.data.itm import DetectFeatTxtTokDataset, TxtTokLmdb, DetectFeatLmdb, get_ids_and_lens |
||||
|
from uniter_model.data.data import get_gather_index |
||||
|
from toolz.sandbox import unzip |
||||
|
from cytoolz import concat |
||||
|
from GLOBAL_VARIABLES import N_EXAMPLES_TEACHER |
||||
|
|
||||
|
|
||||
|
def pad_tensors(tensors, lens=None, pad=0): |
||||
|
"""B x [T, ...]""" |
||||
|
if lens is None: |
||||
|
lens = [t.size(0) for t in tensors] |
||||
|
max_len = max(lens) |
||||
|
bs = len(tensors) |
||||
|
hid = tensors[0].size(-1) |
||||
|
dtype = tensors[0].dtype |
||||
|
output = torch.zeros(bs, max_len, hid, dtype=dtype) |
||||
|
if pad: |
||||
|
output.data.fill_(pad) |
||||
|
for i, (t, l) in enumerate(zip(tensors, lens)): |
||||
|
output.data[i, :l, ...] = t.data |
||||
|
return output |
||||
|
|
||||
|
|
||||
|
# for ITM task |
||||
|
class ItmFastDataset(DetectFeatTxtTokDataset): |
||||
|
""" NOTE this Dataset handles distributed training itself |
||||
|
(for more efficient negative sampling) """ |
||||
|
def __init__(self, txt_db, img_db, num_hard_negatives=0, img_meta=None, tokenizer=None): |
||||
|
assert isinstance(txt_db, TxtTokLmdb) |
||||
|
assert isinstance(img_db, DetectFeatLmdb) |
||||
|
|
||||
|
self.txt_db = txt_db |
||||
|
self.img_db = img_db |
||||
|
|
||||
|
self.txt_lens, self.ids = get_ids_and_lens(txt_db) |
||||
|
self.ids_2_idx = {idx:i for i, idx in enumerate(self.ids)} |
||||
|
self.all_imgs = list(set(txt_db[id_]['img_fname'] for id_ in self.ids)) |
||||
|
|
||||
|
self.num_hard_negatives = num_hard_negatives |
||||
|
self.img_meta = img_meta |
||||
|
self.tokenizer = tokenizer |
||||
|
self.train_imgs = None |
||||
|
self.neg_imgs = None |
||||
|
# self.new_epoch(hard_negatives) |
||||
|
|
||||
|
def new_epoch(self, hard_negatives_img=None, hard_negatives_txt=None): |
||||
|
""" should be called every epoch for more randomness""" |
||||
|
self.lens = [] |
||||
|
self.train_imgs, self.neg_imgs = [], [] |
||||
|
self.train_txts, self.neg_txts = [], [] |
||||
|
for i, (id_, tl) in enumerate(zip(self.ids, self.txt_lens)): |
||||
|
img_fname = super().__getitem__(i)['img_fname'] |
||||
|
self.train_imgs.append(img_fname) |
||||
|
self.train_txts.append(id_) |
||||
|
if hard_negatives_img is not None and self.num_hard_negatives > 0: |
||||
|
self.neg_imgs.append(hard_negatives_img[id_][:self.num_hard_negatives]) |
||||
|
self.neg_txts.append(hard_negatives_txt[img_fname][:self.num_hard_negatives]) |
||||
|
else: |
||||
|
self.neg_imgs.append(None) |
||||
|
self.neg_txts.append(None) |
||||
|
self.lens.append(tl + self.img_db.name2nbb[img_fname]) |
||||
|
|
||||
|
def __getitem__(self, i): |
||||
|
example = super().__getitem__(i) |
||||
|
# labels and negative images should be sampled every epoch |
||||
|
img_fname, hard_neg_imgs = self.train_imgs[i], self.neg_imgs[i] |
||||
|
txt_fname, hard_neg_txts = self.ids[i], self.neg_txts[i] |
||||
|
|
||||
|
img_input_ids = torch.Tensor([101]).long() |
||||
|
img_feat, img_pos_feat, num_bb = self._get_img_feat(img_fname) |
||||
|
attn_masks_img = torch.ones(num_bb+1, dtype=torch.long) |
||||
|
|
||||
|
# text input |
||||
|
input_ids = example['input_ids'] |
||||
|
input_ids = self.txt_db.combine_inputs(input_ids) |
||||
|
attn_masks = torch.ones(len(input_ids), dtype=torch.long) |
||||
|
|
||||
|
if hard_neg_imgs is not None: |
||||
|
# TODO: add hard negative here |
||||
|
neg_imgs = dict({'img_input_ids': [], 'img_feat': [], 'img_pos_feat': [], 'num_bb': [], 'attn_masks_img': [], |
||||
|
'caption_ids': [], 'attn_masks_captions': []}) |
||||
|
for neg_id in hard_neg_imgs: |
||||
|
neg_imgs['img_input_ids'].append(torch.Tensor([101]).long()) |
||||
|
t = self._get_img_feat(neg_id) |
||||
|
neg_imgs['img_feat'].append(t[0]) |
||||
|
neg_imgs['img_pos_feat'].append(t[1]) |
||||
|
neg_imgs['num_bb'].append(t[2]) |
||||
|
neg_imgs['attn_masks_img'].append(torch.ones(t[2]+1, dtype=torch.long)) |
||||
|
if self.img_meta is not None: |
||||
|
tmp = [self.tokenizer.encode(i, add_special_tokens=False) + [self.tokenizer.sep_token_id] |
||||
|
for i in self.img_meta[neg_id]['caption_multiple']] |
||||
|
neg_imgs['caption_ids'].append(torch.tensor([self.tokenizer.cls_token_id] + sum(tmp, []), |
||||
|
dtype=input_ids.dtype, device=input_ids.device)) |
||||
|
neg_imgs['attn_masks_captions'].append(torch.ones(len(neg_imgs['caption_ids'][-1]), dtype=torch.long)) |
||||
|
# debug = [self.tokenizer.encode(a) for a in self.img_meta[img_fname]['annotation']] |
||||
|
neg_txts = dict({'input_ids':[], 'position_ids':[], 'attention_mask':[]}) |
||||
|
for neg_id in hard_neg_txts: |
||||
|
ei = super().__getitem__(self.ids_2_idx[neg_id]) |
||||
|
input_ids_ei = ei['input_ids'] |
||||
|
neg_txts['input_ids'].append(self.txt_db.combine_inputs(input_ids_ei)) |
||||
|
neg_txts['attention_mask'].append(torch.ones(len(neg_txts['input_ids'][-1]), dtype=torch.long)) |
||||
|
else: |
||||
|
neg_imgs = None |
||||
|
neg_txts = None |
||||
|
|
||||
|
if self.img_meta is not None: |
||||
|
caption_ids = [self.tokenizer.encode(i, add_special_tokens=False) + [self.tokenizer.sep_token_id] for i in self.img_meta[img_fname]['caption_multiple']] |
||||
|
caption_ids = torch.tensor([self.tokenizer.cls_token_id] + sum(caption_ids, []), dtype=input_ids.dtype, device=input_ids.device) |
||||
|
attn_masks_captions = torch.ones(len(caption_ids), dtype=torch.long) |
||||
|
# debug = [self.tokenizer.encode(a) for a in self.img_meta[img_fname]['annotation']] |
||||
|
else: |
||||
|
caption_ids = None |
||||
|
attn_masks_captions = None |
||||
|
|
||||
|
# target = torch.Tensor(1).long() |
||||
|
# target.data.fill_(ground_truth_label) |
||||
|
return input_ids, img_feat, img_pos_feat, img_input_ids, attn_masks, attn_masks_img, self.ids[i], img_fname, neg_imgs, neg_txts, caption_ids, attn_masks_captions |
||||
|
|
||||
|
|
||||
|
def itm_fast_collate_kd(inputs): |
||||
|
input_ids, img_feats, img_pos_feats, img_input_ids, attn_masks_text, attn_masks_img, idx, img_fname, negs, caption_ids, attn_masks_captions = map(list, unzip(inputs)) |
||||
|
|
||||
|
# txt_lens = [i.size(0) for i in input_ids] |
||||
|
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0) |
||||
|
captions_ids = pad_sequence(caption_ids, batch_first=True, padding_value=0) if caption_ids[0] is not None else None |
||||
|
|
||||
|
position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long).unsqueeze(0) |
||||
|
position_ids_captions = torch.arange(0, captions_ids.size(1), dtype=torch.long).unsqueeze(0) if caption_ids[0] is not None else None |
||||
|
|
||||
|
if not None in negs: |
||||
|
num_bbs_neg = list(itertools.chain(*[n['num_bb'] for n in negs])) |
||||
|
img_feats_neg = list(itertools.chain(*[n['img_feat'] for n in negs])) |
||||
|
img_input_ids_neg = list(itertools.chain(*[n['img_input_ids'] for n in negs])) |
||||
|
img_pos_feat_neg = list(itertools.chain(*[n['img_pos_feat'] for n in negs])) |
||||
|
attn_masks_img_neg = list(itertools.chain(*[n['attn_masks_img'] for n in negs])) |
||||
|
else: |
||||
|
num_bbs_neg = [] |
||||
|
img_feats_neg = [] |
||||
|
img_input_ids_neg = [] |
||||
|
img_pos_feat_neg = [] |
||||
|
attn_masks_img_neg = [] |
||||
|
|
||||
|
num_bbs = [f.size(0) for f in img_feats] + num_bbs_neg |
||||
|
img_feat = pad_tensors(img_feats+img_feats_neg, num_bbs) |
||||
|
img_pos_feat = pad_tensors(img_pos_feats+img_pos_feat_neg, num_bbs) |
||||
|
|
||||
|
img_input_ids = pad_sequence(img_input_ids+img_input_ids_neg, batch_first=True, padding_value=0) |
||||
|
img_position_ids = torch.arange(0, img_input_ids.size(1), dtype=torch.long).unsqueeze(0) |
||||
|
|
||||
|
attn_masks_text = pad_sequence(attn_masks_text, batch_first=True, padding_value=0) |
||||
|
attn_masks_captions = pad_sequence(attn_masks_captions, batch_first=True, padding_value=0) if attn_masks_captions[0] is not None else None |
||||
|
attn_masks_img = pad_sequence(attn_masks_img+attn_masks_img_neg, batch_first=True, padding_value=0) |
||||
|
sample_size = len(inputs[0]) |
||||
|
assert all(sample_size == len(i) for i in inputs) |
||||
|
|
||||
|
bs, max_tl = input_ids.size() |
||||
|
out_size = attn_masks_img.size(1) |
||||
|
gather_index = get_gather_index([1]*bs, num_bbs, bs, 1, out_size) |
||||
|
|
||||
|
img_feat_teacher = img_feat[:N_EXAMPLES_TEACHER].repeat(bs, 1, 1) |
||||
|
img_pos_feat_teacher = img_pos_feat[:N_EXAMPLES_TEACHER].repeat(bs, 1, 1) |
||||
|
attn_masks_img_teacher = attn_masks_img[:N_EXAMPLES_TEACHER].repeat(bs, 1)[:, 1:] |
||||
|
|
||||
|
input_ids_teacher = input_ids.unsqueeze(1).repeat(1, 10, 1).view(bs*N_EXAMPLES_TEACHER, -1) |
||||
|
position_ids_teacher = position_ids |
||||
|
attn_masks_text_teacher = attn_masks_text.unsqueeze(1).repeat(1, 10, 1).view(bs*N_EXAMPLES_TEACHER, -1) |
||||
|
|
||||
|
attn_masks_teacher = torch.cat([attn_masks_text_teacher, attn_masks_img_teacher], dim=1) |
||||
|
|
||||
|
batch = { |
||||
|
'txt_ids': input_ids, |
||||
|
'img_ids': img_feat, |
||||
|
'caption_ids': captions_ids, |
||||
|
'txt_pos_ids': position_ids, |
||||
|
'img_pos_ids': img_pos_feat, |
||||
|
'caption_pos_ids': position_ids_captions, |
||||
|
'txt_attn_masks': attn_masks_text, |
||||
|
'img_attn_masks': attn_masks_img, |
||||
|
'caption_attn_masks': attn_masks_captions, |
||||
|
'img_txt_ids': img_input_ids, |
||||
|
'img_txt_pos_ids': img_position_ids, |
||||
|
'gather_index': gather_index, |
||||
|
'sample_size': sample_size, |
||||
|
'pos_ctx_indices': list(range(bs)), |
||||
|
'neg_ctx_indices': list(range(bs, len(num_bbs))), |
||||
|
'txt_index': idx, |
||||
|
'img_fname': img_fname, |
||||
|
|
||||
|
'img_feat_teacher': img_feat_teacher, |
||||
|
'img_pos_feat_teacher': img_pos_feat_teacher, |
||||
|
'input_ids_teacher': input_ids_teacher, |
||||
|
'position_ids_teacher': position_ids_teacher, |
||||
|
'attn_masks_teacher': attn_masks_teacher |
||||
|
} |
||||
|
return batch |
||||
|
|
||||
|
|
||||
|
def itm_fast_collate(inputs): |
||||
|
input_ids, img_feats, img_pos_feats, img_input_ids, attn_masks_text, attn_masks_img, idx, img_fname, neg_imgs, neg_txts, caption_ids, attn_masks_captions = map(list, unzip(inputs)) |
||||
|
bs = len(input_ids) |
||||
|
# txt_lens = [i.size(0) for i in input_ids] |
||||
|
|
||||
|
if not None in neg_imgs: |
||||
|
num_bbs_neg = list(itertools.chain(*[n['num_bb'] for n in neg_imgs])) |
||||
|
img_feats_neg = list(itertools.chain(*[n['img_feat'] for n in neg_imgs])) |
||||
|
img_input_ids_neg = list(itertools.chain(*[n['img_input_ids'] for n in neg_imgs])) |
||||
|
img_pos_feat_neg = list(itertools.chain(*[n['img_pos_feat'] for n in neg_imgs])) |
||||
|
attn_masks_img_neg = list(itertools.chain(*[n['attn_masks_img'] for n in neg_imgs])) |
||||
|
caption_ids_neg = list(itertools.chain(*[n['caption_ids'] for n in neg_imgs])) |
||||
|
attn_masks_captions_neg = list(itertools.chain(*[n['attn_masks_captions'] for n in neg_imgs])) |
||||
|
|
||||
|
input_ids_neg = list(itertools.chain(*[n['input_ids'] for n in neg_txts])) |
||||
|
attn_masks_text_neg = list(itertools.chain(*[n['attention_mask'] for n in neg_txts])) |
||||
|
else: |
||||
|
num_bbs_neg = [] |
||||
|
img_feats_neg = [] |
||||
|
img_input_ids_neg = [] |
||||
|
img_pos_feat_neg = [] |
||||
|
attn_masks_img_neg = [] |
||||
|
caption_ids_neg = [] |
||||
|
attn_masks_captions_neg = [] |
||||
|
|
||||
|
input_ids_neg = [] |
||||
|
attn_masks_text_neg = [] |
||||
|
|
||||
|
input_ids = pad_sequence(input_ids+input_ids_neg, batch_first=True, padding_value=0) |
||||
|
position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long).unsqueeze(0) |
||||
|
|
||||
|
captions_ids = pad_sequence(caption_ids+caption_ids_neg, batch_first=True, padding_value=0) if caption_ids[0] is not None else None |
||||
|
position_ids_captions = torch.arange(0, captions_ids.size(1), dtype=torch.long).unsqueeze(0) if caption_ids[0] is not None else None |
||||
|
|
||||
|
num_bbs = [f.size(0) for f in img_feats] + num_bbs_neg |
||||
|
img_feat = pad_tensors(img_feats+img_feats_neg, num_bbs) |
||||
|
img_pos_feat = pad_tensors(img_pos_feats+img_pos_feat_neg, num_bbs) |
||||
|
|
||||
|
img_input_ids = pad_sequence(img_input_ids+img_input_ids_neg, batch_first=True, padding_value=0) |
||||
|
img_position_ids = torch.arange(0, img_input_ids.size(1), dtype=torch.long).unsqueeze(0) |
||||
|
|
||||
|
attn_masks_text = pad_sequence(attn_masks_text+attn_masks_text_neg, batch_first=True, padding_value=0) |
||||
|
attn_masks_captions = pad_sequence(attn_masks_captions+attn_masks_captions_neg, batch_first=True, padding_value=0) if attn_masks_captions[0] is not None else None |
||||
|
attn_masks_img = pad_sequence(attn_masks_img+attn_masks_img_neg, batch_first=True, padding_value=0) |
||||
|
sample_size = bs |
||||
|
# assert all(sample_size == len(i) for i in inputs) |
||||
|
|
||||
|
max_tl = input_ids.shape[1] |
||||
|
out_size = attn_masks_img.size(1) |
||||
|
gather_index = get_gather_index([1]*bs, num_bbs, bs, 1, out_size) |
||||
|
|
||||
|
batch = { |
||||
|
'txts': { |
||||
|
'input_ids': input_ids, |
||||
|
'position_ids': position_ids, |
||||
|
'attention_mask': attn_masks_text, |
||||
|
'img_feat': None, |
||||
|
'img_pos_feat': None, |
||||
|
'img_masks': None, |
||||
|
'gather_index': None |
||||
|
}, |
||||
|
'imgs': { |
||||
|
'input_ids': img_input_ids, |
||||
|
'position_ids': img_position_ids, |
||||
|
'attention_mask': attn_masks_img, |
||||
|
'img_feat': img_feat, |
||||
|
'img_pos_feat': img_pos_feat, |
||||
|
'img_masks': None, |
||||
|
'gather_index': gather_index |
||||
|
}, |
||||
|
'caps': { |
||||
|
'input_ids': captions_ids, |
||||
|
'position_ids': position_ids_captions, |
||||
|
'attention_mask': attn_masks_captions, |
||||
|
'img_feat': None, |
||||
|
'img_pos_feat': None, |
||||
|
'img_masks': None, |
||||
|
'gather_index': None |
||||
|
}, |
||||
|
'sample_size': sample_size, |
||||
|
'pos_ctx_indices': list(range(bs)), |
||||
|
'neg_ctx_indices': list(range(bs, len(num_bbs))), |
||||
|
'txt_index': idx, |
||||
|
'img_fname': img_fname |
||||
|
} |
||||
|
return batch |
||||
|
|
||||
|
|
||||
|
class ItmValDataset(DetectFeatTxtTokDataset): |
||||
|
""" For evaluating Image-Text-Retrieval task """ |
||||
|
def __init__(self, db_dir, img_dir, mini_batch_size=400): |
||||
|
super().__init__(db_dir, img_dir) |
||||
|
del self.lens |
||||
|
self.txt2img = self.txt_db.txt2img |
||||
|
self.img2txts = self.txt_db.img2txts |
||||
|
self.all_img_ids = list(self.img2txts.keys()) |
||||
|
|
||||
|
assert len(self.img2txts) >= mini_batch_size > 0 |
||||
|
self.bs = mini_batch_size |
||||
|
|
||||
|
def _get_batch_ids(self, i): |
||||
|
gt_txt_id = self.ids[i] |
||||
|
gt_img_id = self.txt2img[gt_txt_id] |
||||
|
|
||||
|
# sample fixed negatives for each gt image |
||||
|
i = self.all_img_ids.index(gt_img_id) |
||||
|
neg_st = i+1 |
||||
|
neg_end = neg_st+self.bs-1 |
||||
|
if neg_end > len(self.all_img_ids): |
||||
|
# warp around |
||||
|
neg_end -= len(self.all_img_ids) |
||||
|
neg_img_ids = (self.all_img_ids[neg_st:] |
||||
|
+ self.all_img_ids[:neg_end]) |
||||
|
else: |
||||
|
neg_img_ids = self.all_img_ids[neg_st:neg_end] |
||||
|
|
||||
|
assert len(neg_img_ids) == (self.bs - 1),\ |
||||
|
"Did not sample enough neg samples" |
||||
|
|
||||
|
return gt_img_id, neg_img_ids |
||||
|
|
||||
|
def __getitem__(self, i): |
||||
|
""" this returns list of mini-batches """ |
||||
|
gt_img_id, neg_img_ids = self._get_batch_ids(i) |
||||
|
# NOTE 1st one is gt img |
||||
|
batch = self.get_batch(i, [gt_img_id] + neg_img_ids) |
||||
|
return batch |
||||
|
|
||||
|
def get_batch(self, i, img_ids): |
||||
|
example = super().__getitem__(i) |
||||
|
|
||||
|
input_ids = example['input_ids'] |
||||
|
input_ids = self.txt_db.combine_inputs(input_ids) |
||||
|
input_ids = input_ids.unsqueeze(0).expand(len(img_ids), -1).clone() |
||||
|
position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long |
||||
|
).unsqueeze(0) |
||||
|
|
||||
|
# process image features (gt always first) |
||||
|
img_feats, img_pos_feats, num_bbs = map( |
||||
|
list, unzip(map(self._get_img_feat, img_ids))) |
||||
|
img_feat = pad_tensors(img_feats, num_bbs) |
||||
|
img_pos_feat = pad_tensors(img_pos_feats, num_bbs) |
||||
|
|
||||
|
tl = input_ids.size(1) |
||||
|
attn_masks_text = torch.ones(len(img_ids), tl).long() |
||||
|
# attn_masks_text = torch.ones(1, tl).long() |
||||
|
attn_masks_img = torch.zeros(len(img_ids), max(num_bbs)).long() |
||||
|
for i, nbb in enumerate(num_bbs): |
||||
|
attn_masks_img.data[i, :nbb].fill_(1) |
||||
|
|
||||
|
# out_size = attn_masks.size(1) |
||||
|
gather_index = None #get_gather_index([tl]*len(img_ids), num_bbs, len(img_ids), tl, out_size) |
||||
|
|
||||
|
batch = {'input_ids': input_ids, |
||||
|
'position_ids': position_ids, |
||||
|
'img_feat': img_feat, |
||||
|
'img_pos_feat': img_pos_feat, |
||||
|
'attn_masks_text': attn_masks_text, |
||||
|
'attn_masks_img': attn_masks_img, |
||||
|
'gather_index': gather_index} |
||||
|
return batch |
||||
|
|
||||
|
|
||||
|
# for VQA |
@ -0,0 +1,592 @@ |
|||||
|
""" |
||||
|
Itm dataset |
||||
|
""" |
||||
|
from collections import defaultdict |
||||
|
import copy |
||||
|
import json |
||||
|
import random |
||||
|
|
||||
|
import torch |
||||
|
from torch.nn.utils.rnn import pad_sequence |
||||
|
import numpy as np |
||||
|
from toolz.sandbox import unzip |
||||
|
from cytoolz import concat |
||||
|
|
||||
|
from uniter_model.data.data import (DetectFeatTxtTokDataset, TxtTokLmdb, DetectFeatLmdb, |
||||
|
pad_tensors, get_gather_index, get_ids_and_lens) |
||||
|
from uniter_model.data.sampler import TokenBucketSampler |
||||
|
|
||||
|
|
||||
|
class TokenBucketSamplerForItm(TokenBucketSampler): |
||||
|
def __init__(self, dset, *args, **kwargs): |
||||
|
super().__init__(dset.lens, *args, **kwargs) |
||||
|
self.dset = dset |
||||
|
|
||||
|
def __iter__(self): |
||||
|
it = super().__iter__() |
||||
|
self.dset.new_epoch() |
||||
|
self._lens = self.dset.lens |
||||
|
return it |
||||
|
|
||||
|
|
||||
|
def _has_overlap(la, lb): |
||||
|
if len(la) < len(lb): |
||||
|
la, lb = lb, la |
||||
|
s = set(la) |
||||
|
return any(b in s for b in lb) |
||||
|
|
||||
|
|
||||
|
def _sample_negative_rand(sample_pool, ground_truths, num_sample): |
||||
|
""" random and retry """ |
||||
|
outputs = ground_truths[:1] |
||||
|
while _has_overlap(outputs, ground_truths): |
||||
|
outputs = random.sample(sample_pool, num_sample) |
||||
|
return outputs |
||||
|
|
||||
|
|
||||
|
def _sample_negative_extra(sample_pool, ground_truths, num_sample): |
||||
|
""" sample extra then remove """ |
||||
|
tot_size = len(ground_truths) + num_sample |
||||
|
outputs = set(random.sample(sample_pool, tot_size)) |
||||
|
for gt in ground_truths: |
||||
|
outputs.discard(gt) |
||||
|
outputs = list(outputs)[:num_sample] |
||||
|
return outputs |
||||
|
|
||||
|
|
||||
|
sample_negative = _sample_negative_rand # swith between 2 implementations |
||||
|
|
||||
|
|
||||
|
class ItmDataset(DetectFeatTxtTokDataset): |
||||
|
""" NOTE this Dataset handles distributed training itself |
||||
|
(for more efficient negative sampling) """ |
||||
|
def __init__(self, txt_db, img_db, neg_sample_p=0.0): |
||||
|
assert isinstance(txt_db, TxtTokLmdb) |
||||
|
assert isinstance(img_db, DetectFeatLmdb) |
||||
|
|
||||
|
self.txt_db = txt_db |
||||
|
self.img_db = img_db |
||||
|
|
||||
|
self.txt_lens, self.ids = get_ids_and_lens(txt_db) |
||||
|
self.all_imgs = list(set(txt_db[id_]['img_fname'] for id_ in self.ids)) |
||||
|
|
||||
|
self.neg_sample_p = neg_sample_p |
||||
|
self.new_epoch() |
||||
|
|
||||
|
def new_epoch(self): |
||||
|
""" should be called every epoch for more randomness""" |
||||
|
self.labels = np.random.choice( |
||||
|
[0, 1], size=len(self.ids), |
||||
|
p=[self.neg_sample_p, 1-self.neg_sample_p]) |
||||
|
|
||||
|
self.lens = [] |
||||
|
self.train_imgs = [] |
||||
|
for i, (id_, tl) in enumerate(zip(self.ids, self.txt_lens)): |
||||
|
img_fname = super().__getitem__(i)['img_fname'] |
||||
|
if self.labels[i] == 0: |
||||
|
img_fname = sample_negative(self.all_imgs, [img_fname], 1)[0] |
||||
|
self.train_imgs.append(img_fname) |
||||
|
self.lens.append(tl + self.img_db.name2nbb[img_fname]) |
||||
|
|
||||
|
def __getitem__(self, i): |
||||
|
example = super().__getitem__(i) |
||||
|
# labels and negative images should be sampled every epoch |
||||
|
ground_truth_label = self.labels[i] |
||||
|
img_fname = self.train_imgs[i] |
||||
|
img_input_ids = torch.Tensor([101]).long() |
||||
|
img_feat, img_pos_feat, num_bb = self._get_img_feat(img_fname) |
||||
|
attn_masks_img = torch.ones(num_bb+1, dtype=torch.long) |
||||
|
|
||||
|
# text input |
||||
|
input_ids = example['input_ids'] |
||||
|
input_ids = self.txt_db.combine_inputs(input_ids) |
||||
|
|
||||
|
attn_masks = torch.ones(len(input_ids), dtype=torch.long) |
||||
|
target = torch.Tensor(1).long() |
||||
|
target.data.fill_(ground_truth_label) |
||||
|
|
||||
|
return input_ids, attn_masks, img_input_ids, img_feat, img_pos_feat, attn_masks_img, target |
||||
|
|
||||
|
|
||||
|
def itm_collate(inputs): |
||||
|
(input_ids, attn_masks, img_input_ids, img_feats, img_pos_feats, attn_masks_img, targets |
||||
|
) = map(list, unzip(inputs)) |
||||
|
|
||||
|
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0) |
||||
|
position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long).unsqueeze(0) |
||||
|
|
||||
|
num_bbs = [f.size(0) for f in img_feats] |
||||
|
img_feat = pad_tensors(img_feats, num_bbs) |
||||
|
img_pos_feat = pad_tensors(img_pos_feats, num_bbs) |
||||
|
|
||||
|
img_input_ids = pad_sequence(img_input_ids, batch_first=True, padding_value=0) |
||||
|
img_position_ids = torch.arange(0, img_input_ids.size(1), dtype=torch.long).unsqueeze(0) |
||||
|
|
||||
|
attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0) |
||||
|
attn_masks_img = pad_sequence(attn_masks_img, batch_first=True, padding_value=0) |
||||
|
targets = torch.cat(targets, dim=0) |
||||
|
bs, max_tl = input_ids.size() |
||||
|
out_size = attn_masks_img.size(1) |
||||
|
# gather_index = get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size) |
||||
|
gather_index = get_gather_index([1]*bs, num_bbs, bs, 1, out_size) |
||||
|
|
||||
|
batch = { |
||||
|
'txts': { |
||||
|
'input_ids': input_ids, |
||||
|
'position_ids': position_ids, |
||||
|
'attention_mask': attn_masks, |
||||
|
'img_feat': None, |
||||
|
'img_pos_feat': None, |
||||
|
'img_masks': None, |
||||
|
'gather_index': None |
||||
|
}, |
||||
|
'imgs': { |
||||
|
'input_ids': img_input_ids, |
||||
|
'position_ids': img_position_ids, |
||||
|
'attention_mask': attn_masks_img, |
||||
|
'img_feat': img_feat, |
||||
|
'img_pos_feat': img_pos_feat, |
||||
|
'img_masks': None, |
||||
|
'gather_index': gather_index |
||||
|
}, |
||||
|
'pos_ctx_indices': list(range(bs)), |
||||
|
'neg_ctx_indices': list(range(bs, len(num_bbs))), |
||||
|
'targets': targets |
||||
|
} |
||||
|
return batch |
||||
|
|
||||
|
|
||||
|
def _compute_ot_scatter(txt_lens, max_txt_len, joint_len): |
||||
|
ot_scatter = torch.arange(0, joint_len, dtype=torch.long |
||||
|
).unsqueeze(0).repeat(len(txt_lens), 1) |
||||
|
for i, tl in enumerate(txt_lens): |
||||
|
max_ind = max_txt_len + (joint_len-tl) |
||||
|
ot_scatter.data[i, tl:] = torch.arange(max_txt_len, max_ind, |
||||
|
dtype=torch.long).data |
||||
|
return ot_scatter |
||||
|
|
||||
|
|
||||
|
def _compute_pad(lens, max_len): |
||||
|
pad = torch.zeros(len(lens), max_len, dtype=torch.bool) |
||||
|
for i, l in enumerate(lens): |
||||
|
pad.data[i, l:].fill_(1) |
||||
|
return pad |
||||
|
|
||||
|
|
||||
|
def itm_ot_collate(inputs): |
||||
|
(input_ids, img_feats, img_pos_feats, attn_masks, targets |
||||
|
) = map(list, unzip(inputs)) |
||||
|
|
||||
|
txt_lens = [i.size(0) for i in input_ids] |
||||
|
|
||||
|
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0) |
||||
|
position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long |
||||
|
).unsqueeze(0) |
||||
|
|
||||
|
num_bbs = [f.size(0) for f in img_feats] |
||||
|
img_feat = pad_tensors(img_feats, num_bbs) |
||||
|
img_pos_feat = pad_tensors(img_pos_feats, num_bbs) |
||||
|
|
||||
|
attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0) |
||||
|
targets = torch.cat(targets, dim=0) |
||||
|
bs, max_tl = input_ids.size() |
||||
|
out_size = attn_masks.size(1) |
||||
|
gather_index = get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size) |
||||
|
|
||||
|
# OT inputs |
||||
|
max_tl = max(txt_lens) |
||||
|
max_nbb = max(num_bbs) |
||||
|
ot_scatter = _compute_ot_scatter(txt_lens, max_tl, attn_masks.size(1)) |
||||
|
txt_pad = _compute_pad(txt_lens, max_tl) |
||||
|
img_pad = _compute_pad(num_bbs, max_nbb) |
||||
|
ot_inputs = {'ot_scatter': ot_scatter, |
||||
|
'scatter_max': ot_scatter.max().item(), |
||||
|
'txt_pad': txt_pad, |
||||
|
'img_pad': img_pad} |
||||
|
|
||||
|
batch = {'input_ids': input_ids, |
||||
|
'position_ids': position_ids, |
||||
|
'img_feat': img_feat, |
||||
|
'img_pos_feat': img_pos_feat, |
||||
|
'attn_masks': attn_masks, |
||||
|
'gather_index': gather_index, |
||||
|
'targets': targets, |
||||
|
'ot_inputs': ot_inputs} |
||||
|
return batch |
||||
|
|
||||
|
|
||||
|
class ItmRankDataset(DetectFeatTxtTokDataset): |
||||
|
def __init__(self, txt_db, img_db, neg_sample_size=1): |
||||
|
assert neg_sample_size > 0, \ |
||||
|
"ItmRankDataset need at least 1 negative sample" |
||||
|
super().__init__(txt_db, img_db) |
||||
|
|
||||
|
txt2img = self.txt_db.txt2img |
||||
|
self.txt2img = {id_: txt2img[id_] for id_ in self.ids} |
||||
|
# images partitioned by rank |
||||
|
self.img2txts = defaultdict(list) |
||||
|
for id_, img in self.txt2img.items(): |
||||
|
self.img2txts[img].append(id_) |
||||
|
self.img_name_list = list(self.img2txts.keys()) |
||||
|
|
||||
|
assert neg_sample_size > 0 |
||||
|
self.neg_sample_size = neg_sample_size |
||||
|
|
||||
|
def __getitem__(self, i): |
||||
|
gt_txt_id = self.ids[i] |
||||
|
gt_img_fname = self.txt2img[gt_txt_id] |
||||
|
|
||||
|
id_pairs = [(gt_txt_id, gt_img_fname)] |
||||
|
# sample negatives |
||||
|
neg_sample_img_ids = sample_negative( |
||||
|
self.img_name_list, [gt_img_fname], self.neg_sample_size) |
||||
|
neg_sample_txt_ids = sample_negative( |
||||
|
self.ids, self.img2txts[gt_img_fname], self.neg_sample_size) |
||||
|
id_pairs.extend([(gt_txt_id, neg_img_id) |
||||
|
for neg_img_id in neg_sample_img_ids] + |
||||
|
[(neg_txt_id, gt_img_fname) |
||||
|
for neg_txt_id in neg_sample_txt_ids]) |
||||
|
inputs = self._collect_inputs(id_pairs) |
||||
|
assert len(inputs) == (1 + 2*self.neg_sample_size) |
||||
|
return inputs |
||||
|
|
||||
|
def _collect_inputs(self, id_pairs): |
||||
|
# create input features |
||||
|
inputs = [] |
||||
|
for txt_id, img_id in id_pairs: |
||||
|
example = self.txt_db[txt_id] |
||||
|
# text input |
||||
|
input_ids = example['input_ids'] |
||||
|
input_ids = self.txt_db.combine_inputs(input_ids) |
||||
|
# img input |
||||
|
img_feat, img_pos_feat, num_bb = self._get_img_feat(img_id) |
||||
|
# mask |
||||
|
attn_masks_text = torch.ones(len(input_ids), dtype=torch.long) |
||||
|
attn_masks_img = torch.ones(num_bb, dtype=torch.long) |
||||
|
|
||||
|
inputs.append((input_ids, img_feat, img_pos_feat, attn_masks_text, attn_masks_img)) |
||||
|
|
||||
|
return inputs |
||||
|
|
||||
|
|
||||
|
class ItmRankDatasetHardNeg(ItmRankDataset): |
||||
|
def __init__(self, txt_db, img_db, neg_sample_size=1, hard_neg_size=1): |
||||
|
assert hard_neg_size > 0, \ |
||||
|
"ItmRankDatasetHardNeg need at least 1 hard negative sample" |
||||
|
DetectFeatTxtTokDataset.__init__(self, txt_db, img_db) |
||||
|
|
||||
|
txt2img = self.txt_db.txt2img |
||||
|
self.txt2img = {id_: txt2img[id_] for id_ in self.ids} |
||||
|
self.img2txts = self.txt_db.img2txts |
||||
|
self.img_name_list = list(self.img2txts.keys()) |
||||
|
|
||||
|
assert neg_sample_size > 0 |
||||
|
self.neg_sample_size = neg_sample_size |
||||
|
self.hard_neg_size = hard_neg_size |
||||
|
|
||||
|
def reload_hard_negs(self, hard_neg_dir): |
||||
|
self.txt2hardimgs = json.load( |
||||
|
open(f'{hard_neg_dir}/' |
||||
|
f'txt2hardimgs_rank{hvd.rank()}.json')) |
||||
|
self.img2hardtxts = json.load( |
||||
|
open(f'{hard_neg_dir}/img2hardtxts.json')) |
||||
|
|
||||
|
def __getitem__(self, i): |
||||
|
gt_txt_id = self.ids[i] |
||||
|
gt_img_fname = self.txt2img[gt_txt_id] |
||||
|
|
||||
|
id_pairs = [(gt_txt_id, gt_img_fname)] |
||||
|
# sample hard negatives |
||||
|
if self.hard_neg_size > 0: |
||||
|
hard_neg_img_samples = random.sample( |
||||
|
self.txt2hardimgs[gt_txt_id], self.hard_neg_size) |
||||
|
hard_neg_txt_samples = random.sample( |
||||
|
self.img2hardtxts[gt_img_fname], self.hard_neg_size) |
||||
|
id_pairs.extend([(gt_txt_id, neg_img_id) |
||||
|
for neg_img_id in hard_neg_img_samples] + |
||||
|
[(neg_txt_id, gt_img_fname) |
||||
|
for neg_txt_id in hard_neg_txt_samples]) |
||||
|
# sample normal negatives |
||||
|
if self.neg_sample_size > 0: |
||||
|
neg_sample_img_ids = sample_negative( |
||||
|
self.img_name_list, [gt_img_fname], self.neg_sample_size) |
||||
|
neg_sample_txt_ids = sample_negative( |
||||
|
self.ids, self.img2txts[gt_img_fname], self.neg_sample_size) |
||||
|
id_pairs.extend([(gt_txt_id, neg_img_id) |
||||
|
for neg_img_id in neg_sample_img_ids] + |
||||
|
[(neg_txt_id, gt_img_fname) |
||||
|
for neg_txt_id in neg_sample_txt_ids]) |
||||
|
|
||||
|
inputs = self._collect_inputs(id_pairs) |
||||
|
assert len(inputs) == (1 |
||||
|
+ 2*self.neg_sample_size |
||||
|
+ 2*self.hard_neg_size) |
||||
|
return inputs |
||||
|
|
||||
|
|
||||
|
def itm_rank_collate(inputs): |
||||
|
(input_ids, img_feats, img_pos_feats, attn_masks_text, attn_masks_img, |
||||
|
) = map(list, unzip(concat(i for i in inputs))) |
||||
|
|
||||
|
txt_lens = [i.size(0) for i in input_ids] |
||||
|
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0) |
||||
|
position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long |
||||
|
).unsqueeze(0) |
||||
|
|
||||
|
num_bbs = [f.size(0) for f in img_feats] |
||||
|
img_feat = pad_tensors(img_feats, num_bbs) |
||||
|
img_pos_feat = pad_tensors(img_pos_feats, num_bbs) |
||||
|
|
||||
|
attn_masks_text = pad_sequence(attn_masks_text, batch_first=True, padding_value=0) |
||||
|
attn_masks_img = pad_sequence(attn_masks_img, batch_first=True, padding_value=0) |
||||
|
sample_size = len(inputs[0]) |
||||
|
assert all(sample_size == len(i) for i in inputs) |
||||
|
|
||||
|
bs, max_tl = input_ids.size() |
||||
|
# out_size = attn_masks.size(1) |
||||
|
gather_index = None # get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size) |
||||
|
|
||||
|
batch = {'input_ids': input_ids, |
||||
|
'position_ids': position_ids, |
||||
|
'img_feat': img_feat, |
||||
|
'img_pos_feat': img_pos_feat, |
||||
|
'attn_masks_text': attn_masks_text, |
||||
|
'attn_masks_img': attn_masks_img, |
||||
|
'gather_index': gather_index, |
||||
|
'sample_size': sample_size} |
||||
|
return batch |
||||
|
|
||||
|
|
||||
|
class ItmRankDatasetHardNegFromText(DetectFeatTxtTokDataset): |
||||
|
def __init__(self, txt_db, img_db, neg_sample_size=1): |
||||
|
assert neg_sample_size > 0, \ |
||||
|
"ItmRankDatasetHardNegV2 need at least 1 negative sample" |
||||
|
super().__init__(txt_db, img_db) |
||||
|
|
||||
|
txt2img = self.txt_db.txt2img |
||||
|
self.txt2img = {id_: txt2img[id_] for id_ in self.ids} |
||||
|
self.img2txts = self.txt_db.img2txts |
||||
|
self.img_name_list = list(self.img2txts.keys()) |
||||
|
self.neg_sample_size = neg_sample_size |
||||
|
|
||||
|
def __getitem__(self, i): |
||||
|
gt_txt_id = self.ids[i] |
||||
|
gt_img_fname = self.txt2img[gt_txt_id] |
||||
|
|
||||
|
input_ids = self.txt_db[gt_txt_id]['input_ids'] |
||||
|
input_ids = self.txt_db.combine_inputs(input_ids) |
||||
|
input_ids = input_ids.unsqueeze(0) |
||||
|
position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long |
||||
|
).unsqueeze(0) |
||||
|
|
||||
|
neg_img_ids = sample_negative( |
||||
|
self.img_name_list, [gt_img_fname], self.neg_sample_size) |
||||
|
img_ids = [gt_img_fname] + neg_img_ids |
||||
|
# process image features (gt always first) |
||||
|
img_feats, img_pos_feats, num_bbs = map( |
||||
|
list, unzip(map(self._get_img_feat, img_ids))) |
||||
|
img_feat = pad_tensors(img_feats, num_bbs) |
||||
|
img_pos_feat = pad_tensors(img_pos_feats, num_bbs) |
||||
|
|
||||
|
tl = input_ids.size(1) |
||||
|
attn_masks = torch.zeros(len(img_ids), max(num_bbs) + tl).long() |
||||
|
for i, nbb in enumerate(num_bbs): |
||||
|
attn_masks.data[i, :tl+nbb].fill_(1) |
||||
|
out_size = attn_masks.size(1) |
||||
|
gather_index = get_gather_index([tl]*len(img_ids), num_bbs, |
||||
|
len(img_ids), tl, out_size) |
||||
|
|
||||
|
batch = {'input_ids': input_ids, |
||||
|
'position_ids': position_ids, |
||||
|
'img_feat': img_feat, |
||||
|
'img_pos_feat': img_pos_feat, |
||||
|
'attn_masks': attn_masks, |
||||
|
'gather_index': gather_index} |
||||
|
return batch |
||||
|
|
||||
|
|
||||
|
class ItmRankDatasetHardNegFromImage(DetectFeatTxtTokDataset): |
||||
|
def __init__(self, txt_db, img_db, neg_sample_size=1): |
||||
|
assert neg_sample_size > 0, \ |
||||
|
"ItmRankDatasetHardNegV2 need at least 1 negative sample" |
||||
|
super().__init__(txt_db, img_db) |
||||
|
|
||||
|
txt2img = self.txt_db.txt2img |
||||
|
self.txt2img = {id_: txt2img[id_] for id_ in self.ids} |
||||
|
self.img2txts = self.txt_db.img2txts |
||||
|
self.txt_name_list = list(self.txt2img.keys()) |
||||
|
self.neg_sample_size = neg_sample_size |
||||
|
|
||||
|
|
||||
|
def __getitem__(self, i): |
||||
|
gt_txt_id = self.ids[i] |
||||
|
gt_img_id = self.txt2img[gt_txt_id] |
||||
|
gt_txt_ids = self.img2txts[gt_img_id] |
||||
|
|
||||
|
# process image features (gt always first) |
||||
|
img_feat, img_pos_feat, nbb = self._get_img_feat(gt_img_id) |
||||
|
img_feat = img_feat.unsqueeze(0) |
||||
|
img_pos_feat = img_pos_feat.unsqueeze(0) |
||||
|
|
||||
|
# sample negative |
||||
|
neg_txt_ids = sample_negative( |
||||
|
self.txt_name_list, gt_txt_ids, self.neg_sample_size) |
||||
|
txt_ids = [gt_txt_id] + neg_txt_ids |
||||
|
|
||||
|
# process text inputs |
||||
|
all_inputs = [] |
||||
|
txt_lens = [] |
||||
|
for txt_id in txt_ids: |
||||
|
input_ids = self.txt_db.combine_inputs( |
||||
|
self.txt_db[txt_id]['input_ids']) |
||||
|
all_inputs.append(input_ids) |
||||
|
txt_lens.append(len(input_ids)) |
||||
|
input_ids = pad_sequence(all_inputs, batch_first=True, padding_value=0) |
||||
|
position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long |
||||
|
).unsqueeze(0) |
||||
|
|
||||
|
attn_masks = torch.zeros(len(txt_ids), max(txt_lens) + nbb).long() |
||||
|
for i, tl in enumerate(txt_lens): |
||||
|
attn_masks.data[i, :tl+nbb].fill_(1) |
||||
|
out_size = attn_masks.size(1) |
||||
|
gather_index = get_gather_index(txt_lens, [nbb]*len(txt_ids), |
||||
|
len(txt_ids), tl, out_size) |
||||
|
|
||||
|
batch = {'input_ids': input_ids, |
||||
|
'position_ids': position_ids, |
||||
|
'img_feat': img_feat, |
||||
|
'img_pos_feat': img_pos_feat, |
||||
|
'attn_masks': attn_masks, |
||||
|
'gather_index': gather_index} |
||||
|
return batch |
||||
|
|
||||
|
|
||||
|
def itm_rank_hnv2_collate(inputs): |
||||
|
assert len(inputs) == 1 |
||||
|
return inputs[0] |
||||
|
|
||||
|
|
||||
|
class ItmValDataset(DetectFeatTxtTokDataset): |
||||
|
""" For evaluating Image-Text-Retrieval task """ |
||||
|
def __init__(self, db_dir, img_dir, mini_batch_size=400): |
||||
|
super().__init__(db_dir, img_dir) |
||||
|
del self.lens |
||||
|
self.txt2img = self.txt_db.txt2img |
||||
|
self.img2txts = self.txt_db.img2txts |
||||
|
self.all_img_ids = list(self.img2txts.keys()) |
||||
|
|
||||
|
assert len(self.img2txts) >= mini_batch_size > 0 |
||||
|
self.bs = mini_batch_size |
||||
|
|
||||
|
def _get_batch_ids(self, i): |
||||
|
gt_txt_id = self.ids[i] |
||||
|
gt_img_id = self.txt2img[gt_txt_id] |
||||
|
|
||||
|
# sample fixed negatives for each gt image |
||||
|
i = self.all_img_ids.index(gt_img_id) |
||||
|
neg_st = i+1 |
||||
|
neg_end = neg_st+self.bs-1 |
||||
|
if neg_end > len(self.all_img_ids): |
||||
|
# warp around |
||||
|
neg_end -= len(self.all_img_ids) |
||||
|
neg_img_ids = (self.all_img_ids[neg_st:] |
||||
|
+ self.all_img_ids[:neg_end]) |
||||
|
else: |
||||
|
neg_img_ids = self.all_img_ids[neg_st:neg_end] |
||||
|
|
||||
|
assert len(neg_img_ids) == (self.bs - 1),\ |
||||
|
"Did not sample enough neg samples" |
||||
|
|
||||
|
return gt_img_id, neg_img_ids |
||||
|
|
||||
|
def __getitem__(self, i): |
||||
|
""" this returns list of mini-batches """ |
||||
|
gt_img_id, neg_img_ids = self._get_batch_ids(i) |
||||
|
# NOTE 1st one is gt img |
||||
|
batch = self.get_batch(i, [gt_img_id] + neg_img_ids) |
||||
|
return batch |
||||
|
|
||||
|
def get_batch(self, i, img_ids): |
||||
|
example = super().__getitem__(i) |
||||
|
|
||||
|
input_ids = example['input_ids'] |
||||
|
input_ids = self.txt_db.combine_inputs(input_ids) |
||||
|
input_ids = input_ids.unsqueeze(0).expand(len(img_ids), -1).clone() |
||||
|
position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long |
||||
|
).unsqueeze(0) |
||||
|
|
||||
|
# process image features (gt always first) |
||||
|
img_feats, img_pos_feats, num_bbs = map( |
||||
|
list, unzip(map(self._get_img_feat, img_ids))) |
||||
|
img_feat = pad_tensors(img_feats, num_bbs) |
||||
|
img_pos_feat = pad_tensors(img_pos_feats, num_bbs) |
||||
|
|
||||
|
tl = input_ids.size(1) |
||||
|
attn_masks_text = torch.ones(len(img_ids), tl).long() |
||||
|
# attn_masks_text = torch.ones(1, tl).long() |
||||
|
attn_masks_img = torch.zeros(len(img_ids), max(num_bbs)).long() |
||||
|
for i, nbb in enumerate(num_bbs): |
||||
|
attn_masks_img.data[i, :nbb].fill_(1) |
||||
|
|
||||
|
# out_size = attn_masks.size(1) |
||||
|
gather_index = None #get_gather_index([tl]*len(img_ids), num_bbs, len(img_ids), tl, out_size) |
||||
|
|
||||
|
batch = {'input_ids': input_ids, |
||||
|
'position_ids': position_ids, |
||||
|
'img_feat': img_feat, |
||||
|
'img_pos_feat': img_pos_feat, |
||||
|
'attn_masks_text': attn_masks_text, |
||||
|
'attn_masks_img': attn_masks_img, |
||||
|
'gather_index': gather_index} |
||||
|
return batch |
||||
|
|
||||
|
|
||||
|
def itm_val_collate(inputs): |
||||
|
assert len(inputs) == 1, "input batch size > 1" |
||||
|
return inputs[0] |
||||
|
|
||||
|
|
||||
|
class ItmHardNegDataset(ItmValDataset): |
||||
|
def _get_batch_ids(self, i): |
||||
|
gt_txt_id = self.ids[i] |
||||
|
gt_img_id = self.txt2img[gt_txt_id] |
||||
|
|
||||
|
# sample fixed negatives for each gt image |
||||
|
i = self.all_img_ids.index(gt_img_id) |
||||
|
all_img_ids = copy.deepcopy(self.all_img_ids) |
||||
|
all_img_ids.remove(gt_img_id) |
||||
|
random.shuffle(all_img_ids) |
||||
|
neg_img_ids = all_img_ids[:self.bs] |
||||
|
|
||||
|
assert len(neg_img_ids) == (self.bs),\ |
||||
|
"Did not sample enough neg samples" |
||||
|
|
||||
|
return gt_img_id, neg_img_ids |
||||
|
|
||||
|
def __getitem__(self, i): |
||||
|
""" this returns list of mini-batches """ |
||||
|
_, neg_img_ids = self._get_batch_ids(i) |
||||
|
batch = self.get_batch(i, neg_img_ids) |
||||
|
batch['gt_txt_id'] = self.ids[i] |
||||
|
batch['neg_img_ids'] = neg_img_ids |
||||
|
return batch |
||||
|
|
||||
|
|
||||
|
itm_hn_collate = itm_val_collate |
||||
|
|
||||
|
|
||||
|
class ItmEvalDataset(ItmValDataset): |
||||
|
def __init__(self, *args, **kwargs): |
||||
|
super().__init__(*args, **kwargs) |
||||
|
self.all_img_ids = sorted(copy.deepcopy(self.all_img_ids), |
||||
|
key=lambda i: self.img_db.name2nbb[i]) |
||||
|
|
||||
|
def __getitem__(self, i): |
||||
|
mini_batches = [] |
||||
|
for st in range(0, len(self.all_img_ids), self.bs): |
||||
|
mini_batches.append( |
||||
|
self.get_batch(i, self.all_img_ids[st:st+self.bs])) |
||||
|
return mini_batches |
||||
|
|
||||
|
|
||||
|
itm_eval_collate = itm_val_collate |
@ -0,0 +1,390 @@ |
|||||
|
""" |
||||
|
MLM datasets |
||||
|
""" |
||||
|
import math |
||||
|
import random |
||||
|
|
||||
|
import torch |
||||
|
from torch.utils.data import Dataset |
||||
|
from torch.nn.utils.rnn import pad_sequence |
||||
|
from toolz.sandbox import unzip |
||||
|
|
||||
|
from uniter_model.data.data import (DetectFeatTxtTokDataset, TxtTokLmdb, get_ids_and_lens, pad_tensors, |
||||
|
get_gather_index, get_gather_index_uniter) |
||||
|
|
||||
|
|
||||
|
def random_word(tokens, vocab_range, mask): |
||||
|
""" |
||||
|
Masking some random tokens for Language Model task with probabilities as in |
||||
|
the original BERT paper. |
||||
|
:param tokens: list of int, tokenized sentence. |
||||
|
:param vocab_range: for choosing a random word |
||||
|
:return: (list of int, list of int), masked tokens and related labels for |
||||
|
LM prediction |
||||
|
""" |
||||
|
output_label = [] |
||||
|
|
||||
|
for i, token in enumerate(tokens): |
||||
|
prob = random.random() |
||||
|
# mask token with 15% probability |
||||
|
if prob < 0.15: |
||||
|
prob /= 0.15 |
||||
|
|
||||
|
# 80% randomly change token to mask token |
||||
|
if prob < 0.8: |
||||
|
tokens[i] = mask |
||||
|
|
||||
|
# 10% randomly change token to random token |
||||
|
elif prob < 0.9: |
||||
|
tokens[i] = random.choice(list(range(*vocab_range))) |
||||
|
|
||||
|
# -> rest 10% randomly keep current token |
||||
|
|
||||
|
# append current token to output (we will predict these later) |
||||
|
output_label.append(token) |
||||
|
else: |
||||
|
# no masking token (will be ignored by loss function later) |
||||
|
output_label.append(-1) |
||||
|
if all(o == -1 for o in output_label): |
||||
|
# at least mask 1 |
||||
|
output_label[0] = tokens[0] |
||||
|
tokens[0] = mask |
||||
|
|
||||
|
return tokens, output_label |
||||
|
|
||||
|
|
||||
|
class MlmDataset(DetectFeatTxtTokDataset): |
||||
|
def __init__(self, txt_db, img_db): |
||||
|
assert isinstance(txt_db, TxtTokLmdb) |
||||
|
super().__init__(txt_db, img_db) |
||||
|
|
||||
|
def __getitem__(self, i): |
||||
|
""" |
||||
|
Return: |
||||
|
- input_ids : (L, ), i.e., [cls, wd, wd, ..., sep, 0, 0], 0s padded |
||||
|
- img_feat : (num_bb, d) |
||||
|
- img_pos_feat : (num_bb, 7) |
||||
|
- attn_masks : (L + num_bb, ), ie., [1, 1, ..., 0, 0, 1, 1] |
||||
|
- txt_labels : (L, ), [-1, -1, wid, -1, -1, -1] |
||||
|
0's padded so that (L + num_bb) % 8 == 0 |
||||
|
""" |
||||
|
example = super().__getitem__(i) |
||||
|
|
||||
|
# text input |
||||
|
input_ids, txt_labels = self.create_mlm_io(example['input_ids']) |
||||
|
|
||||
|
# img input |
||||
|
img_input_ids = torch.Tensor([101]).long() |
||||
|
img_feat, img_pos_feat, num_bb = self._get_img_feat(example['img_fname']) |
||||
|
|
||||
|
attn_masks = torch.ones(len(input_ids), dtype=torch.long) |
||||
|
attn_masks_img = torch.ones(num_bb+1, dtype=torch.long) |
||||
|
attn_masks_teacher = torch.ones(len(input_ids) + num_bb, dtype=torch.long) |
||||
|
|
||||
|
return input_ids, attn_masks, img_input_ids, img_feat, img_pos_feat, attn_masks_img, txt_labels, attn_masks_teacher |
||||
|
|
||||
|
def create_mlm_io(self, input_ids): |
||||
|
input_ids, txt_labels = random_word(input_ids, |
||||
|
self.txt_db.v_range, |
||||
|
self.txt_db.mask) |
||||
|
input_ids = torch.tensor([self.txt_db.cls_] |
||||
|
+ input_ids |
||||
|
+ [self.txt_db.sep]) |
||||
|
txt_labels = torch.tensor([-1] + txt_labels + [-1]) |
||||
|
return input_ids, txt_labels |
||||
|
|
||||
|
|
||||
|
def mlm_collate(inputs): |
||||
|
""" |
||||
|
Return: |
||||
|
:input_ids (n, max_L) padded with 0 |
||||
|
:position_ids (n, max_L) padded with 0 |
||||
|
:txt_lens list of [txt_len] |
||||
|
:img_feat (n, max_num_bb, feat_dim) |
||||
|
:img_pos_feat (n, max_num_bb, 7) |
||||
|
:num_bbs list of [num_bb] |
||||
|
:attn_masks (n, max_{L + num_bb}) padded with 0 |
||||
|
:txt_labels (n, max_L) padded with -1 |
||||
|
""" |
||||
|
(input_ids, attn_masks, img_input_ids, img_feats, img_pos_feats, attn_masks_img, txt_labels, attn_masks_teacher |
||||
|
) = map(list, unzip(inputs)) |
||||
|
|
||||
|
# text batches |
||||
|
txt_lens = [i.size(0) for i in input_ids] |
||||
|
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0) |
||||
|
txt_labels = pad_sequence(txt_labels, batch_first=True, padding_value=-1) |
||||
|
position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long).unsqueeze(0) |
||||
|
attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0) |
||||
|
|
||||
|
# image batches |
||||
|
num_bbs = [f.size(0) for f in img_feats] |
||||
|
img_input_ids = pad_sequence(img_input_ids, batch_first=True, padding_value=0) |
||||
|
img_feat = pad_tensors(img_feats, num_bbs) |
||||
|
img_pos_feat = pad_tensors(img_pos_feats, num_bbs) |
||||
|
img_position_ids = torch.arange(0, img_input_ids.size(1), dtype=torch.long).unsqueeze(0) |
||||
|
attn_masks_img = pad_sequence(attn_masks_img, batch_first=True, padding_value=0) |
||||
|
|
||||
|
bs, max_tl = input_ids.size() |
||||
|
out_size = attn_masks_img.size(1) |
||||
|
# gather_index = get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size) |
||||
|
gather_index = get_gather_index([1]*bs, num_bbs, bs, 1, out_size) |
||||
|
|
||||
|
attn_masks_teacher = pad_sequence(attn_masks_teacher, batch_first=True, padding_value=0) |
||||
|
gather_index_teacher = get_gather_index_uniter(txt_lens, num_bbs, bs, max_tl, attn_masks_teacher.size(1)) |
||||
|
|
||||
|
batch = { |
||||
|
'txts': { |
||||
|
'input_ids': input_ids, |
||||
|
'position_ids': position_ids, |
||||
|
'attention_mask': attn_masks, |
||||
|
'img_feat': None, |
||||
|
'img_pos_feat': None, |
||||
|
'img_masks': None, |
||||
|
'gather_index': None |
||||
|
}, |
||||
|
'imgs': { |
||||
|
'input_ids': img_input_ids, |
||||
|
'position_ids': img_position_ids, |
||||
|
'attention_mask': attn_masks_img, |
||||
|
'img_feat': img_feat, |
||||
|
'img_pos_feat': img_pos_feat, |
||||
|
'img_masks': None, |
||||
|
'gather_index': gather_index |
||||
|
}, |
||||
|
'txt_labels': txt_labels, |
||||
|
'teacher': { |
||||
|
'txt_lens': txt_lens, |
||||
|
'num_bbs': num_bbs, |
||||
|
'bs': bs, |
||||
|
'max_tl': max_tl, |
||||
|
'out_size': out_size, |
||||
|
'gather_index': gather_index_teacher, |
||||
|
'attn_masks': attn_masks_teacher |
||||
|
} |
||||
|
} |
||||
|
return batch |
||||
|
|
||||
|
|
||||
|
class BlindMlmDataset(Dataset): |
||||
|
def __init__(self, txt_db): |
||||
|
assert isinstance(txt_db, TxtTokLmdb) |
||||
|
self.txt_db = txt_db |
||||
|
self.lens, self.ids = get_ids_and_lens(txt_db) |
||||
|
|
||||
|
def __len__(self): |
||||
|
return len(self.ids) |
||||
|
|
||||
|
def __getitem__(self, i): |
||||
|
id_ = self.ids[i] |
||||
|
example = self.txt_db[id_] |
||||
|
input_ids, txt_labels = self.create_mlm_io(example['input_ids']) |
||||
|
attn_masks = torch.ones(len(input_ids), dtype=torch.long) |
||||
|
|
||||
|
return input_ids, attn_masks, txt_labels |
||||
|
|
||||
|
|
||||
|
def mlm_blind_collate(inputs): |
||||
|
input_ids, attn_masks, txt_labels = map(list, unzip(inputs)) |
||||
|
|
||||
|
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0) |
||||
|
position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long |
||||
|
).unsqueeze(0) |
||||
|
attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0) |
||||
|
txt_labels = pad_sequence(txt_labels, batch_first=True, padding_value=-1) |
||||
|
|
||||
|
batch = {'input_ids': input_ids, |
||||
|
'position_ids': position_ids, |
||||
|
'attn_masks': attn_masks, |
||||
|
'txt_labels': txt_labels} |
||||
|
return batch |
||||
|
|
||||
|
|
||||
|
def eval_mask(len_, num_samples=7): |
||||
|
""" build the mask for evaluating MLM |
||||
|
circularly mask 1 word out of every x words |
||||
|
""" |
||||
|
# build the random masks |
||||
|
if len_ <= num_samples: |
||||
|
masks = torch.eye(len_).bool() |
||||
|
num_samples = len_ |
||||
|
else: |
||||
|
mask_inds = [list(range(i, len_, num_samples)) |
||||
|
for i in range(num_samples)] |
||||
|
masks = torch.zeros(num_samples, len_).bool() |
||||
|
for i, indices in enumerate(mask_inds): |
||||
|
for j in indices: |
||||
|
masks.data[i, j] = 1 |
||||
|
assert (masks.sum(dim=0) != torch.ones(len_).long()).sum().item() == 0 |
||||
|
assert masks.sum().item() == len_ |
||||
|
return masks |
||||
|
|
||||
|
|
||||
|
def eval_gather_inds(len_, num_samples=7): |
||||
|
""" get the gather indices """ |
||||
|
inds = torch.arange(0, num_samples, dtype=torch.long) |
||||
|
mul = math.ceil(len_ / num_samples) |
||||
|
output = inds.repeat(mul)[:len_] |
||||
|
return output |
||||
|
|
||||
|
|
||||
|
def stack_pad_tensors(tensors, lens=None, ns=None, pad=0): |
||||
|
"""N x [B_i, T, ...]""" |
||||
|
if ns is None: |
||||
|
ns = [t.size(0) for t in tensors] |
||||
|
if lens is None: |
||||
|
lens = [t.size(1) for t in tensors] |
||||
|
max_len = max(lens) |
||||
|
bs = sum(ns) |
||||
|
hid_dims = tensors[0].size()[2:] |
||||
|
dtype = tensors[0].dtype |
||||
|
output = torch.zeros(bs, max_len, *hid_dims, dtype=dtype) |
||||
|
if pad: |
||||
|
output.data.fill_(pad) |
||||
|
i = 0 |
||||
|
for t, l, n in zip(tensors, lens, ns): |
||||
|
output.data[i:i+n, :l, ...] = t.data |
||||
|
i += n |
||||
|
return output |
||||
|
|
||||
|
|
||||
|
def expand_tensors(tensors, ns): |
||||
|
return [t.unsqueeze(0).expand(n, *tuple([-1]*t.dim())) |
||||
|
for t, n in zip(tensors, ns)] |
||||
|
|
||||
|
|
||||
|
class MlmEvalDataset(DetectFeatTxtTokDataset): |
||||
|
""" For evaluating MLM training task """ |
||||
|
def __init__(self, txt_db, img_db): |
||||
|
assert isinstance(txt_db, TxtTokLmdb) |
||||
|
super().__init__(txt_db, img_db) |
||||
|
|
||||
|
def __getitem__(self, i): |
||||
|
example = super().__getitem__(i) |
||||
|
|
||||
|
# text input |
||||
|
(input_ids, txt_labels, gather_inds |
||||
|
) = self.create_mlm_eval_io(example['input_ids']) |
||||
|
|
||||
|
# img input |
||||
|
img_feat, img_pos_feat, num_bb = self._get_img_feat( |
||||
|
example['img_fname']) |
||||
|
|
||||
|
attn_masks = torch.ones(input_ids.size(1) + num_bb, dtype=torch.long) |
||||
|
|
||||
|
return (input_ids, img_feat, img_pos_feat, attn_masks, |
||||
|
txt_labels, gather_inds) |
||||
|
|
||||
|
def create_mlm_eval_io(self, input_ids): |
||||
|
txt_labels = torch.tensor(input_ids) |
||||
|
masks = eval_mask(len(input_ids)) |
||||
|
n_mask = masks.size(0) |
||||
|
masks = torch.cat([torch.zeros(n_mask, 1).bool(), |
||||
|
masks, |
||||
|
torch.zeros(n_mask, 1).bool()], |
||||
|
dim=1) |
||||
|
input_ids = torch.tensor([[self.txt_db.cls_] |
||||
|
+ input_ids |
||||
|
+ [self.txt_db.sep] |
||||
|
for _ in range(n_mask)]) |
||||
|
input_ids.data.masked_fill_(masks, self.txt_db.mask) |
||||
|
gather_inds = eval_gather_inds(len(txt_labels)) |
||||
|
return input_ids, txt_labels, gather_inds |
||||
|
|
||||
|
|
||||
|
def _batch_gather_tgt(gather_inds, n_masks): |
||||
|
gather_tgts = [] |
||||
|
offset = 0 |
||||
|
for g, n in zip(gather_inds, n_masks): |
||||
|
gather_tgts.append(g + offset) |
||||
|
offset += n |
||||
|
gather_tgt = pad_sequence(gather_tgts, batch_first=True, padding_value=0) |
||||
|
return gather_tgt |
||||
|
|
||||
|
|
||||
|
def mlm_eval_collate(inputs): |
||||
|
(input_ids, img_feats, img_pos_feats, attn_masks, txt_labels, gather_inds |
||||
|
) = map(list, unzip(inputs)) |
||||
|
|
||||
|
# sizes |
||||
|
n_masks, txt_lens = map(list, unzip(i.size() for i in input_ids)) |
||||
|
|
||||
|
# text batches |
||||
|
input_ids = stack_pad_tensors(input_ids, txt_lens, n_masks) |
||||
|
position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long |
||||
|
).unsqueeze(0) |
||||
|
txt_labels = pad_sequence(txt_labels, batch_first=True, padding_value=-1) |
||||
|
gather_tgt = _batch_gather_tgt(gather_inds, n_masks) |
||||
|
|
||||
|
# image batches |
||||
|
num_bbs = [f.size(0) for f in img_feats] |
||||
|
img_feat = stack_pad_tensors(expand_tensors(img_feats, n_masks), |
||||
|
num_bbs, n_masks) |
||||
|
img_pos_feat = stack_pad_tensors(expand_tensors(img_pos_feats, n_masks), |
||||
|
num_bbs, n_masks) |
||||
|
|
||||
|
bs, max_tl = input_ids.size() |
||||
|
attn_masks = stack_pad_tensors(expand_tensors(attn_masks, n_masks), |
||||
|
None, n_masks) |
||||
|
out_size = attn_masks.size(1) |
||||
|
# repeat txt_lens, num_bbs |
||||
|
txt_lens = [l for l, n in zip(txt_lens, n_masks) for _ in range(n)] |
||||
|
num_bbs = [b for b, n in zip(num_bbs, n_masks) for _ in range(n)] |
||||
|
gather_index = get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size) |
||||
|
|
||||
|
batch = {'input_ids': input_ids, |
||||
|
'position_ids': position_ids, |
||||
|
'img_feat': img_feat, |
||||
|
'img_pos_feat': img_pos_feat, |
||||
|
'attn_masks': attn_masks, |
||||
|
'gather_index': gather_index, |
||||
|
'gather_tgt': gather_tgt, |
||||
|
'txt_labels': txt_labels} |
||||
|
return batch |
||||
|
|
||||
|
|
||||
|
class BlindMlmEvalDataset(Dataset): |
||||
|
def __init__(self, txt_db): |
||||
|
assert isinstance(txt_db, TxtTokLmdb) |
||||
|
self.txt_db = txt_db |
||||
|
self.lens, self.ids = get_ids_and_lens(txt_db) |
||||
|
|
||||
|
def __len__(self): |
||||
|
return len(self.ids) |
||||
|
|
||||
|
def __getitem__(self, i): |
||||
|
id_ = self.ids[i] |
||||
|
example = self.txt_db[id_] |
||||
|
input_ids = example['input_ids'] |
||||
|
|
||||
|
# text input |
||||
|
input_ids = example['input_ids'] |
||||
|
(input_ids, txt_labels, gather_inds |
||||
|
) = self.txt_db.create_mlm_eval_io(input_ids) |
||||
|
|
||||
|
attn_masks = torch.ones(len(input_ids), dtype=torch.long) |
||||
|
|
||||
|
return input_ids, attn_masks, txt_labels, gather_inds |
||||
|
|
||||
|
|
||||
|
def mlm_blind_eval_collate(inputs): |
||||
|
(input_ids, position_ids, attn_masks, txt_labels, gather_inds |
||||
|
) = map(list, unzip(inputs)) |
||||
|
|
||||
|
# sizes |
||||
|
n_masks, txt_lens = map(list, unzip(i.size() for i in input_ids)) |
||||
|
|
||||
|
# text batches |
||||
|
input_ids = stack_pad_tensors(input_ids, txt_lens, n_masks) |
||||
|
position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long |
||||
|
).unsqueeze(0) |
||||
|
attn_masks = stack_pad_tensors(expand_tensors(attn_masks, n_masks), |
||||
|
None, n_masks) |
||||
|
txt_labels = pad_sequence(txt_labels, batch_first=True, padding_value=-1) |
||||
|
gather_tgt = _batch_gather_tgt(gather_inds, n_masks) |
||||
|
|
||||
|
batch = {'input_ids': input_ids, |
||||
|
'position_ids': position_ids, |
||||
|
'attn_masks': attn_masks, |
||||
|
'gather_tgt': gather_tgt, |
||||
|
'txt_labels': txt_labels} |
||||
|
return batch |
@ -0,0 +1,263 @@ |
|||||
|
""" |
||||
|
MRM Datasets |
||||
|
""" |
||||
|
import random |
||||
|
|
||||
|
import torch |
||||
|
from torch.utils.data import Dataset |
||||
|
from torch.nn.utils.rnn import pad_sequence |
||||
|
from toolz.sandbox import unzip |
||||
|
from uniter_model.data.data import DetectFeatTxtTokDataset, pad_tensors, get_gather_index, get_gather_index_uniter |
||||
|
|
||||
|
|
||||
|
def _get_img_mask(mask_prob, num_bb): |
||||
|
img_mask = [random.random() < mask_prob for _ in range(num_bb)] |
||||
|
if not any(img_mask): |
||||
|
# at least mask 1 |
||||
|
img_mask[random.choice(range(num_bb))] = True |
||||
|
img_mask = torch.tensor(img_mask) |
||||
|
return img_mask |
||||
|
|
||||
|
|
||||
|
def _get_img_tgt_mask(img_mask, txt_len): |
||||
|
z = torch.zeros(txt_len, dtype=torch.bool) |
||||
|
img_mask_tgt = torch.cat([z, img_mask], dim=0) |
||||
|
return img_mask_tgt |
||||
|
|
||||
|
|
||||
|
def _get_feat_target(img_feat, img_masks): |
||||
|
img_masks_ext = img_masks.unsqueeze(-1).expand_as(img_feat) # (n, m, d) |
||||
|
feat_dim = img_feat.size(-1) |
||||
|
feat_targets = img_feat[img_masks_ext].contiguous().view( |
||||
|
-1, feat_dim) # (s, d) |
||||
|
return feat_targets |
||||
|
|
||||
|
|
||||
|
def _mask_img_feat(img_feat, img_masks): |
||||
|
img_masks_ext = img_masks.unsqueeze(-1).expand_as(img_feat) |
||||
|
img_feat_masked = img_feat.data.masked_fill(img_masks_ext, 0) |
||||
|
return img_feat_masked |
||||
|
|
||||
|
|
||||
|
class MrfrDataset(DetectFeatTxtTokDataset): |
||||
|
def __init__(self, mask_prob, *args, **kwargs): |
||||
|
super().__init__(*args, **kwargs) |
||||
|
self.mask_prob = mask_prob |
||||
|
|
||||
|
def __getitem__(self, i): |
||||
|
""" |
||||
|
Return: |
||||
|
- input_ids : (L, ), i.e., [cls, wd, wd, ..., sep, 0, 0], 0s padded |
||||
|
- img_feat : (num_bb, d) |
||||
|
- img_pos_feat : (num_bb, 7) |
||||
|
- attn_masks : (L + num_bb, ), ie., [1, 1, ..., 0, 0, 1, 1] |
||||
|
- img_mask : (num_bb, ) between {0, 1} |
||||
|
""" |
||||
|
example = super().__getitem__(i) |
||||
|
# text input |
||||
|
input_ids = example['input_ids'] |
||||
|
input_ids = self.txt_db.combine_inputs(input_ids) |
||||
|
|
||||
|
# image input features |
||||
|
img_input_ids = torch.Tensor([101]).long() |
||||
|
img_feat, img_pos_feat, num_bb = self._get_img_feat(example['img_fname']) |
||||
|
img_mask = _get_img_mask(self.mask_prob, num_bb) |
||||
|
img_mask_tgt = _get_img_tgt_mask(img_mask, 1) |
||||
|
img_mask_tgt_teacher = _get_img_tgt_mask(img_mask, len(input_ids)) |
||||
|
|
||||
|
attn_masks = torch.ones(len(input_ids), dtype=torch.long) |
||||
|
attn_masks_img = torch.ones(num_bb+1, dtype=torch.long) |
||||
|
attn_masks_teacher = torch.ones(len(input_ids) + num_bb, dtype=torch.long) |
||||
|
|
||||
|
return (input_ids, attn_masks, img_input_ids, img_feat, img_pos_feat, attn_masks_img, |
||||
|
img_mask, img_mask_tgt, attn_masks_teacher, img_mask_tgt_teacher) |
||||
|
|
||||
|
|
||||
|
def mrfr_collate(inputs): |
||||
|
""" |
||||
|
Return: |
||||
|
- input_ids : (n, max_L), i.e., [cls, wd, wd, ..., sep, 0, 0], 0s padded |
||||
|
- position_ids : (n, max_L) |
||||
|
- txt_lens : list of [input_len] |
||||
|
- img_feat : (n, max_num_bb, d) |
||||
|
- img_pos_feat : (n, max_num_bb, 7) |
||||
|
- num_bbs : list of [num_bb] |
||||
|
- attn_masks : (n, max_{L + num_bb}), ie., [1, 1, ..., 0, 0, 1, 1] |
||||
|
- img_masks : (n, max_num_bb) between {0, 1} |
||||
|
""" |
||||
|
(input_ids, attn_masks, img_input_ids, img_feats, img_pos_feats, attn_masks_img, img_masks, img_mask_tgts, |
||||
|
attn_masks_teacher, img_mask_tgt_teacher) = map(list, unzip(inputs)) |
||||
|
|
||||
|
txt_lens = [i.size(0) for i in input_ids] |
||||
|
|
||||
|
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0) |
||||
|
position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long |
||||
|
).unsqueeze(0) |
||||
|
|
||||
|
|
||||
|
num_bbs = [f.size(0) for f in img_feats] |
||||
|
img_feat = pad_tensors(img_feats, num_bbs) |
||||
|
img_input_ids = pad_sequence(img_input_ids, batch_first=True, padding_value=0) |
||||
|
img_pos_feat = pad_tensors(img_pos_feats, num_bbs) |
||||
|
img_position_ids = torch.arange(0, img_input_ids.size(1), dtype=torch.long).unsqueeze(0) |
||||
|
img_masks = pad_sequence(img_masks, batch_first=True, padding_value=0) |
||||
|
feat_targets = _get_feat_target(img_feat, img_masks) |
||||
|
img_feat = _mask_img_feat(img_feat, img_masks) |
||||
|
img_mask_tgt = pad_sequence(img_mask_tgts, batch_first=True, padding_value=0) |
||||
|
img_mask_tgt_teacher = pad_sequence(img_mask_tgt_teacher, batch_first=True, padding_value=0) |
||||
|
|
||||
|
attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0) |
||||
|
attn_masks_img = pad_sequence(attn_masks_img, batch_first=True, padding_value=0) |
||||
|
bs, max_tl = input_ids.size() |
||||
|
out_size = attn_masks_img.size(1) |
||||
|
# gather_index = get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size) |
||||
|
gather_index = get_gather_index([1]*bs, num_bbs, bs, 1, out_size) |
||||
|
|
||||
|
attn_masks_teacher = pad_sequence(attn_masks_teacher, batch_first=True, padding_value=0) |
||||
|
gather_index_teacher = get_gather_index_uniter(txt_lens, num_bbs, bs, max_tl, attn_masks_teacher.size(1)) |
||||
|
|
||||
|
batch = { |
||||
|
'txts': { |
||||
|
'input_ids': input_ids, |
||||
|
'position_ids': position_ids, |
||||
|
'attention_mask': attn_masks, |
||||
|
'img_feat': None, |
||||
|
'img_pos_feat': None, |
||||
|
'img_masks': None, |
||||
|
'gather_index': None |
||||
|
}, |
||||
|
'imgs': { |
||||
|
'input_ids': img_input_ids, |
||||
|
'position_ids': img_position_ids, |
||||
|
'attention_mask': attn_masks_img, |
||||
|
'img_feat': img_feat, |
||||
|
'img_pos_feat': img_pos_feat, |
||||
|
'img_masks': img_masks, |
||||
|
'gather_index': gather_index |
||||
|
}, |
||||
|
'teacher': { |
||||
|
'txt_lens': txt_lens, |
||||
|
'num_bbs': num_bbs, |
||||
|
'bs': bs, |
||||
|
'max_tl': max_tl, |
||||
|
'out_size': out_size, |
||||
|
'gather_index': gather_index_teacher, |
||||
|
'attn_masks': attn_masks_teacher, |
||||
|
'img_mask_tgt': img_mask_tgt_teacher, |
||||
|
}, |
||||
|
'feat_targets': feat_targets, |
||||
|
'img_mask_tgt': img_mask_tgt} |
||||
|
return batch |
||||
|
|
||||
|
|
||||
|
def _get_targets(img_masks, img_soft_label): |
||||
|
soft_label_dim = img_soft_label.size(-1) |
||||
|
img_masks_ext_for_label = img_masks.unsqueeze(-1).expand_as(img_soft_label) |
||||
|
label_targets = img_soft_label[img_masks_ext_for_label].contiguous().view( |
||||
|
-1, soft_label_dim) |
||||
|
return label_targets |
||||
|
|
||||
|
|
||||
|
class MrcDataset(DetectFeatTxtTokDataset): |
||||
|
def __init__(self, mask_prob, *args, **kwargs): |
||||
|
super().__init__(*args, **kwargs) |
||||
|
self.mask_prob = mask_prob |
||||
|
|
||||
|
def _get_img_feat(self, fname): |
||||
|
img_dump = self.img_db.get_dump(fname) |
||||
|
num_bb = self.img_db.name2nbb[fname] |
||||
|
img_feat = torch.tensor(img_dump['features']) |
||||
|
bb = torch.tensor(img_dump['norm_bb']) |
||||
|
img_bb = torch.cat([bb, bb[:, 4:5]*bb[:, 5:]], dim=-1) |
||||
|
img_soft_label = torch.tensor(img_dump['soft_labels']) |
||||
|
return img_feat, img_bb, img_soft_label, num_bb |
||||
|
|
||||
|
def __getitem__(self, i): |
||||
|
example = super().__getitem__(i) |
||||
|
img_feat, img_pos_feat, img_soft_labels, num_bb = self._get_img_feat( |
||||
|
example['img_fname']) |
||||
|
|
||||
|
# image input features |
||||
|
img_input_ids = torch.Tensor([101]).long() |
||||
|
img_mask = _get_img_mask(self.mask_prob, num_bb) |
||||
|
|
||||
|
# text input |
||||
|
input_ids = example['input_ids'] |
||||
|
input_ids = self.txt_db.combine_inputs(input_ids) |
||||
|
img_mask_tgt = _get_img_tgt_mask(img_mask, 1) |
||||
|
img_mask_tgt_teacher = _get_img_tgt_mask(img_mask, len(input_ids)) |
||||
|
|
||||
|
attn_masks = torch.ones(len(input_ids), dtype=torch.long) |
||||
|
attn_masks_img = torch.ones(num_bb+1, dtype=torch.long) |
||||
|
attn_masks_teacher = torch.ones(len(input_ids) + num_bb, dtype=torch.long) |
||||
|
|
||||
|
return (input_ids, attn_masks, img_input_ids, img_feat, img_pos_feat, attn_masks_img, |
||||
|
img_soft_labels, img_mask, img_mask_tgt, attn_masks_teacher, img_mask_tgt_teacher) |
||||
|
|
||||
|
|
||||
|
def mrc_collate(inputs): |
||||
|
(input_ids, attn_masks, img_input_ids, img_feats, img_pos_feats, attn_masks_img, img_soft_labels, |
||||
|
img_masks, img_mask_tgts, attn_masks_teacher, img_mask_tgt_teacher) = map(list, unzip(inputs)) |
||||
|
|
||||
|
txt_lens = [i.size(0) for i in input_ids] |
||||
|
num_bbs = [f.size(0) for f in img_feats] |
||||
|
|
||||
|
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0) |
||||
|
position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long |
||||
|
).unsqueeze(0) |
||||
|
|
||||
|
img_feat = pad_tensors(img_feats, num_bbs) |
||||
|
img_input_ids = pad_sequence(img_input_ids, batch_first=True, padding_value=0) |
||||
|
img_pos_feat = pad_tensors(img_pos_feats, num_bbs) |
||||
|
img_position_ids = torch.arange(0, img_input_ids.size(1), dtype=torch.long).unsqueeze(0) |
||||
|
img_soft_label = pad_tensors(img_soft_labels, num_bbs) |
||||
|
img_masks = pad_sequence(img_masks, batch_first=True, padding_value=0) |
||||
|
label_targets = _get_targets(img_masks, img_soft_label) |
||||
|
|
||||
|
img_feat = _mask_img_feat(img_feat, img_masks) |
||||
|
img_mask_tgt = pad_sequence(img_mask_tgts, batch_first=True, padding_value=0) |
||||
|
img_mask_tgt_teacher = pad_sequence(img_mask_tgt_teacher, batch_first=True, padding_value=0) |
||||
|
|
||||
|
attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0) |
||||
|
attn_masks_img = pad_sequence(attn_masks_img, batch_first=True, padding_value=0) |
||||
|
bs, max_tl = input_ids.size() |
||||
|
out_size = attn_masks_img.size(1) |
||||
|
# gather_index = get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size) |
||||
|
gather_index = get_gather_index([1]*bs, num_bbs, bs, 1, out_size) |
||||
|
|
||||
|
attn_masks_teacher = pad_sequence(attn_masks_teacher, batch_first=True, padding_value=0) |
||||
|
gather_index_teacher = get_gather_index_uniter(txt_lens, num_bbs, bs, max_tl, attn_masks_teacher.size(1)) |
||||
|
|
||||
|
batch = { |
||||
|
'txts': { |
||||
|
'input_ids': input_ids, |
||||
|
'position_ids': position_ids, |
||||
|
'attention_mask': attn_masks, |
||||
|
'img_feat': None, |
||||
|
'img_pos_feat': None, |
||||
|
'img_masks': None, |
||||
|
'gather_index': None |
||||
|
}, |
||||
|
'imgs': { |
||||
|
'input_ids': img_input_ids, |
||||
|
'position_ids': img_position_ids, |
||||
|
'attention_mask': attn_masks_img, |
||||
|
'img_feat': img_feat, |
||||
|
'img_pos_feat': img_pos_feat, |
||||
|
'img_masks': img_masks, |
||||
|
'gather_index': gather_index |
||||
|
}, |
||||
|
'teacher': { |
||||
|
'txt_lens': txt_lens, |
||||
|
'num_bbs': num_bbs, |
||||
|
'bs': bs, |
||||
|
'max_tl': max_tl, |
||||
|
'out_size': out_size, |
||||
|
'gather_index': gather_index_teacher, |
||||
|
'attn_masks': attn_masks_teacher, |
||||
|
'img_mask_tgt': img_mask_tgt_teacher, |
||||
|
}, |
||||
|
'img_mask_tgt': img_mask_tgt, |
||||
|
'label_targets': label_targets} |
||||
|
return batch |
||||
|
|
@ -0,0 +1,145 @@ |
|||||
|
""" |
||||
|
VQA dataset |
||||
|
""" |
||||
|
import torch |
||||
|
from torch.nn.utils.rnn import pad_sequence |
||||
|
from toolz.sandbox import unzip |
||||
|
|
||||
|
from uniter_model.data.data import DetectFeatTxtTokDataset, pad_tensors, get_gather_index |
||||
|
|
||||
|
|
||||
|
def _get_vqa_target(example, num_answers): |
||||
|
target = torch.zeros(num_answers) |
||||
|
labels = example['target']['labels'] |
||||
|
scores = example['target']['scores'] |
||||
|
if labels and scores: |
||||
|
target.scatter_(0, torch.tensor(labels), torch.tensor(scores)) |
||||
|
return target |
||||
|
|
||||
|
|
||||
|
class VqaDataset(DetectFeatTxtTokDataset): |
||||
|
""" NOTE: This handels distributed inside """ |
||||
|
def __init__(self, num_answers, *args, **kwargs): |
||||
|
super().__init__(*args, **kwargs) |
||||
|
self.num_answers = num_answers |
||||
|
|
||||
|
def __getitem__(self, i): |
||||
|
example = super().__getitem__(i) |
||||
|
qid = self.ids[i] |
||||
|
img_feat, img_pos_feat, num_bb = self._get_img_feat( |
||||
|
example['img_fname']) |
||||
|
|
||||
|
# text input |
||||
|
input_ids = example['input_ids'] |
||||
|
input_ids = self.txt_db.combine_inputs(input_ids) |
||||
|
img_input_ids = torch.Tensor([101]).long() |
||||
|
|
||||
|
target = _get_vqa_target(example, self.num_answers) |
||||
|
|
||||
|
attn_masks_txt = torch.ones(len(input_ids), dtype=torch.long) |
||||
|
attn_masks_img = torch.ones(num_bb+1, dtype=torch.long) |
||||
|
|
||||
|
return qid, input_ids, attn_masks_txt, img_input_ids, img_feat, img_pos_feat, attn_masks_img, target |
||||
|
|
||||
|
|
||||
|
def vqa_collate(inputs): |
||||
|
(qids, input_ids, attn_masks_txt, img_input_ids, img_feats, img_pos_feats, attn_masks_img, targets |
||||
|
) = map(list, unzip(inputs)) |
||||
|
|
||||
|
txt_lens = [i.size(0) for i in input_ids] |
||||
|
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0) |
||||
|
position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long |
||||
|
).unsqueeze(0) |
||||
|
|
||||
|
attn_masks_txt = pad_sequence(attn_masks_txt, batch_first=True, padding_value=0) |
||||
|
attn_masks_img = pad_sequence(attn_masks_img, batch_first=True, padding_value=0) |
||||
|
targets = torch.stack(targets, dim=0) |
||||
|
|
||||
|
num_bbs = [f.size(0) for f in img_feats] |
||||
|
img_feat = pad_tensors(img_feats, num_bbs) |
||||
|
img_pos_feat = pad_tensors(img_pos_feats, num_bbs) |
||||
|
img_input_ids = pad_sequence(img_input_ids, batch_first=True, padding_value=0) |
||||
|
img_position_ids = torch.arange(0, img_input_ids.size(1), dtype=torch.long).unsqueeze(0) |
||||
|
|
||||
|
bs, max_tl = input_ids.size() |
||||
|
out_size = attn_masks_img.size(1) |
||||
|
gather_index_teacher = get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size) |
||||
|
gather_index = get_gather_index([1]*bs, num_bbs, bs, 1, out_size) |
||||
|
|
||||
|
batch = {'qids': qids, |
||||
|
'txts': { |
||||
|
'input_ids': input_ids, |
||||
|
'position_ids': position_ids, |
||||
|
'attention_mask': attn_masks_txt, |
||||
|
'img_feat': None, |
||||
|
'img_pos_feat': None, |
||||
|
'img_masks': None, |
||||
|
'gather_index': None |
||||
|
}, |
||||
|
'imgs': { |
||||
|
'input_ids': img_input_ids, |
||||
|
'position_ids': img_position_ids, |
||||
|
'attention_mask': attn_masks_img, |
||||
|
'img_feat': img_feat, |
||||
|
'img_pos_feat': img_pos_feat, |
||||
|
'img_masks': None, |
||||
|
'gather_index': gather_index |
||||
|
}, |
||||
|
'gather_index_teacher': gather_index_teacher, |
||||
|
'targets': targets} |
||||
|
return batch |
||||
|
|
||||
|
|
||||
|
class VqaEvalDataset(VqaDataset): |
||||
|
def __getitem__(self, i): |
||||
|
qid = self.ids[i] |
||||
|
example = DetectFeatTxtTokDataset.__getitem__(self, i) |
||||
|
img_feat, img_pos_feat, num_bb = self._get_img_feat( |
||||
|
example['img_fname']) |
||||
|
|
||||
|
# text input |
||||
|
input_ids = example['input_ids'] |
||||
|
input_ids = self.txt_db.combine_inputs(input_ids) |
||||
|
|
||||
|
if 'target' in example: |
||||
|
target = _get_vqa_target(example, self.num_answers) |
||||
|
else: |
||||
|
target = None |
||||
|
|
||||
|
attn_masks = torch.ones(len(input_ids) + num_bb, dtype=torch.long) |
||||
|
|
||||
|
return qid, input_ids, img_feat, img_pos_feat, attn_masks, target |
||||
|
|
||||
|
|
||||
|
def vqa_eval_collate(inputs): |
||||
|
(qids, input_ids, img_feats, img_pos_feats, attn_masks, targets |
||||
|
) = map(list, unzip(inputs)) |
||||
|
|
||||
|
txt_lens = [i.size(0) for i in input_ids] |
||||
|
|
||||
|
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0) |
||||
|
position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long |
||||
|
).unsqueeze(0) |
||||
|
attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0) |
||||
|
if targets[0] is None: |
||||
|
targets = None |
||||
|
else: |
||||
|
targets = torch.stack(targets, dim=0) |
||||
|
|
||||
|
num_bbs = [f.size(0) for f in img_feats] |
||||
|
img_feat = pad_tensors(img_feats, num_bbs) |
||||
|
img_pos_feat = pad_tensors(img_pos_feats, num_bbs) |
||||
|
|
||||
|
bs, max_tl = input_ids.size() |
||||
|
out_size = attn_masks.size(1) |
||||
|
gather_index = get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size) |
||||
|
|
||||
|
batch = {'qids': qids, |
||||
|
'input_ids': input_ids, |
||||
|
'position_ids': position_ids, |
||||
|
'img_feat': img_feat, |
||||
|
'img_pos_feat': img_pos_feat, |
||||
|
'attn_masks': attn_masks, |
||||
|
'gather_index': gather_index, |
||||
|
'targets': targets} |
||||
|
return batch |
@ -0,0 +1,66 @@ |
|||||
|
import random |
||||
|
import logging |
||||
|
import collections |
||||
|
import json |
||||
|
import os |
||||
|
import itertools |
||||
|
import numpy as np |
||||
|
from collections import ChainMap |
||||
|
|
||||
|
|
||||
|
from dvl.trainer import build_dataloader, _save_checkpoint, eval_model_on_dataloader, load_dataset |
||||
|
|
||||
|
|
||||
|
logger = logging.getLogger() |
||||
|
|
||||
|
|
||||
|
def random_hard_neg(fname2id, num_hard_negatives, id2set, set2id): |
||||
|
# num_hard_negatives must be very small |
||||
|
hard_negs = dict() |
||||
|
for i in fname2id: |
||||
|
while True: |
||||
|
hard_neg = random.choices(set2id[id2set[i]], k=num_hard_negatives) |
||||
|
if fname2id[i] not in hard_neg: |
||||
|
break |
||||
|
hard_negs[i] = hard_neg |
||||
|
return hard_negs |
||||
|
|
||||
|
|
||||
|
def get_img_txt_mappings(train_txt_dbs): |
||||
|
train_img2txt = dict(ChainMap(*[json.load(open(os.path.join(db_folder, 'img2txts.json'))) for db_folder in train_txt_dbs])) |
||||
|
train_txt2img = dict(itertools.chain(*[[(v, k) for v in vals] for k, vals in train_img2txt.items()])) |
||||
|
|
||||
|
train_json = [json.load(open(os.path.join(db_folder, 'img2txts.json'))) for db_folder in train_txt_dbs] |
||||
|
train_img2set = dict(ChainMap(*[{k:v for k in tj } for tj, v in zip(train_json, train_txt_dbs)])) |
||||
|
train_txt2set = {txt_id: train_img2set[img_id] for txt_id, img_id in train_txt2img.items()} |
||||
|
|
||||
|
train_set2img, train_set2txt = collections.defaultdict(list), collections.defaultdict(list) |
||||
|
for img_id, set_id in train_img2set.items(): |
||||
|
train_set2img[set_id].append(img_id) |
||||
|
train_set2txt[set_id] += train_img2txt[img_id] |
||||
|
|
||||
|
return train_img2txt, train_txt2img, train_img2set, train_txt2set, train_set2img, train_set2txt |
||||
|
|
||||
|
|
||||
|
def sampled_hard_negatives(all_img_dbs, args, collate_func, bi_encoder, train_img2txt, train_txt2img): |
||||
|
train_dataset_eval = load_dataset(all_img_dbs, args.train_txt_dbs, args.train_img_dbs, args, True) |
||||
|
hard_negs_txt_all, hard_negs_img_all = [], [] |
||||
|
for dset in train_dataset_eval.datasets: |
||||
|
dset.new_epoch() |
||||
|
train_dataloader_hn = build_dataloader(dset, collate_func, True, args, args.valid_batch_size) |
||||
|
logger.info(f'eval for train dataloader len (for hn) = {len(train_dataloader_hn)}') |
||||
|
|
||||
|
num_hard_sampled = min(max(args.num_hard_negatives * 2 + 10, 50), 1000) |
||||
|
loss_hard, correct_ratio_hard, indexer_hard, recall_hard, (hard_neg_img, hard_neg_txt) = \ |
||||
|
eval_model_on_dataloader(bi_encoder, train_dataloader_hn, args, train_img2txt, num_hard_sampled) |
||||
|
|
||||
|
[v.remove(train_txt2img[k]) for k, v in hard_neg_img.items() if train_txt2img[k] in v] |
||||
|
hard_neg_txt = {k: list(set(v) - set(train_img2txt[k])) for k, v in hard_neg_txt.items()} |
||||
|
|
||||
|
|
||||
|
# remove self in hard negatives as they are labels |
||||
|
hard_negs_txt_all.append({k: random.sample(v, args.num_hard_negatives) for k, v in hard_neg_txt.items()}) |
||||
|
hard_negs_img_all.append({k: random.sample(v, args.num_hard_negatives) for k, v in hard_neg_img.items()}) |
||||
|
hard_negs_txt_all = dict(collections.ChainMap(*hard_negs_txt_all)) |
||||
|
hard_negs_img_all = dict(collections.ChainMap(*hard_negs_img_all)) |
||||
|
return hard_negs_txt_all, hard_negs_img_all |
@ -0,0 +1,154 @@ |
|||||
|
#!/usr/bin/env python3 |
||||
|
# Copyright (c) Facebook, Inc. and its affiliates. |
||||
|
# All rights reserved. |
||||
|
# |
||||
|
# This source code is licensed under the license found in the |
||||
|
# LICENSE file in the root directory of this source tree. |
||||
|
|
||||
|
""" |
||||
|
FAISS-based index components for dense retriver |
||||
|
""" |
||||
|
|
||||
|
import logging |
||||
|
import pickle |
||||
|
from typing import List, Tuple |
||||
|
|
||||
|
import faiss |
||||
|
import numpy as np |
||||
|
|
||||
|
logger = logging.getLogger() |
||||
|
|
||||
|
|
||||
|
class DenseIndexer(object): |
||||
|
|
||||
|
def __init__(self, buffer_size: int = 50000): |
||||
|
self.buffer_size = buffer_size |
||||
|
self.index_id_to_db_id = [] |
||||
|
self.index = None |
||||
|
|
||||
|
def index_data(self, data: List[Tuple[object, np.array]]): |
||||
|
raise NotImplementedError |
||||
|
|
||||
|
def search_knn(self, query_vectors: np.array, top_docs: int) -> List[Tuple[List[object], List[float]]]: |
||||
|
raise NotImplementedError |
||||
|
|
||||
|
def serialize(self, file: str): |
||||
|
logger.info('Serializing index to %s', file) |
||||
|
|
||||
|
index_file = file + '.index.dpr' |
||||
|
meta_file = file + '.index_meta.dpr' |
||||
|
|
||||
|
faiss.write_index(self.index, index_file) |
||||
|
with open(meta_file, mode='wb') as f: |
||||
|
pickle.dump(self.index_id_to_db_id, f) |
||||
|
|
||||
|
def deserialize_from(self, file: str): |
||||
|
logger.info('Loading index from %s', file) |
||||
|
|
||||
|
index_file = file + '.index.dpr' |
||||
|
meta_file = file + '.index_meta.dpr' |
||||
|
|
||||
|
self.index = faiss.read_index(index_file) |
||||
|
logger.info('Loaded index of type %s and size %d', type(self.index), self.index.ntotal) |
||||
|
|
||||
|
with open(meta_file, "rb") as reader: |
||||
|
self.index_id_to_db_id = pickle.load(reader) |
||||
|
assert len( |
||||
|
self.index_id_to_db_id) == self.index.ntotal, 'Deserialized index_id_to_db_id should match faiss index size' |
||||
|
|
||||
|
def _update_id_mapping(self, db_ids: List): |
||||
|
self.index_id_to_db_id.extend(db_ids) |
||||
|
|
||||
|
|
||||
|
class DenseFlatIndexer(DenseIndexer): |
||||
|
|
||||
|
def __init__(self, vector_sz: int, buffer_size: int = 50000): |
||||
|
super(DenseFlatIndexer, self).__init__(buffer_size=buffer_size) |
||||
|
self.index = faiss.IndexFlatIP(vector_sz) |
||||
|
|
||||
|
def index_data(self, data: List[Tuple[object, np.array]]): |
||||
|
n = len(data) |
||||
|
# indexing in batches is beneficial for many faiss index types |
||||
|
for i in range(0, n, self.buffer_size): |
||||
|
db_ids = [t[0] for t in data[i:i + self.buffer_size]] |
||||
|
vectors = [np.reshape(t[1], (1, -1)) for t in data[i:i + self.buffer_size]] |
||||
|
vectors = np.concatenate(vectors, axis=0) |
||||
|
self._update_id_mapping(db_ids) |
||||
|
self.index.add(vectors) |
||||
|
|
||||
|
indexed_cnt = len(self.index_id_to_db_id) |
||||
|
logger.info('Total data indexed %d', indexed_cnt) |
||||
|
|
||||
|
def search_knn(self, query_vectors: np.array, top_docs: int) -> List[Tuple[List[object], List[float]]]: |
||||
|
scores, indexes = self.index.search(query_vectors, top_docs) |
||||
|
# convert to external ids |
||||
|
db_ids = [[self.index_id_to_db_id[i] for i in query_top_idxs] for query_top_idxs in indexes] |
||||
|
result = [(db_ids[i], scores[i]) for i in range(len(db_ids))] |
||||
|
return result |
||||
|
|
||||
|
|
||||
|
class DenseHNSWFlatIndexer(DenseIndexer): |
||||
|
""" |
||||
|
Efficient index for retrieval. Note: default settings are for hugh accuracy but also high RAM usage |
||||
|
""" |
||||
|
|
||||
|
def __init__(self, vector_sz: int, buffer_size: int = 50000, store_n: int = 512 |
||||
|
, ef_search: int = 128, ef_construction: int = 200): |
||||
|
super(DenseHNSWFlatIndexer, self).__init__(buffer_size=buffer_size) |
||||
|
|
||||
|
# IndexHNSWFlat supports L2 similarity only |
||||
|
# so we have to apply DOT -> L2 similairy space conversion with the help of an extra dimension |
||||
|
index = faiss.IndexHNSWFlat(vector_sz + 1, store_n) |
||||
|
index.hnsw.efSearch = ef_search |
||||
|
index.hnsw.efConstruction = ef_construction |
||||
|
self.index = index |
||||
|
self.phi = 0 |
||||
|
|
||||
|
def index_data(self, data: List[Tuple[object, np.array]]): |
||||
|
n = len(data) |
||||
|
|
||||
|
# max norm is required before putting all vectors in the index to convert inner product similarity to L2 |
||||
|
if self.phi > 0: |
||||
|
raise RuntimeError('DPR HNSWF index needs to index all data at once,' |
||||
|
'results will be unpredictable otherwise.') |
||||
|
phi = 0 |
||||
|
for i, item in enumerate(data): |
||||
|
id, doc_vector = item |
||||
|
norms = (doc_vector ** 2).sum() |
||||
|
phi = max(phi, norms) |
||||
|
logger.info('HNSWF DotProduct -> L2 space phi={}'.format(phi)) |
||||
|
self.phi = 0 |
||||
|
|
||||
|
# indexing in batches is beneficial for many faiss index types |
||||
|
for i in range(0, n, self.buffer_size): |
||||
|
db_ids = [t[0] for t in data[i:i + self.buffer_size]] |
||||
|
vectors = [np.reshape(t[1], (1, -1)) for t in data[i:i + self.buffer_size]] |
||||
|
|
||||
|
norms = [(doc_vector ** 2).sum() for doc_vector in vectors] |
||||
|
aux_dims = [np.sqrt(phi - norm) for norm in norms] |
||||
|
hnsw_vectors = [np.hstack((doc_vector, aux_dims[i].reshape(-1, 1))) for i, doc_vector in |
||||
|
enumerate(vectors)] |
||||
|
hnsw_vectors = np.concatenate(hnsw_vectors, axis=0) |
||||
|
|
||||
|
self._update_id_mapping(db_ids) |
||||
|
self.index.add(hnsw_vectors) |
||||
|
logger.info('data indexed %d', len(self.index_id_to_db_id)) |
||||
|
|
||||
|
indexed_cnt = len(self.index_id_to_db_id) |
||||
|
logger.info('Total data indexed %d', indexed_cnt) |
||||
|
|
||||
|
def search_knn(self, query_vectors: np.array, top_docs: int) -> List[Tuple[List[object], List[float]]]: |
||||
|
|
||||
|
aux_dim = np.zeros(len(query_vectors), dtype='float32') |
||||
|
query_nhsw_vectors = np.hstack((query_vectors, aux_dim.reshape(-1, 1))) |
||||
|
# logger.info('query_hnsw_vectors %s', query_nhsw_vectors.shape) |
||||
|
scores, indexes = self.index.search(query_nhsw_vectors, top_docs) |
||||
|
# convert to external ids |
||||
|
db_ids = [[self.index_id_to_db_id[i] for i in query_top_idxs] for query_top_idxs in indexes] |
||||
|
result = [(db_ids[i], scores[i]) for i in range(len(db_ids))] |
||||
|
return result |
||||
|
|
||||
|
def deserialize_from(self, file: str): |
||||
|
super(DenseHNSWFlatIndexer, self).deserialize_from(file) |
||||
|
# to trigger warning on subsequent indexing |
||||
|
self.phi = 1 |
Binary file not shown.
Binary file not shown.
@ -0,0 +1,757 @@ |
|||||
|
#!/usr/bin/env python3 |
||||
|
# Copyright (c) Facebook, Inc. and its affiliates. |
||||
|
# All rights reserved. |
||||
|
# |
||||
|
# This source code is licensed under the license found in the |
||||
|
# LICENSE file in the root directory of this source tree. |
||||
|
|
||||
|
|
||||
|
""" |
||||
|
Pipeline to train DPR Biencoder |
||||
|
""" |
||||
|
|
||||
|
import argparse |
||||
|
import glob |
||||
|
import logging |
||||
|
import math |
||||
|
import os |
||||
|
import random |
||||
|
import time |
||||
|
import json |
||||
|
import csv |
||||
|
import re |
||||
|
|
||||
|
import torch |
||||
|
import numpy as np |
||||
|
|
||||
|
from typing import Tuple |
||||
|
from collections import defaultdict |
||||
|
from torch import nn |
||||
|
from torch import Tensor as T |
||||
|
from torch.optim.lr_scheduler import LambdaLR |
||||
|
|
||||
|
import torch.nn.functional as F |
||||
|
from torch.nn import LayerNorm |
||||
|
from transformers import BertModel, BertConfig, BertPreTrainedModel |
||||
|
from transformers.optimization import AdamW |
||||
|
from uniter_model.model.model import UniterPreTrainedModel, UniterModel, UniterConfig |
||||
|
from dvl.const import IMG_DIM |
||||
|
#from dvl.indexer.faiss_indexers import DenseIndexer |
||||
|
from uniter_model.model.layer import GELU, BertOnlyMLMHead, BertPooler |
||||
|
from uniter_model.model.model import RegionClassification, RegionFeatureRegression, pad_tensor_to_mul |
||||
|
|
||||
|
from typing import List |
||||
|
|
||||
|
|
||||
|
logger = logging.getLogger() |
||||
|
logger.setLevel(logging.INFO) |
||||
|
if (logger.hasHandlers()): |
||||
|
logger.handlers.clear() |
||||
|
console = logging.StreamHandler() |
||||
|
logger.addHandler(console) |
||||
|
|
||||
|
|
||||
|
def dot_product_scores(q_vectors: T, ctx_vectors: T, cosine=False) -> T: |
||||
|
""" |
||||
|
calculates q->ctx scores for every row in ctx_vector |
||||
|
:param q_vector: |
||||
|
:param ctx_vector: |
||||
|
:return: |
||||
|
""" |
||||
|
# q_vector: n1 x D, ctx_vectors: n2 x D, result n1 x n2 |
||||
|
r = torch.matmul(q_vectors, torch.transpose(ctx_vectors, 0, 1)) |
||||
|
if cosine: |
||||
|
n1 = torch.norm(q_vectors, dim=-1) |
||||
|
n2 = torch.norm(ctx_vectors, dim=-1) |
||||
|
n_out = torch.ger(n1, n2) |
||||
|
return r / n_out |
||||
|
return r |
||||
|
|
||||
|
|
||||
|
def cosine_scores(q_vector: T, ctx_vectors: T): |
||||
|
# q_vector: n1 x D, ctx_vectors: n2 x D, result n1 x n2 |
||||
|
return F.cosine_similarity(q_vector, ctx_vectors, dim=1) |
||||
|
|
||||
|
|
||||
|
class BertEncoder(BertPreTrainedModel): |
||||
|
def __init__(self, config, project_dim: int = 0): |
||||
|
super().__init__(config) |
||||
|
self.bert = BertModel(config) |
||||
|
assert config.hidden_size > 0, 'Encoder hidden_size can\'t be zero' |
||||
|
# self.encode_proj = nn.Linear(config.hidden_size, project_dim) if project_dim != 0 else None |
||||
|
if project_dim > 0: |
||||
|
self.encode_proj = nn.Sequential( |
||||
|
nn.Linear(config.hidden_size, config.hidden_size * 2), |
||||
|
GELU(), |
||||
|
LayerNorm(config.hidden_size * 2, eps=1e-12), |
||||
|
nn.Linear(config.hidden_size * 2, project_dim) |
||||
|
) |
||||
|
else: |
||||
|
self.encode_proj = None |
||||
|
self.init_weights() |
||||
|
|
||||
|
@classmethod |
||||
|
def init_encoder(cls, cfg_name: str, checkpoint_path: str, project_dim: int = 0, dropout: float = 0.1, **kwargs)\ |
||||
|
-> BertModel: |
||||
|
|
||||
|
cfg = BertConfig.from_pretrained(cfg_name if cfg_name else 'bert-base-uncased') |
||||
|
if dropout != 0: |
||||
|
cfg.attention_probs_dropout_prob = dropout |
||||
|
cfg.hidden_dropout_prob = dropout |
||||
|
|
||||
|
if checkpoint_path is not None and len(checkpoint_path) > 0: |
||||
|
#state_dict = torch.load(checkpoint_path, map_location=torch.device('cpu')) |
||||
|
state_dict = torch.load(checkpoint_path) |
||||
|
return cls.from_pretrained(cfg_name, config=cfg, project_dim=project_dim, state_dict=state_dict, **kwargs) |
||||
|
else: |
||||
|
return cls.from_pretrained(cfg_name, config=cfg, project_dim=project_dim, **kwargs) |
||||
|
|
||||
|
def forward(self, input_ids, attention_mask, position_ids, |
||||
|
img_feat=None, img_pos_feat=None, img_masks=None, gather_index=None): |
||||
|
if self.config.output_hidden_states: |
||||
|
sequence_output, pooled_output, hidden_states = self.bert(input_ids=input_ids, |
||||
|
token_type_ids=None, |
||||
|
attention_mask=attention_mask, |
||||
|
position_ids=position_ids) |
||||
|
else: |
||||
|
hidden_states = None |
||||
|
sequence_output, pooled_output = self.bert(input_ids=input_ids, |
||||
|
token_type_ids=None, |
||||
|
attention_mask=attention_mask, |
||||
|
position_ids=position_ids) |
||||
|
pooled_output = sequence_output[:, 0, :] |
||||
|
if self.encode_proj: |
||||
|
pooled_output = self.encode_proj(pooled_output) |
||||
|
return sequence_output, pooled_output, hidden_states |
||||
|
|
||||
|
def get_out_size(self): |
||||
|
if self.encode_proj: |
||||
|
return self.encode_proj.out_features |
||||
|
return self.config.hidden_size |
||||
|
|
||||
|
|
||||
|
class UniterEncoder(UniterPreTrainedModel): |
||||
|
def __init__(self, config, project_dim: int = 0): |
||||
|
super().__init__(config) |
||||
|
self.bert = UniterModel(config, IMG_DIM) |
||||
|
assert config.hidden_size > 0, 'Encoder hidden_size can\'t be zero' |
||||
|
# self.encode_proj = nn.Linear(config.hidden_size, project_dim) if project_dim != 0 else None # Yen-Chun |
||||
|
if project_dim > 0: |
||||
|
self.encode_proj = nn.Sequential( |
||||
|
nn.Linear(config.hidden_size, config.hidden_size * 2), |
||||
|
GELU(), |
||||
|
LayerNorm(config.hidden_size * 2, eps=1e-12), |
||||
|
nn.Linear(config.hidden_size * 2, project_dim) |
||||
|
) |
||||
|
else: |
||||
|
self.encode_proj = None |
||||
|
self.apply(self.init_weights) |
||||
|
|
||||
|
@classmethod |
||||
|
def init_encoder(cls, cfg_name: str, checkpoint_path: str, project_dim: int = 0, dropout: float = 0.1, **kwargs)\ |
||||
|
-> UniterModel: |
||||
|
cfg = BertConfig.from_pretrained(cfg_name if cfg_name else 'bert-base-uncased') |
||||
|
if dropout != 0: |
||||
|
cfg.attention_probs_dropout_prob = dropout |
||||
|
cfg.hidden_dropout_prob = dropout |
||||
|
if checkpoint_path is not None and len(checkpoint_path) > 0 and checkpoint_path.lower() != 'none': |
||||
|
logger.info(f'load from {checkpoint_path} for uniter encoder') |
||||
|
state_dict = torch.load(checkpoint_path, map_location=torch.device('cpu')) |
||||
|
state_dict = torch.load(checkpoint_path) |
||||
|
#if 'model_dict' in state_dict: |
||||
|
# state_dict = state_dict['model_dict'] |
||||
|
else: |
||||
|
logger.info('no checkpoint, random initialization for img encoder') |
||||
|
state_dict = dict() |
||||
|
return cls.from_pretrained(cfg_name, state_dict=state_dict, project_dim=project_dim, **kwargs) |
||||
|
|
||||
|
def forward(self, input_ids, attention_mask, position_ids, |
||||
|
img_feat, img_pos_feat, img_masks, gather_index=None) -> Tuple[T, ...]: |
||||
|
if self.config.output_hidden_states: |
||||
|
sequence_output, pooled_output, hidden_states = self.bert(input_ids=input_ids, |
||||
|
position_ids=position_ids, |
||||
|
attention_mask=attention_mask, |
||||
|
img_feat=img_feat, |
||||
|
img_pos_feat=img_pos_feat, |
||||
|
img_masks=img_masks, |
||||
|
img_type_ids=None, |
||||
|
gather_index=gather_index, |
||||
|
output_all_encoded_layers=True |
||||
|
) |
||||
|
else: |
||||
|
hidden_states = None |
||||
|
sequence_output = self.bert(input_ids=input_ids, |
||||
|
position_ids=position_ids, |
||||
|
attention_mask=attention_mask, |
||||
|
img_feat=img_feat, |
||||
|
img_pos_feat=img_pos_feat, |
||||
|
img_masks=img_masks, |
||||
|
img_type_ids=None, |
||||
|
gather_index=gather_index, |
||||
|
output_all_encoded_layers=False) |
||||
|
# pooled_output = self.bert.pooler(sequence_output) |
||||
|
pooled_output = sequence_output[:, 0, :] |
||||
|
if self.encode_proj: |
||||
|
pooled_output = self.encode_proj(pooled_output) |
||||
|
return sequence_output, pooled_output, hidden_states |
||||
|
|
||||
|
def get_out_size(self): |
||||
|
if self.encode_proj: |
||||
|
return self.encode_proj.out_features |
||||
|
return self.config.hidden_size |
||||
|
|
||||
|
|
||||
|
class BiEncoder(nn.Module): |
||||
|
""" Bi-Encoder model component. Encapsulates query/question and context/passage encoders. |
||||
|
""" |
||||
|
|
||||
|
def __init__(self, args, fix_img_encoder: bool = False, fix_txt_encoder: bool = False, project_dim: int = 0): |
||||
|
super(BiEncoder, self).__init__() |
||||
|
logger.info('*'*100) |
||||
|
logger.info('loading img model') |
||||
|
if args.img_model_type == 'uniter-base': |
||||
|
self.img_model = UniterEncoder.init_encoder(args.img_model_config, checkpoint_path=args.img_checkpoint, project_dim=project_dim) |
||||
|
else: |
||||
|
raise ValueError(f'image encoder does not support other types ({args.img_model_type}) for now') |
||||
|
|
||||
|
logger.info('*' * 100) |
||||
|
logger.info('loading txt model') |
||||
|
if args.txt_model_type == 'bert-base': |
||||
|
self.txt_model = BertEncoder.init_encoder(args.txt_model_config, checkpoint_path=args.txt_checkpoint, project_dim=project_dim) |
||||
|
elif args.txt_model_type == 'uniter-base': |
||||
|
self.txt_model = UniterEncoder.init_encoder(args.txt_model_config, checkpoint_path=args.txt_checkpoint, project_dim=project_dim) |
||||
|
else: |
||||
|
raise ValueError(f'txt encoder does not support other types ({args.txt_model_type}) for now') |
||||
|
|
||||
|
self.fix_img_encoder = fix_img_encoder |
||||
|
self.fix_txt_encoder = fix_txt_encoder |
||||
|
self.project_dim = project_dim |
||||
|
if fix_txt_encoder: |
||||
|
for param in self.txt_model.parameters(): |
||||
|
param.requires_grad = False |
||||
|
if fix_img_encoder: |
||||
|
for param in self.img_model.parameters(): |
||||
|
param.requires_grad = False |
||||
|
|
||||
|
@staticmethod |
||||
|
def get_representation(sub_model, input_ids, attention_mask, position_ids, img_feat, img_pos_feat, img_masks, |
||||
|
gather_index=None, fix_encoder=False): |
||||
|
if fix_encoder: |
||||
|
with torch.no_grad(): |
||||
|
sequence_output, pooled_output, hidden_states = sub_model(input_ids, attention_mask, position_ids, |
||||
|
img_feat, img_pos_feat, img_masks, |
||||
|
gather_index) |
||||
|
else: |
||||
|
sequence_output, pooled_output, hidden_states = sub_model(input_ids, attention_mask, position_ids, |
||||
|
img_feat, img_pos_feat, img_masks, |
||||
|
gather_index) |
||||
|
|
||||
|
if sub_model.training: |
||||
|
sequence_output.requires_grad_(requires_grad=True) |
||||
|
pooled_output.requires_grad_(requires_grad=True) |
||||
|
|
||||
|
return sequence_output, pooled_output, hidden_states |
||||
|
|
||||
|
def forward(self, batch, output_all_encoded_layers=False): |
||||
|
# batch keys |
||||
|
# imgs |
||||
|
# txts |
||||
|
# caps |
||||
|
batch = defaultdict(lambda: None, batch) |
||||
|
|
||||
|
if 'txts' in batch: |
||||
|
sb = batch['txts'] |
||||
|
txt_seq, txt_pooled, txt_hidden = self.get_representation(self.txt_model, sb['input_ids'], |
||||
|
sb['attention_mask'], sb['position_ids'], |
||||
|
sb['img_feat'], sb['img_pos_feat'], |
||||
|
sb['img_masks'], |
||||
|
sb['gather_index'], self.fix_txt_encoder) |
||||
|
else: |
||||
|
txt_seq, txt_pooled = None, None |
||||
|
|
||||
|
if 'imgs' in batch: |
||||
|
sb = batch['imgs'] |
||||
|
img_seq, img_pooled, img_hidden = self.get_representation(self.img_model, sb['input_ids'], |
||||
|
sb['attention_mask'], sb['position_ids'], |
||||
|
sb['img_feat'], sb['img_pos_feat'], |
||||
|
sb['img_masks'], |
||||
|
sb['gather_index'], self.fix_txt_encoder) |
||||
|
else: |
||||
|
img_seq, img_pooled = None, None |
||||
|
|
||||
|
if 'caps' in batch and batch['caps']['input_ids'] is not None: |
||||
|
sb = batch['caps'] |
||||
|
cap_seq, cap_pooled, cap_hidden = self.get_representation(self.txt_model, sb['input_ids'], |
||||
|
sb['attention_mask'], sb['position_ids'], |
||||
|
sb['img_feat'], sb['img_pos_feat'], |
||||
|
sb['img_masks'], |
||||
|
sb['gather_index'], self.fix_txt_encoder) |
||||
|
else: |
||||
|
cap_seq, cap_pooled = None, None |
||||
|
|
||||
|
if output_all_encoded_layers: |
||||
|
return txt_seq, img_seq, cap_seq |
||||
|
else: |
||||
|
return txt_pooled, img_pooled, cap_pooled |
||||
|
|
||||
|
|
||||
|
class BiEncoderForPretraining(nn.Module): |
||||
|
""" MLM + MRM """ |
||||
|
def __init__(self, config_file, args, project_dim, img_dim, img_label_dim, nce_temp=1, ot_pos_only=False, |
||||
|
experiment=None): |
||||
|
super().__init__() |
||||
|
config = UniterConfig.from_json_file(config_file) |
||||
|
self.bert = BiEncoder(args, project_dim=project_dim) |
||||
|
self.cls = BertOnlyMLMHead( |
||||
|
config, self.bert.img_model.bert.embeddings.word_embeddings.weight) # ??? |
||||
|
self.feat_regress = RegionFeatureRegression( |
||||
|
config.hidden_size, img_dim, |
||||
|
self.bert.img_model.bert.img_embeddings.img_linear.weight) |
||||
|
self.region_classifier = RegionClassification( |
||||
|
config.hidden_size, img_label_dim) |
||||
|
self.itm_output = nn.Linear(config.hidden_size, 2) |
||||
|
self.cls_concat = args.cls_concat |
||||
|
''' |
||||
|
self.nce_output = BertPredictionHeadTransform(config) |
||||
|
self.nce_output = nn.Sequential(BertPredictionHeadTransform(config), |
||||
|
nn.Linear(config.hidden_size, img_dim)) |
||||
|
self.nce_norm = LayerNorm(config.hidden_size, eps=1e-12) |
||||
|
self.nce_temp = nce_temp # temperature |
||||
|
''' |
||||
|
self.ot_pos_only = ot_pos_only |
||||
|
# self.apply(self.init_weights) |
||||
|
self.vocab_pad = 0 |
||||
|
self.experiment = experiment |
||||
|
|
||||
|
def pad_vocab(self): |
||||
|
# FIXME better padding after integrating huggingface ??? |
||||
|
emb_w = self.bert.embeddings.word_embeddings.weight.data |
||||
|
padded_emb_w, n_pad = pad_tensor_to_mul(emb_w) |
||||
|
padded_emb_w = nn.Parameter(padded_emb_w) |
||||
|
self.bert.embeddings.word_embeddings.weight = padded_emb_w |
||||
|
self.cls.predictions.decoder.weight = padded_emb_w |
||||
|
self.vocab_pad = n_pad |
||||
|
|
||||
|
def forward(self, batch, task, compute_loss=True): |
||||
|
batch = defaultdict(lambda: None, batch) |
||||
|
if task == 'mlm': |
||||
|
txt_labels = batch['txt_labels'] |
||||
|
return self.forward_mlm(batch, txt_labels, compute_loss) |
||||
|
elif task == 'mrfr': |
||||
|
img_mask_tgt = batch['img_mask_tgt'] |
||||
|
img_masks = batch['img_masks'] |
||||
|
mrfr_feat_target = batch['feat_targets'] |
||||
|
return self.forward_mrfr(batch, img_masks, img_mask_tgt, mrfr_feat_target, compute_loss) |
||||
|
elif task == 'mrm-nce': |
||||
|
raise NotImplementedError('nce does not work') |
||||
|
img_mask_tgt = batch['img_mask_tgt'] |
||||
|
img_masks = batch['img_masks'] |
||||
|
img_masks_in = batch['img_masks_in'] |
||||
|
feat_target = batch['feat_targets'] |
||||
|
neg_feats = batch['neg_feats'] |
||||
|
return self.forward_mrm_nce(batch, |
||||
|
img_masks_in, img_masks, img_mask_tgt, |
||||
|
feat_target, neg_feats, compute_loss) |
||||
|
elif task == 'itm': |
||||
|
targets = batch['targets'] |
||||
|
ot_inputs = batch['ot_inputs'] |
||||
|
return self.forward_itm(batch, |
||||
|
targets, ot_inputs, compute_loss) |
||||
|
elif task.startswith('mrc'): |
||||
|
img_mask_tgt = batch['img_mask_tgt'] |
||||
|
img_masks = batch['img_masks'] |
||||
|
mrc_label_target = batch['label_targets'] |
||||
|
return self.forward_mrc(batch, |
||||
|
img_masks, img_mask_tgt, |
||||
|
mrc_label_target, task, compute_loss) |
||||
|
else: |
||||
|
raise ValueError('invalid task') |
||||
|
|
||||
|
# MLM |
||||
|
def forward_mlm(self, batch, txt_labels, compute_loss=True): |
||||
|
txt_seq, img_seq, cap_seq = self.bert(batch, output_all_encoded_layers=True) |
||||
|
# get only the text part |
||||
|
|
||||
|
img_cls = img_seq[:, 0:1, :].repeat(1, txt_seq.shape[1], 1) |
||||
|
if self.cls_concat == 'add': |
||||
|
sequence_output = txt_seq + img_cls |
||||
|
elif self.cls_concat == 'multiply': |
||||
|
sequence_output = txt_seq * img_cls |
||||
|
elif len(self.cls_concat) == 0: |
||||
|
sequence_output = txt_seq |
||||
|
else: |
||||
|
raise NotImplementedError(f'{self.cls_concat} not implemented yet') |
||||
|
# only compute masked tokens for better efficiency |
||||
|
masked_output = self._compute_masked_hidden(sequence_output, |
||||
|
txt_labels != -1) |
||||
|
prediction_scores = self._pad_layer_unpad(masked_output, self.cls) |
||||
|
if self.vocab_pad: |
||||
|
prediction_scores = prediction_scores[:, :-self.vocab_pad] |
||||
|
|
||||
|
masked_lm_loss = F.cross_entropy(prediction_scores, |
||||
|
txt_labels[txt_labels != -1], |
||||
|
reduction='none') |
||||
|
return masked_lm_loss, prediction_scores |
||||
|
|
||||
|
def _compute_masked_hidden(self, hidden, mask): |
||||
|
""" get only the masked region (don't compute unnecessary hiddens) """ |
||||
|
mask = mask.unsqueeze(-1).expand_as(hidden) |
||||
|
hidden_masked = hidden[mask].contiguous().view(-1, hidden.size(-1)) |
||||
|
return hidden_masked |
||||
|
|
||||
|
def _pad_layer_unpad(self, input_, layer): |
||||
|
input_, n_pad = pad_tensor_to_mul(input_) |
||||
|
output = layer(input_) |
||||
|
if n_pad: |
||||
|
output = output[:-n_pad, :] |
||||
|
return output |
||||
|
|
||||
|
def mlm_eval(self, batch, gather_tgt): |
||||
|
raise ValueError('Do not use this') |
||||
|
sequence_output = self.bert(batch, output_all_encoded_layers=False) |
||||
|
# get only the text part (excluding [CLS], [SEP]) |
||||
|
sequence_output = sequence_output[:, 1:input_ids.size(1)-1, :] |
||||
|
# only compute masked tokens for better efficiency |
||||
|
index = gather_tgt.unsqueeze(-1).expand( |
||||
|
-1, -1, self.config.hidden_size) |
||||
|
masked_output = torch.gather(sequence_output, dim=0, index=index) |
||||
|
prediction_scores = self.cls(masked_output) |
||||
|
if self.vocab_pad: |
||||
|
prediction_scores = prediction_scores[..., :-self.vocab_pad] |
||||
|
return prediction_scores |
||||
|
|
||||
|
# MRFR |
||||
|
def forward_mrfr(self, batch, img_masks, img_mask_tgt, |
||||
|
feat_targets, compute_loss=True): |
||||
|
txt_seq, img_seq, cap_seq = self.bert(batch, output_all_encoded_layers=True) |
||||
|
txt_cls = txt_seq[:, 0:1, :].repeat(1, img_seq.shape[1], 1) |
||||
|
if self.cls_concat == 'add': |
||||
|
sequence_output = img_seq + txt_cls |
||||
|
elif self.cls_concat == 'multiply': |
||||
|
sequence_output = img_seq * txt_cls |
||||
|
elif len(self.cls_concat) == 0: |
||||
|
sequence_output = img_seq |
||||
|
else: |
||||
|
raise NotImplementedError(f'{self.cls_concat} not implemented yet') |
||||
|
# only compute masked tokens for better efficiency |
||||
|
masked_output = self._compute_masked_hidden(sequence_output, |
||||
|
img_mask_tgt) |
||||
|
prediction_feat = self._pad_layer_unpad(masked_output, |
||||
|
self.feat_regress) |
||||
|
|
||||
|
mrfr_loss = F.mse_loss(prediction_feat, feat_targets, |
||||
|
reduction='none') |
||||
|
return mrfr_loss, prediction_feat |
||||
|
|
||||
|
# MRM-NCE |
||||
|
def forward_mrm_nce(self,batch, |
||||
|
img_masks_in, img_masks, img_mask_tgt, |
||||
|
feat_targets, neg_feats, compute_loss=True): |
||||
|
sequence_output = self.bert(batch, |
||||
|
output_all_encoded_layers=False, |
||||
|
img_masks=img_masks_in) |
||||
|
|
||||
|
# only compute masked tokens for better efficiency |
||||
|
masked_output = self._compute_masked_hidden(sequence_output, |
||||
|
img_mask_tgt) |
||||
|
|
||||
|
masked_output = self._pad_layer_unpad(masked_output, self.nce_output) |
||||
|
# neg within batch |
||||
|
batch_neg = self._compute_masked_hidden(img_feat, ~img_masks) |
||||
|
neg_feats, _ = pad_tensor_to_mul( |
||||
|
torch.cat([neg_feats, batch_neg], dim=0)) |
||||
|
|
||||
|
# shared image linear transform |
||||
|
neg_output = self.nce_norm( |
||||
|
self.bert.img_embeddings.img_linear(neg_feats)) |
||||
|
pos_output = self._pad_layer_unpad(feat_targets, |
||||
|
self.bert.img_embeddings.img_linear) |
||||
|
pos_output = self.nce_norm(pos_output) |
||||
|
|
||||
|
mrm_nce_loss = self.mrm_nce(masked_output, pos_output, |
||||
|
neg_output, compute_loss=True) |
||||
|
return mrm_nce_loss, masked_output |
||||
|
|
||||
|
def mrm_nce(self, masked_output, pos_output, neg_output, |
||||
|
compute_loss=True): |
||||
|
# dot product of ground truth feature |
||||
|
masked_score = masked_output.matmul(pos_output.t()) |
||||
|
# dot product of neative samples |
||||
|
neg_score = masked_output.matmul(neg_output.t()) |
||||
|
|
||||
|
logits = torch.cat([masked_score, neg_score], dim=1).float() |
||||
|
targets = torch.arange(0, masked_output.size(0), |
||||
|
dtype=torch.long, device=logits.device) |
||||
|
loss = F.cross_entropy(logits/self.nce_temp, targets, |
||||
|
reduction='none') |
||||
|
return loss, logits |
||||
|
|
||||
|
def forward_itm(self, batch, targets, ot_inputs, |
||||
|
compute_loss=True): |
||||
|
txt_seq, img_seq, cap_seq = self.bert(batch, output_all_encoded_layers=False) |
||||
|
# OT loss |
||||
|
if ot_inputs is not None: |
||||
|
ot_scatter = ot_inputs['ot_scatter'] |
||||
|
|
||||
|
b = sequence_output.size(0) |
||||
|
tl = input_ids.size(1) |
||||
|
il = img_feat.size(1) |
||||
|
max_l = max(ot_inputs['scatter_max'] + 1, tl+il) |
||||
|
|
||||
|
ot_scatter = ot_scatter.unsqueeze(-1).expand_as(sequence_output) |
||||
|
ctx_emb = torch.zeros(b, max_l, self.config.hidden_size, |
||||
|
dtype=sequence_output.dtype, |
||||
|
device=sequence_output.device |
||||
|
).scatter_(dim=1, index=ot_scatter, |
||||
|
src=sequence_output) |
||||
|
txt_emb = ctx_emb[:, :tl, :] |
||||
|
img_emb = ctx_emb[:, tl:tl+il, :] |
||||
|
|
||||
|
txt_pad = ot_inputs['txt_pad'] |
||||
|
img_pad = ot_inputs['img_pad'] |
||||
|
ot_dist = optimal_transport_dist(txt_emb, img_emb, |
||||
|
txt_pad, img_pad) |
||||
|
if self.ot_pos_only: |
||||
|
ot_loss = ot_dist.masked_select(targets == 1) |
||||
|
else: |
||||
|
ot_pos_dist = ot_dist.masked_select(targets == 1) |
||||
|
ot_neg_dist = ot_dist.masked_select(targets == 0) |
||||
|
ot_loss = (ot_pos_dist, ot_neg_dist) |
||||
|
else: |
||||
|
ot_loss = None |
||||
|
|
||||
|
loss_function = BiEncoderNllLoss() |
||||
|
itm_loss1, is_correct1, scores1 = loss_function.calc(txt_seq, img_seq, cap_seq, |
||||
|
batch['pos_ctx_indices'], |
||||
|
batch['neg_ctx_indices'], |
||||
|
0.0, self.experiment, 'none') |
||||
|
itm_loss2, is_correct2, scores2 = loss_function.calc(img_seq, txt_seq, cap_seq, |
||||
|
batch['pos_ctx_indices'], |
||||
|
batch['neg_ctx_indices'], |
||||
|
0.0, self.experiment, 'none') |
||||
|
if compute_loss: |
||||
|
return itm_loss1*0.5 + itm_loss2*0.5, ot_loss |
||||
|
else: |
||||
|
return itm_loss1*0.5 + itm_loss2*0.5, ot_loss, is_correct1*0.5 + is_correct2*0.5 |
||||
|
|
||||
|
# MRC |
||||
|
def forward_mrc(self, batch, img_masks, img_mask_tgt, |
||||
|
label_targets, task, compute_loss=True): |
||||
|
txt_seq, img_seq, cap_seq = self.bert(batch, output_all_encoded_layers=True) |
||||
|
txt_cls = txt_seq[:, 0:1, :].repeat(1, img_seq.shape[1], 1) |
||||
|
if self.cls_concat == 'add': |
||||
|
sequence_output = img_seq + txt_cls |
||||
|
elif self.cls_concat == 'multiply': |
||||
|
sequence_output = img_seq * txt_cls |
||||
|
elif len(self.cls_concat) == 0: |
||||
|
sequence_output = img_seq |
||||
|
else: |
||||
|
raise NotImplementedError(f'{self.cls_concat} not implemented yet') |
||||
|
|
||||
|
# sequence_output = torch.cat([txt_seq, img_seq], dim=1) |
||||
|
# only compute masked regions for better efficiency |
||||
|
masked_output = self._compute_masked_hidden(sequence_output, img_mask_tgt) |
||||
|
prediction_soft_label = self._pad_layer_unpad(masked_output, |
||||
|
self.region_classifier) |
||||
|
|
||||
|
if "kl" in task: |
||||
|
prediction_soft_label = F.log_softmax( |
||||
|
prediction_soft_label, dim=-1) |
||||
|
mrc_loss = F.kl_div( |
||||
|
prediction_soft_label, label_targets, reduction='none') |
||||
|
else: |
||||
|
# background class should not be the target |
||||
|
label_targets = torch.max(label_targets[:, 1:], dim=-1)[1] + 1 |
||||
|
mrc_loss = F.cross_entropy( |
||||
|
prediction_soft_label, label_targets, |
||||
|
ignore_index=0, reduction='none') |
||||
|
return mrc_loss, prediction_soft_label |
||||
|
|
||||
|
|
||||
|
def get_optimizer(model: nn.Module, learning_rate: float = 1e-5, adam_eps: float = 1e-8, |
||||
|
weight_decay: float = 0.0, ) -> torch.optim.Optimizer: |
||||
|
no_decay = ['bias', 'LayerNorm.weight'] |
||||
|
|
||||
|
optimizer_grouped_parameters = [ |
||||
|
{'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], |
||||
|
'weight_decay': weight_decay}, |
||||
|
{'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} |
||||
|
] |
||||
|
optimizer = AdamW(optimizer_grouped_parameters, lr=learning_rate, eps=adam_eps) |
||||
|
return optimizer |
||||
|
|
||||
|
|
||||
|
def setup_for_distributed_mode(model: nn.Module, optimizer: torch.optim.Optimizer, device: object, n_gpu: int = 1, |
||||
|
local_rank: int = -1, |
||||
|
fp16: bool = False, |
||||
|
fp16_opt_level: str = "O1", |
||||
|
teacher_model = None) -> (nn.Module, torch.optim.Optimizer): |
||||
|
model.to(device) |
||||
|
if teacher_model is not None: |
||||
|
teacher_model.to(device) |
||||
|
if fp16: |
||||
|
try: |
||||
|
import apex |
||||
|
from apex import amp |
||||
|
apex.amp.register_half_function(torch, "einsum") |
||||
|
except ImportError: |
||||
|
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") |
||||
|
|
||||
|
if optimizer is None: |
||||
|
if teacher_model is None: |
||||
|
model = amp.initialize(model, optimizer, opt_level=fp16_opt_level) |
||||
|
else: |
||||
|
model, teacher_model = amp.initialize([model, teacher_model], optimizer, opt_level=fp16_opt_level) |
||||
|
else: |
||||
|
model, optimizer = amp.initialize(model, optimizer, opt_level=fp16_opt_level) |
||||
|
|
||||
|
#if n_gpu > 1: |
||||
|
# model = torch.nn.DataParallel(model) |
||||
|
|
||||
|
# if local_rank != -1: |
||||
|
# model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank], |
||||
|
# output_device=local_rank, |
||||
|
# find_unused_parameters=True) |
||||
|
return model, optimizer |
||||
|
|
||||
|
|
||||
|
class BiEncoderNllLoss(object): |
||||
|
|
||||
|
def calc(self, q_vectors: T, ctx_vectors: T, caption_vectors: T, positive_idx_per_question: list, |
||||
|
hard_negatice_idx_per_question: list = None, caption_score_weight: float = 0.1, |
||||
|
experiment=None, reduction='mean'): |
||||
|
""" |
||||
|
Computes nll loss for the given lists of question and ctx vectors. |
||||
|
Note that although hard_negatice_idx_per_question in not currently in use, one can use it for the |
||||
|
loss modifications. For example - weighted NLL with different factors for hard vs regular negatives. |
||||
|
:return: a tuple of loss value and amount of correct predictions per batch |
||||
|
""" |
||||
|
scores_img = self.get_scores(q_vectors, ctx_vectors) |
||||
|
if caption_vectors is not None and caption_score_weight != 0: |
||||
|
scores_caption = self.get_scores(q_vectors, caption_vectors) |
||||
|
scores = (1 - caption_score_weight) * scores_img + caption_score_weight * scores_caption |
||||
|
else: |
||||
|
scores = scores_img |
||||
|
|
||||
|
if experiment is not None: |
||||
|
experiment.log_metric('score_img_diag_mean', torch.diag(scores_img).mean().item()) |
||||
|
experiment.log_metric('score_img_offdiag_mean', (scores_img.sum() - torch.diag(scores_img).sum()) / |
||||
|
(torch.numel(scores_img)-len(torch.diag(scores_img)))) |
||||
|
|
||||
|
experiment.log_metric('score_diag_mean', torch.diag(scores).mean().item()) |
||||
|
experiment.log_metric('score_offdiag_mean', (scores.sum() - torch.diag(scores).sum()) / |
||||
|
(torch.numel(scores) - len(torch.diag(scores)))) |
||||
|
|
||||
|
if caption_vectors is not None and caption_score_weight != 0: |
||||
|
experiment.log_metric('score_caption_diag_mean', torch.diag(scores_caption).mean().item()) |
||||
|
experiment.log_metric('score_caption_offdiag_mean', (scores_caption.sum() - torch.diag(scores_caption).sum()) / |
||||
|
(torch.numel(scores_caption) - len(torch.diag(scores_caption)))) |
||||
|
|
||||
|
if len(q_vectors.size()) > 1: |
||||
|
q_num = q_vectors.size(0) |
||||
|
scores = scores.view(q_num, -1) |
||||
|
|
||||
|
softmax_scores = F.log_softmax(scores, dim=1) |
||||
|
|
||||
|
loss = F.nll_loss(softmax_scores, torch.tensor(positive_idx_per_question).to(softmax_scores.device), |
||||
|
reduction=reduction) |
||||
|
|
||||
|
max_score, max_idxs = torch.max(softmax_scores, 1) |
||||
|
correct_predictions_count = (max_idxs == torch.tensor(positive_idx_per_question).to(max_idxs.device)).sum() |
||||
|
return loss, correct_predictions_count, scores |
||||
|
|
||||
|
@staticmethod |
||||
|
def get_scores(q_vector: T, ctx_vectors: T) -> T: |
||||
|
f = BiEncoderNllLoss.get_similarity_function() |
||||
|
return f(q_vector, ctx_vectors) |
||||
|
|
||||
|
@staticmethod |
||||
|
def get_similarity_function(): |
||||
|
return dot_product_scores |
||||
|
|
||||
|
|
||||
|
def get_schedule_linear(optimizer, warmup_steps, training_steps, last_epoch=-1): |
||||
|
""" Create a schedule with a learning rate that decreases linearly after |
||||
|
linearly increasing during a warmup period. |
||||
|
""" |
||||
|
|
||||
|
def lr_lambda(current_step): |
||||
|
if current_step < warmup_steps: |
||||
|
return float(current_step) / float(max(1, warmup_steps)) |
||||
|
return max( |
||||
|
0.0, float(training_steps - current_step) / float(max(1, training_steps - warmup_steps)) |
||||
|
) |
||||
|
|
||||
|
return LambdaLR(optimizer, lr_lambda, last_epoch) |
||||
|
|
||||
|
|
||||
|
class BiEncoderForVisualQuestionAnswering(nn.Module): |
||||
|
""" Finetune multi-modal BERT for VQA |
||||
|
""" |
||||
|
def __init__(self, args, fix_img_encoder: bool = False, fix_txt_encoder: bool = False, |
||||
|
seperate_caption_encoder: bool = False, |
||||
|
project_dim: int = 0, |
||||
|
hidden_size: int = 0, num_answer: int = 0, intersection=False): |
||||
|
super(BiEncoderForVisualQuestionAnswering, self).__init__() |
||||
|
self.biencoder = BiEncoder(args, fix_img_encoder, fix_txt_encoder, project_dim) |
||||
|
self.intersection = intersection |
||||
|
if self.intersection: |
||||
|
hidden_size *= 2 |
||||
|
self.vqa_output = nn.Sequential( |
||||
|
nn.Linear(hidden_size, hidden_size*2), |
||||
|
GELU(), |
||||
|
LayerNorm(hidden_size*2, eps=1e-12), |
||||
|
nn.Linear(hidden_size*2, num_answer) |
||||
|
) |
||||
|
self.init_weights(self.vqa_output) |
||||
|
|
||||
|
def forward(self, batch, compute_loss=True, targets=None) -> Tuple[T, T]: |
||||
|
|
||||
|
q_pooled, ctx_pooled, caption_pooled = self.biencoder(batch) |
||||
|
|
||||
|
if self.intersection: |
||||
|
pooled_output = torch.cat([q_pooled, ctx_pooled, q_pooled*ctx_pooled, q_pooled + ctx_pooled], dim=1) |
||||
|
else: |
||||
|
pooled_output = torch.cat([q_pooled, ctx_pooled], dim=1) |
||||
|
|
||||
|
answer_scores = self.vqa_output(pooled_output) |
||||
|
|
||||
|
if compute_loss: |
||||
|
vqa_loss = F.binary_cross_entropy_with_logits( |
||||
|
answer_scores, targets, reduction='none') |
||||
|
return vqa_loss |
||||
|
else: |
||||
|
return answer_scores |
||||
|
|
||||
|
def init_weights(self, module): |
||||
|
""" Initialize the weights. |
||||
|
""" |
||||
|
if isinstance(module, (nn.Linear, nn.Embedding)): |
||||
|
# Slightly different from the TF version which uses |
||||
|
# truncated_normal for initialization |
||||
|
# cf https://github.com/pytorch/pytorch/pull/5617 |
||||
|
module.weight.data.normal_(mean=0.0, |
||||
|
std=0.02) |
||||
|
elif isinstance(module, LayerNorm): |
||||
|
module.bias.data.zero_() |
||||
|
module.weight.data.fill_(1.0) |
||||
|
if isinstance(module, nn.Linear) and module.bias is not None: |
||||
|
module.bias.data.zero_() |
||||
|
|
||||
|
|
||||
|
def load_biencoder_checkpoint(bi_encoder, biencoder_checkpoint): |
||||
|
if biencoder_checkpoint is not None and len(biencoder_checkpoint) > 0 and biencoder_checkpoint.lower() != 'none': |
||||
|
logger.info(f'loading ckpt from {biencoder_checkpoint}') |
||||
|
state_dict = torch.load(biencoder_checkpoint, map_location='cpu') |
||||
|
try: |
||||
|
bi_encoder.load_state_dict(state_dict['model_dict']) |
||||
|
except KeyError: |
||||
|
logger.info('loading from pre-trained model instead') |
||||
|
for k in list(state_dict.keys()): |
||||
|
if k.startswith('bert.'): |
||||
|
state_dict[k[5:]] = state_dict.pop(k) |
||||
|
else: |
||||
|
state_dict.pop(k) |
||||
|
bi_encoder.load_state_dict(state_dict, strict=True) |
||||
|
else: |
||||
|
logger.info('no checkpoint provided, pass') |
@ -0,0 +1,176 @@ |
|||||
|
import argparse |
||||
|
import json |
||||
|
import sys |
||||
|
import os |
||||
|
import logging |
||||
|
import torch |
||||
|
import random |
||||
|
import socket |
||||
|
import numpy as np |
||||
|
|
||||
|
|
||||
|
logger = logging.getLogger() |
||||
|
|
||||
|
|
||||
|
def default_params(parser: argparse.ArgumentParser): |
||||
|
parser.add_argument('--txt_model_type', default='bert-base', type=str, help="") |
||||
|
parser.add_argument('--txt_model_config', default='bert-base', type=str, help="") |
||||
|
parser.add_argument('--txt_checkpoint', default=None, type=str, help="") |
||||
|
parser.add_argument('--img_model_type', default='uniter-base', type=str, help="") |
||||
|
parser.add_argument('--img_model_config', default='./config/img_base.json', type=str, help="") |
||||
|
parser.add_argument('--img_checkpoint', default=None, type=str, help="") |
||||
|
parser.add_argument('--biencoder_checkpoint', default=None, type=str, help="") |
||||
|
parser.add_argument('--seperate_caption_encoder', action='store_true', help="") |
||||
|
|
||||
|
parser.add_argument('--train_batch_size', default=80, type=int, help="") |
||||
|
parser.add_argument('--valid_batch_size', default=80, type=int, help="") |
||||
|
parser.add_argument('--gradient_accumulation_steps', default=1, type=int, help="") |
||||
|
parser.add_argument('--learning_rate', default=1e-5, type=float, help="") |
||||
|
parser.add_argument('--max_grad_norm', default=2.0, type=float, help="") |
||||
|
parser.add_argument('--warmup_steps', default=500, type=int, help="") |
||||
|
parser.add_argument('--valid_steps', default=500, type=int, help="") |
||||
|
parser.add_argument('--num_train_steps', default=5000, type=int, help="") |
||||
|
parser.add_argument('--num_train_epochs', default=0, type=int, help="") |
||||
|
|
||||
|
parser.add_argument('--fp16', action='store_true', help="") |
||||
|
parser.add_argument('--seed', default=42, type=int, help="") |
||||
|
parser.add_argument('--output_dir', default='./', type=str, help="") |
||||
|
parser.add_argument('--max_txt_len', default=64, type=int, help="") |
||||
|
parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus") |
||||
|
parser.add_argument('--config', default=None, type=str, help="") |
||||
|
parser.add_argument('--itm_global_file', default=None, type=str, help="") |
||||
|
parser.add_argument("--no_cuda", action='store_true', help="Whether not to use CUDA when available") |
||||
|
parser.add_argument('--n_workers', type=int, default=2, help="number of data workers") |
||||
|
parser.add_argument('--pin_mem', action='store_true', help="pin memory") # ??? |
||||
|
parser.add_argument('--hnsw_index', action='store_true', help="") |
||||
|
parser.add_argument('--fp16_opt_level', type=str, default='O1', help="") |
||||
|
parser.add_argument('--img_meta', type=str, default=None, help="") |
||||
|
|
||||
|
|
||||
|
def add_itm_params(parser: argparse.ArgumentParser): |
||||
|
parser.add_argument('--conf_th', default=0.2, type=float, help="") |
||||
|
parser.add_argument('--caption_score_weight', default=0.0, type=float, help="") |
||||
|
parser.add_argument('--negative_size', default=10, type=int, help="") |
||||
|
parser.add_argument('--num_hard_negatives', default=0, type=int, help="") |
||||
|
parser.add_argument('--sample_init_hard_negatives', action='store_true', help="") |
||||
|
parser.add_argument('--hard_negatives_sampling', default='none', type=str, |
||||
|
choices=['none', 'random', 'top', 'top-random', '10-20', '20-30'], help="") |
||||
|
parser.add_argument('--max_bb', default=100, type=int, help="") |
||||
|
parser.add_argument('--min_bb', default=10, type=int, help="") |
||||
|
parser.add_argument('--num_bb', default=36, type=int, help="") |
||||
|
parser.add_argument('--train_txt_dbs', default=None, type=str, help="") |
||||
|
parser.add_argument('--train_img_dbs', default=None, type=str, help="") |
||||
|
|
||||
|
parser.add_argument('--txt_db_mapping', default=None, type=str, help="") |
||||
|
parser.add_argument('--img_db_mapping', default=None, type=str, help="") |
||||
|
parser.add_argument('--pretrain_mapping', default=None, type=str, help="") |
||||
|
|
||||
|
parser.add_argument('--val_txt_db', default=None, type=str, help="") |
||||
|
parser.add_argument('--val_img_db', default=None, type=str, help="") |
||||
|
parser.add_argument('--test_txt_db', default=None, type=str, help="") |
||||
|
parser.add_argument('--test_img_db', default=None, type=str, help="") |
||||
|
parser.add_argument('--steps_per_hard_neg', default=-1, type=int, help="") |
||||
|
parser.add_argument('--inf_minibatch_size', default=400, type=int, help="") |
||||
|
parser.add_argument('--project_dim', default=0, type=int, help='') |
||||
|
parser.add_argument('--cls_concat', default="", type=str, help='') |
||||
|
parser.add_argument('--fix_txt_encoder', action='store_true', help='') |
||||
|
parser.add_argument('--fix_img_encoder', action='store_true', help='') |
||||
|
parser.add_argument('--compressed_db', action='store_true', help='use compressed LMDB') |
||||
|
parser.add_argument('--retrieval_mode', default="both", |
||||
|
choices=['img_only', 'txt_only', 'both'], type=str, help="") |
||||
|
|
||||
|
|
||||
|
def add_logging_params(parser: argparse.ArgumentParser): |
||||
|
parser.add_argument('--log_result_step', default=4, type=int, help="") |
||||
|
parser.add_argument('--project_name', default='itm', type=str, help="") |
||||
|
parser.add_argument('--expr_name_prefix', default='', type=str, help="") |
||||
|
parser.add_argument('--save_all_epochs', action='store_true', help="") |
||||
|
|
||||
|
|
||||
|
def add_kd_params(parser: argparse.ArgumentParser): |
||||
|
parser.add_argument('--teacher_checkpoint', default=None, type=str, help="") |
||||
|
parser.add_argument('--T', default=1.0, type=float, help="") |
||||
|
parser.add_argument('--kd_loss_weight', default=1.0, type=float, help="") |
||||
|
|
||||
|
|
||||
|
def parse_with_config(parser, cmds=None): |
||||
|
if cmds is None: |
||||
|
args = parser.parse_args() |
||||
|
else: |
||||
|
args = parser.parse_args(cmds) |
||||
|
|
||||
|
if args.config is not None: |
||||
|
config_args = json.load(open(args.config)) |
||||
|
override_keys = {arg[2:].split('=')[0] for arg in sys.argv[1:] |
||||
|
if arg.startswith('--')} |
||||
|
for k, v in config_args.items(): |
||||
|
if k not in override_keys: |
||||
|
setattr(args, k, v) |
||||
|
return args |
||||
|
|
||||
|
|
||||
|
def map_db_dirs(args): |
||||
|
# map img db |
||||
|
for k in args.__dict__: |
||||
|
if not isinstance(args.__dict__[k], str): |
||||
|
continue |
||||
|
if args.__dict__[k].startswith('/pretrain') and args.pretrain_mapping: |
||||
|
print('pretrain', k, args.__dict__[k]) |
||||
|
args.__dict__[k] = args.__dict__[k].replace('/pretrain', args.pretrain_mapping) |
||||
|
if args.__dict__[k].startswith('/db') and args.txt_db_mapping: |
||||
|
print('db', k, args.__dict__[k]) |
||||
|
args.__dict__[k] = args.__dict__[k].replace('/db', args.txt_db_mapping) |
||||
|
if args.__dict__[k].startswith('/img') and args.img_db_mapping: |
||||
|
print('img', k, args.__dict__[k]) |
||||
|
args.__dict__[k] = args.__dict__[k].replace('/img', args.img_db_mapping) |
||||
|
|
||||
|
if args.img_db_mapping: |
||||
|
for i in range(len(args.train_img_dbs)): |
||||
|
args.train_img_dbs[i] = args.train_img_dbs[i].replace('/img', args.img_db_mapping) |
||||
|
if args.txt_db_mapping: |
||||
|
for i in range(len(args.train_txt_dbs)): |
||||
|
args.train_txt_dbs[i] = args.train_txt_dbs[i].replace('/db', args.txt_db_mapping) |
||||
|
|
||||
|
|
||||
|
|
||||
|
|
||||
|
def print_args(args): |
||||
|
logger.info(" **************** CONFIGURATION **************** ") |
||||
|
for key, val in sorted(vars(args).items()): |
||||
|
keystr = "{}".format(key) + (" " * (30 - len(key))) |
||||
|
logger.info("%s --> %s", keystr, val) |
||||
|
logger.info(" **************** END CONFIGURATION **************** ") |
||||
|
|
||||
|
|
||||
|
def set_seed(args): |
||||
|
seed = args.seed |
||||
|
random.seed(seed) |
||||
|
np.random.seed(seed) |
||||
|
torch.manual_seed(seed) |
||||
|
if args.n_gpu > 0: |
||||
|
torch.cuda.manual_seed_all(seed) |
||||
|
|
||||
|
|
||||
|
def setup_args_gpu(args): |
||||
|
""" |
||||
|
Setup arguments CUDA, GPU & distributed training |
||||
|
""" |
||||
|
if args.local_rank == -1 or args.no_cuda: # single-node multi-gpu (or cpu) mode |
||||
|
device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") |
||||
|
args.n_gpu = torch.cuda.device_count() |
||||
|
else: # distributed mode |
||||
|
torch.cuda.set_device(args.local_rank) |
||||
|
device = torch.device("cuda", args.local_rank) |
||||
|
torch.distributed.init_process_group(backend="nccl") |
||||
|
args.n_gpu = 1 |
||||
|
args.device = device |
||||
|
ws = os.environ.get('WORLD_SIZE') |
||||
|
|
||||
|
args.distributed_world_size = int(ws) if ws else 1 |
||||
|
|
||||
|
logger.info( |
||||
|
'Initialized host %s as d.rank %d on device=%s, n_gpu=%d, world size=%d', socket.gethostname(), |
||||
|
args.local_rank, device, |
||||
|
args.n_gpu, |
||||
|
args.distributed_world_size) |
||||
|
logger.info("16-bits training: %s ", args.fp16) |
@ -0,0 +1,209 @@ |
|||||
|
import collections |
||||
|
import os |
||||
|
import torch |
||||
|
import tqdm |
||||
|
import logging |
||||
|
import numpy as np |
||||
|
import torch.nn as nn |
||||
|
from torch.utils.data import DataLoader, ConcatDataset, ChainDataset |
||||
|
from uniter_model.data.loader import PrefetchLoader |
||||
|
|
||||
|
from dvl.data.itm import TxtTokLmdb, ItmFastDataset, ItmValDataset, itm_fast_collate |
||||
|
from dvl.models.bi_encoder import BiEncoderNllLoss |
||||
|
from dvl.utils import _calc_loss |
||||
|
from dvl.indexer.faiss_indexers import DenseFlatIndexer, DenseHNSWFlatIndexer |
||||
|
|
||||
|
|
||||
|
logger = logging.getLogger() |
||||
|
CheckpointState = collections.namedtuple("CheckpointState", |
||||
|
['model_dict', 'optimizer_dict', 'scheduler_dict', 'offset', 'epoch', |
||||
|
'encoder_params']) |
||||
|
|
||||
|
|
||||
|
class BiEncoderTrainer: |
||||
|
def __init__(self, args): |
||||
|
pass |
||||
|
|
||||
|
|
||||
|
def build_dataloader(dataset, collate_fn, is_train, opts, batch_size=None): |
||||
|
if batch_size is None: |
||||
|
batch_size = opts.train_batch_size if is_train else opts.valid_batch_size |
||||
|
|
||||
|
dataloader = DataLoader(dataset, batch_size=batch_size, |
||||
|
shuffle=is_train, drop_last=False, |
||||
|
num_workers=opts.n_workers, |
||||
|
pin_memory=opts.pin_mem, collate_fn=collate_fn) |
||||
|
dataloader = PrefetchLoader(dataloader) |
||||
|
return dataloader |
||||
|
|
||||
|
|
||||
|
def get_model_obj(model: nn.Module): |
||||
|
return model.module if hasattr(model, 'module') else model |
||||
|
|
||||
|
|
||||
|
def _save_checkpoint(args, biencoder, optimizer, scheduler, epoch: int, offset: int, cp_name: str = None) -> str: |
||||
|
model_to_save = get_model_obj(biencoder) |
||||
|
if cp_name is None: |
||||
|
cp = os.path.join(args.output_dir, 'biencoder.' + str(epoch) + ('.' + str(offset) if offset > 0 else '')) |
||||
|
else: |
||||
|
cp = os.path.join(args.output_dir, 'biencoder.' + cp_name) |
||||
|
cp += '.pt' |
||||
|
|
||||
|
|
||||
|
meta_params = None |
||||
|
|
||||
|
state = CheckpointState(model_to_save.state_dict(), |
||||
|
optimizer.state_dict(), |
||||
|
scheduler.state_dict(), |
||||
|
offset, |
||||
|
epoch, meta_params |
||||
|
) |
||||
|
torch.save(state._asdict(), cp) |
||||
|
logger.info('Saved checkpoint at %s', cp) |
||||
|
return cp |
||||
|
|
||||
|
|
||||
|
def load_saved_state(biencoder, optimizer=None, scheduler=None, saved_state: CheckpointState = ''): |
||||
|
epoch = saved_state.epoch |
||||
|
offset = saved_state.offset |
||||
|
if offset == 0: # epoch has been completed |
||||
|
epoch += 1 |
||||
|
logger.info('Loading checkpoint @ batch=%s and epoch=%s', offset, epoch) |
||||
|
|
||||
|
model_to_load = get_model_obj(biencoder) |
||||
|
logger.info('Loading saved model state ...') |
||||
|
model_to_load.load_state_dict(saved_state.model_dict) # set strict=False if you use extra projection |
||||
|
|
||||
|
if saved_state.optimizer_dict and optimizer is not None: |
||||
|
logger.info('Loading saved optimizer state ...') |
||||
|
optimizer.load_state_dict(saved_state.optimizer_dict) |
||||
|
|
||||
|
if saved_state.scheduler_dict and scheduler is not None: |
||||
|
scheduler_state = saved_state.scheduler_dict |
||||
|
scheduler.load_state_dict(scheduler_state) |
||||
|
|
||||
|
|
||||
|
def load_states_from_checkpoint(model_file: str) -> CheckpointState: |
||||
|
logger.info('Reading saved model from %s', model_file) |
||||
|
state_dict = torch.load(model_file, map_location='cpu') |
||||
|
logger.info('model_state_dict keys %s', state_dict.keys()) |
||||
|
return CheckpointState(**state_dict) |
||||
|
|
||||
|
|
||||
|
def get_indexer(bi_encoder, eval_dataloader, args, hnsw_index, img_retrieval=True): |
||||
|
bi_encoder.eval() |
||||
|
img_embedding = dict() |
||||
|
|
||||
|
if hnsw_index: |
||||
|
indexer_img = DenseHNSWFlatIndexer(args.vector_size) # modify in future |
||||
|
else: |
||||
|
indexer_img = DenseFlatIndexer(args.vector_size) # modify in future |
||||
|
for i, batch in enumerate(tqdm.tqdm(eval_dataloader)): |
||||
|
with torch.no_grad(): |
||||
|
model_out = bi_encoder(batch) |
||||
|
local_q_vector, local_ctx_vectors, local_caption_vectors = model_out |
||||
|
if img_retrieval: |
||||
|
img_embedding.update({img_id: img_vec.detach().cpu().numpy() for img_id, img_vec in zip(batch['img_fname'], local_ctx_vectors)}) |
||||
|
else: |
||||
|
img_embedding.update({img_id: txt_vec.detach().cpu().numpy() for img_id, txt_vec in zip(batch['txt_index'], local_q_vector)}) |
||||
|
indexer_img.index_data(list(img_embedding.items())) |
||||
|
return indexer_img |
||||
|
|
||||
|
|
||||
|
def eval_model_on_dataloader(bi_encoder, eval_dataloader, args, img2txt=None, num_tops=100, no_eval=False): |
||||
|
total_loss = 0.0 |
||||
|
bi_encoder.eval() |
||||
|
total_correct_predictions = 0 |
||||
|
batches, total_samples = 0, 0 |
||||
|
labels_img_name = [] |
||||
|
labels_txt_name = [] |
||||
|
img_embedding = dict() |
||||
|
txt_embedding = dict() |
||||
|
if args.hnsw_index: |
||||
|
indexer_img = DenseHNSWFlatIndexer(args.vector_size) # modify in future |
||||
|
indexer_txt = DenseHNSWFlatIndexer(args.vector_size) # modify in future |
||||
|
else: |
||||
|
indexer_img = DenseFlatIndexer(args.vector_size) # modify in future |
||||
|
indexer_txt = DenseFlatIndexer(args.vector_size) # modify in future |
||||
|
query_txt, query_txt_id = [], [] |
||||
|
query_img, query_img_id = [], [] |
||||
|
for i, batch in enumerate(eval_dataloader): |
||||
|
with torch.no_grad(): |
||||
|
model_out = bi_encoder(batch) |
||||
|
local_q_vector, local_ctx_vectors, local_caption_vectors = model_out |
||||
|
|
||||
|
query_txt.extend([out.view(-1).detach().cpu().numpy() for out in local_q_vector]) |
||||
|
query_txt_id.extend(batch['txt_index']) |
||||
|
|
||||
|
query_img.extend([out.view(-1).detach().cpu().numpy() for out in local_ctx_vectors]) |
||||
|
query_img_id.extend(batch['img_fname']) |
||||
|
|
||||
|
loss_function = BiEncoderNllLoss() |
||||
|
|
||||
|
loss, correct_cnt, score = _calc_loss(args, loss_function, local_q_vector, local_ctx_vectors, local_caption_vectors, |
||||
|
list(range(len(local_q_vector))), None) |
||||
|
|
||||
|
total_loss += loss.item() |
||||
|
total_correct_predictions += correct_cnt.sum().item() |
||||
|
batches += 1 |
||||
|
total_samples += batch['txts']['input_ids'].shape[0] |
||||
|
|
||||
|
img_embedding.update({img_id: img_vec.detach().cpu().numpy() for img_id, img_vec in zip(batch['img_fname'], local_ctx_vectors)}) |
||||
|
txt_embedding.update({img_id: txt_vec.detach().cpu().numpy() for img_id, txt_vec in zip(batch['txt_index'], local_q_vector)}) |
||||
|
labels_img_name.extend(batch['img_fname']) |
||||
|
labels_txt_name.extend(batch['txt_index']) |
||||
|
|
||||
|
total_loss = total_loss / batches |
||||
|
correct_ratio = total_correct_predictions / float(total_samples) |
||||
|
|
||||
|
query_txt_np = np.array(query_txt) |
||||
|
indexer_img.index_data(list(img_embedding.items())) |
||||
|
query_img_np = np.array(query_img) |
||||
|
indexer_txt.index_data(list(txt_embedding.items())) |
||||
|
|
||||
|
if no_eval: |
||||
|
return total_loss, correct_ratio, (indexer_img, indexer_txt), (None, None), (None, None) |
||||
|
else: |
||||
|
res_txt = indexer_img.search_knn(query_txt_np, num_tops) |
||||
|
rank_txt_res = {query_txt_id[i]: r[0] for i, r in enumerate(res_txt)} |
||||
|
|
||||
|
res_img = indexer_txt.search_knn(query_img_np, num_tops) |
||||
|
rank_img_res = {query_img_id[i]: r[0] for i, r in enumerate(res_img)} |
||||
|
|
||||
|
recall_txt = {1: 0, 5: 0, 10: 0} |
||||
|
for i, q in enumerate(query_txt_id): |
||||
|
for top in recall_txt: |
||||
|
recall_txt[top] += labels_img_name[i] in rank_txt_res[q][:top] |
||||
|
|
||||
|
for top in recall_txt: |
||||
|
recall_txt[top] = recall_txt[top] / len(rank_txt_res) |
||||
|
|
||||
|
recall_img = {1: 0, 5: 0, 10: 0} |
||||
|
for i, q in enumerate(np.unique(query_img_id)): |
||||
|
for top in recall_img: |
||||
|
# recall_img[top] += any([txt_id in rank_img_res[q][:top] for txt_id in img2txt[q]]) |
||||
|
recall_img[top] += any([txt_id in rank_img_res[q][:top] for txt_id in img2txt[q]]) |
||||
|
|
||||
|
for top in recall_img: |
||||
|
recall_img[top] = recall_img[top] / len(rank_img_res) |
||||
|
|
||||
|
return total_loss, correct_ratio, (indexer_img, indexer_txt), (recall_txt, recall_img), (rank_txt_res, rank_img_res) |
||||
|
|
||||
|
|
||||
|
def load_dataset(all_img_dbs, txt_dbs, img_dbs, args, is_train): |
||||
|
if is_train: |
||||
|
# train datasets |
||||
|
datasets = [] |
||||
|
for txt_path, img_path in zip(txt_dbs, img_dbs): |
||||
|
img_db = all_img_dbs[img_path] |
||||
|
txt_db = TxtTokLmdb(txt_path, args.max_txt_len) |
||||
|
datasets.append(ItmFastDataset(txt_db, img_db, args.num_hard_negatives, args.img_meta, args.tokenizer)) |
||||
|
|
||||
|
datasets = ConcatDataset(datasets) # |
||||
|
else: |
||||
|
# eval or test |
||||
|
img_db = all_img_dbs[img_dbs] |
||||
|
txt_db = TxtTokLmdb(txt_dbs, -1) |
||||
|
datasets = ItmFastDataset(txt_db, img_db, args.inf_minibatch_size, args.img_meta, args.tokenizer) |
||||
|
|
||||
|
return datasets |
@ -0,0 +1,234 @@ |
|||||
|
import logging |
||||
|
import random |
||||
|
import tqdm |
||||
|
import torch |
||||
|
import pickle |
||||
|
|
||||
|
import torch.distributed as dist |
||||
|
|
||||
|
from collections import defaultdict |
||||
|
from horovod import torch as hvd |
||||
|
from torch import Tensor as T |
||||
|
from typing import Tuple |
||||
|
|
||||
|
|
||||
|
logger = logging.getLogger() |
||||
|
|
||||
|
|
||||
|
def get_rank(): |
||||
|
return hvd.rank() |
||||
|
|
||||
|
|
||||
|
def get_world_size(): |
||||
|
return hvd.size() |
||||
|
|
||||
|
|
||||
|
def print_args(args): |
||||
|
logger.info(" **************** CONFIGURATION **************** ") |
||||
|
for key, val in sorted(vars(args).items()): |
||||
|
keystr = "{}".format(key) + (" " * (30 - len(key))) |
||||
|
logger.info("%s --> %s", keystr, val) |
||||
|
logger.info(" **************** CONFIGURATION **************** ") |
||||
|
|
||||
|
|
||||
|
def num_of_parameters(model, requires_grad=False): |
||||
|
if requires_grad: |
||||
|
return sum(p.numel() for p in model.parameters() if p.requires_grad) |
||||
|
else: |
||||
|
return sum(p.numel() for p in model.parameters()) |
||||
|
|
||||
|
|
||||
|
def get_default_group(): |
||||
|
return dist.group.WORLD |
||||
|
|
||||
|
|
||||
|
def all_reduce(tensor, group=None): |
||||
|
if group is None: |
||||
|
group = get_default_group() |
||||
|
return dist.all_reduce(tensor, group=group) |
||||
|
|
||||
|
|
||||
|
def all_gather_list(data, group=None, max_size=16384): |
||||
|
"""Gathers arbitrary data from all nodes into a list. |
||||
|
Similar to :func:`~torch.distributed.all_gather` but for arbitrary Python |
||||
|
data. Note that *data* must be picklable. |
||||
|
Args: |
||||
|
data (Any): data from the local worker to be gathered on other workers |
||||
|
group (optional): group of the collective |
||||
|
""" |
||||
|
SIZE_STORAGE_BYTES = 4 # int32 to encode the payload size |
||||
|
|
||||
|
enc = pickle.dumps(data) |
||||
|
enc_size = len(enc) |
||||
|
|
||||
|
if enc_size + SIZE_STORAGE_BYTES > max_size: |
||||
|
raise ValueError( |
||||
|
'encoded data exceeds max_size, this can be fixed by increasing buffer size: {}'.format(enc_size)) |
||||
|
|
||||
|
rank = get_rank() |
||||
|
world_size = get_world_size() |
||||
|
buffer_size = max_size * world_size |
||||
|
|
||||
|
if not hasattr(all_gather_list, '_buffer') or \ |
||||
|
all_gather_list._buffer.numel() < buffer_size: |
||||
|
all_gather_list._buffer = torch.cuda.ByteTensor(buffer_size) |
||||
|
all_gather_list._cpu_buffer = torch.ByteTensor(max_size).pin_memory() |
||||
|
|
||||
|
buffer = all_gather_list._buffer |
||||
|
buffer.zero_() |
||||
|
cpu_buffer = all_gather_list._cpu_buffer |
||||
|
|
||||
|
assert enc_size < 256 ** SIZE_STORAGE_BYTES, 'Encoded object size should be less than {} bytes'.format( |
||||
|
256 ** SIZE_STORAGE_BYTES) |
||||
|
|
||||
|
size_bytes = enc_size.to_bytes(SIZE_STORAGE_BYTES, byteorder='big') |
||||
|
|
||||
|
cpu_buffer[0:SIZE_STORAGE_BYTES] = torch.ByteTensor(list(size_bytes)) |
||||
|
cpu_buffer[SIZE_STORAGE_BYTES: enc_size + SIZE_STORAGE_BYTES] = torch.ByteTensor(list(enc)) |
||||
|
|
||||
|
start = rank * max_size |
||||
|
size = enc_size + SIZE_STORAGE_BYTES |
||||
|
buffer[start: start + size].copy_(cpu_buffer[:size]) |
||||
|
|
||||
|
all_reduce(buffer, group=group) |
||||
|
|
||||
|
try: |
||||
|
result = [] |
||||
|
for i in range(world_size): |
||||
|
out_buffer = buffer[i * max_size: (i + 1) * max_size] |
||||
|
size = int.from_bytes(out_buffer[0:SIZE_STORAGE_BYTES], byteorder='big') |
||||
|
if size > 0: |
||||
|
result.append(pickle.loads(bytes(out_buffer[SIZE_STORAGE_BYTES: size + SIZE_STORAGE_BYTES].tolist()))) |
||||
|
return result |
||||
|
except pickle.UnpicklingError: |
||||
|
raise Exception( |
||||
|
'Unable to unpickle data from other workers. all_gather_list requires all ' |
||||
|
'workers to enter the function together, so this error usually indicates ' |
||||
|
'that the workers have fallen out of sync somehow. Workers can fall out of ' |
||||
|
'sync if one of them runs out of memory, or if there are other conditions ' |
||||
|
'in your training script that can cause one worker to finish an epoch ' |
||||
|
'while other workers are still iterating over their portions of the data.' |
||||
|
) |
||||
|
|
||||
|
|
||||
|
def _calc_loss(args, loss_function, local_q_vector, local_ctx_vectors, local_caption_vectors, local_positive_idxs, |
||||
|
local_hard_negatives_idxs: list = None, experiment=None |
||||
|
): |
||||
|
""" |
||||
|
Calculates In-batch negatives schema loss and supports to run it in DDP mode by exchanging the representations |
||||
|
across all the nodes. |
||||
|
""" |
||||
|
distributed_world_size = 1 # args.distributed_world_size or 1 |
||||
|
if distributed_world_size > 1: |
||||
|
# TODO: Add local_caption_vectors |
||||
|
q_vector_to_send = torch.empty_like(local_q_vector).cpu().copy_(local_q_vector).detach_() |
||||
|
ctx_vector_to_send = torch.empty_like(local_ctx_vectors).cpu().copy_(local_ctx_vectors).detach_() |
||||
|
|
||||
|
global_question_ctx_vectors = all_gather_list( |
||||
|
[q_vector_to_send, ctx_vector_to_send, local_positive_idxs, local_hard_negatives_idxs], |
||||
|
max_size=args.global_loss_buf_sz) |
||||
|
|
||||
|
global_q_vector = [] |
||||
|
global_ctxs_vector = [] |
||||
|
|
||||
|
# ctxs_per_question = local_ctx_vectors.size(0) |
||||
|
positive_idx_per_question = [] |
||||
|
hard_negatives_per_question = [] |
||||
|
|
||||
|
total_ctxs = 0 |
||||
|
|
||||
|
for i, item in enumerate(global_question_ctx_vectors): |
||||
|
q_vector, ctx_vectors, positive_idx, hard_negatives_idxs = item |
||||
|
|
||||
|
if i != args.local_rank: |
||||
|
global_q_vector.append(q_vector.to(local_q_vector.device)) |
||||
|
global_ctxs_vector.append(ctx_vectors.to(local_q_vector.device)) |
||||
|
positive_idx_per_question.extend([v + total_ctxs for v in positive_idx]) |
||||
|
hard_negatives_per_question.extend([[v + total_ctxs for v in l] for l in hard_negatives_idxs]) |
||||
|
else: |
||||
|
global_q_vector.append(local_q_vector) |
||||
|
global_ctxs_vector.append(local_ctx_vectors) |
||||
|
positive_idx_per_question.extend([v + total_ctxs for v in local_positive_idxs]) |
||||
|
hard_negatives_per_question.extend([[v + total_ctxs for v in l] for l in local_hard_negatives_idxs]) |
||||
|
total_ctxs += ctx_vectors.size(0) |
||||
|
|
||||
|
global_q_vector = torch.cat(global_q_vector, dim=0) |
||||
|
global_ctxs_vector = torch.cat(global_ctxs_vector, dim=0) |
||||
|
|
||||
|
else: |
||||
|
global_q_vector = local_q_vector |
||||
|
global_ctxs_vector = local_ctx_vectors |
||||
|
global_caption_vector = local_caption_vectors |
||||
|
positive_idx_per_question = local_positive_idxs |
||||
|
hard_negatives_per_question = local_hard_negatives_idxs |
||||
|
|
||||
|
loss, is_correct, scores = loss_function.calc(global_q_vector, global_ctxs_vector, global_caption_vector, |
||||
|
positive_idx_per_question, hard_negatives_per_question, |
||||
|
args.caption_score_weight, experiment) |
||||
|
|
||||
|
return loss, is_correct, scores |
||||
|
|
||||
|
|
||||
|
def compare_models(model_1, model_2): |
||||
|
models_differ = 0 |
||||
|
for key_item_1, key_item_2 in zip(model_1.state_dict().items(), model_2.state_dict().items()): |
||||
|
if torch.equal(key_item_1[1], key_item_2[1]): |
||||
|
pass |
||||
|
else: |
||||
|
models_differ += 1 |
||||
|
if (key_item_1[0] == key_item_2[0]): |
||||
|
print('Mismtach found at', key_item_1[0]) |
||||
|
else: |
||||
|
raise Exception |
||||
|
if models_differ == 0: |
||||
|
print('Models match perfectly! :)') |
||||
|
|
||||
|
|
||||
|
def is_main_process(): |
||||
|
return hvd.rank() == 0 |
||||
|
|
||||
|
|
||||
|
def display_img(img_meta, name, img_only=False): |
||||
|
import matplotlib.pyplot as plt |
||||
|
import matplotlib.image as mpimg |
||||
|
img = mpimg.imread(img_meta[name]['img_file']) |
||||
|
plt.imshow(img) |
||||
|
plt.show() |
||||
|
if not img_only: |
||||
|
print('annotation') |
||||
|
print('\t' + '\n\t'.join(img_meta[name]['annotation'])) |
||||
|
print('caption') |
||||
|
print('\t' + img_meta[name]['caption'][0]) |
||||
|
|
||||
|
|
||||
|
def retrieve_query(model, query, indexer, args, top=10): |
||||
|
input_ids = args.tokenizer.encode(query) |
||||
|
input_ids = torch.LongTensor(input_ids).to(args.device).unsqueeze(0) |
||||
|
attn_mask = torch.ones(len(input_ids[0]), dtype=torch.long, device=args.device).unsqueeze(0) |
||||
|
pos_ids = torch.arange(len(input_ids[0]), dtype=torch.long, device=args.device).unsqueeze(0) |
||||
|
_, query_vector, _ = model.txt_model(input_ids=input_ids,attention_mask=attn_mask, position_ids=pos_ids) |
||||
|
res = indexer.search_knn(query_vector.detach().cpu().numpy(), 100) |
||||
|
return res |
||||
|
|
||||
|
|
||||
|
def get_model_encoded_vecs(model, dataloader): |
||||
|
img_embedding, caption_embedding, query_embedding = dict(), dict(), defaultdict(list) |
||||
|
labels_img_name = [] |
||||
|
# for i, batch in enumerate(dataloader): |
||||
|
for i, batch in enumerate(tqdm.tqdm(dataloader)): |
||||
|
with torch.no_grad(): |
||||
|
model_out = model(batch) |
||||
|
local_q_vectors, local_ctx_vectors, local_caption_vectors = model_out |
||||
|
|
||||
|
img_embedding.update({img_id: img_vec.detach().cpu().numpy() for img_id, img_vec in zip(batch['img_fname'], local_ctx_vectors)}) |
||||
|
caption_embedding.update({img_id: cap_vec.detach().cpu().numpy() for img_id, cap_vec in zip(batch['img_fname'], local_caption_vectors)}) |
||||
|
query_embedding.update({img_id: cap_vec.detach().cpu().numpy() for img_id, cap_vec in zip(batch['txt_index'], local_q_vectors)}) |
||||
|
|
||||
|
labels_img_name.extend(batch['img_fname']) |
||||
|
return { |
||||
|
'img_embed': img_embedding, |
||||
|
'caption_embed': caption_embedding, |
||||
|
'txt_embed': query_embedding, |
||||
|
'img_name': labels_img_name |
||||
|
} |
||||
|
|
@ -0,0 +1,22 @@ |
|||||
|
FROM nvcr.io/nvidia/pytorch:19.05-py3 |
||||
|
COPY requirements.txt scripts/download_bert.py ./ |
||||
|
RUN pip install -r requirements.txt &&\ |
||||
|
python download_bert.py &&\ |
||||
|
rm ./requirements.txt ./download_bert.py |
||||
|
|
||||
|
################## v1 ########################## |
||||
|
|
||||
|
COPY scripts/install_horovod.sh ./ |
||||
|
RUN source install_horovod.sh &&\ |
||||
|
rm ./install_horovod.sh |
||||
|
ENV OPENMPI_VERSION=4.0.0 |
||||
|
|
||||
|
# fix ssh permissions |
||||
|
RUN bash -c "chmod -R 600 /etc/ssh/ && chmod 600 /var/run/sshd/ && chmod 600 /root" |
||||
|
|
||||
|
################## horovod, v2 ########################## |
||||
|
|
||||
|
|
||||
|
RUN bash -c "pip install lz4==2.1.9 lmdb==0.97" |
||||
|
|
||||
|
################# LMDB ########################## |
@ -0,0 +1,21 @@ |
|||||
|
MIT License |
||||
|
|
||||
|
Copyright (c) 2019 Yen-Chun Chen |
||||
|
|
||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy |
||||
|
of this software and associated documentation files (the "Software"), to deal |
||||
|
in the Software without restriction, including without limitation the rights |
||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell |
||||
|
copies of the Software, and to permit persons to whom the Software is |
||||
|
furnished to do so, subject to the following conditions: |
||||
|
|
||||
|
The above copyright notice and this permission notice shall be included in all |
||||
|
copies or substantial portions of the Software. |
||||
|
|
||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR |
||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, |
||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE |
||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER |
||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, |
||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE |
||||
|
SOFTWARE. |
@ -0,0 +1,89 @@ |
|||||
|
# Universal-Image-Text-Transformer |
||||
|
Research code for pre-training universal vision and language models |
||||
|
|
||||
|
|
||||
|
## Requirements |
||||
|
nvidia driver (418.xx), docker(19.03+), nvidia-container-toolkit |
||||
|
``` |
||||
|
docker pull convaicontainerregistry1.azurecr.io/img-txt |
||||
|
``` |
||||
|
|
||||
|
## lauching the environment |
||||
|
``` |
||||
|
# can use CUDA_VISIBLE_DEVICES to seperate GPUs for each container |
||||
|
source launch_container.sh $TXT_DB $IMG_DIR $OUTPUT $PRETRAIN_PATH |
||||
|
# TXT_DB: convaistorage2share2/TXT_DB_v3 |
||||
|
# IMG_DIR: convaistorage2share2/Bottom-up-features/adaptive/npy_per_img_id |
||||
|
# OUTPUT: somewhere to store model checkpoint (can be on share storage) |
||||
|
# PRETRAIN: path to pretrained model |
||||
|
|
||||
|
# when need to preprocessing |
||||
|
source launch_container.sh $TXT_DB $IMG_DIR $OUTPUT $PRETRAIN_PATH --prepro |
||||
|
# this will make /db writable |
||||
|
|
||||
|
|
||||
|
# multi-node training |
||||
|
source launch_container_dist.sh $TXT_DB $IMG_DIR $OUTPUT $PRETRAIN_PATH |
||||
|
``` |
||||
|
|
||||
|
## Pretrain |
||||
|
``` |
||||
|
# inside the docker container |
||||
|
horovodrun -np $N_GPU -H localhost:$N_GPU \ |
||||
|
python pretrain.py --config config/config-pretrain-alltask.json |
||||
|
``` |
||||
|
|
||||
|
## finetune VQA |
||||
|
``` |
||||
|
horovodrun -np 2 -H localhost:2 \ |
||||
|
python train_vqa.py --config config/config-vqa-bert-2gpu-alldata.json |
||||
|
``` |
||||
|
### VQA inference |
||||
|
``` |
||||
|
# single node only |
||||
|
# please refer to code for commandline options |
||||
|
horovodrun -np $N_GPU -H localhost:$N_GPU \ |
||||
|
python eval_vqa.py --txt_db /db/vqa_test_[base/large]-cased.db/ \ |
||||
|
--img_dir /img/coco_test2015 --checkpoint [NUM] \ |
||||
|
--output_dir /path/to/trained/vqa |
||||
|
``` |
||||
|
|
||||
|
### NLVR2 official evaluation |
||||
|
Use official script to get both acc (our validation matched this) and consistency |
||||
|
``` |
||||
|
# concat all output files |
||||
|
cat $OUTPUT/result/[val/test]_results_$STEP_rank*.csv > $OUTPUT.csv |
||||
|
python eval/nlvr2.py $OUTPUT.csv ANNOTATION.json |
||||
|
``` |
||||
|
|
||||
|
### Referring Expression Comprehension: Finetuning and Evaluation |
||||
|
``` |
||||
|
# train on gd-truth pairs of (ref, sent) |
||||
|
horovodrun -np $N_GPU -H localhost:$N_GPU \ |
||||
|
python train_re.py --config config/hps-refcoco+.json |
||||
|
|
||||
|
# evaluate multiple splits on gd-truth boxes |
||||
|
horovodrun -np $N_GPU -H localhost:$N_GPU \ |
||||
|
python eval_re.py \ |
||||
|
--txt_db /db/refcoco+_val_base-cased.db:/db/refcoco+_testA_base-cased.db:/db/refcoco+_testB_base-cased.db \ |
||||
|
--img_dir /img/visual_grounding_coco_gt \ |
||||
|
--output_dir /storage/refcoco+/bert-base_mlm+itm+mrfr_pretrain-refcoco+_lr1e-4 \ |
||||
|
--checkpoint 26 |
||||
|
|
||||
|
# evaluate multiple splits on detected boxes |
||||
|
horovodrun -np $N_GPU -H localhost:$N_GPU \ |
||||
|
python eval_re.py \ |
||||
|
--txt_db /db/refcoco+_val_base-cased.db:/db/refcoco+_testA_base-cased.db:/db/refcoco+_testB_base-cased.db \ |
||||
|
--img_dir /img/visual_grounding_det_coco \ |
||||
|
--output_dir /storage/refcoco+/bert-base_mlm+itm+mrfr_pretrain-refcoco+_lr1e-4 \ |
||||
|
--checkpoint 26 |
||||
|
``` |
||||
|
|
||||
|
## Misc |
||||
|
1. w/o horovodrun it will run on single GPU |
||||
|
- useful for debugger (-m pdb) |
||||
|
2. try `--pin_mem` it might give a tiny performance improvement |
||||
|
3. `--img_format [lmdb/lmdb-compress]` |
||||
|
- trade-off between memory/CPU |
||||
|
- use `--n_workers $N_CPU` to specify data workers (default: 4) |
||||
|
|
@ -0,0 +1,36 @@ |
|||||
|
{ |
||||
|
"train_txt_db": "/db/vcr_val_w_obj_ids_base-cased.db/", |
||||
|
"train_img_dir": "/img/vcr_gt_val/;/img/vcr_val/", |
||||
|
"val_txt_db": "/db/vcr_val_w_obj_ids_base-cased.db/", |
||||
|
"val_img_dir": "/img/vcr_gt_val/;/img/vcr_val/", |
||||
|
"checkpoint": "/storage/pretrain_vcr/mlm_qar-30k_steps-lr_3e-5-run1/ckpt/model_step_29000.pt", |
||||
|
"checkpoint_from": "vcr", |
||||
|
"task": "qa", |
||||
|
"cut_bert": -1, |
||||
|
"output_dir": "/storage/debug/qa_qar-bert_base-gt_proposed_img_feat-mlm-diff_type_id_for_ra-lr_2e-5-train_step_10k", |
||||
|
"max_txt_len": 220, |
||||
|
"conf_th": 0.2, |
||||
|
"max_bb": 100, |
||||
|
"min_bb": 10, |
||||
|
"num_bb": 36, |
||||
|
"train_batch_size": 4000, |
||||
|
"val_batch_size": 12, |
||||
|
"gradient_accumulation_steps": 10, |
||||
|
"learning_rate": 5e-5, |
||||
|
"valid_steps": 10, |
||||
|
"num_train_steps": 1000, |
||||
|
"optim": "adamw", |
||||
|
"betas": [0.9, 0.98], |
||||
|
"grad_norm": 2.0, |
||||
|
"decay": "linear", |
||||
|
"warm_int": 500, |
||||
|
"decay_int": 2000, |
||||
|
"decay_st": 12000, |
||||
|
"decay_rate": 0.2, |
||||
|
"dropout": 0.1, |
||||
|
"weight_decay": 0.01, |
||||
|
"warmup_steps": 1000, |
||||
|
"seed": 42, |
||||
|
"fp16": true, |
||||
|
"mcan": false |
||||
|
} |
@ -0,0 +1,11 @@ |
|||||
|
{ |
||||
|
"txt_db": "/db/itm_coco_test_1k_4_base-cased.db/", |
||||
|
"img_dir": "/img/coco_val2014/", |
||||
|
"neg_sample_size": -1, |
||||
|
"eval_split": "itm_coco_test_1k_4_fp16", |
||||
|
"checkpoint": 12000, |
||||
|
"cut_bert": -1, |
||||
|
"output_dir": "/storage/itm/coco-bert_base-weak_420k-w_train-lr_2e-5-20k_steps-rank_loss-batch_size_20_acc_8_wd_0/", |
||||
|
"fp16": true, |
||||
|
"eval_mini_batch_size": 400 |
||||
|
} |
@ -0,0 +1,11 @@ |
|||||
|
{ |
||||
|
"txt_db": "/db/itm_flickr30k_train_base-cased.db/", |
||||
|
"img_dir": "/img/flickr30k/", |
||||
|
"neg_sample_size": -1, |
||||
|
"eval_split": "itm_flickr_train", |
||||
|
"checkpoint": 6000, |
||||
|
"cut_bert": -1, |
||||
|
"output_dir": "/storage/itm/flickr30k-bert_base_weak_420k-hard_neg_finetune-lr_5e-5-train_step_5k", |
||||
|
"fp16": true, |
||||
|
"eval_mini_batch_size": 128 |
||||
|
} |
@ -0,0 +1,53 @@ |
|||||
|
{ |
||||
|
"train_txt_db": ["/db/itm_coco_train_base-cased.db/", |
||||
|
"/db/itm_coco_restval_base-cased.db"], |
||||
|
"train_img_dir": ["/img/coco_train2014/", |
||||
|
"/img/coco_val2014/"], |
||||
|
"train_neg_sample_p": 0.5, |
||||
|
"neg_sample_from": "i", |
||||
|
"eval_method": "rank", |
||||
|
"val_txt_db": ["/db/itm_coco_val_1k_0_base-cased.db/", |
||||
|
"/db/itm_coco_val_1k_1_base-cased.db/", |
||||
|
"/db/itm_coco_val_1k_2_base-cased.db/", |
||||
|
"/db/itm_coco_val_1k_3_base-cased.db/", |
||||
|
"/db/itm_coco_val_1k_4_base-cased.db/"], |
||||
|
"val_img_dir": ["/img/coco_val2014/", |
||||
|
"/img/coco_val2014/", |
||||
|
"/img/coco_val2014/", |
||||
|
"/img/coco_val2014/", |
||||
|
"/img/coco_val2014/"], |
||||
|
"test_txt_db": ["/db/itm_coco_test_1k_0_base-cased.db/", |
||||
|
"/db/itm_coco_test_1k_1_base-cased.db/", |
||||
|
"/db/itm_coco_test_1k_2_base-cased.db/", |
||||
|
"/db/itm_coco_test_1k_3_base-cased.db/", |
||||
|
"/db/itm_coco_test_1k_4_base-cased.db/"], |
||||
|
"test_img_dir": ["/img/coco_val2014/", |
||||
|
"/img/coco_val2014/", |
||||
|
"/img/coco_val2014/", |
||||
|
"/img/coco_val2014/", |
||||
|
"/img/coco_val2014/"], |
||||
|
"checkpoint": "/pretrain/mlm_caption_bert-base.pt", |
||||
|
"cut_bert": -1, |
||||
|
"output_dir": "/storage/itm_tr/coco_bert_base-mlm_caption-corrected_img_bb_num-w_train_rv-step_40000", |
||||
|
"max_txt_len": 60, |
||||
|
"conf_th": 0.2, |
||||
|
"max_bb": 100, |
||||
|
"min_bb": 10, |
||||
|
"num_bb": 36, |
||||
|
"train_batch_size": 2048, |
||||
|
"val_batch_size": 4096, |
||||
|
"val_minibatch_size":400, |
||||
|
"test_minibatch_size":300, |
||||
|
"gradient_accumulation_steps": 8, |
||||
|
"learning_rate": 0.001, |
||||
|
"valid_steps": 1000, |
||||
|
"num_train_steps": 40000, |
||||
|
"optim": "adamax", |
||||
|
"decay": "linear", |
||||
|
"dropout": 0.1, |
||||
|
"falseweight_decay": 0.01, |
||||
|
"grad_norm": 2.0, |
||||
|
"warmup_steps": 1000, |
||||
|
"seed": 42, |
||||
|
"fp16": true |
||||
|
} |
@ -0,0 +1,25 @@ |
|||||
|
{ |
||||
|
"train_txt_db": "/db/refcoco+_train_base-cased.db", |
||||
|
"train_img_dir": "/img/visual_grounding_coco_gt", |
||||
|
"val_txt_db": "/db/refcoco+_val_base-cased.db", |
||||
|
"val_img_dir": "/img/visual_grounding_coco_gt", |
||||
|
"checkpoint": "/pretrain/bert-base_weak/ckpt/model_step_420000.pt", |
||||
|
"cut_bert": -1, |
||||
|
"output_dir": "/storage/refcoco+/bert-base_mlm+itm+mrfr_pretrain-refcoco+_lr1e-4", |
||||
|
"max_txt_len": 60, |
||||
|
"train_batch_size": 128, |
||||
|
"val_batch_size": 128, |
||||
|
"learning_rate": 1e-4, |
||||
|
"optim": "adamw", |
||||
|
"betas": [0.9, 0.98], |
||||
|
"weight_decay": 0.01, |
||||
|
"dropout": 0.1, |
||||
|
"grad_norm": 2.0, |
||||
|
"decay": "linear", |
||||
|
"num_train_steps": 24000, |
||||
|
"warmup_steps": 1500, |
||||
|
"gradient_accumulation_steps": 1, |
||||
|
"no_cuda": false, |
||||
|
"seed": 24, |
||||
|
"fp16": true |
||||
|
} |
@ -0,0 +1,26 @@ |
|||||
|
{ |
||||
|
"train_txt_db": "/db/refcoco+_train_base-cased.db", |
||||
|
"train_img_dir": "/img/visual_grounding_coco_gt", |
||||
|
"val_txt_db": "/db/refcoco+_val_base-cased.db", |
||||
|
"val_img_dir": "/img/visual_grounding_det_coco", |
||||
|
"checkpoint": "/pretrain/bert-base_weak_conceptual/ckpt/model_step_200000.pt", |
||||
|
"cut_bert": -1, |
||||
|
"output_dir": "/storage/refcoco+/conceptual-bert-base_mlm+itm+mrfr_pretrain-refcoco+_lr8e-5", |
||||
|
"max_txt_len": 60, |
||||
|
"train_batch_size": 128, |
||||
|
"val_batch_size": 128, |
||||
|
"learning_rate": 8e-5, |
||||
|
"valid_steps": 1000, |
||||
|
"optim": "adamw", |
||||
|
"betas": [0.9, 0.98], |
||||
|
"weight_decay": 0.01, |
||||
|
"dropout": 0.1, |
||||
|
"grad_norm": 2.0, |
||||
|
"decay": "linear", |
||||
|
"num_train_steps": 24000, |
||||
|
"warmup_steps": 1500, |
||||
|
"gradient_accumulation_steps": 1, |
||||
|
"no_cuda": false, |
||||
|
"seed": 24, |
||||
|
"fp16": true |
||||
|
} |
@ -0,0 +1,26 @@ |
|||||
|
{ |
||||
|
"train_txt_db": "/db/refcoco+_train_large-cased.db", |
||||
|
"train_img_dir": "/img/visual_grounding_coco_gt", |
||||
|
"val_txt_db": "/db/refcoco+_val_large-cased.db", |
||||
|
"val_img_dir": "/img/visual_grounding_det_coco", |
||||
|
"checkpoint": "/pretrain/bert-large_weak/ckpt/model_step_50000.pt", |
||||
|
"cut_bert": -1, |
||||
|
"output_dir": "/storage/refcoco+/conceptual-bert-large_mlm+itm+mrfr_pretrain-refcoco+_lr8e-5_b64g4", |
||||
|
"max_txt_len": 60, |
||||
|
"train_batch_size": 64, |
||||
|
"val_batch_size": 256, |
||||
|
"learning_rate": 8e-5, |
||||
|
"valid_steps": 1000, |
||||
|
"optim": "adamw", |
||||
|
"betas": [0.9, 0.98], |
||||
|
"weight_decay": 0.01, |
||||
|
"dropout": 0.1, |
||||
|
"grad_norm": 2.0, |
||||
|
"decay": "linear", |
||||
|
"num_train_steps": 24000, |
||||
|
"warmup_steps": 1500, |
||||
|
"gradient_accumulation_steps": 4, |
||||
|
"no_cuda": false, |
||||
|
"seed": 24, |
||||
|
"fp16": true |
||||
|
} |
@ -0,0 +1,29 @@ |
|||||
|
{ |
||||
|
"train_txt_db": "/db/refcoco+_train_base-cased.db", |
||||
|
"train_img_dir": "/img/visual_grounding_coco_gt", |
||||
|
"val_txt_db": "/db/refcoco+_val_base-cased.db", |
||||
|
"val_img_dir": "/img/visual_grounding_coco_gt", |
||||
|
"checkpoint": "/pretrain/bert-base_weak_conceptual/ckpt/model_step_420000.pt", |
||||
|
"cut_bert": -1, |
||||
|
"output_dir": "/storage/refcoco+/conceptual-bert-base_mlm+itm+mrfr_pretrain-refcoco+_lr1e-4_rank_r0.2_m0.2_step30k", |
||||
|
"max_txt_len": 60, |
||||
|
"train_batch_size": 128, |
||||
|
"val_batch_size": 128, |
||||
|
"learning_rate": 1e-4, |
||||
|
"valid_steps": 1000, |
||||
|
"optim": "adamw", |
||||
|
"betas": [0.9, 0.98], |
||||
|
"weight_decay": 0.01, |
||||
|
"dropout": 0.1, |
||||
|
"grad_norm": 2.0, |
||||
|
"decay": "linear", |
||||
|
"num_train_steps": 30000, |
||||
|
"warmup_steps": 1500, |
||||
|
"gradient_accumulation_steps": 1, |
||||
|
"no_cuda": false, |
||||
|
"seed": 24, |
||||
|
"fp16": true, |
||||
|
"train_loss": "rank", |
||||
|
"hard_ratio": 0.2, |
||||
|
"margin": 0.2 |
||||
|
} |
@ -0,0 +1,26 @@ |
|||||
|
{ |
||||
|
"train_txt_db": "/db/refcoco_train_base-cased.db", |
||||
|
"train_img_dir": "/img/visual_grounding_coco_gt", |
||||
|
"val_txt_db": "/db/refcoco_val_base-cased.db", |
||||
|
"val_img_dir": "/img/visual_grounding_coco_gt", |
||||
|
"checkpoint": "/pretrain/bert-base_weak/ckpt/model_step_420000.pt", |
||||
|
"cut_bert": -1, |
||||
|
"output_dir": "/storage/refcoco/bert-base_mlm+itm+mrfr_pretrain-refcoco_lr3e-4", |
||||
|
"max_txt_len": 60, |
||||
|
"train_batch_size": 128, |
||||
|
"val_batch_size": 128, |
||||
|
"learning_rate": 3e-4, |
||||
|
"valid_steps": 1000, |
||||
|
"optim": "adamw", |
||||
|
"betas": [0.9, 0.98], |
||||
|
"weight_decay": 0.01, |
||||
|
"dropout": 0.1, |
||||
|
"grad_norm": 2.0, |
||||
|
"decay": "linear", |
||||
|
"num_train_steps": 10000, |
||||
|
"warmup_steps": 1500, |
||||
|
"gradient_accumulation_steps": 1, |
||||
|
"no_cuda": false, |
||||
|
"seed": 24, |
||||
|
"fp16": true |
||||
|
} |
@ -0,0 +1,31 @@ |
|||||
|
{ |
||||
|
"train_txt_db": "/db/ve_train_large-cased.db/", |
||||
|
"train_img_dir": "/img/flickr30k/", |
||||
|
"val_txt_db": "/db/ve_dev_large-cased.db/", |
||||
|
"test_img_dir": "/img/flickr30k/", |
||||
|
"test_txt_db": "/db/ve_test_large-cased.db/", |
||||
|
"val_img_dir": "/img/flickr30k/", |
||||
|
"checkpoint": "/pretrain/bert-large_frkl_alldata.pt", |
||||
|
"cut_bert": -1, |
||||
|
"output_dir": "/storage/ve/default", |
||||
|
"max_txt_len": 60, |
||||
|
"conf_th": 0.2, |
||||
|
"max_bb": 100, |
||||
|
"min_bb": 10, |
||||
|
"num_bb": 36, |
||||
|
"train_batch_size": 8192, |
||||
|
"val_batch_size": 8192, |
||||
|
"gradient_accumulation_steps": 4, |
||||
|
"learning_rate": 3e-5, |
||||
|
"valid_steps": 500, |
||||
|
"num_train_steps": 6000, |
||||
|
"warmup_steps": 600, |
||||
|
"optim": "adamw", |
||||
|
"betas": [0.9, 0.98], |
||||
|
"grad_norm": 2.0, |
||||
|
"decay": "linear", |
||||
|
"dropout": 0.1, |
||||
|
"weight_decay": 0.01, |
||||
|
"seed": 42, |
||||
|
"fp16": true |
||||
|
} |
@ -0,0 +1,31 @@ |
|||||
|
{ |
||||
|
"train_txt_db": "/db/ve_train_base-cased.db/", |
||||
|
"train_img_dir": "/img/flickr30k/", |
||||
|
"val_txt_db": "/db/ve_dev_base-cased.db/", |
||||
|
"test_img_dir": "/img/flickr30k/", |
||||
|
"test_txt_db": "/db/ve_test_base-cased.db/", |
||||
|
"val_img_dir": "/img/flickr30k/", |
||||
|
"checkpoint": "/pretrain/base_mlm_mrfr_mrckl_itm_alldata.pt", |
||||
|
"cut_bert": -1, |
||||
|
"output_dir": "/storage/ve/default", |
||||
|
"max_txt_len": 60, |
||||
|
"conf_th": 0.2, |
||||
|
"max_bb": 100, |
||||
|
"min_bb": 10, |
||||
|
"num_bb": 36, |
||||
|
"train_batch_size": 8192, |
||||
|
"val_batch_size": 8192, |
||||
|
"gradient_accumulation_steps": 4, |
||||
|
"learning_rate": 3e-5, |
||||
|
"valid_steps": 500, |
||||
|
"num_train_steps": 6000, |
||||
|
"warmup_steps": 600, |
||||
|
"optim": "adamw", |
||||
|
"betas": [0.9, 0.98], |
||||
|
"grad_norm": 2.0, |
||||
|
"decay": "linear", |
||||
|
"dropout": 0.1, |
||||
|
"weight_decay": 0.01, |
||||
|
"seed": 42, |
||||
|
"fp16": true |
||||
|
} |
@ -0,0 +1,30 @@ |
|||||
|
{ |
||||
|
"train_txt_db": "/db/vqa_train_base-cased.db/", |
||||
|
"train_img_dir": "/img/coco_train2014/", |
||||
|
"val_txt_db": "/db/vqa_val_base-cased.db/", |
||||
|
"val_img_dir": "/img/coco_val2014/", |
||||
|
"ans2label": "/db/ans2label.pkl", |
||||
|
"checkpoint": "/storage/mlm/caption-base_from-scratch_grad-acc-8_step-80k_val-step-5k_lr-2e-4_wu-0.1/ckpt/model_step_80000_final.pt", |
||||
|
"cut_bert": -1, |
||||
|
"output_dir": "/storage/vqa/bert_base-mlm_caption_nonblind_from_scratch-nowd-linear", |
||||
|
"max_txt_len": 60, |
||||
|
"conf_th": 0.2, |
||||
|
"max_bb": 100, |
||||
|
"min_bb": 10, |
||||
|
"num_bb": 36, |
||||
|
"train_batch_size": 2048, |
||||
|
"val_batch_size": 4096, |
||||
|
"gradient_accumulation_steps": 8, |
||||
|
"learning_rate": 0.001, |
||||
|
"valid_steps": 500, |
||||
|
"num_train_steps": 20000, |
||||
|
"optim": "adamax", |
||||
|
"decay": "linear", |
||||
|
"dropout": 0.1, |
||||
|
"weight_decay": 0, |
||||
|
"grad_norm": 2.0, |
||||
|
"warmup_steps": 1000, |
||||
|
"seed": 42, |
||||
|
"fp16": false, |
||||
|
"blind": false |
||||
|
} |
@ -0,0 +1,47 @@ |
|||||
|
{ |
||||
|
"compressed_db": false, |
||||
|
"checkpoint": "/pretrain/alltask_ot_alldata.pt", |
||||
|
"output_dir": "/storage/finetune/itm/coco_ot_alldata_base_hnv2", |
||||
|
"max_txt_len": 60, |
||||
|
"conf_th": 0.2, |
||||
|
"max_bb": 100, |
||||
|
"min_bb": 10, |
||||
|
"num_bb": 36, |
||||
|
"train_batch_size": 8, |
||||
|
"negative_size": 399, |
||||
|
"hard_neg_size": 31, |
||||
|
"inf_minibatch_size": 400, |
||||
|
"margin": 0.2, |
||||
|
"learning_rate": 5e-05, |
||||
|
"valid_steps": 500, |
||||
|
"num_train_steps": 5000, |
||||
|
"optim": "adamw", |
||||
|
"betas": [ |
||||
|
0.9, |
||||
|
0.98 |
||||
|
], |
||||
|
"decay": "linear", |
||||
|
"dropout": 0.1, |
||||
|
"weight_decay": 0.01, |
||||
|
"grad_norm": 2.0, |
||||
|
"warmup_steps": 500, |
||||
|
"seed": 42, |
||||
|
"full_val": true, |
||||
|
"fp16": true, |
||||
|
"n_workers": 4, |
||||
|
"pin_mem": true, |
||||
|
"train_txt_dbs": [ |
||||
|
"/db/itm_coco_train_base-cased.db", |
||||
|
"/db/itm_coco_restval_base-cased.db" |
||||
|
], |
||||
|
"train_img_dbs": [ |
||||
|
"/img/coco_train2014/", |
||||
|
"/img/coco_val2014" |
||||
|
], |
||||
|
"val_txt_db": "/db/itm_coco_val_base-cased.db", |
||||
|
"val_img_db": "/img/coco_val2014/", |
||||
|
"test_txt_db": "/db/itm_coco_test_base-cased.db", |
||||
|
"test_img_db": "/img/coco_val2014/", |
||||
|
"model_config": "/src/config/uniter-base.json", |
||||
|
"rank": 0 |
||||
|
} |
@ -0,0 +1,45 @@ |
|||||
|
{ |
||||
|
"compressed_db": false, |
||||
|
"checkpoint": "/pretrain/bert-base-cased.pt", |
||||
|
"output_dir": "/ssd2/siqi/Projects/model_compression/outputs/debug", |
||||
|
"max_txt_len": 60, |
||||
|
"conf_th": 0.2, |
||||
|
"max_bb": 100, |
||||
|
"min_bb": 10, |
||||
|
"num_bb": 36, |
||||
|
"train_batch_size": 80, |
||||
|
"gradient_accumulation_steps": 8, |
||||
|
"negative_size": 1, |
||||
|
"hard_neg_size": 0, |
||||
|
"inf_minibatch_size": 400, |
||||
|
"margin": 0.2, |
||||
|
"learning_rate": 5e-05, |
||||
|
"warmup_steps": 100, |
||||
|
"valid_steps": 500, |
||||
|
"num_train_steps": 1000, |
||||
|
"optim": "adamw", |
||||
|
"betas": [ |
||||
|
0.9, |
||||
|
0.98 |
||||
|
], |
||||
|
"decay": "linear", |
||||
|
"dropout": 0.1, |
||||
|
"weight_decay": 0.01, |
||||
|
"grad_norm": 2.0, |
||||
|
"seed": 42, |
||||
|
"full_val": false, |
||||
|
"fp16": true, |
||||
|
"n_workers": 4, |
||||
|
"pin_mem": true, |
||||
|
"train_txt_dbs": [ |
||||
|
"/db/itm_flickr30k_train_base-cased.db" |
||||
|
], |
||||
|
"train_img_dbs": [ |
||||
|
"/img/flickr30k/" |
||||
|
], |
||||
|
"val_txt_db": "/db/itm_flickr30k_val_base-cased.db", |
||||
|
"val_img_db": "/img/flickr30k/", |
||||
|
"test_txt_db": "/db/itm_flickr30k_test_base-cased.db", |
||||
|
"test_img_db": "/img/flickr30k/", |
||||
|
"model_config": "./config/uniter-base.json" |
||||
|
} |
@ -0,0 +1,45 @@ |
|||||
|
{ |
||||
|
"compressed_db": false, |
||||
|
"checkpoint": "/pretrain/alltask_ot_alldata.pt", |
||||
|
"output_dir": "/ssd2/siqi/Projects/model_compression/outputs/debug", |
||||
|
"max_txt_len": 60, |
||||
|
"conf_th": 0.2, |
||||
|
"max_bb": 100, |
||||
|
"min_bb": 10, |
||||
|
"num_bb": 36, |
||||
|
"train_batch_size": 40, |
||||
|
"gradient_accumulation_steps": 4, |
||||
|
"negative_size": 1, |
||||
|
"hard_neg_size": 0, |
||||
|
"inf_minibatch_size": 400, |
||||
|
"margin": 0.2, |
||||
|
"learning_rate": 5e-05, |
||||
|
"valid_steps": 500, |
||||
|
"num_train_steps": 5000, |
||||
|
"optim": "adamw", |
||||
|
"betas": [ |
||||
|
0.9, |
||||
|
0.98 |
||||
|
], |
||||
|
"decay": "linear", |
||||
|
"dropout": 0.1, |
||||
|
"weight_decay": 0.01, |
||||
|
"grad_norm": 2.0, |
||||
|
"warmup_steps": 500, |
||||
|
"seed": 42, |
||||
|
"full_val": false, |
||||
|
"fp16": true, |
||||
|
"n_workers": 4, |
||||
|
"pin_mem": true, |
||||
|
"train_txt_dbs": [ |
||||
|
"/db/itm_flickr30k_train_base-cased.db" |
||||
|
], |
||||
|
"train_img_dbs": [ |
||||
|
"/img/flickr30k/" |
||||
|
], |
||||
|
"val_txt_db": "/db/itm_flickr30k_val_base-cased.db", |
||||
|
"val_img_db": "/img/flickr30k/", |
||||
|
"test_txt_db": "/db/itm_flickr30k_test_base-cased.db", |
||||
|
"test_img_db": "/img/flickr30k/", |
||||
|
"model_config": "./config/uniter-base.json" |
||||
|
} |
@ -0,0 +1,42 @@ |
|||||
|
{ |
||||
|
"train_datasets": [ |
||||
|
{"name": "gqa", |
||||
|
"db": ["/db/pretrain_gqa_train_0_large-cased.db", "/db/pretrain_gqa_train_1_base-cased.db", |
||||
|
"/db/pretrain_gqa_train_2_base-cased.db", "/db/pretrain_gqa_train_3_base-cased.db", |
||||
|
"/db/pretrain_gqa_train_4_base-cased.db", "/db/pretrain_gqa_train_5_base-cased.db", |
||||
|
"/db/pretrain_gqa_train_6_base-cased.db", "/db/pretrain_gqa_train_7_base-cased.db", |
||||
|
"/db/pretrain_gqa_train_8_base-cased.db", "/db/pretrain_gqa_train_9_base-cased.db", |
||||
|
"/db/pretrain_gqa_val_base-cased.db"], |
||||
|
"img": ["/img/gqa/"], |
||||
|
"tasks": ["mlm", "mrm", "mrckl"], |
||||
|
"mix_ratio": [2, 1, 1]} |
||||
|
], |
||||
|
"val_datasets": [ |
||||
|
{"name": "gqa", |
||||
|
"db": ["/db/pretrain_gqa_testdev_balanced_base-cased.db"], |
||||
|
"img": ["/img/gqa/"], |
||||
|
"tasks": ["mlm", "mrm", "mrckl"]} |
||||
|
], |
||||
|
"checkpoint": "/pretrain/bert-large_weak_alldata/ckpt/model_step_100000.pt", |
||||
|
"output_dir": "/storage/pretrain_gqa/bert_large_weak_alldata_100k-train_val_all-mlm_mrm_mrckl-train_batch_size_6144-500k_steps", |
||||
|
"mrm_prob": 0.15, |
||||
|
"max_txt_len": 220, |
||||
|
"conf_th": 0.2, |
||||
|
"max_bb": 100, |
||||
|
"min_bb": 10, |
||||
|
"num_bb": 36, |
||||
|
"train_batch_size": 6144, |
||||
|
"val_batch_size": 8000, |
||||
|
"gradient_accumulation_steps": 10, |
||||
|
"learning_rate": 3e-05, |
||||
|
"valid_steps": 10000, |
||||
|
"num_train_steps": 500000, |
||||
|
"optim": "adamw", |
||||
|
"decay": "linear", |
||||
|
"dropout": 0.1, |
||||
|
"weight_decay": 0.01, |
||||
|
"grad_norm": -1, |
||||
|
"warmup_steps": 50000, |
||||
|
"seed": 42, |
||||
|
"fp16": true |
||||
|
} |
@ -0,0 +1,42 @@ |
|||||
|
{ |
||||
|
"train_datasets": [ |
||||
|
{"name": "coco_cap", |
||||
|
"db": ["/db/pretrain_caption_coco_train_base-cased.db/", |
||||
|
"/db/pretrain_caption_coco_trainval_base-cased.db/"], |
||||
|
"img": ["/img/coco_train2014/", "/img/coco_val2014/"], |
||||
|
"tasks": ["mlm"], |
||||
|
"mix_ratio": [1]} |
||||
|
], |
||||
|
"val_datasets": [ |
||||
|
{"name": "coco_cap", |
||||
|
"db": ["/db/pretrain_caption_coco_val_base-cased.db/"], |
||||
|
"img": ["/img/coco_val2014/"], |
||||
|
"tasks": ["mlm"]} |
||||
|
], |
||||
|
"output_dir": "/storage/pretrain/mlm_coco", |
||||
|
"mrm_prob": 0.15, |
||||
|
"neg_size": 1024, |
||||
|
"itm_neg_prob": 0.5, |
||||
|
"max_txt_len": 60, |
||||
|
"conf_th": 0.2, |
||||
|
"max_bb": 100, |
||||
|
"min_bb": 10, |
||||
|
"num_bb": 36, |
||||
|
"train_batch_size": 8192, |
||||
|
"val_batch_size": 8192, |
||||
|
"gradient_accumulation_steps": 2, |
||||
|
"learning_rate": 5e-05, |
||||
|
"valid_steps": 5000, |
||||
|
"num_train_steps": 100000, |
||||
|
"optim": "adamw", |
||||
|
"decay": "linear", |
||||
|
"dropout": 0.1, |
||||
|
"weight_decay": 0.01, |
||||
|
"grad_norm": 2.0, |
||||
|
"warmup_steps": 10000, |
||||
|
"seed": 42, |
||||
|
"fp16": true, |
||||
|
"pin_mem": true, |
||||
|
"n_workers": 4, |
||||
|
"from_scratch": false |
||||
|
} |
@ -0,0 +1,53 @@ |
|||||
|
{ |
||||
|
"train_datasets": [ |
||||
|
{"name": "coco_cap", |
||||
|
"db": ["/db/pretrain_caption_coco_train_base-cased.db/", |
||||
|
"/db/pretrain_caption_coco_trainval_base-cased.db/"], |
||||
|
"img": ["/img/coco_train2014/", "/img/coco_val2014/"], |
||||
|
"tasks": ["itm", "mlm", "mrfr", "mrckl"], |
||||
|
"mix_ratio": [2, 2, 1, 1]}, |
||||
|
{"name": "vg_cap", |
||||
|
"db": ["/db/pretrain_caption_vg_train_base-cased.db/"], |
||||
|
"img": ["/img/vg/"], |
||||
|
"tasks": ["itm", "mlm", "mrfr", "mrckl"], |
||||
|
"mix_ratio": [2, 2, 1, 1]} |
||||
|
], |
||||
|
"val_datasets": [ |
||||
|
{"name": "coco_cap", |
||||
|
"db": ["/db/pretrain_caption_coco_val_base-cased.db/"], |
||||
|
"img": ["/img/coco_val2014/"], |
||||
|
"tasks": ["itm", "mlm", "mrfr", "mrckl"]}, |
||||
|
{"name": "vg_cap", |
||||
|
"db": ["/db/pretrain_caption_vg_val_base-cased.db/"], |
||||
|
"img": ["/img/vg/"], |
||||
|
"tasks": ["itm", "mlm", "mrfr", "mrckl"]} |
||||
|
], |
||||
|
"model_config": "/src/config/uniter-base.json", |
||||
|
"checkpoint": "/pretrain/bert-base-cased.pt", |
||||
|
"output_dir": "/storage/pretrain/alltask_ot_indomain_base", |
||||
|
"ans2label": "/db/pretrain_ans2label.pkl", |
||||
|
"mrm_prob": 0.15, |
||||
|
"itm_neg_prob": 0.5, |
||||
|
"itm_ot_lambda": 0.1, |
||||
|
"max_txt_len": 60, |
||||
|
"conf_th": 0.2, |
||||
|
"max_bb": 100, |
||||
|
"min_bb": 10, |
||||
|
"num_bb": 36, |
||||
|
"train_batch_size": 10240, |
||||
|
"val_batch_size": 10240, |
||||
|
"gradient_accumulation_steps": 2, |
||||
|
"learning_rate": 5e-05, |
||||
|
"valid_steps": 5000, |
||||
|
"num_train_steps": 200000, |
||||
|
"optim": "adamw", |
||||
|
"decay": "linear", |
||||
|
"dropout": 0.1, |
||||
|
"weight_decay": 0.01, |
||||
|
"grad_norm": 5.0, |
||||
|
"warmup_steps": 10000, |
||||
|
"seed": 42, |
||||
|
"fp16": true, |
||||
|
"pin_mem": true, |
||||
|
"n_workers": 4 |
||||
|
} |
@ -0,0 +1,42 @@ |
|||||
|
{ |
||||
|
"train_datasets": [ |
||||
|
{"name": "coco_cap", |
||||
|
"db": ["/db/pretrain_caption_coco_train_base-cased.db/", |
||||
|
"/db/pretrain_caption_coco_trainval_base-cased.db/"], |
||||
|
"img": ["/img/coco_train2014/", "/img/coco_val2014/"], |
||||
|
"tasks": ["mrckl"], |
||||
|
"mix_ratio": [1]} |
||||
|
], |
||||
|
"val_datasets": [ |
||||
|
{"name": "coco_cap", |
||||
|
"db": ["/db/pretrain_caption_coco_val_base-cased.db/"], |
||||
|
"img": ["/img/coco_val2014/"], |
||||
|
"tasks": ["mrckl"]} |
||||
|
], |
||||
|
"output_dir": "/storage/pretrain/mrckl_coco", |
||||
|
"mrm_prob": 0.15, |
||||
|
"neg_size": 1024, |
||||
|
"itm_neg_prob": 0.5, |
||||
|
"max_txt_len": 60, |
||||
|
"conf_th": 0.2, |
||||
|
"max_bb": 100, |
||||
|
"min_bb": 10, |
||||
|
"num_bb": 36, |
||||
|
"train_batch_size": 8192, |
||||
|
"val_batch_size": 8192, |
||||
|
"gradient_accumulation_steps": 2, |
||||
|
"learning_rate": 5e-05, |
||||
|
"valid_steps": 5000, |
||||
|
"num_train_steps": 100000, |
||||
|
"optim": "adamw", |
||||
|
"decay": "linear", |
||||
|
"dropout": 0.1, |
||||
|
"weight_decay": 0.01, |
||||
|
"grad_norm": 2.0, |
||||
|
"warmup_steps": 10000, |
||||
|
"seed": 42, |
||||
|
"fp16": true, |
||||
|
"pin_mem": true, |
||||
|
"n_workers": 4, |
||||
|
"from_scratch": false |
||||
|
} |
@ -0,0 +1,42 @@ |
|||||
|
{ |
||||
|
"train_datasets": [ |
||||
|
{"name": "coco_cap", |
||||
|
"db": ["/db/pretrain_caption_coco_train_base-cased.db/", |
||||
|
"/db/pretrain_caption_coco_trainval_base-cased.db/"], |
||||
|
"img": ["/img/coco_train2014/", "/img/coco_val2014/"], |
||||
|
"tasks": ["mrfr"], |
||||
|
"mix_ratio": [1]} |
||||
|
], |
||||
|
"val_datasets": [ |
||||
|
{"name": "coco_cap", |
||||
|
"db": ["/db/pretrain_caption_coco_val_base-cased.db/"], |
||||
|
"img": ["/img/coco_val2014/"], |
||||
|
"tasks": ["mrfr"]} |
||||
|
], |
||||
|
"output_dir": "/storage/pretrain/mrfr_coco", |
||||
|
"mrm_prob": 0.15, |
||||
|
"neg_size": 1024, |
||||
|
"itm_neg_prob": 0.5, |
||||
|
"max_txt_len": 60, |
||||
|
"conf_th": 0.2, |
||||
|
"max_bb": 100, |
||||
|
"min_bb": 10, |
||||
|
"num_bb": 36, |
||||
|
"train_batch_size": 8192, |
||||
|
"val_batch_size": 8192, |
||||
|
"gradient_accumulation_steps": 2, |
||||
|
"learning_rate": 5e-05, |
||||
|
"valid_steps": 5000, |
||||
|
"num_train_steps": 100000, |
||||
|
"optim": "adamw", |
||||
|
"decay": "linear", |
||||
|
"dropout": 0.1, |
||||
|
"weight_decay": 0.01, |
||||
|
"grad_norm": 2.0, |
||||
|
"warmup_steps": 10000, |
||||
|
"seed": 42, |
||||
|
"fp16": true, |
||||
|
"pin_mem": true, |
||||
|
"n_workers": 4, |
||||
|
"from_scratch": false |
||||
|
} |
@ -0,0 +1,43 @@ |
|||||
|
{ |
||||
|
"train_datasets": [ |
||||
|
{"name": "coco_cap", |
||||
|
"db": ["/db/pretrain_caption_coco_train_base-cased.db/", |
||||
|
"/db/pretrain_caption_coco_trainval_base-cased.db/"], |
||||
|
"img": ["/img/coco_train2014/", "/img/coco_val2014/"], |
||||
|
"tasks": ["mrm-nce"], |
||||
|
"mix_ratio": [1]} |
||||
|
], |
||||
|
"val_datasets": [ |
||||
|
{"name": "coco_cap", |
||||
|
"db": ["/db/pretrain_caption_coco_val_base-cased.db/"], |
||||
|
"img": ["/img/coco_val2014/"], |
||||
|
"tasks": ["mrm-nce"]} |
||||
|
], |
||||
|
"output_dir": "/storage/pretrain/mrm_nce_coco", |
||||
|
"mrm_prob": 0.15, |
||||
|
"neg_size": 1024, |
||||
|
"nce_temp": 1.0, |
||||
|
"itm_neg_prob": 0.5, |
||||
|
"max_txt_len": 60, |
||||
|
"conf_th": 0.2, |
||||
|
"max_bb": 100, |
||||
|
"min_bb": 10, |
||||
|
"num_bb": 36, |
||||
|
"train_batch_size": 8192, |
||||
|
"val_batch_size": 8192, |
||||
|
"gradient_accumulation_steps": 2, |
||||
|
"learning_rate": 5e-05, |
||||
|
"valid_steps": 5000, |
||||
|
"num_train_steps": 100000, |
||||
|
"optim": "adamw", |
||||
|
"decay": "linear", |
||||
|
"dropout": 0.1, |
||||
|
"weight_decay": 0.01, |
||||
|
"grad_norm": 2.0, |
||||
|
"warmup_steps": 10000, |
||||
|
"seed": 42, |
||||
|
"fp16": true, |
||||
|
"pin_mem": true, |
||||
|
"n_workers": 4, |
||||
|
"from_scratch": false |
||||
|
} |
@ -0,0 +1,38 @@ |
|||||
|
{ |
||||
|
"train_datasets": [ |
||||
|
{"name": "vcr", |
||||
|
"db": ["/db/vcr_val_w_obj_ids_base-cased.db/"], |
||||
|
"img": ["/img/vcr_val/;/img/vcr_gt_val/"], |
||||
|
"tasks": ["mlm", "mrm", "mrckl"], |
||||
|
"mix_ratio": [2, 1, 1]} |
||||
|
], |
||||
|
"val_datasets": [ |
||||
|
{"name": "vcr", |
||||
|
"db": ["/db/vcr_val_w_obj_ids_base-cased.db/"], |
||||
|
"img": ["/img/vcr_val/;/img/vcr_gt_val/"], |
||||
|
"tasks": ["mlm", "mrm", "mrckl"]} |
||||
|
], |
||||
|
"checkpoint": "/pretrain/bert-base_weak_w_mlm_itm_mrm_mrckl_4gpu/ckpt/model_step_500000.pt", |
||||
|
"vcr_task": ["qa", "qar"], |
||||
|
"output_dir": "/storage/debug/mlm_mrm_mrckl-qa_qar-gt_det", |
||||
|
"mrm_prob": 0.15, |
||||
|
"max_txt_len": 60, |
||||
|
"conf_th": 0.2, |
||||
|
"max_bb": 100, |
||||
|
"min_bb": 10, |
||||
|
"num_bb": 36, |
||||
|
"train_batch_size": 8000, |
||||
|
"val_batch_size": 8000, |
||||
|
"gradient_accumulation_steps": 5, |
||||
|
"learning_rate": 3e-05, |
||||
|
"valid_steps": 10, |
||||
|
"num_train_steps": 120000, |
||||
|
"optim": "adamw", |
||||
|
"decay": "linear", |
||||
|
"dropout": 0.1, |
||||
|
"weight_decay": 0.01, |
||||
|
"grad_norm": -1, |
||||
|
"warmup_steps": 12000, |
||||
|
"seed": 42, |
||||
|
"fp16": true |
||||
|
} |
@ -0,0 +1,40 @@ |
|||||
|
{ |
||||
|
"train_txt_dbs": ["/db/itm_flickr30k_val_base-cased.db"], |
||||
|
"train_img_dbs": ["/img/flickr30k/"], |
||||
|
"val_txt_db": "/db/itm_flickr30k_val_base-cased.db", |
||||
|
"val_img_db": "/img/flickr30k/", |
||||
|
"test_txt_db": "/db/itm_flickr30k_test_base-cased.db", |
||||
|
"test_img_db": "/img/flickr30k/", |
||||
|
"checkpoint": "/pretrain/uniter-base-iclr.pt", |
||||
|
"model_config": "/src/config/uniter-base.json", |
||||
|
"output_dir": "/debug/itm/flickr_default", |
||||
|
"max_txt_len": 60, |
||||
|
"conf_th": 0.2, |
||||
|
"max_bb": 100, |
||||
|
"min_bb": 10, |
||||
|
"num_bb": 36, |
||||
|
"train_batch_size": 20, |
||||
|
"negative_size": 1, |
||||
|
"hard_neg_size": 1, |
||||
|
"hard_neg_pool_size": 20, |
||||
|
"steps_per_hard_neg": 30, |
||||
|
"inf_minibatch_size": 40, |
||||
|
"gradient_accumulation_steps": 2, |
||||
|
"learning_rate": 1e-05, |
||||
|
"valid_steps": 40, |
||||
|
"num_train_steps": 50, |
||||
|
"optim": "adamw", |
||||
|
"betas": [ |
||||
|
0.9, |
||||
|
0.98 |
||||
|
], |
||||
|
"margin": 0.2, |
||||
|
"dropout": 0.1, |
||||
|
"weight_decay": 0.01, |
||||
|
"grad_norm": 2.0, |
||||
|
"warmup_steps": 1600, |
||||
|
"seed": 42, |
||||
|
"fp16": true, |
||||
|
"n_workers": 0, |
||||
|
"pin_mem": true |
||||
|
} |
@ -0,0 +1,38 @@ |
|||||
|
{ |
||||
|
"train_txt_dbs": ["/db/itm_flickr30k_train_base-cased.db"], |
||||
|
"train_img_dbs": ["/img/flickr30k/"], |
||||
|
"val_txt_db": "/db/itm_flickr30k_val_base-cased.db", |
||||
|
"val_img_db": "/img/flickr30k/", |
||||
|
"test_txt_db": "/db/itm_flickr30k_test_base-cased.db", |
||||
|
"test_img_db": "/img/flickr30k/", |
||||
|
"checkpoint": "/pretrain/alltask_ot_alldata.pt", |
||||
|
"model_config": "/src/config/uniter-base.json", |
||||
|
"output_dir": "/storage/finetune/itm/flickr_ot_alldata_base_hnv2", |
||||
|
"max_txt_len": 60, |
||||
|
"conf_th": 0.2, |
||||
|
"max_bb": 100, |
||||
|
"min_bb": 10, |
||||
|
"num_bb": 36, |
||||
|
"train_batch_size": 8, |
||||
|
"negative_size": 399, |
||||
|
"hard_neg_size": 31, |
||||
|
"inf_minibatch_size": 400, |
||||
|
"learning_rate": 5e-05, |
||||
|
"valid_steps": 500, |
||||
|
"num_train_steps": 5000, |
||||
|
"optim": "adamw", |
||||
|
"betas": [ |
||||
|
0.9, |
||||
|
0.98 |
||||
|
], |
||||
|
"margin": 0.2, |
||||
|
"dropout": 0.1, |
||||
|
"weight_decay": 0.01, |
||||
|
"grad_norm": 2.0, |
||||
|
"warmup_steps": 500, |
||||
|
"seed": 42, |
||||
|
"fp16": true, |
||||
|
"n_workers": 4, |
||||
|
"pin_mem": true, |
||||
|
"full_val": true |
||||
|
} |
@ -0,0 +1,40 @@ |
|||||
|
{ |
||||
|
"train_txt_dbs": ["/db/itm_flickr30k_train_base-cased.db"], |
||||
|
"train_img_dbs": ["/img/flickr30k/"], |
||||
|
"val_txt_db": "/db/itm_flickr30k_val_base-cased.db", |
||||
|
"val_img_db": "/img/flickr30k/", |
||||
|
"test_txt_db": "/db/itm_flickr30k_test_base-cased.db", |
||||
|
"test_img_db": "/img/flickr30k/", |
||||
|
"checkpoint": "/pretrain/alltask_ot_alldata.pt", |
||||
|
"model_config": "/src/config/uniter-base.json", |
||||
|
"output_dir": "/storage/finetune/itm/flickr_ot_alldata_base", |
||||
|
"max_txt_len": 60, |
||||
|
"conf_th": 0.2, |
||||
|
"max_bb": 100, |
||||
|
"min_bb": 10, |
||||
|
"num_bb": 36, |
||||
|
"train_batch_size": 40, |
||||
|
"negative_size": 1, |
||||
|
"hard_neg_size": 0, |
||||
|
"hard_neg_pool_size": 20, |
||||
|
"steps_per_hard_neg": -1, |
||||
|
"inf_minibatch_size": 512, |
||||
|
"gradient_accumulation_steps": 4, |
||||
|
"learning_rate": 5e-05, |
||||
|
"valid_steps": 2000, |
||||
|
"num_train_steps": 20000, |
||||
|
"optim": "adamw", |
||||
|
"betas": [ |
||||
|
0.9, |
||||
|
0.98 |
||||
|
], |
||||
|
"margin": 0.2, |
||||
|
"dropout": 0.1, |
||||
|
"weight_decay": 0.01, |
||||
|
"grad_norm": 2.0, |
||||
|
"warmup_steps": 2000, |
||||
|
"seed": 42, |
||||
|
"fp16": true, |
||||
|
"n_workers": 4, |
||||
|
"pin_mem": true |
||||
|
} |
@ -0,0 +1,37 @@ |
|||||
|
{ |
||||
|
"train_txt_db": "/db/nlvr2_train_base-cased.db", |
||||
|
"train_img_db": "/img/nlvr2_train/", |
||||
|
"val_txt_db": "/db/nlvr2_dev_base-cased.db", |
||||
|
"val_img_db": "/img/nlvr2_dev/", |
||||
|
"test_txt_db": "/db/nlvr2_test1_base-cased.db", |
||||
|
"test_img_db": "/img/nlvr2_test/", |
||||
|
"checkpoint": "/pretrain/uniter-base-iclr.pt", |
||||
|
"model_config": "/src/config/uniter-base.json", |
||||
|
"model": "paired-attn", |
||||
|
"use_img_type": true, |
||||
|
"output_dir": "/storage/nlvr2/default", |
||||
|
"max_txt_len": 60, |
||||
|
"conf_th": 0.2, |
||||
|
"max_bb": 100, |
||||
|
"min_bb": 10, |
||||
|
"num_bb": 36, |
||||
|
"train_batch_size": 10240, |
||||
|
"val_batch_size": 10240, |
||||
|
"gradient_accumulation_steps": 1, |
||||
|
"learning_rate": 3e-05, |
||||
|
"valid_steps": 500, |
||||
|
"num_train_steps": 8000, |
||||
|
"optim": "adamw", |
||||
|
"betas": [ |
||||
|
0.9, |
||||
|
0.98 |
||||
|
], |
||||
|
"dropout": 0.1, |
||||
|
"weight_decay": 0.01, |
||||
|
"grad_norm": 2.0, |
||||
|
"warmup_steps": 800, |
||||
|
"seed": 77, |
||||
|
"fp16": true, |
||||
|
"n_workers": 4, |
||||
|
"pin_mem": true |
||||
|
} |
@ -0,0 +1,31 @@ |
|||||
|
{ |
||||
|
"train_txt_db": "/db/ve_train_base-cased.db/", |
||||
|
"train_img_db": "/img/flickr30k/", |
||||
|
"val_txt_db": "/db/ve_dev_base-cased.db/", |
||||
|
"val_img_db": "/img/flickr30k/", |
||||
|
"test_txt_db": "/db/ve_test_base-cased.db/", |
||||
|
"test_img_db": "/img/flickr30k/", |
||||
|
"checkpoint": "/pretrain/alltask_ot_alldata.pt", |
||||
|
"model_config": "/src/config/uniter-base.json", |
||||
|
"output_dir": "/storage/finetune/ve/default", |
||||
|
"max_txt_len": 60, |
||||
|
"conf_th": 0.2, |
||||
|
"max_bb": 100, |
||||
|
"min_bb": 10, |
||||
|
"num_bb": 36, |
||||
|
"train_batch_size": 4096, |
||||
|
"val_batch_size": 4096, |
||||
|
"gradient_accumulation_steps": 4, |
||||
|
"learning_rate": 3e-5, |
||||
|
"valid_steps": 500, |
||||
|
"num_train_steps": 6000, |
||||
|
"warmup_steps": 600, |
||||
|
"optim": "adamw", |
||||
|
"betas": [0.9, 0.98], |
||||
|
"grad_norm": 2.0, |
||||
|
"decay": "linear", |
||||
|
"dropout": 0.1, |
||||
|
"weight_decay": 0.01, |
||||
|
"seed": 42, |
||||
|
"fp16": true |
||||
|
} |
@ -0,0 +1,31 @@ |
|||||
|
{ |
||||
|
"train_txt_db": "/db/ve_train_large-cased.db/", |
||||
|
"train_img_db": "/img/flickr30k/", |
||||
|
"val_txt_db": "/db/ve_dev_large-cased.db/", |
||||
|
"val_img_db": "/img/flickr30k/", |
||||
|
"test_txt_db": "/db/ve_test_large-cased.db/", |
||||
|
"test_img_db": "/img/flickr30k/", |
||||
|
"checkpoint": "/pretrain/alltask_ot_alldata_large.pt", |
||||
|
"model_config": "/src/config/uniter-large.json", |
||||
|
"output_dir": "/storage/finetune/ve/default_large", |
||||
|
"max_txt_len": 60, |
||||
|
"conf_th": 0.2, |
||||
|
"max_bb": 100, |
||||
|
"min_bb": 10, |
||||
|
"num_bb": 36, |
||||
|
"train_batch_size": 4096, |
||||
|
"val_batch_size": 4096, |
||||
|
"gradient_accumulation_steps": 4, |
||||
|
"learning_rate": 3e-5, |
||||
|
"valid_steps": 500, |
||||
|
"num_train_steps": 6000, |
||||
|
"warmup_steps": 600, |
||||
|
"optim": "adamw", |
||||
|
"betas": [0.9, 0.98], |
||||
|
"grad_norm": 2.0, |
||||
|
"decay": "linear", |
||||
|
"dropout": 0.1, |
||||
|
"weight_decay": 0.01, |
||||
|
"seed": 42, |
||||
|
"fp16": true |
||||
|
} |
@ -0,0 +1,35 @@ |
|||||
|
{ |
||||
|
"train_txt_dbs": ["/db/vqa_train_base-cased.db", |
||||
|
"/db/vqa_trainval_base-cased.db", |
||||
|
"/db/vqa_vg_base-cased.db"], |
||||
|
"train_img_dbs": ["/img/coco_train2014/", "/img/coco_val2014", "/img/vg/"], |
||||
|
"val_txt_db": "/db/vqa_devval_base-cased.db", |
||||
|
"val_img_db": "/img/coco_val2014/", |
||||
|
"checkpoint": "/pretrain/uniter-base-iclr.pt", |
||||
|
"model_config": "/src/config/uniter-base.json", |
||||
|
"output_dir": "/storage/vqa/default", |
||||
|
"max_txt_len": 60, |
||||
|
"conf_th": 0.2, |
||||
|
"max_bb": 100, |
||||
|
"min_bb": 10, |
||||
|
"num_bb": 36, |
||||
|
"train_batch_size": 10240, |
||||
|
"val_batch_size": 10240, |
||||
|
"gradient_accumulation_steps": 5, |
||||
|
"learning_rate": 8e-05, |
||||
|
"valid_steps": 500, |
||||
|
"num_train_steps": 6000, |
||||
|
"optim": "adamw", |
||||
|
"betas": [ |
||||
|
0.9, |
||||
|
0.98 |
||||
|
], |
||||
|
"dropout": 0.1, |
||||
|
"weight_decay": 0.01, |
||||
|
"grad_norm": 2.0, |
||||
|
"warmup_steps": 600, |
||||
|
"seed": 42, |
||||
|
"fp16": true, |
||||
|
"n_workers": 4, |
||||
|
"pin_mem": true |
||||
|
} |
@ -0,0 +1,14 @@ |
|||||
|
{ |
||||
|
"attention_probs_dropout_prob": 0.1, |
||||
|
"hidden_act": "gelu", |
||||
|
"hidden_dropout_prob": 0.1, |
||||
|
"hidden_size": 768, |
||||
|
"initializer_range": 0.02, |
||||
|
"intermediate_size": 3072, |
||||
|
"max_position_embeddings": 512, |
||||
|
"num_attention_heads": 12, |
||||
|
"num_hidden_layers": 12, |
||||
|
"num_hidden_layers_img": 1, |
||||
|
"type_vocab_size": 2, |
||||
|
"vocab_size": 28996 |
||||
|
} |
@ -0,0 +1,13 @@ |
|||||
|
{ |
||||
|
"attention_probs_dropout_prob": 0.1, |
||||
|
"hidden_act": "gelu", |
||||
|
"hidden_dropout_prob": 0.1, |
||||
|
"hidden_size": 1024, |
||||
|
"initializer_range": 0.02, |
||||
|
"intermediate_size": 4096, |
||||
|
"max_position_embeddings": 512, |
||||
|
"num_attention_heads": 16, |
||||
|
"num_hidden_layers": 24, |
||||
|
"type_vocab_size": 2, |
||||
|
"vocab_size": 28996 |
||||
|
} |
@ -0,0 +1,27 @@ |
|||||
|
#from .data import (TxtTokLmdb, DetectFeatLmdb, |
||||
|
# ConcatDatasetWithLens, ImageLmdbGroup) |
||||
|
#from .mlm import (MlmDataset, MlmEvalDataset, |
||||
|
# BlindMlmDataset, BlindMlmEvalDataset, |
||||
|
# mlm_collate, mlm_eval_collate, |
||||
|
# mlm_blind_collate, mlm_blind_eval_collate) |
||||
|
#from .mrm import (MrfrDataset, OnlyImgMrfrDataset, |
||||
|
# MrcDataset, OnlyImgMrcDataset, |
||||
|
# mrfr_collate, mrfr_only_img_collate, |
||||
|
# mrc_collate, mrc_only_img_collate) |
||||
|
from .itm import (TokenBucketSamplerForItm, |
||||
|
ItmDataset, itm_collate, itm_ot_collate, |
||||
|
ItmRankDataset, ItmRankDatasetHardNeg, itm_rank_collate, |
||||
|
ItmRankDatasetHardNegFromText, |
||||
|
ItmRankDatasetHardNegFromImage, itm_rank_hnv2_collate, |
||||
|
ItmHardNegDataset, itm_hn_collate, |
||||
|
ItmValDataset, itm_val_collate, |
||||
|
ItmEvalDataset, itm_eval_collate) |
||||
|
from .sampler import TokenBucketSampler, DistributedSampler |
||||
|
from .loader import MetaLoader, PrefetchLoader |
||||
|
|
||||
|
from .vqa import VqaDataset, vqa_collate, VqaEvalDataset, vqa_eval_collate |
||||
|
from .nlvr2 import (Nlvr2PairedDataset, nlvr2_paired_collate, |
||||
|
Nlvr2PairedEvalDataset, nlvr2_paired_eval_collate, |
||||
|
Nlvr2TripletDataset, nlvr2_triplet_collate, |
||||
|
Nlvr2TripletEvalDataset, nlvr2_triplet_eval_collate) |
||||
|
from .ve import VeDataset, ve_collate, VeEvalDataset, ve_eval_collate |
@ -0,0 +1,283 @@ |
|||||
|
""" |
||||
|
Dataset interfaces |
||||
|
""" |
||||
|
from collections import defaultdict |
||||
|
from contextlib import contextmanager |
||||
|
import io |
||||
|
import json |
||||
|
import lmdb |
||||
|
from os.path import exists |
||||
|
|
||||
|
import numpy as np |
||||
|
import torch |
||||
|
from torch.utils.data import Dataset, ConcatDataset |
||||
|
from tqdm import tqdm |
||||
|
from lz4.frame import compress, decompress |
||||
|
|
||||
|
import msgpack |
||||
|
import msgpack_numpy |
||||
|
msgpack_numpy.patch() |
||||
|
|
||||
|
|
||||
|
def _fp16_to_fp32(feat_dict): |
||||
|
out = {k: arr.astype(np.float32) |
||||
|
if arr.dtype == np.float16 else arr |
||||
|
for k, arr in feat_dict.items()} |
||||
|
return out |
||||
|
|
||||
|
|
||||
|
def compute_num_bb(confs, conf_th, min_bb, max_bb): |
||||
|
num_bb = max(min_bb, (confs > conf_th).sum()) |
||||
|
num_bb = min(max_bb, num_bb) |
||||
|
return num_bb |
||||
|
|
||||
|
|
||||
|
class DetectFeatLmdb(object): |
||||
|
def __init__(self, img_dir, conf_th=0.2, max_bb=100, min_bb=10, num_bb=36, |
||||
|
compress=True): |
||||
|
self.img_dir = img_dir |
||||
|
if conf_th == -1: |
||||
|
db_name = f'feat_numbb{num_bb}' |
||||
|
self.name2nbb = defaultdict(lambda: num_bb) |
||||
|
else: |
||||
|
db_name = f'feat_th{conf_th}_max{max_bb}_min{min_bb}' |
||||
|
nbb = f'nbb_th{conf_th}_max{max_bb}_min{min_bb}.json' |
||||
|
if not exists(f'{img_dir}/{nbb}'): |
||||
|
# nbb is not pre-computed |
||||
|
self.name2nbb = None |
||||
|
else: |
||||
|
self.name2nbb = json.load(open(f'{img_dir}/{nbb}')) |
||||
|
self.compress = compress |
||||
|
if compress: |
||||
|
db_name += '_compressed' |
||||
|
|
||||
|
if self.name2nbb is None: |
||||
|
if compress: |
||||
|
db_name = 'all_compressed' |
||||
|
else: |
||||
|
db_name = 'all' |
||||
|
# only read ahead on single node training |
||||
|
self.env = lmdb.open(f'{img_dir}/{db_name}', |
||||
|
readonly=True, create=False, |
||||
|
readahead=not _check_distributed()) |
||||
|
self.txn = self.env.begin(buffers=True) |
||||
|
if self.name2nbb is None: |
||||
|
self.name2nbb = self._compute_nbb() |
||||
|
|
||||
|
def _compute_nbb(self): |
||||
|
name2nbb = {} |
||||
|
fnames = json.loads(self.txn.get(key=b'__keys__').decode('utf-8')) |
||||
|
for fname in tqdm(fnames, desc='reading images'): |
||||
|
dump = self.txn.get(fname.encode('utf-8')) |
||||
|
if self.compress: |
||||
|
with io.BytesIO(dump) as reader: |
||||
|
img_dump = np.load(reader, allow_pickle=True) |
||||
|
confs = img_dump['conf'] |
||||
|
else: |
||||
|
img_dump = msgpack.loads(dump, raw=False) |
||||
|
confs = img_dump['conf'] |
||||
|
name2nbb[fname] = compute_num_bb(confs, self.conf_th, |
||||
|
self.min_bb, self.max_bb) |
||||
|
|
||||
|
return name2nbb |
||||
|
|
||||
|
def __del__(self): |
||||
|
self.env.close() |
||||
|
|
||||
|
def get_dump(self, file_name): |
||||
|
# hack for MRC |
||||
|
dump = self.txn.get(file_name.encode('utf-8')) |
||||
|
nbb = self.name2nbb[file_name] |
||||
|
if self.compress: |
||||
|
with io.BytesIO(dump) as reader: |
||||
|
img_dump = np.load(reader, allow_pickle=True) |
||||
|
img_dump = _fp16_to_fp32(img_dump) |
||||
|
else: |
||||
|
img_dump = msgpack.loads(dump, raw=False) |
||||
|
img_dump = _fp16_to_fp32(img_dump) |
||||
|
img_dump = {k: arr[:nbb, ...] for k, arr in img_dump.items()} |
||||
|
return img_dump |
||||
|
|
||||
|
def __getitem__(self, file_name): |
||||
|
dump = self.txn.get(file_name.encode('utf-8')) |
||||
|
nbb = self.name2nbb[file_name] |
||||
|
if self.compress: |
||||
|
with io.BytesIO(dump) as reader: |
||||
|
img_dump = np.load(reader, allow_pickle=True) |
||||
|
img_dump = {'features': img_dump['features'], |
||||
|
'norm_bb': img_dump['norm_bb']} |
||||
|
else: |
||||
|
img_dump = msgpack.loads(dump, raw=False) |
||||
|
img_feat = torch.tensor(img_dump['features'][:nbb, :]).float() |
||||
|
img_bb = torch.tensor(img_dump['norm_bb'][:nbb, :]).float() |
||||
|
return img_feat, img_bb |
||||
|
|
||||
|
def __contains__(self, file_name): |
||||
|
return self.txn.get(file_name.encode('utf-8')) is not None |
||||
|
|
||||
|
|
||||
|
@contextmanager |
||||
|
def open_lmdb(db_dir, readonly=False): |
||||
|
db = TxtLmdb(db_dir, readonly) |
||||
|
try: |
||||
|
yield db |
||||
|
finally: |
||||
|
del db |
||||
|
|
||||
|
|
||||
|
class TxtLmdb(object): |
||||
|
def __init__(self, db_dir, readonly=True): |
||||
|
self.readonly = readonly |
||||
|
if readonly: |
||||
|
# training |
||||
|
self.env = lmdb.open(db_dir, |
||||
|
readonly=True, create=False, |
||||
|
readahead=not _check_distributed()) |
||||
|
self.txn = self.env.begin(buffers=True) |
||||
|
self.write_cnt = None |
||||
|
else: |
||||
|
# prepro |
||||
|
self.env = lmdb.open(db_dir, readonly=False, create=True, |
||||
|
map_size=4 * 1024**4) |
||||
|
self.txn = self.env.begin(write=True) |
||||
|
self.write_cnt = 0 |
||||
|
|
||||
|
def __del__(self): |
||||
|
if self.write_cnt: |
||||
|
self.txn.commit() |
||||
|
self.env.close() |
||||
|
|
||||
|
def __getitem__(self, key): |
||||
|
return msgpack.loads(decompress(self.txn.get(key.encode('utf-8'))), |
||||
|
raw=False) |
||||
|
|
||||
|
def __setitem__(self, key, value): |
||||
|
# NOTE: not thread safe |
||||
|
if self.readonly: |
||||
|
raise ValueError('readonly text DB') |
||||
|
ret = self.txn.put(key.encode('utf-8'), |
||||
|
compress(msgpack.dumps(value, use_bin_type=True))) |
||||
|
self.write_cnt += 1 |
||||
|
if self.write_cnt % 1000 == 0: |
||||
|
self.txn.commit() |
||||
|
self.txn = self.env.begin(write=True) |
||||
|
self.write_cnt = 0 |
||||
|
return ret |
||||
|
|
||||
|
def get_ids_and_lens(db): |
||||
|
assert isinstance(db, TxtTokLmdb) |
||||
|
lens = [] |
||||
|
ids = [] |
||||
|
for id_ in db.ids: |
||||
|
lens.append(db.id2len[id_]) |
||||
|
ids.append(id_) |
||||
|
return lens, ids |
||||
|
|
||||
|
|
||||
|
class DetectFeatTxtTokDataset(Dataset): |
||||
|
def __init__(self, txt_db, img_db): |
||||
|
assert isinstance(txt_db, TxtTokLmdb) |
||||
|
assert isinstance(img_db, DetectFeatLmdb) |
||||
|
self.txt_db = txt_db |
||||
|
self.img_db = img_db |
||||
|
txt_lens, self.ids = get_ids_and_lens(txt_db) |
||||
|
|
||||
|
txt2img = txt_db.txt2img |
||||
|
self.lens = [tl + self.img_db.name2nbb[txt2img[id_]] |
||||
|
for tl, id_ in zip(txt_lens, self.ids)] |
||||
|
|
||||
|
def __len__(self): |
||||
|
return len(self.ids) |
||||
|
|
||||
|
def __getitem__(self, i): |
||||
|
id_ = self.ids[i] |
||||
|
example = self.txt_db[id_] |
||||
|
return example |
||||
|
|
||||
|
def _get_img_feat(self, fname): |
||||
|
img_feat, bb = self.img_db[fname] |
||||
|
img_bb = torch.cat([bb, bb[:, 4:5]*bb[:, 5:]], dim=-1) |
||||
|
num_bb = img_feat.size(0) |
||||
|
return img_feat, img_bb, num_bb |
||||
|
|
||||
|
|
||||
|
class ConcatDatasetWithLens(ConcatDataset): |
||||
|
""" A thin wrapper on pytorch concat dataset for lens batching """ |
||||
|
def __init__(self, datasets): |
||||
|
super().__init__(datasets) |
||||
|
self.lens = [l for dset in datasets for l in dset.lens] |
||||
|
|
||||
|
def __getattr__(self, name): |
||||
|
return self._run_method_on_all_dsets(name) |
||||
|
|
||||
|
def _run_method_on_all_dsets(self, name): |
||||
|
def run_all(*args, **kwargs): |
||||
|
return [dset.__getattribute__(name)(*args, **kwargs) |
||||
|
for dset in self.datasets] |
||||
|
return run_all |
||||
|
|
||||
|
|
||||
|
def pad_tensors(tensors, lens=None, pad=0): |
||||
|
"""B x [T, ...]""" |
||||
|
if lens is None: |
||||
|
lens = [t.size(0) for t in tensors] |
||||
|
max_len = max(lens) |
||||
|
bs = len(tensors) |
||||
|
hid = tensors[0].size(-1) |
||||
|
dtype = tensors[0].dtype |
||||
|
output = torch.zeros(bs, max_len, hid, dtype=dtype) |
||||
|
if pad: |
||||
|
output.data.fill_(pad) |
||||
|
for i, (t, l) in enumerate(zip(tensors, lens)): |
||||
|
output.data[i, :l, ...] = t.data |
||||
|
return output |
||||
|
|
||||
|
|
||||
|
def get_gather_index(txt_lens, num_bbs, batch_size, max_len, out_size): |
||||
|
# assert len(txt_lens) == len(num_bbs) == batch_size |
||||
|
gather_index = torch.arange(0, out_size, dtype=torch.long, |
||||
|
).unsqueeze(0).repeat(len(num_bbs), 1) |
||||
|
|
||||
|
# for i, (tl, nbb) in enumerate(zip(txt_lens, num_bbs)): |
||||
|
# gather_index.data[i, tl:tl+nbb] = torch.arange(max_len, max_len+nbb, |
||||
|
# dtype=torch.long).data |
||||
|
return gather_index |
||||
|
|
||||
|
|
||||
|
def get_gather_index_uniter(txt_lens, num_bbs, batch_size, max_len, out_size): |
||||
|
assert len(txt_lens) == len(num_bbs) == batch_size |
||||
|
gather_index = torch.arange(0, out_size, dtype=torch.long, |
||||
|
).unsqueeze(0).repeat(batch_size, 1) |
||||
|
|
||||
|
for i, (tl, nbb) in enumerate(zip(txt_lens, num_bbs)): |
||||
|
gather_index.data[i, tl:tl+nbb] = torch.arange(max_len, max_len+nbb, |
||||
|
dtype=torch.long).data |
||||
|
return gather_index |
||||
|
|
||||
|
|
||||
|
def get_gather_index_img(txt_lens, num_bbs, batch_size, max_len, out_size): |
||||
|
gather_index = torch.zeros(batch_size, out_size, dtype=torch.long) |
||||
|
|
||||
|
for i, (tl, nbb) in enumerate(zip(txt_lens, num_bbs)): |
||||
|
gather_index.data[i, :nbb] = torch.arange(max_len, max_len+nbb, |
||||
|
dtype=torch.long).data |
||||
|
gather_index.data[i, nbb:nbb+tl] = torch.arange(0, tl, |
||||
|
dtype=torch.long).data |
||||
|
return gather_index |
||||
|
|
||||
|
|
||||
|
class ImageLmdbGroup(object): |
||||
|
def __init__(self, conf_th, max_bb, min_bb, num_bb, compress): |
||||
|
self.path2imgdb = {} |
||||
|
self.conf_th = conf_th |
||||
|
self.max_bb = max_bb |
||||
|
self.min_bb = min_bb |
||||
|
self.num_bb = num_bb |
||||
|
self.compress = compress |
||||
|
|
||||
|
def __getitem__(self, path): |
||||
|
img_db = self.path2imgdb.get(path, None) |
||||
|
if img_db is None: |
||||
|
img_db = DetectFeatLmdb(path, self.conf_th, self.max_bb, |
||||
|
self.min_bb, self.num_bb, self.compress) |
||||
|
return img_db |
@ -0,0 +1,572 @@ |
|||||
|
""" |
||||
|
Itm dataset |
||||
|
""" |
||||
|
from collections import defaultdict |
||||
|
import copy |
||||
|
import json |
||||
|
import random |
||||
|
|
||||
|
import torch |
||||
|
from torch.nn.utils.rnn import pad_sequence |
||||
|
import numpy as np |
||||
|
from toolz.sandbox import unzip |
||||
|
from cytoolz import concat |
||||
|
|
||||
|
from .data import (DetectFeatTxtTokDataset, TxtTokLmdb, DetectFeatLmdb, |
||||
|
pad_tensors, get_gather_index, get_ids_and_lens) |
||||
|
from .sampler import TokenBucketSampler |
||||
|
|
||||
|
|
||||
|
class TokenBucketSamplerForItm(TokenBucketSampler): |
||||
|
def __init__(self, dset, *args, **kwargs): |
||||
|
super().__init__(dset.lens, *args, **kwargs) |
||||
|
self.dset = dset |
||||
|
|
||||
|
def __iter__(self): |
||||
|
it = super().__iter__() |
||||
|
self.dset.new_epoch() |
||||
|
self._lens = self.dset.lens |
||||
|
return it |
||||
|
|
||||
|
|
||||
|
def _has_overlap(la, lb): |
||||
|
if len(la) < len(lb): |
||||
|
la, lb = lb, la |
||||
|
s = set(la) |
||||
|
return any(b in s for b in lb) |
||||
|
|
||||
|
|
||||
|
def _sample_negative_rand(sample_pool, ground_truths, num_sample): |
||||
|
""" random and retry """ |
||||
|
outputs = ground_truths[:1] |
||||
|
while _has_overlap(outputs, ground_truths): |
||||
|
outputs = random.sample(sample_pool, num_sample) |
||||
|
return outputs |
||||
|
|
||||
|
|
||||
|
def _sample_negative_extra(sample_pool, ground_truths, num_sample): |
||||
|
""" sample extra then remove """ |
||||
|
tot_size = len(ground_truths) + num_sample |
||||
|
outputs = set(random.sample(sample_pool, tot_size)) |
||||
|
for gt in ground_truths: |
||||
|
outputs.discard(gt) |
||||
|
outputs = list(outputs)[:num_sample] |
||||
|
return outputs |
||||
|
|
||||
|
|
||||
|
sample_negative = _sample_negative_rand # swith between 2 implementations |
||||
|
|
||||
|
|
||||
|
class ItmDataset(DetectFeatTxtTokDataset): |
||||
|
""" NOTE this Dataset handles distributed training itself |
||||
|
(for more efficient negative sampling) """ |
||||
|
def __init__(self, txt_db, img_db, neg_sample_p=0.5): |
||||
|
assert isinstance(txt_db, TxtTokLmdb) |
||||
|
assert isinstance(img_db, DetectFeatLmdb) |
||||
|
|
||||
|
self.txt_db = txt_db |
||||
|
self.img_db = img_db |
||||
|
|
||||
|
self.txt_lens, self.ids = get_ids_and_lens(txt_db) |
||||
|
self.all_imgs = list(set(txt_db[id_]['img_fname'] for id_ in self.ids)) |
||||
|
|
||||
|
self.neg_sample_p = neg_sample_p |
||||
|
self.new_epoch() |
||||
|
|
||||
|
def new_epoch(self): |
||||
|
""" should be called every epoch for more randomness""" |
||||
|
self.labels = np.random.choice( |
||||
|
[0, 1], size=len(self.ids), |
||||
|
p=[self.neg_sample_p, 1-self.neg_sample_p]) |
||||
|
|
||||
|
self.lens = [] |
||||
|
self.train_imgs = [] |
||||
|
for i, (id_, tl) in enumerate(zip(self.ids, self.txt_lens)): |
||||
|
img_fname = super().__getitem__(i)['img_fname'] |
||||
|
if self.labels[i] == 0: |
||||
|
img_fname = sample_negative(self.all_imgs, [img_fname], 1)[0] |
||||
|
self.train_imgs.append(img_fname) |
||||
|
self.lens.append(tl + self.img_db.name2nbb[img_fname]) |
||||
|
|
||||
|
def __getitem__(self, i): |
||||
|
example = super().__getitem__(i) |
||||
|
# labels and negative images should be sampled every epoch |
||||
|
ground_truth_label = self.labels[i] |
||||
|
img_fname = self.train_imgs[i] |
||||
|
img_feat, img_pos_feat, num_bb = self._get_img_feat(img_fname) |
||||
|
|
||||
|
# text input |
||||
|
input_ids = example['input_ids'] |
||||
|
input_ids = self.txt_db.combine_inputs(input_ids) |
||||
|
|
||||
|
attn_masks = torch.ones(len(input_ids) + num_bb, dtype=torch.long) |
||||
|
target = torch.Tensor(1).long() |
||||
|
target.data.fill_(ground_truth_label) |
||||
|
|
||||
|
return input_ids, img_feat, img_pos_feat, attn_masks, target |
||||
|
|
||||
|
|
||||
|
def itm_collate(inputs): |
||||
|
(input_ids, img_feats, img_pos_feats, attn_masks, targets |
||||
|
) = map(list, unzip(inputs)) |
||||
|
|
||||
|
txt_lens = [i.size(0) for i in input_ids] |
||||
|
|
||||
|
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0) |
||||
|
position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long |
||||
|
).unsqueeze(0) |
||||
|
|
||||
|
num_bbs = [f.size(0) for f in img_feats] |
||||
|
img_feat = pad_tensors(img_feats, num_bbs) |
||||
|
img_pos_feat = pad_tensors(img_pos_feats, num_bbs) |
||||
|
|
||||
|
attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0) |
||||
|
targets = torch.cat(targets, dim=0) |
||||
|
bs, max_tl = input_ids.size() |
||||
|
out_size = attn_masks.size(1) |
||||
|
gather_index = get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size) |
||||
|
|
||||
|
batch = {'input_ids': input_ids, |
||||
|
'position_ids': position_ids, |
||||
|
'img_feat': img_feat, |
||||
|
'img_pos_feat': img_pos_feat, |
||||
|
'attn_masks': attn_masks, |
||||
|
'gather_index': gather_index, |
||||
|
'targets': targets} |
||||
|
return batch |
||||
|
|
||||
|
|
||||
|
def _compute_ot_scatter(txt_lens, max_txt_len, joint_len): |
||||
|
ot_scatter = torch.arange(0, joint_len, dtype=torch.long |
||||
|
).unsqueeze(0).repeat(len(txt_lens), 1) |
||||
|
for i, tl in enumerate(txt_lens): |
||||
|
max_ind = max_txt_len + (joint_len-tl) |
||||
|
ot_scatter.data[i, tl:] = torch.arange(max_txt_len, max_ind, |
||||
|
dtype=torch.long).data |
||||
|
return ot_scatter |
||||
|
|
||||
|
|
||||
|
def _compute_pad(lens, max_len): |
||||
|
pad = torch.zeros(len(lens), max_len, dtype=torch.bool) |
||||
|
for i, l in enumerate(lens): |
||||
|
pad.data[i, l:].fill_(1) |
||||
|
return pad |
||||
|
|
||||
|
|
||||
|
def itm_ot_collate(inputs): |
||||
|
(input_ids, img_feats, img_pos_feats, attn_masks, targets |
||||
|
) = map(list, unzip(inputs)) |
||||
|
|
||||
|
txt_lens = [i.size(0) for i in input_ids] |
||||
|
|
||||
|
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0) |
||||
|
position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long |
||||
|
).unsqueeze(0) |
||||
|
|
||||
|
num_bbs = [f.size(0) for f in img_feats] |
||||
|
img_feat = pad_tensors(img_feats, num_bbs) |
||||
|
img_pos_feat = pad_tensors(img_pos_feats, num_bbs) |
||||
|
|
||||
|
attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0) |
||||
|
targets = torch.cat(targets, dim=0) |
||||
|
bs, max_tl = input_ids.size() |
||||
|
out_size = attn_masks.size(1) |
||||
|
gather_index = get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size) |
||||
|
|
||||
|
# OT inputs |
||||
|
max_tl = max(txt_lens) |
||||
|
max_nbb = max(num_bbs) |
||||
|
ot_scatter = _compute_ot_scatter(txt_lens, max_tl, attn_masks.size(1)) |
||||
|
txt_pad = _compute_pad(txt_lens, max_tl) |
||||
|
img_pad = _compute_pad(num_bbs, max_nbb) |
||||
|
ot_inputs = {'ot_scatter': ot_scatter, |
||||
|
'scatter_max': ot_scatter.max().item(), |
||||
|
'txt_pad': txt_pad, |
||||
|
'img_pad': img_pad} |
||||
|
|
||||
|
batch = {'input_ids': input_ids, |
||||
|
'position_ids': position_ids, |
||||
|
'img_feat': img_feat, |
||||
|
'img_pos_feat': img_pos_feat, |
||||
|
'attn_masks': attn_masks, |
||||
|
'gather_index': gather_index, |
||||
|
'targets': targets, |
||||
|
'ot_inputs': ot_inputs} |
||||
|
return batch |
||||
|
|
||||
|
|
||||
|
class ItmRankDataset(DetectFeatTxtTokDataset): |
||||
|
def __init__(self, txt_db, img_db, neg_sample_size=1): |
||||
|
assert neg_sample_size > 0, \ |
||||
|
"ItmRankDataset need at least 1 negative sample" |
||||
|
super().__init__(txt_db, img_db) |
||||
|
|
||||
|
txt2img = self.txt_db.txt2img |
||||
|
self.txt2img = {id_: txt2img[id_] for id_ in self.ids} |
||||
|
# images partitioned by rank |
||||
|
self.img2txts = defaultdict(list) |
||||
|
for id_, img in self.txt2img.items(): |
||||
|
self.img2txts[img].append(id_) |
||||
|
self.img_name_list = list(self.img2txts.keys()) |
||||
|
|
||||
|
assert neg_sample_size > 0 |
||||
|
self.neg_sample_size = neg_sample_size |
||||
|
|
||||
|
def __getitem__(self, i): |
||||
|
gt_txt_id = self.ids[i] |
||||
|
gt_img_fname = self.txt2img[gt_txt_id] |
||||
|
|
||||
|
id_pairs = [(gt_txt_id, gt_img_fname)] |
||||
|
# sample negatives |
||||
|
neg_sample_img_ids = sample_negative( |
||||
|
self.img_name_list, [gt_img_fname], self.neg_sample_size) |
||||
|
neg_sample_txt_ids = sample_negative( |
||||
|
self.ids, self.img2txts[gt_img_fname], self.neg_sample_size) |
||||
|
id_pairs.extend([(gt_txt_id, neg_img_id) |
||||
|
for neg_img_id in neg_sample_img_ids] + |
||||
|
[(neg_txt_id, gt_img_fname) |
||||
|
for neg_txt_id in neg_sample_txt_ids]) |
||||
|
inputs = self._collect_inputs(id_pairs) |
||||
|
assert len(inputs) == (1 + 2*self.neg_sample_size) |
||||
|
return inputs |
||||
|
|
||||
|
def _collect_inputs(self, id_pairs): |
||||
|
# create input features |
||||
|
inputs = [] |
||||
|
for txt_id, img_id in id_pairs: |
||||
|
example = self.txt_db[txt_id] |
||||
|
# text input |
||||
|
input_ids = example['input_ids'] |
||||
|
input_ids = self.txt_db.combine_inputs(input_ids) |
||||
|
# img input |
||||
|
img_feat, img_pos_feat, num_bb = self._get_img_feat(img_id) |
||||
|
# mask |
||||
|
attn_masks_text = torch.ones(len(input_ids), dtype=torch.long) |
||||
|
attn_masks_img = torch.ones(num_bb, dtype=torch.long) |
||||
|
|
||||
|
inputs.append((input_ids, img_feat, img_pos_feat, attn_masks_text, attn_masks_img)) |
||||
|
|
||||
|
return inputs |
||||
|
|
||||
|
|
||||
|
class ItmRankDatasetHardNeg(ItmRankDataset): |
||||
|
def __init__(self, txt_db, img_db, neg_sample_size=1, hard_neg_size=1): |
||||
|
assert hard_neg_size > 0, \ |
||||
|
"ItmRankDatasetHardNeg need at least 1 hard negative sample" |
||||
|
DetectFeatTxtTokDataset.__init__(self, txt_db, img_db) |
||||
|
|
||||
|
txt2img = self.txt_db.txt2img |
||||
|
self.txt2img = {id_: txt2img[id_] for id_ in self.ids} |
||||
|
self.img2txts = self.txt_db.img2txts |
||||
|
self.img_name_list = list(self.img2txts.keys()) |
||||
|
|
||||
|
assert neg_sample_size > 0 |
||||
|
self.neg_sample_size = neg_sample_size |
||||
|
self.hard_neg_size = hard_neg_size |
||||
|
|
||||
|
def reload_hard_negs(self, hard_neg_dir): |
||||
|
self.txt2hardimgs = json.load( |
||||
|
open(f'{hard_neg_dir}/' |
||||
|
f'txt2hardimgs_rank{hvd.rank()}.json')) |
||||
|
self.img2hardtxts = json.load( |
||||
|
open(f'{hard_neg_dir}/img2hardtxts.json')) |
||||
|
|
||||
|
def __getitem__(self, i): |
||||
|
gt_txt_id = self.ids[i] |
||||
|
gt_img_fname = self.txt2img[gt_txt_id] |
||||
|
|
||||
|
id_pairs = [(gt_txt_id, gt_img_fname)] |
||||
|
# sample hard negatives |
||||
|
if self.hard_neg_size > 0: |
||||
|
hard_neg_img_samples = random.sample( |
||||
|
self.txt2hardimgs[gt_txt_id], self.hard_neg_size) |
||||
|
hard_neg_txt_samples = random.sample( |
||||
|
self.img2hardtxts[gt_img_fname], self.hard_neg_size) |
||||
|
id_pairs.extend([(gt_txt_id, neg_img_id) |
||||
|
for neg_img_id in hard_neg_img_samples] + |
||||
|
[(neg_txt_id, gt_img_fname) |
||||
|
for neg_txt_id in hard_neg_txt_samples]) |
||||
|
# sample normal negatives |
||||
|
if self.neg_sample_size > 0: |
||||
|
neg_sample_img_ids = sample_negative( |
||||
|
self.img_name_list, [gt_img_fname], self.neg_sample_size) |
||||
|
neg_sample_txt_ids = sample_negative( |
||||
|
self.ids, self.img2txts[gt_img_fname], self.neg_sample_size) |
||||
|
id_pairs.extend([(gt_txt_id, neg_img_id) |
||||
|
for neg_img_id in neg_sample_img_ids] + |
||||
|
[(neg_txt_id, gt_img_fname) |
||||
|
for neg_txt_id in neg_sample_txt_ids]) |
||||
|
|
||||
|
inputs = self._collect_inputs(id_pairs) |
||||
|
assert len(inputs) == (1 |
||||
|
+ 2*self.neg_sample_size |
||||
|
+ 2*self.hard_neg_size) |
||||
|
return inputs |
||||
|
|
||||
|
|
||||
|
def itm_rank_collate(inputs): |
||||
|
(input_ids, img_feats, img_pos_feats, attn_masks_text, attn_masks_img, |
||||
|
) = map(list, unzip(concat(i for i in inputs))) |
||||
|
|
||||
|
txt_lens = [i.size(0) for i in input_ids] |
||||
|
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0) |
||||
|
position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long |
||||
|
).unsqueeze(0) |
||||
|
|
||||
|
num_bbs = [f.size(0) for f in img_feats] |
||||
|
img_feat = pad_tensors(img_feats, num_bbs) |
||||
|
img_pos_feat = pad_tensors(img_pos_feats, num_bbs) |
||||
|
|
||||
|
attn_masks_text = pad_sequence(attn_masks_text, batch_first=True, padding_value=0) |
||||
|
attn_masks_img = pad_sequence(attn_masks_img, batch_first=True, padding_value=0) |
||||
|
sample_size = len(inputs[0]) |
||||
|
assert all(sample_size == len(i) for i in inputs) |
||||
|
|
||||
|
bs, max_tl = input_ids.size() |
||||
|
# out_size = attn_masks.size(1) |
||||
|
gather_index = None # get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size) |
||||
|
|
||||
|
batch = {'input_ids': input_ids, |
||||
|
'position_ids': position_ids, |
||||
|
'img_feat': img_feat, |
||||
|
'img_pos_feat': img_pos_feat, |
||||
|
'attn_masks_text': attn_masks_text, |
||||
|
'attn_masks_img': attn_masks_img, |
||||
|
'gather_index': gather_index, |
||||
|
'sample_size': sample_size} |
||||
|
return batch |
||||
|
|
||||
|
|
||||
|
class ItmRankDatasetHardNegFromText(DetectFeatTxtTokDataset): |
||||
|
def __init__(self, txt_db, img_db, neg_sample_size=1): |
||||
|
assert neg_sample_size > 0, \ |
||||
|
"ItmRankDatasetHardNegV2 need at least 1 negative sample" |
||||
|
super().__init__(txt_db, img_db) |
||||
|
|
||||
|
txt2img = self.txt_db.txt2img |
||||
|
self.txt2img = {id_: txt2img[id_] for id_ in self.ids} |
||||
|
self.img2txts = self.txt_db.img2txts |
||||
|
self.img_name_list = list(self.img2txts.keys()) |
||||
|
self.neg_sample_size = neg_sample_size |
||||
|
|
||||
|
def __getitem__(self, i): |
||||
|
gt_txt_id = self.ids[i] |
||||
|
gt_img_fname = self.txt2img[gt_txt_id] |
||||
|
|
||||
|
input_ids = self.txt_db[gt_txt_id]['input_ids'] |
||||
|
input_ids = self.txt_db.combine_inputs(input_ids) |
||||
|
input_ids = input_ids.unsqueeze(0) |
||||
|
position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long |
||||
|
).unsqueeze(0) |
||||
|
|
||||
|
neg_img_ids = sample_negative( |
||||
|
self.img_name_list, [gt_img_fname], self.neg_sample_size) |
||||
|
img_ids = [gt_img_fname] + neg_img_ids |
||||
|
# process image features (gt always first) |
||||
|
img_feats, img_pos_feats, num_bbs = map( |
||||
|
list, unzip(map(self._get_img_feat, img_ids))) |
||||
|
img_feat = pad_tensors(img_feats, num_bbs) |
||||
|
img_pos_feat = pad_tensors(img_pos_feats, num_bbs) |
||||
|
|
||||
|
tl = input_ids.size(1) |
||||
|
attn_masks = torch.zeros(len(img_ids), max(num_bbs) + tl).long() |
||||
|
for i, nbb in enumerate(num_bbs): |
||||
|
attn_masks.data[i, :tl+nbb].fill_(1) |
||||
|
out_size = attn_masks.size(1) |
||||
|
gather_index = get_gather_index([tl]*len(img_ids), num_bbs, |
||||
|
len(img_ids), tl, out_size) |
||||
|
|
||||
|
batch = {'input_ids': input_ids, |
||||
|
'position_ids': position_ids, |
||||
|
'img_feat': img_feat, |
||||
|
'img_pos_feat': img_pos_feat, |
||||
|
'attn_masks': attn_masks, |
||||
|
'gather_index': gather_index} |
||||
|
return batch |
||||
|
|
||||
|
|
||||
|
class ItmRankDatasetHardNegFromImage(DetectFeatTxtTokDataset): |
||||
|
def __init__(self, txt_db, img_db, neg_sample_size=1): |
||||
|
assert neg_sample_size > 0, \ |
||||
|
"ItmRankDatasetHardNegV2 need at least 1 negative sample" |
||||
|
super().__init__(txt_db, img_db) |
||||
|
|
||||
|
txt2img = self.txt_db.txt2img |
||||
|
self.txt2img = {id_: txt2img[id_] for id_ in self.ids} |
||||
|
self.img2txts = self.txt_db.img2txts |
||||
|
self.txt_name_list = list(self.txt2img.keys()) |
||||
|
self.neg_sample_size = neg_sample_size |
||||
|
|
||||
|
|
||||
|
def __getitem__(self, i): |
||||
|
gt_txt_id = self.ids[i] |
||||
|
gt_img_id = self.txt2img[gt_txt_id] |
||||
|
gt_txt_ids = self.img2txts[gt_img_id] |
||||
|
|
||||
|
# process image features (gt always first) |
||||
|
img_feat, img_pos_feat, nbb = self._get_img_feat(gt_img_id) |
||||
|
img_feat = img_feat.unsqueeze(0) |
||||
|
img_pos_feat = img_pos_feat.unsqueeze(0) |
||||
|
|
||||
|
# sample negative |
||||
|
neg_txt_ids = sample_negative( |
||||
|
self.txt_name_list, gt_txt_ids, self.neg_sample_size) |
||||
|
txt_ids = [gt_txt_id] + neg_txt_ids |
||||
|
|
||||
|
# process text inputs |
||||
|
all_inputs = [] |
||||
|
txt_lens = [] |
||||
|
for txt_id in txt_ids: |
||||
|
input_ids = self.txt_db.combine_inputs( |
||||
|
self.txt_db[txt_id]['input_ids']) |
||||
|
all_inputs.append(input_ids) |
||||
|
txt_lens.append(len(input_ids)) |
||||
|
input_ids = pad_sequence(all_inputs, batch_first=True, padding_value=0) |
||||
|
position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long |
||||
|
).unsqueeze(0) |
||||
|
|
||||
|
attn_masks = torch.zeros(len(txt_ids), max(txt_lens) + nbb).long() |
||||
|
for i, tl in enumerate(txt_lens): |
||||
|
attn_masks.data[i, :tl+nbb].fill_(1) |
||||
|
out_size = attn_masks.size(1) |
||||
|
gather_index = get_gather_index(txt_lens, [nbb]*len(txt_ids), |
||||
|
len(txt_ids), tl, out_size) |
||||
|
|
||||
|
batch = {'input_ids': input_ids, |
||||
|
'position_ids': position_ids, |
||||
|
'img_feat': img_feat, |
||||
|
'img_pos_feat': img_pos_feat, |
||||
|
'attn_masks': attn_masks, |
||||
|
'gather_index': gather_index} |
||||
|
return batch |
||||
|
|
||||
|
|
||||
|
def itm_rank_hnv2_collate(inputs): |
||||
|
assert len(inputs) == 1 |
||||
|
return inputs[0] |
||||
|
|
||||
|
|
||||
|
class ItmValDataset(DetectFeatTxtTokDataset): |
||||
|
""" For evaluating Image-Text-Retrieval task """ |
||||
|
def __init__(self, db_dir, img_dir, mini_batch_size=400): |
||||
|
super().__init__(db_dir, img_dir) |
||||
|
del self.lens |
||||
|
self.txt2img = self.txt_db.txt2img |
||||
|
self.img2txts = self.txt_db.img2txts |
||||
|
self.all_img_ids = list(self.img2txts.keys()) |
||||
|
|
||||
|
assert len(self.img2txts) >= mini_batch_size > 0 |
||||
|
self.bs = mini_batch_size |
||||
|
|
||||
|
def _get_batch_ids(self, i): |
||||
|
gt_txt_id = self.ids[i] |
||||
|
gt_img_id = self.txt2img[gt_txt_id] |
||||
|
|
||||
|
# sample fixed negatives for each gt image |
||||
|
i = self.all_img_ids.index(gt_img_id) |
||||
|
neg_st = i+1 |
||||
|
neg_end = neg_st+self.bs-1 |
||||
|
if neg_end > len(self.all_img_ids): |
||||
|
# warp around |
||||
|
neg_end -= len(self.all_img_ids) |
||||
|
neg_img_ids = (self.all_img_ids[neg_st:] |
||||
|
+ self.all_img_ids[:neg_end]) |
||||
|
else: |
||||
|
neg_img_ids = self.all_img_ids[neg_st:neg_end] |
||||
|
|
||||
|
assert len(neg_img_ids) == (self.bs - 1),\ |
||||
|
"Did not sample enough neg samples" |
||||
|
|
||||
|
return gt_img_id, neg_img_ids |
||||
|
|
||||
|
def __getitem__(self, i): |
||||
|
""" this returns list of mini-batches """ |
||||
|
gt_img_id, neg_img_ids = self._get_batch_ids(i) |
||||
|
# NOTE 1st one is gt img |
||||
|
batch = self.get_batch(i, [gt_img_id] + neg_img_ids) |
||||
|
return batch |
||||
|
|
||||
|
def get_batch(self, i, img_ids): |
||||
|
example = super().__getitem__(i) |
||||
|
|
||||
|
input_ids = example['input_ids'] |
||||
|
input_ids = self.txt_db.combine_inputs(input_ids) |
||||
|
input_ids = input_ids.unsqueeze(0).expand(len(img_ids), -1).clone() |
||||
|
position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long |
||||
|
).unsqueeze(0) |
||||
|
|
||||
|
# process image features (gt always first) |
||||
|
img_feats, img_pos_feats, num_bbs = map( |
||||
|
list, unzip(map(self._get_img_feat, img_ids))) |
||||
|
img_feat = pad_tensors(img_feats, num_bbs) |
||||
|
img_pos_feat = pad_tensors(img_pos_feats, num_bbs) |
||||
|
|
||||
|
tl = input_ids.size(1) |
||||
|
attn_masks_text = torch.ones(len(img_ids), tl).long() |
||||
|
# attn_masks_text = torch.ones(1, tl).long() |
||||
|
attn_masks_img = torch.zeros(len(img_ids), max(num_bbs)).long() |
||||
|
for i, nbb in enumerate(num_bbs): |
||||
|
attn_masks_img.data[i, :nbb].fill_(1) |
||||
|
|
||||
|
# out_size = attn_masks.size(1) |
||||
|
gather_index = None #get_gather_index([tl]*len(img_ids), num_bbs, len(img_ids), tl, out_size) |
||||
|
|
||||
|
batch = {'input_ids': input_ids, |
||||
|
'position_ids': position_ids, |
||||
|
'img_feat': img_feat, |
||||
|
'img_pos_feat': img_pos_feat, |
||||
|
'attn_masks_text': attn_masks_text, |
||||
|
'attn_masks_img': attn_masks_img, |
||||
|
'gather_index': gather_index} |
||||
|
return batch |
||||
|
|
||||
|
|
||||
|
def itm_val_collate(inputs): |
||||
|
assert len(inputs) == 1, "input batch size > 1" |
||||
|
return inputs[0] |
||||
|
|
||||
|
|
||||
|
class ItmHardNegDataset(ItmValDataset): |
||||
|
def _get_batch_ids(self, i): |
||||
|
gt_txt_id = self.ids[i] |
||||
|
gt_img_id = self.txt2img[gt_txt_id] |
||||
|
|
||||
|
# sample fixed negatives for each gt image |
||||
|
i = self.all_img_ids.index(gt_img_id) |
||||
|
all_img_ids = copy.deepcopy(self.all_img_ids) |
||||
|
all_img_ids.remove(gt_img_id) |
||||
|
random.shuffle(all_img_ids) |
||||
|
neg_img_ids = all_img_ids[:self.bs] |
||||
|
|
||||
|
assert len(neg_img_ids) == (self.bs),\ |
||||
|
"Did not sample enough neg samples" |
||||
|
|
||||
|
return gt_img_id, neg_img_ids |
||||
|
|
||||
|
def __getitem__(self, i): |
||||
|
""" this returns list of mini-batches """ |
||||
|
_, neg_img_ids = self._get_batch_ids(i) |
||||
|
batch = self.get_batch(i, neg_img_ids) |
||||
|
batch['gt_txt_id'] = self.ids[i] |
||||
|
batch['neg_img_ids'] = neg_img_ids |
||||
|
return batch |
||||
|
|
||||
|
|
||||
|
itm_hn_collate = itm_val_collate |
||||
|
|
||||
|
|
||||
|
class ItmEvalDataset(ItmValDataset): |
||||
|
def __init__(self, *args, **kwargs): |
||||
|
super().__init__(*args, **kwargs) |
||||
|
self.all_img_ids = sorted(copy.deepcopy(self.all_img_ids), |
||||
|
key=lambda i: self.img_db.name2nbb[i]) |
||||
|
|
||||
|
def __getitem__(self, i): |
||||
|
mini_batches = [] |
||||
|
for st in range(0, len(self.all_img_ids), self.bs): |
||||
|
mini_batches.append( |
||||
|
self.get_batch(i, self.all_img_ids[st:st+self.bs])) |
||||
|
return mini_batches |
||||
|
|
||||
|
|
||||
|
itm_eval_collate = itm_val_collate |
@ -0,0 +1,138 @@ |
|||||
|
""" |
||||
|
A meta data loader for sampling from different datasets / training tasks |
||||
|
A prefetch loader to speedup data loading |
||||
|
""" |
||||
|
import random |
||||
|
|
||||
|
import torch |
||||
|
from torch.utils.data import DataLoader |
||||
|
|
||||
|
from uniter_model.utils.distributed import any_broadcast |
||||
|
|
||||
|
|
||||
|
class MetaLoader(object): |
||||
|
""" wraps multiple data loader """ |
||||
|
def __init__(self, loaders, accum_steps=1, distributed=False): |
||||
|
assert isinstance(loaders, dict) |
||||
|
self.name2loader = {} |
||||
|
self.name2iter = {} |
||||
|
self.sampling_pools = [] |
||||
|
for n, l in loaders.items(): |
||||
|
if isinstance(l, tuple): |
||||
|
l, r = l |
||||
|
elif isinstance(l, DataLoader): |
||||
|
r = 1 |
||||
|
else: |
||||
|
raise ValueError() |
||||
|
self.name2loader[n] = l |
||||
|
self.name2iter[n] = iter(l) |
||||
|
self.sampling_pools.extend([n]*r) |
||||
|
|
||||
|
self.accum_steps = accum_steps |
||||
|
self.distributed = distributed |
||||
|
self.step = 0 |
||||
|
|
||||
|
def __iter__(self): |
||||
|
""" this iterator will run indefinitely """ |
||||
|
task = self.sampling_pools[0] |
||||
|
while True: |
||||
|
if self.step % self.accum_steps == 0: |
||||
|
task = random.choice(self.sampling_pools) |
||||
|
if self.distributed: |
||||
|
# make sure all process is training same task |
||||
|
task = any_broadcast(task, 0) |
||||
|
self.step += 1 |
||||
|
iter_ = self.name2iter[task] |
||||
|
try: |
||||
|
batch = next(iter_) |
||||
|
except StopIteration: |
||||
|
iter_ = iter(self.name2loader[task]) |
||||
|
batch = next(iter_) |
||||
|
self.name2iter[task] = iter_ |
||||
|
|
||||
|
yield task, batch |
||||
|
|
||||
|
|
||||
|
def move_to_cuda(batch): |
||||
|
if isinstance(batch, torch.Tensor): |
||||
|
return batch.cuda(non_blocking=True) |
||||
|
elif isinstance(batch, list): |
||||
|
new_batch = [move_to_cuda(t) for t in batch] |
||||
|
elif isinstance(batch, tuple): |
||||
|
new_batch = tuple(move_to_cuda(t) for t in batch) |
||||
|
elif isinstance(batch, dict): |
||||
|
new_batch = {n: move_to_cuda(t) for n, t in batch.items()} |
||||
|
else: |
||||
|
return batch |
||||
|
return new_batch |
||||
|
|
||||
|
|
||||
|
def record_cuda_stream(batch): |
||||
|
if isinstance(batch, torch.Tensor): |
||||
|
batch.record_stream(torch.cuda.current_stream()) |
||||
|
elif isinstance(batch, list) or isinstance(batch, tuple): |
||||
|
for t in batch: |
||||
|
record_cuda_stream(t) |
||||
|
elif isinstance(batch, dict): |
||||
|
for t in batch.values(): |
||||
|
record_cuda_stream(t) |
||||
|
else: |
||||
|
pass |
||||
|
|
||||
|
|
||||
|
class PrefetchLoader(object): |
||||
|
""" |
||||
|
overlap compute and cuda data transfer |
||||
|
(copied and then modified from nvidia apex) |
||||
|
""" |
||||
|
def __init__(self, loader): |
||||
|
self.loader = loader |
||||
|
self.stream = torch.cuda.Stream() |
||||
|
|
||||
|
def __iter__(self): |
||||
|
loader_it = iter(self.loader) |
||||
|
self.preload(loader_it) |
||||
|
batch = self.next(loader_it) |
||||
|
while batch is not None: |
||||
|
yield batch |
||||
|
batch = self.next(loader_it) |
||||
|
|
||||
|
def __len__(self): |
||||
|
return len(self.loader) |
||||
|
|
||||
|
def preload(self, it): |
||||
|
try: |
||||
|
self.batch = next(it) |
||||
|
except StopIteration: |
||||
|
self.batch = None |
||||
|
return |
||||
|
# if record_stream() doesn't work, another option is to make sure |
||||
|
# device inputs are created on the main stream. |
||||
|
# self.next_input_gpu = torch.empty_like(self.next_input, |
||||
|
# device='cuda') |
||||
|
# self.next_target_gpu = torch.empty_like(self.next_target, |
||||
|
# device='cuda') |
||||
|
# Need to make sure the memory allocated for next_* is not still in use |
||||
|
# by the main stream at the time we start copying to next_*: |
||||
|
# self.stream.wait_stream(torch.cuda.current_stream()) |
||||
|
with torch.cuda.stream(self.stream): |
||||
|
self.batch = move_to_cuda(self.batch) |
||||
|
# more code for the alternative if record_stream() doesn't work: |
||||
|
# copy_ will record the use of the pinned source tensor in this |
||||
|
# side stream. |
||||
|
# self.next_input_gpu.copy_(self.next_input, non_blocking=True) |
||||
|
# self.next_target_gpu.copy_(self.next_target, non_blocking=True) |
||||
|
# self.next_input = self.next_input_gpu |
||||
|
# self.next_target = self.next_target_gpu |
||||
|
|
||||
|
def next(self, it): |
||||
|
torch.cuda.current_stream().wait_stream(self.stream) |
||||
|
batch = self.batch |
||||
|
if batch is not None: |
||||
|
record_cuda_stream(batch) |
||||
|
self.preload(it) |
||||
|
return batch |
||||
|
|
||||
|
def __getattr__(self, name): |
||||
|
method = self.loader.__getattribute__(name) |
||||
|
return method |
@ -0,0 +1,360 @@ |
|||||
|
""" |
||||
|
MLM datasets |
||||
|
""" |
||||
|
import math |
||||
|
import random |
||||
|
|
||||
|
import torch |
||||
|
from torch.utils.data import Dataset |
||||
|
from torch.nn.utils.rnn import pad_sequence |
||||
|
from toolz.sandbox import unzip |
||||
|
|
||||
|
from .data import (DetectFeatTxtTokDataset, TxtTokLmdb, |
||||
|
get_ids_and_lens, pad_tensors, get_gather_index) |
||||
|
|
||||
|
|
||||
|
def random_word(tokens, vocab_range, mask): |
||||
|
""" |
||||
|
Masking some random tokens for Language Model task with probabilities as in |
||||
|
the original BERT paper. |
||||
|
:param tokens: list of int, tokenized sentence. |
||||
|
:param vocab_range: for choosing a random word |
||||
|
:return: (list of int, list of int), masked tokens and related labels for |
||||
|
LM prediction |
||||
|
""" |
||||
|
output_label = [] |
||||
|
|
||||
|
for i, token in enumerate(tokens): |
||||
|
prob = random.random() |
||||
|
# mask token with 15% probability |
||||
|
if prob < 0.15: |
||||
|
prob /= 0.15 |
||||
|
|
||||
|
# 80% randomly change token to mask token |
||||
|
if prob < 0.8: |
||||
|
tokens[i] = mask |
||||
|
|
||||
|
# 10% randomly change token to random token |
||||
|
elif prob < 0.9: |
||||
|
tokens[i] = random.choice(list(range(*vocab_range))) |
||||
|
|
||||
|
# -> rest 10% randomly keep current token |
||||
|
|
||||
|
# append current token to output (we will predict these later) |
||||
|
output_label.append(token) |
||||
|
else: |
||||
|
# no masking token (will be ignored by loss function later) |
||||
|
output_label.append(-1) |
||||
|
if all(o == -1 for o in output_label): |
||||
|
# at least mask 1 |
||||
|
output_label[0] = tokens[0] |
||||
|
tokens[0] = mask |
||||
|
|
||||
|
return tokens, output_label |
||||
|
|
||||
|
|
||||
|
class MlmDataset(DetectFeatTxtTokDataset): |
||||
|
def __init__(self, txt_db, img_db): |
||||
|
assert isinstance(txt_db, TxtTokLmdb) |
||||
|
super().__init__(txt_db, img_db) |
||||
|
|
||||
|
def __getitem__(self, i): |
||||
|
""" |
||||
|
Return: |
||||
|
- input_ids : (L, ), i.e., [cls, wd, wd, ..., sep, 0, 0], 0s padded |
||||
|
- img_feat : (num_bb, d) |
||||
|
- img_pos_feat : (num_bb, 7) |
||||
|
- attn_masks : (L + num_bb, ), ie., [1, 1, ..., 0, 0, 1, 1] |
||||
|
- txt_labels : (L, ), [-1, -1, wid, -1, -1, -1] |
||||
|
0's padded so that (L + num_bb) % 8 == 0 |
||||
|
""" |
||||
|
example = super().__getitem__(i) |
||||
|
|
||||
|
# text input |
||||
|
input_ids, txt_labels = self.create_mlm_io(example['input_ids']) |
||||
|
|
||||
|
# img input |
||||
|
img_feat, img_pos_feat, num_bb = self._get_img_feat( |
||||
|
example['img_fname']) |
||||
|
|
||||
|
attn_masks = torch.ones(len(input_ids) + num_bb, dtype=torch.long) |
||||
|
|
||||
|
return input_ids, img_feat, img_pos_feat, attn_masks, txt_labels |
||||
|
|
||||
|
def create_mlm_io(self, input_ids): |
||||
|
input_ids, txt_labels = random_word(input_ids, |
||||
|
self.txt_db.v_range, |
||||
|
self.txt_db.mask) |
||||
|
input_ids = torch.tensor([self.txt_db.cls_] |
||||
|
+ input_ids |
||||
|
+ [self.txt_db.sep]) |
||||
|
txt_labels = torch.tensor([-1] + txt_labels + [-1]) |
||||
|
return input_ids, txt_labels |
||||
|
|
||||
|
|
||||
|
def mlm_collate(inputs): |
||||
|
""" |
||||
|
Return: |
||||
|
:input_ids (n, max_L) padded with 0 |
||||
|
:position_ids (n, max_L) padded with 0 |
||||
|
:txt_lens list of [txt_len] |
||||
|
:img_feat (n, max_num_bb, feat_dim) |
||||
|
:img_pos_feat (n, max_num_bb, 7) |
||||
|
:num_bbs list of [num_bb] |
||||
|
:attn_masks (n, max_{L + num_bb}) padded with 0 |
||||
|
:txt_labels (n, max_L) padded with -1 |
||||
|
""" |
||||
|
(input_ids, img_feats, img_pos_feats, attn_masks, txt_labels |
||||
|
) = map(list, unzip(inputs)) |
||||
|
|
||||
|
# text batches |
||||
|
txt_lens = [i.size(0) for i in input_ids] |
||||
|
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0) |
||||
|
txt_labels = pad_sequence(txt_labels, batch_first=True, padding_value=-1) |
||||
|
position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long |
||||
|
).unsqueeze(0) |
||||
|
|
||||
|
# image batches |
||||
|
num_bbs = [f.size(0) for f in img_feats] |
||||
|
img_feat = pad_tensors(img_feats, num_bbs) |
||||
|
img_pos_feat = pad_tensors(img_pos_feats, num_bbs) |
||||
|
|
||||
|
attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0) |
||||
|
|
||||
|
bs, max_tl = input_ids.size() |
||||
|
out_size = attn_masks.size(1) |
||||
|
gather_index = get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size) |
||||
|
|
||||
|
batch = {'input_ids': input_ids, |
||||
|
'position_ids': position_ids, |
||||
|
'img_feat': img_feat, |
||||
|
'img_pos_feat': img_pos_feat, |
||||
|
'attn_masks': attn_masks, |
||||
|
'gather_index': gather_index, |
||||
|
'txt_labels': txt_labels} |
||||
|
return batch |
||||
|
|
||||
|
|
||||
|
class BlindMlmDataset(Dataset): |
||||
|
def __init__(self, txt_db): |
||||
|
assert isinstance(txt_db, TxtTokLmdb) |
||||
|
self.txt_db = txt_db |
||||
|
self.lens, self.ids = get_ids_and_lens(txt_db) |
||||
|
|
||||
|
def __len__(self): |
||||
|
return len(self.ids) |
||||
|
|
||||
|
def __getitem__(self, i): |
||||
|
id_ = self.ids[i] |
||||
|
example = self.txt_db[id_] |
||||
|
input_ids, txt_labels = self.create_mlm_io(example['input_ids']) |
||||
|
attn_masks = torch.ones(len(input_ids), dtype=torch.long) |
||||
|
|
||||
|
return input_ids, attn_masks, txt_labels |
||||
|
|
||||
|
|
||||
|
def mlm_blind_collate(inputs): |
||||
|
input_ids, attn_masks, txt_labels = map(list, unzip(inputs)) |
||||
|
|
||||
|
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0) |
||||
|
position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long |
||||
|
).unsqueeze(0) |
||||
|
attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0) |
||||
|
txt_labels = pad_sequence(txt_labels, batch_first=True, padding_value=-1) |
||||
|
|
||||
|
batch = {'input_ids': input_ids, |
||||
|
'position_ids': position_ids, |
||||
|
'attn_masks': attn_masks, |
||||
|
'txt_labels': txt_labels} |
||||
|
return batch |
||||
|
|
||||
|
|
||||
|
def eval_mask(len_, num_samples=7): |
||||
|
""" build the mask for evaluating MLM |
||||
|
circularly mask 1 word out of every x words |
||||
|
""" |
||||
|
# build the random masks |
||||
|
if len_ <= num_samples: |
||||
|
masks = torch.eye(len_).bool() |
||||
|
num_samples = len_ |
||||
|
else: |
||||
|
mask_inds = [list(range(i, len_, num_samples)) |
||||
|
for i in range(num_samples)] |
||||
|
masks = torch.zeros(num_samples, len_).bool() |
||||
|
for i, indices in enumerate(mask_inds): |
||||
|
for j in indices: |
||||
|
masks.data[i, j] = 1 |
||||
|
assert (masks.sum(dim=0) != torch.ones(len_).long()).sum().item() == 0 |
||||
|
assert masks.sum().item() == len_ |
||||
|
return masks |
||||
|
|
||||
|
|
||||
|
def eval_gather_inds(len_, num_samples=7): |
||||
|
""" get the gather indices """ |
||||
|
inds = torch.arange(0, num_samples, dtype=torch.long) |
||||
|
mul = math.ceil(len_ / num_samples) |
||||
|
output = inds.repeat(mul)[:len_] |
||||
|
return output |
||||
|
|
||||
|
|
||||
|
def stack_pad_tensors(tensors, lens=None, ns=None, pad=0): |
||||
|
"""N x [B_i, T, ...]""" |
||||
|
if ns is None: |
||||
|
ns = [t.size(0) for t in tensors] |
||||
|
if lens is None: |
||||
|
lens = [t.size(1) for t in tensors] |
||||
|
max_len = max(lens) |
||||
|
bs = sum(ns) |
||||
|
hid_dims = tensors[0].size()[2:] |
||||
|
dtype = tensors[0].dtype |
||||
|
output = torch.zeros(bs, max_len, *hid_dims, dtype=dtype) |
||||
|
if pad: |
||||
|
output.data.fill_(pad) |
||||
|
i = 0 |
||||
|
for t, l, n in zip(tensors, lens, ns): |
||||
|
output.data[i:i+n, :l, ...] = t.data |
||||
|
i += n |
||||
|
return output |
||||
|
|
||||
|
|
||||
|
def expand_tensors(tensors, ns): |
||||
|
return [t.unsqueeze(0).expand(n, *tuple([-1]*t.dim())) |
||||
|
for t, n in zip(tensors, ns)] |
||||
|
|
||||
|
|
||||
|
class MlmEvalDataset(DetectFeatTxtTokDataset): |
||||
|
""" For evaluating MLM training task """ |
||||
|
def __init__(self, txt_db, img_db): |
||||
|
assert isinstance(txt_db, TxtTokLmdb) |
||||
|
super().__init__(txt_db, img_db) |
||||
|
|
||||
|
def __getitem__(self, i): |
||||
|
example = super().__getitem__(i) |
||||
|
|
||||
|
# text input |
||||
|
(input_ids, txt_labels, gather_inds |
||||
|
) = self.create_mlm_eval_io(example['input_ids']) |
||||
|
|
||||
|
# img input |
||||
|
img_feat, img_pos_feat, num_bb = self._get_img_feat( |
||||
|
example['img_fname']) |
||||
|
|
||||
|
attn_masks = torch.ones(input_ids.size(1) + num_bb, dtype=torch.long) |
||||
|
|
||||
|
return (input_ids, img_feat, img_pos_feat, attn_masks, |
||||
|
txt_labels, gather_inds) |
||||
|
|
||||
|
def create_mlm_eval_io(self, input_ids): |
||||
|
txt_labels = torch.tensor(input_ids) |
||||
|
masks = eval_mask(len(input_ids)) |
||||
|
n_mask = masks.size(0) |
||||
|
masks = torch.cat([torch.zeros(n_mask, 1).bool(), |
||||
|
masks, |
||||
|
torch.zeros(n_mask, 1).bool()], |
||||
|
dim=1) |
||||
|
input_ids = torch.tensor([[self.txt_db.cls_] |
||||
|
+ input_ids |
||||
|
+ [self.txt_db.sep] |
||||
|
for _ in range(n_mask)]) |
||||
|
input_ids.data.masked_fill_(masks, self.txt_db.mask) |
||||
|
gather_inds = eval_gather_inds(len(txt_labels)) |
||||
|
return input_ids, txt_labels, gather_inds |
||||
|
|
||||
|
|
||||
|
def _batch_gather_tgt(gather_inds, n_masks): |
||||
|
gather_tgts = [] |
||||
|
offset = 0 |
||||
|
for g, n in zip(gather_inds, n_masks): |
||||
|
gather_tgts.append(g + offset) |
||||
|
offset += n |
||||
|
gather_tgt = pad_sequence(gather_tgts, batch_first=True, padding_value=0) |
||||
|
return gather_tgt |
||||
|
|
||||
|
|
||||
|
def mlm_eval_collate(inputs): |
||||
|
(input_ids, img_feats, img_pos_feats, attn_masks, txt_labels, gather_inds |
||||
|
) = map(list, unzip(inputs)) |
||||
|
|
||||
|
# sizes |
||||
|
n_masks, txt_lens = map(list, unzip(i.size() for i in input_ids)) |
||||
|
|
||||
|
# text batches |
||||
|
input_ids = stack_pad_tensors(input_ids, txt_lens, n_masks) |
||||
|
position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long |
||||
|
).unsqueeze(0) |
||||
|
txt_labels = pad_sequence(txt_labels, batch_first=True, padding_value=-1) |
||||
|
gather_tgt = _batch_gather_tgt(gather_inds, n_masks) |
||||
|
|
||||
|
# image batches |
||||
|
num_bbs = [f.size(0) for f in img_feats] |
||||
|
img_feat = stack_pad_tensors(expand_tensors(img_feats, n_masks), |
||||
|
num_bbs, n_masks) |
||||
|
img_pos_feat = stack_pad_tensors(expand_tensors(img_pos_feats, n_masks), |
||||
|
num_bbs, n_masks) |
||||
|
|
||||
|
bs, max_tl = input_ids.size() |
||||
|
attn_masks = stack_pad_tensors(expand_tensors(attn_masks, n_masks), |
||||
|
None, n_masks) |
||||
|
out_size = attn_masks.size(1) |
||||
|
# repeat txt_lens, num_bbs |
||||
|
txt_lens = [l for l, n in zip(txt_lens, n_masks) for _ in range(n)] |
||||
|
num_bbs = [b for b, n in zip(num_bbs, n_masks) for _ in range(n)] |
||||
|
gather_index = get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size) |
||||
|
|
||||
|
batch = {'input_ids': input_ids, |
||||
|
'position_ids': position_ids, |
||||
|
'img_feat': img_feat, |
||||
|
'img_pos_feat': img_pos_feat, |
||||
|
'attn_masks': attn_masks, |
||||
|
'gather_index': gather_index, |
||||
|
'gather_tgt': gather_tgt, |
||||
|
'txt_labels': txt_labels} |
||||
|
return batch |
||||
|
|
||||
|
|
||||
|
class BlindMlmEvalDataset(Dataset): |
||||
|
def __init__(self, txt_db): |
||||
|
assert isinstance(txt_db, TxtTokLmdb) |
||||
|
self.txt_db = txt_db |
||||
|
self.lens, self.ids = get_ids_and_lens(txt_db) |
||||
|
|
||||
|
def __len__(self): |
||||
|
return len(self.ids) |
||||
|
|
||||
|
def __getitem__(self, i): |
||||
|
id_ = self.ids[i] |
||||
|
example = self.txt_db[id_] |
||||
|
input_ids = example['input_ids'] |
||||
|
|
||||
|
# text input |
||||
|
input_ids = example['input_ids'] |
||||
|
(input_ids, txt_labels, gather_inds |
||||
|
) = self.txt_db.create_mlm_eval_io(input_ids) |
||||
|
|
||||
|
attn_masks = torch.ones(len(input_ids), dtype=torch.long) |
||||
|
|
||||
|
return input_ids, attn_masks, txt_labels, gather_inds |
||||
|
|
||||
|
|
||||
|
def mlm_blind_eval_collate(inputs): |
||||
|
(input_ids, position_ids, attn_masks, txt_labels, gather_inds |
||||
|
) = map(list, unzip(inputs)) |
||||
|
|
||||
|
# sizes |
||||
|
n_masks, txt_lens = map(list, unzip(i.size() for i in input_ids)) |
||||
|
|
||||
|
# text batches |
||||
|
input_ids = stack_pad_tensors(input_ids, txt_lens, n_masks) |
||||
|
position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long |
||||
|
).unsqueeze(0) |
||||
|
attn_masks = stack_pad_tensors(expand_tensors(attn_masks, n_masks), |
||||
|
None, n_masks) |
||||
|
txt_labels = pad_sequence(txt_labels, batch_first=True, padding_value=-1) |
||||
|
gather_tgt = _batch_gather_tgt(gather_inds, n_masks) |
||||
|
|
||||
|
batch = {'input_ids': input_ids, |
||||
|
'position_ids': position_ids, |
||||
|
'attn_masks': attn_masks, |
||||
|
'gather_tgt': gather_tgt, |
||||
|
'txt_labels': txt_labels} |
||||
|
return batch |
@ -0,0 +1,287 @@ |
|||||
|
""" |
||||
|
MRM Datasets |
||||
|
""" |
||||
|
import random |
||||
|
|
||||
|
import torch |
||||
|
from torch.utils.data import Dataset |
||||
|
from torch.nn.utils.rnn import pad_sequence |
||||
|
from toolz.sandbox import unzip |
||||
|
from .data import DetectFeatTxtTokDataset, pad_tensors, get_gather_index |
||||
|
|
||||
|
|
||||
|
def _get_img_mask(mask_prob, num_bb): |
||||
|
img_mask = [random.random() < mask_prob for _ in range(num_bb)] |
||||
|
if not any(img_mask): |
||||
|
# at least mask 1 |
||||
|
img_mask[random.choice(range(num_bb))] = True |
||||
|
img_mask = torch.tensor(img_mask) |
||||
|
return img_mask |
||||
|
|
||||
|
|
||||
|
def _get_img_tgt_mask(img_mask, txt_len): |
||||
|
z = torch.zeros(txt_len, dtype=torch.bool) |
||||
|
img_mask_tgt = torch.cat([z, img_mask], dim=0) |
||||
|
return img_mask_tgt |
||||
|
|
||||
|
|
||||
|
def _get_feat_target(img_feat, img_masks): |
||||
|
img_masks_ext = img_masks.unsqueeze(-1).expand_as(img_feat) # (n, m, d) |
||||
|
feat_dim = img_feat.size(-1) |
||||
|
feat_targets = img_feat[img_masks_ext].contiguous().view( |
||||
|
-1, feat_dim) # (s, d) |
||||
|
return feat_targets |
||||
|
|
||||
|
|
||||
|
def _mask_img_feat(img_feat, img_masks): |
||||
|
img_masks_ext = img_masks.unsqueeze(-1).expand_as(img_feat) |
||||
|
img_feat_masked = img_feat.data.masked_fill(img_masks_ext, 0) |
||||
|
return img_feat_masked |
||||
|
|
||||
|
|
||||
|
class MrfrDataset(DetectFeatTxtTokDataset): |
||||
|
def __init__(self, mask_prob, *args, **kwargs): |
||||
|
super().__init__(*args, **kwargs) |
||||
|
self.mask_prob = mask_prob |
||||
|
|
||||
|
def __getitem__(self, i): |
||||
|
""" |
||||
|
Return: |
||||
|
- input_ids : (L, ), i.e., [cls, wd, wd, ..., sep, 0, 0], 0s padded |
||||
|
- img_feat : (num_bb, d) |
||||
|
- img_pos_feat : (num_bb, 7) |
||||
|
- attn_masks : (L + num_bb, ), ie., [1, 1, ..., 0, 0, 1, 1] |
||||
|
- img_mask : (num_bb, ) between {0, 1} |
||||
|
""" |
||||
|
example = super().__getitem__(i) |
||||
|
# text input |
||||
|
input_ids = example['input_ids'] |
||||
|
input_ids = self.txt_db.combine_inputs(input_ids) |
||||
|
|
||||
|
# image input features |
||||
|
img_feat, img_pos_feat, num_bb = self._get_img_feat( |
||||
|
example['img_fname']) |
||||
|
img_mask = _get_img_mask(self.mask_prob, num_bb) |
||||
|
img_mask_tgt = _get_img_tgt_mask(img_mask, len(input_ids)) |
||||
|
|
||||
|
attn_masks = torch.ones(len(input_ids) + num_bb, dtype=torch.long) |
||||
|
|
||||
|
return (input_ids, img_feat, img_pos_feat, |
||||
|
attn_masks, img_mask, img_mask_tgt) |
||||
|
|
||||
|
|
||||
|
def mrfr_collate(inputs): |
||||
|
""" |
||||
|
Return: |
||||
|
- input_ids : (n, max_L), i.e., [cls, wd, wd, ..., sep, 0, 0], 0s padded |
||||
|
- position_ids : (n, max_L) |
||||
|
- txt_lens : list of [input_len] |
||||
|
- img_feat : (n, max_num_bb, d) |
||||
|
- img_pos_feat : (n, max_num_bb, 7) |
||||
|
- num_bbs : list of [num_bb] |
||||
|
- attn_masks : (n, max_{L + num_bb}), ie., [1, 1, ..., 0, 0, 1, 1] |
||||
|
- img_masks : (n, max_num_bb) between {0, 1} |
||||
|
""" |
||||
|
(input_ids, img_feats, img_pos_feats, attn_masks, img_masks, img_mask_tgts, |
||||
|
) = map(list, unzip(inputs)) |
||||
|
|
||||
|
txt_lens = [i.size(0) for i in input_ids] |
||||
|
|
||||
|
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0) |
||||
|
position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long |
||||
|
).unsqueeze(0) |
||||
|
|
||||
|
num_bbs = [f.size(0) for f in img_feats] |
||||
|
img_feat = pad_tensors(img_feats, num_bbs) |
||||
|
img_pos_feat = pad_tensors(img_pos_feats, num_bbs) |
||||
|
|
||||
|
# mask features |
||||
|
img_masks = pad_sequence(img_masks, batch_first=True, padding_value=0) |
||||
|
feat_targets = _get_feat_target(img_feat, img_masks) |
||||
|
img_feat = _mask_img_feat(img_feat, img_masks) |
||||
|
img_mask_tgt = pad_sequence(img_mask_tgts, |
||||
|
batch_first=True, padding_value=0) |
||||
|
|
||||
|
attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0) |
||||
|
bs, max_tl = input_ids.size() |
||||
|
out_size = attn_masks.size(1) |
||||
|
gather_index = get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size) |
||||
|
|
||||
|
batch = {'input_ids': input_ids, |
||||
|
'position_ids': position_ids, |
||||
|
'img_feat': img_feat, |
||||
|
'img_pos_feat': img_pos_feat, |
||||
|
'attn_masks': attn_masks, |
||||
|
'gather_index': gather_index, |
||||
|
'feat_targets': feat_targets, |
||||
|
'img_masks': img_masks, |
||||
|
'img_mask_tgt': img_mask_tgt} |
||||
|
return batch |
||||
|
|
||||
|
|
||||
|
class OnlyImgMrfrDataset(Dataset): |
||||
|
""" an image-only MRM """ |
||||
|
def __init__(self, mask_prob, img_db): |
||||
|
self.ids, self.lens = map(list, unzip(self.img_db.name2nbb.items())) |
||||
|
|
||||
|
def __getitem__(self, i): |
||||
|
id_ = self.ids[i] |
||||
|
img_feat, img_pos_feat, num_bb = self._get_img_feat(id_) |
||||
|
attn_masks = torch.ones(num_bb, dtype=torch.long) |
||||
|
img_mask = _get_img_mask(self.mask_prob, num_bb) |
||||
|
|
||||
|
return img_feat, img_pos_feat, attn_masks, img_mask |
||||
|
|
||||
|
def _get_img_feat(self, fname): |
||||
|
img_feat, bb = self.img_db[fname] |
||||
|
img_bb = torch.cat([bb, bb[:, 4:5]*bb[:, 5:]], dim=-1) |
||||
|
num_bb = img_feat.size(0) |
||||
|
return img_feat, img_bb, num_bb |
||||
|
|
||||
|
|
||||
|
def mrfr_only_img_collate(inputs): |
||||
|
img_feats, img_pos_feats, attn_masks, img_masks = map(list, unzip(inputs)) |
||||
|
|
||||
|
attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0) |
||||
|
|
||||
|
num_bbs = [f.size(0) for f in img_feats] |
||||
|
img_feat = pad_tensors(img_feats, num_bbs) |
||||
|
img_pos_feat = pad_tensors(img_pos_feats, num_bbs) |
||||
|
|
||||
|
# mask features |
||||
|
img_masks = pad_sequence(img_masks, batch_first=True, padding_value=0) |
||||
|
feat_targets = _get_feat_target(img_feat, img_masks) |
||||
|
img_feat = _mask_img_feat(img_feat, img_masks) |
||||
|
|
||||
|
batch = {'img_feat': img_feat, |
||||
|
'img_pos_feat': img_pos_feat, |
||||
|
'attn_masks': attn_masks, |
||||
|
'feat_targets': feat_targets, |
||||
|
'img_masks': img_masks, |
||||
|
'img_mask_tgt': img_masks} |
||||
|
return batch |
||||
|
|
||||
|
|
||||
|
def _get_targets(img_masks, img_soft_label): |
||||
|
soft_label_dim = img_soft_label.size(-1) |
||||
|
img_masks_ext_for_label = img_masks.unsqueeze(-1).expand_as(img_soft_label) |
||||
|
label_targets = img_soft_label[img_masks_ext_for_label].contiguous().view( |
||||
|
-1, soft_label_dim) |
||||
|
return label_targets |
||||
|
|
||||
|
|
||||
|
class MrcDataset(DetectFeatTxtTokDataset): |
||||
|
def __init__(self, mask_prob, *args, **kwargs): |
||||
|
super().__init__(*args, **kwargs) |
||||
|
self.mask_prob = mask_prob |
||||
|
|
||||
|
def _get_img_feat(self, fname): |
||||
|
img_dump = self.img_db.get_dump(fname) |
||||
|
num_bb = self.img_db.name2nbb[fname] |
||||
|
img_feat = torch.tensor(img_dump['features']) |
||||
|
bb = torch.tensor(img_dump['norm_bb']) |
||||
|
img_bb = torch.cat([bb, bb[:, 4:5]*bb[:, 5:]], dim=-1) |
||||
|
img_soft_label = torch.tensor(img_dump['soft_labels']) |
||||
|
return img_feat, img_bb, img_soft_label, num_bb |
||||
|
|
||||
|
def __getitem__(self, i): |
||||
|
example = super().__getitem__(i) |
||||
|
img_feat, img_pos_feat, img_soft_labels, num_bb = self._get_img_feat( |
||||
|
example['img_fname']) |
||||
|
|
||||
|
# image input features |
||||
|
img_mask = _get_img_mask(self.mask_prob, num_bb) |
||||
|
|
||||
|
# text input |
||||
|
input_ids = example['input_ids'] |
||||
|
input_ids = self.txt_db.combine_inputs(input_ids) |
||||
|
img_mask_tgt = _get_img_tgt_mask(img_mask, len(input_ids)) |
||||
|
|
||||
|
attn_masks = torch.ones(len(input_ids) + num_bb, dtype=torch.long) |
||||
|
|
||||
|
return (input_ids, img_feat, img_pos_feat, |
||||
|
img_soft_labels, attn_masks, img_mask, img_mask_tgt) |
||||
|
|
||||
|
|
||||
|
def mrc_collate(inputs): |
||||
|
(input_ids, img_feats, img_pos_feats, img_soft_labels, |
||||
|
attn_masks, img_masks, img_mask_tgts) = map(list, unzip(inputs)) |
||||
|
|
||||
|
txt_lens = [i.size(0) for i in input_ids] |
||||
|
num_bbs = [f.size(0) for f in img_feats] |
||||
|
|
||||
|
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0) |
||||
|
position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long |
||||
|
).unsqueeze(0) |
||||
|
|
||||
|
img_feat = pad_tensors(img_feats, num_bbs) |
||||
|
img_pos_feat = pad_tensors(img_pos_feats, num_bbs) |
||||
|
img_soft_label = pad_tensors(img_soft_labels, num_bbs) |
||||
|
img_masks = pad_sequence(img_masks, batch_first=True, padding_value=0) |
||||
|
label_targets = _get_targets(img_masks, img_soft_label) |
||||
|
|
||||
|
img_feat = _mask_img_feat(img_feat, img_masks) |
||||
|
img_mask_tgt = pad_sequence(img_mask_tgts, |
||||
|
batch_first=True, padding_value=0) |
||||
|
|
||||
|
attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0) |
||||
|
bs, max_tl = input_ids.size() |
||||
|
out_size = attn_masks.size(1) |
||||
|
gather_index = get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size) |
||||
|
|
||||
|
batch = {'input_ids': input_ids, |
||||
|
'position_ids': position_ids, |
||||
|
'img_feat': img_feat, |
||||
|
'img_pos_feat': img_pos_feat, |
||||
|
'attn_masks': attn_masks, |
||||
|
'gather_index': gather_index, |
||||
|
'img_masks': img_masks, |
||||
|
'img_mask_tgt': img_mask_tgt, |
||||
|
'label_targets': label_targets} |
||||
|
return batch |
||||
|
|
||||
|
|
||||
|
class OnlyImgMrcDataset(OnlyImgMrfrDataset): |
||||
|
""" an image-only MRC """ |
||||
|
def __getitem__(self, i): |
||||
|
id_ = self.ids[i] |
||||
|
(img_feat, img_pos_feat, img_soft_labels, num_bb |
||||
|
) = self._get_img_feat(id_) |
||||
|
attn_masks = torch.ones(num_bb, dtype=torch.long) |
||||
|
img_mask = _get_img_mask(self.mask_prob, num_bb) |
||||
|
|
||||
|
return img_feat, img_pos_feat, img_soft_labels, attn_masks, img_mask |
||||
|
|
||||
|
def _get_img_feat(self, fname): |
||||
|
img_dump = self.img_db.get_dump(fname) |
||||
|
num_bb = self.img_db.name2nbb[fname] |
||||
|
img_feat = torch.tensor(img_dump['features']) |
||||
|
bb = torch.tensor(img_dump['norm_bb']) |
||||
|
img_bb = torch.cat([bb, bb[:, 4:5]*bb[:, 5:]], dim=-1) |
||||
|
img_soft_labels = torch.tensor(img_dump['soft_labels']) |
||||
|
return img_feat, img_bb, img_soft_labels, num_bb |
||||
|
|
||||
|
|
||||
|
def mrc_only_img_collate(inputs): |
||||
|
(img_feats, img_pos_feats, img_soft_labels, attn_masks, img_masks |
||||
|
) = map(list, unzip(inputs)) |
||||
|
|
||||
|
attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0) |
||||
|
img_masks = pad_sequence(img_masks, batch_first=True, padding_value=0) |
||||
|
num_bbs = [f.size(0) for f in img_feats] |
||||
|
|
||||
|
img_feat = pad_tensors(img_feats, num_bbs) |
||||
|
img_pos_feat = pad_tensors(img_pos_feats, num_bbs) |
||||
|
img_soft_label = pad_tensors(img_soft_labels, num_bbs) |
||||
|
label_targets = _get_targets(img_masks, img_soft_label) |
||||
|
|
||||
|
# mask features |
||||
|
img_feat = _mask_img_feat(img_feat, img_masks) |
||||
|
|
||||
|
batch = {'img_feat': img_feat, |
||||
|
'img_pos_feat': img_pos_feat, |
||||
|
'attn_masks': attn_masks, |
||||
|
'img_masks': img_masks, |
||||
|
'img_mask_tgt': img_masks, |
||||
|
'label_targets': label_targets} |
||||
|
return batch |
@ -0,0 +1,136 @@ |
|||||
|
""" |
||||
|
MRM Datasets (contrastive learning version) |
||||
|
""" |
||||
|
import torch |
||||
|
from torch.nn.utils.rnn import pad_sequence |
||||
|
from toolz.sandbox import unzip |
||||
|
from cytoolz import curry |
||||
|
|
||||
|
from .data import (DetectFeatLmdb, DetectFeatTxtTokDataset, |
||||
|
pad_tensors, get_gather_index) |
||||
|
from .mrm import _get_img_mask, _get_img_tgt_mask, _get_feat_target |
||||
|
from .itm import sample_negative |
||||
|
|
||||
|
|
||||
|
# FIXME diff implementation from mrfr, mrc |
||||
|
def _mask_img_feat(img_feat, img_masks, neg_feats, |
||||
|
noop_prob=0.1, change_prob=0.1): |
||||
|
rand = torch.rand(*img_masks.size()) |
||||
|
noop_mask = rand < noop_prob |
||||
|
change_mask = ~noop_mask & (rand < (noop_prob+change_prob)) & img_masks |
||||
|
img_masks_in = img_masks & ~noop_mask & ~change_mask |
||||
|
|
||||
|
img_masks_ext = img_masks_in.unsqueeze(-1).expand_as(img_feat) |
||||
|
img_feat_masked = img_feat.data.masked_fill(img_masks_ext, 0) |
||||
|
|
||||
|
n_neg = change_mask.sum().item() |
||||
|
feat_dim = neg_feats.size(-1) |
||||
|
index = torch.arange(0, change_mask.numel(), dtype=torch.long |
||||
|
).masked_select(change_mask.view(-1)) |
||||
|
index = index.unsqueeze(-1).expand(-1, feat_dim) |
||||
|
img_feat_out = img_feat_masked.view(-1, feat_dim).scatter( |
||||
|
dim=0, index=index, src=neg_feats[:n_neg]).view(*img_feat.size()) |
||||
|
|
||||
|
return img_feat_out, img_masks_in |
||||
|
|
||||
|
|
||||
|
class MrmNceDataset(DetectFeatTxtTokDataset): |
||||
|
def __init__(self, mask_prob, *args, **kwargs): |
||||
|
super().__init__(*args, **kwargs) |
||||
|
self.mask_prob = mask_prob |
||||
|
|
||||
|
def __getitem__(self, i): |
||||
|
example = super().__getitem__(i) |
||||
|
# text input |
||||
|
input_ids = example['input_ids'] |
||||
|
input_ids = self.txt_db.combine_inputs(input_ids) |
||||
|
|
||||
|
# image input features |
||||
|
img_feat, img_pos_feat, num_bb = self._get_img_feat( |
||||
|
example['img_fname']) |
||||
|
img_mask = _get_img_mask(self.mask_prob, num_bb) |
||||
|
img_mask_tgt = _get_img_tgt_mask(img_mask, len(input_ids)) |
||||
|
|
||||
|
attn_masks = torch.ones(len(input_ids) + num_bb, dtype=torch.long) |
||||
|
|
||||
|
return (input_ids, img_feat, img_pos_feat, |
||||
|
attn_masks, img_mask, img_mask_tgt, |
||||
|
example['img_fname']) |
||||
|
|
||||
|
|
||||
|
class NegativeImageSampler(object): |
||||
|
def __init__(self, img_dbs, neg_size, size_mul=8): |
||||
|
if not isinstance(img_dbs, list): |
||||
|
assert isinstance(img_dbs, DetectFeatLmdb) |
||||
|
img_dbs = [img_dbs] |
||||
|
self.neg_size = neg_size |
||||
|
self.img_db = JoinedDetectFeatLmdb(img_dbs) |
||||
|
all_imgs = [] |
||||
|
for db in img_dbs: |
||||
|
all_imgs.extend(db.name2nbb.keys()) |
||||
|
self.all_imgs = all_imgs |
||||
|
|
||||
|
def sample_negative_feats(self, pos_imgs): |
||||
|
neg_img_ids = sample_negative(self.all_imgs, pos_imgs, self.neg_size) |
||||
|
all_neg_feats = torch.cat([self.img_db[img][0] for img in neg_img_ids], |
||||
|
dim=0) |
||||
|
# only use multiples of 8 for tensorcores |
||||
|
n_cut = all_neg_feats.size(0) % 8 |
||||
|
if n_cut != 0: |
||||
|
return all_neg_feats[:-n_cut] |
||||
|
else: |
||||
|
return all_neg_feats |
||||
|
|
||||
|
|
||||
|
class JoinedDetectFeatLmdb(object): |
||||
|
def __init__(self, img_dbs): |
||||
|
assert all(isinstance(db, DetectFeatLmdb) for db in img_dbs) |
||||
|
self.img_dbs = img_dbs |
||||
|
|
||||
|
def __getitem__(self, file_name): |
||||
|
for db in self.img_dbs: |
||||
|
if file_name in db: |
||||
|
return db[file_name] |
||||
|
raise ValueError("image does not exists") |
||||
|
|
||||
|
|
||||
|
@curry |
||||
|
def mrm_nce_collate(neg_sampler, inputs): |
||||
|
(input_ids, img_feats, img_pos_feats, attn_masks, img_masks, img_mask_tgts, |
||||
|
positive_imgs) = map(list, unzip(inputs)) |
||||
|
|
||||
|
txt_lens = [i.size(0) for i in input_ids] |
||||
|
|
||||
|
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0) |
||||
|
position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long |
||||
|
).unsqueeze(0) |
||||
|
|
||||
|
num_bbs = [f.size(0) for f in img_feats] |
||||
|
img_feat = pad_tensors(img_feats, num_bbs) |
||||
|
img_pos_feat = pad_tensors(img_pos_feats, num_bbs) |
||||
|
neg_feats = neg_sampler.sample_negative_feats(positive_imgs) |
||||
|
|
||||
|
# mask features |
||||
|
img_masks = pad_sequence(img_masks, batch_first=True, padding_value=0) |
||||
|
feat_targets = _get_feat_target(img_feat, img_masks) |
||||
|
img_feat, img_masks_in = _mask_img_feat(img_feat, img_masks, neg_feats) |
||||
|
img_mask_tgt = pad_sequence(img_mask_tgts, |
||||
|
batch_first=True, padding_value=0) |
||||
|
|
||||
|
attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0) |
||||
|
bs, max_tl = input_ids.size() |
||||
|
out_size = attn_masks.size(1) |
||||
|
gather_index = get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size) |
||||
|
|
||||
|
batch = {'input_ids': input_ids, |
||||
|
'position_ids': position_ids, |
||||
|
'img_feat': img_feat, |
||||
|
'img_pos_feat': img_pos_feat, |
||||
|
'attn_masks': attn_masks, |
||||
|
'gather_index': gather_index, |
||||
|
'feat_targets': feat_targets, |
||||
|
'img_masks': img_masks, |
||||
|
'img_masks_in': img_masks_in, |
||||
|
'img_mask_tgt': img_mask_tgt, |
||||
|
'neg_feats': neg_feats} |
||||
|
return batch |
@ -0,0 +1,218 @@ |
|||||
|
""" |
||||
|
Copyright (c) Microsoft Corporation. |
||||
|
Licensed under the MIT license. |
||||
|
|
||||
|
NLVR2 dataset |
||||
|
""" |
||||
|
import copy |
||||
|
|
||||
|
import torch |
||||
|
from torch.nn.utils.rnn import pad_sequence |
||||
|
from toolz.sandbox import unzip |
||||
|
from cytoolz import concat |
||||
|
|
||||
|
from .data import (DetectFeatTxtTokDataset, TxtTokLmdb, DetectFeatLmdb, |
||||
|
get_ids_and_lens, pad_tensors, get_gather_index) |
||||
|
|
||||
|
|
||||
|
class Nlvr2PairedDataset(DetectFeatTxtTokDataset): |
||||
|
def __init__(self, txt_db, img_db, use_img_type=True): |
||||
|
assert isinstance(txt_db, TxtTokLmdb) |
||||
|
assert isinstance(img_db, DetectFeatLmdb) |
||||
|
self.txt_db = txt_db |
||||
|
self.img_db = img_db |
||||
|
txt_lens, self.ids = get_ids_and_lens(txt_db) |
||||
|
|
||||
|
txt2img = txt_db.txt2img |
||||
|
self.lens = [2*tl + sum(self.img_db.name2nbb[img] |
||||
|
for img in txt2img[id_]) |
||||
|
for tl, id_ in zip(txt_lens, self.ids)] |
||||
|
|
||||
|
self.use_img_type = use_img_type |
||||
|
|
||||
|
def __getitem__(self, i): |
||||
|
""" |
||||
|
[[txt, img1], |
||||
|
[txt, img2]] |
||||
|
""" |
||||
|
example = super().__getitem__(i) |
||||
|
target = example['target'] |
||||
|
outs = [] |
||||
|
for i, img in enumerate(example['img_fname']): |
||||
|
img_feat, img_pos_feat, num_bb = self._get_img_feat(img) |
||||
|
|
||||
|
# text input |
||||
|
input_ids = copy.deepcopy(example['input_ids']) |
||||
|
|
||||
|
input_ids = [self.txt_db.cls_] + input_ids + [self.txt_db.sep] |
||||
|
attn_masks = [1] * (len(input_ids) + num_bb) |
||||
|
input_ids = torch.tensor(input_ids) |
||||
|
attn_masks = torch.tensor(attn_masks) |
||||
|
if self.use_img_type: |
||||
|
img_type_ids = torch.tensor([i+1]*num_bb) |
||||
|
else: |
||||
|
img_type_ids = None |
||||
|
|
||||
|
outs.append((input_ids, img_feat, img_pos_feat, |
||||
|
attn_masks, img_type_ids)) |
||||
|
return tuple(outs), target |
||||
|
|
||||
|
|
||||
|
def nlvr2_paired_collate(inputs): |
||||
|
(input_ids, img_feats, img_pos_feats, attn_masks, |
||||
|
img_type_ids) = map(list, unzip(concat(outs for outs, _ in inputs))) |
||||
|
|
||||
|
txt_lens = [i.size(0) for i in input_ids] |
||||
|
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0) |
||||
|
position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long |
||||
|
).unsqueeze(0) |
||||
|
|
||||
|
# image batches |
||||
|
num_bbs = [f.size(0) for f in img_feats] |
||||
|
img_feat = pad_tensors(img_feats, num_bbs) |
||||
|
img_pos_feat = pad_tensors(img_pos_feats, num_bbs) |
||||
|
if img_type_ids[0] is None: |
||||
|
img_type_ids = None |
||||
|
else: |
||||
|
img_type_ids = pad_sequence(img_type_ids, |
||||
|
batch_first=True, padding_value=0) |
||||
|
|
||||
|
attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0) |
||||
|
targets = torch.Tensor([t for _, t in inputs]).long() |
||||
|
|
||||
|
bs, max_tl = input_ids.size() |
||||
|
out_size = attn_masks.size(1) |
||||
|
gather_index = get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size) |
||||
|
|
||||
|
batch = {'input_ids': input_ids, |
||||
|
'position_ids': position_ids, |
||||
|
'img_feat': img_feat, |
||||
|
'img_pos_feat': img_pos_feat, |
||||
|
'attn_masks': attn_masks, |
||||
|
'gather_index': gather_index, |
||||
|
'img_type_ids': img_type_ids, |
||||
|
'targets': targets} |
||||
|
return batch |
||||
|
|
||||
|
|
||||
|
class Nlvr2PairedEvalDataset(Nlvr2PairedDataset): |
||||
|
def __getitem__(self, i): |
||||
|
qid = self.ids[i] |
||||
|
outs, targets = super().__getitem__(i) |
||||
|
return qid, outs, targets |
||||
|
|
||||
|
|
||||
|
def nlvr2_paired_eval_collate(inputs): |
||||
|
qids, batch = [], [] |
||||
|
for id_, *tensors in inputs: |
||||
|
qids.append(id_) |
||||
|
batch.append(tensors) |
||||
|
batch = nlvr2_paired_collate(batch) |
||||
|
batch['qids'] = qids |
||||
|
return batch |
||||
|
|
||||
|
|
||||
|
class Nlvr2TripletDataset(DetectFeatTxtTokDataset): |
||||
|
def __init__(self, txt_db, img_db, use_img_type=True): |
||||
|
assert isinstance(txt_db, TxtTokLmdb) |
||||
|
assert isinstance(img_db, DetectFeatLmdb) |
||||
|
self.txt_db = txt_db |
||||
|
self.img_db = img_db |
||||
|
txt_lens, self.ids = get_ids_and_lens(txt_db) |
||||
|
|
||||
|
txt2img = txt_db.txt2img |
||||
|
self.lens = [tl + sum(self.img_db.name2nbb[img] |
||||
|
for img in txt2img[id_]) |
||||
|
for tl, id_ in zip(txt_lens, self.ids)] |
||||
|
|
||||
|
self.use_img_type = use_img_type |
||||
|
|
||||
|
def __getitem__(self, i): |
||||
|
""" |
||||
|
[[txt, img1], |
||||
|
[txt, img2]] |
||||
|
""" |
||||
|
example = super().__getitem__(i) |
||||
|
target = example['target'] |
||||
|
img_feats = [] |
||||
|
img_pos_feats = [] |
||||
|
num_bb = 0 |
||||
|
img_type_ids = [] |
||||
|
for i, img in enumerate(example['img_fname']): |
||||
|
feat, pos, nbb = self._get_img_feat(img) |
||||
|
img_feats.append(feat) |
||||
|
img_pos_feats.append(pos) |
||||
|
num_bb += nbb |
||||
|
if self.use_img_type: |
||||
|
img_type_ids.extend([i+1]*nbb) |
||||
|
img_feat = torch.cat(img_feats, dim=0) |
||||
|
img_pos_feat = torch.cat(img_pos_feats, dim=0) |
||||
|
if self.use_img_type: |
||||
|
img_type_ids = torch.tensor(img_type_ids) |
||||
|
else: |
||||
|
img_type_ids = None |
||||
|
|
||||
|
# text input |
||||
|
input_ids = copy.deepcopy(example['input_ids']) |
||||
|
|
||||
|
input_ids = [self.txt_db.cls_] + input_ids + [self.txt_db.sep] |
||||
|
attn_masks = [1] * (len(input_ids) + num_bb) |
||||
|
input_ids = torch.tensor(input_ids) |
||||
|
attn_masks = torch.tensor(attn_masks) |
||||
|
|
||||
|
return (input_ids, img_feat, img_pos_feat, attn_masks, |
||||
|
img_type_ids, target) |
||||
|
|
||||
|
|
||||
|
def nlvr2_triplet_collate(inputs): |
||||
|
(input_ids, img_feats, img_pos_feats, |
||||
|
attn_masks, img_type_ids, targets) = map(list, unzip(inputs)) |
||||
|
|
||||
|
txt_lens = [i.size(0) for i in input_ids] |
||||
|
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0) |
||||
|
position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long |
||||
|
).unsqueeze(0) |
||||
|
|
||||
|
# image batches |
||||
|
num_bbs = [f.size(0) for f in img_feats] |
||||
|
img_feat = pad_tensors(img_feats, num_bbs) |
||||
|
img_pos_feat = pad_tensors(img_pos_feats, num_bbs) |
||||
|
if img_type_ids[0] is None: |
||||
|
img_type_ids = None |
||||
|
else: |
||||
|
img_type_ids = pad_sequence(img_type_ids, |
||||
|
batch_first=True, padding_value=0) |
||||
|
|
||||
|
attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0) |
||||
|
targets = torch.Tensor(targets).long() |
||||
|
|
||||
|
bs, max_tl = input_ids.size() |
||||
|
out_size = attn_masks.size(1) |
||||
|
gather_index = get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size) |
||||
|
|
||||
|
batch = {'input_ids': input_ids, |
||||
|
'position_ids': position_ids, |
||||
|
'img_feat': img_feat, |
||||
|
'img_pos_feat': img_pos_feat, |
||||
|
'attn_masks': attn_masks, |
||||
|
'gather_index': gather_index, |
||||
|
'img_type_ids': img_type_ids, |
||||
|
'targets': targets} |
||||
|
return batch |
||||
|
|
||||
|
|
||||
|
class Nlvr2TripletEvalDataset(Nlvr2TripletDataset): |
||||
|
def __getitem__(self, i): |
||||
|
qid = self.ids[i] |
||||
|
tensors = super().__getitem__(i) |
||||
|
return (qid, *tensors) |
||||
|
|
||||
|
|
||||
|
def nlvr2_triplet_eval_collate(inputs): |
||||
|
qids, batch = [], [] |
||||
|
for id_, *tensors in inputs: |
||||
|
qids.append(id_) |
||||
|
batch.append(tensors) |
||||
|
batch = nlvr2_triplet_collate(batch) |
||||
|
batch['qids'] = qids |
||||
|
return batch |
@ -0,0 +1,319 @@ |
|||||
|
""" |
||||
|
Referring Expression Comprehension dataset |
||||
|
""" |
||||
|
import sys |
||||
|
import json |
||||
|
import random |
||||
|
import numpy as np |
||||
|
|
||||
|
import torch |
||||
|
from torch.utils.data import Dataset |
||||
|
from torch.nn.utils.rnn import pad_sequence |
||||
|
from toolz.sandbox import unzip |
||||
|
|
||||
|
from .data import TxtLmdb |
||||
|
|
||||
|
|
||||
|
class ReImageFeatDir(object): |
||||
|
def __init__(self, img_dir): |
||||
|
self.img_dir = img_dir |
||||
|
|
||||
|
def __getitem__(self, file_name): |
||||
|
img_dump = np.load(f'{self.img_dir}/{file_name}', allow_pickle=True) |
||||
|
img_feat = torch.tensor(img_dump['features']) |
||||
|
img_bb = torch.tensor(img_dump['norm_bb']) |
||||
|
return img_feat, img_bb |
||||
|
|
||||
|
|
||||
|
class ReDetectFeatDir(object): |
||||
|
def __init__(self, img_dir, conf_th=0.2, max_bb=100, min_bb=10, num_bb=36, |
||||
|
format_='npz'): |
||||
|
assert format_ == 'npz', 'only support npz for now.' |
||||
|
assert isinstance(img_dir, str), 'img_dir is path, not db.' |
||||
|
self.img_dir = img_dir |
||||
|
self.conf_th = conf_th |
||||
|
self.max_bb = max_bb |
||||
|
self.min_bb = min_bb |
||||
|
self.num_bb = num_bb |
||||
|
|
||||
|
def _compute_num_bb(self, img_dump): |
||||
|
num_bb = max(self.min_bb, (img_dump['conf'] > self.conf_th).sum()) |
||||
|
num_bb = min(self.max_bb, num_bb) |
||||
|
return num_bb |
||||
|
|
||||
|
def __getitem__(self, file_name): |
||||
|
# image input features |
||||
|
img_dump = np.load(f'{self.img_dir}/{file_name}', allow_pickle=True) |
||||
|
num_bb = self._compute_num_bb(img_dump) |
||||
|
img_feat = torch.tensor(img_dump['features'][:num_bb, :]) |
||||
|
img_bb = torch.tensor(img_dump['norm_bb'][:num_bb, :]) |
||||
|
return img_feat, img_bb |
||||
|
|
||||
|
|
||||
|
class ReferringExpressionDataset(Dataset): |
||||
|
def __init__(self, db_dir, img_dir, max_txt_len=60): |
||||
|
assert isinstance(img_dir, ReImageFeatDir) or \ |
||||
|
isinstance(img_dir, ReDetectFeatDir) |
||||
|
self.img_dir = img_dir |
||||
|
|
||||
|
# load refs = [{ref_id, sent_ids, ann_id, image_id, sentences, split}] |
||||
|
refs = json.load(open(f'{db_dir}/refs.json', 'r')) |
||||
|
self.ref_ids = [ref['ref_id'] for ref in refs] |
||||
|
self.Refs = {ref['ref_id']: ref for ref in refs} |
||||
|
|
||||
|
# load annotations = [{id, area, bbox, image_id, category_id}] |
||||
|
anns = json.load(open(f'{db_dir}/annotations.json', 'r')) |
||||
|
self.Anns = {ann['id']: ann for ann in anns} |
||||
|
|
||||
|
# load categories = [{id, name, supercategory}] |
||||
|
categories = json.load(open(f'{db_dir}/categories.json', 'r')) |
||||
|
self.Cats = {cat['id']: cat['name'] for cat in categories} |
||||
|
|
||||
|
# load images = [{id, file_name, ann_ids, height, width}] |
||||
|
images = json.load(open(f'{db_dir}/images.json', 'r')) |
||||
|
self.Images = {img['id']: img for img in images} |
||||
|
|
||||
|
# id2len: sent_id -> sent_len |
||||
|
id2len = json.load(open(f'{db_dir}/id2len.json', 'r')) |
||||
|
self.id2len = {int(_id): _len for _id, _len in id2len.items()} |
||||
|
self.max_txt_len = max_txt_len |
||||
|
self.sent_ids = self._get_sent_ids() |
||||
|
|
||||
|
# db[str(sent_id)] = |
||||
|
# {sent_id, sent, ref_id, ann_id, image_id, |
||||
|
# bbox, input_ids, toked_sent} |
||||
|
self.db = TxtLmdb(db_dir, readonly=True) |
||||
|
|
||||
|
# meta |
||||
|
meta = json.load(open(f'{db_dir}/meta.json', 'r')) |
||||
|
self.cls_ = meta['CLS'] |
||||
|
self.sep = meta['SEP'] |
||||
|
self.mask = meta['MASK'] |
||||
|
self.v_range = meta['v_range'] |
||||
|
|
||||
|
def shuffle(self): |
||||
|
# we shuffle ref_ids and make sent_ids according to ref_ids |
||||
|
random.shuffle(self.ref_ids) |
||||
|
self.sent_ids = self._get_sent_ids() |
||||
|
|
||||
|
def _get_sent_ids(self): |
||||
|
sent_ids = [] |
||||
|
for ref_id in self.ref_ids: |
||||
|
for sent_id in self.Refs[ref_id]['sent_ids']: |
||||
|
sent_len = self.id2len[sent_id] |
||||
|
if self.max_txt_len == -1 or sent_len < self.max_txt_len: |
||||
|
sent_ids.append(sent_id) |
||||
|
return sent_ids |
||||
|
|
||||
|
def _get_img_feat(self, fname): |
||||
|
img_feat, bb = self.img_dir[fname] |
||||
|
img_bb = torch.cat([bb, bb[:, 4:5]*bb[:, 5:]], dim=-1) |
||||
|
num_bb = img_feat.size(0) |
||||
|
return img_feat, img_bb, num_bb |
||||
|
|
||||
|
def __len__(self): |
||||
|
return len(self.sent_ids) |
||||
|
|
||||
|
def __getitem__(self, i): |
||||
|
""" |
||||
|
Return: |
||||
|
:input_ids : (L, ), i.e., [cls, wd, wd, ..., sep, 0, 0] |
||||
|
:position_ids : range(L) |
||||
|
:img_feat : (num_bb, d) |
||||
|
:img_pos_feat : (num_bb, 7) |
||||
|
:attn_masks : (L+num_bb, ), i.e., [1, 1, ..., 0, 0, 1, 1] |
||||
|
:obj_masks : (num_bb, ) all 0's |
||||
|
:target : (1, ) |
||||
|
""" |
||||
|
# {sent_id, sent, ref_id, ann_id, image_id, |
||||
|
# bbox, input_ids, toked_sent} |
||||
|
sent_id = self.sent_ids[i] |
||||
|
txt_dump = self.db[str(sent_id)] |
||||
|
image_id = txt_dump['image_id'] |
||||
|
fname = f'visual_grounding_coco_gt_{int(image_id):012}.npz' |
||||
|
img_feat, img_pos_feat, num_bb = self._get_img_feat(fname) |
||||
|
|
||||
|
# text input |
||||
|
input_ids = txt_dump['input_ids'] |
||||
|
input_ids = [self.cls_] + input_ids + [self.sep] |
||||
|
attn_masks = [1] * len(input_ids) |
||||
|
position_ids = list(range(len(input_ids))) |
||||
|
attn_masks += [1] * num_bb |
||||
|
|
||||
|
input_ids = torch.tensor(input_ids) |
||||
|
position_ids = torch.tensor(position_ids) |
||||
|
attn_masks = torch.tensor(attn_masks) |
||||
|
|
||||
|
# target bbox |
||||
|
img = self.Images[image_id] |
||||
|
assert len(img['ann_ids']) == num_bb, \ |
||||
|
'Please use visual_grounding_coco_gt' |
||||
|
target = img['ann_ids'].index(txt_dump['ann_id']) |
||||
|
target = torch.tensor([target]) |
||||
|
|
||||
|
# obj_masks, to be padded with 1, for masking out non-object prob. |
||||
|
obj_masks = torch.tensor([0]*len(img['ann_ids'])).bool() |
||||
|
|
||||
|
return (input_ids, position_ids, img_feat, img_pos_feat, attn_masks, |
||||
|
obj_masks, target) |
||||
|
|
||||
|
|
||||
|
def re_collate(inputs): |
||||
|
""" |
||||
|
Return: |
||||
|
:input_ids : (n, max_L) padded with 0 |
||||
|
:position_ids : (n, max_L) padded with 0 |
||||
|
:txt_lens : list of [txt_len] |
||||
|
:img_feat : (n, max_num_bb, feat_dim) |
||||
|
:img_pos_feat : (n, max_num_bb, 7) |
||||
|
:num_bbs : list of [num_bb] |
||||
|
:attn_masks : (n, max_{L+num_bb}) padded with 0 |
||||
|
:obj_masks : (n, max_num_bb) padded with 1 |
||||
|
:targets : (n, ) |
||||
|
""" |
||||
|
(input_ids, position_ids, img_feats, img_pos_feats, attn_masks, obj_masks, |
||||
|
targets) = map(list, unzip(inputs)) |
||||
|
|
||||
|
txt_lens = [i.size(0) for i in input_ids] |
||||
|
num_bbs = [f.size(0) for f in img_feats] |
||||
|
|
||||
|
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0) |
||||
|
position_ids = pad_sequence(position_ids, |
||||
|
batch_first=True, padding_value=0) |
||||
|
attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0) |
||||
|
targets = torch.cat(targets, dim=0) |
||||
|
obj_masks = pad_sequence(obj_masks, |
||||
|
batch_first=True, padding_value=1).bool() |
||||
|
|
||||
|
batch_size = len(img_feats) |
||||
|
num_bb = max(num_bbs) |
||||
|
feat_dim = img_feats[0].size(1) |
||||
|
pos_dim = img_pos_feats[0].size(1) |
||||
|
img_feat = torch.zeros(batch_size, num_bb, feat_dim) |
||||
|
img_pos_feat = torch.zeros(batch_size, num_bb, pos_dim) |
||||
|
for i, (im, pos) in enumerate(zip(img_feats, img_pos_feats)): |
||||
|
len_ = im.size(0) |
||||
|
img_feat.data[i, :len_, :] = im.data |
||||
|
img_pos_feat.data[i, :len_, :] = pos.data |
||||
|
|
||||
|
return (input_ids, position_ids, txt_lens, |
||||
|
img_feat, img_pos_feat, num_bbs, |
||||
|
attn_masks, obj_masks, targets) |
||||
|
|
||||
|
|
||||
|
class ReferringExpressionEvalDataset(ReferringExpressionDataset): |
||||
|
def __getitem__(self, i): |
||||
|
""" |
||||
|
Return: |
||||
|
:input_ids : (L, ), i.e., [cls, wd, wd, ..., sep, 0, 0] |
||||
|
:position_ids : range(L) |
||||
|
:img_feat : (num_bb, d) |
||||
|
:img_pos_feat : (num_bb, 7) |
||||
|
:attn_masks : (L+num_bb, ), i.e., [1, 1, ..., 0, 0, 1, 1] |
||||
|
:obj_masks : (num_bb, ) all 0's |
||||
|
:tgt_box : ndarray (4, ) xywh |
||||
|
:obj_boxes : ndarray (num_bb, 4) xywh |
||||
|
:sent_id |
||||
|
""" |
||||
|
# {sent_id, sent, ref_id, ann_id, image_id, |
||||
|
# bbox, input_ids, toked_sent} |
||||
|
sent_id = self.sent_ids[i] |
||||
|
txt_dump = self.db[str(sent_id)] |
||||
|
image_id = txt_dump['image_id'] |
||||
|
if isinstance(self.img_dir, ReImageFeatDir): |
||||
|
if '_gt' in self.img_dir.img_dir: |
||||
|
fname = f'visual_grounding_coco_gt_{int(image_id):012}.npz' |
||||
|
elif '_det' in self.img_dir.img_dir: |
||||
|
fname = f'visual_grounding_det_coco_{int(image_id):012}.npz' |
||||
|
elif isinstance(self.img_dir, ReDetectFeatDir): |
||||
|
fname = f'coco_train2014_{int(image_id):012}.npz' |
||||
|
else: |
||||
|
sys.exit('%s not supported.' % self.img_dir) |
||||
|
img_feat, img_pos_feat, num_bb = self._get_img_feat(fname) |
||||
|
|
||||
|
# image info |
||||
|
img = self.Images[image_id] |
||||
|
im_width, im_height = img['width'], img['height'] |
||||
|
|
||||
|
# object boxes, img_pos_feat (xyxywha) -> xywh |
||||
|
obj_boxes = np.stack([img_pos_feat[:, 0]*im_width, |
||||
|
img_pos_feat[:, 1]*im_height, |
||||
|
img_pos_feat[:, 4]*im_width, |
||||
|
img_pos_feat[:, 5]*im_height], axis=1) |
||||
|
obj_masks = torch.tensor([0]*num_bb).bool() |
||||
|
|
||||
|
# target box |
||||
|
tgt_box = np.array(txt_dump['bbox']) # xywh |
||||
|
|
||||
|
# text input |
||||
|
input_ids = txt_dump['input_ids'] |
||||
|
input_ids = [self.cls_] + input_ids + [self.sep] |
||||
|
attn_masks = [1] * len(input_ids) |
||||
|
position_ids = list(range(len(input_ids))) |
||||
|
attn_masks += [1] * num_bb |
||||
|
|
||||
|
input_ids = torch.tensor(input_ids) |
||||
|
position_ids = torch.tensor(position_ids) |
||||
|
attn_masks = torch.tensor(attn_masks) |
||||
|
|
||||
|
return (input_ids, position_ids, img_feat, img_pos_feat, attn_masks, |
||||
|
obj_masks, tgt_box, obj_boxes, sent_id) |
||||
|
|
||||
|
# IoU function |
||||
|
def computeIoU(self, box1, box2): |
||||
|
# each box is of [x1, y1, w, h] |
||||
|
inter_x1 = max(box1[0], box2[0]) |
||||
|
inter_y1 = max(box1[1], box2[1]) |
||||
|
inter_x2 = min(box1[0]+box1[2]-1, box2[0]+box2[2]-1) |
||||
|
inter_y2 = min(box1[1]+box1[3]-1, box2[1]+box2[3]-1) |
||||
|
|
||||
|
if inter_x1 < inter_x2 and inter_y1 < inter_y2: |
||||
|
inter = (inter_x2-inter_x1+1)*(inter_y2-inter_y1+1) |
||||
|
else: |
||||
|
inter = 0 |
||||
|
union = box1[2]*box1[3] + box2[2]*box2[3] - inter |
||||
|
return float(inter)/union |
||||
|
|
||||
|
|
||||
|
def re_eval_collate(inputs): |
||||
|
""" |
||||
|
Return: |
||||
|
:input_ids : (n, max_L) |
||||
|
:position_ids : (n, max_L) |
||||
|
:txt_lens : list of [txt_len] |
||||
|
:img_feat : (n, max_num_bb, d) |
||||
|
:img_pos_feat : (n, max_num_bb, 7) |
||||
|
:num_bbs : list of [num_bb] |
||||
|
:attn_masks : (n, max{L+num_bb}) |
||||
|
:obj_masks : (n, max_num_bb) |
||||
|
:tgt_box : list of n [xywh] |
||||
|
:obj_boxes : list of n [[xywh, xywh, ...]] |
||||
|
:sent_ids : list of n [sent_id] |
||||
|
""" |
||||
|
(input_ids, position_ids, img_feats, img_pos_feats, attn_masks, obj_masks, |
||||
|
tgt_box, obj_boxes, sent_ids) = map(list, unzip(inputs)) |
||||
|
|
||||
|
txt_lens = [i.size(0) for i in input_ids] |
||||
|
num_bbs = [f.size(0) for f in img_feats] |
||||
|
|
||||
|
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0) |
||||
|
position_ids = pad_sequence(position_ids, |
||||
|
batch_first=True, padding_value=0) |
||||
|
attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0) |
||||
|
obj_masks = pad_sequence(obj_masks, |
||||
|
batch_first=True, padding_value=1).bool() |
||||
|
|
||||
|
batch_size = len(img_feats) |
||||
|
num_bb = max(num_bbs) |
||||
|
feat_dim = img_feats[0].size(1) |
||||
|
pos_dim = img_pos_feats[0].size(1) |
||||
|
img_feat = torch.zeros(batch_size, num_bb, feat_dim) |
||||
|
img_pos_feat = torch.zeros(batch_size, num_bb, pos_dim) |
||||
|
for i, (im, pos) in enumerate(zip(img_feats, img_pos_feats)): |
||||
|
len_ = im.size(0) |
||||
|
img_feat.data[i, :len_, :] = im.data |
||||
|
img_pos_feat.data[i, :len_, :] = pos.data |
||||
|
|
||||
|
return (input_ids, position_ids, txt_lens, |
||||
|
img_feat, img_pos_feat, num_bbs, |
||||
|
attn_masks, obj_masks, tgt_box, obj_boxes, sent_ids) |
@ -0,0 +1,116 @@ |
|||||
|
""" sampler for length bucketing (batch by tokens) """ |
||||
|
import math |
||||
|
import random |
||||
|
|
||||
|
import torch |
||||
|
import horovod.torch as hvd |
||||
|
from torch.utils.data import Sampler |
||||
|
from cytoolz import partition_all |
||||
|
|
||||
|
|
||||
|
class TokenBucketSampler(Sampler): |
||||
|
def __init__(self, lens, bucket_size, batch_size, |
||||
|
droplast=False, size_multiple=8): |
||||
|
self._lens = lens |
||||
|
self._max_tok = batch_size |
||||
|
self._bucket_size = bucket_size |
||||
|
self._droplast = droplast |
||||
|
self._size_mul = size_multiple |
||||
|
|
||||
|
def _create_ids(self): |
||||
|
return list(range(len(self._lens))) |
||||
|
|
||||
|
def _sort_fn(self, i): |
||||
|
return self._lens[i] |
||||
|
|
||||
|
def __iter__(self): |
||||
|
ids = self._create_ids() |
||||
|
random.shuffle(ids) |
||||
|
buckets = [sorted(ids[i:i+self._bucket_size], |
||||
|
key=self._sort_fn, reverse=True) |
||||
|
for i in range(0, len(ids), self._bucket_size)] |
||||
|
# fill batches until max_token (include padding) |
||||
|
batches = [] |
||||
|
for bucket in buckets: |
||||
|
max_len = 0 |
||||
|
batch_indices = [] |
||||
|
for indices in partition_all(self._size_mul, bucket): |
||||
|
max_len = max(max_len, max(self._lens[i] for i in indices)) |
||||
|
if (max_len * (len(batch_indices) + self._size_mul) |
||||
|
> self._max_tok): |
||||
|
if not batch_indices: |
||||
|
raise ValueError( |
||||
|
"max_tokens too small / max_seq_len too long") |
||||
|
assert len(batch_indices) % self._size_mul == 0 |
||||
|
batches.append(batch_indices) |
||||
|
batch_indices = list(indices) |
||||
|
else: |
||||
|
batch_indices.extend(indices) |
||||
|
if not self._droplast and batch_indices: |
||||
|
batches.append(batch_indices) |
||||
|
random.shuffle(batches) |
||||
|
return iter(batches) |
||||
|
|
||||
|
def __len__(self): |
||||
|
raise ValueError("NOT supported. " |
||||
|
"This has some randomness across epochs") |
||||
|
|
||||
|
|
||||
|
class DistributedSampler(Sampler): |
||||
|
"""Sampler that restricts data loading to a subset of the dataset. |
||||
|
|
||||
|
It is especially useful in conjunction with |
||||
|
:class:`torch.nn.parallel.DistributedDataParallel`. In such case, each |
||||
|
process can pass a DistributedSampler instance as a DataLoader sampler, |
||||
|
and load a subset of the original dataset that is exclusive to it. |
||||
|
|
||||
|
.. note:: |
||||
|
Dataset is assumed to be of constant size. |
||||
|
|
||||
|
Arguments: |
||||
|
dataset: Dataset used for sampling. |
||||
|
num_replicas (optional): Number of processes participating in |
||||
|
distributed training. |
||||
|
rank (optional): Rank of the current process within num_replicas. |
||||
|
shuffle (optional): If true (default), sampler will shuffle the indices |
||||
|
""" |
||||
|
|
||||
|
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True): |
||||
|
if num_replicas is None: |
||||
|
num_replicas = hvd.size() |
||||
|
if rank is None: |
||||
|
rank = hvd.rank() |
||||
|
self.dataset = dataset |
||||
|
self.num_replicas = num_replicas |
||||
|
self.rank = rank |
||||
|
self.epoch = 0 |
||||
|
self.num_samples = int(math.ceil(len(self.dataset) |
||||
|
* 1.0 / self.num_replicas)) |
||||
|
self.total_size = self.num_samples * self.num_replicas |
||||
|
self.shuffle = shuffle |
||||
|
|
||||
|
def __iter__(self): |
||||
|
# deterministically shuffle based on epoch |
||||
|
g = torch.Generator() |
||||
|
g.manual_seed(self.epoch) |
||||
|
|
||||
|
indices = list(range(len(self.dataset))) |
||||
|
# add extra samples to make it evenly divisible |
||||
|
indices += indices[:(self.total_size - len(indices))] |
||||
|
assert len(indices) == self.total_size |
||||
|
|
||||
|
# subsample |
||||
|
indices = indices[self.rank:self.total_size:self.num_replicas] |
||||
|
|
||||
|
if self.shuffle: |
||||
|
shufle_ind = torch.randperm(len(indices), generator=g).tolist() |
||||
|
indices = [indices[i] for i in shufle_ind] |
||||
|
assert len(indices) == self.num_samples |
||||
|
|
||||
|
return iter(indices) |
||||
|
|
||||
|
def __len__(self): |
||||
|
return self.num_samples |
||||
|
|
||||
|
def set_epoch(self, epoch): |
||||
|
self.epoch = epoch |
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -0,0 +1,725 @@ |
|||||
|
""" |
||||
|
VCR dataset |
||||
|
""" |
||||
|
import json |
||||
|
import copy |
||||
|
import random |
||||
|
|
||||
|
import torch |
||||
|
from torch.nn.utils.rnn import pad_sequence |
||||
|
from toolz.sandbox import unzip |
||||
|
from torch.utils.data import Dataset |
||||
|
|
||||
|
from .data import DetectFeatLmdb, TxtLmdb, random_word |
||||
|
from .mrc import DetectFeatDir_for_mrc |
||||
|
|
||||
|
|
||||
|
class ImageTextDataset(Dataset): |
||||
|
def __init__(self, db_dir, img_dir_gt=None, img_dir=None, |
||||
|
max_txt_len=120, task="qa"): |
||||
|
self.txt_lens = [] |
||||
|
self.ids = [] |
||||
|
self.task = task |
||||
|
for id_, len_ in json.load(open(f'{db_dir}/id2len_{task}.json') |
||||
|
).items(): |
||||
|
if max_txt_len == -1 or len_ <= max_txt_len: |
||||
|
self.txt_lens.append(len_) |
||||
|
self.ids.append(id_) |
||||
|
|
||||
|
self.db = TxtLmdb(db_dir, readonly=True) |
||||
|
self.img_dir = img_dir |
||||
|
self.img_dir_gt = img_dir_gt |
||||
|
|
||||
|
def __len__(self): |
||||
|
return len(self.ids) |
||||
|
|
||||
|
def __getitem__(self, i): |
||||
|
id_ = self.ids[i] |
||||
|
txt_dump = self.db[id_] |
||||
|
img_dump_gt, img_dump = None, None |
||||
|
img_fname_gt, img_fname = txt_dump['img_fname'] |
||||
|
if self.img_dump_gt: |
||||
|
img_dump_gt = self.img_dump_gt[img_fname_gt] |
||||
|
if self.img_dir: |
||||
|
img_dump = self.img_dir[img_fname] |
||||
|
return img_dump_gt, img_dump, txt_dump |
||||
|
|
||||
|
|
||||
|
class DetectFeatBertTokDataset(ImageTextDataset): |
||||
|
def __init__(self, db_dir, img_dir_gt=None, img_dir=None, |
||||
|
max_txt_len=60, task="qa"): |
||||
|
assert not (img_dir_gt is None and img_dir is None),\ |
||||
|
"image_dir_gt and img_dir cannot all be None" |
||||
|
assert task == "qa" or task == "qar",\ |
||||
|
"VCR only allow two tasks: qa or qar" |
||||
|
assert img_dir_gt is None or isinstance(img_dir_gt, DetectFeatLmdb) |
||||
|
assert img_dir is None or isinstance(img_dir, DetectFeatLmdb) |
||||
|
|
||||
|
super().__init__(db_dir, img_dir_gt, img_dir, max_txt_len, task) |
||||
|
txt2img = json.load(open(f'{db_dir}/txt2img.json')) |
||||
|
if self.img_dir and self.img_dir_gt: |
||||
|
self.lens = [tl+self.img_dir_gt.name2nbb[txt2img[id_][0]] + |
||||
|
self.img_dir.name2nbb[txt2img[id_][1]] |
||||
|
for tl, id_ in zip(self.txt_lens, self.ids)] |
||||
|
elif self.img_dir: |
||||
|
self.lens = [tl+self.img_dir.name2nbb[txt2img[id_][1]] |
||||
|
for tl, id_ in zip(self.txt_lens, self.ids)] |
||||
|
else: |
||||
|
self.lens = [tl+self.img_dir_gt.name2nbb[txt2img[id_][0]] |
||||
|
for tl, id_ in zip(self.txt_lens, self.ids)] |
||||
|
|
||||
|
meta = json.load(open(f'{db_dir}/meta.json', 'r')) |
||||
|
self.cls_ = meta['CLS'] |
||||
|
self.sep = meta['SEP'] |
||||
|
self.mask = meta['MASK'] |
||||
|
self.v_range = meta['v_range'] |
||||
|
|
||||
|
def _get_img_feat(self, fname_gt, fname): |
||||
|
if self.img_dir and self.img_dir_gt: |
||||
|
img_feat_gt, bb_gt = self.img_dir_gt[fname_gt] |
||||
|
img_bb_gt = torch.cat([bb_gt, bb_gt[:, 4:5]*bb_gt[:, 5:]], dim=-1) |
||||
|
|
||||
|
img_feat, bb = self.img_dir[fname] |
||||
|
img_bb = torch.cat([bb, bb[:, 4:5]*bb[:, 5:]], dim=-1) |
||||
|
|
||||
|
img_feat = torch.cat([img_feat_gt, img_feat], dim=0) |
||||
|
img_bb = torch.cat([img_bb_gt, img_bb], dim=0) |
||||
|
num_bb = img_feat.size(0) |
||||
|
elif self.img_dir: |
||||
|
img_feat, bb = self.img_dir[fname] |
||||
|
img_bb = torch.cat([bb, bb[:, 4:5]*bb[:, 5:]], dim=-1) |
||||
|
num_bb = img_feat.size(0) |
||||
|
elif self.img_dir_gt: |
||||
|
img_feat, bb = self.img_dir_gt[fname_gt] |
||||
|
img_bb = torch.cat([bb, bb[:, 4:5]*bb[:, 5:]], dim=-1) |
||||
|
num_bb = img_feat.size(0) |
||||
|
return img_feat, img_bb, num_bb |
||||
|
|
||||
|
|
||||
|
class VcrDataset(DetectFeatBertTokDataset): |
||||
|
def __init__(self, mask_prob, *args, **kwargs): |
||||
|
super().__init__(*args, **kwargs) |
||||
|
self.mask_prob = mask_prob |
||||
|
del self.txt_lens |
||||
|
|
||||
|
def _get_input_ids(self, txt_dump): |
||||
|
# text input |
||||
|
input_ids_q = txt_dump['input_ids'] |
||||
|
type_ids_q = [0]*len(input_ids_q) |
||||
|
input_ids_as = txt_dump['input_ids_as'] |
||||
|
if self.task == "qar": |
||||
|
input_ids_rs = txt_dump['input_ids_rs'] |
||||
|
answer_label = txt_dump['qa_target'] |
||||
|
assert answer_label >= 0, "answer_label < 0" |
||||
|
input_ids_gt_a = [self.sep] + copy.deepcopy( |
||||
|
input_ids_as[answer_label]) |
||||
|
type_ids_gt_a = [2] * len(input_ids_gt_a) |
||||
|
type_ids_q += type_ids_gt_a |
||||
|
input_ids_q += input_ids_gt_a |
||||
|
input_ids_for_choices = input_ids_rs |
||||
|
else: |
||||
|
input_ids_for_choices = input_ids_as |
||||
|
return input_ids_q, input_ids_for_choices, type_ids_q |
||||
|
|
||||
|
def __getitem__(self, i): |
||||
|
id_ = self.ids[i] |
||||
|
txt_dump = self.db[id_] |
||||
|
img_feat, img_pos_feat, num_bb = self._get_img_feat( |
||||
|
txt_dump['img_fname'][0], txt_dump['img_fname'][1]) |
||||
|
object_targets = txt_dump["object_ids"] |
||||
|
input_ids_q, input_ids_for_choices, type_ids_q = self._get_input_ids( |
||||
|
txt_dump) |
||||
|
label = txt_dump['%s_target' % (self.task)] |
||||
|
|
||||
|
choice_num_bbs, choice_img_feats, choice_img_pos_feats = ( |
||||
|
[], [], []) |
||||
|
(choice_txt_lens, choice_input_ids, choice_txt_type_ids, |
||||
|
choice_attn_masks, choice_position_ids, choice_targets) = ( |
||||
|
[], [], [], [], [], []) |
||||
|
choice_obj_targets, choice_img_masks = ([], []) |
||||
|
|
||||
|
for index, input_ids_a in enumerate(input_ids_for_choices): |
||||
|
if index == label: |
||||
|
target = torch.tensor([1]).long() |
||||
|
else: |
||||
|
target = torch.tensor([0]).long() |
||||
|
input_ids = [self.cls_] + copy.deepcopy(input_ids_q) +\ |
||||
|
[self.sep] + input_ids_a + [self.sep] |
||||
|
type_id_for_choice = 3 if type_ids_q[-1] == 2 else 2 |
||||
|
txt_type_ids = [0] + type_ids_q + [type_id_for_choice]*( |
||||
|
len(input_ids_a)+2) |
||||
|
attn_masks = [1] * len(input_ids) |
||||
|
position_ids = list(range(len(input_ids))) |
||||
|
attn_masks += [1] * num_bb |
||||
|
|
||||
|
input_ids = torch.tensor(input_ids) |
||||
|
position_ids = torch.tensor(position_ids) |
||||
|
attn_masks = torch.tensor(attn_masks) |
||||
|
txt_type_ids = torch.tensor(txt_type_ids) |
||||
|
|
||||
|
choice_txt_lens.append(len(input_ids)) |
||||
|
choice_input_ids.append(input_ids) |
||||
|
choice_attn_masks.append(attn_masks) |
||||
|
choice_position_ids.append(position_ids) |
||||
|
choice_txt_type_ids.append(txt_type_ids) |
||||
|
|
||||
|
choice_num_bbs.append(num_bb) |
||||
|
choice_img_feats.append(img_feat) |
||||
|
choice_img_pos_feats.append(img_pos_feat) |
||||
|
choice_targets.append(target) |
||||
|
|
||||
|
# mask image input features |
||||
|
num_gt_bb = len(object_targets) |
||||
|
num_det_bb = num_bb - num_gt_bb |
||||
|
# only mask gt features |
||||
|
img_mask = [random.random() < self.mask_prob |
||||
|
for _ in range(num_gt_bb)] |
||||
|
if not any(img_mask): |
||||
|
# at least mask 1 |
||||
|
img_mask[0] = True |
||||
|
img_mask += [False]*num_det_bb |
||||
|
img_mask = torch.tensor(img_mask) |
||||
|
object_targets += [0]*num_det_bb |
||||
|
obj_targets = torch.tensor(object_targets) |
||||
|
|
||||
|
choice_obj_targets.append(obj_targets) |
||||
|
choice_img_masks.append(img_mask) |
||||
|
|
||||
|
return (choice_input_ids, choice_position_ids, choice_txt_lens, |
||||
|
choice_txt_type_ids, |
||||
|
choice_img_feats, choice_img_pos_feats, choice_num_bbs, |
||||
|
choice_attn_masks, choice_targets, choice_obj_targets, |
||||
|
choice_img_masks) |
||||
|
|
||||
|
|
||||
|
def vcr_collate(inputs): |
||||
|
(input_ids, position_ids, txt_lens, txt_type_ids, img_feats, |
||||
|
img_pos_feats, num_bbs, attn_masks, targets, |
||||
|
obj_targets, img_masks) = map(list, unzip(inputs)) |
||||
|
|
||||
|
all_num_bbs, all_img_feats, all_img_pos_feats = ( |
||||
|
[], [], []) |
||||
|
all_txt_lens, all_input_ids, all_attn_masks,\ |
||||
|
all_position_ids, all_txt_type_ids = ( |
||||
|
[], [], [], [], []) |
||||
|
all_obj_targets = [] |
||||
|
all_targets = [] |
||||
|
# all_targets = targets |
||||
|
all_img_masks = [] |
||||
|
for i in range(len(num_bbs)): |
||||
|
all_input_ids += input_ids[i] |
||||
|
all_position_ids += position_ids[i] |
||||
|
all_txt_lens += txt_lens[i] |
||||
|
all_txt_type_ids += txt_type_ids[i] |
||||
|
all_img_feats += img_feats[i] |
||||
|
all_img_pos_feats += img_pos_feats[i] |
||||
|
all_num_bbs += num_bbs[i] |
||||
|
all_attn_masks += attn_masks[i] |
||||
|
all_obj_targets += obj_targets[i] |
||||
|
all_img_masks += img_masks[i] |
||||
|
all_targets += targets[i] |
||||
|
|
||||
|
all_input_ids = pad_sequence(all_input_ids, |
||||
|
batch_first=True, padding_value=0) |
||||
|
all_position_ids = pad_sequence(all_position_ids, |
||||
|
batch_first=True, padding_value=0) |
||||
|
all_txt_type_ids = pad_sequence(all_txt_type_ids, |
||||
|
batch_first=True, padding_value=0) |
||||
|
all_attn_masks = pad_sequence(all_attn_masks, |
||||
|
batch_first=True, padding_value=0) |
||||
|
all_img_masks = pad_sequence(all_img_masks, |
||||
|
batch_first=True, padding_value=0) |
||||
|
# all_targets = pad_sequence(all_targets, |
||||
|
# batch_first=True, padding_value=0) |
||||
|
all_targets = torch.stack(all_targets, dim=0) |
||||
|
|
||||
|
batch_size = len(all_img_feats) |
||||
|
num_bb = max(all_num_bbs) |
||||
|
feat_dim = all_img_feats[0].size(1) |
||||
|
pos_dim = all_img_pos_feats[0].size(1) |
||||
|
all_img_feat = torch.zeros(batch_size, num_bb, feat_dim) |
||||
|
all_img_pos_feat = torch.zeros(batch_size, num_bb, pos_dim) |
||||
|
all_obj_target = torch.zeros(batch_size, num_bb) |
||||
|
for i, (im, pos, label) in enumerate(zip( |
||||
|
all_img_feats, all_img_pos_feats, all_obj_targets)): |
||||
|
len_ = im.size(0) |
||||
|
all_img_feat.data[i, :len_, :] = im.data |
||||
|
all_img_pos_feat.data[i, :len_, :] = pos.data |
||||
|
all_obj_target.data[i, :len_] = label.data |
||||
|
|
||||
|
obj_targets = all_obj_target[all_img_masks].contiguous() |
||||
|
return (all_input_ids, all_position_ids, all_txt_lens, |
||||
|
all_txt_type_ids, |
||||
|
all_img_feat, all_img_pos_feat, all_num_bbs, |
||||
|
all_attn_masks, all_targets, obj_targets, all_img_masks) |
||||
|
|
||||
|
|
||||
|
class VcrEvalDataset(DetectFeatBertTokDataset): |
||||
|
def __init__(self, split, *args, **kwargs): |
||||
|
super().__init__(*args, **kwargs) |
||||
|
self.split = split |
||||
|
del self.txt_lens |
||||
|
|
||||
|
def _get_input_ids(self, txt_dump): |
||||
|
# text input |
||||
|
input_ids_for_choices = [] |
||||
|
type_ids_for_choices = [] |
||||
|
input_ids_q = txt_dump['input_ids'] |
||||
|
type_ids_q = [0]*len(input_ids_q) |
||||
|
input_ids_as = txt_dump['input_ids_as'] |
||||
|
input_ids_rs = txt_dump['input_ids_rs'] |
||||
|
for index, input_ids_a in enumerate(input_ids_as): |
||||
|
curr_input_ids_qa = [self.cls_] + copy.deepcopy(input_ids_q) +\ |
||||
|
[self.sep] + input_ids_a + [self.sep] |
||||
|
curr_type_ids_qa = [0] + type_ids_q + [2]*( |
||||
|
len(input_ids_a)+2) |
||||
|
input_ids_for_choices.append(curr_input_ids_qa) |
||||
|
type_ids_for_choices.append(curr_type_ids_qa) |
||||
|
for index, input_ids_a in enumerate(input_ids_as): |
||||
|
curr_input_ids_qa = [self.cls_] + copy.deepcopy(input_ids_q) +\ |
||||
|
[self.sep] + input_ids_a + [self.sep] |
||||
|
curr_type_ids_qa = [0] + type_ids_q + [2]*( |
||||
|
len(input_ids_a)+1) |
||||
|
if (self.split == "val" and index == txt_dump["qa_target"]) or\ |
||||
|
self.split == "test": |
||||
|
for input_ids_r in input_ids_rs: |
||||
|
curr_input_ids_qar = copy.deepcopy(curr_input_ids_qa) +\ |
||||
|
input_ids_r + [self.sep] |
||||
|
curr_type_ids_qar = copy.deepcopy(curr_type_ids_qa) +\ |
||||
|
[3]*(len(input_ids_r)+2) |
||||
|
input_ids_for_choices.append(curr_input_ids_qar) |
||||
|
type_ids_for_choices.append(curr_type_ids_qar) |
||||
|
return input_ids_for_choices, type_ids_for_choices |
||||
|
|
||||
|
def __getitem__(self, i): |
||||
|
qid = self.ids[i] |
||||
|
id_ = self.ids[i] |
||||
|
txt_dump = self.db[id_] |
||||
|
img_feat, img_pos_feat, num_bb = self._get_img_feat( |
||||
|
txt_dump['img_fname'][0], txt_dump['img_fname'][1]) |
||||
|
object_targets = txt_dump["object_ids"] |
||||
|
input_ids_for_choices, type_ids_for_choices = self._get_input_ids( |
||||
|
txt_dump) |
||||
|
qa_target = torch.tensor([int(txt_dump["qa_target"])]) |
||||
|
qar_target = torch.tensor([int(txt_dump["qar_target"])]) |
||||
|
|
||||
|
choice_num_bbs, choice_img_feats, choice_img_pos_feats = ( |
||||
|
[], [], []) |
||||
|
(choice_txt_lens, choice_input_ids, choice_attn_masks, |
||||
|
choice_position_ids, choice_txt_type_ids) = ( |
||||
|
[], [], [], [], []) |
||||
|
choice_obj_targets = [] |
||||
|
for index, input_ids in enumerate(input_ids_for_choices): |
||||
|
txt_type_ids = type_ids_for_choices[index] |
||||
|
attn_masks = [1] * len(input_ids) |
||||
|
position_ids = list(range(len(input_ids))) |
||||
|
attn_masks += [1] * num_bb |
||||
|
|
||||
|
input_ids = torch.tensor(input_ids) |
||||
|
position_ids = torch.tensor(position_ids) |
||||
|
attn_masks = torch.tensor(attn_masks) |
||||
|
txt_type_ids = torch.tensor(txt_type_ids) |
||||
|
|
||||
|
choice_txt_lens.append(len(input_ids)) |
||||
|
choice_input_ids.append(input_ids) |
||||
|
choice_attn_masks.append(attn_masks) |
||||
|
choice_position_ids.append(position_ids) |
||||
|
choice_txt_type_ids.append(txt_type_ids) |
||||
|
|
||||
|
choice_num_bbs.append(num_bb) |
||||
|
choice_img_feats.append(img_feat) |
||||
|
choice_img_pos_feats.append(img_pos_feat) |
||||
|
|
||||
|
obj_targets = torch.tensor(object_targets) |
||||
|
choice_obj_targets.append(obj_targets) |
||||
|
|
||||
|
return (qid, choice_input_ids, choice_position_ids, choice_txt_lens, |
||||
|
choice_txt_type_ids, |
||||
|
choice_img_feats, choice_img_pos_feats, choice_num_bbs, |
||||
|
choice_attn_masks, qa_target, qar_target, choice_obj_targets) |
||||
|
|
||||
|
|
||||
|
def vcr_eval_collate(inputs): |
||||
|
(qids, input_ids, position_ids, txt_lens, txt_type_ids, |
||||
|
img_feats, img_pos_feats, |
||||
|
num_bbs, attn_masks, qa_targets, qar_targets, |
||||
|
obj_targets) = map(list, unzip(inputs)) |
||||
|
|
||||
|
all_num_bbs, all_img_feats, all_img_pos_feats = ( |
||||
|
[], [], []) |
||||
|
all_txt_lens, all_input_ids, all_attn_masks, all_position_ids,\ |
||||
|
all_txt_type_ids = ( |
||||
|
[], [], [], [], []) |
||||
|
# all_qa_targets = qa_targets |
||||
|
# all_qar_targets = qar_targets |
||||
|
all_obj_targets = [] |
||||
|
for i in range(len(num_bbs)): |
||||
|
all_input_ids += input_ids[i] |
||||
|
all_position_ids += position_ids[i] |
||||
|
all_txt_lens += txt_lens[i] |
||||
|
all_img_feats += img_feats[i] |
||||
|
all_img_pos_feats += img_pos_feats[i] |
||||
|
all_num_bbs += num_bbs[i] |
||||
|
all_attn_masks += attn_masks[i] |
||||
|
all_txt_type_ids += txt_type_ids[i] |
||||
|
all_obj_targets += obj_targets[i] |
||||
|
|
||||
|
all_input_ids = pad_sequence(all_input_ids, |
||||
|
batch_first=True, padding_value=0) |
||||
|
all_position_ids = pad_sequence(all_position_ids, |
||||
|
batch_first=True, padding_value=0) |
||||
|
all_txt_type_ids = pad_sequence(all_txt_type_ids, |
||||
|
batch_first=True, padding_value=0) |
||||
|
all_attn_masks = pad_sequence(all_attn_masks, |
||||
|
batch_first=True, padding_value=0) |
||||
|
all_obj_targets = pad_sequence(all_obj_targets, |
||||
|
batch_first=True, padding_value=0) |
||||
|
all_qa_targets = torch.stack(qa_targets, dim=0) |
||||
|
all_qar_targets = torch.stack(qar_targets, dim=0) |
||||
|
|
||||
|
batch_size = len(all_img_feats) |
||||
|
num_bb = max(all_num_bbs) |
||||
|
feat_dim = all_img_feats[0].size(1) |
||||
|
pos_dim = all_img_pos_feats[0].size(1) |
||||
|
all_img_feat = torch.zeros(batch_size, num_bb, feat_dim) |
||||
|
all_img_pos_feat = torch.zeros(batch_size, num_bb, pos_dim) |
||||
|
for i, (im, pos) in enumerate(zip( |
||||
|
all_img_feats, all_img_pos_feats)): |
||||
|
len_ = im.size(0) |
||||
|
all_img_feat.data[i, :len_, :] = im.data |
||||
|
all_img_pos_feat.data[i, :len_, :] = pos.data |
||||
|
|
||||
|
return (qids, all_input_ids, all_position_ids, all_txt_lens, |
||||
|
all_txt_type_ids, |
||||
|
all_img_feat, all_img_pos_feat, all_num_bbs, |
||||
|
all_attn_masks, all_qa_targets, all_qar_targets, all_obj_targets) |
||||
|
|
||||
|
|
||||
|
class MlmDatasetForVCR(DetectFeatBertTokDataset): |
||||
|
def __init__(self, *args, **kwargs): |
||||
|
super().__init__(*args, **kwargs) |
||||
|
del self.txt_lens |
||||
|
|
||||
|
def _get_input_ids(self, txt_dump, mask=True): |
||||
|
# text input |
||||
|
input_ids_q = txt_dump['input_ids'] |
||||
|
type_ids_q = [0]*len(input_ids_q) |
||||
|
if mask: |
||||
|
input_ids_q, txt_labels_q = random_word( |
||||
|
input_ids_q, self.v_range, self.mask) |
||||
|
else: |
||||
|
txt_labels_q = input_ids_q |
||||
|
|
||||
|
answer_label = txt_dump['qa_target'] |
||||
|
assert answer_label >= 0, "answer_label < 0" |
||||
|
|
||||
|
input_ids_a = txt_dump['input_ids_as'][answer_label] |
||||
|
type_ids_a = [2]*len(input_ids_a) |
||||
|
if mask: |
||||
|
input_ids_a, txt_labels_a = random_word( |
||||
|
input_ids_a, self.v_range, self.mask) |
||||
|
else: |
||||
|
txt_labels_a = input_ids_a |
||||
|
|
||||
|
input_ids = input_ids_q + [self.sep] + input_ids_a |
||||
|
type_ids = type_ids_q + [0] + type_ids_a |
||||
|
txt_labels = txt_labels_q + [-1] + txt_labels_a |
||||
|
|
||||
|
if self.task == "qar": |
||||
|
rationale_label = txt_dump['qar_target'] |
||||
|
assert rationale_label >= 0, "rationale_label < 0" |
||||
|
|
||||
|
input_ids_r = txt_dump['input_ids_rs'][rationale_label] |
||||
|
type_ids_r = [3]*len(input_ids_r) |
||||
|
if mask: |
||||
|
input_ids_r, txt_labels_r = random_word( |
||||
|
input_ids_r, self.v_range, self.mask) |
||||
|
else: |
||||
|
txt_labels_r = input_ids_r |
||||
|
|
||||
|
input_ids += [self.sep] + input_ids_r |
||||
|
type_ids += [2] + type_ids_r |
||||
|
txt_labels += [-1] + txt_labels_r |
||||
|
return input_ids, type_ids, txt_labels |
||||
|
|
||||
|
def __getitem__(self, i): |
||||
|
id_ = self.ids[i] |
||||
|
txt_dump = self.db[id_] |
||||
|
img_feat, img_pos_feat, num_bb = self._get_img_feat( |
||||
|
txt_dump['img_fname'][0], txt_dump['img_fname'][1]) |
||||
|
|
||||
|
# txt inputs |
||||
|
input_ids, type_ids, txt_labels = self._get_input_ids(txt_dump) |
||||
|
input_ids = [self.cls_] + input_ids + [self.sep] |
||||
|
txt_labels = [-1] + txt_labels + [-1] |
||||
|
type_ids = [type_ids[0]] + type_ids + [type_ids[-1]] |
||||
|
attn_masks = [1] * len(input_ids) |
||||
|
position_ids = list(range(len(input_ids))) |
||||
|
attn_masks += [1] * num_bb |
||||
|
input_ids = torch.tensor(input_ids) |
||||
|
position_ids = torch.tensor(position_ids) |
||||
|
attn_masks = torch.tensor(attn_masks) |
||||
|
txt_labels = torch.tensor(txt_labels) |
||||
|
type_ids = torch.tensor(type_ids) |
||||
|
|
||||
|
return (input_ids, position_ids, type_ids, img_feat, img_pos_feat, |
||||
|
attn_masks, txt_labels) |
||||
|
|
||||
|
|
||||
|
def mlm_collate_for_vcr(inputs): |
||||
|
(input_ids, position_ids, type_ids, img_feats, img_pos_feats, attn_masks, |
||||
|
txt_labels) = map(list, unzip(inputs)) |
||||
|
|
||||
|
txt_lens = [i.size(0) for i in input_ids] |
||||
|
num_bbs = [f.size(0) for f in img_feats] |
||||
|
|
||||
|
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0) |
||||
|
type_ids = pad_sequence(type_ids, batch_first=True, padding_value=0) |
||||
|
position_ids = pad_sequence(position_ids, |
||||
|
batch_first=True, padding_value=0) |
||||
|
attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0) |
||||
|
txt_labels = pad_sequence(txt_labels, batch_first=True, padding_value=-1) |
||||
|
|
||||
|
batch_size = len(img_feats) |
||||
|
num_bb = max(num_bbs) |
||||
|
feat_dim = img_feats[0].size(1) |
||||
|
pos_dim = img_pos_feats[0].size(1) |
||||
|
img_feat = torch.zeros(batch_size, num_bb, feat_dim) |
||||
|
img_pos_feat = torch.zeros(batch_size, num_bb, pos_dim) |
||||
|
for i, (im, pos) in enumerate(zip(img_feats, img_pos_feats)): |
||||
|
len_ = im.size(0) |
||||
|
img_feat.data[i, :len_, :] = im.data |
||||
|
img_pos_feat.data[i, :len_, :] = pos.data |
||||
|
|
||||
|
return (input_ids, position_ids, type_ids, txt_lens, |
||||
|
img_feat, img_pos_feat, num_bbs, |
||||
|
attn_masks, txt_labels) |
||||
|
|
||||
|
|
||||
|
class MrmDatasetForVCR(DetectFeatBertTokDataset): |
||||
|
def __init__(self, mask_prob, *args, **kwargs): |
||||
|
super().__init__(*args, **kwargs) |
||||
|
self.mask_prob = mask_prob |
||||
|
del self.txt_lens |
||||
|
|
||||
|
def _get_input_ids(self, txt_dump, mask=True): |
||||
|
# text input |
||||
|
input_ids_q = txt_dump['input_ids'] |
||||
|
type_ids_q = [0]*len(input_ids_q) |
||||
|
|
||||
|
answer_label = txt_dump['qa_target'] |
||||
|
assert answer_label >= 0, "answer_label < 0" |
||||
|
|
||||
|
input_ids_a = txt_dump['input_ids_as'][answer_label] |
||||
|
type_ids_a = [2]*len(input_ids_a) |
||||
|
|
||||
|
input_ids = input_ids_q + [self.sep] + input_ids_a |
||||
|
type_ids = type_ids_q + [0] + type_ids_a |
||||
|
|
||||
|
if self.task == "qar": |
||||
|
rationale_label = txt_dump['qar_target'] |
||||
|
assert rationale_label >= 0, "rationale_label < 0" |
||||
|
|
||||
|
input_ids_r = txt_dump['input_ids_rs'][rationale_label] |
||||
|
type_ids_r = [3]*len(input_ids_r) |
||||
|
|
||||
|
input_ids += [self.sep] + input_ids_r |
||||
|
type_ids += [2] + type_ids_r |
||||
|
return input_ids, type_ids |
||||
|
|
||||
|
def __getitem__(self, i): |
||||
|
id_ = self.ids[i] |
||||
|
txt_dump = self.db[id_] |
||||
|
img_feat, img_pos_feat, num_bb = self._get_img_feat( |
||||
|
txt_dump['img_fname'][0], txt_dump['img_fname'][1]) |
||||
|
|
||||
|
# image input features |
||||
|
img_mask = [random.random() < self.mask_prob for _ in range(num_bb)] |
||||
|
if not any(img_mask): |
||||
|
# at least mask 1 |
||||
|
img_mask[0] = True |
||||
|
img_mask = torch.tensor(img_mask) |
||||
|
|
||||
|
# text input |
||||
|
input_ids, type_ids = self._get_input_ids(txt_dump) |
||||
|
input_ids = [self.cls_] + input_ids + [self.sep] |
||||
|
type_ids = [type_ids[0]] + type_ids + [type_ids[-1]] |
||||
|
attn_masks = [1] * len(input_ids) |
||||
|
position_ids = list(range(len(input_ids))) |
||||
|
attn_masks += [1] * num_bb |
||||
|
input_ids = torch.tensor(input_ids) |
||||
|
position_ids = torch.tensor(position_ids) |
||||
|
attn_masks = torch.tensor(attn_masks) |
||||
|
type_ids = torch.tensor(type_ids) |
||||
|
|
||||
|
return (input_ids, position_ids, type_ids, img_feat, img_pos_feat, |
||||
|
attn_masks, img_mask) |
||||
|
|
||||
|
|
||||
|
def mrm_collate_for_vcr(inputs): |
||||
|
(input_ids, position_ids, type_ids, img_feats, img_pos_feats, |
||||
|
attn_masks, img_masks) = map(list, unzip(inputs)) |
||||
|
|
||||
|
txt_lens = [i.size(0) for i in input_ids] |
||||
|
num_bbs = [f.size(0) for f in img_feats] |
||||
|
|
||||
|
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0) |
||||
|
position_ids = pad_sequence(position_ids, |
||||
|
batch_first=True, padding_value=0) |
||||
|
type_ids = pad_sequence(type_ids, batch_first=True, padding_value=0) |
||||
|
attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0) |
||||
|
img_masks = pad_sequence(img_masks, batch_first=True, padding_value=0) |
||||
|
|
||||
|
batch_size = len(img_feats) |
||||
|
num_bb = max(num_bbs) |
||||
|
feat_dim = img_feats[0].size(1) |
||||
|
pos_dim = img_pos_feats[0].size(1) |
||||
|
img_feat = torch.zeros(batch_size, num_bb, feat_dim) |
||||
|
img_pos_feat = torch.zeros(batch_size, num_bb, pos_dim) |
||||
|
for i, (im, pos) in enumerate(zip(img_feats, img_pos_feats)): |
||||
|
len_ = im.size(0) |
||||
|
img_feat.data[i, :len_, :] = im.data |
||||
|
img_pos_feat.data[i, :len_, :] = pos.data |
||||
|
|
||||
|
return (input_ids, position_ids, type_ids, txt_lens, |
||||
|
img_feat, img_pos_feat, num_bbs, |
||||
|
attn_masks, img_masks) |
||||
|
|
||||
|
|
||||
|
class DetectFeatBertTokDataset_for_mrc_vcr(DetectFeatBertTokDataset): |
||||
|
def __init__(self, db_dir, img_dir_gt=None, img_dir=None, |
||||
|
max_txt_len=60, task="qa"): |
||||
|
assert not (img_dir_gt is None and img_dir is None),\ |
||||
|
"image_dir_gt and img_dir cannot all be None" |
||||
|
assert task == "qa" or task == "qar",\ |
||||
|
"VCR only allow two tasks: qa or qar" |
||||
|
assert img_dir_gt is None or isinstance(img_dir_gt, DetectFeatLmdb) |
||||
|
assert img_dir is None or isinstance(img_dir, DetectFeatLmdb) |
||||
|
super().__init__(db_dir, img_dir_gt, img_dir, max_txt_len, task) |
||||
|
if self.img_dir: |
||||
|
self.img_dir = DetectFeatDir_for_mrc(img_dir) |
||||
|
if self.img_dir_gt: |
||||
|
self.img_dir_gt = DetectFeatDir_for_mrc(img_dir_gt) |
||||
|
|
||||
|
def _get_img_feat(self, fname_gt, fname): |
||||
|
if self.img_dir and self.img_dir_gt: |
||||
|
img_feat_gt, bb_gt,\ |
||||
|
img_soft_labels_gt = self.img_dir_gt[fname_gt] |
||||
|
img_bb_gt = torch.cat([bb_gt, bb_gt[:, 4:5]*bb_gt[:, 5:]], dim=-1) |
||||
|
|
||||
|
img_feat, bb, img_soft_labels = self.img_dir[fname] |
||||
|
img_bb = torch.cat([bb, bb[:, 4:5]*bb[:, 5:]], dim=-1) |
||||
|
|
||||
|
img_feat = torch.cat([img_feat_gt, img_feat], dim=0) |
||||
|
img_bb = torch.cat([img_bb_gt, img_bb], dim=0) |
||||
|
img_soft_labels = torch.cat( |
||||
|
[img_soft_labels_gt, img_soft_labels], dim=0) |
||||
|
num_bb = img_feat.size(0) |
||||
|
elif self.img_dir: |
||||
|
img_feat, bb, img_soft_labels = self.img_dir[fname] |
||||
|
img_bb = torch.cat([bb, bb[:, 4:5]*bb[:, 5:]], dim=-1) |
||||
|
num_bb = img_feat.size(0) |
||||
|
elif self.img_dir_gt: |
||||
|
img_feat, bb, img_soft_labels = self.img_dir_gt[fname_gt] |
||||
|
img_bb = torch.cat([bb, bb[:, 4:5]*bb[:, 5:]], dim=-1) |
||||
|
num_bb = img_feat.size(0) |
||||
|
return img_feat, img_bb, img_soft_labels, num_bb |
||||
|
|
||||
|
|
||||
|
class MrcDatasetForVCR(DetectFeatBertTokDataset_for_mrc_vcr): |
||||
|
def __init__(self, mask_prob, *args, **kwargs): |
||||
|
super().__init__(*args, **kwargs) |
||||
|
self.mask_prob = mask_prob |
||||
|
del self.txt_lens |
||||
|
|
||||
|
def _get_input_ids(self, txt_dump, mask=True): |
||||
|
# text input |
||||
|
input_ids_q = txt_dump['input_ids'] |
||||
|
type_ids_q = [0]*len(input_ids_q) |
||||
|
|
||||
|
answer_label = txt_dump['qa_target'] |
||||
|
assert answer_label >= 0, "answer_label < 0" |
||||
|
|
||||
|
input_ids_a = txt_dump['input_ids_as'][answer_label] |
||||
|
type_ids_a = [2]*len(input_ids_a) |
||||
|
|
||||
|
input_ids = input_ids_q + [self.sep] + input_ids_a |
||||
|
type_ids = type_ids_q + [0] + type_ids_a |
||||
|
|
||||
|
if self.task == "qar": |
||||
|
rationale_label = txt_dump['qar_target'] |
||||
|
assert rationale_label >= 0, "rationale_label < 0" |
||||
|
|
||||
|
input_ids_r = txt_dump['input_ids_rs'][rationale_label] |
||||
|
type_ids_r = [3]*len(input_ids_r) |
||||
|
|
||||
|
input_ids += [self.sep] + input_ids_r |
||||
|
type_ids += [2] + type_ids_r |
||||
|
return input_ids, type_ids |
||||
|
|
||||
|
def __getitem__(self, i): |
||||
|
id_ = self.ids[i] |
||||
|
txt_dump = self.db[id_] |
||||
|
img_feat, img_pos_feat, img_soft_labels, num_bb = self._get_img_feat( |
||||
|
txt_dump['img_fname'][0], txt_dump['img_fname'][1]) |
||||
|
|
||||
|
# image input features |
||||
|
img_mask = [random.random() < self.mask_prob for _ in range(num_bb)] |
||||
|
if not any(img_mask): |
||||
|
# at least mask 1 |
||||
|
img_mask[0] = True |
||||
|
img_mask = torch.tensor(img_mask) |
||||
|
|
||||
|
# text input |
||||
|
input_ids, type_ids = self._get_input_ids(txt_dump) |
||||
|
input_ids = [self.cls_] + input_ids + [self.sep] |
||||
|
type_ids = [type_ids[0]] + type_ids + [type_ids[-1]] |
||||
|
attn_masks = [1] * len(input_ids) |
||||
|
position_ids = list(range(len(input_ids))) |
||||
|
attn_masks += [1] * num_bb |
||||
|
input_ids = torch.tensor(input_ids) |
||||
|
position_ids = torch.tensor(position_ids) |
||||
|
attn_masks = torch.tensor(attn_masks) |
||||
|
type_ids = torch.tensor(type_ids) |
||||
|
|
||||
|
return (input_ids, position_ids, type_ids, img_feat, img_pos_feat, |
||||
|
img_soft_labels, attn_masks, img_mask) |
||||
|
|
||||
|
|
||||
|
def mrc_collate_for_vcr(inputs): |
||||
|
(input_ids, position_ids, type_ids, img_feats, img_pos_feats, |
||||
|
img_soft_labels, attn_masks, img_masks |
||||
|
) = map(list, unzip(inputs)) |
||||
|
|
||||
|
txt_lens = [i.size(0) for i in input_ids] |
||||
|
num_bbs = [f.size(0) for f in img_feats] |
||||
|
|
||||
|
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0) |
||||
|
position_ids = pad_sequence(position_ids, |
||||
|
batch_first=True, padding_value=0) |
||||
|
type_ids = pad_sequence(type_ids, batch_first=True, padding_value=0) |
||||
|
attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0) |
||||
|
img_masks = pad_sequence(img_masks, batch_first=True, padding_value=0) |
||||
|
|
||||
|
batch_size = len(img_feats) |
||||
|
num_bb = max(num_bbs) |
||||
|
feat_dim = img_feats[0].size(1) |
||||
|
soft_label_dim = img_soft_labels[0].size(1) |
||||
|
pos_dim = img_pos_feats[0].size(1) |
||||
|
img_feat = torch.zeros(batch_size, num_bb, feat_dim) |
||||
|
img_pos_feat = torch.zeros(batch_size, num_bb, pos_dim) |
||||
|
img_soft_label = torch.zeros(batch_size, num_bb, soft_label_dim) |
||||
|
for i, (im, pos, label) in enumerate(zip(img_feats, |
||||
|
img_pos_feats, |
||||
|
img_soft_labels)): |
||||
|
len_ = im.size(0) |
||||
|
img_feat.data[i, :len_, :] = im.data |
||||
|
img_pos_feat.data[i, :len_, :] = pos.data |
||||
|
img_soft_label.data[i, :len_, :] = label.data |
||||
|
|
||||
|
img_masks_ext_for_label = img_masks.unsqueeze(-1).expand_as(img_soft_label) |
||||
|
label_targets = img_soft_label[img_masks_ext_for_label].contiguous().view( |
||||
|
-1, soft_label_dim) |
||||
|
return (input_ids, position_ids, type_ids, txt_lens, |
||||
|
img_feat, img_pos_feat, num_bbs, |
||||
|
attn_masks, (img_masks, label_targets)) |
@ -0,0 +1,19 @@ |
|||||
|
""" |
||||
|
Visual entailment dataset |
||||
|
# NOTE: basically reuse VQA dataset |
||||
|
""" |
||||
|
from .vqa import VqaDataset, VqaEvalDataset, vqa_collate, vqa_eval_collate |
||||
|
|
||||
|
|
||||
|
class VeDataset(VqaDataset): |
||||
|
def __init__(self, *args, **kwargs): |
||||
|
super().__init__(3, *args, **kwargs) |
||||
|
|
||||
|
|
||||
|
class VeEvalDataset(VqaEvalDataset): |
||||
|
def __init__(self, *args, **kwargs): |
||||
|
super().__init__(3, *args, **kwargs) |
||||
|
|
||||
|
|
||||
|
ve_collate = vqa_collate |
||||
|
ve_eval_collate = vqa_eval_collate |
@ -0,0 +1,124 @@ |
|||||
|
""" |
||||
|
VQA dataset |
||||
|
""" |
||||
|
import torch |
||||
|
from torch.nn.utils.rnn import pad_sequence |
||||
|
from toolz.sandbox import unzip |
||||
|
|
||||
|
from .data import DetectFeatTxtTokDataset, pad_tensors, get_gather_index |
||||
|
|
||||
|
|
||||
|
def _get_vqa_target(example, num_answers): |
||||
|
target = torch.zeros(num_answers) |
||||
|
labels = example['target']['labels'] |
||||
|
scores = example['target']['scores'] |
||||
|
if labels and scores: |
||||
|
target.scatter_(0, torch.tensor(labels), torch.tensor(scores)) |
||||
|
return target |
||||
|
|
||||
|
|
||||
|
class VqaDataset(DetectFeatTxtTokDataset): |
||||
|
""" NOTE: This handels distributed inside """ |
||||
|
def __init__(self, num_answers, *args, **kwargs): |
||||
|
super().__init__(*args, **kwargs) |
||||
|
self.num_answers = num_answers |
||||
|
|
||||
|
def __getitem__(self, i): |
||||
|
example = super().__getitem__(i) |
||||
|
img_feat, img_pos_feat, num_bb = self._get_img_feat( |
||||
|
example['img_fname']) |
||||
|
|
||||
|
# text input |
||||
|
input_ids = example['input_ids'] |
||||
|
input_ids = self.txt_db.combine_inputs(input_ids) |
||||
|
|
||||
|
target = _get_vqa_target(example, self.num_answers) |
||||
|
|
||||
|
attn_masks = torch.ones(len(input_ids) + num_bb, dtype=torch.long) |
||||
|
|
||||
|
return input_ids, img_feat, img_pos_feat, attn_masks, target |
||||
|
|
||||
|
|
||||
|
def vqa_collate(inputs): |
||||
|
(input_ids, img_feats, img_pos_feats, attn_masks, targets |
||||
|
) = map(list, unzip(inputs)) |
||||
|
|
||||
|
txt_lens = [i.size(0) for i in input_ids] |
||||
|
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0) |
||||
|
position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long |
||||
|
).unsqueeze(0) |
||||
|
|
||||
|
attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0) |
||||
|
targets = torch.stack(targets, dim=0) |
||||
|
|
||||
|
num_bbs = [f.size(0) for f in img_feats] |
||||
|
img_feat = pad_tensors(img_feats, num_bbs) |
||||
|
img_pos_feat = pad_tensors(img_pos_feats, num_bbs) |
||||
|
|
||||
|
bs, max_tl = input_ids.size() |
||||
|
out_size = attn_masks.size(1) |
||||
|
gather_index = get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size) |
||||
|
|
||||
|
batch = {'input_ids': input_ids, |
||||
|
'position_ids': position_ids, |
||||
|
'img_feat': img_feat, |
||||
|
'img_pos_feat': img_pos_feat, |
||||
|
'attn_masks': attn_masks, |
||||
|
'gather_index': gather_index, |
||||
|
'targets': targets} |
||||
|
return batch |
||||
|
|
||||
|
|
||||
|
class VqaEvalDataset(VqaDataset): |
||||
|
def __getitem__(self, i): |
||||
|
qid = self.ids[i] |
||||
|
example = DetectFeatTxtTokDataset.__getitem__(self, i) |
||||
|
img_feat, img_pos_feat, num_bb = self._get_img_feat( |
||||
|
example['img_fname']) |
||||
|
|
||||
|
# text input |
||||
|
input_ids = example['input_ids'] |
||||
|
input_ids = self.txt_db.combine_inputs(input_ids) |
||||
|
|
||||
|
if 'target' in example: |
||||
|
target = _get_vqa_target(example, self.num_answers) |
||||
|
else: |
||||
|
target = None |
||||
|
|
||||
|
attn_masks = torch.ones(len(input_ids) + num_bb, dtype=torch.long) |
||||
|
|
||||
|
return qid, input_ids, img_feat, img_pos_feat, attn_masks, target |
||||
|
|
||||
|
|
||||
|
def vqa_eval_collate(inputs): |
||||
|
(qids, input_ids, img_feats, img_pos_feats, attn_masks, targets |
||||
|
) = map(list, unzip(inputs)) |
||||
|
|
||||
|
txt_lens = [i.size(0) for i in input_ids] |
||||
|
|
||||
|
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0) |
||||
|
position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long |
||||
|
).unsqueeze(0) |
||||
|
attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0) |
||||
|
if targets[0] is None: |
||||
|
targets = None |
||||
|
else: |
||||
|
targets = torch.stack(targets, dim=0) |
||||
|
|
||||
|
num_bbs = [f.size(0) for f in img_feats] |
||||
|
img_feat = pad_tensors(img_feats, num_bbs) |
||||
|
img_pos_feat = pad_tensors(img_pos_feats, num_bbs) |
||||
|
|
||||
|
bs, max_tl = input_ids.size() |
||||
|
out_size = attn_masks.size(1) |
||||
|
gather_index = get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size) |
||||
|
|
||||
|
batch = {'qids': qids, |
||||
|
'input_ids': input_ids, |
||||
|
'position_ids': position_ids, |
||||
|
'img_feat': img_feat, |
||||
|
'img_pos_feat': img_pos_feat, |
||||
|
'attn_masks': attn_masks, |
||||
|
'gather_index': gather_index, |
||||
|
'targets': targets} |
||||
|
return batch |
@ -0,0 +1,53 @@ |
|||||
|
""" Image Text Retrieval evaluation helper """ |
||||
|
import torch |
||||
|
|
||||
|
|
||||
|
@torch.no_grad() |
||||
|
def itm_eval(score_matrix, txt_ids, img_ids, txt2img, img2txts): |
||||
|
# image retrieval |
||||
|
img2j = {i: j for j, i in enumerate(img_ids)} |
||||
|
_, rank_txt = score_matrix.topk(10, dim=1) |
||||
|
gt_img_j = torch.LongTensor([img2j[txt2img[txt_id]] |
||||
|
for txt_id in txt_ids], |
||||
|
).to(rank_txt.device |
||||
|
).unsqueeze(1).expand_as(rank_txt) |
||||
|
rank = (rank_txt == gt_img_j).nonzero() |
||||
|
if rank.numel(): |
||||
|
ir_r1 = (rank < 1).sum().item() / len(txt_ids) |
||||
|
ir_r5 = (rank < 5).sum().item() / len(txt_ids) |
||||
|
ir_r10 = (rank < 10).sum().item() / len(txt_ids) |
||||
|
else: |
||||
|
ir_r1, ir_r5, ir_r10 = 0, 0, 0 |
||||
|
|
||||
|
# text retrieval |
||||
|
txt2i = {t: i for i, t in enumerate(txt_ids)} |
||||
|
_, rank_img = score_matrix.topk(10, dim=0) |
||||
|
tr_r1, tr_r5, tr_r10 = 0, 0, 0 |
||||
|
for j, img_id in enumerate(img_ids): |
||||
|
gt_is = [txt2i[t] for t in img2txts[img_id]] |
||||
|
ranks = [(rank_img[:, j] == i).nonzero() for i in gt_is] |
||||
|
rank = min([10] + [r.item() for r in ranks if r.numel()]) |
||||
|
if rank < 1: |
||||
|
tr_r1 += 1 |
||||
|
if rank < 5: |
||||
|
tr_r5 += 1 |
||||
|
if rank < 10: |
||||
|
tr_r10 += 1 |
||||
|
tr_r1 /= len(img_ids) |
||||
|
tr_r5 /= len(img_ids) |
||||
|
tr_r10 /= len(img_ids) |
||||
|
|
||||
|
tr_mean = (tr_r1 + tr_r5 + tr_r10) / 3 |
||||
|
ir_mean = (ir_r1 + ir_r5 + ir_r10) / 3 |
||||
|
r_mean = (tr_mean + ir_mean) / 2 |
||||
|
|
||||
|
eval_log = {'txt_r1': tr_r1, |
||||
|
'txt_r5': tr_r5, |
||||
|
'txt_r10': tr_r10, |
||||
|
'txt_r_mean': tr_mean, |
||||
|
'img_r1': ir_r1, |
||||
|
'img_r5': ir_r5, |
||||
|
'img_r10': ir_r10, |
||||
|
'img_r_mean': ir_mean, |
||||
|
'r_mean': r_mean} |
||||
|
return eval_log |
@ -0,0 +1,62 @@ |
|||||
|
""" |
||||
|
copied from official NLVR2 github |
||||
|
python eval/nlvr2.py <output.csv> <annotation.json> |
||||
|
""" |
||||
|
import json |
||||
|
import sys |
||||
|
|
||||
|
# Load the predictions file. Assume it is a CSV. |
||||
|
predictions = { } |
||||
|
for line in open(sys.argv[1]).readlines(): |
||||
|
if line: |
||||
|
splits = line.strip().split(",") |
||||
|
# We assume identifiers are in the format "split-####-#-#.png". |
||||
|
identifier = splits[0] |
||||
|
prediction = splits[1] |
||||
|
predictions[identifier] = prediction |
||||
|
|
||||
|
# Load the labeled examples. |
||||
|
labeled_examples = [json.loads(line) for line in open(sys.argv[2]).readlines() if line] |
||||
|
|
||||
|
# If not, identify the ones that are missing, and exit. |
||||
|
total_num = len(labeled_examples) |
||||
|
if len(predictions) < total_num: |
||||
|
print("Some predictions are missing!") |
||||
|
print("Got " + str(len(predictions)) + " predictions but expected " + str(total_num)) |
||||
|
|
||||
|
for example in labeled_examples: |
||||
|
lookup = example["identifier"] |
||||
|
if not lookup in predictions: |
||||
|
print("Missing prediction for item " + str(lookup)) |
||||
|
exit() |
||||
|
|
||||
|
# Get the precision by iterating through the examples and checking the value |
||||
|
# that was predicted. |
||||
|
# Also update the "consistency" dictionary that keeps track of whether all |
||||
|
# predictions for a given sentence were correct. |
||||
|
num_correct = 0. |
||||
|
consistency_dict = { } |
||||
|
|
||||
|
for example in labeled_examples: |
||||
|
anon_label = example["identifier"].split("-") |
||||
|
anon_label[2] = '' |
||||
|
anon_label = '-'.join(anon_label) |
||||
|
if not anon_label in consistency_dict: |
||||
|
consistency_dict[anon_label] = True |
||||
|
lookup = example["identifier"] |
||||
|
prediction = predictions[lookup] |
||||
|
if prediction.lower() == example["label"].lower(): |
||||
|
num_correct += 1. |
||||
|
else: |
||||
|
consistency_dict[anon_label] = False |
||||
|
|
||||
|
# Calculate consistency. |
||||
|
num_consistent = 0. |
||||
|
unique_sentence = len(consistency_dict) |
||||
|
for identifier, consistent in consistency_dict.items(): |
||||
|
if consistent: |
||||
|
num_consistent += 1 |
||||
|
|
||||
|
# Report values. |
||||
|
print("accuracy=" + str(num_correct / total_num)) |
||||
|
print("consistency=" + str(num_consistent / unique_sentence)) |
@ -0,0 +1,218 @@ |
|||||
|
# coding=utf-8 |
||||
|
# copied from hugginface github |
||||
|
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. |
||||
|
# team. |
||||
|
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. |
||||
|
# |
||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
|
# you may not use this file except in compliance with the License. |
||||
|
# You may obtain a copy of the License at |
||||
|
# |
||||
|
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
|
# |
||||
|
# Unless required by applicable law or agreed to in writing, software |
||||
|
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
|
# See the License for the specific language governing permissions and |
||||
|
# limitations under the License. |
||||
|
"""BERT for Referring Expression Comprehension Evaluation""" |
||||
|
import argparse |
||||
|
import json |
||||
|
import os |
||||
|
from os.path import exists |
||||
|
from time import time |
||||
|
|
||||
|
import torch |
||||
|
from torch.utils.data import DataLoader |
||||
|
|
||||
|
# to be deprecated once upgraded to 1.2 |
||||
|
# from torch.utils.data.distributed import DistributedSampler |
||||
|
from data import DistributedSampler |
||||
|
|
||||
|
from apex import amp |
||||
|
from horovod import torch as hvd |
||||
|
|
||||
|
from data import (ReImageFeatDir, ReferringExpressionEvalDataset, |
||||
|
re_eval_collate, PrefetchLoader) |
||||
|
from model import BertForReferringExpressionComprehension |
||||
|
|
||||
|
from utils.logger import LOGGER |
||||
|
from utils.distributed import all_gather_list |
||||
|
from utils.misc import Struct |
||||
|
|
||||
|
|
||||
|
def main(opts): |
||||
|
|
||||
|
hvd.init() |
||||
|
n_gpu = hvd.size() |
||||
|
device = torch.device("cuda", hvd.local_rank()) |
||||
|
torch.cuda.set_device(hvd.local_rank()) |
||||
|
rank = hvd.rank() |
||||
|
opts.rank = rank |
||||
|
LOGGER.info(f"device: {device}, n_gpu: {n_gpu}, rank: {hvd.rank()}, " |
||||
|
f"16-bits training: {opts.fp16}") |
||||
|
|
||||
|
hps_file = f'{opts.output_dir}/log/hps.json' |
||||
|
model_opts = json.load(open(hps_file)) |
||||
|
if 'mlp' not in model_opts: |
||||
|
model_opts['mlp'] = 1 |
||||
|
model_opts = Struct(model_opts) |
||||
|
|
||||
|
# Prepro txt_dbs |
||||
|
txt_dbs = opts.txt_db.split(':') |
||||
|
|
||||
|
# Prepro model |
||||
|
if exists(opts.checkpoint): |
||||
|
ckpt_file = torch.load(opts.checkpoint) |
||||
|
else: |
||||
|
ckpt_file = f'{opts.output_dir}/ckpt/model_epoch_{opts.checkpoint}.pt' |
||||
|
checkpoint = torch.load(ckpt_file) |
||||
|
bert_model = json.load(open(f'{txt_dbs[0]}/meta.json'))['bert'] |
||||
|
model = BertForReferringExpressionComprehension.from_pretrained( |
||||
|
bert_model, img_dim=2048, mlp=model_opts.mlp, |
||||
|
state_dict=checkpoint |
||||
|
) |
||||
|
if model_opts.cut_bert != -1: |
||||
|
# cut some layers of BERT |
||||
|
model.bert.encoder.layer = torch.nn.ModuleList( |
||||
|
model.bert.encoder.layer[:opts.cut_bert] |
||||
|
) |
||||
|
model.to(device) |
||||
|
|
||||
|
if opts.fp16: |
||||
|
model = amp.initialize(model, enabled=opts.fp16, opt_level='O2') |
||||
|
|
||||
|
# load DBs and image dirs |
||||
|
eval_img_dir = ReImageFeatDir(opts.img_dir) |
||||
|
for txt_db in txt_dbs: |
||||
|
print(f'Evaluating {txt_db}') |
||||
|
eval_dataset = ReferringExpressionEvalDataset(txt_db, eval_img_dir, |
||||
|
max_txt_len=-1) |
||||
|
eval_sampler = DistributedSampler(eval_dataset, num_replicas=n_gpu, |
||||
|
rank=rank, shuffle=False) |
||||
|
eval_dataloader = DataLoader(eval_dataset, |
||||
|
sampler=eval_sampler, |
||||
|
batch_size=opts.batch_size, |
||||
|
num_workers=opts.n_workers, |
||||
|
pin_memory=opts.pin_mem, |
||||
|
collate_fn=re_eval_collate) |
||||
|
eval_dataloader = PrefetchLoader(eval_dataloader) |
||||
|
|
||||
|
# evaluate |
||||
|
val_log, results = validate(model, eval_dataloader) |
||||
|
|
||||
|
# save |
||||
|
result_dir = f'{opts.output_dir}/results_test' |
||||
|
if not exists(result_dir) and rank == 0: |
||||
|
os.makedirs(result_dir) |
||||
|
|
||||
|
# dummy sync |
||||
|
_ = None |
||||
|
all_gather_list(_) |
||||
|
db_split = txt_db.split('/')[-1].split('-')[0] # refcoco+_val_large |
||||
|
img_dir = opts.img_dir.split('/')[-1] # visual_grounding_coco_gt |
||||
|
if n_gpu > 1: |
||||
|
with open(f'{opts.output_dir}/results_test/' |
||||
|
f'results_{opts.checkpoint}_{db_split}_on_{img_dir}' |
||||
|
f'_rank{rank}.json', |
||||
|
'w') as f: |
||||
|
json.dump(results, f) |
||||
|
# dummy sync |
||||
|
_ = None |
||||
|
all_gather_list(_) |
||||
|
|
||||
|
# join results |
||||
|
if n_gpu > 1: |
||||
|
results = [] |
||||
|
for rank in range(n_gpu): |
||||
|
results.extend(json.load(open( |
||||
|
f'{opts.output_dir}/results_test/' |
||||
|
f'results_{opts.checkpoint}_{db_split}_on_{img_dir}' |
||||
|
f'_rank{rank}.json'))) |
||||
|
if rank == 0: |
||||
|
with open(f'{opts.output_dir}/results_test/' |
||||
|
f'results_{opts.checkpoint}_{db_split}_on_{img_dir}' |
||||
|
f'_all.json', 'w') as f: |
||||
|
json.dump(results, f) |
||||
|
|
||||
|
# print |
||||
|
print(f'{opts.output_dir}/results_test') |
||||
|
|
||||
|
|
||||
|
@torch.no_grad() |
||||
|
def validate(model, val_dataloader): |
||||
|
LOGGER.info(f"start running evaluation.") |
||||
|
model.eval() |
||||
|
tot_score = 0 |
||||
|
n_ex = 0 |
||||
|
st = time() |
||||
|
predictions = [] |
||||
|
for i, batch in enumerate(val_dataloader): |
||||
|
# inputs |
||||
|
(*batch_inputs, tgt_box_list, obj_boxes_list, sent_ids) = batch |
||||
|
|
||||
|
# scores (n, max_num_bb) |
||||
|
scores = model(*batch_inputs, targets=None, compute_loss=False) |
||||
|
ixs = torch.argmax(scores, 1).cpu().detach().numpy() # (n, ) |
||||
|
|
||||
|
# pred_boxes |
||||
|
for ix, obj_boxes, tgt_box, sent_id in \ |
||||
|
zip(ixs, obj_boxes_list, tgt_box_list, sent_ids): |
||||
|
pred_box = obj_boxes[ix] |
||||
|
predictions.append({'sent_id': sent_id, |
||||
|
'pred_box': pred_box.tolist(), |
||||
|
'tgt_box': tgt_box.tolist()}) |
||||
|
if (val_dataloader.loader.dataset.computeIoU(pred_box, tgt_box) |
||||
|
> .5): |
||||
|
tot_score += 1 |
||||
|
n_ex += 1 |
||||
|
|
||||
|
tot_time = time()-st |
||||
|
tot_score = sum(all_gather_list(tot_score)) |
||||
|
n_ex = sum(all_gather_list(n_ex)) |
||||
|
val_acc = tot_score / n_ex |
||||
|
val_log = {'valid/acc': val_acc, 'valid/ex_per_s': n_ex/tot_time} |
||||
|
model.train() |
||||
|
LOGGER.info(f"validation ({n_ex} sents) finished in " |
||||
|
f"{int(tot_time)} seconds" |
||||
|
f", accuracy: {val_acc*100:.2f}%") |
||||
|
|
||||
|
# summarizae |
||||
|
results = {'acc': val_acc, 'predictions': predictions} |
||||
|
return val_log, results |
||||
|
|
||||
|
|
||||
|
if __name__ == '__main__': |
||||
|
parser = argparse.ArgumentParser() |
||||
|
|
||||
|
# Requited parameters |
||||
|
parser.add_argument('--txt_db', |
||||
|
default=None, type=str, |
||||
|
help="The input train corpus. (LMDB)") |
||||
|
parser.add_argument('--img_dir', |
||||
|
default=None, type=str, |
||||
|
help="The input train images.") |
||||
|
parser.add_argument('--checkpoint', |
||||
|
default=None, type=str, |
||||
|
help="pretrained model (can take 'google-bert')") |
||||
|
parser.add_argument('--batch_size', |
||||
|
default=256, type=int, |
||||
|
help="number of sentences per batch") |
||||
|
parser.add_argument('--output_dir', |
||||
|
default=None, type=str, |
||||
|
help="The output directory where the model contains " |
||||
|
"the model checkpoints will be written.") |
||||
|
|
||||
|
# Device parameters |
||||
|
parser.add_argument('--fp16', |
||||
|
action='store_true', |
||||
|
help="whether to use fp-16 float precision instead of " |
||||
|
"32 bit") |
||||
|
parser.add_argument('--n_workers', type=int, default=4, |
||||
|
help="number of data workers") |
||||
|
parser.add_argument('--pin_mem', action='store_true', |
||||
|
help="pin memory") |
||||
|
|
||||
|
args = parser.parse_args() |
||||
|
|
||||
|
main(args) |
@ -0,0 +1,268 @@ |
|||||
|
"""run inference of VCR for submission""" |
||||
|
import argparse |
||||
|
import json |
||||
|
import os |
||||
|
from os.path import exists |
||||
|
from time import time |
||||
|
|
||||
|
import torch |
||||
|
from torch.nn import functional as F |
||||
|
from torch.utils.data import DataLoader |
||||
|
from tqdm import tqdm |
||||
|
from apex import amp |
||||
|
from horovod import torch as hvd |
||||
|
|
||||
|
from data import (DetectFeatLmdb, VcrEvalDataset, vcr_eval_collate, |
||||
|
PrefetchLoader) |
||||
|
from torch.utils.data.distributed import DistributedSampler |
||||
|
from model import BertForVisualCommonsenseReasoning |
||||
|
|
||||
|
from utils.logger import LOGGER |
||||
|
from utils.distributed import all_gather_list |
||||
|
from utils.misc import NoOp, Struct |
||||
|
NUM_SPECIAL_TOKENS = 81 |
||||
|
|
||||
|
|
||||
|
def load_img_feat(dir_list, path2imgdir, opts): |
||||
|
dir_ = dir_list.split(";") |
||||
|
assert len(dir_) <= 2, "More than two img_dirs found" |
||||
|
img_dir_gt, img_dir = None, None |
||||
|
gt_dir_path, dir_path = "", "" |
||||
|
for d in dir_: |
||||
|
if "gt" in d: |
||||
|
gt_dir_path = d |
||||
|
else: |
||||
|
dir_path = d |
||||
|
if gt_dir_path != "": |
||||
|
img_dir_gt = path2imgdir.get(gt_dir_path, None) |
||||
|
if img_dir_gt is None: |
||||
|
img_dir_gt = DetectFeatLmdb(gt_dir_path, -1, |
||||
|
opts.max_bb, opts.min_bb, 100, |
||||
|
opts.compressed_db) |
||||
|
path2imgdir[gt_dir_path] = img_dir_gt |
||||
|
if dir_path != "": |
||||
|
img_dir = path2imgdir.get(dir_path, None) |
||||
|
if img_dir is None: |
||||
|
img_dir = DetectFeatLmdb(dir_path, opts.conf_th, |
||||
|
opts.max_bb, opts.min_bb, opts.num_bb, |
||||
|
opts.compressed_db) |
||||
|
path2imgdir[dir_path] = img_dir |
||||
|
return img_dir, img_dir_gt, path2imgdir |
||||
|
|
||||
|
|
||||
|
def main(opts): |
||||
|
hvd.init() |
||||
|
n_gpu = hvd.size() |
||||
|
device = torch.device("cuda", hvd.local_rank()) |
||||
|
torch.cuda.set_device(hvd.local_rank()) |
||||
|
rank = hvd.rank() |
||||
|
opts.rank = rank |
||||
|
LOGGER.info("device: {} n_gpu: {}, rank: {}, " |
||||
|
"16-bits training: {}".format( |
||||
|
device, n_gpu, hvd.rank(), opts.fp16)) |
||||
|
|
||||
|
hps_file = f'{opts.output_dir}/log/hps.json' |
||||
|
model_opts = Struct(json.load(open(hps_file))) |
||||
|
|
||||
|
path2imgdir = {} |
||||
|
# load DBs and image dirs |
||||
|
val_img_dir, val_img_dir_gt, path2imgdir = load_img_feat( |
||||
|
opts.img_dir, path2imgdir, model_opts) |
||||
|
eval_dataset = VcrEvalDataset("test", opts.txt_db, |
||||
|
val_img_dir_gt, val_img_dir, |
||||
|
max_txt_len=-1) |
||||
|
|
||||
|
# Prepare model |
||||
|
bert_model = json.load(open(f'{opts.txt_db}/meta.json'))['bert'] |
||||
|
model = BertForVisualCommonsenseReasoning.from_pretrained( |
||||
|
bert_model, img_dim=2048, obj_cls=False, |
||||
|
state_dict={}) |
||||
|
model.init_type_embedding() |
||||
|
model.init_word_embedding(NUM_SPECIAL_TOKENS) |
||||
|
if exists(opts.checkpoint): |
||||
|
ckpt_file = opts.checkpoint |
||||
|
else: |
||||
|
ckpt_file = f'{opts.output_dir}/ckpt/model_step_{opts.checkpoint}.pt' |
||||
|
checkpoint = torch.load(ckpt_file) |
||||
|
state_dict = checkpoint.get('model_state', checkpoint) |
||||
|
matched_state_dict = {} |
||||
|
unexpected_keys = set() |
||||
|
missing_keys = set() |
||||
|
for name, param in model.named_parameters(): |
||||
|
missing_keys.add(name) |
||||
|
for key, data in state_dict.items(): |
||||
|
if key in missing_keys: |
||||
|
matched_state_dict[key] = data |
||||
|
missing_keys.remove(key) |
||||
|
else: |
||||
|
unexpected_keys.add(key) |
||||
|
print("Unexpected_keys:", list(unexpected_keys)) |
||||
|
print("Missing_keys:", list(missing_keys)) |
||||
|
model.load_state_dict(matched_state_dict, strict=False) |
||||
|
if model_opts.cut_bert != -1: |
||||
|
# cut some layers of BERT |
||||
|
model.bert.encoder.layer = torch.nn.ModuleList( |
||||
|
model.bert.encoder.layer[:model_opts.cut_bert]) |
||||
|
model.to(device) |
||||
|
if opts.fp16: |
||||
|
model = amp.initialize(model, enabled=opts.fp16, opt_level='O2') |
||||
|
|
||||
|
sampler = DistributedSampler( |
||||
|
eval_dataset, num_replicas=n_gpu, rank=rank) |
||||
|
eval_dataloader = DataLoader(eval_dataset, |
||||
|
batch_size=opts.batch_size, |
||||
|
sampler=sampler, |
||||
|
num_workers=opts.n_workers, |
||||
|
pin_memory=opts.pin_mem, |
||||
|
collate_fn=vcr_eval_collate) |
||||
|
eval_dataloader = PrefetchLoader(eval_dataloader) |
||||
|
|
||||
|
val_log, results = evaluate(model, eval_dataloader) |
||||
|
result_dir = f'{opts.output_dir}/results_{opts.split}' |
||||
|
if not exists(result_dir) and rank == 0: |
||||
|
os.makedirs(result_dir) |
||||
|
# dummy sync |
||||
|
_ = None |
||||
|
all_gather_list(_) |
||||
|
if n_gpu > 1: |
||||
|
with open(f'{opts.output_dir}/results_test/' |
||||
|
f'results_{opts.checkpoint}_rank{rank}.json', |
||||
|
'w') as f: |
||||
|
json.dump(results, f) |
||||
|
# dummy sync |
||||
|
_ = None |
||||
|
all_gather_list(_) |
||||
|
# join results |
||||
|
if n_gpu > 1: |
||||
|
results = [] |
||||
|
for rank in range(n_gpu): |
||||
|
results.extend(json.load(open( |
||||
|
f'{opts.output_dir}/results_test/' |
||||
|
f'results_{opts.checkpoint}_rank{rank}.json'))) |
||||
|
if rank == 0: |
||||
|
with open(f'{opts.output_dir}/results_test/' |
||||
|
f'results_{opts.checkpoint}_all.json', 'w') as f: |
||||
|
json.dump(results, f) |
||||
|
|
||||
|
|
||||
|
def compute_accuracies(out_qa, labels_qa, out_qar, labels_qar): |
||||
|
outputs_qa = out_qa.max(dim=-1)[1] |
||||
|
outputs_qar = out_qar.max(dim=-1)[1] |
||||
|
matched_qa = outputs_qa.squeeze() == labels_qa.squeeze() |
||||
|
matched_qar = outputs_qar.squeeze() == labels_qar.squeeze() |
||||
|
matched_joined = matched_qa & matched_qar |
||||
|
n_correct_qa = matched_qa.sum().item() |
||||
|
n_correct_qar = matched_qar.sum().item() |
||||
|
n_correct_joined = matched_joined.sum().item() |
||||
|
return n_correct_qa, n_correct_qar, n_correct_joined |
||||
|
|
||||
|
|
||||
|
@torch.no_grad() |
||||
|
def evaluate(model, val_loader): |
||||
|
if hvd.rank() == 0: |
||||
|
val_pbar = tqdm(total=len(val_loader)) |
||||
|
else: |
||||
|
val_pbar = NoOp() |
||||
|
LOGGER.info(f"start running evaluation ...") |
||||
|
model.eval() |
||||
|
val_qa_loss, val_qar_loss = 0, 0 |
||||
|
tot_qa_score, tot_qar_score, tot_score = 0, 0, 0 |
||||
|
n_ex = 0 |
||||
|
st = time() |
||||
|
results = {} |
||||
|
for i, batch in enumerate(val_loader): |
||||
|
qids, *inputs, qa_targets, qar_targets, _ = batch |
||||
|
scores = model( |
||||
|
*inputs, targets=None, compute_loss=False) |
||||
|
scores = scores.view(len(qids), -1) |
||||
|
if torch.max(qa_targets) > -1: |
||||
|
vcr_qa_loss = F.cross_entropy( |
||||
|
scores[:, :4], qa_targets.squeeze(-1), reduction="sum") |
||||
|
if scores.shape[1] > 8: |
||||
|
qar_scores = [] |
||||
|
for batch_id in range(scores.shape[0]): |
||||
|
answer_ind = qa_targets[batch_id].item() |
||||
|
qar_index = [4+answer_ind*4+i |
||||
|
for i in range(4)] |
||||
|
qar_scores.append(scores[batch_id, qar_index]) |
||||
|
qar_scores = torch.stack(qar_scores, dim=0) |
||||
|
else: |
||||
|
qar_scores = scores[:, 4:] |
||||
|
vcr_qar_loss = F.cross_entropy( |
||||
|
qar_scores, qar_targets.squeeze(-1), reduction="sum") |
||||
|
val_qa_loss += vcr_qa_loss.item() |
||||
|
val_qar_loss += vcr_qar_loss.item() |
||||
|
|
||||
|
curr_qa_score, curr_qar_score, curr_score = compute_accuracies( |
||||
|
scores[:, :4], qa_targets, qar_scores, qar_targets) |
||||
|
tot_qar_score += curr_qar_score |
||||
|
tot_qa_score += curr_qa_score |
||||
|
tot_score += curr_score |
||||
|
for qid, score in zip(qids, scores): |
||||
|
results[qid] = score.cpu().tolist() |
||||
|
n_ex += len(qids) |
||||
|
val_pbar.update(1) |
||||
|
val_qa_loss = sum(all_gather_list(val_qa_loss)) |
||||
|
val_qar_loss = sum(all_gather_list(val_qar_loss)) |
||||
|
tot_qa_score = sum(all_gather_list(tot_qa_score)) |
||||
|
tot_qar_score = sum(all_gather_list(tot_qar_score)) |
||||
|
tot_score = sum(all_gather_list(tot_score)) |
||||
|
n_ex = sum(all_gather_list(n_ex)) |
||||
|
tot_time = time()-st |
||||
|
val_qa_loss /= n_ex |
||||
|
val_qar_loss /= n_ex |
||||
|
val_qa_acc = tot_qa_score / n_ex |
||||
|
val_qar_acc = tot_qar_score / n_ex |
||||
|
val_acc = tot_score / n_ex |
||||
|
val_log = {f'valid/vcr_qa_loss': val_qa_loss, |
||||
|
f'valid/vcr_qar_loss': val_qar_loss, |
||||
|
f'valid/acc_qa': val_qa_acc, |
||||
|
f'valid/acc_qar': val_qar_acc, |
||||
|
f'valid/acc': val_acc, |
||||
|
f'valid/ex_per_s': n_ex/tot_time} |
||||
|
model.train() |
||||
|
LOGGER.info(f"validation finished in {int(tot_time)} seconds, " |
||||
|
f"score_qa: {val_qa_acc*100:.2f} " |
||||
|
f"score_qar: {val_qar_acc*100:.2f} " |
||||
|
f"score: {val_acc*100:.2f} ") |
||||
|
return val_log, results |
||||
|
|
||||
|
|
||||
|
if __name__ == "__main__": |
||||
|
parser = argparse.ArgumentParser() |
||||
|
|
||||
|
# Required parameters |
||||
|
parser.add_argument("--txt_db", |
||||
|
default=None, type=str, |
||||
|
help="The input train corpus. (LMDB)") |
||||
|
parser.add_argument("--img_dir", |
||||
|
default=None, type=str, |
||||
|
help="The input train images.") |
||||
|
parser.add_argument('--compressed_db', action='store_true', |
||||
|
help='use compressed LMDB') |
||||
|
parser.add_argument("--split", |
||||
|
default="test", type=str, |
||||
|
help="The input split") |
||||
|
parser.add_argument("--checkpoint", |
||||
|
default=None, type=str, |
||||
|
help="pretrained model (can take 'google-bert') ") |
||||
|
parser.add_argument("--batch_size", |
||||
|
default=10, type=int, |
||||
|
help="number of tokens in a batch") |
||||
|
parser.add_argument( |
||||
|
"--output_dir", default=None, type=str, |
||||
|
help="The output directory where the model checkpoints will be " |
||||
|
"written.") |
||||
|
# device parameters |
||||
|
parser.add_argument('--fp16', |
||||
|
action='store_true', |
||||
|
help="Whether to use 16-bit float precision instead " |
||||
|
"of 32-bit") |
||||
|
parser.add_argument('--n_workers', type=int, default=4, |
||||
|
help="number of data workers") |
||||
|
parser.add_argument('--pin_mem', action='store_true', |
||||
|
help="pin memory") |
||||
|
|
||||
|
args = parser.parse_args() |
||||
|
|
||||
|
main(args) |
@ -0,0 +1,180 @@ |
|||||
|
"""run inference of VQA for submission""" |
||||
|
import argparse |
||||
|
import json |
||||
|
import os |
||||
|
from os.path import exists |
||||
|
from time import time |
||||
|
|
||||
|
import torch |
||||
|
from torch.utils.data import DataLoader |
||||
|
|
||||
|
from apex import amp |
||||
|
from horovod import torch as hvd |
||||
|
import numpy as np |
||||
|
from cytoolz import concat |
||||
|
|
||||
|
from data import (TokenBucketSampler, PrefetchLoader, |
||||
|
DetectFeatLmdb, TxtTokLmdb, VqaEvalDataset, vqa_eval_collate) |
||||
|
from model import UniterForVisualQuestionAnswering |
||||
|
|
||||
|
from utils.logger import LOGGER |
||||
|
from utils.distributed import all_gather_list |
||||
|
from utils.misc import Struct |
||||
|
from utils.const import BUCKET_SIZE, IMG_DIM |
||||
|
|
||||
|
|
||||
|
def main(opts): |
||||
|
hvd.init() |
||||
|
n_gpu = hvd.size() |
||||
|
device = torch.device("cuda", hvd.local_rank()) |
||||
|
torch.cuda.set_device(hvd.local_rank()) |
||||
|
rank = hvd.rank() |
||||
|
LOGGER.info("device: {} n_gpu: {}, rank: {}, " |
||||
|
"16-bits training: {}".format( |
||||
|
device, n_gpu, hvd.rank(), opts.fp16)) |
||||
|
|
||||
|
hps_file = f'{opts.output_dir}/log/hps.json' |
||||
|
model_opts = Struct(json.load(open(hps_file))) |
||||
|
|
||||
|
# train_examples = None |
||||
|
ans2label_file = f'{opts.output_dir}/ckpt/ans2label.json' |
||||
|
ans2label = json.load(open(ans2label_file)) |
||||
|
label2ans = {label: ans for ans, label in ans2label.items()} |
||||
|
|
||||
|
# load DBs and image dirs |
||||
|
eval_img_db = DetectFeatLmdb(opts.img_db, |
||||
|
model_opts.conf_th, model_opts.max_bb, |
||||
|
model_opts.min_bb, model_opts.num_bb, |
||||
|
opts.compressed_db) |
||||
|
eval_txt_db = TxtTokLmdb(opts.txt_db, -1) |
||||
|
eval_dataset = VqaEvalDataset(len(ans2label), eval_txt_db, eval_img_db) |
||||
|
|
||||
|
# Prepare model |
||||
|
if exists(opts.checkpoint): |
||||
|
ckpt_file = opts.checkpoint |
||||
|
else: |
||||
|
ckpt_file = f'{opts.output_dir}/ckpt/model_step_{opts.checkpoint}.pt' |
||||
|
checkpoint = torch.load(ckpt_file) |
||||
|
model = UniterForVisualQuestionAnswering.from_pretrained( |
||||
|
f'{opts.output_dir}/log/model.json', checkpoint, |
||||
|
img_dim=IMG_DIM, num_answer=len(ans2label)) |
||||
|
model.to(device) |
||||
|
model = amp.initialize(model, enabled=opts.fp16, opt_level='O2') |
||||
|
|
||||
|
sampler = TokenBucketSampler(eval_dataset.lens, bucket_size=BUCKET_SIZE, |
||||
|
batch_size=opts.batch_size, droplast=False) |
||||
|
eval_dataloader = DataLoader(eval_dataset, |
||||
|
batch_sampler=sampler, |
||||
|
num_workers=opts.n_workers, |
||||
|
pin_memory=opts.pin_mem, |
||||
|
collate_fn=vqa_eval_collate) |
||||
|
eval_dataloader = PrefetchLoader(eval_dataloader) |
||||
|
|
||||
|
val_log, results, logits = evaluate(model, eval_dataloader, label2ans, |
||||
|
opts.save_logits) |
||||
|
result_dir = f'{opts.output_dir}/results_test' |
||||
|
if not exists(result_dir) and rank == 0: |
||||
|
os.makedirs(result_dir) |
||||
|
|
||||
|
all_results = list(concat(all_gather_list(results))) |
||||
|
if opts.save_logits: |
||||
|
all_logits = {} |
||||
|
for id2logit in all_gather_list(logits): |
||||
|
all_logits.update(id2logit) |
||||
|
if hvd.rank() == 0: |
||||
|
with open(f'{result_dir}/' |
||||
|
f'results_{opts.checkpoint}_all.json', 'w') as f: |
||||
|
json.dump(all_results, f) |
||||
|
if opts.save_logits: |
||||
|
np.savez(f'{result_dir}/logits_{opts.checkpoint}_all.npz', |
||||
|
**all_logits) |
||||
|
|
||||
|
|
||||
|
@torch.no_grad() |
||||
|
def evaluate(model, eval_loader, label2ans, save_logits=False): |
||||
|
LOGGER.info("start running evaluation...") |
||||
|
model.eval() |
||||
|
n_ex = 0 |
||||
|
st = time() |
||||
|
results = [] |
||||
|
logits = {} |
||||
|
for i, batch in enumerate(eval_loader): |
||||
|
qids = batch['qids'] |
||||
|
scores = model(batch, compute_loss=False) |
||||
|
answers = [label2ans[i] |
||||
|
for i in scores.max(dim=-1, keepdim=False |
||||
|
)[1].cpu().tolist()] |
||||
|
for qid, answer in zip(qids, answers): |
||||
|
results.append({'answer': answer, 'question_id': int(qid)}) |
||||
|
if save_logits: |
||||
|
scores = scores.cpu() |
||||
|
for i, qid in enumerate(qids): |
||||
|
logits[qid] = scores[i].half().numpy() |
||||
|
if i % 100 == 0 and hvd.rank() == 0: |
||||
|
n_results = len(results) |
||||
|
n_results *= hvd.size() # an approximation to avoid hangs |
||||
|
LOGGER.info(f'{n_results}/{len(eval_loader.dataset)} ' |
||||
|
'answers predicted') |
||||
|
n_ex += len(qids) |
||||
|
n_ex = sum(all_gather_list(n_ex)) |
||||
|
tot_time = time()-st |
||||
|
val_log = {'valid/ex_per_s': n_ex/tot_time} |
||||
|
model.train() |
||||
|
LOGGER.info(f"evaluation finished in {int(tot_time)} seconds " |
||||
|
f"at {int(n_ex/tot_time)} examples per second") |
||||
|
return val_log, results, logits |
||||
|
|
||||
|
|
||||
|
def compute_score_with_logits(logits, labels): |
||||
|
logits = torch.max(logits, 1)[1] # argmax |
||||
|
one_hots = torch.zeros(*labels.size(), device=labels.device) |
||||
|
one_hots.scatter_(1, logits.view(-1, 1), 1) |
||||
|
scores = (one_hots * labels) |
||||
|
return scores |
||||
|
|
||||
|
|
||||
|
if __name__ == "__main__": |
||||
|
parser = argparse.ArgumentParser() |
||||
|
|
||||
|
# Required parameters |
||||
|
parser.add_argument("--txt_db", |
||||
|
default=None, type=str, |
||||
|
help="The input train corpus. (LMDB)") |
||||
|
parser.add_argument("--img_db", |
||||
|
default=None, type=str, |
||||
|
help="The input train images.") |
||||
|
parser.add_argument('--compressed_db', action='store_true', |
||||
|
help='use compressed LMDB') |
||||
|
parser.add_argument("--checkpoint", |
||||
|
default=None, type=str, |
||||
|
help="pretrained model (can take 'google-bert') ") |
||||
|
parser.add_argument("--batch_size", |
||||
|
default=8192, type=int, |
||||
|
help="number of tokens in a batch") |
||||
|
|
||||
|
parser.add_argument( |
||||
|
"--output_dir", default=None, type=str, |
||||
|
help="The output directory where the model checkpoints will be " |
||||
|
"written.") |
||||
|
|
||||
|
parser.add_argument("--save_logits", action='store_true', |
||||
|
help="Whether to save logits (for making ensemble)") |
||||
|
|
||||
|
# Prepro parameters |
||||
|
|
||||
|
# device parameters |
||||
|
parser.add_argument('--fp16', |
||||
|
action='store_true', |
||||
|
help="Whether to use 16-bit float precision instead " |
||||
|
"of 32-bit") |
||||
|
parser.add_argument('--n_workers', type=int, default=4, |
||||
|
help="number of data workers") |
||||
|
parser.add_argument('--pin_mem', action='store_true', |
||||
|
help="pin memory") |
||||
|
|
||||
|
args = parser.parse_args() |
||||
|
|
||||
|
# options safe guard |
||||
|
# TODO |
||||
|
|
||||
|
main(args) |
@ -0,0 +1,71 @@ |
|||||
|
# Supports ablation study of the follows: |
||||
|
# 1) scratch |
||||
|
# 2) bert |
||||
|
# 3) mrfr |
||||
|
# 4) mlm |
||||
|
# 5) itm |
||||
|
# 6) mlm_itm |
||||
|
# 7) mlm_mrfr_itm |
||||
|
# 8) mlm_mrc_itm |
||||
|
# 9) mlm_mrckl_itm |
||||
|
# 10) mlm_mrfr_mrc_itm |
||||
|
# 11) mlm_mrfr_mrckl_itm |
||||
|
# 12) mlm_mrfr_mrckl_itm_jrm |
||||
|
# 13) mlm_mrfr_mrckl_itm_jrm+ |
||||
|
|
||||
|
ablation_pretrained_model=$1 |
||||
|
|
||||
|
case $ablation_pretrained_model in |
||||
|
scratch|bert|mrfr|mlm|itm|mlm_itm|mlm_mrfr_itm|mlm_mrc_itm|mlm_mrckl_itm|mlm_mrfr_mrc_itm|mlm_mrfr_mrckl_itm|mlm_mrfr_mrckl_itm_jrm|mlm_mrfr_mrckl_itm_jrm+) |
||||
|
echo running $ablation_pretrained_model ...;; |
||||
|
*) |
||||
|
echo "$ablation_pretrained_model" not supported.; |
||||
|
exit 1; |
||||
|
esac |
||||
|
|
||||
|
if [ "$ablation_pretrained_model" == "mrfr" ]; then |
||||
|
cut_bert=1 |
||||
|
else |
||||
|
cut_bert=-1 |
||||
|
fi |
||||
|
|
||||
|
case $ablation_pretrained_model in |
||||
|
scratch) |
||||
|
cut_bert=1; |
||||
|
checkpoint="scratch";; |
||||
|
bert) |
||||
|
cut_bert=1; |
||||
|
checkpoint="google-bert";; |
||||
|
mrfr) |
||||
|
cut_bert=1; |
||||
|
checkpoint=/pretrain/ablation/"${ablation_pretrained_model}".pt;; |
||||
|
*) |
||||
|
cut_bert=-1; |
||||
|
checkpoint=/pretrain/ablation/"${ablation_pretrained_model}".pt;; |
||||
|
esac |
||||
|
|
||||
|
horovodrun -np 1 -H localhost:1 \ |
||||
|
python train_re.py \ |
||||
|
--train_txt_db /db/refcoco+_train_base-cased.db \ |
||||
|
--train_img_dir /img/visual_grounding_coco_gt \ |
||||
|
--val_txt_db /db/refcoco+_val_base-cased.db \ |
||||
|
--val_img_dir /img/visual_grounding_det_coco \ |
||||
|
--checkpoint ${checkpoint} \ |
||||
|
--cut_bert ${cut_bert} \ |
||||
|
--output_dir /storage/refcoco+/ablation_"${ablation_pretrained_model}" \ |
||||
|
--max_txt_len 60 \ |
||||
|
--train_batch_size 128 \ |
||||
|
--val_batch_size 128 \ |
||||
|
--learning_rate 8e-5 \ |
||||
|
--optim adamw \ |
||||
|
--betas 0.9 0.98 \ |
||||
|
--weight_decay 0.01 \ |
||||
|
--dropout 0.1 \ |
||||
|
--grad_norm 2.0 \ |
||||
|
--decay linear \ |
||||
|
--num_train_steps 24000 \ |
||||
|
--warmup_steps 1500 \ |
||||
|
--gradient_accumulation_steps 1 \ |
||||
|
--seed 24 \ |
||||
|
--mlp 1 \ |
||||
|
--fp16 |
@ -0,0 +1,38 @@ |
|||||
|
# Supports ablation study of the follows: |
||||
|
# 1) scratch |
||||
|
# 2) bert |
||||
|
# 3) mrfr |
||||
|
# 4) mlm |
||||
|
# 5) itm |
||||
|
# 6) mlm_itm |
||||
|
# 7) mlm_mrfr_itm |
||||
|
# 8) mlm_mrc_itm |
||||
|
# 9) mlm_mrckl_itm |
||||
|
# 10) mlm_mrfr_mrc_itm |
||||
|
# 11) mlm_mrfr_mrckl_itm |
||||
|
# 12) mlm_mrfr_mrckl_itm_jrm |
||||
|
# 13) mlm_mrfr_mrckl_itm_jrm+ |
||||
|
|
||||
|
ablation_pretrained_model=$1 |
||||
|
|
||||
|
case $ablation_pretrained_model in |
||||
|
scratch|bert|mrfr|mlm|itm|mlm_itm|mlm_mrfr_itm|mlm_mrc_itm|mlm_mrckl_itm|mlm_mrfr_mrc_itm|mlm_mrfr_mrckl_itm|mlm_mrfr_mrckl_itm_jrm|mlm_mrfr_mrckl_itm_jrm+) |
||||
|
echo running $ablation_pretrained_model ...;; |
||||
|
*) |
||||
|
echo "$ablation_pretrained_model" not supported.; |
||||
|
exit 1; |
||||
|
esac |
||||
|
|
||||
|
horovodrun -np 1 -H localhost:1 \ |
||||
|
python eval_re.py \ |
||||
|
--txt_db /db/refcoco+_val_base-cased.db:/db/refcoco+_testA_base-cased.db:/db/refcoco+_testB_base-cased.db \ |
||||
|
--img_dir /img/visual_grounding_coco_gt \ |
||||
|
--output_dir /storage/refcoco+/ablation_${ablation_pretrained_model} \ |
||||
|
--checkpoint best |
||||
|
|
||||
|
horovodrun -np 1 -H localhost:1 \ |
||||
|
python eval_re.py \ |
||||
|
--txt_db /db/refcoco+_val_base-cased.db:/db/refcoco+_testA_base-cased.db:/db/refcoco+_testB_base-cased.db \ |
||||
|
--img_dir /img/visual_grounding_det_coco \ |
||||
|
--output_dir /storage/refcoco+/ablation_${ablation_pretrained_model} \ |
||||
|
--checkpoint best |
Some files were not shown because too many files changed in this diff
Loading…
Reference in new issue