From 7210b706ac5c26d89be850f8fdb22951847a5981 Mon Sep 17 00:00:00 2001 From: wxywb Date: Fri, 23 Sep 2022 11:54:18 +0800 Subject: [PATCH] update the operator. Signed-off-by: wxywb --- __init__.py | 19 +++++++++++++++++++ taiyi.py | 46 +++++++++++++++++++++++++++++++++++++--------- 2 files changed, 56 insertions(+), 9 deletions(-) diff --git a/__init__.py b/__init__.py index e69de29..d1acbac 100644 --- a/__init__.py +++ b/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2021 Zilliz. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .taiyi import Taiyi + + +def taiyi(model_name: str, modality: str): + return Taiyi(model_name, modality) diff --git a/taiyi.py b/taiyi.py index d8931ba..51084e0 100644 --- a/taiyi.py +++ b/taiyi.py @@ -15,6 +15,16 @@ import torch from transformers import BertForSequenceClassification, BertConfig, BertTokenizer from transformers import CLIPProcessor, CLIPModel +import sys +from pathlib import Path +import torch +from torchvision import transforms + +from towhee.types.image_utils import to_pil +from towhee.operator.base import NNOperator, OperatorFlag +from towhee.types.arg import arg, to_image_color +from towhee import register + @register(output_schema=['vec']) class Taiyi(NNOperator): """ @@ -23,10 +33,13 @@ class Taiyi(NNOperator): def __init__(self, model_name: str, modality: str): self.modality = modality self.device = "cuda" if torch.cuda.is_available() else "cpu" - self.text_tokenizer = BertTokenizer.from_pretrained("IDEA-CCNL/Taiyi-CLIP-Roberta-large-326M-Chinese") - self.text_encoder = BertForSequenceClassification.from_pretrained("IDEA-CCNL/Taiyi-CLIP-Roberta-large-326M-Chinese").eval() - self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14") - self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14") + config = self._configs()[model_name] + + self.text_tokenizer = BertTokenizer.from_pretrained(config['tokenizer']) + self.text_encoder = BertForSequenceClassification.from_pretrained(config['text_encoder']).eval() + + self.clip_model = CLIPModel.from_pretrained(config['clip_model']) + self.processor = CLIPProcessor.from_pretrained(config['processor']) def inference_single_data(self, data): if self.modality == 'image': @@ -52,14 +65,29 @@ class Taiyi(NNOperator): return results def _inference_from_text(self, text): - self.text = self.text_tokenizer(text, return_tensors='pt', padding=True)['input_ids'].to(self.device) - text_features = text_encoder(text).logits + tokens = self.text_tokenizer(text, return_tensors='pt', padding=True)['input_ids'].to(self.device) + text_features = self.text_encoder(tokens).logits return text_features @arg(1, to_image_color('RGB')) def _inference_from_image(self, img): - image = to_pil(image) - image = self.processor(images=image.raw), return_tensors="pt") - image_features = clip_model.get_image_features(**image) + image = to_pil(img) + image = self.processor(images=image, return_tensors="pt") + image_features = self.clip_model.get_image_features(**image) return image_features + def _configs(self): + config = {} + config['taiyi-clip-roberta-102m-chinese'] = {} + config['taiyi-clip-roberta-102m-chinese']['tokenizer'] = 'IDEA-CCNL/Taiyi-CLIP-Roberta-102M-Chinese' + config['taiyi-clip-roberta-102m-chinese']['text_encoder'] = 'IDEA-CCNL/Taiyi-CLIP-Roberta-102M-Chinese' + config['taiyi-clip-roberta-102m-chinese']['clip_model'] = 'openai/clip-vit-base-patch32' + config['taiyi-clip-roberta-102m-chinese']['processor'] = 'openai/clip-vit-base-patch32' + + config['taiyi-clip-roberta-large-326m-chinese'] = {} + config['taiyi-clip-roberta-large-326m-chinese']['tokenizer'] = 'IDEA-CCNL/Taiyi-CLIP-Roberta-large-326M-Chinese' + config['taiyi-clip-roberta-large-326m-chinese']['text_encoder'] = 'IDEA-CCNL/Taiyi-CLIP-Roberta-large-326M-Chinese' + config['taiyi-clip-roberta-large-326m-chinese']['clip_model'] = 'openai/clip-vit-large-patch14' + config['taiyi-clip-roberta-large-326m-chinese']['processor'] = 'openai/clip-vit-large-patch14' + return config +