-
Notifications
You must be signed in to change notification settings - Fork 0
/
evaluations.py
322 lines (266 loc) · 14.2 KB
/
evaluations.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
# SPDX-FileCopyrightText: 2024 Idiap Research Institute
#
# SPDX-License-Identifier: MIT
""" Evaluations on saved model outputs. """
from collections import defaultdict
import json
import numpy as np
import os
from abc import ABC, abstractmethod
from bert_score import BERTScorer
from data_schema import SchemaFactory
class Evaluation(ABC):
def __init__(self, args):
self.input_dir = args.input_dir
self.schema = SchemaFactory.get_schema(args.dataset)
self.annotation = args.annotation if hasattr(args, 'annotation') else None
def read_file(self, filename):
with open(os.path.join(self.input_dir, filename), 'r') as f:
lines = map(str.strip, f.readlines())
if self.annotation:
lines = map(self.get_annotation_text, lines)
return list(lines)
def get_special_tokens(self):
""" Returns a list of special text tokens. """
return set(self.schema.get_special_text_tokens())
def remove_special_tokens(self, text):
for st in self.get_special_tokens():
text = text.replace(st, '')
text = ' '.join(text.split()) # replace multiple white-spaces
return text
def get_annotation_text(self, text):
""" Returns the text of an annotation. """
start_token = self.schema.mapping[self.annotation]['text_start']
end_token = self.schema.mapping[self.annotation]['text_end']
annotation_texts = []
while start_token in text and end_token in text:
start = text.index(start_token) + len(start_token)
end = text.index(end_token)
annotation_texts.append(text[start:end].strip())
text = text[end + len(end_token):]
return ', '.join(annotation_texts)
@abstractmethod
def run(self):
pass
class Rouge(Evaluation):
""" Computes ROUGE-1/2/L, a measure of textual overlap between the generated and the reference StSts. """
def __init__(self, args, remove_special_tokens=False):
super().__init__(args)
from rouge import RougeScorer, RougeAggregator
self.sentence_separator = args.sentence_separator
self.reference_separator = args.reference_separator
self.do_remove_special_tokens = remove_special_tokens
self.scorer = RougeScorer()
self.aggregator = RougeAggregator()
def run(self):
references = self.read_file('references.txt')
candidates = self.read_file('candidates.txt')
assert len(references) == len(candidates), "Number of references and candidates differ"
if self.do_remove_special_tokens:
references = map(self.remove_special_tokens, references)
candidates = map(self.remove_special_tokens, candidates)
else:
references = [r.replace('_', '') for r in references] # remove _ in annotation markers, so they don't get
candidates = [c.replace('_', '') for c in candidates] # tokenized to separate tokens in rouge_score
references = [r.split(self.reference_separator) for r in references]
for candidate, reference_list in zip(candidates, references):
if not candidate:
candidate = ' ' # if candidate is empty, return ROUGE score of 0 (run with single whitespace)
scores = [
self.scorer.compute_rouge_score(candidate, r, sentence_sep=self.sentence_separator)
for r in reference_list
]
best_score = max(scores, key=lambda x: x['rouge1'].fmeasure + x['rouge2'].fmeasure + x['rougeLsum'].fmeasure)
self.aggregator.add_scores(best_score)
r1, r2, rL = self.aggregator.get_rouge_scores()
return {'rouge1': r1, 'rouge2': r2, 'rougeL': rL}
class DistinctNGrams(Evaluation):
""" Computes the distinct n-grams generated. """
def __init__(self, args):
super().__init__(args)
self.sentence_separator = args.sentence_separator
self.ngrams = args.ngrams
def run(self):
candidates = self.read_file('candidates.txt')
candidates = map(self.remove_special_tokens, candidates)
distinct_ngrams = defaultdict(set)
total_ngrams = defaultdict(int)
for candidate in candidates:
sentences = candidate.split(self.sentence_separator)
for sentence in sentences:
words = sentence.split()
for n in self.ngrams:
distinct_ngrams[n] = distinct_ngrams[n] | set((tuple(words[i:i+n]) for i in range(len(words)-n+1)))
total_ngrams[n] += len(words) - n + 1
return {n: len(distinct_ngrams[n])/total_ngrams[n] if total_ngrams[n] else 0 for n in self.ngrams}
class NovelNGrams(Evaluation):
""" Computes the fraction of n-grams in the generated StSts that don't appear in the filtered refdoc. """
def __init__(self, args):
super().__init__(args)
self.sentence_separator = args.sentence_separator
self.ngrams = args.ngrams
def run(self):
sources = self.read_file('sources.txt')
candidates = self.read_file('candidates.txt')
sources = map(self.remove_special_tokens, sources)
candidates = map(self.remove_special_tokens, candidates)
def compute_ngram_overlap(src_words, tgt_words, n):
""" Computes fraction of target n-grams appearing in source. """
assert n > 0, "N for n-grams overlap needs to be positive"
assert len(tgt_words) >= n, "Not enough words in target"
src_ngrams = set((tuple(src_words[i:i + n]) for i in range(len(src_words) - n + 1)))
tgt_ngrams = set((tuple(tgt_words[i:i + n]) for i in range(len(tgt_words) - n + 1)))
ngram_overlap = src_ngrams & tgt_ngrams
return len(ngram_overlap) / len(tgt_ngrams)
novel_ngrams = defaultdict(list)
for source, candidate in zip(sources, candidates):
source_words = source.split() # separated by periods, so matches across sentence boundaries are unlikely
candidate_sentence_words = [sentence.split() for sentence in candidate.split(self.sentence_separator)]
for n in self.ngrams:
# remove candidate sentences that are shorter than n
candidate_sents = [words for words in candidate_sentence_words if len(words) >= n]
# compute overlaps individually per candidate sentence (avoid match over sentence boundaries), then weight by its length share
overlaps = [compute_ngram_overlap(source_words, candidate_words, n) for candidate_words in candidate_sents]
lengths = [len(candidate_words) for candidate_words in candidate_sentence_words]
if not sum(lengths):
continue # skip if candidate is empty
weighted_overlap = sum([o * l for o, l in zip(overlaps, lengths)]) / sum(lengths)
novel_ngrams[n].append(1 - weighted_overlap)
return {n: np.mean(novel_ngrams[n]) for n in self.ngrams}
class Length(Evaluation):
""" Computes the length of generated StSts measured by number of StSts, tokens and words generated. """
def __init__(self, args):
super().__init__(args)
self.sentence_separator = args.sentence_separator
def run(self):
candidates = self.read_file('candidates.txt')
num_sents = []
num_tokens = []
num_words = []
num_special_tokens = []
tokens_per_sent = []
words_per_sent = []
special_tokens_per_sent = []
for candidate in candidates:
sents = list(filter(None, candidate.split(self.sentence_separator)))
tokens = [t for s in sents for t in s.split()]
words = list(filter(lambda w: w not in self.get_special_tokens(), tokens))
special_tokens = list(filter(lambda t: t in self.get_special_tokens(), tokens))
num_sents.append(len(sents))
num_tokens.append(len(tokens))
num_words.append(len(words))
num_special_tokens.append(len(special_tokens))
tokens_per_sent.append(len(tokens) / len(sents) if len(sents) else 0)
words_per_sent.append(len(words) / len(sents) if len(sents) else 0)
special_tokens_per_sent.append(len(special_tokens) / len(sents) if len(sents) else 0)
return {
'num_sents': {'mean': np.mean(num_sents), 'std': np.std(num_sents)},
'num_tokens': {'mean': np.mean(num_tokens), 'std': np.std(num_tokens)},
'num_words': {'mean': np.mean(num_words), 'std': np.std(num_words)},
'num_special_tokens': {'mean': np.mean(num_special_tokens), 'std': np.std(num_special_tokens)},
'tokens_per_sent': {'mean': np.mean(tokens_per_sent), 'std': np.std(tokens_per_sent)},
'words_per_sent': {'mean': np.mean(words_per_sent), 'std': np.std(words_per_sent)},
'special_tokens_per_sent': {'mean': np.mean(special_tokens_per_sent),
'std': np.std(special_tokens_per_sent)},
}
class Annotations(Evaluation):
""" Computes the number and length of each type of generated annotations, and whether they are correctly closed. """
def __init__(self, args):
super().__init__(args)
self.sentence_separator = args.sentence_separator
def run(self):
candidates = self.read_file('candidates.txt')
annotation_counts = defaultdict(list)
closed_correctly = []
for candidate in candidates:
counts = defaultdict(int)
sents = list(filter(None, candidate.split(self.sentence_separator)))
for sent in sents:
tokens = sent.split()
special_tokens = list(filter(lambda t: any(map(lambda x: x in t, self.get_special_tokens())), tokens))
for k, v in self.schema.mapping.items():
is_open = False
for st in special_tokens:
if v['text_start'] in st:
counts['opened'] += 1
is_open = True
elif v['text_end'] in st and is_open:
counts[k] += 1
counts['closed'] += 1
is_open = False
for k in self.schema.mapping.keys():
annotation_counts[k].append(counts[k])
if counts['opened']: # can only close when at least one annotation is opened
closed_correctly.append(counts['closed'] / counts['opened'])
results = {a: {'mean': np.mean(c), 'std': np.std(c)} for a, c in sorted(annotation_counts.items())}
results.update({'closed_correctly': np.mean(closed_correctly) if closed_correctly else 0})
return results
class BERTScore(Evaluation):
""" Computes the semantic similarity between generated and reference StSts with BERTScore (Zhang et al., 2020). """
def __init__(self, args):
super().__init__(args)
self.sentence_separator = args.sentence_separator
self.reference_separator = args.reference_separator
def run(self):
references = self.read_file('references.txt')
candidates = self.read_file('candidates.txt')
references = map(self.remove_special_tokens, references)
candidates = map(self.remove_special_tokens, candidates)
references = [r.replace(self.sentence_separator, ' ') for r in references]
references = [r.split(self.reference_separator) for r in references]
candidates = [c.replace(self.sentence_separator, ' ') for c in candidates]
scorer = BERTScorer(model_type='roberta-large', lang='en', rescale_with_baseline=True)
(p, r, f1), hash = scorer.score(cands=candidates, refs=references, return_hash=True)
return {'f1': np.mean(f1.tolist()), 'hash': hash}
class EvaluationFactory:
_registry = {
'rouge': Rouge,
'rouge_words': Rouge,
'distinct_ngrams': DistinctNGrams,
'novel_ngrams': NovelNGrams,
'length': Length,
'annotations': Annotations,
'bertscore': BERTScore,
}
@staticmethod
def get_evaluations():
return EvaluationFactory._registry.keys()
@staticmethod
def run_evaluation(name, args):
kwargs = {'args': args} if name != 'rouge_words' else {'args': args, 'remove_special_tokens': True}
evaluation = EvaluationFactory._registry[name](**kwargs)
return evaluation.run()
def main(args):
evaluations = EvaluationFactory.get_evaluations()
results = {}
results_path = os.path.join(args.input_dir, 'results.json')
if os.path.exists(results_path):
with open(results_path, 'r') as f:
results = json.load(f)
for e in evaluations:
if args.overwrite or e not in results:
print("Running evaluation", e)
results[e] = EvaluationFactory.run_evaluation(e, args)
for a in args.annotations:
assert a in SchemaFactory.get_schema(args.dataset).mapping.keys(), f"{a} is not a valid schema key"
for e in evaluations:
if e in ('rouge', 'annotations'):
continue # only compute ROUGE without the special tokens
eval_name = f'{e}-{a}'
if args.overwrite or eval_name not in results:
print("Running evaluation", eval_name)
args.annotation = a
results[eval_name] = EvaluationFactory.run_evaluation(e, args)
with open(results_path, 'w') as f:
json.dump(results, f)
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser(description='Run evaluations on saved model outputs.')
parser.add_argument('--input_dir', required=True, help='Path to model outputs directory.')
parser.add_argument('--dataset', default='us-russia', choices=['us-russia'], help='Dataset name')
parser.add_argument('--overwrite', action='store_true', help='Rerun evaluations even if present.')
parser.add_argument('--sentence_separator', default='<sent>', help='Sentence separator')
parser.add_argument('--reference_separator', default='<ref>', help='Reference separator')
parser.add_argument('--ngrams', type=int, nargs='+', default=[1, 2, 3, 4], help='List of n for n-gram statistics.')
parser.add_argument('--annotations', nargs='+', default=[], help='Annotations to evaluate separately.')
main(parser.parse_args())