From 009ddda7fc78c979bd3481337c1c55c825eb18f2 Mon Sep 17 00:00:00 2001 From: ChengZi Date: Fri, 16 Dec 2022 16:25:34 +0800 Subject: [PATCH] remove redundant script. --- train_mlm_with_hf_trainer.py | 18 +----------------- 1 file changed, 1 insertion(+), 17 deletions(-) diff --git a/train_mlm_with_hf_trainer.py b/train_mlm_with_hf_trainer.py index 45b262f..fc651f4 100644 --- a/train_mlm_with_hf_trainer.py +++ b/train_mlm_with_hf_trainer.py @@ -455,20 +455,4 @@ def train_mlm_with_hf_trainer(model, trainer.log_metrics("eval", metrics) trainer.save_metrics("eval", metrics) - print('done mlm.') - - sequence = ( - f"I have this film out of the {tokenizer.mask_token} right now and I haven't finished watching it. It is so bad I am in disbelief." - ) - - import torch - inputs = tokenizer(sequence, return_tensors="pt").to('cuda:0') - mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1] - - token_logits = model(**inputs).logits # [1, 28, 30522] - mask_token_logits = token_logits[0, mask_token_index, :] - - top_5_tokens = torch.topk(mask_token_logits, 5, dim=1).indices[0].tolist() - - for token in top_5_tokens: - print(sequence.replace(tokenizer.mask_token, tokenizer.decode([token]))) \ No newline at end of file + print('done mlm.') \ No newline at end of file