logo
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions

Updated 2 years ago

image-captioning

Unsupervised Domain Adaptation of Language Model


Catalogue:


1. MSCOCO Benchmark:

We first describe how to perform unsupervised domain adaptation of language model on the text corpus of MSCOCO benchmark.

1.1. MSCOCO Data Preparation:

To prepare the MSCOCO benchmark, please follow the instructions [here].

1.2.Unsupervised Domain Adaptation on MSCOCO:

After preparing the MSCOCO data, run the following command to train the language model.

chmod +x ./train_mscoco.sh
./train_mscoco.sh

The arguments are as follows:

  • --model_name: The name of huggingface pre-trained gpt model (e.g. gpt2, gpt-large).
  • --train_path: The file path of training set.
  • --dev_path: The file path of validation set.
  • --test_path: The file path of test set.
  • --add_eos_token_to_data: Whether adding an eos token at the end of text sequence.
  • --margin: The contrastive margin $\rho$.
  • --max_len: The maximum length of training samples.
  • --number_of_gpu: The number of available GPUs.
  • --batch_size_per_gpu: The batch size for each GPU.
  • --gradient_accumulation_steps: How many forward computations between two gradient updates.
  • --effective_batch_size: The overall batch size. It equals to batch_size_per_gpu x gradient_accumulation_steps x number_of_gpu.
  • --total_steps: The number of total gradient update steps.
  • --print_every: Have many steps to show the intermediate results.
  • --save_every: How many steps to save one checkpoint.
  • --learning_rate: The learning rate.
  • --save_path_prefix: Where to save the checkpoints.

2. Flickr30k Benchmark:

We then describe how to perform unsupervised domain adaptation of language model on the text corpus of Flickr30k benchmark.

2.1. Flickr30k Data Preparation:

To prepare the Flickr30k benchmark, please follow the instructions [here].

2.2. Unsupervised Domain Adaptation on Flickr30k:

After preparing the Flickr30k data, run the following command to train the language model.

chmod +x ./train_flickr30k.sh
./train_flickr30k.sh

3. Unsupervised Baselines:

Here, we illustrate how to use the language model to perform unsupervised baselines as described in our paper. Note that, all these methods are unsupervised as the language model is a text-only model and does not take image as input.

# first, load the language model
import torch
from simctg import SimCTG
sos_token, pad_token = r'<-start_of_text->', r'<-pad->'
# we use the language model adapted on MSCOCO as an example.
language_model_name = r'cambridgeltl/magic_mscoco'
generation_model = SimCTG(language_model_name, sos_token, pad_token)
generation_model.eval()

# then, prepare the input ids. Note that, the text is always generated from the same start of sentence token.
tokens = generation_model.tokenizer.tokenize(sos_token)
input_ids = generation_model.tokenizer.convert_tokens_to_ids(tokens)
input_ids = torch.LongTensor(input_ids).view(1,-1)
'''
   use contrastive search to generate the result.
   note that, contrastive search is a deterministic decoding method, thus the generated text is always the same.
'''
beam_width, alpha, decoding_len = 45, 0.1, 16
output_text = generation_model.fast_contrastive_search(input_ids, beam_width, alpha, decoding_len)
print (output_text)
'''
   A man is riding a skateboard down a street.
'''

The arguments are as follows:

  • --input_ids: The id of the start of sentence token.
  • --beam_width: k in the contrastive search.
  • --alpha: alpha in the contrastive search.
  • --decoding_len: Number of tokens to generate.
3.2. Top-k Sampling :
'''
   use top-k sampling to generate the result.
   note that, the this method is a stochastic method, thus the generated text is always different.
'''
top_k, decoding_len = 40, 16
output_text = generation_model.top_k_sampling(input_ids, top_k, decoding_len)
print (output_text)
'''
   some very different types of vases with flowers together
'''

The arguments are as follows:

  • --input_ids: The id of the start of sentence token.
  • --k: The k in top-k sampling.
  • --decoding_len: Number of tokens to generate.
3.3. Nucleus Sampling :
'''
   use nucleus sampling to generate the result.
   note that, the this method is a stochastic method, thus the generated text is always different.
'''
nucleus_p, decoding_len = 0.95, 16
output_text = generation_model.nucleus_sampling(input_ids, nucleus_p, decoding_len)
print (output_text)
'''
   Two young girls enjoying a hot dog hot dog bun.
'''

The arguments are as follows:

  • --input_ids: The id of the start of sentence token.
  • --nucleus_p: The probability in nucleus sampling.
  • --decoding_len: Number of tokens to generate.
wxywb b4b9e9f46c update the operator. 6 Commits
..
file-icon README.md
5.8 KiB
download-icon
update the magic. 2 years ago
file-icon dataclass.py
7.3 KiB
download-icon
update the operator. 2 years ago
file-icon loss_func.py
3.2 KiB
download-icon
update the magic. 2 years ago
file-icon simctg.py
9.8 KiB
download-icon
update the operator. 2 years ago
file-icon train.py
4.2 KiB
download-icon
update the operator. 2 years ago
file-icon train_flickr30k.sh
555 B
download-icon
update the magic. 2 years ago
file-icon train_mscoco.sh
535 B
download-icon
update the magic. 2 years ago
file-icon trainer.py
7.0 KiB
download-icon
update the operator. 2 years ago
file-icon utlis.py
12 KiB
download-icon
update the magic. 2 years ago