| 
						
						
							
								
							
						
						
					 | 
				
				 | 
				
					@ -22,6 +22,7 @@ from towhee.operator.base import NNOperator, OperatorFlag | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					from towhee.types.arg import arg, to_image_color | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					from towhee import register | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					from transformers import CLIPTokenizer, CLIPTextModel ,CLIPModel,CLIPProcessor | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					#from towhee.dc2 import accelerate | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					#@accelerate | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					class CLIPModelVision(nn.Module): | 
				
			
			
		
	
	
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
				
				 | 
				
					@ -49,10 +50,10 @@ class Clip(NNOperator): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    """ | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    CLIP multi-modal embedding operator | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    """ | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    def __init__(self, model_name: str, modality: str, device, checkpoint_path): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    def __init__(self, model_name: str, modality: str, device: str = 'cpu', checkpoint_path: str = None): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self.model_name = model_name | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self.modality = modality | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self.device = "cuda" if torch.cuda.is_available() else "cpu" | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self.device = device | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        cfg = self._configs()[model_name] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        try: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            clip_model = CLIPModel.from_pretrained(cfg) | 
				
			
			
		
	
	
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
				
				 | 
				
					@ -71,6 +72,7 @@ class Clip(NNOperator): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            self.model = CLIPModelText(clip_model) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        else: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            raise ValueError("modality[{}] not implemented.".format(self.modality)) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self.model.to(self.device) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self.tokenizer = CLIPTokenizer.from_pretrained(cfg) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self.processor =  CLIPProcessor.from_pretrained(cfg) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
	
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
				
				 | 
				
					@ -99,14 +101,14 @@ class Clip(NNOperator): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    def _inference_from_text(self, text): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        tokens = self.tokenizer([text], padding=True, return_tensors="pt") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        text_features = self.model(tokens['input_ids'],tokens['attention_mask']) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        text_features = self.model(tokens['input_ids'].to(self.device), tokens['attention_mask'].to(self.device)) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        return text_features | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    @arg(1, to_image_color('RGB')) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    def _inference_from_image(self, img): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        img = to_pil(img) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        inputs = self.processor(images=img, return_tensors="pt") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        image_features = self.model(inputs['pixel_values']) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        image_features = self.model(inputs['pixel_values'].to(self.device)) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        return image_features | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    def train(self, **kwargs): | 
				
			
			
		
	
	
		
			
				
					| 
						
							
								
							
						
						
						
					 | 
				
				 | 
				
					
  |