logo
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions

303 lines
11 KiB

import collections
import itertools
import os
import numpy as np
import torch
from pycocotools.coco import COCO as pyCOCO
from torch.utils.data import Dataset as PthDataset
from utils import nostdout
from .example import Example
class Dataset(PthDataset):
def __init__(self, examples, fields):
self.examples = examples
self.fields = dict(fields)
def collate_fn(self):
def collate(batch):
if len(self.fields) == 1:
batch = [batch, ]
else:
batch = list(zip(*batch))
tensors = []
for field, data in zip(self.fields.values(), batch):
tensor = field.process(data)
if isinstance(tensor, collections.Sequence) and any(isinstance(t, torch.Tensor) for t in tensor):
tensors.extend(tensor)
else:
tensors.append(tensor)
if len(tensors) > 1:
return tensors
else:
return tensors[0]
return collate
def __getitem__(self, i):
example = self.examples[i]
data = []
for field_name, field in self.fields.items():
data.append(field.preprocess(getattr(example, field_name, None)))
if len(data) == 1:
data = data[0]
return data
def __len__(self):
return len(self.examples)
class ValueDataset(Dataset):
def __init__(self, examples, fields, dictionary):
self.dictionary = dictionary
super(ValueDataset, self).__init__(examples, fields)
def collate_fn(self):
def collate(batch):
value_batch_flattened = list(itertools.chain(*batch))
value_tensors_flattened = super(ValueDataset, self).collate_fn()(value_batch_flattened)
lengths = [0, ] + list(itertools.accumulate([len(x) for x in batch]))
if isinstance(value_tensors_flattened, collections.Sequence) \
and any(isinstance(t, torch.Tensor) for t in value_tensors_flattened):
value_tensors = [[vt[s:e] for (s, e) in zip(lengths[:-1], lengths[1:])]
for vt in value_tensors_flattened]
else:
value_tensors = [value_tensors_flattened[s:e] for (s, e) in zip(lengths[:-1], lengths[1:])]
return value_tensors
return collate
def __getitem__(self, i):
if i not in self.dictionary:
raise IndexError
values_data = []
for idx in self.dictionary[i]:
value_data = super(ValueDataset, self).__getitem__(idx)
values_data.append(value_data)
return values_data
def __len__(self):
return len(self.dictionary)
class DictionaryDataset(Dataset):
def __init__(self, examples, fields, key_fields):
if not isinstance(key_fields, (tuple, list)):
key_fields = (key_fields,)
for field in key_fields:
assert (field in fields)
dictionary = collections.defaultdict(list)
key_fields = {k: fields[k] for k in key_fields}
value_fields = {k: fields[k] for k in fields.keys() if k not in key_fields}
key_examples = []
key_dict = dict()
value_examples = []
for i, e in enumerate(examples):
key_example = Example.fromdict({k: getattr(e, k) for k in key_fields})
value_example = Example.fromdict({v: getattr(e, v) for v in value_fields})
if key_example not in key_dict:
key_dict[key_example] = len(key_examples)
key_examples.append(key_example)
value_examples.append(value_example)
dictionary[key_dict[key_example]].append(i)
self.key_dataset = Dataset(key_examples, key_fields)
self.value_dataset = ValueDataset(value_examples, value_fields, dictionary)
super(DictionaryDataset, self).__init__(examples, fields)
def collate_fn(self):
def collate(batch):
key_batch, value_batch = list(zip(*batch))
key_tensors = self.key_dataset.collate_fn()(key_batch)
value_tensors = self.value_dataset.collate_fn()(value_batch)
return key_tensors, value_tensors
return collate
def __getitem__(self, i):
return self.key_dataset[i], self.value_dataset[i]
def __len__(self):
return len(self.key_dataset)
def unique(sequence):
seen = set()
if isinstance(sequence[0], list):
return [x for x in sequence if not (tuple(x) in seen or seen.add(tuple(x)))]
else:
return [x for x in sequence if not (x in seen or seen.add(x))]
# class ImageDataset(Dataset):
# def __init__(self, examples, image_field):
# super().__init__(examples, {'image': image_field})
#
#
# class COCO2017Unlabeled(ImageDataset):
# def __init__(self, image_field, img_root, load_in_tmp=False):
# tmp_path = os.path.join('/tmp/coco2017unlabeled')
# if load_in_tmp and sync_path(img_root, tmp_path, size=19*1024**3):
# img_root = tmp_path
#
# examples = [Example.fromdict({'image': os.path.join(img_root, f)}) for f in os.listdir(img_root)]
# super().__init__(examples, image_field)
#
#
# class Nocaps(ImageDataset):
# def __init__(self, image_field, img_root, load_in_tmp=False):
# tmp_path = os.path.join('/tmp/nocaps')
# if load_in_tmp and sync_path(img_root, tmp_path, size=3.6*1024**3):
# img_root = tmp_path
#
# examples = []
# for split in ('validation', 'test'):
# examples += [Example.fromdict({'image': os.path.join(img_root, split, f)}) for f in os.listdir(img_root + '/' + split)]
# super().__init__(examples, image_field)
class PairedDataset(Dataset):
def __init__(self, examples, image_field, text_field):
super(PairedDataset, self).__init__(examples, {'image': image_field, 'text': text_field})
self.image_field = self.fields['image']
self.text_field = self.fields['text']
def image_set(self):
img_list = [e.image for e in self.examples]
image_set = unique(img_list)
examples = [Example.fromdict({'image': i}) for i in image_set]
dataset = Dataset(examples, {'image': self.image_field})
return dataset
def text_set(self):
text_list = [e.text for e in self.examples]
text_list = unique(text_list)
examples = [Example.fromdict({'text': t}) for t in text_list]
dataset = Dataset(examples, {'text': self.text_field})
return dataset
def image_dictionary(self, fields=None):
if not fields:
fields = self.fields
dataset = DictionaryDataset(self.examples, fields, key_fields='image')
return dataset
def text_dictionary(self, fields=None):
if not fields:
fields = self.fields
dataset = DictionaryDataset(self.examples, fields, key_fields='text')
return dataset
@property
def splits(self):
raise NotImplementedError
class COCO(PairedDataset):
def __init__(self, image_field, text_field, img_root, ann_root, id_root=None, use_restval=True,
cut_validation=False):
roots = {}
roots['train'] = {
'img': os.path.join(img_root, 'train2014'),
'cap': os.path.join(ann_root, 'captions_train2014.json')
}
roots['val'] = {
'img': os.path.join(img_root, 'val2014'),
'cap': os.path.join(ann_root, 'captions_val2014.json')
}
roots['test'] = {
'img': os.path.join(img_root, 'val2014'),
'cap': os.path.join(ann_root, 'captions_val2014.json')
}
roots['trainrestval'] = {
'img': (roots['train']['img'], roots['val']['img']),
'cap': (roots['train']['cap'], roots['val']['cap'])
}
if id_root is not None:
ids = {}
ids['train'] = np.load(os.path.join(id_root, 'coco_train_ids.npy'))
ids['val'] = np.load(os.path.join(id_root, 'coco_dev_ids.npy'))
if cut_validation:
ids['val'] = ids['val'][:5000]
ids['test'] = np.load(os.path.join(id_root, 'coco_test_ids.npy'))
ids['trainrestval'] = (
ids['train'],
np.load(os.path.join(id_root, 'coco_restval_ids.npy')))
if use_restval:
roots['train'] = roots['trainrestval']
ids['train'] = ids['trainrestval']
else:
ids = None
with nostdout():
self.train_examples, self.val_examples, self.test_examples = self.get_samples(roots, ids)
examples = self.train_examples + self.val_examples + self.test_examples
super(COCO, self).__init__(examples, image_field, text_field)
@property
def splits(self):
train_split = PairedDataset(self.train_examples, self.image_field, self.text_field)
val_split = PairedDataset(self.val_examples, self.image_field, self.text_field)
test_split = PairedDataset(self.test_examples, self.image_field, self.text_field)
return train_split, val_split, test_split
@classmethod
def get_samples(cls, roots, ids_dataset=None):
train_samples = []
val_samples = []
test_samples = []
for split in ['train', 'val', 'test']:
if isinstance(roots[split]['cap'], tuple):
coco_dataset = (pyCOCO(roots[split]['cap'][0]), pyCOCO(roots[split]['cap'][1]))
root = roots[split]['img']
else:
coco_dataset = (pyCOCO(roots[split]['cap']),)
root = (roots[split]['img'],)
if ids_dataset is None:
ids = list(coco_dataset.anns.keys())
else:
ids = ids_dataset[split]
if isinstance(ids, tuple):
bp = len(ids[0])
ids = list(ids[0]) + list(ids[1])
else:
bp = len(ids)
for index in range(len(ids)):
if index < bp:
coco = coco_dataset[0]
img_root = root[0]
else:
coco = coco_dataset[1]
img_root = root[1]
ann_id = ids[index]
caption = coco.anns[ann_id]['caption']
img_id = coco.anns[ann_id]['image_id']
filename = coco.loadImgs(img_id)[0]['file_name']
example = Example.fromdict({'image': os.path.join(img_root, filename), 'text': caption})
if split == 'train':
train_samples.append(example)
elif split == 'val':
val_samples.append(example)
elif split == 'test':
test_samples.append(example)
return train_samples, val_samples, test_samples