Browse Source
        
      
      fix the operator.
      
        Signed-off-by: wxywb <xy.wang@zilliz.com>
      
      
        main
      
      
     
    
      
        
          
             wxywb
          
          3 years ago
            wxywb
          
          3 years ago
          
         
        
        
       
      
     
    
    
	
		
			
				 2 changed files with 
5 additions and 
4 deletions
			 
			
		 
		
			
				- 
					
					
					 
					__init__.py
				
- 
					
					
					 
					jclip.py
				
				
				
					
						
							
								
									
	
		
			
				
					|  |  | @ -14,5 +14,5 @@ | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  | from .jclip import Jaclip | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  | def jclip(model_name: str, modality: str): | 
			
		
	
		
			
				
					|  |  |  | def japanese_clip(model_name: str, modality: str): | 
			
		
	
		
			
				
					|  |  |  |     return Jaclip(model_name, modality) | 
			
		
	
	
		
			
				
					|  |  | 
 | 
			
		
	
								
							
						
					 
					
				 
			 
		
			
			
			
			
			
			
				
				
					
						
							
								
									
	
		
			
				
					|  |  | @ -34,6 +34,7 @@ class Jaclip(NNOperator): | 
			
		
	
		
			
				
					|  |  |  |         import japanese_clip as ja_clip | 
			
		
	
		
			
				
					|  |  |  |         sys.path.pop() | 
			
		
	
		
			
				
					|  |  |  |         self.device = "cuda" if torch.cuda.is_available() else "cpu" | 
			
		
	
		
			
				
					|  |  |  |         self._modality = modality | 
			
		
	
		
			
				
					|  |  |  |         model, preprocess = ja_clip.load("rinna/japanese-clip-vit-b-16", cache_dir="{}/weights/japanese_clip".format(path), device=self.device) | 
			
		
	
		
			
				
					|  |  |  |         self.model = model | 
			
		
	
		
			
				
					|  |  |  |         self.tfms = preprocess | 
			
		
	
	
		
			
				
					|  |  | @ -51,20 +52,20 @@ class Jaclip(NNOperator): | 
			
		
	
		
			
				
					|  |  |  |         return vec.detach().cpu().numpy().flatten() | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |     def _inference_from_text(self, text): | 
			
		
	
		
			
				
					|  |  |  |         encodings = ja_clip.tokenize( | 
			
		
	
		
			
				
					|  |  |  |         encodings = self.ja_clip.tokenize( | 
			
		
	
		
			
				
					|  |  |  |             texts=[text], | 
			
		
	
		
			
				
					|  |  |  |             max_seq_len=77, | 
			
		
	
		
			
				
					|  |  |  |             device=self.device, | 
			
		
	
		
			
				
					|  |  |  |             tokenizer=self.tokenizer, # this is optional. if you don't pass, load tokenizer each time | 
			
		
	
		
			
				
					|  |  |  |         ) | 
			
		
	
		
			
				
					|  |  |  |         text_feature = model.get_text_features(**encodings) | 
			
		
	
		
			
				
					|  |  |  |         text_feature = self.model.get_text_features(**encodings) | 
			
		
	
		
			
				
					|  |  |  |         return text_feature | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |     @arg(1, to_image_color('RGB')) | 
			
		
	
		
			
				
					|  |  |  |     def _inference_from_image(self, img): | 
			
		
	
		
			
				
					|  |  |  |         img = self._preprocess(img) | 
			
		
	
		
			
				
					|  |  |  |         caption = '' | 
			
		
	
		
			
				
					|  |  |  |         image_feature = self.model.get_image_features(image) | 
			
		
	
		
			
				
					|  |  |  |         image_feature = self.model.get_image_features(img) | 
			
		
	
		
			
				
					|  |  |  |         return image_feature | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |     def _preprocess(self, img): | 
			
		
	
	
		
			
				
					|  |  | 
 |