diff --git a/auto_transformers.py b/auto_transformers.py index 7d88542..ff193d5 100644 --- a/auto_transformers.py +++ b/auto_transformers.py @@ -404,7 +404,7 @@ class AutoTransformers(NNOperator): if task == 'mlm' or task is None: model_with_head = AutoModelForMaskedLM.from_pretrained(self.model_name) if prepare_model_weights_f is not None: - model_with_head = prepare_model_weights_f(self.model, model_with_head, **kwargs) + model_with_head = prepare_model_weights_f(self._model, model_with_head, **kwargs) train_mlm_with_hf_trainer( model_with_head, @@ -416,7 +416,7 @@ class AutoTransformers(NNOperator): elif task == 'clm': model_with_head = AutoModelForCausalLM.from_pretrained(self.model_name) if prepare_model_weights_f is not None: - model_with_head = prepare_model_weights_f(self.model, model_with_head, **kwargs) + model_with_head = prepare_model_weights_f(self._model, model_with_head, **kwargs) train_clm_with_hf_trainer( model_with_head,