logo
Browse Source

add3

Signed-off-by: xujinling <jinling.xu@zilliz.com>
main
xujinling 3 years ago
parent
commit
d5ee886801
  1. 37
      video_swin_transformer.py

37
video_swin_transformer.py

@ -66,6 +66,26 @@ class VideoSwinTransformer(NNOperator):
patch_norm=self.model_configs['patch_norm'],
device=self.device)
self.transform_cfgs = get_configs(
side_size=224,
crop_size=224,
num_frames=4,
mean=[0.48145466, 0.4578275, 0.40821073],
std=[0.26862954, 0.26130258, 0.27577711],
)
def decoder_video(self, data: List[VideoFrame]):
video = numpy.stack([img.astype(numpy.float32) / 255. for img in data], axis=0)
assert len(video.shape) == 4
video = video.transpose(3, 0, 1, 2) # twhc -> ctwh
video = transform_video(
video=video,
**self.transform_cfgs
)
# [B x C x T x H x W]
video = video.to(self.device)[None, ...]
return video
def __call__(self, video: List[VideoFrame]):
"""
Args:
@ -78,21 +98,8 @@ class VideoSwinTransformer(NNOperator):
OR emb
Video embedding.
"""
# Convert list of towhee.types.Image to numpy.ndarray in float32
video = numpy.stack([img.astype(numpy.float32)/255. for img in video], axis=0)
assert len(video.shape) == 4
video = video.transpose(3, 0, 1, 2) # twhc -> ctwh
# Transform video data given configs
if self.skip_preprocess:
self.cfg.update(num_frames=None)
data = transform_video(
video=video,
**self.cfg
)
inputs = data.to(self.device)[None, ...]
inputs = self.decoder_video(video)
# inputs [B x C x T x H x W]
feats = self.model.forward_features(inputs)
features = feats.to('cpu').squeeze(0).detach().numpy()

Loading…
Cancel
Save