From 34f1a1e4b87270b157f541044e21bf827bf71dde Mon Sep 17 00:00:00 2001 From: Jael Gu Date: Thu, 7 Sep 2023 13:36:22 +0800 Subject: [PATCH] Add cls pooling Signed-off-by: Jael Gu --- auto_transformers.py | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/auto_transformers.py b/auto_transformers.py index 580e222..b49d841 100644 --- a/auto_transformers.py +++ b/auto_transformers.py @@ -97,15 +97,22 @@ class AutoTransformers(NNOperator): The local checkpoint path. tokenizer (`object`): The tokenizer to tokenize input text as model inputs. + pool (`str`): + The type of post-process pooling after token embeddings, defaults to "mean". Options: "mean", "cls" """ def __init__(self, model_name: str = None, checkpoint_path: str = None, tokenizer: object = None, + pool: str = 'mean', device: str = None, ): super().__init__() + if pool not in ['mean', 'cls']: + log.warning('Invalid pool %s, using mean pooling instead.', pool) + pool = 'mean' + self.pool = pool if device: self.device = device else: @@ -145,7 +152,10 @@ class AutoTransformers(NNOperator): except Exception as e: log.error(f'Invalid input for the model: {self.model_name}') raise e - outs = self.post_proc(outs, inputs) + if self.pool == 'mean': + outs = self.mean_pool(outs, inputs) + elif self.pool == 'cls': + outs = self.cls_pool(outs) features = outs.detach().numpy() if isinstance(data, str): features = features.squeeze(0) @@ -184,7 +194,7 @@ class AutoTransformers(NNOperator): } return onnx_config - def post_proc(self, token_embeddings, inputs): + def mean_pool(self, token_embeddings, inputs): token_embeddings = token_embeddings attention_mask = inputs['attention_mask'] input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() @@ -192,6 +202,15 @@ class AutoTransformers(NNOperator): token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) return sentence_embs + def cls_pool(self, token_embeddings): + _shape = token_embeddings.shape + if len(_shape) == 3: + return token_embeddings[:, 0, :] + elif len(_shape) == 2: + return token_embeddings[0] + else: + raise RuntimeError(f'Invalid shape of token embeddings: {_shape}') + def save_model(self, model_type: str = 'pytorch', output_file: str = 'default'): if output_file == 'default': output_file = str(Path(__file__).parent)