| 
					
					
						
							
						
					
					
				 | 
				@ -15,10 +15,13 @@ from transformers import AutoModelForSequenceClassification, AutoTokenizer, Auto | 
			
		
		
	
		
			
				 | 
				 | 
				
 | 
				 | 
				 | 
				
 | 
			
		
		
	
		
			
				 | 
				 | 
				@accelerate | 
				 | 
				 | 
				@accelerate | 
			
		
		
	
		
			
				 | 
				 | 
				class Model: | 
				 | 
				 | 
				class Model: | 
			
		
		
	
		
			
				 | 
				 | 
				    def __init__(self, model_name, config, device): | 
				 | 
				 | 
				 | 
			
		
		
	
		
			
				 | 
				 | 
				 | 
				 | 
				 | 
				    def __init__(self, model_name, checkpoint_path, config, device): | 
			
		
		
	
		
			
				 | 
				 | 
				        self.device = device | 
				 | 
				 | 
				        self.device = device | 
			
		
		
	
		
			
				 | 
				 | 
				        self.config = config | 
				 | 
				 | 
				        self.config = config | 
			
		
		
	
		
			
				 | 
				 | 
				        self.model = AutoModelForSequenceClassification.from_pretrained(model_name, config=self.config) | 
				 | 
				 | 
				        self.model = AutoModelForSequenceClassification.from_pretrained(model_name, config=self.config) | 
			
		
		
	
		
			
				 | 
				 | 
				 | 
				 | 
				 | 
				        if checkpoint_path: | 
			
		
		
	
		
			
				 | 
				 | 
				 | 
				 | 
				 | 
				            state_dict = torch.load(checkpoint_path, map_location=device) | 
			
		
		
	
		
			
				 | 
				 | 
				 | 
				 | 
				 | 
				            self.model.load_state_dict(state_dict) | 
			
		
		
	
		
			
				 | 
				 | 
				        self.model.to(self.device) | 
				 | 
				 | 
				        self.model.to(self.device) | 
			
		
		
	
		
			
				 | 
				 | 
				        self.model.eval() | 
				 | 
				 | 
				        self.model.eval() | 
			
		
		
	
		
			
				 | 
				 | 
				
 | 
				 | 
				 | 
				
 | 
			
		
		
	
	
		
			
				| 
					
					
					
						
							
						
					
				 | 
				@ -35,12 +38,12 @@ class Model: | 
			
		
		
	
		
			
				 | 
				 | 
				
 | 
				 | 
				 | 
				
 | 
			
		
		
	
		
			
				 | 
				 | 
				
 | 
				 | 
				 | 
				
 | 
			
		
		
	
		
			
				 | 
				 | 
				class ReRank(NNOperator): | 
				 | 
				 | 
				class ReRank(NNOperator): | 
			
		
		
	
		
			
				 | 
				 | 
				    def __init__(self, model_name: str = 'cross-encoder/ms-marco-MiniLM-L-6-v2', threshold: float = 0.6, device: str = None, max_length=512): | 
				 | 
				 | 
				 | 
			
		
		
	
		
			
				 | 
				 | 
				 | 
				 | 
				 | 
				    def __init__(self, model_name: str = 'cross-encoder/ms-marco-MiniLM-L-6-v2', threshold: float = 0.6, device: str = None, max_length=512, checkpoint_path=None): | 
			
		
		
	
		
			
				 | 
				 | 
				        super().__init__() | 
				 | 
				 | 
				        super().__init__() | 
			
		
		
	
		
			
				 | 
				 | 
				        self._model_name = model_name | 
				 | 
				 | 
				        self._model_name = model_name | 
			
		
		
	
		
			
				 | 
				 | 
				        self.config = AutoConfig.from_pretrained(model_name) | 
				 | 
				 | 
				        self.config = AutoConfig.from_pretrained(model_name) | 
			
		
		
	
		
			
				 | 
				 | 
				        self.device = device | 
				 | 
				 | 
				        self.device = device | 
			
		
		
	
		
			
				 | 
				 | 
				        self.model = Model(model_name, self.config, device) | 
				 | 
				 | 
				 | 
			
		
		
	
		
			
				 | 
				 | 
				 | 
				 | 
				 | 
				        self.model = Model(model_name, checkpoint_path, self.config, device) | 
			
		
		
	
		
			
				 | 
				 | 
				        self.tokenizer = AutoTokenizer.from_pretrained(model_name) | 
				 | 
				 | 
				        self.tokenizer = AutoTokenizer.from_pretrained(model_name) | 
			
		
		
	
		
			
				 | 
				 | 
				        self.max_length = max_length | 
				 | 
				 | 
				        self.max_length = max_length | 
			
		
		
	
		
			
				 | 
				 | 
				        self._threshold = threshold | 
				 | 
				 | 
				        self._threshold = threshold | 
			
		
		
	
	
		
			
				| 
					
						
							
						
					
					
					
				 | 
				
  |