logo
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions

209 lines
8.4 KiB

import collections
import os
import torch
import tqdm
import logging
import numpy as np
import torch.nn as nn
from torch.utils.data import DataLoader, ConcatDataset, ChainDataset
from uniter_model.data.loader import PrefetchLoader
from dvl.data.itm import TxtTokLmdb, ItmFastDataset, ItmValDataset, itm_fast_collate
from dvl.models.bi_encoder import BiEncoderNllLoss
from dvl.utils import _calc_loss
from dvl.indexer.faiss_indexers import DenseFlatIndexer, DenseHNSWFlatIndexer
logger = logging.getLogger()
CheckpointState = collections.namedtuple("CheckpointState",
['model_dict', 'optimizer_dict', 'scheduler_dict', 'offset', 'epoch',
'encoder_params'])
class BiEncoderTrainer:
def __init__(self, args):
pass
def build_dataloader(dataset, collate_fn, is_train, opts, batch_size=None):
if batch_size is None:
batch_size = opts.train_batch_size if is_train else opts.valid_batch_size
dataloader = DataLoader(dataset, batch_size=batch_size,
shuffle=is_train, drop_last=False,
num_workers=opts.n_workers,
pin_memory=opts.pin_mem, collate_fn=collate_fn)
dataloader = PrefetchLoader(dataloader)
return dataloader
def get_model_obj(model: nn.Module):
return model.module if hasattr(model, 'module') else model
def _save_checkpoint(args, biencoder, optimizer, scheduler, epoch: int, offset: int, cp_name: str = None) -> str:
model_to_save = get_model_obj(biencoder)
if cp_name is None:
cp = os.path.join(args.output_dir, 'biencoder.' + str(epoch) + ('.' + str(offset) if offset > 0 else ''))
else:
cp = os.path.join(args.output_dir, 'biencoder.' + cp_name)
cp += '.pt'
meta_params = None
state = CheckpointState(model_to_save.state_dict(),
optimizer.state_dict(),
scheduler.state_dict(),
offset,
epoch, meta_params
)
torch.save(state._asdict(), cp)
logger.info('Saved checkpoint at %s', cp)
return cp
def load_saved_state(biencoder, optimizer=None, scheduler=None, saved_state: CheckpointState = ''):
epoch = saved_state.epoch
offset = saved_state.offset
if offset == 0: # epoch has been completed
epoch += 1
logger.info('Loading checkpoint @ batch=%s and epoch=%s', offset, epoch)
model_to_load = get_model_obj(biencoder)
logger.info('Loading saved model state ...')
model_to_load.load_state_dict(saved_state.model_dict) # set strict=False if you use extra projection
if saved_state.optimizer_dict and optimizer is not None:
logger.info('Loading saved optimizer state ...')
optimizer.load_state_dict(saved_state.optimizer_dict)
if saved_state.scheduler_dict and scheduler is not None:
scheduler_state = saved_state.scheduler_dict
scheduler.load_state_dict(scheduler_state)
def load_states_from_checkpoint(model_file: str) -> CheckpointState:
logger.info('Reading saved model from %s', model_file)
state_dict = torch.load(model_file, map_location='cpu')
logger.info('model_state_dict keys %s', state_dict.keys())
return CheckpointState(**state_dict)
def get_indexer(bi_encoder, eval_dataloader, args, hnsw_index, img_retrieval=True):
bi_encoder.eval()
img_embedding = dict()
if hnsw_index:
indexer_img = DenseHNSWFlatIndexer(args.vector_size) # modify in future
else:
indexer_img = DenseFlatIndexer(args.vector_size) # modify in future
for i, batch in enumerate(tqdm.tqdm(eval_dataloader)):
with torch.no_grad():
model_out = bi_encoder(batch)
local_q_vector, local_ctx_vectors, local_caption_vectors = model_out
if img_retrieval:
img_embedding.update({img_id: img_vec.detach().cpu().numpy() for img_id, img_vec in zip(batch['img_fname'], local_ctx_vectors)})
else:
img_embedding.update({img_id: txt_vec.detach().cpu().numpy() for img_id, txt_vec in zip(batch['txt_index'], local_q_vector)})
indexer_img.index_data(list(img_embedding.items()))
return indexer_img
def eval_model_on_dataloader(bi_encoder, eval_dataloader, args, img2txt=None, num_tops=100, no_eval=False):
total_loss = 0.0
bi_encoder.eval()
total_correct_predictions = 0
batches, total_samples = 0, 0
labels_img_name = []
labels_txt_name = []
img_embedding = dict()
txt_embedding = dict()
if args.hnsw_index:
indexer_img = DenseHNSWFlatIndexer(args.vector_size) # modify in future
indexer_txt = DenseHNSWFlatIndexer(args.vector_size) # modify in future
else:
indexer_img = DenseFlatIndexer(args.vector_size) # modify in future
indexer_txt = DenseFlatIndexer(args.vector_size) # modify in future
query_txt, query_txt_id = [], []
query_img, query_img_id = [], []
for i, batch in enumerate(eval_dataloader):
with torch.no_grad():
model_out = bi_encoder(batch)
local_q_vector, local_ctx_vectors, local_caption_vectors = model_out
query_txt.extend([out.view(-1).detach().cpu().numpy() for out in local_q_vector])
query_txt_id.extend(batch['txt_index'])
query_img.extend([out.view(-1).detach().cpu().numpy() for out in local_ctx_vectors])
query_img_id.extend(batch['img_fname'])
loss_function = BiEncoderNllLoss()
loss, correct_cnt, score = _calc_loss(args, loss_function, local_q_vector, local_ctx_vectors, local_caption_vectors,
list(range(len(local_q_vector))), None)
total_loss += loss.item()
total_correct_predictions += correct_cnt.sum().item()
batches += 1
total_samples += batch['txts']['input_ids'].shape[0]
img_embedding.update({img_id: img_vec.detach().cpu().numpy() for img_id, img_vec in zip(batch['img_fname'], local_ctx_vectors)})
txt_embedding.update({img_id: txt_vec.detach().cpu().numpy() for img_id, txt_vec in zip(batch['txt_index'], local_q_vector)})
labels_img_name.extend(batch['img_fname'])
labels_txt_name.extend(batch['txt_index'])
total_loss = total_loss / batches
correct_ratio = total_correct_predictions / float(total_samples)
query_txt_np = np.array(query_txt)
indexer_img.index_data(list(img_embedding.items()))
query_img_np = np.array(query_img)
indexer_txt.index_data(list(txt_embedding.items()))
if no_eval:
return total_loss, correct_ratio, (indexer_img, indexer_txt), (None, None), (None, None)
else:
res_txt = indexer_img.search_knn(query_txt_np, num_tops)
rank_txt_res = {query_txt_id[i]: r[0] for i, r in enumerate(res_txt)}
res_img = indexer_txt.search_knn(query_img_np, num_tops)
rank_img_res = {query_img_id[i]: r[0] for i, r in enumerate(res_img)}
recall_txt = {1: 0, 5: 0, 10: 0}
for i, q in enumerate(query_txt_id):
for top in recall_txt:
recall_txt[top] += labels_img_name[i] in rank_txt_res[q][:top]
for top in recall_txt:
recall_txt[top] = recall_txt[top] / len(rank_txt_res)
recall_img = {1: 0, 5: 0, 10: 0}
for i, q in enumerate(np.unique(query_img_id)):
for top in recall_img:
# recall_img[top] += any([txt_id in rank_img_res[q][:top] for txt_id in img2txt[q]])
recall_img[top] += any([txt_id in rank_img_res[q][:top] for txt_id in img2txt[q]])
for top in recall_img:
recall_img[top] = recall_img[top] / len(rank_img_res)
return total_loss, correct_ratio, (indexer_img, indexer_txt), (recall_txt, recall_img), (rank_txt_res, rank_img_res)
def load_dataset(all_img_dbs, txt_dbs, img_dbs, args, is_train):
if is_train:
# train datasets
datasets = []
for txt_path, img_path in zip(txt_dbs, img_dbs):
img_db = all_img_dbs[img_path]
txt_db = TxtTokLmdb(txt_path, args.max_txt_len)
datasets.append(ItmFastDataset(txt_db, img_db, args.num_hard_negatives, args.img_meta, args.tokenizer))
datasets = ConcatDataset(datasets) #
else:
# eval or test
img_db = all_img_dbs[img_dbs]
txt_db = TxtTokLmdb(txt_dbs, -1)
datasets = ItmFastDataset(txt_db, img_db, args.inf_minibatch_size, args.img_meta, args.tokenizer)
return datasets