-
Notifications
You must be signed in to change notification settings - Fork 0
/
rare_effects.py
302 lines (239 loc) · 12.2 KB
/
rare_effects.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
import numpy as np
import pandas as pd
import scipy.io as sio
import os
import sys
import argparse
import matplotlib.pyplot as plt
import pickle
from pprint import pprint
from load_semantics import *
from plots_with_av_heatmap import plot_rarity_effects
# model_paths = {
# 'ALE': '/home/gdata/sandipan/BTP2021/new_zsl_models/ALE/CZSL',
# 'ESZSL': '/home/gdata/sandipan/BTP2021/new_zsl_models/ESZSL/CZSL',
# 'DEVISE': '/home/gdata/sandipan/BTP2021/new_zsl_models/DEVISE/CZSL',
# 'SAE': '/home/gdata/sandipan/BTP2021/new_zsl_models/SAE/CZSL',
# 'SJE': '/home/gdata/sandipan/BTP2021/new_zsl_models/SJE/CZSL',
# 'LSRGAN': '/home/gdata/sandipan/BTP2021/new_zsl_models/LSRGAN/CZSL',
# 'TFVAEGAN': '/home/gdata/sandipan/BTP2021/new_zsl_models/TFVAEGAN/tfvaegan-master/CZSL'
# }
model_paths = {
#comment out CNZSL and TransZero if running for SUN dataset
'ALE': '/workspace/arijit_pg/BTP2021/new_zsl_models/ALE/CZSL',
'ESZSL': '/workspace/arijit_pg/BTP2021/new_zsl_models/ESZSL/CZSL',
'DEVISE': '/workspace/arijit_pg/BTP2021/new_zsl_models/DEVISE/CZSL',
'SAE': '/workspace/arijit_pg/BTP2021/new_zsl_models/SAE/CZSL',
'SJE': '/workspace/arijit_pg/BTP2021/new_zsl_models/SJE/CZSL',
'LSRGAN': '/workspace/arijit_pg/BTP2021/new_zsl_models/LSRGAN/CZSL',
'TFVAEGAN': '/workspace/arijit_pg/BTP2021/new_zsl_models/TFVAEGAN/tfvaegan-master/CZSL',
# 'CNZSL': '/workspace/arijit_pg/BTP2021/new_zsl_models/CNZSL/CZSL',
'FREE': '/workspace/arijit_pg/BTP2021/new_zsl_models/FREE/CZSL',
# 'TransZero': '/workspace/arijit_pg/BTP2021/new_zsl_models/TransZero/CZSL',
'MSDN': '/workspace/arijit_pg/BTP2021/new_zsl_models/MSDN/CZSL'
}
model_names_in_resfiles = {
# as per the model names with which the result filenames start in CZSL folders
'ALE': 'ale',
'ESZSL': 'eszsl',
'DEVISE': 'devise',
'SAE': 'sae',
'SJE': 'sje',
'LSRGAN': 'clswgan',
'TFVAEGAN': 'tfvaegan_czsl',
# 'CNZSL':'cnzsl_czsl',
'FREE':'free_czsl',
# 'TransZero':'TransZero_czsl',
'MSDN':'MSDN_czsl'
}
# classwise_accs_key_names = {
# # as per codes added by me for each model in its main python runner file
# 'ALE': 'common_unseen_classwise',
# 'ESZSL': 'common_unseen_classwise',
# 'DEVISE': 'common_unseen_classwise',
# 'SAE': 'common_unseen_classwise_F2S',
# 'SJE': 'common_unseen_classwise',
# 'LSRGAN': 'common_unseen_classwise',
# 'TFVAEGAN': 'common_unseen_classwise'
# }
split_names = {
'ES':'original',
'PS':'new_seed_final'
}
def get_paths():
# ZSL folder name ends with lr value. For SUN it is 0.001 and for others it is 0.01.
split_paths = {k: v+'_al_lr'+str(args.al_lr)+'/u_split'+str(args.sn) for k, v in model_paths.items()}
result_paths = {}
resfiles = {'ES':{}, 'PS':{}}
for m in model_paths.keys():
top = split_paths[m]
for root, dirs, files in os.walk(top, topdown=False):
for file in files:
if file.endswith(".pickle"):
result_paths[m] = os.path.join(root)
# print(result_paths[m])
# segregate the result files for the current u_split
for skey, sval in split_names.items():
resfiles[skey][m] = result_paths[m] + '/' + model_names_in_resfiles[m] + '_' + args.dataset + '_' + sval + '_results.pickle'
return split_paths, result_paths, resfiles
def get_domain_semantics():
# get semantic matrices for trainval seen classes and common unseen classes from any resfile from any model - here we select ALE
res = open(resfiles['PS']['ALE'], 'rb')
test_res = pickle.load(res)
common_unseen = test_res['zsl_common_unseen']
att_df, data_complete_info, imagenet_overlapping_classes, given_testclasses = load_semantic_matrix(args.dataset)
# for any run of our framework, the available experimental classes would be all except common unseen (say T). Hence, for identifying rare and common attributes, info should be extracted only from classes in T.
common_unseen_att_df = att_df.loc[common_unseen]
att_df.drop(common_unseen, axis = 0, inplace = True)
res.close()
return att_df, common_unseen_att_df, common_unseen
def get_rare_and_common_atts(x):
r_ratio = float(args.r_ratio)
c_ratio = float(args.c_ratio)
# removing irrelevant and unremarkable attributes, if any (very low chance of getting any such attributes as we are considering the entire domain, except only few classes, i.e. the common unseen classes)
non_zero_counts = np.count_nonzero(x, axis = 0)
irrelevant_atts = x.columns[non_zero_counts == 0]
print('\nIrrelevant_atts ({}) : {}'.format(len(irrelevant_atts), irrelevant_atts))
x = x.drop(columns= irrelevant_atts)
non_zero_counts = non_zero_counts[non_zero_counts != 0]
# calculate thershold for each remaining attribute
clipped_semantic_mean = x.sum()
clipped_semantic_mean = clipped_semantic_mean / non_zero_counts # mean of only nonzero values of each attribute
# print(clipped_semantic_mean)
# get binary semantic matrix
thresh_df = (x - clipped_semantic_mean).clip(lower=0)
thresh_df[thresh_df > 0] = 1
thresh_df = thresh_df.astype(int)
non_zero_atts = np.count_nonzero(thresh_df, axis = 0)
unremarkable_atts = thresh_df.columns[non_zero_atts == 0]
print('\nUnremarkable_atts ({}) : {}'.format(len(unremarkable_atts), unremarkable_atts))
thresh_df = thresh_df.drop(columns= unremarkable_atts)
# infer rare and common attributes from their frequencies in binary matrix
occurences = thresh_df.sum()
occurences.sort_values(inplace = True, ascending = True)
# attribute will be rare if it occurs in less than r_ratio% of all domain classes
rare_count = (occurences < round(thresh_df.shape[0] * r_ratio)).sum()
rare_atts = occurences.index[:rare_count].tolist()
# attribute will be common if it occurs in more than c_ratio% of all domain classes
common_count = (occurences > round(thresh_df.shape[0] * c_ratio)).sum()
common_atts = occurences.index[-common_count:].tolist()
# above line will mistakenly give all classes if common_count is zero. Tune r_ratio and c_ratio in such a way that this is avoided
print('\n\nRare ({}): {}'.format(rare_count, rare_atts))
print('\n\nCommon ({}): {}'.format(common_count, common_atts))
return rare_atts, common_atts, thresh_df
def get_classes_with_rare_and_common_atts(rare_atts, common_atts, common_unseen_att_df):
# for common unseen classes, no thresholding and removal of irrelevant or unremarkable attributes is required - those were only for the classes incorporated in DiRaC-I. For easy processing, we consider any non-zero value in semantic matrix of common unseen classes as 1, otherwise a 0
common_unseen_att_df[common_unseen_att_df > 0.0] = 1
common_unseen_att_df = common_unseen_att_df.astype(int)
# extracting rare info
classes_with_rare = [] # contains all unique classes with at least one rare attribute
rare_att_in_common_unseen = [] # contains all unique classes for each rare attribute
for att in rare_atts:
classes_found = common_unseen_att_df[common_unseen_att_df[att] == 1].index.tolist()
rare_att_in_common_unseen.append(classes_found)
for c in classes_found:
if c not in classes_with_rare:
classes_with_rare.append(c)
# extracting common info
classes_with_common = [] # contains all unique classes with at least one common attribute
common_att_in_common_unseen = [] # contains all unique classes for each common attribute
for att in common_atts:
classes_found = common_unseen_att_df[common_unseen_att_df[att] == 1].index.tolist()
common_att_in_common_unseen.append(classes_found)
for c in classes_found:
if c not in classes_with_common:
classes_with_common.append(c)
# creating dataframes
rare_att_in_common_unseen = pd.DataFrame(rare_att_in_common_unseen, index = rare_atts)
common_att_in_common_unseen = pd.DataFrame(common_att_in_common_unseen, index = common_atts)
# print('\nResults for original unseen classes:')
# print(rare_att_in_orig_test)
# print(common_att_in_orig_test)
print('\n\n')
print('Classes with rare attributes ({}/{}): {}'.format(len(classes_with_rare), len(common_unseen_att_df.index), classes_with_rare))
print('Classes with common attributes ({}/{}): {}'.format(len(classes_with_common), len(common_unseen_att_df.index), classes_with_common))
return common_unseen_att_df, classes_with_rare, classes_with_common
def get_accuracies(classes_with_rare, classes_with_common, common_unseen, resfiles):
print('\n\nImpact of incorporating rarity\n==============================\n')
effects = {}
for m in model_paths.keys():
print('\nModel = ', m)
effects[m] = {}
if m == 'SAE':
res_key = 'common_unseen_classwise_F2S'
else:
res_key = 'common_unseen_classwise'
for split_name in resfiles.keys(): # ES or PS
print('\nSplit name = ', split_name)
pklfile = open(resfiles[split_name][m], 'rb')
test_res = pickle.load(pklfile)
classwise_accs_common_unseen = test_res[res_key]
effects[m][split_name] = {}
avg_rare = 0.0
avg_common = 0.0
rare_accs = []
common_accs = []
# sanity check
rare_class_order = []
common_class_order = []
for c in common_unseen:
a = classwise_accs_common_unseen[c] * 100 # converting to percentage values
if c in classes_with_rare:
rare_accs.append(a)
rare_class_order.append(c)
avg_rare += a
if c in classes_with_common:
common_accs.append(a)
common_class_order.append(c)
avg_common += a
avg_rare = avg_rare/len(classes_with_rare)
avg_common = avg_common/len(classes_with_common)
print('\nAvg. accuracy for classes with rare attributes = ', avg_rare)
print('\nAvg. accuracy for classes with common attributes = ', avg_common)
effects[m][split_name]['avg_rare'] = avg_rare
effects[m][split_name]['avg_common'] = avg_common
effects[m][split_name]['rare_accs'] = rare_accs
effects[m][split_name]['common_accs'] = common_accs
effects[m][split_name]['rare_class_order'] = rare_class_order
effects[m][split_name]['common_class_order'] = common_class_order
pklfile.close()
print('\n\n\nFinal rare effect results\n=========================\n')
pprint(effects)
return effects
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Observing the effect of rare attributes on learning ability of ZSL models")
parser.add_argument('-d','--dataset', default = 'SUN', help = 'AWA2, SUN, CUB')
parser.add_argument('-sn', '--sn', type = int, default = 1, help='random unknown unknown split number')
parser.add_argument('-al_lr', '--al_lr', default=0.01, type=float, help='learning rate used during active learning')
parser.add_argument('-rare', '--r_ratio', default=0.1, type=float, help='ratio below which attribute considered rare')
parser.add_argument('-common', '--c_ratio', default=0.5, type=float, help='ratio above which attribute considered common')
#added argument for better path parsing - change it as per you need
parser.add_argument('--home_dir', default='/workspace/arijit_pg/BTP2021/', help='path to dataset')
args = parser.parse_args()
# make folder to store final reports
res_folder = args.home_dir + args.dataset + '/rarity_reports_lr' + str(args.al_lr) + '/'
if not os.path.exists(res_folder):
os.mkdir(res_folder)
result_filename = res_folder + 'u_split' + str(args.sn) + '_' + args.dataset + '_r' + str(args.r_ratio) + '_c' + str(args.c_ratio) + '_rarity_reports.txt'
sys.stdout = open(result_filename, 'w')
# get the paths where results are stored
split_paths, result_paths, resfiles = get_paths()
pprint(resfiles)
# get semantic matrix from which rare and common attributes are to be identified
att_df, common_unseen_att_df, common_unseen = get_domain_semantics()
num_domain_classes = att_df.shape[0]
num_domain_atts = att_df.shape[1]
print('\n\nDomain classes = {}, Domain attributes = {}'.format(num_domain_classes, num_domain_atts))
rare_atts, common_atts, thresh_df = get_rare_and_common_atts(att_df)
# obtain list of classes with rare and common attributes
common_unseen_att_df_binary, classes_with_rare, classes_with_common = get_classes_with_rare_and_common_atts(rare_atts, common_atts, common_unseen_att_df)
# extract classwise accuracies
effects_dict = get_accuracies(classes_with_rare, classes_with_common, common_unseen, resfiles)
storage_name = res_folder + 'u_split' + str(args.sn) + '_' + args.dataset + '_r' + str(args.r_ratio) + '_c' + str(args.c_ratio) + '_rarity_effects.pickle'
pfile = open(storage_name, 'wb')
pickle.dump(effects_dict, pfile)
pfile.close()
# plot graphs
plot_rarity_effects(args, effects_dict, model_paths)
print('\n\nDone!')
sys.stdout.close()