diff --git a/README.md b/README.md index 39e773f..111d9fa 100644 --- a/README.md +++ b/README.md @@ -28,7 +28,7 @@ import towhee towhee.dc(['./demo_video.mp4']) \ .video_decode.ffmpeg(sample_type='uniform_temporal_subsample', args={'num_samples': 12}) \ - .runas_op(func=lambda x: [y[0] for y in x]) \ + .runas_op(func=lambda x: [y for y in x]) \ .clip4clip(model_name='clip_vit_b32', modality='video', weight_path='./pytorch_model.bin.1') \ .show() @@ -46,7 +46,7 @@ import towhee towhee.dc['path'](['./demo_video.mp4']) \ .video_decode.ffmpeg['path', 'frames'](sample_type='uniform_temporal_subsample', args={'num_samples': 12}) \ - .runas_op['frames', 'frames'](func=lambda x: [y[0] for y in x]) \ + .runas_op['frames', 'frames'](func=lambda x: [y for y in x]) \ .clip4clip['frames', 'vec'](model_name='clip_vit_b32', modality='video', weight_path='./pytorch_model.bin.1') \ .show() diff --git a/__init__.py b/__init__.py index 5bc9b3f..c68e45b 100644 --- a/__init__.py +++ b/__init__.py @@ -15,6 +15,6 @@ from .clip4clip import CLIP4Clip -def clip4clip(model_name: str, modality: str, weight_path: str): - return CLIP4Clip(model_name, modality, weight_path) +def clip4clip(model_name: str, modality: str, weight_path: str, device: str): + return CLIP4Clip(model_name, modality, weight_path, device) diff --git a/clip4clip.py b/clip4clip.py index 74a4ecd..0b1f040 100644 --- a/clip4clip.py +++ b/clip4clip.py @@ -14,7 +14,6 @@ import numpy as np import torch -import towhee from typing import List, Union from torchvision import transforms @@ -23,7 +22,7 @@ from towhee.operator.base import NNOperator from towhee import register from towhee.models import clip4clip from PIL import Image as PILImage -from towhee.types.image import Image +from towhee.types import VideoFrame @register(output_schema=['vec']) @@ -32,10 +31,13 @@ class CLIP4Clip(NNOperator): CLIP4Clip multi-modal embedding operator """ - def __init__(self, model_name: str, modality: str, weight_path: str = None): + def __init__(self, model_name: str, modality: str, weight_path: str = None, device: str = None): super().__init__() self.modality = modality - self.device = "cuda" if torch.cuda.is_available() else "cpu" + if device is None: + self.device = "cuda" if torch.cuda.is_available() else "cpu" + else: + self.device = device self.model = clip4clip.create_model(model_name=model_name, context_length=77, pretrained=True, @@ -52,7 +54,7 @@ class CLIP4Clip(NNOperator): ]) self.model.eval() - def __call__(self, data: Union[str, List[Image]]): + def __call__(self, data: Union[str, List[VideoFrame]]): if self.modality == 'video': vec = self._inference_from_video(data) elif self.modality == 'text': @@ -69,14 +71,14 @@ class CLIP4Clip(NNOperator): text_features = text_features / text_features.norm(dim=-1, keepdim=True) return text_features.detach().flatten().cpu().numpy() - def _inference_from_video(self, img_list: List[Image]): + def _inference_from_video(self, img_list: List[VideoFrame]): self.model.eval() max_frames = 12 video = np.zeros((1, max_frames, 1, 3, 224, 224), dtype=np.float64) slice_len = len(img_list) max_video_length = 0 if 0 > slice_len else slice_len for i, img in enumerate(img_list): - pil_img = PILImage.fromarray(img.to_ndarray(), img.mode) + pil_img = PILImage.fromarray(img, img.mode) tfmed_img = self.tfms(pil_img).unsqueeze(0).to(self.device) if slice_len >= 1: video[0, i, ...] = tfmed_img