|
|
@ -25,10 +25,21 @@ class Longformer(NNOperator): |
|
|
|
Args: |
|
|
|
model_name (`str`): |
|
|
|
Which model to use for the embeddings. |
|
|
|
global_attention_mask (`torch.Tensor`): |
|
|
|
Global attention mask depending on tasks |
|
|
|
pooler_output (`bool`): |
|
|
|
Whether to pool features |
|
|
|
""" |
|
|
|
def __init__(self, model_name: str = 'allenai/longformer-base-4096') -> None: |
|
|
|
def __init__( |
|
|
|
self, |
|
|
|
model_name: str = 'allenai/longformer-base-4096', |
|
|
|
global_attention_mask: torch.Tensor = None, |
|
|
|
pooler_output: bool = False |
|
|
|
): |
|
|
|
super().__init__() |
|
|
|
self.model_name = model_name |
|
|
|
self.global_attention_mask = global_attention_mask |
|
|
|
self.pooler_output = pooler_output |
|
|
|
try: |
|
|
|
self.model = LongformerModel.from_pretrained(model_name) |
|
|
|
except Exception as e: |
|
|
@ -47,13 +58,22 @@ class Longformer(NNOperator): |
|
|
|
log.error(f'Invalid input for the tokenizer: {self.model_name}') |
|
|
|
raise e |
|
|
|
try: |
|
|
|
attention_mask = None |
|
|
|
outs = self.model(input_ids, attention_mask=attention_mask, labels=input_ids, output_hidden_states=True) |
|
|
|
attention_mask = torch.ones( |
|
|
|
input_ids.shape, dtype=torch.long, device=input_ids.device |
|
|
|
) |
|
|
|
outs = self.model( |
|
|
|
input_ids, |
|
|
|
attention_mask=attention_mask, |
|
|
|
global_attention_mask=self.global_attention_mask |
|
|
|
) |
|
|
|
except Exception as e: |
|
|
|
log.error(f'Invalid input for the model: {self.model_name}') |
|
|
|
raise e |
|
|
|
try: |
|
|
|
feature_vector = outs[1].squeeze() |
|
|
|
if self.pooler_output: |
|
|
|
feature_vector = outs.pooler_output.squeeze(0) |
|
|
|
else: |
|
|
|
feature_vector = outs.last_hidden_state.squeeze(0) |
|
|
|
except Exception as e: |
|
|
|
log.error(f'Fail to extract features by model: {self.model_name}') |
|
|
|
raise e |
|
|
|