Browse Source
        
      
      fix gpu related problem.
      
        Signed-off-by: jinlingxu06 <jinling.xu@zilliz.com>
      
      
        main
      
      
     
    
      
        
          
             jinlingxu06
          
          3 years ago
            jinlingxu06
          
          3 years ago
          
         
        
        
       
      
     
    
    
	
		
			
				 2 changed files with 
7 additions and 
5 deletions
			 
			
		 
		
			
				- 
					
					
					 
					__init__.py
				
- 
					
					
					 
					expansionnet_v2.py
				
				
				
					
						
							
								
									
	
		
			
				
					|  |  | @ -14,5 +14,5 @@ | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  | from .expansionnet_v2 import ExpansionNetV2  | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  | def expansionnet_v2(model_name: str): | 
			
		
	
		
			
				
					|  |  |  |     return ExpansionNetV2(model_name) | 
			
		
	
		
			
				
					|  |  |  | def expansionnet_v2(model_name, device = None): | 
			
		
	
		
			
				
					|  |  |  |     return ExpansionNetV2(model_name, device) | 
			
		
	
	
		
			
				
					|  |  | 
 | 
			
		
	
								
							
						
					 
					
				 
			 
		
			
			
			
			
			
			
				
				
					
						
							
								
									
	
		
			
				
					|  |  | @ -33,7 +33,7 @@ class ExpansionNetV2(NNOperator): | 
			
		
	
		
			
				
					|  |  |  |     """ | 
			
		
	
		
			
				
					|  |  |  |     ExpansionNet V2 image captioning operator | 
			
		
	
		
			
				
					|  |  |  |     """ | 
			
		
	
		
			
				
					|  |  |  |     def __init__(self, model_name: str): | 
			
		
	
		
			
				
					|  |  |  |     def __init__(self, model_name: str, device: str = None): | 
			
		
	
		
			
				
					|  |  |  |         super().__init__() | 
			
		
	
		
			
				
					|  |  |  |         path = str(pathlib.Path(__file__).parent) | 
			
		
	
		
			
				
					|  |  |  |         sys.path.append(path) | 
			
		
	
	
		
			
				
					|  |  | @ -50,7 +50,9 @@ class ExpansionNetV2(NNOperator): | 
			
		
	
		
			
				
					|  |  |  |             self.coco_tokens = coco_tokens | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |         img_size = 384 | 
			
		
	
		
			
				
					|  |  |  |         self.device = "cuda" if torch.cuda.is_available() else "cpu" | 
			
		
	
		
			
				
					|  |  |  |         if device == None: | 
			
		
	
		
			
				
					|  |  |  |             device = "cuda" if torch.cuda.is_available() else "cpu" | 
			
		
	
		
			
				
					|  |  |  |         self.device = device | 
			
		
	
		
			
				
					|  |  |  |         drop_args = Namespace(enc=0.0, | 
			
		
	
		
			
				
					|  |  |  |                       dec=0.0, | 
			
		
	
		
			
				
					|  |  |  |                       enc_input=0.0, | 
			
		
	
	
		
			
				
					|  |  | @ -84,7 +86,7 @@ class ExpansionNetV2(NNOperator): | 
			
		
	
		
			
				
					|  |  |  |                                     output_word2idx=coco_tokens['word2idx_dict'], | 
			
		
	
		
			
				
					|  |  |  |                                     output_idx2word=coco_tokens['idx2word_list'], | 
			
		
	
		
			
				
					|  |  |  |                                     max_seq_len=max_seq_len, drop_args=model_args.drop_args, | 
			
		
	
		
			
				
					|  |  |  |                                 rank='cpu') | 
			
		
	
		
			
				
					|  |  |  |                                     rank=self.device) | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |         checkpoint = torch.load('{}/weights/{}'.format(path,os.path.basename(cfg['weights'])), map_location=torch.device('cpu')) | 
			
		
	
		
			
				
					|  |  |  |         self.model.load_state_dict(checkpoint['model_state_dict']) | 
			
		
	
	
		
			
				
					|  |  | 
 |