logo
Browse Source

and last checkpoint check.

Signed-off-by: jinlingxu06 <jinling.xu@zilliz.com>
main
jinlingxu06 2 years ago
committed by wxywb
parent
commit
5c7e5c574b
  1. 24
      train_clip_with_hf_trainer.py

24
train_clip_with_hf_trainer.py

@ -3,7 +3,6 @@ import os
import sys
import transformers
import dataclasses
import ipdb
from dataclasses import dataclass, field
from typing import Optional, List
@ -21,6 +20,7 @@ from transformers import (
is_torch_tpu_available,
set_seed,
)
from transformers.trainer_utils import get_last_checkpoint
# We use torchvision for faster image pre-processing. The transforms are implemented as nn.Module,
# so we jit it to be faster.
@ -180,16 +180,15 @@ def train_with_hf_trainer(model, tokenizer, data_args, training_args, **kwargs):
get_dataclasses_help(TrainingArguments)
training_args = dataclass_from_dict(TrainingArguments, training_args)
# 2. Setup logging
# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout)],
)
# 2. Setup logging
# Setup logging
#+ training_args
#log_level = training_args.get_process_log_level()
log_level = training_args.get_process_log_level()
logger.setLevel(log_level)
transformers.utils.logging.set_verbosity(log_level)
@ -204,9 +203,19 @@ def train_with_hf_trainer(model, tokenizer, data_args, training_args, **kwargs):
logger.info(f"Training/evaluation parameters {training_args}")
# Detecting last checkpoint and eventualy continue from last checkpoint
### place holder ###
### place holder ###
### place holder ###
last_checkpoint = None
if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
last_checkpoint = get_last_checkpoint(training_args.output_dir)
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
raise ValueError(
f"Output directory ({training_args.output_dir}) already exists and is not empty. "
"Use --overwrite_output_dir to overcome."
)
elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
logger.info(
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
)
# Load dataset
@ -384,7 +393,6 @@ def train_with_hf_trainer(model, tokenizer, data_args, training_args, **kwargs):
)
# Training
last_checkpoint = None
if training_args.do_train:
checkpoint = None
if training_args.resume_from_checkpoint is not None:

Loading…
Cancel
Save