From fdda753863837095622e33487f59c6aa2991bae1 Mon Sep 17 00:00:00 2001 From: Jael Gu Date: Tue, 7 Feb 2023 18:41:48 +0800 Subject: [PATCH] Optimize triton Signed-off-by: Jael Gu --- auto_transformers.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/auto_transformers.py b/auto_transformers.py index 1f25fd8..911319c 100644 --- a/auto_transformers.py +++ b/auto_transformers.py @@ -59,10 +59,17 @@ def create_model(model_name, checkpoint_path, device): # @accelerate class Model: def __init__(self, model_name, checkpoint_path, device): + self.device = device self.model = create_model(model_name, checkpoint_path, device) def __call__(self, *args, **kwargs): - outs = self.model(*args, **kwargs, return_dict=True) + new_args = [] + for x in args: + new_args.append(x.to(self.device)) + new_kwargs = {} + for k, v in kwargs.items(): + new_kwargs[k] = v.to(self.device) + outs = self.model(*new_args, **new_kwargs, return_dict=True) return outs['last_hidden_state'] @@ -84,7 +91,6 @@ class AutoTransformers(NNOperator): checkpoint_path: str = None, tokenizer: object = None, device: str = None, - norm: bool = False ): super().__init__() if device: @@ -95,7 +101,6 @@ class AutoTransformers(NNOperator): self.model_name = 'sentence-transformers/' + model_name else: self.model_name = model_name - self.norm = norm self.checkpoint_path = checkpoint_path if self.model_name: @@ -118,7 +123,7 @@ class AutoTransformers(NNOperator): else: txt = data try: - inputs = self.tokenizer(txt, padding=True, truncation=True, return_tensors='pt').to(self.device) + inputs = self.tokenizer(txt, padding=True, truncation=True, return_tensors='pt') except Exception as e: log.error(f'Fail to tokenize inputs: {e}') raise e @@ -128,8 +133,6 @@ class AutoTransformers(NNOperator): log.error(f'Invalid input for the model: {self.model_name}') raise e outs = self.post_proc(outs, inputs) - if self.norm: - outs = torch.nn.functional.normalize(outs, ) features = outs.cpu().detach().numpy() if isinstance(data, str): features = features.squeeze(0) @@ -169,8 +172,8 @@ class AutoTransformers(NNOperator): return onnx_config def post_proc(self, token_embeddings, inputs): - token_embeddings = token_embeddings.to(self.device) - attention_mask = inputs['attention_mask'].to(self.device) + token_embeddings = token_embeddings + attention_mask = inputs['attention_mask'] input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() sentence_embs = torch.sum( token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) @@ -191,7 +194,7 @@ class AutoTransformers(NNOperator): raise AttributeError('Unsupported model_type.') dummy_input = 'test sentence' - inputs = self.tokenizer(dummy_input, padding=True, truncation=True, return_tensors='pt').to(self.device) + inputs = self.tokenizer(dummy_input, padding=True, truncation=True, return_tensors='pt') if model_type == 'pytorch': torch.save(self._model, output_file) elif model_type == 'torchscript': @@ -214,7 +217,7 @@ class AutoTransformers(NNOperator): if hasattr(self._model.config, 'use_cache'): self._model.config.use_cache = False torch.onnx.export( - self._model, + self._model.to('cpu'), tuple(inputs.values()), output_file, input_names=list(self.onnx_config['inputs'].keys()),