capdec
copied
3 changed files with 100 additions and 0 deletions
@ -0,0 +1,17 @@ |
|||
# Copyright 2021 Zilliz. All rights reserved. |
|||
# |
|||
# Licensed under the Apache License, Version 2.0 (the "License"); |
|||
# you may not use this file except in compliance with the License. |
|||
# You may obtain a copy of the License at |
|||
# |
|||
# http://www.apache.org/licenses/LICENSE-2.0 |
|||
# |
|||
# Unless required by applicable law or agreed to in writing, software |
|||
# distributed under the License is distributed on an "AS IS" BASIS, |
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|||
# See the License for the specific language governing permissions and |
|||
# limitations under the License. |
|||
|
|||
from .capdec import Capdec |
|||
def capdec(model_name: str): |
|||
return Capdec(model_name) |
@ -0,0 +1,83 @@ |
|||
# Copyright 2021 Zilliz. All rights reserved. |
|||
# |
|||
# Licensed under the Apache License, Version 2.0 (the "License"); |
|||
# you may not use this file except in compliance with the License. |
|||
# You may obtain a copy of the License at |
|||
# |
|||
# http://www.apache.org/licenses/LICENSE-2.0 |
|||
# |
|||
# Unless required by applicable law or agreed to in writing, software |
|||
# distributed under the License is distributed on an "AS IS" BASIS, |
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|||
# See the License for the specific language governing permissions and |
|||
# limitations under the License. |
|||
|
|||
import sys |
|||
import os |
|||
from pathlib import Path |
|||
|
|||
import torch |
|||
from torchvision import transforms |
|||
from transformers import GPT2Tokenizer |
|||
|
|||
from towhee.types.arg import arg, to_image_color |
|||
from towhee.types.image_utils import to_pil |
|||
from towhee.operator.base import NNOperator, OperatorFlag |
|||
from towhee import register |
|||
from towhee.models import clip |
|||
|
|||
from transformers import GPT2Tokenizer, GPT2LMHeadModel, AdamW, get_linear_schedule_with_warmup |
|||
|
|||
|
|||
class Capdec(NNOperator): |
|||
""" |
|||
CapDec image captioning operator |
|||
""" |
|||
def __init__(self, model_name: str): |
|||
super().__init__() |
|||
sys.path.append(str(Path(__file__).parent)) |
|||
|
|||
|
|||
|
|||
@arg(1, to_image_color('RGB')) |
|||
def inference_single_data(self, data): |
|||
text = self._inference_from_image(data) |
|||
return text |
|||
|
|||
def _preprocess(self, img): |
|||
img = to_pil(img) |
|||
processed_img = self.clip_tfms(img).unsqueeze(0).to(self.device) |
|||
return processed_img |
|||
|
|||
def __call__(self, data): |
|||
if not isinstance(data, list): |
|||
data = [data] |
|||
else: |
|||
data = data |
|||
results = [] |
|||
for single_data in data: |
|||
result = self.inference_single_data(single_data) |
|||
results.append(result) |
|||
if len(data) == 1: |
|||
return results[0] |
|||
else: |
|||
return results |
|||
|
|||
@arg(1, to_image_color('RGB')) |
|||
def _inference_from_image(self, img): |
|||
img = self._preprocess(img) |
|||
clip_feat = self.clip_model.encode_image(img) |
|||
|
|||
self.prefix_length = 10 |
|||
prefix_embed = self.model.clip_project(clip_feat).reshape(1, self.prefix_length, -1) |
|||
|
|||
generated_text_prefix = self.generate_beam(self.model, self.tokenizer, embed=prefix_embed)[0] |
|||
return generated_text_prefix |
|||
|
|||
def _configs(self): |
|||
config = {} |
|||
config['clipcap_coco'] = {} |
|||
config['clipcap_coco']['weights'] = 'coco_weights.pt' |
|||
config['clipcap_conceptual'] = {} |
|||
config['clipcap_conceptual']['weights'] = 'conceptual_weights.pt' |
|||
return config |
Loading…
Reference in new issue