|
|
@ -455,20 +455,4 @@ def train_mlm_with_hf_trainer(model, |
|
|
|
trainer.log_metrics("eval", metrics) |
|
|
|
trainer.save_metrics("eval", metrics) |
|
|
|
|
|
|
|
print('done mlm.') |
|
|
|
|
|
|
|
sequence = ( |
|
|
|
f"I have this film out of the {tokenizer.mask_token} right now and I haven't finished watching it. It is so bad I am in disbelief." |
|
|
|
) |
|
|
|
|
|
|
|
import torch |
|
|
|
inputs = tokenizer(sequence, return_tensors="pt").to('cuda:0') |
|
|
|
mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1] |
|
|
|
|
|
|
|
token_logits = model(**inputs).logits # [1, 28, 30522] |
|
|
|
mask_token_logits = token_logits[0, mask_token_index, :] |
|
|
|
|
|
|
|
top_5_tokens = torch.topk(mask_token_logits, 5, dim=1).indices[0].tolist() |
|
|
|
|
|
|
|
for token in top_5_tokens: |
|
|
|
print(sequence.replace(tokenizer.mask_token, tokenizer.decode([token]))) |
|
|
|
print('done mlm.') |