logo
Browse Source

update the operator.

Signed-off-by: wxywb <xy.wang@zilliz.com>
main
wxywb 2 years ago
parent
commit
4059059c1b
  1. 4
      __init__.py
  2. 19
      config/bert_base.json
  3. 20
      config/coco_eval_config.json
  4. 43
      config/coco_ft_config.json
  5. 22
      config/flickr30k_eval_config.json
  6. 38
      config/flickr30k_ft_config.json
  7. 16
      config/img_base.json
  8. 191
      config/pretrain-alldata-base.json
  9. BIN
      data/model/resnet101_faster_rcnn_final.pth
  10. BIN
      data/model/uniter-base.pt
  11. 0
      detector/__init__.py
  12. BIN
      detector/__pycache__/__init__.cpython-38.pyc
  13. BIN
      detector/__pycache__/bbox_transform.cpython-38.pyc
  14. BIN
      detector/__pycache__/faster_rcnn.cpython-38.pyc
  15. BIN
      detector/__pycache__/generate_anchors.cpython-38.pyc
  16. BIN
      detector/__pycache__/rpn.cpython-38.pyc
  17. 75
      detector/bbox_transform.py
  18. 478
      detector/faster_rcnn.py
  19. 105
      detector/generate_anchors.py
  20. 136
      detector/rpn.py
  21. BIN
      dvl/__pycache__/const.cpython-38.pyc
  22. 3
      dvl/const.py
  23. 366
      dvl/data/itm.py
  24. 592
      dvl/data/itm_pre.py
  25. 390
      dvl/data/mlm.py
  26. 263
      dvl/data/mrm.py
  27. 145
      dvl/data/vqa.py
  28. 66
      dvl/hn.py
  29. 154
      dvl/indexer/faiss_indexers.py
  30. 0
      dvl/models/__init__.py
  31. BIN
      dvl/models/__pycache__/__init__.cpython-38.pyc
  32. BIN
      dvl/models/__pycache__/bi_encoder.cpython-38.pyc
  33. 757
      dvl/models/bi_encoder.py
  34. 176
      dvl/options.py
  35. 209
      dvl/trainer.py
  36. 234
      dvl/utils.py
  37. 52
      lightningdot.py
  38. 3
      requirements.txt
  39. 22
      uniter_model/Dockerfile
  40. 21
      uniter_model/LICENSE
  41. 89
      uniter_model/README.md
  42. 36
      uniter_model/config/config-vcr-bert-2gpu.json
  43. 11
      uniter_model/config/eval-itm-coco.json
  44. 11
      uniter_model/config/eval-itm-flickr.json
  45. 53
      uniter_model/config/hps-itm.json
  46. 25
      uniter_model/config/hps-refcoco+.json
  47. 26
      uniter_model/config/hps-refcoco+_conceptual.json
  48. 26
      uniter_model/config/hps-refcoco+_conceptual_large_weak.json
  49. 29
      uniter_model/config/hps-refcoco+_conceptual_rank.json
  50. 26
      uniter_model/config/hps-refcoco.json
  51. 31
      uniter_model/config/hps-ve-large.json
  52. 31
      uniter_model/config/hps-ve.json
  53. 30
      uniter_model/config/hps-vqa.json
  54. 47
      uniter_model/config/itm-coco-base.json
  55. 45
      uniter_model/config/itm-ot-base-16gpus.json
  56. 45
      uniter_model/config/itm-ot-base-16gpus_philly.json
  57. 42
      uniter_model/config/pretrain-gqa-alltask.json
  58. 42
      uniter_model/config/pretrain-mlm-coco.json
  59. 53
      uniter_model/config/pretrain-mlm_itmot_mrfr_mrckl-indomain-base.json
  60. 42
      uniter_model/config/pretrain-mrckl-coco.json
  61. 42
      uniter_model/config/pretrain-mrfr-coco.json
  62. 43
      uniter_model/config/pretrain-mrm-nce-coco.json
  63. 38
      uniter_model/config/pretrain-vcr-alltask.json
  64. 40
      uniter_model/config/train-itm-debug.json
  65. 38
      uniter_model/config/train-itm-flickr-base-hnv2.json
  66. 40
      uniter_model/config/train-itm-flickr-base.json
  67. 37
      uniter_model/config/train-nlvr2-base-1gpu.json
  68. 31
      uniter_model/config/train-ve-base-2gpu.json
  69. 31
      uniter_model/config/train-ve-large-2gpu.json
  70. 35
      uniter_model/config/train-vqa-base-2gpu.json
  71. 14
      uniter_model/config/uniter-base.json
  72. 13
      uniter_model/config/uniter-large.json
  73. 27
      uniter_model/data/__init__.py
  74. 283
      uniter_model/data/data.py
  75. 572
      uniter_model/data/itm.py
  76. 138
      uniter_model/data/loader.py
  77. 360
      uniter_model/data/mlm.py
  78. 287
      uniter_model/data/mrm.py
  79. 136
      uniter_model/data/mrm_nce.py
  80. 218
      uniter_model/data/nlvr2.py
  81. 319
      uniter_model/data/re.py
  82. 116
      uniter_model/data/sampler.py
  83. BIN
      uniter_model/data/test_data/input0.txt
  84. BIN
      uniter_model/data/test_data/input1.txt
  85. BIN
      uniter_model/data/test_data/input2.txt
  86. BIN
      uniter_model/data/test_data/input3.txt
  87. BIN
      uniter_model/data/test_data/input4.txt
  88. BIN
      uniter_model/data/test_data/input5.txt
  89. BIN
      uniter_model/data/test_data/input6.txt
  90. BIN
      uniter_model/data/test_data/input7.txt
  91. 725
      uniter_model/data/vcr.py
  92. 19
      uniter_model/data/ve.py
  93. 124
      uniter_model/data/vqa.py
  94. 53
      uniter_model/eval/itm.py
  95. 62
      uniter_model/eval/nlvr2.py
  96. 218
      uniter_model/eval_re.py
  97. 268
      uniter_model/eval_vcr.py
  98. 180
      uniter_model/eval_vqa.py
  99. 71
      uniter_model/experiments/ablation_refcoco+.sh
  100. 38
      uniter_model/experiments/eval_ablation_refcoco+.sh

4
__init__.py

@ -14,5 +14,5 @@
from .lightningdot import LightningDOT
def lightningdot(modality: str):
return LightningDOT(modality)
def lightningdot(model_name: str, modality: str):
return LightningDOT(model_name, modality)

19
config/bert_base.json

@ -0,0 +1,19 @@
{
"architectures": [
"BertForMaskedLM"
],
"attention_probs_dropout_prob": 0.1,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 768,
"initializer_range": 0.02,
"intermediate_size": 3072,
"layer_norm_eps": 1e-12,
"max_position_embeddings": 512,
"model_type": "bert",
"num_attention_heads": 12,
"num_hidden_layers": 12,
"pad_token_id": 0,
"type_vocab_size": 2,
"vocab_size": 30522
}

20
config/coco_eval_config.json

@ -0,0 +1,20 @@
{
"img_model_type": "uniter-base",
"txt_model_type": "bert-base",
"txt_model_config": "bert-base-cased",
"img_model_config": "./config/img_base.json",
"itm_global_file":"./data/meta/coco_meta.json",
"seed": 42,
"output_dir": "/storage/debug-eval",
"max_txt_len": 60,
"conf_th": 0.2,
"max_bb": 100,
"min_bb": 10,
"num_bb": 36,
"project_dim": 768,
"test_txt_db": "./data/db/itm_coco_test_base-cased.db",
"test_img_db": "./data/img/coco_val2014/",
"project_name": "itm-debug",
"n_workers": 4,
"fp16": true
}

43
config/coco_ft_config.json

@ -0,0 +1,43 @@
{
"txt_model_type": "bert-base",
"txt_model_config": "bert-base-cased",
"img_model_type": "uniter-base",
"img_model_config": "./config/img_base.json",
"img_checkpoint": "./data/model/uniter-base.pt",
"itm_global_file":"./data/meta/coco_meta.json",
"train_batch_size": 64,
"val_batch_size": 256,
"gradient_accumulation_steps": 1,
"learning_rate": 2e-05,
"warmup_steps": 100,
"valid_steps": 500,
"num_train_epochs": 20,
"seed": 42,
"output_dir": "/storage/debug_coco",
"max_txt_len": 60,
"conf_th": 0.2,
"max_bb": 100,
"min_bb": 10,
"num_bb": 36,
"project_dim": 768,
"train_txt_dbs": [
"./data/db/itm_coco_train_base-cased.db",
"./data/db/itm_coco_restval_base-cased.db"
],
"train_img_dbs": [
"./data/img/coco_train2014/",
"./data/img/coco_val2014"
],
"val_txt_db": "./data/db/itm_coco_val_base-cased.db",
"val_img_db": "./data/img/coco_val2014/",
"test_txt_db": "./data/db/itm_coco_test_base-cased.db",
"test_img_db": "./data/img/coco_val2014/",
"project_name": "itm-debug",
"num_hard_negatives": 0,
"hard_negatives_sampling": "none",
"inf_minibatch_size": 0,
"n_workers": 0,
"fp16": true,
"compressed_db": false,
"pin_mem": true
}

22
config/flickr30k_eval_config.json

@ -0,0 +1,22 @@
{
"img_model_type": "uniter-base",
"txt_model_type": "bert-base",
"txt_model_config": "bert-base-cased",
"img_model_config": "./config/img_base.json",
"itm_global_file":"./data/meta/flickr_meta.json",
"seed": 42,
"output_dir": "/storage/debug-eval",
"max_txt_len": 60,
"conf_th": 0.2,
"max_bb": 100,
"min_bb": 10,
"num_bb": 36,
"project_dim": 768,
"val_txt_db": "./data/db/itm_flickr30k_val_base-cased.db",
"val_img_db": "./data/img/flickr30k/",
"test_txt_db": "./data/db/itm_flickr30k_test_base-cased.db",
"test_img_db": "./data/img/flickr30k/",
"project_name": "itm-debug",
"n_workers": 4,
"fp16": true
}

38
config/flickr30k_ft_config.json

@ -0,0 +1,38 @@
{
"txt_model_type": "bert-base",
"txt_model_config": "bert-base-cased",
"img_model_type": "uniter-base",
"img_model_config": "./config/img_base.json",
"img_checkpoint": "./data/model/uniter-base.pt",
"itm_global_file":"./data/meta/flickr_meta.json",
"train_batch_size": 64,
"gradient_accumulation_steps": 1,
"learning_rate": 2e-05,
"warmup_steps": 100,
"valid_steps": 500,
"num_train_epochs": 15,
"seed": 42,
"output_dir": "/storage/debug_flickr",
"max_txt_len": 60,
"conf_th": 0.2,
"max_bb": 100,
"min_bb": 10,
"num_bb": 36,
"project_dim": 768,
"train_txt_dbs": [
"./data/db/itm_flickr30k_train_base-cased.db"
],
"train_img_dbs": [
"./data/img/flickr30k/"
],
"val_txt_db": "./data/db/itm_flickr30k_val_base-cased.db",
"val_img_db": "./data/img/flickr30k/",
"test_txt_db": "./data/db/itm_flickr30k_test_base-cased.db",
"test_img_db": "./data/img/flickr30k/",
"project_name": "itm-debug",
"num_hard_negatives": 0,
"hard_negatives_sampling": "none",
"inf_minibatch_size": 0,
"n_workers": 0,
"fp16": true
}

16
config/img_base.json

@ -0,0 +1,16 @@
{
"attention_probs_dropout_prob": 0.1,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 768,
"initializer_range": 0.02,
"intermediate_size": 3072,
"layer_norm_eps": 1e-12,
"max_position_embeddings": 512,
"num_attention_heads": 12,
"num_hidden_layers": 12,
"pad_token_id": 0,
"type_vocab_size": 2,
"vocab_size": 28996,
"output_hidden_states": false
}

191
config/pretrain-alldata-base.json

@ -0,0 +1,191 @@
{
"compressed_db": false,
"txt_model_type": "bert-base",
"txt_model_config": "bert-base-cased",
"img_model_type": "uniter-base",
"img_model_config": "./config/img_base.json",
"model_config": "./config/img_base.json",
"output_dir": "/storage/pretrain/alltask_ot_alldata_base",
"project_dim": 768,
"mrm_prob": 0.15,
"neg_size": 128,
"nce_temp": 1.0,
"itm_neg_prob": 0.0,
"itm_ot_lambda": 0.0,
"max_txt_len": 60,
"conf_th": 0.2,
"max_bb": 100,
"min_bb": 10,
"num_bb": 36,
"train_batch_size": 10240,
"val_batch_size": 10240,
"gradient_accumulation_steps": 6,
"learning_rate": 5e-05,
"valid_steps": 10000,
"num_train_steps": 300000,
"optim": "adamw",
"betas": [
0.9,
0.98
],
"decay": "linear",
"dropout": 0.1,
"weight_decay": 0.01,
"grad_norm": 5.0,
"warmup_steps": 10000,
"seed": 42,
"fp16": true,
"n_workers": 3,
"pin_mem": true,
"train_datasets": [
{
"name": "coco_cap",
"db": [
"./data/db/pretrain_caption_coco_train_base-cased.db/",
"./data/db/pretrain_caption_coco_trainval_base-cased.db/"
],
"img": [
"./data/img/coco_train2014/",
"./data/img/coco_val2014/"
],
"tasks": [
"itm",
"mlm",
"mrfr",
"mrckl"
],
"mix_ratio": [
16,
8,
4,
4
]
},
{
"name": "vg_cap",
"db": [
"./data/db/pretrain_caption_vg_train_base-cased.db/"
],
"img": [
"./data/img/vg/"
],
"tasks": [
"itm",
"mlm",
"mrfr",
"mrckl"
],
"mix_ratio": [
16,
12,
6,
6
]
},
{
"name": "cc",
"db": [
"./data/db/conceptual_caption_train_base-cased.db/"
],
"img": [
"./data/img/gcc_train/"
],
"tasks": [
"itm",
"mlm",
"mrfr",
"mrckl"
],
"mix_ratio": [
16,
12,
6,
6
]
},
{
"name": "sbu",
"db": [
"./data/db/sbu_caption_train_base-cased.db/"
],
"img": [
"./data/img/sbu/"
],
"tasks": [
"itm",
"mlm",
"mrfr",
"mrckl"
],
"mix_ratio": [
16,
8,
4,
4
]
}
],
"val_datasets": [
{
"name": "coco_cap",
"db": [
"./data/db/pretrain_caption_coco_val_base-cased.db/"
],
"img": [
"./data/img/coco_val2014/"
],
"tasks": [
"itm",
"mlm",
"mrfr",
"mrckl"
]
},
{
"name": "vg_cap",
"db": [
"./data/db/pretrain_caption_vg_val_base-cased.db/"
],
"img": [
"./data/img/vg/"
],
"tasks": [
"itm",
"mlm",
"mrfr",
"mrckl"
]
},
{
"name": "cc",
"db": [
"./data/db/conceptual_caption_val_base-cased.db/"
],
"img": [
"./data/img/gcc_val/"
],
"tasks": [
"itm",
"mlm",
"mrfr",
"mrckl"
]
},
{
"name": "sbu",
"db": [
"./data/db/sbu_caption_val_base-cased.db/"
],
"img": [
"./data/img/sbu/"
],
"tasks": [
"itm",
"mlm",
"mrfr",
"mrckl"
]
}
],
"rank": 0
}

BIN
data/model/resnet101_faster_rcnn_final.pth (Stored with Git LFS)

Binary file not shown.

BIN
data/model/uniter-base.pt (Stored with Git LFS)

Binary file not shown.

0
detector/__init__.py

BIN
detector/__pycache__/__init__.cpython-38.pyc

Binary file not shown.

BIN
detector/__pycache__/bbox_transform.cpython-38.pyc

Binary file not shown.

BIN
detector/__pycache__/faster_rcnn.cpython-38.pyc

Binary file not shown.

BIN
detector/__pycache__/generate_anchors.cpython-38.pyc

Binary file not shown.

BIN
detector/__pycache__/rpn.cpython-38.pyc

Binary file not shown.

75
detector/bbox_transform.py

@ -0,0 +1,75 @@
# --------------------------------------------------------
# Fast R-CNN
# Copyright (c) 2015 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Ross Girshick
# --------------------------------------------------------
import numpy as np
def bbox_transform(ex_rois, gt_rois):
ex_widths = ex_rois[:, 2] - ex_rois[:, 0] + 1.0
ex_heights = ex_rois[:, 3] - ex_rois[:, 1] + 1.0
ex_ctr_x = ex_rois[:, 0] + 0.5 * ex_widths
ex_ctr_y = ex_rois[:, 1] + 0.5 * ex_heights
gt_widths = gt_rois[:, 2] - gt_rois[:, 0] + 1.0
gt_heights = gt_rois[:, 3] - gt_rois[:, 1] + 1.0
gt_ctr_x = gt_rois[:, 0] + 0.5 * gt_widths
gt_ctr_y = gt_rois[:, 1] + 0.5 * gt_heights
targets_dx = (gt_ctr_x - ex_ctr_x) / ex_widths
targets_dy = (gt_ctr_y - ex_ctr_y) / ex_heights
targets_dw = np.log(gt_widths / ex_widths)
targets_dh = np.log(gt_heights / ex_heights)
targets = np.vstack(
(targets_dx, targets_dy, targets_dw, targets_dh)).transpose()
return targets
def bbox_transform_inv(boxes, deltas):
if boxes.shape[0] == 0:
return np.zeros((0, deltas.shape[1]), dtype=deltas.dtype)
boxes = boxes.astype(deltas.dtype, copy=False)
widths = boxes[:, 2] - boxes[:, 0] + 1.0
heights = boxes[:, 3] - boxes[:, 1] + 1.0
ctr_x = boxes[:, 0] + 0.5 * widths
ctr_y = boxes[:, 1] + 0.5 * heights
dx = deltas[:, 0::4]
dy = deltas[:, 1::4]
dw = deltas[:, 2::4]
dh = deltas[:, 3::4]
pred_ctr_x = dx * widths[:, np.newaxis] + ctr_x[:, np.newaxis]
pred_ctr_y = dy * heights[:, np.newaxis] + ctr_y[:, np.newaxis]
pred_w = np.exp(dw) * widths[:, np.newaxis]
pred_h = np.exp(dh) * heights[:, np.newaxis]
pred_boxes = np.zeros(deltas.shape, dtype=deltas.dtype)
# x1
pred_boxes[:, 0::4] = pred_ctr_x - 0.5 * pred_w
# y1
pred_boxes[:, 1::4] = pred_ctr_y - 0.5 * pred_h
# x2
pred_boxes[:, 2::4] = pred_ctr_x + 0.5 * pred_w
# y2
pred_boxes[:, 3::4] = pred_ctr_y + 0.5 * pred_h
return pred_boxes
def clip_boxes(boxes, im_shape):
"""
Clip boxes to image boundaries.
"""
# x1 >= 0
boxes[:, 0::4] = np.maximum(np.minimum(boxes[:, 0::4], im_shape[1] - 1), 0)
# y1 >= 0
boxes[:, 1::4] = np.maximum(np.minimum(boxes[:, 1::4], im_shape[0] - 1), 0)
# x2 < im_shape[1]
boxes[:, 2::4] = np.maximum(np.minimum(boxes[:, 2::4], im_shape[1] - 1), 0)
# y2 < im_shape[0]
boxes[:, 3::4] = np.maximum(np.minimum(boxes[:, 3::4], im_shape[0] - 1), 0)
return boxes

478
detector/faster_rcnn.py

@ -0,0 +1,478 @@
import pickle
import ipdb
import torch
import numpy as np
import cv2
import torchvision
from torch import nn
from .rpn import RegionProposalNetwork
class ConvBlock(nn.Module):
def __init__(self,i,o,k,s,p,d,use_relu = True):
super(ConvBlock, self).__init__()
self.conv = nn.Conv2d(i, o, k, s, p, d)
self.bn = nn.BatchNorm2d(o)
self.use_relu = use_relu
if self.use_relu == True:
self.relu = nn.ReLU()
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
if self.use_relu == True:
x = self.relu(x)
return x
def load_convblock(block, convname, bnname, scalename, weights):
block.conv.weight = nn.Parameter(torch.FloatTensor(weights[convname][0]))
block.conv.bias = nn.Parameter(torch.zeros_like(block.conv.bias))
block.bn.running_mean = nn.Parameter(torch.FloatTensor(weights[bnname][0] / weights[bnname][2]))
block.bn.running_var = nn.Parameter(torch.FloatTensor(weights[bnname][1] / weights[bnname][2]))
block.bn.weight = nn.Parameter(torch.FloatTensor(weights[scalename][0]))
block.bn.bias = nn.Parameter(torch.FloatTensor(weights[scalename][1]))
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = ConvBlock(3,64,7,2,3,1,True)
self.pool1 = nn.MaxPool2d(3,2,0,ceil_mode=True)
self.res2a_branch1 = ConvBlock(64,256,1,1,0,1,False)
self.res2a_branch2a = ConvBlock(64,64,1,1,0,1,True)
self.res2a_branch2b = ConvBlock(64,64,3,1,1,1,True)
self.res2a_branch2c = ConvBlock(64,256,1,1,0,1,False)
self.res2b_branch2a = ConvBlock(256,64,1,1,0,1,True)
self.res2b_branch2b = ConvBlock(64,64,3,1,1,1,True)
self.res2b_branch2c = ConvBlock(64,256,1,1,0,1,False)
self.res2c_branch2a = ConvBlock(256,64,1,1,0,1,True)
self.res2c_branch2b = ConvBlock(64,64,3,1,1,1,True)
self.res2c_branch2c = ConvBlock(64,256,1,1,0,1,False)
self.res3a_branch1 = ConvBlock(256,512,1,2,0,1,False)
self.res3a_branch2a = ConvBlock(256,128,1,2,0,1,True)
self.res3a_branch2b = ConvBlock(128,128,3,1,1,1,True)
self.res3a_branch2c = ConvBlock(128,512,1,1,0,1,False)
self.res3b1_branch2a = ConvBlock(512,128,1,1,0,1,True)
self.res3b1_branch2b = ConvBlock(128,128,3,1,1,1,True)
self.res3b1_branch2c = ConvBlock(128,512,1,1,0,1,False)
self.res3b2_branch2a = ConvBlock(512,128,1,1,0,1,True)
self.res3b2_branch2b = ConvBlock(128,128,3,1,1,1,True)
self.res3b2_branch2c = ConvBlock(128,512,1,1,0,1,False)
self.res3b3_branch2a = ConvBlock(512,128,1,1,0,1,True)
self.res3b3_branch2b = ConvBlock(128,128,3,1,1,1,True)
self.res3b3_branch2c = ConvBlock(128,512,1,1,0,1,False)
self.res4a_branch1 = ConvBlock(512,1024,1,2,0,1,False)
self.res4a_branch2a = ConvBlock(512,256,1,2,0,1,True)
self.res4a_branch2b = ConvBlock(256,256,3,1,1,1,True)
self.res4a_branch2c = ConvBlock(256,1024,1,1,0,1,False)
self.res4b1_branch2a = ConvBlock(1024,256,1,1,0,1,True)
self.res4b1_branch2b = ConvBlock(256,256,3,1,1,1,True)
self.res4b1_branch2c = ConvBlock(256,1024,1,1,0,1,False)
self.res4b2_branch2a = ConvBlock(1024,256,1,1,0,1,True)
self.res4b2_branch2b = ConvBlock(256,256,3,1,1,1,True)
self.res4b2_branch2c = ConvBlock(256,1024,1,1,0,1,False)
self.res4b3_branch2a = ConvBlock(1024,256,1,1,0,1,True)
self.res4b3_branch2b = ConvBlock(256,256,3,1,1,1,True)
self.res4b3_branch2c = ConvBlock(256,1024,1,1,0,1,False)
self.res4b4_branch2a = ConvBlock(1024,256,1,1,0,1,True)
self.res4b4_branch2b = ConvBlock(256,256,3,1,1,1,True)
self.res4b4_branch2c = ConvBlock(256,1024,1,1,0,1,False)
self.res4b5_branch2a = ConvBlock(1024,256,1,1,0,1,True)
self.res4b5_branch2b = ConvBlock(256,256,3,1,1,1,True)
self.res4b5_branch2c = ConvBlock(256,1024,1,1,0,1,False)
self.res4b6_branch2a = ConvBlock(1024,256,1,1,0,1,True)
self.res4b6_branch2b = ConvBlock(256,256,3,1,1,1,True)
self.res4b6_branch2c = ConvBlock(256,1024,1,1,0,1,False)
self.res4b7_branch2a = ConvBlock(1024,256,1,1,0,1,True)
self.res4b7_branch2b = ConvBlock(256,256,3,1,1,1,True)
self.res4b7_branch2c = ConvBlock(256,1024,1,1,0,1,False)
self.res4b8_branch2a = ConvBlock(1024,256,1,1,0,1,True)
self.res4b8_branch2b = ConvBlock(256,256,3,1,1,1,True)
self.res4b8_branch2c = ConvBlock(256,1024,1,1,0,1,False)
self.res4b9_branch2a = ConvBlock(1024,256,1,1,0,1,True)
self.res4b9_branch2b = ConvBlock(256,256,3,1,1,1,True)
self.res4b9_branch2c = ConvBlock(256,1024,1,1,0,1,False)
self.res4b10_branch2a = ConvBlock(1024,256,1,1,0,1,True)
self.res4b10_branch2b = ConvBlock(256,256,3,1,1,1,True)
self.res4b10_branch2c = ConvBlock(256,1024,1,1,0,1,False)
self.res4b11_branch2a = ConvBlock(1024,256,1,1,0,1,True)
self.res4b11_branch2b = ConvBlock(256,256,3,1,1,1,True)
self.res4b11_branch2c = ConvBlock(256,1024,1,1,0,1,False)
self.res4b12_branch2a = ConvBlock(1024,256,1,1,0,1,True)
self.res4b12_branch2b = ConvBlock(256,256,3,1,1,1,True)
self.res4b12_branch2c = ConvBlock(256,1024,1,1,0,1,False)
self.res4b13_branch2a = ConvBlock(1024,256,1,1,0,1,True)
self.res4b13_branch2b = ConvBlock(256,256,3,1,1,1,True)
self.res4b13_branch2c = ConvBlock(256,1024,1,1,0,1,False)
self.res4b14_branch2a = ConvBlock(1024,256,1,1,0,1,True)
self.res4b14_branch2b = ConvBlock(256,256,3,1,1,1,True)
self.res4b14_branch2c = ConvBlock(256,1024,1,1,0,1,False)
self.res4b15_branch2a = ConvBlock(1024,256,1,1,0,1,True)
self.res4b15_branch2b = ConvBlock(256,256,3,1,1,1,True)
self.res4b15_branch2c = ConvBlock(256,1024,1,1,0,1,False)
self.res4b16_branch2a = ConvBlock(1024,256,1,1,0,1,True)
self.res4b16_branch2b = ConvBlock(256,256,3,1,1,1,True)
self.res4b16_branch2c = ConvBlock(256,1024,1,1,0,1,False)
self.res4b17_branch2a = ConvBlock(1024,256,1,1,0,1,True)
self.res4b17_branch2b = ConvBlock(256,256,3,1,1,1,True)
self.res4b17_branch2c = ConvBlock(256,1024,1,1,0,1,False)
self.res4b18_branch2a = ConvBlock(1024,256,1,1,0,1,True)
self.res4b18_branch2b = ConvBlock(256,256,3,1,1,1,True)
self.res4b18_branch2c = ConvBlock(256,1024,1,1,0,1,False)
self.res4b19_branch2a = ConvBlock(1024,256,1,1,0,1,True)
self.res4b19_branch2b = ConvBlock(256,256,3,1,1,1,True)
self.res4b19_branch2c = ConvBlock(256,1024,1,1,0,1,False)
self.res4b20_branch2a = ConvBlock(1024,256,1,1,0,1,True)
self.res4b20_branch2b = ConvBlock(256,256,3,1,1,1,True)
self.res4b20_branch2c = ConvBlock(256,1024,1,1,0,1,False)
self.res4b21_branch2a = ConvBlock(1024,256,1,1,0,1,True)
self.res4b21_branch2b = ConvBlock(256,256,3,1,1,1,True)
self.res4b21_branch2c = ConvBlock(256,1024,1,1,0,1,False)
self.res4b22_branch2a = ConvBlock(1024,256,1,1,0,1,True)
self.res4b22_branch2b = ConvBlock(256,256,3,1,1,1,True)
self.res4b22_branch2c = ConvBlock(256,1024,1,1,0,1,False)
self.res5a_branch1 = ConvBlock(1024,2048,1,1,0,1,False)
self.res5a_branch2a = ConvBlock(1024,512,1,1,0,1,True)
self.res5a_branch2b = ConvBlock(512,512,3,1,2,2,True)
self.res5a_branch2c = ConvBlock(512,2048,1,1,0,1,False)
self.res5b_branch2a = ConvBlock(2048,512,1,1,0,1,True)
self.res5b_branch2b = ConvBlock(512,512,3,1,2,2,True)
self.res5b_branch2c = ConvBlock(512,2048,1,1,0,1,False)
self.res5c_branch2a = ConvBlock(2048,512,1,1,0,1,True)
self.res5c_branch2b = ConvBlock(512,512,3,1,2,2,True)
self.res5c_branch2c = ConvBlock(512,2048,1,1,0,1,False)
self.rpn_conv_3x3 = nn.Conv2d(1024,512,3,1,1,1)
self.rpn_cls_score = nn.Conv2d(512,24,1,1,0,1)
self.rpn_bbox_pred = nn.Conv2d(512,48,1,1,0,1)
self.rpn = RegionProposalNetwork(pre_nms_topN = 6000, post_nms_topN = 300, nms_thresh = 0.7, min_size = 16, anchor_scales = (4, 8, 16, 32), feat_stride=16)
#self.pool5 = nn.MaxPool2d(3,2,1,ceil_mode=True)
self.cls_score = nn.Linear(2048, 1601)
def infer_resblock(self, l, r, x):
xl = x
xr = x
for b in l:
xl = b(xl)
for b in r:
xr = b(xr)
return xl + xr
def forward(self, x, im_size):
x = self.conv1(x)
x = self.pool1(x)
x = nn.functional.relu(self.infer_resblock([self.res2a_branch1], [self.res2a_branch2a,self.res2a_branch2b,self.res2a_branch2c],x))
x = nn.functional.relu(self.infer_resblock([lambda x: x], [self.res2b_branch2a,self.res2b_branch2b,self.res2b_branch2c],x))
x = nn.functional.relu(self.infer_resblock([lambda x: x], [self.res2c_branch2a,self.res2c_branch2b,self.res2c_branch2c],x))
x = nn.functional.relu(self.infer_resblock([self.res3a_branch1], [self.res3a_branch2a,self.res3a_branch2b,self.res3a_branch2c],x))
x = nn.functional.relu(self.infer_resblock([lambda x: x], [self.res3b1_branch2a,self.res3b1_branch2b,self.res3b1_branch2c],x))
x = nn.functional.relu(self.infer_resblock([lambda x: x], [self.res3b2_branch2a,self.res3b2_branch2b,self.res3b2_branch2c],x))
x = nn.functional.relu(self.infer_resblock([lambda x: x], [self.res3b3_branch2a,self.res3b3_branch2b,self.res3b3_branch2c],x))
x = nn.functional.relu(self.infer_resblock([self.res4a_branch1], [self.res4a_branch2a,self.res4a_branch2b,self.res4a_branch2c],x))
x = nn.functional.relu(self.infer_resblock([lambda x: x], [self.res4b1_branch2a,self.res4b1_branch2b,self.res4b1_branch2c],x))
x = nn.functional.relu(self.infer_resblock([lambda x: x], [self.res4b2_branch2a,self.res4b2_branch2b,self.res4b2_branch2c],x))
x = nn.functional.relu(self.infer_resblock([lambda x: x], [self.res4b3_branch2a,self.res4b3_branch2b,self.res4b3_branch2c],x))
x = nn.functional.relu(self.infer_resblock([lambda x: x], [self.res4b4_branch2a,self.res4b4_branch2b,self.res4b4_branch2c],x))
x = nn.functional.relu(self.infer_resblock([lambda x: x], [self.res4b5_branch2a,self.res4b5_branch2b,self.res4b5_branch2c],x))
x = nn.functional.relu(self.infer_resblock([lambda x: x], [self.res4b6_branch2a,self.res4b6_branch2b,self.res4b6_branch2c],x))
x = nn.functional.relu(self.infer_resblock([lambda x: x], [self.res4b7_branch2a,self.res4b7_branch2b,self.res4b7_branch2c],x))
x = nn.functional.relu(self.infer_resblock([lambda x: x], [self.res4b8_branch2a,self.res4b8_branch2b,self.res4b8_branch2c],x))
x = nn.functional.relu(self.infer_resblock([lambda x: x], [self.res4b9_branch2a,self.res4b9_branch2b,self.res4b9_branch2c],x))
x = nn.functional.relu(self.infer_resblock([lambda x: x], [self.res4b10_branch2a,self.res4b10_branch2b,self.res4b10_branch2c],x))
x = nn.functional.relu(self.infer_resblock([lambda x: x], [self.res4b11_branch2a,self.res4b11_branch2b,self.res4b11_branch2c],x))
x = nn.functional.relu(self.infer_resblock([lambda x: x], [self.res4b12_branch2a,self.res4b12_branch2b,self.res4b12_branch2c],x))
x = nn.functional.relu(self.infer_resblock([lambda x: x], [self.res4b13_branch2a,self.res4b13_branch2b,self.res4b13_branch2c],x))
x = nn.functional.relu(self.infer_resblock([lambda x: x], [self.res4b14_branch2a,self.res4b14_branch2b,self.res4b14_branch2c],x))
x = nn.functional.relu(self.infer_resblock([lambda x: x], [self.res4b15_branch2a,self.res4b15_branch2b,self.res4b15_branch2c],x))
x = nn.functional.relu(self.infer_resblock([lambda x: x], [self.res4b16_branch2a,self.res4b16_branch2b,self.res4b16_branch2c],x))
x = nn.functional.relu(self.infer_resblock([lambda x: x], [self.res4b17_branch2a,self.res4b17_branch2b,self.res4b17_branch2c],x))
x = nn.functional.relu(self.infer_resblock([lambda x: x], [self.res4b18_branch2a,self.res4b18_branch2b,self.res4b18_branch2c],x))
x = nn.functional.relu(self.infer_resblock([lambda x: x], [self.res4b19_branch2a,self.res4b19_branch2b,self.res4b19_branch2c],x))
x = nn.functional.relu(self.infer_resblock([lambda x: x], [self.res4b20_branch2a,self.res4b20_branch2b,self.res4b20_branch2c],x))
x = nn.functional.relu(self.infer_resblock([lambda x: x], [self.res4b21_branch2a,self.res4b21_branch2b,self.res4b21_branch2c],x))
#x = data_kv['res4b21']
x = nn.functional.relu(self.infer_resblock([lambda x: x], [self.res4b22_branch2a,self.res4b22_branch2b,self.res4b22_branch2c],x))
x_rpn_output = nn.functional.relu(self.rpn_conv_3x3(x))
x_rpn_cls_score = self.rpn_cls_score(x_rpn_output)
x_rpn_bbox_pred = self.rpn_bbox_pred(x_rpn_output)
n, c, h, w = x_rpn_cls_score.shape
x_rpn_cls_score = x_rpn_cls_score.reshape(n,2,-1,w)
x_rpn_cls_prob = nn.functional.softmax(x_rpn_cls_score, 1)
x_rpn_cls_prob_reshape = x_rpn_cls_prob.reshape(n,24,-1,w)
#im_size = np.array([600. , 600. , 2.6785715])
#im_size = np.array([5.6200000e+02, 1.0000000e+03, 8.9285713e-01])
rois = self.rpn.forward(x_rpn_cls_prob_reshape, x_rpn_bbox_pred, im_size)
feats = torchvision.ops.roi_pool(x, rois, output_size=[14,14], spatial_scale=0.0625)
x = nn.functional.relu(self.infer_resblock([self.res5a_branch1], [self.res5a_branch2a,self.res5a_branch2b,self.res5a_branch2c],feats))
x = nn.functional.relu(self.infer_resblock([lambda x: x], [self.res5b_branch2a,self.res5b_branch2b,self.res5b_branch2c],x))
x = nn.functional.relu(self.infer_resblock([lambda x: x], [self.res5c_branch2a,self.res5c_branch2b,self.res5c_branch2c],x))
x = torch.nn.functional.adaptive_avg_pool2d(x, (1,1))
pool5_flat = x.reshape((x.shape[0], -1))
x_cls_score = self.cls_score(pool5_flat)
x_cls_prob = torch.nn.functional.softmax(x_cls_score, -1)
x_cls_boxes = rois[:, 1:5] / im_size[2]
max_conf, keep_boxes = self.post_process(rois, x_cls_boxes, x_cls_prob, 0.2)
MIN_BOXES = 10
MAX_BOXES = 100
if len(keep_boxes) < MIN_BOXES:
keep_boxes = torch.argsort(max_conf, 0, True)[:MIN_BOXES]
elif len(keep_boxes) > MAX_BOXES:
keep_boxes = torch.argsort(max_conf, 0, True)[:MAX_BOXES]
boxes = x_cls_boxes[keep_boxes]
features = pool5_flat[keep_boxes]
confidence = max_conf[keep_boxes]
return boxes, features, confidence
def post_process(self, rois, cls_boxes, cls_prob, conf_thresh = 0.2):
max_conf = torch.zeros((rois.shape[0]), device = rois.device )
for cls_ind in range(1, cls_prob.shape[1]):
#cls_scores = scores[:, cls_ind]
cls_scores = cls_prob[:, cls_ind]
dets = torch.hstack(
(cls_boxes, cls_scores[:, np.newaxis]))
keep = np.array(torchvision.ops.nms(dets[:,:4],dets[:,4], 0.3 ))
max_conf[keep] = torch.where(cls_scores[keep] > max_conf[keep],
cls_scores[keep], max_conf[keep])
keep_boxes = torch.where(max_conf >= conf_thresh)[0]
return max_conf, keep_boxes
def load_weights_from_pkl(self, weights):
with torch.no_grad():
load_convblock(self.conv1, 'conv1', 'bn_conv1', 'scale_conv1', weights_kv)
load_convblock(self.res2a_branch1, 'res2a_branch1', 'bn2a_branch1', 'scale2a_branch1', weights_kv)
load_convblock(self.res2a_branch2a, 'res2a_branch2a', 'bn2a_branch2a', 'scale2a_branch2a', weights_kv)
load_convblock(self.res2a_branch2b, 'res2a_branch2b', 'bn2a_branch2b', 'scale2a_branch2b', weights_kv)
load_convblock(self.res2a_branch2c, 'res2a_branch2c', 'bn2a_branch2c', 'scale2a_branch2c', weights_kv)
load_convblock(self.res2b_branch2a, 'res2b_branch2a', 'bn2b_branch2a', 'scale2b_branch2a', weights_kv)
load_convblock(self.res2b_branch2b, 'res2b_branch2b', 'bn2b_branch2b', 'scale2b_branch2b', weights_kv)
load_convblock(self.res2b_branch2c, 'res2b_branch2c', 'bn2b_branch2c', 'scale2b_branch2c', weights_kv)
load_convblock(self.res2c_branch2a, 'res2c_branch2a', 'bn2c_branch2a', 'scale2c_branch2a', weights_kv)
load_convblock(self.res2c_branch2b, 'res2c_branch2b', 'bn2c_branch2b', 'scale2c_branch2b', weights_kv)
load_convblock(self.res2c_branch2c, 'res2c_branch2c', 'bn2c_branch2c', 'scale2c_branch2c', weights_kv)
load_convblock(self.res3a_branch1, 'res3a_branch1', 'bn3a_branch1', 'scale3a_branch1', weights_kv)
load_convblock(self.res3a_branch2a, 'res3a_branch2a', 'bn3a_branch2a', 'scale3a_branch2a', weights_kv)
load_convblock(self.res3a_branch2b, 'res3a_branch2b', 'bn3a_branch2b', 'scale3a_branch2b', weights_kv)
load_convblock(self.res3a_branch2c, 'res3a_branch2c', 'bn3a_branch2c', 'scale3a_branch2c', weights_kv)
load_convblock(self.res3b1_branch2a, 'res3b1_branch2a', 'bn3b1_branch2a', 'scale3b1_branch2a', weights_kv)
load_convblock(self.res3b1_branch2b, 'res3b1_branch2b', 'bn3b1_branch2b', 'scale3b1_branch2b', weights_kv)
load_convblock(self.res3b1_branch2c, 'res3b1_branch2c', 'bn3b1_branch2c', 'scale3b1_branch2c', weights_kv)
load_convblock(self.res3b2_branch2a, 'res3b2_branch2a', 'bn3b2_branch2a', 'scale3b2_branch2a', weights_kv)
load_convblock(self.res3b2_branch2b, 'res3b2_branch2b', 'bn3b2_branch2b', 'scale3b2_branch2b', weights_kv)
load_convblock(self.res3b2_branch2c, 'res3b2_branch2c', 'bn3b2_branch2c', 'scale3b2_branch2c', weights_kv)
load_convblock(self.res3b3_branch2a, 'res3b3_branch2a', 'bn3b3_branch2a', 'scale3b3_branch2a', weights_kv)
load_convblock(self.res3b3_branch2b, 'res3b3_branch2b', 'bn3b3_branch2b', 'scale3b3_branch2b', weights_kv)
load_convblock(self.res3b3_branch2c, 'res3b3_branch2c', 'bn3b3_branch2c', 'scale3b3_branch2c', weights_kv)
load_convblock(self.res4a_branch1, 'res4a_branch1', 'bn4a_branch1', 'scale4a_branch1', weights_kv)
load_convblock(self.res4a_branch2a, 'res4a_branch2a', 'bn4a_branch2a', 'scale4a_branch2a', weights_kv)
load_convblock(self.res4a_branch2b, 'res4a_branch2b', 'bn4a_branch2b', 'scale4a_branch2b', weights_kv)
load_convblock(self.res4a_branch2c, 'res4a_branch2c', 'bn4a_branch2c', 'scale4a_branch2c', weights_kv)
load_convblock(self.res4b1_branch2a, 'res4b1_branch2a', 'bn4b1_branch2a', 'scale4b1_branch2a', weights_kv)
load_convblock(self.res4b1_branch2b, 'res4b1_branch2b', 'bn4b1_branch2b', 'scale4b1_branch2b', weights_kv)
load_convblock(self.res4b1_branch2c, 'res4b1_branch2c', 'bn4b1_branch2c', 'scale4b1_branch2c', weights_kv)
load_convblock(self.res4b2_branch2a, 'res4b2_branch2a', 'bn4b2_branch2a', 'scale4b2_branch2a', weights_kv)
load_convblock(self.res4b2_branch2b, 'res4b2_branch2b', 'bn4b2_branch2b', 'scale4b2_branch2b', weights_kv)
load_convblock(self.res4b2_branch2c, 'res4b2_branch2c', 'bn4b2_branch2c', 'scale4b2_branch2c', weights_kv)
load_convblock(self.res4b3_branch2a, 'res4b3_branch2a', 'bn4b3_branch2a', 'scale4b3_branch2a', weights_kv)
load_convblock(self.res4b3_branch2b, 'res4b3_branch2b', 'bn4b3_branch2b', 'scale4b3_branch2b', weights_kv)
load_convblock(self.res4b3_branch2c, 'res4b3_branch2c', 'bn4b3_branch2c', 'scale4b3_branch2c', weights_kv)
load_convblock(self.res4b4_branch2a, 'res4b4_branch2a', 'bn4b4_branch2a', 'scale4b4_branch2a', weights_kv)
load_convblock(self.res4b4_branch2b, 'res4b4_branch2b', 'bn4b4_branch2b', 'scale4b4_branch2b', weights_kv)
load_convblock(self.res4b4_branch2c, 'res4b4_branch2c', 'bn4b4_branch2c', 'scale4b4_branch2c', weights_kv)
load_convblock(self.res4b5_branch2a, 'res4b5_branch2a', 'bn4b5_branch2a', 'scale4b5_branch2a', weights_kv)
load_convblock(self.res4b5_branch2b, 'res4b5_branch2b', 'bn4b5_branch2b', 'scale4b5_branch2b', weights_kv)
load_convblock(self.res4b5_branch2c, 'res4b5_branch2c', 'bn4b5_branch2c', 'scale4b5_branch2c', weights_kv)
load_convblock(self.res4b6_branch2a, 'res4b6_branch2a', 'bn4b6_branch2a', 'scale4b6_branch2a', weights_kv)
load_convblock(self.res4b6_branch2b, 'res4b6_branch2b', 'bn4b6_branch2b', 'scale4b6_branch2b', weights_kv)
load_convblock(self.res4b6_branch2c, 'res4b6_branch2c', 'bn4b6_branch2c', 'scale4b6_branch2c', weights_kv)
load_convblock(self.res4b7_branch2a, 'res4b7_branch2a', 'bn4b7_branch2a', 'scale4b7_branch2a', weights_kv)
load_convblock(self.res4b7_branch2b, 'res4b7_branch2b', 'bn4b7_branch2b', 'scale4b7_branch2b', weights_kv)
load_convblock(self.res4b7_branch2c, 'res4b7_branch2c', 'bn4b7_branch2c', 'scale4b7_branch2c', weights_kv)
load_convblock(self.res4b8_branch2a, 'res4b8_branch2a', 'bn4b8_branch2a', 'scale4b8_branch2a', weights_kv)
load_convblock(self.res4b8_branch2b, 'res4b8_branch2b', 'bn4b8_branch2b', 'scale4b8_branch2b', weights_kv)
load_convblock(self.res4b8_branch2c, 'res4b8_branch2c', 'bn4b8_branch2c', 'scale4b8_branch2c', weights_kv)
load_convblock(self.res4b9_branch2a, 'res4b9_branch2a', 'bn4b9_branch2a', 'scale4b9_branch2a', weights_kv)
load_convblock(self.res4b9_branch2b, 'res4b9_branch2b', 'bn4b9_branch2b', 'scale4b9_branch2b', weights_kv)
load_convblock(self.res4b9_branch2c, 'res4b9_branch2c', 'bn4b9_branch2c', 'scale4b9_branch2c', weights_kv)
load_convblock(self.res4b10_branch2a, 'res4b10_branch2a', 'bn4b10_branch2a', 'scale4b10_branch2a', weights_kv)
load_convblock(self.res4b10_branch2b, 'res4b10_branch2b', 'bn4b10_branch2b', 'scale4b10_branch2b', weights_kv)
load_convblock(self.res4b10_branch2c, 'res4b10_branch2c', 'bn4b10_branch2c', 'scale4b10_branch2c', weights_kv)
load_convblock(self.res4b11_branch2a, 'res4b11_branch2a', 'bn4b11_branch2a', 'scale4b11_branch2a', weights_kv)
load_convblock(self.res4b11_branch2b, 'res4b11_branch2b', 'bn4b11_branch2b', 'scale4b11_branch2b', weights_kv)
load_convblock(self.res4b11_branch2c, 'res4b11_branch2c', 'bn4b11_branch2c', 'scale4b11_branch2c', weights_kv)
load_convblock(self.res4b12_branch2a, 'res4b12_branch2a', 'bn4b12_branch2a', 'scale4b12_branch2a', weights_kv)
load_convblock(self.res4b12_branch2b, 'res4b12_branch2b', 'bn4b12_branch2b', 'scale4b12_branch2b', weights_kv)
load_convblock(self.res4b12_branch2c, 'res4b12_branch2c', 'bn4b12_branch2c', 'scale4b12_branch2c', weights_kv)
load_convblock(self.res4b13_branch2a, 'res4b13_branch2a', 'bn4b13_branch2a', 'scale4b13_branch2a', weights_kv)
load_convblock(self.res4b13_branch2b, 'res4b13_branch2b', 'bn4b13_branch2b', 'scale4b13_branch2b', weights_kv)
load_convblock(self.res4b13_branch2c, 'res4b13_branch2c', 'bn4b13_branch2c', 'scale4b13_branch2c', weights_kv)
load_convblock(self.res4b14_branch2a, 'res4b14_branch2a', 'bn4b14_branch2a', 'scale4b14_branch2a', weights_kv)
load_convblock(self.res4b14_branch2b, 'res4b14_branch2b', 'bn4b14_branch2b', 'scale4b14_branch2b', weights_kv)
load_convblock(self.res4b14_branch2c, 'res4b14_branch2c', 'bn4b14_branch2c', 'scale4b14_branch2c', weights_kv)
load_convblock(self.res4b15_branch2a, 'res4b15_branch2a', 'bn4b15_branch2a', 'scale4b15_branch2a', weights_kv)
load_convblock(self.res4b15_branch2b, 'res4b15_branch2b', 'bn4b15_branch2b', 'scale4b15_branch2b', weights_kv)
load_convblock(self.res4b15_branch2c, 'res4b15_branch2c', 'bn4b15_branch2c', 'scale4b15_branch2c', weights_kv)
load_convblock(self.res4b16_branch2a, 'res4b16_branch2a', 'bn4b16_branch2a', 'scale4b16_branch2a', weights_kv)
load_convblock(self.res4b16_branch2b, 'res4b16_branch2b', 'bn4b16_branch2b', 'scale4b16_branch2b', weights_kv)
load_convblock(self.res4b16_branch2c, 'res4b16_branch2c', 'bn4b16_branch2c', 'scale4b16_branch2c', weights_kv)
load_convblock(self.res4b17_branch2a, 'res4b17_branch2a', 'bn4b17_branch2a', 'scale4b17_branch2a', weights_kv)
load_convblock(self.res4b17_branch2b, 'res4b17_branch2b', 'bn4b17_branch2b', 'scale4b17_branch2b', weights_kv)
load_convblock(self.res4b17_branch2c, 'res4b17_branch2c', 'bn4b17_branch2c', 'scale4b17_branch2c', weights_kv)
load_convblock(self.res4b18_branch2a, 'res4b18_branch2a', 'bn4b18_branch2a', 'scale4b18_branch2a', weights_kv)
load_convblock(self.res4b18_branch2b, 'res4b18_branch2b', 'bn4b18_branch2b', 'scale4b18_branch2b', weights_kv)
load_convblock(self.res4b18_branch2c, 'res4b18_branch2c', 'bn4b18_branch2c', 'scale4b18_branch2c', weights_kv)
load_convblock(self.res4b19_branch2a, 'res4b19_branch2a', 'bn4b19_branch2a', 'scale4b19_branch2a', weights_kv)
load_convblock(self.res4b19_branch2b, 'res4b19_branch2b', 'bn4b19_branch2b', 'scale4b19_branch2b', weights_kv)
load_convblock(self.res4b19_branch2c, 'res4b19_branch2c', 'bn4b19_branch2c', 'scale4b19_branch2c', weights_kv)
load_convblock(self.res4b20_branch2a, 'res4b20_branch2a', 'bn4b20_branch2a', 'scale4b20_branch2a', weights_kv)
load_convblock(self.res4b20_branch2b, 'res4b20_branch2b', 'bn4b20_branch2b', 'scale4b20_branch2b', weights_kv)
load_convblock(self.res4b20_branch2c, 'res4b20_branch2c', 'bn4b20_branch2c', 'scale4b20_branch2c', weights_kv)
load_convblock(self.res4b21_branch2a, 'res4b21_branch2a', 'bn4b21_branch2a', 'scale4b21_branch2a', weights_kv)
load_convblock(self.res4b21_branch2b, 'res4b21_branch2b', 'bn4b21_branch2b', 'scale4b21_branch2b', weights_kv)
load_convblock(self.res4b21_branch2c, 'res4b21_branch2c', 'bn4b21_branch2c', 'scale4b21_branch2c', weights_kv)
load_convblock(self.res4b22_branch2a, 'res4b22_branch2a', 'bn4b22_branch2a', 'scale4b22_branch2a', weights_kv)
load_convblock(self.res4b22_branch2b, 'res4b22_branch2b', 'bn4b22_branch2b', 'scale4b22_branch2b', weights_kv)
load_convblock(self.res4b22_branch2c, 'res4b22_branch2c', 'bn4b22_branch2c', 'scale4b22_branch2c', weights_kv)
load_convblock(self.res5a_branch1, 'res5a_branch1', 'bn5a_branch1', 'scale5a_branch1', weights_kv)
load_convblock(self.res5a_branch2a, 'res5a_branch2a', 'bn5a_branch2a', 'scale5a_branch2a', weights_kv)
load_convblock(self.res5a_branch2b, 'res5a_branch2b', 'bn5a_branch2b', 'scale5a_branch2b', weights_kv)
load_convblock(self.res5a_branch2c, 'res5a_branch2c', 'bn5a_branch2c', 'scale5a_branch2c', weights_kv)
load_convblock(self.res5b_branch2a, 'res5b_branch2a', 'bn5b_branch2a', 'scale5b_branch2a', weights_kv)
load_convblock(self.res5b_branch2b, 'res5b_branch2b', 'bn5b_branch2b', 'scale5b_branch2b', weights_kv)
load_convblock(self.res5b_branch2c, 'res5b_branch2c', 'bn5b_branch2c', 'scale5b_branch2c', weights_kv)
load_convblock(self.res5c_branch2a, 'res5c_branch2a', 'bn5c_branch2a', 'scale5c_branch2a', weights_kv)
load_convblock(self.res5c_branch2b, 'res5c_branch2b', 'bn5c_branch2b', 'scale5c_branch2b', weights_kv)
load_convblock(self.res5c_branch2c, 'res5c_branch2c', 'bn5c_branch2c', 'scale5c_branch2c', weights_kv)
self.rpn_conv_3x3.weight = nn.Parameter(torch.FloatTensor(weights_kv['rpn_conv/3x3'][0]))
self.rpn_conv_3x3.bias = nn.Parameter(torch.FloatTensor(weights_kv['rpn_conv/3x3'][1]))
self.rpn_cls_score.weight = nn.Parameter(torch.FloatTensor(weights_kv['rpn_cls_score'][0]))
self.rpn_cls_score.bias = nn.Parameter(torch.FloatTensor(weights_kv['rpn_cls_score'][1]))
self.rpn_bbox_pred.weight = nn.Parameter(torch.FloatTensor(weights_kv['rpn_bbox_pred'][0]))
self.rpn_bbox_pred.bias = nn.Parameter(torch.FloatTensor(weights_kv['rpn_bbox_pred'][1]))
self.cls_score.weight = nn.Parameter(torch.FloatTensor(weights_kv['cls_score'][0]))
self.cls_score.bias = nn.Parameter(torch.FloatTensor(weights_kv['cls_score'][1]))
# self.conv1.weight = nn.Parameter(torch.FloatTensor(weights[0]['weights'][0]))
# self.conv1.bias = nn.Parameter(torch.zeros_like(self.conv1.bias))
# self.bn_conv1.running_mean = nn.Parameter(torch.FloatTensor(weights[1]['weights'][0] / weights[1]['weights'][2]))
# self.bn_conv1.running_var = nn.Parameter(torch.FloatTensor(weights[1]['weights'][1] / weights[1]['weights'][2]))
# self.bn_conv1.weight = nn.Parameter(torch.FloatTensor(weights[2]['weights'][0]))
# self.bn_conv1.bias = nn.Parameter(torch.FloatTensor(weights[2]['weights'][1]))
#
def process_img(img):
mean = np.array([[[102.9801, 115.9465, 122.7717]]])
img = img - mean
im_shape = img.shape
im_size_min = np.min(im_shape[0:2])
im_size_max = np.max(im_shape[0:2])
target_size = 600
max_size = 1000
im_scale = float(target_size) / float(im_size_min)
if np.round(im_scale * im_size_max) > max_size:
im_scale = float(max_size) / float(im_size_max)
im = cv2.resize(img, None, None, fx=im_scale, fy=im_scale, interpolation=cv2.INTER_LINEAR)
return im, np.array([im.shape[0], im.shape[1], im_scale])
#img = cv2.imread('img2.jpg')
##print(img)
#ipdb.set_trace()
#shape = process_img(img)
#
#net = Net()
#net.load_weights_from_pkl(data2)
#net.eval()
#with torch.no_grad():
# output = net(img)
#ipdb.set_trace()
#print(residual)

105
detector/generate_anchors.py

@ -0,0 +1,105 @@
# --------------------------------------------------------
# Faster R-CNN
# Copyright (c) 2015 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Ross Girshick and Sean Bell
# --------------------------------------------------------
import numpy as np
# Verify that we compute the same anchors as Shaoqing's matlab implementation:
#
# >> load output/rpn_cachedir/faster_rcnn_VOC2007_ZF_stage1_rpn/anchors.mat
# >> anchors
#
# anchors =
#
# -83 -39 100 56
# -175 -87 192 104
# -359 -183 376 200
# -55 -55 72 72
# -119 -119 136 136
# -247 -247 264 264
# -35 -79 52 96
# -79 -167 96 184
# -167 -343 184 360
#array([[ -83., -39., 100., 56.],
# [-175., -87., 192., 104.],
# [-359., -183., 376., 200.],
# [ -55., -55., 72., 72.],
# [-119., -119., 136., 136.],
# [-247., -247., 264., 264.],
# [ -35., -79., 52., 96.],
# [ -79., -167., 96., 184.],
# [-167., -343., 184., 360.]])
def generate_anchors(base_size=16, ratios=[0.5, 1, 2],
scales=2**np.arange(3, 6)):
"""
Generate anchor (reference) windows by enumerating aspect ratios X
scales wrt a reference (0, 0, 15, 15) window.
"""
base_anchor = np.array([1, 1, base_size, base_size]) - 1
ratio_anchors = _ratio_enum(base_anchor, ratios)
anchors = np.vstack([_scale_enum(ratio_anchors[i, :], scales)
for i in range(ratio_anchors.shape[0])])
return anchors
def _whctrs(anchor):
"""
Return width, height, x center, and y center for an anchor (window).
"""
w = anchor[2] - anchor[0] + 1
h = anchor[3] - anchor[1] + 1
x_ctr = anchor[0] + 0.5 * (w - 1)
y_ctr = anchor[1] + 0.5 * (h - 1)
return w, h, x_ctr, y_ctr
def _mkanchors(ws, hs, x_ctr, y_ctr):
"""
Given a vector of widths (ws) and heights (hs) around a center
(x_ctr, y_ctr), output a set of anchors (windows).
"""
ws = ws[:, np.newaxis]
hs = hs[:, np.newaxis]
anchors = np.hstack((x_ctr - 0.5 * (ws - 1),
y_ctr - 0.5 * (hs - 1),
x_ctr + 0.5 * (ws - 1),
y_ctr + 0.5 * (hs - 1)))
return anchors
def _ratio_enum(anchor, ratios):
"""
Enumerate a set of anchors for each aspect ratio wrt an anchor.
"""
w, h, x_ctr, y_ctr = _whctrs(anchor)
size = w * h
size_ratios = size / ratios
ws = np.round(np.sqrt(size_ratios))
hs = np.round(ws * ratios)
anchors = _mkanchors(ws, hs, x_ctr, y_ctr)
return anchors
def _scale_enum(anchor, scales):
"""
Enumerate a set of anchors for each scale wrt an anchor.
"""
w, h, x_ctr, y_ctr = _whctrs(anchor)
ws = w * scales
hs = h * scales
anchors = _mkanchors(ws, hs, x_ctr, y_ctr)
return anchors
#if __name__ == '__main__':
# import time
# t = time.time()
# a = generate_anchors()
# print time.time() - t
# print a
# from IPython import embed; embed()

136
detector/rpn.py

@ -0,0 +1,136 @@
import pickle
import numpy as np
import torch as t
import torch
from torch.nn import functional as F
from torch import nn
import torchvision
from torch import nn
from .generate_anchors import generate_anchors
from .bbox_transform import bbox_transform_inv, clip_boxes
class RegionProposalNetwork(nn.Module):
def __init__(self, pre_nms_topN,post_nms_topN, nms_thresh, min_size, anchor_scales, feat_stride):
super(RegionProposalNetwork, self).__init__()
self._anchors = generate_anchors(scales=np.array(anchor_scales))
self._num_anchors = self._anchors.shape[0]
self._feat_stride = feat_stride
self.pre_nms_topN = pre_nms_topN
self.post_nms_topN = post_nms_topN
self.nms_thresh = nms_thresh
self.min_size = min_size
self.anchor_scales = anchor_scales
self.feat_stride = feat_stride
def forward(self, rpn_cls_prob, rpn_bbox_pred, img_size):
scores = rpn_cls_prob[:,self._num_anchors:, :, :]
bbox_deltas = rpn_bbox_pred
min_size = self.min_size
pre_nms_topN = self.pre_nms_topN
post_nms_topN = self.post_nms_topN
nms_thresh = self.nms_thresh
min_size = self.min_size
anchor_scales = self.anchor_scales
feat_stride = self.feat_stride
n, _, hh, ww = scores.shape
#n_anchor = anchor.shape[0] // (hh * ww)
anchors = self._enumerate_shifted_anchor(self._anchors, self._feat_stride, hh, ww)
bbox_deltas = bbox_deltas.permute((0, 2, 3, 1)).reshape((-1, 4))
bbox_deltas = bbox_deltas.cpu().detach().numpy()
proposals = bbox_transform_inv(anchors, bbox_deltas)
scores = scores.permute((0, 2, 3, 1)).reshape((-1, 1))
proposals = bbox_transform_inv(anchors, bbox_deltas)
proposals = clip_boxes(proposals, img_size[:2])
keep = _filter_boxes(proposals, min_size * img_size[2])
proposals = proposals[keep, :]
scores = scores[keep]
order = scores.ravel().argsort(descending=True)
proposals = t.FloatTensor(proposals, device = rpn_cls_prob.device)
if pre_nms_topN > 0:
order = order[:pre_nms_topN]
proposals = proposals[order, :]
scores = scores[order]
keep = torchvision.ops.nms(proposals, scores.ravel(), nms_thresh)
if post_nms_topN > 0:
keep = keep[:post_nms_topN]
proposals = proposals[keep, :]
scores = scores[keep]
batch_inds = t.zeros((proposals.shape[0], 1), dtype = proposals.dtype, device = proposals.device)
rois = t.hstack([batch_inds, proposals])
return rois
#keep = nms(np.hstack((proposals, scores)), nms_thresh)
#if post_nms_topN > 0:
# keep = keep[:post_nms_topN]
#proposals = proposals[keep, :]
#scores = scores[keep]
def _enumerate_shifted_anchor(self, anchor_base, feat_stride, height, width):
# Enumerate all shifted anchors:
#
# add A anchors (1, A, 4) to
# cell K shifts (K, 1, 4) to get
# shift anchors (K, A, 4)
# reshape to (K*A, 4) shifted anchors
# return (K*A, 4)
# !TODO: add support for torch.CudaTensor
# xp = cuda.get_array_module(anchor_base)
# it seems that it can't be boosed using GPU
import numpy as xp
shift_y = xp.arange(0, height * feat_stride, feat_stride)
shift_x = xp.arange(0, width * feat_stride, feat_stride)
shift_x, shift_y = xp.meshgrid(shift_x, shift_y)
shift = xp.stack((shift_x.ravel(), shift_y.ravel(),
shift_x.ravel(), shift_y.ravel()), axis=1)
A = anchor_base.shape[0]
K = shift.shape[0]
anchor = anchor_base.reshape((1, A, 4)) + \
shift.reshape((1, K, 4)).transpose((1, 0, 2))
anchor = anchor.reshape((K * A, 4)).astype(np.float32)
return anchor
def _filter_boxes(boxes, min_size):
"""Remove all boxes with any side smaller than min_size."""
ws = boxes[:, 2] - boxes[:, 0] + 1
hs = boxes[:, 3] - boxes[:, 1] + 1
keep = np.where((ws >= min_size) & (hs >= min_size))[0]
return keep
#if __name__ == '__main__':
# rpn_cls_prob_reshape = data_kv['rpn_cls_prob_reshape']
# rpn_bbox_pred = data_kv['rpn_bbox_pred']
#
# rpn = RegionProposalNetwork(pre_nms_topN = 6000, post_nms_topN = 300, nms_thresh = 0.7, min_size = 16, anchor_scales = (4, 8, 16, 32), feat_stride=16)
# im_size = np.array([600. , 600. , 2.6785715])
# ipdb.set_trace()
# rois = rpn.forward(rpn_cls_prob_reshape, rpn_bbox_pred, im_size)
#
# outputs = data_kv['res4b22']
# feats = torchvision.ops.roi_pool(outputs,rois, output_size=[14,14], spatial_scale=0.0625)
#
# print(rpn_cls_prob_reshape.shape, rpn_bbox_pred.shape)
# print(rpn)
# print('main')
#

BIN
dvl/__pycache__/const.cpython-38.pyc

Binary file not shown.

3
dvl/const.py

@ -0,0 +1,3 @@
IMG_DIM = 2048
IMG_LABEL_DIM = 1601
BUCKET_SIZE = 8192

366
dvl/data/itm.py

@ -0,0 +1,366 @@
import torch
import numpy as np
import itertools
from torch.nn.utils.rnn import pad_sequence
from uniter_model.data.itm import DetectFeatTxtTokDataset, TxtTokLmdb, DetectFeatLmdb, get_ids_and_lens
from uniter_model.data.data import get_gather_index
from toolz.sandbox import unzip
from cytoolz import concat
from GLOBAL_VARIABLES import N_EXAMPLES_TEACHER
def pad_tensors(tensors, lens=None, pad=0):
"""B x [T, ...]"""
if lens is None:
lens = [t.size(0) for t in tensors]
max_len = max(lens)
bs = len(tensors)
hid = tensors[0].size(-1)
dtype = tensors[0].dtype
output = torch.zeros(bs, max_len, hid, dtype=dtype)
if pad:
output.data.fill_(pad)
for i, (t, l) in enumerate(zip(tensors, lens)):
output.data[i, :l, ...] = t.data
return output
# for ITM task
class ItmFastDataset(DetectFeatTxtTokDataset):
""" NOTE this Dataset handles distributed training itself
(for more efficient negative sampling) """
def __init__(self, txt_db, img_db, num_hard_negatives=0, img_meta=None, tokenizer=None):
assert isinstance(txt_db, TxtTokLmdb)
assert isinstance(img_db, DetectFeatLmdb)
self.txt_db = txt_db
self.img_db = img_db
self.txt_lens, self.ids = get_ids_and_lens(txt_db)
self.ids_2_idx = {idx:i for i, idx in enumerate(self.ids)}
self.all_imgs = list(set(txt_db[id_]['img_fname'] for id_ in self.ids))
self.num_hard_negatives = num_hard_negatives
self.img_meta = img_meta
self.tokenizer = tokenizer
self.train_imgs = None
self.neg_imgs = None
# self.new_epoch(hard_negatives)
def new_epoch(self, hard_negatives_img=None, hard_negatives_txt=None):
""" should be called every epoch for more randomness"""
self.lens = []
self.train_imgs, self.neg_imgs = [], []
self.train_txts, self.neg_txts = [], []
for i, (id_, tl) in enumerate(zip(self.ids, self.txt_lens)):
img_fname = super().__getitem__(i)['img_fname']
self.train_imgs.append(img_fname)
self.train_txts.append(id_)
if hard_negatives_img is not None and self.num_hard_negatives > 0:
self.neg_imgs.append(hard_negatives_img[id_][:self.num_hard_negatives])
self.neg_txts.append(hard_negatives_txt[img_fname][:self.num_hard_negatives])
else:
self.neg_imgs.append(None)
self.neg_txts.append(None)
self.lens.append(tl + self.img_db.name2nbb[img_fname])
def __getitem__(self, i):
example = super().__getitem__(i)
# labels and negative images should be sampled every epoch
img_fname, hard_neg_imgs = self.train_imgs[i], self.neg_imgs[i]
txt_fname, hard_neg_txts = self.ids[i], self.neg_txts[i]
img_input_ids = torch.Tensor([101]).long()
img_feat, img_pos_feat, num_bb = self._get_img_feat(img_fname)
attn_masks_img = torch.ones(num_bb+1, dtype=torch.long)
# text input
input_ids = example['input_ids']
input_ids = self.txt_db.combine_inputs(input_ids)
attn_masks = torch.ones(len(input_ids), dtype=torch.long)
if hard_neg_imgs is not None:
# TODO: add hard negative here
neg_imgs = dict({'img_input_ids': [], 'img_feat': [], 'img_pos_feat': [], 'num_bb': [], 'attn_masks_img': [],
'caption_ids': [], 'attn_masks_captions': []})
for neg_id in hard_neg_imgs:
neg_imgs['img_input_ids'].append(torch.Tensor([101]).long())
t = self._get_img_feat(neg_id)
neg_imgs['img_feat'].append(t[0])
neg_imgs['img_pos_feat'].append(t[1])
neg_imgs['num_bb'].append(t[2])
neg_imgs['attn_masks_img'].append(torch.ones(t[2]+1, dtype=torch.long))
if self.img_meta is not None:
tmp = [self.tokenizer.encode(i, add_special_tokens=False) + [self.tokenizer.sep_token_id]
for i in self.img_meta[neg_id]['caption_multiple']]
neg_imgs['caption_ids'].append(torch.tensor([self.tokenizer.cls_token_id] + sum(tmp, []),
dtype=input_ids.dtype, device=input_ids.device))
neg_imgs['attn_masks_captions'].append(torch.ones(len(neg_imgs['caption_ids'][-1]), dtype=torch.long))
# debug = [self.tokenizer.encode(a) for a in self.img_meta[img_fname]['annotation']]
neg_txts = dict({'input_ids':[], 'position_ids':[], 'attention_mask':[]})
for neg_id in hard_neg_txts:
ei = super().__getitem__(self.ids_2_idx[neg_id])
input_ids_ei = ei['input_ids']
neg_txts['input_ids'].append(self.txt_db.combine_inputs(input_ids_ei))
neg_txts['attention_mask'].append(torch.ones(len(neg_txts['input_ids'][-1]), dtype=torch.long))
else:
neg_imgs = None
neg_txts = None
if self.img_meta is not None:
caption_ids = [self.tokenizer.encode(i, add_special_tokens=False) + [self.tokenizer.sep_token_id] for i in self.img_meta[img_fname]['caption_multiple']]
caption_ids = torch.tensor([self.tokenizer.cls_token_id] + sum(caption_ids, []), dtype=input_ids.dtype, device=input_ids.device)
attn_masks_captions = torch.ones(len(caption_ids), dtype=torch.long)
# debug = [self.tokenizer.encode(a) for a in self.img_meta[img_fname]['annotation']]
else:
caption_ids = None
attn_masks_captions = None
# target = torch.Tensor(1).long()
# target.data.fill_(ground_truth_label)
return input_ids, img_feat, img_pos_feat, img_input_ids, attn_masks, attn_masks_img, self.ids[i], img_fname, neg_imgs, neg_txts, caption_ids, attn_masks_captions
def itm_fast_collate_kd(inputs):
input_ids, img_feats, img_pos_feats, img_input_ids, attn_masks_text, attn_masks_img, idx, img_fname, negs, caption_ids, attn_masks_captions = map(list, unzip(inputs))
# txt_lens = [i.size(0) for i in input_ids]
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0)
captions_ids = pad_sequence(caption_ids, batch_first=True, padding_value=0) if caption_ids[0] is not None else None
position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long).unsqueeze(0)
position_ids_captions = torch.arange(0, captions_ids.size(1), dtype=torch.long).unsqueeze(0) if caption_ids[0] is not None else None
if not None in negs:
num_bbs_neg = list(itertools.chain(*[n['num_bb'] for n in negs]))
img_feats_neg = list(itertools.chain(*[n['img_feat'] for n in negs]))
img_input_ids_neg = list(itertools.chain(*[n['img_input_ids'] for n in negs]))
img_pos_feat_neg = list(itertools.chain(*[n['img_pos_feat'] for n in negs]))
attn_masks_img_neg = list(itertools.chain(*[n['attn_masks_img'] for n in negs]))
else:
num_bbs_neg = []
img_feats_neg = []
img_input_ids_neg = []
img_pos_feat_neg = []
attn_masks_img_neg = []
num_bbs = [f.size(0) for f in img_feats] + num_bbs_neg
img_feat = pad_tensors(img_feats+img_feats_neg, num_bbs)
img_pos_feat = pad_tensors(img_pos_feats+img_pos_feat_neg, num_bbs)
img_input_ids = pad_sequence(img_input_ids+img_input_ids_neg, batch_first=True, padding_value=0)
img_position_ids = torch.arange(0, img_input_ids.size(1), dtype=torch.long).unsqueeze(0)
attn_masks_text = pad_sequence(attn_masks_text, batch_first=True, padding_value=0)
attn_masks_captions = pad_sequence(attn_masks_captions, batch_first=True, padding_value=0) if attn_masks_captions[0] is not None else None
attn_masks_img = pad_sequence(attn_masks_img+attn_masks_img_neg, batch_first=True, padding_value=0)
sample_size = len(inputs[0])
assert all(sample_size == len(i) for i in inputs)
bs, max_tl = input_ids.size()
out_size = attn_masks_img.size(1)
gather_index = get_gather_index([1]*bs, num_bbs, bs, 1, out_size)
img_feat_teacher = img_feat[:N_EXAMPLES_TEACHER].repeat(bs, 1, 1)
img_pos_feat_teacher = img_pos_feat[:N_EXAMPLES_TEACHER].repeat(bs, 1, 1)
attn_masks_img_teacher = attn_masks_img[:N_EXAMPLES_TEACHER].repeat(bs, 1)[:, 1:]
input_ids_teacher = input_ids.unsqueeze(1).repeat(1, 10, 1).view(bs*N_EXAMPLES_TEACHER, -1)
position_ids_teacher = position_ids
attn_masks_text_teacher = attn_masks_text.unsqueeze(1).repeat(1, 10, 1).view(bs*N_EXAMPLES_TEACHER, -1)
attn_masks_teacher = torch.cat([attn_masks_text_teacher, attn_masks_img_teacher], dim=1)
batch = {
'txt_ids': input_ids,
'img_ids': img_feat,
'caption_ids': captions_ids,
'txt_pos_ids': position_ids,
'img_pos_ids': img_pos_feat,
'caption_pos_ids': position_ids_captions,
'txt_attn_masks': attn_masks_text,
'img_attn_masks': attn_masks_img,
'caption_attn_masks': attn_masks_captions,
'img_txt_ids': img_input_ids,
'img_txt_pos_ids': img_position_ids,
'gather_index': gather_index,
'sample_size': sample_size,
'pos_ctx_indices': list(range(bs)),
'neg_ctx_indices': list(range(bs, len(num_bbs))),
'txt_index': idx,
'img_fname': img_fname,
'img_feat_teacher': img_feat_teacher,
'img_pos_feat_teacher': img_pos_feat_teacher,
'input_ids_teacher': input_ids_teacher,
'position_ids_teacher': position_ids_teacher,
'attn_masks_teacher': attn_masks_teacher
}
return batch
def itm_fast_collate(inputs):
input_ids, img_feats, img_pos_feats, img_input_ids, attn_masks_text, attn_masks_img, idx, img_fname, neg_imgs, neg_txts, caption_ids, attn_masks_captions = map(list, unzip(inputs))
bs = len(input_ids)
# txt_lens = [i.size(0) for i in input_ids]
if not None in neg_imgs:
num_bbs_neg = list(itertools.chain(*[n['num_bb'] for n in neg_imgs]))
img_feats_neg = list(itertools.chain(*[n['img_feat'] for n in neg_imgs]))
img_input_ids_neg = list(itertools.chain(*[n['img_input_ids'] for n in neg_imgs]))
img_pos_feat_neg = list(itertools.chain(*[n['img_pos_feat'] for n in neg_imgs]))
attn_masks_img_neg = list(itertools.chain(*[n['attn_masks_img'] for n in neg_imgs]))
caption_ids_neg = list(itertools.chain(*[n['caption_ids'] for n in neg_imgs]))
attn_masks_captions_neg = list(itertools.chain(*[n['attn_masks_captions'] for n in neg_imgs]))
input_ids_neg = list(itertools.chain(*[n['input_ids'] for n in neg_txts]))
attn_masks_text_neg = list(itertools.chain(*[n['attention_mask'] for n in neg_txts]))
else:
num_bbs_neg = []
img_feats_neg = []
img_input_ids_neg = []
img_pos_feat_neg = []
attn_masks_img_neg = []
caption_ids_neg = []
attn_masks_captions_neg = []
input_ids_neg = []
attn_masks_text_neg = []
input_ids = pad_sequence(input_ids+input_ids_neg, batch_first=True, padding_value=0)
position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long).unsqueeze(0)
captions_ids = pad_sequence(caption_ids+caption_ids_neg, batch_first=True, padding_value=0) if caption_ids[0] is not None else None
position_ids_captions = torch.arange(0, captions_ids.size(1), dtype=torch.long).unsqueeze(0) if caption_ids[0] is not None else None
num_bbs = [f.size(0) for f in img_feats] + num_bbs_neg
img_feat = pad_tensors(img_feats+img_feats_neg, num_bbs)
img_pos_feat = pad_tensors(img_pos_feats+img_pos_feat_neg, num_bbs)
img_input_ids = pad_sequence(img_input_ids+img_input_ids_neg, batch_first=True, padding_value=0)
img_position_ids = torch.arange(0, img_input_ids.size(1), dtype=torch.long).unsqueeze(0)
attn_masks_text = pad_sequence(attn_masks_text+attn_masks_text_neg, batch_first=True, padding_value=0)
attn_masks_captions = pad_sequence(attn_masks_captions+attn_masks_captions_neg, batch_first=True, padding_value=0) if attn_masks_captions[0] is not None else None
attn_masks_img = pad_sequence(attn_masks_img+attn_masks_img_neg, batch_first=True, padding_value=0)
sample_size = bs
# assert all(sample_size == len(i) for i in inputs)
max_tl = input_ids.shape[1]
out_size = attn_masks_img.size(1)
gather_index = get_gather_index([1]*bs, num_bbs, bs, 1, out_size)
batch = {
'txts': {
'input_ids': input_ids,
'position_ids': position_ids,
'attention_mask': attn_masks_text,
'img_feat': None,
'img_pos_feat': None,
'img_masks': None,
'gather_index': None
},
'imgs': {
'input_ids': img_input_ids,
'position_ids': img_position_ids,
'attention_mask': attn_masks_img,
'img_feat': img_feat,
'img_pos_feat': img_pos_feat,
'img_masks': None,
'gather_index': gather_index
},
'caps': {
'input_ids': captions_ids,
'position_ids': position_ids_captions,
'attention_mask': attn_masks_captions,
'img_feat': None,
'img_pos_feat': None,
'img_masks': None,
'gather_index': None
},
'sample_size': sample_size,
'pos_ctx_indices': list(range(bs)),
'neg_ctx_indices': list(range(bs, len(num_bbs))),
'txt_index': idx,
'img_fname': img_fname
}
return batch
class ItmValDataset(DetectFeatTxtTokDataset):
""" For evaluating Image-Text-Retrieval task """
def __init__(self, db_dir, img_dir, mini_batch_size=400):
super().__init__(db_dir, img_dir)
del self.lens
self.txt2img = self.txt_db.txt2img
self.img2txts = self.txt_db.img2txts
self.all_img_ids = list(self.img2txts.keys())
assert len(self.img2txts) >= mini_batch_size > 0
self.bs = mini_batch_size
def _get_batch_ids(self, i):
gt_txt_id = self.ids[i]
gt_img_id = self.txt2img[gt_txt_id]
# sample fixed negatives for each gt image
i = self.all_img_ids.index(gt_img_id)
neg_st = i+1
neg_end = neg_st+self.bs-1
if neg_end > len(self.all_img_ids):
# warp around
neg_end -= len(self.all_img_ids)
neg_img_ids = (self.all_img_ids[neg_st:]
+ self.all_img_ids[:neg_end])
else:
neg_img_ids = self.all_img_ids[neg_st:neg_end]
assert len(neg_img_ids) == (self.bs - 1),\
"Did not sample enough neg samples"
return gt_img_id, neg_img_ids
def __getitem__(self, i):
""" this returns list of mini-batches """
gt_img_id, neg_img_ids = self._get_batch_ids(i)
# NOTE 1st one is gt img
batch = self.get_batch(i, [gt_img_id] + neg_img_ids)
return batch
def get_batch(self, i, img_ids):
example = super().__getitem__(i)
input_ids = example['input_ids']
input_ids = self.txt_db.combine_inputs(input_ids)
input_ids = input_ids.unsqueeze(0).expand(len(img_ids), -1).clone()
position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long
).unsqueeze(0)
# process image features (gt always first)
img_feats, img_pos_feats, num_bbs = map(
list, unzip(map(self._get_img_feat, img_ids)))
img_feat = pad_tensors(img_feats, num_bbs)
img_pos_feat = pad_tensors(img_pos_feats, num_bbs)
tl = input_ids.size(1)
attn_masks_text = torch.ones(len(img_ids), tl).long()
# attn_masks_text = torch.ones(1, tl).long()
attn_masks_img = torch.zeros(len(img_ids), max(num_bbs)).long()
for i, nbb in enumerate(num_bbs):
attn_masks_img.data[i, :nbb].fill_(1)
# out_size = attn_masks.size(1)
gather_index = None #get_gather_index([tl]*len(img_ids), num_bbs, len(img_ids), tl, out_size)
batch = {'input_ids': input_ids,
'position_ids': position_ids,
'img_feat': img_feat,
'img_pos_feat': img_pos_feat,
'attn_masks_text': attn_masks_text,
'attn_masks_img': attn_masks_img,
'gather_index': gather_index}
return batch
# for VQA

592
dvl/data/itm_pre.py

@ -0,0 +1,592 @@
"""
Itm dataset
"""
from collections import defaultdict
import copy
import json
import random
import torch
from torch.nn.utils.rnn import pad_sequence
import numpy as np
from toolz.sandbox import unzip
from cytoolz import concat
from uniter_model.data.data import (DetectFeatTxtTokDataset, TxtTokLmdb, DetectFeatLmdb,
pad_tensors, get_gather_index, get_ids_and_lens)
from uniter_model.data.sampler import TokenBucketSampler
class TokenBucketSamplerForItm(TokenBucketSampler):
def __init__(self, dset, *args, **kwargs):
super().__init__(dset.lens, *args, **kwargs)
self.dset = dset
def __iter__(self):
it = super().__iter__()
self.dset.new_epoch()
self._lens = self.dset.lens
return it
def _has_overlap(la, lb):
if len(la) < len(lb):
la, lb = lb, la
s = set(la)
return any(b in s for b in lb)
def _sample_negative_rand(sample_pool, ground_truths, num_sample):
""" random and retry """
outputs = ground_truths[:1]
while _has_overlap(outputs, ground_truths):
outputs = random.sample(sample_pool, num_sample)
return outputs
def _sample_negative_extra(sample_pool, ground_truths, num_sample):
""" sample extra then remove """
tot_size = len(ground_truths) + num_sample
outputs = set(random.sample(sample_pool, tot_size))
for gt in ground_truths:
outputs.discard(gt)
outputs = list(outputs)[:num_sample]
return outputs
sample_negative = _sample_negative_rand # swith between 2 implementations
class ItmDataset(DetectFeatTxtTokDataset):
""" NOTE this Dataset handles distributed training itself
(for more efficient negative sampling) """
def __init__(self, txt_db, img_db, neg_sample_p=0.0):
assert isinstance(txt_db, TxtTokLmdb)
assert isinstance(img_db, DetectFeatLmdb)
self.txt_db = txt_db
self.img_db = img_db
self.txt_lens, self.ids = get_ids_and_lens(txt_db)
self.all_imgs = list(set(txt_db[id_]['img_fname'] for id_ in self.ids))
self.neg_sample_p = neg_sample_p
self.new_epoch()
def new_epoch(self):
""" should be called every epoch for more randomness"""
self.labels = np.random.choice(
[0, 1], size=len(self.ids),
p=[self.neg_sample_p, 1-self.neg_sample_p])
self.lens = []
self.train_imgs = []
for i, (id_, tl) in enumerate(zip(self.ids, self.txt_lens)):
img_fname = super().__getitem__(i)['img_fname']
if self.labels[i] == 0:
img_fname = sample_negative(self.all_imgs, [img_fname], 1)[0]
self.train_imgs.append(img_fname)
self.lens.append(tl + self.img_db.name2nbb[img_fname])
def __getitem__(self, i):
example = super().__getitem__(i)
# labels and negative images should be sampled every epoch
ground_truth_label = self.labels[i]
img_fname = self.train_imgs[i]
img_input_ids = torch.Tensor([101]).long()
img_feat, img_pos_feat, num_bb = self._get_img_feat(img_fname)
attn_masks_img = torch.ones(num_bb+1, dtype=torch.long)
# text input
input_ids = example['input_ids']
input_ids = self.txt_db.combine_inputs(input_ids)
attn_masks = torch.ones(len(input_ids), dtype=torch.long)
target = torch.Tensor(1).long()
target.data.fill_(ground_truth_label)
return input_ids, attn_masks, img_input_ids, img_feat, img_pos_feat, attn_masks_img, target
def itm_collate(inputs):
(input_ids, attn_masks, img_input_ids, img_feats, img_pos_feats, attn_masks_img, targets
) = map(list, unzip(inputs))
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0)
position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long).unsqueeze(0)
num_bbs = [f.size(0) for f in img_feats]
img_feat = pad_tensors(img_feats, num_bbs)
img_pos_feat = pad_tensors(img_pos_feats, num_bbs)
img_input_ids = pad_sequence(img_input_ids, batch_first=True, padding_value=0)
img_position_ids = torch.arange(0, img_input_ids.size(1), dtype=torch.long).unsqueeze(0)
attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0)
attn_masks_img = pad_sequence(attn_masks_img, batch_first=True, padding_value=0)
targets = torch.cat(targets, dim=0)
bs, max_tl = input_ids.size()
out_size = attn_masks_img.size(1)
# gather_index = get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size)
gather_index = get_gather_index([1]*bs, num_bbs, bs, 1, out_size)
batch = {
'txts': {
'input_ids': input_ids,
'position_ids': position_ids,
'attention_mask': attn_masks,
'img_feat': None,
'img_pos_feat': None,
'img_masks': None,
'gather_index': None
},
'imgs': {
'input_ids': img_input_ids,
'position_ids': img_position_ids,
'attention_mask': attn_masks_img,
'img_feat': img_feat,
'img_pos_feat': img_pos_feat,
'img_masks': None,
'gather_index': gather_index
},
'pos_ctx_indices': list(range(bs)),
'neg_ctx_indices': list(range(bs, len(num_bbs))),
'targets': targets
}
return batch
def _compute_ot_scatter(txt_lens, max_txt_len, joint_len):
ot_scatter = torch.arange(0, joint_len, dtype=torch.long
).unsqueeze(0).repeat(len(txt_lens), 1)
for i, tl in enumerate(txt_lens):
max_ind = max_txt_len + (joint_len-tl)
ot_scatter.data[i, tl:] = torch.arange(max_txt_len, max_ind,
dtype=torch.long).data
return ot_scatter
def _compute_pad(lens, max_len):
pad = torch.zeros(len(lens), max_len, dtype=torch.bool)
for i, l in enumerate(lens):
pad.data[i, l:].fill_(1)
return pad
def itm_ot_collate(inputs):
(input_ids, img_feats, img_pos_feats, attn_masks, targets
) = map(list, unzip(inputs))
txt_lens = [i.size(0) for i in input_ids]
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0)
position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long
).unsqueeze(0)
num_bbs = [f.size(0) for f in img_feats]
img_feat = pad_tensors(img_feats, num_bbs)
img_pos_feat = pad_tensors(img_pos_feats, num_bbs)
attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0)
targets = torch.cat(targets, dim=0)
bs, max_tl = input_ids.size()
out_size = attn_masks.size(1)
gather_index = get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size)
# OT inputs
max_tl = max(txt_lens)
max_nbb = max(num_bbs)
ot_scatter = _compute_ot_scatter(txt_lens, max_tl, attn_masks.size(1))
txt_pad = _compute_pad(txt_lens, max_tl)
img_pad = _compute_pad(num_bbs, max_nbb)
ot_inputs = {'ot_scatter': ot_scatter,
'scatter_max': ot_scatter.max().item(),
'txt_pad': txt_pad,
'img_pad': img_pad}
batch = {'input_ids': input_ids,
'position_ids': position_ids,
'img_feat': img_feat,
'img_pos_feat': img_pos_feat,
'attn_masks': attn_masks,
'gather_index': gather_index,
'targets': targets,
'ot_inputs': ot_inputs}
return batch
class ItmRankDataset(DetectFeatTxtTokDataset):
def __init__(self, txt_db, img_db, neg_sample_size=1):
assert neg_sample_size > 0, \
"ItmRankDataset need at least 1 negative sample"
super().__init__(txt_db, img_db)
txt2img = self.txt_db.txt2img
self.txt2img = {id_: txt2img[id_] for id_ in self.ids}
# images partitioned by rank
self.img2txts = defaultdict(list)
for id_, img in self.txt2img.items():
self.img2txts[img].append(id_)
self.img_name_list = list(self.img2txts.keys())
assert neg_sample_size > 0
self.neg_sample_size = neg_sample_size
def __getitem__(self, i):
gt_txt_id = self.ids[i]
gt_img_fname = self.txt2img[gt_txt_id]
id_pairs = [(gt_txt_id, gt_img_fname)]
# sample negatives
neg_sample_img_ids = sample_negative(
self.img_name_list, [gt_img_fname], self.neg_sample_size)
neg_sample_txt_ids = sample_negative(
self.ids, self.img2txts[gt_img_fname], self.neg_sample_size)
id_pairs.extend([(gt_txt_id, neg_img_id)
for neg_img_id in neg_sample_img_ids] +
[(neg_txt_id, gt_img_fname)
for neg_txt_id in neg_sample_txt_ids])
inputs = self._collect_inputs(id_pairs)
assert len(inputs) == (1 + 2*self.neg_sample_size)
return inputs
def _collect_inputs(self, id_pairs):
# create input features
inputs = []
for txt_id, img_id in id_pairs:
example = self.txt_db[txt_id]
# text input
input_ids = example['input_ids']
input_ids = self.txt_db.combine_inputs(input_ids)
# img input
img_feat, img_pos_feat, num_bb = self._get_img_feat(img_id)
# mask
attn_masks_text = torch.ones(len(input_ids), dtype=torch.long)
attn_masks_img = torch.ones(num_bb, dtype=torch.long)
inputs.append((input_ids, img_feat, img_pos_feat, attn_masks_text, attn_masks_img))
return inputs
class ItmRankDatasetHardNeg(ItmRankDataset):
def __init__(self, txt_db, img_db, neg_sample_size=1, hard_neg_size=1):
assert hard_neg_size > 0, \
"ItmRankDatasetHardNeg need at least 1 hard negative sample"
DetectFeatTxtTokDataset.__init__(self, txt_db, img_db)
txt2img = self.txt_db.txt2img
self.txt2img = {id_: txt2img[id_] for id_ in self.ids}
self.img2txts = self.txt_db.img2txts
self.img_name_list = list(self.img2txts.keys())
assert neg_sample_size > 0
self.neg_sample_size = neg_sample_size
self.hard_neg_size = hard_neg_size
def reload_hard_negs(self, hard_neg_dir):
self.txt2hardimgs = json.load(
open(f'{hard_neg_dir}/'
f'txt2hardimgs_rank{hvd.rank()}.json'))
self.img2hardtxts = json.load(
open(f'{hard_neg_dir}/img2hardtxts.json'))
def __getitem__(self, i):
gt_txt_id = self.ids[i]
gt_img_fname = self.txt2img[gt_txt_id]
id_pairs = [(gt_txt_id, gt_img_fname)]
# sample hard negatives
if self.hard_neg_size > 0:
hard_neg_img_samples = random.sample(
self.txt2hardimgs[gt_txt_id], self.hard_neg_size)
hard_neg_txt_samples = random.sample(
self.img2hardtxts[gt_img_fname], self.hard_neg_size)
id_pairs.extend([(gt_txt_id, neg_img_id)
for neg_img_id in hard_neg_img_samples] +
[(neg_txt_id, gt_img_fname)
for neg_txt_id in hard_neg_txt_samples])
# sample normal negatives
if self.neg_sample_size > 0:
neg_sample_img_ids = sample_negative(
self.img_name_list, [gt_img_fname], self.neg_sample_size)
neg_sample_txt_ids = sample_negative(
self.ids, self.img2txts[gt_img_fname], self.neg_sample_size)
id_pairs.extend([(gt_txt_id, neg_img_id)
for neg_img_id in neg_sample_img_ids] +
[(neg_txt_id, gt_img_fname)
for neg_txt_id in neg_sample_txt_ids])
inputs = self._collect_inputs(id_pairs)
assert len(inputs) == (1
+ 2*self.neg_sample_size
+ 2*self.hard_neg_size)
return inputs
def itm_rank_collate(inputs):
(input_ids, img_feats, img_pos_feats, attn_masks_text, attn_masks_img,
) = map(list, unzip(concat(i for i in inputs)))
txt_lens = [i.size(0) for i in input_ids]
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0)
position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long
).unsqueeze(0)
num_bbs = [f.size(0) for f in img_feats]
img_feat = pad_tensors(img_feats, num_bbs)
img_pos_feat = pad_tensors(img_pos_feats, num_bbs)
attn_masks_text = pad_sequence(attn_masks_text, batch_first=True, padding_value=0)
attn_masks_img = pad_sequence(attn_masks_img, batch_first=True, padding_value=0)
sample_size = len(inputs[0])
assert all(sample_size == len(i) for i in inputs)
bs, max_tl = input_ids.size()
# out_size = attn_masks.size(1)
gather_index = None # get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size)
batch = {'input_ids': input_ids,
'position_ids': position_ids,
'img_feat': img_feat,
'img_pos_feat': img_pos_feat,
'attn_masks_text': attn_masks_text,
'attn_masks_img': attn_masks_img,
'gather_index': gather_index,
'sample_size': sample_size}
return batch
class ItmRankDatasetHardNegFromText(DetectFeatTxtTokDataset):
def __init__(self, txt_db, img_db, neg_sample_size=1):
assert neg_sample_size > 0, \
"ItmRankDatasetHardNegV2 need at least 1 negative sample"
super().__init__(txt_db, img_db)
txt2img = self.txt_db.txt2img
self.txt2img = {id_: txt2img[id_] for id_ in self.ids}
self.img2txts = self.txt_db.img2txts
self.img_name_list = list(self.img2txts.keys())
self.neg_sample_size = neg_sample_size
def __getitem__(self, i):
gt_txt_id = self.ids[i]
gt_img_fname = self.txt2img[gt_txt_id]
input_ids = self.txt_db[gt_txt_id]['input_ids']
input_ids = self.txt_db.combine_inputs(input_ids)
input_ids = input_ids.unsqueeze(0)
position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long
).unsqueeze(0)
neg_img_ids = sample_negative(
self.img_name_list, [gt_img_fname], self.neg_sample_size)
img_ids = [gt_img_fname] + neg_img_ids
# process image features (gt always first)
img_feats, img_pos_feats, num_bbs = map(
list, unzip(map(self._get_img_feat, img_ids)))
img_feat = pad_tensors(img_feats, num_bbs)
img_pos_feat = pad_tensors(img_pos_feats, num_bbs)
tl = input_ids.size(1)
attn_masks = torch.zeros(len(img_ids), max(num_bbs) + tl).long()
for i, nbb in enumerate(num_bbs):
attn_masks.data[i, :tl+nbb].fill_(1)
out_size = attn_masks.size(1)
gather_index = get_gather_index([tl]*len(img_ids), num_bbs,
len(img_ids), tl, out_size)
batch = {'input_ids': input_ids,
'position_ids': position_ids,
'img_feat': img_feat,
'img_pos_feat': img_pos_feat,
'attn_masks': attn_masks,
'gather_index': gather_index}
return batch
class ItmRankDatasetHardNegFromImage(DetectFeatTxtTokDataset):
def __init__(self, txt_db, img_db, neg_sample_size=1):
assert neg_sample_size > 0, \
"ItmRankDatasetHardNegV2 need at least 1 negative sample"
super().__init__(txt_db, img_db)
txt2img = self.txt_db.txt2img
self.txt2img = {id_: txt2img[id_] for id_ in self.ids}
self.img2txts = self.txt_db.img2txts
self.txt_name_list = list(self.txt2img.keys())
self.neg_sample_size = neg_sample_size
def __getitem__(self, i):
gt_txt_id = self.ids[i]
gt_img_id = self.txt2img[gt_txt_id]
gt_txt_ids = self.img2txts[gt_img_id]
# process image features (gt always first)
img_feat, img_pos_feat, nbb = self._get_img_feat(gt_img_id)
img_feat = img_feat.unsqueeze(0)
img_pos_feat = img_pos_feat.unsqueeze(0)
# sample negative
neg_txt_ids = sample_negative(
self.txt_name_list, gt_txt_ids, self.neg_sample_size)
txt_ids = [gt_txt_id] + neg_txt_ids
# process text inputs
all_inputs = []
txt_lens = []
for txt_id in txt_ids:
input_ids = self.txt_db.combine_inputs(
self.txt_db[txt_id]['input_ids'])
all_inputs.append(input_ids)
txt_lens.append(len(input_ids))
input_ids = pad_sequence(all_inputs, batch_first=True, padding_value=0)
position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long
).unsqueeze(0)
attn_masks = torch.zeros(len(txt_ids), max(txt_lens) + nbb).long()
for i, tl in enumerate(txt_lens):
attn_masks.data[i, :tl+nbb].fill_(1)
out_size = attn_masks.size(1)
gather_index = get_gather_index(txt_lens, [nbb]*len(txt_ids),
len(txt_ids), tl, out_size)
batch = {'input_ids': input_ids,
'position_ids': position_ids,
'img_feat': img_feat,
'img_pos_feat': img_pos_feat,
'attn_masks': attn_masks,
'gather_index': gather_index}
return batch
def itm_rank_hnv2_collate(inputs):
assert len(inputs) == 1
return inputs[0]
class ItmValDataset(DetectFeatTxtTokDataset):
""" For evaluating Image-Text-Retrieval task """
def __init__(self, db_dir, img_dir, mini_batch_size=400):
super().__init__(db_dir, img_dir)
del self.lens
self.txt2img = self.txt_db.txt2img
self.img2txts = self.txt_db.img2txts
self.all_img_ids = list(self.img2txts.keys())
assert len(self.img2txts) >= mini_batch_size > 0
self.bs = mini_batch_size
def _get_batch_ids(self, i):
gt_txt_id = self.ids[i]
gt_img_id = self.txt2img[gt_txt_id]
# sample fixed negatives for each gt image
i = self.all_img_ids.index(gt_img_id)
neg_st = i+1
neg_end = neg_st+self.bs-1
if neg_end > len(self.all_img_ids):
# warp around
neg_end -= len(self.all_img_ids)
neg_img_ids = (self.all_img_ids[neg_st:]
+ self.all_img_ids[:neg_end])
else:
neg_img_ids = self.all_img_ids[neg_st:neg_end]
assert len(neg_img_ids) == (self.bs - 1),\
"Did not sample enough neg samples"
return gt_img_id, neg_img_ids
def __getitem__(self, i):
""" this returns list of mini-batches """
gt_img_id, neg_img_ids = self._get_batch_ids(i)
# NOTE 1st one is gt img
batch = self.get_batch(i, [gt_img_id] + neg_img_ids)
return batch
def get_batch(self, i, img_ids):
example = super().__getitem__(i)
input_ids = example['input_ids']
input_ids = self.txt_db.combine_inputs(input_ids)
input_ids = input_ids.unsqueeze(0).expand(len(img_ids), -1).clone()
position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long
).unsqueeze(0)
# process image features (gt always first)
img_feats, img_pos_feats, num_bbs = map(
list, unzip(map(self._get_img_feat, img_ids)))
img_feat = pad_tensors(img_feats, num_bbs)
img_pos_feat = pad_tensors(img_pos_feats, num_bbs)
tl = input_ids.size(1)
attn_masks_text = torch.ones(len(img_ids), tl).long()
# attn_masks_text = torch.ones(1, tl).long()
attn_masks_img = torch.zeros(len(img_ids), max(num_bbs)).long()
for i, nbb in enumerate(num_bbs):
attn_masks_img.data[i, :nbb].fill_(1)
# out_size = attn_masks.size(1)
gather_index = None #get_gather_index([tl]*len(img_ids), num_bbs, len(img_ids), tl, out_size)
batch = {'input_ids': input_ids,
'position_ids': position_ids,
'img_feat': img_feat,
'img_pos_feat': img_pos_feat,
'attn_masks_text': attn_masks_text,
'attn_masks_img': attn_masks_img,
'gather_index': gather_index}
return batch
def itm_val_collate(inputs):
assert len(inputs) == 1, "input batch size > 1"
return inputs[0]
class ItmHardNegDataset(ItmValDataset):
def _get_batch_ids(self, i):
gt_txt_id = self.ids[i]
gt_img_id = self.txt2img[gt_txt_id]
# sample fixed negatives for each gt image
i = self.all_img_ids.index(gt_img_id)
all_img_ids = copy.deepcopy(self.all_img_ids)
all_img_ids.remove(gt_img_id)
random.shuffle(all_img_ids)
neg_img_ids = all_img_ids[:self.bs]
assert len(neg_img_ids) == (self.bs),\
"Did not sample enough neg samples"
return gt_img_id, neg_img_ids
def __getitem__(self, i):
""" this returns list of mini-batches """
_, neg_img_ids = self._get_batch_ids(i)
batch = self.get_batch(i, neg_img_ids)
batch['gt_txt_id'] = self.ids[i]
batch['neg_img_ids'] = neg_img_ids
return batch
itm_hn_collate = itm_val_collate
class ItmEvalDataset(ItmValDataset):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.all_img_ids = sorted(copy.deepcopy(self.all_img_ids),
key=lambda i: self.img_db.name2nbb[i])
def __getitem__(self, i):
mini_batches = []
for st in range(0, len(self.all_img_ids), self.bs):
mini_batches.append(
self.get_batch(i, self.all_img_ids[st:st+self.bs]))
return mini_batches
itm_eval_collate = itm_val_collate

390
dvl/data/mlm.py

@ -0,0 +1,390 @@
"""
MLM datasets
"""
import math
import random
import torch
from torch.utils.data import Dataset
from torch.nn.utils.rnn import pad_sequence
from toolz.sandbox import unzip
from uniter_model.data.data import (DetectFeatTxtTokDataset, TxtTokLmdb, get_ids_and_lens, pad_tensors,
get_gather_index, get_gather_index_uniter)
def random_word(tokens, vocab_range, mask):
"""
Masking some random tokens for Language Model task with probabilities as in
the original BERT paper.
:param tokens: list of int, tokenized sentence.
:param vocab_range: for choosing a random word
:return: (list of int, list of int), masked tokens and related labels for
LM prediction
"""
output_label = []
for i, token in enumerate(tokens):
prob = random.random()
# mask token with 15% probability
if prob < 0.15:
prob /= 0.15
# 80% randomly change token to mask token
if prob < 0.8:
tokens[i] = mask
# 10% randomly change token to random token
elif prob < 0.9:
tokens[i] = random.choice(list(range(*vocab_range)))
# -> rest 10% randomly keep current token
# append current token to output (we will predict these later)
output_label.append(token)
else:
# no masking token (will be ignored by loss function later)
output_label.append(-1)
if all(o == -1 for o in output_label):
# at least mask 1
output_label[0] = tokens[0]
tokens[0] = mask
return tokens, output_label
class MlmDataset(DetectFeatTxtTokDataset):
def __init__(self, txt_db, img_db):
assert isinstance(txt_db, TxtTokLmdb)
super().__init__(txt_db, img_db)
def __getitem__(self, i):
"""
Return:
- input_ids : (L, ), i.e., [cls, wd, wd, ..., sep, 0, 0], 0s padded
- img_feat : (num_bb, d)
- img_pos_feat : (num_bb, 7)
- attn_masks : (L + num_bb, ), ie., [1, 1, ..., 0, 0, 1, 1]
- txt_labels : (L, ), [-1, -1, wid, -1, -1, -1]
0's padded so that (L + num_bb) % 8 == 0
"""
example = super().__getitem__(i)
# text input
input_ids, txt_labels = self.create_mlm_io(example['input_ids'])
# img input
img_input_ids = torch.Tensor([101]).long()
img_feat, img_pos_feat, num_bb = self._get_img_feat(example['img_fname'])
attn_masks = torch.ones(len(input_ids), dtype=torch.long)
attn_masks_img = torch.ones(num_bb+1, dtype=torch.long)
attn_masks_teacher = torch.ones(len(input_ids) + num_bb, dtype=torch.long)
return input_ids, attn_masks, img_input_ids, img_feat, img_pos_feat, attn_masks_img, txt_labels, attn_masks_teacher
def create_mlm_io(self, input_ids):
input_ids, txt_labels = random_word(input_ids,
self.txt_db.v_range,
self.txt_db.mask)
input_ids = torch.tensor([self.txt_db.cls_]
+ input_ids
+ [self.txt_db.sep])
txt_labels = torch.tensor([-1] + txt_labels + [-1])
return input_ids, txt_labels
def mlm_collate(inputs):
"""
Return:
:input_ids (n, max_L) padded with 0
:position_ids (n, max_L) padded with 0
:txt_lens list of [txt_len]
:img_feat (n, max_num_bb, feat_dim)
:img_pos_feat (n, max_num_bb, 7)
:num_bbs list of [num_bb]
:attn_masks (n, max_{L + num_bb}) padded with 0
:txt_labels (n, max_L) padded with -1
"""
(input_ids, attn_masks, img_input_ids, img_feats, img_pos_feats, attn_masks_img, txt_labels, attn_masks_teacher
) = map(list, unzip(inputs))
# text batches
txt_lens = [i.size(0) for i in input_ids]
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0)
txt_labels = pad_sequence(txt_labels, batch_first=True, padding_value=-1)
position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long).unsqueeze(0)
attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0)
# image batches
num_bbs = [f.size(0) for f in img_feats]
img_input_ids = pad_sequence(img_input_ids, batch_first=True, padding_value=0)
img_feat = pad_tensors(img_feats, num_bbs)
img_pos_feat = pad_tensors(img_pos_feats, num_bbs)
img_position_ids = torch.arange(0, img_input_ids.size(1), dtype=torch.long).unsqueeze(0)
attn_masks_img = pad_sequence(attn_masks_img, batch_first=True, padding_value=0)
bs, max_tl = input_ids.size()
out_size = attn_masks_img.size(1)
# gather_index = get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size)
gather_index = get_gather_index([1]*bs, num_bbs, bs, 1, out_size)
attn_masks_teacher = pad_sequence(attn_masks_teacher, batch_first=True, padding_value=0)
gather_index_teacher = get_gather_index_uniter(txt_lens, num_bbs, bs, max_tl, attn_masks_teacher.size(1))
batch = {
'txts': {
'input_ids': input_ids,
'position_ids': position_ids,
'attention_mask': attn_masks,
'img_feat': None,
'img_pos_feat': None,
'img_masks': None,
'gather_index': None
},
'imgs': {
'input_ids': img_input_ids,
'position_ids': img_position_ids,
'attention_mask': attn_masks_img,
'img_feat': img_feat,
'img_pos_feat': img_pos_feat,
'img_masks': None,
'gather_index': gather_index
},
'txt_labels': txt_labels,
'teacher': {
'txt_lens': txt_lens,
'num_bbs': num_bbs,
'bs': bs,
'max_tl': max_tl,
'out_size': out_size,
'gather_index': gather_index_teacher,
'attn_masks': attn_masks_teacher
}
}
return batch
class BlindMlmDataset(Dataset):
def __init__(self, txt_db):
assert isinstance(txt_db, TxtTokLmdb)
self.txt_db = txt_db
self.lens, self.ids = get_ids_and_lens(txt_db)
def __len__(self):
return len(self.ids)
def __getitem__(self, i):
id_ = self.ids[i]
example = self.txt_db[id_]
input_ids, txt_labels = self.create_mlm_io(example['input_ids'])
attn_masks = torch.ones(len(input_ids), dtype=torch.long)
return input_ids, attn_masks, txt_labels
def mlm_blind_collate(inputs):
input_ids, attn_masks, txt_labels = map(list, unzip(inputs))
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0)
position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long
).unsqueeze(0)
attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0)
txt_labels = pad_sequence(txt_labels, batch_first=True, padding_value=-1)
batch = {'input_ids': input_ids,
'position_ids': position_ids,
'attn_masks': attn_masks,
'txt_labels': txt_labels}
return batch
def eval_mask(len_, num_samples=7):
""" build the mask for evaluating MLM
circularly mask 1 word out of every x words
"""
# build the random masks
if len_ <= num_samples:
masks = torch.eye(len_).bool()
num_samples = len_
else:
mask_inds = [list(range(i, len_, num_samples))
for i in range(num_samples)]
masks = torch.zeros(num_samples, len_).bool()
for i, indices in enumerate(mask_inds):
for j in indices:
masks.data[i, j] = 1
assert (masks.sum(dim=0) != torch.ones(len_).long()).sum().item() == 0
assert masks.sum().item() == len_
return masks
def eval_gather_inds(len_, num_samples=7):
""" get the gather indices """
inds = torch.arange(0, num_samples, dtype=torch.long)
mul = math.ceil(len_ / num_samples)
output = inds.repeat(mul)[:len_]
return output
def stack_pad_tensors(tensors, lens=None, ns=None, pad=0):
"""N x [B_i, T, ...]"""
if ns is None:
ns = [t.size(0) for t in tensors]
if lens is None:
lens = [t.size(1) for t in tensors]
max_len = max(lens)
bs = sum(ns)
hid_dims = tensors[0].size()[2:]
dtype = tensors[0].dtype
output = torch.zeros(bs, max_len, *hid_dims, dtype=dtype)
if pad:
output.data.fill_(pad)
i = 0
for t, l, n in zip(tensors, lens, ns):
output.data[i:i+n, :l, ...] = t.data
i += n
return output
def expand_tensors(tensors, ns):
return [t.unsqueeze(0).expand(n, *tuple([-1]*t.dim()))
for t, n in zip(tensors, ns)]
class MlmEvalDataset(DetectFeatTxtTokDataset):
""" For evaluating MLM training task """
def __init__(self, txt_db, img_db):
assert isinstance(txt_db, TxtTokLmdb)
super().__init__(txt_db, img_db)
def __getitem__(self, i):
example = super().__getitem__(i)
# text input
(input_ids, txt_labels, gather_inds
) = self.create_mlm_eval_io(example['input_ids'])
# img input
img_feat, img_pos_feat, num_bb = self._get_img_feat(
example['img_fname'])
attn_masks = torch.ones(input_ids.size(1) + num_bb, dtype=torch.long)
return (input_ids, img_feat, img_pos_feat, attn_masks,
txt_labels, gather_inds)
def create_mlm_eval_io(self, input_ids):
txt_labels = torch.tensor(input_ids)
masks = eval_mask(len(input_ids))
n_mask = masks.size(0)
masks = torch.cat([torch.zeros(n_mask, 1).bool(),
masks,
torch.zeros(n_mask, 1).bool()],
dim=1)
input_ids = torch.tensor([[self.txt_db.cls_]
+ input_ids
+ [self.txt_db.sep]
for _ in range(n_mask)])
input_ids.data.masked_fill_(masks, self.txt_db.mask)
gather_inds = eval_gather_inds(len(txt_labels))
return input_ids, txt_labels, gather_inds
def _batch_gather_tgt(gather_inds, n_masks):
gather_tgts = []
offset = 0
for g, n in zip(gather_inds, n_masks):
gather_tgts.append(g + offset)
offset += n
gather_tgt = pad_sequence(gather_tgts, batch_first=True, padding_value=0)
return gather_tgt
def mlm_eval_collate(inputs):
(input_ids, img_feats, img_pos_feats, attn_masks, txt_labels, gather_inds
) = map(list, unzip(inputs))
# sizes
n_masks, txt_lens = map(list, unzip(i.size() for i in input_ids))
# text batches
input_ids = stack_pad_tensors(input_ids, txt_lens, n_masks)
position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long
).unsqueeze(0)
txt_labels = pad_sequence(txt_labels, batch_first=True, padding_value=-1)
gather_tgt = _batch_gather_tgt(gather_inds, n_masks)
# image batches
num_bbs = [f.size(0) for f in img_feats]
img_feat = stack_pad_tensors(expand_tensors(img_feats, n_masks),
num_bbs, n_masks)
img_pos_feat = stack_pad_tensors(expand_tensors(img_pos_feats, n_masks),
num_bbs, n_masks)
bs, max_tl = input_ids.size()
attn_masks = stack_pad_tensors(expand_tensors(attn_masks, n_masks),
None, n_masks)
out_size = attn_masks.size(1)
# repeat txt_lens, num_bbs
txt_lens = [l for l, n in zip(txt_lens, n_masks) for _ in range(n)]
num_bbs = [b for b, n in zip(num_bbs, n_masks) for _ in range(n)]
gather_index = get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size)
batch = {'input_ids': input_ids,
'position_ids': position_ids,
'img_feat': img_feat,
'img_pos_feat': img_pos_feat,
'attn_masks': attn_masks,
'gather_index': gather_index,
'gather_tgt': gather_tgt,
'txt_labels': txt_labels}
return batch
class BlindMlmEvalDataset(Dataset):
def __init__(self, txt_db):
assert isinstance(txt_db, TxtTokLmdb)
self.txt_db = txt_db
self.lens, self.ids = get_ids_and_lens(txt_db)
def __len__(self):
return len(self.ids)
def __getitem__(self, i):
id_ = self.ids[i]
example = self.txt_db[id_]
input_ids = example['input_ids']
# text input
input_ids = example['input_ids']
(input_ids, txt_labels, gather_inds
) = self.txt_db.create_mlm_eval_io(input_ids)
attn_masks = torch.ones(len(input_ids), dtype=torch.long)
return input_ids, attn_masks, txt_labels, gather_inds
def mlm_blind_eval_collate(inputs):
(input_ids, position_ids, attn_masks, txt_labels, gather_inds
) = map(list, unzip(inputs))
# sizes
n_masks, txt_lens = map(list, unzip(i.size() for i in input_ids))
# text batches
input_ids = stack_pad_tensors(input_ids, txt_lens, n_masks)
position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long
).unsqueeze(0)
attn_masks = stack_pad_tensors(expand_tensors(attn_masks, n_masks),
None, n_masks)
txt_labels = pad_sequence(txt_labels, batch_first=True, padding_value=-1)
gather_tgt = _batch_gather_tgt(gather_inds, n_masks)
batch = {'input_ids': input_ids,
'position_ids': position_ids,
'attn_masks': attn_masks,
'gather_tgt': gather_tgt,
'txt_labels': txt_labels}
return batch

263
dvl/data/mrm.py

@ -0,0 +1,263 @@
"""
MRM Datasets
"""
import random
import torch
from torch.utils.data import Dataset
from torch.nn.utils.rnn import pad_sequence
from toolz.sandbox import unzip
from uniter_model.data.data import DetectFeatTxtTokDataset, pad_tensors, get_gather_index, get_gather_index_uniter
def _get_img_mask(mask_prob, num_bb):
img_mask = [random.random() < mask_prob for _ in range(num_bb)]
if not any(img_mask):
# at least mask 1
img_mask[random.choice(range(num_bb))] = True
img_mask = torch.tensor(img_mask)
return img_mask
def _get_img_tgt_mask(img_mask, txt_len):
z = torch.zeros(txt_len, dtype=torch.bool)
img_mask_tgt = torch.cat([z, img_mask], dim=0)
return img_mask_tgt
def _get_feat_target(img_feat, img_masks):
img_masks_ext = img_masks.unsqueeze(-1).expand_as(img_feat) # (n, m, d)
feat_dim = img_feat.size(-1)
feat_targets = img_feat[img_masks_ext].contiguous().view(
-1, feat_dim) # (s, d)
return feat_targets
def _mask_img_feat(img_feat, img_masks):
img_masks_ext = img_masks.unsqueeze(-1).expand_as(img_feat)
img_feat_masked = img_feat.data.masked_fill(img_masks_ext, 0)
return img_feat_masked
class MrfrDataset(DetectFeatTxtTokDataset):
def __init__(self, mask_prob, *args, **kwargs):
super().__init__(*args, **kwargs)
self.mask_prob = mask_prob
def __getitem__(self, i):
"""
Return:
- input_ids : (L, ), i.e., [cls, wd, wd, ..., sep, 0, 0], 0s padded
- img_feat : (num_bb, d)
- img_pos_feat : (num_bb, 7)
- attn_masks : (L + num_bb, ), ie., [1, 1, ..., 0, 0, 1, 1]
- img_mask : (num_bb, ) between {0, 1}
"""
example = super().__getitem__(i)
# text input
input_ids = example['input_ids']
input_ids = self.txt_db.combine_inputs(input_ids)
# image input features
img_input_ids = torch.Tensor([101]).long()
img_feat, img_pos_feat, num_bb = self._get_img_feat(example['img_fname'])
img_mask = _get_img_mask(self.mask_prob, num_bb)
img_mask_tgt = _get_img_tgt_mask(img_mask, 1)
img_mask_tgt_teacher = _get_img_tgt_mask(img_mask, len(input_ids))
attn_masks = torch.ones(len(input_ids), dtype=torch.long)
attn_masks_img = torch.ones(num_bb+1, dtype=torch.long)
attn_masks_teacher = torch.ones(len(input_ids) + num_bb, dtype=torch.long)
return (input_ids, attn_masks, img_input_ids, img_feat, img_pos_feat, attn_masks_img,
img_mask, img_mask_tgt, attn_masks_teacher, img_mask_tgt_teacher)
def mrfr_collate(inputs):
"""
Return:
- input_ids : (n, max_L), i.e., [cls, wd, wd, ..., sep, 0, 0], 0s padded
- position_ids : (n, max_L)
- txt_lens : list of [input_len]
- img_feat : (n, max_num_bb, d)
- img_pos_feat : (n, max_num_bb, 7)
- num_bbs : list of [num_bb]
- attn_masks : (n, max_{L + num_bb}), ie., [1, 1, ..., 0, 0, 1, 1]
- img_masks : (n, max_num_bb) between {0, 1}
"""
(input_ids, attn_masks, img_input_ids, img_feats, img_pos_feats, attn_masks_img, img_masks, img_mask_tgts,
attn_masks_teacher, img_mask_tgt_teacher) = map(list, unzip(inputs))
txt_lens = [i.size(0) for i in input_ids]
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0)
position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long
).unsqueeze(0)
num_bbs = [f.size(0) for f in img_feats]
img_feat = pad_tensors(img_feats, num_bbs)
img_input_ids = pad_sequence(img_input_ids, batch_first=True, padding_value=0)
img_pos_feat = pad_tensors(img_pos_feats, num_bbs)
img_position_ids = torch.arange(0, img_input_ids.size(1), dtype=torch.long).unsqueeze(0)
img_masks = pad_sequence(img_masks, batch_first=True, padding_value=0)
feat_targets = _get_feat_target(img_feat, img_masks)
img_feat = _mask_img_feat(img_feat, img_masks)
img_mask_tgt = pad_sequence(img_mask_tgts, batch_first=True, padding_value=0)
img_mask_tgt_teacher = pad_sequence(img_mask_tgt_teacher, batch_first=True, padding_value=0)
attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0)
attn_masks_img = pad_sequence(attn_masks_img, batch_first=True, padding_value=0)
bs, max_tl = input_ids.size()
out_size = attn_masks_img.size(1)
# gather_index = get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size)
gather_index = get_gather_index([1]*bs, num_bbs, bs, 1, out_size)
attn_masks_teacher = pad_sequence(attn_masks_teacher, batch_first=True, padding_value=0)
gather_index_teacher = get_gather_index_uniter(txt_lens, num_bbs, bs, max_tl, attn_masks_teacher.size(1))
batch = {
'txts': {
'input_ids': input_ids,
'position_ids': position_ids,
'attention_mask': attn_masks,
'img_feat': None,
'img_pos_feat': None,
'img_masks': None,
'gather_index': None
},
'imgs': {
'input_ids': img_input_ids,
'position_ids': img_position_ids,
'attention_mask': attn_masks_img,
'img_feat': img_feat,
'img_pos_feat': img_pos_feat,
'img_masks': img_masks,
'gather_index': gather_index
},
'teacher': {
'txt_lens': txt_lens,
'num_bbs': num_bbs,
'bs': bs,
'max_tl': max_tl,
'out_size': out_size,
'gather_index': gather_index_teacher,
'attn_masks': attn_masks_teacher,
'img_mask_tgt': img_mask_tgt_teacher,
},
'feat_targets': feat_targets,
'img_mask_tgt': img_mask_tgt}
return batch
def _get_targets(img_masks, img_soft_label):
soft_label_dim = img_soft_label.size(-1)
img_masks_ext_for_label = img_masks.unsqueeze(-1).expand_as(img_soft_label)
label_targets = img_soft_label[img_masks_ext_for_label].contiguous().view(
-1, soft_label_dim)
return label_targets
class MrcDataset(DetectFeatTxtTokDataset):
def __init__(self, mask_prob, *args, **kwargs):
super().__init__(*args, **kwargs)
self.mask_prob = mask_prob
def _get_img_feat(self, fname):
img_dump = self.img_db.get_dump(fname)
num_bb = self.img_db.name2nbb[fname]
img_feat = torch.tensor(img_dump['features'])
bb = torch.tensor(img_dump['norm_bb'])
img_bb = torch.cat([bb, bb[:, 4:5]*bb[:, 5:]], dim=-1)
img_soft_label = torch.tensor(img_dump['soft_labels'])
return img_feat, img_bb, img_soft_label, num_bb
def __getitem__(self, i):
example = super().__getitem__(i)
img_feat, img_pos_feat, img_soft_labels, num_bb = self._get_img_feat(
example['img_fname'])
# image input features
img_input_ids = torch.Tensor([101]).long()
img_mask = _get_img_mask(self.mask_prob, num_bb)
# text input
input_ids = example['input_ids']
input_ids = self.txt_db.combine_inputs(input_ids)
img_mask_tgt = _get_img_tgt_mask(img_mask, 1)
img_mask_tgt_teacher = _get_img_tgt_mask(img_mask, len(input_ids))
attn_masks = torch.ones(len(input_ids), dtype=torch.long)
attn_masks_img = torch.ones(num_bb+1, dtype=torch.long)
attn_masks_teacher = torch.ones(len(input_ids) + num_bb, dtype=torch.long)
return (input_ids, attn_masks, img_input_ids, img_feat, img_pos_feat, attn_masks_img,
img_soft_labels, img_mask, img_mask_tgt, attn_masks_teacher, img_mask_tgt_teacher)
def mrc_collate(inputs):
(input_ids, attn_masks, img_input_ids, img_feats, img_pos_feats, attn_masks_img, img_soft_labels,
img_masks, img_mask_tgts, attn_masks_teacher, img_mask_tgt_teacher) = map(list, unzip(inputs))
txt_lens = [i.size(0) for i in input_ids]
num_bbs = [f.size(0) for f in img_feats]
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0)
position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long
).unsqueeze(0)
img_feat = pad_tensors(img_feats, num_bbs)
img_input_ids = pad_sequence(img_input_ids, batch_first=True, padding_value=0)
img_pos_feat = pad_tensors(img_pos_feats, num_bbs)
img_position_ids = torch.arange(0, img_input_ids.size(1), dtype=torch.long).unsqueeze(0)
img_soft_label = pad_tensors(img_soft_labels, num_bbs)
img_masks = pad_sequence(img_masks, batch_first=True, padding_value=0)
label_targets = _get_targets(img_masks, img_soft_label)
img_feat = _mask_img_feat(img_feat, img_masks)
img_mask_tgt = pad_sequence(img_mask_tgts, batch_first=True, padding_value=0)
img_mask_tgt_teacher = pad_sequence(img_mask_tgt_teacher, batch_first=True, padding_value=0)
attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0)
attn_masks_img = pad_sequence(attn_masks_img, batch_first=True, padding_value=0)
bs, max_tl = input_ids.size()
out_size = attn_masks_img.size(1)
# gather_index = get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size)
gather_index = get_gather_index([1]*bs, num_bbs, bs, 1, out_size)
attn_masks_teacher = pad_sequence(attn_masks_teacher, batch_first=True, padding_value=0)
gather_index_teacher = get_gather_index_uniter(txt_lens, num_bbs, bs, max_tl, attn_masks_teacher.size(1))
batch = {
'txts': {
'input_ids': input_ids,
'position_ids': position_ids,
'attention_mask': attn_masks,
'img_feat': None,
'img_pos_feat': None,
'img_masks': None,
'gather_index': None
},
'imgs': {
'input_ids': img_input_ids,
'position_ids': img_position_ids,
'attention_mask': attn_masks_img,
'img_feat': img_feat,
'img_pos_feat': img_pos_feat,
'img_masks': img_masks,
'gather_index': gather_index
},
'teacher': {
'txt_lens': txt_lens,
'num_bbs': num_bbs,
'bs': bs,
'max_tl': max_tl,
'out_size': out_size,
'gather_index': gather_index_teacher,
'attn_masks': attn_masks_teacher,
'img_mask_tgt': img_mask_tgt_teacher,
},
'img_mask_tgt': img_mask_tgt,
'label_targets': label_targets}
return batch

145
dvl/data/vqa.py

@ -0,0 +1,145 @@
"""
VQA dataset
"""
import torch
from torch.nn.utils.rnn import pad_sequence
from toolz.sandbox import unzip
from uniter_model.data.data import DetectFeatTxtTokDataset, pad_tensors, get_gather_index
def _get_vqa_target(example, num_answers):
target = torch.zeros(num_answers)
labels = example['target']['labels']
scores = example['target']['scores']
if labels and scores:
target.scatter_(0, torch.tensor(labels), torch.tensor(scores))
return target
class VqaDataset(DetectFeatTxtTokDataset):
""" NOTE: This handels distributed inside """
def __init__(self, num_answers, *args, **kwargs):
super().__init__(*args, **kwargs)
self.num_answers = num_answers
def __getitem__(self, i):
example = super().__getitem__(i)
qid = self.ids[i]
img_feat, img_pos_feat, num_bb = self._get_img_feat(
example['img_fname'])
# text input
input_ids = example['input_ids']
input_ids = self.txt_db.combine_inputs(input_ids)
img_input_ids = torch.Tensor([101]).long()
target = _get_vqa_target(example, self.num_answers)
attn_masks_txt = torch.ones(len(input_ids), dtype=torch.long)
attn_masks_img = torch.ones(num_bb+1, dtype=torch.long)
return qid, input_ids, attn_masks_txt, img_input_ids, img_feat, img_pos_feat, attn_masks_img, target
def vqa_collate(inputs):
(qids, input_ids, attn_masks_txt, img_input_ids, img_feats, img_pos_feats, attn_masks_img, targets
) = map(list, unzip(inputs))
txt_lens = [i.size(0) for i in input_ids]
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0)
position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long
).unsqueeze(0)
attn_masks_txt = pad_sequence(attn_masks_txt, batch_first=True, padding_value=0)
attn_masks_img = pad_sequence(attn_masks_img, batch_first=True, padding_value=0)
targets = torch.stack(targets, dim=0)
num_bbs = [f.size(0) for f in img_feats]
img_feat = pad_tensors(img_feats, num_bbs)
img_pos_feat = pad_tensors(img_pos_feats, num_bbs)
img_input_ids = pad_sequence(img_input_ids, batch_first=True, padding_value=0)
img_position_ids = torch.arange(0, img_input_ids.size(1), dtype=torch.long).unsqueeze(0)
bs, max_tl = input_ids.size()
out_size = attn_masks_img.size(1)
gather_index_teacher = get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size)
gather_index = get_gather_index([1]*bs, num_bbs, bs, 1, out_size)
batch = {'qids': qids,
'txts': {
'input_ids': input_ids,
'position_ids': position_ids,
'attention_mask': attn_masks_txt,
'img_feat': None,
'img_pos_feat': None,
'img_masks': None,
'gather_index': None
},
'imgs': {
'input_ids': img_input_ids,
'position_ids': img_position_ids,
'attention_mask': attn_masks_img,
'img_feat': img_feat,
'img_pos_feat': img_pos_feat,
'img_masks': None,
'gather_index': gather_index
},
'gather_index_teacher': gather_index_teacher,
'targets': targets}
return batch
class VqaEvalDataset(VqaDataset):
def __getitem__(self, i):
qid = self.ids[i]
example = DetectFeatTxtTokDataset.__getitem__(self, i)
img_feat, img_pos_feat, num_bb = self._get_img_feat(
example['img_fname'])
# text input
input_ids = example['input_ids']
input_ids = self.txt_db.combine_inputs(input_ids)
if 'target' in example:
target = _get_vqa_target(example, self.num_answers)
else:
target = None
attn_masks = torch.ones(len(input_ids) + num_bb, dtype=torch.long)
return qid, input_ids, img_feat, img_pos_feat, attn_masks, target
def vqa_eval_collate(inputs):
(qids, input_ids, img_feats, img_pos_feats, attn_masks, targets
) = map(list, unzip(inputs))
txt_lens = [i.size(0) for i in input_ids]
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0)
position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long
).unsqueeze(0)
attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0)
if targets[0] is None:
targets = None
else:
targets = torch.stack(targets, dim=0)
num_bbs = [f.size(0) for f in img_feats]
img_feat = pad_tensors(img_feats, num_bbs)
img_pos_feat = pad_tensors(img_pos_feats, num_bbs)
bs, max_tl = input_ids.size()
out_size = attn_masks.size(1)
gather_index = get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size)
batch = {'qids': qids,
'input_ids': input_ids,
'position_ids': position_ids,
'img_feat': img_feat,
'img_pos_feat': img_pos_feat,
'attn_masks': attn_masks,
'gather_index': gather_index,
'targets': targets}
return batch

66
dvl/hn.py

@ -0,0 +1,66 @@
import random
import logging
import collections
import json
import os
import itertools
import numpy as np
from collections import ChainMap
from dvl.trainer import build_dataloader, _save_checkpoint, eval_model_on_dataloader, load_dataset
logger = logging.getLogger()
def random_hard_neg(fname2id, num_hard_negatives, id2set, set2id):
# num_hard_negatives must be very small
hard_negs = dict()
for i in fname2id:
while True:
hard_neg = random.choices(set2id[id2set[i]], k=num_hard_negatives)
if fname2id[i] not in hard_neg:
break
hard_negs[i] = hard_neg
return hard_negs
def get_img_txt_mappings(train_txt_dbs):
train_img2txt = dict(ChainMap(*[json.load(open(os.path.join(db_folder, 'img2txts.json'))) for db_folder in train_txt_dbs]))
train_txt2img = dict(itertools.chain(*[[(v, k) for v in vals] for k, vals in train_img2txt.items()]))
train_json = [json.load(open(os.path.join(db_folder, 'img2txts.json'))) for db_folder in train_txt_dbs]
train_img2set = dict(ChainMap(*[{k:v for k in tj } for tj, v in zip(train_json, train_txt_dbs)]))
train_txt2set = {txt_id: train_img2set[img_id] for txt_id, img_id in train_txt2img.items()}
train_set2img, train_set2txt = collections.defaultdict(list), collections.defaultdict(list)
for img_id, set_id in train_img2set.items():
train_set2img[set_id].append(img_id)
train_set2txt[set_id] += train_img2txt[img_id]
return train_img2txt, train_txt2img, train_img2set, train_txt2set, train_set2img, train_set2txt
def sampled_hard_negatives(all_img_dbs, args, collate_func, bi_encoder, train_img2txt, train_txt2img):
train_dataset_eval = load_dataset(all_img_dbs, args.train_txt_dbs, args.train_img_dbs, args, True)
hard_negs_txt_all, hard_negs_img_all = [], []
for dset in train_dataset_eval.datasets:
dset.new_epoch()
train_dataloader_hn = build_dataloader(dset, collate_func, True, args, args.valid_batch_size)
logger.info(f'eval for train dataloader len (for hn) = {len(train_dataloader_hn)}')
num_hard_sampled = min(max(args.num_hard_negatives * 2 + 10, 50), 1000)
loss_hard, correct_ratio_hard, indexer_hard, recall_hard, (hard_neg_img, hard_neg_txt) = \
eval_model_on_dataloader(bi_encoder, train_dataloader_hn, args, train_img2txt, num_hard_sampled)
[v.remove(train_txt2img[k]) for k, v in hard_neg_img.items() if train_txt2img[k] in v]
hard_neg_txt = {k: list(set(v) - set(train_img2txt[k])) for k, v in hard_neg_txt.items()}
# remove self in hard negatives as they are labels
hard_negs_txt_all.append({k: random.sample(v, args.num_hard_negatives) for k, v in hard_neg_txt.items()})
hard_negs_img_all.append({k: random.sample(v, args.num_hard_negatives) for k, v in hard_neg_img.items()})
hard_negs_txt_all = dict(collections.ChainMap(*hard_negs_txt_all))
hard_negs_img_all = dict(collections.ChainMap(*hard_negs_img_all))
return hard_negs_txt_all, hard_negs_img_all

154
dvl/indexer/faiss_indexers.py

@ -0,0 +1,154 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
FAISS-based index components for dense retriver
"""
import logging
import pickle
from typing import List, Tuple
import faiss
import numpy as np
logger = logging.getLogger()
class DenseIndexer(object):
def __init__(self, buffer_size: int = 50000):
self.buffer_size = buffer_size
self.index_id_to_db_id = []
self.index = None
def index_data(self, data: List[Tuple[object, np.array]]):
raise NotImplementedError
def search_knn(self, query_vectors: np.array, top_docs: int) -> List[Tuple[List[object], List[float]]]:
raise NotImplementedError
def serialize(self, file: str):
logger.info('Serializing index to %s', file)
index_file = file + '.index.dpr'
meta_file = file + '.index_meta.dpr'
faiss.write_index(self.index, index_file)
with open(meta_file, mode='wb') as f:
pickle.dump(self.index_id_to_db_id, f)
def deserialize_from(self, file: str):
logger.info('Loading index from %s', file)
index_file = file + '.index.dpr'
meta_file = file + '.index_meta.dpr'
self.index = faiss.read_index(index_file)
logger.info('Loaded index of type %s and size %d', type(self.index), self.index.ntotal)
with open(meta_file, "rb") as reader:
self.index_id_to_db_id = pickle.load(reader)
assert len(
self.index_id_to_db_id) == self.index.ntotal, 'Deserialized index_id_to_db_id should match faiss index size'
def _update_id_mapping(self, db_ids: List):
self.index_id_to_db_id.extend(db_ids)
class DenseFlatIndexer(DenseIndexer):
def __init__(self, vector_sz: int, buffer_size: int = 50000):
super(DenseFlatIndexer, self).__init__(buffer_size=buffer_size)
self.index = faiss.IndexFlatIP(vector_sz)
def index_data(self, data: List[Tuple[object, np.array]]):
n = len(data)
# indexing in batches is beneficial for many faiss index types
for i in range(0, n, self.buffer_size):
db_ids = [t[0] for t in data[i:i + self.buffer_size]]
vectors = [np.reshape(t[1], (1, -1)) for t in data[i:i + self.buffer_size]]
vectors = np.concatenate(vectors, axis=0)
self._update_id_mapping(db_ids)
self.index.add(vectors)
indexed_cnt = len(self.index_id_to_db_id)
logger.info('Total data indexed %d', indexed_cnt)
def search_knn(self, query_vectors: np.array, top_docs: int) -> List[Tuple[List[object], List[float]]]:
scores, indexes = self.index.search(query_vectors, top_docs)
# convert to external ids
db_ids = [[self.index_id_to_db_id[i] for i in query_top_idxs] for query_top_idxs in indexes]
result = [(db_ids[i], scores[i]) for i in range(len(db_ids))]
return result
class DenseHNSWFlatIndexer(DenseIndexer):
"""
Efficient index for retrieval. Note: default settings are for hugh accuracy but also high RAM usage
"""
def __init__(self, vector_sz: int, buffer_size: int = 50000, store_n: int = 512
, ef_search: int = 128, ef_construction: int = 200):
super(DenseHNSWFlatIndexer, self).__init__(buffer_size=buffer_size)
# IndexHNSWFlat supports L2 similarity only
# so we have to apply DOT -> L2 similairy space conversion with the help of an extra dimension
index = faiss.IndexHNSWFlat(vector_sz + 1, store_n)
index.hnsw.efSearch = ef_search
index.hnsw.efConstruction = ef_construction
self.index = index
self.phi = 0
def index_data(self, data: List[Tuple[object, np.array]]):
n = len(data)
# max norm is required before putting all vectors in the index to convert inner product similarity to L2
if self.phi > 0:
raise RuntimeError('DPR HNSWF index needs to index all data at once,'
'results will be unpredictable otherwise.')
phi = 0
for i, item in enumerate(data):
id, doc_vector = item
norms = (doc_vector ** 2).sum()
phi = max(phi, norms)
logger.info('HNSWF DotProduct -> L2 space phi={}'.format(phi))
self.phi = 0
# indexing in batches is beneficial for many faiss index types
for i in range(0, n, self.buffer_size):
db_ids = [t[0] for t in data[i:i + self.buffer_size]]
vectors = [np.reshape(t[1], (1, -1)) for t in data[i:i + self.buffer_size]]
norms = [(doc_vector ** 2).sum() for doc_vector in vectors]
aux_dims = [np.sqrt(phi - norm) for norm in norms]
hnsw_vectors = [np.hstack((doc_vector, aux_dims[i].reshape(-1, 1))) for i, doc_vector in
enumerate(vectors)]
hnsw_vectors = np.concatenate(hnsw_vectors, axis=0)
self._update_id_mapping(db_ids)
self.index.add(hnsw_vectors)
logger.info('data indexed %d', len(self.index_id_to_db_id))
indexed_cnt = len(self.index_id_to_db_id)
logger.info('Total data indexed %d', indexed_cnt)
def search_knn(self, query_vectors: np.array, top_docs: int) -> List[Tuple[List[object], List[float]]]:
aux_dim = np.zeros(len(query_vectors), dtype='float32')
query_nhsw_vectors = np.hstack((query_vectors, aux_dim.reshape(-1, 1)))
# logger.info('query_hnsw_vectors %s', query_nhsw_vectors.shape)
scores, indexes = self.index.search(query_nhsw_vectors, top_docs)
# convert to external ids
db_ids = [[self.index_id_to_db_id[i] for i in query_top_idxs] for query_top_idxs in indexes]
result = [(db_ids[i], scores[i]) for i in range(len(db_ids))]
return result
def deserialize_from(self, file: str):
super(DenseHNSWFlatIndexer, self).deserialize_from(file)
# to trigger warning on subsequent indexing
self.phi = 1

0
dvl/models/__init__.py

BIN
dvl/models/__pycache__/__init__.cpython-38.pyc

Binary file not shown.

BIN
dvl/models/__pycache__/bi_encoder.cpython-38.pyc

Binary file not shown.

757
dvl/models/bi_encoder.py

@ -0,0 +1,757 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
Pipeline to train DPR Biencoder
"""
import argparse
import glob
import logging
import math
import os
import random
import time
import json
import csv
import re
import torch
import numpy as np
from typing import Tuple
from collections import defaultdict
from torch import nn
from torch import Tensor as T
from torch.optim.lr_scheduler import LambdaLR
import torch.nn.functional as F
from torch.nn import LayerNorm
from transformers import BertModel, BertConfig, BertPreTrainedModel
from transformers.optimization import AdamW
from uniter_model.model.model import UniterPreTrainedModel, UniterModel, UniterConfig
from dvl.const import IMG_DIM
#from dvl.indexer.faiss_indexers import DenseIndexer
from uniter_model.model.layer import GELU, BertOnlyMLMHead, BertPooler
from uniter_model.model.model import RegionClassification, RegionFeatureRegression, pad_tensor_to_mul
from typing import List
logger = logging.getLogger()
logger.setLevel(logging.INFO)
if (logger.hasHandlers()):
logger.handlers.clear()
console = logging.StreamHandler()
logger.addHandler(console)
def dot_product_scores(q_vectors: T, ctx_vectors: T, cosine=False) -> T:
"""
calculates q->ctx scores for every row in ctx_vector
:param q_vector:
:param ctx_vector:
:return:
"""
# q_vector: n1 x D, ctx_vectors: n2 x D, result n1 x n2
r = torch.matmul(q_vectors, torch.transpose(ctx_vectors, 0, 1))
if cosine:
n1 = torch.norm(q_vectors, dim=-1)
n2 = torch.norm(ctx_vectors, dim=-1)
n_out = torch.ger(n1, n2)
return r / n_out
return r
def cosine_scores(q_vector: T, ctx_vectors: T):
# q_vector: n1 x D, ctx_vectors: n2 x D, result n1 x n2
return F.cosine_similarity(q_vector, ctx_vectors, dim=1)
class BertEncoder(BertPreTrainedModel):
def __init__(self, config, project_dim: int = 0):
super().__init__(config)
self.bert = BertModel(config)
assert config.hidden_size > 0, 'Encoder hidden_size can\'t be zero'
# self.encode_proj = nn.Linear(config.hidden_size, project_dim) if project_dim != 0 else None
if project_dim > 0:
self.encode_proj = nn.Sequential(
nn.Linear(config.hidden_size, config.hidden_size * 2),
GELU(),
LayerNorm(config.hidden_size * 2, eps=1e-12),
nn.Linear(config.hidden_size * 2, project_dim)
)
else:
self.encode_proj = None
self.init_weights()
@classmethod
def init_encoder(cls, cfg_name: str, checkpoint_path: str, project_dim: int = 0, dropout: float = 0.1, **kwargs)\
-> BertModel:
cfg = BertConfig.from_pretrained(cfg_name if cfg_name else 'bert-base-uncased')
if dropout != 0:
cfg.attention_probs_dropout_prob = dropout
cfg.hidden_dropout_prob = dropout
if checkpoint_path is not None and len(checkpoint_path) > 0:
#state_dict = torch.load(checkpoint_path, map_location=torch.device('cpu'))
state_dict = torch.load(checkpoint_path)
return cls.from_pretrained(cfg_name, config=cfg, project_dim=project_dim, state_dict=state_dict, **kwargs)
else:
return cls.from_pretrained(cfg_name, config=cfg, project_dim=project_dim, **kwargs)
def forward(self, input_ids, attention_mask, position_ids,
img_feat=None, img_pos_feat=None, img_masks=None, gather_index=None):
if self.config.output_hidden_states:
sequence_output, pooled_output, hidden_states = self.bert(input_ids=input_ids,
token_type_ids=None,
attention_mask=attention_mask,
position_ids=position_ids)
else:
hidden_states = None
sequence_output, pooled_output = self.bert(input_ids=input_ids,
token_type_ids=None,
attention_mask=attention_mask,
position_ids=position_ids)
pooled_output = sequence_output[:, 0, :]
if self.encode_proj:
pooled_output = self.encode_proj(pooled_output)
return sequence_output, pooled_output, hidden_states
def get_out_size(self):
if self.encode_proj:
return self.encode_proj.out_features
return self.config.hidden_size
class UniterEncoder(UniterPreTrainedModel):
def __init__(self, config, project_dim: int = 0):
super().__init__(config)
self.bert = UniterModel(config, IMG_DIM)
assert config.hidden_size > 0, 'Encoder hidden_size can\'t be zero'
# self.encode_proj = nn.Linear(config.hidden_size, project_dim) if project_dim != 0 else None # Yen-Chun
if project_dim > 0:
self.encode_proj = nn.Sequential(
nn.Linear(config.hidden_size, config.hidden_size * 2),
GELU(),
LayerNorm(config.hidden_size * 2, eps=1e-12),
nn.Linear(config.hidden_size * 2, project_dim)
)
else:
self.encode_proj = None
self.apply(self.init_weights)
@classmethod
def init_encoder(cls, cfg_name: str, checkpoint_path: str, project_dim: int = 0, dropout: float = 0.1, **kwargs)\
-> UniterModel:
cfg = BertConfig.from_pretrained(cfg_name if cfg_name else 'bert-base-uncased')
if dropout != 0:
cfg.attention_probs_dropout_prob = dropout
cfg.hidden_dropout_prob = dropout
if checkpoint_path is not None and len(checkpoint_path) > 0 and checkpoint_path.lower() != 'none':
logger.info(f'load from {checkpoint_path} for uniter encoder')
state_dict = torch.load(checkpoint_path, map_location=torch.device('cpu'))
state_dict = torch.load(checkpoint_path)
#if 'model_dict' in state_dict:
# state_dict = state_dict['model_dict']
else:
logger.info('no checkpoint, random initialization for img encoder')
state_dict = dict()
return cls.from_pretrained(cfg_name, state_dict=state_dict, project_dim=project_dim, **kwargs)
def forward(self, input_ids, attention_mask, position_ids,
img_feat, img_pos_feat, img_masks, gather_index=None) -> Tuple[T, ...]:
if self.config.output_hidden_states:
sequence_output, pooled_output, hidden_states = self.bert(input_ids=input_ids,
position_ids=position_ids,
attention_mask=attention_mask,
img_feat=img_feat,
img_pos_feat=img_pos_feat,
img_masks=img_masks,
img_type_ids=None,
gather_index=gather_index,
output_all_encoded_layers=True
)
else:
hidden_states = None
sequence_output = self.bert(input_ids=input_ids,
position_ids=position_ids,
attention_mask=attention_mask,
img_feat=img_feat,
img_pos_feat=img_pos_feat,
img_masks=img_masks,
img_type_ids=None,
gather_index=gather_index,
output_all_encoded_layers=False)
# pooled_output = self.bert.pooler(sequence_output)
pooled_output = sequence_output[:, 0, :]
if self.encode_proj:
pooled_output = self.encode_proj(pooled_output)
return sequence_output, pooled_output, hidden_states
def get_out_size(self):
if self.encode_proj:
return self.encode_proj.out_features
return self.config.hidden_size
class BiEncoder(nn.Module):
""" Bi-Encoder model component. Encapsulates query/question and context/passage encoders.
"""
def __init__(self, args, fix_img_encoder: bool = False, fix_txt_encoder: bool = False, project_dim: int = 0):
super(BiEncoder, self).__init__()
logger.info('*'*100)
logger.info('loading img model')
if args.img_model_type == 'uniter-base':
self.img_model = UniterEncoder.init_encoder(args.img_model_config, checkpoint_path=args.img_checkpoint, project_dim=project_dim)
else:
raise ValueError(f'image encoder does not support other types ({args.img_model_type}) for now')
logger.info('*' * 100)
logger.info('loading txt model')
if args.txt_model_type == 'bert-base':
self.txt_model = BertEncoder.init_encoder(args.txt_model_config, checkpoint_path=args.txt_checkpoint, project_dim=project_dim)
elif args.txt_model_type == 'uniter-base':
self.txt_model = UniterEncoder.init_encoder(args.txt_model_config, checkpoint_path=args.txt_checkpoint, project_dim=project_dim)
else:
raise ValueError(f'txt encoder does not support other types ({args.txt_model_type}) for now')
self.fix_img_encoder = fix_img_encoder
self.fix_txt_encoder = fix_txt_encoder
self.project_dim = project_dim
if fix_txt_encoder:
for param in self.txt_model.parameters():
param.requires_grad = False
if fix_img_encoder:
for param in self.img_model.parameters():
param.requires_grad = False
@staticmethod
def get_representation(sub_model, input_ids, attention_mask, position_ids, img_feat, img_pos_feat, img_masks,
gather_index=None, fix_encoder=False):
if fix_encoder:
with torch.no_grad():
sequence_output, pooled_output, hidden_states = sub_model(input_ids, attention_mask, position_ids,
img_feat, img_pos_feat, img_masks,
gather_index)
else:
sequence_output, pooled_output, hidden_states = sub_model(input_ids, attention_mask, position_ids,
img_feat, img_pos_feat, img_masks,
gather_index)
if sub_model.training:
sequence_output.requires_grad_(requires_grad=True)
pooled_output.requires_grad_(requires_grad=True)
return sequence_output, pooled_output, hidden_states
def forward(self, batch, output_all_encoded_layers=False):
# batch keys
# imgs
# txts
# caps
batch = defaultdict(lambda: None, batch)
if 'txts' in batch:
sb = batch['txts']
txt_seq, txt_pooled, txt_hidden = self.get_representation(self.txt_model, sb['input_ids'],
sb['attention_mask'], sb['position_ids'],
sb['img_feat'], sb['img_pos_feat'],
sb['img_masks'],
sb['gather_index'], self.fix_txt_encoder)
else:
txt_seq, txt_pooled = None, None
if 'imgs' in batch:
sb = batch['imgs']
img_seq, img_pooled, img_hidden = self.get_representation(self.img_model, sb['input_ids'],
sb['attention_mask'], sb['position_ids'],
sb['img_feat'], sb['img_pos_feat'],
sb['img_masks'],
sb['gather_index'], self.fix_txt_encoder)
else:
img_seq, img_pooled = None, None
if 'caps' in batch and batch['caps']['input_ids'] is not None:
sb = batch['caps']
cap_seq, cap_pooled, cap_hidden = self.get_representation(self.txt_model, sb['input_ids'],
sb['attention_mask'], sb['position_ids'],
sb['img_feat'], sb['img_pos_feat'],
sb['img_masks'],
sb['gather_index'], self.fix_txt_encoder)
else:
cap_seq, cap_pooled = None, None
if output_all_encoded_layers:
return txt_seq, img_seq, cap_seq
else:
return txt_pooled, img_pooled, cap_pooled
class BiEncoderForPretraining(nn.Module):
""" MLM + MRM """
def __init__(self, config_file, args, project_dim, img_dim, img_label_dim, nce_temp=1, ot_pos_only=False,
experiment=None):
super().__init__()
config = UniterConfig.from_json_file(config_file)
self.bert = BiEncoder(args, project_dim=project_dim)
self.cls = BertOnlyMLMHead(
config, self.bert.img_model.bert.embeddings.word_embeddings.weight) # ???
self.feat_regress = RegionFeatureRegression(
config.hidden_size, img_dim,
self.bert.img_model.bert.img_embeddings.img_linear.weight)
self.region_classifier = RegionClassification(
config.hidden_size, img_label_dim)
self.itm_output = nn.Linear(config.hidden_size, 2)
self.cls_concat = args.cls_concat
'''
self.nce_output = BertPredictionHeadTransform(config)
self.nce_output = nn.Sequential(BertPredictionHeadTransform(config),
nn.Linear(config.hidden_size, img_dim))
self.nce_norm = LayerNorm(config.hidden_size, eps=1e-12)
self.nce_temp = nce_temp # temperature
'''
self.ot_pos_only = ot_pos_only
# self.apply(self.init_weights)
self.vocab_pad = 0
self.experiment = experiment
def pad_vocab(self):
# FIXME better padding after integrating huggingface ???
emb_w = self.bert.embeddings.word_embeddings.weight.data
padded_emb_w, n_pad = pad_tensor_to_mul(emb_w)
padded_emb_w = nn.Parameter(padded_emb_w)
self.bert.embeddings.word_embeddings.weight = padded_emb_w
self.cls.predictions.decoder.weight = padded_emb_w
self.vocab_pad = n_pad
def forward(self, batch, task, compute_loss=True):
batch = defaultdict(lambda: None, batch)
if task == 'mlm':
txt_labels = batch['txt_labels']
return self.forward_mlm(batch, txt_labels, compute_loss)
elif task == 'mrfr':
img_mask_tgt = batch['img_mask_tgt']
img_masks = batch['img_masks']
mrfr_feat_target = batch['feat_targets']
return self.forward_mrfr(batch, img_masks, img_mask_tgt, mrfr_feat_target, compute_loss)
elif task == 'mrm-nce':
raise NotImplementedError('nce does not work')
img_mask_tgt = batch['img_mask_tgt']
img_masks = batch['img_masks']
img_masks_in = batch['img_masks_in']
feat_target = batch['feat_targets']
neg_feats = batch['neg_feats']
return self.forward_mrm_nce(batch,
img_masks_in, img_masks, img_mask_tgt,
feat_target, neg_feats, compute_loss)
elif task == 'itm':
targets = batch['targets']
ot_inputs = batch['ot_inputs']
return self.forward_itm(batch,
targets, ot_inputs, compute_loss)
elif task.startswith('mrc'):
img_mask_tgt = batch['img_mask_tgt']
img_masks = batch['img_masks']
mrc_label_target = batch['label_targets']
return self.forward_mrc(batch,
img_masks, img_mask_tgt,
mrc_label_target, task, compute_loss)
else:
raise ValueError('invalid task')
# MLM
def forward_mlm(self, batch, txt_labels, compute_loss=True):
txt_seq, img_seq, cap_seq = self.bert(batch, output_all_encoded_layers=True)
# get only the text part
img_cls = img_seq[:, 0:1, :].repeat(1, txt_seq.shape[1], 1)
if self.cls_concat == 'add':
sequence_output = txt_seq + img_cls
elif self.cls_concat == 'multiply':
sequence_output = txt_seq * img_cls
elif len(self.cls_concat) == 0:
sequence_output = txt_seq
else:
raise NotImplementedError(f'{self.cls_concat} not implemented yet')
# only compute masked tokens for better efficiency
masked_output = self._compute_masked_hidden(sequence_output,
txt_labels != -1)
prediction_scores = self._pad_layer_unpad(masked_output, self.cls)
if self.vocab_pad:
prediction_scores = prediction_scores[:, :-self.vocab_pad]
masked_lm_loss = F.cross_entropy(prediction_scores,
txt_labels[txt_labels != -1],
reduction='none')
return masked_lm_loss, prediction_scores
def _compute_masked_hidden(self, hidden, mask):
""" get only the masked region (don't compute unnecessary hiddens) """
mask = mask.unsqueeze(-1).expand_as(hidden)
hidden_masked = hidden[mask].contiguous().view(-1, hidden.size(-1))
return hidden_masked
def _pad_layer_unpad(self, input_, layer):
input_, n_pad = pad_tensor_to_mul(input_)
output = layer(input_)
if n_pad:
output = output[:-n_pad, :]
return output
def mlm_eval(self, batch, gather_tgt):
raise ValueError('Do not use this')
sequence_output = self.bert(batch, output_all_encoded_layers=False)
# get only the text part (excluding [CLS], [SEP])
sequence_output = sequence_output[:, 1:input_ids.size(1)-1, :]
# only compute masked tokens for better efficiency
index = gather_tgt.unsqueeze(-1).expand(
-1, -1, self.config.hidden_size)
masked_output = torch.gather(sequence_output, dim=0, index=index)
prediction_scores = self.cls(masked_output)
if self.vocab_pad:
prediction_scores = prediction_scores[..., :-self.vocab_pad]
return prediction_scores
# MRFR
def forward_mrfr(self, batch, img_masks, img_mask_tgt,
feat_targets, compute_loss=True):
txt_seq, img_seq, cap_seq = self.bert(batch, output_all_encoded_layers=True)
txt_cls = txt_seq[:, 0:1, :].repeat(1, img_seq.shape[1], 1)
if self.cls_concat == 'add':
sequence_output = img_seq + txt_cls
elif self.cls_concat == 'multiply':
sequence_output = img_seq * txt_cls
elif len(self.cls_concat) == 0:
sequence_output = img_seq
else:
raise NotImplementedError(f'{self.cls_concat} not implemented yet')
# only compute masked tokens for better efficiency
masked_output = self._compute_masked_hidden(sequence_output,
img_mask_tgt)
prediction_feat = self._pad_layer_unpad(masked_output,
self.feat_regress)
mrfr_loss = F.mse_loss(prediction_feat, feat_targets,
reduction='none')
return mrfr_loss, prediction_feat
# MRM-NCE
def forward_mrm_nce(self,batch,
img_masks_in, img_masks, img_mask_tgt,
feat_targets, neg_feats, compute_loss=True):
sequence_output = self.bert(batch,
output_all_encoded_layers=False,
img_masks=img_masks_in)
# only compute masked tokens for better efficiency
masked_output = self._compute_masked_hidden(sequence_output,
img_mask_tgt)
masked_output = self._pad_layer_unpad(masked_output, self.nce_output)
# neg within batch
batch_neg = self._compute_masked_hidden(img_feat, ~img_masks)
neg_feats, _ = pad_tensor_to_mul(
torch.cat([neg_feats, batch_neg], dim=0))
# shared image linear transform
neg_output = self.nce_norm(
self.bert.img_embeddings.img_linear(neg_feats))
pos_output = self._pad_layer_unpad(feat_targets,
self.bert.img_embeddings.img_linear)
pos_output = self.nce_norm(pos_output)
mrm_nce_loss = self.mrm_nce(masked_output, pos_output,
neg_output, compute_loss=True)
return mrm_nce_loss, masked_output
def mrm_nce(self, masked_output, pos_output, neg_output,
compute_loss=True):
# dot product of ground truth feature
masked_score = masked_output.matmul(pos_output.t())
# dot product of neative samples
neg_score = masked_output.matmul(neg_output.t())
logits = torch.cat([masked_score, neg_score], dim=1).float()
targets = torch.arange(0, masked_output.size(0),
dtype=torch.long, device=logits.device)
loss = F.cross_entropy(logits/self.nce_temp, targets,
reduction='none')
return loss, logits
def forward_itm(self, batch, targets, ot_inputs,
compute_loss=True):
txt_seq, img_seq, cap_seq = self.bert(batch, output_all_encoded_layers=False)
# OT loss
if ot_inputs is not None:
ot_scatter = ot_inputs['ot_scatter']
b = sequence_output.size(0)
tl = input_ids.size(1)
il = img_feat.size(1)
max_l = max(ot_inputs['scatter_max'] + 1, tl+il)
ot_scatter = ot_scatter.unsqueeze(-1).expand_as(sequence_output)
ctx_emb = torch.zeros(b, max_l, self.config.hidden_size,
dtype=sequence_output.dtype,
device=sequence_output.device
).scatter_(dim=1, index=ot_scatter,
src=sequence_output)
txt_emb = ctx_emb[:, :tl, :]
img_emb = ctx_emb[:, tl:tl+il, :]
txt_pad = ot_inputs['txt_pad']
img_pad = ot_inputs['img_pad']
ot_dist = optimal_transport_dist(txt_emb, img_emb,
txt_pad, img_pad)
if self.ot_pos_only:
ot_loss = ot_dist.masked_select(targets == 1)
else:
ot_pos_dist = ot_dist.masked_select(targets == 1)
ot_neg_dist = ot_dist.masked_select(targets == 0)
ot_loss = (ot_pos_dist, ot_neg_dist)
else:
ot_loss = None
loss_function = BiEncoderNllLoss()
itm_loss1, is_correct1, scores1 = loss_function.calc(txt_seq, img_seq, cap_seq,
batch['pos_ctx_indices'],
batch['neg_ctx_indices'],
0.0, self.experiment, 'none')
itm_loss2, is_correct2, scores2 = loss_function.calc(img_seq, txt_seq, cap_seq,
batch['pos_ctx_indices'],
batch['neg_ctx_indices'],
0.0, self.experiment, 'none')
if compute_loss:
return itm_loss1*0.5 + itm_loss2*0.5, ot_loss
else:
return itm_loss1*0.5 + itm_loss2*0.5, ot_loss, is_correct1*0.5 + is_correct2*0.5
# MRC
def forward_mrc(self, batch, img_masks, img_mask_tgt,
label_targets, task, compute_loss=True):
txt_seq, img_seq, cap_seq = self.bert(batch, output_all_encoded_layers=True)
txt_cls = txt_seq[:, 0:1, :].repeat(1, img_seq.shape[1], 1)
if self.cls_concat == 'add':
sequence_output = img_seq + txt_cls
elif self.cls_concat == 'multiply':
sequence_output = img_seq * txt_cls
elif len(self.cls_concat) == 0:
sequence_output = img_seq
else:
raise NotImplementedError(f'{self.cls_concat} not implemented yet')
# sequence_output = torch.cat([txt_seq, img_seq], dim=1)
# only compute masked regions for better efficiency
masked_output = self._compute_masked_hidden(sequence_output, img_mask_tgt)
prediction_soft_label = self._pad_layer_unpad(masked_output,
self.region_classifier)
if "kl" in task:
prediction_soft_label = F.log_softmax(
prediction_soft_label, dim=-1)
mrc_loss = F.kl_div(
prediction_soft_label, label_targets, reduction='none')
else:
# background class should not be the target
label_targets = torch.max(label_targets[:, 1:], dim=-1)[1] + 1
mrc_loss = F.cross_entropy(
prediction_soft_label, label_targets,
ignore_index=0, reduction='none')
return mrc_loss, prediction_soft_label
def get_optimizer(model: nn.Module, learning_rate: float = 1e-5, adam_eps: float = 1e-8,
weight_decay: float = 0.0, ) -> torch.optim.Optimizer:
no_decay = ['bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
{'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
'weight_decay': weight_decay},
{'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]
optimizer = AdamW(optimizer_grouped_parameters, lr=learning_rate, eps=adam_eps)
return optimizer
def setup_for_distributed_mode(model: nn.Module, optimizer: torch.optim.Optimizer, device: object, n_gpu: int = 1,
local_rank: int = -1,
fp16: bool = False,
fp16_opt_level: str = "O1",
teacher_model = None) -> (nn.Module, torch.optim.Optimizer):
model.to(device)
if teacher_model is not None:
teacher_model.to(device)
if fp16:
try:
import apex
from apex import amp
apex.amp.register_half_function(torch, "einsum")
except ImportError:
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
if optimizer is None:
if teacher_model is None:
model = amp.initialize(model, optimizer, opt_level=fp16_opt_level)
else:
model, teacher_model = amp.initialize([model, teacher_model], optimizer, opt_level=fp16_opt_level)
else:
model, optimizer = amp.initialize(model, optimizer, opt_level=fp16_opt_level)
#if n_gpu > 1:
# model = torch.nn.DataParallel(model)
# if local_rank != -1:
# model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank],
# output_device=local_rank,
# find_unused_parameters=True)
return model, optimizer
class BiEncoderNllLoss(object):
def calc(self, q_vectors: T, ctx_vectors: T, caption_vectors: T, positive_idx_per_question: list,
hard_negatice_idx_per_question: list = None, caption_score_weight: float = 0.1,
experiment=None, reduction='mean'):
"""
Computes nll loss for the given lists of question and ctx vectors.
Note that although hard_negatice_idx_per_question in not currently in use, one can use it for the
loss modifications. For example - weighted NLL with different factors for hard vs regular negatives.
:return: a tuple of loss value and amount of correct predictions per batch
"""
scores_img = self.get_scores(q_vectors, ctx_vectors)
if caption_vectors is not None and caption_score_weight != 0:
scores_caption = self.get_scores(q_vectors, caption_vectors)
scores = (1 - caption_score_weight) * scores_img + caption_score_weight * scores_caption
else:
scores = scores_img
if experiment is not None:
experiment.log_metric('score_img_diag_mean', torch.diag(scores_img).mean().item())
experiment.log_metric('score_img_offdiag_mean', (scores_img.sum() - torch.diag(scores_img).sum()) /
(torch.numel(scores_img)-len(torch.diag(scores_img))))
experiment.log_metric('score_diag_mean', torch.diag(scores).mean().item())
experiment.log_metric('score_offdiag_mean', (scores.sum() - torch.diag(scores).sum()) /
(torch.numel(scores) - len(torch.diag(scores))))
if caption_vectors is not None and caption_score_weight != 0:
experiment.log_metric('score_caption_diag_mean', torch.diag(scores_caption).mean().item())
experiment.log_metric('score_caption_offdiag_mean', (scores_caption.sum() - torch.diag(scores_caption).sum()) /
(torch.numel(scores_caption) - len(torch.diag(scores_caption))))
if len(q_vectors.size()) > 1:
q_num = q_vectors.size(0)
scores = scores.view(q_num, -1)
softmax_scores = F.log_softmax(scores, dim=1)
loss = F.nll_loss(softmax_scores, torch.tensor(positive_idx_per_question).to(softmax_scores.device),
reduction=reduction)
max_score, max_idxs = torch.max(softmax_scores, 1)
correct_predictions_count = (max_idxs == torch.tensor(positive_idx_per_question).to(max_idxs.device)).sum()
return loss, correct_predictions_count, scores
@staticmethod
def get_scores(q_vector: T, ctx_vectors: T) -> T:
f = BiEncoderNllLoss.get_similarity_function()
return f(q_vector, ctx_vectors)
@staticmethod
def get_similarity_function():
return dot_product_scores
def get_schedule_linear(optimizer, warmup_steps, training_steps, last_epoch=-1):
""" Create a schedule with a learning rate that decreases linearly after
linearly increasing during a warmup period.
"""
def lr_lambda(current_step):
if current_step < warmup_steps:
return float(current_step) / float(max(1, warmup_steps))
return max(
0.0, float(training_steps - current_step) / float(max(1, training_steps - warmup_steps))
)
return LambdaLR(optimizer, lr_lambda, last_epoch)
class BiEncoderForVisualQuestionAnswering(nn.Module):
""" Finetune multi-modal BERT for VQA
"""
def __init__(self, args, fix_img_encoder: bool = False, fix_txt_encoder: bool = False,
seperate_caption_encoder: bool = False,
project_dim: int = 0,
hidden_size: int = 0, num_answer: int = 0, intersection=False):
super(BiEncoderForVisualQuestionAnswering, self).__init__()
self.biencoder = BiEncoder(args, fix_img_encoder, fix_txt_encoder, project_dim)
self.intersection = intersection
if self.intersection:
hidden_size *= 2
self.vqa_output = nn.Sequential(
nn.Linear(hidden_size, hidden_size*2),
GELU(),
LayerNorm(hidden_size*2, eps=1e-12),
nn.Linear(hidden_size*2, num_answer)
)
self.init_weights(self.vqa_output)
def forward(self, batch, compute_loss=True, targets=None) -> Tuple[T, T]:
q_pooled, ctx_pooled, caption_pooled = self.biencoder(batch)
if self.intersection:
pooled_output = torch.cat([q_pooled, ctx_pooled, q_pooled*ctx_pooled, q_pooled + ctx_pooled], dim=1)
else:
pooled_output = torch.cat([q_pooled, ctx_pooled], dim=1)
answer_scores = self.vqa_output(pooled_output)
if compute_loss:
vqa_loss = F.binary_cross_entropy_with_logits(
answer_scores, targets, reduction='none')
return vqa_loss
else:
return answer_scores
def init_weights(self, module):
""" Initialize the weights.
"""
if isinstance(module, (nn.Linear, nn.Embedding)):
# Slightly different from the TF version which uses
# truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0,
std=0.02)
elif isinstance(module, LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
def load_biencoder_checkpoint(bi_encoder, biencoder_checkpoint):
if biencoder_checkpoint is not None and len(biencoder_checkpoint) > 0 and biencoder_checkpoint.lower() != 'none':
logger.info(f'loading ckpt from {biencoder_checkpoint}')
state_dict = torch.load(biencoder_checkpoint, map_location='cpu')
try:
bi_encoder.load_state_dict(state_dict['model_dict'])
except KeyError:
logger.info('loading from pre-trained model instead')
for k in list(state_dict.keys()):
if k.startswith('bert.'):
state_dict[k[5:]] = state_dict.pop(k)
else:
state_dict.pop(k)
bi_encoder.load_state_dict(state_dict, strict=True)
else:
logger.info('no checkpoint provided, pass')

176
dvl/options.py

@ -0,0 +1,176 @@
import argparse
import json
import sys
import os
import logging
import torch
import random
import socket
import numpy as np
logger = logging.getLogger()
def default_params(parser: argparse.ArgumentParser):
parser.add_argument('--txt_model_type', default='bert-base', type=str, help="")
parser.add_argument('--txt_model_config', default='bert-base', type=str, help="")
parser.add_argument('--txt_checkpoint', default=None, type=str, help="")
parser.add_argument('--img_model_type', default='uniter-base', type=str, help="")
parser.add_argument('--img_model_config', default='./config/img_base.json', type=str, help="")
parser.add_argument('--img_checkpoint', default=None, type=str, help="")
parser.add_argument('--biencoder_checkpoint', default=None, type=str, help="")
parser.add_argument('--seperate_caption_encoder', action='store_true', help="")
parser.add_argument('--train_batch_size', default=80, type=int, help="")
parser.add_argument('--valid_batch_size', default=80, type=int, help="")
parser.add_argument('--gradient_accumulation_steps', default=1, type=int, help="")
parser.add_argument('--learning_rate', default=1e-5, type=float, help="")
parser.add_argument('--max_grad_norm', default=2.0, type=float, help="")
parser.add_argument('--warmup_steps', default=500, type=int, help="")
parser.add_argument('--valid_steps', default=500, type=int, help="")
parser.add_argument('--num_train_steps', default=5000, type=int, help="")
parser.add_argument('--num_train_epochs', default=0, type=int, help="")
parser.add_argument('--fp16', action='store_true', help="")
parser.add_argument('--seed', default=42, type=int, help="")
parser.add_argument('--output_dir', default='./', type=str, help="")
parser.add_argument('--max_txt_len', default=64, type=int, help="")
parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus")
parser.add_argument('--config', default=None, type=str, help="")
parser.add_argument('--itm_global_file', default=None, type=str, help="")
parser.add_argument("--no_cuda", action='store_true', help="Whether not to use CUDA when available")
parser.add_argument('--n_workers', type=int, default=2, help="number of data workers")
parser.add_argument('--pin_mem', action='store_true', help="pin memory") # ???
parser.add_argument('--hnsw_index', action='store_true', help="")
parser.add_argument('--fp16_opt_level', type=str, default='O1', help="")
parser.add_argument('--img_meta', type=str, default=None, help="")
def add_itm_params(parser: argparse.ArgumentParser):
parser.add_argument('--conf_th', default=0.2, type=float, help="")
parser.add_argument('--caption_score_weight', default=0.0, type=float, help="")
parser.add_argument('--negative_size', default=10, type=int, help="")
parser.add_argument('--num_hard_negatives', default=0, type=int, help="")
parser.add_argument('--sample_init_hard_negatives', action='store_true', help="")
parser.add_argument('--hard_negatives_sampling', default='none', type=str,
choices=['none', 'random', 'top', 'top-random', '10-20', '20-30'], help="")
parser.add_argument('--max_bb', default=100, type=int, help="")
parser.add_argument('--min_bb', default=10, type=int, help="")
parser.add_argument('--num_bb', default=36, type=int, help="")
parser.add_argument('--train_txt_dbs', default=None, type=str, help="")
parser.add_argument('--train_img_dbs', default=None, type=str, help="")
parser.add_argument('--txt_db_mapping', default=None, type=str, help="")
parser.add_argument('--img_db_mapping', default=None, type=str, help="")
parser.add_argument('--pretrain_mapping', default=None, type=str, help="")
parser.add_argument('--val_txt_db', default=None, type=str, help="")
parser.add_argument('--val_img_db', default=None, type=str, help="")
parser.add_argument('--test_txt_db', default=None, type=str, help="")
parser.add_argument('--test_img_db', default=None, type=str, help="")
parser.add_argument('--steps_per_hard_neg', default=-1, type=int, help="")
parser.add_argument('--inf_minibatch_size', default=400, type=int, help="")
parser.add_argument('--project_dim', default=0, type=int, help='')
parser.add_argument('--cls_concat', default="", type=str, help='')
parser.add_argument('--fix_txt_encoder', action='store_true', help='')
parser.add_argument('--fix_img_encoder', action='store_true', help='')
parser.add_argument('--compressed_db', action='store_true', help='use compressed LMDB')
parser.add_argument('--retrieval_mode', default="both",
choices=['img_only', 'txt_only', 'both'], type=str, help="")
def add_logging_params(parser: argparse.ArgumentParser):
parser.add_argument('--log_result_step', default=4, type=int, help="")
parser.add_argument('--project_name', default='itm', type=str, help="")
parser.add_argument('--expr_name_prefix', default='', type=str, help="")
parser.add_argument('--save_all_epochs', action='store_true', help="")
def add_kd_params(parser: argparse.ArgumentParser):
parser.add_argument('--teacher_checkpoint', default=None, type=str, help="")
parser.add_argument('--T', default=1.0, type=float, help="")
parser.add_argument('--kd_loss_weight', default=1.0, type=float, help="")
def parse_with_config(parser, cmds=None):
if cmds is None:
args = parser.parse_args()
else:
args = parser.parse_args(cmds)
if args.config is not None:
config_args = json.load(open(args.config))
override_keys = {arg[2:].split('=')[0] for arg in sys.argv[1:]
if arg.startswith('--')}
for k, v in config_args.items():
if k not in override_keys:
setattr(args, k, v)
return args
def map_db_dirs(args):
# map img db
for k in args.__dict__:
if not isinstance(args.__dict__[k], str):
continue
if args.__dict__[k].startswith('/pretrain') and args.pretrain_mapping:
print('pretrain', k, args.__dict__[k])
args.__dict__[k] = args.__dict__[k].replace('/pretrain', args.pretrain_mapping)
if args.__dict__[k].startswith('/db') and args.txt_db_mapping:
print('db', k, args.__dict__[k])
args.__dict__[k] = args.__dict__[k].replace('/db', args.txt_db_mapping)
if args.__dict__[k].startswith('/img') and args.img_db_mapping:
print('img', k, args.__dict__[k])
args.__dict__[k] = args.__dict__[k].replace('/img', args.img_db_mapping)
if args.img_db_mapping:
for i in range(len(args.train_img_dbs)):
args.train_img_dbs[i] = args.train_img_dbs[i].replace('/img', args.img_db_mapping)
if args.txt_db_mapping:
for i in range(len(args.train_txt_dbs)):
args.train_txt_dbs[i] = args.train_txt_dbs[i].replace('/db', args.txt_db_mapping)
def print_args(args):
logger.info(" **************** CONFIGURATION **************** ")
for key, val in sorted(vars(args).items()):
keystr = "{}".format(key) + (" " * (30 - len(key)))
logger.info("%s --> %s", keystr, val)
logger.info(" **************** END CONFIGURATION **************** ")
def set_seed(args):
seed = args.seed
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if args.n_gpu > 0:
torch.cuda.manual_seed_all(seed)
def setup_args_gpu(args):
"""
Setup arguments CUDA, GPU & distributed training
"""
if args.local_rank == -1 or args.no_cuda: # single-node multi-gpu (or cpu) mode
device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
args.n_gpu = torch.cuda.device_count()
else: # distributed mode
torch.cuda.set_device(args.local_rank)
device = torch.device("cuda", args.local_rank)
torch.distributed.init_process_group(backend="nccl")
args.n_gpu = 1
args.device = device
ws = os.environ.get('WORLD_SIZE')
args.distributed_world_size = int(ws) if ws else 1
logger.info(
'Initialized host %s as d.rank %d on device=%s, n_gpu=%d, world size=%d', socket.gethostname(),
args.local_rank, device,
args.n_gpu,
args.distributed_world_size)
logger.info("16-bits training: %s ", args.fp16)

209
dvl/trainer.py

@ -0,0 +1,209 @@
import collections
import os
import torch
import tqdm
import logging
import numpy as np
import torch.nn as nn
from torch.utils.data import DataLoader, ConcatDataset, ChainDataset
from uniter_model.data.loader import PrefetchLoader
from dvl.data.itm import TxtTokLmdb, ItmFastDataset, ItmValDataset, itm_fast_collate
from dvl.models.bi_encoder import BiEncoderNllLoss
from dvl.utils import _calc_loss
from dvl.indexer.faiss_indexers import DenseFlatIndexer, DenseHNSWFlatIndexer
logger = logging.getLogger()
CheckpointState = collections.namedtuple("CheckpointState",
['model_dict', 'optimizer_dict', 'scheduler_dict', 'offset', 'epoch',
'encoder_params'])
class BiEncoderTrainer:
def __init__(self, args):
pass
def build_dataloader(dataset, collate_fn, is_train, opts, batch_size=None):
if batch_size is None:
batch_size = opts.train_batch_size if is_train else opts.valid_batch_size
dataloader = DataLoader(dataset, batch_size=batch_size,
shuffle=is_train, drop_last=False,
num_workers=opts.n_workers,
pin_memory=opts.pin_mem, collate_fn=collate_fn)
dataloader = PrefetchLoader(dataloader)
return dataloader
def get_model_obj(model: nn.Module):
return model.module if hasattr(model, 'module') else model
def _save_checkpoint(args, biencoder, optimizer, scheduler, epoch: int, offset: int, cp_name: str = None) -> str:
model_to_save = get_model_obj(biencoder)
if cp_name is None:
cp = os.path.join(args.output_dir, 'biencoder.' + str(epoch) + ('.' + str(offset) if offset > 0 else ''))
else:
cp = os.path.join(args.output_dir, 'biencoder.' + cp_name)
cp += '.pt'
meta_params = None
state = CheckpointState(model_to_save.state_dict(),
optimizer.state_dict(),
scheduler.state_dict(),
offset,
epoch, meta_params
)
torch.save(state._asdict(), cp)
logger.info('Saved checkpoint at %s', cp)
return cp
def load_saved_state(biencoder, optimizer=None, scheduler=None, saved_state: CheckpointState = ''):
epoch = saved_state.epoch
offset = saved_state.offset
if offset == 0: # epoch has been completed
epoch += 1
logger.info('Loading checkpoint @ batch=%s and epoch=%s', offset, epoch)
model_to_load = get_model_obj(biencoder)
logger.info('Loading saved model state ...')
model_to_load.load_state_dict(saved_state.model_dict) # set strict=False if you use extra projection
if saved_state.optimizer_dict and optimizer is not None:
logger.info('Loading saved optimizer state ...')
optimizer.load_state_dict(saved_state.optimizer_dict)
if saved_state.scheduler_dict and scheduler is not None:
scheduler_state = saved_state.scheduler_dict
scheduler.load_state_dict(scheduler_state)
def load_states_from_checkpoint(model_file: str) -> CheckpointState:
logger.info('Reading saved model from %s', model_file)
state_dict = torch.load(model_file, map_location='cpu')
logger.info('model_state_dict keys %s', state_dict.keys())
return CheckpointState(**state_dict)
def get_indexer(bi_encoder, eval_dataloader, args, hnsw_index, img_retrieval=True):
bi_encoder.eval()
img_embedding = dict()
if hnsw_index:
indexer_img = DenseHNSWFlatIndexer(args.vector_size) # modify in future
else:
indexer_img = DenseFlatIndexer(args.vector_size) # modify in future
for i, batch in enumerate(tqdm.tqdm(eval_dataloader)):
with torch.no_grad():
model_out = bi_encoder(batch)
local_q_vector, local_ctx_vectors, local_caption_vectors = model_out
if img_retrieval:
img_embedding.update({img_id: img_vec.detach().cpu().numpy() for img_id, img_vec in zip(batch['img_fname'], local_ctx_vectors)})
else:
img_embedding.update({img_id: txt_vec.detach().cpu().numpy() for img_id, txt_vec in zip(batch['txt_index'], local_q_vector)})
indexer_img.index_data(list(img_embedding.items()))
return indexer_img
def eval_model_on_dataloader(bi_encoder, eval_dataloader, args, img2txt=None, num_tops=100, no_eval=False):
total_loss = 0.0
bi_encoder.eval()
total_correct_predictions = 0
batches, total_samples = 0, 0
labels_img_name = []
labels_txt_name = []
img_embedding = dict()
txt_embedding = dict()
if args.hnsw_index:
indexer_img = DenseHNSWFlatIndexer(args.vector_size) # modify in future
indexer_txt = DenseHNSWFlatIndexer(args.vector_size) # modify in future
else:
indexer_img = DenseFlatIndexer(args.vector_size) # modify in future
indexer_txt = DenseFlatIndexer(args.vector_size) # modify in future
query_txt, query_txt_id = [], []
query_img, query_img_id = [], []
for i, batch in enumerate(eval_dataloader):
with torch.no_grad():
model_out = bi_encoder(batch)
local_q_vector, local_ctx_vectors, local_caption_vectors = model_out
query_txt.extend([out.view(-1).detach().cpu().numpy() for out in local_q_vector])
query_txt_id.extend(batch['txt_index'])
query_img.extend([out.view(-1).detach().cpu().numpy() for out in local_ctx_vectors])
query_img_id.extend(batch['img_fname'])
loss_function = BiEncoderNllLoss()
loss, correct_cnt, score = _calc_loss(args, loss_function, local_q_vector, local_ctx_vectors, local_caption_vectors,
list(range(len(local_q_vector))), None)
total_loss += loss.item()
total_correct_predictions += correct_cnt.sum().item()
batches += 1
total_samples += batch['txts']['input_ids'].shape[0]
img_embedding.update({img_id: img_vec.detach().cpu().numpy() for img_id, img_vec in zip(batch['img_fname'], local_ctx_vectors)})
txt_embedding.update({img_id: txt_vec.detach().cpu().numpy() for img_id, txt_vec in zip(batch['txt_index'], local_q_vector)})
labels_img_name.extend(batch['img_fname'])
labels_txt_name.extend(batch['txt_index'])
total_loss = total_loss / batches
correct_ratio = total_correct_predictions / float(total_samples)
query_txt_np = np.array(query_txt)
indexer_img.index_data(list(img_embedding.items()))
query_img_np = np.array(query_img)
indexer_txt.index_data(list(txt_embedding.items()))
if no_eval:
return total_loss, correct_ratio, (indexer_img, indexer_txt), (None, None), (None, None)
else:
res_txt = indexer_img.search_knn(query_txt_np, num_tops)
rank_txt_res = {query_txt_id[i]: r[0] for i, r in enumerate(res_txt)}
res_img = indexer_txt.search_knn(query_img_np, num_tops)
rank_img_res = {query_img_id[i]: r[0] for i, r in enumerate(res_img)}
recall_txt = {1: 0, 5: 0, 10: 0}
for i, q in enumerate(query_txt_id):
for top in recall_txt:
recall_txt[top] += labels_img_name[i] in rank_txt_res[q][:top]
for top in recall_txt:
recall_txt[top] = recall_txt[top] / len(rank_txt_res)
recall_img = {1: 0, 5: 0, 10: 0}
for i, q in enumerate(np.unique(query_img_id)):
for top in recall_img:
# recall_img[top] += any([txt_id in rank_img_res[q][:top] for txt_id in img2txt[q]])
recall_img[top] += any([txt_id in rank_img_res[q][:top] for txt_id in img2txt[q]])
for top in recall_img:
recall_img[top] = recall_img[top] / len(rank_img_res)
return total_loss, correct_ratio, (indexer_img, indexer_txt), (recall_txt, recall_img), (rank_txt_res, rank_img_res)
def load_dataset(all_img_dbs, txt_dbs, img_dbs, args, is_train):
if is_train:
# train datasets
datasets = []
for txt_path, img_path in zip(txt_dbs, img_dbs):
img_db = all_img_dbs[img_path]
txt_db = TxtTokLmdb(txt_path, args.max_txt_len)
datasets.append(ItmFastDataset(txt_db, img_db, args.num_hard_negatives, args.img_meta, args.tokenizer))
datasets = ConcatDataset(datasets) #
else:
# eval or test
img_db = all_img_dbs[img_dbs]
txt_db = TxtTokLmdb(txt_dbs, -1)
datasets = ItmFastDataset(txt_db, img_db, args.inf_minibatch_size, args.img_meta, args.tokenizer)
return datasets

234
dvl/utils.py

@ -0,0 +1,234 @@
import logging
import random
import tqdm
import torch
import pickle
import torch.distributed as dist
from collections import defaultdict
from horovod import torch as hvd
from torch import Tensor as T
from typing import Tuple
logger = logging.getLogger()
def get_rank():
return hvd.rank()
def get_world_size():
return hvd.size()
def print_args(args):
logger.info(" **************** CONFIGURATION **************** ")
for key, val in sorted(vars(args).items()):
keystr = "{}".format(key) + (" " * (30 - len(key)))
logger.info("%s --> %s", keystr, val)
logger.info(" **************** CONFIGURATION **************** ")
def num_of_parameters(model, requires_grad=False):
if requires_grad:
return sum(p.numel() for p in model.parameters() if p.requires_grad)
else:
return sum(p.numel() for p in model.parameters())
def get_default_group():
return dist.group.WORLD
def all_reduce(tensor, group=None):
if group is None:
group = get_default_group()
return dist.all_reduce(tensor, group=group)
def all_gather_list(data, group=None, max_size=16384):
"""Gathers arbitrary data from all nodes into a list.
Similar to :func:`~torch.distributed.all_gather` but for arbitrary Python
data. Note that *data* must be picklable.
Args:
data (Any): data from the local worker to be gathered on other workers
group (optional): group of the collective
"""
SIZE_STORAGE_BYTES = 4 # int32 to encode the payload size
enc = pickle.dumps(data)
enc_size = len(enc)
if enc_size + SIZE_STORAGE_BYTES > max_size:
raise ValueError(
'encoded data exceeds max_size, this can be fixed by increasing buffer size: {}'.format(enc_size))
rank = get_rank()
world_size = get_world_size()
buffer_size = max_size * world_size
if not hasattr(all_gather_list, '_buffer') or \
all_gather_list._buffer.numel() < buffer_size:
all_gather_list._buffer = torch.cuda.ByteTensor(buffer_size)
all_gather_list._cpu_buffer = torch.ByteTensor(max_size).pin_memory()
buffer = all_gather_list._buffer
buffer.zero_()
cpu_buffer = all_gather_list._cpu_buffer
assert enc_size < 256 ** SIZE_STORAGE_BYTES, 'Encoded object size should be less than {} bytes'.format(
256 ** SIZE_STORAGE_BYTES)
size_bytes = enc_size.to_bytes(SIZE_STORAGE_BYTES, byteorder='big')
cpu_buffer[0:SIZE_STORAGE_BYTES] = torch.ByteTensor(list(size_bytes))
cpu_buffer[SIZE_STORAGE_BYTES: enc_size + SIZE_STORAGE_BYTES] = torch.ByteTensor(list(enc))
start = rank * max_size
size = enc_size + SIZE_STORAGE_BYTES
buffer[start: start + size].copy_(cpu_buffer[:size])
all_reduce(buffer, group=group)
try:
result = []
for i in range(world_size):
out_buffer = buffer[i * max_size: (i + 1) * max_size]
size = int.from_bytes(out_buffer[0:SIZE_STORAGE_BYTES], byteorder='big')
if size > 0:
result.append(pickle.loads(bytes(out_buffer[SIZE_STORAGE_BYTES: size + SIZE_STORAGE_BYTES].tolist())))
return result
except pickle.UnpicklingError:
raise Exception(
'Unable to unpickle data from other workers. all_gather_list requires all '
'workers to enter the function together, so this error usually indicates '
'that the workers have fallen out of sync somehow. Workers can fall out of '
'sync if one of them runs out of memory, or if there are other conditions '
'in your training script that can cause one worker to finish an epoch '
'while other workers are still iterating over their portions of the data.'
)
def _calc_loss(args, loss_function, local_q_vector, local_ctx_vectors, local_caption_vectors, local_positive_idxs,
local_hard_negatives_idxs: list = None, experiment=None
):
"""
Calculates In-batch negatives schema loss and supports to run it in DDP mode by exchanging the representations
across all the nodes.
"""
distributed_world_size = 1 # args.distributed_world_size or 1
if distributed_world_size > 1:
# TODO: Add local_caption_vectors
q_vector_to_send = torch.empty_like(local_q_vector).cpu().copy_(local_q_vector).detach_()
ctx_vector_to_send = torch.empty_like(local_ctx_vectors).cpu().copy_(local_ctx_vectors).detach_()
global_question_ctx_vectors = all_gather_list(
[q_vector_to_send, ctx_vector_to_send, local_positive_idxs, local_hard_negatives_idxs],
max_size=args.global_loss_buf_sz)
global_q_vector = []
global_ctxs_vector = []
# ctxs_per_question = local_ctx_vectors.size(0)
positive_idx_per_question = []
hard_negatives_per_question = []
total_ctxs = 0
for i, item in enumerate(global_question_ctx_vectors):
q_vector, ctx_vectors, positive_idx, hard_negatives_idxs = item
if i != args.local_rank:
global_q_vector.append(q_vector.to(local_q_vector.device))
global_ctxs_vector.append(ctx_vectors.to(local_q_vector.device))
positive_idx_per_question.extend([v + total_ctxs for v in positive_idx])
hard_negatives_per_question.extend([[v + total_ctxs for v in l] for l in hard_negatives_idxs])
else:
global_q_vector.append(local_q_vector)
global_ctxs_vector.append(local_ctx_vectors)
positive_idx_per_question.extend([v + total_ctxs for v in local_positive_idxs])
hard_negatives_per_question.extend([[v + total_ctxs for v in l] for l in local_hard_negatives_idxs])
total_ctxs += ctx_vectors.size(0)
global_q_vector = torch.cat(global_q_vector, dim=0)
global_ctxs_vector = torch.cat(global_ctxs_vector, dim=0)
else:
global_q_vector = local_q_vector
global_ctxs_vector = local_ctx_vectors
global_caption_vector = local_caption_vectors
positive_idx_per_question = local_positive_idxs
hard_negatives_per_question = local_hard_negatives_idxs
loss, is_correct, scores = loss_function.calc(global_q_vector, global_ctxs_vector, global_caption_vector,
positive_idx_per_question, hard_negatives_per_question,
args.caption_score_weight, experiment)
return loss, is_correct, scores
def compare_models(model_1, model_2):
models_differ = 0
for key_item_1, key_item_2 in zip(model_1.state_dict().items(), model_2.state_dict().items()):
if torch.equal(key_item_1[1], key_item_2[1]):
pass
else:
models_differ += 1
if (key_item_1[0] == key_item_2[0]):
print('Mismtach found at', key_item_1[0])
else:
raise Exception
if models_differ == 0:
print('Models match perfectly! :)')
def is_main_process():
return hvd.rank() == 0
def display_img(img_meta, name, img_only=False):
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
img = mpimg.imread(img_meta[name]['img_file'])
plt.imshow(img)
plt.show()
if not img_only:
print('annotation')
print('\t' + '\n\t'.join(img_meta[name]['annotation']))
print('caption')
print('\t' + img_meta[name]['caption'][0])
def retrieve_query(model, query, indexer, args, top=10):
input_ids = args.tokenizer.encode(query)
input_ids = torch.LongTensor(input_ids).to(args.device).unsqueeze(0)
attn_mask = torch.ones(len(input_ids[0]), dtype=torch.long, device=args.device).unsqueeze(0)
pos_ids = torch.arange(len(input_ids[0]), dtype=torch.long, device=args.device).unsqueeze(0)
_, query_vector, _ = model.txt_model(input_ids=input_ids,attention_mask=attn_mask, position_ids=pos_ids)
res = indexer.search_knn(query_vector.detach().cpu().numpy(), 100)
return res
def get_model_encoded_vecs(model, dataloader):
img_embedding, caption_embedding, query_embedding = dict(), dict(), defaultdict(list)
labels_img_name = []
# for i, batch in enumerate(dataloader):
for i, batch in enumerate(tqdm.tqdm(dataloader)):
with torch.no_grad():
model_out = model(batch)
local_q_vectors, local_ctx_vectors, local_caption_vectors = model_out
img_embedding.update({img_id: img_vec.detach().cpu().numpy() for img_id, img_vec in zip(batch['img_fname'], local_ctx_vectors)})
caption_embedding.update({img_id: cap_vec.detach().cpu().numpy() for img_id, cap_vec in zip(batch['img_fname'], local_caption_vectors)})
query_embedding.update({img_id: cap_vec.detach().cpu().numpy() for img_id, cap_vec in zip(batch['txt_index'], local_q_vectors)})
labels_img_name.extend(batch['img_fname'])
return {
'img_embed': img_embedding,
'caption_embed': caption_embedding,
'txt_embed': query_embedding,
'img_name': labels_img_name
}

52
lightningdot.py

@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import sys
import os
import json
@ -30,7 +31,7 @@ from .utils import Configs, get_gather_index
def arg_process(args):
dirname = os.path.dirname(__file__)
args.img_checkpoint = dirname + '/' + args.img_checkpoint
#args.img_checkpoint = dirname + '/' + args.img_checkpoint
args.img_model_config = dirname + '/' + args.img_model_config
return args
@ -39,19 +40,47 @@ class LightningDOT(NNOperator):
"""
CLIP multi-modal embedding operator
"""
def __init__(self, modality: str):
def __init__(self, model_name:str, modality: str):
logger = logging.getLogger()
sys.path.append(str(Path(__file__).parent))
from dvl.models.bi_encoder import BiEncoder
from detector.faster_rcnn import Net, process_img
from utils import download_file
full_path = os.path.dirname(__file__) + '/config/flickr30k_ft_config.json'
with open(full_path) as fw:
config_path = os.path.dirname(__file__) + self._configs()[model_name]['config']
model_url = self._configs()[model_name]['weights']
weight_name = os.path.basename(model_url)
weight_path = os.path.dirname(__file__) + '/data/model/' + weight_name
if os.path.exists(weight_path) is False:
download_file(model_url, os.path.dirname(__file__) + '/data/model/')
with open(config_path) as fw:
content = fw.read()
args = json.loads(content)
#args['img_checkpoint'] = './data/model/' + weight_name
args = Configs(args)
args = arg_process(args)
self.bi_encoder = BiEncoder(args, True, True, project_dim=args.project_dim)
self.tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
state_dict = torch.load(weight_path, map_location='cpu')
try:
if 'model_dict' in state_dict:
self.bi_encoder.load_state_dict(state_dict['model_dict'])
else:
self.bi_encoder.load_state_dict(state_dict)
except RuntimeError:
logger.info('loading from pre-trained model instead')
for k in list(state_dict.keys()):
if k.startswith('bert.'):
state_dict[k[5:]] = state_dict.pop(k)
else:
state_dict.pop(k)
bi_encoder.load_state_dict(state_dict, strict=True)
img_model, txt_model = self.bi_encoder.img_model, self.bi_encoder.txt_model
img_model.eval()
txt_model.eval()
@ -144,3 +173,18 @@ class LightningDOT(NNOperator):
gather_index, fix_txt_encoder)
return img_pooled
def _configs(self):
config = {}
config['lightningdot_base'] = {}
config['lightningdot_base']['weights'] = 'https://convaisharables.blob.core.windows.net/lightningdot/LightningDot.pt'
config['lightningdot_base']['config'] = '/config/pretrain-alldata-base.json'
config['lightningdot_coco_ft'] = {}
config['lightningdot_coco_ft']['weights'] = 'https://convaisharables.blob.core.windows.net/lightningdot/coco-ft.pt'
config['lightningdot_coco_ft']['config'] = '/config/coco_eval_config.json'
config['lightningdot_flickr_ft'] = {}
config['lightningdot_flickr_ft']['weights'] = 'https://convaisharables.blob.core.windows.net/lightningdot/flickr-ft.pt'
config['lightningdot_flickr_ft']['config'] = '/config/flickr30k_eval_config.json'
return config

3
requirements.txt

@ -2,5 +2,4 @@ torch>=1.9.0
torchvision>=0.10.0
transformers==2.3.0
Pillow
towhee
towhee

22
uniter_model/Dockerfile

@ -0,0 +1,22 @@
FROM nvcr.io/nvidia/pytorch:19.05-py3
COPY requirements.txt scripts/download_bert.py ./
RUN pip install -r requirements.txt &&\
python download_bert.py &&\
rm ./requirements.txt ./download_bert.py
################## v1 ##########################
COPY scripts/install_horovod.sh ./
RUN source install_horovod.sh &&\
rm ./install_horovod.sh
ENV OPENMPI_VERSION=4.0.0
# fix ssh permissions
RUN bash -c "chmod -R 600 /etc/ssh/ && chmod 600 /var/run/sshd/ && chmod 600 /root"
################## horovod, v2 ##########################
RUN bash -c "pip install lz4==2.1.9 lmdb==0.97"
################# LMDB ##########################

21
uniter_model/LICENSE

@ -0,0 +1,21 @@
MIT License
Copyright (c) 2019 Yen-Chun Chen
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

89
uniter_model/README.md

@ -0,0 +1,89 @@
# Universal-Image-Text-Transformer
Research code for pre-training universal vision and language models
## Requirements
nvidia driver (418.xx), docker(19.03+), nvidia-container-toolkit
```
docker pull convaicontainerregistry1.azurecr.io/img-txt
```
## lauching the environment
```
# can use CUDA_VISIBLE_DEVICES to seperate GPUs for each container
source launch_container.sh $TXT_DB $IMG_DIR $OUTPUT $PRETRAIN_PATH
# TXT_DB: convaistorage2share2/TXT_DB_v3
# IMG_DIR: convaistorage2share2/Bottom-up-features/adaptive/npy_per_img_id
# OUTPUT: somewhere to store model checkpoint (can be on share storage)
# PRETRAIN: path to pretrained model
# when need to preprocessing
source launch_container.sh $TXT_DB $IMG_DIR $OUTPUT $PRETRAIN_PATH --prepro
# this will make /db writable
# multi-node training
source launch_container_dist.sh $TXT_DB $IMG_DIR $OUTPUT $PRETRAIN_PATH
```
## Pretrain
```
# inside the docker container
horovodrun -np $N_GPU -H localhost:$N_GPU \
python pretrain.py --config config/config-pretrain-alltask.json
```
## finetune VQA
```
horovodrun -np 2 -H localhost:2 \
python train_vqa.py --config config/config-vqa-bert-2gpu-alldata.json
```
### VQA inference
```
# single node only
# please refer to code for commandline options
horovodrun -np $N_GPU -H localhost:$N_GPU \
python eval_vqa.py --txt_db /db/vqa_test_[base/large]-cased.db/ \
--img_dir /img/coco_test2015 --checkpoint [NUM] \
--output_dir /path/to/trained/vqa
```
### NLVR2 official evaluation
Use official script to get both acc (our validation matched this) and consistency
```
# concat all output files
cat $OUTPUT/result/[val/test]_results_$STEP_rank*.csv > $OUTPUT.csv
python eval/nlvr2.py $OUTPUT.csv ANNOTATION.json
```
### Referring Expression Comprehension: Finetuning and Evaluation
```
# train on gd-truth pairs of (ref, sent)
horovodrun -np $N_GPU -H localhost:$N_GPU \
python train_re.py --config config/hps-refcoco+.json
# evaluate multiple splits on gd-truth boxes
horovodrun -np $N_GPU -H localhost:$N_GPU \
python eval_re.py \
--txt_db /db/refcoco+_val_base-cased.db:/db/refcoco+_testA_base-cased.db:/db/refcoco+_testB_base-cased.db \
--img_dir /img/visual_grounding_coco_gt \
--output_dir /storage/refcoco+/bert-base_mlm+itm+mrfr_pretrain-refcoco+_lr1e-4 \
--checkpoint 26
# evaluate multiple splits on detected boxes
horovodrun -np $N_GPU -H localhost:$N_GPU \
python eval_re.py \
--txt_db /db/refcoco+_val_base-cased.db:/db/refcoco+_testA_base-cased.db:/db/refcoco+_testB_base-cased.db \
--img_dir /img/visual_grounding_det_coco \
--output_dir /storage/refcoco+/bert-base_mlm+itm+mrfr_pretrain-refcoco+_lr1e-4 \
--checkpoint 26
```
## Misc
1. w/o horovodrun it will run on single GPU
- useful for debugger (-m pdb)
2. try `--pin_mem` it might give a tiny performance improvement
3. `--img_format [lmdb/lmdb-compress]`
- trade-off between memory/CPU
- use `--n_workers $N_CPU` to specify data workers (default: 4)

36
uniter_model/config/config-vcr-bert-2gpu.json

@ -0,0 +1,36 @@
{
"train_txt_db": "/db/vcr_val_w_obj_ids_base-cased.db/",
"train_img_dir": "/img/vcr_gt_val/;/img/vcr_val/",
"val_txt_db": "/db/vcr_val_w_obj_ids_base-cased.db/",
"val_img_dir": "/img/vcr_gt_val/;/img/vcr_val/",
"checkpoint": "/storage/pretrain_vcr/mlm_qar-30k_steps-lr_3e-5-run1/ckpt/model_step_29000.pt",
"checkpoint_from": "vcr",
"task": "qa",
"cut_bert": -1,
"output_dir": "/storage/debug/qa_qar-bert_base-gt_proposed_img_feat-mlm-diff_type_id_for_ra-lr_2e-5-train_step_10k",
"max_txt_len": 220,
"conf_th": 0.2,
"max_bb": 100,
"min_bb": 10,
"num_bb": 36,
"train_batch_size": 4000,
"val_batch_size": 12,
"gradient_accumulation_steps": 10,
"learning_rate": 5e-5,
"valid_steps": 10,
"num_train_steps": 1000,
"optim": "adamw",
"betas": [0.9, 0.98],
"grad_norm": 2.0,
"decay": "linear",
"warm_int": 500,
"decay_int": 2000,
"decay_st": 12000,
"decay_rate": 0.2,
"dropout": 0.1,
"weight_decay": 0.01,
"warmup_steps": 1000,
"seed": 42,
"fp16": true,
"mcan": false
}

11
uniter_model/config/eval-itm-coco.json

@ -0,0 +1,11 @@
{
"txt_db": "/db/itm_coco_test_1k_4_base-cased.db/",
"img_dir": "/img/coco_val2014/",
"neg_sample_size": -1,
"eval_split": "itm_coco_test_1k_4_fp16",
"checkpoint": 12000,
"cut_bert": -1,
"output_dir": "/storage/itm/coco-bert_base-weak_420k-w_train-lr_2e-5-20k_steps-rank_loss-batch_size_20_acc_8_wd_0/",
"fp16": true,
"eval_mini_batch_size": 400
}

11
uniter_model/config/eval-itm-flickr.json

@ -0,0 +1,11 @@
{
"txt_db": "/db/itm_flickr30k_train_base-cased.db/",
"img_dir": "/img/flickr30k/",
"neg_sample_size": -1,
"eval_split": "itm_flickr_train",
"checkpoint": 6000,
"cut_bert": -1,
"output_dir": "/storage/itm/flickr30k-bert_base_weak_420k-hard_neg_finetune-lr_5e-5-train_step_5k",
"fp16": true,
"eval_mini_batch_size": 128
}

53
uniter_model/config/hps-itm.json

@ -0,0 +1,53 @@
{
"train_txt_db": ["/db/itm_coco_train_base-cased.db/",
"/db/itm_coco_restval_base-cased.db"],
"train_img_dir": ["/img/coco_train2014/",
"/img/coco_val2014/"],
"train_neg_sample_p": 0.5,
"neg_sample_from": "i",
"eval_method": "rank",
"val_txt_db": ["/db/itm_coco_val_1k_0_base-cased.db/",
"/db/itm_coco_val_1k_1_base-cased.db/",
"/db/itm_coco_val_1k_2_base-cased.db/",
"/db/itm_coco_val_1k_3_base-cased.db/",
"/db/itm_coco_val_1k_4_base-cased.db/"],
"val_img_dir": ["/img/coco_val2014/",
"/img/coco_val2014/",
"/img/coco_val2014/",
"/img/coco_val2014/",
"/img/coco_val2014/"],
"test_txt_db": ["/db/itm_coco_test_1k_0_base-cased.db/",
"/db/itm_coco_test_1k_1_base-cased.db/",
"/db/itm_coco_test_1k_2_base-cased.db/",
"/db/itm_coco_test_1k_3_base-cased.db/",
"/db/itm_coco_test_1k_4_base-cased.db/"],
"test_img_dir": ["/img/coco_val2014/",
"/img/coco_val2014/",
"/img/coco_val2014/",
"/img/coco_val2014/",
"/img/coco_val2014/"],
"checkpoint": "/pretrain/mlm_caption_bert-base.pt",
"cut_bert": -1,
"output_dir": "/storage/itm_tr/coco_bert_base-mlm_caption-corrected_img_bb_num-w_train_rv-step_40000",
"max_txt_len": 60,
"conf_th": 0.2,
"max_bb": 100,
"min_bb": 10,
"num_bb": 36,
"train_batch_size": 2048,
"val_batch_size": 4096,
"val_minibatch_size":400,
"test_minibatch_size":300,
"gradient_accumulation_steps": 8,
"learning_rate": 0.001,
"valid_steps": 1000,
"num_train_steps": 40000,
"optim": "adamax",
"decay": "linear",
"dropout": 0.1,
"falseweight_decay": 0.01,
"grad_norm": 2.0,
"warmup_steps": 1000,
"seed": 42,
"fp16": true
}

25
uniter_model/config/hps-refcoco+.json

@ -0,0 +1,25 @@
{
"train_txt_db": "/db/refcoco+_train_base-cased.db",
"train_img_dir": "/img/visual_grounding_coco_gt",
"val_txt_db": "/db/refcoco+_val_base-cased.db",
"val_img_dir": "/img/visual_grounding_coco_gt",
"checkpoint": "/pretrain/bert-base_weak/ckpt/model_step_420000.pt",
"cut_bert": -1,
"output_dir": "/storage/refcoco+/bert-base_mlm+itm+mrfr_pretrain-refcoco+_lr1e-4",
"max_txt_len": 60,
"train_batch_size": 128,
"val_batch_size": 128,
"learning_rate": 1e-4,
"optim": "adamw",
"betas": [0.9, 0.98],
"weight_decay": 0.01,
"dropout": 0.1,
"grad_norm": 2.0,
"decay": "linear",
"num_train_steps": 24000,
"warmup_steps": 1500,
"gradient_accumulation_steps": 1,
"no_cuda": false,
"seed": 24,
"fp16": true
}

26
uniter_model/config/hps-refcoco+_conceptual.json

@ -0,0 +1,26 @@
{
"train_txt_db": "/db/refcoco+_train_base-cased.db",
"train_img_dir": "/img/visual_grounding_coco_gt",
"val_txt_db": "/db/refcoco+_val_base-cased.db",
"val_img_dir": "/img/visual_grounding_det_coco",
"checkpoint": "/pretrain/bert-base_weak_conceptual/ckpt/model_step_200000.pt",
"cut_bert": -1,
"output_dir": "/storage/refcoco+/conceptual-bert-base_mlm+itm+mrfr_pretrain-refcoco+_lr8e-5",
"max_txt_len": 60,
"train_batch_size": 128,
"val_batch_size": 128,
"learning_rate": 8e-5,
"valid_steps": 1000,
"optim": "adamw",
"betas": [0.9, 0.98],
"weight_decay": 0.01,
"dropout": 0.1,
"grad_norm": 2.0,
"decay": "linear",
"num_train_steps": 24000,
"warmup_steps": 1500,
"gradient_accumulation_steps": 1,
"no_cuda": false,
"seed": 24,
"fp16": true
}

26
uniter_model/config/hps-refcoco+_conceptual_large_weak.json

@ -0,0 +1,26 @@
{
"train_txt_db": "/db/refcoco+_train_large-cased.db",
"train_img_dir": "/img/visual_grounding_coco_gt",
"val_txt_db": "/db/refcoco+_val_large-cased.db",
"val_img_dir": "/img/visual_grounding_det_coco",
"checkpoint": "/pretrain/bert-large_weak/ckpt/model_step_50000.pt",
"cut_bert": -1,
"output_dir": "/storage/refcoco+/conceptual-bert-large_mlm+itm+mrfr_pretrain-refcoco+_lr8e-5_b64g4",
"max_txt_len": 60,
"train_batch_size": 64,
"val_batch_size": 256,
"learning_rate": 8e-5,
"valid_steps": 1000,
"optim": "adamw",
"betas": [0.9, 0.98],
"weight_decay": 0.01,
"dropout": 0.1,
"grad_norm": 2.0,
"decay": "linear",
"num_train_steps": 24000,
"warmup_steps": 1500,
"gradient_accumulation_steps": 4,
"no_cuda": false,
"seed": 24,
"fp16": true
}

29
uniter_model/config/hps-refcoco+_conceptual_rank.json

@ -0,0 +1,29 @@
{
"train_txt_db": "/db/refcoco+_train_base-cased.db",
"train_img_dir": "/img/visual_grounding_coco_gt",
"val_txt_db": "/db/refcoco+_val_base-cased.db",
"val_img_dir": "/img/visual_grounding_coco_gt",
"checkpoint": "/pretrain/bert-base_weak_conceptual/ckpt/model_step_420000.pt",
"cut_bert": -1,
"output_dir": "/storage/refcoco+/conceptual-bert-base_mlm+itm+mrfr_pretrain-refcoco+_lr1e-4_rank_r0.2_m0.2_step30k",
"max_txt_len": 60,
"train_batch_size": 128,
"val_batch_size": 128,
"learning_rate": 1e-4,
"valid_steps": 1000,
"optim": "adamw",
"betas": [0.9, 0.98],
"weight_decay": 0.01,
"dropout": 0.1,
"grad_norm": 2.0,
"decay": "linear",
"num_train_steps": 30000,
"warmup_steps": 1500,
"gradient_accumulation_steps": 1,
"no_cuda": false,
"seed": 24,
"fp16": true,
"train_loss": "rank",
"hard_ratio": 0.2,
"margin": 0.2
}

26
uniter_model/config/hps-refcoco.json

@ -0,0 +1,26 @@
{
"train_txt_db": "/db/refcoco_train_base-cased.db",
"train_img_dir": "/img/visual_grounding_coco_gt",
"val_txt_db": "/db/refcoco_val_base-cased.db",
"val_img_dir": "/img/visual_grounding_coco_gt",
"checkpoint": "/pretrain/bert-base_weak/ckpt/model_step_420000.pt",
"cut_bert": -1,
"output_dir": "/storage/refcoco/bert-base_mlm+itm+mrfr_pretrain-refcoco_lr3e-4",
"max_txt_len": 60,
"train_batch_size": 128,
"val_batch_size": 128,
"learning_rate": 3e-4,
"valid_steps": 1000,
"optim": "adamw",
"betas": [0.9, 0.98],
"weight_decay": 0.01,
"dropout": 0.1,
"grad_norm": 2.0,
"decay": "linear",
"num_train_steps": 10000,
"warmup_steps": 1500,
"gradient_accumulation_steps": 1,
"no_cuda": false,
"seed": 24,
"fp16": true
}

31
uniter_model/config/hps-ve-large.json

@ -0,0 +1,31 @@
{
"train_txt_db": "/db/ve_train_large-cased.db/",
"train_img_dir": "/img/flickr30k/",
"val_txt_db": "/db/ve_dev_large-cased.db/",
"test_img_dir": "/img/flickr30k/",
"test_txt_db": "/db/ve_test_large-cased.db/",
"val_img_dir": "/img/flickr30k/",
"checkpoint": "/pretrain/bert-large_frkl_alldata.pt",
"cut_bert": -1,
"output_dir": "/storage/ve/default",
"max_txt_len": 60,
"conf_th": 0.2,
"max_bb": 100,
"min_bb": 10,
"num_bb": 36,
"train_batch_size": 8192,
"val_batch_size": 8192,
"gradient_accumulation_steps": 4,
"learning_rate": 3e-5,
"valid_steps": 500,
"num_train_steps": 6000,
"warmup_steps": 600,
"optim": "adamw",
"betas": [0.9, 0.98],
"grad_norm": 2.0,
"decay": "linear",
"dropout": 0.1,
"weight_decay": 0.01,
"seed": 42,
"fp16": true
}

31
uniter_model/config/hps-ve.json

@ -0,0 +1,31 @@
{
"train_txt_db": "/db/ve_train_base-cased.db/",
"train_img_dir": "/img/flickr30k/",
"val_txt_db": "/db/ve_dev_base-cased.db/",
"test_img_dir": "/img/flickr30k/",
"test_txt_db": "/db/ve_test_base-cased.db/",
"val_img_dir": "/img/flickr30k/",
"checkpoint": "/pretrain/base_mlm_mrfr_mrckl_itm_alldata.pt",
"cut_bert": -1,
"output_dir": "/storage/ve/default",
"max_txt_len": 60,
"conf_th": 0.2,
"max_bb": 100,
"min_bb": 10,
"num_bb": 36,
"train_batch_size": 8192,
"val_batch_size": 8192,
"gradient_accumulation_steps": 4,
"learning_rate": 3e-5,
"valid_steps": 500,
"num_train_steps": 6000,
"warmup_steps": 600,
"optim": "adamw",
"betas": [0.9, 0.98],
"grad_norm": 2.0,
"decay": "linear",
"dropout": 0.1,
"weight_decay": 0.01,
"seed": 42,
"fp16": true
}

30
uniter_model/config/hps-vqa.json

@ -0,0 +1,30 @@
{
"train_txt_db": "/db/vqa_train_base-cased.db/",
"train_img_dir": "/img/coco_train2014/",
"val_txt_db": "/db/vqa_val_base-cased.db/",
"val_img_dir": "/img/coco_val2014/",
"ans2label": "/db/ans2label.pkl",
"checkpoint": "/storage/mlm/caption-base_from-scratch_grad-acc-8_step-80k_val-step-5k_lr-2e-4_wu-0.1/ckpt/model_step_80000_final.pt",
"cut_bert": -1,
"output_dir": "/storage/vqa/bert_base-mlm_caption_nonblind_from_scratch-nowd-linear",
"max_txt_len": 60,
"conf_th": 0.2,
"max_bb": 100,
"min_bb": 10,
"num_bb": 36,
"train_batch_size": 2048,
"val_batch_size": 4096,
"gradient_accumulation_steps": 8,
"learning_rate": 0.001,
"valid_steps": 500,
"num_train_steps": 20000,
"optim": "adamax",
"decay": "linear",
"dropout": 0.1,
"weight_decay": 0,
"grad_norm": 2.0,
"warmup_steps": 1000,
"seed": 42,
"fp16": false,
"blind": false
}

47
uniter_model/config/itm-coco-base.json

@ -0,0 +1,47 @@
{
"compressed_db": false,
"checkpoint": "/pretrain/alltask_ot_alldata.pt",
"output_dir": "/storage/finetune/itm/coco_ot_alldata_base_hnv2",
"max_txt_len": 60,
"conf_th": 0.2,
"max_bb": 100,
"min_bb": 10,
"num_bb": 36,
"train_batch_size": 8,
"negative_size": 399,
"hard_neg_size": 31,
"inf_minibatch_size": 400,
"margin": 0.2,
"learning_rate": 5e-05,
"valid_steps": 500,
"num_train_steps": 5000,
"optim": "adamw",
"betas": [
0.9,
0.98
],
"decay": "linear",
"dropout": 0.1,
"weight_decay": 0.01,
"grad_norm": 2.0,
"warmup_steps": 500,
"seed": 42,
"full_val": true,
"fp16": true,
"n_workers": 4,
"pin_mem": true,
"train_txt_dbs": [
"/db/itm_coco_train_base-cased.db",
"/db/itm_coco_restval_base-cased.db"
],
"train_img_dbs": [
"/img/coco_train2014/",
"/img/coco_val2014"
],
"val_txt_db": "/db/itm_coco_val_base-cased.db",
"val_img_db": "/img/coco_val2014/",
"test_txt_db": "/db/itm_coco_test_base-cased.db",
"test_img_db": "/img/coco_val2014/",
"model_config": "/src/config/uniter-base.json",
"rank": 0
}

45
uniter_model/config/itm-ot-base-16gpus.json

@ -0,0 +1,45 @@
{
"compressed_db": false,
"checkpoint": "/pretrain/bert-base-cased.pt",
"output_dir": "/ssd2/siqi/Projects/model_compression/outputs/debug",
"max_txt_len": 60,
"conf_th": 0.2,
"max_bb": 100,
"min_bb": 10,
"num_bb": 36,
"train_batch_size": 80,
"gradient_accumulation_steps": 8,
"negative_size": 1,
"hard_neg_size": 0,
"inf_minibatch_size": 400,
"margin": 0.2,
"learning_rate": 5e-05,
"warmup_steps": 100,
"valid_steps": 500,
"num_train_steps": 1000,
"optim": "adamw",
"betas": [
0.9,
0.98
],
"decay": "linear",
"dropout": 0.1,
"weight_decay": 0.01,
"grad_norm": 2.0,
"seed": 42,
"full_val": false,
"fp16": true,
"n_workers": 4,
"pin_mem": true,
"train_txt_dbs": [
"/db/itm_flickr30k_train_base-cased.db"
],
"train_img_dbs": [
"/img/flickr30k/"
],
"val_txt_db": "/db/itm_flickr30k_val_base-cased.db",
"val_img_db": "/img/flickr30k/",
"test_txt_db": "/db/itm_flickr30k_test_base-cased.db",
"test_img_db": "/img/flickr30k/",
"model_config": "./config/uniter-base.json"
}

45
uniter_model/config/itm-ot-base-16gpus_philly.json

@ -0,0 +1,45 @@
{
"compressed_db": false,
"checkpoint": "/pretrain/alltask_ot_alldata.pt",
"output_dir": "/ssd2/siqi/Projects/model_compression/outputs/debug",
"max_txt_len": 60,
"conf_th": 0.2,
"max_bb": 100,
"min_bb": 10,
"num_bb": 36,
"train_batch_size": 40,
"gradient_accumulation_steps": 4,
"negative_size": 1,
"hard_neg_size": 0,
"inf_minibatch_size": 400,
"margin": 0.2,
"learning_rate": 5e-05,
"valid_steps": 500,
"num_train_steps": 5000,
"optim": "adamw",
"betas": [
0.9,
0.98
],
"decay": "linear",
"dropout": 0.1,
"weight_decay": 0.01,
"grad_norm": 2.0,
"warmup_steps": 500,
"seed": 42,
"full_val": false,
"fp16": true,
"n_workers": 4,
"pin_mem": true,
"train_txt_dbs": [
"/db/itm_flickr30k_train_base-cased.db"
],
"train_img_dbs": [
"/img/flickr30k/"
],
"val_txt_db": "/db/itm_flickr30k_val_base-cased.db",
"val_img_db": "/img/flickr30k/",
"test_txt_db": "/db/itm_flickr30k_test_base-cased.db",
"test_img_db": "/img/flickr30k/",
"model_config": "./config/uniter-base.json"
}

42
uniter_model/config/pretrain-gqa-alltask.json

@ -0,0 +1,42 @@
{
"train_datasets": [
{"name": "gqa",
"db": ["/db/pretrain_gqa_train_0_large-cased.db", "/db/pretrain_gqa_train_1_base-cased.db",
"/db/pretrain_gqa_train_2_base-cased.db", "/db/pretrain_gqa_train_3_base-cased.db",
"/db/pretrain_gqa_train_4_base-cased.db", "/db/pretrain_gqa_train_5_base-cased.db",
"/db/pretrain_gqa_train_6_base-cased.db", "/db/pretrain_gqa_train_7_base-cased.db",
"/db/pretrain_gqa_train_8_base-cased.db", "/db/pretrain_gqa_train_9_base-cased.db",
"/db/pretrain_gqa_val_base-cased.db"],
"img": ["/img/gqa/"],
"tasks": ["mlm", "mrm", "mrckl"],
"mix_ratio": [2, 1, 1]}
],
"val_datasets": [
{"name": "gqa",
"db": ["/db/pretrain_gqa_testdev_balanced_base-cased.db"],
"img": ["/img/gqa/"],
"tasks": ["mlm", "mrm", "mrckl"]}
],
"checkpoint": "/pretrain/bert-large_weak_alldata/ckpt/model_step_100000.pt",
"output_dir": "/storage/pretrain_gqa/bert_large_weak_alldata_100k-train_val_all-mlm_mrm_mrckl-train_batch_size_6144-500k_steps",
"mrm_prob": 0.15,
"max_txt_len": 220,
"conf_th": 0.2,
"max_bb": 100,
"min_bb": 10,
"num_bb": 36,
"train_batch_size": 6144,
"val_batch_size": 8000,
"gradient_accumulation_steps": 10,
"learning_rate": 3e-05,
"valid_steps": 10000,
"num_train_steps": 500000,
"optim": "adamw",
"decay": "linear",
"dropout": 0.1,
"weight_decay": 0.01,
"grad_norm": -1,
"warmup_steps": 50000,
"seed": 42,
"fp16": true
}

42
uniter_model/config/pretrain-mlm-coco.json

@ -0,0 +1,42 @@
{
"train_datasets": [
{"name": "coco_cap",
"db": ["/db/pretrain_caption_coco_train_base-cased.db/",
"/db/pretrain_caption_coco_trainval_base-cased.db/"],
"img": ["/img/coco_train2014/", "/img/coco_val2014/"],
"tasks": ["mlm"],
"mix_ratio": [1]}
],
"val_datasets": [
{"name": "coco_cap",
"db": ["/db/pretrain_caption_coco_val_base-cased.db/"],
"img": ["/img/coco_val2014/"],
"tasks": ["mlm"]}
],
"output_dir": "/storage/pretrain/mlm_coco",
"mrm_prob": 0.15,
"neg_size": 1024,
"itm_neg_prob": 0.5,
"max_txt_len": 60,
"conf_th": 0.2,
"max_bb": 100,
"min_bb": 10,
"num_bb": 36,
"train_batch_size": 8192,
"val_batch_size": 8192,
"gradient_accumulation_steps": 2,
"learning_rate": 5e-05,
"valid_steps": 5000,
"num_train_steps": 100000,
"optim": "adamw",
"decay": "linear",
"dropout": 0.1,
"weight_decay": 0.01,
"grad_norm": 2.0,
"warmup_steps": 10000,
"seed": 42,
"fp16": true,
"pin_mem": true,
"n_workers": 4,
"from_scratch": false
}

53
uniter_model/config/pretrain-mlm_itmot_mrfr_mrckl-indomain-base.json

@ -0,0 +1,53 @@
{
"train_datasets": [
{"name": "coco_cap",
"db": ["/db/pretrain_caption_coco_train_base-cased.db/",
"/db/pretrain_caption_coco_trainval_base-cased.db/"],
"img": ["/img/coco_train2014/", "/img/coco_val2014/"],
"tasks": ["itm", "mlm", "mrfr", "mrckl"],
"mix_ratio": [2, 2, 1, 1]},
{"name": "vg_cap",
"db": ["/db/pretrain_caption_vg_train_base-cased.db/"],
"img": ["/img/vg/"],
"tasks": ["itm", "mlm", "mrfr", "mrckl"],
"mix_ratio": [2, 2, 1, 1]}
],
"val_datasets": [
{"name": "coco_cap",
"db": ["/db/pretrain_caption_coco_val_base-cased.db/"],
"img": ["/img/coco_val2014/"],
"tasks": ["itm", "mlm", "mrfr", "mrckl"]},
{"name": "vg_cap",
"db": ["/db/pretrain_caption_vg_val_base-cased.db/"],
"img": ["/img/vg/"],
"tasks": ["itm", "mlm", "mrfr", "mrckl"]}
],
"model_config": "/src/config/uniter-base.json",
"checkpoint": "/pretrain/bert-base-cased.pt",
"output_dir": "/storage/pretrain/alltask_ot_indomain_base",
"ans2label": "/db/pretrain_ans2label.pkl",
"mrm_prob": 0.15,
"itm_neg_prob": 0.5,
"itm_ot_lambda": 0.1,
"max_txt_len": 60,
"conf_th": 0.2,
"max_bb": 100,
"min_bb": 10,
"num_bb": 36,
"train_batch_size": 10240,
"val_batch_size": 10240,
"gradient_accumulation_steps": 2,
"learning_rate": 5e-05,
"valid_steps": 5000,
"num_train_steps": 200000,
"optim": "adamw",
"decay": "linear",
"dropout": 0.1,
"weight_decay": 0.01,
"grad_norm": 5.0,
"warmup_steps": 10000,
"seed": 42,
"fp16": true,
"pin_mem": true,
"n_workers": 4
}

42
uniter_model/config/pretrain-mrckl-coco.json

@ -0,0 +1,42 @@
{
"train_datasets": [
{"name": "coco_cap",
"db": ["/db/pretrain_caption_coco_train_base-cased.db/",
"/db/pretrain_caption_coco_trainval_base-cased.db/"],
"img": ["/img/coco_train2014/", "/img/coco_val2014/"],
"tasks": ["mrckl"],
"mix_ratio": [1]}
],
"val_datasets": [
{"name": "coco_cap",
"db": ["/db/pretrain_caption_coco_val_base-cased.db/"],
"img": ["/img/coco_val2014/"],
"tasks": ["mrckl"]}
],
"output_dir": "/storage/pretrain/mrckl_coco",
"mrm_prob": 0.15,
"neg_size": 1024,
"itm_neg_prob": 0.5,
"max_txt_len": 60,
"conf_th": 0.2,
"max_bb": 100,
"min_bb": 10,
"num_bb": 36,
"train_batch_size": 8192,
"val_batch_size": 8192,
"gradient_accumulation_steps": 2,
"learning_rate": 5e-05,
"valid_steps": 5000,
"num_train_steps": 100000,
"optim": "adamw",
"decay": "linear",
"dropout": 0.1,
"weight_decay": 0.01,
"grad_norm": 2.0,
"warmup_steps": 10000,
"seed": 42,
"fp16": true,
"pin_mem": true,
"n_workers": 4,
"from_scratch": false
}

42
uniter_model/config/pretrain-mrfr-coco.json

@ -0,0 +1,42 @@
{
"train_datasets": [
{"name": "coco_cap",
"db": ["/db/pretrain_caption_coco_train_base-cased.db/",
"/db/pretrain_caption_coco_trainval_base-cased.db/"],
"img": ["/img/coco_train2014/", "/img/coco_val2014/"],
"tasks": ["mrfr"],
"mix_ratio": [1]}
],
"val_datasets": [
{"name": "coco_cap",
"db": ["/db/pretrain_caption_coco_val_base-cased.db/"],
"img": ["/img/coco_val2014/"],
"tasks": ["mrfr"]}
],
"output_dir": "/storage/pretrain/mrfr_coco",
"mrm_prob": 0.15,
"neg_size": 1024,
"itm_neg_prob": 0.5,
"max_txt_len": 60,
"conf_th": 0.2,
"max_bb": 100,
"min_bb": 10,
"num_bb": 36,
"train_batch_size": 8192,
"val_batch_size": 8192,
"gradient_accumulation_steps": 2,
"learning_rate": 5e-05,
"valid_steps": 5000,
"num_train_steps": 100000,
"optim": "adamw",
"decay": "linear",
"dropout": 0.1,
"weight_decay": 0.01,
"grad_norm": 2.0,
"warmup_steps": 10000,
"seed": 42,
"fp16": true,
"pin_mem": true,
"n_workers": 4,
"from_scratch": false
}

43
uniter_model/config/pretrain-mrm-nce-coco.json

@ -0,0 +1,43 @@
{
"train_datasets": [
{"name": "coco_cap",
"db": ["/db/pretrain_caption_coco_train_base-cased.db/",
"/db/pretrain_caption_coco_trainval_base-cased.db/"],
"img": ["/img/coco_train2014/", "/img/coco_val2014/"],
"tasks": ["mrm-nce"],
"mix_ratio": [1]}
],
"val_datasets": [
{"name": "coco_cap",
"db": ["/db/pretrain_caption_coco_val_base-cased.db/"],
"img": ["/img/coco_val2014/"],
"tasks": ["mrm-nce"]}
],
"output_dir": "/storage/pretrain/mrm_nce_coco",
"mrm_prob": 0.15,
"neg_size": 1024,
"nce_temp": 1.0,
"itm_neg_prob": 0.5,
"max_txt_len": 60,
"conf_th": 0.2,
"max_bb": 100,
"min_bb": 10,
"num_bb": 36,
"train_batch_size": 8192,
"val_batch_size": 8192,
"gradient_accumulation_steps": 2,
"learning_rate": 5e-05,
"valid_steps": 5000,
"num_train_steps": 100000,
"optim": "adamw",
"decay": "linear",
"dropout": 0.1,
"weight_decay": 0.01,
"grad_norm": 2.0,
"warmup_steps": 10000,
"seed": 42,
"fp16": true,
"pin_mem": true,
"n_workers": 4,
"from_scratch": false
}

38
uniter_model/config/pretrain-vcr-alltask.json

@ -0,0 +1,38 @@
{
"train_datasets": [
{"name": "vcr",
"db": ["/db/vcr_val_w_obj_ids_base-cased.db/"],
"img": ["/img/vcr_val/;/img/vcr_gt_val/"],
"tasks": ["mlm", "mrm", "mrckl"],
"mix_ratio": [2, 1, 1]}
],
"val_datasets": [
{"name": "vcr",
"db": ["/db/vcr_val_w_obj_ids_base-cased.db/"],
"img": ["/img/vcr_val/;/img/vcr_gt_val/"],
"tasks": ["mlm", "mrm", "mrckl"]}
],
"checkpoint": "/pretrain/bert-base_weak_w_mlm_itm_mrm_mrckl_4gpu/ckpt/model_step_500000.pt",
"vcr_task": ["qa", "qar"],
"output_dir": "/storage/debug/mlm_mrm_mrckl-qa_qar-gt_det",
"mrm_prob": 0.15,
"max_txt_len": 60,
"conf_th": 0.2,
"max_bb": 100,
"min_bb": 10,
"num_bb": 36,
"train_batch_size": 8000,
"val_batch_size": 8000,
"gradient_accumulation_steps": 5,
"learning_rate": 3e-05,
"valid_steps": 10,
"num_train_steps": 120000,
"optim": "adamw",
"decay": "linear",
"dropout": 0.1,
"weight_decay": 0.01,
"grad_norm": -1,
"warmup_steps": 12000,
"seed": 42,
"fp16": true
}

40
uniter_model/config/train-itm-debug.json

@ -0,0 +1,40 @@
{
"train_txt_dbs": ["/db/itm_flickr30k_val_base-cased.db"],
"train_img_dbs": ["/img/flickr30k/"],
"val_txt_db": "/db/itm_flickr30k_val_base-cased.db",
"val_img_db": "/img/flickr30k/",
"test_txt_db": "/db/itm_flickr30k_test_base-cased.db",
"test_img_db": "/img/flickr30k/",
"checkpoint": "/pretrain/uniter-base-iclr.pt",
"model_config": "/src/config/uniter-base.json",
"output_dir": "/debug/itm/flickr_default",
"max_txt_len": 60,
"conf_th": 0.2,
"max_bb": 100,
"min_bb": 10,
"num_bb": 36,
"train_batch_size": 20,
"negative_size": 1,
"hard_neg_size": 1,
"hard_neg_pool_size": 20,
"steps_per_hard_neg": 30,
"inf_minibatch_size": 40,
"gradient_accumulation_steps": 2,
"learning_rate": 1e-05,
"valid_steps": 40,
"num_train_steps": 50,
"optim": "adamw",
"betas": [
0.9,
0.98
],
"margin": 0.2,
"dropout": 0.1,
"weight_decay": 0.01,
"grad_norm": 2.0,
"warmup_steps": 1600,
"seed": 42,
"fp16": true,
"n_workers": 0,
"pin_mem": true
}

38
uniter_model/config/train-itm-flickr-base-hnv2.json

@ -0,0 +1,38 @@
{
"train_txt_dbs": ["/db/itm_flickr30k_train_base-cased.db"],
"train_img_dbs": ["/img/flickr30k/"],
"val_txt_db": "/db/itm_flickr30k_val_base-cased.db",
"val_img_db": "/img/flickr30k/",
"test_txt_db": "/db/itm_flickr30k_test_base-cased.db",
"test_img_db": "/img/flickr30k/",
"checkpoint": "/pretrain/alltask_ot_alldata.pt",
"model_config": "/src/config/uniter-base.json",
"output_dir": "/storage/finetune/itm/flickr_ot_alldata_base_hnv2",
"max_txt_len": 60,
"conf_th": 0.2,
"max_bb": 100,
"min_bb": 10,
"num_bb": 36,
"train_batch_size": 8,
"negative_size": 399,
"hard_neg_size": 31,
"inf_minibatch_size": 400,
"learning_rate": 5e-05,
"valid_steps": 500,
"num_train_steps": 5000,
"optim": "adamw",
"betas": [
0.9,
0.98
],
"margin": 0.2,
"dropout": 0.1,
"weight_decay": 0.01,
"grad_norm": 2.0,
"warmup_steps": 500,
"seed": 42,
"fp16": true,
"n_workers": 4,
"pin_mem": true,
"full_val": true
}

40
uniter_model/config/train-itm-flickr-base.json

@ -0,0 +1,40 @@
{
"train_txt_dbs": ["/db/itm_flickr30k_train_base-cased.db"],
"train_img_dbs": ["/img/flickr30k/"],
"val_txt_db": "/db/itm_flickr30k_val_base-cased.db",
"val_img_db": "/img/flickr30k/",
"test_txt_db": "/db/itm_flickr30k_test_base-cased.db",
"test_img_db": "/img/flickr30k/",
"checkpoint": "/pretrain/alltask_ot_alldata.pt",
"model_config": "/src/config/uniter-base.json",
"output_dir": "/storage/finetune/itm/flickr_ot_alldata_base",
"max_txt_len": 60,
"conf_th": 0.2,
"max_bb": 100,
"min_bb": 10,
"num_bb": 36,
"train_batch_size": 40,
"negative_size": 1,
"hard_neg_size": 0,
"hard_neg_pool_size": 20,
"steps_per_hard_neg": -1,
"inf_minibatch_size": 512,
"gradient_accumulation_steps": 4,
"learning_rate": 5e-05,
"valid_steps": 2000,
"num_train_steps": 20000,
"optim": "adamw",
"betas": [
0.9,
0.98
],
"margin": 0.2,
"dropout": 0.1,
"weight_decay": 0.01,
"grad_norm": 2.0,
"warmup_steps": 2000,
"seed": 42,
"fp16": true,
"n_workers": 4,
"pin_mem": true
}

37
uniter_model/config/train-nlvr2-base-1gpu.json

@ -0,0 +1,37 @@
{
"train_txt_db": "/db/nlvr2_train_base-cased.db",
"train_img_db": "/img/nlvr2_train/",
"val_txt_db": "/db/nlvr2_dev_base-cased.db",
"val_img_db": "/img/nlvr2_dev/",
"test_txt_db": "/db/nlvr2_test1_base-cased.db",
"test_img_db": "/img/nlvr2_test/",
"checkpoint": "/pretrain/uniter-base-iclr.pt",
"model_config": "/src/config/uniter-base.json",
"model": "paired-attn",
"use_img_type": true,
"output_dir": "/storage/nlvr2/default",
"max_txt_len": 60,
"conf_th": 0.2,
"max_bb": 100,
"min_bb": 10,
"num_bb": 36,
"train_batch_size": 10240,
"val_batch_size": 10240,
"gradient_accumulation_steps": 1,
"learning_rate": 3e-05,
"valid_steps": 500,
"num_train_steps": 8000,
"optim": "adamw",
"betas": [
0.9,
0.98
],
"dropout": 0.1,
"weight_decay": 0.01,
"grad_norm": 2.0,
"warmup_steps": 800,
"seed": 77,
"fp16": true,
"n_workers": 4,
"pin_mem": true
}

31
uniter_model/config/train-ve-base-2gpu.json

@ -0,0 +1,31 @@
{
"train_txt_db": "/db/ve_train_base-cased.db/",
"train_img_db": "/img/flickr30k/",
"val_txt_db": "/db/ve_dev_base-cased.db/",
"val_img_db": "/img/flickr30k/",
"test_txt_db": "/db/ve_test_base-cased.db/",
"test_img_db": "/img/flickr30k/",
"checkpoint": "/pretrain/alltask_ot_alldata.pt",
"model_config": "/src/config/uniter-base.json",
"output_dir": "/storage/finetune/ve/default",
"max_txt_len": 60,
"conf_th": 0.2,
"max_bb": 100,
"min_bb": 10,
"num_bb": 36,
"train_batch_size": 4096,
"val_batch_size": 4096,
"gradient_accumulation_steps": 4,
"learning_rate": 3e-5,
"valid_steps": 500,
"num_train_steps": 6000,
"warmup_steps": 600,
"optim": "adamw",
"betas": [0.9, 0.98],
"grad_norm": 2.0,
"decay": "linear",
"dropout": 0.1,
"weight_decay": 0.01,
"seed": 42,
"fp16": true
}

31
uniter_model/config/train-ve-large-2gpu.json

@ -0,0 +1,31 @@
{
"train_txt_db": "/db/ve_train_large-cased.db/",
"train_img_db": "/img/flickr30k/",
"val_txt_db": "/db/ve_dev_large-cased.db/",
"val_img_db": "/img/flickr30k/",
"test_txt_db": "/db/ve_test_large-cased.db/",
"test_img_db": "/img/flickr30k/",
"checkpoint": "/pretrain/alltask_ot_alldata_large.pt",
"model_config": "/src/config/uniter-large.json",
"output_dir": "/storage/finetune/ve/default_large",
"max_txt_len": 60,
"conf_th": 0.2,
"max_bb": 100,
"min_bb": 10,
"num_bb": 36,
"train_batch_size": 4096,
"val_batch_size": 4096,
"gradient_accumulation_steps": 4,
"learning_rate": 3e-5,
"valid_steps": 500,
"num_train_steps": 6000,
"warmup_steps": 600,
"optim": "adamw",
"betas": [0.9, 0.98],
"grad_norm": 2.0,
"decay": "linear",
"dropout": 0.1,
"weight_decay": 0.01,
"seed": 42,
"fp16": true
}

35
uniter_model/config/train-vqa-base-2gpu.json

@ -0,0 +1,35 @@
{
"train_txt_dbs": ["/db/vqa_train_base-cased.db",
"/db/vqa_trainval_base-cased.db",
"/db/vqa_vg_base-cased.db"],
"train_img_dbs": ["/img/coco_train2014/", "/img/coco_val2014", "/img/vg/"],
"val_txt_db": "/db/vqa_devval_base-cased.db",
"val_img_db": "/img/coco_val2014/",
"checkpoint": "/pretrain/uniter-base-iclr.pt",
"model_config": "/src/config/uniter-base.json",
"output_dir": "/storage/vqa/default",
"max_txt_len": 60,
"conf_th": 0.2,
"max_bb": 100,
"min_bb": 10,
"num_bb": 36,
"train_batch_size": 10240,
"val_batch_size": 10240,
"gradient_accumulation_steps": 5,
"learning_rate": 8e-05,
"valid_steps": 500,
"num_train_steps": 6000,
"optim": "adamw",
"betas": [
0.9,
0.98
],
"dropout": 0.1,
"weight_decay": 0.01,
"grad_norm": 2.0,
"warmup_steps": 600,
"seed": 42,
"fp16": true,
"n_workers": 4,
"pin_mem": true
}

14
uniter_model/config/uniter-base.json

@ -0,0 +1,14 @@
{
"attention_probs_dropout_prob": 0.1,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 768,
"initializer_range": 0.02,
"intermediate_size": 3072,
"max_position_embeddings": 512,
"num_attention_heads": 12,
"num_hidden_layers": 12,
"num_hidden_layers_img": 1,
"type_vocab_size": 2,
"vocab_size": 28996
}

13
uniter_model/config/uniter-large.json

@ -0,0 +1,13 @@
{
"attention_probs_dropout_prob": 0.1,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 1024,
"initializer_range": 0.02,
"intermediate_size": 4096,
"max_position_embeddings": 512,
"num_attention_heads": 16,
"num_hidden_layers": 24,
"type_vocab_size": 2,
"vocab_size": 28996
}

27
uniter_model/data/__init__.py

@ -0,0 +1,27 @@
#from .data import (TxtTokLmdb, DetectFeatLmdb,
# ConcatDatasetWithLens, ImageLmdbGroup)
#from .mlm import (MlmDataset, MlmEvalDataset,
# BlindMlmDataset, BlindMlmEvalDataset,
# mlm_collate, mlm_eval_collate,
# mlm_blind_collate, mlm_blind_eval_collate)
#from .mrm import (MrfrDataset, OnlyImgMrfrDataset,
# MrcDataset, OnlyImgMrcDataset,
# mrfr_collate, mrfr_only_img_collate,
# mrc_collate, mrc_only_img_collate)
from .itm import (TokenBucketSamplerForItm,
ItmDataset, itm_collate, itm_ot_collate,
ItmRankDataset, ItmRankDatasetHardNeg, itm_rank_collate,
ItmRankDatasetHardNegFromText,
ItmRankDatasetHardNegFromImage, itm_rank_hnv2_collate,
ItmHardNegDataset, itm_hn_collate,
ItmValDataset, itm_val_collate,
ItmEvalDataset, itm_eval_collate)
from .sampler import TokenBucketSampler, DistributedSampler
from .loader import MetaLoader, PrefetchLoader
from .vqa import VqaDataset, vqa_collate, VqaEvalDataset, vqa_eval_collate
from .nlvr2 import (Nlvr2PairedDataset, nlvr2_paired_collate,
Nlvr2PairedEvalDataset, nlvr2_paired_eval_collate,
Nlvr2TripletDataset, nlvr2_triplet_collate,
Nlvr2TripletEvalDataset, nlvr2_triplet_eval_collate)
from .ve import VeDataset, ve_collate, VeEvalDataset, ve_eval_collate

283
uniter_model/data/data.py

@ -0,0 +1,283 @@
"""
Dataset interfaces
"""
from collections import defaultdict
from contextlib import contextmanager
import io
import json
import lmdb
from os.path import exists
import numpy as np
import torch
from torch.utils.data import Dataset, ConcatDataset
from tqdm import tqdm
from lz4.frame import compress, decompress
import msgpack
import msgpack_numpy
msgpack_numpy.patch()
def _fp16_to_fp32(feat_dict):
out = {k: arr.astype(np.float32)
if arr.dtype == np.float16 else arr
for k, arr in feat_dict.items()}
return out
def compute_num_bb(confs, conf_th, min_bb, max_bb):
num_bb = max(min_bb, (confs > conf_th).sum())
num_bb = min(max_bb, num_bb)
return num_bb
class DetectFeatLmdb(object):
def __init__(self, img_dir, conf_th=0.2, max_bb=100, min_bb=10, num_bb=36,
compress=True):
self.img_dir = img_dir
if conf_th == -1:
db_name = f'feat_numbb{num_bb}'
self.name2nbb = defaultdict(lambda: num_bb)
else:
db_name = f'feat_th{conf_th}_max{max_bb}_min{min_bb}'
nbb = f'nbb_th{conf_th}_max{max_bb}_min{min_bb}.json'
if not exists(f'{img_dir}/{nbb}'):
# nbb is not pre-computed
self.name2nbb = None
else:
self.name2nbb = json.load(open(f'{img_dir}/{nbb}'))
self.compress = compress
if compress:
db_name += '_compressed'
if self.name2nbb is None:
if compress:
db_name = 'all_compressed'
else:
db_name = 'all'
# only read ahead on single node training
self.env = lmdb.open(f'{img_dir}/{db_name}',
readonly=True, create=False,
readahead=not _check_distributed())
self.txn = self.env.begin(buffers=True)
if self.name2nbb is None:
self.name2nbb = self._compute_nbb()
def _compute_nbb(self):
name2nbb = {}
fnames = json.loads(self.txn.get(key=b'__keys__').decode('utf-8'))
for fname in tqdm(fnames, desc='reading images'):
dump = self.txn.get(fname.encode('utf-8'))
if self.compress:
with io.BytesIO(dump) as reader:
img_dump = np.load(reader, allow_pickle=True)
confs = img_dump['conf']
else:
img_dump = msgpack.loads(dump, raw=False)
confs = img_dump['conf']
name2nbb[fname] = compute_num_bb(confs, self.conf_th,
self.min_bb, self.max_bb)
return name2nbb
def __del__(self):
self.env.close()
def get_dump(self, file_name):
# hack for MRC
dump = self.txn.get(file_name.encode('utf-8'))
nbb = self.name2nbb[file_name]
if self.compress:
with io.BytesIO(dump) as reader:
img_dump = np.load(reader, allow_pickle=True)
img_dump = _fp16_to_fp32(img_dump)
else:
img_dump = msgpack.loads(dump, raw=False)
img_dump = _fp16_to_fp32(img_dump)
img_dump = {k: arr[:nbb, ...] for k, arr in img_dump.items()}
return img_dump
def __getitem__(self, file_name):
dump = self.txn.get(file_name.encode('utf-8'))
nbb = self.name2nbb[file_name]
if self.compress:
with io.BytesIO(dump) as reader:
img_dump = np.load(reader, allow_pickle=True)
img_dump = {'features': img_dump['features'],
'norm_bb': img_dump['norm_bb']}
else:
img_dump = msgpack.loads(dump, raw=False)
img_feat = torch.tensor(img_dump['features'][:nbb, :]).float()
img_bb = torch.tensor(img_dump['norm_bb'][:nbb, :]).float()
return img_feat, img_bb
def __contains__(self, file_name):
return self.txn.get(file_name.encode('utf-8')) is not None
@contextmanager
def open_lmdb(db_dir, readonly=False):
db = TxtLmdb(db_dir, readonly)
try:
yield db
finally:
del db
class TxtLmdb(object):
def __init__(self, db_dir, readonly=True):
self.readonly = readonly
if readonly:
# training
self.env = lmdb.open(db_dir,
readonly=True, create=False,
readahead=not _check_distributed())
self.txn = self.env.begin(buffers=True)
self.write_cnt = None
else:
# prepro
self.env = lmdb.open(db_dir, readonly=False, create=True,
map_size=4 * 1024**4)
self.txn = self.env.begin(write=True)
self.write_cnt = 0
def __del__(self):
if self.write_cnt:
self.txn.commit()
self.env.close()
def __getitem__(self, key):
return msgpack.loads(decompress(self.txn.get(key.encode('utf-8'))),
raw=False)
def __setitem__(self, key, value):
# NOTE: not thread safe
if self.readonly:
raise ValueError('readonly text DB')
ret = self.txn.put(key.encode('utf-8'),
compress(msgpack.dumps(value, use_bin_type=True)))
self.write_cnt += 1
if self.write_cnt % 1000 == 0:
self.txn.commit()
self.txn = self.env.begin(write=True)
self.write_cnt = 0
return ret
def get_ids_and_lens(db):
assert isinstance(db, TxtTokLmdb)
lens = []
ids = []
for id_ in db.ids:
lens.append(db.id2len[id_])
ids.append(id_)
return lens, ids
class DetectFeatTxtTokDataset(Dataset):
def __init__(self, txt_db, img_db):
assert isinstance(txt_db, TxtTokLmdb)
assert isinstance(img_db, DetectFeatLmdb)
self.txt_db = txt_db
self.img_db = img_db
txt_lens, self.ids = get_ids_and_lens(txt_db)
txt2img = txt_db.txt2img
self.lens = [tl + self.img_db.name2nbb[txt2img[id_]]
for tl, id_ in zip(txt_lens, self.ids)]
def __len__(self):
return len(self.ids)
def __getitem__(self, i):
id_ = self.ids[i]
example = self.txt_db[id_]
return example
def _get_img_feat(self, fname):
img_feat, bb = self.img_db[fname]
img_bb = torch.cat([bb, bb[:, 4:5]*bb[:, 5:]], dim=-1)
num_bb = img_feat.size(0)
return img_feat, img_bb, num_bb
class ConcatDatasetWithLens(ConcatDataset):
""" A thin wrapper on pytorch concat dataset for lens batching """
def __init__(self, datasets):
super().__init__(datasets)
self.lens = [l for dset in datasets for l in dset.lens]
def __getattr__(self, name):
return self._run_method_on_all_dsets(name)
def _run_method_on_all_dsets(self, name):
def run_all(*args, **kwargs):
return [dset.__getattribute__(name)(*args, **kwargs)
for dset in self.datasets]
return run_all
def pad_tensors(tensors, lens=None, pad=0):
"""B x [T, ...]"""
if lens is None:
lens = [t.size(0) for t in tensors]
max_len = max(lens)
bs = len(tensors)
hid = tensors[0].size(-1)
dtype = tensors[0].dtype
output = torch.zeros(bs, max_len, hid, dtype=dtype)
if pad:
output.data.fill_(pad)
for i, (t, l) in enumerate(zip(tensors, lens)):
output.data[i, :l, ...] = t.data
return output
def get_gather_index(txt_lens, num_bbs, batch_size, max_len, out_size):
# assert len(txt_lens) == len(num_bbs) == batch_size
gather_index = torch.arange(0, out_size, dtype=torch.long,
).unsqueeze(0).repeat(len(num_bbs), 1)
# for i, (tl, nbb) in enumerate(zip(txt_lens, num_bbs)):
# gather_index.data[i, tl:tl+nbb] = torch.arange(max_len, max_len+nbb,
# dtype=torch.long).data
return gather_index
def get_gather_index_uniter(txt_lens, num_bbs, batch_size, max_len, out_size):
assert len(txt_lens) == len(num_bbs) == batch_size
gather_index = torch.arange(0, out_size, dtype=torch.long,
).unsqueeze(0).repeat(batch_size, 1)
for i, (tl, nbb) in enumerate(zip(txt_lens, num_bbs)):
gather_index.data[i, tl:tl+nbb] = torch.arange(max_len, max_len+nbb,
dtype=torch.long).data
return gather_index
def get_gather_index_img(txt_lens, num_bbs, batch_size, max_len, out_size):
gather_index = torch.zeros(batch_size, out_size, dtype=torch.long)
for i, (tl, nbb) in enumerate(zip(txt_lens, num_bbs)):
gather_index.data[i, :nbb] = torch.arange(max_len, max_len+nbb,
dtype=torch.long).data
gather_index.data[i, nbb:nbb+tl] = torch.arange(0, tl,
dtype=torch.long).data
return gather_index
class ImageLmdbGroup(object):
def __init__(self, conf_th, max_bb, min_bb, num_bb, compress):
self.path2imgdb = {}
self.conf_th = conf_th
self.max_bb = max_bb
self.min_bb = min_bb
self.num_bb = num_bb
self.compress = compress
def __getitem__(self, path):
img_db = self.path2imgdb.get(path, None)
if img_db is None:
img_db = DetectFeatLmdb(path, self.conf_th, self.max_bb,
self.min_bb, self.num_bb, self.compress)
return img_db

572
uniter_model/data/itm.py

@ -0,0 +1,572 @@
"""
Itm dataset
"""
from collections import defaultdict
import copy
import json
import random
import torch
from torch.nn.utils.rnn import pad_sequence
import numpy as np
from toolz.sandbox import unzip
from cytoolz import concat
from .data import (DetectFeatTxtTokDataset, TxtTokLmdb, DetectFeatLmdb,
pad_tensors, get_gather_index, get_ids_and_lens)
from .sampler import TokenBucketSampler
class TokenBucketSamplerForItm(TokenBucketSampler):
def __init__(self, dset, *args, **kwargs):
super().__init__(dset.lens, *args, **kwargs)
self.dset = dset
def __iter__(self):
it = super().__iter__()
self.dset.new_epoch()
self._lens = self.dset.lens
return it
def _has_overlap(la, lb):
if len(la) < len(lb):
la, lb = lb, la
s = set(la)
return any(b in s for b in lb)
def _sample_negative_rand(sample_pool, ground_truths, num_sample):
""" random and retry """
outputs = ground_truths[:1]
while _has_overlap(outputs, ground_truths):
outputs = random.sample(sample_pool, num_sample)
return outputs
def _sample_negative_extra(sample_pool, ground_truths, num_sample):
""" sample extra then remove """
tot_size = len(ground_truths) + num_sample
outputs = set(random.sample(sample_pool, tot_size))
for gt in ground_truths:
outputs.discard(gt)
outputs = list(outputs)[:num_sample]
return outputs
sample_negative = _sample_negative_rand # swith between 2 implementations
class ItmDataset(DetectFeatTxtTokDataset):
""" NOTE this Dataset handles distributed training itself
(for more efficient negative sampling) """
def __init__(self, txt_db, img_db, neg_sample_p=0.5):
assert isinstance(txt_db, TxtTokLmdb)
assert isinstance(img_db, DetectFeatLmdb)
self.txt_db = txt_db
self.img_db = img_db
self.txt_lens, self.ids = get_ids_and_lens(txt_db)
self.all_imgs = list(set(txt_db[id_]['img_fname'] for id_ in self.ids))
self.neg_sample_p = neg_sample_p
self.new_epoch()
def new_epoch(self):
""" should be called every epoch for more randomness"""
self.labels = np.random.choice(
[0, 1], size=len(self.ids),
p=[self.neg_sample_p, 1-self.neg_sample_p])
self.lens = []
self.train_imgs = []
for i, (id_, tl) in enumerate(zip(self.ids, self.txt_lens)):
img_fname = super().__getitem__(i)['img_fname']
if self.labels[i] == 0:
img_fname = sample_negative(self.all_imgs, [img_fname], 1)[0]
self.train_imgs.append(img_fname)
self.lens.append(tl + self.img_db.name2nbb[img_fname])
def __getitem__(self, i):
example = super().__getitem__(i)
# labels and negative images should be sampled every epoch
ground_truth_label = self.labels[i]
img_fname = self.train_imgs[i]
img_feat, img_pos_feat, num_bb = self._get_img_feat(img_fname)
# text input
input_ids = example['input_ids']
input_ids = self.txt_db.combine_inputs(input_ids)
attn_masks = torch.ones(len(input_ids) + num_bb, dtype=torch.long)
target = torch.Tensor(1).long()
target.data.fill_(ground_truth_label)
return input_ids, img_feat, img_pos_feat, attn_masks, target
def itm_collate(inputs):
(input_ids, img_feats, img_pos_feats, attn_masks, targets
) = map(list, unzip(inputs))
txt_lens = [i.size(0) for i in input_ids]
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0)
position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long
).unsqueeze(0)
num_bbs = [f.size(0) for f in img_feats]
img_feat = pad_tensors(img_feats, num_bbs)
img_pos_feat = pad_tensors(img_pos_feats, num_bbs)
attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0)
targets = torch.cat(targets, dim=0)
bs, max_tl = input_ids.size()
out_size = attn_masks.size(1)
gather_index = get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size)
batch = {'input_ids': input_ids,
'position_ids': position_ids,
'img_feat': img_feat,
'img_pos_feat': img_pos_feat,
'attn_masks': attn_masks,
'gather_index': gather_index,
'targets': targets}
return batch
def _compute_ot_scatter(txt_lens, max_txt_len, joint_len):
ot_scatter = torch.arange(0, joint_len, dtype=torch.long
).unsqueeze(0).repeat(len(txt_lens), 1)
for i, tl in enumerate(txt_lens):
max_ind = max_txt_len + (joint_len-tl)
ot_scatter.data[i, tl:] = torch.arange(max_txt_len, max_ind,
dtype=torch.long).data
return ot_scatter
def _compute_pad(lens, max_len):
pad = torch.zeros(len(lens), max_len, dtype=torch.bool)
for i, l in enumerate(lens):
pad.data[i, l:].fill_(1)
return pad
def itm_ot_collate(inputs):
(input_ids, img_feats, img_pos_feats, attn_masks, targets
) = map(list, unzip(inputs))
txt_lens = [i.size(0) for i in input_ids]
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0)
position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long
).unsqueeze(0)
num_bbs = [f.size(0) for f in img_feats]
img_feat = pad_tensors(img_feats, num_bbs)
img_pos_feat = pad_tensors(img_pos_feats, num_bbs)
attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0)
targets = torch.cat(targets, dim=0)
bs, max_tl = input_ids.size()
out_size = attn_masks.size(1)
gather_index = get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size)
# OT inputs
max_tl = max(txt_lens)
max_nbb = max(num_bbs)
ot_scatter = _compute_ot_scatter(txt_lens, max_tl, attn_masks.size(1))
txt_pad = _compute_pad(txt_lens, max_tl)
img_pad = _compute_pad(num_bbs, max_nbb)
ot_inputs = {'ot_scatter': ot_scatter,
'scatter_max': ot_scatter.max().item(),
'txt_pad': txt_pad,
'img_pad': img_pad}
batch = {'input_ids': input_ids,
'position_ids': position_ids,
'img_feat': img_feat,
'img_pos_feat': img_pos_feat,
'attn_masks': attn_masks,
'gather_index': gather_index,
'targets': targets,
'ot_inputs': ot_inputs}
return batch
class ItmRankDataset(DetectFeatTxtTokDataset):
def __init__(self, txt_db, img_db, neg_sample_size=1):
assert neg_sample_size > 0, \
"ItmRankDataset need at least 1 negative sample"
super().__init__(txt_db, img_db)
txt2img = self.txt_db.txt2img
self.txt2img = {id_: txt2img[id_] for id_ in self.ids}
# images partitioned by rank
self.img2txts = defaultdict(list)
for id_, img in self.txt2img.items():
self.img2txts[img].append(id_)
self.img_name_list = list(self.img2txts.keys())
assert neg_sample_size > 0
self.neg_sample_size = neg_sample_size
def __getitem__(self, i):
gt_txt_id = self.ids[i]
gt_img_fname = self.txt2img[gt_txt_id]
id_pairs = [(gt_txt_id, gt_img_fname)]
# sample negatives
neg_sample_img_ids = sample_negative(
self.img_name_list, [gt_img_fname], self.neg_sample_size)
neg_sample_txt_ids = sample_negative(
self.ids, self.img2txts[gt_img_fname], self.neg_sample_size)
id_pairs.extend([(gt_txt_id, neg_img_id)
for neg_img_id in neg_sample_img_ids] +
[(neg_txt_id, gt_img_fname)
for neg_txt_id in neg_sample_txt_ids])
inputs = self._collect_inputs(id_pairs)
assert len(inputs) == (1 + 2*self.neg_sample_size)
return inputs
def _collect_inputs(self, id_pairs):
# create input features
inputs = []
for txt_id, img_id in id_pairs:
example = self.txt_db[txt_id]
# text input
input_ids = example['input_ids']
input_ids = self.txt_db.combine_inputs(input_ids)
# img input
img_feat, img_pos_feat, num_bb = self._get_img_feat(img_id)
# mask
attn_masks_text = torch.ones(len(input_ids), dtype=torch.long)
attn_masks_img = torch.ones(num_bb, dtype=torch.long)
inputs.append((input_ids, img_feat, img_pos_feat, attn_masks_text, attn_masks_img))
return inputs
class ItmRankDatasetHardNeg(ItmRankDataset):
def __init__(self, txt_db, img_db, neg_sample_size=1, hard_neg_size=1):
assert hard_neg_size > 0, \
"ItmRankDatasetHardNeg need at least 1 hard negative sample"
DetectFeatTxtTokDataset.__init__(self, txt_db, img_db)
txt2img = self.txt_db.txt2img
self.txt2img = {id_: txt2img[id_] for id_ in self.ids}
self.img2txts = self.txt_db.img2txts
self.img_name_list = list(self.img2txts.keys())
assert neg_sample_size > 0
self.neg_sample_size = neg_sample_size
self.hard_neg_size = hard_neg_size
def reload_hard_negs(self, hard_neg_dir):
self.txt2hardimgs = json.load(
open(f'{hard_neg_dir}/'
f'txt2hardimgs_rank{hvd.rank()}.json'))
self.img2hardtxts = json.load(
open(f'{hard_neg_dir}/img2hardtxts.json'))
def __getitem__(self, i):
gt_txt_id = self.ids[i]
gt_img_fname = self.txt2img[gt_txt_id]
id_pairs = [(gt_txt_id, gt_img_fname)]
# sample hard negatives
if self.hard_neg_size > 0:
hard_neg_img_samples = random.sample(
self.txt2hardimgs[gt_txt_id], self.hard_neg_size)
hard_neg_txt_samples = random.sample(
self.img2hardtxts[gt_img_fname], self.hard_neg_size)
id_pairs.extend([(gt_txt_id, neg_img_id)
for neg_img_id in hard_neg_img_samples] +
[(neg_txt_id, gt_img_fname)
for neg_txt_id in hard_neg_txt_samples])
# sample normal negatives
if self.neg_sample_size > 0:
neg_sample_img_ids = sample_negative(
self.img_name_list, [gt_img_fname], self.neg_sample_size)
neg_sample_txt_ids = sample_negative(
self.ids, self.img2txts[gt_img_fname], self.neg_sample_size)
id_pairs.extend([(gt_txt_id, neg_img_id)
for neg_img_id in neg_sample_img_ids] +
[(neg_txt_id, gt_img_fname)
for neg_txt_id in neg_sample_txt_ids])
inputs = self._collect_inputs(id_pairs)
assert len(inputs) == (1
+ 2*self.neg_sample_size
+ 2*self.hard_neg_size)
return inputs
def itm_rank_collate(inputs):
(input_ids, img_feats, img_pos_feats, attn_masks_text, attn_masks_img,
) = map(list, unzip(concat(i for i in inputs)))
txt_lens = [i.size(0) for i in input_ids]
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0)
position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long
).unsqueeze(0)
num_bbs = [f.size(0) for f in img_feats]
img_feat = pad_tensors(img_feats, num_bbs)
img_pos_feat = pad_tensors(img_pos_feats, num_bbs)
attn_masks_text = pad_sequence(attn_masks_text, batch_first=True, padding_value=0)
attn_masks_img = pad_sequence(attn_masks_img, batch_first=True, padding_value=0)
sample_size = len(inputs[0])
assert all(sample_size == len(i) for i in inputs)
bs, max_tl = input_ids.size()
# out_size = attn_masks.size(1)
gather_index = None # get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size)
batch = {'input_ids': input_ids,
'position_ids': position_ids,
'img_feat': img_feat,
'img_pos_feat': img_pos_feat,
'attn_masks_text': attn_masks_text,
'attn_masks_img': attn_masks_img,
'gather_index': gather_index,
'sample_size': sample_size}
return batch
class ItmRankDatasetHardNegFromText(DetectFeatTxtTokDataset):
def __init__(self, txt_db, img_db, neg_sample_size=1):
assert neg_sample_size > 0, \
"ItmRankDatasetHardNegV2 need at least 1 negative sample"
super().__init__(txt_db, img_db)
txt2img = self.txt_db.txt2img
self.txt2img = {id_: txt2img[id_] for id_ in self.ids}
self.img2txts = self.txt_db.img2txts
self.img_name_list = list(self.img2txts.keys())
self.neg_sample_size = neg_sample_size
def __getitem__(self, i):
gt_txt_id = self.ids[i]
gt_img_fname = self.txt2img[gt_txt_id]
input_ids = self.txt_db[gt_txt_id]['input_ids']
input_ids = self.txt_db.combine_inputs(input_ids)
input_ids = input_ids.unsqueeze(0)
position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long
).unsqueeze(0)
neg_img_ids = sample_negative(
self.img_name_list, [gt_img_fname], self.neg_sample_size)
img_ids = [gt_img_fname] + neg_img_ids
# process image features (gt always first)
img_feats, img_pos_feats, num_bbs = map(
list, unzip(map(self._get_img_feat, img_ids)))
img_feat = pad_tensors(img_feats, num_bbs)
img_pos_feat = pad_tensors(img_pos_feats, num_bbs)
tl = input_ids.size(1)
attn_masks = torch.zeros(len(img_ids), max(num_bbs) + tl).long()
for i, nbb in enumerate(num_bbs):
attn_masks.data[i, :tl+nbb].fill_(1)
out_size = attn_masks.size(1)
gather_index = get_gather_index([tl]*len(img_ids), num_bbs,
len(img_ids), tl, out_size)
batch = {'input_ids': input_ids,
'position_ids': position_ids,
'img_feat': img_feat,
'img_pos_feat': img_pos_feat,
'attn_masks': attn_masks,
'gather_index': gather_index}
return batch
class ItmRankDatasetHardNegFromImage(DetectFeatTxtTokDataset):
def __init__(self, txt_db, img_db, neg_sample_size=1):
assert neg_sample_size > 0, \
"ItmRankDatasetHardNegV2 need at least 1 negative sample"
super().__init__(txt_db, img_db)
txt2img = self.txt_db.txt2img
self.txt2img = {id_: txt2img[id_] for id_ in self.ids}
self.img2txts = self.txt_db.img2txts
self.txt_name_list = list(self.txt2img.keys())
self.neg_sample_size = neg_sample_size
def __getitem__(self, i):
gt_txt_id = self.ids[i]
gt_img_id = self.txt2img[gt_txt_id]
gt_txt_ids = self.img2txts[gt_img_id]
# process image features (gt always first)
img_feat, img_pos_feat, nbb = self._get_img_feat(gt_img_id)
img_feat = img_feat.unsqueeze(0)
img_pos_feat = img_pos_feat.unsqueeze(0)
# sample negative
neg_txt_ids = sample_negative(
self.txt_name_list, gt_txt_ids, self.neg_sample_size)
txt_ids = [gt_txt_id] + neg_txt_ids
# process text inputs
all_inputs = []
txt_lens = []
for txt_id in txt_ids:
input_ids = self.txt_db.combine_inputs(
self.txt_db[txt_id]['input_ids'])
all_inputs.append(input_ids)
txt_lens.append(len(input_ids))
input_ids = pad_sequence(all_inputs, batch_first=True, padding_value=0)
position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long
).unsqueeze(0)
attn_masks = torch.zeros(len(txt_ids), max(txt_lens) + nbb).long()
for i, tl in enumerate(txt_lens):
attn_masks.data[i, :tl+nbb].fill_(1)
out_size = attn_masks.size(1)
gather_index = get_gather_index(txt_lens, [nbb]*len(txt_ids),
len(txt_ids), tl, out_size)
batch = {'input_ids': input_ids,
'position_ids': position_ids,
'img_feat': img_feat,
'img_pos_feat': img_pos_feat,
'attn_masks': attn_masks,
'gather_index': gather_index}
return batch
def itm_rank_hnv2_collate(inputs):
assert len(inputs) == 1
return inputs[0]
class ItmValDataset(DetectFeatTxtTokDataset):
""" For evaluating Image-Text-Retrieval task """
def __init__(self, db_dir, img_dir, mini_batch_size=400):
super().__init__(db_dir, img_dir)
del self.lens
self.txt2img = self.txt_db.txt2img
self.img2txts = self.txt_db.img2txts
self.all_img_ids = list(self.img2txts.keys())
assert len(self.img2txts) >= mini_batch_size > 0
self.bs = mini_batch_size
def _get_batch_ids(self, i):
gt_txt_id = self.ids[i]
gt_img_id = self.txt2img[gt_txt_id]
# sample fixed negatives for each gt image
i = self.all_img_ids.index(gt_img_id)
neg_st = i+1
neg_end = neg_st+self.bs-1
if neg_end > len(self.all_img_ids):
# warp around
neg_end -= len(self.all_img_ids)
neg_img_ids = (self.all_img_ids[neg_st:]
+ self.all_img_ids[:neg_end])
else:
neg_img_ids = self.all_img_ids[neg_st:neg_end]
assert len(neg_img_ids) == (self.bs - 1),\
"Did not sample enough neg samples"
return gt_img_id, neg_img_ids
def __getitem__(self, i):
""" this returns list of mini-batches """
gt_img_id, neg_img_ids = self._get_batch_ids(i)
# NOTE 1st one is gt img
batch = self.get_batch(i, [gt_img_id] + neg_img_ids)
return batch
def get_batch(self, i, img_ids):
example = super().__getitem__(i)
input_ids = example['input_ids']
input_ids = self.txt_db.combine_inputs(input_ids)
input_ids = input_ids.unsqueeze(0).expand(len(img_ids), -1).clone()
position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long
).unsqueeze(0)
# process image features (gt always first)
img_feats, img_pos_feats, num_bbs = map(
list, unzip(map(self._get_img_feat, img_ids)))
img_feat = pad_tensors(img_feats, num_bbs)
img_pos_feat = pad_tensors(img_pos_feats, num_bbs)
tl = input_ids.size(1)
attn_masks_text = torch.ones(len(img_ids), tl).long()
# attn_masks_text = torch.ones(1, tl).long()
attn_masks_img = torch.zeros(len(img_ids), max(num_bbs)).long()
for i, nbb in enumerate(num_bbs):
attn_masks_img.data[i, :nbb].fill_(1)
# out_size = attn_masks.size(1)
gather_index = None #get_gather_index([tl]*len(img_ids), num_bbs, len(img_ids), tl, out_size)
batch = {'input_ids': input_ids,
'position_ids': position_ids,
'img_feat': img_feat,
'img_pos_feat': img_pos_feat,
'attn_masks_text': attn_masks_text,
'attn_masks_img': attn_masks_img,
'gather_index': gather_index}
return batch
def itm_val_collate(inputs):
assert len(inputs) == 1, "input batch size > 1"
return inputs[0]
class ItmHardNegDataset(ItmValDataset):
def _get_batch_ids(self, i):
gt_txt_id = self.ids[i]
gt_img_id = self.txt2img[gt_txt_id]
# sample fixed negatives for each gt image
i = self.all_img_ids.index(gt_img_id)
all_img_ids = copy.deepcopy(self.all_img_ids)
all_img_ids.remove(gt_img_id)
random.shuffle(all_img_ids)
neg_img_ids = all_img_ids[:self.bs]
assert len(neg_img_ids) == (self.bs),\
"Did not sample enough neg samples"
return gt_img_id, neg_img_ids
def __getitem__(self, i):
""" this returns list of mini-batches """
_, neg_img_ids = self._get_batch_ids(i)
batch = self.get_batch(i, neg_img_ids)
batch['gt_txt_id'] = self.ids[i]
batch['neg_img_ids'] = neg_img_ids
return batch
itm_hn_collate = itm_val_collate
class ItmEvalDataset(ItmValDataset):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.all_img_ids = sorted(copy.deepcopy(self.all_img_ids),
key=lambda i: self.img_db.name2nbb[i])
def __getitem__(self, i):
mini_batches = []
for st in range(0, len(self.all_img_ids), self.bs):
mini_batches.append(
self.get_batch(i, self.all_img_ids[st:st+self.bs]))
return mini_batches
itm_eval_collate = itm_val_collate

138
uniter_model/data/loader.py

@ -0,0 +1,138 @@
"""
A meta data loader for sampling from different datasets / training tasks
A prefetch loader to speedup data loading
"""
import random
import torch
from torch.utils.data import DataLoader
from uniter_model.utils.distributed import any_broadcast
class MetaLoader(object):
""" wraps multiple data loader """
def __init__(self, loaders, accum_steps=1, distributed=False):
assert isinstance(loaders, dict)
self.name2loader = {}
self.name2iter = {}
self.sampling_pools = []
for n, l in loaders.items():
if isinstance(l, tuple):
l, r = l
elif isinstance(l, DataLoader):
r = 1
else:
raise ValueError()
self.name2loader[n] = l
self.name2iter[n] = iter(l)
self.sampling_pools.extend([n]*r)
self.accum_steps = accum_steps
self.distributed = distributed
self.step = 0
def __iter__(self):
""" this iterator will run indefinitely """
task = self.sampling_pools[0]
while True:
if self.step % self.accum_steps == 0:
task = random.choice(self.sampling_pools)
if self.distributed:
# make sure all process is training same task
task = any_broadcast(task, 0)
self.step += 1
iter_ = self.name2iter[task]
try:
batch = next(iter_)
except StopIteration:
iter_ = iter(self.name2loader[task])
batch = next(iter_)
self.name2iter[task] = iter_
yield task, batch
def move_to_cuda(batch):
if isinstance(batch, torch.Tensor):
return batch.cuda(non_blocking=True)
elif isinstance(batch, list):
new_batch = [move_to_cuda(t) for t in batch]
elif isinstance(batch, tuple):
new_batch = tuple(move_to_cuda(t) for t in batch)
elif isinstance(batch, dict):
new_batch = {n: move_to_cuda(t) for n, t in batch.items()}
else:
return batch
return new_batch
def record_cuda_stream(batch):
if isinstance(batch, torch.Tensor):
batch.record_stream(torch.cuda.current_stream())
elif isinstance(batch, list) or isinstance(batch, tuple):
for t in batch:
record_cuda_stream(t)
elif isinstance(batch, dict):
for t in batch.values():
record_cuda_stream(t)
else:
pass
class PrefetchLoader(object):
"""
overlap compute and cuda data transfer
(copied and then modified from nvidia apex)
"""
def __init__(self, loader):
self.loader = loader
self.stream = torch.cuda.Stream()
def __iter__(self):
loader_it = iter(self.loader)
self.preload(loader_it)
batch = self.next(loader_it)
while batch is not None:
yield batch
batch = self.next(loader_it)
def __len__(self):
return len(self.loader)
def preload(self, it):
try:
self.batch = next(it)
except StopIteration:
self.batch = None
return
# if record_stream() doesn't work, another option is to make sure
# device inputs are created on the main stream.
# self.next_input_gpu = torch.empty_like(self.next_input,
# device='cuda')
# self.next_target_gpu = torch.empty_like(self.next_target,
# device='cuda')
# Need to make sure the memory allocated for next_* is not still in use
# by the main stream at the time we start copying to next_*:
# self.stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self.stream):
self.batch = move_to_cuda(self.batch)
# more code for the alternative if record_stream() doesn't work:
# copy_ will record the use of the pinned source tensor in this
# side stream.
# self.next_input_gpu.copy_(self.next_input, non_blocking=True)
# self.next_target_gpu.copy_(self.next_target, non_blocking=True)
# self.next_input = self.next_input_gpu
# self.next_target = self.next_target_gpu
def next(self, it):
torch.cuda.current_stream().wait_stream(self.stream)
batch = self.batch
if batch is not None:
record_cuda_stream(batch)
self.preload(it)
return batch
def __getattr__(self, name):
method = self.loader.__getattribute__(name)
return method

360
uniter_model/data/mlm.py

@ -0,0 +1,360 @@
"""
MLM datasets
"""
import math
import random
import torch
from torch.utils.data import Dataset
from torch.nn.utils.rnn import pad_sequence
from toolz.sandbox import unzip
from .data import (DetectFeatTxtTokDataset, TxtTokLmdb,
get_ids_and_lens, pad_tensors, get_gather_index)
def random_word(tokens, vocab_range, mask):
"""
Masking some random tokens for Language Model task with probabilities as in
the original BERT paper.
:param tokens: list of int, tokenized sentence.
:param vocab_range: for choosing a random word
:return: (list of int, list of int), masked tokens and related labels for
LM prediction
"""
output_label = []
for i, token in enumerate(tokens):
prob = random.random()
# mask token with 15% probability
if prob < 0.15:
prob /= 0.15
# 80% randomly change token to mask token
if prob < 0.8:
tokens[i] = mask
# 10% randomly change token to random token
elif prob < 0.9:
tokens[i] = random.choice(list(range(*vocab_range)))
# -> rest 10% randomly keep current token
# append current token to output (we will predict these later)
output_label.append(token)
else:
# no masking token (will be ignored by loss function later)
output_label.append(-1)
if all(o == -1 for o in output_label):
# at least mask 1
output_label[0] = tokens[0]
tokens[0] = mask
return tokens, output_label
class MlmDataset(DetectFeatTxtTokDataset):
def __init__(self, txt_db, img_db):
assert isinstance(txt_db, TxtTokLmdb)
super().__init__(txt_db, img_db)
def __getitem__(self, i):
"""
Return:
- input_ids : (L, ), i.e., [cls, wd, wd, ..., sep, 0, 0], 0s padded
- img_feat : (num_bb, d)
- img_pos_feat : (num_bb, 7)
- attn_masks : (L + num_bb, ), ie., [1, 1, ..., 0, 0, 1, 1]
- txt_labels : (L, ), [-1, -1, wid, -1, -1, -1]
0's padded so that (L + num_bb) % 8 == 0
"""
example = super().__getitem__(i)
# text input
input_ids, txt_labels = self.create_mlm_io(example['input_ids'])
# img input
img_feat, img_pos_feat, num_bb = self._get_img_feat(
example['img_fname'])
attn_masks = torch.ones(len(input_ids) + num_bb, dtype=torch.long)
return input_ids, img_feat, img_pos_feat, attn_masks, txt_labels
def create_mlm_io(self, input_ids):
input_ids, txt_labels = random_word(input_ids,
self.txt_db.v_range,
self.txt_db.mask)
input_ids = torch.tensor([self.txt_db.cls_]
+ input_ids
+ [self.txt_db.sep])
txt_labels = torch.tensor([-1] + txt_labels + [-1])
return input_ids, txt_labels
def mlm_collate(inputs):
"""
Return:
:input_ids (n, max_L) padded with 0
:position_ids (n, max_L) padded with 0
:txt_lens list of [txt_len]
:img_feat (n, max_num_bb, feat_dim)
:img_pos_feat (n, max_num_bb, 7)
:num_bbs list of [num_bb]
:attn_masks (n, max_{L + num_bb}) padded with 0
:txt_labels (n, max_L) padded with -1
"""
(input_ids, img_feats, img_pos_feats, attn_masks, txt_labels
) = map(list, unzip(inputs))
# text batches
txt_lens = [i.size(0) for i in input_ids]
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0)
txt_labels = pad_sequence(txt_labels, batch_first=True, padding_value=-1)
position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long
).unsqueeze(0)
# image batches
num_bbs = [f.size(0) for f in img_feats]
img_feat = pad_tensors(img_feats, num_bbs)
img_pos_feat = pad_tensors(img_pos_feats, num_bbs)
attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0)
bs, max_tl = input_ids.size()
out_size = attn_masks.size(1)
gather_index = get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size)
batch = {'input_ids': input_ids,
'position_ids': position_ids,
'img_feat': img_feat,
'img_pos_feat': img_pos_feat,
'attn_masks': attn_masks,
'gather_index': gather_index,
'txt_labels': txt_labels}
return batch
class BlindMlmDataset(Dataset):
def __init__(self, txt_db):
assert isinstance(txt_db, TxtTokLmdb)
self.txt_db = txt_db
self.lens, self.ids = get_ids_and_lens(txt_db)
def __len__(self):
return len(self.ids)
def __getitem__(self, i):
id_ = self.ids[i]
example = self.txt_db[id_]
input_ids, txt_labels = self.create_mlm_io(example['input_ids'])
attn_masks = torch.ones(len(input_ids), dtype=torch.long)
return input_ids, attn_masks, txt_labels
def mlm_blind_collate(inputs):
input_ids, attn_masks, txt_labels = map(list, unzip(inputs))
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0)
position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long
).unsqueeze(0)
attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0)
txt_labels = pad_sequence(txt_labels, batch_first=True, padding_value=-1)
batch = {'input_ids': input_ids,
'position_ids': position_ids,
'attn_masks': attn_masks,
'txt_labels': txt_labels}
return batch
def eval_mask(len_, num_samples=7):
""" build the mask for evaluating MLM
circularly mask 1 word out of every x words
"""
# build the random masks
if len_ <= num_samples:
masks = torch.eye(len_).bool()
num_samples = len_
else:
mask_inds = [list(range(i, len_, num_samples))
for i in range(num_samples)]
masks = torch.zeros(num_samples, len_).bool()
for i, indices in enumerate(mask_inds):
for j in indices:
masks.data[i, j] = 1
assert (masks.sum(dim=0) != torch.ones(len_).long()).sum().item() == 0
assert masks.sum().item() == len_
return masks
def eval_gather_inds(len_, num_samples=7):
""" get the gather indices """
inds = torch.arange(0, num_samples, dtype=torch.long)
mul = math.ceil(len_ / num_samples)
output = inds.repeat(mul)[:len_]
return output
def stack_pad_tensors(tensors, lens=None, ns=None, pad=0):
"""N x [B_i, T, ...]"""
if ns is None:
ns = [t.size(0) for t in tensors]
if lens is None:
lens = [t.size(1) for t in tensors]
max_len = max(lens)
bs = sum(ns)
hid_dims = tensors[0].size()[2:]
dtype = tensors[0].dtype
output = torch.zeros(bs, max_len, *hid_dims, dtype=dtype)
if pad:
output.data.fill_(pad)
i = 0
for t, l, n in zip(tensors, lens, ns):
output.data[i:i+n, :l, ...] = t.data
i += n
return output
def expand_tensors(tensors, ns):
return [t.unsqueeze(0).expand(n, *tuple([-1]*t.dim()))
for t, n in zip(tensors, ns)]
class MlmEvalDataset(DetectFeatTxtTokDataset):
""" For evaluating MLM training task """
def __init__(self, txt_db, img_db):
assert isinstance(txt_db, TxtTokLmdb)
super().__init__(txt_db, img_db)
def __getitem__(self, i):
example = super().__getitem__(i)
# text input
(input_ids, txt_labels, gather_inds
) = self.create_mlm_eval_io(example['input_ids'])
# img input
img_feat, img_pos_feat, num_bb = self._get_img_feat(
example['img_fname'])
attn_masks = torch.ones(input_ids.size(1) + num_bb, dtype=torch.long)
return (input_ids, img_feat, img_pos_feat, attn_masks,
txt_labels, gather_inds)
def create_mlm_eval_io(self, input_ids):
txt_labels = torch.tensor(input_ids)
masks = eval_mask(len(input_ids))
n_mask = masks.size(0)
masks = torch.cat([torch.zeros(n_mask, 1).bool(),
masks,
torch.zeros(n_mask, 1).bool()],
dim=1)
input_ids = torch.tensor([[self.txt_db.cls_]
+ input_ids
+ [self.txt_db.sep]
for _ in range(n_mask)])
input_ids.data.masked_fill_(masks, self.txt_db.mask)
gather_inds = eval_gather_inds(len(txt_labels))
return input_ids, txt_labels, gather_inds
def _batch_gather_tgt(gather_inds, n_masks):
gather_tgts = []
offset = 0
for g, n in zip(gather_inds, n_masks):
gather_tgts.append(g + offset)
offset += n
gather_tgt = pad_sequence(gather_tgts, batch_first=True, padding_value=0)
return gather_tgt
def mlm_eval_collate(inputs):
(input_ids, img_feats, img_pos_feats, attn_masks, txt_labels, gather_inds
) = map(list, unzip(inputs))
# sizes
n_masks, txt_lens = map(list, unzip(i.size() for i in input_ids))
# text batches
input_ids = stack_pad_tensors(input_ids, txt_lens, n_masks)
position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long
).unsqueeze(0)
txt_labels = pad_sequence(txt_labels, batch_first=True, padding_value=-1)
gather_tgt = _batch_gather_tgt(gather_inds, n_masks)
# image batches
num_bbs = [f.size(0) for f in img_feats]
img_feat = stack_pad_tensors(expand_tensors(img_feats, n_masks),
num_bbs, n_masks)
img_pos_feat = stack_pad_tensors(expand_tensors(img_pos_feats, n_masks),
num_bbs, n_masks)
bs, max_tl = input_ids.size()
attn_masks = stack_pad_tensors(expand_tensors(attn_masks, n_masks),
None, n_masks)
out_size = attn_masks.size(1)
# repeat txt_lens, num_bbs
txt_lens = [l for l, n in zip(txt_lens, n_masks) for _ in range(n)]
num_bbs = [b for b, n in zip(num_bbs, n_masks) for _ in range(n)]
gather_index = get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size)
batch = {'input_ids': input_ids,
'position_ids': position_ids,
'img_feat': img_feat,
'img_pos_feat': img_pos_feat,
'attn_masks': attn_masks,
'gather_index': gather_index,
'gather_tgt': gather_tgt,
'txt_labels': txt_labels}
return batch
class BlindMlmEvalDataset(Dataset):
def __init__(self, txt_db):
assert isinstance(txt_db, TxtTokLmdb)
self.txt_db = txt_db
self.lens, self.ids = get_ids_and_lens(txt_db)
def __len__(self):
return len(self.ids)
def __getitem__(self, i):
id_ = self.ids[i]
example = self.txt_db[id_]
input_ids = example['input_ids']
# text input
input_ids = example['input_ids']
(input_ids, txt_labels, gather_inds
) = self.txt_db.create_mlm_eval_io(input_ids)
attn_masks = torch.ones(len(input_ids), dtype=torch.long)
return input_ids, attn_masks, txt_labels, gather_inds
def mlm_blind_eval_collate(inputs):
(input_ids, position_ids, attn_masks, txt_labels, gather_inds
) = map(list, unzip(inputs))
# sizes
n_masks, txt_lens = map(list, unzip(i.size() for i in input_ids))
# text batches
input_ids = stack_pad_tensors(input_ids, txt_lens, n_masks)
position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long
).unsqueeze(0)
attn_masks = stack_pad_tensors(expand_tensors(attn_masks, n_masks),
None, n_masks)
txt_labels = pad_sequence(txt_labels, batch_first=True, padding_value=-1)
gather_tgt = _batch_gather_tgt(gather_inds, n_masks)
batch = {'input_ids': input_ids,
'position_ids': position_ids,
'attn_masks': attn_masks,
'gather_tgt': gather_tgt,
'txt_labels': txt_labels}
return batch

287
uniter_model/data/mrm.py

@ -0,0 +1,287 @@
"""
MRM Datasets
"""
import random
import torch
from torch.utils.data import Dataset
from torch.nn.utils.rnn import pad_sequence
from toolz.sandbox import unzip
from .data import DetectFeatTxtTokDataset, pad_tensors, get_gather_index
def _get_img_mask(mask_prob, num_bb):
img_mask = [random.random() < mask_prob for _ in range(num_bb)]
if not any(img_mask):
# at least mask 1
img_mask[random.choice(range(num_bb))] = True
img_mask = torch.tensor(img_mask)
return img_mask
def _get_img_tgt_mask(img_mask, txt_len):
z = torch.zeros(txt_len, dtype=torch.bool)
img_mask_tgt = torch.cat([z, img_mask], dim=0)
return img_mask_tgt
def _get_feat_target(img_feat, img_masks):
img_masks_ext = img_masks.unsqueeze(-1).expand_as(img_feat) # (n, m, d)
feat_dim = img_feat.size(-1)
feat_targets = img_feat[img_masks_ext].contiguous().view(
-1, feat_dim) # (s, d)
return feat_targets
def _mask_img_feat(img_feat, img_masks):
img_masks_ext = img_masks.unsqueeze(-1).expand_as(img_feat)
img_feat_masked = img_feat.data.masked_fill(img_masks_ext, 0)
return img_feat_masked
class MrfrDataset(DetectFeatTxtTokDataset):
def __init__(self, mask_prob, *args, **kwargs):
super().__init__(*args, **kwargs)
self.mask_prob = mask_prob
def __getitem__(self, i):
"""
Return:
- input_ids : (L, ), i.e., [cls, wd, wd, ..., sep, 0, 0], 0s padded
- img_feat : (num_bb, d)
- img_pos_feat : (num_bb, 7)
- attn_masks : (L + num_bb, ), ie., [1, 1, ..., 0, 0, 1, 1]
- img_mask : (num_bb, ) between {0, 1}
"""
example = super().__getitem__(i)
# text input
input_ids = example['input_ids']
input_ids = self.txt_db.combine_inputs(input_ids)
# image input features
img_feat, img_pos_feat, num_bb = self._get_img_feat(
example['img_fname'])
img_mask = _get_img_mask(self.mask_prob, num_bb)
img_mask_tgt = _get_img_tgt_mask(img_mask, len(input_ids))
attn_masks = torch.ones(len(input_ids) + num_bb, dtype=torch.long)
return (input_ids, img_feat, img_pos_feat,
attn_masks, img_mask, img_mask_tgt)
def mrfr_collate(inputs):
"""
Return:
- input_ids : (n, max_L), i.e., [cls, wd, wd, ..., sep, 0, 0], 0s padded
- position_ids : (n, max_L)
- txt_lens : list of [input_len]
- img_feat : (n, max_num_bb, d)
- img_pos_feat : (n, max_num_bb, 7)
- num_bbs : list of [num_bb]
- attn_masks : (n, max_{L + num_bb}), ie., [1, 1, ..., 0, 0, 1, 1]
- img_masks : (n, max_num_bb) between {0, 1}
"""
(input_ids, img_feats, img_pos_feats, attn_masks, img_masks, img_mask_tgts,
) = map(list, unzip(inputs))
txt_lens = [i.size(0) for i in input_ids]
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0)
position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long
).unsqueeze(0)
num_bbs = [f.size(0) for f in img_feats]
img_feat = pad_tensors(img_feats, num_bbs)
img_pos_feat = pad_tensors(img_pos_feats, num_bbs)
# mask features
img_masks = pad_sequence(img_masks, batch_first=True, padding_value=0)
feat_targets = _get_feat_target(img_feat, img_masks)
img_feat = _mask_img_feat(img_feat, img_masks)
img_mask_tgt = pad_sequence(img_mask_tgts,
batch_first=True, padding_value=0)
attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0)
bs, max_tl = input_ids.size()
out_size = attn_masks.size(1)
gather_index = get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size)
batch = {'input_ids': input_ids,
'position_ids': position_ids,
'img_feat': img_feat,
'img_pos_feat': img_pos_feat,
'attn_masks': attn_masks,
'gather_index': gather_index,
'feat_targets': feat_targets,
'img_masks': img_masks,
'img_mask_tgt': img_mask_tgt}
return batch
class OnlyImgMrfrDataset(Dataset):
""" an image-only MRM """
def __init__(self, mask_prob, img_db):
self.ids, self.lens = map(list, unzip(self.img_db.name2nbb.items()))
def __getitem__(self, i):
id_ = self.ids[i]
img_feat, img_pos_feat, num_bb = self._get_img_feat(id_)
attn_masks = torch.ones(num_bb, dtype=torch.long)
img_mask = _get_img_mask(self.mask_prob, num_bb)
return img_feat, img_pos_feat, attn_masks, img_mask
def _get_img_feat(self, fname):
img_feat, bb = self.img_db[fname]
img_bb = torch.cat([bb, bb[:, 4:5]*bb[:, 5:]], dim=-1)
num_bb = img_feat.size(0)
return img_feat, img_bb, num_bb
def mrfr_only_img_collate(inputs):
img_feats, img_pos_feats, attn_masks, img_masks = map(list, unzip(inputs))
attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0)
num_bbs = [f.size(0) for f in img_feats]
img_feat = pad_tensors(img_feats, num_bbs)
img_pos_feat = pad_tensors(img_pos_feats, num_bbs)
# mask features
img_masks = pad_sequence(img_masks, batch_first=True, padding_value=0)
feat_targets = _get_feat_target(img_feat, img_masks)
img_feat = _mask_img_feat(img_feat, img_masks)
batch = {'img_feat': img_feat,
'img_pos_feat': img_pos_feat,
'attn_masks': attn_masks,
'feat_targets': feat_targets,
'img_masks': img_masks,
'img_mask_tgt': img_masks}
return batch
def _get_targets(img_masks, img_soft_label):
soft_label_dim = img_soft_label.size(-1)
img_masks_ext_for_label = img_masks.unsqueeze(-1).expand_as(img_soft_label)
label_targets = img_soft_label[img_masks_ext_for_label].contiguous().view(
-1, soft_label_dim)
return label_targets
class MrcDataset(DetectFeatTxtTokDataset):
def __init__(self, mask_prob, *args, **kwargs):
super().__init__(*args, **kwargs)
self.mask_prob = mask_prob
def _get_img_feat(self, fname):
img_dump = self.img_db.get_dump(fname)
num_bb = self.img_db.name2nbb[fname]
img_feat = torch.tensor(img_dump['features'])
bb = torch.tensor(img_dump['norm_bb'])
img_bb = torch.cat([bb, bb[:, 4:5]*bb[:, 5:]], dim=-1)
img_soft_label = torch.tensor(img_dump['soft_labels'])
return img_feat, img_bb, img_soft_label, num_bb
def __getitem__(self, i):
example = super().__getitem__(i)
img_feat, img_pos_feat, img_soft_labels, num_bb = self._get_img_feat(
example['img_fname'])
# image input features
img_mask = _get_img_mask(self.mask_prob, num_bb)
# text input
input_ids = example['input_ids']
input_ids = self.txt_db.combine_inputs(input_ids)
img_mask_tgt = _get_img_tgt_mask(img_mask, len(input_ids))
attn_masks = torch.ones(len(input_ids) + num_bb, dtype=torch.long)
return (input_ids, img_feat, img_pos_feat,
img_soft_labels, attn_masks, img_mask, img_mask_tgt)
def mrc_collate(inputs):
(input_ids, img_feats, img_pos_feats, img_soft_labels,
attn_masks, img_masks, img_mask_tgts) = map(list, unzip(inputs))
txt_lens = [i.size(0) for i in input_ids]
num_bbs = [f.size(0) for f in img_feats]
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0)
position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long
).unsqueeze(0)
img_feat = pad_tensors(img_feats, num_bbs)
img_pos_feat = pad_tensors(img_pos_feats, num_bbs)
img_soft_label = pad_tensors(img_soft_labels, num_bbs)
img_masks = pad_sequence(img_masks, batch_first=True, padding_value=0)
label_targets = _get_targets(img_masks, img_soft_label)
img_feat = _mask_img_feat(img_feat, img_masks)
img_mask_tgt = pad_sequence(img_mask_tgts,
batch_first=True, padding_value=0)
attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0)
bs, max_tl = input_ids.size()
out_size = attn_masks.size(1)
gather_index = get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size)
batch = {'input_ids': input_ids,
'position_ids': position_ids,
'img_feat': img_feat,
'img_pos_feat': img_pos_feat,
'attn_masks': attn_masks,
'gather_index': gather_index,
'img_masks': img_masks,
'img_mask_tgt': img_mask_tgt,
'label_targets': label_targets}
return batch
class OnlyImgMrcDataset(OnlyImgMrfrDataset):
""" an image-only MRC """
def __getitem__(self, i):
id_ = self.ids[i]
(img_feat, img_pos_feat, img_soft_labels, num_bb
) = self._get_img_feat(id_)
attn_masks = torch.ones(num_bb, dtype=torch.long)
img_mask = _get_img_mask(self.mask_prob, num_bb)
return img_feat, img_pos_feat, img_soft_labels, attn_masks, img_mask
def _get_img_feat(self, fname):
img_dump = self.img_db.get_dump(fname)
num_bb = self.img_db.name2nbb[fname]
img_feat = torch.tensor(img_dump['features'])
bb = torch.tensor(img_dump['norm_bb'])
img_bb = torch.cat([bb, bb[:, 4:5]*bb[:, 5:]], dim=-1)
img_soft_labels = torch.tensor(img_dump['soft_labels'])
return img_feat, img_bb, img_soft_labels, num_bb
def mrc_only_img_collate(inputs):
(img_feats, img_pos_feats, img_soft_labels, attn_masks, img_masks
) = map(list, unzip(inputs))
attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0)
img_masks = pad_sequence(img_masks, batch_first=True, padding_value=0)
num_bbs = [f.size(0) for f in img_feats]
img_feat = pad_tensors(img_feats, num_bbs)
img_pos_feat = pad_tensors(img_pos_feats, num_bbs)
img_soft_label = pad_tensors(img_soft_labels, num_bbs)
label_targets = _get_targets(img_masks, img_soft_label)
# mask features
img_feat = _mask_img_feat(img_feat, img_masks)
batch = {'img_feat': img_feat,
'img_pos_feat': img_pos_feat,
'attn_masks': attn_masks,
'img_masks': img_masks,
'img_mask_tgt': img_masks,
'label_targets': label_targets}
return batch

136
uniter_model/data/mrm_nce.py

@ -0,0 +1,136 @@
"""
MRM Datasets (contrastive learning version)
"""
import torch
from torch.nn.utils.rnn import pad_sequence
from toolz.sandbox import unzip
from cytoolz import curry
from .data import (DetectFeatLmdb, DetectFeatTxtTokDataset,
pad_tensors, get_gather_index)
from .mrm import _get_img_mask, _get_img_tgt_mask, _get_feat_target
from .itm import sample_negative
# FIXME diff implementation from mrfr, mrc
def _mask_img_feat(img_feat, img_masks, neg_feats,
noop_prob=0.1, change_prob=0.1):
rand = torch.rand(*img_masks.size())
noop_mask = rand < noop_prob
change_mask = ~noop_mask & (rand < (noop_prob+change_prob)) & img_masks
img_masks_in = img_masks & ~noop_mask & ~change_mask
img_masks_ext = img_masks_in.unsqueeze(-1).expand_as(img_feat)
img_feat_masked = img_feat.data.masked_fill(img_masks_ext, 0)
n_neg = change_mask.sum().item()
feat_dim = neg_feats.size(-1)
index = torch.arange(0, change_mask.numel(), dtype=torch.long
).masked_select(change_mask.view(-1))
index = index.unsqueeze(-1).expand(-1, feat_dim)
img_feat_out = img_feat_masked.view(-1, feat_dim).scatter(
dim=0, index=index, src=neg_feats[:n_neg]).view(*img_feat.size())
return img_feat_out, img_masks_in
class MrmNceDataset(DetectFeatTxtTokDataset):
def __init__(self, mask_prob, *args, **kwargs):
super().__init__(*args, **kwargs)
self.mask_prob = mask_prob
def __getitem__(self, i):
example = super().__getitem__(i)
# text input
input_ids = example['input_ids']
input_ids = self.txt_db.combine_inputs(input_ids)
# image input features
img_feat, img_pos_feat, num_bb = self._get_img_feat(
example['img_fname'])
img_mask = _get_img_mask(self.mask_prob, num_bb)
img_mask_tgt = _get_img_tgt_mask(img_mask, len(input_ids))
attn_masks = torch.ones(len(input_ids) + num_bb, dtype=torch.long)
return (input_ids, img_feat, img_pos_feat,
attn_masks, img_mask, img_mask_tgt,
example['img_fname'])
class NegativeImageSampler(object):
def __init__(self, img_dbs, neg_size, size_mul=8):
if not isinstance(img_dbs, list):
assert isinstance(img_dbs, DetectFeatLmdb)
img_dbs = [img_dbs]
self.neg_size = neg_size
self.img_db = JoinedDetectFeatLmdb(img_dbs)
all_imgs = []
for db in img_dbs:
all_imgs.extend(db.name2nbb.keys())
self.all_imgs = all_imgs
def sample_negative_feats(self, pos_imgs):
neg_img_ids = sample_negative(self.all_imgs, pos_imgs, self.neg_size)
all_neg_feats = torch.cat([self.img_db[img][0] for img in neg_img_ids],
dim=0)
# only use multiples of 8 for tensorcores
n_cut = all_neg_feats.size(0) % 8
if n_cut != 0:
return all_neg_feats[:-n_cut]
else:
return all_neg_feats
class JoinedDetectFeatLmdb(object):
def __init__(self, img_dbs):
assert all(isinstance(db, DetectFeatLmdb) for db in img_dbs)
self.img_dbs = img_dbs
def __getitem__(self, file_name):
for db in self.img_dbs:
if file_name in db:
return db[file_name]
raise ValueError("image does not exists")
@curry
def mrm_nce_collate(neg_sampler, inputs):
(input_ids, img_feats, img_pos_feats, attn_masks, img_masks, img_mask_tgts,
positive_imgs) = map(list, unzip(inputs))
txt_lens = [i.size(0) for i in input_ids]
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0)
position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long
).unsqueeze(0)
num_bbs = [f.size(0) for f in img_feats]
img_feat = pad_tensors(img_feats, num_bbs)
img_pos_feat = pad_tensors(img_pos_feats, num_bbs)
neg_feats = neg_sampler.sample_negative_feats(positive_imgs)
# mask features
img_masks = pad_sequence(img_masks, batch_first=True, padding_value=0)
feat_targets = _get_feat_target(img_feat, img_masks)
img_feat, img_masks_in = _mask_img_feat(img_feat, img_masks, neg_feats)
img_mask_tgt = pad_sequence(img_mask_tgts,
batch_first=True, padding_value=0)
attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0)
bs, max_tl = input_ids.size()
out_size = attn_masks.size(1)
gather_index = get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size)
batch = {'input_ids': input_ids,
'position_ids': position_ids,
'img_feat': img_feat,
'img_pos_feat': img_pos_feat,
'attn_masks': attn_masks,
'gather_index': gather_index,
'feat_targets': feat_targets,
'img_masks': img_masks,
'img_masks_in': img_masks_in,
'img_mask_tgt': img_mask_tgt,
'neg_feats': neg_feats}
return batch

218
uniter_model/data/nlvr2.py

@ -0,0 +1,218 @@
"""
Copyright (c) Microsoft Corporation.
Licensed under the MIT license.
NLVR2 dataset
"""
import copy
import torch
from torch.nn.utils.rnn import pad_sequence
from toolz.sandbox import unzip
from cytoolz import concat
from .data import (DetectFeatTxtTokDataset, TxtTokLmdb, DetectFeatLmdb,
get_ids_and_lens, pad_tensors, get_gather_index)
class Nlvr2PairedDataset(DetectFeatTxtTokDataset):
def __init__(self, txt_db, img_db, use_img_type=True):
assert isinstance(txt_db, TxtTokLmdb)
assert isinstance(img_db, DetectFeatLmdb)
self.txt_db = txt_db
self.img_db = img_db
txt_lens, self.ids = get_ids_and_lens(txt_db)
txt2img = txt_db.txt2img
self.lens = [2*tl + sum(self.img_db.name2nbb[img]
for img in txt2img[id_])
for tl, id_ in zip(txt_lens, self.ids)]
self.use_img_type = use_img_type
def __getitem__(self, i):
"""
[[txt, img1],
[txt, img2]]
"""
example = super().__getitem__(i)
target = example['target']
outs = []
for i, img in enumerate(example['img_fname']):
img_feat, img_pos_feat, num_bb = self._get_img_feat(img)
# text input
input_ids = copy.deepcopy(example['input_ids'])
input_ids = [self.txt_db.cls_] + input_ids + [self.txt_db.sep]
attn_masks = [1] * (len(input_ids) + num_bb)
input_ids = torch.tensor(input_ids)
attn_masks = torch.tensor(attn_masks)
if self.use_img_type:
img_type_ids = torch.tensor([i+1]*num_bb)
else:
img_type_ids = None
outs.append((input_ids, img_feat, img_pos_feat,
attn_masks, img_type_ids))
return tuple(outs), target
def nlvr2_paired_collate(inputs):
(input_ids, img_feats, img_pos_feats, attn_masks,
img_type_ids) = map(list, unzip(concat(outs for outs, _ in inputs)))
txt_lens = [i.size(0) for i in input_ids]
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0)
position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long
).unsqueeze(0)
# image batches
num_bbs = [f.size(0) for f in img_feats]
img_feat = pad_tensors(img_feats, num_bbs)
img_pos_feat = pad_tensors(img_pos_feats, num_bbs)
if img_type_ids[0] is None:
img_type_ids = None
else:
img_type_ids = pad_sequence(img_type_ids,
batch_first=True, padding_value=0)
attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0)
targets = torch.Tensor([t for _, t in inputs]).long()
bs, max_tl = input_ids.size()
out_size = attn_masks.size(1)
gather_index = get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size)
batch = {'input_ids': input_ids,
'position_ids': position_ids,
'img_feat': img_feat,
'img_pos_feat': img_pos_feat,
'attn_masks': attn_masks,
'gather_index': gather_index,
'img_type_ids': img_type_ids,
'targets': targets}
return batch
class Nlvr2PairedEvalDataset(Nlvr2PairedDataset):
def __getitem__(self, i):
qid = self.ids[i]
outs, targets = super().__getitem__(i)
return qid, outs, targets
def nlvr2_paired_eval_collate(inputs):
qids, batch = [], []
for id_, *tensors in inputs:
qids.append(id_)
batch.append(tensors)
batch = nlvr2_paired_collate(batch)
batch['qids'] = qids
return batch
class Nlvr2TripletDataset(DetectFeatTxtTokDataset):
def __init__(self, txt_db, img_db, use_img_type=True):
assert isinstance(txt_db, TxtTokLmdb)
assert isinstance(img_db, DetectFeatLmdb)
self.txt_db = txt_db
self.img_db = img_db
txt_lens, self.ids = get_ids_and_lens(txt_db)
txt2img = txt_db.txt2img
self.lens = [tl + sum(self.img_db.name2nbb[img]
for img in txt2img[id_])
for tl, id_ in zip(txt_lens, self.ids)]
self.use_img_type = use_img_type
def __getitem__(self, i):
"""
[[txt, img1],
[txt, img2]]
"""
example = super().__getitem__(i)
target = example['target']
img_feats = []
img_pos_feats = []
num_bb = 0
img_type_ids = []
for i, img in enumerate(example['img_fname']):
feat, pos, nbb = self._get_img_feat(img)
img_feats.append(feat)
img_pos_feats.append(pos)
num_bb += nbb
if self.use_img_type:
img_type_ids.extend([i+1]*nbb)
img_feat = torch.cat(img_feats, dim=0)
img_pos_feat = torch.cat(img_pos_feats, dim=0)
if self.use_img_type:
img_type_ids = torch.tensor(img_type_ids)
else:
img_type_ids = None
# text input
input_ids = copy.deepcopy(example['input_ids'])
input_ids = [self.txt_db.cls_] + input_ids + [self.txt_db.sep]
attn_masks = [1] * (len(input_ids) + num_bb)
input_ids = torch.tensor(input_ids)
attn_masks = torch.tensor(attn_masks)
return (input_ids, img_feat, img_pos_feat, attn_masks,
img_type_ids, target)
def nlvr2_triplet_collate(inputs):
(input_ids, img_feats, img_pos_feats,
attn_masks, img_type_ids, targets) = map(list, unzip(inputs))
txt_lens = [i.size(0) for i in input_ids]
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0)
position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long
).unsqueeze(0)
# image batches
num_bbs = [f.size(0) for f in img_feats]
img_feat = pad_tensors(img_feats, num_bbs)
img_pos_feat = pad_tensors(img_pos_feats, num_bbs)
if img_type_ids[0] is None:
img_type_ids = None
else:
img_type_ids = pad_sequence(img_type_ids,
batch_first=True, padding_value=0)
attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0)
targets = torch.Tensor(targets).long()
bs, max_tl = input_ids.size()
out_size = attn_masks.size(1)
gather_index = get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size)
batch = {'input_ids': input_ids,
'position_ids': position_ids,
'img_feat': img_feat,
'img_pos_feat': img_pos_feat,
'attn_masks': attn_masks,
'gather_index': gather_index,
'img_type_ids': img_type_ids,
'targets': targets}
return batch
class Nlvr2TripletEvalDataset(Nlvr2TripletDataset):
def __getitem__(self, i):
qid = self.ids[i]
tensors = super().__getitem__(i)
return (qid, *tensors)
def nlvr2_triplet_eval_collate(inputs):
qids, batch = [], []
for id_, *tensors in inputs:
qids.append(id_)
batch.append(tensors)
batch = nlvr2_triplet_collate(batch)
batch['qids'] = qids
return batch

319
uniter_model/data/re.py

@ -0,0 +1,319 @@
"""
Referring Expression Comprehension dataset
"""
import sys
import json
import random
import numpy as np
import torch
from torch.utils.data import Dataset
from torch.nn.utils.rnn import pad_sequence
from toolz.sandbox import unzip
from .data import TxtLmdb
class ReImageFeatDir(object):
def __init__(self, img_dir):
self.img_dir = img_dir
def __getitem__(self, file_name):
img_dump = np.load(f'{self.img_dir}/{file_name}', allow_pickle=True)
img_feat = torch.tensor(img_dump['features'])
img_bb = torch.tensor(img_dump['norm_bb'])
return img_feat, img_bb
class ReDetectFeatDir(object):
def __init__(self, img_dir, conf_th=0.2, max_bb=100, min_bb=10, num_bb=36,
format_='npz'):
assert format_ == 'npz', 'only support npz for now.'
assert isinstance(img_dir, str), 'img_dir is path, not db.'
self.img_dir = img_dir
self.conf_th = conf_th
self.max_bb = max_bb
self.min_bb = min_bb
self.num_bb = num_bb
def _compute_num_bb(self, img_dump):
num_bb = max(self.min_bb, (img_dump['conf'] > self.conf_th).sum())
num_bb = min(self.max_bb, num_bb)
return num_bb
def __getitem__(self, file_name):
# image input features
img_dump = np.load(f'{self.img_dir}/{file_name}', allow_pickle=True)
num_bb = self._compute_num_bb(img_dump)
img_feat = torch.tensor(img_dump['features'][:num_bb, :])
img_bb = torch.tensor(img_dump['norm_bb'][:num_bb, :])
return img_feat, img_bb
class ReferringExpressionDataset(Dataset):
def __init__(self, db_dir, img_dir, max_txt_len=60):
assert isinstance(img_dir, ReImageFeatDir) or \
isinstance(img_dir, ReDetectFeatDir)
self.img_dir = img_dir
# load refs = [{ref_id, sent_ids, ann_id, image_id, sentences, split}]
refs = json.load(open(f'{db_dir}/refs.json', 'r'))
self.ref_ids = [ref['ref_id'] for ref in refs]
self.Refs = {ref['ref_id']: ref for ref in refs}
# load annotations = [{id, area, bbox, image_id, category_id}]
anns = json.load(open(f'{db_dir}/annotations.json', 'r'))
self.Anns = {ann['id']: ann for ann in anns}
# load categories = [{id, name, supercategory}]
categories = json.load(open(f'{db_dir}/categories.json', 'r'))
self.Cats = {cat['id']: cat['name'] for cat in categories}
# load images = [{id, file_name, ann_ids, height, width}]
images = json.load(open(f'{db_dir}/images.json', 'r'))
self.Images = {img['id']: img for img in images}
# id2len: sent_id -> sent_len
id2len = json.load(open(f'{db_dir}/id2len.json', 'r'))
self.id2len = {int(_id): _len for _id, _len in id2len.items()}
self.max_txt_len = max_txt_len
self.sent_ids = self._get_sent_ids()
# db[str(sent_id)] =
# {sent_id, sent, ref_id, ann_id, image_id,
# bbox, input_ids, toked_sent}
self.db = TxtLmdb(db_dir, readonly=True)
# meta
meta = json.load(open(f'{db_dir}/meta.json', 'r'))
self.cls_ = meta['CLS']
self.sep = meta['SEP']
self.mask = meta['MASK']
self.v_range = meta['v_range']
def shuffle(self):
# we shuffle ref_ids and make sent_ids according to ref_ids
random.shuffle(self.ref_ids)
self.sent_ids = self._get_sent_ids()
def _get_sent_ids(self):
sent_ids = []
for ref_id in self.ref_ids:
for sent_id in self.Refs[ref_id]['sent_ids']:
sent_len = self.id2len[sent_id]
if self.max_txt_len == -1 or sent_len < self.max_txt_len:
sent_ids.append(sent_id)
return sent_ids
def _get_img_feat(self, fname):
img_feat, bb = self.img_dir[fname]
img_bb = torch.cat([bb, bb[:, 4:5]*bb[:, 5:]], dim=-1)
num_bb = img_feat.size(0)
return img_feat, img_bb, num_bb
def __len__(self):
return len(self.sent_ids)
def __getitem__(self, i):
"""
Return:
:input_ids : (L, ), i.e., [cls, wd, wd, ..., sep, 0, 0]
:position_ids : range(L)
:img_feat : (num_bb, d)
:img_pos_feat : (num_bb, 7)
:attn_masks : (L+num_bb, ), i.e., [1, 1, ..., 0, 0, 1, 1]
:obj_masks : (num_bb, ) all 0's
:target : (1, )
"""
# {sent_id, sent, ref_id, ann_id, image_id,
# bbox, input_ids, toked_sent}
sent_id = self.sent_ids[i]
txt_dump = self.db[str(sent_id)]
image_id = txt_dump['image_id']
fname = f'visual_grounding_coco_gt_{int(image_id):012}.npz'
img_feat, img_pos_feat, num_bb = self._get_img_feat(fname)
# text input
input_ids = txt_dump['input_ids']
input_ids = [self.cls_] + input_ids + [self.sep]
attn_masks = [1] * len(input_ids)
position_ids = list(range(len(input_ids)))
attn_masks += [1] * num_bb
input_ids = torch.tensor(input_ids)
position_ids = torch.tensor(position_ids)
attn_masks = torch.tensor(attn_masks)
# target bbox
img = self.Images[image_id]
assert len(img['ann_ids']) == num_bb, \
'Please use visual_grounding_coco_gt'
target = img['ann_ids'].index(txt_dump['ann_id'])
target = torch.tensor([target])
# obj_masks, to be padded with 1, for masking out non-object prob.
obj_masks = torch.tensor([0]*len(img['ann_ids'])).bool()
return (input_ids, position_ids, img_feat, img_pos_feat, attn_masks,
obj_masks, target)
def re_collate(inputs):
"""
Return:
:input_ids : (n, max_L) padded with 0
:position_ids : (n, max_L) padded with 0
:txt_lens : list of [txt_len]
:img_feat : (n, max_num_bb, feat_dim)
:img_pos_feat : (n, max_num_bb, 7)
:num_bbs : list of [num_bb]
:attn_masks : (n, max_{L+num_bb}) padded with 0
:obj_masks : (n, max_num_bb) padded with 1
:targets : (n, )
"""
(input_ids, position_ids, img_feats, img_pos_feats, attn_masks, obj_masks,
targets) = map(list, unzip(inputs))
txt_lens = [i.size(0) for i in input_ids]
num_bbs = [f.size(0) for f in img_feats]
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0)
position_ids = pad_sequence(position_ids,
batch_first=True, padding_value=0)
attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0)
targets = torch.cat(targets, dim=0)
obj_masks = pad_sequence(obj_masks,
batch_first=True, padding_value=1).bool()
batch_size = len(img_feats)
num_bb = max(num_bbs)
feat_dim = img_feats[0].size(1)
pos_dim = img_pos_feats[0].size(1)
img_feat = torch.zeros(batch_size, num_bb, feat_dim)
img_pos_feat = torch.zeros(batch_size, num_bb, pos_dim)
for i, (im, pos) in enumerate(zip(img_feats, img_pos_feats)):
len_ = im.size(0)
img_feat.data[i, :len_, :] = im.data
img_pos_feat.data[i, :len_, :] = pos.data
return (input_ids, position_ids, txt_lens,
img_feat, img_pos_feat, num_bbs,
attn_masks, obj_masks, targets)
class ReferringExpressionEvalDataset(ReferringExpressionDataset):
def __getitem__(self, i):
"""
Return:
:input_ids : (L, ), i.e., [cls, wd, wd, ..., sep, 0, 0]
:position_ids : range(L)
:img_feat : (num_bb, d)
:img_pos_feat : (num_bb, 7)
:attn_masks : (L+num_bb, ), i.e., [1, 1, ..., 0, 0, 1, 1]
:obj_masks : (num_bb, ) all 0's
:tgt_box : ndarray (4, ) xywh
:obj_boxes : ndarray (num_bb, 4) xywh
:sent_id
"""
# {sent_id, sent, ref_id, ann_id, image_id,
# bbox, input_ids, toked_sent}
sent_id = self.sent_ids[i]
txt_dump = self.db[str(sent_id)]
image_id = txt_dump['image_id']
if isinstance(self.img_dir, ReImageFeatDir):
if '_gt' in self.img_dir.img_dir:
fname = f'visual_grounding_coco_gt_{int(image_id):012}.npz'
elif '_det' in self.img_dir.img_dir:
fname = f'visual_grounding_det_coco_{int(image_id):012}.npz'
elif isinstance(self.img_dir, ReDetectFeatDir):
fname = f'coco_train2014_{int(image_id):012}.npz'
else:
sys.exit('%s not supported.' % self.img_dir)
img_feat, img_pos_feat, num_bb = self._get_img_feat(fname)
# image info
img = self.Images[image_id]
im_width, im_height = img['width'], img['height']
# object boxes, img_pos_feat (xyxywha) -> xywh
obj_boxes = np.stack([img_pos_feat[:, 0]*im_width,
img_pos_feat[:, 1]*im_height,
img_pos_feat[:, 4]*im_width,
img_pos_feat[:, 5]*im_height], axis=1)
obj_masks = torch.tensor([0]*num_bb).bool()
# target box
tgt_box = np.array(txt_dump['bbox']) # xywh
# text input
input_ids = txt_dump['input_ids']
input_ids = [self.cls_] + input_ids + [self.sep]
attn_masks = [1] * len(input_ids)
position_ids = list(range(len(input_ids)))
attn_masks += [1] * num_bb
input_ids = torch.tensor(input_ids)
position_ids = torch.tensor(position_ids)
attn_masks = torch.tensor(attn_masks)
return (input_ids, position_ids, img_feat, img_pos_feat, attn_masks,
obj_masks, tgt_box, obj_boxes, sent_id)
# IoU function
def computeIoU(self, box1, box2):
# each box is of [x1, y1, w, h]
inter_x1 = max(box1[0], box2[0])
inter_y1 = max(box1[1], box2[1])
inter_x2 = min(box1[0]+box1[2]-1, box2[0]+box2[2]-1)
inter_y2 = min(box1[1]+box1[3]-1, box2[1]+box2[3]-1)
if inter_x1 < inter_x2 and inter_y1 < inter_y2:
inter = (inter_x2-inter_x1+1)*(inter_y2-inter_y1+1)
else:
inter = 0
union = box1[2]*box1[3] + box2[2]*box2[3] - inter
return float(inter)/union
def re_eval_collate(inputs):
"""
Return:
:input_ids : (n, max_L)
:position_ids : (n, max_L)
:txt_lens : list of [txt_len]
:img_feat : (n, max_num_bb, d)
:img_pos_feat : (n, max_num_bb, 7)
:num_bbs : list of [num_bb]
:attn_masks : (n, max{L+num_bb})
:obj_masks : (n, max_num_bb)
:tgt_box : list of n [xywh]
:obj_boxes : list of n [[xywh, xywh, ...]]
:sent_ids : list of n [sent_id]
"""
(input_ids, position_ids, img_feats, img_pos_feats, attn_masks, obj_masks,
tgt_box, obj_boxes, sent_ids) = map(list, unzip(inputs))
txt_lens = [i.size(0) for i in input_ids]
num_bbs = [f.size(0) for f in img_feats]
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0)
position_ids = pad_sequence(position_ids,
batch_first=True, padding_value=0)
attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0)
obj_masks = pad_sequence(obj_masks,
batch_first=True, padding_value=1).bool()
batch_size = len(img_feats)
num_bb = max(num_bbs)
feat_dim = img_feats[0].size(1)
pos_dim = img_pos_feats[0].size(1)
img_feat = torch.zeros(batch_size, num_bb, feat_dim)
img_pos_feat = torch.zeros(batch_size, num_bb, pos_dim)
for i, (im, pos) in enumerate(zip(img_feats, img_pos_feats)):
len_ = im.size(0)
img_feat.data[i, :len_, :] = im.data
img_pos_feat.data[i, :len_, :] = pos.data
return (input_ids, position_ids, txt_lens,
img_feat, img_pos_feat, num_bbs,
attn_masks, obj_masks, tgt_box, obj_boxes, sent_ids)

116
uniter_model/data/sampler.py

@ -0,0 +1,116 @@
""" sampler for length bucketing (batch by tokens) """
import math
import random
import torch
import horovod.torch as hvd
from torch.utils.data import Sampler
from cytoolz import partition_all
class TokenBucketSampler(Sampler):
def __init__(self, lens, bucket_size, batch_size,
droplast=False, size_multiple=8):
self._lens = lens
self._max_tok = batch_size
self._bucket_size = bucket_size
self._droplast = droplast
self._size_mul = size_multiple
def _create_ids(self):
return list(range(len(self._lens)))
def _sort_fn(self, i):
return self._lens[i]
def __iter__(self):
ids = self._create_ids()
random.shuffle(ids)
buckets = [sorted(ids[i:i+self._bucket_size],
key=self._sort_fn, reverse=True)
for i in range(0, len(ids), self._bucket_size)]
# fill batches until max_token (include padding)
batches = []
for bucket in buckets:
max_len = 0
batch_indices = []
for indices in partition_all(self._size_mul, bucket):
max_len = max(max_len, max(self._lens[i] for i in indices))
if (max_len * (len(batch_indices) + self._size_mul)
> self._max_tok):
if not batch_indices:
raise ValueError(
"max_tokens too small / max_seq_len too long")
assert len(batch_indices) % self._size_mul == 0
batches.append(batch_indices)
batch_indices = list(indices)
else:
batch_indices.extend(indices)
if not self._droplast and batch_indices:
batches.append(batch_indices)
random.shuffle(batches)
return iter(batches)
def __len__(self):
raise ValueError("NOT supported. "
"This has some randomness across epochs")
class DistributedSampler(Sampler):
"""Sampler that restricts data loading to a subset of the dataset.
It is especially useful in conjunction with
:class:`torch.nn.parallel.DistributedDataParallel`. In such case, each
process can pass a DistributedSampler instance as a DataLoader sampler,
and load a subset of the original dataset that is exclusive to it.
.. note::
Dataset is assumed to be of constant size.
Arguments:
dataset: Dataset used for sampling.
num_replicas (optional): Number of processes participating in
distributed training.
rank (optional): Rank of the current process within num_replicas.
shuffle (optional): If true (default), sampler will shuffle the indices
"""
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True):
if num_replicas is None:
num_replicas = hvd.size()
if rank is None:
rank = hvd.rank()
self.dataset = dataset
self.num_replicas = num_replicas
self.rank = rank
self.epoch = 0
self.num_samples = int(math.ceil(len(self.dataset)
* 1.0 / self.num_replicas))
self.total_size = self.num_samples * self.num_replicas
self.shuffle = shuffle
def __iter__(self):
# deterministically shuffle based on epoch
g = torch.Generator()
g.manual_seed(self.epoch)
indices = list(range(len(self.dataset)))
# add extra samples to make it evenly divisible
indices += indices[:(self.total_size - len(indices))]
assert len(indices) == self.total_size
# subsample
indices = indices[self.rank:self.total_size:self.num_replicas]
if self.shuffle:
shufle_ind = torch.randperm(len(indices), generator=g).tolist()
indices = [indices[i] for i in shufle_ind]
assert len(indices) == self.num_samples
return iter(indices)
def __len__(self):
return self.num_samples
def set_epoch(self, epoch):
self.epoch = epoch

BIN
uniter_model/data/test_data/input0.txt

Binary file not shown.

BIN
uniter_model/data/test_data/input1.txt

Binary file not shown.

BIN
uniter_model/data/test_data/input2.txt

Binary file not shown.

BIN
uniter_model/data/test_data/input3.txt

Binary file not shown.

BIN
uniter_model/data/test_data/input4.txt

Binary file not shown.

BIN
uniter_model/data/test_data/input5.txt

Binary file not shown.

BIN
uniter_model/data/test_data/input6.txt

Binary file not shown.

BIN
uniter_model/data/test_data/input7.txt

Binary file not shown.

725
uniter_model/data/vcr.py

@ -0,0 +1,725 @@
"""
VCR dataset
"""
import json
import copy
import random
import torch
from torch.nn.utils.rnn import pad_sequence
from toolz.sandbox import unzip
from torch.utils.data import Dataset
from .data import DetectFeatLmdb, TxtLmdb, random_word
from .mrc import DetectFeatDir_for_mrc
class ImageTextDataset(Dataset):
def __init__(self, db_dir, img_dir_gt=None, img_dir=None,
max_txt_len=120, task="qa"):
self.txt_lens = []
self.ids = []
self.task = task
for id_, len_ in json.load(open(f'{db_dir}/id2len_{task}.json')
).items():
if max_txt_len == -1 or len_ <= max_txt_len:
self.txt_lens.append(len_)
self.ids.append(id_)
self.db = TxtLmdb(db_dir, readonly=True)
self.img_dir = img_dir
self.img_dir_gt = img_dir_gt
def __len__(self):
return len(self.ids)
def __getitem__(self, i):
id_ = self.ids[i]
txt_dump = self.db[id_]
img_dump_gt, img_dump = None, None
img_fname_gt, img_fname = txt_dump['img_fname']
if self.img_dump_gt:
img_dump_gt = self.img_dump_gt[img_fname_gt]
if self.img_dir:
img_dump = self.img_dir[img_fname]
return img_dump_gt, img_dump, txt_dump
class DetectFeatBertTokDataset(ImageTextDataset):
def __init__(self, db_dir, img_dir_gt=None, img_dir=None,
max_txt_len=60, task="qa"):
assert not (img_dir_gt is None and img_dir is None),\
"image_dir_gt and img_dir cannot all be None"
assert task == "qa" or task == "qar",\
"VCR only allow two tasks: qa or qar"
assert img_dir_gt is None or isinstance(img_dir_gt, DetectFeatLmdb)
assert img_dir is None or isinstance(img_dir, DetectFeatLmdb)
super().__init__(db_dir, img_dir_gt, img_dir, max_txt_len, task)
txt2img = json.load(open(f'{db_dir}/txt2img.json'))
if self.img_dir and self.img_dir_gt:
self.lens = [tl+self.img_dir_gt.name2nbb[txt2img[id_][0]] +
self.img_dir.name2nbb[txt2img[id_][1]]
for tl, id_ in zip(self.txt_lens, self.ids)]
elif self.img_dir:
self.lens = [tl+self.img_dir.name2nbb[txt2img[id_][1]]
for tl, id_ in zip(self.txt_lens, self.ids)]
else:
self.lens = [tl+self.img_dir_gt.name2nbb[txt2img[id_][0]]
for tl, id_ in zip(self.txt_lens, self.ids)]
meta = json.load(open(f'{db_dir}/meta.json', 'r'))
self.cls_ = meta['CLS']
self.sep = meta['SEP']
self.mask = meta['MASK']
self.v_range = meta['v_range']
def _get_img_feat(self, fname_gt, fname):
if self.img_dir and self.img_dir_gt:
img_feat_gt, bb_gt = self.img_dir_gt[fname_gt]
img_bb_gt = torch.cat([bb_gt, bb_gt[:, 4:5]*bb_gt[:, 5:]], dim=-1)
img_feat, bb = self.img_dir[fname]
img_bb = torch.cat([bb, bb[:, 4:5]*bb[:, 5:]], dim=-1)
img_feat = torch.cat([img_feat_gt, img_feat], dim=0)
img_bb = torch.cat([img_bb_gt, img_bb], dim=0)
num_bb = img_feat.size(0)
elif self.img_dir:
img_feat, bb = self.img_dir[fname]
img_bb = torch.cat([bb, bb[:, 4:5]*bb[:, 5:]], dim=-1)
num_bb = img_feat.size(0)
elif self.img_dir_gt:
img_feat, bb = self.img_dir_gt[fname_gt]
img_bb = torch.cat([bb, bb[:, 4:5]*bb[:, 5:]], dim=-1)
num_bb = img_feat.size(0)
return img_feat, img_bb, num_bb
class VcrDataset(DetectFeatBertTokDataset):
def __init__(self, mask_prob, *args, **kwargs):
super().__init__(*args, **kwargs)
self.mask_prob = mask_prob
del self.txt_lens
def _get_input_ids(self, txt_dump):
# text input
input_ids_q = txt_dump['input_ids']
type_ids_q = [0]*len(input_ids_q)
input_ids_as = txt_dump['input_ids_as']
if self.task == "qar":
input_ids_rs = txt_dump['input_ids_rs']
answer_label = txt_dump['qa_target']
assert answer_label >= 0, "answer_label < 0"
input_ids_gt_a = [self.sep] + copy.deepcopy(
input_ids_as[answer_label])
type_ids_gt_a = [2] * len(input_ids_gt_a)
type_ids_q += type_ids_gt_a
input_ids_q += input_ids_gt_a
input_ids_for_choices = input_ids_rs
else:
input_ids_for_choices = input_ids_as
return input_ids_q, input_ids_for_choices, type_ids_q
def __getitem__(self, i):
id_ = self.ids[i]
txt_dump = self.db[id_]
img_feat, img_pos_feat, num_bb = self._get_img_feat(
txt_dump['img_fname'][0], txt_dump['img_fname'][1])
object_targets = txt_dump["object_ids"]
input_ids_q, input_ids_for_choices, type_ids_q = self._get_input_ids(
txt_dump)
label = txt_dump['%s_target' % (self.task)]
choice_num_bbs, choice_img_feats, choice_img_pos_feats = (
[], [], [])
(choice_txt_lens, choice_input_ids, choice_txt_type_ids,
choice_attn_masks, choice_position_ids, choice_targets) = (
[], [], [], [], [], [])
choice_obj_targets, choice_img_masks = ([], [])
for index, input_ids_a in enumerate(input_ids_for_choices):
if index == label:
target = torch.tensor([1]).long()
else:
target = torch.tensor([0]).long()
input_ids = [self.cls_] + copy.deepcopy(input_ids_q) +\
[self.sep] + input_ids_a + [self.sep]
type_id_for_choice = 3 if type_ids_q[-1] == 2 else 2
txt_type_ids = [0] + type_ids_q + [type_id_for_choice]*(
len(input_ids_a)+2)
attn_masks = [1] * len(input_ids)
position_ids = list(range(len(input_ids)))
attn_masks += [1] * num_bb
input_ids = torch.tensor(input_ids)
position_ids = torch.tensor(position_ids)
attn_masks = torch.tensor(attn_masks)
txt_type_ids = torch.tensor(txt_type_ids)
choice_txt_lens.append(len(input_ids))
choice_input_ids.append(input_ids)
choice_attn_masks.append(attn_masks)
choice_position_ids.append(position_ids)
choice_txt_type_ids.append(txt_type_ids)
choice_num_bbs.append(num_bb)
choice_img_feats.append(img_feat)
choice_img_pos_feats.append(img_pos_feat)
choice_targets.append(target)
# mask image input features
num_gt_bb = len(object_targets)
num_det_bb = num_bb - num_gt_bb
# only mask gt features
img_mask = [random.random() < self.mask_prob
for _ in range(num_gt_bb)]
if not any(img_mask):
# at least mask 1
img_mask[0] = True
img_mask += [False]*num_det_bb
img_mask = torch.tensor(img_mask)
object_targets += [0]*num_det_bb
obj_targets = torch.tensor(object_targets)
choice_obj_targets.append(obj_targets)
choice_img_masks.append(img_mask)
return (choice_input_ids, choice_position_ids, choice_txt_lens,
choice_txt_type_ids,
choice_img_feats, choice_img_pos_feats, choice_num_bbs,
choice_attn_masks, choice_targets, choice_obj_targets,
choice_img_masks)
def vcr_collate(inputs):
(input_ids, position_ids, txt_lens, txt_type_ids, img_feats,
img_pos_feats, num_bbs, attn_masks, targets,
obj_targets, img_masks) = map(list, unzip(inputs))
all_num_bbs, all_img_feats, all_img_pos_feats = (
[], [], [])
all_txt_lens, all_input_ids, all_attn_masks,\
all_position_ids, all_txt_type_ids = (
[], [], [], [], [])
all_obj_targets = []
all_targets = []
# all_targets = targets
all_img_masks = []
for i in range(len(num_bbs)):
all_input_ids += input_ids[i]
all_position_ids += position_ids[i]
all_txt_lens += txt_lens[i]
all_txt_type_ids += txt_type_ids[i]
all_img_feats += img_feats[i]
all_img_pos_feats += img_pos_feats[i]
all_num_bbs += num_bbs[i]
all_attn_masks += attn_masks[i]
all_obj_targets += obj_targets[i]
all_img_masks += img_masks[i]
all_targets += targets[i]
all_input_ids = pad_sequence(all_input_ids,
batch_first=True, padding_value=0)
all_position_ids = pad_sequence(all_position_ids,
batch_first=True, padding_value=0)
all_txt_type_ids = pad_sequence(all_txt_type_ids,
batch_first=True, padding_value=0)
all_attn_masks = pad_sequence(all_attn_masks,
batch_first=True, padding_value=0)
all_img_masks = pad_sequence(all_img_masks,
batch_first=True, padding_value=0)
# all_targets = pad_sequence(all_targets,
# batch_first=True, padding_value=0)
all_targets = torch.stack(all_targets, dim=0)
batch_size = len(all_img_feats)
num_bb = max(all_num_bbs)
feat_dim = all_img_feats[0].size(1)
pos_dim = all_img_pos_feats[0].size(1)
all_img_feat = torch.zeros(batch_size, num_bb, feat_dim)
all_img_pos_feat = torch.zeros(batch_size, num_bb, pos_dim)
all_obj_target = torch.zeros(batch_size, num_bb)
for i, (im, pos, label) in enumerate(zip(
all_img_feats, all_img_pos_feats, all_obj_targets)):
len_ = im.size(0)
all_img_feat.data[i, :len_, :] = im.data
all_img_pos_feat.data[i, :len_, :] = pos.data
all_obj_target.data[i, :len_] = label.data
obj_targets = all_obj_target[all_img_masks].contiguous()
return (all_input_ids, all_position_ids, all_txt_lens,
all_txt_type_ids,
all_img_feat, all_img_pos_feat, all_num_bbs,
all_attn_masks, all_targets, obj_targets, all_img_masks)
class VcrEvalDataset(DetectFeatBertTokDataset):
def __init__(self, split, *args, **kwargs):
super().__init__(*args, **kwargs)
self.split = split
del self.txt_lens
def _get_input_ids(self, txt_dump):
# text input
input_ids_for_choices = []
type_ids_for_choices = []
input_ids_q = txt_dump['input_ids']
type_ids_q = [0]*len(input_ids_q)
input_ids_as = txt_dump['input_ids_as']
input_ids_rs = txt_dump['input_ids_rs']
for index, input_ids_a in enumerate(input_ids_as):
curr_input_ids_qa = [self.cls_] + copy.deepcopy(input_ids_q) +\
[self.sep] + input_ids_a + [self.sep]
curr_type_ids_qa = [0] + type_ids_q + [2]*(
len(input_ids_a)+2)
input_ids_for_choices.append(curr_input_ids_qa)
type_ids_for_choices.append(curr_type_ids_qa)
for index, input_ids_a in enumerate(input_ids_as):
curr_input_ids_qa = [self.cls_] + copy.deepcopy(input_ids_q) +\
[self.sep] + input_ids_a + [self.sep]
curr_type_ids_qa = [0] + type_ids_q + [2]*(
len(input_ids_a)+1)
if (self.split == "val" and index == txt_dump["qa_target"]) or\
self.split == "test":
for input_ids_r in input_ids_rs:
curr_input_ids_qar = copy.deepcopy(curr_input_ids_qa) +\
input_ids_r + [self.sep]
curr_type_ids_qar = copy.deepcopy(curr_type_ids_qa) +\
[3]*(len(input_ids_r)+2)
input_ids_for_choices.append(curr_input_ids_qar)
type_ids_for_choices.append(curr_type_ids_qar)
return input_ids_for_choices, type_ids_for_choices
def __getitem__(self, i):
qid = self.ids[i]
id_ = self.ids[i]
txt_dump = self.db[id_]
img_feat, img_pos_feat, num_bb = self._get_img_feat(
txt_dump['img_fname'][0], txt_dump['img_fname'][1])
object_targets = txt_dump["object_ids"]
input_ids_for_choices, type_ids_for_choices = self._get_input_ids(
txt_dump)
qa_target = torch.tensor([int(txt_dump["qa_target"])])
qar_target = torch.tensor([int(txt_dump["qar_target"])])
choice_num_bbs, choice_img_feats, choice_img_pos_feats = (
[], [], [])
(choice_txt_lens, choice_input_ids, choice_attn_masks,
choice_position_ids, choice_txt_type_ids) = (
[], [], [], [], [])
choice_obj_targets = []
for index, input_ids in enumerate(input_ids_for_choices):
txt_type_ids = type_ids_for_choices[index]
attn_masks = [1] * len(input_ids)
position_ids = list(range(len(input_ids)))
attn_masks += [1] * num_bb
input_ids = torch.tensor(input_ids)
position_ids = torch.tensor(position_ids)
attn_masks = torch.tensor(attn_masks)
txt_type_ids = torch.tensor(txt_type_ids)
choice_txt_lens.append(len(input_ids))
choice_input_ids.append(input_ids)
choice_attn_masks.append(attn_masks)
choice_position_ids.append(position_ids)
choice_txt_type_ids.append(txt_type_ids)
choice_num_bbs.append(num_bb)
choice_img_feats.append(img_feat)
choice_img_pos_feats.append(img_pos_feat)
obj_targets = torch.tensor(object_targets)
choice_obj_targets.append(obj_targets)
return (qid, choice_input_ids, choice_position_ids, choice_txt_lens,
choice_txt_type_ids,
choice_img_feats, choice_img_pos_feats, choice_num_bbs,
choice_attn_masks, qa_target, qar_target, choice_obj_targets)
def vcr_eval_collate(inputs):
(qids, input_ids, position_ids, txt_lens, txt_type_ids,
img_feats, img_pos_feats,
num_bbs, attn_masks, qa_targets, qar_targets,
obj_targets) = map(list, unzip(inputs))
all_num_bbs, all_img_feats, all_img_pos_feats = (
[], [], [])
all_txt_lens, all_input_ids, all_attn_masks, all_position_ids,\
all_txt_type_ids = (
[], [], [], [], [])
# all_qa_targets = qa_targets
# all_qar_targets = qar_targets
all_obj_targets = []
for i in range(len(num_bbs)):
all_input_ids += input_ids[i]
all_position_ids += position_ids[i]
all_txt_lens += txt_lens[i]
all_img_feats += img_feats[i]
all_img_pos_feats += img_pos_feats[i]
all_num_bbs += num_bbs[i]
all_attn_masks += attn_masks[i]
all_txt_type_ids += txt_type_ids[i]
all_obj_targets += obj_targets[i]
all_input_ids = pad_sequence(all_input_ids,
batch_first=True, padding_value=0)
all_position_ids = pad_sequence(all_position_ids,
batch_first=True, padding_value=0)
all_txt_type_ids = pad_sequence(all_txt_type_ids,
batch_first=True, padding_value=0)
all_attn_masks = pad_sequence(all_attn_masks,
batch_first=True, padding_value=0)
all_obj_targets = pad_sequence(all_obj_targets,
batch_first=True, padding_value=0)
all_qa_targets = torch.stack(qa_targets, dim=0)
all_qar_targets = torch.stack(qar_targets, dim=0)
batch_size = len(all_img_feats)
num_bb = max(all_num_bbs)
feat_dim = all_img_feats[0].size(1)
pos_dim = all_img_pos_feats[0].size(1)
all_img_feat = torch.zeros(batch_size, num_bb, feat_dim)
all_img_pos_feat = torch.zeros(batch_size, num_bb, pos_dim)
for i, (im, pos) in enumerate(zip(
all_img_feats, all_img_pos_feats)):
len_ = im.size(0)
all_img_feat.data[i, :len_, :] = im.data
all_img_pos_feat.data[i, :len_, :] = pos.data
return (qids, all_input_ids, all_position_ids, all_txt_lens,
all_txt_type_ids,
all_img_feat, all_img_pos_feat, all_num_bbs,
all_attn_masks, all_qa_targets, all_qar_targets, all_obj_targets)
class MlmDatasetForVCR(DetectFeatBertTokDataset):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
del self.txt_lens
def _get_input_ids(self, txt_dump, mask=True):
# text input
input_ids_q = txt_dump['input_ids']
type_ids_q = [0]*len(input_ids_q)
if mask:
input_ids_q, txt_labels_q = random_word(
input_ids_q, self.v_range, self.mask)
else:
txt_labels_q = input_ids_q
answer_label = txt_dump['qa_target']
assert answer_label >= 0, "answer_label < 0"
input_ids_a = txt_dump['input_ids_as'][answer_label]
type_ids_a = [2]*len(input_ids_a)
if mask:
input_ids_a, txt_labels_a = random_word(
input_ids_a, self.v_range, self.mask)
else:
txt_labels_a = input_ids_a
input_ids = input_ids_q + [self.sep] + input_ids_a
type_ids = type_ids_q + [0] + type_ids_a
txt_labels = txt_labels_q + [-1] + txt_labels_a
if self.task == "qar":
rationale_label = txt_dump['qar_target']
assert rationale_label >= 0, "rationale_label < 0"
input_ids_r = txt_dump['input_ids_rs'][rationale_label]
type_ids_r = [3]*len(input_ids_r)
if mask:
input_ids_r, txt_labels_r = random_word(
input_ids_r, self.v_range, self.mask)
else:
txt_labels_r = input_ids_r
input_ids += [self.sep] + input_ids_r
type_ids += [2] + type_ids_r
txt_labels += [-1] + txt_labels_r
return input_ids, type_ids, txt_labels
def __getitem__(self, i):
id_ = self.ids[i]
txt_dump = self.db[id_]
img_feat, img_pos_feat, num_bb = self._get_img_feat(
txt_dump['img_fname'][0], txt_dump['img_fname'][1])
# txt inputs
input_ids, type_ids, txt_labels = self._get_input_ids(txt_dump)
input_ids = [self.cls_] + input_ids + [self.sep]
txt_labels = [-1] + txt_labels + [-1]
type_ids = [type_ids[0]] + type_ids + [type_ids[-1]]
attn_masks = [1] * len(input_ids)
position_ids = list(range(len(input_ids)))
attn_masks += [1] * num_bb
input_ids = torch.tensor(input_ids)
position_ids = torch.tensor(position_ids)
attn_masks = torch.tensor(attn_masks)
txt_labels = torch.tensor(txt_labels)
type_ids = torch.tensor(type_ids)
return (input_ids, position_ids, type_ids, img_feat, img_pos_feat,
attn_masks, txt_labels)
def mlm_collate_for_vcr(inputs):
(input_ids, position_ids, type_ids, img_feats, img_pos_feats, attn_masks,
txt_labels) = map(list, unzip(inputs))
txt_lens = [i.size(0) for i in input_ids]
num_bbs = [f.size(0) for f in img_feats]
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0)
type_ids = pad_sequence(type_ids, batch_first=True, padding_value=0)
position_ids = pad_sequence(position_ids,
batch_first=True, padding_value=0)
attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0)
txt_labels = pad_sequence(txt_labels, batch_first=True, padding_value=-1)
batch_size = len(img_feats)
num_bb = max(num_bbs)
feat_dim = img_feats[0].size(1)
pos_dim = img_pos_feats[0].size(1)
img_feat = torch.zeros(batch_size, num_bb, feat_dim)
img_pos_feat = torch.zeros(batch_size, num_bb, pos_dim)
for i, (im, pos) in enumerate(zip(img_feats, img_pos_feats)):
len_ = im.size(0)
img_feat.data[i, :len_, :] = im.data
img_pos_feat.data[i, :len_, :] = pos.data
return (input_ids, position_ids, type_ids, txt_lens,
img_feat, img_pos_feat, num_bbs,
attn_masks, txt_labels)
class MrmDatasetForVCR(DetectFeatBertTokDataset):
def __init__(self, mask_prob, *args, **kwargs):
super().__init__(*args, **kwargs)
self.mask_prob = mask_prob
del self.txt_lens
def _get_input_ids(self, txt_dump, mask=True):
# text input
input_ids_q = txt_dump['input_ids']
type_ids_q = [0]*len(input_ids_q)
answer_label = txt_dump['qa_target']
assert answer_label >= 0, "answer_label < 0"
input_ids_a = txt_dump['input_ids_as'][answer_label]
type_ids_a = [2]*len(input_ids_a)
input_ids = input_ids_q + [self.sep] + input_ids_a
type_ids = type_ids_q + [0] + type_ids_a
if self.task == "qar":
rationale_label = txt_dump['qar_target']
assert rationale_label >= 0, "rationale_label < 0"
input_ids_r = txt_dump['input_ids_rs'][rationale_label]
type_ids_r = [3]*len(input_ids_r)
input_ids += [self.sep] + input_ids_r
type_ids += [2] + type_ids_r
return input_ids, type_ids
def __getitem__(self, i):
id_ = self.ids[i]
txt_dump = self.db[id_]
img_feat, img_pos_feat, num_bb = self._get_img_feat(
txt_dump['img_fname'][0], txt_dump['img_fname'][1])
# image input features
img_mask = [random.random() < self.mask_prob for _ in range(num_bb)]
if not any(img_mask):
# at least mask 1
img_mask[0] = True
img_mask = torch.tensor(img_mask)
# text input
input_ids, type_ids = self._get_input_ids(txt_dump)
input_ids = [self.cls_] + input_ids + [self.sep]
type_ids = [type_ids[0]] + type_ids + [type_ids[-1]]
attn_masks = [1] * len(input_ids)
position_ids = list(range(len(input_ids)))
attn_masks += [1] * num_bb
input_ids = torch.tensor(input_ids)
position_ids = torch.tensor(position_ids)
attn_masks = torch.tensor(attn_masks)
type_ids = torch.tensor(type_ids)
return (input_ids, position_ids, type_ids, img_feat, img_pos_feat,
attn_masks, img_mask)
def mrm_collate_for_vcr(inputs):
(input_ids, position_ids, type_ids, img_feats, img_pos_feats,
attn_masks, img_masks) = map(list, unzip(inputs))
txt_lens = [i.size(0) for i in input_ids]
num_bbs = [f.size(0) for f in img_feats]
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0)
position_ids = pad_sequence(position_ids,
batch_first=True, padding_value=0)
type_ids = pad_sequence(type_ids, batch_first=True, padding_value=0)
attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0)
img_masks = pad_sequence(img_masks, batch_first=True, padding_value=0)
batch_size = len(img_feats)
num_bb = max(num_bbs)
feat_dim = img_feats[0].size(1)
pos_dim = img_pos_feats[0].size(1)
img_feat = torch.zeros(batch_size, num_bb, feat_dim)
img_pos_feat = torch.zeros(batch_size, num_bb, pos_dim)
for i, (im, pos) in enumerate(zip(img_feats, img_pos_feats)):
len_ = im.size(0)
img_feat.data[i, :len_, :] = im.data
img_pos_feat.data[i, :len_, :] = pos.data
return (input_ids, position_ids, type_ids, txt_lens,
img_feat, img_pos_feat, num_bbs,
attn_masks, img_masks)
class DetectFeatBertTokDataset_for_mrc_vcr(DetectFeatBertTokDataset):
def __init__(self, db_dir, img_dir_gt=None, img_dir=None,
max_txt_len=60, task="qa"):
assert not (img_dir_gt is None and img_dir is None),\
"image_dir_gt and img_dir cannot all be None"
assert task == "qa" or task == "qar",\
"VCR only allow two tasks: qa or qar"
assert img_dir_gt is None or isinstance(img_dir_gt, DetectFeatLmdb)
assert img_dir is None or isinstance(img_dir, DetectFeatLmdb)
super().__init__(db_dir, img_dir_gt, img_dir, max_txt_len, task)
if self.img_dir:
self.img_dir = DetectFeatDir_for_mrc(img_dir)
if self.img_dir_gt:
self.img_dir_gt = DetectFeatDir_for_mrc(img_dir_gt)
def _get_img_feat(self, fname_gt, fname):
if self.img_dir and self.img_dir_gt:
img_feat_gt, bb_gt,\
img_soft_labels_gt = self.img_dir_gt[fname_gt]
img_bb_gt = torch.cat([bb_gt, bb_gt[:, 4:5]*bb_gt[:, 5:]], dim=-1)
img_feat, bb, img_soft_labels = self.img_dir[fname]
img_bb = torch.cat([bb, bb[:, 4:5]*bb[:, 5:]], dim=-1)
img_feat = torch.cat([img_feat_gt, img_feat], dim=0)
img_bb = torch.cat([img_bb_gt, img_bb], dim=0)
img_soft_labels = torch.cat(
[img_soft_labels_gt, img_soft_labels], dim=0)
num_bb = img_feat.size(0)
elif self.img_dir:
img_feat, bb, img_soft_labels = self.img_dir[fname]
img_bb = torch.cat([bb, bb[:, 4:5]*bb[:, 5:]], dim=-1)
num_bb = img_feat.size(0)
elif self.img_dir_gt:
img_feat, bb, img_soft_labels = self.img_dir_gt[fname_gt]
img_bb = torch.cat([bb, bb[:, 4:5]*bb[:, 5:]], dim=-1)
num_bb = img_feat.size(0)
return img_feat, img_bb, img_soft_labels, num_bb
class MrcDatasetForVCR(DetectFeatBertTokDataset_for_mrc_vcr):
def __init__(self, mask_prob, *args, **kwargs):
super().__init__(*args, **kwargs)
self.mask_prob = mask_prob
del self.txt_lens
def _get_input_ids(self, txt_dump, mask=True):
# text input
input_ids_q = txt_dump['input_ids']
type_ids_q = [0]*len(input_ids_q)
answer_label = txt_dump['qa_target']
assert answer_label >= 0, "answer_label < 0"
input_ids_a = txt_dump['input_ids_as'][answer_label]
type_ids_a = [2]*len(input_ids_a)
input_ids = input_ids_q + [self.sep] + input_ids_a
type_ids = type_ids_q + [0] + type_ids_a
if self.task == "qar":
rationale_label = txt_dump['qar_target']
assert rationale_label >= 0, "rationale_label < 0"
input_ids_r = txt_dump['input_ids_rs'][rationale_label]
type_ids_r = [3]*len(input_ids_r)
input_ids += [self.sep] + input_ids_r
type_ids += [2] + type_ids_r
return input_ids, type_ids
def __getitem__(self, i):
id_ = self.ids[i]
txt_dump = self.db[id_]
img_feat, img_pos_feat, img_soft_labels, num_bb = self._get_img_feat(
txt_dump['img_fname'][0], txt_dump['img_fname'][1])
# image input features
img_mask = [random.random() < self.mask_prob for _ in range(num_bb)]
if not any(img_mask):
# at least mask 1
img_mask[0] = True
img_mask = torch.tensor(img_mask)
# text input
input_ids, type_ids = self._get_input_ids(txt_dump)
input_ids = [self.cls_] + input_ids + [self.sep]
type_ids = [type_ids[0]] + type_ids + [type_ids[-1]]
attn_masks = [1] * len(input_ids)
position_ids = list(range(len(input_ids)))
attn_masks += [1] * num_bb
input_ids = torch.tensor(input_ids)
position_ids = torch.tensor(position_ids)
attn_masks = torch.tensor(attn_masks)
type_ids = torch.tensor(type_ids)
return (input_ids, position_ids, type_ids, img_feat, img_pos_feat,
img_soft_labels, attn_masks, img_mask)
def mrc_collate_for_vcr(inputs):
(input_ids, position_ids, type_ids, img_feats, img_pos_feats,
img_soft_labels, attn_masks, img_masks
) = map(list, unzip(inputs))
txt_lens = [i.size(0) for i in input_ids]
num_bbs = [f.size(0) for f in img_feats]
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0)
position_ids = pad_sequence(position_ids,
batch_first=True, padding_value=0)
type_ids = pad_sequence(type_ids, batch_first=True, padding_value=0)
attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0)
img_masks = pad_sequence(img_masks, batch_first=True, padding_value=0)
batch_size = len(img_feats)
num_bb = max(num_bbs)
feat_dim = img_feats[0].size(1)
soft_label_dim = img_soft_labels[0].size(1)
pos_dim = img_pos_feats[0].size(1)
img_feat = torch.zeros(batch_size, num_bb, feat_dim)
img_pos_feat = torch.zeros(batch_size, num_bb, pos_dim)
img_soft_label = torch.zeros(batch_size, num_bb, soft_label_dim)
for i, (im, pos, label) in enumerate(zip(img_feats,
img_pos_feats,
img_soft_labels)):
len_ = im.size(0)
img_feat.data[i, :len_, :] = im.data
img_pos_feat.data[i, :len_, :] = pos.data
img_soft_label.data[i, :len_, :] = label.data
img_masks_ext_for_label = img_masks.unsqueeze(-1).expand_as(img_soft_label)
label_targets = img_soft_label[img_masks_ext_for_label].contiguous().view(
-1, soft_label_dim)
return (input_ids, position_ids, type_ids, txt_lens,
img_feat, img_pos_feat, num_bbs,
attn_masks, (img_masks, label_targets))

19
uniter_model/data/ve.py

@ -0,0 +1,19 @@
"""
Visual entailment dataset
# NOTE: basically reuse VQA dataset
"""
from .vqa import VqaDataset, VqaEvalDataset, vqa_collate, vqa_eval_collate
class VeDataset(VqaDataset):
def __init__(self, *args, **kwargs):
super().__init__(3, *args, **kwargs)
class VeEvalDataset(VqaEvalDataset):
def __init__(self, *args, **kwargs):
super().__init__(3, *args, **kwargs)
ve_collate = vqa_collate
ve_eval_collate = vqa_eval_collate

124
uniter_model/data/vqa.py

@ -0,0 +1,124 @@
"""
VQA dataset
"""
import torch
from torch.nn.utils.rnn import pad_sequence
from toolz.sandbox import unzip
from .data import DetectFeatTxtTokDataset, pad_tensors, get_gather_index
def _get_vqa_target(example, num_answers):
target = torch.zeros(num_answers)
labels = example['target']['labels']
scores = example['target']['scores']
if labels and scores:
target.scatter_(0, torch.tensor(labels), torch.tensor(scores))
return target
class VqaDataset(DetectFeatTxtTokDataset):
""" NOTE: This handels distributed inside """
def __init__(self, num_answers, *args, **kwargs):
super().__init__(*args, **kwargs)
self.num_answers = num_answers
def __getitem__(self, i):
example = super().__getitem__(i)
img_feat, img_pos_feat, num_bb = self._get_img_feat(
example['img_fname'])
# text input
input_ids = example['input_ids']
input_ids = self.txt_db.combine_inputs(input_ids)
target = _get_vqa_target(example, self.num_answers)
attn_masks = torch.ones(len(input_ids) + num_bb, dtype=torch.long)
return input_ids, img_feat, img_pos_feat, attn_masks, target
def vqa_collate(inputs):
(input_ids, img_feats, img_pos_feats, attn_masks, targets
) = map(list, unzip(inputs))
txt_lens = [i.size(0) for i in input_ids]
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0)
position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long
).unsqueeze(0)
attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0)
targets = torch.stack(targets, dim=0)
num_bbs = [f.size(0) for f in img_feats]
img_feat = pad_tensors(img_feats, num_bbs)
img_pos_feat = pad_tensors(img_pos_feats, num_bbs)
bs, max_tl = input_ids.size()
out_size = attn_masks.size(1)
gather_index = get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size)
batch = {'input_ids': input_ids,
'position_ids': position_ids,
'img_feat': img_feat,
'img_pos_feat': img_pos_feat,
'attn_masks': attn_masks,
'gather_index': gather_index,
'targets': targets}
return batch
class VqaEvalDataset(VqaDataset):
def __getitem__(self, i):
qid = self.ids[i]
example = DetectFeatTxtTokDataset.__getitem__(self, i)
img_feat, img_pos_feat, num_bb = self._get_img_feat(
example['img_fname'])
# text input
input_ids = example['input_ids']
input_ids = self.txt_db.combine_inputs(input_ids)
if 'target' in example:
target = _get_vqa_target(example, self.num_answers)
else:
target = None
attn_masks = torch.ones(len(input_ids) + num_bb, dtype=torch.long)
return qid, input_ids, img_feat, img_pos_feat, attn_masks, target
def vqa_eval_collate(inputs):
(qids, input_ids, img_feats, img_pos_feats, attn_masks, targets
) = map(list, unzip(inputs))
txt_lens = [i.size(0) for i in input_ids]
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0)
position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long
).unsqueeze(0)
attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0)
if targets[0] is None:
targets = None
else:
targets = torch.stack(targets, dim=0)
num_bbs = [f.size(0) for f in img_feats]
img_feat = pad_tensors(img_feats, num_bbs)
img_pos_feat = pad_tensors(img_pos_feats, num_bbs)
bs, max_tl = input_ids.size()
out_size = attn_masks.size(1)
gather_index = get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size)
batch = {'qids': qids,
'input_ids': input_ids,
'position_ids': position_ids,
'img_feat': img_feat,
'img_pos_feat': img_pos_feat,
'attn_masks': attn_masks,
'gather_index': gather_index,
'targets': targets}
return batch

53
uniter_model/eval/itm.py

@ -0,0 +1,53 @@
""" Image Text Retrieval evaluation helper """
import torch
@torch.no_grad()
def itm_eval(score_matrix, txt_ids, img_ids, txt2img, img2txts):
# image retrieval
img2j = {i: j for j, i in enumerate(img_ids)}
_, rank_txt = score_matrix.topk(10, dim=1)
gt_img_j = torch.LongTensor([img2j[txt2img[txt_id]]
for txt_id in txt_ids],
).to(rank_txt.device
).unsqueeze(1).expand_as(rank_txt)
rank = (rank_txt == gt_img_j).nonzero()
if rank.numel():
ir_r1 = (rank < 1).sum().item() / len(txt_ids)
ir_r5 = (rank < 5).sum().item() / len(txt_ids)
ir_r10 = (rank < 10).sum().item() / len(txt_ids)
else:
ir_r1, ir_r5, ir_r10 = 0, 0, 0
# text retrieval
txt2i = {t: i for i, t in enumerate(txt_ids)}
_, rank_img = score_matrix.topk(10, dim=0)
tr_r1, tr_r5, tr_r10 = 0, 0, 0
for j, img_id in enumerate(img_ids):
gt_is = [txt2i[t] for t in img2txts[img_id]]
ranks = [(rank_img[:, j] == i).nonzero() for i in gt_is]
rank = min([10] + [r.item() for r in ranks if r.numel()])
if rank < 1:
tr_r1 += 1
if rank < 5:
tr_r5 += 1
if rank < 10:
tr_r10 += 1
tr_r1 /= len(img_ids)
tr_r5 /= len(img_ids)
tr_r10 /= len(img_ids)
tr_mean = (tr_r1 + tr_r5 + tr_r10) / 3
ir_mean = (ir_r1 + ir_r5 + ir_r10) / 3
r_mean = (tr_mean + ir_mean) / 2
eval_log = {'txt_r1': tr_r1,
'txt_r5': tr_r5,
'txt_r10': tr_r10,
'txt_r_mean': tr_mean,
'img_r1': ir_r1,
'img_r5': ir_r5,
'img_r10': ir_r10,
'img_r_mean': ir_mean,
'r_mean': r_mean}
return eval_log

62
uniter_model/eval/nlvr2.py

@ -0,0 +1,62 @@
"""
copied from official NLVR2 github
python eval/nlvr2.py <output.csv> <annotation.json>
"""
import json
import sys
# Load the predictions file. Assume it is a CSV.
predictions = { }
for line in open(sys.argv[1]).readlines():
if line:
splits = line.strip().split(",")
# We assume identifiers are in the format "split-####-#-#.png".
identifier = splits[0]
prediction = splits[1]
predictions[identifier] = prediction
# Load the labeled examples.
labeled_examples = [json.loads(line) for line in open(sys.argv[2]).readlines() if line]
# If not, identify the ones that are missing, and exit.
total_num = len(labeled_examples)
if len(predictions) < total_num:
print("Some predictions are missing!")
print("Got " + str(len(predictions)) + " predictions but expected " + str(total_num))
for example in labeled_examples:
lookup = example["identifier"]
if not lookup in predictions:
print("Missing prediction for item " + str(lookup))
exit()
# Get the precision by iterating through the examples and checking the value
# that was predicted.
# Also update the "consistency" dictionary that keeps track of whether all
# predictions for a given sentence were correct.
num_correct = 0.
consistency_dict = { }
for example in labeled_examples:
anon_label = example["identifier"].split("-")
anon_label[2] = ''
anon_label = '-'.join(anon_label)
if not anon_label in consistency_dict:
consistency_dict[anon_label] = True
lookup = example["identifier"]
prediction = predictions[lookup]
if prediction.lower() == example["label"].lower():
num_correct += 1.
else:
consistency_dict[anon_label] = False
# Calculate consistency.
num_consistent = 0.
unique_sentence = len(consistency_dict)
for identifier, consistent in consistency_dict.items():
if consistent:
num_consistent += 1
# Report values.
print("accuracy=" + str(num_correct / total_num))
print("consistency=" + str(num_consistent / unique_sentence))

218
uniter_model/eval_re.py

@ -0,0 +1,218 @@
# coding=utf-8
# copied from hugginface github
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc.
# team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""BERT for Referring Expression Comprehension Evaluation"""
import argparse
import json
import os
from os.path import exists
from time import time
import torch
from torch.utils.data import DataLoader
# to be deprecated once upgraded to 1.2
# from torch.utils.data.distributed import DistributedSampler
from data import DistributedSampler
from apex import amp
from horovod import torch as hvd
from data import (ReImageFeatDir, ReferringExpressionEvalDataset,
re_eval_collate, PrefetchLoader)
from model import BertForReferringExpressionComprehension
from utils.logger import LOGGER
from utils.distributed import all_gather_list
from utils.misc import Struct
def main(opts):
hvd.init()
n_gpu = hvd.size()
device = torch.device("cuda", hvd.local_rank())
torch.cuda.set_device(hvd.local_rank())
rank = hvd.rank()
opts.rank = rank
LOGGER.info(f"device: {device}, n_gpu: {n_gpu}, rank: {hvd.rank()}, "
f"16-bits training: {opts.fp16}")
hps_file = f'{opts.output_dir}/log/hps.json'
model_opts = json.load(open(hps_file))
if 'mlp' not in model_opts:
model_opts['mlp'] = 1
model_opts = Struct(model_opts)
# Prepro txt_dbs
txt_dbs = opts.txt_db.split(':')
# Prepro model
if exists(opts.checkpoint):
ckpt_file = torch.load(opts.checkpoint)
else:
ckpt_file = f'{opts.output_dir}/ckpt/model_epoch_{opts.checkpoint}.pt'
checkpoint = torch.load(ckpt_file)
bert_model = json.load(open(f'{txt_dbs[0]}/meta.json'))['bert']
model = BertForReferringExpressionComprehension.from_pretrained(
bert_model, img_dim=2048, mlp=model_opts.mlp,
state_dict=checkpoint
)
if model_opts.cut_bert != -1:
# cut some layers of BERT
model.bert.encoder.layer = torch.nn.ModuleList(
model.bert.encoder.layer[:opts.cut_bert]
)
model.to(device)
if opts.fp16:
model = amp.initialize(model, enabled=opts.fp16, opt_level='O2')
# load DBs and image dirs
eval_img_dir = ReImageFeatDir(opts.img_dir)
for txt_db in txt_dbs:
print(f'Evaluating {txt_db}')
eval_dataset = ReferringExpressionEvalDataset(txt_db, eval_img_dir,
max_txt_len=-1)
eval_sampler = DistributedSampler(eval_dataset, num_replicas=n_gpu,
rank=rank, shuffle=False)
eval_dataloader = DataLoader(eval_dataset,
sampler=eval_sampler,
batch_size=opts.batch_size,
num_workers=opts.n_workers,
pin_memory=opts.pin_mem,
collate_fn=re_eval_collate)
eval_dataloader = PrefetchLoader(eval_dataloader)
# evaluate
val_log, results = validate(model, eval_dataloader)
# save
result_dir = f'{opts.output_dir}/results_test'
if not exists(result_dir) and rank == 0:
os.makedirs(result_dir)
# dummy sync
_ = None
all_gather_list(_)
db_split = txt_db.split('/')[-1].split('-')[0] # refcoco+_val_large
img_dir = opts.img_dir.split('/')[-1] # visual_grounding_coco_gt
if n_gpu > 1:
with open(f'{opts.output_dir}/results_test/'
f'results_{opts.checkpoint}_{db_split}_on_{img_dir}'
f'_rank{rank}.json',
'w') as f:
json.dump(results, f)
# dummy sync
_ = None
all_gather_list(_)
# join results
if n_gpu > 1:
results = []
for rank in range(n_gpu):
results.extend(json.load(open(
f'{opts.output_dir}/results_test/'
f'results_{opts.checkpoint}_{db_split}_on_{img_dir}'
f'_rank{rank}.json')))
if rank == 0:
with open(f'{opts.output_dir}/results_test/'
f'results_{opts.checkpoint}_{db_split}_on_{img_dir}'
f'_all.json', 'w') as f:
json.dump(results, f)
# print
print(f'{opts.output_dir}/results_test')
@torch.no_grad()
def validate(model, val_dataloader):
LOGGER.info(f"start running evaluation.")
model.eval()
tot_score = 0
n_ex = 0
st = time()
predictions = []
for i, batch in enumerate(val_dataloader):
# inputs
(*batch_inputs, tgt_box_list, obj_boxes_list, sent_ids) = batch
# scores (n, max_num_bb)
scores = model(*batch_inputs, targets=None, compute_loss=False)
ixs = torch.argmax(scores, 1).cpu().detach().numpy() # (n, )
# pred_boxes
for ix, obj_boxes, tgt_box, sent_id in \
zip(ixs, obj_boxes_list, tgt_box_list, sent_ids):
pred_box = obj_boxes[ix]
predictions.append({'sent_id': sent_id,
'pred_box': pred_box.tolist(),
'tgt_box': tgt_box.tolist()})
if (val_dataloader.loader.dataset.computeIoU(pred_box, tgt_box)
> .5):
tot_score += 1
n_ex += 1
tot_time = time()-st
tot_score = sum(all_gather_list(tot_score))
n_ex = sum(all_gather_list(n_ex))
val_acc = tot_score / n_ex
val_log = {'valid/acc': val_acc, 'valid/ex_per_s': n_ex/tot_time}
model.train()
LOGGER.info(f"validation ({n_ex} sents) finished in "
f"{int(tot_time)} seconds"
f", accuracy: {val_acc*100:.2f}%")
# summarizae
results = {'acc': val_acc, 'predictions': predictions}
return val_log, results
if __name__ == '__main__':
parser = argparse.ArgumentParser()
# Requited parameters
parser.add_argument('--txt_db',
default=None, type=str,
help="The input train corpus. (LMDB)")
parser.add_argument('--img_dir',
default=None, type=str,
help="The input train images.")
parser.add_argument('--checkpoint',
default=None, type=str,
help="pretrained model (can take 'google-bert')")
parser.add_argument('--batch_size',
default=256, type=int,
help="number of sentences per batch")
parser.add_argument('--output_dir',
default=None, type=str,
help="The output directory where the model contains "
"the model checkpoints will be written.")
# Device parameters
parser.add_argument('--fp16',
action='store_true',
help="whether to use fp-16 float precision instead of "
"32 bit")
parser.add_argument('--n_workers', type=int, default=4,
help="number of data workers")
parser.add_argument('--pin_mem', action='store_true',
help="pin memory")
args = parser.parse_args()
main(args)

268
uniter_model/eval_vcr.py

@ -0,0 +1,268 @@
"""run inference of VCR for submission"""
import argparse
import json
import os
from os.path import exists
from time import time
import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader
from tqdm import tqdm
from apex import amp
from horovod import torch as hvd
from data import (DetectFeatLmdb, VcrEvalDataset, vcr_eval_collate,
PrefetchLoader)
from torch.utils.data.distributed import DistributedSampler
from model import BertForVisualCommonsenseReasoning
from utils.logger import LOGGER
from utils.distributed import all_gather_list
from utils.misc import NoOp, Struct
NUM_SPECIAL_TOKENS = 81
def load_img_feat(dir_list, path2imgdir, opts):
dir_ = dir_list.split(";")
assert len(dir_) <= 2, "More than two img_dirs found"
img_dir_gt, img_dir = None, None
gt_dir_path, dir_path = "", ""
for d in dir_:
if "gt" in d:
gt_dir_path = d
else:
dir_path = d
if gt_dir_path != "":
img_dir_gt = path2imgdir.get(gt_dir_path, None)
if img_dir_gt is None:
img_dir_gt = DetectFeatLmdb(gt_dir_path, -1,
opts.max_bb, opts.min_bb, 100,
opts.compressed_db)
path2imgdir[gt_dir_path] = img_dir_gt
if dir_path != "":
img_dir = path2imgdir.get(dir_path, None)
if img_dir is None:
img_dir = DetectFeatLmdb(dir_path, opts.conf_th,
opts.max_bb, opts.min_bb, opts.num_bb,
opts.compressed_db)
path2imgdir[dir_path] = img_dir
return img_dir, img_dir_gt, path2imgdir
def main(opts):
hvd.init()
n_gpu = hvd.size()
device = torch.device("cuda", hvd.local_rank())
torch.cuda.set_device(hvd.local_rank())
rank = hvd.rank()
opts.rank = rank
LOGGER.info("device: {} n_gpu: {}, rank: {}, "
"16-bits training: {}".format(
device, n_gpu, hvd.rank(), opts.fp16))
hps_file = f'{opts.output_dir}/log/hps.json'
model_opts = Struct(json.load(open(hps_file)))
path2imgdir = {}
# load DBs and image dirs
val_img_dir, val_img_dir_gt, path2imgdir = load_img_feat(
opts.img_dir, path2imgdir, model_opts)
eval_dataset = VcrEvalDataset("test", opts.txt_db,
val_img_dir_gt, val_img_dir,
max_txt_len=-1)
# Prepare model
bert_model = json.load(open(f'{opts.txt_db}/meta.json'))['bert']
model = BertForVisualCommonsenseReasoning.from_pretrained(
bert_model, img_dim=2048, obj_cls=False,
state_dict={})
model.init_type_embedding()
model.init_word_embedding(NUM_SPECIAL_TOKENS)
if exists(opts.checkpoint):
ckpt_file = opts.checkpoint
else:
ckpt_file = f'{opts.output_dir}/ckpt/model_step_{opts.checkpoint}.pt'
checkpoint = torch.load(ckpt_file)
state_dict = checkpoint.get('model_state', checkpoint)
matched_state_dict = {}
unexpected_keys = set()
missing_keys = set()
for name, param in model.named_parameters():
missing_keys.add(name)
for key, data in state_dict.items():
if key in missing_keys:
matched_state_dict[key] = data
missing_keys.remove(key)
else:
unexpected_keys.add(key)
print("Unexpected_keys:", list(unexpected_keys))
print("Missing_keys:", list(missing_keys))
model.load_state_dict(matched_state_dict, strict=False)
if model_opts.cut_bert != -1:
# cut some layers of BERT
model.bert.encoder.layer = torch.nn.ModuleList(
model.bert.encoder.layer[:model_opts.cut_bert])
model.to(device)
if opts.fp16:
model = amp.initialize(model, enabled=opts.fp16, opt_level='O2')
sampler = DistributedSampler(
eval_dataset, num_replicas=n_gpu, rank=rank)
eval_dataloader = DataLoader(eval_dataset,
batch_size=opts.batch_size,
sampler=sampler,
num_workers=opts.n_workers,
pin_memory=opts.pin_mem,
collate_fn=vcr_eval_collate)
eval_dataloader = PrefetchLoader(eval_dataloader)
val_log, results = evaluate(model, eval_dataloader)
result_dir = f'{opts.output_dir}/results_{opts.split}'
if not exists(result_dir) and rank == 0:
os.makedirs(result_dir)
# dummy sync
_ = None
all_gather_list(_)
if n_gpu > 1:
with open(f'{opts.output_dir}/results_test/'
f'results_{opts.checkpoint}_rank{rank}.json',
'w') as f:
json.dump(results, f)
# dummy sync
_ = None
all_gather_list(_)
# join results
if n_gpu > 1:
results = []
for rank in range(n_gpu):
results.extend(json.load(open(
f'{opts.output_dir}/results_test/'
f'results_{opts.checkpoint}_rank{rank}.json')))
if rank == 0:
with open(f'{opts.output_dir}/results_test/'
f'results_{opts.checkpoint}_all.json', 'w') as f:
json.dump(results, f)
def compute_accuracies(out_qa, labels_qa, out_qar, labels_qar):
outputs_qa = out_qa.max(dim=-1)[1]
outputs_qar = out_qar.max(dim=-1)[1]
matched_qa = outputs_qa.squeeze() == labels_qa.squeeze()
matched_qar = outputs_qar.squeeze() == labels_qar.squeeze()
matched_joined = matched_qa & matched_qar
n_correct_qa = matched_qa.sum().item()
n_correct_qar = matched_qar.sum().item()
n_correct_joined = matched_joined.sum().item()
return n_correct_qa, n_correct_qar, n_correct_joined
@torch.no_grad()
def evaluate(model, val_loader):
if hvd.rank() == 0:
val_pbar = tqdm(total=len(val_loader))
else:
val_pbar = NoOp()
LOGGER.info(f"start running evaluation ...")
model.eval()
val_qa_loss, val_qar_loss = 0, 0
tot_qa_score, tot_qar_score, tot_score = 0, 0, 0
n_ex = 0
st = time()
results = {}
for i, batch in enumerate(val_loader):
qids, *inputs, qa_targets, qar_targets, _ = batch
scores = model(
*inputs, targets=None, compute_loss=False)
scores = scores.view(len(qids), -1)
if torch.max(qa_targets) > -1:
vcr_qa_loss = F.cross_entropy(
scores[:, :4], qa_targets.squeeze(-1), reduction="sum")
if scores.shape[1] > 8:
qar_scores = []
for batch_id in range(scores.shape[0]):
answer_ind = qa_targets[batch_id].item()
qar_index = [4+answer_ind*4+i
for i in range(4)]
qar_scores.append(scores[batch_id, qar_index])
qar_scores = torch.stack(qar_scores, dim=0)
else:
qar_scores = scores[:, 4:]
vcr_qar_loss = F.cross_entropy(
qar_scores, qar_targets.squeeze(-1), reduction="sum")
val_qa_loss += vcr_qa_loss.item()
val_qar_loss += vcr_qar_loss.item()
curr_qa_score, curr_qar_score, curr_score = compute_accuracies(
scores[:, :4], qa_targets, qar_scores, qar_targets)
tot_qar_score += curr_qar_score
tot_qa_score += curr_qa_score
tot_score += curr_score
for qid, score in zip(qids, scores):
results[qid] = score.cpu().tolist()
n_ex += len(qids)
val_pbar.update(1)
val_qa_loss = sum(all_gather_list(val_qa_loss))
val_qar_loss = sum(all_gather_list(val_qar_loss))
tot_qa_score = sum(all_gather_list(tot_qa_score))
tot_qar_score = sum(all_gather_list(tot_qar_score))
tot_score = sum(all_gather_list(tot_score))
n_ex = sum(all_gather_list(n_ex))
tot_time = time()-st
val_qa_loss /= n_ex
val_qar_loss /= n_ex
val_qa_acc = tot_qa_score / n_ex
val_qar_acc = tot_qar_score / n_ex
val_acc = tot_score / n_ex
val_log = {f'valid/vcr_qa_loss': val_qa_loss,
f'valid/vcr_qar_loss': val_qar_loss,
f'valid/acc_qa': val_qa_acc,
f'valid/acc_qar': val_qar_acc,
f'valid/acc': val_acc,
f'valid/ex_per_s': n_ex/tot_time}
model.train()
LOGGER.info(f"validation finished in {int(tot_time)} seconds, "
f"score_qa: {val_qa_acc*100:.2f} "
f"score_qar: {val_qar_acc*100:.2f} "
f"score: {val_acc*100:.2f} ")
return val_log, results
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument("--txt_db",
default=None, type=str,
help="The input train corpus. (LMDB)")
parser.add_argument("--img_dir",
default=None, type=str,
help="The input train images.")
parser.add_argument('--compressed_db', action='store_true',
help='use compressed LMDB')
parser.add_argument("--split",
default="test", type=str,
help="The input split")
parser.add_argument("--checkpoint",
default=None, type=str,
help="pretrained model (can take 'google-bert') ")
parser.add_argument("--batch_size",
default=10, type=int,
help="number of tokens in a batch")
parser.add_argument(
"--output_dir", default=None, type=str,
help="The output directory where the model checkpoints will be "
"written.")
# device parameters
parser.add_argument('--fp16',
action='store_true',
help="Whether to use 16-bit float precision instead "
"of 32-bit")
parser.add_argument('--n_workers', type=int, default=4,
help="number of data workers")
parser.add_argument('--pin_mem', action='store_true',
help="pin memory")
args = parser.parse_args()
main(args)

180
uniter_model/eval_vqa.py

@ -0,0 +1,180 @@
"""run inference of VQA for submission"""
import argparse
import json
import os
from os.path import exists
from time import time
import torch
from torch.utils.data import DataLoader
from apex import amp
from horovod import torch as hvd
import numpy as np
from cytoolz import concat
from data import (TokenBucketSampler, PrefetchLoader,
DetectFeatLmdb, TxtTokLmdb, VqaEvalDataset, vqa_eval_collate)
from model import UniterForVisualQuestionAnswering
from utils.logger import LOGGER
from utils.distributed import all_gather_list
from utils.misc import Struct
from utils.const import BUCKET_SIZE, IMG_DIM
def main(opts):
hvd.init()
n_gpu = hvd.size()
device = torch.device("cuda", hvd.local_rank())
torch.cuda.set_device(hvd.local_rank())
rank = hvd.rank()
LOGGER.info("device: {} n_gpu: {}, rank: {}, "
"16-bits training: {}".format(
device, n_gpu, hvd.rank(), opts.fp16))
hps_file = f'{opts.output_dir}/log/hps.json'
model_opts = Struct(json.load(open(hps_file)))
# train_examples = None
ans2label_file = f'{opts.output_dir}/ckpt/ans2label.json'
ans2label = json.load(open(ans2label_file))
label2ans = {label: ans for ans, label in ans2label.items()}
# load DBs and image dirs
eval_img_db = DetectFeatLmdb(opts.img_db,
model_opts.conf_th, model_opts.max_bb,
model_opts.min_bb, model_opts.num_bb,
opts.compressed_db)
eval_txt_db = TxtTokLmdb(opts.txt_db, -1)
eval_dataset = VqaEvalDataset(len(ans2label), eval_txt_db, eval_img_db)
# Prepare model
if exists(opts.checkpoint):
ckpt_file = opts.checkpoint
else:
ckpt_file = f'{opts.output_dir}/ckpt/model_step_{opts.checkpoint}.pt'
checkpoint = torch.load(ckpt_file)
model = UniterForVisualQuestionAnswering.from_pretrained(
f'{opts.output_dir}/log/model.json', checkpoint,
img_dim=IMG_DIM, num_answer=len(ans2label))
model.to(device)
model = amp.initialize(model, enabled=opts.fp16, opt_level='O2')
sampler = TokenBucketSampler(eval_dataset.lens, bucket_size=BUCKET_SIZE,
batch_size=opts.batch_size, droplast=False)
eval_dataloader = DataLoader(eval_dataset,
batch_sampler=sampler,
num_workers=opts.n_workers,
pin_memory=opts.pin_mem,
collate_fn=vqa_eval_collate)
eval_dataloader = PrefetchLoader(eval_dataloader)
val_log, results, logits = evaluate(model, eval_dataloader, label2ans,
opts.save_logits)
result_dir = f'{opts.output_dir}/results_test'
if not exists(result_dir) and rank == 0:
os.makedirs(result_dir)
all_results = list(concat(all_gather_list(results)))
if opts.save_logits:
all_logits = {}
for id2logit in all_gather_list(logits):
all_logits.update(id2logit)
if hvd.rank() == 0:
with open(f'{result_dir}/'
f'results_{opts.checkpoint}_all.json', 'w') as f:
json.dump(all_results, f)
if opts.save_logits:
np.savez(f'{result_dir}/logits_{opts.checkpoint}_all.npz',
**all_logits)
@torch.no_grad()
def evaluate(model, eval_loader, label2ans, save_logits=False):
LOGGER.info("start running evaluation...")
model.eval()
n_ex = 0
st = time()
results = []
logits = {}
for i, batch in enumerate(eval_loader):
qids = batch['qids']
scores = model(batch, compute_loss=False)
answers = [label2ans[i]
for i in scores.max(dim=-1, keepdim=False
)[1].cpu().tolist()]
for qid, answer in zip(qids, answers):
results.append({'answer': answer, 'question_id': int(qid)})
if save_logits:
scores = scores.cpu()
for i, qid in enumerate(qids):
logits[qid] = scores[i].half().numpy()
if i % 100 == 0 and hvd.rank() == 0:
n_results = len(results)
n_results *= hvd.size() # an approximation to avoid hangs
LOGGER.info(f'{n_results}/{len(eval_loader.dataset)} '
'answers predicted')
n_ex += len(qids)
n_ex = sum(all_gather_list(n_ex))
tot_time = time()-st
val_log = {'valid/ex_per_s': n_ex/tot_time}
model.train()
LOGGER.info(f"evaluation finished in {int(tot_time)} seconds "
f"at {int(n_ex/tot_time)} examples per second")
return val_log, results, logits
def compute_score_with_logits(logits, labels):
logits = torch.max(logits, 1)[1] # argmax
one_hots = torch.zeros(*labels.size(), device=labels.device)
one_hots.scatter_(1, logits.view(-1, 1), 1)
scores = (one_hots * labels)
return scores
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument("--txt_db",
default=None, type=str,
help="The input train corpus. (LMDB)")
parser.add_argument("--img_db",
default=None, type=str,
help="The input train images.")
parser.add_argument('--compressed_db', action='store_true',
help='use compressed LMDB')
parser.add_argument("--checkpoint",
default=None, type=str,
help="pretrained model (can take 'google-bert') ")
parser.add_argument("--batch_size",
default=8192, type=int,
help="number of tokens in a batch")
parser.add_argument(
"--output_dir", default=None, type=str,
help="The output directory where the model checkpoints will be "
"written.")
parser.add_argument("--save_logits", action='store_true',
help="Whether to save logits (for making ensemble)")
# Prepro parameters
# device parameters
parser.add_argument('--fp16',
action='store_true',
help="Whether to use 16-bit float precision instead "
"of 32-bit")
parser.add_argument('--n_workers', type=int, default=4,
help="number of data workers")
parser.add_argument('--pin_mem', action='store_true',
help="pin memory")
args = parser.parse_args()
# options safe guard
# TODO
main(args)

71
uniter_model/experiments/ablation_refcoco+.sh

@ -0,0 +1,71 @@
# Supports ablation study of the follows:
# 1) scratch
# 2) bert
# 3) mrfr
# 4) mlm
# 5) itm
# 6) mlm_itm
# 7) mlm_mrfr_itm
# 8) mlm_mrc_itm
# 9) mlm_mrckl_itm
# 10) mlm_mrfr_mrc_itm
# 11) mlm_mrfr_mrckl_itm
# 12) mlm_mrfr_mrckl_itm_jrm
# 13) mlm_mrfr_mrckl_itm_jrm+
ablation_pretrained_model=$1
case $ablation_pretrained_model in
scratch|bert|mrfr|mlm|itm|mlm_itm|mlm_mrfr_itm|mlm_mrc_itm|mlm_mrckl_itm|mlm_mrfr_mrc_itm|mlm_mrfr_mrckl_itm|mlm_mrfr_mrckl_itm_jrm|mlm_mrfr_mrckl_itm_jrm+)
echo running $ablation_pretrained_model ...;;
*)
echo "$ablation_pretrained_model" not supported.;
exit 1;
esac
if [ "$ablation_pretrained_model" == "mrfr" ]; then
cut_bert=1
else
cut_bert=-1
fi
case $ablation_pretrained_model in
scratch)
cut_bert=1;
checkpoint="scratch";;
bert)
cut_bert=1;
checkpoint="google-bert";;
mrfr)
cut_bert=1;
checkpoint=/pretrain/ablation/"${ablation_pretrained_model}".pt;;
*)
cut_bert=-1;
checkpoint=/pretrain/ablation/"${ablation_pretrained_model}".pt;;
esac
horovodrun -np 1 -H localhost:1 \
python train_re.py \
--train_txt_db /db/refcoco+_train_base-cased.db \
--train_img_dir /img/visual_grounding_coco_gt \
--val_txt_db /db/refcoco+_val_base-cased.db \
--val_img_dir /img/visual_grounding_det_coco \
--checkpoint ${checkpoint} \
--cut_bert ${cut_bert} \
--output_dir /storage/refcoco+/ablation_"${ablation_pretrained_model}" \
--max_txt_len 60 \
--train_batch_size 128 \
--val_batch_size 128 \
--learning_rate 8e-5 \
--optim adamw \
--betas 0.9 0.98 \
--weight_decay 0.01 \
--dropout 0.1 \
--grad_norm 2.0 \
--decay linear \
--num_train_steps 24000 \
--warmup_steps 1500 \
--gradient_accumulation_steps 1 \
--seed 24 \
--mlp 1 \
--fp16

38
uniter_model/experiments/eval_ablation_refcoco+.sh

@ -0,0 +1,38 @@
# Supports ablation study of the follows:
# 1) scratch
# 2) bert
# 3) mrfr
# 4) mlm
# 5) itm
# 6) mlm_itm
# 7) mlm_mrfr_itm
# 8) mlm_mrc_itm
# 9) mlm_mrckl_itm
# 10) mlm_mrfr_mrc_itm
# 11) mlm_mrfr_mrckl_itm
# 12) mlm_mrfr_mrckl_itm_jrm
# 13) mlm_mrfr_mrckl_itm_jrm+
ablation_pretrained_model=$1
case $ablation_pretrained_model in
scratch|bert|mrfr|mlm|itm|mlm_itm|mlm_mrfr_itm|mlm_mrc_itm|mlm_mrckl_itm|mlm_mrfr_mrc_itm|mlm_mrfr_mrckl_itm|mlm_mrfr_mrckl_itm_jrm|mlm_mrfr_mrckl_itm_jrm+)
echo running $ablation_pretrained_model ...;;
*)
echo "$ablation_pretrained_model" not supported.;
exit 1;
esac
horovodrun -np 1 -H localhost:1 \
python eval_re.py \
--txt_db /db/refcoco+_val_base-cased.db:/db/refcoco+_testA_base-cased.db:/db/refcoco+_testB_base-cased.db \
--img_dir /img/visual_grounding_coco_gt \
--output_dir /storage/refcoco+/ablation_${ablation_pretrained_model} \
--checkpoint best
horovodrun -np 1 -H localhost:1 \
python eval_re.py \
--txt_db /db/refcoco+_val_base-cased.db:/db/refcoco+_testA_base-cased.db:/db/refcoco+_testB_base-cased.db \
--img_dir /img/visual_grounding_det_coco \
--output_dir /storage/refcoco+/ablation_${ablation_pretrained_model} \
--checkpoint best

Some files were not shown because too many files changed in this diff

Loading…
Cancel
Save