| 
					
					
					
				 | 
				@ -1,17 +1,18 @@ | 
			
		
		
	
		
			
				 | 
				 | 
				from typing import List | 
				 | 
				 | 
				from typing import List | 
			
		
		
	
		
			
				 | 
				 | 
				from xml.dom.expatbuilder import theDOMImplementation | 
				 | 
				 | 
				 | 
			
		
		
	
		
			
				 | 
				 | 
				
 | 
				 | 
				 | 
				
 | 
			
		
		
	
		
			
				 | 
				 | 
				from sentence_transformers import CrossEncoder | 
				 | 
				 | 
				from sentence_transformers import CrossEncoder | 
			
		
		
	
		
			
				 | 
				 | 
				from towhee.operator import NNOperator | 
				 | 
				 | 
				from towhee.operator import NNOperator | 
			
		
		
	
		
			
				 | 
				 | 
				
 | 
				 | 
				 | 
				
 | 
			
		
		
	
		
			
				 | 
				 | 
				
 | 
				 | 
				 | 
				
 | 
			
		
		
	
		
			
				 | 
				 | 
				class ReRank(NNOperator): | 
				 | 
				 | 
				class ReRank(NNOperator): | 
			
		
		
	
		
			
				 | 
				 | 
				    def __init__(self, model_name: str = 'cross-encoder/ms-marco-MiniLM-L-12-v2'): | 
				 | 
				 | 
				 | 
			
		
		
	
		
			
				 | 
				 | 
				 | 
				 | 
				 | 
				    def __init__(self, model_name: str = 'cross-encoder/ms-marco-MiniLM-L-6-v2'): | 
			
		
		
	
		
			
				 | 
				 | 
				        super().__init__() | 
				 | 
				 | 
				        super().__init__() | 
			
		
		
	
		
			
				 | 
				 | 
				        self._model_name = model_name | 
				 | 
				 | 
				        self._model_name = model_name | 
			
		
		
	
		
			
				 | 
				 | 
				        self._model = CrossEncoder(self._model_name, max_length=1000) | 
				 | 
				 | 
				        self._model = CrossEncoder(self._model_name, max_length=1000) | 
			
		
		
	
		
			
				 | 
				 | 
				
 | 
				 | 
				 | 
				
 | 
			
		
		
	
		
			
				 | 
				 | 
				    def __call__(self, query: str, docs: List, threshold: float = None): | 
				 | 
				 | 
				    def __call__(self, query: str, docs: List, threshold: float = None): | 
			
		
		
	
		
			
				 | 
				 | 
				 | 
				 | 
				 | 
				        if len(docs) == 0: | 
			
		
		
	
		
			
				 | 
				 | 
				 | 
				 | 
				 | 
				            return [], [] | 
			
		
		
	
		
			
				 | 
				 | 
				        scores = self._model.predict([(query, doc) for doc in docs]) | 
				 | 
				 | 
				        scores = self._model.predict([(query, doc) for doc in docs]) | 
			
		
		
	
		
			
				 | 
				 | 
				        re_ids = sorted(range(len(scores)), key=lambda k: scores[k], reverse=True) | 
				 | 
				 | 
				        re_ids = sorted(range(len(scores)), key=lambda k: scores[k], reverse=True) | 
			
		
		
	
		
			
				 | 
				 | 
				        if threshold is None: | 
				 | 
				 | 
				        if threshold is None: | 
			
		
		
	
	
		
			
				| 
					
						
							
						
					
					
					
				 | 
				
  |