lightningdot
copied
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions
54 lines
2.0 KiB
54 lines
2.0 KiB
import pandas as pd
|
|
import json
|
|
import os
|
|
import argparse
|
|
import numpy as np
|
|
|
|
|
|
def main(opts):
|
|
with open(os.path.join(opts.input_folder, opts.pred_file), "r") as f:
|
|
data = json.load(f)
|
|
probs_grp = []
|
|
ids_grp = []
|
|
ordered_data = sorted(data.items(),
|
|
key=lambda item: int(item[0].split("-")[1]))
|
|
for annot_id, scores in ordered_data:
|
|
ids_grp.append(annot_id)
|
|
probs_grp.append(np.array(scores).reshape(1, 5, 4))
|
|
|
|
# Double check the IDs are in the same order for everything
|
|
# assert [x == ids_grp[0] for x in ids_grp]
|
|
|
|
probs_grp = np.stack(probs_grp, 1)
|
|
# essentially probs_grp is a [num_ex, 5, 4] array of probabilities.
|
|
# The 5 'groups' are
|
|
# [answer, rationale_conditioned_on_a0, rationale_conditioned_on_a1,
|
|
# rationale_conditioned_on_a2, rationale_conditioned_on_a3].
|
|
# We will flatten this to a CSV file so it's easy to submit.
|
|
group_names = ['answer'] + [f'rationale_conditioned_on_a{i}'
|
|
for i in range(4)]
|
|
probs_df = pd.DataFrame(data=probs_grp.reshape((-1, 20)),
|
|
columns=[f'{group_name}_{i}'
|
|
for group_name in group_names for i in range(4)])
|
|
probs_df['annot_id'] = ids_grp
|
|
probs_df = probs_df.set_index('annot_id', drop=True)
|
|
probs_df.to_csv(os.path.join(opts.input_folder, opts.output_file))
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
|
|
# Required parameters
|
|
parser.add_argument("--pred_file",
|
|
default=None, type=str,
|
|
help="The input JSON file.")
|
|
parser.add_argument("--output_file",
|
|
default=None, type=str,
|
|
help="The output CSV file.")
|
|
parser.add_argument(
|
|
"--input_folder", default=None, type=str,
|
|
help="The directory where the predicted JSON files are in")
|
|
|
|
args = parser.parse_args()
|
|
|
|
main(args)
|
|
|