logo
Browse Source

remove test in the clip training.

Signed-off-by: wxywb <xy.wang@zilliz.com>
main
wxywb 2 years ago
parent
commit
8f521ea2f6
  1. 7
      train_clip_with_hf_trainer.py

7
train_clip_with_hf_trainer.py

@ -236,9 +236,6 @@ def train_with_hf_trainer(model, tokenizer, data_args, training_args, **kwargs):
if data_args.validation_file is not None: if data_args.validation_file is not None:
data_files["validation"] = data_args.validation_file data_files["validation"] = data_args.validation_file
extension = data_args.validation_file.split(".")[-1] extension = data_args.validation_file.split(".")[-1]
if data_args.test_file is not None:
data_files["test"] = data_args.test_file
extension = data_args.test_file.split(".")[-1]
dataset = load_dataset( dataset = load_dataset(
extension, extension,
data_files=data_files, data_files=data_files,
@ -273,10 +270,8 @@ def train_with_hf_trainer(model, tokenizer, data_args, training_args, **kwargs):
column_names = dataset["train"].column_names column_names = dataset["train"].column_names
elif training_args.do_eval: elif training_args.do_eval:
column_names = dataset["validation"].column_names column_names = dataset["validation"].column_names
elif training_args.do_predict:
column_names = dataset["test"].column_names
else: else:
logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.")
logger.info("There is nothing to do. Please pass `do_train`, `do_eval`.")
return return
dataset_columns = dataset_name_mapping.get(data_args.dataset_name, None) dataset_columns = dataset_name_mapping.get(data_args.dataset_name, None)

Loading…
Cancel
Save