|
@ -9,10 +9,6 @@ from dataclasses import dataclass, field |
|
|
from itertools import chain |
|
|
from itertools import chain |
|
|
from typing import Optional |
|
|
from typing import Optional |
|
|
|
|
|
|
|
|
import datasets |
|
|
|
|
|
from datasets import load_dataset |
|
|
|
|
|
|
|
|
|
|
|
import evaluate |
|
|
|
|
|
import transformers |
|
|
import transformers |
|
|
from transformers import ( |
|
|
from transformers import ( |
|
|
MODEL_FOR_CAUSAL_LM_MAPPING, |
|
|
MODEL_FOR_CAUSAL_LM_MAPPING, |
|
@ -24,11 +20,8 @@ from transformers import ( |
|
|
from transformers.testing_utils import CaptureLogger |
|
|
from transformers.testing_utils import CaptureLogger |
|
|
from transformers.trainer_utils import get_last_checkpoint |
|
|
from transformers.trainer_utils import get_last_checkpoint |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MODEL_CONFIG_CLASSES = list(MODEL_FOR_CAUSAL_LM_MAPPING.keys()) |
|
|
MODEL_CONFIG_CLASSES = list(MODEL_FOR_CAUSAL_LM_MAPPING.keys()) |
|
|
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) |
|
|
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) |
|
|
|
|
|
|
|
@ -41,7 +34,6 @@ def dataclass_from_dict(klass, d): |
|
|
return d # Not a dataclass field |
|
|
return d # Not a dataclass field |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass |
|
|
@dataclass |
|
|
class DataTrainingArguments: |
|
|
class DataTrainingArguments: |
|
|
""" |
|
|
""" |
|
@ -122,7 +114,11 @@ def train_clm_with_hf_trainer(model, |
|
|
data_args, |
|
|
data_args, |
|
|
training_args, |
|
|
training_args, |
|
|
**kwargs): |
|
|
**kwargs): |
|
|
|
|
|
import evaluate |
|
|
|
|
|
import datasets |
|
|
from transformers import Trainer |
|
|
from transformers import Trainer |
|
|
|
|
|
from datasets import load_dataset |
|
|
|
|
|
|
|
|
print('train clm with hugging face transformers trainer') |
|
|
print('train clm with hugging face transformers trainer') |
|
|
|
|
|
|
|
|
data_args = dataclass_from_dict(DataTrainingArguments, data_args) |
|
|
data_args = dataclass_from_dict(DataTrainingArguments, data_args) |
|
|