From 3ed58a383ae52cd073e5fa743bd42c031af2f30a Mon Sep 17 00:00:00 2001 From: Jael Gu Date: Thu, 30 May 2024 16:14:44 +0800 Subject: [PATCH] Fix token count Signed-off-by: Jael Gu --- auto_transformers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/auto_transformers.py b/auto_transformers.py index dd90f30..abd077a 100644 --- a/auto_transformers.py +++ b/auto_transformers.py @@ -146,6 +146,7 @@ class AutoTransformers(NNOperator): txt = data try: inputs = self.tokenizer(txt, padding=True, truncation=True, return_tensors='pt') + num_tokens = int(torch.count_nonzero(inputs['input_ids'])) except Exception as e: log.error(f'Fail to tokenize inputs: {e}') raise e @@ -155,7 +156,6 @@ class AutoTransformers(NNOperator): log.error(f'Invalid input for the model: {self.model_name}') raise e - num_tokens = outs.size(1) if self.pool == 'mean': outs = self.mean_pool(outs, inputs) elif self.pool == 'cls':