towhee
/
distill-and-select
copied
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions
18 lines
725 B
18 lines
725 B
from .layers import *
|
|
from .losses import *
|
|
from .similarities import *
|
|
|
|
|
|
def check_dims(features, mask=None, axis=0):
|
|
if features.ndim == 4:
|
|
return features, mask
|
|
elif features.ndim == 3:
|
|
features = features.unsqueeze(axis)
|
|
if mask is not None:
|
|
mask = mask.unsqueeze(axis)
|
|
return features, mask
|
|
else:
|
|
raise Exception('Wrong shape of input video tensor. The shape of the tensor must be either '
|
|
'[N, T, R, D] or [T, R, D], where N is the batch size, T the number of frames, '
|
|
'R the number of regions and D number of dimensions. '
|
|
'Input video tensor has shape {}'.format(features.shape))
|