# Copyright 2021 Zilliz. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import sys import os import pathlib import pickle from argparse import Namespace import torch import torchvision from torchvision import transforms from transformers import GPT2Tokenizer from towhee.types.arg import arg, to_image_color from towhee.types.image_utils import to_pil from towhee.operator.base import NNOperator, OperatorFlag from towhee import register from towhee.models import clip from towhee.command.s3 import S3Bucket class ExpansionNetV2(NNOperator): """ ExpansionNet V2 image captioning operator """ def __init__(self, model_name: str): super().__init__() path = str(pathlib.Path(__file__).parent) sys.path.append(path) from models.End_ExpansionNet_v2 import End_ExpansionNet_v2 from utils.language_utils import convert_vector_idx2word self.convert_vector_idx2word = convert_vector_idx2word cfg = self._configs()[model_name] s3_bucket = S3Bucket() s3_bucket.download_file(cfg['weights'], path + '/weights/') sys.path.pop() with open('{}/demo_coco_tokens.pickle'.format(path), 'rb') as f: coco_tokens = pickle.load(f) self.coco_tokens = coco_tokens img_size = 384 self.device = "cuda" if torch.cuda.is_available() else "cpu" drop_args = Namespace(enc=0.0, dec=0.0, enc_input=0.0, dec_input=0.0, other=0.0) drop_args = Namespace(enc=0.0, dec=0.0, enc_input=0.0, dec_input=0.0, other=0.0) model_args = Namespace(model_dim=512, N_enc=3, N_dec=3, dropout=0.0, drop_args=drop_args) max_seq_len = 74 beam_size = 5 self.model = End_ExpansionNet_v2(swin_img_size=img_size, swin_patch_size=4, swin_in_chans=3, swin_embed_dim=192, swin_depths=[2, 2, 18, 2], swin_num_heads=[6, 12, 24, 48], swin_window_size=12, swin_mlp_ratio=4., swin_qkv_bias=True, swin_qk_scale=None, swin_drop_rate=0.0, swin_attn_drop_rate=0.0, swin_drop_path_rate=0.0, swin_norm_layer=torch.nn.LayerNorm, swin_ape=False, swin_patch_norm=True, swin_use_checkpoint=False, final_swin_dim=1536, d_model=model_args.model_dim, N_enc=model_args.N_enc, N_dec=model_args.N_dec, num_heads=8, ff=2048, num_exp_enc_list=[32, 64, 128, 256, 512], num_exp_dec=16, output_word2idx=coco_tokens['word2idx_dict'], output_idx2word=coco_tokens['idx2word_list'], max_seq_len=max_seq_len, drop_args=model_args.drop_args, rank='cpu') checkpoint = torch.load('{}/weights/{}'.format(path,os.path.basename(cfg['weights'])), map_location=torch.device('cpu')) self.model.load_state_dict(checkpoint['model_state_dict']) self.model.to(self.device) self.transf_1 = torchvision.transforms.Compose([torchvision.transforms.Resize((img_size, img_size)), torchvision.transforms.ToTensor()]) self.transf_2 = torchvision.transforms.Compose([torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) self.beam_search_kwargs = {'beam_size': beam_size, 'beam_max_seq_len': max_seq_len, 'sample_or_max': 'max', 'how_many_outputs': 1, 'sos_idx': coco_tokens['word2idx_dict'][coco_tokens['sos_str']], 'eos_idx': coco_tokens['word2idx_dict'][coco_tokens['eos_str']]} def _preprocess(self, img): img = to_pil(img) processed_img = self.transf_1(img) processed_img = self.transf_2(processed_img) processed_img = processed_img.to(self.device) return processed_img @arg(1, to_image_color('RGB')) def inference_single_data(self, data): text = self._inference_from_image(data) return text def __call__(self, data): if not isinstance(data, list): data = [data] else: data = data results = [] for single_data in data: result = self.inference_single_data(single_data) results.append(result) if len(data) == 1: return results[0] else: return results @arg(1, to_image_color('RGB')) def _inference_from_image(self, img): img = self._preprocess(img).unsqueeze(0) with torch.no_grad(): pred, _ = self.model(enc_x=img, enc_x_num_pads=[0], mode='beam_search', **self.beam_search_kwargs) pred = self.convert_vector_idx2word(pred[0][0], self.coco_tokens['idx2word_list'])[1:-1] pred[-1] = pred[-1] + '.' pred = ' '.join(pred).capitalize() return pred def _configs(self): config = {} config['expansionnet_rf'] = {} config['expansionnet_rf']['weights'] = 'image-captioning/expansionnet-v2/rf_model.pth' return config