diff --git a/train_clip_with_hf_trainer.py b/train_clip_with_hf_trainer.py index 5b36085..5184fde 100644 --- a/train_clip_with_hf_trainer.py +++ b/train_clip_with_hf_trainer.py @@ -236,9 +236,6 @@ def train_with_hf_trainer(model, tokenizer, data_args, training_args, **kwargs): if data_args.validation_file is not None: data_files["validation"] = data_args.validation_file extension = data_args.validation_file.split(".")[-1] - if data_args.test_file is not None: - data_files["test"] = data_args.test_file - extension = data_args.test_file.split(".")[-1] dataset = load_dataset( extension, data_files=data_files, @@ -273,10 +270,8 @@ def train_with_hf_trainer(model, tokenizer, data_args, training_args, **kwargs): column_names = dataset["train"].column_names elif training_args.do_eval: column_names = dataset["validation"].column_names - elif training_args.do_predict: - column_names = dataset["test"].column_names else: - logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.") + logger.info("There is nothing to do. Please pass `do_train`, `do_eval`.") return dataset_columns = dataset_name_mapping.get(data_args.dataset_name, None)