import os from pathlib import Path from typing import List from functools import partial import torch try: from towhee import accelerate except: def accelerate(func): return func from towhee.operator import NNOperator from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig @accelerate class Model: def __init__(self, model_name, checkpoint_path, config, device): self.device = device self.config = config if checkpoint_path: self.model = AutoModelForSequenceClassification.from_pretrained(checkpoint_path, config=self.config) else: self.model = AutoModelForSequenceClassification.from_pretrained(model_name, config=self.config) self.model.to(self.device) self.model.eval() def __call__(self, *args, **kwargs): new_args = [] for x in args: new_args.append(x.to(self.device)) new_kwargs = {} for k, v in kwargs.items(): new_kwargs[k] = v.to(self.device) outs = self.model(*new_args, **new_kwargs, return_dict=True) return outs.logits class ReRank(NNOperator): def __init__(self, model_name: str = 'cross-encoder/ms-marco-MiniLM-L-6-v2', threshold: float = 0.6, device: str = None, max_length=512, checkpoint_path=None): super().__init__() self._model_name = model_name self.config = AutoConfig.from_pretrained(model_name) self.device = device self.model = Model(model_name, checkpoint_path, self.config, device) self.tokenizer = AutoTokenizer.from_pretrained(model_name) self.max_length = max_length self._threshold = threshold if self.config.num_labels == 1: self.activation_fct = torch.sigmoid else: self.activation_fct = partial(torch.softmax, dim=1) def __call__(self, query: str, docs: List): if len(docs) == 0: return [], [] batch = [(query, doc) for doc in docs] texts = [[] for _ in range(len(batch[0]))] for example in batch: for idx, text in enumerate(example): texts[idx].append(text.strip()) tokenized = self.tokenizer(*texts, padding=True, truncation='longest_first', return_tensors="pt", max_length=self.max_length) logits = self.model(**tokenized) scores = self.post_proc(logits) re_ids = sorted(range(len(scores)), key=lambda k: scores[k], reverse=True) if self._threshold is None: re_docs = [docs[i] for i in re_ids] re_scores = [scores[i] for i in re_ids] else: re_docs = [docs[i] for i in re_ids if scores[i] >= self._threshold] re_scores = [scores[i] for i in re_ids if scores[i] >= self._threshold] return re_docs, re_scores def post_proc(self, logits): scores = self.activation_fct(logits).detach().cpu().numpy() if self.config.num_labels == 1: scores = [float(score[0]) for score in scores] else: scores = scores[:, 1] scores = [float(score) for score in scores] return scores @property def _model(self): return self.model.model @property def supported_formats(self): return ['onnx'] def save_model(self, format: str = 'pytorch', path: str = 'default'): if path == 'default': path = str(Path(__file__).parent) path = os.path.join(path, 'saved', format) os.makedirs(path, exist_ok=True) name = self._model_name.replace('/', '-') path = os.path.join(path, name) if format in ['pytorch',]: path = path + '.pt' elif format == 'onnx': path = path + '.onnx' else: raise AttributeError(f'Invalid format {format}.') if format == 'pytorch': torch.save(self._model, path) elif format == 'onnx': from transformers.onnx.features import FeaturesManager from transformers.onnx import export model_kind, model_onnx_config = FeaturesManager.check_supported_model_or_raise( self._model, feature='default') onnx_config = model_onnx_config(self._model.config) onnx_inputs, onnx_outputs = export( self.tokenizer, self._model, config=onnx_config, opset=13, output=Path(path) ) return Path(path).resolve() if __name__ == '__main__': model_name_list = [ 'cross-encoder/ms-marco-TinyBERT-L-2-v2', 'cross-encoder/ms-marco-MiniLM-L-2-v2', 'cross-encoder/ms-marco-MiniLM-L-4-v2', 'cross-encoder/ms-marco-MiniLM-L-6-v2', 'cross-encoder/ms-marco-MiniLM-L-12-v2', 'cross-encoder/ms-marco-TinyBERT-L-2', 'cross-encoder/ms-marco-TinyBERT-L-4', 'cross-encoder/ms-marco-TinyBERT-L-6', 'cross-encoder/ms-marco-electra-base', 'nboost/pt-tinybert-msmarco', 'nboost/pt-bert-base-uncased-msmarco', 'nboost/pt-bert-large-msmarco', 'Capreolus/electra-base-msmarco', 'amberoad/bert-multilingual-passage-reranking-msmarco', ] for model_name in model_name_list: print('\n' + model_name) op = ReRank(model_name, threshold=0) res = op('abc', ['123', 'ABC', 'ABCabc']) print(res) op.save_model('onnx')