|
@ -67,6 +67,7 @@ class ReRank(NNOperator): |
|
|
texts[idx].append(text.strip()) |
|
|
texts[idx].append(text.strip()) |
|
|
|
|
|
|
|
|
tokenized = self.tokenizer(*texts, padding=True, truncation='longest_first', return_tensors="pt", max_length=self.max_length) |
|
|
tokenized = self.tokenizer(*texts, padding=True, truncation='longest_first', return_tensors="pt", max_length=self.max_length) |
|
|
|
|
|
token_count = torch.count_nonzero(tokenized['input_ids']) |
|
|
|
|
|
|
|
|
logits = self.model(**tokenized) |
|
|
logits = self.model(**tokenized) |
|
|
scores = self.post_proc(logits) |
|
|
scores = self.post_proc(logits) |
|
@ -78,7 +79,7 @@ class ReRank(NNOperator): |
|
|
else: |
|
|
else: |
|
|
re_docs = [docs[i] for i in re_ids if scores[i] >= self._threshold] |
|
|
re_docs = [docs[i] for i in re_ids if scores[i] >= self._threshold] |
|
|
re_scores = [scores[i] for i in re_ids if scores[i] >= self._threshold] |
|
|
re_scores = [scores[i] for i in re_ids if scores[i] >= self._threshold] |
|
|
return re_docs, re_scores |
|
|
|
|
|
|
|
|
return re_docs, re_scores, token_count |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def post_proc(self, logits): |
|
|
def post_proc(self, logits): |
|
|