diff --git a/train_clip_with_hf_trainer.py b/train_clip_with_hf_trainer.py index 5184fde..2f53662 100644 --- a/train_clip_with_hf_trainer.py +++ b/train_clip_with_hf_trainer.py @@ -3,7 +3,6 @@ import os import sys import transformers import dataclasses -import ipdb from dataclasses import dataclass, field from typing import Optional, List @@ -21,6 +20,7 @@ from transformers import ( is_torch_tpu_available, set_seed, ) +from transformers.trainer_utils import get_last_checkpoint # We use torchvision for faster image pre-processing. The transforms are implemented as nn.Module, # so we jit it to be faster. @@ -180,16 +180,15 @@ def train_with_hf_trainer(model, tokenizer, data_args, training_args, **kwargs): get_dataclasses_help(TrainingArguments) training_args = dataclass_from_dict(TrainingArguments, training_args) - # 2. Setup logging + # Setup logging logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", handlers=[logging.StreamHandler(sys.stdout)], ) - # 2. Setup logging + # Setup logging #+ training_args - #log_level = training_args.get_process_log_level() log_level = training_args.get_process_log_level() logger.setLevel(log_level) transformers.utils.logging.set_verbosity(log_level) @@ -204,9 +203,19 @@ def train_with_hf_trainer(model, tokenizer, data_args, training_args, **kwargs): logger.info(f"Training/evaluation parameters {training_args}") # Detecting last checkpoint and eventualy continue from last checkpoint - ### place holder ### - ### place holder ### - ### place holder ### + last_checkpoint = None + if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: + last_checkpoint = get_last_checkpoint(training_args.output_dir) + if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: + raise ValueError( + f"Output directory ({training_args.output_dir}) already exists and is not empty. " + "Use --overwrite_output_dir to overcome." + ) + elif last_checkpoint is not None and training_args.resume_from_checkpoint is None: + logger.info( + f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " + "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." + ) # Load dataset @@ -384,7 +393,6 @@ def train_with_hf_trainer(model, tokenizer, data_args, training_args, **kwargs): ) # Training - last_checkpoint = None if training_args.do_train: checkpoint = None if training_args.resume_from_checkpoint is not None: