| 
					
					
						
							
						
					
					
				 | 
				@ -10,13 +10,18 @@ class ReRank(NNOperator): | 
			
		
		
	
		
			
				 | 
				 | 
				    def __init__(self, model_name: str = 'cross-encoder/ms-marco-MiniLM-L-6-v2', threshold: float = 0.6, device: str = None): | 
				 | 
				 | 
				    def __init__(self, model_name: str = 'cross-encoder/ms-marco-MiniLM-L-6-v2', threshold: float = 0.6, device: str = None): | 
			
		
		
	
		
			
				 | 
				 | 
				        super().__init__() | 
				 | 
				 | 
				        super().__init__() | 
			
		
		
	
		
			
				 | 
				 | 
				        self._model_name = model_name | 
				 | 
				 | 
				        self._model_name = model_name | 
			
		
		
	
		
			
				 | 
				 | 
				        self._model = CrossEncoder(self._model_name, device=device, default_activation_function=nn.Sigmoid()) | 
				 | 
				 | 
				 | 
			
		
		
	
		
			
				 | 
				 | 
				 | 
				 | 
				 | 
				        self._model = CrossEncoder(self._model_name, device=device) | 
			
		
		
	
		
			
				 | 
				 | 
				 | 
				 | 
				 | 
				        if self._model.config.num_labels == 1: | 
			
		
		
	
		
			
				 | 
				 | 
				 | 
				 | 
				 | 
				            self._model.default_activation_function = nn.Sigmoid() | 
			
		
		
	
		
			
				 | 
				 | 
				        self._threshold = threshold | 
				 | 
				 | 
				        self._threshold = threshold | 
			
		
		
	
		
			
				 | 
				 | 
				
 | 
				 | 
				 | 
				
 | 
			
		
		
	
		
			
				 | 
				 | 
				    def __call__(self, query: str, docs: List): | 
				 | 
				 | 
				    def __call__(self, query: str, docs: List): | 
			
		
		
	
		
			
				 | 
				 | 
				        if len(docs) == 0: | 
				 | 
				 | 
				        if len(docs) == 0: | 
			
		
		
	
		
			
				 | 
				 | 
				            return [], [] | 
				 | 
				 | 
				            return [], [] | 
			
		
		
	
		
			
				 | 
				 | 
				        scores = self._model.predict([(query, doc) for doc in docs]) | 
				 | 
				 | 
				 | 
			
		
		
	
		
			
				 | 
				 | 
				 | 
				 | 
				 | 
				        if self._model.config.num_labels > 1: | 
			
		
		
	
		
			
				 | 
				 | 
				 | 
				 | 
				 | 
				            scores = self._model.predict([(query, doc) for doc in docs], apply_softmax=True)[:, 1] | 
			
		
		
	
		
			
				 | 
				 | 
				 | 
				 | 
				 | 
				        else: | 
			
		
		
	
		
			
				 | 
				 | 
				 | 
				 | 
				 | 
				            scores = self._model.predict([(query, doc) for doc in docs]) | 
			
		
		
	
		
			
				 | 
				 | 
				        re_ids = sorted(range(len(scores)), key=lambda k: scores[k], reverse=True) | 
				 | 
				 | 
				        re_ids = sorted(range(len(scores)), key=lambda k: scores[k], reverse=True) | 
			
		
		
	
		
			
				 | 
				 | 
				        if self._threshold is None: | 
				 | 
				 | 
				        if self._threshold is None: | 
			
		
		
	
		
			
				 | 
				 | 
				            re_docs = [docs[i] for i in re_ids] | 
				 | 
				 | 
				            re_docs = [docs[i] for i in re_ids] | 
			
		
		
	
	
		
			
				| 
					
						
							
						
					
					
					
				 | 
				
  |