import collections import itertools import os import numpy as np import torch from pycocotools.coco import COCO as pyCOCO from torch.utils.data import Dataset as PthDataset from utils import nostdout from .example import Example class Dataset(PthDataset): def __init__(self, examples, fields): self.examples = examples self.fields = dict(fields) def collate_fn(self): def collate(batch): if len(self.fields) == 1: batch = [batch, ] else: batch = list(zip(*batch)) tensors = [] for field, data in zip(self.fields.values(), batch): tensor = field.process(data) if isinstance(tensor, collections.Sequence) and any(isinstance(t, torch.Tensor) for t in tensor): tensors.extend(tensor) else: tensors.append(tensor) if len(tensors) > 1: return tensors else: return tensors[0] return collate def __getitem__(self, i): example = self.examples[i] data = [] for field_name, field in self.fields.items(): data.append(field.preprocess(getattr(example, field_name, None))) if len(data) == 1: data = data[0] return data def __len__(self): return len(self.examples) class ValueDataset(Dataset): def __init__(self, examples, fields, dictionary): self.dictionary = dictionary super(ValueDataset, self).__init__(examples, fields) def collate_fn(self): def collate(batch): value_batch_flattened = list(itertools.chain(*batch)) value_tensors_flattened = super(ValueDataset, self).collate_fn()(value_batch_flattened) lengths = [0, ] + list(itertools.accumulate([len(x) for x in batch])) if isinstance(value_tensors_flattened, collections.Sequence) \ and any(isinstance(t, torch.Tensor) for t in value_tensors_flattened): value_tensors = [[vt[s:e] for (s, e) in zip(lengths[:-1], lengths[1:])] for vt in value_tensors_flattened] else: value_tensors = [value_tensors_flattened[s:e] for (s, e) in zip(lengths[:-1], lengths[1:])] return value_tensors return collate def __getitem__(self, i): if i not in self.dictionary: raise IndexError values_data = [] for idx in self.dictionary[i]: value_data = super(ValueDataset, self).__getitem__(idx) values_data.append(value_data) return values_data def __len__(self): return len(self.dictionary) class DictionaryDataset(Dataset): def __init__(self, examples, fields, key_fields): if not isinstance(key_fields, (tuple, list)): key_fields = (key_fields,) for field in key_fields: assert (field in fields) dictionary = collections.defaultdict(list) key_fields = {k: fields[k] for k in key_fields} value_fields = {k: fields[k] for k in fields.keys() if k not in key_fields} key_examples = [] key_dict = dict() value_examples = [] for i, e in enumerate(examples): key_example = Example.fromdict({k: getattr(e, k) for k in key_fields}) value_example = Example.fromdict({v: getattr(e, v) for v in value_fields}) if key_example not in key_dict: key_dict[key_example] = len(key_examples) key_examples.append(key_example) value_examples.append(value_example) dictionary[key_dict[key_example]].append(i) self.key_dataset = Dataset(key_examples, key_fields) self.value_dataset = ValueDataset(value_examples, value_fields, dictionary) super(DictionaryDataset, self).__init__(examples, fields) def collate_fn(self): def collate(batch): key_batch, value_batch = list(zip(*batch)) key_tensors = self.key_dataset.collate_fn()(key_batch) value_tensors = self.value_dataset.collate_fn()(value_batch) return key_tensors, value_tensors return collate def __getitem__(self, i): return self.key_dataset[i], self.value_dataset[i] def __len__(self): return len(self.key_dataset) def unique(sequence): seen = set() if isinstance(sequence[0], list): return [x for x in sequence if not (tuple(x) in seen or seen.add(tuple(x)))] else: return [x for x in sequence if not (x in seen or seen.add(x))] # class ImageDataset(Dataset): # def __init__(self, examples, image_field): # super().__init__(examples, {'image': image_field}) # # # class COCO2017Unlabeled(ImageDataset): # def __init__(self, image_field, img_root, load_in_tmp=False): # tmp_path = os.path.join('/tmp/coco2017unlabeled') # if load_in_tmp and sync_path(img_root, tmp_path, size=19*1024**3): # img_root = tmp_path # # examples = [Example.fromdict({'image': os.path.join(img_root, f)}) for f in os.listdir(img_root)] # super().__init__(examples, image_field) # # # class Nocaps(ImageDataset): # def __init__(self, image_field, img_root, load_in_tmp=False): # tmp_path = os.path.join('/tmp/nocaps') # if load_in_tmp and sync_path(img_root, tmp_path, size=3.6*1024**3): # img_root = tmp_path # # examples = [] # for split in ('validation', 'test'): # examples += [Example.fromdict({'image': os.path.join(img_root, split, f)}) for f in os.listdir(img_root + '/' + split)] # super().__init__(examples, image_field) class PairedDataset(Dataset): def __init__(self, examples, image_field, text_field): super(PairedDataset, self).__init__(examples, {'image': image_field, 'text': text_field}) self.image_field = self.fields['image'] self.text_field = self.fields['text'] def image_set(self): img_list = [e.image for e in self.examples] image_set = unique(img_list) examples = [Example.fromdict({'image': i}) for i in image_set] dataset = Dataset(examples, {'image': self.image_field}) return dataset def text_set(self): text_list = [e.text for e in self.examples] text_list = unique(text_list) examples = [Example.fromdict({'text': t}) for t in text_list] dataset = Dataset(examples, {'text': self.text_field}) return dataset def image_dictionary(self, fields=None): if not fields: fields = self.fields dataset = DictionaryDataset(self.examples, fields, key_fields='image') return dataset def text_dictionary(self, fields=None): if not fields: fields = self.fields dataset = DictionaryDataset(self.examples, fields, key_fields='text') return dataset @property def splits(self): raise NotImplementedError class COCO(PairedDataset): def __init__(self, image_field, text_field, img_root, ann_root, id_root=None, use_restval=True, cut_validation=False): roots = {} roots['train'] = { 'img': os.path.join(img_root, 'train2014'), 'cap': os.path.join(ann_root, 'captions_train2014.json') } roots['val'] = { 'img': os.path.join(img_root, 'val2014'), 'cap': os.path.join(ann_root, 'captions_val2014.json') } roots['test'] = { 'img': os.path.join(img_root, 'val2014'), 'cap': os.path.join(ann_root, 'captions_val2014.json') } roots['trainrestval'] = { 'img': (roots['train']['img'], roots['val']['img']), 'cap': (roots['train']['cap'], roots['val']['cap']) } if id_root is not None: ids = {} ids['train'] = np.load(os.path.join(id_root, 'coco_train_ids.npy')) ids['val'] = np.load(os.path.join(id_root, 'coco_dev_ids.npy')) if cut_validation: ids['val'] = ids['val'][:5000] ids['test'] = np.load(os.path.join(id_root, 'coco_test_ids.npy')) ids['trainrestval'] = ( ids['train'], np.load(os.path.join(id_root, 'coco_restval_ids.npy'))) if use_restval: roots['train'] = roots['trainrestval'] ids['train'] = ids['trainrestval'] else: ids = None with nostdout(): self.train_examples, self.val_examples, self.test_examples = self.get_samples(roots, ids) examples = self.train_examples + self.val_examples + self.test_examples super(COCO, self).__init__(examples, image_field, text_field) @property def splits(self): train_split = PairedDataset(self.train_examples, self.image_field, self.text_field) val_split = PairedDataset(self.val_examples, self.image_field, self.text_field) test_split = PairedDataset(self.test_examples, self.image_field, self.text_field) return train_split, val_split, test_split @classmethod def get_samples(cls, roots, ids_dataset=None): train_samples = [] val_samples = [] test_samples = [] for split in ['train', 'val', 'test']: if isinstance(roots[split]['cap'], tuple): coco_dataset = (pyCOCO(roots[split]['cap'][0]), pyCOCO(roots[split]['cap'][1])) root = roots[split]['img'] else: coco_dataset = (pyCOCO(roots[split]['cap']),) root = (roots[split]['img'],) if ids_dataset is None: ids = list(coco_dataset.anns.keys()) else: ids = ids_dataset[split] if isinstance(ids, tuple): bp = len(ids[0]) ids = list(ids[0]) + list(ids[1]) else: bp = len(ids) for index in range(len(ids)): if index < bp: coco = coco_dataset[0] img_root = root[0] else: coco = coco_dataset[1] img_root = root[1] ann_id = ids[index] caption = coco.anns[ann_id]['caption'] img_id = coco.anns[ann_id]['image_id'] filename = coco.loadImgs(img_id)[0]['file_name'] example = Example.fromdict({'image': os.path.join(img_root, filename), 'text': caption}) if split == 'train': train_samples.append(example) elif split == 'val': val_samples.append(example) elif split == 'test': test_samples.append(example) return train_samples, val_samples, test_samples