From 6391af9f343345cc8c0673c26ba72cac7e7e52e7 Mon Sep 17 00:00:00 2001 From: Arthur Paulino Date: Wed, 13 Sep 2023 09:19:13 -0300 Subject: [PATCH] some isolated changes from #629 --- src/lem/circuit.rs | 1041 ++++++++++++++++++++++------------------ src/lem/eval.rs | 834 ++++++++++++++++++-------------- src/lem/interpreter.rs | 236 +++++---- src/lem/macros.rs | 37 +- src/lem/mod.rs | 157 +++--- src/lem/path.rs | 22 +- src/lem/pointers.rs | 98 ++-- src/lem/slot.rs | 7 +- src/lem/store.rs | 577 +++++++++++++++++----- src/lem/var_map.rs | 2 +- src/lem/zstore.rs | 162 +++++++ 11 files changed, 2031 insertions(+), 1142 deletions(-) create mode 100644 src/lem/zstore.rs diff --git a/src/lem/circuit.rs b/src/lem/circuit.rs index 8a13091130..880119d741 100644 --- a/src/lem/circuit.rs +++ b/src/lem/circuit.rs @@ -21,9 +21,7 @@ //! on a concrete or a virtual path and use such booleans as the premises to build //! the constraints we care about with implication gadgets. -use std::collections::{HashMap, HashSet, VecDeque}; - -use anyhow::{Context, Result}; +use anyhow::{anyhow, Result}; use bellpepper_core::{ ConstraintSystem, SynthesisError, { @@ -31,6 +29,7 @@ use bellpepper_core::{ num::AllocatedNum, }, }; +use std::collections::{HashMap, HashSet, VecDeque}; use crate::circuit::gadgets::{ constraints::{ @@ -58,268 +57,369 @@ use super::{ /// Manages global allocations for constants in a constraint system #[derive(Default)] -pub(crate) struct GlobalAllocator(HashMap, AllocatedNum>); +pub struct GlobalAllocator(HashMap, AllocatedNum>); #[inline] fn allocate_num>( cs: &mut CS, namespace: &str, value: F, -) -> Result> { +) -> Result, SynthesisError> { AllocatedNum::alloc(cs.namespace(|| namespace), || Ok(value)) - .with_context(|| format!("allocation for '{namespace}' failed")) +} + +impl GlobalAllocator { + /// Checks if the allocation for a numeric variable has already been cached. + /// If so, don't do anything. Otherwise, allocate and cache it. + fn alloc_const>(&mut self, cs: &mut CS, f: F) { + self.0.entry(FWrap(f)).or_insert_with(|| { + allocate_constant( + &mut cs.namespace(|| format!("allocate constant {}", f.hex_digits())), + f, + ) + }); + } + + #[inline] + fn get_allocated_const(&self, f: F) -> Result<&AllocatedNum> { + self.0 + .get(&FWrap(f)) + .ok_or_else(|| anyhow!("Global allocation not found for {}", f.hex_digits())) + } + + #[inline] + fn get_allocated_const_cloned(&self, f: F) -> Result> { + self.get_allocated_const(f).cloned() + } +} + +pub(crate) type BoundAllocations = VarMap>; + +/// Allocates an unconstrained pointer +fn allocate_ptr>( + cs: &mut CS, + z_ptr: &ZPtr, + var: &Var, + bound_allocations: &mut BoundAllocations, +) -> Result> { + let allocated_tag = allocate_num(cs, &format!("allocate {var}'s tag"), z_ptr.tag.to_field())?; + let allocated_hash = allocate_num(cs, &format!("allocate {var}'s hash"), z_ptr.hash)?; + let allocated_ptr = AllocatedPtr::from_parts(allocated_tag, allocated_hash); + bound_allocations.insert(var.clone(), allocated_ptr.clone()); + Ok(allocated_ptr) +} + +/// Allocates an unconstrained pointer for each output of the frame +fn allocate_output>( + cs: &mut CS, + store: &Store, + frame: &Frame, + bound_allocations: &mut BoundAllocations, +) -> Result>> { + frame + .output + .iter() + .enumerate() + .map(|(i, ptr)| { + allocate_ptr( + cs, + &store.hash_ptr(ptr)?, + &Var(format!("output[{}]", i).into()), + bound_allocations, + ) + }) + .collect() } #[inline] -fn allocate_const>( +fn allocate_preimg_component_for_slot>( cs: &mut CS, - namespace: &str, + slot: &Slot, + component_idx: usize, value: F, -) -> AllocatedNum { - allocate_constant(&mut cs.namespace(|| namespace), value) +) -> Result, SynthesisError> { + allocate_num( + cs, + &format!("component {component_idx} for slot {slot}"), + value, + ) } -impl GlobalAllocator { - /// Checks if the allocation for a numeric variable has already been cached. - /// If so, return the cached allocation variable. Allocate as a constant, - /// cache and return otherwise. - pub(crate) fn get_or_alloc_const>( - &mut self, - cs: &mut CS, - f: F, - ) -> AllocatedNum { - let wrap = FWrap(f); - match self.0.get(&wrap) { - Some(allocated_num) => allocated_num.to_owned(), - None => { - let allocated_num = - allocate_const(cs, &format!("allocate constant {}", f.hex_digits()), f); - self.0.insert(wrap, allocated_num.clone()); - allocated_num +fn allocate_img_for_slot>( + cs: &mut CS, + slot: &Slot, + preallocated_preimg: Vec>, + store: &Store, +) -> Result> { + let cs = &mut cs.namespace(|| format!("image for slot {slot}")); + let preallocated_img = { + match slot.typ { + SlotType::Hash2 => { + hash_poseidon(cs, preallocated_preimg, store.poseidon_cache.constants.c4())? + } + SlotType::Hash3 => { + hash_poseidon(cs, preallocated_preimg, store.poseidon_cache.constants.c6())? + } + SlotType::Hash4 => { + hash_poseidon(cs, preallocated_preimg, store.poseidon_cache.constants.c8())? + } + SlotType::Commitment => { + hash_poseidon(cs, preallocated_preimg, store.poseidon_cache.constants.c3())? + } + SlotType::LessThan => { + // When a and b have the same sign, a < b iff a - b < 0 + // When a and b have different signs, a < b iff a is negative + let a_num = &preallocated_preimg[0]; + let b_num = &preallocated_preimg[1]; + let slot_str = &slot.to_string(); + let a_is_negative = allocate_is_negative( + &mut cs.namespace(|| format!("a_is_negative for slot {slot_str}")), + a_num, + )?; + let a_is_negative_num = boolean_to_num( + &mut cs.namespace(|| format!("a_is_negative_num for slot {slot_str}")), + &a_is_negative, + )?; + let b_is_negative = allocate_is_negative( + &mut cs.namespace(|| format!("b_is_negative for slot {slot_str}")), + b_num, + )?; + let same_sign = Boolean::xor( + &mut cs.namespace(|| format!("same_sign for slot {slot_str}")), + &a_is_negative, + &b_is_negative, + )? + .not(); + let diff = sub( + &mut cs.namespace(|| format!("diff for slot {slot_str}")), + a_num, + b_num, + )?; + let diff_is_negative = allocate_is_negative( + &mut cs.namespace(|| format!("diff_is_negative for slot {slot_str}")), + &diff, + )?; + let diff_is_negative_num = boolean_to_num( + &mut cs.namespace(|| format!("diff_is_negative_num for slot {slot_str}")), + &diff_is_negative, + )?; + pick( + &mut cs.namespace(|| format!("pick for slot {slot}")), + &same_sign, + &diff_is_negative_num, + &a_is_negative_num, + )? } } - } + }; + Ok(preallocated_img) } -type BoundAllocations = VarMap>; +/// Allocates unconstrained slots +fn allocate_slots>( + cs: &mut CS, + preimg_data: &[Option>], + slot_type: SlotType, + num_slots: usize, + store: &Store, +) -> Result>, AllocatedNum)>> { + assert!( + preimg_data.len() == num_slots, + "collected preimages not equal to the number of available slots" + ); + + let mut preallocations = Vec::with_capacity(num_slots); + + // We must perform the allocations for the slots containing data collected + // by the interpreter. The `None` cases must be filled with dummy values + for (slot_idx, maybe_preimg_data) in preimg_data.iter().enumerate() { + if let Some(preimg_data) = maybe_preimg_data { + let slot = Slot { + idx: slot_idx, + typ: slot_type, + }; + + // Allocate the preimage because the image depends on it + let mut preallocated_preimg = Vec::with_capacity(slot_type.preimg_size()); + + match preimg_data { + PreimageData::PtrVec(ptr_vec) => { + let mut component_idx = 0; + for ptr in ptr_vec { + let z_ptr = store.hash_ptr(ptr)?; -impl Func { - /// Allocates an unconstrained pointer - fn allocate_ptr>( - cs: &mut CS, - z_ptr: &ZPtr, - var: &Var, - bound_allocations: &mut BoundAllocations, - ) -> Result> { - let allocated_tag = - allocate_num(cs, &format!("allocate {var}'s tag"), z_ptr.tag.to_field())?; - let allocated_hash = allocate_num(cs, &format!("allocate {var}'s hash"), z_ptr.hash)?; - let allocated_ptr = AllocatedPtr::from_parts(allocated_tag, allocated_hash); - bound_allocations.insert(var.clone(), allocated_ptr.clone()); - Ok(allocated_ptr) - } + // allocate pointer tag + preallocated_preimg.push(allocate_preimg_component_for_slot( + cs, + &slot, + component_idx, + z_ptr.tag.to_field(), + )?); - /// Allocates an unconstrained pointer for each input of the frame - fn allocate_input>( - &self, - cs: &mut CS, - store: &mut Store, - frame: &Frame, - bound_allocations: &mut BoundAllocations, - ) -> Result<()> { - for (i, ptr) in frame.input.iter().enumerate() { - let param = &self.input_params[i]; - Self::allocate_ptr(cs, &store.hash_ptr(ptr)?, param, bound_allocations)?; - } - Ok(()) - } + component_idx += 1; - /// Allocates an unconstrained pointer for each output of the frame - fn allocate_output>( - cs: &mut CS, - store: &mut Store, - frame: &Frame, - bound_allocations: &mut BoundAllocations, - ) -> Result>> { - frame - .output - .iter() - .enumerate() - .map(|(i, ptr)| { - Self::allocate_ptr( - cs, - &store.hash_ptr(ptr)?, - &Var(format!("output[{}]", i).into()), - bound_allocations, - ) - }) - .collect::>() - } + // allocate pointer hash + preallocated_preimg.push(allocate_preimg_component_for_slot( + cs, + &slot, + component_idx, + z_ptr.hash, + )?); - #[inline] - fn allocate_preimg_component_for_slot>( - cs: &mut CS, - slot: &Slot, - component_idx: usize, - value: F, - ) -> Result> { - allocate_num( - cs, - &format!("component {component_idx} for slot {slot}"), - value, - ) + component_idx += 1; + } + } + PreimageData::FPtr(f, ptr) => { + let z_ptr = store.hash_ptr(ptr)?; + // allocate first component + preallocated_preimg.push(allocate_preimg_component_for_slot(cs, &slot, 0, *f)?); + // allocate second component + preallocated_preimg.push(allocate_preimg_component_for_slot( + cs, + &slot, + 1, + z_ptr.tag.to_field(), + )?); + // allocate third component + preallocated_preimg.push(allocate_preimg_component_for_slot( + cs, &slot, 2, z_ptr.hash, + )?); + } + PreimageData::FPair(a, b) => { + // allocate first component + preallocated_preimg.push(allocate_preimg_component_for_slot(cs, &slot, 0, *a)?); + + // allocate second component + preallocated_preimg.push(allocate_preimg_component_for_slot(cs, &slot, 1, *b)?); + } + } + + // Allocate the image by calling the arithmetic function according + // to the slot type + let preallocated_img = + allocate_img_for_slot(cs, &slot, preallocated_preimg.clone(), store)?; + + preallocations.push((preallocated_preimg, preallocated_img)); + } else { + let slot = Slot { + idx: slot_idx, + typ: slot_type, + }; + let preallocated_preimg: Vec<_> = (0..slot_type.preimg_size()) + .map(|component_idx| { + allocate_preimg_component_for_slot(cs, &slot, component_idx, F::ZERO) + }) + .collect::>()?; + + let preallocated_img = + allocate_img_for_slot(cs, &slot, preallocated_preimg.clone(), store)?; + + preallocations.push((preallocated_preimg, preallocated_img)); + } } - fn allocate_img_for_slot>( + Ok(preallocations) +} + +impl Block { + fn alloc_globals>( + &self, cs: &mut CS, - slot: &Slot, - preallocated_preimg: Vec>, - store: &mut Store, - ) -> Result> { - let cs = &mut cs.namespace(|| format!("image for slot {slot}")); - let preallocated_img = { - match slot.typ { - SlotType::Hash2 => { - hash_poseidon(cs, preallocated_preimg, store.poseidon_cache.constants.c4())? + store: &Store, + g: &mut GlobalAllocator, + ) -> Result<(), SynthesisError> { + for op in &self.ops { + match op { + Op::Call(_, func, _) => func.body.alloc_globals(cs, store, g)?, + Op::Hash2(_, tag, _) + | Op::Hash3(_, tag, _) + | Op::Hash4(_, tag, _) + | Op::Cast(_, tag, _) => { + g.alloc_const(cs, tag.to_field()); } - SlotType::Hash3 => { - hash_poseidon(cs, preallocated_preimg, store.poseidon_cache.constants.c6())? + Op::Lit(_, lit) => { + let lit_ptr = lit.to_ptr_cached(store); + let lit_z_ptr = store.hash_ptr(&lit_ptr).unwrap(); + g.alloc_const(cs, lit_z_ptr.tag.to_field()); + g.alloc_const(cs, lit_z_ptr.hash); } - SlotType::Hash4 => { - hash_poseidon(cs, preallocated_preimg, store.poseidon_cache.constants.c8())? + Op::Null(_, tag) => { + g.alloc_const(cs, tag.to_field()); + g.alloc_const(cs, F::ZERO); } - SlotType::Commitment => { - hash_poseidon(cs, preallocated_preimg, store.poseidon_cache.constants.c3())? + Op::EqTag(..) + | Op::EqVal(..) + | Op::Add(..) + | Op::Sub(..) + | Op::Mul(..) + | Op::Lt(..) + | Op::Trunc(..) + | Op::DivRem64(..) => { + g.alloc_const(cs, Tag::Expr(Num).to_field()); } - SlotType::LessThan => { - let a_num = &preallocated_preimg[0]; - let b_num = &preallocated_preimg[1]; - let diff = sub( - &mut cs.namespace(|| format!("sub for slot {slot}")), - a_num, - b_num, - )?; - let diff_is_negative = allocate_is_negative( - &mut cs.namespace(|| format!("is_negative for slot {slot}")), - &diff, - )?; - boolean_to_num( - &mut cs.namespace(|| format!("boolean_to_num for slot {slot}")), - &diff_is_negative, - )? + Op::Div(..) => { + g.alloc_const(cs, Tag::Expr(Num).to_field()); + g.alloc_const(cs, F::ONE); + } + Op::Hide(..) | Op::Open(..) => { + g.alloc_const(cs, Tag::Expr(Num).to_field()); + g.alloc_const(cs, Tag::Expr(Comm).to_field()); + } + _ => (), + } + } + match &self.ctrl { + Ctrl::IfEq(.., a, b) => { + a.alloc_globals(cs, store, g)?; + b.alloc_globals(cs, store, g)?; + } + Ctrl::MatchTag(_, cases, def) => { + for block in cases.values() { + block.alloc_globals(cs, store, g)?; + } + if let Some(def) = def { + def.alloc_globals(cs, store, g)?; } } - }; - Ok(preallocated_img) + Ctrl::MatchSymbol(_, cases, def) => { + g.alloc_const(cs, Tag::Expr(Sym).to_field()); + for block in cases.values() { + block.alloc_globals(cs, store, g)?; + } + if let Some(def) = def { + def.alloc_globals(cs, store, g)?; + } + } + Ctrl::Return(..) => (), + } + Ok(()) } +} - /// Allocates unconstrained slots - fn allocate_slots>( +impl Func { + /// Allocates an unconstrained pointer for each input of the frame + fn allocate_input>( + &self, cs: &mut CS, - preimg_data: &[Option>], - slot_type: SlotType, - num_slots: usize, - store: &mut Store, - ) -> Result>, AllocatedNum)>> { - assert!( - preimg_data.len() == num_slots, - "collected preimages not equal to the number of available slots" - ); - - let mut preallocations = Vec::with_capacity(num_slots); - - // We must perform the allocations for the slots containing data collected - // by the interpreter. The `None` cases must be filled with dummy values - for (slot_idx, maybe_preimg_data) in preimg_data.iter().enumerate() { - if let Some(preimg_data) = maybe_preimg_data { - let slot = Slot { - idx: slot_idx, - typ: slot_type, - }; - - // Allocate the preimage because the image depends on it - let mut preallocated_preimg = Vec::with_capacity(slot_type.preimg_size()); - - match preimg_data { - PreimageData::PtrVec(ptr_vec) => { - let mut component_idx = 0; - for ptr in ptr_vec { - let z_ptr = store.hash_ptr(ptr)?; - - // allocate pointer tag - preallocated_preimg.push(Self::allocate_preimg_component_for_slot( - cs, - &slot, - component_idx, - z_ptr.tag.to_field(), - )?); - - component_idx += 1; - - // allocate pointer hash - preallocated_preimg.push(Self::allocate_preimg_component_for_slot( - cs, - &slot, - component_idx, - z_ptr.hash, - )?); - - component_idx += 1; - } - } - PreimageData::FPtr(f, ptr) => { - let z_ptr = store.hash_ptr(ptr)?; - // allocate first component - preallocated_preimg - .push(Self::allocate_preimg_component_for_slot(cs, &slot, 0, *f)?); - // allocate second component - preallocated_preimg.push(Self::allocate_preimg_component_for_slot( - cs, - &slot, - 1, - z_ptr.tag.to_field(), - )?); - // allocate third component - preallocated_preimg.push(Self::allocate_preimg_component_for_slot( - cs, &slot, 2, z_ptr.hash, - )?); - } - PreimageData::FPair(a, b) => { - // allocate first component - preallocated_preimg - .push(Self::allocate_preimg_component_for_slot(cs, &slot, 0, *a)?); - - // allocate second component - preallocated_preimg - .push(Self::allocate_preimg_component_for_slot(cs, &slot, 1, *b)?); - } - } - - // Allocate the image by calling the arithmetic function according - // to the slot type - let preallocated_img = - Self::allocate_img_for_slot(cs, &slot, preallocated_preimg.clone(), store)?; - - preallocations.push((preallocated_preimg, preallocated_img)); - } else { - let slot = Slot { - idx: slot_idx, - typ: slot_type, - }; - let preallocated_preimg: Vec<_> = (0..slot_type.preimg_size()) - .map(|component_idx| { - Self::allocate_preimg_component_for_slot(cs, &slot, component_idx, F::ZERO) - }) - .collect::>()?; - - let preallocated_img = - Self::allocate_img_for_slot(cs, &slot, preallocated_preimg.clone(), store)?; - - preallocations.push((preallocated_preimg, preallocated_img)); - } + store: &Store, + frame: &Frame, + bound_allocations: &mut BoundAllocations, + ) -> Result<()> { + for (i, ptr) in frame.input.iter().enumerate() { + let param = &self.input_params[i]; + allocate_ptr(cs, &store.hash_ptr(ptr)?, param, bound_allocations)?; } + Ok(()) + } - Ok(preallocations) + pub fn alloc_globals>( + &self, + cs: &mut CS, + store: &Store, + ) -> Result, SynthesisError> { + let mut g = GlobalAllocator::default(); + self.body.alloc_globals(cs, store, &mut g)?; + Ok(g) } /// Create R1CS constraints for a LEM function given an evaluation frame. This @@ -330,24 +430,21 @@ impl Func { /// each slot and then, as we traverse the function, we add constraints to make /// sure that the witness satisfies the arithmetic equations for the /// corresponding slots. - pub fn synthesize>( + pub fn synthesize_frame>( &self, cs: &mut CS, - store: &mut Store, + store: &Store, frame: &Frame, - ) -> Result<()> { - let mut global_allocator = GlobalAllocator::default(); - let mut bound_allocations = BoundAllocations::new(); - - // Inputs are constrained by their usage inside the function body - self.allocate_input(cs, store, frame, &mut bound_allocations)?; + global_allocator: &GlobalAllocator, + bound_allocations: &mut BoundAllocations, + ) -> Result>> { // Outputs are constrained by the return statement. All functions return - let preallocated_outputs = Func::allocate_output(cs, store, frame, &mut bound_allocations)?; + let preallocated_outputs = allocate_output(cs, store, frame, bound_allocations)?; // Slots are constrained by their usage inside the function body. The ones // not used in throughout the concrete path are effectively unconstrained, // that's why they are filled with dummies - let preallocated_hash2_slots = Func::allocate_slots( + let preallocated_hash2_slots = allocate_slots( cs, &frame.preimages.hash2, SlotType::Hash2, @@ -355,7 +452,7 @@ impl Func { store, )?; - let preallocated_hash3_slots = Func::allocate_slots( + let preallocated_hash3_slots = allocate_slots( cs, &frame.preimages.hash3, SlotType::Hash3, @@ -363,7 +460,7 @@ impl Func { store, )?; - let preallocated_hash4_slots = Func::allocate_slots( + let preallocated_hash4_slots = allocate_slots( cs, &frame.preimages.hash4, SlotType::Hash4, @@ -371,7 +468,7 @@ impl Func { store, )?; - let preallocated_commitment_slots = Func::allocate_slots( + let preallocated_commitment_slots = allocate_slots( cs, &frame.preimages.commitment, SlotType::Commitment, @@ -379,7 +476,7 @@ impl Func { store, )?; - let preallocated_less_than_slots = Func::allocate_slots( + let preallocated_less_than_slots = allocate_slots( cs, &frame.preimages.less_than, SlotType::LessThan, @@ -388,8 +485,8 @@ impl Func { )?; struct Globals<'a, F: LurkField> { - store: &'a mut Store, - global_allocator: &'a mut GlobalAllocator, + store: &'a Store, + global_allocator: &'a GlobalAllocator, preallocated_hash2_slots: Vec<(Vec>, AllocatedNum)>, preallocated_hash3_slots: Vec<(Vec>, AllocatedNum)>, preallocated_hash4_slots: Vec<(Vec>, AllocatedNum)>, @@ -456,7 +553,9 @@ impl Func { // Allocate the image tag if it hasn't been allocated before, // create the full image pointer and add it to bound allocations - let img_tag = g.global_allocator.get_or_alloc_const(cs, $tag.to_field()); + let img_tag = g + .global_allocator + .get_allocated_const_cloned($tag.to_field())?; let img_hash = preallocated_img_hash.clone(); let img_ptr = AllocatedPtr::from_parts(img_tag, img_hash); bound_allocations.insert($img, img_ptr); @@ -512,17 +611,19 @@ impl Func { // Note that, because there's currently no way of deferring giving // a value to the allocated nums to be filled later, we must either // add the results of the call to the witness, or recompute them. + let dummy = Ptr::null(Tag::Expr(Nil)); let output_vals = if let Some(true) = not_dummy.get_value() { - g.call_outputs.pop_front().unwrap() + g.call_outputs + .pop_front() + .unwrap_or_else(|| (0..out.len()).map(|_| dummy).collect()) } else { - let dummy = Ptr::Leaf(Tag::Expr(Nil), F::ZERO); (0..out.len()).map(|_| dummy).collect() }; assert_eq!(output_vals.len(), out.len()); let mut output_ptrs = Vec::with_capacity(out.len()); for (ptr, var) in output_vals.iter().zip(out.iter()) { let zptr = &g.store.hash_ptr(ptr)?; - output_ptrs.push(Func::allocate_ptr(cs, zptr, var, bound_allocations)?); + output_ptrs.push(allocate_ptr(cs, zptr, var, bound_allocations)?); } // Get the pointers for the input, i.e. the arguments let args = bound_allocations.get_many_cloned(inp)?; @@ -563,23 +664,29 @@ impl Func { unhash_helper!(preimg, img, SlotType::Hash4); } Op::Null(tgt, tag) => { - let tag = g.global_allocator.get_or_alloc_const(cs, tag.to_field()); - let zero = g.global_allocator.get_or_alloc_const(cs, F::ZERO); + let tag = g + .global_allocator + .get_allocated_const_cloned(tag.to_field())?; + let zero = g.global_allocator.get_allocated_const_cloned(F::ZERO)?; let allocated_ptr = AllocatedPtr::from_parts(tag, zero); bound_allocations.insert(tgt.clone(), allocated_ptr); } Op::Lit(tgt, lit) => { - let lit_ptr = lit.to_ptr(g.store); + let lit_ptr = lit.to_ptr_cached(g.store); let lit_tag = lit_ptr.tag().to_field(); let lit_hash = g.store.hash_ptr(&lit_ptr)?.hash; - let allocated_tag = g.global_allocator.get_or_alloc_const(cs, lit_tag); - let allocated_hash = g.global_allocator.get_or_alloc_const(cs, lit_hash); + let allocated_tag = + g.global_allocator.get_allocated_const_cloned(lit_tag)?; + let allocated_hash = + g.global_allocator.get_allocated_const_cloned(lit_hash)?; let allocated_ptr = AllocatedPtr::from_parts(allocated_tag, allocated_hash); bound_allocations.insert(tgt.clone(), allocated_ptr); } Op::Cast(tgt, tag, src) => { let src = bound_allocations.get(src)?; - let tag = g.global_allocator.get_or_alloc_const(cs, tag.to_field()); + let tag = g + .global_allocator + .get_allocated_const_cloned(tag.to_field())?; let allocated_ptr = AllocatedPtr::from_parts(tag, src.hash().clone()); bound_allocations.insert(tgt.clone(), allocated_ptr); } @@ -592,7 +699,7 @@ impl Func { let c_num = boolean_to_num(&mut cs.namespace(|| "equal_tag.to_num"), &eq)?; let tag = g .global_allocator - .get_or_alloc_const(cs, Tag::Expr(Num).to_field()); + .get_allocated_const_cloned(Tag::Expr(Num).to_field())?; let c = AllocatedPtr::from_parts(tag, c_num); bound_allocations.insert(tgt.clone(), c); } @@ -605,7 +712,7 @@ impl Func { let c_num = boolean_to_num(&mut cs.namespace(|| "equal_val.to_num"), &eq)?; let tag = g .global_allocator - .get_or_alloc_const(cs, Tag::Expr(Num).to_field()); + .get_allocated_const_cloned(Tag::Expr(Num).to_field())?; let c = AllocatedPtr::from_parts(tag, c_num); bound_allocations.insert(tgt.clone(), c); } @@ -617,7 +724,7 @@ impl Func { let c_num = add(&mut cs.namespace(|| "add"), a_num, b_num)?; let tag = g .global_allocator - .get_or_alloc_const(cs, Tag::Expr(Num).to_field()); + .get_allocated_const_cloned(Tag::Expr(Num).to_field())?; let c = AllocatedPtr::from_parts(tag, c_num); bound_allocations.insert(tgt.clone(), c); } @@ -629,7 +736,7 @@ impl Func { let c_num = sub(&mut cs.namespace(|| "sub"), a_num, b_num)?; let tag = g .global_allocator - .get_or_alloc_const(cs, Tag::Expr(Num).to_field()); + .get_allocated_const_cloned(Tag::Expr(Num).to_field())?; let c = AllocatedPtr::from_parts(tag, c_num); bound_allocations.insert(tgt.clone(), c); } @@ -641,7 +748,7 @@ impl Func { let c_num = mul(&mut cs.namespace(|| "mul"), a_num, b_num)?; let tag = g .global_allocator - .get_or_alloc_const(cs, Tag::Expr(Num).to_field()); + .get_allocated_const_cloned(Tag::Expr(Num).to_field())?; let c = AllocatedPtr::from_parts(tag, c_num); bound_allocations.insert(tgt.clone(), c); } @@ -652,12 +759,12 @@ impl Func { let b_num = b.hash(); let b_is_zero = &alloc_is_zero(&mut cs.namespace(|| "b_is_zero"), b_num)?; - let one = g.global_allocator.get_or_alloc_const(cs, F::ONE); + let one = g.global_allocator.get_allocated_const(F::ONE)?; let divisor = pick( &mut cs.namespace(|| "maybe-dummy divisor"), b_is_zero, - &one, + one, b_num, )?; @@ -665,7 +772,7 @@ impl Func { let tag = g .global_allocator - .get_or_alloc_const(cs, Tag::Expr(Num).to_field()); + .get_allocated_const_cloned(Tag::Expr(Num).to_field())?; let c = AllocatedPtr::from_parts(tag, quotient); bound_allocations.insert(tgt.clone(), c); } @@ -674,7 +781,7 @@ impl Func { let b = bound_allocations.get(b)?; let tag = g .global_allocator - .get_or_alloc_const(cs, Tag::Expr(Num).to_field()); + .get_allocated_const_cloned(Tag::Expr(Num).to_field())?; let (preallocated_preimg, lt) = &g.preallocated_less_than_slots[next_slot.consume_less_than()]; for (i, n) in [a.hash(), b.hash()].into_iter().enumerate() { @@ -707,7 +814,7 @@ impl Func { enforce_pack(&mut cs.namespace(|| "enforce_trunc"), &trunc_bits, &trunc); let tag = g .global_allocator - .get_or_alloc_const(cs, Tag::Expr(Num).to_field()); + .get_allocated_const_cloned(Tag::Expr(Num).to_field())?; let c = AllocatedPtr::from_parts(tag, trunc); bound_allocations.insert(tgt.clone(), c); } @@ -745,7 +852,7 @@ impl Func { ); let tag = g .global_allocator - .get_or_alloc_const(cs, Tag::Expr(Num).to_field()); + .get_allocated_const_cloned(Tag::Expr(Num).to_field())?; let div_ptr = AllocatedPtr::from_parts(tag.clone(), div); let rem_ptr = AllocatedPtr::from_parts(tag, rem); bound_allocations.insert(tgt[0].clone(), div_ptr); @@ -757,7 +864,7 @@ impl Func { let pay = bound_allocations.get(pay)?; let sec_tag = g .global_allocator - .get_or_alloc_const(cs, Tag::Expr(Num).to_field()); + .get_allocated_const(Tag::Expr(Num).to_field())?; let (preallocated_preimg, hash) = &g.preallocated_commitment_slots[next_slot.consume_commitment()]; implies_equal( @@ -766,7 +873,7 @@ impl Func { }), not_dummy, sec.tag(), - &sec_tag, + sec_tag, ); implies_equal( &mut cs.namespace(|| { @@ -794,7 +901,7 @@ impl Func { ); let tag = g .global_allocator - .get_or_alloc_const(cs, Tag::Expr(Comm).to_field()); + .get_allocated_const_cloned(Tag::Expr(Comm).to_field())?; let allocated_ptr = AllocatedPtr::from_parts(tag, hash.clone()); bound_allocations.insert(tgt.clone(), allocated_ptr); } @@ -804,14 +911,14 @@ impl Func { &g.preallocated_commitment_slots[next_slot.consume_commitment()]; let comm_tag = g .global_allocator - .get_or_alloc_const(cs, Tag::Expr(Comm).to_field()); + .get_allocated_const(Tag::Expr(Comm).to_field())?; implies_equal( &mut cs.namespace(|| { format!("implies equal for comm's tag (OP {:?})", &op) }), not_dummy, comm.tag(), - &comm_tag, + comm_tag, ); implies_equal( &mut cs.namespace(|| { @@ -823,7 +930,7 @@ impl Func { ); let sec_tag = g .global_allocator - .get_or_alloc_const(cs, Tag::Expr(Num).to_field()); + .get_allocated_const_cloned(Tag::Expr(Num).to_field())?; let allocated_sec_ptr = AllocatedPtr::from_parts(sec_tag, preallocated_preimg[0].clone()); let allocated_pay_ptr = AllocatedPtr::from_parts( @@ -836,6 +943,110 @@ impl Func { } } + let mut synthesize_match = |matched: &AllocatedNum, + cases: &[(F, &Block)], + def: &Option>, + bound_allocations: &mut VarMap>, + g: &mut Globals<'_, F>| + -> Result> { + // * One `Boolean` for each case + // * Maybe one `Boolean` for the default case + // * One `Boolean` for the negation of `not_dummy` + let selector_size = cases.len() + usize::from(def.is_some()) + 1; + let mut selector = Vec::with_capacity(selector_size); + let mut branch_slots = Vec::with_capacity(cases.len()); + for (i, (f, block)) in cases.iter().enumerate() { + // For each case, we compute `not_dummy_and_has_match: Boolean` + // and accumulate them on a `selector` vector + let not_dummy_and_has_match_bool = + not_dummy.get_value().and_then(|not_dummy| { + matched + .get_value() + .map(|matched_f| not_dummy && &matched_f == f) + }); + let not_dummy_and_has_match = Boolean::Is(AllocatedBit::alloc( + &mut cs.namespace(|| format!("{i}.allocated_bit")), + not_dummy_and_has_match_bool, + )?); + + // If `not_dummy_and_has_match` is true, then we enforce a match + implies_equal_const( + &mut cs.namespace(|| format!("{i}.implies_equal_const")), + ¬_dummy_and_has_match, + matched, + *f, + ); + + selector.push(not_dummy_and_has_match.clone()); + + let mut branch_slot = *next_slot; + recurse( + &mut cs.namespace(|| format!("{i}")), + block, + ¬_dummy_and_has_match, + &mut branch_slot, + bound_allocations, + preallocated_outputs, + g, + )?; + branch_slots.push(branch_slot); + } + + if let Some(def) = def { + // Compute `default: Boolean`, which tells whether the default case was chosen or not + let is_default_bool = selector.iter().fold(not_dummy.get_value(), |acc, b| { + // all the booleans in `selector` have to be false up to this point + // in order for the default case to be selected + acc.and_then(|acc| b.get_value().map(|b| acc && !b)) + }); + let is_default = Boolean::Is(AllocatedBit::alloc( + &mut cs.namespace(|| "_.allocated_bit"), + is_default_bool, + )?); + + for (i, (f, _)) in cases.iter().enumerate() { + // if the default path was taken, then there can be no tag in `cases` + // that equals the tag of the pointer being matched on + implies_unequal_const( + &mut cs.namespace(|| format!("{i}.implies_unequal_const")), + &is_default, + matched, + *f, + )?; + } + + recurse( + &mut cs.namespace(|| "_"), + def, + &is_default, + next_slot, + bound_allocations, + preallocated_outputs, + g, + )?; + + // Pushing `is_default` to `selector` to enforce summation = 1 + selector.push(is_default); + } + + // Now we need to enforce that exactly one path was taken. We do that by enforcing + // that the sum of the previously collected `Boolean`s is one. But, of course, this + // is irrelevant if we're on a virtual path and thus we use an implication gadget. + + // If `not_dummy` is false, then all booleans in `selector` are false up to this point. + // Thus we need to add a negation of `not_dummy` to make it satisfiable. If it's true, + // it will count as a 0 and will not influence the sum. + selector.push(not_dummy.not()); + + enforce_selector_with_premise( + &mut cs.namespace(|| "enforce_selector_with_premise"), + not_dummy, + &selector, + ); + + Ok(branch_slots) + }; + match &block.ctrl { Ctrl::Return(return_vars) => { for (i, return_var) in return_vars.iter().enumerate() { @@ -917,179 +1128,54 @@ impl Func { Ok(()) } Ctrl::MatchTag(match_var, cases, def) => { - let match_tag = bound_allocations.get(match_var)?.tag().clone(); - let mut selector = Vec::with_capacity(cases.len() + 2); - let mut branch_slots = Vec::with_capacity(cases.len()); - for (tag, block) in cases { - let is_eq = not_dummy.get_value().and_then(|not_dummy| { - match_tag - .get_value() - .map(|val| not_dummy && val == tag.to_field::()) - }); - - let has_match = Boolean::Is(AllocatedBit::alloc( - &mut cs.namespace(|| format!("{tag}.allocated_bit")), - is_eq, - )?); - implies_equal_const( - &mut cs.namespace(|| format!("implies equal for {match_var}'s {tag}")), - &has_match, - &match_tag, - tag.to_field(), - ); - - selector.push(has_match.clone()); - - let mut branch_slot = *next_slot; - recurse( - &mut cs.namespace(|| format!("{}", tag)), - block, - &has_match, - &mut branch_slot, - bound_allocations, - preallocated_outputs, - g, - )?; - branch_slots.push(branch_slot); - } - - match def { - Some(def) => { - let default = selector.iter().fold(not_dummy.get_value(), |acc, b| { - acc.and_then(|acc| b.get_value().map(|b| acc && !b)) - }); - let has_match = Boolean::Is(AllocatedBit::alloc( - &mut cs.namespace(|| "_.allocated_bit"), - default, - )?); - for (tag, _) in cases { - implies_unequal_const( - &mut cs.namespace(|| format!("{tag} implies_unequal")), - &has_match, - &match_tag, - tag.to_field(), - )?; - } - - selector.push(has_match.clone()); - - recurse( - &mut cs.namespace(|| "_"), - def, - &has_match, - next_slot, - bound_allocations, - preallocated_outputs, - g, - )?; - } - None => (), - } + let matched = bound_allocations.get(match_var)?.tag().clone(); + let cases_vec = cases + .iter() + .map(|(tag, block)| (tag.to_field::(), block)) + .collect::>(); + let branch_slots = + synthesize_match(&matched, &cases_vec, def, bound_allocations, g)?; // The number of slots the match used is the max number of slots of each branch - *next_slot = branch_slots - .into_iter() - .fold(*next_slot, |acc, branch_slot| acc.max(branch_slot)); - - // Now we need to enforce that at exactly one path was taken. We do that by enforcing - // that the sum of the previously collected `Boolean`s is one. But, of course, this - // irrelevant if we're on a virtual path and thus we use an implication gadget. - selector.push(not_dummy.not()); - enforce_selector_with_premise( - &mut cs.namespace(|| "enforce_selector_with_premise"), - not_dummy, - &selector, - ); + *next_slot = next_slot.fold_max(branch_slots); Ok(()) } - Ctrl::MatchVal(match_var, cases, def) => { - let match_lit = bound_allocations.get(match_var)?.hash().clone(); - let mut selector = Vec::with_capacity(cases.len() + 2); - let mut branch_slots = Vec::with_capacity(cases.len()); - for (i, (lit, block)) in cases.iter().enumerate() { - let lit_ptr = lit.to_ptr(g.store); - let lit_hash = g.store.hash_ptr(&lit_ptr)?.hash; - let is_eq = not_dummy.get_value().and_then(|not_dummy| { - match_lit - .get_value() - .map(|val| not_dummy && val == lit_hash) - }); - - let has_match = Boolean::Is(AllocatedBit::alloc( - &mut cs.namespace(|| format!("{i}.allocated_bit")), - is_eq, - )?); - implies_equal_const( - &mut cs.namespace(|| format!("implies equal for {match_var} ({i})")), - &has_match, - &match_lit, - lit_hash, - ); - - selector.push(has_match.clone()); - - let mut branch_slot = *next_slot; - recurse( - &mut cs.namespace(|| format!("{i}.case")), - block, - &has_match, - &mut branch_slot, - bound_allocations, - preallocated_outputs, - g, - )?; - branch_slots.push(branch_slot); + Ctrl::MatchSymbol(match_var, cases, def) => { + let match_var_ptr = bound_allocations.get(match_var)?.clone(); + + let mut cases_vec = Vec::with_capacity(cases.len()); + for (sym, block) in cases { + let sym_ptr = g + .store + .interned_symbol(sym) + .expect("symbol must have been interned"); + let sym_hash = g.store.hash_ptr(sym_ptr)?.hash; + cases_vec.push((sym_hash, block)); } - match def { - Some(def) => { - let default = selector.iter().fold(not_dummy.get_value(), |acc, b| { - acc.and_then(|acc| b.get_value().map(|b| acc && !b)) - }); - let has_match = Boolean::Is(AllocatedBit::alloc( - &mut cs.namespace(|| "_.allocated_bit"), - default, - )?); - for (i, (lit, _)) in cases.iter().enumerate() { - let lit_ptr = lit.to_ptr(g.store); - let lit_hash = g.store.hash_ptr(&lit_ptr)?.hash; - implies_unequal_const( - &mut cs.namespace(|| format!("{i} implies_unequal")), - &has_match, - &match_lit, - lit_hash, - )?; - } + let branch_slots = synthesize_match( + match_var_ptr.hash(), + &cases_vec, + def, + bound_allocations, + g, + )?; - selector.push(has_match.clone()); - - recurse( - &mut cs.namespace(|| "_"), - def, - &has_match, - next_slot, - bound_allocations, - preallocated_outputs, - g, - )?; - } - None => (), - } + // Now we enforce `match_var`'s tag - // The number of slots the match used is the max number of slots of each branch - *next_slot = branch_slots - .into_iter() - .fold(*next_slot, |acc, branch_slot| acc.max(branch_slot)); + let sym_tag = g + .global_allocator + .get_allocated_const(Tag::Expr(Sym).to_field())?; - // Now we need to enforce that at exactly one path was taken. We do that by enforcing - // that the sum of the previously collected `Boolean`s is one. But, of course, this - // irrelevant if we're on a virtual path and thus we use an implication gadget. - selector.push(not_dummy.not()); - enforce_selector_with_premise( - &mut cs.namespace(|| "enforce_selector_with_premise"), + implies_equal( + &mut cs.namespace(|| format!("implies equal for {match_var}'s tag (Sym)")), not_dummy, - &selector, + match_var_ptr.tag(), + sym_tag, ); + + // The number of slots the match used is the max number of slots of each branch + *next_slot = next_slot.fold_max(branch_slots); Ok(()) } } @@ -1101,11 +1187,11 @@ impl Func { &self.body, &Boolean::Constant(true), &mut SlotsCounter::default(), - &mut bound_allocations, + bound_allocations, &preallocated_outputs, &mut Globals { store, - global_allocator: &mut global_allocator, + global_allocator, preallocated_hash2_slots, preallocated_hash3_slots, preallocated_hash4_slots, @@ -1114,17 +1200,32 @@ impl Func { call_outputs, call_count: 0, }, - ) + )?; + Ok(preallocated_outputs) + } + + /// Helper API for tests + pub fn synthesize_frame_aux>( + &self, + cs: &mut CS, + store: &Store, + frame: &Frame, + ) -> Result<()> { + let bound_allocations = &mut BoundAllocations::new(); + let global_allocator = self.alloc_globals(cs, store)?; + self.allocate_input(cs, store, frame, bound_allocations)?; + self.synthesize_frame(cs, store, frame, &global_allocator, bound_allocations)?; + Ok(()) } /// Computes the number of constraints that `synthesize` should create. It's /// also an explicit way to document and attest how the number of constraints /// grow. - pub fn num_constraints(&self, store: &mut Store) -> usize { + pub fn num_constraints(&self, store: &Store) -> usize { fn recurse( block: &Block, globals: &mut HashSet>, - store: &mut Store, + store: &Store, ) -> usize { let mut num_constraints = 0; for op in &block.ops { @@ -1138,36 +1239,37 @@ impl Func { globals.insert(FWrap(F::ZERO)); } Op::Lit(_, lit) => { - let lit_ptr = lit.to_ptr(store); - let lit_hash = store.hash_ptr(&lit_ptr).unwrap().hash; - globals.insert(FWrap(Tag::Expr(Sym).to_field())); - globals.insert(FWrap(lit_hash)); + let lit_ptr = lit.to_ptr_cached(store); + let lit_z_ptr = store.hash_ptr(&lit_ptr).unwrap(); + globals.insert(FWrap(lit_z_ptr.tag.to_field())); + globals.insert(FWrap(lit_z_ptr.hash)); } - Op::Cast(_tgt, tag, _src) => { + Op::Cast(_, tag, _) => { globals.insert(FWrap(tag.to_field())); } - Op::EqTag(_, _, _) | Op::EqVal(_, _, _) => { + Op::EqTag(..) | Op::EqVal(..) => { globals.insert(FWrap(Tag::Expr(Num).to_field())); num_constraints += 5; } - Op::Add(_, _, _) | Op::Sub(_, _, _) | Op::Mul(_, _, _) => { + Op::Add(..) | Op::Sub(..) | Op::Mul(..) => { globals.insert(FWrap(Tag::Expr(Num).to_field())); num_constraints += 1; } - Op::Div(_, _, _) => { + Op::Div(..) => { + globals.insert(FWrap(Tag::Expr(Num).to_field())); globals.insert(FWrap(F::ONE)); num_constraints += 5; } - Op::Lt(_, _, _) => { + Op::Lt(..) => { globals.insert(FWrap(Tag::Expr(Num).to_field())); num_constraints += 2; } - Op::Trunc(_, _, _) => { + Op::Trunc(..) => { globals.insert(FWrap(Tag::Expr(Num).to_field())); // bit decomposition + enforce_pack num_constraints += 389; } - Op::DivRem64(_, _, _) => { + Op::DivRem64(..) => { globals.insert(FWrap(Tag::Expr(Num).to_field())); // three implies_u64, one sub and one linear num_constraints += 197; @@ -1224,28 +1326,31 @@ impl Func { for block in cases.values() { num_constraints += recurse(block, globals, store); } - match def { - Some(def) => { - // constraints for the boolean, the unequalities and the default case - num_constraints += 1 + cases.len(); - num_constraints += recurse(def, globals, store); - } - None => (), - }; + if let Some(def) = def { + // constraints for the boolean, the unequalities and the default case + num_constraints += 1 + cases.len(); + num_constraints += recurse(def, globals, store); + } num_constraints } - Ctrl::MatchVal(_, cases, def) => { + Ctrl::MatchSymbol(_, cases, def) => { + // First we enforce that the tag of the pointer being matched on + // is Sym + num_constraints += 1; + globals.insert(FWrap(Tag::Expr(Sym).to_field())); + // We allocate one boolean per case and constrain it once + // per case. Then we add 1 constraint to enforce only one + // case was selected num_constraints += 2 * cases.len() + 1; + for block in cases.values() { num_constraints += recurse(block, globals, store); } - match def { - Some(def) => { - num_constraints += 1 + cases.len(); - num_constraints += recurse(def, globals, store); - } - None => (), - }; + if let Some(def) = def { + // constraints for the boolean, the unequalities and the default case + num_constraints += 1 + cases.len(); + num_constraints += recurse(def, globals, store); + } num_constraints } } @@ -1256,8 +1361,8 @@ impl Func { + 337 * self.slot.hash3 + 388 * self.slot.hash4 + 265 * self.slot.commitment - + 391 * self.slot.less_than; - let num_constraints = recurse::(&self.body, globals, store); + + 1172 * self.slot.less_than; + let num_constraints = recurse(&self.body, globals, store); slot_constraints + num_constraints + globals.len() } } diff --git a/src/lem/eval.rs b/src/lem/eval.rs index 33e468fd1b..4128215788 100644 --- a/src/lem/eval.rs +++ b/src/lem/eval.rs @@ -1,27 +1,89 @@ -use crate::func; +use anyhow::Result; +use once_cell::sync::OnceCell; -use super::Func; +use crate::{field::LurkField, func, state::initial_lurk_state, tag::ContTag::*}; + +use super::{interpreter::Frame, pointers::Ptr, store::Store, Func, Tag}; + +static EVAL_STEP: OnceCell = OnceCell::new(); /// Lurk's step function -#[allow(dead_code)] -pub(crate) fn eval_step() -> Func { - let reduce = reduce(); - let apply_cont = apply_cont(); - let make_thunk = make_thunk(); +pub fn eval_step() -> &'static Func { + EVAL_STEP.get_or_init(|| { + let reduce = reduce(); + let apply_cont = apply_cont(); + let make_thunk = make_thunk(); - func!(step(expr, env, cont): 3 => { - let (expr, env, cont, ctrl) = reduce(expr, env, cont); - let (expr, env, cont, ctrl) = apply_cont(expr, env, cont, ctrl); - let (expr, env, cont, _ctrl) = make_thunk(expr, env, cont, ctrl); - return (expr, env, cont) + func!(step(expr, env, cont): 3 => { + let (expr, env, cont, ctrl) = reduce(expr, env, cont); + let (expr, env, cont, ctrl) = apply_cont(expr, env, cont, ctrl); + let (expr, env, cont, _ctrl) = make_thunk(expr, env, cont, ctrl); + return (expr, env, cont) + }) }) } +pub fn evaluate_with_env_and_cont( + expr: Ptr, + env: Ptr, + cont: Ptr, + store: &mut Store, + limit: usize, +) -> Result<(Vec>, usize)> { + let stop_cond = |output: &[Ptr]| { + output[2] == Ptr::null(Tag::Cont(Terminal)) || output[2] == Ptr::null(Tag::Cont(Error)) + }; + let state = initial_lurk_state(); + let log_fmt = |i: usize, inp: &[Ptr], emit: &[Ptr], store: &Store| { + let mut out = format!( + "Frame: {i}\n\tExpr: {}\n\tEnv: {}\n\tCont: {}", + inp[0].fmt_to_string(store, state), + inp[1].fmt_to_string(store, state), + inp[2].fmt_to_string(store, state) + ); + if let Some(ptr) = emit.first() { + out.push_str(&format!("\n\tEmtd: {}", ptr.fmt_to_string(store, state))); + } + out + }; + + let input = &[expr, env, cont]; + let (frames, iterations, _) = + eval_step().call_until(input, store, stop_cond, limit, log_fmt)?; + Ok((frames, iterations)) +} + +pub fn evaluate( + expr: Ptr, + store: &mut Store, + limit: usize, +) -> Result<(Vec>, usize)> { + evaluate_with_env_and_cont( + expr, + store.intern_nil(), + Ptr::null(Tag::Cont(Outermost)), + store, + limit, + ) +} + +pub fn evaluate_simple( + expr: Ptr, + store: &mut Store, + limit: usize, +) -> Result<(Vec>, usize, Vec>)> { + let stop_cond = |output: &[Ptr]| { + output[2] == Ptr::null(Tag::Cont(Terminal)) || output[2] == Ptr::null(Tag::Cont(Error)) + }; + let input = vec![expr, store.intern_nil(), Ptr::null(Tag::Cont(Outermost))]; + eval_step().call_until_simple(input, store, stop_cond, limit) +} + fn safe_uncons() -> Func { func!(safe_uncons(xs): 2 => { let nil = Symbol("nil"); let nil = cast(nil, Expr::Nil); - let nilstr = Symbol(""); + let empty_str = String(""); match xs.tag { Expr::Nil => { return (nil, nil) @@ -31,8 +93,8 @@ fn safe_uncons() -> Func { return (car, cdr) } Expr::Str => { - if xs == nilstr { - return (nil, nilstr) + if xs == empty_str { + return (nil, empty_str) } let (car, cdr) = unhash2(xs); return (car, cdr) @@ -78,12 +140,12 @@ fn reduce() -> Func { return (expanded) }); let choose_let_cont = func!(choose_let_cont(head, var, env, expanded, cont): 1 => { - match head.val { - Symbol("let") => { + match symbol head { + "let" => { let cont: Cont::Let = hash4(var, env, expanded, cont); return (cont) } - Symbol("letrec") => { + "letrec" => { let cont: Cont::LetRec = hash4(var, env, expanded, cont); return (cont) } @@ -93,18 +155,8 @@ fn reduce() -> Func { let nil = Symbol("nil"); let nil = cast(nil, Expr::Nil); let t = Symbol("t"); - match head.val { - Symbol("car") - | Symbol("cdr") - | Symbol("commit") - | Symbol("num") - | Symbol("u64") - | Symbol("comm") - | Symbol("char") - | Symbol("open") - | Symbol("secret") - | Symbol("atom") - | Symbol("emit") => { + match symbol head { + "car", "cdr", "commit", "num", "u64", "comm", "char", "open", "secret", "atom", "emit" => { return (t) } }; @@ -115,32 +167,43 @@ fn reduce() -> Func { let nil = Symbol("nil"); let nil = cast(nil, Expr::Nil); let t = Symbol("t"); - match head.val { - Symbol("cons") - | Symbol("strcons") - | Symbol("hide") - | Symbol("+") - | Symbol("-") - | Symbol("*") - | Symbol("/") - | Symbol("%") - | Symbol("=") - | Symbol("eq") - | Symbol("<") - | Symbol(">") - | Symbol("<=") - | Symbol(">=") => { + match symbol head { + "cons", "strcons", "hide", "+", "-", "*", "/", "%", "=", "eq", "<", ">", "<=", ">=" => { return (t) } }; return (nil) }); + let make_call = func!(make_call(head, rest, env, cont): 4 => { + let ret: Ctrl::Return; + match rest.tag { + Expr::Nil => { + let cont: Cont::Call0 = hash2(env, cont); + return (head, env, cont, ret) + } + Expr::Cons => { + let (arg, more_args) = unhash2(rest); + match more_args.tag { + Expr::Nil => { + let cont: Cont::Call = hash3(arg, env, cont); + return (head, env, cont, ret) + } + }; + let nil = Symbol("nil"); + let nil = cast(nil, Expr::Nil); + let expanded_inner0: Expr::Cons = hash2(arg, nil); + let expanded_inner: Expr::Cons = hash2(head, expanded_inner0); + let expanded: Expr::Cons = hash2(expanded_inner, more_args); + return (expanded, env, cont, ret) + } + } + }); let is_potentially_fun = func!(is_potentially_fun(head): 1 => { let t = Symbol("t"); let nil = Symbol("nil"); let nil = cast(nil, Expr::Nil); match head.tag { - Expr::Fun | Expr::Cons | Expr::Sym | Expr::Thunk => { + Expr::Fun | Expr::Cons | Expr::Thunk => { return (t) } }; @@ -173,8 +236,8 @@ fn reduce() -> Func { return (thunk_expr, env, thunk_continuation, apply) } Expr::Sym => { - match expr.val { - Symbol("nil") | Symbol("t") => { + match symbol expr { + "nil", "t" => { return (expr, env, cont, apply) } }; @@ -234,70 +297,82 @@ fn reduce() -> Func { let cont: Cont::Lookup = hash2(env, cont); return (expr, env_to_use, cont, ret) } - } + }; + return (expr, env, err, errctrl) } Expr::Cons => { // No need for `safe_uncons` since the expression is already a `Cons` let (head, rest) = unhash2(expr); - match head.val { - Symbol("lambda") => { - let (args, body) = safe_uncons(rest); - let (arg, cdr_args) = extract_arg(args); + match rest.tag { + // rest's tag can only be Nil or Cons + Expr::Sym | Expr::Fun | Expr::Num | Expr::Thunk | Expr::Str + | Expr::Char | Expr::Comm | Expr::U64 | Expr::Key => { + return (expr, env, err, errctrl); + } + }; + match head.tag { + Expr::Sym => { + match symbol head { + "lambda" => { + let (args, body) = safe_uncons(rest); + let (arg, cdr_args) = extract_arg(args); - match arg.tag { - Expr::Sym => { - match cdr_args.tag { - Expr::Nil => { - let function: Expr::Fun = hash3(arg, body, env); + match arg.tag { + Expr::Sym => { + match cdr_args.tag { + Expr::Nil => { + let function: Expr::Fun = hash3(arg, body, env); + return (function, env, cont, apply) + } + }; + let inner: Expr::Cons = hash2(cdr_args, body); + let l: Expr::Cons = hash2(head, inner); + let inner_body: Expr::Cons = hash2(l, nil); + let function: Expr::Fun = hash3(arg, inner_body, env); return (function, env, cont, apply) } }; - let inner: Expr::Cons = hash2(cdr_args, body); - let lambda = Symbol("lambda"); - let l: Expr::Cons = hash2(lambda, inner); - let inner_body: Expr::Cons = hash2(l, nil); - let function: Expr::Fun = hash3(arg, inner_body, env); - return (function, env, cont, apply) + return (expr, env, err, errctrl) } - }; - return (expr, env, err, errctrl) - } - Symbol("quote") => { - let (quoted, end) = safe_uncons(rest); + "quote" => { + let (quoted, end) = safe_uncons(rest); - match end.tag { - Expr::Nil => { - return (quoted, env, cont, apply) - } - }; - return (expr, env, err, errctrl) - } - Symbol("let") | Symbol("letrec") => { - let (bindings, body) = safe_uncons(rest); - let (body1, rest_body) = safe_uncons(body); - // Only a single body form allowed for now. - match body.tag { - Expr::Nil => { + match end.tag { + Expr::Nil => { + return (quoted, env, cont, apply) + } + }; return (expr, env, err, errctrl) } - }; - match rest_body.tag { - Expr::Nil => { - match bindings.tag { + "let", "letrec" => { + let (bindings, body) = safe_uncons(rest); + let (body1, rest_body) = safe_uncons(body); + // Only a single body form allowed for now. + match body.tag { Expr::Nil => { - return (body1, env, cont, ret) + return (expr, env, err, errctrl) } }; - let (binding1, rest_bindings) = safe_uncons(bindings); - let (var, vals) = safe_uncons(binding1); - match var.tag { - Expr::Sym => { - let (val, end) = safe_uncons(vals); - match end.tag { + match rest_body.tag { + Expr::Nil => { + match bindings.tag { Expr::Nil => { - let (expanded) = expand_bindings(head, body, body1, rest_bindings); - let (cont) = choose_let_cont(head, var, env, expanded, cont); - return (val, env, cont, ret) + return (body1, env, cont, ret) + } + }; + let (binding1, rest_bindings) = safe_uncons(bindings); + let (var, vals) = safe_uncons(binding1); + match var.tag { + Expr::Sym => { + let (val, end) = safe_uncons(vals); + match end.tag { + Expr::Nil => { + let (expanded) = expand_bindings(head, body, body1, rest_bindings); + let (cont) = choose_let_cont(head, var, env, expanded, cont); + return (val, env, cont, ret) + } + }; + return (expr, env, err, errctrl) } }; return (expr, env, err, errctrl) @@ -305,112 +380,97 @@ fn reduce() -> Func { }; return (expr, env, err, errctrl) } - }; - return (expr, env, err, errctrl) - } - Symbol("begin") => { - let (arg1, more) = safe_uncons(rest); - match more.tag { - Expr::Nil => { + "begin" => { + let (arg1, more) = safe_uncons(rest); + match more.tag { + Expr::Nil => { + return (arg1, env, cont, ret) + } + }; + let cont: Cont::Binop = hash4(head, env, more, cont); return (arg1, env, cont, ret) } - }; - let cont: Cont::Binop = hash4(head, env, more, cont); - return (arg1, env, cont, ret) - } - Symbol("eval") => { - match rest.tag { - Expr::Nil => { - return (expr, env, err, errctrl) - } - }; - let (arg1, more) = safe_uncons(rest); - match more.tag { - Expr::Nil => { - let cont: Cont::Unop = hash2(head, cont); + "eval" => { + match rest.tag { + Expr::Nil => { + return (expr, env, err, errctrl) + } + }; + let (arg1, more) = safe_uncons(rest); + match more.tag { + Expr::Nil => { + let cont: Cont::Unop = hash2(head, cont); + return (arg1, env, cont, ret) + } + }; + let cont: Cont::Binop = hash4(head, env, more, cont); return (arg1, env, cont, ret) } - }; - let cont: Cont::Binop = hash4(head, env, more, cont); - return (arg1, env, cont, ret) - } - Symbol("if") => { - let (condition, more) = safe_uncons(rest); - match more.tag { - Expr::Nil => { - return (condition, env, err, errctrl) + "if" => { + let (condition, more) = safe_uncons(rest); + match more.tag { + Expr::Nil => { + return (expr, env, err, errctrl) + } + }; + let cont: Cont::If = hash2(more, cont); + return (condition, env, cont, ret) } - }; - let cont: Cont::If = hash2(more, cont); - return (condition, env, cont, ret) - } - Symbol("current-env") => { - match rest.tag { - Expr::Nil => { - return (env, env, cont, apply) + "current-env" => { + match rest.tag { + Expr::Nil => { + return (env, env, cont, apply) + } + }; + return (expr, env, err, errctrl) } }; - return (expr, env, err, errctrl) - } - }; - // unops - let (op) = is_unop(head); - if op == t { - match rest.tag { - Expr::Nil => { + // unops + let (op) = is_unop(head); + if op == t { + match rest.tag { + Expr::Nil => { + return (expr, env, err, errctrl) + } + }; + let (arg1, end) = unhash2(rest); + match end.tag { + Expr::Nil => { + let cont: Cont::Unop = hash2(head, cont); + return (arg1, env, cont, ret) + } + }; return (expr, env, err, errctrl) } - }; - let (arg1, end) = unhash2(rest); - match end.tag { - Expr::Nil => { - let cont: Cont::Unop = hash2(head, cont); + // binops + let (op) = is_binop(head); + if op == t { + match rest.tag { + Expr::Nil => { + return (expr, env, err, errctrl) + } + }; + let (arg1, more) = unhash2(rest); + match more.tag { + Expr::Nil => { + return (expr, env, err, errctrl) + } + }; + let cont: Cont::Binop = hash4(head, env, more, cont); return (arg1, env, cont, ret) } - }; - return (expr, env, err, errctrl) - } - // binops - let (op) = is_binop(head); - if op == t { - match rest.tag { - Expr::Nil => { - return (expr, env, err, errctrl) - } - }; - let (arg1, more) = unhash2(rest); - match more.tag { - Expr::Nil => { - return (expr, env, err, errctrl) - } - }; - let cont: Cont::Binop = hash4(head, env, more, cont); - return (arg1, env, cont, ret) - } + // just call assuming that the symbol is bound to a function + let (fun, env, cont, ret) = make_call(head, rest, env, cont); + return (fun, env, cont, ret); + } + }; // TODO coprocessors (could it be simply a `func`?) // head -> fn, rest -> args let (potentially_fun) = is_potentially_fun(head); if potentially_fun == t { - match rest.tag { - Expr::Nil => { - let cont: Cont::Call0 = hash2(env, cont); - return (head, env, cont, ret) - } - Expr::Cons => { - let (arg, more_args) = unhash2(rest); - match more_args.tag { - Expr::Nil => { - let cont: Cont::Call = hash3(arg, env, cont); - return (head, env, cont, ret) - } - }; - let expanded_inner0: Expr::Cons = hash2(arg, nil); - let expanded_inner: Expr::Cons = hash2(head, expanded_inner0); - let expanded: Expr::Cons = hash2(expanded_inner, more_args); - return (expanded, env, cont, ret) - } - } + let (fun, env, cont, ret) = make_call(head, rest, env, cont); + return (fun, env, cont, ret); } return (expr, env, err, errctrl) } @@ -452,39 +512,40 @@ fn apply_cont() -> Func { } } }); - // Returns 2 if both arguments are U64, 1 if the arguments are some kind of number (either U64 or Num), - // and 0 otherwise + // Returns 0u64 if both arguments are U64, 0 (num) if the arguments are some kind of number (either U64 or Num), + // and nil otherwise let args_num_type = func!(args_num_type(arg1, arg2): 1 => { - let other = Num(0); + let nil = Symbol("nil"); + let nil = cast(nil, Expr::Nil); match arg1.tag { Expr::Num => { match arg2.tag { Expr::Num => { - let ret = Num(1); + let ret: Expr::Num; return (ret) } Expr::U64 => { - let ret = Num(1); + let ret: Expr::Num; return (ret) } }; - return (other) + return (nil) } Expr::U64 => { match arg2.tag { Expr::Num => { - let ret = Num(1); + let ret: Expr::Num; return (ret) } Expr::U64 => { - let ret = Num(2); + let ret: Expr::U64; return (ret) } }; - return (other) + return (nil) } }; - return (other) + return (nil) }); func!(apply_cont(result, env, cont, ctrl): 4 => { // Useful constants @@ -497,6 +558,7 @@ fn apply_cont() -> Func { let t = Symbol("t"); let zero = Num(0); let size_u64 = Num(18446744073709551616); + let empty_str = String(""); match ctrl.tag { Ctrl::ApplyContinuation => { @@ -509,7 +571,6 @@ fn apply_cont() -> Func { return (result, env, cont, ret) } Cont::Emit => { - emit(result); // TODO Does this make sense? let (cont, _rest) = unhash2(cont); return (result, env, cont, makethunk) @@ -519,8 +580,8 @@ fn apply_cont() -> Func { match result.tag { Expr::Fun => { let (arg, body, closed_env) = unhash3(result); - match arg.val { - Symbol("dummy") => { + match symbol arg { + "dummy" => { match body.tag { Expr::Nil => { return (result, env, err, errctrl) @@ -556,8 +617,8 @@ fn apply_cont() -> Func { match function.tag { Expr::Fun => { let (arg, body, closed_env) = unhash3(function); - match arg.val { - Symbol("dummy") => { + match symbol arg { + "dummy" => { return (result, env, err, errctrl) } }; @@ -595,16 +656,48 @@ fn apply_cont() -> Func { } Cont::Unop => { let (operator, continuation) = unhash2(cont); - match operator.val { - Symbol("car") => { - let (car, _cdr) = safe_uncons(result); - return (car, env, continuation, makethunk) + match symbol operator { + "car" => { + // Almost like safe_uncons, except it returns + // an error in case it can't unhash it + match result.tag { + Expr::Nil => { + return (nil, env, continuation, makethunk) + } + Expr::Cons => { + let (car, _cdr) = unhash2(result); + return (car, env, continuation, makethunk) + } + Expr::Str => { + if result == empty_str { + return (nil, env, continuation, makethunk) + } + let (car, _cdr) = unhash2(result); + return (car, env, continuation, makethunk) + } + }; + return(result, env, err, errctrl) } - Symbol("cdr") => { - let (_car, cdr) = safe_uncons(result); - return (cdr, env, continuation, makethunk) + "cdr" => { + match result.tag { + Expr::Nil => { + return (nil, env, continuation, makethunk) + } + Expr::Cons => { + let (_car, cdr) = unhash2(result); + return (cdr, env, continuation, makethunk) + } + Expr::Str => { + if result == empty_str { + return (empty_str, env, continuation, makethunk) + } + let (_car, cdr) = unhash2(result); + return (cdr, env, continuation, makethunk) + } + }; + return(result, env, err, errctrl) } - Symbol("atom") => { + "atom" => { match result.tag { Expr::Cons => { return (nil, env, continuation, makethunk) @@ -612,24 +705,45 @@ fn apply_cont() -> Func { }; return (t, env, continuation, makethunk) } - Symbol("emit") => { + "emit" => { // TODO Does this make sense? - let emit: Cont::Emit = hash2(cont, nil); + emit(result); + let emit: Cont::Emit = hash2(continuation, nil); return (result, env, emit, makethunk) } - Symbol("open") => { - let (_secret, payload) = open(result); - return(payload, env, continuation, makethunk) + "open" => { + match result.tag { + Expr::Num => { + let result = cast(result, Expr::Comm); + let (_secret, payload) = open(result); + return(payload, env, continuation, makethunk) + } + Expr::Comm => { + let (_secret, payload) = open(result); + return(payload, env, continuation, makethunk) + } + }; + return(result, env, err, errctrl) } - Symbol("secret") => { - let (secret, _payload) = open(result); - return(secret, env, continuation, makethunk) + "secret" => { + match result.tag { + Expr::Num => { + let result = cast(result, Expr::Comm); + let (secret, _payload) = open(result); + return(secret, env, continuation, makethunk) + } + Expr::Comm => { + let (secret, _payload) = open(result); + return(secret, env, continuation, makethunk) + } + }; + return(result, env, err, errctrl) } - Symbol("commit") => { + "commit" => { let comm = hide(zero, result); return(comm, env, continuation, makethunk) } - Symbol("num") => { + "num" => { match result.tag { Expr::Num | Expr::Comm | Expr::Char | Expr::U64 => { let cast = cast(result, Expr::Num); @@ -638,7 +752,7 @@ fn apply_cont() -> Func { }; return(result, env, err, errctrl) } - Symbol("u64") => { + "u64" => { match result.tag { Expr::Num => { // The limit is 2**64 - 1 @@ -652,7 +766,7 @@ fn apply_cont() -> Func { }; return(result, env, err, errctrl) } - Symbol("comm") => { + "comm" => { match result.tag { Expr::Num | Expr::Comm => { let cast = cast(result, Expr::Comm); @@ -661,7 +775,7 @@ fn apply_cont() -> Func { }; return(result, env, err, errctrl) } - Symbol("char") => { + "char" => { match result.tag { Expr::Num => { // The limit is 2**32 - 1 @@ -675,7 +789,7 @@ fn apply_cont() -> Func { }; return(result, env, err, errctrl) } - Symbol("eval") => { + "eval" => { return(result, nil, continuation, ret) } }; @@ -684,15 +798,14 @@ fn apply_cont() -> Func { Cont::Binop => { let (operator, saved_env, unevaled_args, continuation) = unhash4(cont); let (arg2, rest) = safe_uncons(unevaled_args); - match operator.val { - Symbol("begin") => { + match symbol operator { + "begin" => { match rest.tag { Expr::Nil => { return (arg2, saved_env, continuation, ret) } }; - let begin = Symbol("begin"); - let begin_again: Expr::Cons = hash2(begin, unevaled_args); + let begin_again: Expr::Cons = hash2(operator, unevaled_args); return (begin_again, saved_env, continuation, ctrl) } }; @@ -707,20 +820,20 @@ fn apply_cont() -> Func { Cont::Binop2 => { let (operator, evaled_arg, continuation) = unhash3(cont); let (args_num_type) = args_num_type(evaled_arg, result); - match operator.val { - Symbol("eval") => { + match symbol operator { + "eval" => { return (evaled_arg, result, continuation, ret) } - Symbol("cons") => { + "cons" => { let val: Expr::Cons = hash2(evaled_arg, result); return (val, env, continuation, makethunk) } - Symbol("strcons") => { + "strcons" => { match evaled_arg.tag { Expr::Char => { - match evaled_arg.tag { + match result.tag { Expr::Str => { - let val: Expr::Cons = hash2(evaled_arg, result); + let val: Expr::Str = hash2(evaled_arg, result); return (val, env, continuation, makethunk) } }; @@ -729,89 +842,81 @@ fn apply_cont() -> Func { }; return (result, env, err, errctrl) } - Symbol("hide") => { - let num = cast(evaled_arg, Expr::Num); - let hidden = hide(num, result); - return(hidden, env, continuation, makethunk) + "hide" => { + match evaled_arg.tag { + Expr::Num => { + let hidden = hide(evaled_arg, result); + return(hidden, env, continuation, makethunk) + } + }; + return (result, env, err, errctrl) } - Symbol("eq") => { + "eq" => { let eq_tag = eq_tag(evaled_arg, result); let eq_val = eq_val(evaled_arg, result); let eq = mul(eq_tag, eq_val); - match eq.val { - Num(0) => { - return (nil, env, continuation, makethunk) - } - Num(1) => { - return (t, env, continuation, makethunk) - } + if eq == zero { + return (nil, env, continuation, makethunk) } + return (t, env, continuation, makethunk) } - Symbol("+") => { - match args_num_type.val { - Num(0) => { + "+" => { + match args_num_type.tag { + Expr::Nil => { return (result, env, err, errctrl) } - Num(1) => { + Expr::Num => { let val = add(evaled_arg, result); return (val, env, continuation, makethunk) } - Num(2) => { + Expr::U64 => { let val = add(evaled_arg, result); let not_overflow = lt(val, size_u64); - match not_overflow.val { - Num(0) => { - let val = sub(val, size_u64); - let val = cast(val, Expr::U64); - return (val, env, continuation, makethunk) - } - Num(1) => { - let val = cast(val, Expr::U64); - return (val, env, continuation, makethunk) - } + if not_overflow == zero { + let val = sub(val, size_u64); + let val = cast(val, Expr::U64); + return (val, env, continuation, makethunk) } + let val = cast(val, Expr::U64); + return (val, env, continuation, makethunk) } } } - Symbol("-") => { - match args_num_type.val { - Num(0) => { + "-" => { + match args_num_type.tag { + Expr::Nil => { return (result, env, err, errctrl) } - Num(1) => { + Expr::Num => { let val = sub(evaled_arg, result); return (val, env, continuation, makethunk) } - Num(2) => { + Expr::U64 => { // Subtraction in U64 is almost the same as subtraction // in the field. If the difference is negative, we need // to add 2^64 to get back to U64 domain. let val = sub(evaled_arg, result); let is_neg = lt(val, zero); - match is_neg.val { - Num(0) => { - let val = add(val, size_u64); - let val = cast(val, Expr::U64); - return (val, env, continuation, makethunk) - } - Num(1) => { - let val = cast(val, Expr::U64); - return (val, env, continuation, makethunk) - } + if is_neg == zero { + let val = cast(val, Expr::U64); + return (val, env, continuation, makethunk) } + let val = add(val, size_u64); + let val = cast(val, Expr::U64); + return (val, env, continuation, makethunk) } } } - Symbol("*") => { - match args_num_type.val { - Num(0) => { + "*" => { + match args_num_type.tag { + Expr::Nil => { return (result, env, err, errctrl) } - Num(1) => { + Expr::Num => { let val = mul(evaled_arg, result); return (val, env, continuation, makethunk) } - Num(2) => { + Expr::U64 => { let val = mul(evaled_arg, result); // The limit is 2**64 - 1 let trunc = truncate(val, 64); @@ -820,86 +925,79 @@ fn apply_cont() -> Func { } } } - Symbol("/") => { - match args_num_type.val { - Num(0) => { - return (result, env, err, errctrl) - } - Num(1) => { - let val = div(evaled_arg, result); - return (val, env, continuation, makethunk) - } - Num(2) => { - let (div, _rem) = div_rem64(evaled_arg, result); - let div = cast(div, Expr::U64); - return (div, env, continuation, makethunk) + "/" => { + let is_z = eq_val(result, zero); + if is_z == zero { + match args_num_type.tag { + Expr::Nil => { + return (result, env, err, errctrl) + } + Expr::Num => { + let val = div(evaled_arg, result); + return (val, env, continuation, makethunk) + } + Expr::U64 => { + let (div, _rem) = div_rem64(evaled_arg, result); + let div = cast(div, Expr::U64); + return (div, env, continuation, makethunk) + } } } + return (result, env, err, errctrl) } - Symbol("%") => { - match args_num_type.val { - Num(2) => { - let (_div, rem) = div_rem64(evaled_arg, result); - let rem = cast(rem, Expr::U64); - return (rem, env, continuation, makethunk) - } - }; + "%" => { + let is_z = eq_val(result, zero); + if is_z == zero { + match args_num_type.tag { + Expr::U64 => { + let (_div, rem) = div_rem64(evaled_arg, result); + let rem = cast(rem, Expr::U64); + return (rem, env, continuation, makethunk) + } + }; + return (result, env, err, errctrl) + } return (result, env, err, errctrl) } - Symbol("=") => { - match args_num_type.val { - Num(0) => { + "=" => { + match args_num_type.tag { + Expr::Nil => { return (result, env, err, errctrl) } }; - if evaled_arg == result { - return (t, env, continuation, makethunk) + let eq = eq_val(evaled_arg, result); + if eq == zero { + return (nil, env, continuation, makethunk) } - return (nil, env, continuation, makethunk) + return (t, env, continuation, makethunk) } - Symbol("<") => { + "<" => { let val = lt(evaled_arg, result); - match val.val { - Num(0) => { - return (nil, env, continuation, makethunk) - } - Num(1) => { - return (t, env, continuation, makethunk) - } + if val == zero { + return (nil, env, continuation, makethunk) } + return (t, env, continuation, makethunk) } - Symbol(">") => { + ">" => { let val = lt(result, evaled_arg); - match val.val { - Num(0) => { - return (nil, env, continuation, makethunk) - } - Num(1) => { - return (t, env, continuation, makethunk) - } + if val == zero { + return (nil, env, continuation, makethunk) } + return (t, env, continuation, makethunk) } - Symbol("<=") => { + "<=" => { let val = lt(result, evaled_arg); - match val.val { - Num(0) => { - return (t, env, continuation, makethunk) - } - Num(1) => { - return (nil, env, continuation, makethunk) - } + if val == zero { + return (t, env, continuation, makethunk) } + return (nil, env, continuation, makethunk) } - Symbol(">=") => { + ">=" => { let val = lt(evaled_arg, result); - match val.val { - Num(0) => { - return (t, env, continuation, makethunk) - } - Num(1) => { - return (nil, env, continuation, makethunk) - } + if val == zero { + return (t, env, continuation, makethunk) } + return (nil, env, continuation, makethunk) } }; return (result, env, err, errctrl) @@ -918,7 +1016,7 @@ fn apply_cont() -> Func { return (arg1, env, continuation, ret) } }; - return (result, env, err, errctrl) + return (arg1, env, err, errctrl) } Cont::Lookup => { let (saved_env, continuation) = unhash2(cont); @@ -964,15 +1062,16 @@ fn make_thunk() -> Func { #[cfg(test)] mod tests { use super::*; - use crate::lem::{pointers::Ptr, slot::SlotsCounter, store::Store, Tag}; - use crate::state::{lurk_sym, State}; - use crate::tag::ContTag::*; - use bellpepper_core::{test_cs::TestConstraintSystem, Comparable}; + use crate::{ + lem::{pointers::Ptr, slot::SlotsCounter, store::Store, Tag}, + state::State, + }; + use bellpepper_core::{test_cs::TestConstraintSystem, Comparable, Delta}; use blstrs::Scalar as Fr; const NUM_INPUTS: usize = 1; - const NUM_AUX: usize = 9885; - const NUM_CONSTRAINTS: usize = 12178; + const NUM_AUX: usize = 10744; + const NUM_CONSTRAINTS: usize = 13299; const NUM_SLOTS: SlotsCounter = SlotsCounter { hash2: 16, hash3: 4, @@ -981,9 +1080,11 @@ mod tests { less_than: 1, }; - fn test_eval_and_constrain_aux(store: &mut Store, pairs: Vec<(Ptr, Ptr)>) { - let eval_step = eval_step(); - + fn test_eval_and_constrain_aux( + eval_step: &Func, + store: &mut Store, + pairs: Vec<(Ptr, Ptr)>, + ) { assert_eq!(eval_step.slot, NUM_SLOTS); let computed_num_constraints = eval_step.num_constraints::(store); @@ -994,20 +1095,29 @@ mod tests { let outermost = Ptr::null(Tag::Cont(Outermost)); let terminal = Ptr::null(Tag::Cont(Terminal)); let error = Ptr::null(Tag::Cont(Error)); - let nil = store.intern_symbol(&lurk_sym("nil")); + let nil = store.intern_nil(); // Stop condition: the continuation is either terminal or error let stop_cond = |output: &[Ptr]| output[2] == terminal || output[2] == error; + let log_fmt = |_: usize, _: &[Ptr], _: &[Ptr], _: &Store| String::default(); + + let limit = 10000; + + let mut cs_prev = None; for (expr_in, expr_out) in pairs { - let input = vec![expr_in, nil, outermost]; - let (frames, paths) = eval_step.call_until(input, store, stop_cond).unwrap(); + let input = [expr_in, nil, outermost]; + let (frames, _, paths) = eval_step + .call_until(&input, store, stop_cond, limit, log_fmt) + .unwrap(); let last_frame = frames.last().expect("eval should add at least one frame"); assert_eq!(last_frame.output[0], expr_out); store.hydrate_z_cache(); for frame in frames.iter() { let mut cs = TestConstraintSystem::::new(); - eval_step.synthesize(&mut cs, store, frame).unwrap(); + eval_step + .synthesize_frame_aux(&mut cs, store, frame) + .unwrap(); assert!(cs.is_satisfied()); assert_eq!(cs.num_inputs(), NUM_INPUTS); assert_eq!(cs.aux().len(), NUM_AUX); @@ -1015,7 +1125,12 @@ mod tests { let num_constraints = cs.num_constraints(); assert_eq!(computed_num_constraints, num_constraints); assert_eq!(num_constraints, NUM_CONSTRAINTS); - // TODO: assert uniformity with `Delta` from bellperson + + if let Some(cs_prev) = cs_prev { + // Check for all input expresssions that all frames are uniform. + assert_eq!(cs.delta(&cs_prev, true), Delta::Equal); + } + cs_prev = Some(cs); } all_paths.extend(paths); } @@ -1072,9 +1187,9 @@ mod tests { (if (eq xs nil) 0 (+ (car xs) (sum (cdr xs))))))) - (sum (build 10)))", + (sum (build 4)))", ); - let fold_res = read("55"); + let fold_res = read("10"); vec![ (div, div_res), (rem, rem_res), @@ -1101,9 +1216,10 @@ mod tests { #[test] fn test_pairs() { - let mut store = Store::default(); - let pairs = expr_in_expr_out_pairs(&mut store); + let step_fn = eval_step(); + let store = &mut step_fn.init_store(); + let pairs = expr_in_expr_out_pairs(store); store.hydrate_z_cache(); - test_eval_and_constrain_aux(&mut store, pairs); + test_eval_and_constrain_aux(step_fn, store, pairs); } } diff --git a/src/lem/interpreter.rs b/src/lem/interpreter.rs index da573ccdcf..b3aa073f33 100644 --- a/src/lem/interpreter.rs +++ b/src/lem/interpreter.rs @@ -1,13 +1,9 @@ -use crate::field::{FWrap, LurkField}; -use crate::num::Num; use anyhow::{bail, Result}; use std::collections::VecDeque; -use super::{ - path::Path, pointers::Ptr, store::Store, var_map::VarMap, Block, Ctrl, Func, Lit, Op, Tag, -}; +use super::{path::Path, pointers::Ptr, store::Store, var_map::VarMap, Block, Ctrl, Func, Op, Tag}; -use crate::tag::ExprTag::*; +use crate::{field::LurkField, num::Num, state::initial_lurk_state, tag::ExprTag::*}; #[derive(Clone, Debug)] pub enum PreimageData { @@ -48,6 +44,24 @@ impl Preimages { call_outputs, } } + + pub fn blank(func: &Func) -> Preimages { + let slot = func.slot; + let hash2 = vec![None; slot.hash2]; + let hash3 = vec![None; slot.hash3]; + let hash4 = vec![None; slot.hash4]; + let commitment = vec![None; slot.commitment]; + let less_than = vec![None; slot.less_than]; + let call_outputs = VecDeque::new(); + Preimages { + hash2, + hash3, + hash4, + commitment, + less_than, + call_outputs, + } + } } /// A `Frame` carries the data that results from interpreting a LEM. That is, @@ -55,11 +69,26 @@ impl Preimages { /// running one iteration as a HashMap of variables to pointers. /// /// This information is used to generate the witness. -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Default)] pub struct Frame { pub input: Vec>, pub output: Vec>, pub preimages: Preimages, + pub blank: bool, +} + +impl Frame { + pub fn blank(func: &Func) -> Frame { + let input = vec![Ptr::null(Tag::Expr(Nil)); func.input_params.len()]; + let output = vec![Ptr::null(Tag::Expr(Nil)); func.output_size]; + let preimages = Preimages::blank(func); + Frame { + input, + output, + preimages, + blank: true, + } + } } impl Block { @@ -68,11 +97,12 @@ impl Block { /// in `circuit.rs`) fn run( &self, - input: Vec>, + input: &[Ptr], store: &mut Store, mut bindings: VarMap>, mut preimages: Preimages, mut path: Path, + emitted: &mut Vec>, ) -> Result<(Frame, Path)> { for op in &self.ops { match op { @@ -86,7 +116,7 @@ impl Block { // of it, then extend `call_outputs` let mut inner_call_outputs = VecDeque::new(); std::mem::swap(&mut inner_call_outputs, &mut preimages.call_outputs); - let (mut frame, func_path) = func.call(inp_ptrs, store, preimages)?; + let (mut frame, func_path) = func.call(&inp_ptrs, store, preimages, emitted)?; std::mem::swap(&mut inner_call_outputs, &mut frame.preimages.call_outputs); // Extend the path and bind the output variables to the output values @@ -115,9 +145,9 @@ impl Block { let a = bindings.get(a)?; let b = bindings.get(b)?; let c = if a.tag() == b.tag() { - Ptr::Leaf(Tag::Expr(Num), F::ONE) + Ptr::Atom(Tag::Expr(Num), F::ONE) } else { - Ptr::Leaf(Tag::Expr(Num), F::ZERO) + Ptr::Atom(Tag::Expr(Num), F::ZERO) }; bindings.insert(tgt.clone(), c); } @@ -129,50 +159,50 @@ impl Block { let a_hash = store.hash_ptr(a)?.hash; let b_hash = store.hash_ptr(b)?.hash; let c = if a_hash == b_hash { - Ptr::Leaf(Tag::Expr(Num), F::ONE) + Ptr::Atom(Tag::Expr(Num), F::ONE) } else { - Ptr::Leaf(Tag::Expr(Num), F::ZERO) + Ptr::Atom(Tag::Expr(Num), F::ZERO) }; bindings.insert(tgt.clone(), c); } Op::Add(tgt, a, b) => { let a = bindings.get(a)?; let b = bindings.get(b)?; - let c = if let (Ptr::Leaf(_, f), Ptr::Leaf(_, g)) = (a, b) { - Ptr::Leaf(Tag::Expr(Num), *f + *g) + let c = if let (Ptr::Atom(_, f), Ptr::Atom(_, g)) = (a, b) { + Ptr::Atom(Tag::Expr(Num), *f + *g) } else { - bail!("`Add` only works on leaves") + bail!("`Add` only works on atoms") }; bindings.insert(tgt.clone(), c); } Op::Sub(tgt, a, b) => { let a = bindings.get(a)?; let b = bindings.get(b)?; - let c = if let (Ptr::Leaf(_, f), Ptr::Leaf(_, g)) = (a, b) { - Ptr::Leaf(Tag::Expr(Num), *f - *g) + let c = if let (Ptr::Atom(_, f), Ptr::Atom(_, g)) = (a, b) { + Ptr::Atom(Tag::Expr(Num), *f - *g) } else { - bail!("`Sub` only works on leaves") + bail!("`Sub` only works on atoms") }; bindings.insert(tgt.clone(), c); } Op::Mul(tgt, a, b) => { let a = bindings.get(a)?; let b = bindings.get(b)?; - let c = if let (Ptr::Leaf(_, f), Ptr::Leaf(_, g)) = (a, b) { - Ptr::Leaf(Tag::Expr(Num), *f * *g) + let c = if let (Ptr::Atom(_, f), Ptr::Atom(_, g)) = (a, b) { + Ptr::Atom(Tag::Expr(Num), *f * *g) } else { - bail!("`Mul` only works on leaves") + bail!("`Mul` only works on atoms") }; bindings.insert(tgt.clone(), c); } Op::Div(tgt, a, b) => { let a = bindings.get(a)?; let b = bindings.get(b)?; - let c = if let (Ptr::Leaf(_, f), Ptr::Leaf(_, g)) = (a, b) { + let c = if let (Ptr::Atom(_, f), Ptr::Atom(_, g)) = (a, b) { if g == &F::ZERO { bail!("Can't divide by zero") } - Ptr::Leaf(Tag::Expr(Num), *f * g.invert().expect("not zero")) + Ptr::Atom(Tag::Expr(Num), *f * g.invert().expect("not zero")) } else { bail!("`Div` only works on numbers") }; @@ -181,23 +211,23 @@ impl Block { Op::Lt(tgt, a, b) => { let a = bindings.get(a)?; let b = bindings.get(b)?; - let c = if let (Ptr::Leaf(_, f), Ptr::Leaf(_, g)) = (a, b) { + let c = if let (Ptr::Atom(_, f), Ptr::Atom(_, g)) = (a, b) { preimages.less_than.push(Some(PreimageData::FPair(*f, *g))); let f = Num::Scalar(*f); let g = Num::Scalar(*g); let b = if f < g { F::ONE } else { F::ZERO }; - Ptr::Leaf(Tag::Expr(Num), b) + Ptr::Atom(Tag::Expr(Num), b) } else { - bail!("`Lt` only works on leaves") + bail!("`Lt` only works on atoms") }; bindings.insert(tgt.clone(), c); } Op::Trunc(tgt, a, n) => { assert!(*n <= 64); let a = bindings.get(a)?; - let c = if let Ptr::Leaf(_, f) = a { + let c = if let Ptr::Atom(_, f) = a { let b = if *n < 64 { (1 << *n) - 1 } else { u64::MAX }; - Ptr::Leaf(Tag::Expr(Num), F::from_u64(f.to_u64_unchecked() & b)) + Ptr::Atom(Tag::Expr(Num), F::from_u64(f.to_u64_unchecked() & b)) } else { bail!("`Trunc` only works a leaf") }; @@ -206,24 +236,25 @@ impl Block { Op::DivRem64(tgt, a, b) => { let a = bindings.get(a)?; let b = bindings.get(b)?; - let (c1, c2) = if let (Ptr::Leaf(_, f), Ptr::Leaf(_, g)) = (a, b) { + let (c1, c2) = if let (Ptr::Atom(_, f), Ptr::Atom(_, g)) = (a, b) { if g == &F::ZERO { bail!("Can't divide by zero") } let f = f.to_u64_unchecked(); let g = g.to_u64_unchecked(); - let c1 = Ptr::Leaf(Tag::Expr(Num), F::from_u64(f / g)); - let c2 = Ptr::Leaf(Tag::Expr(Num), F::from_u64(f % g)); + let c1 = Ptr::Atom(Tag::Expr(Num), F::from_u64(f / g)); + let c2 = Ptr::Atom(Tag::Expr(Num), F::from_u64(f % g)); (c1, c2) } else { - bail!("`DivRem64` only works on leaves") + bail!("`DivRem64` only works on atoms") }; bindings.insert(tgt[0].clone(), c1); bindings.insert(tgt[1].clone(), c2); } Op::Emit(a) => { let a = bindings.get(a)?; - println!("{}", a.dbg_display(store)) + println!("{}", a.fmt_to_string(store, initial_lurk_state())); + emitted.push(*a); } Op::Hash2(img, tag, preimg) => { let preimg_ptrs = bindings.get_many_cloned(preimg)?; @@ -306,30 +337,24 @@ impl Block { } Op::Hide(tgt, sec, src) => { let src_ptr = bindings.get(src)?; - let Ptr::Leaf(Tag::Expr(Num), secret) = bindings.get(sec)? else { + let Ptr::Atom(Tag::Expr(Num), secret) = bindings.get(sec)? else { bail!("{sec} is not a numeric pointer") }; - let z_ptr = store.hash_ptr(src_ptr)?; - let hash = - store - .poseidon_cache - .hash3(&[*secret, z_ptr.tag.to_field(), z_ptr.hash]); - let tgt_ptr = Ptr::comm(hash); - store.comms.insert(FWrap::(hash), (*secret, *src_ptr)); + let tgt_ptr = store.hide(*secret, *src_ptr)?; preimages .commitment .push(Some(PreimageData::FPtr(*secret, *src_ptr))); bindings.insert(tgt.clone(), tgt_ptr); } Op::Open(tgt_secret, tgt_ptr, comm) => { - let Ptr::Leaf(Tag::Expr(Comm), hash) = bindings.get(comm)? else { + let Ptr::Atom(Tag::Expr(Comm), hash) = bindings.get(comm)? else { bail!("{comm} is not a comm pointer") }; - let Some((secret, ptr)) = store.comms.get(&FWrap::(*hash)) else { + let Some((secret, ptr)) = store.open(*hash) else { bail!("No committed data for hash {}", &hash.hex_digits()) }; bindings.insert(tgt_ptr.clone(), *ptr); - bindings.insert(tgt_secret.clone(), Ptr::Leaf(Tag::Expr(Num), *secret)); + bindings.insert(tgt_secret.clone(), Ptr::Atom(Tag::Expr(Num), *secret)); preimages .commitment .push(Some(PreimageData::FPtr(*secret, *ptr))) @@ -342,39 +367,36 @@ impl Block { let tag = ptr.tag(); match cases.get(tag) { Some(block) => { - path.push_tag_inplace(tag); - block.run(input, store, bindings, preimages, path) + path.push_tag_inplace(*tag); + block.run(input, store, bindings, preimages, path, emitted) } None => { path.push_default_inplace(); match def { - Some(def) => def.run(input, store, bindings, preimages, path), + Some(def) => def.run(input, store, bindings, preimages, path, emitted), None => bail!("No match for tag {}", tag), } } } } - Ctrl::MatchVal(match_var, cases, def) => { + Ctrl::MatchSymbol(match_var, cases, def) => { let ptr = bindings.get(match_var)?; - let Some(lit) = Lit::from_ptr(ptr, store) else { - // If we can't find it in the store, it most certaily is not equal to any - // of the cases, which are all interned - path.push_default_inplace(); - match def { - Some(def) => return def.run(input, store, bindings, preimages, path), - None => bail!("No match for literal"), - } + if ptr.tag() != &Tag::Expr(Sym) { + bail!("{match_var} is not a symbol"); + } + let Some(sym) = store.fetch_symbol(ptr) else { + bail!("Symbol bound to {match_var} wasn't interned"); }; - match cases.get(&lit) { + match cases.get(&sym) { Some(block) => { - path.push_lit_inplace(&lit); - block.run(input, store, bindings, preimages, path) + path.push_symbol_inplace(sym); + block.run(input, store, bindings, preimages, path, emitted) } None => { path.push_default_inplace(); match def { - Some(def) => def.run(input, store, bindings, preimages, path), - None => bail!("No match for literal {:?}", lit), + Some(def) => def.run(input, store, bindings, preimages, path, emitted), + None => bail!("No match for symbol {sym}"), } } } @@ -385,9 +407,9 @@ impl Block { let b = x == y; path.push_bool_inplace(b); if b { - eq_block.run(input, store, bindings, preimages, path) + eq_block.run(input, store, bindings, preimages, path, emitted) } else { - else_block.run(input, store, bindings, preimages, path) + else_block.run(input, store, bindings, preimages, path, emitted) } } Ctrl::Return(output_vars) => { @@ -395,11 +417,13 @@ impl Block { for var in output_vars.iter() { output.push(*bindings.get(var)?) } + let input = input.to_vec(); Ok(( Frame { input, output, preimages, + blank: false, }, path, )) @@ -411,9 +435,10 @@ impl Block { impl Func { pub fn call( &self, - args: Vec>, + args: &[Ptr], store: &mut Store, preimages: Preimages, + emitted: &mut Vec>, ) -> Result<(Frame, Path)> { let mut bindings = VarMap::new(); for (i, param) in self.input_params.iter().enumerate() { @@ -430,7 +455,7 @@ impl Func { let mut res = self .body - .run(args, store, bindings, preimages, Path::default())?; + .run(args, store, bindings, preimages, Path::default(), emitted)?; let preimages = &mut res.0.preimages; let hash2_used = preimages.hash2.len() - hash2_init; @@ -460,39 +485,80 @@ impl Func { /// Calls a `Func` on an input until the stop contidion is satisfied, using the output of one /// iteration as the input of the next one. - pub fn call_until]) -> bool>( + pub fn call_until< + F: LurkField, + StopCond: Fn(&[Ptr]) -> bool, + // iteration -> input -> emitted -> store -> string + LogFmt: Fn(usize, &[Ptr], &[Ptr], &Store) -> String, + >( &self, - mut args: Vec>, + args: &[Ptr], store: &mut Store, - stop_cond: Stop, - ) -> Result<(Vec>, Vec)> { - if self.input_params.len() != self.output_size { - assert_eq!(self.input_params.len(), self.output_size) - } - if self.input_params.len() != args.len() { - assert_eq!(args.len(), self.input_params.len()) - } + stop_cond: StopCond, + limit: usize, + // TODO: make this argument optional + log_fmt: LogFmt, + ) -> Result<(Vec>, usize, Vec)> { + assert_eq!(self.input_params.len(), self.output_size); + assert_eq!(self.input_params.len(), args.len()); - // Initial path vector and frames + // Initial input, path vector and frames + let mut input = args.to_vec(); let mut frames = vec![]; let mut paths = vec![]; - loop { + let mut iterations = 0; + + tracing::info!("{}", &log_fmt(iterations, &input, &[], store)); + + for _ in 0..limit { let preimages = Preimages::new_from_func(self); - let (frame, path) = self.call(args, store, preimages)?; + let mut emitted = vec![]; + let (frame, path) = self.call(&input, store, preimages, &mut emitted)?; + input = frame.output.clone(); + iterations += 1; + tracing::info!("{}", &log_fmt(iterations, &input, &emitted, store)); if stop_cond(&frame.output) { frames.push(frame); paths.push(path); break; } - // Should frames take borrowed vectors instead, as to avoid cloning? - // Using AVec is a possibility, but to create a dynamic AVec, currently, - // requires 2 allocations since it must be created from a Vec and - // Vec -> Arc<[T]> uses `copy_from_slice`. - args = frame.output.clone(); frames.push(frame); paths.push(path); } - Ok((frames, paths)) + if iterations < limit { + // pushing a frame that can be padded + let preimages = Preimages::new_from_func(self); + let (frame, path) = self.call(&input, store, preimages, &mut vec![])?; + frames.push(frame); + paths.push(path); + } + Ok((frames, iterations, paths)) + } + + pub fn call_until_simple]) -> bool>( + &self, + args: Vec>, + store: &mut Store, + stop_cond: StopCond, + limit: usize, + ) -> Result<(Vec>, usize, Vec>)> { + assert_eq!(self.input_params.len(), self.output_size); + assert_eq!(self.input_params.len(), args.len()); + + let mut input = args; + let mut emitted = vec![]; + + let mut iterations = 0; + + for _ in 0..limit { + let (frame, _) = self.call(&input, store, Preimages::default(), &mut emitted)?; + input = frame.output.clone(); + iterations += 1; + if stop_cond(&frame.output) { + break; + } + } + Ok((input, iterations, emitted)) } } diff --git a/src/lem/macros.rs b/src/lem/macros.rs index 10869a34be..d24adbf2fc 100644 --- a/src/lem/macros.rs +++ b/src/lem/macros.rs @@ -201,19 +201,19 @@ macro_rules! ctrl { $crate::lem::Ctrl::MatchTag($crate::var!($sii), cases, default) } }; - ( match $sii:ident.val { $( $cnstr:ident($val:literal) $(| $other_cnstr:ident($other_val:literal))* => $case_ops:tt )* } $(; $($def:tt)*)? ) => { + ( match symbol $sii:ident { $( $sym:expr $(, $other_sym:expr)* => $case_ops:tt )* } $(; $($def:tt)*)? ) => { { let mut cases = indexmap::IndexMap::new(); $( if cases.insert( - $crate::lit!($cnstr($val)), + $crate::state::lurk_sym($sym), $crate::block!( $case_ops ), ).is_some() { panic!("Repeated value on `match`"); }; $( if cases.insert( - $crate::lit!($other_cnstr($other_val)), + $crate::state::lurk_sym($other_sym), $crate::block!( $case_ops ), ).is_some() { panic!("Repeated value on `match`"); @@ -221,7 +221,7 @@ macro_rules! ctrl { )* )* let default = None $( .or (Some(Box::new($crate::block!( @seq {}, $($def)* )))) )?; - $crate::lem::Ctrl::MatchVal($crate::var!($sii), cases, default) + $crate::lem::Ctrl::MatchSymbol($crate::var!($sii), cases, default) } }; ( if $x:ident == $y:ident { $($true_block:tt)+ } $($false_block:tt)+ ) => { @@ -508,13 +508,13 @@ macro_rules! block { $crate::ctrl!( match $sii.tag { $( $kind::$tag $(| $other_kind::$other_tag)* => $case_ops )* } $(; $($def)*)? ) ) }; - (@seq {$($limbs:expr)*}, match $sii:ident.val { $( $cnstr:ident($val:literal) $(| $other_cnstr:ident($other_val:literal))* => $case_ops:tt )* } $(; $($def:tt)*)?) => { + (@seq {$($limbs:expr)*}, match symbol $sii:ident { $( $sym:expr $(, $other_sym:expr)* => $case_ops:tt )* } $(; $($def:tt)*)?) => { $crate::block! ( @end { $($limbs)* }, - $crate::ctrl!( match $sii.val { $( $cnstr($val) $(| $other_cnstr($other_val))* => $case_ops )* } $(; $($def)*)? ) + $crate::ctrl!( match symbol $sii { $( $sym $(, $other_sym)* => $case_ops )* } $(; $($def)*)? ) ) }; (@seq {$($limbs:expr)*}, if $x:ident == $y:ident { $($true_block:tt)+ } $($false_block:tt)+ ) => { @@ -572,9 +572,12 @@ macro_rules! func { #[cfg(test)] mod tests { - use crate::lem::{Block, Ctrl, Lit, Op, Tag, Var}; - use crate::state::lurk_sym; - use crate::tag::ExprTag::*; + use crate::{ + lem::{Block, Ctrl, Op, Tag, Var}, + state::lurk_sym, + tag::ExprTag::*, + Symbol, + }; #[inline] fn mptr(name: &str) -> Var { @@ -587,8 +590,8 @@ mod tests { } #[inline] - fn match_val(i: Var, cases: Vec<(Lit, Block)>, def: Block) -> Ctrl { - Ctrl::MatchVal(i, indexmap::IndexMap::from_iter(cases), Some(Box::new(def))) + fn match_symbol(i: Var, cases: Vec<(Symbol, Block)>, def: Block) -> Ctrl { + Ctrl::MatchSymbol(i, indexmap::IndexMap::from_iter(cases), Some(Box::new(def))) } #[test] @@ -698,11 +701,11 @@ mod tests { ); let moo = ctrl!( - match www.val { - Symbol("nil") => { + match symbol www { + "nil" => { return (foo, foo, foo); // a single Ctrl will not turn into a Seq } - Symbol("cons") => { + "cons" => { let foo: Expr::Num; let goo: Expr::Char; return (foo, goo, goo); @@ -713,18 +716,18 @@ mod tests { ); assert!( - moo == match_val( + moo == match_symbol( mptr("www"), vec![ ( - Lit::Symbol(lurk_sym("nil")), + lurk_sym("nil"), Block { ops: vec![], ctrl: Ctrl::Return(vec![mptr("foo"), mptr("foo"), mptr("foo")]), } ), ( - Lit::Symbol(lurk_sym("cons")), + lurk_sym("cons"), Block { ops: vec![ Op::Null(mptr("foo"), Tag::Expr(Num)), diff --git a/src/lem/mod.rs b/src/lem/mod.rs index c652d62437..1c25237953 100644 --- a/src/lem/mod.rs +++ b/src/lem/mod.rs @@ -59,21 +59,25 @@ //! 6. We also check for variables that are not used. If intended they should //! be prefixed by "_" -mod circuit; -mod eval; -mod interpreter; +pub mod circuit; +pub mod eval; +pub mod interpreter; mod macros; mod path; -mod pointers; +pub mod pointers; mod slot; -mod store; +pub mod store; mod var_map; +pub mod zstore; +use crate::coprocessor::Coprocessor; +use crate::eval::lang::Lang; use crate::field::LurkField; use crate::symbol::Symbol; use crate::tag::{ContTag, ExprTag, Tag as TagTrait}; use anyhow::{bail, Result}; use indexmap::IndexMap; +use serde::{Deserialize, Serialize}; use std::sync::Arc; use self::{pointers::Ptr, slot::SlotsCounter, store::Store, var_map::VarMap}; @@ -84,11 +88,17 @@ pub type AString = Arc; /// function body, which is a `Block` #[derive(Debug, Clone, PartialEq, Eq)] pub struct Func { - name: String, - input_params: Vec, - output_size: usize, - body: Block, - slot: SlotsCounter, + pub name: String, + pub input_params: Vec, + pub output_size: usize, + pub body: Block, + pub slot: SlotsCounter, +} + +impl> From<&Lang> for Func { + fn from(_lang: &Lang) -> Self { + eval::eval_step().clone() + } } /// LEM variables @@ -96,14 +106,14 @@ pub struct Func { pub struct Var(AString); /// LEM tags -#[derive(Copy, Debug, PartialEq, Clone, Eq, Hash)] +#[derive(Copy, Debug, PartialEq, Clone, Eq, Hash, Serialize, Deserialize)] pub enum Tag { Expr(ExprTag), Cont(ContTag), Ctrl(CtrlTag), } -#[derive(Copy, Debug, PartialEq, Clone, Eq, Hash)] +#[derive(Copy, Debug, PartialEq, Clone, Eq, Hash, Serialize, Deserialize)] pub enum CtrlTag { Return, MakeThunk, @@ -169,18 +179,31 @@ impl Lit { Self::Num(num) => Ptr::num(F::from_u128(*num)), } } + + pub fn to_ptr_cached(&self, store: &Store) -> Ptr { + match self { + Self::Symbol(s) => *store + .interned_symbol(s) + .expect("Symbol should have been cached"), + Self::String(s) => *store + .interned_string(s) + .expect("String should have been cached"), + Self::Num(num) => Ptr::num(F::from_u128(*num)), + } + } + pub fn from_ptr(ptr: &Ptr, store: &Store) -> Option { use ExprTag::*; use Tag::*; match ptr.tag() { Expr(Num) => match ptr { - Ptr::Leaf(_, f) => { + Ptr::Atom(_, f) => { let num = LurkField::to_u128_unchecked(f); Some(Self::Num(num)) } _ => unreachable!(), }, - Expr(Str) => store.fetch_string(ptr).cloned().map(Lit::String), + Expr(Str) => store.fetch_string(ptr).map(Lit::String), Expr(Sym) => store.fetch_symbol(ptr).map(Lit::Symbol), _ => None, } @@ -212,13 +235,14 @@ pub struct Block { #[non_exhaustive] #[derive(Debug, Clone, PartialEq, Eq)] pub enum Ctrl { - /// `MatchTag(x, cases)` performs a match on the tag of `x`, choosing the - /// appropriate `Block` among the ones provided in `cases` + /// `MatchTag(x, cases, def)` checks whether the tag of `x` matches some tag + /// among the ones provided in `cases`. If so, run the corresponding `Block`. + /// Run `def` otherwise MatchTag(Var, IndexMap, Option>), - /// `MatchSymbol(x, cases, def)` checks whether `x` matches some symbol among - /// the ones provided in `cases`. If so, run the corresponding `Block`. Run - /// `def` otherwise - MatchVal(Var, IndexMap, Option>), + /// `MatchSymbol(x, cases, def)` requires that `x` is a symbol and checks + /// whether `x` matches some symbol among the ones provided in `cases`. If so, + /// run the corresponding `Block`. Run `def` otherwise + MatchSymbol(Var, IndexMap, Option>), /// `IfEq(x, y, eq_block, else_block)` runs `eq_block` if `x == y`, and /// otherwise runs `else_block` IfEq(Var, Var, Box, Box), @@ -454,31 +478,13 @@ impl Func { None => (), } } - Ctrl::MatchVal(var, cases, def) => { + Ctrl::MatchSymbol(var, cases, def) => { is_bound(var, map)?; - let mut lits = HashSet::new(); - let mut kind = None; - for (lit, block) in cases { - let lit_kind = match lit { - Lit::Num(..) => 0, - Lit::String(..) => 1, - Lit::Symbol(..) => 2, - }; - if let Some(kind) = kind { - if kind != lit_kind { - bail!("Only values of the same kind allowed."); - } - } else { - kind = Some(lit_kind) - } - if !lits.insert(lit) { - bail!("Case {:?} already defined.", lit); - } + for block in cases.values() { recurse(block, return_size, map)?; } - match def { - Some(def) => recurse(def, return_size, map)?, - None => (), + if let Some(def) = def { + recurse(def, return_size, map)?; } } Ctrl::IfEq(x, y, eq_block, else_block) => { @@ -548,6 +554,12 @@ impl Func { body, ) } + + pub fn init_store(&self) -> Store { + let mut store = Store::default(); + self.body.intern_lits(&mut store); + store + } } impl Block { @@ -695,18 +707,18 @@ impl Block { }; Ctrl::MatchTag(var, IndexMap::from_iter(new_cases), new_def) } - Ctrl::MatchVal(var, cases, def) => { + Ctrl::MatchSymbol(var, cases, def) => { let var = map.get_cloned(&var)?; let mut new_cases = Vec::with_capacity(cases.len()); - for (lit, case) in cases { + for (sym, case) in cases { let new_case = case.deconflict(&mut map.clone(), uniq)?; - new_cases.push((lit.clone(), new_case)); + new_cases.push((sym.clone(), new_case)); } let new_def = match def { Some(def) => Some(Box::new(def.deconflict(map, uniq)?)), None => None, }; - Ctrl::MatchVal(var, IndexMap::from_iter(new_cases), new_def) + Ctrl::MatchSymbol(var, IndexMap::from_iter(new_cases), new_def) } Ctrl::IfEq(x, y, eq_block, else_block) => { let x = map.get_cloned(&x)?; @@ -719,6 +731,40 @@ impl Block { }; Ok(Block { ops, ctrl }) } + + fn intern_lits(&self, store: &mut Store) { + for op in &self.ops { + match op { + Op::Call(_, func, _) => func.body.intern_lits(store), + Op::Lit(_, lit) => { + lit.to_ptr(store); + } + _ => (), + } + } + match &self.ctrl { + Ctrl::IfEq(.., a, b) => { + a.intern_lits(store); + b.intern_lits(store); + } + Ctrl::MatchTag(_, cases, def) => { + cases.values().for_each(|block| block.intern_lits(store)); + if let Some(def) = def { + def.intern_lits(store); + } + } + Ctrl::MatchSymbol(_, cases, def) => { + for (sym, b) in cases { + store.intern_symbol(sym); + b.intern_lits(store); + } + if let Some(def) = def { + def.intern_lits(store); + } + } + Ctrl::Return(..) => (), + } + } } impl Var { @@ -731,8 +777,7 @@ impl Var { #[cfg(test)] mod tests { use super::slot::SlotsCounter; - use super::{store::Store, *}; - use crate::state::lurk_sym; + use super::*; use crate::{func, lem::pointers::Ptr}; use bellpepper::util_cs::Comparable; use bellpepper_core::test_cs::TestConstraintSystem; @@ -747,27 +792,31 @@ mod tests { /// - `expected_slots` gives the number of expected slots for each type of hash. fn synthesize_test_helper(func: &Func, inputs: Vec>, expected_num_slots: SlotsCounter) { use crate::tag::ContTag::*; - let store = &mut Store::default(); + let store = &mut func.init_store(); let outermost = Ptr::null(Tag::Cont(Outermost)); let terminal = Ptr::null(Tag::Cont(Terminal)); let error = Ptr::null(Tag::Cont(Error)); - let nil = store.intern_symbol(&lurk_sym("nil")); + let nil = store.intern_nil(); let stop_cond = |output: &[Ptr]| output[2] == terminal || output[2] == error; assert_eq!(func.slot, expected_num_slots); let computed_num_constraints = func.num_constraints::(store); + let log_fmt = |_: usize, _: &[Ptr], _: &[Ptr], _: &Store| String::default(); + let mut cs_prev = None; for input in inputs.into_iter() { - let input = vec![input, nil, outermost]; - let (frames, _) = func.call_until(input, store, stop_cond).unwrap(); + let input = [input, nil, outermost]; + let (frames, ..) = func + .call_until(&input, store, stop_cond, 10, log_fmt) + .unwrap(); let mut cs; - for frame in frames.clone() { + for frame in frames { cs = TestConstraintSystem::::new(); - func.synthesize(&mut cs, store, &frame).unwrap(); + func.synthesize_frame_aux(&mut cs, store, &frame).unwrap(); assert!(cs.is_satisfied()); assert_eq!(computed_num_constraints, cs.num_constraints()); if let Some(cs_prev) = cs_prev { diff --git a/src/lem/path.rs b/src/lem/path.rs index b1a757efbf..dbd63d4af1 100644 --- a/src/lem/path.rs +++ b/src/lem/path.rs @@ -1,11 +1,13 @@ use std::collections::HashSet; -use super::{Block, Ctrl, Func, Lit, Op, Tag}; +use crate::Symbol; + +use super::{Block, Ctrl, Func, Op, Tag}; #[derive(Clone, Debug, PartialEq, Eq, Hash)] pub(crate) enum PathNode { Tag(Tag), - Lit(Lit), + Symbol(Symbol), Bool(bool), Default, } @@ -14,7 +16,7 @@ impl std::fmt::Display for PathNode { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Self::Tag(tag) => write!(f, "Tag({})", tag), - Self::Lit(lit) => write!(f, "{:?}", lit), + Self::Symbol(sym) => write!(f, "Symbol({})", sym), Self::Bool(b) => write!(f, "Bool({})", b), Self::Default => write!(f, "Default"), } @@ -44,9 +46,9 @@ impl Path { Path(path) } - pub fn push_lit(&self, lit: &Lit) -> Path { + pub fn push_symbol(&self, sym: Symbol) -> Path { let mut path = self.0.clone(); - path.push(PathNode::Lit(lit.clone())); + path.push(PathNode::Symbol(sym)); Path(path) } @@ -57,8 +59,8 @@ impl Path { } #[inline] - pub fn push_tag_inplace(&mut self, tag: &Tag) { - self.0.push(PathNode::Tag(*tag)); + pub fn push_tag_inplace(&mut self, tag: Tag) { + self.0.push(PathNode::Tag(tag)); } #[inline] @@ -67,8 +69,8 @@ impl Path { } #[inline] - pub fn push_lit_inplace(&mut self, lit: &Lit) { - self.0.push(PathNode::Lit(lit.clone())); + pub fn push_symbol_inplace(&mut self, sym: Symbol) { + self.0.push(PathNode::Symbol(sym)); } #[inline] @@ -119,7 +121,7 @@ impl Block { .values() .fold(init, |acc, block| acc + block.num_paths()) } - Ctrl::MatchVal(_, cases, def) => { + Ctrl::MatchSymbol(_, cases, def) => { let init = def.as_ref().map_or(0, |def| def.num_paths()); cases .values() diff --git a/src/lem/pointers.rs b/src/lem/pointers.rs index fbb0c41a1f..8f41dba21b 100644 --- a/src/lem/pointers.rs +++ b/src/lem/pointers.rs @@ -1,10 +1,12 @@ -use crate::{field::*, tag::ContTag::Dummy, tag::ExprTag::*}; +use serde::{Deserialize, Serialize}; + +use crate::{field::*, tag::ExprTag::*}; use super::Tag; /// `Ptr` is the main piece of data LEMs operate on. We can think of a pointer /// as a building block for trees that represent Lurk data. A pointer can be a -/// leaf that contains data encoded as an element of a `LurkField` or it can have +/// atom that contains data encoded as an element of a `LurkField` or it can have /// children. For performance, the children of a pointer are stored on an /// `IndexSet` and the resulding index is carried by the pointer itself. /// @@ -13,9 +15,9 @@ use super::Tag; /// children a pointer has. However, LEMs require extra flexibility because LEM /// hashing operations can plug any tag to the resulting pointer. Thus, the /// number of children have to be made explicit as the `Ptr` enum. -#[derive(Clone, Copy, Debug, PartialEq, Eq)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)] pub enum Ptr { - Leaf(Tag, F), + Atom(Tag, F), Tuple2(Tag, usize), Tuple3(Tag, usize), Tuple4(Tag, usize), @@ -24,7 +26,7 @@ pub enum Ptr { impl std::hash::Hash for Ptr { fn hash(&self, state: &mut H) { match self { - Ptr::Leaf(tag, f) => (0, tag, f.to_repr().as_ref()).hash(state), + Ptr::Atom(tag, f) => (0, tag, f.to_repr().as_ref()).hash(state), Ptr::Tuple2(tag, x) => (1, tag, x).hash(state), Ptr::Tuple3(tag, x) => (2, tag, x).hash(state), Ptr::Tuple4(tag, x) => (3, tag, x).hash(state), @@ -35,7 +37,7 @@ impl std::hash::Hash for Ptr { impl Ptr { pub fn tag(&self) -> &Tag { match self { - Ptr::Leaf(tag, _) | Ptr::Tuple2(tag, _) | Ptr::Tuple3(tag, _) | Ptr::Tuple4(tag, _) => { + Ptr::Atom(tag, _) | Ptr::Tuple2(tag, _) | Ptr::Tuple3(tag, _) | Ptr::Tuple4(tag, _) => { tag } } @@ -43,31 +45,60 @@ impl Ptr { #[inline] pub fn num(f: F) -> Self { - Ptr::Leaf(Tag::Expr(Num), f) + Ptr::Atom(Tag::Expr(Num), f) + } + + #[inline] + pub fn num_u64(u: u64) -> Self { + Ptr::Atom(Tag::Expr(Num), F::from_u64(u)) + } + + #[inline] + pub fn u64(u: u64) -> Self { + Ptr::Atom(Tag::Expr(U64), F::from_u64(u)) } #[inline] pub fn char(c: char) -> Self { - Ptr::Leaf(Tag::Expr(Char), F::from_char(c)) + Ptr::Atom(Tag::Expr(Char), F::from_char(c)) } #[inline] pub fn comm(hash: F) -> Self { - Ptr::Leaf(Tag::Expr(Comm), hash) + Ptr::Atom(Tag::Expr(Comm), hash) } #[inline] pub fn null(tag: Tag) -> Self { - Ptr::Leaf(tag, F::ZERO) + Ptr::Atom(tag, F::ZERO) + } + + pub fn is_null(&self) -> bool { + match self { + Ptr::Atom(_, f) => f == &F::ZERO, + _ => false, + } + } + + pub fn is_nil(&self) -> bool { + self.tag() == &Tag::Expr(Nil) + } + + #[inline] + pub fn cast(self, tag: Tag) -> Self { + match self { + Ptr::Atom(_, f) => Ptr::Atom(tag, f), + Ptr::Tuple2(_, x) => Ptr::Tuple2(tag, x), + Ptr::Tuple3(_, x) => Ptr::Tuple3(tag, x), + Ptr::Tuple4(_, x) => Ptr::Tuple4(tag, x), + } } #[inline] - pub fn cast(&self, tag: Tag) -> Self { + pub fn get_atom(&self) -> Option<&F> { match self { - Ptr::Leaf(_, f) => Ptr::Leaf(tag, *f), - Ptr::Tuple2(_, x) => Ptr::Tuple2(tag, *x), - Ptr::Tuple3(_, x) => Ptr::Tuple3(tag, *x), - Ptr::Tuple4(_, x) => Ptr::Tuple4(tag, *x), + Ptr::Atom(_, f) => Some(f), + _ => None, } } @@ -106,32 +137,43 @@ impl Ptr { /// An important note is that computing the respective `ZPtr` of a `Ptr` can be /// expensive because of the Poseidon hashes. That's why we operate on `Ptr`s /// when interpreting LEMs and delay the need for `ZPtr`s as much as possible. -#[derive(Clone, Copy, PartialEq, Eq, Debug)] +#[derive(Clone, Copy, PartialEq, Eq, Debug, Serialize, Deserialize)] pub struct ZPtr { pub tag: Tag, pub hash: F, } +impl PartialOrd for ZPtr { + fn partial_cmp(&self, other: &Self) -> Option { + ( + self.tag.to_field::().to_repr().as_ref(), + self.hash.to_repr().as_ref(), + ) + .partial_cmp(&( + other.tag.to_field::().to_repr().as_ref(), + other.hash.to_repr().as_ref(), + )) + } +} + +impl Ord for ZPtr { + fn cmp(&self, other: &Self) -> core::cmp::Ordering { + self.partial_cmp(other) + .expect("ZPtr::cmp: partial_cmp domain invariant violation") + } +} + /// `ZChildren` keeps track of the children of `ZPtr`s, in case they have any. /// This information is saved during hydration and is needed to content-address /// a store. -#[derive(Debug)] -pub(crate) enum ZChildren { +#[derive(Debug, Serialize, Deserialize)] +pub enum ZChildren { + Atom, Tuple2(ZPtr, ZPtr), Tuple3(ZPtr, ZPtr, ZPtr), Tuple4(ZPtr, ZPtr, ZPtr, ZPtr), } -impl ZPtr { - #[inline] - pub fn dummy() -> Self { - Self { - tag: Tag::Cont(Dummy), - hash: F::ZERO, - } - } -} - impl std::hash::Hash for ZPtr { fn hash(&self, state: &mut H) { self.tag.hash(state); diff --git a/src/lem/slot.rs b/src/lem/slot.rs index 6456e7c2c2..676cc3f80f 100644 --- a/src/lem/slot.rs +++ b/src/lem/slot.rs @@ -179,6 +179,11 @@ impl SlotsCounter { less_than: self.less_than + other.less_than, } } + + #[inline] + pub fn fold_max(self, vec: Vec) -> Self { + vec.into_iter().fold(self, |acc, i| acc.max(i)) + } } impl Block { @@ -204,7 +209,7 @@ impl Block { .values() .fold(init, |acc, block| acc.max(block.count_slots())) } - Ctrl::MatchVal(_, cases, def) => { + Ctrl::MatchSymbol(_, cases, def) => { let init = def .as_ref() .map_or(SlotsCounter::default(), |def| def.count_slots()); diff --git a/src/lem/store.rs b/src/lem/store.rs index c65ac9d2fa..e2c6685216 100644 --- a/src/lem/store.rs +++ b/src/lem/store.rs @@ -1,33 +1,34 @@ +use anyhow::{bail, Result}; +use indexmap::IndexSet; +use nom::{sequence::preceded, Parser}; use rayon::prelude::*; use std::{cell::RefCell, collections::HashMap, rc::Rc}; use crate::{ + cache_map::CacheMap, field::{FWrap, LurkField}, hash::PoseidonCache, lem::Tag, + parser::*, state::{lurk_sym, State}, symbol::Symbol, syntax::Syntax, tag::ExprTag::*, uint::UInt, }; -use anyhow::{bail, Result}; -use dashmap::DashMap; -use indexmap::IndexSet; -use super::pointers::{Ptr, ZChildren, ZPtr}; +use super::pointers::{Ptr, ZPtr}; /// The `Store` is a crucial part of Lurk's implementation and tries to be a /// vesatile data structure for many parts of Lurk's data pipeline. /// -/// It holds Lurk data structured as trees of `Ptr`s (or `ZPtr`s). When a `Ptr` -/// has children`, we store them in the `IndexSet`s available: `tuple2`, `tuple3` -/// or `tuple4`. These data structures speed up LEM interpretation because lookups -/// by indices are fast. +/// It holds Lurk data structured as trees of `Ptr`s. When a `Ptr` has children, +/// we store them in the `IndexSet`s available: `tuple2`, `tuple3` or `tuple4`. +/// These data structures speed up LEM interpretation because lookups by indices +/// are fast. /// -/// The `Store` also provides an infra to speed up interning strings and symbols. -/// This data is saved in `str_tails_cache` and `sym_tails_cache`, which are better -/// explained in `intern_string` and `intern_symbol_path` respectively. +/// The `Store` provides an infra to speed up interning strings and symbols. This +/// data is saved in `string_ptr_cache` and `symbol_ptr_cache`. /// /// There's also a process that we call "hydration", in which we use Poseidon /// hashes to compute the (stable) hash of the children of a pointer. These hashes @@ -42,17 +43,18 @@ pub struct Store { tuple3: IndexSet<(Ptr, Ptr, Ptr)>, tuple4: IndexSet<(Ptr, Ptr, Ptr, Ptr)>, - str_cache: HashMap>, - ptr_str_cache: HashMap, String>, - sym_cache: HashMap, Ptr>, - ptr_sym_cache: HashMap, Vec>, + string_ptr_cache: HashMap>, + symbol_ptr_cache: HashMap>, + + ptr_string_cache: CacheMap, String>, + ptr_symbol_cache: CacheMap, Box>, pub poseidon_cache: PoseidonCache, + dehydrated: Vec>, - z_cache: DashMap, ZPtr, ahash::RandomState>, - z_dag: DashMap, ZChildren, ahash::RandomState>, + z_cache: CacheMap, Box>>, - pub comms: HashMap, (F, Ptr)>, // hash -> (secret, src) + comms: HashMap, (F, Ptr)>, // hash -> (secret, src) } impl Store { @@ -69,10 +71,11 @@ impl Store { /// Similar to `intern_2_ptrs` but doesn't add the resulting pointer to /// `dehydrated`. This function is used when converting a `ZStore` to a - /// `Store` (TODO). - #[inline] - pub fn intern_2_ptrs_not_dehydrated(&mut self, tag: Tag, a: Ptr, b: Ptr) -> Ptr { - Ptr::Tuple2(tag, self.tuple2.insert_full((a, b)).0) + /// `Store`. + pub fn intern_2_ptrs_hydrated(&mut self, tag: Tag, a: Ptr, b: Ptr, z: ZPtr) -> Ptr { + let ptr = Ptr::Tuple2(tag, self.tuple2.insert_full((a, b)).0); + self.z_cache.insert(ptr, Box::new(z)); + ptr } /// Creates a `Ptr` that's a parent of three children @@ -88,16 +91,18 @@ impl Store { /// Similar to `intern_3_ptrs` but doesn't add the resulting pointer to /// `dehydrated`. This function is used when converting a `ZStore` to a - /// `Store` (TODO). - #[inline] - pub fn intern_3_ptrs_not_dehydrated( + /// `Store`. + pub fn intern_3_ptrs_hydrated( &mut self, tag: Tag, a: Ptr, b: Ptr, c: Ptr, + z: ZPtr, ) -> Ptr { - Ptr::Tuple3(tag, self.tuple3.insert_full((a, b, c)).0) + let ptr = Ptr::Tuple3(tag, self.tuple3.insert_full((a, b, c)).0); + self.z_cache.insert(ptr, Box::new(z)); + ptr } /// Creates a `Ptr` that's a parent of four children @@ -120,17 +125,19 @@ impl Store { /// Similar to `intern_4_ptrs` but doesn't add the resulting pointer to /// `dehydrated`. This function is used when converting a `ZStore` to a - /// `Store` (TODO). - #[inline] - pub fn intern_4_ptrs_not_dehydrated( + /// `Store`. + pub fn intern_4_ptrs_hydrated( &mut self, tag: Tag, a: Ptr, b: Ptr, c: Ptr, d: Ptr, + z: ZPtr, ) -> Ptr { - Ptr::Tuple4(tag, self.tuple4.insert_full((a, b, c, d)).0) + let ptr = Ptr::Tuple4(tag, self.tuple4.insert_full((a, b, c, d)).0); + self.z_cache.insert(ptr, Box::new(z)); + ptr } #[inline] @@ -148,131 +155,304 @@ impl Store { self.tuple4.get_index(idx) } - /// Interns a string recursively pub fn intern_string(&mut self, s: &str) -> Ptr { - if s.is_empty() { - let ptr = Ptr::null(Tag::Expr(Str)); - self.ptr_str_cache.insert(ptr, "".into()); - return ptr; + if let Some(ptr) = self.string_ptr_cache.get(s) { + *ptr + } else { + let ptr = s.chars().rev().fold(Ptr::null(Tag::Expr(Str)), |acc, c| { + self.intern_2_ptrs(Tag::Expr(Str), Ptr::char(c), acc) + }); + self.string_ptr_cache.insert(s.to_string(), ptr); + self.ptr_string_cache.insert(ptr, s.to_string()); + ptr } + } + + #[inline] + pub fn interned_string(&self, s: &str) -> Option<&Ptr> { + self.string_ptr_cache.get(s) + } - match self.str_cache.get(s) { - Some(ptr_cache) => *ptr_cache, - None => { - let tail = &s.chars().skip(1).collect::(); - let tail_ptr = self.intern_string(tail); - let head = s.chars().next().unwrap(); - let s_ptr = self.intern_2_ptrs(Tag::Expr(Str), Ptr::char(head), tail_ptr); - self.str_cache.insert(s.into(), s_ptr); - self.ptr_str_cache.insert(s_ptr, s.into()); - s_ptr + pub fn fetch_string(&self, ptr: &Ptr) -> Option { + if let Some(str) = self.ptr_string_cache.get(ptr) { + Some(str.to_string()) + } else { + let mut string = String::new(); + let mut ptr = *ptr; + loop { + match ptr { + Ptr::Atom(Tag::Expr(Str), f) => { + if f == F::ZERO { + self.ptr_string_cache.insert(ptr, string.clone()); + return Some(string); + } else { + return None; + } + } + Ptr::Tuple2(Tag::Expr(Str), idx) => { + let (car, cdr) = self.fetch_2_ptrs(idx)?; + match car { + Ptr::Atom(Tag::Expr(Char), c) => { + string.push(c.to_char().expect("char pointers are well formed")); + ptr = *cdr + } + _ => return None, + } + } + _ => return None, + } } } } - #[inline] - pub fn fetch_string(&self, ptr: &Ptr) -> Option<&String> { - match ptr.tag() { - Tag::Expr(Str) => self.ptr_str_cache.get(ptr), - _ => None, + pub fn intern_symbol_path(&mut self, path: &[String]) -> Ptr { + path.iter().fold(Ptr::null(Tag::Expr(Sym)), |acc, s| { + let s_ptr = self.intern_string(s); + self.intern_2_ptrs(Tag::Expr(Sym), s_ptr, acc) + }) + } + + pub fn intern_symbol(&mut self, sym: &Symbol) -> Ptr { + if let Some(ptr) = self.symbol_ptr_cache.get(sym) { + *ptr + } else { + let path_ptr = self.intern_symbol_path(sym.path()); + let sym_ptr = if sym == &lurk_sym("nil") { + path_ptr.cast(Tag::Expr(Nil)) + } else if sym.is_keyword() { + path_ptr.cast(Tag::Expr(Key)) + } else { + path_ptr + }; + self.symbol_ptr_cache.insert(sym.clone(), sym_ptr); + self.ptr_symbol_cache.insert(sym_ptr, Box::new(sym.clone())); + sym_ptr } } - /// Interns a symbol path recursively - pub fn intern_symbol_path(&mut self, path: &[String]) -> Ptr { - if path.is_empty() { - let ptr = Ptr::null(Tag::Expr(Sym)); - self.ptr_sym_cache.insert(ptr, vec![]); - return ptr; + #[inline] + pub fn interned_symbol(&self, s: &Symbol) -> Option<&Ptr> { + self.symbol_ptr_cache.get(s) + } + + pub fn fetch_symbol_path(&self, mut idx: usize) -> Option> { + let mut path = vec![]; + loop { + let (car, cdr) = self.fetch_2_ptrs(idx)?; + let string = self.fetch_string(car)?; + path.push(string); + match cdr { + Ptr::Atom(Tag::Expr(Sym), f) => { + if f == &F::ZERO { + path.reverse(); + return Some(path); + } else { + return None; + } + } + Ptr::Tuple2(Tag::Expr(Sym), idx_cdr) => idx = *idx_cdr, + _ => return None, + } } + } - match self.sym_cache.get(path) { - Some(ptr_cache) => *ptr_cache, - None => { - let tail = &path[1..]; - let tail_ptr = self.intern_symbol_path(tail); - let head = &path[0]; - let head_ptr = self.intern_string(head); - let path_ptr = self.intern_2_ptrs(Tag::Expr(Sym), head_ptr, tail_ptr); - self.sym_cache.insert(path.to_vec(), path_ptr); - self.ptr_sym_cache.insert(path_ptr, path.to_vec()); - path_ptr + pub fn fetch_symbol(&self, ptr: &Ptr) -> Option { + if let Some(sym) = self.ptr_symbol_cache.get(ptr) { + Some(sym.clone()) + } else { + match ptr { + Ptr::Atom(Tag::Expr(Sym), f) => { + if f == &F::ZERO { + let sym = Symbol::root_sym(); + self.ptr_symbol_cache.insert(*ptr, Box::new(sym.clone())); + Some(sym) + } else { + None + } + } + Ptr::Atom(Tag::Expr(Key), f) => { + if f == &F::ZERO { + let key = Symbol::root_key(); + self.ptr_symbol_cache.insert(*ptr, Box::new(key.clone())); + Some(key) + } else { + None + } + } + Ptr::Tuple2(Tag::Expr(Sym), idx) | Ptr::Tuple2(Tag::Expr(Nil), idx) => { + let path = self.fetch_symbol_path(*idx)?; + let sym = Symbol::sym_from_vec(path); + self.ptr_symbol_cache.insert(*ptr, Box::new(sym.clone())); + Some(sym) + } + Ptr::Tuple2(Tag::Expr(Key), idx) => { + let path = self.fetch_symbol_path(*idx)?; + let key = Symbol::key_from_vec(path); + self.ptr_symbol_cache.insert(*ptr, Box::new(key.clone())); + Some(key) + } + _ => None, } } } + pub fn fetch_sym(&self, ptr: &Ptr) -> Option { + if ptr.tag() == &Tag::Expr(Sym) { + self.fetch_symbol(ptr) + } else { + None + } + } + + pub fn fetch_key(&self, ptr: &Ptr) -> Option { + if ptr.tag() == &Tag::Expr(Key) { + self.fetch_symbol(ptr) + } else { + None + } + } + + pub fn hide(&mut self, secret: F, payload: Ptr) -> Result> { + let z_ptr = self.hash_ptr(&payload)?; + let hash = self + .poseidon_cache + .hash3(&[secret, z_ptr.tag.to_field(), z_ptr.hash]); + self.comms.insert(FWrap::(hash), (secret, payload)); + Ok(Ptr::comm(hash)) + } + + pub fn hide_and_return_z_payload( + &mut self, + secret: F, + payload: Ptr, + ) -> Result<(F, ZPtr)> { + let z_ptr = self.hash_ptr(&payload)?; + let hash = self + .poseidon_cache + .hash3(&[secret, z_ptr.tag.to_field(), z_ptr.hash]); + self.comms.insert(FWrap::(hash), (secret, payload)); + Ok((hash, z_ptr)) + } + #[inline] - pub fn fetch_sym_path(&self, ptr: &Ptr) -> Option<&Vec> { - self.ptr_sym_cache.get(ptr) + pub fn commit(&mut self, payload: Ptr) -> Result> { + self.hide(F::NON_HIDING_COMMITMENT_SECRET, payload) + } + + pub fn open(&self, hash: F) -> Option<&(F, Ptr)> { + self.comms.get(&FWrap(hash)) } #[inline] - pub fn fetch_symbol(&self, ptr: &Ptr) -> Option { + pub fn intern_lurk_sym(&mut self, name: &str) -> Ptr { + self.intern_symbol(&lurk_sym(name)) + } + + #[inline] + pub fn intern_nil(&mut self) -> Ptr { + self.intern_lurk_sym("nil") + } + + #[inline] + pub fn key(&mut self, name: &str) -> Ptr { + self.intern_symbol(&Symbol::key(&[name.to_string()])) + } + + pub fn car_cdr(&mut self, ptr: &Ptr) -> Result<(Ptr, Ptr)> { match ptr.tag() { - Tag::Expr(Sym) | Tag::Expr(Key) => Some(Symbol::new( - self.fetch_sym_path(ptr)?, - ptr.tag() == &Tag::Expr(Key), - )), - _ => None, + Tag::Expr(Nil) => { + let nil = self.intern_nil(); + Ok((nil, nil)) + } + Tag::Expr(Cons) => { + let Some(idx) = ptr.get_index2() else { + bail!("malformed cons pointer") + }; + match self.fetch_2_ptrs(idx) { + Some(res) => Ok(*res), + None => bail!("car/cdr not found"), + } + } + Tag::Expr(Str) => { + if ptr.is_null() { + Ok((self.intern_nil(), Ptr::null(Tag::Expr(Str)))) + } else { + let Some(idx) = ptr.get_index2() else { + bail!("malformed str pointer") + }; + match self.fetch_2_ptrs(idx) { + Some(res) => Ok(*res), + None => bail!("car/cdr not found"), + } + } + } + _ => bail!("invalid pointer to extract car/cdr from"), } } - pub fn intern_symbol(&mut self, sym: &Symbol) -> Ptr { - let path_ptr = self.intern_symbol_path(sym.path()); - if sym == &lurk_sym("nil") { - path_ptr.cast(Tag::Expr(Nil)) - } else if !sym.is_keyword() { - path_ptr - } else { - path_ptr.cast(Tag::Expr(Key)) - } + pub fn list(&mut self, elts: Vec>) -> Ptr { + elts.into_iter().rev().fold(self.intern_nil(), |acc, elt| { + self.intern_2_ptrs(Tag::Expr(Cons), elt, acc) + }) } - pub fn intern_syntax(&mut self, syn: Syntax) -> Result> { + pub fn intern_syntax(&mut self, syn: Syntax) -> Ptr { match syn { - Syntax::Num(_, x) => Ok(Ptr::Leaf(Tag::Expr(Num), x.into_scalar())), - Syntax::UInt(_, UInt::U64(x)) => Ok(Ptr::Leaf(Tag::Expr(U64), x.into())), - Syntax::Char(_, x) => Ok(Ptr::Leaf(Tag::Expr(Char), (x as u64).into())), - Syntax::Symbol(_, symbol) => Ok(self.intern_symbol(&symbol)), - Syntax::String(_, x) => Ok(self.intern_string(&x)), + Syntax::Num(_, x) => Ptr::Atom(Tag::Expr(Num), x.into_scalar()), + Syntax::UInt(_, UInt::U64(x)) => Ptr::Atom(Tag::Expr(U64), x.into()), + Syntax::Char(_, x) => Ptr::Atom(Tag::Expr(Char), (x as u64).into()), + Syntax::Symbol(_, symbol) => self.intern_symbol(&symbol), + Syntax::String(_, x) => self.intern_string(&x), Syntax::Quote(pos, x) => { let xs = vec![Syntax::Symbol(pos, lurk_sym("quote").into()), *x]; self.intern_syntax(Syntax::List(pos, xs)) } - Syntax::List(_, xs) => { - let mut cdr = self.intern_symbol(&lurk_sym("nil")); - for x in xs.into_iter().rev() { - let car = self.intern_syntax(x)?; - cdr = self.intern_2_ptrs(Tag::Expr(Cons), car, cdr); - } - Ok(cdr) - } + Syntax::List(_, xs) => xs.into_iter().rev().fold(self.intern_nil(), |acc, x| { + let car = self.intern_syntax(x); + self.intern_2_ptrs(Tag::Expr(Cons), car, acc) + }), Syntax::Improper(_, xs, end) => { - let mut cdr = self.intern_syntax(*end)?; - for x in xs.into_iter().rev() { - let car = self.intern_syntax(x)?; - cdr = self.intern_2_ptrs(Tag::Expr(Cons), car, cdr); - } - Ok(cdr) + xs.into_iter() + .rev() + .fold(self.intern_syntax(*end), |acc, x| { + let car = self.intern_syntax(x); + self.intern_2_ptrs(Tag::Expr(Cons), car, acc) + }) } } } pub fn read(&mut self, state: Rc>, input: &str) -> Result> { - use crate::parser::*; - use nom::sequence::preceded; - use nom::Parser; match preceded( syntax::parse_space, syntax::parse_syntax(state, false, false), ) .parse(Span::new(input)) { - Ok((_i, x)) => self.intern_syntax(x), + Ok((_, x)) => Ok(self.intern_syntax(x)), Err(e) => bail!("{}", e), } } + pub fn read_maybe_meta<'a>( + &mut self, + state: Rc>, + input: &'a str, + ) -> Result<(Span<'a>, Ptr, bool), crate::parser::Error> { + match preceded(syntax::parse_space, syntax::parse_maybe_meta(state, false)) + .parse(input.into()) + { + Ok((i, Some((is_meta, x)))) => Ok((i, self.intern_syntax(x), is_meta)), + Ok((_, None)) => Err(Error::NoInput), + Err(e) => Err(Error::Syntax(format!("{}", e))), + } + } + + #[inline] + pub fn read_with_default_state(&mut self, input: &str) -> Result> { + self.read(State::init_lurk_state().rccell(), input) + } + /// Recursively hashes the children of a `Ptr` in order to obtain its /// corresponding `ZPtr`. While traversing a `Ptr` tree, it consults the /// cache of `Ptr`s that have already been hydrated and also populates this @@ -282,14 +462,14 @@ impl Store { /// depth limit. This limitation is circumvented by calling `hydrate_z_cache`. pub fn hash_ptr(&self, ptr: &Ptr) -> Result> { match ptr { - Ptr::Leaf(tag, x) => Ok(ZPtr { + Ptr::Atom(tag, x) => Ok(ZPtr { tag: *tag, hash: *x, }), Ptr::Tuple2(tag, idx) => match self.z_cache.get(ptr) { Some(z_ptr) => Ok(*z_ptr), None => { - let Some((a, b)) = self.tuple2.get_index(*idx) else { + let Some((a, b)) = self.fetch_2_ptrs(*idx) else { bail!("Index {idx} not found on tuple2") }; let a = self.hash_ptr(a)?; @@ -303,15 +483,14 @@ impl Store { b.hash, ]), }; - self.z_dag.insert(z_ptr, ZChildren::Tuple2(a, b)); - self.z_cache.insert(*ptr, z_ptr); + self.z_cache.insert(*ptr, Box::new(z_ptr)); Ok(z_ptr) } }, Ptr::Tuple3(tag, idx) => match self.z_cache.get(ptr) { Some(z_ptr) => Ok(*z_ptr), None => { - let Some((a, b, c)) = self.tuple3.get_index(*idx) else { + let Some((a, b, c)) = self.fetch_3_ptrs(*idx) else { bail!("Index {idx} not found on tuple3") }; let a = self.hash_ptr(a)?; @@ -328,15 +507,14 @@ impl Store { c.hash, ]), }; - self.z_dag.insert(z_ptr, ZChildren::Tuple3(a, b, c)); - self.z_cache.insert(*ptr, z_ptr); + self.z_cache.insert(*ptr, Box::new(z_ptr)); Ok(z_ptr) } }, Ptr::Tuple4(tag, idx) => match self.z_cache.get(ptr) { Some(z_ptr) => Ok(*z_ptr), None => { - let Some((a, b, c, d)) = self.tuple4.get_index(*idx) else { + let Some((a, b, c, d)) = self.fetch_4_ptrs(*idx) else { bail!("Index {idx} not found on tuple4") }; let a = self.hash_ptr(a)?; @@ -356,8 +534,7 @@ impl Store { d.hash, ]), }; - self.z_dag.insert(z_ptr, ZChildren::Tuple4(a, b, c, d)); - self.z_cache.insert(*ptr, z_ptr); + self.z_cache.insert(*ptr, Box::new(z_ptr)); Ok(z_ptr) } }, @@ -372,6 +549,16 @@ impl Store { }); self.dehydrated = Vec::new(); } + + pub fn to_vector(&self, ptrs: &[Ptr]) -> Result> { + ptrs.iter() + .try_fold(Vec::with_capacity(2 * ptrs.len()), |mut acc, ptr| { + let z_ptr = self.hash_ptr(ptr)?; + acc.push(z_ptr.tag.to_field()); + acc.push(z_ptr.hash); + Ok(acc) + }) + } } impl Ptr { @@ -383,7 +570,7 @@ impl Ptr { return format!("{}", s); } match self { - Ptr::Leaf(tag, f) => { + Ptr::Atom(tag, f) => { if let Some(x) = f.to_u64() { format!("{}{}", tag, x) } else { @@ -422,4 +609,156 @@ impl Ptr { } } } + + fn unfold_list(&self, store: &Store) -> Option<(Vec>, Option>)> { + let mut idx = self.get_index2()?; + let mut list = vec![]; + let mut last = None; + while let Some((car, cdr)) = store.fetch_2_ptrs(idx) { + list.push(*car); + match cdr.tag() { + Tag::Expr(Nil) => break, + Tag::Expr(Cons) => { + idx = cdr.get_index2()?; + } + _ => { + last = Some(*cdr); + break; + } + } + } + Some((list, last)) + } + + pub fn fmt_to_string(&self, store: &Store, state: &State) -> String { + match self.tag() { + Tag::Expr(t) => match t { + Nil => { + if let Some(sym) = store.fetch_symbol(self) { + state.fmt_to_string(&sym.into()) + } else { + "".into() + } + } + Sym => { + if let Some(sym) = store.fetch_sym(self) { + state.fmt_to_string(&sym.into()) + } else { + "".into() + } + } + Key => { + if let Some(key) = store.fetch_key(self) { + state.fmt_to_string(&key.into()) + } else { + "".into() + } + } + Str => { + if let Some(str) = store.fetch_string(self) { + format!("\"{str}\"") + } else { + "".into() + } + } + Char => match self.get_atom().map(F::to_char) { + Some(Some(c)) => format!("\'{c}\'"), + _ => "".into(), + }, + Cons => { + if let Some((list, last)) = self.unfold_list(store) { + let list = list + .iter() + .map(|p| p.fmt_to_string(store, state)) + .collect::>(); + if let Some(last) = last { + format!( + "({} . {})", + list.join(" "), + last.fmt_to_string(store, state) + ) + } else { + format!("({})", list.join(" ")) + } + } else { + "".into() + } + } + Num => match self.get_atom() { + None => "".into(), + Some(f) => { + if let Some(u) = f.to_u64() { + u.to_string() + } else { + format!("0x{}", f.hex_digits()) + } + } + }, + U64 => match self.get_atom().map(F::to_u64) { + Some(Some(u)) => format!("{u}u64"), + _ => "".into(), + }, + Fun => match self.get_index3() { + None => "".into(), + Some(idx) => { + if let Some((arg, bod, _)) = store.fetch_3_ptrs(idx) { + match bod.tag() { + Tag::Expr(Nil) => { + format!( + "", + arg.fmt_to_string(store, state), + bod.fmt_to_string(store, state) + ) + } + Tag::Expr(Cons) => { + if let Some(idx) = bod.get_index2() { + if let Some((bod, _)) = store.fetch_2_ptrs(idx) { + format!( + "", + arg.fmt_to_string(store, state), + bod.fmt_to_string(store, state) + ) + } else { + "".into() + } + } else { + "".into() + } + } + _ => "".into(), + } + } else { + "".into() + } + } + }, + Thunk => match self.get_index2() { + None => "".into(), + Some(idx) => { + if let Some((val, cont)) = store.fetch_2_ptrs(idx) { + format!( + "Thunk{{ value: {} => cont: {} }}", + val.fmt_to_string(store, state), + cont.fmt_to_string(store, state) + ) + } else { + "".into() + } + } + }, + Comm => match self.get_atom() { + Some(f) => { + if store.comms.contains_key(&FWrap(*f)) { + format!("(comm 0x{})", f.hex_digits()) + } else { + format!("", f.hex_digits()) + } + } + None => "".into(), + }, + }, + Tag::Cont(_) => "".into(), + Tag::Ctrl(_) => unreachable!(), + } + } } diff --git a/src/lem/var_map.rs b/src/lem/var_map.rs index 82ef94d342..b904f56f51 100644 --- a/src/lem/var_map.rs +++ b/src/lem/var_map.rs @@ -9,7 +9,7 @@ use super::Var; /// variables before using them, so we don't expect to need some piece of /// information from a variable that hasn't been defined. #[derive(Clone)] -pub(crate) struct VarMap(HashMap); +pub struct VarMap(HashMap); impl VarMap { /// Creates an empty `VarMap` diff --git a/src/lem/zstore.rs b/src/lem/zstore.rs new file mode 100644 index 0000000000..c49ece6dd1 --- /dev/null +++ b/src/lem/zstore.rs @@ -0,0 +1,162 @@ +use anyhow::{bail, Result}; +use serde::{Deserialize, Serialize}; +use std::collections::{BTreeMap, HashMap}; + +use crate::field::{FWrap, LurkField}; + +use super::{ + pointers::{Ptr, ZChildren, ZPtr}, + store::Store, +}; + +#[derive(Default, Serialize, Deserialize)] +pub struct ZStore { + dag: BTreeMap, ZChildren>, + comms: BTreeMap, (F, ZPtr)>, +} + +impl ZStore { + #[inline] + pub fn add_comm(&mut self, hash: F, secret: F, payload: ZPtr) { + self.comms.insert(FWrap(hash), (secret, payload)); + } + + #[inline] + pub fn open(&self, hash: F) -> Option<&(F, ZPtr)> { + self.comms.get(&FWrap(hash)) + } + + #[inline] + pub fn get_children(&self, z_ptr: &ZPtr) -> Option<&ZChildren> { + self.dag.get(z_ptr) + } +} + +pub fn populate_z_store( + z_store: &mut ZStore, + ptr: &Ptr, + store: &Store, +) -> Result> { + let mut cache: HashMap, ZPtr> = HashMap::default(); + let mut recurse = |ptr: &Ptr| -> Result> { + if let Some(z_ptr) = cache.get(ptr) { + Ok(*z_ptr) + } else { + let z_ptr = match ptr { + Ptr::Atom(tag, f) => { + let z_ptr = ZPtr { + tag: *tag, + hash: *f, + }; + z_store.dag.insert(z_ptr, ZChildren::Atom); + z_ptr + } + Ptr::Tuple2(tag, idx) => { + let Some((a, b)) = store.fetch_2_ptrs(*idx) else { + bail!("Index {idx} not found on tuple2") + }; + let a = populate_z_store(z_store, a, store)?; + let b = populate_z_store(z_store, b, store)?; + let z_ptr = ZPtr { + tag: *tag, + hash: store.poseidon_cache.hash4(&[ + a.tag.to_field(), + a.hash, + b.tag.to_field(), + b.hash, + ]), + }; + z_store.dag.insert(z_ptr, ZChildren::Tuple2(a, b)); + z_ptr + } + Ptr::Tuple3(tag, idx) => { + let Some((a, b, c)) = store.fetch_3_ptrs(*idx) else { + bail!("Index {idx} not found on tuple3") + }; + let a = populate_z_store(z_store, a, store)?; + let b = populate_z_store(z_store, b, store)?; + let c = populate_z_store(z_store, c, store)?; + let z_ptr = ZPtr { + tag: *tag, + hash: store.poseidon_cache.hash6(&[ + a.tag.to_field(), + a.hash, + b.tag.to_field(), + b.hash, + c.tag.to_field(), + c.hash, + ]), + }; + z_store.dag.insert(z_ptr, ZChildren::Tuple3(a, b, c)); + z_ptr + } + Ptr::Tuple4(tag, idx) => { + let Some((a, b, c, d)) = store.fetch_4_ptrs(*idx) else { + bail!("Index {idx} not found on tuple4") + }; + let a = populate_z_store(z_store, a, store)?; + let b = populate_z_store(z_store, b, store)?; + let c = populate_z_store(z_store, c, store)?; + let d = populate_z_store(z_store, d, store)?; + let z_ptr = ZPtr { + tag: *tag, + hash: store.poseidon_cache.hash8(&[ + a.tag.to_field(), + a.hash, + b.tag.to_field(), + b.hash, + c.tag.to_field(), + c.hash, + d.tag.to_field(), + d.hash, + ]), + }; + z_store.dag.insert(z_ptr, ZChildren::Tuple4(a, b, c, d)); + z_ptr + } + }; + cache.insert(*ptr, z_ptr); + Ok(z_ptr) + } + }; + recurse(ptr) +} + +pub fn populate_store( + store: &mut Store, + z_ptr: &ZPtr, + z_store: &ZStore, +) -> Result> { + let mut cache: HashMap, Ptr> = HashMap::default(); + let mut recurse = |z_ptr: &ZPtr| -> Result> { + if let Some(z_ptr) = cache.get(z_ptr) { + Ok(*z_ptr) + } else { + let ptr = match z_store.get_children(z_ptr) { + None => bail!("Couldn't find ZPtr"), + Some(ZChildren::Atom) => Ptr::Atom(z_ptr.tag, z_ptr.hash), + Some(ZChildren::Tuple2(z1, z2)) => { + let ptr1 = populate_store(store, z1, z_store)?; + let ptr2 = populate_store(store, z2, z_store)?; + store.intern_2_ptrs_hydrated(z_ptr.tag, ptr1, ptr2, *z_ptr) + } + Some(ZChildren::Tuple3(z1, z2, z3)) => { + let ptr1 = populate_store(store, z1, z_store)?; + let ptr2 = populate_store(store, z2, z_store)?; + let ptr3 = populate_store(store, z3, z_store)?; + store.intern_3_ptrs_hydrated(z_ptr.tag, ptr1, ptr2, ptr3, *z_ptr) + } + Some(ZChildren::Tuple4(z1, z2, z3, z4)) => { + let ptr1 = populate_store(store, z1, z_store)?; + let ptr2 = populate_store(store, z2, z_store)?; + let ptr3 = populate_store(store, z3, z_store)?; + let ptr4 = populate_store(store, z4, z_store)?; + store.intern_4_ptrs_hydrated(z_ptr.tag, ptr1, ptr2, ptr3, ptr4, *z_ptr) + } + }; + cache.insert(*z_ptr, ptr); + Ok(ptr) + } + }; + recurse(z_ptr) +}