towhee
/
distill-and-select
copied
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions
19 lines
725 B
19 lines
725 B
3 years ago
|
from .layers import *
|
||
|
from .losses import *
|
||
|
from .similarities import *
|
||
|
|
||
|
|
||
|
def check_dims(features, mask=None, axis=0):
|
||
|
if features.ndim == 4:
|
||
|
return features, mask
|
||
|
elif features.ndim == 3:
|
||
|
features = features.unsqueeze(axis)
|
||
|
if mask is not None:
|
||
|
mask = mask.unsqueeze(axis)
|
||
|
return features, mask
|
||
|
else:
|
||
|
raise Exception('Wrong shape of input video tensor. The shape of the tensor must be either '
|
||
|
'[N, T, R, D] or [T, R, D], where N is the batch size, T the number of frames, '
|
||
|
'R the number of regions and D number of dimensions. '
|
||
|
'Input video tensor has shape {}'.format(features.shape))
|