logo
Browse Source

Add usage

Signed-off-by: junjie.jiang <junjie.jiang@zilliz.com>
main
junjie.jiang 8 months ago
parent
commit
232797ba16
  1. 6
      s_bert.py

6
s_bert.py

@ -77,7 +77,7 @@ class STransformers(NNOperator):
Operator using pretrained Sentence Transformers
"""
def __init__(self, model_name: str = None, device: str = None):
def __init__(self, model_name: str = None, device: str = None, return_usage: bool = False):
self.model_name = model_name
if device is None:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
@ -87,6 +87,7 @@ class STransformers(NNOperator):
else:
log.warning('The operator is initialized without specified model.')
self._tokenize = self.get_tokenizer()
self.return_usage = return_usage
def __call__(self, txt: Union[List[str], str]):
if isinstance(txt, str):
@ -94,11 +95,14 @@ class STransformers(NNOperator):
else:
sentences = txt
inputs = self._tokenize(sentences)
num_tokens = int(torch.count_nonzero(inputs['input_ids']))
embs = self.model(**inputs).cpu().detach().numpy()
if isinstance(txt, str):
embs = embs.squeeze(0)
else:
embs = list(embs)
if self.return_usage:
return {'data': embs, 'token_usage': num_tokens}
return embs
@property

Loading…
Cancel
Save