Browse Source
        
      
      make accelerate evaliable.
      
        Signed-off-by: wxywb <xy.wang@zilliz.com>
      
      
        main
      
      
     
    
      
        
          
             wxywb
          
          2 years ago
            wxywb
          
          2 years ago
          
         
        
        
       
      
     
    
    
	
		
			
				 1 changed files with 
6 additions and 
2 deletions
			 
			
		 
		
			
				- 
					
					
					 
					clip.py
				
				
				
					
						
							
								
									
	
		
			
				
					|  |  | @ -26,7 +26,11 @@ from towhee.types.arg import arg, to_image_color | 
			
		
	
		
			
				
					|  |  |  | from towhee import register | 
			
		
	
		
			
				
					|  |  |  | from transformers import CLIPTokenizer, CLIPTextModel ,CLIPModel,CLIPProcessor | 
			
		
	
		
			
				
					|  |  |  | from transformers import logging as t_logging | 
			
		
	
		
			
				
					|  |  |  | # from towhee.dc2 import accelerate | 
			
		
	
		
			
				
					|  |  |  | try: | 
			
		
	
		
			
				
					|  |  |  |     from towhee import accelerate | 
			
		
	
		
			
				
					|  |  |  | except: | 
			
		
	
		
			
				
					|  |  |  |     def accelerate(func): | 
			
		
	
		
			
				
					|  |  |  |         return func | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  | log = logging.getLogger('run_op') | 
			
		
	
		
			
				
					|  |  |  | warnings.filterwarnings('ignore') | 
			
		
	
	
		
			
				
					|  |  | @ -70,7 +74,7 @@ class CLIPModelText(nn.Module): | 
			
		
	
		
			
				
					|  |  |  |         text_embeds = self.backbone.get_text_features(input_ids, attention_mask) | 
			
		
	
		
			
				
					|  |  |  |         return text_embeds | 
			
		
	
		
			
				
					|  |  |  |      | 
			
		
	
		
			
				
					|  |  |  | # @accelerate | 
			
		
	
		
			
				
					|  |  |  | @accelerate | 
			
		
	
		
			
				
					|  |  |  | class Model: | 
			
		
	
		
			
				
					|  |  |  |     def __init__(self, model_name, modality, checkpoint_path, device): | 
			
		
	
		
			
				
					|  |  |  |         self.model = create_model(model_name, modality, checkpoint_path, device) | 
			
		
	
	
		
			
				
					|  |  | 
 |