logo
Browse Source

Debug

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 3 years ago
parent
commit
409c7c56d5
  1. 11
      README.md
  2. 28
      longformer.py

11
README.md

@ -59,6 +59,17 @@ Supported model names:
- allenai/longformer-base-4096-extra.pos.embd.only
- allenai/longformer-large-4096-extra.pos.embd.only
***global_attention_mask***: *torch.Tensor*
defaults to None.
***pooler_output***: *bool*
The flag controllling whether to return outputs with pooled features, defaults to False.
The default output is in shape of (num_tokens, dim) for each input text.
If True, then the output will be a vector in (dim,) for each input text.
<br />
## Interface

28
longformer.py

@ -25,10 +25,21 @@ class Longformer(NNOperator):
Args:
model_name (`str`):
Which model to use for the embeddings.
global_attention_mask (`torch.Tensor`):
Global attention mask depending on tasks
pooler_output (`bool`):
Whether to pool features
"""
def __init__(self, model_name: str = 'allenai/longformer-base-4096') -> None:
def __init__(
self,
model_name: str = 'allenai/longformer-base-4096',
global_attention_mask: torch.Tensor = None,
pooler_output: bool = False
):
super().__init__()
self.model_name = model_name
self.global_attention_mask = global_attention_mask
self.pooler_output = pooler_output
try:
self.model = LongformerModel.from_pretrained(model_name)
except Exception as e:
@ -47,13 +58,22 @@ class Longformer(NNOperator):
log.error(f'Invalid input for the tokenizer: {self.model_name}')
raise e
try:
attention_mask = None
outs = self.model(input_ids, attention_mask=attention_mask, labels=input_ids, output_hidden_states=True)
attention_mask = torch.ones(
input_ids.shape, dtype=torch.long, device=input_ids.device
)
outs = self.model(
input_ids,
attention_mask=attention_mask,
global_attention_mask=self.global_attention_mask
)
except Exception as e:
log.error(f'Invalid input for the model: {self.model_name}')
raise e
try:
feature_vector = outs[1].squeeze()
if self.pooler_output:
feature_vector = outs.pooler_output.squeeze(0)
else:
feature_vector = outs.last_hidden_state.squeeze(0)
except Exception as e:
log.error(f'Fail to extract features by model: {self.model_name}')
raise e

Loading…
Cancel
Save