Browse Source
Update acc
Signed-off-by: junjie.jiang <junjie.jiang@zilliz.com>
main
1 changed files with
5 additions and
4 deletions
-
auto_transformers.py
|
|
@ -35,9 +35,6 @@ import warnings |
|
|
|
import logging |
|
|
|
from transformers import logging as t_logging |
|
|
|
|
|
|
|
from .train_mlm_with_hf_trainer import train_mlm_with_hf_trainer |
|
|
|
from .train_clm_with_hf_trainer import train_clm_with_hf_trainer |
|
|
|
|
|
|
|
log = logging.getLogger('run_op') |
|
|
|
warnings.filterwarnings('ignore') |
|
|
|
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' |
|
|
@ -60,7 +57,7 @@ def create_model(model_name, checkpoint_path, device): |
|
|
|
return model |
|
|
|
|
|
|
|
|
|
|
|
# @accelerate |
|
|
|
@accelerate |
|
|
|
class Model: |
|
|
|
def __init__(self, model_name, checkpoint_path, device): |
|
|
|
self.device = device |
|
|
@ -277,6 +274,10 @@ class AutoTransformers(NNOperator): |
|
|
|
train_dataset=None, |
|
|
|
eval_dataset=None, |
|
|
|
resume_checkpoint_path=None, **kwargs): |
|
|
|
from .train_mlm_with_hf_trainer import train_mlm_with_hf_trainer |
|
|
|
from .train_clm_with_hf_trainer import train_clm_with_hf_trainer |
|
|
|
|
|
|
|
|
|
|
|
task = kwargs.pop('task', None) |
|
|
|
data_args = kwargs.pop('data_args', None) |
|
|
|
training_args = kwargs.pop('training_args', None) |
|
|
|