diff --git a/auto_transformers.py b/auto_transformers.py index dd90f30..abd077a 100644 --- a/auto_transformers.py +++ b/auto_transformers.py @@ -146,6 +146,7 @@ class AutoTransformers(NNOperator): txt = data try: inputs = self.tokenizer(txt, padding=True, truncation=True, return_tensors='pt') + num_tokens = int(torch.count_nonzero(inputs['input_ids'])) except Exception as e: log.error(f'Fail to tokenize inputs: {e}') raise e @@ -155,7 +156,6 @@ class AutoTransformers(NNOperator): log.error(f'Invalid input for the model: {self.model_name}') raise e - num_tokens = outs.size(1) if self.pool == 'mean': outs = self.mean_pool(outs, inputs) elif self.pool == 'cls':