Skip to content

Commit

Permalink
Draft fc16-like observation space for generic model
Browse files Browse the repository at this point in the history
  • Loading branch information
pkel committed Apr 10, 2024
1 parent cf85fd8 commit 8abf0d1
Show file tree
Hide file tree
Showing 9 changed files with 98 additions and 85 deletions.
16 changes: 8 additions & 8 deletions gym/rust/cpr_gym_rs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
from . import _rust
from gymnasium.envs.registration import register

default_assumptions = dict(
alpha=0.42,
gamma=0.84,
horizon=100,
)

register(
id="FC16SSZwPT-v0",
entry_point="cpr_gym_rs.envs:FC16SSZwPT",
nondeterministic=True,
order_enforce=False,
max_episode_steps=1000, # forced termination, not Markovian
kwargs=dict(
alpha=1 / 3,
gamma=0.5,
horizon=100,
),
kwargs=default_assumptions,
)

register(
Expand All @@ -20,7 +22,5 @@
nondeterministic=True,
order_enforce=False,
max_episode_steps=1000, # forced termination, not Markovian
kwargs=dict(
protocol="nakamoto", alpha=1 / 3, gamma=0.5, horizon=25, max_blocks=128
),
kwargs=dict(protocol="nakamoto") | default_assumptions,
)
8 changes: 3 additions & 5 deletions gym/rust/cpr_gym_rs/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,10 @@ def step(self, action):
class Generic(gymnasium.Env):
protocols = {"nakamoto": lambda: _rust.Protocol.Nakamoto}

def __init__(self, protocol, *args, alpha, gamma, horizon, max_blocks, **kwargs):
def __init__(self, protocol, *args, alpha, gamma, horizon, **kwargs):
p = Generic.protocols[protocol](*args, **kwargs)

self.rs_env = _rust.GenericEnv(
p, alpha=alpha, gamma=gamma, horizon=horizon, max_blocks=max_blocks
)
self.rs_env = _rust.GenericEnv(p, alpha=alpha, gamma=gamma, horizon=horizon)

self.action_space = gymnasium.spaces.Discrete(
255, start=-127
Expand All @@ -37,7 +35,7 @@ def __init__(self, protocol, *args, alpha, gamma, horizon, max_blocks, **kwargs)
obs, _info = self.rs_env.reset()

self.observation_space = gymnasium.spaces.Box(
shape=obs.shape, low=0, high=2, dtype=numpy.uint8
shape=obs.shape, low=float("-inf"), high=float("inf"), dtype=numpy.float32
)

def reset(self, *args, seed=None, options=None):
Expand Down
4 changes: 2 additions & 2 deletions gym/rust/hyperparams/dqn.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,12 @@ FC16SSZwPT-v0:
Nakamoto-v0:
n_envs: 24
n_timesteps: 50_000_000
buffer_size: 1_000_000 # how many experiences recorded for sampling
buffer_size: 5_000_000 # how many experiences recorded for sampling
train_freq: 10_000 # how regularly to update / length of rollout in steps
gradient_steps: 100 # how many updates per rollout (at the end of each rollout)
learning_rate: 0.001 # weight of individual update
batch_size: 500 # how many experiences to sample for each update
learning_starts: 100_000 # steps before learning starts
learning_starts: 500_000 # steps before learning starts
gamma: 1 # discount factor
target_update_interval: 10_000 # how often to update target network (steps)
tau: 0.01 # weight of target network update; 1 implies hard update
Expand Down
8 changes: 4 additions & 4 deletions gym/rust/hyperparams/ppo.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@ FC16SSZwPT-v0:

Nakamoto-v0:
n_timesteps: 4_000_000
n_envs: 12 # = size of rollout buffer / n_steps
n_steps: 128 # = size of rollout buffer / n_envs
n_envs: 32 # = size of rollout buffer / n_steps
n_steps: 8192 # = size of rollout buffer / n_envs
learning_rate: 0.0001 # weight of each update
batch_size: 128 # number of steps to consider per update
n_epochs: 1 # how often to process each rollout buffer
batch_size: 512 # number of steps to consider per update
n_epochs: 10 # how often to process each rollout buffer
ent_coef: 0.05 # entropy, exploration term
policy: 'MlpPolicy'
13 changes: 13 additions & 0 deletions gym/rust/src/generic/intf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,16 @@ pub trait Protocol<Block, Miner, Data> {
// Given block `b`, what rewards where allocated since `pred(b)`?
// Block `b` is guaranteed to be part of a linear history of some tip.
}

pub trait FeatureExtractor<Block, Miner, Data> {
// Interface for specifying protocol-dependent feature extractors (WIP)

// BlockDAG with attacker and defender entrypoints and common ancestor in linear history.
fn observe<DAG: BlockDAG<Block, Miner, Data>>(
&self,
d: &DAG,
atk: Block,
def: Block,
ca: Block,
) -> Vec<f32>;
}
104 changes: 42 additions & 62 deletions gym/rust/src/generic/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ type EdgeWeight = ();
type Graph<ProtoData> = petgraph::graph::DiGraph<NodeWeight<ProtoData>, EdgeWeight>;

pub mod intf;
use intf::{BlockDAG, Protocol};
use intf::{BlockDAG, FeatureExtractor, Protocol};

impl<ProtoData> BlockDAG<Block, Party, ProtoData> for Graph<ProtoData> {
fn parents(&self, b: Block) -> Vec<Block> {
Expand Down Expand Up @@ -302,23 +302,23 @@ use rand::{Rng, SeedableRng};

// environment logic

use numpy::ndarray::Array2;

pub struct Env<P, D>
pub struct Env<P, D, O>
where
P: Protocol<Block, Party, D>,
O: FeatureExtractor<Block, Party, D>,
{
g: Graph<D>,
p: P,
o: O,
a: AvailableActions,
alpha: f32,
gamma: f32,
horizon: f32,
rng: StdRng,
pub obs: Array2<u8>,
ca: Block, // cached common ancestor
}

fn init_graph<P, D>(g: &mut Graph<D>, p: &P)
fn init_graph<P, D>(g: &mut Graph<D>, p: &P) -> Block
where
P: Protocol<Block, Party, D>,
{
Expand All @@ -328,12 +328,13 @@ where
nv: NView::Honest,
pd: p.init(),
};
g.add_node(genesis);
g.add_node(genesis)
}

impl<P, D> BlockDAG<Block, Party, D> for Env<P, D>
impl<P, D, O> BlockDAG<Block, Party, D> for Env<P, D, O>
where
P: Protocol<Block, Party, D>,
O: FeatureExtractor<Block, Party, D>,
{
fn parents(&self, b: Block) -> Vec<Block> {
self.g.parents(b)
Expand All @@ -349,44 +350,41 @@ where
}
}

impl<P, D> Env<P, D>
impl<P, D, O> Env<P, D, O>
where
P: Protocol<Block, Party, D>,
O: FeatureExtractor<Block, Party, D>,
{
pub fn new(p: P, alpha: f32, gamma: f32, horizon: f32, max_blocks: usize) -> Self {
pub fn new(p: P, o: O, alpha: f32, gamma: f32, horizon: f32) -> Self {
let mut g = Graph::new();
init_graph(&mut g, &p);
let genesis = init_graph(&mut g, &p);
let a = available_actions(&g);
let mut self_ = Env {
Env {
g,
p,
o,
a,
alpha,
gamma,
horizon,
rng: StdRng::from_entropy(),
obs: Array2::zeros((max_blocks, max_blocks + 3)),
};
self_.observe();
self_
ca: genesis,
}
}

pub fn describe(&self) -> String {
format!(
"Generic {{ alpha: {}, gamma: {}, horizon: {:?}, obs_bytes: {} }}",
self.alpha,
self.gamma,
self.horizon,
self.obs.len()
"Generic {{ alpha: {}, gamma: {}, horizon: {:?} }}",
self.alpha, self.gamma, self.horizon,
)
// TODO: possible actions
}

fn reset(&mut self) {
self.g.clear();
init_graph(&mut self.g, &self.p);
let genesis = init_graph(&mut self.g, &self.p);
self.a = available_actions(&self.g);
self.observe();
self.ca = genesis;
}

pub fn describe_action(&self, a: Action) -> String {
Expand All @@ -409,8 +407,11 @@ where
}

fn step(&mut self, a: Action) -> (f64, bool, bool) {
// check the dag invariants
assert!(dag_check(&self.g));

// derive pre-action state
let old_ca = *self.common_history().last().unwrap(); // TODO/perf: persist in self
let old_ca = self.ca;

// decode action & apply
match decode_action(self.guarded_action(a)) {
Expand Down Expand Up @@ -448,12 +449,13 @@ where
// progress := blocks mined or blocks on defender chain
// reward := number of blocks rewritten in defender chain
let h = self.common_history();

let (terminate, progress, reward);
{
let mut prg = 0.;
let mut rew_atk = 0.;
let mut rew_def = 0.;
for b in h.into_iter().rev() {
for &b in h.iter().rev() {
if b == old_ca {
break;
}
Expand All @@ -472,11 +474,11 @@ where
terminate = self.env_terminates(progress);
}

// update observation buffer
self.observe();
// cache common ancestor; observation and next step rely on this
self.ca = *h.last().unwrap();

// truncate when observation buffer reaches limit
let truncate = self.g.node_count() >= self.obs.dim().0;
// we do not need truncation
let truncate = false;

// return
(<f64>::try_from(reward).unwrap(), terminate, truncate)
Expand All @@ -490,28 +492,6 @@ where
self.g.node_weight_mut(b).unwrap()
}

fn observe(&mut self) {
// observation is triggered often; check the dag invariants here
assert!(dag_check(&self.g));
// reset buffer
self.obs.fill(0);
// fill buffer
for src in self.g.node_indices() {
// adjacency
for dst in self.parents(src) {
self.obs[[src.index(), dst.index() + 3]] = 1;
}
// color / weight
let w = self.weight(src);
let av = w.av as u8;
let dv = w.dv as u8;
let nv = w.nv as u8;
self.obs[[src.index(), 0]] = av;
self.obs[[src.index(), 1]] = dv;
self.obs[[src.index(), 2]] = nv;
}
}

fn entrypoint(&self, m: Party) -> Block {
if m == Party::Defender {
self.g
Expand Down Expand Up @@ -737,14 +717,21 @@ use numpy::IntoPyArray;
use pyo3::prelude::*;
use std::collections::HashMap;

impl<P, D> Env<P, D>
impl<P, D, O> Env<P, D, O>
where
P: Protocol<Block, Party, D>,
O: FeatureExtractor<Block, Party, D>,
{
fn py_observe(&self, py: Python) -> PyObject {
let atk = self.entrypoint(Party::Attacker);
let def = self.entrypoint(Party::Defender);
let obs: Vec<f32> = self.o.observe(&self.g, atk, def, self.ca);
obs.into_pyarray(py).into()
}

pub fn py_reset(&mut self, py: Python) -> (PyObject, HashMap<String, PyObject>) {
self.reset();
// TODO/perf avoid obs cloning?
(self.obs.clone().into_pyarray(py).into(), HashMap::new())
(self.py_observe(py), HashMap::new())
}

pub fn py_step(
Expand All @@ -753,13 +740,6 @@ where
a: Action,
) -> (PyObject, f64, bool, bool, HashMap<String, PyObject>) {
let (rew, term, trunc) = self.step(a);
// TODO/perf avoid obs cloning?
(
self.obs.clone().into_pyarray(py).into(),
rew,
term,
trunc,
HashMap::new(),
)
(self.py_observe(py), rew, term, trunc, HashMap::new())
}
}
6 changes: 3 additions & 3 deletions gym/rust/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ enum Protocol {
// investigating this is currently not worth the effort.

enum BoxedEnv {
Nakamoto(Env<nakamoto::Protocol, nakamoto::Data>),
Nakamoto(Env<nakamoto::Protocol, nakamoto::Data, nakamoto::BaseObserver>),
}
// Boxed here means boxed type, not boxed data.

Expand All @@ -31,15 +31,15 @@ struct GenericEnv {
#[pymethods]
impl GenericEnv {
#[new]
fn new(p: Protocol, alpha: f32, gamma: f32, horizon: f32, max_blocks: usize) -> Self {
fn new(p: Protocol, alpha: f32, gamma: f32, horizon: f32) -> Self {
match p {
Protocol::Nakamoto => GenericEnv {
env: BoxedEnv::Nakamoto(Env::new(
nakamoto::Protocol {},
nakamoto::BaseObserver {},
alpha,
gamma,
horizon,
max_blocks,
)),
},
}
Expand Down
22 changes: 22 additions & 0 deletions gym/rust/src/proto/nakamoto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,25 @@ where
vec![(d.miner(b), 1.)]
}
}

pub struct BaseObserver {}

impl<Block, Miner> intf::FeatureExtractor<Block, Miner, Data> for BaseObserver
where
Block: Copy,
Miner: Copy,
{
fn observe<DAG: BlockDAG<Block, Miner, Data>>(
&self,
d: &DAG,
atk: Block,
def: Block,
ca: Block,
) -> Vec<f32> {
// TODO derive third observable of fc16.rs
let ch = d.data(ca).height; // height common ancestor
let ah = d.data(atk).height - ch; // height attacker preferred block
let dh = d.data(def).height - ch; // height defender preferred block
vec![ah as f32, dh as f32]
}
}
2 changes: 1 addition & 1 deletion gym/rust/test/test_rust.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def test_FC16SSZwPT():

def test_nakamoto():
proto = _rust.Protocol.Nakamoto
env = _rust.GenericEnv(proto, alpha=0.5, gamma=0.5, horizon=25, max_blocks=128)
env = _rust.GenericEnv(proto, alpha=0.5, gamma=0.5, horizon=25)
for _ in range(1000):
mina, maxa = env.action_range()
a = randint(mina, maxa)
Expand Down

0 comments on commit 8abf0d1

Please sign in to comment.