clipcap
              
                
                
            
          copied
			You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
			
			Readme
Files and versions
		
			
				
					80 lines
				
				2.8 KiB
			
		
		
			
		
	
	
					80 lines
				
				2.8 KiB
			| 
								 
											3 years ago
										 
									 | 
							
								# Copyright 2021 Zilliz. All rights reserved.
							 | 
						||
| 
								 | 
							
								#
							 | 
						||
| 
								 | 
							
								# Licensed under the Apache License, Version 2.0 (the "License");
							 | 
						||
| 
								 | 
							
								# you may not use this file except in compliance with the License.
							 | 
						||
| 
								 | 
							
								# You may obtain a copy of the License at
							 | 
						||
| 
								 | 
							
								#
							 | 
						||
| 
								 | 
							
								#     http://www.apache.org/licenses/LICENSE-2.0
							 | 
						||
| 
								 | 
							
								#
							 | 
						||
| 
								 | 
							
								# Unless required by applicable law or agreed to in writing, software
							 | 
						||
| 
								 | 
							
								# distributed under the License is distributed on an "AS IS" BASIS,
							 | 
						||
| 
								 | 
							
								# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
							 | 
						||
| 
								 | 
							
								# See the License for the specific language governing permissions and
							 | 
						||
| 
								 | 
							
								# limitations under the License.
							 | 
						||
| 
								 | 
							
								import os
							 | 
						||
| 
								 | 
							
								import torch
							 | 
						||
| 
								 | 
							
								from pathlib import Path
							 | 
						||
| 
								 | 
							
								from torchvision import transforms
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								from towhee.types.image_utils import to_pil
							 | 
						||
| 
								 | 
							
								from towhee.operator.base import NNOperator, OperatorFlag
							 | 
						||
| 
								 | 
							
								from towhee import register
							 | 
						||
| 
								 | 
							
								from towhee.models import clip
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								class ClipCap(NNOperator):
							 | 
						||
| 
								 | 
							
								    """
							 | 
						||
| 
								 | 
							
								    ClipCap image captioning operator
							 | 
						||
| 
								 | 
							
								    """
							 | 
						||
| 
								 | 
							
								    def __init__(self, model_name: str):
							 | 
						||
| 
								 | 
							
								        super().__init__():
							 | 
						||
| 
								 | 
							
								        sys.path.append(str(Path(__file__).parent))
							 | 
						||
| 
								 | 
							
								        from models.clipcap import ClipCaptionModel
							 | 
						||
| 
								 | 
							
								        config = self._configs()[model_name]
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        self.clip_tfms = self.tfms = transforms.Compose([
							 | 
						||
| 
								 | 
							
								            transforms.Resize(224, interpolation=transforms.InterpolationMode.BICUBIC),
							 | 
						||
| 
								 | 
							
								            transforms.CenterCrop(224),
							 | 
						||
| 
								 | 
							
								            transforms.ToTensor(),
							 | 
						||
| 
								 | 
							
								            transforms.Normalize(
							 | 
						||
| 
								 | 
							
								               (0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
							 | 
						||
| 
								 | 
							
								            ])
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        clip_model_type = 'clip_vit_b32'
							 | 
						||
| 
								 | 
							
								        self.clip_model = clip.create_model(model_name=clip_model_type, pretrained=True, jit=True)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        self.model = ClipCaptionModel(prefix = 10)
							 | 
						||
| 
								 | 
							
								        model_path = os.path.dirname(__file__) + '/weights/' + config['weights']
							 | 
						||
| 
								 | 
							
								        self.model.load_state_dict(torch.load(model_path, map_location=CPU))
							 | 
						||
| 
								 | 
							
								        self.model = model.eval()
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        
							 | 
						||
| 
								 | 
							
								    @arg(1, to_image_color('RGB'))
							 | 
						||
| 
								 | 
							
								    def __call__(self, data:):
							 | 
						||
| 
								 | 
							
								        vec = self._inference_from_image(data)
							 | 
						||
| 
								 | 
							
								        return vec
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def _preprocess(self, img):
							 | 
						||
| 
								 | 
							
								        img = to_pil(img)
							 | 
						||
| 
								 | 
							
								        processed_img = self.self.clip_tfms(img).unsqueeze(0).to(self.device)
							 | 
						||
| 
								 | 
							
								        return processed_img
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    @arg(1, to_image_color('RGB'))
							 | 
						||
| 
								 | 
							
								    def _inference_from_image(self, img):
							 | 
						||
| 
								 | 
							
								        img = self._preprocess(img)
							 | 
						||
| 
								 | 
							
								        clip_feat = self.clip_model.encode_image(image)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        prefix_length = 10
							 | 
						||
| 
								 | 
							
								        prefix_embed = self.model.clip_project(clip_feat).reshape(1, prefix_length, -1)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        generated_text_prefix = generate_beam(model, tokenizer, embed=prefix_embed)[0]
							 | 
						||
| 
								 | 
							
								        return generated_text_prefix
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def _configs(self):
							 | 
						||
| 
								 | 
							
								        config = {}
							 | 
						||
| 
								 | 
							
								        config['clipcap_coco'] = {}
							 | 
						||
| 
								 | 
							
								        config['clipcap_coco']['weights'] = 'weights/coco_weights.pt'
							 | 
						||
| 
								 | 
							
								        config['clipcap_conceptual'] = {}
							 | 
						||
| 
								 | 
							
								        config['clipcap_conceptual']['weights'] = 'weights/conceptual_weights.pt'
							 | 
						||
| 
								 | 
							
								        return config
							 | 
						||
| 
								 | 
							
								
							 |