logo
Browse Source

Fix token count

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 7 months ago
parent
commit
3ed58a383a
  1. 2
      auto_transformers.py

2
auto_transformers.py

@ -146,6 +146,7 @@ class AutoTransformers(NNOperator):
txt = data txt = data
try: try:
inputs = self.tokenizer(txt, padding=True, truncation=True, return_tensors='pt') inputs = self.tokenizer(txt, padding=True, truncation=True, return_tensors='pt')
num_tokens = int(torch.count_nonzero(inputs['input_ids']))
except Exception as e: except Exception as e:
log.error(f'Fail to tokenize inputs: {e}') log.error(f'Fail to tokenize inputs: {e}')
raise e raise e
@ -155,7 +156,6 @@ class AutoTransformers(NNOperator):
log.error(f'Invalid input for the model: {self.model_name}') log.error(f'Invalid input for the model: {self.model_name}')
raise e raise e
num_tokens = outs.size(1)
if self.pool == 'mean': if self.pool == 'mean':
outs = self.mean_pool(outs, inputs) outs = self.mean_pool(outs, inputs)
elif self.pool == 'cls': elif self.pool == 'cls':

Loading…
Cancel
Save