logo
Browse Source

Allow to return token usage

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 1 year ago
parent
commit
4c36b8a18b
  1. 8
      README.md
  2. 7
      auto_transformers.py

8
README.md

@ -63,7 +63,13 @@ The path to local checkpoint, defaults to None.
***tokenizer***: *object* ***tokenizer***: *object*
The method to tokenize input text, defaults to None. The method to tokenize input text, defaults to None.
If None, the operator will use default tokenizer by `model_name` from Huggingface transformers.
If None, the operator will use default tokenizer by `model_name` from HuggingFace transformers.
***return_usage***: *bool*
The flag to return token usage with __call__ method, defaults to False.
If True, __call__ method will return a dictionary containing data (embedding).
<br /> <br />

7
auto_transformers.py

@ -107,8 +107,10 @@ class AutoTransformers(NNOperator):
tokenizer: object = None, tokenizer: object = None,
pool: str = 'mean', pool: str = 'mean',
device: str = None, device: str = None,
return_usage: bool = False
): ):
super().__init__() super().__init__()
self.return_usage = return_usage
if pool not in ['mean', 'cls']: if pool not in ['mean', 'cls']:
log.warning('Invalid pool %s, using mean pooling instead.', pool) log.warning('Invalid pool %s, using mean pooling instead.', pool)
pool = 'mean' pool = 'mean'
@ -152,6 +154,8 @@ class AutoTransformers(NNOperator):
except Exception as e: except Exception as e:
log.error(f'Invalid input for the model: {self.model_name}') log.error(f'Invalid input for the model: {self.model_name}')
raise e raise e
num_tokens = outs.size(1)
if self.pool == 'mean': if self.pool == 'mean':
outs = self.mean_pool(outs, inputs) outs = self.mean_pool(outs, inputs)
elif self.pool == 'cls': elif self.pool == 'cls':
@ -161,6 +165,9 @@ class AutoTransformers(NNOperator):
features = features.squeeze(0) features = features.squeeze(0)
else: else:
features = list(features) features = list(features)
if self.return_usage:
return {'data': features, 'token_usage': num_tokens}
return features return features
@property @property

Loading…
Cancel
Save