lightningdot
copied
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions
203 lines
7.1 KiB
203 lines
7.1 KiB
"""
|
|
NOTE: modified from ban-vqa
|
|
This code is slightly modified from Hengyuan Hu's repository.
|
|
https://github.com/hengyuan-hu/bottom-up-attention-vqa
|
|
"""
|
|
import os
|
|
import json
|
|
import re
|
|
import sys
|
|
import pickle
|
|
|
|
|
|
CONTRACTIONS = {
|
|
"aint": "ain't", "arent": "aren't", "cant": "can't", "couldve":
|
|
"could've", "couldnt": "couldn't", "couldn'tve": "couldn't've",
|
|
"couldnt've": "couldn't've", "didnt": "didn't", "doesnt":
|
|
"doesn't", "dont": "don't", "hadnt": "hadn't", "hadnt've":
|
|
"hadn't've", "hadn'tve": "hadn't've", "hasnt": "hasn't", "havent":
|
|
"haven't", "hed": "he'd", "hed've": "he'd've", "he'dve":
|
|
"he'd've", "hes": "he's", "howd": "how'd", "howll": "how'll",
|
|
"hows": "how's", "Id've": "I'd've", "I'dve": "I'd've", "Im":
|
|
"I'm", "Ive": "I've", "isnt": "isn't", "itd": "it'd", "itd've":
|
|
"it'd've", "it'dve": "it'd've", "itll": "it'll", "let's": "let's",
|
|
"maam": "ma'am", "mightnt": "mightn't", "mightnt've":
|
|
"mightn't've", "mightn'tve": "mightn't've", "mightve": "might've",
|
|
"mustnt": "mustn't", "mustve": "must've", "neednt": "needn't",
|
|
"notve": "not've", "oclock": "o'clock", "oughtnt": "oughtn't",
|
|
"ow's'at": "'ow's'at", "'ows'at": "'ow's'at", "'ow'sat":
|
|
"'ow's'at", "shant": "shan't", "shed've": "she'd've", "she'dve":
|
|
"she'd've", "she's": "she's", "shouldve": "should've", "shouldnt":
|
|
"shouldn't", "shouldnt've": "shouldn't've", "shouldn'tve":
|
|
"shouldn't've", "somebody'd": "somebodyd", "somebodyd've":
|
|
"somebody'd've", "somebody'dve": "somebody'd've", "somebodyll":
|
|
"somebody'll", "somebodys": "somebody's", "someoned": "someone'd",
|
|
"someoned've": "someone'd've", "someone'dve": "someone'd've",
|
|
"someonell": "someone'll", "someones": "someone's", "somethingd":
|
|
"something'd", "somethingd've": "something'd've", "something'dve":
|
|
"something'd've", "somethingll": "something'll", "thats":
|
|
"that's", "thered": "there'd", "thered've": "there'd've",
|
|
"there'dve": "there'd've", "therere": "there're", "theres":
|
|
"there's", "theyd": "they'd", "theyd've": "they'd've", "they'dve":
|
|
"they'd've", "theyll": "they'll", "theyre": "they're", "theyve":
|
|
"they've", "twas": "'twas", "wasnt": "wasn't", "wed've":
|
|
"we'd've", "we'dve": "we'd've", "weve": "we've", "werent":
|
|
"weren't", "whatll": "what'll", "whatre": "what're", "whats":
|
|
"what's", "whatve": "what've", "whens": "when's", "whered":
|
|
"where'd", "wheres": "where's", "whereve": "where've", "whod":
|
|
"who'd", "whod've": "who'd've", "who'dve": "who'd've", "wholl":
|
|
"who'll", "whos": "who's", "whove": "who've", "whyll": "why'll",
|
|
"whyre": "why're", "whys": "why's", "wont": "won't", "wouldve":
|
|
"would've", "wouldnt": "wouldn't", "wouldnt've": "wouldn't've",
|
|
"wouldn'tve": "wouldn't've", "yall": "y'all", "yall'll":
|
|
"y'all'll", "y'allll": "y'all'll", "yall'd've": "y'all'd've",
|
|
"y'alld've": "y'all'd've", "y'all'dve": "y'all'd've", "youd":
|
|
"you'd", "youd've": "you'd've", "you'dve": "you'd've", "youll":
|
|
"you'll", "youre": "you're", "youve": "you've"
|
|
}
|
|
|
|
MANUAL_MAP = {'none': '0',
|
|
'zero': '0',
|
|
'one': '1',
|
|
'two': '2',
|
|
'three': '3',
|
|
'four': '4',
|
|
'five': '5',
|
|
'six': '6',
|
|
'seven': '7',
|
|
'eight': '8',
|
|
'nine': '9',
|
|
'ten': '10'}
|
|
ARTICLES = ['a', 'an', 'the']
|
|
PERIOD_STRIP = re.compile(r"(?!<=\d)(\.)(?!\d)")
|
|
COMMA_STRIP = re.compile(r"(\d)(\,)(\d)")
|
|
PUNCT = [';', r"/", '[', ']', '"', '{', '}',
|
|
'(', ')', '=', '+', '\\', '_', '-',
|
|
'>', '<', '@', '`', ',', '?', '!']
|
|
|
|
|
|
# Notice that VQA score is the average of 10 choose 9 candidate answers cases
|
|
# See http://visualqa.org/evaluation.html
|
|
def get_score(occurences):
|
|
if occurences == 0:
|
|
return .0
|
|
elif occurences == 1:
|
|
return .3
|
|
elif occurences == 2:
|
|
return .6
|
|
elif occurences == 3:
|
|
return .9
|
|
else:
|
|
return 1.
|
|
|
|
|
|
def process_punctuation(inText):
|
|
outText = inText
|
|
for p in PUNCT:
|
|
if (p + ' ' in inText
|
|
or ' ' + p in inText
|
|
or re.search(COMMA_STRIP, inText) is not None):
|
|
outText = outText.replace(p, '')
|
|
else:
|
|
outText = outText.replace(p, ' ')
|
|
outText = PERIOD_STRIP.sub("", outText, re.UNICODE)
|
|
return outText
|
|
|
|
|
|
def process_digit_article(inText):
|
|
outText = []
|
|
tempText = inText.lower().split()
|
|
for word in tempText:
|
|
word = MANUAL_MAP.setdefault(word, word)
|
|
if word not in ARTICLES:
|
|
outText.append(word)
|
|
else:
|
|
pass
|
|
for wordId, word in enumerate(outText):
|
|
if word in CONTRACTIONS:
|
|
outText[wordId] = CONTRACTIONS[word]
|
|
outText = ' '.join(outText)
|
|
return outText
|
|
|
|
|
|
def preprocess_answer(answer):
|
|
answer = process_digit_article(process_punctuation(answer))
|
|
answer = answer.replace(',', '')
|
|
return answer
|
|
|
|
|
|
def filter_answers(answers_dset, min_occurence):
|
|
"""This will change the answer to preprocessed version
|
|
"""
|
|
occurence = {}
|
|
|
|
for ans_entry in answers_dset:
|
|
gtruth = ans_entry.get('multiple_choice_answer', None)
|
|
if gtruth is None:
|
|
gtruth = ans_entry['answers'][0]['answer'] # VG, GQA pretraining
|
|
gtruth = preprocess_answer(gtruth)
|
|
if gtruth not in occurence:
|
|
occurence[gtruth] = set()
|
|
occurence[gtruth].add(ans_entry['question_id'])
|
|
for answer in list(occurence):
|
|
if len(occurence[answer]) < min_occurence:
|
|
occurence.pop(answer)
|
|
|
|
print('Num of answers that appear >= %d times: %d' % (
|
|
min_occurence, len(occurence)))
|
|
return occurence
|
|
|
|
|
|
def create_ans2label(occurence, path):
|
|
"""
|
|
occurence: dict {answer -> whatever}
|
|
name: dir of the output file
|
|
"""
|
|
ans2label = {}
|
|
label2ans = []
|
|
label = 0
|
|
for answer in occurence:
|
|
label2ans.append(answer)
|
|
ans2label[answer] = label
|
|
label += 1
|
|
|
|
output_file = os.path.join(path, 'ans2label.pkl')
|
|
pickle.dump(ans2label, open(output_file, 'wb'))
|
|
|
|
|
|
def compute_target(answers, ans2label):
|
|
answer_count = {}
|
|
if len(answers) == 1:
|
|
# VG VQA, GQA
|
|
answer_ = preprocess_answer(answers[0]['answer'])
|
|
answer_count[answer_] = 10
|
|
else:
|
|
# COCO VQA
|
|
for answer in answers:
|
|
answer_ = preprocess_answer(answer['answer'])
|
|
answer_count[answer_] = answer_count.get(answer_, 0) + 1
|
|
|
|
labels = []
|
|
scores = []
|
|
for answer in answer_count:
|
|
if answer not in ans2label:
|
|
continue
|
|
labels.append(ans2label[answer])
|
|
score = get_score(answer_count[answer])
|
|
scores.append(score)
|
|
target = {'labels': labels, 'scores': scores}
|
|
return target
|
|
|
|
|
|
if __name__ == '__main__':
|
|
*answer_files, output = sys.argv[1:]
|
|
answers = []
|
|
for ans_file in answer_files:
|
|
ans = json.load(open(ans_file))['annotations']
|
|
answers.extend(ans)
|
|
|
|
occurence = filter_answers(answers, 9)
|
|
|
|
if os.path.exists(f'{output}/ans2label.pkl'):
|
|
raise ValueError(f'{output} already exists')
|
|
create_ans2label(occurence, output)
|