You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
			
			Readme
Files and versions
		
      
        
        
          
            156 lines
          
        
        
          
            5.3 KiB
          
        
        
      
		
    
      
      
    
	
  
	
            156 lines
          
        
        
          
            5.3 KiB
          
        
        
      | import os | |
| from pathlib import Path | |
| from typing import List | |
| from functools import partial | |
| 
 | |
| import torch | |
| try: | |
|     from towhee import accelerate | |
| except: | |
|     def accelerate(func): | |
|         return func | |
| from towhee.operator import NNOperator | |
| from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig | |
| 
 | |
| 
 | |
| @accelerate | |
| class Model: | |
|     def __init__(self, model_name, config, device): | |
|         self.device = device | |
|         self.config = config | |
|         self.model = AutoModelForSequenceClassification.from_pretrained(model_name, config=self.config) | |
|         self.model.to(self.device) | |
|         self.model.eval() | |
| 
 | |
| 
 | |
|     def __call__(self, *args, **kwargs): | |
|         new_args = [] | |
|         for x in args: | |
|             new_args.append(x.to(self.device)) | |
|         new_kwargs = {} | |
|         for k, v in kwargs.items(): | |
|             new_kwargs[k] = v.to(self.device) | |
|         outs = self.model(*new_args, **new_kwargs, return_dict=True) | |
|         return outs.logits | |
| 
 | |
| 
 | |
| class ReRank(NNOperator): | |
|     def __init__(self, model_name: str = 'cross-encoder/ms-marco-MiniLM-L-6-v2', threshold: float = 0.6, device: str = None, max_length=512): | |
|         super().__init__() | |
|         self._model_name = model_name | |
|         self.config = AutoConfig.from_pretrained(model_name) | |
|         self.device = device | |
|         self.model = Model(model_name, self.config, device) | |
|         self.tokenizer = AutoTokenizer.from_pretrained(model_name) | |
|         self.max_length = max_length | |
|         self._threshold = threshold | |
|         if self.config.num_labels == 1: | |
|             self.activation_fct = torch.sigmoid | |
|         else: | |
|             self.activation_fct = partial(torch.softmax, dim=1) | |
| 
 | |
|     def __call__(self, query: str, docs: List): | |
|         if len(docs) == 0: | |
|             return [], [] | |
| 
 | |
|         batch = [(query, doc) for doc in docs] | |
|         texts = [[] for _ in range(len(batch[0]))] | |
| 
 | |
|         for example in batch: | |
|             for idx, text in enumerate(example): | |
|                 texts[idx].append(text.strip()) | |
| 
 | |
|         tokenized = self.tokenizer(*texts, padding=True, truncation='longest_first', return_tensors="pt", max_length=self.max_length) | |
| 
 | |
|         for name in tokenized: | |
|             tokenized[name] = tokenized[name].to(self.device) | |
| 
 | |
|         logits = self.model(**tokenized) | |
|         scores = self.post_proc(logits) | |
| 
 | |
|         re_ids = sorted(range(len(scores)), key=lambda k: scores[k], reverse=True) | |
|         if self._threshold is None: | |
|             re_docs = [docs[i] for i in re_ids] | |
|             re_scores = [scores[i] for i in re_ids] | |
|         else: | |
|             re_docs = [docs[i] for i in re_ids if scores[i] >= self._threshold] | |
|             re_scores = [scores[i] for i in re_ids if scores[i] >= self._threshold] | |
|         return re_docs, re_scores | |
| 
 | |
| 
 | |
|     def post_proc(self, logits): | |
|         scores = self.activation_fct(logits).detach().cpu().numpy() | |
|         if self.config.num_labels == 1: | |
|             scores = [float(score[0]) for score in scores] | |
|         else: | |
|             scores = scores[:, 1] | |
|             scores = [float(score) for score in scores] | |
|         return scores | |
| 
 | |
| 
 | |
|     @property | |
|     def _model(self): | |
|         return self.model.model | |
| 
 | |
|     @property | |
|     def supported_formats(self): | |
|         return ['onnx'] | |
| 
 | |
|     def save_model(self, format: str = 'pytorch', path: str = 'default'): | |
|         if path == 'default': | |
|             path = str(Path(__file__).parent) | |
|             path = os.path.join(path, 'saved', format) | |
|             os.makedirs(path, exist_ok=True) | |
|             name = self._model_name.replace('/', '-') | |
|             path = os.path.join(path, name) | |
|             if format in ['pytorch',]: | |
|                 path = path + '.pt' | |
|             elif format == 'onnx': | |
|                 path = path + '.onnx' | |
|             else: | |
|                 raise AttributeError(f'Invalid format {format}.') | |
| 
 | |
| 
 | |
|         if format == 'pytorch': | |
|             torch.save(self._model, path) | |
| 
 | |
|         elif format == 'onnx': | |
|             from transformers.onnx.features import FeaturesManager | |
|             from transformers.onnx import export | |
|             model_kind, model_onnx_config = FeaturesManager.check_supported_model_or_raise( | |
|                 self._model, feature='default') | |
|             onnx_config = model_onnx_config(self._model.config) | |
|             onnx_inputs, onnx_outputs = export( | |
|                 self.tokenizer, | |
|                 self._model, | |
|                 config=onnx_config, | |
|                 opset=13, | |
|                 output=Path(path) | |
|             ) | |
|         return Path(path).resolve() | |
| 
 | |
| 
 | |
| 
 | |
| if __name__ == '__main__': | |
|     model_name_list = [ | |
|         'cross-encoder/ms-marco-TinyBERT-L-2-v2', | |
|         'cross-encoder/ms-marco-MiniLM-L-2-v2', | |
|         'cross-encoder/ms-marco-MiniLM-L-4-v2', | |
|         'cross-encoder/ms-marco-MiniLM-L-6-v2', | |
|         'cross-encoder/ms-marco-MiniLM-L-12-v2', | |
|         'cross-encoder/ms-marco-TinyBERT-L-2', | |
|         'cross-encoder/ms-marco-TinyBERT-L-4', | |
|         'cross-encoder/ms-marco-TinyBERT-L-6', | |
|         'cross-encoder/ms-marco-electra-base', | |
|         'nboost/pt-tinybert-msmarco', | |
|         'nboost/pt-bert-base-uncased-msmarco', | |
|         'nboost/pt-bert-large-msmarco', | |
|         'Capreolus/electra-base-msmarco', | |
|         'amberoad/bert-multilingual-passage-reranking-msmarco', | |
|         ] | |
|     for model_name in model_name_list: | |
|         print('\n' + model_name) | |
|         op = ReRank(model_name, threshold=0) | |
|         res = op('abc', ['123', 'ABC', 'ABCabc']) | |
|         print(res) | |
|         op.save_model('onnx')
 | 
