logo
Browse Source

Update acc

Signed-off-by: junjie.jiang <junjie.jiang@zilliz.com>
main
junjie.jiang 2 years ago
parent
commit
a6a0672269
  1. 9
      auto_transformers.py

9
auto_transformers.py

@ -35,9 +35,6 @@ import warnings
import logging import logging
from transformers import logging as t_logging from transformers import logging as t_logging
from .train_mlm_with_hf_trainer import train_mlm_with_hf_trainer
from .train_clm_with_hf_trainer import train_clm_with_hf_trainer
log = logging.getLogger('run_op') log = logging.getLogger('run_op')
warnings.filterwarnings('ignore') warnings.filterwarnings('ignore')
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
@ -60,7 +57,7 @@ def create_model(model_name, checkpoint_path, device):
return model return model
# @accelerate
@accelerate
class Model: class Model:
def __init__(self, model_name, checkpoint_path, device): def __init__(self, model_name, checkpoint_path, device):
self.device = device self.device = device
@ -277,6 +274,10 @@ class AutoTransformers(NNOperator):
train_dataset=None, train_dataset=None,
eval_dataset=None, eval_dataset=None,
resume_checkpoint_path=None, **kwargs): resume_checkpoint_path=None, **kwargs):
from .train_mlm_with_hf_trainer import train_mlm_with_hf_trainer
from .train_clm_with_hf_trainer import train_clm_with_hf_trainer
task = kwargs.pop('task', None) task = kwargs.pop('task', None)
data_args = kwargs.pop('data_args', None) data_args = kwargs.pop('data_args', None)
training_args = kwargs.pop('training_args', None) training_args = kwargs.pop('training_args', None)

Loading…
Cancel
Save