diff --git a/train_clm_with_hf_trainer.py b/train_clm_with_hf_trainer.py index 094e08a..181237e 100644 --- a/train_clm_with_hf_trainer.py +++ b/train_clm_with_hf_trainer.py @@ -118,10 +118,16 @@ def train_clm_with_hf_trainer(model, import datasets from transformers import Trainer from datasets import load_dataset + from towhee.trainer.training_config import get_dataclasses_help print('train clm with hugging face transformers trainer') + print('**** DataTrainingArguments ****') + get_dataclasses_help(DataTrainingArguments) data_args = dataclass_from_dict(DataTrainingArguments, data_args) + + print('**** TrainingArguments ****') + get_dataclasses_help(TrainingArguments) training_args = dataclass_from_dict(TrainingArguments, training_args) # Setup logging @@ -308,7 +314,7 @@ def train_clm_with_hf_trainer(model, for k, t in concatenated_examples.items() } result["labels"] = result["input_ids"].copy() - return result + return result # 2318 * 1024, dict(input_ids=[[token1, token2, ...token1024], ...], attention_mask=[[...], ....], labels=[[...],...]) # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder # for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower diff --git a/train_mlm_with_hf_trainer.py b/train_mlm_with_hf_trainer.py index 6cb5c7f..45b262f 100644 --- a/train_mlm_with_hf_trainer.py +++ b/train_mlm_with_hf_trainer.py @@ -129,10 +129,16 @@ def train_mlm_with_hf_trainer(model, import datasets from transformers import Trainer from datasets import load_dataset + from towhee.trainer.training_config import get_dataclasses_help print('train mlm with hugging face transformers trainer') + print('**** DataTrainingArguments ****') + get_dataclasses_help(DataTrainingArguments) data_args = dataclass_from_dict(DataTrainingArguments, data_args) + + print('**** TrainingArguments ****') + get_dataclasses_help(TrainingArguments) training_args = dataclass_from_dict(TrainingArguments, training_args) # Setup logging @@ -321,20 +327,21 @@ def train_mlm_with_hf_trainer(model, # Main data processing function that will concatenate all texts from our dataset and generate chunks of # max_seq_length. - def group_texts(examples): + def group_texts(examples): # examples: 1000 * (about 50~500) = total_length # Concatenate all texts. concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()} total_length = len(concatenated_examples[list(examples.keys())[0]]) # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can # customize this part to your needs. - if total_length >= max_seq_length: + if total_length >= max_seq_length: # max_seq_length = 512 total_length = (total_length // max_seq_length) * max_seq_length # Split by chunks of max_len. result = { k: [t[i: i + max_seq_length] for i in range(0, total_length, max_seq_length)] for k, t in concatenated_examples.items() } - return result + return result # 573 * 512 = 293376 = total_length, dict(input_ids=[[token1, token2, ...token512], ...], token_type_ids=[[...],...], attention_mask=[[...],...], special_tkens_mask=[[...],...]) + # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a # remainder for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value @@ -449,3 +456,19 @@ def train_mlm_with_hf_trainer(model, trainer.save_metrics("eval", metrics) print('done mlm.') + + sequence = ( + f"I have this film out of the {tokenizer.mask_token} right now and I haven't finished watching it. It is so bad I am in disbelief." + ) + + import torch + inputs = tokenizer(sequence, return_tensors="pt").to('cuda:0') + mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1] + + token_logits = model(**inputs).logits # [1, 28, 30522] + mask_token_logits = token_logits[0, mask_token_index, :] + + top_5_tokens = torch.topk(mask_token_logits, 5, dim=1).indices[0].tolist() + + for token in top_5_tokens: + print(sequence.replace(tokenizer.mask_token, tokenizer.decode([token]))) \ No newline at end of file