From 5feecc6f6d40c68ba391f88a25a2b76fc364eb2c Mon Sep 17 00:00:00 2001 From: ChengZi Date: Thu, 20 Jul 2023 17:07:16 +0800 Subject: [PATCH] refactor to support trt accelerate Signed-off-by: ChengZi --- rerank.py | 67 ++++++++++++++++++++++++++++++++++++++++++------------- 1 file changed, 52 insertions(+), 15 deletions(-) diff --git a/rerank.py b/rerank.py index 0b02fdd..37a2374 100644 --- a/rerank.py +++ b/rerank.py @@ -21,16 +21,17 @@ class Model: self.model = AutoModelForSequenceClassification.from_pretrained(model_name, config=self.config) self.model.to(self.device) self.model.eval() - if self.config.num_labels == 1: - self.activation_fct = torch.sigmoid - else: - self.activation_fct = partial(torch.softmax, dim=1) - def __call__(self, **features): - with torch.no_grad(): - logits = self.model(**features, return_dict=True).logits - scores = self.activation_fct(logits) - return scores + + def __call__(self, *args, **kwargs): + new_args = [] + for x in args: + new_args.append(x.to(self.device)) + new_kwargs = {} + for k, v in kwargs.items(): + new_kwargs[k] = v.to(self.device) + outs = self.model(*new_args, **new_kwargs, return_dict=True) + return outs.logits class ReRank(NNOperator): @@ -43,6 +44,10 @@ class ReRank(NNOperator): self.tokenizer = AutoTokenizer.from_pretrained(model_name) self.max_length = max_length self._threshold = threshold + if self.config.num_labels == 1: + self.activation_fct = torch.sigmoid + else: + self.activation_fct = partial(torch.softmax, dim=1) def __call__(self, query: str, docs: List): if len(docs) == 0: @@ -60,12 +65,8 @@ class ReRank(NNOperator): for name in tokenized: tokenized[name] = tokenized[name].to(self.device) - scores = self.model(**tokenized).detach().cpu().numpy() - if self.config.num_labels == 1: - scores = [score[0] for score in scores] - else: - scores = scores[:, 1] - scores = [score for score in scores] + logits = self.model(**tokenized) + scores = self.post_proc(logits) re_ids = sorted(range(len(scores)), key=lambda k: scores[k], reverse=True) if self._threshold is None: @@ -77,6 +78,16 @@ class ReRank(NNOperator): return re_docs, re_scores + def post_proc(self, logits): + scores = self.activation_fct(logits).detach().cpu().numpy() + if self.config.num_labels == 1: + scores = [score[0] for score in scores] + else: + scores = scores[:, 1] + scores = [score for score in scores] + return scores + + @property def _model(self): return self.model.model @@ -114,3 +125,29 @@ class ReRank(NNOperator): output=Path(path) ) return Path(path).resolve() + + + +if __name__ == '__main__': + model_name_list = [ + 'cross-encoder/ms-marco-TinyBERT-L-2-v2', + 'cross-encoder/ms-marco-MiniLM-L-2-v2', + 'cross-encoder/ms-marco-MiniLM-L-4-v2', + 'cross-encoder/ms-marco-MiniLM-L-6-v2', + 'cross-encoder/ms-marco-MiniLM-L-12-v2', + 'cross-encoder/ms-marco-TinyBERT-L-2', + 'cross-encoder/ms-marco-TinyBERT-L-4', + 'cross-encoder/ms-marco-TinyBERT-L-6', + 'cross-encoder/ms-marco-electra-base', + 'nboost/pt-tinybert-msmarco', + 'nboost/pt-bert-base-uncased-msmarco', + 'nboost/pt-bert-large-msmarco', + 'Capreolus/electra-base-msmarco', + 'amberoad/bert-multilingual-passage-reranking-msmarco', + ] + for model_name in model_name_list: + print('\n' + model_name) + op = ReRank(model_name, threshold=0) + res = op('abc', ['123', 'ABC', 'ABCabc']) + print(res) + op.save_model('onnx') \ No newline at end of file