From 1902ba932be400dfcec8361a7a022b51b126c2aa Mon Sep 17 00:00:00 2001 From: Jael Gu Date: Thu, 30 May 2024 16:10:29 +0800 Subject: [PATCH] Output token usage Signed-off-by: Jael Gu --- rerank.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/rerank.py b/rerank.py index 1365df6..85cb72f 100644 --- a/rerank.py +++ b/rerank.py @@ -67,6 +67,7 @@ class ReRank(NNOperator): texts[idx].append(text.strip()) tokenized = self.tokenizer(*texts, padding=True, truncation='longest_first', return_tensors="pt", max_length=self.max_length) + token_count = torch.count_nonzero(tokenized['input_ids']) logits = self.model(**tokenized) scores = self.post_proc(logits) @@ -78,7 +79,7 @@ class ReRank(NNOperator): else: re_docs = [docs[i] for i in re_ids if scores[i] >= self._threshold] re_scores = [scores[i] for i in re_ids if scores[i] >= self._threshold] - return re_docs, re_scores + return re_docs, re_scores, token_count def post_proc(self, logits):