logo
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions

82 lines
2.6 KiB

from contextlib import contextmanager
from torch import nn
from utils.typing import *
class Module(nn.Module):
def __init__(self):
super(Module, self).__init__()
self._is_stateful = False
self._state_names = []
self._state_defaults = dict()
def register_state(self, name: str, default: TensorOrNone):
self._state_names.append(name)
if default is None:
self._state_defaults[name] = None
else:
self._state_defaults[name] = default.clone().detach()
self.register_buffer(name, default)
def states(self):
for name in self._state_names:
yield self._buffers[name]
for m in self.children():
if isinstance(m, Module):
yield from m.states()
def apply_to_states(self, fn):
for name in self._state_names:
if self._buffers[name] is not None:
self._buffers[name] = fn(self._buffers[name])
for m in self.children():
if isinstance(m, Module):
m.apply_to_states(fn)
def _init_states(self, batch_size: int):
for name in self._state_names:
if self._state_defaults[name] is None:
self._buffers[name] = None
else:
self._buffers[name] = self._state_defaults[name].clone().detach().to(self._buffers[name].device)
self._buffers[name] = self._buffers[name].unsqueeze(0)
self._buffers[name] = self._buffers[name].expand([batch_size, ] + list(self._buffers[name].shape[1:]))
self._buffers[name] = self._buffers[name].contiguous()
def _reset_states(self):
for name in self._state_names:
if self._state_defaults[name] is None:
self._buffers[name] = None
else:
self._buffers[name] = self._state_defaults[name].clone().detach().to(self._buffers[name].device)
def enable_statefulness(self, batch_size: int):
for m in self.children():
if isinstance(m, Module):
m.enable_statefulness(batch_size)
self._init_states(batch_size)
self._is_stateful = True
def disable_statefulness(self):
for m in self.children():
if isinstance(m, Module):
m.disable_statefulness()
self._reset_states()
self._is_stateful = False
@contextmanager
def statefulness(self, batch_size: int):
self.enable_statefulness(batch_size)
try:
yield
finally:
self.disable_statefulness()
class ModuleList(nn.ModuleList, Module):
pass
class ModuleDict(nn.ModuleDict, Module):
pass