forked from GerbenBeintema/gym-unbalanced-disk
-
Notifications
You must be signed in to change notification settings - Fork 0
/
ANN_Models.py
78 lines (59 loc) · 2.57 KB
/
ANN_Models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
from torch import nn, zeros, stack, cat, float64
class NonLinear(nn.Module):
"""Create A Non-Linear Model"""
def __init__(self, out_features, output_dim:int=1) -> None:
super().__init__()
self.fc1 = nn.Sequential(nn.LazyLinear(out_features),
nn.LeakyReLU(),
nn.Linear(out_features, output_dim)
)
self.name=f'NonLinear_{out_features}'
def forward(self, x):
x = self.fc1(x)
return x
class NARX(nn.Module):
"""This module is setup as a NARX model"""
def __init__(self, out_features, output_dim:int=1):
super().__init__()
self.fc0 = nn.Sequential(nn.LazyLinear(out_features),
nn.LeakyReLU(),
nn.Linear(out_features, out_features),
nn.LeakyReLU()
)
self.fc1 = nn.Sequential(nn.Linear(out_features, out_features),
nn.LeakyReLU(),
nn.Linear(out_features, output_dim),
)
self.name = f'NARX_{out_features}'
def forward(self, x):
x = self.fc1(self.fc0(x))
return x
class RNN(nn.Module):
"""This module is setup as a RNN model"""
def __init__(self, input_size:int=1, hidden_size:int=40, output_size:int=1, nr_nodes:int=40):
super().__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.nr_nodes = nr_nodes
self.output_size = output_size
net = lambda n_in,n_out: nn.Sequential(nn.Linear(n_in, self.nr_nodes),
nn.LeakyReLU(),
nn.Linear(self.nr_nodes, n_out)
).double()
# Initialize the network
self.H2H = net(self.input_size + self.hidden_size, self.hidden_size)
self.H2O = net(self.input_size + self.hidden_size, self.output_size)
self.name = f'RNN'
def forward(self, inputs):
"""forward pass of the RNN model"""
# Initialize hidden state
hidden = zeros(inputs.size(0), self.hidden_size, dtype=float64, device=inputs.device)
outputs = []
for i in range(inputs.size(1)):
# Set up data for this timestep
u = inputs[:, i]
combined = cat((u[:, None], hidden), dim=1)
# Update hidden state
hidden = self.H2H(combined)
outputs.append(self.H2O(combined)[:,0])
return stack(outputs, dim=1)