|
|
@ -14,7 +14,6 @@ |
|
|
|
|
|
|
|
import numpy as np |
|
|
|
import torch |
|
|
|
import towhee |
|
|
|
|
|
|
|
from typing import List, Union |
|
|
|
from torchvision import transforms |
|
|
@ -23,7 +22,7 @@ from towhee.operator.base import NNOperator |
|
|
|
from towhee import register |
|
|
|
from towhee.models import clip4clip |
|
|
|
from PIL import Image as PILImage |
|
|
|
from towhee.types.image import Image |
|
|
|
from towhee.types import VideoFrame |
|
|
|
|
|
|
|
|
|
|
|
@register(output_schema=['vec']) |
|
|
@ -32,10 +31,13 @@ class CLIP4Clip(NNOperator): |
|
|
|
CLIP4Clip multi-modal embedding operator |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, model_name: str, modality: str, weight_path: str = None): |
|
|
|
def __init__(self, model_name: str, modality: str, weight_path: str = None, device: str = None): |
|
|
|
super().__init__() |
|
|
|
self.modality = modality |
|
|
|
if device is None: |
|
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
else: |
|
|
|
self.device = device |
|
|
|
self.model = clip4clip.create_model(model_name=model_name, |
|
|
|
context_length=77, |
|
|
|
pretrained=True, |
|
|
@ -52,7 +54,7 @@ class CLIP4Clip(NNOperator): |
|
|
|
]) |
|
|
|
self.model.eval() |
|
|
|
|
|
|
|
def __call__(self, data: Union[str, List[Image]]): |
|
|
|
def __call__(self, data: Union[str, List[VideoFrame]]): |
|
|
|
if self.modality == 'video': |
|
|
|
vec = self._inference_from_video(data) |
|
|
|
elif self.modality == 'text': |
|
|
@ -69,14 +71,14 @@ class CLIP4Clip(NNOperator): |
|
|
|
text_features = text_features / text_features.norm(dim=-1, keepdim=True) |
|
|
|
return text_features.detach().flatten().cpu().numpy() |
|
|
|
|
|
|
|
def _inference_from_video(self, img_list: List[Image]): |
|
|
|
def _inference_from_video(self, img_list: List[VideoFrame]): |
|
|
|
self.model.eval() |
|
|
|
max_frames = 12 |
|
|
|
video = np.zeros((1, max_frames, 1, 3, 224, 224), dtype=np.float64) |
|
|
|
slice_len = len(img_list) |
|
|
|
max_video_length = 0 if 0 > slice_len else slice_len |
|
|
|
for i, img in enumerate(img_list): |
|
|
|
pil_img = PILImage.fromarray(img.to_ndarray(), img.mode) |
|
|
|
pil_img = PILImage.fromarray(img, img.mode) |
|
|
|
tfmed_img = self.tfms(pil_img).unsqueeze(0).to(self.device) |
|
|
|
if slice_len >= 1: |
|
|
|
video[0, i, ...] = tfmed_img |
|
|
|