|
|
@ -129,10 +129,16 @@ def train_mlm_with_hf_trainer(model, |
|
|
|
import datasets |
|
|
|
from transformers import Trainer |
|
|
|
from datasets import load_dataset |
|
|
|
from towhee.trainer.training_config import get_dataclasses_help |
|
|
|
|
|
|
|
print('train mlm with hugging face transformers trainer') |
|
|
|
|
|
|
|
print('**** DataTrainingArguments ****') |
|
|
|
get_dataclasses_help(DataTrainingArguments) |
|
|
|
data_args = dataclass_from_dict(DataTrainingArguments, data_args) |
|
|
|
|
|
|
|
print('**** TrainingArguments ****') |
|
|
|
get_dataclasses_help(TrainingArguments) |
|
|
|
training_args = dataclass_from_dict(TrainingArguments, training_args) |
|
|
|
|
|
|
|
# Setup logging |
|
|
@ -321,20 +327,21 @@ def train_mlm_with_hf_trainer(model, |
|
|
|
|
|
|
|
# Main data processing function that will concatenate all texts from our dataset and generate chunks of |
|
|
|
# max_seq_length. |
|
|
|
def group_texts(examples): |
|
|
|
def group_texts(examples): # examples: 1000 * (about 50~500) = total_length |
|
|
|
# Concatenate all texts. |
|
|
|
concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()} |
|
|
|
total_length = len(concatenated_examples[list(examples.keys())[0]]) |
|
|
|
# We drop the small remainder, we could add padding if the model supported it instead of this drop, you can |
|
|
|
# customize this part to your needs. |
|
|
|
if total_length >= max_seq_length: |
|
|
|
if total_length >= max_seq_length: # max_seq_length = 512 |
|
|
|
total_length = (total_length // max_seq_length) * max_seq_length |
|
|
|
# Split by chunks of max_len. |
|
|
|
result = { |
|
|
|
k: [t[i: i + max_seq_length] for i in range(0, total_length, max_seq_length)] |
|
|
|
for k, t in concatenated_examples.items() |
|
|
|
} |
|
|
|
return result |
|
|
|
return result # 573 * 512 = 293376 = total_length, dict(input_ids=[[token1, token2, ...token512], ...], token_type_ids=[[...],...], attention_mask=[[...],...], special_tkens_mask=[[...],...]) |
|
|
|
|
|
|
|
|
|
|
|
# Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a |
|
|
|
# remainder for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value |
|
|
@ -449,3 +456,19 @@ def train_mlm_with_hf_trainer(model, |
|
|
|
trainer.save_metrics("eval", metrics) |
|
|
|
|
|
|
|
print('done mlm.') |
|
|
|
|
|
|
|
sequence = ( |
|
|
|
f"I have this film out of the {tokenizer.mask_token} right now and I haven't finished watching it. It is so bad I am in disbelief." |
|
|
|
) |
|
|
|
|
|
|
|
import torch |
|
|
|
inputs = tokenizer(sequence, return_tensors="pt").to('cuda:0') |
|
|
|
mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1] |
|
|
|
|
|
|
|
token_logits = model(**inputs).logits # [1, 28, 30522] |
|
|
|
mask_token_logits = token_logits[0, mask_token_index, :] |
|
|
|
|
|
|
|
top_5_tokens = torch.topk(mask_token_logits, 5, dim=1).indices[0].tolist() |
|
|
|
|
|
|
|
for token in top_5_tokens: |
|
|
|
print(sequence.replace(tokenizer.mask_token, tokenizer.decode([token]))) |