logo
Browse Source

repair device problem

main
ChengZi 3 years ago
parent
commit
fd265f9674
  1. 8
      clip4clip.py

8
clip4clip.py

@ -79,16 +79,16 @@ class CLIP4Clip(NNOperator):
max_video_length = 0 if 0 > slice_len else slice_len max_video_length = 0 if 0 > slice_len else slice_len
for i, img in enumerate(img_list): for i, img in enumerate(img_list):
pil_img = PILImage.fromarray(img, img.mode) pil_img = PILImage.fromarray(img, img.mode)
tfmed_img = self.tfms(pil_img).unsqueeze(0).to(self.device)
tfmed_img = self.tfms(pil_img).unsqueeze(0)
if slice_len >= 1: if slice_len >= 1:
video[0, i, ...] = tfmed_img
video[0, i, ...] = tfmed_img.cpu().numpy()
video_mask = np.zeros((1, max_frames), dtype=np.int32) video_mask = np.zeros((1, max_frames), dtype=np.int32)
video_mask[0, :max_video_length] = [1] * max_video_length video_mask[0, :max_video_length] = [1] * max_video_length
video = torch.as_tensor(video).float()
video = torch.as_tensor(video).float().to(self.device)
pair, bs, ts, channel, h, w = video.shape pair, bs, ts, channel, h, w = video.shape
video = video.view(pair * bs * ts, channel, h, w) video = video.view(pair * bs * ts, channel, h, w)
video_mask = torch.as_tensor(video_mask).float()
video_mask = torch.as_tensor(video_mask).float().to(self.device)
visual_output = self.model.get_visual_output(video, video_mask, shaped=True) visual_output = self.model.get_visual_output(video, video_mask, shaped=True)

Loading…
Cancel
Save