diff --git a/README.md b/README.md index 54b2ac4..84a0406 100644 --- a/README.md +++ b/README.md @@ -63,7 +63,13 @@ The path to local checkpoint, defaults to None. ***tokenizer***: *object* The method to tokenize input text, defaults to None. -If None, the operator will use default tokenizer by `model_name` from Huggingface transformers. +If None, the operator will use default tokenizer by `model_name` from HuggingFace transformers. + + +***return_usage***: *bool* + +The flag to return token usage with __call__ method, defaults to False. +If True, __call__ method will return a dictionary containing data (embedding).
diff --git a/auto_transformers.py b/auto_transformers.py index b49d841..dd90f30 100644 --- a/auto_transformers.py +++ b/auto_transformers.py @@ -107,8 +107,10 @@ class AutoTransformers(NNOperator): tokenizer: object = None, pool: str = 'mean', device: str = None, + return_usage: bool = False ): super().__init__() + self.return_usage = return_usage if pool not in ['mean', 'cls']: log.warning('Invalid pool %s, using mean pooling instead.', pool) pool = 'mean' @@ -152,6 +154,8 @@ class AutoTransformers(NNOperator): except Exception as e: log.error(f'Invalid input for the model: {self.model_name}') raise e + + num_tokens = outs.size(1) if self.pool == 'mean': outs = self.mean_pool(outs, inputs) elif self.pool == 'cls': @@ -161,6 +165,9 @@ class AutoTransformers(NNOperator): features = features.squeeze(0) else: features = list(features) + + if self.return_usage: + return {'data': features, 'token_usage': num_tokens} return features @property