logo
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions

97 lines
2.2 KiB

"""
minimal running script of distributed training
"""
import argparse
import random
import torch
from torch import nn
from torch.nn import functional as F
from torch.nn.utils.rnn import pad_sequence
from torch import optim
# communication operations
from utils.distributed import all_reduce_and_rescale_tensors, all_gather_list
class DataLoader(object):
def __init__(self, vocab_size, n_class, batch_size=8, lengths=(5, 10)):
self.vsize = vocab_size
self.ncls = n_class
self.bs = batch_size
self.lengths = lengths
def __iter__(self):
while True:
input_, target = self._random_batch()
yield input_, target
def _random_batch(self):
inputs = []
targets = []
for _ in range(self.bs):
i, t = self._random_inputs()
inputs.append(i)
targets.append(t)
input_ = pad_sequence(inputs)
targets = torch.LongTensor(targets)
return input_, targets
def _random_inputs(self):
len_ = random.randint(*self.lengths)
inputs = [random.randint(0, self.vsize-1) for _ in range(len_)]
target = random.randint(0, self.ncls-1)
return torch.LongTensor(inputs), target
class Model(nn.Module):
def __init__(self, vsize, ncls):
super().__init__()
self.emb = nn.Embedding(vsize, 100)
self.rnn = nn.LSTM(100, 100, 1)
self.proj = nn.Linear(100, ncls)
def forward(self, input_):
emb_out = self.emb(input_)
_, (h, c) = self.rnn(emb_out)
output = self.proj(h[-1])
return output
class InputExample(object):
def __init__(self, input, target):
self.input = input
self.target = target
def main():
vsize = 200
ncls = 10
accum = 4
total_step = 100
seed = 777
total_step = 100
random.seed(seed)
torch.manual_seed(seed)
global_step = 0
loader = DataLoader(vsize, ncls)
examples = []
print ("example generating")
for step, (input_, target) in enumerate(loader):
print ("example appended" + str(step))
examples.append(InputExample(input=input_, target = target))
global_step += 1
if global_step >= total_step:
break
print ("saving torch.save")
torch.save(examples, 'data/test_data/input0.txt')
examples = torch.load('data/test_data/input.txt')
for step, ie in enumerate(examples):
print (step)
print (ie.input)
print (ie.target)
if __name__ == '__main__':
main()