diff --git a/bridge_former.py b/bridge_former.py index 400e6eb..14fcedb 100644 --- a/bridge_former.py +++ b/bridge_former.py @@ -52,14 +52,8 @@ class BridgeFormer(NNOperator): model_name=self.model_name) self.tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased', TOKENIZERS_PARALLELISM=False) - # self.transform_cfgs = configs(self.model_name) - self.transform_cfgs = get_configs( - side_size=256, - crop_size=224, - num_frames=None, - mean=[0.485, 0.456, 0.406], - std=[0.229, 0.224, 0.225], - ) + self.transform_cfgs = configs(self.model_name) + self.model.eval() def decoder_video(self, data: List[VideoFrame]): @@ -73,6 +67,7 @@ class BridgeFormer(NNOperator): video=video, **self.transform_cfgs ) + # [B x C x T x H x W] video = video.to(self.device)[None, ...] return video