|
@ -77,7 +77,7 @@ class STransformers(NNOperator): |
|
|
Operator using pretrained Sentence Transformers |
|
|
Operator using pretrained Sentence Transformers |
|
|
""" |
|
|
""" |
|
|
|
|
|
|
|
|
def __init__(self, model_name: str = None, device: str = None): |
|
|
|
|
|
|
|
|
def __init__(self, model_name: str = None, device: str = None, return_usage: bool = False): |
|
|
self.model_name = model_name |
|
|
self.model_name = model_name |
|
|
if device is None: |
|
|
if device is None: |
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
@ -87,6 +87,7 @@ class STransformers(NNOperator): |
|
|
else: |
|
|
else: |
|
|
log.warning('The operator is initialized without specified model.') |
|
|
log.warning('The operator is initialized without specified model.') |
|
|
self._tokenize = self.get_tokenizer() |
|
|
self._tokenize = self.get_tokenizer() |
|
|
|
|
|
self.return_usage = return_usage |
|
|
|
|
|
|
|
|
def __call__(self, txt: Union[List[str], str]): |
|
|
def __call__(self, txt: Union[List[str], str]): |
|
|
if isinstance(txt, str): |
|
|
if isinstance(txt, str): |
|
@ -94,11 +95,14 @@ class STransformers(NNOperator): |
|
|
else: |
|
|
else: |
|
|
sentences = txt |
|
|
sentences = txt |
|
|
inputs = self._tokenize(sentences) |
|
|
inputs = self._tokenize(sentences) |
|
|
|
|
|
num_tokens = int(torch.count_nonzero(inputs['input_ids'])) |
|
|
embs = self.model(**inputs).cpu().detach().numpy() |
|
|
embs = self.model(**inputs).cpu().detach().numpy() |
|
|
if isinstance(txt, str): |
|
|
if isinstance(txt, str): |
|
|
embs = embs.squeeze(0) |
|
|
embs = embs.squeeze(0) |
|
|
else: |
|
|
else: |
|
|
embs = list(embs) |
|
|
embs = list(embs) |
|
|
|
|
|
if self.return_usage: |
|
|
|
|
|
return {'data': embs, 'token_usage': num_tokens} |
|
|
return embs |
|
|
return embs |
|
|
|
|
|
|
|
|
@property |
|
|
@property |
|
|