logo
Browse Source

add print args

main
ChengZi 1 year ago
parent
commit
45c55a398c
  1. 8
      train_clm_with_hf_trainer.py
  2. 29
      train_mlm_with_hf_trainer.py

8
train_clm_with_hf_trainer.py

@ -118,10 +118,16 @@ def train_clm_with_hf_trainer(model,
import datasets
from transformers import Trainer
from datasets import load_dataset
from towhee.trainer.training_config import get_dataclasses_help
print('train clm with hugging face transformers trainer')
print('**** DataTrainingArguments ****')
get_dataclasses_help(DataTrainingArguments)
data_args = dataclass_from_dict(DataTrainingArguments, data_args)
print('**** TrainingArguments ****')
get_dataclasses_help(TrainingArguments)
training_args = dataclass_from_dict(TrainingArguments, training_args)
# Setup logging
@ -308,7 +314,7 @@ def train_clm_with_hf_trainer(model,
for k, t in concatenated_examples.items()
}
result["labels"] = result["input_ids"].copy()
return result
return result # 2318 * 1024, dict(input_ids=[[token1, token2, ...token1024], ...], attention_mask=[[...], ....], labels=[[...],...])
# Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder
# for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower

29
train_mlm_with_hf_trainer.py

@ -129,10 +129,16 @@ def train_mlm_with_hf_trainer(model,
import datasets
from transformers import Trainer
from datasets import load_dataset
from towhee.trainer.training_config import get_dataclasses_help
print('train mlm with hugging face transformers trainer')
print('**** DataTrainingArguments ****')
get_dataclasses_help(DataTrainingArguments)
data_args = dataclass_from_dict(DataTrainingArguments, data_args)
print('**** TrainingArguments ****')
get_dataclasses_help(TrainingArguments)
training_args = dataclass_from_dict(TrainingArguments, training_args)
# Setup logging
@ -321,20 +327,21 @@ def train_mlm_with_hf_trainer(model,
# Main data processing function that will concatenate all texts from our dataset and generate chunks of
# max_seq_length.
def group_texts(examples):
def group_texts(examples): # examples: 1000 * (about 50~500) = total_length
# Concatenate all texts.
concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
total_length = len(concatenated_examples[list(examples.keys())[0]])
# We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
# customize this part to your needs.
if total_length >= max_seq_length:
if total_length >= max_seq_length: # max_seq_length = 512
total_length = (total_length // max_seq_length) * max_seq_length
# Split by chunks of max_len.
result = {
k: [t[i: i + max_seq_length] for i in range(0, total_length, max_seq_length)]
for k, t in concatenated_examples.items()
}
return result
return result # 573 * 512 = 293376 = total_length, dict(input_ids=[[token1, token2, ...token512], ...], token_type_ids=[[...],...], attention_mask=[[...],...], special_tkens_mask=[[...],...])
# Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a
# remainder for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value
@ -449,3 +456,19 @@ def train_mlm_with_hf_trainer(model,
trainer.save_metrics("eval", metrics)
print('done mlm.')
sequence = (
f"I have this film out of the {tokenizer.mask_token} right now and I haven't finished watching it. It is so bad I am in disbelief."
)
import torch
inputs = tokenizer(sequence, return_tensors="pt").to('cuda:0')
mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1]
token_logits = model(**inputs).logits # [1, 28, 30522]
mask_token_logits = token_logits[0, mask_token_index, :]
top_5_tokens = torch.topk(mask_token_logits, 5, dim=1).indices[0].tolist()
for token in top_5_tokens:
print(sequence.replace(tokenizer.mask_token, tokenizer.decode([token])))
Loading…
Cancel
Save