diff --git a/mdp/mcvi_test.py b/mdp/mcvi_test.py index f8b77f79..869b557b 100644 --- a/mdp/mcvi_test.py +++ b/mdp/mcvi_test.py @@ -41,7 +41,7 @@ def mcvi( j = 0 process = psutil.Process() - start_value = agent.start_value() + start_value, start_progress = agent.start_value_and_progress() assert start_value >= max_start_value, "value iteration is monotonic" max_start_value = max(start_value, max_start_value) @@ -58,7 +58,8 @@ def mcvi( n_states_exploring_starts=len(agent.exploring_starts), start_value=start_value, start_value_by_horizon=start_value / horizon, - start_value_by_progress=start_value / agent.progress_gamma999, + start_value_by_progress=start_value / start_progress, + start_progress=start_progress, ram_usage_gb=process.memory_info().rss / 1024**3, exploration_states_per_step=len(agent.states) / (i + 1), exploration_gamma9999=agent.exploration_gamma9999, diff --git a/mdp/monte_carlo_value_iteration.py b/mdp/monte_carlo_value_iteration.py index e30cb8a5..6b28ed34 100644 --- a/mdp/monte_carlo_value_iteration.py +++ b/mdp/monte_carlo_value_iteration.py @@ -24,7 +24,8 @@ def collision_resistant_hash(x): class State: def __init__(self): - self.value = 0 + self.value = 0 # estimate of future rewards + self.progress = 0 # estimate of future progress self.count = 0 self._actions = None # action idx -> state hash transition list self._honest = None # honest action id @@ -98,19 +99,14 @@ def set_exploration(self, *args, eps=None, eps_honest=None, eps_es=None): assert 0 <= eps_es <= 1 self.eps_es = eps_es - def state_hash_value(self, state_hash): - if state_hash in self.states: - return self.states[state_hash].value - else: - assert False, "there should be an initial estimate for all states" - return 0 - - def start_value(self): + def start_value_and_progress(self): v = 0 + p = 0 for full_state, prob in self.model.start(): state, state_hash = self.state_and_hash_of_full_state(full_state) v += prob * state.value - return v + p += prob * state.progress + return v, p def step(self): full_state = self.full_state @@ -146,17 +142,23 @@ def step(self): # consider all available actions, tracking ... max_i = 0 # index of best action max_q = 0 # value of best action + max_p = 0 # progress of best action for i, transitions in enumerate(actions): q = 0 # action value estimate + p = 0 for t in transitions: - q += t.probability * (t.reward + self.state_hash_value(t.state)) + to_state = self.states[t.state] + q += t.probability * (t.reward + to_state.value) + p += t.probability * (t.progress + to_state.progress) if q > max_q: max_i = i max_q = q + max_p = p - # update state-value estimate + # update state-value and progress estimate state.value = max_q + state.progress = max_p # exploring starts heuristic: # we try to record such states that have better than honest value