Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Isolate some changes from #629 #663

Merged
merged 1 commit into from
Sep 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,041 changes: 573 additions & 468 deletions src/lem/circuit.rs

Large diffs are not rendered by default.

834 changes: 475 additions & 359 deletions src/lem/eval.rs

Large diffs are not rendered by default.

236 changes: 151 additions & 85 deletions src/lem/interpreter.rs

Large diffs are not rendered by default.

37 changes: 20 additions & 17 deletions src/lem/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -201,27 +201,27 @@ macro_rules! ctrl {
$crate::lem::Ctrl::MatchTag($crate::var!($sii), cases, default)
}
};
( match $sii:ident.val { $( $cnstr:ident($val:literal) $(| $other_cnstr:ident($other_val:literal))* => $case_ops:tt )* } $(; $($def:tt)*)? ) => {
( match symbol $sii:ident { $( $sym:expr $(, $other_sym:expr)* => $case_ops:tt )* } $(; $($def:tt)*)? ) => {
{
let mut cases = indexmap::IndexMap::new();
$(
if cases.insert(
$crate::lit!($cnstr($val)),
$crate::state::lurk_sym($sym),
$crate::block!( $case_ops ),
).is_some() {
panic!("Repeated value on `match`");
};
$(
if cases.insert(
$crate::lit!($other_cnstr($other_val)),
$crate::state::lurk_sym($other_sym),
$crate::block!( $case_ops ),
).is_some() {
panic!("Repeated value on `match`");
};
)*
)*
let default = None $( .or (Some(Box::new($crate::block!( @seq {}, $($def)* )))) )?;
$crate::lem::Ctrl::MatchVal($crate::var!($sii), cases, default)
$crate::lem::Ctrl::MatchSymbol($crate::var!($sii), cases, default)
}
};
( if $x:ident == $y:ident { $($true_block:tt)+ } $($false_block:tt)+ ) => {
Expand Down Expand Up @@ -508,13 +508,13 @@ macro_rules! block {
$crate::ctrl!( match $sii.tag { $( $kind::$tag $(| $other_kind::$other_tag)* => $case_ops )* } $(; $($def)*)? )
)
};
(@seq {$($limbs:expr)*}, match $sii:ident.val { $( $cnstr:ident($val:literal) $(| $other_cnstr:ident($other_val:literal))* => $case_ops:tt )* } $(; $($def:tt)*)?) => {
(@seq {$($limbs:expr)*}, match symbol $sii:ident { $( $sym:expr $(, $other_sym:expr)* => $case_ops:tt )* } $(; $($def:tt)*)?) => {
$crate::block! (
@end
{
$($limbs)*
},
$crate::ctrl!( match $sii.val { $( $cnstr($val) $(| $other_cnstr($other_val))* => $case_ops )* } $(; $($def)*)? )
$crate::ctrl!( match symbol $sii { $( $sym $(, $other_sym)* => $case_ops )* } $(; $($def)*)? )
)
};
(@seq {$($limbs:expr)*}, if $x:ident == $y:ident { $($true_block:tt)+ } $($false_block:tt)+ ) => {
Expand Down Expand Up @@ -572,9 +572,12 @@ macro_rules! func {

#[cfg(test)]
mod tests {
use crate::lem::{Block, Ctrl, Lit, Op, Tag, Var};
use crate::state::lurk_sym;
use crate::tag::ExprTag::*;
use crate::{
lem::{Block, Ctrl, Op, Tag, Var},
state::lurk_sym,
tag::ExprTag::*,
Symbol,
};

#[inline]
fn mptr(name: &str) -> Var {
Expand All @@ -587,8 +590,8 @@ mod tests {
}

#[inline]
fn match_val(i: Var, cases: Vec<(Lit, Block)>, def: Block) -> Ctrl {
Ctrl::MatchVal(i, indexmap::IndexMap::from_iter(cases), Some(Box::new(def)))
fn match_symbol(i: Var, cases: Vec<(Symbol, Block)>, def: Block) -> Ctrl {
Ctrl::MatchSymbol(i, indexmap::IndexMap::from_iter(cases), Some(Box::new(def)))
}

#[test]
Expand Down Expand Up @@ -698,11 +701,11 @@ mod tests {
);

let moo = ctrl!(
match www.val {
Symbol("nil") => {
match symbol www {
"nil" => {
return (foo, foo, foo); // a single Ctrl will not turn into a Seq
}
Symbol("cons") => {
"cons" => {
let foo: Expr::Num;
let goo: Expr::Char;
return (foo, goo, goo);
Expand All @@ -713,18 +716,18 @@ mod tests {
);

assert!(
moo == match_val(
moo == match_symbol(
mptr("www"),
vec![
(
Lit::Symbol(lurk_sym("nil")),
lurk_sym("nil"),
Block {
ops: vec![],
ctrl: Ctrl::Return(vec![mptr("foo"), mptr("foo"), mptr("foo")]),
}
),
(
Lit::Symbol(lurk_sym("cons")),
lurk_sym("cons"),
Block {
ops: vec![
Op::Null(mptr("foo"), Tag::Expr(Num)),
Expand Down
157 changes: 103 additions & 54 deletions src/lem/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,21 +59,25 @@
//! 6. We also check for variables that are not used. If intended they should
//! be prefixed by "_"

mod circuit;
mod eval;
mod interpreter;
pub mod circuit;
pub mod eval;
pub mod interpreter;
mod macros;
mod path;
mod pointers;
pub mod pointers;
mod slot;
mod store;
pub mod store;
mod var_map;
pub mod zstore;

use crate::coprocessor::Coprocessor;
use crate::eval::lang::Lang;
use crate::field::LurkField;
use crate::symbol::Symbol;
use crate::tag::{ContTag, ExprTag, Tag as TagTrait};
use anyhow::{bail, Result};
use indexmap::IndexMap;
use serde::{Deserialize, Serialize};
use std::sync::Arc;

use self::{pointers::Ptr, slot::SlotsCounter, store::Store, var_map::VarMap};
Expand All @@ -84,26 +88,32 @@ pub type AString = Arc<str>;
/// function body, which is a `Block`
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Func {
name: String,
input_params: Vec<Var>,
output_size: usize,
body: Block,
slot: SlotsCounter,
pub name: String,
pub input_params: Vec<Var>,
pub output_size: usize,
pub body: Block,
pub slot: SlotsCounter,
}

impl<F: LurkField, C: Coprocessor<F>> From<&Lang<F, C>> for Func {
fn from(_lang: &Lang<F, C>) -> Self {
eval::eval_step().clone()
}
}

/// LEM variables
#[derive(Debug, PartialEq, Clone, Eq, Hash)]
pub struct Var(AString);

/// LEM tags
#[derive(Copy, Debug, PartialEq, Clone, Eq, Hash)]
#[derive(Copy, Debug, PartialEq, Clone, Eq, Hash, Serialize, Deserialize)]
pub enum Tag {
Expr(ExprTag),
Cont(ContTag),
Ctrl(CtrlTag),
}

#[derive(Copy, Debug, PartialEq, Clone, Eq, Hash)]
#[derive(Copy, Debug, PartialEq, Clone, Eq, Hash, Serialize, Deserialize)]
pub enum CtrlTag {
Return,
MakeThunk,
Expand Down Expand Up @@ -169,18 +179,31 @@ impl Lit {
Self::Num(num) => Ptr::num(F::from_u128(*num)),
}
}

pub fn to_ptr_cached<F: LurkField>(&self, store: &Store<F>) -> Ptr<F> {
match self {
Self::Symbol(s) => *store
.interned_symbol(s)
.expect("Symbol should have been cached"),
Self::String(s) => *store
.interned_string(s)
.expect("String should have been cached"),
Self::Num(num) => Ptr::num(F::from_u128(*num)),
}
}

pub fn from_ptr<F: LurkField>(ptr: &Ptr<F>, store: &Store<F>) -> Option<Self> {
use ExprTag::*;
use Tag::*;
match ptr.tag() {
Expr(Num) => match ptr {
Ptr::Leaf(_, f) => {
Ptr::Atom(_, f) => {
let num = LurkField::to_u128_unchecked(f);
Some(Self::Num(num))
}
_ => unreachable!(),
},
Expr(Str) => store.fetch_string(ptr).cloned().map(Lit::String),
Expr(Str) => store.fetch_string(ptr).map(Lit::String),
Expr(Sym) => store.fetch_symbol(ptr).map(Lit::Symbol),
_ => None,
}
Expand Down Expand Up @@ -212,13 +235,14 @@ pub struct Block {
#[non_exhaustive]
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Ctrl {
/// `MatchTag(x, cases)` performs a match on the tag of `x`, choosing the
/// appropriate `Block` among the ones provided in `cases`
/// `MatchTag(x, cases, def)` checks whether the tag of `x` matches some tag
/// among the ones provided in `cases`. If so, run the corresponding `Block`.
/// Run `def` otherwise
MatchTag(Var, IndexMap<Tag, Block>, Option<Box<Block>>),
/// `MatchSymbol(x, cases, def)` checks whether `x` matches some symbol among
/// the ones provided in `cases`. If so, run the corresponding `Block`. Run
/// `def` otherwise
MatchVal(Var, IndexMap<Lit, Block>, Option<Box<Block>>),
/// `MatchSymbol(x, cases, def)` requires that `x` is a symbol and checks
/// whether `x` matches some symbol among the ones provided in `cases`. If so,
/// run the corresponding `Block`. Run `def` otherwise
MatchSymbol(Var, IndexMap<Symbol, Block>, Option<Box<Block>>),
/// `IfEq(x, y, eq_block, else_block)` runs `eq_block` if `x == y`, and
/// otherwise runs `else_block`
IfEq(Var, Var, Box<Block>, Box<Block>),
Expand Down Expand Up @@ -454,31 +478,13 @@ impl Func {
None => (),
}
}
Ctrl::MatchVal(var, cases, def) => {
Ctrl::MatchSymbol(var, cases, def) => {
is_bound(var, map)?;
let mut lits = HashSet::new();
let mut kind = None;
for (lit, block) in cases {
let lit_kind = match lit {
Lit::Num(..) => 0,
Lit::String(..) => 1,
Lit::Symbol(..) => 2,
};
if let Some(kind) = kind {
if kind != lit_kind {
bail!("Only values of the same kind allowed.");
}
} else {
kind = Some(lit_kind)
}
if !lits.insert(lit) {
bail!("Case {:?} already defined.", lit);
}
for block in cases.values() {
recurse(block, return_size, map)?;
}
match def {
Some(def) => recurse(def, return_size, map)?,
None => (),
if let Some(def) = def {
recurse(def, return_size, map)?;
}
}
Ctrl::IfEq(x, y, eq_block, else_block) => {
Expand Down Expand Up @@ -548,6 +554,12 @@ impl Func {
body,
)
}

pub fn init_store<F: LurkField>(&self) -> Store<F> {
let mut store = Store::default();
self.body.intern_lits(&mut store);
store
}
}

impl Block {
Expand Down Expand Up @@ -695,18 +707,18 @@ impl Block {
};
Ctrl::MatchTag(var, IndexMap::from_iter(new_cases), new_def)
}
Ctrl::MatchVal(var, cases, def) => {
Ctrl::MatchSymbol(var, cases, def) => {
let var = map.get_cloned(&var)?;
let mut new_cases = Vec::with_capacity(cases.len());
for (lit, case) in cases {
for (sym, case) in cases {
let new_case = case.deconflict(&mut map.clone(), uniq)?;
new_cases.push((lit.clone(), new_case));
new_cases.push((sym.clone(), new_case));
}
let new_def = match def {
Some(def) => Some(Box::new(def.deconflict(map, uniq)?)),
None => None,
};
Ctrl::MatchVal(var, IndexMap::from_iter(new_cases), new_def)
Ctrl::MatchSymbol(var, IndexMap::from_iter(new_cases), new_def)
}
Ctrl::IfEq(x, y, eq_block, else_block) => {
let x = map.get_cloned(&x)?;
Expand All @@ -719,6 +731,40 @@ impl Block {
};
Ok(Block { ops, ctrl })
}

fn intern_lits<F: LurkField>(&self, store: &mut Store<F>) {
for op in &self.ops {
match op {
Op::Call(_, func, _) => func.body.intern_lits(store),
Op::Lit(_, lit) => {
lit.to_ptr(store);
}
_ => (),
}
}
match &self.ctrl {
Ctrl::IfEq(.., a, b) => {
a.intern_lits(store);
b.intern_lits(store);
}
Ctrl::MatchTag(_, cases, def) => {
cases.values().for_each(|block| block.intern_lits(store));
if let Some(def) = def {
def.intern_lits(store);
}
}
Ctrl::MatchSymbol(_, cases, def) => {
for (sym, b) in cases {
store.intern_symbol(sym);
b.intern_lits(store);
}
if let Some(def) = def {
def.intern_lits(store);
}
}
Ctrl::Return(..) => (),
}
}
}

impl Var {
Expand All @@ -731,8 +777,7 @@ impl Var {
#[cfg(test)]
mod tests {
use super::slot::SlotsCounter;
use super::{store::Store, *};
use crate::state::lurk_sym;
use super::*;
use crate::{func, lem::pointers::Ptr};
use bellpepper::util_cs::Comparable;
use bellpepper_core::test_cs::TestConstraintSystem;
Expand All @@ -747,27 +792,31 @@ mod tests {
/// - `expected_slots` gives the number of expected slots for each type of hash.
fn synthesize_test_helper(func: &Func, inputs: Vec<Ptr<Fr>>, expected_num_slots: SlotsCounter) {
use crate::tag::ContTag::*;
let store = &mut Store::default();
let store = &mut func.init_store();
let outermost = Ptr::null(Tag::Cont(Outermost));
let terminal = Ptr::null(Tag::Cont(Terminal));
let error = Ptr::null(Tag::Cont(Error));
let nil = store.intern_symbol(&lurk_sym("nil"));
let nil = store.intern_nil();
let stop_cond = |output: &[Ptr<Fr>]| output[2] == terminal || output[2] == error;

assert_eq!(func.slot, expected_num_slots);

let computed_num_constraints = func.num_constraints::<Fr>(store);

let log_fmt = |_: usize, _: &[Ptr<Fr>], _: &[Ptr<Fr>], _: &Store<Fr>| String::default();

let mut cs_prev = None;
for input in inputs.into_iter() {
let input = vec![input, nil, outermost];
let (frames, _) = func.call_until(input, store, stop_cond).unwrap();
let input = [input, nil, outermost];
let (frames, ..) = func
.call_until(&input, store, stop_cond, 10, log_fmt)
.unwrap();

let mut cs;

for frame in frames.clone() {
for frame in frames {
cs = TestConstraintSystem::<Fr>::new();
func.synthesize(&mut cs, store, &frame).unwrap();
func.synthesize_frame_aux(&mut cs, store, &frame).unwrap();
assert!(cs.is_satisfied());
assert_eq!(computed_num_constraints, cs.num_constraints());
if let Some(cs_prev) = cs_prev {
Expand Down
Loading
Loading