logo
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions

117 lines
4.0 KiB

""" sampler for length bucketing (batch by tokens) """
import math
import random
import torch
import horovod.torch as hvd
from torch.utils.data import Sampler
from cytoolz import partition_all
class TokenBucketSampler(Sampler):
def __init__(self, lens, bucket_size, batch_size,
droplast=False, size_multiple=8):
self._lens = lens
self._max_tok = batch_size
self._bucket_size = bucket_size
self._droplast = droplast
self._size_mul = size_multiple
def _create_ids(self):
return list(range(len(self._lens)))
def _sort_fn(self, i):
return self._lens[i]
def __iter__(self):
ids = self._create_ids()
random.shuffle(ids)
buckets = [sorted(ids[i:i+self._bucket_size],
key=self._sort_fn, reverse=True)
for i in range(0, len(ids), self._bucket_size)]
# fill batches until max_token (include padding)
batches = []
for bucket in buckets:
max_len = 0
batch_indices = []
for indices in partition_all(self._size_mul, bucket):
max_len = max(max_len, max(self._lens[i] for i in indices))
if (max_len * (len(batch_indices) + self._size_mul)
> self._max_tok):
if not batch_indices:
raise ValueError(
"max_tokens too small / max_seq_len too long")
assert len(batch_indices) % self._size_mul == 0
batches.append(batch_indices)
batch_indices = list(indices)
else:
batch_indices.extend(indices)
if not self._droplast and batch_indices:
batches.append(batch_indices)
random.shuffle(batches)
return iter(batches)
def __len__(self):
raise ValueError("NOT supported. "
"This has some randomness across epochs")
class DistributedSampler(Sampler):
"""Sampler that restricts data loading to a subset of the dataset.
It is especially useful in conjunction with
:class:`torch.nn.parallel.DistributedDataParallel`. In such case, each
process can pass a DistributedSampler instance as a DataLoader sampler,
and load a subset of the original dataset that is exclusive to it.
.. note::
Dataset is assumed to be of constant size.
Arguments:
dataset: Dataset used for sampling.
num_replicas (optional): Number of processes participating in
distributed training.
rank (optional): Rank of the current process within num_replicas.
shuffle (optional): If true (default), sampler will shuffle the indices
"""
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True):
if num_replicas is None:
num_replicas = hvd.size()
if rank is None:
rank = hvd.rank()
self.dataset = dataset
self.num_replicas = num_replicas
self.rank = rank
self.epoch = 0
self.num_samples = int(math.ceil(len(self.dataset)
* 1.0 / self.num_replicas))
self.total_size = self.num_samples * self.num_replicas
self.shuffle = shuffle
def __iter__(self):
# deterministically shuffle based on epoch
g = torch.Generator()
g.manual_seed(self.epoch)
indices = list(range(len(self.dataset)))
# add extra samples to make it evenly divisible
indices += indices[:(self.total_size - len(indices))]
assert len(indices) == self.total_size
# subsample
indices = indices[self.rank:self.total_size:self.num_replicas]
if self.shuffle:
shufle_ind = torch.randperm(len(indices), generator=g).tolist()
indices = [indices[i] for i in shufle_ind]
assert len(indices) == self.num_samples
return iter(indices)
def __len__(self):
return self.num_samples
def set_epoch(self, epoch):
self.epoch = epoch