magic
copied
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions
106 lines
4.2 KiB
106 lines
4.2 KiB
# coding=utf-8
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import torch.multiprocessing as mp
|
|
import argparse, os
|
|
import random
|
|
import numpy as np
|
|
import time
|
|
import logging
|
|
|
|
import logging
|
|
logging.getLogger('transformers.generation_utils').disabled = True
|
|
|
|
def parse_config():
|
|
parser = argparse.ArgumentParser()
|
|
# data configuration
|
|
parser.add_argument("--model_name", type=str, default='gpt2')
|
|
parser.add_argument("--train_path", type=str)
|
|
parser.add_argument("--dev_path", type=str)
|
|
parser.add_argument("--test_path", type=str)
|
|
parser.add_argument("--max_len", type=int)
|
|
parser.add_argument("--add_eos_token_to_data", type=str)
|
|
# mini-batch training configuration
|
|
parser.add_argument("--number_of_gpu", type=int, help="Number of available GPUs.")
|
|
parser.add_argument("--batch_size_per_gpu", type=int, help='batch size for each gpu.')
|
|
parser.add_argument("--gradient_accumulation_steps", type=int, help="gradient accumulation step.")
|
|
parser.add_argument("--effective_batch_size", type=int,
|
|
help="effective_bsz = batch_size_per_gpu x number_of_gpu x gradient_accumulation_steps")
|
|
# pre-training configuration
|
|
parser.add_argument("--total_steps", type=int,
|
|
help="total effective training steps")
|
|
parser.add_argument("--print_every", type=int,
|
|
help="how many update steps to print one intermediate result")
|
|
parser.add_argument("--save_every", type=int,
|
|
help="how many update steps to save one model")
|
|
# learning configuration
|
|
parser.add_argument("--learning_rate", type=float, default=2e-5)
|
|
parser.add_argument("--margin", type=float)
|
|
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
|
parser.add_argument("--save_path_prefix", type=str, help="directory to save the model parameters.")
|
|
return parser.parse_args()
|
|
|
|
def load_previous_best_model(path):
|
|
import os
|
|
filenames = os.listdir(path)
|
|
for file in filenames:
|
|
if file.startswith('training_step'):
|
|
return path + '/' + file
|
|
raise Exception('No best model found!')
|
|
|
|
import argparse
|
|
if __name__ == '__main__':
|
|
if torch.cuda.is_available():
|
|
print ('Cuda is available.')
|
|
cuda_available = torch.cuda.is_available()
|
|
multi_gpu_training = False
|
|
if cuda_available:
|
|
if torch.cuda.device_count() > 1:
|
|
multi_gpu_training = True
|
|
print ('Using Multi-GPU training, number of GPU is {}'.format(torch.cuda.device_count()))
|
|
else:
|
|
print ('Using single GPU training.')
|
|
else:
|
|
pass
|
|
args = parse_config()
|
|
device = torch.device('cuda')
|
|
model_name = args.model_name
|
|
|
|
sos_token, pad_token = r'<-start_of_text->', r'<-pad->'
|
|
add_eos_token_to_data = args.add_eos_token_to_data
|
|
if add_eos_token_to_data == 'True':
|
|
add_eos_token_to_data = True
|
|
print ('Add eos token to data!')
|
|
elif add_eos_token_to_data == 'False':
|
|
add_eos_token_to_data = False
|
|
print ('Do not add eos token to data!')
|
|
else:
|
|
raise Exception('Wrong eos configuration for data!!!')
|
|
print ('Loading data...')
|
|
from dataclass import Data
|
|
data = Data(model_name, args.train_path, args.dev_path, args.test_path, args.max_len,
|
|
sos_token, pad_token, add_eos_token_to_data)
|
|
print ('Data loaded.')
|
|
|
|
from trainer import model_training
|
|
print ('############################################################')
|
|
print ('Start Training...')
|
|
from simctg import SimCTG
|
|
print ('Initializaing SimCTG model...')
|
|
model = SimCTG(model_name, sos_token, pad_token)
|
|
if cuda_available:
|
|
if multi_gpu_training:
|
|
model = nn.DataParallel(model) # multi-gpu training
|
|
else:
|
|
pass
|
|
model = model.to(device)
|
|
else:
|
|
pass
|
|
print ('Model loaded')
|
|
total_steps, print_every, save_every = args.total_steps, args.print_every, args.save_every
|
|
ckpt_save_path = args.save_path_prefix
|
|
model = model_training(args, data, model, total_steps, print_every, save_every,
|
|
ckpt_save_path, cuda_available, device)
|
|
print ('Training stage completed!')
|
|
print ('############################################################')
|