logo
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions

234 lines
8.9 KiB

import logging
import random
import tqdm
import torch
import pickle
import torch.distributed as dist
from collections import defaultdict
from horovod import torch as hvd
from torch import Tensor as T
from typing import Tuple
logger = logging.getLogger()
def get_rank():
return hvd.rank()
def get_world_size():
return hvd.size()
def print_args(args):
logger.info(" **************** CONFIGURATION **************** ")
for key, val in sorted(vars(args).items()):
keystr = "{}".format(key) + (" " * (30 - len(key)))
logger.info("%s --> %s", keystr, val)
logger.info(" **************** CONFIGURATION **************** ")
def num_of_parameters(model, requires_grad=False):
if requires_grad:
return sum(p.numel() for p in model.parameters() if p.requires_grad)
else:
return sum(p.numel() for p in model.parameters())
def get_default_group():
return dist.group.WORLD
def all_reduce(tensor, group=None):
if group is None:
group = get_default_group()
return dist.all_reduce(tensor, group=group)
def all_gather_list(data, group=None, max_size=16384):
"""Gathers arbitrary data from all nodes into a list.
Similar to :func:`~torch.distributed.all_gather` but for arbitrary Python
data. Note that *data* must be picklable.
Args:
data (Any): data from the local worker to be gathered on other workers
group (optional): group of the collective
"""
SIZE_STORAGE_BYTES = 4 # int32 to encode the payload size
enc = pickle.dumps(data)
enc_size = len(enc)
if enc_size + SIZE_STORAGE_BYTES > max_size:
raise ValueError(
'encoded data exceeds max_size, this can be fixed by increasing buffer size: {}'.format(enc_size))
rank = get_rank()
world_size = get_world_size()
buffer_size = max_size * world_size
if not hasattr(all_gather_list, '_buffer') or \
all_gather_list._buffer.numel() < buffer_size:
all_gather_list._buffer = torch.cuda.ByteTensor(buffer_size)
all_gather_list._cpu_buffer = torch.ByteTensor(max_size).pin_memory()
buffer = all_gather_list._buffer
buffer.zero_()
cpu_buffer = all_gather_list._cpu_buffer
assert enc_size < 256 ** SIZE_STORAGE_BYTES, 'Encoded object size should be less than {} bytes'.format(
256 ** SIZE_STORAGE_BYTES)
size_bytes = enc_size.to_bytes(SIZE_STORAGE_BYTES, byteorder='big')
cpu_buffer[0:SIZE_STORAGE_BYTES] = torch.ByteTensor(list(size_bytes))
cpu_buffer[SIZE_STORAGE_BYTES: enc_size + SIZE_STORAGE_BYTES] = torch.ByteTensor(list(enc))
start = rank * max_size
size = enc_size + SIZE_STORAGE_BYTES
buffer[start: start + size].copy_(cpu_buffer[:size])
all_reduce(buffer, group=group)
try:
result = []
for i in range(world_size):
out_buffer = buffer[i * max_size: (i + 1) * max_size]
size = int.from_bytes(out_buffer[0:SIZE_STORAGE_BYTES], byteorder='big')
if size > 0:
result.append(pickle.loads(bytes(out_buffer[SIZE_STORAGE_BYTES: size + SIZE_STORAGE_BYTES].tolist())))
return result
except pickle.UnpicklingError:
raise Exception(
'Unable to unpickle data from other workers. all_gather_list requires all '
'workers to enter the function together, so this error usually indicates '
'that the workers have fallen out of sync somehow. Workers can fall out of '
'sync if one of them runs out of memory, or if there are other conditions '
'in your training script that can cause one worker to finish an epoch '
'while other workers are still iterating over their portions of the data.'
)
def _calc_loss(args, loss_function, local_q_vector, local_ctx_vectors, local_caption_vectors, local_positive_idxs,
local_hard_negatives_idxs: list = None, experiment=None
):
"""
Calculates In-batch negatives schema loss and supports to run it in DDP mode by exchanging the representations
across all the nodes.
"""
distributed_world_size = 1 # args.distributed_world_size or 1
if distributed_world_size > 1:
# TODO: Add local_caption_vectors
q_vector_to_send = torch.empty_like(local_q_vector).cpu().copy_(local_q_vector).detach_()
ctx_vector_to_send = torch.empty_like(local_ctx_vectors).cpu().copy_(local_ctx_vectors).detach_()
global_question_ctx_vectors = all_gather_list(
[q_vector_to_send, ctx_vector_to_send, local_positive_idxs, local_hard_negatives_idxs],
max_size=args.global_loss_buf_sz)
global_q_vector = []
global_ctxs_vector = []
# ctxs_per_question = local_ctx_vectors.size(0)
positive_idx_per_question = []
hard_negatives_per_question = []
total_ctxs = 0
for i, item in enumerate(global_question_ctx_vectors):
q_vector, ctx_vectors, positive_idx, hard_negatives_idxs = item
if i != args.local_rank:
global_q_vector.append(q_vector.to(local_q_vector.device))
global_ctxs_vector.append(ctx_vectors.to(local_q_vector.device))
positive_idx_per_question.extend([v + total_ctxs for v in positive_idx])
hard_negatives_per_question.extend([[v + total_ctxs for v in l] for l in hard_negatives_idxs])
else:
global_q_vector.append(local_q_vector)
global_ctxs_vector.append(local_ctx_vectors)
positive_idx_per_question.extend([v + total_ctxs for v in local_positive_idxs])
hard_negatives_per_question.extend([[v + total_ctxs for v in l] for l in local_hard_negatives_idxs])
total_ctxs += ctx_vectors.size(0)
global_q_vector = torch.cat(global_q_vector, dim=0)
global_ctxs_vector = torch.cat(global_ctxs_vector, dim=0)
else:
global_q_vector = local_q_vector
global_ctxs_vector = local_ctx_vectors
global_caption_vector = local_caption_vectors
positive_idx_per_question = local_positive_idxs
hard_negatives_per_question = local_hard_negatives_idxs
loss, is_correct, scores = loss_function.calc(global_q_vector, global_ctxs_vector, global_caption_vector,
positive_idx_per_question, hard_negatives_per_question,
args.caption_score_weight, experiment)
return loss, is_correct, scores
def compare_models(model_1, model_2):
models_differ = 0
for key_item_1, key_item_2 in zip(model_1.state_dict().items(), model_2.state_dict().items()):
if torch.equal(key_item_1[1], key_item_2[1]):
pass
else:
models_differ += 1
if (key_item_1[0] == key_item_2[0]):
print('Mismtach found at', key_item_1[0])
else:
raise Exception
if models_differ == 0:
print('Models match perfectly! :)')
def is_main_process():
return hvd.rank() == 0
def display_img(img_meta, name, img_only=False):
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
img = mpimg.imread(img_meta[name]['img_file'])
plt.imshow(img)
plt.show()
if not img_only:
print('annotation')
print('\t' + '\n\t'.join(img_meta[name]['annotation']))
print('caption')
print('\t' + img_meta[name]['caption'][0])
def retrieve_query(model, query, indexer, args, top=10):
input_ids = args.tokenizer.encode(query)
input_ids = torch.LongTensor(input_ids).to(args.device).unsqueeze(0)
attn_mask = torch.ones(len(input_ids[0]), dtype=torch.long, device=args.device).unsqueeze(0)
pos_ids = torch.arange(len(input_ids[0]), dtype=torch.long, device=args.device).unsqueeze(0)
_, query_vector, _ = model.txt_model(input_ids=input_ids,attention_mask=attn_mask, position_ids=pos_ids)
res = indexer.search_knn(query_vector.detach().cpu().numpy(), 100)
return res
def get_model_encoded_vecs(model, dataloader):
img_embedding, caption_embedding, query_embedding = dict(), dict(), defaultdict(list)
labels_img_name = []
# for i, batch in enumerate(dataloader):
for i, batch in enumerate(tqdm.tqdm(dataloader)):
with torch.no_grad():
model_out = model(batch)
local_q_vectors, local_ctx_vectors, local_caption_vectors = model_out
img_embedding.update({img_id: img_vec.detach().cpu().numpy() for img_id, img_vec in zip(batch['img_fname'], local_ctx_vectors)})
caption_embedding.update({img_id: cap_vec.detach().cpu().numpy() for img_id, cap_vec in zip(batch['img_fname'], local_caption_vectors)})
query_embedding.update({img_id: cap_vec.detach().cpu().numpy() for img_id, cap_vec in zip(batch['txt_index'], local_q_vectors)})
labels_img_name.extend(batch['img_fname'])
return {
'img_embed': img_embedding,
'caption_embed': caption_embedding,
'txt_embed': query_embedding,
'img_name': labels_img_name
}