diff --git a/clip4clip.py b/clip4clip.py index 0b1f040..b6c1ea7 100644 --- a/clip4clip.py +++ b/clip4clip.py @@ -79,16 +79,16 @@ class CLIP4Clip(NNOperator): max_video_length = 0 if 0 > slice_len else slice_len for i, img in enumerate(img_list): pil_img = PILImage.fromarray(img, img.mode) - tfmed_img = self.tfms(pil_img).unsqueeze(0).to(self.device) + tfmed_img = self.tfms(pil_img).unsqueeze(0) if slice_len >= 1: - video[0, i, ...] = tfmed_img + video[0, i, ...] = tfmed_img.cpu().numpy() video_mask = np.zeros((1, max_frames), dtype=np.int32) video_mask[0, :max_video_length] = [1] * max_video_length - video = torch.as_tensor(video).float() + video = torch.as_tensor(video).float().to(self.device) pair, bs, ts, channel, h, w = video.shape video = video.view(pair * bs * ts, channel, h, w) - video_mask = torch.as_tensor(video_mask).float() + video_mask = torch.as_tensor(video_mask).float().to(self.device) visual_output = self.model.get_visual_output(video, video_mask, shaped=True)