diff --git a/video_swin_transformer.py b/video_swin_transformer.py index 9a1c67f..b69e56e 100644 --- a/video_swin_transformer.py +++ b/video_swin_transformer.py @@ -59,7 +59,7 @@ class VideoSwinTransformer(NNOperator): device=self.device) self.model.to(self.device) self.transform_cfgs = get_configs( - side_size=224, + side_size=256, crop_size=224, num_frames=32, mean=[0.485, 0.456, 0.406], @@ -67,15 +67,20 @@ class VideoSwinTransformer(NNOperator): ) def decoder_video(self, data: List[VideoFrame]): + video = numpy.stack([img.astype(numpy.float32) / 255. for img in data], axis=0) + print(video.shape) assert len(video.shape) == 4 video = video.transpose(3, 0, 1, 2) # twhc -> ctwh + print(video.shape) video = transform_video( video=video, **self.transform_cfgs ) + print(video.shape) # [B x C x T x H x W] video = video.to(self.device)[None, ...] + print(video.shape) return video def __call__(self, video: List[VideoFrame]):