Skip to content

Commit

Permalink
Calculate expected future progress during optimization
Browse files Browse the repository at this point in the history
  • Loading branch information
pkel committed Apr 29, 2024
1 parent 44e39e6 commit d699b96
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 14 deletions.
5 changes: 3 additions & 2 deletions mdp/mcvi_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def mcvi(
j = 0
process = psutil.Process()

start_value = agent.start_value()
start_value, start_progress = agent.start_value_and_progress()
assert start_value >= max_start_value, "value iteration is monotonic"
max_start_value = max(start_value, max_start_value)

Expand All @@ -58,7 +58,8 @@ def mcvi(
n_states_exploring_starts=len(agent.exploring_starts),
start_value=start_value,
start_value_by_horizon=start_value / horizon,
start_value_by_progress=start_value / agent.progress_gamma999,
start_value_by_progress=start_value / start_progress,
start_progress=start_progress,
ram_usage_gb=process.memory_info().rss / 1024**3,
exploration_states_per_step=len(agent.states) / (i + 1),
exploration_gamma9999=agent.exploration_gamma9999,
Expand Down
26 changes: 14 additions & 12 deletions mdp/monte_carlo_value_iteration.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ def collision_resistant_hash(x):

class State:
def __init__(self):
self.value = 0
self.value = 0 # estimate of future rewards
self.progress = 0 # estimate of future progress
self.count = 0
self._actions = None # action idx -> state hash transition list
self._honest = None # honest action id
Expand Down Expand Up @@ -98,19 +99,14 @@ def set_exploration(self, *args, eps=None, eps_honest=None, eps_es=None):
assert 0 <= eps_es <= 1
self.eps_es = eps_es

def state_hash_value(self, state_hash):
if state_hash in self.states:
return self.states[state_hash].value
else:
assert False, "there should be an initial estimate for all states"
return 0

def start_value(self):
def start_value_and_progress(self):
v = 0
p = 0
for full_state, prob in self.model.start():
state, state_hash = self.state_and_hash_of_full_state(full_state)
v += prob * state.value
return v
p += prob * state.progress
return v, p

def step(self):
full_state = self.full_state
Expand Down Expand Up @@ -146,17 +142,23 @@ def step(self):
# consider all available actions, tracking ...
max_i = 0 # index of best action
max_q = 0 # value of best action
max_p = 0 # progress of best action
for i, transitions in enumerate(actions):
q = 0 # action value estimate
p = 0
for t in transitions:
q += t.probability * (t.reward + self.state_hash_value(t.state))
to_state = self.states[t.state]
q += t.probability * (t.reward + to_state.value)
p += t.probability * (t.progress + to_state.progress)

if q > max_q:
max_i = i
max_q = q
max_p = p

# update state-value estimate
# update state-value and progress estimate
state.value = max_q
state.progress = max_p

# exploring starts heuristic:
# we try to record such states that have better than honest value
Expand Down

0 comments on commit d699b96

Please sign in to comment.