lightningdot
copied
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions
63 lines
1.9 KiB
63 lines
1.9 KiB
2 years ago
|
"""
|
||
|
copied from official NLVR2 github
|
||
|
python eval/nlvr2.py <output.csv> <annotation.json>
|
||
|
"""
|
||
|
import json
|
||
|
import sys
|
||
|
|
||
|
# Load the predictions file. Assume it is a CSV.
|
||
|
predictions = { }
|
||
|
for line in open(sys.argv[1]).readlines():
|
||
|
if line:
|
||
|
splits = line.strip().split(",")
|
||
|
# We assume identifiers are in the format "split-####-#-#.png".
|
||
|
identifier = splits[0]
|
||
|
prediction = splits[1]
|
||
|
predictions[identifier] = prediction
|
||
|
|
||
|
# Load the labeled examples.
|
||
|
labeled_examples = [json.loads(line) for line in open(sys.argv[2]).readlines() if line]
|
||
|
|
||
|
# If not, identify the ones that are missing, and exit.
|
||
|
total_num = len(labeled_examples)
|
||
|
if len(predictions) < total_num:
|
||
|
print("Some predictions are missing!")
|
||
|
print("Got " + str(len(predictions)) + " predictions but expected " + str(total_num))
|
||
|
|
||
|
for example in labeled_examples:
|
||
|
lookup = example["identifier"]
|
||
|
if not lookup in predictions:
|
||
|
print("Missing prediction for item " + str(lookup))
|
||
|
exit()
|
||
|
|
||
|
# Get the precision by iterating through the examples and checking the value
|
||
|
# that was predicted.
|
||
|
# Also update the "consistency" dictionary that keeps track of whether all
|
||
|
# predictions for a given sentence were correct.
|
||
|
num_correct = 0.
|
||
|
consistency_dict = { }
|
||
|
|
||
|
for example in labeled_examples:
|
||
|
anon_label = example["identifier"].split("-")
|
||
|
anon_label[2] = ''
|
||
|
anon_label = '-'.join(anon_label)
|
||
|
if not anon_label in consistency_dict:
|
||
|
consistency_dict[anon_label] = True
|
||
|
lookup = example["identifier"]
|
||
|
prediction = predictions[lookup]
|
||
|
if prediction.lower() == example["label"].lower():
|
||
|
num_correct += 1.
|
||
|
else:
|
||
|
consistency_dict[anon_label] = False
|
||
|
|
||
|
# Calculate consistency.
|
||
|
num_consistent = 0.
|
||
|
unique_sentence = len(consistency_dict)
|
||
|
for identifier, consistent in consistency_dict.items():
|
||
|
if consistent:
|
||
|
num_consistent += 1
|
||
|
|
||
|
# Report values.
|
||
|
print("accuracy=" + str(num_correct / total_num))
|
||
|
print("consistency=" + str(num_consistent / unique_sentence))
|