diff --git a/frozen_in_time.py b/frozen_in_time.py index d1e796e..a61d03e 100644 --- a/frozen_in_time.py +++ b/frozen_in_time.py @@ -70,7 +70,7 @@ class FrozenInTime(NNOperator): video_model_type='SpaceTimeTransformer', text_is_load_pretrained=False, device=self.device) - + self.model.to(self.device) self.tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased', TOKENIZERS_PARALLELISM=False) self.transform_cfgs = get_configs( side_size=256, @@ -101,6 +101,8 @@ class FrozenInTime(NNOperator): # Convert list of towhee.types.Image to numpy.ndarray in float32 video = numpy.stack([img.astype(numpy.float32) / 255. for img in data], axis=0) assert len(video.shape) == 4 + if video.shape[0] != 4: + self.transform_cfgs.update(num_frames=4) video = video.transpose(3, 0, 1, 2) # twhc -> ctwh video = transform_video( video=video,