Skip to content

Commit

Permalink
migrate some changes from #629
Browse files Browse the repository at this point in the history
  • Loading branch information
arthurpaulino committed Sep 12, 2023
1 parent fc9156d commit 502e0c6
Show file tree
Hide file tree
Showing 11 changed files with 2,031 additions and 1,142 deletions.
1,041 changes: 573 additions & 468 deletions src/lem/circuit.rs

Large diffs are not rendered by default.

834 changes: 475 additions & 359 deletions src/lem/eval.rs

Large diffs are not rendered by default.

236 changes: 151 additions & 85 deletions src/lem/interpreter.rs

Large diffs are not rendered by default.

37 changes: 20 additions & 17 deletions src/lem/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -201,27 +201,27 @@ macro_rules! ctrl {
$crate::lem::Ctrl::MatchTag($crate::var!($sii), cases, default)
}
};
( match $sii:ident.val { $( $cnstr:ident($val:literal) $(| $other_cnstr:ident($other_val:literal))* => $case_ops:tt )* } $(; $($def:tt)*)? ) => {
( match symbol $sii:ident { $( $sym:expr $(, $other_sym:expr)* => $case_ops:tt )* } $(; $($def:tt)*)? ) => {
{
let mut cases = indexmap::IndexMap::new();
$(
if cases.insert(
$crate::lit!($cnstr($val)),
$crate::state::lurk_sym($sym),
$crate::block!( $case_ops ),
).is_some() {
panic!("Repeated value on `match`");
};
$(
if cases.insert(
$crate::lit!($other_cnstr($other_val)),
$crate::state::lurk_sym($other_sym),
$crate::block!( $case_ops ),
).is_some() {
panic!("Repeated value on `match`");
};
)*
)*
let default = None $( .or (Some(Box::new($crate::block!( @seq {}, $($def)* )))) )?;
$crate::lem::Ctrl::MatchVal($crate::var!($sii), cases, default)
$crate::lem::Ctrl::MatchSymbol($crate::var!($sii), cases, default)
}
};
( if $x:ident == $y:ident { $($true_block:tt)+ } $($false_block:tt)+ ) => {
Expand Down Expand Up @@ -508,13 +508,13 @@ macro_rules! block {
$crate::ctrl!( match $sii.tag { $( $kind::$tag $(| $other_kind::$other_tag)* => $case_ops )* } $(; $($def)*)? )
)
};
(@seq {$($limbs:expr)*}, match $sii:ident.val { $( $cnstr:ident($val:literal) $(| $other_cnstr:ident($other_val:literal))* => $case_ops:tt )* } $(; $($def:tt)*)?) => {
(@seq {$($limbs:expr)*}, match symbol $sii:ident { $( $sym:expr $(, $other_sym:expr)* => $case_ops:tt )* } $(; $($def:tt)*)?) => {
$crate::block! (
@end
{
$($limbs)*
},
$crate::ctrl!( match $sii.val { $( $cnstr($val) $(| $other_cnstr($other_val))* => $case_ops )* } $(; $($def)*)? )
$crate::ctrl!( match symbol $sii { $( $sym $(, $other_sym)* => $case_ops )* } $(; $($def)*)? )
)
};
(@seq {$($limbs:expr)*}, if $x:ident == $y:ident { $($true_block:tt)+ } $($false_block:tt)+ ) => {
Expand Down Expand Up @@ -572,9 +572,12 @@ macro_rules! func {

#[cfg(test)]
mod tests {
use crate::lem::{Block, Ctrl, Lit, Op, Tag, Var};
use crate::state::lurk_sym;
use crate::tag::ExprTag::*;
use crate::{
lem::{Block, Ctrl, Op, Tag, Var},
state::lurk_sym,
tag::ExprTag::*,
Symbol,
};

#[inline]
fn mptr(name: &str) -> Var {
Expand All @@ -587,8 +590,8 @@ mod tests {
}

#[inline]
fn match_val(i: Var, cases: Vec<(Lit, Block)>, def: Block) -> Ctrl {
Ctrl::MatchVal(i, indexmap::IndexMap::from_iter(cases), Some(Box::new(def)))
fn match_symbol(i: Var, cases: Vec<(Symbol, Block)>, def: Block) -> Ctrl {
Ctrl::MatchSymbol(i, indexmap::IndexMap::from_iter(cases), Some(Box::new(def)))
}

#[test]
Expand Down Expand Up @@ -698,11 +701,11 @@ mod tests {
);

let moo = ctrl!(
match www.val {
Symbol("nil") => {
match symbol www {
"nil" => {
return (foo, foo, foo); // a single Ctrl will not turn into a Seq
}
Symbol("cons") => {
"cons" => {
let foo: Expr::Num;
let goo: Expr::Char;
return (foo, goo, goo);
Expand All @@ -713,18 +716,18 @@ mod tests {
);

assert!(
moo == match_val(
moo == match_symbol(
mptr("www"),
vec![
(
Lit::Symbol(lurk_sym("nil")),
lurk_sym("nil"),
Block {
ops: vec![],
ctrl: Ctrl::Return(vec![mptr("foo"), mptr("foo"), mptr("foo")]),
}
),
(
Lit::Symbol(lurk_sym("cons")),
lurk_sym("cons"),
Block {
ops: vec![
Op::Null(mptr("foo"), Tag::Expr(Num)),
Expand Down
157 changes: 103 additions & 54 deletions src/lem/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,21 +59,25 @@
//! 6. We also check for variables that are not used. If intended they should
//! be prefixed by "_"

mod circuit;
mod eval;
mod interpreter;
pub mod circuit;
pub mod eval;
pub mod interpreter;
mod macros;
mod path;
mod pointers;
pub mod pointers;
mod slot;
mod store;
pub mod store;
mod var_map;
pub mod zstore;

use crate::coprocessor::Coprocessor;
use crate::eval::lang::Lang;
use crate::field::LurkField;
use crate::symbol::Symbol;
use crate::tag::{ContTag, ExprTag, Tag as TagTrait};
use anyhow::{bail, Result};
use indexmap::IndexMap;
use serde::{Deserialize, Serialize};
use std::sync::Arc;

use self::{pointers::Ptr, slot::SlotsCounter, store::Store, var_map::VarMap};
Expand All @@ -84,26 +88,32 @@ pub type AString = Arc<str>;
/// function body, which is a `Block`
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Func {
name: String,
input_params: Vec<Var>,
output_size: usize,
body: Block,
slot: SlotsCounter,
pub name: String,
pub input_params: Vec<Var>,
pub output_size: usize,
pub body: Block,
pub slot: SlotsCounter,
}

impl<F: LurkField, C: Coprocessor<F>> From<&Lang<F, C>> for Func {
fn from(_lang: &Lang<F, C>) -> Self {
eval::eval_step().clone()
}
}

/// LEM variables
#[derive(Debug, PartialEq, Clone, Eq, Hash)]
pub struct Var(AString);

/// LEM tags
#[derive(Copy, Debug, PartialEq, Clone, Eq, Hash)]
#[derive(Copy, Debug, PartialEq, Clone, Eq, Hash, Serialize, Deserialize)]
pub enum Tag {
Expr(ExprTag),
Cont(ContTag),
Ctrl(CtrlTag),
}

#[derive(Copy, Debug, PartialEq, Clone, Eq, Hash)]
#[derive(Copy, Debug, PartialEq, Clone, Eq, Hash, Serialize, Deserialize)]
pub enum CtrlTag {
Return,
MakeThunk,
Expand Down Expand Up @@ -169,18 +179,31 @@ impl Lit {
Self::Num(num) => Ptr::num(F::from_u128(*num)),
}
}

pub fn to_ptr_cached<F: LurkField>(&self, store: &Store<F>) -> Ptr<F> {
match self {
Self::Symbol(s) => *store
.interned_symbol(s)
.expect("Symbol should have been cached"),
Self::String(s) => *store
.interned_string(s)
.expect("String should have been cached"),
Self::Num(num) => Ptr::num(F::from_u128(*num)),
}
}

pub fn from_ptr<F: LurkField>(ptr: &Ptr<F>, store: &Store<F>) -> Option<Self> {
use ExprTag::*;
use Tag::*;
match ptr.tag() {
Expr(Num) => match ptr {
Ptr::Leaf(_, f) => {
Ptr::Atom(_, f) => {
let num = LurkField::to_u128_unchecked(f);
Some(Self::Num(num))
}
_ => unreachable!(),
},
Expr(Str) => store.fetch_string(ptr).cloned().map(Lit::String),
Expr(Str) => store.fetch_string(ptr).map(Lit::String),
Expr(Sym) => store.fetch_symbol(ptr).map(Lit::Symbol),
_ => None,
}
Expand Down Expand Up @@ -212,13 +235,14 @@ pub struct Block {
#[non_exhaustive]
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Ctrl {
/// `MatchTag(x, cases)` performs a match on the tag of `x`, choosing the
/// appropriate `Block` among the ones provided in `cases`
/// `MatchTag(x, cases, def)` checks whether the tag of `x` matches some tag
/// among the ones provided in `cases`. If so, run the corresponding `Block`.
/// Run `def` otherwise
MatchTag(Var, IndexMap<Tag, Block>, Option<Box<Block>>),
/// `MatchSymbol(x, cases, def)` checks whether `x` matches some symbol among
/// the ones provided in `cases`. If so, run the corresponding `Block`. Run
/// `def` otherwise
MatchVal(Var, IndexMap<Lit, Block>, Option<Box<Block>>),
/// `MatchSymbol(x, cases, def)` requires that `x` is a symbol and checks
/// whether `x` matches some symbol among the ones provided in `cases`. If so,
/// run the corresponding `Block`. Run `def` otherwise
MatchSymbol(Var, IndexMap<Symbol, Block>, Option<Box<Block>>),
/// `IfEq(x, y, eq_block, else_block)` runs `eq_block` if `x == y`, and
/// otherwise runs `else_block`
IfEq(Var, Var, Box<Block>, Box<Block>),
Expand Down Expand Up @@ -454,31 +478,13 @@ impl Func {
None => (),
}
}
Ctrl::MatchVal(var, cases, def) => {
Ctrl::MatchSymbol(var, cases, def) => {
is_bound(var, map)?;
let mut lits = HashSet::new();
let mut kind = None;
for (lit, block) in cases {
let lit_kind = match lit {
Lit::Num(..) => 0,
Lit::String(..) => 1,
Lit::Symbol(..) => 2,
};
if let Some(kind) = kind {
if kind != lit_kind {
bail!("Only values of the same kind allowed.");
}
} else {
kind = Some(lit_kind)
}
if !lits.insert(lit) {
bail!("Case {:?} already defined.", lit);
}
for block in cases.values() {
recurse(block, return_size, map)?;
}
match def {
Some(def) => recurse(def, return_size, map)?,
None => (),
if let Some(def) = def {
recurse(def, return_size, map)?;
}
}
Ctrl::IfEq(x, y, eq_block, else_block) => {
Expand Down Expand Up @@ -548,6 +554,12 @@ impl Func {
body,
)
}

pub fn init_store<F: LurkField>(&self) -> Store<F> {
let mut store = Store::default();
self.body.intern_lits(&mut store);
store
}
}

impl Block {
Expand Down Expand Up @@ -695,18 +707,18 @@ impl Block {
};
Ctrl::MatchTag(var, IndexMap::from_iter(new_cases), new_def)
}
Ctrl::MatchVal(var, cases, def) => {
Ctrl::MatchSymbol(var, cases, def) => {
let var = map.get_cloned(&var)?;
let mut new_cases = Vec::with_capacity(cases.len());
for (lit, case) in cases {
for (sym, case) in cases {
let new_case = case.deconflict(&mut map.clone(), uniq)?;
new_cases.push((lit.clone(), new_case));
new_cases.push((sym.clone(), new_case));
}
let new_def = match def {
Some(def) => Some(Box::new(def.deconflict(map, uniq)?)),
None => None,
};
Ctrl::MatchVal(var, IndexMap::from_iter(new_cases), new_def)
Ctrl::MatchSymbol(var, IndexMap::from_iter(new_cases), new_def)
}
Ctrl::IfEq(x, y, eq_block, else_block) => {
let x = map.get_cloned(&x)?;
Expand All @@ -719,6 +731,40 @@ impl Block {
};
Ok(Block { ops, ctrl })
}

fn intern_lits<F: LurkField>(&self, store: &mut Store<F>) {
for op in &self.ops {
match op {
Op::Call(_, func, _) => func.body.intern_lits(store),
Op::Lit(_, lit) => {
lit.to_ptr(store);
}
_ => (),
}
}
match &self.ctrl {
Ctrl::IfEq(.., a, b) => {
a.intern_lits(store);
b.intern_lits(store);
}
Ctrl::MatchTag(_, cases, def) => {
cases.values().for_each(|block| block.intern_lits(store));
if let Some(def) = def {
def.intern_lits(store);
}
}
Ctrl::MatchSymbol(_, cases, def) => {
for (sym, b) in cases {
store.intern_symbol(sym);
b.intern_lits(store);
}
if let Some(def) = def {
def.intern_lits(store);
}
}
Ctrl::Return(..) => (),
}
}
}

impl Var {
Expand All @@ -731,8 +777,7 @@ impl Var {
#[cfg(test)]
mod tests {
use super::slot::SlotsCounter;
use super::{store::Store, *};
use crate::state::lurk_sym;
use super::*;
use crate::{func, lem::pointers::Ptr};
use bellpepper::util_cs::Comparable;
use bellpepper_core::test_cs::TestConstraintSystem;
Expand All @@ -747,27 +792,31 @@ mod tests {
/// - `expected_slots` gives the number of expected slots for each type of hash.
fn synthesize_test_helper(func: &Func, inputs: Vec<Ptr<Fr>>, expected_num_slots: SlotsCounter) {
use crate::tag::ContTag::*;
let store = &mut Store::default();
let store = &mut func.init_store();
let outermost = Ptr::null(Tag::Cont(Outermost));
let terminal = Ptr::null(Tag::Cont(Terminal));
let error = Ptr::null(Tag::Cont(Error));
let nil = store.intern_symbol(&lurk_sym("nil"));
let nil = store.intern_nil();
let stop_cond = |output: &[Ptr<Fr>]| output[2] == terminal || output[2] == error;

assert_eq!(func.slot, expected_num_slots);

let computed_num_constraints = func.num_constraints::<Fr>(store);

let log_fmt = |_: usize, _: &[Ptr<Fr>], _: &[Ptr<Fr>], _: &Store<Fr>| String::default();

let mut cs_prev = None;
for input in inputs.into_iter() {
let input = vec![input, nil, outermost];
let (frames, _) = func.call_until(input, store, stop_cond).unwrap();
let input = [input, nil, outermost];
let (frames, ..) = func
.call_until(&input, store, stop_cond, 10, log_fmt)
.unwrap();

let mut cs;

for frame in frames.clone() {
for frame in frames {
cs = TestConstraintSystem::<Fr>::new();
func.synthesize(&mut cs, store, &frame).unwrap();
func.synthesize_frame_aux(&mut cs, store, &frame).unwrap();
assert!(cs.is_satisfied());
assert_eq!(computed_num_constraints, cs.num_constraints());
if let Some(cs_prev) = cs_prev {
Expand Down
Loading

0 comments on commit 502e0c6

Please sign in to comment.