diff --git a/train_clm_with_hf_trainer.py b/train_clm_with_hf_trainer.py index b3f6ef5..4895699 100644 --- a/train_clm_with_hf_trainer.py +++ b/train_clm_with_hf_trainer.py @@ -16,7 +16,6 @@ import evaluate import transformers from transformers import ( MODEL_FOR_CAUSAL_LM_MAPPING, - Trainer, TrainingArguments, default_data_collator, is_torch_tpu_available, @@ -123,6 +122,7 @@ def train_clm_with_hf_trainer(model, data_args, training_args, **kwargs): + from transformers import Trainer print('train clm with hugging face transformers trainer') data_args = dataclass_from_dict(DataTrainingArguments, data_args) diff --git a/train_mlm_with_hf_trainer.py b/train_mlm_with_hf_trainer.py index 97e6c24..aeebaf8 100644 --- a/train_mlm_with_hf_trainer.py +++ b/train_mlm_with_hf_trainer.py @@ -17,7 +17,6 @@ import transformers from transformers import ( MODEL_FOR_MASKED_LM_MAPPING, DataCollatorForLanguageModeling, - Trainer, TrainingArguments, is_torch_tpu_available, set_seed, @@ -131,6 +130,7 @@ def train_mlm_with_hf_trainer(model, data_args, training_args, **kwargs): + from transformers import Trainer print('train mlm with hugging face transformers trainer') data_args = dataclass_from_dict(DataTrainingArguments, data_args)