|
@ -9,10 +9,6 @@ from dataclasses import dataclass, field |
|
|
from itertools import chain |
|
|
from itertools import chain |
|
|
from typing import Optional |
|
|
from typing import Optional |
|
|
|
|
|
|
|
|
import datasets |
|
|
|
|
|
from datasets import load_dataset |
|
|
|
|
|
|
|
|
|
|
|
import evaluate |
|
|
|
|
|
import transformers |
|
|
import transformers |
|
|
from transformers import ( |
|
|
from transformers import ( |
|
|
MODEL_FOR_CAUSAL_LM_MAPPING, |
|
|
MODEL_FOR_CAUSAL_LM_MAPPING, |
|
@ -24,11 +20,8 @@ from transformers import ( |
|
|
from transformers.testing_utils import CaptureLogger |
|
|
from transformers.testing_utils import CaptureLogger |
|
|
from transformers.trainer_utils import get_last_checkpoint |
|
|
from transformers.trainer_utils import get_last_checkpoint |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MODEL_CONFIG_CLASSES = list(MODEL_FOR_CAUSAL_LM_MAPPING.keys()) |
|
|
MODEL_CONFIG_CLASSES = list(MODEL_FOR_CAUSAL_LM_MAPPING.keys()) |
|
|
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) |
|
|
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) |
|
|
|
|
|
|
|
@ -41,7 +34,6 @@ def dataclass_from_dict(klass, d): |
|
|
return d # Not a dataclass field |
|
|
return d # Not a dataclass field |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass |
|
|
@dataclass |
|
|
class DataTrainingArguments: |
|
|
class DataTrainingArguments: |
|
|
""" |
|
|
""" |
|
@ -122,7 +114,11 @@ def train_clm_with_hf_trainer(model, |
|
|
data_args, |
|
|
data_args, |
|
|
training_args, |
|
|
training_args, |
|
|
**kwargs): |
|
|
**kwargs): |
|
|
|
|
|
import evaluate |
|
|
|
|
|
import datasets |
|
|
from transformers import Trainer |
|
|
from transformers import Trainer |
|
|
|
|
|
from datasets import load_dataset |
|
|
|
|
|
|
|
|
print('train clm with hugging face transformers trainer') |
|
|
print('train clm with hugging face transformers trainer') |
|
|
|
|
|
|
|
|
data_args = dataclass_from_dict(DataTrainingArguments, data_args) |
|
|
data_args = dataclass_from_dict(DataTrainingArguments, data_args) |
|
@ -308,7 +304,7 @@ def train_clm_with_hf_trainer(model, |
|
|
total_length = (total_length // block_size) * block_size |
|
|
total_length = (total_length // block_size) * block_size |
|
|
# Split by chunks of max_len. |
|
|
# Split by chunks of max_len. |
|
|
result = { |
|
|
result = { |
|
|
k: [t[i : i + block_size] for i in range(0, total_length, block_size)] |
|
|
k: [t[i: i + block_size] for i in range(0, total_length, block_size)] |
|
|
for k, t in concatenated_examples.items() |
|
|
for k, t in concatenated_examples.items() |
|
|
} |
|
|
} |
|
|
result["labels"] = result["input_ids"].copy() |
|
|
result["labels"] = result["input_ids"].copy() |
|
|