diff --git a/rerank.py b/rerank.py index 1365df6..85cb72f 100644 --- a/rerank.py +++ b/rerank.py @@ -67,6 +67,7 @@ class ReRank(NNOperator): texts[idx].append(text.strip()) tokenized = self.tokenizer(*texts, padding=True, truncation='longest_first', return_tensors="pt", max_length=self.max_length) + token_count = torch.count_nonzero(tokenized['input_ids']) logits = self.model(**tokenized) scores = self.post_proc(logits) @@ -78,7 +79,7 @@ class ReRank(NNOperator): else: re_docs = [docs[i] for i in re_ids if scores[i] >= self._threshold] re_scores = [scores[i] for i in re_ids if scores[i] >= self._threshold] - return re_docs, re_scores + return re_docs, re_scores, token_count def post_proc(self, logits):