diff --git a/s_bert.py b/s_bert.py index 18688f3..6a367af 100644 --- a/s_bert.py +++ b/s_bert.py @@ -77,7 +77,7 @@ class STransformers(NNOperator): Operator using pretrained Sentence Transformers """ - def __init__(self, model_name: str = None, device: str = None): + def __init__(self, model_name: str = None, device: str = None, return_usage: bool = False): self.model_name = model_name if device is None: device = 'cuda' if torch.cuda.is_available() else 'cpu' @@ -87,6 +87,7 @@ class STransformers(NNOperator): else: log.warning('The operator is initialized without specified model.') self._tokenize = self.get_tokenizer() + self.return_usage = return_usage def __call__(self, txt: Union[List[str], str]): if isinstance(txt, str): @@ -94,11 +95,14 @@ class STransformers(NNOperator): else: sentences = txt inputs = self._tokenize(sentences) + num_tokens = int(torch.count_nonzero(inputs['input_ids'])) embs = self.model(**inputs).cpu().detach().numpy() if isinstance(txt, str): embs = embs.squeeze(0) else: embs = list(embs) + if self.return_usage: + return {'data': embs, 'token_usage': num_tokens} return embs @property