logo
Browse Source

remove redundant script.

main
ChengZi 1 year ago
parent
commit
009ddda7fc
  1. 18
      train_mlm_with_hf_trainer.py

18
train_mlm_with_hf_trainer.py

@ -455,20 +455,4 @@ def train_mlm_with_hf_trainer(model,
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
print('done mlm.')
sequence = (
f"I have this film out of the {tokenizer.mask_token} right now and I haven't finished watching it. It is so bad I am in disbelief."
)
import torch
inputs = tokenizer(sequence, return_tensors="pt").to('cuda:0')
mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1]
token_logits = model(**inputs).logits # [1, 28, 30522]
mask_token_logits = token_logits[0, mask_token_index, :]
top_5_tokens = torch.topk(mask_token_logits, 5, dim=1).indices[0].tolist()
for token in top_5_tokens:
print(sequence.replace(tokenizer.mask_token, tokenizer.decode([token])))
print('done mlm.')
Loading…
Cancel
Save