From 3dc67453f0ed2844d5837855841ff98e5466e198 Mon Sep 17 00:00:00 2001 From: wxywb Date: Fri, 15 Apr 2022 16:50:08 +0800 Subject: [PATCH] update clip. Signed-off-by: wxywb --- __init__.py | 23 ++++++++++++++++++++++- clip.py | 38 +++++++++++++++++++++++--------------- clip_impl.py | 4 ++-- model.py => clip_model.py | 0 4 files changed, 47 insertions(+), 18 deletions(-) rename model.py => clip_model.py (100%) diff --git a/__init__.py b/__init__.py index dcc5619..b80b17d 100644 --- a/__init__.py +++ b/__init__.py @@ -1 +1,22 @@ -from .clip import * +# Copyright 2021 Zilliz. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .clip import Clip + +def dolg(img_size=512, input_dim=3, hidden_dim=1024, output_dim=2048): + return Dolg(img_size, input_dim, hidden_dim, output_dim) + + +def clip(name: str, modality: str): + return Clip(name, modality) diff --git a/clip.py b/clip.py index 62ab943..8e888bc 100644 --- a/clip.py +++ b/clip.py @@ -12,42 +12,50 @@ # See the License for the specific language governing permissions and # limitations under the License. - -@register(output_schema=['vec']) import numpy import towhee import sys from pathlib import Path +import torch from torchvision import transforms -from towhee.types.image_utils import to_pil +from towhee.types.image_utils import to_pil from towhee.operator.base import NNOperator, OperatorFlag from towhee.types.arg import arg, to_image_color from towhee import register + @register(output_schema=['vec']) class Clip(NNOperator): """ CLIP multi-modal embedding operator """ - def __init__(self, modality: str): - self._modality = modality + def __init__(self, name: str, modality: str): + sys.path.append(str(Path(__file__).parent)) + #from clip_impl import load + import clip_impl + self.modality = modality + self.device = "cuda" if torch.cuda.is_available() else "cpu" + self._model, self.preprocess = clip_impl.load(name, self.device) + self.tokenize = clip_impl.tokenize def __call__(self, data): - if self._modality == 'image' - emb = self._inference_from_image(data) - elif self._modality == 'text' - emb = self._inference_from_text(data) - else + if self.modality == 'image': + vec = self._inference_from_image(data) + elif self.modality == 'text': + vec = self._inference_from_text(data) + else: raise ValueError("modality[{}] not implemented.".format(self._modality)) + return vec def _inference_from_text(self, text): - return text + text = self.tokenize(text).to(self.device) + text_features = self._model.encode_text(text) + return text_features @arg(1, to_image_color('RGB')) def _inference_from_image(self, img): - return img - - - + image = self.preprocess(to_pil(img)).unsqueeze(0).to(self.device) + image_features = self._model.encode_image(image) + return image_features diff --git a/clip_impl.py b/clip_impl.py index cf2ba38..658ac2c 100644 --- a/clip_impl.py +++ b/clip_impl.py @@ -10,8 +10,8 @@ from PIL import Image from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize from tqdm import tqdm -from .model import build_model -from .simple_tokenizer import SimpleTokenizer as _Tokenizer +from clip_model import build_model +from simple_tokenizer import SimpleTokenizer as _Tokenizer try: from torchvision.transforms import InterpolationMode diff --git a/model.py b/clip_model.py similarity index 100% rename from model.py rename to clip_model.py