Browse Source
Fix token count
Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
1 changed files with
1 additions and
1 deletions
-
auto_transformers.py
|
@ -146,6 +146,7 @@ class AutoTransformers(NNOperator): |
|
|
txt = data |
|
|
txt = data |
|
|
try: |
|
|
try: |
|
|
inputs = self.tokenizer(txt, padding=True, truncation=True, return_tensors='pt') |
|
|
inputs = self.tokenizer(txt, padding=True, truncation=True, return_tensors='pt') |
|
|
|
|
|
num_tokens = int(torch.count_nonzero(inputs['input_ids'])) |
|
|
except Exception as e: |
|
|
except Exception as e: |
|
|
log.error(f'Fail to tokenize inputs: {e}') |
|
|
log.error(f'Fail to tokenize inputs: {e}') |
|
|
raise e |
|
|
raise e |
|
@ -155,7 +156,6 @@ class AutoTransformers(NNOperator): |
|
|
log.error(f'Invalid input for the model: {self.model_name}') |
|
|
log.error(f'Invalid input for the model: {self.model_name}') |
|
|
raise e |
|
|
raise e |
|
|
|
|
|
|
|
|
num_tokens = outs.size(1) |
|
|
|
|
|
if self.pool == 'mean': |
|
|
if self.pool == 'mean': |
|
|
outs = self.mean_pool(outs, inputs) |
|
|
outs = self.mean_pool(outs, inputs) |
|
|
elif self.pool == 'cls': |
|
|
elif self.pool == 'cls': |
|
|