diff --git a/train_isc.py b/train_isc.py index 1f7dd39..9820eb7 100644 --- a/train_isc.py +++ b/train_isc.py @@ -151,6 +151,8 @@ def main_worker(gpu, ngpus_per_node, model, training_args): train_one_epoch(train_loader, model, loss_fn, optimizer, scaler, epoch, rank) if not distributed or (distributed and rank == 0): + if not os.path.exists(training_args.output_dir): + os.mkdir(training_args.output_dir) torch.save({ 'epoch': epoch + 1, 'state_dict': model.state_dict(),