Browse Source
Allow to return token usage
Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
2 changed files with
14 additions and
1 deletions
-
README.md
-
auto_transformers.py
|
|
@ -63,7 +63,13 @@ The path to local checkpoint, defaults to None. |
|
|
|
***tokenizer***: *object* |
|
|
|
|
|
|
|
The method to tokenize input text, defaults to None. |
|
|
|
If None, the operator will use default tokenizer by `model_name` from Huggingface transformers. |
|
|
|
If None, the operator will use default tokenizer by `model_name` from HuggingFace transformers. |
|
|
|
|
|
|
|
|
|
|
|
***return_usage***: *bool* |
|
|
|
|
|
|
|
The flag to return token usage with __call__ method, defaults to False. |
|
|
|
If True, __call__ method will return a dictionary containing data (embedding). |
|
|
|
|
|
|
|
<br /> |
|
|
|
|
|
|
|
|
|
@ -107,8 +107,10 @@ class AutoTransformers(NNOperator): |
|
|
|
tokenizer: object = None, |
|
|
|
pool: str = 'mean', |
|
|
|
device: str = None, |
|
|
|
return_usage: bool = False |
|
|
|
): |
|
|
|
super().__init__() |
|
|
|
self.return_usage = return_usage |
|
|
|
if pool not in ['mean', 'cls']: |
|
|
|
log.warning('Invalid pool %s, using mean pooling instead.', pool) |
|
|
|
pool = 'mean' |
|
|
@ -152,6 +154,8 @@ class AutoTransformers(NNOperator): |
|
|
|
except Exception as e: |
|
|
|
log.error(f'Invalid input for the model: {self.model_name}') |
|
|
|
raise e |
|
|
|
|
|
|
|
num_tokens = outs.size(1) |
|
|
|
if self.pool == 'mean': |
|
|
|
outs = self.mean_pool(outs, inputs) |
|
|
|
elif self.pool == 'cls': |
|
|
@ -161,6 +165,9 @@ class AutoTransformers(NNOperator): |
|
|
|
features = features.squeeze(0) |
|
|
|
else: |
|
|
|
features = list(features) |
|
|
|
|
|
|
|
if self.return_usage: |
|
|
|
return {'data': features, 'token_usage': num_tokens} |
|
|
|
return features |
|
|
|
|
|
|
|
@property |
|
|
|