Browse Source
        
      
      adapt when model output label num is 2
      
        Signed-off-by: ChengZi <chen.zhang@zilliz.com>
      
      
        main
      
      
     
    
      
        
          
            
            ChengZi
          
          2 years ago
          
         
        
        
       
      
     
    
    
	
		
			
				 1 changed files with 
7 additions and 
2 deletions
			 
			
		 
		
			
				- 
					
					
					 
					rerank.py
				
 
			
		
		
			
			
			
			
			
			
				
				
					
						
							
								
									
	
		
			
				
					| 
						
						
							
								
							
						
						
					 | 
				
				 | 
				
					@ -10,12 +10,17 @@ class ReRank(NNOperator): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    def __init__(self, model_name: str = 'cross-encoder/ms-marco-MiniLM-L-6-v2', threshold: float = 0.6, device: str = None): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        super().__init__() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self._model_name = model_name | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self._model = CrossEncoder(self._model_name, device=device, default_activation_function=nn.Sigmoid()) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self._model = CrossEncoder(self._model_name, device=device) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        if self._model.config.num_labels == 1: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            self._model.default_activation_function = nn.Sigmoid() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self._threshold = threshold | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    def __call__(self, query: str, docs: List): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        if len(docs) == 0: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            return [], [] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        if self._model.config.num_labels > 1: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            scores = self._model.predict([(query, doc) for doc in docs], apply_softmax=True)[:, 1] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        else: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            scores = self._model.predict([(query, doc) for doc in docs]) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        re_ids = sorted(range(len(scores)), key=lambda k: scores[k], reverse=True) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        if self._threshold is None: | 
				
			
			
		
	
	
		
			
				
					| 
						
							
								
							
						
						
						
					 | 
				
				 | 
				
					
  |