logo
Browse Source

Optimize triton

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 2 years ago
parent
commit
fdda753863
  1. 23
      auto_transformers.py

23
auto_transformers.py

@ -59,10 +59,17 @@ def create_model(model_name, checkpoint_path, device):
# @accelerate
class Model:
def __init__(self, model_name, checkpoint_path, device):
self.device = device
self.model = create_model(model_name, checkpoint_path, device)
def __call__(self, *args, **kwargs):
outs = self.model(*args, **kwargs, return_dict=True)
new_args = []
for x in args:
new_args.append(x.to(self.device))
new_kwargs = {}
for k, v in kwargs.items():
new_kwargs[k] = v.to(self.device)
outs = self.model(*new_args, **new_kwargs, return_dict=True)
return outs['last_hidden_state']
@ -84,7 +91,6 @@ class AutoTransformers(NNOperator):
checkpoint_path: str = None,
tokenizer: object = None,
device: str = None,
norm: bool = False
):
super().__init__()
if device:
@ -95,7 +101,6 @@ class AutoTransformers(NNOperator):
self.model_name = 'sentence-transformers/' + model_name
else:
self.model_name = model_name
self.norm = norm
self.checkpoint_path = checkpoint_path
if self.model_name:
@ -118,7 +123,7 @@ class AutoTransformers(NNOperator):
else:
txt = data
try:
inputs = self.tokenizer(txt, padding=True, truncation=True, return_tensors='pt').to(self.device)
inputs = self.tokenizer(txt, padding=True, truncation=True, return_tensors='pt')
except Exception as e:
log.error(f'Fail to tokenize inputs: {e}')
raise e
@ -128,8 +133,6 @@ class AutoTransformers(NNOperator):
log.error(f'Invalid input for the model: {self.model_name}')
raise e
outs = self.post_proc(outs, inputs)
if self.norm:
outs = torch.nn.functional.normalize(outs, )
features = outs.cpu().detach().numpy()
if isinstance(data, str):
features = features.squeeze(0)
@ -169,8 +172,8 @@ class AutoTransformers(NNOperator):
return onnx_config
def post_proc(self, token_embeddings, inputs):
token_embeddings = token_embeddings.to(self.device)
attention_mask = inputs['attention_mask'].to(self.device)
token_embeddings = token_embeddings
attention_mask = inputs['attention_mask']
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
sentence_embs = torch.sum(
token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
@ -191,7 +194,7 @@ class AutoTransformers(NNOperator):
raise AttributeError('Unsupported model_type.')
dummy_input = 'test sentence'
inputs = self.tokenizer(dummy_input, padding=True, truncation=True, return_tensors='pt').to(self.device)
inputs = self.tokenizer(dummy_input, padding=True, truncation=True, return_tensors='pt')
if model_type == 'pytorch':
torch.save(self._model, output_file)
elif model_type == 'torchscript':
@ -214,7 +217,7 @@ class AutoTransformers(NNOperator):
if hasattr(self._model.config, 'use_cache'):
self._model.config.use_cache = False
torch.onnx.export(
self._model,
self._model.to('cpu'),
tuple(inputs.values()),
output_file,
input_names=list(self.onnx_config['inputs'].keys()),

Loading…
Cancel
Save