logo
Browse Source

lazy import evaluate and datasets for avoiding potential error.

main
ChengZi 2 years ago
parent
commit
44fac86041
  1. 14
      train_clm_with_hf_trainer.py
  2. 9
      train_mlm_with_hf_trainer.py

14
train_clm_with_hf_trainer.py

@ -9,10 +9,6 @@ from dataclasses import dataclass, field
from itertools import chain from itertools import chain
from typing import Optional from typing import Optional
import datasets
from datasets import load_dataset
import evaluate
import transformers import transformers
from transformers import ( from transformers import (
MODEL_FOR_CAUSAL_LM_MAPPING, MODEL_FOR_CAUSAL_LM_MAPPING,
@ -24,11 +20,8 @@ from transformers import (
from transformers.testing_utils import CaptureLogger from transformers.testing_utils import CaptureLogger
from transformers.trainer_utils import get_last_checkpoint from transformers.trainer_utils import get_last_checkpoint
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
MODEL_CONFIG_CLASSES = list(MODEL_FOR_CAUSAL_LM_MAPPING.keys()) MODEL_CONFIG_CLASSES = list(MODEL_FOR_CAUSAL_LM_MAPPING.keys())
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
@ -41,7 +34,6 @@ def dataclass_from_dict(klass, d):
return d # Not a dataclass field return d # Not a dataclass field
@dataclass @dataclass
class DataTrainingArguments: class DataTrainingArguments:
""" """
@ -122,7 +114,11 @@ def train_clm_with_hf_trainer(model,
data_args, data_args,
training_args, training_args,
**kwargs): **kwargs):
import evaluate
import datasets
from transformers import Trainer from transformers import Trainer
from datasets import load_dataset
print('train clm with hugging face transformers trainer') print('train clm with hugging face transformers trainer')
data_args = dataclass_from_dict(DataTrainingArguments, data_args) data_args = dataclass_from_dict(DataTrainingArguments, data_args)
@ -308,7 +304,7 @@ def train_clm_with_hf_trainer(model,
total_length = (total_length // block_size) * block_size total_length = (total_length // block_size) * block_size
# Split by chunks of max_len. # Split by chunks of max_len.
result = { result = {
k: [t[i : i + block_size] for i in range(0, total_length, block_size)] k: [t[i: i + block_size] for i in range(0, total_length, block_size)]
for k, t in concatenated_examples.items() for k, t in concatenated_examples.items()
} }
result["labels"] = result["input_ids"].copy() result["labels"] = result["input_ids"].copy()

9
train_mlm_with_hf_trainer.py

@ -9,10 +9,6 @@ from dataclasses import dataclass, field
from itertools import chain from itertools import chain
from typing import Optional from typing import Optional
import datasets
from datasets import load_dataset
import evaluate
import transformers import transformers
from transformers import ( from transformers import (
MODEL_FOR_MASKED_LM_MAPPING, MODEL_FOR_MASKED_LM_MAPPING,
@ -23,7 +19,6 @@ from transformers import (
) )
from transformers.trainer_utils import get_last_checkpoint from transformers.trainer_utils import get_last_checkpoint
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
MODEL_CONFIG_CLASSES = list(MODEL_FOR_MASKED_LM_MAPPING.keys()) MODEL_CONFIG_CLASSES = list(MODEL_FOR_MASKED_LM_MAPPING.keys())
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
@ -130,7 +125,11 @@ def train_mlm_with_hf_trainer(model,
data_args, data_args,
training_args, training_args,
**kwargs): **kwargs):
import evaluate
import datasets
from transformers import Trainer from transformers import Trainer
from datasets import load_dataset
print('train mlm with hugging face transformers trainer') print('train mlm with hugging face transformers trainer')
data_args = dataclass_from_dict(DataTrainingArguments, data_args) data_args = dataclass_from_dict(DataTrainingArguments, data_args)

Loading…
Cancel
Save