magic
copied
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions
153 lines
7.3 KiB
153 lines
7.3 KiB
2 years ago
|
import json
|
||
|
import random
|
||
|
import torch
|
||
|
import numpy as np
|
||
|
from torch.nn.utils import rnn
|
||
|
|
||
|
class Data:
|
||
|
def __init__(self, model_name, train_path, dev_path, test_path, max_len,
|
||
|
sos_token, pad_token, add_eos_token_to_data):
|
||
|
'''
|
||
|
model_name: gpt2
|
||
|
train_path: training data path
|
||
|
dev_path: validation data path
|
||
|
test_path: test data path
|
||
|
max_len: maximum length for training sequences
|
||
|
sos_token: initialized sos token <-start_of_text->
|
||
|
pad_token: used to pad the sequences <-pad->
|
||
|
add_eos_token_to_data: whether we want to the model learn to generate eos token;
|
||
|
if so, the model could automatically stop generation by generating eos token
|
||
|
'''
|
||
|
from transformers import GPT2TokenizerFast
|
||
|
self.tokenizer = GPT2TokenizerFast.from_pretrained(model_name)
|
||
|
self.sos_token, self.sos_token_id = self.add_special_token(sos_token)
|
||
|
print ('sos token is {}, sos token id is {}'.format(self.sos_token, self.sos_token_id))
|
||
|
self.pad_token, self.pad_token_id = self.add_special_token(pad_token)
|
||
|
print ('pad token is {}, pad token id is {}'.format(self.pad_token, self.pad_token_id))
|
||
|
self.eos_token, self.eos_token_id = self.tokenizer.bos_token, self.tokenizer.bos_token_id
|
||
|
print ('eos token is {}, eos token id is {}'.format(self.eos_token, self.eos_token_id))
|
||
|
self.add_eos_token_to_data = add_eos_token_to_data
|
||
|
|
||
|
self.max_len = max_len
|
||
|
self.train_token_list, self.train_token_id_list = self.process_one_file(train_path)
|
||
|
self.dev_token_list, self.dev_token_id_list = self.process_one_file(dev_path)
|
||
|
self.test_token_list, self.test_token_id_list = self.process_one_file(test_path)
|
||
|
self.train_num, self.dev_num, self.test_num = len(self.train_token_list), len(self.dev_token_list), \
|
||
|
len(self.test_token_list)
|
||
|
print ('train number:{}, dev number:{}, test number:{}'.format(self.train_num, self.dev_num, self.test_num))
|
||
|
|
||
|
self.train_idx_list = [i for i in range(self.train_num)]
|
||
|
random.shuffle(self.train_idx_list)
|
||
|
self.dev_idx_list = [j for j in range(self.dev_num)]
|
||
|
self.test_idx_list = [j for j in range(self.test_num)]
|
||
|
self.dev_current_idx, self.test_current_idx = 0, 0
|
||
|
|
||
|
def add_special_token(self, special_token):
|
||
|
if special_token in self.tokenizer.vocab:
|
||
|
print (special_token + ' token exists.')
|
||
|
else:
|
||
|
print ('Add token to the tokenizer.')
|
||
|
print ('Original vocabulary size is {}'.format(len(self.tokenizer)))
|
||
|
self.tokenizer.add_tokens([special_token])
|
||
|
print ('Vocabulary size after extension is {}'.format(len(self.tokenizer)))
|
||
|
assert len(self.tokenizer.convert_tokens_to_ids([special_token])) == 1
|
||
|
special_token_id = self.tokenizer.convert_tokens_to_ids([special_token])[0]
|
||
|
return special_token, special_token_id
|
||
|
|
||
|
def process_one_file(self, path):
|
||
|
print ('Processing {}'.format(path))
|
||
|
with open(path) as f:
|
||
|
item_list = json.load(f)
|
||
|
lines = []
|
||
|
for item in item_list:
|
||
|
captions_list = item['captions']
|
||
|
for one_caption in captions_list:
|
||
|
lines.append(one_caption.strip())
|
||
|
|
||
|
res_token_list, res_token_id_list = [], []
|
||
|
n = len(lines)
|
||
|
for i in range(n):
|
||
|
text = lines[i].strip('\n')
|
||
|
self.process_one_text(text, res_token_list, res_token_id_list)
|
||
|
print ('{} processed!'.format(path))
|
||
|
return res_token_list, res_token_id_list
|
||
|
|
||
|
def process_one_text(self, text, res_token_list, res_token_id_list):
|
||
|
tokens = self.tokenizer.tokenize(text, max_length=self.max_len, truncation=True)
|
||
|
if len(tokens) <= 1: # filter out too short sequence
|
||
|
return
|
||
|
tokens = [self.sos_token] + tokens[:self.max_len]
|
||
|
if self.add_eos_token_to_data:
|
||
|
tokens = tokens + [self.eos_token]
|
||
|
token_ids = self.tokenizer.convert_tokens_to_ids(tokens)
|
||
|
res_token_list.append(tokens)
|
||
|
res_token_id_list.append(token_ids)
|
||
|
return
|
||
|
|
||
|
def pad_batch(self, batch_id_list):
|
||
|
batch_id_list = [torch.LongTensor(item) for item in batch_id_list]
|
||
|
batch_tensor = rnn.pad_sequence(batch_id_list, batch_first=True, padding_value=self.pad_token_id)
|
||
|
batch_mask = torch.ones_like(batch_tensor)
|
||
|
batch_mask = batch_mask.masked_fill(batch_tensor.eq(self.pad_token_id), 0.0).type(torch.FloatTensor)
|
||
|
return batch_tensor, batch_mask
|
||
|
|
||
|
def process_output(self, batch_tgt_id_list):
|
||
|
batch_tgt_id_list = [torch.LongTensor(item) for item in batch_tgt_id_list]
|
||
|
batch_tgt_tensor, _ = self.pad_batch(batch_tgt_id_list) # padded target sequence
|
||
|
batch_tgt_input_tensor = batch_tgt_tensor[:, :-1].clone()
|
||
|
batch_tgt_output_tensor = batch_tgt_tensor[:, 1:].clone()
|
||
|
return batch_tgt_input_tensor, batch_tgt_output_tensor
|
||
|
|
||
|
def parse_batch(self, batch_id_list):
|
||
|
batch_input, batch_labels = self.process_output(batch_id_list)
|
||
|
batch_labels[batch_labels[:, :] == self.pad_token_id] = -100
|
||
|
return batch_input, batch_labels
|
||
|
|
||
|
def get_next_train_batch(self, batch_size):
|
||
|
batch_idx_list = random.sample(self.train_idx_list, batch_size)
|
||
|
batch_id_list, batch_token_list = [], []
|
||
|
|
||
|
for idx in batch_idx_list:
|
||
|
batch_id_list.append(self.train_token_id_list[idx])
|
||
|
batch_token_list.append(self.train_token_list[idx])
|
||
|
batch_input_tensor, batch_labels = self.parse_batch(batch_id_list)
|
||
|
return batch_input_tensor, batch_labels, batch_token_list
|
||
|
|
||
|
def get_next_validation_batch(self, batch_size, mode):
|
||
|
batch_id_list, batch_token_list = [], []
|
||
|
if mode == 'dev':
|
||
|
curr_select_idx, instance_num = self.dev_current_idx, self.dev_num
|
||
|
tgt_token_id_list, tgt_token_list = self.dev_token_id_list, self.dev_token_list
|
||
|
elif mode == 'test':
|
||
|
curr_select_idx, instance_num = self.test_current_idx, self.test_num
|
||
|
tgt_token_id_list, tgt_token_list = self.test_token_id_list, self.test_token_list
|
||
|
else:
|
||
|
raise Exception('Wrong Validation Mode!!!')
|
||
|
|
||
|
if curr_select_idx + batch_size < instance_num:
|
||
|
for i in range(batch_size):
|
||
|
curr_idx = curr_select_idx + i
|
||
|
batch_id_list.append(tgt_token_id_list[curr_idx])
|
||
|
batch_token_list.append(tgt_token_list[curr_idx])
|
||
|
if mode == 'dev':
|
||
|
self.dev_current_idx += batch_size
|
||
|
else:
|
||
|
self.test_current_idx += batch_size
|
||
|
else:
|
||
|
for i in range(batch_size):
|
||
|
curr_idx = curr_select_idx + i
|
||
|
if curr_idx > instance_num - 1:
|
||
|
curr_idx = 0
|
||
|
if mode == 'dev':
|
||
|
self.dev_current_idx = 0
|
||
|
else:
|
||
|
self.test_current_idx = 0
|
||
|
batch_id_list.append(tgt_token_id_list[curr_idx])
|
||
|
batch_token_list.append(tgt_token_list[curr_idx])
|
||
|
if mode == 'dev':
|
||
|
self.dev_current_idx = 0
|
||
|
else:
|
||
|
self.test_current_idx = 0
|
||
|
batch_input_tensor, batch_labels = self.parse_batch(batch_id_list)
|
||
|
return batch_input_tensor, batch_labels, batch_token_list
|