Skip to content

Commit

Permalink
Observe available actions, simulate honest behaviour.
Browse files Browse the repository at this point in the history
  • Loading branch information
pkel committed Apr 17, 2024
1 parent 177c2ef commit 34852d0
Show file tree
Hide file tree
Showing 2 changed files with 141 additions and 90 deletions.
177 changes: 95 additions & 82 deletions gym/rust/misc/nakamoto-baseline.ipynb

Large diffs are not rendered by default.

54 changes: 46 additions & 8 deletions gym/rust/src/generic/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -877,6 +877,49 @@ where
}
common
}

fn observe(&self) -> Vec<f32> {
// apply protocol-dependent feature extractor
let mut obs: Vec<f32> = {
let atk = self.entrypoint(Party::Attacker);
let def = self.entrypoint(Party::Defender);
let ktad = |b: Block| -> bool { self.weight(b).dv != DView::Unknown };
let mbd = |b: Block| -> bool { self.weight(b).nv == NView::Honest };
self.o.observe(&self.g, atk, def, self.ca, &ktad, &mbd)
};

// add information about available actions
// a) release
let n = self.a.release.len();
if n > 0 {
obs.push(encode_action(ActionHum::Release((n - 1) as u8)))
} else {
obs.push(0.)
}
// b) consider
let n = self.a.consider.len();
if n > 0 {
obs.push(encode_action(ActionHum::Consider((n - 1) as u8)))
} else {
obs.push(0.)
}

obs
}

fn low(&self) -> Vec<f32> {
let mut obs = self.o.low();
obs.push(-1.); // mininum release action
obs.push(0.); // minimum consider action
obs
}

fn high(&self) -> Vec<f32> {
let mut obs = self.o.high();
obs.push(0.); // maximum release action
obs.push(1.); // maximum consider action
obs
}
}

// We need a slightly different interface for Python-interop
Expand All @@ -900,12 +943,7 @@ where
O: FeatureExtractor<Block, Party, D>,
{
fn py_observe(&self, py: Python) -> PyObject {
let atk = self.entrypoint(Party::Attacker);
let def = self.entrypoint(Party::Defender);
let ktad = |b: Block| -> bool { self.weight(b).dv != DView::Unknown };
let mbd = |b: Block| -> bool { self.weight(b).nv == NView::Honest };
let obs: Vec<f32> = self.o.observe(&self.g, atk, def, self.ca, &ktad, &mbd);
obs.into_pyarray(py).into()
self.observe().into_pyarray(py).into()
}

fn py_info(&self, py: Python, i: Info) -> HashMap<String, PyObject> {
Expand All @@ -917,11 +955,11 @@ where
}

pub fn py_low(&self, py: Python) -> PyObject {
self.o.low().into_pyarray(py).into()
self.low().into_pyarray(py).into()
}

pub fn py_high(&self, py: Python) -> PyObject {
self.o.high().into_pyarray(py).into()
self.high().into_pyarray(py).into()
}

pub fn py_reset(&mut self, py: Python) -> (PyObject, HashMap<String, PyObject>) {
Expand Down

0 comments on commit 34852d0

Please sign in to comment.