lightningdot
              
                 
                
            
          copied
			You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
			
			Readme
Files and versions
		
      
        
        
          
            62 lines
          
        
        
          
            1.9 KiB
          
        
        
      
		
    
      
      
    
	
  
	
            62 lines
          
        
        
          
            1.9 KiB
          
        
        
      | """ | |
| copied from official NLVR2 github | |
| python eval/nlvr2.py <output.csv> <annotation.json> | |
| """ | |
| import json | |
| import sys | |
| 
 | |
| # Load the predictions file. Assume it is a CSV. | |
| predictions = { } | |
| for line in open(sys.argv[1]).readlines(): | |
|   if line: | |
|     splits = line.strip().split(",") | |
|     # We assume identifiers are in the format "split-####-#-#.png". | |
|     identifier = splits[0] | |
|     prediction = splits[1] | |
|     predictions[identifier] = prediction | |
| 
 | |
| # Load the labeled examples. | |
| labeled_examples = [json.loads(line) for line in open(sys.argv[2]).readlines() if line] | |
| 
 | |
| # If not, identify the ones that are missing, and exit. | |
| total_num = len(labeled_examples) | |
| if len(predictions) < total_num: | |
|   print("Some predictions are missing!") | |
|   print("Got " + str(len(predictions)) + " predictions but expected " + str(total_num)) | |
| 
 | |
|   for example in labeled_examples: | |
|       lookup = example["identifier"] | |
|       if not lookup in predictions: | |
|         print("Missing prediction for item " + str(lookup)) | |
|   exit() | |
| 
 | |
| # Get the precision by iterating through the examples and checking the value | |
| # that was predicted. | |
| # Also update the "consistency" dictionary that keeps track of whether all | |
| # predictions for a given sentence were correct. | |
| num_correct = 0. | |
| consistency_dict = { } | |
| 
 | |
| for example in labeled_examples: | |
|   anon_label = example["identifier"].split("-") | |
|   anon_label[2] = '' | |
|   anon_label = '-'.join(anon_label) | |
|   if not anon_label in consistency_dict: | |
|     consistency_dict[anon_label] = True | |
|   lookup = example["identifier"] | |
|   prediction = predictions[lookup] | |
|   if prediction.lower() == example["label"].lower(): | |
|     num_correct += 1. | |
|   else: | |
|     consistency_dict[anon_label] = False | |
| 
 | |
| # Calculate consistency. | |
| num_consistent = 0. | |
| unique_sentence = len(consistency_dict) | |
| for identifier, consistent in consistency_dict.items(): | |
|   if consistent: | |
|     num_consistent += 1 | |
| 
 | |
| # Report values. | |
| print("accuracy=" + str(num_correct / total_num)) | |
| print("consistency=" + str(num_consistent / unique_sentence))
 | 
