-
Notifications
You must be signed in to change notification settings - Fork 1
/
config.py
135 lines (109 loc) · 3.3 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import json
from logger.logger import logger
global seed
seed = 0
global configuration
configuration = {
"DEBUG": False,
# These paths are relative to the main directory
"paths": {
'data_root': "data/MR/",
# 'data_root': "data/SemEval16/",
# 'data_root': "data/MAMS_ACSA/",
# 'data_root': "data/FourSquared/",
# 'data_root': "data/SamsungGalaxy/",
"dataset": "sample.csv",
"saved_train_graphs": "MR_train_graphs.bin",
"saved_large_graph": "MR_large_graph.bin",
"dataset_dataframe": "MR_dataframe.csv",
"label_text_to_label_id": "MR_label_text_to_label_id.json",
"output": "output/",
},
"data": {
"dataset": {
"name": "MR"
# 'name': 'MAMS_ATSA'
# 'name': 'SamsungGalaxy'
# 'name': 'FourSquared'
},
"trainval_test_split": 0.3,
"train_val_split": 0.3,
"min_label_occurences": 0,
"multi_label": False
},
"model": {
"in_dim": 300,
"hidden_dim": 150,
"out_dim": 2,
"num_heads": 2,
},
"training": {
"seed": 23,
"epochs": 2,
"create_dataset": False,
"dropout": 0.2,
"train_batch_size": 30,
"val_batch_size": 60,
"test_batch_size": 60,
"early_stopping_patience": 6,
"early_stopping_delta": 0,
"optimizer": {
"optimizer_type": "adam",
"learning_rate": 3e-4,
},
"threshold": 0.5,
},
"embeddings": {
"embedding_file": "glove-twitter-25",
},
"hardware": {
"num_workers": 16
}
}
class Config(object):
""" Contains all configuration details of the project. """
def __init__(self):
super(Config, self).__init__()
self.configuration = configuration
def get_config(self):
return self.configuration
def print_config(self, indent=4, sort=True):
""" Prints the config. """
logger.info("[{}] : {}".format("Configuration",
json.dumps(self.configuration,
indent=indent,
sort_keys=sort)))
@ staticmethod
def get_platform():
""" Returns dataset path based on OS.
:return: str
"""
import platform
if platform.system() == 'Windows':
return platform.system()
elif platform.system() == 'Linux':
return platform.system()
else: # OS X returns name 'Darwin'
return "OSX"
@ staticmethod
def get_username():
"""
:returns the current username.
:return: string
"""
try:
import os
import pwd
username = pwd.getpwuid(os.getuid()).pw_name
except Exception:
import getpass
username = getpass.getuser()
finally:
username = os.environ.get('USER')
return username
config_cls = Config()
config_cls.print_config()
global platform
platform = config_cls.get_platform()
global username
username = config_cls.get_username()