Browse Source
        
      
      add device and sigmoid
      
        Signed-off-by: junjie.jiang <junjie.jiang@zilliz.com>
      
      
        main
      
      
     
    
    
    
	
		
			
				 2 changed files with 
6 additions and 
5 deletions
			 
			
		 
		
			
				- 
					
					
					 
					README.md
				
- 
					
					
					 
					rerank.py
				
				
				
					
						
							
								
									
	
		
			
				
					|  |  | @ -36,7 +36,6 @@ DataCollection(p('What is Towhee?', | 
			
		
	
		
			
				
					|  |  |  |               ).show() | 
			
		
	
		
			
				
					|  |  |  | ``` | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  | <img src="./result.png" height="100px"/> | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  | <br /> | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
	
		
			
				
					|  |  | @ -56,8 +55,8 @@ Create the operator via the following factory method | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  | 	***threshold***: float | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |     The threshold for filtering with score, defaults to none, i.e., no filtering. | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |     The threshold for filtering with score | 
			
		
	
		
			
				
					|  |  |  |    ***device***: str | 
			
		
	
		
			
				
					|  |  |  | <br /> | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
	
		
			
				
					|  |  | 
 | 
			
		
	
								
							
						
					 
					
				 
			 
		
			
			
			
			
			
			
				
				
					
						
							
								
									
	
		
			
				
					|  |  | @ -1,14 +1,16 @@ | 
			
		
	
		
			
				
					|  |  |  | from typing import List | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  | from torch import nn | 
			
		
	
		
			
				
					|  |  |  | from sentence_transformers import CrossEncoder | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  | from towhee.operator import NNOperator | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  | class ReRank(NNOperator): | 
			
		
	
		
			
				
					|  |  |  |     def __init__(self, model_name: str = 'cross-encoder/ms-marco-MiniLM-L-6-v2', threshold: float = None): | 
			
		
	
		
			
				
					|  |  |  |     def __init__(self, model_name: str = 'cross-encoder/ms-marco-MiniLM-L-6-v2', threshold: float = 0.6, device: str = None): | 
			
		
	
		
			
				
					|  |  |  |         super().__init__() | 
			
		
	
		
			
				
					|  |  |  |         self._model_name = model_name | 
			
		
	
		
			
				
					|  |  |  |         self._model = CrossEncoder(self._model_name, max_length=1000) | 
			
		
	
		
			
				
					|  |  |  |         self._model = CrossEncoder(self._model_name, device=device, default_activation_function=nn.Sigmoid()) | 
			
		
	
		
			
				
					|  |  |  |         self._threshold = threshold | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |     def __call__(self, query: str, docs: List): | 
			
		
	
	
		
			
				
					|  |  | 
 |