""" sampler for length bucketing (batch by tokens) """ import math import random import torch import horovod.torch as hvd from torch.utils.data import Sampler from cytoolz import partition_all class TokenBucketSampler(Sampler): def __init__(self, lens, bucket_size, batch_size, droplast=False, size_multiple=8): self._lens = lens self._max_tok = batch_size self._bucket_size = bucket_size self._droplast = droplast self._size_mul = size_multiple def _create_ids(self): return list(range(len(self._lens))) def _sort_fn(self, i): return self._lens[i] def __iter__(self): ids = self._create_ids() random.shuffle(ids) buckets = [sorted(ids[i:i+self._bucket_size], key=self._sort_fn, reverse=True) for i in range(0, len(ids), self._bucket_size)] # fill batches until max_token (include padding) batches = [] for bucket in buckets: max_len = 0 batch_indices = [] for indices in partition_all(self._size_mul, bucket): max_len = max(max_len, max(self._lens[i] for i in indices)) if (max_len * (len(batch_indices) + self._size_mul) > self._max_tok): if not batch_indices: raise ValueError( "max_tokens too small / max_seq_len too long") assert len(batch_indices) % self._size_mul == 0 batches.append(batch_indices) batch_indices = list(indices) else: batch_indices.extend(indices) if not self._droplast and batch_indices: batches.append(batch_indices) random.shuffle(batches) return iter(batches) def __len__(self): raise ValueError("NOT supported. " "This has some randomness across epochs") class DistributedSampler(Sampler): """Sampler that restricts data loading to a subset of the dataset. It is especially useful in conjunction with :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each process can pass a DistributedSampler instance as a DataLoader sampler, and load a subset of the original dataset that is exclusive to it. .. note:: Dataset is assumed to be of constant size. Arguments: dataset: Dataset used for sampling. num_replicas (optional): Number of processes participating in distributed training. rank (optional): Rank of the current process within num_replicas. shuffle (optional): If true (default), sampler will shuffle the indices """ def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True): if num_replicas is None: num_replicas = hvd.size() if rank is None: rank = hvd.rank() self.dataset = dataset self.num_replicas = num_replicas self.rank = rank self.epoch = 0 self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) self.total_size = self.num_samples * self.num_replicas self.shuffle = shuffle def __iter__(self): # deterministically shuffle based on epoch g = torch.Generator() g.manual_seed(self.epoch) indices = list(range(len(self.dataset))) # add extra samples to make it evenly divisible indices += indices[:(self.total_size - len(indices))] assert len(indices) == self.total_size # subsample indices = indices[self.rank:self.total_size:self.num_replicas] if self.shuffle: shufle_ind = torch.randperm(len(indices), generator=g).tolist() indices = [indices[i] for i in shufle_ind] assert len(indices) == self.num_samples return iter(indices) def __len__(self): return self.num_samples def set_epoch(self, epoch): self.epoch = epoch