logo
Browse Source

add making saving dir

main
ChengZi 2 years ago
parent
commit
5db34032f6
  1. 2
      train_isc.py

2
train_isc.py

@ -151,6 +151,8 @@ def main_worker(gpu, ngpus_per_node, model, training_args):
train_one_epoch(train_loader, model, loss_fn, optimizer, scaler, epoch, rank)
if not distributed or (distributed and rank == 0):
if not os.path.exists(training_args.output_dir):
os.mkdir(training_args.output_dir)
torch.save({
'epoch': epoch + 1,
'state_dict': model.state_dict(),

Loading…
Cancel
Save