logo
Browse Source

and last checkpoint check.

Signed-off-by: jinlingxu06 <jinling.xu@zilliz.com>
main
jinlingxu06 2 years ago
committed by wxywb
parent
commit
5c7e5c574b
  1. 24
      train_clip_with_hf_trainer.py

24
train_clip_with_hf_trainer.py

@ -3,7 +3,6 @@ import os
import sys import sys
import transformers import transformers
import dataclasses import dataclasses
import ipdb
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Optional, List from typing import Optional, List
@ -21,6 +20,7 @@ from transformers import (
is_torch_tpu_available, is_torch_tpu_available,
set_seed, set_seed,
) )
from transformers.trainer_utils import get_last_checkpoint
# We use torchvision for faster image pre-processing. The transforms are implemented as nn.Module, # We use torchvision for faster image pre-processing. The transforms are implemented as nn.Module,
# so we jit it to be faster. # so we jit it to be faster.
@ -180,16 +180,15 @@ def train_with_hf_trainer(model, tokenizer, data_args, training_args, **kwargs):
get_dataclasses_help(TrainingArguments) get_dataclasses_help(TrainingArguments)
training_args = dataclass_from_dict(TrainingArguments, training_args) training_args = dataclass_from_dict(TrainingArguments, training_args)
# 2. Setup logging
# Setup logging
logging.basicConfig( logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S", datefmt="%m/%d/%Y %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout)], handlers=[logging.StreamHandler(sys.stdout)],
) )
# 2. Setup logging
# Setup logging
#+ training_args #+ training_args
#log_level = training_args.get_process_log_level()
log_level = training_args.get_process_log_level() log_level = training_args.get_process_log_level()
logger.setLevel(log_level) logger.setLevel(log_level)
transformers.utils.logging.set_verbosity(log_level) transformers.utils.logging.set_verbosity(log_level)
@ -204,9 +203,19 @@ def train_with_hf_trainer(model, tokenizer, data_args, training_args, **kwargs):
logger.info(f"Training/evaluation parameters {training_args}") logger.info(f"Training/evaluation parameters {training_args}")
# Detecting last checkpoint and eventualy continue from last checkpoint # Detecting last checkpoint and eventualy continue from last checkpoint
### place holder ###
### place holder ###
### place holder ###
last_checkpoint = None
if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
last_checkpoint = get_last_checkpoint(training_args.output_dir)
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
raise ValueError(
f"Output directory ({training_args.output_dir}) already exists and is not empty. "
"Use --overwrite_output_dir to overcome."
)
elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
logger.info(
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
)
# Load dataset # Load dataset
@ -384,7 +393,6 @@ def train_with_hf_trainer(model, tokenizer, data_args, training_args, **kwargs):
) )
# Training # Training
last_checkpoint = None
if training_args.do_train: if training_args.do_train:
checkpoint = None checkpoint = None
if training_args.resume_from_checkpoint is not None: if training_args.resume_from_checkpoint is not None:

Loading…
Cancel
Save