From 83e096c6dbdad3f2412ca133ff460ed8169480e2 Mon Sep 17 00:00:00 2001 From: Jael Gu Date: Thu, 12 Jan 2023 17:32:41 +0800 Subject: [PATCH] Remove sentence post_proc Signed-off-by: Jael Gu --- auto_transformers.py | 45 +++++++++++++++----------------------------- 1 file changed, 15 insertions(+), 30 deletions(-) diff --git a/auto_transformers.py b/auto_transformers.py index b6c1df6..cc9f6d7 100644 --- a/auto_transformers.py +++ b/auto_transformers.py @@ -40,21 +40,8 @@ t_logging.set_verbosity_error() # @accelerate class Model: - def __init__(self, model_name, device, checkpoint_path): - try: - self.model = AutoModel.from_pretrained(model_name).to(device) - if hasattr(self.model, 'pooler') and self.model.pooler: - self.model.pooler = None - except Exception as e: - log.error(f"Fail to load model by name: {self.model_name}") - raise e - if checkpoint_path: - try: - state_dict = torch.load(checkpoint_path, map_location=device) - self.model.load_state_dict(state_dict) - except Exception as e: - log.error(f"Fail to load state dict from {checkpoint_path}: {e}") - self.model.eval() + def __init__(self, model): + self.model = model def __call__(self, *args, **kwargs): outs = self.model(*args, **kwargs, return_dict=True) @@ -79,20 +66,18 @@ class AutoTransformers(NNOperator): checkpoint_path: str = None, tokenizer: object = None, device: str = None, - return_sentence_emb: bool = True ): super().__init__() if device is None: device = 'cuda' if torch.cuda.is_available() else 'cpu' self.device = device self.model_name = model_name - self.return_sentence_emb = return_sentence_emb + self.checkpoint_path = checkpoint_path if self.model_name: model_list = self.supported_model_names() assert model_name in model_list, f"Invalid model name: {model_name}. Supported model names: {model_list}" - self.model = Model( - model_name=self.model_name, device=self.device, checkpoint_path=checkpoint_path) + self.model = Model(self._model) if tokenizer is None: try: self.tokenizer = AutoTokenizer.from_pretrained(model_name) @@ -122,8 +107,6 @@ class AutoTransformers(NNOperator): except Exception as e: log.error(f'Invalid input for the model: {self.model_name}') raise e - if self.return_sentence_emb: - outs = self.post_proc(outs, inputs) features = outs.cpu().detach().numpy() if isinstance(data, str): features = features.squeeze(0) @@ -133,15 +116,17 @@ class AutoTransformers(NNOperator): @property def _model(self): - return self.model.model - - def post_proc(self, token_embeddings, inputs): - token_embeddings = token_embeddings.to(self.device) - attention_mask = inputs['attention_mask'].to(self.device) - input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() - sentence_embs = torch.sum( - token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) - return sentence_embs + model = AutoModel.from_pretrained(self.model_name).to(self.device) + if hasattr(model, 'pooler') and model.pooler: + model.pooler = None + if self.checkpoint_path: + try: + state_dict = torch.load(self.checkpoint_path, map_location=self.device) + model.load_state_dict(state_dict) + except Exception: + log.error(f'Fail to load weights from {self.checkpoint_path}') + model.eval() + return model def save_model(self, model_type: str = 'pytorch', output_file: str = 'default'): if output_file == 'default':