diff --git a/README.md b/README.md index 9d4caa1..2e6111d 100644 --- a/README.md +++ b/README.md @@ -59,6 +59,17 @@ Supported model names: - allenai/longformer-base-4096-extra.pos.embd.only - allenai/longformer-large-4096-extra.pos.embd.only +***global_attention_mask***: *torch.Tensor* + +defaults to None. + + +***pooler_output***: *bool* + +The flag controllling whether to return outputs with pooled features, defaults to False. +The default output is in shape of (num_tokens, dim) for each input text. +If True, then the output will be a vector in (dim,) for each input text. +
## Interface diff --git a/longformer.py b/longformer.py index f0f420a..4aa96b9 100644 --- a/longformer.py +++ b/longformer.py @@ -25,10 +25,21 @@ class Longformer(NNOperator): Args: model_name (`str`): Which model to use for the embeddings. + global_attention_mask (`torch.Tensor`): + Global attention mask depending on tasks + pooler_output (`bool`): + Whether to pool features """ - def __init__(self, model_name: str = 'allenai/longformer-base-4096') -> None: + def __init__( + self, + model_name: str = 'allenai/longformer-base-4096', + global_attention_mask: torch.Tensor = None, + pooler_output: bool = False + ): super().__init__() self.model_name = model_name + self.global_attention_mask = global_attention_mask + self.pooler_output = pooler_output try: self.model = LongformerModel.from_pretrained(model_name) except Exception as e: @@ -47,13 +58,22 @@ class Longformer(NNOperator): log.error(f'Invalid input for the tokenizer: {self.model_name}') raise e try: - attention_mask = None - outs = self.model(input_ids, attention_mask=attention_mask, labels=input_ids, output_hidden_states=True) + attention_mask = torch.ones( + input_ids.shape, dtype=torch.long, device=input_ids.device + ) + outs = self.model( + input_ids, + attention_mask=attention_mask, + global_attention_mask=self.global_attention_mask + ) except Exception as e: log.error(f'Invalid input for the model: {self.model_name}') raise e try: - feature_vector = outs[1].squeeze() + if self.pooler_output: + feature_vector = outs.pooler_output.squeeze(0) + else: + feature_vector = outs.last_hidden_state.squeeze(0) except Exception as e: log.error(f'Fail to extract features by model: {self.model_name}') raise e