lightningdot
copied
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions
209 lines
8.4 KiB
209 lines
8.4 KiB
import collections
|
|
import os
|
|
import torch
|
|
import tqdm
|
|
import logging
|
|
import numpy as np
|
|
import torch.nn as nn
|
|
from torch.utils.data import DataLoader, ConcatDataset, ChainDataset
|
|
from uniter_model.data.loader import PrefetchLoader
|
|
|
|
from dvl.data.itm import TxtTokLmdb, ItmFastDataset, ItmValDataset, itm_fast_collate
|
|
from dvl.models.bi_encoder import BiEncoderNllLoss
|
|
from dvl.utils import _calc_loss
|
|
from dvl.indexer.faiss_indexers import DenseFlatIndexer, DenseHNSWFlatIndexer
|
|
|
|
|
|
logger = logging.getLogger()
|
|
CheckpointState = collections.namedtuple("CheckpointState",
|
|
['model_dict', 'optimizer_dict', 'scheduler_dict', 'offset', 'epoch',
|
|
'encoder_params'])
|
|
|
|
|
|
class BiEncoderTrainer:
|
|
def __init__(self, args):
|
|
pass
|
|
|
|
|
|
def build_dataloader(dataset, collate_fn, is_train, opts, batch_size=None):
|
|
if batch_size is None:
|
|
batch_size = opts.train_batch_size if is_train else opts.valid_batch_size
|
|
|
|
dataloader = DataLoader(dataset, batch_size=batch_size,
|
|
shuffle=is_train, drop_last=False,
|
|
num_workers=opts.n_workers,
|
|
pin_memory=opts.pin_mem, collate_fn=collate_fn)
|
|
dataloader = PrefetchLoader(dataloader)
|
|
return dataloader
|
|
|
|
|
|
def get_model_obj(model: nn.Module):
|
|
return model.module if hasattr(model, 'module') else model
|
|
|
|
|
|
def _save_checkpoint(args, biencoder, optimizer, scheduler, epoch: int, offset: int, cp_name: str = None) -> str:
|
|
model_to_save = get_model_obj(biencoder)
|
|
if cp_name is None:
|
|
cp = os.path.join(args.output_dir, 'biencoder.' + str(epoch) + ('.' + str(offset) if offset > 0 else ''))
|
|
else:
|
|
cp = os.path.join(args.output_dir, 'biencoder.' + cp_name)
|
|
cp += '.pt'
|
|
|
|
|
|
meta_params = None
|
|
|
|
state = CheckpointState(model_to_save.state_dict(),
|
|
optimizer.state_dict(),
|
|
scheduler.state_dict(),
|
|
offset,
|
|
epoch, meta_params
|
|
)
|
|
torch.save(state._asdict(), cp)
|
|
logger.info('Saved checkpoint at %s', cp)
|
|
return cp
|
|
|
|
|
|
def load_saved_state(biencoder, optimizer=None, scheduler=None, saved_state: CheckpointState = ''):
|
|
epoch = saved_state.epoch
|
|
offset = saved_state.offset
|
|
if offset == 0: # epoch has been completed
|
|
epoch += 1
|
|
logger.info('Loading checkpoint @ batch=%s and epoch=%s', offset, epoch)
|
|
|
|
model_to_load = get_model_obj(biencoder)
|
|
logger.info('Loading saved model state ...')
|
|
model_to_load.load_state_dict(saved_state.model_dict) # set strict=False if you use extra projection
|
|
|
|
if saved_state.optimizer_dict and optimizer is not None:
|
|
logger.info('Loading saved optimizer state ...')
|
|
optimizer.load_state_dict(saved_state.optimizer_dict)
|
|
|
|
if saved_state.scheduler_dict and scheduler is not None:
|
|
scheduler_state = saved_state.scheduler_dict
|
|
scheduler.load_state_dict(scheduler_state)
|
|
|
|
|
|
def load_states_from_checkpoint(model_file: str) -> CheckpointState:
|
|
logger.info('Reading saved model from %s', model_file)
|
|
state_dict = torch.load(model_file, map_location='cpu')
|
|
logger.info('model_state_dict keys %s', state_dict.keys())
|
|
return CheckpointState(**state_dict)
|
|
|
|
|
|
def get_indexer(bi_encoder, eval_dataloader, args, hnsw_index, img_retrieval=True):
|
|
bi_encoder.eval()
|
|
img_embedding = dict()
|
|
|
|
if hnsw_index:
|
|
indexer_img = DenseHNSWFlatIndexer(args.vector_size) # modify in future
|
|
else:
|
|
indexer_img = DenseFlatIndexer(args.vector_size) # modify in future
|
|
for i, batch in enumerate(tqdm.tqdm(eval_dataloader)):
|
|
with torch.no_grad():
|
|
model_out = bi_encoder(batch)
|
|
local_q_vector, local_ctx_vectors, local_caption_vectors = model_out
|
|
if img_retrieval:
|
|
img_embedding.update({img_id: img_vec.detach().cpu().numpy() for img_id, img_vec in zip(batch['img_fname'], local_ctx_vectors)})
|
|
else:
|
|
img_embedding.update({img_id: txt_vec.detach().cpu().numpy() for img_id, txt_vec in zip(batch['txt_index'], local_q_vector)})
|
|
indexer_img.index_data(list(img_embedding.items()))
|
|
return indexer_img
|
|
|
|
|
|
def eval_model_on_dataloader(bi_encoder, eval_dataloader, args, img2txt=None, num_tops=100, no_eval=False):
|
|
total_loss = 0.0
|
|
bi_encoder.eval()
|
|
total_correct_predictions = 0
|
|
batches, total_samples = 0, 0
|
|
labels_img_name = []
|
|
labels_txt_name = []
|
|
img_embedding = dict()
|
|
txt_embedding = dict()
|
|
if args.hnsw_index:
|
|
indexer_img = DenseHNSWFlatIndexer(args.vector_size) # modify in future
|
|
indexer_txt = DenseHNSWFlatIndexer(args.vector_size) # modify in future
|
|
else:
|
|
indexer_img = DenseFlatIndexer(args.vector_size) # modify in future
|
|
indexer_txt = DenseFlatIndexer(args.vector_size) # modify in future
|
|
query_txt, query_txt_id = [], []
|
|
query_img, query_img_id = [], []
|
|
for i, batch in enumerate(eval_dataloader):
|
|
with torch.no_grad():
|
|
model_out = bi_encoder(batch)
|
|
local_q_vector, local_ctx_vectors, local_caption_vectors = model_out
|
|
|
|
query_txt.extend([out.view(-1).detach().cpu().numpy() for out in local_q_vector])
|
|
query_txt_id.extend(batch['txt_index'])
|
|
|
|
query_img.extend([out.view(-1).detach().cpu().numpy() for out in local_ctx_vectors])
|
|
query_img_id.extend(batch['img_fname'])
|
|
|
|
loss_function = BiEncoderNllLoss()
|
|
|
|
loss, correct_cnt, score = _calc_loss(args, loss_function, local_q_vector, local_ctx_vectors, local_caption_vectors,
|
|
list(range(len(local_q_vector))), None)
|
|
|
|
total_loss += loss.item()
|
|
total_correct_predictions += correct_cnt.sum().item()
|
|
batches += 1
|
|
total_samples += batch['txts']['input_ids'].shape[0]
|
|
|
|
img_embedding.update({img_id: img_vec.detach().cpu().numpy() for img_id, img_vec in zip(batch['img_fname'], local_ctx_vectors)})
|
|
txt_embedding.update({img_id: txt_vec.detach().cpu().numpy() for img_id, txt_vec in zip(batch['txt_index'], local_q_vector)})
|
|
labels_img_name.extend(batch['img_fname'])
|
|
labels_txt_name.extend(batch['txt_index'])
|
|
|
|
total_loss = total_loss / batches
|
|
correct_ratio = total_correct_predictions / float(total_samples)
|
|
|
|
query_txt_np = np.array(query_txt)
|
|
indexer_img.index_data(list(img_embedding.items()))
|
|
query_img_np = np.array(query_img)
|
|
indexer_txt.index_data(list(txt_embedding.items()))
|
|
|
|
if no_eval:
|
|
return total_loss, correct_ratio, (indexer_img, indexer_txt), (None, None), (None, None)
|
|
else:
|
|
res_txt = indexer_img.search_knn(query_txt_np, num_tops)
|
|
rank_txt_res = {query_txt_id[i]: r[0] for i, r in enumerate(res_txt)}
|
|
|
|
res_img = indexer_txt.search_knn(query_img_np, num_tops)
|
|
rank_img_res = {query_img_id[i]: r[0] for i, r in enumerate(res_img)}
|
|
|
|
recall_txt = {1: 0, 5: 0, 10: 0}
|
|
for i, q in enumerate(query_txt_id):
|
|
for top in recall_txt:
|
|
recall_txt[top] += labels_img_name[i] in rank_txt_res[q][:top]
|
|
|
|
for top in recall_txt:
|
|
recall_txt[top] = recall_txt[top] / len(rank_txt_res)
|
|
|
|
recall_img = {1: 0, 5: 0, 10: 0}
|
|
for i, q in enumerate(np.unique(query_img_id)):
|
|
for top in recall_img:
|
|
# recall_img[top] += any([txt_id in rank_img_res[q][:top] for txt_id in img2txt[q]])
|
|
recall_img[top] += any([txt_id in rank_img_res[q][:top] for txt_id in img2txt[q]])
|
|
|
|
for top in recall_img:
|
|
recall_img[top] = recall_img[top] / len(rank_img_res)
|
|
|
|
return total_loss, correct_ratio, (indexer_img, indexer_txt), (recall_txt, recall_img), (rank_txt_res, rank_img_res)
|
|
|
|
|
|
def load_dataset(all_img_dbs, txt_dbs, img_dbs, args, is_train):
|
|
if is_train:
|
|
# train datasets
|
|
datasets = []
|
|
for txt_path, img_path in zip(txt_dbs, img_dbs):
|
|
img_db = all_img_dbs[img_path]
|
|
txt_db = TxtTokLmdb(txt_path, args.max_txt_len)
|
|
datasets.append(ItmFastDataset(txt_db, img_db, args.num_hard_negatives, args.img_meta, args.tokenizer))
|
|
|
|
datasets = ConcatDataset(datasets) #
|
|
else:
|
|
# eval or test
|
|
img_db = all_img_dbs[img_dbs]
|
|
txt_db = TxtTokLmdb(txt_dbs, -1)
|
|
datasets = ItmFastDataset(txt_db, img_db, args.inf_minibatch_size, args.img_meta, args.tokenizer)
|
|
|
|
return datasets
|