camel
copied
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions
127 lines
4.3 KiB
127 lines
4.3 KiB
# coding: utf8
|
|
from itertools import takewhile
|
|
|
|
import torch
|
|
from torch.utils.data.dataloader import default_collate
|
|
from torchvision.datasets.folder import default_loader
|
|
|
|
from .tokenizer.simple_tokenizer import SimpleTokenizer as _Tokenizer
|
|
|
|
|
|
class RawField(object):
|
|
""" Defines a general datatype.
|
|
|
|
Every dataset consists of one or more types of data. For instance,
|
|
a machine translation dataset contains paired examples of text, while
|
|
an image captioning dataset contains images and texts.
|
|
Each of these types of data is represented by a RawField object.
|
|
An RawField object does not assume any property of the data type and
|
|
it holds parameters relating to how a datatype should be processed.
|
|
|
|
Attributes:
|
|
preprocessing: The Pipeline that will be applied to examples
|
|
using this field before creating an example.
|
|
Default: None.
|
|
postprocessing: A Pipeline that will be applied to a list of examples
|
|
using this field before assigning to a batch.
|
|
Function signature: (batch(list)) -> object
|
|
Default: None.
|
|
"""
|
|
|
|
def __init__(self, preprocessing=None, postprocessing=None):
|
|
self.preprocessing = preprocessing
|
|
self.postprocessing = postprocessing
|
|
|
|
def preprocess(self, x):
|
|
""" Preprocess an example if the `preprocessing` Pipeline is provided. """
|
|
if self.preprocessing is not None:
|
|
return self.preprocessing(x)
|
|
else:
|
|
return x
|
|
|
|
def process(self, batch, *args, **kwargs):
|
|
""" Process a list of examples to create a batch.
|
|
|
|
Postprocess the batch with user-provided Pipeline.
|
|
|
|
Args:
|
|
batch (list(object)): A list of object from a batch of examples.
|
|
Returns:
|
|
object: Processed object given the input and custom
|
|
postprocessing Pipeline.
|
|
"""
|
|
if self.postprocessing is not None:
|
|
batch = self.postprocessing(batch)
|
|
return default_collate(batch)
|
|
|
|
|
|
class Merge(RawField):
|
|
def __init__(self, *fields):
|
|
super(Merge, self).__init__()
|
|
self.fields = fields
|
|
|
|
def preprocess(self, x):
|
|
return tuple(f.preprocess(x) for f in self.fields)
|
|
|
|
def process(self, batch, *args, **kwargs):
|
|
if len(self.fields) == 1:
|
|
batch = [batch, ]
|
|
else:
|
|
batch = list(zip(*batch))
|
|
|
|
out = list(f.process(b, *args, **kwargs) for f, b in zip(self.fields, batch))
|
|
return out
|
|
|
|
|
|
class ImageField(RawField):
|
|
def __init__(self, preprocessing=None, postprocessing=None, loader=default_loader, transform=None):
|
|
self.loader = loader
|
|
self.transform = transform
|
|
super().__init__(preprocessing, postprocessing)
|
|
|
|
def preprocess(self, x):
|
|
sample = self.loader(x)
|
|
if self.transform is not None:
|
|
sample = self.transform(sample)
|
|
return sample
|
|
|
|
|
|
class TextField(RawField):
|
|
def __init__(self):
|
|
self._tokenizer = _Tokenizer()
|
|
super(TextField, self).__init__()
|
|
|
|
def preprocess(self, x):
|
|
if x is None:
|
|
return ''
|
|
return x
|
|
|
|
def process(self, texts):
|
|
if isinstance(texts, str):
|
|
texts = [texts]
|
|
|
|
sot_token = self._tokenizer.bos_idx
|
|
eot_token = self._tokenizer.eos_idx
|
|
all_tokens = [[sot_token] + self._tokenizer.encode(text) + [eot_token] for text in texts]
|
|
result = torch.zeros(len(all_tokens), max(len(s) for s in all_tokens), dtype=torch.long)
|
|
|
|
for i, tokens in enumerate(all_tokens):
|
|
result[i, :len(tokens)] = torch.tensor(tokens)
|
|
|
|
return result
|
|
|
|
def decode(self, word_idxs):
|
|
if isinstance(word_idxs, list) and len(word_idxs) == 0:
|
|
return self.decode([word_idxs, ])[0]
|
|
if isinstance(word_idxs, list) and isinstance(word_idxs[0], int):
|
|
return self.decode([word_idxs, ])[0]
|
|
elif isinstance(word_idxs, torch.Tensor) and word_idxs.ndimension() == 1:
|
|
return self.decode(word_idxs.unsqueeze(0))[0]
|
|
|
|
captions = []
|
|
for wis in word_idxs:
|
|
wis = wis.tolist()
|
|
wis = list(takewhile(lambda tok: tok != self._tokenizer.eos_idx, wis))
|
|
caption = self._tokenizer.decode(wis)
|
|
captions.append(caption)
|
|
return captions
|