# Copyright 2021 Zilliz. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import sys import os from pathlib import Path import torch from torchvision import transforms from transformers import GPT2Tokenizer from towhee.types.arg import arg, to_image_color from towhee.types.image_utils import to_pil from towhee.operator.base import NNOperator, OperatorFlag from towhee import register from towhee.models import clip class ExpansionNetV2(NNOperator): """ ExpansionNet V2 image captioning operator """ def __init__(self, model_name: str): super().__init__() sys.path.append(str(Path(__file__).parent)) from models.End_ExpansionNet_v2 import End_ExpansionNet_v2 sys.path.pop() with open('demo_coco_tokens.pickle') as fw: self.model = End_ExpansionNet_v2(swin_img_size=img_size, swin_patch_size=4, swin_in_chans=3, swin_embed_dim=192, swin_depths=[2, 2, 18, 2], swin_num_heads=[6, 12, 24, 48], swin_window_size=12, swin_mlp_ratio=4., swin_qkv_bias=True, swin_qk_scale=None, swin_drop_rate=0.0, swin_attn_drop_rate=0.0, swin_drop_path_rate=0.0, swin_norm_layer=torch.nn.LayerNorm, swin_ape=False, swin_patch_norm=True, swin_use_checkpoint=False, final_swin_dim=1536, d_model=model_args.model_dim, N_enc=model_args.N_enc, N_dec=model_args.N_dec, num_heads=8, ff=2048, num_exp_enc_list=[32, 64, 128, 256, 512], num_exp_dec=16, output_word2idx=coco_tokens['word2idx_dict'], output_idx2word=coco_tokens['idx2word_list'], max_seq_len=args.max_seq_len, drop_args=model_args.drop_args, rank='cpu')