-
Notifications
You must be signed in to change notification settings - Fork 4
/
train.py
141 lines (111 loc) · 5.09 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
from env.make_env import make_env
import argparse, datetime
from tensorboardX import SummaryWriter
import numpy as np
import torch
import gym
from agent import BiCNet
from normalized_env import ActionNormalizedEnv, ObsEnv, reward_from_state
def main(args):
env = make_env('simple_spread')
# env = make_env('simple')
# env = gym.make('Pendulum-v0')
env = ActionNormalizedEnv(env)
env = ObsEnv(env)
kwargs = dict()
kwargs['config'] = args
torch.manual_seed(args.seed)
if args.tensorboard:
writer = SummaryWriter(log_dir='runs/'+args.log_dir)
model = BiCNet(14, 2, 3, **kwargs)
# model = BiCNet(4, 2, 1, **kwargs)
episode = 0
total_step = 0
while episode < args.max_episodes:
state = env.reset()
episode += 1
step = 0
accum_reward = 0
rewardA = 0
rewardB = 0
rewardC = 0
prev_reward = np.zeros((3), dtype=np.float32)
while True:
# action = agent.random_action()
if episode > args.warmup:
action = model.choose_action(state, noisy=True)
else:
action = model.random_action()
next_state, reward, done, info = env.step(action)
step += 1
total_step += 1
reward = np.array(reward)
'''KeyboardInterrupt
Reward Shaping
- Distance to landmarks
'''
rew1 = reward_from_state(next_state)
#if step % 5 == 0:
# rew1 -= 0.1
reward = rew1 + (np.array(reward, dtype=np.float32) / 100.)
accum_reward += sum(reward)
rewardA += reward[0]
rewardB += reward[1]
rewardC += reward[2]
if args.render and episode % 100 == 0:
env.render(mode='rgb_array')
model.memory(state, action, reward, next_state, done)
state = next_state
if len(model.replay_buffer) >= args.batch_size and total_step % args.steps_per_update == 0:
model.prep_train()
model.train()
model.prep_eval()
if args.episode_length < step or (True in done):
c_loss, a_loss = model.get_loss()
action_std = model.get_action_std()
print("[Episode %05d] reward %6.4f eps %.4f" % (episode, accum_reward, model.epsilon), end='')
if args.tensorboard:
writer.add_scalar(tag='agent/reward', global_step=episode, scalar_value=accum_reward.item())
writer.add_scalar(tag='agent/reward_0', global_step=episode, scalar_value=rewardA.item())
writer.add_scalar(tag='agent/reward_1', global_step=episode, scalar_value=rewardB.item())
writer.add_scalar(tag='agent/reward_2', global_step=episode, scalar_value=rewardC.item())
writer.add_scalar(tag='agent/epsilon', global_step=episode, scalar_value=model.epsilon)
if c_loss and a_loss:
writer.add_scalars('agent/loss', global_step=episode, tag_scalar_dict={'actor':a_loss, 'critic':c_loss})
if action_std:
writer.add_scalar(tag='agent/action_std', global_step=episode, scalar_value=action_std)
if c_loss and a_loss:
print(" a_loss %3.2f c_loss %3.2f" % (a_loss, c_loss), end='')
if action_std:
print(" action_std %3.2f" % (action_std), end='')
print()
env.reset()
model.reset()
break
if args.tensorboard:
writer.close()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--max_episodes', default=1000000, type=int)
parser.add_argument('--episode_length', default=50, type=int)
parser.add_argument('--memory_length', default=int(1e5), type=int)
parser.add_argument("--steps_per_update", default=100, type=int)
parser.add_argument('--tau', default=0.001, type=float)
parser.add_argument('--gamma', default=0.95, type=float)
parser.add_argument('--warmup', default=100, type=int)
parser.add_argument('--use_cuda', default=True, type=bool)
parser.add_argument('--seed', default=777, type=int)
parser.add_argument('--a_lr', default=0.001, type=float)
parser.add_argument('--c_lr', default=0.001, type=float)
parser.add_argument('--batch_size', default=64, type=int)
parser.add_argument('--render', action="store_true")
parser.add_argument('--ou_theta', default=0.15, type=float)
parser.add_argument('--ou_mu', default=0.0, type=float)
parser.add_argument('--ou_sigma', default=0.2, type=float)
parser.add_argument('--epsilon_decay', default=1000000, type=int)
parser.add_argument('--reward_coef', default=1, type=float)
parser.add_argument('--tensorboard', action="store_true")
parser.add_argument("--save_interval", default=1000, type=int)
parser.add_argument('--log_dir', default=datetime.datetime.now().strftime('%Y%m%d_%H%M%S'))
args = parser.parse_args()
main(args)