camel
copied
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions
19 lines
383 B
19 lines
383 B
from .logger import *
|
|
from .typing import *
|
|
from .utils import *
|
|
|
|
|
|
def get_batch_size(x: TensorOrSequence) -> int:
|
|
if isinstance(x, torch.Tensor):
|
|
b_s = x.size(0)
|
|
else:
|
|
b_s = x[0].size(0)
|
|
return b_s
|
|
|
|
|
|
def get_device(x: TensorOrSequence) -> int:
|
|
if isinstance(x, torch.Tensor):
|
|
b_s = x.device
|
|
else:
|
|
b_s = x[0].device
|
|
return b_s
|