-
Notifications
You must be signed in to change notification settings - Fork 8
/
evaluate_v1.0.py
130 lines (104 loc) · 4.5 KB
/
evaluate_v1.0.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import json
import argparse
from typing import List, Dict
from collections import Counter
from metrics.answer import AnswerMetric
from metrics.support import SupportMetric
from metrics.group_answer_sufficiency import GroupAnswerSufficiencyMetric
from metrics.group_support_sufficiency import GroupSupportSufficiencyMetric
def read_jsonl(file_path: str) -> List[Dict]:
with open(file_path, "r") as file:
instances = [json.loads(line.strip()) for line in file if line.strip()]
return instances
def evaluate(filepath_with_predictions: str, filepath_with_ground_truths: str) -> Dict:
prediction_instances = read_jsonl(filepath_with_predictions)
ground_truth_instances = read_jsonl(filepath_with_ground_truths)
do_sufficiency_eval = False
answer_metric = AnswerMetric()
support_metric = SupportMetric()
group_answer_sufficiency_metric = GroupAnswerSufficiencyMetric()
group_support_sufficiency_metric = GroupSupportSufficiencyMetric()
assert len(prediction_instances) == len(
ground_truth_instances
), "The number of lines in the two files are not the same."
for ground_truth_instance, prediction_instance in zip(
ground_truth_instances, prediction_instances
):
assert (
ground_truth_instance["id"] == prediction_instance["id"]
), "The instances (ids) in prediction and gold filepath jsonl should be in same order."
question_id = ground_truth_instance["id"]
predicted_answer = prediction_instance["predicted_answer"]
ground_truth_answers = [
ground_truth_instance["answer"]
] + ground_truth_instance["answer_aliases"]
predicted_support_indices = prediction_instance["predicted_support_idxs"]
ground_truth_support_indices = [
paragraph["idx"]
for paragraph in ground_truth_instance["paragraphs"]
if paragraph["is_supporting"]
]
predicted_sufficiency = prediction_instance["predicted_answerable"]
ground_truth_sufficiency = ground_truth_instance["answerable"]
if ground_truth_sufficiency:
answer_metric(predicted_answer, ground_truth_answers)
support_metric(predicted_support_indices, ground_truth_support_indices)
group_answer_sufficiency_metric(
predicted_answer,
ground_truth_answers,
predicted_sufficiency,
ground_truth_sufficiency,
question_id,
)
group_support_sufficiency_metric(
predicted_support_indices,
ground_truth_support_indices,
predicted_sufficiency,
ground_truth_sufficiency,
question_id,
)
# If there's any instance with ground truth of unanswerable, we'll assume
# it's full version of the dataset and not only the answerable version.
if not ground_truth_sufficiency:
do_sufficiency_eval = True
metrics = {}
metrics["answer_f1"] = round(answer_metric.get_metric()[1], 3)
metrics["answer_em"] = round(answer_metric.get_metric()[0], 3)
metrics["support_f1"] = round(support_metric.get_metric()[1], 3)
if do_sufficiency_eval:
assert set(Counter([e['id'] for e in prediction_instances]).values()) == {2}, \
"For sufficiency evaluation, there should two instances for each question."
metrics["group_answer_sufficiency_f1"] = round(
group_answer_sufficiency_metric.get_metric()["f1"], 3
)
metrics["group_support_sufficiency_f1"] = round(
group_support_sufficiency_metric.get_metric()["f1"], 3
)
return metrics
def main():
parser = argparse.ArgumentParser(description="Evaluate MuSiQue predictions.")
parser.add_argument(
"filepath_with_predictions",
type=str,
help="jsonl filepath to predicted instances.",
)
parser.add_argument(
"filepath_with_ground_truths",
type=str,
help="jsonl filepath to data instances.",
)
parser.add_argument(
"--output_filepath",
type=str,
help="(optional) filepath to save output metrics."
)
args = parser.parse_args()
metrics = evaluate(args.filepath_with_predictions, args.filepath_with_ground_truths)
if args.output_filepath:
print(f"Writing metrics output in: {args.output_filepath}")
with open(args.output_filepath, "w") as file:
json.dump(metrics, file, indent=4)
else:
print(json.dumps(metrics, indent=4))
if __name__ == "__main__":
main()