logo
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions

160 lines
5.4 KiB

# Copyright 2021 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import pathlib
import json
from pathlib import Path
from PIL import Image
import numpy as np
import torch
from torch import nn
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
from torchvision.transforms.functional import InterpolationMode
from timm.models.vision_transformer import resize_pos_embed
from towhee.types.image_utils import to_pil
from towhee.types.arg import arg, to_image_color
from towhee.operator.base import NNOperator, OperatorFlag
class ClipCaptionReward(NNOperator):
"""
BLIP multi-modal embedding operator
"""
def __init__(self, model_name: str):
super().__init__()
sys.path.append(str(Path(__file__).parent))
from utils import opts
from transformer_model import TransformerModel
from captioning.models.model_utils import decode_sequence
self.decode_sequence = decode_sequence
import mclip
sys.path.pop()
path = pathlib.Path(__file__).parent
cfg = self._configs()[model_name]
config = str(path) + cfg['config']
opt = opts.parse_opt(parse=False, cfg=(config))
dict_json = json.load(open("{}/data/cocotalk.json".format(path)))
ix_to_word = dict_json["ix_to_word"]
self.device = "cuda" if torch.cuda.is_available() else "cpu"
clip_model, clip_transform = mclip.load("RN50", jit=False, device=self.device)
self.clip_model = clip_model
self.clip_transform = clip_transform
vocab_size = len(ix_to_word)
seq_length = 1
opt.vocab_size = vocab_size
opt.seq_length = seq_length
opt.batch_size = 1
opt.vocab = ix_to_word
num_patches = 196 # 600 * 1000 // 32 // 32
pos_embed = nn.Parameter(
torch.zeros(
1,
num_patches + 1,
clip_model.visual.attnpool.positional_embedding.shape[-1],
device=self.device,
),
)
pos_embed.weight = resize_pos_embed(
clip_model.visual.attnpool.positional_embedding.unsqueeze(0), pos_embed
)
self.clip_model.visual.attnpool.positional_embedding = pos_embed
ckpt_path = str(path) + cfg['weights']
raw_state_dict = torch.load(ckpt_path, map_location=torch.device('cpu'))
self.model = TransformerModel(opt)
self.model.load_state_dict(raw_state_dict)
self.image_mean = (
torch.Tensor([0.48145466, 0.4578275, 0.40821073])
.to(self.device)
.reshape(3, 1, 1)
)
self.image_std = (
torch.Tensor([0.26862954, 0.26130258, 0.27577711])
.to(self.device)
.reshape(3, 1, 1)
)
self._preprocess = Compose(
[
Resize((448, 448), interpolation= InterpolationMode.BILINEAR),
CenterCrop((448, 448)),
ToTensor(),
]
)
self.eval_kwargs = {}
self.eval_kwargs.update(vars(opt))
@arg(1, to_image_color('RGB'))
def inference_single_data(self, data):
text = self._inference_from_image(data)
return text
def _inference_from_image(self, img):
img = to_pil(img)
img = self._preprocess(img)
img = torch.tensor(np.stack([img])).to(self.device)
img -= self.image_mean
img /= self.image_std
tmp_att, tmp_fc = self.clip_model.encode_image(img)
tmp_att = tmp_att[0].permute(1, 2, 0)
att_feat = tmp_att
with torch.no_grad():
fc_feats = torch.zeros((1, 0)).to(self.device)
att_feats = att_feat.view(1, 196, 2048).float().to(self.device)
att_masks = None
# forward the model to also get generated samples for each image
# Only leave one feature for each image, in case duplicate sample
tmp_eval_kwargs = self.eval_kwargs.copy()
tmp_eval_kwargs.update({"sample_n": 1})
seq, seq_logprobs = self.model(
fc_feats, att_feats, att_masks, opt=tmp_eval_kwargs, mode="sample"
)
seq = seq.data
sents = self.decode_sequence(self.model.vocab, seq)
return sents[0]
def __call__(self, data):
results = []
if not isinstance(data, list):
data = [data]
else:
data = data
for single_data in data:
result = self.inference_single_data(single_data)
results.append(result)
if len(data) == 1:
return results[0]
else:
return results
def _configs(self):
config = {}
config['clipRN50_clips_grammar'] = {}
config['clipRN50_clips_grammar']['weights'] = '/weights/clipRN50_clips_grammar-last.pth'
config['clipRN50_clips_grammar']['config'] = '/configs/phase2/clipRN50_clips_grammar.yml'
return config