logo
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions

19 lines
383 B

from .logger import *
from .typing import *
from .utils import *
def get_batch_size(x: TensorOrSequence) -> int:
if isinstance(x, torch.Tensor):
b_s = x.size(0)
else:
b_s = x[0].size(0)
return b_s
def get_device(x: TensorOrSequence) -> int:
if isinstance(x, torch.Tensor):
b_s = x.device
else:
b_s = x[0].device
return b_s