|
@ -151,6 +151,8 @@ def main_worker(gpu, ngpus_per_node, model, training_args): |
|
|
train_one_epoch(train_loader, model, loss_fn, optimizer, scaler, epoch, rank) |
|
|
train_one_epoch(train_loader, model, loss_fn, optimizer, scaler, epoch, rank) |
|
|
|
|
|
|
|
|
if not distributed or (distributed and rank == 0): |
|
|
if not distributed or (distributed and rank == 0): |
|
|
|
|
|
if not os.path.exists(training_args.output_dir): |
|
|
|
|
|
os.mkdir(training_args.output_dir) |
|
|
torch.save({ |
|
|
torch.save({ |
|
|
'epoch': epoch + 1, |
|
|
'epoch': epoch + 1, |
|
|
'state_dict': model.state_dict(), |
|
|
'state_dict': model.state_dict(), |
|
|