logo
Browse Source

Support sentence embedding

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 2 years ago
parent
commit
6279bd4d80
  1. 20
      auto_transformers.py
  2. 7
      test_onnx.py

20
auto_transformers.py

@ -100,9 +100,9 @@ class AutoTransformers(NNOperator):
log.warning('The operator is initialized without specified model.')
pass
def __call__(self, txt: str) -> numpy.ndarray:
def __call__(self, txt: str, return_sentence_emb: bool = False) -> numpy.ndarray:
try:
inputs = self.tokenizer(txt, return_tensors="pt").to(self.device)
inputs = self.tokenizer(txt, padding=True, truncation=True, return_tensors="pt").to(self.device)
except Exception as e:
log.error(f'Invalid input for the tokenizer: {self.model_name}')
raise e
@ -111,11 +111,9 @@ class AutoTransformers(NNOperator):
except Exception as e:
log.error(f'Invalid input for the model: {self.model_name}')
raise e
try:
if return_sentence_emb:
outs = self.post_proc(outs, inputs)
features = outs.squeeze(0)
except Exception as e:
log.error(f'Fail to extract features by model: {self.model_name}')
raise e
vec = features.cpu().detach().numpy()
return vec
@ -123,6 +121,14 @@ class AutoTransformers(NNOperator):
def _model(self):
return self.model.model
def post_proc(self, token_embeddings, inputs):
token_embeddings = token_embeddings.to(self.device)
attention_mask = inputs['attention_mask'].to(self.device)
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
sentence_embs = torch.sum(
token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
return sentence_embs
def save_model(self, model_type: str = 'pytorch', output_file: str = 'default'):
if output_file == 'default':
output_file = str(Path(__file__).parent)
@ -138,7 +144,7 @@ class AutoTransformers(NNOperator):
raise AttributeError('Unsupported model_type.')
dummy_input = '[CLS]'
inputs = self.tokenizer(dummy_input, return_tensors='pt') # a dictionary
inputs = self.tokenizer(dummy_input, padding=True, truncation=True, return_tensors='pt') # a dictionary
if model_type == 'pytorch':
torch.save(self._model, output_file)
elif model_type == 'torchscript':

7
test_onnx.py

@ -9,6 +9,13 @@ import logging
import platform
import psutil
import warnings
from transformers import logging as t_logging
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
warnings.filterwarnings("ignore")
t_logging.set_verbosity_error()
# full_models = AutoTransformers.supported_model_names()
# checked_models = AutoTransformers.supported_model_names(format='onnx')
# models = [x for x in full_models if x not in checked_models]

Loading…
Cancel
Save