capdec
copied
3 changed files with 100 additions and 0 deletions
@ -0,0 +1,17 @@ |
|||||
|
# Copyright 2021 Zilliz. All rights reserved. |
||||
|
# |
||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
|
# you may not use this file except in compliance with the License. |
||||
|
# You may obtain a copy of the License at |
||||
|
# |
||||
|
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
|
# |
||||
|
# Unless required by applicable law or agreed to in writing, software |
||||
|
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
|
# See the License for the specific language governing permissions and |
||||
|
# limitations under the License. |
||||
|
|
||||
|
from .capdec import Capdec |
||||
|
def capdec(model_name: str): |
||||
|
return Capdec(model_name) |
@ -0,0 +1,83 @@ |
|||||
|
# Copyright 2021 Zilliz. All rights reserved. |
||||
|
# |
||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
|
# you may not use this file except in compliance with the License. |
||||
|
# You may obtain a copy of the License at |
||||
|
# |
||||
|
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
|
# |
||||
|
# Unless required by applicable law or agreed to in writing, software |
||||
|
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
|
# See the License for the specific language governing permissions and |
||||
|
# limitations under the License. |
||||
|
|
||||
|
import sys |
||||
|
import os |
||||
|
from pathlib import Path |
||||
|
|
||||
|
import torch |
||||
|
from torchvision import transforms |
||||
|
from transformers import GPT2Tokenizer |
||||
|
|
||||
|
from towhee.types.arg import arg, to_image_color |
||||
|
from towhee.types.image_utils import to_pil |
||||
|
from towhee.operator.base import NNOperator, OperatorFlag |
||||
|
from towhee import register |
||||
|
from towhee.models import clip |
||||
|
|
||||
|
from transformers import GPT2Tokenizer, GPT2LMHeadModel, AdamW, get_linear_schedule_with_warmup |
||||
|
|
||||
|
|
||||
|
class Capdec(NNOperator): |
||||
|
""" |
||||
|
CapDec image captioning operator |
||||
|
""" |
||||
|
def __init__(self, model_name: str): |
||||
|
super().__init__() |
||||
|
sys.path.append(str(Path(__file__).parent)) |
||||
|
|
||||
|
|
||||
|
|
||||
|
@arg(1, to_image_color('RGB')) |
||||
|
def inference_single_data(self, data): |
||||
|
text = self._inference_from_image(data) |
||||
|
return text |
||||
|
|
||||
|
def _preprocess(self, img): |
||||
|
img = to_pil(img) |
||||
|
processed_img = self.clip_tfms(img).unsqueeze(0).to(self.device) |
||||
|
return processed_img |
||||
|
|
||||
|
def __call__(self, data): |
||||
|
if not isinstance(data, list): |
||||
|
data = [data] |
||||
|
else: |
||||
|
data = data |
||||
|
results = [] |
||||
|
for single_data in data: |
||||
|
result = self.inference_single_data(single_data) |
||||
|
results.append(result) |
||||
|
if len(data) == 1: |
||||
|
return results[0] |
||||
|
else: |
||||
|
return results |
||||
|
|
||||
|
@arg(1, to_image_color('RGB')) |
||||
|
def _inference_from_image(self, img): |
||||
|
img = self._preprocess(img) |
||||
|
clip_feat = self.clip_model.encode_image(img) |
||||
|
|
||||
|
self.prefix_length = 10 |
||||
|
prefix_embed = self.model.clip_project(clip_feat).reshape(1, self.prefix_length, -1) |
||||
|
|
||||
|
generated_text_prefix = self.generate_beam(self.model, self.tokenizer, embed=prefix_embed)[0] |
||||
|
return generated_text_prefix |
||||
|
|
||||
|
def _configs(self): |
||||
|
config = {} |
||||
|
config['clipcap_coco'] = {} |
||||
|
config['clipcap_coco']['weights'] = 'coco_weights.pt' |
||||
|
config['clipcap_conceptual'] = {} |
||||
|
config['clipcap_conceptual']['weights'] = 'conceptual_weights.pt' |
||||
|
return config |
Loading…
Reference in new issue