diff --git a/train_clm_with_hf_trainer.py b/train_clm_with_hf_trainer.py index 4895699..094e08a 100644 --- a/train_clm_with_hf_trainer.py +++ b/train_clm_with_hf_trainer.py @@ -9,10 +9,6 @@ from dataclasses import dataclass, field from itertools import chain from typing import Optional -import datasets -from datasets import load_dataset - -import evaluate import transformers from transformers import ( MODEL_FOR_CAUSAL_LM_MAPPING, @@ -24,11 +20,8 @@ from transformers import ( from transformers.testing_utils import CaptureLogger from transformers.trainer_utils import get_last_checkpoint - - logger = logging.getLogger(__name__) - MODEL_CONFIG_CLASSES = list(MODEL_FOR_CAUSAL_LM_MAPPING.keys()) MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) @@ -41,7 +34,6 @@ def dataclass_from_dict(klass, d): return d # Not a dataclass field - @dataclass class DataTrainingArguments: """ @@ -122,7 +114,11 @@ def train_clm_with_hf_trainer(model, data_args, training_args, **kwargs): + import evaluate + import datasets from transformers import Trainer + from datasets import load_dataset + print('train clm with hugging face transformers trainer') data_args = dataclass_from_dict(DataTrainingArguments, data_args) @@ -308,7 +304,7 @@ def train_clm_with_hf_trainer(model, total_length = (total_length // block_size) * block_size # Split by chunks of max_len. result = { - k: [t[i : i + block_size] for i in range(0, total_length, block_size)] + k: [t[i: i + block_size] for i in range(0, total_length, block_size)] for k, t in concatenated_examples.items() } result["labels"] = result["input_ids"].copy() @@ -415,4 +411,4 @@ def train_clm_with_hf_trainer(model, trainer.log_metrics("eval", metrics) trainer.save_metrics("eval", metrics) - print('done clm.') \ No newline at end of file + print('done clm.') diff --git a/train_mlm_with_hf_trainer.py b/train_mlm_with_hf_trainer.py index aeebaf8..6cb5c7f 100644 --- a/train_mlm_with_hf_trainer.py +++ b/train_mlm_with_hf_trainer.py @@ -9,10 +9,6 @@ from dataclasses import dataclass, field from itertools import chain from typing import Optional -import datasets -from datasets import load_dataset - -import evaluate import transformers from transformers import ( MODEL_FOR_MASKED_LM_MAPPING, @@ -23,7 +19,6 @@ from transformers import ( ) from transformers.trainer_utils import get_last_checkpoint - logger = logging.getLogger(__name__) MODEL_CONFIG_CLASSES = list(MODEL_FOR_MASKED_LM_MAPPING.keys()) MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) @@ -130,7 +125,11 @@ def train_mlm_with_hf_trainer(model, data_args, training_args, **kwargs): + import evaluate + import datasets from transformers import Trainer + from datasets import load_dataset + print('train mlm with hugging face transformers trainer') data_args = dataclass_from_dict(DataTrainingArguments, data_args) @@ -449,4 +448,4 @@ def train_mlm_with_hf_trainer(model, trainer.log_metrics("eval", metrics) trainer.save_metrics("eval", metrics) - print('done mlm.') \ No newline at end of file + print('done mlm.')