lightningdot
copied
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions
234 lines
8.9 KiB
234 lines
8.9 KiB
import logging
|
|
import random
|
|
import tqdm
|
|
import torch
|
|
import pickle
|
|
|
|
import torch.distributed as dist
|
|
|
|
from collections import defaultdict
|
|
from horovod import torch as hvd
|
|
from torch import Tensor as T
|
|
from typing import Tuple
|
|
|
|
|
|
logger = logging.getLogger()
|
|
|
|
|
|
def get_rank():
|
|
return hvd.rank()
|
|
|
|
|
|
def get_world_size():
|
|
return hvd.size()
|
|
|
|
|
|
def print_args(args):
|
|
logger.info(" **************** CONFIGURATION **************** ")
|
|
for key, val in sorted(vars(args).items()):
|
|
keystr = "{}".format(key) + (" " * (30 - len(key)))
|
|
logger.info("%s --> %s", keystr, val)
|
|
logger.info(" **************** CONFIGURATION **************** ")
|
|
|
|
|
|
def num_of_parameters(model, requires_grad=False):
|
|
if requires_grad:
|
|
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
else:
|
|
return sum(p.numel() for p in model.parameters())
|
|
|
|
|
|
def get_default_group():
|
|
return dist.group.WORLD
|
|
|
|
|
|
def all_reduce(tensor, group=None):
|
|
if group is None:
|
|
group = get_default_group()
|
|
return dist.all_reduce(tensor, group=group)
|
|
|
|
|
|
def all_gather_list(data, group=None, max_size=16384):
|
|
"""Gathers arbitrary data from all nodes into a list.
|
|
Similar to :func:`~torch.distributed.all_gather` but for arbitrary Python
|
|
data. Note that *data* must be picklable.
|
|
Args:
|
|
data (Any): data from the local worker to be gathered on other workers
|
|
group (optional): group of the collective
|
|
"""
|
|
SIZE_STORAGE_BYTES = 4 # int32 to encode the payload size
|
|
|
|
enc = pickle.dumps(data)
|
|
enc_size = len(enc)
|
|
|
|
if enc_size + SIZE_STORAGE_BYTES > max_size:
|
|
raise ValueError(
|
|
'encoded data exceeds max_size, this can be fixed by increasing buffer size: {}'.format(enc_size))
|
|
|
|
rank = get_rank()
|
|
world_size = get_world_size()
|
|
buffer_size = max_size * world_size
|
|
|
|
if not hasattr(all_gather_list, '_buffer') or \
|
|
all_gather_list._buffer.numel() < buffer_size:
|
|
all_gather_list._buffer = torch.cuda.ByteTensor(buffer_size)
|
|
all_gather_list._cpu_buffer = torch.ByteTensor(max_size).pin_memory()
|
|
|
|
buffer = all_gather_list._buffer
|
|
buffer.zero_()
|
|
cpu_buffer = all_gather_list._cpu_buffer
|
|
|
|
assert enc_size < 256 ** SIZE_STORAGE_BYTES, 'Encoded object size should be less than {} bytes'.format(
|
|
256 ** SIZE_STORAGE_BYTES)
|
|
|
|
size_bytes = enc_size.to_bytes(SIZE_STORAGE_BYTES, byteorder='big')
|
|
|
|
cpu_buffer[0:SIZE_STORAGE_BYTES] = torch.ByteTensor(list(size_bytes))
|
|
cpu_buffer[SIZE_STORAGE_BYTES: enc_size + SIZE_STORAGE_BYTES] = torch.ByteTensor(list(enc))
|
|
|
|
start = rank * max_size
|
|
size = enc_size + SIZE_STORAGE_BYTES
|
|
buffer[start: start + size].copy_(cpu_buffer[:size])
|
|
|
|
all_reduce(buffer, group=group)
|
|
|
|
try:
|
|
result = []
|
|
for i in range(world_size):
|
|
out_buffer = buffer[i * max_size: (i + 1) * max_size]
|
|
size = int.from_bytes(out_buffer[0:SIZE_STORAGE_BYTES], byteorder='big')
|
|
if size > 0:
|
|
result.append(pickle.loads(bytes(out_buffer[SIZE_STORAGE_BYTES: size + SIZE_STORAGE_BYTES].tolist())))
|
|
return result
|
|
except pickle.UnpicklingError:
|
|
raise Exception(
|
|
'Unable to unpickle data from other workers. all_gather_list requires all '
|
|
'workers to enter the function together, so this error usually indicates '
|
|
'that the workers have fallen out of sync somehow. Workers can fall out of '
|
|
'sync if one of them runs out of memory, or if there are other conditions '
|
|
'in your training script that can cause one worker to finish an epoch '
|
|
'while other workers are still iterating over their portions of the data.'
|
|
)
|
|
|
|
|
|
def _calc_loss(args, loss_function, local_q_vector, local_ctx_vectors, local_caption_vectors, local_positive_idxs,
|
|
local_hard_negatives_idxs: list = None, experiment=None
|
|
):
|
|
"""
|
|
Calculates In-batch negatives schema loss and supports to run it in DDP mode by exchanging the representations
|
|
across all the nodes.
|
|
"""
|
|
distributed_world_size = 1 # args.distributed_world_size or 1
|
|
if distributed_world_size > 1:
|
|
# TODO: Add local_caption_vectors
|
|
q_vector_to_send = torch.empty_like(local_q_vector).cpu().copy_(local_q_vector).detach_()
|
|
ctx_vector_to_send = torch.empty_like(local_ctx_vectors).cpu().copy_(local_ctx_vectors).detach_()
|
|
|
|
global_question_ctx_vectors = all_gather_list(
|
|
[q_vector_to_send, ctx_vector_to_send, local_positive_idxs, local_hard_negatives_idxs],
|
|
max_size=args.global_loss_buf_sz)
|
|
|
|
global_q_vector = []
|
|
global_ctxs_vector = []
|
|
|
|
# ctxs_per_question = local_ctx_vectors.size(0)
|
|
positive_idx_per_question = []
|
|
hard_negatives_per_question = []
|
|
|
|
total_ctxs = 0
|
|
|
|
for i, item in enumerate(global_question_ctx_vectors):
|
|
q_vector, ctx_vectors, positive_idx, hard_negatives_idxs = item
|
|
|
|
if i != args.local_rank:
|
|
global_q_vector.append(q_vector.to(local_q_vector.device))
|
|
global_ctxs_vector.append(ctx_vectors.to(local_q_vector.device))
|
|
positive_idx_per_question.extend([v + total_ctxs for v in positive_idx])
|
|
hard_negatives_per_question.extend([[v + total_ctxs for v in l] for l in hard_negatives_idxs])
|
|
else:
|
|
global_q_vector.append(local_q_vector)
|
|
global_ctxs_vector.append(local_ctx_vectors)
|
|
positive_idx_per_question.extend([v + total_ctxs for v in local_positive_idxs])
|
|
hard_negatives_per_question.extend([[v + total_ctxs for v in l] for l in local_hard_negatives_idxs])
|
|
total_ctxs += ctx_vectors.size(0)
|
|
|
|
global_q_vector = torch.cat(global_q_vector, dim=0)
|
|
global_ctxs_vector = torch.cat(global_ctxs_vector, dim=0)
|
|
|
|
else:
|
|
global_q_vector = local_q_vector
|
|
global_ctxs_vector = local_ctx_vectors
|
|
global_caption_vector = local_caption_vectors
|
|
positive_idx_per_question = local_positive_idxs
|
|
hard_negatives_per_question = local_hard_negatives_idxs
|
|
|
|
loss, is_correct, scores = loss_function.calc(global_q_vector, global_ctxs_vector, global_caption_vector,
|
|
positive_idx_per_question, hard_negatives_per_question,
|
|
args.caption_score_weight, experiment)
|
|
|
|
return loss, is_correct, scores
|
|
|
|
|
|
def compare_models(model_1, model_2):
|
|
models_differ = 0
|
|
for key_item_1, key_item_2 in zip(model_1.state_dict().items(), model_2.state_dict().items()):
|
|
if torch.equal(key_item_1[1], key_item_2[1]):
|
|
pass
|
|
else:
|
|
models_differ += 1
|
|
if (key_item_1[0] == key_item_2[0]):
|
|
print('Mismtach found at', key_item_1[0])
|
|
else:
|
|
raise Exception
|
|
if models_differ == 0:
|
|
print('Models match perfectly! :)')
|
|
|
|
|
|
def is_main_process():
|
|
return hvd.rank() == 0
|
|
|
|
|
|
def display_img(img_meta, name, img_only=False):
|
|
import matplotlib.pyplot as plt
|
|
import matplotlib.image as mpimg
|
|
img = mpimg.imread(img_meta[name]['img_file'])
|
|
plt.imshow(img)
|
|
plt.show()
|
|
if not img_only:
|
|
print('annotation')
|
|
print('\t' + '\n\t'.join(img_meta[name]['annotation']))
|
|
print('caption')
|
|
print('\t' + img_meta[name]['caption'][0])
|
|
|
|
|
|
def retrieve_query(model, query, indexer, args, top=10):
|
|
input_ids = args.tokenizer.encode(query)
|
|
input_ids = torch.LongTensor(input_ids).to(args.device).unsqueeze(0)
|
|
attn_mask = torch.ones(len(input_ids[0]), dtype=torch.long, device=args.device).unsqueeze(0)
|
|
pos_ids = torch.arange(len(input_ids[0]), dtype=torch.long, device=args.device).unsqueeze(0)
|
|
_, query_vector, _ = model.txt_model(input_ids=input_ids,attention_mask=attn_mask, position_ids=pos_ids)
|
|
res = indexer.search_knn(query_vector.detach().cpu().numpy(), 100)
|
|
return res
|
|
|
|
|
|
def get_model_encoded_vecs(model, dataloader):
|
|
img_embedding, caption_embedding, query_embedding = dict(), dict(), defaultdict(list)
|
|
labels_img_name = []
|
|
# for i, batch in enumerate(dataloader):
|
|
for i, batch in enumerate(tqdm.tqdm(dataloader)):
|
|
with torch.no_grad():
|
|
model_out = model(batch)
|
|
local_q_vectors, local_ctx_vectors, local_caption_vectors = model_out
|
|
|
|
img_embedding.update({img_id: img_vec.detach().cpu().numpy() for img_id, img_vec in zip(batch['img_fname'], local_ctx_vectors)})
|
|
caption_embedding.update({img_id: cap_vec.detach().cpu().numpy() for img_id, cap_vec in zip(batch['img_fname'], local_caption_vectors)})
|
|
query_embedding.update({img_id: cap_vec.detach().cpu().numpy() for img_id, cap_vec in zip(batch['txt_index'], local_q_vectors)})
|
|
|
|
labels_img_name.extend(batch['img_fname'])
|
|
return {
|
|
'img_embed': img_embedding,
|
|
'caption_embed': caption_embedding,
|
|
'txt_embed': query_embedding,
|
|
'img_name': labels_img_name
|
|
}
|
|
|