Skip to content

Commit

Permalink
improve exploring starts logic
Browse files Browse the repository at this point in the history
  • Loading branch information
pkel committed May 12, 2024
1 parent b3d63f1 commit 1f0fbb4
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 40 deletions.
16 changes: 15 additions & 1 deletion mdp/fc16sapirshtein.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,9 +179,23 @@ def honest(self, s: BState) -> list[Action]:
def shutdown(self, s: BState) -> list[Transition]:
# Rewards and progress are calculated on common chain. Terminating with
# a no-op is already fair.
return [Transition(state=s, probability=1, reward=0, progress=0)]
# return [Transition(state=s, probability=1, reward=0, progress=0)]
# NOTE In principle, we could do and award a full release here, but this
# would change the model. Maybe evaluate this separately.
snew = BState(a=0, h=0, fork=IRRELEVANT)
if s.h > s.a:
return [Transition(state=snew, probability=1, reward=0, progress=s.h)]
if s.a > s.h:
return [Transition(state=snew, probability=1, reward=s.a, progress=s.a)]
if s.a == s.h:
return [
Transition(
state=snew, probability=self.gamma, reward=s.a, progress=s.a
),
Transition(
state=snew, probability=1 - self.gamma, reward=0, progress=s.h
),
]


mappable_params = dict(alpha=0.125, gamma=0.25)
Expand Down
54 changes: 39 additions & 15 deletions mdp/measure-rtdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from sm import SelfishMining

# solving algorithm
from model import PTO_wrapper
from rtdp import RTDP
from compiler import Compiler

Expand All @@ -27,17 +28,17 @@
dict(alpha=1 / 4, gamma=1 / 4, attacker="weak"),
dict(alpha=1 / 3, gamma=1 / 3, attacker="intermediate"),
dict(
alpha=0.45, gamma=0.90, attacker="strong"
alpha=0.42, gamma=0.82, attacker="strong"
), # TODO double check whether we can do 1/2
]

rows = [
dict(row=1, protocol="bitcoin", model="fc16", trunc=40, algo="aft20", ref=1),
dict(row=2, protocol="bitcoin", model="aft20", trunc=40, algo="aft20", ref=1),
# dict(row=2, protocol="bitcoin", model="aft20", trunc=40, algo="aft20", ref=1),
dict(row=3, protocol="bitcoin", model="fc16", trunc=40, algo="rtdp", ref=1),
dict(row=4, protocol="bitcoin", model="aft20", trunc=40, algo="rtdp", ref=1),
dict(row=5, protocol="bitcoin", model="fc16", trunc=0, algo="rtdp", ref=1),
dict(row=6, protocol="bitcoin", model="aft20", trunc=0, algo="rtdp", ref=1),
# dict(row=4, protocol="bitcoin", model="aft20", trunc=40, algo="rtdp", ref=1),
# dict(row=5, protocol="bitcoin", model="fc16", trunc=0, algo="rtdp", ref=1),
# dict(row=6, protocol="bitcoin", model="aft20", trunc=0, algo="rtdp", ref=1),
# dict(row=7, protocol="bitcoin", model="generic", trunc=10, algo="aft20", ref=1),
# dict(row=8, protocol="bitcoin", model="generic", trunc=10, algo="rtdp", ref=1),
# dict(row=9, protocol="bitcoin", model="generic", trunc=0, algo="rtdp", ref=5),
Expand All @@ -47,7 +48,7 @@
# Algorithms


def post_algo(mdp, policy, value_estimate, start_value, start_progress):
def post_algo(mdp, policy, value_estimate, start_value, start_progress, **kwargs):
value_estimate = numpy.array(value_estimate)

# Steady States: I thought it would be cool to report on steady states
Expand All @@ -70,15 +71,15 @@ def post_algo(mdp, policy, value_estimate, start_value, start_progress):
mdp_n_states=mdp.n_states,
mdp_n_transitions=mdp.n_transitions,
pimc_n_states=pimc["prb"].get_shape()[0],
**kwargs,
)


def algo_aft20(implicit_mdp, *args, horizon, vi_delta, **kwargs):
# Compile Full MDP
mdp = Compiler(implicit_mdp).mdp()
implicit_ptmdp = PTO_wrapper(implicit_mdp, horizon=horizon, terminal_state=b"")

# Derive PTO MDP
mdp = aft20barzur.ptmdp(mdp, horizon=horizon)
# Compile Full MDP
mdp = Compiler(implicit_ptmdp).mdp()

# Solve PTO MDP
vi = mdp.value_iteration(stop_delta=vi_delta, eps=None, discount=1)
Expand All @@ -94,15 +95,38 @@ def algo_aft20(implicit_mdp, *args, horizon, vi_delta, **kwargs):


def algo_rtdp(implicit_mdp, *args, horizon, rtdp_steps, rtdp_eps, rtdp_es, **kwargs):
agent = RTDP(implicit_mdp, eps=rtdp_eps, eps_honest=0, es=rtdp_es, horizon=horizon)
implicit_ptmdp = PTO_wrapper(implicit_mdp, horizon=horizon, terminal_state=b"")

agent = RTDP(implicit_ptmdp, eps=rtdp_eps, eps_honest=0, es=rtdp_es)

for i in range(rtdp_steps):
log = []

i = 0
j = 0
while i < rtdp_steps:
i += 1
agent.step()

# logging
j += 1
if j >= 1000:
j = 0
sv, sp = agent.start_value_and_progress()
log.append(
dict(
step=i,
start_value=sv,
start_progress=sp,
n_states=len(agent.states),
)
)

m = agent.mdp()
start_value, start_progress = agent.start_value_and_progress()

return post_algo(m["mdp"], m["policy"], m["value"], start_value, start_progress)
return post_algo(
m["mdp"], m["policy"], m["value"], start_value, start_progress, log=log
)


# How do we instantiate the models and run the algo?
Expand Down Expand Up @@ -141,10 +165,10 @@ def implicit_mdp(*args, model, protocol, trunc, alpha, gamma, **kwargs):
argp = argparse.ArgumentParser()
argp.add_argument("-j", "--n_jobs", type=int, default=1, metavar="INT")
argp.add_argument("-H", "--horizon", type=int, default=30, metavar="INT")
argp.add_argument("--rtdp_eps", type=float, default=0.1, metavar="FLOAT")
argp.add_argument("--rtdp_eps", type=float, default=0.2, metavar="FLOAT")
argp.add_argument("--rtdp_es", type=float, default=0.9, metavar="FLOAT")
argp.add_argument("--rtdp_steps", type=int, default=50_000, metavar="INT")
argp.add_argument("--vi_delta", type=float, default=0.01, metavar="FLOAT")
argp.add_argument("--vi_delta", type=float, default=0.001, metavar="FLOAT")
args = argp.parse_args()

# Single measurement
Expand Down
59 changes: 37 additions & 22 deletions mdp/rtdp.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from collections import deque
import mdp
from mdp import sum_to_one
from model import Model, PTO_wrapper, Transition
from model import Model, Transition
import random
import xxhash
import sys
Expand Down Expand Up @@ -29,6 +28,7 @@ def __init__(self):
self.value = 0 # estimate of future rewards
self.progress = 0 # estimate of future progress
self.count = 0
self.es_last_seen = -1
self.actions = None # action idx -> state hash transition list
self.honest = None # honest action id

Expand All @@ -38,16 +38,12 @@ def __init__(
self,
model: Model,
*args,
horizon: int,
eps: float,
eps_honest: float = 0,
es: float = 0
es: float = 0,
es_threshold=500_000,
):
assert horizon > 0

model = PTO_wrapper(model, horizon=horizon, terminal_state=b"")
self.model = model
self.horizon = horizon

self.set_exploration(eps=eps, eps_honest=eps_honest, es=es)

Expand All @@ -64,20 +60,23 @@ def __init__(
# We overcome this by maintaining a set of good full states which are
# worth using as starting states. What is a good state? We just use the
# set of recently visited states.
self.exploring_starts = deque(maxlen=100 * horizon) # full states
self.es_buf = dict() # state hash -> full state
self.es_threshold = es_threshold

self.i = 0

# start states
self.start_states = list() # list[tuple[float, hash, full_state]]
for full_state, prob in self.model.start():
state, state_hash = self.state_and_hash_of_full_state(full_state)
self.start_states.append((prob, state_hash, full_state))
self.start_states.append((prob, state_hash, full_state, state))

# init state & state_id
self.start_new_episode()

# statistics
self.n_episodes = 0
self.progress_gamma999 = horizon
self.progress_gamma999 = 0
self.exploration_gamma9999 = 1
self.state_size_gamma9999 = 0
self.n_states_visited = 0
Expand All @@ -87,13 +86,23 @@ def start_new_episode(self):
self.episode_progress = 0 # statistics

# Barto and Sutton's "exploring starts"
if self.es > 0 and len(self.exploring_starts) > 0:
if self.es > 0:
if random.random() < self.es:
self.full_state = random.choice(self.exploring_starts)
return
candidates = []
for state_hash, state in self.states.items():
if state.es_last_seen < 1:
continue
if self.i - state.es_last_seen < self.es_threshold:
candidates.append(self.es_buf[state_hash])
else:
# We won't need these anymore
self.es_buf.pop(state_hash, None)
if len(candidates) > 0:
self.set_full_state(random.choice(candidates))
return

# start from an actual start state otherwise
self.full_state = sample(self.start_states, lambda x: x[0])[2]
self.set_full_state(sample(self.start_states, lambda x: x[0])[2])

def reset(self):
self.n_episodes += 1
Expand All @@ -118,15 +127,19 @@ def set_exploration(self, *args, eps=None, eps_honest=None, es=None):
def start_value_and_progress(self):
v = 0
p = 0
for full_state, prob in self.model.start():
state, state_hash = self.state_and_hash_of_full_state(full_state)
for prob, _hash, _full, state in self.start_states:
v += prob * state.value
p += prob * state.progress
return v, p

def set_full_state(self, full_state):
self.full_state = full_state
self.state, self.state_hash = self.state_and_hash_of_full_state(full_state)

def step(self):
self.i += 1
full_state = self.full_state
state, state_hash = self.state_and_hash_of_full_state(full_state)
state = self.state

# ## Statistics

Expand All @@ -152,6 +165,7 @@ def step(self):
# no action available, terminal state
self.reset()
assert state.value == 0
assert state.progress == 0
return

# value iteration step:
Expand Down Expand Up @@ -197,11 +211,12 @@ def step(self):
a = self.model.actions(full_state)[i]
to = sample(self.model.apply(a, full_state), lambda x: x.probability)
self.episode_progress += to.progress # statistics
self.full_state = to.state
self.set_full_state(to.state)

# exploring starts
# ## Exploring starts
if greedy:
self.exploring_starts.append(to.state)
self.state.es_last_seen = self.i + 1
self.es_buf[self.state_hash] = self.full_state

def actions(self, state, full_state):
if state.actions is not None:
Expand Down Expand Up @@ -348,7 +363,7 @@ def mdp(self):

# mdp: set start states
assert len(m.start) == 0
for prob, state_hash, full_state in self.start_states:
for prob, state_hash, full_state, state in self.start_states:
m.start[state_id[state_hash]] = prob

assert m.check()
Expand Down
7 changes: 5 additions & 2 deletions mdp/rtdp_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import aft20barzur
from model import PTO_wrapper
from rtdp import RTDP
import pprint
import psutil
Expand All @@ -21,10 +22,12 @@ def rtdp(
honest_warmup_steps=0,
**kwargs
):
model = PTO_wrapper(model, horizon=horizon, terminal_state=b"")

if honest_warmup_steps > 0:
agent = RTDP(model, eps=0, eps_honest=1, horizon=horizon, **kwargs)
agent = RTDP(model, eps=0, eps_honest=1, **kwargs)
else:
agent = RTDP(model, eps=eps, eps_honest=eps_honest, horizon=horizon, **kwargs)
agent = RTDP(model, eps=eps, eps_honest=eps_honest, **kwargs)

max_start_value = 0

Expand Down

0 comments on commit 1f0fbb4

Please sign in to comment.