-
Notifications
You must be signed in to change notification settings - Fork 0
/
filter_refdocs_with_tokenizer.py
194 lines (157 loc) · 8.23 KB
/
filter_refdocs_with_tokenizer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
# SPDX-FileCopyrightText: 2024 Idiap Research Institute
#
# SPDX-License-Identifier: MIT
""" Script that filters refdocs with ROUGE oracle/lead sentences with a length limit for a given tokenizer. """
import json
import nltk
import os
import tqdm
from abc import ABC, abstractmethod
from rouge_score import rouge_scorer
from transformers import AutoTokenizer, AddedToken
from data_schema import SchemaFactory
SPLIT_SYMBOL = '<sent>'
class Selector(ABC):
def __init__(self, dataset, tokenizer_name, max_src_tokens):
annotation_schema = SchemaFactory.get_schema(dataset)
special_tokens = annotation_schema.get_special_text_tokens()
special_tokens = [AddedToken(t) for t in special_tokens]
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, additional_special_tokens=special_tokens)
self.max_src_tokens = max_src_tokens
@abstractmethod
def select_sentences(self, source_sentences, reference_ststs):
pass
class LeadSelector(Selector):
def select_sentences(self, source_sentences, reference_ststs):
selected_sentence_indices = {0}
selected_sentence_texts = set()
num_tokens = len(self.tokenizer.encode(source_sentences[0]))
for i in range(1, len(source_sentences)):
# skip duplicate sentences
if source_sentences[i] in selected_sentence_texts:
continue
new_tokens = len(self.tokenizer.encode(source_sentences[i]))
if num_tokens + new_tokens > self.max_src_tokens:
break
selected_sentence_indices.add(i)
selected_sentence_texts.add(source_sentences[i])
num_tokens += new_tokens
return selected_sentence_indices
class RougeSelector(Selector):
def __init__(self, dataset, tokenizer_name, max_src_tokens):
super().__init__(dataset, tokenizer_name, max_src_tokens)
self.scorer = rouge_scorer.RougeScorer(['rouge2'], use_stemmer=True)
def select_sentences(self, source_sentences, targets):
""" Greedily selects source sentences with the highest ROUGE-2 recall, until the token limit is reached. """
references = '\n'.join(targets)
selected_sentence_indices = set()
possible_sentence_indices = set(range(len(source_sentences)))
selected_sentence_texts = set()
cur_recall = 0
cur_length = 0
# compute sentence lengths
sentence_lengths = [len(self.tokenizer.encode(s)) for s in source_sentences]
# add source sentences until above the token limit or the ROUGE recall does not increase anymore
while True:
# compute recall values of each source sentence when combined with the already selected sentences
max_recall_increase = 0
best_recall = 0
best_idx = -1
indices_to_remove = set()
for i in possible_sentence_indices:
# skip and remove sentences that are too long
if cur_length + sentence_lengths[i] > self.max_src_tokens:
indices_to_remove.add(i)
continue
# skip and remove duplicate sentences
if source_sentences[i] in selected_sentence_texts:
indices_to_remove.add(i)
continue
sentence_indices = sorted(list(selected_sentence_indices) + [i])
candidates = '\n'.join([source_sentences[j] for j in sentence_indices])
recall = self.scorer.score(references, candidates)['rouge2'].recall
# compute recall increases normalized by sentence length
recall_increase = (recall - cur_recall) / sentence_lengths[i]
if recall_increase == 0:
# no more recall increases from this sentence: remove
indices_to_remove.add(i)
elif recall_increase > max_recall_increase:
max_recall_increase = recall_increase
best_recall = recall
best_idx = i
possible_sentence_indices -= indices_to_remove
cur_recall = best_recall
if max_recall_increase == 0:
# no further increase in recall: stop
break
else:
selected_sentence_indices.add(best_idx)
possible_sentence_indices.remove(best_idx)
selected_sentence_texts.add(source_sentences[best_idx])
cur_length += sentence_lengths[best_idx]
# if length limit has not been reached yet, add sentences from the top of the document
if cur_length < self.max_src_tokens:
for i, length in enumerate(sentence_lengths):
if i in selected_sentence_indices or source_sentences[i] in selected_sentence_texts:
continue
if cur_length + length > self.max_src_tokens:
break
selected_sentence_indices.add(i)
cur_length += length
# if no sentence was selected, add the first one even if it exceeds the length limit
if not selected_sentence_indices:
selected_sentence_indices = {0}
return selected_sentence_indices
def clean_target_text(text, special_tokens):
""" Removes special tokens from target text. """
text = text.replace(SPLIT_SYMBOL, ' ')
for token in special_tokens:
text = text.replace(token, ' ')
return ' '.join(text.split()) # remove multiple white-space
def main(args):
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
if args.filter_model == 'oracle':
filter_model = RougeSelector(args.dataset, args.tokenizer_name, args.max_src_tokens)
else:
filter_model = LeadSelector(args.dataset, args.tokenizer_name, args.max_src_tokens)
annotation_schema = SchemaFactory.get_schema(args.dataset)
special_text_tokens = annotation_schema.get_special_text_tokens()
for split in ['train', 'valid', 'test']:
print(f'Processing {split} files...')
outputs = []
with open(os.path.join(args.input_dir, f'{split}.txt'), 'r') as f:
filenames = list(map(str.strip, f.readlines()))
for filename in tqdm.tqdm(filenames):
with open(os.path.join(args.input_dir, f'{filename}.src.txt'), 'r') as f:
source_text = f.read().strip()
with open(os.path.join(args.input_dir, f'{filename}.tgt.txt'), 'r') as f:
targets = list(map(str.strip, f.readlines()))
target_texts = list(map(lambda t: clean_target_text(t, special_text_tokens), targets))
# split source into sentences
source_sentences = nltk.sent_tokenize(source_text)
# select source sentences
selected_sentence_indices = filter_model.select_sentences(source_sentences, target_texts)
source_sentences = [s for i, s in enumerate(source_sentences) if i in selected_sentence_indices]
# add filtered example to output
for i, target in enumerate(targets):
outputs.append({
'src_sents': source_sentences,
'tgt': target,
'name': filename,
'tgt_i': i,
})
with open(os.path.join(args.output_dir, f'{args.dataset}.{args.filter_model}.{split}.json'), 'w') as f:
json.dump(outputs, f)
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser(description='Filters the refdoc sentences.')
parser.add_argument('--dataset', default='us-russia', choices=['us-russia'], help='Dataset name')
parser.add_argument('--input_dir', default='data_us_russia_txt', help='Path to input text data directory')
parser.add_argument('--output_dir', default='data_us_russia_filtered_text_bart',
help='Path to output directory')
parser.add_argument('--filter_model', default='oracle', choices=['oracle', 'lead'],
help='Model to select refdoc sentences')
parser.add_argument('--tokenizer_name', default='bert-base-uncased', help='Tokenizer model name or path to dir')
parser.add_argument('--max_src_tokens', type=int, default=1024, help='Limit number of source tokens')
main(parser.parse_args())