From 4059059c1bf79890d5c23a82451e54b74022913f Mon Sep 17 00:00:00 2001 From: wxywb Date: Wed, 15 Jun 2022 20:57:23 +0800 Subject: [PATCH] update the operator. Signed-off-by: wxywb --- __init__.py | 4 +- config/bert_base.json | 19 + config/coco_eval_config.json | 20 + config/coco_ft_config.json | 43 + config/flickr30k_eval_config.json | 22 + config/flickr30k_ft_config.json | 38 + config/img_base.json | 16 + config/pretrain-alldata-base.json | 191 ++++ data/model/resnet101_faster_rcnn_final.pth | 3 + data/model/uniter-base.pt | 3 + detector/__init__.py | 0 detector/__pycache__/__init__.cpython-38.pyc | Bin 0 -> 176 bytes .../__pycache__/bbox_transform.cpython-38.pyc | Bin 0 -> 2245 bytes .../__pycache__/faster_rcnn.cpython-38.pyc | Bin 0 -> 20477 bytes .../generate_anchors.cpython-38.pyc | Bin 0 -> 2219 bytes detector/__pycache__/rpn.cpython-38.pyc | Bin 0 -> 3064 bytes detector/bbox_transform.py | 75 ++ detector/faster_rcnn.py | 478 ++++++++++ detector/generate_anchors.py | 105 +++ detector/rpn.py | 136 +++ dvl/__pycache__/const.cpython-38.pyc | Bin 0 -> 233 bytes dvl/const.py | 3 + dvl/data/itm.py | 366 ++++++++ dvl/data/itm_pre.py | 592 +++++++++++++ dvl/data/mlm.py | 390 ++++++++ dvl/data/mrm.py | 263 ++++++ dvl/data/vqa.py | 145 +++ dvl/hn.py | 66 ++ dvl/indexer/faiss_indexers.py | 154 ++++ dvl/models/__init__.py | 0 .../__pycache__/__init__.cpython-38.pyc | Bin 0 -> 178 bytes .../__pycache__/bi_encoder.cpython-38.pyc | Bin 0 -> 22206 bytes dvl/models/bi_encoder.py | 757 ++++++++++++++++ dvl/options.py | 176 ++++ dvl/trainer.py | 209 +++++ dvl/utils.py | 234 +++++ lightningdot.py | 52 +- requirements.txt | 3 +- uniter_model/Dockerfile | 22 + uniter_model/LICENSE | 21 + uniter_model/README.md | 89 ++ uniter_model/config/config-vcr-bert-2gpu.json | 36 + uniter_model/config/eval-itm-coco.json | 11 + uniter_model/config/eval-itm-flickr.json | 11 + uniter_model/config/hps-itm.json | 53 ++ uniter_model/config/hps-refcoco+.json | 25 + .../config/hps-refcoco+_conceptual.json | 26 + .../hps-refcoco+_conceptual_large_weak.json | 26 + .../config/hps-refcoco+_conceptual_rank.json | 29 + uniter_model/config/hps-refcoco.json | 26 + uniter_model/config/hps-ve-large.json | 31 + uniter_model/config/hps-ve.json | 31 + uniter_model/config/hps-vqa.json | 30 + uniter_model/config/itm-coco-base.json | 47 + uniter_model/config/itm-ot-base-16gpus.json | 45 + .../config/itm-ot-base-16gpus_philly.json | 45 + uniter_model/config/pretrain-gqa-alltask.json | 42 + uniter_model/config/pretrain-mlm-coco.json | 42 + ...in-mlm_itmot_mrfr_mrckl-indomain-base.json | 53 ++ uniter_model/config/pretrain-mrckl-coco.json | 42 + uniter_model/config/pretrain-mrfr-coco.json | 42 + .../config/pretrain-mrm-nce-coco.json | 43 + uniter_model/config/pretrain-vcr-alltask.json | 38 + uniter_model/config/train-itm-debug.json | 40 + .../config/train-itm-flickr-base-hnv2.json | 38 + .../config/train-itm-flickr-base.json | 40 + .../config/train-nlvr2-base-1gpu.json | 37 + uniter_model/config/train-ve-base-2gpu.json | 31 + uniter_model/config/train-ve-large-2gpu.json | 31 + uniter_model/config/train-vqa-base-2gpu.json | 35 + uniter_model/config/uniter-base.json | 14 + uniter_model/config/uniter-large.json | 13 + uniter_model/data/__init__.py | 27 + uniter_model/data/data.py | 283 ++++++ uniter_model/data/itm.py | 572 ++++++++++++ uniter_model/data/loader.py | 138 +++ uniter_model/data/mlm.py | 360 ++++++++ uniter_model/data/mrm.py | 287 ++++++ uniter_model/data/mrm_nce.py | 136 +++ uniter_model/data/nlvr2.py | 218 +++++ uniter_model/data/re.py | 319 +++++++ uniter_model/data/sampler.py | 116 +++ uniter_model/data/test_data/input0.txt | Bin 0 -> 92849 bytes uniter_model/data/test_data/input1.txt | Bin 0 -> 92849 bytes uniter_model/data/test_data/input2.txt | Bin 0 -> 92849 bytes uniter_model/data/test_data/input3.txt | Bin 0 -> 92849 bytes uniter_model/data/test_data/input4.txt | Bin 0 -> 92849 bytes uniter_model/data/test_data/input5.txt | Bin 0 -> 92849 bytes uniter_model/data/test_data/input6.txt | Bin 0 -> 92849 bytes uniter_model/data/test_data/input7.txt | Bin 0 -> 92849 bytes uniter_model/data/vcr.py | 725 +++++++++++++++ uniter_model/data/ve.py | 19 + uniter_model/data/vqa.py | 124 +++ uniter_model/eval/itm.py | 53 ++ uniter_model/eval/nlvr2.py | 62 ++ uniter_model/eval_re.py | 218 +++++ uniter_model/eval_vcr.py | 268 ++++++ uniter_model/eval_vqa.py | 180 ++++ uniter_model/experiments/ablation_refcoco+.sh | 71 ++ .../experiments/eval_ablation_refcoco+.sh | 38 + uniter_model/experiments/eval_refcoco+.sh | 13 + .../eval_refcoco+_base_mlm_itm_mrfr_cc.sh | 15 + ...al_refcoco+_base_mlm_itm_mrfr_mrckl_all.sh | 17 + ..._refcoco+_base_mlm_itm_mrfr_mrckl_ccsbu.sh | 33 + ...refcoco+_base_mlm_itm_mrfr_mrckl_vgcoco.sh | 36 + .../experiments/eval_refcoco+_conceptual.sh | 15 + .../experiments/eval_refcoco+_large.sh | 14 + ...eval_refcoco+_large_mlm_itm_mrfr_cocovg.sh | 13 + ...l_refcoco+_large_mlm_itm_mrfr_mrckl_all.sh | 26 + .../eval_refer_base_mlm_itm_mrfr_mrckl_all.sh | 38 + ...eval_refer_large_mlm_itm_mrfr_mrckl_all.sh | 38 + uniter_model/experiments/train_refcoco+.sh | 30 + .../train_refcoco+_base_mlm_itm_mrfr_cc.sh | 27 + ...in_refcoco+_base_mlm_itm_mrfr_mrckl_all.sh | 44 + ..._refcoco+_base_mlm_itm_mrfr_mrckl_ccsbu.sh | 73 ++ ...refcoco+_base_mlm_itm_mrfr_mrckl_vgcoco.sh | 64 ++ .../train_refcoco+_conceptual_rank.sh | 2 + ...rain_refcoco+_large_mlm_itm_mrfr_cocovg.sh | 28 + ...n_refcoco+_large_mlm_itm_mrfr_mrckl_all.sh | 70 ++ uniter_model/experiments/train_refcoco.sh | 2 + ...train_refer_base_mlm_itm_mrfr_mrckl_all.sh | 67 ++ ...rain_refer_large_mlm_itm_mrfr_mrckl_all.sh | 67 ++ uniter_model/format_vcr_predictions.py | 54 ++ uniter_model/inf_itm.py | 155 ++++ uniter_model/launch_container.sh | 25 + uniter_model/launch_container_dist.sh | 28 + uniter_model/misc/ans2label.json | 1 + uniter_model/model/__init__.py | 8 + .../model/__pycache__/__init__.cpython-38.pyc | Bin 0 -> 630 bytes .../__pycache__/attention.cpython-38.pyc | Bin 0 -> 11790 bytes .../model/__pycache__/itm.cpython-38.pyc | Bin 0 -> 6182 bytes .../model/__pycache__/layer.cpython-38.pyc | Bin 0 -> 8270 bytes .../model/__pycache__/model.cpython-38.pyc | Bin 0 -> 20954 bytes .../model/__pycache__/nlvr2.cpython-38.pyc | Bin 0 -> 5848 bytes .../model/__pycache__/ot.cpython-38.pyc | Bin 0 -> 2590 bytes .../model/__pycache__/ve.cpython-38.pyc | Bin 0 -> 715 bytes .../model/__pycache__/vqa.cpython-38.pyc | Bin 0 -> 1931 bytes uniter_model/model/attention.py | 401 +++++++++ uniter_model/model/gqa.py | 133 +++ uniter_model/model/itm.py | 195 ++++ uniter_model/model/layer.py | 235 +++++ uniter_model/model/model.py | 701 +++++++++++++++ uniter_model/model/nlvr2.py | 182 ++++ uniter_model/model/ot.py | 82 ++ uniter_model/model/re.py | 140 +++ uniter_model/model/vcr.py | 287 ++++++ uniter_model/model/ve.py | 11 + uniter_model/model/vqa.py | 49 + uniter_model/optim/__init__.py | 2 + uniter_model/optim/adamw.py | 103 +++ uniter_model/optim/misc.py | 32 + uniter_model/optim/sched.py | 52 ++ uniter_model/prepro.py | 750 ++++++++++++++++ uniter_model/pretrain.py | 834 ++++++++++++++++++ uniter_model/pretrain_vcr.py | 754 ++++++++++++++++ uniter_model/requirements.txt | 3 + uniter_model/scripts/compress_lmdb.py | 53 ++ uniter_model/scripts/compute_numbb.py | 71 ++ uniter_model/scripts/convert_gqa.py | 62 ++ uniter_model/scripts/convert_imgdir.py | 139 +++ uniter_model/scripts/download_bert.py | 12 + uniter_model/scripts/download_bert.sh | 9 + uniter_model/scripts/install_horovod.sh | 42 + uniter_model/scripts/map_iid_to_ann_ids.py | 118 +++ uniter_model/scripts/map_vg_vqa_img.py | 67 ++ uniter_model/scripts/prepro_all.sh | 186 ++++ uniter_model/scripts/prepro_gqa.sh | 12 + uniter_model/scripts/prepro_iid_to_dets.py | 45 + uniter_model/scripts/prepro_re.sh | 44 + uniter_model/scripts/split_annotations.py | 57 ++ .../scripts/split_coco_pretrain_vqa.py | 57 ++ uniter_model/scripts/split_vqa_val.py | 64 ++ uniter_model/tests/generate_test_data.py | 97 ++ uniter_model/tests/test_distributed_fa.py | 126 +++ uniter_model/tests/test_hvd_fa.py | 118 +++ uniter_model/train_itm.py | 599 +++++++++++++ uniter_model/train_itm_v2.py | 499 +++++++++++ uniter_model/train_nlvr2.py | 418 +++++++++ uniter_model/train_re.py | 460 ++++++++++ uniter_model/train_vcr.py | 604 +++++++++++++ uniter_model/train_ve.py | 413 +++++++++ uniter_model/train_vqa.py | 415 +++++++++ uniter_model/utils/__init__.py | 0 uniter_model/utils/const.py | 4 + uniter_model/utils/distributed.py | 230 +++++ uniter_model/utils/itm.py | 62 ++ uniter_model/utils/logger.py | 91 ++ uniter_model/utils/misc.py | 67 ++ uniter_model/utils/save.py | 76 ++ uniter_model/utils/visual_entailment.py | 46 + uniter_model/utils/vqa.py | 203 +++++ utils.py | 8 + 192 files changed, 21477 insertions(+), 8 deletions(-) create mode 100644 config/bert_base.json create mode 100644 config/coco_eval_config.json create mode 100644 config/coco_ft_config.json create mode 100644 config/flickr30k_eval_config.json create mode 100644 config/flickr30k_ft_config.json create mode 100644 config/img_base.json create mode 100644 config/pretrain-alldata-base.json create mode 100644 data/model/resnet101_faster_rcnn_final.pth create mode 100644 data/model/uniter-base.pt create mode 100644 detector/__init__.py create mode 100644 detector/__pycache__/__init__.cpython-38.pyc create mode 100644 detector/__pycache__/bbox_transform.cpython-38.pyc create mode 100644 detector/__pycache__/faster_rcnn.cpython-38.pyc create mode 100644 detector/__pycache__/generate_anchors.cpython-38.pyc create mode 100644 detector/__pycache__/rpn.cpython-38.pyc create mode 100644 detector/bbox_transform.py create mode 100644 detector/faster_rcnn.py create mode 100644 detector/generate_anchors.py create mode 100644 detector/rpn.py create mode 100644 dvl/__pycache__/const.cpython-38.pyc create mode 100644 dvl/const.py create mode 100644 dvl/data/itm.py create mode 100644 dvl/data/itm_pre.py create mode 100644 dvl/data/mlm.py create mode 100644 dvl/data/mrm.py create mode 100644 dvl/data/vqa.py create mode 100644 dvl/hn.py create mode 100644 dvl/indexer/faiss_indexers.py create mode 100644 dvl/models/__init__.py create mode 100644 dvl/models/__pycache__/__init__.cpython-38.pyc create mode 100644 dvl/models/__pycache__/bi_encoder.cpython-38.pyc create mode 100644 dvl/models/bi_encoder.py create mode 100644 dvl/options.py create mode 100644 dvl/trainer.py create mode 100644 dvl/utils.py create mode 100644 uniter_model/Dockerfile create mode 100644 uniter_model/LICENSE create mode 100644 uniter_model/README.md create mode 100644 uniter_model/config/config-vcr-bert-2gpu.json create mode 100644 uniter_model/config/eval-itm-coco.json create mode 100644 uniter_model/config/eval-itm-flickr.json create mode 100644 uniter_model/config/hps-itm.json create mode 100644 uniter_model/config/hps-refcoco+.json create mode 100644 uniter_model/config/hps-refcoco+_conceptual.json create mode 100644 uniter_model/config/hps-refcoco+_conceptual_large_weak.json create mode 100644 uniter_model/config/hps-refcoco+_conceptual_rank.json create mode 100644 uniter_model/config/hps-refcoco.json create mode 100644 uniter_model/config/hps-ve-large.json create mode 100644 uniter_model/config/hps-ve.json create mode 100644 uniter_model/config/hps-vqa.json create mode 100644 uniter_model/config/itm-coco-base.json create mode 100644 uniter_model/config/itm-ot-base-16gpus.json create mode 100644 uniter_model/config/itm-ot-base-16gpus_philly.json create mode 100644 uniter_model/config/pretrain-gqa-alltask.json create mode 100644 uniter_model/config/pretrain-mlm-coco.json create mode 100644 uniter_model/config/pretrain-mlm_itmot_mrfr_mrckl-indomain-base.json create mode 100644 uniter_model/config/pretrain-mrckl-coco.json create mode 100644 uniter_model/config/pretrain-mrfr-coco.json create mode 100644 uniter_model/config/pretrain-mrm-nce-coco.json create mode 100644 uniter_model/config/pretrain-vcr-alltask.json create mode 100644 uniter_model/config/train-itm-debug.json create mode 100644 uniter_model/config/train-itm-flickr-base-hnv2.json create mode 100644 uniter_model/config/train-itm-flickr-base.json create mode 100644 uniter_model/config/train-nlvr2-base-1gpu.json create mode 100644 uniter_model/config/train-ve-base-2gpu.json create mode 100644 uniter_model/config/train-ve-large-2gpu.json create mode 100644 uniter_model/config/train-vqa-base-2gpu.json create mode 100644 uniter_model/config/uniter-base.json create mode 100644 uniter_model/config/uniter-large.json create mode 100644 uniter_model/data/__init__.py create mode 100644 uniter_model/data/data.py create mode 100644 uniter_model/data/itm.py create mode 100644 uniter_model/data/loader.py create mode 100644 uniter_model/data/mlm.py create mode 100644 uniter_model/data/mrm.py create mode 100644 uniter_model/data/mrm_nce.py create mode 100644 uniter_model/data/nlvr2.py create mode 100644 uniter_model/data/re.py create mode 100644 uniter_model/data/sampler.py create mode 100644 uniter_model/data/test_data/input0.txt create mode 100644 uniter_model/data/test_data/input1.txt create mode 100644 uniter_model/data/test_data/input2.txt create mode 100644 uniter_model/data/test_data/input3.txt create mode 100644 uniter_model/data/test_data/input4.txt create mode 100644 uniter_model/data/test_data/input5.txt create mode 100644 uniter_model/data/test_data/input6.txt create mode 100644 uniter_model/data/test_data/input7.txt create mode 100644 uniter_model/data/vcr.py create mode 100644 uniter_model/data/ve.py create mode 100644 uniter_model/data/vqa.py create mode 100644 uniter_model/eval/itm.py create mode 100644 uniter_model/eval/nlvr2.py create mode 100644 uniter_model/eval_re.py create mode 100644 uniter_model/eval_vcr.py create mode 100644 uniter_model/eval_vqa.py create mode 100644 uniter_model/experiments/ablation_refcoco+.sh create mode 100644 uniter_model/experiments/eval_ablation_refcoco+.sh create mode 100644 uniter_model/experiments/eval_refcoco+.sh create mode 100644 uniter_model/experiments/eval_refcoco+_base_mlm_itm_mrfr_cc.sh create mode 100644 uniter_model/experiments/eval_refcoco+_base_mlm_itm_mrfr_mrckl_all.sh create mode 100644 uniter_model/experiments/eval_refcoco+_base_mlm_itm_mrfr_mrckl_ccsbu.sh create mode 100644 uniter_model/experiments/eval_refcoco+_base_mlm_itm_mrfr_mrckl_vgcoco.sh create mode 100644 uniter_model/experiments/eval_refcoco+_conceptual.sh create mode 100644 uniter_model/experiments/eval_refcoco+_large.sh create mode 100644 uniter_model/experiments/eval_refcoco+_large_mlm_itm_mrfr_cocovg.sh create mode 100644 uniter_model/experiments/eval_refcoco+_large_mlm_itm_mrfr_mrckl_all.sh create mode 100644 uniter_model/experiments/eval_refer_base_mlm_itm_mrfr_mrckl_all.sh create mode 100644 uniter_model/experiments/eval_refer_large_mlm_itm_mrfr_mrckl_all.sh create mode 100644 uniter_model/experiments/train_refcoco+.sh create mode 100644 uniter_model/experiments/train_refcoco+_base_mlm_itm_mrfr_cc.sh create mode 100644 uniter_model/experiments/train_refcoco+_base_mlm_itm_mrfr_mrckl_all.sh create mode 100644 uniter_model/experiments/train_refcoco+_base_mlm_itm_mrfr_mrckl_ccsbu.sh create mode 100644 uniter_model/experiments/train_refcoco+_base_mlm_itm_mrfr_mrckl_vgcoco.sh create mode 100644 uniter_model/experiments/train_refcoco+_conceptual_rank.sh create mode 100644 uniter_model/experiments/train_refcoco+_large_mlm_itm_mrfr_cocovg.sh create mode 100644 uniter_model/experiments/train_refcoco+_large_mlm_itm_mrfr_mrckl_all.sh create mode 100644 uniter_model/experiments/train_refcoco.sh create mode 100644 uniter_model/experiments/train_refer_base_mlm_itm_mrfr_mrckl_all.sh create mode 100644 uniter_model/experiments/train_refer_large_mlm_itm_mrfr_mrckl_all.sh create mode 100644 uniter_model/format_vcr_predictions.py create mode 100644 uniter_model/inf_itm.py create mode 100644 uniter_model/launch_container.sh create mode 100644 uniter_model/launch_container_dist.sh create mode 100644 uniter_model/misc/ans2label.json create mode 100644 uniter_model/model/__init__.py create mode 100644 uniter_model/model/__pycache__/__init__.cpython-38.pyc create mode 100644 uniter_model/model/__pycache__/attention.cpython-38.pyc create mode 100644 uniter_model/model/__pycache__/itm.cpython-38.pyc create mode 100644 uniter_model/model/__pycache__/layer.cpython-38.pyc create mode 100644 uniter_model/model/__pycache__/model.cpython-38.pyc create mode 100644 uniter_model/model/__pycache__/nlvr2.cpython-38.pyc create mode 100644 uniter_model/model/__pycache__/ot.cpython-38.pyc create mode 100644 uniter_model/model/__pycache__/ve.cpython-38.pyc create mode 100644 uniter_model/model/__pycache__/vqa.cpython-38.pyc create mode 100644 uniter_model/model/attention.py create mode 100644 uniter_model/model/gqa.py create mode 100644 uniter_model/model/itm.py create mode 100644 uniter_model/model/layer.py create mode 100644 uniter_model/model/model.py create mode 100644 uniter_model/model/nlvr2.py create mode 100644 uniter_model/model/ot.py create mode 100644 uniter_model/model/re.py create mode 100644 uniter_model/model/vcr.py create mode 100644 uniter_model/model/ve.py create mode 100644 uniter_model/model/vqa.py create mode 100644 uniter_model/optim/__init__.py create mode 100644 uniter_model/optim/adamw.py create mode 100644 uniter_model/optim/misc.py create mode 100644 uniter_model/optim/sched.py create mode 100644 uniter_model/prepro.py create mode 100644 uniter_model/pretrain.py create mode 100644 uniter_model/pretrain_vcr.py create mode 100644 uniter_model/requirements.txt create mode 100644 uniter_model/scripts/compress_lmdb.py create mode 100644 uniter_model/scripts/compute_numbb.py create mode 100644 uniter_model/scripts/convert_gqa.py create mode 100644 uniter_model/scripts/convert_imgdir.py create mode 100644 uniter_model/scripts/download_bert.py create mode 100644 uniter_model/scripts/download_bert.sh create mode 100644 uniter_model/scripts/install_horovod.sh create mode 100644 uniter_model/scripts/map_iid_to_ann_ids.py create mode 100644 uniter_model/scripts/map_vg_vqa_img.py create mode 100644 uniter_model/scripts/prepro_all.sh create mode 100644 uniter_model/scripts/prepro_gqa.sh create mode 100644 uniter_model/scripts/prepro_iid_to_dets.py create mode 100644 uniter_model/scripts/prepro_re.sh create mode 100644 uniter_model/scripts/split_annotations.py create mode 100644 uniter_model/scripts/split_coco_pretrain_vqa.py create mode 100644 uniter_model/scripts/split_vqa_val.py create mode 100644 uniter_model/tests/generate_test_data.py create mode 100644 uniter_model/tests/test_distributed_fa.py create mode 100644 uniter_model/tests/test_hvd_fa.py create mode 100644 uniter_model/train_itm.py create mode 100644 uniter_model/train_itm_v2.py create mode 100644 uniter_model/train_nlvr2.py create mode 100644 uniter_model/train_re.py create mode 100644 uniter_model/train_vcr.py create mode 100644 uniter_model/train_ve.py create mode 100644 uniter_model/train_vqa.py create mode 100644 uniter_model/utils/__init__.py create mode 100644 uniter_model/utils/const.py create mode 100644 uniter_model/utils/distributed.py create mode 100644 uniter_model/utils/itm.py create mode 100644 uniter_model/utils/logger.py create mode 100644 uniter_model/utils/misc.py create mode 100644 uniter_model/utils/save.py create mode 100644 uniter_model/utils/visual_entailment.py create mode 100644 uniter_model/utils/vqa.py diff --git a/__init__.py b/__init__.py index 4e32104..139a180 100644 --- a/__init__.py +++ b/__init__.py @@ -14,5 +14,5 @@ from .lightningdot import LightningDOT -def lightningdot(modality: str): - return LightningDOT(modality) +def lightningdot(model_name: str, modality: str): + return LightningDOT(model_name, modality) diff --git a/config/bert_base.json b/config/bert_base.json new file mode 100644 index 0000000..cac0c19 --- /dev/null +++ b/config/bert_base.json @@ -0,0 +1,19 @@ +{ + "architectures": [ + "BertForMaskedLM" + ], + "attention_probs_dropout_prob": 0.1, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.1, + "hidden_size": 768, + "initializer_range": 0.02, + "intermediate_size": 3072, + "layer_norm_eps": 1e-12, + "max_position_embeddings": 512, + "model_type": "bert", + "num_attention_heads": 12, + "num_hidden_layers": 12, + "pad_token_id": 0, + "type_vocab_size": 2, + "vocab_size": 30522 +} \ No newline at end of file diff --git a/config/coco_eval_config.json b/config/coco_eval_config.json new file mode 100644 index 0000000..5272889 --- /dev/null +++ b/config/coco_eval_config.json @@ -0,0 +1,20 @@ +{ + "img_model_type": "uniter-base", + "txt_model_type": "bert-base", + "txt_model_config": "bert-base-cased", + "img_model_config": "./config/img_base.json", + "itm_global_file":"./data/meta/coco_meta.json", + "seed": 42, + "output_dir": "/storage/debug-eval", + "max_txt_len": 60, + "conf_th": 0.2, + "max_bb": 100, + "min_bb": 10, + "num_bb": 36, + "project_dim": 768, + "test_txt_db": "./data/db/itm_coco_test_base-cased.db", + "test_img_db": "./data/img/coco_val2014/", + "project_name": "itm-debug", + "n_workers": 4, + "fp16": true +} diff --git a/config/coco_ft_config.json b/config/coco_ft_config.json new file mode 100644 index 0000000..b5e246e --- /dev/null +++ b/config/coco_ft_config.json @@ -0,0 +1,43 @@ +{ + "txt_model_type": "bert-base", + "txt_model_config": "bert-base-cased", + "img_model_type": "uniter-base", + "img_model_config": "./config/img_base.json", + "img_checkpoint": "./data/model/uniter-base.pt", + "itm_global_file":"./data/meta/coco_meta.json", + "train_batch_size": 64, + "val_batch_size": 256, + "gradient_accumulation_steps": 1, + "learning_rate": 2e-05, + "warmup_steps": 100, + "valid_steps": 500, + "num_train_epochs": 20, + "seed": 42, + "output_dir": "/storage/debug_coco", + "max_txt_len": 60, + "conf_th": 0.2, + "max_bb": 100, + "min_bb": 10, + "num_bb": 36, + "project_dim": 768, + "train_txt_dbs": [ + "./data/db/itm_coco_train_base-cased.db", + "./data/db/itm_coco_restval_base-cased.db" + ], + "train_img_dbs": [ + "./data/img/coco_train2014/", + "./data/img/coco_val2014" + ], + "val_txt_db": "./data/db/itm_coco_val_base-cased.db", + "val_img_db": "./data/img/coco_val2014/", + "test_txt_db": "./data/db/itm_coco_test_base-cased.db", + "test_img_db": "./data/img/coco_val2014/", + "project_name": "itm-debug", + "num_hard_negatives": 0, + "hard_negatives_sampling": "none", + "inf_minibatch_size": 0, + "n_workers": 0, + "fp16": true, + "compressed_db": false, + "pin_mem": true +} \ No newline at end of file diff --git a/config/flickr30k_eval_config.json b/config/flickr30k_eval_config.json new file mode 100644 index 0000000..4d2ec98 --- /dev/null +++ b/config/flickr30k_eval_config.json @@ -0,0 +1,22 @@ +{ + "img_model_type": "uniter-base", + "txt_model_type": "bert-base", + "txt_model_config": "bert-base-cased", + "img_model_config": "./config/img_base.json", + "itm_global_file":"./data/meta/flickr_meta.json", + "seed": 42, + "output_dir": "/storage/debug-eval", + "max_txt_len": 60, + "conf_th": 0.2, + "max_bb": 100, + "min_bb": 10, + "num_bb": 36, + "project_dim": 768, + "val_txt_db": "./data/db/itm_flickr30k_val_base-cased.db", + "val_img_db": "./data/img/flickr30k/", + "test_txt_db": "./data/db/itm_flickr30k_test_base-cased.db", + "test_img_db": "./data/img/flickr30k/", + "project_name": "itm-debug", + "n_workers": 4, + "fp16": true +} diff --git a/config/flickr30k_ft_config.json b/config/flickr30k_ft_config.json new file mode 100644 index 0000000..ec9e77a --- /dev/null +++ b/config/flickr30k_ft_config.json @@ -0,0 +1,38 @@ +{ + "txt_model_type": "bert-base", + "txt_model_config": "bert-base-cased", + "img_model_type": "uniter-base", + "img_model_config": "./config/img_base.json", + "img_checkpoint": "./data/model/uniter-base.pt", + "itm_global_file":"./data/meta/flickr_meta.json", + "train_batch_size": 64, + "gradient_accumulation_steps": 1, + "learning_rate": 2e-05, + "warmup_steps": 100, + "valid_steps": 500, + "num_train_epochs": 15, + "seed": 42, + "output_dir": "/storage/debug_flickr", + "max_txt_len": 60, + "conf_th": 0.2, + "max_bb": 100, + "min_bb": 10, + "num_bb": 36, + "project_dim": 768, + "train_txt_dbs": [ + "./data/db/itm_flickr30k_train_base-cased.db" + ], + "train_img_dbs": [ + "./data/img/flickr30k/" + ], + "val_txt_db": "./data/db/itm_flickr30k_val_base-cased.db", + "val_img_db": "./data/img/flickr30k/", + "test_txt_db": "./data/db/itm_flickr30k_test_base-cased.db", + "test_img_db": "./data/img/flickr30k/", + "project_name": "itm-debug", + "num_hard_negatives": 0, + "hard_negatives_sampling": "none", + "inf_minibatch_size": 0, + "n_workers": 0, + "fp16": true +} \ No newline at end of file diff --git a/config/img_base.json b/config/img_base.json new file mode 100644 index 0000000..8d111ea --- /dev/null +++ b/config/img_base.json @@ -0,0 +1,16 @@ +{ + "attention_probs_dropout_prob": 0.1, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.1, + "hidden_size": 768, + "initializer_range": 0.02, + "intermediate_size": 3072, + "layer_norm_eps": 1e-12, + "max_position_embeddings": 512, + "num_attention_heads": 12, + "num_hidden_layers": 12, + "pad_token_id": 0, + "type_vocab_size": 2, + "vocab_size": 28996, + "output_hidden_states": false +} diff --git a/config/pretrain-alldata-base.json b/config/pretrain-alldata-base.json new file mode 100644 index 0000000..deeead6 --- /dev/null +++ b/config/pretrain-alldata-base.json @@ -0,0 +1,191 @@ +{ + "compressed_db": false, + "txt_model_type": "bert-base", + "txt_model_config": "bert-base-cased", + "img_model_type": "uniter-base", + "img_model_config": "./config/img_base.json", + "model_config": "./config/img_base.json", + "output_dir": "/storage/pretrain/alltask_ot_alldata_base", + "project_dim": 768, + "mrm_prob": 0.15, + "neg_size": 128, + "nce_temp": 1.0, + "itm_neg_prob": 0.0, + "itm_ot_lambda": 0.0, + "max_txt_len": 60, + "conf_th": 0.2, + "max_bb": 100, + "min_bb": 10, + "num_bb": 36, + "train_batch_size": 10240, + "val_batch_size": 10240, + "gradient_accumulation_steps": 6, + "learning_rate": 5e-05, + "valid_steps": 10000, + "num_train_steps": 300000, + "optim": "adamw", + "betas": [ + 0.9, + 0.98 + ], + "decay": "linear", + "dropout": 0.1, + "weight_decay": 0.01, + "grad_norm": 5.0, + "warmup_steps": 10000, + "seed": 42, + "fp16": true, + "n_workers": 3, + "pin_mem": true, + "train_datasets": [ + { + "name": "coco_cap", + "db": [ + "./data/db/pretrain_caption_coco_train_base-cased.db/", + "./data/db/pretrain_caption_coco_trainval_base-cased.db/" + ], + "img": [ + "./data/img/coco_train2014/", + "./data/img/coco_val2014/" + ], + "tasks": [ + "itm", + "mlm", + "mrfr", + "mrckl" + ], + "mix_ratio": [ + 16, + 8, + 4, + 4 + ] + }, + { + "name": "vg_cap", + "db": [ + "./data/db/pretrain_caption_vg_train_base-cased.db/" + ], + "img": [ + "./data/img/vg/" + ], + "tasks": [ + "itm", + "mlm", + "mrfr", + "mrckl" + ], + "mix_ratio": [ + 16, + 12, + 6, + 6 + ] + }, + { + "name": "cc", + "db": [ + "./data/db/conceptual_caption_train_base-cased.db/" + ], + "img": [ + "./data/img/gcc_train/" + ], + "tasks": [ + "itm", + "mlm", + "mrfr", + "mrckl" + ], + "mix_ratio": [ + 16, + 12, + 6, + 6 + ] + }, + { + "name": "sbu", + "db": [ + "./data/db/sbu_caption_train_base-cased.db/" + ], + "img": [ + "./data/img/sbu/" + ], + "tasks": [ + "itm", + "mlm", + "mrfr", + "mrckl" + ], + "mix_ratio": [ + 16, + 8, + 4, + 4 + ] + } + ], + "val_datasets": [ + { + "name": "coco_cap", + "db": [ + "./data/db/pretrain_caption_coco_val_base-cased.db/" + ], + "img": [ + "./data/img/coco_val2014/" + ], + "tasks": [ + "itm", + "mlm", + "mrfr", + "mrckl" + ] + }, + { + "name": "vg_cap", + "db": [ + "./data/db/pretrain_caption_vg_val_base-cased.db/" + ], + "img": [ + "./data/img/vg/" + ], + "tasks": [ + "itm", + "mlm", + "mrfr", + "mrckl" + ] + }, + { + "name": "cc", + "db": [ + "./data/db/conceptual_caption_val_base-cased.db/" + ], + "img": [ + "./data/img/gcc_val/" + ], + "tasks": [ + "itm", + "mlm", + "mrfr", + "mrckl" + ] + }, + { + "name": "sbu", + "db": [ + "./data/db/sbu_caption_val_base-cased.db/" + ], + "img": [ + "./data/img/sbu/" + ], + "tasks": [ + "itm", + "mlm", + "mrfr", + "mrckl" + ] + } + ], + "rank": 0 +} diff --git a/data/model/resnet101_faster_rcnn_final.pth b/data/model/resnet101_faster_rcnn_final.pth new file mode 100644 index 0000000..98a75be --- /dev/null +++ b/data/model/resnet101_faster_rcnn_final.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:aa6462dae6e3226dedb3080bbdd84f744629a5acfeee23b38655157b522997e6 +size 203016449 diff --git a/data/model/uniter-base.pt b/data/model/uniter-base.pt new file mode 100644 index 0000000..7436aa5 --- /dev/null +++ b/data/model/uniter-base.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e74adade973703eb810d983d88c8669fa63ba8a59d803d6bd06c5465e3142c1e +size 273600336 diff --git a/detector/__init__.py b/detector/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/detector/__pycache__/__init__.cpython-38.pyc b/detector/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..68117a1ac452a41e2761151b82d3e0ab31f06e61 GIT binary patch literal 176 zcmWIL<>g`kg2dL9Ng(<$h(HF6K#l_t7qb9~6oz01O-8?!3`HPe1o10WKeRZts93)$ zGbbmrN?)%ezdR!~RX?LNNk20;F+EkcB(qEmVMn_W)1x-w)^oCACC08{oC?M&F(#6`IXfhXwM)N2!*3bClhaz>VpPrRmY7Z6 z0#nWFm6#&0Yu5Kmf05TS{zmTScAMnvn5=@lkvp(0*v*XG(;RPoOxAR@XGWf$!Dt^N zmIrw=-^=%J>@~f&)@0SnTd1*j`Jl@yQAzR#8m_jW` zSvdHl8kr0y=`=}2_u$GCi@r>v^vWA%#!wA-<`K_EsS03@L^K>_sq$eDGTC30)>1VP zQ$S&sP}p1O)>^6{3*}H`X`e4DcGW+x_(#gy;=dBv#;^MOm)?g|$h3DGjmOb#?<`B^ zBO!XD*>x|9!=ZSciADBW#MgrJXfo`L^;=J1@g(cTVKnJ+k%>W;$lmpJvgn%{+$1tS zn=VzW?$UxtJ9rTK3U`~fS<&y&J=&rzX4BvwpUD@%{6Ei6@QHHr5t=zd`^4mFMn8h_ zIfV~AzhST}Zhc|zHO@98DC!Ga?G)kGxNp<~eRwr}_%%$yWsTYC!$k7NBYkZ6+jy)G z=p?wa^vbK9B3Q+{(|}v??iAoxybaB9jG=hcz-Z%;u5;TP`pzB{$9(oE;mAymkI(ci8hik=~-+oT=CT90*vt=I-HNmXbar4qhSUMn5mnw&6%c=MomOV(OOKCK&RV{7(&^zdZwz{5j(or}S z(!)qOw?ZbVJl3`wrg{zKO~gE0M5(gGV%j}W4p=Wz<#I93LiAhRdWC46jbcr=R17j@ zT=#RdkqV|#@U5tK_ENdVKi~ER&ZIKB-8_!;i#@B0tRGE2pTM8~fCqu7UeIq!nSD#m z%}(hleMw)kXY6Eq%O8QRNlw=Y2iN9kv9bHm%krnDfW+IMyoh5X2MlCFY5ErxyQo{u1*6_im&oj#C^b%FNr_*>}sAFY))`%OCl2_r1A& zzx~b3+j4JrXXfs7dwW{LSMiZ24ef%a{g^uPD~6r@*y<0X01fntmiIpOye|5Yd_;Uk z^D*%m&&S1QBA>v|NF_4XmT%KFh(fHA982X>x^{~OaY$U&AW>=?NneWQGtdUf%UV8* zYf_NLHRcE~^$*K~k>Qp>f^;wZF%obE=d)ma^vadE+k(onfpJ#N*;YxQDfpk!aH zS)(Q~tklfW zQ)W`=dxDm!MX}WfP}rKSJKCklbYw=;H9I;Tb95)-M4gxucM>xRM=Qh;O*oMmIPo5KS&YYd@?QFW6O^-{qqRVLiniPF*2Zj6?5NjFw6RR&$X?CLdFA9eM* ztB<=nSnX)@N!?8q3TCBPuNMkWYQH|yce-A(>U~q?N~Jv2x5chq94?jmh9``^@>p@G zwAn6A+M7#bMhT!?9qOx;hlcGcEZ{)oI;d8i|O# zx&g6h6T|bzN}xh6X3;5ck!!KpRwtnms|!C|ebQP?D}-1=NZVk|x>&SeK6dKj%Yx(d z-c$2}(KcB*{ZuYN$EsQ!D-{ZEx=iDJcDBldx{j6AeAce+`sRPgWy z{uj|mXb|EIAm4R5i<5dsQXnRs82*cv+pvi}^D$hoaevqm+6lw|&~q8@+)o`ZTDW`hjB?|6cA3L&=H^PRXrC%o>orRpUN>>EM5m*h zo+??jdZAJtEx9qHT&#-&?`ExuDxIo@u~M<>X8hiTqLqugaf80_`6PX{=xF5=Mpb;K z>SnQ0qD8mOvtO4EmpENIa3D}1O<{zv=2^Z>@6|Wpdzo(0iECnOdC?ZqtG*Y7KS7D< z2`8w_Nt~W=f}-ikn>=xv%|a4Vmvx+I(~!Y6ZDQg|(nRS-aUL{j_xUec<{8ULT{&Zc zm?CdV6SE zX(5AX3N8C~J#JQMpe(;X`N`cb(;O#SS z+g2f(!guT5SuIeqAn%>45N(&vD9^$grG>Rh3s);GtW#Q8ue7j1X(6YyaE;Q!Mx}*K zN(-Bn7PcrY^eHWDRa)4lw9v1#uw7|khtk4MrG;He3%ivTu2ovNPHEwKrG*=m7WODD z+^Dp$S7~9N(!zeFg#$_p2bC5MDJ|Tjv~XBy;bx_UTa*@VRa!Wrw1D@Titpy_N()Do z7LF+`3@9z!p|o&ZY2i+#g%e5(CzTdXDJ`553oKsr3 zS83rsrG^jVY zZkP`}umIb_$LYDxM(zY|-J+QY_AZEf+ung$y5F1KPl9gSt<@w5J0eq< zo7;{~`ilj_BIDcK4!PT34E2pr-}L%MQ{Crtzscu*lW&Ij+;8%^Z;Ih`zj3hDh&84J6dXIO;+Uxfi* zWmi~*0bhjyUuAb#g#llM0bk|XunGgd3Io2%bzv0-d=&;vw+l~EJWyK$oiJit!C@`w z%x-Ei!di^579ZlGZYDZYyNL8t4A4~%3@qA2o2ND@?Ui4 zLHAFGu+?{W@m5z))F zHRHh*SGTOCXw8jGRaw^B_S zu>Yb1N;gen1E%qIN)_9-BO{bs(5B-UyG;yHJQpK%Ba}%<4NOg?HbP%bMV& zk(0reRq`;Z*W|S$rvqE3iidcYur?2uy0OhyvG8u+x7LH41=toUd2!|~LQXHX#Y$d+ zc}tM96x%Wt5Ao$bXVCz%zvdr6xoNd_pmArQ5ZAZ=yY&(^_4(9Db&Teei zDtVpEyACtfy>@n!ygYL5!FEo`Tgtq9k#irmf|9q4dG{lyh|N&&Ov=FFtQ5sr3mlrC zX#OgRrMMFt89WzjBpPjvWFzIoriYvaEQb|NY9`i*iFs_;iNVU5NF(i}#oATPwKSKN zr$?L&W>_OL(MHxD9jT0rVcsu~g#o5a!UBZIIWvt26&F%R6pS$2BL z>Ae?o2!DtLw4~AF^ zJQmuYkL-m`hxlCNbc@ekXQ8vm=^cz>nrZXE6)7f4Go8dHGL;T4TJOZ3xMDu3iwkM^ z>Y3bADUKOX-1`|M{t}xuwVbY$&5Y`nJ(!FuR_c3k!Q=flKCSoq-aWS>oBBS)K1=xf zQOt?8&(Eo~FU+a6FV3m8FVCs9ugs~nr{>h!SLf8)*XGpP*XPvQH|EsZH|Nyax8~HE zJEzvZJ*U>bGpE+RJEzv@!E)|q%kR&rwI9r>wI9x@wI9u?wI9!^wV%wXwV%$ZwV(Ue zL|SC_X?%&?h(*u!v-E5|gbiQs-j|EGnc9RsK7pACk)x=O7wvMf;w35O&|2;wXO|z6O`T46~MXyiqo27!USwL|r_n>tHep)w@pmHcTJvdP{Z7h2% zRz&@p^&VP9l|$A(68lLUAaRfcg%qrtNE{|{Gl^SB+)CmIiQ7orPU0wuVsPvRaD=SbX3;yw}u68DoRk}yb^BxqBu5{W?) zbcC&85@ix2Bt}V8NQ{xFlBkgwCvl#HMWRl^CNV(*D_b?|B8f>750IE5@eC3Ui3W)W zNjya2VG_?I@hlRTNL(iI2#IHtc$CCrB%VX!xg?%P;`t;VC-DLjFC_503;o>;fo?)cPTd3Ct?Xlvd^>(5xBT7aT z#9b)Y@!W8uwefmBY1PUFO0`+9A=Yb2ypF`{NxXr?G>I7!ZzS<15^pB428BDn2*q)` ze4$h*UKkRV`oVe!t;vd{;kZ>ZOZB=N#nL#pt!NEl5uTk}?1d_qA}`-o9`nAO*~tP0 zKfUj3zSHY*c+&?s80zX(S2taK*wruQI|4R(H1Z2u`xJf?J0*tOThGTSH(Af8X~TCG z1}jB7pP3AzuKBdr5tf=OFa%qwW45!1t7_FjH$7S^jTZ!(PaQilP&jzUnZqY>@xYnp zV!WwBz9!JWn(uJYlwBChPhCLCAK6YT$A=cMo6cXQ&sf_ewWql9_(o#d3L4X3* zhp^Q@gTmHE@VS+$Y4Q0}AHf*m+(jzQFkB8oMemnf9_7zS7I`I4Sjodi9rHmH&m zpUNGqL@i7Na@dBzh&0@uL3P%wq zIU19DGf_hKXy?}n9>x?#<1;#q8Dj=0}1FhXChPI2> zQ~Pr5-WeYkMWfb7h(M1|ZXA=-LnSvcT(^tnC{~}3yKU9d#o}bS?#7FjReZqu7?D4Y zA|J)vxf{n?X{&@4t6hZCzRZg_zIQq$>74h@Jvw6ww#c+wpCwuwtt(d{pXM({9g99o zcHNE98EJivEG}lTgqA*mVxf*ljd~ge{(JDui0M5~$0GfDx1P`vPp9#mnd+kBFMn&V zsIKPY|DB{)Lp-BBk$mB&u%;cW7w`-aRibw6QoM;!Z3k9%N=O3Zt5}#SA#GHk3RwwB zVx2_?bV^7Hx~O(SLegFxg@j~aAuNJk3CY4@SOQBWq#c$~{76DNFov=UR!c}HMTTIl zgmhsnXC17Uka@5Ha&V1=bi+p21e+yfK5T(L*eW4CunqcQyM!!&9k3I2NytLj4cEeT z60!)chZ|szg!ICVuow18$YR(J2jHNDEP+FC6C9S1rEoLc0=G)YGB^Ua!R->V9FD>< z7?6+^a0eWRJ0)Z#oPd*XNScp1E0LUzF`;Fa(y3E2&=hS$JrCFELo9lRdi zAR*VmG|a#oCFFW|6TBJTA|W@xTj6c+b_v-7Pr^Ikof2{*{0{ss{GNpDg?GWb;XM+v z58ey!gZE3we)xU(1NcJ;IRGDk55gZw$U*oJd>B3=A&20j@W=2m3AqVA4u1llkdVXh zr|?PmGYPpF{v199pO%nY;4|=9_?(2?3ZI8Bz!xRt2z&{?3}2Cu+u$kqDtt{sZilbK zH{hERaumJ=E__=;j=^`}yYM{;8G!G@58#IqatHhfehfd6kmK-E_!<0MLhgiLz%2ZQ zgq(oCgujBnmXMS1H}JRccM@_6{vQ4TekmcR;UD2w@J|wQ7yL8)3;e5u+zr2me}msh z$Qk%|_z!qmLe4&cznH?NV~a?59{W*jF>H8ez}JO)u%E!zhAk=K=dho`md2Km@O!bJ z#nz6kL&ERFekZmrZ1W_%fcQtFf(<@Imaa$F>1mPQr(3iz~eXG!_e`81=?}u9Kd+P)JdPsk7NZ$_Wd+P&zyKw$) zZKrmpc0xOY_oOGa4g2tZX+S(!=aYt4pRp~U5w(q?SYw-bp3Y~DYEYh-?-c6;Ks;{e z+a*qaG2hin>MvUR@vf|uWn95B!~=UP%e;bRiYI!(;<6O>7$D-2zm>%vnPrHWKr4$q zGRqV}1Ht0L7WNn*BC62JVvo!+MA)H~#U7buins*rc`l{g*4*;}Ai@@8i6P$e5+dRn ztt@|30m~49k5-mHs(@vR$Vk2eJxYa-`FY}C%a913gm}i4JVV4+LOk8{NL?9=`3;Z(A6tNzU$7OYa-vAQfpb#(c%RECwj6%G?FY`LOg$ONIX--&paL_ra4gTKL!K80VE=V9*@g11HS2_%KuQ&P+qZ>%&aRJ2OoY z)Ne9*ev!Nfn27v`nUZ&A8X`RqW=h_fX^I>|lgYD;cdm@*ZF!ClzK&-kE8L#8sFnd1t06vRfWA_&q@2UBW~PEX)kN2TVic$->OQ zd%!eBGOfwvIY{yzU?SrdW=h_fX^6C3m??Q@rYUlFO(xGllJ@`;3BE8>^3F^{WC6oW z$vZPmk=&zuX8v)5f6p8MBKJr4xeW0=vxJD`U@ObNXAW3~$R4(`{CnnrWs3BoVDZzW z|Ku3_$f79nkF6~B$a;oIP`0wzBeP7A!4xchnq-dwB2t{KEcVDOL*zhPS?rNn<`n_>E?fl6h<%ne+nOyS^uX}_`iM9 zJ%Z0k)>S0dk`OfjSJTpZ5}~uo``;pKBdu!93Azi!U(xvI)}U1zD~yj;)=-x4M@ZN3 zKbN}wH|7TMEz8#4|7?xSn&ahO%wzXPu8<*YmA)zxw5i3E?Nsgs{3R7-zjG;TIeu6x zNUS7*Nl(oyrlE2Qkyo|8MN6wl1bcW9bw%oRGcHk`g!r4Eq@I09>rJMUcc$AdTGLX1 zM7dRRAoA-Uzy*xPVniMP^(HzL9ny!I|6}Foy;=i*{WB72M3521%5TiEQdKtQ!lQ`b ze=I%E2cvPaf+*Jtr|5vmVBE3BkZ0LJ3A2{c?zM z@W^Yk+eT9OFF%P!(n-?DoFwM>W0-TNPs+Ww4PiE3WYg0FLtlE|*rkx{<+2Y!A5YwK#1L~6$Hae!i9h@Dt`~Gl39Hh$#7QXzm)zt1 h+1Oqjm{z9~Sd@KG=?A{EdN$JCl}vR%q|tlLe*(+!UPk}` literal 0 HcmV?d00001 diff --git a/detector/__pycache__/generate_anchors.cpython-38.pyc b/detector/__pycache__/generate_anchors.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d032b029a2a0dd54e1fc63b4a29b4ce9a77b87b0 GIT binary patch literal 2219 zcma)--EP}96vs(@S+?WMOM@n9*FpONBEYq~0lUkxVcoC|1A<{gfx>8EC?ahswPly2 z^KU&yRl|KdIF!1g_xte7CVg z$RF5v{VGA@3n+323?q#8Nt-f*napB#&tx{Irv`Ia>CCu8x+Zg3`G~enR$$Ym3W<3Odmz zzXhYnDP?e{BqcrT3~K9SZYm=+Q+jCpiv8v*@``p1T=$WytP`?lGi!q!(yySL(Ycjc zupWE%bH#(}8=i0xDDDMG7!8HDDR`F)o`l@@CUL@s6Y1?vJ)VpQI4(|lULZ$2R35g5 z(tE0>kzvs1(whk71>V&Z-sYo*2fvR$#j=>NsoxM5gcBlw#f2Pj3+_OX?}O<f#<)sFsr2N&rZnbCZd(fHhG*YpxQd1Nwk@i&5C1iKt3L#Ca=Ur`ia;j5lNre< zvjZW5DFm2Df&pgEzEnZ@BC|!1^mt~=C>U|yZkKlh$vZNh@yyBhMu5mkyQ=*;gsfD+ z5H;u$RW!?Bvih4V`v|-e-zTNJv_k82>C!dCI<#->8447`dkjTB1CzqB*+2{Xv`(10 zM&=+^6d+jh%4g&wm<{`Cu9H;)n_&~aa&q`^E;wb&y=?I~TQIh!EGWfonMmJZl0Eo95%&nNx71lZ-~0 zeb5P&$m}UTa)g?rMp)`hB5=!3;B~1JjlD|qTf7Ut5IHR_{B9YF=C%hFh#d~|v7yKw z9NrBwN5$Xtz&SWIfDN+KeAY>7LaQJMt=Ib2ZuG5@E+mWM=v%X2@}nQ`bKv6bYbATb zZVps>n-l44hB;YqMCcI1F>nnEu6l}14W!{|Ab!Iej@4+4DL`_C6BcTN8S_+9VBg6Z zBM{#dm=d#6!Y9E&45u{G@&@+CL}o@L#XHb19-vu6qlsT*uXF}4iqCuo-4bUb&I-kQ z8SlRxOMC#n5cA98LgnKgJcc5F0Ha_B4(T&%Zp_WOH3to*<|>&xsgu(CWL{E``k|O< z=?s&oi|y1+=_-L%Sy@mor&ekoTJuU;!OS%KP9yrd$efU2-=u=dksLvtU^6*d?ONugaun^gPfu4|mS?*>Yg@fAuG|1Bs16QG3uD0l)&fLbbGxv8sRefhs(=`4OLB_pSV zo9&l#CYZr)WAYyWTzKs%l;Rxc0G6phzwc%??CEHlZ-*1;JgigW_4hR5r7adm}NJY!yV2r zJ-rVth9^n!NkD!AdDJz>{FJ`tw0|KN<*O!lMH{k*=$fwTVs%Y*eO3K!rxPFc*A zpX@RAUt%^NAI#?%)jvTblRRgG)te85Q2NG#`!kab8NFgMQqiRQ(i?2ari@>)!M5aoW4+c5d|}5u zZsXB;F*I3Oq&YZI6gM0zm~Nk{Lh00~BrS&HQdd1;cTY~rtHkKEs79rpCRuS|<6)l7 z63A6GB%_;;w`QteW1`rAOEwVSu@RSCidWIVlb-b96h1_Tcn9AxpAAAd*sup*t5dl3 zl`dyxmFCCFT$cK5*x0a7bBZo)Q#}VU%VI1)UG=ZCJkPHC&rEqaR;oWfKj~-F^i(}H>dHJ-(-S3SR-E?p>~w4j z%w%c$(=;pkQW-TgrS9ul@oe@&x3J4LMn$_b-sQU(;l}U!@dwZQ*=A>`F0&FT9PmMqAkUe!*;OfEZL%qv9)CLz--shc`4>Q zwSTASLDPS~Vs`4Cnjr7LMoN;d*@Yy83_kCZg!tX{cR+`%)w}L{{KejqP4)>nZ}@!A z98B&(mrw4?#v9rd@IUka#q}@hJ;~|@?J(bi1~`h;=+@dESu+nNhl__;byn-<(!!fJ)GD&v!7o z5je#g!my?@?unascy3m6{mHFdKf>%AKroi5G^(J)BS1=ZkqwozwEn*hqRUpkfab1yrs|UNz1}M#(iqb(c)+kWvG7r{_l524Rq}>rG8>L7?043?ryqbFw1G0fw@PNZJh2_(T%|@Skv7Vvr>i=F z;#I$5vck!tYfpuuW{_%5r7>Go?`JN-dQh#3!j4j!_3C z7Ic|a`XdU3BOjCF4fKu6R7*O{>Q5kyf69Fxij7(zy8b@j7TY`qjm56m<=Z?0501C- zj>JCAob)m8it9%ot{v;WhE$JlpEf3UrQ<(f(27_NzyiSN6X*b9pVa~Y2-gynj0e>h!9f3<^qQ{6#1Aooq;T?)IUS zq~SW)2$Px;k-rwwRpvX+qdNTn$B%SCcU$WOB)#L18h1LLVGEAS3_lBy%AlV!Tz|Y( z$DowC{@vGVT3)E5G|!LRt?lSCGvlMQcyUzW{&7@Iv7+>m8K=e3?G@wMu`9pt(@j5y z&pn}CRs^^mXw?gB;0(Hwf_hP*ZdCd+C~m~bh3JH)u#KaT+gg@k(q`e&*&a&Yd zJfMVK+jj*~KZU?XZvLz&G== 0 + boxes[:, 0::4] = np.maximum(np.minimum(boxes[:, 0::4], im_shape[1] - 1), 0) + # y1 >= 0 + boxes[:, 1::4] = np.maximum(np.minimum(boxes[:, 1::4], im_shape[0] - 1), 0) + # x2 < im_shape[1] + boxes[:, 2::4] = np.maximum(np.minimum(boxes[:, 2::4], im_shape[1] - 1), 0) + # y2 < im_shape[0] + boxes[:, 3::4] = np.maximum(np.minimum(boxes[:, 3::4], im_shape[0] - 1), 0) + return boxes diff --git a/detector/faster_rcnn.py b/detector/faster_rcnn.py new file mode 100644 index 0000000..dbfc060 --- /dev/null +++ b/detector/faster_rcnn.py @@ -0,0 +1,478 @@ +import pickle +import ipdb +import torch +import numpy as np +import cv2 +import torchvision +from torch import nn +from .rpn import RegionProposalNetwork + +class ConvBlock(nn.Module): + def __init__(self,i,o,k,s,p,d,use_relu = True): + super(ConvBlock, self).__init__() + self.conv = nn.Conv2d(i, o, k, s, p, d) + self.bn = nn.BatchNorm2d(o) + self.use_relu = use_relu + if self.use_relu == True: + self.relu = nn.ReLU() + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + if self.use_relu == True: + x = self.relu(x) + return x + +def load_convblock(block, convname, bnname, scalename, weights): + block.conv.weight = nn.Parameter(torch.FloatTensor(weights[convname][0])) + block.conv.bias = nn.Parameter(torch.zeros_like(block.conv.bias)) + block.bn.running_mean = nn.Parameter(torch.FloatTensor(weights[bnname][0] / weights[bnname][2])) + block.bn.running_var = nn.Parameter(torch.FloatTensor(weights[bnname][1] / weights[bnname][2])) + block.bn.weight = nn.Parameter(torch.FloatTensor(weights[scalename][0])) + block.bn.bias = nn.Parameter(torch.FloatTensor(weights[scalename][1])) + + +class Net(nn.Module): + def __init__(self): + super(Net, self).__init__() + self.conv1 = ConvBlock(3,64,7,2,3,1,True) + + self.pool1 = nn.MaxPool2d(3,2,0,ceil_mode=True) + + self.res2a_branch1 = ConvBlock(64,256,1,1,0,1,False) + self.res2a_branch2a = ConvBlock(64,64,1,1,0,1,True) + self.res2a_branch2b = ConvBlock(64,64,3,1,1,1,True) + self.res2a_branch2c = ConvBlock(64,256,1,1,0,1,False) + + self.res2b_branch2a = ConvBlock(256,64,1,1,0,1,True) + self.res2b_branch2b = ConvBlock(64,64,3,1,1,1,True) + self.res2b_branch2c = ConvBlock(64,256,1,1,0,1,False) + + self.res2c_branch2a = ConvBlock(256,64,1,1,0,1,True) + self.res2c_branch2b = ConvBlock(64,64,3,1,1,1,True) + self.res2c_branch2c = ConvBlock(64,256,1,1,0,1,False) + + self.res3a_branch1 = ConvBlock(256,512,1,2,0,1,False) + self.res3a_branch2a = ConvBlock(256,128,1,2,0,1,True) + self.res3a_branch2b = ConvBlock(128,128,3,1,1,1,True) + self.res3a_branch2c = ConvBlock(128,512,1,1,0,1,False) + + self.res3b1_branch2a = ConvBlock(512,128,1,1,0,1,True) + self.res3b1_branch2b = ConvBlock(128,128,3,1,1,1,True) + self.res3b1_branch2c = ConvBlock(128,512,1,1,0,1,False) + + self.res3b2_branch2a = ConvBlock(512,128,1,1,0,1,True) + self.res3b2_branch2b = ConvBlock(128,128,3,1,1,1,True) + self.res3b2_branch2c = ConvBlock(128,512,1,1,0,1,False) + + self.res3b3_branch2a = ConvBlock(512,128,1,1,0,1,True) + self.res3b3_branch2b = ConvBlock(128,128,3,1,1,1,True) + self.res3b3_branch2c = ConvBlock(128,512,1,1,0,1,False) + + self.res4a_branch1 = ConvBlock(512,1024,1,2,0,1,False) + self.res4a_branch2a = ConvBlock(512,256,1,2,0,1,True) + self.res4a_branch2b = ConvBlock(256,256,3,1,1,1,True) + self.res4a_branch2c = ConvBlock(256,1024,1,1,0,1,False) + + self.res4b1_branch2a = ConvBlock(1024,256,1,1,0,1,True) + self.res4b1_branch2b = ConvBlock(256,256,3,1,1,1,True) + self.res4b1_branch2c = ConvBlock(256,1024,1,1,0,1,False) + + self.res4b2_branch2a = ConvBlock(1024,256,1,1,0,1,True) + self.res4b2_branch2b = ConvBlock(256,256,3,1,1,1,True) + self.res4b2_branch2c = ConvBlock(256,1024,1,1,0,1,False) + + self.res4b3_branch2a = ConvBlock(1024,256,1,1,0,1,True) + self.res4b3_branch2b = ConvBlock(256,256,3,1,1,1,True) + self.res4b3_branch2c = ConvBlock(256,1024,1,1,0,1,False) + + self.res4b4_branch2a = ConvBlock(1024,256,1,1,0,1,True) + self.res4b4_branch2b = ConvBlock(256,256,3,1,1,1,True) + self.res4b4_branch2c = ConvBlock(256,1024,1,1,0,1,False) + + self.res4b5_branch2a = ConvBlock(1024,256,1,1,0,1,True) + self.res4b5_branch2b = ConvBlock(256,256,3,1,1,1,True) + self.res4b5_branch2c = ConvBlock(256,1024,1,1,0,1,False) + + self.res4b6_branch2a = ConvBlock(1024,256,1,1,0,1,True) + self.res4b6_branch2b = ConvBlock(256,256,3,1,1,1,True) + self.res4b6_branch2c = ConvBlock(256,1024,1,1,0,1,False) + + self.res4b7_branch2a = ConvBlock(1024,256,1,1,0,1,True) + self.res4b7_branch2b = ConvBlock(256,256,3,1,1,1,True) + self.res4b7_branch2c = ConvBlock(256,1024,1,1,0,1,False) + + self.res4b8_branch2a = ConvBlock(1024,256,1,1,0,1,True) + self.res4b8_branch2b = ConvBlock(256,256,3,1,1,1,True) + self.res4b8_branch2c = ConvBlock(256,1024,1,1,0,1,False) + + self.res4b9_branch2a = ConvBlock(1024,256,1,1,0,1,True) + self.res4b9_branch2b = ConvBlock(256,256,3,1,1,1,True) + self.res4b9_branch2c = ConvBlock(256,1024,1,1,0,1,False) + + self.res4b10_branch2a = ConvBlock(1024,256,1,1,0,1,True) + self.res4b10_branch2b = ConvBlock(256,256,3,1,1,1,True) + self.res4b10_branch2c = ConvBlock(256,1024,1,1,0,1,False) + + self.res4b11_branch2a = ConvBlock(1024,256,1,1,0,1,True) + self.res4b11_branch2b = ConvBlock(256,256,3,1,1,1,True) + self.res4b11_branch2c = ConvBlock(256,1024,1,1,0,1,False) + + self.res4b12_branch2a = ConvBlock(1024,256,1,1,0,1,True) + self.res4b12_branch2b = ConvBlock(256,256,3,1,1,1,True) + self.res4b12_branch2c = ConvBlock(256,1024,1,1,0,1,False) + + self.res4b13_branch2a = ConvBlock(1024,256,1,1,0,1,True) + self.res4b13_branch2b = ConvBlock(256,256,3,1,1,1,True) + self.res4b13_branch2c = ConvBlock(256,1024,1,1,0,1,False) + + self.res4b14_branch2a = ConvBlock(1024,256,1,1,0,1,True) + self.res4b14_branch2b = ConvBlock(256,256,3,1,1,1,True) + self.res4b14_branch2c = ConvBlock(256,1024,1,1,0,1,False) + + self.res4b15_branch2a = ConvBlock(1024,256,1,1,0,1,True) + self.res4b15_branch2b = ConvBlock(256,256,3,1,1,1,True) + self.res4b15_branch2c = ConvBlock(256,1024,1,1,0,1,False) + + self.res4b16_branch2a = ConvBlock(1024,256,1,1,0,1,True) + self.res4b16_branch2b = ConvBlock(256,256,3,1,1,1,True) + self.res4b16_branch2c = ConvBlock(256,1024,1,1,0,1,False) + + self.res4b17_branch2a = ConvBlock(1024,256,1,1,0,1,True) + self.res4b17_branch2b = ConvBlock(256,256,3,1,1,1,True) + self.res4b17_branch2c = ConvBlock(256,1024,1,1,0,1,False) + + self.res4b18_branch2a = ConvBlock(1024,256,1,1,0,1,True) + self.res4b18_branch2b = ConvBlock(256,256,3,1,1,1,True) + self.res4b18_branch2c = ConvBlock(256,1024,1,1,0,1,False) + + self.res4b19_branch2a = ConvBlock(1024,256,1,1,0,1,True) + self.res4b19_branch2b = ConvBlock(256,256,3,1,1,1,True) + self.res4b19_branch2c = ConvBlock(256,1024,1,1,0,1,False) + + self.res4b20_branch2a = ConvBlock(1024,256,1,1,0,1,True) + self.res4b20_branch2b = ConvBlock(256,256,3,1,1,1,True) + self.res4b20_branch2c = ConvBlock(256,1024,1,1,0,1,False) + + self.res4b21_branch2a = ConvBlock(1024,256,1,1,0,1,True) + self.res4b21_branch2b = ConvBlock(256,256,3,1,1,1,True) + self.res4b21_branch2c = ConvBlock(256,1024,1,1,0,1,False) + + self.res4b22_branch2a = ConvBlock(1024,256,1,1,0,1,True) + self.res4b22_branch2b = ConvBlock(256,256,3,1,1,1,True) + self.res4b22_branch2c = ConvBlock(256,1024,1,1,0,1,False) + + self.res5a_branch1 = ConvBlock(1024,2048,1,1,0,1,False) + self.res5a_branch2a = ConvBlock(1024,512,1,1,0,1,True) + self.res5a_branch2b = ConvBlock(512,512,3,1,2,2,True) + self.res5a_branch2c = ConvBlock(512,2048,1,1,0,1,False) + + self.res5b_branch2a = ConvBlock(2048,512,1,1,0,1,True) + self.res5b_branch2b = ConvBlock(512,512,3,1,2,2,True) + self.res5b_branch2c = ConvBlock(512,2048,1,1,0,1,False) + + self.res5c_branch2a = ConvBlock(2048,512,1,1,0,1,True) + self.res5c_branch2b = ConvBlock(512,512,3,1,2,2,True) + self.res5c_branch2c = ConvBlock(512,2048,1,1,0,1,False) + + self.rpn_conv_3x3 = nn.Conv2d(1024,512,3,1,1,1) + self.rpn_cls_score = nn.Conv2d(512,24,1,1,0,1) + self.rpn_bbox_pred = nn.Conv2d(512,48,1,1,0,1) + + self.rpn = RegionProposalNetwork(pre_nms_topN = 6000, post_nms_topN = 300, nms_thresh = 0.7, min_size = 16, anchor_scales = (4, 8, 16, 32), feat_stride=16) + + #self.pool5 = nn.MaxPool2d(3,2,1,ceil_mode=True) + self.cls_score = nn.Linear(2048, 1601) + + def infer_resblock(self, l, r, x): + xl = x + xr = x + + for b in l: + xl = b(xl) + for b in r: + xr = b(xr) + return xl + xr + + def forward(self, x, im_size): + x = self.conv1(x) + x = self.pool1(x) + + x = nn.functional.relu(self.infer_resblock([self.res2a_branch1], [self.res2a_branch2a,self.res2a_branch2b,self.res2a_branch2c],x)) + x = nn.functional.relu(self.infer_resblock([lambda x: x], [self.res2b_branch2a,self.res2b_branch2b,self.res2b_branch2c],x)) + x = nn.functional.relu(self.infer_resblock([lambda x: x], [self.res2c_branch2a,self.res2c_branch2b,self.res2c_branch2c],x)) + x = nn.functional.relu(self.infer_resblock([self.res3a_branch1], [self.res3a_branch2a,self.res3a_branch2b,self.res3a_branch2c],x)) + x = nn.functional.relu(self.infer_resblock([lambda x: x], [self.res3b1_branch2a,self.res3b1_branch2b,self.res3b1_branch2c],x)) + x = nn.functional.relu(self.infer_resblock([lambda x: x], [self.res3b2_branch2a,self.res3b2_branch2b,self.res3b2_branch2c],x)) + x = nn.functional.relu(self.infer_resblock([lambda x: x], [self.res3b3_branch2a,self.res3b3_branch2b,self.res3b3_branch2c],x)) + x = nn.functional.relu(self.infer_resblock([self.res4a_branch1], [self.res4a_branch2a,self.res4a_branch2b,self.res4a_branch2c],x)) + x = nn.functional.relu(self.infer_resblock([lambda x: x], [self.res4b1_branch2a,self.res4b1_branch2b,self.res4b1_branch2c],x)) + x = nn.functional.relu(self.infer_resblock([lambda x: x], [self.res4b2_branch2a,self.res4b2_branch2b,self.res4b2_branch2c],x)) + x = nn.functional.relu(self.infer_resblock([lambda x: x], [self.res4b3_branch2a,self.res4b3_branch2b,self.res4b3_branch2c],x)) + x = nn.functional.relu(self.infer_resblock([lambda x: x], [self.res4b4_branch2a,self.res4b4_branch2b,self.res4b4_branch2c],x)) + x = nn.functional.relu(self.infer_resblock([lambda x: x], [self.res4b5_branch2a,self.res4b5_branch2b,self.res4b5_branch2c],x)) + x = nn.functional.relu(self.infer_resblock([lambda x: x], [self.res4b6_branch2a,self.res4b6_branch2b,self.res4b6_branch2c],x)) + x = nn.functional.relu(self.infer_resblock([lambda x: x], [self.res4b7_branch2a,self.res4b7_branch2b,self.res4b7_branch2c],x)) + x = nn.functional.relu(self.infer_resblock([lambda x: x], [self.res4b8_branch2a,self.res4b8_branch2b,self.res4b8_branch2c],x)) + x = nn.functional.relu(self.infer_resblock([lambda x: x], [self.res4b9_branch2a,self.res4b9_branch2b,self.res4b9_branch2c],x)) + x = nn.functional.relu(self.infer_resblock([lambda x: x], [self.res4b10_branch2a,self.res4b10_branch2b,self.res4b10_branch2c],x)) + x = nn.functional.relu(self.infer_resblock([lambda x: x], [self.res4b11_branch2a,self.res4b11_branch2b,self.res4b11_branch2c],x)) + x = nn.functional.relu(self.infer_resblock([lambda x: x], [self.res4b12_branch2a,self.res4b12_branch2b,self.res4b12_branch2c],x)) + x = nn.functional.relu(self.infer_resblock([lambda x: x], [self.res4b13_branch2a,self.res4b13_branch2b,self.res4b13_branch2c],x)) + x = nn.functional.relu(self.infer_resblock([lambda x: x], [self.res4b14_branch2a,self.res4b14_branch2b,self.res4b14_branch2c],x)) + x = nn.functional.relu(self.infer_resblock([lambda x: x], [self.res4b15_branch2a,self.res4b15_branch2b,self.res4b15_branch2c],x)) + x = nn.functional.relu(self.infer_resblock([lambda x: x], [self.res4b16_branch2a,self.res4b16_branch2b,self.res4b16_branch2c],x)) + x = nn.functional.relu(self.infer_resblock([lambda x: x], [self.res4b17_branch2a,self.res4b17_branch2b,self.res4b17_branch2c],x)) + x = nn.functional.relu(self.infer_resblock([lambda x: x], [self.res4b18_branch2a,self.res4b18_branch2b,self.res4b18_branch2c],x)) + x = nn.functional.relu(self.infer_resblock([lambda x: x], [self.res4b19_branch2a,self.res4b19_branch2b,self.res4b19_branch2c],x)) + x = nn.functional.relu(self.infer_resblock([lambda x: x], [self.res4b20_branch2a,self.res4b20_branch2b,self.res4b20_branch2c],x)) + x = nn.functional.relu(self.infer_resblock([lambda x: x], [self.res4b21_branch2a,self.res4b21_branch2b,self.res4b21_branch2c],x)) + #x = data_kv['res4b21'] + x = nn.functional.relu(self.infer_resblock([lambda x: x], [self.res4b22_branch2a,self.res4b22_branch2b,self.res4b22_branch2c],x)) + + x_rpn_output = nn.functional.relu(self.rpn_conv_3x3(x)) + x_rpn_cls_score = self.rpn_cls_score(x_rpn_output) + x_rpn_bbox_pred = self.rpn_bbox_pred(x_rpn_output) + + n, c, h, w = x_rpn_cls_score.shape + + x_rpn_cls_score = x_rpn_cls_score.reshape(n,2,-1,w) + x_rpn_cls_prob = nn.functional.softmax(x_rpn_cls_score, 1) + x_rpn_cls_prob_reshape = x_rpn_cls_prob.reshape(n,24,-1,w) + + #im_size = np.array([600. , 600. , 2.6785715]) + #im_size = np.array([5.6200000e+02, 1.0000000e+03, 8.9285713e-01]) + + rois = self.rpn.forward(x_rpn_cls_prob_reshape, x_rpn_bbox_pred, im_size) + feats = torchvision.ops.roi_pool(x, rois, output_size=[14,14], spatial_scale=0.0625) + + + x = nn.functional.relu(self.infer_resblock([self.res5a_branch1], [self.res5a_branch2a,self.res5a_branch2b,self.res5a_branch2c],feats)) + x = nn.functional.relu(self.infer_resblock([lambda x: x], [self.res5b_branch2a,self.res5b_branch2b,self.res5b_branch2c],x)) + x = nn.functional.relu(self.infer_resblock([lambda x: x], [self.res5c_branch2a,self.res5c_branch2b,self.res5c_branch2c],x)) + + x = torch.nn.functional.adaptive_avg_pool2d(x, (1,1)) + + pool5_flat = x.reshape((x.shape[0], -1)) + + x_cls_score = self.cls_score(pool5_flat) + x_cls_prob = torch.nn.functional.softmax(x_cls_score, -1) + x_cls_boxes = rois[:, 1:5] / im_size[2] + max_conf, keep_boxes = self.post_process(rois, x_cls_boxes, x_cls_prob, 0.2) + + MIN_BOXES = 10 + MAX_BOXES = 100 + + if len(keep_boxes) < MIN_BOXES: + keep_boxes = torch.argsort(max_conf, 0, True)[:MIN_BOXES] + elif len(keep_boxes) > MAX_BOXES: + keep_boxes = torch.argsort(max_conf, 0, True)[:MAX_BOXES] + + boxes = x_cls_boxes[keep_boxes] + features = pool5_flat[keep_boxes] + confidence = max_conf[keep_boxes] + + return boxes, features, confidence + + + def post_process(self, rois, cls_boxes, cls_prob, conf_thresh = 0.2): + max_conf = torch.zeros((rois.shape[0]), device = rois.device ) + for cls_ind in range(1, cls_prob.shape[1]): + #cls_scores = scores[:, cls_ind] + cls_scores = cls_prob[:, cls_ind] + dets = torch.hstack( + (cls_boxes, cls_scores[:, np.newaxis])) + keep = np.array(torchvision.ops.nms(dets[:,:4],dets[:,4], 0.3 )) + max_conf[keep] = torch.where(cls_scores[keep] > max_conf[keep], + cls_scores[keep], max_conf[keep]) + keep_boxes = torch.where(max_conf >= conf_thresh)[0] + return max_conf, keep_boxes + + def load_weights_from_pkl(self, weights): + with torch.no_grad(): + load_convblock(self.conv1, 'conv1', 'bn_conv1', 'scale_conv1', weights_kv) + load_convblock(self.res2a_branch1, 'res2a_branch1', 'bn2a_branch1', 'scale2a_branch1', weights_kv) + load_convblock(self.res2a_branch2a, 'res2a_branch2a', 'bn2a_branch2a', 'scale2a_branch2a', weights_kv) + load_convblock(self.res2a_branch2b, 'res2a_branch2b', 'bn2a_branch2b', 'scale2a_branch2b', weights_kv) + load_convblock(self.res2a_branch2c, 'res2a_branch2c', 'bn2a_branch2c', 'scale2a_branch2c', weights_kv) + + load_convblock(self.res2b_branch2a, 'res2b_branch2a', 'bn2b_branch2a', 'scale2b_branch2a', weights_kv) + load_convblock(self.res2b_branch2b, 'res2b_branch2b', 'bn2b_branch2b', 'scale2b_branch2b', weights_kv) + load_convblock(self.res2b_branch2c, 'res2b_branch2c', 'bn2b_branch2c', 'scale2b_branch2c', weights_kv) + + load_convblock(self.res2c_branch2a, 'res2c_branch2a', 'bn2c_branch2a', 'scale2c_branch2a', weights_kv) + load_convblock(self.res2c_branch2b, 'res2c_branch2b', 'bn2c_branch2b', 'scale2c_branch2b', weights_kv) + load_convblock(self.res2c_branch2c, 'res2c_branch2c', 'bn2c_branch2c', 'scale2c_branch2c', weights_kv) + + load_convblock(self.res3a_branch1, 'res3a_branch1', 'bn3a_branch1', 'scale3a_branch1', weights_kv) + load_convblock(self.res3a_branch2a, 'res3a_branch2a', 'bn3a_branch2a', 'scale3a_branch2a', weights_kv) + load_convblock(self.res3a_branch2b, 'res3a_branch2b', 'bn3a_branch2b', 'scale3a_branch2b', weights_kv) + load_convblock(self.res3a_branch2c, 'res3a_branch2c', 'bn3a_branch2c', 'scale3a_branch2c', weights_kv) + + load_convblock(self.res3b1_branch2a, 'res3b1_branch2a', 'bn3b1_branch2a', 'scale3b1_branch2a', weights_kv) + load_convblock(self.res3b1_branch2b, 'res3b1_branch2b', 'bn3b1_branch2b', 'scale3b1_branch2b', weights_kv) + load_convblock(self.res3b1_branch2c, 'res3b1_branch2c', 'bn3b1_branch2c', 'scale3b1_branch2c', weights_kv) + + load_convblock(self.res3b2_branch2a, 'res3b2_branch2a', 'bn3b2_branch2a', 'scale3b2_branch2a', weights_kv) + load_convblock(self.res3b2_branch2b, 'res3b2_branch2b', 'bn3b2_branch2b', 'scale3b2_branch2b', weights_kv) + load_convblock(self.res3b2_branch2c, 'res3b2_branch2c', 'bn3b2_branch2c', 'scale3b2_branch2c', weights_kv) + + load_convblock(self.res3b3_branch2a, 'res3b3_branch2a', 'bn3b3_branch2a', 'scale3b3_branch2a', weights_kv) + load_convblock(self.res3b3_branch2b, 'res3b3_branch2b', 'bn3b3_branch2b', 'scale3b3_branch2b', weights_kv) + load_convblock(self.res3b3_branch2c, 'res3b3_branch2c', 'bn3b3_branch2c', 'scale3b3_branch2c', weights_kv) + + load_convblock(self.res4a_branch1, 'res4a_branch1', 'bn4a_branch1', 'scale4a_branch1', weights_kv) + load_convblock(self.res4a_branch2a, 'res4a_branch2a', 'bn4a_branch2a', 'scale4a_branch2a', weights_kv) + load_convblock(self.res4a_branch2b, 'res4a_branch2b', 'bn4a_branch2b', 'scale4a_branch2b', weights_kv) + load_convblock(self.res4a_branch2c, 'res4a_branch2c', 'bn4a_branch2c', 'scale4a_branch2c', weights_kv) + + load_convblock(self.res4b1_branch2a, 'res4b1_branch2a', 'bn4b1_branch2a', 'scale4b1_branch2a', weights_kv) + load_convblock(self.res4b1_branch2b, 'res4b1_branch2b', 'bn4b1_branch2b', 'scale4b1_branch2b', weights_kv) + load_convblock(self.res4b1_branch2c, 'res4b1_branch2c', 'bn4b1_branch2c', 'scale4b1_branch2c', weights_kv) + + load_convblock(self.res4b2_branch2a, 'res4b2_branch2a', 'bn4b2_branch2a', 'scale4b2_branch2a', weights_kv) + load_convblock(self.res4b2_branch2b, 'res4b2_branch2b', 'bn4b2_branch2b', 'scale4b2_branch2b', weights_kv) + load_convblock(self.res4b2_branch2c, 'res4b2_branch2c', 'bn4b2_branch2c', 'scale4b2_branch2c', weights_kv) + + load_convblock(self.res4b3_branch2a, 'res4b3_branch2a', 'bn4b3_branch2a', 'scale4b3_branch2a', weights_kv) + load_convblock(self.res4b3_branch2b, 'res4b3_branch2b', 'bn4b3_branch2b', 'scale4b3_branch2b', weights_kv) + load_convblock(self.res4b3_branch2c, 'res4b3_branch2c', 'bn4b3_branch2c', 'scale4b3_branch2c', weights_kv) + + load_convblock(self.res4b4_branch2a, 'res4b4_branch2a', 'bn4b4_branch2a', 'scale4b4_branch2a', weights_kv) + load_convblock(self.res4b4_branch2b, 'res4b4_branch2b', 'bn4b4_branch2b', 'scale4b4_branch2b', weights_kv) + load_convblock(self.res4b4_branch2c, 'res4b4_branch2c', 'bn4b4_branch2c', 'scale4b4_branch2c', weights_kv) + + load_convblock(self.res4b5_branch2a, 'res4b5_branch2a', 'bn4b5_branch2a', 'scale4b5_branch2a', weights_kv) + load_convblock(self.res4b5_branch2b, 'res4b5_branch2b', 'bn4b5_branch2b', 'scale4b5_branch2b', weights_kv) + load_convblock(self.res4b5_branch2c, 'res4b5_branch2c', 'bn4b5_branch2c', 'scale4b5_branch2c', weights_kv) + + load_convblock(self.res4b6_branch2a, 'res4b6_branch2a', 'bn4b6_branch2a', 'scale4b6_branch2a', weights_kv) + load_convblock(self.res4b6_branch2b, 'res4b6_branch2b', 'bn4b6_branch2b', 'scale4b6_branch2b', weights_kv) + load_convblock(self.res4b6_branch2c, 'res4b6_branch2c', 'bn4b6_branch2c', 'scale4b6_branch2c', weights_kv) + + load_convblock(self.res4b7_branch2a, 'res4b7_branch2a', 'bn4b7_branch2a', 'scale4b7_branch2a', weights_kv) + load_convblock(self.res4b7_branch2b, 'res4b7_branch2b', 'bn4b7_branch2b', 'scale4b7_branch2b', weights_kv) + load_convblock(self.res4b7_branch2c, 'res4b7_branch2c', 'bn4b7_branch2c', 'scale4b7_branch2c', weights_kv) + + load_convblock(self.res4b8_branch2a, 'res4b8_branch2a', 'bn4b8_branch2a', 'scale4b8_branch2a', weights_kv) + load_convblock(self.res4b8_branch2b, 'res4b8_branch2b', 'bn4b8_branch2b', 'scale4b8_branch2b', weights_kv) + load_convblock(self.res4b8_branch2c, 'res4b8_branch2c', 'bn4b8_branch2c', 'scale4b8_branch2c', weights_kv) + + load_convblock(self.res4b9_branch2a, 'res4b9_branch2a', 'bn4b9_branch2a', 'scale4b9_branch2a', weights_kv) + load_convblock(self.res4b9_branch2b, 'res4b9_branch2b', 'bn4b9_branch2b', 'scale4b9_branch2b', weights_kv) + load_convblock(self.res4b9_branch2c, 'res4b9_branch2c', 'bn4b9_branch2c', 'scale4b9_branch2c', weights_kv) + + load_convblock(self.res4b10_branch2a, 'res4b10_branch2a', 'bn4b10_branch2a', 'scale4b10_branch2a', weights_kv) + load_convblock(self.res4b10_branch2b, 'res4b10_branch2b', 'bn4b10_branch2b', 'scale4b10_branch2b', weights_kv) + load_convblock(self.res4b10_branch2c, 'res4b10_branch2c', 'bn4b10_branch2c', 'scale4b10_branch2c', weights_kv) + + load_convblock(self.res4b11_branch2a, 'res4b11_branch2a', 'bn4b11_branch2a', 'scale4b11_branch2a', weights_kv) + load_convblock(self.res4b11_branch2b, 'res4b11_branch2b', 'bn4b11_branch2b', 'scale4b11_branch2b', weights_kv) + load_convblock(self.res4b11_branch2c, 'res4b11_branch2c', 'bn4b11_branch2c', 'scale4b11_branch2c', weights_kv) + + load_convblock(self.res4b12_branch2a, 'res4b12_branch2a', 'bn4b12_branch2a', 'scale4b12_branch2a', weights_kv) + load_convblock(self.res4b12_branch2b, 'res4b12_branch2b', 'bn4b12_branch2b', 'scale4b12_branch2b', weights_kv) + load_convblock(self.res4b12_branch2c, 'res4b12_branch2c', 'bn4b12_branch2c', 'scale4b12_branch2c', weights_kv) + + load_convblock(self.res4b13_branch2a, 'res4b13_branch2a', 'bn4b13_branch2a', 'scale4b13_branch2a', weights_kv) + load_convblock(self.res4b13_branch2b, 'res4b13_branch2b', 'bn4b13_branch2b', 'scale4b13_branch2b', weights_kv) + load_convblock(self.res4b13_branch2c, 'res4b13_branch2c', 'bn4b13_branch2c', 'scale4b13_branch2c', weights_kv) + + load_convblock(self.res4b14_branch2a, 'res4b14_branch2a', 'bn4b14_branch2a', 'scale4b14_branch2a', weights_kv) + load_convblock(self.res4b14_branch2b, 'res4b14_branch2b', 'bn4b14_branch2b', 'scale4b14_branch2b', weights_kv) + load_convblock(self.res4b14_branch2c, 'res4b14_branch2c', 'bn4b14_branch2c', 'scale4b14_branch2c', weights_kv) + + load_convblock(self.res4b15_branch2a, 'res4b15_branch2a', 'bn4b15_branch2a', 'scale4b15_branch2a', weights_kv) + load_convblock(self.res4b15_branch2b, 'res4b15_branch2b', 'bn4b15_branch2b', 'scale4b15_branch2b', weights_kv) + load_convblock(self.res4b15_branch2c, 'res4b15_branch2c', 'bn4b15_branch2c', 'scale4b15_branch2c', weights_kv) + + load_convblock(self.res4b16_branch2a, 'res4b16_branch2a', 'bn4b16_branch2a', 'scale4b16_branch2a', weights_kv) + load_convblock(self.res4b16_branch2b, 'res4b16_branch2b', 'bn4b16_branch2b', 'scale4b16_branch2b', weights_kv) + load_convblock(self.res4b16_branch2c, 'res4b16_branch2c', 'bn4b16_branch2c', 'scale4b16_branch2c', weights_kv) + + load_convblock(self.res4b17_branch2a, 'res4b17_branch2a', 'bn4b17_branch2a', 'scale4b17_branch2a', weights_kv) + load_convblock(self.res4b17_branch2b, 'res4b17_branch2b', 'bn4b17_branch2b', 'scale4b17_branch2b', weights_kv) + load_convblock(self.res4b17_branch2c, 'res4b17_branch2c', 'bn4b17_branch2c', 'scale4b17_branch2c', weights_kv) + + load_convblock(self.res4b18_branch2a, 'res4b18_branch2a', 'bn4b18_branch2a', 'scale4b18_branch2a', weights_kv) + load_convblock(self.res4b18_branch2b, 'res4b18_branch2b', 'bn4b18_branch2b', 'scale4b18_branch2b', weights_kv) + load_convblock(self.res4b18_branch2c, 'res4b18_branch2c', 'bn4b18_branch2c', 'scale4b18_branch2c', weights_kv) + + load_convblock(self.res4b19_branch2a, 'res4b19_branch2a', 'bn4b19_branch2a', 'scale4b19_branch2a', weights_kv) + load_convblock(self.res4b19_branch2b, 'res4b19_branch2b', 'bn4b19_branch2b', 'scale4b19_branch2b', weights_kv) + load_convblock(self.res4b19_branch2c, 'res4b19_branch2c', 'bn4b19_branch2c', 'scale4b19_branch2c', weights_kv) + + load_convblock(self.res4b20_branch2a, 'res4b20_branch2a', 'bn4b20_branch2a', 'scale4b20_branch2a', weights_kv) + load_convblock(self.res4b20_branch2b, 'res4b20_branch2b', 'bn4b20_branch2b', 'scale4b20_branch2b', weights_kv) + load_convblock(self.res4b20_branch2c, 'res4b20_branch2c', 'bn4b20_branch2c', 'scale4b20_branch2c', weights_kv) + + load_convblock(self.res4b21_branch2a, 'res4b21_branch2a', 'bn4b21_branch2a', 'scale4b21_branch2a', weights_kv) + load_convblock(self.res4b21_branch2b, 'res4b21_branch2b', 'bn4b21_branch2b', 'scale4b21_branch2b', weights_kv) + load_convblock(self.res4b21_branch2c, 'res4b21_branch2c', 'bn4b21_branch2c', 'scale4b21_branch2c', weights_kv) + + load_convblock(self.res4b22_branch2a, 'res4b22_branch2a', 'bn4b22_branch2a', 'scale4b22_branch2a', weights_kv) + load_convblock(self.res4b22_branch2b, 'res4b22_branch2b', 'bn4b22_branch2b', 'scale4b22_branch2b', weights_kv) + load_convblock(self.res4b22_branch2c, 'res4b22_branch2c', 'bn4b22_branch2c', 'scale4b22_branch2c', weights_kv) + + load_convblock(self.res5a_branch1, 'res5a_branch1', 'bn5a_branch1', 'scale5a_branch1', weights_kv) + load_convblock(self.res5a_branch2a, 'res5a_branch2a', 'bn5a_branch2a', 'scale5a_branch2a', weights_kv) + load_convblock(self.res5a_branch2b, 'res5a_branch2b', 'bn5a_branch2b', 'scale5a_branch2b', weights_kv) + load_convblock(self.res5a_branch2c, 'res5a_branch2c', 'bn5a_branch2c', 'scale5a_branch2c', weights_kv) + + load_convblock(self.res5b_branch2a, 'res5b_branch2a', 'bn5b_branch2a', 'scale5b_branch2a', weights_kv) + load_convblock(self.res5b_branch2b, 'res5b_branch2b', 'bn5b_branch2b', 'scale5b_branch2b', weights_kv) + load_convblock(self.res5b_branch2c, 'res5b_branch2c', 'bn5b_branch2c', 'scale5b_branch2c', weights_kv) + + load_convblock(self.res5c_branch2a, 'res5c_branch2a', 'bn5c_branch2a', 'scale5c_branch2a', weights_kv) + load_convblock(self.res5c_branch2b, 'res5c_branch2b', 'bn5c_branch2b', 'scale5c_branch2b', weights_kv) + load_convblock(self.res5c_branch2c, 'res5c_branch2c', 'bn5c_branch2c', 'scale5c_branch2c', weights_kv) + + self.rpn_conv_3x3.weight = nn.Parameter(torch.FloatTensor(weights_kv['rpn_conv/3x3'][0])) + self.rpn_conv_3x3.bias = nn.Parameter(torch.FloatTensor(weights_kv['rpn_conv/3x3'][1])) + + self.rpn_cls_score.weight = nn.Parameter(torch.FloatTensor(weights_kv['rpn_cls_score'][0])) + self.rpn_cls_score.bias = nn.Parameter(torch.FloatTensor(weights_kv['rpn_cls_score'][1])) + + self.rpn_bbox_pred.weight = nn.Parameter(torch.FloatTensor(weights_kv['rpn_bbox_pred'][0])) + self.rpn_bbox_pred.bias = nn.Parameter(torch.FloatTensor(weights_kv['rpn_bbox_pred'][1])) + + self.cls_score.weight = nn.Parameter(torch.FloatTensor(weights_kv['cls_score'][0])) + self.cls_score.bias = nn.Parameter(torch.FloatTensor(weights_kv['cls_score'][1])) + + +# self.conv1.weight = nn.Parameter(torch.FloatTensor(weights[0]['weights'][0])) +# self.conv1.bias = nn.Parameter(torch.zeros_like(self.conv1.bias)) +# self.bn_conv1.running_mean = nn.Parameter(torch.FloatTensor(weights[1]['weights'][0] / weights[1]['weights'][2])) +# self.bn_conv1.running_var = nn.Parameter(torch.FloatTensor(weights[1]['weights'][1] / weights[1]['weights'][2])) +# self.bn_conv1.weight = nn.Parameter(torch.FloatTensor(weights[2]['weights'][0])) +# self.bn_conv1.bias = nn.Parameter(torch.FloatTensor(weights[2]['weights'][1])) +# + +def process_img(img): + mean = np.array([[[102.9801, 115.9465, 122.7717]]]) + img = img - mean + + im_shape = img.shape + im_size_min = np.min(im_shape[0:2]) + im_size_max = np.max(im_shape[0:2]) + + target_size = 600 + max_size = 1000 + im_scale = float(target_size) / float(im_size_min) + if np.round(im_scale * im_size_max) > max_size: + im_scale = float(max_size) / float(im_size_max) + im = cv2.resize(img, None, None, fx=im_scale, fy=im_scale, interpolation=cv2.INTER_LINEAR) + return im, np.array([im.shape[0], im.shape[1], im_scale]) + +#img = cv2.imread('img2.jpg') +##print(img) +#ipdb.set_trace() +#shape = process_img(img) +# +#net = Net() +#net.load_weights_from_pkl(data2) +#net.eval() +#with torch.no_grad(): +# output = net(img) +#ipdb.set_trace() +#print(residual) diff --git a/detector/generate_anchors.py b/detector/generate_anchors.py new file mode 100644 index 0000000..561dcf3 --- /dev/null +++ b/detector/generate_anchors.py @@ -0,0 +1,105 @@ +# -------------------------------------------------------- +# Faster R-CNN +# Copyright (c) 2015 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Written by Ross Girshick and Sean Bell +# -------------------------------------------------------- + +import numpy as np + +# Verify that we compute the same anchors as Shaoqing's matlab implementation: +# +# >> load output/rpn_cachedir/faster_rcnn_VOC2007_ZF_stage1_rpn/anchors.mat +# >> anchors +# +# anchors = +# +# -83 -39 100 56 +# -175 -87 192 104 +# -359 -183 376 200 +# -55 -55 72 72 +# -119 -119 136 136 +# -247 -247 264 264 +# -35 -79 52 96 +# -79 -167 96 184 +# -167 -343 184 360 + +#array([[ -83., -39., 100., 56.], +# [-175., -87., 192., 104.], +# [-359., -183., 376., 200.], +# [ -55., -55., 72., 72.], +# [-119., -119., 136., 136.], +# [-247., -247., 264., 264.], +# [ -35., -79., 52., 96.], +# [ -79., -167., 96., 184.], +# [-167., -343., 184., 360.]]) + +def generate_anchors(base_size=16, ratios=[0.5, 1, 2], + scales=2**np.arange(3, 6)): + """ + Generate anchor (reference) windows by enumerating aspect ratios X + scales wrt a reference (0, 0, 15, 15) window. + """ + + base_anchor = np.array([1, 1, base_size, base_size]) - 1 + ratio_anchors = _ratio_enum(base_anchor, ratios) + anchors = np.vstack([_scale_enum(ratio_anchors[i, :], scales) + for i in range(ratio_anchors.shape[0])]) + return anchors + +def _whctrs(anchor): + """ + Return width, height, x center, and y center for an anchor (window). + """ + + w = anchor[2] - anchor[0] + 1 + h = anchor[3] - anchor[1] + 1 + x_ctr = anchor[0] + 0.5 * (w - 1) + y_ctr = anchor[1] + 0.5 * (h - 1) + return w, h, x_ctr, y_ctr + +def _mkanchors(ws, hs, x_ctr, y_ctr): + """ + Given a vector of widths (ws) and heights (hs) around a center + (x_ctr, y_ctr), output a set of anchors (windows). + """ + + ws = ws[:, np.newaxis] + hs = hs[:, np.newaxis] + anchors = np.hstack((x_ctr - 0.5 * (ws - 1), + y_ctr - 0.5 * (hs - 1), + x_ctr + 0.5 * (ws - 1), + y_ctr + 0.5 * (hs - 1))) + return anchors + +def _ratio_enum(anchor, ratios): + """ + Enumerate a set of anchors for each aspect ratio wrt an anchor. + """ + + w, h, x_ctr, y_ctr = _whctrs(anchor) + size = w * h + size_ratios = size / ratios + ws = np.round(np.sqrt(size_ratios)) + hs = np.round(ws * ratios) + anchors = _mkanchors(ws, hs, x_ctr, y_ctr) + return anchors + +def _scale_enum(anchor, scales): + """ + Enumerate a set of anchors for each scale wrt an anchor. + """ + + w, h, x_ctr, y_ctr = _whctrs(anchor) + ws = w * scales + hs = h * scales + anchors = _mkanchors(ws, hs, x_ctr, y_ctr) + return anchors + +#if __name__ == '__main__': +# import time +# t = time.time() +# a = generate_anchors() +# print time.time() - t +# print a +# from IPython import embed; embed() diff --git a/detector/rpn.py b/detector/rpn.py new file mode 100644 index 0000000..cc46bca --- /dev/null +++ b/detector/rpn.py @@ -0,0 +1,136 @@ +import pickle + +import numpy as np +import torch as t +import torch +from torch.nn import functional as F +from torch import nn +import torchvision +from torch import nn + +from .generate_anchors import generate_anchors +from .bbox_transform import bbox_transform_inv, clip_boxes + + +class RegionProposalNetwork(nn.Module): + def __init__(self, pre_nms_topN,post_nms_topN, nms_thresh, min_size, anchor_scales, feat_stride): + super(RegionProposalNetwork, self).__init__() + self._anchors = generate_anchors(scales=np.array(anchor_scales)) + self._num_anchors = self._anchors.shape[0] + self._feat_stride = feat_stride + self.pre_nms_topN = pre_nms_topN + self.post_nms_topN = post_nms_topN + self.nms_thresh = nms_thresh + self.min_size = min_size + self.anchor_scales = anchor_scales + self.feat_stride = feat_stride + + + def forward(self, rpn_cls_prob, rpn_bbox_pred, img_size): + scores = rpn_cls_prob[:,self._num_anchors:, :, :] + bbox_deltas = rpn_bbox_pred + min_size = self.min_size + + pre_nms_topN = self.pre_nms_topN + post_nms_topN = self.post_nms_topN + nms_thresh = self.nms_thresh + min_size = self.min_size + anchor_scales = self.anchor_scales + feat_stride = self.feat_stride + + n, _, hh, ww = scores.shape + + #n_anchor = anchor.shape[0] // (hh * ww) + + anchors = self._enumerate_shifted_anchor(self._anchors, self._feat_stride, hh, ww) + bbox_deltas = bbox_deltas.permute((0, 2, 3, 1)).reshape((-1, 4)) + + bbox_deltas = bbox_deltas.cpu().detach().numpy() + proposals = bbox_transform_inv(anchors, bbox_deltas) + scores = scores.permute((0, 2, 3, 1)).reshape((-1, 1)) + + proposals = bbox_transform_inv(anchors, bbox_deltas) + + proposals = clip_boxes(proposals, img_size[:2]) + + keep = _filter_boxes(proposals, min_size * img_size[2]) + proposals = proposals[keep, :] + scores = scores[keep] + + order = scores.ravel().argsort(descending=True) + proposals = t.FloatTensor(proposals, device = rpn_cls_prob.device) + + if pre_nms_topN > 0: + order = order[:pre_nms_topN] + + proposals = proposals[order, :] + scores = scores[order] + + keep = torchvision.ops.nms(proposals, scores.ravel(), nms_thresh) + + if post_nms_topN > 0: + keep = keep[:post_nms_topN] + proposals = proposals[keep, :] + scores = scores[keep] + + batch_inds = t.zeros((proposals.shape[0], 1), dtype = proposals.dtype, device = proposals.device) + rois = t.hstack([batch_inds, proposals]) + + return rois + + #keep = nms(np.hstack((proposals, scores)), nms_thresh) + #if post_nms_topN > 0: + # keep = keep[:post_nms_topN] + #proposals = proposals[keep, :] + #scores = scores[keep] + + def _enumerate_shifted_anchor(self, anchor_base, feat_stride, height, width): + # Enumerate all shifted anchors: + # + # add A anchors (1, A, 4) to + # cell K shifts (K, 1, 4) to get + # shift anchors (K, A, 4) + # reshape to (K*A, 4) shifted anchors + # return (K*A, 4) + + # !TODO: add support for torch.CudaTensor + # xp = cuda.get_array_module(anchor_base) + # it seems that it can't be boosed using GPU + import numpy as xp + shift_y = xp.arange(0, height * feat_stride, feat_stride) + shift_x = xp.arange(0, width * feat_stride, feat_stride) + shift_x, shift_y = xp.meshgrid(shift_x, shift_y) + shift = xp.stack((shift_x.ravel(), shift_y.ravel(), + shift_x.ravel(), shift_y.ravel()), axis=1) + + A = anchor_base.shape[0] + K = shift.shape[0] + anchor = anchor_base.reshape((1, A, 4)) + \ + shift.reshape((1, K, 4)).transpose((1, 0, 2)) + anchor = anchor.reshape((K * A, 4)).astype(np.float32) + + return anchor + +def _filter_boxes(boxes, min_size): + """Remove all boxes with any side smaller than min_size.""" + ws = boxes[:, 2] - boxes[:, 0] + 1 + hs = boxes[:, 3] - boxes[:, 1] + 1 + keep = np.where((ws >= min_size) & (hs >= min_size))[0] + return keep + +#if __name__ == '__main__': +# rpn_cls_prob_reshape = data_kv['rpn_cls_prob_reshape'] +# rpn_bbox_pred = data_kv['rpn_bbox_pred'] +# +# rpn = RegionProposalNetwork(pre_nms_topN = 6000, post_nms_topN = 300, nms_thresh = 0.7, min_size = 16, anchor_scales = (4, 8, 16, 32), feat_stride=16) +# im_size = np.array([600. , 600. , 2.6785715]) +# ipdb.set_trace() +# rois = rpn.forward(rpn_cls_prob_reshape, rpn_bbox_pred, im_size) +# +# outputs = data_kv['res4b22'] +# feats = torchvision.ops.roi_pool(outputs,rois, output_size=[14,14], spatial_scale=0.0625) +# +# print(rpn_cls_prob_reshape.shape, rpn_bbox_pred.shape) +# print(rpn) +# print('main') +# diff --git a/dvl/__pycache__/const.cpython-38.pyc b/dvl/__pycache__/const.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e106867d430438914ecc5964e69f464a9738eb89 GIT binary patch literal 233 zcmWIL<>g`kf~QR@lgxqiV-N=!FakLaKwK;UBvKfn7*ZIc7*m*{m{OR788lfk88{di zG9B4~1cL$tgP$hzEp|^|_jngi-zZ)X!^hFd)d$Swb_#X&b`6OS_Kb2}$xy@uG!IPt z^3@M5PAw|dugc8H$*j`XE6FdB xD9Hn|Qu0gma}zW3^i#@m^po@Ric9ngDsOSv 0: + self.neg_imgs.append(hard_negatives_img[id_][:self.num_hard_negatives]) + self.neg_txts.append(hard_negatives_txt[img_fname][:self.num_hard_negatives]) + else: + self.neg_imgs.append(None) + self.neg_txts.append(None) + self.lens.append(tl + self.img_db.name2nbb[img_fname]) + + def __getitem__(self, i): + example = super().__getitem__(i) + # labels and negative images should be sampled every epoch + img_fname, hard_neg_imgs = self.train_imgs[i], self.neg_imgs[i] + txt_fname, hard_neg_txts = self.ids[i], self.neg_txts[i] + + img_input_ids = torch.Tensor([101]).long() + img_feat, img_pos_feat, num_bb = self._get_img_feat(img_fname) + attn_masks_img = torch.ones(num_bb+1, dtype=torch.long) + + # text input + input_ids = example['input_ids'] + input_ids = self.txt_db.combine_inputs(input_ids) + attn_masks = torch.ones(len(input_ids), dtype=torch.long) + + if hard_neg_imgs is not None: + # TODO: add hard negative here + neg_imgs = dict({'img_input_ids': [], 'img_feat': [], 'img_pos_feat': [], 'num_bb': [], 'attn_masks_img': [], + 'caption_ids': [], 'attn_masks_captions': []}) + for neg_id in hard_neg_imgs: + neg_imgs['img_input_ids'].append(torch.Tensor([101]).long()) + t = self._get_img_feat(neg_id) + neg_imgs['img_feat'].append(t[0]) + neg_imgs['img_pos_feat'].append(t[1]) + neg_imgs['num_bb'].append(t[2]) + neg_imgs['attn_masks_img'].append(torch.ones(t[2]+1, dtype=torch.long)) + if self.img_meta is not None: + tmp = [self.tokenizer.encode(i, add_special_tokens=False) + [self.tokenizer.sep_token_id] + for i in self.img_meta[neg_id]['caption_multiple']] + neg_imgs['caption_ids'].append(torch.tensor([self.tokenizer.cls_token_id] + sum(tmp, []), + dtype=input_ids.dtype, device=input_ids.device)) + neg_imgs['attn_masks_captions'].append(torch.ones(len(neg_imgs['caption_ids'][-1]), dtype=torch.long)) + # debug = [self.tokenizer.encode(a) for a in self.img_meta[img_fname]['annotation']] + neg_txts = dict({'input_ids':[], 'position_ids':[], 'attention_mask':[]}) + for neg_id in hard_neg_txts: + ei = super().__getitem__(self.ids_2_idx[neg_id]) + input_ids_ei = ei['input_ids'] + neg_txts['input_ids'].append(self.txt_db.combine_inputs(input_ids_ei)) + neg_txts['attention_mask'].append(torch.ones(len(neg_txts['input_ids'][-1]), dtype=torch.long)) + else: + neg_imgs = None + neg_txts = None + + if self.img_meta is not None: + caption_ids = [self.tokenizer.encode(i, add_special_tokens=False) + [self.tokenizer.sep_token_id] for i in self.img_meta[img_fname]['caption_multiple']] + caption_ids = torch.tensor([self.tokenizer.cls_token_id] + sum(caption_ids, []), dtype=input_ids.dtype, device=input_ids.device) + attn_masks_captions = torch.ones(len(caption_ids), dtype=torch.long) + # debug = [self.tokenizer.encode(a) for a in self.img_meta[img_fname]['annotation']] + else: + caption_ids = None + attn_masks_captions = None + + # target = torch.Tensor(1).long() + # target.data.fill_(ground_truth_label) + return input_ids, img_feat, img_pos_feat, img_input_ids, attn_masks, attn_masks_img, self.ids[i], img_fname, neg_imgs, neg_txts, caption_ids, attn_masks_captions + + +def itm_fast_collate_kd(inputs): + input_ids, img_feats, img_pos_feats, img_input_ids, attn_masks_text, attn_masks_img, idx, img_fname, negs, caption_ids, attn_masks_captions = map(list, unzip(inputs)) + + # txt_lens = [i.size(0) for i in input_ids] + input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0) + captions_ids = pad_sequence(caption_ids, batch_first=True, padding_value=0) if caption_ids[0] is not None else None + + position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long).unsqueeze(0) + position_ids_captions = torch.arange(0, captions_ids.size(1), dtype=torch.long).unsqueeze(0) if caption_ids[0] is not None else None + + if not None in negs: + num_bbs_neg = list(itertools.chain(*[n['num_bb'] for n in negs])) + img_feats_neg = list(itertools.chain(*[n['img_feat'] for n in negs])) + img_input_ids_neg = list(itertools.chain(*[n['img_input_ids'] for n in negs])) + img_pos_feat_neg = list(itertools.chain(*[n['img_pos_feat'] for n in negs])) + attn_masks_img_neg = list(itertools.chain(*[n['attn_masks_img'] for n in negs])) + else: + num_bbs_neg = [] + img_feats_neg = [] + img_input_ids_neg = [] + img_pos_feat_neg = [] + attn_masks_img_neg = [] + + num_bbs = [f.size(0) for f in img_feats] + num_bbs_neg + img_feat = pad_tensors(img_feats+img_feats_neg, num_bbs) + img_pos_feat = pad_tensors(img_pos_feats+img_pos_feat_neg, num_bbs) + + img_input_ids = pad_sequence(img_input_ids+img_input_ids_neg, batch_first=True, padding_value=0) + img_position_ids = torch.arange(0, img_input_ids.size(1), dtype=torch.long).unsqueeze(0) + + attn_masks_text = pad_sequence(attn_masks_text, batch_first=True, padding_value=0) + attn_masks_captions = pad_sequence(attn_masks_captions, batch_first=True, padding_value=0) if attn_masks_captions[0] is not None else None + attn_masks_img = pad_sequence(attn_masks_img+attn_masks_img_neg, batch_first=True, padding_value=0) + sample_size = len(inputs[0]) + assert all(sample_size == len(i) for i in inputs) + + bs, max_tl = input_ids.size() + out_size = attn_masks_img.size(1) + gather_index = get_gather_index([1]*bs, num_bbs, bs, 1, out_size) + + img_feat_teacher = img_feat[:N_EXAMPLES_TEACHER].repeat(bs, 1, 1) + img_pos_feat_teacher = img_pos_feat[:N_EXAMPLES_TEACHER].repeat(bs, 1, 1) + attn_masks_img_teacher = attn_masks_img[:N_EXAMPLES_TEACHER].repeat(bs, 1)[:, 1:] + + input_ids_teacher = input_ids.unsqueeze(1).repeat(1, 10, 1).view(bs*N_EXAMPLES_TEACHER, -1) + position_ids_teacher = position_ids + attn_masks_text_teacher = attn_masks_text.unsqueeze(1).repeat(1, 10, 1).view(bs*N_EXAMPLES_TEACHER, -1) + + attn_masks_teacher = torch.cat([attn_masks_text_teacher, attn_masks_img_teacher], dim=1) + + batch = { + 'txt_ids': input_ids, + 'img_ids': img_feat, + 'caption_ids': captions_ids, + 'txt_pos_ids': position_ids, + 'img_pos_ids': img_pos_feat, + 'caption_pos_ids': position_ids_captions, + 'txt_attn_masks': attn_masks_text, + 'img_attn_masks': attn_masks_img, + 'caption_attn_masks': attn_masks_captions, + 'img_txt_ids': img_input_ids, + 'img_txt_pos_ids': img_position_ids, + 'gather_index': gather_index, + 'sample_size': sample_size, + 'pos_ctx_indices': list(range(bs)), + 'neg_ctx_indices': list(range(bs, len(num_bbs))), + 'txt_index': idx, + 'img_fname': img_fname, + + 'img_feat_teacher': img_feat_teacher, + 'img_pos_feat_teacher': img_pos_feat_teacher, + 'input_ids_teacher': input_ids_teacher, + 'position_ids_teacher': position_ids_teacher, + 'attn_masks_teacher': attn_masks_teacher + } + return batch + + +def itm_fast_collate(inputs): + input_ids, img_feats, img_pos_feats, img_input_ids, attn_masks_text, attn_masks_img, idx, img_fname, neg_imgs, neg_txts, caption_ids, attn_masks_captions = map(list, unzip(inputs)) + bs = len(input_ids) + # txt_lens = [i.size(0) for i in input_ids] + + if not None in neg_imgs: + num_bbs_neg = list(itertools.chain(*[n['num_bb'] for n in neg_imgs])) + img_feats_neg = list(itertools.chain(*[n['img_feat'] for n in neg_imgs])) + img_input_ids_neg = list(itertools.chain(*[n['img_input_ids'] for n in neg_imgs])) + img_pos_feat_neg = list(itertools.chain(*[n['img_pos_feat'] for n in neg_imgs])) + attn_masks_img_neg = list(itertools.chain(*[n['attn_masks_img'] for n in neg_imgs])) + caption_ids_neg = list(itertools.chain(*[n['caption_ids'] for n in neg_imgs])) + attn_masks_captions_neg = list(itertools.chain(*[n['attn_masks_captions'] for n in neg_imgs])) + + input_ids_neg = list(itertools.chain(*[n['input_ids'] for n in neg_txts])) + attn_masks_text_neg = list(itertools.chain(*[n['attention_mask'] for n in neg_txts])) + else: + num_bbs_neg = [] + img_feats_neg = [] + img_input_ids_neg = [] + img_pos_feat_neg = [] + attn_masks_img_neg = [] + caption_ids_neg = [] + attn_masks_captions_neg = [] + + input_ids_neg = [] + attn_masks_text_neg = [] + + input_ids = pad_sequence(input_ids+input_ids_neg, batch_first=True, padding_value=0) + position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long).unsqueeze(0) + + captions_ids = pad_sequence(caption_ids+caption_ids_neg, batch_first=True, padding_value=0) if caption_ids[0] is not None else None + position_ids_captions = torch.arange(0, captions_ids.size(1), dtype=torch.long).unsqueeze(0) if caption_ids[0] is not None else None + + num_bbs = [f.size(0) for f in img_feats] + num_bbs_neg + img_feat = pad_tensors(img_feats+img_feats_neg, num_bbs) + img_pos_feat = pad_tensors(img_pos_feats+img_pos_feat_neg, num_bbs) + + img_input_ids = pad_sequence(img_input_ids+img_input_ids_neg, batch_first=True, padding_value=0) + img_position_ids = torch.arange(0, img_input_ids.size(1), dtype=torch.long).unsqueeze(0) + + attn_masks_text = pad_sequence(attn_masks_text+attn_masks_text_neg, batch_first=True, padding_value=0) + attn_masks_captions = pad_sequence(attn_masks_captions+attn_masks_captions_neg, batch_first=True, padding_value=0) if attn_masks_captions[0] is not None else None + attn_masks_img = pad_sequence(attn_masks_img+attn_masks_img_neg, batch_first=True, padding_value=0) + sample_size = bs + # assert all(sample_size == len(i) for i in inputs) + + max_tl = input_ids.shape[1] + out_size = attn_masks_img.size(1) + gather_index = get_gather_index([1]*bs, num_bbs, bs, 1, out_size) + + batch = { + 'txts': { + 'input_ids': input_ids, + 'position_ids': position_ids, + 'attention_mask': attn_masks_text, + 'img_feat': None, + 'img_pos_feat': None, + 'img_masks': None, + 'gather_index': None + }, + 'imgs': { + 'input_ids': img_input_ids, + 'position_ids': img_position_ids, + 'attention_mask': attn_masks_img, + 'img_feat': img_feat, + 'img_pos_feat': img_pos_feat, + 'img_masks': None, + 'gather_index': gather_index + }, + 'caps': { + 'input_ids': captions_ids, + 'position_ids': position_ids_captions, + 'attention_mask': attn_masks_captions, + 'img_feat': None, + 'img_pos_feat': None, + 'img_masks': None, + 'gather_index': None + }, + 'sample_size': sample_size, + 'pos_ctx_indices': list(range(bs)), + 'neg_ctx_indices': list(range(bs, len(num_bbs))), + 'txt_index': idx, + 'img_fname': img_fname + } + return batch + + +class ItmValDataset(DetectFeatTxtTokDataset): + """ For evaluating Image-Text-Retrieval task """ + def __init__(self, db_dir, img_dir, mini_batch_size=400): + super().__init__(db_dir, img_dir) + del self.lens + self.txt2img = self.txt_db.txt2img + self.img2txts = self.txt_db.img2txts + self.all_img_ids = list(self.img2txts.keys()) + + assert len(self.img2txts) >= mini_batch_size > 0 + self.bs = mini_batch_size + + def _get_batch_ids(self, i): + gt_txt_id = self.ids[i] + gt_img_id = self.txt2img[gt_txt_id] + + # sample fixed negatives for each gt image + i = self.all_img_ids.index(gt_img_id) + neg_st = i+1 + neg_end = neg_st+self.bs-1 + if neg_end > len(self.all_img_ids): + # warp around + neg_end -= len(self.all_img_ids) + neg_img_ids = (self.all_img_ids[neg_st:] + + self.all_img_ids[:neg_end]) + else: + neg_img_ids = self.all_img_ids[neg_st:neg_end] + + assert len(neg_img_ids) == (self.bs - 1),\ + "Did not sample enough neg samples" + + return gt_img_id, neg_img_ids + + def __getitem__(self, i): + """ this returns list of mini-batches """ + gt_img_id, neg_img_ids = self._get_batch_ids(i) + # NOTE 1st one is gt img + batch = self.get_batch(i, [gt_img_id] + neg_img_ids) + return batch + + def get_batch(self, i, img_ids): + example = super().__getitem__(i) + + input_ids = example['input_ids'] + input_ids = self.txt_db.combine_inputs(input_ids) + input_ids = input_ids.unsqueeze(0).expand(len(img_ids), -1).clone() + position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long + ).unsqueeze(0) + + # process image features (gt always first) + img_feats, img_pos_feats, num_bbs = map( + list, unzip(map(self._get_img_feat, img_ids))) + img_feat = pad_tensors(img_feats, num_bbs) + img_pos_feat = pad_tensors(img_pos_feats, num_bbs) + + tl = input_ids.size(1) + attn_masks_text = torch.ones(len(img_ids), tl).long() + # attn_masks_text = torch.ones(1, tl).long() + attn_masks_img = torch.zeros(len(img_ids), max(num_bbs)).long() + for i, nbb in enumerate(num_bbs): + attn_masks_img.data[i, :nbb].fill_(1) + + # out_size = attn_masks.size(1) + gather_index = None #get_gather_index([tl]*len(img_ids), num_bbs, len(img_ids), tl, out_size) + + batch = {'input_ids': input_ids, + 'position_ids': position_ids, + 'img_feat': img_feat, + 'img_pos_feat': img_pos_feat, + 'attn_masks_text': attn_masks_text, + 'attn_masks_img': attn_masks_img, + 'gather_index': gather_index} + return batch + + +# for VQA \ No newline at end of file diff --git a/dvl/data/itm_pre.py b/dvl/data/itm_pre.py new file mode 100644 index 0000000..aae769e --- /dev/null +++ b/dvl/data/itm_pre.py @@ -0,0 +1,592 @@ +""" +Itm dataset +""" +from collections import defaultdict +import copy +import json +import random + +import torch +from torch.nn.utils.rnn import pad_sequence +import numpy as np +from toolz.sandbox import unzip +from cytoolz import concat + +from uniter_model.data.data import (DetectFeatTxtTokDataset, TxtTokLmdb, DetectFeatLmdb, + pad_tensors, get_gather_index, get_ids_and_lens) +from uniter_model.data.sampler import TokenBucketSampler + + +class TokenBucketSamplerForItm(TokenBucketSampler): + def __init__(self, dset, *args, **kwargs): + super().__init__(dset.lens, *args, **kwargs) + self.dset = dset + + def __iter__(self): + it = super().__iter__() + self.dset.new_epoch() + self._lens = self.dset.lens + return it + + +def _has_overlap(la, lb): + if len(la) < len(lb): + la, lb = lb, la + s = set(la) + return any(b in s for b in lb) + + +def _sample_negative_rand(sample_pool, ground_truths, num_sample): + """ random and retry """ + outputs = ground_truths[:1] + while _has_overlap(outputs, ground_truths): + outputs = random.sample(sample_pool, num_sample) + return outputs + + +def _sample_negative_extra(sample_pool, ground_truths, num_sample): + """ sample extra then remove """ + tot_size = len(ground_truths) + num_sample + outputs = set(random.sample(sample_pool, tot_size)) + for gt in ground_truths: + outputs.discard(gt) + outputs = list(outputs)[:num_sample] + return outputs + + +sample_negative = _sample_negative_rand # swith between 2 implementations + + +class ItmDataset(DetectFeatTxtTokDataset): + """ NOTE this Dataset handles distributed training itself + (for more efficient negative sampling) """ + def __init__(self, txt_db, img_db, neg_sample_p=0.0): + assert isinstance(txt_db, TxtTokLmdb) + assert isinstance(img_db, DetectFeatLmdb) + + self.txt_db = txt_db + self.img_db = img_db + + self.txt_lens, self.ids = get_ids_and_lens(txt_db) + self.all_imgs = list(set(txt_db[id_]['img_fname'] for id_ in self.ids)) + + self.neg_sample_p = neg_sample_p + self.new_epoch() + + def new_epoch(self): + """ should be called every epoch for more randomness""" + self.labels = np.random.choice( + [0, 1], size=len(self.ids), + p=[self.neg_sample_p, 1-self.neg_sample_p]) + + self.lens = [] + self.train_imgs = [] + for i, (id_, tl) in enumerate(zip(self.ids, self.txt_lens)): + img_fname = super().__getitem__(i)['img_fname'] + if self.labels[i] == 0: + img_fname = sample_negative(self.all_imgs, [img_fname], 1)[0] + self.train_imgs.append(img_fname) + self.lens.append(tl + self.img_db.name2nbb[img_fname]) + + def __getitem__(self, i): + example = super().__getitem__(i) + # labels and negative images should be sampled every epoch + ground_truth_label = self.labels[i] + img_fname = self.train_imgs[i] + img_input_ids = torch.Tensor([101]).long() + img_feat, img_pos_feat, num_bb = self._get_img_feat(img_fname) + attn_masks_img = torch.ones(num_bb+1, dtype=torch.long) + + # text input + input_ids = example['input_ids'] + input_ids = self.txt_db.combine_inputs(input_ids) + + attn_masks = torch.ones(len(input_ids), dtype=torch.long) + target = torch.Tensor(1).long() + target.data.fill_(ground_truth_label) + + return input_ids, attn_masks, img_input_ids, img_feat, img_pos_feat, attn_masks_img, target + + +def itm_collate(inputs): + (input_ids, attn_masks, img_input_ids, img_feats, img_pos_feats, attn_masks_img, targets + ) = map(list, unzip(inputs)) + + input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0) + position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long).unsqueeze(0) + + num_bbs = [f.size(0) for f in img_feats] + img_feat = pad_tensors(img_feats, num_bbs) + img_pos_feat = pad_tensors(img_pos_feats, num_bbs) + + img_input_ids = pad_sequence(img_input_ids, batch_first=True, padding_value=0) + img_position_ids = torch.arange(0, img_input_ids.size(1), dtype=torch.long).unsqueeze(0) + + attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0) + attn_masks_img = pad_sequence(attn_masks_img, batch_first=True, padding_value=0) + targets = torch.cat(targets, dim=0) + bs, max_tl = input_ids.size() + out_size = attn_masks_img.size(1) + # gather_index = get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size) + gather_index = get_gather_index([1]*bs, num_bbs, bs, 1, out_size) + + batch = { + 'txts': { + 'input_ids': input_ids, + 'position_ids': position_ids, + 'attention_mask': attn_masks, + 'img_feat': None, + 'img_pos_feat': None, + 'img_masks': None, + 'gather_index': None + }, + 'imgs': { + 'input_ids': img_input_ids, + 'position_ids': img_position_ids, + 'attention_mask': attn_masks_img, + 'img_feat': img_feat, + 'img_pos_feat': img_pos_feat, + 'img_masks': None, + 'gather_index': gather_index + }, + 'pos_ctx_indices': list(range(bs)), + 'neg_ctx_indices': list(range(bs, len(num_bbs))), + 'targets': targets + } + return batch + + +def _compute_ot_scatter(txt_lens, max_txt_len, joint_len): + ot_scatter = torch.arange(0, joint_len, dtype=torch.long + ).unsqueeze(0).repeat(len(txt_lens), 1) + for i, tl in enumerate(txt_lens): + max_ind = max_txt_len + (joint_len-tl) + ot_scatter.data[i, tl:] = torch.arange(max_txt_len, max_ind, + dtype=torch.long).data + return ot_scatter + + +def _compute_pad(lens, max_len): + pad = torch.zeros(len(lens), max_len, dtype=torch.bool) + for i, l in enumerate(lens): + pad.data[i, l:].fill_(1) + return pad + + +def itm_ot_collate(inputs): + (input_ids, img_feats, img_pos_feats, attn_masks, targets + ) = map(list, unzip(inputs)) + + txt_lens = [i.size(0) for i in input_ids] + + input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0) + position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long + ).unsqueeze(0) + + num_bbs = [f.size(0) for f in img_feats] + img_feat = pad_tensors(img_feats, num_bbs) + img_pos_feat = pad_tensors(img_pos_feats, num_bbs) + + attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0) + targets = torch.cat(targets, dim=0) + bs, max_tl = input_ids.size() + out_size = attn_masks.size(1) + gather_index = get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size) + + # OT inputs + max_tl = max(txt_lens) + max_nbb = max(num_bbs) + ot_scatter = _compute_ot_scatter(txt_lens, max_tl, attn_masks.size(1)) + txt_pad = _compute_pad(txt_lens, max_tl) + img_pad = _compute_pad(num_bbs, max_nbb) + ot_inputs = {'ot_scatter': ot_scatter, + 'scatter_max': ot_scatter.max().item(), + 'txt_pad': txt_pad, + 'img_pad': img_pad} + + batch = {'input_ids': input_ids, + 'position_ids': position_ids, + 'img_feat': img_feat, + 'img_pos_feat': img_pos_feat, + 'attn_masks': attn_masks, + 'gather_index': gather_index, + 'targets': targets, + 'ot_inputs': ot_inputs} + return batch + + +class ItmRankDataset(DetectFeatTxtTokDataset): + def __init__(self, txt_db, img_db, neg_sample_size=1): + assert neg_sample_size > 0, \ + "ItmRankDataset need at least 1 negative sample" + super().__init__(txt_db, img_db) + + txt2img = self.txt_db.txt2img + self.txt2img = {id_: txt2img[id_] for id_ in self.ids} + # images partitioned by rank + self.img2txts = defaultdict(list) + for id_, img in self.txt2img.items(): + self.img2txts[img].append(id_) + self.img_name_list = list(self.img2txts.keys()) + + assert neg_sample_size > 0 + self.neg_sample_size = neg_sample_size + + def __getitem__(self, i): + gt_txt_id = self.ids[i] + gt_img_fname = self.txt2img[gt_txt_id] + + id_pairs = [(gt_txt_id, gt_img_fname)] + # sample negatives + neg_sample_img_ids = sample_negative( + self.img_name_list, [gt_img_fname], self.neg_sample_size) + neg_sample_txt_ids = sample_negative( + self.ids, self.img2txts[gt_img_fname], self.neg_sample_size) + id_pairs.extend([(gt_txt_id, neg_img_id) + for neg_img_id in neg_sample_img_ids] + + [(neg_txt_id, gt_img_fname) + for neg_txt_id in neg_sample_txt_ids]) + inputs = self._collect_inputs(id_pairs) + assert len(inputs) == (1 + 2*self.neg_sample_size) + return inputs + + def _collect_inputs(self, id_pairs): + # create input features + inputs = [] + for txt_id, img_id in id_pairs: + example = self.txt_db[txt_id] + # text input + input_ids = example['input_ids'] + input_ids = self.txt_db.combine_inputs(input_ids) + # img input + img_feat, img_pos_feat, num_bb = self._get_img_feat(img_id) + # mask + attn_masks_text = torch.ones(len(input_ids), dtype=torch.long) + attn_masks_img = torch.ones(num_bb, dtype=torch.long) + + inputs.append((input_ids, img_feat, img_pos_feat, attn_masks_text, attn_masks_img)) + + return inputs + + +class ItmRankDatasetHardNeg(ItmRankDataset): + def __init__(self, txt_db, img_db, neg_sample_size=1, hard_neg_size=1): + assert hard_neg_size > 0, \ + "ItmRankDatasetHardNeg need at least 1 hard negative sample" + DetectFeatTxtTokDataset.__init__(self, txt_db, img_db) + + txt2img = self.txt_db.txt2img + self.txt2img = {id_: txt2img[id_] for id_ in self.ids} + self.img2txts = self.txt_db.img2txts + self.img_name_list = list(self.img2txts.keys()) + + assert neg_sample_size > 0 + self.neg_sample_size = neg_sample_size + self.hard_neg_size = hard_neg_size + + def reload_hard_negs(self, hard_neg_dir): + self.txt2hardimgs = json.load( + open(f'{hard_neg_dir}/' + f'txt2hardimgs_rank{hvd.rank()}.json')) + self.img2hardtxts = json.load( + open(f'{hard_neg_dir}/img2hardtxts.json')) + + def __getitem__(self, i): + gt_txt_id = self.ids[i] + gt_img_fname = self.txt2img[gt_txt_id] + + id_pairs = [(gt_txt_id, gt_img_fname)] + # sample hard negatives + if self.hard_neg_size > 0: + hard_neg_img_samples = random.sample( + self.txt2hardimgs[gt_txt_id], self.hard_neg_size) + hard_neg_txt_samples = random.sample( + self.img2hardtxts[gt_img_fname], self.hard_neg_size) + id_pairs.extend([(gt_txt_id, neg_img_id) + for neg_img_id in hard_neg_img_samples] + + [(neg_txt_id, gt_img_fname) + for neg_txt_id in hard_neg_txt_samples]) + # sample normal negatives + if self.neg_sample_size > 0: + neg_sample_img_ids = sample_negative( + self.img_name_list, [gt_img_fname], self.neg_sample_size) + neg_sample_txt_ids = sample_negative( + self.ids, self.img2txts[gt_img_fname], self.neg_sample_size) + id_pairs.extend([(gt_txt_id, neg_img_id) + for neg_img_id in neg_sample_img_ids] + + [(neg_txt_id, gt_img_fname) + for neg_txt_id in neg_sample_txt_ids]) + + inputs = self._collect_inputs(id_pairs) + assert len(inputs) == (1 + + 2*self.neg_sample_size + + 2*self.hard_neg_size) + return inputs + + +def itm_rank_collate(inputs): + (input_ids, img_feats, img_pos_feats, attn_masks_text, attn_masks_img, + ) = map(list, unzip(concat(i for i in inputs))) + + txt_lens = [i.size(0) for i in input_ids] + input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0) + position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long + ).unsqueeze(0) + + num_bbs = [f.size(0) for f in img_feats] + img_feat = pad_tensors(img_feats, num_bbs) + img_pos_feat = pad_tensors(img_pos_feats, num_bbs) + + attn_masks_text = pad_sequence(attn_masks_text, batch_first=True, padding_value=0) + attn_masks_img = pad_sequence(attn_masks_img, batch_first=True, padding_value=0) + sample_size = len(inputs[0]) + assert all(sample_size == len(i) for i in inputs) + + bs, max_tl = input_ids.size() + # out_size = attn_masks.size(1) + gather_index = None # get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size) + + batch = {'input_ids': input_ids, + 'position_ids': position_ids, + 'img_feat': img_feat, + 'img_pos_feat': img_pos_feat, + 'attn_masks_text': attn_masks_text, + 'attn_masks_img': attn_masks_img, + 'gather_index': gather_index, + 'sample_size': sample_size} + return batch + + +class ItmRankDatasetHardNegFromText(DetectFeatTxtTokDataset): + def __init__(self, txt_db, img_db, neg_sample_size=1): + assert neg_sample_size > 0, \ + "ItmRankDatasetHardNegV2 need at least 1 negative sample" + super().__init__(txt_db, img_db) + + txt2img = self.txt_db.txt2img + self.txt2img = {id_: txt2img[id_] for id_ in self.ids} + self.img2txts = self.txt_db.img2txts + self.img_name_list = list(self.img2txts.keys()) + self.neg_sample_size = neg_sample_size + + def __getitem__(self, i): + gt_txt_id = self.ids[i] + gt_img_fname = self.txt2img[gt_txt_id] + + input_ids = self.txt_db[gt_txt_id]['input_ids'] + input_ids = self.txt_db.combine_inputs(input_ids) + input_ids = input_ids.unsqueeze(0) + position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long + ).unsqueeze(0) + + neg_img_ids = sample_negative( + self.img_name_list, [gt_img_fname], self.neg_sample_size) + img_ids = [gt_img_fname] + neg_img_ids + # process image features (gt always first) + img_feats, img_pos_feats, num_bbs = map( + list, unzip(map(self._get_img_feat, img_ids))) + img_feat = pad_tensors(img_feats, num_bbs) + img_pos_feat = pad_tensors(img_pos_feats, num_bbs) + + tl = input_ids.size(1) + attn_masks = torch.zeros(len(img_ids), max(num_bbs) + tl).long() + for i, nbb in enumerate(num_bbs): + attn_masks.data[i, :tl+nbb].fill_(1) + out_size = attn_masks.size(1) + gather_index = get_gather_index([tl]*len(img_ids), num_bbs, + len(img_ids), tl, out_size) + + batch = {'input_ids': input_ids, + 'position_ids': position_ids, + 'img_feat': img_feat, + 'img_pos_feat': img_pos_feat, + 'attn_masks': attn_masks, + 'gather_index': gather_index} + return batch + + +class ItmRankDatasetHardNegFromImage(DetectFeatTxtTokDataset): + def __init__(self, txt_db, img_db, neg_sample_size=1): + assert neg_sample_size > 0, \ + "ItmRankDatasetHardNegV2 need at least 1 negative sample" + super().__init__(txt_db, img_db) + + txt2img = self.txt_db.txt2img + self.txt2img = {id_: txt2img[id_] for id_ in self.ids} + self.img2txts = self.txt_db.img2txts + self.txt_name_list = list(self.txt2img.keys()) + self.neg_sample_size = neg_sample_size + + + def __getitem__(self, i): + gt_txt_id = self.ids[i] + gt_img_id = self.txt2img[gt_txt_id] + gt_txt_ids = self.img2txts[gt_img_id] + + # process image features (gt always first) + img_feat, img_pos_feat, nbb = self._get_img_feat(gt_img_id) + img_feat = img_feat.unsqueeze(0) + img_pos_feat = img_pos_feat.unsqueeze(0) + + # sample negative + neg_txt_ids = sample_negative( + self.txt_name_list, gt_txt_ids, self.neg_sample_size) + txt_ids = [gt_txt_id] + neg_txt_ids + + # process text inputs + all_inputs = [] + txt_lens = [] + for txt_id in txt_ids: + input_ids = self.txt_db.combine_inputs( + self.txt_db[txt_id]['input_ids']) + all_inputs.append(input_ids) + txt_lens.append(len(input_ids)) + input_ids = pad_sequence(all_inputs, batch_first=True, padding_value=0) + position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long + ).unsqueeze(0) + + attn_masks = torch.zeros(len(txt_ids), max(txt_lens) + nbb).long() + for i, tl in enumerate(txt_lens): + attn_masks.data[i, :tl+nbb].fill_(1) + out_size = attn_masks.size(1) + gather_index = get_gather_index(txt_lens, [nbb]*len(txt_ids), + len(txt_ids), tl, out_size) + + batch = {'input_ids': input_ids, + 'position_ids': position_ids, + 'img_feat': img_feat, + 'img_pos_feat': img_pos_feat, + 'attn_masks': attn_masks, + 'gather_index': gather_index} + return batch + + +def itm_rank_hnv2_collate(inputs): + assert len(inputs) == 1 + return inputs[0] + + +class ItmValDataset(DetectFeatTxtTokDataset): + """ For evaluating Image-Text-Retrieval task """ + def __init__(self, db_dir, img_dir, mini_batch_size=400): + super().__init__(db_dir, img_dir) + del self.lens + self.txt2img = self.txt_db.txt2img + self.img2txts = self.txt_db.img2txts + self.all_img_ids = list(self.img2txts.keys()) + + assert len(self.img2txts) >= mini_batch_size > 0 + self.bs = mini_batch_size + + def _get_batch_ids(self, i): + gt_txt_id = self.ids[i] + gt_img_id = self.txt2img[gt_txt_id] + + # sample fixed negatives for each gt image + i = self.all_img_ids.index(gt_img_id) + neg_st = i+1 + neg_end = neg_st+self.bs-1 + if neg_end > len(self.all_img_ids): + # warp around + neg_end -= len(self.all_img_ids) + neg_img_ids = (self.all_img_ids[neg_st:] + + self.all_img_ids[:neg_end]) + else: + neg_img_ids = self.all_img_ids[neg_st:neg_end] + + assert len(neg_img_ids) == (self.bs - 1),\ + "Did not sample enough neg samples" + + return gt_img_id, neg_img_ids + + def __getitem__(self, i): + """ this returns list of mini-batches """ + gt_img_id, neg_img_ids = self._get_batch_ids(i) + # NOTE 1st one is gt img + batch = self.get_batch(i, [gt_img_id] + neg_img_ids) + return batch + + def get_batch(self, i, img_ids): + example = super().__getitem__(i) + + input_ids = example['input_ids'] + input_ids = self.txt_db.combine_inputs(input_ids) + input_ids = input_ids.unsqueeze(0).expand(len(img_ids), -1).clone() + position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long + ).unsqueeze(0) + + # process image features (gt always first) + img_feats, img_pos_feats, num_bbs = map( + list, unzip(map(self._get_img_feat, img_ids))) + img_feat = pad_tensors(img_feats, num_bbs) + img_pos_feat = pad_tensors(img_pos_feats, num_bbs) + + tl = input_ids.size(1) + attn_masks_text = torch.ones(len(img_ids), tl).long() + # attn_masks_text = torch.ones(1, tl).long() + attn_masks_img = torch.zeros(len(img_ids), max(num_bbs)).long() + for i, nbb in enumerate(num_bbs): + attn_masks_img.data[i, :nbb].fill_(1) + + # out_size = attn_masks.size(1) + gather_index = None #get_gather_index([tl]*len(img_ids), num_bbs, len(img_ids), tl, out_size) + + batch = {'input_ids': input_ids, + 'position_ids': position_ids, + 'img_feat': img_feat, + 'img_pos_feat': img_pos_feat, + 'attn_masks_text': attn_masks_text, + 'attn_masks_img': attn_masks_img, + 'gather_index': gather_index} + return batch + + +def itm_val_collate(inputs): + assert len(inputs) == 1, "input batch size > 1" + return inputs[0] + + +class ItmHardNegDataset(ItmValDataset): + def _get_batch_ids(self, i): + gt_txt_id = self.ids[i] + gt_img_id = self.txt2img[gt_txt_id] + + # sample fixed negatives for each gt image + i = self.all_img_ids.index(gt_img_id) + all_img_ids = copy.deepcopy(self.all_img_ids) + all_img_ids.remove(gt_img_id) + random.shuffle(all_img_ids) + neg_img_ids = all_img_ids[:self.bs] + + assert len(neg_img_ids) == (self.bs),\ + "Did not sample enough neg samples" + + return gt_img_id, neg_img_ids + + def __getitem__(self, i): + """ this returns list of mini-batches """ + _, neg_img_ids = self._get_batch_ids(i) + batch = self.get_batch(i, neg_img_ids) + batch['gt_txt_id'] = self.ids[i] + batch['neg_img_ids'] = neg_img_ids + return batch + + +itm_hn_collate = itm_val_collate + + +class ItmEvalDataset(ItmValDataset): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.all_img_ids = sorted(copy.deepcopy(self.all_img_ids), + key=lambda i: self.img_db.name2nbb[i]) + + def __getitem__(self, i): + mini_batches = [] + for st in range(0, len(self.all_img_ids), self.bs): + mini_batches.append( + self.get_batch(i, self.all_img_ids[st:st+self.bs])) + return mini_batches + + +itm_eval_collate = itm_val_collate diff --git a/dvl/data/mlm.py b/dvl/data/mlm.py new file mode 100644 index 0000000..718b491 --- /dev/null +++ b/dvl/data/mlm.py @@ -0,0 +1,390 @@ +""" +MLM datasets +""" +import math +import random + +import torch +from torch.utils.data import Dataset +from torch.nn.utils.rnn import pad_sequence +from toolz.sandbox import unzip + +from uniter_model.data.data import (DetectFeatTxtTokDataset, TxtTokLmdb, get_ids_and_lens, pad_tensors, + get_gather_index, get_gather_index_uniter) + + +def random_word(tokens, vocab_range, mask): + """ + Masking some random tokens for Language Model task with probabilities as in + the original BERT paper. + :param tokens: list of int, tokenized sentence. + :param vocab_range: for choosing a random word + :return: (list of int, list of int), masked tokens and related labels for + LM prediction + """ + output_label = [] + + for i, token in enumerate(tokens): + prob = random.random() + # mask token with 15% probability + if prob < 0.15: + prob /= 0.15 + + # 80% randomly change token to mask token + if prob < 0.8: + tokens[i] = mask + + # 10% randomly change token to random token + elif prob < 0.9: + tokens[i] = random.choice(list(range(*vocab_range))) + + # -> rest 10% randomly keep current token + + # append current token to output (we will predict these later) + output_label.append(token) + else: + # no masking token (will be ignored by loss function later) + output_label.append(-1) + if all(o == -1 for o in output_label): + # at least mask 1 + output_label[0] = tokens[0] + tokens[0] = mask + + return tokens, output_label + + +class MlmDataset(DetectFeatTxtTokDataset): + def __init__(self, txt_db, img_db): + assert isinstance(txt_db, TxtTokLmdb) + super().__init__(txt_db, img_db) + + def __getitem__(self, i): + """ + Return: + - input_ids : (L, ), i.e., [cls, wd, wd, ..., sep, 0, 0], 0s padded + - img_feat : (num_bb, d) + - img_pos_feat : (num_bb, 7) + - attn_masks : (L + num_bb, ), ie., [1, 1, ..., 0, 0, 1, 1] + - txt_labels : (L, ), [-1, -1, wid, -1, -1, -1] + 0's padded so that (L + num_bb) % 8 == 0 + """ + example = super().__getitem__(i) + + # text input + input_ids, txt_labels = self.create_mlm_io(example['input_ids']) + + # img input + img_input_ids = torch.Tensor([101]).long() + img_feat, img_pos_feat, num_bb = self._get_img_feat(example['img_fname']) + + attn_masks = torch.ones(len(input_ids), dtype=torch.long) + attn_masks_img = torch.ones(num_bb+1, dtype=torch.long) + attn_masks_teacher = torch.ones(len(input_ids) + num_bb, dtype=torch.long) + + return input_ids, attn_masks, img_input_ids, img_feat, img_pos_feat, attn_masks_img, txt_labels, attn_masks_teacher + + def create_mlm_io(self, input_ids): + input_ids, txt_labels = random_word(input_ids, + self.txt_db.v_range, + self.txt_db.mask) + input_ids = torch.tensor([self.txt_db.cls_] + + input_ids + + [self.txt_db.sep]) + txt_labels = torch.tensor([-1] + txt_labels + [-1]) + return input_ids, txt_labels + + +def mlm_collate(inputs): + """ + Return: + :input_ids (n, max_L) padded with 0 + :position_ids (n, max_L) padded with 0 + :txt_lens list of [txt_len] + :img_feat (n, max_num_bb, feat_dim) + :img_pos_feat (n, max_num_bb, 7) + :num_bbs list of [num_bb] + :attn_masks (n, max_{L + num_bb}) padded with 0 + :txt_labels (n, max_L) padded with -1 + """ + (input_ids, attn_masks, img_input_ids, img_feats, img_pos_feats, attn_masks_img, txt_labels, attn_masks_teacher + ) = map(list, unzip(inputs)) + + # text batches + txt_lens = [i.size(0) for i in input_ids] + input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0) + txt_labels = pad_sequence(txt_labels, batch_first=True, padding_value=-1) + position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long).unsqueeze(0) + attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0) + + # image batches + num_bbs = [f.size(0) for f in img_feats] + img_input_ids = pad_sequence(img_input_ids, batch_first=True, padding_value=0) + img_feat = pad_tensors(img_feats, num_bbs) + img_pos_feat = pad_tensors(img_pos_feats, num_bbs) + img_position_ids = torch.arange(0, img_input_ids.size(1), dtype=torch.long).unsqueeze(0) + attn_masks_img = pad_sequence(attn_masks_img, batch_first=True, padding_value=0) + + bs, max_tl = input_ids.size() + out_size = attn_masks_img.size(1) + # gather_index = get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size) + gather_index = get_gather_index([1]*bs, num_bbs, bs, 1, out_size) + + attn_masks_teacher = pad_sequence(attn_masks_teacher, batch_first=True, padding_value=0) + gather_index_teacher = get_gather_index_uniter(txt_lens, num_bbs, bs, max_tl, attn_masks_teacher.size(1)) + + batch = { + 'txts': { + 'input_ids': input_ids, + 'position_ids': position_ids, + 'attention_mask': attn_masks, + 'img_feat': None, + 'img_pos_feat': None, + 'img_masks': None, + 'gather_index': None + }, + 'imgs': { + 'input_ids': img_input_ids, + 'position_ids': img_position_ids, + 'attention_mask': attn_masks_img, + 'img_feat': img_feat, + 'img_pos_feat': img_pos_feat, + 'img_masks': None, + 'gather_index': gather_index + }, + 'txt_labels': txt_labels, + 'teacher': { + 'txt_lens': txt_lens, + 'num_bbs': num_bbs, + 'bs': bs, + 'max_tl': max_tl, + 'out_size': out_size, + 'gather_index': gather_index_teacher, + 'attn_masks': attn_masks_teacher + } + } + return batch + + +class BlindMlmDataset(Dataset): + def __init__(self, txt_db): + assert isinstance(txt_db, TxtTokLmdb) + self.txt_db = txt_db + self.lens, self.ids = get_ids_and_lens(txt_db) + + def __len__(self): + return len(self.ids) + + def __getitem__(self, i): + id_ = self.ids[i] + example = self.txt_db[id_] + input_ids, txt_labels = self.create_mlm_io(example['input_ids']) + attn_masks = torch.ones(len(input_ids), dtype=torch.long) + + return input_ids, attn_masks, txt_labels + + +def mlm_blind_collate(inputs): + input_ids, attn_masks, txt_labels = map(list, unzip(inputs)) + + input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0) + position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long + ).unsqueeze(0) + attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0) + txt_labels = pad_sequence(txt_labels, batch_first=True, padding_value=-1) + + batch = {'input_ids': input_ids, + 'position_ids': position_ids, + 'attn_masks': attn_masks, + 'txt_labels': txt_labels} + return batch + + +def eval_mask(len_, num_samples=7): + """ build the mask for evaluating MLM + circularly mask 1 word out of every x words + """ + # build the random masks + if len_ <= num_samples: + masks = torch.eye(len_).bool() + num_samples = len_ + else: + mask_inds = [list(range(i, len_, num_samples)) + for i in range(num_samples)] + masks = torch.zeros(num_samples, len_).bool() + for i, indices in enumerate(mask_inds): + for j in indices: + masks.data[i, j] = 1 + assert (masks.sum(dim=0) != torch.ones(len_).long()).sum().item() == 0 + assert masks.sum().item() == len_ + return masks + + +def eval_gather_inds(len_, num_samples=7): + """ get the gather indices """ + inds = torch.arange(0, num_samples, dtype=torch.long) + mul = math.ceil(len_ / num_samples) + output = inds.repeat(mul)[:len_] + return output + + +def stack_pad_tensors(tensors, lens=None, ns=None, pad=0): + """N x [B_i, T, ...]""" + if ns is None: + ns = [t.size(0) for t in tensors] + if lens is None: + lens = [t.size(1) for t in tensors] + max_len = max(lens) + bs = sum(ns) + hid_dims = tensors[0].size()[2:] + dtype = tensors[0].dtype + output = torch.zeros(bs, max_len, *hid_dims, dtype=dtype) + if pad: + output.data.fill_(pad) + i = 0 + for t, l, n in zip(tensors, lens, ns): + output.data[i:i+n, :l, ...] = t.data + i += n + return output + + +def expand_tensors(tensors, ns): + return [t.unsqueeze(0).expand(n, *tuple([-1]*t.dim())) + for t, n in zip(tensors, ns)] + + +class MlmEvalDataset(DetectFeatTxtTokDataset): + """ For evaluating MLM training task """ + def __init__(self, txt_db, img_db): + assert isinstance(txt_db, TxtTokLmdb) + super().__init__(txt_db, img_db) + + def __getitem__(self, i): + example = super().__getitem__(i) + + # text input + (input_ids, txt_labels, gather_inds + ) = self.create_mlm_eval_io(example['input_ids']) + + # img input + img_feat, img_pos_feat, num_bb = self._get_img_feat( + example['img_fname']) + + attn_masks = torch.ones(input_ids.size(1) + num_bb, dtype=torch.long) + + return (input_ids, img_feat, img_pos_feat, attn_masks, + txt_labels, gather_inds) + + def create_mlm_eval_io(self, input_ids): + txt_labels = torch.tensor(input_ids) + masks = eval_mask(len(input_ids)) + n_mask = masks.size(0) + masks = torch.cat([torch.zeros(n_mask, 1).bool(), + masks, + torch.zeros(n_mask, 1).bool()], + dim=1) + input_ids = torch.tensor([[self.txt_db.cls_] + + input_ids + + [self.txt_db.sep] + for _ in range(n_mask)]) + input_ids.data.masked_fill_(masks, self.txt_db.mask) + gather_inds = eval_gather_inds(len(txt_labels)) + return input_ids, txt_labels, gather_inds + + +def _batch_gather_tgt(gather_inds, n_masks): + gather_tgts = [] + offset = 0 + for g, n in zip(gather_inds, n_masks): + gather_tgts.append(g + offset) + offset += n + gather_tgt = pad_sequence(gather_tgts, batch_first=True, padding_value=0) + return gather_tgt + + +def mlm_eval_collate(inputs): + (input_ids, img_feats, img_pos_feats, attn_masks, txt_labels, gather_inds + ) = map(list, unzip(inputs)) + + # sizes + n_masks, txt_lens = map(list, unzip(i.size() for i in input_ids)) + + # text batches + input_ids = stack_pad_tensors(input_ids, txt_lens, n_masks) + position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long + ).unsqueeze(0) + txt_labels = pad_sequence(txt_labels, batch_first=True, padding_value=-1) + gather_tgt = _batch_gather_tgt(gather_inds, n_masks) + + # image batches + num_bbs = [f.size(0) for f in img_feats] + img_feat = stack_pad_tensors(expand_tensors(img_feats, n_masks), + num_bbs, n_masks) + img_pos_feat = stack_pad_tensors(expand_tensors(img_pos_feats, n_masks), + num_bbs, n_masks) + + bs, max_tl = input_ids.size() + attn_masks = stack_pad_tensors(expand_tensors(attn_masks, n_masks), + None, n_masks) + out_size = attn_masks.size(1) + # repeat txt_lens, num_bbs + txt_lens = [l for l, n in zip(txt_lens, n_masks) for _ in range(n)] + num_bbs = [b for b, n in zip(num_bbs, n_masks) for _ in range(n)] + gather_index = get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size) + + batch = {'input_ids': input_ids, + 'position_ids': position_ids, + 'img_feat': img_feat, + 'img_pos_feat': img_pos_feat, + 'attn_masks': attn_masks, + 'gather_index': gather_index, + 'gather_tgt': gather_tgt, + 'txt_labels': txt_labels} + return batch + + +class BlindMlmEvalDataset(Dataset): + def __init__(self, txt_db): + assert isinstance(txt_db, TxtTokLmdb) + self.txt_db = txt_db + self.lens, self.ids = get_ids_and_lens(txt_db) + + def __len__(self): + return len(self.ids) + + def __getitem__(self, i): + id_ = self.ids[i] + example = self.txt_db[id_] + input_ids = example['input_ids'] + + # text input + input_ids = example['input_ids'] + (input_ids, txt_labels, gather_inds + ) = self.txt_db.create_mlm_eval_io(input_ids) + + attn_masks = torch.ones(len(input_ids), dtype=torch.long) + + return input_ids, attn_masks, txt_labels, gather_inds + + +def mlm_blind_eval_collate(inputs): + (input_ids, position_ids, attn_masks, txt_labels, gather_inds + ) = map(list, unzip(inputs)) + + # sizes + n_masks, txt_lens = map(list, unzip(i.size() for i in input_ids)) + + # text batches + input_ids = stack_pad_tensors(input_ids, txt_lens, n_masks) + position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long + ).unsqueeze(0) + attn_masks = stack_pad_tensors(expand_tensors(attn_masks, n_masks), + None, n_masks) + txt_labels = pad_sequence(txt_labels, batch_first=True, padding_value=-1) + gather_tgt = _batch_gather_tgt(gather_inds, n_masks) + + batch = {'input_ids': input_ids, + 'position_ids': position_ids, + 'attn_masks': attn_masks, + 'gather_tgt': gather_tgt, + 'txt_labels': txt_labels} + return batch diff --git a/dvl/data/mrm.py b/dvl/data/mrm.py new file mode 100644 index 0000000..16516bd --- /dev/null +++ b/dvl/data/mrm.py @@ -0,0 +1,263 @@ +""" +MRM Datasets +""" +import random + +import torch +from torch.utils.data import Dataset +from torch.nn.utils.rnn import pad_sequence +from toolz.sandbox import unzip +from uniter_model.data.data import DetectFeatTxtTokDataset, pad_tensors, get_gather_index, get_gather_index_uniter + + +def _get_img_mask(mask_prob, num_bb): + img_mask = [random.random() < mask_prob for _ in range(num_bb)] + if not any(img_mask): + # at least mask 1 + img_mask[random.choice(range(num_bb))] = True + img_mask = torch.tensor(img_mask) + return img_mask + + +def _get_img_tgt_mask(img_mask, txt_len): + z = torch.zeros(txt_len, dtype=torch.bool) + img_mask_tgt = torch.cat([z, img_mask], dim=0) + return img_mask_tgt + + +def _get_feat_target(img_feat, img_masks): + img_masks_ext = img_masks.unsqueeze(-1).expand_as(img_feat) # (n, m, d) + feat_dim = img_feat.size(-1) + feat_targets = img_feat[img_masks_ext].contiguous().view( + -1, feat_dim) # (s, d) + return feat_targets + + +def _mask_img_feat(img_feat, img_masks): + img_masks_ext = img_masks.unsqueeze(-1).expand_as(img_feat) + img_feat_masked = img_feat.data.masked_fill(img_masks_ext, 0) + return img_feat_masked + + +class MrfrDataset(DetectFeatTxtTokDataset): + def __init__(self, mask_prob, *args, **kwargs): + super().__init__(*args, **kwargs) + self.mask_prob = mask_prob + + def __getitem__(self, i): + """ + Return: + - input_ids : (L, ), i.e., [cls, wd, wd, ..., sep, 0, 0], 0s padded + - img_feat : (num_bb, d) + - img_pos_feat : (num_bb, 7) + - attn_masks : (L + num_bb, ), ie., [1, 1, ..., 0, 0, 1, 1] + - img_mask : (num_bb, ) between {0, 1} + """ + example = super().__getitem__(i) + # text input + input_ids = example['input_ids'] + input_ids = self.txt_db.combine_inputs(input_ids) + + # image input features + img_input_ids = torch.Tensor([101]).long() + img_feat, img_pos_feat, num_bb = self._get_img_feat(example['img_fname']) + img_mask = _get_img_mask(self.mask_prob, num_bb) + img_mask_tgt = _get_img_tgt_mask(img_mask, 1) + img_mask_tgt_teacher = _get_img_tgt_mask(img_mask, len(input_ids)) + + attn_masks = torch.ones(len(input_ids), dtype=torch.long) + attn_masks_img = torch.ones(num_bb+1, dtype=torch.long) + attn_masks_teacher = torch.ones(len(input_ids) + num_bb, dtype=torch.long) + + return (input_ids, attn_masks, img_input_ids, img_feat, img_pos_feat, attn_masks_img, + img_mask, img_mask_tgt, attn_masks_teacher, img_mask_tgt_teacher) + + +def mrfr_collate(inputs): + """ + Return: + - input_ids : (n, max_L), i.e., [cls, wd, wd, ..., sep, 0, 0], 0s padded + - position_ids : (n, max_L) + - txt_lens : list of [input_len] + - img_feat : (n, max_num_bb, d) + - img_pos_feat : (n, max_num_bb, 7) + - num_bbs : list of [num_bb] + - attn_masks : (n, max_{L + num_bb}), ie., [1, 1, ..., 0, 0, 1, 1] + - img_masks : (n, max_num_bb) between {0, 1} + """ + (input_ids, attn_masks, img_input_ids, img_feats, img_pos_feats, attn_masks_img, img_masks, img_mask_tgts, + attn_masks_teacher, img_mask_tgt_teacher) = map(list, unzip(inputs)) + + txt_lens = [i.size(0) for i in input_ids] + + input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0) + position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long + ).unsqueeze(0) + + + num_bbs = [f.size(0) for f in img_feats] + img_feat = pad_tensors(img_feats, num_bbs) + img_input_ids = pad_sequence(img_input_ids, batch_first=True, padding_value=0) + img_pos_feat = pad_tensors(img_pos_feats, num_bbs) + img_position_ids = torch.arange(0, img_input_ids.size(1), dtype=torch.long).unsqueeze(0) + img_masks = pad_sequence(img_masks, batch_first=True, padding_value=0) + feat_targets = _get_feat_target(img_feat, img_masks) + img_feat = _mask_img_feat(img_feat, img_masks) + img_mask_tgt = pad_sequence(img_mask_tgts, batch_first=True, padding_value=0) + img_mask_tgt_teacher = pad_sequence(img_mask_tgt_teacher, batch_first=True, padding_value=0) + + attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0) + attn_masks_img = pad_sequence(attn_masks_img, batch_first=True, padding_value=0) + bs, max_tl = input_ids.size() + out_size = attn_masks_img.size(1) + # gather_index = get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size) + gather_index = get_gather_index([1]*bs, num_bbs, bs, 1, out_size) + + attn_masks_teacher = pad_sequence(attn_masks_teacher, batch_first=True, padding_value=0) + gather_index_teacher = get_gather_index_uniter(txt_lens, num_bbs, bs, max_tl, attn_masks_teacher.size(1)) + + batch = { + 'txts': { + 'input_ids': input_ids, + 'position_ids': position_ids, + 'attention_mask': attn_masks, + 'img_feat': None, + 'img_pos_feat': None, + 'img_masks': None, + 'gather_index': None + }, + 'imgs': { + 'input_ids': img_input_ids, + 'position_ids': img_position_ids, + 'attention_mask': attn_masks_img, + 'img_feat': img_feat, + 'img_pos_feat': img_pos_feat, + 'img_masks': img_masks, + 'gather_index': gather_index + }, + 'teacher': { + 'txt_lens': txt_lens, + 'num_bbs': num_bbs, + 'bs': bs, + 'max_tl': max_tl, + 'out_size': out_size, + 'gather_index': gather_index_teacher, + 'attn_masks': attn_masks_teacher, + 'img_mask_tgt': img_mask_tgt_teacher, + }, + 'feat_targets': feat_targets, + 'img_mask_tgt': img_mask_tgt} + return batch + + +def _get_targets(img_masks, img_soft_label): + soft_label_dim = img_soft_label.size(-1) + img_masks_ext_for_label = img_masks.unsqueeze(-1).expand_as(img_soft_label) + label_targets = img_soft_label[img_masks_ext_for_label].contiguous().view( + -1, soft_label_dim) + return label_targets + + +class MrcDataset(DetectFeatTxtTokDataset): + def __init__(self, mask_prob, *args, **kwargs): + super().__init__(*args, **kwargs) + self.mask_prob = mask_prob + + def _get_img_feat(self, fname): + img_dump = self.img_db.get_dump(fname) + num_bb = self.img_db.name2nbb[fname] + img_feat = torch.tensor(img_dump['features']) + bb = torch.tensor(img_dump['norm_bb']) + img_bb = torch.cat([bb, bb[:, 4:5]*bb[:, 5:]], dim=-1) + img_soft_label = torch.tensor(img_dump['soft_labels']) + return img_feat, img_bb, img_soft_label, num_bb + + def __getitem__(self, i): + example = super().__getitem__(i) + img_feat, img_pos_feat, img_soft_labels, num_bb = self._get_img_feat( + example['img_fname']) + + # image input features + img_input_ids = torch.Tensor([101]).long() + img_mask = _get_img_mask(self.mask_prob, num_bb) + + # text input + input_ids = example['input_ids'] + input_ids = self.txt_db.combine_inputs(input_ids) + img_mask_tgt = _get_img_tgt_mask(img_mask, 1) + img_mask_tgt_teacher = _get_img_tgt_mask(img_mask, len(input_ids)) + + attn_masks = torch.ones(len(input_ids), dtype=torch.long) + attn_masks_img = torch.ones(num_bb+1, dtype=torch.long) + attn_masks_teacher = torch.ones(len(input_ids) + num_bb, dtype=torch.long) + + return (input_ids, attn_masks, img_input_ids, img_feat, img_pos_feat, attn_masks_img, + img_soft_labels, img_mask, img_mask_tgt, attn_masks_teacher, img_mask_tgt_teacher) + + +def mrc_collate(inputs): + (input_ids, attn_masks, img_input_ids, img_feats, img_pos_feats, attn_masks_img, img_soft_labels, + img_masks, img_mask_tgts, attn_masks_teacher, img_mask_tgt_teacher) = map(list, unzip(inputs)) + + txt_lens = [i.size(0) for i in input_ids] + num_bbs = [f.size(0) for f in img_feats] + + input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0) + position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long + ).unsqueeze(0) + + img_feat = pad_tensors(img_feats, num_bbs) + img_input_ids = pad_sequence(img_input_ids, batch_first=True, padding_value=0) + img_pos_feat = pad_tensors(img_pos_feats, num_bbs) + img_position_ids = torch.arange(0, img_input_ids.size(1), dtype=torch.long).unsqueeze(0) + img_soft_label = pad_tensors(img_soft_labels, num_bbs) + img_masks = pad_sequence(img_masks, batch_first=True, padding_value=0) + label_targets = _get_targets(img_masks, img_soft_label) + + img_feat = _mask_img_feat(img_feat, img_masks) + img_mask_tgt = pad_sequence(img_mask_tgts, batch_first=True, padding_value=0) + img_mask_tgt_teacher = pad_sequence(img_mask_tgt_teacher, batch_first=True, padding_value=0) + + attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0) + attn_masks_img = pad_sequence(attn_masks_img, batch_first=True, padding_value=0) + bs, max_tl = input_ids.size() + out_size = attn_masks_img.size(1) + # gather_index = get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size) + gather_index = get_gather_index([1]*bs, num_bbs, bs, 1, out_size) + + attn_masks_teacher = pad_sequence(attn_masks_teacher, batch_first=True, padding_value=0) + gather_index_teacher = get_gather_index_uniter(txt_lens, num_bbs, bs, max_tl, attn_masks_teacher.size(1)) + + batch = { + 'txts': { + 'input_ids': input_ids, + 'position_ids': position_ids, + 'attention_mask': attn_masks, + 'img_feat': None, + 'img_pos_feat': None, + 'img_masks': None, + 'gather_index': None + }, + 'imgs': { + 'input_ids': img_input_ids, + 'position_ids': img_position_ids, + 'attention_mask': attn_masks_img, + 'img_feat': img_feat, + 'img_pos_feat': img_pos_feat, + 'img_masks': img_masks, + 'gather_index': gather_index + }, + 'teacher': { + 'txt_lens': txt_lens, + 'num_bbs': num_bbs, + 'bs': bs, + 'max_tl': max_tl, + 'out_size': out_size, + 'gather_index': gather_index_teacher, + 'attn_masks': attn_masks_teacher, + 'img_mask_tgt': img_mask_tgt_teacher, + }, + 'img_mask_tgt': img_mask_tgt, + 'label_targets': label_targets} + return batch + diff --git a/dvl/data/vqa.py b/dvl/data/vqa.py new file mode 100644 index 0000000..f4b18ca --- /dev/null +++ b/dvl/data/vqa.py @@ -0,0 +1,145 @@ +""" +VQA dataset +""" +import torch +from torch.nn.utils.rnn import pad_sequence +from toolz.sandbox import unzip + +from uniter_model.data.data import DetectFeatTxtTokDataset, pad_tensors, get_gather_index + + +def _get_vqa_target(example, num_answers): + target = torch.zeros(num_answers) + labels = example['target']['labels'] + scores = example['target']['scores'] + if labels and scores: + target.scatter_(0, torch.tensor(labels), torch.tensor(scores)) + return target + + +class VqaDataset(DetectFeatTxtTokDataset): + """ NOTE: This handels distributed inside """ + def __init__(self, num_answers, *args, **kwargs): + super().__init__(*args, **kwargs) + self.num_answers = num_answers + + def __getitem__(self, i): + example = super().__getitem__(i) + qid = self.ids[i] + img_feat, img_pos_feat, num_bb = self._get_img_feat( + example['img_fname']) + + # text input + input_ids = example['input_ids'] + input_ids = self.txt_db.combine_inputs(input_ids) + img_input_ids = torch.Tensor([101]).long() + + target = _get_vqa_target(example, self.num_answers) + + attn_masks_txt = torch.ones(len(input_ids), dtype=torch.long) + attn_masks_img = torch.ones(num_bb+1, dtype=torch.long) + + return qid, input_ids, attn_masks_txt, img_input_ids, img_feat, img_pos_feat, attn_masks_img, target + + +def vqa_collate(inputs): + (qids, input_ids, attn_masks_txt, img_input_ids, img_feats, img_pos_feats, attn_masks_img, targets + ) = map(list, unzip(inputs)) + + txt_lens = [i.size(0) for i in input_ids] + input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0) + position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long + ).unsqueeze(0) + + attn_masks_txt = pad_sequence(attn_masks_txt, batch_first=True, padding_value=0) + attn_masks_img = pad_sequence(attn_masks_img, batch_first=True, padding_value=0) + targets = torch.stack(targets, dim=0) + + num_bbs = [f.size(0) for f in img_feats] + img_feat = pad_tensors(img_feats, num_bbs) + img_pos_feat = pad_tensors(img_pos_feats, num_bbs) + img_input_ids = pad_sequence(img_input_ids, batch_first=True, padding_value=0) + img_position_ids = torch.arange(0, img_input_ids.size(1), dtype=torch.long).unsqueeze(0) + + bs, max_tl = input_ids.size() + out_size = attn_masks_img.size(1) + gather_index_teacher = get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size) + gather_index = get_gather_index([1]*bs, num_bbs, bs, 1, out_size) + + batch = {'qids': qids, + 'txts': { + 'input_ids': input_ids, + 'position_ids': position_ids, + 'attention_mask': attn_masks_txt, + 'img_feat': None, + 'img_pos_feat': None, + 'img_masks': None, + 'gather_index': None + }, + 'imgs': { + 'input_ids': img_input_ids, + 'position_ids': img_position_ids, + 'attention_mask': attn_masks_img, + 'img_feat': img_feat, + 'img_pos_feat': img_pos_feat, + 'img_masks': None, + 'gather_index': gather_index + }, + 'gather_index_teacher': gather_index_teacher, + 'targets': targets} + return batch + + +class VqaEvalDataset(VqaDataset): + def __getitem__(self, i): + qid = self.ids[i] + example = DetectFeatTxtTokDataset.__getitem__(self, i) + img_feat, img_pos_feat, num_bb = self._get_img_feat( + example['img_fname']) + + # text input + input_ids = example['input_ids'] + input_ids = self.txt_db.combine_inputs(input_ids) + + if 'target' in example: + target = _get_vqa_target(example, self.num_answers) + else: + target = None + + attn_masks = torch.ones(len(input_ids) + num_bb, dtype=torch.long) + + return qid, input_ids, img_feat, img_pos_feat, attn_masks, target + + +def vqa_eval_collate(inputs): + (qids, input_ids, img_feats, img_pos_feats, attn_masks, targets + ) = map(list, unzip(inputs)) + + txt_lens = [i.size(0) for i in input_ids] + + input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0) + position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long + ).unsqueeze(0) + attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0) + if targets[0] is None: + targets = None + else: + targets = torch.stack(targets, dim=0) + + num_bbs = [f.size(0) for f in img_feats] + img_feat = pad_tensors(img_feats, num_bbs) + img_pos_feat = pad_tensors(img_pos_feats, num_bbs) + + bs, max_tl = input_ids.size() + out_size = attn_masks.size(1) + gather_index = get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size) + + batch = {'qids': qids, + 'input_ids': input_ids, + 'position_ids': position_ids, + 'img_feat': img_feat, + 'img_pos_feat': img_pos_feat, + 'attn_masks': attn_masks, + 'gather_index': gather_index, + 'targets': targets} + return batch diff --git a/dvl/hn.py b/dvl/hn.py new file mode 100644 index 0000000..d57e061 --- /dev/null +++ b/dvl/hn.py @@ -0,0 +1,66 @@ +import random +import logging +import collections +import json +import os +import itertools +import numpy as np +from collections import ChainMap + + +from dvl.trainer import build_dataloader, _save_checkpoint, eval_model_on_dataloader, load_dataset + + +logger = logging.getLogger() + + +def random_hard_neg(fname2id, num_hard_negatives, id2set, set2id): + # num_hard_negatives must be very small + hard_negs = dict() + for i in fname2id: + while True: + hard_neg = random.choices(set2id[id2set[i]], k=num_hard_negatives) + if fname2id[i] not in hard_neg: + break + hard_negs[i] = hard_neg + return hard_negs + + +def get_img_txt_mappings(train_txt_dbs): + train_img2txt = dict(ChainMap(*[json.load(open(os.path.join(db_folder, 'img2txts.json'))) for db_folder in train_txt_dbs])) + train_txt2img = dict(itertools.chain(*[[(v, k) for v in vals] for k, vals in train_img2txt.items()])) + + train_json = [json.load(open(os.path.join(db_folder, 'img2txts.json'))) for db_folder in train_txt_dbs] + train_img2set = dict(ChainMap(*[{k:v for k in tj } for tj, v in zip(train_json, train_txt_dbs)])) + train_txt2set = {txt_id: train_img2set[img_id] for txt_id, img_id in train_txt2img.items()} + + train_set2img, train_set2txt = collections.defaultdict(list), collections.defaultdict(list) + for img_id, set_id in train_img2set.items(): + train_set2img[set_id].append(img_id) + train_set2txt[set_id] += train_img2txt[img_id] + + return train_img2txt, train_txt2img, train_img2set, train_txt2set, train_set2img, train_set2txt + + +def sampled_hard_negatives(all_img_dbs, args, collate_func, bi_encoder, train_img2txt, train_txt2img): + train_dataset_eval = load_dataset(all_img_dbs, args.train_txt_dbs, args.train_img_dbs, args, True) + hard_negs_txt_all, hard_negs_img_all = [], [] + for dset in train_dataset_eval.datasets: + dset.new_epoch() + train_dataloader_hn = build_dataloader(dset, collate_func, True, args, args.valid_batch_size) + logger.info(f'eval for train dataloader len (for hn) = {len(train_dataloader_hn)}') + + num_hard_sampled = min(max(args.num_hard_negatives * 2 + 10, 50), 1000) + loss_hard, correct_ratio_hard, indexer_hard, recall_hard, (hard_neg_img, hard_neg_txt) = \ + eval_model_on_dataloader(bi_encoder, train_dataloader_hn, args, train_img2txt, num_hard_sampled) + + [v.remove(train_txt2img[k]) for k, v in hard_neg_img.items() if train_txt2img[k] in v] + hard_neg_txt = {k: list(set(v) - set(train_img2txt[k])) for k, v in hard_neg_txt.items()} + + + # remove self in hard negatives as they are labels + hard_negs_txt_all.append({k: random.sample(v, args.num_hard_negatives) for k, v in hard_neg_txt.items()}) + hard_negs_img_all.append({k: random.sample(v, args.num_hard_negatives) for k, v in hard_neg_img.items()}) + hard_negs_txt_all = dict(collections.ChainMap(*hard_negs_txt_all)) + hard_negs_img_all = dict(collections.ChainMap(*hard_negs_img_all)) + return hard_negs_txt_all, hard_negs_img_all diff --git a/dvl/indexer/faiss_indexers.py b/dvl/indexer/faiss_indexers.py new file mode 100644 index 0000000..8229cb1 --- /dev/null +++ b/dvl/indexer/faiss_indexers.py @@ -0,0 +1,154 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" + FAISS-based index components for dense retriver +""" + +import logging +import pickle +from typing import List, Tuple + +import faiss +import numpy as np + +logger = logging.getLogger() + + +class DenseIndexer(object): + + def __init__(self, buffer_size: int = 50000): + self.buffer_size = buffer_size + self.index_id_to_db_id = [] + self.index = None + + def index_data(self, data: List[Tuple[object, np.array]]): + raise NotImplementedError + + def search_knn(self, query_vectors: np.array, top_docs: int) -> List[Tuple[List[object], List[float]]]: + raise NotImplementedError + + def serialize(self, file: str): + logger.info('Serializing index to %s', file) + + index_file = file + '.index.dpr' + meta_file = file + '.index_meta.dpr' + + faiss.write_index(self.index, index_file) + with open(meta_file, mode='wb') as f: + pickle.dump(self.index_id_to_db_id, f) + + def deserialize_from(self, file: str): + logger.info('Loading index from %s', file) + + index_file = file + '.index.dpr' + meta_file = file + '.index_meta.dpr' + + self.index = faiss.read_index(index_file) + logger.info('Loaded index of type %s and size %d', type(self.index), self.index.ntotal) + + with open(meta_file, "rb") as reader: + self.index_id_to_db_id = pickle.load(reader) + assert len( + self.index_id_to_db_id) == self.index.ntotal, 'Deserialized index_id_to_db_id should match faiss index size' + + def _update_id_mapping(self, db_ids: List): + self.index_id_to_db_id.extend(db_ids) + + +class DenseFlatIndexer(DenseIndexer): + + def __init__(self, vector_sz: int, buffer_size: int = 50000): + super(DenseFlatIndexer, self).__init__(buffer_size=buffer_size) + self.index = faiss.IndexFlatIP(vector_sz) + + def index_data(self, data: List[Tuple[object, np.array]]): + n = len(data) + # indexing in batches is beneficial for many faiss index types + for i in range(0, n, self.buffer_size): + db_ids = [t[0] for t in data[i:i + self.buffer_size]] + vectors = [np.reshape(t[1], (1, -1)) for t in data[i:i + self.buffer_size]] + vectors = np.concatenate(vectors, axis=0) + self._update_id_mapping(db_ids) + self.index.add(vectors) + + indexed_cnt = len(self.index_id_to_db_id) + logger.info('Total data indexed %d', indexed_cnt) + + def search_knn(self, query_vectors: np.array, top_docs: int) -> List[Tuple[List[object], List[float]]]: + scores, indexes = self.index.search(query_vectors, top_docs) + # convert to external ids + db_ids = [[self.index_id_to_db_id[i] for i in query_top_idxs] for query_top_idxs in indexes] + result = [(db_ids[i], scores[i]) for i in range(len(db_ids))] + return result + + +class DenseHNSWFlatIndexer(DenseIndexer): + """ + Efficient index for retrieval. Note: default settings are for hugh accuracy but also high RAM usage + """ + + def __init__(self, vector_sz: int, buffer_size: int = 50000, store_n: int = 512 + , ef_search: int = 128, ef_construction: int = 200): + super(DenseHNSWFlatIndexer, self).__init__(buffer_size=buffer_size) + + # IndexHNSWFlat supports L2 similarity only + # so we have to apply DOT -> L2 similairy space conversion with the help of an extra dimension + index = faiss.IndexHNSWFlat(vector_sz + 1, store_n) + index.hnsw.efSearch = ef_search + index.hnsw.efConstruction = ef_construction + self.index = index + self.phi = 0 + + def index_data(self, data: List[Tuple[object, np.array]]): + n = len(data) + + # max norm is required before putting all vectors in the index to convert inner product similarity to L2 + if self.phi > 0: + raise RuntimeError('DPR HNSWF index needs to index all data at once,' + 'results will be unpredictable otherwise.') + phi = 0 + for i, item in enumerate(data): + id, doc_vector = item + norms = (doc_vector ** 2).sum() + phi = max(phi, norms) + logger.info('HNSWF DotProduct -> L2 space phi={}'.format(phi)) + self.phi = 0 + + # indexing in batches is beneficial for many faiss index types + for i in range(0, n, self.buffer_size): + db_ids = [t[0] for t in data[i:i + self.buffer_size]] + vectors = [np.reshape(t[1], (1, -1)) for t in data[i:i + self.buffer_size]] + + norms = [(doc_vector ** 2).sum() for doc_vector in vectors] + aux_dims = [np.sqrt(phi - norm) for norm in norms] + hnsw_vectors = [np.hstack((doc_vector, aux_dims[i].reshape(-1, 1))) for i, doc_vector in + enumerate(vectors)] + hnsw_vectors = np.concatenate(hnsw_vectors, axis=0) + + self._update_id_mapping(db_ids) + self.index.add(hnsw_vectors) + logger.info('data indexed %d', len(self.index_id_to_db_id)) + + indexed_cnt = len(self.index_id_to_db_id) + logger.info('Total data indexed %d', indexed_cnt) + + def search_knn(self, query_vectors: np.array, top_docs: int) -> List[Tuple[List[object], List[float]]]: + + aux_dim = np.zeros(len(query_vectors), dtype='float32') + query_nhsw_vectors = np.hstack((query_vectors, aux_dim.reshape(-1, 1))) + # logger.info('query_hnsw_vectors %s', query_nhsw_vectors.shape) + scores, indexes = self.index.search(query_nhsw_vectors, top_docs) + # convert to external ids + db_ids = [[self.index_id_to_db_id[i] for i in query_top_idxs] for query_top_idxs in indexes] + result = [(db_ids[i], scores[i]) for i in range(len(db_ids))] + return result + + def deserialize_from(self, file: str): + super(DenseHNSWFlatIndexer, self).deserialize_from(file) + # to trigger warning on subsequent indexing + self.phi = 1 diff --git a/dvl/models/__init__.py b/dvl/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/dvl/models/__pycache__/__init__.cpython-38.pyc b/dvl/models/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4e9b7d32cf8cc325a759c60c1e1897b697540085 GIT binary patch literal 178 zcmWIL<>g`kf~QR@lR)%i5P=LBfgA@QE@lA|DGb33nv8xc8Hzx{2;x_`erR!OQL%ni zW=>9KmA+m{etAY}s(wakl741xVtT4>NoqxjZfb5)YD!9GUb=oxW_m_R9*~uiU!tFz tn3<=aQkJ8io1c=JQ>-5!pP83g5+AQuPADlV53$tl}OWG7J+J5dt5;&M`xN|j1Mt)?M+KOgZM6QM&jA#?$)d~D{-s2r!{RErqP=7 z_DYV^oNw*(_DMX~+~2y*yRCJ=JJ7n_yIu0!X1R5TcZbCD%{yCnd3UuAdIwu?^WG+T zh34I@d%SxjUTogms(6*wecpYo`@Q=muY_j~d52nuy~C{qZ$a`V@CHY`Bc>6gwvKKa z-m&1=)&oIm_3hgx@{VmixNZE5sl4FW>N~a`+Qw*6@13akFv{K)Jj`1D+k+YY`BPU@ z{5IsKSze7qTKX^JgiN9xp zXNP&qmVfV6$FKPJq4slV{eJ%tT0ay#jo#ktKja_w7cQIL`%rSkKZ=s0Nx!H3WBvn| z4exp1dcRS5XZKFG{9_GqeR_SZ z86cAPgXP+KGx8ht2uZ8mMwmGrw8M_NX2#B`T5HL#oq8F$xl^@Efm-aSRwWhZo(NR* zJQ{7rE~8I&+RKfVco*X@s^By`4}7VBN7Ik{wbm{68Hfv$n zSZ>s75mNCUoqIBZ#tvAu9eX0>gC;6V|s^^1x)KOua zuSXXLiHr)+PP?Asj3&PNJ*|$^M)<&vlsrN?4YN}`WDc`&;pegynOGk7HEju-?Jz zl$SfI9Gnl-rLyW=C}Z`9BbEmrS*xj9tK5I3M@BP;Rg-M}f;yp0yv=tdxQ^d7Gfw%9 zRwW&$F*Efuai&#^FmZ7XK-Uh}I$;poZ2(G~TEUWJ)f^sIdlAH$dM5;A#HIqQ9mDm2 zmu^=(>(Mns-HQr8`0$YrgnRdPz1V_%SFCA&LYAeCvD7Y9M4q8ir z?>E{jN1Bb5Gf^8czY`s40qTzU=bJ}bgv{{BQlpxH{=(X&cp6ozYpUa~*Q2VQvWL)R zSVCYpv!-p9@K-P$Q{9K0YgX*^r%DE22mT)X7*FGp#hye2i-^V63w?#lm0DE$kgsk- zP_bh3WIUx`z8W@Kjb=?XqD$(2*5Hflpn_Oj>{pXjeFj;f^lYgkh}UtA{RndWg)bur z4A2z=G^G@zeCw+DX=7Oc>vGYvgS0;pWbl^?vVr4gfTOa$b2+_ff|gABuAc{Y1w`bZ71G-YD%`o8~9Y9th!CTd63smBr<-hSW;U7~58q+cYA#pZ0A(eZ}h6 z@=>E_ZCl$mTF$pKzQKFX*eZ}P_}D+!GK(_s?j6Z%l+QGLKWJkqx&by+`<|%06qLJx z>U@`T{hg`*diWEc|J=Fv5Y&UUuu_cE;rd#jVy9Yd0EbnpiZ2=4OPHECKhj!UdOQRK z6X`syR7WYB_1yOd=hlOE)TlM%%qi@Eno?6JP+Yw@uc4H{xf&Ow{R;sBJG=&@7Y5B` zZX9A*^#JN)P`PTgP6UT>2=owe)VmNok1OPUGXTl1Ic+(B>uFq*=H%;6He+^YMpkp7 z|2E(G#n1Ein~z;f$4-5DrP{8wf_S2SCa9lX>onTfIJM}E^d4t@)mg)qRrBbj&P~w& z@ntf&-iIsfA`lE{szT3fV|BMpU^Lse&ZgAez;KzWB_S}PZ=B7kk41J5kbrU<@$^;( zYie#}d)!#&ve_dx%=)RTsb2bX)=!zhiX$byebwscbHyv^#UD&@aSkul!r<_FyN;VL zKu}33?ppmaZWsZR0*XQZtYJQo7WuY#Vp(-sSeHP^Y2e5EYY~VASC^}`6jqbTkW{=Y z0fe#4R5NI++u54k?9}{gPMoSYL-j1j@LYc&>b!t+Hf{<8*Y2vH#XcGN-dW593Ok_CoCVU$@_9@uYp`pI<M|Gt%qn7NU8izPvBxkjtN`i*g)Q#~ z7vqUgHgP>5!9ss=ag9ueKS+-EzhmjLAkSD`zlFZTG6Dm8Gh!t|{5a}`gcbv%bIfrV8?9I|^)|EQrcUCCioJ**uqe!=LbZBsz1B=hQb0F?%fNCG$F+j!OvjH?VWi^pGPY18xL17;4Xc2`DuTF}pu~F6 zZ1&Njyw6jtA_(#*V%$|B=gKi>CM`^}b-g5e2wB(NX~jBqTw}lfrU_p~_JGDt-jv2p z-IB&m2X2t}-DPX9;peZKLTS@p!7up}ptwb$xFw*jU2#$Kz{%R{c|3-$Z%t``5~VvS zZIs%yKnY1jAxwhS-WO%YXl)c_w;WtK){t77*1Djz`F=W~wMnfpT3bkHZ2|kc2&gNG z*ahsX-j3l^Ce*{Y$xhQYTpSI0IQs zS=Ir``xNpAB(HEYk~cdxodcpLpsC1U-9W`i)r5*|A@b*1xmy+(0s}vQ?mE3?DmWR}EM?Em#S8^N9 zfF{~E&_s1pnwVZZUCC)eMH;24)yEjCGI$k%C)BD>B~=Ys@!o_0)`0r-0{hh_#r5!{ zMjk)Hr^zquoKxb@FX0LwL@+{;G+ko&{Zf8Fn?`aoZbGFJnf@<&TvMAjjffDF&k|>LCT&{OoYd{q3XrYXP z+FICGDXc?*a0J3iNLH&11c0h-O5I1+K#U=R_xF9cAPNKxVbGOKEH^Gz$vXvFRMZ*K z#i*Yl9QPQi&EReM>M6c@0jx}9L=X=*tu2e9Ta?;@kSGX{F`z{g5;llg>a3-niZZKN zq$nunAXeB{pi-&2?dd4*3o!$t&3rLM(Q-#S*`%F<@1UK}Mn!3-gmxwd?M${GM7kh7 z{c6(F&Km5Aq_Z&W3CO~zP5~Jh)oGv&qq?h?Qhy3mGPyOgZ9HXs)ZIw?ZZBQkeeRn* z3mQxZW$q1SuZswiU+nt#{eSp%r${WILn#xh3wU(%np}wV1Oz{jQ@V#l4bz8#vfqK? zsojan5Rcb7Dk^u#Fqes*!t%Y{Tt%owyK`aK2_PcrBsUPyy7w_R0h~8!^kijLGZvo! z)aaetW$Z^0cqL)pss0f`LUx-Fj(BodqGf4acS!Cai;(fxO3e?}n(KiG@Lq`^EZMrr zqwHsLnAvZQIisx&YN$_gc$*AJBV$+eZ~+vdVO&&@DH>qLt1GJJS2Dtf$2O(FosXsE zv72!~zmHbKLkJ8guw40@v!)>m@^6k8QVYav6bXkz@2F! zjpq^RIqE?&L)*qrm$4JE7bzfW^4QC6WwfQqdJ>1?|r^A$THmjiU;(%3{+D+MnXt}p3z+adX> zae^u&)XNilB>&$f6m{FSkXV|&-?ojEBSBp_xraYHX#h+BpAfI<2{3EJlUa5u8&K*`7%sv<@?Le zzx;ezj2I`!34@FG;7bt>!=FQ-H4nnFS|Q-LqC_ai7J`HmZAo361I; z_O?;oD?8g>RL3wfMY%>~pvwG+I${F3aY>ZBt08o`%MBP0xp-PVG0hlvRO}8+O<_E7 zp`-ljNJd6$=@rq1+RcD^jaYb)k!rH3XqT<2xIl#?h=a zq7U)hM%3!-7hR~0tKc>utjBrUvY@f-zz!Pc&I7qDK}8Q+XmAl`ibe~Bsp4vEkk<@~ zen;wwgbVRx&%=7-6}jh|wI#4jNXAYZ3eG5Kt*KvT-D?babwF@3NNR6x;`4MR*C zn=ZATmM3y};#Z=S?}`!c#wYW55;B^!0BK3s378OIKpQz=R01jiE1;W*b3hf?O!CRGl*Fi3F7yB0gqb1gWP5c34SNKQ}yXp=`3!>wYZ z!v;7Mn%-0tUT8#Th6Nak`V{)Alw>QZpJ%@|2N4&jvRH?)uGtAg_4iPqzRciN1l}Yk zPAnT${R-;WnVlDOo8C^S-?NZ(RJUj?e6Vw0ybmYoA@D(OSbp75h0?(7Jy7}#p-kM zy>w!UgA@RJ9GF;bjoZSe@7b%CaNPEGN;Jxl5wPW6CWpKS8@2r~T1@#Fgzo>85zf& z%X3^XhHSWzyP5|Q0OID;kOy+UdnE+~fZS!eH{};kLwQSvR_inOAr_~AUSkK2Cs6}N zf*;I?9-=V>!ty1sH2e9?Is#hFfw2;x~# zTMcsDjKJgxH#|v9t`)4+^)|7(Z~!C$>M$QH))inB;$#F<`6YBlvLLXm;r44x{#6FQ zz~J);G(Dzan8IgVOnOW#jZ;IKlen(LrKCo)r8mS5;QfbaB_!}N-!NxV_kpD7Kk(t! zB=rFRO2nZZpt;cku^y3JPL8pJee)xC{+NJD+LOU00J(q&a5E^*wy9k{Hb^h9#MI|) zff?+XTB9e9P&5Kt-4kV?v@-daGE_%p{Z?ho6P54vJM_&+mZ7}fZU=QZZq?x9q}L*> zlx2zGjFek%htOLF-5}0^3xejdS_@+r+)vb4S?{cev30RGtfWYLC&OWuc!CV3Cx3dB(ZjsYo{xWI)aNWhLDEciA|U-sPa z-LxXZ0-U4)MQI;@Ko-#T0%6$^O=dt(n2{_H1WLA`O@+9TL7lw%6wEYmLtuP>TrxdS z{z`7KJMmOUFvNNY2gx%Hm`45?#;BdN1WRBd0FdecZgE;n{tMW;pJFB96UBev-!b`b zStX-^Y-mN-x~Yuz85p7j_>v5^zvm_%1tjd;=;}W*_% z!@wGqtb^dUNo}}=C$V`muz=l~v+g%i7rof_kP>1G%-Sd^$!tJbL$jASb=-`U zWJgGeIC>HZy+UE3K=Ox5FEh;+BAJ&{foU~ zuOI}%T|BLqlZ&X=^y`fMa|XZ5fHXk;9s=kFWI6vOGbqUKK>LdOTybltzr{c>y3-0Z z<1xyZ2h?3cLocOzn0Gv4O63v4v@dR1I|Q^z4MLTSC|S zRNe%4keYyx9_Tpad>B(QJ$D70o&KzX&N8x$GMS%omgWUs#aJ~%a2sOO5_1Xqa@}AZJ(n_Vf&xYP zE7INVs0%wd7C>ZPdaSCy!504}3&>-WX@$jbuziK8+_~z)R?5=LhUQ-Oa-xB&g87b3 z__~k@Pbj%7QoHyMd4;5ln-cKfMmxXFr9X;@kpkC95}n2!OI-rtw#>=vj%}L-0RRV~ z3w0W>0)gP3HwOZFyD@VdslWgZnn-|Llb}Q4k2Qe^hEys^Wg>{+SiKBZTDZ?GdmADm za$%lB&5REp2c)vS?5YDu0M96!m;QYDeJ~}+y24<9Rn8{sI&_$oMTG}$+NUK6wsNe> za03@q85yw|Y4?F|vGg2sGLY2Ua4sN{rng;tML)_dto{vyzk?uAjA*R;9p-+U0hbU< zqB%YB5B5ZQlBl>XG@!7h{J^Fmg6!C`Z@AHl|n3kK&a>D3bGG5YA^-vLAK z4!9GJqSPzu67(UxjFnDc3@|=yOyW-cz*IDW?LE<(>`iUj8`J1@QoYfe-kQ=H35;l0 z``}o5CP~?6=`6FmH`Cj_Y@rr7rHx(DG~8uod$U`+2%ykah}Hoi%?)&qoAD>M9f(VN zHs&_=N)pqT}Cutg>5ELcv?Y*a(Y53<;$GMnOJD{UHOvm1#OikGKu0 z@3M%SGj{re_hvh55p2Y;zv{yPiaz4--TP%$>x(t;X@>K7JUPHpi_PXK?5P+gj%Hny zj+NU+@a{lrzZs2H|D8>nOR*I-W2@2h(inTS>18m(_<0jtwh3Fk0b( zeuW2iVvyg#fOyOzd=(K;doX#9C{|r_7Hl4HiD^@%O3UKwne@DQK2y%`4@TW6bkUlXy9 zkTVN9yA^OBaC@eovHBM%2G{Xcnbg0{)`UsT!?ms&c9tW+KYYol&p)q8eKRN!LM?QP>kALdY)5q2Kkj*n!+C!z-ggJ=^ zpQ;Mp?y*zD5xXjms>B6tl&ZooQM;u59c?_>|B=6=*q6d3M~29~tyrHi21kX+?HhDf z!bu%$(5se|WPJPbNC}LUC|Tj@HJT52ddCrg97X{M+pw}WYGHSBaAt5pk7~@e1CS7Q za~FqlAA%KNl#kx|5ex>F9X%ju5H35fnH4DP8~3)@q_#6 zsUf5d+9Ne;&j)3oehQ}j{*!1kRcl|mW+@koADe5};G+kZ6}=?HoZ^ojyN@>e4@+lU z8<@w9S)+T$n|6mYWf}1Z3RM;@-Y_3*>@Cl%`b(QxmE-HlI|_d+w!#wfo*11@g|o@< z2WK$N0LPG%hot^)B&GgvMl|?i*_ukd6R(EXpX7l7AG+AYJQpge$KgXoPYji2kYT{- z@%9ha4xU3DoaGvQ5Y`YfEnHa*3NjvnhvL{myqhmEiC*9;o0_30X3hW5PM6?V%$|sL z1jWbJ3&$#s9taLmCrlDdfGcb3F?_SZ9ej{B4&caCHzdOu=}vlwmw8r!#L}xE$P95axJ$CH5-j;IDMDeFwT7?$;qyW zS&+F#$wG+u?aN*T6wdV4GvD3_M9_B1KG-^whJ8*#yAUz6v4VG2y=K|lx-~a<0 zipOb=D5$|1P$dp%2U);Acd&f%e29a@YgMRbG?(01f`!YMr=BIOm9f@FnjU2?2_fz{ z;Q$^_d(=!@^VUpG-GheftT_53xSYThK8s-BbyC3TPyY3tim{Y_1}D$qP^?S*?A4Sw z76TAG2X4kZ7Ygjl?+pO&7S;FB+@c(vi0xKTYh%C7Axod{oq$UadE>!B>Aet5aCqQ6 zk91wELX^izAMIKkA2@^4P-L#@esh!ZnSNWsPFE8Gl0#qp-!vjj%dYtMm;~7sphcpm zg1G_v;y9p{Ck~G%Wq|w*7aEGpoI#Z&byXDPz*O439Q0K1@fUIWMeHXirm3Rl$u(*> zw$jjpVBbgy%+_0(ZJcOL$`PMq3W`qXJ)D02ra5Qq8$JcC3F_x>(jJtYgZ5ZHY!9iV zJ;X-ZllOy2`l5bAiE+%jZ}%pBz|cPEgXP0MkV^VMY}iKx!~Irn8_N&NAI0B`Ow$n? zDuX;Uc?F+(yL^;iMVd@86#PY`pqRy*mY^{-w$%JCZGbC z9H}|UnrC6&hc0OfDuJoQ1+-ZEF)Xk6o0su>Qk&)Ac?>i`2~50ag7CeE3yube6W-{F!6>*?@ZUx&EYj$pPz22jr- z*s$a*vEZRb)F0Ad&@xUiZ-b8#gbQU@*~-C1Ai-u(K3vuy9{AV(bi6oG%6bo-4-$`jzSbkj`TcUO4E%!DF&Q!d*b>$QnT}ekvyO&Q zkVvlOAzvSsu~o}UbR!&$Umu|s8t3iOdy?WbFfg1IfC#+VVP9p<7827McoTPfo}7PJ$*O1|Ayz|9F-eV4-yxmz^kFR zH(5d+Rqp#HK_P+l-9d2F|3({4QWlDrMy`N^;RR?Hpge{-bKac2o=NrZGlQIIP@jV6 zGoi&yiQ#kiPSwoVx4W5j?F&$j76;*Z|uCMU;`g zhV1P_c;qTOl$W(e%%_Q3YvEQe0x!jrXmwGb5}d`3vq?XWGANsY&GeBZE}pu|fbXc- z?@b3p8okN^5hxu_IbZ4xu^C=3+;5pPwq?p3*GFYT%7skx0Wmgy+W54&Y>|V!?C?7! zSS8qcHfgv|b-U#!6|7f4Cino!8F*Fz%Lvk{mHVtM!p;+j^o-ZcvtY(J-c1OS zP*$6-0lGF#eoF-@D>7a~+1#)cPaq{V2!(@>cYF`c!PrFOf|AXS6O#$zuYz%0gB)4v zyUj`l3UPU+RD;PE)uR8W-gb2`7om7%CYTkb5FI?LF9y?eI?Nj;HOJtIRgbMUaD zNCZ>Au%C|-0N#xNGDR)KL94Y6{J!7AT!GF<7?a+2-K6zxjziWzUCOG*P`A!Ojs7^V z{E4Puz(n#~T1i+)oJ`}NZ(q#}T$yn?4Idi8Cq@ECjymVAn!-e;!9wbfi@4&1oX0#* z#zS9u(83QkLYxDCDd7+wZ-+QA3Lxm-Q$7jz*=QXER@{9KgN)Xi8w*M6qaH}-FfaB9& zFu@6W?aKTPj`|@4V8tJ1`~=tdU5tGKL7ZFaf9?rdMY!Ur5ISja_u^$dE(%G9in&$FNHvFjPcS`$YvdM3>hy zneL$hM!xm34?KcTFRC9yTVuG25;3>Jlej{KK=b$8=D-V%U$S^etdu;ziwpCMaF4Mf zpIGB*YaAA&6%)CAE2cQP!taenFnpk+o)VCNs%{)=lmX{iBcraDRoW!a`o!~Pd-iCbx<~#g&Zz1*}oBLL9`&c&aZK?=8-bQE3C|uV-_q8RJ)2X*|xg(zm?^J z5X%o--0+Mja;KRWZP z&-|zEUwI4&UzkLEal%h*W&h)`#Y!G0V({@Q=qBJ5ptlnh+Hvmb{@19)e_Z_vGjT#q z(Fi6V_)!klJ{Id)6`!!cr>&~!gPRPW&XSK{>|B=3%MjCH4Xa>l;Pmf3fMq! z&^r27)#JK`6-#|pPP&xHg zE;r;*ek1CukU~v-SAlZcW`JFA34)O`v(jIQP5I zd1ayV`Yk;wd0vhLL)x{pw2OnfN9m{rP{a2Z)K%ouxFpyfCLtJy_8N~G=g#-@p=8+-)d0jQJ^j$h+Wb1L%ga- z-MjlWr4o4j!@_hxQj|)fF~gVex@jJ$UHC5fIG8;kPEh!Xy8FjI>i{vY1I1b3A1d=i zHvm9npG)kSL{1LUYO~Qx9LNMa2RlO=%F)Em`+`fN7E=!suI;nIrBMB2R!yyS*6>ju zclhBoZ)QM`Mh?fy2P0zh>{m)*0SuL5APii3sF*#e@EuEj15sqh%Ay#5U_DVQP$ToZbt}yd4$Scz6<|D|;lc-i zVyhFvQ&_%*0bg_+u8D1)8iMaocs{mNpv6lHlK9M=)rLkBUMqN9whg{w(`nM$4I&)s zqkDvL#YK`W<;}@g^cF+~y#Pk9&rK^u>#Mg*7q{CBeakaGBIXX`iGy0dmSCK?Al7$a zBHYzcex-SrbAU^8}uAhGzNpAJKuyR7C2G9 z?Q+uxwHN9XpCK|+WYgu~lRhyu%j6yeaSkl%Dfu>tfFWE8kfu63wx3%(`2xPORtulS zBH*M}7&~YBv2>l%MqGS<1o@(sRDkUtqr%6TVi!*0Nv0Oyjz=}8}T(oA%2#q7~Jz$Vx=TXckgr(y| zbMl!q8;68%_@ikbaGm^F(}>|KC%j-mov{jV_v0xcQ4ObNrY6&q@0xsO`eV}{H$Dzg i2zEwFdztHnVuF25 T: + """ + calculates q->ctx scores for every row in ctx_vector + :param q_vector: + :param ctx_vector: + :return: + """ + # q_vector: n1 x D, ctx_vectors: n2 x D, result n1 x n2 + r = torch.matmul(q_vectors, torch.transpose(ctx_vectors, 0, 1)) + if cosine: + n1 = torch.norm(q_vectors, dim=-1) + n2 = torch.norm(ctx_vectors, dim=-1) + n_out = torch.ger(n1, n2) + return r / n_out + return r + + +def cosine_scores(q_vector: T, ctx_vectors: T): + # q_vector: n1 x D, ctx_vectors: n2 x D, result n1 x n2 + return F.cosine_similarity(q_vector, ctx_vectors, dim=1) + + +class BertEncoder(BertPreTrainedModel): + def __init__(self, config, project_dim: int = 0): + super().__init__(config) + self.bert = BertModel(config) + assert config.hidden_size > 0, 'Encoder hidden_size can\'t be zero' + # self.encode_proj = nn.Linear(config.hidden_size, project_dim) if project_dim != 0 else None + if project_dim > 0: + self.encode_proj = nn.Sequential( + nn.Linear(config.hidden_size, config.hidden_size * 2), + GELU(), + LayerNorm(config.hidden_size * 2, eps=1e-12), + nn.Linear(config.hidden_size * 2, project_dim) + ) + else: + self.encode_proj = None + self.init_weights() + + @classmethod + def init_encoder(cls, cfg_name: str, checkpoint_path: str, project_dim: int = 0, dropout: float = 0.1, **kwargs)\ + -> BertModel: + + cfg = BertConfig.from_pretrained(cfg_name if cfg_name else 'bert-base-uncased') + if dropout != 0: + cfg.attention_probs_dropout_prob = dropout + cfg.hidden_dropout_prob = dropout + + if checkpoint_path is not None and len(checkpoint_path) > 0: + #state_dict = torch.load(checkpoint_path, map_location=torch.device('cpu')) + state_dict = torch.load(checkpoint_path) + return cls.from_pretrained(cfg_name, config=cfg, project_dim=project_dim, state_dict=state_dict, **kwargs) + else: + return cls.from_pretrained(cfg_name, config=cfg, project_dim=project_dim, **kwargs) + + def forward(self, input_ids, attention_mask, position_ids, + img_feat=None, img_pos_feat=None, img_masks=None, gather_index=None): + if self.config.output_hidden_states: + sequence_output, pooled_output, hidden_states = self.bert(input_ids=input_ids, + token_type_ids=None, + attention_mask=attention_mask, + position_ids=position_ids) + else: + hidden_states = None + sequence_output, pooled_output = self.bert(input_ids=input_ids, + token_type_ids=None, + attention_mask=attention_mask, + position_ids=position_ids) + pooled_output = sequence_output[:, 0, :] + if self.encode_proj: + pooled_output = self.encode_proj(pooled_output) + return sequence_output, pooled_output, hidden_states + + def get_out_size(self): + if self.encode_proj: + return self.encode_proj.out_features + return self.config.hidden_size + + +class UniterEncoder(UniterPreTrainedModel): + def __init__(self, config, project_dim: int = 0): + super().__init__(config) + self.bert = UniterModel(config, IMG_DIM) + assert config.hidden_size > 0, 'Encoder hidden_size can\'t be zero' + # self.encode_proj = nn.Linear(config.hidden_size, project_dim) if project_dim != 0 else None # Yen-Chun + if project_dim > 0: + self.encode_proj = nn.Sequential( + nn.Linear(config.hidden_size, config.hidden_size * 2), + GELU(), + LayerNorm(config.hidden_size * 2, eps=1e-12), + nn.Linear(config.hidden_size * 2, project_dim) + ) + else: + self.encode_proj = None + self.apply(self.init_weights) + + @classmethod + def init_encoder(cls, cfg_name: str, checkpoint_path: str, project_dim: int = 0, dropout: float = 0.1, **kwargs)\ + -> UniterModel: + cfg = BertConfig.from_pretrained(cfg_name if cfg_name else 'bert-base-uncased') + if dropout != 0: + cfg.attention_probs_dropout_prob = dropout + cfg.hidden_dropout_prob = dropout + if checkpoint_path is not None and len(checkpoint_path) > 0 and checkpoint_path.lower() != 'none': + logger.info(f'load from {checkpoint_path} for uniter encoder') + state_dict = torch.load(checkpoint_path, map_location=torch.device('cpu')) + state_dict = torch.load(checkpoint_path) + #if 'model_dict' in state_dict: + # state_dict = state_dict['model_dict'] + else: + logger.info('no checkpoint, random initialization for img encoder') + state_dict = dict() + return cls.from_pretrained(cfg_name, state_dict=state_dict, project_dim=project_dim, **kwargs) + + def forward(self, input_ids, attention_mask, position_ids, + img_feat, img_pos_feat, img_masks, gather_index=None) -> Tuple[T, ...]: + if self.config.output_hidden_states: + sequence_output, pooled_output, hidden_states = self.bert(input_ids=input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + img_feat=img_feat, + img_pos_feat=img_pos_feat, + img_masks=img_masks, + img_type_ids=None, + gather_index=gather_index, + output_all_encoded_layers=True + ) + else: + hidden_states = None + sequence_output = self.bert(input_ids=input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + img_feat=img_feat, + img_pos_feat=img_pos_feat, + img_masks=img_masks, + img_type_ids=None, + gather_index=gather_index, + output_all_encoded_layers=False) + # pooled_output = self.bert.pooler(sequence_output) + pooled_output = sequence_output[:, 0, :] + if self.encode_proj: + pooled_output = self.encode_proj(pooled_output) + return sequence_output, pooled_output, hidden_states + + def get_out_size(self): + if self.encode_proj: + return self.encode_proj.out_features + return self.config.hidden_size + + +class BiEncoder(nn.Module): + """ Bi-Encoder model component. Encapsulates query/question and context/passage encoders. + """ + + def __init__(self, args, fix_img_encoder: bool = False, fix_txt_encoder: bool = False, project_dim: int = 0): + super(BiEncoder, self).__init__() + logger.info('*'*100) + logger.info('loading img model') + if args.img_model_type == 'uniter-base': + self.img_model = UniterEncoder.init_encoder(args.img_model_config, checkpoint_path=args.img_checkpoint, project_dim=project_dim) + else: + raise ValueError(f'image encoder does not support other types ({args.img_model_type}) for now') + + logger.info('*' * 100) + logger.info('loading txt model') + if args.txt_model_type == 'bert-base': + self.txt_model = BertEncoder.init_encoder(args.txt_model_config, checkpoint_path=args.txt_checkpoint, project_dim=project_dim) + elif args.txt_model_type == 'uniter-base': + self.txt_model = UniterEncoder.init_encoder(args.txt_model_config, checkpoint_path=args.txt_checkpoint, project_dim=project_dim) + else: + raise ValueError(f'txt encoder does not support other types ({args.txt_model_type}) for now') + + self.fix_img_encoder = fix_img_encoder + self.fix_txt_encoder = fix_txt_encoder + self.project_dim = project_dim + if fix_txt_encoder: + for param in self.txt_model.parameters(): + param.requires_grad = False + if fix_img_encoder: + for param in self.img_model.parameters(): + param.requires_grad = False + + @staticmethod + def get_representation(sub_model, input_ids, attention_mask, position_ids, img_feat, img_pos_feat, img_masks, + gather_index=None, fix_encoder=False): + if fix_encoder: + with torch.no_grad(): + sequence_output, pooled_output, hidden_states = sub_model(input_ids, attention_mask, position_ids, + img_feat, img_pos_feat, img_masks, + gather_index) + else: + sequence_output, pooled_output, hidden_states = sub_model(input_ids, attention_mask, position_ids, + img_feat, img_pos_feat, img_masks, + gather_index) + + if sub_model.training: + sequence_output.requires_grad_(requires_grad=True) + pooled_output.requires_grad_(requires_grad=True) + + return sequence_output, pooled_output, hidden_states + + def forward(self, batch, output_all_encoded_layers=False): + # batch keys + # imgs + # txts + # caps + batch = defaultdict(lambda: None, batch) + + if 'txts' in batch: + sb = batch['txts'] + txt_seq, txt_pooled, txt_hidden = self.get_representation(self.txt_model, sb['input_ids'], + sb['attention_mask'], sb['position_ids'], + sb['img_feat'], sb['img_pos_feat'], + sb['img_masks'], + sb['gather_index'], self.fix_txt_encoder) + else: + txt_seq, txt_pooled = None, None + + if 'imgs' in batch: + sb = batch['imgs'] + img_seq, img_pooled, img_hidden = self.get_representation(self.img_model, sb['input_ids'], + sb['attention_mask'], sb['position_ids'], + sb['img_feat'], sb['img_pos_feat'], + sb['img_masks'], + sb['gather_index'], self.fix_txt_encoder) + else: + img_seq, img_pooled = None, None + + if 'caps' in batch and batch['caps']['input_ids'] is not None: + sb = batch['caps'] + cap_seq, cap_pooled, cap_hidden = self.get_representation(self.txt_model, sb['input_ids'], + sb['attention_mask'], sb['position_ids'], + sb['img_feat'], sb['img_pos_feat'], + sb['img_masks'], + sb['gather_index'], self.fix_txt_encoder) + else: + cap_seq, cap_pooled = None, None + + if output_all_encoded_layers: + return txt_seq, img_seq, cap_seq + else: + return txt_pooled, img_pooled, cap_pooled + + +class BiEncoderForPretraining(nn.Module): + """ MLM + MRM """ + def __init__(self, config_file, args, project_dim, img_dim, img_label_dim, nce_temp=1, ot_pos_only=False, + experiment=None): + super().__init__() + config = UniterConfig.from_json_file(config_file) + self.bert = BiEncoder(args, project_dim=project_dim) + self.cls = BertOnlyMLMHead( + config, self.bert.img_model.bert.embeddings.word_embeddings.weight) # ??? + self.feat_regress = RegionFeatureRegression( + config.hidden_size, img_dim, + self.bert.img_model.bert.img_embeddings.img_linear.weight) + self.region_classifier = RegionClassification( + config.hidden_size, img_label_dim) + self.itm_output = nn.Linear(config.hidden_size, 2) + self.cls_concat = args.cls_concat + ''' + self.nce_output = BertPredictionHeadTransform(config) + self.nce_output = nn.Sequential(BertPredictionHeadTransform(config), + nn.Linear(config.hidden_size, img_dim)) + self.nce_norm = LayerNorm(config.hidden_size, eps=1e-12) + self.nce_temp = nce_temp # temperature + ''' + self.ot_pos_only = ot_pos_only + # self.apply(self.init_weights) + self.vocab_pad = 0 + self.experiment = experiment + + def pad_vocab(self): + # FIXME better padding after integrating huggingface ??? + emb_w = self.bert.embeddings.word_embeddings.weight.data + padded_emb_w, n_pad = pad_tensor_to_mul(emb_w) + padded_emb_w = nn.Parameter(padded_emb_w) + self.bert.embeddings.word_embeddings.weight = padded_emb_w + self.cls.predictions.decoder.weight = padded_emb_w + self.vocab_pad = n_pad + + def forward(self, batch, task, compute_loss=True): + batch = defaultdict(lambda: None, batch) + if task == 'mlm': + txt_labels = batch['txt_labels'] + return self.forward_mlm(batch, txt_labels, compute_loss) + elif task == 'mrfr': + img_mask_tgt = batch['img_mask_tgt'] + img_masks = batch['img_masks'] + mrfr_feat_target = batch['feat_targets'] + return self.forward_mrfr(batch, img_masks, img_mask_tgt, mrfr_feat_target, compute_loss) + elif task == 'mrm-nce': + raise NotImplementedError('nce does not work') + img_mask_tgt = batch['img_mask_tgt'] + img_masks = batch['img_masks'] + img_masks_in = batch['img_masks_in'] + feat_target = batch['feat_targets'] + neg_feats = batch['neg_feats'] + return self.forward_mrm_nce(batch, + img_masks_in, img_masks, img_mask_tgt, + feat_target, neg_feats, compute_loss) + elif task == 'itm': + targets = batch['targets'] + ot_inputs = batch['ot_inputs'] + return self.forward_itm(batch, + targets, ot_inputs, compute_loss) + elif task.startswith('mrc'): + img_mask_tgt = batch['img_mask_tgt'] + img_masks = batch['img_masks'] + mrc_label_target = batch['label_targets'] + return self.forward_mrc(batch, + img_masks, img_mask_tgt, + mrc_label_target, task, compute_loss) + else: + raise ValueError('invalid task') + + # MLM + def forward_mlm(self, batch, txt_labels, compute_loss=True): + txt_seq, img_seq, cap_seq = self.bert(batch, output_all_encoded_layers=True) + # get only the text part + + img_cls = img_seq[:, 0:1, :].repeat(1, txt_seq.shape[1], 1) + if self.cls_concat == 'add': + sequence_output = txt_seq + img_cls + elif self.cls_concat == 'multiply': + sequence_output = txt_seq * img_cls + elif len(self.cls_concat) == 0: + sequence_output = txt_seq + else: + raise NotImplementedError(f'{self.cls_concat} not implemented yet') + # only compute masked tokens for better efficiency + masked_output = self._compute_masked_hidden(sequence_output, + txt_labels != -1) + prediction_scores = self._pad_layer_unpad(masked_output, self.cls) + if self.vocab_pad: + prediction_scores = prediction_scores[:, :-self.vocab_pad] + + masked_lm_loss = F.cross_entropy(prediction_scores, + txt_labels[txt_labels != -1], + reduction='none') + return masked_lm_loss, prediction_scores + + def _compute_masked_hidden(self, hidden, mask): + """ get only the masked region (don't compute unnecessary hiddens) """ + mask = mask.unsqueeze(-1).expand_as(hidden) + hidden_masked = hidden[mask].contiguous().view(-1, hidden.size(-1)) + return hidden_masked + + def _pad_layer_unpad(self, input_, layer): + input_, n_pad = pad_tensor_to_mul(input_) + output = layer(input_) + if n_pad: + output = output[:-n_pad, :] + return output + + def mlm_eval(self, batch, gather_tgt): + raise ValueError('Do not use this') + sequence_output = self.bert(batch, output_all_encoded_layers=False) + # get only the text part (excluding [CLS], [SEP]) + sequence_output = sequence_output[:, 1:input_ids.size(1)-1, :] + # only compute masked tokens for better efficiency + index = gather_tgt.unsqueeze(-1).expand( + -1, -1, self.config.hidden_size) + masked_output = torch.gather(sequence_output, dim=0, index=index) + prediction_scores = self.cls(masked_output) + if self.vocab_pad: + prediction_scores = prediction_scores[..., :-self.vocab_pad] + return prediction_scores + + # MRFR + def forward_mrfr(self, batch, img_masks, img_mask_tgt, + feat_targets, compute_loss=True): + txt_seq, img_seq, cap_seq = self.bert(batch, output_all_encoded_layers=True) + txt_cls = txt_seq[:, 0:1, :].repeat(1, img_seq.shape[1], 1) + if self.cls_concat == 'add': + sequence_output = img_seq + txt_cls + elif self.cls_concat == 'multiply': + sequence_output = img_seq * txt_cls + elif len(self.cls_concat) == 0: + sequence_output = img_seq + else: + raise NotImplementedError(f'{self.cls_concat} not implemented yet') + # only compute masked tokens for better efficiency + masked_output = self._compute_masked_hidden(sequence_output, + img_mask_tgt) + prediction_feat = self._pad_layer_unpad(masked_output, + self.feat_regress) + + mrfr_loss = F.mse_loss(prediction_feat, feat_targets, + reduction='none') + return mrfr_loss, prediction_feat + + # MRM-NCE + def forward_mrm_nce(self,batch, + img_masks_in, img_masks, img_mask_tgt, + feat_targets, neg_feats, compute_loss=True): + sequence_output = self.bert(batch, + output_all_encoded_layers=False, + img_masks=img_masks_in) + + # only compute masked tokens for better efficiency + masked_output = self._compute_masked_hidden(sequence_output, + img_mask_tgt) + + masked_output = self._pad_layer_unpad(masked_output, self.nce_output) + # neg within batch + batch_neg = self._compute_masked_hidden(img_feat, ~img_masks) + neg_feats, _ = pad_tensor_to_mul( + torch.cat([neg_feats, batch_neg], dim=0)) + + # shared image linear transform + neg_output = self.nce_norm( + self.bert.img_embeddings.img_linear(neg_feats)) + pos_output = self._pad_layer_unpad(feat_targets, + self.bert.img_embeddings.img_linear) + pos_output = self.nce_norm(pos_output) + + mrm_nce_loss = self.mrm_nce(masked_output, pos_output, + neg_output, compute_loss=True) + return mrm_nce_loss, masked_output + + def mrm_nce(self, masked_output, pos_output, neg_output, + compute_loss=True): + # dot product of ground truth feature + masked_score = masked_output.matmul(pos_output.t()) + # dot product of neative samples + neg_score = masked_output.matmul(neg_output.t()) + + logits = torch.cat([masked_score, neg_score], dim=1).float() + targets = torch.arange(0, masked_output.size(0), + dtype=torch.long, device=logits.device) + loss = F.cross_entropy(logits/self.nce_temp, targets, + reduction='none') + return loss, logits + + def forward_itm(self, batch, targets, ot_inputs, + compute_loss=True): + txt_seq, img_seq, cap_seq = self.bert(batch, output_all_encoded_layers=False) + # OT loss + if ot_inputs is not None: + ot_scatter = ot_inputs['ot_scatter'] + + b = sequence_output.size(0) + tl = input_ids.size(1) + il = img_feat.size(1) + max_l = max(ot_inputs['scatter_max'] + 1, tl+il) + + ot_scatter = ot_scatter.unsqueeze(-1).expand_as(sequence_output) + ctx_emb = torch.zeros(b, max_l, self.config.hidden_size, + dtype=sequence_output.dtype, + device=sequence_output.device + ).scatter_(dim=1, index=ot_scatter, + src=sequence_output) + txt_emb = ctx_emb[:, :tl, :] + img_emb = ctx_emb[:, tl:tl+il, :] + + txt_pad = ot_inputs['txt_pad'] + img_pad = ot_inputs['img_pad'] + ot_dist = optimal_transport_dist(txt_emb, img_emb, + txt_pad, img_pad) + if self.ot_pos_only: + ot_loss = ot_dist.masked_select(targets == 1) + else: + ot_pos_dist = ot_dist.masked_select(targets == 1) + ot_neg_dist = ot_dist.masked_select(targets == 0) + ot_loss = (ot_pos_dist, ot_neg_dist) + else: + ot_loss = None + + loss_function = BiEncoderNllLoss() + itm_loss1, is_correct1, scores1 = loss_function.calc(txt_seq, img_seq, cap_seq, + batch['pos_ctx_indices'], + batch['neg_ctx_indices'], + 0.0, self.experiment, 'none') + itm_loss2, is_correct2, scores2 = loss_function.calc(img_seq, txt_seq, cap_seq, + batch['pos_ctx_indices'], + batch['neg_ctx_indices'], + 0.0, self.experiment, 'none') + if compute_loss: + return itm_loss1*0.5 + itm_loss2*0.5, ot_loss + else: + return itm_loss1*0.5 + itm_loss2*0.5, ot_loss, is_correct1*0.5 + is_correct2*0.5 + + # MRC + def forward_mrc(self, batch, img_masks, img_mask_tgt, + label_targets, task, compute_loss=True): + txt_seq, img_seq, cap_seq = self.bert(batch, output_all_encoded_layers=True) + txt_cls = txt_seq[:, 0:1, :].repeat(1, img_seq.shape[1], 1) + if self.cls_concat == 'add': + sequence_output = img_seq + txt_cls + elif self.cls_concat == 'multiply': + sequence_output = img_seq * txt_cls + elif len(self.cls_concat) == 0: + sequence_output = img_seq + else: + raise NotImplementedError(f'{self.cls_concat} not implemented yet') + + # sequence_output = torch.cat([txt_seq, img_seq], dim=1) + # only compute masked regions for better efficiency + masked_output = self._compute_masked_hidden(sequence_output, img_mask_tgt) + prediction_soft_label = self._pad_layer_unpad(masked_output, + self.region_classifier) + + if "kl" in task: + prediction_soft_label = F.log_softmax( + prediction_soft_label, dim=-1) + mrc_loss = F.kl_div( + prediction_soft_label, label_targets, reduction='none') + else: + # background class should not be the target + label_targets = torch.max(label_targets[:, 1:], dim=-1)[1] + 1 + mrc_loss = F.cross_entropy( + prediction_soft_label, label_targets, + ignore_index=0, reduction='none') + return mrc_loss, prediction_soft_label + + +def get_optimizer(model: nn.Module, learning_rate: float = 1e-5, adam_eps: float = 1e-8, + weight_decay: float = 0.0, ) -> torch.optim.Optimizer: + no_decay = ['bias', 'LayerNorm.weight'] + + optimizer_grouped_parameters = [ + {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], + 'weight_decay': weight_decay}, + {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} + ] + optimizer = AdamW(optimizer_grouped_parameters, lr=learning_rate, eps=adam_eps) + return optimizer + + +def setup_for_distributed_mode(model: nn.Module, optimizer: torch.optim.Optimizer, device: object, n_gpu: int = 1, + local_rank: int = -1, + fp16: bool = False, + fp16_opt_level: str = "O1", + teacher_model = None) -> (nn.Module, torch.optim.Optimizer): + model.to(device) + if teacher_model is not None: + teacher_model.to(device) + if fp16: + try: + import apex + from apex import amp + apex.amp.register_half_function(torch, "einsum") + except ImportError: + raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") + + if optimizer is None: + if teacher_model is None: + model = amp.initialize(model, optimizer, opt_level=fp16_opt_level) + else: + model, teacher_model = amp.initialize([model, teacher_model], optimizer, opt_level=fp16_opt_level) + else: + model, optimizer = amp.initialize(model, optimizer, opt_level=fp16_opt_level) + + #if n_gpu > 1: + # model = torch.nn.DataParallel(model) + + # if local_rank != -1: + # model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank], + # output_device=local_rank, + # find_unused_parameters=True) + return model, optimizer + + +class BiEncoderNllLoss(object): + + def calc(self, q_vectors: T, ctx_vectors: T, caption_vectors: T, positive_idx_per_question: list, + hard_negatice_idx_per_question: list = None, caption_score_weight: float = 0.1, + experiment=None, reduction='mean'): + """ + Computes nll loss for the given lists of question and ctx vectors. + Note that although hard_negatice_idx_per_question in not currently in use, one can use it for the + loss modifications. For example - weighted NLL with different factors for hard vs regular negatives. + :return: a tuple of loss value and amount of correct predictions per batch + """ + scores_img = self.get_scores(q_vectors, ctx_vectors) + if caption_vectors is not None and caption_score_weight != 0: + scores_caption = self.get_scores(q_vectors, caption_vectors) + scores = (1 - caption_score_weight) * scores_img + caption_score_weight * scores_caption + else: + scores = scores_img + + if experiment is not None: + experiment.log_metric('score_img_diag_mean', torch.diag(scores_img).mean().item()) + experiment.log_metric('score_img_offdiag_mean', (scores_img.sum() - torch.diag(scores_img).sum()) / + (torch.numel(scores_img)-len(torch.diag(scores_img)))) + + experiment.log_metric('score_diag_mean', torch.diag(scores).mean().item()) + experiment.log_metric('score_offdiag_mean', (scores.sum() - torch.diag(scores).sum()) / + (torch.numel(scores) - len(torch.diag(scores)))) + + if caption_vectors is not None and caption_score_weight != 0: + experiment.log_metric('score_caption_diag_mean', torch.diag(scores_caption).mean().item()) + experiment.log_metric('score_caption_offdiag_mean', (scores_caption.sum() - torch.diag(scores_caption).sum()) / + (torch.numel(scores_caption) - len(torch.diag(scores_caption)))) + + if len(q_vectors.size()) > 1: + q_num = q_vectors.size(0) + scores = scores.view(q_num, -1) + + softmax_scores = F.log_softmax(scores, dim=1) + + loss = F.nll_loss(softmax_scores, torch.tensor(positive_idx_per_question).to(softmax_scores.device), + reduction=reduction) + + max_score, max_idxs = torch.max(softmax_scores, 1) + correct_predictions_count = (max_idxs == torch.tensor(positive_idx_per_question).to(max_idxs.device)).sum() + return loss, correct_predictions_count, scores + + @staticmethod + def get_scores(q_vector: T, ctx_vectors: T) -> T: + f = BiEncoderNllLoss.get_similarity_function() + return f(q_vector, ctx_vectors) + + @staticmethod + def get_similarity_function(): + return dot_product_scores + + +def get_schedule_linear(optimizer, warmup_steps, training_steps, last_epoch=-1): + """ Create a schedule with a learning rate that decreases linearly after + linearly increasing during a warmup period. + """ + + def lr_lambda(current_step): + if current_step < warmup_steps: + return float(current_step) / float(max(1, warmup_steps)) + return max( + 0.0, float(training_steps - current_step) / float(max(1, training_steps - warmup_steps)) + ) + + return LambdaLR(optimizer, lr_lambda, last_epoch) + + +class BiEncoderForVisualQuestionAnswering(nn.Module): + """ Finetune multi-modal BERT for VQA + """ + def __init__(self, args, fix_img_encoder: bool = False, fix_txt_encoder: bool = False, + seperate_caption_encoder: bool = False, + project_dim: int = 0, + hidden_size: int = 0, num_answer: int = 0, intersection=False): + super(BiEncoderForVisualQuestionAnswering, self).__init__() + self.biencoder = BiEncoder(args, fix_img_encoder, fix_txt_encoder, project_dim) + self.intersection = intersection + if self.intersection: + hidden_size *= 2 + self.vqa_output = nn.Sequential( + nn.Linear(hidden_size, hidden_size*2), + GELU(), + LayerNorm(hidden_size*2, eps=1e-12), + nn.Linear(hidden_size*2, num_answer) + ) + self.init_weights(self.vqa_output) + + def forward(self, batch, compute_loss=True, targets=None) -> Tuple[T, T]: + + q_pooled, ctx_pooled, caption_pooled = self.biencoder(batch) + + if self.intersection: + pooled_output = torch.cat([q_pooled, ctx_pooled, q_pooled*ctx_pooled, q_pooled + ctx_pooled], dim=1) + else: + pooled_output = torch.cat([q_pooled, ctx_pooled], dim=1) + + answer_scores = self.vqa_output(pooled_output) + + if compute_loss: + vqa_loss = F.binary_cross_entropy_with_logits( + answer_scores, targets, reduction='none') + return vqa_loss + else: + return answer_scores + + def init_weights(self, module): + """ Initialize the weights. + """ + if isinstance(module, (nn.Linear, nn.Embedding)): + # Slightly different from the TF version which uses + # truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, + std=0.02) + elif isinstance(module, LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + +def load_biencoder_checkpoint(bi_encoder, biencoder_checkpoint): + if biencoder_checkpoint is not None and len(biencoder_checkpoint) > 0 and biencoder_checkpoint.lower() != 'none': + logger.info(f'loading ckpt from {biencoder_checkpoint}') + state_dict = torch.load(biencoder_checkpoint, map_location='cpu') + try: + bi_encoder.load_state_dict(state_dict['model_dict']) + except KeyError: + logger.info('loading from pre-trained model instead') + for k in list(state_dict.keys()): + if k.startswith('bert.'): + state_dict[k[5:]] = state_dict.pop(k) + else: + state_dict.pop(k) + bi_encoder.load_state_dict(state_dict, strict=True) + else: + logger.info('no checkpoint provided, pass') diff --git a/dvl/options.py b/dvl/options.py new file mode 100644 index 0000000..7a76525 --- /dev/null +++ b/dvl/options.py @@ -0,0 +1,176 @@ +import argparse +import json +import sys +import os +import logging +import torch +import random +import socket +import numpy as np + + +logger = logging.getLogger() + + +def default_params(parser: argparse.ArgumentParser): + parser.add_argument('--txt_model_type', default='bert-base', type=str, help="") + parser.add_argument('--txt_model_config', default='bert-base', type=str, help="") + parser.add_argument('--txt_checkpoint', default=None, type=str, help="") + parser.add_argument('--img_model_type', default='uniter-base', type=str, help="") + parser.add_argument('--img_model_config', default='./config/img_base.json', type=str, help="") + parser.add_argument('--img_checkpoint', default=None, type=str, help="") + parser.add_argument('--biencoder_checkpoint', default=None, type=str, help="") + parser.add_argument('--seperate_caption_encoder', action='store_true', help="") + + parser.add_argument('--train_batch_size', default=80, type=int, help="") + parser.add_argument('--valid_batch_size', default=80, type=int, help="") + parser.add_argument('--gradient_accumulation_steps', default=1, type=int, help="") + parser.add_argument('--learning_rate', default=1e-5, type=float, help="") + parser.add_argument('--max_grad_norm', default=2.0, type=float, help="") + parser.add_argument('--warmup_steps', default=500, type=int, help="") + parser.add_argument('--valid_steps', default=500, type=int, help="") + parser.add_argument('--num_train_steps', default=5000, type=int, help="") + parser.add_argument('--num_train_epochs', default=0, type=int, help="") + + parser.add_argument('--fp16', action='store_true', help="") + parser.add_argument('--seed', default=42, type=int, help="") + parser.add_argument('--output_dir', default='./', type=str, help="") + parser.add_argument('--max_txt_len', default=64, type=int, help="") + parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus") + parser.add_argument('--config', default=None, type=str, help="") + parser.add_argument('--itm_global_file', default=None, type=str, help="") + parser.add_argument("--no_cuda", action='store_true', help="Whether not to use CUDA when available") + parser.add_argument('--n_workers', type=int, default=2, help="number of data workers") + parser.add_argument('--pin_mem', action='store_true', help="pin memory") # ??? + parser.add_argument('--hnsw_index', action='store_true', help="") + parser.add_argument('--fp16_opt_level', type=str, default='O1', help="") + parser.add_argument('--img_meta', type=str, default=None, help="") + + +def add_itm_params(parser: argparse.ArgumentParser): + parser.add_argument('--conf_th', default=0.2, type=float, help="") + parser.add_argument('--caption_score_weight', default=0.0, type=float, help="") + parser.add_argument('--negative_size', default=10, type=int, help="") + parser.add_argument('--num_hard_negatives', default=0, type=int, help="") + parser.add_argument('--sample_init_hard_negatives', action='store_true', help="") + parser.add_argument('--hard_negatives_sampling', default='none', type=str, + choices=['none', 'random', 'top', 'top-random', '10-20', '20-30'], help="") + parser.add_argument('--max_bb', default=100, type=int, help="") + parser.add_argument('--min_bb', default=10, type=int, help="") + parser.add_argument('--num_bb', default=36, type=int, help="") + parser.add_argument('--train_txt_dbs', default=None, type=str, help="") + parser.add_argument('--train_img_dbs', default=None, type=str, help="") + + parser.add_argument('--txt_db_mapping', default=None, type=str, help="") + parser.add_argument('--img_db_mapping', default=None, type=str, help="") + parser.add_argument('--pretrain_mapping', default=None, type=str, help="") + + parser.add_argument('--val_txt_db', default=None, type=str, help="") + parser.add_argument('--val_img_db', default=None, type=str, help="") + parser.add_argument('--test_txt_db', default=None, type=str, help="") + parser.add_argument('--test_img_db', default=None, type=str, help="") + parser.add_argument('--steps_per_hard_neg', default=-1, type=int, help="") + parser.add_argument('--inf_minibatch_size', default=400, type=int, help="") + parser.add_argument('--project_dim', default=0, type=int, help='') + parser.add_argument('--cls_concat', default="", type=str, help='') + parser.add_argument('--fix_txt_encoder', action='store_true', help='') + parser.add_argument('--fix_img_encoder', action='store_true', help='') + parser.add_argument('--compressed_db', action='store_true', help='use compressed LMDB') + parser.add_argument('--retrieval_mode', default="both", + choices=['img_only', 'txt_only', 'both'], type=str, help="") + + +def add_logging_params(parser: argparse.ArgumentParser): + parser.add_argument('--log_result_step', default=4, type=int, help="") + parser.add_argument('--project_name', default='itm', type=str, help="") + parser.add_argument('--expr_name_prefix', default='', type=str, help="") + parser.add_argument('--save_all_epochs', action='store_true', help="") + + +def add_kd_params(parser: argparse.ArgumentParser): + parser.add_argument('--teacher_checkpoint', default=None, type=str, help="") + parser.add_argument('--T', default=1.0, type=float, help="") + parser.add_argument('--kd_loss_weight', default=1.0, type=float, help="") + + +def parse_with_config(parser, cmds=None): + if cmds is None: + args = parser.parse_args() + else: + args = parser.parse_args(cmds) + + if args.config is not None: + config_args = json.load(open(args.config)) + override_keys = {arg[2:].split('=')[0] for arg in sys.argv[1:] + if arg.startswith('--')} + for k, v in config_args.items(): + if k not in override_keys: + setattr(args, k, v) + return args + + +def map_db_dirs(args): + # map img db + for k in args.__dict__: + if not isinstance(args.__dict__[k], str): + continue + if args.__dict__[k].startswith('/pretrain') and args.pretrain_mapping: + print('pretrain', k, args.__dict__[k]) + args.__dict__[k] = args.__dict__[k].replace('/pretrain', args.pretrain_mapping) + if args.__dict__[k].startswith('/db') and args.txt_db_mapping: + print('db', k, args.__dict__[k]) + args.__dict__[k] = args.__dict__[k].replace('/db', args.txt_db_mapping) + if args.__dict__[k].startswith('/img') and args.img_db_mapping: + print('img', k, args.__dict__[k]) + args.__dict__[k] = args.__dict__[k].replace('/img', args.img_db_mapping) + + if args.img_db_mapping: + for i in range(len(args.train_img_dbs)): + args.train_img_dbs[i] = args.train_img_dbs[i].replace('/img', args.img_db_mapping) + if args.txt_db_mapping: + for i in range(len(args.train_txt_dbs)): + args.train_txt_dbs[i] = args.train_txt_dbs[i].replace('/db', args.txt_db_mapping) + + + + +def print_args(args): + logger.info(" **************** CONFIGURATION **************** ") + for key, val in sorted(vars(args).items()): + keystr = "{}".format(key) + (" " * (30 - len(key))) + logger.info("%s --> %s", keystr, val) + logger.info(" **************** END CONFIGURATION **************** ") + + +def set_seed(args): + seed = args.seed + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if args.n_gpu > 0: + torch.cuda.manual_seed_all(seed) + + +def setup_args_gpu(args): + """ + Setup arguments CUDA, GPU & distributed training + """ + if args.local_rank == -1 or args.no_cuda: # single-node multi-gpu (or cpu) mode + device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") + args.n_gpu = torch.cuda.device_count() + else: # distributed mode + torch.cuda.set_device(args.local_rank) + device = torch.device("cuda", args.local_rank) + torch.distributed.init_process_group(backend="nccl") + args.n_gpu = 1 + args.device = device + ws = os.environ.get('WORLD_SIZE') + + args.distributed_world_size = int(ws) if ws else 1 + + logger.info( + 'Initialized host %s as d.rank %d on device=%s, n_gpu=%d, world size=%d', socket.gethostname(), + args.local_rank, device, + args.n_gpu, + args.distributed_world_size) + logger.info("16-bits training: %s ", args.fp16) diff --git a/dvl/trainer.py b/dvl/trainer.py new file mode 100644 index 0000000..2e0720b --- /dev/null +++ b/dvl/trainer.py @@ -0,0 +1,209 @@ +import collections +import os +import torch +import tqdm +import logging +import numpy as np +import torch.nn as nn +from torch.utils.data import DataLoader, ConcatDataset, ChainDataset +from uniter_model.data.loader import PrefetchLoader + +from dvl.data.itm import TxtTokLmdb, ItmFastDataset, ItmValDataset, itm_fast_collate +from dvl.models.bi_encoder import BiEncoderNllLoss +from dvl.utils import _calc_loss +from dvl.indexer.faiss_indexers import DenseFlatIndexer, DenseHNSWFlatIndexer + + +logger = logging.getLogger() +CheckpointState = collections.namedtuple("CheckpointState", + ['model_dict', 'optimizer_dict', 'scheduler_dict', 'offset', 'epoch', + 'encoder_params']) + + +class BiEncoderTrainer: + def __init__(self, args): + pass + + +def build_dataloader(dataset, collate_fn, is_train, opts, batch_size=None): + if batch_size is None: + batch_size = opts.train_batch_size if is_train else opts.valid_batch_size + + dataloader = DataLoader(dataset, batch_size=batch_size, + shuffle=is_train, drop_last=False, + num_workers=opts.n_workers, + pin_memory=opts.pin_mem, collate_fn=collate_fn) + dataloader = PrefetchLoader(dataloader) + return dataloader + + +def get_model_obj(model: nn.Module): + return model.module if hasattr(model, 'module') else model + + +def _save_checkpoint(args, biencoder, optimizer, scheduler, epoch: int, offset: int, cp_name: str = None) -> str: + model_to_save = get_model_obj(biencoder) + if cp_name is None: + cp = os.path.join(args.output_dir, 'biencoder.' + str(epoch) + ('.' + str(offset) if offset > 0 else '')) + else: + cp = os.path.join(args.output_dir, 'biencoder.' + cp_name) + cp += '.pt' + + + meta_params = None + + state = CheckpointState(model_to_save.state_dict(), + optimizer.state_dict(), + scheduler.state_dict(), + offset, + epoch, meta_params + ) + torch.save(state._asdict(), cp) + logger.info('Saved checkpoint at %s', cp) + return cp + + +def load_saved_state(biencoder, optimizer=None, scheduler=None, saved_state: CheckpointState = ''): + epoch = saved_state.epoch + offset = saved_state.offset + if offset == 0: # epoch has been completed + epoch += 1 + logger.info('Loading checkpoint @ batch=%s and epoch=%s', offset, epoch) + + model_to_load = get_model_obj(biencoder) + logger.info('Loading saved model state ...') + model_to_load.load_state_dict(saved_state.model_dict) # set strict=False if you use extra projection + + if saved_state.optimizer_dict and optimizer is not None: + logger.info('Loading saved optimizer state ...') + optimizer.load_state_dict(saved_state.optimizer_dict) + + if saved_state.scheduler_dict and scheduler is not None: + scheduler_state = saved_state.scheduler_dict + scheduler.load_state_dict(scheduler_state) + + +def load_states_from_checkpoint(model_file: str) -> CheckpointState: + logger.info('Reading saved model from %s', model_file) + state_dict = torch.load(model_file, map_location='cpu') + logger.info('model_state_dict keys %s', state_dict.keys()) + return CheckpointState(**state_dict) + + +def get_indexer(bi_encoder, eval_dataloader, args, hnsw_index, img_retrieval=True): + bi_encoder.eval() + img_embedding = dict() + + if hnsw_index: + indexer_img = DenseHNSWFlatIndexer(args.vector_size) # modify in future + else: + indexer_img = DenseFlatIndexer(args.vector_size) # modify in future + for i, batch in enumerate(tqdm.tqdm(eval_dataloader)): + with torch.no_grad(): + model_out = bi_encoder(batch) + local_q_vector, local_ctx_vectors, local_caption_vectors = model_out + if img_retrieval: + img_embedding.update({img_id: img_vec.detach().cpu().numpy() for img_id, img_vec in zip(batch['img_fname'], local_ctx_vectors)}) + else: + img_embedding.update({img_id: txt_vec.detach().cpu().numpy() for img_id, txt_vec in zip(batch['txt_index'], local_q_vector)}) + indexer_img.index_data(list(img_embedding.items())) + return indexer_img + + +def eval_model_on_dataloader(bi_encoder, eval_dataloader, args, img2txt=None, num_tops=100, no_eval=False): + total_loss = 0.0 + bi_encoder.eval() + total_correct_predictions = 0 + batches, total_samples = 0, 0 + labels_img_name = [] + labels_txt_name = [] + img_embedding = dict() + txt_embedding = dict() + if args.hnsw_index: + indexer_img = DenseHNSWFlatIndexer(args.vector_size) # modify in future + indexer_txt = DenseHNSWFlatIndexer(args.vector_size) # modify in future + else: + indexer_img = DenseFlatIndexer(args.vector_size) # modify in future + indexer_txt = DenseFlatIndexer(args.vector_size) # modify in future + query_txt, query_txt_id = [], [] + query_img, query_img_id = [], [] + for i, batch in enumerate(eval_dataloader): + with torch.no_grad(): + model_out = bi_encoder(batch) + local_q_vector, local_ctx_vectors, local_caption_vectors = model_out + + query_txt.extend([out.view(-1).detach().cpu().numpy() for out in local_q_vector]) + query_txt_id.extend(batch['txt_index']) + + query_img.extend([out.view(-1).detach().cpu().numpy() for out in local_ctx_vectors]) + query_img_id.extend(batch['img_fname']) + + loss_function = BiEncoderNllLoss() + + loss, correct_cnt, score = _calc_loss(args, loss_function, local_q_vector, local_ctx_vectors, local_caption_vectors, + list(range(len(local_q_vector))), None) + + total_loss += loss.item() + total_correct_predictions += correct_cnt.sum().item() + batches += 1 + total_samples += batch['txts']['input_ids'].shape[0] + + img_embedding.update({img_id: img_vec.detach().cpu().numpy() for img_id, img_vec in zip(batch['img_fname'], local_ctx_vectors)}) + txt_embedding.update({img_id: txt_vec.detach().cpu().numpy() for img_id, txt_vec in zip(batch['txt_index'], local_q_vector)}) + labels_img_name.extend(batch['img_fname']) + labels_txt_name.extend(batch['txt_index']) + + total_loss = total_loss / batches + correct_ratio = total_correct_predictions / float(total_samples) + + query_txt_np = np.array(query_txt) + indexer_img.index_data(list(img_embedding.items())) + query_img_np = np.array(query_img) + indexer_txt.index_data(list(txt_embedding.items())) + + if no_eval: + return total_loss, correct_ratio, (indexer_img, indexer_txt), (None, None), (None, None) + else: + res_txt = indexer_img.search_knn(query_txt_np, num_tops) + rank_txt_res = {query_txt_id[i]: r[0] for i, r in enumerate(res_txt)} + + res_img = indexer_txt.search_knn(query_img_np, num_tops) + rank_img_res = {query_img_id[i]: r[0] for i, r in enumerate(res_img)} + + recall_txt = {1: 0, 5: 0, 10: 0} + for i, q in enumerate(query_txt_id): + for top in recall_txt: + recall_txt[top] += labels_img_name[i] in rank_txt_res[q][:top] + + for top in recall_txt: + recall_txt[top] = recall_txt[top] / len(rank_txt_res) + + recall_img = {1: 0, 5: 0, 10: 0} + for i, q in enumerate(np.unique(query_img_id)): + for top in recall_img: + # recall_img[top] += any([txt_id in rank_img_res[q][:top] for txt_id in img2txt[q]]) + recall_img[top] += any([txt_id in rank_img_res[q][:top] for txt_id in img2txt[q]]) + + for top in recall_img: + recall_img[top] = recall_img[top] / len(rank_img_res) + + return total_loss, correct_ratio, (indexer_img, indexer_txt), (recall_txt, recall_img), (rank_txt_res, rank_img_res) + + +def load_dataset(all_img_dbs, txt_dbs, img_dbs, args, is_train): + if is_train: + # train datasets + datasets = [] + for txt_path, img_path in zip(txt_dbs, img_dbs): + img_db = all_img_dbs[img_path] + txt_db = TxtTokLmdb(txt_path, args.max_txt_len) + datasets.append(ItmFastDataset(txt_db, img_db, args.num_hard_negatives, args.img_meta, args.tokenizer)) + + datasets = ConcatDataset(datasets) # + else: + # eval or test + img_db = all_img_dbs[img_dbs] + txt_db = TxtTokLmdb(txt_dbs, -1) + datasets = ItmFastDataset(txt_db, img_db, args.inf_minibatch_size, args.img_meta, args.tokenizer) + + return datasets diff --git a/dvl/utils.py b/dvl/utils.py new file mode 100644 index 0000000..573133f --- /dev/null +++ b/dvl/utils.py @@ -0,0 +1,234 @@ +import logging +import random +import tqdm +import torch +import pickle + +import torch.distributed as dist + +from collections import defaultdict +from horovod import torch as hvd +from torch import Tensor as T +from typing import Tuple + + +logger = logging.getLogger() + + +def get_rank(): + return hvd.rank() + + +def get_world_size(): + return hvd.size() + + +def print_args(args): + logger.info(" **************** CONFIGURATION **************** ") + for key, val in sorted(vars(args).items()): + keystr = "{}".format(key) + (" " * (30 - len(key))) + logger.info("%s --> %s", keystr, val) + logger.info(" **************** CONFIGURATION **************** ") + + +def num_of_parameters(model, requires_grad=False): + if requires_grad: + return sum(p.numel() for p in model.parameters() if p.requires_grad) + else: + return sum(p.numel() for p in model.parameters()) + + +def get_default_group(): + return dist.group.WORLD + + +def all_reduce(tensor, group=None): + if group is None: + group = get_default_group() + return dist.all_reduce(tensor, group=group) + + +def all_gather_list(data, group=None, max_size=16384): + """Gathers arbitrary data from all nodes into a list. + Similar to :func:`~torch.distributed.all_gather` but for arbitrary Python + data. Note that *data* must be picklable. + Args: + data (Any): data from the local worker to be gathered on other workers + group (optional): group of the collective + """ + SIZE_STORAGE_BYTES = 4 # int32 to encode the payload size + + enc = pickle.dumps(data) + enc_size = len(enc) + + if enc_size + SIZE_STORAGE_BYTES > max_size: + raise ValueError( + 'encoded data exceeds max_size, this can be fixed by increasing buffer size: {}'.format(enc_size)) + + rank = get_rank() + world_size = get_world_size() + buffer_size = max_size * world_size + + if not hasattr(all_gather_list, '_buffer') or \ + all_gather_list._buffer.numel() < buffer_size: + all_gather_list._buffer = torch.cuda.ByteTensor(buffer_size) + all_gather_list._cpu_buffer = torch.ByteTensor(max_size).pin_memory() + + buffer = all_gather_list._buffer + buffer.zero_() + cpu_buffer = all_gather_list._cpu_buffer + + assert enc_size < 256 ** SIZE_STORAGE_BYTES, 'Encoded object size should be less than {} bytes'.format( + 256 ** SIZE_STORAGE_BYTES) + + size_bytes = enc_size.to_bytes(SIZE_STORAGE_BYTES, byteorder='big') + + cpu_buffer[0:SIZE_STORAGE_BYTES] = torch.ByteTensor(list(size_bytes)) + cpu_buffer[SIZE_STORAGE_BYTES: enc_size + SIZE_STORAGE_BYTES] = torch.ByteTensor(list(enc)) + + start = rank * max_size + size = enc_size + SIZE_STORAGE_BYTES + buffer[start: start + size].copy_(cpu_buffer[:size]) + + all_reduce(buffer, group=group) + + try: + result = [] + for i in range(world_size): + out_buffer = buffer[i * max_size: (i + 1) * max_size] + size = int.from_bytes(out_buffer[0:SIZE_STORAGE_BYTES], byteorder='big') + if size > 0: + result.append(pickle.loads(bytes(out_buffer[SIZE_STORAGE_BYTES: size + SIZE_STORAGE_BYTES].tolist()))) + return result + except pickle.UnpicklingError: + raise Exception( + 'Unable to unpickle data from other workers. all_gather_list requires all ' + 'workers to enter the function together, so this error usually indicates ' + 'that the workers have fallen out of sync somehow. Workers can fall out of ' + 'sync if one of them runs out of memory, or if there are other conditions ' + 'in your training script that can cause one worker to finish an epoch ' + 'while other workers are still iterating over their portions of the data.' + ) + + +def _calc_loss(args, loss_function, local_q_vector, local_ctx_vectors, local_caption_vectors, local_positive_idxs, + local_hard_negatives_idxs: list = None, experiment=None + ): + """ + Calculates In-batch negatives schema loss and supports to run it in DDP mode by exchanging the representations + across all the nodes. + """ + distributed_world_size = 1 # args.distributed_world_size or 1 + if distributed_world_size > 1: + # TODO: Add local_caption_vectors + q_vector_to_send = torch.empty_like(local_q_vector).cpu().copy_(local_q_vector).detach_() + ctx_vector_to_send = torch.empty_like(local_ctx_vectors).cpu().copy_(local_ctx_vectors).detach_() + + global_question_ctx_vectors = all_gather_list( + [q_vector_to_send, ctx_vector_to_send, local_positive_idxs, local_hard_negatives_idxs], + max_size=args.global_loss_buf_sz) + + global_q_vector = [] + global_ctxs_vector = [] + + # ctxs_per_question = local_ctx_vectors.size(0) + positive_idx_per_question = [] + hard_negatives_per_question = [] + + total_ctxs = 0 + + for i, item in enumerate(global_question_ctx_vectors): + q_vector, ctx_vectors, positive_idx, hard_negatives_idxs = item + + if i != args.local_rank: + global_q_vector.append(q_vector.to(local_q_vector.device)) + global_ctxs_vector.append(ctx_vectors.to(local_q_vector.device)) + positive_idx_per_question.extend([v + total_ctxs for v in positive_idx]) + hard_negatives_per_question.extend([[v + total_ctxs for v in l] for l in hard_negatives_idxs]) + else: + global_q_vector.append(local_q_vector) + global_ctxs_vector.append(local_ctx_vectors) + positive_idx_per_question.extend([v + total_ctxs for v in local_positive_idxs]) + hard_negatives_per_question.extend([[v + total_ctxs for v in l] for l in local_hard_negatives_idxs]) + total_ctxs += ctx_vectors.size(0) + + global_q_vector = torch.cat(global_q_vector, dim=0) + global_ctxs_vector = torch.cat(global_ctxs_vector, dim=0) + + else: + global_q_vector = local_q_vector + global_ctxs_vector = local_ctx_vectors + global_caption_vector = local_caption_vectors + positive_idx_per_question = local_positive_idxs + hard_negatives_per_question = local_hard_negatives_idxs + + loss, is_correct, scores = loss_function.calc(global_q_vector, global_ctxs_vector, global_caption_vector, + positive_idx_per_question, hard_negatives_per_question, + args.caption_score_weight, experiment) + + return loss, is_correct, scores + + +def compare_models(model_1, model_2): + models_differ = 0 + for key_item_1, key_item_2 in zip(model_1.state_dict().items(), model_2.state_dict().items()): + if torch.equal(key_item_1[1], key_item_2[1]): + pass + else: + models_differ += 1 + if (key_item_1[0] == key_item_2[0]): + print('Mismtach found at', key_item_1[0]) + else: + raise Exception + if models_differ == 0: + print('Models match perfectly! :)') + + +def is_main_process(): + return hvd.rank() == 0 + + +def display_img(img_meta, name, img_only=False): + import matplotlib.pyplot as plt + import matplotlib.image as mpimg + img = mpimg.imread(img_meta[name]['img_file']) + plt.imshow(img) + plt.show() + if not img_only: + print('annotation') + print('\t' + '\n\t'.join(img_meta[name]['annotation'])) + print('caption') + print('\t' + img_meta[name]['caption'][0]) + + +def retrieve_query(model, query, indexer, args, top=10): + input_ids = args.tokenizer.encode(query) + input_ids = torch.LongTensor(input_ids).to(args.device).unsqueeze(0) + attn_mask = torch.ones(len(input_ids[0]), dtype=torch.long, device=args.device).unsqueeze(0) + pos_ids = torch.arange(len(input_ids[0]), dtype=torch.long, device=args.device).unsqueeze(0) + _, query_vector, _ = model.txt_model(input_ids=input_ids,attention_mask=attn_mask, position_ids=pos_ids) + res = indexer.search_knn(query_vector.detach().cpu().numpy(), 100) + return res + + +def get_model_encoded_vecs(model, dataloader): + img_embedding, caption_embedding, query_embedding = dict(), dict(), defaultdict(list) + labels_img_name = [] + # for i, batch in enumerate(dataloader): + for i, batch in enumerate(tqdm.tqdm(dataloader)): + with torch.no_grad(): + model_out = model(batch) + local_q_vectors, local_ctx_vectors, local_caption_vectors = model_out + + img_embedding.update({img_id: img_vec.detach().cpu().numpy() for img_id, img_vec in zip(batch['img_fname'], local_ctx_vectors)}) + caption_embedding.update({img_id: cap_vec.detach().cpu().numpy() for img_id, cap_vec in zip(batch['img_fname'], local_caption_vectors)}) + query_embedding.update({img_id: cap_vec.detach().cpu().numpy() for img_id, cap_vec in zip(batch['txt_index'], local_q_vectors)}) + + labels_img_name.extend(batch['img_fname']) + return { + 'img_embed': img_embedding, + 'caption_embed': caption_embedding, + 'txt_embed': query_embedding, + 'img_name': labels_img_name + } + diff --git a/lightningdot.py b/lightningdot.py index b314edc..fdc611e 100644 --- a/lightningdot.py +++ b/lightningdot.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging import sys import os import json @@ -30,7 +31,7 @@ from .utils import Configs, get_gather_index def arg_process(args): dirname = os.path.dirname(__file__) - args.img_checkpoint = dirname + '/' + args.img_checkpoint + #args.img_checkpoint = dirname + '/' + args.img_checkpoint args.img_model_config = dirname + '/' + args.img_model_config return args @@ -39,19 +40,47 @@ class LightningDOT(NNOperator): """ CLIP multi-modal embedding operator """ - def __init__(self, modality: str): + def __init__(self, model_name:str, modality: str): + logger = logging.getLogger() sys.path.append(str(Path(__file__).parent)) from dvl.models.bi_encoder import BiEncoder from detector.faster_rcnn import Net, process_img + from utils import download_file - full_path = os.path.dirname(__file__) + '/config/flickr30k_ft_config.json' - with open(full_path) as fw: + config_path = os.path.dirname(__file__) + self._configs()[model_name]['config'] + model_url = self._configs()[model_name]['weights'] + weight_name = os.path.basename(model_url) + weight_path = os.path.dirname(__file__) + '/data/model/' + weight_name + + if os.path.exists(weight_path) is False: + download_file(model_url, os.path.dirname(__file__) + '/data/model/') + + with open(config_path) as fw: content = fw.read() args = json.loads(content) + + #args['img_checkpoint'] = './data/model/' + weight_name args = Configs(args) + args = arg_process(args) self.bi_encoder = BiEncoder(args, True, True, project_dim=args.project_dim) self.tokenizer = BertTokenizer.from_pretrained('bert-base-cased') + + state_dict = torch.load(weight_path, map_location='cpu') + try: + if 'model_dict' in state_dict: + self.bi_encoder.load_state_dict(state_dict['model_dict']) + else: + self.bi_encoder.load_state_dict(state_dict) + except RuntimeError: + logger.info('loading from pre-trained model instead') + for k in list(state_dict.keys()): + if k.startswith('bert.'): + state_dict[k[5:]] = state_dict.pop(k) + else: + state_dict.pop(k) + bi_encoder.load_state_dict(state_dict, strict=True) + img_model, txt_model = self.bi_encoder.img_model, self.bi_encoder.txt_model img_model.eval() txt_model.eval() @@ -144,3 +173,18 @@ class LightningDOT(NNOperator): gather_index, fix_txt_encoder) return img_pooled + + def _configs(self): + config = {} + config['lightningdot_base'] = {} + config['lightningdot_base']['weights'] = 'https://convaisharables.blob.core.windows.net/lightningdot/LightningDot.pt' + config['lightningdot_base']['config'] = '/config/pretrain-alldata-base.json' + + config['lightningdot_coco_ft'] = {} + config['lightningdot_coco_ft']['weights'] = 'https://convaisharables.blob.core.windows.net/lightningdot/coco-ft.pt' + config['lightningdot_coco_ft']['config'] = '/config/coco_eval_config.json' + + config['lightningdot_flickr_ft'] = {} + config['lightningdot_flickr_ft']['weights'] = 'https://convaisharables.blob.core.windows.net/lightningdot/flickr-ft.pt' + config['lightningdot_flickr_ft']['config'] = '/config/flickr30k_eval_config.json' + return config diff --git a/requirements.txt b/requirements.txt index 5168bbd..ce1a2a3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,5 +2,4 @@ torch>=1.9.0 torchvision>=0.10.0 transformers==2.3.0 Pillow - -towhee \ No newline at end of file +towhee diff --git a/uniter_model/Dockerfile b/uniter_model/Dockerfile new file mode 100644 index 0000000..5311c91 --- /dev/null +++ b/uniter_model/Dockerfile @@ -0,0 +1,22 @@ +FROM nvcr.io/nvidia/pytorch:19.05-py3 +COPY requirements.txt scripts/download_bert.py ./ +RUN pip install -r requirements.txt &&\ + python download_bert.py &&\ + rm ./requirements.txt ./download_bert.py + +################## v1 ########################## + +COPY scripts/install_horovod.sh ./ +RUN source install_horovod.sh &&\ + rm ./install_horovod.sh +ENV OPENMPI_VERSION=4.0.0 + +# fix ssh permissions +RUN bash -c "chmod -R 600 /etc/ssh/ && chmod 600 /var/run/sshd/ && chmod 600 /root" + +################## horovod, v2 ########################## + + +RUN bash -c "pip install lz4==2.1.9 lmdb==0.97" + +################# LMDB ########################## diff --git a/uniter_model/LICENSE b/uniter_model/LICENSE new file mode 100644 index 0000000..8bde70e --- /dev/null +++ b/uniter_model/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2019 Yen-Chun Chen + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/uniter_model/README.md b/uniter_model/README.md new file mode 100644 index 0000000..dda9919 --- /dev/null +++ b/uniter_model/README.md @@ -0,0 +1,89 @@ +# Universal-Image-Text-Transformer +Research code for pre-training universal vision and language models + + +## Requirements +nvidia driver (418.xx), docker(19.03+), nvidia-container-toolkit +``` +docker pull convaicontainerregistry1.azurecr.io/img-txt +``` + +## lauching the environment +``` +# can use CUDA_VISIBLE_DEVICES to seperate GPUs for each container +source launch_container.sh $TXT_DB $IMG_DIR $OUTPUT $PRETRAIN_PATH +# TXT_DB: convaistorage2share2/TXT_DB_v3 +# IMG_DIR: convaistorage2share2/Bottom-up-features/adaptive/npy_per_img_id +# OUTPUT: somewhere to store model checkpoint (can be on share storage) +# PRETRAIN: path to pretrained model + +# when need to preprocessing +source launch_container.sh $TXT_DB $IMG_DIR $OUTPUT $PRETRAIN_PATH --prepro +# this will make /db writable + + +# multi-node training +source launch_container_dist.sh $TXT_DB $IMG_DIR $OUTPUT $PRETRAIN_PATH +``` + +## Pretrain +``` +# inside the docker container +horovodrun -np $N_GPU -H localhost:$N_GPU \ + python pretrain.py --config config/config-pretrain-alltask.json +``` + +## finetune VQA +``` +horovodrun -np 2 -H localhost:2 \ + python train_vqa.py --config config/config-vqa-bert-2gpu-alldata.json +``` +### VQA inference +``` +# single node only +# please refer to code for commandline options +horovodrun -np $N_GPU -H localhost:$N_GPU \ + python eval_vqa.py --txt_db /db/vqa_test_[base/large]-cased.db/ \ + --img_dir /img/coco_test2015 --checkpoint [NUM] \ + --output_dir /path/to/trained/vqa +``` + +### NLVR2 official evaluation +Use official script to get both acc (our validation matched this) and consistency +``` +# concat all output files +cat $OUTPUT/result/[val/test]_results_$STEP_rank*.csv > $OUTPUT.csv +python eval/nlvr2.py $OUTPUT.csv ANNOTATION.json +``` + +### Referring Expression Comprehension: Finetuning and Evaluation +``` +# train on gd-truth pairs of (ref, sent) +horovodrun -np $N_GPU -H localhost:$N_GPU \ + python train_re.py --config config/hps-refcoco+.json + +# evaluate multiple splits on gd-truth boxes +horovodrun -np $N_GPU -H localhost:$N_GPU \ + python eval_re.py \ + --txt_db /db/refcoco+_val_base-cased.db:/db/refcoco+_testA_base-cased.db:/db/refcoco+_testB_base-cased.db \ + --img_dir /img/visual_grounding_coco_gt \ + --output_dir /storage/refcoco+/bert-base_mlm+itm+mrfr_pretrain-refcoco+_lr1e-4 \ + --checkpoint 26 + +# evaluate multiple splits on detected boxes +horovodrun -np $N_GPU -H localhost:$N_GPU \ + python eval_re.py \ + --txt_db /db/refcoco+_val_base-cased.db:/db/refcoco+_testA_base-cased.db:/db/refcoco+_testB_base-cased.db \ + --img_dir /img/visual_grounding_det_coco \ + --output_dir /storage/refcoco+/bert-base_mlm+itm+mrfr_pretrain-refcoco+_lr1e-4 \ + --checkpoint 26 +``` + +## Misc +1. w/o horovodrun it will run on single GPU + - useful for debugger (-m pdb) +2. try `--pin_mem` it might give a tiny performance improvement +3. `--img_format [lmdb/lmdb-compress]` + - trade-off between memory/CPU + - use `--n_workers $N_CPU` to specify data workers (default: 4) + diff --git a/uniter_model/config/config-vcr-bert-2gpu.json b/uniter_model/config/config-vcr-bert-2gpu.json new file mode 100644 index 0000000..2cd044c --- /dev/null +++ b/uniter_model/config/config-vcr-bert-2gpu.json @@ -0,0 +1,36 @@ +{ + "train_txt_db": "/db/vcr_val_w_obj_ids_base-cased.db/", + "train_img_dir": "/img/vcr_gt_val/;/img/vcr_val/", + "val_txt_db": "/db/vcr_val_w_obj_ids_base-cased.db/", + "val_img_dir": "/img/vcr_gt_val/;/img/vcr_val/", + "checkpoint": "/storage/pretrain_vcr/mlm_qar-30k_steps-lr_3e-5-run1/ckpt/model_step_29000.pt", + "checkpoint_from": "vcr", + "task": "qa", + "cut_bert": -1, + "output_dir": "/storage/debug/qa_qar-bert_base-gt_proposed_img_feat-mlm-diff_type_id_for_ra-lr_2e-5-train_step_10k", + "max_txt_len": 220, + "conf_th": 0.2, + "max_bb": 100, + "min_bb": 10, + "num_bb": 36, + "train_batch_size": 4000, + "val_batch_size": 12, + "gradient_accumulation_steps": 10, + "learning_rate": 5e-5, + "valid_steps": 10, + "num_train_steps": 1000, + "optim": "adamw", + "betas": [0.9, 0.98], + "grad_norm": 2.0, + "decay": "linear", + "warm_int": 500, + "decay_int": 2000, + "decay_st": 12000, + "decay_rate": 0.2, + "dropout": 0.1, + "weight_decay": 0.01, + "warmup_steps": 1000, + "seed": 42, + "fp16": true, + "mcan": false +} diff --git a/uniter_model/config/eval-itm-coco.json b/uniter_model/config/eval-itm-coco.json new file mode 100644 index 0000000..1cd8376 --- /dev/null +++ b/uniter_model/config/eval-itm-coco.json @@ -0,0 +1,11 @@ +{ + "txt_db": "/db/itm_coco_test_1k_4_base-cased.db/", + "img_dir": "/img/coco_val2014/", + "neg_sample_size": -1, + "eval_split": "itm_coco_test_1k_4_fp16", + "checkpoint": 12000, + "cut_bert": -1, + "output_dir": "/storage/itm/coco-bert_base-weak_420k-w_train-lr_2e-5-20k_steps-rank_loss-batch_size_20_acc_8_wd_0/", + "fp16": true, + "eval_mini_batch_size": 400 +} diff --git a/uniter_model/config/eval-itm-flickr.json b/uniter_model/config/eval-itm-flickr.json new file mode 100644 index 0000000..fa18efc --- /dev/null +++ b/uniter_model/config/eval-itm-flickr.json @@ -0,0 +1,11 @@ +{ + "txt_db": "/db/itm_flickr30k_train_base-cased.db/", + "img_dir": "/img/flickr30k/", + "neg_sample_size": -1, + "eval_split": "itm_flickr_train", + "checkpoint": 6000, + "cut_bert": -1, + "output_dir": "/storage/itm/flickr30k-bert_base_weak_420k-hard_neg_finetune-lr_5e-5-train_step_5k", + "fp16": true, + "eval_mini_batch_size": 128 +} diff --git a/uniter_model/config/hps-itm.json b/uniter_model/config/hps-itm.json new file mode 100644 index 0000000..df9fd92 --- /dev/null +++ b/uniter_model/config/hps-itm.json @@ -0,0 +1,53 @@ +{ + "train_txt_db": ["/db/itm_coco_train_base-cased.db/", + "/db/itm_coco_restval_base-cased.db"], + "train_img_dir": ["/img/coco_train2014/", + "/img/coco_val2014/"], + "train_neg_sample_p": 0.5, + "neg_sample_from": "i", + "eval_method": "rank", + "val_txt_db": ["/db/itm_coco_val_1k_0_base-cased.db/", + "/db/itm_coco_val_1k_1_base-cased.db/", + "/db/itm_coco_val_1k_2_base-cased.db/", + "/db/itm_coco_val_1k_3_base-cased.db/", + "/db/itm_coco_val_1k_4_base-cased.db/"], + "val_img_dir": ["/img/coco_val2014/", + "/img/coco_val2014/", + "/img/coco_val2014/", + "/img/coco_val2014/", + "/img/coco_val2014/"], + "test_txt_db": ["/db/itm_coco_test_1k_0_base-cased.db/", + "/db/itm_coco_test_1k_1_base-cased.db/", + "/db/itm_coco_test_1k_2_base-cased.db/", + "/db/itm_coco_test_1k_3_base-cased.db/", + "/db/itm_coco_test_1k_4_base-cased.db/"], + "test_img_dir": ["/img/coco_val2014/", + "/img/coco_val2014/", + "/img/coco_val2014/", + "/img/coco_val2014/", + "/img/coco_val2014/"], + "checkpoint": "/pretrain/mlm_caption_bert-base.pt", + "cut_bert": -1, + "output_dir": "/storage/itm_tr/coco_bert_base-mlm_caption-corrected_img_bb_num-w_train_rv-step_40000", + "max_txt_len": 60, + "conf_th": 0.2, + "max_bb": 100, + "min_bb": 10, + "num_bb": 36, + "train_batch_size": 2048, + "val_batch_size": 4096, + "val_minibatch_size":400, + "test_minibatch_size":300, + "gradient_accumulation_steps": 8, + "learning_rate": 0.001, + "valid_steps": 1000, + "num_train_steps": 40000, + "optim": "adamax", + "decay": "linear", + "dropout": 0.1, + "falseweight_decay": 0.01, + "grad_norm": 2.0, + "warmup_steps": 1000, + "seed": 42, + "fp16": true +} diff --git a/uniter_model/config/hps-refcoco+.json b/uniter_model/config/hps-refcoco+.json new file mode 100644 index 0000000..afa54df --- /dev/null +++ b/uniter_model/config/hps-refcoco+.json @@ -0,0 +1,25 @@ +{ + "train_txt_db": "/db/refcoco+_train_base-cased.db", + "train_img_dir": "/img/visual_grounding_coco_gt", + "val_txt_db": "/db/refcoco+_val_base-cased.db", + "val_img_dir": "/img/visual_grounding_coco_gt", + "checkpoint": "/pretrain/bert-base_weak/ckpt/model_step_420000.pt", + "cut_bert": -1, + "output_dir": "/storage/refcoco+/bert-base_mlm+itm+mrfr_pretrain-refcoco+_lr1e-4", + "max_txt_len": 60, + "train_batch_size": 128, + "val_batch_size": 128, + "learning_rate": 1e-4, + "optim": "adamw", + "betas": [0.9, 0.98], + "weight_decay": 0.01, + "dropout": 0.1, + "grad_norm": 2.0, + "decay": "linear", + "num_train_steps": 24000, + "warmup_steps": 1500, + "gradient_accumulation_steps": 1, + "no_cuda": false, + "seed": 24, + "fp16": true +} \ No newline at end of file diff --git a/uniter_model/config/hps-refcoco+_conceptual.json b/uniter_model/config/hps-refcoco+_conceptual.json new file mode 100644 index 0000000..eb449f0 --- /dev/null +++ b/uniter_model/config/hps-refcoco+_conceptual.json @@ -0,0 +1,26 @@ +{ + "train_txt_db": "/db/refcoco+_train_base-cased.db", + "train_img_dir": "/img/visual_grounding_coco_gt", + "val_txt_db": "/db/refcoco+_val_base-cased.db", + "val_img_dir": "/img/visual_grounding_det_coco", + "checkpoint": "/pretrain/bert-base_weak_conceptual/ckpt/model_step_200000.pt", + "cut_bert": -1, + "output_dir": "/storage/refcoco+/conceptual-bert-base_mlm+itm+mrfr_pretrain-refcoco+_lr8e-5", + "max_txt_len": 60, + "train_batch_size": 128, + "val_batch_size": 128, + "learning_rate": 8e-5, + "valid_steps": 1000, + "optim": "adamw", + "betas": [0.9, 0.98], + "weight_decay": 0.01, + "dropout": 0.1, + "grad_norm": 2.0, + "decay": "linear", + "num_train_steps": 24000, + "warmup_steps": 1500, + "gradient_accumulation_steps": 1, + "no_cuda": false, + "seed": 24, + "fp16": true +} \ No newline at end of file diff --git a/uniter_model/config/hps-refcoco+_conceptual_large_weak.json b/uniter_model/config/hps-refcoco+_conceptual_large_weak.json new file mode 100644 index 0000000..4b1aaba --- /dev/null +++ b/uniter_model/config/hps-refcoco+_conceptual_large_weak.json @@ -0,0 +1,26 @@ +{ + "train_txt_db": "/db/refcoco+_train_large-cased.db", + "train_img_dir": "/img/visual_grounding_coco_gt", + "val_txt_db": "/db/refcoco+_val_large-cased.db", + "val_img_dir": "/img/visual_grounding_det_coco", + "checkpoint": "/pretrain/bert-large_weak/ckpt/model_step_50000.pt", + "cut_bert": -1, + "output_dir": "/storage/refcoco+/conceptual-bert-large_mlm+itm+mrfr_pretrain-refcoco+_lr8e-5_b64g4", + "max_txt_len": 60, + "train_batch_size": 64, + "val_batch_size": 256, + "learning_rate": 8e-5, + "valid_steps": 1000, + "optim": "adamw", + "betas": [0.9, 0.98], + "weight_decay": 0.01, + "dropout": 0.1, + "grad_norm": 2.0, + "decay": "linear", + "num_train_steps": 24000, + "warmup_steps": 1500, + "gradient_accumulation_steps": 4, + "no_cuda": false, + "seed": 24, + "fp16": true +} \ No newline at end of file diff --git a/uniter_model/config/hps-refcoco+_conceptual_rank.json b/uniter_model/config/hps-refcoco+_conceptual_rank.json new file mode 100644 index 0000000..f4502d0 --- /dev/null +++ b/uniter_model/config/hps-refcoco+_conceptual_rank.json @@ -0,0 +1,29 @@ +{ + "train_txt_db": "/db/refcoco+_train_base-cased.db", + "train_img_dir": "/img/visual_grounding_coco_gt", + "val_txt_db": "/db/refcoco+_val_base-cased.db", + "val_img_dir": "/img/visual_grounding_coco_gt", + "checkpoint": "/pretrain/bert-base_weak_conceptual/ckpt/model_step_420000.pt", + "cut_bert": -1, + "output_dir": "/storage/refcoco+/conceptual-bert-base_mlm+itm+mrfr_pretrain-refcoco+_lr1e-4_rank_r0.2_m0.2_step30k", + "max_txt_len": 60, + "train_batch_size": 128, + "val_batch_size": 128, + "learning_rate": 1e-4, + "valid_steps": 1000, + "optim": "adamw", + "betas": [0.9, 0.98], + "weight_decay": 0.01, + "dropout": 0.1, + "grad_norm": 2.0, + "decay": "linear", + "num_train_steps": 30000, + "warmup_steps": 1500, + "gradient_accumulation_steps": 1, + "no_cuda": false, + "seed": 24, + "fp16": true, + "train_loss": "rank", + "hard_ratio": 0.2, + "margin": 0.2 +} \ No newline at end of file diff --git a/uniter_model/config/hps-refcoco.json b/uniter_model/config/hps-refcoco.json new file mode 100644 index 0000000..64df930 --- /dev/null +++ b/uniter_model/config/hps-refcoco.json @@ -0,0 +1,26 @@ +{ + "train_txt_db": "/db/refcoco_train_base-cased.db", + "train_img_dir": "/img/visual_grounding_coco_gt", + "val_txt_db": "/db/refcoco_val_base-cased.db", + "val_img_dir": "/img/visual_grounding_coco_gt", + "checkpoint": "/pretrain/bert-base_weak/ckpt/model_step_420000.pt", + "cut_bert": -1, + "output_dir": "/storage/refcoco/bert-base_mlm+itm+mrfr_pretrain-refcoco_lr3e-4", + "max_txt_len": 60, + "train_batch_size": 128, + "val_batch_size": 128, + "learning_rate": 3e-4, + "valid_steps": 1000, + "optim": "adamw", + "betas": [0.9, 0.98], + "weight_decay": 0.01, + "dropout": 0.1, + "grad_norm": 2.0, + "decay": "linear", + "num_train_steps": 10000, + "warmup_steps": 1500, + "gradient_accumulation_steps": 1, + "no_cuda": false, + "seed": 24, + "fp16": true +} \ No newline at end of file diff --git a/uniter_model/config/hps-ve-large.json b/uniter_model/config/hps-ve-large.json new file mode 100644 index 0000000..cc440de --- /dev/null +++ b/uniter_model/config/hps-ve-large.json @@ -0,0 +1,31 @@ +{ + "train_txt_db": "/db/ve_train_large-cased.db/", + "train_img_dir": "/img/flickr30k/", + "val_txt_db": "/db/ve_dev_large-cased.db/", + "test_img_dir": "/img/flickr30k/", + "test_txt_db": "/db/ve_test_large-cased.db/", + "val_img_dir": "/img/flickr30k/", + "checkpoint": "/pretrain/bert-large_frkl_alldata.pt", + "cut_bert": -1, + "output_dir": "/storage/ve/default", + "max_txt_len": 60, + "conf_th": 0.2, + "max_bb": 100, + "min_bb": 10, + "num_bb": 36, + "train_batch_size": 8192, + "val_batch_size": 8192, + "gradient_accumulation_steps": 4, + "learning_rate": 3e-5, + "valid_steps": 500, + "num_train_steps": 6000, + "warmup_steps": 600, + "optim": "adamw", + "betas": [0.9, 0.98], + "grad_norm": 2.0, + "decay": "linear", + "dropout": 0.1, + "weight_decay": 0.01, + "seed": 42, + "fp16": true +} diff --git a/uniter_model/config/hps-ve.json b/uniter_model/config/hps-ve.json new file mode 100644 index 0000000..8c20672 --- /dev/null +++ b/uniter_model/config/hps-ve.json @@ -0,0 +1,31 @@ +{ + "train_txt_db": "/db/ve_train_base-cased.db/", + "train_img_dir": "/img/flickr30k/", + "val_txt_db": "/db/ve_dev_base-cased.db/", + "test_img_dir": "/img/flickr30k/", + "test_txt_db": "/db/ve_test_base-cased.db/", + "val_img_dir": "/img/flickr30k/", + "checkpoint": "/pretrain/base_mlm_mrfr_mrckl_itm_alldata.pt", + "cut_bert": -1, + "output_dir": "/storage/ve/default", + "max_txt_len": 60, + "conf_th": 0.2, + "max_bb": 100, + "min_bb": 10, + "num_bb": 36, + "train_batch_size": 8192, + "val_batch_size": 8192, + "gradient_accumulation_steps": 4, + "learning_rate": 3e-5, + "valid_steps": 500, + "num_train_steps": 6000, + "warmup_steps": 600, + "optim": "adamw", + "betas": [0.9, 0.98], + "grad_norm": 2.0, + "decay": "linear", + "dropout": 0.1, + "weight_decay": 0.01, + "seed": 42, + "fp16": true +} diff --git a/uniter_model/config/hps-vqa.json b/uniter_model/config/hps-vqa.json new file mode 100644 index 0000000..e88f5b3 --- /dev/null +++ b/uniter_model/config/hps-vqa.json @@ -0,0 +1,30 @@ +{ + "train_txt_db": "/db/vqa_train_base-cased.db/", + "train_img_dir": "/img/coco_train2014/", + "val_txt_db": "/db/vqa_val_base-cased.db/", + "val_img_dir": "/img/coco_val2014/", + "ans2label": "/db/ans2label.pkl", + "checkpoint": "/storage/mlm/caption-base_from-scratch_grad-acc-8_step-80k_val-step-5k_lr-2e-4_wu-0.1/ckpt/model_step_80000_final.pt", + "cut_bert": -1, + "output_dir": "/storage/vqa/bert_base-mlm_caption_nonblind_from_scratch-nowd-linear", + "max_txt_len": 60, + "conf_th": 0.2, + "max_bb": 100, + "min_bb": 10, + "num_bb": 36, + "train_batch_size": 2048, + "val_batch_size": 4096, + "gradient_accumulation_steps": 8, + "learning_rate": 0.001, + "valid_steps": 500, + "num_train_steps": 20000, + "optim": "adamax", + "decay": "linear", + "dropout": 0.1, + "weight_decay": 0, + "grad_norm": 2.0, + "warmup_steps": 1000, + "seed": 42, + "fp16": false, + "blind": false +} diff --git a/uniter_model/config/itm-coco-base.json b/uniter_model/config/itm-coco-base.json new file mode 100644 index 0000000..ed930df --- /dev/null +++ b/uniter_model/config/itm-coco-base.json @@ -0,0 +1,47 @@ +{ + "compressed_db": false, + "checkpoint": "/pretrain/alltask_ot_alldata.pt", + "output_dir": "/storage/finetune/itm/coco_ot_alldata_base_hnv2", + "max_txt_len": 60, + "conf_th": 0.2, + "max_bb": 100, + "min_bb": 10, + "num_bb": 36, + "train_batch_size": 8, + "negative_size": 399, + "hard_neg_size": 31, + "inf_minibatch_size": 400, + "margin": 0.2, + "learning_rate": 5e-05, + "valid_steps": 500, + "num_train_steps": 5000, + "optim": "adamw", + "betas": [ + 0.9, + 0.98 + ], + "decay": "linear", + "dropout": 0.1, + "weight_decay": 0.01, + "grad_norm": 2.0, + "warmup_steps": 500, + "seed": 42, + "full_val": true, + "fp16": true, + "n_workers": 4, + "pin_mem": true, + "train_txt_dbs": [ + "/db/itm_coco_train_base-cased.db", + "/db/itm_coco_restval_base-cased.db" + ], + "train_img_dbs": [ + "/img/coco_train2014/", + "/img/coco_val2014" + ], + "val_txt_db": "/db/itm_coco_val_base-cased.db", + "val_img_db": "/img/coco_val2014/", + "test_txt_db": "/db/itm_coco_test_base-cased.db", + "test_img_db": "/img/coco_val2014/", + "model_config": "/src/config/uniter-base.json", + "rank": 0 +} \ No newline at end of file diff --git a/uniter_model/config/itm-ot-base-16gpus.json b/uniter_model/config/itm-ot-base-16gpus.json new file mode 100644 index 0000000..21104a4 --- /dev/null +++ b/uniter_model/config/itm-ot-base-16gpus.json @@ -0,0 +1,45 @@ +{ + "compressed_db": false, + "checkpoint": "/pretrain/bert-base-cased.pt", + "output_dir": "/ssd2/siqi/Projects/model_compression/outputs/debug", + "max_txt_len": 60, + "conf_th": 0.2, + "max_bb": 100, + "min_bb": 10, + "num_bb": 36, + "train_batch_size": 80, + "gradient_accumulation_steps": 8, + "negative_size": 1, + "hard_neg_size": 0, + "inf_minibatch_size": 400, + "margin": 0.2, + "learning_rate": 5e-05, + "warmup_steps": 100, + "valid_steps": 500, + "num_train_steps": 1000, + "optim": "adamw", + "betas": [ + 0.9, + 0.98 + ], + "decay": "linear", + "dropout": 0.1, + "weight_decay": 0.01, + "grad_norm": 2.0, + "seed": 42, + "full_val": false, + "fp16": true, + "n_workers": 4, + "pin_mem": true, + "train_txt_dbs": [ + "/db/itm_flickr30k_train_base-cased.db" + ], + "train_img_dbs": [ + "/img/flickr30k/" + ], + "val_txt_db": "/db/itm_flickr30k_val_base-cased.db", + "val_img_db": "/img/flickr30k/", + "test_txt_db": "/db/itm_flickr30k_test_base-cased.db", + "test_img_db": "/img/flickr30k/", + "model_config": "./config/uniter-base.json" +} \ No newline at end of file diff --git a/uniter_model/config/itm-ot-base-16gpus_philly.json b/uniter_model/config/itm-ot-base-16gpus_philly.json new file mode 100644 index 0000000..d50cb5b --- /dev/null +++ b/uniter_model/config/itm-ot-base-16gpus_philly.json @@ -0,0 +1,45 @@ +{ + "compressed_db": false, + "checkpoint": "/pretrain/alltask_ot_alldata.pt", + "output_dir": "/ssd2/siqi/Projects/model_compression/outputs/debug", + "max_txt_len": 60, + "conf_th": 0.2, + "max_bb": 100, + "min_bb": 10, + "num_bb": 36, + "train_batch_size": 40, + "gradient_accumulation_steps": 4, + "negative_size": 1, + "hard_neg_size": 0, + "inf_minibatch_size": 400, + "margin": 0.2, + "learning_rate": 5e-05, + "valid_steps": 500, + "num_train_steps": 5000, + "optim": "adamw", + "betas": [ + 0.9, + 0.98 + ], + "decay": "linear", + "dropout": 0.1, + "weight_decay": 0.01, + "grad_norm": 2.0, + "warmup_steps": 500, + "seed": 42, + "full_val": false, + "fp16": true, + "n_workers": 4, + "pin_mem": true, + "train_txt_dbs": [ + "/db/itm_flickr30k_train_base-cased.db" + ], + "train_img_dbs": [ + "/img/flickr30k/" + ], + "val_txt_db": "/db/itm_flickr30k_val_base-cased.db", + "val_img_db": "/img/flickr30k/", + "test_txt_db": "/db/itm_flickr30k_test_base-cased.db", + "test_img_db": "/img/flickr30k/", + "model_config": "./config/uniter-base.json" +} \ No newline at end of file diff --git a/uniter_model/config/pretrain-gqa-alltask.json b/uniter_model/config/pretrain-gqa-alltask.json new file mode 100644 index 0000000..851bead --- /dev/null +++ b/uniter_model/config/pretrain-gqa-alltask.json @@ -0,0 +1,42 @@ +{ + "train_datasets": [ + {"name": "gqa", + "db": ["/db/pretrain_gqa_train_0_large-cased.db", "/db/pretrain_gqa_train_1_base-cased.db", + "/db/pretrain_gqa_train_2_base-cased.db", "/db/pretrain_gqa_train_3_base-cased.db", + "/db/pretrain_gqa_train_4_base-cased.db", "/db/pretrain_gqa_train_5_base-cased.db", + "/db/pretrain_gqa_train_6_base-cased.db", "/db/pretrain_gqa_train_7_base-cased.db", + "/db/pretrain_gqa_train_8_base-cased.db", "/db/pretrain_gqa_train_9_base-cased.db", + "/db/pretrain_gqa_val_base-cased.db"], + "img": ["/img/gqa/"], + "tasks": ["mlm", "mrm", "mrckl"], + "mix_ratio": [2, 1, 1]} + ], + "val_datasets": [ + {"name": "gqa", + "db": ["/db/pretrain_gqa_testdev_balanced_base-cased.db"], + "img": ["/img/gqa/"], + "tasks": ["mlm", "mrm", "mrckl"]} + ], + "checkpoint": "/pretrain/bert-large_weak_alldata/ckpt/model_step_100000.pt", + "output_dir": "/storage/pretrain_gqa/bert_large_weak_alldata_100k-train_val_all-mlm_mrm_mrckl-train_batch_size_6144-500k_steps", + "mrm_prob": 0.15, + "max_txt_len": 220, + "conf_th": 0.2, + "max_bb": 100, + "min_bb": 10, + "num_bb": 36, + "train_batch_size": 6144, + "val_batch_size": 8000, + "gradient_accumulation_steps": 10, + "learning_rate": 3e-05, + "valid_steps": 10000, + "num_train_steps": 500000, + "optim": "adamw", + "decay": "linear", + "dropout": 0.1, + "weight_decay": 0.01, + "grad_norm": -1, + "warmup_steps": 50000, + "seed": 42, + "fp16": true +} diff --git a/uniter_model/config/pretrain-mlm-coco.json b/uniter_model/config/pretrain-mlm-coco.json new file mode 100644 index 0000000..c48cc6a --- /dev/null +++ b/uniter_model/config/pretrain-mlm-coco.json @@ -0,0 +1,42 @@ +{ + "train_datasets": [ + {"name": "coco_cap", + "db": ["/db/pretrain_caption_coco_train_base-cased.db/", + "/db/pretrain_caption_coco_trainval_base-cased.db/"], + "img": ["/img/coco_train2014/", "/img/coco_val2014/"], + "tasks": ["mlm"], + "mix_ratio": [1]} + ], + "val_datasets": [ + {"name": "coco_cap", + "db": ["/db/pretrain_caption_coco_val_base-cased.db/"], + "img": ["/img/coco_val2014/"], + "tasks": ["mlm"]} + ], + "output_dir": "/storage/pretrain/mlm_coco", + "mrm_prob": 0.15, + "neg_size": 1024, + "itm_neg_prob": 0.5, + "max_txt_len": 60, + "conf_th": 0.2, + "max_bb": 100, + "min_bb": 10, + "num_bb": 36, + "train_batch_size": 8192, + "val_batch_size": 8192, + "gradient_accumulation_steps": 2, + "learning_rate": 5e-05, + "valid_steps": 5000, + "num_train_steps": 100000, + "optim": "adamw", + "decay": "linear", + "dropout": 0.1, + "weight_decay": 0.01, + "grad_norm": 2.0, + "warmup_steps": 10000, + "seed": 42, + "fp16": true, + "pin_mem": true, + "n_workers": 4, + "from_scratch": false +} diff --git a/uniter_model/config/pretrain-mlm_itmot_mrfr_mrckl-indomain-base.json b/uniter_model/config/pretrain-mlm_itmot_mrfr_mrckl-indomain-base.json new file mode 100644 index 0000000..b360e0f --- /dev/null +++ b/uniter_model/config/pretrain-mlm_itmot_mrfr_mrckl-indomain-base.json @@ -0,0 +1,53 @@ +{ + "train_datasets": [ + {"name": "coco_cap", + "db": ["/db/pretrain_caption_coco_train_base-cased.db/", + "/db/pretrain_caption_coco_trainval_base-cased.db/"], + "img": ["/img/coco_train2014/", "/img/coco_val2014/"], + "tasks": ["itm", "mlm", "mrfr", "mrckl"], + "mix_ratio": [2, 2, 1, 1]}, + {"name": "vg_cap", + "db": ["/db/pretrain_caption_vg_train_base-cased.db/"], + "img": ["/img/vg/"], + "tasks": ["itm", "mlm", "mrfr", "mrckl"], + "mix_ratio": [2, 2, 1, 1]} + ], + "val_datasets": [ + {"name": "coco_cap", + "db": ["/db/pretrain_caption_coco_val_base-cased.db/"], + "img": ["/img/coco_val2014/"], + "tasks": ["itm", "mlm", "mrfr", "mrckl"]}, + {"name": "vg_cap", + "db": ["/db/pretrain_caption_vg_val_base-cased.db/"], + "img": ["/img/vg/"], + "tasks": ["itm", "mlm", "mrfr", "mrckl"]} + ], + "model_config": "/src/config/uniter-base.json", + "checkpoint": "/pretrain/bert-base-cased.pt", + "output_dir": "/storage/pretrain/alltask_ot_indomain_base", + "ans2label": "/db/pretrain_ans2label.pkl", + "mrm_prob": 0.15, + "itm_neg_prob": 0.5, + "itm_ot_lambda": 0.1, + "max_txt_len": 60, + "conf_th": 0.2, + "max_bb": 100, + "min_bb": 10, + "num_bb": 36, + "train_batch_size": 10240, + "val_batch_size": 10240, + "gradient_accumulation_steps": 2, + "learning_rate": 5e-05, + "valid_steps": 5000, + "num_train_steps": 200000, + "optim": "adamw", + "decay": "linear", + "dropout": 0.1, + "weight_decay": 0.01, + "grad_norm": 5.0, + "warmup_steps": 10000, + "seed": 42, + "fp16": true, + "pin_mem": true, + "n_workers": 4 +} diff --git a/uniter_model/config/pretrain-mrckl-coco.json b/uniter_model/config/pretrain-mrckl-coco.json new file mode 100644 index 0000000..c770798 --- /dev/null +++ b/uniter_model/config/pretrain-mrckl-coco.json @@ -0,0 +1,42 @@ +{ + "train_datasets": [ + {"name": "coco_cap", + "db": ["/db/pretrain_caption_coco_train_base-cased.db/", + "/db/pretrain_caption_coco_trainval_base-cased.db/"], + "img": ["/img/coco_train2014/", "/img/coco_val2014/"], + "tasks": ["mrckl"], + "mix_ratio": [1]} + ], + "val_datasets": [ + {"name": "coco_cap", + "db": ["/db/pretrain_caption_coco_val_base-cased.db/"], + "img": ["/img/coco_val2014/"], + "tasks": ["mrckl"]} + ], + "output_dir": "/storage/pretrain/mrckl_coco", + "mrm_prob": 0.15, + "neg_size": 1024, + "itm_neg_prob": 0.5, + "max_txt_len": 60, + "conf_th": 0.2, + "max_bb": 100, + "min_bb": 10, + "num_bb": 36, + "train_batch_size": 8192, + "val_batch_size": 8192, + "gradient_accumulation_steps": 2, + "learning_rate": 5e-05, + "valid_steps": 5000, + "num_train_steps": 100000, + "optim": "adamw", + "decay": "linear", + "dropout": 0.1, + "weight_decay": 0.01, + "grad_norm": 2.0, + "warmup_steps": 10000, + "seed": 42, + "fp16": true, + "pin_mem": true, + "n_workers": 4, + "from_scratch": false +} diff --git a/uniter_model/config/pretrain-mrfr-coco.json b/uniter_model/config/pretrain-mrfr-coco.json new file mode 100644 index 0000000..20e7ec3 --- /dev/null +++ b/uniter_model/config/pretrain-mrfr-coco.json @@ -0,0 +1,42 @@ +{ + "train_datasets": [ + {"name": "coco_cap", + "db": ["/db/pretrain_caption_coco_train_base-cased.db/", + "/db/pretrain_caption_coco_trainval_base-cased.db/"], + "img": ["/img/coco_train2014/", "/img/coco_val2014/"], + "tasks": ["mrfr"], + "mix_ratio": [1]} + ], + "val_datasets": [ + {"name": "coco_cap", + "db": ["/db/pretrain_caption_coco_val_base-cased.db/"], + "img": ["/img/coco_val2014/"], + "tasks": ["mrfr"]} + ], + "output_dir": "/storage/pretrain/mrfr_coco", + "mrm_prob": 0.15, + "neg_size": 1024, + "itm_neg_prob": 0.5, + "max_txt_len": 60, + "conf_th": 0.2, + "max_bb": 100, + "min_bb": 10, + "num_bb": 36, + "train_batch_size": 8192, + "val_batch_size": 8192, + "gradient_accumulation_steps": 2, + "learning_rate": 5e-05, + "valid_steps": 5000, + "num_train_steps": 100000, + "optim": "adamw", + "decay": "linear", + "dropout": 0.1, + "weight_decay": 0.01, + "grad_norm": 2.0, + "warmup_steps": 10000, + "seed": 42, + "fp16": true, + "pin_mem": true, + "n_workers": 4, + "from_scratch": false +} diff --git a/uniter_model/config/pretrain-mrm-nce-coco.json b/uniter_model/config/pretrain-mrm-nce-coco.json new file mode 100644 index 0000000..32b8534 --- /dev/null +++ b/uniter_model/config/pretrain-mrm-nce-coco.json @@ -0,0 +1,43 @@ +{ + "train_datasets": [ + {"name": "coco_cap", + "db": ["/db/pretrain_caption_coco_train_base-cased.db/", + "/db/pretrain_caption_coco_trainval_base-cased.db/"], + "img": ["/img/coco_train2014/", "/img/coco_val2014/"], + "tasks": ["mrm-nce"], + "mix_ratio": [1]} + ], + "val_datasets": [ + {"name": "coco_cap", + "db": ["/db/pretrain_caption_coco_val_base-cased.db/"], + "img": ["/img/coco_val2014/"], + "tasks": ["mrm-nce"]} + ], + "output_dir": "/storage/pretrain/mrm_nce_coco", + "mrm_prob": 0.15, + "neg_size": 1024, + "nce_temp": 1.0, + "itm_neg_prob": 0.5, + "max_txt_len": 60, + "conf_th": 0.2, + "max_bb": 100, + "min_bb": 10, + "num_bb": 36, + "train_batch_size": 8192, + "val_batch_size": 8192, + "gradient_accumulation_steps": 2, + "learning_rate": 5e-05, + "valid_steps": 5000, + "num_train_steps": 100000, + "optim": "adamw", + "decay": "linear", + "dropout": 0.1, + "weight_decay": 0.01, + "grad_norm": 2.0, + "warmup_steps": 10000, + "seed": 42, + "fp16": true, + "pin_mem": true, + "n_workers": 4, + "from_scratch": false +} diff --git a/uniter_model/config/pretrain-vcr-alltask.json b/uniter_model/config/pretrain-vcr-alltask.json new file mode 100644 index 0000000..fb738e8 --- /dev/null +++ b/uniter_model/config/pretrain-vcr-alltask.json @@ -0,0 +1,38 @@ +{ + "train_datasets": [ + {"name": "vcr", + "db": ["/db/vcr_val_w_obj_ids_base-cased.db/"], + "img": ["/img/vcr_val/;/img/vcr_gt_val/"], + "tasks": ["mlm", "mrm", "mrckl"], + "mix_ratio": [2, 1, 1]} + ], + "val_datasets": [ + {"name": "vcr", + "db": ["/db/vcr_val_w_obj_ids_base-cased.db/"], + "img": ["/img/vcr_val/;/img/vcr_gt_val/"], + "tasks": ["mlm", "mrm", "mrckl"]} + ], + "checkpoint": "/pretrain/bert-base_weak_w_mlm_itm_mrm_mrckl_4gpu/ckpt/model_step_500000.pt", + "vcr_task": ["qa", "qar"], + "output_dir": "/storage/debug/mlm_mrm_mrckl-qa_qar-gt_det", + "mrm_prob": 0.15, + "max_txt_len": 60, + "conf_th": 0.2, + "max_bb": 100, + "min_bb": 10, + "num_bb": 36, + "train_batch_size": 8000, + "val_batch_size": 8000, + "gradient_accumulation_steps": 5, + "learning_rate": 3e-05, + "valid_steps": 10, + "num_train_steps": 120000, + "optim": "adamw", + "decay": "linear", + "dropout": 0.1, + "weight_decay": 0.01, + "grad_norm": -1, + "warmup_steps": 12000, + "seed": 42, + "fp16": true +} diff --git a/uniter_model/config/train-itm-debug.json b/uniter_model/config/train-itm-debug.json new file mode 100644 index 0000000..202470c --- /dev/null +++ b/uniter_model/config/train-itm-debug.json @@ -0,0 +1,40 @@ +{ + "train_txt_dbs": ["/db/itm_flickr30k_val_base-cased.db"], + "train_img_dbs": ["/img/flickr30k/"], + "val_txt_db": "/db/itm_flickr30k_val_base-cased.db", + "val_img_db": "/img/flickr30k/", + "test_txt_db": "/db/itm_flickr30k_test_base-cased.db", + "test_img_db": "/img/flickr30k/", + "checkpoint": "/pretrain/uniter-base-iclr.pt", + "model_config": "/src/config/uniter-base.json", + "output_dir": "/debug/itm/flickr_default", + "max_txt_len": 60, + "conf_th": 0.2, + "max_bb": 100, + "min_bb": 10, + "num_bb": 36, + "train_batch_size": 20, + "negative_size": 1, + "hard_neg_size": 1, + "hard_neg_pool_size": 20, + "steps_per_hard_neg": 30, + "inf_minibatch_size": 40, + "gradient_accumulation_steps": 2, + "learning_rate": 1e-05, + "valid_steps": 40, + "num_train_steps": 50, + "optim": "adamw", + "betas": [ + 0.9, + 0.98 + ], + "margin": 0.2, + "dropout": 0.1, + "weight_decay": 0.01, + "grad_norm": 2.0, + "warmup_steps": 1600, + "seed": 42, + "fp16": true, + "n_workers": 0, + "pin_mem": true +} diff --git a/uniter_model/config/train-itm-flickr-base-hnv2.json b/uniter_model/config/train-itm-flickr-base-hnv2.json new file mode 100644 index 0000000..b9e73af --- /dev/null +++ b/uniter_model/config/train-itm-flickr-base-hnv2.json @@ -0,0 +1,38 @@ +{ + "train_txt_dbs": ["/db/itm_flickr30k_train_base-cased.db"], + "train_img_dbs": ["/img/flickr30k/"], + "val_txt_db": "/db/itm_flickr30k_val_base-cased.db", + "val_img_db": "/img/flickr30k/", + "test_txt_db": "/db/itm_flickr30k_test_base-cased.db", + "test_img_db": "/img/flickr30k/", + "checkpoint": "/pretrain/alltask_ot_alldata.pt", + "model_config": "/src/config/uniter-base.json", + "output_dir": "/storage/finetune/itm/flickr_ot_alldata_base_hnv2", + "max_txt_len": 60, + "conf_th": 0.2, + "max_bb": 100, + "min_bb": 10, + "num_bb": 36, + "train_batch_size": 8, + "negative_size": 399, + "hard_neg_size": 31, + "inf_minibatch_size": 400, + "learning_rate": 5e-05, + "valid_steps": 500, + "num_train_steps": 5000, + "optim": "adamw", + "betas": [ + 0.9, + 0.98 + ], + "margin": 0.2, + "dropout": 0.1, + "weight_decay": 0.01, + "grad_norm": 2.0, + "warmup_steps": 500, + "seed": 42, + "fp16": true, + "n_workers": 4, + "pin_mem": true, + "full_val": true +} diff --git a/uniter_model/config/train-itm-flickr-base.json b/uniter_model/config/train-itm-flickr-base.json new file mode 100644 index 0000000..2148cc5 --- /dev/null +++ b/uniter_model/config/train-itm-flickr-base.json @@ -0,0 +1,40 @@ +{ + "train_txt_dbs": ["/db/itm_flickr30k_train_base-cased.db"], + "train_img_dbs": ["/img/flickr30k/"], + "val_txt_db": "/db/itm_flickr30k_val_base-cased.db", + "val_img_db": "/img/flickr30k/", + "test_txt_db": "/db/itm_flickr30k_test_base-cased.db", + "test_img_db": "/img/flickr30k/", + "checkpoint": "/pretrain/alltask_ot_alldata.pt", + "model_config": "/src/config/uniter-base.json", + "output_dir": "/storage/finetune/itm/flickr_ot_alldata_base", + "max_txt_len": 60, + "conf_th": 0.2, + "max_bb": 100, + "min_bb": 10, + "num_bb": 36, + "train_batch_size": 40, + "negative_size": 1, + "hard_neg_size": 0, + "hard_neg_pool_size": 20, + "steps_per_hard_neg": -1, + "inf_minibatch_size": 512, + "gradient_accumulation_steps": 4, + "learning_rate": 5e-05, + "valid_steps": 2000, + "num_train_steps": 20000, + "optim": "adamw", + "betas": [ + 0.9, + 0.98 + ], + "margin": 0.2, + "dropout": 0.1, + "weight_decay": 0.01, + "grad_norm": 2.0, + "warmup_steps": 2000, + "seed": 42, + "fp16": true, + "n_workers": 4, + "pin_mem": true +} diff --git a/uniter_model/config/train-nlvr2-base-1gpu.json b/uniter_model/config/train-nlvr2-base-1gpu.json new file mode 100644 index 0000000..a77bad9 --- /dev/null +++ b/uniter_model/config/train-nlvr2-base-1gpu.json @@ -0,0 +1,37 @@ +{ + "train_txt_db": "/db/nlvr2_train_base-cased.db", + "train_img_db": "/img/nlvr2_train/", + "val_txt_db": "/db/nlvr2_dev_base-cased.db", + "val_img_db": "/img/nlvr2_dev/", + "test_txt_db": "/db/nlvr2_test1_base-cased.db", + "test_img_db": "/img/nlvr2_test/", + "checkpoint": "/pretrain/uniter-base-iclr.pt", + "model_config": "/src/config/uniter-base.json", + "model": "paired-attn", + "use_img_type": true, + "output_dir": "/storage/nlvr2/default", + "max_txt_len": 60, + "conf_th": 0.2, + "max_bb": 100, + "min_bb": 10, + "num_bb": 36, + "train_batch_size": 10240, + "val_batch_size": 10240, + "gradient_accumulation_steps": 1, + "learning_rate": 3e-05, + "valid_steps": 500, + "num_train_steps": 8000, + "optim": "adamw", + "betas": [ + 0.9, + 0.98 + ], + "dropout": 0.1, + "weight_decay": 0.01, + "grad_norm": 2.0, + "warmup_steps": 800, + "seed": 77, + "fp16": true, + "n_workers": 4, + "pin_mem": true +} diff --git a/uniter_model/config/train-ve-base-2gpu.json b/uniter_model/config/train-ve-base-2gpu.json new file mode 100644 index 0000000..48c5e0e --- /dev/null +++ b/uniter_model/config/train-ve-base-2gpu.json @@ -0,0 +1,31 @@ +{ + "train_txt_db": "/db/ve_train_base-cased.db/", + "train_img_db": "/img/flickr30k/", + "val_txt_db": "/db/ve_dev_base-cased.db/", + "val_img_db": "/img/flickr30k/", + "test_txt_db": "/db/ve_test_base-cased.db/", + "test_img_db": "/img/flickr30k/", + "checkpoint": "/pretrain/alltask_ot_alldata.pt", + "model_config": "/src/config/uniter-base.json", + "output_dir": "/storage/finetune/ve/default", + "max_txt_len": 60, + "conf_th": 0.2, + "max_bb": 100, + "min_bb": 10, + "num_bb": 36, + "train_batch_size": 4096, + "val_batch_size": 4096, + "gradient_accumulation_steps": 4, + "learning_rate": 3e-5, + "valid_steps": 500, + "num_train_steps": 6000, + "warmup_steps": 600, + "optim": "adamw", + "betas": [0.9, 0.98], + "grad_norm": 2.0, + "decay": "linear", + "dropout": 0.1, + "weight_decay": 0.01, + "seed": 42, + "fp16": true +} diff --git a/uniter_model/config/train-ve-large-2gpu.json b/uniter_model/config/train-ve-large-2gpu.json new file mode 100644 index 0000000..2ac167b --- /dev/null +++ b/uniter_model/config/train-ve-large-2gpu.json @@ -0,0 +1,31 @@ +{ + "train_txt_db": "/db/ve_train_large-cased.db/", + "train_img_db": "/img/flickr30k/", + "val_txt_db": "/db/ve_dev_large-cased.db/", + "val_img_db": "/img/flickr30k/", + "test_txt_db": "/db/ve_test_large-cased.db/", + "test_img_db": "/img/flickr30k/", + "checkpoint": "/pretrain/alltask_ot_alldata_large.pt", + "model_config": "/src/config/uniter-large.json", + "output_dir": "/storage/finetune/ve/default_large", + "max_txt_len": 60, + "conf_th": 0.2, + "max_bb": 100, + "min_bb": 10, + "num_bb": 36, + "train_batch_size": 4096, + "val_batch_size": 4096, + "gradient_accumulation_steps": 4, + "learning_rate": 3e-5, + "valid_steps": 500, + "num_train_steps": 6000, + "warmup_steps": 600, + "optim": "adamw", + "betas": [0.9, 0.98], + "grad_norm": 2.0, + "decay": "linear", + "dropout": 0.1, + "weight_decay": 0.01, + "seed": 42, + "fp16": true +} diff --git a/uniter_model/config/train-vqa-base-2gpu.json b/uniter_model/config/train-vqa-base-2gpu.json new file mode 100644 index 0000000..f1bbdc9 --- /dev/null +++ b/uniter_model/config/train-vqa-base-2gpu.json @@ -0,0 +1,35 @@ +{ + "train_txt_dbs": ["/db/vqa_train_base-cased.db", + "/db/vqa_trainval_base-cased.db", + "/db/vqa_vg_base-cased.db"], + "train_img_dbs": ["/img/coco_train2014/", "/img/coco_val2014", "/img/vg/"], + "val_txt_db": "/db/vqa_devval_base-cased.db", + "val_img_db": "/img/coco_val2014/", + "checkpoint": "/pretrain/uniter-base-iclr.pt", + "model_config": "/src/config/uniter-base.json", + "output_dir": "/storage/vqa/default", + "max_txt_len": 60, + "conf_th": 0.2, + "max_bb": 100, + "min_bb": 10, + "num_bb": 36, + "train_batch_size": 10240, + "val_batch_size": 10240, + "gradient_accumulation_steps": 5, + "learning_rate": 8e-05, + "valid_steps": 500, + "num_train_steps": 6000, + "optim": "adamw", + "betas": [ + 0.9, + 0.98 + ], + "dropout": 0.1, + "weight_decay": 0.01, + "grad_norm": 2.0, + "warmup_steps": 600, + "seed": 42, + "fp16": true, + "n_workers": 4, + "pin_mem": true +} diff --git a/uniter_model/config/uniter-base.json b/uniter_model/config/uniter-base.json new file mode 100644 index 0000000..8da8c59 --- /dev/null +++ b/uniter_model/config/uniter-base.json @@ -0,0 +1,14 @@ +{ + "attention_probs_dropout_prob": 0.1, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.1, + "hidden_size": 768, + "initializer_range": 0.02, + "intermediate_size": 3072, + "max_position_embeddings": 512, + "num_attention_heads": 12, + "num_hidden_layers": 12, + "num_hidden_layers_img": 1, + "type_vocab_size": 2, + "vocab_size": 28996 +} diff --git a/uniter_model/config/uniter-large.json b/uniter_model/config/uniter-large.json new file mode 100644 index 0000000..961e7ca --- /dev/null +++ b/uniter_model/config/uniter-large.json @@ -0,0 +1,13 @@ +{ + "attention_probs_dropout_prob": 0.1, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.1, + "hidden_size": 1024, + "initializer_range": 0.02, + "intermediate_size": 4096, + "max_position_embeddings": 512, + "num_attention_heads": 16, + "num_hidden_layers": 24, + "type_vocab_size": 2, + "vocab_size": 28996 +} diff --git a/uniter_model/data/__init__.py b/uniter_model/data/__init__.py new file mode 100644 index 0000000..6ce372b --- /dev/null +++ b/uniter_model/data/__init__.py @@ -0,0 +1,27 @@ +#from .data import (TxtTokLmdb, DetectFeatLmdb, +# ConcatDatasetWithLens, ImageLmdbGroup) +#from .mlm import (MlmDataset, MlmEvalDataset, +# BlindMlmDataset, BlindMlmEvalDataset, +# mlm_collate, mlm_eval_collate, +# mlm_blind_collate, mlm_blind_eval_collate) +#from .mrm import (MrfrDataset, OnlyImgMrfrDataset, +# MrcDataset, OnlyImgMrcDataset, +# mrfr_collate, mrfr_only_img_collate, +# mrc_collate, mrc_only_img_collate) +from .itm import (TokenBucketSamplerForItm, + ItmDataset, itm_collate, itm_ot_collate, + ItmRankDataset, ItmRankDatasetHardNeg, itm_rank_collate, + ItmRankDatasetHardNegFromText, + ItmRankDatasetHardNegFromImage, itm_rank_hnv2_collate, + ItmHardNegDataset, itm_hn_collate, + ItmValDataset, itm_val_collate, + ItmEvalDataset, itm_eval_collate) +from .sampler import TokenBucketSampler, DistributedSampler +from .loader import MetaLoader, PrefetchLoader + +from .vqa import VqaDataset, vqa_collate, VqaEvalDataset, vqa_eval_collate +from .nlvr2 import (Nlvr2PairedDataset, nlvr2_paired_collate, + Nlvr2PairedEvalDataset, nlvr2_paired_eval_collate, + Nlvr2TripletDataset, nlvr2_triplet_collate, + Nlvr2TripletEvalDataset, nlvr2_triplet_eval_collate) +from .ve import VeDataset, ve_collate, VeEvalDataset, ve_eval_collate diff --git a/uniter_model/data/data.py b/uniter_model/data/data.py new file mode 100644 index 0000000..52c226e --- /dev/null +++ b/uniter_model/data/data.py @@ -0,0 +1,283 @@ +""" +Dataset interfaces +""" +from collections import defaultdict +from contextlib import contextmanager +import io +import json +import lmdb +from os.path import exists + +import numpy as np +import torch +from torch.utils.data import Dataset, ConcatDataset +from tqdm import tqdm +from lz4.frame import compress, decompress + +import msgpack +import msgpack_numpy +msgpack_numpy.patch() + + +def _fp16_to_fp32(feat_dict): + out = {k: arr.astype(np.float32) + if arr.dtype == np.float16 else arr + for k, arr in feat_dict.items()} + return out + + +def compute_num_bb(confs, conf_th, min_bb, max_bb): + num_bb = max(min_bb, (confs > conf_th).sum()) + num_bb = min(max_bb, num_bb) + return num_bb + + +class DetectFeatLmdb(object): + def __init__(self, img_dir, conf_th=0.2, max_bb=100, min_bb=10, num_bb=36, + compress=True): + self.img_dir = img_dir + if conf_th == -1: + db_name = f'feat_numbb{num_bb}' + self.name2nbb = defaultdict(lambda: num_bb) + else: + db_name = f'feat_th{conf_th}_max{max_bb}_min{min_bb}' + nbb = f'nbb_th{conf_th}_max{max_bb}_min{min_bb}.json' + if not exists(f'{img_dir}/{nbb}'): + # nbb is not pre-computed + self.name2nbb = None + else: + self.name2nbb = json.load(open(f'{img_dir}/{nbb}')) + self.compress = compress + if compress: + db_name += '_compressed' + + if self.name2nbb is None: + if compress: + db_name = 'all_compressed' + else: + db_name = 'all' + # only read ahead on single node training + self.env = lmdb.open(f'{img_dir}/{db_name}', + readonly=True, create=False, + readahead=not _check_distributed()) + self.txn = self.env.begin(buffers=True) + if self.name2nbb is None: + self.name2nbb = self._compute_nbb() + + def _compute_nbb(self): + name2nbb = {} + fnames = json.loads(self.txn.get(key=b'__keys__').decode('utf-8')) + for fname in tqdm(fnames, desc='reading images'): + dump = self.txn.get(fname.encode('utf-8')) + if self.compress: + with io.BytesIO(dump) as reader: + img_dump = np.load(reader, allow_pickle=True) + confs = img_dump['conf'] + else: + img_dump = msgpack.loads(dump, raw=False) + confs = img_dump['conf'] + name2nbb[fname] = compute_num_bb(confs, self.conf_th, + self.min_bb, self.max_bb) + + return name2nbb + + def __del__(self): + self.env.close() + + def get_dump(self, file_name): + # hack for MRC + dump = self.txn.get(file_name.encode('utf-8')) + nbb = self.name2nbb[file_name] + if self.compress: + with io.BytesIO(dump) as reader: + img_dump = np.load(reader, allow_pickle=True) + img_dump = _fp16_to_fp32(img_dump) + else: + img_dump = msgpack.loads(dump, raw=False) + img_dump = _fp16_to_fp32(img_dump) + img_dump = {k: arr[:nbb, ...] for k, arr in img_dump.items()} + return img_dump + + def __getitem__(self, file_name): + dump = self.txn.get(file_name.encode('utf-8')) + nbb = self.name2nbb[file_name] + if self.compress: + with io.BytesIO(dump) as reader: + img_dump = np.load(reader, allow_pickle=True) + img_dump = {'features': img_dump['features'], + 'norm_bb': img_dump['norm_bb']} + else: + img_dump = msgpack.loads(dump, raw=False) + img_feat = torch.tensor(img_dump['features'][:nbb, :]).float() + img_bb = torch.tensor(img_dump['norm_bb'][:nbb, :]).float() + return img_feat, img_bb + + def __contains__(self, file_name): + return self.txn.get(file_name.encode('utf-8')) is not None + + +@contextmanager +def open_lmdb(db_dir, readonly=False): + db = TxtLmdb(db_dir, readonly) + try: + yield db + finally: + del db + + +class TxtLmdb(object): + def __init__(self, db_dir, readonly=True): + self.readonly = readonly + if readonly: + # training + self.env = lmdb.open(db_dir, + readonly=True, create=False, + readahead=not _check_distributed()) + self.txn = self.env.begin(buffers=True) + self.write_cnt = None + else: + # prepro + self.env = lmdb.open(db_dir, readonly=False, create=True, + map_size=4 * 1024**4) + self.txn = self.env.begin(write=True) + self.write_cnt = 0 + + def __del__(self): + if self.write_cnt: + self.txn.commit() + self.env.close() + + def __getitem__(self, key): + return msgpack.loads(decompress(self.txn.get(key.encode('utf-8'))), + raw=False) + + def __setitem__(self, key, value): + # NOTE: not thread safe + if self.readonly: + raise ValueError('readonly text DB') + ret = self.txn.put(key.encode('utf-8'), + compress(msgpack.dumps(value, use_bin_type=True))) + self.write_cnt += 1 + if self.write_cnt % 1000 == 0: + self.txn.commit() + self.txn = self.env.begin(write=True) + self.write_cnt = 0 + return ret + +def get_ids_and_lens(db): + assert isinstance(db, TxtTokLmdb) + lens = [] + ids = [] + for id_ in db.ids: + lens.append(db.id2len[id_]) + ids.append(id_) + return lens, ids + + +class DetectFeatTxtTokDataset(Dataset): + def __init__(self, txt_db, img_db): + assert isinstance(txt_db, TxtTokLmdb) + assert isinstance(img_db, DetectFeatLmdb) + self.txt_db = txt_db + self.img_db = img_db + txt_lens, self.ids = get_ids_and_lens(txt_db) + + txt2img = txt_db.txt2img + self.lens = [tl + self.img_db.name2nbb[txt2img[id_]] + for tl, id_ in zip(txt_lens, self.ids)] + + def __len__(self): + return len(self.ids) + + def __getitem__(self, i): + id_ = self.ids[i] + example = self.txt_db[id_] + return example + + def _get_img_feat(self, fname): + img_feat, bb = self.img_db[fname] + img_bb = torch.cat([bb, bb[:, 4:5]*bb[:, 5:]], dim=-1) + num_bb = img_feat.size(0) + return img_feat, img_bb, num_bb + + +class ConcatDatasetWithLens(ConcatDataset): + """ A thin wrapper on pytorch concat dataset for lens batching """ + def __init__(self, datasets): + super().__init__(datasets) + self.lens = [l for dset in datasets for l in dset.lens] + + def __getattr__(self, name): + return self._run_method_on_all_dsets(name) + + def _run_method_on_all_dsets(self, name): + def run_all(*args, **kwargs): + return [dset.__getattribute__(name)(*args, **kwargs) + for dset in self.datasets] + return run_all + + +def pad_tensors(tensors, lens=None, pad=0): + """B x [T, ...]""" + if lens is None: + lens = [t.size(0) for t in tensors] + max_len = max(lens) + bs = len(tensors) + hid = tensors[0].size(-1) + dtype = tensors[0].dtype + output = torch.zeros(bs, max_len, hid, dtype=dtype) + if pad: + output.data.fill_(pad) + for i, (t, l) in enumerate(zip(tensors, lens)): + output.data[i, :l, ...] = t.data + return output + + +def get_gather_index(txt_lens, num_bbs, batch_size, max_len, out_size): + # assert len(txt_lens) == len(num_bbs) == batch_size + gather_index = torch.arange(0, out_size, dtype=torch.long, + ).unsqueeze(0).repeat(len(num_bbs), 1) + + # for i, (tl, nbb) in enumerate(zip(txt_lens, num_bbs)): + # gather_index.data[i, tl:tl+nbb] = torch.arange(max_len, max_len+nbb, + # dtype=torch.long).data + return gather_index + + +def get_gather_index_uniter(txt_lens, num_bbs, batch_size, max_len, out_size): + assert len(txt_lens) == len(num_bbs) == batch_size + gather_index = torch.arange(0, out_size, dtype=torch.long, + ).unsqueeze(0).repeat(batch_size, 1) + + for i, (tl, nbb) in enumerate(zip(txt_lens, num_bbs)): + gather_index.data[i, tl:tl+nbb] = torch.arange(max_len, max_len+nbb, + dtype=torch.long).data + return gather_index + + +def get_gather_index_img(txt_lens, num_bbs, batch_size, max_len, out_size): + gather_index = torch.zeros(batch_size, out_size, dtype=torch.long) + + for i, (tl, nbb) in enumerate(zip(txt_lens, num_bbs)): + gather_index.data[i, :nbb] = torch.arange(max_len, max_len+nbb, + dtype=torch.long).data + gather_index.data[i, nbb:nbb+tl] = torch.arange(0, tl, + dtype=torch.long).data + return gather_index + + +class ImageLmdbGroup(object): + def __init__(self, conf_th, max_bb, min_bb, num_bb, compress): + self.path2imgdb = {} + self.conf_th = conf_th + self.max_bb = max_bb + self.min_bb = min_bb + self.num_bb = num_bb + self.compress = compress + + def __getitem__(self, path): + img_db = self.path2imgdb.get(path, None) + if img_db is None: + img_db = DetectFeatLmdb(path, self.conf_th, self.max_bb, + self.min_bb, self.num_bb, self.compress) + return img_db diff --git a/uniter_model/data/itm.py b/uniter_model/data/itm.py new file mode 100644 index 0000000..448eb51 --- /dev/null +++ b/uniter_model/data/itm.py @@ -0,0 +1,572 @@ +""" +Itm dataset +""" +from collections import defaultdict +import copy +import json +import random + +import torch +from torch.nn.utils.rnn import pad_sequence +import numpy as np +from toolz.sandbox import unzip +from cytoolz import concat + +from .data import (DetectFeatTxtTokDataset, TxtTokLmdb, DetectFeatLmdb, + pad_tensors, get_gather_index, get_ids_and_lens) +from .sampler import TokenBucketSampler + + +class TokenBucketSamplerForItm(TokenBucketSampler): + def __init__(self, dset, *args, **kwargs): + super().__init__(dset.lens, *args, **kwargs) + self.dset = dset + + def __iter__(self): + it = super().__iter__() + self.dset.new_epoch() + self._lens = self.dset.lens + return it + + +def _has_overlap(la, lb): + if len(la) < len(lb): + la, lb = lb, la + s = set(la) + return any(b in s for b in lb) + + +def _sample_negative_rand(sample_pool, ground_truths, num_sample): + """ random and retry """ + outputs = ground_truths[:1] + while _has_overlap(outputs, ground_truths): + outputs = random.sample(sample_pool, num_sample) + return outputs + + +def _sample_negative_extra(sample_pool, ground_truths, num_sample): + """ sample extra then remove """ + tot_size = len(ground_truths) + num_sample + outputs = set(random.sample(sample_pool, tot_size)) + for gt in ground_truths: + outputs.discard(gt) + outputs = list(outputs)[:num_sample] + return outputs + + +sample_negative = _sample_negative_rand # swith between 2 implementations + + +class ItmDataset(DetectFeatTxtTokDataset): + """ NOTE this Dataset handles distributed training itself + (for more efficient negative sampling) """ + def __init__(self, txt_db, img_db, neg_sample_p=0.5): + assert isinstance(txt_db, TxtTokLmdb) + assert isinstance(img_db, DetectFeatLmdb) + + self.txt_db = txt_db + self.img_db = img_db + + self.txt_lens, self.ids = get_ids_and_lens(txt_db) + self.all_imgs = list(set(txt_db[id_]['img_fname'] for id_ in self.ids)) + + self.neg_sample_p = neg_sample_p + self.new_epoch() + + def new_epoch(self): + """ should be called every epoch for more randomness""" + self.labels = np.random.choice( + [0, 1], size=len(self.ids), + p=[self.neg_sample_p, 1-self.neg_sample_p]) + + self.lens = [] + self.train_imgs = [] + for i, (id_, tl) in enumerate(zip(self.ids, self.txt_lens)): + img_fname = super().__getitem__(i)['img_fname'] + if self.labels[i] == 0: + img_fname = sample_negative(self.all_imgs, [img_fname], 1)[0] + self.train_imgs.append(img_fname) + self.lens.append(tl + self.img_db.name2nbb[img_fname]) + + def __getitem__(self, i): + example = super().__getitem__(i) + # labels and negative images should be sampled every epoch + ground_truth_label = self.labels[i] + img_fname = self.train_imgs[i] + img_feat, img_pos_feat, num_bb = self._get_img_feat(img_fname) + + # text input + input_ids = example['input_ids'] + input_ids = self.txt_db.combine_inputs(input_ids) + + attn_masks = torch.ones(len(input_ids) + num_bb, dtype=torch.long) + target = torch.Tensor(1).long() + target.data.fill_(ground_truth_label) + + return input_ids, img_feat, img_pos_feat, attn_masks, target + + +def itm_collate(inputs): + (input_ids, img_feats, img_pos_feats, attn_masks, targets + ) = map(list, unzip(inputs)) + + txt_lens = [i.size(0) for i in input_ids] + + input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0) + position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long + ).unsqueeze(0) + + num_bbs = [f.size(0) for f in img_feats] + img_feat = pad_tensors(img_feats, num_bbs) + img_pos_feat = pad_tensors(img_pos_feats, num_bbs) + + attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0) + targets = torch.cat(targets, dim=0) + bs, max_tl = input_ids.size() + out_size = attn_masks.size(1) + gather_index = get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size) + + batch = {'input_ids': input_ids, + 'position_ids': position_ids, + 'img_feat': img_feat, + 'img_pos_feat': img_pos_feat, + 'attn_masks': attn_masks, + 'gather_index': gather_index, + 'targets': targets} + return batch + + +def _compute_ot_scatter(txt_lens, max_txt_len, joint_len): + ot_scatter = torch.arange(0, joint_len, dtype=torch.long + ).unsqueeze(0).repeat(len(txt_lens), 1) + for i, tl in enumerate(txt_lens): + max_ind = max_txt_len + (joint_len-tl) + ot_scatter.data[i, tl:] = torch.arange(max_txt_len, max_ind, + dtype=torch.long).data + return ot_scatter + + +def _compute_pad(lens, max_len): + pad = torch.zeros(len(lens), max_len, dtype=torch.bool) + for i, l in enumerate(lens): + pad.data[i, l:].fill_(1) + return pad + + +def itm_ot_collate(inputs): + (input_ids, img_feats, img_pos_feats, attn_masks, targets + ) = map(list, unzip(inputs)) + + txt_lens = [i.size(0) for i in input_ids] + + input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0) + position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long + ).unsqueeze(0) + + num_bbs = [f.size(0) for f in img_feats] + img_feat = pad_tensors(img_feats, num_bbs) + img_pos_feat = pad_tensors(img_pos_feats, num_bbs) + + attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0) + targets = torch.cat(targets, dim=0) + bs, max_tl = input_ids.size() + out_size = attn_masks.size(1) + gather_index = get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size) + + # OT inputs + max_tl = max(txt_lens) + max_nbb = max(num_bbs) + ot_scatter = _compute_ot_scatter(txt_lens, max_tl, attn_masks.size(1)) + txt_pad = _compute_pad(txt_lens, max_tl) + img_pad = _compute_pad(num_bbs, max_nbb) + ot_inputs = {'ot_scatter': ot_scatter, + 'scatter_max': ot_scatter.max().item(), + 'txt_pad': txt_pad, + 'img_pad': img_pad} + + batch = {'input_ids': input_ids, + 'position_ids': position_ids, + 'img_feat': img_feat, + 'img_pos_feat': img_pos_feat, + 'attn_masks': attn_masks, + 'gather_index': gather_index, + 'targets': targets, + 'ot_inputs': ot_inputs} + return batch + + +class ItmRankDataset(DetectFeatTxtTokDataset): + def __init__(self, txt_db, img_db, neg_sample_size=1): + assert neg_sample_size > 0, \ + "ItmRankDataset need at least 1 negative sample" + super().__init__(txt_db, img_db) + + txt2img = self.txt_db.txt2img + self.txt2img = {id_: txt2img[id_] for id_ in self.ids} + # images partitioned by rank + self.img2txts = defaultdict(list) + for id_, img in self.txt2img.items(): + self.img2txts[img].append(id_) + self.img_name_list = list(self.img2txts.keys()) + + assert neg_sample_size > 0 + self.neg_sample_size = neg_sample_size + + def __getitem__(self, i): + gt_txt_id = self.ids[i] + gt_img_fname = self.txt2img[gt_txt_id] + + id_pairs = [(gt_txt_id, gt_img_fname)] + # sample negatives + neg_sample_img_ids = sample_negative( + self.img_name_list, [gt_img_fname], self.neg_sample_size) + neg_sample_txt_ids = sample_negative( + self.ids, self.img2txts[gt_img_fname], self.neg_sample_size) + id_pairs.extend([(gt_txt_id, neg_img_id) + for neg_img_id in neg_sample_img_ids] + + [(neg_txt_id, gt_img_fname) + for neg_txt_id in neg_sample_txt_ids]) + inputs = self._collect_inputs(id_pairs) + assert len(inputs) == (1 + 2*self.neg_sample_size) + return inputs + + def _collect_inputs(self, id_pairs): + # create input features + inputs = [] + for txt_id, img_id in id_pairs: + example = self.txt_db[txt_id] + # text input + input_ids = example['input_ids'] + input_ids = self.txt_db.combine_inputs(input_ids) + # img input + img_feat, img_pos_feat, num_bb = self._get_img_feat(img_id) + # mask + attn_masks_text = torch.ones(len(input_ids), dtype=torch.long) + attn_masks_img = torch.ones(num_bb, dtype=torch.long) + + inputs.append((input_ids, img_feat, img_pos_feat, attn_masks_text, attn_masks_img)) + + return inputs + + +class ItmRankDatasetHardNeg(ItmRankDataset): + def __init__(self, txt_db, img_db, neg_sample_size=1, hard_neg_size=1): + assert hard_neg_size > 0, \ + "ItmRankDatasetHardNeg need at least 1 hard negative sample" + DetectFeatTxtTokDataset.__init__(self, txt_db, img_db) + + txt2img = self.txt_db.txt2img + self.txt2img = {id_: txt2img[id_] for id_ in self.ids} + self.img2txts = self.txt_db.img2txts + self.img_name_list = list(self.img2txts.keys()) + + assert neg_sample_size > 0 + self.neg_sample_size = neg_sample_size + self.hard_neg_size = hard_neg_size + + def reload_hard_negs(self, hard_neg_dir): + self.txt2hardimgs = json.load( + open(f'{hard_neg_dir}/' + f'txt2hardimgs_rank{hvd.rank()}.json')) + self.img2hardtxts = json.load( + open(f'{hard_neg_dir}/img2hardtxts.json')) + + def __getitem__(self, i): + gt_txt_id = self.ids[i] + gt_img_fname = self.txt2img[gt_txt_id] + + id_pairs = [(gt_txt_id, gt_img_fname)] + # sample hard negatives + if self.hard_neg_size > 0: + hard_neg_img_samples = random.sample( + self.txt2hardimgs[gt_txt_id], self.hard_neg_size) + hard_neg_txt_samples = random.sample( + self.img2hardtxts[gt_img_fname], self.hard_neg_size) + id_pairs.extend([(gt_txt_id, neg_img_id) + for neg_img_id in hard_neg_img_samples] + + [(neg_txt_id, gt_img_fname) + for neg_txt_id in hard_neg_txt_samples]) + # sample normal negatives + if self.neg_sample_size > 0: + neg_sample_img_ids = sample_negative( + self.img_name_list, [gt_img_fname], self.neg_sample_size) + neg_sample_txt_ids = sample_negative( + self.ids, self.img2txts[gt_img_fname], self.neg_sample_size) + id_pairs.extend([(gt_txt_id, neg_img_id) + for neg_img_id in neg_sample_img_ids] + + [(neg_txt_id, gt_img_fname) + for neg_txt_id in neg_sample_txt_ids]) + + inputs = self._collect_inputs(id_pairs) + assert len(inputs) == (1 + + 2*self.neg_sample_size + + 2*self.hard_neg_size) + return inputs + + +def itm_rank_collate(inputs): + (input_ids, img_feats, img_pos_feats, attn_masks_text, attn_masks_img, + ) = map(list, unzip(concat(i for i in inputs))) + + txt_lens = [i.size(0) for i in input_ids] + input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0) + position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long + ).unsqueeze(0) + + num_bbs = [f.size(0) for f in img_feats] + img_feat = pad_tensors(img_feats, num_bbs) + img_pos_feat = pad_tensors(img_pos_feats, num_bbs) + + attn_masks_text = pad_sequence(attn_masks_text, batch_first=True, padding_value=0) + attn_masks_img = pad_sequence(attn_masks_img, batch_first=True, padding_value=0) + sample_size = len(inputs[0]) + assert all(sample_size == len(i) for i in inputs) + + bs, max_tl = input_ids.size() + # out_size = attn_masks.size(1) + gather_index = None # get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size) + + batch = {'input_ids': input_ids, + 'position_ids': position_ids, + 'img_feat': img_feat, + 'img_pos_feat': img_pos_feat, + 'attn_masks_text': attn_masks_text, + 'attn_masks_img': attn_masks_img, + 'gather_index': gather_index, + 'sample_size': sample_size} + return batch + + +class ItmRankDatasetHardNegFromText(DetectFeatTxtTokDataset): + def __init__(self, txt_db, img_db, neg_sample_size=1): + assert neg_sample_size > 0, \ + "ItmRankDatasetHardNegV2 need at least 1 negative sample" + super().__init__(txt_db, img_db) + + txt2img = self.txt_db.txt2img + self.txt2img = {id_: txt2img[id_] for id_ in self.ids} + self.img2txts = self.txt_db.img2txts + self.img_name_list = list(self.img2txts.keys()) + self.neg_sample_size = neg_sample_size + + def __getitem__(self, i): + gt_txt_id = self.ids[i] + gt_img_fname = self.txt2img[gt_txt_id] + + input_ids = self.txt_db[gt_txt_id]['input_ids'] + input_ids = self.txt_db.combine_inputs(input_ids) + input_ids = input_ids.unsqueeze(0) + position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long + ).unsqueeze(0) + + neg_img_ids = sample_negative( + self.img_name_list, [gt_img_fname], self.neg_sample_size) + img_ids = [gt_img_fname] + neg_img_ids + # process image features (gt always first) + img_feats, img_pos_feats, num_bbs = map( + list, unzip(map(self._get_img_feat, img_ids))) + img_feat = pad_tensors(img_feats, num_bbs) + img_pos_feat = pad_tensors(img_pos_feats, num_bbs) + + tl = input_ids.size(1) + attn_masks = torch.zeros(len(img_ids), max(num_bbs) + tl).long() + for i, nbb in enumerate(num_bbs): + attn_masks.data[i, :tl+nbb].fill_(1) + out_size = attn_masks.size(1) + gather_index = get_gather_index([tl]*len(img_ids), num_bbs, + len(img_ids), tl, out_size) + + batch = {'input_ids': input_ids, + 'position_ids': position_ids, + 'img_feat': img_feat, + 'img_pos_feat': img_pos_feat, + 'attn_masks': attn_masks, + 'gather_index': gather_index} + return batch + + +class ItmRankDatasetHardNegFromImage(DetectFeatTxtTokDataset): + def __init__(self, txt_db, img_db, neg_sample_size=1): + assert neg_sample_size > 0, \ + "ItmRankDatasetHardNegV2 need at least 1 negative sample" + super().__init__(txt_db, img_db) + + txt2img = self.txt_db.txt2img + self.txt2img = {id_: txt2img[id_] for id_ in self.ids} + self.img2txts = self.txt_db.img2txts + self.txt_name_list = list(self.txt2img.keys()) + self.neg_sample_size = neg_sample_size + + + def __getitem__(self, i): + gt_txt_id = self.ids[i] + gt_img_id = self.txt2img[gt_txt_id] + gt_txt_ids = self.img2txts[gt_img_id] + + # process image features (gt always first) + img_feat, img_pos_feat, nbb = self._get_img_feat(gt_img_id) + img_feat = img_feat.unsqueeze(0) + img_pos_feat = img_pos_feat.unsqueeze(0) + + # sample negative + neg_txt_ids = sample_negative( + self.txt_name_list, gt_txt_ids, self.neg_sample_size) + txt_ids = [gt_txt_id] + neg_txt_ids + + # process text inputs + all_inputs = [] + txt_lens = [] + for txt_id in txt_ids: + input_ids = self.txt_db.combine_inputs( + self.txt_db[txt_id]['input_ids']) + all_inputs.append(input_ids) + txt_lens.append(len(input_ids)) + input_ids = pad_sequence(all_inputs, batch_first=True, padding_value=0) + position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long + ).unsqueeze(0) + + attn_masks = torch.zeros(len(txt_ids), max(txt_lens) + nbb).long() + for i, tl in enumerate(txt_lens): + attn_masks.data[i, :tl+nbb].fill_(1) + out_size = attn_masks.size(1) + gather_index = get_gather_index(txt_lens, [nbb]*len(txt_ids), + len(txt_ids), tl, out_size) + + batch = {'input_ids': input_ids, + 'position_ids': position_ids, + 'img_feat': img_feat, + 'img_pos_feat': img_pos_feat, + 'attn_masks': attn_masks, + 'gather_index': gather_index} + return batch + + +def itm_rank_hnv2_collate(inputs): + assert len(inputs) == 1 + return inputs[0] + + +class ItmValDataset(DetectFeatTxtTokDataset): + """ For evaluating Image-Text-Retrieval task """ + def __init__(self, db_dir, img_dir, mini_batch_size=400): + super().__init__(db_dir, img_dir) + del self.lens + self.txt2img = self.txt_db.txt2img + self.img2txts = self.txt_db.img2txts + self.all_img_ids = list(self.img2txts.keys()) + + assert len(self.img2txts) >= mini_batch_size > 0 + self.bs = mini_batch_size + + def _get_batch_ids(self, i): + gt_txt_id = self.ids[i] + gt_img_id = self.txt2img[gt_txt_id] + + # sample fixed negatives for each gt image + i = self.all_img_ids.index(gt_img_id) + neg_st = i+1 + neg_end = neg_st+self.bs-1 + if neg_end > len(self.all_img_ids): + # warp around + neg_end -= len(self.all_img_ids) + neg_img_ids = (self.all_img_ids[neg_st:] + + self.all_img_ids[:neg_end]) + else: + neg_img_ids = self.all_img_ids[neg_st:neg_end] + + assert len(neg_img_ids) == (self.bs - 1),\ + "Did not sample enough neg samples" + + return gt_img_id, neg_img_ids + + def __getitem__(self, i): + """ this returns list of mini-batches """ + gt_img_id, neg_img_ids = self._get_batch_ids(i) + # NOTE 1st one is gt img + batch = self.get_batch(i, [gt_img_id] + neg_img_ids) + return batch + + def get_batch(self, i, img_ids): + example = super().__getitem__(i) + + input_ids = example['input_ids'] + input_ids = self.txt_db.combine_inputs(input_ids) + input_ids = input_ids.unsqueeze(0).expand(len(img_ids), -1).clone() + position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long + ).unsqueeze(0) + + # process image features (gt always first) + img_feats, img_pos_feats, num_bbs = map( + list, unzip(map(self._get_img_feat, img_ids))) + img_feat = pad_tensors(img_feats, num_bbs) + img_pos_feat = pad_tensors(img_pos_feats, num_bbs) + + tl = input_ids.size(1) + attn_masks_text = torch.ones(len(img_ids), tl).long() + # attn_masks_text = torch.ones(1, tl).long() + attn_masks_img = torch.zeros(len(img_ids), max(num_bbs)).long() + for i, nbb in enumerate(num_bbs): + attn_masks_img.data[i, :nbb].fill_(1) + + # out_size = attn_masks.size(1) + gather_index = None #get_gather_index([tl]*len(img_ids), num_bbs, len(img_ids), tl, out_size) + + batch = {'input_ids': input_ids, + 'position_ids': position_ids, + 'img_feat': img_feat, + 'img_pos_feat': img_pos_feat, + 'attn_masks_text': attn_masks_text, + 'attn_masks_img': attn_masks_img, + 'gather_index': gather_index} + return batch + + +def itm_val_collate(inputs): + assert len(inputs) == 1, "input batch size > 1" + return inputs[0] + + +class ItmHardNegDataset(ItmValDataset): + def _get_batch_ids(self, i): + gt_txt_id = self.ids[i] + gt_img_id = self.txt2img[gt_txt_id] + + # sample fixed negatives for each gt image + i = self.all_img_ids.index(gt_img_id) + all_img_ids = copy.deepcopy(self.all_img_ids) + all_img_ids.remove(gt_img_id) + random.shuffle(all_img_ids) + neg_img_ids = all_img_ids[:self.bs] + + assert len(neg_img_ids) == (self.bs),\ + "Did not sample enough neg samples" + + return gt_img_id, neg_img_ids + + def __getitem__(self, i): + """ this returns list of mini-batches """ + _, neg_img_ids = self._get_batch_ids(i) + batch = self.get_batch(i, neg_img_ids) + batch['gt_txt_id'] = self.ids[i] + batch['neg_img_ids'] = neg_img_ids + return batch + + +itm_hn_collate = itm_val_collate + + +class ItmEvalDataset(ItmValDataset): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.all_img_ids = sorted(copy.deepcopy(self.all_img_ids), + key=lambda i: self.img_db.name2nbb[i]) + + def __getitem__(self, i): + mini_batches = [] + for st in range(0, len(self.all_img_ids), self.bs): + mini_batches.append( + self.get_batch(i, self.all_img_ids[st:st+self.bs])) + return mini_batches + + +itm_eval_collate = itm_val_collate diff --git a/uniter_model/data/loader.py b/uniter_model/data/loader.py new file mode 100644 index 0000000..3e521fb --- /dev/null +++ b/uniter_model/data/loader.py @@ -0,0 +1,138 @@ +""" +A meta data loader for sampling from different datasets / training tasks +A prefetch loader to speedup data loading +""" +import random + +import torch +from torch.utils.data import DataLoader + +from uniter_model.utils.distributed import any_broadcast + + +class MetaLoader(object): + """ wraps multiple data loader """ + def __init__(self, loaders, accum_steps=1, distributed=False): + assert isinstance(loaders, dict) + self.name2loader = {} + self.name2iter = {} + self.sampling_pools = [] + for n, l in loaders.items(): + if isinstance(l, tuple): + l, r = l + elif isinstance(l, DataLoader): + r = 1 + else: + raise ValueError() + self.name2loader[n] = l + self.name2iter[n] = iter(l) + self.sampling_pools.extend([n]*r) + + self.accum_steps = accum_steps + self.distributed = distributed + self.step = 0 + + def __iter__(self): + """ this iterator will run indefinitely """ + task = self.sampling_pools[0] + while True: + if self.step % self.accum_steps == 0: + task = random.choice(self.sampling_pools) + if self.distributed: + # make sure all process is training same task + task = any_broadcast(task, 0) + self.step += 1 + iter_ = self.name2iter[task] + try: + batch = next(iter_) + except StopIteration: + iter_ = iter(self.name2loader[task]) + batch = next(iter_) + self.name2iter[task] = iter_ + + yield task, batch + + +def move_to_cuda(batch): + if isinstance(batch, torch.Tensor): + return batch.cuda(non_blocking=True) + elif isinstance(batch, list): + new_batch = [move_to_cuda(t) for t in batch] + elif isinstance(batch, tuple): + new_batch = tuple(move_to_cuda(t) for t in batch) + elif isinstance(batch, dict): + new_batch = {n: move_to_cuda(t) for n, t in batch.items()} + else: + return batch + return new_batch + + +def record_cuda_stream(batch): + if isinstance(batch, torch.Tensor): + batch.record_stream(torch.cuda.current_stream()) + elif isinstance(batch, list) or isinstance(batch, tuple): + for t in batch: + record_cuda_stream(t) + elif isinstance(batch, dict): + for t in batch.values(): + record_cuda_stream(t) + else: + pass + + +class PrefetchLoader(object): + """ + overlap compute and cuda data transfer + (copied and then modified from nvidia apex) + """ + def __init__(self, loader): + self.loader = loader + self.stream = torch.cuda.Stream() + + def __iter__(self): + loader_it = iter(self.loader) + self.preload(loader_it) + batch = self.next(loader_it) + while batch is not None: + yield batch + batch = self.next(loader_it) + + def __len__(self): + return len(self.loader) + + def preload(self, it): + try: + self.batch = next(it) + except StopIteration: + self.batch = None + return + # if record_stream() doesn't work, another option is to make sure + # device inputs are created on the main stream. + # self.next_input_gpu = torch.empty_like(self.next_input, + # device='cuda') + # self.next_target_gpu = torch.empty_like(self.next_target, + # device='cuda') + # Need to make sure the memory allocated for next_* is not still in use + # by the main stream at the time we start copying to next_*: + # self.stream.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(self.stream): + self.batch = move_to_cuda(self.batch) + # more code for the alternative if record_stream() doesn't work: + # copy_ will record the use of the pinned source tensor in this + # side stream. + # self.next_input_gpu.copy_(self.next_input, non_blocking=True) + # self.next_target_gpu.copy_(self.next_target, non_blocking=True) + # self.next_input = self.next_input_gpu + # self.next_target = self.next_target_gpu + + def next(self, it): + torch.cuda.current_stream().wait_stream(self.stream) + batch = self.batch + if batch is not None: + record_cuda_stream(batch) + self.preload(it) + return batch + + def __getattr__(self, name): + method = self.loader.__getattribute__(name) + return method diff --git a/uniter_model/data/mlm.py b/uniter_model/data/mlm.py new file mode 100644 index 0000000..a77cf99 --- /dev/null +++ b/uniter_model/data/mlm.py @@ -0,0 +1,360 @@ +""" +MLM datasets +""" +import math +import random + +import torch +from torch.utils.data import Dataset +from torch.nn.utils.rnn import pad_sequence +from toolz.sandbox import unzip + +from .data import (DetectFeatTxtTokDataset, TxtTokLmdb, + get_ids_and_lens, pad_tensors, get_gather_index) + + +def random_word(tokens, vocab_range, mask): + """ + Masking some random tokens for Language Model task with probabilities as in + the original BERT paper. + :param tokens: list of int, tokenized sentence. + :param vocab_range: for choosing a random word + :return: (list of int, list of int), masked tokens and related labels for + LM prediction + """ + output_label = [] + + for i, token in enumerate(tokens): + prob = random.random() + # mask token with 15% probability + if prob < 0.15: + prob /= 0.15 + + # 80% randomly change token to mask token + if prob < 0.8: + tokens[i] = mask + + # 10% randomly change token to random token + elif prob < 0.9: + tokens[i] = random.choice(list(range(*vocab_range))) + + # -> rest 10% randomly keep current token + + # append current token to output (we will predict these later) + output_label.append(token) + else: + # no masking token (will be ignored by loss function later) + output_label.append(-1) + if all(o == -1 for o in output_label): + # at least mask 1 + output_label[0] = tokens[0] + tokens[0] = mask + + return tokens, output_label + + +class MlmDataset(DetectFeatTxtTokDataset): + def __init__(self, txt_db, img_db): + assert isinstance(txt_db, TxtTokLmdb) + super().__init__(txt_db, img_db) + + def __getitem__(self, i): + """ + Return: + - input_ids : (L, ), i.e., [cls, wd, wd, ..., sep, 0, 0], 0s padded + - img_feat : (num_bb, d) + - img_pos_feat : (num_bb, 7) + - attn_masks : (L + num_bb, ), ie., [1, 1, ..., 0, 0, 1, 1] + - txt_labels : (L, ), [-1, -1, wid, -1, -1, -1] + 0's padded so that (L + num_bb) % 8 == 0 + """ + example = super().__getitem__(i) + + # text input + input_ids, txt_labels = self.create_mlm_io(example['input_ids']) + + # img input + img_feat, img_pos_feat, num_bb = self._get_img_feat( + example['img_fname']) + + attn_masks = torch.ones(len(input_ids) + num_bb, dtype=torch.long) + + return input_ids, img_feat, img_pos_feat, attn_masks, txt_labels + + def create_mlm_io(self, input_ids): + input_ids, txt_labels = random_word(input_ids, + self.txt_db.v_range, + self.txt_db.mask) + input_ids = torch.tensor([self.txt_db.cls_] + + input_ids + + [self.txt_db.sep]) + txt_labels = torch.tensor([-1] + txt_labels + [-1]) + return input_ids, txt_labels + + +def mlm_collate(inputs): + """ + Return: + :input_ids (n, max_L) padded with 0 + :position_ids (n, max_L) padded with 0 + :txt_lens list of [txt_len] + :img_feat (n, max_num_bb, feat_dim) + :img_pos_feat (n, max_num_bb, 7) + :num_bbs list of [num_bb] + :attn_masks (n, max_{L + num_bb}) padded with 0 + :txt_labels (n, max_L) padded with -1 + """ + (input_ids, img_feats, img_pos_feats, attn_masks, txt_labels + ) = map(list, unzip(inputs)) + + # text batches + txt_lens = [i.size(0) for i in input_ids] + input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0) + txt_labels = pad_sequence(txt_labels, batch_first=True, padding_value=-1) + position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long + ).unsqueeze(0) + + # image batches + num_bbs = [f.size(0) for f in img_feats] + img_feat = pad_tensors(img_feats, num_bbs) + img_pos_feat = pad_tensors(img_pos_feats, num_bbs) + + attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0) + + bs, max_tl = input_ids.size() + out_size = attn_masks.size(1) + gather_index = get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size) + + batch = {'input_ids': input_ids, + 'position_ids': position_ids, + 'img_feat': img_feat, + 'img_pos_feat': img_pos_feat, + 'attn_masks': attn_masks, + 'gather_index': gather_index, + 'txt_labels': txt_labels} + return batch + + +class BlindMlmDataset(Dataset): + def __init__(self, txt_db): + assert isinstance(txt_db, TxtTokLmdb) + self.txt_db = txt_db + self.lens, self.ids = get_ids_and_lens(txt_db) + + def __len__(self): + return len(self.ids) + + def __getitem__(self, i): + id_ = self.ids[i] + example = self.txt_db[id_] + input_ids, txt_labels = self.create_mlm_io(example['input_ids']) + attn_masks = torch.ones(len(input_ids), dtype=torch.long) + + return input_ids, attn_masks, txt_labels + + +def mlm_blind_collate(inputs): + input_ids, attn_masks, txt_labels = map(list, unzip(inputs)) + + input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0) + position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long + ).unsqueeze(0) + attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0) + txt_labels = pad_sequence(txt_labels, batch_first=True, padding_value=-1) + + batch = {'input_ids': input_ids, + 'position_ids': position_ids, + 'attn_masks': attn_masks, + 'txt_labels': txt_labels} + return batch + + +def eval_mask(len_, num_samples=7): + """ build the mask for evaluating MLM + circularly mask 1 word out of every x words + """ + # build the random masks + if len_ <= num_samples: + masks = torch.eye(len_).bool() + num_samples = len_ + else: + mask_inds = [list(range(i, len_, num_samples)) + for i in range(num_samples)] + masks = torch.zeros(num_samples, len_).bool() + for i, indices in enumerate(mask_inds): + for j in indices: + masks.data[i, j] = 1 + assert (masks.sum(dim=0) != torch.ones(len_).long()).sum().item() == 0 + assert masks.sum().item() == len_ + return masks + + +def eval_gather_inds(len_, num_samples=7): + """ get the gather indices """ + inds = torch.arange(0, num_samples, dtype=torch.long) + mul = math.ceil(len_ / num_samples) + output = inds.repeat(mul)[:len_] + return output + + +def stack_pad_tensors(tensors, lens=None, ns=None, pad=0): + """N x [B_i, T, ...]""" + if ns is None: + ns = [t.size(0) for t in tensors] + if lens is None: + lens = [t.size(1) for t in tensors] + max_len = max(lens) + bs = sum(ns) + hid_dims = tensors[0].size()[2:] + dtype = tensors[0].dtype + output = torch.zeros(bs, max_len, *hid_dims, dtype=dtype) + if pad: + output.data.fill_(pad) + i = 0 + for t, l, n in zip(tensors, lens, ns): + output.data[i:i+n, :l, ...] = t.data + i += n + return output + + +def expand_tensors(tensors, ns): + return [t.unsqueeze(0).expand(n, *tuple([-1]*t.dim())) + for t, n in zip(tensors, ns)] + + +class MlmEvalDataset(DetectFeatTxtTokDataset): + """ For evaluating MLM training task """ + def __init__(self, txt_db, img_db): + assert isinstance(txt_db, TxtTokLmdb) + super().__init__(txt_db, img_db) + + def __getitem__(self, i): + example = super().__getitem__(i) + + # text input + (input_ids, txt_labels, gather_inds + ) = self.create_mlm_eval_io(example['input_ids']) + + # img input + img_feat, img_pos_feat, num_bb = self._get_img_feat( + example['img_fname']) + + attn_masks = torch.ones(input_ids.size(1) + num_bb, dtype=torch.long) + + return (input_ids, img_feat, img_pos_feat, attn_masks, + txt_labels, gather_inds) + + def create_mlm_eval_io(self, input_ids): + txt_labels = torch.tensor(input_ids) + masks = eval_mask(len(input_ids)) + n_mask = masks.size(0) + masks = torch.cat([torch.zeros(n_mask, 1).bool(), + masks, + torch.zeros(n_mask, 1).bool()], + dim=1) + input_ids = torch.tensor([[self.txt_db.cls_] + + input_ids + + [self.txt_db.sep] + for _ in range(n_mask)]) + input_ids.data.masked_fill_(masks, self.txt_db.mask) + gather_inds = eval_gather_inds(len(txt_labels)) + return input_ids, txt_labels, gather_inds + + +def _batch_gather_tgt(gather_inds, n_masks): + gather_tgts = [] + offset = 0 + for g, n in zip(gather_inds, n_masks): + gather_tgts.append(g + offset) + offset += n + gather_tgt = pad_sequence(gather_tgts, batch_first=True, padding_value=0) + return gather_tgt + + +def mlm_eval_collate(inputs): + (input_ids, img_feats, img_pos_feats, attn_masks, txt_labels, gather_inds + ) = map(list, unzip(inputs)) + + # sizes + n_masks, txt_lens = map(list, unzip(i.size() for i in input_ids)) + + # text batches + input_ids = stack_pad_tensors(input_ids, txt_lens, n_masks) + position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long + ).unsqueeze(0) + txt_labels = pad_sequence(txt_labels, batch_first=True, padding_value=-1) + gather_tgt = _batch_gather_tgt(gather_inds, n_masks) + + # image batches + num_bbs = [f.size(0) for f in img_feats] + img_feat = stack_pad_tensors(expand_tensors(img_feats, n_masks), + num_bbs, n_masks) + img_pos_feat = stack_pad_tensors(expand_tensors(img_pos_feats, n_masks), + num_bbs, n_masks) + + bs, max_tl = input_ids.size() + attn_masks = stack_pad_tensors(expand_tensors(attn_masks, n_masks), + None, n_masks) + out_size = attn_masks.size(1) + # repeat txt_lens, num_bbs + txt_lens = [l for l, n in zip(txt_lens, n_masks) for _ in range(n)] + num_bbs = [b for b, n in zip(num_bbs, n_masks) for _ in range(n)] + gather_index = get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size) + + batch = {'input_ids': input_ids, + 'position_ids': position_ids, + 'img_feat': img_feat, + 'img_pos_feat': img_pos_feat, + 'attn_masks': attn_masks, + 'gather_index': gather_index, + 'gather_tgt': gather_tgt, + 'txt_labels': txt_labels} + return batch + + +class BlindMlmEvalDataset(Dataset): + def __init__(self, txt_db): + assert isinstance(txt_db, TxtTokLmdb) + self.txt_db = txt_db + self.lens, self.ids = get_ids_and_lens(txt_db) + + def __len__(self): + return len(self.ids) + + def __getitem__(self, i): + id_ = self.ids[i] + example = self.txt_db[id_] + input_ids = example['input_ids'] + + # text input + input_ids = example['input_ids'] + (input_ids, txt_labels, gather_inds + ) = self.txt_db.create_mlm_eval_io(input_ids) + + attn_masks = torch.ones(len(input_ids), dtype=torch.long) + + return input_ids, attn_masks, txt_labels, gather_inds + + +def mlm_blind_eval_collate(inputs): + (input_ids, position_ids, attn_masks, txt_labels, gather_inds + ) = map(list, unzip(inputs)) + + # sizes + n_masks, txt_lens = map(list, unzip(i.size() for i in input_ids)) + + # text batches + input_ids = stack_pad_tensors(input_ids, txt_lens, n_masks) + position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long + ).unsqueeze(0) + attn_masks = stack_pad_tensors(expand_tensors(attn_masks, n_masks), + None, n_masks) + txt_labels = pad_sequence(txt_labels, batch_first=True, padding_value=-1) + gather_tgt = _batch_gather_tgt(gather_inds, n_masks) + + batch = {'input_ids': input_ids, + 'position_ids': position_ids, + 'attn_masks': attn_masks, + 'gather_tgt': gather_tgt, + 'txt_labels': txt_labels} + return batch diff --git a/uniter_model/data/mrm.py b/uniter_model/data/mrm.py new file mode 100644 index 0000000..7cde0dc --- /dev/null +++ b/uniter_model/data/mrm.py @@ -0,0 +1,287 @@ +""" +MRM Datasets +""" +import random + +import torch +from torch.utils.data import Dataset +from torch.nn.utils.rnn import pad_sequence +from toolz.sandbox import unzip +from .data import DetectFeatTxtTokDataset, pad_tensors, get_gather_index + + +def _get_img_mask(mask_prob, num_bb): + img_mask = [random.random() < mask_prob for _ in range(num_bb)] + if not any(img_mask): + # at least mask 1 + img_mask[random.choice(range(num_bb))] = True + img_mask = torch.tensor(img_mask) + return img_mask + + +def _get_img_tgt_mask(img_mask, txt_len): + z = torch.zeros(txt_len, dtype=torch.bool) + img_mask_tgt = torch.cat([z, img_mask], dim=0) + return img_mask_tgt + + +def _get_feat_target(img_feat, img_masks): + img_masks_ext = img_masks.unsqueeze(-1).expand_as(img_feat) # (n, m, d) + feat_dim = img_feat.size(-1) + feat_targets = img_feat[img_masks_ext].contiguous().view( + -1, feat_dim) # (s, d) + return feat_targets + + +def _mask_img_feat(img_feat, img_masks): + img_masks_ext = img_masks.unsqueeze(-1).expand_as(img_feat) + img_feat_masked = img_feat.data.masked_fill(img_masks_ext, 0) + return img_feat_masked + + +class MrfrDataset(DetectFeatTxtTokDataset): + def __init__(self, mask_prob, *args, **kwargs): + super().__init__(*args, **kwargs) + self.mask_prob = mask_prob + + def __getitem__(self, i): + """ + Return: + - input_ids : (L, ), i.e., [cls, wd, wd, ..., sep, 0, 0], 0s padded + - img_feat : (num_bb, d) + - img_pos_feat : (num_bb, 7) + - attn_masks : (L + num_bb, ), ie., [1, 1, ..., 0, 0, 1, 1] + - img_mask : (num_bb, ) between {0, 1} + """ + example = super().__getitem__(i) + # text input + input_ids = example['input_ids'] + input_ids = self.txt_db.combine_inputs(input_ids) + + # image input features + img_feat, img_pos_feat, num_bb = self._get_img_feat( + example['img_fname']) + img_mask = _get_img_mask(self.mask_prob, num_bb) + img_mask_tgt = _get_img_tgt_mask(img_mask, len(input_ids)) + + attn_masks = torch.ones(len(input_ids) + num_bb, dtype=torch.long) + + return (input_ids, img_feat, img_pos_feat, + attn_masks, img_mask, img_mask_tgt) + + +def mrfr_collate(inputs): + """ + Return: + - input_ids : (n, max_L), i.e., [cls, wd, wd, ..., sep, 0, 0], 0s padded + - position_ids : (n, max_L) + - txt_lens : list of [input_len] + - img_feat : (n, max_num_bb, d) + - img_pos_feat : (n, max_num_bb, 7) + - num_bbs : list of [num_bb] + - attn_masks : (n, max_{L + num_bb}), ie., [1, 1, ..., 0, 0, 1, 1] + - img_masks : (n, max_num_bb) between {0, 1} + """ + (input_ids, img_feats, img_pos_feats, attn_masks, img_masks, img_mask_tgts, + ) = map(list, unzip(inputs)) + + txt_lens = [i.size(0) for i in input_ids] + + input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0) + position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long + ).unsqueeze(0) + + num_bbs = [f.size(0) for f in img_feats] + img_feat = pad_tensors(img_feats, num_bbs) + img_pos_feat = pad_tensors(img_pos_feats, num_bbs) + + # mask features + img_masks = pad_sequence(img_masks, batch_first=True, padding_value=0) + feat_targets = _get_feat_target(img_feat, img_masks) + img_feat = _mask_img_feat(img_feat, img_masks) + img_mask_tgt = pad_sequence(img_mask_tgts, + batch_first=True, padding_value=0) + + attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0) + bs, max_tl = input_ids.size() + out_size = attn_masks.size(1) + gather_index = get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size) + + batch = {'input_ids': input_ids, + 'position_ids': position_ids, + 'img_feat': img_feat, + 'img_pos_feat': img_pos_feat, + 'attn_masks': attn_masks, + 'gather_index': gather_index, + 'feat_targets': feat_targets, + 'img_masks': img_masks, + 'img_mask_tgt': img_mask_tgt} + return batch + + +class OnlyImgMrfrDataset(Dataset): + """ an image-only MRM """ + def __init__(self, mask_prob, img_db): + self.ids, self.lens = map(list, unzip(self.img_db.name2nbb.items())) + + def __getitem__(self, i): + id_ = self.ids[i] + img_feat, img_pos_feat, num_bb = self._get_img_feat(id_) + attn_masks = torch.ones(num_bb, dtype=torch.long) + img_mask = _get_img_mask(self.mask_prob, num_bb) + + return img_feat, img_pos_feat, attn_masks, img_mask + + def _get_img_feat(self, fname): + img_feat, bb = self.img_db[fname] + img_bb = torch.cat([bb, bb[:, 4:5]*bb[:, 5:]], dim=-1) + num_bb = img_feat.size(0) + return img_feat, img_bb, num_bb + + +def mrfr_only_img_collate(inputs): + img_feats, img_pos_feats, attn_masks, img_masks = map(list, unzip(inputs)) + + attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0) + + num_bbs = [f.size(0) for f in img_feats] + img_feat = pad_tensors(img_feats, num_bbs) + img_pos_feat = pad_tensors(img_pos_feats, num_bbs) + + # mask features + img_masks = pad_sequence(img_masks, batch_first=True, padding_value=0) + feat_targets = _get_feat_target(img_feat, img_masks) + img_feat = _mask_img_feat(img_feat, img_masks) + + batch = {'img_feat': img_feat, + 'img_pos_feat': img_pos_feat, + 'attn_masks': attn_masks, + 'feat_targets': feat_targets, + 'img_masks': img_masks, + 'img_mask_tgt': img_masks} + return batch + + +def _get_targets(img_masks, img_soft_label): + soft_label_dim = img_soft_label.size(-1) + img_masks_ext_for_label = img_masks.unsqueeze(-1).expand_as(img_soft_label) + label_targets = img_soft_label[img_masks_ext_for_label].contiguous().view( + -1, soft_label_dim) + return label_targets + + +class MrcDataset(DetectFeatTxtTokDataset): + def __init__(self, mask_prob, *args, **kwargs): + super().__init__(*args, **kwargs) + self.mask_prob = mask_prob + + def _get_img_feat(self, fname): + img_dump = self.img_db.get_dump(fname) + num_bb = self.img_db.name2nbb[fname] + img_feat = torch.tensor(img_dump['features']) + bb = torch.tensor(img_dump['norm_bb']) + img_bb = torch.cat([bb, bb[:, 4:5]*bb[:, 5:]], dim=-1) + img_soft_label = torch.tensor(img_dump['soft_labels']) + return img_feat, img_bb, img_soft_label, num_bb + + def __getitem__(self, i): + example = super().__getitem__(i) + img_feat, img_pos_feat, img_soft_labels, num_bb = self._get_img_feat( + example['img_fname']) + + # image input features + img_mask = _get_img_mask(self.mask_prob, num_bb) + + # text input + input_ids = example['input_ids'] + input_ids = self.txt_db.combine_inputs(input_ids) + img_mask_tgt = _get_img_tgt_mask(img_mask, len(input_ids)) + + attn_masks = torch.ones(len(input_ids) + num_bb, dtype=torch.long) + + return (input_ids, img_feat, img_pos_feat, + img_soft_labels, attn_masks, img_mask, img_mask_tgt) + + +def mrc_collate(inputs): + (input_ids, img_feats, img_pos_feats, img_soft_labels, + attn_masks, img_masks, img_mask_tgts) = map(list, unzip(inputs)) + + txt_lens = [i.size(0) for i in input_ids] + num_bbs = [f.size(0) for f in img_feats] + + input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0) + position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long + ).unsqueeze(0) + + img_feat = pad_tensors(img_feats, num_bbs) + img_pos_feat = pad_tensors(img_pos_feats, num_bbs) + img_soft_label = pad_tensors(img_soft_labels, num_bbs) + img_masks = pad_sequence(img_masks, batch_first=True, padding_value=0) + label_targets = _get_targets(img_masks, img_soft_label) + + img_feat = _mask_img_feat(img_feat, img_masks) + img_mask_tgt = pad_sequence(img_mask_tgts, + batch_first=True, padding_value=0) + + attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0) + bs, max_tl = input_ids.size() + out_size = attn_masks.size(1) + gather_index = get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size) + + batch = {'input_ids': input_ids, + 'position_ids': position_ids, + 'img_feat': img_feat, + 'img_pos_feat': img_pos_feat, + 'attn_masks': attn_masks, + 'gather_index': gather_index, + 'img_masks': img_masks, + 'img_mask_tgt': img_mask_tgt, + 'label_targets': label_targets} + return batch + + +class OnlyImgMrcDataset(OnlyImgMrfrDataset): + """ an image-only MRC """ + def __getitem__(self, i): + id_ = self.ids[i] + (img_feat, img_pos_feat, img_soft_labels, num_bb + ) = self._get_img_feat(id_) + attn_masks = torch.ones(num_bb, dtype=torch.long) + img_mask = _get_img_mask(self.mask_prob, num_bb) + + return img_feat, img_pos_feat, img_soft_labels, attn_masks, img_mask + + def _get_img_feat(self, fname): + img_dump = self.img_db.get_dump(fname) + num_bb = self.img_db.name2nbb[fname] + img_feat = torch.tensor(img_dump['features']) + bb = torch.tensor(img_dump['norm_bb']) + img_bb = torch.cat([bb, bb[:, 4:5]*bb[:, 5:]], dim=-1) + img_soft_labels = torch.tensor(img_dump['soft_labels']) + return img_feat, img_bb, img_soft_labels, num_bb + + +def mrc_only_img_collate(inputs): + (img_feats, img_pos_feats, img_soft_labels, attn_masks, img_masks + ) = map(list, unzip(inputs)) + + attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0) + img_masks = pad_sequence(img_masks, batch_first=True, padding_value=0) + num_bbs = [f.size(0) for f in img_feats] + + img_feat = pad_tensors(img_feats, num_bbs) + img_pos_feat = pad_tensors(img_pos_feats, num_bbs) + img_soft_label = pad_tensors(img_soft_labels, num_bbs) + label_targets = _get_targets(img_masks, img_soft_label) + + # mask features + img_feat = _mask_img_feat(img_feat, img_masks) + + batch = {'img_feat': img_feat, + 'img_pos_feat': img_pos_feat, + 'attn_masks': attn_masks, + 'img_masks': img_masks, + 'img_mask_tgt': img_masks, + 'label_targets': label_targets} + return batch diff --git a/uniter_model/data/mrm_nce.py b/uniter_model/data/mrm_nce.py new file mode 100644 index 0000000..1bcdcd5 --- /dev/null +++ b/uniter_model/data/mrm_nce.py @@ -0,0 +1,136 @@ +""" +MRM Datasets (contrastive learning version) +""" +import torch +from torch.nn.utils.rnn import pad_sequence +from toolz.sandbox import unzip +from cytoolz import curry + +from .data import (DetectFeatLmdb, DetectFeatTxtTokDataset, + pad_tensors, get_gather_index) +from .mrm import _get_img_mask, _get_img_tgt_mask, _get_feat_target +from .itm import sample_negative + + +# FIXME diff implementation from mrfr, mrc +def _mask_img_feat(img_feat, img_masks, neg_feats, + noop_prob=0.1, change_prob=0.1): + rand = torch.rand(*img_masks.size()) + noop_mask = rand < noop_prob + change_mask = ~noop_mask & (rand < (noop_prob+change_prob)) & img_masks + img_masks_in = img_masks & ~noop_mask & ~change_mask + + img_masks_ext = img_masks_in.unsqueeze(-1).expand_as(img_feat) + img_feat_masked = img_feat.data.masked_fill(img_masks_ext, 0) + + n_neg = change_mask.sum().item() + feat_dim = neg_feats.size(-1) + index = torch.arange(0, change_mask.numel(), dtype=torch.long + ).masked_select(change_mask.view(-1)) + index = index.unsqueeze(-1).expand(-1, feat_dim) + img_feat_out = img_feat_masked.view(-1, feat_dim).scatter( + dim=0, index=index, src=neg_feats[:n_neg]).view(*img_feat.size()) + + return img_feat_out, img_masks_in + + +class MrmNceDataset(DetectFeatTxtTokDataset): + def __init__(self, mask_prob, *args, **kwargs): + super().__init__(*args, **kwargs) + self.mask_prob = mask_prob + + def __getitem__(self, i): + example = super().__getitem__(i) + # text input + input_ids = example['input_ids'] + input_ids = self.txt_db.combine_inputs(input_ids) + + # image input features + img_feat, img_pos_feat, num_bb = self._get_img_feat( + example['img_fname']) + img_mask = _get_img_mask(self.mask_prob, num_bb) + img_mask_tgt = _get_img_tgt_mask(img_mask, len(input_ids)) + + attn_masks = torch.ones(len(input_ids) + num_bb, dtype=torch.long) + + return (input_ids, img_feat, img_pos_feat, + attn_masks, img_mask, img_mask_tgt, + example['img_fname']) + + +class NegativeImageSampler(object): + def __init__(self, img_dbs, neg_size, size_mul=8): + if not isinstance(img_dbs, list): + assert isinstance(img_dbs, DetectFeatLmdb) + img_dbs = [img_dbs] + self.neg_size = neg_size + self.img_db = JoinedDetectFeatLmdb(img_dbs) + all_imgs = [] + for db in img_dbs: + all_imgs.extend(db.name2nbb.keys()) + self.all_imgs = all_imgs + + def sample_negative_feats(self, pos_imgs): + neg_img_ids = sample_negative(self.all_imgs, pos_imgs, self.neg_size) + all_neg_feats = torch.cat([self.img_db[img][0] for img in neg_img_ids], + dim=0) + # only use multiples of 8 for tensorcores + n_cut = all_neg_feats.size(0) % 8 + if n_cut != 0: + return all_neg_feats[:-n_cut] + else: + return all_neg_feats + + +class JoinedDetectFeatLmdb(object): + def __init__(self, img_dbs): + assert all(isinstance(db, DetectFeatLmdb) for db in img_dbs) + self.img_dbs = img_dbs + + def __getitem__(self, file_name): + for db in self.img_dbs: + if file_name in db: + return db[file_name] + raise ValueError("image does not exists") + + +@curry +def mrm_nce_collate(neg_sampler, inputs): + (input_ids, img_feats, img_pos_feats, attn_masks, img_masks, img_mask_tgts, + positive_imgs) = map(list, unzip(inputs)) + + txt_lens = [i.size(0) for i in input_ids] + + input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0) + position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long + ).unsqueeze(0) + + num_bbs = [f.size(0) for f in img_feats] + img_feat = pad_tensors(img_feats, num_bbs) + img_pos_feat = pad_tensors(img_pos_feats, num_bbs) + neg_feats = neg_sampler.sample_negative_feats(positive_imgs) + + # mask features + img_masks = pad_sequence(img_masks, batch_first=True, padding_value=0) + feat_targets = _get_feat_target(img_feat, img_masks) + img_feat, img_masks_in = _mask_img_feat(img_feat, img_masks, neg_feats) + img_mask_tgt = pad_sequence(img_mask_tgts, + batch_first=True, padding_value=0) + + attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0) + bs, max_tl = input_ids.size() + out_size = attn_masks.size(1) + gather_index = get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size) + + batch = {'input_ids': input_ids, + 'position_ids': position_ids, + 'img_feat': img_feat, + 'img_pos_feat': img_pos_feat, + 'attn_masks': attn_masks, + 'gather_index': gather_index, + 'feat_targets': feat_targets, + 'img_masks': img_masks, + 'img_masks_in': img_masks_in, + 'img_mask_tgt': img_mask_tgt, + 'neg_feats': neg_feats} + return batch diff --git a/uniter_model/data/nlvr2.py b/uniter_model/data/nlvr2.py new file mode 100644 index 0000000..e45f337 --- /dev/null +++ b/uniter_model/data/nlvr2.py @@ -0,0 +1,218 @@ +""" +Copyright (c) Microsoft Corporation. +Licensed under the MIT license. + +NLVR2 dataset +""" +import copy + +import torch +from torch.nn.utils.rnn import pad_sequence +from toolz.sandbox import unzip +from cytoolz import concat + +from .data import (DetectFeatTxtTokDataset, TxtTokLmdb, DetectFeatLmdb, + get_ids_and_lens, pad_tensors, get_gather_index) + + +class Nlvr2PairedDataset(DetectFeatTxtTokDataset): + def __init__(self, txt_db, img_db, use_img_type=True): + assert isinstance(txt_db, TxtTokLmdb) + assert isinstance(img_db, DetectFeatLmdb) + self.txt_db = txt_db + self.img_db = img_db + txt_lens, self.ids = get_ids_and_lens(txt_db) + + txt2img = txt_db.txt2img + self.lens = [2*tl + sum(self.img_db.name2nbb[img] + for img in txt2img[id_]) + for tl, id_ in zip(txt_lens, self.ids)] + + self.use_img_type = use_img_type + + def __getitem__(self, i): + """ + [[txt, img1], + [txt, img2]] + """ + example = super().__getitem__(i) + target = example['target'] + outs = [] + for i, img in enumerate(example['img_fname']): + img_feat, img_pos_feat, num_bb = self._get_img_feat(img) + + # text input + input_ids = copy.deepcopy(example['input_ids']) + + input_ids = [self.txt_db.cls_] + input_ids + [self.txt_db.sep] + attn_masks = [1] * (len(input_ids) + num_bb) + input_ids = torch.tensor(input_ids) + attn_masks = torch.tensor(attn_masks) + if self.use_img_type: + img_type_ids = torch.tensor([i+1]*num_bb) + else: + img_type_ids = None + + outs.append((input_ids, img_feat, img_pos_feat, + attn_masks, img_type_ids)) + return tuple(outs), target + + +def nlvr2_paired_collate(inputs): + (input_ids, img_feats, img_pos_feats, attn_masks, + img_type_ids) = map(list, unzip(concat(outs for outs, _ in inputs))) + + txt_lens = [i.size(0) for i in input_ids] + input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0) + position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long + ).unsqueeze(0) + + # image batches + num_bbs = [f.size(0) for f in img_feats] + img_feat = pad_tensors(img_feats, num_bbs) + img_pos_feat = pad_tensors(img_pos_feats, num_bbs) + if img_type_ids[0] is None: + img_type_ids = None + else: + img_type_ids = pad_sequence(img_type_ids, + batch_first=True, padding_value=0) + + attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0) + targets = torch.Tensor([t for _, t in inputs]).long() + + bs, max_tl = input_ids.size() + out_size = attn_masks.size(1) + gather_index = get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size) + + batch = {'input_ids': input_ids, + 'position_ids': position_ids, + 'img_feat': img_feat, + 'img_pos_feat': img_pos_feat, + 'attn_masks': attn_masks, + 'gather_index': gather_index, + 'img_type_ids': img_type_ids, + 'targets': targets} + return batch + + +class Nlvr2PairedEvalDataset(Nlvr2PairedDataset): + def __getitem__(self, i): + qid = self.ids[i] + outs, targets = super().__getitem__(i) + return qid, outs, targets + + +def nlvr2_paired_eval_collate(inputs): + qids, batch = [], [] + for id_, *tensors in inputs: + qids.append(id_) + batch.append(tensors) + batch = nlvr2_paired_collate(batch) + batch['qids'] = qids + return batch + + +class Nlvr2TripletDataset(DetectFeatTxtTokDataset): + def __init__(self, txt_db, img_db, use_img_type=True): + assert isinstance(txt_db, TxtTokLmdb) + assert isinstance(img_db, DetectFeatLmdb) + self.txt_db = txt_db + self.img_db = img_db + txt_lens, self.ids = get_ids_and_lens(txt_db) + + txt2img = txt_db.txt2img + self.lens = [tl + sum(self.img_db.name2nbb[img] + for img in txt2img[id_]) + for tl, id_ in zip(txt_lens, self.ids)] + + self.use_img_type = use_img_type + + def __getitem__(self, i): + """ + [[txt, img1], + [txt, img2]] + """ + example = super().__getitem__(i) + target = example['target'] + img_feats = [] + img_pos_feats = [] + num_bb = 0 + img_type_ids = [] + for i, img in enumerate(example['img_fname']): + feat, pos, nbb = self._get_img_feat(img) + img_feats.append(feat) + img_pos_feats.append(pos) + num_bb += nbb + if self.use_img_type: + img_type_ids.extend([i+1]*nbb) + img_feat = torch.cat(img_feats, dim=0) + img_pos_feat = torch.cat(img_pos_feats, dim=0) + if self.use_img_type: + img_type_ids = torch.tensor(img_type_ids) + else: + img_type_ids = None + + # text input + input_ids = copy.deepcopy(example['input_ids']) + + input_ids = [self.txt_db.cls_] + input_ids + [self.txt_db.sep] + attn_masks = [1] * (len(input_ids) + num_bb) + input_ids = torch.tensor(input_ids) + attn_masks = torch.tensor(attn_masks) + + return (input_ids, img_feat, img_pos_feat, attn_masks, + img_type_ids, target) + + +def nlvr2_triplet_collate(inputs): + (input_ids, img_feats, img_pos_feats, + attn_masks, img_type_ids, targets) = map(list, unzip(inputs)) + + txt_lens = [i.size(0) for i in input_ids] + input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0) + position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long + ).unsqueeze(0) + + # image batches + num_bbs = [f.size(0) for f in img_feats] + img_feat = pad_tensors(img_feats, num_bbs) + img_pos_feat = pad_tensors(img_pos_feats, num_bbs) + if img_type_ids[0] is None: + img_type_ids = None + else: + img_type_ids = pad_sequence(img_type_ids, + batch_first=True, padding_value=0) + + attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0) + targets = torch.Tensor(targets).long() + + bs, max_tl = input_ids.size() + out_size = attn_masks.size(1) + gather_index = get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size) + + batch = {'input_ids': input_ids, + 'position_ids': position_ids, + 'img_feat': img_feat, + 'img_pos_feat': img_pos_feat, + 'attn_masks': attn_masks, + 'gather_index': gather_index, + 'img_type_ids': img_type_ids, + 'targets': targets} + return batch + + +class Nlvr2TripletEvalDataset(Nlvr2TripletDataset): + def __getitem__(self, i): + qid = self.ids[i] + tensors = super().__getitem__(i) + return (qid, *tensors) + + +def nlvr2_triplet_eval_collate(inputs): + qids, batch = [], [] + for id_, *tensors in inputs: + qids.append(id_) + batch.append(tensors) + batch = nlvr2_triplet_collate(batch) + batch['qids'] = qids + return batch diff --git a/uniter_model/data/re.py b/uniter_model/data/re.py new file mode 100644 index 0000000..4e20bdb --- /dev/null +++ b/uniter_model/data/re.py @@ -0,0 +1,319 @@ +""" +Referring Expression Comprehension dataset +""" +import sys +import json +import random +import numpy as np + +import torch +from torch.utils.data import Dataset +from torch.nn.utils.rnn import pad_sequence +from toolz.sandbox import unzip + +from .data import TxtLmdb + + +class ReImageFeatDir(object): + def __init__(self, img_dir): + self.img_dir = img_dir + + def __getitem__(self, file_name): + img_dump = np.load(f'{self.img_dir}/{file_name}', allow_pickle=True) + img_feat = torch.tensor(img_dump['features']) + img_bb = torch.tensor(img_dump['norm_bb']) + return img_feat, img_bb + + +class ReDetectFeatDir(object): + def __init__(self, img_dir, conf_th=0.2, max_bb=100, min_bb=10, num_bb=36, + format_='npz'): + assert format_ == 'npz', 'only support npz for now.' + assert isinstance(img_dir, str), 'img_dir is path, not db.' + self.img_dir = img_dir + self.conf_th = conf_th + self.max_bb = max_bb + self.min_bb = min_bb + self.num_bb = num_bb + + def _compute_num_bb(self, img_dump): + num_bb = max(self.min_bb, (img_dump['conf'] > self.conf_th).sum()) + num_bb = min(self.max_bb, num_bb) + return num_bb + + def __getitem__(self, file_name): + # image input features + img_dump = np.load(f'{self.img_dir}/{file_name}', allow_pickle=True) + num_bb = self._compute_num_bb(img_dump) + img_feat = torch.tensor(img_dump['features'][:num_bb, :]) + img_bb = torch.tensor(img_dump['norm_bb'][:num_bb, :]) + return img_feat, img_bb + + +class ReferringExpressionDataset(Dataset): + def __init__(self, db_dir, img_dir, max_txt_len=60): + assert isinstance(img_dir, ReImageFeatDir) or \ + isinstance(img_dir, ReDetectFeatDir) + self.img_dir = img_dir + + # load refs = [{ref_id, sent_ids, ann_id, image_id, sentences, split}] + refs = json.load(open(f'{db_dir}/refs.json', 'r')) + self.ref_ids = [ref['ref_id'] for ref in refs] + self.Refs = {ref['ref_id']: ref for ref in refs} + + # load annotations = [{id, area, bbox, image_id, category_id}] + anns = json.load(open(f'{db_dir}/annotations.json', 'r')) + self.Anns = {ann['id']: ann for ann in anns} + + # load categories = [{id, name, supercategory}] + categories = json.load(open(f'{db_dir}/categories.json', 'r')) + self.Cats = {cat['id']: cat['name'] for cat in categories} + + # load images = [{id, file_name, ann_ids, height, width}] + images = json.load(open(f'{db_dir}/images.json', 'r')) + self.Images = {img['id']: img for img in images} + + # id2len: sent_id -> sent_len + id2len = json.load(open(f'{db_dir}/id2len.json', 'r')) + self.id2len = {int(_id): _len for _id, _len in id2len.items()} + self.max_txt_len = max_txt_len + self.sent_ids = self._get_sent_ids() + + # db[str(sent_id)] = + # {sent_id, sent, ref_id, ann_id, image_id, + # bbox, input_ids, toked_sent} + self.db = TxtLmdb(db_dir, readonly=True) + + # meta + meta = json.load(open(f'{db_dir}/meta.json', 'r')) + self.cls_ = meta['CLS'] + self.sep = meta['SEP'] + self.mask = meta['MASK'] + self.v_range = meta['v_range'] + + def shuffle(self): + # we shuffle ref_ids and make sent_ids according to ref_ids + random.shuffle(self.ref_ids) + self.sent_ids = self._get_sent_ids() + + def _get_sent_ids(self): + sent_ids = [] + for ref_id in self.ref_ids: + for sent_id in self.Refs[ref_id]['sent_ids']: + sent_len = self.id2len[sent_id] + if self.max_txt_len == -1 or sent_len < self.max_txt_len: + sent_ids.append(sent_id) + return sent_ids + + def _get_img_feat(self, fname): + img_feat, bb = self.img_dir[fname] + img_bb = torch.cat([bb, bb[:, 4:5]*bb[:, 5:]], dim=-1) + num_bb = img_feat.size(0) + return img_feat, img_bb, num_bb + + def __len__(self): + return len(self.sent_ids) + + def __getitem__(self, i): + """ + Return: + :input_ids : (L, ), i.e., [cls, wd, wd, ..., sep, 0, 0] + :position_ids : range(L) + :img_feat : (num_bb, d) + :img_pos_feat : (num_bb, 7) + :attn_masks : (L+num_bb, ), i.e., [1, 1, ..., 0, 0, 1, 1] + :obj_masks : (num_bb, ) all 0's + :target : (1, ) + """ + # {sent_id, sent, ref_id, ann_id, image_id, + # bbox, input_ids, toked_sent} + sent_id = self.sent_ids[i] + txt_dump = self.db[str(sent_id)] + image_id = txt_dump['image_id'] + fname = f'visual_grounding_coco_gt_{int(image_id):012}.npz' + img_feat, img_pos_feat, num_bb = self._get_img_feat(fname) + + # text input + input_ids = txt_dump['input_ids'] + input_ids = [self.cls_] + input_ids + [self.sep] + attn_masks = [1] * len(input_ids) + position_ids = list(range(len(input_ids))) + attn_masks += [1] * num_bb + + input_ids = torch.tensor(input_ids) + position_ids = torch.tensor(position_ids) + attn_masks = torch.tensor(attn_masks) + + # target bbox + img = self.Images[image_id] + assert len(img['ann_ids']) == num_bb, \ + 'Please use visual_grounding_coco_gt' + target = img['ann_ids'].index(txt_dump['ann_id']) + target = torch.tensor([target]) + + # obj_masks, to be padded with 1, for masking out non-object prob. + obj_masks = torch.tensor([0]*len(img['ann_ids'])).bool() + + return (input_ids, position_ids, img_feat, img_pos_feat, attn_masks, + obj_masks, target) + + +def re_collate(inputs): + """ + Return: + :input_ids : (n, max_L) padded with 0 + :position_ids : (n, max_L) padded with 0 + :txt_lens : list of [txt_len] + :img_feat : (n, max_num_bb, feat_dim) + :img_pos_feat : (n, max_num_bb, 7) + :num_bbs : list of [num_bb] + :attn_masks : (n, max_{L+num_bb}) padded with 0 + :obj_masks : (n, max_num_bb) padded with 1 + :targets : (n, ) + """ + (input_ids, position_ids, img_feats, img_pos_feats, attn_masks, obj_masks, + targets) = map(list, unzip(inputs)) + + txt_lens = [i.size(0) for i in input_ids] + num_bbs = [f.size(0) for f in img_feats] + + input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0) + position_ids = pad_sequence(position_ids, + batch_first=True, padding_value=0) + attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0) + targets = torch.cat(targets, dim=0) + obj_masks = pad_sequence(obj_masks, + batch_first=True, padding_value=1).bool() + + batch_size = len(img_feats) + num_bb = max(num_bbs) + feat_dim = img_feats[0].size(1) + pos_dim = img_pos_feats[0].size(1) + img_feat = torch.zeros(batch_size, num_bb, feat_dim) + img_pos_feat = torch.zeros(batch_size, num_bb, pos_dim) + for i, (im, pos) in enumerate(zip(img_feats, img_pos_feats)): + len_ = im.size(0) + img_feat.data[i, :len_, :] = im.data + img_pos_feat.data[i, :len_, :] = pos.data + + return (input_ids, position_ids, txt_lens, + img_feat, img_pos_feat, num_bbs, + attn_masks, obj_masks, targets) + + +class ReferringExpressionEvalDataset(ReferringExpressionDataset): + def __getitem__(self, i): + """ + Return: + :input_ids : (L, ), i.e., [cls, wd, wd, ..., sep, 0, 0] + :position_ids : range(L) + :img_feat : (num_bb, d) + :img_pos_feat : (num_bb, 7) + :attn_masks : (L+num_bb, ), i.e., [1, 1, ..., 0, 0, 1, 1] + :obj_masks : (num_bb, ) all 0's + :tgt_box : ndarray (4, ) xywh + :obj_boxes : ndarray (num_bb, 4) xywh + :sent_id + """ + # {sent_id, sent, ref_id, ann_id, image_id, + # bbox, input_ids, toked_sent} + sent_id = self.sent_ids[i] + txt_dump = self.db[str(sent_id)] + image_id = txt_dump['image_id'] + if isinstance(self.img_dir, ReImageFeatDir): + if '_gt' in self.img_dir.img_dir: + fname = f'visual_grounding_coco_gt_{int(image_id):012}.npz' + elif '_det' in self.img_dir.img_dir: + fname = f'visual_grounding_det_coco_{int(image_id):012}.npz' + elif isinstance(self.img_dir, ReDetectFeatDir): + fname = f'coco_train2014_{int(image_id):012}.npz' + else: + sys.exit('%s not supported.' % self.img_dir) + img_feat, img_pos_feat, num_bb = self._get_img_feat(fname) + + # image info + img = self.Images[image_id] + im_width, im_height = img['width'], img['height'] + + # object boxes, img_pos_feat (xyxywha) -> xywh + obj_boxes = np.stack([img_pos_feat[:, 0]*im_width, + img_pos_feat[:, 1]*im_height, + img_pos_feat[:, 4]*im_width, + img_pos_feat[:, 5]*im_height], axis=1) + obj_masks = torch.tensor([0]*num_bb).bool() + + # target box + tgt_box = np.array(txt_dump['bbox']) # xywh + + # text input + input_ids = txt_dump['input_ids'] + input_ids = [self.cls_] + input_ids + [self.sep] + attn_masks = [1] * len(input_ids) + position_ids = list(range(len(input_ids))) + attn_masks += [1] * num_bb + + input_ids = torch.tensor(input_ids) + position_ids = torch.tensor(position_ids) + attn_masks = torch.tensor(attn_masks) + + return (input_ids, position_ids, img_feat, img_pos_feat, attn_masks, + obj_masks, tgt_box, obj_boxes, sent_id) + + # IoU function + def computeIoU(self, box1, box2): + # each box is of [x1, y1, w, h] + inter_x1 = max(box1[0], box2[0]) + inter_y1 = max(box1[1], box2[1]) + inter_x2 = min(box1[0]+box1[2]-1, box2[0]+box2[2]-1) + inter_y2 = min(box1[1]+box1[3]-1, box2[1]+box2[3]-1) + + if inter_x1 < inter_x2 and inter_y1 < inter_y2: + inter = (inter_x2-inter_x1+1)*(inter_y2-inter_y1+1) + else: + inter = 0 + union = box1[2]*box1[3] + box2[2]*box2[3] - inter + return float(inter)/union + + +def re_eval_collate(inputs): + """ + Return: + :input_ids : (n, max_L) + :position_ids : (n, max_L) + :txt_lens : list of [txt_len] + :img_feat : (n, max_num_bb, d) + :img_pos_feat : (n, max_num_bb, 7) + :num_bbs : list of [num_bb] + :attn_masks : (n, max{L+num_bb}) + :obj_masks : (n, max_num_bb) + :tgt_box : list of n [xywh] + :obj_boxes : list of n [[xywh, xywh, ...]] + :sent_ids : list of n [sent_id] + """ + (input_ids, position_ids, img_feats, img_pos_feats, attn_masks, obj_masks, + tgt_box, obj_boxes, sent_ids) = map(list, unzip(inputs)) + + txt_lens = [i.size(0) for i in input_ids] + num_bbs = [f.size(0) for f in img_feats] + + input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0) + position_ids = pad_sequence(position_ids, + batch_first=True, padding_value=0) + attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0) + obj_masks = pad_sequence(obj_masks, + batch_first=True, padding_value=1).bool() + + batch_size = len(img_feats) + num_bb = max(num_bbs) + feat_dim = img_feats[0].size(1) + pos_dim = img_pos_feats[0].size(1) + img_feat = torch.zeros(batch_size, num_bb, feat_dim) + img_pos_feat = torch.zeros(batch_size, num_bb, pos_dim) + for i, (im, pos) in enumerate(zip(img_feats, img_pos_feats)): + len_ = im.size(0) + img_feat.data[i, :len_, :] = im.data + img_pos_feat.data[i, :len_, :] = pos.data + + return (input_ids, position_ids, txt_lens, + img_feat, img_pos_feat, num_bbs, + attn_masks, obj_masks, tgt_box, obj_boxes, sent_ids) diff --git a/uniter_model/data/sampler.py b/uniter_model/data/sampler.py new file mode 100644 index 0000000..4095f4c --- /dev/null +++ b/uniter_model/data/sampler.py @@ -0,0 +1,116 @@ +""" sampler for length bucketing (batch by tokens) """ +import math +import random + +import torch +import horovod.torch as hvd +from torch.utils.data import Sampler +from cytoolz import partition_all + + +class TokenBucketSampler(Sampler): + def __init__(self, lens, bucket_size, batch_size, + droplast=False, size_multiple=8): + self._lens = lens + self._max_tok = batch_size + self._bucket_size = bucket_size + self._droplast = droplast + self._size_mul = size_multiple + + def _create_ids(self): + return list(range(len(self._lens))) + + def _sort_fn(self, i): + return self._lens[i] + + def __iter__(self): + ids = self._create_ids() + random.shuffle(ids) + buckets = [sorted(ids[i:i+self._bucket_size], + key=self._sort_fn, reverse=True) + for i in range(0, len(ids), self._bucket_size)] + # fill batches until max_token (include padding) + batches = [] + for bucket in buckets: + max_len = 0 + batch_indices = [] + for indices in partition_all(self._size_mul, bucket): + max_len = max(max_len, max(self._lens[i] for i in indices)) + if (max_len * (len(batch_indices) + self._size_mul) + > self._max_tok): + if not batch_indices: + raise ValueError( + "max_tokens too small / max_seq_len too long") + assert len(batch_indices) % self._size_mul == 0 + batches.append(batch_indices) + batch_indices = list(indices) + else: + batch_indices.extend(indices) + if not self._droplast and batch_indices: + batches.append(batch_indices) + random.shuffle(batches) + return iter(batches) + + def __len__(self): + raise ValueError("NOT supported. " + "This has some randomness across epochs") + + +class DistributedSampler(Sampler): + """Sampler that restricts data loading to a subset of the dataset. + + It is especially useful in conjunction with + :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each + process can pass a DistributedSampler instance as a DataLoader sampler, + and load a subset of the original dataset that is exclusive to it. + + .. note:: + Dataset is assumed to be of constant size. + + Arguments: + dataset: Dataset used for sampling. + num_replicas (optional): Number of processes participating in + distributed training. + rank (optional): Rank of the current process within num_replicas. + shuffle (optional): If true (default), sampler will shuffle the indices + """ + + def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True): + if num_replicas is None: + num_replicas = hvd.size() + if rank is None: + rank = hvd.rank() + self.dataset = dataset + self.num_replicas = num_replicas + self.rank = rank + self.epoch = 0 + self.num_samples = int(math.ceil(len(self.dataset) + * 1.0 / self.num_replicas)) + self.total_size = self.num_samples * self.num_replicas + self.shuffle = shuffle + + def __iter__(self): + # deterministically shuffle based on epoch + g = torch.Generator() + g.manual_seed(self.epoch) + + indices = list(range(len(self.dataset))) + # add extra samples to make it evenly divisible + indices += indices[:(self.total_size - len(indices))] + assert len(indices) == self.total_size + + # subsample + indices = indices[self.rank:self.total_size:self.num_replicas] + + if self.shuffle: + shufle_ind = torch.randperm(len(indices), generator=g).tolist() + indices = [indices[i] for i in shufle_ind] + assert len(indices) == self.num_samples + + return iter(indices) + + def __len__(self): + return self.num_samples + + def set_epoch(self, epoch): + self.epoch = epoch diff --git a/uniter_model/data/test_data/input0.txt b/uniter_model/data/test_data/input0.txt new file mode 100644 index 0000000000000000000000000000000000000000..619f9be9298fe03e442596509debf6de5e019598 GIT binary patch literal 92849 zcmbTfcf3|b*5!YR0*V=PAPdHugNckFB6>|2iI*Vp93D{At4LB*%wkp~2}V@R2@r64 zI?Z&N>3RF@p6N8x(`lyDOi!oZs$1*(d|paVqd)$rwb!nFs%ls5+I8x=&%G>Ze^7?pL<@v`J$swU($Z;FFL4f%_;_rKe(czVtziaG;hkh*%KDEqn%y#?W$x6Oljdga?;PJv%F^QbEfW^Zys2eD)}b})IADAyDHqI`mzT0mW7?1J zAU1PuDeF9@$%4U7&t1iMnpte&Y|6NmFfgZGC-XO>J#`T~l?|rERQJ<}b{; zj=AWtcJTPlYP&QkpWaf+4w<3x%;-EL>9DFsU1x{R=sKp0Iyx+L)MZ6>_?UJp zvL35ugtm?-Wk)t=M=hK@Bh16mt=TaHW_0}g#%dcH8>+Ko)mRwyaph4TA4h#cZ1BWV z*0VY5HRJO~RNYuzuMze3i28U$eZz?Q#RmG9vXh##0Upu7*6ieeGNOj+#_SZ2=+yFv zPKzTtJvMkoDXVPGstQI_Q(x0q*_c&(L^U2!Z5UBqY@oiBH8f|99#K}orR<94Y`jNwWovd-;fNX=o3v_IdqmfiM|5o*(RH!G38ideb2h1HL=6q~ zP4(Gik7$ZVG&PKy>_!23vqNyXsj)c zTH}arjSb#b%9b=|w}1YKlx`Xu6^u(gqGcY@@-U(mv4NGP?2hK_PLF6+Yj#)Rh#IOJ z8f&w=J)(QcBU&9tbZ>0%zEXC7bM`>dh(Zpm%pUZJ9`cAD4kLObHn66YJ=&Z-<`F&K znmtiCqS}T!jp#`=Hdu4?R9j~)Xm4UcGB7}1-tfwxN8+s)ZK9?`q4*?WZ}YHFxysL$T_h(0Kf=)*Xok79!# zm$L26*^bX0k(R8XrmDI!+vyQ~;t}l%BibDs*i*_rZO;CkNA#7}?B5rTNHbJln|;+I z`dWEJUymdDMr`n#rR+a6XaBKiM2*!=nxk)dMBny^z7t0D-PpkQO4)yE&i=DU^!?WC zzkL3P>KYru2~eB;i%0Z>@`(O5j_8N6!5@{f|Jt1Wx1td>)z)bD|JWn?iAVI)FruHu z27X@3{(E!wKRlvev}XVF^GBo!Z>rYb&3@?-{i-~oU&j&sCN}unQue=^v;SQ*qQ<8B z%F66_9?|bTqCbQY{V_K1r&9KRnzR4w5&gL}`+xsrL^XBQ*YA|#|HmV z$}8GcR5a)9R2Gb@uCl(mu0C%cEV_o*lDvb&IJ&%}L=7m5h{^&FYMZnp z>hnW`MfBh>iT2=diNS*&!r1B&l1llJDx348RN8~1CAQ|rs4VuNx~i!_xFA_aCK~n7NL`kK*r^@EMm&zg!Y8$F6EA!sLB6`qAqCMy3&Z|@w zcu=R;tj5Z`I#@&xY9!i&T8Y7fI$>O z!8H=g9$YIDJ-AL%>}!IgQa({-b3RFBkq6cFRSos|s1zcP*+pgq$J;izT+^H>oW0ptiBTN*>%S6ZI2v z(Jc}~Kk9X5-YOD3xK&c@>o!THe2L2D{C1U}?}6g8UME$)G+0CrmPxb+%OwU6RtRIO zDw&qW$EcBqdN$;xp zlQOXfPf09$@U%$u;2BA=uV*Ed^0g|P^XF6+c~DthRbQE}3l`CX^%CvD28qFgjl$UK zCP}6Id6muiW|j6}i^SIa1(n4f)Q7D7qD<_;OA^Z-yetwuctuj|>s3jme5=al{56$D z9#n@{@BH;(5j}WAqCMCqF?jH%Ft+-Zq*DI2%I5qXmG2Ot8C79s4ViJR-X`6=R1Q%^xzYT_F$L9;K6QT zY;}*MQvRvR=KSBOv-vd{y) zKWe7_MJ8T@A4n{(!M}<`4}K^q_Vpu4rTo9DY|j6i$^s9WGL=O&>6#F_z zlI{m8bw5y9nD(%715_LaNS?GaY zD;gW~V`X9wj+0pS;CPYf!3mOLUnfe^{XnJe2P%s^sH<+$&7*g)h#vHjXb<{I3?B3o z##Z}F()~cC?guLE!9a<+AE+$!pteeH!TBjNQ9t3W;#7&DAN9I2KTRZhaJr<}*BO#@ zKTxUrfyyEeYHM{M=;Jvc`sdT_3!*w-LQx*w?2{Xk`*2TeL_@}a>ZdN53)Js2)AcrZd3TOBD$ z_XCx>AE>kkqb2Hopt8_|`o;!D>lm4+pKzjGATjjQcA||Hi5^@iDfTr^lI{m8bw5y9 z3*P6_XCyo zV5UUf4^$R-5Z>ZL)}AF3^%M5NY>A;C^|~^jBN9EBD=GFhPm=BjDs?|lS>!>jUfXp) z2o}+U1rqH+Nn-F|p)j_(NRsXcDs?|lX%7}l)crtZfd~3_MwY z^x#%Wv9H@C>3*P6_XCwh9;jWN_Dh3B^kA7pd$3$$@L+{7wz^W1?guJ$KTv58R!P+T zKxLr^4K?B1xmzao;2w!(4_1pr5AKx|`?^n(?guJ$KTuiZK|@2G_Q8X}B6{$UM0@bC z#Nfdr!r1B>NxC1X)cruEJ$PKA?guK1J!q`f8aycz^%HjTQxZczZ9DmCk?6rQl44)a zO49v6rS1nRi#({-hrreOx?mAKSTE5YY>*f{*eHyxZjz+?flA#ERN8|r5_LaNS?EDs zI1us|WnvFrl34cOWs&H?E0SVguS(MWK&9>nDvLa*(FW7~AXr2X-jHYywn+>gyeW*W zz9mWb1C_cTsI&*~O4R*8WuXW9SgO7@e_tl{-~)+e4?Yx$9(*Jz_VuwO-49giexS0* zgR1)4s``9qu!tUfBGDe~k{CSLEsU-1k)->9O5G1s+Jmo1)crtZu?J0+dI$chOzgqe zB$hq+x=8fk8Pv+dhkO@v9BLV()~cC?guK1 zJgBR!(>?je!6JI_6N&cVrxJq)KNH4Qe=bS)1C_cTsI&*akf{5C%0dqstMmic{FgHE z8vIIPc@2Io5kQuhOuMIJQOHR!G4cflfh@Oz2&;13dm2Y(dCR{tbP z_XCx>AE>kkf0n5GfyzP;Y8u1Y{}-9qgTG2Fd+;}r=)vD5#lHR_N%sSNk3;tZl|>%Z zx845Q2aD)I2Z{Eeqr~7rCt+;0vn1UQRO)`9(jIh`sQZD+Vh`G)Xn&d5g99X%J?JhH zJvdNO?CT&&x*w?2{Xk`*2aQeo)bY?@5j{9eqCGfVV(_4cFt&PxB;5~G>VBZo9vm%E z_XCxM9@OjTtjv#a)R9WTJlB?gtV>Kk8NY1Ci*#>5^h! zXGqfhK&9>nDvLa*uh%y<^6Fp_J*bgr4{9X_59)-m)p|*~AE?y*K&3rslBoNE%0dsC z^zmk8ewIwU24_nwufaJY(Svg(#l8kf()~cC?guK1JZNmx59ITq!6JGvOrkv)E-`p8 zLKs^eDM|MOmAW6OvVBZI(1ZFqeaR*tBNKaYfyA;0V@09|7fOnKjgzGNflA#E zR2Fzpubt2Ukke{Xk`r2l~uWKTgfB zmWe&MMq=56Yek|5*GY9(Ss=x?ZH%u!GjiIY;~F> z-49giexTAG%#^76fyzP;s@mR7XURnUwA~LRhJMtm?gt{#gSnDoU-KmCexOqK1C<3H z)K^y3>T{etSVRvNNVEqfiNS+~!r1B}NxC1X)cruEJy!=oQ*HQJ(cLn!2lq%Ud$3w0dT_6#*w=lMbU#q3`+>?L4{B?}XK4=x zi|D~a679jm5`zbi2xF^jBtskOT6f<^RTy+nJkL1OS=qcFC*Ns{gdDs?|lX%DtY)crtZ zp$FP-daHO*CidVZiDeI77Kt9bA}RLuswCYHRO)`9vdDw*qjud7f<^S;4T<((o5bM3 zo5I-YTat7?P^tTYN_+6GMBNWm7JAUs5IlHaCh8~btji} zAE?y*KxLr^O4F73&R`Ke_(Y;T*d;M|uv-{g-6Ki&1C_cTsI&)Pk*NED%0dqs>-58- z{HrqY8hlM+c@4fU5kQuhOuMIKbuho3@yD_BGizAe!nd`DvN;Jd=u z>h~n+exOqK1C{pR`x13OP+8hoX9 z#B1;?iRCr;wMg{fHflA#ER2F*Bq}j=T7c8O&zn5qa{va`U@JC^6^-q#? zKTxUrfl7PuXNkHWs4ViJzNSjwPR#!z6R*KvC6?FVZz9oyze|dJ{X>%O2l{Te?guIh zJ!t!Ra{FKrJ?J3O9(0r#Jm@5ht#+2A`+-W`4^-NNt`c=WP+8JgH3KTxUr zfl7ODv_#zxR2F$qUmdVBZIz=OKVCVdMtuMQT`gBpqUpjKk=piUTDt(ThndgcI!oiJ>3$x-uUt z5VBZI&;xzoRhc&ji|D~6679jI5`zbq31h35OVa&7rS1nR?ZK52 zbw5y9VBZI$OH8f-hpSy#2(C+SoUC!Nc3Q?q}bOyNxC1X z)crtZkq7ma`hAT&4;Imb1rqH+Nn-F|p)j_(NRsXcDs?|lX%7}l)crtZp$FAf`tUlx zStj=27KvpKT1BD?L532P8 zy2|{)U=cldNTNM>SYq(t5n*g~jU?RAE+$yptkMJZe6g59;}yW4>m{)9&8lGRyRq~{XnJe z2P*Bs7Kyqas4VoLwz@{2D7+{WufakQuhOuMIO{Nh0iKp z4;ImbHzeAFZ4!eAZwh0pZ%NYqK&9>nD(%6$5_LaNS?EEv-tF}B1DV)^4kQuhOu_TU#1 zbw5y9=t0}B59PmAE+$ypst}&?<>Cx z7SV&>OSA`nkQhApqcFDmCrP>=sMP&Hr9JqwMBNWm7J5)wud7u47n!J^@NW87iJ_mi zchkR#L=XNhDfaabNxC2Cw?cJ4P+8n zD(yj6iMk)CEcBqhE_}~!f0?MC5Jd+_4E?l4QFoE(!GV%uUk6Fj{XnJe2P%s^sILp3 z*&P}zq6ddbvVBZo9t@PI z`+>?r4{94KHB+a^ME!&)I#pumr!9(36Nw(2E-ChPh9un&RO)`9vd{zlAh0s84i?da z8j1FxR$}m=P8eIQm!$iFO5G1s+Jh#Ex*w=4^gutR(7WkbGO-6|ODubEj!5+2TuHI7 zL6USoP^tTY$|4VHs`ML_`OshyJs2j@9t@WlJQyL2t&WtW`+-W`4^-NN(Gqn(P+91K zers6g&KQ~4g9{{}#AP-49giexS0@gW7P?GzW|5!6g#y!KD&|2bT$B ztCvgC{XnJe2P*Bsl@fJ7P+90fV}1C1_G+1^pRkj!kr?_>uPgIwMWP4SNs4_?L5A-!V z#plvs5j|KY(H<<97(7@ZjIFMer2BzN-49gSgH;lBKTuiZL48x(E%R=f*n@i{mOWT4 z5=Y;}_)-49giexTAGY>}w@fyzP;^aT{{VBZI$b-sy{g^R-Jy=8!-jHYywn+>gyeW*Wz9mWb1C_cTsI&*~O4R*8 zWuXW9tU{lxy)P55!3Pq{Yw)2+^xz{&v9FIM>3*P6_XCwh9<=>he!eqUL=QfZXb*Nt z3?A$j##Z-8()~cC?guLE!B-^eexS10gSIpIt1|H#d`)6`4ZbcCJ@|&C*w;5D>3*P6 z_XCwh9@N!@Z_IuxSVRxLEzur)M`G~cyTaJ&_ay0lpi=h(mG!>3zTd4k zyMK|1J@|pdvIqYv5DRUM9|w!*!A~UGgP%$a9{fxg zTm88t-49giexTAG{6eDc2P%s^&{ux-+fey0WnvG0C9&+muSKE%O2l}Hax*w=4@Svf(D$Gs$U=cm&AkiLllo&kdB#fAE+$wprN|$XYIX%Mf9MLM0?OzV(_4!Ft*xXlI{m8 zbw5yP4+cuq{Xk`r2U_d4-#?Iv*Wgr%BzkbVq}bOPl5{^%sr!M-0uS`Vq3|1@ z)xjcqP$SVE)JhB<)Cps&^^$ZyP^tTYN_)^GQTGFtg&u@p)%S~JVh_%iSoYu?k?6s> zl44(jBHyNbW4;q;cqPMpDM$j@Ht>_W!tARS@*q_@oz#Mm?*=~ zk+XwRW%yGi2dB#LM^Fw)mErFq9hxe`pOiT)QMUbcpu*xt5dm zQr5NoI$1U#QHI~x%?2jQ@TWDhlM`k5NoID+-pZ=*hZ(X{6J_{|1leh+Qs0cpPEVBK z&Xk>zD8ruv$SM=`$`> zhQB*@QKAgLb)8+DD8sMbWzBml+uk^{OA=-H70m3?R2g>1Wr;HQnO&YL!{Dz-mHKQk z8=ok{@7rcqCd%*?l&TCvMs!)C(7{aSJ{k28GctP zo4L2L?U#45>l0--^|M*2Qa^CdW+%$<&A4n%stiFpH&ur8J1+Q)T#*ky)Y) zAE{<}q71*clP%a=+4lAGtduCj+jX`uQHEa($rdHb@D`ihkSN24HQ9}cGJGQ~TfCPt zez`ciDN%+`O0t_%rQUe5TM}jX-gefyx3VgHjF#P+Dnl^bmMX*R!IDH7eiEJCo+!iD zSF@#iE8D)2kS$A<`oTrEJW+<9*=8#eW%%kxwlYzM?`URsq{{FLcW0{9&-}Ahdn?<% zG?(3#D#KqwygN~bFGgkeB+Br4dA2%L>Z@AWy@@h>hc3G>QMSF5-k&JL_XV>DQe}7( zcraCl<#=dsW!slCvWF98_$hw&NUGG=nX)yBGJM*QJ(?)P*M_pkQe}8GdOT6K1;Z1m zGR)GGsWPnQQ>ij+!>3ba_|tUHB+BraSN81QN*&K?b8VswpQmTfCCc#Cl5Ab7)bHqL z>l0=82rS!>D8r|o*~V0Wg7 z{!*&c=N8$^i8B1KGkYafhIgA+Q>CuL+15lEe$tq|wzsnF2b-JyEv30q)pK+0gb;LAEnhhBf+RZ)MxZ71^#t89oWi zcJHlh`)DTHlPbf6e3~fRj>9jU#Kq{0`1c; z^~@vubOhsJ3GbWmlnzZ44<&qG^oV^K6x5Hhev$EG$uXKb(n2_u93Jo#{5d-dgHOiON{tOzpq-)yk2I` zzFllyX#{S!Fy6?%VgJ5l{qe?0Mj+)Ux`O$sGJnB1ng@;7B%JlUC((amJ|oc=QN~L> z`2oVW`LE~CuejxRr_2IE|5DE~#3aTqu=c4+jiYlys!$_QUPV9+Ckhowyut!fjG!yjw`!+DPQ1j!F@ z+Tr;WX$Pd9ai@AX{iNfep8i4X7!OE)_&v%U4>WRaPckR2W|=QCHW`VlD)Z4s@+IRV zZVq())y5X%Nybx*Zx}Zk@r(Vk#rngHeT}<}?0epmr&!Ov?_th9!e9DjzR0h{JN>o# z`A>|`7=c$2j1L)4G@hK)rxbrvJ^7CLOY7k~5E5j5n1`ggOxKpB&llz<>Wj?Dzu*bh z<3If}kF57;$@4FoA8%yc*O=dJWSw~bxXyak>os%Yl6d5K-ito69pmJ@W}LhSk#C8= zA%2eYXM{QN#(o-ZJ?HUAbIylV=J;t^!Tb2JR`fHHw>O(l<$(l=m**1wR`dIfoGXc0 znZ8h#qHvq}WaGt2eM;*R1^3%TT}a0j*6dxu`gXLsQ-bU(;)Z%S&(Tg-Ki|&CIl}u2 z9Q(AL!G7*t<2vIsBXLK(Gtc|G{$nHi=2mm!kT_uf5J#+Mv+XAvPcV)&zGY+_SJep+3z%fdxQ_7U|U z`;B(d-(_K^u{#wBvTkYrm@%J>gZ_Hifc^QhIrG8%b6&pZdggnB`9$ME#`BEaYhE`e zu80HjNjKL&Z)Cijo8-ZDNjp!Nw>MsITy7*DFEJ<2@!S~eiKiRQ*`MUKTI<<|rWK1U zunz`PVZ7f+9E?cnx0v@e9-Pz@?1=bS_jtwPubl85(U zR<0%viK8{HInhWQb~k69`k1p1c`q7fJ@G^QPOzSNBn}?6p8htP)6ZJ-(MIC&aC7p} zYV+Yn=7D%6ep_5W%J_hh^PKpdZawc+OU*kN>4)_u-=E_8*+%A%eRrPqV~s}`0qn?& zce!Gok@>?uMe;fG!v5-SL;OjfE8@4i0zZ})8Gb=M@sOGmhnyQ{+JHE?(0rEh7UO&) z^UXS)V?FIM{^zWpXk@?NYEB+G#C*4ryu|oVvwoZLP9yR2jye6Y&cw+Du7BD{T!Rl; zPh663###T8@e?EM!?EK$Vjjt>UH#mtMu7Q%FSEct!EfTTABF{ZSfW2Z;cq28r6aBA zm{yn**9FyO^}@2Wx6p2HTK>5d{V`9hTW8x-5AS3>>rdP<|KztjY}ecPq!GBm0`Z4_ zhz0iVedaeAZ#MG2HQAi^g+Auz8ecJz=g7nBtmj@*WzIS?{%Y%|8J{-}FwQZ4Xe2LQ zW`36O7~=~@=6jVn?N2d3+Q@!mob2CYU4OrkahaB(hozpz+l_A+*@r{TQ+Uh@+V5%3 zK4kvW=irQ=bzwcI|FXDW!2_eLC(kT3=e*~AhWT3V`t`;Z<7-CtbB+1M#zDrhM&j^N z^LvfN8Tos$^{n@5^K*<7jI1;L%(s5Ck+@=iu;H-eGzCF(-jE2&{0Ch9LtyV#O^2a-S8xA6A1?_lJ4u#@%J zfjpnKM?2(Q+QE+aryY3O56_`bIqlIO%`h zCyeAx;+^MKy8cchd6M{}oikiN%-CdP{~l@%K45;lk#Vq(c<*3e(tnHXLHvB%`ZtW^ zN9J*X^*0-Vr!6pE-a8(*p8Vb0oPKzoyi#ig>p?x|4d*C&ko~;Mie*$7Ip3I>B*sN6GW+%}YiA|M9EQ6`W7w&2uB?9?w@=U|rCI%s1_GzEKa7zv19$3nNh!An!kE z^$shD!(HaA_nqdoM)tuo=3|WHgJI^xGw&_L3;F19+kIdpAD(S~yODK#&HOMU?OtV0 zemK#5u<=qO`+)pUTr72cwQ-}7b>#hlcx`rli*b;V_%p2_A9b>Vah+~XUY%}^UpxqGS9CG5PRu9gN1f}*FK?UQ zZM@t_-elfZTYtXsMkD72@k@R^#q~XmyN%ZwTN5PCPqqFMW22G$_qaLvVW2tt;V$!6 zjpU&l%&#y$Y-Hasu7|7#4>soQkF{ok@c`rX zM)sxYXQ16@G*Lmj+*49B^4SZnVBU$NtF8akNd93x$WyE%{miyK`TRihJB%ZY#Mu;c z){}fY)_U^XN^{1^_!e5vy7n}$H+D0!FXo%SV_a)wTszIl>+Fk8)-N}nX(Z2*f0-}r zY*s=4#07eg{8yCc-mu|j9!PLQ!iy;TIIS z^~audp&gr7urH`bPkZDG+M_@A1$y@5S~o!6gHN=c_b1|+yvzO~p3%SM=Qw}aUoTp} z#<;{tJT5mUj(eJu-!_=DpT?PwHy&(cUyz^ShZ^=J5`+4>id6w}r zUh-m{?OEr{e713$k#(psZ!tb=WL-I*$d{}O>%_X#{uVz!z?i~JE9M*7k6X`&CBMTN zCi{Z;nBMk% z*to~I&iJZvq>*?hAJ4UZl<_6wcH=ol+N1ss>+$Cj^Azr}qQ8;(XTHhv$E6RNe`q|_ z*kB}1ZZv0~nU-OirOS=I%QaEhp70{dJf!vWQpK0#RkkOdd2fId$L(yF?i0FWK6nn~ zIqYbkcGxf2r%37RZ&TcABS&+{CJo_66p&psuOlCM8@KddWp@UHbQ7@Lgr z$GNf2divX9evOfHfc?k$^S0}`ubNg6kF{3ZX2c$U##leVINLbGIMO)TxZXJ3IM7JG z*l2#e@d)E?BlAapR+r%&OFfLVgWu0ve_8p7e=hL8L%)^(-25}O=UcPcc+6)k{<-2V zS08C)-T(PhpRFBh)rH0`M&gD2M0|tfxwK#EN4XvHUtja*j3bQ9A92Y3;9NM|c0G-} z@8Ex<^=BBLHWCN3%<+f!LY^lcdfASBbAb6A<1*vXM%HbDIp^d?^Y@JSLB9*F-(`H= z$h=%^&VJtCyv6vkk^Rxre7+Gl(83NB2~IR8pInjDLwR3?7h%!)@n!c#+CK3?z9Nt8 zv?2c8X}-$1#7N#I9->>9dEz8<#>2XntY2)r*2sI$3Fe%qTg{mt z@)PTStm}!#rRF`2?2~8Bk1}pDl1J%>dAP~-j~a>B{^lP`GB>o4PuQjsYY4>gG z0rnM~al>aK7>6gE_E8^+P=FsK`ZRvor-@l1t#(ZpBlAf913S4s#nhg5+9mDsTnFnx zo@XB5w8MO&r(Jl8)RWiQU-!G;6OH#7*)N=TdBJ?Gb3OgDKgi>p7j?FK%=n6NfpL`aO(XM6JLg%? zx|)`Mj`Ji6BYCLSoV*Xm{5T7o59q;u2*zm%KiT}!L=Pno!c#onih*gx7v!lwJjVJk zkLkRzUf8Ak15P~C4skx#9Wj5z>Alu(HZrfDm`^pHZp44niftCw8b36)80r5K^Y@J# zjfWWDGV=T_<_{S!G~Qu6#yG?{(fF!yhmm=C#+-A4_L{9{|MfNRZoJZXrtwo_e`7D> z(?;f}oB8cVUJlxo5 zB;GDJKh20gTg_R=hs|lw{l+SZ3tU}&V@L|wr=p~nySzVNcdfcQ_;h4jS2 znpn}r_3ezz6X!gL9f*B~rE6F1kGZ!``zZZ;B^i_NQyM;hymoZIWn z$I&u|)u)_1lj$@a&aPnO{*S9w^E|AB`i^SbpR}408G7>+`Cwce@ z*WY98VZ7N$e%)gJijg>Ae;sZ8BqQrVytP|*_qOZj8aelu znD;XBzJ0hk^`Dv(PsHWD)-!I#!@To;_@eD*8i9i?JZF5&NZw_f?6)+IXWRZNBk$v; z6>AZUOA^j`$G&F1F=Kr24-kw87#}urz9eR4O5B3PRf^0X^T~Yk9DI~p9%LM6Bwx@D zdg7wRcI08x3gV-m6+4Wq3;q!g#M^e;F@LPzaO=tIoa+OvzsPuqk#mmr2UvfSk-T`G zc{k&JM&jXB^NB|G8_zR;V_Z)jy~iA=wm=?WeAI(yT3Cp}$otn?^DL={E<#9fLc)tE zc_FP&<2S8`vkn(nvB3De@oeLzM)C^dCJu>L&ZlPElfR~#FE+9s>~G@cIM)-0_;-Z$ zeT?-+=JihVn~jX8r}+vaaY|h6ww~w7M^9Q$etgCJM&qeQ&V`}opBjm?{S*7G=Jdxr z;2-;@(ssb@7H%`*C;nez{V_a{;4*XOmFkNTjGU{9y5ck|*w56Tk6`SXa47x5Q|w{| zd5iU9Jj`dikxDDBjcQ7{Z8XL;|?Rw5%*SC zuzrtNK^~*O8P;&D*3%Z1=v2(U!CxMQT7>j_zx#9ZnlE`wbq<`HO!oKYBlFQ zjdeQDdft~-nsYvZ8?3*@NWAtk=RIem`9vf4&-Lb~7?&Hj80iOl#z#KuXgk*bGV@Q2 z*BS3Ljxw_DFPTp=t~I8x!V2b#_#n^1x2F%9GcNKV`v`O52%h4DR(xn=zEis*&bo7+ z98hFfpntDGpQ@SPt6af45?_6-Coe2ACmx?RC;o_wKGxS6&o%4=eT+MdoP*5s!PYa+XPIv{;urC5bp_`a z{A3I3jMp018dEqiQ5m|psYE%!DgUg>{7SGO~=Uc_l<>(l3mXXc+g#r$#3 z6SuFsoef6v2J6Fqxy$v$Ta7vUew;b$FvFa<8e)E*@h0PNBYA4QIsQLlewC4RW1o{( ziEH*1?GShDuPJV4fpNQWe4^)lf%s#7$g{WE{$At##>vKAM)C#Yx4HtqnUA$p7<(Ey zhna`OtOEbj{eT_$w-!Z$tiyAOKBXV0iiC5%pe8QiAn{E+qbJT-cj9}L4R{~8+x#)( zR^vt^(9r_^4KgQRoojxjv6qo^k9NsxtS|9IeBk#$Kabzfn)fpfHV!kg4;deE%RXQq zK45$DYNa{x&Hg>pdiLpN^GU`O$hW)~kuOiMA^V*3D19DI{K3C0;@_8j4*N$D_JOQR zhs2I}Y-c@jNBmF^rtJ{7w42(|9d^Vw?a&{3_62&{Wqdq`J)C}c9zC4?;k3(h*rQMF zQcin}BW;Izo?|}o1OJ#Wp2sd7598u_p2H3fVoyEep?&(NU!G&U_{DP<*^u|E-sT|t z@ImW2PuS1TTi?^T(n$L!n$!Pu^Ye|{jl7p$XHNeU%sGFLFdu0of6g(##(0mBJTSxj z6XPo5#snvulP{Xg*Bcia-!_gho@QiTOe=1-um*(@f61rhH_m&^8DB3HM$XqE=FIOQ ziQQY~qYLyeSkFAu4(6=?d=v>X4_TszK4oEl6bbelyc>=B&fj=J<1^ zIdRPV4Yr>1m2nWqm$`nLv7d3P@d@McM*P3ZocZFt5Z>SQl}7Rb?`?I~A8b6|xYF2c zyw&)gkvzbDW4>Ko@v#Nwk2u|gfsys49jY5FY%#Kq)b9(4SKl*| z&g-o2VPw7UGABM+zbmZoY-B!`n$r&Zf;b}o9%Q@QjVp}gizm!S7@sySF*1MK%*n5d z%*n$?Cj3)#;*5RF_;S~?j<=b!Z;407Lw+A_yN8XFj603knU-OXCG1&;LDrKG$`2$0 z@y`4ATGy{OUYyu}Q9AUC8Wz~?Z#{7e5=ZS^PyDbB*ijD>&+Pwce*QEg@zf{LKW5H+ z5~sxBZr78~?luQlSK^}H6|Dbq^UcOl#*&eKUp6OhJDNXZoMRks+-hVz#0PPElIuTB zp1aI^i4i-Vf6)3Am>14b^1|rEkn=22SKtSEhWf<_#@dAMi(Z--FwR6>NN-4L$ZLtZ zkdi+zV11vng8016e1Z`_cbRiOTwzXLyw;p~Cl4KJJ^T6q^O43%;{YT3?R0bY72}y` zJ?G6s=Fb}0Pj{MMZ(L^l$Vgm}r(dw1d1C)zf2`}b8=1cvb6|-D_66&9yY=@Qdm2;V ze7?gK&nE`Vd!qgf#jmu2Jd^5E&ihfSPk9$Bcs~G%3+9>m;CbSW^`#x^_p?Ag<#~8& zM?2)#%WbgG$akOktN z{J?yC?E3Y_H13$k8LmIW2&DGJ3H7``kuS*;)L)EXJTT!$B>am~;*hwQ><3>pGM~)H zQP!_EK4Ro~{9a=HJx0z!=8toYc){PLwr74WG$&6mkFC}(F)}~gC*ZWV&2~>3$-nQK zZ#7Od5)bT~PpxNMz09kOLyf?83*`M~^W%&ojW-!nV4m0qA0!6&m8i><_-DS(Dc3}S zeU+#`LpNB{J*jyq;rpVuBnD~!U0hE-B~PMn?|L}Rk)O#Y#Pdfspnj)0^ZJze3C1&x z9~znO+2)Lwc%dEUf%P3_d-lu8=EUz}bIu#yBZgUjsj;{5Wh3i3)BG&saYo)l7nx5m zlK0pT$5_AC2r%Evtr$Rsk$ucMl6Tp+YZ80r0rm41t}(uEyvoS_G_7DfaPkLvX{9ST zXRrg;A{5|m>vt5??@N9DJ`9Lc;*axz$5{W=PvRJR;(3)HJi^HQ-)UZL>}w=HvJP)s z&-@QGZ!xl8PD=Ew?{@2NH1;y`{>yw4&m&xauW?@T9C3z!#QPJryTEvR@;rWSxBhP9 z9wYA&^h-PBQTn~a_NN-zKg>JhWqiyt?;Wf!ae0W_0oEcIiT8Ep2N)S=VqU>Ipug6_ zP~&i8n%A+<5sc)E0g1kd&Ppnn2l5f;+AiD8F!El7fA?F@dHR$&d7kmyC<&XqPqkbWqB-Zr!{#p;$@2@%HyD>1f#WT( z|K^&{HkORX8B@60iUmg2f%WD2_34A=>>H}_gZO&U6%#So2a*pkV}8#@u@79ak9w+E zpRQJrU#SPvyox%##@Zc$5ZArj68po`BEeEu-3e{ zvBCJLk-T79{{JsluG+Zz!lo zU2kEOk#lQ`IpZfE;5Ye&_sfm8hjV_OYdv}EX!8z6@(%l<#(KuT$ow%Q@ju=CZsSDb z%|_<=Lv!Afh`*<;r=3c3^8DrIbw=X*9`kpSdd79B^}w|VM)v7UbLRcd#18rag7Fq( zC*vw3^J7|uEoCVR#6Rl$5dME++CKiQaC-xdoC`;qHyW8A*3Ie)@*Vq&^|``l|Wa9(IRYumcr#W%QbL9Uj*S~4xyu8Gm z^`B|p%gDLGJUwXr^+sR>f{{2m+#J8|GiN?|jQN7UVS)J~Ur_&L@r5t@Kw?iEat<+1 zC)(~qWYgk5YOAq85i%FldM0)2#^oQSiyW! zPrD1OAfBl|$imS^_8ay1&pv6U!pM8WIp)MmVpgV{3n;F#Fw|HSueAHJ?RbuL!0!pJ zXTI5od9``PV-{$KbBH{9u0)!{*Yq>+gf)YV-Hm4(iIdglDdbi#Z@gz? zms#kCBEjAXXC4!EAw9IP=8IKyu^H>p-kkO7Xb!So*pWYoXV(8@KR3>}#>lzGIl{cc ziMJz@etB;oe?IDZ;-RNG<0D?!SLEN7wp(O8!bm=1KM%D2Oyl#${f(Q9#0~qer}gg{ zhZq^}67vm4_UUeOfcSgIik`*;jVBt(x6HrQ<-ZTgI=<@))`2|9dvKb^X=j*)tBuST z`xQIp37#VB$-Hh)3^pe`rMFtqVw{{Bm=n(+@j$$GvK@Njo_Qk9d7f@~4y2yv;q*g$ zquf9*V>cu7y3?HeLH^_X+wFSh{Sos$M$T*2b+q;LOZ&9b&-HzbFB&;dX_xbFqwD7w zcNu}JEZktc#z>s9FPQK7t{-e%VkD0-KXuk$YMf~7Z6wcf4&Grs=h3O=)ke-U@&NNV z&GqD=dh-?|e#|xp$OGgn#?{jXO-APD1s*isW+V@slGO9wOgaceQ{WdivoUM^8WG8{&-hs~Cb9=bH1r zb+kGA%CzDh3$GZ-7x;tS3JW_>7$+H*C;AH#ev3KzH>s|`fAs7#IOiLDUwGKQ43c`{ zkago+%x%Xx$GXk1p8ApI4;W83t}|X|tTQs7Z<&)ndG2cKi68oV)OzxCtN9Ve;YQ+` zJkGe^c0K!*bz)!das6?|rAFdxg!$#hLySX>eT~DCdd6*a#ViX$P#78iWb=cK#CKv| zk;W{pvF+YBZZcxm$9%AHl@VZE@SO-o_U~BpGmRUK^lMtd zy29CCyr(Rsq5xSxs@aF|hY$*oJlnP^BkQ1N7-(L z@e<=MBkMq3A&*|{dh+y2^LvdgM&gij5vL`6N$JzrgjwjqGFc$1~Qm5Br#t zSC*M`AE2Gvt)FTH8Z2})o=1g|yvTkbzhHKn1;#Vo{32tz9^{dMt{-E3(s-?LzLEDB z;&{CEJ&lYDzq?!ip7AXs^GCj+AL2Vp?AnGp*#;KPacP2 zA2=_mKggW-grs_(6o1KYaPl5`5k$}W(?R-NXIEek5~s8WcCbDB;#2e9M&f6!`358V z-L(Ao)Ge>mfx;CkYJ zw0U3S6-L&ZcsbYlH;m*t{DZr?;&ls08p+S>gZr(2)VR%_Y8+-}#8F~Z+8pJCi;Bwt-<&Url7e4~+lH_H4L;|3$~%lut#J^TG|bIzGn z<}VwU8G$_(*gyE2=E)k{?chPAM46KAwSHt-FeA#SC}*I&F1rs?At-+#OW#K*fWl8)*o$r z%s9<>my!3u_2%Te>&;nD_91b(&h?C!_sG+&?`LEl*qu{RuzE zxt=(~Kh6R2*2lIZUg?MRoag#x<7VT2M)C~n{G9ciQ^U;r7*8~IGQMb>WPIG%&dB~F z@3EdMT;J8$%lNjDxHPRe*#dqXXZ}0}#xyRdzS_dAM&2VI#K6dWCh7{F$3NyTwPSu$ zJ$Zq3OY7l#5R6?4kom=&dBnb>E5IUsXWQ|7hvd2TNj>eb{;n~<;~*~Yay|OP%!!+G%?BCJ zG!mDb0~cF=pOJk^T-|B?-NsjpCm2680*voAD>hJJB+sDUieNm?$i5$&)TflVO7-+J zAgxF^>yxM}jg=q_~FDe`8#t1-Hp#0S=VFD zxp%y4-pxo}-fTYFNITb>Pc$;$#3yn0l4$xko-fmE*Vj10Nc#=u?CXcjS+C{h zU5xnur1=!%d?W9pjIY-ETZ~T_$^YyxtIOYW7-q#RBk$GNEwb=13M2dAVDn{0#+jIZ zhQ4FXNF(#J@iUg6uc);8EhFo`^YhI=qrD-~j50sS$b2$C?9Yy5-iM|0wo_L=viTz{pJdD>!rjgfr?zs~x5j3bPzjpS3}eyR008W{)evEI+S zeywr0k@pqGd4%=smq*O!8}}FwGy>0AU>(Mov(Jvi(AX#8&<+R*o{@0YBT*Or|6;6X zYR~><9q1oD_N+JW4s zDiR!Oo<7e!Cgz3oprq#Qgzt;aOblugzAs9AV7IUSS%1!R{3o8!gTw`KkNqfj$atrl z^M1C){9YsX&ynV=%W>wf8&5RyzW2U4^K`TMIwQ}|FejebPsAnrWPjVyKk>=_B;K#G z9sA)WbNnWMZMVLU@igN$<2%N+#-2vb4dQLI^}tOQX!lJjjDw8qr^=+B^M!GNoalEX97;UGi8J)@_BJ3MLB`Gcv9Fo`Wq$5H<0j*BV^3qf zainpLk$Gi*5HDxDp1AK}US-^4yu(;BPBOCoP3FWc+II7c3F zL*tC(!Q0JijgyQ!jret^d4D7CH4mGQG;TA#Y9!ux@4)|eTu)x5KjQlW*OTu#4~AP$ z+!OD#kADZ-j&}K6<#g+rPwpqYUoemC8|H`kyUWir@APAJ`R8gqt+?Mvzo(iHF|uz> zD=xFZdB}L!PqSRldOW~`#+AmIiJpCddO5-u#9fIU`!rD(QsS9;XMdvSLE;-Eo|y;Y z`Eoxn(8#_d&-SH3BXP#M^4@)g>xoC^_gw2)=cVS1i~P#D&wP_7887pGv7cwXTFi-S z;$pt_(~OK4e`i?Fetp@zkCFYg$eew`JUwmwGsZ)W#ODF#lZ?y{{SpVqxPH0u6C-fE zg;~t2kq2|4f%meYW!}=$U^~Sf1-v&;rE$1%iIH&=hvcUfuID+z_9sV#Hti8EyS~BXQ5?EF(dIk*8DCb;~-DWwSI{%9lT&2{Ee zjI5(+8CuFx6mBwKWvoi-zbM6@FKVCKElm~XX&f^TnByn{ zKHwmE?L;d$ms>E{2X;;B-!#8G(HGHCNdVw_Cr~_^^?4oxDpPquoPox86t|W_@a{C+{%t%p2qQz;>+fd*+-s z#N$!c-)m$Zc+aLE&XLn?H`@rjfMBE@)UR21-iTk+r|-F+QeotMe2#eloDUe zOIpwTqKRLG!a6RjaIr9YvSs&t; zeXtvYec-@-)Kg6yGGFAY3vFA@tisNX}kGqV^1UTO+Uo>2-mZouAKn~cQIbo24X3C4Gg7aNBeiQggS>>K*I-+JPYean7Z=lW{n&Bolw zxEUAYVIR%49phzu?^sX&mzg&i&o-t|lPCt6?=;?*)bER)V*~QqgZp|Qsh^Q>@@i-6 z$&={mhdfJuZD)IsID5eQt;So7#2x#(kM(Pfqm0aNPjljSm3eO?`*Ml-L?ii={lomQ zU!J!eaq+e}`Q5Yvd*Y5bU_E(X-RK91!$x!BXskK$^tky$#^;QjBgE-Q>u)l$56Qpm z_l2&XX9Sq%3`K%(nJ+OC|Cq7PgHa^-m^u3=sjhg@irHy}IdN7{UH-l|=S6BqzD)JR z5qa}{8(e3+#JIna_e0iavGq?GiEsAZOzRIdE;Z5*d4Y94%Jt-R; zIltzZHyCOEE_2!$Y(C$}IuRe7YlpdhhLL?i|5lgbV@n;%nkZahew%SjQV)I2!ZsAf z7mU0oFz&>x;=ojq@C6A^Df64^;nNU|M?2(Ekhr2->{HJ3AocW5 zJ0R_o&lwN(7r6n}AATAQ8)q2LFcN1^n6p3X6T3?DH;mhj>@(t@{MO|9ZpKZ z4f_AqT)8h%Mm{C&yC%I3r!?+2{IKwG?SBwtN8uQRSTGJieH zmm67!9p=?W_5=18SkF3fF5G85>-3=cB;!k)_-WEKh|-w^{n@?=KYK}8?Q0GYa|{gn$r*asK5317#Zgv^To#W{d2JO zcNR(wg$JhC4;`FZkENAu1`>Ot(_w8JnuB+pZi zJ@s(vNBKF{m3XbQp7q^o&U1&D6X#EuFEf(ot~DnwzGJ@0c(xI{8Ro?4Sab3^`G7bY zB3s z{V?C0FZe-TU>>I>?b8lOUS&K7yMp;=d`I&@g0sv|N%TV#&iL>L+|PEzH*pB3p17wS z5Ig3D=h+W%>gjis`z5asPy1WH(a3sEGAG|(Z@$vVbEle<56Lsztv}E>*+?EFZZERF z!T6-{DdX)%;%~Hh599sD?nd%B`;qxP%=Modi5vXIfA;w>+f^DFFY!-%_{Dq42-~y& zX`gsv9aq?ndFFkEeDi?o-!rCAvVyq%#QY^A^O8OXPum?xh4IjYLy4m^ta#t}qLFcN z4xefLEaUmcSB&_7xcM^U5ym%-n~m2Q&ofRpE;bS`>&@9OubIy`PD?oB>S_J6i5>n~ zUD0fTcvx#r|7V-0ei08hxgL1i0`uBxzQB05k#mFmV0Fb+7CIQolgtzP^!=CTLG~;8 z6n(9QJ}3(C1M8~_>c?A89I-x(Ynkg886P!rPINOr%E-Qd+PtTc{Y{?OZv9TF``awqa zC2_yjdghJz++aOWLMTA?0r_QIfnAaM6jw|wGT4XygruH)ikf(0f9&T9?3gF);TKu) zxN)VCcw^ljww`$+e=#oh2XTS^LO=Je@j2sT#<51m&$?Y}{gFoU2Knv<>&dsoHRt4~ zuD{N>$vE6N*?5bQxT`Tg*SN*F+<3lmwh_NgD_G|ntYDrym@}`~Gj2HN7=90S#U&Ud z*gN6O6Y5703h=2!zcAsP8;QD1i3=3uUHBJ8;*j%?cx1jfUoLio#Q6i}2OEid^7(Y@ z+5emu6Ramr_LyH_ywbSC$bKhocu%1J=WX|nahmZ=Bl+}gbM_Ca=x$7(KiPWbgXdqiK>y7DGuG2S_PpQ0@33%yk^R2W{1Xg~HI9CttE(*}t7^hd#Aym$dW$HFloAIfhXjE=o0K?a`=J zv1*4Zf@;->imKXFh@FU$phi$aC@KV{q4um%Ge$L<@NfBYT|OLn;{9~JzjMxWKleE2 z+)opU1?W6zf3YuI#NGghL+jUlbv@40zT>@Zzp-zam*#T~<35G8(EjE<>UT#oe%G@Y z{S)qkuY&yqdH}o$?Z-ROo}b@y_IdN8f;jWI7;T;-)9e?1=eRHWC-LJ8EYktK2h0o4 z!!XzDKDb`}es7%D5NKUzVt*3+sY=&rAAiS#)|LC!341TN8QNDW(a)jr=8^WE_?N(Z zcpSRV_WjY=v!HqE`EJ3!9Qyq=HrTCi`))OU*YEy6!@eD^f{oC;cb_t_XG8mM1KRhD zL%Xl8_cM0iZ=Ck|_*X3E0h;gC==;!q)QYZwZ=rcM z3H=^gR}0bB-%Ip8cocSqd!hM#9c>;tzscCGqkiZiun9haJK;BY5SqvKodWE0p!My3 zdOr3E!qR5(BUb}-eeRp?dVd5S2-& z6q+~I{dw$TU^a9e7tq^bA#8>#VQ(m`-+37Fpx=eY`TPbAuok+`F6eh5&UrivcK6eJ z&N_Y_;*|CcwfBtYt38QxlN`Zr9vbg?4kONUD?$%|bD-yO3_TcDLGynN+Wp^z_Wp5w zH?X^p_Sb>fr@&Lt{O^VC4>!X)xE1>Sx(IFEWT78I>njKS82TRT(ffBF{=?97-GO%h z%yaX8JpNqx5jviE?t1s*Z-Ktw_hA0}|KA?K;J&yH$Mf8+FQ4z~4>&4l=NHU>rF~t) z8-Ku`L7RuWu)A*K?3fR5jQw$AOt_K9-*zGoxaefIpvV0XRt z6VKK2bY2sQzXHuC{f=w>oh8n`>ACnG$9JB;i9ZDULh*$kM>N21uojy4YeQTk`dqN@ zL7PwZKlj&g&nv$FmI3ew92)G(d;a(j+y2KN-R9wL9F@>|bv?7OyH4}M`_gmDBd!8| zhufg%WxV_E{a{|1H|DwhtPSHlgT>JJljwZ-9C~lr5A44k@Voxi=&!I2+Rx0V4D2QF zGOU8W*E+=f&%bLn?;X$mD&+y|z~>|DG0cPo(DRA$NxBI~%>ShLXkUsUD>_ot9Yb6u ZIvuuyp1<_`eB5!`hv%X_KIi+i{{hsO0d)WX literal 0 HcmV?d00001 diff --git a/uniter_model/data/test_data/input1.txt b/uniter_model/data/test_data/input1.txt new file mode 100644 index 0000000000000000000000000000000000000000..619f9be9298fe03e442596509debf6de5e019598 GIT binary patch literal 92849 zcmbTfcf3|b*5!YR0*V=PAPdHugNckFB6>|2iI*Vp93D{At4LB*%wkp~2}V@R2@r64 zI?Z&N>3RF@p6N8x(`lyDOi!oZs$1*(d|paVqd)$rwb!nFs%ls5+I8x=&%G>Ze^7?pL<@v`J$swU($Z;FFL4f%_;_rKe(czVtziaG;hkh*%KDEqn%y#?W$x6Oljdga?;PJv%F^QbEfW^Zys2eD)}b})IADAyDHqI`mzT0mW7?1J zAU1PuDeF9@$%4U7&t1iMnpte&Y|6NmFfgZGC-XO>J#`T~l?|rERQJ<}b{; zj=AWtcJTPlYP&QkpWaf+4w<3x%;-EL>9DFsU1x{R=sKp0Iyx+L)MZ6>_?UJp zvL35ugtm?-Wk)t=M=hK@Bh16mt=TaHW_0}g#%dcH8>+Ko)mRwyaph4TA4h#cZ1BWV z*0VY5HRJO~RNYuzuMze3i28U$eZz?Q#RmG9vXh##0Upu7*6ieeGNOj+#_SZ2=+yFv zPKzTtJvMkoDXVPGstQI_Q(x0q*_c&(L^U2!Z5UBqY@oiBH8f|99#K}orR<94Y`jNwWovd-;fNX=o3v_IdqmfiM|5o*(RH!G38ideb2h1HL=6q~ zP4(Gik7$ZVG&PKy>_!23vqNyXsj)c zTH}arjSb#b%9b=|w}1YKlx`Xu6^u(gqGcY@@-U(mv4NGP?2hK_PLF6+Yj#)Rh#IOJ z8f&w=J)(QcBU&9tbZ>0%zEXC7bM`>dh(Zpm%pUZJ9`cAD4kLObHn66YJ=&Z-<`F&K znmtiCqS}T!jp#`=Hdu4?R9j~)Xm4UcGB7}1-tfwxN8+s)ZK9?`q4*?WZ}YHFxysL$T_h(0Kf=)*Xok79!# zm$L26*^bX0k(R8XrmDI!+vyQ~;t}l%BibDs*i*_rZO;CkNA#7}?B5rTNHbJln|;+I z`dWEJUymdDMr`n#rR+a6XaBKiM2*!=nxk)dMBny^z7t0D-PpkQO4)yE&i=DU^!?WC zzkL3P>KYru2~eB;i%0Z>@`(O5j_8N6!5@{f|Jt1Wx1td>)z)bD|JWn?iAVI)FruHu z27X@3{(E!wKRlvev}XVF^GBo!Z>rYb&3@?-{i-~oU&j&sCN}unQue=^v;SQ*qQ<8B z%F66_9?|bTqCbQY{V_K1r&9KRnzR4w5&gL}`+xsrL^XBQ*YA|#|HmV z$}8GcR5a)9R2Gb@uCl(mu0C%cEV_o*lDvb&IJ&%}L=7m5h{^&FYMZnp z>hnW`MfBh>iT2=diNS*&!r1B&l1llJDx348RN8~1CAQ|rs4VuNx~i!_xFA_aCK~n7NL`kK*r^@EMm&zg!Y8$F6EA!sLB6`qAqCMy3&Z|@w zcu=R;tj5Z`I#@&xY9!i&T8Y7fI$>O z!8H=g9$YIDJ-AL%>}!IgQa({-b3RFBkq6cFRSos|s1zcP*+pgq$J;izT+^H>oW0ptiBTN*>%S6ZI2v z(Jc}~Kk9X5-YOD3xK&c@>o!THe2L2D{C1U}?}6g8UME$)G+0CrmPxb+%OwU6RtRIO zDw&qW$EcBqdN$;xp zlQOXfPf09$@U%$u;2BA=uV*Ed^0g|P^XF6+c~DthRbQE}3l`CX^%CvD28qFgjl$UK zCP}6Id6muiW|j6}i^SIa1(n4f)Q7D7qD<_;OA^Z-yetwuctuj|>s3jme5=al{56$D z9#n@{@BH;(5j}WAqCMCqF?jH%Ft+-Zq*DI2%I5qXmG2Ot8C79s4ViJR-X`6=R1Q%^xzYT_F$L9;K6QT zY;}*MQvRvR=KSBOv-vd{y) zKWe7_MJ8T@A4n{(!M}<`4}K^q_Vpu4rTo9DY|j6i$^s9WGL=O&>6#F_z zlI{m8bw5y9nD(%715_LaNS?GaY zD;gW~V`X9wj+0pS;CPYf!3mOLUnfe^{XnJe2P%s^sH<+$&7*g)h#vHjXb<{I3?B3o z##Z}F()~cC?guLE!9a<+AE+$!pteeH!TBjNQ9t3W;#7&DAN9I2KTRZhaJr<}*BO#@ zKTxUrfyyEeYHM{M=;Jvc`sdT_3!*w-LQx*w?2{Xk`*2TeL_@}a>ZdN53)Js2)AcrZd3TOBD$ z_XCx>AE>kkqb2Hopt8_|`o;!D>lm4+pKzjGATjjQcA||Hi5^@iDfTr^lI{m8bw5y9 z3*P6_XCyo zV5UUf4^$R-5Z>ZL)}AF3^%M5NY>A;C^|~^jBN9EBD=GFhPm=BjDs?|lS>!>jUfXp) z2o}+U1rqH+Nn-F|p)j_(NRsXcDs?|lX%7}l)crtZfd~3_MwY z^x#%Wv9H@C>3*P6_XCwh9;jWN_Dh3B^kA7pd$3$$@L+{7wz^W1?guJ$KTv58R!P+T zKxLr^4K?B1xmzao;2w!(4_1pr5AKx|`?^n(?guJ$KTuiZK|@2G_Q8X}B6{$UM0@bC z#Nfdr!r1B>NxC1X)cruEJ$PKA?guK1J!q`f8aycz^%HjTQxZczZ9DmCk?6rQl44)a zO49v6rS1nRi#({-hrreOx?mAKSTE5YY>*f{*eHyxZjz+?flA#ERN8|r5_LaNS?EDs zI1us|WnvFrl34cOWs&H?E0SVguS(MWK&9>nDvLa*(FW7~AXr2X-jHYywn+>gyeW*W zz9mWb1C_cTsI&*~O4R*8WuXW9SgO7@e_tl{-~)+e4?Yx$9(*Jz_VuwO-49giexS0* zgR1)4s``9qu!tUfBGDe~k{CSLEsU-1k)->9O5G1s+Jmo1)crtZu?J0+dI$chOzgqe zB$hq+x=8fk8Pv+dhkO@v9BLV()~cC?guK1 zJgBR!(>?je!6JI_6N&cVrxJq)KNH4Qe=bS)1C_cTsI&*akf{5C%0dqstMmic{FgHE z8vIIPc@2Io5kQuhOuMIJQOHR!G4cflfh@Oz2&;13dm2Y(dCR{tbP z_XCx>AE>kkf0n5GfyzP;Y8u1Y{}-9qgTG2Fd+;}r=)vD5#lHR_N%sSNk3;tZl|>%Z zx845Q2aD)I2Z{Eeqr~7rCt+;0vn1UQRO)`9(jIh`sQZD+Vh`G)Xn&d5g99X%J?JhH zJvdNO?CT&&x*w?2{Xk`*2aQeo)bY?@5j{9eqCGfVV(_4cFt&PxB;5~G>VBZo9vm%E z_XCxM9@OjTtjv#a)R9WTJlB?gtV>Kk8NY1Ci*#>5^h! zXGqfhK&9>nDvLa*uh%y<^6Fp_J*bgr4{9X_59)-m)p|*~AE?y*K&3rslBoNE%0dsC z^zmk8ewIwU24_nwufaJY(Svg(#l8kf()~cC?guK1JZNmx59ITq!6JGvOrkv)E-`p8 zLKs^eDM|MOmAW6OvVBZI(1ZFqeaR*tBNKaYfyA;0V@09|7fOnKjgzGNflA#E zR2Fzpubt2Ukke{Xk`r2l~uWKTgfB zmWe&MMq=56Yek|5*GY9(Ss=x?ZH%u!GjiIY;~F> z-49giexTAG%#^76fyzP;s@mR7XURnUwA~LRhJMtm?gt{#gSnDoU-KmCexOqK1C<3H z)K^y3>T{etSVRvNNVEqfiNS+~!r1B}NxC1X)cruEJy!=oQ*HQJ(cLn!2lq%Ud$3w0dT_6#*w=lMbU#q3`+>?L4{B?}XK4=x zi|D~a679jm5`zbi2xF^jBtskOT6f<^RTy+nJkL1OS=qcFC*Ns{gdDs?|lX%DtY)crtZ zp$FP-daHO*CidVZiDeI77Kt9bA}RLuswCYHRO)`9vdDw*qjud7f<^S;4T<((o5bM3 zo5I-YTat7?P^tTYN_+6GMBNWm7JAUs5IlHaCh8~btji} zAE?y*KxLr^O4F73&R`Ke_(Y;T*d;M|uv-{g-6Ki&1C_cTsI&)Pk*NED%0dqs>-58- z{HrqY8hlM+c@4fU5kQuhOuMIKbuho3@yD_BGizAe!nd`DvN;Jd=u z>h~n+exOqK1C{pR`x13OP+8hoX9 z#B1;?iRCr;wMg{fHflA#ER2F*Bq}j=T7c8O&zn5qa{va`U@JC^6^-q#? zKTxUrfl7PuXNkHWs4ViJzNSjwPR#!z6R*KvC6?FVZz9oyze|dJ{X>%O2l{Te?guIh zJ!t!Ra{FKrJ?J3O9(0r#Jm@5ht#+2A`+-W`4^-NNt`c=WP+8JgH3KTxUr zfl7ODv_#zxR2F$qUmdVBZIz=OKVCVdMtuMQT`gBpqUpjKk=piUTDt(ThndgcI!oiJ>3$x-uUt z5VBZI&;xzoRhc&ji|D~6679jI5`zbq31h35OVa&7rS1nR?ZK52 zbw5y9VBZI$OH8f-hpSy#2(C+SoUC!Nc3Q?q}bOyNxC1X z)crtZkq7ma`hAT&4;Imb1rqH+Nn-F|p)j_(NRsXcDs?|lX%7}l)crtZp$FAf`tUlx zStj=27KvpKT1BD?L532P8 zy2|{)U=cldNTNM>SYq(t5n*g~jU?RAE+$yptkMJZe6g59;}yW4>m{)9&8lGRyRq~{XnJe z2P*Bs7Kyqas4VoLwz@{2D7+{WufakQuhOuMIO{Nh0iKp z4;ImbHzeAFZ4!eAZwh0pZ%NYqK&9>nD(%6$5_LaNS?EEv-tF}B1DV)^4kQuhOu_TU#1 zbw5y9=t0}B59PmAE+$ypst}&?<>Cx z7SV&>OSA`nkQhApqcFDmCrP>=sMP&Hr9JqwMBNWm7J5)wud7u47n!J^@NW87iJ_mi zchkR#L=XNhDfaabNxC2Cw?cJ4P+8n zD(yj6iMk)CEcBqhE_}~!f0?MC5Jd+_4E?l4QFoE(!GV%uUk6Fj{XnJe2P%s^sILp3 z*&P}zq6ddbvVBZo9t@PI z`+>?r4{94KHB+a^ME!&)I#pumr!9(36Nw(2E-ChPh9un&RO)`9vd{zlAh0s84i?da z8j1FxR$}m=P8eIQm!$iFO5G1s+Jh#Ex*w=4^gutR(7WkbGO-6|ODubEj!5+2TuHI7 zL6USoP^tTY$|4VHs`ML_`OshyJs2j@9t@WlJQyL2t&WtW`+-W`4^-NN(Gqn(P+91K zers6g&KQ~4g9{{}#AP-49giexS0@gW7P?GzW|5!6g#y!KD&|2bT$B ztCvgC{XnJe2P*Bsl@fJ7P+90fV}1C1_G+1^pRkj!kr?_>uPgIwMWP4SNs4_?L5A-!V z#plvs5j|KY(H<<97(7@ZjIFMer2BzN-49gSgH;lBKTuiZL48x(E%R=f*n@i{mOWT4 z5=Y;}_)-49giexTAGY>}w@fyzP;^aT{{VBZI$b-sy{g^R-Jy=8!-jHYywn+>gyeW*Wz9mWb1C_cTsI&*~O4R*8 zWuXW9tU{lxy)P55!3Pq{Yw)2+^xz{&v9FIM>3*P6_XCwh9<=>he!eqUL=QfZXb*Nt z3?A$j##Z-8()~cC?guLE!B-^eexS10gSIpIt1|H#d`)6`4ZbcCJ@|&C*w;5D>3*P6 z_XCwh9@N!@Z_IuxSVRxLEzur)M`G~cyTaJ&_ay0lpi=h(mG!>3zTd4k zyMK|1J@|pdvIqYv5DRUM9|w!*!A~UGgP%$a9{fxg zTm88t-49giexTAG{6eDc2P%s^&{ux-+fey0WnvG0C9&+muSKE%O2l}Hax*w=4@Svf(D$Gs$U=cm&AkiLllo&kdB#fAE+$wprN|$XYIX%Mf9MLM0?OzV(_4!Ft*xXlI{m8 zbw5yP4+cuq{Xk`r2U_d4-#?Iv*Wgr%BzkbVq}bOPl5{^%sr!M-0uS`Vq3|1@ z)xjcqP$SVE)JhB<)Cps&^^$ZyP^tTYN_)^GQTGFtg&u@p)%S~JVh_%iSoYu?k?6s> zl44(jBHyNbW4;q;cqPMpDM$j@Ht>_W!tARS@*q_@oz#Mm?*=~ zk+XwRW%yGi2dB#LM^Fw)mErFq9hxe`pOiT)QMUbcpu*xt5dm zQr5NoI$1U#QHI~x%?2jQ@TWDhlM`k5NoID+-pZ=*hZ(X{6J_{|1leh+Qs0cpPEVBK z&Xk>zD8ruv$SM=`$`> zhQB*@QKAgLb)8+DD8sMbWzBml+uk^{OA=-H70m3?R2g>1Wr;HQnO&YL!{Dz-mHKQk z8=ok{@7rcqCd%*?l&TCvMs!)C(7{aSJ{k28GctP zo4L2L?U#45>l0--^|M*2Qa^CdW+%$<&A4n%stiFpH&ur8J1+Q)T#*ky)Y) zAE{<}q71*clP%a=+4lAGtduCj+jX`uQHEa($rdHb@D`ihkSN24HQ9}cGJGQ~TfCPt zez`ciDN%+`O0t_%rQUe5TM}jX-gefyx3VgHjF#P+Dnl^bmMX*R!IDH7eiEJCo+!iD zSF@#iE8D)2kS$A<`oTrEJW+<9*=8#eW%%kxwlYzM?`URsq{{FLcW0{9&-}Ahdn?<% zG?(3#D#KqwygN~bFGgkeB+Br4dA2%L>Z@AWy@@h>hc3G>QMSF5-k&JL_XV>DQe}7( zcraCl<#=dsW!slCvWF98_$hw&NUGG=nX)yBGJM*QJ(?)P*M_pkQe}8GdOT6K1;Z1m zGR)GGsWPnQQ>ij+!>3ba_|tUHB+BraSN81QN*&K?b8VswpQmTfCCc#Cl5Ab7)bHqL z>l0=82rS!>D8r|o*~V0Wg7 z{!*&c=N8$^i8B1KGkYafhIgA+Q>CuL+15lEe$tq|wzsnF2b-JyEv30q)pK+0gb;LAEnhhBf+RZ)MxZ71^#t89oWi zcJHlh`)DTHlPbf6e3~fRj>9jU#Kq{0`1c; z^~@vubOhsJ3GbWmlnzZ44<&qG^oV^K6x5Hhev$EG$uXKb(n2_u93Jo#{5d-dgHOiON{tOzpq-)yk2I` zzFllyX#{S!Fy6?%VgJ5l{qe?0Mj+)Ux`O$sGJnB1ng@;7B%JlUC((amJ|oc=QN~L> z`2oVW`LE~CuejxRr_2IE|5DE~#3aTqu=c4+jiYlys!$_QUPV9+Ckhowyut!fjG!yjw`!+DPQ1j!F@ z+Tr;WX$Pd9ai@AX{iNfep8i4X7!OE)_&v%U4>WRaPckR2W|=QCHW`VlD)Z4s@+IRV zZVq())y5X%Nybx*Zx}Zk@r(Vk#rngHeT}<}?0epmr&!Ov?_th9!e9DjzR0h{JN>o# z`A>|`7=c$2j1L)4G@hK)rxbrvJ^7CLOY7k~5E5j5n1`ggOxKpB&llz<>Wj?Dzu*bh z<3If}kF57;$@4FoA8%yc*O=dJWSw~bxXyak>os%Yl6d5K-ito69pmJ@W}LhSk#C8= zA%2eYXM{QN#(o-ZJ?HUAbIylV=J;t^!Tb2JR`fHHw>O(l<$(l=m**1wR`dIfoGXc0 znZ8h#qHvq}WaGt2eM;*R1^3%TT}a0j*6dxu`gXLsQ-bU(;)Z%S&(Tg-Ki|&CIl}u2 z9Q(AL!G7*t<2vIsBXLK(Gtc|G{$nHi=2mm!kT_uf5J#+Mv+XAvPcV)&zGY+_SJep+3z%fdxQ_7U|U z`;B(d-(_K^u{#wBvTkYrm@%J>gZ_Hifc^QhIrG8%b6&pZdggnB`9$ME#`BEaYhE`e zu80HjNjKL&Z)Cijo8-ZDNjp!Nw>MsITy7*DFEJ<2@!S~eiKiRQ*`MUKTI<<|rWK1U zunz`PVZ7f+9E?cnx0v@e9-Pz@?1=bS_jtwPubl85(U zR<0%viK8{HInhWQb~k69`k1p1c`q7fJ@G^QPOzSNBn}?6p8htP)6ZJ-(MIC&aC7p} zYV+Yn=7D%6ep_5W%J_hh^PKpdZawc+OU*kN>4)_u-=E_8*+%A%eRrPqV~s}`0qn?& zce!Gok@>?uMe;fG!v5-SL;OjfE8@4i0zZ})8Gb=M@sOGmhnyQ{+JHE?(0rEh7UO&) z^UXS)V?FIM{^zWpXk@?NYEB+G#C*4ryu|oVvwoZLP9yR2jye6Y&cw+Du7BD{T!Rl; zPh663###T8@e?EM!?EK$Vjjt>UH#mtMu7Q%FSEct!EfTTABF{ZSfW2Z;cq28r6aBA zm{yn**9FyO^}@2Wx6p2HTK>5d{V`9hTW8x-5AS3>>rdP<|KztjY}ecPq!GBm0`Z4_ zhz0iVedaeAZ#MG2HQAi^g+Auz8ecJz=g7nBtmj@*WzIS?{%Y%|8J{-}FwQZ4Xe2LQ zW`36O7~=~@=6jVn?N2d3+Q@!mob2CYU4OrkahaB(hozpz+l_A+*@r{TQ+Uh@+V5%3 zK4kvW=irQ=bzwcI|FXDW!2_eLC(kT3=e*~AhWT3V`t`;Z<7-CtbB+1M#zDrhM&j^N z^LvfN8Tos$^{n@5^K*<7jI1;L%(s5Ck+@=iu;H-eGzCF(-jE2&{0Ch9LtyV#O^2a-S8xA6A1?_lJ4u#@%J zfjpnKM?2(Q+QE+aryY3O56_`bIqlIO%`h zCyeAx;+^MKy8cchd6M{}oikiN%-CdP{~l@%K45;lk#Vq(c<*3e(tnHXLHvB%`ZtW^ zN9J*X^*0-Vr!6pE-a8(*p8Vb0oPKzoyi#ig>p?x|4d*C&ko~;Mie*$7Ip3I>B*sN6GW+%}YiA|M9EQ6`W7w&2uB?9?w@=U|rCI%s1_GzEKa7zv19$3nNh!An!kE z^$shD!(HaA_nqdoM)tuo=3|WHgJI^xGw&_L3;F19+kIdpAD(S~yODK#&HOMU?OtV0 zemK#5u<=qO`+)pUTr72cwQ-}7b>#hlcx`rli*b;V_%p2_A9b>Vah+~XUY%}^UpxqGS9CG5PRu9gN1f}*FK?UQ zZM@t_-elfZTYtXsMkD72@k@R^#q~XmyN%ZwTN5PCPqqFMW22G$_qaLvVW2tt;V$!6 zjpU&l%&#y$Y-Hasu7|7#4>soQkF{ok@c`rX zM)sxYXQ16@G*Lmj+*49B^4SZnVBU$NtF8akNd93x$WyE%{miyK`TRihJB%ZY#Mu;c z){}fY)_U^XN^{1^_!e5vy7n}$H+D0!FXo%SV_a)wTszIl>+Fk8)-N}nX(Z2*f0-}r zY*s=4#07eg{8yCc-mu|j9!PLQ!iy;TIIS z^~audp&gr7urH`bPkZDG+M_@A1$y@5S~o!6gHN=c_b1|+yvzO~p3%SM=Qw}aUoTp} z#<;{tJT5mUj(eJu-!_=DpT?PwHy&(cUyz^ShZ^=J5`+4>id6w}r zUh-m{?OEr{e713$k#(psZ!tb=WL-I*$d{}O>%_X#{uVz!z?i~JE9M*7k6X`&CBMTN zCi{Z;nBMk% z*to~I&iJZvq>*?hAJ4UZl<_6wcH=ol+N1ss>+$Cj^Azr}qQ8;(XTHhv$E6RNe`q|_ z*kB}1ZZv0~nU-OirOS=I%QaEhp70{dJf!vWQpK0#RkkOdd2fId$L(yF?i0FWK6nn~ zIqYbkcGxf2r%37RZ&TcABS&+{CJo_66p&psuOlCM8@KddWp@UHbQ7@Lgr z$GNf2divX9evOfHfc?k$^S0}`ubNg6kF{3ZX2c$U##leVINLbGIMO)TxZXJ3IM7JG z*l2#e@d)E?BlAapR+r%&OFfLVgWu0ve_8p7e=hL8L%)^(-25}O=UcPcc+6)k{<-2V zS08C)-T(PhpRFBh)rH0`M&gD2M0|tfxwK#EN4XvHUtja*j3bQ9A92Y3;9NM|c0G-} z@8Ex<^=BBLHWCN3%<+f!LY^lcdfASBbAb6A<1*vXM%HbDIp^d?^Y@JSLB9*F-(`H= z$h=%^&VJtCyv6vkk^Rxre7+Gl(83NB2~IR8pInjDLwR3?7h%!)@n!c#+CK3?z9Nt8 zv?2c8X}-$1#7N#I9->>9dEz8<#>2XntY2)r*2sI$3Fe%qTg{mt z@)PTStm}!#rRF`2?2~8Bk1}pDl1J%>dAP~-j~a>B{^lP`GB>o4PuQjsYY4>gG z0rnM~al>aK7>6gE_E8^+P=FsK`ZRvor-@l1t#(ZpBlAf913S4s#nhg5+9mDsTnFnx zo@XB5w8MO&r(Jl8)RWiQU-!G;6OH#7*)N=TdBJ?Gb3OgDKgi>p7j?FK%=n6NfpL`aO(XM6JLg%? zx|)`Mj`Ji6BYCLSoV*Xm{5T7o59q;u2*zm%KiT}!L=Pno!c#onih*gx7v!lwJjVJk zkLkRzUf8Ak15P~C4skx#9Wj5z>Alu(HZrfDm`^pHZp44niftCw8b36)80r5K^Y@J# zjfWWDGV=T_<_{S!G~Qu6#yG?{(fF!yhmm=C#+-A4_L{9{|MfNRZoJZXrtwo_e`7D> z(?;f}oB8cVUJlxo5 zB;GDJKh20gTg_R=hs|lw{l+SZ3tU}&V@L|wr=p~nySzVNcdfcQ_;h4jS2 znpn}r_3ezz6X!gL9f*B~rE6F1kGZ!``zZZ;B^i_NQyM;hymoZIWn z$I&u|)u)_1lj$@a&aPnO{*S9w^E|AB`i^SbpR}408G7>+`Cwce@ z*WY98VZ7N$e%)gJijg>Ae;sZ8BqQrVytP|*_qOZj8aelu znD;XBzJ0hk^`Dv(PsHWD)-!I#!@To;_@eD*8i9i?JZF5&NZw_f?6)+IXWRZNBk$v; z6>AZUOA^j`$G&F1F=Kr24-kw87#}urz9eR4O5B3PRf^0X^T~Yk9DI~p9%LM6Bwx@D zdg7wRcI08x3gV-m6+4Wq3;q!g#M^e;F@LPzaO=tIoa+OvzsPuqk#mmr2UvfSk-T`G zc{k&JM&jXB^NB|G8_zR;V_Z)jy~iA=wm=?WeAI(yT3Cp}$otn?^DL={E<#9fLc)tE zc_FP&<2S8`vkn(nvB3De@oeLzM)C^dCJu>L&ZlPElfR~#FE+9s>~G@cIM)-0_;-Z$ zeT?-+=JihVn~jX8r}+vaaY|h6ww~w7M^9Q$etgCJM&qeQ&V`}opBjm?{S*7G=Jdxr z;2-;@(ssb@7H%`*C;nez{V_a{;4*XOmFkNTjGU{9y5ck|*w56Tk6`SXa47x5Q|w{| zd5iU9Jj`dikxDDBjcQ7{Z8XL;|?Rw5%*SC zuzrtNK^~*O8P;&D*3%Z1=v2(U!CxMQT7>j_zx#9ZnlE`wbq<`HO!oKYBlFQ zjdeQDdft~-nsYvZ8?3*@NWAtk=RIem`9vf4&-Lb~7?&Hj80iOl#z#KuXgk*bGV@Q2 z*BS3Ljxw_DFPTp=t~I8x!V2b#_#n^1x2F%9GcNKV`v`O52%h4DR(xn=zEis*&bo7+ z98hFfpntDGpQ@SPt6af45?_6-Coe2ACmx?RC;o_wKGxS6&o%4=eT+MdoP*5s!PYa+XPIv{;urC5bp_`a z{A3I3jMp018dEqiQ5m|psYE%!DgUg>{7SGO~=Uc_l<>(l3mXXc+g#r$#3 z6SuFsoef6v2J6Fqxy$v$Ta7vUew;b$FvFa<8e)E*@h0PNBYA4QIsQLlewC4RW1o{( ziEH*1?GShDuPJV4fpNQWe4^)lf%s#7$g{WE{$At##>vKAM)C#Yx4HtqnUA$p7<(Ey zhna`OtOEbj{eT_$w-!Z$tiyAOKBXV0iiC5%pe8QiAn{E+qbJT-cj9}L4R{~8+x#)( zR^vt^(9r_^4KgQRoojxjv6qo^k9NsxtS|9IeBk#$Kabzfn)fpfHV!kg4;deE%RXQq zK45$DYNa{x&Hg>pdiLpN^GU`O$hW)~kuOiMA^V*3D19DI{K3C0;@_8j4*N$D_JOQR zhs2I}Y-c@jNBmF^rtJ{7w42(|9d^Vw?a&{3_62&{Wqdq`J)C}c9zC4?;k3(h*rQMF zQcin}BW;Izo?|}o1OJ#Wp2sd7598u_p2H3fVoyEep?&(NU!G&U_{DP<*^u|E-sT|t z@ImW2PuS1TTi?^T(n$L!n$!Pu^Ye|{jl7p$XHNeU%sGFLFdu0of6g(##(0mBJTSxj z6XPo5#snvulP{Xg*Bcia-!_gho@QiTOe=1-um*(@f61rhH_m&^8DB3HM$XqE=FIOQ ziQQY~qYLyeSkFAu4(6=?d=v>X4_TszK4oEl6bbelyc>=B&fj=J<1^ zIdRPV4Yr>1m2nWqm$`nLv7d3P@d@McM*P3ZocZFt5Z>SQl}7Rb?`?I~A8b6|xYF2c zyw&)gkvzbDW4>Ko@v#Nwk2u|gfsys49jY5FY%#Kq)b9(4SKl*| z&g-o2VPw7UGABM+zbmZoY-B!`n$r&Zf;b}o9%Q@QjVp}gizm!S7@sySF*1MK%*n5d z%*n$?Cj3)#;*5RF_;S~?j<=b!Z;407Lw+A_yN8XFj603knU-OXCG1&;LDrKG$`2$0 z@y`4ATGy{OUYyu}Q9AUC8Wz~?Z#{7e5=ZS^PyDbB*ijD>&+Pwce*QEg@zf{LKW5H+ z5~sxBZr78~?luQlSK^}H6|Dbq^UcOl#*&eKUp6OhJDNXZoMRks+-hVz#0PPElIuTB zp1aI^i4i-Vf6)3Am>14b^1|rEkn=22SKtSEhWf<_#@dAMi(Z--FwR6>NN-4L$ZLtZ zkdi+zV11vng8016e1Z`_cbRiOTwzXLyw;p~Cl4KJJ^T6q^O43%;{YT3?R0bY72}y` zJ?G6s=Fb}0Pj{MMZ(L^l$Vgm}r(dw1d1C)zf2`}b8=1cvb6|-D_66&9yY=@Qdm2;V ze7?gK&nE`Vd!qgf#jmu2Jd^5E&ihfSPk9$Bcs~G%3+9>m;CbSW^`#x^_p?Ag<#~8& zM?2)#%WbgG$akOktN z{J?yC?E3Y_H13$k8LmIW2&DGJ3H7``kuS*;)L)EXJTT!$B>am~;*hwQ><3>pGM~)H zQP!_EK4Ro~{9a=HJx0z!=8toYc){PLwr74WG$&6mkFC}(F)}~gC*ZWV&2~>3$-nQK zZ#7Od5)bT~PpxNMz09kOLyf?83*`M~^W%&ojW-!nV4m0qA0!6&m8i><_-DS(Dc3}S zeU+#`LpNB{J*jyq;rpVuBnD~!U0hE-B~PMn?|L}Rk)O#Y#Pdfspnj)0^ZJze3C1&x z9~znO+2)Lwc%dEUf%P3_d-lu8=EUz}bIu#yBZgUjsj;{5Wh3i3)BG&saYo)l7nx5m zlK0pT$5_AC2r%Evtr$Rsk$ucMl6Tp+YZ80r0rm41t}(uEyvoS_G_7DfaPkLvX{9ST zXRrg;A{5|m>vt5??@N9DJ`9Lc;*axz$5{W=PvRJR;(3)HJi^HQ-)UZL>}w=HvJP)s z&-@QGZ!xl8PD=Ew?{@2NH1;y`{>yw4&m&xauW?@T9C3z!#QPJryTEvR@;rWSxBhP9 z9wYA&^h-PBQTn~a_NN-zKg>JhWqiyt?;Wf!ae0W_0oEcIiT8Ep2N)S=VqU>Ipug6_ zP~&i8n%A+<5sc)E0g1kd&Ppnn2l5f;+AiD8F!El7fA?F@dHR$&d7kmyC<&XqPqkbWqB-Zr!{#p;$@2@%HyD>1f#WT( z|K^&{HkORX8B@60iUmg2f%WD2_34A=>>H}_gZO&U6%#So2a*pkV}8#@u@79ak9w+E zpRQJrU#SPvyox%##@Zc$5ZArj68po`BEeEu-3e{ zvBCJLk-T79{{JsluG+Zz!lo zU2kEOk#lQ`IpZfE;5Ye&_sfm8hjV_OYdv}EX!8z6@(%l<#(KuT$ow%Q@ju=CZsSDb z%|_<=Lv!Afh`*<;r=3c3^8DrIbw=X*9`kpSdd79B^}w|VM)v7UbLRcd#18rag7Fq( zC*vw3^J7|uEoCVR#6Rl$5dME++CKiQaC-xdoC`;qHyW8A*3Ie)@*Vq&^|``l|Wa9(IRYumcr#W%QbL9Uj*S~4xyu8Gm z^`B|p%gDLGJUwXr^+sR>f{{2m+#J8|GiN?|jQN7UVS)J~Ur_&L@r5t@Kw?iEat<+1 zC)(~qWYgk5YOAq85i%FldM0)2#^oQSiyW! zPrD1OAfBl|$imS^_8ay1&pv6U!pM8WIp)MmVpgV{3n;F#Fw|HSueAHJ?RbuL!0!pJ zXTI5od9``PV-{$KbBH{9u0)!{*Yq>+gf)YV-Hm4(iIdglDdbi#Z@gz? zms#kCBEjAXXC4!EAw9IP=8IKyu^H>p-kkO7Xb!So*pWYoXV(8@KR3>}#>lzGIl{cc ziMJz@etB;oe?IDZ;-RNG<0D?!SLEN7wp(O8!bm=1KM%D2Oyl#${f(Q9#0~qer}gg{ zhZq^}67vm4_UUeOfcSgIik`*;jVBt(x6HrQ<-ZTgI=<@))`2|9dvKb^X=j*)tBuST z`xQIp37#VB$-Hh)3^pe`rMFtqVw{{Bm=n(+@j$$GvK@Njo_Qk9d7f@~4y2yv;q*g$ zquf9*V>cu7y3?HeLH^_X+wFSh{Sos$M$T*2b+q;LOZ&9b&-HzbFB&;dX_xbFqwD7w zcNu}JEZktc#z>s9FPQK7t{-e%VkD0-KXuk$YMf~7Z6wcf4&Grs=h3O=)ke-U@&NNV z&GqD=dh-?|e#|xp$OGgn#?{jXO-APD1s*isW+V@slGO9wOgaceQ{WdivoUM^8WG8{&-hs~Cb9=bH1r zb+kGA%CzDh3$GZ-7x;tS3JW_>7$+H*C;AH#ev3KzH>s|`fAs7#IOiLDUwGKQ43c`{ zkago+%x%Xx$GXk1p8ApI4;W83t}|X|tTQs7Z<&)ndG2cKi68oV)OzxCtN9Ve;YQ+` zJkGe^c0K!*bz)!das6?|rAFdxg!$#hLySX>eT~DCdd6*a#ViX$P#78iWb=cK#CKv| zk;W{pvF+YBZZcxm$9%AHl@VZE@SO-o_U~BpGmRUK^lMtd zy29CCyr(Rsq5xSxs@aF|hY$*oJlnP^BkQ1N7-(L z@e<=MBkMq3A&*|{dh+y2^LvdgM&gij5vL`6N$JzrgjwjqGFc$1~Qm5Br#t zSC*M`AE2Gvt)FTH8Z2})o=1g|yvTkbzhHKn1;#Vo{32tz9^{dMt{-E3(s-?LzLEDB z;&{CEJ&lYDzq?!ip7AXs^GCj+AL2Vp?AnGp*#;KPacP2 zA2=_mKggW-grs_(6o1KYaPl5`5k$}W(?R-NXIEek5~s8WcCbDB;#2e9M&f6!`358V z-L(Ao)Ge>mfx;CkYJ zw0U3S6-L&ZcsbYlH;m*t{DZr?;&ls08p+S>gZr(2)VR%_Y8+-}#8F~Z+8pJCi;Bwt-<&Url7e4~+lH_H4L;|3$~%lut#J^TG|bIzGn z<}VwU8G$_(*gyE2=E)k{?chPAM46KAwSHt-FeA#SC}*I&F1rs?At-+#OW#K*fWl8)*o$r z%s9<>my!3u_2%Te>&;nD_91b(&h?C!_sG+&?`LEl*qu{RuzE zxt=(~Kh6R2*2lIZUg?MRoag#x<7VT2M)C~n{G9ciQ^U;r7*8~IGQMb>WPIG%&dB~F z@3EdMT;J8$%lNjDxHPRe*#dqXXZ}0}#xyRdzS_dAM&2VI#K6dWCh7{F$3NyTwPSu$ zJ$Zq3OY7l#5R6?4kom=&dBnb>E5IUsXWQ|7hvd2TNj>eb{;n~<;~*~Yay|OP%!!+G%?BCJ zG!mDb0~cF=pOJk^T-|B?-NsjpCm2680*voAD>hJJB+sDUieNm?$i5$&)TflVO7-+J zAgxF^>yxM}jg=q_~FDe`8#t1-Hp#0S=VFD zxp%y4-pxo}-fTYFNITb>Pc$;$#3yn0l4$xko-fmE*Vj10Nc#=u?CXcjS+C{h zU5xnur1=!%d?W9pjIY-ETZ~T_$^YyxtIOYW7-q#RBk$GNEwb=13M2dAVDn{0#+jIZ zhQ4FXNF(#J@iUg6uc);8EhFo`^YhI=qrD-~j50sS$b2$C?9Yy5-iM|0wo_L=viTz{pJdD>!rjgfr?zs~x5j3bPzjpS3}eyR008W{)evEI+S zeywr0k@pqGd4%=smq*O!8}}FwGy>0AU>(Mov(Jvi(AX#8&<+R*o{@0YBT*Or|6;6X zYR~><9q1oD_N+JW4s zDiR!Oo<7e!Cgz3oprq#Qgzt;aOblugzAs9AV7IUSS%1!R{3o8!gTw`KkNqfj$atrl z^M1C){9YsX&ynV=%W>wf8&5RyzW2U4^K`TMIwQ}|FejebPsAnrWPjVyKk>=_B;K#G z9sA)WbNnWMZMVLU@igN$<2%N+#-2vb4dQLI^}tOQX!lJjjDw8qr^=+B^M!GNoalEX97;UGi8J)@_BJ3MLB`Gcv9Fo`Wq$5H<0j*BV^3qf zainpLk$Gi*5HDxDp1AK}US-^4yu(;BPBOCoP3FWc+II7c3F zL*tC(!Q0JijgyQ!jret^d4D7CH4mGQG;TA#Y9!ux@4)|eTu)x5KjQlW*OTu#4~AP$ z+!OD#kADZ-j&}K6<#g+rPwpqYUoemC8|H`kyUWir@APAJ`R8gqt+?Mvzo(iHF|uz> zD=xFZdB}L!PqSRldOW~`#+AmIiJpCddO5-u#9fIU`!rD(QsS9;XMdvSLE;-Eo|y;Y z`Eoxn(8#_d&-SH3BXP#M^4@)g>xoC^_gw2)=cVS1i~P#D&wP_7887pGv7cwXTFi-S z;$pt_(~OK4e`i?Fetp@zkCFYg$eew`JUwmwGsZ)W#ODF#lZ?y{{SpVqxPH0u6C-fE zg;~t2kq2|4f%meYW!}=$U^~Sf1-v&;rE$1%iIH&=hvcUfuID+z_9sV#Hti8EyS~BXQ5?EF(dIk*8DCb;~-DWwSI{%9lT&2{Ee zjI5(+8CuFx6mBwKWvoi-zbM6@FKVCKElm~XX&f^TnByn{ zKHwmE?L;d$ms>E{2X;;B-!#8G(HGHCNdVw_Cr~_^^?4oxDpPquoPox86t|W_@a{C+{%t%p2qQz;>+fd*+-s z#N$!c-)m$Zc+aLE&XLn?H`@rjfMBE@)UR21-iTk+r|-F+QeotMe2#eloDUe zOIpwTqKRLG!a6RjaIr9YvSs&t; zeXtvYec-@-)Kg6yGGFAY3vFA@tisNX}kGqV^1UTO+Uo>2-mZouAKn~cQIbo24X3C4Gg7aNBeiQggS>>K*I-+JPYean7Z=lW{n&Bolw zxEUAYVIR%49phzu?^sX&mzg&i&o-t|lPCt6?=;?*)bER)V*~QqgZp|Qsh^Q>@@i-6 z$&={mhdfJuZD)IsID5eQt;So7#2x#(kM(Pfqm0aNPjljSm3eO?`*Ml-L?ii={lomQ zU!J!eaq+e}`Q5Yvd*Y5bU_E(X-RK91!$x!BXskK$^tky$#^;QjBgE-Q>u)l$56Qpm z_l2&XX9Sq%3`K%(nJ+OC|Cq7PgHa^-m^u3=sjhg@irHy}IdN7{UH-l|=S6BqzD)JR z5qa}{8(e3+#JIna_e0iavGq?GiEsAZOzRIdE;Z5*d4Y94%Jt-R; zIltzZHyCOEE_2!$Y(C$}IuRe7YlpdhhLL?i|5lgbV@n;%nkZahew%SjQV)I2!ZsAf z7mU0oFz&>x;=ojq@C6A^Df64^;nNU|M?2(Ekhr2->{HJ3AocW5 zJ0R_o&lwN(7r6n}AATAQ8)q2LFcN1^n6p3X6T3?DH;mhj>@(t@{MO|9ZpKZ z4f_AqT)8h%Mm{C&yC%I3r!?+2{IKwG?SBwtN8uQRSTGJieH zmm67!9p=?W_5=18SkF3fF5G85>-3=cB;!k)_-WEKh|-w^{n@?=KYK}8?Q0GYa|{gn$r*asK5317#Zgv^To#W{d2JO zcNR(wg$JhC4;`FZkENAu1`>Ot(_w8JnuB+pZi zJ@s(vNBKF{m3XbQp7q^o&U1&D6X#EuFEf(ot~DnwzGJ@0c(xI{8Ro?4Sab3^`G7bY zB3s z{V?C0FZe-TU>>I>?b8lOUS&K7yMp;=d`I&@g0sv|N%TV#&iL>L+|PEzH*pB3p17wS z5Ig3D=h+W%>gjis`z5asPy1WH(a3sEGAG|(Z@$vVbEle<56Lsztv}E>*+?EFZZERF z!T6-{DdX)%;%~Hh599sD?nd%B`;qxP%=Modi5vXIfA;w>+f^DFFY!-%_{Dq42-~y& zX`gsv9aq?ndFFkEeDi?o-!rCAvVyq%#QY^A^O8OXPum?xh4IjYLy4m^ta#t}qLFcN z4xefLEaUmcSB&_7xcM^U5ym%-n~m2Q&ofRpE;bS`>&@9OubIy`PD?oB>S_J6i5>n~ zUD0fTcvx#r|7V-0ei08hxgL1i0`uBxzQB05k#mFmV0Fb+7CIQolgtzP^!=CTLG~;8 z6n(9QJ}3(C1M8~_>c?A89I-x(Ynkg886P!rPINOr%E-Qd+PtTc{Y{?OZv9TF``awqa zC2_yjdghJz++aOWLMTA?0r_QIfnAaM6jw|wGT4XygruH)ikf(0f9&T9?3gF);TKu) zxN)VCcw^ljww`$+e=#oh2XTS^LO=Je@j2sT#<51m&$?Y}{gFoU2Knv<>&dsoHRt4~ zuD{N>$vE6N*?5bQxT`Tg*SN*F+<3lmwh_NgD_G|ntYDrym@}`~Gj2HN7=90S#U&Ud z*gN6O6Y5703h=2!zcAsP8;QD1i3=3uUHBJ8;*j%?cx1jfUoLio#Q6i}2OEid^7(Y@ z+5emu6Ramr_LyH_ywbSC$bKhocu%1J=WX|nahmZ=Bl+}gbM_Ca=x$7(KiPWbgXdqiK>y7DGuG2S_PpQ0@33%yk^R2W{1Xg~HI9CttE(*}t7^hd#Aym$dW$HFloAIfhXjE=o0K?a`=J zv1*4Zf@;->imKXFh@FU$phi$aC@KV{q4um%Ge$L<@NfBYT|OLn;{9~JzjMxWKleE2 z+)opU1?W6zf3YuI#NGghL+jUlbv@40zT>@Zzp-zam*#T~<35G8(EjE<>UT#oe%G@Y z{S)qkuY&yqdH}o$?Z-ROo}b@y_IdN8f;jWI7;T;-)9e?1=eRHWC-LJ8EYktK2h0o4 z!!XzDKDb`}es7%D5NKUzVt*3+sY=&rAAiS#)|LC!341TN8QNDW(a)jr=8^WE_?N(Z zcpSRV_WjY=v!HqE`EJ3!9Qyq=HrTCi`))OU*YEy6!@eD^f{oC;cb_t_XG8mM1KRhD zL%Xl8_cM0iZ=Ck|_*X3E0h;gC==;!q)QYZwZ=rcM z3H=^gR}0bB-%Ip8cocSqd!hM#9c>;tzscCGqkiZiun9haJK;BY5SqvKodWE0p!My3 zdOr3E!qR5(BUb}-eeRp?dVd5S2-& z6q+~I{dw$TU^a9e7tq^bA#8>#VQ(m`-+37Fpx=eY`TPbAuok+`F6eh5&UrivcK6eJ z&N_Y_;*|CcwfBtYt38QxlN`Zr9vbg?4kONUD?$%|bD-yO3_TcDLGynN+Wp^z_Wp5w zH?X^p_Sb>fr@&Lt{O^VC4>!X)xE1>Sx(IFEWT78I>njKS82TRT(ffBF{=?97-GO%h z%yaX8JpNqx5jviE?t1s*Z-Ktw_hA0}|KA?K;J&yH$Mf8+FQ4z~4>&4l=NHU>rF~t) z8-Ku`L7RuWu)A*K?3fR5jQw$AOt_K9-*zGoxaefIpvV0XRt z6VKK2bY2sQzXHuC{f=w>oh8n`>ACnG$9JB;i9ZDULh*$kM>N21uojy4YeQTk`dqN@ zL7PwZKlj&g&nv$FmI3ew92)G(d;a(j+y2KN-R9wL9F@>|bv?7OyH4}M`_gmDBd!8| zhufg%WxV_E{a{|1H|DwhtPSHlgT>JJljwZ-9C~lr5A44k@Voxi=&!I2+Rx0V4D2QF zGOU8W*E+=f&%bLn?;X$mD&+y|z~>|DG0cPo(DRA$NxBI~%>ShLXkUsUD>_ot9Yb6u ZIvuuyp1<_`eB5!`hv%X_KIi+i{{hsO0d)WX literal 0 HcmV?d00001 diff --git a/uniter_model/data/test_data/input2.txt b/uniter_model/data/test_data/input2.txt new file mode 100644 index 0000000000000000000000000000000000000000..619f9be9298fe03e442596509debf6de5e019598 GIT binary patch literal 92849 zcmbTfcf3|b*5!YR0*V=PAPdHugNckFB6>|2iI*Vp93D{At4LB*%wkp~2}V@R2@r64 zI?Z&N>3RF@p6N8x(`lyDOi!oZs$1*(d|paVqd)$rwb!nFs%ls5+I8x=&%G>Ze^7?pL<@v`J$swU($Z;FFL4f%_;_rKe(czVtziaG;hkh*%KDEqn%y#?W$x6Oljdga?;PJv%F^QbEfW^Zys2eD)}b})IADAyDHqI`mzT0mW7?1J zAU1PuDeF9@$%4U7&t1iMnpte&Y|6NmFfgZGC-XO>J#`T~l?|rERQJ<}b{; zj=AWtcJTPlYP&QkpWaf+4w<3x%;-EL>9DFsU1x{R=sKp0Iyx+L)MZ6>_?UJp zvL35ugtm?-Wk)t=M=hK@Bh16mt=TaHW_0}g#%dcH8>+Ko)mRwyaph4TA4h#cZ1BWV z*0VY5HRJO~RNYuzuMze3i28U$eZz?Q#RmG9vXh##0Upu7*6ieeGNOj+#_SZ2=+yFv zPKzTtJvMkoDXVPGstQI_Q(x0q*_c&(L^U2!Z5UBqY@oiBH8f|99#K}orR<94Y`jNwWovd-;fNX=o3v_IdqmfiM|5o*(RH!G38ideb2h1HL=6q~ zP4(Gik7$ZVG&PKy>_!23vqNyXsj)c zTH}arjSb#b%9b=|w}1YKlx`Xu6^u(gqGcY@@-U(mv4NGP?2hK_PLF6+Yj#)Rh#IOJ z8f&w=J)(QcBU&9tbZ>0%zEXC7bM`>dh(Zpm%pUZJ9`cAD4kLObHn66YJ=&Z-<`F&K znmtiCqS}T!jp#`=Hdu4?R9j~)Xm4UcGB7}1-tfwxN8+s)ZK9?`q4*?WZ}YHFxysL$T_h(0Kf=)*Xok79!# zm$L26*^bX0k(R8XrmDI!+vyQ~;t}l%BibDs*i*_rZO;CkNA#7}?B5rTNHbJln|;+I z`dWEJUymdDMr`n#rR+a6XaBKiM2*!=nxk)dMBny^z7t0D-PpkQO4)yE&i=DU^!?WC zzkL3P>KYru2~eB;i%0Z>@`(O5j_8N6!5@{f|Jt1Wx1td>)z)bD|JWn?iAVI)FruHu z27X@3{(E!wKRlvev}XVF^GBo!Z>rYb&3@?-{i-~oU&j&sCN}unQue=^v;SQ*qQ<8B z%F66_9?|bTqCbQY{V_K1r&9KRnzR4w5&gL}`+xsrL^XBQ*YA|#|HmV z$}8GcR5a)9R2Gb@uCl(mu0C%cEV_o*lDvb&IJ&%}L=7m5h{^&FYMZnp z>hnW`MfBh>iT2=diNS*&!r1B&l1llJDx348RN8~1CAQ|rs4VuNx~i!_xFA_aCK~n7NL`kK*r^@EMm&zg!Y8$F6EA!sLB6`qAqCMy3&Z|@w zcu=R;tj5Z`I#@&xY9!i&T8Y7fI$>O z!8H=g9$YIDJ-AL%>}!IgQa({-b3RFBkq6cFRSos|s1zcP*+pgq$J;izT+^H>oW0ptiBTN*>%S6ZI2v z(Jc}~Kk9X5-YOD3xK&c@>o!THe2L2D{C1U}?}6g8UME$)G+0CrmPxb+%OwU6RtRIO zDw&qW$EcBqdN$;xp zlQOXfPf09$@U%$u;2BA=uV*Ed^0g|P^XF6+c~DthRbQE}3l`CX^%CvD28qFgjl$UK zCP}6Id6muiW|j6}i^SIa1(n4f)Q7D7qD<_;OA^Z-yetwuctuj|>s3jme5=al{56$D z9#n@{@BH;(5j}WAqCMCqF?jH%Ft+-Zq*DI2%I5qXmG2Ot8C79s4ViJR-X`6=R1Q%^xzYT_F$L9;K6QT zY;}*MQvRvR=KSBOv-vd{y) zKWe7_MJ8T@A4n{(!M}<`4}K^q_Vpu4rTo9DY|j6i$^s9WGL=O&>6#F_z zlI{m8bw5y9nD(%715_LaNS?GaY zD;gW~V`X9wj+0pS;CPYf!3mOLUnfe^{XnJe2P%s^sH<+$&7*g)h#vHjXb<{I3?B3o z##Z}F()~cC?guLE!9a<+AE+$!pteeH!TBjNQ9t3W;#7&DAN9I2KTRZhaJr<}*BO#@ zKTxUrfyyEeYHM{M=;Jvc`sdT_3!*w-LQx*w?2{Xk`*2TeL_@}a>ZdN53)Js2)AcrZd3TOBD$ z_XCx>AE>kkqb2Hopt8_|`o;!D>lm4+pKzjGATjjQcA||Hi5^@iDfTr^lI{m8bw5y9 z3*P6_XCyo zV5UUf4^$R-5Z>ZL)}AF3^%M5NY>A;C^|~^jBN9EBD=GFhPm=BjDs?|lS>!>jUfXp) z2o}+U1rqH+Nn-F|p)j_(NRsXcDs?|lX%7}l)crtZfd~3_MwY z^x#%Wv9H@C>3*P6_XCwh9;jWN_Dh3B^kA7pd$3$$@L+{7wz^W1?guJ$KTv58R!P+T zKxLr^4K?B1xmzao;2w!(4_1pr5AKx|`?^n(?guJ$KTuiZK|@2G_Q8X}B6{$UM0@bC z#Nfdr!r1B>NxC1X)cruEJ$PKA?guK1J!q`f8aycz^%HjTQxZczZ9DmCk?6rQl44)a zO49v6rS1nRi#({-hrreOx?mAKSTE5YY>*f{*eHyxZjz+?flA#ERN8|r5_LaNS?EDs zI1us|WnvFrl34cOWs&H?E0SVguS(MWK&9>nDvLa*(FW7~AXr2X-jHYywn+>gyeW*W zz9mWb1C_cTsI&*~O4R*8WuXW9SgO7@e_tl{-~)+e4?Yx$9(*Jz_VuwO-49giexS0* zgR1)4s``9qu!tUfBGDe~k{CSLEsU-1k)->9O5G1s+Jmo1)crtZu?J0+dI$chOzgqe zB$hq+x=8fk8Pv+dhkO@v9BLV()~cC?guK1 zJgBR!(>?je!6JI_6N&cVrxJq)KNH4Qe=bS)1C_cTsI&*akf{5C%0dqstMmic{FgHE z8vIIPc@2Io5kQuhOuMIJQOHR!G4cflfh@Oz2&;13dm2Y(dCR{tbP z_XCx>AE>kkf0n5GfyzP;Y8u1Y{}-9qgTG2Fd+;}r=)vD5#lHR_N%sSNk3;tZl|>%Z zx845Q2aD)I2Z{Eeqr~7rCt+;0vn1UQRO)`9(jIh`sQZD+Vh`G)Xn&d5g99X%J?JhH zJvdNO?CT&&x*w?2{Xk`*2aQeo)bY?@5j{9eqCGfVV(_4cFt&PxB;5~G>VBZo9vm%E z_XCxM9@OjTtjv#a)R9WTJlB?gtV>Kk8NY1Ci*#>5^h! zXGqfhK&9>nDvLa*uh%y<^6Fp_J*bgr4{9X_59)-m)p|*~AE?y*K&3rslBoNE%0dsC z^zmk8ewIwU24_nwufaJY(Svg(#l8kf()~cC?guK1JZNmx59ITq!6JGvOrkv)E-`p8 zLKs^eDM|MOmAW6OvVBZI(1ZFqeaR*tBNKaYfyA;0V@09|7fOnKjgzGNflA#E zR2Fzpubt2Ukke{Xk`r2l~uWKTgfB zmWe&MMq=56Yek|5*GY9(Ss=x?ZH%u!GjiIY;~F> z-49giexTAG%#^76fyzP;s@mR7XURnUwA~LRhJMtm?gt{#gSnDoU-KmCexOqK1C<3H z)K^y3>T{etSVRvNNVEqfiNS+~!r1B}NxC1X)cruEJy!=oQ*HQJ(cLn!2lq%Ud$3w0dT_6#*w=lMbU#q3`+>?L4{B?}XK4=x zi|D~a679jm5`zbi2xF^jBtskOT6f<^RTy+nJkL1OS=qcFC*Ns{gdDs?|lX%DtY)crtZ zp$FP-daHO*CidVZiDeI77Kt9bA}RLuswCYHRO)`9vdDw*qjud7f<^S;4T<((o5bM3 zo5I-YTat7?P^tTYN_+6GMBNWm7JAUs5IlHaCh8~btji} zAE?y*KxLr^O4F73&R`Ke_(Y;T*d;M|uv-{g-6Ki&1C_cTsI&)Pk*NED%0dqs>-58- z{HrqY8hlM+c@4fU5kQuhOuMIKbuho3@yD_BGizAe!nd`DvN;Jd=u z>h~n+exOqK1C{pR`x13OP+8hoX9 z#B1;?iRCr;wMg{fHflA#ER2F*Bq}j=T7c8O&zn5qa{va`U@JC^6^-q#? zKTxUrfl7PuXNkHWs4ViJzNSjwPR#!z6R*KvC6?FVZz9oyze|dJ{X>%O2l{Te?guIh zJ!t!Ra{FKrJ?J3O9(0r#Jm@5ht#+2A`+-W`4^-NNt`c=WP+8JgH3KTxUr zfl7ODv_#zxR2F$qUmdVBZIz=OKVCVdMtuMQT`gBpqUpjKk=piUTDt(ThndgcI!oiJ>3$x-uUt z5VBZI&;xzoRhc&ji|D~6679jI5`zbq31h35OVa&7rS1nR?ZK52 zbw5y9VBZI$OH8f-hpSy#2(C+SoUC!Nc3Q?q}bOyNxC1X z)crtZkq7ma`hAT&4;Imb1rqH+Nn-F|p)j_(NRsXcDs?|lX%7}l)crtZp$FAf`tUlx zStj=27KvpKT1BD?L532P8 zy2|{)U=cldNTNM>SYq(t5n*g~jU?RAE+$yptkMJZe6g59;}yW4>m{)9&8lGRyRq~{XnJe z2P*Bs7Kyqas4VoLwz@{2D7+{WufakQuhOuMIO{Nh0iKp z4;ImbHzeAFZ4!eAZwh0pZ%NYqK&9>nD(%6$5_LaNS?EEv-tF}B1DV)^4kQuhOu_TU#1 zbw5y9=t0}B59PmAE+$ypst}&?<>Cx z7SV&>OSA`nkQhApqcFDmCrP>=sMP&Hr9JqwMBNWm7J5)wud7u47n!J^@NW87iJ_mi zchkR#L=XNhDfaabNxC2Cw?cJ4P+8n zD(yj6iMk)CEcBqhE_}~!f0?MC5Jd+_4E?l4QFoE(!GV%uUk6Fj{XnJe2P%s^sILp3 z*&P}zq6ddbvVBZo9t@PI z`+>?r4{94KHB+a^ME!&)I#pumr!9(36Nw(2E-ChPh9un&RO)`9vd{zlAh0s84i?da z8j1FxR$}m=P8eIQm!$iFO5G1s+Jh#Ex*w=4^gutR(7WkbGO-6|ODubEj!5+2TuHI7 zL6USoP^tTY$|4VHs`ML_`OshyJs2j@9t@WlJQyL2t&WtW`+-W`4^-NN(Gqn(P+91K zers6g&KQ~4g9{{}#AP-49giexS0@gW7P?GzW|5!6g#y!KD&|2bT$B ztCvgC{XnJe2P*Bsl@fJ7P+90fV}1C1_G+1^pRkj!kr?_>uPgIwMWP4SNs4_?L5A-!V z#plvs5j|KY(H<<97(7@ZjIFMer2BzN-49gSgH;lBKTuiZL48x(E%R=f*n@i{mOWT4 z5=Y;}_)-49giexTAGY>}w@fyzP;^aT{{VBZI$b-sy{g^R-Jy=8!-jHYywn+>gyeW*Wz9mWb1C_cTsI&*~O4R*8 zWuXW9tU{lxy)P55!3Pq{Yw)2+^xz{&v9FIM>3*P6_XCwh9<=>he!eqUL=QfZXb*Nt z3?A$j##Z-8()~cC?guLE!B-^eexS10gSIpIt1|H#d`)6`4ZbcCJ@|&C*w;5D>3*P6 z_XCwh9@N!@Z_IuxSVRxLEzur)M`G~cyTaJ&_ay0lpi=h(mG!>3zTd4k zyMK|1J@|pdvIqYv5DRUM9|w!*!A~UGgP%$a9{fxg zTm88t-49giexTAG{6eDc2P%s^&{ux-+fey0WnvG0C9&+muSKE%O2l}Hax*w=4@Svf(D$Gs$U=cm&AkiLllo&kdB#fAE+$wprN|$XYIX%Mf9MLM0?OzV(_4!Ft*xXlI{m8 zbw5yP4+cuq{Xk`r2U_d4-#?Iv*Wgr%BzkbVq}bOPl5{^%sr!M-0uS`Vq3|1@ z)xjcqP$SVE)JhB<)Cps&^^$ZyP^tTYN_)^GQTGFtg&u@p)%S~JVh_%iSoYu?k?6s> zl44(jBHyNbW4;q;cqPMpDM$j@Ht>_W!tARS@*q_@oz#Mm?*=~ zk+XwRW%yGi2dB#LM^Fw)mErFq9hxe`pOiT)QMUbcpu*xt5dm zQr5NoI$1U#QHI~x%?2jQ@TWDhlM`k5NoID+-pZ=*hZ(X{6J_{|1leh+Qs0cpPEVBK z&Xk>zD8ruv$SM=`$`> zhQB*@QKAgLb)8+DD8sMbWzBml+uk^{OA=-H70m3?R2g>1Wr;HQnO&YL!{Dz-mHKQk z8=ok{@7rcqCd%*?l&TCvMs!)C(7{aSJ{k28GctP zo4L2L?U#45>l0--^|M*2Qa^CdW+%$<&A4n%stiFpH&ur8J1+Q)T#*ky)Y) zAE{<}q71*clP%a=+4lAGtduCj+jX`uQHEa($rdHb@D`ihkSN24HQ9}cGJGQ~TfCPt zez`ciDN%+`O0t_%rQUe5TM}jX-gefyx3VgHjF#P+Dnl^bmMX*R!IDH7eiEJCo+!iD zSF@#iE8D)2kS$A<`oTrEJW+<9*=8#eW%%kxwlYzM?`URsq{{FLcW0{9&-}Ahdn?<% zG?(3#D#KqwygN~bFGgkeB+Br4dA2%L>Z@AWy@@h>hc3G>QMSF5-k&JL_XV>DQe}7( zcraCl<#=dsW!slCvWF98_$hw&NUGG=nX)yBGJM*QJ(?)P*M_pkQe}8GdOT6K1;Z1m zGR)GGsWPnQQ>ij+!>3ba_|tUHB+BraSN81QN*&K?b8VswpQmTfCCc#Cl5Ab7)bHqL z>l0=82rS!>D8r|o*~V0Wg7 z{!*&c=N8$^i8B1KGkYafhIgA+Q>CuL+15lEe$tq|wzsnF2b-JyEv30q)pK+0gb;LAEnhhBf+RZ)MxZ71^#t89oWi zcJHlh`)DTHlPbf6e3~fRj>9jU#Kq{0`1c; z^~@vubOhsJ3GbWmlnzZ44<&qG^oV^K6x5Hhev$EG$uXKb(n2_u93Jo#{5d-dgHOiON{tOzpq-)yk2I` zzFllyX#{S!Fy6?%VgJ5l{qe?0Mj+)Ux`O$sGJnB1ng@;7B%JlUC((amJ|oc=QN~L> z`2oVW`LE~CuejxRr_2IE|5DE~#3aTqu=c4+jiYlys!$_QUPV9+Ckhowyut!fjG!yjw`!+DPQ1j!F@ z+Tr;WX$Pd9ai@AX{iNfep8i4X7!OE)_&v%U4>WRaPckR2W|=QCHW`VlD)Z4s@+IRV zZVq())y5X%Nybx*Zx}Zk@r(Vk#rngHeT}<}?0epmr&!Ov?_th9!e9DjzR0h{JN>o# z`A>|`7=c$2j1L)4G@hK)rxbrvJ^7CLOY7k~5E5j5n1`ggOxKpB&llz<>Wj?Dzu*bh z<3If}kF57;$@4FoA8%yc*O=dJWSw~bxXyak>os%Yl6d5K-ito69pmJ@W}LhSk#C8= zA%2eYXM{QN#(o-ZJ?HUAbIylV=J;t^!Tb2JR`fHHw>O(l<$(l=m**1wR`dIfoGXc0 znZ8h#qHvq}WaGt2eM;*R1^3%TT}a0j*6dxu`gXLsQ-bU(;)Z%S&(Tg-Ki|&CIl}u2 z9Q(AL!G7*t<2vIsBXLK(Gtc|G{$nHi=2mm!kT_uf5J#+Mv+XAvPcV)&zGY+_SJep+3z%fdxQ_7U|U z`;B(d-(_K^u{#wBvTkYrm@%J>gZ_Hifc^QhIrG8%b6&pZdggnB`9$ME#`BEaYhE`e zu80HjNjKL&Z)Cijo8-ZDNjp!Nw>MsITy7*DFEJ<2@!S~eiKiRQ*`MUKTI<<|rWK1U zunz`PVZ7f+9E?cnx0v@e9-Pz@?1=bS_jtwPubl85(U zR<0%viK8{HInhWQb~k69`k1p1c`q7fJ@G^QPOzSNBn}?6p8htP)6ZJ-(MIC&aC7p} zYV+Yn=7D%6ep_5W%J_hh^PKpdZawc+OU*kN>4)_u-=E_8*+%A%eRrPqV~s}`0qn?& zce!Gok@>?uMe;fG!v5-SL;OjfE8@4i0zZ})8Gb=M@sOGmhnyQ{+JHE?(0rEh7UO&) z^UXS)V?FIM{^zWpXk@?NYEB+G#C*4ryu|oVvwoZLP9yR2jye6Y&cw+Du7BD{T!Rl; zPh663###T8@e?EM!?EK$Vjjt>UH#mtMu7Q%FSEct!EfTTABF{ZSfW2Z;cq28r6aBA zm{yn**9FyO^}@2Wx6p2HTK>5d{V`9hTW8x-5AS3>>rdP<|KztjY}ecPq!GBm0`Z4_ zhz0iVedaeAZ#MG2HQAi^g+Auz8ecJz=g7nBtmj@*WzIS?{%Y%|8J{-}FwQZ4Xe2LQ zW`36O7~=~@=6jVn?N2d3+Q@!mob2CYU4OrkahaB(hozpz+l_A+*@r{TQ+Uh@+V5%3 zK4kvW=irQ=bzwcI|FXDW!2_eLC(kT3=e*~AhWT3V`t`;Z<7-CtbB+1M#zDrhM&j^N z^LvfN8Tos$^{n@5^K*<7jI1;L%(s5Ck+@=iu;H-eGzCF(-jE2&{0Ch9LtyV#O^2a-S8xA6A1?_lJ4u#@%J zfjpnKM?2(Q+QE+aryY3O56_`bIqlIO%`h zCyeAx;+^MKy8cchd6M{}oikiN%-CdP{~l@%K45;lk#Vq(c<*3e(tnHXLHvB%`ZtW^ zN9J*X^*0-Vr!6pE-a8(*p8Vb0oPKzoyi#ig>p?x|4d*C&ko~;Mie*$7Ip3I>B*sN6GW+%}YiA|M9EQ6`W7w&2uB?9?w@=U|rCI%s1_GzEKa7zv19$3nNh!An!kE z^$shD!(HaA_nqdoM)tuo=3|WHgJI^xGw&_L3;F19+kIdpAD(S~yODK#&HOMU?OtV0 zemK#5u<=qO`+)pUTr72cwQ-}7b>#hlcx`rli*b;V_%p2_A9b>Vah+~XUY%}^UpxqGS9CG5PRu9gN1f}*FK?UQ zZM@t_-elfZTYtXsMkD72@k@R^#q~XmyN%ZwTN5PCPqqFMW22G$_qaLvVW2tt;V$!6 zjpU&l%&#y$Y-Hasu7|7#4>soQkF{ok@c`rX zM)sxYXQ16@G*Lmj+*49B^4SZnVBU$NtF8akNd93x$WyE%{miyK`TRihJB%ZY#Mu;c z){}fY)_U^XN^{1^_!e5vy7n}$H+D0!FXo%SV_a)wTszIl>+Fk8)-N}nX(Z2*f0-}r zY*s=4#07eg{8yCc-mu|j9!PLQ!iy;TIIS z^~audp&gr7urH`bPkZDG+M_@A1$y@5S~o!6gHN=c_b1|+yvzO~p3%SM=Qw}aUoTp} z#<;{tJT5mUj(eJu-!_=DpT?PwHy&(cUyz^ShZ^=J5`+4>id6w}r zUh-m{?OEr{e713$k#(psZ!tb=WL-I*$d{}O>%_X#{uVz!z?i~JE9M*7k6X`&CBMTN zCi{Z;nBMk% z*to~I&iJZvq>*?hAJ4UZl<_6wcH=ol+N1ss>+$Cj^Azr}qQ8;(XTHhv$E6RNe`q|_ z*kB}1ZZv0~nU-OirOS=I%QaEhp70{dJf!vWQpK0#RkkOdd2fId$L(yF?i0FWK6nn~ zIqYbkcGxf2r%37RZ&TcABS&+{CJo_66p&psuOlCM8@KddWp@UHbQ7@Lgr z$GNf2divX9evOfHfc?k$^S0}`ubNg6kF{3ZX2c$U##leVINLbGIMO)TxZXJ3IM7JG z*l2#e@d)E?BlAapR+r%&OFfLVgWu0ve_8p7e=hL8L%)^(-25}O=UcPcc+6)k{<-2V zS08C)-T(PhpRFBh)rH0`M&gD2M0|tfxwK#EN4XvHUtja*j3bQ9A92Y3;9NM|c0G-} z@8Ex<^=BBLHWCN3%<+f!LY^lcdfASBbAb6A<1*vXM%HbDIp^d?^Y@JSLB9*F-(`H= z$h=%^&VJtCyv6vkk^Rxre7+Gl(83NB2~IR8pInjDLwR3?7h%!)@n!c#+CK3?z9Nt8 zv?2c8X}-$1#7N#I9->>9dEz8<#>2XntY2)r*2sI$3Fe%qTg{mt z@)PTStm}!#rRF`2?2~8Bk1}pDl1J%>dAP~-j~a>B{^lP`GB>o4PuQjsYY4>gG z0rnM~al>aK7>6gE_E8^+P=FsK`ZRvor-@l1t#(ZpBlAf913S4s#nhg5+9mDsTnFnx zo@XB5w8MO&r(Jl8)RWiQU-!G;6OH#7*)N=TdBJ?Gb3OgDKgi>p7j?FK%=n6NfpL`aO(XM6JLg%? zx|)`Mj`Ji6BYCLSoV*Xm{5T7o59q;u2*zm%KiT}!L=Pno!c#onih*gx7v!lwJjVJk zkLkRzUf8Ak15P~C4skx#9Wj5z>Alu(HZrfDm`^pHZp44niftCw8b36)80r5K^Y@J# zjfWWDGV=T_<_{S!G~Qu6#yG?{(fF!yhmm=C#+-A4_L{9{|MfNRZoJZXrtwo_e`7D> z(?;f}oB8cVUJlxo5 zB;GDJKh20gTg_R=hs|lw{l+SZ3tU}&V@L|wr=p~nySzVNcdfcQ_;h4jS2 znpn}r_3ezz6X!gL9f*B~rE6F1kGZ!``zZZ;B^i_NQyM;hymoZIWn z$I&u|)u)_1lj$@a&aPnO{*S9w^E|AB`i^SbpR}408G7>+`Cwce@ z*WY98VZ7N$e%)gJijg>Ae;sZ8BqQrVytP|*_qOZj8aelu znD;XBzJ0hk^`Dv(PsHWD)-!I#!@To;_@eD*8i9i?JZF5&NZw_f?6)+IXWRZNBk$v; z6>AZUOA^j`$G&F1F=Kr24-kw87#}urz9eR4O5B3PRf^0X^T~Yk9DI~p9%LM6Bwx@D zdg7wRcI08x3gV-m6+4Wq3;q!g#M^e;F@LPzaO=tIoa+OvzsPuqk#mmr2UvfSk-T`G zc{k&JM&jXB^NB|G8_zR;V_Z)jy~iA=wm=?WeAI(yT3Cp}$otn?^DL={E<#9fLc)tE zc_FP&<2S8`vkn(nvB3De@oeLzM)C^dCJu>L&ZlPElfR~#FE+9s>~G@cIM)-0_;-Z$ zeT?-+=JihVn~jX8r}+vaaY|h6ww~w7M^9Q$etgCJM&qeQ&V`}opBjm?{S*7G=Jdxr z;2-;@(ssb@7H%`*C;nez{V_a{;4*XOmFkNTjGU{9y5ck|*w56Tk6`SXa47x5Q|w{| zd5iU9Jj`dikxDDBjcQ7{Z8XL;|?Rw5%*SC zuzrtNK^~*O8P;&D*3%Z1=v2(U!CxMQT7>j_zx#9ZnlE`wbq<`HO!oKYBlFQ zjdeQDdft~-nsYvZ8?3*@NWAtk=RIem`9vf4&-Lb~7?&Hj80iOl#z#KuXgk*bGV@Q2 z*BS3Ljxw_DFPTp=t~I8x!V2b#_#n^1x2F%9GcNKV`v`O52%h4DR(xn=zEis*&bo7+ z98hFfpntDGpQ@SPt6af45?_6-Coe2ACmx?RC;o_wKGxS6&o%4=eT+MdoP*5s!PYa+XPIv{;urC5bp_`a z{A3I3jMp018dEqiQ5m|psYE%!DgUg>{7SGO~=Uc_l<>(l3mXXc+g#r$#3 z6SuFsoef6v2J6Fqxy$v$Ta7vUew;b$FvFa<8e)E*@h0PNBYA4QIsQLlewC4RW1o{( ziEH*1?GShDuPJV4fpNQWe4^)lf%s#7$g{WE{$At##>vKAM)C#Yx4HtqnUA$p7<(Ey zhna`OtOEbj{eT_$w-!Z$tiyAOKBXV0iiC5%pe8QiAn{E+qbJT-cj9}L4R{~8+x#)( zR^vt^(9r_^4KgQRoojxjv6qo^k9NsxtS|9IeBk#$Kabzfn)fpfHV!kg4;deE%RXQq zK45$DYNa{x&Hg>pdiLpN^GU`O$hW)~kuOiMA^V*3D19DI{K3C0;@_8j4*N$D_JOQR zhs2I}Y-c@jNBmF^rtJ{7w42(|9d^Vw?a&{3_62&{Wqdq`J)C}c9zC4?;k3(h*rQMF zQcin}BW;Izo?|}o1OJ#Wp2sd7598u_p2H3fVoyEep?&(NU!G&U_{DP<*^u|E-sT|t z@ImW2PuS1TTi?^T(n$L!n$!Pu^Ye|{jl7p$XHNeU%sGFLFdu0of6g(##(0mBJTSxj z6XPo5#snvulP{Xg*Bcia-!_gho@QiTOe=1-um*(@f61rhH_m&^8DB3HM$XqE=FIOQ ziQQY~qYLyeSkFAu4(6=?d=v>X4_TszK4oEl6bbelyc>=B&fj=J<1^ zIdRPV4Yr>1m2nWqm$`nLv7d3P@d@McM*P3ZocZFt5Z>SQl}7Rb?`?I~A8b6|xYF2c zyw&)gkvzbDW4>Ko@v#Nwk2u|gfsys49jY5FY%#Kq)b9(4SKl*| z&g-o2VPw7UGABM+zbmZoY-B!`n$r&Zf;b}o9%Q@QjVp}gizm!S7@sySF*1MK%*n5d z%*n$?Cj3)#;*5RF_;S~?j<=b!Z;407Lw+A_yN8XFj603knU-OXCG1&;LDrKG$`2$0 z@y`4ATGy{OUYyu}Q9AUC8Wz~?Z#{7e5=ZS^PyDbB*ijD>&+Pwce*QEg@zf{LKW5H+ z5~sxBZr78~?luQlSK^}H6|Dbq^UcOl#*&eKUp6OhJDNXZoMRks+-hVz#0PPElIuTB zp1aI^i4i-Vf6)3Am>14b^1|rEkn=22SKtSEhWf<_#@dAMi(Z--FwR6>NN-4L$ZLtZ zkdi+zV11vng8016e1Z`_cbRiOTwzXLyw;p~Cl4KJJ^T6q^O43%;{YT3?R0bY72}y` zJ?G6s=Fb}0Pj{MMZ(L^l$Vgm}r(dw1d1C)zf2`}b8=1cvb6|-D_66&9yY=@Qdm2;V ze7?gK&nE`Vd!qgf#jmu2Jd^5E&ihfSPk9$Bcs~G%3+9>m;CbSW^`#x^_p?Ag<#~8& zM?2)#%WbgG$akOktN z{J?yC?E3Y_H13$k8LmIW2&DGJ3H7``kuS*;)L)EXJTT!$B>am~;*hwQ><3>pGM~)H zQP!_EK4Ro~{9a=HJx0z!=8toYc){PLwr74WG$&6mkFC}(F)}~gC*ZWV&2~>3$-nQK zZ#7Od5)bT~PpxNMz09kOLyf?83*`M~^W%&ojW-!nV4m0qA0!6&m8i><_-DS(Dc3}S zeU+#`LpNB{J*jyq;rpVuBnD~!U0hE-B~PMn?|L}Rk)O#Y#Pdfspnj)0^ZJze3C1&x z9~znO+2)Lwc%dEUf%P3_d-lu8=EUz}bIu#yBZgUjsj;{5Wh3i3)BG&saYo)l7nx5m zlK0pT$5_AC2r%Evtr$Rsk$ucMl6Tp+YZ80r0rm41t}(uEyvoS_G_7DfaPkLvX{9ST zXRrg;A{5|m>vt5??@N9DJ`9Lc;*axz$5{W=PvRJR;(3)HJi^HQ-)UZL>}w=HvJP)s z&-@QGZ!xl8PD=Ew?{@2NH1;y`{>yw4&m&xauW?@T9C3z!#QPJryTEvR@;rWSxBhP9 z9wYA&^h-PBQTn~a_NN-zKg>JhWqiyt?;Wf!ae0W_0oEcIiT8Ep2N)S=VqU>Ipug6_ zP~&i8n%A+<5sc)E0g1kd&Ppnn2l5f;+AiD8F!El7fA?F@dHR$&d7kmyC<&XqPqkbWqB-Zr!{#p;$@2@%HyD>1f#WT( z|K^&{HkORX8B@60iUmg2f%WD2_34A=>>H}_gZO&U6%#So2a*pkV}8#@u@79ak9w+E zpRQJrU#SPvyox%##@Zc$5ZArj68po`BEeEu-3e{ zvBCJLk-T79{{JsluG+Zz!lo zU2kEOk#lQ`IpZfE;5Ye&_sfm8hjV_OYdv}EX!8z6@(%l<#(KuT$ow%Q@ju=CZsSDb z%|_<=Lv!Afh`*<;r=3c3^8DrIbw=X*9`kpSdd79B^}w|VM)v7UbLRcd#18rag7Fq( zC*vw3^J7|uEoCVR#6Rl$5dME++CKiQaC-xdoC`;qHyW8A*3Ie)@*Vq&^|``l|Wa9(IRYumcr#W%QbL9Uj*S~4xyu8Gm z^`B|p%gDLGJUwXr^+sR>f{{2m+#J8|GiN?|jQN7UVS)J~Ur_&L@r5t@Kw?iEat<+1 zC)(~qWYgk5YOAq85i%FldM0)2#^oQSiyW! zPrD1OAfBl|$imS^_8ay1&pv6U!pM8WIp)MmVpgV{3n;F#Fw|HSueAHJ?RbuL!0!pJ zXTI5od9``PV-{$KbBH{9u0)!{*Yq>+gf)YV-Hm4(iIdglDdbi#Z@gz? zms#kCBEjAXXC4!EAw9IP=8IKyu^H>p-kkO7Xb!So*pWYoXV(8@KR3>}#>lzGIl{cc ziMJz@etB;oe?IDZ;-RNG<0D?!SLEN7wp(O8!bm=1KM%D2Oyl#${f(Q9#0~qer}gg{ zhZq^}67vm4_UUeOfcSgIik`*;jVBt(x6HrQ<-ZTgI=<@))`2|9dvKb^X=j*)tBuST z`xQIp37#VB$-Hh)3^pe`rMFtqVw{{Bm=n(+@j$$GvK@Njo_Qk9d7f@~4y2yv;q*g$ zquf9*V>cu7y3?HeLH^_X+wFSh{Sos$M$T*2b+q;LOZ&9b&-HzbFB&;dX_xbFqwD7w zcNu}JEZktc#z>s9FPQK7t{-e%VkD0-KXuk$YMf~7Z6wcf4&Grs=h3O=)ke-U@&NNV z&GqD=dh-?|e#|xp$OGgn#?{jXO-APD1s*isW+V@slGO9wOgaceQ{WdivoUM^8WG8{&-hs~Cb9=bH1r zb+kGA%CzDh3$GZ-7x;tS3JW_>7$+H*C;AH#ev3KzH>s|`fAs7#IOiLDUwGKQ43c`{ zkago+%x%Xx$GXk1p8ApI4;W83t}|X|tTQs7Z<&)ndG2cKi68oV)OzxCtN9Ve;YQ+` zJkGe^c0K!*bz)!das6?|rAFdxg!$#hLySX>eT~DCdd6*a#ViX$P#78iWb=cK#CKv| zk;W{pvF+YBZZcxm$9%AHl@VZE@SO-o_U~BpGmRUK^lMtd zy29CCyr(Rsq5xSxs@aF|hY$*oJlnP^BkQ1N7-(L z@e<=MBkMq3A&*|{dh+y2^LvdgM&gij5vL`6N$JzrgjwjqGFc$1~Qm5Br#t zSC*M`AE2Gvt)FTH8Z2})o=1g|yvTkbzhHKn1;#Vo{32tz9^{dMt{-E3(s-?LzLEDB z;&{CEJ&lYDzq?!ip7AXs^GCj+AL2Vp?AnGp*#;KPacP2 zA2=_mKggW-grs_(6o1KYaPl5`5k$}W(?R-NXIEek5~s8WcCbDB;#2e9M&f6!`358V z-L(Ao)Ge>mfx;CkYJ zw0U3S6-L&ZcsbYlH;m*t{DZr?;&ls08p+S>gZr(2)VR%_Y8+-}#8F~Z+8pJCi;Bwt-<&Url7e4~+lH_H4L;|3$~%lut#J^TG|bIzGn z<}VwU8G$_(*gyE2=E)k{?chPAM46KAwSHt-FeA#SC}*I&F1rs?At-+#OW#K*fWl8)*o$r z%s9<>my!3u_2%Te>&;nD_91b(&h?C!_sG+&?`LEl*qu{RuzE zxt=(~Kh6R2*2lIZUg?MRoag#x<7VT2M)C~n{G9ciQ^U;r7*8~IGQMb>WPIG%&dB~F z@3EdMT;J8$%lNjDxHPRe*#dqXXZ}0}#xyRdzS_dAM&2VI#K6dWCh7{F$3NyTwPSu$ zJ$Zq3OY7l#5R6?4kom=&dBnb>E5IUsXWQ|7hvd2TNj>eb{;n~<;~*~Yay|OP%!!+G%?BCJ zG!mDb0~cF=pOJk^T-|B?-NsjpCm2680*voAD>hJJB+sDUieNm?$i5$&)TflVO7-+J zAgxF^>yxM}jg=q_~FDe`8#t1-Hp#0S=VFD zxp%y4-pxo}-fTYFNITb>Pc$;$#3yn0l4$xko-fmE*Vj10Nc#=u?CXcjS+C{h zU5xnur1=!%d?W9pjIY-ETZ~T_$^YyxtIOYW7-q#RBk$GNEwb=13M2dAVDn{0#+jIZ zhQ4FXNF(#J@iUg6uc);8EhFo`^YhI=qrD-~j50sS$b2$C?9Yy5-iM|0wo_L=viTz{pJdD>!rjgfr?zs~x5j3bPzjpS3}eyR008W{)evEI+S zeywr0k@pqGd4%=smq*O!8}}FwGy>0AU>(Mov(Jvi(AX#8&<+R*o{@0YBT*Or|6;6X zYR~><9q1oD_N+JW4s zDiR!Oo<7e!Cgz3oprq#Qgzt;aOblugzAs9AV7IUSS%1!R{3o8!gTw`KkNqfj$atrl z^M1C){9YsX&ynV=%W>wf8&5RyzW2U4^K`TMIwQ}|FejebPsAnrWPjVyKk>=_B;K#G z9sA)WbNnWMZMVLU@igN$<2%N+#-2vb4dQLI^}tOQX!lJjjDw8qr^=+B^M!GNoalEX97;UGi8J)@_BJ3MLB`Gcv9Fo`Wq$5H<0j*BV^3qf zainpLk$Gi*5HDxDp1AK}US-^4yu(;BPBOCoP3FWc+II7c3F zL*tC(!Q0JijgyQ!jret^d4D7CH4mGQG;TA#Y9!ux@4)|eTu)x5KjQlW*OTu#4~AP$ z+!OD#kADZ-j&}K6<#g+rPwpqYUoemC8|H`kyUWir@APAJ`R8gqt+?Mvzo(iHF|uz> zD=xFZdB}L!PqSRldOW~`#+AmIiJpCddO5-u#9fIU`!rD(QsS9;XMdvSLE;-Eo|y;Y z`Eoxn(8#_d&-SH3BXP#M^4@)g>xoC^_gw2)=cVS1i~P#D&wP_7887pGv7cwXTFi-S z;$pt_(~OK4e`i?Fetp@zkCFYg$eew`JUwmwGsZ)W#ODF#lZ?y{{SpVqxPH0u6C-fE zg;~t2kq2|4f%meYW!}=$U^~Sf1-v&;rE$1%iIH&=hvcUfuID+z_9sV#Hti8EyS~BXQ5?EF(dIk*8DCb;~-DWwSI{%9lT&2{Ee zjI5(+8CuFx6mBwKWvoi-zbM6@FKVCKElm~XX&f^TnByn{ zKHwmE?L;d$ms>E{2X;;B-!#8G(HGHCNdVw_Cr~_^^?4oxDpPquoPox86t|W_@a{C+{%t%p2qQz;>+fd*+-s z#N$!c-)m$Zc+aLE&XLn?H`@rjfMBE@)UR21-iTk+r|-F+QeotMe2#eloDUe zOIpwTqKRLG!a6RjaIr9YvSs&t; zeXtvYec-@-)Kg6yGGFAY3vFA@tisNX}kGqV^1UTO+Uo>2-mZouAKn~cQIbo24X3C4Gg7aNBeiQggS>>K*I-+JPYean7Z=lW{n&Bolw zxEUAYVIR%49phzu?^sX&mzg&i&o-t|lPCt6?=;?*)bER)V*~QqgZp|Qsh^Q>@@i-6 z$&={mhdfJuZD)IsID5eQt;So7#2x#(kM(Pfqm0aNPjljSm3eO?`*Ml-L?ii={lomQ zU!J!eaq+e}`Q5Yvd*Y5bU_E(X-RK91!$x!BXskK$^tky$#^;QjBgE-Q>u)l$56Qpm z_l2&XX9Sq%3`K%(nJ+OC|Cq7PgHa^-m^u3=sjhg@irHy}IdN7{UH-l|=S6BqzD)JR z5qa}{8(e3+#JIna_e0iavGq?GiEsAZOzRIdE;Z5*d4Y94%Jt-R; zIltzZHyCOEE_2!$Y(C$}IuRe7YlpdhhLL?i|5lgbV@n;%nkZahew%SjQV)I2!ZsAf z7mU0oFz&>x;=ojq@C6A^Df64^;nNU|M?2(Ekhr2->{HJ3AocW5 zJ0R_o&lwN(7r6n}AATAQ8)q2LFcN1^n6p3X6T3?DH;mhj>@(t@{MO|9ZpKZ z4f_AqT)8h%Mm{C&yC%I3r!?+2{IKwG?SBwtN8uQRSTGJieH zmm67!9p=?W_5=18SkF3fF5G85>-3=cB;!k)_-WEKh|-w^{n@?=KYK}8?Q0GYa|{gn$r*asK5317#Zgv^To#W{d2JO zcNR(wg$JhC4;`FZkENAu1`>Ot(_w8JnuB+pZi zJ@s(vNBKF{m3XbQp7q^o&U1&D6X#EuFEf(ot~DnwzGJ@0c(xI{8Ro?4Sab3^`G7bY zB3s z{V?C0FZe-TU>>I>?b8lOUS&K7yMp;=d`I&@g0sv|N%TV#&iL>L+|PEzH*pB3p17wS z5Ig3D=h+W%>gjis`z5asPy1WH(a3sEGAG|(Z@$vVbEle<56Lsztv}E>*+?EFZZERF z!T6-{DdX)%;%~Hh599sD?nd%B`;qxP%=Modi5vXIfA;w>+f^DFFY!-%_{Dq42-~y& zX`gsv9aq?ndFFkEeDi?o-!rCAvVyq%#QY^A^O8OXPum?xh4IjYLy4m^ta#t}qLFcN z4xefLEaUmcSB&_7xcM^U5ym%-n~m2Q&ofRpE;bS`>&@9OubIy`PD?oB>S_J6i5>n~ zUD0fTcvx#r|7V-0ei08hxgL1i0`uBxzQB05k#mFmV0Fb+7CIQolgtzP^!=CTLG~;8 z6n(9QJ}3(C1M8~_>c?A89I-x(Ynkg886P!rPINOr%E-Qd+PtTc{Y{?OZv9TF``awqa zC2_yjdghJz++aOWLMTA?0r_QIfnAaM6jw|wGT4XygruH)ikf(0f9&T9?3gF);TKu) zxN)VCcw^ljww`$+e=#oh2XTS^LO=Je@j2sT#<51m&$?Y}{gFoU2Knv<>&dsoHRt4~ zuD{N>$vE6N*?5bQxT`Tg*SN*F+<3lmwh_NgD_G|ntYDrym@}`~Gj2HN7=90S#U&Ud z*gN6O6Y5703h=2!zcAsP8;QD1i3=3uUHBJ8;*j%?cx1jfUoLio#Q6i}2OEid^7(Y@ z+5emu6Ramr_LyH_ywbSC$bKhocu%1J=WX|nahmZ=Bl+}gbM_Ca=x$7(KiPWbgXdqiK>y7DGuG2S_PpQ0@33%yk^R2W{1Xg~HI9CttE(*}t7^hd#Aym$dW$HFloAIfhXjE=o0K?a`=J zv1*4Zf@;->imKXFh@FU$phi$aC@KV{q4um%Ge$L<@NfBYT|OLn;{9~JzjMxWKleE2 z+)opU1?W6zf3YuI#NGghL+jUlbv@40zT>@Zzp-zam*#T~<35G8(EjE<>UT#oe%G@Y z{S)qkuY&yqdH}o$?Z-ROo}b@y_IdN8f;jWI7;T;-)9e?1=eRHWC-LJ8EYktK2h0o4 z!!XzDKDb`}es7%D5NKUzVt*3+sY=&rAAiS#)|LC!341TN8QNDW(a)jr=8^WE_?N(Z zcpSRV_WjY=v!HqE`EJ3!9Qyq=HrTCi`))OU*YEy6!@eD^f{oC;cb_t_XG8mM1KRhD zL%Xl8_cM0iZ=Ck|_*X3E0h;gC==;!q)QYZwZ=rcM z3H=^gR}0bB-%Ip8cocSqd!hM#9c>;tzscCGqkiZiun9haJK;BY5SqvKodWE0p!My3 zdOr3E!qR5(BUb}-eeRp?dVd5S2-& z6q+~I{dw$TU^a9e7tq^bA#8>#VQ(m`-+37Fpx=eY`TPbAuok+`F6eh5&UrivcK6eJ z&N_Y_;*|CcwfBtYt38QxlN`Zr9vbg?4kONUD?$%|bD-yO3_TcDLGynN+Wp^z_Wp5w zH?X^p_Sb>fr@&Lt{O^VC4>!X)xE1>Sx(IFEWT78I>njKS82TRT(ffBF{=?97-GO%h z%yaX8JpNqx5jviE?t1s*Z-Ktw_hA0}|KA?K;J&yH$Mf8+FQ4z~4>&4l=NHU>rF~t) z8-Ku`L7RuWu)A*K?3fR5jQw$AOt_K9-*zGoxaefIpvV0XRt z6VKK2bY2sQzXHuC{f=w>oh8n`>ACnG$9JB;i9ZDULh*$kM>N21uojy4YeQTk`dqN@ zL7PwZKlj&g&nv$FmI3ew92)G(d;a(j+y2KN-R9wL9F@>|bv?7OyH4}M`_gmDBd!8| zhufg%WxV_E{a{|1H|DwhtPSHlgT>JJljwZ-9C~lr5A44k@Voxi=&!I2+Rx0V4D2QF zGOU8W*E+=f&%bLn?;X$mD&+y|z~>|DG0cPo(DRA$NxBI~%>ShLXkUsUD>_ot9Yb6u ZIvuuyp1<_`eB5!`hv%X_KIi+i{{hsO0d)WX literal 0 HcmV?d00001 diff --git a/uniter_model/data/test_data/input3.txt b/uniter_model/data/test_data/input3.txt new file mode 100644 index 0000000000000000000000000000000000000000..619f9be9298fe03e442596509debf6de5e019598 GIT binary patch literal 92849 zcmbTfcf3|b*5!YR0*V=PAPdHugNckFB6>|2iI*Vp93D{At4LB*%wkp~2}V@R2@r64 zI?Z&N>3RF@p6N8x(`lyDOi!oZs$1*(d|paVqd)$rwb!nFs%ls5+I8x=&%G>Ze^7?pL<@v`J$swU($Z;FFL4f%_;_rKe(czVtziaG;hkh*%KDEqn%y#?W$x6Oljdga?;PJv%F^QbEfW^Zys2eD)}b})IADAyDHqI`mzT0mW7?1J zAU1PuDeF9@$%4U7&t1iMnpte&Y|6NmFfgZGC-XO>J#`T~l?|rERQJ<}b{; zj=AWtcJTPlYP&QkpWaf+4w<3x%;-EL>9DFsU1x{R=sKp0Iyx+L)MZ6>_?UJp zvL35ugtm?-Wk)t=M=hK@Bh16mt=TaHW_0}g#%dcH8>+Ko)mRwyaph4TA4h#cZ1BWV z*0VY5HRJO~RNYuzuMze3i28U$eZz?Q#RmG9vXh##0Upu7*6ieeGNOj+#_SZ2=+yFv zPKzTtJvMkoDXVPGstQI_Q(x0q*_c&(L^U2!Z5UBqY@oiBH8f|99#K}orR<94Y`jNwWovd-;fNX=o3v_IdqmfiM|5o*(RH!G38ideb2h1HL=6q~ zP4(Gik7$ZVG&PKy>_!23vqNyXsj)c zTH}arjSb#b%9b=|w}1YKlx`Xu6^u(gqGcY@@-U(mv4NGP?2hK_PLF6+Yj#)Rh#IOJ z8f&w=J)(QcBU&9tbZ>0%zEXC7bM`>dh(Zpm%pUZJ9`cAD4kLObHn66YJ=&Z-<`F&K znmtiCqS}T!jp#`=Hdu4?R9j~)Xm4UcGB7}1-tfwxN8+s)ZK9?`q4*?WZ}YHFxysL$T_h(0Kf=)*Xok79!# zm$L26*^bX0k(R8XrmDI!+vyQ~;t}l%BibDs*i*_rZO;CkNA#7}?B5rTNHbJln|;+I z`dWEJUymdDMr`n#rR+a6XaBKiM2*!=nxk)dMBny^z7t0D-PpkQO4)yE&i=DU^!?WC zzkL3P>KYru2~eB;i%0Z>@`(O5j_8N6!5@{f|Jt1Wx1td>)z)bD|JWn?iAVI)FruHu z27X@3{(E!wKRlvev}XVF^GBo!Z>rYb&3@?-{i-~oU&j&sCN}unQue=^v;SQ*qQ<8B z%F66_9?|bTqCbQY{V_K1r&9KRnzR4w5&gL}`+xsrL^XBQ*YA|#|HmV z$}8GcR5a)9R2Gb@uCl(mu0C%cEV_o*lDvb&IJ&%}L=7m5h{^&FYMZnp z>hnW`MfBh>iT2=diNS*&!r1B&l1llJDx348RN8~1CAQ|rs4VuNx~i!_xFA_aCK~n7NL`kK*r^@EMm&zg!Y8$F6EA!sLB6`qAqCMy3&Z|@w zcu=R;tj5Z`I#@&xY9!i&T8Y7fI$>O z!8H=g9$YIDJ-AL%>}!IgQa({-b3RFBkq6cFRSos|s1zcP*+pgq$J;izT+^H>oW0ptiBTN*>%S6ZI2v z(Jc}~Kk9X5-YOD3xK&c@>o!THe2L2D{C1U}?}6g8UME$)G+0CrmPxb+%OwU6RtRIO zDw&qW$EcBqdN$;xp zlQOXfPf09$@U%$u;2BA=uV*Ed^0g|P^XF6+c~DthRbQE}3l`CX^%CvD28qFgjl$UK zCP}6Id6muiW|j6}i^SIa1(n4f)Q7D7qD<_;OA^Z-yetwuctuj|>s3jme5=al{56$D z9#n@{@BH;(5j}WAqCMCqF?jH%Ft+-Zq*DI2%I5qXmG2Ot8C79s4ViJR-X`6=R1Q%^xzYT_F$L9;K6QT zY;}*MQvRvR=KSBOv-vd{y) zKWe7_MJ8T@A4n{(!M}<`4}K^q_Vpu4rTo9DY|j6i$^s9WGL=O&>6#F_z zlI{m8bw5y9nD(%715_LaNS?GaY zD;gW~V`X9wj+0pS;CPYf!3mOLUnfe^{XnJe2P%s^sH<+$&7*g)h#vHjXb<{I3?B3o z##Z}F()~cC?guLE!9a<+AE+$!pteeH!TBjNQ9t3W;#7&DAN9I2KTRZhaJr<}*BO#@ zKTxUrfyyEeYHM{M=;Jvc`sdT_3!*w-LQx*w?2{Xk`*2TeL_@}a>ZdN53)Js2)AcrZd3TOBD$ z_XCx>AE>kkqb2Hopt8_|`o;!D>lm4+pKzjGATjjQcA||Hi5^@iDfTr^lI{m8bw5y9 z3*P6_XCyo zV5UUf4^$R-5Z>ZL)}AF3^%M5NY>A;C^|~^jBN9EBD=GFhPm=BjDs?|lS>!>jUfXp) z2o}+U1rqH+Nn-F|p)j_(NRsXcDs?|lX%7}l)crtZfd~3_MwY z^x#%Wv9H@C>3*P6_XCwh9;jWN_Dh3B^kA7pd$3$$@L+{7wz^W1?guJ$KTv58R!P+T zKxLr^4K?B1xmzao;2w!(4_1pr5AKx|`?^n(?guJ$KTuiZK|@2G_Q8X}B6{$UM0@bC z#Nfdr!r1B>NxC1X)cruEJ$PKA?guK1J!q`f8aycz^%HjTQxZczZ9DmCk?6rQl44)a zO49v6rS1nRi#({-hrreOx?mAKSTE5YY>*f{*eHyxZjz+?flA#ERN8|r5_LaNS?EDs zI1us|WnvFrl34cOWs&H?E0SVguS(MWK&9>nDvLa*(FW7~AXr2X-jHYywn+>gyeW*W zz9mWb1C_cTsI&*~O4R*8WuXW9SgO7@e_tl{-~)+e4?Yx$9(*Jz_VuwO-49giexS0* zgR1)4s``9qu!tUfBGDe~k{CSLEsU-1k)->9O5G1s+Jmo1)crtZu?J0+dI$chOzgqe zB$hq+x=8fk8Pv+dhkO@v9BLV()~cC?guK1 zJgBR!(>?je!6JI_6N&cVrxJq)KNH4Qe=bS)1C_cTsI&*akf{5C%0dqstMmic{FgHE z8vIIPc@2Io5kQuhOuMIJQOHR!G4cflfh@Oz2&;13dm2Y(dCR{tbP z_XCx>AE>kkf0n5GfyzP;Y8u1Y{}-9qgTG2Fd+;}r=)vD5#lHR_N%sSNk3;tZl|>%Z zx845Q2aD)I2Z{Eeqr~7rCt+;0vn1UQRO)`9(jIh`sQZD+Vh`G)Xn&d5g99X%J?JhH zJvdNO?CT&&x*w?2{Xk`*2aQeo)bY?@5j{9eqCGfVV(_4cFt&PxB;5~G>VBZo9vm%E z_XCxM9@OjTtjv#a)R9WTJlB?gtV>Kk8NY1Ci*#>5^h! zXGqfhK&9>nDvLa*uh%y<^6Fp_J*bgr4{9X_59)-m)p|*~AE?y*K&3rslBoNE%0dsC z^zmk8ewIwU24_nwufaJY(Svg(#l8kf()~cC?guK1JZNmx59ITq!6JGvOrkv)E-`p8 zLKs^eDM|MOmAW6OvVBZI(1ZFqeaR*tBNKaYfyA;0V@09|7fOnKjgzGNflA#E zR2Fzpubt2Ukke{Xk`r2l~uWKTgfB zmWe&MMq=56Yek|5*GY9(Ss=x?ZH%u!GjiIY;~F> z-49giexTAG%#^76fyzP;s@mR7XURnUwA~LRhJMtm?gt{#gSnDoU-KmCexOqK1C<3H z)K^y3>T{etSVRvNNVEqfiNS+~!r1B}NxC1X)cruEJy!=oQ*HQJ(cLn!2lq%Ud$3w0dT_6#*w=lMbU#q3`+>?L4{B?}XK4=x zi|D~a679jm5`zbi2xF^jBtskOT6f<^RTy+nJkL1OS=qcFC*Ns{gdDs?|lX%DtY)crtZ zp$FP-daHO*CidVZiDeI77Kt9bA}RLuswCYHRO)`9vdDw*qjud7f<^S;4T<((o5bM3 zo5I-YTat7?P^tTYN_+6GMBNWm7JAUs5IlHaCh8~btji} zAE?y*KxLr^O4F73&R`Ke_(Y;T*d;M|uv-{g-6Ki&1C_cTsI&)Pk*NED%0dqs>-58- z{HrqY8hlM+c@4fU5kQuhOuMIKbuho3@yD_BGizAe!nd`DvN;Jd=u z>h~n+exOqK1C{pR`x13OP+8hoX9 z#B1;?iRCr;wMg{fHflA#ER2F*Bq}j=T7c8O&zn5qa{va`U@JC^6^-q#? zKTxUrfl7PuXNkHWs4ViJzNSjwPR#!z6R*KvC6?FVZz9oyze|dJ{X>%O2l{Te?guIh zJ!t!Ra{FKrJ?J3O9(0r#Jm@5ht#+2A`+-W`4^-NNt`c=WP+8JgH3KTxUr zfl7ODv_#zxR2F$qUmdVBZIz=OKVCVdMtuMQT`gBpqUpjKk=piUTDt(ThndgcI!oiJ>3$x-uUt z5VBZI&;xzoRhc&ji|D~6679jI5`zbq31h35OVa&7rS1nR?ZK52 zbw5y9VBZI$OH8f-hpSy#2(C+SoUC!Nc3Q?q}bOyNxC1X z)crtZkq7ma`hAT&4;Imb1rqH+Nn-F|p)j_(NRsXcDs?|lX%7}l)crtZp$FAf`tUlx zStj=27KvpKT1BD?L532P8 zy2|{)U=cldNTNM>SYq(t5n*g~jU?RAE+$yptkMJZe6g59;}yW4>m{)9&8lGRyRq~{XnJe z2P*Bs7Kyqas4VoLwz@{2D7+{WufakQuhOuMIO{Nh0iKp z4;ImbHzeAFZ4!eAZwh0pZ%NYqK&9>nD(%6$5_LaNS?EEv-tF}B1DV)^4kQuhOu_TU#1 zbw5y9=t0}B59PmAE+$ypst}&?<>Cx z7SV&>OSA`nkQhApqcFDmCrP>=sMP&Hr9JqwMBNWm7J5)wud7u47n!J^@NW87iJ_mi zchkR#L=XNhDfaabNxC2Cw?cJ4P+8n zD(yj6iMk)CEcBqhE_}~!f0?MC5Jd+_4E?l4QFoE(!GV%uUk6Fj{XnJe2P%s^sILp3 z*&P}zq6ddbvVBZo9t@PI z`+>?r4{94KHB+a^ME!&)I#pumr!9(36Nw(2E-ChPh9un&RO)`9vd{zlAh0s84i?da z8j1FxR$}m=P8eIQm!$iFO5G1s+Jh#Ex*w=4^gutR(7WkbGO-6|ODubEj!5+2TuHI7 zL6USoP^tTY$|4VHs`ML_`OshyJs2j@9t@WlJQyL2t&WtW`+-W`4^-NN(Gqn(P+91K zers6g&KQ~4g9{{}#AP-49giexS0@gW7P?GzW|5!6g#y!KD&|2bT$B ztCvgC{XnJe2P*Bsl@fJ7P+90fV}1C1_G+1^pRkj!kr?_>uPgIwMWP4SNs4_?L5A-!V z#plvs5j|KY(H<<97(7@ZjIFMer2BzN-49gSgH;lBKTuiZL48x(E%R=f*n@i{mOWT4 z5=Y;}_)-49giexTAGY>}w@fyzP;^aT{{VBZI$b-sy{g^R-Jy=8!-jHYywn+>gyeW*Wz9mWb1C_cTsI&*~O4R*8 zWuXW9tU{lxy)P55!3Pq{Yw)2+^xz{&v9FIM>3*P6_XCwh9<=>he!eqUL=QfZXb*Nt z3?A$j##Z-8()~cC?guLE!B-^eexS10gSIpIt1|H#d`)6`4ZbcCJ@|&C*w;5D>3*P6 z_XCwh9@N!@Z_IuxSVRxLEzur)M`G~cyTaJ&_ay0lpi=h(mG!>3zTd4k zyMK|1J@|pdvIqYv5DRUM9|w!*!A~UGgP%$a9{fxg zTm88t-49giexTAG{6eDc2P%s^&{ux-+fey0WnvG0C9&+muSKE%O2l}Hax*w=4@Svf(D$Gs$U=cm&AkiLllo&kdB#fAE+$wprN|$XYIX%Mf9MLM0?OzV(_4!Ft*xXlI{m8 zbw5yP4+cuq{Xk`r2U_d4-#?Iv*Wgr%BzkbVq}bOPl5{^%sr!M-0uS`Vq3|1@ z)xjcqP$SVE)JhB<)Cps&^^$ZyP^tTYN_)^GQTGFtg&u@p)%S~JVh_%iSoYu?k?6s> zl44(jBHyNbW4;q;cqPMpDM$j@Ht>_W!tARS@*q_@oz#Mm?*=~ zk+XwRW%yGi2dB#LM^Fw)mErFq9hxe`pOiT)QMUbcpu*xt5dm zQr5NoI$1U#QHI~x%?2jQ@TWDhlM`k5NoID+-pZ=*hZ(X{6J_{|1leh+Qs0cpPEVBK z&Xk>zD8ruv$SM=`$`> zhQB*@QKAgLb)8+DD8sMbWzBml+uk^{OA=-H70m3?R2g>1Wr;HQnO&YL!{Dz-mHKQk z8=ok{@7rcqCd%*?l&TCvMs!)C(7{aSJ{k28GctP zo4L2L?U#45>l0--^|M*2Qa^CdW+%$<&A4n%stiFpH&ur8J1+Q)T#*ky)Y) zAE{<}q71*clP%a=+4lAGtduCj+jX`uQHEa($rdHb@D`ihkSN24HQ9}cGJGQ~TfCPt zez`ciDN%+`O0t_%rQUe5TM}jX-gefyx3VgHjF#P+Dnl^bmMX*R!IDH7eiEJCo+!iD zSF@#iE8D)2kS$A<`oTrEJW+<9*=8#eW%%kxwlYzM?`URsq{{FLcW0{9&-}Ahdn?<% zG?(3#D#KqwygN~bFGgkeB+Br4dA2%L>Z@AWy@@h>hc3G>QMSF5-k&JL_XV>DQe}7( zcraCl<#=dsW!slCvWF98_$hw&NUGG=nX)yBGJM*QJ(?)P*M_pkQe}8GdOT6K1;Z1m zGR)GGsWPnQQ>ij+!>3ba_|tUHB+BraSN81QN*&K?b8VswpQmTfCCc#Cl5Ab7)bHqL z>l0=82rS!>D8r|o*~V0Wg7 z{!*&c=N8$^i8B1KGkYafhIgA+Q>CuL+15lEe$tq|wzsnF2b-JyEv30q)pK+0gb;LAEnhhBf+RZ)MxZ71^#t89oWi zcJHlh`)DTHlPbf6e3~fRj>9jU#Kq{0`1c; z^~@vubOhsJ3GbWmlnzZ44<&qG^oV^K6x5Hhev$EG$uXKb(n2_u93Jo#{5d-dgHOiON{tOzpq-)yk2I` zzFllyX#{S!Fy6?%VgJ5l{qe?0Mj+)Ux`O$sGJnB1ng@;7B%JlUC((amJ|oc=QN~L> z`2oVW`LE~CuejxRr_2IE|5DE~#3aTqu=c4+jiYlys!$_QUPV9+Ckhowyut!fjG!yjw`!+DPQ1j!F@ z+Tr;WX$Pd9ai@AX{iNfep8i4X7!OE)_&v%U4>WRaPckR2W|=QCHW`VlD)Z4s@+IRV zZVq())y5X%Nybx*Zx}Zk@r(Vk#rngHeT}<}?0epmr&!Ov?_th9!e9DjzR0h{JN>o# z`A>|`7=c$2j1L)4G@hK)rxbrvJ^7CLOY7k~5E5j5n1`ggOxKpB&llz<>Wj?Dzu*bh z<3If}kF57;$@4FoA8%yc*O=dJWSw~bxXyak>os%Yl6d5K-ito69pmJ@W}LhSk#C8= zA%2eYXM{QN#(o-ZJ?HUAbIylV=J;t^!Tb2JR`fHHw>O(l<$(l=m**1wR`dIfoGXc0 znZ8h#qHvq}WaGt2eM;*R1^3%TT}a0j*6dxu`gXLsQ-bU(;)Z%S&(Tg-Ki|&CIl}u2 z9Q(AL!G7*t<2vIsBXLK(Gtc|G{$nHi=2mm!kT_uf5J#+Mv+XAvPcV)&zGY+_SJep+3z%fdxQ_7U|U z`;B(d-(_K^u{#wBvTkYrm@%J>gZ_Hifc^QhIrG8%b6&pZdggnB`9$ME#`BEaYhE`e zu80HjNjKL&Z)Cijo8-ZDNjp!Nw>MsITy7*DFEJ<2@!S~eiKiRQ*`MUKTI<<|rWK1U zunz`PVZ7f+9E?cnx0v@e9-Pz@?1=bS_jtwPubl85(U zR<0%viK8{HInhWQb~k69`k1p1c`q7fJ@G^QPOzSNBn}?6p8htP)6ZJ-(MIC&aC7p} zYV+Yn=7D%6ep_5W%J_hh^PKpdZawc+OU*kN>4)_u-=E_8*+%A%eRrPqV~s}`0qn?& zce!Gok@>?uMe;fG!v5-SL;OjfE8@4i0zZ})8Gb=M@sOGmhnyQ{+JHE?(0rEh7UO&) z^UXS)V?FIM{^zWpXk@?NYEB+G#C*4ryu|oVvwoZLP9yR2jye6Y&cw+Du7BD{T!Rl; zPh663###T8@e?EM!?EK$Vjjt>UH#mtMu7Q%FSEct!EfTTABF{ZSfW2Z;cq28r6aBA zm{yn**9FyO^}@2Wx6p2HTK>5d{V`9hTW8x-5AS3>>rdP<|KztjY}ecPq!GBm0`Z4_ zhz0iVedaeAZ#MG2HQAi^g+Auz8ecJz=g7nBtmj@*WzIS?{%Y%|8J{-}FwQZ4Xe2LQ zW`36O7~=~@=6jVn?N2d3+Q@!mob2CYU4OrkahaB(hozpz+l_A+*@r{TQ+Uh@+V5%3 zK4kvW=irQ=bzwcI|FXDW!2_eLC(kT3=e*~AhWT3V`t`;Z<7-CtbB+1M#zDrhM&j^N z^LvfN8Tos$^{n@5^K*<7jI1;L%(s5Ck+@=iu;H-eGzCF(-jE2&{0Ch9LtyV#O^2a-S8xA6A1?_lJ4u#@%J zfjpnKM?2(Q+QE+aryY3O56_`bIqlIO%`h zCyeAx;+^MKy8cchd6M{}oikiN%-CdP{~l@%K45;lk#Vq(c<*3e(tnHXLHvB%`ZtW^ zN9J*X^*0-Vr!6pE-a8(*p8Vb0oPKzoyi#ig>p?x|4d*C&ko~;Mie*$7Ip3I>B*sN6GW+%}YiA|M9EQ6`W7w&2uB?9?w@=U|rCI%s1_GzEKa7zv19$3nNh!An!kE z^$shD!(HaA_nqdoM)tuo=3|WHgJI^xGw&_L3;F19+kIdpAD(S~yODK#&HOMU?OtV0 zemK#5u<=qO`+)pUTr72cwQ-}7b>#hlcx`rli*b;V_%p2_A9b>Vah+~XUY%}^UpxqGS9CG5PRu9gN1f}*FK?UQ zZM@t_-elfZTYtXsMkD72@k@R^#q~XmyN%ZwTN5PCPqqFMW22G$_qaLvVW2tt;V$!6 zjpU&l%&#y$Y-Hasu7|7#4>soQkF{ok@c`rX zM)sxYXQ16@G*Lmj+*49B^4SZnVBU$NtF8akNd93x$WyE%{miyK`TRihJB%ZY#Mu;c z){}fY)_U^XN^{1^_!e5vy7n}$H+D0!FXo%SV_a)wTszIl>+Fk8)-N}nX(Z2*f0-}r zY*s=4#07eg{8yCc-mu|j9!PLQ!iy;TIIS z^~audp&gr7urH`bPkZDG+M_@A1$y@5S~o!6gHN=c_b1|+yvzO~p3%SM=Qw}aUoTp} z#<;{tJT5mUj(eJu-!_=DpT?PwHy&(cUyz^ShZ^=J5`+4>id6w}r zUh-m{?OEr{e713$k#(psZ!tb=WL-I*$d{}O>%_X#{uVz!z?i~JE9M*7k6X`&CBMTN zCi{Z;nBMk% z*to~I&iJZvq>*?hAJ4UZl<_6wcH=ol+N1ss>+$Cj^Azr}qQ8;(XTHhv$E6RNe`q|_ z*kB}1ZZv0~nU-OirOS=I%QaEhp70{dJf!vWQpK0#RkkOdd2fId$L(yF?i0FWK6nn~ zIqYbkcGxf2r%37RZ&TcABS&+{CJo_66p&psuOlCM8@KddWp@UHbQ7@Lgr z$GNf2divX9evOfHfc?k$^S0}`ubNg6kF{3ZX2c$U##leVINLbGIMO)TxZXJ3IM7JG z*l2#e@d)E?BlAapR+r%&OFfLVgWu0ve_8p7e=hL8L%)^(-25}O=UcPcc+6)k{<-2V zS08C)-T(PhpRFBh)rH0`M&gD2M0|tfxwK#EN4XvHUtja*j3bQ9A92Y3;9NM|c0G-} z@8Ex<^=BBLHWCN3%<+f!LY^lcdfASBbAb6A<1*vXM%HbDIp^d?^Y@JSLB9*F-(`H= z$h=%^&VJtCyv6vkk^Rxre7+Gl(83NB2~IR8pInjDLwR3?7h%!)@n!c#+CK3?z9Nt8 zv?2c8X}-$1#7N#I9->>9dEz8<#>2XntY2)r*2sI$3Fe%qTg{mt z@)PTStm}!#rRF`2?2~8Bk1}pDl1J%>dAP~-j~a>B{^lP`GB>o4PuQjsYY4>gG z0rnM~al>aK7>6gE_E8^+P=FsK`ZRvor-@l1t#(ZpBlAf913S4s#nhg5+9mDsTnFnx zo@XB5w8MO&r(Jl8)RWiQU-!G;6OH#7*)N=TdBJ?Gb3OgDKgi>p7j?FK%=n6NfpL`aO(XM6JLg%? zx|)`Mj`Ji6BYCLSoV*Xm{5T7o59q;u2*zm%KiT}!L=Pno!c#onih*gx7v!lwJjVJk zkLkRzUf8Ak15P~C4skx#9Wj5z>Alu(HZrfDm`^pHZp44niftCw8b36)80r5K^Y@J# zjfWWDGV=T_<_{S!G~Qu6#yG?{(fF!yhmm=C#+-A4_L{9{|MfNRZoJZXrtwo_e`7D> z(?;f}oB8cVUJlxo5 zB;GDJKh20gTg_R=hs|lw{l+SZ3tU}&V@L|wr=p~nySzVNcdfcQ_;h4jS2 znpn}r_3ezz6X!gL9f*B~rE6F1kGZ!``zZZ;B^i_NQyM;hymoZIWn z$I&u|)u)_1lj$@a&aPnO{*S9w^E|AB`i^SbpR}408G7>+`Cwce@ z*WY98VZ7N$e%)gJijg>Ae;sZ8BqQrVytP|*_qOZj8aelu znD;XBzJ0hk^`Dv(PsHWD)-!I#!@To;_@eD*8i9i?JZF5&NZw_f?6)+IXWRZNBk$v; z6>AZUOA^j`$G&F1F=Kr24-kw87#}urz9eR4O5B3PRf^0X^T~Yk9DI~p9%LM6Bwx@D zdg7wRcI08x3gV-m6+4Wq3;q!g#M^e;F@LPzaO=tIoa+OvzsPuqk#mmr2UvfSk-T`G zc{k&JM&jXB^NB|G8_zR;V_Z)jy~iA=wm=?WeAI(yT3Cp}$otn?^DL={E<#9fLc)tE zc_FP&<2S8`vkn(nvB3De@oeLzM)C^dCJu>L&ZlPElfR~#FE+9s>~G@cIM)-0_;-Z$ zeT?-+=JihVn~jX8r}+vaaY|h6ww~w7M^9Q$etgCJM&qeQ&V`}opBjm?{S*7G=Jdxr z;2-;@(ssb@7H%`*C;nez{V_a{;4*XOmFkNTjGU{9y5ck|*w56Tk6`SXa47x5Q|w{| zd5iU9Jj`dikxDDBjcQ7{Z8XL;|?Rw5%*SC zuzrtNK^~*O8P;&D*3%Z1=v2(U!CxMQT7>j_zx#9ZnlE`wbq<`HO!oKYBlFQ zjdeQDdft~-nsYvZ8?3*@NWAtk=RIem`9vf4&-Lb~7?&Hj80iOl#z#KuXgk*bGV@Q2 z*BS3Ljxw_DFPTp=t~I8x!V2b#_#n^1x2F%9GcNKV`v`O52%h4DR(xn=zEis*&bo7+ z98hFfpntDGpQ@SPt6af45?_6-Coe2ACmx?RC;o_wKGxS6&o%4=eT+MdoP*5s!PYa+XPIv{;urC5bp_`a z{A3I3jMp018dEqiQ5m|psYE%!DgUg>{7SGO~=Uc_l<>(l3mXXc+g#r$#3 z6SuFsoef6v2J6Fqxy$v$Ta7vUew;b$FvFa<8e)E*@h0PNBYA4QIsQLlewC4RW1o{( ziEH*1?GShDuPJV4fpNQWe4^)lf%s#7$g{WE{$At##>vKAM)C#Yx4HtqnUA$p7<(Ey zhna`OtOEbj{eT_$w-!Z$tiyAOKBXV0iiC5%pe8QiAn{E+qbJT-cj9}L4R{~8+x#)( zR^vt^(9r_^4KgQRoojxjv6qo^k9NsxtS|9IeBk#$Kabzfn)fpfHV!kg4;deE%RXQq zK45$DYNa{x&Hg>pdiLpN^GU`O$hW)~kuOiMA^V*3D19DI{K3C0;@_8j4*N$D_JOQR zhs2I}Y-c@jNBmF^rtJ{7w42(|9d^Vw?a&{3_62&{Wqdq`J)C}c9zC4?;k3(h*rQMF zQcin}BW;Izo?|}o1OJ#Wp2sd7598u_p2H3fVoyEep?&(NU!G&U_{DP<*^u|E-sT|t z@ImW2PuS1TTi?^T(n$L!n$!Pu^Ye|{jl7p$XHNeU%sGFLFdu0of6g(##(0mBJTSxj z6XPo5#snvulP{Xg*Bcia-!_gho@QiTOe=1-um*(@f61rhH_m&^8DB3HM$XqE=FIOQ ziQQY~qYLyeSkFAu4(6=?d=v>X4_TszK4oEl6bbelyc>=B&fj=J<1^ zIdRPV4Yr>1m2nWqm$`nLv7d3P@d@McM*P3ZocZFt5Z>SQl}7Rb?`?I~A8b6|xYF2c zyw&)gkvzbDW4>Ko@v#Nwk2u|gfsys49jY5FY%#Kq)b9(4SKl*| z&g-o2VPw7UGABM+zbmZoY-B!`n$r&Zf;b}o9%Q@QjVp}gizm!S7@sySF*1MK%*n5d z%*n$?Cj3)#;*5RF_;S~?j<=b!Z;407Lw+A_yN8XFj603knU-OXCG1&;LDrKG$`2$0 z@y`4ATGy{OUYyu}Q9AUC8Wz~?Z#{7e5=ZS^PyDbB*ijD>&+Pwce*QEg@zf{LKW5H+ z5~sxBZr78~?luQlSK^}H6|Dbq^UcOl#*&eKUp6OhJDNXZoMRks+-hVz#0PPElIuTB zp1aI^i4i-Vf6)3Am>14b^1|rEkn=22SKtSEhWf<_#@dAMi(Z--FwR6>NN-4L$ZLtZ zkdi+zV11vng8016e1Z`_cbRiOTwzXLyw;p~Cl4KJJ^T6q^O43%;{YT3?R0bY72}y` zJ?G6s=Fb}0Pj{MMZ(L^l$Vgm}r(dw1d1C)zf2`}b8=1cvb6|-D_66&9yY=@Qdm2;V ze7?gK&nE`Vd!qgf#jmu2Jd^5E&ihfSPk9$Bcs~G%3+9>m;CbSW^`#x^_p?Ag<#~8& zM?2)#%WbgG$akOktN z{J?yC?E3Y_H13$k8LmIW2&DGJ3H7``kuS*;)L)EXJTT!$B>am~;*hwQ><3>pGM~)H zQP!_EK4Ro~{9a=HJx0z!=8toYc){PLwr74WG$&6mkFC}(F)}~gC*ZWV&2~>3$-nQK zZ#7Od5)bT~PpxNMz09kOLyf?83*`M~^W%&ojW-!nV4m0qA0!6&m8i><_-DS(Dc3}S zeU+#`LpNB{J*jyq;rpVuBnD~!U0hE-B~PMn?|L}Rk)O#Y#Pdfspnj)0^ZJze3C1&x z9~znO+2)Lwc%dEUf%P3_d-lu8=EUz}bIu#yBZgUjsj;{5Wh3i3)BG&saYo)l7nx5m zlK0pT$5_AC2r%Evtr$Rsk$ucMl6Tp+YZ80r0rm41t}(uEyvoS_G_7DfaPkLvX{9ST zXRrg;A{5|m>vt5??@N9DJ`9Lc;*axz$5{W=PvRJR;(3)HJi^HQ-)UZL>}w=HvJP)s z&-@QGZ!xl8PD=Ew?{@2NH1;y`{>yw4&m&xauW?@T9C3z!#QPJryTEvR@;rWSxBhP9 z9wYA&^h-PBQTn~a_NN-zKg>JhWqiyt?;Wf!ae0W_0oEcIiT8Ep2N)S=VqU>Ipug6_ zP~&i8n%A+<5sc)E0g1kd&Ppnn2l5f;+AiD8F!El7fA?F@dHR$&d7kmyC<&XqPqkbWqB-Zr!{#p;$@2@%HyD>1f#WT( z|K^&{HkORX8B@60iUmg2f%WD2_34A=>>H}_gZO&U6%#So2a*pkV}8#@u@79ak9w+E zpRQJrU#SPvyox%##@Zc$5ZArj68po`BEeEu-3e{ zvBCJLk-T79{{JsluG+Zz!lo zU2kEOk#lQ`IpZfE;5Ye&_sfm8hjV_OYdv}EX!8z6@(%l<#(KuT$ow%Q@ju=CZsSDb z%|_<=Lv!Afh`*<;r=3c3^8DrIbw=X*9`kpSdd79B^}w|VM)v7UbLRcd#18rag7Fq( zC*vw3^J7|uEoCVR#6Rl$5dME++CKiQaC-xdoC`;qHyW8A*3Ie)@*Vq&^|``l|Wa9(IRYumcr#W%QbL9Uj*S~4xyu8Gm z^`B|p%gDLGJUwXr^+sR>f{{2m+#J8|GiN?|jQN7UVS)J~Ur_&L@r5t@Kw?iEat<+1 zC)(~qWYgk5YOAq85i%FldM0)2#^oQSiyW! zPrD1OAfBl|$imS^_8ay1&pv6U!pM8WIp)MmVpgV{3n;F#Fw|HSueAHJ?RbuL!0!pJ zXTI5od9``PV-{$KbBH{9u0)!{*Yq>+gf)YV-Hm4(iIdglDdbi#Z@gz? zms#kCBEjAXXC4!EAw9IP=8IKyu^H>p-kkO7Xb!So*pWYoXV(8@KR3>}#>lzGIl{cc ziMJz@etB;oe?IDZ;-RNG<0D?!SLEN7wp(O8!bm=1KM%D2Oyl#${f(Q9#0~qer}gg{ zhZq^}67vm4_UUeOfcSgIik`*;jVBt(x6HrQ<-ZTgI=<@))`2|9dvKb^X=j*)tBuST z`xQIp37#VB$-Hh)3^pe`rMFtqVw{{Bm=n(+@j$$GvK@Njo_Qk9d7f@~4y2yv;q*g$ zquf9*V>cu7y3?HeLH^_X+wFSh{Sos$M$T*2b+q;LOZ&9b&-HzbFB&;dX_xbFqwD7w zcNu}JEZktc#z>s9FPQK7t{-e%VkD0-KXuk$YMf~7Z6wcf4&Grs=h3O=)ke-U@&NNV z&GqD=dh-?|e#|xp$OGgn#?{jXO-APD1s*isW+V@slGO9wOgaceQ{WdivoUM^8WG8{&-hs~Cb9=bH1r zb+kGA%CzDh3$GZ-7x;tS3JW_>7$+H*C;AH#ev3KzH>s|`fAs7#IOiLDUwGKQ43c`{ zkago+%x%Xx$GXk1p8ApI4;W83t}|X|tTQs7Z<&)ndG2cKi68oV)OzxCtN9Ve;YQ+` zJkGe^c0K!*bz)!das6?|rAFdxg!$#hLySX>eT~DCdd6*a#ViX$P#78iWb=cK#CKv| zk;W{pvF+YBZZcxm$9%AHl@VZE@SO-o_U~BpGmRUK^lMtd zy29CCyr(Rsq5xSxs@aF|hY$*oJlnP^BkQ1N7-(L z@e<=MBkMq3A&*|{dh+y2^LvdgM&gij5vL`6N$JzrgjwjqGFc$1~Qm5Br#t zSC*M`AE2Gvt)FTH8Z2})o=1g|yvTkbzhHKn1;#Vo{32tz9^{dMt{-E3(s-?LzLEDB z;&{CEJ&lYDzq?!ip7AXs^GCj+AL2Vp?AnGp*#;KPacP2 zA2=_mKggW-grs_(6o1KYaPl5`5k$}W(?R-NXIEek5~s8WcCbDB;#2e9M&f6!`358V z-L(Ao)Ge>mfx;CkYJ zw0U3S6-L&ZcsbYlH;m*t{DZr?;&ls08p+S>gZr(2)VR%_Y8+-}#8F~Z+8pJCi;Bwt-<&Url7e4~+lH_H4L;|3$~%lut#J^TG|bIzGn z<}VwU8G$_(*gyE2=E)k{?chPAM46KAwSHt-FeA#SC}*I&F1rs?At-+#OW#K*fWl8)*o$r z%s9<>my!3u_2%Te>&;nD_91b(&h?C!_sG+&?`LEl*qu{RuzE zxt=(~Kh6R2*2lIZUg?MRoag#x<7VT2M)C~n{G9ciQ^U;r7*8~IGQMb>WPIG%&dB~F z@3EdMT;J8$%lNjDxHPRe*#dqXXZ}0}#xyRdzS_dAM&2VI#K6dWCh7{F$3NyTwPSu$ zJ$Zq3OY7l#5R6?4kom=&dBnb>E5IUsXWQ|7hvd2TNj>eb{;n~<;~*~Yay|OP%!!+G%?BCJ zG!mDb0~cF=pOJk^T-|B?-NsjpCm2680*voAD>hJJB+sDUieNm?$i5$&)TflVO7-+J zAgxF^>yxM}jg=q_~FDe`8#t1-Hp#0S=VFD zxp%y4-pxo}-fTYFNITb>Pc$;$#3yn0l4$xko-fmE*Vj10Nc#=u?CXcjS+C{h zU5xnur1=!%d?W9pjIY-ETZ~T_$^YyxtIOYW7-q#RBk$GNEwb=13M2dAVDn{0#+jIZ zhQ4FXNF(#J@iUg6uc);8EhFo`^YhI=qrD-~j50sS$b2$C?9Yy5-iM|0wo_L=viTz{pJdD>!rjgfr?zs~x5j3bPzjpS3}eyR008W{)evEI+S zeywr0k@pqGd4%=smq*O!8}}FwGy>0AU>(Mov(Jvi(AX#8&<+R*o{@0YBT*Or|6;6X zYR~><9q1oD_N+JW4s zDiR!Oo<7e!Cgz3oprq#Qgzt;aOblugzAs9AV7IUSS%1!R{3o8!gTw`KkNqfj$atrl z^M1C){9YsX&ynV=%W>wf8&5RyzW2U4^K`TMIwQ}|FejebPsAnrWPjVyKk>=_B;K#G z9sA)WbNnWMZMVLU@igN$<2%N+#-2vb4dQLI^}tOQX!lJjjDw8qr^=+B^M!GNoalEX97;UGi8J)@_BJ3MLB`Gcv9Fo`Wq$5H<0j*BV^3qf zainpLk$Gi*5HDxDp1AK}US-^4yu(;BPBOCoP3FWc+II7c3F zL*tC(!Q0JijgyQ!jret^d4D7CH4mGQG;TA#Y9!ux@4)|eTu)x5KjQlW*OTu#4~AP$ z+!OD#kADZ-j&}K6<#g+rPwpqYUoemC8|H`kyUWir@APAJ`R8gqt+?Mvzo(iHF|uz> zD=xFZdB}L!PqSRldOW~`#+AmIiJpCddO5-u#9fIU`!rD(QsS9;XMdvSLE;-Eo|y;Y z`Eoxn(8#_d&-SH3BXP#M^4@)g>xoC^_gw2)=cVS1i~P#D&wP_7887pGv7cwXTFi-S z;$pt_(~OK4e`i?Fetp@zkCFYg$eew`JUwmwGsZ)W#ODF#lZ?y{{SpVqxPH0u6C-fE zg;~t2kq2|4f%meYW!}=$U^~Sf1-v&;rE$1%iIH&=hvcUfuID+z_9sV#Hti8EyS~BXQ5?EF(dIk*8DCb;~-DWwSI{%9lT&2{Ee zjI5(+8CuFx6mBwKWvoi-zbM6@FKVCKElm~XX&f^TnByn{ zKHwmE?L;d$ms>E{2X;;B-!#8G(HGHCNdVw_Cr~_^^?4oxDpPquoPox86t|W_@a{C+{%t%p2qQz;>+fd*+-s z#N$!c-)m$Zc+aLE&XLn?H`@rjfMBE@)UR21-iTk+r|-F+QeotMe2#eloDUe zOIpwTqKRLG!a6RjaIr9YvSs&t; zeXtvYec-@-)Kg6yGGFAY3vFA@tisNX}kGqV^1UTO+Uo>2-mZouAKn~cQIbo24X3C4Gg7aNBeiQggS>>K*I-+JPYean7Z=lW{n&Bolw zxEUAYVIR%49phzu?^sX&mzg&i&o-t|lPCt6?=;?*)bER)V*~QqgZp|Qsh^Q>@@i-6 z$&={mhdfJuZD)IsID5eQt;So7#2x#(kM(Pfqm0aNPjljSm3eO?`*Ml-L?ii={lomQ zU!J!eaq+e}`Q5Yvd*Y5bU_E(X-RK91!$x!BXskK$^tky$#^;QjBgE-Q>u)l$56Qpm z_l2&XX9Sq%3`K%(nJ+OC|Cq7PgHa^-m^u3=sjhg@irHy}IdN7{UH-l|=S6BqzD)JR z5qa}{8(e3+#JIna_e0iavGq?GiEsAZOzRIdE;Z5*d4Y94%Jt-R; zIltzZHyCOEE_2!$Y(C$}IuRe7YlpdhhLL?i|5lgbV@n;%nkZahew%SjQV)I2!ZsAf z7mU0oFz&>x;=ojq@C6A^Df64^;nNU|M?2(Ekhr2->{HJ3AocW5 zJ0R_o&lwN(7r6n}AATAQ8)q2LFcN1^n6p3X6T3?DH;mhj>@(t@{MO|9ZpKZ z4f_AqT)8h%Mm{C&yC%I3r!?+2{IKwG?SBwtN8uQRSTGJieH zmm67!9p=?W_5=18SkF3fF5G85>-3=cB;!k)_-WEKh|-w^{n@?=KYK}8?Q0GYa|{gn$r*asK5317#Zgv^To#W{d2JO zcNR(wg$JhC4;`FZkENAu1`>Ot(_w8JnuB+pZi zJ@s(vNBKF{m3XbQp7q^o&U1&D6X#EuFEf(ot~DnwzGJ@0c(xI{8Ro?4Sab3^`G7bY zB3s z{V?C0FZe-TU>>I>?b8lOUS&K7yMp;=d`I&@g0sv|N%TV#&iL>L+|PEzH*pB3p17wS z5Ig3D=h+W%>gjis`z5asPy1WH(a3sEGAG|(Z@$vVbEle<56Lsztv}E>*+?EFZZERF z!T6-{DdX)%;%~Hh599sD?nd%B`;qxP%=Modi5vXIfA;w>+f^DFFY!-%_{Dq42-~y& zX`gsv9aq?ndFFkEeDi?o-!rCAvVyq%#QY^A^O8OXPum?xh4IjYLy4m^ta#t}qLFcN z4xefLEaUmcSB&_7xcM^U5ym%-n~m2Q&ofRpE;bS`>&@9OubIy`PD?oB>S_J6i5>n~ zUD0fTcvx#r|7V-0ei08hxgL1i0`uBxzQB05k#mFmV0Fb+7CIQolgtzP^!=CTLG~;8 z6n(9QJ}3(C1M8~_>c?A89I-x(Ynkg886P!rPINOr%E-Qd+PtTc{Y{?OZv9TF``awqa zC2_yjdghJz++aOWLMTA?0r_QIfnAaM6jw|wGT4XygruH)ikf(0f9&T9?3gF);TKu) zxN)VCcw^ljww`$+e=#oh2XTS^LO=Je@j2sT#<51m&$?Y}{gFoU2Knv<>&dsoHRt4~ zuD{N>$vE6N*?5bQxT`Tg*SN*F+<3lmwh_NgD_G|ntYDrym@}`~Gj2HN7=90S#U&Ud z*gN6O6Y5703h=2!zcAsP8;QD1i3=3uUHBJ8;*j%?cx1jfUoLio#Q6i}2OEid^7(Y@ z+5emu6Ramr_LyH_ywbSC$bKhocu%1J=WX|nahmZ=Bl+}gbM_Ca=x$7(KiPWbgXdqiK>y7DGuG2S_PpQ0@33%yk^R2W{1Xg~HI9CttE(*}t7^hd#Aym$dW$HFloAIfhXjE=o0K?a`=J zv1*4Zf@;->imKXFh@FU$phi$aC@KV{q4um%Ge$L<@NfBYT|OLn;{9~JzjMxWKleE2 z+)opU1?W6zf3YuI#NGghL+jUlbv@40zT>@Zzp-zam*#T~<35G8(EjE<>UT#oe%G@Y z{S)qkuY&yqdH}o$?Z-ROo}b@y_IdN8f;jWI7;T;-)9e?1=eRHWC-LJ8EYktK2h0o4 z!!XzDKDb`}es7%D5NKUzVt*3+sY=&rAAiS#)|LC!341TN8QNDW(a)jr=8^WE_?N(Z zcpSRV_WjY=v!HqE`EJ3!9Qyq=HrTCi`))OU*YEy6!@eD^f{oC;cb_t_XG8mM1KRhD zL%Xl8_cM0iZ=Ck|_*X3E0h;gC==;!q)QYZwZ=rcM z3H=^gR}0bB-%Ip8cocSqd!hM#9c>;tzscCGqkiZiun9haJK;BY5SqvKodWE0p!My3 zdOr3E!qR5(BUb}-eeRp?dVd5S2-& z6q+~I{dw$TU^a9e7tq^bA#8>#VQ(m`-+37Fpx=eY`TPbAuok+`F6eh5&UrivcK6eJ z&N_Y_;*|CcwfBtYt38QxlN`Zr9vbg?4kONUD?$%|bD-yO3_TcDLGynN+Wp^z_Wp5w zH?X^p_Sb>fr@&Lt{O^VC4>!X)xE1>Sx(IFEWT78I>njKS82TRT(ffBF{=?97-GO%h z%yaX8JpNqx5jviE?t1s*Z-Ktw_hA0}|KA?K;J&yH$Mf8+FQ4z~4>&4l=NHU>rF~t) z8-Ku`L7RuWu)A*K?3fR5jQw$AOt_K9-*zGoxaefIpvV0XRt z6VKK2bY2sQzXHuC{f=w>oh8n`>ACnG$9JB;i9ZDULh*$kM>N21uojy4YeQTk`dqN@ zL7PwZKlj&g&nv$FmI3ew92)G(d;a(j+y2KN-R9wL9F@>|bv?7OyH4}M`_gmDBd!8| zhufg%WxV_E{a{|1H|DwhtPSHlgT>JJljwZ-9C~lr5A44k@Voxi=&!I2+Rx0V4D2QF zGOU8W*E+=f&%bLn?;X$mD&+y|z~>|DG0cPo(DRA$NxBI~%>ShLXkUsUD>_ot9Yb6u ZIvuuyp1<_`eB5!`hv%X_KIi+i{{hsO0d)WX literal 0 HcmV?d00001 diff --git a/uniter_model/data/test_data/input4.txt b/uniter_model/data/test_data/input4.txt new file mode 100644 index 0000000000000000000000000000000000000000..619f9be9298fe03e442596509debf6de5e019598 GIT binary patch literal 92849 zcmbTfcf3|b*5!YR0*V=PAPdHugNckFB6>|2iI*Vp93D{At4LB*%wkp~2}V@R2@r64 zI?Z&N>3RF@p6N8x(`lyDOi!oZs$1*(d|paVqd)$rwb!nFs%ls5+I8x=&%G>Ze^7?pL<@v`J$swU($Z;FFL4f%_;_rKe(czVtziaG;hkh*%KDEqn%y#?W$x6Oljdga?;PJv%F^QbEfW^Zys2eD)}b})IADAyDHqI`mzT0mW7?1J zAU1PuDeF9@$%4U7&t1iMnpte&Y|6NmFfgZGC-XO>J#`T~l?|rERQJ<}b{; zj=AWtcJTPlYP&QkpWaf+4w<3x%;-EL>9DFsU1x{R=sKp0Iyx+L)MZ6>_?UJp zvL35ugtm?-Wk)t=M=hK@Bh16mt=TaHW_0}g#%dcH8>+Ko)mRwyaph4TA4h#cZ1BWV z*0VY5HRJO~RNYuzuMze3i28U$eZz?Q#RmG9vXh##0Upu7*6ieeGNOj+#_SZ2=+yFv zPKzTtJvMkoDXVPGstQI_Q(x0q*_c&(L^U2!Z5UBqY@oiBH8f|99#K}orR<94Y`jNwWovd-;fNX=o3v_IdqmfiM|5o*(RH!G38ideb2h1HL=6q~ zP4(Gik7$ZVG&PKy>_!23vqNyXsj)c zTH}arjSb#b%9b=|w}1YKlx`Xu6^u(gqGcY@@-U(mv4NGP?2hK_PLF6+Yj#)Rh#IOJ z8f&w=J)(QcBU&9tbZ>0%zEXC7bM`>dh(Zpm%pUZJ9`cAD4kLObHn66YJ=&Z-<`F&K znmtiCqS}T!jp#`=Hdu4?R9j~)Xm4UcGB7}1-tfwxN8+s)ZK9?`q4*?WZ}YHFxysL$T_h(0Kf=)*Xok79!# zm$L26*^bX0k(R8XrmDI!+vyQ~;t}l%BibDs*i*_rZO;CkNA#7}?B5rTNHbJln|;+I z`dWEJUymdDMr`n#rR+a6XaBKiM2*!=nxk)dMBny^z7t0D-PpkQO4)yE&i=DU^!?WC zzkL3P>KYru2~eB;i%0Z>@`(O5j_8N6!5@{f|Jt1Wx1td>)z)bD|JWn?iAVI)FruHu z27X@3{(E!wKRlvev}XVF^GBo!Z>rYb&3@?-{i-~oU&j&sCN}unQue=^v;SQ*qQ<8B z%F66_9?|bTqCbQY{V_K1r&9KRnzR4w5&gL}`+xsrL^XBQ*YA|#|HmV z$}8GcR5a)9R2Gb@uCl(mu0C%cEV_o*lDvb&IJ&%}L=7m5h{^&FYMZnp z>hnW`MfBh>iT2=diNS*&!r1B&l1llJDx348RN8~1CAQ|rs4VuNx~i!_xFA_aCK~n7NL`kK*r^@EMm&zg!Y8$F6EA!sLB6`qAqCMy3&Z|@w zcu=R;tj5Z`I#@&xY9!i&T8Y7fI$>O z!8H=g9$YIDJ-AL%>}!IgQa({-b3RFBkq6cFRSos|s1zcP*+pgq$J;izT+^H>oW0ptiBTN*>%S6ZI2v z(Jc}~Kk9X5-YOD3xK&c@>o!THe2L2D{C1U}?}6g8UME$)G+0CrmPxb+%OwU6RtRIO zDw&qW$EcBqdN$;xp zlQOXfPf09$@U%$u;2BA=uV*Ed^0g|P^XF6+c~DthRbQE}3l`CX^%CvD28qFgjl$UK zCP}6Id6muiW|j6}i^SIa1(n4f)Q7D7qD<_;OA^Z-yetwuctuj|>s3jme5=al{56$D z9#n@{@BH;(5j}WAqCMCqF?jH%Ft+-Zq*DI2%I5qXmG2Ot8C79s4ViJR-X`6=R1Q%^xzYT_F$L9;K6QT zY;}*MQvRvR=KSBOv-vd{y) zKWe7_MJ8T@A4n{(!M}<`4}K^q_Vpu4rTo9DY|j6i$^s9WGL=O&>6#F_z zlI{m8bw5y9nD(%715_LaNS?GaY zD;gW~V`X9wj+0pS;CPYf!3mOLUnfe^{XnJe2P%s^sH<+$&7*g)h#vHjXb<{I3?B3o z##Z}F()~cC?guLE!9a<+AE+$!pteeH!TBjNQ9t3W;#7&DAN9I2KTRZhaJr<}*BO#@ zKTxUrfyyEeYHM{M=;Jvc`sdT_3!*w-LQx*w?2{Xk`*2TeL_@}a>ZdN53)Js2)AcrZd3TOBD$ z_XCx>AE>kkqb2Hopt8_|`o;!D>lm4+pKzjGATjjQcA||Hi5^@iDfTr^lI{m8bw5y9 z3*P6_XCyo zV5UUf4^$R-5Z>ZL)}AF3^%M5NY>A;C^|~^jBN9EBD=GFhPm=BjDs?|lS>!>jUfXp) z2o}+U1rqH+Nn-F|p)j_(NRsXcDs?|lX%7}l)crtZfd~3_MwY z^x#%Wv9H@C>3*P6_XCwh9;jWN_Dh3B^kA7pd$3$$@L+{7wz^W1?guJ$KTv58R!P+T zKxLr^4K?B1xmzao;2w!(4_1pr5AKx|`?^n(?guJ$KTuiZK|@2G_Q8X}B6{$UM0@bC z#Nfdr!r1B>NxC1X)cruEJ$PKA?guK1J!q`f8aycz^%HjTQxZczZ9DmCk?6rQl44)a zO49v6rS1nRi#({-hrreOx?mAKSTE5YY>*f{*eHyxZjz+?flA#ERN8|r5_LaNS?EDs zI1us|WnvFrl34cOWs&H?E0SVguS(MWK&9>nDvLa*(FW7~AXr2X-jHYywn+>gyeW*W zz9mWb1C_cTsI&*~O4R*8WuXW9SgO7@e_tl{-~)+e4?Yx$9(*Jz_VuwO-49giexS0* zgR1)4s``9qu!tUfBGDe~k{CSLEsU-1k)->9O5G1s+Jmo1)crtZu?J0+dI$chOzgqe zB$hq+x=8fk8Pv+dhkO@v9BLV()~cC?guK1 zJgBR!(>?je!6JI_6N&cVrxJq)KNH4Qe=bS)1C_cTsI&*akf{5C%0dqstMmic{FgHE z8vIIPc@2Io5kQuhOuMIJQOHR!G4cflfh@Oz2&;13dm2Y(dCR{tbP z_XCx>AE>kkf0n5GfyzP;Y8u1Y{}-9qgTG2Fd+;}r=)vD5#lHR_N%sSNk3;tZl|>%Z zx845Q2aD)I2Z{Eeqr~7rCt+;0vn1UQRO)`9(jIh`sQZD+Vh`G)Xn&d5g99X%J?JhH zJvdNO?CT&&x*w?2{Xk`*2aQeo)bY?@5j{9eqCGfVV(_4cFt&PxB;5~G>VBZo9vm%E z_XCxM9@OjTtjv#a)R9WTJlB?gtV>Kk8NY1Ci*#>5^h! zXGqfhK&9>nDvLa*uh%y<^6Fp_J*bgr4{9X_59)-m)p|*~AE?y*K&3rslBoNE%0dsC z^zmk8ewIwU24_nwufaJY(Svg(#l8kf()~cC?guK1JZNmx59ITq!6JGvOrkv)E-`p8 zLKs^eDM|MOmAW6OvVBZI(1ZFqeaR*tBNKaYfyA;0V@09|7fOnKjgzGNflA#E zR2Fzpubt2Ukke{Xk`r2l~uWKTgfB zmWe&MMq=56Yek|5*GY9(Ss=x?ZH%u!GjiIY;~F> z-49giexTAG%#^76fyzP;s@mR7XURnUwA~LRhJMtm?gt{#gSnDoU-KmCexOqK1C<3H z)K^y3>T{etSVRvNNVEqfiNS+~!r1B}NxC1X)cruEJy!=oQ*HQJ(cLn!2lq%Ud$3w0dT_6#*w=lMbU#q3`+>?L4{B?}XK4=x zi|D~a679jm5`zbi2xF^jBtskOT6f<^RTy+nJkL1OS=qcFC*Ns{gdDs?|lX%DtY)crtZ zp$FP-daHO*CidVZiDeI77Kt9bA}RLuswCYHRO)`9vdDw*qjud7f<^S;4T<((o5bM3 zo5I-YTat7?P^tTYN_+6GMBNWm7JAUs5IlHaCh8~btji} zAE?y*KxLr^O4F73&R`Ke_(Y;T*d;M|uv-{g-6Ki&1C_cTsI&)Pk*NED%0dqs>-58- z{HrqY8hlM+c@4fU5kQuhOuMIKbuho3@yD_BGizAe!nd`DvN;Jd=u z>h~n+exOqK1C{pR`x13OP+8hoX9 z#B1;?iRCr;wMg{fHflA#ER2F*Bq}j=T7c8O&zn5qa{va`U@JC^6^-q#? zKTxUrfl7PuXNkHWs4ViJzNSjwPR#!z6R*KvC6?FVZz9oyze|dJ{X>%O2l{Te?guIh zJ!t!Ra{FKrJ?J3O9(0r#Jm@5ht#+2A`+-W`4^-NNt`c=WP+8JgH3KTxUr zfl7ODv_#zxR2F$qUmdVBZIz=OKVCVdMtuMQT`gBpqUpjKk=piUTDt(ThndgcI!oiJ>3$x-uUt z5VBZI&;xzoRhc&ji|D~6679jI5`zbq31h35OVa&7rS1nR?ZK52 zbw5y9VBZI$OH8f-hpSy#2(C+SoUC!Nc3Q?q}bOyNxC1X z)crtZkq7ma`hAT&4;Imb1rqH+Nn-F|p)j_(NRsXcDs?|lX%7}l)crtZp$FAf`tUlx zStj=27KvpKT1BD?L532P8 zy2|{)U=cldNTNM>SYq(t5n*g~jU?RAE+$yptkMJZe6g59;}yW4>m{)9&8lGRyRq~{XnJe z2P*Bs7Kyqas4VoLwz@{2D7+{WufakQuhOuMIO{Nh0iKp z4;ImbHzeAFZ4!eAZwh0pZ%NYqK&9>nD(%6$5_LaNS?EEv-tF}B1DV)^4kQuhOu_TU#1 zbw5y9=t0}B59PmAE+$ypst}&?<>Cx z7SV&>OSA`nkQhApqcFDmCrP>=sMP&Hr9JqwMBNWm7J5)wud7u47n!J^@NW87iJ_mi zchkR#L=XNhDfaabNxC2Cw?cJ4P+8n zD(yj6iMk)CEcBqhE_}~!f0?MC5Jd+_4E?l4QFoE(!GV%uUk6Fj{XnJe2P%s^sILp3 z*&P}zq6ddbvVBZo9t@PI z`+>?r4{94KHB+a^ME!&)I#pumr!9(36Nw(2E-ChPh9un&RO)`9vd{zlAh0s84i?da z8j1FxR$}m=P8eIQm!$iFO5G1s+Jh#Ex*w=4^gutR(7WkbGO-6|ODubEj!5+2TuHI7 zL6USoP^tTY$|4VHs`ML_`OshyJs2j@9t@WlJQyL2t&WtW`+-W`4^-NN(Gqn(P+91K zers6g&KQ~4g9{{}#AP-49giexS0@gW7P?GzW|5!6g#y!KD&|2bT$B ztCvgC{XnJe2P*Bsl@fJ7P+90fV}1C1_G+1^pRkj!kr?_>uPgIwMWP4SNs4_?L5A-!V z#plvs5j|KY(H<<97(7@ZjIFMer2BzN-49gSgH;lBKTuiZL48x(E%R=f*n@i{mOWT4 z5=Y;}_)-49giexTAGY>}w@fyzP;^aT{{VBZI$b-sy{g^R-Jy=8!-jHYywn+>gyeW*Wz9mWb1C_cTsI&*~O4R*8 zWuXW9tU{lxy)P55!3Pq{Yw)2+^xz{&v9FIM>3*P6_XCwh9<=>he!eqUL=QfZXb*Nt z3?A$j##Z-8()~cC?guLE!B-^eexS10gSIpIt1|H#d`)6`4ZbcCJ@|&C*w;5D>3*P6 z_XCwh9@N!@Z_IuxSVRxLEzur)M`G~cyTaJ&_ay0lpi=h(mG!>3zTd4k zyMK|1J@|pdvIqYv5DRUM9|w!*!A~UGgP%$a9{fxg zTm88t-49giexTAG{6eDc2P%s^&{ux-+fey0WnvG0C9&+muSKE%O2l}Hax*w=4@Svf(D$Gs$U=cm&AkiLllo&kdB#fAE+$wprN|$XYIX%Mf9MLM0?OzV(_4!Ft*xXlI{m8 zbw5yP4+cuq{Xk`r2U_d4-#?Iv*Wgr%BzkbVq}bOPl5{^%sr!M-0uS`Vq3|1@ z)xjcqP$SVE)JhB<)Cps&^^$ZyP^tTYN_)^GQTGFtg&u@p)%S~JVh_%iSoYu?k?6s> zl44(jBHyNbW4;q;cqPMpDM$j@Ht>_W!tARS@*q_@oz#Mm?*=~ zk+XwRW%yGi2dB#LM^Fw)mErFq9hxe`pOiT)QMUbcpu*xt5dm zQr5NoI$1U#QHI~x%?2jQ@TWDhlM`k5NoID+-pZ=*hZ(X{6J_{|1leh+Qs0cpPEVBK z&Xk>zD8ruv$SM=`$`> zhQB*@QKAgLb)8+DD8sMbWzBml+uk^{OA=-H70m3?R2g>1Wr;HQnO&YL!{Dz-mHKQk z8=ok{@7rcqCd%*?l&TCvMs!)C(7{aSJ{k28GctP zo4L2L?U#45>l0--^|M*2Qa^CdW+%$<&A4n%stiFpH&ur8J1+Q)T#*ky)Y) zAE{<}q71*clP%a=+4lAGtduCj+jX`uQHEa($rdHb@D`ihkSN24HQ9}cGJGQ~TfCPt zez`ciDN%+`O0t_%rQUe5TM}jX-gefyx3VgHjF#P+Dnl^bmMX*R!IDH7eiEJCo+!iD zSF@#iE8D)2kS$A<`oTrEJW+<9*=8#eW%%kxwlYzM?`URsq{{FLcW0{9&-}Ahdn?<% zG?(3#D#KqwygN~bFGgkeB+Br4dA2%L>Z@AWy@@h>hc3G>QMSF5-k&JL_XV>DQe}7( zcraCl<#=dsW!slCvWF98_$hw&NUGG=nX)yBGJM*QJ(?)P*M_pkQe}8GdOT6K1;Z1m zGR)GGsWPnQQ>ij+!>3ba_|tUHB+BraSN81QN*&K?b8VswpQmTfCCc#Cl5Ab7)bHqL z>l0=82rS!>D8r|o*~V0Wg7 z{!*&c=N8$^i8B1KGkYafhIgA+Q>CuL+15lEe$tq|wzsnF2b-JyEv30q)pK+0gb;LAEnhhBf+RZ)MxZ71^#t89oWi zcJHlh`)DTHlPbf6e3~fRj>9jU#Kq{0`1c; z^~@vubOhsJ3GbWmlnzZ44<&qG^oV^K6x5Hhev$EG$uXKb(n2_u93Jo#{5d-dgHOiON{tOzpq-)yk2I` zzFllyX#{S!Fy6?%VgJ5l{qe?0Mj+)Ux`O$sGJnB1ng@;7B%JlUC((amJ|oc=QN~L> z`2oVW`LE~CuejxRr_2IE|5DE~#3aTqu=c4+jiYlys!$_QUPV9+Ckhowyut!fjG!yjw`!+DPQ1j!F@ z+Tr;WX$Pd9ai@AX{iNfep8i4X7!OE)_&v%U4>WRaPckR2W|=QCHW`VlD)Z4s@+IRV zZVq())y5X%Nybx*Zx}Zk@r(Vk#rngHeT}<}?0epmr&!Ov?_th9!e9DjzR0h{JN>o# z`A>|`7=c$2j1L)4G@hK)rxbrvJ^7CLOY7k~5E5j5n1`ggOxKpB&llz<>Wj?Dzu*bh z<3If}kF57;$@4FoA8%yc*O=dJWSw~bxXyak>os%Yl6d5K-ito69pmJ@W}LhSk#C8= zA%2eYXM{QN#(o-ZJ?HUAbIylV=J;t^!Tb2JR`fHHw>O(l<$(l=m**1wR`dIfoGXc0 znZ8h#qHvq}WaGt2eM;*R1^3%TT}a0j*6dxu`gXLsQ-bU(;)Z%S&(Tg-Ki|&CIl}u2 z9Q(AL!G7*t<2vIsBXLK(Gtc|G{$nHi=2mm!kT_uf5J#+Mv+XAvPcV)&zGY+_SJep+3z%fdxQ_7U|U z`;B(d-(_K^u{#wBvTkYrm@%J>gZ_Hifc^QhIrG8%b6&pZdggnB`9$ME#`BEaYhE`e zu80HjNjKL&Z)Cijo8-ZDNjp!Nw>MsITy7*DFEJ<2@!S~eiKiRQ*`MUKTI<<|rWK1U zunz`PVZ7f+9E?cnx0v@e9-Pz@?1=bS_jtwPubl85(U zR<0%viK8{HInhWQb~k69`k1p1c`q7fJ@G^QPOzSNBn}?6p8htP)6ZJ-(MIC&aC7p} zYV+Yn=7D%6ep_5W%J_hh^PKpdZawc+OU*kN>4)_u-=E_8*+%A%eRrPqV~s}`0qn?& zce!Gok@>?uMe;fG!v5-SL;OjfE8@4i0zZ})8Gb=M@sOGmhnyQ{+JHE?(0rEh7UO&) z^UXS)V?FIM{^zWpXk@?NYEB+G#C*4ryu|oVvwoZLP9yR2jye6Y&cw+Du7BD{T!Rl; zPh663###T8@e?EM!?EK$Vjjt>UH#mtMu7Q%FSEct!EfTTABF{ZSfW2Z;cq28r6aBA zm{yn**9FyO^}@2Wx6p2HTK>5d{V`9hTW8x-5AS3>>rdP<|KztjY}ecPq!GBm0`Z4_ zhz0iVedaeAZ#MG2HQAi^g+Auz8ecJz=g7nBtmj@*WzIS?{%Y%|8J{-}FwQZ4Xe2LQ zW`36O7~=~@=6jVn?N2d3+Q@!mob2CYU4OrkahaB(hozpz+l_A+*@r{TQ+Uh@+V5%3 zK4kvW=irQ=bzwcI|FXDW!2_eLC(kT3=e*~AhWT3V`t`;Z<7-CtbB+1M#zDrhM&j^N z^LvfN8Tos$^{n@5^K*<7jI1;L%(s5Ck+@=iu;H-eGzCF(-jE2&{0Ch9LtyV#O^2a-S8xA6A1?_lJ4u#@%J zfjpnKM?2(Q+QE+aryY3O56_`bIqlIO%`h zCyeAx;+^MKy8cchd6M{}oikiN%-CdP{~l@%K45;lk#Vq(c<*3e(tnHXLHvB%`ZtW^ zN9J*X^*0-Vr!6pE-a8(*p8Vb0oPKzoyi#ig>p?x|4d*C&ko~;Mie*$7Ip3I>B*sN6GW+%}YiA|M9EQ6`W7w&2uB?9?w@=U|rCI%s1_GzEKa7zv19$3nNh!An!kE z^$shD!(HaA_nqdoM)tuo=3|WHgJI^xGw&_L3;F19+kIdpAD(S~yODK#&HOMU?OtV0 zemK#5u<=qO`+)pUTr72cwQ-}7b>#hlcx`rli*b;V_%p2_A9b>Vah+~XUY%}^UpxqGS9CG5PRu9gN1f}*FK?UQ zZM@t_-elfZTYtXsMkD72@k@R^#q~XmyN%ZwTN5PCPqqFMW22G$_qaLvVW2tt;V$!6 zjpU&l%&#y$Y-Hasu7|7#4>soQkF{ok@c`rX zM)sxYXQ16@G*Lmj+*49B^4SZnVBU$NtF8akNd93x$WyE%{miyK`TRihJB%ZY#Mu;c z){}fY)_U^XN^{1^_!e5vy7n}$H+D0!FXo%SV_a)wTszIl>+Fk8)-N}nX(Z2*f0-}r zY*s=4#07eg{8yCc-mu|j9!PLQ!iy;TIIS z^~audp&gr7urH`bPkZDG+M_@A1$y@5S~o!6gHN=c_b1|+yvzO~p3%SM=Qw}aUoTp} z#<;{tJT5mUj(eJu-!_=DpT?PwHy&(cUyz^ShZ^=J5`+4>id6w}r zUh-m{?OEr{e713$k#(psZ!tb=WL-I*$d{}O>%_X#{uVz!z?i~JE9M*7k6X`&CBMTN zCi{Z;nBMk% z*to~I&iJZvq>*?hAJ4UZl<_6wcH=ol+N1ss>+$Cj^Azr}qQ8;(XTHhv$E6RNe`q|_ z*kB}1ZZv0~nU-OirOS=I%QaEhp70{dJf!vWQpK0#RkkOdd2fId$L(yF?i0FWK6nn~ zIqYbkcGxf2r%37RZ&TcABS&+{CJo_66p&psuOlCM8@KddWp@UHbQ7@Lgr z$GNf2divX9evOfHfc?k$^S0}`ubNg6kF{3ZX2c$U##leVINLbGIMO)TxZXJ3IM7JG z*l2#e@d)E?BlAapR+r%&OFfLVgWu0ve_8p7e=hL8L%)^(-25}O=UcPcc+6)k{<-2V zS08C)-T(PhpRFBh)rH0`M&gD2M0|tfxwK#EN4XvHUtja*j3bQ9A92Y3;9NM|c0G-} z@8Ex<^=BBLHWCN3%<+f!LY^lcdfASBbAb6A<1*vXM%HbDIp^d?^Y@JSLB9*F-(`H= z$h=%^&VJtCyv6vkk^Rxre7+Gl(83NB2~IR8pInjDLwR3?7h%!)@n!c#+CK3?z9Nt8 zv?2c8X}-$1#7N#I9->>9dEz8<#>2XntY2)r*2sI$3Fe%qTg{mt z@)PTStm}!#rRF`2?2~8Bk1}pDl1J%>dAP~-j~a>B{^lP`GB>o4PuQjsYY4>gG z0rnM~al>aK7>6gE_E8^+P=FsK`ZRvor-@l1t#(ZpBlAf913S4s#nhg5+9mDsTnFnx zo@XB5w8MO&r(Jl8)RWiQU-!G;6OH#7*)N=TdBJ?Gb3OgDKgi>p7j?FK%=n6NfpL`aO(XM6JLg%? zx|)`Mj`Ji6BYCLSoV*Xm{5T7o59q;u2*zm%KiT}!L=Pno!c#onih*gx7v!lwJjVJk zkLkRzUf8Ak15P~C4skx#9Wj5z>Alu(HZrfDm`^pHZp44niftCw8b36)80r5K^Y@J# zjfWWDGV=T_<_{S!G~Qu6#yG?{(fF!yhmm=C#+-A4_L{9{|MfNRZoJZXrtwo_e`7D> z(?;f}oB8cVUJlxo5 zB;GDJKh20gTg_R=hs|lw{l+SZ3tU}&V@L|wr=p~nySzVNcdfcQ_;h4jS2 znpn}r_3ezz6X!gL9f*B~rE6F1kGZ!``zZZ;B^i_NQyM;hymoZIWn z$I&u|)u)_1lj$@a&aPnO{*S9w^E|AB`i^SbpR}408G7>+`Cwce@ z*WY98VZ7N$e%)gJijg>Ae;sZ8BqQrVytP|*_qOZj8aelu znD;XBzJ0hk^`Dv(PsHWD)-!I#!@To;_@eD*8i9i?JZF5&NZw_f?6)+IXWRZNBk$v; z6>AZUOA^j`$G&F1F=Kr24-kw87#}urz9eR4O5B3PRf^0X^T~Yk9DI~p9%LM6Bwx@D zdg7wRcI08x3gV-m6+4Wq3;q!g#M^e;F@LPzaO=tIoa+OvzsPuqk#mmr2UvfSk-T`G zc{k&JM&jXB^NB|G8_zR;V_Z)jy~iA=wm=?WeAI(yT3Cp}$otn?^DL={E<#9fLc)tE zc_FP&<2S8`vkn(nvB3De@oeLzM)C^dCJu>L&ZlPElfR~#FE+9s>~G@cIM)-0_;-Z$ zeT?-+=JihVn~jX8r}+vaaY|h6ww~w7M^9Q$etgCJM&qeQ&V`}opBjm?{S*7G=Jdxr z;2-;@(ssb@7H%`*C;nez{V_a{;4*XOmFkNTjGU{9y5ck|*w56Tk6`SXa47x5Q|w{| zd5iU9Jj`dikxDDBjcQ7{Z8XL;|?Rw5%*SC zuzrtNK^~*O8P;&D*3%Z1=v2(U!CxMQT7>j_zx#9ZnlE`wbq<`HO!oKYBlFQ zjdeQDdft~-nsYvZ8?3*@NWAtk=RIem`9vf4&-Lb~7?&Hj80iOl#z#KuXgk*bGV@Q2 z*BS3Ljxw_DFPTp=t~I8x!V2b#_#n^1x2F%9GcNKV`v`O52%h4DR(xn=zEis*&bo7+ z98hFfpntDGpQ@SPt6af45?_6-Coe2ACmx?RC;o_wKGxS6&o%4=eT+MdoP*5s!PYa+XPIv{;urC5bp_`a z{A3I3jMp018dEqiQ5m|psYE%!DgUg>{7SGO~=Uc_l<>(l3mXXc+g#r$#3 z6SuFsoef6v2J6Fqxy$v$Ta7vUew;b$FvFa<8e)E*@h0PNBYA4QIsQLlewC4RW1o{( ziEH*1?GShDuPJV4fpNQWe4^)lf%s#7$g{WE{$At##>vKAM)C#Yx4HtqnUA$p7<(Ey zhna`OtOEbj{eT_$w-!Z$tiyAOKBXV0iiC5%pe8QiAn{E+qbJT-cj9}L4R{~8+x#)( zR^vt^(9r_^4KgQRoojxjv6qo^k9NsxtS|9IeBk#$Kabzfn)fpfHV!kg4;deE%RXQq zK45$DYNa{x&Hg>pdiLpN^GU`O$hW)~kuOiMA^V*3D19DI{K3C0;@_8j4*N$D_JOQR zhs2I}Y-c@jNBmF^rtJ{7w42(|9d^Vw?a&{3_62&{Wqdq`J)C}c9zC4?;k3(h*rQMF zQcin}BW;Izo?|}o1OJ#Wp2sd7598u_p2H3fVoyEep?&(NU!G&U_{DP<*^u|E-sT|t z@ImW2PuS1TTi?^T(n$L!n$!Pu^Ye|{jl7p$XHNeU%sGFLFdu0of6g(##(0mBJTSxj z6XPo5#snvulP{Xg*Bcia-!_gho@QiTOe=1-um*(@f61rhH_m&^8DB3HM$XqE=FIOQ ziQQY~qYLyeSkFAu4(6=?d=v>X4_TszK4oEl6bbelyc>=B&fj=J<1^ zIdRPV4Yr>1m2nWqm$`nLv7d3P@d@McM*P3ZocZFt5Z>SQl}7Rb?`?I~A8b6|xYF2c zyw&)gkvzbDW4>Ko@v#Nwk2u|gfsys49jY5FY%#Kq)b9(4SKl*| z&g-o2VPw7UGABM+zbmZoY-B!`n$r&Zf;b}o9%Q@QjVp}gizm!S7@sySF*1MK%*n5d z%*n$?Cj3)#;*5RF_;S~?j<=b!Z;407Lw+A_yN8XFj603knU-OXCG1&;LDrKG$`2$0 z@y`4ATGy{OUYyu}Q9AUC8Wz~?Z#{7e5=ZS^PyDbB*ijD>&+Pwce*QEg@zf{LKW5H+ z5~sxBZr78~?luQlSK^}H6|Dbq^UcOl#*&eKUp6OhJDNXZoMRks+-hVz#0PPElIuTB zp1aI^i4i-Vf6)3Am>14b^1|rEkn=22SKtSEhWf<_#@dAMi(Z--FwR6>NN-4L$ZLtZ zkdi+zV11vng8016e1Z`_cbRiOTwzXLyw;p~Cl4KJJ^T6q^O43%;{YT3?R0bY72}y` zJ?G6s=Fb}0Pj{MMZ(L^l$Vgm}r(dw1d1C)zf2`}b8=1cvb6|-D_66&9yY=@Qdm2;V ze7?gK&nE`Vd!qgf#jmu2Jd^5E&ihfSPk9$Bcs~G%3+9>m;CbSW^`#x^_p?Ag<#~8& zM?2)#%WbgG$akOktN z{J?yC?E3Y_H13$k8LmIW2&DGJ3H7``kuS*;)L)EXJTT!$B>am~;*hwQ><3>pGM~)H zQP!_EK4Ro~{9a=HJx0z!=8toYc){PLwr74WG$&6mkFC}(F)}~gC*ZWV&2~>3$-nQK zZ#7Od5)bT~PpxNMz09kOLyf?83*`M~^W%&ojW-!nV4m0qA0!6&m8i><_-DS(Dc3}S zeU+#`LpNB{J*jyq;rpVuBnD~!U0hE-B~PMn?|L}Rk)O#Y#Pdfspnj)0^ZJze3C1&x z9~znO+2)Lwc%dEUf%P3_d-lu8=EUz}bIu#yBZgUjsj;{5Wh3i3)BG&saYo)l7nx5m zlK0pT$5_AC2r%Evtr$Rsk$ucMl6Tp+YZ80r0rm41t}(uEyvoS_G_7DfaPkLvX{9ST zXRrg;A{5|m>vt5??@N9DJ`9Lc;*axz$5{W=PvRJR;(3)HJi^HQ-)UZL>}w=HvJP)s z&-@QGZ!xl8PD=Ew?{@2NH1;y`{>yw4&m&xauW?@T9C3z!#QPJryTEvR@;rWSxBhP9 z9wYA&^h-PBQTn~a_NN-zKg>JhWqiyt?;Wf!ae0W_0oEcIiT8Ep2N)S=VqU>Ipug6_ zP~&i8n%A+<5sc)E0g1kd&Ppnn2l5f;+AiD8F!El7fA?F@dHR$&d7kmyC<&XqPqkbWqB-Zr!{#p;$@2@%HyD>1f#WT( z|K^&{HkORX8B@60iUmg2f%WD2_34A=>>H}_gZO&U6%#So2a*pkV}8#@u@79ak9w+E zpRQJrU#SPvyox%##@Zc$5ZArj68po`BEeEu-3e{ zvBCJLk-T79{{JsluG+Zz!lo zU2kEOk#lQ`IpZfE;5Ye&_sfm8hjV_OYdv}EX!8z6@(%l<#(KuT$ow%Q@ju=CZsSDb z%|_<=Lv!Afh`*<;r=3c3^8DrIbw=X*9`kpSdd79B^}w|VM)v7UbLRcd#18rag7Fq( zC*vw3^J7|uEoCVR#6Rl$5dME++CKiQaC-xdoC`;qHyW8A*3Ie)@*Vq&^|``l|Wa9(IRYumcr#W%QbL9Uj*S~4xyu8Gm z^`B|p%gDLGJUwXr^+sR>f{{2m+#J8|GiN?|jQN7UVS)J~Ur_&L@r5t@Kw?iEat<+1 zC)(~qWYgk5YOAq85i%FldM0)2#^oQSiyW! zPrD1OAfBl|$imS^_8ay1&pv6U!pM8WIp)MmVpgV{3n;F#Fw|HSueAHJ?RbuL!0!pJ zXTI5od9``PV-{$KbBH{9u0)!{*Yq>+gf)YV-Hm4(iIdglDdbi#Z@gz? zms#kCBEjAXXC4!EAw9IP=8IKyu^H>p-kkO7Xb!So*pWYoXV(8@KR3>}#>lzGIl{cc ziMJz@etB;oe?IDZ;-RNG<0D?!SLEN7wp(O8!bm=1KM%D2Oyl#${f(Q9#0~qer}gg{ zhZq^}67vm4_UUeOfcSgIik`*;jVBt(x6HrQ<-ZTgI=<@))`2|9dvKb^X=j*)tBuST z`xQIp37#VB$-Hh)3^pe`rMFtqVw{{Bm=n(+@j$$GvK@Njo_Qk9d7f@~4y2yv;q*g$ zquf9*V>cu7y3?HeLH^_X+wFSh{Sos$M$T*2b+q;LOZ&9b&-HzbFB&;dX_xbFqwD7w zcNu}JEZktc#z>s9FPQK7t{-e%VkD0-KXuk$YMf~7Z6wcf4&Grs=h3O=)ke-U@&NNV z&GqD=dh-?|e#|xp$OGgn#?{jXO-APD1s*isW+V@slGO9wOgaceQ{WdivoUM^8WG8{&-hs~Cb9=bH1r zb+kGA%CzDh3$GZ-7x;tS3JW_>7$+H*C;AH#ev3KzH>s|`fAs7#IOiLDUwGKQ43c`{ zkago+%x%Xx$GXk1p8ApI4;W83t}|X|tTQs7Z<&)ndG2cKi68oV)OzxCtN9Ve;YQ+` zJkGe^c0K!*bz)!das6?|rAFdxg!$#hLySX>eT~DCdd6*a#ViX$P#78iWb=cK#CKv| zk;W{pvF+YBZZcxm$9%AHl@VZE@SO-o_U~BpGmRUK^lMtd zy29CCyr(Rsq5xSxs@aF|hY$*oJlnP^BkQ1N7-(L z@e<=MBkMq3A&*|{dh+y2^LvdgM&gij5vL`6N$JzrgjwjqGFc$1~Qm5Br#t zSC*M`AE2Gvt)FTH8Z2})o=1g|yvTkbzhHKn1;#Vo{32tz9^{dMt{-E3(s-?LzLEDB z;&{CEJ&lYDzq?!ip7AXs^GCj+AL2Vp?AnGp*#;KPacP2 zA2=_mKggW-grs_(6o1KYaPl5`5k$}W(?R-NXIEek5~s8WcCbDB;#2e9M&f6!`358V z-L(Ao)Ge>mfx;CkYJ zw0U3S6-L&ZcsbYlH;m*t{DZr?;&ls08p+S>gZr(2)VR%_Y8+-}#8F~Z+8pJCi;Bwt-<&Url7e4~+lH_H4L;|3$~%lut#J^TG|bIzGn z<}VwU8G$_(*gyE2=E)k{?chPAM46KAwSHt-FeA#SC}*I&F1rs?At-+#OW#K*fWl8)*o$r z%s9<>my!3u_2%Te>&;nD_91b(&h?C!_sG+&?`LEl*qu{RuzE zxt=(~Kh6R2*2lIZUg?MRoag#x<7VT2M)C~n{G9ciQ^U;r7*8~IGQMb>WPIG%&dB~F z@3EdMT;J8$%lNjDxHPRe*#dqXXZ}0}#xyRdzS_dAM&2VI#K6dWCh7{F$3NyTwPSu$ zJ$Zq3OY7l#5R6?4kom=&dBnb>E5IUsXWQ|7hvd2TNj>eb{;n~<;~*~Yay|OP%!!+G%?BCJ zG!mDb0~cF=pOJk^T-|B?-NsjpCm2680*voAD>hJJB+sDUieNm?$i5$&)TflVO7-+J zAgxF^>yxM}jg=q_~FDe`8#t1-Hp#0S=VFD zxp%y4-pxo}-fTYFNITb>Pc$;$#3yn0l4$xko-fmE*Vj10Nc#=u?CXcjS+C{h zU5xnur1=!%d?W9pjIY-ETZ~T_$^YyxtIOYW7-q#RBk$GNEwb=13M2dAVDn{0#+jIZ zhQ4FXNF(#J@iUg6uc);8EhFo`^YhI=qrD-~j50sS$b2$C?9Yy5-iM|0wo_L=viTz{pJdD>!rjgfr?zs~x5j3bPzjpS3}eyR008W{)evEI+S zeywr0k@pqGd4%=smq*O!8}}FwGy>0AU>(Mov(Jvi(AX#8&<+R*o{@0YBT*Or|6;6X zYR~><9q1oD_N+JW4s zDiR!Oo<7e!Cgz3oprq#Qgzt;aOblugzAs9AV7IUSS%1!R{3o8!gTw`KkNqfj$atrl z^M1C){9YsX&ynV=%W>wf8&5RyzW2U4^K`TMIwQ}|FejebPsAnrWPjVyKk>=_B;K#G z9sA)WbNnWMZMVLU@igN$<2%N+#-2vb4dQLI^}tOQX!lJjjDw8qr^=+B^M!GNoalEX97;UGi8J)@_BJ3MLB`Gcv9Fo`Wq$5H<0j*BV^3qf zainpLk$Gi*5HDxDp1AK}US-^4yu(;BPBOCoP3FWc+II7c3F zL*tC(!Q0JijgyQ!jret^d4D7CH4mGQG;TA#Y9!ux@4)|eTu)x5KjQlW*OTu#4~AP$ z+!OD#kADZ-j&}K6<#g+rPwpqYUoemC8|H`kyUWir@APAJ`R8gqt+?Mvzo(iHF|uz> zD=xFZdB}L!PqSRldOW~`#+AmIiJpCddO5-u#9fIU`!rD(QsS9;XMdvSLE;-Eo|y;Y z`Eoxn(8#_d&-SH3BXP#M^4@)g>xoC^_gw2)=cVS1i~P#D&wP_7887pGv7cwXTFi-S z;$pt_(~OK4e`i?Fetp@zkCFYg$eew`JUwmwGsZ)W#ODF#lZ?y{{SpVqxPH0u6C-fE zg;~t2kq2|4f%meYW!}=$U^~Sf1-v&;rE$1%iIH&=hvcUfuID+z_9sV#Hti8EyS~BXQ5?EF(dIk*8DCb;~-DWwSI{%9lT&2{Ee zjI5(+8CuFx6mBwKWvoi-zbM6@FKVCKElm~XX&f^TnByn{ zKHwmE?L;d$ms>E{2X;;B-!#8G(HGHCNdVw_Cr~_^^?4oxDpPquoPox86t|W_@a{C+{%t%p2qQz;>+fd*+-s z#N$!c-)m$Zc+aLE&XLn?H`@rjfMBE@)UR21-iTk+r|-F+QeotMe2#eloDUe zOIpwTqKRLG!a6RjaIr9YvSs&t; zeXtvYec-@-)Kg6yGGFAY3vFA@tisNX}kGqV^1UTO+Uo>2-mZouAKn~cQIbo24X3C4Gg7aNBeiQggS>>K*I-+JPYean7Z=lW{n&Bolw zxEUAYVIR%49phzu?^sX&mzg&i&o-t|lPCt6?=;?*)bER)V*~QqgZp|Qsh^Q>@@i-6 z$&={mhdfJuZD)IsID5eQt;So7#2x#(kM(Pfqm0aNPjljSm3eO?`*Ml-L?ii={lomQ zU!J!eaq+e}`Q5Yvd*Y5bU_E(X-RK91!$x!BXskK$^tky$#^;QjBgE-Q>u)l$56Qpm z_l2&XX9Sq%3`K%(nJ+OC|Cq7PgHa^-m^u3=sjhg@irHy}IdN7{UH-l|=S6BqzD)JR z5qa}{8(e3+#JIna_e0iavGq?GiEsAZOzRIdE;Z5*d4Y94%Jt-R; zIltzZHyCOEE_2!$Y(C$}IuRe7YlpdhhLL?i|5lgbV@n;%nkZahew%SjQV)I2!ZsAf z7mU0oFz&>x;=ojq@C6A^Df64^;nNU|M?2(Ekhr2->{HJ3AocW5 zJ0R_o&lwN(7r6n}AATAQ8)q2LFcN1^n6p3X6T3?DH;mhj>@(t@{MO|9ZpKZ z4f_AqT)8h%Mm{C&yC%I3r!?+2{IKwG?SBwtN8uQRSTGJieH zmm67!9p=?W_5=18SkF3fF5G85>-3=cB;!k)_-WEKh|-w^{n@?=KYK}8?Q0GYa|{gn$r*asK5317#Zgv^To#W{d2JO zcNR(wg$JhC4;`FZkENAu1`>Ot(_w8JnuB+pZi zJ@s(vNBKF{m3XbQp7q^o&U1&D6X#EuFEf(ot~DnwzGJ@0c(xI{8Ro?4Sab3^`G7bY zB3s z{V?C0FZe-TU>>I>?b8lOUS&K7yMp;=d`I&@g0sv|N%TV#&iL>L+|PEzH*pB3p17wS z5Ig3D=h+W%>gjis`z5asPy1WH(a3sEGAG|(Z@$vVbEle<56Lsztv}E>*+?EFZZERF z!T6-{DdX)%;%~Hh599sD?nd%B`;qxP%=Modi5vXIfA;w>+f^DFFY!-%_{Dq42-~y& zX`gsv9aq?ndFFkEeDi?o-!rCAvVyq%#QY^A^O8OXPum?xh4IjYLy4m^ta#t}qLFcN z4xefLEaUmcSB&_7xcM^U5ym%-n~m2Q&ofRpE;bS`>&@9OubIy`PD?oB>S_J6i5>n~ zUD0fTcvx#r|7V-0ei08hxgL1i0`uBxzQB05k#mFmV0Fb+7CIQolgtzP^!=CTLG~;8 z6n(9QJ}3(C1M8~_>c?A89I-x(Ynkg886P!rPINOr%E-Qd+PtTc{Y{?OZv9TF``awqa zC2_yjdghJz++aOWLMTA?0r_QIfnAaM6jw|wGT4XygruH)ikf(0f9&T9?3gF);TKu) zxN)VCcw^ljww`$+e=#oh2XTS^LO=Je@j2sT#<51m&$?Y}{gFoU2Knv<>&dsoHRt4~ zuD{N>$vE6N*?5bQxT`Tg*SN*F+<3lmwh_NgD_G|ntYDrym@}`~Gj2HN7=90S#U&Ud z*gN6O6Y5703h=2!zcAsP8;QD1i3=3uUHBJ8;*j%?cx1jfUoLio#Q6i}2OEid^7(Y@ z+5emu6Ramr_LyH_ywbSC$bKhocu%1J=WX|nahmZ=Bl+}gbM_Ca=x$7(KiPWbgXdqiK>y7DGuG2S_PpQ0@33%yk^R2W{1Xg~HI9CttE(*}t7^hd#Aym$dW$HFloAIfhXjE=o0K?a`=J zv1*4Zf@;->imKXFh@FU$phi$aC@KV{q4um%Ge$L<@NfBYT|OLn;{9~JzjMxWKleE2 z+)opU1?W6zf3YuI#NGghL+jUlbv@40zT>@Zzp-zam*#T~<35G8(EjE<>UT#oe%G@Y z{S)qkuY&yqdH}o$?Z-ROo}b@y_IdN8f;jWI7;T;-)9e?1=eRHWC-LJ8EYktK2h0o4 z!!XzDKDb`}es7%D5NKUzVt*3+sY=&rAAiS#)|LC!341TN8QNDW(a)jr=8^WE_?N(Z zcpSRV_WjY=v!HqE`EJ3!9Qyq=HrTCi`))OU*YEy6!@eD^f{oC;cb_t_XG8mM1KRhD zL%Xl8_cM0iZ=Ck|_*X3E0h;gC==;!q)QYZwZ=rcM z3H=^gR}0bB-%Ip8cocSqd!hM#9c>;tzscCGqkiZiun9haJK;BY5SqvKodWE0p!My3 zdOr3E!qR5(BUb}-eeRp?dVd5S2-& z6q+~I{dw$TU^a9e7tq^bA#8>#VQ(m`-+37Fpx=eY`TPbAuok+`F6eh5&UrivcK6eJ z&N_Y_;*|CcwfBtYt38QxlN`Zr9vbg?4kONUD?$%|bD-yO3_TcDLGynN+Wp^z_Wp5w zH?X^p_Sb>fr@&Lt{O^VC4>!X)xE1>Sx(IFEWT78I>njKS82TRT(ffBF{=?97-GO%h z%yaX8JpNqx5jviE?t1s*Z-Ktw_hA0}|KA?K;J&yH$Mf8+FQ4z~4>&4l=NHU>rF~t) z8-Ku`L7RuWu)A*K?3fR5jQw$AOt_K9-*zGoxaefIpvV0XRt z6VKK2bY2sQzXHuC{f=w>oh8n`>ACnG$9JB;i9ZDULh*$kM>N21uojy4YeQTk`dqN@ zL7PwZKlj&g&nv$FmI3ew92)G(d;a(j+y2KN-R9wL9F@>|bv?7OyH4}M`_gmDBd!8| zhufg%WxV_E{a{|1H|DwhtPSHlgT>JJljwZ-9C~lr5A44k@Voxi=&!I2+Rx0V4D2QF zGOU8W*E+=f&%bLn?;X$mD&+y|z~>|DG0cPo(DRA$NxBI~%>ShLXkUsUD>_ot9Yb6u ZIvuuyp1<_`eB5!`hv%X_KIi+i{{hsO0d)WX literal 0 HcmV?d00001 diff --git a/uniter_model/data/test_data/input5.txt b/uniter_model/data/test_data/input5.txt new file mode 100644 index 0000000000000000000000000000000000000000..619f9be9298fe03e442596509debf6de5e019598 GIT binary patch literal 92849 zcmbTfcf3|b*5!YR0*V=PAPdHugNckFB6>|2iI*Vp93D{At4LB*%wkp~2}V@R2@r64 zI?Z&N>3RF@p6N8x(`lyDOi!oZs$1*(d|paVqd)$rwb!nFs%ls5+I8x=&%G>Ze^7?pL<@v`J$swU($Z;FFL4f%_;_rKe(czVtziaG;hkh*%KDEqn%y#?W$x6Oljdga?;PJv%F^QbEfW^Zys2eD)}b})IADAyDHqI`mzT0mW7?1J zAU1PuDeF9@$%4U7&t1iMnpte&Y|6NmFfgZGC-XO>J#`T~l?|rERQJ<}b{; zj=AWtcJTPlYP&QkpWaf+4w<3x%;-EL>9DFsU1x{R=sKp0Iyx+L)MZ6>_?UJp zvL35ugtm?-Wk)t=M=hK@Bh16mt=TaHW_0}g#%dcH8>+Ko)mRwyaph4TA4h#cZ1BWV z*0VY5HRJO~RNYuzuMze3i28U$eZz?Q#RmG9vXh##0Upu7*6ieeGNOj+#_SZ2=+yFv zPKzTtJvMkoDXVPGstQI_Q(x0q*_c&(L^U2!Z5UBqY@oiBH8f|99#K}orR<94Y`jNwWovd-;fNX=o3v_IdqmfiM|5o*(RH!G38ideb2h1HL=6q~ zP4(Gik7$ZVG&PKy>_!23vqNyXsj)c zTH}arjSb#b%9b=|w}1YKlx`Xu6^u(gqGcY@@-U(mv4NGP?2hK_PLF6+Yj#)Rh#IOJ z8f&w=J)(QcBU&9tbZ>0%zEXC7bM`>dh(Zpm%pUZJ9`cAD4kLObHn66YJ=&Z-<`F&K znmtiCqS}T!jp#`=Hdu4?R9j~)Xm4UcGB7}1-tfwxN8+s)ZK9?`q4*?WZ}YHFxysL$T_h(0Kf=)*Xok79!# zm$L26*^bX0k(R8XrmDI!+vyQ~;t}l%BibDs*i*_rZO;CkNA#7}?B5rTNHbJln|;+I z`dWEJUymdDMr`n#rR+a6XaBKiM2*!=nxk)dMBny^z7t0D-PpkQO4)yE&i=DU^!?WC zzkL3P>KYru2~eB;i%0Z>@`(O5j_8N6!5@{f|Jt1Wx1td>)z)bD|JWn?iAVI)FruHu z27X@3{(E!wKRlvev}XVF^GBo!Z>rYb&3@?-{i-~oU&j&sCN}unQue=^v;SQ*qQ<8B z%F66_9?|bTqCbQY{V_K1r&9KRnzR4w5&gL}`+xsrL^XBQ*YA|#|HmV z$}8GcR5a)9R2Gb@uCl(mu0C%cEV_o*lDvb&IJ&%}L=7m5h{^&FYMZnp z>hnW`MfBh>iT2=diNS*&!r1B&l1llJDx348RN8~1CAQ|rs4VuNx~i!_xFA_aCK~n7NL`kK*r^@EMm&zg!Y8$F6EA!sLB6`qAqCMy3&Z|@w zcu=R;tj5Z`I#@&xY9!i&T8Y7fI$>O z!8H=g9$YIDJ-AL%>}!IgQa({-b3RFBkq6cFRSos|s1zcP*+pgq$J;izT+^H>oW0ptiBTN*>%S6ZI2v z(Jc}~Kk9X5-YOD3xK&c@>o!THe2L2D{C1U}?}6g8UME$)G+0CrmPxb+%OwU6RtRIO zDw&qW$EcBqdN$;xp zlQOXfPf09$@U%$u;2BA=uV*Ed^0g|P^XF6+c~DthRbQE}3l`CX^%CvD28qFgjl$UK zCP}6Id6muiW|j6}i^SIa1(n4f)Q7D7qD<_;OA^Z-yetwuctuj|>s3jme5=al{56$D z9#n@{@BH;(5j}WAqCMCqF?jH%Ft+-Zq*DI2%I5qXmG2Ot8C79s4ViJR-X`6=R1Q%^xzYT_F$L9;K6QT zY;}*MQvRvR=KSBOv-vd{y) zKWe7_MJ8T@A4n{(!M}<`4}K^q_Vpu4rTo9DY|j6i$^s9WGL=O&>6#F_z zlI{m8bw5y9nD(%715_LaNS?GaY zD;gW~V`X9wj+0pS;CPYf!3mOLUnfe^{XnJe2P%s^sH<+$&7*g)h#vHjXb<{I3?B3o z##Z}F()~cC?guLE!9a<+AE+$!pteeH!TBjNQ9t3W;#7&DAN9I2KTRZhaJr<}*BO#@ zKTxUrfyyEeYHM{M=;Jvc`sdT_3!*w-LQx*w?2{Xk`*2TeL_@}a>ZdN53)Js2)AcrZd3TOBD$ z_XCx>AE>kkqb2Hopt8_|`o;!D>lm4+pKzjGATjjQcA||Hi5^@iDfTr^lI{m8bw5y9 z3*P6_XCyo zV5UUf4^$R-5Z>ZL)}AF3^%M5NY>A;C^|~^jBN9EBD=GFhPm=BjDs?|lS>!>jUfXp) z2o}+U1rqH+Nn-F|p)j_(NRsXcDs?|lX%7}l)crtZfd~3_MwY z^x#%Wv9H@C>3*P6_XCwh9;jWN_Dh3B^kA7pd$3$$@L+{7wz^W1?guJ$KTv58R!P+T zKxLr^4K?B1xmzao;2w!(4_1pr5AKx|`?^n(?guJ$KTuiZK|@2G_Q8X}B6{$UM0@bC z#Nfdr!r1B>NxC1X)cruEJ$PKA?guK1J!q`f8aycz^%HjTQxZczZ9DmCk?6rQl44)a zO49v6rS1nRi#({-hrreOx?mAKSTE5YY>*f{*eHyxZjz+?flA#ERN8|r5_LaNS?EDs zI1us|WnvFrl34cOWs&H?E0SVguS(MWK&9>nDvLa*(FW7~AXr2X-jHYywn+>gyeW*W zz9mWb1C_cTsI&*~O4R*8WuXW9SgO7@e_tl{-~)+e4?Yx$9(*Jz_VuwO-49giexS0* zgR1)4s``9qu!tUfBGDe~k{CSLEsU-1k)->9O5G1s+Jmo1)crtZu?J0+dI$chOzgqe zB$hq+x=8fk8Pv+dhkO@v9BLV()~cC?guK1 zJgBR!(>?je!6JI_6N&cVrxJq)KNH4Qe=bS)1C_cTsI&*akf{5C%0dqstMmic{FgHE z8vIIPc@2Io5kQuhOuMIJQOHR!G4cflfh@Oz2&;13dm2Y(dCR{tbP z_XCx>AE>kkf0n5GfyzP;Y8u1Y{}-9qgTG2Fd+;}r=)vD5#lHR_N%sSNk3;tZl|>%Z zx845Q2aD)I2Z{Eeqr~7rCt+;0vn1UQRO)`9(jIh`sQZD+Vh`G)Xn&d5g99X%J?JhH zJvdNO?CT&&x*w?2{Xk`*2aQeo)bY?@5j{9eqCGfVV(_4cFt&PxB;5~G>VBZo9vm%E z_XCxM9@OjTtjv#a)R9WTJlB?gtV>Kk8NY1Ci*#>5^h! zXGqfhK&9>nDvLa*uh%y<^6Fp_J*bgr4{9X_59)-m)p|*~AE?y*K&3rslBoNE%0dsC z^zmk8ewIwU24_nwufaJY(Svg(#l8kf()~cC?guK1JZNmx59ITq!6JGvOrkv)E-`p8 zLKs^eDM|MOmAW6OvVBZI(1ZFqeaR*tBNKaYfyA;0V@09|7fOnKjgzGNflA#E zR2Fzpubt2Ukke{Xk`r2l~uWKTgfB zmWe&MMq=56Yek|5*GY9(Ss=x?ZH%u!GjiIY;~F> z-49giexTAG%#^76fyzP;s@mR7XURnUwA~LRhJMtm?gt{#gSnDoU-KmCexOqK1C<3H z)K^y3>T{etSVRvNNVEqfiNS+~!r1B}NxC1X)cruEJy!=oQ*HQJ(cLn!2lq%Ud$3w0dT_6#*w=lMbU#q3`+>?L4{B?}XK4=x zi|D~a679jm5`zbi2xF^jBtskOT6f<^RTy+nJkL1OS=qcFC*Ns{gdDs?|lX%DtY)crtZ zp$FP-daHO*CidVZiDeI77Kt9bA}RLuswCYHRO)`9vdDw*qjud7f<^S;4T<((o5bM3 zo5I-YTat7?P^tTYN_+6GMBNWm7JAUs5IlHaCh8~btji} zAE?y*KxLr^O4F73&R`Ke_(Y;T*d;M|uv-{g-6Ki&1C_cTsI&)Pk*NED%0dqs>-58- z{HrqY8hlM+c@4fU5kQuhOuMIKbuho3@yD_BGizAe!nd`DvN;Jd=u z>h~n+exOqK1C{pR`x13OP+8hoX9 z#B1;?iRCr;wMg{fHflA#ER2F*Bq}j=T7c8O&zn5qa{va`U@JC^6^-q#? zKTxUrfl7PuXNkHWs4ViJzNSjwPR#!z6R*KvC6?FVZz9oyze|dJ{X>%O2l{Te?guIh zJ!t!Ra{FKrJ?J3O9(0r#Jm@5ht#+2A`+-W`4^-NNt`c=WP+8JgH3KTxUr zfl7ODv_#zxR2F$qUmdVBZIz=OKVCVdMtuMQT`gBpqUpjKk=piUTDt(ThndgcI!oiJ>3$x-uUt z5VBZI&;xzoRhc&ji|D~6679jI5`zbq31h35OVa&7rS1nR?ZK52 zbw5y9VBZI$OH8f-hpSy#2(C+SoUC!Nc3Q?q}bOyNxC1X z)crtZkq7ma`hAT&4;Imb1rqH+Nn-F|p)j_(NRsXcDs?|lX%7}l)crtZp$FAf`tUlx zStj=27KvpKT1BD?L532P8 zy2|{)U=cldNTNM>SYq(t5n*g~jU?RAE+$yptkMJZe6g59;}yW4>m{)9&8lGRyRq~{XnJe z2P*Bs7Kyqas4VoLwz@{2D7+{WufakQuhOuMIO{Nh0iKp z4;ImbHzeAFZ4!eAZwh0pZ%NYqK&9>nD(%6$5_LaNS?EEv-tF}B1DV)^4kQuhOu_TU#1 zbw5y9=t0}B59PmAE+$ypst}&?<>Cx z7SV&>OSA`nkQhApqcFDmCrP>=sMP&Hr9JqwMBNWm7J5)wud7u47n!J^@NW87iJ_mi zchkR#L=XNhDfaabNxC2Cw?cJ4P+8n zD(yj6iMk)CEcBqhE_}~!f0?MC5Jd+_4E?l4QFoE(!GV%uUk6Fj{XnJe2P%s^sILp3 z*&P}zq6ddbvVBZo9t@PI z`+>?r4{94KHB+a^ME!&)I#pumr!9(36Nw(2E-ChPh9un&RO)`9vd{zlAh0s84i?da z8j1FxR$}m=P8eIQm!$iFO5G1s+Jh#Ex*w=4^gutR(7WkbGO-6|ODubEj!5+2TuHI7 zL6USoP^tTY$|4VHs`ML_`OshyJs2j@9t@WlJQyL2t&WtW`+-W`4^-NN(Gqn(P+91K zers6g&KQ~4g9{{}#AP-49giexS0@gW7P?GzW|5!6g#y!KD&|2bT$B ztCvgC{XnJe2P*Bsl@fJ7P+90fV}1C1_G+1^pRkj!kr?_>uPgIwMWP4SNs4_?L5A-!V z#plvs5j|KY(H<<97(7@ZjIFMer2BzN-49gSgH;lBKTuiZL48x(E%R=f*n@i{mOWT4 z5=Y;}_)-49giexTAGY>}w@fyzP;^aT{{VBZI$b-sy{g^R-Jy=8!-jHYywn+>gyeW*Wz9mWb1C_cTsI&*~O4R*8 zWuXW9tU{lxy)P55!3Pq{Yw)2+^xz{&v9FIM>3*P6_XCwh9<=>he!eqUL=QfZXb*Nt z3?A$j##Z-8()~cC?guLE!B-^eexS10gSIpIt1|H#d`)6`4ZbcCJ@|&C*w;5D>3*P6 z_XCwh9@N!@Z_IuxSVRxLEzur)M`G~cyTaJ&_ay0lpi=h(mG!>3zTd4k zyMK|1J@|pdvIqYv5DRUM9|w!*!A~UGgP%$a9{fxg zTm88t-49giexTAG{6eDc2P%s^&{ux-+fey0WnvG0C9&+muSKE%O2l}Hax*w=4@Svf(D$Gs$U=cm&AkiLllo&kdB#fAE+$wprN|$XYIX%Mf9MLM0?OzV(_4!Ft*xXlI{m8 zbw5yP4+cuq{Xk`r2U_d4-#?Iv*Wgr%BzkbVq}bOPl5{^%sr!M-0uS`Vq3|1@ z)xjcqP$SVE)JhB<)Cps&^^$ZyP^tTYN_)^GQTGFtg&u@p)%S~JVh_%iSoYu?k?6s> zl44(jBHyNbW4;q;cqPMpDM$j@Ht>_W!tARS@*q_@oz#Mm?*=~ zk+XwRW%yGi2dB#LM^Fw)mErFq9hxe`pOiT)QMUbcpu*xt5dm zQr5NoI$1U#QHI~x%?2jQ@TWDhlM`k5NoID+-pZ=*hZ(X{6J_{|1leh+Qs0cpPEVBK z&Xk>zD8ruv$SM=`$`> zhQB*@QKAgLb)8+DD8sMbWzBml+uk^{OA=-H70m3?R2g>1Wr;HQnO&YL!{Dz-mHKQk z8=ok{@7rcqCd%*?l&TCvMs!)C(7{aSJ{k28GctP zo4L2L?U#45>l0--^|M*2Qa^CdW+%$<&A4n%stiFpH&ur8J1+Q)T#*ky)Y) zAE{<}q71*clP%a=+4lAGtduCj+jX`uQHEa($rdHb@D`ihkSN24HQ9}cGJGQ~TfCPt zez`ciDN%+`O0t_%rQUe5TM}jX-gefyx3VgHjF#P+Dnl^bmMX*R!IDH7eiEJCo+!iD zSF@#iE8D)2kS$A<`oTrEJW+<9*=8#eW%%kxwlYzM?`URsq{{FLcW0{9&-}Ahdn?<% zG?(3#D#KqwygN~bFGgkeB+Br4dA2%L>Z@AWy@@h>hc3G>QMSF5-k&JL_XV>DQe}7( zcraCl<#=dsW!slCvWF98_$hw&NUGG=nX)yBGJM*QJ(?)P*M_pkQe}8GdOT6K1;Z1m zGR)GGsWPnQQ>ij+!>3ba_|tUHB+BraSN81QN*&K?b8VswpQmTfCCc#Cl5Ab7)bHqL z>l0=82rS!>D8r|o*~V0Wg7 z{!*&c=N8$^i8B1KGkYafhIgA+Q>CuL+15lEe$tq|wzsnF2b-JyEv30q)pK+0gb;LAEnhhBf+RZ)MxZ71^#t89oWi zcJHlh`)DTHlPbf6e3~fRj>9jU#Kq{0`1c; z^~@vubOhsJ3GbWmlnzZ44<&qG^oV^K6x5Hhev$EG$uXKb(n2_u93Jo#{5d-dgHOiON{tOzpq-)yk2I` zzFllyX#{S!Fy6?%VgJ5l{qe?0Mj+)Ux`O$sGJnB1ng@;7B%JlUC((amJ|oc=QN~L> z`2oVW`LE~CuejxRr_2IE|5DE~#3aTqu=c4+jiYlys!$_QUPV9+Ckhowyut!fjG!yjw`!+DPQ1j!F@ z+Tr;WX$Pd9ai@AX{iNfep8i4X7!OE)_&v%U4>WRaPckR2W|=QCHW`VlD)Z4s@+IRV zZVq())y5X%Nybx*Zx}Zk@r(Vk#rngHeT}<}?0epmr&!Ov?_th9!e9DjzR0h{JN>o# z`A>|`7=c$2j1L)4G@hK)rxbrvJ^7CLOY7k~5E5j5n1`ggOxKpB&llz<>Wj?Dzu*bh z<3If}kF57;$@4FoA8%yc*O=dJWSw~bxXyak>os%Yl6d5K-ito69pmJ@W}LhSk#C8= zA%2eYXM{QN#(o-ZJ?HUAbIylV=J;t^!Tb2JR`fHHw>O(l<$(l=m**1wR`dIfoGXc0 znZ8h#qHvq}WaGt2eM;*R1^3%TT}a0j*6dxu`gXLsQ-bU(;)Z%S&(Tg-Ki|&CIl}u2 z9Q(AL!G7*t<2vIsBXLK(Gtc|G{$nHi=2mm!kT_uf5J#+Mv+XAvPcV)&zGY+_SJep+3z%fdxQ_7U|U z`;B(d-(_K^u{#wBvTkYrm@%J>gZ_Hifc^QhIrG8%b6&pZdggnB`9$ME#`BEaYhE`e zu80HjNjKL&Z)Cijo8-ZDNjp!Nw>MsITy7*DFEJ<2@!S~eiKiRQ*`MUKTI<<|rWK1U zunz`PVZ7f+9E?cnx0v@e9-Pz@?1=bS_jtwPubl85(U zR<0%viK8{HInhWQb~k69`k1p1c`q7fJ@G^QPOzSNBn}?6p8htP)6ZJ-(MIC&aC7p} zYV+Yn=7D%6ep_5W%J_hh^PKpdZawc+OU*kN>4)_u-=E_8*+%A%eRrPqV~s}`0qn?& zce!Gok@>?uMe;fG!v5-SL;OjfE8@4i0zZ})8Gb=M@sOGmhnyQ{+JHE?(0rEh7UO&) z^UXS)V?FIM{^zWpXk@?NYEB+G#C*4ryu|oVvwoZLP9yR2jye6Y&cw+Du7BD{T!Rl; zPh663###T8@e?EM!?EK$Vjjt>UH#mtMu7Q%FSEct!EfTTABF{ZSfW2Z;cq28r6aBA zm{yn**9FyO^}@2Wx6p2HTK>5d{V`9hTW8x-5AS3>>rdP<|KztjY}ecPq!GBm0`Z4_ zhz0iVedaeAZ#MG2HQAi^g+Auz8ecJz=g7nBtmj@*WzIS?{%Y%|8J{-}FwQZ4Xe2LQ zW`36O7~=~@=6jVn?N2d3+Q@!mob2CYU4OrkahaB(hozpz+l_A+*@r{TQ+Uh@+V5%3 zK4kvW=irQ=bzwcI|FXDW!2_eLC(kT3=e*~AhWT3V`t`;Z<7-CtbB+1M#zDrhM&j^N z^LvfN8Tos$^{n@5^K*<7jI1;L%(s5Ck+@=iu;H-eGzCF(-jE2&{0Ch9LtyV#O^2a-S8xA6A1?_lJ4u#@%J zfjpnKM?2(Q+QE+aryY3O56_`bIqlIO%`h zCyeAx;+^MKy8cchd6M{}oikiN%-CdP{~l@%K45;lk#Vq(c<*3e(tnHXLHvB%`ZtW^ zN9J*X^*0-Vr!6pE-a8(*p8Vb0oPKzoyi#ig>p?x|4d*C&ko~;Mie*$7Ip3I>B*sN6GW+%}YiA|M9EQ6`W7w&2uB?9?w@=U|rCI%s1_GzEKa7zv19$3nNh!An!kE z^$shD!(HaA_nqdoM)tuo=3|WHgJI^xGw&_L3;F19+kIdpAD(S~yODK#&HOMU?OtV0 zemK#5u<=qO`+)pUTr72cwQ-}7b>#hlcx`rli*b;V_%p2_A9b>Vah+~XUY%}^UpxqGS9CG5PRu9gN1f}*FK?UQ zZM@t_-elfZTYtXsMkD72@k@R^#q~XmyN%ZwTN5PCPqqFMW22G$_qaLvVW2tt;V$!6 zjpU&l%&#y$Y-Hasu7|7#4>soQkF{ok@c`rX zM)sxYXQ16@G*Lmj+*49B^4SZnVBU$NtF8akNd93x$WyE%{miyK`TRihJB%ZY#Mu;c z){}fY)_U^XN^{1^_!e5vy7n}$H+D0!FXo%SV_a)wTszIl>+Fk8)-N}nX(Z2*f0-}r zY*s=4#07eg{8yCc-mu|j9!PLQ!iy;TIIS z^~audp&gr7urH`bPkZDG+M_@A1$y@5S~o!6gHN=c_b1|+yvzO~p3%SM=Qw}aUoTp} z#<;{tJT5mUj(eJu-!_=DpT?PwHy&(cUyz^ShZ^=J5`+4>id6w}r zUh-m{?OEr{e713$k#(psZ!tb=WL-I*$d{}O>%_X#{uVz!z?i~JE9M*7k6X`&CBMTN zCi{Z;nBMk% z*to~I&iJZvq>*?hAJ4UZl<_6wcH=ol+N1ss>+$Cj^Azr}qQ8;(XTHhv$E6RNe`q|_ z*kB}1ZZv0~nU-OirOS=I%QaEhp70{dJf!vWQpK0#RkkOdd2fId$L(yF?i0FWK6nn~ zIqYbkcGxf2r%37RZ&TcABS&+{CJo_66p&psuOlCM8@KddWp@UHbQ7@Lgr z$GNf2divX9evOfHfc?k$^S0}`ubNg6kF{3ZX2c$U##leVINLbGIMO)TxZXJ3IM7JG z*l2#e@d)E?BlAapR+r%&OFfLVgWu0ve_8p7e=hL8L%)^(-25}O=UcPcc+6)k{<-2V zS08C)-T(PhpRFBh)rH0`M&gD2M0|tfxwK#EN4XvHUtja*j3bQ9A92Y3;9NM|c0G-} z@8Ex<^=BBLHWCN3%<+f!LY^lcdfASBbAb6A<1*vXM%HbDIp^d?^Y@JSLB9*F-(`H= z$h=%^&VJtCyv6vkk^Rxre7+Gl(83NB2~IR8pInjDLwR3?7h%!)@n!c#+CK3?z9Nt8 zv?2c8X}-$1#7N#I9->>9dEz8<#>2XntY2)r*2sI$3Fe%qTg{mt z@)PTStm}!#rRF`2?2~8Bk1}pDl1J%>dAP~-j~a>B{^lP`GB>o4PuQjsYY4>gG z0rnM~al>aK7>6gE_E8^+P=FsK`ZRvor-@l1t#(ZpBlAf913S4s#nhg5+9mDsTnFnx zo@XB5w8MO&r(Jl8)RWiQU-!G;6OH#7*)N=TdBJ?Gb3OgDKgi>p7j?FK%=n6NfpL`aO(XM6JLg%? zx|)`Mj`Ji6BYCLSoV*Xm{5T7o59q;u2*zm%KiT}!L=Pno!c#onih*gx7v!lwJjVJk zkLkRzUf8Ak15P~C4skx#9Wj5z>Alu(HZrfDm`^pHZp44niftCw8b36)80r5K^Y@J# zjfWWDGV=T_<_{S!G~Qu6#yG?{(fF!yhmm=C#+-A4_L{9{|MfNRZoJZXrtwo_e`7D> z(?;f}oB8cVUJlxo5 zB;GDJKh20gTg_R=hs|lw{l+SZ3tU}&V@L|wr=p~nySzVNcdfcQ_;h4jS2 znpn}r_3ezz6X!gL9f*B~rE6F1kGZ!``zZZ;B^i_NQyM;hymoZIWn z$I&u|)u)_1lj$@a&aPnO{*S9w^E|AB`i^SbpR}408G7>+`Cwce@ z*WY98VZ7N$e%)gJijg>Ae;sZ8BqQrVytP|*_qOZj8aelu znD;XBzJ0hk^`Dv(PsHWD)-!I#!@To;_@eD*8i9i?JZF5&NZw_f?6)+IXWRZNBk$v; z6>AZUOA^j`$G&F1F=Kr24-kw87#}urz9eR4O5B3PRf^0X^T~Yk9DI~p9%LM6Bwx@D zdg7wRcI08x3gV-m6+4Wq3;q!g#M^e;F@LPzaO=tIoa+OvzsPuqk#mmr2UvfSk-T`G zc{k&JM&jXB^NB|G8_zR;V_Z)jy~iA=wm=?WeAI(yT3Cp}$otn?^DL={E<#9fLc)tE zc_FP&<2S8`vkn(nvB3De@oeLzM)C^dCJu>L&ZlPElfR~#FE+9s>~G@cIM)-0_;-Z$ zeT?-+=JihVn~jX8r}+vaaY|h6ww~w7M^9Q$etgCJM&qeQ&V`}opBjm?{S*7G=Jdxr z;2-;@(ssb@7H%`*C;nez{V_a{;4*XOmFkNTjGU{9y5ck|*w56Tk6`SXa47x5Q|w{| zd5iU9Jj`dikxDDBjcQ7{Z8XL;|?Rw5%*SC zuzrtNK^~*O8P;&D*3%Z1=v2(U!CxMQT7>j_zx#9ZnlE`wbq<`HO!oKYBlFQ zjdeQDdft~-nsYvZ8?3*@NWAtk=RIem`9vf4&-Lb~7?&Hj80iOl#z#KuXgk*bGV@Q2 z*BS3Ljxw_DFPTp=t~I8x!V2b#_#n^1x2F%9GcNKV`v`O52%h4DR(xn=zEis*&bo7+ z98hFfpntDGpQ@SPt6af45?_6-Coe2ACmx?RC;o_wKGxS6&o%4=eT+MdoP*5s!PYa+XPIv{;urC5bp_`a z{A3I3jMp018dEqiQ5m|psYE%!DgUg>{7SGO~=Uc_l<>(l3mXXc+g#r$#3 z6SuFsoef6v2J6Fqxy$v$Ta7vUew;b$FvFa<8e)E*@h0PNBYA4QIsQLlewC4RW1o{( ziEH*1?GShDuPJV4fpNQWe4^)lf%s#7$g{WE{$At##>vKAM)C#Yx4HtqnUA$p7<(Ey zhna`OtOEbj{eT_$w-!Z$tiyAOKBXV0iiC5%pe8QiAn{E+qbJT-cj9}L4R{~8+x#)( zR^vt^(9r_^4KgQRoojxjv6qo^k9NsxtS|9IeBk#$Kabzfn)fpfHV!kg4;deE%RXQq zK45$DYNa{x&Hg>pdiLpN^GU`O$hW)~kuOiMA^V*3D19DI{K3C0;@_8j4*N$D_JOQR zhs2I}Y-c@jNBmF^rtJ{7w42(|9d^Vw?a&{3_62&{Wqdq`J)C}c9zC4?;k3(h*rQMF zQcin}BW;Izo?|}o1OJ#Wp2sd7598u_p2H3fVoyEep?&(NU!G&U_{DP<*^u|E-sT|t z@ImW2PuS1TTi?^T(n$L!n$!Pu^Ye|{jl7p$XHNeU%sGFLFdu0of6g(##(0mBJTSxj z6XPo5#snvulP{Xg*Bcia-!_gho@QiTOe=1-um*(@f61rhH_m&^8DB3HM$XqE=FIOQ ziQQY~qYLyeSkFAu4(6=?d=v>X4_TszK4oEl6bbelyc>=B&fj=J<1^ zIdRPV4Yr>1m2nWqm$`nLv7d3P@d@McM*P3ZocZFt5Z>SQl}7Rb?`?I~A8b6|xYF2c zyw&)gkvzbDW4>Ko@v#Nwk2u|gfsys49jY5FY%#Kq)b9(4SKl*| z&g-o2VPw7UGABM+zbmZoY-B!`n$r&Zf;b}o9%Q@QjVp}gizm!S7@sySF*1MK%*n5d z%*n$?Cj3)#;*5RF_;S~?j<=b!Z;407Lw+A_yN8XFj603knU-OXCG1&;LDrKG$`2$0 z@y`4ATGy{OUYyu}Q9AUC8Wz~?Z#{7e5=ZS^PyDbB*ijD>&+Pwce*QEg@zf{LKW5H+ z5~sxBZr78~?luQlSK^}H6|Dbq^UcOl#*&eKUp6OhJDNXZoMRks+-hVz#0PPElIuTB zp1aI^i4i-Vf6)3Am>14b^1|rEkn=22SKtSEhWf<_#@dAMi(Z--FwR6>NN-4L$ZLtZ zkdi+zV11vng8016e1Z`_cbRiOTwzXLyw;p~Cl4KJJ^T6q^O43%;{YT3?R0bY72}y` zJ?G6s=Fb}0Pj{MMZ(L^l$Vgm}r(dw1d1C)zf2`}b8=1cvb6|-D_66&9yY=@Qdm2;V ze7?gK&nE`Vd!qgf#jmu2Jd^5E&ihfSPk9$Bcs~G%3+9>m;CbSW^`#x^_p?Ag<#~8& zM?2)#%WbgG$akOktN z{J?yC?E3Y_H13$k8LmIW2&DGJ3H7``kuS*;)L)EXJTT!$B>am~;*hwQ><3>pGM~)H zQP!_EK4Ro~{9a=HJx0z!=8toYc){PLwr74WG$&6mkFC}(F)}~gC*ZWV&2~>3$-nQK zZ#7Od5)bT~PpxNMz09kOLyf?83*`M~^W%&ojW-!nV4m0qA0!6&m8i><_-DS(Dc3}S zeU+#`LpNB{J*jyq;rpVuBnD~!U0hE-B~PMn?|L}Rk)O#Y#Pdfspnj)0^ZJze3C1&x z9~znO+2)Lwc%dEUf%P3_d-lu8=EUz}bIu#yBZgUjsj;{5Wh3i3)BG&saYo)l7nx5m zlK0pT$5_AC2r%Evtr$Rsk$ucMl6Tp+YZ80r0rm41t}(uEyvoS_G_7DfaPkLvX{9ST zXRrg;A{5|m>vt5??@N9DJ`9Lc;*axz$5{W=PvRJR;(3)HJi^HQ-)UZL>}w=HvJP)s z&-@QGZ!xl8PD=Ew?{@2NH1;y`{>yw4&m&xauW?@T9C3z!#QPJryTEvR@;rWSxBhP9 z9wYA&^h-PBQTn~a_NN-zKg>JhWqiyt?;Wf!ae0W_0oEcIiT8Ep2N)S=VqU>Ipug6_ zP~&i8n%A+<5sc)E0g1kd&Ppnn2l5f;+AiD8F!El7fA?F@dHR$&d7kmyC<&XqPqkbWqB-Zr!{#p;$@2@%HyD>1f#WT( z|K^&{HkORX8B@60iUmg2f%WD2_34A=>>H}_gZO&U6%#So2a*pkV}8#@u@79ak9w+E zpRQJrU#SPvyox%##@Zc$5ZArj68po`BEeEu-3e{ zvBCJLk-T79{{JsluG+Zz!lo zU2kEOk#lQ`IpZfE;5Ye&_sfm8hjV_OYdv}EX!8z6@(%l<#(KuT$ow%Q@ju=CZsSDb z%|_<=Lv!Afh`*<;r=3c3^8DrIbw=X*9`kpSdd79B^}w|VM)v7UbLRcd#18rag7Fq( zC*vw3^J7|uEoCVR#6Rl$5dME++CKiQaC-xdoC`;qHyW8A*3Ie)@*Vq&^|``l|Wa9(IRYumcr#W%QbL9Uj*S~4xyu8Gm z^`B|p%gDLGJUwXr^+sR>f{{2m+#J8|GiN?|jQN7UVS)J~Ur_&L@r5t@Kw?iEat<+1 zC)(~qWYgk5YOAq85i%FldM0)2#^oQSiyW! zPrD1OAfBl|$imS^_8ay1&pv6U!pM8WIp)MmVpgV{3n;F#Fw|HSueAHJ?RbuL!0!pJ zXTI5od9``PV-{$KbBH{9u0)!{*Yq>+gf)YV-Hm4(iIdglDdbi#Z@gz? zms#kCBEjAXXC4!EAw9IP=8IKyu^H>p-kkO7Xb!So*pWYoXV(8@KR3>}#>lzGIl{cc ziMJz@etB;oe?IDZ;-RNG<0D?!SLEN7wp(O8!bm=1KM%D2Oyl#${f(Q9#0~qer}gg{ zhZq^}67vm4_UUeOfcSgIik`*;jVBt(x6HrQ<-ZTgI=<@))`2|9dvKb^X=j*)tBuST z`xQIp37#VB$-Hh)3^pe`rMFtqVw{{Bm=n(+@j$$GvK@Njo_Qk9d7f@~4y2yv;q*g$ zquf9*V>cu7y3?HeLH^_X+wFSh{Sos$M$T*2b+q;LOZ&9b&-HzbFB&;dX_xbFqwD7w zcNu}JEZktc#z>s9FPQK7t{-e%VkD0-KXuk$YMf~7Z6wcf4&Grs=h3O=)ke-U@&NNV z&GqD=dh-?|e#|xp$OGgn#?{jXO-APD1s*isW+V@slGO9wOgaceQ{WdivoUM^8WG8{&-hs~Cb9=bH1r zb+kGA%CzDh3$GZ-7x;tS3JW_>7$+H*C;AH#ev3KzH>s|`fAs7#IOiLDUwGKQ43c`{ zkago+%x%Xx$GXk1p8ApI4;W83t}|X|tTQs7Z<&)ndG2cKi68oV)OzxCtN9Ve;YQ+` zJkGe^c0K!*bz)!das6?|rAFdxg!$#hLySX>eT~DCdd6*a#ViX$P#78iWb=cK#CKv| zk;W{pvF+YBZZcxm$9%AHl@VZE@SO-o_U~BpGmRUK^lMtd zy29CCyr(Rsq5xSxs@aF|hY$*oJlnP^BkQ1N7-(L z@e<=MBkMq3A&*|{dh+y2^LvdgM&gij5vL`6N$JzrgjwjqGFc$1~Qm5Br#t zSC*M`AE2Gvt)FTH8Z2})o=1g|yvTkbzhHKn1;#Vo{32tz9^{dMt{-E3(s-?LzLEDB z;&{CEJ&lYDzq?!ip7AXs^GCj+AL2Vp?AnGp*#;KPacP2 zA2=_mKggW-grs_(6o1KYaPl5`5k$}W(?R-NXIEek5~s8WcCbDB;#2e9M&f6!`358V z-L(Ao)Ge>mfx;CkYJ zw0U3S6-L&ZcsbYlH;m*t{DZr?;&ls08p+S>gZr(2)VR%_Y8+-}#8F~Z+8pJCi;Bwt-<&Url7e4~+lH_H4L;|3$~%lut#J^TG|bIzGn z<}VwU8G$_(*gyE2=E)k{?chPAM46KAwSHt-FeA#SC}*I&F1rs?At-+#OW#K*fWl8)*o$r z%s9<>my!3u_2%Te>&;nD_91b(&h?C!_sG+&?`LEl*qu{RuzE zxt=(~Kh6R2*2lIZUg?MRoag#x<7VT2M)C~n{G9ciQ^U;r7*8~IGQMb>WPIG%&dB~F z@3EdMT;J8$%lNjDxHPRe*#dqXXZ}0}#xyRdzS_dAM&2VI#K6dWCh7{F$3NyTwPSu$ zJ$Zq3OY7l#5R6?4kom=&dBnb>E5IUsXWQ|7hvd2TNj>eb{;n~<;~*~Yay|OP%!!+G%?BCJ zG!mDb0~cF=pOJk^T-|B?-NsjpCm2680*voAD>hJJB+sDUieNm?$i5$&)TflVO7-+J zAgxF^>yxM}jg=q_~FDe`8#t1-Hp#0S=VFD zxp%y4-pxo}-fTYFNITb>Pc$;$#3yn0l4$xko-fmE*Vj10Nc#=u?CXcjS+C{h zU5xnur1=!%d?W9pjIY-ETZ~T_$^YyxtIOYW7-q#RBk$GNEwb=13M2dAVDn{0#+jIZ zhQ4FXNF(#J@iUg6uc);8EhFo`^YhI=qrD-~j50sS$b2$C?9Yy5-iM|0wo_L=viTz{pJdD>!rjgfr?zs~x5j3bPzjpS3}eyR008W{)evEI+S zeywr0k@pqGd4%=smq*O!8}}FwGy>0AU>(Mov(Jvi(AX#8&<+R*o{@0YBT*Or|6;6X zYR~><9q1oD_N+JW4s zDiR!Oo<7e!Cgz3oprq#Qgzt;aOblugzAs9AV7IUSS%1!R{3o8!gTw`KkNqfj$atrl z^M1C){9YsX&ynV=%W>wf8&5RyzW2U4^K`TMIwQ}|FejebPsAnrWPjVyKk>=_B;K#G z9sA)WbNnWMZMVLU@igN$<2%N+#-2vb4dQLI^}tOQX!lJjjDw8qr^=+B^M!GNoalEX97;UGi8J)@_BJ3MLB`Gcv9Fo`Wq$5H<0j*BV^3qf zainpLk$Gi*5HDxDp1AK}US-^4yu(;BPBOCoP3FWc+II7c3F zL*tC(!Q0JijgyQ!jret^d4D7CH4mGQG;TA#Y9!ux@4)|eTu)x5KjQlW*OTu#4~AP$ z+!OD#kADZ-j&}K6<#g+rPwpqYUoemC8|H`kyUWir@APAJ`R8gqt+?Mvzo(iHF|uz> zD=xFZdB}L!PqSRldOW~`#+AmIiJpCddO5-u#9fIU`!rD(QsS9;XMdvSLE;-Eo|y;Y z`Eoxn(8#_d&-SH3BXP#M^4@)g>xoC^_gw2)=cVS1i~P#D&wP_7887pGv7cwXTFi-S z;$pt_(~OK4e`i?Fetp@zkCFYg$eew`JUwmwGsZ)W#ODF#lZ?y{{SpVqxPH0u6C-fE zg;~t2kq2|4f%meYW!}=$U^~Sf1-v&;rE$1%iIH&=hvcUfuID+z_9sV#Hti8EyS~BXQ5?EF(dIk*8DCb;~-DWwSI{%9lT&2{Ee zjI5(+8CuFx6mBwKWvoi-zbM6@FKVCKElm~XX&f^TnByn{ zKHwmE?L;d$ms>E{2X;;B-!#8G(HGHCNdVw_Cr~_^^?4oxDpPquoPox86t|W_@a{C+{%t%p2qQz;>+fd*+-s z#N$!c-)m$Zc+aLE&XLn?H`@rjfMBE@)UR21-iTk+r|-F+QeotMe2#eloDUe zOIpwTqKRLG!a6RjaIr9YvSs&t; zeXtvYec-@-)Kg6yGGFAY3vFA@tisNX}kGqV^1UTO+Uo>2-mZouAKn~cQIbo24X3C4Gg7aNBeiQggS>>K*I-+JPYean7Z=lW{n&Bolw zxEUAYVIR%49phzu?^sX&mzg&i&o-t|lPCt6?=;?*)bER)V*~QqgZp|Qsh^Q>@@i-6 z$&={mhdfJuZD)IsID5eQt;So7#2x#(kM(Pfqm0aNPjljSm3eO?`*Ml-L?ii={lomQ zU!J!eaq+e}`Q5Yvd*Y5bU_E(X-RK91!$x!BXskK$^tky$#^;QjBgE-Q>u)l$56Qpm z_l2&XX9Sq%3`K%(nJ+OC|Cq7PgHa^-m^u3=sjhg@irHy}IdN7{UH-l|=S6BqzD)JR z5qa}{8(e3+#JIna_e0iavGq?GiEsAZOzRIdE;Z5*d4Y94%Jt-R; zIltzZHyCOEE_2!$Y(C$}IuRe7YlpdhhLL?i|5lgbV@n;%nkZahew%SjQV)I2!ZsAf z7mU0oFz&>x;=ojq@C6A^Df64^;nNU|M?2(Ekhr2->{HJ3AocW5 zJ0R_o&lwN(7r6n}AATAQ8)q2LFcN1^n6p3X6T3?DH;mhj>@(t@{MO|9ZpKZ z4f_AqT)8h%Mm{C&yC%I3r!?+2{IKwG?SBwtN8uQRSTGJieH zmm67!9p=?W_5=18SkF3fF5G85>-3=cB;!k)_-WEKh|-w^{n@?=KYK}8?Q0GYa|{gn$r*asK5317#Zgv^To#W{d2JO zcNR(wg$JhC4;`FZkENAu1`>Ot(_w8JnuB+pZi zJ@s(vNBKF{m3XbQp7q^o&U1&D6X#EuFEf(ot~DnwzGJ@0c(xI{8Ro?4Sab3^`G7bY zB3s z{V?C0FZe-TU>>I>?b8lOUS&K7yMp;=d`I&@g0sv|N%TV#&iL>L+|PEzH*pB3p17wS z5Ig3D=h+W%>gjis`z5asPy1WH(a3sEGAG|(Z@$vVbEle<56Lsztv}E>*+?EFZZERF z!T6-{DdX)%;%~Hh599sD?nd%B`;qxP%=Modi5vXIfA;w>+f^DFFY!-%_{Dq42-~y& zX`gsv9aq?ndFFkEeDi?o-!rCAvVyq%#QY^A^O8OXPum?xh4IjYLy4m^ta#t}qLFcN z4xefLEaUmcSB&_7xcM^U5ym%-n~m2Q&ofRpE;bS`>&@9OubIy`PD?oB>S_J6i5>n~ zUD0fTcvx#r|7V-0ei08hxgL1i0`uBxzQB05k#mFmV0Fb+7CIQolgtzP^!=CTLG~;8 z6n(9QJ}3(C1M8~_>c?A89I-x(Ynkg886P!rPINOr%E-Qd+PtTc{Y{?OZv9TF``awqa zC2_yjdghJz++aOWLMTA?0r_QIfnAaM6jw|wGT4XygruH)ikf(0f9&T9?3gF);TKu) zxN)VCcw^ljww`$+e=#oh2XTS^LO=Je@j2sT#<51m&$?Y}{gFoU2Knv<>&dsoHRt4~ zuD{N>$vE6N*?5bQxT`Tg*SN*F+<3lmwh_NgD_G|ntYDrym@}`~Gj2HN7=90S#U&Ud z*gN6O6Y5703h=2!zcAsP8;QD1i3=3uUHBJ8;*j%?cx1jfUoLio#Q6i}2OEid^7(Y@ z+5emu6Ramr_LyH_ywbSC$bKhocu%1J=WX|nahmZ=Bl+}gbM_Ca=x$7(KiPWbgXdqiK>y7DGuG2S_PpQ0@33%yk^R2W{1Xg~HI9CttE(*}t7^hd#Aym$dW$HFloAIfhXjE=o0K?a`=J zv1*4Zf@;->imKXFh@FU$phi$aC@KV{q4um%Ge$L<@NfBYT|OLn;{9~JzjMxWKleE2 z+)opU1?W6zf3YuI#NGghL+jUlbv@40zT>@Zzp-zam*#T~<35G8(EjE<>UT#oe%G@Y z{S)qkuY&yqdH}o$?Z-ROo}b@y_IdN8f;jWI7;T;-)9e?1=eRHWC-LJ8EYktK2h0o4 z!!XzDKDb`}es7%D5NKUzVt*3+sY=&rAAiS#)|LC!341TN8QNDW(a)jr=8^WE_?N(Z zcpSRV_WjY=v!HqE`EJ3!9Qyq=HrTCi`))OU*YEy6!@eD^f{oC;cb_t_XG8mM1KRhD zL%Xl8_cM0iZ=Ck|_*X3E0h;gC==;!q)QYZwZ=rcM z3H=^gR}0bB-%Ip8cocSqd!hM#9c>;tzscCGqkiZiun9haJK;BY5SqvKodWE0p!My3 zdOr3E!qR5(BUb}-eeRp?dVd5S2-& z6q+~I{dw$TU^a9e7tq^bA#8>#VQ(m`-+37Fpx=eY`TPbAuok+`F6eh5&UrivcK6eJ z&N_Y_;*|CcwfBtYt38QxlN`Zr9vbg?4kONUD?$%|bD-yO3_TcDLGynN+Wp^z_Wp5w zH?X^p_Sb>fr@&Lt{O^VC4>!X)xE1>Sx(IFEWT78I>njKS82TRT(ffBF{=?97-GO%h z%yaX8JpNqx5jviE?t1s*Z-Ktw_hA0}|KA?K;J&yH$Mf8+FQ4z~4>&4l=NHU>rF~t) z8-Ku`L7RuWu)A*K?3fR5jQw$AOt_K9-*zGoxaefIpvV0XRt z6VKK2bY2sQzXHuC{f=w>oh8n`>ACnG$9JB;i9ZDULh*$kM>N21uojy4YeQTk`dqN@ zL7PwZKlj&g&nv$FmI3ew92)G(d;a(j+y2KN-R9wL9F@>|bv?7OyH4}M`_gmDBd!8| zhufg%WxV_E{a{|1H|DwhtPSHlgT>JJljwZ-9C~lr5A44k@Voxi=&!I2+Rx0V4D2QF zGOU8W*E+=f&%bLn?;X$mD&+y|z~>|DG0cPo(DRA$NxBI~%>ShLXkUsUD>_ot9Yb6u ZIvuuyp1<_`eB5!`hv%X_KIi+i{{hsO0d)WX literal 0 HcmV?d00001 diff --git a/uniter_model/data/test_data/input6.txt b/uniter_model/data/test_data/input6.txt new file mode 100644 index 0000000000000000000000000000000000000000..619f9be9298fe03e442596509debf6de5e019598 GIT binary patch literal 92849 zcmbTfcf3|b*5!YR0*V=PAPdHugNckFB6>|2iI*Vp93D{At4LB*%wkp~2}V@R2@r64 zI?Z&N>3RF@p6N8x(`lyDOi!oZs$1*(d|paVqd)$rwb!nFs%ls5+I8x=&%G>Ze^7?pL<@v`J$swU($Z;FFL4f%_;_rKe(czVtziaG;hkh*%KDEqn%y#?W$x6Oljdga?;PJv%F^QbEfW^Zys2eD)}b})IADAyDHqI`mzT0mW7?1J zAU1PuDeF9@$%4U7&t1iMnpte&Y|6NmFfgZGC-XO>J#`T~l?|rERQJ<}b{; zj=AWtcJTPlYP&QkpWaf+4w<3x%;-EL>9DFsU1x{R=sKp0Iyx+L)MZ6>_?UJp zvL35ugtm?-Wk)t=M=hK@Bh16mt=TaHW_0}g#%dcH8>+Ko)mRwyaph4TA4h#cZ1BWV z*0VY5HRJO~RNYuzuMze3i28U$eZz?Q#RmG9vXh##0Upu7*6ieeGNOj+#_SZ2=+yFv zPKzTtJvMkoDXVPGstQI_Q(x0q*_c&(L^U2!Z5UBqY@oiBH8f|99#K}orR<94Y`jNwWovd-;fNX=o3v_IdqmfiM|5o*(RH!G38ideb2h1HL=6q~ zP4(Gik7$ZVG&PKy>_!23vqNyXsj)c zTH}arjSb#b%9b=|w}1YKlx`Xu6^u(gqGcY@@-U(mv4NGP?2hK_PLF6+Yj#)Rh#IOJ z8f&w=J)(QcBU&9tbZ>0%zEXC7bM`>dh(Zpm%pUZJ9`cAD4kLObHn66YJ=&Z-<`F&K znmtiCqS}T!jp#`=Hdu4?R9j~)Xm4UcGB7}1-tfwxN8+s)ZK9?`q4*?WZ}YHFxysL$T_h(0Kf=)*Xok79!# zm$L26*^bX0k(R8XrmDI!+vyQ~;t}l%BibDs*i*_rZO;CkNA#7}?B5rTNHbJln|;+I z`dWEJUymdDMr`n#rR+a6XaBKiM2*!=nxk)dMBny^z7t0D-PpkQO4)yE&i=DU^!?WC zzkL3P>KYru2~eB;i%0Z>@`(O5j_8N6!5@{f|Jt1Wx1td>)z)bD|JWn?iAVI)FruHu z27X@3{(E!wKRlvev}XVF^GBo!Z>rYb&3@?-{i-~oU&j&sCN}unQue=^v;SQ*qQ<8B z%F66_9?|bTqCbQY{V_K1r&9KRnzR4w5&gL}`+xsrL^XBQ*YA|#|HmV z$}8GcR5a)9R2Gb@uCl(mu0C%cEV_o*lDvb&IJ&%}L=7m5h{^&FYMZnp z>hnW`MfBh>iT2=diNS*&!r1B&l1llJDx348RN8~1CAQ|rs4VuNx~i!_xFA_aCK~n7NL`kK*r^@EMm&zg!Y8$F6EA!sLB6`qAqCMy3&Z|@w zcu=R;tj5Z`I#@&xY9!i&T8Y7fI$>O z!8H=g9$YIDJ-AL%>}!IgQa({-b3RFBkq6cFRSos|s1zcP*+pgq$J;izT+^H>oW0ptiBTN*>%S6ZI2v z(Jc}~Kk9X5-YOD3xK&c@>o!THe2L2D{C1U}?}6g8UME$)G+0CrmPxb+%OwU6RtRIO zDw&qW$EcBqdN$;xp zlQOXfPf09$@U%$u;2BA=uV*Ed^0g|P^XF6+c~DthRbQE}3l`CX^%CvD28qFgjl$UK zCP}6Id6muiW|j6}i^SIa1(n4f)Q7D7qD<_;OA^Z-yetwuctuj|>s3jme5=al{56$D z9#n@{@BH;(5j}WAqCMCqF?jH%Ft+-Zq*DI2%I5qXmG2Ot8C79s4ViJR-X`6=R1Q%^xzYT_F$L9;K6QT zY;}*MQvRvR=KSBOv-vd{y) zKWe7_MJ8T@A4n{(!M}<`4}K^q_Vpu4rTo9DY|j6i$^s9WGL=O&>6#F_z zlI{m8bw5y9nD(%715_LaNS?GaY zD;gW~V`X9wj+0pS;CPYf!3mOLUnfe^{XnJe2P%s^sH<+$&7*g)h#vHjXb<{I3?B3o z##Z}F()~cC?guLE!9a<+AE+$!pteeH!TBjNQ9t3W;#7&DAN9I2KTRZhaJr<}*BO#@ zKTxUrfyyEeYHM{M=;Jvc`sdT_3!*w-LQx*w?2{Xk`*2TeL_@}a>ZdN53)Js2)AcrZd3TOBD$ z_XCx>AE>kkqb2Hopt8_|`o;!D>lm4+pKzjGATjjQcA||Hi5^@iDfTr^lI{m8bw5y9 z3*P6_XCyo zV5UUf4^$R-5Z>ZL)}AF3^%M5NY>A;C^|~^jBN9EBD=GFhPm=BjDs?|lS>!>jUfXp) z2o}+U1rqH+Nn-F|p)j_(NRsXcDs?|lX%7}l)crtZfd~3_MwY z^x#%Wv9H@C>3*P6_XCwh9;jWN_Dh3B^kA7pd$3$$@L+{7wz^W1?guJ$KTv58R!P+T zKxLr^4K?B1xmzao;2w!(4_1pr5AKx|`?^n(?guJ$KTuiZK|@2G_Q8X}B6{$UM0@bC z#Nfdr!r1B>NxC1X)cruEJ$PKA?guK1J!q`f8aycz^%HjTQxZczZ9DmCk?6rQl44)a zO49v6rS1nRi#({-hrreOx?mAKSTE5YY>*f{*eHyxZjz+?flA#ERN8|r5_LaNS?EDs zI1us|WnvFrl34cOWs&H?E0SVguS(MWK&9>nDvLa*(FW7~AXr2X-jHYywn+>gyeW*W zz9mWb1C_cTsI&*~O4R*8WuXW9SgO7@e_tl{-~)+e4?Yx$9(*Jz_VuwO-49giexS0* zgR1)4s``9qu!tUfBGDe~k{CSLEsU-1k)->9O5G1s+Jmo1)crtZu?J0+dI$chOzgqe zB$hq+x=8fk8Pv+dhkO@v9BLV()~cC?guK1 zJgBR!(>?je!6JI_6N&cVrxJq)KNH4Qe=bS)1C_cTsI&*akf{5C%0dqstMmic{FgHE z8vIIPc@2Io5kQuhOuMIJQOHR!G4cflfh@Oz2&;13dm2Y(dCR{tbP z_XCx>AE>kkf0n5GfyzP;Y8u1Y{}-9qgTG2Fd+;}r=)vD5#lHR_N%sSNk3;tZl|>%Z zx845Q2aD)I2Z{Eeqr~7rCt+;0vn1UQRO)`9(jIh`sQZD+Vh`G)Xn&d5g99X%J?JhH zJvdNO?CT&&x*w?2{Xk`*2aQeo)bY?@5j{9eqCGfVV(_4cFt&PxB;5~G>VBZo9vm%E z_XCxM9@OjTtjv#a)R9WTJlB?gtV>Kk8NY1Ci*#>5^h! zXGqfhK&9>nDvLa*uh%y<^6Fp_J*bgr4{9X_59)-m)p|*~AE?y*K&3rslBoNE%0dsC z^zmk8ewIwU24_nwufaJY(Svg(#l8kf()~cC?guK1JZNmx59ITq!6JGvOrkv)E-`p8 zLKs^eDM|MOmAW6OvVBZI(1ZFqeaR*tBNKaYfyA;0V@09|7fOnKjgzGNflA#E zR2Fzpubt2Ukke{Xk`r2l~uWKTgfB zmWe&MMq=56Yek|5*GY9(Ss=x?ZH%u!GjiIY;~F> z-49giexTAG%#^76fyzP;s@mR7XURnUwA~LRhJMtm?gt{#gSnDoU-KmCexOqK1C<3H z)K^y3>T{etSVRvNNVEqfiNS+~!r1B}NxC1X)cruEJy!=oQ*HQJ(cLn!2lq%Ud$3w0dT_6#*w=lMbU#q3`+>?L4{B?}XK4=x zi|D~a679jm5`zbi2xF^jBtskOT6f<^RTy+nJkL1OS=qcFC*Ns{gdDs?|lX%DtY)crtZ zp$FP-daHO*CidVZiDeI77Kt9bA}RLuswCYHRO)`9vdDw*qjud7f<^S;4T<((o5bM3 zo5I-YTat7?P^tTYN_+6GMBNWm7JAUs5IlHaCh8~btji} zAE?y*KxLr^O4F73&R`Ke_(Y;T*d;M|uv-{g-6Ki&1C_cTsI&)Pk*NED%0dqs>-58- z{HrqY8hlM+c@4fU5kQuhOuMIKbuho3@yD_BGizAe!nd`DvN;Jd=u z>h~n+exOqK1C{pR`x13OP+8hoX9 z#B1;?iRCr;wMg{fHflA#ER2F*Bq}j=T7c8O&zn5qa{va`U@JC^6^-q#? zKTxUrfl7PuXNkHWs4ViJzNSjwPR#!z6R*KvC6?FVZz9oyze|dJ{X>%O2l{Te?guIh zJ!t!Ra{FKrJ?J3O9(0r#Jm@5ht#+2A`+-W`4^-NNt`c=WP+8JgH3KTxUr zfl7ODv_#zxR2F$qUmdVBZIz=OKVCVdMtuMQT`gBpqUpjKk=piUTDt(ThndgcI!oiJ>3$x-uUt z5VBZI&;xzoRhc&ji|D~6679jI5`zbq31h35OVa&7rS1nR?ZK52 zbw5y9VBZI$OH8f-hpSy#2(C+SoUC!Nc3Q?q}bOyNxC1X z)crtZkq7ma`hAT&4;Imb1rqH+Nn-F|p)j_(NRsXcDs?|lX%7}l)crtZp$FAf`tUlx zStj=27KvpKT1BD?L532P8 zy2|{)U=cldNTNM>SYq(t5n*g~jU?RAE+$yptkMJZe6g59;}yW4>m{)9&8lGRyRq~{XnJe z2P*Bs7Kyqas4VoLwz@{2D7+{WufakQuhOuMIO{Nh0iKp z4;ImbHzeAFZ4!eAZwh0pZ%NYqK&9>nD(%6$5_LaNS?EEv-tF}B1DV)^4kQuhOu_TU#1 zbw5y9=t0}B59PmAE+$ypst}&?<>Cx z7SV&>OSA`nkQhApqcFDmCrP>=sMP&Hr9JqwMBNWm7J5)wud7u47n!J^@NW87iJ_mi zchkR#L=XNhDfaabNxC2Cw?cJ4P+8n zD(yj6iMk)CEcBqhE_}~!f0?MC5Jd+_4E?l4QFoE(!GV%uUk6Fj{XnJe2P%s^sILp3 z*&P}zq6ddbvVBZo9t@PI z`+>?r4{94KHB+a^ME!&)I#pumr!9(36Nw(2E-ChPh9un&RO)`9vd{zlAh0s84i?da z8j1FxR$}m=P8eIQm!$iFO5G1s+Jh#Ex*w=4^gutR(7WkbGO-6|ODubEj!5+2TuHI7 zL6USoP^tTY$|4VHs`ML_`OshyJs2j@9t@WlJQyL2t&WtW`+-W`4^-NN(Gqn(P+91K zers6g&KQ~4g9{{}#AP-49giexS0@gW7P?GzW|5!6g#y!KD&|2bT$B ztCvgC{XnJe2P*Bsl@fJ7P+90fV}1C1_G+1^pRkj!kr?_>uPgIwMWP4SNs4_?L5A-!V z#plvs5j|KY(H<<97(7@ZjIFMer2BzN-49gSgH;lBKTuiZL48x(E%R=f*n@i{mOWT4 z5=Y;}_)-49giexTAGY>}w@fyzP;^aT{{VBZI$b-sy{g^R-Jy=8!-jHYywn+>gyeW*Wz9mWb1C_cTsI&*~O4R*8 zWuXW9tU{lxy)P55!3Pq{Yw)2+^xz{&v9FIM>3*P6_XCwh9<=>he!eqUL=QfZXb*Nt z3?A$j##Z-8()~cC?guLE!B-^eexS10gSIpIt1|H#d`)6`4ZbcCJ@|&C*w;5D>3*P6 z_XCwh9@N!@Z_IuxSVRxLEzur)M`G~cyTaJ&_ay0lpi=h(mG!>3zTd4k zyMK|1J@|pdvIqYv5DRUM9|w!*!A~UGgP%$a9{fxg zTm88t-49giexTAG{6eDc2P%s^&{ux-+fey0WnvG0C9&+muSKE%O2l}Hax*w=4@Svf(D$Gs$U=cm&AkiLllo&kdB#fAE+$wprN|$XYIX%Mf9MLM0?OzV(_4!Ft*xXlI{m8 zbw5yP4+cuq{Xk`r2U_d4-#?Iv*Wgr%BzkbVq}bOPl5{^%sr!M-0uS`Vq3|1@ z)xjcqP$SVE)JhB<)Cps&^^$ZyP^tTYN_)^GQTGFtg&u@p)%S~JVh_%iSoYu?k?6s> zl44(jBHyNbW4;q;cqPMpDM$j@Ht>_W!tARS@*q_@oz#Mm?*=~ zk+XwRW%yGi2dB#LM^Fw)mErFq9hxe`pOiT)QMUbcpu*xt5dm zQr5NoI$1U#QHI~x%?2jQ@TWDhlM`k5NoID+-pZ=*hZ(X{6J_{|1leh+Qs0cpPEVBK z&Xk>zD8ruv$SM=`$`> zhQB*@QKAgLb)8+DD8sMbWzBml+uk^{OA=-H70m3?R2g>1Wr;HQnO&YL!{Dz-mHKQk z8=ok{@7rcqCd%*?l&TCvMs!)C(7{aSJ{k28GctP zo4L2L?U#45>l0--^|M*2Qa^CdW+%$<&A4n%stiFpH&ur8J1+Q)T#*ky)Y) zAE{<}q71*clP%a=+4lAGtduCj+jX`uQHEa($rdHb@D`ihkSN24HQ9}cGJGQ~TfCPt zez`ciDN%+`O0t_%rQUe5TM}jX-gefyx3VgHjF#P+Dnl^bmMX*R!IDH7eiEJCo+!iD zSF@#iE8D)2kS$A<`oTrEJW+<9*=8#eW%%kxwlYzM?`URsq{{FLcW0{9&-}Ahdn?<% zG?(3#D#KqwygN~bFGgkeB+Br4dA2%L>Z@AWy@@h>hc3G>QMSF5-k&JL_XV>DQe}7( zcraCl<#=dsW!slCvWF98_$hw&NUGG=nX)yBGJM*QJ(?)P*M_pkQe}8GdOT6K1;Z1m zGR)GGsWPnQQ>ij+!>3ba_|tUHB+BraSN81QN*&K?b8VswpQmTfCCc#Cl5Ab7)bHqL z>l0=82rS!>D8r|o*~V0Wg7 z{!*&c=N8$^i8B1KGkYafhIgA+Q>CuL+15lEe$tq|wzsnF2b-JyEv30q)pK+0gb;LAEnhhBf+RZ)MxZ71^#t89oWi zcJHlh`)DTHlPbf6e3~fRj>9jU#Kq{0`1c; z^~@vubOhsJ3GbWmlnzZ44<&qG^oV^K6x5Hhev$EG$uXKb(n2_u93Jo#{5d-dgHOiON{tOzpq-)yk2I` zzFllyX#{S!Fy6?%VgJ5l{qe?0Mj+)Ux`O$sGJnB1ng@;7B%JlUC((amJ|oc=QN~L> z`2oVW`LE~CuejxRr_2IE|5DE~#3aTqu=c4+jiYlys!$_QUPV9+Ckhowyut!fjG!yjw`!+DPQ1j!F@ z+Tr;WX$Pd9ai@AX{iNfep8i4X7!OE)_&v%U4>WRaPckR2W|=QCHW`VlD)Z4s@+IRV zZVq())y5X%Nybx*Zx}Zk@r(Vk#rngHeT}<}?0epmr&!Ov?_th9!e9DjzR0h{JN>o# z`A>|`7=c$2j1L)4G@hK)rxbrvJ^7CLOY7k~5E5j5n1`ggOxKpB&llz<>Wj?Dzu*bh z<3If}kF57;$@4FoA8%yc*O=dJWSw~bxXyak>os%Yl6d5K-ito69pmJ@W}LhSk#C8= zA%2eYXM{QN#(o-ZJ?HUAbIylV=J;t^!Tb2JR`fHHw>O(l<$(l=m**1wR`dIfoGXc0 znZ8h#qHvq}WaGt2eM;*R1^3%TT}a0j*6dxu`gXLsQ-bU(;)Z%S&(Tg-Ki|&CIl}u2 z9Q(AL!G7*t<2vIsBXLK(Gtc|G{$nHi=2mm!kT_uf5J#+Mv+XAvPcV)&zGY+_SJep+3z%fdxQ_7U|U z`;B(d-(_K^u{#wBvTkYrm@%J>gZ_Hifc^QhIrG8%b6&pZdggnB`9$ME#`BEaYhE`e zu80HjNjKL&Z)Cijo8-ZDNjp!Nw>MsITy7*DFEJ<2@!S~eiKiRQ*`MUKTI<<|rWK1U zunz`PVZ7f+9E?cnx0v@e9-Pz@?1=bS_jtwPubl85(U zR<0%viK8{HInhWQb~k69`k1p1c`q7fJ@G^QPOzSNBn}?6p8htP)6ZJ-(MIC&aC7p} zYV+Yn=7D%6ep_5W%J_hh^PKpdZawc+OU*kN>4)_u-=E_8*+%A%eRrPqV~s}`0qn?& zce!Gok@>?uMe;fG!v5-SL;OjfE8@4i0zZ})8Gb=M@sOGmhnyQ{+JHE?(0rEh7UO&) z^UXS)V?FIM{^zWpXk@?NYEB+G#C*4ryu|oVvwoZLP9yR2jye6Y&cw+Du7BD{T!Rl; zPh663###T8@e?EM!?EK$Vjjt>UH#mtMu7Q%FSEct!EfTTABF{ZSfW2Z;cq28r6aBA zm{yn**9FyO^}@2Wx6p2HTK>5d{V`9hTW8x-5AS3>>rdP<|KztjY}ecPq!GBm0`Z4_ zhz0iVedaeAZ#MG2HQAi^g+Auz8ecJz=g7nBtmj@*WzIS?{%Y%|8J{-}FwQZ4Xe2LQ zW`36O7~=~@=6jVn?N2d3+Q@!mob2CYU4OrkahaB(hozpz+l_A+*@r{TQ+Uh@+V5%3 zK4kvW=irQ=bzwcI|FXDW!2_eLC(kT3=e*~AhWT3V`t`;Z<7-CtbB+1M#zDrhM&j^N z^LvfN8Tos$^{n@5^K*<7jI1;L%(s5Ck+@=iu;H-eGzCF(-jE2&{0Ch9LtyV#O^2a-S8xA6A1?_lJ4u#@%J zfjpnKM?2(Q+QE+aryY3O56_`bIqlIO%`h zCyeAx;+^MKy8cchd6M{}oikiN%-CdP{~l@%K45;lk#Vq(c<*3e(tnHXLHvB%`ZtW^ zN9J*X^*0-Vr!6pE-a8(*p8Vb0oPKzoyi#ig>p?x|4d*C&ko~;Mie*$7Ip3I>B*sN6GW+%}YiA|M9EQ6`W7w&2uB?9?w@=U|rCI%s1_GzEKa7zv19$3nNh!An!kE z^$shD!(HaA_nqdoM)tuo=3|WHgJI^xGw&_L3;F19+kIdpAD(S~yODK#&HOMU?OtV0 zemK#5u<=qO`+)pUTr72cwQ-}7b>#hlcx`rli*b;V_%p2_A9b>Vah+~XUY%}^UpxqGS9CG5PRu9gN1f}*FK?UQ zZM@t_-elfZTYtXsMkD72@k@R^#q~XmyN%ZwTN5PCPqqFMW22G$_qaLvVW2tt;V$!6 zjpU&l%&#y$Y-Hasu7|7#4>soQkF{ok@c`rX zM)sxYXQ16@G*Lmj+*49B^4SZnVBU$NtF8akNd93x$WyE%{miyK`TRihJB%ZY#Mu;c z){}fY)_U^XN^{1^_!e5vy7n}$H+D0!FXo%SV_a)wTszIl>+Fk8)-N}nX(Z2*f0-}r zY*s=4#07eg{8yCc-mu|j9!PLQ!iy;TIIS z^~audp&gr7urH`bPkZDG+M_@A1$y@5S~o!6gHN=c_b1|+yvzO~p3%SM=Qw}aUoTp} z#<;{tJT5mUj(eJu-!_=DpT?PwHy&(cUyz^ShZ^=J5`+4>id6w}r zUh-m{?OEr{e713$k#(psZ!tb=WL-I*$d{}O>%_X#{uVz!z?i~JE9M*7k6X`&CBMTN zCi{Z;nBMk% z*to~I&iJZvq>*?hAJ4UZl<_6wcH=ol+N1ss>+$Cj^Azr}qQ8;(XTHhv$E6RNe`q|_ z*kB}1ZZv0~nU-OirOS=I%QaEhp70{dJf!vWQpK0#RkkOdd2fId$L(yF?i0FWK6nn~ zIqYbkcGxf2r%37RZ&TcABS&+{CJo_66p&psuOlCM8@KddWp@UHbQ7@Lgr z$GNf2divX9evOfHfc?k$^S0}`ubNg6kF{3ZX2c$U##leVINLbGIMO)TxZXJ3IM7JG z*l2#e@d)E?BlAapR+r%&OFfLVgWu0ve_8p7e=hL8L%)^(-25}O=UcPcc+6)k{<-2V zS08C)-T(PhpRFBh)rH0`M&gD2M0|tfxwK#EN4XvHUtja*j3bQ9A92Y3;9NM|c0G-} z@8Ex<^=BBLHWCN3%<+f!LY^lcdfASBbAb6A<1*vXM%HbDIp^d?^Y@JSLB9*F-(`H= z$h=%^&VJtCyv6vkk^Rxre7+Gl(83NB2~IR8pInjDLwR3?7h%!)@n!c#+CK3?z9Nt8 zv?2c8X}-$1#7N#I9->>9dEz8<#>2XntY2)r*2sI$3Fe%qTg{mt z@)PTStm}!#rRF`2?2~8Bk1}pDl1J%>dAP~-j~a>B{^lP`GB>o4PuQjsYY4>gG z0rnM~al>aK7>6gE_E8^+P=FsK`ZRvor-@l1t#(ZpBlAf913S4s#nhg5+9mDsTnFnx zo@XB5w8MO&r(Jl8)RWiQU-!G;6OH#7*)N=TdBJ?Gb3OgDKgi>p7j?FK%=n6NfpL`aO(XM6JLg%? zx|)`Mj`Ji6BYCLSoV*Xm{5T7o59q;u2*zm%KiT}!L=Pno!c#onih*gx7v!lwJjVJk zkLkRzUf8Ak15P~C4skx#9Wj5z>Alu(HZrfDm`^pHZp44niftCw8b36)80r5K^Y@J# zjfWWDGV=T_<_{S!G~Qu6#yG?{(fF!yhmm=C#+-A4_L{9{|MfNRZoJZXrtwo_e`7D> z(?;f}oB8cVUJlxo5 zB;GDJKh20gTg_R=hs|lw{l+SZ3tU}&V@L|wr=p~nySzVNcdfcQ_;h4jS2 znpn}r_3ezz6X!gL9f*B~rE6F1kGZ!``zZZ;B^i_NQyM;hymoZIWn z$I&u|)u)_1lj$@a&aPnO{*S9w^E|AB`i^SbpR}408G7>+`Cwce@ z*WY98VZ7N$e%)gJijg>Ae;sZ8BqQrVytP|*_qOZj8aelu znD;XBzJ0hk^`Dv(PsHWD)-!I#!@To;_@eD*8i9i?JZF5&NZw_f?6)+IXWRZNBk$v; z6>AZUOA^j`$G&F1F=Kr24-kw87#}urz9eR4O5B3PRf^0X^T~Yk9DI~p9%LM6Bwx@D zdg7wRcI08x3gV-m6+4Wq3;q!g#M^e;F@LPzaO=tIoa+OvzsPuqk#mmr2UvfSk-T`G zc{k&JM&jXB^NB|G8_zR;V_Z)jy~iA=wm=?WeAI(yT3Cp}$otn?^DL={E<#9fLc)tE zc_FP&<2S8`vkn(nvB3De@oeLzM)C^dCJu>L&ZlPElfR~#FE+9s>~G@cIM)-0_;-Z$ zeT?-+=JihVn~jX8r}+vaaY|h6ww~w7M^9Q$etgCJM&qeQ&V`}opBjm?{S*7G=Jdxr z;2-;@(ssb@7H%`*C;nez{V_a{;4*XOmFkNTjGU{9y5ck|*w56Tk6`SXa47x5Q|w{| zd5iU9Jj`dikxDDBjcQ7{Z8XL;|?Rw5%*SC zuzrtNK^~*O8P;&D*3%Z1=v2(U!CxMQT7>j_zx#9ZnlE`wbq<`HO!oKYBlFQ zjdeQDdft~-nsYvZ8?3*@NWAtk=RIem`9vf4&-Lb~7?&Hj80iOl#z#KuXgk*bGV@Q2 z*BS3Ljxw_DFPTp=t~I8x!V2b#_#n^1x2F%9GcNKV`v`O52%h4DR(xn=zEis*&bo7+ z98hFfpntDGpQ@SPt6af45?_6-Coe2ACmx?RC;o_wKGxS6&o%4=eT+MdoP*5s!PYa+XPIv{;urC5bp_`a z{A3I3jMp018dEqiQ5m|psYE%!DgUg>{7SGO~=Uc_l<>(l3mXXc+g#r$#3 z6SuFsoef6v2J6Fqxy$v$Ta7vUew;b$FvFa<8e)E*@h0PNBYA4QIsQLlewC4RW1o{( ziEH*1?GShDuPJV4fpNQWe4^)lf%s#7$g{WE{$At##>vKAM)C#Yx4HtqnUA$p7<(Ey zhna`OtOEbj{eT_$w-!Z$tiyAOKBXV0iiC5%pe8QiAn{E+qbJT-cj9}L4R{~8+x#)( zR^vt^(9r_^4KgQRoojxjv6qo^k9NsxtS|9IeBk#$Kabzfn)fpfHV!kg4;deE%RXQq zK45$DYNa{x&Hg>pdiLpN^GU`O$hW)~kuOiMA^V*3D19DI{K3C0;@_8j4*N$D_JOQR zhs2I}Y-c@jNBmF^rtJ{7w42(|9d^Vw?a&{3_62&{Wqdq`J)C}c9zC4?;k3(h*rQMF zQcin}BW;Izo?|}o1OJ#Wp2sd7598u_p2H3fVoyEep?&(NU!G&U_{DP<*^u|E-sT|t z@ImW2PuS1TTi?^T(n$L!n$!Pu^Ye|{jl7p$XHNeU%sGFLFdu0of6g(##(0mBJTSxj z6XPo5#snvulP{Xg*Bcia-!_gho@QiTOe=1-um*(@f61rhH_m&^8DB3HM$XqE=FIOQ ziQQY~qYLyeSkFAu4(6=?d=v>X4_TszK4oEl6bbelyc>=B&fj=J<1^ zIdRPV4Yr>1m2nWqm$`nLv7d3P@d@McM*P3ZocZFt5Z>SQl}7Rb?`?I~A8b6|xYF2c zyw&)gkvzbDW4>Ko@v#Nwk2u|gfsys49jY5FY%#Kq)b9(4SKl*| z&g-o2VPw7UGABM+zbmZoY-B!`n$r&Zf;b}o9%Q@QjVp}gizm!S7@sySF*1MK%*n5d z%*n$?Cj3)#;*5RF_;S~?j<=b!Z;407Lw+A_yN8XFj603knU-OXCG1&;LDrKG$`2$0 z@y`4ATGy{OUYyu}Q9AUC8Wz~?Z#{7e5=ZS^PyDbB*ijD>&+Pwce*QEg@zf{LKW5H+ z5~sxBZr78~?luQlSK^}H6|Dbq^UcOl#*&eKUp6OhJDNXZoMRks+-hVz#0PPElIuTB zp1aI^i4i-Vf6)3Am>14b^1|rEkn=22SKtSEhWf<_#@dAMi(Z--FwR6>NN-4L$ZLtZ zkdi+zV11vng8016e1Z`_cbRiOTwzXLyw;p~Cl4KJJ^T6q^O43%;{YT3?R0bY72}y` zJ?G6s=Fb}0Pj{MMZ(L^l$Vgm}r(dw1d1C)zf2`}b8=1cvb6|-D_66&9yY=@Qdm2;V ze7?gK&nE`Vd!qgf#jmu2Jd^5E&ihfSPk9$Bcs~G%3+9>m;CbSW^`#x^_p?Ag<#~8& zM?2)#%WbgG$akOktN z{J?yC?E3Y_H13$k8LmIW2&DGJ3H7``kuS*;)L)EXJTT!$B>am~;*hwQ><3>pGM~)H zQP!_EK4Ro~{9a=HJx0z!=8toYc){PLwr74WG$&6mkFC}(F)}~gC*ZWV&2~>3$-nQK zZ#7Od5)bT~PpxNMz09kOLyf?83*`M~^W%&ojW-!nV4m0qA0!6&m8i><_-DS(Dc3}S zeU+#`LpNB{J*jyq;rpVuBnD~!U0hE-B~PMn?|L}Rk)O#Y#Pdfspnj)0^ZJze3C1&x z9~znO+2)Lwc%dEUf%P3_d-lu8=EUz}bIu#yBZgUjsj;{5Wh3i3)BG&saYo)l7nx5m zlK0pT$5_AC2r%Evtr$Rsk$ucMl6Tp+YZ80r0rm41t}(uEyvoS_G_7DfaPkLvX{9ST zXRrg;A{5|m>vt5??@N9DJ`9Lc;*axz$5{W=PvRJR;(3)HJi^HQ-)UZL>}w=HvJP)s z&-@QGZ!xl8PD=Ew?{@2NH1;y`{>yw4&m&xauW?@T9C3z!#QPJryTEvR@;rWSxBhP9 z9wYA&^h-PBQTn~a_NN-zKg>JhWqiyt?;Wf!ae0W_0oEcIiT8Ep2N)S=VqU>Ipug6_ zP~&i8n%A+<5sc)E0g1kd&Ppnn2l5f;+AiD8F!El7fA?F@dHR$&d7kmyC<&XqPqkbWqB-Zr!{#p;$@2@%HyD>1f#WT( z|K^&{HkORX8B@60iUmg2f%WD2_34A=>>H}_gZO&U6%#So2a*pkV}8#@u@79ak9w+E zpRQJrU#SPvyox%##@Zc$5ZArj68po`BEeEu-3e{ zvBCJLk-T79{{JsluG+Zz!lo zU2kEOk#lQ`IpZfE;5Ye&_sfm8hjV_OYdv}EX!8z6@(%l<#(KuT$ow%Q@ju=CZsSDb z%|_<=Lv!Afh`*<;r=3c3^8DrIbw=X*9`kpSdd79B^}w|VM)v7UbLRcd#18rag7Fq( zC*vw3^J7|uEoCVR#6Rl$5dME++CKiQaC-xdoC`;qHyW8A*3Ie)@*Vq&^|``l|Wa9(IRYumcr#W%QbL9Uj*S~4xyu8Gm z^`B|p%gDLGJUwXr^+sR>f{{2m+#J8|GiN?|jQN7UVS)J~Ur_&L@r5t@Kw?iEat<+1 zC)(~qWYgk5YOAq85i%FldM0)2#^oQSiyW! zPrD1OAfBl|$imS^_8ay1&pv6U!pM8WIp)MmVpgV{3n;F#Fw|HSueAHJ?RbuL!0!pJ zXTI5od9``PV-{$KbBH{9u0)!{*Yq>+gf)YV-Hm4(iIdglDdbi#Z@gz? zms#kCBEjAXXC4!EAw9IP=8IKyu^H>p-kkO7Xb!So*pWYoXV(8@KR3>}#>lzGIl{cc ziMJz@etB;oe?IDZ;-RNG<0D?!SLEN7wp(O8!bm=1KM%D2Oyl#${f(Q9#0~qer}gg{ zhZq^}67vm4_UUeOfcSgIik`*;jVBt(x6HrQ<-ZTgI=<@))`2|9dvKb^X=j*)tBuST z`xQIp37#VB$-Hh)3^pe`rMFtqVw{{Bm=n(+@j$$GvK@Njo_Qk9d7f@~4y2yv;q*g$ zquf9*V>cu7y3?HeLH^_X+wFSh{Sos$M$T*2b+q;LOZ&9b&-HzbFB&;dX_xbFqwD7w zcNu}JEZktc#z>s9FPQK7t{-e%VkD0-KXuk$YMf~7Z6wcf4&Grs=h3O=)ke-U@&NNV z&GqD=dh-?|e#|xp$OGgn#?{jXO-APD1s*isW+V@slGO9wOgaceQ{WdivoUM^8WG8{&-hs~Cb9=bH1r zb+kGA%CzDh3$GZ-7x;tS3JW_>7$+H*C;AH#ev3KzH>s|`fAs7#IOiLDUwGKQ43c`{ zkago+%x%Xx$GXk1p8ApI4;W83t}|X|tTQs7Z<&)ndG2cKi68oV)OzxCtN9Ve;YQ+` zJkGe^c0K!*bz)!das6?|rAFdxg!$#hLySX>eT~DCdd6*a#ViX$P#78iWb=cK#CKv| zk;W{pvF+YBZZcxm$9%AHl@VZE@SO-o_U~BpGmRUK^lMtd zy29CCyr(Rsq5xSxs@aF|hY$*oJlnP^BkQ1N7-(L z@e<=MBkMq3A&*|{dh+y2^LvdgM&gij5vL`6N$JzrgjwjqGFc$1~Qm5Br#t zSC*M`AE2Gvt)FTH8Z2})o=1g|yvTkbzhHKn1;#Vo{32tz9^{dMt{-E3(s-?LzLEDB z;&{CEJ&lYDzq?!ip7AXs^GCj+AL2Vp?AnGp*#;KPacP2 zA2=_mKggW-grs_(6o1KYaPl5`5k$}W(?R-NXIEek5~s8WcCbDB;#2e9M&f6!`358V z-L(Ao)Ge>mfx;CkYJ zw0U3S6-L&ZcsbYlH;m*t{DZr?;&ls08p+S>gZr(2)VR%_Y8+-}#8F~Z+8pJCi;Bwt-<&Url7e4~+lH_H4L;|3$~%lut#J^TG|bIzGn z<}VwU8G$_(*gyE2=E)k{?chPAM46KAwSHt-FeA#SC}*I&F1rs?At-+#OW#K*fWl8)*o$r z%s9<>my!3u_2%Te>&;nD_91b(&h?C!_sG+&?`LEl*qu{RuzE zxt=(~Kh6R2*2lIZUg?MRoag#x<7VT2M)C~n{G9ciQ^U;r7*8~IGQMb>WPIG%&dB~F z@3EdMT;J8$%lNjDxHPRe*#dqXXZ}0}#xyRdzS_dAM&2VI#K6dWCh7{F$3NyTwPSu$ zJ$Zq3OY7l#5R6?4kom=&dBnb>E5IUsXWQ|7hvd2TNj>eb{;n~<;~*~Yay|OP%!!+G%?BCJ zG!mDb0~cF=pOJk^T-|B?-NsjpCm2680*voAD>hJJB+sDUieNm?$i5$&)TflVO7-+J zAgxF^>yxM}jg=q_~FDe`8#t1-Hp#0S=VFD zxp%y4-pxo}-fTYFNITb>Pc$;$#3yn0l4$xko-fmE*Vj10Nc#=u?CXcjS+C{h zU5xnur1=!%d?W9pjIY-ETZ~T_$^YyxtIOYW7-q#RBk$GNEwb=13M2dAVDn{0#+jIZ zhQ4FXNF(#J@iUg6uc);8EhFo`^YhI=qrD-~j50sS$b2$C?9Yy5-iM|0wo_L=viTz{pJdD>!rjgfr?zs~x5j3bPzjpS3}eyR008W{)evEI+S zeywr0k@pqGd4%=smq*O!8}}FwGy>0AU>(Mov(Jvi(AX#8&<+R*o{@0YBT*Or|6;6X zYR~><9q1oD_N+JW4s zDiR!Oo<7e!Cgz3oprq#Qgzt;aOblugzAs9AV7IUSS%1!R{3o8!gTw`KkNqfj$atrl z^M1C){9YsX&ynV=%W>wf8&5RyzW2U4^K`TMIwQ}|FejebPsAnrWPjVyKk>=_B;K#G z9sA)WbNnWMZMVLU@igN$<2%N+#-2vb4dQLI^}tOQX!lJjjDw8qr^=+B^M!GNoalEX97;UGi8J)@_BJ3MLB`Gcv9Fo`Wq$5H<0j*BV^3qf zainpLk$Gi*5HDxDp1AK}US-^4yu(;BPBOCoP3FWc+II7c3F zL*tC(!Q0JijgyQ!jret^d4D7CH4mGQG;TA#Y9!ux@4)|eTu)x5KjQlW*OTu#4~AP$ z+!OD#kADZ-j&}K6<#g+rPwpqYUoemC8|H`kyUWir@APAJ`R8gqt+?Mvzo(iHF|uz> zD=xFZdB}L!PqSRldOW~`#+AmIiJpCddO5-u#9fIU`!rD(QsS9;XMdvSLE;-Eo|y;Y z`Eoxn(8#_d&-SH3BXP#M^4@)g>xoC^_gw2)=cVS1i~P#D&wP_7887pGv7cwXTFi-S z;$pt_(~OK4e`i?Fetp@zkCFYg$eew`JUwmwGsZ)W#ODF#lZ?y{{SpVqxPH0u6C-fE zg;~t2kq2|4f%meYW!}=$U^~Sf1-v&;rE$1%iIH&=hvcUfuID+z_9sV#Hti8EyS~BXQ5?EF(dIk*8DCb;~-DWwSI{%9lT&2{Ee zjI5(+8CuFx6mBwKWvoi-zbM6@FKVCKElm~XX&f^TnByn{ zKHwmE?L;d$ms>E{2X;;B-!#8G(HGHCNdVw_Cr~_^^?4oxDpPquoPox86t|W_@a{C+{%t%p2qQz;>+fd*+-s z#N$!c-)m$Zc+aLE&XLn?H`@rjfMBE@)UR21-iTk+r|-F+QeotMe2#eloDUe zOIpwTqKRLG!a6RjaIr9YvSs&t; zeXtvYec-@-)Kg6yGGFAY3vFA@tisNX}kGqV^1UTO+Uo>2-mZouAKn~cQIbo24X3C4Gg7aNBeiQggS>>K*I-+JPYean7Z=lW{n&Bolw zxEUAYVIR%49phzu?^sX&mzg&i&o-t|lPCt6?=;?*)bER)V*~QqgZp|Qsh^Q>@@i-6 z$&={mhdfJuZD)IsID5eQt;So7#2x#(kM(Pfqm0aNPjljSm3eO?`*Ml-L?ii={lomQ zU!J!eaq+e}`Q5Yvd*Y5bU_E(X-RK91!$x!BXskK$^tky$#^;QjBgE-Q>u)l$56Qpm z_l2&XX9Sq%3`K%(nJ+OC|Cq7PgHa^-m^u3=sjhg@irHy}IdN7{UH-l|=S6BqzD)JR z5qa}{8(e3+#JIna_e0iavGq?GiEsAZOzRIdE;Z5*d4Y94%Jt-R; zIltzZHyCOEE_2!$Y(C$}IuRe7YlpdhhLL?i|5lgbV@n;%nkZahew%SjQV)I2!ZsAf z7mU0oFz&>x;=ojq@C6A^Df64^;nNU|M?2(Ekhr2->{HJ3AocW5 zJ0R_o&lwN(7r6n}AATAQ8)q2LFcN1^n6p3X6T3?DH;mhj>@(t@{MO|9ZpKZ z4f_AqT)8h%Mm{C&yC%I3r!?+2{IKwG?SBwtN8uQRSTGJieH zmm67!9p=?W_5=18SkF3fF5G85>-3=cB;!k)_-WEKh|-w^{n@?=KYK}8?Q0GYa|{gn$r*asK5317#Zgv^To#W{d2JO zcNR(wg$JhC4;`FZkENAu1`>Ot(_w8JnuB+pZi zJ@s(vNBKF{m3XbQp7q^o&U1&D6X#EuFEf(ot~DnwzGJ@0c(xI{8Ro?4Sab3^`G7bY zB3s z{V?C0FZe-TU>>I>?b8lOUS&K7yMp;=d`I&@g0sv|N%TV#&iL>L+|PEzH*pB3p17wS z5Ig3D=h+W%>gjis`z5asPy1WH(a3sEGAG|(Z@$vVbEle<56Lsztv}E>*+?EFZZERF z!T6-{DdX)%;%~Hh599sD?nd%B`;qxP%=Modi5vXIfA;w>+f^DFFY!-%_{Dq42-~y& zX`gsv9aq?ndFFkEeDi?o-!rCAvVyq%#QY^A^O8OXPum?xh4IjYLy4m^ta#t}qLFcN z4xefLEaUmcSB&_7xcM^U5ym%-n~m2Q&ofRpE;bS`>&@9OubIy`PD?oB>S_J6i5>n~ zUD0fTcvx#r|7V-0ei08hxgL1i0`uBxzQB05k#mFmV0Fb+7CIQolgtzP^!=CTLG~;8 z6n(9QJ}3(C1M8~_>c?A89I-x(Ynkg886P!rPINOr%E-Qd+PtTc{Y{?OZv9TF``awqa zC2_yjdghJz++aOWLMTA?0r_QIfnAaM6jw|wGT4XygruH)ikf(0f9&T9?3gF);TKu) zxN)VCcw^ljww`$+e=#oh2XTS^LO=Je@j2sT#<51m&$?Y}{gFoU2Knv<>&dsoHRt4~ zuD{N>$vE6N*?5bQxT`Tg*SN*F+<3lmwh_NgD_G|ntYDrym@}`~Gj2HN7=90S#U&Ud z*gN6O6Y5703h=2!zcAsP8;QD1i3=3uUHBJ8;*j%?cx1jfUoLio#Q6i}2OEid^7(Y@ z+5emu6Ramr_LyH_ywbSC$bKhocu%1J=WX|nahmZ=Bl+}gbM_Ca=x$7(KiPWbgXdqiK>y7DGuG2S_PpQ0@33%yk^R2W{1Xg~HI9CttE(*}t7^hd#Aym$dW$HFloAIfhXjE=o0K?a`=J zv1*4Zf@;->imKXFh@FU$phi$aC@KV{q4um%Ge$L<@NfBYT|OLn;{9~JzjMxWKleE2 z+)opU1?W6zf3YuI#NGghL+jUlbv@40zT>@Zzp-zam*#T~<35G8(EjE<>UT#oe%G@Y z{S)qkuY&yqdH}o$?Z-ROo}b@y_IdN8f;jWI7;T;-)9e?1=eRHWC-LJ8EYktK2h0o4 z!!XzDKDb`}es7%D5NKUzVt*3+sY=&rAAiS#)|LC!341TN8QNDW(a)jr=8^WE_?N(Z zcpSRV_WjY=v!HqE`EJ3!9Qyq=HrTCi`))OU*YEy6!@eD^f{oC;cb_t_XG8mM1KRhD zL%Xl8_cM0iZ=Ck|_*X3E0h;gC==;!q)QYZwZ=rcM z3H=^gR}0bB-%Ip8cocSqd!hM#9c>;tzscCGqkiZiun9haJK;BY5SqvKodWE0p!My3 zdOr3E!qR5(BUb}-eeRp?dVd5S2-& z6q+~I{dw$TU^a9e7tq^bA#8>#VQ(m`-+37Fpx=eY`TPbAuok+`F6eh5&UrivcK6eJ z&N_Y_;*|CcwfBtYt38QxlN`Zr9vbg?4kONUD?$%|bD-yO3_TcDLGynN+Wp^z_Wp5w zH?X^p_Sb>fr@&Lt{O^VC4>!X)xE1>Sx(IFEWT78I>njKS82TRT(ffBF{=?97-GO%h z%yaX8JpNqx5jviE?t1s*Z-Ktw_hA0}|KA?K;J&yH$Mf8+FQ4z~4>&4l=NHU>rF~t) z8-Ku`L7RuWu)A*K?3fR5jQw$AOt_K9-*zGoxaefIpvV0XRt z6VKK2bY2sQzXHuC{f=w>oh8n`>ACnG$9JB;i9ZDULh*$kM>N21uojy4YeQTk`dqN@ zL7PwZKlj&g&nv$FmI3ew92)G(d;a(j+y2KN-R9wL9F@>|bv?7OyH4}M`_gmDBd!8| zhufg%WxV_E{a{|1H|DwhtPSHlgT>JJljwZ-9C~lr5A44k@Voxi=&!I2+Rx0V4D2QF zGOU8W*E+=f&%bLn?;X$mD&+y|z~>|DG0cPo(DRA$NxBI~%>ShLXkUsUD>_ot9Yb6u ZIvuuyp1<_`eB5!`hv%X_KIi+i{{hsO0d)WX literal 0 HcmV?d00001 diff --git a/uniter_model/data/test_data/input7.txt b/uniter_model/data/test_data/input7.txt new file mode 100644 index 0000000000000000000000000000000000000000..619f9be9298fe03e442596509debf6de5e019598 GIT binary patch literal 92849 zcmbTfcf3|b*5!YR0*V=PAPdHugNckFB6>|2iI*Vp93D{At4LB*%wkp~2}V@R2@r64 zI?Z&N>3RF@p6N8x(`lyDOi!oZs$1*(d|paVqd)$rwb!nFs%ls5+I8x=&%G>Ze^7?pL<@v`J$swU($Z;FFL4f%_;_rKe(czVtziaG;hkh*%KDEqn%y#?W$x6Oljdga?;PJv%F^QbEfW^Zys2eD)}b})IADAyDHqI`mzT0mW7?1J zAU1PuDeF9@$%4U7&t1iMnpte&Y|6NmFfgZGC-XO>J#`T~l?|rERQJ<}b{; zj=AWtcJTPlYP&QkpWaf+4w<3x%;-EL>9DFsU1x{R=sKp0Iyx+L)MZ6>_?UJp zvL35ugtm?-Wk)t=M=hK@Bh16mt=TaHW_0}g#%dcH8>+Ko)mRwyaph4TA4h#cZ1BWV z*0VY5HRJO~RNYuzuMze3i28U$eZz?Q#RmG9vXh##0Upu7*6ieeGNOj+#_SZ2=+yFv zPKzTtJvMkoDXVPGstQI_Q(x0q*_c&(L^U2!Z5UBqY@oiBH8f|99#K}orR<94Y`jNwWovd-;fNX=o3v_IdqmfiM|5o*(RH!G38ideb2h1HL=6q~ zP4(Gik7$ZVG&PKy>_!23vqNyXsj)c zTH}arjSb#b%9b=|w}1YKlx`Xu6^u(gqGcY@@-U(mv4NGP?2hK_PLF6+Yj#)Rh#IOJ z8f&w=J)(QcBU&9tbZ>0%zEXC7bM`>dh(Zpm%pUZJ9`cAD4kLObHn66YJ=&Z-<`F&K znmtiCqS}T!jp#`=Hdu4?R9j~)Xm4UcGB7}1-tfwxN8+s)ZK9?`q4*?WZ}YHFxysL$T_h(0Kf=)*Xok79!# zm$L26*^bX0k(R8XrmDI!+vyQ~;t}l%BibDs*i*_rZO;CkNA#7}?B5rTNHbJln|;+I z`dWEJUymdDMr`n#rR+a6XaBKiM2*!=nxk)dMBny^z7t0D-PpkQO4)yE&i=DU^!?WC zzkL3P>KYru2~eB;i%0Z>@`(O5j_8N6!5@{f|Jt1Wx1td>)z)bD|JWn?iAVI)FruHu z27X@3{(E!wKRlvev}XVF^GBo!Z>rYb&3@?-{i-~oU&j&sCN}unQue=^v;SQ*qQ<8B z%F66_9?|bTqCbQY{V_K1r&9KRnzR4w5&gL}`+xsrL^XBQ*YA|#|HmV z$}8GcR5a)9R2Gb@uCl(mu0C%cEV_o*lDvb&IJ&%}L=7m5h{^&FYMZnp z>hnW`MfBh>iT2=diNS*&!r1B&l1llJDx348RN8~1CAQ|rs4VuNx~i!_xFA_aCK~n7NL`kK*r^@EMm&zg!Y8$F6EA!sLB6`qAqCMy3&Z|@w zcu=R;tj5Z`I#@&xY9!i&T8Y7fI$>O z!8H=g9$YIDJ-AL%>}!IgQa({-b3RFBkq6cFRSos|s1zcP*+pgq$J;izT+^H>oW0ptiBTN*>%S6ZI2v z(Jc}~Kk9X5-YOD3xK&c@>o!THe2L2D{C1U}?}6g8UME$)G+0CrmPxb+%OwU6RtRIO zDw&qW$EcBqdN$;xp zlQOXfPf09$@U%$u;2BA=uV*Ed^0g|P^XF6+c~DthRbQE}3l`CX^%CvD28qFgjl$UK zCP}6Id6muiW|j6}i^SIa1(n4f)Q7D7qD<_;OA^Z-yetwuctuj|>s3jme5=al{56$D z9#n@{@BH;(5j}WAqCMCqF?jH%Ft+-Zq*DI2%I5qXmG2Ot8C79s4ViJR-X`6=R1Q%^xzYT_F$L9;K6QT zY;}*MQvRvR=KSBOv-vd{y) zKWe7_MJ8T@A4n{(!M}<`4}K^q_Vpu4rTo9DY|j6i$^s9WGL=O&>6#F_z zlI{m8bw5y9nD(%715_LaNS?GaY zD;gW~V`X9wj+0pS;CPYf!3mOLUnfe^{XnJe2P%s^sH<+$&7*g)h#vHjXb<{I3?B3o z##Z}F()~cC?guLE!9a<+AE+$!pteeH!TBjNQ9t3W;#7&DAN9I2KTRZhaJr<}*BO#@ zKTxUrfyyEeYHM{M=;Jvc`sdT_3!*w-LQx*w?2{Xk`*2TeL_@}a>ZdN53)Js2)AcrZd3TOBD$ z_XCx>AE>kkqb2Hopt8_|`o;!D>lm4+pKzjGATjjQcA||Hi5^@iDfTr^lI{m8bw5y9 z3*P6_XCyo zV5UUf4^$R-5Z>ZL)}AF3^%M5NY>A;C^|~^jBN9EBD=GFhPm=BjDs?|lS>!>jUfXp) z2o}+U1rqH+Nn-F|p)j_(NRsXcDs?|lX%7}l)crtZfd~3_MwY z^x#%Wv9H@C>3*P6_XCwh9;jWN_Dh3B^kA7pd$3$$@L+{7wz^W1?guJ$KTv58R!P+T zKxLr^4K?B1xmzao;2w!(4_1pr5AKx|`?^n(?guJ$KTuiZK|@2G_Q8X}B6{$UM0@bC z#Nfdr!r1B>NxC1X)cruEJ$PKA?guK1J!q`f8aycz^%HjTQxZczZ9DmCk?6rQl44)a zO49v6rS1nRi#({-hrreOx?mAKSTE5YY>*f{*eHyxZjz+?flA#ERN8|r5_LaNS?EDs zI1us|WnvFrl34cOWs&H?E0SVguS(MWK&9>nDvLa*(FW7~AXr2X-jHYywn+>gyeW*W zz9mWb1C_cTsI&*~O4R*8WuXW9SgO7@e_tl{-~)+e4?Yx$9(*Jz_VuwO-49giexS0* zgR1)4s``9qu!tUfBGDe~k{CSLEsU-1k)->9O5G1s+Jmo1)crtZu?J0+dI$chOzgqe zB$hq+x=8fk8Pv+dhkO@v9BLV()~cC?guK1 zJgBR!(>?je!6JI_6N&cVrxJq)KNH4Qe=bS)1C_cTsI&*akf{5C%0dqstMmic{FgHE z8vIIPc@2Io5kQuhOuMIJQOHR!G4cflfh@Oz2&;13dm2Y(dCR{tbP z_XCx>AE>kkf0n5GfyzP;Y8u1Y{}-9qgTG2Fd+;}r=)vD5#lHR_N%sSNk3;tZl|>%Z zx845Q2aD)I2Z{Eeqr~7rCt+;0vn1UQRO)`9(jIh`sQZD+Vh`G)Xn&d5g99X%J?JhH zJvdNO?CT&&x*w?2{Xk`*2aQeo)bY?@5j{9eqCGfVV(_4cFt&PxB;5~G>VBZo9vm%E z_XCxM9@OjTtjv#a)R9WTJlB?gtV>Kk8NY1Ci*#>5^h! zXGqfhK&9>nDvLa*uh%y<^6Fp_J*bgr4{9X_59)-m)p|*~AE?y*K&3rslBoNE%0dsC z^zmk8ewIwU24_nwufaJY(Svg(#l8kf()~cC?guK1JZNmx59ITq!6JGvOrkv)E-`p8 zLKs^eDM|MOmAW6OvVBZI(1ZFqeaR*tBNKaYfyA;0V@09|7fOnKjgzGNflA#E zR2Fzpubt2Ukke{Xk`r2l~uWKTgfB zmWe&MMq=56Yek|5*GY9(Ss=x?ZH%u!GjiIY;~F> z-49giexTAG%#^76fyzP;s@mR7XURnUwA~LRhJMtm?gt{#gSnDoU-KmCexOqK1C<3H z)K^y3>T{etSVRvNNVEqfiNS+~!r1B}NxC1X)cruEJy!=oQ*HQJ(cLn!2lq%Ud$3w0dT_6#*w=lMbU#q3`+>?L4{B?}XK4=x zi|D~a679jm5`zbi2xF^jBtskOT6f<^RTy+nJkL1OS=qcFC*Ns{gdDs?|lX%DtY)crtZ zp$FP-daHO*CidVZiDeI77Kt9bA}RLuswCYHRO)`9vdDw*qjud7f<^S;4T<((o5bM3 zo5I-YTat7?P^tTYN_+6GMBNWm7JAUs5IlHaCh8~btji} zAE?y*KxLr^O4F73&R`Ke_(Y;T*d;M|uv-{g-6Ki&1C_cTsI&)Pk*NED%0dqs>-58- z{HrqY8hlM+c@4fU5kQuhOuMIKbuho3@yD_BGizAe!nd`DvN;Jd=u z>h~n+exOqK1C{pR`x13OP+8hoX9 z#B1;?iRCr;wMg{fHflA#ER2F*Bq}j=T7c8O&zn5qa{va`U@JC^6^-q#? zKTxUrfl7PuXNkHWs4ViJzNSjwPR#!z6R*KvC6?FVZz9oyze|dJ{X>%O2l{Te?guIh zJ!t!Ra{FKrJ?J3O9(0r#Jm@5ht#+2A`+-W`4^-NNt`c=WP+8JgH3KTxUr zfl7ODv_#zxR2F$qUmdVBZIz=OKVCVdMtuMQT`gBpqUpjKk=piUTDt(ThndgcI!oiJ>3$x-uUt z5VBZI&;xzoRhc&ji|D~6679jI5`zbq31h35OVa&7rS1nR?ZK52 zbw5y9VBZI$OH8f-hpSy#2(C+SoUC!Nc3Q?q}bOyNxC1X z)crtZkq7ma`hAT&4;Imb1rqH+Nn-F|p)j_(NRsXcDs?|lX%7}l)crtZp$FAf`tUlx zStj=27KvpKT1BD?L532P8 zy2|{)U=cldNTNM>SYq(t5n*g~jU?RAE+$yptkMJZe6g59;}yW4>m{)9&8lGRyRq~{XnJe z2P*Bs7Kyqas4VoLwz@{2D7+{WufakQuhOuMIO{Nh0iKp z4;ImbHzeAFZ4!eAZwh0pZ%NYqK&9>nD(%6$5_LaNS?EEv-tF}B1DV)^4kQuhOu_TU#1 zbw5y9=t0}B59PmAE+$ypst}&?<>Cx z7SV&>OSA`nkQhApqcFDmCrP>=sMP&Hr9JqwMBNWm7J5)wud7u47n!J^@NW87iJ_mi zchkR#L=XNhDfaabNxC2Cw?cJ4P+8n zD(yj6iMk)CEcBqhE_}~!f0?MC5Jd+_4E?l4QFoE(!GV%uUk6Fj{XnJe2P%s^sILp3 z*&P}zq6ddbvVBZo9t@PI z`+>?r4{94KHB+a^ME!&)I#pumr!9(36Nw(2E-ChPh9un&RO)`9vd{zlAh0s84i?da z8j1FxR$}m=P8eIQm!$iFO5G1s+Jh#Ex*w=4^gutR(7WkbGO-6|ODubEj!5+2TuHI7 zL6USoP^tTY$|4VHs`ML_`OshyJs2j@9t@WlJQyL2t&WtW`+-W`4^-NN(Gqn(P+91K zers6g&KQ~4g9{{}#AP-49giexS0@gW7P?GzW|5!6g#y!KD&|2bT$B ztCvgC{XnJe2P*Bsl@fJ7P+90fV}1C1_G+1^pRkj!kr?_>uPgIwMWP4SNs4_?L5A-!V z#plvs5j|KY(H<<97(7@ZjIFMer2BzN-49gSgH;lBKTuiZL48x(E%R=f*n@i{mOWT4 z5=Y;}_)-49giexTAGY>}w@fyzP;^aT{{VBZI$b-sy{g^R-Jy=8!-jHYywn+>gyeW*Wz9mWb1C_cTsI&*~O4R*8 zWuXW9tU{lxy)P55!3Pq{Yw)2+^xz{&v9FIM>3*P6_XCwh9<=>he!eqUL=QfZXb*Nt z3?A$j##Z-8()~cC?guLE!B-^eexS10gSIpIt1|H#d`)6`4ZbcCJ@|&C*w;5D>3*P6 z_XCwh9@N!@Z_IuxSVRxLEzur)M`G~cyTaJ&_ay0lpi=h(mG!>3zTd4k zyMK|1J@|pdvIqYv5DRUM9|w!*!A~UGgP%$a9{fxg zTm88t-49giexTAG{6eDc2P%s^&{ux-+fey0WnvG0C9&+muSKE%O2l}Hax*w=4@Svf(D$Gs$U=cm&AkiLllo&kdB#fAE+$wprN|$XYIX%Mf9MLM0?OzV(_4!Ft*xXlI{m8 zbw5yP4+cuq{Xk`r2U_d4-#?Iv*Wgr%BzkbVq}bOPl5{^%sr!M-0uS`Vq3|1@ z)xjcqP$SVE)JhB<)Cps&^^$ZyP^tTYN_)^GQTGFtg&u@p)%S~JVh_%iSoYu?k?6s> zl44(jBHyNbW4;q;cqPMpDM$j@Ht>_W!tARS@*q_@oz#Mm?*=~ zk+XwRW%yGi2dB#LM^Fw)mErFq9hxe`pOiT)QMUbcpu*xt5dm zQr5NoI$1U#QHI~x%?2jQ@TWDhlM`k5NoID+-pZ=*hZ(X{6J_{|1leh+Qs0cpPEVBK z&Xk>zD8ruv$SM=`$`> zhQB*@QKAgLb)8+DD8sMbWzBml+uk^{OA=-H70m3?R2g>1Wr;HQnO&YL!{Dz-mHKQk z8=ok{@7rcqCd%*?l&TCvMs!)C(7{aSJ{k28GctP zo4L2L?U#45>l0--^|M*2Qa^CdW+%$<&A4n%stiFpH&ur8J1+Q)T#*ky)Y) zAE{<}q71*clP%a=+4lAGtduCj+jX`uQHEa($rdHb@D`ihkSN24HQ9}cGJGQ~TfCPt zez`ciDN%+`O0t_%rQUe5TM}jX-gefyx3VgHjF#P+Dnl^bmMX*R!IDH7eiEJCo+!iD zSF@#iE8D)2kS$A<`oTrEJW+<9*=8#eW%%kxwlYzM?`URsq{{FLcW0{9&-}Ahdn?<% zG?(3#D#KqwygN~bFGgkeB+Br4dA2%L>Z@AWy@@h>hc3G>QMSF5-k&JL_XV>DQe}7( zcraCl<#=dsW!slCvWF98_$hw&NUGG=nX)yBGJM*QJ(?)P*M_pkQe}8GdOT6K1;Z1m zGR)GGsWPnQQ>ij+!>3ba_|tUHB+BraSN81QN*&K?b8VswpQmTfCCc#Cl5Ab7)bHqL z>l0=82rS!>D8r|o*~V0Wg7 z{!*&c=N8$^i8B1KGkYafhIgA+Q>CuL+15lEe$tq|wzsnF2b-JyEv30q)pK+0gb;LAEnhhBf+RZ)MxZ71^#t89oWi zcJHlh`)DTHlPbf6e3~fRj>9jU#Kq{0`1c; z^~@vubOhsJ3GbWmlnzZ44<&qG^oV^K6x5Hhev$EG$uXKb(n2_u93Jo#{5d-dgHOiON{tOzpq-)yk2I` zzFllyX#{S!Fy6?%VgJ5l{qe?0Mj+)Ux`O$sGJnB1ng@;7B%JlUC((amJ|oc=QN~L> z`2oVW`LE~CuejxRr_2IE|5DE~#3aTqu=c4+jiYlys!$_QUPV9+Ckhowyut!fjG!yjw`!+DPQ1j!F@ z+Tr;WX$Pd9ai@AX{iNfep8i4X7!OE)_&v%U4>WRaPckR2W|=QCHW`VlD)Z4s@+IRV zZVq())y5X%Nybx*Zx}Zk@r(Vk#rngHeT}<}?0epmr&!Ov?_th9!e9DjzR0h{JN>o# z`A>|`7=c$2j1L)4G@hK)rxbrvJ^7CLOY7k~5E5j5n1`ggOxKpB&llz<>Wj?Dzu*bh z<3If}kF57;$@4FoA8%yc*O=dJWSw~bxXyak>os%Yl6d5K-ito69pmJ@W}LhSk#C8= zA%2eYXM{QN#(o-ZJ?HUAbIylV=J;t^!Tb2JR`fHHw>O(l<$(l=m**1wR`dIfoGXc0 znZ8h#qHvq}WaGt2eM;*R1^3%TT}a0j*6dxu`gXLsQ-bU(;)Z%S&(Tg-Ki|&CIl}u2 z9Q(AL!G7*t<2vIsBXLK(Gtc|G{$nHi=2mm!kT_uf5J#+Mv+XAvPcV)&zGY+_SJep+3z%fdxQ_7U|U z`;B(d-(_K^u{#wBvTkYrm@%J>gZ_Hifc^QhIrG8%b6&pZdggnB`9$ME#`BEaYhE`e zu80HjNjKL&Z)Cijo8-ZDNjp!Nw>MsITy7*DFEJ<2@!S~eiKiRQ*`MUKTI<<|rWK1U zunz`PVZ7f+9E?cnx0v@e9-Pz@?1=bS_jtwPubl85(U zR<0%viK8{HInhWQb~k69`k1p1c`q7fJ@G^QPOzSNBn}?6p8htP)6ZJ-(MIC&aC7p} zYV+Yn=7D%6ep_5W%J_hh^PKpdZawc+OU*kN>4)_u-=E_8*+%A%eRrPqV~s}`0qn?& zce!Gok@>?uMe;fG!v5-SL;OjfE8@4i0zZ})8Gb=M@sOGmhnyQ{+JHE?(0rEh7UO&) z^UXS)V?FIM{^zWpXk@?NYEB+G#C*4ryu|oVvwoZLP9yR2jye6Y&cw+Du7BD{T!Rl; zPh663###T8@e?EM!?EK$Vjjt>UH#mtMu7Q%FSEct!EfTTABF{ZSfW2Z;cq28r6aBA zm{yn**9FyO^}@2Wx6p2HTK>5d{V`9hTW8x-5AS3>>rdP<|KztjY}ecPq!GBm0`Z4_ zhz0iVedaeAZ#MG2HQAi^g+Auz8ecJz=g7nBtmj@*WzIS?{%Y%|8J{-}FwQZ4Xe2LQ zW`36O7~=~@=6jVn?N2d3+Q@!mob2CYU4OrkahaB(hozpz+l_A+*@r{TQ+Uh@+V5%3 zK4kvW=irQ=bzwcI|FXDW!2_eLC(kT3=e*~AhWT3V`t`;Z<7-CtbB+1M#zDrhM&j^N z^LvfN8Tos$^{n@5^K*<7jI1;L%(s5Ck+@=iu;H-eGzCF(-jE2&{0Ch9LtyV#O^2a-S8xA6A1?_lJ4u#@%J zfjpnKM?2(Q+QE+aryY3O56_`bIqlIO%`h zCyeAx;+^MKy8cchd6M{}oikiN%-CdP{~l@%K45;lk#Vq(c<*3e(tnHXLHvB%`ZtW^ zN9J*X^*0-Vr!6pE-a8(*p8Vb0oPKzoyi#ig>p?x|4d*C&ko~;Mie*$7Ip3I>B*sN6GW+%}YiA|M9EQ6`W7w&2uB?9?w@=U|rCI%s1_GzEKa7zv19$3nNh!An!kE z^$shD!(HaA_nqdoM)tuo=3|WHgJI^xGw&_L3;F19+kIdpAD(S~yODK#&HOMU?OtV0 zemK#5u<=qO`+)pUTr72cwQ-}7b>#hlcx`rli*b;V_%p2_A9b>Vah+~XUY%}^UpxqGS9CG5PRu9gN1f}*FK?UQ zZM@t_-elfZTYtXsMkD72@k@R^#q~XmyN%ZwTN5PCPqqFMW22G$_qaLvVW2tt;V$!6 zjpU&l%&#y$Y-Hasu7|7#4>soQkF{ok@c`rX zM)sxYXQ16@G*Lmj+*49B^4SZnVBU$NtF8akNd93x$WyE%{miyK`TRihJB%ZY#Mu;c z){}fY)_U^XN^{1^_!e5vy7n}$H+D0!FXo%SV_a)wTszIl>+Fk8)-N}nX(Z2*f0-}r zY*s=4#07eg{8yCc-mu|j9!PLQ!iy;TIIS z^~audp&gr7urH`bPkZDG+M_@A1$y@5S~o!6gHN=c_b1|+yvzO~p3%SM=Qw}aUoTp} z#<;{tJT5mUj(eJu-!_=DpT?PwHy&(cUyz^ShZ^=J5`+4>id6w}r zUh-m{?OEr{e713$k#(psZ!tb=WL-I*$d{}O>%_X#{uVz!z?i~JE9M*7k6X`&CBMTN zCi{Z;nBMk% z*to~I&iJZvq>*?hAJ4UZl<_6wcH=ol+N1ss>+$Cj^Azr}qQ8;(XTHhv$E6RNe`q|_ z*kB}1ZZv0~nU-OirOS=I%QaEhp70{dJf!vWQpK0#RkkOdd2fId$L(yF?i0FWK6nn~ zIqYbkcGxf2r%37RZ&TcABS&+{CJo_66p&psuOlCM8@KddWp@UHbQ7@Lgr z$GNf2divX9evOfHfc?k$^S0}`ubNg6kF{3ZX2c$U##leVINLbGIMO)TxZXJ3IM7JG z*l2#e@d)E?BlAapR+r%&OFfLVgWu0ve_8p7e=hL8L%)^(-25}O=UcPcc+6)k{<-2V zS08C)-T(PhpRFBh)rH0`M&gD2M0|tfxwK#EN4XvHUtja*j3bQ9A92Y3;9NM|c0G-} z@8Ex<^=BBLHWCN3%<+f!LY^lcdfASBbAb6A<1*vXM%HbDIp^d?^Y@JSLB9*F-(`H= z$h=%^&VJtCyv6vkk^Rxre7+Gl(83NB2~IR8pInjDLwR3?7h%!)@n!c#+CK3?z9Nt8 zv?2c8X}-$1#7N#I9->>9dEz8<#>2XntY2)r*2sI$3Fe%qTg{mt z@)PTStm}!#rRF`2?2~8Bk1}pDl1J%>dAP~-j~a>B{^lP`GB>o4PuQjsYY4>gG z0rnM~al>aK7>6gE_E8^+P=FsK`ZRvor-@l1t#(ZpBlAf913S4s#nhg5+9mDsTnFnx zo@XB5w8MO&r(Jl8)RWiQU-!G;6OH#7*)N=TdBJ?Gb3OgDKgi>p7j?FK%=n6NfpL`aO(XM6JLg%? zx|)`Mj`Ji6BYCLSoV*Xm{5T7o59q;u2*zm%KiT}!L=Pno!c#onih*gx7v!lwJjVJk zkLkRzUf8Ak15P~C4skx#9Wj5z>Alu(HZrfDm`^pHZp44niftCw8b36)80r5K^Y@J# zjfWWDGV=T_<_{S!G~Qu6#yG?{(fF!yhmm=C#+-A4_L{9{|MfNRZoJZXrtwo_e`7D> z(?;f}oB8cVUJlxo5 zB;GDJKh20gTg_R=hs|lw{l+SZ3tU}&V@L|wr=p~nySzVNcdfcQ_;h4jS2 znpn}r_3ezz6X!gL9f*B~rE6F1kGZ!``zZZ;B^i_NQyM;hymoZIWn z$I&u|)u)_1lj$@a&aPnO{*S9w^E|AB`i^SbpR}408G7>+`Cwce@ z*WY98VZ7N$e%)gJijg>Ae;sZ8BqQrVytP|*_qOZj8aelu znD;XBzJ0hk^`Dv(PsHWD)-!I#!@To;_@eD*8i9i?JZF5&NZw_f?6)+IXWRZNBk$v; z6>AZUOA^j`$G&F1F=Kr24-kw87#}urz9eR4O5B3PRf^0X^T~Yk9DI~p9%LM6Bwx@D zdg7wRcI08x3gV-m6+4Wq3;q!g#M^e;F@LPzaO=tIoa+OvzsPuqk#mmr2UvfSk-T`G zc{k&JM&jXB^NB|G8_zR;V_Z)jy~iA=wm=?WeAI(yT3Cp}$otn?^DL={E<#9fLc)tE zc_FP&<2S8`vkn(nvB3De@oeLzM)C^dCJu>L&ZlPElfR~#FE+9s>~G@cIM)-0_;-Z$ zeT?-+=JihVn~jX8r}+vaaY|h6ww~w7M^9Q$etgCJM&qeQ&V`}opBjm?{S*7G=Jdxr z;2-;@(ssb@7H%`*C;nez{V_a{;4*XOmFkNTjGU{9y5ck|*w56Tk6`SXa47x5Q|w{| zd5iU9Jj`dikxDDBjcQ7{Z8XL;|?Rw5%*SC zuzrtNK^~*O8P;&D*3%Z1=v2(U!CxMQT7>j_zx#9ZnlE`wbq<`HO!oKYBlFQ zjdeQDdft~-nsYvZ8?3*@NWAtk=RIem`9vf4&-Lb~7?&Hj80iOl#z#KuXgk*bGV@Q2 z*BS3Ljxw_DFPTp=t~I8x!V2b#_#n^1x2F%9GcNKV`v`O52%h4DR(xn=zEis*&bo7+ z98hFfpntDGpQ@SPt6af45?_6-Coe2ACmx?RC;o_wKGxS6&o%4=eT+MdoP*5s!PYa+XPIv{;urC5bp_`a z{A3I3jMp018dEqiQ5m|psYE%!DgUg>{7SGO~=Uc_l<>(l3mXXc+g#r$#3 z6SuFsoef6v2J6Fqxy$v$Ta7vUew;b$FvFa<8e)E*@h0PNBYA4QIsQLlewC4RW1o{( ziEH*1?GShDuPJV4fpNQWe4^)lf%s#7$g{WE{$At##>vKAM)C#Yx4HtqnUA$p7<(Ey zhna`OtOEbj{eT_$w-!Z$tiyAOKBXV0iiC5%pe8QiAn{E+qbJT-cj9}L4R{~8+x#)( zR^vt^(9r_^4KgQRoojxjv6qo^k9NsxtS|9IeBk#$Kabzfn)fpfHV!kg4;deE%RXQq zK45$DYNa{x&Hg>pdiLpN^GU`O$hW)~kuOiMA^V*3D19DI{K3C0;@_8j4*N$D_JOQR zhs2I}Y-c@jNBmF^rtJ{7w42(|9d^Vw?a&{3_62&{Wqdq`J)C}c9zC4?;k3(h*rQMF zQcin}BW;Izo?|}o1OJ#Wp2sd7598u_p2H3fVoyEep?&(NU!G&U_{DP<*^u|E-sT|t z@ImW2PuS1TTi?^T(n$L!n$!Pu^Ye|{jl7p$XHNeU%sGFLFdu0of6g(##(0mBJTSxj z6XPo5#snvulP{Xg*Bcia-!_gho@QiTOe=1-um*(@f61rhH_m&^8DB3HM$XqE=FIOQ ziQQY~qYLyeSkFAu4(6=?d=v>X4_TszK4oEl6bbelyc>=B&fj=J<1^ zIdRPV4Yr>1m2nWqm$`nLv7d3P@d@McM*P3ZocZFt5Z>SQl}7Rb?`?I~A8b6|xYF2c zyw&)gkvzbDW4>Ko@v#Nwk2u|gfsys49jY5FY%#Kq)b9(4SKl*| z&g-o2VPw7UGABM+zbmZoY-B!`n$r&Zf;b}o9%Q@QjVp}gizm!S7@sySF*1MK%*n5d z%*n$?Cj3)#;*5RF_;S~?j<=b!Z;407Lw+A_yN8XFj603knU-OXCG1&;LDrKG$`2$0 z@y`4ATGy{OUYyu}Q9AUC8Wz~?Z#{7e5=ZS^PyDbB*ijD>&+Pwce*QEg@zf{LKW5H+ z5~sxBZr78~?luQlSK^}H6|Dbq^UcOl#*&eKUp6OhJDNXZoMRks+-hVz#0PPElIuTB zp1aI^i4i-Vf6)3Am>14b^1|rEkn=22SKtSEhWf<_#@dAMi(Z--FwR6>NN-4L$ZLtZ zkdi+zV11vng8016e1Z`_cbRiOTwzXLyw;p~Cl4KJJ^T6q^O43%;{YT3?R0bY72}y` zJ?G6s=Fb}0Pj{MMZ(L^l$Vgm}r(dw1d1C)zf2`}b8=1cvb6|-D_66&9yY=@Qdm2;V ze7?gK&nE`Vd!qgf#jmu2Jd^5E&ihfSPk9$Bcs~G%3+9>m;CbSW^`#x^_p?Ag<#~8& zM?2)#%WbgG$akOktN z{J?yC?E3Y_H13$k8LmIW2&DGJ3H7``kuS*;)L)EXJTT!$B>am~;*hwQ><3>pGM~)H zQP!_EK4Ro~{9a=HJx0z!=8toYc){PLwr74WG$&6mkFC}(F)}~gC*ZWV&2~>3$-nQK zZ#7Od5)bT~PpxNMz09kOLyf?83*`M~^W%&ojW-!nV4m0qA0!6&m8i><_-DS(Dc3}S zeU+#`LpNB{J*jyq;rpVuBnD~!U0hE-B~PMn?|L}Rk)O#Y#Pdfspnj)0^ZJze3C1&x z9~znO+2)Lwc%dEUf%P3_d-lu8=EUz}bIu#yBZgUjsj;{5Wh3i3)BG&saYo)l7nx5m zlK0pT$5_AC2r%Evtr$Rsk$ucMl6Tp+YZ80r0rm41t}(uEyvoS_G_7DfaPkLvX{9ST zXRrg;A{5|m>vt5??@N9DJ`9Lc;*axz$5{W=PvRJR;(3)HJi^HQ-)UZL>}w=HvJP)s z&-@QGZ!xl8PD=Ew?{@2NH1;y`{>yw4&m&xauW?@T9C3z!#QPJryTEvR@;rWSxBhP9 z9wYA&^h-PBQTn~a_NN-zKg>JhWqiyt?;Wf!ae0W_0oEcIiT8Ep2N)S=VqU>Ipug6_ zP~&i8n%A+<5sc)E0g1kd&Ppnn2l5f;+AiD8F!El7fA?F@dHR$&d7kmyC<&XqPqkbWqB-Zr!{#p;$@2@%HyD>1f#WT( z|K^&{HkORX8B@60iUmg2f%WD2_34A=>>H}_gZO&U6%#So2a*pkV}8#@u@79ak9w+E zpRQJrU#SPvyox%##@Zc$5ZArj68po`BEeEu-3e{ zvBCJLk-T79{{JsluG+Zz!lo zU2kEOk#lQ`IpZfE;5Ye&_sfm8hjV_OYdv}EX!8z6@(%l<#(KuT$ow%Q@ju=CZsSDb z%|_<=Lv!Afh`*<;r=3c3^8DrIbw=X*9`kpSdd79B^}w|VM)v7UbLRcd#18rag7Fq( zC*vw3^J7|uEoCVR#6Rl$5dME++CKiQaC-xdoC`;qHyW8A*3Ie)@*Vq&^|``l|Wa9(IRYumcr#W%QbL9Uj*S~4xyu8Gm z^`B|p%gDLGJUwXr^+sR>f{{2m+#J8|GiN?|jQN7UVS)J~Ur_&L@r5t@Kw?iEat<+1 zC)(~qWYgk5YOAq85i%FldM0)2#^oQSiyW! zPrD1OAfBl|$imS^_8ay1&pv6U!pM8WIp)MmVpgV{3n;F#Fw|HSueAHJ?RbuL!0!pJ zXTI5od9``PV-{$KbBH{9u0)!{*Yq>+gf)YV-Hm4(iIdglDdbi#Z@gz? zms#kCBEjAXXC4!EAw9IP=8IKyu^H>p-kkO7Xb!So*pWYoXV(8@KR3>}#>lzGIl{cc ziMJz@etB;oe?IDZ;-RNG<0D?!SLEN7wp(O8!bm=1KM%D2Oyl#${f(Q9#0~qer}gg{ zhZq^}67vm4_UUeOfcSgIik`*;jVBt(x6HrQ<-ZTgI=<@))`2|9dvKb^X=j*)tBuST z`xQIp37#VB$-Hh)3^pe`rMFtqVw{{Bm=n(+@j$$GvK@Njo_Qk9d7f@~4y2yv;q*g$ zquf9*V>cu7y3?HeLH^_X+wFSh{Sos$M$T*2b+q;LOZ&9b&-HzbFB&;dX_xbFqwD7w zcNu}JEZktc#z>s9FPQK7t{-e%VkD0-KXuk$YMf~7Z6wcf4&Grs=h3O=)ke-U@&NNV z&GqD=dh-?|e#|xp$OGgn#?{jXO-APD1s*isW+V@slGO9wOgaceQ{WdivoUM^8WG8{&-hs~Cb9=bH1r zb+kGA%CzDh3$GZ-7x;tS3JW_>7$+H*C;AH#ev3KzH>s|`fAs7#IOiLDUwGKQ43c`{ zkago+%x%Xx$GXk1p8ApI4;W83t}|X|tTQs7Z<&)ndG2cKi68oV)OzxCtN9Ve;YQ+` zJkGe^c0K!*bz)!das6?|rAFdxg!$#hLySX>eT~DCdd6*a#ViX$P#78iWb=cK#CKv| zk;W{pvF+YBZZcxm$9%AHl@VZE@SO-o_U~BpGmRUK^lMtd zy29CCyr(Rsq5xSxs@aF|hY$*oJlnP^BkQ1N7-(L z@e<=MBkMq3A&*|{dh+y2^LvdgM&gij5vL`6N$JzrgjwjqGFc$1~Qm5Br#t zSC*M`AE2Gvt)FTH8Z2})o=1g|yvTkbzhHKn1;#Vo{32tz9^{dMt{-E3(s-?LzLEDB z;&{CEJ&lYDzq?!ip7AXs^GCj+AL2Vp?AnGp*#;KPacP2 zA2=_mKggW-grs_(6o1KYaPl5`5k$}W(?R-NXIEek5~s8WcCbDB;#2e9M&f6!`358V z-L(Ao)Ge>mfx;CkYJ zw0U3S6-L&ZcsbYlH;m*t{DZr?;&ls08p+S>gZr(2)VR%_Y8+-}#8F~Z+8pJCi;Bwt-<&Url7e4~+lH_H4L;|3$~%lut#J^TG|bIzGn z<}VwU8G$_(*gyE2=E)k{?chPAM46KAwSHt-FeA#SC}*I&F1rs?At-+#OW#K*fWl8)*o$r z%s9<>my!3u_2%Te>&;nD_91b(&h?C!_sG+&?`LEl*qu{RuzE zxt=(~Kh6R2*2lIZUg?MRoag#x<7VT2M)C~n{G9ciQ^U;r7*8~IGQMb>WPIG%&dB~F z@3EdMT;J8$%lNjDxHPRe*#dqXXZ}0}#xyRdzS_dAM&2VI#K6dWCh7{F$3NyTwPSu$ zJ$Zq3OY7l#5R6?4kom=&dBnb>E5IUsXWQ|7hvd2TNj>eb{;n~<;~*~Yay|OP%!!+G%?BCJ zG!mDb0~cF=pOJk^T-|B?-NsjpCm2680*voAD>hJJB+sDUieNm?$i5$&)TflVO7-+J zAgxF^>yxM}jg=q_~FDe`8#t1-Hp#0S=VFD zxp%y4-pxo}-fTYFNITb>Pc$;$#3yn0l4$xko-fmE*Vj10Nc#=u?CXcjS+C{h zU5xnur1=!%d?W9pjIY-ETZ~T_$^YyxtIOYW7-q#RBk$GNEwb=13M2dAVDn{0#+jIZ zhQ4FXNF(#J@iUg6uc);8EhFo`^YhI=qrD-~j50sS$b2$C?9Yy5-iM|0wo_L=viTz{pJdD>!rjgfr?zs~x5j3bPzjpS3}eyR008W{)evEI+S zeywr0k@pqGd4%=smq*O!8}}FwGy>0AU>(Mov(Jvi(AX#8&<+R*o{@0YBT*Or|6;6X zYR~><9q1oD_N+JW4s zDiR!Oo<7e!Cgz3oprq#Qgzt;aOblugzAs9AV7IUSS%1!R{3o8!gTw`KkNqfj$atrl z^M1C){9YsX&ynV=%W>wf8&5RyzW2U4^K`TMIwQ}|FejebPsAnrWPjVyKk>=_B;K#G z9sA)WbNnWMZMVLU@igN$<2%N+#-2vb4dQLI^}tOQX!lJjjDw8qr^=+B^M!GNoalEX97;UGi8J)@_BJ3MLB`Gcv9Fo`Wq$5H<0j*BV^3qf zainpLk$Gi*5HDxDp1AK}US-^4yu(;BPBOCoP3FWc+II7c3F zL*tC(!Q0JijgyQ!jret^d4D7CH4mGQG;TA#Y9!ux@4)|eTu)x5KjQlW*OTu#4~AP$ z+!OD#kADZ-j&}K6<#g+rPwpqYUoemC8|H`kyUWir@APAJ`R8gqt+?Mvzo(iHF|uz> zD=xFZdB}L!PqSRldOW~`#+AmIiJpCddO5-u#9fIU`!rD(QsS9;XMdvSLE;-Eo|y;Y z`Eoxn(8#_d&-SH3BXP#M^4@)g>xoC^_gw2)=cVS1i~P#D&wP_7887pGv7cwXTFi-S z;$pt_(~OK4e`i?Fetp@zkCFYg$eew`JUwmwGsZ)W#ODF#lZ?y{{SpVqxPH0u6C-fE zg;~t2kq2|4f%meYW!}=$U^~Sf1-v&;rE$1%iIH&=hvcUfuID+z_9sV#Hti8EyS~BXQ5?EF(dIk*8DCb;~-DWwSI{%9lT&2{Ee zjI5(+8CuFx6mBwKWvoi-zbM6@FKVCKElm~XX&f^TnByn{ zKHwmE?L;d$ms>E{2X;;B-!#8G(HGHCNdVw_Cr~_^^?4oxDpPquoPox86t|W_@a{C+{%t%p2qQz;>+fd*+-s z#N$!c-)m$Zc+aLE&XLn?H`@rjfMBE@)UR21-iTk+r|-F+QeotMe2#eloDUe zOIpwTqKRLG!a6RjaIr9YvSs&t; zeXtvYec-@-)Kg6yGGFAY3vFA@tisNX}kGqV^1UTO+Uo>2-mZouAKn~cQIbo24X3C4Gg7aNBeiQggS>>K*I-+JPYean7Z=lW{n&Bolw zxEUAYVIR%49phzu?^sX&mzg&i&o-t|lPCt6?=;?*)bER)V*~QqgZp|Qsh^Q>@@i-6 z$&={mhdfJuZD)IsID5eQt;So7#2x#(kM(Pfqm0aNPjljSm3eO?`*Ml-L?ii={lomQ zU!J!eaq+e}`Q5Yvd*Y5bU_E(X-RK91!$x!BXskK$^tky$#^;QjBgE-Q>u)l$56Qpm z_l2&XX9Sq%3`K%(nJ+OC|Cq7PgHa^-m^u3=sjhg@irHy}IdN7{UH-l|=S6BqzD)JR z5qa}{8(e3+#JIna_e0iavGq?GiEsAZOzRIdE;Z5*d4Y94%Jt-R; zIltzZHyCOEE_2!$Y(C$}IuRe7YlpdhhLL?i|5lgbV@n;%nkZahew%SjQV)I2!ZsAf z7mU0oFz&>x;=ojq@C6A^Df64^;nNU|M?2(Ekhr2->{HJ3AocW5 zJ0R_o&lwN(7r6n}AATAQ8)q2LFcN1^n6p3X6T3?DH;mhj>@(t@{MO|9ZpKZ z4f_AqT)8h%Mm{C&yC%I3r!?+2{IKwG?SBwtN8uQRSTGJieH zmm67!9p=?W_5=18SkF3fF5G85>-3=cB;!k)_-WEKh|-w^{n@?=KYK}8?Q0GYa|{gn$r*asK5317#Zgv^To#W{d2JO zcNR(wg$JhC4;`FZkENAu1`>Ot(_w8JnuB+pZi zJ@s(vNBKF{m3XbQp7q^o&U1&D6X#EuFEf(ot~DnwzGJ@0c(xI{8Ro?4Sab3^`G7bY zB3s z{V?C0FZe-TU>>I>?b8lOUS&K7yMp;=d`I&@g0sv|N%TV#&iL>L+|PEzH*pB3p17wS z5Ig3D=h+W%>gjis`z5asPy1WH(a3sEGAG|(Z@$vVbEle<56Lsztv}E>*+?EFZZERF z!T6-{DdX)%;%~Hh599sD?nd%B`;qxP%=Modi5vXIfA;w>+f^DFFY!-%_{Dq42-~y& zX`gsv9aq?ndFFkEeDi?o-!rCAvVyq%#QY^A^O8OXPum?xh4IjYLy4m^ta#t}qLFcN z4xefLEaUmcSB&_7xcM^U5ym%-n~m2Q&ofRpE;bS`>&@9OubIy`PD?oB>S_J6i5>n~ zUD0fTcvx#r|7V-0ei08hxgL1i0`uBxzQB05k#mFmV0Fb+7CIQolgtzP^!=CTLG~;8 z6n(9QJ}3(C1M8~_>c?A89I-x(Ynkg886P!rPINOr%E-Qd+PtTc{Y{?OZv9TF``awqa zC2_yjdghJz++aOWLMTA?0r_QIfnAaM6jw|wGT4XygruH)ikf(0f9&T9?3gF);TKu) zxN)VCcw^ljww`$+e=#oh2XTS^LO=Je@j2sT#<51m&$?Y}{gFoU2Knv<>&dsoHRt4~ zuD{N>$vE6N*?5bQxT`Tg*SN*F+<3lmwh_NgD_G|ntYDrym@}`~Gj2HN7=90S#U&Ud z*gN6O6Y5703h=2!zcAsP8;QD1i3=3uUHBJ8;*j%?cx1jfUoLio#Q6i}2OEid^7(Y@ z+5emu6Ramr_LyH_ywbSC$bKhocu%1J=WX|nahmZ=Bl+}gbM_Ca=x$7(KiPWbgXdqiK>y7DGuG2S_PpQ0@33%yk^R2W{1Xg~HI9CttE(*}t7^hd#Aym$dW$HFloAIfhXjE=o0K?a`=J zv1*4Zf@;->imKXFh@FU$phi$aC@KV{q4um%Ge$L<@NfBYT|OLn;{9~JzjMxWKleE2 z+)opU1?W6zf3YuI#NGghL+jUlbv@40zT>@Zzp-zam*#T~<35G8(EjE<>UT#oe%G@Y z{S)qkuY&yqdH}o$?Z-ROo}b@y_IdN8f;jWI7;T;-)9e?1=eRHWC-LJ8EYktK2h0o4 z!!XzDKDb`}es7%D5NKUzVt*3+sY=&rAAiS#)|LC!341TN8QNDW(a)jr=8^WE_?N(Z zcpSRV_WjY=v!HqE`EJ3!9Qyq=HrTCi`))OU*YEy6!@eD^f{oC;cb_t_XG8mM1KRhD zL%Xl8_cM0iZ=Ck|_*X3E0h;gC==;!q)QYZwZ=rcM z3H=^gR}0bB-%Ip8cocSqd!hM#9c>;tzscCGqkiZiun9haJK;BY5SqvKodWE0p!My3 zdOr3E!qR5(BUb}-eeRp?dVd5S2-& z6q+~I{dw$TU^a9e7tq^bA#8>#VQ(m`-+37Fpx=eY`TPbAuok+`F6eh5&UrivcK6eJ z&N_Y_;*|CcwfBtYt38QxlN`Zr9vbg?4kONUD?$%|bD-yO3_TcDLGynN+Wp^z_Wp5w zH?X^p_Sb>fr@&Lt{O^VC4>!X)xE1>Sx(IFEWT78I>njKS82TRT(ffBF{=?97-GO%h z%yaX8JpNqx5jviE?t1s*Z-Ktw_hA0}|KA?K;J&yH$Mf8+FQ4z~4>&4l=NHU>rF~t) z8-Ku`L7RuWu)A*K?3fR5jQw$AOt_K9-*zGoxaefIpvV0XRt z6VKK2bY2sQzXHuC{f=w>oh8n`>ACnG$9JB;i9ZDULh*$kM>N21uojy4YeQTk`dqN@ zL7PwZKlj&g&nv$FmI3ew92)G(d;a(j+y2KN-R9wL9F@>|bv?7OyH4}M`_gmDBd!8| zhufg%WxV_E{a{|1H|DwhtPSHlgT>JJljwZ-9C~lr5A44k@Voxi=&!I2+Rx0V4D2QF zGOU8W*E+=f&%bLn?;X$mD&+y|z~>|DG0cPo(DRA$NxBI~%>ShLXkUsUD>_ot9Yb6u ZIvuuyp1<_`eB5!`hv%X_KIi+i{{hsO0d)WX literal 0 HcmV?d00001 diff --git a/uniter_model/data/vcr.py b/uniter_model/data/vcr.py new file mode 100644 index 0000000..99e8986 --- /dev/null +++ b/uniter_model/data/vcr.py @@ -0,0 +1,725 @@ +""" +VCR dataset +""" +import json +import copy +import random + +import torch +from torch.nn.utils.rnn import pad_sequence +from toolz.sandbox import unzip +from torch.utils.data import Dataset + +from .data import DetectFeatLmdb, TxtLmdb, random_word +from .mrc import DetectFeatDir_for_mrc + + +class ImageTextDataset(Dataset): + def __init__(self, db_dir, img_dir_gt=None, img_dir=None, + max_txt_len=120, task="qa"): + self.txt_lens = [] + self.ids = [] + self.task = task + for id_, len_ in json.load(open(f'{db_dir}/id2len_{task}.json') + ).items(): + if max_txt_len == -1 or len_ <= max_txt_len: + self.txt_lens.append(len_) + self.ids.append(id_) + + self.db = TxtLmdb(db_dir, readonly=True) + self.img_dir = img_dir + self.img_dir_gt = img_dir_gt + + def __len__(self): + return len(self.ids) + + def __getitem__(self, i): + id_ = self.ids[i] + txt_dump = self.db[id_] + img_dump_gt, img_dump = None, None + img_fname_gt, img_fname = txt_dump['img_fname'] + if self.img_dump_gt: + img_dump_gt = self.img_dump_gt[img_fname_gt] + if self.img_dir: + img_dump = self.img_dir[img_fname] + return img_dump_gt, img_dump, txt_dump + + +class DetectFeatBertTokDataset(ImageTextDataset): + def __init__(self, db_dir, img_dir_gt=None, img_dir=None, + max_txt_len=60, task="qa"): + assert not (img_dir_gt is None and img_dir is None),\ + "image_dir_gt and img_dir cannot all be None" + assert task == "qa" or task == "qar",\ + "VCR only allow two tasks: qa or qar" + assert img_dir_gt is None or isinstance(img_dir_gt, DetectFeatLmdb) + assert img_dir is None or isinstance(img_dir, DetectFeatLmdb) + + super().__init__(db_dir, img_dir_gt, img_dir, max_txt_len, task) + txt2img = json.load(open(f'{db_dir}/txt2img.json')) + if self.img_dir and self.img_dir_gt: + self.lens = [tl+self.img_dir_gt.name2nbb[txt2img[id_][0]] + + self.img_dir.name2nbb[txt2img[id_][1]] + for tl, id_ in zip(self.txt_lens, self.ids)] + elif self.img_dir: + self.lens = [tl+self.img_dir.name2nbb[txt2img[id_][1]] + for tl, id_ in zip(self.txt_lens, self.ids)] + else: + self.lens = [tl+self.img_dir_gt.name2nbb[txt2img[id_][0]] + for tl, id_ in zip(self.txt_lens, self.ids)] + + meta = json.load(open(f'{db_dir}/meta.json', 'r')) + self.cls_ = meta['CLS'] + self.sep = meta['SEP'] + self.mask = meta['MASK'] + self.v_range = meta['v_range'] + + def _get_img_feat(self, fname_gt, fname): + if self.img_dir and self.img_dir_gt: + img_feat_gt, bb_gt = self.img_dir_gt[fname_gt] + img_bb_gt = torch.cat([bb_gt, bb_gt[:, 4:5]*bb_gt[:, 5:]], dim=-1) + + img_feat, bb = self.img_dir[fname] + img_bb = torch.cat([bb, bb[:, 4:5]*bb[:, 5:]], dim=-1) + + img_feat = torch.cat([img_feat_gt, img_feat], dim=0) + img_bb = torch.cat([img_bb_gt, img_bb], dim=0) + num_bb = img_feat.size(0) + elif self.img_dir: + img_feat, bb = self.img_dir[fname] + img_bb = torch.cat([bb, bb[:, 4:5]*bb[:, 5:]], dim=-1) + num_bb = img_feat.size(0) + elif self.img_dir_gt: + img_feat, bb = self.img_dir_gt[fname_gt] + img_bb = torch.cat([bb, bb[:, 4:5]*bb[:, 5:]], dim=-1) + num_bb = img_feat.size(0) + return img_feat, img_bb, num_bb + + +class VcrDataset(DetectFeatBertTokDataset): + def __init__(self, mask_prob, *args, **kwargs): + super().__init__(*args, **kwargs) + self.mask_prob = mask_prob + del self.txt_lens + + def _get_input_ids(self, txt_dump): + # text input + input_ids_q = txt_dump['input_ids'] + type_ids_q = [0]*len(input_ids_q) + input_ids_as = txt_dump['input_ids_as'] + if self.task == "qar": + input_ids_rs = txt_dump['input_ids_rs'] + answer_label = txt_dump['qa_target'] + assert answer_label >= 0, "answer_label < 0" + input_ids_gt_a = [self.sep] + copy.deepcopy( + input_ids_as[answer_label]) + type_ids_gt_a = [2] * len(input_ids_gt_a) + type_ids_q += type_ids_gt_a + input_ids_q += input_ids_gt_a + input_ids_for_choices = input_ids_rs + else: + input_ids_for_choices = input_ids_as + return input_ids_q, input_ids_for_choices, type_ids_q + + def __getitem__(self, i): + id_ = self.ids[i] + txt_dump = self.db[id_] + img_feat, img_pos_feat, num_bb = self._get_img_feat( + txt_dump['img_fname'][0], txt_dump['img_fname'][1]) + object_targets = txt_dump["object_ids"] + input_ids_q, input_ids_for_choices, type_ids_q = self._get_input_ids( + txt_dump) + label = txt_dump['%s_target' % (self.task)] + + choice_num_bbs, choice_img_feats, choice_img_pos_feats = ( + [], [], []) + (choice_txt_lens, choice_input_ids, choice_txt_type_ids, + choice_attn_masks, choice_position_ids, choice_targets) = ( + [], [], [], [], [], []) + choice_obj_targets, choice_img_masks = ([], []) + + for index, input_ids_a in enumerate(input_ids_for_choices): + if index == label: + target = torch.tensor([1]).long() + else: + target = torch.tensor([0]).long() + input_ids = [self.cls_] + copy.deepcopy(input_ids_q) +\ + [self.sep] + input_ids_a + [self.sep] + type_id_for_choice = 3 if type_ids_q[-1] == 2 else 2 + txt_type_ids = [0] + type_ids_q + [type_id_for_choice]*( + len(input_ids_a)+2) + attn_masks = [1] * len(input_ids) + position_ids = list(range(len(input_ids))) + attn_masks += [1] * num_bb + + input_ids = torch.tensor(input_ids) + position_ids = torch.tensor(position_ids) + attn_masks = torch.tensor(attn_masks) + txt_type_ids = torch.tensor(txt_type_ids) + + choice_txt_lens.append(len(input_ids)) + choice_input_ids.append(input_ids) + choice_attn_masks.append(attn_masks) + choice_position_ids.append(position_ids) + choice_txt_type_ids.append(txt_type_ids) + + choice_num_bbs.append(num_bb) + choice_img_feats.append(img_feat) + choice_img_pos_feats.append(img_pos_feat) + choice_targets.append(target) + + # mask image input features + num_gt_bb = len(object_targets) + num_det_bb = num_bb - num_gt_bb + # only mask gt features + img_mask = [random.random() < self.mask_prob + for _ in range(num_gt_bb)] + if not any(img_mask): + # at least mask 1 + img_mask[0] = True + img_mask += [False]*num_det_bb + img_mask = torch.tensor(img_mask) + object_targets += [0]*num_det_bb + obj_targets = torch.tensor(object_targets) + + choice_obj_targets.append(obj_targets) + choice_img_masks.append(img_mask) + + return (choice_input_ids, choice_position_ids, choice_txt_lens, + choice_txt_type_ids, + choice_img_feats, choice_img_pos_feats, choice_num_bbs, + choice_attn_masks, choice_targets, choice_obj_targets, + choice_img_masks) + + +def vcr_collate(inputs): + (input_ids, position_ids, txt_lens, txt_type_ids, img_feats, + img_pos_feats, num_bbs, attn_masks, targets, + obj_targets, img_masks) = map(list, unzip(inputs)) + + all_num_bbs, all_img_feats, all_img_pos_feats = ( + [], [], []) + all_txt_lens, all_input_ids, all_attn_masks,\ + all_position_ids, all_txt_type_ids = ( + [], [], [], [], []) + all_obj_targets = [] + all_targets = [] + # all_targets = targets + all_img_masks = [] + for i in range(len(num_bbs)): + all_input_ids += input_ids[i] + all_position_ids += position_ids[i] + all_txt_lens += txt_lens[i] + all_txt_type_ids += txt_type_ids[i] + all_img_feats += img_feats[i] + all_img_pos_feats += img_pos_feats[i] + all_num_bbs += num_bbs[i] + all_attn_masks += attn_masks[i] + all_obj_targets += obj_targets[i] + all_img_masks += img_masks[i] + all_targets += targets[i] + + all_input_ids = pad_sequence(all_input_ids, + batch_first=True, padding_value=0) + all_position_ids = pad_sequence(all_position_ids, + batch_first=True, padding_value=0) + all_txt_type_ids = pad_sequence(all_txt_type_ids, + batch_first=True, padding_value=0) + all_attn_masks = pad_sequence(all_attn_masks, + batch_first=True, padding_value=0) + all_img_masks = pad_sequence(all_img_masks, + batch_first=True, padding_value=0) + # all_targets = pad_sequence(all_targets, + # batch_first=True, padding_value=0) + all_targets = torch.stack(all_targets, dim=0) + + batch_size = len(all_img_feats) + num_bb = max(all_num_bbs) + feat_dim = all_img_feats[0].size(1) + pos_dim = all_img_pos_feats[0].size(1) + all_img_feat = torch.zeros(batch_size, num_bb, feat_dim) + all_img_pos_feat = torch.zeros(batch_size, num_bb, pos_dim) + all_obj_target = torch.zeros(batch_size, num_bb) + for i, (im, pos, label) in enumerate(zip( + all_img_feats, all_img_pos_feats, all_obj_targets)): + len_ = im.size(0) + all_img_feat.data[i, :len_, :] = im.data + all_img_pos_feat.data[i, :len_, :] = pos.data + all_obj_target.data[i, :len_] = label.data + + obj_targets = all_obj_target[all_img_masks].contiguous() + return (all_input_ids, all_position_ids, all_txt_lens, + all_txt_type_ids, + all_img_feat, all_img_pos_feat, all_num_bbs, + all_attn_masks, all_targets, obj_targets, all_img_masks) + + +class VcrEvalDataset(DetectFeatBertTokDataset): + def __init__(self, split, *args, **kwargs): + super().__init__(*args, **kwargs) + self.split = split + del self.txt_lens + + def _get_input_ids(self, txt_dump): + # text input + input_ids_for_choices = [] + type_ids_for_choices = [] + input_ids_q = txt_dump['input_ids'] + type_ids_q = [0]*len(input_ids_q) + input_ids_as = txt_dump['input_ids_as'] + input_ids_rs = txt_dump['input_ids_rs'] + for index, input_ids_a in enumerate(input_ids_as): + curr_input_ids_qa = [self.cls_] + copy.deepcopy(input_ids_q) +\ + [self.sep] + input_ids_a + [self.sep] + curr_type_ids_qa = [0] + type_ids_q + [2]*( + len(input_ids_a)+2) + input_ids_for_choices.append(curr_input_ids_qa) + type_ids_for_choices.append(curr_type_ids_qa) + for index, input_ids_a in enumerate(input_ids_as): + curr_input_ids_qa = [self.cls_] + copy.deepcopy(input_ids_q) +\ + [self.sep] + input_ids_a + [self.sep] + curr_type_ids_qa = [0] + type_ids_q + [2]*( + len(input_ids_a)+1) + if (self.split == "val" and index == txt_dump["qa_target"]) or\ + self.split == "test": + for input_ids_r in input_ids_rs: + curr_input_ids_qar = copy.deepcopy(curr_input_ids_qa) +\ + input_ids_r + [self.sep] + curr_type_ids_qar = copy.deepcopy(curr_type_ids_qa) +\ + [3]*(len(input_ids_r)+2) + input_ids_for_choices.append(curr_input_ids_qar) + type_ids_for_choices.append(curr_type_ids_qar) + return input_ids_for_choices, type_ids_for_choices + + def __getitem__(self, i): + qid = self.ids[i] + id_ = self.ids[i] + txt_dump = self.db[id_] + img_feat, img_pos_feat, num_bb = self._get_img_feat( + txt_dump['img_fname'][0], txt_dump['img_fname'][1]) + object_targets = txt_dump["object_ids"] + input_ids_for_choices, type_ids_for_choices = self._get_input_ids( + txt_dump) + qa_target = torch.tensor([int(txt_dump["qa_target"])]) + qar_target = torch.tensor([int(txt_dump["qar_target"])]) + + choice_num_bbs, choice_img_feats, choice_img_pos_feats = ( + [], [], []) + (choice_txt_lens, choice_input_ids, choice_attn_masks, + choice_position_ids, choice_txt_type_ids) = ( + [], [], [], [], []) + choice_obj_targets = [] + for index, input_ids in enumerate(input_ids_for_choices): + txt_type_ids = type_ids_for_choices[index] + attn_masks = [1] * len(input_ids) + position_ids = list(range(len(input_ids))) + attn_masks += [1] * num_bb + + input_ids = torch.tensor(input_ids) + position_ids = torch.tensor(position_ids) + attn_masks = torch.tensor(attn_masks) + txt_type_ids = torch.tensor(txt_type_ids) + + choice_txt_lens.append(len(input_ids)) + choice_input_ids.append(input_ids) + choice_attn_masks.append(attn_masks) + choice_position_ids.append(position_ids) + choice_txt_type_ids.append(txt_type_ids) + + choice_num_bbs.append(num_bb) + choice_img_feats.append(img_feat) + choice_img_pos_feats.append(img_pos_feat) + + obj_targets = torch.tensor(object_targets) + choice_obj_targets.append(obj_targets) + + return (qid, choice_input_ids, choice_position_ids, choice_txt_lens, + choice_txt_type_ids, + choice_img_feats, choice_img_pos_feats, choice_num_bbs, + choice_attn_masks, qa_target, qar_target, choice_obj_targets) + + +def vcr_eval_collate(inputs): + (qids, input_ids, position_ids, txt_lens, txt_type_ids, + img_feats, img_pos_feats, + num_bbs, attn_masks, qa_targets, qar_targets, + obj_targets) = map(list, unzip(inputs)) + + all_num_bbs, all_img_feats, all_img_pos_feats = ( + [], [], []) + all_txt_lens, all_input_ids, all_attn_masks, all_position_ids,\ + all_txt_type_ids = ( + [], [], [], [], []) + # all_qa_targets = qa_targets + # all_qar_targets = qar_targets + all_obj_targets = [] + for i in range(len(num_bbs)): + all_input_ids += input_ids[i] + all_position_ids += position_ids[i] + all_txt_lens += txt_lens[i] + all_img_feats += img_feats[i] + all_img_pos_feats += img_pos_feats[i] + all_num_bbs += num_bbs[i] + all_attn_masks += attn_masks[i] + all_txt_type_ids += txt_type_ids[i] + all_obj_targets += obj_targets[i] + + all_input_ids = pad_sequence(all_input_ids, + batch_first=True, padding_value=0) + all_position_ids = pad_sequence(all_position_ids, + batch_first=True, padding_value=0) + all_txt_type_ids = pad_sequence(all_txt_type_ids, + batch_first=True, padding_value=0) + all_attn_masks = pad_sequence(all_attn_masks, + batch_first=True, padding_value=0) + all_obj_targets = pad_sequence(all_obj_targets, + batch_first=True, padding_value=0) + all_qa_targets = torch.stack(qa_targets, dim=0) + all_qar_targets = torch.stack(qar_targets, dim=0) + + batch_size = len(all_img_feats) + num_bb = max(all_num_bbs) + feat_dim = all_img_feats[0].size(1) + pos_dim = all_img_pos_feats[0].size(1) + all_img_feat = torch.zeros(batch_size, num_bb, feat_dim) + all_img_pos_feat = torch.zeros(batch_size, num_bb, pos_dim) + for i, (im, pos) in enumerate(zip( + all_img_feats, all_img_pos_feats)): + len_ = im.size(0) + all_img_feat.data[i, :len_, :] = im.data + all_img_pos_feat.data[i, :len_, :] = pos.data + + return (qids, all_input_ids, all_position_ids, all_txt_lens, + all_txt_type_ids, + all_img_feat, all_img_pos_feat, all_num_bbs, + all_attn_masks, all_qa_targets, all_qar_targets, all_obj_targets) + + +class MlmDatasetForVCR(DetectFeatBertTokDataset): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + del self.txt_lens + + def _get_input_ids(self, txt_dump, mask=True): + # text input + input_ids_q = txt_dump['input_ids'] + type_ids_q = [0]*len(input_ids_q) + if mask: + input_ids_q, txt_labels_q = random_word( + input_ids_q, self.v_range, self.mask) + else: + txt_labels_q = input_ids_q + + answer_label = txt_dump['qa_target'] + assert answer_label >= 0, "answer_label < 0" + + input_ids_a = txt_dump['input_ids_as'][answer_label] + type_ids_a = [2]*len(input_ids_a) + if mask: + input_ids_a, txt_labels_a = random_word( + input_ids_a, self.v_range, self.mask) + else: + txt_labels_a = input_ids_a + + input_ids = input_ids_q + [self.sep] + input_ids_a + type_ids = type_ids_q + [0] + type_ids_a + txt_labels = txt_labels_q + [-1] + txt_labels_a + + if self.task == "qar": + rationale_label = txt_dump['qar_target'] + assert rationale_label >= 0, "rationale_label < 0" + + input_ids_r = txt_dump['input_ids_rs'][rationale_label] + type_ids_r = [3]*len(input_ids_r) + if mask: + input_ids_r, txt_labels_r = random_word( + input_ids_r, self.v_range, self.mask) + else: + txt_labels_r = input_ids_r + + input_ids += [self.sep] + input_ids_r + type_ids += [2] + type_ids_r + txt_labels += [-1] + txt_labels_r + return input_ids, type_ids, txt_labels + + def __getitem__(self, i): + id_ = self.ids[i] + txt_dump = self.db[id_] + img_feat, img_pos_feat, num_bb = self._get_img_feat( + txt_dump['img_fname'][0], txt_dump['img_fname'][1]) + + # txt inputs + input_ids, type_ids, txt_labels = self._get_input_ids(txt_dump) + input_ids = [self.cls_] + input_ids + [self.sep] + txt_labels = [-1] + txt_labels + [-1] + type_ids = [type_ids[0]] + type_ids + [type_ids[-1]] + attn_masks = [1] * len(input_ids) + position_ids = list(range(len(input_ids))) + attn_masks += [1] * num_bb + input_ids = torch.tensor(input_ids) + position_ids = torch.tensor(position_ids) + attn_masks = torch.tensor(attn_masks) + txt_labels = torch.tensor(txt_labels) + type_ids = torch.tensor(type_ids) + + return (input_ids, position_ids, type_ids, img_feat, img_pos_feat, + attn_masks, txt_labels) + + +def mlm_collate_for_vcr(inputs): + (input_ids, position_ids, type_ids, img_feats, img_pos_feats, attn_masks, + txt_labels) = map(list, unzip(inputs)) + + txt_lens = [i.size(0) for i in input_ids] + num_bbs = [f.size(0) for f in img_feats] + + input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0) + type_ids = pad_sequence(type_ids, batch_first=True, padding_value=0) + position_ids = pad_sequence(position_ids, + batch_first=True, padding_value=0) + attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0) + txt_labels = pad_sequence(txt_labels, batch_first=True, padding_value=-1) + + batch_size = len(img_feats) + num_bb = max(num_bbs) + feat_dim = img_feats[0].size(1) + pos_dim = img_pos_feats[0].size(1) + img_feat = torch.zeros(batch_size, num_bb, feat_dim) + img_pos_feat = torch.zeros(batch_size, num_bb, pos_dim) + for i, (im, pos) in enumerate(zip(img_feats, img_pos_feats)): + len_ = im.size(0) + img_feat.data[i, :len_, :] = im.data + img_pos_feat.data[i, :len_, :] = pos.data + + return (input_ids, position_ids, type_ids, txt_lens, + img_feat, img_pos_feat, num_bbs, + attn_masks, txt_labels) + + +class MrmDatasetForVCR(DetectFeatBertTokDataset): + def __init__(self, mask_prob, *args, **kwargs): + super().__init__(*args, **kwargs) + self.mask_prob = mask_prob + del self.txt_lens + + def _get_input_ids(self, txt_dump, mask=True): + # text input + input_ids_q = txt_dump['input_ids'] + type_ids_q = [0]*len(input_ids_q) + + answer_label = txt_dump['qa_target'] + assert answer_label >= 0, "answer_label < 0" + + input_ids_a = txt_dump['input_ids_as'][answer_label] + type_ids_a = [2]*len(input_ids_a) + + input_ids = input_ids_q + [self.sep] + input_ids_a + type_ids = type_ids_q + [0] + type_ids_a + + if self.task == "qar": + rationale_label = txt_dump['qar_target'] + assert rationale_label >= 0, "rationale_label < 0" + + input_ids_r = txt_dump['input_ids_rs'][rationale_label] + type_ids_r = [3]*len(input_ids_r) + + input_ids += [self.sep] + input_ids_r + type_ids += [2] + type_ids_r + return input_ids, type_ids + + def __getitem__(self, i): + id_ = self.ids[i] + txt_dump = self.db[id_] + img_feat, img_pos_feat, num_bb = self._get_img_feat( + txt_dump['img_fname'][0], txt_dump['img_fname'][1]) + + # image input features + img_mask = [random.random() < self.mask_prob for _ in range(num_bb)] + if not any(img_mask): + # at least mask 1 + img_mask[0] = True + img_mask = torch.tensor(img_mask) + + # text input + input_ids, type_ids = self._get_input_ids(txt_dump) + input_ids = [self.cls_] + input_ids + [self.sep] + type_ids = [type_ids[0]] + type_ids + [type_ids[-1]] + attn_masks = [1] * len(input_ids) + position_ids = list(range(len(input_ids))) + attn_masks += [1] * num_bb + input_ids = torch.tensor(input_ids) + position_ids = torch.tensor(position_ids) + attn_masks = torch.tensor(attn_masks) + type_ids = torch.tensor(type_ids) + + return (input_ids, position_ids, type_ids, img_feat, img_pos_feat, + attn_masks, img_mask) + + +def mrm_collate_for_vcr(inputs): + (input_ids, position_ids, type_ids, img_feats, img_pos_feats, + attn_masks, img_masks) = map(list, unzip(inputs)) + + txt_lens = [i.size(0) for i in input_ids] + num_bbs = [f.size(0) for f in img_feats] + + input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0) + position_ids = pad_sequence(position_ids, + batch_first=True, padding_value=0) + type_ids = pad_sequence(type_ids, batch_first=True, padding_value=0) + attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0) + img_masks = pad_sequence(img_masks, batch_first=True, padding_value=0) + + batch_size = len(img_feats) + num_bb = max(num_bbs) + feat_dim = img_feats[0].size(1) + pos_dim = img_pos_feats[0].size(1) + img_feat = torch.zeros(batch_size, num_bb, feat_dim) + img_pos_feat = torch.zeros(batch_size, num_bb, pos_dim) + for i, (im, pos) in enumerate(zip(img_feats, img_pos_feats)): + len_ = im.size(0) + img_feat.data[i, :len_, :] = im.data + img_pos_feat.data[i, :len_, :] = pos.data + + return (input_ids, position_ids, type_ids, txt_lens, + img_feat, img_pos_feat, num_bbs, + attn_masks, img_masks) + + +class DetectFeatBertTokDataset_for_mrc_vcr(DetectFeatBertTokDataset): + def __init__(self, db_dir, img_dir_gt=None, img_dir=None, + max_txt_len=60, task="qa"): + assert not (img_dir_gt is None and img_dir is None),\ + "image_dir_gt and img_dir cannot all be None" + assert task == "qa" or task == "qar",\ + "VCR only allow two tasks: qa or qar" + assert img_dir_gt is None or isinstance(img_dir_gt, DetectFeatLmdb) + assert img_dir is None or isinstance(img_dir, DetectFeatLmdb) + super().__init__(db_dir, img_dir_gt, img_dir, max_txt_len, task) + if self.img_dir: + self.img_dir = DetectFeatDir_for_mrc(img_dir) + if self.img_dir_gt: + self.img_dir_gt = DetectFeatDir_for_mrc(img_dir_gt) + + def _get_img_feat(self, fname_gt, fname): + if self.img_dir and self.img_dir_gt: + img_feat_gt, bb_gt,\ + img_soft_labels_gt = self.img_dir_gt[fname_gt] + img_bb_gt = torch.cat([bb_gt, bb_gt[:, 4:5]*bb_gt[:, 5:]], dim=-1) + + img_feat, bb, img_soft_labels = self.img_dir[fname] + img_bb = torch.cat([bb, bb[:, 4:5]*bb[:, 5:]], dim=-1) + + img_feat = torch.cat([img_feat_gt, img_feat], dim=0) + img_bb = torch.cat([img_bb_gt, img_bb], dim=0) + img_soft_labels = torch.cat( + [img_soft_labels_gt, img_soft_labels], dim=0) + num_bb = img_feat.size(0) + elif self.img_dir: + img_feat, bb, img_soft_labels = self.img_dir[fname] + img_bb = torch.cat([bb, bb[:, 4:5]*bb[:, 5:]], dim=-1) + num_bb = img_feat.size(0) + elif self.img_dir_gt: + img_feat, bb, img_soft_labels = self.img_dir_gt[fname_gt] + img_bb = torch.cat([bb, bb[:, 4:5]*bb[:, 5:]], dim=-1) + num_bb = img_feat.size(0) + return img_feat, img_bb, img_soft_labels, num_bb + + +class MrcDatasetForVCR(DetectFeatBertTokDataset_for_mrc_vcr): + def __init__(self, mask_prob, *args, **kwargs): + super().__init__(*args, **kwargs) + self.mask_prob = mask_prob + del self.txt_lens + + def _get_input_ids(self, txt_dump, mask=True): + # text input + input_ids_q = txt_dump['input_ids'] + type_ids_q = [0]*len(input_ids_q) + + answer_label = txt_dump['qa_target'] + assert answer_label >= 0, "answer_label < 0" + + input_ids_a = txt_dump['input_ids_as'][answer_label] + type_ids_a = [2]*len(input_ids_a) + + input_ids = input_ids_q + [self.sep] + input_ids_a + type_ids = type_ids_q + [0] + type_ids_a + + if self.task == "qar": + rationale_label = txt_dump['qar_target'] + assert rationale_label >= 0, "rationale_label < 0" + + input_ids_r = txt_dump['input_ids_rs'][rationale_label] + type_ids_r = [3]*len(input_ids_r) + + input_ids += [self.sep] + input_ids_r + type_ids += [2] + type_ids_r + return input_ids, type_ids + + def __getitem__(self, i): + id_ = self.ids[i] + txt_dump = self.db[id_] + img_feat, img_pos_feat, img_soft_labels, num_bb = self._get_img_feat( + txt_dump['img_fname'][0], txt_dump['img_fname'][1]) + + # image input features + img_mask = [random.random() < self.mask_prob for _ in range(num_bb)] + if not any(img_mask): + # at least mask 1 + img_mask[0] = True + img_mask = torch.tensor(img_mask) + + # text input + input_ids, type_ids = self._get_input_ids(txt_dump) + input_ids = [self.cls_] + input_ids + [self.sep] + type_ids = [type_ids[0]] + type_ids + [type_ids[-1]] + attn_masks = [1] * len(input_ids) + position_ids = list(range(len(input_ids))) + attn_masks += [1] * num_bb + input_ids = torch.tensor(input_ids) + position_ids = torch.tensor(position_ids) + attn_masks = torch.tensor(attn_masks) + type_ids = torch.tensor(type_ids) + + return (input_ids, position_ids, type_ids, img_feat, img_pos_feat, + img_soft_labels, attn_masks, img_mask) + + +def mrc_collate_for_vcr(inputs): + (input_ids, position_ids, type_ids, img_feats, img_pos_feats, + img_soft_labels, attn_masks, img_masks + ) = map(list, unzip(inputs)) + + txt_lens = [i.size(0) for i in input_ids] + num_bbs = [f.size(0) for f in img_feats] + + input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0) + position_ids = pad_sequence(position_ids, + batch_first=True, padding_value=0) + type_ids = pad_sequence(type_ids, batch_first=True, padding_value=0) + attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0) + img_masks = pad_sequence(img_masks, batch_first=True, padding_value=0) + + batch_size = len(img_feats) + num_bb = max(num_bbs) + feat_dim = img_feats[0].size(1) + soft_label_dim = img_soft_labels[0].size(1) + pos_dim = img_pos_feats[0].size(1) + img_feat = torch.zeros(batch_size, num_bb, feat_dim) + img_pos_feat = torch.zeros(batch_size, num_bb, pos_dim) + img_soft_label = torch.zeros(batch_size, num_bb, soft_label_dim) + for i, (im, pos, label) in enumerate(zip(img_feats, + img_pos_feats, + img_soft_labels)): + len_ = im.size(0) + img_feat.data[i, :len_, :] = im.data + img_pos_feat.data[i, :len_, :] = pos.data + img_soft_label.data[i, :len_, :] = label.data + + img_masks_ext_for_label = img_masks.unsqueeze(-1).expand_as(img_soft_label) + label_targets = img_soft_label[img_masks_ext_for_label].contiguous().view( + -1, soft_label_dim) + return (input_ids, position_ids, type_ids, txt_lens, + img_feat, img_pos_feat, num_bbs, + attn_masks, (img_masks, label_targets)) diff --git a/uniter_model/data/ve.py b/uniter_model/data/ve.py new file mode 100644 index 0000000..ddd11e7 --- /dev/null +++ b/uniter_model/data/ve.py @@ -0,0 +1,19 @@ +""" +Visual entailment dataset +# NOTE: basically reuse VQA dataset +""" +from .vqa import VqaDataset, VqaEvalDataset, vqa_collate, vqa_eval_collate + + +class VeDataset(VqaDataset): + def __init__(self, *args, **kwargs): + super().__init__(3, *args, **kwargs) + + +class VeEvalDataset(VqaEvalDataset): + def __init__(self, *args, **kwargs): + super().__init__(3, *args, **kwargs) + + +ve_collate = vqa_collate +ve_eval_collate = vqa_eval_collate diff --git a/uniter_model/data/vqa.py b/uniter_model/data/vqa.py new file mode 100644 index 0000000..b3422c4 --- /dev/null +++ b/uniter_model/data/vqa.py @@ -0,0 +1,124 @@ +""" +VQA dataset +""" +import torch +from torch.nn.utils.rnn import pad_sequence +from toolz.sandbox import unzip + +from .data import DetectFeatTxtTokDataset, pad_tensors, get_gather_index + + +def _get_vqa_target(example, num_answers): + target = torch.zeros(num_answers) + labels = example['target']['labels'] + scores = example['target']['scores'] + if labels and scores: + target.scatter_(0, torch.tensor(labels), torch.tensor(scores)) + return target + + +class VqaDataset(DetectFeatTxtTokDataset): + """ NOTE: This handels distributed inside """ + def __init__(self, num_answers, *args, **kwargs): + super().__init__(*args, **kwargs) + self.num_answers = num_answers + + def __getitem__(self, i): + example = super().__getitem__(i) + img_feat, img_pos_feat, num_bb = self._get_img_feat( + example['img_fname']) + + # text input + input_ids = example['input_ids'] + input_ids = self.txt_db.combine_inputs(input_ids) + + target = _get_vqa_target(example, self.num_answers) + + attn_masks = torch.ones(len(input_ids) + num_bb, dtype=torch.long) + + return input_ids, img_feat, img_pos_feat, attn_masks, target + + +def vqa_collate(inputs): + (input_ids, img_feats, img_pos_feats, attn_masks, targets + ) = map(list, unzip(inputs)) + + txt_lens = [i.size(0) for i in input_ids] + input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0) + position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long + ).unsqueeze(0) + + attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0) + targets = torch.stack(targets, dim=0) + + num_bbs = [f.size(0) for f in img_feats] + img_feat = pad_tensors(img_feats, num_bbs) + img_pos_feat = pad_tensors(img_pos_feats, num_bbs) + + bs, max_tl = input_ids.size() + out_size = attn_masks.size(1) + gather_index = get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size) + + batch = {'input_ids': input_ids, + 'position_ids': position_ids, + 'img_feat': img_feat, + 'img_pos_feat': img_pos_feat, + 'attn_masks': attn_masks, + 'gather_index': gather_index, + 'targets': targets} + return batch + + +class VqaEvalDataset(VqaDataset): + def __getitem__(self, i): + qid = self.ids[i] + example = DetectFeatTxtTokDataset.__getitem__(self, i) + img_feat, img_pos_feat, num_bb = self._get_img_feat( + example['img_fname']) + + # text input + input_ids = example['input_ids'] + input_ids = self.txt_db.combine_inputs(input_ids) + + if 'target' in example: + target = _get_vqa_target(example, self.num_answers) + else: + target = None + + attn_masks = torch.ones(len(input_ids) + num_bb, dtype=torch.long) + + return qid, input_ids, img_feat, img_pos_feat, attn_masks, target + + +def vqa_eval_collate(inputs): + (qids, input_ids, img_feats, img_pos_feats, attn_masks, targets + ) = map(list, unzip(inputs)) + + txt_lens = [i.size(0) for i in input_ids] + + input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0) + position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long + ).unsqueeze(0) + attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0) + if targets[0] is None: + targets = None + else: + targets = torch.stack(targets, dim=0) + + num_bbs = [f.size(0) for f in img_feats] + img_feat = pad_tensors(img_feats, num_bbs) + img_pos_feat = pad_tensors(img_pos_feats, num_bbs) + + bs, max_tl = input_ids.size() + out_size = attn_masks.size(1) + gather_index = get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size) + + batch = {'qids': qids, + 'input_ids': input_ids, + 'position_ids': position_ids, + 'img_feat': img_feat, + 'img_pos_feat': img_pos_feat, + 'attn_masks': attn_masks, + 'gather_index': gather_index, + 'targets': targets} + return batch diff --git a/uniter_model/eval/itm.py b/uniter_model/eval/itm.py new file mode 100644 index 0000000..b8ea00d --- /dev/null +++ b/uniter_model/eval/itm.py @@ -0,0 +1,53 @@ +""" Image Text Retrieval evaluation helper """ +import torch + + +@torch.no_grad() +def itm_eval(score_matrix, txt_ids, img_ids, txt2img, img2txts): + # image retrieval + img2j = {i: j for j, i in enumerate(img_ids)} + _, rank_txt = score_matrix.topk(10, dim=1) + gt_img_j = torch.LongTensor([img2j[txt2img[txt_id]] + for txt_id in txt_ids], + ).to(rank_txt.device + ).unsqueeze(1).expand_as(rank_txt) + rank = (rank_txt == gt_img_j).nonzero() + if rank.numel(): + ir_r1 = (rank < 1).sum().item() / len(txt_ids) + ir_r5 = (rank < 5).sum().item() / len(txt_ids) + ir_r10 = (rank < 10).sum().item() / len(txt_ids) + else: + ir_r1, ir_r5, ir_r10 = 0, 0, 0 + + # text retrieval + txt2i = {t: i for i, t in enumerate(txt_ids)} + _, rank_img = score_matrix.topk(10, dim=0) + tr_r1, tr_r5, tr_r10 = 0, 0, 0 + for j, img_id in enumerate(img_ids): + gt_is = [txt2i[t] for t in img2txts[img_id]] + ranks = [(rank_img[:, j] == i).nonzero() for i in gt_is] + rank = min([10] + [r.item() for r in ranks if r.numel()]) + if rank < 1: + tr_r1 += 1 + if rank < 5: + tr_r5 += 1 + if rank < 10: + tr_r10 += 1 + tr_r1 /= len(img_ids) + tr_r5 /= len(img_ids) + tr_r10 /= len(img_ids) + + tr_mean = (tr_r1 + tr_r5 + tr_r10) / 3 + ir_mean = (ir_r1 + ir_r5 + ir_r10) / 3 + r_mean = (tr_mean + ir_mean) / 2 + + eval_log = {'txt_r1': tr_r1, + 'txt_r5': tr_r5, + 'txt_r10': tr_r10, + 'txt_r_mean': tr_mean, + 'img_r1': ir_r1, + 'img_r5': ir_r5, + 'img_r10': ir_r10, + 'img_r_mean': ir_mean, + 'r_mean': r_mean} + return eval_log diff --git a/uniter_model/eval/nlvr2.py b/uniter_model/eval/nlvr2.py new file mode 100644 index 0000000..e6c7f57 --- /dev/null +++ b/uniter_model/eval/nlvr2.py @@ -0,0 +1,62 @@ +""" +copied from official NLVR2 github +python eval/nlvr2.py +""" +import json +import sys + +# Load the predictions file. Assume it is a CSV. +predictions = { } +for line in open(sys.argv[1]).readlines(): + if line: + splits = line.strip().split(",") + # We assume identifiers are in the format "split-####-#-#.png". + identifier = splits[0] + prediction = splits[1] + predictions[identifier] = prediction + +# Load the labeled examples. +labeled_examples = [json.loads(line) for line in open(sys.argv[2]).readlines() if line] + +# If not, identify the ones that are missing, and exit. +total_num = len(labeled_examples) +if len(predictions) < total_num: + print("Some predictions are missing!") + print("Got " + str(len(predictions)) + " predictions but expected " + str(total_num)) + + for example in labeled_examples: + lookup = example["identifier"] + if not lookup in predictions: + print("Missing prediction for item " + str(lookup)) + exit() + +# Get the precision by iterating through the examples and checking the value +# that was predicted. +# Also update the "consistency" dictionary that keeps track of whether all +# predictions for a given sentence were correct. +num_correct = 0. +consistency_dict = { } + +for example in labeled_examples: + anon_label = example["identifier"].split("-") + anon_label[2] = '' + anon_label = '-'.join(anon_label) + if not anon_label in consistency_dict: + consistency_dict[anon_label] = True + lookup = example["identifier"] + prediction = predictions[lookup] + if prediction.lower() == example["label"].lower(): + num_correct += 1. + else: + consistency_dict[anon_label] = False + +# Calculate consistency. +num_consistent = 0. +unique_sentence = len(consistency_dict) +for identifier, consistent in consistency_dict.items(): + if consistent: + num_consistent += 1 + +# Report values. +print("accuracy=" + str(num_correct / total_num)) +print("consistency=" + str(num_consistent / unique_sentence)) diff --git a/uniter_model/eval_re.py b/uniter_model/eval_re.py new file mode 100644 index 0000000..68c42c8 --- /dev/null +++ b/uniter_model/eval_re.py @@ -0,0 +1,218 @@ +# coding=utf-8 +# copied from hugginface github +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. +# team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""BERT for Referring Expression Comprehension Evaluation""" +import argparse +import json +import os +from os.path import exists +from time import time + +import torch +from torch.utils.data import DataLoader + +# to be deprecated once upgraded to 1.2 +# from torch.utils.data.distributed import DistributedSampler +from data import DistributedSampler + +from apex import amp +from horovod import torch as hvd + +from data import (ReImageFeatDir, ReferringExpressionEvalDataset, + re_eval_collate, PrefetchLoader) +from model import BertForReferringExpressionComprehension + +from utils.logger import LOGGER +from utils.distributed import all_gather_list +from utils.misc import Struct + + +def main(opts): + + hvd.init() + n_gpu = hvd.size() + device = torch.device("cuda", hvd.local_rank()) + torch.cuda.set_device(hvd.local_rank()) + rank = hvd.rank() + opts.rank = rank + LOGGER.info(f"device: {device}, n_gpu: {n_gpu}, rank: {hvd.rank()}, " + f"16-bits training: {opts.fp16}") + + hps_file = f'{opts.output_dir}/log/hps.json' + model_opts = json.load(open(hps_file)) + if 'mlp' not in model_opts: + model_opts['mlp'] = 1 + model_opts = Struct(model_opts) + + # Prepro txt_dbs + txt_dbs = opts.txt_db.split(':') + + # Prepro model + if exists(opts.checkpoint): + ckpt_file = torch.load(opts.checkpoint) + else: + ckpt_file = f'{opts.output_dir}/ckpt/model_epoch_{opts.checkpoint}.pt' + checkpoint = torch.load(ckpt_file) + bert_model = json.load(open(f'{txt_dbs[0]}/meta.json'))['bert'] + model = BertForReferringExpressionComprehension.from_pretrained( + bert_model, img_dim=2048, mlp=model_opts.mlp, + state_dict=checkpoint + ) + if model_opts.cut_bert != -1: + # cut some layers of BERT + model.bert.encoder.layer = torch.nn.ModuleList( + model.bert.encoder.layer[:opts.cut_bert] + ) + model.to(device) + + if opts.fp16: + model = amp.initialize(model, enabled=opts.fp16, opt_level='O2') + + # load DBs and image dirs + eval_img_dir = ReImageFeatDir(opts.img_dir) + for txt_db in txt_dbs: + print(f'Evaluating {txt_db}') + eval_dataset = ReferringExpressionEvalDataset(txt_db, eval_img_dir, + max_txt_len=-1) + eval_sampler = DistributedSampler(eval_dataset, num_replicas=n_gpu, + rank=rank, shuffle=False) + eval_dataloader = DataLoader(eval_dataset, + sampler=eval_sampler, + batch_size=opts.batch_size, + num_workers=opts.n_workers, + pin_memory=opts.pin_mem, + collate_fn=re_eval_collate) + eval_dataloader = PrefetchLoader(eval_dataloader) + + # evaluate + val_log, results = validate(model, eval_dataloader) + + # save + result_dir = f'{opts.output_dir}/results_test' + if not exists(result_dir) and rank == 0: + os.makedirs(result_dir) + + # dummy sync + _ = None + all_gather_list(_) + db_split = txt_db.split('/')[-1].split('-')[0] # refcoco+_val_large + img_dir = opts.img_dir.split('/')[-1] # visual_grounding_coco_gt + if n_gpu > 1: + with open(f'{opts.output_dir}/results_test/' + f'results_{opts.checkpoint}_{db_split}_on_{img_dir}' + f'_rank{rank}.json', + 'w') as f: + json.dump(results, f) + # dummy sync + _ = None + all_gather_list(_) + + # join results + if n_gpu > 1: + results = [] + for rank in range(n_gpu): + results.extend(json.load(open( + f'{opts.output_dir}/results_test/' + f'results_{opts.checkpoint}_{db_split}_on_{img_dir}' + f'_rank{rank}.json'))) + if rank == 0: + with open(f'{opts.output_dir}/results_test/' + f'results_{opts.checkpoint}_{db_split}_on_{img_dir}' + f'_all.json', 'w') as f: + json.dump(results, f) + + # print + print(f'{opts.output_dir}/results_test') + + +@torch.no_grad() +def validate(model, val_dataloader): + LOGGER.info(f"start running evaluation.") + model.eval() + tot_score = 0 + n_ex = 0 + st = time() + predictions = [] + for i, batch in enumerate(val_dataloader): + # inputs + (*batch_inputs, tgt_box_list, obj_boxes_list, sent_ids) = batch + + # scores (n, max_num_bb) + scores = model(*batch_inputs, targets=None, compute_loss=False) + ixs = torch.argmax(scores, 1).cpu().detach().numpy() # (n, ) + + # pred_boxes + for ix, obj_boxes, tgt_box, sent_id in \ + zip(ixs, obj_boxes_list, tgt_box_list, sent_ids): + pred_box = obj_boxes[ix] + predictions.append({'sent_id': sent_id, + 'pred_box': pred_box.tolist(), + 'tgt_box': tgt_box.tolist()}) + if (val_dataloader.loader.dataset.computeIoU(pred_box, tgt_box) + > .5): + tot_score += 1 + n_ex += 1 + + tot_time = time()-st + tot_score = sum(all_gather_list(tot_score)) + n_ex = sum(all_gather_list(n_ex)) + val_acc = tot_score / n_ex + val_log = {'valid/acc': val_acc, 'valid/ex_per_s': n_ex/tot_time} + model.train() + LOGGER.info(f"validation ({n_ex} sents) finished in " + f"{int(tot_time)} seconds" + f", accuracy: {val_acc*100:.2f}%") + + # summarizae + results = {'acc': val_acc, 'predictions': predictions} + return val_log, results + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + + # Requited parameters + parser.add_argument('--txt_db', + default=None, type=str, + help="The input train corpus. (LMDB)") + parser.add_argument('--img_dir', + default=None, type=str, + help="The input train images.") + parser.add_argument('--checkpoint', + default=None, type=str, + help="pretrained model (can take 'google-bert')") + parser.add_argument('--batch_size', + default=256, type=int, + help="number of sentences per batch") + parser.add_argument('--output_dir', + default=None, type=str, + help="The output directory where the model contains " + "the model checkpoints will be written.") + + # Device parameters + parser.add_argument('--fp16', + action='store_true', + help="whether to use fp-16 float precision instead of " + "32 bit") + parser.add_argument('--n_workers', type=int, default=4, + help="number of data workers") + parser.add_argument('--pin_mem', action='store_true', + help="pin memory") + + args = parser.parse_args() + + main(args) diff --git a/uniter_model/eval_vcr.py b/uniter_model/eval_vcr.py new file mode 100644 index 0000000..14e543e --- /dev/null +++ b/uniter_model/eval_vcr.py @@ -0,0 +1,268 @@ +"""run inference of VCR for submission""" +import argparse +import json +import os +from os.path import exists +from time import time + +import torch +from torch.nn import functional as F +from torch.utils.data import DataLoader +from tqdm import tqdm +from apex import amp +from horovod import torch as hvd + +from data import (DetectFeatLmdb, VcrEvalDataset, vcr_eval_collate, + PrefetchLoader) +from torch.utils.data.distributed import DistributedSampler +from model import BertForVisualCommonsenseReasoning + +from utils.logger import LOGGER +from utils.distributed import all_gather_list +from utils.misc import NoOp, Struct +NUM_SPECIAL_TOKENS = 81 + + +def load_img_feat(dir_list, path2imgdir, opts): + dir_ = dir_list.split(";") + assert len(dir_) <= 2, "More than two img_dirs found" + img_dir_gt, img_dir = None, None + gt_dir_path, dir_path = "", "" + for d in dir_: + if "gt" in d: + gt_dir_path = d + else: + dir_path = d + if gt_dir_path != "": + img_dir_gt = path2imgdir.get(gt_dir_path, None) + if img_dir_gt is None: + img_dir_gt = DetectFeatLmdb(gt_dir_path, -1, + opts.max_bb, opts.min_bb, 100, + opts.compressed_db) + path2imgdir[gt_dir_path] = img_dir_gt + if dir_path != "": + img_dir = path2imgdir.get(dir_path, None) + if img_dir is None: + img_dir = DetectFeatLmdb(dir_path, opts.conf_th, + opts.max_bb, opts.min_bb, opts.num_bb, + opts.compressed_db) + path2imgdir[dir_path] = img_dir + return img_dir, img_dir_gt, path2imgdir + + +def main(opts): + hvd.init() + n_gpu = hvd.size() + device = torch.device("cuda", hvd.local_rank()) + torch.cuda.set_device(hvd.local_rank()) + rank = hvd.rank() + opts.rank = rank + LOGGER.info("device: {} n_gpu: {}, rank: {}, " + "16-bits training: {}".format( + device, n_gpu, hvd.rank(), opts.fp16)) + + hps_file = f'{opts.output_dir}/log/hps.json' + model_opts = Struct(json.load(open(hps_file))) + + path2imgdir = {} + # load DBs and image dirs + val_img_dir, val_img_dir_gt, path2imgdir = load_img_feat( + opts.img_dir, path2imgdir, model_opts) + eval_dataset = VcrEvalDataset("test", opts.txt_db, + val_img_dir_gt, val_img_dir, + max_txt_len=-1) + + # Prepare model + bert_model = json.load(open(f'{opts.txt_db}/meta.json'))['bert'] + model = BertForVisualCommonsenseReasoning.from_pretrained( + bert_model, img_dim=2048, obj_cls=False, + state_dict={}) + model.init_type_embedding() + model.init_word_embedding(NUM_SPECIAL_TOKENS) + if exists(opts.checkpoint): + ckpt_file = opts.checkpoint + else: + ckpt_file = f'{opts.output_dir}/ckpt/model_step_{opts.checkpoint}.pt' + checkpoint = torch.load(ckpt_file) + state_dict = checkpoint.get('model_state', checkpoint) + matched_state_dict = {} + unexpected_keys = set() + missing_keys = set() + for name, param in model.named_parameters(): + missing_keys.add(name) + for key, data in state_dict.items(): + if key in missing_keys: + matched_state_dict[key] = data + missing_keys.remove(key) + else: + unexpected_keys.add(key) + print("Unexpected_keys:", list(unexpected_keys)) + print("Missing_keys:", list(missing_keys)) + model.load_state_dict(matched_state_dict, strict=False) + if model_opts.cut_bert != -1: + # cut some layers of BERT + model.bert.encoder.layer = torch.nn.ModuleList( + model.bert.encoder.layer[:model_opts.cut_bert]) + model.to(device) + if opts.fp16: + model = amp.initialize(model, enabled=opts.fp16, opt_level='O2') + + sampler = DistributedSampler( + eval_dataset, num_replicas=n_gpu, rank=rank) + eval_dataloader = DataLoader(eval_dataset, + batch_size=opts.batch_size, + sampler=sampler, + num_workers=opts.n_workers, + pin_memory=opts.pin_mem, + collate_fn=vcr_eval_collate) + eval_dataloader = PrefetchLoader(eval_dataloader) + + val_log, results = evaluate(model, eval_dataloader) + result_dir = f'{opts.output_dir}/results_{opts.split}' + if not exists(result_dir) and rank == 0: + os.makedirs(result_dir) + # dummy sync + _ = None + all_gather_list(_) + if n_gpu > 1: + with open(f'{opts.output_dir}/results_test/' + f'results_{opts.checkpoint}_rank{rank}.json', + 'w') as f: + json.dump(results, f) + # dummy sync + _ = None + all_gather_list(_) + # join results + if n_gpu > 1: + results = [] + for rank in range(n_gpu): + results.extend(json.load(open( + f'{opts.output_dir}/results_test/' + f'results_{opts.checkpoint}_rank{rank}.json'))) + if rank == 0: + with open(f'{opts.output_dir}/results_test/' + f'results_{opts.checkpoint}_all.json', 'w') as f: + json.dump(results, f) + + +def compute_accuracies(out_qa, labels_qa, out_qar, labels_qar): + outputs_qa = out_qa.max(dim=-1)[1] + outputs_qar = out_qar.max(dim=-1)[1] + matched_qa = outputs_qa.squeeze() == labels_qa.squeeze() + matched_qar = outputs_qar.squeeze() == labels_qar.squeeze() + matched_joined = matched_qa & matched_qar + n_correct_qa = matched_qa.sum().item() + n_correct_qar = matched_qar.sum().item() + n_correct_joined = matched_joined.sum().item() + return n_correct_qa, n_correct_qar, n_correct_joined + + +@torch.no_grad() +def evaluate(model, val_loader): + if hvd.rank() == 0: + val_pbar = tqdm(total=len(val_loader)) + else: + val_pbar = NoOp() + LOGGER.info(f"start running evaluation ...") + model.eval() + val_qa_loss, val_qar_loss = 0, 0 + tot_qa_score, tot_qar_score, tot_score = 0, 0, 0 + n_ex = 0 + st = time() + results = {} + for i, batch in enumerate(val_loader): + qids, *inputs, qa_targets, qar_targets, _ = batch + scores = model( + *inputs, targets=None, compute_loss=False) + scores = scores.view(len(qids), -1) + if torch.max(qa_targets) > -1: + vcr_qa_loss = F.cross_entropy( + scores[:, :4], qa_targets.squeeze(-1), reduction="sum") + if scores.shape[1] > 8: + qar_scores = [] + for batch_id in range(scores.shape[0]): + answer_ind = qa_targets[batch_id].item() + qar_index = [4+answer_ind*4+i + for i in range(4)] + qar_scores.append(scores[batch_id, qar_index]) + qar_scores = torch.stack(qar_scores, dim=0) + else: + qar_scores = scores[:, 4:] + vcr_qar_loss = F.cross_entropy( + qar_scores, qar_targets.squeeze(-1), reduction="sum") + val_qa_loss += vcr_qa_loss.item() + val_qar_loss += vcr_qar_loss.item() + + curr_qa_score, curr_qar_score, curr_score = compute_accuracies( + scores[:, :4], qa_targets, qar_scores, qar_targets) + tot_qar_score += curr_qar_score + tot_qa_score += curr_qa_score + tot_score += curr_score + for qid, score in zip(qids, scores): + results[qid] = score.cpu().tolist() + n_ex += len(qids) + val_pbar.update(1) + val_qa_loss = sum(all_gather_list(val_qa_loss)) + val_qar_loss = sum(all_gather_list(val_qar_loss)) + tot_qa_score = sum(all_gather_list(tot_qa_score)) + tot_qar_score = sum(all_gather_list(tot_qar_score)) + tot_score = sum(all_gather_list(tot_score)) + n_ex = sum(all_gather_list(n_ex)) + tot_time = time()-st + val_qa_loss /= n_ex + val_qar_loss /= n_ex + val_qa_acc = tot_qa_score / n_ex + val_qar_acc = tot_qar_score / n_ex + val_acc = tot_score / n_ex + val_log = {f'valid/vcr_qa_loss': val_qa_loss, + f'valid/vcr_qar_loss': val_qar_loss, + f'valid/acc_qa': val_qa_acc, + f'valid/acc_qar': val_qar_acc, + f'valid/acc': val_acc, + f'valid/ex_per_s': n_ex/tot_time} + model.train() + LOGGER.info(f"validation finished in {int(tot_time)} seconds, " + f"score_qa: {val_qa_acc*100:.2f} " + f"score_qar: {val_qar_acc*100:.2f} " + f"score: {val_acc*100:.2f} ") + return val_log, results + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + # Required parameters + parser.add_argument("--txt_db", + default=None, type=str, + help="The input train corpus. (LMDB)") + parser.add_argument("--img_dir", + default=None, type=str, + help="The input train images.") + parser.add_argument('--compressed_db', action='store_true', + help='use compressed LMDB') + parser.add_argument("--split", + default="test", type=str, + help="The input split") + parser.add_argument("--checkpoint", + default=None, type=str, + help="pretrained model (can take 'google-bert') ") + parser.add_argument("--batch_size", + default=10, type=int, + help="number of tokens in a batch") + parser.add_argument( + "--output_dir", default=None, type=str, + help="The output directory where the model checkpoints will be " + "written.") + # device parameters + parser.add_argument('--fp16', + action='store_true', + help="Whether to use 16-bit float precision instead " + "of 32-bit") + parser.add_argument('--n_workers', type=int, default=4, + help="number of data workers") + parser.add_argument('--pin_mem', action='store_true', + help="pin memory") + + args = parser.parse_args() + + main(args) diff --git a/uniter_model/eval_vqa.py b/uniter_model/eval_vqa.py new file mode 100644 index 0000000..980bde4 --- /dev/null +++ b/uniter_model/eval_vqa.py @@ -0,0 +1,180 @@ +"""run inference of VQA for submission""" +import argparse +import json +import os +from os.path import exists +from time import time + +import torch +from torch.utils.data import DataLoader + +from apex import amp +from horovod import torch as hvd +import numpy as np +from cytoolz import concat + +from data import (TokenBucketSampler, PrefetchLoader, + DetectFeatLmdb, TxtTokLmdb, VqaEvalDataset, vqa_eval_collate) +from model import UniterForVisualQuestionAnswering + +from utils.logger import LOGGER +from utils.distributed import all_gather_list +from utils.misc import Struct +from utils.const import BUCKET_SIZE, IMG_DIM + + +def main(opts): + hvd.init() + n_gpu = hvd.size() + device = torch.device("cuda", hvd.local_rank()) + torch.cuda.set_device(hvd.local_rank()) + rank = hvd.rank() + LOGGER.info("device: {} n_gpu: {}, rank: {}, " + "16-bits training: {}".format( + device, n_gpu, hvd.rank(), opts.fp16)) + + hps_file = f'{opts.output_dir}/log/hps.json' + model_opts = Struct(json.load(open(hps_file))) + + # train_examples = None + ans2label_file = f'{opts.output_dir}/ckpt/ans2label.json' + ans2label = json.load(open(ans2label_file)) + label2ans = {label: ans for ans, label in ans2label.items()} + + # load DBs and image dirs + eval_img_db = DetectFeatLmdb(opts.img_db, + model_opts.conf_th, model_opts.max_bb, + model_opts.min_bb, model_opts.num_bb, + opts.compressed_db) + eval_txt_db = TxtTokLmdb(opts.txt_db, -1) + eval_dataset = VqaEvalDataset(len(ans2label), eval_txt_db, eval_img_db) + + # Prepare model + if exists(opts.checkpoint): + ckpt_file = opts.checkpoint + else: + ckpt_file = f'{opts.output_dir}/ckpt/model_step_{opts.checkpoint}.pt' + checkpoint = torch.load(ckpt_file) + model = UniterForVisualQuestionAnswering.from_pretrained( + f'{opts.output_dir}/log/model.json', checkpoint, + img_dim=IMG_DIM, num_answer=len(ans2label)) + model.to(device) + model = amp.initialize(model, enabled=opts.fp16, opt_level='O2') + + sampler = TokenBucketSampler(eval_dataset.lens, bucket_size=BUCKET_SIZE, + batch_size=opts.batch_size, droplast=False) + eval_dataloader = DataLoader(eval_dataset, + batch_sampler=sampler, + num_workers=opts.n_workers, + pin_memory=opts.pin_mem, + collate_fn=vqa_eval_collate) + eval_dataloader = PrefetchLoader(eval_dataloader) + + val_log, results, logits = evaluate(model, eval_dataloader, label2ans, + opts.save_logits) + result_dir = f'{opts.output_dir}/results_test' + if not exists(result_dir) and rank == 0: + os.makedirs(result_dir) + + all_results = list(concat(all_gather_list(results))) + if opts.save_logits: + all_logits = {} + for id2logit in all_gather_list(logits): + all_logits.update(id2logit) + if hvd.rank() == 0: + with open(f'{result_dir}/' + f'results_{opts.checkpoint}_all.json', 'w') as f: + json.dump(all_results, f) + if opts.save_logits: + np.savez(f'{result_dir}/logits_{opts.checkpoint}_all.npz', + **all_logits) + + +@torch.no_grad() +def evaluate(model, eval_loader, label2ans, save_logits=False): + LOGGER.info("start running evaluation...") + model.eval() + n_ex = 0 + st = time() + results = [] + logits = {} + for i, batch in enumerate(eval_loader): + qids = batch['qids'] + scores = model(batch, compute_loss=False) + answers = [label2ans[i] + for i in scores.max(dim=-1, keepdim=False + )[1].cpu().tolist()] + for qid, answer in zip(qids, answers): + results.append({'answer': answer, 'question_id': int(qid)}) + if save_logits: + scores = scores.cpu() + for i, qid in enumerate(qids): + logits[qid] = scores[i].half().numpy() + if i % 100 == 0 and hvd.rank() == 0: + n_results = len(results) + n_results *= hvd.size() # an approximation to avoid hangs + LOGGER.info(f'{n_results}/{len(eval_loader.dataset)} ' + 'answers predicted') + n_ex += len(qids) + n_ex = sum(all_gather_list(n_ex)) + tot_time = time()-st + val_log = {'valid/ex_per_s': n_ex/tot_time} + model.train() + LOGGER.info(f"evaluation finished in {int(tot_time)} seconds " + f"at {int(n_ex/tot_time)} examples per second") + return val_log, results, logits + + +def compute_score_with_logits(logits, labels): + logits = torch.max(logits, 1)[1] # argmax + one_hots = torch.zeros(*labels.size(), device=labels.device) + one_hots.scatter_(1, logits.view(-1, 1), 1) + scores = (one_hots * labels) + return scores + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + # Required parameters + parser.add_argument("--txt_db", + default=None, type=str, + help="The input train corpus. (LMDB)") + parser.add_argument("--img_db", + default=None, type=str, + help="The input train images.") + parser.add_argument('--compressed_db', action='store_true', + help='use compressed LMDB') + parser.add_argument("--checkpoint", + default=None, type=str, + help="pretrained model (can take 'google-bert') ") + parser.add_argument("--batch_size", + default=8192, type=int, + help="number of tokens in a batch") + + parser.add_argument( + "--output_dir", default=None, type=str, + help="The output directory where the model checkpoints will be " + "written.") + + parser.add_argument("--save_logits", action='store_true', + help="Whether to save logits (for making ensemble)") + + # Prepro parameters + + # device parameters + parser.add_argument('--fp16', + action='store_true', + help="Whether to use 16-bit float precision instead " + "of 32-bit") + parser.add_argument('--n_workers', type=int, default=4, + help="number of data workers") + parser.add_argument('--pin_mem', action='store_true', + help="pin memory") + + args = parser.parse_args() + + # options safe guard + # TODO + + main(args) diff --git a/uniter_model/experiments/ablation_refcoco+.sh b/uniter_model/experiments/ablation_refcoco+.sh new file mode 100644 index 0000000..17be7b6 --- /dev/null +++ b/uniter_model/experiments/ablation_refcoco+.sh @@ -0,0 +1,71 @@ +# Supports ablation study of the follows: +# 1) scratch +# 2) bert +# 3) mrfr +# 4) mlm +# 5) itm +# 6) mlm_itm +# 7) mlm_mrfr_itm +# 8) mlm_mrc_itm +# 9) mlm_mrckl_itm +# 10) mlm_mrfr_mrc_itm +# 11) mlm_mrfr_mrckl_itm +# 12) mlm_mrfr_mrckl_itm_jrm +# 13) mlm_mrfr_mrckl_itm_jrm+ + +ablation_pretrained_model=$1 + +case $ablation_pretrained_model in + scratch|bert|mrfr|mlm|itm|mlm_itm|mlm_mrfr_itm|mlm_mrc_itm|mlm_mrckl_itm|mlm_mrfr_mrc_itm|mlm_mrfr_mrckl_itm|mlm_mrfr_mrckl_itm_jrm|mlm_mrfr_mrckl_itm_jrm+) + echo running $ablation_pretrained_model ...;; + *) + echo "$ablation_pretrained_model" not supported.; + exit 1; +esac + +if [ "$ablation_pretrained_model" == "mrfr" ]; then + cut_bert=1 +else + cut_bert=-1 +fi + +case $ablation_pretrained_model in + scratch) + cut_bert=1; + checkpoint="scratch";; + bert) + cut_bert=1; + checkpoint="google-bert";; + mrfr) + cut_bert=1; + checkpoint=/pretrain/ablation/"${ablation_pretrained_model}".pt;; + *) + cut_bert=-1; + checkpoint=/pretrain/ablation/"${ablation_pretrained_model}".pt;; +esac + +horovodrun -np 1 -H localhost:1 \ + python train_re.py \ + --train_txt_db /db/refcoco+_train_base-cased.db \ + --train_img_dir /img/visual_grounding_coco_gt \ + --val_txt_db /db/refcoco+_val_base-cased.db \ + --val_img_dir /img/visual_grounding_det_coco \ + --checkpoint ${checkpoint} \ + --cut_bert ${cut_bert} \ + --output_dir /storage/refcoco+/ablation_"${ablation_pretrained_model}" \ + --max_txt_len 60 \ + --train_batch_size 128 \ + --val_batch_size 128 \ + --learning_rate 8e-5 \ + --optim adamw \ + --betas 0.9 0.98 \ + --weight_decay 0.01 \ + --dropout 0.1 \ + --grad_norm 2.0 \ + --decay linear \ + --num_train_steps 24000 \ + --warmup_steps 1500 \ + --gradient_accumulation_steps 1 \ + --seed 24 \ + --mlp 1 \ + --fp16 diff --git a/uniter_model/experiments/eval_ablation_refcoco+.sh b/uniter_model/experiments/eval_ablation_refcoco+.sh new file mode 100644 index 0000000..09aa1a4 --- /dev/null +++ b/uniter_model/experiments/eval_ablation_refcoco+.sh @@ -0,0 +1,38 @@ +# Supports ablation study of the follows: +# 1) scratch +# 2) bert +# 3) mrfr +# 4) mlm +# 5) itm +# 6) mlm_itm +# 7) mlm_mrfr_itm +# 8) mlm_mrc_itm +# 9) mlm_mrckl_itm +# 10) mlm_mrfr_mrc_itm +# 11) mlm_mrfr_mrckl_itm +# 12) mlm_mrfr_mrckl_itm_jrm +# 13) mlm_mrfr_mrckl_itm_jrm+ + +ablation_pretrained_model=$1 + +case $ablation_pretrained_model in + scratch|bert|mrfr|mlm|itm|mlm_itm|mlm_mrfr_itm|mlm_mrc_itm|mlm_mrckl_itm|mlm_mrfr_mrc_itm|mlm_mrfr_mrckl_itm|mlm_mrfr_mrckl_itm_jrm|mlm_mrfr_mrckl_itm_jrm+) + echo running $ablation_pretrained_model ...;; + *) + echo "$ablation_pretrained_model" not supported.; + exit 1; +esac + +horovodrun -np 1 -H localhost:1 \ + python eval_re.py \ + --txt_db /db/refcoco+_val_base-cased.db:/db/refcoco+_testA_base-cased.db:/db/refcoco+_testB_base-cased.db \ + --img_dir /img/visual_grounding_coco_gt \ + --output_dir /storage/refcoco+/ablation_${ablation_pretrained_model} \ + --checkpoint best + +horovodrun -np 1 -H localhost:1 \ + python eval_re.py \ + --txt_db /db/refcoco+_val_base-cased.db:/db/refcoco+_testA_base-cased.db:/db/refcoco+_testB_base-cased.db \ + --img_dir /img/visual_grounding_det_coco \ + --output_dir /storage/refcoco+/ablation_${ablation_pretrained_model} \ + --checkpoint best diff --git a/uniter_model/experiments/eval_refcoco+.sh b/uniter_model/experiments/eval_refcoco+.sh new file mode 100644 index 0000000..db0b488 --- /dev/null +++ b/uniter_model/experiments/eval_refcoco+.sh @@ -0,0 +1,13 @@ +horovodrun -np 1 -H localhost:1 \ + python eval_re.py \ + --txt_db /db/refcoco+_val_base-cased.db:/db/refcoco+_testA_base-cased.db:/db/refcoco+_testB_base-cased.db \ + --img_dir /img/visual_grounding_coco_gt \ + --output_dir /storage/refcoco+/bert-base_mlm+itm+mrfr_pretrain-refcoco+_lr1e-4 \ + --checkpoint 26 + +horovodrun -np 1 -H localhost:1 \ + python eval_re.py \ + --txt_db /db/refcoco+_val_base-cased.db:/db/refcoco+_testA_base-cased.db:/db/refcoco+_testB_base-cased.db \ + --img_dir /img/visual_grounding_det_coco \ + --output_dir /storage/refcoco+/bert-base_mlm+itm+mrfr_pretrain-refcoco+_lr1e-4 \ + --checkpoint 26 diff --git a/uniter_model/experiments/eval_refcoco+_base_mlm_itm_mrfr_cc.sh b/uniter_model/experiments/eval_refcoco+_base_mlm_itm_mrfr_cc.sh new file mode 100644 index 0000000..a3c8599 --- /dev/null +++ b/uniter_model/experiments/eval_refcoco+_base_mlm_itm_mrfr_cc.sh @@ -0,0 +1,15 @@ +horovodrun -np 1 -H localhost:1 \ + python eval_re.py \ + --txt_db /db/refcoco+_val_base-cased.db:/db/refcoco+_testA_base-cased.db:/db/refcoco+_testB_base-cased.db \ + --img_dir /img/visual_grounding_coco_gt \ + --output_dir /storage/refcoco+/bert-base_mlm_itm_mrfr_pretrain_cc-refcoco+lr8e-5 \ + --checkpoint 51 + +horovodrun -np 1 -H localhost:1 \ + python eval_re.py \ + --txt_db /db/refcoco+_val_base-cased.db:/db/refcoco+_testA_base-cased.db:/db/refcoco+_testB_base-cased.db \ + --img_dir /img/visual_grounding_det_coco \ + --output_dir /storage/refcoco+/bert-base_mlm_itm_mrfr_pretrain_cc-refcoco+lr8e-5 \ + --checkpoint 51 + + diff --git a/uniter_model/experiments/eval_refcoco+_base_mlm_itm_mrfr_mrckl_all.sh b/uniter_model/experiments/eval_refcoco+_base_mlm_itm_mrfr_mrckl_all.sh new file mode 100644 index 0000000..fe4bc06 --- /dev/null +++ b/uniter_model/experiments/eval_refcoco+_base_mlm_itm_mrfr_mrckl_all.sh @@ -0,0 +1,17 @@ +# This training is done by experiments/train_refcoco+_base_mlm_itm_mrfr_mrckl_all.sh +# det val: 74.52, testA: 79.76, testB: 64.43 +# gd val: 82.74, testA: 85.21, testB: 77.52 +horovodrun -np 1 -H localhost:1 \ + python eval_re.py \ + --txt_db /db/refcoco+_val_base-cased.db:/db/refcoco+_testA_base-cased.db:/db/refcoco+_testB_base-cased.db \ + --img_dir /img/visual_grounding_coco_gt \ + --output_dir /storage/refcoco+/bert-base_allweak_alldata-refcoco+_w1000_s10000_l5e-5_b128_g1_m1 \ + --checkpoint best + +horovodrun -np 1 -H localhost:1 \ + python eval_re.py \ + --txt_db /db/refcoco+_val_base-cased.db:/db/refcoco+_testA_base-cased.db:/db/refcoco+_testB_base-cased.db \ + --img_dir /img/visual_grounding_det_coco \ + --output_dir /storage/refcoco+/bert-base_allweak_alldata-refcoco+_w1000_s10000_l5e-5_b128_g1_m1 \ + --checkpoint best + diff --git a/uniter_model/experiments/eval_refcoco+_base_mlm_itm_mrfr_mrckl_ccsbu.sh b/uniter_model/experiments/eval_refcoco+_base_mlm_itm_mrfr_mrckl_ccsbu.sh new file mode 100644 index 0000000..b4ad1db --- /dev/null +++ b/uniter_model/experiments/eval_refcoco+_base_mlm_itm_mrfr_mrckl_ccsbu.sh @@ -0,0 +1,33 @@ +# det val: 72.49, testA: 79.20, testB: 63.22; gt val: 80.23, testA: 83.71, testB: 75.62 +# output_name=bert-base_allweak_ccsbu-refcoco+_w1000_s10000_l5e-5_b128_g1_m1 + +# det val: 72.63, testA: 79.18, testB: 63.37; gt val: 80.37, testA: 83.30, testB: 75.31 +# output_name=bert-base_allweak_ccsbu-refcoco+_w1200_s12000_l8e-5_b128_g1_m1 + + +# # det val: 72.63, testA: 78.83, testB: 63.76; gt val: 80.45, testA: 83.58, testB: 75.82 +# GPU=$1 +# output_name=bert-base_allweak_ccsbu-refcoco+_w1000_s10000_l6e-5_b128_g1_m1 + +# det val: 72.90, testA: 79.01, testB: 63.53; gt val: 80.82, testA: 83.65, testB: 76.46 +GPU=$1 +output_name=bert-base_allweak_ccsbu-refcoco+_w1000_s12000_l6e-5_b64_g1_m2 + +# print +echo $output_name + +CUDA_VISIBLE_DEVICES=${GPU} horovodrun -np 1 -H localhost:1 \ + python eval_re.py \ + --txt_db /db/refcoco+_val_base-cased.db:/db/refcoco+_testA_base-cased.db:/db/refcoco+_testB_base-cased.db \ + --img_dir /img/visual_grounding_coco_gt \ + --output_dir /storage/refcoco+/${output_name} \ + --checkpoint best + +CUDA_VISIBLE_DEVICES=${GPU} horovodrun -np 1 -H localhost:1 \ + python eval_re.py \ + --txt_db /db/refcoco+_val_base-cased.db:/db/refcoco+_testA_base-cased.db:/db/refcoco+_testB_base-cased.db \ + --img_dir /img/visual_grounding_det_coco \ + --output_dir /storage/refcoco+/${output_name} \ + --checkpoint best + + diff --git a/uniter_model/experiments/eval_refcoco+_base_mlm_itm_mrfr_mrckl_vgcoco.sh b/uniter_model/experiments/eval_refcoco+_base_mlm_itm_mrfr_mrckl_vgcoco.sh new file mode 100644 index 0000000..92d9754 --- /dev/null +++ b/uniter_model/experiments/eval_refcoco+_base_mlm_itm_mrfr_mrckl_vgcoco.sh @@ -0,0 +1,36 @@ +# This training is done by experiments/ablation_refcoco+.sh mlm_mrfr_mrckl_itm +# det val: 74.52, testA: 79.76, testB: 64.43. +# gd val: 82.74, testA: 85.21, testB: 77.52; +horovodrun -np 1 -H localhost:1 \ + python eval_re.py \ + --txt_db /db/refcoco+_val_base-cased.db:/db/refcoco+_testA_base-cased.db:/db/refcoco+_testB_base-cased.db \ + --img_dir /img/visual_grounding_coco_gt \ + --output_dir /storage/refcoco+/ablation_mlm_mrfr_mrckl_itm \ + --checkpoint best + +horovodrun -np 1 -H localhost:1 \ + python eval_re.py \ + --txt_db /db/refcoco+_val_base-cased.db:/db/refcoco+_testA_base-cased.db:/db/refcoco+_testB_base-cased.db \ + --img_dir /img/visual_grounding_det_coco \ + --output_dir /storage/refcoco+/ablation_mlm_mrfr_mrckl_itm \ + --checkpoint best + + +# # This training is done by ./experiments/train_refcoco+_base_mlm_itm_mrfr_mrckl_vgcoco.sh +# # det val: 74.60, testA: 80.42, testB: 64.98 +# # gd val: 82.84, testA: 85.10, testB: 77.95 +# horovodrun -np 1 -H localhost:1 \ +# python eval_re.py \ +# --txt_db /db/refcoco+_val_base-cased.db:/db/refcoco+_testA_base-cased.db:/db/refcoco+_testB_base-cased.db \ +# --img_dir /img/visual_grounding_coco_gt \ +# --output_dir /storage/refcoco+/bert-base_mlm_itm_mrfr_mrckl_itm_pretrain_cocovg-refcoco+_12k_mlp1 \ +# --checkpoint best + +# horovodrun -np 1 -H localhost:1 \ +# python eval_re.py \ +# --txt_db /db/refcoco+_val_base-cased.db:/db/refcoco+_testA_base-cased.db:/db/refcoco+_testB_base-cased.db \ +# --img_dir /img/visual_grounding_det_coco \ +# --output_dir /storage/refcoco+/bert-base_mlm_itm_mrfr_mrckl_itm_pretrain_cocovg-refcoco+_12k_mlp1 \ +# --checkpoint best + + diff --git a/uniter_model/experiments/eval_refcoco+_conceptual.sh b/uniter_model/experiments/eval_refcoco+_conceptual.sh new file mode 100644 index 0000000..f574cd5 --- /dev/null +++ b/uniter_model/experiments/eval_refcoco+_conceptual.sh @@ -0,0 +1,15 @@ +horovodrun -np 1 -H localhost:1 \ + python eval_re.py \ + --txt_db /db/refcoco+_val_base-cased.db:/db/refcoco+_testA_base-cased.db:/db/refcoco+_testB_base-cased.db \ + --img_dir /img/visual_grounding_coco_gt \ + --output_dir /storage/refcoco+/conceptual-bert-base_mlm+itm+mrfr_pretrain-refcoco+_lr8e-5 \ + --checkpoint 51 + +horovodrun -np 1 -H localhost:1 \ + python eval_re.py \ + --txt_db /db/refcoco+_val_base-cased.db:/db/refcoco+_testA_base-cased.db:/db/refcoco+_testB_base-cased.db \ + --img_dir /img/visual_grounding_det_coco \ + --output_dir /storage/refcoco+/conceptual-bert-base_mlm+itm+mrfr_pretrain-refcoco+_lr8e-5 \ + --checkpoint 51 + + diff --git a/uniter_model/experiments/eval_refcoco+_large.sh b/uniter_model/experiments/eval_refcoco+_large.sh new file mode 100644 index 0000000..018ba28 --- /dev/null +++ b/uniter_model/experiments/eval_refcoco+_large.sh @@ -0,0 +1,14 @@ +horovodrun -np 1 -H localhost:1 \ + python eval_re.py \ + --txt_db /db/refcoco+_val_large-cased.db:/db/refcoco+_testA_large-cased.db:/db/refcoco+_testB_large-cased.db \ + --img_dir /img/visual_grounding_coco_gt \ + --output_dir /storage/refcoco+/bert-large_mlm+itm+mrfr+mrckl_pretrain_alldata-refcoco+_lr8e-5_2mlp \ + --checkpoint 40 + +horovodrun -np 1 -H localhost:1 \ + python eval_re.py \ + --txt_db /db/refcoco+_val_large-cased.db:/db/refcoco+_testA_large-cased.db:/db/refcoco+_testB_large-cased.db \ + --img_dir /img/visual_grounding_det_coco \ + --output_dir /storage/refcoco+/bert-large_mlm+itm+mrfr+mrckl_pretrain_alldata-refcoco+_lr8e-5_2mlp \ + --checkpoint 40 + diff --git a/uniter_model/experiments/eval_refcoco+_large_mlm_itm_mrfr_cocovg.sh b/uniter_model/experiments/eval_refcoco+_large_mlm_itm_mrfr_cocovg.sh new file mode 100644 index 0000000..2053233 --- /dev/null +++ b/uniter_model/experiments/eval_refcoco+_large_mlm_itm_mrfr_cocovg.sh @@ -0,0 +1,13 @@ +horovodrun -np 1 -H localhost:1 \ + python eval_re.py \ + --txt_db /db/refcoco+_val_large-cased.db:/db/refcoco+_testA_large-cased.db:/db/refcoco+_testB_large-cased.db \ + --img_dir /img/visual_grounding_coco_gt \ + --output_dir /storage/refcoco+/bert-large_mlm+itm+mrfr_pretrain_cocovg-refcoco+_lr8e-5_b64g4 \ + --checkpoint 52 + +horovodrun -np 1 -H localhost:1 \ + python eval_re.py \ + --txt_db /db/refcoco+_val_large-cased.db:/db/refcoco+_testA_large-cased.db:/db/refcoco+_testB_large-cased.db \ + --img_dir /img/visual_grounding_det_coco \ + --output_dir /storage/refcoco+/bert-large_mlm+itm+mrfr_pretrain_cocovg-refcoco+_lr8e-5_b64g4 \ + --checkpoint 52 diff --git a/uniter_model/experiments/eval_refcoco+_large_mlm_itm_mrfr_mrckl_all.sh b/uniter_model/experiments/eval_refcoco+_large_mlm_itm_mrfr_mrckl_all.sh new file mode 100644 index 0000000..91d4056 --- /dev/null +++ b/uniter_model/experiments/eval_refcoco+_large_mlm_itm_mrfr_mrckl_all.sh @@ -0,0 +1,26 @@ +# output_name=bert-large_mlm+itm+mrfr+mrckl_pretrain_alldata-refcoco+_lr8e-5_2mlp +# output_name=bert-large_allweak_alldata-refcoco+_w1000_s10000_l5e-5_b64_g2_m1 + +# det val: 74.94, testA: 81.24, testB: 65.06. +# gd val: 84.04, testA: 85.87, testB: 78.89; +# output_name=bert-large_allweak_alldata-refcoco+_w1000_s10000_l5e-5_b64_g2_m2 + +# +output_name=bert-large_allweak_alldata-refcoco+_w1000_s10000_l5e-5_b32_g2_m2 + +echo ${output_name} + +horovodrun -np 1 -H localhost:1 \ + python eval_re.py \ + --txt_db /db/refcoco+_val_large-cased.db:/db/refcoco+_testA_large-cased.db:/db/refcoco+_testB_large-cased.db \ + --img_dir /img/visual_grounding_coco_gt \ + --output_dir /storage/refcoco+/${output_name} \ + --checkpoint best + +horovodrun -np 1 -H localhost:1 \ + python eval_re.py \ + --txt_db /db/refcoco+_val_large-cased.db:/db/refcoco+_testA_large-cased.db:/db/refcoco+_testB_large-cased.db \ + --img_dir /img/visual_grounding_det_coco \ + --output_dir /storage/refcoco+/${output_name} \ + --checkpoint best + diff --git a/uniter_model/experiments/eval_refer_base_mlm_itm_mrfr_mrckl_all.sh b/uniter_model/experiments/eval_refer_base_mlm_itm_mrfr_mrckl_all.sh new file mode 100644 index 0000000..2f070f2 --- /dev/null +++ b/uniter_model/experiments/eval_refer_base_mlm_itm_mrfr_mrckl_all.sh @@ -0,0 +1,38 @@ +# bert-base with all-tasks pre-trained on all data (COCO+VG+CC+SBU) + +REFER=$1 +GPU=$2 + +# parameters +warmup=1000 +steps=12000 # 12000 +lr=6e-5 +batch_size=64 # 64 +gradient_accumulation_steps=1 +mlp=2 + +# output name +output_name=bert-base_allweak_alldata-${REFER}_w${warmup}_s${steps}_l${lr}_b${batch_size}_g${gradient_accumulation_steps}_m${mlp} +echo ${output_name} + +# Evaluate +case ${REFER} in + refcoco|refcoco+) + dbs=/db/${REFER}_val_base-cased.db:/db/${REFER}_testA_base-cased.db:/db/${REFER}_testB_base-cased.db;; + refcocog) + dbs=/db/${REFER}_val_base-cased.db:/db/${REFER}_test_base-cased.db;; +esac + +CUDA_VISIBLE_DEVICES=${GPU} horovodrun -np 1 -H localhost:1 \ + python eval_re.py \ + --txt_db ${dbs} \ + --img_dir /img/visual_grounding_coco_gt \ + --output_dir /storage/${REFER}/${output_name} \ + --checkpoint best + +CUDA_VISIBLE_DEVICES=${GPU} horovodrun -np 1 -H localhost:1 \ + python eval_re.py \ + --txt_db ${dbs} \ + --img_dir /img/visual_grounding_det_coco \ + --output_dir /storage/${REFER}/${output_name} \ + --checkpoint best \ No newline at end of file diff --git a/uniter_model/experiments/eval_refer_large_mlm_itm_mrfr_mrckl_all.sh b/uniter_model/experiments/eval_refer_large_mlm_itm_mrfr_mrckl_all.sh new file mode 100644 index 0000000..562552e --- /dev/null +++ b/uniter_model/experiments/eval_refer_large_mlm_itm_mrfr_mrckl_all.sh @@ -0,0 +1,38 @@ +# bert-large with all-tasks pre-trained on all data (COCO+VG+CC+SBU) + +REFER=$1 +GPU=$2 + +# parameters +warmup=1000 +steps=12000 +lr=5e-5 +batch_size=32 #32 or 64 +gradient_accumulation_steps=2 +mlp=2 + +# output name +output_name=bert-large_allweak_alldata-${REFER}_w${warmup}_s${steps}_l${lr}_b${batch_size}_g${gradient_accumulation_steps}_m${mlp} +echo ${output_name} + +# Evaluate +case ${REFER} in + refcoco|refcoco+) + dbs=/db/${REFER}_val_large-cased.db:/db/${REFER}_testA_large-cased.db:/db/${REFER}_testB_large-cased.db;; + refcocog) + dbs=/db/${REFER}_val_large-cased.db:/db/${REFER}_test_large-cased.db;; +esac + +CUDA_VISIBLE_DEVICES=${GPU} horovodrun -np 1 -H localhost:1 \ + python eval_re.py \ + --txt_db ${dbs} \ + --img_dir /img/visual_grounding_coco_gt \ + --output_dir /storage/${REFER}/${output_name} \ + --checkpoint best + +CUDA_VISIBLE_DEVICES=${GPU} horovodrun -np 1 -H localhost:1 \ + python eval_re.py \ + --txt_db ${dbs} \ + --img_dir /img/visual_grounding_det_coco \ + --output_dir /storage/${REFER}/${output_name} \ + --checkpoint best diff --git a/uniter_model/experiments/train_refcoco+.sh b/uniter_model/experiments/train_refcoco+.sh new file mode 100644 index 0000000..ad4a68a --- /dev/null +++ b/uniter_model/experiments/train_refcoco+.sh @@ -0,0 +1,30 @@ +# horovodrun -np 1 -H localhost:1 \ +# python train_re.py --config config/hps-refcoco+.json + +# horovodrun -np 2 -H localhost:2 \ +# python train_re.py --config config/hps-refcoco+.json + +horovodrun -np 1 -H localhost:1 \ + python train_re.py \ + --train_txt_db /db/refcoco+_train_base-cased.db \ + --train_img_dir /img/visual_grounding_coco_gt \ + --val_txt_db /db/refcoco+_val_base-cased.db \ + --val_img_dir /img/visual_grounding_coco_gt \ + --checkpoint /pretrain/bert-base_weak/ckpt/model_step_420000.pt \ + --cut_bert -1 \ + --output_dir /storage/refcoco+/bert-base_mlm+itm+mrfr_pretrain-refcoco+_lr1e-4 \ + --max_txt_len 60 \ + --train_batch_size 128 \ + --val_batch_size 128 \ + --learning_rate 1e-4 \ + --optim adamw \ + --betas 0.9 0.98 \ + --weight_decay 0.01 \ + --dropout 0.1 \ + --grad_norm 2.0 \ + --decay linear \ + --num_train_steps 24000 \ + --warmup_steps 1500 \ + --gradient_accumulation_steps 1 \ + --seed 24 \ + --fp16 \ No newline at end of file diff --git a/uniter_model/experiments/train_refcoco+_base_mlm_itm_mrfr_cc.sh b/uniter_model/experiments/train_refcoco+_base_mlm_itm_mrfr_cc.sh new file mode 100644 index 0000000..27c1eb3 --- /dev/null +++ b/uniter_model/experiments/train_refcoco+_base_mlm_itm_mrfr_cc.sh @@ -0,0 +1,27 @@ +# horovodrun -np 1 -H localhost:1 \ +# python train_re.py --config config/hps-refcoco+_conceptual.json + +horovodrun -np 1 -H localhost:1 \ + python train_re.py \ + --train_txt_db /db/refcoco+_train_base-cased.db \ + --train_img_dir /img/visual_grounding_coco_gt \ + --val_txt_db /db/refcoco+_val_base-cased.db \ + --val_img_dir /img/visual_grounding_coco_gt \ + --checkpoint /pretrain/bert-base_weak_conceptual/ckpt/model_step_200000.pt \ + --cut_bert -1 \ + --output_dir /storage/refcoco+/bert-base_mlm_itm_mrfr_pretrain_cc-refcoco+lr8e-5 \ + --max_txt_len 60 \ + --train_batch_size 128 \ + --val_batch_size 128 \ + --learning_rate 8e-5 \ + --optim adamw \ + --betas 0.9 0.98 \ + --weight_decay 0.01 \ + --dropout 0.1 \ + --grad_norm 2.0 \ + --decay linear \ + --num_train_steps 24000 \ + --warmup_steps 1500 \ + --gradient_accumulation_steps 1 \ + --seed 24 \ + --fp16 diff --git a/uniter_model/experiments/train_refcoco+_base_mlm_itm_mrfr_mrckl_all.sh b/uniter_model/experiments/train_refcoco+_base_mlm_itm_mrfr_mrckl_all.sh new file mode 100644 index 0000000..4d1f9fc --- /dev/null +++ b/uniter_model/experiments/train_refcoco+_base_mlm_itm_mrfr_mrckl_all.sh @@ -0,0 +1,44 @@ +# bert-base with all-tasks pre-trained on all data (COCO+VG+CC+SBU) + +GPU=$1 + +# pre-trained model +checkpoint=/pretrain/bert-base_weak_alldata_2nodes_16accum/ckpt/model_step_125000.pt + +# parameters +warmup=1000 +steps=10000 +lr=5e-5 +batch_size=128 +gradient_accumulation_steps=1 +mlp=1 + +# output name +output_name=bert-base_allweak_alldata-refcoco+_w${warmup}_s${steps}_l${lr}_b${batch_size}_g${gradient_accumulation_steps}_m${mlp} +echo ${output_name} + +CUDA_VISIBLE_DEVICES=${GPU} horovodrun -np 1 -H localhost:1 \ + python train_re.py \ + --train_txt_db /db/refcoco+_train_base-cased.db \ + --train_img_dir /img/visual_grounding_coco_gt \ + --val_txt_db /db/refcoco+_val_base-cased.db \ + --val_img_dir /img/visual_grounding_det_coco \ + --checkpoint ${checkpoint} \ + --cut_bert -1 \ + --output_dir /storage/refcoco+/${output_name} \ + --max_txt_len 60 \ + --train_batch_size ${batch_size} \ + --val_batch_size 128 \ + --learning_rate ${lr} \ + --optim adamw \ + --betas 0.9 0.98 \ + --weight_decay 0.01 \ + --dropout 0.1 \ + --grad_norm 2.0 \ + --decay linear \ + --num_train_steps ${steps} \ + --warmup_steps ${warmup} \ + --gradient_accumulation_steps ${gradient_accumulation_steps} \ + --seed 24 \ + --mlp ${mlp} \ + --fp16 \ No newline at end of file diff --git a/uniter_model/experiments/train_refcoco+_base_mlm_itm_mrfr_mrckl_ccsbu.sh b/uniter_model/experiments/train_refcoco+_base_mlm_itm_mrfr_mrckl_ccsbu.sh new file mode 100644 index 0000000..a47c225 --- /dev/null +++ b/uniter_model/experiments/train_refcoco+_base_mlm_itm_mrfr_mrckl_ccsbu.sh @@ -0,0 +1,73 @@ +# # bert-base with all-tasks pre-trained on CC+SBU + +# checkpoint=/pretrain/bert-base_weak_conceptual_sbu_mlm_mrm_itm_mrckl_3xV100_run2/ckpt/model_step_100000.pt +# output_name=bert-base_mlm_itm_mrfr_mrckl_itm_pretrain_ccsbu-refcoco+_12k_mlp2 + +# horovodrun -np 1 -H localhost:1 \ +# python train_re.py \ +# --train_txt_db /db/refcoco+_train_base-cased.db \ +# --train_img_dir /img/visual_grounding_coco_gt \ +# --val_txt_db /db/refcoco+_val_base-cased.db \ +# --val_img_dir /img/visual_grounding_det_coco \ +# --checkpoint ${checkpoint} \ +# --cut_bert -1 \ +# --output_dir /storage/refcoco+/${output_name} \ +# --max_txt_len 60 \ +# --train_batch_size 128 \ +# --val_batch_size 128 \ +# --learning_rate 8e-5 \ +# --optim adamw \ +# --betas 0.9 0.98 \ +# --weight_decay 0.01 \ +# --dropout 0.1 \ +# --grad_norm 2.0 \ +# --decay linear \ +# --num_train_steps 12000 \ +# --warmup_steps 1500 \ +# --gradient_accumulation_steps 1 \ +# --seed 24 \ +# --mlp 2 \ +# --fp16 + +GPU=$1 + +# pre-trained model +checkpoint=/pretrain/bert-base_weak_conceptual_sbu_mlm_mrm_itm_mrckl_3xV100_run2/ckpt/model_step_100000.pt + +# parameters +warmup=1000 +steps=12000 +lr=6e-5 +batch_size=64 +gradient_accumulation_steps=1 +mlp=2 + +# output name +output_name=bert-base_allweak_ccsbu-refcoco+_w${warmup}_s${steps}_l${lr}_b${batch_size}_g${gradient_accumulation_steps}_m${mlp} +echo ${output_name} + +CUDA_VISIBLE_DEVICES=${GPU} horovodrun -np 1 -H localhost:1 \ + python train_re.py \ + --train_txt_db /db/refcoco+_train_base-cased.db \ + --train_img_dir /img/visual_grounding_coco_gt \ + --val_txt_db /db/refcoco+_val_base-cased.db \ + --val_img_dir /img/visual_grounding_det_coco \ + --checkpoint ${checkpoint} \ + --cut_bert -1 \ + --output_dir /storage/refcoco+/${output_name} \ + --max_txt_len 60 \ + --train_batch_size ${batch_size} \ + --val_batch_size 128 \ + --learning_rate ${lr} \ + --optim adamw \ + --betas 0.9 0.98 \ + --weight_decay 0.01 \ + --dropout 0.1 \ + --grad_norm 2.0 \ + --decay linear \ + --num_train_steps ${steps} \ + --warmup_steps ${warmup} \ + --gradient_accumulation_steps ${gradient_accumulation_steps} \ + --seed 24 \ + --mlp ${mlp} \ + --fp16 \ No newline at end of file diff --git a/uniter_model/experiments/train_refcoco+_base_mlm_itm_mrfr_mrckl_vgcoco.sh b/uniter_model/experiments/train_refcoco+_base_mlm_itm_mrfr_mrckl_vgcoco.sh new file mode 100644 index 0000000..2f7f623 --- /dev/null +++ b/uniter_model/experiments/train_refcoco+_base_mlm_itm_mrfr_mrckl_vgcoco.sh @@ -0,0 +1,64 @@ + +# bert-base with all-tasks pre-trained on COCO+VG +checkpoint=/pretrain/ablation/mlm_mrfr_mrckl_itm.pt +output_name=bert-base_mlm_itm_mrfr_mrckl_itm_pretrain_cocovg-refcoco+_12k_mlp1 + +horovodrun -np 1 -H localhost:1 \ + python train_re.py \ + --train_txt_db /db/refcoco+_train_base-cased.db \ + --train_img_dir /img/visual_grounding_coco_gt \ + --val_txt_db /db/refcoco+_val_base-cased.db \ + --val_img_dir /img/visual_grounding_det_coco \ + --checkpoint ${checkpoint} \ + --cut_bert -1 \ + --output_dir /storage/refcoco+/${output_name} \ + --max_txt_len 60 \ + --train_batch_size 128 \ + --val_batch_size 128 \ + --learning_rate 8e-5 \ + --optim adamw \ + --betas 0.9 0.98 \ + --weight_decay 0.01 \ + --dropout 0.1 \ + --grad_norm 2.0 \ + --decay linear \ + --num_train_steps 12000 \ + --warmup_steps 1500 \ + --gradient_accumulation_steps 1 \ + --seed 24 \ + --mlp 1 \ + --fp16 + +######################### +# This one is even better +######################### +# ablation_pretrained_model=mlm_mrfr_mrckl_itm +# checkpoint=/pretrain/ablation/mlm_mrfr_mrckl_itm.pt; +# output_name=bert-base_mlm_itm_mrfr_mrckl_itm_pretrain_cocovg-refcoco+_step10k + +# horovodrun -np 1 -H localhost:1 \ +# python train_re.py \ +# --train_txt_db /db/refcoco+_train_base-cased.db \ +# --train_img_dir /img/visual_grounding_coco_gt \ +# --val_txt_db /db/refcoco+_val_base-cased.db \ +# --val_img_dir /img/visual_grounding_det_coco \ +# --checkpoint ${checkpoint} \ +# --cut_bert -1 \ +# --output_dir /storage/refcoco+/${output_name} \ +# --max_txt_len 60 \ +# --train_batch_size 128 \ +# --val_batch_size 128 \ +# --learning_rate 8e-5 \ +# --optim adamw \ +# --betas 0.9 0.98 \ +# --weight_decay 0.01 \ +# --dropout 0.1 \ +# --grad_norm 2.0 \ +# --decay linear \ +# --num_train_steps 10000 \ +# --warmup_steps 1500 \ +# --gradient_accumulation_steps 1 \ +# --seed 24 \ +# --mlp 1 \ +# --fp16 + diff --git a/uniter_model/experiments/train_refcoco+_conceptual_rank.sh b/uniter_model/experiments/train_refcoco+_conceptual_rank.sh new file mode 100644 index 0000000..67f82c6 --- /dev/null +++ b/uniter_model/experiments/train_refcoco+_conceptual_rank.sh @@ -0,0 +1,2 @@ +horovodrun -np 1 -H localhost:1 \ + python train_re.py --config config/hps-refcoco+_conceptual_rank.json diff --git a/uniter_model/experiments/train_refcoco+_large_mlm_itm_mrfr_cocovg.sh b/uniter_model/experiments/train_refcoco+_large_mlm_itm_mrfr_cocovg.sh new file mode 100644 index 0000000..b1b83f1 --- /dev/null +++ b/uniter_model/experiments/train_refcoco+_large_mlm_itm_mrfr_cocovg.sh @@ -0,0 +1,28 @@ +# horovodrun -np 1 -H localhost:1 \ +# python train_re.py --config config/hps-refcoco+_conceptual_large_weak.json + +horovodrun -np 1 -H localhost:1 \ + python train_re.py \ + --train_txt_db /db/refcoco+_train_large-cased.db \ + --train_img_dir /img/visual_grounding_coco_gt \ + --val_txt_db /db/refcoco+_val_large-cased.db \ + --val_img_dir /img/visual_grounding_det_coco \ + --checkpoint /pretrain/bert-large_weak/ckpt/model_step_50000.pt \ + --cut_bert -1 \ + --output_dir /storage/refcoco+/conceptual-bert-large_mlm+itm+mrfr_pretrain-refcoco+_lr8e-5_b64g4 \ + --max_txt_len 60 \ + --train_batch_size 64 \ + --val_batch_size 256 \ + --learning_rate 8e-5 \ + --optim adamw \ + --betas 0.9 0.98 \ + --weight_decay 0.01 \ + --dropout 0.1 \ + --grad_norm 2.0 \ + --decay linear \ + --num_train_steps 24000 \ + --warmup_steps 1500 \ + --gradient_accumulation_steps 4 \ + --seed 24 \ + --mlp 1 \ + --fp16 \ No newline at end of file diff --git a/uniter_model/experiments/train_refcoco+_large_mlm_itm_mrfr_mrckl_all.sh b/uniter_model/experiments/train_refcoco+_large_mlm_itm_mrfr_mrckl_all.sh new file mode 100644 index 0000000..aeeb34f --- /dev/null +++ b/uniter_model/experiments/train_refcoco+_large_mlm_itm_mrfr_mrckl_all.sh @@ -0,0 +1,70 @@ +# horovodrun -np 1 -H localhost:1 \ +# python train_re.py \ +# --train_txt_db /db/refcoco+_train_large-cased.db \ +# --train_img_dir /img/visual_grounding_coco_gt \ +# --val_txt_db /db/refcoco+_val_large-cased.db \ +# --val_img_dir /img/visual_grounding_det_coco \ +# --checkpoint /pretrain/bert-large_frkl_alldata.pt \ +# --cut_bert -1 \ +# --output_dir /storage/refcoco+/bert-large_mlm+itm+mrfr+mrckl_pretrain_alldata-refcoco+_lr8e-5_2mlp \ +# --max_txt_len 60 \ +# --train_batch_size 64 \ +# --val_batch_size 256 \ +# --learning_rate 8e-5 \ +# --optim adamw \ +# --betas 0.9 0.98 \ +# --weight_decay 0.01 \ +# --dropout 0.1 \ +# --grad_norm 2.0 \ +# --decay linear \ +# --num_train_steps 24000 \ +# --warmup_steps 1500 \ +# --gradient_accumulation_steps 4 \ +# --seed 24 \ +# --mlp 2 \ +# --fp16 + + +# bert-large with all-tasks pre-trained on all data (COCO+VG+CC+SBU) +GPU=$1 + +# pre-trained model +checkpoint=/pretrain/bert-large_frkl_alldata.pt + +# parameters +warmup=1000 +steps=10000 +lr=5e-5 +batch_size=32 +gradient_accumulation_steps=2 +mlp=1 + +# output name +output_name=bert-large_allweak_alldata-refcoco+_w${warmup}_s${steps}_l${lr}_b${batch_size}_g${gradient_accumulation_steps}_m${mlp} +echo ${output_name} + +CUDA_VISIBLE_DEVICES=${GPU} horovodrun -np 1 -H localhost:1 \ + python train_re.py \ + --train_txt_db /db/refcoco+_train_large-cased.db \ + --train_img_dir /img/visual_grounding_coco_gt \ + --val_txt_db /db/refcoco+_val_large-cased.db \ + --val_img_dir /img/visual_grounding_det_coco \ + --checkpoint ${checkpoint} \ + --cut_bert -1 \ + --output_dir /storage/refcoco+/${output_name} \ + --max_txt_len 60 \ + --train_batch_size ${batch_size} \ + --val_batch_size 128 \ + --learning_rate ${lr} \ + --optim adamw \ + --betas 0.9 0.98 \ + --weight_decay 0.01 \ + --dropout 0.1 \ + --grad_norm 2.0 \ + --decay linear \ + --num_train_steps ${steps} \ + --warmup_steps ${warmup} \ + --gradient_accumulation_steps ${gradient_accumulation_steps} \ + --seed 24 \ + --mlp ${mlp} \ + --fp16 \ No newline at end of file diff --git a/uniter_model/experiments/train_refcoco.sh b/uniter_model/experiments/train_refcoco.sh new file mode 100644 index 0000000..d49c2e6 --- /dev/null +++ b/uniter_model/experiments/train_refcoco.sh @@ -0,0 +1,2 @@ +horovodrun -np 1 -H localhost:1 \ + python train_re.py --config config/hps-refcoco.json diff --git a/uniter_model/experiments/train_refer_base_mlm_itm_mrfr_mrckl_all.sh b/uniter_model/experiments/train_refer_base_mlm_itm_mrfr_mrckl_all.sh new file mode 100644 index 0000000..b3c9090 --- /dev/null +++ b/uniter_model/experiments/train_refer_base_mlm_itm_mrfr_mrckl_all.sh @@ -0,0 +1,67 @@ +# bert-base with all-tasks pre-trained on all data (COCO+VG+CC+SBU) +REFER=$1 +GPU=$2 + +# pre-trained model +checkpoint=/pretrain/bert-base_weak_alldata_2nodes_16accum/ckpt/model_step_125000.pt + +# parameters +warmup=1000 +steps=10000 +lr=6e-5 +batch_size=128 +gradient_accumulation_steps=1 +mlp=2 + +# output name +output_name=bert-base_allweak_alldata-${REFER}_w${warmup}_s${steps}_l${lr}_b${batch_size}_g${gradient_accumulation_steps}_m${mlp} +echo ${output_name} + +# Training +CUDA_VISIBLE_DEVICES=${GPU} horovodrun -np 1 -H localhost:1 \ + python train_re.py \ + --train_txt_db /db/${REFER}_train_base-cased.db \ + --train_img_dir /img/visual_grounding_coco_gt \ + --val_txt_db /db/${REFER}_val_base-cased.db \ + --val_img_dir /img/visual_grounding_det_coco \ + --checkpoint ${checkpoint} \ + --cut_bert -1 \ + --output_dir /storage/${REFER}/${output_name} \ + --max_txt_len 60 \ + --train_batch_size ${batch_size} \ + --val_batch_size 128 \ + --learning_rate ${lr} \ + --optim adamw \ + --betas 0.9 0.98 \ + --weight_decay 0.01 \ + --dropout 0.1 \ + --grad_norm 2.0 \ + --decay linear \ + --num_train_steps ${steps} \ + --warmup_steps ${warmup} \ + --gradient_accumulation_steps ${gradient_accumulation_steps} \ + --seed 24 \ + --mlp ${mlp} \ + --fp16 + +# Evaluate +case ${REFER} in + refcoco|refcoco+) + dbs=/db/${REFER}_val_base-cased.db:/db/${REFER}_testA_base-cased.db:/db/${REFER}_testB_base-cased.db;; + refcocog) + dbs=/db/${REFER}_val_base-cased.db:/db/${REFER}_test_base-cased.db;; +esac + +CUDA_VISIBLE_DEVICES=${GPU} horovodrun -np 1 -H localhost:1 \ + python eval_re.py \ + --txt_db ${dbs} \ + --img_dir /img/visual_grounding_coco_gt \ + --output_dir /storage/${REFER}/${output_name} \ + --checkpoint best + +CUDA_VISIBLE_DEVICES=${GPU} horovodrun -np 1 -H localhost:1 \ + python eval_re.py \ + --txt_db ${dbs} \ + --img_dir /img/visual_grounding_det_coco \ + --output_dir /storage/${REFER}/${output_name} \ + --checkpoint best \ No newline at end of file diff --git a/uniter_model/experiments/train_refer_large_mlm_itm_mrfr_mrckl_all.sh b/uniter_model/experiments/train_refer_large_mlm_itm_mrfr_mrckl_all.sh new file mode 100644 index 0000000..6ce149b --- /dev/null +++ b/uniter_model/experiments/train_refer_large_mlm_itm_mrfr_mrckl_all.sh @@ -0,0 +1,67 @@ + +# bert-large with all-tasks pre-trained on all data (COCO+VG+CC+SBU) +REFER=$1 +GPU=$2 + +# pre-trained model +checkpoint=/pretrain/bert-large_frkl_alldata.pt + +# parameters +warmup=1000 +steps=10000 +lr=5.3e-5 +batch_size=32 +gradient_accumulation_steps=2 +mlp=2 + +# output name +output_name=bert-large_allweak_alldata-${REFER}_w${warmup}_s${steps}_l${lr}_b${batch_size}_g${gradient_accumulation_steps}_m${mlp} +echo ${output_name} + +CUDA_VISIBLE_DEVICES=${GPU} horovodrun -np 1 -H localhost:1 \ + python train_re.py \ + --train_txt_db /db/${REFER}_train_large-cased.db \ + --train_img_dir /img/visual_grounding_coco_gt \ + --val_txt_db /db/${REFER}_val_large-cased.db \ + --val_img_dir /img/visual_grounding_det_coco \ + --checkpoint ${checkpoint} \ + --cut_bert -1 \ + --output_dir /storage/${REFER}/${output_name} \ + --max_txt_len 60 \ + --train_batch_size ${batch_size} \ + --val_batch_size 128 \ + --learning_rate ${lr} \ + --optim adamw \ + --betas 0.9 0.98 \ + --weight_decay 0.01 \ + --dropout 0.1 \ + --grad_norm 2.0 \ + --decay linear \ + --num_train_steps ${steps} \ + --warmup_steps ${warmup} \ + --gradient_accumulation_steps ${gradient_accumulation_steps} \ + --seed 24 \ + --mlp ${mlp} \ + --fp16 + +# Evaluate +case ${REFER} in + refcoco|refcoco+) + dbs=/db/${REFER}_val_large-cased.db:/db/${REFER}_testA_large-cased.db:/db/${REFER}_testB_large-cased.db;; + refcocog) + dbs=/db/${REFER}_val_large-cased.db:/db/${REFER}_test_large-cased.db;; +esac + +CUDA_VISIBLE_DEVICES=${GPU} horovodrun -np 1 -H localhost:1 \ + python eval_re.py \ + --txt_db ${dbs} \ + --img_dir /img/visual_grounding_coco_gt \ + --output_dir /storage/${REFER}/${output_name} \ + --checkpoint best + +CUDA_VISIBLE_DEVICES=${GPU} horovodrun -np 1 -H localhost:1 \ + python eval_re.py \ + --txt_db ${dbs} \ + --img_dir /img/visual_grounding_det_coco \ + --output_dir /storage/${REFER}/${output_name} \ + --checkpoint best diff --git a/uniter_model/format_vcr_predictions.py b/uniter_model/format_vcr_predictions.py new file mode 100644 index 0000000..fa7f3bc --- /dev/null +++ b/uniter_model/format_vcr_predictions.py @@ -0,0 +1,54 @@ +import pandas as pd +import json +import os +import argparse +import numpy as np + + +def main(opts): + with open(os.path.join(opts.input_folder, opts.pred_file), "r") as f: + data = json.load(f) + probs_grp = [] + ids_grp = [] + ordered_data = sorted(data.items(), + key=lambda item: int(item[0].split("-")[1])) + for annot_id, scores in ordered_data: + ids_grp.append(annot_id) + probs_grp.append(np.array(scores).reshape(1, 5, 4)) + + # Double check the IDs are in the same order for everything + # assert [x == ids_grp[0] for x in ids_grp] + + probs_grp = np.stack(probs_grp, 1) + # essentially probs_grp is a [num_ex, 5, 4] array of probabilities. + # The 5 'groups' are + # [answer, rationale_conditioned_on_a0, rationale_conditioned_on_a1, + # rationale_conditioned_on_a2, rationale_conditioned_on_a3]. + # We will flatten this to a CSV file so it's easy to submit. + group_names = ['answer'] + [f'rationale_conditioned_on_a{i}' + for i in range(4)] + probs_df = pd.DataFrame(data=probs_grp.reshape((-1, 20)), + columns=[f'{group_name}_{i}' + for group_name in group_names for i in range(4)]) + probs_df['annot_id'] = ids_grp + probs_df = probs_df.set_index('annot_id', drop=True) + probs_df.to_csv(os.path.join(opts.input_folder, opts.output_file)) + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + # Required parameters + parser.add_argument("--pred_file", + default=None, type=str, + help="The input JSON file.") + parser.add_argument("--output_file", + default=None, type=str, + help="The output CSV file.") + parser.add_argument( + "--input_folder", default=None, type=str, + help="The directory where the predicted JSON files are in") + + args = parser.parse_args() + + main(args) + diff --git a/uniter_model/inf_itm.py b/uniter_model/inf_itm.py new file mode 100644 index 0000000..eee51c0 --- /dev/null +++ b/uniter_model/inf_itm.py @@ -0,0 +1,155 @@ +"""run inference for Image Text Retrieval""" +import argparse +import json +import os +from os.path import exists +import pickle +from time import time + +import torch +from torch.utils.data import DataLoader + +from apex import amp +from horovod import torch as hvd + +from data import (PrefetchLoader, + DetectFeatLmdb, TxtTokLmdb, ItmEvalDataset, itm_eval_collate) +from model import UniterForImageTextRetrieval + +from utils.logger import LOGGER +from utils.distributed import all_gather_list +from utils.misc import Struct +from utils.const import IMG_DIM + +from eval.itm import itm_eval +from train_itm import inference # FIXME + + +def main(opts): + hvd.init() + n_gpu = hvd.size() + device = torch.device("cuda", hvd.local_rank()) + torch.cuda.set_device(hvd.local_rank()) + rank = hvd.rank() + LOGGER.info("device: {} n_gpu: {}, rank: {}, " + "16-bits training: {}".format( + device, n_gpu, hvd.rank(), opts.fp16)) + + hps_file = f'{opts.output_dir}/log/hps.json' + model_opts = Struct(json.load(open(hps_file))) + + # load DBs and image dirs + eval_img_db = DetectFeatLmdb(opts.img_db, + model_opts.conf_th, model_opts.max_bb, + model_opts.min_bb, model_opts.num_bb, + opts.compressed_db) + eval_txt_db = TxtTokLmdb(opts.txt_db, -1) + eval_dataset = ItmEvalDataset(eval_txt_db, eval_img_db, opts.batch_size) + + # Prepare model + if exists(opts.checkpoint): + ckpt_file = opts.checkpoint + else: + ckpt_file = f'{opts.output_dir}/ckpt/model_step_{opts.checkpoint}.pt' + checkpoint = torch.load(ckpt_file) + model = UniterForImageTextRetrieval.from_pretrained( + f'{opts.output_dir}/log/model.json', checkpoint, + img_dim=IMG_DIM) + if 'rank_output' not in checkpoint: + model.init_output() # zero shot setting + + model.to(device) + model = amp.initialize(model, enabled=opts.fp16, opt_level='O2') + + eval_dataloader = DataLoader(eval_dataset, batch_size=1, + num_workers=opts.n_workers, + pin_memory=opts.pin_mem, + collate_fn=itm_eval_collate) + eval_dataloader = PrefetchLoader(eval_dataloader) + + eval_log, results = evaluate(model, eval_dataloader) + if hvd.rank() == 0: + result_dir = f'{opts.output_dir}/itm_results_{opts.name}' + if not exists(result_dir) and rank == 0: + os.makedirs(result_dir) + out_file = f'{result_dir}/results_{opts.checkpoint}.bin' + if not exists(out_file): + with open(out_file, 'wb') as f: + pickle.dump(results, f) + with open(f'{result_dir}/scores_{opts.checkpoint}.json', + 'w') as f: + json.dump(eval_log, f) + LOGGER.info(f'evaluation finished') + LOGGER.info( + f"======================== {opts.name} =========================\n" + f"image retrieval R1: {eval_log['img_r1']*100:.2f},\n" + f"image retrieval R5: {eval_log['img_r5']*100:.2f},\n" + f"image retrieval R10: {eval_log['img_r10']*100:.2f}\n" + f"text retrieval R1: {eval_log['txt_r1']*100:.2f},\n" + f"text retrieval R5: {eval_log['txt_r5']*100:.2f},\n" + f"text retrieval R10: {eval_log['txt_r10']*100:.2f}") + LOGGER.info("========================================================") + + +@torch.no_grad() +def evaluate(model, eval_loader): + model.eval() + st = time() + LOGGER.info("start running Image/Text Retrieval evaluation ...") + score_matrix = inference(model, eval_loader) + dset = eval_loader.dataset + all_score = hvd.allgather(score_matrix) + all_txt_ids = [i for ids in all_gather_list(dset.ids) + for i in ids] + all_img_ids = dset.all_img_ids + assert all_score.size() == (len(all_txt_ids), len(all_img_ids)) + if hvd.rank() != 0: + return {}, tuple() + # NOTE: only use rank0 to compute final scores + eval_log = itm_eval(all_score, all_txt_ids, all_img_ids, + dset.txt2img, dset.img2txts) + + results = (all_score, all_txt_ids, all_img_ids) + tot_time = time()-st + LOGGER.info(f"evaluation finished in {int(tot_time)} seconds, ") + return eval_log, results + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + # Required parameters + parser.add_argument("--txt_db", default=None, type=str, + help="The input train corpus. (LMDB)") + parser.add_argument("--img_db", default=None, type=str, + help="The input train images.") + parser.add_argument("--name", default='flickr_val', type=str, + help="affects output path") + parser.add_argument('--compressed_db', action='store_true', + help='use compressed LMDB') + parser.add_argument("--checkpoint", + default=None, type=str, + help="pretrained model (can take 'google-bert') ") + parser.add_argument("--batch_size", + default=400, type=int, + help="number of tokens in a batch") + + parser.add_argument( + "--output_dir", default=None, type=str, + help="The output directory where the model checkpoints will be " + "written.") + + # Prepro parameters + + # device parameters + parser.add_argument('--fp16', action='store_true', + help="Whether to use 16-bit float precision instead " + "of 32-bit") + parser.add_argument('--n_workers', type=int, default=4, + help="number of data workers") + parser.add_argument('--pin_mem', action='store_true', + help="pin memory") + + args = parser.parse_args() + + main(args) diff --git a/uniter_model/launch_container.sh b/uniter_model/launch_container.sh new file mode 100644 index 0000000..ce158ea --- /dev/null +++ b/uniter_model/launch_container.sh @@ -0,0 +1,25 @@ +TXT_DB=$1 +IMG_DIR=$2 +OUTPUT=$3 +PRETRAIN_DIR=$4 + +if [ -z $CUDA_VISIBLE_DEVICES ]; then + CUDA_VISIBLE_DEVICES='all' +fi + +if [ "$5" = "--prepro" ]; then + RO="" +else + RO=",readonly" +fi + + +docker run --gpus '"'device=$CUDA_VISIBLE_DEVICES'"' --ipc=host --rm -it \ + --mount src=$(pwd),dst=/src,type=bind \ + --mount src=$OUTPUT,dst=/storage,type=bind \ + --mount src=$PRETRAIN_DIR,dst=/pretrain,type=bind,readonly \ + --mount src=$TXT_DB,dst=/db,type=bind$RO \ + --mount src=$IMG_DIR,dst=/img,type=bind,readonly \ + -e NVIDIA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES \ + -e NCCL_IB_CUDA_SUPPORT=0 \ + -w /src convaicontainerregistry1.azurecr.io/img-txt diff --git a/uniter_model/launch_container_dist.sh b/uniter_model/launch_container_dist.sh new file mode 100644 index 0000000..9398d0a --- /dev/null +++ b/uniter_model/launch_container_dist.sh @@ -0,0 +1,28 @@ +TXT_DB=$1 +IMG_DIR=$2 +OUTPUT=$3 +PRETRAIN_DIR=$4 + +if [ -z $CUDA_VISIBLE_DEVICES ]; then + CUDA_VISIBLE_DEVICES='all' +fi + +if [ "$5" = "--prepro" ]; then + RO="" +else + RO=",readonly" +fi + +sudo docker run --gpus "device=$CUDA_VISIBLE_DEVICES" --ipc=host --rm -it \ + --privileged=true \ + --network=host \ + -v /convaistorage3mmb0/horovod_cluster_keys:/root/.ssh \ + -v /convaistorage3mmb0:/convaistorage3mmb0 \ + -v /convaistorage3mmb1:/convaistorage3mmb1 \ + --mount src=$(pwd),dst=/src,type=bind \ + --mount src=$OUTPUT,dst=/storage,type=bind \ + --mount src=$PRETRAIN_DIR,dst=/pretrain,type=bind,readonly \ + --mount src=$TXT_DB,dst=/db,type=bind$RO \ + --mount src=$IMG_DIR,dst=/img,type=bind,readonly \ + -e NCCL_IB_CUDA_SUPPORT=0 \ + -w /src convaicontainerregistry1.azurecr.io/img-txt diff --git a/uniter_model/misc/ans2label.json b/uniter_model/misc/ans2label.json new file mode 100644 index 0000000..9fc717b --- /dev/null +++ b/uniter_model/misc/ans2label.json @@ -0,0 +1 @@ +{"net": 0, "pitcher": 1, "orange": 2, "yes": 3, "white": 4, "skiing": 5, "red": 6, "frisbee": 7, "brushing teeth": 8, "no": 9, "black and white": 10, "skateboard": 11, "1": 12, "blue": 13, "green": 14, "motorcycle": 15, "gray": 16, "2": 17, "purse": 18, "skis": 19, "poles": 20, "surfboard": 21, "dog": 22, "on": 23, "office": 24, "large": 25, "very big": 26, "laptop": 27, "vent": 28, "computer": 29, "black": 30, "bear": 31, "3": 32, "wii": 33, "glasses": 34, "tree": 35, "eating": 36, "log": 37, "5": 38, "raft": 39, "left": 40, "living room": 41, "pink": 42, "right": 43, "railing": 44, "grass": 45, "wire": 46, "10 years": 47, "knife": 48, "cake": 49, "banana": 50, "chef": 51, "vanilla": 52, "4": 53, "outdoor": 54, "mustard": 55, "bun": 56, "clouds": 57, "dock": 58, "brown": 59, "silver": 60, "refrigerator": 61, "square": 62, "teddy": 63, "elm": 64, "stripes": 65, "baseball": 66, "catcher": 67, "beer": 68, "bottom": 69, "north": 70, "nike": 71, "yellow and white": 72, "morning": 73, "elephant": 74, "red and white": 75, "propeller": 76, "tan": 77, "wall": 78, "rolex": 79, "clock": 80, "table": 81, "0": 82, "wood": 83, "christmas": 84, "spinach": 85, "thick": 86, "bag": 87, "leaves": 88, "necklace": 89, "6": 90, "bathroom": 91, "shower": 92, "towel": 93, "solid": 94, "referee": 95, "wilson": 96, "8:00": 97, "e": 98, "24": 99, "hat": 100, "grazing": 101, "sheep": 102, "10": 103, "tag": 104, "spanish": 105, "hot dog": 106, "plate": 107, "lunch": 108, "butter": 109, "peppers": 110, "onions": 111, "very": 112, "mayonnaise": 113, "mayo": 114, "sweet potato": 115, "pig": 116, "sweet": 117, "flowers": 118, "floral": 119, "yellow": 120, "window": 121, "7": 122, "pizza": 123, "car": 124, "": 125, "cargo": 126, "stairs": 127, "abstract": 128, "rug": 129, "baseball cap": 130, "texting": 131, "pole": 132, "crosswalk": 133, "nothing": 134, "urban": 135, "bus": 136, "light": 137, "afternoon": 138, "boat": 139, "cheese": 140, "paper": 141, "real": 142, "sun": 143, "birthday": 144, "words": 145, "inside": 146, "shadows": 147, "tomato": 148, "evergreen": 149, "100 feet": 150, "shingles": 151, "trees": 152, "building": 153, "hay": 154, "ski pole": 155, "patterned": 156, "walking": 157, "ice": 158, "laundry": 159, "pepsi": 160, "good": 161, "1:50": 162, "purple": 163, "13": 164, "africa": 165, "teddy bears": 166, "socks": 167, "giraffe": 168, "soccer": 169, "blue and yellow": 170, "zebras": 171, "cupcake": 172, "broccoli": 173, "soldier": 174, "parking lot": 175, "cows": 176, "herding": 177, "on table": 178, "fish": 179, "nightstand": 180, "50": 181, "overcast": 182, "cross": 183, "toaster oven": 184, "tile": 185, "11:55": 186, "red and yellow": 187, "nowhere": 188, "hair dryer": 189, "truck": 190, "11": 191, "people": 192, "rectangle": 193, "hot dogs": 194, "party": 195, "12:55": 196, "apron": 197, "kitchen": 198, "cooking": 199, "ring": 200, "1 way": 201, "stop": 202, "neither": 203, "many": 204, "female": 205, "brushing": 206, "tie": 207, "tennis racket": 208, "knife and fork": 209, "restaurant": 210, "cat": 211, "bed": 212, "sand": 213, "ocean": 214, "cold": 215, "kites": 216, "cumulus": 217, "standing": 218, "male": 219, "star": 220, "tracks": 221, "chocolate": 222, "round": 223, "fork and knife": 224, "yankees": 225, "pictures": 226, "dots": 227, "bird": 228, "parrot": 229, "red white and blue": 230, "man": 231, "metal": 232, "fence": 233, "snowboarding": 234, "pine": 235, "snow": 236, "shorts": 237, "swim": 238, "wine": 239, "brick": 240, "no parking": 241, "children": 242, "beef": 243, "phone": 244, "english": 245, "cell phone": 246, "pink and yellow": 247, "clear": 248, "watermelon": 249, "bedroom": 250, "fork": 251, "cow": 252, "rackets": 253, "tennis rackets": 254, "8": 255, "collar": 256, "tennis": 257, "1950s": 258, "playing tennis": 259, "skirt": 260, "30": 261, "polka dot": 262, "beach": 263, "horse": 264, "grill": 265, "african american": 266, "down": 267, "street": 268, "in air": 269, "sweater": 270, "yellow and blue": 271, "park": 272, "backyard": 273, "spectators": 274, "parasailing": 275, "31": 276, "river": 277, "55": 278, "shadow": 279, "winter": 280, "chicken": 281, "tea": 282, "evening": 283, "dusk": 284, "ski resort": 285, "helmet": 286, "penne": 287, "bench": 288, "resting": 289, "elephants": 290, "southwest": 291, "usa": 292, "cars": 293, "town": 294, "bananas": 295, "umbrella": 296, "container": 297, "woman": 298, "on counter": 299, "salad": 300, "striped": 301, "motel": 302, "vertical": 303, "oranges": 304, "hot sauce": 305, "bottle": 306, "juice": 307, "eyes": 308, "ground": 309, "backpack": 310, "black and yellow": 311, "forward": 312, "jackets": 313, "1 on right": 314, "green and yellow": 315, "playing baseball": 316, "riding": 317, "sitting": 318, "carrot": 319, "basket": 320, "seagull": 321, "ski poles": 322, "p": 323, "parking": 324, "street light": 325, "mets": 326, "strap": 327, "bike": 328, "riding bike": 329, "poodle": 330, "shoes": 331, "carpet": 332, "lettuce": 333, "food": 334, "1 foot": 335, "roses": 336, "mountains": 337, "scissors": 338, "camera": 339, "beige": 340, "beard": 341, "cutting": 342, "baby": 343, "tape": 344, "watch": 345, "never": 346, "taking picture": 347, "eggs": 348, "syrup": 349, "sandwich": 350, "water skiing": 351, "microphone": 352, "back": 353, "bears": 354, "donuts": 355, "w": 356, "sky": 357, "double decker": 358, "england": 359, "surfing": 360, "running": 361, "shirt": 362, "barn": 363, "weather vane": 364, "white and blue": 365, "fishing": 366, "bridge": 367, "los angeles": 368, "open": 369, "red sox": 370, "bat": 371, "plane": 372, "white and green": 373, "transportation": 374, "sunny": 375, "bus stop": 376, "city": 377, "brown and white": 378, "bicycle": 379, "crow": 380, "magazines": 381, "daisy": 382, "14": 383, "old": 384, "curtains": 385, "jumped": 386, "snowboard": 387, "dinosaur": 388, "racing": 389, "asphalt": 390, "court": 391, "plastic": 392, "circle": 393, "red and blue": 394, "zebra": 395, "12": 396, "biplane": 397, "shallow": 398, "brazil": 399, "logo": 400, "2:20": 401, "electric": 402, "night time": 403, "motion": 404, "toothbrushes": 405, "orange and white": 406, "66": 407, "spoon": 408, "toyota": 409, "tennis shoes": 410, "46": 411, "second": 412, "no 1": 413, "iphone": 414, "friend": 415, "apple": 416, "carnation": 417, "15": 418, "tiger": 419, "glove": 420, "airplane": 421, "bow": 422, "air france": 423, "passengers": 424, "tv": 425, "on building": 426, "3:55": 427, "victorian": 428, "steeple": 429, "happy": 430, "skateboarding": 431, "fruit": 432, "cutting board": 433, "cantaloupe": 434, "kiwi": 435, "sliced": 436, "heart": 437, "water": 438, "rainy": 439, "carrots": 440, "giraffes": 441, "eat": 442, "ramp": 443, "lab": 444, "field": 445, "horizontal": 446, "birds": 447, "home": 448, "shrimp": 449, "12 feet": 450, "girl": 451, "modern": 452, "turtle": 453, "dell": 454, "boots": 455, "sunglasses": 456, "black and orange": 457, "yellow and black": 458, "gloves": 459, "hp": 460, "desk": 461, "both": 462, "sign": 463, "on street": 464, "2000": 465, "cirrus": 466, "to dry": 467, "ceiling": 468, "fluorescent": 469, "up": 470, "9": 471, "boys": 472, "playing soccer": 473, "american": 474, "passenger": 475, "turn": 476, "palm": 477, "no train": 478, "wedding": 479, "branch": 480, "parrots": 481, "air force": 482, "on tracks": 483, "small": 484, "tank": 485, "dirty": 486, "france": 487, "honda": 488, "2.00": 489, "whale": 490, "vase": 491, "flying": 492, "professional": 493, "driving": 494, "tissue": 495, "protest": 496, "corona": 497, "for balance": 498, "twin": 499, "clothes": 500, "t shirt": 501, "window sill": 502, "wild": 503, "noon": 504, "caution": 505, "spring": 506, "raining": 507, "cane": 508, "school": 509, "windsurfing": 510, "parachute": 511, "black and red": 512, "25": 513, "background": 514, "toaster": 515, "planes": 516, "yellow and red": 517, "spatula": 518, "10:10": 519, "ivory": 520, "train": 521, "welcome": 522, "highway": 523, "off": 524, "on track": 525, "electricity": 526, "italy": 527, "dinner": 528, "sink": 529, "squares": 530, "5 ft": 531, "parked": 532, "store": 533, "dress": 534, "signs": 535, "meow": 536, "football": 537, "rugby": 538, "stainless steel": 539, "la": 540, "dirt": 541, "blue and white": 542, "klm": 543, "house": 544, "unknown": 545, "ford": 546, "reading": 547, "chair": 548, "mountain": 549, "alive": 550, "water skis": 551, "picture": 552, "parade": 553, "slippers": 554, "trailer": 555, "boating": 556, "holding it": 557, "shade": 558, "cloth": 559, "6:20": 560, "candle": 561, "hose": 562, "hand": 563, "3:25": 564, "on sidewalk": 565, "poster": 566, "downhill": 567, "68": 568, "reflection": 569, "summer": 570, "pickles": 571, "halloween": 572, "bats": 573, "london": 574, "zoo": 575, "surfer": 576, "racket": 577, "flickr": 578, "cutting hair": 579, "strawberries": 580, "mushroom": 581, "teddy bear": 582, "big": 583, "suitcase": 584, "veggie": 585, "pepper": 586, "houses": 587, "70": 588, "toshiba": 589, "triangle": 590, "boxes": 591, "photograph": 592, "smoke": 593, "engine": 594, "camel": 595, "sidewalk": 596, "left 1": 597, "red and green": 598, "4:35": 599, "on couch": 600, "candy": 601, "minnie mouse": 602, "homemade": 603, "mouse": 604, "box": 605, "movie": 606, "45": 607, "strawberry": 608, "fridge": 609, "full": 610, "vegetables": 611, "bright": 612, "play": 613, "remote": 614, "pond": 615, "savannah": 616, "celery": 617, "concrete": 618, "semi": 619, "dump": 620, "scania": 621, "safety": 622, "posing": 623, "fabric": 624, "laying": 625, "couch": 626, "blueberries": 627, "handle": 628, "pipe": 629, "stick": 630, "parmesan": 631, "steak": 632, "chain link": 633, "catch": 634, "barbed wire": 635, "mozzarella": 636, "soda": 637, "fire hydrant": 638, "cat food": 639, "pepperoni": 640, "lot": 641, "licking": 642, "red and black": 643, "clay": 644, "tennis court": 645, "jumping": 646, "potatoes": 647, "toothbrush": 648, "kite": 649, "not at all": 650, "flying kite": 651, "broken": 652, "black and silver": 653, "lap": 654, "outside": 655, "44": 656, "delta": 657, "greyhound": 658, "ring finger": 659, "talking on phone": 660, "bad": 661, "kettle": 662, "35": 663, "motorcycles": 664, "produce": 665, "comfort": 666, "steering wheel": 667, "18": 668, "humans": 669, "coffee": 670, "white and brown": 671, "fall": 672, "bread": 673, "cherry": 674, "4:30": 675, "flag": 676, "night": 677, "lamp": 678, "cucumber": 679, "can't see": 680, "porcelain": 681, "oval": 682, "museum": 683, "rain": 684, "sprinkles": 685, "20": 686, "kids": 687, "bracelet": 688, "sneakers": 689, "mask": 690, "mickey mouse": 691, "twins": 692, "very high": 693, "costume": 694, "cabbage": 695, "paint": 696, "lighting": 697, "young": 698, "air conditioner": 699, "wooden": 700, "board": 701, "someone": 702, "beets": 703, "16": 704, "day time": 705, "4 inches": 706, "lights": 707, "ladder": 708, "glass": 709, "ferris wheel": 710, "fries": 711, "steamed": 712, "shepherd": 713, "cotton": 714, "suit": 715, "goatee": 716, "on his head": 717, "print": 718, "happy birthday": 719, "forks": 720, "travel": 721, "maple": 722, "200": 723, "oil": 724, "jeans": 725, "can": 726, "chopsticks": 727, "on wall": 728, "construction": 729, "mack": 730, "36": 731, "chinese": 732, "moped": 733, "festival": 734, "gas": 735, "throwing": 736, "circus": 737, "wires": 738, "not possible": 739, "plates": 740, "sugar": 741, "in": 742, "women's": 743, "door": 744, "no man": 745, "volleyball": 746, "serving": 747, "ponytail": 748, "business": 749, "decoration": 750, "santa": 751, "flat": 752, "barrel": 753, "12:15": 754, "candles": 755, "atv": 756, "free": 757, "hair": 758, "waffle": 759, "ball": 760, "stop sign": 761, "wetsuit": 762, "very deep": 763, "swimsuit": 764, "green and black": 765, "foreground": 766, "stands": 767, "china airlines": 768, "flower": 769, "300": 770, "lobster": 771, "on bench": 772, "plaster": 773, "phones": 774, "sailboat": 775, "apples": 776, "road": 777, "recently": 778, "cones": 779, "cactus": 780, "rice": 781, "vegetarian": 782, "donut": 783, "ketchup": 784, "police": 785, "mirror": 786, "rock": 787, "meat": 788, "blinds": 789, "cell phones": 790, "china": 791, "rust": 792, "7:25": 793, "stone": 794, "vans": 795, "middle": 796, "eagle": 797, "9:30": 798, "ping pong": 799, "microwave": 800, "gmc": 801, "umbrellas": 802, "wrist": 803, "cuddling": 804, "laughing": 805, "boy": 806, "next to toilet": 807, "tabby": 808, "petting": 809, "south": 810, "40": 811, "name tag": 812, "checkered": 813, "name": 814, "slow": 815, "cardboard": 816, "windows": 817, "croissant": 818, "plain": 819, "cookie": 820, "on ground": 821, "low": 822, "water bottle": 823, "goggles": 824, "turkey": 825, "pull": 826, "shut": 827, "kite flying": 828, "bowl": 829, "smile": 830, "in bowl": 831, "bush": 832, "cloudy": 833, "top left": 834, "skateboarder": 835, "coca cola": 836, "pan": 837, "drinking": 838, "short": 839, "floor": 840, "thanksgiving": 841, "radio": 842, "drink": 843, "on toilet": 844, "bike rack": 845, "bleachers": 846, "train tracks": 847, "horses": 848, "far": 849, "top": 850, "toilet": 851, "in water": 852, "private": 853, "nature": 854, "checkers": 855, "commercial": 856, "stroller": 857, "power": 858, "stuffed animals": 859, "uniforms": 860, "japan": 861, "liquor": 862, "faucet": 863, "green and orange": 864, "corn": 865, "sub": 866, "white and yellow": 867, "mercedes": 868, "in sky": 869, "tarp": 870, "indian": 871, "counter": 872, "multicolored": 873, "polar": 874, "go": 875, "now": 876, "no number": 877, "swimming": 878, "bridle": 879, "cowboy": 880, "union station": 881, "salt and pepper": 882, "olives": 883, "pizza cutter": 884, "british airways": 885, "nighttime": 886, "domestic": 887, "trolley": 888, "australia": 889, "tiles": 890, "pug": 891, "wicker": 892, "british": 893, "us airways express": 894, "burton": 895, "christmas tree": 896, "napkin": 897, "writing": 898, "rocks": 899, "hello kitty": 900, "lacoste": 901, "gold": 902, "fan": 903, "skateboards": 904, "day": 905, "on floor": 906, "2008": 907, "dark": 908, "flying kites": 909, "rural": 910, "olympics": 911, "bmw": 912, "34": 913, "factory": 914, "denim": 915, "typing": 916, "for fun": 917, "steel": 918, "watching tv": 919, "chevron": 920, "driver": 921, "baggage claim": 922, "grapes": 923, "f": 924, "angels": 925, "roof": 926, "handlebars": 927, "train station": 928, "public": 929, "oak": 930, "sleeping": 931, "canada": 932, "on runway": 933, "air canada": 934, "on top": 935, "tired": 936, "blonde": 937, "cups": 938, "little": 939, "adidas": 940, "10 feet": 941, "white and gray": 942, "leaf": 943, "fisheye": 944, "forest": 945, "war": 946, "octagon": 947, "raspberry": 948, "helmets": 949, "united states": 950, "29": 951, "noodles": 952, "van": 953, "long": 954, "traveling": 955, "luggage": 956, "airport": 957, "single": 958, "pitching": 959, "dugout": 960, "garbage": 961, "in street": 962, "happiness": 963, "cigarette": 964, "on tower": 965, "antelope": 966, "graffiti": 967, "skating": 968, "on road": 969, "curved": 970, "red light": 971, "washington": 972, "ski lift": 973, "athletics": 974, "brace": 975, "squatting": 976, "catching": 977, "batter": 978, "batting": 979, "game": 980, "towards": 981, "33": 982, "sliding": 983, "makeup": 984, "japanese": 985, "person": 986, "pirates": 987, "plaid": 988, "rose": 989, "daytime": 990, "keyboard": 991, "surfboards": 992, "hummingbird": 993, "ollie": 994, "11:30": 995, "clock tower": 996, "5:55": 997, "san francisco": 998, "stopping": 999, "tags": 1000, "samsung": 1001, "computers": 1002, "cabinets": 1003, "talking": 1004, "cage": 1005, "asparagus": 1006, "5 years": 1007, "hanger": 1008, "adult": 1009, "rabbit": 1010, "empty": 1011, "softball": 1012, "1st": 1013, "playing": 1014, "chairs": 1015, "farm": 1016, "cross country": 1017, "dump truck": 1018, "women": 1019, "snowboarder": 1020, "tall": 1021, "monkey": 1022, "mantle": 1023, "fire": 1024, "books": 1025, "quilt": 1026, "cessna": 1027, "chandelier": 1028, "dunkin donuts": 1029, "beans": 1030, "relish": 1031, "no flag": 1032, "parking meter": 1033, "spots": 1034, "ducks": 1035, "sandals": 1036, "doughnut": 1037, "lighthouse": 1038, "yacht": 1039, "german shepherd": 1040, "in middle": 1041, "raw": 1042, "chain": 1043, "2 feet": 1044, "pedestal": 1045, "sauerkraut": 1046, "bagels": 1047, "mutt": 1048, "dog and cat": 1049, "race": 1050, "poor": 1051, "cat and dog": 1052, "station": 1053, "printer": 1054, "daisies": 1055, "front": 1056, "gravel": 1057, "rear": 1058, "grassy": 1059, "pigeons": 1060, "dogs": 1061, "in car": 1062, "life": 1063, "wii remotes": 1064, "suv": 1065, "leather": 1066, "bottom right": 1067, "peace": 1068, "facebook": 1069, "blanket": 1070, "fountain": 1071, "frisbees": 1072, "12:30": 1073, "am": 1074, "scooter": 1075, "going": 1076, "analog": 1077, "america": 1078, "pitbull": 1079, "relaxing": 1080, "paddle boarding": 1081, "white and pink": 1082, "shampoo": 1083, "alps": 1084, "ride": 1085, "side": 1086, "mane": 1087, "on desk": 1088, "on chair": 1089, "2012": 1090, "multi": 1091, "straight": 1092, "big ben": 1093, "closed": 1094, "frosted": 1095, "3 feet": 1096, "waves": 1097, "buoy": 1098, "life vest": 1099, "trash can": 1100, "medium": 1101, "boxer": 1102, "very tall": 1103, "yamaha": 1104, "sunlight": 1105, "hit ball": 1106, "dry": 1107, "coke": 1108, "gym": 1109, "orange and black": 1110, "center": 1111, "rope": 1112, "flip flops": 1113, "4th of july": 1114, "siamese": 1115, "crafts": 1116, "color": 1117, "italian": 1118, "playing frisbee": 1119, "skate park": 1120, "orange juice": 1121, "windowsill": 1122, "corgi": 1123, "thumb": 1124, "peanut butter": 1125, "pie": 1126, "toast": 1127, "no hat": 1128, "benches": 1129, "diamond": 1130, "blender": 1131, "avocado": 1132, "television": 1133, "speakers": 1134, "pony": 1135, "baseball field": 1136, "pavement": 1137, "sydney": 1138, "not there": 1139, "diamonds": 1140, "4 feet": 1141, "goalie": 1142, "soccer ball": 1143, "runway": 1144, "video game": 1145, "gaming": 1146, "casual": 1147, "green and white": 1148, "toilet brush": 1149, "working": 1150, "pickup": 1151, "girls": 1152, "remotes": 1153, "pasta": 1154, "hood": 1155, "braves": 1156, "skier": 1157, "motorola": 1158, "17": 1159, "b": 1160, "100": 1161, "diet coke": 1162, "hospital": 1163, "wagon": 1164, "milk": 1165, "ferry": 1166, "rainbow": 1167, "on bed": 1168, "toward": 1169, "1:30": 1170, "19": 1171, "security": 1172, "herself": 1173, "mercedes benz": 1174, "supreme": 1175, "thin": 1176, "platform": 1177, "gray and red": 1178, "thai": 1179, "storage": 1180, "thailand": 1181, "swan": 1182, "peach": 1183, "10:05": 1184, "dome": 1185, "chiquita": 1186, "2:00": 1187, "mountain dew": 1188, "23": 1189, "knives": 1190, "street sign": 1191, "on beach": 1192, "playing wii": 1193, "using laptop": 1194, "stickers": 1195, "yogurt": 1196, "on grass": 1197, "9:50": 1198, "9:45": 1199, "sweat": 1200, "gatorade": 1201, "umpire": 1202, "37": 1203, "transport": 1204, "desktop": 1205, "desserts": 1206, "main": 1207, "boston": 1208, "fell": 1209, "top right": 1210, "case": 1211, "asleep": 1212, "over": 1213, "9:55": 1214, "grapefruit": 1215, "breakfast": 1216, "headphones": 1217, "freight": 1218, "cup": 1219, "sweatband": 1220, "nobody": 1221, "lamps": 1222, "9:25": 1223, "scarf": 1224, "on fridge": 1225, "main st": 1226, "moving": 1227, "confused": 1228, "fresh": 1229, "kiting": 1230, "blue jay": 1231, "flats": 1232, "long time": 1233, "chihuahua": 1234, "ceramic": 1235, "mushrooms": 1236, "on plate": 1237, "human": 1238, "power lines": 1239, "hotel": 1240, "map": 1241, "earring": 1242, "boarding": 1243, "display": 1244, "warm": 1245, "napkins": 1246, "brown and black": 1247, "broom": 1248, "basketball": 1249, "papers": 1250, "holding baby": 1251, "sad": 1252, "kickstand": 1253, "60": 1254, "shoulder": 1255, "sleep": 1256, "footprints": 1257, "tunnel": 1258, "1990": 1259, "hats": 1260, "6 inches": 1261, "ham": 1262, "bacon": 1263, "church": 1264, "53": 1265, "pineapple": 1266, "at camera": 1267, "red bull": 1268, "pilot": 1269, "tattoo": 1270, "work": 1271, "polar bear": 1272, "taking off": 1273, "website": 1274, "22": 1275, "4:00": 1276, "coffee maker": 1277, "fast": 1278, "fur": 1279, "rubber": 1280, "tongs": 1281, "german": 1282, "germany": 1283, "3 inches": 1284, "toy": 1285, "3:20": 1286, "calm": 1287, "pots": 1288, "balloons": 1289, "fruits": 1290, "9:20": 1291, "drawer": 1292, "oven": 1293, "soup": 1294, "stove": 1295, "heels": 1296, "wind": 1297, "island": 1298, "blood": 1299, "leg": 1300, "theater": 1301, "tennis racquet": 1302, "21": 1303, "gothic": 1304, "2:35": 1305, "wii remote": 1306, "turning": 1307, "20 feet": 1308, "pink and black": 1309, "ears": 1310, "fun": 1311, "wreath": 1312, "to right": 1313, "child": 1314, "fly": 1315, "head": 1316, "drywall": 1317, "shorter": 1318, "pier": 1319, "feeding giraffe": 1320, "in vase": 1321, "burger": 1322, "easter": 1323, "onion": 1324, "uniform": 1325, "remote control": 1326, "guitar": 1327, "time": 1328, "verizon": 1329, "tomatoes": 1330, "ship": 1331, "tulips": 1332, "glaze": 1333, "on suitcase": 1334, "tent": 1335, "1:45": 1336, "market": 1337, "bnsf": 1338, "bandana": 1339, "still": 1340, "don't know": 1341, "piano": 1342, "mouth": 1343, "run": 1344, "sparrow": 1345, "throw": 1346, "lines": 1347, "vest": 1348, "1950": 1349, "jet": 1350, "sepia": 1351, "2015": 1352, "busy": 1353, "lighter": 1354, "dessert": 1355, "bending": 1356, "75": 1357, "finch": 1358, "pastries": 1359, "outdoors": 1360, "bakery": 1361, "clean": 1362, "ipod": 1363, "tablecloth": 1364, "cigarettes": 1365, "looking at phone": 1366, "in front": 1367, "food truck": 1368, "face": 1369, "swinging": 1370, "safari": 1371, "500": 1372, "volkswagen": 1373, "2010": 1374, "shape": 1375, "shelves": 1376, "riding horses": 1377, "2016": 1378, "behind bus": 1379, "towels": 1380, "lemon": 1381, "straw": 1382, "bamboo": 1383, "5 feet": 1384, "hardwood": 1385, "oregon": 1386, "schnauzer": 1387, "organic": 1388, "h": 1389, "kid": 1390, "meter": 1391, "61": 1392, "charging": 1393, "bald": 1394, "caucasian": 1395, "man on left": 1396, "stand": 1397, "27": 1398, "dining room": 1399, "sandwiches": 1400, "32": 1401, "apartment": 1402, "tower": 1403, "virgin": 1404, "out": 1405, "white and red": 1406, "2:05": 1407, "i don't know": 1408, "chains": 1409, "legs": 1410, "age": 1411, "goats": 1412, "s": 1413, "congratulations": 1414, "dresser": 1415, "camper": 1416, "half": 1417, "silverware": 1418, "decorative": 1419, "hawaiian": 1420, "petting horse": 1421, "wheel": 1422, "florida": 1423, "reds": 1424, "washington dc": 1425, "moon": 1426, "conference": 1427, "screen": 1428, "controller": 1429, "robin": 1430, "men": 1431, "protection": 1432, "roll": 1433, "harley davidson": 1434, "coal": 1435, "mustache": 1436, "smiling": 1437, "pedestrians": 1438, "88": 1439, "me": 1440, "tray": 1441, "males": 1442, "monitor": 1443, "bell": 1444, "landscape": 1445, "club": 1446, "toothpick": 1447, "seagulls": 1448, "bowtie": 1449, "lake": 1450, "steam": 1451, "surf": 1452, "baseball glove": 1453, "blinders": 1454, "woods": 1455, "stuffed": 1456, "sunbathing": 1457, "shearing": 1458, "dad": 1459, "mixer": 1460, "pot": 1461, "blending": 1462, "identification": 1463, "owl": 1464, "wine glass": 1465, "on bike": 1466, "billabong": 1467, "new york": 1468, "yarn": 1469, "tube": 1470, "tennis ball": 1471, "2:55": 1472, "ice cream": 1473, "chevrolet": 1474, "shirt and tie": 1475, "taking selfie": 1476, "blue and green": 1477, "he isn't": 1478, "cutting cake": 1479, "east": 1480, "setting": 1481, "brewers": 1482, "riding bikes": 1483, "7 eleven": 1484, "stars": 1485, "jockey": 1486, "jacket": 1487, "standing still": 1488, "book": 1489, "gray and white": 1490, "pen": 1491, "red white blue": 1492, "above": 1493, "alaska": 1494, "tongue": 1495, "feathers": 1496, "k": 1497, "camping": 1498, "pasture": 1499, "corner": 1500, "away": 1501, "ski": 1502, "texas": 1503, "fire truck": 1504, "sailboats": 1505, "jump": 1506, "walk": 1507, "spray paint": 1508, "loading": 1509, "united": 1510, "1000": 1511, "brushing his teeth": 1512, "roman numerals": 1513, "garlic": 1514, "surprise": 1515, "3rd": 1516, "first": 1517, "side of road": 1518, "dodgers": 1519, "airplanes": 1520, "unsure": 1521, "russian": 1522, "wet": 1523, "skyscraper": 1524, "5 star": 1525, "brushing her teeth": 1526, "blankets": 1527, "natural": 1528, "across street": 1529, "smartphone": 1530, "duck": 1531, "sausage": 1532, "paris": 1533, "newspaper": 1534, "pants": 1535, "spices": 1536, "pillow": 1537, "to left": 1538, "snowboards": 1539, "colgate": 1540, "on elephant": 1541, "string": 1542, "horns": 1543, "2:40": 1544, "men's": 1545, "cobblestone": 1546, "regular": 1547, "staring": 1548, "28": 1549, "barber shop": 1550, "linoleum": 1551, "grind": 1552, "cut": 1553, "x": 1554, "above sink": 1555, "above stove": 1556, "dishes": 1557, "dalmatian": 1558, "watching": 1559, "glazed": 1560, "5:25": 1561, "j": 1562, "messy": 1563, "wallet": 1564, "tuna": 1565, "toasted": 1566, "grilled": 1567, "french": 1568, "green and blue": 1569, "sunflowers": 1570, "to catch frisbee": 1571, "wool": 1572, "sprint": 1573, "no grass": 1574, "cabinet": 1575, "shell": 1576, "foil": 1577, "bottles": 1578, "bar": 1579, "king": 1580, "paper towels": 1581, "friends": 1582, "beagle": 1583, "school bus": 1584, "laptops": 1585, "snowing": 1586, "cement": 1587, "pc": 1588, "accident": 1589, "stuffed animal": 1590, "wakeboard": 1591, "balance": 1592, "in suitcase": 1593, "white and black": 1594, "nikon": 1595, "cleats": 1596, "on sink": 1597, "pool": 1598, "mom": 1599, "downtown": 1600, "asian": 1601, "heater": 1602, "bathing": 1603, "193": 1604, "against wall": 1605, "canopy": 1606, "jungle": 1607, "berries": 1608, "military": 1609, "pickle": 1610, "clams": 1611, "seafood": 1612, "in box": 1613, "boats": 1614, "tables": 1615, "lizard": 1616, "lemonade": 1617, "m": 1618, "soft": 1619, "illinois": 1620, "country": 1621, "for sale": 1622, "arm": 1623, "listening": 1624, "curly": 1625, "play tennis": 1626, "hands": 1627, "cereal": 1628, "blue and red": 1629, "robe": 1630, "around neck": 1631, "red and silver": 1632, "soap": 1633, "trains": 1634, "throwing frisbee": 1635, "smoking": 1636, "india": 1637, "headband": 1638, "not very": 1639, "westin": 1640, "serve": 1641, "bicycles": 1642, "can't tell": 1643, "to catch ball": 1644, "visibility": 1645, "ana": 1646, "reins": 1647, "rodeo": 1648, "boot": 1649, "on horse": 1650, "12:35": 1651, "riding motorcycle": 1652, "mexico": 1653, "mother": 1654, "african": 1655, "left and right": 1656, "button": 1657, "earrings": 1658, "blackberry": 1659, "cell": 1660, "10:00": 1661, "harness": 1662, "pillows": 1663, "vegetable": 1664, "tablet": 1665, "fern": 1666, "cats": 1667, "golden retriever": 1668, "goat": 1669, "tractor": 1670, "valentine's day": 1671, "hearts": 1672, "khaki": 1673, "man on right": 1674, "mcdonald's": 1675, "player": 1676, "arriving": 1677, "husky": 1678, "on skateboard": 1679, "vases": 1680, "coat": 1681, "beanie": 1682, "coming": 1683, "granite": 1684, "shopping cart": 1685, "it's raining": 1686, "sports": 1687, "leash": 1688, "balls": 1689, "blurry": 1690, "baseball bat": 1691, "team": 1692, "mango": 1693, "mug": 1694, "eiffel tower": 1695, "worms": 1696, "trash": 1697, "robot": 1698, "show": 1699, "terrier": 1700, "painting": 1701, "rooster": 1702, "42": 1703, "jones": 1704, "state farm": 1705, "balloon": 1706, "trunk": 1707, "coach": 1708, "t": 1709, "playing game": 1710, "fireplace": 1711, "behind clouds": 1712, "uphill": 1713, "motocross": 1714, "sony": 1715, "magazine": 1716, "kitesurfing": 1717, "catching frisbee": 1718, "catch frisbee": 1719, "bud light": 1720, "drive": 1721, "fighting": 1722, "1 on left": 1723, "very old": 1724, "hallway": 1725, "lexus": 1726, "wii controller": 1727, "9:15": 1728, "fast food": 1729, "5:45": 1730, "catholic": 1731, "muffin": 1732, "traffic light": 1733, "band": 1734, "button up": 1735, "grocery": 1736, "shelf": 1737, "2:25": 1738, "honey": 1739, "plants": 1740, "oars": 1741, "foggy": 1742, "nathan's": 1743, "cord": 1744, "yard": 1745, "48": 1746, "donut shop": 1747, "chimney": 1748, "calico": 1749, "suits": 1750, "sideways": 1751, "animals": 1752, "black and blue": 1753, "bikini": 1754, "photographer": 1755, "700": 1756, "queen": 1757, "1:00": 1758, "12:05": 1759, "horseback riding": 1760, "awake": 1761, "bunny": 1762, "12:00": 1763, "continental": 1764, "flamingo": 1765, "rye": 1766, "family": 1767, "lots": 1768, "owner": 1769, "stew": 1770, "palm tree": 1771, "cruise ship": 1772, "56": 1773, "design": 1774, "ny": 1775, "far right": 1776, "tire": 1777, "younger": 1778, "biking": 1779, "at&t": 1780, "giants": 1781, "marshmallows": 1782, "caramel": 1783, "polo": 1784, "emirates": 1785, "salon": 1786, "focus": 1787, "on motorcycle": 1788, "magnets": 1789, "mat": 1790, "ivy": 1791, "cakes": 1792, "chrome": 1793, "bob": 1794, "asia": 1795, "graduation": 1796, "cauliflower": 1797, "in snow": 1798, "c": 1799, "rough": 1800, "vacation": 1801, "air": 1802, "windy": 1803, "victoria": 1804, "4:45": 1805, "trick": 1806, "coconut": 1807, "labrador": 1808, "on left": 1809, "yellow and green": 1810, "butterfly": 1811, "fake": 1812, "on napkin": 1813, "bricks": 1814, "wine glasses": 1815, "detroit": 1816, "man's": 1817, "parsley": 1818, "art": 1819, "subway": 1820, "wave": 1821, "placemat": 1822, "hydrant": 1823, "sofa": 1824, "pigeon": 1825, "riding elephant": 1826, "all": 1827, "branches": 1828, "plant": 1829, "to eat": 1830, "zucchini": 1831, "feta": 1832, "neon": 1833, "mouse pad": 1834, "cloud": 1835, "toilet paper": 1836, "pumpkin": 1837, "rowing": 1838, "toronto": 1839, "handicap": 1840, "seeds": 1841, "fly kite": 1842, "chicago": 1843, "marble": 1844, "frame": 1845, "150": 1846, "rocky": 1847, "give way": 1848, "sauce": 1849, "it's not": 1850, "control": 1851, "high chair": 1852, "playstation": 1853, "xbox": 1854, "not likely": 1855, "roman": 1856, "land": 1857, "1:35": 1858, "lifeguard": 1859, "on pizza": 1860, "size": 1861, "bull": 1862, "dandelions": 1863, "equestrian": 1864, "goose": 1865, "8 feet": 1866, "recessed": 1867, "statue": 1868, "index": 1869, "phillies": 1870, "strike": 1871, "mirrors": 1872, "pointing": 1873, "farmer": 1874, "collie": 1875, "motorbike": 1876, "lanes": 1877, "bikes": 1878, "biker": 1879, "arrows": 1880, "gas station": 1881, "logs": 1882, "smaller": 1883, "desert": 1884, "yield": 1885, "flags": 1886, "stool": 1887, "kitten": 1888, "doll": 1889, "daffodils": 1890, "letters": 1891, "dishwasher": 1892, "first base": 1893, "nuts": 1894, "2013": 1895, "persian": 1896, "swim trunks": 1897, "deep": 1898, "o": 1899, "doubles": 1900, "toothpicks": 1901, "in field": 1902, "wristband": 1903, "wheels": 1904, "baking": 1905, "4:15": 1906, "11:00": 1907, "ear": 1908, "2007": 1909, "51": 1910, "chevy": 1911, "using computer": 1912, "frog": 1913, "storm": 1914, "boogie board": 1915, "hungry": 1916, "by window": 1917, "ambulance": 1918, "pigtails": 1919, "audi": 1920, "microsoft": 1921, "on man": 1922, "cannot tell": 1923, "stained glass": 1924, "hugging": 1925, "laying down": 1926, "3:00": 1927, "taxi": 1928, "pedestrian": 1929, "landing": 1930, "numbers": 1931, "38": 1932, "stones": 1933, "on tree": 1934, "clocks": 1935, "new": 1936, "picnic": 1937, "fog": 1938, "buffalo": 1939, "under armour": 1940, "cocker spaniel": 1941, "orioles": 1942, "no sign": 1943, "telling time": 1944, "bags": 1945, "golden gate": 1946, "cover": 1947, "castle": 1948, "canoe": 1949, "selfie": 1950, "cream": 1951, "floating": 1952, "indoor": 1953, "antique": 1954, "aluminum": 1955, "silver and black": 1956, "cast iron": 1957, "peas": 1958, "sun hat": 1959, "on right": 1960, "swiss": 1961, "flour": 1962, "under sink": 1963, "fashion": 1964, "fedora": 1965, "shells": 1966, "1 hour": 1967, "puppy": 1968, "in stands": 1969, "not here": 1970, "motor": 1971, "thousands": 1972, "120": 1973, "sail": 1974, "butt": 1975, "mexican": 1976, "dead end": 1977, "paddle": 1978, "bathing suit": 1979, "shop": 1980, "onion rings": 1981, "boxing": 1982, "birthday cake": 1983, "chalk": 1984, "scenery": 1985, "style": 1986, "nissan": 1987, "sticker": 1988, "on rack": 1989, "1 4": 1990, "woman's": 1991, "surprised": 1992, "north face": 1993, "squash": 1994, "not sure": 1995, "email": 1996, "spotted": 1997, "seat": 1998, "himself": 1999, "circles": 2000, "san diego": 2001, "kia": 2002, "mattress": 2003, "obama": 2004, "lamb": 2005, "american flag": 2006, "climbing": 2007, "skull and crossbones": 2008, "roast beef": 2009, "visor": 2010, "herd": 2011, "double": 2012, "52": 2013, "high": 2014, "stagecoach": 2015, "cart": 2016, "feeding": 2017, "eaten": 2018, "cone": 2019, "11:15": 2020, "smoothie": 2021, "golf": 2022, "colorado": 2023, "electronics": 2024, "5:15": 2025, "bowling": 2026, "players": 2027, "ketchup and mustard": 2028, "styrofoam": 2029, "6 feet": 2030, "hawk": 2031, "cheddar": 2032, "12:28": 2033, "arabic": 2034, "12:25": 2035, "12:10": 2036, "shower curtain": 2037, "army": 2038, "salmon": 2039, "10:40": 2040, "hanging": 2041, "whole": 2042, "behind fence": 2043, "bars": 2044, "moss": 2045, "no dog": 2046, "traffic": 2047, "10:25": 2048, "r": 2049, "countryside": 2050, "machine": 2051, "directions": 2052, "cooked": 2053, "aa": 2054, "6:45": 2055, "4 way": 2056, "stripe": 2057, "brand": 2058, "baseball player": 2059, "bunk": 2060, "coleslaw": 2061, "fishing boat": 2062, "at table": 2063, "europe": 2064, "dead": 2065, "arch": 2066, "scrambled": 2067, "clothing": 2068, "closet": 2069, "egg": 2070, "suitcases": 2071, "indoors": 2072, "coffee pot": 2073, "tires": 2074, "lilies": 2075, "cafe": 2076, "9:35": 2077, "teal": 2078, "toothpaste": 2079, "in background": 2080, "tarmac": 2081, "painted": 2082, "sunset": 2083, "orange and yellow": 2084, "oar": 2085, "peaches": 2086, "zebra and giraffe": 2087, "ladybug": 2088, "20 ft": 2089, "sesame seeds": 2090, "hills": 2091, "2:30": 2092, "stucco": 2093, "tail": 2094, "couple": 2095, "kawasaki": 2096, "smooth": 2097, "powdered sugar": 2098, "pedestrian crossing": 2099, "french fries": 2100, "picnic table": 2101, "teeth": 2102, "ribbon": 2103, "saddle": 2104, "15 feet": 2105, "earbuds": 2106, "on train": 2107, "39": 2108, "curb": 2109, "tow": 2110, "shark": 2111, "white and orange": 2112, "6:25": 2113, "gravy": 2114, "fork and spoon": 2115, "pooping": 2116, "curtain": 2117, "lime": 2118, "skull": 2119, "crossing": 2120, "speed limit": 2121, "peacock": 2122, "boredom": 2123, "neck": 2124, "hit": 2125, "dragon": 2126, "tissues": 2127, "basil": 2128, "waving": 2129, "blue team": 2130, "rectangles": 2131, "helicopter": 2132, "mud": 2133, "us": 2134, "balcony": 2135, "red and gray": 2136, "firefighter": 2137, "sunflower": 2138, "wallpaper": 2139, "best buy": 2140, "11:20": 2141, "public market center": 2142, "seattle": 2143, "bookshelf": 2144, "looking": 2145, "1 inch": 2146, "harley": 2147, "urinal": 2148, "cartoon": 2149, "t shirt and jeans": 2150, "navy": 2151, "fedex": 2152, "rays": 2153, "deck": 2154, "coaster": 2155, "1:20": 2156, "50 feet": 2157, "4:20": 2158, "us open": 2159, "looking at camera": 2160, "600": 2161, "national express": 2162, "white house": 2163, "5:00": 2164, "jp morgan": 2165, "palm trees": 2166, "tub": 2167, "pens": 2168, "soldiers": 2169, "2 people": 2170, "animal": 2171, "speaker": 2172, "hamburger": 2173, "spaghetti": 2174, "green beans": 2175, "it isn't": 2176, "10:20": 2177, "buildings": 2178, "on shelf": 2179, "baseball uniform": 2180, "tiled": 2181, "orange and blue": 2182, "90": 2183, "north america": 2184, "arrow": 2185, "news": 2186, "tropicana": 2187, "formal": 2188, "in grass": 2189, "thumbs up": 2190, "clip": 2191, "gate": 2192, "tennis player": 2193, "lilac": 2194, "pastry": 2195, "nose": 2196, "pacifier": 2197, "11:35": 2198, "different teams": 2199, "cardinals": 2200, "exhaust": 2201, "hauling": 2202, "on tray": 2203, "bagel": 2204, "huge": 2205, "out of focus": 2206, "cook": 2207, "wheat": 2208, "photo": 2209, "ghost": 2210, "sedan": 2211, "qatar": 2212, "zig zag": 2213, "lanyard": 2214, "pink and white": 2215, "sesame": 2216, "space": 2217, "no clock": 2218, "warning": 2219, "snowy": 2220, "tater tots": 2221, "tropical": 2222, "grandfather": 2223, "mac": 2224, "magnet": 2225, "photoshop": 2226, "pajamas": 2227, "350": 2228, "casserole": 2229, "4:55": 2230, "pelican": 2231, "2009": 2232, "clydesdale": 2233, "tow truck": 2234, "belt": 2235, "west": 2236, "omelet": 2237, "heavy": 2238, "crown": 2239, "in corner": 2240, "hexagon": 2241, "mound": 2242, "iris": 2243, "g": 2244, "12:45": 2245, "2:15": 2246, "3:10": 2247, "drawing": 2248, "only": 2249, "little girl": 2250, "washing": 2251, "nokia": 2252, "windsor": 2253, "2 men": 2254, "parmesan cheese": 2255, "on woman": 2256, "freezer": 2257, "icing": 2258, "venice": 2259, "dairy": 2260, "several": 2261, "concentration": 2262, "3:15": 2263, "no smoking": 2264, "kayak": 2265, "frosting": 2266, "jetblue": 2267, "thoroughbred": 2268, "parakeet": 2269, "shoe": 2270, "skeleton": 2271, "britain": 2272, "ties": 2273, "in sink": 2274, "patio": 2275, "bank": 2276, "camouflage": 2277, "privacy": 2278, "bib": 2279, "blue and gray": 2280, "looking out window": 2281, "falling": 2282, "bucket": 2283, "cupcakes": 2284, "throw ball": 2285, "garden": 2286, "almonds": 2287, "ducati": 2288, "ireland": 2289, "plastic wrap": 2290, "starbucks": 2291, "all way": 2292, "bark": 2293, "home plate": 2294, "base": 2295, "dog food": 2296, "toys": 2297, "blue and orange": 2298, "1 in front": 2299, "foot": 2300, "dc": 2301, "california": 2302, "towing": 2303, "cheesecake": 2304, "bushes": 2305, "bow tie": 2306, "millions": 2307, "down street": 2308, "2011": 2309, "police officer": 2310, "windmill": 2311, "taking pictures": 2312, "street name": 2313, "cleaning": 2314, "on pole": 2315, "russia": 2316, "main street": 2317, "catch ball": 2318, "mario": 2319, "pirate": 2320, "track": 2321, "garage": 2322, "7:10": 2323, "they aren't": 2324, "mother and child": 2325, "tents": 2326, "fancy": 2327, "tattoos": 2328, "alcohol": 2329, "2:45": 2330, "wheelchair": 2331, "money": 2332, "top hat": 2333, "willow": 2334, "cd": 2335, "brushing hair": 2336, "pancake": 2337, "80": 2338, "listening to music": 2339, "green and red": 2340, "barrier": 2341, "vests": 2342, "hiking": 2343, "tank top": 2344, "lufthansa": 2345, "student": 2346, "menu": 2347, "forehand": 2348, "wii controllers": 2349, "acer": 2350, "wall st": 2351, "hundreds": 2352, "water ski": 2353, "furniture": 2354, "paisley": 2355, "pizza hut": 2356, "baseball game": 2357, "hill": 2358, "prom": 2359, "1 world": 2360, "tiara": 2361, "students": 2362, "information": 2363, "hazy": 2364, "nasa": 2365, "canon": 2366, "bird feeder": 2367, "crane": 2368, "dr pepper": 2369, "logitech": 2370, "2:10": 2371, "all of them": 2372, "utensils": 2373, "telephone": 2374, "converse": 2375, "bone": 2376, "jeep": 2377, "nursing": 2378, "krispy kreme": 2379, "cameraman": 2380, "pee": 2381, "ranch": 2382, "polka dots": 2383, "railroad crossing": 2384, "shirts": 2385, "feeder": 2386, "above toilet": 2387, "unclear": 2388, "below": 2389, "43": 2390, "spoons": 2391, "calendar": 2392, "vaio": 2393, "fox": 2394, "mint": 2395, "after": 2396, "spiderman": 2397, "lg": 2398, "concert": 2399, "on rock": 2400, "fluffy": 2401, "gray and black": 2402, "coats": 2403, "lady": 2404, "dodge": 2405, "easyjet": 2406, "pearl": 2407, "bunt": 2408, "flat screen": 2409, "10:30": 2410, "music": 2411, "polar bears": 2412, "riding horse": 2413, "lift": 2414, "angry": 2415, "cookies": 2416, "3:45": 2417, "buttons": 2418, "hot": 2419, "cute": 2420, "behind": 2421, "dole": 2422, "in motion": 2423, "26": 2424, "pans": 2425, "love": 2426, "winnie pooh": 2427, "pear": 2428, "copyright": 2429, "2 hours": 2430, "snowsuit": 2431, "kissing": 2432, "backhand": 2433, "to get to other side": 2434, "metro": 2435, "swans": 2436, "very fast": 2437, "can't see it": 2438, "nintendo": 2439, "direction": 2440, "waiting": 2441, "mohawk": 2442, "st patrick's day": 2443, "rail": 2444, "hoodie": 2445, "feet": 2446, "swirls": 2447, "muffins": 2448, "4:05": 2449, "106": 2450, "10:55": 2451, "coins": 2452, "mitt": 2453, "game controller": 2454, "room": 2455, "adults": 2456, "urinals": 2457, "cameras": 2458, "marker": 2459, "upright": 2460, "brass": 2461, "sled": 2462, "teacher": 2463, "conductor": 2464, "farmers market": 2465, "toiletries": 2466, "blue and black": 2467, "soccer field": 2468, "banana peel": 2469, "sprite": 2470, "doughnuts": 2471, "bank of america": 2472, "on his face": 2473, "heat": 2474, "emergency": 2475, "ski slope": 2476, "hard": 2477, "41": 2478, "6:00": 2479, "in his hand": 2480, "cluttered": 2481, "dog show": 2482, "on boat": 2483, "grizzly": 2484, "drums": 2485, "not": 2486, "in hand": 2487, "easy": 2488, "400": 2489, "under table": 2490, "d": 2491, "hitting ball": 2492, "photography": 2493, "intersection": 2494, "backwards": 2495, "crocs": 2496, "marina": 2497, "chips": 2498, "bible": 2499, "harry potter": 2500, "hawaii": 2501, "fanta": 2502, "half full": 2503, "carriage": 2504, "curious": 2505, "12:50": 2506, "black white": 2507, "geese": 2508, "pork": 2509, "mailbox": 2510, "l": 2511, "sidecar": 2512, "poop": 2513, "wings": 2514, "penguin": 2515, "to see": 2516, "pocket": 2517, "steps": 2518, "cubs": 2519, "junk": 2520, "deer": 2521, "ottoman": 2522, "salt": 2523, "condiments": 2524, "1:55": 2525, "post": 2526, "bulldog": 2527, "notebook": 2528, "no cat": 2529, "champagne": 2530, "jets": 2531, "knee pads": 2532, "throw frisbee": 2533, "drinks": 2534, "leopard": 2535, "taller": 2536, "cooler": 2537, "bundt": 2538, "monday": 2539, "grape": 2540, "wine tasting": 2541, "under": 2542, "baskets": 2543, "santa hat": 2544, "chest": 2545, "sewing": 2546, "on car": 2547, "sony ericsson": 2548, "peeing": 2549, "for photo": 2550, "tour": 2551, "few": 2552, "singapore": 2553, "fireman": 2554, "fire extinguisher": 2555, "wildebeest": 2556, "lemons": 2557, "peanuts": 2558, "babies": 2559, "wiimote": 2560, "guitar hero": 2561, "slide": 2562, "stopped": 2563, "library": 2564, "multi colored": 2565, "blue and pink": 2566, "choppy": 2567, "sailing": 2568, "brush": 2569, "grinding": 2570, "jelly": 2571, "dairy queen": 2572, "shaking hands": 2573, "ge": 2574, "tigers": 2575, "tokyo": 2576, "philadelphia": 2577, "ski boots": 2578, "buses": 2579, "11:45": 2580, "collage": 2581, "pink and blue": 2582, "jesus": 2583, "singles": 2584, "iron": 2585, "coffee table": 2586, "2 years": 2587, "don't walk": 2588, "classroom": 2589, "on water": 2590, "potato salad": 2591, "posts": 2592, "harbor": 2593, "residential": 2594, "joshua": 2595, "uk": 2596, "burgers": 2597, "deli": 2598, "kicking": 2599, "lace": 2600, "overalls": 2601, "vehicles": 2602, "ram": 2603, "dancing": 2604, "47": 2605, "shed": 2606, "lid": 2607, "he's not": 2608, "fans": 2609, "amtrak": 2610, "space shuttle": 2611, "ostrich": 2612, "bathtub": 2613, "kneeling": 2614, "2:50": 2615, "mall": 2616, "yellow and orange": 2617, "gazebo": 2618, "wax": 2619, "slow down": 2620, "lays": 2621, "hammer time": 2622, "octopus": 2623, "crib": 2624, "banana split": 2625, "broadway": 2626, "pottery": 2627, "wavy": 2628, "farmers": 2629, "holding phone": 2630, "on phone": 2631, "squirrel": 2632, "wax paper": 2633, "tusks": 2634, "dining": 2635, "packing": 2636, "kangaroo": 2637, "dawn": 2638, "defense": 2639, "powdered": 2640, "thomas": 2641, "budweiser": 2642, "back left": 2643, "stir fry": 2644, "beijing": 2645, "11:10": 2646, "tripod": 2647, "wide": 2648, "slope": 2649, "black and gray": 2650, "planter": 2651, "chili": 2652, "siblings": 2653, "kayaking": 2654, "captivity": 2655, "opaque": 2656, "rack": 2657, "panda": 2658, "doorway": 2659, "wheelie": 2660, "pelicans": 2661, "genetics": 2662, "not in service": 2663, "volvo": 2664, "dachshund": 2665, "v": 2666, "on laptop": 2667, "western": 2668, "gone": 2669, "birthday party": 2670, "parking garage": 2671, "tying tie": 2672, "blueberry": 2673, "scale": 2674, "notes": 2675, "train car": 2676, "man made": 2677, "stability": 2678, "lily": 2679, "lying down": 2680, "pacific": 2681, "high heels": 2682, "pare": 2683, "checkerboard": 2684, "partly cloudy": 2685, "cool": 2686, "n": 2687, "toilets": 2688, "tree branch": 2689, "copper": 2690, "cycling": 2691, "5:50": 2692, "870": 2693, "shopping": 2694, "7:05": 2695, "zipper": 2696, "holding umbrella": 2697, "batman": 2698, "lotion": 2699, "1:25": 2700, "black and brown": 2701, "playing video game": 2702, "girl on right": 2703, "legos": 2704, "drinking water": 2705, "burrito": 2706, "plow": 2707, "jet ski": 2708, "spiral": 2709, "ibm": 2710, "tools": 2711, "flashlight": 2712, "cherries": 2713, "maple leaf": 2714, "mountainous": 2715, "under tree": 2716, "vines": 2717, "sushi": 2718, "baker": 2719, "snake": 2720, "globe": 2721, "target": 2722, "john": 2723, "pomeranian": 2724, "tuxedo": 2725, "hockey": 2726, "sleeve": 2727, "leaning": 2728, "wireless": 2729, "11:05": 2730, "compaq": 2731, "do not enter": 2732, "radish": 2733, "1:05": 2734, "dim": 2735, "advertisement": 2736, "movement": 2737, "model": 2738, "hammock": 2739, "swing": 2740, "sheet": 2741, "google": 2742, "boardwalk": 2743, "right 1": 2744, "haircut": 2745, "ankle": 2746, "3:30": 2747, "exit": 2748, "csx": 2749, "tim hortons": 2750, "lego": 2751, "cucumbers": 2752, "angel": 2753, "12:20": 2754, "racquet": 2755, "behind woman": 2756, "potato": 2757, "egg salad": 2758, "controllers": 2759, "recliner": 2760, "upside down": 2761, "mosaic": 2762, "before": 2763, "antenna": 2764, "3:50": 2765, "10:15": 2766, "lion": 2767, "camo": 2768, "fighter": 2769, "silver and red": 2770, "dirt bike": 2771, "playing video games": 2772, "used": 2773, "crates": 2774, "horizontally": 2775, "plunger": 2776, "refrigerators": 2777, "radiator": 2778, "stork": 2779, "in basket": 2780, "cap": 2781, "living": 2782, "married": 2783, "briefcase": 2784, "bottom left": 2785, "30 mph": 2786, "ascending": 2787, "flip phone": 2788, "101": 2789, "11:50": 2790, "gun": 2791, "arizona": 2792, "foam": 2793, "serious": 2794, "y": 2795, "close up": 2796, "pancakes": 2797, "heineken": 2798, "paw": 2799, "cnn": 2800, "comforter": 2801, "sheets": 2802, "8:35": 2803, "driveway": 2804, "fair": 2805, "cleaner": 2806, "1 year": 2807, "delivery": 2808, "commuter": 2809, "apple and banana": 2810, "chase": 2811, "72": 2812, "safe": 2813, "trucks": 2814, "trunks": 2815, "spider": 2816, "64": 2817, "slacks": 2818, "meeting": 2819, "7:00": 2820, "skiers": 2821, "shaved": 2822, "carrot cake": 2823, "holding": 2824, "surfers": 2825, "giraffe and zebra": 2826, "7:45": 2827, "mississippi": 2828, "seaweed": 2829, "black and pink": 2830, "horse racing": 2831, "orchid": 2832, "rv": 2833, "tourist": 2834, "above door": 2835, "leaving": 2836, "pitch": 2837, "crest": 2838, "miami": 2839, "asics": 2840, "flood": 2841, "bus station": 2842, "take off": 2843, "amazon": 2844, "practice": 2845, "entering": 2846, "diesel": 2847, "pm": 2848, "wetsuits": 2849, "remodeling": 2850, "porch": 2851, "7:35": 2852, "tie dye": 2853, "baked": 2854, "life jacket": 2855, "cylinder": 2856, "grilled cheese": 2857, "meatballs": 2858, "paddling": 2859, "banana bread": 2860, "monster": 2861, "smiley face": 2862, "not high": 2863, "keys": 2864, "dreadlocks": 2865, "kitchenaid": 2866, "straight ahead": 2867, "badminton": 2868, "long sleeve": 2869, "sheepdog": 2870, "5:18": 2871, "end": 2872, "on shore": 2873, "scratching": 2874, "oriental": 2875, "5:05": 2876, "alligator": 2877, "city bus": 2878, "purple and white": 2879, "10:50": 2880, "each other": 2881, "weeds": 2882, "tinkerbell": 2883, "rottweiler": 2884, "apartments": 2885, "snowflakes": 2886, "stop light": 2887, "sweatshirt": 2888, "shore": 2889, "bidet": 2890, "switzerland": 2891, "stretching": 2892, "tv stand": 2893, "boundaries": 2894, "65": 2895, "bronze": 2896, "jar": 2897, "middle 1": 2898, "54": 2899, "skate": 2900, "easton": 2901, "turn right": 2902, "raspberries": 2903, "singing": 2904, "on bus": 2905, "carnations": 2906, "descending": 2907, "classic": 2908, "suspenders": 2909, "not long": 2910, "8:50": 2911, "father": 2912, "anniversary": 2913, "hsbc": 2914, "very long": 2915, "space needle": 2916, "skatepark": 2917, "fruit salad": 2918, "kenmore": 2919, "no water": 2920, "8:05": 2921, "db": 2922, "baby's breath": 2923, "shelter": 2924, "1980": 2925, "no left turn": 2926, "washington monument": 2927, "ham and cheese": 2928, "10 inches": 2929, "8:55": 2930, "savory": 2931, "6:35": 2932, "indians": 2933, "9:05": 2934, "fires": 2935, "pipes": 2936, "donkey": 2937, "cds": 2938, "mitsubishi": 2939, "tell time": 2940, "outfield": 2941, "christian": 2942, "puma": 2943, "parking meters": 2944, "cranes": 2945, "flip": 2946, "wine bottle": 2947, "stadium": 2948, "mouthwash": 2949, "heinz": 2950, "distance": 2951, "macaroni": 2952, "on plane": 2953, "triumph": 2954, "more": 2955, "4:50": 2956, "single engine": 2957, "disney": 2958, "on stove": 2959, "shih tzu": 2960, "fried": 2961, "to hit ball": 2962, "in her hand": 2963, "sunrise": 2964, "2nd": 2965, "elmo": 2966, "kite string": 2967, "suzuki": 2968, "traffic lights": 2969, "blt": 2970, "i": 2971, "hitting": 2972, "htc": 2973, "healthy": 2974, "current": 2975, "star alliance": 2976, "stomach": 2977, "watch tv": 2978, "tulip": 2979, "5:10": 2980, "right side": 2981, "4:40": 2982, "ginger": 2983, "on sign": 2984, "cushion": 2985, "5:30": 2986, "learning": 2987, "pencil": 2988, "maroon": 2989, "food processor": 2990, "5:40": 2991, "dog bed": 2992, "michigan": 2993, "close": 2994, "license plate": 2995, "crows": 2996, "right hand": 2997, "normal": 2998, "green and brown": 2999, "1.00": 3000, "000": 3001, "1:40": 3002, "wing": 3003, "american airlines": 3004, "kodak": 3005, "mural": 3006, "sniffing": 3007, "1:15": 3008, "behind bench": 3009, "cardinal": 3010, "no light": 3011, "warmth": 3012, "paved": 3013, "skyscrapers": 3014, "swinging bat": 3015, "watermark": 3016, "in cup": 3017, "pizza box": 3018, "dough": 3019, "hiding": 3020, "goal": 3021, "no plate": 3022, "shower head": 3023, "ripe": 3024, "1:10": 3025, "1 in back": 3026, "older": 3027, "nest": 3028, "multiple": 3029, "cinnamon": 3030, "bin": 3031, "new orleans": 3032, "colored": 3033, "enclosure": 3034, "bride": 3035, "on dresser": 3036, "star wars": 3037, "in back": 3038, "triangles": 3039, "over easy": 3040, "cilantro": 3041, "statues": 3042, "sticks": 3043, "formica": 3044, "roundabout": 3045, "bowls": 3046, "ahead": 3047, "years": 3048, "drain": 3049, "veggies": 3050, "no shirt": 3051, "taking photo": 3052, "tugboat": 3053, "broke": 3054, "59": 3055, "cadillac": 3056, "prince": 3057, "left side": 3058, "1 in middle": 3059, "10:45": 3060, "drying": 3061, "11:25": 3062, "silk": 3063, "conference room": 3064, "buoys": 3065, "pockets": 3066, "daffodil": 3067, "6:40": 3068, "walgreens": 3069, "4 ft": 3070, "6:05": 3071, "virgin atlantic": 3072, "12:40": 3073, "digital": 3074, "ups": 3075, "westjet": 3076, "bikers": 3077, "us air force": 3078, "limes": 3079, "comcast": 3080, "dip": 3081, "7:55": 3082, "man in middle": 3083, "bus driver": 3084, "soon": 3085, "futon": 3086, "selling": 3087, "braid": 3088, "mariners": 3089, "wisconsin": 3090, "99": 3091, "citizen": 3092, "broccoli and carrots": 3093, "grocery store": 3094, "us airways": 3095, "49": 3096, "bored": 3097, "red velvet": 3098, "hotel room": 3099, "qantas": 3100, "tam": 3101, "korean air": 3102, "10:35": 3103, "whirlpool": 3104, "coffee cup": 3105, "hilly": 3106, "9:12": 3107, "whipped cream": 3108, "video": 3109, "finger": 3110, "competition": 3111, "hollywood": 3112, "sas": 3113, "backward": 3114, "beads": 3115, "cosmo": 3116, "10:08": 3117, "jal": 3118, "6:30": 3119, "100 year party ct": 3120, "hispanic": 3121, "in cabbage town": 3122, "opponent": 3123, "woodpecker": 3124, "visilab": 3125, "mt airy": 3126, "crosstown": 3127, "freightliner": 3128} \ No newline at end of file diff --git a/uniter_model/model/__init__.py b/uniter_model/model/__init__.py new file mode 100644 index 0000000..83a0b46 --- /dev/null +++ b/uniter_model/model/__init__.py @@ -0,0 +1,8 @@ +from .model import UniterForPretraining, UniterConfig +from .vqa import UniterForVisualQuestionAnswering +from .nlvr2 import (UniterForNlvr2PairedAttn, UniterForNlvr2Paired, + UniterForNlvr2Triplet) +from .itm import (UniterForImageTextRetrieval, + UniterForImageTextRetrievalFast, + UniterForImageTextRetrievalHardNeg) +from .ve import UniterForVisualEntailment diff --git a/uniter_model/model/__pycache__/__init__.cpython-38.pyc b/uniter_model/model/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..93001cbf03b0a0ad1561435487d791e47605b8ed GIT binary patch literal 630 zcmaJ;yKdV+5adIYWIb#b3Xs5sgLEwH16&(H3?oL;2#zhexP~}^Yw3WyqrE)JiTiv+ zKEc0qZL0i2s_asLpfo;!gWcI7Z-+AUvU)u=4SUK)O@0+o)MH$8Joz) zGPa3KV(t@@+K%kBb!sx(m0jDDJ=>T4mg|^3JCFlAltVj`qn69e*zU`Hs=K#Q-upxM zMtKq@bEiGvRfX3G9+h_5EyHNTURLg2FY`FO->U9(-6->|f!gcJUAX!Qklx7C@OG=2 znl+wXD~(WIc<*-Z*qMYk|LbR{9}IZ<(9w4A)v6`T;Jd$})ikV?3CFJtuT<^BhgSq& z6qZw9L+`e#jiJxZD{U+|Kg|z9W~&kmatxRf$aMXvu)|g6*c@g&>N4~QLaM#RKDT=e zoNh?gfI})?B0OM{;cp@4HDE2CwK4ixocii%1t3<qm^*aeg&{Z?W0a^*Ra svhu=GI-+UwD%uUj28)7H^F?v`@FVokiE9je7)TcJA((>G>f!q9{tDNUm6xf0iQ+N!v(v6UG*-*w#j|m-bp*$vbh}YEIXX zP0nB>AjmnVAcsB$$sur%!{(I4$lCc{_0Rl> zUdc%|iv>Jn*Hl;4t9td`t5;vWn(t0cRWy9w`RL=Ve|tmI{)t|OKUDh)p1?!mXbsKL z9iy!`boFgC41DL>xlX>3?-Uw^j@dAEYLjmlJEcacQ*M+yl}1IC71|S>$;M=7sxj56 zHmWF_bV^S7p58e1EzPMo6ZbS{g6EHNjcLcYsnwL$?g8uc+M&C$!)=GzVaUDE z^*z?~9nM<9@38Jc=!@nKTlV-}F4#U70UFdQpMYg;qn3-zZ~IQK&7<<$wy--qPVz zy>Rn-^E{20d{*#X!t*Mg;4h1s7HSW5sY(6V@Xx8dE{$WOe|{)GC$)#hvG!iBf1y9$ zUyutjw`T-b!`u(`Fdx?)>GFb{mkZle^Dfm$ooJ*%DazWaY(eJdwf>@9l=)drE};BR z!knWY>FB8-7muLBGg|*^^0brND?c#AhCDC+IyAvwL7GRo-J+9s3XXZ-052E&OMUhr z_kknl9_T_!v;UM+>|c@#{io6Tl6?9j16nM7WVDP6TK^gOlvCO(i~o{O#dRoqO7)DI zBlIS9)R)lHWeGaU1&nxhO8q4~%zK+gK<}6ErbeK?9~}Wb+@Ue3HW@>gY7B5+PPjLS zckgeXjA!)tlP}~t;Ylty4$7W~5vx6>l4;Hjr zdRUe-$9n%+`K+=UqyJq0>+(6$HZ=MJ`E{}wr*N#(dvH%#QA~9PRLgSt1gbv=73?Xa z`sonWEU0R-b^_JEf(rJRQTPEjt9uE+{drpsk%!p4MbWeh5RdW8tJ5*#>|Vyx>IQxa(Z_PPt}>Y zkDlPsu!8#2vYPO4OHox7RdtA}3aaWDRn>bYW!{l$l;jJna-QQv*`MdQ-Cbsa%TNIR*HZ@pem+CGGxE zvR}}SDSpcj3Sv>7jb|742Ky@ndBqQvmydP1sLJNzGOP>I2vOdW=N_R{{!f(;DBn?( z=b-T=w68#CmlG)$#E)enTRE>r`D4|V%Fs?;R%Jg{W#tE^xF;_=m2YbO%St!UJTyvu zrW!^5Sc5-NUgo?em!PYqFx=p=F?DW3;JFnPNlr^5tv_{41&_v?Aru!q1QMTbEaM z&#M)y%lv}Y0{)P7r$paQ30P#>{<8^9t1x`7iS1xr<>>8$9v267w#Sj|+wC4-XC1rC z?6`v29xw|Fc*y(~HCj<^Sl4z1_B|i<4B(;L3s`CAs4;{sOYUz zGbjMakZsu^KqBT&wW=2qLNdsR^hpEBXMF5rycY)JCHT+kLTtj{vbWr}8y?U&&@FFs z2t6cCgJl|+I|EW^epHHCK`ch%FSd>+l4U4t*K@sXOuE}`AFzZo=C+s{use4)MUUUP zGZ-FIx4MuINUV+>z);W=F2svq#&L)6w3L?z-*vi108^cN<_D zbXy0iS%%qH+2#&7!9PT^<$AUNw#UG$?Dh_iyNuJtyieK$pu!X$GU#t(h>+pCQnii-F>VwU>As;p_cmW8h)$`&7X*Zj^3%f)aEz?yNuh2kwmLZ=u!W%iWL_r>w zU8}_fY@1F=nB9#QYuW994}3wg=u<_2Gg#H#9T%(64MsbLVmo9I%2{CPd@#&9yzA)e z{_yKy>R*Crz+}AAp4?&KBpA9)a#=|ZPAkr(t|So=O;<^CU>1ItLoL35|Es9^-`ugg ze6W|?Ah$WRx2(jQ$U0l^*x}Clo#i*{Y@^Ozuib$dIlTXyw3bMRp)Izt$D}@zwYj$) z?x1zrAa;I%dkFo!J}Q2qSG$HR=|+sNe6>1V6Z1OgfB5xYT*M( zZ(Po>0S1Kut!3{t-6ykxp#^`UG{`EIEuUnI!$YR+R6OWJa~-VAR9L?rGsr~rmV$&M zo$xM&{44Qi9yAbtr=0l2@kqF zG9A9}Hu>Sxv|X^K?RmZ$BHNS)Tfpk!vyW+{{MhPYp3n$}(7RD7g>JREsGx$GQ2}t* zZbx}?*-`ac5CBgAM!YVBFQR49n1-(vfB`PTB5ar+_qk@`eG?OVNJCQ`?OdC-(Jvqc^A(~9z zNr;h2skN7U??+P*_yFBU#N6;K~ymbjcA_ubLtFbIswnV*g3z6eM~ zQzJMdni>Ho(bNbyiAt$A&2syV(g%C{R)9csqnKo(Vz?bz*t8+;;4ms@B8y7o<>)wv zh;!3M@OeAZKZx|bNZ*ePYb!D>3?RZNZ*k8-$=$};5%v()#!Bs?0eXKyr7t$lj*2G&7^QU*#z6)~=)(&AK?%~>M=wplEYdgKIHMe7L z^OcZ)7_MY~eyvSsNIJPVez?{FYF;OWOQlu3nfbJ|g3`s2 zw@F1oeNY{}l@vX_O&ciDFOH-Yw4BrDRoSFoGG^8L9H>own#TSt|GAMfes1Q$kv^kJ zOvCtjsW1axt9sSAh_>nhcXOc555AE@e})MjUcz5RKbO*fX68owGYhjiL;DJNsHl;v zy7^f?KVulWp&6f>Mio3VNNCb9jcN2Mz5~fM@r-=TQkvgm86&g?(UYZ_u3{@r->nxqe-Hfkm%meIZ`9eX8he?=b+7s!ylj72 zE!Y?8@4=RXx_a5E-G2YAY{!dRE6mhqs@~1|$k;;vw=MU58ic-SB>KKpi}@Q3L1o_J zmlEkoTr7eL?6vr+D>-1}+zsW7d=X#uvM44Ui>a8F=36 zqdsB)?dk&DPI=hW$KUf*aEyh%TlCm%?dF=$|O z3AGRxx<~5bRma7y)T|sXV>DcTGFkKtRg%XA6qURvOG!yVYP)%fOF94dktwM3#D#82{C!CE zr_fJYH=UJW^rZdip|V0!HY4-mXR18u{TP?V==+Q}iR;8!TnlT*hOQxCkG>@?iSHA~ z=VSrbo-ZYIri079SS%?Zj53v;S4&I`YWg+YtT9$Z)Uh`z}z_FB`u~MjIWwtHqXr(}_ zxJJ1-N=U2XEE3@EPB%OdbblY2@gZLPHr1|E@;j6eL(z=j+W_GJDHLl{4oqMx13M}1 z6l>J@c}iZOiqjVDPXMWD`=C|^zUS?hVtk0NO=>^yuc=(M=1dp6K}|2>pxt0bmgt4 zqHPkINU9Cl%fy5?O}d6)o9g-Y{Tu0T0usI)9L!=lRTM`To# zFYT6L;&?;8qL$81I6?V@4yFS-smgFCni!0+A#Q??+LZVnz67mPrNNUlC-7qoP5dDx zBz4X`W21ME0SVVn=eHBy3eef4Y_P>g$)4gXJCLW1fKK63uv4P19 zqOP#FzKZDb>{jZrf@8@?ZRgiFYb6!F6TgWKQhbAwZ&LDGl#F7t!=r2Ju0Qz=C>^*S z-ILRk=MD~`TO1e9E*(zjo*tL;>ChwJZ@0M}Pyk^?1!~y=F1q5MPsI1ibONHE7{o`~ zTK*(slEbox;V@h=LbA&^VC zXo3Q9x8t%}F@`qFL|3EaG9_a|9R)s7F^Q@D0f8SvLE6PxT}6&G10Wu8toTu>shNmF zA*!SvRrBxW1R`FBfr!%+Q2poO{G;kxO<1Uz3gm5GzrKMV^=@o#)JlZ?J%De^iYgXG z3lUI2e$uiic%IZ0Ez9wn0Llc91|D2;JoEb_%1_>}_hI05zR2CTQkWWDxIL ztUg_u#wkV9K38XtT4fV_fFFEnr$w6xIVI{SQBsloRnHT2qY$4l=!npmN-I|B;35?G zRG5|EV00#{O>gV%whHnO!tE+Vr#Kh_su^01N-xEtxJsM`&mqa@N-F*d=uiBQ|NTql K``Y(4{l5W#wnaby literal 0 HcmV?d00001 diff --git a/uniter_model/model/__pycache__/itm.cpython-38.pyc b/uniter_model/model/__pycache__/itm.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7f130db6a7467d2f75df98ec475b884ee14ce7e2 GIT binary patch literal 6182 zcmcIoOOqQ{6~6tDy47kuX2y;)1c=HBB+A1BAprt|6XIkNLoyXSHF&YJm)*# z(bY=DVR){5e6;hACmH(-6=ok3h1ZedJ_Khi#x<^oT1&&*2=%VfGPJzTY?-JtL$h0G z6;M~;#jx15T2|L?*&2JBaf{n`7`KIW+h~P>Z(V8ZMzV2v`@Kdt;v#hJVKr>Csbys@cKtz^@SvTbsP}sJG&D8!?CjNEkO=v{ z6x-4ddV;@4UuNaJS{2$f%IrhCr6+zFLNM%-wREnv4DNiwcClk_-YH?9*xMb}DqzJn zFWu2v#ZQ>YS{7F9W>4f3Y((;M*YAjJaWlCh5*di=e)z?O#s-#>40@t5v#)?UHWKt~ z$VpeHf#4E2Wmet`I2S!H4u&Fgf~4z3gQP!5vaYS0@Z z=9rbpDPG$w&j>m(pT){*~T2}JBcId~k=N+*xT5DHhA>*}S5Qf2U z?LrdW*b`!HZ?LnbPV_}O(HBK`M{pkWI%^@#(L;_$$yygWu{KaD;whb4%O#e5p?^zO zFkcNRb|IMN;#bG7s12V!<$fq3OA1m zatX6$#r%{rllzG;m#NC!3HOaX1>~P zlYUGdny`eVj^WS9Bnmk0uvELPA3%M%vC0k%yiL4KybE|kfq4<{BHk9>7Tz}AHV#RH z0-I?eEv8lqL*&j~>j3J!uZ{Gybf_Ibtw%=cAXiRx`6s3RQ1b&9W0WyQC3Q!|LpHLw zn^q3>gKApkFacZ%6btI;dkLn!l9R&#|=jIOA1(y_7C(v9!FUjHp@L zY)>{md6OP&A*q8@mZoja9QwlX+v}m<-QoVlui=xw(-^~7PCIu1-zCEe9`ffJt1NHT-@A% z39jZE;X&^_(Q5-#xflAkV2D{c_PhO1C;+&3ACJ;MsQ}rL3R9}sM^@}dQ7B}tf`ufK z?Y+#3gHAUJc;>+Llb|z*261Ly55$eE0F&?b6$CVwm8xci9Y4WGZaeB?Yr+eoIF`>~ zZ22sSCW+^0sw%c3FdlYH`;jkDWi1vT4X}1G!K8{>5T6utyo%eA6mhnqz$-6cXKImh zjfN?Xdl)#l%;1vK{>JHlL+8O*F z(@cH%)FWJfQO(gCdlsjUEz6|;}OUWa5rA&dfrC^KODDMp2wp$P@ND@K1JdU zh^#V3+P1(wdoyafN?m9h(t;p(jBN-TE3DdfLDfE>`U*Oz>klC29ILAwD_p8_tQA!v zr&@x0t#XH#=iF;8JAcX{yaWB;6dg(HQO9tvj$FF{F?lFY$v}#~0x|KOsh%2#;4%Od z2Bd|d%BRuXG$d_NolaKT!vspZn}r-Q2^Zz#B*;O@pOR1qq^@6`5}oT0k#hLNeGZ?0 z_!f>GNv0*+bIzaW8r`tkfD6+CmT9YO#JCP*Bh(}mCDbGoCDbGowNjlIujz6#E#y79 zh0^CJX{pI=&;fk6&U){oUkNS0NsBZ6uA`*-p{K&y?I6-p z6gT^R4;b`;9}dJ@QbyoC%9(;SNHW{Z>P22h`aCm<(B;aZ45sFFMB*t8$-$|hK5RO< zrhJb+o&Y5IX*G4`sJ=?47T0@`zHp`kE-Dty`WjM9K*#hcoK&@C00+scc1AA(Msgsb znz}xG?h({)j8=1+M|6lx{t{CBH)!68{+l$AG>?*W|{5Wp9a#!chD#CAsuV zqY|&BE_|__mX7l!Xt~2{H1;(^{vj=JB3dY&Q!`H5bOycU-+ISZjW_z5V307jTe)W4QvfzbwaDh!2$|2M-3 z5J4NHkOo}ILm(sCiKGB6GvYxvGrIoG%tTPymG9@nE80!SAj#O}2P95lG4fSvP<#3i zxt6Oy0kdG~F=C`i2IG67KxjpoVCE{6Big1h`mleXy5E*^=(tm!nOthna+J$xAdH`K6tnt-q}n#p+;hxP{)^!M$y4^e;m668 zn#@5fF%p_yIrR{k2`f1{qF0U{FqoQ%>2)xcJl?mH5+ZpG5SGXL?r}Mf_bV#iufnm` z0D$#Lyf2ATG?%74M+Nv&F?F*Y^>1Z17ox9nXM2OLHxBc2((fU1$d(?c&Qxr1{2zfN z=OKxWsZdlg|C}hzAOr@RhXRz*KpTp&qHx?hj||teGG)+F7&kmhg~4rmzL8C!u@JW z4%dc5x6&%wYoG-#=BuZ*kqy^d;+3?X2h|e?%Hcp&_`RGijq4DukLy-2lKK}bT|r2` zjF7yBug|1s5RzBNDDylLy$4 zFVg%JfGQq46Jop55Lo8FI1A<(z;t`7J}|^_qoVBuUlE{i7{GB^Brl}bz1I^NT}0L zE|dtde4WIPNi5hKIgH;S#pF^JT_8EJxfonuD4TUfrsrbv)BXxm_BvL9XHEh5!Hn literal 0 HcmV?d00001 diff --git a/uniter_model/model/__pycache__/layer.cpython-38.pyc b/uniter_model/model/__pycache__/layer.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2a54c73125c15978fee84221da6d36a92022febc GIT binary patch literal 8270 zcmb_hNpKuj8SZUndR9pzd3TZ~t0u&f?Kn$?0C7z0xUgjt+X2MVjHY{}mS?8r*FCnZ z8dC+fD<^Y-Dhi6jC?|A+3s<=22IsD#imGriC(ayVzVCI87*6V*vYd0$P6q#ZHz(T^){Ac5F@L8om(D3p*2&#d zoSbW{PuzLh%Hf1^5R@a%QBaO1ltZ9A<2(z>vkB#}qg_!d z&u!Vc_dodPs@168c12*-h2ONoHP>2eudaHn)q1t&T3&Ob;Wph?SPebDm78A+!;Rpb z#l=-GTx+i^)coe+Q1iu5R9iva7flRU$vvcsmMUtbwOTbOrhl<+;y!~ruuycMusOxi zLZz#Alu+yHj=rrPQMT0s%A5jbjp{LUIzPgV<<}*@t8SxhRcoPlQ!;7Q+bul!t%aPG zTxWe@c`foKTfAdk+HhN^&s*m%UwtL8WPdNPf`+%c7B+5Mj#sa{0*kiht1ZV`^=`To z-wHQqH^QLuPJg1~3vXDjTAR4%Pgq~ILSNL@7Q$+4Z5}hPEd)12IDgVwwB*x&igqUV2pxOgpq zA{V#3M#I}$TnPPJYpx50uPl1a>Z*GzbT`9eZga(T8~|of%CCi*eQ@DgO|Afqc#$DKORfv3gi1%ftKQK%usT6PMU_)mbjaLo)T6Z5+GvM% zl=j;pD!cj50JyiR!a0eq0p+2nrn*(�$x*C2(TecB|TSZ9B@@lz6+rdcn4Dw5yHy zjo?CthNqSVWhACioWLz5WUBIcOdebrui!NlJ$Zbh)TNibXHq7P0olNdI%)i;K2Z#1 zISn~wqM7$y5njP+Plq8a3+6c@GzyGM;}pTCC_*KKC%CJ1wQWU(0M~Wg#(wyT>$*4@ zrW~zDt84p}yXwzXzV9fu5vF@@j(`qIrlau9dS(s)vdu8t&B+$Zy1A`^k{_cK#wbNl zO6WDU%@L;^`Xqfx(2;FjxJsw#IgZ=10&mN;=3jCu792{;53Q=zgpcsxMSJe5)rP?c zK2>|7u-4qF6Ucs*8Q`_O;RfOsXnkvo=471Q@B`O|0oy^%M=1Fay#g*lQID!Q zb?fEvl~3wcBWt~@hzebxl)`7Ibd)Y6v8~?GQNwCcgDG?&0n}2c=^X^xJ84-nP=m36 zmO{-8jc%%w?qs@YW>4Gr-N|;cVWw+Bl39jlK6fb5Y?xcmGq-x6cCzqwIgC>1j?UH3Z4A=Os4=Hp%?UGE(v;_hhIQB z$S8EjYf2r7W1xx!7Kd)A3W$qk~?z*w5q;CeLE^fp}wo*3q3 z%Lo`v#Uz+oLP*+3vr9<%;SSkB)BJFw>A_qvkfNypMVe|#xj$*>0q%@NW>n968z z5sQ>5ngUGXr^!;EQ>QgkpEhz@L7h_%60e8G@tQ=ZN=EFX!1l8^z=HN4YlaaiDkeUn z=E656enMIcgLWLC=ndb(EiJ@U3)5ywddM1RBNtW8E1>ot7o=oxN&#A_sMKTqk_^Bj zj)gv;ADQ0A!)Rm*^{xSTlHx6+Njl81PLp}S=whNp78Do*V+|w6j=r=y^RHt+`^~R! ze4SW!H-bt*9KzUwP9T;J=M_wk#1Rxx3ehESqar!Iz=VAX*=&^USH(+gHJiY2*f+;G zp65tXoP9SW=eZ7gV4|QGF2fIR9T*kHz}@yF%f_rzDray9FQb6L$BqW8OdJhsa5Quf za0GBZN&2&+i*KVx%op(mg7lkTLp>_Q$q?VdUz`;>(~+>h5$e=8u7lL-Lm5+zm2jkbsX7 z0%)?9CXnK6G804R3i`&-p{>Lcm2jHn<6io|%^BmUQ=%11S@xNntbBUSzVaCiW94@k zA>ufLWDRwrgNPdQ7^ek-tW;7lZ{I;(yo+Kv&n(!Eliz5_olQ)KB?f+qZVA@~wV-b8 zBLW6oCjby1af(S=rmQwOMxy1EAmyoOd77-8hr7il*zphz#uUvtrJ~QGvRl!8Z<3bY zBomtO9SE;?HStcEXPwE6U1&);I#s5Q)_ON zOFRhn(~%y8B1)e=bM@rerD%4fGhW;E)}E-kk70uxirA|Zkk7^OVOTH&aM9QGF4ifX zGq?jHd7@|O`4VgZr7!sb8Ya1?pc@&|I@yq*d4Ng+8aj@xgS$1=+fD6_8;bTGhe>c_ z3R0d5ZgS-2|B64uh>5Ox!XsK2Lt9DGAx(_Z;*}k`(q3c>z6Jr z%Jy?Z30;QUflrW2lTLb1%l?#1nJ8%4(u|}uJ-Q}qDPS##Bl46YUdt1K zLBVcwgg!7xc`Cq=zTmR&H(W7>49+Hz!9CAzWb~s)FBjFjnQg|4aa@qm-FDwM`N+76 zL(3>PI91&f5P!rt{RE<{Zb?gzc^}|m*unP}aE2Y}K{_~BbCmlyqpxG9j@>y^Ne!4t zatT@EX`2=wxb%F!i@p_OfQP7D_e2odp?}?N*>M6H6*nkG=P@Yw6Zw#c)~AOMoh;#r zFrXa&z#$SCqy#CvM=m{i`Z5XQ5&DPA0%?Ri&;AgHL|1zU?0e&a&ijeD;2s4hVKS>= zbTDOBk(XHow)Z#_D}XxK#T2Dqk#VCG`;|3JvnPJUclq{vEWXcT8AUWL_uUXEj;`YL zVGvk|gGt+yzknVTPz(|YoPECd7{N>i-AGpPJcgT?RrC%<27478>B#)TLzg%N{*bi^ zVdgW?6K5i~$m-S$kK^)*xT1Xj%rO&OWd$iu1y^%2?pz#2R;`Cn%px5O&yh$AsJz{P z8ib(rbVuuE9GoY~f7mTg(;mQfZLzzcEtj8hqU`0Wz;QT&TYUPH z!oi&%i0f$F+izA-Up~45Ba6RK*!@o#c05Qh36g(9FZo44MnI#0!+7lr6Q`DjT6(ng z%p}?V514!wqmE;eN0vAY8^L1(Ek%e{>>aPlO>A@K+9%z_2Jt zc`7iRB1e~6joTM5T%>OsLm<&75lG>TBe0|0Af+n>`@%CVdy~#@K@py*)`CX~6#F6hk0p7ru%-}DY3S6A)VYAiS^3Dtx zOwNKQIhco>fm+~6TzpD1%-s_;NvV9^8v`dBDM3o*NR+W{$FJcNL#E-&FTe4zC|4F+ znN0oZtE23y8(u&o7}Y*Lr`ZD!dA8;1qg?OnYqcS!NR_TXeNldolAo32v0Z%7#f(&@ zp>g8L{7FVn>FOlVB#%SVkd8OUiOhH7Z(QG?-;jBPF>9)(hL!22p3}{&X_WPHy8L>% PR4$aKP?k{6l;{2nkKOL5 literal 0 HcmV?d00001 diff --git a/uniter_model/model/__pycache__/model.cpython-38.pyc b/uniter_model/model/__pycache__/model.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8e29d65c06067f83f5170dac164db42257a999bc GIT binary patch literal 20954 zcmb7sd5|2}d0$`CbMEYkU0fua03=~3azRm)#6T1Y3Is?h;EE@%z5-^}aYYRWNWo|K`ohU+fvi@9<*uaBy)3XZVO| z7`{<8eABmDX4|S-Ci8a7uG;eJR2})vR5SQ>T5daA%_enn)f~z)tz0`_%}bfvDzuB$ zlH{|r zMStlHvw9Ei9rO?3-l5@)zpY1GUms?ORw5qvFKQQ|B(c80SA5s7UU`1!kY?c3XjprxB@SaF%xYTiHKFTP<^3;q&j;!s?AvD1tKb++4CZLam6(!ZOk&Ov&O zRy_>8sOyE8tiX$|1YScb4k~n4y}EZvH}P!O4_cR&3m%SJhHV(PO6aW0BDrNC8#`-3 zYxBE5^Hm<-di<7Q;Ra6QyRR?&!c)K9`|{(BDO?SXS!D7!&*BV!5lLk9&B)xg)N_&5 zH$lLbZ@+E-v4yRv$lh`KQ#+YmW7pg&|FH218+Xm zOy5|~_NN~;UOC-)S=!5?y>Il*zER8fEu@9MjkMTzke2!xq*MJY(sJKLI^EA9o$2S1 z&h`sP=lVsY^ZgRig-$Ly;5*yqkAb-`xo-WpGU$J|6Gm#Y5ry99RO@--Ts@)I!qYNq zyyEd{w^6@b3!A;5)>Sn<KfL5M zJ7dkU$$i#{#XPPweLv{P>rQ(wB<*pjgH8iWs=OAdJoM@vUnpQhs9Bc|N5Pglo9$ZC znttZAx3by3j3&@SePhsAvmvQL5K8Ny8M_-cO^5m9> zroEsJHcxNtd5!cGQ#RS3!SKi5-|T={+JWD!M?o@W7co_XW#F_O8?AH9JK0m8WQyyJ=(NX|ce7W>niy-OG9)Nr}gj`sHS;88xrV64qNSZ*{ZPy1wUiV2)1Ez$+#QD*)N~K870J zKUm27g@Th&jWv`Khx8?^-MMEX7o|&F>GZj=*0{VQBM8Fr(YNcb)i$~z#s&f&c2my) zZZ_MSZ7&R7*~D4|UMuLVMOQ$Z^$1Bb-18>rub}0%n`>7hFL)KKb~*4iL(IpuW^`rQ zyLf%0*}(KcQUtmM5$D}@fYv%|dzyn#Sqr@N%`hT4HtJ#EoeY-OmhbmI@v(|3bc53k3f$P>?x`>+iUSW zi&J*cayc#x1ywvdDilR<#`F8gf%xdK5%SPDFUCuEV4iq-g7)~n3BrgAFV$O{!8xV6 zsxlipVbEHQ?;Z|)KUJ%aVNQ&+;Zw^|_u7>pICW+7@+qLgwcvp$crAJ$ z6%wafq{0pe%A1RoQX&KgS*R;%o z=8Rc33#Pm8zns-OFa}r#tu=_*1_uC+L8Ig>&X7|PiAnOdxdn{2W2WfK-g1DctaYnz z1AWc)E#M=&w{ms}j#315!8huv4rzvIT)uT1 zN5xbRq8}Ao5lyOln0PG6#CD?<#(GV@AK>Pg-$oemFNNisf5~I#cXg#w=N!r7e2|wHiB?AXP0NMxNl}YGFn0+pS*?aJXKxtJ2 zlL)F9T0+ZQEvokr=W@l4oko|8=m)_DbD9t>>3Zw}`~b604k3$ylEPaY0m;X(Kx+$Q z^O;l`8v-Yi+$a;tm5z*c3}MIG6^^8>uZLI}+W!Ks(TnjnwyrCq5MRhex} zGoksYb6gfht5D6Q$vkuHp*n!1Ql!l5Kpa4|6>5}QWmIqsYDv>!!Vdx}w5Y>p-Q% zNPhPkp@kqMSPO&vBw|U>?;fE@Qsn}NT!hHj7M?c^n)Z>sZH4zo@L%n~0AF`uvP-&a z(XE_vaV?qJ+toml%Co3oNv7i&t&I>)XbyC3ThPAGdsmy@QSXz zQ}Uteu=e8MS!X*NtmQjb^Ex-E?>i#C?)9D-kgIk|B~&~#N7gu49uQnkJE(Vvi~aBF zvj6t@N@YsVI7bt^7sPmu^XCReW9;e{W5=&Y^*EcD4Kn?L@Dx>D>|Acv!#D%mqg#tJ zz~rqO)D;&ecD3voksAUhN&?7|?*2F`>SeL#fGOPDIjaC%Q8pLN-m(3cW%TjJKxm*D zgII^6QdsJzp$kNT0Y7YIB5MaOVR*w zTlVx;-nX_2zTMAXb=6D#{Eq8GckHq(BW3GSmhoMd^$Y5BKf9Cd=cV-=as}pIIUVKT zJA;P}Z7#wCSKKP#9r=E7*Sz8M3&4Rtv1WY9#+kusM7Zl`eQVqPl7)I8kS{qmEZ+hV z{N@dm)Jpz=w~V*Ux2(79x16^!tFE8-3vcAVWcs_+w^CyI=K6^%+oC9$z7jWpS@x4LitObOj@TU zb>`AKnY0dsanjD0qd9+mP{Uu?nZE(AWFop;kndu>g5^nJ?K1)*V8e>U*>}|wBuvrs zz_>LiW*}to5G8LUZfzhFG)NpVWC{_L6Rt)rCKJD+z8E6Ow38kVX;F~LX-d-J3joE5 z#m6e&r^V9!-hKCJ!GvevjuF`2f)s|#O1&uIM$l-kHXDh1RLkYGwcfO-v4M1deQPDo ztkv7?x_TMp4i*fex~egouR)2|g<*{`bwgD{3{J2RUO@t?GOp%SscCfPZdf}+>C-ZAVMdEgFCix$T3i}STtxSug-v`SleDQe4N{$tXIy6 zf^t^Wr|`auvAK-kfV~z(@qrpOWlglia7uw&w7>|(U;#+KjOSt(OTOBCEiSj<@of#q z=Mc%nnZ}i7%fIQy1!xgCire8D41G2q=hCStwVNTl18W+w6-Hh;z1ay~!`#y^LO3^~ zG!$O`t2jei2?k`++-=Je!x2t<{G$o^eaD_Nb7t@3`_k;B8eD!sb~zo z4s=V%*Al+=jW5g)d+#95wQIfvZSjib6Ny{j%kb%CMrSx46~W~gU4f*xmLH7GWE`mm zjPsAK1sth*>^p3mA|FG!-FxH*?0DLYdJRp$%p^irgx$SQ{jm6lLfWIRZ^0o;@7cir z;A{RJLIRB|PNQb;!NG$9GA64?L*mQ>sCj)$p14mO7$?ID-jIT>;i?K(p=u$I)~!Mw z-J~w8ENE~+XdAn&?%EpSQBJc1nDIP|9KrKhSVk@2yV$L7AnN7EcCWb++Z)}DI1Ac^ zYeuQFsIFG|O!i7WWKnF_JJ(|yo=S~92vp+I^P6zdx3#+kuBdLa1Dz}8WpPC2s7*F^ zk;#ip1k+>$O~tuxOA}_C>jc*{i9W=(v&m(gO=`p?Ap>}C$t{|WH`tu$-02BK_9t;C zr0_6`FyyTn5VivncR6hkvehFz=WPP}2pJQ%&bA7i{3#qdC`d z&D?vg)BDf|qOM9_0TXKQ1}06l~sM)APp5OV_t2h0r{Mg|QIgcilfU_>h$CwAyXcmz_k z6pl$-GBPDD+Yd=7qu_qsCgj7(ptSWiV{I;VDoU~}-DM0emz3nwl0sCJXMjT9W{o0h zlu$CY3z%6ei>%G82rt47Til?ofyit6uYGs+cOUrCuYTo~&!F56Hp0rZ`XTg5&oD)S z`lC#aF*%MTp1#&q{)k^*Tzm2SJ|27V_r(WB9NDR(F4Xui`=kzWh@&vhCSiq}8X{^i zDT|lND(Tk0L^s02NQl<4ityIuEd1Vfb()zmdv|H-8efJKw68QcvcW;06&K?y&XBCu zw?+NUr4SLzumigbU=d+YEC#(Eu)F|bfXlaBD7Va#(RY^s6j7S#XHm+Wo8+=D80v8} z4Eo4OR6Q;b6Nj}{^J<{JiaY9OnEWXwE|Wivq>|H6IWKU4k;}N0PBX9I@)?a0RA;y} z;pp0=twIG3u=^hQpM{WPnC%wcnapd>EC6cbp9NJ0)kbB;>)z3Q=QU}dVy@hDC;3%a z^_v{iA7>&*ZM6@Veh-yHx}?aZGs(X>I6ns#^~8#dFeley za1fXB`Yg_nYtwfGQ=SIeg#T{c!dd_Ty=koiozjle7923Ev*XBqhYXM>*b`8RtVyOs z9bsDomYA`VBRdPjf?dJJK&V^T#Q-}jhf?MWNe*tLo0gWMaKZgWbVe|?`T}!xBm|e8 zHTCCseUr(bXYv=A47ob5xjM(WTZJPfE-}Yd;8j&eYKGhz=XslJ0_@74Y8f_8vMInv za0yS0T%#TJ4Zd*5XrlVafi5U0TtLQv*gPr>w9{!VXcU+r9h72AJ}}&>|Q<2$e9@&SVY6R^kFoKsa(Qp#3-_ zQC@X|H^0T?>r7;=euX)q2F}$1c7i}it5JZo+k=QrGS_KWE{^F-N+iBkSdxGl&JjQs zYQ%NAM~*@xy4V^7KNZSsRR$}91;yo2?%Uz%;dP(WD4NmIo&8ZXSIH(WBCYr0@-XB{ z`i~FkxIhhY=b8kRVV~*kG}XCbnM%sysX@e+7C%&|mVMfq+>c9+a9-xn4@2NBc|*fr z@JOOe2J@5h#>G4H#$Q82*axVCKTvL2|uq!F(ASl5wOdC7* zn($SJ-pNO_gF)YFVEY?j$P~#KscY!5KYU02N*RO7(Z83~#Nvk&bF-IyH2uIz#fWE7mfWu)0)j?P zavnv1*|}jT0jO^g)`#~8^YfV4xwxpA7L11w!Y4fZ(nXW(EW%uW4hCK1y@ZVi*nD&W8!tA2KRh8E zAaC*Bj4i&6?hVfHJT9~-qps3U|7{|9I})$Uc8W!$Gz&i>gdOf<-yO7-u?`O+wxJ>f zpqZ;QA~guDw1%LJjs+eJwT+NwoYicl6vYFaEP8up21ak|x;DFhfvt#85j~R%l}~ep zO1Md<2(3g>Es0Sila&gUaS+aR;wi*W?!!pncLMgZcg2%57>g;LhXY7_{JW^Beilhw zie8J-?K0{d%g0qz%E$i&>2=^#&BLY!iOb?se;0uAAruxgg^;ekuG7^m}23v@#^8^xdH5~i8@daQlKiO2^t+P zPh8Sfu>S<_ki}B}jBofCd_J4dXDs*{^B>Qoi%{D?D1IcRR9WOP`hK3v#;BR$BrZxU zSL&?1?Ld|;z}lnZ(#4$xz5m7PePmxw*xUZaMv;vtM;X5q8^R8>9k4dgV^|x*r`?QM zUpEoO1X$-D1{hN|Hp3KQBoz-DerCZy4M2LhpDzegkOCMmakfgG2NTQO#~r|UuuC4R zTNgGPUb)@vdV12F|g=POice z$P6&R&?tZ>4}LRj92olXGV#*$!5UQed6;CID!>Ks5q0?(5_ZsD4kuqA^E_&Q?AS2v zsjkgkE6QveH2k&1+J%IwWr^g(m6wkSJ4J1TqNGH%ebXLs0l+Uj1pRzvTzEmYzj9-z zs-i%#6WZLjJJ=BUGArfN$iJ3bzri{+B!l?)apbEx8lAM^tFz>6?K!CF0PKDW{w*#B zd7M^58AMh1zU#PV&%48XrRsjnaN{VbL(k%rRRIG-+e17cR1Q}g!~=EXE>r@1p7s`C z??C^%>ObS5*gOw6z4ltjXw7JBY=47wT#fwq({aKc?{F;CKpYEV%J)~l6AB+kU>a-A|1r(`? zC9M7<9v?F@IUx0)n2ZyetO9m72?{%o`_+7_emQ8}$sP0+&W;9#3Eh~5_j-5jH@N*t zhNI(o7N-#4kYu=7QUhIctXm|4J1+j$2g}99?;YK@Bk#NJ{}XTcsX;?D?i5QJIL8i7 zB9HI}l7RrC5#rw^Vu)dcJoYcit|b^DcPFNY>)(@ny_D78F5xRB`EChcDam(X_)5u> z@0QGJ(Z5%|Tr#UA|CqnzAB4d&qvZ?w@yf-CccfdLbu# zM=_>boUB+x2ktMJDBC7N`v~qwE)W;-oEFwnt43Ej5Z8Ksy?{Iywm52p%36E_&FYYc zEQ9z?X}Rke{V&YvC@6xJ$wuZ7p`k!NwNk&2`-%RFGXTZ4Ylw+xPl8;jf4~MFP7@lY zah{F@S%P2X+=&!|sizH8%uuq$Yhs4dhsM2ML>t^pBy;pdw5K<8zXA0He;2g94-=36 zFyFy1X3Wpvm%87@FHowWPZ;ce4%hgo&d*=9)GDqExQ6i#s|WEUSUuPpL>vo)5RPZm zE27@dqFxF0h>@r8i_n2z#_trK!*->S=ZJcz@N7YzoxW>U1oyNui&nsd z(%LkB=WvbB0coI)v|1K4J&g|q?cI{JJdc*Kqa05x;0ds;e*nL8coJTyktb*Sb9e^M z6Kv>qeS8s%rxa>Wj%Eo@lhw0hPjfW$#P_?f*`l0(P(C8$9>+r~WU;Y>B>e(T(Z(qO zj!?w}edHc()!Uc-`eWb0lN8`aZ&~Bq36faVfwEE0tbS1K;vw-v3~jt>0jRyBJ@B2qwMd2)$5^cuFhl!5gsc4&Plg;VFcJ6=n|I zkZ~~?ZjIe8r5Eh-{7N@^hI?N5@lN2&mWz0LSfyR#r#5K==!&{N-lS@%Wt3pw1?6F! zxm=GLSFn?w4&KthPt)p#;eJy)VmCQhGR44#p12v#s031KdTa`gLJox}Hz~mv-B1Zk z@IAK?o9y|Q*i5vW-gBJc{gHCsd)LV@Vm4)bi=9Ykl(kLmN}WLd_1FO*=!j3ULpj0`1xYC%K#h^3N+e98|T zSo~atK9)BvT|!rucxhmT_p{iGiVcAxzifW^%YZ4Q6yUpx`p^Ol(fI=Bi%GC<3!zYK zYA8a^$v`VZ)w_RxF^IoREBgux-5hjzd>!_dlrtD;!`9&azE zeh@MgEN5*3#4~g{W0e^O-RvN*^*_~#`mcO&s)2YsV&|O*TN19r?5)5ArRF8F(H z#=|BWx@yYs!5u^%&Kv0cUa~=O|K%qakC_cr_QZ8U0^~&qX6>k<$tvF`^%+xzw~}#H zEI;DjyFg^x8mm~nGakf~N30`X<>?`N8Y_7x{ch(Y*iWxruHGg;Lxl;=kCOD7FxN56S;3h&1B%|eoC z6VIA451YM@?km^#Cz`R@qP5Byli7L*8E(cFFl1sr5PI3BN7q)%^#1m35_!TVm_yu2 z16?{!G4NlDiSjKpr-92k<|ZiN9t>lN!=P7lFl1+4Gazxo?i4*^nu?a$!&b(j9z z2)o5~Kw6Mhynqbma6)!NqFQg~ zIMYCr@P2}6;R^*qq(_u;dMd(K2*8JVy1uYF<$k$8B{Da&avmOV@p=9-hstoMI*J6Q zi^y5Bl_F89|Hbs>diBNQ&ShSAB5k21Xq#X1z=F~r0f1L#r6uZ9 zv8(Vrk#)8|hkMAulr95T&x+=BfSc$0;Qy9Bzu4?ZLO($4MSdJ6!XSoFLAz#8zY`s?U2It3(~XiwRq zl;lgv+xJzPb9mW#k^`CC01iu)=8zohGSsGJ&fLx+8d8E1#eP(zg(Q##5;^fcBqF09 z3jvZsQdQc*1$SlZ4_C2HumN#p@8FwFTrpm6(=HmsujuC7%AyppE7ZiNB)f=@02%S? zAm?po3r-WPGby8Z#DGNAC<(}{pH0?!B!VUD{cGsL2;-LUwM{)j{9OHa-WV=#W)+*w z@sTe+5yM{w5GAXAiB-PMgq*EBuHXMd4laG+Y!Q7Av-IOkxQuZrK|2C%v=O4S+)3Sx zLP6Wy9JLqcd<@SAT(U*H!{Fjri}WiS+|G%GBF`q94kt8SLI@2}i+mWzQJg8P?VxEg z&X{6j$S_M~v`(f>)DZ+DBIXd0P)JwvFig)NN98uvFF?65LJQMatGTUdETt&Ctr5=B_+i1V(j`*a^=X~!~H}3!>bl*0n2P1>~lsgN3-j5P#;--IfULG=^yDIB>ACC ze#IA8P5^scfV~-FI|M#E*CL29*0`7y!6<%Bu%9BO!i&s+t7ukyh6d|Y0-yltIsi(v zKGZ_H)s^ z7}9>_=m$|~sWmy&W$~13oA47MRxN_Zt|s7LeTH0PUPdIb6bKdif5+quhw}uhXECMx+gFkyE2Z8>DW&U?GLsk89DS1evD-Xa;y1>!&v{H?VoyFC zVm9nkz4O;x1Y%}@$DvD+w&}ZW=CFA;4ReinN+{hzTV5Q36>##F+ps^GTPgPRO?sB%>q8eds#3|nVa&lE1--k7uqpTuQ~{J zy7P!I;=rB-tLsgpx<0*eehynd^LQ>V&+*9u{&@uZ01%FG8L z#UZvT^T9hb=O}W5 z;>hkCeEvQ3;^$C9{%!|@&$zx_UdrVQf5`ZL;XuyKWpj(jEg=Z)As{`2{6og~6`#W( z@T8&s4)Z7Es}+1n-)&L4z-bYRi%oAGrP7MC{N@PJ7l?zd{sR;u!yMkr$u|SbosOav zs$~=5hTbwMv0o}Z*{vn9j;hs-M7G6wSm+nzYbHHyTWrtmUgjV7kT0HM7rU~scT~QC zk$v~FcUSg-$&NY+sp~)qU86F%5v%86mWh`=_h?&MdW_=^??-~~YZ6~q36cWA3aext bG)s2Do~h0GbN9?0nOi8#%&g4K%uW4&ehGCn literal 0 HcmV?d00001 diff --git a/uniter_model/model/__pycache__/nlvr2.cpython-38.pyc b/uniter_model/model/__pycache__/nlvr2.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8571634ffd1cfce25fa7bc82f2c5f6ba99af1297 GIT binary patch literal 5848 zcmdT|TXWmS6$Ta{0D>3My4g;hhP_P2?ATJH-qJKlbFrPIwN;NTw=>KIg}9Ie8U*PD z7{wg$$i9Pzy~CPJC!HA?9kW?LyTGl; z>ex+Nwe8616r074+jLdmiAtSvv#e|HX}rkY#~OEo;v=J3LBGVy=$BQ$ihhMx(XR%q zy?`rf&@Av8G&Oa_InM5CjdO!f+;_XZha%kFOY4_gjry&yCAvv>C#}ER6}_(T)36(_ zxi`aB5GMhz_hTN2db$_XZ+);;k8<^z>)wmQ6azb5QQy4z`#T%;PL~If`xNi3Z8mht z;uzB1>BlXa(T|?$7=g-s@A}Pq4JKFe8E%VUOZZ_N@LM#2Eahs|!PSem`cWG01wMZ} zO@o-GYSAO+4`b3vHqZz(7+Y&Ht~ZS@v>lkA8;_l)$xU8(tThWTqs^VideizsGqt7- zqquT8f9Q>_*o;0E8@K&X1bnbwzkzk7{Wz%K+x%eb`klGAz1&mRkf!6O4Y!V?HC+g; zhuY)~^q;&Ak!p{$k)G*eUDr~U=^1;(GR-qmb5zLmw#D^jc*dA?cBYM-OlueETCUyI z4%p^XhB*#gSxEZ5KuE{)LhOL&2^)H8ZU-V2bcwWXV$r^krM-~zAoh}Q5Xf>&@8xy- zX|JEkg5T>!4`o@+upf{wBnO5xlOWoWR;wHDguBuXJG&kaI|sTfdR{B?lf?55w10iH zelH0`vOWl-C>*S>rQQ9#AXwk)Z?A_Pe>b?A2A`!@gU)uqc^L1mM>Ktm7VoC(9qjaa zU%8H_9A-V2YI$q$%PL-B31{L!XnGmHWqt7CY5TJ_H9?C!YkH8)3pn4!nLLI_VRf#< znyhWK%`qEd4`6YvZ7J)rYud-J4qU&>}?%;-H*wTJieSUSfg9&XYJtf{Y=SA*4IA zX(CtB?tScS`mh&xb80K!OlfjI^<|-jNcRp*v5MQ9b{y=}wNekIfq(HGnpufUH2sqD zhT}6{#6YqLq3Mg*2?r-TEVE1c3;N*lse5wDBwHSIeq*RU|c>%^L?yUd$uG^P|WM;uesHdy)STevPPr-zzhOJju6< zcIxRN4r#{SIG~6CKJ*o^H5|2L(&}})5dzax@#Z3k(r~xm?I+UwGz|6ySyAdYWQE{1 z!9!Dl-D%X+=E!0gBQQNYQI>n%BqW?uJ%?7a6ZmPqE{x?p*H6>f>-foi+_CEep9KJc z2cL;6^dLL+#cq&dtkvyc=D>^aRJo7@5BhlDV2T74wQ4@;Hm%qrTx%|fAldVKFv03}8JB&&1`LoU>15UHfmI!-MkPd$Q zb!jKfFZz~w=T)rnoAb`;IY*EYM47}231y}kTqx^OTAKsvnm+jM@jaZp(*F@s3$Qv( zyhP$<630POxzjuJ)@OiXmuOef7pfp`;V@a}d(Kfb{{{oJ z?32+<`e&k$I6VF-bM}>T&x3|bE;Y= zcm?&TtK3rmr+Ng$GfdHfk^n`4NglCdjdU-qJmj+=u>Ray{iyE4gE+EvOh56dM_9gBrIDZahJ% zU)Dg(Fm|437}9-OXj_>bEUp@ zrGD+>D|Jej^*3_O$Kop1L6u5@Q?}nlPt*SN_rEDaBE=ivzOd&ew9-zHHUsczmUhJ* zXavQmiWiKbct==7t8m1lP=&+gKs1%Sa*5^cIMN{2k5N-IyZb8f8i z5t(ZTXF~9xu^I^2BF<%fRUfPzxBDz01jScF-)z|8I$GieiT6mnPl9Gr`A~d7ts{6l zikshJz{9C_8Cg)T76j>@i5YTOI-WlRUN{bydEJE5H|L^IMc)}U8kJA!!d-+lyV^qdq?c|uoA-BAf=85|a4%W&~F z_-R^sq(NInZ(*#D>F!#3ZgRJ+s=&qhyCzlO5?p>6d0=H)1*(kPSe*-TT6_-TR4%E4 zK`zDpjyFRfk4oo6UOH0rM0Y#!CNb|(5A`5S{!0G0p$Zali{_gFhWH)2a%oag&flxz zL%D}6(kvb!UY<>6mCKp@p^sT}j;IPf3%+K4Wf{|d6W(s>?td&38gN7ug5NvKVP?Sd zzh4Y~3ySdt8RiU%apg=Cd?U-afN5teV+)n&*RTwCq>r_ajA3b529t0{4E2_gmALT& zI0ZN0+m$i+25}8uc)~tn!zwB*K=XYztYx)luOG8xSDza$X4Mg{=TwtdvxQlE0jsi5 zzoBl+7E!^On0skRkzhkp#qZoiyF%KLlTv-YlvQ^bcW5n02uJfy$5*n2Q^xTuRP2~h zt%2HjCfcMgt9h}Gqx*)dS@F-hxS3Yk)oc|^3*Y3xw9aSe!L(KP`p3R!_k8<#_0#w!AJh&Ihz&3M^zxqr9kUM$$JH^Q60n zcK)5mQJb_D$AWTBP9O>@J>|r#pgxBQ;!634Jp$F7O9}R*Muf}S>) m%vD+^3L%}>^D_BO^1g(=BbgR`#wo*IHN9-q8nu@f7yk>PXk(3Ngj&F&7$z|wZh%$TO@@ez-5}UqAq=6_XsU_thzX)aykI z<(GeWar6&ppHk!X;-c{mzVc@ff{mHr!kO_gNA1p>xjXha6Z^s!!E-+LMJQU&**K7V z6KqRes`3!FGlapH8Min5I5GpW|`BF$78 z7mwrBLA;J{9i@75EX6R2QGRppAbu}F$Kj)$zP9@9yG;{4@2wvxj$>aqBwnHnUBef~ z71`Xla#3DT?=VypD6{u20YLHmsZxb9UYVUr!ylM-UMOHdtze$&`Eq93y1}EA!_e@R z;cL~RIf7L}&f{R+)KPQ1q|D~y&Q7PJzF_Ja{Qv#y(MJUO=ro&Q{?Q#>e0(hB=y-WF z0_;r`n?9D)3HGb9 zRnja?TP*h>m^t4|Hlbi#7NU)h&r;`iP#Ws|)~$K%)PTNnL5MdkA)FJp zanOx+-JdzqI{mr?@Xwp=aJTqJ2}dh@h=B1KZlnVe_?By*ijm&LL)-m^@8 zY}|4=9|nrzY50-hbHj7v(I_lLN3ak#VnBPi1wnx>zXrjaJ>KO5?!5HqL(5@ShYTF+ zJ86^OuFwqBwt$=fs#k!t>!|K%e&ST^Dgp#9J`XT>xY2a~_)S)Kcm16&c=&Zs^G$%6 zLb|roeU5p3F&Dy&*2(xL^-nLx$ z1J?A_SCAwA*SfExDm>$d?9sFOYPDDO2CVL5KOH(B&g0@VQPr{@Z`W4^s95zk-KqfJ4XBok zinsxWt`_PFY26(QA0t5|#?giGk$I+4eVs<^K@10m-yi#{WF~W}7O5Ef>!!7)mT8|9 zS+1?>9Z8*V|GM(X1)Hi9Sw@7>o!uVx2ou`I>X!yHh6hqjylheE7*mvYAeg%ccDKN6BpSC1uH)zK^L_U^zwUH-ixdB2{wC5;m#ZIQ#Ah^& zvA_Q0;`^v<^d^*azaEheF=GqQXB*is0_l*|D!&TO+PHpg z^hH&N`nUXy*PtVa(>i_Yk~oJu&OD??ia2yhA4?a?T^Iks1J$H5BF+Q#RmdT?+!%Ij zini=f-+{eh&^JgsAYlnaL8P3h)5)>TAPqr!Ce$mP;1w}pHlH>MS%@6g0!4k7ruiNT zioWsa`A{RMV>_DIWqqNjACO0j9y0ou@?|1iFPc9s$?lOUZ*7(JQaIBq+<; z0GKc@lBvS+C^mND*LF)!u%g6jT5sd&T+ZZOx~L_YJhuhDM!i-i;1OLm#$NsxT6>8r literal 0 HcmV?d00001 diff --git a/uniter_model/model/__pycache__/ve.cpython-38.pyc b/uniter_model/model/__pycache__/ve.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6ebb34202150f444b879b494bec2f7c388f8d401 GIT binary patch literal 715 zcmZ`%O^?$s5VhmN*gZIrLaoL}I*b$njdCG3B*e`==9lmv`pPL>xgii#&w9b(5_(kCRsEWJd{w*sevDvz`Q6j;&jCWe zSmomIK=~9%eFco63^B|cflDWIu$5immR{y{vM2m9$U=~P7Ko_qWxY-gMO-FXQl?po z5%>1`S^pkl5$l~G*5k>so2{{U8x51iwfGK~dN8Tw;NgSYgR*8^#LwW*XoyY1_=Gk> zvm)1^I#mTco-|dii@KuX8D5_A?%ln=J9Ny;yH%mN{6g{_NsEfJ&+T(2SuS@jwM z!x(9&#jSJfw1~LcJN1Er7Vm|Y7`}$i9334a-COC_fhW%Dls(b9XU>@i_v03wCN0`a z?Yj(tOWxKq2mUv`^6UDyKmEA>2?X2ab2aQ6Uo~?sO+-im5hg@NaL#yRNS(BHns}S< zH@wmX6(+b3PE%=;S-}{uh$3h@i+UF)Qz4u+W2%n7l$V8-ag>P zbf(qNGoIpNVuHM`CdJf*ML8v`fSXlQ657T1DK$QFxv|$S+$Kog9QO~ZQI&JR^AKApa zZ46SBb&e8lvM3Xnqzrc64oRblI{ z(;^F#38xz11QP5mwU}yMk&>#f;l`Be8JB=%#vg40%C9ehNT?8mS2;8dMudf`GNGn& z%C$0HRaYF&_9SOb*MEkAwB;gg22C)pYr&;_$trl$48}!8~d)6OQozF~~*ClwwVOJ?RfF1dUWp7&cmSt~S_6{&}Np;1R^Hl-$yPvav z)Ma%LmDiFhdBfg`Y`t}F|C^#NG2wZT?Xvm+7{Y5fdhTNQ=rp$Od4E?WXO3()NPx2N$gA;(LzyrE|AzsVHlg=cW#H4xr0U;QS5RNC#N} literal 0 HcmV?d00001 diff --git a/uniter_model/model/attention.py b/uniter_model/model/attention.py new file mode 100644 index 0000000..8c320d5 --- /dev/null +++ b/uniter_model/model/attention.py @@ -0,0 +1,401 @@ +""" +copy multi-head attention code from pytorch (newer version) +""" +import warnings + +import torch +from torch.nn import Module, Parameter, Linear +from torch.nn.init import xavier_normal_, xavier_uniform_, constant_ +from torch.nn.functional import linear, softmax, dropout + + +def multi_head_attention_forward(query, # type: Tensor + key, # type: Tensor + value, # type: Tensor + embed_dim_to_check, # type: int + num_heads, # type: int + in_proj_weight, # type: Tensor + in_proj_bias, # type: Tensor + bias_k, # type: Optional[Tensor] + bias_v, # type: Optional[Tensor] + add_zero_attn, # type: bool + dropout_p, # type: float + out_proj_weight, # type: Tensor + out_proj_bias, # type: Tensor + training=True, # type: bool + key_padding_mask=None, # type: Optional[Tensor] + need_weights=True, # type: bool + attn_mask=None, # type: Optional[Tensor] + use_separate_proj_weight=False, # type: bool + q_proj_weight=None, # type: Optional[Tensor] + k_proj_weight=None, # type: Optional[Tensor] + v_proj_weight=None, # type: Optional[Tensor] + static_k=None, # type: Optional[Tensor] + static_v=None # type: Optional[Tensor] + ): + # type: (...) -> Tuple[Tensor, Optional[Tensor]] + r""" + Args: + query, key, value: map a query and a set of key-value pairs to an output. + See "Attention Is All You Need" for more details. + embed_dim_to_check: total dimension of the model. + num_heads: parallel attention heads. + in_proj_weight, in_proj_bias: input projection weight and bias. + bias_k, bias_v: bias of the key and value sequences to be added at dim=0. + add_zero_attn: add a new batch of zeros to the key and + value sequences at dim=1. + dropout_p: probability of an element to be zeroed. + out_proj_weight, out_proj_bias: the output projection weight and bias. + training: apply dropout if is ``True``. + key_padding_mask: if provided, specified padding elements in the key will + be ignored by the attention. This is an binary mask. When the value is True, + the corresponding value on the attention layer will be filled with -inf. + need_weights: output attn_output_weights. + attn_mask: mask that prevents attention to certain positions. This is an additive mask + (i.e. the values will be added to the attention layer). + use_separate_proj_weight: the function accept the proj. weights for query, key, + and value in differnt forms. If false, in_proj_weight will be used, which is + a combination of q_proj_weight, k_proj_weight, v_proj_weight. + q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias. + static_k, static_v: static key and value used for attention operators. + + + Shape: + Inputs: + - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. + - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - key_padding_mask: :math:`(N, S)`, ByteTensor, where N is the batch size, S is the source sequence length. + - attn_mask: :math:`(L, S)` where L is the target sequence length, S is the source sequence length. + - static_k: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length, + N is the batch size, E is the embedding dimension. E/num_heads is the head dimension. + - static_v: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length, + N is the batch size, E is the embedding dimension. E/num_heads is the head dimension. + + Outputs: + - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, + E is the embedding dimension. + - attn_output_weights: :math:`(N, L, S)` where N is the batch size, + L is the target sequence length, S is the source sequence length. + """ + + qkv_same = torch.equal(query, key) and torch.equal(key, value) + kv_same = torch.equal(key, value) + + tgt_len, bsz, embed_dim = query.size() + assert embed_dim == embed_dim_to_check + assert list(query.size()) == [tgt_len, bsz, embed_dim] + assert key.size() == value.size() + + head_dim = embed_dim // num_heads + assert head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads" + scaling = float(head_dim) ** -0.5 + + if use_separate_proj_weight is not True: + if qkv_same: + # self-attention + q, k, v = linear(query, in_proj_weight, in_proj_bias).chunk(3, dim=-1) + + elif kv_same: + # encoder-decoder attention + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = 0 + _end = embed_dim + _w = in_proj_weight[_start:_end, :] + if _b is not None: + _b = _b[_start:_end] + q = linear(query, _w, _b) + + if key is None: + assert value is None + k = None + v = None + else: + + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = embed_dim + _end = None + _w = in_proj_weight[_start:, :] + if _b is not None: + _b = _b[_start:] + k, v = linear(key, _w, _b).chunk(2, dim=-1) + + else: + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = 0 + _end = embed_dim + _w = in_proj_weight[_start:_end, :] + if _b is not None: + _b = _b[_start:_end] + q = linear(query, _w, _b) + + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = embed_dim + _end = embed_dim * 2 + _w = in_proj_weight[_start:_end, :] + if _b is not None: + _b = _b[_start:_end] + k = linear(key, _w, _b) + + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = embed_dim * 2 + _end = None + _w = in_proj_weight[_start:, :] + if _b is not None: + _b = _b[_start:] + v = linear(value, _w, _b) + else: + q_proj_weight_non_opt = torch.jit._unwrap_optional(q_proj_weight) + len1, len2 = q_proj_weight_non_opt.size() + assert len1 == embed_dim and len2 == query.size(-1) + + k_proj_weight_non_opt = torch.jit._unwrap_optional(k_proj_weight) + len1, len2 = k_proj_weight_non_opt.size() + assert len1 == embed_dim and len2 == key.size(-1) + + v_proj_weight_non_opt = torch.jit._unwrap_optional(v_proj_weight) + len1, len2 = v_proj_weight_non_opt.size() + assert len1 == embed_dim and len2 == value.size(-1) + + if in_proj_bias is not None: + q = linear(query, q_proj_weight_non_opt, in_proj_bias[0:embed_dim]) + k = linear(key, k_proj_weight_non_opt, in_proj_bias[embed_dim:(embed_dim * 2)]) + v = linear(value, v_proj_weight_non_opt, in_proj_bias[(embed_dim * 2):]) + else: + q = linear(query, q_proj_weight_non_opt, in_proj_bias) + k = linear(key, k_proj_weight_non_opt, in_proj_bias) + v = linear(value, v_proj_weight_non_opt, in_proj_bias) + q = q * scaling + + if bias_k is not None and bias_v is not None: + if static_k is None and static_v is None: + k = torch.cat([k, bias_k.repeat(1, bsz, 1)]) + v = torch.cat([v, bias_v.repeat(1, bsz, 1)]) + if attn_mask is not None: + attn_mask = torch.cat([attn_mask, + torch.zeros((attn_mask.size(0), 1), + dtype=attn_mask.dtype, + device=attn_mask.device)], dim=1) + if key_padding_mask is not None: + key_padding_mask = torch.cat( + [key_padding_mask, torch.zeros((key_padding_mask.size(0), 1), + dtype=key_padding_mask.dtype, + device=key_padding_mask.device)], dim=1) + else: + assert static_k is None, "bias cannot be added to static key." + assert static_v is None, "bias cannot be added to static value." + else: + assert bias_k is None + assert bias_v is None + + q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1) + if k is not None: + k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) + if v is not None: + v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) + + if static_k is not None: + assert static_k.size(0) == bsz * num_heads + assert static_k.size(2) == head_dim + k = static_k + + if static_v is not None: + assert static_v.size(0) == bsz * num_heads + assert static_v.size(2) == head_dim + v = static_v + + src_len = k.size(1) + + if key_padding_mask is not None: + assert key_padding_mask.size(0) == bsz + assert key_padding_mask.size(1) == src_len + + if add_zero_attn: + src_len += 1 + k = torch.cat([k, torch.zeros((k.size(0), 1) + k.size()[2:], dtype=k.dtype, device=k.device)], dim=1) + v = torch.cat([v, torch.zeros((v.size(0), 1) + v.size()[2:], dtype=v.dtype, device=v.device)], dim=1) + if attn_mask is not None: + attn_mask = torch.cat([attn_mask, torch.zeros((attn_mask.size(0), 1), + dtype=attn_mask.dtype, + device=attn_mask.device)], dim=1) + if key_padding_mask is not None: + key_padding_mask = torch.cat( + [key_padding_mask, torch.zeros((key_padding_mask.size(0), 1), + dtype=key_padding_mask.dtype, + device=key_padding_mask.device)], dim=1) + + attn_output_weights = torch.bmm(q, k.transpose(1, 2)) + assert list(attn_output_weights.size()) == [bsz * num_heads, tgt_len, src_len] + + if attn_mask is not None: + attn_mask = attn_mask.unsqueeze(0) + attn_output_weights += attn_mask + + if key_padding_mask is not None: + attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len) + attn_output_weights = attn_output_weights.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2), + float('-inf'), + ) + attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, src_len) + + attn_output_weights = softmax( + attn_output_weights, dim=-1) + attn_output_weights = dropout(attn_output_weights, p=dropout_p, training=training) + + attn_output = torch.bmm(attn_output_weights, v) + assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] + attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) + attn_output = linear(attn_output, out_proj_weight, out_proj_bias) + + if need_weights: + # average attention weights over heads + attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len) + return attn_output, attn_output_weights.sum(dim=1) / num_heads + else: + return attn_output, None + + +class MultiheadAttention(Module): + r"""Allows the model to jointly attend to information + from different representation subspaces. + See reference: Attention Is All You Need + + .. math:: + \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O + \text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V) + + Args: + embed_dim: total dimension of the model. + num_heads: parallel attention heads. + dropout: a Dropout layer on attn_output_weights. Default: 0.0. + bias: add bias as module parameter. Default: True. + add_bias_kv: add bias to the key and value sequences at dim=0. + add_zero_attn: add a new batch of zeros to the key and + value sequences at dim=1. + kdim: total number of features in key. Default: None. + vdim: total number of features in key. Default: None. + + Note: if kdim and vdim are None, they will be set to embed_dim such that + query, key, and value have the same number of features. + + Examples:: + + >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads) + >>> attn_output, attn_output_weights = multihead_attn(query, key, value) + """ + + def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None): + super(MultiheadAttention, self).__init__() + self.embed_dim = embed_dim + self.kdim = kdim if kdim is not None else embed_dim + self.vdim = vdim if vdim is not None else embed_dim + self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim + + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" + + self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim)) + + if self._qkv_same_embed_dim is False: + self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim)) + self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim)) + self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim)) + + if bias: + self.in_proj_bias = Parameter(torch.empty(3 * embed_dim)) + else: + self.register_parameter('in_proj_bias', None) + self.out_proj = Linear(embed_dim, embed_dim, bias=bias) + + if add_bias_kv: + self.bias_k = Parameter(torch.empty(1, 1, embed_dim)) + self.bias_v = Parameter(torch.empty(1, 1, embed_dim)) + else: + self.bias_k = self.bias_v = None + + self.add_zero_attn = add_zero_attn + + self._reset_parameters() + + def _reset_parameters(self): + if self._qkv_same_embed_dim: + xavier_uniform_(self.in_proj_weight) + else: + xavier_uniform_(self.q_proj_weight) + xavier_uniform_(self.k_proj_weight) + xavier_uniform_(self.v_proj_weight) + + if self.in_proj_bias is not None: + constant_(self.in_proj_bias, 0.) + constant_(self.out_proj.bias, 0.) + if self.bias_k is not None: + xavier_normal_(self.bias_k) + if self.bias_v is not None: + xavier_normal_(self.bias_v) + + def forward(self, query, key, value, key_padding_mask=None, + need_weights=True, attn_mask=None): + r""" + Args: + query, key, value: map a query and a set of key-value pairs to an output. + See "Attention Is All You Need" for more details. + key_padding_mask: if provided, specified padding elements in the key will + be ignored by the attention. This is an binary mask. When the value is True, + the corresponding value on the attention layer will be filled with -inf. + need_weights: output attn_output_weights. + attn_mask: mask that prevents attention to certain positions. This is an additive mask + (i.e. the values will be added to the attention layer). + + Shape: + - Inputs: + - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. + - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - key_padding_mask: :math:`(N, S)`, ByteTensor, where N is the batch size, S is the source sequence length. + - attn_mask: :math:`(L, S)` where L is the target sequence length, S is the source sequence length. + + - Outputs: + - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, + E is the embedding dimension. + - attn_output_weights: :math:`(N, L, S)` where N is the batch size, + L is the target sequence length, S is the source sequence length. + """ + if hasattr(self, '_qkv_same_embed_dim') and self._qkv_same_embed_dim is False: + return multi_head_attention_forward( + query, key, value, self.embed_dim, self.num_heads, + self.in_proj_weight, self.in_proj_bias, + self.bias_k, self.bias_v, self.add_zero_attn, + self.dropout, self.out_proj.weight, self.out_proj.bias, + training=self.training, + key_padding_mask=key_padding_mask, need_weights=need_weights, + attn_mask=attn_mask, use_separate_proj_weight=True, + q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight, + v_proj_weight=self.v_proj_weight) + else: + if not hasattr(self, '_qkv_same_embed_dim'): + warnings.warn('A new version of MultiheadAttention module has been implemented. \ + Please re-train your model with the new module', + UserWarning) + + return multi_head_attention_forward( + query, key, value, self.embed_dim, self.num_heads, + self.in_proj_weight, self.in_proj_bias, + self.bias_k, self.bias_v, self.add_zero_attn, + self.dropout, self.out_proj.weight, self.out_proj.bias, + training=self.training, + key_padding_mask=key_padding_mask, need_weights=need_weights, + attn_mask=attn_mask) diff --git a/uniter_model/model/gqa.py b/uniter_model/model/gqa.py new file mode 100644 index 0000000..b0d11f6 --- /dev/null +++ b/uniter_model/model/gqa.py @@ -0,0 +1,133 @@ +""" +Bert for VCR model +""" +from torch import nn +from torch.nn import functional as F +from pytorch_pretrained_bert.modeling import ( + BertOnlyMLMHead) +from .model import (BertForImageTextPretraining, + _get_image_hidden, + mask_img_feat, + RegionFeatureRegression, + mask_img_feat_for_mrc, + RegionClassification) +import torch +import random + + +class BertForImageTextPretrainingForGQA(BertForImageTextPretraining): + def init_type_embedding(self): + new_emb = nn.Embedding(3, self.bert.config.hidden_size) + new_emb.apply(self.init_bert_weights) + for i in [0, 1]: + emb = self.bert.embeddings.token_type_embeddings.weight.data[i, :] + new_emb.weight.data[i, :].copy_(emb) + emb = self.bert.embeddings.token_type_embeddings.weight.data[0, :] + new_emb.weight.data[2, :].copy_(emb) + self.bert.embeddings.token_type_embeddings = new_emb + + def forward(self, input_ids, position_ids, txt_type_ids, txt_lens, + img_feat, img_pos_feat, num_bbs, + attention_mask, labels, task, compute_loss=True): + if task == 'mlm': + txt_labels = labels + return self.forward_mlm(input_ids, position_ids, txt_type_ids, + txt_lens, + img_feat, img_pos_feat, num_bbs, + attention_mask, txt_labels, compute_loss) + elif task == 'mrm': + img_mask = labels + return self.forward_mrm(input_ids, position_ids, txt_type_ids, + txt_lens, + img_feat, img_pos_feat, num_bbs, + attention_mask, img_mask, compute_loss) + elif task.startswith('mrc'): + img_mask, mrc_label_target = labels + return self.forward_mrc(input_ids, position_ids, txt_type_ids, + txt_lens, + img_feat, img_pos_feat, num_bbs, + attention_mask, img_mask, + mrc_label_target, task, compute_loss) + else: + raise ValueError('invalid task') + + # MLM + def forward_mlm(self, input_ids, position_ids, txt_type_ids, txt_lens, + img_feat, img_pos_feat, num_bbs, + attention_mask, txt_labels, compute_loss=True): + sequence_output = self.bert(input_ids, position_ids, txt_lens, + img_feat, img_pos_feat, num_bbs, + attention_mask, + output_all_encoded_layers=False, + txt_type_ids=txt_type_ids) + # get only the text part + sequence_output = sequence_output[:, :input_ids.size(1), :] + # only compute masked tokens for better efficiency + prediction_scores = self.masked_compute_scores( + sequence_output, txt_labels != -1) + if self.vocab_pad: + prediction_scores = prediction_scores[:, :-self.vocab_pad] + + if compute_loss: + masked_lm_loss = F.cross_entropy(prediction_scores, + txt_labels[txt_labels != -1], + reduction='none') + return masked_lm_loss + else: + return prediction_scores + + # MRM + def forward_mrm(self, input_ids, position_ids, txt_type_ids, txt_lens, + img_feat, img_pos_feat, num_bbs, + attention_mask, img_masks, compute_loss=True): + img_feat, feat_targets = mask_img_feat(img_feat, img_masks) + sequence_output = self.bert(input_ids, position_ids, txt_lens, + img_feat, img_pos_feat, num_bbs, + attention_mask, + output_all_encoded_layers=False, + txt_type_ids=txt_type_ids) + # get only the text part + sequence_output = _get_image_hidden(sequence_output, txt_lens, num_bbs) + # only compute masked tokens for better efficiency + prediction_feat = self.masked_compute_feat( + sequence_output, img_masks) + + if compute_loss: + mrm_loss = F.mse_loss(prediction_feat, feat_targets, + reduction='none') + return mrm_loss + else: + return prediction_feat + + # MRC + def forward_mrc(self, input_ids, position_ids, txt_type_ids, txt_lens, + img_feat, img_pos_feat, num_bbs, + attention_mask, img_masks, + label_targets, task, compute_loss=True): + img_feat = mask_img_feat_for_mrc(img_feat, img_masks) + sequence_output = self.bert(input_ids, position_ids, txt_lens, + img_feat, img_pos_feat, num_bbs, + attention_mask, + output_all_encoded_layers=False, + txt_type_ids=txt_type_ids) + # get only the image part + sequence_output = _get_image_hidden(sequence_output, txt_lens, num_bbs) + # only compute masked tokens for better efficiency + prediction_soft_label = self.masked_predict_labels( + sequence_output, img_masks) + + if compute_loss: + if "kl" in task: + prediction_soft_label = F.log_softmax( + prediction_soft_label, dim=-1) + mrc_loss = F.kl_div( + prediction_soft_label, label_targets, reduction='none') + else: + label_targets = torch.max( + label_targets, -1)[1] # argmax + mrc_loss = F.cross_entropy( + prediction_soft_label, label_targets, + ignore_index=0, reduction='none') + return mrc_loss + else: + return prediction_soft_label diff --git a/uniter_model/model/itm.py b/uniter_model/model/itm.py new file mode 100644 index 0000000..358f924 --- /dev/null +++ b/uniter_model/model/itm.py @@ -0,0 +1,195 @@ +""" +UNITER for ITM model +""" +import copy +from collections import defaultdict + +import torch +from torch import nn +from .model import UniterPreTrainedModel, UniterModel + + +class UniterForImageTextRetrieval(UniterPreTrainedModel): + """ Finetune UNITER for image text retrieval + """ + def __init__(self, config, img_dim, margin=0.2): + super().__init__(config) + self.bert = UniterModel(config, img_dim) + self.itm_output = nn.Linear(config.hidden_size, 2) + self.rank_output = nn.Linear(config.hidden_size, 1) + self.margin = margin + self.apply(self.init_weights) + + def init_output(self): + """ need to be called after from pretrained """ + self.rank_output.weight.data = self.itm_output.weight.data[1:, :] + self.rank_output.bias.data = self.itm_output.bias.data[1:] + + def forward(self, batch, compute_loss=True): + batch = defaultdict(lambda: None, batch) + input_ids = batch['input_ids'] + position_ids = batch['position_ids'] + img_feat = batch['img_feat'] + img_pos_feat = batch['img_pos_feat'] + attention_mask = batch['attn_masks'] + gather_index = batch['gather_index'] + sequence_output = self.bert(input_ids, position_ids, + img_feat, img_pos_feat, + attention_mask, gather_index, + output_all_encoded_layers=False) + pooled_output = self.bert.pooler(sequence_output) + rank_scores = self.rank_output(pooled_output) + + if compute_loss: + # triplet loss + rank_scores_sigmoid = torch.sigmoid(rank_scores) + sample_size = batch['sample_size'] + scores = rank_scores_sigmoid.contiguous().view(-1, sample_size) + pos = scores[:, :1] + neg = scores[:, 1:] + rank_loss = torch.clamp(self.margin + neg - pos, 0) + return rank_loss + else: + return rank_scores + + +class UniterForImageTextRetrievalHardNeg(UniterForImageTextRetrieval): + """ Finetune UNITER for image text retrieval + """ + def __init__(self, config, img_dim, margin=0.2, hard_size=16): + super().__init__(config, img_dim, margin) + self.hard_size = hard_size + + def forward(self, batch, sample_from='t', compute_loss=True): + # expect same input_ids for all pairs + batch_size = batch['attn_masks'].size(0) + input_ids = batch['input_ids'] + img_feat = batch['img_feat'] + img_pos_feat = batch['img_pos_feat'] + if sample_from == 't': + if input_ids.size(0) == 1: + batch['input_ids'] = input_ids.expand(batch_size, -1) + elif sample_from == 'i': + if img_feat.size(0) == 1: + batch['img_feat'] = img_feat.expand(batch_size, -1, -1) + if img_pos_feat.size(0) == 1: + batch['img_pos_feat'] = img_pos_feat.expand(batch_size, -1, -1) + else: + raise ValueError() + + if self.training and compute_loss: + with torch.no_grad(): + self.eval() + scores = super().forward(batch, compute_loss=False) + hard_batch = self._get_hard_batch(batch, scores, sample_from) + self.train() + return super().forward(hard_batch, compute_loss=True) + else: + return super().forward(batch, compute_loss) + + def _get_hard_batch(self, batch, scores, sample_from='t'): + batch = defaultdict(lambda: None, batch) + input_ids = batch['input_ids'] + position_ids = batch['position_ids'] + img_feat = batch['img_feat'] + img_pos_feat = batch['img_pos_feat'] + attention_mask = batch['attn_masks'] + gather_index = batch['gather_index'] + hard_batch = {'sample_size': self.hard_size + 1} + + # NOTE first example is positive + hard_indices = scores.squeeze(-1)[1:].topk( + self.hard_size, sorted=False)[1] + 1 + indices = torch.cat([torch.zeros(1, dtype=torch.long, + device=hard_indices.device), + hard_indices]) + + attention_mask = attention_mask.index_select(0, indices) + gather_index = gather_index.index_select(0, indices) + if position_ids.size(0) != 1: + position_ids = position_ids[:self.hard_size+1] + + if sample_from == 't': + # cut to minimum padding + max_len = attention_mask.sum(dim=1).max().item() + max_i = max_len - input_ids.size(1) + attention_mask = attention_mask[:, :max_len] + gather_index = gather_index[:, :max_len] + img_feat = img_feat.index_select(0, indices)[:, :max_i, :] + img_pos_feat = img_pos_feat.index_select(0, indices)[:, :max_i, :] + # expect same input_ids for all pairs + input_ids = input_ids[:self.hard_size+1] + elif sample_from == 'i': + input_ids = input_ids.index_select(0, indices) + # expect same image features for all pairs + img_feat = img_feat[:self.hard_size+1] + img_pos_feat = img_pos_feat[:self.hard_size+1] + else: + raise ValueError() + + hard_batch['input_ids'] = input_ids + hard_batch['position_ids'] = position_ids + hard_batch['img_feat'] = img_feat + hard_batch['img_pos_feat'] = img_pos_feat + hard_batch['attn_masks'] = attention_mask + hard_batch['gather_index'] = gather_index + + return hard_batch + + +class UniterForImageTextRetrievalFast(UniterPreTrainedModel): + """ Finetune UNITER for image text retrieval + """ + def __init__(self, config, img_dim, margin=0.2): + super().__init__(config) + self.bert = UniterModel(config, img_dim) + config_img = copy.deepcopy(config) + config_img.num_hidden_layers = config_img.num_hidden_layers_img + self.img_bert = UniterModel(config_img, img_dim) + self.itm_output = nn.Linear(config.hidden_size, 2) + self.rank_output = nn.Linear(config.hidden_size, 1) + self.margin = margin + self.apply(self.init_weights) + + def init_output(self): + """ need to be called after from pretrained """ + self.rank_output.weight.data = self.itm_output.weight.data[1:, :] + self.rank_output.bias.data = self.itm_output.bias.data[1:] + + def forward(self, batch, compute_loss=True): + batch = defaultdict(lambda: None, batch) + input_ids = batch['input_ids'] + position_ids = batch['position_ids'] + img_feat = batch['img_feat'] + img_pos_feat = batch['img_pos_feat'] + attention_mask_text = batch['attn_masks_text'] + attention_mask_img = batch['attn_masks_img'] + gather_index = batch['gather_index'] + sequence_output_text = self.bert(input_ids, position_ids, + None, img_pos_feat, + attention_mask_text, gather_index, + output_all_encoded_layers=False) + pooled_output_text = self.bert.pooler(sequence_output_text) + + sequence_output_img = self.img_bert(None, position_ids, + img_feat, img_pos_feat, + attention_mask_img, gather_index, + output_all_encoded_layers=False) + pooled_output_img = self.img_bert.pooler(sequence_output_img) + + # rank_scores = (pooled_output_text * pooled_output_img).sum(-1) + # rank_scores = self.rank_output(pooled_output) + rank_scores = torch.nn.CosineSimilarity()(pooled_output_text, pooled_output_img) + + if compute_loss: + # triplet loss + rank_scores_sigmoid = torch.sigmoid(rank_scores) + sample_size = batch['sample_size'] + scores = rank_scores_sigmoid.contiguous().view(-1, sample_size) + pos = scores[:, :1] + neg = scores[:, 1:] + rank_loss = torch.clamp(self.margin + neg - pos, 0) + return rank_loss + else: + return rank_scores + diff --git a/uniter_model/model/layer.py b/uniter_model/model/layer.py new file mode 100644 index 0000000..fa3b3cc --- /dev/null +++ b/uniter_model/model/layer.py @@ -0,0 +1,235 @@ +""" +BERT layers from the huggingface implementation +(https://github.com/huggingface/transformers) +""" +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import math + +import torch +from torch import nn +#from apex.normalization.fused_layer_norm import FusedLayerNorm as BertLayerNorm + +BertLayerNorm = torch.nn.LayerNorm + + +logger = logging.getLogger(__name__) + + +def gelu(x): + """Implementation of the gelu activation function. + For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): + 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) + Also see https://arxiv.org/abs/1606.08415 + """ + return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) + + +def swish(x): + return x * torch.sigmoid(x) + + +ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish} + + +class GELU(nn.Module): + def forward(self, input_): + output = gelu(input_) + return output + + +class BertSelfAttention(nn.Module): + def __init__(self, config): + super(BertSelfAttention, self).__init__() + if config.hidden_size % config.num_attention_heads != 0: + raise ValueError( + "The hidden size (%d) is not a multiple of the number of attention " + "heads (%d)" % (config.hidden_size, config.num_attention_heads)) + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward(self, hidden_states, attention_mask): + mixed_query_layer = self.query(hidden_states) + mixed_key_layer = self.key(hidden_states) + mixed_value_layer = self.value(hidden_states) + + query_layer = self.transpose_for_scores(mixed_query_layer) + key_layer = self.transpose_for_scores(mixed_key_layer) + value_layer = self.transpose_for_scores(mixed_value_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + # Apply the attention mask is (precomputed for all layers in BertModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.Softmax(dim=-1)(attention_scores) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + context_layer = torch.matmul(attention_probs, value_layer) + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + return context_layer + + +class BertSelfOutput(nn.Module): + def __init__(self, config): + super(BertSelfOutput, self).__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertAttention(nn.Module): + def __init__(self, config): + super(BertAttention, self).__init__() + self.self = BertSelfAttention(config) + self.output = BertSelfOutput(config) + + def forward(self, input_tensor, attention_mask): + self_output = self.self(input_tensor, attention_mask) + attention_output = self.output(self_output, input_tensor) + return attention_output + + +class BertIntermediate(nn.Module): + def __init__(self, config): + super(BertIntermediate, self).__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class BertOutput(nn.Module): + def __init__(self, config): + super(BertOutput, self).__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertLayer(nn.Module): + def __init__(self, config): + super(BertLayer, self).__init__() + self.attention = BertAttention(config) + self.intermediate = BertIntermediate(config) + self.output = BertOutput(config) + + def forward(self, hidden_states, attention_mask): + attention_output = self.attention(hidden_states, attention_mask) + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +class BertPooler(nn.Module): + def __init__(self, config): + super(BertPooler, self).__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class BertPredictionHeadTransform(nn.Module): + def __init__(self, config): + super(BertPredictionHeadTransform, self).__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if isinstance(config.hidden_act, str): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +class BertLMPredictionHead(nn.Module): + def __init__(self, config, bert_model_embedding_weights): + super(BertLMPredictionHead, self).__init__() + self.transform = BertPredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear(bert_model_embedding_weights.size(1), + bert_model_embedding_weights.size(0), + bias=False) + self.decoder.weight = bert_model_embedding_weights + self.bias = nn.Parameter( + torch.zeros(bert_model_embedding_weights.size(0))) + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + self.bias + return hidden_states + + +class BertOnlyMLMHead(nn.Module): + def __init__(self, config, bert_model_embedding_weights): + super(BertOnlyMLMHead, self).__init__() + self.predictions = BertLMPredictionHead(config, + bert_model_embedding_weights) + + def forward(self, sequence_output): + prediction_scores = self.predictions(sequence_output) + return prediction_scores diff --git a/uniter_model/model/model.py b/uniter_model/model/model.py new file mode 100644 index 0000000..f2a5f02 --- /dev/null +++ b/uniter_model/model/model.py @@ -0,0 +1,701 @@ +""" +Pytorch modules +""" +from collections import defaultdict +import copy +import json +import logging +from io import open + +import torch +from torch import nn +from torch.nn import functional as F +# from apex.normalization.fused_layer_norm import FusedLayerNorm as LayerNorm +from torch.nn import LayerNorm + +from .layer import GELU, BertLayer, BertPooler, BertOnlyMLMHead +from .ot import optimal_transport_dist + + +logger = logging.getLogger(__name__) + + +class UniterConfig(object): + """Configuration class to store the configuration of a `UniterModel`. + """ + def __init__(self, + vocab_size_or_config_json_file, + hidden_size=768, + num_hidden_layers=12, + num_hidden_layers_img=1, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=2, + initializer_range=0.02): + """Constructs UniterConfig. + Args: + vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in + `UniterModel`. + hidden_size: Size of the encoder layers and the pooler layer. + num_hidden_layers: Number of hidden layers in the Transformer + encoder. + num_attention_heads: Number of attention heads for each attention + layer in the Transformer encoder. + intermediate_size: The size of the "intermediate" (i.e. + feed-forward) layer in the Transformer encoder. + hidden_act: The non-linear activation function (function or string) + in the encoder and pooler. If string, "gelu", "relu" and + "swish" are supported. + hidden_dropout_prob: The dropout probabilitiy for all fully + connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob: The dropout ratio for the attention + probabilities. + max_position_embeddings: The maximum sequence length that this + model might ever be used with. Typically set this to something + large just in case (e.g., 512 or 1024 or 2048). + type_vocab_size: The vocabulary size of the `token_type_ids` passed + into `UniterModel`. + initializer_range: The sttdev of the truncated_normal_initializer + for initializing all weight matrices. + """ + if isinstance(vocab_size_or_config_json_file, str): + with open(vocab_size_or_config_json_file, + "r", encoding='utf-8') as reader: + json_config = json.loads(reader.read()) + for key, value in json_config.items(): + self.__dict__[key] = value + elif isinstance(vocab_size_or_config_json_file, int): + self.vocab_size = vocab_size_or_config_json_file + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_hidden_layers_img = num_hidden_layers_img + self.num_attention_heads = num_attention_heads + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.initializer_range = initializer_range + else: + raise ValueError("First argument must be either a vocabulary size " + "(int) or the path to a pretrained model config " + "file (str)") + + @classmethod + def from_dict(cls, json_object): + """Constructs a `UniterConfig` from a + Python dictionary of parameters.""" + config = UniterConfig(vocab_size_or_config_json_file=-1) + for key, value in json_object.items(): + config.__dict__[key] = value + return config + + @classmethod + def from_json_file(cls, json_file): + """Constructs a `UniterConfig` from a json file of parameters.""" + with open(json_file, "r", encoding='utf-8') as reader: + text = reader.read() + return cls.from_dict(json.loads(text)) + + def __repr__(self): + return str(self.to_json_string()) + + def to_dict(self): + """Serializes this instance to a Python dictionary.""" + output = copy.deepcopy(self.__dict__) + return output + + def to_json_string(self): + """Serializes this instance to a JSON string.""" + return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" + + +class UniterPreTrainedModel(nn.Module): + """ An abstract class to handle weights initialization and + a simple interface for dowloading and loading pretrained models. + """ + def __init__(self, config, *inputs, **kwargs): + super().__init__() + if not isinstance(config, UniterConfig): + raise ValueError( + "Parameter config in `{}(config)` should be an instance of " + "class `UniterConfig`. To create a model from a Google " + "pretrained model use " + "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format( + self.__class__.__name__, self.__class__.__name__ + )) + self.config = config + + def init_weights(self, module): + """ Initialize the weights. + """ + if isinstance(module, (nn.Linear, nn.Embedding)): + # Slightly different from the TF version which uses + # truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, + std=self.config.initializer_range) + elif isinstance(module, LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + @classmethod + def from_pretrained(cls, config_file, state_dict, *inputs, **kwargs): + """ + Instantiate a UniterPreTrainedModel from a pre-trained model file or a + pytorch state dict. + Params: + config_file: config json file + state_dict: an state dictionnary + *inputs, **kwargs: additional input for the specific Uniter class + """ + # Load config + config = UniterConfig.from_json_file(config_file) + logger.info("Model config {}".format(config)) + # Instantiate model. + model = cls(config, *inputs, **kwargs) + # Load from a PyTorch state_dict + old_keys = [] + new_keys = [] + for key in state_dict.keys(): + new_key = None + if 'gamma' in key: + new_key = key.replace('gamma', 'weight') + if 'beta' in key: + new_key = key.replace('beta', 'bias') + if new_key: + old_keys.append(key) + new_keys.append(new_key) + for old_key, new_key in zip(old_keys, new_keys): + state_dict[new_key] = state_dict.pop(old_key) + + missing_keys = [] + unexpected_keys = [] + error_msgs = [] + # copy state_dict so _load_from_state_dict can modify it + metadata = getattr(state_dict, '_metadata', None) + state_dict = state_dict.copy() + if metadata is not None: + state_dict._metadata = metadata + + def load(module, prefix=''): + local_metadata = ({} if metadata is None + else metadata.get(prefix[:-1], {})) + module._load_from_state_dict( + state_dict, prefix, local_metadata, True, missing_keys, + unexpected_keys, error_msgs) + for name, child in module._modules.items(): + if child is not None: + load(child, prefix + name + '.') + start_prefix = '' + if not hasattr(model, 'bert') and any(s.startswith('bert.') + for s in state_dict.keys()): + start_prefix = 'bert.' + load(model, prefix=start_prefix) + if len(missing_keys) > 0: + logger.info("Weights of {} not initialized from " + "pretrained model: {}".format( + model.__class__.__name__, missing_keys)) + if len(unexpected_keys) > 0: + logger.info("Weights from pretrained model not used in " + "{}: {}".format( + model.__class__.__name__, unexpected_keys)) + if len(error_msgs) > 0: + raise RuntimeError('Error(s) in loading state_dict for ' + '{}:\n\t{}'.format( + model.__class__.__name__, + "\n\t".join(error_msgs))) + return model + + +class UniterTextEmbeddings(nn.Module): + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding(config.vocab_size, + config.hidden_size, padding_idx=0) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, + config.hidden_size) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, + config.hidden_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model + # variable name and be able to load any TensorFlow checkpoint file + self.LayerNorm = LayerNorm(config.hidden_size, eps=1e-12) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, input_ids, position_ids, token_type_ids=None): + if token_type_ids is None: + token_type_ids = torch.zeros_like(input_ids) + + words_embeddings = self.word_embeddings(input_ids) + position_embeddings = self.position_embeddings(position_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = (words_embeddings + + position_embeddings + + token_type_embeddings) + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class UniterImageEmbeddings(nn.Module): + def __init__(self, config, img_dim): + super().__init__() + self.img_linear = nn.Linear(img_dim, config.hidden_size) + self.img_layer_norm = LayerNorm(config.hidden_size, eps=1e-12) + self.pos_layer_norm = LayerNorm(config.hidden_size, eps=1e-12) + self.pos_linear = nn.Linear(7, config.hidden_size) + self.mask_embedding = nn.Embedding(2, img_dim, padding_idx=0) + + # tf naming convention for layer norm + self.LayerNorm = LayerNorm(config.hidden_size, eps=1e-12) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, img_feat, img_pos_feat, type_embeddings, img_masks=None): + if img_masks is not None: + self.mask_embedding.weight.data[0, :].fill_(0) + mask = self.mask_embedding(img_masks.long()) + img_feat = img_feat + mask + + transformed_im = self.img_layer_norm(self.img_linear(img_feat)) + transformed_pos = self.pos_layer_norm(self.pos_linear(img_pos_feat)) + embeddings = transformed_im + transformed_pos + type_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class UniterEncoder(nn.Module): + def __init__(self, config): + super().__init__() + layer = BertLayer(config) + self.layer = nn.ModuleList([copy.deepcopy(layer) + for _ in range(config.num_hidden_layers)]) + + def forward(self, input_, attention_mask, + output_all_encoded_layers=True): + all_encoder_layers = [] + hidden_states = input_ + for layer_module in self.layer: + hidden_states = layer_module(hidden_states, attention_mask) + if output_all_encoded_layers: + all_encoder_layers.append(hidden_states) + if not output_all_encoded_layers: + all_encoder_layers.append(hidden_states) + return all_encoder_layers + + +def pad_tensor_to_mul(tensor, dim=0, mul=8): + """ pad tensor to multiples (8 for tensor cores) """ + # TODO find out whether this helps speed + return tensor, 0 + t_size = list(tensor.size()) + n_pad = mul - t_size[dim] % mul + if n_pad == mul: + n_pad = 0 + padded_tensor = tensor + else: + t_size[dim] = n_pad + pad = torch.zeros(*t_size, dtype=tensor.dtype, device=tensor.device) + padded_tensor = torch.cat([tensor, pad], dim=dim) + return padded_tensor, n_pad + + +class UniterModel(UniterPreTrainedModel): + """ Modification for Joint Vision-Language Encoding + """ + def __init__(self, config, img_dim): + super().__init__(config) + self.embeddings = UniterTextEmbeddings(config) + self.img_embeddings = UniterImageEmbeddings(config, img_dim) + self.encoder = UniterEncoder(config) + self.pooler = BertPooler(config) + self.apply(self.init_weights) + + def _compute_txt_embeddings(self, input_ids, position_ids, + txt_type_ids=None): + output = self.embeddings(input_ids, position_ids, txt_type_ids) + return output + + def _compute_img_embeddings(self, img_feat, img_pos_feat, img_masks=None, + img_type_ids=None): + if img_type_ids is None: + img_type_ids = torch.ones_like(img_feat[:, :, 0].long()) + img_type_embeddings = self.embeddings.token_type_embeddings( + img_type_ids) + output = self.img_embeddings(img_feat, img_pos_feat, + img_type_embeddings, img_masks) + return output + + def _compute_img_txt_embeddings(self, input_ids, position_ids, + img_feat, img_pos_feat, + gather_index, img_masks=None, + txt_type_ids=None, img_type_ids=None): + txt_emb = self._compute_txt_embeddings( + input_ids, position_ids, txt_type_ids) + img_emb = self._compute_img_embeddings( + img_feat, img_pos_feat, img_masks, img_type_ids) + # align back to most compact input + if gather_index is None: + embedding_output = torch.cat([txt_emb, img_emb], dim=1) + else: + gather_index = gather_index.unsqueeze(-1).expand( + -1, -1, self.config.hidden_size) + embedding_output = torch.gather(torch.cat([txt_emb, img_emb], dim=1), + dim=1, index=gather_index) + return embedding_output + + def forward(self, input_ids, position_ids, + img_feat, img_pos_feat, + attention_mask, gather_index=None, img_masks=None, + output_all_encoded_layers=True, + txt_type_ids=None, img_type_ids=None): + # compute self-attention mask + extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) + extended_attention_mask = extended_attention_mask.to( + dtype=next(self.parameters()).dtype) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + + # embedding layer + if input_ids is None: + # image only + embedding_output = self._compute_img_embeddings( + img_feat, img_pos_feat, img_masks, img_type_ids) + elif img_feat is None: + # text only + embedding_output = self._compute_txt_embeddings( + input_ids, position_ids, txt_type_ids) + else: + embedding_output = self._compute_img_txt_embeddings( + input_ids, position_ids, + img_feat, img_pos_feat, + gather_index, img_masks, txt_type_ids, img_type_ids) + + encoded_layers = self.encoder( + embedding_output, extended_attention_mask, + output_all_encoded_layers=output_all_encoded_layers) + if not output_all_encoded_layers: + encoded_layers = encoded_layers[-1] + return encoded_layers + + +class RegionFeatureRegression(nn.Module): + def __init__(self, hidden_size, feat_dim, img_linear_weight): + super().__init__() + self.net = nn.Sequential(nn.Linear(hidden_size, hidden_size), + GELU(), + LayerNorm(hidden_size, eps=1e-12)) + + self.weight = img_linear_weight + self.bias = nn.Parameter(torch.zeros(feat_dim)) + + def forward(self, input_): + hidden = self.net(input_) + output = F.linear(hidden, self.weight.t(), self.bias) + return output + + +class RegionClassification(nn.Module): + def __init__(self, hidden_size, label_dim): + super().__init__() + self.net = nn.Sequential(nn.Linear(hidden_size, hidden_size), + GELU(), + LayerNorm(hidden_size, eps=1e-12), + nn.Linear(hidden_size, label_dim)) + + def forward(self, input_): + output = self.net(input_) + return output + + +class UniterForPretraining(UniterPreTrainedModel): + """ MLM + MRM """ + def __init__(self, config, img_dim, img_label_dim, + nce_temp=1, ot_pos_only=False): + super().__init__(config) + self.bert = UniterModel(config, img_dim) + self.cls = BertOnlyMLMHead( + config, self.bert.embeddings.word_embeddings.weight) + self.feat_regress = RegionFeatureRegression( + config.hidden_size, img_dim, + self.bert.img_embeddings.img_linear.weight) + self.region_classifier = RegionClassification( + config.hidden_size, img_label_dim) + self.itm_output = nn.Linear(config.hidden_size, 2) + ''' + self.nce_output = BertPredictionHeadTransform(config) + self.nce_output = nn.Sequential(BertPredictionHeadTransform(config), + nn.Linear(config.hidden_size, img_dim)) + self.nce_norm = LayerNorm(config.hidden_size, eps=1e-12) + self.nce_temp = nce_temp # temperature + ''' + self.ot_pos_only = ot_pos_only + self.apply(self.init_weights) + self.vocab_pad = 0 + + def pad_vocab(self): + # FIXME better padding after integrating huggingface + emb_w = self.bert.embeddings.word_embeddings.weight.data + padded_emb_w, n_pad = pad_tensor_to_mul(emb_w) + padded_emb_w = nn.Parameter(padded_emb_w) + self.bert.embeddings.word_embeddings.weight = padded_emb_w + self.cls.predictions.decoder.weight = padded_emb_w + self.vocab_pad = n_pad + + def forward(self, batch, task, compute_loss=True): + batch = defaultdict(lambda: None, batch) + input_ids = batch['input_ids'] + position_ids = batch['position_ids'] + img_feat = batch['img_feat'] + img_pos_feat = batch['img_pos_feat'] + attention_mask = batch['attn_masks'] + gather_index = batch['gather_index'] + if task == 'mlm': + txt_labels = batch['txt_labels'] + return self.forward_mlm(input_ids, position_ids, + img_feat, img_pos_feat, + attention_mask, gather_index, + txt_labels, compute_loss) + elif task == 'mrfr': + img_mask_tgt = batch['img_mask_tgt'] + img_masks = batch['img_masks'] + mrfr_feat_target = batch['feat_targets'] + return self.forward_mrfr(input_ids, position_ids, + img_feat, img_pos_feat, + attention_mask, gather_index, + img_masks, img_mask_tgt, + mrfr_feat_target, compute_loss) + elif task == 'mrm-nce': + raise NotImplementedError('nce does not work') + img_mask_tgt = batch['img_mask_tgt'] + img_masks = batch['img_masks'] + img_masks_in = batch['img_masks_in'] + feat_target = batch['feat_targets'] + neg_feats = batch['neg_feats'] + return self.forward_mrm_nce(input_ids, position_ids, + img_feat, img_pos_feat, + attention_mask, gather_index, + img_masks_in, img_masks, img_mask_tgt, + feat_target, neg_feats, compute_loss) + elif task == 'itm': + targets = batch['targets'] + ot_inputs = batch['ot_inputs'] + return self.forward_itm(input_ids, position_ids, + img_feat, img_pos_feat, + attention_mask, gather_index, + targets, ot_inputs, compute_loss) + elif task.startswith('mrc'): + img_mask_tgt = batch['img_mask_tgt'] + img_masks = batch['img_masks'] + mrc_label_target = batch['label_targets'] + return self.forward_mrc(input_ids, position_ids, + img_feat, img_pos_feat, + attention_mask, gather_index, + img_masks, img_mask_tgt, + mrc_label_target, task, compute_loss) + else: + raise ValueError('invalid task') + + # MLM + def forward_mlm(self, input_ids, position_ids, img_feat, img_pos_feat, + attention_mask, gather_index, + txt_labels, compute_loss=True): + sequence_output = self.bert(input_ids, position_ids, + img_feat, img_pos_feat, + attention_mask, gather_index, + output_all_encoded_layers=False) + # get only the text part + sequence_output = sequence_output[:, :input_ids.size(1), :] + # only compute masked tokens for better efficiency + masked_output = self._compute_masked_hidden(sequence_output, + txt_labels != -1) + prediction_scores = self._pad_layer_unpad(masked_output, self.cls) + if self.vocab_pad: + prediction_scores = prediction_scores[:, :-self.vocab_pad] + + masked_lm_loss = F.cross_entropy(prediction_scores, + txt_labels[txt_labels != -1], + reduction='none') + return masked_lm_loss, prediction_scores + + def _compute_masked_hidden(self, hidden, mask): + """ get only the masked region (don't compute unnecessary hiddens) """ + mask = mask.unsqueeze(-1).expand_as(hidden) + hidden_masked = hidden[mask].contiguous().view(-1, hidden.size(-1)) + return hidden_masked + + def _pad_layer_unpad(self, input_, layer): + input_, n_pad = pad_tensor_to_mul(input_) + output = layer(input_) + if n_pad: + output = output[:-n_pad, :] + return output + + def mlm_eval(self, input_ids, position_ids, + img_feat, img_pos_feat, + attention_mask, gather_index, gather_tgt): + raise ValueError('Do not use this') + sequence_output = self.bert(input_ids, position_ids, + img_feat, img_pos_feat, + attention_mask, gather_index, + output_all_encoded_layers=False) + # get only the text part (excluding [CLS], [SEP]) + sequence_output = sequence_output[:, 1:input_ids.size(1)-1, :] + # only compute masked tokens for better efficiency + index = gather_tgt.unsqueeze(-1).expand( + -1, -1, self.config.hidden_size) + masked_output = torch.gather(sequence_output, dim=0, index=index) + prediction_scores = self.cls(masked_output) + if self.vocab_pad: + prediction_scores = prediction_scores[..., :-self.vocab_pad] + return prediction_scores + + # MRFR + def forward_mrfr(self, input_ids, position_ids, img_feat, img_pos_feat, + attention_mask, gather_index, img_masks, img_mask_tgt, + feat_targets, compute_loss=True): + sequence_output = self.bert(input_ids, position_ids, + img_feat, img_pos_feat, + attention_mask, gather_index, + output_all_encoded_layers=False, + img_masks=img_masks) + + # only compute masked tokens for better efficiency + masked_output = self._compute_masked_hidden(sequence_output, + img_mask_tgt) + prediction_feat = self._pad_layer_unpad(masked_output, + self.feat_regress) + + mrfr_loss = F.mse_loss(prediction_feat, feat_targets, + reduction='none') + return mrfr_loss, prediction_feat + + # MRM-NCE + def forward_mrm_nce(self, input_ids, position_ids, img_feat, img_pos_feat, + attention_mask, gather_index, + img_masks_in, img_masks, img_mask_tgt, + feat_targets, neg_feats, compute_loss=True): + sequence_output = self.bert(input_ids, position_ids, + img_feat, img_pos_feat, + attention_mask, gather_index, + output_all_encoded_layers=False, + img_masks=img_masks_in) + + # only compute masked tokens for better efficiency + masked_output = self._compute_masked_hidden(sequence_output, + img_mask_tgt) + + masked_output = self._pad_layer_unpad(masked_output, self.nce_output) + # neg within batch + batch_neg = self._compute_masked_hidden(img_feat, ~img_masks) + neg_feats, _ = pad_tensor_to_mul( + torch.cat([neg_feats, batch_neg], dim=0)) + + # shared image linear transform + neg_output = self.nce_norm( + self.bert.img_embeddings.img_linear(neg_feats)) + pos_output = self._pad_layer_unpad(feat_targets, + self.bert.img_embeddings.img_linear) + pos_output = self.nce_norm(pos_output) + + mrm_nce_loss = self.mrm_nce(masked_output, pos_output, + neg_output, compute_loss=True) + return mrm_nce_loss, masked_output # ??? + + def mrm_nce(self, masked_output, pos_output, neg_output, + compute_loss=True): + # dot product of ground truth feature + masked_score = masked_output.matmul(pos_output.t()) + # dot product of neative samples + neg_score = masked_output.matmul(neg_output.t()) + + logits = torch.cat([masked_score, neg_score], dim=1).float() + targets = torch.arange(0, masked_output.size(0), + dtype=torch.long, device=logits.device) + loss = F.cross_entropy(logits/self.nce_temp, targets, + reduction='none') + return loss, logits + + def forward_itm(self, input_ids, position_ids, img_feat, img_pos_feat, + attention_mask, gather_index, targets, ot_inputs, + compute_loss=True): + sequence_output = self.bert(input_ids, position_ids, + img_feat, img_pos_feat, + attention_mask, gather_index, + output_all_encoded_layers=False) + pooled_output = self.bert.pooler(sequence_output) + rank_scores = self.itm_output(pooled_output) + + # OT loss + if ot_inputs is not None: + ot_scatter = ot_inputs['ot_scatter'] + + b = sequence_output.size(0) + tl = input_ids.size(1) + il = img_feat.size(1) + max_l = max(ot_inputs['scatter_max'] + 1, tl+il) + + ot_scatter = ot_scatter.unsqueeze(-1).expand_as(sequence_output) + ctx_emb = torch.zeros(b, max_l, self.config.hidden_size, + dtype=sequence_output.dtype, + device=sequence_output.device + ).scatter_(dim=1, index=ot_scatter, + src=sequence_output) + txt_emb = ctx_emb[:, :tl, :] + img_emb = ctx_emb[:, tl:tl+il, :] + + txt_pad = ot_inputs['txt_pad'] + img_pad = ot_inputs['img_pad'] + ot_dist = optimal_transport_dist(txt_emb, img_emb, + txt_pad, img_pad) + if self.ot_pos_only: + ot_loss = ot_dist.masked_select(targets == 1) + else: + ot_pos_dist = ot_dist.masked_select(targets == 1) + ot_neg_dist = ot_dist.masked_select(targets == 0) + ot_loss = (ot_pos_dist, ot_neg_dist) + else: + ot_loss = None + + if compute_loss: + itm_loss = F.cross_entropy(rank_scores, targets, reduction='none') + return itm_loss, ot_loss + else: + return rank_scores, ot_loss + + # MRC + def forward_mrc(self, input_ids, position_ids, img_feat, img_pos_feat, + attention_mask, gather_index, img_masks, img_mask_tgt, + label_targets, task, compute_loss=True): + sequence_output = self.bert(input_ids, position_ids, + img_feat, img_pos_feat, + attention_mask, gather_index, + output_all_encoded_layers=False, + img_masks=img_masks) + + # only compute masked regions for better efficiency + masked_output = self._compute_masked_hidden(sequence_output, + img_mask_tgt) + prediction_soft_label = self._pad_layer_unpad(masked_output, + self.region_classifier) + + if "kl" in task: + prediction_soft_label = F.log_softmax( + prediction_soft_label, dim=-1) + mrc_loss = F.kl_div( + prediction_soft_label, label_targets, reduction='none') + else: + # background class should not be the target + label_targets = torch.max(label_targets[:, 1:], dim=-1)[1] + 1 + mrc_loss = F.cross_entropy( + prediction_soft_label, label_targets, + ignore_index=0, reduction='none') + return mrc_loss, prediction_soft_label diff --git a/uniter_model/model/nlvr2.py b/uniter_model/model/nlvr2.py new file mode 100644 index 0000000..6837d46 --- /dev/null +++ b/uniter_model/model/nlvr2.py @@ -0,0 +1,182 @@ +""" +Copyright (c) Microsoft Corporation. +Licensed under the MIT license. + +Uniter for NLVR2 model +""" +import torch +from torch import nn +from torch.nn import functional as F + +from .layer import GELU +from .model import UniterPreTrainedModel, UniterModel +from .attention import MultiheadAttention + + +class UniterForNlvr2Paired(UniterPreTrainedModel): + """ Finetune UNITER for NLVR2 (paired format) + """ + def __init__(self, config, img_dim): + super().__init__(config) + self.bert = UniterModel(config, img_dim) + self.nlvr2_output = nn.Linear(config.hidden_size*2, 2) + self.apply(self.init_weights) + + def init_type_embedding(self): + new_emb = nn.Embedding(3, self.bert.config.hidden_size) + new_emb.apply(self.init_weights) + for i in [0, 1]: + emb = self.bert.embeddings.token_type_embeddings\ + .weight.data[i, :] + new_emb.weight.data[i, :].copy_(emb) + new_emb.weight.data[2, :].copy_(emb) + self.bert.embeddings.token_type_embeddings = new_emb + + def forward(self, input_ids, position_ids, img_feat, img_pos_feat, + attn_masks, gather_index, + img_type_ids, targets, compute_loss=True): + sequence_output = self.bert(input_ids, position_ids, + img_feat, img_pos_feat, + attn_masks, gather_index, + output_all_encoded_layers=False, + img_type_ids=img_type_ids) + pooled_output = self.bert.pooler(sequence_output) + # concat CLS of the pair + n_pair = pooled_output.size(0) // 2 + reshaped_output = pooled_output.contiguous().view(n_pair, -1) + answer_scores = self.nlvr2_output(reshaped_output) + + if compute_loss: + nlvr2_loss = F.cross_entropy( + answer_scores, targets, reduction='none') + return nlvr2_loss + else: + return answer_scores + + +class UniterForNlvr2Triplet(UniterPreTrainedModel): + """ Finetune UNITER for NLVR2 (triplet format) + """ + def __init__(self, config, img_dim): + super().__init__(config) + self.bert = UniterModel(config, img_dim) + self.nlvr2_output = nn.Linear(config.hidden_size, 2) + self.apply(self.init_weights) + + def init_type_embedding(self): + new_emb = nn.Embedding(3, self.bert.config.hidden_size) + new_emb.apply(self.init_weights) + for i in [0, 1]: + emb = self.bert.embeddings.token_type_embeddings\ + .weight.data[i, :] + new_emb.weight.data[i, :].copy_(emb) + new_emb.weight.data[2, :].copy_(emb) + self.bert.embeddings.token_type_embeddings = new_emb + + def forward(self, input_ids, position_ids, img_feat, img_pos_feat, + attn_masks, gather_index, + img_type_ids, targets, compute_loss=True): + sequence_output = self.bert(input_ids, position_ids, + img_feat, img_pos_feat, + attn_masks, gather_index, + output_all_encoded_layers=False, + img_type_ids=img_type_ids) + pooled_output = self.bert.pooler(sequence_output) + answer_scores = self.nlvr2_output(pooled_output) + + if compute_loss: + nlvr2_loss = F.cross_entropy( + answer_scores, targets, reduction='none') + return nlvr2_loss + else: + return answer_scores + + +class AttentionPool(nn.Module): + """ attention pooling layer """ + def __init__(self, hidden_size, drop=0.0): + super().__init__() + self.fc = nn.Sequential(nn.Linear(hidden_size, 1), GELU()) + self.dropout = nn.Dropout(drop) + + def forward(self, input_, mask=None): + """input: [B, T, D], mask = [B, T]""" + score = self.fc(input_).squeeze(-1) + if mask is not None: + mask = mask.to(dtype=input_.dtype) * -1e4 + score = score + mask + norm_score = self.dropout(F.softmax(score, dim=1)) + output = norm_score.unsqueeze(1).matmul(input_).squeeze(1) + return output + + +class UniterForNlvr2PairedAttn(UniterPreTrainedModel): + """ Finetune UNITER for NLVR2 + (paired format with additional attention layer) + """ + def __init__(self, config, img_dim): + super().__init__(config) + self.bert = UniterModel(config, img_dim) + self.attn1 = MultiheadAttention(config.hidden_size, + config.num_attention_heads, + config.attention_probs_dropout_prob) + self.attn2 = MultiheadAttention(config.hidden_size, + config.num_attention_heads, + config.attention_probs_dropout_prob) + self.fc = nn.Sequential( + nn.Linear(2*config.hidden_size, config.hidden_size), + GELU(), + nn.Dropout(config.hidden_dropout_prob)) + self.attn_pool = AttentionPool(config.hidden_size, + config.attention_probs_dropout_prob) + self.nlvr2_output = nn.Linear(2*config.hidden_size, 2) + self.apply(self.init_weights) + + def init_type_embedding(self): + new_emb = nn.Embedding(3, self.bert.config.hidden_size) + new_emb.apply(self.init_weights) + for i in [0, 1]: + emb = self.bert.embeddings.token_type_embeddings\ + .weight.data[i, :] + new_emb.weight.data[i, :].copy_(emb) + new_emb.weight.data[2, :].copy_(emb) + self.bert.embeddings.token_type_embeddings = new_emb + + def forward(self, input_ids, position_ids, img_feat, img_pos_feat, + attn_masks, gather_index, + img_type_ids, targets, compute_loss=True): + sequence_output = self.bert(input_ids, position_ids, + img_feat, img_pos_feat, + attn_masks, gather_index, + output_all_encoded_layers=False, + img_type_ids=img_type_ids) + # separate left image and right image + bs, tl, d = sequence_output.size() + left_out, right_out = sequence_output.contiguous().view( + bs//2, tl*2, d).chunk(2, dim=1) + # bidirectional attention + mask = attn_masks == 0 + left_mask, right_mask = mask.contiguous().view(bs//2, tl*2 + ).chunk(2, dim=1) + left_out = left_out.transpose(0, 1) + right_out = right_out.transpose(0, 1) + l2r_attn, _ = self.attn1(left_out, right_out, right_out, + key_padding_mask=right_mask) + r2l_attn, _ = self.attn2(right_out, left_out, left_out, + key_padding_mask=left_mask) + left_out = self.fc(torch.cat([l2r_attn, left_out], dim=-1) + ).transpose(0, 1) + right_out = self.fc(torch.cat([r2l_attn, right_out], dim=-1) + ).transpose(0, 1) + # attention pooling and final prediction + left_out = self.attn_pool(left_out, left_mask) + right_out = self.attn_pool(right_out, right_mask) + answer_scores = self.nlvr2_output( + torch.cat([left_out, right_out], dim=-1)) + + if compute_loss: + nlvr2_loss = F.cross_entropy( + answer_scores, targets, reduction='none') + return nlvr2_loss + else: + return answer_scores diff --git a/uniter_model/model/ot.py b/uniter_model/model/ot.py new file mode 100644 index 0000000..46c6571 --- /dev/null +++ b/uniter_model/model/ot.py @@ -0,0 +1,82 @@ +""" +Wasserstein Distance (Optimal Transport) +""" +import torch +from torch.nn import functional as F + + +def cost_matrix_cosine(x, y, eps=1e-5): + """ Compute cosine distnace across every pairs of x, y (batched) + [B, L_x, D] [B, L_y, D] -> [B, Lx, Ly]""" + assert x.dim() == y.dim() + assert x.size(0) == y.size(0) + assert x.size(2) == y.size(2) + x_norm = F.normalize(x, p=2, dim=-1, eps=eps) + y_norm = F.normalize(y, p=2, dim=-1, eps=eps) + cosine_sim = x_norm.matmul(y_norm.transpose(1, 2)) + cosine_dist = 1 - cosine_sim + return cosine_dist + + +def trace(x): + """ compute trace of input tensor (batched) """ + b, m, n = x.size() + assert m == n + mask = torch.eye(n, dtype=torch.bool, device=x.device + ).unsqueeze(0).expand_as(x) + trace = x.masked_select(mask).contiguous().view( + b, n).sum(dim=-1, keepdim=False) + return trace + + +@torch.no_grad() +def ipot(C, x_len, x_pad, y_len, y_pad, joint_pad, beta, iteration, k): + """ [B, M, N], [B], [B, M], [B], [B, N], [B, M, N]""" + b, m, n = C.size() + sigma = torch.ones(b, m, dtype=C.dtype, device=C.device + ) / x_len.unsqueeze(1) + T = torch.ones(b, n, m, dtype=C.dtype, device=C.device) + A = torch.exp(-C.transpose(1, 2)/beta) + + # mask padded positions + sigma.masked_fill_(x_pad, 0) + joint_pad = joint_pad.transpose(1, 2) + T.masked_fill_(joint_pad, 0) + A.masked_fill_(joint_pad, 0) + + # broadcastable lengths + x_len = x_len.unsqueeze(1).unsqueeze(2) + y_len = y_len.unsqueeze(1).unsqueeze(2) + + # mask to zero out padding in delta and sigma + x_mask = (x_pad.to(C.dtype) * 1e4).unsqueeze(1) + y_mask = (y_pad.to(C.dtype) * 1e4).unsqueeze(1) + + for _ in range(iteration): + Q = A * T # bs * n * m + sigma = sigma.view(b, m, 1) + for _ in range(k): + delta = 1 / (y_len * Q.matmul(sigma).view(b, 1, n) + y_mask) + sigma = 1 / (x_len * delta.matmul(Q) + x_mask) + T = delta.view(b, n, 1) * Q * sigma + T.masked_fill_(joint_pad, 0) + return T + + +def optimal_transport_dist(txt_emb, img_emb, txt_pad, img_pad, + beta=0.5, iteration=50, k=1): + """ [B, M, D], [B, N, D], [B, M], [B, N]""" + cost = cost_matrix_cosine(txt_emb, img_emb) + # mask the padded inputs + joint_pad = txt_pad.unsqueeze(-1) | img_pad.unsqueeze(-2) + cost.masked_fill_(joint_pad, 0) + + txt_len = (txt_pad.size(1) - txt_pad.sum(dim=1, keepdim=False) + ).to(dtype=cost.dtype) + img_len = (img_pad.size(1) - img_pad.sum(dim=1, keepdim=False) + ).to(dtype=cost.dtype) + + T = ipot(cost.detach(), txt_len, txt_pad, img_len, img_pad, joint_pad, + beta, iteration, k) + distance = trace(cost.matmul(T.detach())) + return distance diff --git a/uniter_model/model/re.py b/uniter_model/model/re.py new file mode 100644 index 0000000..64a1880 --- /dev/null +++ b/uniter_model/model/re.py @@ -0,0 +1,140 @@ +""" +Bert for Referring Expression Comprehension +""" +import sys +import torch +import torch.nn as nn +from torch.nn import functional as F +from pytorch_pretrained_bert.modeling import BertPreTrainedModel, BertLayerNorm + +from .model import BertVisionLanguageEncoder + +import numpy as np +import random + + +class BertForReferringExpressionComprehension(BertPreTrainedModel): + """Finetune multi-model BERT for Referring Expression Comprehension + """ + def __init__(self, config, img_dim, loss="cls", + margin=0.2, hard_ratio=0.3, mlp=1): + super().__init__(config) + self.bert = BertVisionLanguageEncoder(config, img_dim) + if mlp == 1: + self.re_output = nn.Linear(config.hidden_size, 1) + elif mlp == 2: + self.re_output = nn.Sequential( + nn.Linear(config.hidden_size, config.hidden_size), + nn.ReLU(), + BertLayerNorm(config.hidden_size, eps=1e-12), + nn.Linear(config.hidden_size, 1) + ) + else: + sys.exit("MLP restricted to be 1 or 2 layers.") + self.loss = loss + assert self.loss in ['cls', 'rank'] + if self.loss == 'rank': + self.margin = margin + self.hard_ratio = hard_ratio + else: + self.crit = nn.CrossEntropyLoss(reduction='none') + # initialize + self.apply(self.init_bert_weights) + + def forward(self, input_ids, position_ids, txt_lens, img_feat, + img_pos_feat, num_bbs, attn_masks, obj_masks, + targets, compute_loss=True): + sequence_output = self.bert(input_ids, position_ids, txt_lens, + img_feat, img_pos_feat, num_bbs, + attn_masks, + output_all_encoded_layers=False) + # get only the region part + sequence_output = self._get_image_hidden(sequence_output, txt_lens, + num_bbs) + + # re score (n, max_num_bb) + scores = self.re_output(sequence_output).squeeze(2) + scores = scores.masked_fill(obj_masks, -1e4) # mask out non-objects + + # loss + if compute_loss: + if self.loss == 'cls': + ce_loss = self.crit(scores, targets) # (n, ) as no reduction + return ce_loss + else: + # ranking + _n = len(num_bbs) + # positive (target) + pos_ix = targets + pos_sc = scores.gather(1, pos_ix.view(_n, 1)) # (n, 1) + pos_sc = torch.sigmoid(pos_sc).view(-1) # (n, ) sc[0, 1] + # negative + neg_ix = self.sample_neg_ix(scores, targets, num_bbs) + neg_sc = scores.gather(1, neg_ix.view(_n, 1)) # (n, 1) + neg_sc = torch.sigmoid(neg_sc).view(-1) # (n, ) sc[0, 1] + # ranking + mm_loss = torch.clamp(self.margin + neg_sc - pos_sc, 0) # (n, ) + return mm_loss + else: + # (n, max_num_bb) + return scores + + def sample_neg_ix(self, scores, targets, num_bbs): + """ + Inputs: + :scores (n, max_num_bb) + :targets (n, ) + :num_bbs list of [num_bb] + return: + :neg_ix (n, ) easy/hard negative (!= target) + """ + neg_ix = [] + cand_ixs = torch.argsort(scores, dim=-1, descending=True) # (n, num_bb) + for i in range(len(num_bbs)): + num_bb = num_bbs[i] + if np.random.uniform(0, 1, 1) < self.hard_ratio: + # sample hard negative, w/ highest score + for ix in cand_ixs[i].tolist(): + if ix != targets[i]: + assert ix < num_bb, f'ix={ix}, num_bb={num_bb}' + neg_ix.append(ix) + break + else: + # sample easy negative, i.e., random one + ix = random.randint(0, num_bb-1) # [0, num_bb-1] + while ix == targets[i]: + ix = random.randint(0, num_bb-1) + neg_ix.append(ix) + neg_ix = torch.tensor(neg_ix).type(targets.type()) + assert neg_ix.numel() == targets.numel() + return neg_ix + + def _get_image_hidden(self, sequence_output, txt_lens, num_bbs): + """ + Extracting the img_hidden part from sequence_output. + Inputs: + - sequence_output: (n, txt_len+num_bb, hid_size) + - txt_lens : [txt_len] + - num_bbs : [num_bb] + Output: + - img_hidden : (n, max_num_bb, hid_size) + """ + outputs = [] + max_bb = max(num_bbs) + hid_size = sequence_output.size(-1) + for seq_out, len_, nbb in zip(sequence_output.split(1, dim=0), + txt_lens, num_bbs): + img_hid = seq_out[:, len_:len_+nbb, :] + if nbb < max_bb: + img_hid = torch.cat( + [img_hid, self._get_pad(img_hid, max_bb-nbb, hid_size)], + dim=1) + outputs.append(img_hid) + + img_hidden = torch.cat(outputs, dim=0) + return img_hidden + + def _get_pad(self, t, len_, hidden_size): + pad = torch.zeros(1, len_, hidden_size, dtype=t.dtype, device=t.device) + return pad + diff --git a/uniter_model/model/vcr.py b/uniter_model/model/vcr.py new file mode 100644 index 0000000..ae13f4e --- /dev/null +++ b/uniter_model/model/vcr.py @@ -0,0 +1,287 @@ +""" +Bert for VCR model +""" +from torch import nn +from torch.nn import functional as F +from pytorch_pretrained_bert.modeling import ( + BertPreTrainedModel, BertEmbeddings, BertEncoder, BertLayerNorm, + BertPooler, BertOnlyMLMHead) +from .model import (BertTextEmbeddings, BertImageEmbeddings, + BertForImageTextMaskedLM, + BertVisionLanguageEncoder, + BertForImageTextPretraining, + _get_image_hidden, + mask_img_feat, + RegionFeatureRegression, + mask_img_feat_for_mrc, + RegionClassification) +import torch +import random + + +class BertVisionLanguageEncoderForVCR(BertVisionLanguageEncoder): + """ Modification for Joint Vision-Language Encoding + """ + def __init__(self, config, img_dim, num_region_toks): + BertPreTrainedModel.__init__(self, config) + self.embeddings = BertTextEmbeddings(config) + self.img_embeddings = BertImageEmbeddings(config, img_dim) + self.num_region_toks = num_region_toks + self.region_token_embeddings = nn.Embedding( + num_region_toks, + config.hidden_size) + self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.encoder = BertEncoder(config) + self.pooler = BertPooler(config) + self.apply(self.init_bert_weights) + + def forward(self, input_ids, position_ids, txt_lens, + img_feat, img_pos_feat, num_bbs, + attention_mask, output_all_encoded_layers=True, + txt_type_ids=None, img_type_ids=None, region_tok_ids=None): + extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) + extended_attention_mask = extended_attention_mask.to( + dtype=next(self.parameters()).dtype) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + + embedding_output = self._compute_img_txt_embeddings( + input_ids, position_ids, txt_lens, + img_feat, img_pos_feat, num_bbs, attention_mask.size(1), + txt_type_ids, img_type_ids) + if region_tok_ids is not None: + region_tok_embeddings = self.region_token_embeddings( + region_tok_ids) + embedding_output += region_tok_embeddings + embedding_output = self.LayerNorm(embedding_output) + embedding_output = self.dropout(embedding_output) + encoded_layers = self.encoder( + embedding_output, extended_attention_mask, + output_all_encoded_layers=output_all_encoded_layers) + if not output_all_encoded_layers: + encoded_layers = encoded_layers[-1] + return encoded_layers + + +class BertForVisualCommonsenseReasoning(BertPreTrainedModel): + """ Finetune multi-modal BERT for ITM + """ + def __init__(self, config, img_dim, obj_cls=True, img_label_dim=81): + super().__init__(config, img_dim) + self.bert = BertVisionLanguageEncoder( + config, img_dim) + # self.vcr_output = nn.Linear(config.hidden_size, 1) + # self.vcr_output = nn.Linear(config.hidden_size, 2) + self.vcr_output = nn.Sequential( + nn.Linear(config.hidden_size, config.hidden_size*2), + nn.ReLU(), + BertLayerNorm(config.hidden_size*2, eps=1e-12), + nn.Linear(config.hidden_size*2, 2) + ) + self.apply(self.init_bert_weights) + self.obj_cls = obj_cls + if self.obj_cls: + self.region_classifier = RegionClassification( + config.hidden_size, img_label_dim) + + def init_type_embedding(self): + new_emb = nn.Embedding(4, self.bert.config.hidden_size) + new_emb.apply(self.init_bert_weights) + for i in [0, 1]: + emb = self.bert.embeddings.token_type_embeddings.weight.data[i, :] + new_emb.weight.data[i, :].copy_(emb) + emb = self.bert.embeddings.token_type_embeddings.weight.data[0, :] + new_emb.weight.data[2, :].copy_(emb) + new_emb.weight.data[3, :].copy_(emb) + self.bert.embeddings.token_type_embeddings = new_emb + + def init_word_embedding(self, num_special_tokens): + orig_word_num = self.bert.embeddings.word_embeddings.weight.size(0) + new_emb = nn.Embedding( + orig_word_num + num_special_tokens, self.bert.config.hidden_size) + new_emb.apply(self.init_bert_weights) + emb = self.bert.embeddings.word_embeddings.weight.data + new_emb.weight.data[:orig_word_num, :].copy_(emb) + self.bert.embeddings.word_embeddings = new_emb + + def masked_predict_labels(self, sequence_output, mask): + # only compute masked outputs + mask = mask.unsqueeze(-1).expand_as(sequence_output) + sequence_output_masked = sequence_output[mask].contiguous().view( + -1, self.config.hidden_size) + prediction_soft_label = self.region_classifier(sequence_output_masked) + + return prediction_soft_label + + def forward(self, input_ids, position_ids, txt_lens, txt_type_ids, + img_feat, img_pos_feat, num_bbs, + attention_mask, targets, obj_targets=None, img_masks=None, + region_tok_ids=None, compute_loss=True): + sequence_output = self.bert(input_ids, position_ids, txt_lens, + img_feat, img_pos_feat, num_bbs, + attention_mask, + output_all_encoded_layers=False, + txt_type_ids=txt_type_ids) + pooled_output = self.bert.pooler(sequence_output) + rank_scores = self.vcr_output(pooled_output) + # rank_scores = rank_scores.reshape((-1, 4)) + + if self.obj_cls and img_masks is not None: + img_feat = mask_img_feat_for_mrc(img_feat, img_masks) + masked_sequence_output = self.bert( + input_ids, position_ids, txt_lens, + img_feat, img_pos_feat, num_bbs, + attention_mask, + output_all_encoded_layers=False, + txt_type_ids=txt_type_ids) + # get only the image part + img_sequence_output = _get_image_hidden( + masked_sequence_output, txt_lens, num_bbs) + # only compute masked tokens for better efficiency + predicted_obj_label = self.masked_predict_labels( + img_sequence_output, img_masks) + + if compute_loss: + vcr_loss = F.cross_entropy( + rank_scores, targets.squeeze(-1), + reduction='mean') + if self.obj_cls: + obj_cls_loss = F.cross_entropy( + predicted_obj_label, obj_targets.long(), + ignore_index=0, reduction='mean') + else: + obj_cls_loss = torch.tensor([0.], device=vcr_loss.device) + return vcr_loss, obj_cls_loss + else: + rank_scores = rank_scores[:, 1:] + return rank_scores + + +class BertForImageTextPretrainingForVCR(BertForImageTextPretraining): + def init_type_embedding(self): + new_emb = nn.Embedding(4, self.bert.config.hidden_size) + new_emb.apply(self.init_bert_weights) + for i in [0, 1]: + emb = self.bert.embeddings.token_type_embeddings.weight.data[i, :] + new_emb.weight.data[i, :].copy_(emb) + emb = self.bert.embeddings.token_type_embeddings.weight.data[0, :] + new_emb.weight.data[2, :].copy_(emb) + new_emb.weight.data[3, :].copy_(emb) + self.bert.embeddings.token_type_embeddings = new_emb + + def init_word_embedding(self, num_special_tokens): + orig_word_num = self.bert.embeddings.word_embeddings.weight.size(0) + new_emb = nn.Embedding( + orig_word_num + num_special_tokens, self.bert.config.hidden_size) + new_emb.apply(self.init_bert_weights) + emb = self.bert.embeddings.word_embeddings.weight.data + new_emb.weight.data[:orig_word_num, :].copy_(emb) + self.bert.embeddings.word_embeddings = new_emb + self.cls = BertOnlyMLMHead( + self.bert.config, self.bert.embeddings.word_embeddings.weight) + + def forward(self, input_ids, position_ids, txt_type_ids, txt_lens, + img_feat, img_pos_feat, num_bbs, + attention_mask, labels, task, compute_loss=True): + if task == 'mlm': + txt_labels = labels + return self.forward_mlm(input_ids, position_ids, txt_type_ids, + txt_lens, + img_feat, img_pos_feat, num_bbs, + attention_mask, txt_labels, compute_loss) + elif task == 'mrm': + img_mask = labels + return self.forward_mrm(input_ids, position_ids, txt_type_ids, + txt_lens, + img_feat, img_pos_feat, num_bbs, + attention_mask, img_mask, compute_loss) + elif task.startswith('mrc'): + img_mask, mrc_label_target = labels + return self.forward_mrc(input_ids, position_ids, txt_type_ids, + txt_lens, + img_feat, img_pos_feat, num_bbs, + attention_mask, img_mask, + mrc_label_target, task, compute_loss) + else: + raise ValueError('invalid task') + + # MLM + def forward_mlm(self, input_ids, position_ids, txt_type_ids, txt_lens, + img_feat, img_pos_feat, num_bbs, + attention_mask, txt_labels, compute_loss=True): + sequence_output = self.bert(input_ids, position_ids, txt_lens, + img_feat, img_pos_feat, num_bbs, + attention_mask, + output_all_encoded_layers=False, + txt_type_ids=txt_type_ids) + # get only the text part + sequence_output = sequence_output[:, :input_ids.size(1), :] + # only compute masked tokens for better efficiency + prediction_scores = self.masked_compute_scores( + sequence_output, txt_labels != -1) + if self.vocab_pad: + prediction_scores = prediction_scores[:, :-self.vocab_pad] + + if compute_loss: + masked_lm_loss = F.cross_entropy(prediction_scores, + txt_labels[txt_labels != -1], + reduction='none') + return masked_lm_loss + else: + return prediction_scores + + # MRM + def forward_mrm(self, input_ids, position_ids, txt_type_ids, txt_lens, + img_feat, img_pos_feat, num_bbs, + attention_mask, img_masks, compute_loss=True): + img_feat, feat_targets = mask_img_feat(img_feat, img_masks) + sequence_output = self.bert(input_ids, position_ids, txt_lens, + img_feat, img_pos_feat, num_bbs, + attention_mask, + output_all_encoded_layers=False, + txt_type_ids=txt_type_ids) + # get only the text part + sequence_output = _get_image_hidden(sequence_output, txt_lens, num_bbs) + # only compute masked tokens for better efficiency + prediction_feat = self.masked_compute_feat( + sequence_output, img_masks) + + if compute_loss: + mrm_loss = F.mse_loss(prediction_feat, feat_targets, + reduction='none') + return mrm_loss + else: + return prediction_feat + + # MRC + def forward_mrc(self, input_ids, position_ids, txt_type_ids, txt_lens, + img_feat, img_pos_feat, num_bbs, + attention_mask, img_masks, + label_targets, task, compute_loss=True): + img_feat = mask_img_feat_for_mrc(img_feat, img_masks) + sequence_output = self.bert(input_ids, position_ids, txt_lens, + img_feat, img_pos_feat, num_bbs, + attention_mask, + output_all_encoded_layers=False, + txt_type_ids=txt_type_ids) + # get only the image part + sequence_output = _get_image_hidden(sequence_output, txt_lens, num_bbs) + # only compute masked tokens for better efficiency + prediction_soft_label = self.masked_predict_labels( + sequence_output, img_masks) + + if compute_loss: + if "kl" in task: + prediction_soft_label = F.log_softmax( + prediction_soft_label, dim=-1) + mrc_loss = F.kl_div( + prediction_soft_label, label_targets, reduction='none') + else: + label_targets = torch.max( + label_targets, -1)[1] # argmax + mrc_loss = F.cross_entropy( + prediction_soft_label, label_targets, + ignore_index=0, reduction='none') + return mrc_loss + else: + return prediction_soft_label diff --git a/uniter_model/model/ve.py b/uniter_model/model/ve.py new file mode 100644 index 0000000..c709e0b --- /dev/null +++ b/uniter_model/model/ve.py @@ -0,0 +1,11 @@ +""" +UNITER for VE model +""" +from .vqa import UniterForVisualQuestionAnswering + + +class UniterForVisualEntailment(UniterForVisualQuestionAnswering): + """ Finetune multi-modal BERT for VE + """ + def __init__(self, config, img_dim): + super().__init__(config, img_dim, 3) diff --git a/uniter_model/model/vqa.py b/uniter_model/model/vqa.py new file mode 100644 index 0000000..bdaaffd --- /dev/null +++ b/uniter_model/model/vqa.py @@ -0,0 +1,49 @@ +""" +Bert for VQA model +""" +from collections import defaultdict + +from torch import nn +from torch.nn import functional as F + +from .layer import GELU +from .model import UniterPreTrainedModel, UniterModel + +LayerNorm = nn.LayerNorm + +class UniterForVisualQuestionAnswering(UniterPreTrainedModel): + """ Finetune multi-modal BERT for VQA + """ + def __init__(self, config, img_dim, num_answer): + super().__init__(config) + self.bert = UniterModel(config, img_dim) + self.vqa_output = nn.Sequential( + nn.Linear(config.hidden_size, config.hidden_size*2), + GELU(), + LayerNorm(config.hidden_size*2, eps=1e-12), + nn.Linear(config.hidden_size*2, num_answer) + ) + self.apply(self.init_weights) + + def forward(self, batch, compute_loss=True): + batch = defaultdict(lambda: None, batch) + input_ids = batch['input_ids'] + position_ids = batch['position_ids'] + img_feat = batch['img_feat'] + img_pos_feat = batch['img_pos_feat'] + attn_masks = batch['attn_masks'] + gather_index = batch['gather_index'] + sequence_output = self.bert(input_ids, position_ids, + img_feat, img_pos_feat, + attn_masks, gather_index, + output_all_encoded_layers=False) + pooled_output = self.bert.pooler(sequence_output) + answer_scores = self.vqa_output(pooled_output) + + if compute_loss: + targets = batch['targets'] + vqa_loss = F.binary_cross_entropy_with_logits( + answer_scores, targets, reduction='none') + return vqa_loss + else: + return answer_scores diff --git a/uniter_model/optim/__init__.py b/uniter_model/optim/__init__.py new file mode 100644 index 0000000..a97aad3 --- /dev/null +++ b/uniter_model/optim/__init__.py @@ -0,0 +1,2 @@ +from .sched import noam_schedule, warmup_linear, vqa_schedule, get_lr_sched +from .adamw import AdamW diff --git a/uniter_model/optim/adamw.py b/uniter_model/optim/adamw.py new file mode 100644 index 0000000..e8472d0 --- /dev/null +++ b/uniter_model/optim/adamw.py @@ -0,0 +1,103 @@ +""" +AdamW optimizer (weight decay fix) +copied from hugginface +""" +import math + +import torch +from torch.optim import Optimizer + + +class AdamW(Optimizer): + """ Implements Adam algorithm with weight decay fix. + Parameters: + lr (float): learning rate. Default 1e-3. + betas (tuple of 2 floats): Adams beta parameters (b1, b2). + Default: (0.9, 0.999) + eps (float): Adams epsilon. Default: 1e-6 + weight_decay (float): Weight decay. Default: 0.0 + correct_bias (bool): can be set to False to avoid correcting bias + in Adam (e.g. like in Bert TF repository). Default True. + """ + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6, + weight_decay=0.0, correct_bias=True): + if lr < 0.0: + raise ValueError( + "Invalid learning rate: {} - should be >= 0.0".format(lr)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter: {} - " + "should be in [0.0, 1.0[".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter: {} - " + "should be in [0.0, 1.0[".format(betas[1])) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {} - " + "should be >= 0.0".format(eps)) + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, + correct_bias=correct_bias) + super(AdamW, self).__init__(params, defaults) + + def step(self, closure=None): + """Performs a single optimization step. + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + grad = p.grad.data + if grad.is_sparse: + raise RuntimeError( + 'Adam does not support sparse ' + 'gradients, please consider SparseAdam instead') + + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = 0 + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p.data) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like(p.data) + + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + beta1, beta2 = group['betas'] + + state['step'] += 1 + + # Decay the first and second moment running average coefficient + # In-place operations to update the averages at the same time + exp_avg.mul_(beta1).add_(1.0 - beta1, grad) + exp_avg_sq.mul_(beta2).addcmul_(1.0 - beta2, grad, grad) + denom = exp_avg_sq.sqrt().add_(group['eps']) + + step_size = group['lr'] + if group['correct_bias']: # No bias correction for Bert + bias_correction1 = 1.0 - beta1 ** state['step'] + bias_correction2 = 1.0 - beta2 ** state['step'] + step_size = (step_size * math.sqrt(bias_correction2) + / bias_correction1) + + p.data.addcdiv_(-step_size, exp_avg, denom) + + # Just adding the square of the weights to the loss function is + # *not* the correct way of using L2 regularization/weight decay + # with Adam, since that will interact with the m and v + # parameters in strange ways. + # + # Instead we want to decay the weights in a manner that doesn't + # interact with the m/v parameters. This is equivalent to + # adding the square of the weights to the loss with plain + # (non-momentum) SGD. + # Add weight decay at the end (fixed version) + if group['weight_decay'] > 0.0: + p.data.add_(-group['lr'] * group['weight_decay'], p.data) + + return loss diff --git a/uniter_model/optim/misc.py b/uniter_model/optim/misc.py new file mode 100644 index 0000000..0a1388a --- /dev/null +++ b/uniter_model/optim/misc.py @@ -0,0 +1,32 @@ +""" +Misc lr helper +""" +from torch.optim import Adam, Adamax + +from .adamw import AdamW + + +def build_optimizer(model, opts): + param_optimizer = list(model.named_parameters()) + no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] + optimizer_grouped_parameters = [ + {'params': [p for n, p in param_optimizer + if not any(nd in n for nd in no_decay)], + 'weight_decay': opts.weight_decay}, + {'params': [p for n, p in param_optimizer + if any(nd in n for nd in no_decay)], + 'weight_decay': 0.0} + ] + + # currently Adam only + if opts.optim == 'adam': + OptimCls = Adam + elif opts.optim == 'adamax': + OptimCls = Adamax + elif opts.optim == 'adamw': + OptimCls = AdamW + else: + raise ValueError('invalid optimizer') + optimizer = OptimCls(optimizer_grouped_parameters, + lr=opts.learning_rate, betas=opts.betas) + return optimizer diff --git a/uniter_model/optim/sched.py b/uniter_model/optim/sched.py new file mode 100644 index 0000000..ee5f9a9 --- /dev/null +++ b/uniter_model/optim/sched.py @@ -0,0 +1,52 @@ +""" +optimizer learning rate scheduling helpers +""" +from math import ceil + + +def noam_schedule(step, warmup_step=4000): + if step <= warmup_step: + return step / warmup_step + return (warmup_step ** 0.5) * (step ** -0.5) + + +def warmup_linear(step, warmup_step, tot_step): + if step < warmup_step: + return step / warmup_step + return max(0, (tot_step-step)/(tot_step-warmup_step)) + + +def vqa_schedule(step, warmup_interval, decay_interval, + decay_start, decay_rate): + """ VQA schedule from MCAN """ + if step < warmup_interval: + return 1/4 + elif step < 2 * warmup_interval: + return 2/4 + elif step < 3 * warmup_interval: + return 3/4 + elif step >= decay_start: + num_decay = ceil((step - decay_start) / decay_interval) + return decay_rate ** num_decay + else: + return 1 + + +def get_lr_sched(global_step, opts): + # learning rate scheduling + if opts.decay == 'linear': + lr_this_step = opts.learning_rate * warmup_linear( + global_step, opts.warmup_steps, opts.num_train_steps) + elif opts.decay == 'invsqrt': + lr_this_step = opts.learning_rate * noam_schedule( + global_step, opts.warmup_steps) + elif opts.decay == 'constant': + lr_this_step = opts.learning_rate + elif opts.decay == 'vqa': + lr_this_step = opts.learning_rate * vqa_schedule( + global_step, opts.warm_int, opts.decay_int, + opts.decay_st, opts.decay_rate) + if lr_this_step <= 0: + # save guard for possible miscalculation of train steps + lr_this_step = 1e-8 + return lr_this_step diff --git a/uniter_model/prepro.py b/uniter_model/prepro.py new file mode 100644 index 0000000..ebc8380 --- /dev/null +++ b/uniter_model/prepro.py @@ -0,0 +1,750 @@ +""" +preprocess COCO annotations into LMDB +""" +import argparse +from collections import defaultdict +import json +import os +from os.path import basename, exists +import pickle +import re + +from cytoolz import curry +from tqdm import tqdm +from pytorch_pretrained_bert import BertTokenizer + +from utils.vqa import compute_target +from utils.visual_entailment import compute_target as compute_target_ve +from data.data import open_lmdb + + +IN_WORD = '@@' + + +@curry +def bert_tokenize(tokenizer, text): + """ reconstructable tokenization for possible generation """ + if text == ('this house is leaning out to wards ' + 'the road taken in cambridge@ @@@@'): + # SBU special case + text = text.replace('@@', '') + assert IN_WORD not in text + ids = [] + words = [] + for word in text.strip().split(): + ws = tokenizer.tokenize(word) + if not ws: + # some special char in conceptual caption + continue + words.append(ws[0]) + for w in ws[1:]: + words.append(f'{IN_WORD}{w}') + ids.extend(tokenizer.convert_tokens_to_ids(ws)) + return ids, words + + +@curry +def bert_tokenize_for_vcr(tokenizer, special_tokens, text, txt_region_tokens): + """ reconstructable tokenization for possible generation """ + assert IN_WORD not in text + ids = [] + words = [] + special_tokens_dict = {val: ind for ind, val in enumerate(special_tokens)} + toked_txt_region_tokens = [] + index = 0 + for word in text.strip().split(): + if word in special_tokens_dict: + words.append(word) + ids.extend([len(tokenizer.vocab)+special_tokens_dict[word]]) + toked_txt_region_tokens.append(txt_region_tokens[index]) + else: + ws = tokenizer.tokenize(word) + words.append(ws[0]) + toked_txt_region_tokens.append(txt_region_tokens[index]) + for w in ws[1:]: + words.append(f'{IN_WORD}{w}') + toked_txt_region_tokens.append(txt_region_tokens[index]) + ids.extend(tokenizer.convert_tokens_to_ids(ws)) + index += 1 + return ids, words, toked_txt_region_tokens + + +def _norm_text(text): + norm_text = re.sub(r"([.,'!?\"()*#:;])", '', text.lower() + ).replace('-', ' ').replace('/', ' ') + return norm_text + + +def make_word2id(texts): + word2id = {'PAD': 0, 'UNK': 1} + for text in texts: + for w in _norm_text(text).split(): + if w not in word2id: + word2id[w] = len(word2id) + return word2id + + +def gen_vqa_texts(annotation): + questions = json.load(open(annotation))['questions'] + for q in questions: + yield q['question'] + + +def gen_ve_texts(annotation): + contents = open(annotation, "r").read() + hypotheses = [json.loads(str(item)) + for item in contents.strip().split('\n')] + for h in hypotheses: + yield h['sentence2'] + + +def gen_itm_texts(annotation): + data = json.load(open(annotation)) + for q in data: + for s in q["sentences"]: + yield s['raw'] + + +@curry +def _get_coco_fname(id_, split): + fname = f'coco_{split}_{id_:012}.npz' + return fname + + +def _get_vg_fname(id_): + fname = f'vg_{int(id_):012}.npz' + return fname + + +def _get_gqa_fname(id_): + if "n" not in id_: + fname = f'gqa_{int(id_):012}.npz' + else: + fname = f'gqa_{id_}.npz' + return fname + + +def _get_flickr_fname(id_): + fname = f'flickr30k_{id_:012}.npz' + return fname + + +def _get_vcr_fname(id_, split): + fname_gt = f'vcr_gt_{split}_{id_}.npz' + fname = f'vcr_{split}_{id_}.npz' + return fname_gt, fname + + +def process_vqa(questions, answers, ans2label, db, tokenizer, split): + """ + Inputs: + - questions : [{image_id, question, question_id}] + - answers : [{answers, image_id, question_id, + question_type, answer_type}] + - ans2label : ans -> ans_id + - db + - tokenizer + - split + Return: + - id2len : qid -> tokenized question length + - txt2img : qid -> img(feature) filename + - img2txts : img(feature) filename -> [qid] + Besides, we write into db[qid]: + - toked_question : [tokens] + - input_ids : [wd_ids] + - img_fname : img(feature) filename + - target : {labels, scores} + """ + id2len = {} + txt2img = {} + img2txts = defaultdict(list) + if split == 'vg': + get_img_fname = _get_vg_fname + elif split == 'gqa': + get_img_fname = _get_gqa_fname + else: + get_img_fname = _get_coco_fname(split=split) + for q in tqdm(questions, desc='processing VQA questions'): + qid = str(q['question_id']) + input_ids, toked_question = tokenizer(q['question']) + id2len[qid] = len(input_ids) + img_fname = get_img_fname(q['image_id']) + txt2img[qid] = img_fname + img2txts[img_fname].append(qid) + q['toked_question'] = toked_question + q['input_ids'] = input_ids + q['img_fname'] = img_fname + db[qid] = q + if answers is not None: + for a in tqdm(answers, desc='processing VQA answers'): + qid = str(a['question_id']) + q = db[qid] + assert q['question_id'] == a['question_id'] + assert q['image_id'] == a['image_id'] + for k, v in a.items(): + q[k] = v + q['target'] = compute_target(a['answers'], ans2label) + db[qid] = q + return id2len, txt2img, img2txts + + +def process_referring_expressions(refs, instances, iid_to_ann_ids, + db, tokenizer, split): + """ + Inputs: + - refs: [ref_id, ann_id, image_id, split, sent_ids, sentences] + - instances: {images, annotations, categories} + - iid_to_ann_ids: image_id -> ann_ids ordered by extracted butd features + Return: + - id2len : sent_id -> tokenized question length + - images : [{id, file_name, ann_ids, height, width} ] + - annotations: [{id, area, bbox, image_id, category_id, iscrowd}] + - categories : [{id, name, supercategory}] + """ + # images within split + image_set = set([ref['image_id'] for ref in refs if ref['split'] == split]) + images = [] + for img in instances['images']: + if img['id'] in image_set: + images.append({'id': img['id'], 'file_name': img['file_name'], + 'ann_ids': iid_to_ann_ids[str(img['id'])], + 'height': img['height'], 'width': img['width']}) + # anns within split + annotations = [] + for ann in instances['annotations']: + if ann['image_id'] in image_set: + annotations.append({ + 'id': ann['id'], 'area': ann['area'], 'bbox': ann['bbox'], + 'image_id': ann['image_id'], 'category_id': ann['category_id'], + 'iscrowd': ann['iscrowd'] + }) + Anns = {ann['id']: ann for ann in annotations} + # category info + categories = instances['categories'] + # refs within split + refs = [ref for ref in refs if ref['split'] == split] + id2len = {} + for ref in tqdm(refs, desc='processing referring expressions'): + ref_id = ref['ref_id'] + ann_id = ref['ann_id'] + image_id = ref['image_id'] + for sent in ref['sentences']: + sent_id = sent['sent_id'] + input_ids, toked_sent = tokenizer(sent['sent']) + id2len[str(sent_id)] = len(input_ids) + db[str(sent_id)] = { + 'sent_id': sent_id, 'sent': sent['sent'], + 'ref_id': ref_id, 'ann_id': ann_id, 'image_id': image_id, + 'bbox': Anns[ann_id]['bbox'], + 'input_ids': input_ids, 'toked_sent': toked_sent} + return id2len, images, annotations, categories, refs + + +def process_gqa(questions, db, tokenizer, split): + id2len = {} + txt2img = {} + img2txts = defaultdict(list) + get_img_fname = _get_gqa_fname + for qid, q in tqdm(questions.items(), + desc=f'processing GQA_{split} questions'): + input_ids, toked_question = tokenizer(q['question']) + id2len[qid] = len(input_ids) + img_fname = get_img_fname(q['imageId']) + txt2img[qid] = img_fname + img2txts[img_fname].append(qid) + q['toked_question'] = toked_question + q['input_ids'] = input_ids + q['img_fname'] = img_fname + input_ids_a, toked_a = tokenizer(q['fullAnswer']) + id2len[qid] += len(input_ids_a) + q['input_ids_a'] = input_ids_a + q['toked_answers'] = toked_a + db[qid] = q + return id2len, txt2img, img2txts + + +def process_nlvr2(jsonl, db, tokenizer, imgs=None): + id2len = {} + txt2img = {} # not sure if useful + img2txts = defaultdict(list) # not sure if useful + for line in tqdm(jsonl, desc='processing NLVR2'): + example = json.loads(line) + id_ = example['identifier'] + img_id = '-'.join(id_.split('-')[:-1]) + img_fname = (f'nlvr2_{img_id}-img0.npz', f'nlvr2_{img_id}-img1.npz') + if imgs is not None: + if not all(img in imgs for img in img_fname): + continue + input_ids, toked_question = tokenizer(example['sentence']) + target = 1 if example['label'] == 'True' else 0 + id2len[id_] = len(input_ids) + txt2img[id_] = img_fname + for fname in img_fname: + img2txts[fname].append(id_) + example['toked_question'] = toked_question + example['input_ids'] = input_ids + example['img_fname'] = img_fname + example['target'] = target + db[id_] = example + return id2len, txt2img, img2txts + + +def process_visual_entailment(hypotheses, ans2label, db, tokenizer): + id2len = {} + txt2img = {} + img2txts = defaultdict(list) + for h in tqdm(hypotheses, desc='processing visaul entailment hypotheses'): + hid = h['pairID'] + h['image_id'] = int(h["Flikr30kID"].split(".")[0]) + input_ids, toked_hypothesis = tokenizer(h['sentence2']) + id2len[hid] = len(input_ids) + img_fname = _get_flickr_fname(h['image_id']) + txt2img[hid] = img_fname + img2txts[img_fname].append(hid) + h['toked_hypothesis'] = toked_hypothesis + h['input_ids'] = input_ids + h['target'] = compute_target_ve([h['gold_label']], ans2label) + h['img_fname'] = img_fname + db[hid] = h + + return id2len, txt2img, img2txts + + +def process_caption(data, db, tokenizer, split): + id2len = {} + txt2img = {} + img2txts = defaultdict(list) + for q in tqdm(data['annotations'], desc='processing COCO captions'): + id_ = str(q['id']) + input_ids, toked_caption = tokenizer(q['caption']) + id2len[id_] = len(input_ids) + img_fname = _get_coco_fname(q['image_id'], split) + txt2img[id_] = img_fname + img2txts[img_fname].append(id_) + q['toked_caption'] = toked_caption + q['input_ids'] = input_ids + q['img_fname'] = img_fname + db[id_] = q + return id2len, txt2img, img2txts + + +def process_conceptual_caption(tsv, imgs, db, tokenizer, split): + id2len = {} + txt2img = {} + img2txts = defaultdict(list) + for line in tqdm(tsv, desc='processing conceptual captions'): + fields = line.strip().split('\t') + assert len(fields) == 4 + id_, _, caption, success = fields + if success == 'fail': + continue + assert success == 'success' + input_ids, toked_caption = tokenizer(caption) + assert input_ids # safeguard for empty text + img_fname = f'gcc_{split}_{int(id_):012}.npz' + if img_fname not in imgs: + continue + id2len[id_] = len(input_ids) + txt2img[id_] = img_fname + img2txts[img_fname].append(id_) + db[id_] = {'id': id_, + 'toked_caption': toked_caption, + 'input_ids': input_ids, + 'img_fname': img_fname} + return id2len, txt2img, img2txts + + +def process_sbu_caption(data, db, tokenizer): + id2len = {} + txt2img = {} + img2txts = defaultdict(list) + for ex in tqdm(data, desc='processing SBU captions'): + if ex['file_path'] == '0347/565.jpg': + # special case for corrupted image + continue + id_ = ex['iid'] + input_ids, toked_caption = tokenizer(ex['sent']) + assert input_ids # safeguard for empty text + try: + # FIXME sbu feature extraction bug + id_ = str(int(id_)) + except ValueError: + pass + img_fname = f'sbu_{id_}.npz' + id2len[id_] = len(input_ids) + txt2img[id_] = img_fname + img2txts[img_fname].append(id_) + db[id_] = {'id': id_, + 'toked_caption': toked_caption, + 'input_ids': input_ids, + 'img_fname': img_fname} + return id2len, txt2img, img2txts + + +def process_image_text_retrieval(data, db, tokenizer, dataset, split): + id2len = {} + txt2img = {} + img2txts = defaultdict(list) + if dataset == 'coco': + _get_img_fname = _get_coco_fname(split=split) + elif dataset == 'flickr': + _get_img_fname = _get_flickr_fname + else: + raise ValueError('unrecognized data') + for q in tqdm(data, desc=f'processing image_text_retrieval for {split}'): + filename = q["filename"].split(".jpg")[0] + image_id = (int(filename.split("_")[-1]) if re.search('[a-zA-Z]', + filename) + else int(filename)) + img_fname = _get_img_fname(image_id) + for s in q["sentences"]: + s['image_id'] = image_id + id_ = str(s['sentid']) + txt2img[id_] = img_fname + img2txts[img_fname].append(id_) + input_ids, toked_caption = tokenizer(s['raw']) + id2len[id_] = len(input_ids) + s['toked_caption'] = toked_caption + s['input_ids'] = input_ids + s['img_fname'] = img_fname + db[id_] = s + return id2len, txt2img, img2txts + + +def process_caption_licheng_cleaned(data, db, tokenizer, split="COCO"): + """ + Inputs: + - data : [{id, dataset, split, sent, bbox, + dataset_image_id, file_path}] + - db + - tokenizer + - split + Return: + - id2len : id -> tokenized caption length + - txt2img : id -> img(feature) filenamee + - img2txts : img(feature) filename -> id(s) + We will also write to db[id]: + - image_id + - toked_caption : [tokens] + - input_ids : [wd_ids] + - img_fname : img(feature) filename + """ + id2len = {} + txt2img = {} + img2txts = defaultdict(list) + for q in tqdm(data, desc='processing licheng collected captions ' + f'for split: {split}'): + id_ = str(q['id']) + input_ids, toked_caption = tokenizer(q['sent']) + id2len[id_] = len(input_ids) + if q['dataset'] == 'vg': + img_fname = _get_vg_fname(q['dataset_image_id']) + else: + assert q['dataset'] == 'coco' + img_split = basename(q['file_path']).split('_')[1] + img_fname = _get_coco_fname(q['dataset_image_id'], img_split) + txt2img[id_] = img_fname + img2txts[img_fname].append(id_) + q['image_id'] = q['dataset_image_id'] + q['toked_caption'] = toked_caption + q['input_ids'] = input_ids + q['img_fname'] = img_fname + db[id_] = q + return id2len, txt2img, img2txts + + +def process_vcr_text(tokened_txt, objects, special_tokens): + text_region_tokens = [] + image_region_tokens = [0]*len(objects) + words = [] + for w in tokened_txt: + if isinstance(w, str): + word_splits = w.split(" ") + for splited_w in word_splits: + words.append(splited_w) + text_region_tokens.append(0) + else: + for index in w: + text_region_tokens.append(index+1) + image_region_tokens[index] = index+1 + object_name = objects[index] + if "person" in object_name: + object_name = f"{object_name}_{index}" + if object_name not in special_tokens: + special_tokens.append(object_name) + words.append(object_name) + return " ".join(words), image_region_tokens, text_region_tokens + + +def process_vcr_obj_categories(objects, object2ids): + output_ids = [] + for obj in objects: + output_ids.append(object2ids[obj]+1) + return output_ids + + +def process_vcr(data, db, tokenizer, split, object2ids): + id2len_qa = {} + id2len_qar = {} + txt2img = {} + img2txts = defaultdict(list) + special_tokens = [f"person_{i}" for i in range(81)] + for q in tqdm(data, desc='processing VCR %s questions' % split): + filename, file_extension = os.path.splitext( + q["img_fn"].split("/")[-1]) + q["image_id"] = filename + q['qa_target'] = q["answer_label"] if "answer_label" in q else -1 + q["qar_target"] = q["rationale_label"] \ + if "rationale_label" in q else -1 + qid = str(q['annot_id']) + q["raw_q"], image_region_tokens, txt_region_tokens = process_vcr_text( + q["question"], q["objects"], special_tokens) + q["image_region_tokens"] = image_region_tokens + input_ids, toked_question, toked_txt_region_tokens = tokenizer( + special_tokens, q["raw_q"], txt_region_tokens) + object_ids = process_vcr_obj_categories(q["objects"], object2ids) + q["object_ids"] = object_ids + q['toked_question'] = toked_question + q['input_ids'] = input_ids + q['toked_txt_region_tokens'] = toked_txt_region_tokens + q["raw_as"] = [] + q["raw_rs"] = [] + img_fname_gt, img_fname = _get_vcr_fname(q['image_id'], split) + txt2img[qid] = [img_fname_gt, img_fname] + img2txts[img_fname].append(qid) + img2txts[img_fname_gt].append(qid) + + input_ids_as = [] + toked_as = [] + input_ids_rs = [] + toked_rs = [] + toked_txt_region_tokens_a = [] + toked_txt_region_tokens_r = [] + max_qa_len = 0 + for ans in q["answer_choices"]: + raw_ans, _, txt_region_tokens = process_vcr_text( + ans, q["objects"], special_tokens) + q["raw_as"].append(raw_ans) + input_ids_a, toked_a, toked_txt_region_tokens = tokenizer( + special_tokens, raw_ans, txt_region_tokens) + if len(input_ids_a) > max_qa_len: + max_qa_len = len(input_ids_a) + input_ids_as.append(input_ids_a) + toked_as.append(toked_a) + toked_txt_region_tokens_a.append(toked_txt_region_tokens) + id2len_qa[qid] = (len(input_ids)+max_qa_len)*4 + + max_r_len = 0 + for r in q["rationale_choices"]: + raw_r, _, txt_region_tokens = process_vcr_text( + r, q["objects"], special_tokens) + q["raw_rs"].append(raw_r) + input_ids_r, toked_r, toked_txt_region_tokens = tokenizer( + special_tokens, raw_r, txt_region_tokens) + if len(input_ids_r) > max_r_len: + max_r_len = len(input_ids_r) + input_ids_rs.append(input_ids_r) + toked_rs.append(toked_r) + toked_txt_region_tokens_r.append(toked_txt_region_tokens) + id2len_qar[qid] = id2len_qa[qid]+max_r_len + q['img_fname'] = [img_fname_gt, img_fname] + q['toked_as'] = toked_as + q['toked_txt_region_tokens_a'] = toked_txt_region_tokens_a + q['input_ids_as'] = input_ids_as + q['toked_rs'] = toked_rs + q['input_ids_rs'] = input_ids_rs + q['toked_txt_region_tokens_r'] = toked_txt_region_tokens_r + db[qid] = q + return id2len_qa, id2len_qar, txt2img, img2txts, special_tokens + + +def _get_img_split(annotation): + for split in ['train2014', 'val2014', 'test2015', 'test-dev2015']: + if split in annotation: + img_split = split + break + else: + if ('vg' in annotation.lower() + or 'genome' in annotation.lower()): + img_split = 'vg' + elif 'gqa' in annotation.lower(): + if ('test' in annotation.lower() + or 'submission' in annotation.lower()): + img_split = 'gqa' + else: + img_split = 'vg' + elif 'val' in annotation.lower(): + img_split = 'val2014' + elif 'train' in annotation.lower(): + img_split = 'train2014' + else: + raise ValueError('cannot identify split') + if img_split == 'test-dev2015': + img_split = 'test2015' + return img_split + + +def main(opts): + if not exists(opts.output): + os.makedirs(opts.output) + else: + raise ValueError('Found existing DB. Please explicitly remove ' + 'for re-processing') + meta = vars(opts) + toker = BertTokenizer.from_pretrained( + opts.bert, do_lower_case='uncased' in opts.bert) + tokenizer = bert_tokenize(toker) + meta['UNK'] = toker.convert_tokens_to_ids(['[UNK]'])[0] + meta['CLS'] = toker.convert_tokens_to_ids(['[CLS]'])[0] + meta['SEP'] = toker.convert_tokens_to_ids(['[SEP]'])[0] + meta['MASK'] = toker.convert_tokens_to_ids(['[MASK]'])[0] + meta['v_range'] = (toker.convert_tokens_to_ids('!')[0], + len(toker.vocab)) + with open(f'{opts.output}/meta.json', 'w') as f: + json.dump(vars(opts), f, indent=4) + + output_field_name = ['id2len', 'txt2img', 'img2txts'] + with open_lmdb(opts.output, readonly=False) as db: + if opts.task == 'vqa': + questions = json.load(open(opts.annotations[0]))['questions'] + if len(opts.annotations) == 3: + answers = json.load(open(opts.annotations[1]))['annotations'] + ans2label = pickle.load(open(opts.annotations[2], 'rb')) + with open(f'{opts.output}/ans2label.pkl', 'wb') as f: + pickle.dump(ans2label, f) + else: + answers = None + ans2label = None + + # train2014, val2014 + img_split = _get_img_split(opts.annotations[0]) + jsons = process_vqa(questions, answers, ans2label, + db, tokenizer, img_split) + elif opts.task == 've': + contents = open(opts.annotations[0], "r").read() + hypotheses = [json.loads(str(item)) + for item in contents.strip().split('\n')] + from utils.misc import VE_ENT2IDX + ans2label = VE_ENT2IDX + jsons = process_visual_entailment( + hypotheses, ans2label, db, tokenizer) + elif opts.task == 'caption': + data = json.load(open(opts.annotations[0])) + img_split = _get_img_split(opts.annotations[0]) + jsons = process_caption(data, db, tokenizer, img_split) + elif opts.task == 'conceptual': + split = 'train' if 'train' in opts.annotations[0] else 'val' + imgs = set(json.load(open(opts.annotations[1]))) + with open(opts.annotations[0]) as tsv: + jsons = process_conceptual_caption(tsv, imgs, + db, tokenizer, split) + elif opts.task == 'sbu': + data = json.load(open(opts.annotations[0])) + jsons = process_sbu_caption(data, db, tokenizer) + elif opts.task == 'itm': + data = json.load(open(opts.annotations[0])) + if 'coco' in opts.annotations[0].lower(): + dataset = 'coco' + if 'train' in opts.annotations[0].lower(): + split = 'train2014' + elif ('val' in opts.annotations[0].lower() + or 'test' in opts.annotations[0].lower()): + split = 'val2014' + else: + raise ValueError() + elif 'flickr' in opts.annotations[0].lower(): + dataset = 'flickr' + split = None + else: + raise ValueError() + jsons = process_image_text_retrieval( + data, db, tokenizer, dataset, split) + elif opts.task == 'licheng_cleaned': + data = json.load(open(opts.annotations[0])) + jsons = process_caption_licheng_cleaned( + data, db, tokenizer, + split=opts.annotations[0].split(".")[0].split("/")[-1]) + elif opts.task == 'gqa': + data = json.load(open(opts.annotations[0])) + data_split = opts.annotations[0].split(".")[0].split("/")[-1] + data_split = data_split.split("_")[0] + jsons = process_gqa( + data, db, tokenizer, + split=data_split) + elif opts.task == 'vcr': + data = [] + with open(opts.annotations[0], "r") as f: + for line in f: + data.append(json.loads(line)) + img_split = opts.annotations[0].split("/")[-1].split(".")[0] + tokenizer = bert_tokenize_for_vcr(toker) + ann_folder = "/".join(opts.annotations[0].split("/")[:-1]) + object_categories_path = ann_folder+"/object_categories.json" + object_categories = json.load(open(object_categories_path, "r")) + jsons = process_vcr(data, db, tokenizer, + img_split, object_categories) + output_field_name = ['id2len_qa', 'id2len_qar', 'txt2img', + 'img2txts', 'special_tokens'] + elif opts.task == 'nlvr2': + with open(opts.annotations[0]) as ann: + if len(opts.annotations) == 2: + imgs = set(json.load(open(opts.annotations[1]))) + else: + imgs = None + jsons = process_nlvr2(ann, db, tokenizer, imgs) + elif opts.task == 're': + data = [] + refs = pickle.load(open(opts.annotations[0], 'rb')) + instances = json.load(open(opts.annotations[1], 'r')) + iid_to_ann_ids = json.load(open(opts.annotations[2], + 'r'))['iid_to_ann_ids'] + # dirs/refcoco_testA_bert-base-cased.db -> testA + img_split = opts.output.split('/')[-1].split('_')[1] + jsons = process_referring_expressions( + refs, instances, iid_to_ann_ids, db, tokenizer, img_split) + output_field_name = ['id2len', 'images', 'annotations', + 'categories', 'refs'] + else: + raise ValueError() + + for dump, name in zip(jsons, output_field_name): + with open(f'{opts.output}/{name}.json', 'w') as f: + json.dump(dump, f) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--annotations', required=True, nargs='+', + help='annotation JSON') + parser.add_argument('--output', required=True, + help='output dir of DB') + parser.add_argument('--task', required=True, + choices=['vqa', 'caption', + 've', "itm", "licheng_cleaned", + 'vcr', 'nlvr2', 're', 'gqa', + 'conceptual', 'sbu']) + parser.add_argument('--bert', default='bert-base-cased') + args = parser.parse_args() + if args.task == 'vqa': + assert len(args.annotations) == 3 or len(args.annotations) == 1 + elif args.task == 'gqa': + assert len(args.annotations) == 1 + elif args.task == 've': + assert len(args.annotations) == 1 + elif args.task == 'itm': + assert len(args.annotations) == 1 + elif args.task == 'licheng_cleaned': + assert len(args.annotations) == 1 + elif args.task == 'caption': + assert len(args.annotations) == 1 + elif args.task == 'vcr': + assert len(args.annotations) == 1 + elif args.task == 'nlvr2': + assert len(args.annotations) == 1 or len(args.annotations) == 2 + elif args.task == 'conceptual': + assert len(args.annotations) == 2 or len(args.annotations) == 1 + elif args.task == 'sbu': + assert len(args.annotations) == 1 + elif args.task == 're': + assert len(args.annotations) == 3 + main(args) diff --git a/uniter_model/pretrain.py b/uniter_model/pretrain.py new file mode 100644 index 0000000..82dec82 --- /dev/null +++ b/uniter_model/pretrain.py @@ -0,0 +1,834 @@ +# coding=utf-8 +# copied from hugginface github +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. +# team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""UNITER pre-training runner.""" +import argparse +from collections import defaultdict +import json +import math +import os +from os.path import exists, join +from time import time + +import torch +from torch.utils.data import DataLoader +from torch.nn import functional as F +from torch.nn.utils import clip_grad_norm_ + +from apex import amp +from horovod import torch as hvd + +from tqdm import tqdm + +from data import (TokenBucketSampler, TokenBucketSamplerForItm, + MetaLoader, PrefetchLoader, + TxtTokLmdb, ImageLmdbGroup, ConcatDatasetWithLens, + MlmDataset, MlmEvalDataset, + BlindMlmDataset, BlindMlmEvalDataset, + MrfrDataset, OnlyImgMrfrDataset, + MrcDataset, OnlyImgMrcDataset, + mlm_collate, mlm_eval_collate, + mlm_blind_collate, mlm_blind_eval_collate, + mrfr_collate, mrfr_only_img_collate, + mrc_collate, mrc_only_img_collate, + ItmDataset, itm_collate, itm_ot_collate) +from data.mrm_nce import NegativeImageSampler, MrmNceDataset, mrm_nce_collate + +from model import UniterForPretraining +from optim import get_lr_sched +from optim.misc import build_optimizer + +from utils.logger import LOGGER, TB_LOGGER, RunningMeter, add_log_to_file +from utils.distributed import (all_reduce_and_rescale_tensors, all_gather_list, + broadcast_tensors) +from utils.save import ModelSaver, save_training_meta +from utils.misc import NoOp, parse_with_config, set_dropout, set_random_seed +from utils.const import IMG_DIM, IMG_LABEL_DIM, BUCKET_SIZE + + +WARM_STEP = 500 + + +def build_dataloader(dataset, collate_fn, is_train, opts): + if is_train: + batch_size = opts.train_batch_size + else: + batch_size = opts.val_batch_size + sampler = TokenBucketSampler(dataset.lens, bucket_size=BUCKET_SIZE, + batch_size=batch_size, droplast=is_train) + loader = DataLoader(dataset, batch_sampler=sampler, + num_workers=opts.n_workers, pin_memory=opts.pin_mem, + collate_fn=collate_fn) + return loader + + +def build_dataloader_itm(dataset, collate_fn, is_train, opts): + if is_train: + batch_size = opts.train_batch_size + else: + batch_size = opts.val_batch_size + sampler = TokenBucketSamplerForItm( + dataset, bucket_size=BUCKET_SIZE, + batch_size=batch_size, droplast=is_train) + loader = DataLoader(dataset, batch_sampler=sampler, + num_workers=opts.n_workers, pin_memory=opts.pin_mem, + collate_fn=collate_fn) + return loader + + +def build_mlm_dataset(txt_db, img_db, blind, is_train, opts): + if is_train: + if blind: + collate_fn = mlm_blind_collate + datasets = [BlindMlmDataset(t) for t in txt_db] + else: + collate_fn = mlm_collate + datasets = [MlmDataset(t, i) for t, i in zip(txt_db, img_db)] + dataset = ConcatDatasetWithLens(datasets) + else: + if blind: + collate_fn = mlm_blind_collate + dataset = BlindMlmDataset(txt_db) + else: + collate_fn = mlm_collate + dataset = MlmDataset(txt_db, img_db) + + return dataset, collate_fn + + +def build_mrfr_dataset(txt_db, img_db, only_i, is_train, opts): + collate_fn = (mrfr_only_img_collate if only_i + else mrfr_collate) + if is_train: + if only_i: + datasets = [OnlyImgMrfrDataset(opts.mrm_prob, i) for i in img_db] + else: + datasets = [MrfrDataset(opts.mrm_prob, t, i) + for t, i in zip(txt_db, img_db)] + dataset = ConcatDatasetWithLens(datasets) + else: + if only_i: + dataset = OnlyImgMrfrDataset(opts.mrm_prob, img_db) + else: + dataset = MrfrDataset(opts.mrm_prob, txt_db, img_db) + + return dataset, collate_fn + + +def build_mrm_nce_dataset(txt_db, img_db, only_i, is_train, opts): + assert not only_i + neg_sampler = NegativeImageSampler(img_db, opts.neg_size) + collate_fn = mrm_nce_collate(neg_sampler) + if is_train: + datasets = [MrmNceDataset(opts.mrm_prob, t, i) + for t, i in zip(txt_db, img_db)] + dataset = ConcatDatasetWithLens(datasets) + else: + dataset = MrmNceDataset(opts.mrm_prob, txt_db, img_db) + + return dataset, collate_fn + + +def build_mrc_dataset(txt_db, img_db, only_i, is_train, opts): + collate_fn = (mrc_only_img_collate if only_i + else mrc_collate) + if is_train: + if only_i: + datasets = [OnlyImgMrcDataset(opts.mrm_prob, i) for i in img_db] + else: + datasets = [MrcDataset(opts.mrm_prob, t, i) + for t, i in zip(txt_db, img_db)] + dataset = ConcatDatasetWithLens(datasets) + else: + if only_i: + dataset = OnlyImgMrcDataset(opts.mrm_prob, img_db) + else: + dataset = MrcDataset(opts.mrm_prob, txt_db, img_db) + + return dataset, collate_fn + + +def build_itm_dataset(txt_db, img_db, is_train, opts): + if is_train: + datasets = [ItmDataset(t, i, opts.itm_neg_prob) + for t, i in zip(txt_db, img_db)] + dataset = ConcatDatasetWithLens(datasets) + else: + dataset = ItmDataset(txt_db, img_db, opts.itm_neg_prob) + collate_fn = itm_ot_collate if opts.itm_ot_lambda > 0 else itm_collate + return dataset, collate_fn + + +def create_dataloaders(datasets, is_train, opts, all_img_dbs=None): + if all_img_dbs is None: + all_img_dbs = ImageLmdbGroup(opts.conf_th, opts.max_bb, opts.min_bb, + opts.num_bb, opts.compressed_db) + dataloaders = {} + for dset in datasets: + if is_train: + assert len(dset['db']) == len(dset['img']) + assert len(dset['tasks']) == len(dset['mix_ratio']) + img_db = [all_img_dbs[path] for path in dset['img']] + else: + assert len(dset['db']) == len(dset['img']) == 1 + img_db = all_img_dbs[dset['img'][0]] + + for i, t in enumerate(dset['tasks']): + task = f'{t}_{dset["name"]}' + + if is_train: + LOGGER.info(f"Loading {task} train dataset " + f"{dset['db']}, {[img.img_dir for img in img_db]}") + txt_db = [TxtTokLmdb(path, opts.max_txt_len) + for path in dset['db']] + else: + LOGGER.info(f"Loading {task} validation dataset, " + f"{dset['db']}, {img_db.img_dir}") + txt_db = TxtTokLmdb(dset['db'][0], -1) + + if task.startswith('mlm'): + blind = 'blind' in task + dataset = build_mlm_dataset(txt_db, img_db, + blind, is_train, opts) + elif task.startswith('mrfr'): + only_i = 'only_i' in task + dataset = build_mrfr_dataset(txt_db, img_db, + only_i, is_train, opts) + elif task.startswith('mrm-nce'): + only_i = 'only_i' in task + dataset = build_mrm_nce_dataset(txt_db, img_db, + only_i, is_train, opts) + elif task.startswith('mrc'): + only_i = 'only_i' in task + dataset = build_mrc_dataset(txt_db, img_db, + only_i, is_train, opts) + elif task.startswith('itm'): + dataset = build_itm_dataset(txt_db, img_db, is_train, opts) + else: + raise ValueError(f'Undefined task {task}') + + LOGGER.info(f"{len(dataset[0])*hvd.size()} samples loaded") + if task.startswith('itm'): + # itm handles distributed training in dset not sampler + loader = build_dataloader_itm(*dataset, is_train, opts) + else: + loader = build_dataloader(*dataset, is_train, opts) + if is_train: + ratio = dset['mix_ratio'][i] + dataloaders[task] = (loader, ratio) + else: + dataloaders[task] = PrefetchLoader(loader) + return dataloaders, all_img_dbs + + +def main(opts): + hvd.init() + n_gpu = hvd.size() + device = torch.device("cuda", hvd.local_rank()) + torch.cuda.set_device(hvd.local_rank()) + rank = hvd.rank() + opts.rank = rank + LOGGER.info("device: {} n_gpu: {}, rank: {}, " + "16-bits training: {}".format( + device, n_gpu, hvd.rank(), opts.fp16)) + + if opts.gradient_accumulation_steps < 1: + raise ValueError("Invalid gradient_accumulation_steps parameter: {}, " + "should be >= 1".format( + opts.gradient_accumulation_steps)) + + set_random_seed(opts.seed) + + if rank == 0: + save_training_meta(opts) + TB_LOGGER.create(join(opts.output_dir, 'log')) + pbar = tqdm(total=opts.num_train_steps) + model_saver = ModelSaver(join(args.output_dir, 'ckpt')) + add_log_to_file(join(opts.output_dir, 'log', 'log.txt')) + else: + LOGGER.disabled = True + pbar = NoOp() + model_saver = NoOp() + + all_dbs = [db for datasets in [opts.train_datasets, opts.val_datasets] + for dset in datasets for db in dset['db']] + + tokenizer = json.load(open(f'{all_dbs[0]}/meta.json'))['bert'] + assert all(tokenizer == json.load(open(f'{db}/meta.json'))['bert'] + for db in all_dbs) + + # build data loaders + train_dataloaders, all_img_dbs = create_dataloaders( + opts.train_datasets, True, opts) + val_dataloaders, _ = create_dataloaders( + opts.val_datasets, False, opts, all_img_dbs) + meta_loader = MetaLoader(train_dataloaders, + accum_steps=opts.gradient_accumulation_steps, + distributed=n_gpu > 1) + meta_loader = PrefetchLoader(meta_loader) + + # Prepare model + if opts.checkpoint: + checkpoint = torch.load(opts.checkpoint) + else: + checkpoint = {} + model = UniterForPretraining.from_pretrained( + opts.model_config, checkpoint, + img_dim=IMG_DIM, img_label_dim=IMG_LABEL_DIM, + nce_temp=opts.nce_temp, ot_pos_only=opts.ot_pos_only) + model.pad_vocab() # tensor core padding for vocabulary + model.to(device) + model.train() + # make sure every process has same model parameters in the beginning + broadcast_tensors([p.data for p in model.parameters()], 0) + set_dropout(model, opts.dropout) + + # Prepare optimizer + optimizer = build_optimizer(model, opts) + task2scaler = {t: i for i, t in enumerate(train_dataloaders.keys())} + model, optimizer = amp.initialize(model, optimizer, + num_losses=len(task2scaler), + enabled=opts.fp16, opt_level='O2') + + global_step = 0 + LOGGER.info(f"***** Running training with {n_gpu} GPUs *****") + LOGGER.info(" Batch size = %d", opts.train_batch_size) + LOGGER.info(" Accumulate steps = %d", opts.gradient_accumulation_steps) + LOGGER.info(" Num steps = %d", opts.num_train_steps) + + # to compute training statistics + task2loss = {task: RunningMeter(f'loss/{task}') + for task in train_dataloaders.keys()} + # ITM w/ OT + if opts.itm_ot_lambda > 0: + for task in train_dataloaders.keys(): + if task.startswith('itm'): + task2loss[f'{task}_xe'] = RunningMeter(f'loss/{task}_xe') + task2loss[f'{task}_ot'] = RunningMeter(f'loss/{task}_ot') + if not opts.ot_pos_only: + task2loss[f'{task}_ot_pos'] = RunningMeter( + f'loss/{task}_ot_pos') + task2loss[f'{task}_ot_neg'] = RunningMeter( + f'loss/{task}_ot_neg') + + n_examples = defaultdict(int) + n_in_units = defaultdict(int) + n_loss_units = defaultdict(int) + n_neg_nce = defaultdict(int) + grad_norm = 0 + + start = time() + # quick hack for amp delay_unscale bug + optimizer.zero_grad() + optimizer.step() + for step, (name, batch) in enumerate(meta_loader): + # forward pass + assert all(name == n for n in all_gather_list(name)) + n_examples[name] += batch['input_ids'].size(0) + n_in_units[name] += (batch['attn_masks'] == 1).sum().item() + if 'nce' in name: + n_neg_nce[name] += batch['neg_feats'].size(0) + task = name.split('_')[0] + loss = model(batch, task=task, compute_loss=True) + if task.startswith('itm'): + # OT + itm_loss, ot_loss = loss + n_loss_units[name] += itm_loss.size(0) + itm_loss = itm_loss.mean() + if ot_loss is not None: + if not opts.ot_pos_only: + ot_pos, ot_neg = ot_loss + ot_loss = (ot_pos.sum() - ot_neg.sum() + ) / (ot_pos.size(0) + ot_neg.size(0)) + + # NOTE: be ware of empty tensor + ot_pos = ot_pos.mean().item() + if not math.isnan(ot_pos): + task2loss[f'{name}_ot_pos'](ot_pos) + ot_neg = ot_neg.mean().item() + if not math.isnan(ot_neg): + task2loss[f'{name}_ot_neg'](ot_neg) + else: + ot_loss = ot_loss.mean() + loss = itm_loss + opts.itm_ot_lambda * ot_loss + task2loss[f'{name}_xe'](itm_loss.item()) + task2loss[f'{name}_ot'](ot_loss.item()) + else: + loss = itm_loss + else: + n_loss_units[name] += loss.size(0) + loss = loss.mean() # loss is not normalized in model + + # backward pass + delay_unscale = (step+1) % opts.gradient_accumulation_steps != 0 + with amp.scale_loss(loss, optimizer, delay_unscale=delay_unscale, + loss_id=task2scaler[name]) as scaled_loss: + scaled_loss.backward() + if not delay_unscale: + # gather gradients from every processes + # do this before unscaling to make sure every process uses + # the same gradient scale + grads = [p.grad.data for p in model.parameters() + if p.requires_grad and p.grad is not None] + all_reduce_and_rescale_tensors(grads, float(1)) + task2loss[name](loss.item()) + + # optimizer update and logging + if (step + 1) % opts.gradient_accumulation_steps == 0: + global_step += 1 + + # learning rate scheduling + lr_this_step = get_lr_sched(global_step, opts) + for param_group in optimizer.param_groups: + param_group['lr'] = lr_this_step + TB_LOGGER.add_scalar('lr', lr_this_step, global_step) + + # log loss + for t, l in task2loss.items(): + loss = sum(v for v in all_gather_list(l.val) + if v is not None) / hvd.size() + task2loss[t] = RunningMeter(f'loss/{t}', loss) + TB_LOGGER.log_scaler_dict({l.name: l.val + for l in task2loss.values() + if l.val is not None}) + TB_LOGGER.step() + + # update model params + if opts.grad_norm != -1: + ''' + if global_step % 10 == 0 and not opts.fp16: + bias = model.bert.img_embeddings.img_linear.bias + weight = model.bert.img_embeddings.img_linear.weight + print(f"bnorm: {bias.norm()}") + print(f"wnorm: {weight.norm()}") + print(f"bgnorm: {bias.grad.norm()}") + print(f"wgnorm: {weight.grad.norm()}") + + mask = model.bert.img_embeddings.mask_embedding.weight + print(f"mnorm: {mask.norm()}") + print(f"mgnorm: {mask.grad.norm()}") + + print([(n, p.grad.norm().item()) + for n, p in model.named_parameters() + if p.grad is not None + and p.grad.norm().item() > grad_norm/10]) + ''' + grad_norm = clip_grad_norm_(amp.master_params(optimizer), + opts.grad_norm) + TB_LOGGER.add_scalar('grad_norm', grad_norm, global_step) + optimizer.step() + optimizer.zero_grad() + pbar.update(1) + + if global_step % 100 == 0: + # monitor training throughput + LOGGER.info(f'==============Step {global_step}===============') + for t in train_dataloaders.keys(): + assert all(tt == t for tt in all_gather_list(t)) + tot_ex = sum(all_gather_list(n_examples[t])) + ex_per_sec = int(tot_ex / (time()-start)) + tot_in = sum(all_gather_list(n_in_units[t])) + in_per_sec = int(tot_in / (time()-start)) + tot_l = sum(all_gather_list(n_loss_units[t])) + l_per_sec = int(tot_l / (time()-start)) + LOGGER.info(f'{t}: {tot_ex} examples trained at ' + f'{ex_per_sec} ex/s') + TB_LOGGER.add_scalar(f'perf/{t}_ex_per_s', ex_per_sec, + global_step) + TB_LOGGER.add_scalar(f'perf/{t}_in_per_s', in_per_sec, + global_step) + TB_LOGGER.add_scalar(f'perf/{t}_loss_per_s', l_per_sec, + global_step) + if 'nce' in t: + avg_neg = sum(all_gather_list(n_neg_nce[t]) + ) / hvd.size() // step + LOGGER.info(f'{t}: averaging ' + f'{avg_neg} negative samples') + LOGGER.info(f'===============================================') + + if global_step % opts.valid_steps == 0: + LOGGER.info(f'Step {global_step}: start validation') + validate(model, val_dataloaders) + model_saver.save(model, global_step, optimizer) + if global_step >= opts.num_train_steps: + break + if global_step % opts.valid_steps != 0: + LOGGER.info(f'Step {global_step}: start validation') + validate(model, val_dataloaders) + model_saver.save(model, global_step) + + +def validate(model, val_dataloaders): + model.eval() + for task, loader in val_dataloaders.items(): + LOGGER.info(f"validate on {task} task") + if task.startswith('mlm'): + val_log = validate_mlm(model, loader) + elif task.startswith('mrfr'): + val_log = validate_mrfr(model, loader) + elif task.startswith('mrm-nce'): + val_log = validate_mrm_nce(model, loader) + elif task.startswith('mrc'): + val_log = validate_mrc(model, loader, task) + elif task.startswith('itm'): + val_log = validate_itm(model, loader) + else: + raise ValueError(f'Undefined task {task}') + val_log = {f'{task}_{k}': v for k, v in val_log.items()} + TB_LOGGER.log_scaler_dict( + {f'valid_{task}/{k}': v for k, v in val_log.items()}) + model.train() + + +@torch.no_grad() +def validate_mlm(model, val_loader): + LOGGER.info("start running MLM validation...") + val_loss = 0 + n_correct = 0 + n_word = 0 + st = time() + for i, batch in enumerate(val_loader): + scores = model(batch, task='mlm', compute_loss=False) + labels = batch['txt_labels'] + labels = labels[labels != -1] + loss = F.cross_entropy(scores, labels, reduction='sum') + val_loss += loss.item() + n_correct += (scores.max(dim=-1)[1] == labels).sum().item() + n_word += labels.numel() + val_loss = sum(all_gather_list(val_loss)) + n_correct = sum(all_gather_list(n_correct)) + n_word = sum(all_gather_list(n_word)) + tot_time = time()-st + val_loss /= n_word + acc = n_correct / n_word + val_log = {'loss': val_loss, + 'acc': acc, + 'tok_per_s': n_word/tot_time} + LOGGER.info(f"validation finished in {int(tot_time)} seconds, " + f"acc: {acc*100:.2f}") + return val_log + + +@torch.no_grad() +def validate_mlm_old(model, val_loader): + LOGGER.info("start running MLM validation...") + val_loss = 0 + n_correct = 0 + n_word = 0 + st = time() + for i, batch in enumerate(val_loader): + scores = model.forward(batch, task='mlm', compute_loss=False) + loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-1, + reduction='sum') + scores = scores.contiguous().view(-1, model.config.vocab_size) + labels = batch['txt_labels'].contiguous().view(-1) + loss = loss_fct(scores, labels) + val_loss += loss.item() + n_correct += accuracy_count(scores, labels) + n_word += batch['txt_labels'].numel() + val_loss = sum(all_gather_list(val_loss)) + n_correct = sum(all_gather_list(n_correct)) + n_word = sum(all_gather_list(n_word)) + tot_time = time()-st + val_loss /= n_word + acc = n_correct / n_word + val_log = {'loss': val_loss, + 'acc': acc, + 'tok_per_s': n_word/tot_time} + LOGGER.info(f"validation finished in {int(tot_time)} seconds, " + f"acc: {acc*100:.2f}") + return val_log + + +def accuracy_count(out, labels): + outputs = out.max(dim=-1)[1] + mask = labels != -1 + n_correct = (outputs == labels).masked_select(mask).sum().item() + return n_correct + + +@torch.no_grad() +def validate_mrfr(model, val_loader): + LOGGER.info("start running MRFR validation...") + val_loss = 0 + n_feat = 0 + st = time() + for i, batch in enumerate(val_loader): + loss = model(batch, task='mrfr', compute_loss=True) + val_loss += loss.sum().item() / IMG_DIM + n_feat += batch['img_mask_tgt'].sum().item() + val_loss = sum(all_gather_list(val_loss)) + n_feat = sum(all_gather_list(n_feat)) + tot_time = time()-st + val_loss /= n_feat + val_log = {'loss': val_loss, + 'feat_per_s': n_feat/tot_time} + LOGGER.info(f"validation finished in {int(tot_time)} seconds, " + f"loss: {val_loss:.2f}") + return val_log + + +@torch.no_grad() +def validate_mrm_nce(model, val_loader): + LOGGER.info("start running MRM-NCE validation...") + val_loss = 0 + val_l2 = 0 + n_correct = 0 + cosine = 0 + n_feat = 0 + n_neg = 0 + st = time() + for i, batch in enumerate(val_loader): + feats, pos_feats, neg_feats = model(batch, task='mrm-nce', + compute_loss=False) + logits = model.mrm_nce(feats, pos_feats, neg_feats, + compute_loss=False) + targets = torch.arange(0, logits.size(0), + dtype=torch.long, device=logits.device) + val_loss += F.cross_entropy(logits, targets, reduction='sum').item() + val_l2 += F.mse_loss(feats, pos_feats, reduction='sum' + ).item() / feats.size(-1) + n_correct += (logits.max(dim=-1)[1] == targets).sum().item() + cosine += F.cosine_similarity(feats, pos_feats, dim=-1).sum().item() + nf = batch['img_mask_tgt'].sum().item() + n_feat += nf + n_neg += neg_feats.size(0) * nf + val_loss = sum(all_gather_list(val_loss)) + val_l2 = sum(all_gather_list(val_l2)) + n_correct = sum(all_gather_list(n_correct)) + cosine = sum(all_gather_list(cosine)) + n_feat = sum(all_gather_list(n_feat)) + n_neg = sum(all_gather_list(n_neg)) + tot_time = time()-st + val_loss /= n_feat + val_acc = n_correct / n_feat + val_log = {'loss': val_loss, + 'acc': val_acc, + 'l2': val_l2 / n_feat, + 'cosine': cosine / n_feat, + 'feat_per_s': n_feat/tot_time} + LOGGER.info(f"validation finished in {int(tot_time)} seconds, " + f"loss: {val_loss:.2f}, acc: {val_acc*100:.2f} " + f"(average {n_neg/n_feat:.0f} negatives)") + return val_log + + +@torch.no_grad() +def validate_mrc(model, val_loader, task): + LOGGER.info("start running MRC validation...") + val_loss = 0 + n_feat = 0 + st = time() + tot_score = 0 + for i, batch in enumerate(val_loader): + prediction_soft_label = model( + batch, task=task, compute_loss=False) + if "kl" in task: + prediction_soft_label = F.log_softmax( + prediction_soft_label, dim=-1) + label_targets = batch['label_targets'] + loss = F.kl_div( + prediction_soft_label, label_targets, reduction='sum') + tot_score += compute_accuracy_for_soft_targets( + prediction_soft_label, label_targets) + else: + # background class should not be the target + cls_label_targets = label_targets[:, 1:].max(dim=-1)[1] + 1 + loss = F.cross_entropy( + prediction_soft_label, cls_label_targets, + ignore_index=0, reduction='sum') + tot_score += compute_accuracy_for_soft_targets( + prediction_soft_label[:, 1:], label_targets[:, 1:]) + val_loss += loss.item() + n_feat += batch['img_mask_tgt'].sum().item() + val_loss = sum(all_gather_list(val_loss)) + tot_score = sum(all_gather_list(tot_score)) + n_feat = sum(all_gather_list(n_feat)) + tot_time = time()-st + val_loss /= n_feat + val_acc = tot_score / n_feat + val_log = {'loss': val_loss, + 'acc': val_acc, + 'feat_per_s': n_feat/tot_time} + LOGGER.info(f"validation finished in {int(tot_time)} seconds, " + f"score: {val_acc*100:.2f}") + return val_log + + +def compute_accuracy_for_soft_targets(out, labels): + outputs = out.max(dim=-1)[1] + labels = labels.max(dim=-1)[1] # argmax + n_correct = (outputs == labels).sum().item() + return n_correct + + +@torch.no_grad() +def validate_itm(model, val_loader): + LOGGER.info("start running ITM validation...") + val_loss = 0 + tot_ot_loss = 0 + tot_ot_pos = 0 + tot_ot_neg = 0 + tot_score = 0 + n_ex = 0 + st = time() + for i, batch in enumerate(val_loader): + scores, ot_loss = model(batch, task='itm', compute_loss=False) + if ot_loss is not None: + if isinstance(ot_loss, tuple): + ot_pos, ot_neg = ot_loss + ot_pos = ot_pos.sum().item() + ot_neg = ot_neg.sum().item() + tot_ot_pos += ot_pos + tot_ot_neg += ot_neg + tot_ot_loss += ot_pos - ot_neg + else: + tot_ot_loss += ot_loss.sum().item() + targets = batch['targets'] + loss = F.cross_entropy(scores, targets, reduction='sum') + val_loss += loss.item() + + tot_score += (scores.max(dim=-1)[1] == targets).sum().item() + n_ex += len(targets) + val_loss = sum(all_gather_list(val_loss)) + tot_score = sum(all_gather_list(tot_score)) + n_ex = sum(all_gather_list(n_ex)) + tot_time = time()-st + val_loss /= n_ex + val_acc = tot_score / n_ex + val_log = {'valid/loss': val_loss, + 'valid/acc': val_acc, + 'valid/ex_per_s': n_ex/tot_time} + + if ot_loss is not None: + tot_ot_loss = sum(all_gather_list(tot_ot_loss)) + tot_ot_pos = sum(all_gather_list(tot_ot_pos)) + tot_ot_neg = sum(all_gather_list(tot_ot_neg)) + val_log['valid/ot_loss'] = tot_ot_loss / n_ex + val_log['valid/ot_pos'] = tot_ot_pos / n_ex + val_log['valid/ot_neg'] = tot_ot_neg / n_ex + + LOGGER.info(f"validation finished in {int(tot_time)} seconds, " + f"score: {val_acc*100:.2f}") + return val_log + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + # Required parameters + # NOTE: train tasks and val tasks cannot take command line arguments + parser.add_argument('--compressed_db', action='store_true', + help='use compressed LMDB') + + parser.add_argument("--model_config", type=str, + help="path to model structure config json") + parser.add_argument("--checkpoint", default=None, type=str, + help="path to model checkpoint (*.pt)") + + parser.add_argument( + "--output_dir", default=None, type=str, + help="The output directory where the model checkpoints will be " + "written.") + + parser.add_argument('--mrm_prob', default=0.15, type=float, + help='probability to mask in MRM training') + parser.add_argument('--neg_size', default=128, type=int, + help='negative image size for NCE') + parser.add_argument('--nce_temp', default=1.0, type=float, + help='softmax temperature for NCE') + parser.add_argument('--itm_neg_prob', default=0.5, type=float, + help='probability to make negative examples' + 'in ITM training') + parser.add_argument('--itm_ot_lambda', default=0.0, type=float, + help='weight of OT (optimal transport) loss') + parser.add_argument('--ot_pos_only', action='store_true', + help='use OT distance of positive pairs only') + + # Prepro parameters + parser.add_argument('--max_txt_len', type=int, default=60, + help='max number of tokens in text (BERT BPE)') + parser.add_argument('--conf_th', type=float, default=0.2, + help='threshold for dynamic bounding boxes ' + '(-1 for fixed)') + parser.add_argument('--max_bb', type=int, default=100, + help='max number of bounding boxes') + parser.add_argument('--min_bb', type=int, default=10, + help='min number of bounding boxes') + parser.add_argument('--num_bb', type=int, default=36, + help='static number of bounding boxes') + + # training parameters + parser.add_argument("--train_batch_size", default=4096, type=int, + help="Total batch size for training. " + "(batch by tokens)") + parser.add_argument("--val_batch_size", default=4096, type=int, + help="Total batch size for validation. " + "(batch by tokens)") + parser.add_argument('--gradient_accumulation_steps', type=int, default=16, + help="Number of updates steps to accumualte before " + "performing a backward/update pass.") + parser.add_argument("--learning_rate", default=3e-5, type=float, + help="The initial learning rate for Adam.") + parser.add_argument("--valid_steps", default=1000, type=int, + help="Run validation every X steps") + parser.add_argument("--num_train_steps", default=100000, type=int, + help="Total number of training updates to perform.") + parser.add_argument("--optim", default='adamw', + choices=['adam', 'adamax', 'adamw'], + help="optimizer") + parser.add_argument("--betas", default=[0.9, 0.98], nargs='+', + help="beta for adam optimizer") + parser.add_argument("--decay", default='linear', + choices=['linear', 'invsqrt'], + help="learning rate decay method") + parser.add_argument("--dropout", default=0.1, type=float, + help="tune dropout regularization") + parser.add_argument("--weight_decay", default=0.01, type=float, + help="weight decay (L2) regularization") + parser.add_argument("--grad_norm", default=2.0, type=float, + help="gradient clipping (-1 for no clipping)") + parser.add_argument("--warmup_steps", default=10000, type=int, + help="Number of training steps to perform linear " + "learning rate warmup for. (invsqrt decay)") + + # device parameters + parser.add_argument('--seed', type=int, default=42, + help="random seed for initialization") + parser.add_argument('--fp16', action='store_true', + help="Whether to use 16-bit float precision instead " + "of 32-bit") + parser.add_argument('--n_workers', type=int, default=4, + help="number of data workers") + parser.add_argument('--pin_mem', action='store_true', help="pin memory") + + # can use config files + parser.add_argument('--config', required=True, help='JSON config files') + + args = parse_with_config(parser) + + if exists(args.output_dir) and os.listdir(args.output_dir): + raise ValueError("Output directory ({}) already exists and is not " + "empty.".format(args.output_dir)) + + # options safe guard + if args.conf_th == -1: + assert args.max_bb + args.max_txt_len + 2 <= 512 + else: + assert args.num_bb + args.max_txt_len + 2 <= 512 + + main(args) diff --git a/uniter_model/pretrain_vcr.py b/uniter_model/pretrain_vcr.py new file mode 100644 index 0000000..0b4bb09 --- /dev/null +++ b/uniter_model/pretrain_vcr.py @@ -0,0 +1,754 @@ +# coding=utf-8 +# copied from hugginface github +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. +# team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""BERT pre-training runner.""" +import argparse +import json +import os +from os.path import exists, join +import random +from time import time + +import torch +from torch.nn import functional as F +from torch.nn.utils import clip_grad_norm_ +from torch.optim import Adam, Adamax +from torch.utils.data import DataLoader +from data.data import ConcatDetectFeatBertTokDataset as ConcatDataset + +from apex import amp +from horovod import torch as hvd + +import numpy as np +from tqdm import tqdm + +from data import (DistributedTokenBucketSampler, + DetectFeatLmdb, MlmDatasetForVCR, mlm_collate_for_vcr, + MrmDatasetForVCR, mrm_collate_for_vcr, + MrcDatasetForVCR, mrc_collate_for_vcr, + MetaLoader, PrefetchLoader) +from model import BertForImageTextPretrainingForVCR +from optim import warmup_linear, noam_schedule, vqa_schedule, AdamW + +from utils.logger import LOGGER, TB_LOGGER, RunningMeter, add_log_to_file +from utils.distributed import (all_reduce_and_rescale_tensors, all_gather_list, + broadcast_tensors) +from utils.save import ModelSaver, save_training_meta +from utils.misc import NoOp, parse_with_config +NUM_SPECIAL_TOKENS = 81 +IMG_DIM = 2048 +IMG_LABEL_DIM = 1601 + + +def parse_tasks(datasets): + task_names = [] + dset_paths = [] + mix_ratio = [] + for i, dset in enumerate(datasets): + assert len(dset['db']) == len(dset['img']) + if 'mix_ratio' in dset: + assert len(dset['tasks']) == len(dset['mix_ratio']) + mix_ratio.extend(dset['mix_ratio']) + task_names.extend(f'{t}_{dset["name"]}' for t in dset['tasks']) + n_task = len(dset['tasks']) + dset_paths.extend([(dset['db'], dset['img'])] * n_task) + + assert len(task_names) == len(set(task_names)) == len(dset_paths) + if mix_ratio: + assert len(task_names) == len(mix_ratio) + return task_names, dset_paths, mix_ratio + else: + return task_names, dset_paths + + +def build_sampler(lens, batch_size, eval_, bucket_size=8192): + droplast = not eval_ + sampler = DistributedTokenBucketSampler( + hvd.size(), hvd.rank(), lens, + bucket_size=bucket_size, batch_size=batch_size, droplast=droplast) + return sampler + + +def build_mlm_train_dataloader(txt_db, img_dir_gt, img_dir, + n_gpu, opts): + LOGGER.info(f"Loading MLM Train Dataset {txt_db}, " + f"{[i.img_dir for i in img_dir]}" + f"{[i.img_dir for i in img_dir_gt]}") + train_datasets = [MlmDatasetForVCR( + db, dir_gt_, dir_, opts.max_txt_len, task=t) + for db, dir_gt_, dir_ in zip(txt_db, img_dir_gt, img_dir) + for t in opts.vcr_task] + train_dataset = ConcatDataset(train_datasets) + train_sampler = build_sampler(train_dataset.lens, + opts.train_batch_size, eval_=False) + train_dataloader = DataLoader(train_dataset, + batch_sampler=train_sampler, + num_workers=opts.n_workers, + pin_memory=opts.pin_mem, + collate_fn=mlm_collate_for_vcr) + LOGGER.info(f"{len(train_dataset)} samples loaded") + return train_dataloader + + +def build_mrm_train_dataloader(txt_db, img_dir_gt, img_dir, + n_gpu, opts): + LOGGER.info(f"Loading MRM Train Dataset {txt_db}, " + f"{[i.img_dir for i in img_dir]}" + f"{[i.img_dir for i in img_dir_gt]}") + + train_datasets = [MrmDatasetForVCR( + opts.mrm_prob, db, dir_gt_, + dir_, opts.max_txt_len, task=t) + for db, dir_gt_, dir_ in zip(txt_db, img_dir_gt, img_dir) + for t in opts.vcr_task] + train_dataset = ConcatDataset(train_datasets) + train_sampler = build_sampler(train_dataset.lens, + opts.train_batch_size, eval_=False) + train_dataloader = DataLoader(train_dataset, + batch_sampler=train_sampler, + num_workers=opts.n_workers, + pin_memory=opts.pin_mem, + collate_fn=mrm_collate_for_vcr) + LOGGER.info(f"{len(train_dataset)} samples loaded") + return train_dataloader + + +def build_mrc_train_dataloader(txt_db, img_dir_gt, img_dir, + n_gpu, opts): + LOGGER.info(f"Loading MRC Train Dataset {txt_db}, " + f"{[i.img_dir for i in img_dir]}" + f"{[i.img_dir for i in img_dir_gt]}") + train_datasets = [MrcDatasetForVCR( + opts.mrc_prob, db, dir_gt_, + dir_, opts.max_txt_len, task=t) + for db, dir_gt_, dir_ in zip(txt_db, img_dir_gt, img_dir) + for t in opts.vcr_task] + train_dataset = ConcatDataset(train_datasets) + train_sampler = build_sampler(train_dataset.lens, + opts.train_batch_size, eval_=False) + train_dataloader = DataLoader(train_dataset, + batch_sampler=train_sampler, + num_workers=opts.n_workers, + pin_memory=opts.pin_mem, + collate_fn=mrc_collate_for_vcr) + LOGGER.info(f"{len(train_dataset)} samples loaded") + return train_dataloader + + +def build_mlm_val_dataloader(txt_db, img_dir_gt, img_dir, + n_gpu, opts): + LOGGER.info(f"Loading MLM Val Dataset {txt_db}, " + f"{img_dir_gt.img_dir}, {img_dir.img_dir}") + val_datasets = [MlmDatasetForVCR( + txt_db, img_dir_gt, img_dir, -1, task=t) + for t in opts.vcr_task] + val_dataset = ConcatDataset(val_datasets) + val_sampler = build_sampler(val_dataset.lens, + opts.val_batch_size, eval_=True) + val_dataloader = DataLoader(val_dataset, + batch_sampler=val_sampler, + num_workers=opts.n_workers, + pin_memory=opts.pin_mem, + collate_fn=mlm_collate_for_vcr) + LOGGER.info(f"{len(val_dataset)} samples loaded") + return val_dataloader + + +def build_mrm_val_dataloader(txt_db, img_dir_gt, img_dir, + n_gpu, opts): + LOGGER.info(f"Loading MRM Val Dataset {txt_db}, " + f"{img_dir_gt.img_dir}, {img_dir.img_dir}") + val_datasets = [MrmDatasetForVCR( + opts.mrm_prob, txt_db, img_dir_gt, + img_dir, -1, task=t) + for t in opts.vcr_task] + val_dataset = ConcatDataset(val_datasets) + val_sampler = build_sampler(val_dataset.lens, + opts.val_batch_size, eval_=True) + val_dataloader = DataLoader(val_dataset, + batch_sampler=val_sampler, + num_workers=opts.n_workers, + pin_memory=opts.pin_mem, + collate_fn=mrm_collate_for_vcr) + LOGGER.info(f"{len(val_dataset)} samples loaded") + return val_dataloader + + +def build_mrc_val_dataloader(txt_db, img_dir_gt, img_dir, + n_gpu, opts): + LOGGER.info(f"Loading MRC Val Dataset {txt_db}, " + f"{img_dir_gt.img_dir}, {img_dir.img_dir}") + val_datasets = [MrcDatasetForVCR( + opts.mrc_prob, txt_db, img_dir_gt, + img_dir, -1, task=t) + for t in opts.vcr_task] + val_dataset = ConcatDataset(val_datasets) + val_sampler = build_sampler(val_dataset.lens, + opts.val_batch_size, eval_=True) + val_dataloader = DataLoader(val_dataset, + batch_sampler=val_sampler, + num_workers=opts.n_workers, + pin_memory=opts.pin_mem, + collate_fn=mrc_collate_for_vcr) + LOGGER.info(f"{len(val_dataset)} samples loaded") + return val_dataloader + + +def load_img_feat(dir_list, path2imgdir, opts): + dir_ = dir_list.split(";") + assert len(dir_) <= 2, "More than two img_dirs found" + img_dir_gt, img_dir = None, None + gt_dir_path, dir_path = "", "" + for d in dir_: + if "gt" in d: + gt_dir_path = d + else: + dir_path = d + if gt_dir_path != "": + img_dir_gt = path2imgdir.get(gt_dir_path, None) + if img_dir_gt is None: + img_dir_gt = DetectFeatLmdb(gt_dir_path, -1, + opts.max_bb, opts.min_bb, 100, + opts.compressed_db) + path2imgdir[gt_dir_path] = img_dir_gt + if dir_path != "": + img_dir = path2imgdir.get(dir_path, None) + if img_dir is None: + img_dir = DetectFeatLmdb(dir_path, opts.conf_th, + opts.max_bb, opts.min_bb, opts.num_bb, + opts.compressed_db) + path2imgdir[dir_path] = img_dir + return img_dir, img_dir_gt, path2imgdir + + +def main(opts): + hvd.init() + n_gpu = hvd.size() + device = torch.device("cuda", hvd.local_rank()) + torch.cuda.set_device(hvd.local_rank()) + rank = hvd.rank() + opts.rank = rank + LOGGER.info("device: {} n_gpu: {}, rank: {}, " + "16-bits training: {}".format( + device, n_gpu, hvd.rank(), opts.fp16)) + + if opts.gradient_accumulation_steps < 1: + raise ValueError("Invalid gradient_accumulation_steps parameter: {}, " + "should be >= 1".format( + opts.gradient_accumulation_steps)) + + random.seed(opts.seed) + np.random.seed(opts.seed) + torch.manual_seed(opts.seed) + if n_gpu > 0: + torch.cuda.manual_seed_all(opts.seed) + + if rank == 0: + save_training_meta(opts) + TB_LOGGER.create(join(opts.output_dir, 'log')) + pbar = tqdm(total=opts.num_train_steps) + model_saver = ModelSaver(join(args.output_dir, 'ckpt')) + add_log_to_file(join(opts.output_dir, 'log', 'log.txt')) + else: + LOGGER.disabled = True + pbar = NoOp() + model_saver = NoOp() + + all_dbs = [db for datasets in [opts.train_datasets, opts.val_datasets] + for dset in datasets for db in dset['db']] + bert_model = json.load(open(f'{all_dbs[0]}/meta.json'))['bert'] + assert all(bert_model == json.load(open(f'{db}/meta.json'))['bert'] + for db in all_dbs) + + train_tasks, train_data_paths, mix_ratio = parse_tasks(opts.train_datasets) + train_dataloaders = [] + path2imgdir = {} + for (dbs, dirs), task in zip(train_data_paths, train_tasks): + img_dirs = [] + img_gt_dirs = [] + for db, dir_list in zip(dbs, dirs): + img_dir, img_dir_gt, path2imgdir = load_img_feat( + dir_list, path2imgdir, opts) + img_dirs.append(img_dir) + img_gt_dirs.append(img_dir_gt) + if task.startswith('mlm'): + loader = build_mlm_train_dataloader(dbs, img_gt_dirs, img_dirs, + n_gpu, opts) + elif task.startswith('mrm'): + loader = build_mrm_train_dataloader(dbs, img_gt_dirs, img_dirs, + n_gpu, opts) + elif task.startswith('mrc'): + loader = build_mrc_train_dataloader(dbs, img_gt_dirs, img_dirs, + n_gpu, opts) + else: + raise ValueError(f'Undefined task {task}') + train_dataloaders.append(loader) + val_tasks, val_data_paths = parse_tasks(opts.val_datasets) + val_dataloaders = [] + for (db, dir_), task in zip(val_data_paths, val_tasks): + assert len(db) == len(dir_) == 1 + db = db[0] + dir_ = dir_[0] + img_dir, img_dir_gt, path2imgdir = load_img_feat( + dir_, path2imgdir, opts) + if task.startswith('mlm'): + loader = build_mlm_val_dataloader(db, img_dir_gt, img_dir, n_gpu, opts) + elif task.startswith('mrm'): + loader = build_mrm_val_dataloader(db, img_dir_gt, img_dir, n_gpu, opts) + elif task.startswith('mrc'): + loader = build_mrc_val_dataloader(db, img_dir_gt, img_dir, n_gpu, opts) + else: + raise ValueError(f'Undefined task {task}') + val_dataloaders.append(PrefetchLoader(loader)) + meta_loader = MetaLoader(train_dataloaders, + mix_ratio=mix_ratio, names=train_tasks, + accum_steps=opts.gradient_accumulation_steps, + distributed=n_gpu > 1) + meta_loader = PrefetchLoader(meta_loader) + named_val_loaders = list(zip(val_tasks, val_dataloaders)) + + # Prepare model + + if opts.checkpoint: + if opts.checkpoint == 'google-bert': + checkpoint = None + else: + checkpoint = torch.load(opts.checkpoint) + else: + checkpoint = {} + model = BertForImageTextPretrainingForVCR.from_pretrained( + bert_model, img_dim=IMG_DIM, img_label_dim=IMG_LABEL_DIM, + state_dict=checkpoint) + model.init_type_embedding() + model.init_word_embedding(NUM_SPECIAL_TOKENS) + model.pad_vocab() # tensor core padding for vocabulary + if opts.cut_bert != -1: + # cut some layers of BERT + model.bert.encoder.layer = torch.nn.ModuleList( + model.bert.encoder.layer[:opts.cut_bert]) + + for name, module in model.named_modules(): + # we might want to tune dropout for smaller dataset + if isinstance(module, torch.nn.Dropout): + if module.p != opts.dropout: + module.p = opts.dropout + LOGGER.info(f'{name} set to {opts.dropout}') + model.to(device) + # make sure every process has same model parameters in the beginning + broadcast_tensors([p.data for p in model.parameters()], 0) + + # Prepare optimizer + param_optimizer = list(model.named_parameters()) + no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] + optimizer_grouped_parameters = [ + {'params': [p for n, p in param_optimizer + if not any(nd in n for nd in no_decay)], + 'weight_decay': opts.weight_decay}, + {'params': [p for n, p in param_optimizer + if any(nd in n for nd in no_decay)], + 'weight_decay': 0.0} + ] + + if opts.optim == 'adam': + OptimCls = Adam + elif opts.optim == 'adamax': + OptimCls = Adamax + elif opts.optim == 'adamw': + OptimCls = AdamW + else: + raise ValueError('invalid optimizer') + optimizer = OptimCls(optimizer_grouped_parameters, + lr=opts.learning_rate, betas=opts.betas) + model, optimizer = amp.initialize(model, optimizer, + enabled=opts.fp16, opt_level='O2') + + global_step = 0 + if rank == 0: + save_training_meta(opts) + TB_LOGGER.create(join(opts.output_dir, 'log')) + pbar = tqdm(total=opts.num_train_steps) + model_saver = ModelSaver(join(opts.output_dir, 'ckpt')) + os.makedirs(join(opts.output_dir, 'results')) # store VQA predictions + add_log_to_file(join(opts.output_dir, 'log', 'log.txt')) + else: + LOGGER.disabled = True + pbar = NoOp() + model_saver = NoOp() + + LOGGER.info(f"***** Running training with {n_gpu} GPUs *****") + LOGGER.info(" Batch size = %d", opts.train_batch_size) + LOGGER.info(" Accumulate steps = %d", opts.gradient_accumulation_steps) + LOGGER.info(" Num steps = %d", opts.num_train_steps) + + task2loss = {task: RunningMeter(f'loss/{task}') for task in train_tasks} + model.train() + n_examples = 0 + n_epoch = 0 + start = time() + # quick hack for amp delay_unscale bug + optimizer.zero_grad() + optimizer.step() + while True: + for step, (name, batch) in enumerate(meta_loader): + input_ids, *_ = batch + n_examples += input_ids.size(0) + task = name.split('_')[0] + loss = model(*batch, task=task, compute_loss=True) + loss = loss.mean() # loss is not normalized + if task == 'mrckl': + # MRCkl normalization; safeguard fp16 overflow + loss = loss.float() * IMG_LABEL_DIM + delay_unscale = (step+1) % opts.gradient_accumulation_steps != 0 + with amp.scale_loss(loss, optimizer, delay_unscale=delay_unscale + ) as scaled_loss: + scaled_loss.backward() + if not delay_unscale: + # gather gradients from every processes + # do this before unscaling to make sure every process uses + # the same gradient scale + grads = [p.grad.data for p in model.parameters() + if p.requires_grad and p.grad is not None] + all_reduce_and_rescale_tensors(grads, float(1)) + + task2loss[name](loss.item()) + + if (step + 1) % opts.gradient_accumulation_steps == 0: + global_step += 1 + + # learning rate scheduling + if opts.decay == 'linear': + lr_this_step = opts.learning_rate * warmup_linear( + global_step, opts.warmup_steps, opts.num_train_steps) + elif opts.decay == 'invsqrt': + lr_this_step = opts.learning_rate * noam_schedule( + global_step, opts.warmup_steps) + elif opts.decay == 'constant': + lr_this_step = opts.learning_rate + elif opts.decay == 'vqa': + lr_this_step = opts.learning_rate * vqa_schedule( + global_step, opts.warm_int, opts.decay_int, + opts.decay_st, opts.decay_rate) + if lr_this_step < 0: + # save guard for possible miscalculation of train steps + lr_this_step = 1e-8 + for param_group in optimizer.param_groups: + param_group['lr'] = lr_this_step + TB_LOGGER.add_scalar('lr', lr_this_step, global_step) + + # log loss + for t, l in task2loss.items(): + loss = sum(v for v in all_gather_list(l.val) + if v is not None) / hvd.size() + task2loss[t] = RunningMeter(f'loss/{t}', loss) + TB_LOGGER.log_scaler_dict({l.name: l.val + for l in task2loss.values() + if l.val is not None}) + TB_LOGGER.step() + + # update model params + if opts.grad_norm != -1: + grad_norm = clip_grad_norm_(amp.master_params(optimizer), + opts.grad_norm) + TB_LOGGER.add_scalar('grad_norm', grad_norm, global_step) + optimizer.step() + optimizer.zero_grad() + pbar.update(1) + + if global_step % 5 == 0: + torch.cuda.empty_cache() + if global_step % 100 == 0: + # monitor training throughput + tot_ex = sum(all_gather_list(n_examples)) + ex_per_sec = int(tot_ex / (time()-start)) + LOGGER.info(f'{tot_ex} examples trained at ' + f'{ex_per_sec} ex/s') + TB_LOGGER.add_scalar('perf/ex_per_s', + ex_per_sec, global_step) + + if global_step % opts.valid_steps == 0: + validate(model, named_val_loaders) + model_saver.save(model, global_step) + if global_step >= opts.num_train_steps: + break + if global_step % opts.valid_steps != 0: + validate(model, named_val_loaders) + model_saver.save(model, global_step) + + +def validate(model, named_val_loaders): + model.eval() + for task, loader in named_val_loaders: + LOGGER.info(f"validate on {task} task") + if task.startswith('mlm'): + val_log = validate_mlm(model, loader) + elif task.startswith('mrm'): + val_log = validate_mrm(model, loader) + elif task.startswith('mrc'): + val_log = validate_mrc(model, loader, task) + else: + raise ValueError(f'Undefined task {task}') + val_log = {f'{task}_{k}': v for k, v in val_log.items()} + TB_LOGGER.log_scaler_dict( + {f'valid_{task}/{k}': v for k, v in val_log.items()}) + model.train() + + +@torch.no_grad() +def validate_mrc(model, val_loader, task): + LOGGER.info("start running MRC validation...") + val_loss = 0 + n_feat = 0 + st = time() + tot_score = 0 + for i, batch in enumerate(val_loader): + *_, label = batch + feat_mask, label_targets = label + prediction_soft_label = model( + *batch, task=task, compute_loss=False) + if "kl" in task: + prediction_soft_label = F.log_softmax( + prediction_soft_label, dim=-1) + loss = F.kl_div( + prediction_soft_label, label_targets, reduction='sum') + tot_score += compute_accuracy_for_mrc( + prediction_soft_label, label_targets) + else: + cls_label_targets = label_targets.max(dim=-1)[1] # argmax + loss = F.cross_entropy( + prediction_soft_label, cls_label_targets, + ignore_index=0, reduction='sum') + tot_score += compute_accuracy_for_mrc( + prediction_soft_label[:, 1:], label_targets[:, 1:]) + val_loss += loss.item() + n_feat += feat_mask.sum().item() + val_loss = sum(all_gather_list(val_loss)) + tot_score = sum(all_gather_list(tot_score)) + n_feat = sum(all_gather_list(n_feat)) + tot_time = time()-st + val_loss /= n_feat + val_acc = tot_score / n_feat + val_log = {'loss': val_loss, + 'acc': val_acc, + 'feat_per_s': n_feat/tot_time} + LOGGER.info(f"validation finished in {int(tot_time)} seconds, " + f"score: {val_acc*100:.2f}") + return val_log + + +@torch.no_grad() +def validate_mrm(model, val_loader): + LOGGER.info("start running MRM validation...") + val_loss = 0 + n_feat = 0 + st = time() + for i, batch in enumerate(val_loader): + *_, feat_mask = batch + loss = model(*batch, task='mrm', compute_loss=True) + val_loss += loss.sum().item() + n_feat += feat_mask.sum().item() + val_loss = sum(all_gather_list(val_loss)) + n_feat = sum(all_gather_list(n_feat)) + tot_time = time()-st + val_loss /= (n_feat * IMG_DIM) + val_log = {'loss': val_loss, + 'feat_per_s': n_feat/tot_time} + LOGGER.info(f"validation finished in {int(tot_time)} seconds, " + f"loss: {val_loss:.2f}") + return val_log + + +@torch.no_grad() +def validate_mlm(model, val_loader): + LOGGER.info(f"start running MLM validation ...") + val_loss = 0 + n_correct = 0 + n_word = 0 + st = time() + for i, batch in enumerate(val_loader): + *inputs, txt_labels = batch + loss = model.forward(*batch, task='mlm', compute_loss=True) + # loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-1, + # reduction='sum') + # loss = loss_fct(scores, txt_labels) + loss = loss.mean() + val_loss += loss.item() + # n_correct += accuracy_count(scores, txt_labels) + n_word += txt_labels.numel() + val_loss = sum(all_gather_list(val_loss)) + n_correct = sum(all_gather_list(n_correct)) + n_word = sum(all_gather_list(n_word)) + tot_time = time()-st + val_loss /= n_word + acc = n_correct / n_word + val_log = {'loss': val_loss, + 'acc': acc, + 'tok_per_s': n_word/tot_time} + LOGGER.info(f"validation finished in {int(tot_time)} seconds, " + f"acc: {acc*100:.2f}" + f"loss: {val_loss}") + return val_log + + +def compute_accuracy_for_mrc(out, labels): + outputs = out.max(dim=-1)[1] + labels = labels.max(dim=-1)[1] # argmax + n_correct = (outputs == labels).sum().item() + return n_correct + + +def accuracy_count(out, labels): + outputs = out.max(dim=-1)[1] + mask = labels != -1 + n_correct = (outputs == labels).masked_select(mask).sum().item() + return n_correct + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + # Required parameters + # NOTE: train tasks and val tasks cannot take command line arguments + parser.add_argument('--compressed_db', action='store_true', + help='use compressed LMDB') + parser.add_argument("--vcr_task", + default=["qar"], type=str, nargs='+', + choices=['qa', 'qar'], + help="VCR tasks: qa or qar") + parser.add_argument('--tasks', default=None, type=str, nargs='+', + help="specify pretraining tasks") + parser.add_argument('--mrm_prob', default=0.15, type=float, + help='probability to mask in MRM training') + parser.add_argument('--mrc_prob', default=0.15, type=float, + help='probability to mask in MRC training') + parser.add_argument("--checkpoint", + default=None, type=str, + help="pretrained model (can take 'google-bert') ") + parser.add_argument("--cut_bert", default=-1, type=int, + help="reduce BERT layers (-1 for original depth)") + + parser.add_argument( + "--output_dir", default=None, type=str, + help="The output directory where the model checkpoints will be " + "written.") + + # Prepro parameters + parser.add_argument('--max_txt_len', type=int, default=60, + help='max number of tokens in text (BERT BPE)') + parser.add_argument('--conf_th', type=float, default=0.2, + help='threshold for dynamic bounding boxes ' + '(-1 for fixed)') + parser.add_argument('--max_bb', type=int, default=100, + help='max number of bounding boxes') + parser.add_argument('--min_bb', type=int, default=10, + help='min number of bounding boxes') + parser.add_argument('--num_bb', type=int, default=36, + help='static number of bounding boxes') + + # training parameters + parser.add_argument("--train_batch_size", + default=4096, type=int, + help="Total batch size for training. " + "(batch by tokens)") + parser.add_argument("--val_batch_size", + default=4096, type=int, + help="Total batch size for validation. " + "(batch by tokens)") + parser.add_argument('--gradient_accumulation_steps', + type=int, + default=16, + help="Number of updates steps to accumualte before " + "performing a backward/update pass.") + parser.add_argument("--learning_rate", + default=3e-5, + type=float, + help="The initial learning rate for Adam.") + parser.add_argument("--valid_steps", + default=1000, + type=int, + help="Run validation every X steps") + parser.add_argument("--num_train_steps", + default=100000, + type=int, + help="Total number of training updates to perform.") + parser.add_argument('--mask_prob', default=0.15, type=float, + help='probability to mask in MRC training') + parser.add_argument("--optim", default='adam', + choices=['adam', 'adamax', 'adamw'], + help="optimizer") + parser.add_argument("--betas", default=[0.9, 0.98], nargs='+', + help="beta for adam optimizer") + parser.add_argument("--decay", default='linear', + choices=['linear', 'invsqrt', 'constant', 'vqa'], + help="learning rate decay method") + parser.add_argument("--decay_int", default=2000, type=int, + help="interval between VQA lr decy") + parser.add_argument("--warm_int", default=2000, type=int, + help="interval for VQA lr warmup") + parser.add_argument("--decay_st", default=20000, type=int, + help="when to start decay") + parser.add_argument("--decay_rate", default=0.2, type=float, + help="ratio of lr decay") + parser.add_argument("--dropout", + default=0.1, + type=float, + help="tune dropout regularization") + parser.add_argument("--weight_decay", + default=0.0, + type=float, + help="weight decay (L2) regularization") + parser.add_argument("--grad_norm", + default=0.25, + type=float, + help="gradient clipping (-1 for no clipping)") + parser.add_argument("--warmup_steps", + default=4000, + type=int, + help="Number of training steps to perform linear " + "learning rate warmup for. (invsqrt decay)") + + # device parameters + parser.add_argument('--seed', + type=int, + default=42, + help="random seed for initialization") + parser.add_argument('--fp16', + action='store_true', + help="Whether to use 16-bit float precision instead " + "of 32-bit") + parser.add_argument('--n_workers', type=int, default=4, + help="number of data workers") + parser.add_argument('--pin_mem', action='store_true', + help="pin memory") + + # can use config files + parser.add_argument('--config', help='JSON config files') + + args = parse_with_config(parser) + + if exists(args.output_dir) and os.listdir(args.output_dir): + raise ValueError("Output directory ({}) already exists and is not " + "empty.".format(args.output_dir)) + + # options safe guard + # TODO + + if args.conf_th == -1: + assert args.max_bb + args.max_txt_len + 2 <= 512 + else: + assert args.num_bb + args.max_txt_len + 2 <= 512 + assert len(args.vcr_task) > 0, "Must choose at least one vcr task" + main(args) diff --git a/uniter_model/requirements.txt b/uniter_model/requirements.txt new file mode 100644 index 0000000..bed0694 --- /dev/null +++ b/uniter_model/requirements.txt @@ -0,0 +1,3 @@ +pytorch_pretrained_bert +tensorboardX +ipdb diff --git a/uniter_model/scripts/compress_lmdb.py b/uniter_model/scripts/compress_lmdb.py new file mode 100644 index 0000000..54e610c --- /dev/null +++ b/uniter_model/scripts/compress_lmdb.py @@ -0,0 +1,53 @@ +""" +compress processed LMDB +""" +import argparse +import io +import multiprocessing as mp + +import numpy as np +import lmdb +from tqdm import tqdm + +import msgpack +import msgpack_numpy +msgpack_numpy.patch() + + +def compress_dump(item): + key, dump = item + img_dump = {k.decode('utf-8'): v for k, v in msgpack.loads(dump).items()} + with io.BytesIO() as writer: + np.savez_compressed(writer, **img_dump, allow_pickle=True) + return key, writer.getvalue() + + +def main(opts): + if opts.db[-1] == '/': + opts.db = opts.db[:-1] + out_name = f'{opts.db}_compressed' + env = lmdb.open(opts.db, readonly=True) + txn = env.begin() + out_env = lmdb.open(out_name, map_size=1024**4) + out_txn = out_env.begin(write=True) + with mp.Pool(opts.nproc) as pool, tqdm(total=txn.stat()['entries']) as pbar: + for i, (key, value) in enumerate( + pool.imap_unordered(compress_dump, txn.cursor(), + chunksize=128)): + out_txn.put(key=key, value=value) + if i % 1000 == 0: + out_txn.commit() + out_txn = out_env.begin(write=True) + pbar.update(1) + out_txn.commit() + out_env.close() + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument("--db", default=None, type=str, + help="processed LMDB") + parser.add_argument('--nproc', type=int, + help='number of cores used') + args = parser.parse_args() + main(args) diff --git a/uniter_model/scripts/compute_numbb.py b/uniter_model/scripts/compute_numbb.py new file mode 100644 index 0000000..fb18942 --- /dev/null +++ b/uniter_model/scripts/compute_numbb.py @@ -0,0 +1,71 @@ +""" +compute adaptive number of bounding boxes +""" +import argparse +import glob +import json +from os.path import basename +import multiprocessing as mp + +import numpy as np +from tqdm import tqdm +from cytoolz import curry + + +def _compute_nbb(img_dump, conf_th, max_bb, min_bb): + num_bb = max(min_bb, (img_dump['conf'] > conf_th).sum()) + num_bb = min(max_bb, num_bb) + return int(num_bb) + + +@curry +def _compute_item(conf_th, max_bb, min_bb, fname): + name = basename(fname) + try: + nbb = _compute_nbb(np.load(fname, allow_pickle=True), + conf_th, max_bb, min_bb) + except OSError: + # some corrupted files in conceptual caption + nbb = None + return name, nbb + + +def _compute_all_nbb(img_dir, conf_th, max_bb, min_bb, nproc): + files = glob.glob(f'{img_dir}/*.npz') + with mp.Pool(nproc) as pool: + fname2nbb = dict( + pool.imap_unordered(_compute_item(conf_th, max_bb, min_bb), + tqdm(files), chunksize=2048)) + + return fname2nbb + + +def main(opts): + n2bb = _compute_all_nbb(opts.img_dir, opts.conf_th, + opts.max_bb, opts.min_bb, + opts.nproc) + with open(f'{opts.img_dir}/' + f'nbb_th{opts.conf_th}_max{opts.max_bb}_min{opts.min_bb}.json', + 'w') as f: + json.dump(n2bb, f) + corrupts = [f for f, n in n2bb.items() if n is None] + if corrupts: + with open(f'{opts.img_dir}/corrupted.json', 'w') as f: + json.dump(corrupts, f, indent=4) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument("--img_dir", default=None, type=str, + help="The input images.") + parser.add_argument('--conf_th', type=float, default=0.2, + help='threshold for dynamic bounding boxes ' + '(-1 for fixed)') + parser.add_argument('--max_bb', type=int, default=100, + help='max number of bounding boxes') + parser.add_argument('--min_bb', type=int, default=10, + help='min number of bounding boxes') + parser.add_argument('--nproc', type=int, + help='number of cores used') + args = parser.parse_args() + main(args) diff --git a/uniter_model/scripts/convert_gqa.py b/uniter_model/scripts/convert_gqa.py new file mode 100644 index 0000000..0a401d8 --- /dev/null +++ b/uniter_model/scripts/convert_gqa.py @@ -0,0 +1,62 @@ +""" +convert GQA jsons into VQA format +""" +import json + +from toolz.sandbox import unzip + +ANNOTATION = '/ssd2/yenchun/ANNOTATIONS/' +SPLITS = ['train', 'val', 'testdev'] +VERSIONS = ['all', 'balanced'] + + +def convert(item): + qid, example = item + q = {'image_id': example['imageId'], + 'question_id': qid, + 'question': example['question']} + if 'answer' in example: + a = {'image_id': example['imageId'], + 'question_id': qid, + 'answers': [{"answer": example['answer']}]} + else: + a = None + return q, a + + +def convert_all(data): + questions, answers = unzip(map(convert, data.items())) + return questions, answers + + +def main(): + for split in SPLITS: + for ver in VERSIONS: + if split == 'train' and ver == 'all': + data = {} + for i in range(10): + for qid, ex in json.load(open( + f'{ANNOTATION}/GQA/train_all_questions/' + f'train_all_questions_{i}.json')).items(): + for key in list(ex.keys()): + if key not in ['imageId', 'question', 'answer']: + del ex[key] + data[qid] = ex + else: + data = json.load(open(f'{ANNOTATION}/GQA/' + f'{split}_{ver}_questions.json')) + questions, answers = convert_all(data) + json.dump({'questions': list(questions)}, + open(f'{ANNOTATION}/GQA/' + f'gqa_{split}_{ver}_questions.vqa.json', 'w')) + json.dump({'annotations': list(answers)}, + open(f'{ANNOTATION}/GQA/' + f'gqa_{split}_{ver}_annotations.vqa.json', 'w')) + data = json.load(open(f'{ANNOTATION}/GQA/submission_all_questions.json')) + questions, _ = convert_all(data) + json.dump({'questions': list(questions)}, + open(f'{ANNOTATION}/GQA/gqa_submission_questions.vqa.json', 'w')) + + +if __name__ == '__main__': + main() diff --git a/uniter_model/scripts/convert_imgdir.py b/uniter_model/scripts/convert_imgdir.py new file mode 100644 index 0000000..8cbfefc --- /dev/null +++ b/uniter_model/scripts/convert_imgdir.py @@ -0,0 +1,139 @@ +""" +convert image npz to LMDB +""" +import argparse +import glob +import io +import json +import multiprocessing as mp +import os +from os.path import basename, exists + +from cytoolz import curry +import numpy as np +from tqdm import tqdm +import lmdb + +import msgpack +import msgpack_numpy +msgpack_numpy.patch() + + +def _compute_nbb(img_dump, conf_th, max_bb, min_bb, num_bb): + num_bb = max(min_bb, (img_dump['conf'] > conf_th).sum()) + num_bb = min(max_bb, num_bb) + return int(num_bb) + + +@curry +def load_npz(conf_th, max_bb, min_bb, num_bb, fname, keep_all=False): + try: + img_dump = np.load(fname, allow_pickle=True) + if keep_all: + nbb = None + else: + nbb = _compute_nbb(img_dump, conf_th, max_bb, min_bb, num_bb) + dump = {} + for key, arr in img_dump.items(): + if arr.dtype == np.float32: + arr = arr.astype(np.float16) + if arr.ndim == 2: + dump[key] = arr[:nbb, :] + elif arr.ndim == 1: + dump[key] = arr[:nbb] + else: + raise ValueError('wrong ndim') + except Exception as e: + # corrupted file + print(f'corrupted file {fname}', e) + dump = {} + nbb = 0 + + name = basename(fname) + return name, dump, nbb + + +def dumps_npz(dump, compress=False): + with io.BytesIO() as writer: + if compress: + np.savez_compressed(writer, **dump, allow_pickle=True) + else: + np.savez(writer, **dump, allow_pickle=True) + return writer.getvalue() + + +def dumps_msgpack(dump): + return msgpack.dumps(dump, use_bin_type=True) + + +def main(opts): + if opts.img_dir[-1] == '/': + opts.img_dir = opts.img_dir[:-1] + split = basename(opts.img_dir) + if opts.keep_all: + db_name = 'all' + else: + if opts.conf_th == -1: + db_name = f'feat_numbb{opts.num_bb}' + else: + db_name = (f'feat_th{opts.conf_th}_max{opts.max_bb}' + f'_min{opts.min_bb}') + if opts.compress: + db_name += '_compressed' + if not exists(f'{opts.output}/{split}'): + os.makedirs(f'{opts.output}/{split}') + env = lmdb.open(f'{opts.output}/{split}/{db_name}', map_size=1024**4) + txn = env.begin(write=True) + files = glob.glob(f'{opts.img_dir}/*.npz') + load = load_npz(opts.conf_th, opts.max_bb, opts.min_bb, opts.num_bb, + keep_all=opts.keep_all) + name2nbb = {} + with mp.Pool(opts.nproc) as pool, tqdm(total=len(files)) as pbar: + for i, (fname, features, nbb) in enumerate( + pool.imap_unordered(load, files, chunksize=128)): + if not features: + continue # corrupted feature + if opts.compress: + dump = dumps_npz(features, compress=True) + else: + dump = dumps_msgpack(features) + txn.put(key=fname.encode('utf-8'), value=dump) + if i % 1000 == 0: + txn.commit() + txn = env.begin(write=True) + name2nbb[fname] = nbb + pbar.update(1) + txn.put(key=b'__keys__', + value=json.dumps(list(name2nbb.keys())).encode('utf-8')) + txn.commit() + env.close() + if opts.conf_th != -1 and not opts.keep_all: + with open(f'{opts.output}/{split}/' + f'nbb_th{opts.conf_th}_' + f'max{opts.max_bb}_min{opts.min_bb}.json', 'w') as f: + json.dump(name2nbb, f) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument("--img_dir", default=None, type=str, + help="The input images.") + parser.add_argument("--output", default=None, type=str, + help="output lmdb") + parser.add_argument('--nproc', type=int, default=8, + help='number of cores used') + parser.add_argument('--compress', action='store_true', + help='compress the tensors') + parser.add_argument('--keep_all', action='store_true', + help='keep all features, overrides all following args') + parser.add_argument('--conf_th', type=float, default=0.2, + help='threshold for dynamic bounding boxes ' + '(-1 for fixed)') + parser.add_argument('--max_bb', type=int, default=100, + help='max number of bounding boxes') + parser.add_argument('--min_bb', type=int, default=10, + help='min number of bounding boxes') + parser.add_argument('--num_bb', type=int, default=100, + help='number of bounding boxes (fixed)') + args = parser.parse_args() + main(args) diff --git a/uniter_model/scripts/download_bert.py b/uniter_model/scripts/download_bert.py new file mode 100644 index 0000000..c59a8f7 --- /dev/null +++ b/uniter_model/scripts/download_bert.py @@ -0,0 +1,12 @@ +""" +Download and extract PyTorch pretrained BERT model +python scripts/download_bert.py bert-base-cased /pretrain/bert-base-cased.pt +""" +import sys + +import torch +from pytorch_pretrained_bert import BertForPreTraining + +bert, output = sys.argv[1:] +model = BertForPreTraining.from_pretrained(bert) +torch.save(model.state_dict(), output) diff --git a/uniter_model/scripts/download_bert.sh b/uniter_model/scripts/download_bert.sh new file mode 100644 index 0000000..ed0f3ab --- /dev/null +++ b/uniter_model/scripts/download_bert.sh @@ -0,0 +1,9 @@ +BERT=$1 +PRETRAIN_DIR=$2 + + +docker run --rm \ + --mount src=$(pwd),dst=/src,type=bind \ + --mount src=$PRETRAIN_DIR,dst=/pretrain,type=bind \ + convaicontainerregistry1.azurecr.io/uniter \ + python scripts/download_bert.py $BERT /pretrain/$BERT.pt diff --git a/uniter_model/scripts/install_horovod.sh b/uniter_model/scripts/install_horovod.sh new file mode 100644 index 0000000..6a39e71 --- /dev/null +++ b/uniter_model/scripts/install_horovod.sh @@ -0,0 +1,42 @@ +# for building docker image + +# Update OpenMPI to avoid bug +rm -r /usr/local/mpi + +wget https://download.open-mpi.org/release/open-mpi/v4.0/openmpi-4.0.0.tar.gz +gunzip -c openmpi-4.0.0.tar.gz | tar xf - +cd openmpi-4.0.0 +./configure --prefix=/usr/local/mpi --enable-orterun-prefix-by-default \ + --disable-getpwuid +make -j$(nproc) all && make install +ldconfig + +cd - +rm -r openmpi-4.0.0 +rm openmpi-4.0.0.tar.gz + +export OPENMPI_VERSION=4.0.0 + + +# missing libnccl_static.a (solve by upgrading NCCL) +echo "deb http://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1604/x86_64 /" \ + > /etc/apt/sources.list.d/nvidia-ml.list +apt update +apt install libnccl2=2.4.7-1+cuda10.1 libnccl-dev=2.4.7-1+cuda10.1 + +export PATH=/usr/local/mpi/bin:$PATH +HOROVOD_GPU_ALLREDUCE=NCCL HOROVOD_WITH_PYTORCH=1 \ + pip install --no-cache-dir horovod +ldconfig + +# Install OpenSSH for MPI to communicate between containers +# apt-get install -y --no-install-recommends \ +# openssh-client openssh-server && \ +# mkdir -p /var/run/sshd + +# Allow OpenSSH to talk to containers without asking for confirmation +# cat /etc/ssh/ssh_config | \ +# grep -v StrictHostKeyChecking > \ +# /etc/ssh/ssh_config.new && \ +# echo " StrictHostKeyChecking no" >> /etc/ssh/ssh_config.new && \ +# mv /etc/ssh/ssh_config.new /etc/ssh/ssh_config diff --git a/uniter_model/scripts/map_iid_to_ann_ids.py b/uniter_model/scripts/map_iid_to_ann_ids.py new file mode 100644 index 0000000..90812ce --- /dev/null +++ b/uniter_model/scripts/map_iid_to_ann_ids.py @@ -0,0 +1,118 @@ +""" +We load Linjie's features from: datasets/npy_per_img_id.visual_grounding_coco_gt +Each feature is named as: visual_grounding_coco_000000581857.npz +containing {norm_bb, features, conf, soft_labels} + +The order of extracted bbox and features should align with ann_ids for each +img_id. + +We save this order for the use of REFER dataloader. +""" +import time +import pickle +import numpy as np +from pprint import pprint +from tqdm import tqdm +import json +import os.path as osp +import argparse + + +def recover_ann_ids(denorm_bb, raw_bb, raw_ann_ids): + """ + Inputs: + - denorm_bb : [xywh], extracted from BUTD detectors. + - raw_bb : [xywh] + - raw_ann_ids + Return: + - ordered_ann_ids: ordered by denorm_bb + """ + assert denorm_bb.shape[0] == raw_bb.shape[0] + num_bb = denorm_bb.shape[0] + ordered_ann_ids = [] + for i in range(num_bb): + ref_bb = denorm_bb[i] + min_err, ix = 1e5, None + for j in range(num_bb): + if np.sum(np.abs(ref_bb - raw_bb[j])) < min_err: + min_err, ix = np.sum(np.abs(ref_bb-raw_bb[j])), j + ordered_ann_ids.append(raw_ann_ids[ix]) + return ordered_ann_ids + +def main(args): + + # Load all instances from refcoco, refcoco+ and refcocog + tic = time.time() + iid_to_ann_ids = {} + warning_img_ids = set() + for dataset in ['refcoco', 'refcoco+', 'refcocog']: + print('Checking %s...' % dataset) + instances = json.load(open(osp.join(args.refer_dir, dataset, + 'instances.json'))) + Anns, Imgs, iid_to_raw_ann_ids = {}, {}, {} + for ann in instances['annotations']: + Anns[ann['id']] = ann + iid_to_raw_ann_ids[ann['image_id']] = iid_to_raw_ann_ids.get( + ann['image_id'], []) + [ann['id']] + for img in instances['images']: + Imgs[img['id']] = img + + # Make iid_to_ann_ids for this dataset + img_ids = list(Imgs.keys()) + for img_id in tqdm(img_ids): + if img_id in iid_to_ann_ids: + continue + raw_ann_ids = iid_to_raw_ann_ids[img_id] + # raw_gd_bb + raw_gd_bb = np.array([Anns[ann_id]['bbox'] + for ann_id in raw_ann_ids]) # (n, 4) xywh + # denorm_bb + im_width = Imgs[img_id]['width'] + im_height = Imgs[img_id]['height'] + img_feat = np.load(osp.join(args.feats_dir, + f'visual_grounding_coco_gt_{int(img_id):012}.npz')) + norm_bb = img_feat['norm_bb'] + x1, x2 = norm_bb[:, 0] * im_width, norm_bb[:, 2] * im_width + y1, y2 = norm_bb[:, 1] * im_height, norm_bb[:, 3] * im_height + w, h = norm_bb[:, 4] * im_width, norm_bb[:, 5] * im_height + denorm_bb = np.stack([x1, y1, w, h], axis=1) # (n,4) + # re-order ann_ids + ordered_ann_ids = recover_ann_ids(denorm_bb, raw_gd_bb, raw_ann_ids) + # check difference + ordered_gd_bb = np.array([Anns[ann_id]['bbox'] + for ann_id in ordered_ann_ids]) # (n, 4) + for i in range(denorm_bb.shape[0]): + assert np.sum(np.abs(denorm_bb[i]-ordered_gd_bb[i])) < 0.01, \ + '%s, %s' %(denorm_bb[i], ordered_gd_bb[i]) + # check ann_ids set + if set(ordered_ann_ids) != set(raw_ann_ids): + print('Please check img_id[%s]'%img_id) + warning_img_ids.add(img_id) + # check length of ann_ids + assert len(ordered_ann_ids) == len(raw_ann_ids) + # add to iid_to_ann_ids + iid_to_ann_ids[img_id] = ordered_ann_ids + + print('%s images contain dupicated bounding boxes.' % len(warning_img_ids)) + pprint(list(warning_img_ids)) + + # save + output_file = osp.join(args.output_dir, 'iid_to_ann_ids.json') + with open(output_file, 'w') as f: + json.dump({'iid_to_ann_ids': iid_to_ann_ids}, f) + print('%s iid_to_ann_ids saved in %s.' % (len(iid_to_ann_ids), output_file)) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--refer_dir', + default='datasets/refer', + help='folder saving all downloaded refer datasets') + parser.add_argument('--feats_dir', + default='datasets/npy_per_img_id/visual_grounding_coco_gt', + help='folder saving butd features.') + parser.add_argument('--output_dir', + default='index', + help='output folder saving img_id --> [ann_id]') + args = parser.parse_args() + main(args) \ No newline at end of file diff --git a/uniter_model/scripts/map_vg_vqa_img.py b/uniter_model/scripts/map_vg_vqa_img.py new file mode 100644 index 0000000..c09cae9 --- /dev/null +++ b/uniter_model/scripts/map_vg_vqa_img.py @@ -0,0 +1,67 @@ +""" +mcan vg annotation image id is COCO, need to map back to VG +""" +import json + + +ANNOTATION = '/ssd2/yenchun/ANNOTATIONS' +# karpathy 5k test split +TEST_5K = f'{ANNOTATION}/Image-Text-Matching/coco_test.json' + +VG_QUESTION = f'{ANNOTATION}/VQA/VG_questions.json' +VG_ANSWER = f'{ANNOTATION}/VQA/VG_annotations.json' +VG_IMG_META = f'{ANNOTATION}/VQA/image_data.json' + + +def _get_img_id(img_name): + img_name = img_name[:-4] + id_ = int(img_name.split('_')[-1]) + return id_ + + +def _get_test_ids(): + data = json.load(open(TEST_5K)) + ids = {_get_img_id(d['filename']) for d in data} + return ids + + +def _get_coco2vg(): + data = json.load(open(VG_IMG_META)) + coco2vg = {d['coco_id']: d['image_id'] for d in data} + return coco2vg + + +def filter_data(data, test_ids): + filtered = (d for d in data if d['image_id'] not in test_ids) + return filtered + + +def map_data(data, coco2vg): + def gen_mapped(): + for d in data: + coco_id = d['image_id'] + d['image_id'] = coco2vg[coco_id] + yield d + return gen_mapped() + + +def main(): + test_ids = _get_test_ids() + coco2vg = _get_coco2vg() + + # process questions + questions = json.load(open(VG_QUESTION))['questions'] + mapped_qs = list(map_data(filter_data(questions, test_ids), coco2vg)) + qname = f'{VG_QUESTION}.mapped' + json.dump({'questions': mapped_qs}, open(qname, 'w')) + del questions, mapped_qs + + # process answers + answers = json.load(open(VG_ANSWER))['annotations'] + mapped_as = list(map_data(filter_data(answers, test_ids), coco2vg)) + aname = f'{VG_ANSWER}.mapped' + json.dump({'annotations': mapped_as}, open(aname, 'w')) + + +if __name__ == '__main__': + main() diff --git a/uniter_model/scripts/prepro_all.sh b/uniter_model/scripts/prepro_all.sh new file mode 100644 index 0000000..81245c0 --- /dev/null +++ b/uniter_model/scripts/prepro_all.sh @@ -0,0 +1,186 @@ +TOKER=$1 +TXT_DB=$2 +FORMAT=$3 +#TXT_DB='/ssd2/yenchun/TXT_DB_test' + +ANNOTATIONS='/ssd2/yenchun/ANNOTATIONS' +VQA_ANN=$ANNOTATIONS/VQA/ +CAP_ANN=$ANNOTATIONS/COCO_annotation/ +CONCEPT_ANN=$ANNOTATIONS/conceptual_captions/ +SBU_ANN=$ANNOTATIONS/sbu_caption/ +PRETRAIN_ANN=$ANNOTATIONS/latest_cleaned/ +ITM_ANN=$ANNOTATIONS/Image-Text-Matching +VE_ANN=$ANNOTATIONS/visual_entailment/ +GQA_ANN=$ANNOTATIONS/GQA/ +VCR_ANN=$ANNOTATIONS/VCR/ +NLVR2_ANN=$ANNOTATIONS/NLVR2/ + + +# process licheng's split +#python scripts/split_annotations.py --format $FORMAT \ +# $PRETRAIN_ANN/collected\(coco+vg\).json $PRETRAIN_ANN + + +if [ $TOKER = 'bert-large-cased' ]; then + SUFFIX='large-cased' +elif [ $TOKER = 'bert-base-cased' ]; then + SUFFIX='base-cased' +else + echo "invalid tokenizer specified" + exit(1) +fi + +# Image Text Retrieval +for DSET in 'flickr30k' 'coco'; do + for SPLIT in 'train' 'val' 'test'; do + python prepro.py --task itm --bert $TOKER --format $FORMAT \ + --annotations $ITM_ANN/${DSET}_$SPLIT.json \ + --output $TXT_DB/itm_${DSET}_${SPLIT}_$SUFFIX.db + + done +done +# coco 1k splits +for SPLIT in 'val' 'test'; do + for i in 0 1 2 3 4; do + python prepro.py --task itm --bert $TOKER --format $FORMAT \ + --annotations $ITM_ANN/coco_${SPLIT}_1k_$i.json \ + --output $TXT_DB/itm_coco_${SPLIT}_1k_${i}_$SUFFIX.db + done +done +# coco val rest +python prepro.py --task itm --bert $TOKER --format $FORMAT \ + --annotations $ITM_ANN/coco_restval.json \ + --output $TXT_DB/itm_coco_restval_$SUFFIX.db + + +# COCO +for SPLIT in 'train' 'val'; do + # VQA + python prepro.py --task vqa --bert $TOKER --format $FORMAT \ + --annotations $VQA_ANN/v2_OpenEnded_mscoco_${SPLIT}2014_questions.json \ + $VQA_ANN/v2_mscoco_${SPLIT}2014_annotations.json \ + $VQA_ANN/ans2label.pkl \ + --output $TXT_DB/vqa_${SPLIT}_$SUFFIX.db + if [ $SPLIT = 'val' ]; then + for SP in 'train' 'dev'; do + python prepro.py --task vqa --bert $TOKER --format $FORMAT \ + --annotations $VQA_ANN/v2_OpenEnded_mscoco_${SP}val2014_questions.json \ + $VQA_ANN/v2_mscoco_${SP}val2014_annotations.json \ + $VQA_ANN/ans2label.pkl \ + --output $TXT_DB/vqa_${SP}val_$SUFFIX.db + done + fi + + # Caption + python prepro.py --task caption --bert $TOKER --format $FORMAT \ + --annotations $CAP_ANN/captions_${SPLIT}2014.json \ + --output $TXT_DB/caption_${SPLIT}_$SUFFIX.db +done + +# COCO VQA test +python prepro.py --task vqa --bert $TOKER --format $FORMAT \ + --annotations $VQA_ANN/v2_OpenEnded_mscoco_test2015_questions.json \ + --output $TXT_DB/vqa_test_$SUFFIX.db + +# VG VQA +python prepro.py --task vqa --bert $TOKER --format $FORMAT \ + --annotations $VQA_ANN/VG_questions.json.mapped \ + $VQA_ANN/VG_annotations.json.mapped \ + $VQA_ANN/ans2label.pkl \ + --output $TXT_DB/vqa_vg_$SUFFIX.db + +# all pretraining + +# coco trainval +python prepro.py --task licheng_cleaned --bert $TOKER --format $FORMAT \ + --annotations $PRETRAIN_ANN/pretrain_caption_coco_trainval.json \ + --output $TXT_DB/pretrain_caption_coco_trainval_$SUFFIX.db + +for DSET in 'coco' 'vg'; do + for SPLIT in 'val' 'train'; do + python prepro.py --task licheng_cleaned --bert $TOKER --format $FORMAT \ + --annotations $PRETRAIN_ANN/pretrain_caption_${DSET}_$SPLIT.json \ + --output $TXT_DB/pretrain_caption_${DSET}_${SPLIT}_$SUFFIX.db + done +done + +# pretrain VQA +for DSET in 'genome_vqa' 'gqa'; do + if [ $DSET = 'genome_vqa' ]; then + DS='vg' + else + DS='gqa' + fi + for SPLIT in 'val' 'train'; do + python prepro.py --task vqa --bert $TOKER --format $FORMAT \ + --annotations $PRETRAIN_ANN/${DSET}_${SPLIT}_questions.json \ + $PRETRAIN_ANN/${DSET}_${SPLIT}_annotations.json \ + $PRETRAIN_ANN/ans2label.pkl \ + --output $TXT_DB/pretrain_vqa_${DS}_${SPLIT}_$SUFFIX.db + done +done +# Pretrain VQA COCO +for SPLIT in 'val' 'trainsplit' 'valsplit' ; do + python prepro.py --task vqa --bert $TOKER --format $FORMAT \ + --annotations $PRETRAIN_ANN/coco_vqa_${SPLIT}_questions.json \ + $PRETRAIN_ANN/coco_vqa_${SPLIT}_annotations.json \ + $PRETRAIN_ANN/ans2label.pkl \ + --output $TXT_DB/pretrain_vqa_coco_${SPLIT}_$SUFFIX.db +done + + +# Visual Entailment +for SPLIT in 'train' 'dev' 'test'; do + python prepro.py --task ve --bert $TOKER --format $FORMAT \ + --annotations $VE_ANN/snli_ve_$SPLIT.jsonl \ + --output $TXT_DB/ve_${SPLIT}_$SUFFIX.db +done + +# GQA +for SPLIT in 'train' 'val' 'testdev'; do + for VER in 'all' 'balanced'; do + python prepro.py --task vqa --bert $TOKER --format $FORMAT \ + --annotations $GQA_ANN/gqa_${SPLIT}_${VER}_questions.vqa.json \ + $GQA_ANN/gqa_${SPLIT}_${VER}_annotations.vqa.json \ + $GQA_ANN/ans2label.pkl \ + --output $TXT_DB/gqa_${SPLIT}_${VER}_$SUFFIX.db + done +done +# GQA test +python prepro.py --task vqa --bert $TOKER --format $FORMAT \ + --annotations $GQA_ANN/gqa_submission_questions.vqa.json \ + --output $TXT_DB/gqa_submission_$SUFFIX.db + + +# Conceptual Captions +for SPLIT in 'train' 'val'; do + python prepro.py --task conceptual --bert $TOKER --format $FORMAT \ + --annotations $CONCEPT_ANN/${SPLIT}_imageId2Ann.tsv \ + $CONCEPT_ANN/${SPLIT}_imgs.json \ + --output $TXT_DB/conceptual_caption_${SPLIT}_$SUFFIX.db +done + +# SBU captions +for SPLIT in 'train' 'val'; do + python prepro.py --task sbu --bert $TOKER --format $FORMAT \ + --annotations $SBU_ANN/sbu_${SPLIT}_captions.json \ + --output $TXT_DB/sbu_caption_${SPLIT}_$SUFFIX.db +done + +# VCR +for SPLIT in 'train' 'val'; do + python prepro.py --task vcr --bert $TOKER --format $FORMAT \ + --annotations $VCR_ANN/$SPLIT.jsonl \ + --output $TXT_DB/vcr_${SPLIT}_$SUFFIX.db +done + +# NLVR2 +for SPLIT in 'dev' 'test1'; do + python prepro.py --task nlvr2 --bert $TOKER --format $FORMAT \ + --annotations $NLVR2_ANN/$SPLIT.json \ + --output $TXT_DB/nlvr2_${SPLIT}_$SUFFIX.db +done +# some corrupted train features +python prepro.py --task nlvr2 --bert $TOKER --format $FORMAT \ + --annotations $NLVR2_ANN/train.json $NLVR2_ANN/train_imgs.json \ + --output $TXT_DB/nlvr2_train_$SUFFIX.db diff --git a/uniter_model/scripts/prepro_gqa.sh b/uniter_model/scripts/prepro_gqa.sh new file mode 100644 index 0000000..44782e7 --- /dev/null +++ b/uniter_model/scripts/prepro_gqa.sh @@ -0,0 +1,12 @@ +GQA_ANN='/db/raw_data/GQA/questions1.2/train_all_questions/' +SUFFIX="base-cased" +TXT_DB='/db/TXT_DB_v3' +SPLIT="train" + +for i in 2 3 4 5 6 7 8 9; do + python prepro.py --task gqa \ + --annotations $GQA_ANN/${SPLIT}_all_questions_$i.json \ + --output $TXT_DB/pretrain_gqa_${SPLIT}_${i}_$SUFFIX.db +done + + diff --git a/uniter_model/scripts/prepro_iid_to_dets.py b/uniter_model/scripts/prepro_iid_to_dets.py new file mode 100644 index 0000000..295aafa --- /dev/null +++ b/uniter_model/scripts/prepro_iid_to_dets.py @@ -0,0 +1,45 @@ +""" +This code converts all RefCOCO(+/g) detections from Mask R-CNN +(https://github.com/lichengunc/MAttNet) +to image_id -> [box], where each box is {box, category_id, category_name, score} +""" +import json +import os +import os.path as osp + +dets_dir = 'datasets/refer/detections' +image_set = set() +dataset_names = ['refcoco_unc', 'refcoco+_unc', 'refcocog_umd'] +Detections = {} +for dataset_name in dataset_names: + dets_file = osp.join(dets_dir, dataset_name, + 'res101_coco_minus_refer_notime_dets.json') + detections = json.load(open(dets_file, 'r')) + for det in detections: + image_set.add(det['image_id']) + Detections[dataset_name] = detections +num_images = len(image_set) + +iid_to_dets = {} +for dataset_name in dataset_names: + detections = Detections[dataset_name] + for det in detections: + image_id = det['image_id'] + if image_id in image_set: + box = {'box': det['box'], + 'category_id': det['category_id'], + 'category_name': det['category_name'], + 'score': det['score']} + iid_to_dets[image_id] = iid_to_dets.get(image_id, []) + [box] + for det in detections: + image_id = det['image_id'] + if image_id in image_set: + image_set.remove(image_id) + +num_dets = sum([len(dets) for dets in iid_to_dets.values()]) +print(f'{num_dets} detections in {num_images} images for {dataset_names}.') + +# save +with open('index/iid_to_dets.json', 'w') as f: + json.dump(iid_to_dets, f) + diff --git a/uniter_model/scripts/prepro_re.sh b/uniter_model/scripts/prepro_re.sh new file mode 100644 index 0000000..90862c3 --- /dev/null +++ b/uniter_model/scripts/prepro_re.sh @@ -0,0 +1,44 @@ +TOKER='bert-base-cased' +TXT_DB='datasets/TXT_DB_v3' + +ANNOTATIONS='datasets' +RE_ANN=$ANNOTATIONS/refer + +if [ $TOKER = 'bert-large-cased' ]; then + SUFFIX='large-cased' +elif [ $TOKER = 'bert-base-cased' ]; then + SUFFIX='base-cased' +else + echo "invalid tokenizer specified" + # exit(1) +fi + +# refcoco, refcoco+ +for DATASET in 'refcoco' 'refcoco+'; do + for SPLIT in 'train' 'val' 'testA' 'testB'; do + python prepro.py --task re --bert $TOKER \ + --annotations $RE_ANN/${DATASET}/'refs(unc).p' \ + $RE_ANN/${DATASET}/instances.json \ + index/iid_to_ann_ids.json \ + --output $TXT_DB/${DATASET}_${SPLIT}_$SUFFIX.db + done +done + +# refcocog +DATASET='refcocog' +for SPLIT in 'train' 'val' 'test'; do + python prepro.py --task re --bert $TOKER \ + --annotations $RE_ANN/${DATASET}/'refs(umd).p' \ + $RE_ANN/${DATASET}/instances.json \ + index/iid_to_ann_ids.json \ + --output $TXT_DB/${DATASET}_${SPLIT}_$SUFFIX.db +done + + +# DATASET='refcoco' +# SPLIT='train' +# python prepro.py --task re --bert $TOKER \ +# --annotations $RE_ANN/${DATASET}/'refs(unc).p' \ +# $RE_ANN/${DATASET}/instances.json \ +# index/iid_to_ann_ids.json \ +# --output $TXT_DB/${DATASET}_${SPLIT}_$SUFFIX.db \ No newline at end of file diff --git a/uniter_model/scripts/split_annotations.py b/uniter_model/scripts/split_annotations.py new file mode 100644 index 0000000..692992e --- /dev/null +++ b/uniter_model/scripts/split_annotations.py @@ -0,0 +1,57 @@ +import json +from os.path import join +import sys + + +def save_coco_train_val(data, output_dir): + current_data = [] + rest_data = [] + for d in data: + if not d['sent'].strip(): + # filter out empty sentence + continue + if (d['dataset'] == 'coco' + and d['split'] == 'train' + and 'val' in d['file_path']): + current_data.append(d) + else: + rest_data.append(d) + fileName = "pretrain_caption_coco_trainval.json" + json.dump(current_data, open(join(output_dir, fileName), "w")) + return rest_data + + +def save_by_dataset_and_split(data, dataset, split, output_dir): + current_data = [] + rest_data = [] + for d in data: + if not d['sent'].strip(): + # filter out empty sentence + continue + if split == 'trainval': + if (d['dataset'] == 'coco' + and d['split'] == 'train' + and 'val' in d['file_path']): + current_data.append(d) + else: + rest_data.append(d) + elif d["dataset"] == dataset and d["split"] == split: + current_data.append(d) + else: + rest_data.append(d) + fileName = f"pretrain_caption_{dataset}_{split}.json" + json.dump(current_data, open(join(output_dir, fileName), "w")) + return rest_data + + +def main(): + input_file, output_dir = sys.argv[1:] + data = json.load(open(input_file, "r")) + data = save_coco_train_val(data, output_dir) + for dataset in ["coco", "vg"]: + for split in ["train", "val", "test"]: + data = save_by_dataset_and_split(data, dataset, split, output_dir) + + +if __name__ == '__main__': + main() diff --git a/uniter_model/scripts/split_coco_pretrain_vqa.py b/uniter_model/scripts/split_coco_pretrain_vqa.py new file mode 100644 index 0000000..f3b7c74 --- /dev/null +++ b/uniter_model/scripts/split_coco_pretrain_vqa.py @@ -0,0 +1,57 @@ +""" +split pretraining COCO VQA according to image directory +""" +import json +from os.path import dirname + +ANNOTATION = '/ssd2/yenchun/ANNOTATIONS' +EXCLUDE_IID = f'{dirname(__file__)}/../index/excluded_coco_vg_iids.json' + +COCO_TRAIN_QUESTION = (f'{ANNOTATION}/VQA/' + 'v2_OpenEnded_mscoco_train2014_questions.json') +COCO_TRAIN_ANSWER = f'{ANNOTATION}/VQA/v2_mscoco_train2014_annotations.json' + +COCO_VAL_QUESTION = (f'{ANNOTATION}/VQA/' + 'v2_OpenEnded_mscoco_val2014_questions.json') +COCO_VAL_ANSWER = f'{ANNOTATION}/VQA/v2_mscoco_val2014_annotations.json' + +OUT_DIR = f'{ANNOTATION}/latest_cleaned/' + + +def _filter_data(examples, exclude_iids): + filtered = (ex for ex in examples if ex['image_id'] not in exclude_iids) + return filtered + + +def main(): + ids = json.load(open(EXCLUDE_IID)) + train_exclude_iids = set(ids['flickr30k_coco_iids'] + + ids['refer_val_coco_iids'] + + ids['refer_test_coco_iids']) + val_exclude_iids = set(ids['flickr30k_coco_iids'] + + ids['karpathy_minival_iids'] + + ids['karpathy_minitest_iids']) + + # split train + questions = json.load(open(COCO_TRAIN_QUESTION))['questions'] + train_qs = _filter_data(questions, train_exclude_iids) + with open(f'{OUT_DIR}/coco_vqa_trainsplit_questions.json', 'w') as f: + json.dump({'questions': list(train_qs)}, f) + answers = json.load(open(COCO_TRAIN_ANSWER))['annotations'] + train_as = _filter_data(answers, train_exclude_iids) + with open(f'{OUT_DIR}/coco_vqa_trainsplit_annotations.json', 'w') as f: + json.dump({'annotations': list(train_as)}, f) + + # split val + questions = json.load(open(COCO_VAL_QUESTION))['questions'] + val_qs = _filter_data(questions, val_exclude_iids) + with open(f'{OUT_DIR}/coco_vqa_valsplit_questions.json', 'w') as f: + json.dump({'questions': list(val_qs)}, f) + answers = json.load(open(COCO_VAL_ANSWER))['annotations'] + val_as = _filter_data(answers, val_exclude_iids) + with open(f'{OUT_DIR}/coco_vqa_valsplit_annotations.json', 'w') as f: + json.dump({'annotations': list(val_as)}, f) + + +if __name__ == '__main__': + main() diff --git a/uniter_model/scripts/split_vqa_val.py b/uniter_model/scripts/split_vqa_val.py new file mode 100644 index 0000000..d0099f5 --- /dev/null +++ b/uniter_model/scripts/split_vqa_val.py @@ -0,0 +1,64 @@ +""" +split vqa val set for data augmentation +""" +import json + + +ANNOTATION = '/ssd2/yenchun/ANNOTATIONS' +# karpathy 5k test split +TEST_5K = f'{ANNOTATION}/Image-Text-Matching/coco_test.json' + +# original VQA val data +VAL_QUESTION = f'{ANNOTATION}/VQA/v2_OpenEnded_mscoco_val2014_questions.json' +VAL_ANSWER = f'{ANNOTATION}/VQA/v2_mscoco_val2014_annotations.json' + + +def _get_img_id(img_name): + img_name = img_name[:-4] + id_ = int(img_name.split('_')[-1]) + return id_ + + +def _get_test_ids(): + data = json.load(open(TEST_5K)) + ids = {_get_img_id(d['filename']) for d in data} + return ids + + +def main(): + dev_ids = _get_test_ids() + + # process questions + val_questions = json.load(open(VAL_QUESTION))['questions'] + dev_qs = [] + train_qs = [] + for q in val_questions: + if q['image_id'] in dev_ids: + dev_qs.append(q) + else: + train_qs.append(q) + assert len(val_questions) == len(dev_qs) + len(train_qs) + dev_q_name = VAL_QUESTION.replace('val', 'devval') + json.dump({'questions': dev_qs}, open(dev_q_name, 'w')) + train_q_name = VAL_QUESTION.replace('val', 'trainval') + json.dump({'questions': train_qs}, open(train_q_name, 'w')) + + # process answers + val_answers = json.load(open(VAL_ANSWER))['annotations'] + dev_as = [] + train_as = [] + for a in val_answers: + if a['image_id'] in dev_ids: + dev_as.append(a) + else: + train_as.append(a) + assert len(dev_as) == len(dev_qs) + assert len(train_as) == len(train_qs) + dev_a_name = VAL_ANSWER.replace('val', 'devval') + json.dump({'annotations': dev_as}, open(dev_a_name, 'w')) + train_a_name = VAL_ANSWER.replace('val', 'trainval') + json.dump({'annotations': train_as}, open(train_a_name, 'w')) + + +if __name__ == '__main__': + main() diff --git a/uniter_model/tests/generate_test_data.py b/uniter_model/tests/generate_test_data.py new file mode 100644 index 0000000..bce2f3e --- /dev/null +++ b/uniter_model/tests/generate_test_data.py @@ -0,0 +1,97 @@ +""" +minimal running script of distributed training +""" +import argparse +import random + +import torch +from torch import nn +from torch.nn import functional as F +from torch.nn.utils.rnn import pad_sequence +from torch import optim + +# communication operations +from utils.distributed import all_reduce_and_rescale_tensors, all_gather_list + + +class DataLoader(object): + def __init__(self, vocab_size, n_class, batch_size=8, lengths=(5, 10)): + self.vsize = vocab_size + self.ncls = n_class + self.bs = batch_size + self.lengths = lengths + + def __iter__(self): + while True: + input_, target = self._random_batch() + yield input_, target + + def _random_batch(self): + inputs = [] + targets = [] + for _ in range(self.bs): + i, t = self._random_inputs() + inputs.append(i) + targets.append(t) + input_ = pad_sequence(inputs) + targets = torch.LongTensor(targets) + return input_, targets + + def _random_inputs(self): + len_ = random.randint(*self.lengths) + inputs = [random.randint(0, self.vsize-1) for _ in range(len_)] + target = random.randint(0, self.ncls-1) + return torch.LongTensor(inputs), target + + +class Model(nn.Module): + def __init__(self, vsize, ncls): + super().__init__() + self.emb = nn.Embedding(vsize, 100) + self.rnn = nn.LSTM(100, 100, 1) + self.proj = nn.Linear(100, ncls) + + def forward(self, input_): + emb_out = self.emb(input_) + _, (h, c) = self.rnn(emb_out) + output = self.proj(h[-1]) + return output + +class InputExample(object): + def __init__(self, input, target): + self.input = input + self.target = target + +def main(): + vsize = 200 + ncls = 10 + accum = 4 + total_step = 100 + seed = 777 + total_step = 100 + + random.seed(seed) + torch.manual_seed(seed) + global_step = 0 + loader = DataLoader(vsize, ncls) + examples = [] + print ("example generating") + for step, (input_, target) in enumerate(loader): + print ("example appended" + str(step)) + examples.append(InputExample(input=input_, target = target)) + global_step += 1 + if global_step >= total_step: + break + print ("saving torch.save") + torch.save(examples, 'data/test_data/input0.txt') + + examples = torch.load('data/test_data/input.txt') + for step, ie in enumerate(examples): + print (step) + print (ie.input) + print (ie.target) + +if __name__ == '__main__': + main() + + diff --git a/uniter_model/tests/test_distributed_fa.py b/uniter_model/tests/test_distributed_fa.py new file mode 100644 index 0000000..bd39889 --- /dev/null +++ b/uniter_model/tests/test_distributed_fa.py @@ -0,0 +1,126 @@ +""" +minimal running script of distributed training +""" +import argparse +import random + +import torch +from torch import nn +from torch.nn import functional as F +from torch.nn.utils.rnn import pad_sequence +from torch import optim + +# communication operations +from utils.distributed import all_reduce_and_rescale_tensors, all_gather_list + + +class DataLoader(object): + def __init__(self, vocab_size, n_class, batch_size=8, lengths=(5, 10)): + self.vsize = vocab_size + self.ncls = n_class + self.bs = batch_size + self.lengths = lengths + + def __iter__(self): + while True: + input_, target = self._random_batch() + yield input_, target + + def _random_batch(self): + inputs = [] + targets = [] + for _ in range(self.bs): + i, t = self._random_inputs() + inputs.append(i) + targets.append(t) + input_ = pad_sequence(inputs) + targets = torch.LongTensor(targets) + return input_, targets + + def _random_inputs(self): + len_ = random.randint(*self.lengths) + inputs = [random.randint(0, self.vsize-1) for _ in range(len_)] + target = random.randint(0, self.ncls-1) + return torch.LongTensor(inputs), target + + +class Model(nn.Module): + def __init__(self, vsize, ncls): + super().__init__() + self.emb = nn.Embedding(vsize, 100) + self.rnn = nn.LSTM(100, 100, 1) + self.proj = nn.Linear(100, ncls) + + def forward(self, input_): + emb_out = self.emb(input_) + _, (h, c) = self.rnn(emb_out) + output = self.proj(h[-1]) + return output + +class InputExample(object): + def __init__(self, input, target): + self.input = input + self.target = target + +def main(local_rank): + vsize = 200 + ncls = 10 + accum = 4 + total_step = 100 + seed = 777 + + # distributed initialization + if local_rank == -1: + device = torch.device("cuda") + n_gpu = 1 + else: + torch.cuda.set_device(local_rank) + device = torch.device("cuda", local_rank) + # Initializes the distributed backend which will take care of + # sychronizing nodes/GPUs + torch.distributed.init_process_group(backend='nccl') + n_gpu = torch.distributed.get_world_size() + print("device: {} n_gpu: {}, distributed training: {}".format( + device, n_gpu, bool(local_rank != -1))) + + random.seed(seed) + torch.manual_seed(seed) + if n_gpu > 0: + torch.cuda.manual_seed_all(seed) + + #loader = DataLoader(vsize, ncls) + model = Model(vsize, ncls).to(device) + optimizer = optim.Adam(model.parameters(), lr=1e-4) + global_step = 0 + + print ("local_rank" + str(local_rank)) + examples = torch.load('data/test_data/input'+str(local_rank)+'.txt') + + for step, ie in enumerate(examples): + input_ = ie.input + target = ie.target + input_ = input_.to(device) + target = target.to(device) + logit = model(input_) + loss = F.cross_entropy(logit, target, reduction='sum') + losses = all_gather_list(loss.item()) + loss.backward() + if (step+1) % accum == 0: + if local_rank != -1: + grads = [p.grad.data for p in model.parameters() + if p is not None and p.requires_grad] + all_reduce_and_rescale_tensors(grads, 1) + optimizer.step() + optimizer.zero_grad() + global_step += 1 + if local_rank <= 0: + print(f'step: {global_step}; loss: {sum(losses)}') + if global_step >= total_step: + break + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--local_rank', type=int, default=-1) + args = parser.parse_args() + main(args.local_rank) diff --git a/uniter_model/tests/test_hvd_fa.py b/uniter_model/tests/test_hvd_fa.py new file mode 100644 index 0000000..da5ef31 --- /dev/null +++ b/uniter_model/tests/test_hvd_fa.py @@ -0,0 +1,118 @@ +""" +minimal running script of distributed training +""" +import argparse +import random + +import torch +from torch import nn +from torch.nn import functional as F +from torch.nn.utils.rnn import pad_sequence +from torch import optim + +from horovod import torch as hvd + +# communication operations +from utils.distributed import all_reduce_and_rescale_tensors, all_gather_list + + +class DataLoader(object): + def __init__(self, vocab_size, n_class, batch_size=8, lengths=(5, 10)): + self.vsize = vocab_size + self.ncls = n_class + self.bs = batch_size + self.lengths = lengths + + def __iter__(self): + while True: + input_, target = self._random_batch() + yield input_, target + + def _random_batch(self): + inputs = [] + targets = [] + for _ in range(self.bs): + i, t = self._random_inputs() + inputs.append(i) + targets.append(t) + input_ = pad_sequence(inputs) + targets = torch.LongTensor(targets) + return input_, targets + + def _random_inputs(self): + len_ = random.randint(*self.lengths) + inputs = [random.randint(0, self.vsize-1) for _ in range(len_)] + target = random.randint(0, self.ncls-1) + return torch.LongTensor(inputs), target + + +class Model(nn.Module): + def __init__(self, vsize, ncls): + super().__init__() + self.emb = nn.Embedding(vsize, 100) + self.rnn = nn.LSTM(100, 100, 1) + self.proj = nn.Linear(100, ncls) + + def forward(self, input_): + emb_out = self.emb(input_) + _, (h, c) = self.rnn(emb_out) + output = self.proj(h[-1]) + return output + +class InputExample(object): + def __init__(self, input, target): + self.input = input + self.target = target + +def main(): + vsize = 200 + ncls = 10 + accum = 4 + total_step = 100 + seed = 777 + + # distributed initialization + hvd.init() + n_gpu = hvd.size() + device = torch.device("cuda", hvd.local_rank()) + local_rank = hvd.rank() + + random.seed(seed) + torch.manual_seed(seed) + if n_gpu > 0: + torch.cuda.manual_seed_all(seed) + + #loader = DataLoader(vsize, ncls) + model = Model(vsize, ncls).to(device) + optimizer = optim.Adam(model.parameters(), lr=1e-4) + global_step = 0 + + print ("local_rank" + str(local_rank)) + examples = torch.load('data/test_data/input'+str(local_rank)+'.txt') + + for step, ie in enumerate(examples): + input_ = ie.input + target = ie.target + input_ = input_.to(device) + target = target.to(device) + logit = model(input_) + loss = F.cross_entropy(logit, target, reduction='sum') + losses = all_gather_list(loss.item()) + #losses = [loss.item()] + loss.backward() + if (step+1) % accum == 0: + if local_rank != -1: + grads = [p.grad.data for p in model.parameters() + if p is not None and p.requires_grad] + all_reduce_and_rescale_tensors(grads, 1) + optimizer.step() + optimizer.zero_grad() + global_step += 1 + if local_rank <= 0: + print(f'step: {global_step}; loss: {sum(losses)}') + if global_step >= total_step: + break + + +if __name__ == '__main__': + main() diff --git a/uniter_model/train_itm.py b/uniter_model/train_itm.py new file mode 100644 index 0000000..9157230 --- /dev/null +++ b/uniter_model/train_itm.py @@ -0,0 +1,599 @@ +""" +Copyright (c) Microsoft Corporation. +Licensed under the MIT license. + +UNITER finetuning for Image-Text Retrieval +""" +import argparse +from collections import defaultdict +import json +import os +from os.path import exists, join +from time import time + +import torch +from torch.nn.utils import clip_grad_norm_ +from torch.utils.data import DataLoader, ConcatDataset +from apex import amp +from horovod import torch as hvd +from toolz.sandbox import unzip +from tqdm import tqdm + +from data import (PrefetchLoader, TxtTokLmdb, ImageLmdbGroup, + ItmRankDataset, ItmRankDatasetHardNeg, itm_rank_collate, + ItmHardNegDataset, itm_hn_collate, + ItmValDataset, itm_val_collate, + ItmEvalDataset, itm_eval_collate) +from model import UniterForImageTextRetrieval, UniterForImageTextRetrievalFast +from optim import get_lr_sched +from optim.misc import build_optimizer + +from utils.logger import LOGGER, TB_LOGGER, RunningMeter, add_log_to_file +from utils.distributed import (all_reduce_and_rescale_tensors, all_gather_list, + broadcast_tensors, any_broadcast) +from utils.save import ModelSaver, save_training_meta +from utils.misc import NoOp, parse_with_config, set_dropout, set_random_seed +from utils.const import IMG_DIM +from eval.itm import itm_eval + + +def build_dataloader(dataset, collate_fn, is_train, opts): + batch_size = opts.train_batch_size if is_train else 1 + dataloader = DataLoader(dataset, batch_size=batch_size, + shuffle=is_train, drop_last=is_train, + num_workers=opts.n_workers, + pin_memory=opts.pin_mem, collate_fn=collate_fn) + dataloader = PrefetchLoader(dataloader) + return dataloader + + +def compute_hard_neg(model, loader, dataset, hard_negative_num, hard_neg_dir): + txt2hardimgs, img2hardtxts = get_hard_negs(model, loader, + hard_negative_num) + with open(f'{hard_neg_dir}/' + f'txt2hardimgs_rank{hvd.rank()}.json', + 'w') as f: + json.dump(txt2hardimgs, f) + if hvd.rank() == 0: + with open(f'{hard_neg_dir}/img2hardtxts.json', 'w') as f: + json.dump(img2hardtxts, f) + all_gather_list(None) # dummy sync to wait for writing + if isinstance(dataset, ConcatDataset): + for dset in dataset.datasets: + dset.reload_hard_negs(hard_neg_dir) + else: + dataset.reload_hard_negs(hard_neg_dir) + + +def main(opts): + hvd.init() + n_gpu = hvd.size() + device = torch.device("cuda", hvd.local_rank()) + torch.cuda.set_device(hvd.local_rank()) + rank = hvd.rank() + opts.rank = rank + LOGGER.info("device: {} n_gpu: {}, rank: {}, " + "16-bits training: {}".format( + device, n_gpu, hvd.rank(), opts.fp16)) + + if opts.gradient_accumulation_steps < 1: + raise ValueError("Invalid gradient_accumulation_steps parameter: {}, " + "should be >= 1".format( + opts.gradient_accumulation_steps)) + + set_random_seed(opts.seed) + + if hvd.rank() == 0: + save_training_meta(opts) + TB_LOGGER.create(join(opts.output_dir, 'log')) + pbar = tqdm(total=opts.num_train_steps) + model_saver = ModelSaver(join(opts.output_dir, 'ckpt')) + add_log_to_file(join(opts.output_dir, 'log', 'log.txt')) + # store ITM predictions + os.makedirs(join(opts.output_dir, 'results_val')) + os.makedirs(join(opts.output_dir, 'results_test')) + os.makedirs(join(opts.output_dir, 'results_train')) + else: + LOGGER.disabled = True + pbar = NoOp() + model_saver = NoOp() + + # train_examples = None + LOGGER.info(f"Loading Train Dataset {opts.train_txt_dbs}, " + f"{opts.train_img_dbs}") + # check multiple DBs + assert len(opts.train_txt_dbs) == len(opts.train_img_dbs), \ + "train txt_db and img_db have different length" + + # load DBs and image dirs + all_img_dbs = ImageLmdbGroup(opts.conf_th, opts.max_bb, opts.min_bb, + opts.num_bb, opts.compressed_db) + # train + LOGGER.info(f"Loading Train Dataset " + f"{opts.train_txt_dbs}, {opts.train_img_dbs}") + train_datasets = [] + for txt_path, img_path in zip(opts.train_txt_dbs, opts.train_img_dbs): + img_db = all_img_dbs[img_path] + txt_db = TxtTokLmdb(txt_path, opts.max_txt_len) + if opts.hard_neg_size > 0: + train_datasets.append( + ItmRankDatasetHardNeg(txt_db, img_db, + opts.negative_size, opts.hard_neg_size)) + else: + train_datasets.append(ItmRankDataset(txt_db, img_db, + opts.negative_size)) + train_dataset = ConcatDataset(train_datasets) + + # hard negative + hn_datasets = [] + for txt_path, img_path in zip(opts.train_txt_dbs, opts.train_img_dbs): + img_db = all_img_dbs[img_path] + txt_db = TxtTokLmdb(txt_path, opts.max_txt_len) + hn_datasets.append(ItmHardNegDataset(txt_db, img_db, + opts.inf_minibatch_size)) + hn_dataset = ConcatDataset(hn_datasets) + hn_dataloader = build_dataloader(hn_dataset, itm_hn_collate, False, opts) + hard_neg_dir = f'{opts.output_dir}/results_train/' + + # val + LOGGER.info(f"Loading Val Dataset {opts.val_txt_db}, {opts.val_img_db}") + val_img_db = all_img_dbs[opts.val_img_db] + val_txt_db = TxtTokLmdb(opts.val_txt_db, -1) + val_dataset = ItmValDataset(val_txt_db, val_img_db, + opts.inf_minibatch_size) + val_dataloader = build_dataloader(val_dataset, itm_val_collate, + False, opts) + # eval + LOGGER.info(f"Loading val, test Dataset for full evaluation: " + f"{opts.val_txt_db}, {opts.val_img_db}" + f"{opts.test_txt_db}, {opts.test_img_db}") + eval_dataset_val = ItmEvalDataset(val_txt_db, val_img_db, + opts.inf_minibatch_size) + eval_loader_val = build_dataloader(eval_dataset_val, itm_eval_collate, + False, opts) + test_img_db = all_img_dbs[opts.test_img_db] + test_txt_db = TxtTokLmdb(opts.test_txt_db, -1) + eval_dataset_test = ItmEvalDataset(test_txt_db, test_img_db, + opts.inf_minibatch_size) + eval_loader_test = build_dataloader(eval_dataset_test, itm_eval_collate, + False, opts) + + # Prepare model + if opts.checkpoint: + checkpoint = torch.load(opts.checkpoint) + else: + checkpoint = {} + + model = UniterForImageTextRetrievalFast.from_pretrained( + opts.model_config, state_dict=checkpoint, + img_dim=IMG_DIM, margin=opts.margin) + model.init_output() # pretrain ITM head is different from ranking head + model.to(device) + # make sure every process has same model parameters in the beginning + broadcast_tensors([p.data for p in model.parameters()], 0) + set_dropout(model, opts.dropout) + + # Prepare optimizer + optimizer = build_optimizer(model, opts) + model, optimizer = amp.initialize(model, optimizer, + enabled=opts.fp16, opt_level='O2') + + global_step = 0 + LOGGER.info(f"***** Running training on {n_gpu} GPUs *****") + LOGGER.info(" Num examples = %d", len(train_dataset) * hvd.size()) + LOGGER.info(" Batch size = %d", opts.train_batch_size) + LOGGER.info(" Accumulate steps = %d", opts.gradient_accumulation_steps) + LOGGER.info(" Num steps = %d", opts.num_train_steps) + + running_loss = RunningMeter('loss') + model.train() + + if opts.steps_per_hard_neg != -1: + compute_hard_neg(model, hn_dataloader, train_dataset, + opts.hard_neg_pool_size, hard_neg_dir) + + n_examples = 0 + start = time() + # quick hack for amp delay_unscale bug + optimizer.zero_grad() + optimizer.step() + while True: + train_dataloader = build_dataloader( + train_dataset, itm_rank_collate, True, opts) + for step, batch in enumerate(train_dataloader): + n_examples += batch['input_ids'].size(0) + loss = model(batch, compute_loss=True) + loss = loss.mean() + delay_unscale = (step+1) % opts.gradient_accumulation_steps != 0 + with amp.scale_loss(loss, optimizer, delay_unscale=delay_unscale + ) as scaled_loss: + scaled_loss.backward() + if not delay_unscale: + # gather gradients from every processes + # do this before unscaling to make sure every process uses + # the same gradient scale + grads = [p.grad.data for p in model.parameters() + if p.requires_grad and p.grad is not None] + all_reduce_and_rescale_tensors(grads, float(1)) + + running_loss(loss.item()) + if (step + 1) % opts.gradient_accumulation_steps == 0: + global_step += 1 + + # learning rate scheduling + lr_this_step = get_lr_sched(global_step, opts) + for param_group in optimizer.param_groups: + param_group['lr'] = lr_this_step + TB_LOGGER.add_scalar('lr', lr_this_step, global_step) + + # log loss + losses = all_gather_list(running_loss) + running_loss = RunningMeter( + 'loss', sum(l.val for l in losses)/len(losses)) + TB_LOGGER.add_scalar('loss', running_loss.val, global_step) + TB_LOGGER.step() + + # update model params + if opts.grad_norm != -1: + grad_norm = clip_grad_norm_(amp.master_params(optimizer), + opts.grad_norm) + TB_LOGGER.add_scalar('grad_norm', grad_norm, global_step) + optimizer.step() + optimizer.zero_grad() + pbar.update(1) + + if global_step % 100 == 0: + # monitor training throughput + LOGGER.info(f'============Step {global_step}=============') + tot_ex = sum(all_gather_list(n_examples)) + ex_per_sec = int(tot_ex / (time()-start)) + LOGGER.info(f'{tot_ex} examples trained at ' + f'{ex_per_sec} ex/s') + TB_LOGGER.add_scalar('perf/ex_per_s', + ex_per_sec, global_step) + LOGGER.info(f'===========================================') + + if global_step % opts.valid_steps == 0: + if opts.full_val: + val_log = evaluate(model, eval_loader_val) + TB_LOGGER.log_scaler_dict( + {f"valid/{k}": v for k, v in val_log.items()}) + else: + val_log = validate(model, val_dataloader) + TB_LOGGER.log_scaler_dict(val_log) + model_saver.save(model, global_step) + + if (opts.steps_per_hard_neg != -1 + and global_step % opts.steps_per_hard_neg == 0): + # sample hard negatives for training + compute_hard_neg(model, hn_dataloader, train_dataset, + opts.hard_neg_pool_size, hard_neg_dir) + # break to reconstruct loader + # for potential multi-worker issue (not sure) + break + + if global_step >= opts.num_train_steps: + break + + if global_step >= opts.num_train_steps: + break + # NOTE can no longer count epochs + + pbar.close() + # final validation + model_saver.save(model, f'{global_step}_final') + + # evaluation + for split, loader in [('val', eval_loader_val), + ('test', eval_loader_test)]: + eval_log = evaluate(model, loader) + TB_LOGGER.log_scaler_dict({f"eval/{split}_{k}": v + for k, v in eval_log.items()}) + if hvd.rank() != 0: + continue + LOGGER.info( + f"========================= {split} ===========================\n" + f"image retrieval R1: {eval_log['img_r1']*100:.2f},\n" + f"image retrieval R5: {eval_log['img_r5']*100:.2f},\n" + f"image retrieval R10: {eval_log['img_r10']*100:.2f}\n" + f"text retrieval R1: {eval_log['txt_r1']*100:.2f},\n" + f"text retrieval R5: {eval_log['txt_r5']*100:.2f},\n" + f"text retrieval R10: {eval_log['txt_r10']*100:.2f}") + LOGGER.info("=========================================================") + + +@torch.no_grad() +def get_hard_negs(model, loader, hard_negative_num=20): + LOGGER.info("start running hard negative extraction") + st = time() + if hvd.rank() == 0: + pbar = tqdm(total=len(loader)) + else: + pbar = NoOp() + model.eval() + + txt2hardimgs = {} + img_to_score_txts = defaultdict(list) + for batch in loader: + scores = model(batch, compute_loss=False).squeeze(-1) + txt = batch['gt_txt_id'] + imgs = batch['neg_img_ids'] + # record hard images + hard_indices = scores.topk(hard_negative_num, sorted=False)[1].tolist() + txt2hardimgs[txt] = [imgs[i] for i in hard_indices] + # record img2txts + for i, img in enumerate(imgs): + img_to_score_txts[img].append((scores[i].item(), txt)) + pbar.update(1) + pbar.close() + + LOGGER.info("start computing hard texts from images...") + n_less_neg = 0 + tot_text = 0 + img2hardtxts = {} + # need to gather hard texts from all GPUs + all_img_ids = [i for dset in loader.dataset.datasets + for i in dset.all_img_ids] + all_img_ids = any_broadcast(all_img_ids, 0) + for img in all_img_ids: + score_txts = img_to_score_txts[img] + scores, txts = map(list, unzip( + pair for pairs in all_gather_list(score_txts) + for pair in pairs)) + if hvd.rank() != 0: + # only rank 0 needs to compute + continue + tot_text += len(txts) + if len(txts) < hard_negative_num: + # not enough negatives + hard_indices = range(len(txts)) + n_less_neg += 1 + else: + hard_indices = torch.tensor(scores).topk(hard_negative_num, + sorted=False)[1].tolist() + img2hardtxts[img] = [txts[i] for i in hard_indices] + + n_less_neg = sum(all_gather_list(n_less_neg)) + if n_less_neg: + LOGGER.info(f"Warning: {n_less_neg} images did not " + f"sample enough negatives") + LOGGER.info(f"hard negative extraction finished " + f"in {int(time() - st)} seconds " + f"({tot_text//len(img_to_score_txts)} texts per images)") + + model.train() + return txt2hardimgs, img2hardtxts + + +@torch.no_grad() +def validate(model, val_loader): + if hvd.rank() == 0: + pbar = tqdm(total=len(val_loader)) + else: + pbar = NoOp() + LOGGER.info("start running Image Retrieval validation ...") + model.eval() + n_ex = 0 + st = time() + + recall_at_1, recall_at_5, recall_at_10 = 0, 0, 0 + for batch in val_loader: + scores = model(batch, compute_loss=False) + _, indices = scores.topk(10, dim=0) + rank = (indices == 0).nonzero() + if rank.numel(): + rank = rank.item() + if rank < 1: + recall_at_1 += 1 + if rank < 5: + recall_at_5 += 1 + if rank < 10: + recall_at_10 += 1 + n_ex += 1 + pbar.update(1) + n_ex = sum(all_gather_list(n_ex)) + recall_at_1 = sum(all_gather_list(recall_at_1)) / n_ex + recall_at_5 = sum(all_gather_list(recall_at_5)) / n_ex + recall_at_10 = sum(all_gather_list(recall_at_10)) / n_ex + tot_time = time()-st + val_log = {'valid/ex_per_s': n_ex/tot_time, + 'valid/recall_1': recall_at_1, + 'valid/recall_5': recall_at_5, + 'valid/recall_10': recall_at_10} + model.train() + LOGGER.info(f"validation finished in {int(tot_time)} seconds, " + f"recall_1: {recall_at_1*100:.2f}, " + f"recall_5: {recall_at_5*100:.2f}, " + f"recall_10: {recall_at_10*100:.2f}") + pbar.close() + return val_log + + +@torch.no_grad() +def evaluate(model, eval_loader): + st = time() + LOGGER.info("start running Image/Text Retrieval evaluation ...") + score_matrix = inference(model, eval_loader) + dset = eval_loader.dataset + all_score = hvd.allgather(score_matrix) + all_txt_ids = [i for ids in all_gather_list(dset.ids) + for i in ids] + all_img_ids = dset.all_img_ids + assert all_score.size() == (len(all_txt_ids), len(all_img_ids)) + if hvd.rank() != 0: + return {} + + # NOTE: only use rank0 to compute final scores + # TODO store score_matrix and ids + eval_log = itm_eval(all_score, all_txt_ids, all_img_ids, + dset.txt2img, dset.img2txts) + + tot_time = time()-st + LOGGER.info(f"evaluation finished in {int(tot_time)} seconds, ") + return eval_log + + +@torch.no_grad() +def inference(model, eval_loader): + model.eval() + if hvd.rank() == 0: + pbar = tqdm(total=len(eval_loader)) + else: + pbar = NoOp() + score_matrix = torch.zeros(len(eval_loader.dataset), + len(eval_loader.dataset.all_img_ids), + device=torch.device("cuda"), + dtype=torch.float16) + for i, mini_batches in enumerate(eval_loader): + j = 0 + for batch in mini_batches: + scores = model(batch, compute_loss=False) + bs = scores.size(0) + # score_matrix.data[i, j:j+bs] = scores.data.squeeze(1).half() + score_matrix.data[i, j:j+bs] = scores.data.half() + j += bs + assert j == score_matrix.size(1) + pbar.update(1) + model.train() + pbar.close() + return score_matrix + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + # Required parameters + + parser.add_argument('--compressed_db', action='store_true', + help='use compressed LMDB') + parser.add_argument("--checkpoint", + default=None, type=str, + help="pretrained MLM") + + parser.add_argument("--output_dir", default=None, type=str, + help="The output directory where the model " + "checkpoints will be written.") + + # Prepro parameters + parser.add_argument('--max_txt_len', type=int, default=60, + help='max number of tokens in text (BERT BPE)') + parser.add_argument('--conf_th', type=float, default=0.2, + help='threshold for dynamic bounding boxes ' + '(-1 for fixed)') + parser.add_argument('--max_bb', type=int, default=100, + help='max number of bounding boxes') + parser.add_argument('--min_bb', type=int, default=10, + help='min number of bounding boxes') + parser.add_argument('--num_bb', type=int, default=36, + help='static number of bounding boxes') + + # training parameters + parser.add_argument("--train_batch_size", + default=128, type=int, + help="Total batch size for training. " + "(batch by examples)") + + parser.add_argument("--negative_size", + default=1, type=int, + help="Number of negative samples per positive sample") + parser.add_argument("--hard_neg_size", + default=0, type=int, + help="Number of hard negative samples " + "per positive sample") + + parser.add_argument("--hard_neg_pool_size", + default=20, type=int, + help="Size of hard negative pool") + parser.add_argument("--steps_per_hard_neg", + default=-1, type=int, + help="Run hard neg sampling every X steps") + + parser.add_argument("--inf_minibatch_size", + default=400, type=int, + help="batch size for running inference. " + "(used for validation, evaluation," + " and hard negative sampling)") + + parser.add_argument("--margin", + default=0.2, type=float, + help="margin of ranking loss") + parser.add_argument('--gradient_accumulation_steps', + type=int, + default=16, + help="Number of updates steps to accumualte before " + "performing a backward/update pass.") + parser.add_argument("--learning_rate", + default=3e-5, + type=float, + help="The initial learning rate for Adam.") + parser.add_argument("--valid_steps", + default=1000, + type=int, + help="Run validation every X steps") + parser.add_argument("--num_train_steps", + default=100000, + type=int, + help="Total number of training updates to perform.") + parser.add_argument("--optim", default='adam', + choices=['adam', 'adamax', 'adamw'], + help="optimizer") + parser.add_argument("--betas", default=[0.9, 0.98], nargs='+', + help="beta for adam optimizer") + parser.add_argument("--decay", default='linear', + choices=['linear', 'invsqrt', 'constant'], + help="learning rate decay method") + parser.add_argument("--dropout", + default=0.1, + type=float, + help="tune dropout regularization") + # FIXME check weight decay + parser.add_argument("--weight_decay", + default=0.01, + type=float, + help="weight decay (L2) regularization") + parser.add_argument("--grad_norm", + default=0.25, + type=float, + help="gradient clipping (-1 for no clipping)") + parser.add_argument("--warmup_steps", + default=4000, + type=int, + help="Number of training steps to perform linear " + "learning rate warmup for. (invsqrt decay)") + + # device parameters + parser.add_argument('--seed', + type=int, + default=42, + help="random seed for initialization") + parser.add_argument('--full_val', action='store_true', + help="Always run full evaluation during training") + parser.add_argument('--fp16', action='store_true', + help="Whether to use 16-bit float precision instead " + "of 32-bit") + parser.add_argument('--n_workers', type=int, default=4, + help="number of data workers") + parser.add_argument('--pin_mem', action='store_true', + help="pin memory") + + # can use config files + parser.add_argument('--config', help='JSON config files') + + args = parse_with_config(parser) + + # if exists(args.output_dir) and os.listdir(args.output_dir): + # raise ValueError("Output directory ({}) already exists and is not " + # "empty.".format(args.output_dir)) + + # options safe guard + if args.conf_th == -1: + assert args.max_bb + args.max_txt_len + 2 <= 512 + else: + assert args.num_bb + args.max_txt_len + 2 <= 512 + assert (args.hard_neg_size + <= args.hard_neg_pool_size + <= args.inf_minibatch_size) + if args.steps_per_hard_neg != -1: + assert args.hard_neg_size > 0 + + main(args) diff --git a/uniter_model/train_itm_v2.py b/uniter_model/train_itm_v2.py new file mode 100644 index 0000000..a22b9b9 --- /dev/null +++ b/uniter_model/train_itm_v2.py @@ -0,0 +1,499 @@ +""" +Copyright (c) Microsoft Corporation. +Licensed under the MIT license. + +UNITER finetuning for Image-Text Retrieval +""" +import argparse +import os +from os.path import exists, join +from time import time + +import torch +from torch.nn.utils import clip_grad_norm_ +from torch.utils.data import DataLoader, ConcatDataset +from apex import amp +from horovod import torch as hvd +from tqdm import tqdm + +from data import (PrefetchLoader, TxtTokLmdb, ImageLmdbGroup, + ItmRankDatasetHardNegFromText, + ItmRankDatasetHardNegFromImage, itm_rank_hnv2_collate, + ItmValDataset, itm_val_collate, + ItmEvalDataset, itm_eval_collate) +from model import UniterForImageTextRetrievalHardNeg +from optim import get_lr_sched +from optim.misc import build_optimizer + +from utils.logger import LOGGER, TB_LOGGER, RunningMeter, add_log_to_file +from utils.distributed import (all_reduce_and_rescale_tensors, all_gather_list, + broadcast_tensors) +from utils.save import ModelSaver, save_training_meta +from utils.misc import NoOp, parse_with_config, set_dropout, set_random_seed +from utils.const import IMG_DIM +from eval.itm import itm_eval + + +def build_dataloader(dataset, collate_fn, is_train, opts): + dataloader = DataLoader(dataset, batch_size=1, + shuffle=is_train, drop_last=is_train, + num_workers=opts.n_workers, + pin_memory=opts.pin_mem, collate_fn=collate_fn) + dataloader = PrefetchLoader(dataloader) + return dataloader + + +def main(opts): + hvd.init() + n_gpu = hvd.size() + device = torch.device("cuda", hvd.local_rank()) + torch.cuda.set_device(hvd.local_rank()) + rank = hvd.rank() + opts.rank = rank + LOGGER.info("device: {} n_gpu: {}, rank: {}, " + "16-bits training: {}".format( + device, n_gpu, hvd.rank(), opts.fp16)) + + set_random_seed(opts.seed) + + if hvd.rank() == 0: + save_training_meta(opts) + TB_LOGGER.create(join(opts.output_dir, 'log')) + pbar = tqdm(total=opts.num_train_steps) + model_saver = ModelSaver(join(opts.output_dir, 'ckpt')) + add_log_to_file(join(opts.output_dir, 'log', 'log.txt')) + # store ITM predictions + os.makedirs(join(opts.output_dir, 'results_val')) + os.makedirs(join(opts.output_dir, 'results_test')) + os.makedirs(join(opts.output_dir, 'results_train')) + else: + LOGGER.disabled = True + pbar = NoOp() + model_saver = NoOp() + + # train_examples = None + LOGGER.info(f"Loading Train Dataset {opts.train_txt_dbs}, " + f"{opts.train_img_dbs}") + # check multiple DBs + assert len(opts.train_txt_dbs) == len(opts.train_img_dbs), \ + "train txt_db and img_db have different length" + + # load DBs and image dirs + all_img_dbs = ImageLmdbGroup(opts.conf_th, opts.max_bb, opts.min_bb, + opts.num_bb, opts.compressed_db) + # train + LOGGER.info(f"Loading Train Dataset " + f"{opts.train_txt_dbs}, {opts.train_img_dbs}") + train_datasets_t = [] + train_datasets_i = [] + for txt_path, img_path in zip(opts.train_txt_dbs, opts.train_img_dbs): + img_db = all_img_dbs[img_path] + txt_db = TxtTokLmdb(txt_path, opts.max_txt_len) + train_datasets_t.append( + ItmRankDatasetHardNegFromText(txt_db, img_db, opts.negative_size)) + train_datasets_i.append( + ItmRankDatasetHardNegFromImage(txt_db, img_db, opts.negative_size)) + train_dataset_t = ConcatDataset(train_datasets_t) + train_dataset_i = ConcatDataset(train_datasets_i) + train_dataloader_t = build_dataloader( + train_dataset_t, itm_rank_hnv2_collate, True, opts) + train_dataloader_i = build_dataloader( + train_dataset_i, itm_rank_hnv2_collate, True, opts) + + # val + LOGGER.info(f"Loading Val Dataset {opts.val_txt_db}, {opts.val_img_db}") + val_img_db = all_img_dbs[opts.val_img_db] + val_txt_db = TxtTokLmdb(opts.val_txt_db, -1) + val_dataset = ItmValDataset(val_txt_db, val_img_db, + opts.inf_minibatch_size) + val_dataloader = build_dataloader(val_dataset, itm_val_collate, + False, opts) + # eval + LOGGER.info(f"Loading val, test Dataset for full evaluation: " + f"{opts.val_txt_db}, {opts.val_img_db}" + f"{opts.test_txt_db}, {opts.test_img_db}") + eval_dataset_val = ItmEvalDataset(val_txt_db, val_img_db, + opts.inf_minibatch_size) + eval_loader_val = build_dataloader(eval_dataset_val, itm_eval_collate, + False, opts) + test_img_db = all_img_dbs[opts.test_img_db] + test_txt_db = TxtTokLmdb(opts.test_txt_db, -1) + eval_dataset_test = ItmEvalDataset(test_txt_db, test_img_db, + opts.inf_minibatch_size) + eval_loader_test = build_dataloader(eval_dataset_test, itm_eval_collate, + False, opts) + + # Prepare model + if opts.checkpoint: + checkpoint = torch.load(opts.checkpoint) + else: + checkpoint = {} + + model = UniterForImageTextRetrievalHardNeg.from_pretrained( + opts.model_config, state_dict=checkpoint, + img_dim=IMG_DIM, margin=opts.margin, hard_size=opts.hard_neg_size) + model.init_output() # pretrain ITM head is different from ranking head + model.to(device) + # make sure every process has same model parameters in the beginning + broadcast_tensors([p.data for p in model.parameters()], 0) + set_dropout(model, opts.dropout) + + # Prepare optimizer + optimizer = build_optimizer(model, opts) + model, optimizer = amp.initialize(model, optimizer, + enabled=opts.fp16, opt_level='O2') + + LOGGER.info(f"***** Running training on {n_gpu} GPUs *****") + LOGGER.info(" Num examples = %d", + sum(all_gather_list(len(train_dataset_t)))) + LOGGER.info(" Batch size = %d", opts.train_batch_size) + LOGGER.info(" Num steps = %d", opts.num_train_steps) + + running_loss = RunningMeter('loss') + model.train() + + global_step = 0 + step = 0 + n_examples = 0 + n_hard_ex = 0 + n_epoch = 0 + start = time() + train_iter_i = iter(train_dataloader_i) + # quick hack for amp delay_unscale bug + optimizer.zero_grad() + optimizer.step() + while True: + for batch in train_dataloader_t: + + # hard text from image + try: + batch_i = next(train_iter_i) + except StopIteration: + train_iter_i = iter(train_dataloader_i) + batch_i = next(train_iter_i) + n_examples += batch_i['attn_masks'].size(0) + loss = model(batch_i, sample_from='i', compute_loss=True) + n_hard_ex += loss.numel() + loss = loss.mean() / opts.train_batch_size + with amp.scale_loss(loss, optimizer, delay_unscale=True + ) as scaled_loss: + scaled_loss.backward() + + # hard image from text + n_examples += batch['attn_masks'].size(0) + loss = model(batch, sample_from='t', compute_loss=True) + n_hard_ex += loss.numel() + # NOTE we use gradient accumulation to implemented train_batch_size + loss = loss.mean() / opts.train_batch_size + + step += 1 + delay_unscale = step % opts.train_batch_size != 0 + with amp.scale_loss(loss, optimizer, delay_unscale=delay_unscale + ) as scaled_loss: + scaled_loss.backward() + if not delay_unscale: + # gather gradients from every processes + # do this before unscaling to make sure every process uses + # the same gradient scale + grads = [p.grad.data for p in model.parameters() + if p.requires_grad and p.grad is not None] + all_reduce_and_rescale_tensors(grads, float(1)) + + running_loss(loss.item()) + if step % opts.train_batch_size == 0: + global_step += 1 + + # learning rate scheduling + lr_this_step = get_lr_sched(global_step, opts) + for param_group in optimizer.param_groups: + param_group['lr'] = lr_this_step + TB_LOGGER.add_scalar('lr', lr_this_step, global_step) + + # log loss + losses = all_gather_list(running_loss) + running_loss = RunningMeter( + 'loss', sum(l.val for l in losses)/len(losses)) + TB_LOGGER.add_scalar('loss', running_loss.val, global_step) + TB_LOGGER.step() + + # update model params + if opts.grad_norm != -1: + grad_norm = clip_grad_norm_(amp.master_params(optimizer), + opts.grad_norm) + TB_LOGGER.add_scalar('grad_norm', grad_norm, global_step) + optimizer.step() + optimizer.zero_grad() + pbar.update(1) + + if global_step % 100 == 0: + # monitor training throughput + LOGGER.info(f'------------Step {global_step}-------------') + tot_ex = sum(all_gather_list(n_examples)) + ex_per_sec = int(tot_ex / (time()-start)) + tot_hn = sum(all_gather_list(n_hard_ex)) + hn_per_sec = int(tot_hn / (time()-start)) + LOGGER.info(f'{tot_ex} ({tot_hn}) examples (hard) ' + f'trained at {ex_per_sec} ({hn_per_sec}) ex/s') + TB_LOGGER.add_scalar('perf/ex_per_s', + ex_per_sec, global_step) + TB_LOGGER.add_scalar('perf/hn_per_s', + hn_per_sec, global_step) + LOGGER.info(f'-------------------------------------------') + + if global_step % opts.valid_steps == 0: + if opts.full_val: + LOGGER.info( + f"========================== Step {global_step} " + f"==========================") + val_log = evaluate(model, eval_loader_val) + TB_LOGGER.log_scaler_dict( + {f"valid/{k}": v for k, v in val_log.items()}) + if hvd.rank() == 0: + LOGGER.info( + f"image retrieval R1: " + f"{val_log['img_r1']*100:.2f},\n" + f"image retrieval R5: " + f"{val_log['img_r5']*100:.2f},\n" + f"image retrieval R10: " + f"{val_log['img_r10']*100:.2f}\n" + f"text retrieval R1: " + f"{val_log['txt_r1']*100:.2f},\n" + f"text retrieval R5: " + f"{val_log['txt_r5']*100:.2f},\n" + f"text retrieval R10: " + f"{val_log['txt_r10']*100:.2f}") + LOGGER.info("=================================" + "=================================") + else: + val_log = validate(model, val_dataloader) + TB_LOGGER.log_scaler_dict(val_log) + model_saver.save(model, global_step) + + if global_step >= opts.num_train_steps: + break + + if global_step >= opts.num_train_steps: + break + n_epoch += 1 + LOGGER.info(f"finished {n_epoch} epochs") + + pbar.close() + # final validation + val_log = validate(model, val_dataloader) + TB_LOGGER.log_scaler_dict(val_log) + model_saver.save(model, f'{global_step}_final') + + # evaluation + for split, loader in [('val', eval_loader_val), + ('test', eval_loader_test)]: + eval_log = evaluate(model, loader) + TB_LOGGER.log_scaler_dict({f"eval/{split}_{k}": v + for k, v in eval_log.items()}) + if hvd.rank() != 0: + continue + LOGGER.info( + f"========================= {split} ===========================\n" + f"image retrieval R1: {eval_log['img_r1']*100:.2f},\n" + f"image retrieval R5: {eval_log['img_r5']*100:.2f},\n" + f"image retrieval R10: {eval_log['img_r10']*100:.2f}\n" + f"text retrieval R1: {eval_log['txt_r1']*100:.2f},\n" + f"text retrieval R5: {eval_log['txt_r5']*100:.2f},\n" + f"text retrieval R10: {eval_log['txt_r10']*100:.2f}") + LOGGER.info("=========================================================") + + +@torch.no_grad() +def validate(model, val_loader): + if hvd.rank() == 0: + pbar = tqdm(total=len(val_loader)) + else: + pbar = NoOp() + LOGGER.info("start running Image Retrieval validation ...") + model.eval() + n_ex = 0 + st = time() + + recall_at_1, recall_at_5, recall_at_10 = 0, 0, 0 + for batch in val_loader: + scores = model(batch, compute_loss=False) + _, indices = scores.squeeze(1).topk(10, dim=0) + rank = (indices == 0).nonzero() + if rank.numel(): + rank = rank.item() + if rank < 1: + recall_at_1 += 1 + if rank < 5: + recall_at_5 += 1 + if rank < 10: + recall_at_10 += 1 + n_ex += 1 + pbar.update(1) + n_ex = sum(all_gather_list(n_ex)) + recall_at_1 = sum(all_gather_list(recall_at_1)) / n_ex + recall_at_5 = sum(all_gather_list(recall_at_5)) / n_ex + recall_at_10 = sum(all_gather_list(recall_at_10)) / n_ex + tot_time = time()-st + val_log = {'valid/ex_per_s': n_ex/tot_time, + 'valid/recall_1': recall_at_1, + 'valid/recall_5': recall_at_5, + 'valid/recall_10': recall_at_10} + model.train() + LOGGER.info(f"validation finished in {int(tot_time)} seconds, " + f"recall_1: {recall_at_1*100:.2f}, " + f"recall_5: {recall_at_5*100:.2f}, " + f"recall_10: {recall_at_10*100:.2f}") + pbar.close() + return val_log + + +@torch.no_grad() +def evaluate(model, eval_loader): + st = time() + LOGGER.info("start running Image/Text Retrieval evaluation ...") + score_matrix = inference(model, eval_loader) + dset = eval_loader.dataset + all_score = hvd.allgather(score_matrix) + all_txt_ids = [i for ids in all_gather_list(dset.ids) + for i in ids] + all_img_ids = dset.all_img_ids + assert all_score.size() == (len(all_txt_ids), len(all_img_ids)) + if hvd.rank() != 0: + return {} + + # NOTE: only use rank0 to compute final scores + # TODO store score_matrix and ids + eval_log = itm_eval(all_score, all_txt_ids, all_img_ids, + dset.txt2img, dset.img2txts) + + tot_time = time()-st + LOGGER.info(f"evaluation finished in {int(tot_time)} seconds") + return eval_log + + +@torch.no_grad() +def inference(model, eval_loader): + model.eval() + if hvd.rank() == 0: + pbar = tqdm(total=len(eval_loader)) + else: + pbar = NoOp() + score_matrix = torch.zeros(len(eval_loader.dataset), + len(eval_loader.dataset.all_img_ids), + device=torch.device("cuda"), + dtype=torch.float16) + for i, mini_batches in enumerate(eval_loader): + j = 0 + for batch in mini_batches: + scores = model(batch, compute_loss=False) + bs = scores.size(0) + score_matrix.data[i, j:j+bs] = scores.data.squeeze(1).half() + j += bs + assert j == score_matrix.size(1) + pbar.update(1) + model.train() + pbar.close() + return score_matrix + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + # Required parameters + + parser.add_argument('--compressed_db', action='store_true', + help='use compressed LMDB') + parser.add_argument("--checkpoint", + default=None, type=str, + help="pretrained MLM") + + parser.add_argument("--output_dir", default=None, type=str, + help="The output directory where the model " + "checkpoints will be written.") + + # Prepro parameters + parser.add_argument('--max_txt_len', type=int, default=60, + help='max number of tokens in text (BERT BPE)') + parser.add_argument('--conf_th', type=float, default=0.2, + help='threshold for dynamic bounding boxes ' + '(-1 for fixed)') + parser.add_argument('--max_bb', type=int, default=100, + help='max number of bounding boxes') + parser.add_argument('--min_bb', type=int, default=10, + help='min number of bounding boxes') + parser.add_argument('--num_bb', type=int, default=36, + help='static number of bounding boxes') + + # training parameters + parser.add_argument("--train_batch_size", default=32, type=int, + help="batch size (# positive examples) for training. " + "(implemented with gradient accumulation)") + + parser.add_argument("--negative_size", default=511, type=int, + help="Number of negative samples per positive sample" + "(forward only)") + parser.add_argument("--hard_neg_size", default=31, type=int, + help="Number of hard negative samples " + "per positive sample (acutally used to train)") + + parser.add_argument("--inf_minibatch_size", default=512, type=int, + help="batch size for running inference. " + "(used for validation and evaluation)") + + parser.add_argument("--margin", default=0.2, type=float, + help="margin of ranking loss") + parser.add_argument("--learning_rate", default=3e-5, type=float, + help="The initial learning rate for Adam.") + parser.add_argument("--valid_steps", default=1000, type=int, + help="Run validation every X steps") + parser.add_argument("--num_train_steps", default=100000, type=int, + help="Total number of training updates to perform.") + parser.add_argument("--optim", default='adam', + choices=['adam', 'adamax', 'adamw'], + help="optimizer") + parser.add_argument("--betas", default=[0.9, 0.98], nargs='+', + help="beta for adam optimizer") + parser.add_argument("--decay", default='linear', + choices=['linear', 'invsqrt', 'constant'], + help="learning rate decay method") + parser.add_argument("--dropout", default=0.1, type=float, + help="tune dropout regularization") + parser.add_argument("--weight_decay", default=0.01, type=float, + help="weight decay (L2) regularization") + parser.add_argument("--grad_norm", default=0.25, type=float, + help="gradient clipping (-1 for no clipping)") + parser.add_argument("--warmup_steps", default=4000, type=int, + help="Number of training steps to perform linear " + "learning rate warmup for. (invsqrt decay)") + + # device parameters + parser.add_argument('--seed', type=int, default=42, + help="random seed for initialization") + parser.add_argument('--full_val', action='store_true', + help="Always run full evaluation during training") + parser.add_argument('--fp16', action='store_true', + help="Whether to use 16-bit float precision instead " + "of 32-bit") + parser.add_argument('--n_workers', type=int, default=4, + help="number of data workers") + parser.add_argument('--pin_mem', action='store_true', + help="pin memory") + + # can use config files + parser.add_argument('--config', help='JSON config files') + + args = parse_with_config(parser) + + if exists(args.output_dir) and os.listdir(args.output_dir): + raise ValueError("Output directory ({}) already exists and is not " + "empty.".format(args.output_dir)) + + # options safe guard + if args.conf_th == -1: + assert args.max_bb + args.max_txt_len + 2 <= 512 + else: + assert args.num_bb + args.max_txt_len + 2 <= 512 + + # for tensor core + assert (args.negative_size+1) % 8 == (args.hard_neg_size+1) % 8 == 0 + + main(args) diff --git a/uniter_model/train_nlvr2.py b/uniter_model/train_nlvr2.py new file mode 100644 index 0000000..7d583a5 --- /dev/null +++ b/uniter_model/train_nlvr2.py @@ -0,0 +1,418 @@ +""" +Copyright (c) Microsoft Corporation. +Licensed under the MIT license. + +UNITER finetuning for NLVR2 +""" +import argparse +from collections import defaultdict +import json +import os +from os.path import exists, join +from time import time + +import torch +from torch.nn import functional as F +from torch.nn.utils import clip_grad_norm_ +from torch.utils.data import DataLoader + +from apex import amp +from horovod import torch as hvd + +from tqdm import tqdm + +from data import (TokenBucketSampler, DetectFeatLmdb, TxtTokLmdb, + Nlvr2PairedDataset, Nlvr2PairedEvalDataset, + Nlvr2TripletDataset, Nlvr2TripletEvalDataset, + nlvr2_paired_collate, nlvr2_paired_eval_collate, + nlvr2_triplet_collate, nlvr2_triplet_eval_collate, + PrefetchLoader) +from model.nlvr2 import (UniterForNlvr2Paired, UniterForNlvr2Triplet, + UniterForNlvr2PairedAttn) +from optim import get_lr_sched +from optim.misc import build_optimizer + +from utils.logger import LOGGER, TB_LOGGER, RunningMeter, add_log_to_file +from utils.distributed import (all_reduce_and_rescale_tensors, all_gather_list, + broadcast_tensors) +from utils.save import ModelSaver, save_training_meta +from utils.misc import NoOp, parse_with_config, set_dropout, set_random_seed +from utils.const import IMG_DIM, BUCKET_SIZE + + +def create_dataloader(img_path, txt_path, batch_size, is_train, + dset_cls, collate_fn, opts): + img_db = DetectFeatLmdb(img_path, opts.conf_th, opts.max_bb, opts.min_bb, + opts.num_bb, opts.compressed_db) + txt_db = TxtTokLmdb(txt_path, opts.max_txt_len if is_train else -1) + dset = dset_cls(txt_db, img_db, opts.use_img_type) + sampler = TokenBucketSampler(dset.lens, bucket_size=BUCKET_SIZE, + batch_size=batch_size, droplast=is_train) + loader = DataLoader(dset, batch_sampler=sampler, + num_workers=opts.n_workers, pin_memory=opts.pin_mem, + collate_fn=collate_fn) + return PrefetchLoader(loader) + + +def main(opts): + hvd.init() + n_gpu = hvd.size() + device = torch.device("cuda", hvd.local_rank()) + torch.cuda.set_device(hvd.local_rank()) + rank = hvd.rank() + opts.rank = rank + LOGGER.info("device: {} n_gpu: {}, rank: {}, " + "16-bits training: {}".format( + device, n_gpu, hvd.rank(), opts.fp16)) + + if opts.gradient_accumulation_steps < 1: + raise ValueError("Invalid gradient_accumulation_steps parameter: {}, " + "should be >= 1".format( + opts.gradient_accumulation_steps)) + + set_random_seed(opts.seed) + + # train_examples = None + LOGGER.info(f"Loading Train Dataset {opts.train_txt_db}, " + f"{opts.train_img_dir}") + if 'paired' in opts.model: + DatasetCls = Nlvr2PairedDataset + EvalDatasetCls = Nlvr2PairedEvalDataset + collate_fn = nlvr2_paired_collate + eval_collate_fn = nlvr2_paired_eval_collate + if opts.model == 'paired': + ModelCls = UniterForNlvr2Paired + elif opts.model == 'paired-attn': + ModelCls = UniterForNlvr2PairedAttn + else: + raise ValueError('unrecognized model type') + elif opts.model == 'triplet': + DatasetCls = Nlvr2TripletDataset + EvalDatasetCls = Nlvr2TripletEvalDataset + ModelCls = UniterForNlvr2Triplet + collate_fn = nlvr2_triplet_collate + eval_collate_fn = nlvr2_triplet_eval_collate + else: + raise ValueError('unrecognized model type') + + # data loaders + train_dataloader = create_dataloader(opts.train_img_db, opts.train_txt_db, + opts.train_batch_size, True, + DatasetCls, collate_fn, opts) + val_dataloader = create_dataloader(opts.val_img_db, opts.val_txt_db, + opts.val_batch_size, False, + EvalDatasetCls, eval_collate_fn, opts) + test_dataloader = create_dataloader(opts.test_img_db, opts.test_txt_db, + opts.val_batch_size, False, + EvalDatasetCls, eval_collate_fn, opts) + + # Prepare model + if opts.checkpoint: + checkpoint = torch.load(opts.checkpoint) + else: + checkpoint = {} + + model = ModelCls.from_pretrained(opts.model_config, state_dict=checkpoint, + img_dim=IMG_DIM) + model.init_type_embedding() + model.to(device) + # make sure every process has same model parameters in the beginning + broadcast_tensors([p.data for p in model.parameters()], 0) + set_dropout(model, opts.dropout) + + # Prepare optimizer + optimizer = build_optimizer(model, opts) + model, optimizer = amp.initialize(model, optimizer, + enabled=opts.fp16, opt_level='O2') + + global_step = 0 + if rank == 0: + save_training_meta(opts) + TB_LOGGER.create(join(opts.output_dir, 'log')) + pbar = tqdm(total=opts.num_train_steps) + model_saver = ModelSaver(join(opts.output_dir, 'ckpt')) + os.makedirs(join(opts.output_dir, 'results')) # store val predictions + add_log_to_file(join(opts.output_dir, 'log', 'log.txt')) + else: + LOGGER.disabled = True + pbar = NoOp() + model_saver = NoOp() + + LOGGER.info(f"***** Running training with {n_gpu} GPUs *****") + LOGGER.info(" Num examples = %d", len(train_dataloader.dataset)) + LOGGER.info(" Batch size = %d", opts.train_batch_size) + LOGGER.info(" Accumulate steps = %d", opts.gradient_accumulation_steps) + LOGGER.info(" Num steps = %d", opts.num_train_steps) + + running_loss = RunningMeter('loss') + model.train() + n_examples = 0 + n_epoch = 0 + start = time() + # quick hack for amp delay_unscale bug + optimizer.zero_grad() + optimizer.step() + while True: + for step, batch in enumerate(train_dataloader): + batch = defaultdict(lambda: None, batch) + targets = batch['targets'] + n_examples += targets.size(0) + + loss = model(**batch, compute_loss=True) + loss = loss.mean() + delay_unscale = (step+1) % opts.gradient_accumulation_steps != 0 + with amp.scale_loss(loss, optimizer, delay_unscale=delay_unscale + ) as scaled_loss: + scaled_loss.backward() + if not delay_unscale: + # gather gradients from every processes + # do this before unscaling to make sure every process uses + # the same gradient scale + grads = [p.grad.data for p in model.parameters() + if p.requires_grad and p.grad is not None] + all_reduce_and_rescale_tensors(grads, float(1)) + + running_loss(loss.item()) + + if (step + 1) % opts.gradient_accumulation_steps == 0: + global_step += 1 + + # learning rate scheduling + lr_this_step = get_lr_sched(global_step, opts) + for param_group in optimizer.param_groups: + param_group['lr'] = lr_this_step + TB_LOGGER.add_scalar('lr', lr_this_step, global_step) + + # log loss + losses = all_gather_list(running_loss) + running_loss = RunningMeter( + 'loss', sum(l.val for l in losses)/len(losses)) + TB_LOGGER.add_scalar('loss', running_loss.val, global_step) + TB_LOGGER.step() + + # update model params + if opts.grad_norm != -1: + grad_norm = clip_grad_norm_(amp.master_params(optimizer), + opts.grad_norm) + TB_LOGGER.add_scalar('grad_norm', grad_norm, global_step) + optimizer.step() + optimizer.zero_grad() + pbar.update(1) + + if global_step % 100 == 0: + # monitor training throughput + LOGGER.info(f'============Step {global_step}=============') + tot_ex = sum(all_gather_list(n_examples)) + ex_per_sec = int(tot_ex / (time()-start)) + LOGGER.info( + f'{tot_ex} examples trained at ' + f'{ex_per_sec} ex/s') + TB_LOGGER.add_scalar('perf/ex_per_s', + ex_per_sec, global_step) + LOGGER.info(f'===========================================') + + if global_step % opts.valid_steps == 0: + for split, loader in [('val', val_dataloader), + ('test', test_dataloader)]: + LOGGER.info(f"Step {global_step}: start running " + f"validation on {split} split...") + log, results = validate(model, loader, split) + with open(f'{opts.output_dir}/results/' + f'{split}_results_{global_step}_' + f'rank{rank}.csv', 'w') as f: + for id_, ans in results: + f.write(f'{id_},{ans}\n') + TB_LOGGER.log_scaler_dict(log) + model_saver.save(model, global_step) + if global_step >= opts.num_train_steps: + break + if global_step >= opts.num_train_steps: + break + n_epoch += 1 + LOGGER.info(f"Step {global_step}: finished {n_epoch} epochs") + for split, loader in [('val', val_dataloader), ('test', test_dataloader)]: + LOGGER.info(f"Step {global_step}: start running " + f"validation on {split} split...") + log, results = validate(model, loader, split) + with open(f'{opts.output_dir}/results/' + f'{split}_results_{global_step}_' + f'rank{rank}_final.csv', 'w') as f: + for id_, ans in results: + f.write(f'{id_},{ans}\n') + TB_LOGGER.log_scaler_dict(log) + model_saver.save(model, f'{global_step}_final') + + +@torch.no_grad() +def validate(model, val_loader, split): + model.eval() + val_loss = 0 + tot_score = 0 + n_ex = 0 + st = time() + results = [] + for i, batch in enumerate(val_loader): + qids = batch['qids'] + targets = batch['targets'] + del batch['targets'] + del batch['qids'] + scores = model(**batch, targets=None, compute_loss=False) + loss = F.cross_entropy(scores, targets, reduction='sum') + val_loss += loss.item() + tot_score += (scores.max(dim=-1, keepdim=False)[1] == targets + ).sum().item() + answers = ['True' if i == 1 else 'False' + for i in scores.max(dim=-1, keepdim=False + )[1].cpu().tolist()] + results.extend(zip(qids, answers)) + n_ex += len(qids) + val_loss = sum(all_gather_list(val_loss)) + tot_score = sum(all_gather_list(tot_score)) + n_ex = sum(all_gather_list(n_ex)) + tot_time = time()-st + val_loss /= n_ex + val_acc = tot_score / n_ex + val_log = {f'valid/{split}_loss': val_loss, + f'valid/{split}_acc': val_acc, + f'valid/{split}_ex_per_s': n_ex/tot_time} + model.train() + LOGGER.info(f"validation finished in {int(tot_time)} seconds, " + f"score: {val_acc*100:.2f}") + return val_log, results + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + # Required parameters + parser.add_argument("--train_txt_db", + default=None, type=str, + help="The input train corpus. (LMDB)") + parser.add_argument("--train_img_dir", + default=None, type=str, + help="The input train images.") + parser.add_argument("--val_txt_db", + default=None, type=str, + help="The input validation corpus. (LMDB)") + parser.add_argument("--val_img_dir", + default=None, type=str, + help="The input validation images.") + parser.add_argument("--test_txt_db", + default=None, type=str, + help="The input test corpus. (LMDB)") + parser.add_argument("--test_img_dir", + default=None, type=str, + help="The input test images.") + parser.add_argument('--compressed_db', action='store_true', + help='use compressed LMDB') + parser.add_argument("--model_config", + default=None, type=str, + help="json file for model architecture") + parser.add_argument("--checkpoint", + default=None, type=str, + help="pretrained model") + parser.add_argument("--model", default='paired', + choices=['paired', 'triplet', 'paired-attn'], + help="choose from 2 model architecture") + parser.add_argument('--use_img_type', action='store_true', + help="expand the type embedding for 2 image types") + + parser.add_argument( + "--output_dir", default=None, type=str, + help="The output directory where the model checkpoints will be " + "written.") + + # Prepro parameters + parser.add_argument('--max_txt_len', type=int, default=60, + help='max number of tokens in text (BERT BPE)') + parser.add_argument('--conf_th', type=float, default=0.2, + help='threshold for dynamic bounding boxes ' + '(-1 for fixed)') + parser.add_argument('--max_bb', type=int, default=100, + help='max number of bounding boxes') + parser.add_argument('--min_bb', type=int, default=10, + help='min number of bounding boxes') + parser.add_argument('--num_bb', type=int, default=36, + help='static number of bounding boxes') + + # training parameters + parser.add_argument("--train_batch_size", + default=4096, type=int, + help="Total batch size for training. " + "(batch by tokens)") + parser.add_argument("--val_batch_size", + default=4096, type=int, + help="Total batch size for validation. " + "(batch by tokens)") + parser.add_argument('--gradient_accumulation_steps', + type=int, + default=16, + help="Number of updates steps to accumualte before " + "performing a backward/update pass.") + parser.add_argument("--learning_rate", + default=3e-5, + type=float, + help="The initial learning rate for Adam.") + parser.add_argument("--valid_steps", + default=1000, + type=int, + help="Run validation every X steps") + parser.add_argument("--num_train_steps", + default=100000, + type=int, + help="Total number of training updates to perform.") + parser.add_argument("--optim", default='adam', + choices=['adam', 'adamax', 'adamw'], + help="optimizer") + parser.add_argument("--betas", default=[0.9, 0.98], nargs='+', type=float, + help="beta for adam optimizer") + parser.add_argument("--decay", default='linear', + choices=['linear', 'invsqrt'], + help="learning rate decay method") + parser.add_argument("--dropout", + default=0.1, + type=float, + help="tune dropout regularization") + parser.add_argument("--weight_decay", + default=0.0, + type=float, + help="weight decay (L2) regularization") + parser.add_argument("--grad_norm", + default=0.25, + type=float, + help="gradient clipping (-1 for no clipping)") + parser.add_argument("--warmup_steps", + default=4000, + type=int, + help="Number of training steps to perform linear " + "learning rate warmup for.") + + # device parameters + parser.add_argument('--seed', + type=int, + default=42, + help="random seed for initialization") + parser.add_argument('--fp16', + action='store_true', + help="Whether to use 16-bit float precision instead " + "of 32-bit") + parser.add_argument('--n_workers', type=int, default=4, + help="number of data workers") + parser.add_argument('--pin_mem', action='store_true', + help="pin memory") + + # can use config files + parser.add_argument('--config', help='JSON config files') + + args = parse_with_config(parser) + + if exists(args.output_dir) and os.listdir(args.output_dir): + raise ValueError("Output directory ({}) already exists and is not " + "empty.".format(args.output_dir)) + + if args.conf_th == -1: + assert args.max_bb + args.max_txt_len + 2 <= 512 + else: + assert args.num_bb + args.max_txt_len + 2 <= 512 + + main(args) diff --git a/uniter_model/train_re.py b/uniter_model/train_re.py new file mode 100644 index 0000000..0d8ac37 --- /dev/null +++ b/uniter_model/train_re.py @@ -0,0 +1,460 @@ +# coding=utf-8 +# copied from hugginface github +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. +# team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""BERT for Referring Expression Comprehension""" +import argparse +import json +import os +from os.path import exists, join +import random +from time import time + +import torch +from torch.nn.utils import clip_grad_norm_ +from torch.optim import Adam, Adamax +from torch.utils.data import DataLoader + +# to be deprecated once upgraded to 1.2 +# from torch.utils.data.distributed import DistributedSampler +from data import DistributedSampler + +from apex import amp +from horovod import torch as hvd + +import numpy as np +from tqdm import tqdm + +from data import (ReImageFeatDir, ReferringExpressionDataset, + ReferringExpressionEvalDataset, re_collate, re_eval_collate, + PrefetchLoader) +from model import BertForReferringExpressionComprehension +from optim import warmup_linear, noam_schedule, AdamW + +from utils.logger import LOGGER, TB_LOGGER, RunningMeter, add_log_to_file +from utils.distributed import (all_reduce_and_rescale_tensors, all_gather_list, + broadcast_tensors) +from utils.save import ModelSaver, save_training_meta +from utils.misc import NoOp, parse_with_config + + +def main(opts): + + hvd.init() + n_gpu = hvd.size() + device = torch.device("cuda", hvd.local_rank()) + torch.cuda.set_device(hvd.local_rank()) + rank = hvd.rank() + opts.rank = rank + LOGGER.info(f"device: {device}, n_gpu: {n_gpu}, rank: {hvd.rank()}, " + f"16-bits training: {opts.fp16}") + + if opts.gradient_accumulation_steps < 1: + raise ValueError("Invalid gradient_accumulation_steps parameter: {}, " + "should be >= 1".format( + opts.gradient_accumulation_steps)) + + random.seed(opts.seed) + np.random.seed(opts.seed) + torch.manual_seed(opts.seed) + if n_gpu > 0: + torch.cuda.manual_seed_all(opts.seed) + + # train_samples = None + LOGGER.info(f"Loading Train Dataset {opts.train_txt_db}, " + f"{opts.train_img_dir}") + + # load DBs and image dirs + train_img_dir = ReImageFeatDir(opts.train_img_dir) + train_dataset = ReferringExpressionDataset( + opts.train_txt_db, train_img_dir, + max_txt_len=opts.max_txt_len) + val_img_dir = ReImageFeatDir(opts.val_img_dir) + val_dataset = ReferringExpressionEvalDataset( + opts.val_txt_db, val_img_dir, + max_txt_len=opts.max_txt_len) + + # Prepro model + if opts.checkpoint and opts.checkpoint != 'scratch': + if opts.checkpoint == 'google-bert': + # from google-bert + checkpoint = None + else: + checkpoint = torch.load(opts.checkpoint) + else: + # from scratch + checkpoint = {} + bert_model = json.load(open(f'{opts.train_txt_db}/meta.json'))['bert'] + model = BertForReferringExpressionComprehension.from_pretrained( + bert_model, img_dim=2048, + loss=opts.train_loss, + margin=opts.margin, + hard_ratio=opts.hard_ratio, + mlp=opts.mlp, + state_dict=checkpoint + ) + if opts.cut_bert != -1: + # cut some layers of BERT + model.bert.encoder.layer = torch.nn.ModuleList( + model.bert.encoder.layer[:opts.cut_bert] + ) + del checkpoint + for name, module in model.named_modules(): + # we may want to tune dropout for smaller dataset + if isinstance(module, torch.nn.Dropout): + if module.p != opts.dropout: + module.p = opts.dropout + LOGGER.info(f'{name} set to {opts.dropout}') + model.to(device) + + # make sure every process has same model params in the beginning + broadcast_tensors([p.data for p in model.parameters()], 0) + + # Prepare optimizer + param_optimizer = list(model.named_parameters()) + no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] + optimizer_grouped_parameters = [ + {'params': [p for n, p in param_optimizer + if not any(nd in n for nd in no_decay)], + 'weight_decay': opts.weight_decay}, + {'params': [p for n, p in param_optimizer + if any(nd in n for nd in no_decay)], + 'weight_decay': 0.0} + ] + + # currently Adam only + if opts.optim == 'adam': + OptimCls = Adam + elif opts.optim == 'adamax': + OptimCls = Adamax + elif opts.optim == 'adamw': + OptimCls = AdamW + else: + raise ValueError('invalid optimizer') + optimizer = OptimCls(optimizer_grouped_parameters, + lr=opts.learning_rate, betas=opts.betas) + model, optimizer = amp.initialize(model, optimizer, enabled=opts.fp16, + opt_level='O2') + + global_step = 0 + LOGGER.info("***** Running training *****") + LOGGER.info(" Num examples = %d", len(train_dataset)) + LOGGER.info(" Batch size = %d", opts.train_batch_size) + LOGGER.info(" Accumulate steps = %d", opts.gradient_accumulation_steps) + LOGGER.info(" Num steps = %d", opts.num_train_steps) + + train_sampler = DistributedSampler( + train_dataset, num_replicas=n_gpu, rank=rank, shuffle=False) + train_dataloader = DataLoader(train_dataset, + sampler=train_sampler, + batch_size=opts.train_batch_size, + num_workers=opts.n_workers, + pin_memory=opts.pin_mem, + collate_fn=re_collate) + train_dataloader = PrefetchLoader(train_dataloader) + + val_sampler = DistributedSampler( + val_dataset, num_replicas=n_gpu, rank=rank, shuffle=False) + val_dataloader = DataLoader(val_dataset, + sampler=val_sampler, + batch_size=opts.val_batch_size, + num_workers=opts.n_workers, + pin_memory=opts.pin_mem, + collate_fn=re_eval_collate) + val_dataloader = PrefetchLoader(val_dataloader) + + if rank == 0: + save_training_meta(opts) + TB_LOGGER.create(join(opts.output_dir, 'log')) + pbar = tqdm(total=opts.num_train_steps) + model_saver = ModelSaver(join(opts.output_dir, 'ckpt'), 'model_epoch') + os.makedirs(join(opts.output_dir, 'results')) # store ITM predictions + add_log_to_file(join(opts.output_dir, 'log', 'log.txt')) + else: + LOGGER.disabled = True + pbar = NoOp() + model_saver = NoOp() + running_loss = RunningMeter(opts.train_loss) + n_examples = 0 + n_epoch = 0 + best_val_acc, best_epoch = None, None + start = time() + # quick hack for amp delay_unscale bug + optimizer.zero_grad() + optimizer.step() + + while True: + model.train() + for step, batch in enumerate(train_dataloader): + if global_step >= opts.num_train_steps: + break + + *_, targets = batch + n_examples += targets.size(0) + loss = model(*batch, compute_loss=True) + loss = loss.sum() # sum over vectorized loss TODO: investigate + delay_unscale = (step+1) % opts.gradient_accumulation_steps != 0 + with amp.scale_loss(loss, optimizer, delay_unscale=delay_unscale + ) as scaled_loss: + scaled_loss.backward() + if not delay_unscale: + # gather gradients from every processes + # do this before unscaling to make sure every process uses + # the same gradient scale + grads = [p.grad.data for p in model.parameters() + if p.requires_grad and p.grad is not None] + all_reduce_and_rescale_tensors(grads, float(1)) + + running_loss(loss.item()) + if (step + 1) % opts.gradient_accumulation_steps == 0: + global_step += 1 + + # learning rate scheduling + if opts.decay == 'linear': + lr_this_step = opts.learning_rate * warmup_linear( + global_step, opts.warmup_steps, opts.num_train_steps) + elif opts.decay == 'invsqrt': + lr_this_step = opts.learning_rate * noam_schedule( + global_step, opts.warmup_steps) + elif opts.decay == 'constant': + lr_this_step = opts.learning_rate + if lr_this_step < 0: + # save guard for possible miscalculation of train steps + lr_this_step = 1e-8 + for param_group in optimizer.param_groups: + param_group['lr'] = lr_this_step + TB_LOGGER.add_scalar('lr', lr_this_step, global_step) + + # log loss + losses = all_gather_list(running_loss) + running_loss = RunningMeter( + opts.train_loss, sum(l.val for l in losses)/len(losses)) + TB_LOGGER.add_scalar('loss_'+opts.train_loss, running_loss.val, + global_step) + TB_LOGGER.step() + + # update model params + if opts.grad_norm != -1: + grad_norm = clip_grad_norm_(amp.master_params(optimizer), + opts.grad_norm) + TB_LOGGER.add_scalar('grad_norm', grad_norm, global_step) + optimizer.step() + optimizer.zero_grad() + pbar.update(1) + + if global_step % 5 == 0: + torch.cuda.empty_cache() + if global_step % 100 == 0: + # monitor training throughput + tot_ex = sum(all_gather_list(n_examples)) + ex_per_sec = int(tot_ex / (time()-start)) + LOGGER.info(f'{tot_ex} examples trained at ' + f'{ex_per_sec} ex/s') + TB_LOGGER.add_scalar('perf/ex_per_s', + ex_per_sec, global_step) + # evaluate after each epoch + val_log, _ = validate(model, val_dataloader) + TB_LOGGER.log_scaler_dict(val_log) + + # save model + n_epoch += 1 + model_saver.save(model, n_epoch) + LOGGER.info(f"finished {n_epoch} epochs") + + # save best model + if best_val_acc is None or val_log['valid/acc'] > best_val_acc: + best_val_acc = val_log['valid/acc'] + best_epoch = n_epoch + model_saver.save(model, 'best') + + # shuffle training data for the next epoch + train_dataloader.loader.dataset.shuffle() + + # is training finished? + if global_step >= opts.num_train_steps: + break + + val_log, results = validate(model, val_dataloader) + with open(f'{opts.output_dir}/results/' + f'results_{global_step}_' + f'rank{rank}_final.json', 'w') as f: + json.dump(results, f) + TB_LOGGER.log_scaler_dict(val_log) + model_saver.save(model, f'{global_step}_final') + + # print best model + LOGGER.info(f'best_val_acc = {best_val_acc*100:.2f}% ' + f'at epoch {best_epoch}.') + + +@torch.no_grad() +def validate(model, val_dataloader): + LOGGER.info(f"start running evaluation.") + model.eval() + tot_score = 0 + n_ex = 0 + st = time() + predictions = {} + for i, batch in enumerate(val_dataloader): + # inputs + (*batch_inputs, tgt_box_list, obj_boxes_list, sent_ids) = batch + + # scores (n, max_num_bb) + scores = model(*batch_inputs, targets=None, compute_loss=False) + ixs = torch.argmax(scores, 1).cpu().detach().numpy() # (n, ) + + # pred_boxes + for ix, obj_boxes, tgt_box, sent_id in \ + zip(ixs, obj_boxes_list, tgt_box_list, sent_ids): + pred_box = obj_boxes[ix] + predictions['sent_id'] = {'pred_box': pred_box.tolist(), + 'tgt_box': tgt_box.tolist()} + if (val_dataloader.loader.dataset.computeIoU(pred_box, tgt_box) + > .5): + tot_score += 1 + n_ex += 1 + + tot_time = time()-st + tot_score = sum(all_gather_list(tot_score)) + n_ex = sum(all_gather_list(n_ex)) + val_acc = tot_score / n_ex + val_log = {'valid/acc': val_acc, 'valid/ex_per_s': n_ex/tot_time} + model.train() + LOGGER.info(f"validation ({n_ex} sents) finished in " + f"{int(tot_time)} seconds" + f", accuracy: {val_acc*100:.2f}%") + return val_log, predictions + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + + # Required parameters + parser.add_argument("--train_txt_db", + default=None, type=str, + help="The input train corpus. (LMDB)") + parser.add_argument("--train_img_dir", + default=None, type=str, + help="The input train images.") + parser.add_argument("--val_txt_db", + default=None, type=str, + help="The input validation corpus. (LMDB)") + parser.add_argument("--val_img_dir", + default=None, type=str, + help="The input validation images.") + parser.add_argument('--img_format', default='npz', + choices=['npz', 'lmdb', 'lmdb-compress'], + help='format of image feature') + parser.add_argument("--checkpoint", + default=None, type=str, + help="pretrained model (can take 'google-bert') ") + parser.add_argument("--cut_bert", default=-1, type=int, + help="reduce BERT layers (-1 for original depth)") + parser.add_argument("--mlp", default=1, type=int, + help="number of MLP layers for RE output") + + parser.add_argument( + "--output_dir", default=None, type=str, + help="The output directory where the model checkpoints will be " + "written.") + + # Prepro parameters + parser.add_argument('--max_txt_len', type=int, default=60, + help='max number of tokens in text (BERT BPE)') + + # training parameters + parser.add_argument("--train_batch_size", + default=128, type=int, + help="Total batch size for training. " + "(batch by examples)") + parser.add_argument("--val_batch_size", + default=256, type=int, + help="Total batch size for validation. " + "(batch by tokens)") + parser.add_argument("--train_loss", + default="cls", type=str, + choices=['cls', 'rank'], + help="loss to used during training") + parser.add_argument("--margin", + default=0.2, type=float, + help="margin of ranking loss") + parser.add_argument("--hard_ratio", + default=0.3, type=float, + help="sampling ratio of hard negatives") + parser.add_argument('--gradient_accumulation_steps', + type=int, + default=16, + help="Number of updates steps to accumualte before " + "performing a backward/update pass.") + parser.add_argument("--learning_rate", + default=3e-5, + type=float, + help="The initial learning rate for Adam.") + parser.add_argument("--num_train_steps", + default=32000, + type=int, + help="Total number of training updates to perform.") + parser.add_argument("--optim", default='adam', + choices=['adam', 'adamax', 'adamw'], + help="optimizer") + parser.add_argument("--betas", default=[0.9, 0.98], nargs='+', type=float, + help="beta for adam optimizer") + parser.add_argument("--decay", default='linear', + choices=['linear', 'invsqrt', 'constant'], + help="learning rate decay method") + parser.add_argument("--dropout", + default=0.1, + type=float, + help="tune dropout regularization") + parser.add_argument("--weight_decay", + default=0.0, + type=float, + help="weight decay (L2) regularization") + parser.add_argument("--grad_norm", + default=0.25, + type=float, + help="gradient clipping (-1 for no clipping)") + parser.add_argument("--warmup_steps", + default=4000, + type=int, + help="Number of training steps to perform linear " + "learning rate warmup for. (invsqrt decay)") + + # device parameters + parser.add_argument('--seed', + type=int, + default=24, + help="random seed for initialization") + parser.add_argument('--fp16', + action='store_true', + help="Whether to use 16-bit float precision instead " + "of 32-bit") + parser.add_argument('--n_workers', type=int, default=4, + help="number of data workers") + parser.add_argument('--pin_mem', action='store_true', + help="pin memory") + + # can use config files + parser.add_argument('--config', help='JSON config files') + + args = parse_with_config(parser) + + if exists(args.output_dir) and os.listdir(args.output_dir): + raise ValueError("Output directory ({}) already exists and is not " + "empty.".format(args.output_dir)) + + # options safe guard + main(args) diff --git a/uniter_model/train_vcr.py b/uniter_model/train_vcr.py new file mode 100644 index 0000000..1c44ee0 --- /dev/null +++ b/uniter_model/train_vcr.py @@ -0,0 +1,604 @@ +# coding=utf-8 +# copied from hugginface github +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. +# team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""BERT pre-training runner.""" +import argparse +import json +import os +from os.path import exists, join +import random +from time import time + +import torch +from torch.nn import functional as F +from torch.nn.utils import clip_grad_norm_ +from torch.optim import Adam, Adamax +from torch.utils.data import DataLoader, ConcatDataset + +from apex import amp +from horovod import torch as hvd + +import numpy as np +from tqdm import tqdm + +from data import (DistributedTokenBucketSampler, + DetectFeatLmdb, VcrDataset, VcrEvalDataset, + vcr_collate, vcr_eval_collate, + PrefetchLoader) +from model import BertForVisualCommonsenseReasoning +from optim import warmup_linear, noam_schedule, vqa_schedule, AdamW +from torch.utils.data.distributed import DistributedSampler + +from utils.logger import LOGGER, TB_LOGGER, RunningMeter, add_log_to_file +from utils.distributed import (all_reduce_and_rescale_tensors, all_gather_list, + broadcast_tensors) +from utils.save import ModelSaver, save_training_meta +from utils.misc import NoOp, parse_with_config +NUM_SPECIAL_TOKENS = 81 + + +def load_img_feat(dir_list, path2imgdir, opts): + dir_ = dir_list.split(";") + assert len(dir_) <= 2, "More than two img_dirs found" + img_dir_gt, img_dir = None, None + gt_dir_path, dir_path = "", "" + for d in dir_: + if "gt" in d: + gt_dir_path = d + else: + dir_path = d + if gt_dir_path != "": + img_dir_gt = path2imgdir.get(gt_dir_path, None) + if img_dir_gt is None: + img_dir_gt = DetectFeatLmdb(gt_dir_path, -1, + opts.max_bb, opts.min_bb, 100, + opts.compressed_db) + path2imgdir[gt_dir_path] = img_dir_gt + if dir_path != "": + img_dir = path2imgdir.get(dir_path, None) + if img_dir is None: + img_dir = DetectFeatLmdb(dir_path, opts.conf_th, + opts.max_bb, opts.min_bb, opts.num_bb, + opts.compressed_db) + path2imgdir[dir_path] = img_dir + return img_dir, img_dir_gt, path2imgdir + + +def main(opts): + hvd.init() + n_gpu = hvd.size() + device = torch.device("cuda", hvd.local_rank()) + torch.cuda.set_device(hvd.local_rank()) + rank = hvd.rank() + opts.rank = rank + LOGGER.info("device: {} n_gpu: {}, rank: {}, " + "16-bits training: {}".format( + device, n_gpu, hvd.rank(), opts.fp16)) + + if opts.gradient_accumulation_steps < 1: + raise ValueError("Invalid gradient_accumulation_steps parameter: {}, " + "should be >= 1".format( + opts.gradient_accumulation_steps)) + + random.seed(opts.seed) + np.random.seed(opts.seed) + torch.manual_seed(opts.seed) + if n_gpu > 0: + torch.cuda.manual_seed_all(opts.seed) + + # train_examples = None + LOGGER.info(f"Loading Train Dataset {opts.train_txt_db}, " + f"{opts.train_img_dir}") + + # load DBs and image dirs + train_txt_dbs = opts.train_txt_db.split(':') + train_img_dirs = opts.train_img_dir.split(':') + path2imgdir = {} + train_datasets = [] + for db, dir_list in zip(train_txt_dbs, train_img_dirs): + img_dir, img_dir_gt, path2imgdir = load_img_feat( + dir_list, path2imgdir, opts) + train_datasets.append(VcrDataset(opts.mask_prob, db, img_dir_gt, + img_dir, + opts.max_txt_len, task="qa")) + train_datasets.append(VcrDataset(opts.mask_prob, db, img_dir_gt, + img_dir, + opts.max_txt_len, task="qar")) + train_dataset = ConcatDataset(train_datasets) + train_lens = [l for dset in train_datasets for l in dset.lens] + val_img_dir, val_img_dir_gt, path2imgdir = load_img_feat( + opts.val_img_dir, path2imgdir, opts) + val_dataset = VcrEvalDataset("val", opts.val_txt_db, + val_img_dir_gt, val_img_dir, + max_txt_len=-1) + val_final_dataset = VcrEvalDataset("test", opts.val_txt_db, + val_img_dir_gt, val_img_dir, + max_txt_len=-1) + + # Prepare model + train_txt_db = train_txt_dbs[0] + emb_file = f'{train_txt_db}/embedding.pt' + + if opts.checkpoint and opts.checkpoint_from == "pretrain": + if opts.checkpoint == 'google-bert': + checkpoint = None + else: + checkpoint = torch.load(opts.checkpoint) + else: + checkpoint = {} + bert_model = json.load(open(f'{train_txt_db}/meta.json'))['bert'] + if 'bert' not in bert_model: + bert_model = 'bert-large-cased' # quick hack for glove exp + model = BertForVisualCommonsenseReasoning.from_pretrained( + bert_model, img_dim=2048, obj_cls=False, + state_dict=checkpoint) + model.init_type_embedding() + model.init_word_embedding(NUM_SPECIAL_TOKENS) + if opts.checkpoint_from == "vcr": + checkpoint = torch.load(opts.checkpoint) + state_dict = checkpoint.get('model_state', checkpoint) + matched_state_dict = {} + unexpected_keys = set() + missing_keys = set() + for name, param in model.named_parameters(): + missing_keys.add(name) + for key, data in state_dict.items(): + if key in missing_keys: + matched_state_dict[key] = data + missing_keys.remove(key) + else: + unexpected_keys.add(key) + print("Unexpected_keys:", list(unexpected_keys)) + print("Missing_keys:", list(missing_keys)) + model.load_state_dict(matched_state_dict, strict=False) + if opts.cut_bert != -1: + # cut some layers of BERT + model.bert.encoder.layer = torch.nn.ModuleList( + model.bert.encoder.layer[:opts.cut_bert]) + if exists(emb_file) and not opts.checkpoint: + glove = torch.load(f'{train_txt_db}/embedding.pt') + vsize = glove.size(0) + hid_size = model.config.hidden_size + model.bert.embeddings.word_embeddings = torch.nn.Embedding( + vsize, hid_size) + mul_ = hid_size // 300 + 1 + model.bert.embeddings.word_embeddings.weight.data = glove.repeat( + 1, mul_)[:, :hid_size] + LOGGER.info('using GloVe for BERT') + del checkpoint + for name, module in model.named_modules(): + # we might want to tune dropout for smaller dataset + if isinstance(module, torch.nn.Dropout): + if module.p != opts.dropout: + module.p = opts.dropout + LOGGER.info(f'{name} set to {opts.dropout}') + model.to(device) + if rank != -1: + # make sure every process has same model parameters in the beginning + broadcast_tensors([p.data for p in model.parameters()], 0) + + # Prepare optimizer + param_optimizer = list(model.named_parameters()) + no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] + optimizer_grouped_parameters = [ + {'params': [p for n, p in param_optimizer + if not any(nd in n for nd in no_decay)], + 'weight_decay': opts.weight_decay}, + {'params': [p for n, p in param_optimizer + if any(nd in n for nd in no_decay)], + 'weight_decay': 0.0} + ] + + if opts.optim == 'adam': + OptimCls = Adam + elif opts.optim == 'adamax': + OptimCls = Adamax + elif opts.optim == 'adamw': + OptimCls = AdamW + else: + raise ValueError('invalid optimizer') + optimizer = OptimCls(optimizer_grouped_parameters, + lr=opts.learning_rate, betas=opts.betas) + model, optimizer = amp.initialize(model, optimizer, + enabled=opts.fp16, opt_level='O2') + + train_sampler = DistributedTokenBucketSampler( + n_gpu, rank, train_lens, bucket_size=8192, + batch_size=opts.train_batch_size, droplast=True) + val_sampler = DistributedSampler( + val_dataset, num_replicas=n_gpu, rank=rank) + val_final_sampler = DistributedSampler( + val_final_dataset, num_replicas=n_gpu, rank=rank) + train_dataloader = DataLoader(train_dataset, + batch_sampler=train_sampler, + num_workers=opts.n_workers, + pin_memory=opts.pin_mem, + collate_fn=vcr_collate) + train_dataloader = PrefetchLoader(train_dataloader) + val_dataloader = DataLoader(val_dataset, + batch_size=opts.val_batch_size*3, + sampler=val_sampler, + num_workers=opts.n_workers, + pin_memory=opts.pin_mem, + collate_fn=vcr_eval_collate) + val_final_dataloader = DataLoader(val_final_dataset, + batch_size=opts.val_batch_size, + sampler=val_final_sampler, + num_workers=opts.n_workers, + pin_memory=opts.pin_mem, + collate_fn=vcr_eval_collate) + val_dataloader = PrefetchLoader(val_dataloader) + val_final_dataloader = PrefetchLoader(val_final_dataloader) + + global_step = 0 + if rank == 0: + save_training_meta(opts) + TB_LOGGER.create(join(opts.output_dir, 'log')) + pbar = tqdm(total=opts.num_train_steps) + model_saver = ModelSaver(join(opts.output_dir, 'ckpt')) + os.makedirs(join(opts.output_dir, 'results')) # store VQA predictions + add_log_to_file(join(opts.output_dir, 'log', 'log.txt')) + else: + LOGGER.disabled = True + pbar = NoOp() + model_saver = NoOp() + + LOGGER.info(f"***** Running training with {n_gpu} GPUs *****") + LOGGER.info(" Num examples = %d", len(train_dataset)) + LOGGER.info(" Batch size = %d", opts.train_batch_size) + LOGGER.info(" Accumulate steps = %d", opts.gradient_accumulation_steps) + LOGGER.info(" Num steps = %d", opts.num_train_steps) + + running_vcr_loss = RunningMeter('vcr_loss') + running_obj_loss = RunningMeter('obj_cls_loss') + running_loss = RunningMeter('loss') + model.train() + n_examples = 0 + n_epoch = 0 + start = time() + # quick hack for amp delay_unscale bug + optimizer.zero_grad() + optimizer.step() + while True: + for step, batch in enumerate(train_dataloader): + *_, targets = batch + n_examples += targets.size(0) + + vcr_loss, obj_cls_loss = model(*batch, compute_loss=True) + # loss = loss.mean() + loss = vcr_loss + obj_cls_loss + delay_unscale = (step+1) % opts.gradient_accumulation_steps != 0 + with amp.scale_loss(loss, optimizer, delay_unscale=delay_unscale + ) as scaled_loss: + scaled_loss.backward() + if not delay_unscale: + # gather gradients from every processes + # do this before unscaling to make sure every process uses + # the same gradient scale + grads = [p.grad.data for p in model.parameters() + if p.requires_grad and p.grad is not None] + all_reduce_and_rescale_tensors(grads, float(1)) + + running_loss(loss.item()) + running_vcr_loss(vcr_loss.item()) + running_obj_loss(obj_cls_loss.item()) + + if (step + 1) % opts.gradient_accumulation_steps == 0: + global_step += 1 + + # learning rate scheduling + if opts.decay == 'linear': + lr_this_step = opts.learning_rate * warmup_linear( + global_step, opts.warmup_steps, opts.num_train_steps) + elif opts.decay == 'invsqrt': + lr_this_step = opts.learning_rate * noam_schedule( + global_step, opts.warmup_steps) + elif opts.decay == 'constant': + lr_this_step = opts.learning_rate + elif opts.decay == 'vqa': + lr_this_step = opts.learning_rate * vqa_schedule( + global_step, opts.warm_int, opts.decay_int, + opts.decay_st, opts.decay_rate) + if lr_this_step < 0: + # save guard for possible miscalculation of train steps + lr_this_step = 1e-8 + for param_group in optimizer.param_groups: + param_group['lr'] = lr_this_step + TB_LOGGER.add_scalar('lr', lr_this_step, global_step) + + # log loss + losses = all_gather_list(running_loss) + running_loss = RunningMeter( + 'loss', sum(l.val for l in losses)/len(losses)) + TB_LOGGER.add_scalar('loss', running_loss.val, global_step) + + vcr_losses = all_gather_list(running_vcr_loss) + running_vcr_loss = RunningMeter( + 'vcr_loss', sum(l.val for l in vcr_losses)/len(vcr_losses)) + TB_LOGGER.add_scalar('vcr_loss', running_vcr_loss.val, + global_step) + + obj_losses = all_gather_list(running_obj_loss) + running_obj_loss = RunningMeter( + 'obj_cls_loss', + sum(l.val for l in obj_losses)/len(obj_losses)) + TB_LOGGER.add_scalar('obj_cls_loss', running_obj_loss.val, + global_step) + TB_LOGGER.step() + + # update model params + if opts.grad_norm != -1: + grad_norm = clip_grad_norm_(amp.master_params(optimizer), + opts.grad_norm) + TB_LOGGER.add_scalar('grad_norm', grad_norm, global_step) + optimizer.step() + optimizer.zero_grad() + pbar.update(1) + + if global_step % 5 == 0: + torch.cuda.empty_cache() + if global_step % 100 == 0: + # monitor training throughput + tot_ex = sum(all_gather_list(n_examples)) + ex_per_sec = int(tot_ex / (time()-start)) + LOGGER.info(f'{tot_ex} examples trained at ' + f'{ex_per_sec} ex/s') + TB_LOGGER.add_scalar('perf/ex_per_s', + ex_per_sec, global_step) + if global_step % opts.valid_steps == 0: + val_log, results = validate( + model, val_dataloader) + TB_LOGGER.log_scaler_dict(val_log) + model_saver.save(model, global_step) + if global_step >= opts.num_train_steps: + break + if global_step >= opts.num_train_steps: + break + n_epoch += 1 + LOGGER.info(f"finished {n_epoch} epochs") + val_log, results = validate( + model, val_final_dataloader) + with open(f'{opts.output_dir}/results/' + f'results_{global_step}_' + f'rank{rank}.json', 'w') as f: + json.dump(results, f) + TB_LOGGER.log_scaler_dict(val_log) + model_saver.save(model, f'{global_step}_final') + + +def compute_accuracies(out_qa, labels_qa, out_qar, labels_qar): + outputs_qa = out_qa.max(dim=-1)[1] + outputs_qar = out_qar.max(dim=-1)[1] + matched_qa = outputs_qa.squeeze() == labels_qa.squeeze() + matched_qar = outputs_qar.squeeze() == labels_qar.squeeze() + matched_joined = matched_qa & matched_qar + n_correct_qa = matched_qa.sum().item() + n_correct_qar = matched_qar.sum().item() + n_correct_joined = matched_joined.sum().item() + return n_correct_qa, n_correct_qar, n_correct_joined + + +@torch.no_grad() +def validate(model, val_loader): + if hvd.rank() == 0: + val_pbar = tqdm(total=len(val_loader)) + else: + val_pbar = NoOp() + LOGGER.info(f"start running evaluation ...") + model.eval() + val_qa_loss, val_qar_loss = 0, 0 + tot_qa_score, tot_qar_score, tot_score = 0, 0, 0 + n_ex = 0 + st = time() + results = {} + for i, batch in enumerate(val_loader): + qids, *inputs, qa_targets, qar_targets, _ = batch + scores = model( + *inputs, targets=None, compute_loss=False) + scores = scores.view(len(qids), -1) + vcr_qa_loss = F.cross_entropy( + scores[:, :4], qa_targets.squeeze(-1), reduction="sum") + if scores.shape[1] > 8: + qar_index = [4+answer_ind.item()*4+i for answer_ind in qa_targets + for i in range(4)] + qar_scores = scores[:, qar_index] + else: + qar_scores = scores[:, 4:] + vcr_qar_loss = F.cross_entropy( + qar_scores, qar_targets.squeeze(-1), reduction="sum") + val_qa_loss += vcr_qa_loss.item() + val_qar_loss += vcr_qar_loss.item() + curr_qa_score, curr_qar_score, curr_score = compute_accuracies( + scores[:, :4], qa_targets, qar_scores, qar_targets) + tot_qar_score += curr_qar_score + tot_qa_score += curr_qa_score + tot_score += curr_score + for qid, score in zip(qids, scores): + results[qid] = score.cpu().tolist() + n_ex += len(qids) + val_pbar.update(1) + val_qa_loss = sum(all_gather_list(val_qa_loss)) + val_qar_loss = sum(all_gather_list(val_qar_loss)) + tot_qa_score = sum(all_gather_list(tot_qa_score)) + tot_qar_score = sum(all_gather_list(tot_qar_score)) + tot_score = sum(all_gather_list(tot_score)) + n_ex = sum(all_gather_list(n_ex)) + tot_time = time()-st + val_qa_loss /= n_ex + val_qar_loss /= n_ex + val_qa_acc = tot_qa_score / n_ex + val_qar_acc = tot_qar_score / n_ex + val_acc = tot_score / n_ex + val_log = {f'valid/vcr_qa_loss': val_qa_loss, + f'valid/vcr_qar_loss': val_qar_loss, + f'valid/acc_qa': val_qa_acc, + f'valid/acc_qar': val_qar_acc, + f'valid/acc': val_acc, + f'valid/ex_per_s': n_ex/tot_time} + model.train() + LOGGER.info(f"validation finished in {int(tot_time)} seconds, " + f"score_qa: {val_qa_acc*100:.2f} " + f"score_qar: {val_qar_acc*100:.2f} " + f"score: {val_acc*100:.2f} ") + return val_log, results + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + # Required parameters + parser.add_argument("--task", + default="qa", type=str, + choices=['qa', 'qar'], + help="VCR tasks: qa or qar") + parser.add_argument("--train_txt_db", + default=None, type=str, + help="The input train corpus. (LMDB)") + parser.add_argument("--train_img_dir", + default=None, type=str, + help="The input train images.") + parser.add_argument("--val_txt_db", + default=None, type=str, + help="The input validation corpus. (LMDB)") + parser.add_argument("--val_img_dir", + default=None, type=str, + help="The input validation images.") + parser.add_argument('--img_format', default='npz', + choices=['npz', 'lmdb', 'lmdb-compress'], + help='format of image feature') + parser.add_argument("--checkpoint", + default=None, type=str, + help="pretrained model (can take 'google-bert') ") + parser.add_argument("--checkpoint_from", + default='pretrain', type=str, + choices=['pretrain', 'vcr'], + help="which setting is checkpoint from") + parser.add_argument("--cut_bert", default=-1, type=int, + help="reduce BERT layers (-1 for original depth)") + + parser.add_argument( + "--output_dir", default=None, type=str, + help="The output directory where the model checkpoints will be " + "written.") + + # Prepro parameters + parser.add_argument('--max_txt_len', type=int, default=60, + help='max number of tokens in text (BERT BPE)') + parser.add_argument('--conf_th', type=float, default=0.2, + help='threshold for dynamic bounding boxes ' + '(-1 for fixed)') + parser.add_argument('--max_bb', type=int, default=100, + help='max number of bounding boxes') + parser.add_argument('--min_bb', type=int, default=10, + help='min number of bounding boxes') + parser.add_argument('--num_bb', type=int, default=36, + help='static number of bounding boxes') + + # training parameters + parser.add_argument("--train_batch_size", + default=4096, type=int, + help="Total batch size for training. " + "(batch by tokens)") + parser.add_argument("--val_batch_size", + default=4096, type=int, + help="Total batch size for validation. " + "(batch by tokens)") + parser.add_argument('--gradient_accumulation_steps', + type=int, + default=16, + help="Number of updates steps to accumualte before " + "performing a backward/update pass.") + parser.add_argument("--learning_rate", + default=3e-5, + type=float, + help="The initial learning rate for Adam.") + parser.add_argument("--valid_steps", + default=1000, + type=int, + help="Run validation every X steps") + parser.add_argument("--num_train_steps", + default=100000, + type=int, + help="Total number of training updates to perform.") + parser.add_argument('--mask_prob', default=0.15, type=float, + help='probability to mask in MRC training') + parser.add_argument("--optim", default='adam', + choices=['adam', 'adamax', 'adamw'], + help="optimizer") + parser.add_argument("--betas", default=[0.9, 0.98], nargs='+', + help="beta for adam optimizer") + parser.add_argument("--decay", default='linear', + choices=['linear', 'invsqrt', 'constant', 'vqa'], + help="learning rate decay method") + parser.add_argument("--decay_int", default=2000, type=int, + help="interval between VQA lr decy") + parser.add_argument("--warm_int", default=2000, type=int, + help="interval for VQA lr warmup") + parser.add_argument("--decay_st", default=20000, type=int, + help="when to start decay") + parser.add_argument("--decay_rate", default=0.2, type=float, + help="ratio of lr decay") + parser.add_argument("--dropout", + default=0.1, + type=float, + help="tune dropout regularization") + parser.add_argument("--weight_decay", + default=0.0, + type=float, + help="weight decay (L2) regularization") + parser.add_argument("--grad_norm", + default=0.25, + type=float, + help="gradient clipping (-1 for no clipping)") + parser.add_argument("--warmup_steps", + default=4000, + type=int, + help="Number of training steps to perform linear " + "learning rate warmup for. (invsqrt decay)") + + # device parameters + parser.add_argument('--seed', + type=int, + default=42, + help="random seed for initialization") + parser.add_argument('--fp16', + action='store_true', + help="Whether to use 16-bit float precision instead " + "of 32-bit") + parser.add_argument('--n_workers', type=int, default=4, + help="number of data workers") + parser.add_argument('--pin_mem', action='store_true', + help="pin memory") + + # can use config files + parser.add_argument('--config', help='JSON config files') + + args = parse_with_config(parser) + + if exists(args.output_dir) and os.listdir(args.output_dir): + raise ValueError("Output directory ({}) already exists and is not " + "empty.".format(args.output_dir)) + + # options safe guard + # TODO + + if args.conf_th == -1: + assert args.max_bb + args.max_txt_len + 2 <= 512 + else: + assert args.num_bb + args.max_txt_len + 2 <= 512 + + main(args) diff --git a/uniter_model/train_ve.py b/uniter_model/train_ve.py new file mode 100644 index 0000000..ab44349 --- /dev/null +++ b/uniter_model/train_ve.py @@ -0,0 +1,413 @@ +# coding=utf-8 +# copied from hugginface github +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. +# team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""BERT pre-training runner.""" +import argparse +import json +import os +from os.path import exists, join +import pickle +from time import time + +import torch +from torch.nn import functional as F +from torch.nn.utils import clip_grad_norm_ +from torch.utils.data import DataLoader + +from apex import amp +from horovod import torch as hvd + +from tqdm import tqdm + +from data import (TokenBucketSampler, PrefetchLoader, + DetectFeatLmdb, TxtTokLmdb, + VeDataset, VeEvalDataset, + ve_collate, ve_eval_collate) +from model import UniterForVisualEntailment +from optim import get_lr_sched +from optim.misc import build_optimizer + +from utils.logger import LOGGER, TB_LOGGER, RunningMeter, add_log_to_file +from utils.distributed import (all_reduce_and_rescale_tensors, all_gather_list, + broadcast_tensors) +from utils.save import ModelSaver, save_training_meta +from utils.misc import NoOp, parse_with_config, set_dropout, set_random_seed +from utils.misc import VE_ENT2IDX as ans2label +from utils.misc import VE_IDX2ENT as label2ans +from utils.const import IMG_DIM, BUCKET_SIZE + + +def create_dataloader(img_path, txt_path, batch_size, is_train, + dset_cls, collate_fn, opts): + img_db = DetectFeatLmdb(img_path, opts.conf_th, opts.max_bb, opts.min_bb, + opts.num_bb, opts.compressed_db) + txt_db = TxtTokLmdb(txt_path, opts.max_txt_len if is_train else -1) + dset = dset_cls(txt_db, img_db) + sampler = TokenBucketSampler(dset.lens, bucket_size=BUCKET_SIZE, + batch_size=batch_size, droplast=is_train) + loader = DataLoader(dset, batch_sampler=sampler, + num_workers=opts.n_workers, pin_memory=opts.pin_mem, + collate_fn=collate_fn) + return PrefetchLoader(loader) + + +def main(opts): + hvd.init() + n_gpu = hvd.size() + device = torch.device("cuda", hvd.local_rank()) + torch.cuda.set_device(hvd.local_rank()) + rank = hvd.rank() + opts.rank = rank + LOGGER.info("device: {} n_gpu: {}, rank: {}, " + "16-bits training: {}".format( + device, n_gpu, hvd.rank(), opts.fp16)) + + if opts.gradient_accumulation_steps < 1: + raise ValueError("Invalid gradient_accumulation_steps parameter: {}, " + "should be >= 1".format( + opts.gradient_accumulation_steps)) + + set_random_seed(opts.seed) + + # train_examples = None + LOGGER.info(f"Loading Train Dataset {opts.train_txt_db}, " + f"{opts.train_img_db}") + train_dataloader = create_dataloader(opts.train_img_db, opts.train_txt_db, + opts.train_batch_size, True, + VeDataset, ve_collate, opts) + val_dataloader = create_dataloader(opts.val_img_db, opts.val_txt_db, + opts.val_batch_size, False, + VeEvalDataset, ve_eval_collate, opts) + test_dataloader = create_dataloader(opts.test_img_db, opts.test_txt_db, + opts.val_batch_size, False, + VeEvalDataset, ve_eval_collate, opts) + + # Prepare model + if opts.checkpoint: + checkpoint = torch.load(opts.checkpoint) + else: + checkpoint = {} + bert_model = json.load(open(f'{opts.train_txt_db}/meta.json'))['bert'] + if 'bert' not in bert_model: + bert_model = 'bert-large-cased' # quick hack for glove exp + model = UniterForVisualEntailment.from_pretrained( + opts.model_config, state_dict=checkpoint, img_dim=IMG_DIM) + model.to(device) + # make sure every process has same model parameters in the beginning + broadcast_tensors([p.data for p in model.parameters()], 0) + set_dropout(model, opts.dropout) + + # Prepare optimizer + optimizer = build_optimizer(model, opts) + model, optimizer = amp.initialize(model, optimizer, + enabled=opts.fp16, opt_level='O2') + + global_step = 0 + if rank == 0: + save_training_meta(opts) + TB_LOGGER.create(join(opts.output_dir, 'log')) + pbar = tqdm(total=opts.num_train_steps) + model_saver = ModelSaver(join(opts.output_dir, 'ckpt')) + pickle.dump(ans2label, + open(join(opts.output_dir, 'ckpt', 'ans2label.pkl'), 'wb')) + os.makedirs(join(opts.output_dir, 'results')) # store VQA predictions + add_log_to_file(join(opts.output_dir, 'log', 'log.txt')) + else: + LOGGER.disabled = True + pbar = NoOp() + model_saver = NoOp() + + LOGGER.info(f"***** Running training with {n_gpu} GPUs *****") + LOGGER.info(" Num examples = %d", len(train_dataloader.dataset)) + LOGGER.info(" Batch size = %d", opts.train_batch_size) + LOGGER.info(" Accumulate steps = %d", opts.gradient_accumulation_steps) + LOGGER.info(" Num steps = %d", opts.num_train_steps) + + running_loss = RunningMeter('loss') + model.train() + n_examples = 0 + n_epoch = 0 + start = time() + # quick hack for amp delay_unscale bug + optimizer.zero_grad() + optimizer.step() + while True: + for step, batch in enumerate(train_dataloader): + n_examples += batch['input_ids'].size(0) + + loss = model(batch, compute_loss=True) + loss = loss.mean() * batch['targets'].size(1) # instance-leval bce + delay_unscale = (step+1) % opts.gradient_accumulation_steps != 0 + with amp.scale_loss(loss, optimizer, delay_unscale=delay_unscale + ) as scaled_loss: + scaled_loss.backward() + if not delay_unscale: + # gather gradients from every processes + # do this before unscaling to make sure every process uses + # the same gradient scale + grads = [p.grad.data for p in model.parameters() + if p.requires_grad and p.grad is not None] + all_reduce_and_rescale_tensors(grads, float(1)) + + running_loss(loss.item()) + + if (step + 1) % opts.gradient_accumulation_steps == 0: + global_step += 1 + + # learning rate scheduling + lr_this_step = get_lr_sched(global_step, opts) + for param_group in optimizer.param_groups: + param_group['lr'] = lr_this_step + TB_LOGGER.add_scalar('lr', lr_this_step, global_step) + + # log loss + losses = all_gather_list(running_loss) + running_loss = RunningMeter( + 'loss', sum(l.val for l in losses)/len(losses)) + TB_LOGGER.add_scalar('loss', running_loss.val, global_step) + TB_LOGGER.step() + + # update model params + if opts.grad_norm != -1: + grad_norm = clip_grad_norm_(amp.master_params(optimizer), + opts.grad_norm) + TB_LOGGER.add_scalar('grad_norm', grad_norm, global_step) + optimizer.step() + optimizer.zero_grad() + pbar.update(1) + + if global_step % 100 == 0: + # monitor training throughput + LOGGER.info(f'============Step {global_step}=============') + tot_ex = sum(all_gather_list(n_examples)) + ex_per_sec = int(tot_ex / (time()-start)) + LOGGER.info(f'{tot_ex} examples trained at ' + f'{ex_per_sec} ex/s') + TB_LOGGER.add_scalar('perf/ex_per_s', + ex_per_sec, global_step) + LOGGER.info(f'===========================================') + + if global_step % opts.valid_steps == 0: + for split, loader in [("val", val_dataloader), + ("test", test_dataloader)]: + LOGGER.info(f"Step {global_step}: start running " + f"validation on {split} split...") + val_log, results = validate( + model, loader, label2ans, split) + with open(f'{opts.output_dir}/results/' + f'{split}_results_{global_step}_' + f'rank{rank}.json', 'w') as f: + json.dump(results, f) + TB_LOGGER.log_scaler_dict(val_log) + model_saver.save(model, global_step) + if global_step >= opts.num_train_steps: + break + if global_step >= opts.num_train_steps: + break + n_epoch += 1 + LOGGER.info(f"Step {global_step}: finished {n_epoch} epochs") + for split, loader in [("val", val_dataloader), + ("test", test_dataloader)]: + LOGGER.info(f"Step {global_step}: start running " + f"validation on {split} split...") + val_log, results = validate(model, loader, label2ans, split) + with open(f'{opts.output_dir}/results/' + f'{split}_results_{global_step}_' + f'rank{rank}_final.json', 'w') as f: + json.dump(results, f) + TB_LOGGER.log_scaler_dict(val_log) + model_saver.save(model, f'{global_step}_final') + + +@torch.no_grad() +def validate(model, val_loader, label2ans, split='val'): + model.eval() + val_loss = 0 + tot_score = 0 + n_ex = 0 + st = time() + results = {} + for i, batch in enumerate(val_loader): + scores = model(batch, compute_loss=False) + targets = batch['targets'] + loss = F.binary_cross_entropy_with_logits( + scores, targets, reduction='sum') + val_loss += loss.item() + tot_score += compute_score_with_logits(scores, targets).sum().item() + answers = [label2ans[i] + for i in scores.max(dim=-1, keepdim=False + )[1].cpu().tolist()] + qids = batch['qids'] + for qid, answer in zip(qids, answers): + results[qid] = answer + n_ex += len(qids) + val_loss = sum(all_gather_list(val_loss)) + tot_score = sum(all_gather_list(tot_score)) + n_ex = sum(all_gather_list(n_ex)) + tot_time = time()-st + val_loss /= n_ex + val_acc = tot_score / n_ex + val_log = {f'valid/{split}_loss': val_loss, + f'valid/{split}_acc': val_acc, + f'valid/{split}_ex_per_s': n_ex/tot_time} + model.train() + LOGGER.info(f"validation finished in {int(tot_time)} seconds, " + f"score: {val_acc*100:.2f}") + return val_log, results + + +def compute_score_with_logits(logits, labels): + logits = torch.max(logits, 1)[1] # argmax + one_hots = torch.zeros(*labels.size(), device=labels.device) + one_hots.scatter_(1, logits.view(-1, 1), 1) + scores = (one_hots * labels) + return scores + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + # Required parameters + parser.add_argument("--train_txt_db", + default=None, type=str, + help="The input train corpus. (LMDB)") + parser.add_argument("--train_img_db", + default=None, type=str, + help="The input train images.") + parser.add_argument("--val_txt_db", + default=None, type=str, + help="The input validation corpus. (LMDB)") + parser.add_argument("--val_img_db", + default=None, type=str, + help="The input validation images.") + parser.add_argument("--test_txt_db", + default=None, type=str, + help="The input test corpus. (LMDB)") + parser.add_argument("--test_img_db", + default=None, type=str, + help="The input test images.") + parser.add_argument('--compressed_db', action='store_true', + help='use compressed LMDB') + parser.add_argument("--model_config", + default=None, type=str, + help="json file for model architecture") + parser.add_argument("--checkpoint", + default=None, type=str, + help="pretrained model (can take 'google-bert') ") + + parser.add_argument( + "--output_dir", default=None, type=str, + help="The output directory where the model checkpoints will be " + "written.") + + # Prepro parameters + parser.add_argument('--max_txt_len', type=int, default=60, + help='max number of tokens in text (BERT BPE)') + parser.add_argument('--conf_th', type=float, default=0.2, + help='threshold for dynamic bounding boxes ' + '(-1 for fixed)') + parser.add_argument('--max_bb', type=int, default=100, + help='max number of bounding boxes') + parser.add_argument('--min_bb', type=int, default=10, + help='min number of bounding boxes') + parser.add_argument('--num_bb', type=int, default=36, + help='static number of bounding boxes') + + # training parameters + parser.add_argument("--train_batch_size", + default=4096, type=int, + help="Total batch size for training. " + "(batch by tokens)") + parser.add_argument("--val_batch_size", + default=4096, type=int, + help="Total batch size for validation. " + "(batch by tokens)") + parser.add_argument('--gradient_accumulation_steps', + type=int, + default=16, + help="Number of updates steps to accumualte before " + "performing a backward/update pass.") + parser.add_argument("--learning_rate", + default=3e-5, + type=float, + help="The initial learning rate for Adam.") + parser.add_argument("--valid_steps", + default=1000, + type=int, + help="Run validation every X steps") + parser.add_argument("--num_train_steps", + default=100000, + type=int, + help="Total number of training updates to perform.") + parser.add_argument("--optim", default='adam', + choices=['adam', 'adamax', 'adamw'], + help="optimizer") + parser.add_argument("--betas", default=[0.9, 0.98], nargs='+', + help="beta for adam optimizer") + parser.add_argument("--decay", default='linear', + choices=['linear', 'invsqrt', 'constant'], + help="learning rate decay method") + parser.add_argument("--dropout", + default=0.1, + type=float, + help="tune dropout regularization") + parser.add_argument("--weight_decay", + default=0.0, + type=float, + help="weight decay (L2) regularization") + parser.add_argument("--grad_norm", + default=0.25, + type=float, + help="gradient clipping (-1 for no clipping)") + parser.add_argument("--warmup_steps", + default=4000, + type=int, + help="Number of training steps to perform linear " + "learning rate warmup for. (invsqrt decay)") + + # device parameters + parser.add_argument('--seed', + type=int, + default=42, + help="random seed for initialization") + parser.add_argument('--fp16', + action='store_true', + help="Whether to use 16-bit float precision instead " + "of 32-bit") + parser.add_argument('--n_workers', type=int, default=4, + help="number of data workers") + parser.add_argument('--pin_mem', action='store_true', + help="pin memory") + + # can use config files + parser.add_argument('--config', help='JSON config files') + + args = parse_with_config(parser) + + if exists(args.output_dir) and os.listdir(args.output_dir): + raise ValueError("Output directory ({}) already exists and is not " + "empty.".format(args.output_dir)) + + # options safe guard + # TODO + + if args.conf_th == -1: + assert args.max_bb + args.max_txt_len + 2 <= 512 + else: + assert args.num_bb + args.max_txt_len + 2 <= 512 + + main(args) diff --git a/uniter_model/train_vqa.py b/uniter_model/train_vqa.py new file mode 100644 index 0000000..749749c --- /dev/null +++ b/uniter_model/train_vqa.py @@ -0,0 +1,415 @@ +""" +Copyright (c) Microsoft Corporation. +Licensed under the MIT license. + +UNITER finetuning for VQA +""" +import argparse +import json +import os +from os.path import abspath, dirname, exists, join +from time import time + +import torch +from torch.nn import functional as F +from torch.nn.utils import clip_grad_norm_ +from torch.utils.data import DataLoader +from torch.optim import Adam, Adamax + +from apex import amp +from horovod import torch as hvd + +from tqdm import tqdm + +from data import (TokenBucketSampler, PrefetchLoader, + TxtTokLmdb, ImageLmdbGroup, ConcatDatasetWithLens, + VqaDataset, VqaEvalDataset, + vqa_collate, vqa_eval_collate) +from model import UniterForVisualQuestionAnswering +from optim import AdamW, get_lr_sched + +from utils.logger import LOGGER, TB_LOGGER, RunningMeter, add_log_to_file +from utils.distributed import (all_reduce_and_rescale_tensors, all_gather_list, + broadcast_tensors) +from utils.save import ModelSaver, save_training_meta +from utils.misc import NoOp, parse_with_config, set_dropout, set_random_seed +from utils.const import BUCKET_SIZE, IMG_DIM + + +def build_dataloader(dataset, collate_fn, is_train, opts): + batch_size = (opts.train_batch_size if is_train + else opts.val_batch_size) + sampler = TokenBucketSampler(dataset.lens, bucket_size=BUCKET_SIZE, + batch_size=batch_size, droplast=is_train) + dataloader = DataLoader(dataset, batch_sampler=sampler, + num_workers=opts.n_workers, + pin_memory=opts.pin_mem, collate_fn=collate_fn) + dataloader = PrefetchLoader(dataloader) + return dataloader + + +def build_optimizer(model, opts): + """ vqa linear may get larger learning rate """ + no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] + param_optimizer = [(n, p) for n, p in model.named_parameters() + if 'vqa_output' not in n] + param_top = [(n, p) for n, p in model.named_parameters() + if 'vqa_output' in n] + optimizer_grouped_parameters = [ + {'params': [p for n, p in param_top + if not any(nd in n for nd in no_decay)], + 'lr': opts.learning_rate, + 'weight_decay': opts.weight_decay}, + {'params': [p for n, p in param_top + if any(nd in n for nd in no_decay)], + 'lr': opts.learning_rate, + 'weight_decay': 0.0}, + {'params': [p for n, p in param_optimizer + if not any(nd in n for nd in no_decay)], + 'weight_decay': opts.weight_decay}, + {'params': [p for n, p in param_optimizer + if any(nd in n for nd in no_decay)], + 'weight_decay': 0.0} + ] + + # currently Adam only + if opts.optim == 'adam': + OptimCls = Adam + elif opts.optim == 'adamax': + OptimCls = Adamax + elif opts.optim == 'adamw': + OptimCls = AdamW + else: + raise ValueError('invalid optimizer') + optimizer = OptimCls(optimizer_grouped_parameters, + lr=opts.learning_rate, betas=opts.betas) + return optimizer + + +def main(opts): + hvd.init() + n_gpu = hvd.size() + device = torch.device("cuda", hvd.local_rank()) + torch.cuda.set_device(hvd.local_rank()) + rank = hvd.rank() + opts.rank = rank + LOGGER.info("device: {} n_gpu: {}, rank: {}, " + "16-bits training: {}".format( + device, n_gpu, hvd.rank(), opts.fp16)) + + if opts.gradient_accumulation_steps < 1: + raise ValueError("Invalid gradient_accumulation_steps parameter: {}, " + "should be >= 1".format( + opts.gradient_accumulation_steps)) + + set_random_seed(opts.seed) + + ans2label = json.load(open(f'{dirname(abspath(__file__))}' + f'/misc/ans2label.json')) + label2ans = {label: ans for ans, label in ans2label.items()} + + # load DBs and image dirs + all_img_dbs = ImageLmdbGroup(opts.conf_th, opts.max_bb, opts.min_bb, + opts.num_bb, opts.compressed_db) + # train + LOGGER.info(f"Loading Train Dataset " + f"{opts.train_txt_dbs}, {opts.train_img_dbs}") + train_datasets = [] + for txt_path, img_path in zip(opts.train_txt_dbs, opts.train_img_dbs): + img_db = all_img_dbs[img_path] + txt_db = TxtTokLmdb(txt_path, opts.max_txt_len) + train_datasets.append(VqaDataset(len(ans2label), txt_db, img_db)) + train_dataset = ConcatDatasetWithLens(train_datasets) + train_dataloader = build_dataloader(train_dataset, vqa_collate, True, opts) + # val + LOGGER.info(f"Loading Train Dataset {opts.val_txt_db}, {opts.val_img_db}") + val_img_db = all_img_dbs[opts.val_img_db] + val_txt_db = TxtTokLmdb(opts.val_txt_db, -1) + val_dataset = VqaEvalDataset(len(ans2label), val_txt_db, val_img_db) + val_dataloader = build_dataloader(val_dataset, vqa_eval_collate, + False, opts) + + # Prepare model + if opts.checkpoint: + checkpoint = torch.load(opts.checkpoint) + else: + checkpoint = {} + + all_dbs = opts.train_txt_dbs + [opts.val_txt_db] + toker = json.load(open(f'{all_dbs[0]}/meta.json'))['bert'] + assert all(toker == json.load(open(f'{db}/meta.json'))['bert'] + for db in all_dbs) + model = UniterForVisualQuestionAnswering.from_pretrained( + opts.model_config, checkpoint, + img_dim=IMG_DIM, num_answer=len(ans2label)) + model.to(device) + # make sure every process has same model parameters in the beginning + broadcast_tensors([p.data for p in model.parameters()], 0) + set_dropout(model, opts.dropout) + + # Prepare optimizer + optimizer = build_optimizer(model, opts) + model, optimizer = amp.initialize(model, optimizer, + enabled=opts.fp16, opt_level='O2') + global_step = 0 + if rank == 0: + save_training_meta(opts) + TB_LOGGER.create(join(opts.output_dir, 'log')) + pbar = tqdm(total=opts.num_train_steps) + model_saver = ModelSaver(join(opts.output_dir, 'ckpt')) + json.dump(ans2label, + open(join(opts.output_dir, 'ckpt', 'ans2label.json'), 'w')) + os.makedirs(join(opts.output_dir, 'results')) # store VQA predictions + add_log_to_file(join(opts.output_dir, 'log', 'log.txt')) + else: + LOGGER.disabled = True + pbar = NoOp() + model_saver = NoOp() + + LOGGER.info(f"***** Running training with {n_gpu} GPUs *****") + LOGGER.info(" Num examples = %d", len(train_dataset) * hvd.size()) + LOGGER.info(" Batch size = %d", opts.train_batch_size) + LOGGER.info(" Accumulate steps = %d", opts.gradient_accumulation_steps) + LOGGER.info(" Num steps = %d", opts.num_train_steps) + + running_loss = RunningMeter('loss') + model.train() + n_examples = 0 + n_epoch = 0 + start = time() + # quick hack for amp delay_unscale bug + optimizer.zero_grad() + optimizer.step() + while True: + for step, batch in enumerate(train_dataloader): + n_examples += batch['input_ids'].size(0) + + loss = model(batch, compute_loss=True) + loss = loss.mean() * batch['targets'].size(1) # instance-leval bce + delay_unscale = (step+1) % opts.gradient_accumulation_steps != 0 + with amp.scale_loss(loss, optimizer, delay_unscale=delay_unscale + ) as scaled_loss: + scaled_loss.backward() + if not delay_unscale: + # gather gradients from every processes + # do this before unscaling to make sure every process uses + # the same gradient scale + grads = [p.grad.data for p in model.parameters() + if p.requires_grad and p.grad is not None] + all_reduce_and_rescale_tensors(grads, float(1)) + + running_loss(loss.item()) + + if (step + 1) % opts.gradient_accumulation_steps == 0: + global_step += 1 + + # learning rate scheduling + lr_this_step = get_lr_sched(global_step, opts) + for i, param_group in enumerate(optimizer.param_groups): + if i == 0 or i == 1: + param_group['lr'] = lr_this_step * opts.lr_mul + elif i == 2 or i == 3: + param_group['lr'] = lr_this_step + else: + raise ValueError() + TB_LOGGER.add_scalar('lr', lr_this_step, global_step) + + # log loss + losses = all_gather_list(running_loss) + running_loss = RunningMeter( + 'loss', sum(l.val for l in losses)/len(losses)) + TB_LOGGER.add_scalar('loss', running_loss.val, global_step) + TB_LOGGER.step() + + # update model params + if opts.grad_norm != -1: + grad_norm = clip_grad_norm_(amp.master_params(optimizer), + opts.grad_norm) + TB_LOGGER.add_scalar('grad_norm', grad_norm, global_step) + optimizer.step() + optimizer.zero_grad() + pbar.update(1) + + if global_step % 100 == 0: + # monitor training throughput + LOGGER.info(f'============Step {global_step}=============') + tot_ex = sum(all_gather_list(n_examples)) + ex_per_sec = int(tot_ex / (time()-start)) + LOGGER.info(f'{tot_ex} examples trained at ' + f'{ex_per_sec} ex/s') + TB_LOGGER.add_scalar('perf/ex_per_s', + ex_per_sec, global_step) + LOGGER.info(f'===========================================') + + if global_step % opts.valid_steps == 0: + val_log, results = validate( + model, val_dataloader, label2ans) + with open(f'{opts.output_dir}/results/' + f'results_{global_step}_' + f'rank{rank}.json', 'w') as f: + json.dump(results, f) + TB_LOGGER.log_scaler_dict(val_log) + model_saver.save(model, global_step) + if global_step >= opts.num_train_steps: + break + if global_step >= opts.num_train_steps: + break + n_epoch += 1 + LOGGER.info(f"finished {n_epoch} epochs") + val_log, results = validate(model, val_dataloader, label2ans) + with open(f'{opts.output_dir}/results/' + f'results_{global_step}_' + f'rank{rank}_final.json', 'w') as f: + json.dump(results, f) + TB_LOGGER.log_scaler_dict(val_log) + model_saver.save(model, f'{global_step}_final') + + +@torch.no_grad() +def validate(model, val_loader, label2ans): + LOGGER.info("start running validation...") + model.eval() + val_loss = 0 + tot_score = 0 + n_ex = 0 + st = time() + results = {} + for i, batch in enumerate(val_loader): + scores = model(batch, compute_loss=False) + targets = batch['targets'] + loss = F.binary_cross_entropy_with_logits( + scores, targets, reduction='sum') + val_loss += loss.item() + tot_score += compute_score_with_logits(scores, targets).sum().item() + answers = [label2ans[i] + for i in scores.max(dim=-1, keepdim=False + )[1].cpu().tolist()] + for qid, answer in zip(batch['qids'], answers): + results[qid] = answer + n_ex += len(batch['qids']) + val_loss = sum(all_gather_list(val_loss)) + tot_score = sum(all_gather_list(tot_score)) + n_ex = sum(all_gather_list(n_ex)) + tot_time = time()-st + val_loss /= n_ex + val_acc = tot_score / n_ex + val_log = {'valid/loss': val_loss, + 'valid/acc': val_acc, + 'valid/ex_per_s': n_ex/tot_time} + model.train() + LOGGER.info(f"validation finished in {int(tot_time)} seconds, " + f"score: {val_acc*100:.2f}") + return val_log, results + + +def compute_score_with_logits(logits, labels): + logits = torch.max(logits, 1)[1] # argmax + one_hots = torch.zeros(*labels.size(), device=labels.device) + one_hots.scatter_(1, logits.view(-1, 1), 1) + scores = (one_hots * labels) + return scores + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + # Required parameters + # TODO datasets + parser.add_argument('--compressed_db', action='store_true', + help='use compressed LMDB') + parser.add_argument("--model_config", + default=None, type=str, + help="json file for model architecture") + parser.add_argument("--checkpoint", + default=None, type=str, + help="pretrained model") + + parser.add_argument( + "--output_dir", default=None, type=str, + help="The output directory where the model checkpoints will be " + "written.") + + # Prepro parameters + parser.add_argument('--max_txt_len', type=int, default=60, + help='max number of tokens in text (BERT BPE)') + parser.add_argument('--conf_th', type=float, default=0.2, + help='threshold for dynamic bounding boxes ' + '(-1 for fixed)') + parser.add_argument('--max_bb', type=int, default=100, + help='max number of bounding boxes') + parser.add_argument('--min_bb', type=int, default=10, + help='min number of bounding boxes') + parser.add_argument('--num_bb', type=int, default=36, + help='static number of bounding boxes') + + # training parameters + parser.add_argument("--train_batch_size", default=4096, type=int, + help="Total batch size for training. " + "(batch by tokens)") + parser.add_argument("--val_batch_size", default=4096, type=int, + help="Total batch size for validation. " + "(batch by tokens)") + parser.add_argument('--gradient_accumulation_steps', type=int, default=16, + help="Number of updates steps to accumualte before " + "performing a backward/update pass.") + parser.add_argument("--learning_rate", default=3e-5, type=float, + help="The initial learning rate for Adam.") + parser.add_argument("--lr_mul", default=1.0, type=float, + help="multiplier for top layer lr") + parser.add_argument("--valid_steps", default=1000, type=int, + help="Run validation every X steps") + parser.add_argument("--num_train_steps", default=100000, type=int, + help="Total number of training updates to perform.") + parser.add_argument("--optim", default='adam', + choices=['adam', 'adamax', 'adamw'], + help="optimizer") + parser.add_argument("--betas", default=[0.9, 0.98], nargs='+', + help="beta for adam optimizer") + parser.add_argument("--decay", default='linear', + choices=['linear', 'invsqrt', 'constant', 'vqa'], + help="learning rate decay method") + parser.add_argument("--decay_int", default=2000, type=int, + help="interval between VQA lr decy") + parser.add_argument("--warm_int", default=2000, type=int, + help="interval for VQA lr warmup") + parser.add_argument("--decay_st", default=20000, type=int, + help="when to start decay") + parser.add_argument("--decay_rate", default=0.2, type=float, + help="ratio of lr decay") + parser.add_argument("--dropout", default=0.1, type=float, + help="tune dropout regularization") + parser.add_argument("--weight_decay", default=0.0, type=float, + help="weight decay (L2) regularization") + parser.add_argument("--grad_norm", default=2.0, type=float, + help="gradient clipping (-1 for no clipping)") + parser.add_argument("--warmup_steps", default=4000, type=int, + help="Number of training steps to perform linear " + "learning rate warmup for. (invsqrt decay)") + + # device parameters + parser.add_argument('--seed', type=int, default=42, + help="random seed for initialization") + parser.add_argument('--fp16', action='store_true', + help="Whether to use 16-bit float precision instead " + "of 32-bit") + parser.add_argument('--n_workers', type=int, default=4, + help="number of data workers") + parser.add_argument('--pin_mem', action='store_true', help="pin memory") + + # can use config files + parser.add_argument('--config', help='JSON config files') + + args = parse_with_config(parser) + + if exists(args.output_dir) and os.listdir(args.output_dir): + raise ValueError("Output directory ({}) already exists and is not " + "empty.".format(args.output_dir)) + + # options safe guard + # TODO + if args.conf_th == -1: + assert args.max_bb + args.max_txt_len + 2 <= 512 + else: + assert args.num_bb + args.max_txt_len + 2 <= 512 + + main(args) diff --git a/uniter_model/utils/__init__.py b/uniter_model/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/uniter_model/utils/const.py b/uniter_model/utils/const.py new file mode 100644 index 0000000..1286e45 --- /dev/null +++ b/uniter_model/utils/const.py @@ -0,0 +1,4 @@ +""" constants """ +IMG_DIM = 2048 +IMG_LABEL_DIM = 1601 +BUCKET_SIZE = 8192 diff --git a/uniter_model/utils/distributed.py b/uniter_model/utils/distributed.py new file mode 100644 index 0000000..6264815 --- /dev/null +++ b/uniter_model/utils/distributed.py @@ -0,0 +1,230 @@ +""" +distributed API using Horovod +""" +import math +import pickle + +import torch +from horovod import torch as hvd + +import msgpack +import msgpack_numpy +msgpack_numpy.patch() + + +def all_reduce_and_rescale_tensors(tensors, rescale_denom): + """All-reduce and rescale tensors at once (as a flattened tensor) + + Args: + tensors: list of Tensors to all-reduce + rescale_denom: denominator for rescaling summed Tensors + """ + # buffer size in bytes, determine equiv. # of elements based on data type + sz = sum(t.numel() for t in tensors) + buffer_t = tensors[0].new(sz).zero_() + + # copy tensors into buffer_t + offset = 0 + for t in tensors: + numel = t.numel() + buffer_t[offset:offset+numel].copy_(t.view(-1)) + offset += numel + + # all-reduce and rescale + hvd.allreduce_(buffer_t[:offset]) + buffer_t.div_(rescale_denom) + + # copy all-reduced buffer back into tensors + offset = 0 + for t in tensors: + numel = t.numel() + t.view(-1).copy_(buffer_t[offset:offset+numel]) + offset += numel + + +def all_reduce_and_rescale_tensors_chunked(tensors, rescale_denom, + buffer_size=10485760): + """All-reduce and rescale tensors in chunks of the specified size. + + Args: + tensors: list of Tensors to all-reduce + rescale_denom: denominator for rescaling summed Tensors + buffer_size: all-reduce chunk size in bytes + """ + # buffer size in bytes, determine equiv. # of elements based on data type + buffer_t = tensors[0].new( + math.ceil(buffer_size / tensors[0].element_size())).zero_() + buffer = [] + + def all_reduce_buffer(): + # copy tensors into buffer_t + offset = 0 + for t in buffer: + numel = t.numel() + buffer_t[offset:offset+numel].copy_(t.view(-1)) + offset += numel + + # all-reduce and rescale + hvd.allreduce_(buffer_t[:offset]) + buffer_t.div_(rescale_denom) + + # copy all-reduced buffer back into tensors + offset = 0 + for t in buffer: + numel = t.numel() + t.view(-1).copy_(buffer_t[offset:offset+numel]) + offset += numel + + filled = 0 + for t in tensors: + sz = t.numel() * t.element_size() + if sz > buffer_size: + # tensor is bigger than buffer, all-reduce and rescale directly + hvd.allreduce_(t) + t.div_(rescale_denom) + elif filled + sz > buffer_size: + # buffer is full, all-reduce and replace buffer with grad + all_reduce_buffer() + buffer = [t] + filled = sz + else: + # add tensor to buffer + buffer.append(t) + filled += sz + + if len(buffer) > 0: + all_reduce_buffer() + + +def broadcast_tensors(tensors, root_rank, buffer_size=10485760): + """broadcast tensors in chunks of the specified size. + + Args: + tensors: list of Tensors to broadcast + root_rank: rank to broadcast + buffer_size: broadcast chunk size in bytes + """ + # buffer size in bytes, determine equiv. # of elements based on data type + buffer_t = tensors[0].new( + math.ceil(buffer_size / tensors[0].element_size())).zero_() + buffer = [] + + def broadcast_buffer(): + # copy tensors into buffer_t + offset = 0 + for t in buffer: + numel = t.numel() + buffer_t[offset:offset+numel].copy_(t.view(-1)) + offset += numel + + # broadcast + hvd.broadcast_(buffer_t[:offset], root_rank) + + # copy all-reduced buffer back into tensors + offset = 0 + for t in buffer: + numel = t.numel() + t.view(-1).copy_(buffer_t[offset:offset+numel]) + offset += numel + + filled = 0 + for t in tensors: + sz = t.numel() * t.element_size() + if sz > buffer_size: + # tensor is bigger than buffer, broadcast directly + hvd.broadcast_(t, root_rank) + elif filled + sz > buffer_size: + # buffer is full, broadcast and replace buffer with tensor + broadcast_buffer() + buffer = [t] + filled = sz + else: + # add tensor to buffer + buffer.append(t) + filled += sz + + if len(buffer) > 0: + broadcast_buffer() + + +def _encode(enc, max_size, buffer_=None): + enc_size = len(enc) + enc_byte = max(math.floor(math.log(max_size, 256)+1), 1) + if buffer_ is None or len(buffer_) < enc_size + enc_byte: + buffer_ = torch.cuda.ByteTensor(enc_size+enc_byte) + remainder = enc_size + for i in range(enc_byte): + base = 256 ** (enc_byte-i-1) + buffer_[i] = remainder // base + remainder %= base + buffer_[enc_byte:enc_byte+enc_size] = torch.ByteTensor(list(enc)) + return buffer_, enc_byte + + +def _decode(buffer_, enc_byte): + size = sum(256 ** (enc_byte-i-1) * buffer_[i].item() + for i in range(enc_byte)) + bytes_list = bytes(buffer_[enc_byte:enc_byte+size].tolist()) + shift = size + enc_byte + return bytes_list, shift + + +_BUFFER_SIZE = 4096 + + +def all_gather_list(data): + """Gathers arbitrary data from all nodes into a list.""" + if not hasattr(all_gather_list, '_buffer'): + # keeps small buffer to avoid re-allocate every call + all_gather_list._buffer = torch.cuda.ByteTensor(_BUFFER_SIZE) + try: + enc = msgpack.dumps(data, use_bin_type=True) + msgpack_success = True + except TypeError: + enc = pickle.dumps(data) + msgpack_success = False + + enc_size = len(enc) + max_size = hvd.allgather(torch.tensor([enc_size]).cuda()).max().item() + buffer_ = all_gather_list._buffer + in_buffer, enc_byte = _encode(enc, max_size, buffer_) + + out_buffer = hvd.allgather(in_buffer[:enc_byte+enc_size]) + + results = [] + for _ in range(hvd.size()): + bytes_list, shift = _decode(out_buffer, enc_byte) + out_buffer = out_buffer[shift:] + + if msgpack_success: + result = msgpack.loads(bytes_list, raw=False) + else: + result = pickle.loads(bytes_list) + results.append(result) + return results + + +def any_broadcast(data, root_rank): + """broadcast arbitrary data from root_rank to all nodes.""" + if not hasattr(any_broadcast, '_buffer'): + # keeps small buffer to avoid re-allocate every call + any_broadcast._buffer = torch.cuda.ByteTensor(_BUFFER_SIZE) + try: + enc = msgpack.dumps(data, use_bin_type=True) + msgpack_success = True + except TypeError: + enc = pickle.dumps(data) + msgpack_success = False + + max_size = hvd.allgather(torch.tensor([len(enc)]).cuda()).max().item() + buffer_ = any_broadcast._buffer + buffer_, enc_byte = _encode(enc, max_size, buffer_) + + hvd.broadcast_(buffer_, root_rank) + + bytes_list, _ = _decode(buffer_, enc_byte) + if msgpack_success: + result = msgpack.loads(bytes_list, raw=False) + else: + result = pickle.loads(bytes_list) + return result diff --git a/uniter_model/utils/itm.py b/uniter_model/utils/itm.py new file mode 100644 index 0000000..6a14e82 --- /dev/null +++ b/uniter_model/utils/itm.py @@ -0,0 +1,62 @@ +import numpy as np + + +def i2t(sims, npts=None, return_ranks=False): + """ + Images->Text (Image Annotation) + sims: (N, 5N) matrix of similarity im-cap + """ + npts = sims.shape[0] + ranks = np.zeros(npts) + top1 = np.zeros(npts) + for index in range(npts): + inds = np.argsort(sims[index])[::-1] + # Score + rank = 1e20 + for i in range(5 * index, 5 * index + 5, 1): + tmp = np.where(inds == i)[0][0] + if tmp < rank: + rank = tmp + ranks[index] = rank + top1[index] = inds[0] + + # Compute metrics + r1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks) + r5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks) + r10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks) + medr = np.floor(np.median(ranks)) + 1 + meanr = ranks.mean() + 1 + if return_ranks: + return (r1, r5, r10, medr, meanr), (ranks, top1) + else: + return (r1, r5, r10, medr, meanr) + + +def t2i(sims, npts=None, return_ranks=False): + """ + Text->Images (Image Search) + sims: (N, 5N) matrix of similarity im-cap + """ + npts = sims.shape[0] + ranks = np.zeros(5 * npts) + top1 = np.zeros(5 * npts) + + # --> (5N(caption), N(image)) + sims = sims.T + + for index in range(npts): + for i in range(5): + inds = np.argsort(sims[5 * index + i])[::-1] + ranks[5 * index + i] = np.where(inds == index)[0][0] + top1[5 * index + i] = inds[0] + + # Compute metrics + r1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks) + r5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks) + r10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks) + medr = np.floor(np.median(ranks)) + 1 + meanr = ranks.mean() + 1 + if return_ranks: + return (r1, r5, r10, medr, meanr), (ranks, top1) + else: + return (r1, r5, r10, medr, meanr) diff --git a/uniter_model/utils/logger.py b/uniter_model/utils/logger.py new file mode 100644 index 0000000..db634b3 --- /dev/null +++ b/uniter_model/utils/logger.py @@ -0,0 +1,91 @@ +""" +helper for logging +NOTE: loggers are global objects use with caution +""" +import logging +import math + +import tensorboardX + + +_LOG_FMT = '%(asctime)s - %(levelname)s - %(name)s - %(message)s' +_DATE_FMT = '%m/%d/%Y %H:%M:%S' +logging.basicConfig(format=_LOG_FMT, datefmt=_DATE_FMT, level=logging.INFO) +LOGGER = logging.getLogger('__main__') # this is the global logger + + +def add_log_to_file(log_path): + fh = logging.FileHandler(log_path) + formatter = logging.Formatter(_LOG_FMT, datefmt=_DATE_FMT) + fh.setFormatter(formatter) + LOGGER.addHandler(fh) + + +class TensorboardLogger(object): + def __init__(self): + self._logger = None + self._global_step = 0 + + def create(self, path): + self._logger = tensorboardX.SummaryWriter(path) + + def noop(self, *args, **kwargs): + return + + def step(self): + self._global_step += 1 + + @property + def global_step(self): + return self._global_step + + def log_scaler_dict(self, log_dict, prefix=''): + """ log a dictionary of scalar values""" + if self._logger is None: + return + if prefix: + prefix = f'{prefix}_' + for name, value in log_dict.items(): + if isinstance(value, dict): + self.log_scaler_dict(value, self._global_step, + prefix=f'{prefix}{name}') + else: + self._logger.add_scalar(f'{prefix}{name}', value, + self._global_step) + + def __getattr__(self, name): + if self._logger is None: + return self.noop + return self._logger.__getattribute__(name) + + +TB_LOGGER = TensorboardLogger() + + +class RunningMeter(object): + """ running meteor of a scalar value + (useful for monitoring training loss) + """ + def __init__(self, name, val=None, smooth=0.99): + self._name = name + self._sm = smooth + self._val = val + + def __call__(self, value): + val = (value if self._val is None + else value*(1-self._sm) + self._val*self._sm) + if not math.isnan(val) and not math.isinf(val): + self._val = val + else: + print(f'Inf/Nan in {self._name}') + + def __str__(self): + return f'{self._name}: {self._val:.4f}' + + @property + def val(self): + return self._val + + @property + def name(self): + return self._name diff --git a/uniter_model/utils/misc.py b/uniter_model/utils/misc.py new file mode 100644 index 0000000..6781e1b --- /dev/null +++ b/uniter_model/utils/misc.py @@ -0,0 +1,67 @@ +""" +Misc utilities +""" +import json +import random +import sys + +import torch +import numpy as np + +from uniter_model.utils.logger import LOGGER + + +class NoOp(object): + """ useful for distributed training No-Ops """ + def __getattr__(self, name): + return self.noop + + def noop(self, *args, **kwargs): + return + + +def parse_with_config(parser): + args = parser.parse_args() + if args.config is not None: + config_args = json.load(open(args.config)) + override_keys = {arg[2:].split('=')[0] for arg in sys.argv[1:] + if arg.startswith('--')} + for k, v in config_args.items(): + if k not in override_keys: + setattr(args, k, v) + del args.config + return args + + +VE_ENT2IDX = { + 'contradiction': 0, + 'entailment': 1, + 'neutral': 2 +} + +VE_IDX2ENT = { + 0: 'contradiction', + 1: 'entailment', + 2: 'neutral' +} + + +class Struct(object): + def __init__(self, dict_): + self.__dict__.update(dict_) + + +def set_dropout(model, drop_p): + for name, module in model.named_modules(): + # we might want to tune dropout for smaller dataset + if isinstance(module, torch.nn.Dropout): + if module.p != drop_p: + module.p = drop_p + LOGGER.info(f'{name} set to {drop_p}') + + +def set_random_seed(seed): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) diff --git a/uniter_model/utils/save.py b/uniter_model/utils/save.py new file mode 100644 index 0000000..bd7a8b6 --- /dev/null +++ b/uniter_model/utils/save.py @@ -0,0 +1,76 @@ +""" +saving utilities +""" +import json +import os +from os.path import abspath, dirname, exists, join +import subprocess + +import torch + +from uniter_model.utils.logger import LOGGER + + +def save_training_meta(args): + if args.rank > 0: + return + + os.makedirs(join(args.output_dir, 'log'), exist_ok=True) + os.makedirs(join(args.output_dir, 'ckpt'), exist_ok=True) + + with open(join(args.output_dir, 'log', 'hps.json'), 'w') as writer: + json.dump(vars(args), writer, indent=4) + if False: + model_config = json.load(open(args.model_config)) + with open(join(args.output_dir, 'log', 'model.json'), 'w') as writer: + json.dump(model_config, writer, indent=4) + # git info + try: + LOGGER.info("Waiting on git info....") + c = subprocess.run(["git", "rev-parse", "--abbrev-ref", "HEAD"], + timeout=10, stdout=subprocess.PIPE) + git_branch_name = c.stdout.decode().strip() + LOGGER.info("Git branch: %s", git_branch_name) + c = subprocess.run(["git", "rev-parse", "HEAD"], + timeout=10, stdout=subprocess.PIPE) + git_sha = c.stdout.decode().strip() + LOGGER.info("Git SHA: %s", git_sha) + git_dir = abspath(dirname(__file__)) + git_status = subprocess.check_output( + ['git', 'status', '--short'], + cwd=git_dir, universal_newlines=True).strip() + with open(join(args.output_dir, 'log', 'git_info.json'), + 'w') as writer: + json.dump({'branch': git_branch_name, + 'is_dirty': bool(git_status), + 'status': git_status, + 'sha': git_sha}, + writer, indent=4) + except subprocess.TimeoutExpired as e: + LOGGER.exception(e) + LOGGER.warn("Git info not found. Moving right along...") + + +class ModelSaver(object): + def __init__(self, output_dir, prefix='model_step', suffix='pt'): + self.output_dir = output_dir + self.prefix = prefix + self.suffix = suffix + + def save(self, model, step, optimizer=None): + output_model_file = join(self.output_dir, + f"{self.prefix}_{step}.{self.suffix}") + state_dict = {k: v.cpu() if isinstance(v, torch.Tensor) else v + for k, v in model.state_dict().items()} + if hasattr(model, 'vocab_pad') and model.vocab_pad: + # store vocab embeddings before padding + emb_w = state_dict['bert.embeddings.word_embeddings.weight'] + emb_w = emb_w[:-model.vocab_pad, :] + state_dict['bert.embeddings.word_embeddings.weight'] = emb_w + state_dict['cls.predictions.decoder.weight'] = emb_w + torch.save(state_dict, output_model_file) + if optimizer is not None: + dump = {'step': step, 'optimizer': optimizer.state_dict()} + if hasattr(optimizer, '_amp_stash'): + pass # TODO fp16 optimizer + torch.save(dump, f'{self.output_dir}/train_state_{step}.pt') diff --git a/uniter_model/utils/visual_entailment.py b/uniter_model/utils/visual_entailment.py new file mode 100644 index 0000000..4ebb4b7 --- /dev/null +++ b/uniter_model/utils/visual_entailment.py @@ -0,0 +1,46 @@ +""" +NOTE: modified from ban-vqa +This code is slightly modified from Hengyuan Hu's repository. +https://github.com/hengyuan-hu/bottom-up-attention-vqa +""" +import os +import sys +import pickle + + +def create_ans2label(path): + """ + occurence: dict {answer -> whatever} + name: dir of the output file + """ + ans2label = {"contradiction": 0, "entailment":1 , "neutral": 2} + label2ans = ["contradiction", "entailment", "neutral"] + + output_file = os.path.join(path, 'visual_entailment_ans2label.pkl') + pickle.dump(ans2label, open(output_file, 'wb')) + + +def compute_target(answers, ans2label): + answer_count = {} + for answer in answers: + answer_ = answer + answer_count[answer_] = answer_count.get(answer_, 0) + 1 + + labels = [] + scores = [] + for answer in answer_count: + if answer not in ans2label: + continue + labels.append(ans2label[answer]) + score = answer_count[answer]/len(answers) + scores.append(score) + target = {'labels': labels, 'scores': scores} + return target + + +if __name__ == '__main__': + output = sys.argv[1:][0] + print(output) + if os.path.exists(f'{output}/visual_entailment_ans2label.pkl'): + raise ValueError(f'{output} already exists') + create_ans2label(output) diff --git a/uniter_model/utils/vqa.py b/uniter_model/utils/vqa.py new file mode 100644 index 0000000..0b1360f --- /dev/null +++ b/uniter_model/utils/vqa.py @@ -0,0 +1,203 @@ +""" +NOTE: modified from ban-vqa +This code is slightly modified from Hengyuan Hu's repository. +https://github.com/hengyuan-hu/bottom-up-attention-vqa +""" +import os +import json +import re +import sys +import pickle + + +CONTRACTIONS = { + "aint": "ain't", "arent": "aren't", "cant": "can't", "couldve": + "could've", "couldnt": "couldn't", "couldn'tve": "couldn't've", + "couldnt've": "couldn't've", "didnt": "didn't", "doesnt": + "doesn't", "dont": "don't", "hadnt": "hadn't", "hadnt've": + "hadn't've", "hadn'tve": "hadn't've", "hasnt": "hasn't", "havent": + "haven't", "hed": "he'd", "hed've": "he'd've", "he'dve": + "he'd've", "hes": "he's", "howd": "how'd", "howll": "how'll", + "hows": "how's", "Id've": "I'd've", "I'dve": "I'd've", "Im": + "I'm", "Ive": "I've", "isnt": "isn't", "itd": "it'd", "itd've": + "it'd've", "it'dve": "it'd've", "itll": "it'll", "let's": "let's", + "maam": "ma'am", "mightnt": "mightn't", "mightnt've": + "mightn't've", "mightn'tve": "mightn't've", "mightve": "might've", + "mustnt": "mustn't", "mustve": "must've", "neednt": "needn't", + "notve": "not've", "oclock": "o'clock", "oughtnt": "oughtn't", + "ow's'at": "'ow's'at", "'ows'at": "'ow's'at", "'ow'sat": + "'ow's'at", "shant": "shan't", "shed've": "she'd've", "she'dve": + "she'd've", "she's": "she's", "shouldve": "should've", "shouldnt": + "shouldn't", "shouldnt've": "shouldn't've", "shouldn'tve": + "shouldn't've", "somebody'd": "somebodyd", "somebodyd've": + "somebody'd've", "somebody'dve": "somebody'd've", "somebodyll": + "somebody'll", "somebodys": "somebody's", "someoned": "someone'd", + "someoned've": "someone'd've", "someone'dve": "someone'd've", + "someonell": "someone'll", "someones": "someone's", "somethingd": + "something'd", "somethingd've": "something'd've", "something'dve": + "something'd've", "somethingll": "something'll", "thats": + "that's", "thered": "there'd", "thered've": "there'd've", + "there'dve": "there'd've", "therere": "there're", "theres": + "there's", "theyd": "they'd", "theyd've": "they'd've", "they'dve": + "they'd've", "theyll": "they'll", "theyre": "they're", "theyve": + "they've", "twas": "'twas", "wasnt": "wasn't", "wed've": + "we'd've", "we'dve": "we'd've", "weve": "we've", "werent": + "weren't", "whatll": "what'll", "whatre": "what're", "whats": + "what's", "whatve": "what've", "whens": "when's", "whered": + "where'd", "wheres": "where's", "whereve": "where've", "whod": + "who'd", "whod've": "who'd've", "who'dve": "who'd've", "wholl": + "who'll", "whos": "who's", "whove": "who've", "whyll": "why'll", + "whyre": "why're", "whys": "why's", "wont": "won't", "wouldve": + "would've", "wouldnt": "wouldn't", "wouldnt've": "wouldn't've", + "wouldn'tve": "wouldn't've", "yall": "y'all", "yall'll": + "y'all'll", "y'allll": "y'all'll", "yall'd've": "y'all'd've", + "y'alld've": "y'all'd've", "y'all'dve": "y'all'd've", "youd": + "you'd", "youd've": "you'd've", "you'dve": "you'd've", "youll": + "you'll", "youre": "you're", "youve": "you've" +} + +MANUAL_MAP = {'none': '0', + 'zero': '0', + 'one': '1', + 'two': '2', + 'three': '3', + 'four': '4', + 'five': '5', + 'six': '6', + 'seven': '7', + 'eight': '8', + 'nine': '9', + 'ten': '10'} +ARTICLES = ['a', 'an', 'the'] +PERIOD_STRIP = re.compile(r"(?!<=\d)(\.)(?!\d)") +COMMA_STRIP = re.compile(r"(\d)(\,)(\d)") +PUNCT = [';', r"/", '[', ']', '"', '{', '}', + '(', ')', '=', '+', '\\', '_', '-', + '>', '<', '@', '`', ',', '?', '!'] + + +# Notice that VQA score is the average of 10 choose 9 candidate answers cases +# See http://visualqa.org/evaluation.html +def get_score(occurences): + if occurences == 0: + return .0 + elif occurences == 1: + return .3 + elif occurences == 2: + return .6 + elif occurences == 3: + return .9 + else: + return 1. + + +def process_punctuation(inText): + outText = inText + for p in PUNCT: + if (p + ' ' in inText + or ' ' + p in inText + or re.search(COMMA_STRIP, inText) is not None): + outText = outText.replace(p, '') + else: + outText = outText.replace(p, ' ') + outText = PERIOD_STRIP.sub("", outText, re.UNICODE) + return outText + + +def process_digit_article(inText): + outText = [] + tempText = inText.lower().split() + for word in tempText: + word = MANUAL_MAP.setdefault(word, word) + if word not in ARTICLES: + outText.append(word) + else: + pass + for wordId, word in enumerate(outText): + if word in CONTRACTIONS: + outText[wordId] = CONTRACTIONS[word] + outText = ' '.join(outText) + return outText + + +def preprocess_answer(answer): + answer = process_digit_article(process_punctuation(answer)) + answer = answer.replace(',', '') + return answer + + +def filter_answers(answers_dset, min_occurence): + """This will change the answer to preprocessed version + """ + occurence = {} + + for ans_entry in answers_dset: + gtruth = ans_entry.get('multiple_choice_answer', None) + if gtruth is None: + gtruth = ans_entry['answers'][0]['answer'] # VG, GQA pretraining + gtruth = preprocess_answer(gtruth) + if gtruth not in occurence: + occurence[gtruth] = set() + occurence[gtruth].add(ans_entry['question_id']) + for answer in list(occurence): + if len(occurence[answer]) < min_occurence: + occurence.pop(answer) + + print('Num of answers that appear >= %d times: %d' % ( + min_occurence, len(occurence))) + return occurence + + +def create_ans2label(occurence, path): + """ + occurence: dict {answer -> whatever} + name: dir of the output file + """ + ans2label = {} + label2ans = [] + label = 0 + for answer in occurence: + label2ans.append(answer) + ans2label[answer] = label + label += 1 + + output_file = os.path.join(path, 'ans2label.pkl') + pickle.dump(ans2label, open(output_file, 'wb')) + + +def compute_target(answers, ans2label): + answer_count = {} + if len(answers) == 1: + # VG VQA, GQA + answer_ = preprocess_answer(answers[0]['answer']) + answer_count[answer_] = 10 + else: + # COCO VQA + for answer in answers: + answer_ = preprocess_answer(answer['answer']) + answer_count[answer_] = answer_count.get(answer_, 0) + 1 + + labels = [] + scores = [] + for answer in answer_count: + if answer not in ans2label: + continue + labels.append(ans2label[answer]) + score = get_score(answer_count[answer]) + scores.append(score) + target = {'labels': labels, 'scores': scores} + return target + + +if __name__ == '__main__': + *answer_files, output = sys.argv[1:] + answers = [] + for ans_file in answer_files: + ans = json.load(open(ans_file))['annotations'] + answers.extend(ans) + + occurence = filter_answers(answers, 9) + + if os.path.exists(f'{output}/ans2label.pkl'): + raise ValueError(f'{output} already exists') + create_ans2label(occurence, output) diff --git a/utils.py b/utils.py index 8350c97..c4d5b74 100644 --- a/utils.py +++ b/utils.py @@ -11,8 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import wget import torch +import os from types import SimpleNamespace +from torchvision.datasets.utils import download_url class Configs(SimpleNamespace): @@ -39,3 +42,8 @@ def get_gather_index(txt_lens, num_bbs, batch_size, max_len, out_size): # gather_index.data[i, tl:tl+nbb] = torch.arange(max_len, max_len+nbb, # dtype=torch.long).data return gather_index + +def download_file(url, save_path): + filename = os.path.basename(url) + download_url(url, save_path, filename) +