|  |  | @ -129,10 +129,16 @@ def train_mlm_with_hf_trainer(model, | 
			
		
	
		
			
				
					|  |  |  |     import datasets | 
			
		
	
		
			
				
					|  |  |  |     from transformers import Trainer | 
			
		
	
		
			
				
					|  |  |  |     from datasets import load_dataset | 
			
		
	
		
			
				
					|  |  |  |     from towhee.trainer.training_config import get_dataclasses_help | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |     print('train mlm with hugging face transformers trainer') | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |     print('**** DataTrainingArguments ****') | 
			
		
	
		
			
				
					|  |  |  |     get_dataclasses_help(DataTrainingArguments) | 
			
		
	
		
			
				
					|  |  |  |     data_args = dataclass_from_dict(DataTrainingArguments, data_args) | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |     print('**** TrainingArguments ****') | 
			
		
	
		
			
				
					|  |  |  |     get_dataclasses_help(TrainingArguments) | 
			
		
	
		
			
				
					|  |  |  |     training_args = dataclass_from_dict(TrainingArguments, training_args) | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |     # Setup logging | 
			
		
	
	
		
			
				
					|  |  | @ -321,20 +327,21 @@ def train_mlm_with_hf_trainer(model, | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |         # Main data processing function that will concatenate all texts from our dataset and generate chunks of | 
			
		
	
		
			
				
					|  |  |  |         # max_seq_length. | 
			
		
	
		
			
				
					|  |  |  |         def group_texts(examples): | 
			
		
	
		
			
				
					|  |  |  |         def group_texts(examples): # examples: 1000 * (about 50~500) = total_length | 
			
		
	
		
			
				
					|  |  |  |             # Concatenate all texts. | 
			
		
	
		
			
				
					|  |  |  |             concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()} | 
			
		
	
		
			
				
					|  |  |  |             total_length = len(concatenated_examples[list(examples.keys())[0]]) | 
			
		
	
		
			
				
					|  |  |  |             # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can | 
			
		
	
		
			
				
					|  |  |  |             # customize this part to your needs. | 
			
		
	
		
			
				
					|  |  |  |             if total_length >= max_seq_length: | 
			
		
	
		
			
				
					|  |  |  |             if total_length >= max_seq_length: # max_seq_length = 512 | 
			
		
	
		
			
				
					|  |  |  |                 total_length = (total_length // max_seq_length) * max_seq_length | 
			
		
	
		
			
				
					|  |  |  |             # Split by chunks of max_len. | 
			
		
	
		
			
				
					|  |  |  |             result = { | 
			
		
	
		
			
				
					|  |  |  |                 k: [t[i: i + max_seq_length] for i in range(0, total_length, max_seq_length)] | 
			
		
	
		
			
				
					|  |  |  |                 for k, t in concatenated_examples.items() | 
			
		
	
		
			
				
					|  |  |  |             } | 
			
		
	
		
			
				
					|  |  |  |             return result | 
			
		
	
		
			
				
					|  |  |  |             return result # 573 * 512 = 293376 = total_length, dict(input_ids=[[token1, token2, ...token512], ...], token_type_ids=[[...],...], attention_mask=[[...],...], special_tkens_mask=[[...],...]) | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |         # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a | 
			
		
	
		
			
				
					|  |  |  |         # remainder for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value | 
			
		
	
	
		
			
				
					|  |  | @ -449,3 +456,19 @@ def train_mlm_with_hf_trainer(model, | 
			
		
	
		
			
				
					|  |  |  |         trainer.save_metrics("eval", metrics) | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |     print('done mlm.') | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |     sequence = ( | 
			
		
	
		
			
				
					|  |  |  |         f"I have this film out of the {tokenizer.mask_token} right now and I haven't finished watching it. It is so bad I am in disbelief." | 
			
		
	
		
			
				
					|  |  |  |     ) | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |     import torch | 
			
		
	
		
			
				
					|  |  |  |     inputs = tokenizer(sequence, return_tensors="pt").to('cuda:0') | 
			
		
	
		
			
				
					|  |  |  |     mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1] | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |     token_logits = model(**inputs).logits  # [1, 28, 30522] | 
			
		
	
		
			
				
					|  |  |  |     mask_token_logits = token_logits[0, mask_token_index, :] | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |     top_5_tokens = torch.topk(mask_token_logits, 5, dim=1).indices[0].tolist() | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |     for token in top_5_tokens: | 
			
		
	
		
			
				
					|  |  |  |         print(sequence.replace(tokenizer.mask_token, tokenizer.decode([token]))) |