lightningdot
copied
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions
231 lines
6.9 KiB
231 lines
6.9 KiB
2 years ago
|
"""
|
||
|
distributed API using Horovod
|
||
|
"""
|
||
|
import math
|
||
|
import pickle
|
||
|
|
||
|
import torch
|
||
|
from horovod import torch as hvd
|
||
|
|
||
|
import msgpack
|
||
|
import msgpack_numpy
|
||
|
msgpack_numpy.patch()
|
||
|
|
||
|
|
||
|
def all_reduce_and_rescale_tensors(tensors, rescale_denom):
|
||
|
"""All-reduce and rescale tensors at once (as a flattened tensor)
|
||
|
|
||
|
Args:
|
||
|
tensors: list of Tensors to all-reduce
|
||
|
rescale_denom: denominator for rescaling summed Tensors
|
||
|
"""
|
||
|
# buffer size in bytes, determine equiv. # of elements based on data type
|
||
|
sz = sum(t.numel() for t in tensors)
|
||
|
buffer_t = tensors[0].new(sz).zero_()
|
||
|
|
||
|
# copy tensors into buffer_t
|
||
|
offset = 0
|
||
|
for t in tensors:
|
||
|
numel = t.numel()
|
||
|
buffer_t[offset:offset+numel].copy_(t.view(-1))
|
||
|
offset += numel
|
||
|
|
||
|
# all-reduce and rescale
|
||
|
hvd.allreduce_(buffer_t[:offset])
|
||
|
buffer_t.div_(rescale_denom)
|
||
|
|
||
|
# copy all-reduced buffer back into tensors
|
||
|
offset = 0
|
||
|
for t in tensors:
|
||
|
numel = t.numel()
|
||
|
t.view(-1).copy_(buffer_t[offset:offset+numel])
|
||
|
offset += numel
|
||
|
|
||
|
|
||
|
def all_reduce_and_rescale_tensors_chunked(tensors, rescale_denom,
|
||
|
buffer_size=10485760):
|
||
|
"""All-reduce and rescale tensors in chunks of the specified size.
|
||
|
|
||
|
Args:
|
||
|
tensors: list of Tensors to all-reduce
|
||
|
rescale_denom: denominator for rescaling summed Tensors
|
||
|
buffer_size: all-reduce chunk size in bytes
|
||
|
"""
|
||
|
# buffer size in bytes, determine equiv. # of elements based on data type
|
||
|
buffer_t = tensors[0].new(
|
||
|
math.ceil(buffer_size / tensors[0].element_size())).zero_()
|
||
|
buffer = []
|
||
|
|
||
|
def all_reduce_buffer():
|
||
|
# copy tensors into buffer_t
|
||
|
offset = 0
|
||
|
for t in buffer:
|
||
|
numel = t.numel()
|
||
|
buffer_t[offset:offset+numel].copy_(t.view(-1))
|
||
|
offset += numel
|
||
|
|
||
|
# all-reduce and rescale
|
||
|
hvd.allreduce_(buffer_t[:offset])
|
||
|
buffer_t.div_(rescale_denom)
|
||
|
|
||
|
# copy all-reduced buffer back into tensors
|
||
|
offset = 0
|
||
|
for t in buffer:
|
||
|
numel = t.numel()
|
||
|
t.view(-1).copy_(buffer_t[offset:offset+numel])
|
||
|
offset += numel
|
||
|
|
||
|
filled = 0
|
||
|
for t in tensors:
|
||
|
sz = t.numel() * t.element_size()
|
||
|
if sz > buffer_size:
|
||
|
# tensor is bigger than buffer, all-reduce and rescale directly
|
||
|
hvd.allreduce_(t)
|
||
|
t.div_(rescale_denom)
|
||
|
elif filled + sz > buffer_size:
|
||
|
# buffer is full, all-reduce and replace buffer with grad
|
||
|
all_reduce_buffer()
|
||
|
buffer = [t]
|
||
|
filled = sz
|
||
|
else:
|
||
|
# add tensor to buffer
|
||
|
buffer.append(t)
|
||
|
filled += sz
|
||
|
|
||
|
if len(buffer) > 0:
|
||
|
all_reduce_buffer()
|
||
|
|
||
|
|
||
|
def broadcast_tensors(tensors, root_rank, buffer_size=10485760):
|
||
|
"""broadcast tensors in chunks of the specified size.
|
||
|
|
||
|
Args:
|
||
|
tensors: list of Tensors to broadcast
|
||
|
root_rank: rank to broadcast
|
||
|
buffer_size: broadcast chunk size in bytes
|
||
|
"""
|
||
|
# buffer size in bytes, determine equiv. # of elements based on data type
|
||
|
buffer_t = tensors[0].new(
|
||
|
math.ceil(buffer_size / tensors[0].element_size())).zero_()
|
||
|
buffer = []
|
||
|
|
||
|
def broadcast_buffer():
|
||
|
# copy tensors into buffer_t
|
||
|
offset = 0
|
||
|
for t in buffer:
|
||
|
numel = t.numel()
|
||
|
buffer_t[offset:offset+numel].copy_(t.view(-1))
|
||
|
offset += numel
|
||
|
|
||
|
# broadcast
|
||
|
hvd.broadcast_(buffer_t[:offset], root_rank)
|
||
|
|
||
|
# copy all-reduced buffer back into tensors
|
||
|
offset = 0
|
||
|
for t in buffer:
|
||
|
numel = t.numel()
|
||
|
t.view(-1).copy_(buffer_t[offset:offset+numel])
|
||
|
offset += numel
|
||
|
|
||
|
filled = 0
|
||
|
for t in tensors:
|
||
|
sz = t.numel() * t.element_size()
|
||
|
if sz > buffer_size:
|
||
|
# tensor is bigger than buffer, broadcast directly
|
||
|
hvd.broadcast_(t, root_rank)
|
||
|
elif filled + sz > buffer_size:
|
||
|
# buffer is full, broadcast and replace buffer with tensor
|
||
|
broadcast_buffer()
|
||
|
buffer = [t]
|
||
|
filled = sz
|
||
|
else:
|
||
|
# add tensor to buffer
|
||
|
buffer.append(t)
|
||
|
filled += sz
|
||
|
|
||
|
if len(buffer) > 0:
|
||
|
broadcast_buffer()
|
||
|
|
||
|
|
||
|
def _encode(enc, max_size, buffer_=None):
|
||
|
enc_size = len(enc)
|
||
|
enc_byte = max(math.floor(math.log(max_size, 256)+1), 1)
|
||
|
if buffer_ is None or len(buffer_) < enc_size + enc_byte:
|
||
|
buffer_ = torch.cuda.ByteTensor(enc_size+enc_byte)
|
||
|
remainder = enc_size
|
||
|
for i in range(enc_byte):
|
||
|
base = 256 ** (enc_byte-i-1)
|
||
|
buffer_[i] = remainder // base
|
||
|
remainder %= base
|
||
|
buffer_[enc_byte:enc_byte+enc_size] = torch.ByteTensor(list(enc))
|
||
|
return buffer_, enc_byte
|
||
|
|
||
|
|
||
|
def _decode(buffer_, enc_byte):
|
||
|
size = sum(256 ** (enc_byte-i-1) * buffer_[i].item()
|
||
|
for i in range(enc_byte))
|
||
|
bytes_list = bytes(buffer_[enc_byte:enc_byte+size].tolist())
|
||
|
shift = size + enc_byte
|
||
|
return bytes_list, shift
|
||
|
|
||
|
|
||
|
_BUFFER_SIZE = 4096
|
||
|
|
||
|
|
||
|
def all_gather_list(data):
|
||
|
"""Gathers arbitrary data from all nodes into a list."""
|
||
|
if not hasattr(all_gather_list, '_buffer'):
|
||
|
# keeps small buffer to avoid re-allocate every call
|
||
|
all_gather_list._buffer = torch.cuda.ByteTensor(_BUFFER_SIZE)
|
||
|
try:
|
||
|
enc = msgpack.dumps(data, use_bin_type=True)
|
||
|
msgpack_success = True
|
||
|
except TypeError:
|
||
|
enc = pickle.dumps(data)
|
||
|
msgpack_success = False
|
||
|
|
||
|
enc_size = len(enc)
|
||
|
max_size = hvd.allgather(torch.tensor([enc_size]).cuda()).max().item()
|
||
|
buffer_ = all_gather_list._buffer
|
||
|
in_buffer, enc_byte = _encode(enc, max_size, buffer_)
|
||
|
|
||
|
out_buffer = hvd.allgather(in_buffer[:enc_byte+enc_size])
|
||
|
|
||
|
results = []
|
||
|
for _ in range(hvd.size()):
|
||
|
bytes_list, shift = _decode(out_buffer, enc_byte)
|
||
|
out_buffer = out_buffer[shift:]
|
||
|
|
||
|
if msgpack_success:
|
||
|
result = msgpack.loads(bytes_list, raw=False)
|
||
|
else:
|
||
|
result = pickle.loads(bytes_list)
|
||
|
results.append(result)
|
||
|
return results
|
||
|
|
||
|
|
||
|
def any_broadcast(data, root_rank):
|
||
|
"""broadcast arbitrary data from root_rank to all nodes."""
|
||
|
if not hasattr(any_broadcast, '_buffer'):
|
||
|
# keeps small buffer to avoid re-allocate every call
|
||
|
any_broadcast._buffer = torch.cuda.ByteTensor(_BUFFER_SIZE)
|
||
|
try:
|
||
|
enc = msgpack.dumps(data, use_bin_type=True)
|
||
|
msgpack_success = True
|
||
|
except TypeError:
|
||
|
enc = pickle.dumps(data)
|
||
|
msgpack_success = False
|
||
|
|
||
|
max_size = hvd.allgather(torch.tensor([len(enc)]).cuda()).max().item()
|
||
|
buffer_ = any_broadcast._buffer
|
||
|
buffer_, enc_byte = _encode(enc, max_size, buffer_)
|
||
|
|
||
|
hvd.broadcast_(buffer_, root_rank)
|
||
|
|
||
|
bytes_list, _ = _decode(buffer_, enc_byte)
|
||
|
if msgpack_success:
|
||
|
result = msgpack.loads(bytes_list, raw=False)
|
||
|
else:
|
||
|
result = pickle.loads(bytes_list)
|
||
|
return result
|