From 61179b49504d5766d7f61353b74fdb7ab90029c0 Mon Sep 17 00:00:00 2001 From: wxywb Date: Wed, 9 Nov 2022 15:52:35 +0800 Subject: [PATCH] init the operator. Signed-off-by: wxywb --- __init__.py | 17 ++++++++++ capdec.py | 83 ++++++++++++++++++++++++++++++++++++++++++++++++ requirements.txt | 0 3 files changed, 100 insertions(+) create mode 100644 __init__.py create mode 100644 capdec.py create mode 100644 requirements.txt diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..9a8818b --- /dev/null +++ b/__init__.py @@ -0,0 +1,17 @@ +# Copyright 2021 Zilliz. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .capdec import Capdec +def capdec(model_name: str): + return Capdec(model_name) diff --git a/capdec.py b/capdec.py new file mode 100644 index 0000000..c628d0c --- /dev/null +++ b/capdec.py @@ -0,0 +1,83 @@ +# Copyright 2021 Zilliz. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import os +from pathlib import Path + +import torch +from torchvision import transforms +from transformers import GPT2Tokenizer + +from towhee.types.arg import arg, to_image_color +from towhee.types.image_utils import to_pil +from towhee.operator.base import NNOperator, OperatorFlag +from towhee import register +from towhee.models import clip + +from transformers import GPT2Tokenizer, GPT2LMHeadModel, AdamW, get_linear_schedule_with_warmup + + +class Capdec(NNOperator): + """ + CapDec image captioning operator + """ + def __init__(self, model_name: str): + super().__init__() + sys.path.append(str(Path(__file__).parent)) + + + + @arg(1, to_image_color('RGB')) + def inference_single_data(self, data): + text = self._inference_from_image(data) + return text + + def _preprocess(self, img): + img = to_pil(img) + processed_img = self.clip_tfms(img).unsqueeze(0).to(self.device) + return processed_img + + def __call__(self, data): + if not isinstance(data, list): + data = [data] + else: + data = data + results = [] + for single_data in data: + result = self.inference_single_data(single_data) + results.append(result) + if len(data) == 1: + return results[0] + else: + return results + + @arg(1, to_image_color('RGB')) + def _inference_from_image(self, img): + img = self._preprocess(img) + clip_feat = self.clip_model.encode_image(img) + + self.prefix_length = 10 + prefix_embed = self.model.clip_project(clip_feat).reshape(1, self.prefix_length, -1) + + generated_text_prefix = self.generate_beam(self.model, self.tokenizer, embed=prefix_embed)[0] + return generated_text_prefix + + def _configs(self): + config = {} + config['clipcap_coco'] = {} + config['clipcap_coco']['weights'] = 'coco_weights.pt' + config['clipcap_conceptual'] = {} + config['clipcap_conceptual']['weights'] = 'conceptual_weights.pt' + return config diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..e69de29