From 8abf0d155c723bad306b8c606776f5dbd60d700f Mon Sep 17 00:00:00 2001 From: Patrik Keller Date: Wed, 10 Apr 2024 19:40:51 +0200 Subject: [PATCH] Draft fc16-like observation space for generic model --- gym/rust/cpr_gym_rs/__init__.py | 16 ++--- gym/rust/cpr_gym_rs/envs.py | 8 +-- gym/rust/hyperparams/dqn.yml | 4 +- gym/rust/hyperparams/ppo.yml | 8 +-- gym/rust/src/generic/intf.rs | 13 ++++ gym/rust/src/generic/mod.rs | 104 +++++++++++++------------------- gym/rust/src/lib.rs | 6 +- gym/rust/src/proto/nakamoto.rs | 22 +++++++ gym/rust/test/test_rust.py | 2 +- 9 files changed, 98 insertions(+), 85 deletions(-) diff --git a/gym/rust/cpr_gym_rs/__init__.py b/gym/rust/cpr_gym_rs/__init__.py index 8730f944..5998a552 100644 --- a/gym/rust/cpr_gym_rs/__init__.py +++ b/gym/rust/cpr_gym_rs/__init__.py @@ -1,17 +1,19 @@ from . import _rust from gymnasium.envs.registration import register +default_assumptions = dict( + alpha=0.42, + gamma=0.84, + horizon=100, +) + register( id="FC16SSZwPT-v0", entry_point="cpr_gym_rs.envs:FC16SSZwPT", nondeterministic=True, order_enforce=False, max_episode_steps=1000, # forced termination, not Markovian - kwargs=dict( - alpha=1 / 3, - gamma=0.5, - horizon=100, - ), + kwargs=default_assumptions, ) register( @@ -20,7 +22,5 @@ nondeterministic=True, order_enforce=False, max_episode_steps=1000, # forced termination, not Markovian - kwargs=dict( - protocol="nakamoto", alpha=1 / 3, gamma=0.5, horizon=25, max_blocks=128 - ), + kwargs=dict(protocol="nakamoto") | default_assumptions, ) diff --git a/gym/rust/cpr_gym_rs/envs.py b/gym/rust/cpr_gym_rs/envs.py index 92f88b83..96a695e3 100644 --- a/gym/rust/cpr_gym_rs/envs.py +++ b/gym/rust/cpr_gym_rs/envs.py @@ -22,12 +22,10 @@ def step(self, action): class Generic(gymnasium.Env): protocols = {"nakamoto": lambda: _rust.Protocol.Nakamoto} - def __init__(self, protocol, *args, alpha, gamma, horizon, max_blocks, **kwargs): + def __init__(self, protocol, *args, alpha, gamma, horizon, **kwargs): p = Generic.protocols[protocol](*args, **kwargs) - self.rs_env = _rust.GenericEnv( - p, alpha=alpha, gamma=gamma, horizon=horizon, max_blocks=max_blocks - ) + self.rs_env = _rust.GenericEnv(p, alpha=alpha, gamma=gamma, horizon=horizon) self.action_space = gymnasium.spaces.Discrete( 255, start=-127 @@ -37,7 +35,7 @@ def __init__(self, protocol, *args, alpha, gamma, horizon, max_blocks, **kwargs) obs, _info = self.rs_env.reset() self.observation_space = gymnasium.spaces.Box( - shape=obs.shape, low=0, high=2, dtype=numpy.uint8 + shape=obs.shape, low=float("-inf"), high=float("inf"), dtype=numpy.float32 ) def reset(self, *args, seed=None, options=None): diff --git a/gym/rust/hyperparams/dqn.yml b/gym/rust/hyperparams/dqn.yml index bfa5fdb7..b63d242b 100644 --- a/gym/rust/hyperparams/dqn.yml +++ b/gym/rust/hyperparams/dqn.yml @@ -39,12 +39,12 @@ FC16SSZwPT-v0: Nakamoto-v0: n_envs: 24 n_timesteps: 50_000_000 - buffer_size: 1_000_000 # how many experiences recorded for sampling + buffer_size: 5_000_000 # how many experiences recorded for sampling train_freq: 10_000 # how regularly to update / length of rollout in steps gradient_steps: 100 # how many updates per rollout (at the end of each rollout) learning_rate: 0.001 # weight of individual update batch_size: 500 # how many experiences to sample for each update - learning_starts: 100_000 # steps before learning starts + learning_starts: 500_000 # steps before learning starts gamma: 1 # discount factor target_update_interval: 10_000 # how often to update target network (steps) tau: 0.01 # weight of target network update; 1 implies hard update diff --git a/gym/rust/hyperparams/ppo.yml b/gym/rust/hyperparams/ppo.yml index db346da7..9ab17369 100644 --- a/gym/rust/hyperparams/ppo.yml +++ b/gym/rust/hyperparams/ppo.yml @@ -10,10 +10,10 @@ FC16SSZwPT-v0: Nakamoto-v0: n_timesteps: 4_000_000 - n_envs: 12 # = size of rollout buffer / n_steps - n_steps: 128 # = size of rollout buffer / n_envs + n_envs: 32 # = size of rollout buffer / n_steps + n_steps: 8192 # = size of rollout buffer / n_envs learning_rate: 0.0001 # weight of each update - batch_size: 128 # number of steps to consider per update - n_epochs: 1 # how often to process each rollout buffer + batch_size: 512 # number of steps to consider per update + n_epochs: 10 # how often to process each rollout buffer ent_coef: 0.05 # entropy, exploration term policy: 'MlpPolicy' diff --git a/gym/rust/src/generic/intf.rs b/gym/rust/src/generic/intf.rs index 0077f800..75bff18e 100644 --- a/gym/rust/src/generic/intf.rs +++ b/gym/rust/src/generic/intf.rs @@ -52,3 +52,16 @@ pub trait Protocol { // Given block `b`, what rewards where allocated since `pred(b)`? // Block `b` is guaranteed to be part of a linear history of some tip. } + +pub trait FeatureExtractor { + // Interface for specifying protocol-dependent feature extractors (WIP) + + // BlockDAG with attacker and defender entrypoints and common ancestor in linear history. + fn observe>( + &self, + d: &DAG, + atk: Block, + def: Block, + ca: Block, + ) -> Vec; +} diff --git a/gym/rust/src/generic/mod.rs b/gym/rust/src/generic/mod.rs index 7bdca77b..f0ca2a78 100644 --- a/gym/rust/src/generic/mod.rs +++ b/gym/rust/src/generic/mod.rs @@ -60,7 +60,7 @@ type EdgeWeight = (); type Graph = petgraph::graph::DiGraph, EdgeWeight>; pub mod intf; -use intf::{BlockDAG, Protocol}; +use intf::{BlockDAG, FeatureExtractor, Protocol}; impl BlockDAG for Graph { fn parents(&self, b: Block) -> Vec { @@ -302,23 +302,23 @@ use rand::{Rng, SeedableRng}; // environment logic -use numpy::ndarray::Array2; - -pub struct Env +pub struct Env where P: Protocol, + O: FeatureExtractor, { g: Graph, p: P, + o: O, a: AvailableActions, alpha: f32, gamma: f32, horizon: f32, rng: StdRng, - pub obs: Array2, + ca: Block, // cached common ancestor } -fn init_graph(g: &mut Graph, p: &P) +fn init_graph(g: &mut Graph, p: &P) -> Block where P: Protocol, { @@ -328,12 +328,13 @@ where nv: NView::Honest, pd: p.init(), }; - g.add_node(genesis); + g.add_node(genesis) } -impl BlockDAG for Env +impl BlockDAG for Env where P: Protocol, + O: FeatureExtractor, { fn parents(&self, b: Block) -> Vec { self.g.parents(b) @@ -349,44 +350,41 @@ where } } -impl Env +impl Env where P: Protocol, + O: FeatureExtractor, { - pub fn new(p: P, alpha: f32, gamma: f32, horizon: f32, max_blocks: usize) -> Self { + pub fn new(p: P, o: O, alpha: f32, gamma: f32, horizon: f32) -> Self { let mut g = Graph::new(); - init_graph(&mut g, &p); + let genesis = init_graph(&mut g, &p); let a = available_actions(&g); - let mut self_ = Env { + Env { g, p, + o, a, alpha, gamma, horizon, rng: StdRng::from_entropy(), - obs: Array2::zeros((max_blocks, max_blocks + 3)), - }; - self_.observe(); - self_ + ca: genesis, + } } pub fn describe(&self) -> String { format!( - "Generic {{ alpha: {}, gamma: {}, horizon: {:?}, obs_bytes: {} }}", - self.alpha, - self.gamma, - self.horizon, - self.obs.len() + "Generic {{ alpha: {}, gamma: {}, horizon: {:?} }}", + self.alpha, self.gamma, self.horizon, ) // TODO: possible actions } fn reset(&mut self) { self.g.clear(); - init_graph(&mut self.g, &self.p); + let genesis = init_graph(&mut self.g, &self.p); self.a = available_actions(&self.g); - self.observe(); + self.ca = genesis; } pub fn describe_action(&self, a: Action) -> String { @@ -409,8 +407,11 @@ where } fn step(&mut self, a: Action) -> (f64, bool, bool) { + // check the dag invariants + assert!(dag_check(&self.g)); + // derive pre-action state - let old_ca = *self.common_history().last().unwrap(); // TODO/perf: persist in self + let old_ca = self.ca; // decode action & apply match decode_action(self.guarded_action(a)) { @@ -448,12 +449,13 @@ where // progress := blocks mined or blocks on defender chain // reward := number of blocks rewritten in defender chain let h = self.common_history(); + let (terminate, progress, reward); { let mut prg = 0.; let mut rew_atk = 0.; let mut rew_def = 0.; - for b in h.into_iter().rev() { + for &b in h.iter().rev() { if b == old_ca { break; } @@ -472,11 +474,11 @@ where terminate = self.env_terminates(progress); } - // update observation buffer - self.observe(); + // cache common ancestor; observation and next step rely on this + self.ca = *h.last().unwrap(); - // truncate when observation buffer reaches limit - let truncate = self.g.node_count() >= self.obs.dim().0; + // we do not need truncation + let truncate = false; // return (::try_from(reward).unwrap(), terminate, truncate) @@ -490,28 +492,6 @@ where self.g.node_weight_mut(b).unwrap() } - fn observe(&mut self) { - // observation is triggered often; check the dag invariants here - assert!(dag_check(&self.g)); - // reset buffer - self.obs.fill(0); - // fill buffer - for src in self.g.node_indices() { - // adjacency - for dst in self.parents(src) { - self.obs[[src.index(), dst.index() + 3]] = 1; - } - // color / weight - let w = self.weight(src); - let av = w.av as u8; - let dv = w.dv as u8; - let nv = w.nv as u8; - self.obs[[src.index(), 0]] = av; - self.obs[[src.index(), 1]] = dv; - self.obs[[src.index(), 2]] = nv; - } - } - fn entrypoint(&self, m: Party) -> Block { if m == Party::Defender { self.g @@ -737,14 +717,21 @@ use numpy::IntoPyArray; use pyo3::prelude::*; use std::collections::HashMap; -impl Env +impl Env where P: Protocol, + O: FeatureExtractor, { + fn py_observe(&self, py: Python) -> PyObject { + let atk = self.entrypoint(Party::Attacker); + let def = self.entrypoint(Party::Defender); + let obs: Vec = self.o.observe(&self.g, atk, def, self.ca); + obs.into_pyarray(py).into() + } + pub fn py_reset(&mut self, py: Python) -> (PyObject, HashMap) { self.reset(); - // TODO/perf avoid obs cloning? - (self.obs.clone().into_pyarray(py).into(), HashMap::new()) + (self.py_observe(py), HashMap::new()) } pub fn py_step( @@ -753,13 +740,6 @@ where a: Action, ) -> (PyObject, f64, bool, bool, HashMap) { let (rew, term, trunc) = self.step(a); - // TODO/perf avoid obs cloning? - ( - self.obs.clone().into_pyarray(py).into(), - rew, - term, - trunc, - HashMap::new(), - ) + (self.py_observe(py), rew, term, trunc, HashMap::new()) } } diff --git a/gym/rust/src/lib.rs b/gym/rust/src/lib.rs index e81981ae..aea10dbb 100644 --- a/gym/rust/src/lib.rs +++ b/gym/rust/src/lib.rs @@ -19,7 +19,7 @@ enum Protocol { // investigating this is currently not worth the effort. enum BoxedEnv { - Nakamoto(Env), + Nakamoto(Env), } // Boxed here means boxed type, not boxed data. @@ -31,15 +31,15 @@ struct GenericEnv { #[pymethods] impl GenericEnv { #[new] - fn new(p: Protocol, alpha: f32, gamma: f32, horizon: f32, max_blocks: usize) -> Self { + fn new(p: Protocol, alpha: f32, gamma: f32, horizon: f32) -> Self { match p { Protocol::Nakamoto => GenericEnv { env: BoxedEnv::Nakamoto(Env::new( nakamoto::Protocol {}, + nakamoto::BaseObserver {}, alpha, gamma, horizon, - max_blocks, )), }, } diff --git a/gym/rust/src/proto/nakamoto.rs b/gym/rust/src/proto/nakamoto.rs index c1e70857..107a8d13 100644 --- a/gym/rust/src/proto/nakamoto.rs +++ b/gym/rust/src/proto/nakamoto.rs @@ -53,3 +53,25 @@ where vec![(d.miner(b), 1.)] } } + +pub struct BaseObserver {} + +impl intf::FeatureExtractor for BaseObserver +where + Block: Copy, + Miner: Copy, +{ + fn observe>( + &self, + d: &DAG, + atk: Block, + def: Block, + ca: Block, + ) -> Vec { + // TODO derive third observable of fc16.rs + let ch = d.data(ca).height; // height common ancestor + let ah = d.data(atk).height - ch; // height attacker preferred block + let dh = d.data(def).height - ch; // height defender preferred block + vec![ah as f32, dh as f32] + } +} diff --git a/gym/rust/test/test_rust.py b/gym/rust/test/test_rust.py index 437866dd..5d11afd3 100644 --- a/gym/rust/test/test_rust.py +++ b/gym/rust/test/test_rust.py @@ -16,7 +16,7 @@ def test_FC16SSZwPT(): def test_nakamoto(): proto = _rust.Protocol.Nakamoto - env = _rust.GenericEnv(proto, alpha=0.5, gamma=0.5, horizon=25, max_blocks=128) + env = _rust.GenericEnv(proto, alpha=0.5, gamma=0.5, horizon=25) for _ in range(1000): mina, maxa = env.action_range() a = randint(mina, maxa)