-
Notifications
You must be signed in to change notification settings - Fork 0
/
cnn.py
28 lines (25 loc) · 1009 Bytes
/
cnn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
from tensorflow import keras
def get_cnn(final_layer_hidden_size):
"""
Gets a convolutional neural network that will approximate the Q function.
:param final_layer_hidden_size: The number of neurons in the last layer.
:return: The CNN.
"""
model = keras.models.Sequential([
keras.layers.Conv2D(16,
(8, 8),
strides=(4, 4),
padding='same',
input_shape=(84, 84, 4), activation='relu'),
keras.layers.Activation('relu'),
keras.layers.Conv2D(32,
(4, 4),
strides=(2, 2),
padding='same',
activation='relu'),
keras.layers.Flatten(),
keras.layers.Dense(256, activation='relu'),
keras.layers.Dense(final_layer_hidden_size)
])
model.compile(optimizer='Adam', loss=keras.losses.mean_squared_error)
return model