diff --git a/auto_transformers.py b/auto_transformers.py index 1a06d98..236a841 100644 --- a/auto_transformers.py +++ b/auto_transformers.py @@ -35,9 +35,6 @@ import warnings import logging from transformers import logging as t_logging -from .train_mlm_with_hf_trainer import train_mlm_with_hf_trainer -from .train_clm_with_hf_trainer import train_clm_with_hf_trainer - log = logging.getLogger('run_op') warnings.filterwarnings('ignore') os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' @@ -60,7 +57,7 @@ def create_model(model_name, checkpoint_path, device): return model -# @accelerate +@accelerate class Model: def __init__(self, model_name, checkpoint_path, device): self.device = device @@ -277,6 +274,10 @@ class AutoTransformers(NNOperator): train_dataset=None, eval_dataset=None, resume_checkpoint_path=None, **kwargs): + from .train_mlm_with_hf_trainer import train_mlm_with_hf_trainer + from .train_clm_with_hf_trainer import train_clm_with_hf_trainer + + task = kwargs.pop('task', None) data_args = kwargs.pop('data_args', None) training_args = kwargs.pop('training_args', None)