diff --git a/rerank.py b/rerank.py index 85cb72f..7a44e36 100644 --- a/rerank.py +++ b/rerank.py @@ -67,7 +67,7 @@ class ReRank(NNOperator): texts[idx].append(text.strip()) tokenized = self.tokenizer(*texts, padding=True, truncation='longest_first', return_tensors="pt", max_length=self.max_length) - token_count = torch.count_nonzero(tokenized['input_ids']) + token_count = int(torch.count_nonzero(tokenized['input_ids'])) logits = self.model(**tokenized) scores = self.post_proc(logits)