logo
Browse Source

modifty output from video-decode from Image to VideoFrame

main
ChengZi 3 years ago
parent
commit
4fcde03a9d
  1. 4
      README.md
  2. 4
      __init__.py
  3. 14
      clip4clip.py

4
README.md

@ -28,7 +28,7 @@ import towhee
towhee.dc(['./demo_video.mp4']) \
.video_decode.ffmpeg(sample_type='uniform_temporal_subsample', args={'num_samples': 12}) \
.runas_op(func=lambda x: [y[0] for y in x]) \
.runas_op(func=lambda x: [y for y in x]) \
.clip4clip(model_name='clip_vit_b32', modality='video', weight_path='./pytorch_model.bin.1') \
.show()
@ -46,7 +46,7 @@ import towhee
towhee.dc['path'](['./demo_video.mp4']) \
.video_decode.ffmpeg['path', 'frames'](sample_type='uniform_temporal_subsample', args={'num_samples': 12}) \
.runas_op['frames', 'frames'](func=lambda x: [y[0] for y in x]) \
.runas_op['frames', 'frames'](func=lambda x: [y for y in x]) \
.clip4clip['frames', 'vec'](model_name='clip_vit_b32', modality='video', weight_path='./pytorch_model.bin.1') \
.show()

4
__init__.py

@ -15,6 +15,6 @@
from .clip4clip import CLIP4Clip
def clip4clip(model_name: str, modality: str, weight_path: str):
return CLIP4Clip(model_name, modality, weight_path)
def clip4clip(model_name: str, modality: str, weight_path: str, device: str):
return CLIP4Clip(model_name, modality, weight_path, device)

14
clip4clip.py

@ -14,7 +14,6 @@
import numpy as np
import torch
import towhee
from typing import List, Union
from torchvision import transforms
@ -23,7 +22,7 @@ from towhee.operator.base import NNOperator
from towhee import register
from towhee.models import clip4clip
from PIL import Image as PILImage
from towhee.types.image import Image
from towhee.types import VideoFrame
@register(output_schema=['vec'])
@ -32,10 +31,13 @@ class CLIP4Clip(NNOperator):
CLIP4Clip multi-modal embedding operator
"""
def __init__(self, model_name: str, modality: str, weight_path: str = None):
def __init__(self, model_name: str, modality: str, weight_path: str = None, device: str = None):
super().__init__()
self.modality = modality
if device is None:
self.device = "cuda" if torch.cuda.is_available() else "cpu"
else:
self.device = device
self.model = clip4clip.create_model(model_name=model_name,
context_length=77,
pretrained=True,
@ -52,7 +54,7 @@ class CLIP4Clip(NNOperator):
])
self.model.eval()
def __call__(self, data: Union[str, List[Image]]):
def __call__(self, data: Union[str, List[VideoFrame]]):
if self.modality == 'video':
vec = self._inference_from_video(data)
elif self.modality == 'text':
@ -69,14 +71,14 @@ class CLIP4Clip(NNOperator):
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
return text_features.detach().flatten().cpu().numpy()
def _inference_from_video(self, img_list: List[Image]):
def _inference_from_video(self, img_list: List[VideoFrame]):
self.model.eval()
max_frames = 12
video = np.zeros((1, max_frames, 1, 3, 224, 224), dtype=np.float64)
slice_len = len(img_list)
max_video_length = 0 if 0 > slice_len else slice_len
for i, img in enumerate(img_list):
pil_img = PILImage.fromarray(img.to_ndarray(), img.mode)
pil_img = PILImage.fromarray(img, img.mode)
tfmed_img = self.tfms(pil_img).unsqueeze(0).to(self.device)
if slice_len >= 1:
video[0, i, ...] = tfmed_img

Loading…
Cancel
Save