clip4clip
copied
ChengZi
3 years ago
3 changed files with 177 additions and 0 deletions
@ -0,0 +1,20 @@ |
|||
# Copyright 2021 Zilliz. All rights reserved. |
|||
# |
|||
# Licensed under the Apache License, Version 2.0 (the "License"); |
|||
# you may not use this file except in compliance with the License. |
|||
# You may obtain a copy of the License at |
|||
# |
|||
# http://www.apache.org/licenses/LICENSE-2.0 |
|||
# |
|||
# Unless required by applicable law or agreed to in writing, software |
|||
# distributed under the License is distributed on an "AS IS" BASIS, |
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|||
# See the License for the specific language governing permissions and |
|||
# limitations under the License. |
|||
|
|||
from .clip4clip import CLIP4Clip |
|||
|
|||
|
|||
def clip4clip(model_name: str, modality: str, weight_path: str): |
|||
return CLIP4Clip(model_name, modality, weight_path) |
|||
|
@ -0,0 +1,157 @@ |
|||
# Copyright 2021 Zilliz. All rights reserved. |
|||
# |
|||
# Licensed under the Apache License, Version 2.0 (the "License"); |
|||
# you may not use this file except in compliance with the License. |
|||
# You may obtain a copy of the License at |
|||
# |
|||
# http://www.apache.org/licenses/LICENSE-2.0 |
|||
# |
|||
# Unless required by applicable law or agreed to in writing, software |
|||
# distributed under the License is distributed on an "AS IS" BASIS, |
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|||
# See the License for the specific language governing permissions and |
|||
# limitations under the License. |
|||
import os |
|||
import random |
|||
import sys |
|||
from pathlib import Path |
|||
|
|||
import numpy as np |
|||
import torch |
|||
import towhee |
|||
from torchvision import transforms |
|||
from towhee.models.clip4clip import convert_tokens_to_id |
|||
|
|||
from towhee.types.image_utils import to_pil |
|||
from towhee.operator.base import NNOperator, OperatorFlag |
|||
from towhee.types.arg import arg, to_image_color |
|||
from towhee import register |
|||
from towhee.models import clip4clip |
|||
from towhee.utils.ndarray_utils import to_ndarray |
|||
from PIL import Image as PILImage |
|||
|
|||
|
|||
@register(name='clip4clip', output_schema=['vec']) |
|||
class CLIP4Clip(NNOperator): |
|||
""" |
|||
CLIP multi-modal embedding operator |
|||
""" |
|||
def __init__(self, model_name: str, modality: str, weight_path: str = None): |
|||
super().__init__() |
|||
self.modality = modality |
|||
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|||
self.model = clip4clip.create_model(model_name=model_name, |
|||
context_length=77, |
|||
pretrained=True, |
|||
weights_path=weight_path, |
|||
device=self.device) |
|||
|
|||
self.tokenize = clip4clip.SimpleTokenizer() |
|||
self.tfms = transforms.Compose([ |
|||
transforms.Resize(224, interpolation=transforms.InterpolationMode.BICUBIC), |
|||
transforms.CenterCrop(224), |
|||
transforms.ToTensor(), |
|||
transforms.Normalize( |
|||
(0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) |
|||
]) |
|||
self.model.eval() |
|||
|
|||
def __call__(self, data): |
|||
if self.modality == 'video': |
|||
vec = self._inference_from_video(data) |
|||
elif self.modality == 'text': |
|||
vec = self._inference_from_text(data) |
|||
else: |
|||
raise ValueError("modality[{}] not implemented.".format(self._modality)) |
|||
return vec |
|||
# |
|||
def _inference_from_text(self, text): |
|||
self.model.eval() |
|||
# text = self.tokenize(text) |
|||
text_ids = convert_tokens_to_id(self.tokenize, text) |
|||
print(text_ids) |
|||
text_ids = torch.tensor(text_ids).unsqueeze(0).to(self.device) |
|||
text_features = self.model.get_sequence_output(text_ids) |
|||
text_features = text_features / text_features.norm(dim=-1, keepdim=True) |
|||
# print(text_features.norm(dim=-1, keepdim=True)) |
|||
return text_features#.unsqueeze(0).cpu().numpy() |
|||
|
|||
def _inference_from_video(self, img_list): |
|||
self.model.eval() |
|||
# video = self.tfms(video) |
|||
max_frames = 12 |
|||
video = np.zeros((1, max_frames, 1, 3, 224, 224), dtype=np.float) |
|||
slice_len = len(img_list) |
|||
max_video_length = 0 if 0 > slice_len else slice_len |
|||
for i, img in enumerate(img_list): |
|||
pil_img = PILImage.fromarray(img.to_ndarray(), img.mode) |
|||
tfmed_img = self.tfms(pil_img).unsqueeze(0).to(self.device) |
|||
print('tfmed_img.shape', tfmed_img.shape) |
|||
|
|||
if slice_len >= 1: |
|||
video[0, i, ...] = tfmed_img |
|||
video_mask = np.zeros((1, max_frames), dtype=np.long) |
|||
video_mask[0, :max_video_length] = [1] * max_video_length |
|||
|
|||
video = torch.as_tensor(video).float() |
|||
pair, bs, ts, channel, h, w = video.shape |
|||
video = video.view(pair * bs * ts, channel, h, w) |
|||
video_mask = torch.as_tensor(video_mask).float() |
|||
# video_list.append(video) |
|||
# video_mask_list.append(video_mask) |
|||
# video_list_tensor = torch.cat(video_list, dim=0) |
|||
# video_mask_list_tensor = torch.cat(video_mask_list, dim=0) |
|||
|
|||
visual_output = self.model.get_visual_output(video, video_mask, shaped=True) |
|||
|
|||
visual_output = visual_output / visual_output.norm(dim=-1, keepdim=True) |
|||
|
|||
video_mask_un = video_mask.to(dtype=torch.float).unsqueeze(-1) |
|||
visual_output = visual_output * video_mask_un |
|||
video_mask_un_sum = torch.sum(video_mask_un, dim=1, dtype=torch.float) |
|||
video_mask_un_sum[video_mask_un_sum == 0.] = 1. |
|||
visual_output = torch.sum(visual_output, dim=1) / video_mask_un_sum |
|||
|
|||
visual_output = visual_output / visual_output.norm(dim=-1, keepdim=True) |
|||
|
|||
return visual_output#.unsqueeze(0).cpu().numpy() |
|||
|
|||
|
|||
# |
|||
# @arg(1, to_image_color('RGB')) |
|||
# def _inference_from_video(self, img): |
|||
# img = to_pil(img) |
|||
# image = self.tfms(img).unsqueeze(0).to(self.device) |
|||
# image_features = self.model.encode_image(image) |
|||
# return image_features |
|||
if __name__ == '__main__': |
|||
# op = CLIP4Clip('clip_vit_b32', 'text', './pytorch_model.bin.1') |
|||
# res = op('kids feeding and playing with the horse') |
|||
# print(res.shape) |
|||
|
|||
|
|||
|
|||
|
|||
# from towhee import ops |
|||
# op = CLIP4Clip('clip_vit_b32', 'video', './pytorch_model.bin.1') |
|||
# d = ops.video_decode.ffmpeg(sample_type='uniform_temporal_subsample', |
|||
# args={'num_samples': 12}) |
|||
# # ops.video_decode.get_video_duration() |
|||
video_path = '/Users/zilliz/dataset/MSRVTT/MSRVTT/videos/all/video9451.mp4' |
|||
# img_list = [] |
|||
# for frame in d(video_path): |
|||
# print(frame) |
|||
# img_list.append(frame[0]) |
|||
# res = op(img_list) |
|||
# print(res.shape) |
|||
|
|||
dc = ( |
|||
towhee.dc['path']([video_path]) |
|||
.video_decode.ffmpeg['path', 'frames']( |
|||
sample_type='uniform_temporal_subsample', |
|||
args={'num_samples': 12}) |
|||
.runas_op['frames', 'frames'](func=lambda x: [y[0] for y in x]) |
|||
.clip4clip['frames', 'vec'](model_name='clip_vit_b32', modality='video', weight_path='./pytorch_model.bin.1') |
|||
) |
|||
dc.show() |
|||
|
Loading…
Reference in new issue