logo
Browse Source

modifty output from video-decode from Image to VideoFrame

main
ChengZi 2 years ago
parent
commit
4fcde03a9d
  1. 4
      README.md
  2. 4
      __init__.py
  3. 16
      clip4clip.py

4
README.md

@ -28,7 +28,7 @@ import towhee
towhee.dc(['./demo_video.mp4']) \ towhee.dc(['./demo_video.mp4']) \
.video_decode.ffmpeg(sample_type='uniform_temporal_subsample', args={'num_samples': 12}) \ .video_decode.ffmpeg(sample_type='uniform_temporal_subsample', args={'num_samples': 12}) \
.runas_op(func=lambda x: [y[0] for y in x]) \
.runas_op(func=lambda x: [y for y in x]) \
.clip4clip(model_name='clip_vit_b32', modality='video', weight_path='./pytorch_model.bin.1') \ .clip4clip(model_name='clip_vit_b32', modality='video', weight_path='./pytorch_model.bin.1') \
.show() .show()
@ -46,7 +46,7 @@ import towhee
towhee.dc['path'](['./demo_video.mp4']) \ towhee.dc['path'](['./demo_video.mp4']) \
.video_decode.ffmpeg['path', 'frames'](sample_type='uniform_temporal_subsample', args={'num_samples': 12}) \ .video_decode.ffmpeg['path', 'frames'](sample_type='uniform_temporal_subsample', args={'num_samples': 12}) \
.runas_op['frames', 'frames'](func=lambda x: [y[0] for y in x]) \
.runas_op['frames', 'frames'](func=lambda x: [y for y in x]) \
.clip4clip['frames', 'vec'](model_name='clip_vit_b32', modality='video', weight_path='./pytorch_model.bin.1') \ .clip4clip['frames', 'vec'](model_name='clip_vit_b32', modality='video', weight_path='./pytorch_model.bin.1') \
.show() .show()

4
__init__.py

@ -15,6 +15,6 @@
from .clip4clip import CLIP4Clip from .clip4clip import CLIP4Clip
def clip4clip(model_name: str, modality: str, weight_path: str):
return CLIP4Clip(model_name, modality, weight_path)
def clip4clip(model_name: str, modality: str, weight_path: str, device: str):
return CLIP4Clip(model_name, modality, weight_path, device)

16
clip4clip.py

@ -14,7 +14,6 @@
import numpy as np import numpy as np
import torch import torch
import towhee
from typing import List, Union from typing import List, Union
from torchvision import transforms from torchvision import transforms
@ -23,7 +22,7 @@ from towhee.operator.base import NNOperator
from towhee import register from towhee import register
from towhee.models import clip4clip from towhee.models import clip4clip
from PIL import Image as PILImage from PIL import Image as PILImage
from towhee.types.image import Image
from towhee.types import VideoFrame
@register(output_schema=['vec']) @register(output_schema=['vec'])
@ -32,10 +31,13 @@ class CLIP4Clip(NNOperator):
CLIP4Clip multi-modal embedding operator CLIP4Clip multi-modal embedding operator
""" """
def __init__(self, model_name: str, modality: str, weight_path: str = None):
def __init__(self, model_name: str, modality: str, weight_path: str = None, device: str = None):
super().__init__() super().__init__()
self.modality = modality self.modality = modality
self.device = "cuda" if torch.cuda.is_available() else "cpu"
if device is None:
self.device = "cuda" if torch.cuda.is_available() else "cpu"
else:
self.device = device
self.model = clip4clip.create_model(model_name=model_name, self.model = clip4clip.create_model(model_name=model_name,
context_length=77, context_length=77,
pretrained=True, pretrained=True,
@ -52,7 +54,7 @@ class CLIP4Clip(NNOperator):
]) ])
self.model.eval() self.model.eval()
def __call__(self, data: Union[str, List[Image]]):
def __call__(self, data: Union[str, List[VideoFrame]]):
if self.modality == 'video': if self.modality == 'video':
vec = self._inference_from_video(data) vec = self._inference_from_video(data)
elif self.modality == 'text': elif self.modality == 'text':
@ -69,14 +71,14 @@ class CLIP4Clip(NNOperator):
text_features = text_features / text_features.norm(dim=-1, keepdim=True) text_features = text_features / text_features.norm(dim=-1, keepdim=True)
return text_features.detach().flatten().cpu().numpy() return text_features.detach().flatten().cpu().numpy()
def _inference_from_video(self, img_list: List[Image]):
def _inference_from_video(self, img_list: List[VideoFrame]):
self.model.eval() self.model.eval()
max_frames = 12 max_frames = 12
video = np.zeros((1, max_frames, 1, 3, 224, 224), dtype=np.float64) video = np.zeros((1, max_frames, 1, 3, 224, 224), dtype=np.float64)
slice_len = len(img_list) slice_len = len(img_list)
max_video_length = 0 if 0 > slice_len else slice_len max_video_length = 0 if 0 > slice_len else slice_len
for i, img in enumerate(img_list): for i, img in enumerate(img_list):
pil_img = PILImage.fromarray(img.to_ndarray(), img.mode)
pil_img = PILImage.fromarray(img, img.mode)
tfmed_img = self.tfms(pil_img).unsqueeze(0).to(self.device) tfmed_img = self.tfms(pil_img).unsqueeze(0).to(self.device)
if slice_len >= 1: if slice_len >= 1:
video[0, i, ...] = tfmed_img video[0, i, ...] = tfmed_img

Loading…
Cancel
Save