camel
copied
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions
82 lines
2.6 KiB
82 lines
2.6 KiB
3 years ago
|
from contextlib import contextmanager
|
||
|
from torch import nn
|
||
|
from utils.typing import *
|
||
|
|
||
|
|
||
|
class Module(nn.Module):
|
||
|
def __init__(self):
|
||
|
super(Module, self).__init__()
|
||
|
self._is_stateful = False
|
||
|
self._state_names = []
|
||
|
self._state_defaults = dict()
|
||
|
|
||
|
def register_state(self, name: str, default: TensorOrNone):
|
||
|
self._state_names.append(name)
|
||
|
if default is None:
|
||
|
self._state_defaults[name] = None
|
||
|
else:
|
||
|
self._state_defaults[name] = default.clone().detach()
|
||
|
self.register_buffer(name, default)
|
||
|
|
||
|
def states(self):
|
||
|
for name in self._state_names:
|
||
|
yield self._buffers[name]
|
||
|
for m in self.children():
|
||
|
if isinstance(m, Module):
|
||
|
yield from m.states()
|
||
|
|
||
|
def apply_to_states(self, fn):
|
||
|
for name in self._state_names:
|
||
|
if self._buffers[name] is not None:
|
||
|
self._buffers[name] = fn(self._buffers[name])
|
||
|
for m in self.children():
|
||
|
if isinstance(m, Module):
|
||
|
m.apply_to_states(fn)
|
||
|
|
||
|
def _init_states(self, batch_size: int):
|
||
|
for name in self._state_names:
|
||
|
if self._state_defaults[name] is None:
|
||
|
self._buffers[name] = None
|
||
|
else:
|
||
|
self._buffers[name] = self._state_defaults[name].clone().detach().to(self._buffers[name].device)
|
||
|
self._buffers[name] = self._buffers[name].unsqueeze(0)
|
||
|
self._buffers[name] = self._buffers[name].expand([batch_size, ] + list(self._buffers[name].shape[1:]))
|
||
|
self._buffers[name] = self._buffers[name].contiguous()
|
||
|
|
||
|
def _reset_states(self):
|
||
|
for name in self._state_names:
|
||
|
if self._state_defaults[name] is None:
|
||
|
self._buffers[name] = None
|
||
|
else:
|
||
|
self._buffers[name] = self._state_defaults[name].clone().detach().to(self._buffers[name].device)
|
||
|
|
||
|
def enable_statefulness(self, batch_size: int):
|
||
|
for m in self.children():
|
||
|
if isinstance(m, Module):
|
||
|
m.enable_statefulness(batch_size)
|
||
|
self._init_states(batch_size)
|
||
|
self._is_stateful = True
|
||
|
|
||
|
def disable_statefulness(self):
|
||
|
for m in self.children():
|
||
|
if isinstance(m, Module):
|
||
|
m.disable_statefulness()
|
||
|
self._reset_states()
|
||
|
self._is_stateful = False
|
||
|
|
||
|
@contextmanager
|
||
|
def statefulness(self, batch_size: int):
|
||
|
self.enable_statefulness(batch_size)
|
||
|
try:
|
||
|
yield
|
||
|
finally:
|
||
|
self.disable_statefulness()
|
||
|
|
||
|
|
||
|
class ModuleList(nn.ModuleList, Module):
|
||
|
pass
|
||
|
|
||
|
|
||
|
class ModuleDict(nn.ModuleDict, Module):
|
||
|
pass
|