# Copyright 2021 Zilliz. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import sys import os from pathlib import Path from easydict import EasyDict as edict import torch from torchvision import transforms from transformers import GPT2Tokenizer from towhee.types.arg import arg, to_image_color from towhee.types.image_utils import to_pil from towhee.operator.base import NNOperator, OperatorFlag from towhee import register from towhee.command.s3 import S3Bucket class Camel(NNOperator): """ Camel image captioning operator """ def _gen_args(self): args = edict() args.N_dec=3 args.N_enc=3 args.batch_size=25 args.d_ff=2048 args.d_model=512 args.disable_mesh=False args.head=8 args.image_dim=3072 args.m=40 args.network='target' args.with_pe=False args.workers=0 return args def __init__(self, model_name: str): super().__init__() sys.path.append(str(Path(__file__).parent)) self.device = "cuda" if torch.cuda.is_available() else "cpu" from models import Captioner, clip from data import ImageField, TextField from models import clip # Pipeline for text self.text_field = TextField() args = self._gen_args() path = str(Path(__file__).parent) self.clip_model, self.clip_tfms = clip.load('RN50x16', jit=False) self.image_model = self.clip_model.visual self.image_model.forward = self.image_model.intermediate_features image_field = ImageField(transform=self.clip_tfms) args.image_dim = self.image_model.embed_dim config = self._configs()[model_name] s3_bucket = S3Bucket() s3_bucket.download_file(config['weights'], path + '/weights/') model_path = path + '/weights/' + os.path.basename(config['weights']) # Create the model self.model = Captioner(args, self.text_field).to(self.device) self.model.forward = self.model.beam_search self.image_model = self.image_model.to(self.device) self.model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))['state_dict_t']) self.model.eval() sys.path.pop() @arg(1, to_image_color('RGB')) def inference_single_data(self, data): text = self._inference_from_image(data) return text def _preprocess(self, img): img = to_pil(img) processed_img = self.clip_tfms(img).unsqueeze(0).to(self.device) return processed_img def __call__(self, data): if not isinstance(data, list): data = [data] else: data = data results = [] for single_data in data: result = self.inference_single_data(single_data) results.append(result) if len(data) == 1: return results[0] else: return results @arg(1, to_image_color('RGB')) def _inference_from_image(self, img): img = self._preprocess(img) feat = self.image_model(img) tokens, _ = self.model.beam_search(feat, beam_size=5, out_size=1) text = text_field.decode(tokens) return text def _configs(self): config = {} config['camel_mesh'] = {} config['camel_mesh']['weights'] = 'image-captioning/camel/camel_mesh.pth' return config if __name__ == '__main__': pass