forked from Vizards8/DQN_Mine
-
Notifications
You must be signed in to change notification settings - Fork 0
/
agent.py
111 lines (99 loc) · 5.03 KB
/
agent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
#!/usr/bin/env python
# coding=utf-8
'''
@Author: John
@Email: johnjim0816@gmail.com
@Date: 2020-06-12 00:50:49
@LastEditor: John
LastEditTime: 2021-05-07 16:30:05
@Discription:
@Environment: python 3.7.7
'''
'''off-policy
'''
import torch
import torch.nn as nn
import torch.optim as optim
import random
import math
import numpy as np
from memory import ReplayBuffer
from model import MLP
from hparam import hparams as hp
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
class DQN:
def __init__(self):
self.action_dim = hp.action_dim # 总的动作个数
self.gamma = hp.gamma # 奖励的折扣因子
# e-greedy策略相关参数
self.frame_idx = 0 # 用于epsilon的衰减计数
self.epsilon = lambda frame_idx: hp.epsilon_end + \
(hp.epsilon_start - hp.epsilon_end) * \
math.exp(-1. * frame_idx / hp.epsilon_decay)
self.batch_size = hp.batch_size
self.policy_net = MLP(hp.state_dim, hp.action_dim, hidden_dim=hp.hidden_dim).to(device)
self.target_net = MLP(hp.state_dim, hp.action_dim, hidden_dim=hp.hidden_dim).to(device)
for target_param, param in zip(self.target_net.parameters(),
self.policy_net.parameters()): # copy params from policy net
target_param.data.copy_(param.data)
self.optimizer = optim.Adam(self.policy_net.parameters(), lr=hp.lr)
self.memory = ReplayBuffer(hp.memory_capacity)
def choose_action(self, state, mask):
'''选择动作
'''
self.frame_idx += 1
# print(f'action: {self.frame_idx}')
# print(f'epsilon:{self.epsilon(self.frame_idx)}')
if random.random() > self.epsilon(self.frame_idx):
action = self.predict(state, mask)
else:
action = random.randrange(self.action_dim)
while not mask[action]:
action = random.randrange(self.action_dim)
return action
def predict(self, state, mask):
with torch.no_grad():
state = torch.tensor([state], device=device, dtype=torch.float32)
mask = torch.tensor(mask, device=device, dtype=torch.int32)
q_values = self.policy_net(state)
q_values_softmax = torch.softmax(q_values, dim=1)
q_values_softmax = q_values_softmax * mask
action = q_values_softmax.max(1)[1].item()
return action
def update(self):
if len(self.memory) < self.batch_size:
return
# 从memory中随机采样transition
state_batch, action_batch, reward_batch, next_state_batch, done_batch = self.memory.sample(
self.batch_size)
'''转为张量
例如tensor([[-4.5543e-02, -2.3910e-01, 1.8344e-02, 2.3158e-01],...,[-1.8615e-02, -2.3921e-01, -1.1791e-02, 2.3400e-01]])'''
state_batch = torch.tensor(state_batch, device=device, dtype=torch.float)
action_batch = torch.tensor(action_batch, device=device).unsqueeze(1) # 例如tensor([[1],...,[0]])
reward_batch = torch.tensor(reward_batch, device=device, dtype=torch.float) # tensor([1., 1.,...,1])
next_state_batch = torch.tensor(next_state_batch, device=device, dtype=torch.float)
done_batch = torch.tensor(np.float32(done_batch), device=device)
'''计算当前(s_t,a)对应的Q(s_t, a)'''
'''torch.gather:对于a=torch.Tensor([[1,2],[3,4]]),那么a.gather(1,torch.Tensor([[0],[1]]))=torch.Tensor([[1],[3]])'''
q_values = self.policy_net(state_batch).gather(dim=1, index=action_batch) # 等价于self.forward
# 计算所有next states的V(s_{t+1}),即通过target_net中选取reward最大的对应states
next_q_values = self.target_net(next_state_batch).max(1)[0].detach() # 比如tensor([ 0.0060, -0.0171,...,])
# 计算 expected_q_value
# 对于终止状态,此时done_batch[0]=1, 对应的expected_q_value等于reward
expected_q_values = reward_batch + self.gamma * next_q_values * (1 - done_batch)
# self.loss = F.smooth_l1_loss(q_values,expected_q_values.unsqueeze(1)) # 计算 Huber loss
loss = nn.MSELoss()(q_values, expected_q_values.unsqueeze(1)) # 计算 均方误差loss
# 优化模型
self.optimizer.zero_grad() # zero_grad清除上一步所有旧的gradients from the last step
# loss.backward()使用backpropagation计算loss相对于所有parameters(需要gradients)的微分
loss.backward()
# for param in self.policy_net.parameters(): # clip防止梯度爆炸
# param.grad.data.clamp_(-1, 1)
self.optimizer.step() # 更新模型
return loss.item()
def save(self, path):
torch.save(self.target_net.state_dict(), path + 'dqn_checkpoint.pth')
def load(self, path):
self.target_net.load_state_dict(torch.load(path + 'dqn_checkpoint.pth'))
for target_param, param in zip(self.target_net.parameters(), self.policy_net.parameters()):
param.data.copy_(target_param.data)