From d5ee886801df889738a45b0c08ed1001bdbee793 Mon Sep 17 00:00:00 2001 From: xujinling Date: Mon, 13 Jun 2022 14:34:23 +0800 Subject: [PATCH] add3 Signed-off-by: xujinling --- video_swin_transformer.py | 37 ++++++++++++++++++++++--------------- 1 file changed, 22 insertions(+), 15 deletions(-) diff --git a/video_swin_transformer.py b/video_swin_transformer.py index 31a08f0..bf6947c 100644 --- a/video_swin_transformer.py +++ b/video_swin_transformer.py @@ -66,6 +66,26 @@ class VideoSwinTransformer(NNOperator): patch_norm=self.model_configs['patch_norm'], device=self.device) + self.transform_cfgs = get_configs( + side_size=224, + crop_size=224, + num_frames=4, + mean=[0.48145466, 0.4578275, 0.40821073], + std=[0.26862954, 0.26130258, 0.27577711], + ) + + def decoder_video(self, data: List[VideoFrame]): + video = numpy.stack([img.astype(numpy.float32) / 255. for img in data], axis=0) + assert len(video.shape) == 4 + video = video.transpose(3, 0, 1, 2) # twhc -> ctwh + video = transform_video( + video=video, + **self.transform_cfgs + ) + # [B x C x T x H x W] + video = video.to(self.device)[None, ...] + return video + def __call__(self, video: List[VideoFrame]): """ Args: @@ -78,21 +98,8 @@ class VideoSwinTransformer(NNOperator): OR emb Video embedding. """ - # Convert list of towhee.types.Image to numpy.ndarray in float32 - video = numpy.stack([img.astype(numpy.float32)/255. for img in video], axis=0) - assert len(video.shape) == 4 - video = video.transpose(3, 0, 1, 2) # twhc -> ctwh - - # Transform video data given configs - if self.skip_preprocess: - self.cfg.update(num_frames=None) - - data = transform_video( - video=video, - **self.cfg - ) - inputs = data.to(self.device)[None, ...] - + inputs = self.decoder_video(video) + # inputs [B x C x T x H x W] feats = self.model.forward_features(inputs) features = feats.to('cpu').squeeze(0).detach().numpy()