clip4clip
copied
ChengZi
3 years ago
3 changed files with 177 additions and 0 deletions
@ -0,0 +1,20 @@ |
|||||
|
# Copyright 2021 Zilliz. All rights reserved. |
||||
|
# |
||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
|
# you may not use this file except in compliance with the License. |
||||
|
# You may obtain a copy of the License at |
||||
|
# |
||||
|
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
|
# |
||||
|
# Unless required by applicable law or agreed to in writing, software |
||||
|
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
|
# See the License for the specific language governing permissions and |
||||
|
# limitations under the License. |
||||
|
|
||||
|
from .clip4clip import CLIP4Clip |
||||
|
|
||||
|
|
||||
|
def clip4clip(model_name: str, modality: str, weight_path: str): |
||||
|
return CLIP4Clip(model_name, modality, weight_path) |
||||
|
|
@ -0,0 +1,157 @@ |
|||||
|
# Copyright 2021 Zilliz. All rights reserved. |
||||
|
# |
||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
|
# you may not use this file except in compliance with the License. |
||||
|
# You may obtain a copy of the License at |
||||
|
# |
||||
|
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
|
# |
||||
|
# Unless required by applicable law or agreed to in writing, software |
||||
|
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
|
# See the License for the specific language governing permissions and |
||||
|
# limitations under the License. |
||||
|
import os |
||||
|
import random |
||||
|
import sys |
||||
|
from pathlib import Path |
||||
|
|
||||
|
import numpy as np |
||||
|
import torch |
||||
|
import towhee |
||||
|
from torchvision import transforms |
||||
|
from towhee.models.clip4clip import convert_tokens_to_id |
||||
|
|
||||
|
from towhee.types.image_utils import to_pil |
||||
|
from towhee.operator.base import NNOperator, OperatorFlag |
||||
|
from towhee.types.arg import arg, to_image_color |
||||
|
from towhee import register |
||||
|
from towhee.models import clip4clip |
||||
|
from towhee.utils.ndarray_utils import to_ndarray |
||||
|
from PIL import Image as PILImage |
||||
|
|
||||
|
|
||||
|
@register(name='clip4clip', output_schema=['vec']) |
||||
|
class CLIP4Clip(NNOperator): |
||||
|
""" |
||||
|
CLIP multi-modal embedding operator |
||||
|
""" |
||||
|
def __init__(self, model_name: str, modality: str, weight_path: str = None): |
||||
|
super().__init__() |
||||
|
self.modality = modality |
||||
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
||||
|
self.model = clip4clip.create_model(model_name=model_name, |
||||
|
context_length=77, |
||||
|
pretrained=True, |
||||
|
weights_path=weight_path, |
||||
|
device=self.device) |
||||
|
|
||||
|
self.tokenize = clip4clip.SimpleTokenizer() |
||||
|
self.tfms = transforms.Compose([ |
||||
|
transforms.Resize(224, interpolation=transforms.InterpolationMode.BICUBIC), |
||||
|
transforms.CenterCrop(224), |
||||
|
transforms.ToTensor(), |
||||
|
transforms.Normalize( |
||||
|
(0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) |
||||
|
]) |
||||
|
self.model.eval() |
||||
|
|
||||
|
def __call__(self, data): |
||||
|
if self.modality == 'video': |
||||
|
vec = self._inference_from_video(data) |
||||
|
elif self.modality == 'text': |
||||
|
vec = self._inference_from_text(data) |
||||
|
else: |
||||
|
raise ValueError("modality[{}] not implemented.".format(self._modality)) |
||||
|
return vec |
||||
|
# |
||||
|
def _inference_from_text(self, text): |
||||
|
self.model.eval() |
||||
|
# text = self.tokenize(text) |
||||
|
text_ids = convert_tokens_to_id(self.tokenize, text) |
||||
|
print(text_ids) |
||||
|
text_ids = torch.tensor(text_ids).unsqueeze(0).to(self.device) |
||||
|
text_features = self.model.get_sequence_output(text_ids) |
||||
|
text_features = text_features / text_features.norm(dim=-1, keepdim=True) |
||||
|
# print(text_features.norm(dim=-1, keepdim=True)) |
||||
|
return text_features#.unsqueeze(0).cpu().numpy() |
||||
|
|
||||
|
def _inference_from_video(self, img_list): |
||||
|
self.model.eval() |
||||
|
# video = self.tfms(video) |
||||
|
max_frames = 12 |
||||
|
video = np.zeros((1, max_frames, 1, 3, 224, 224), dtype=np.float) |
||||
|
slice_len = len(img_list) |
||||
|
max_video_length = 0 if 0 > slice_len else slice_len |
||||
|
for i, img in enumerate(img_list): |
||||
|
pil_img = PILImage.fromarray(img.to_ndarray(), img.mode) |
||||
|
tfmed_img = self.tfms(pil_img).unsqueeze(0).to(self.device) |
||||
|
print('tfmed_img.shape', tfmed_img.shape) |
||||
|
|
||||
|
if slice_len >= 1: |
||||
|
video[0, i, ...] = tfmed_img |
||||
|
video_mask = np.zeros((1, max_frames), dtype=np.long) |
||||
|
video_mask[0, :max_video_length] = [1] * max_video_length |
||||
|
|
||||
|
video = torch.as_tensor(video).float() |
||||
|
pair, bs, ts, channel, h, w = video.shape |
||||
|
video = video.view(pair * bs * ts, channel, h, w) |
||||
|
video_mask = torch.as_tensor(video_mask).float() |
||||
|
# video_list.append(video) |
||||
|
# video_mask_list.append(video_mask) |
||||
|
# video_list_tensor = torch.cat(video_list, dim=0) |
||||
|
# video_mask_list_tensor = torch.cat(video_mask_list, dim=0) |
||||
|
|
||||
|
visual_output = self.model.get_visual_output(video, video_mask, shaped=True) |
||||
|
|
||||
|
visual_output = visual_output / visual_output.norm(dim=-1, keepdim=True) |
||||
|
|
||||
|
video_mask_un = video_mask.to(dtype=torch.float).unsqueeze(-1) |
||||
|
visual_output = visual_output * video_mask_un |
||||
|
video_mask_un_sum = torch.sum(video_mask_un, dim=1, dtype=torch.float) |
||||
|
video_mask_un_sum[video_mask_un_sum == 0.] = 1. |
||||
|
visual_output = torch.sum(visual_output, dim=1) / video_mask_un_sum |
||||
|
|
||||
|
visual_output = visual_output / visual_output.norm(dim=-1, keepdim=True) |
||||
|
|
||||
|
return visual_output#.unsqueeze(0).cpu().numpy() |
||||
|
|
||||
|
|
||||
|
# |
||||
|
# @arg(1, to_image_color('RGB')) |
||||
|
# def _inference_from_video(self, img): |
||||
|
# img = to_pil(img) |
||||
|
# image = self.tfms(img).unsqueeze(0).to(self.device) |
||||
|
# image_features = self.model.encode_image(image) |
||||
|
# return image_features |
||||
|
if __name__ == '__main__': |
||||
|
# op = CLIP4Clip('clip_vit_b32', 'text', './pytorch_model.bin.1') |
||||
|
# res = op('kids feeding and playing with the horse') |
||||
|
# print(res.shape) |
||||
|
|
||||
|
|
||||
|
|
||||
|
|
||||
|
# from towhee import ops |
||||
|
# op = CLIP4Clip('clip_vit_b32', 'video', './pytorch_model.bin.1') |
||||
|
# d = ops.video_decode.ffmpeg(sample_type='uniform_temporal_subsample', |
||||
|
# args={'num_samples': 12}) |
||||
|
# # ops.video_decode.get_video_duration() |
||||
|
video_path = '/Users/zilliz/dataset/MSRVTT/MSRVTT/videos/all/video9451.mp4' |
||||
|
# img_list = [] |
||||
|
# for frame in d(video_path): |
||||
|
# print(frame) |
||||
|
# img_list.append(frame[0]) |
||||
|
# res = op(img_list) |
||||
|
# print(res.shape) |
||||
|
|
||||
|
dc = ( |
||||
|
towhee.dc['path']([video_path]) |
||||
|
.video_decode.ffmpeg['path', 'frames']( |
||||
|
sample_type='uniform_temporal_subsample', |
||||
|
args={'num_samples': 12}) |
||||
|
.runas_op['frames', 'frames'](func=lambda x: [y[0] for y in x]) |
||||
|
.clip4clip['frames', 'vec'](model_name='clip_vit_b32', modality='video', weight_path='./pytorch_model.bin.1') |
||||
|
) |
||||
|
dc.show() |
||||
|
|
Loading…
Reference in new issue