diff --git a/src/lem/circuit.rs b/src/lem/circuit.rs index 880119d741..83c3ff0734 100644 --- a/src/lem/circuit.rs +++ b/src/lem/circuit.rs @@ -102,8 +102,8 @@ fn allocate_ptr>( var: &Var, bound_allocations: &mut BoundAllocations, ) -> Result> { - let allocated_tag = allocate_num(cs, &format!("allocate {var}'s tag"), z_ptr.tag.to_field())?; - let allocated_hash = allocate_num(cs, &format!("allocate {var}'s hash"), z_ptr.hash)?; + let allocated_tag = allocate_num(cs, &format!("allocate {var}'s tag"), z_ptr.tag_field())?; + let allocated_hash = allocate_num(cs, &format!("allocate {var}'s hash"), *z_ptr.value())?; let allocated_ptr = AllocatedPtr::from_parts(allocated_tag, allocated_hash); bound_allocations.insert(var.clone(), allocated_ptr.clone()); Ok(allocated_ptr) @@ -253,7 +253,7 @@ fn allocate_slots>( cs, &slot, component_idx, - z_ptr.tag.to_field(), + z_ptr.tag_field(), )?); component_idx += 1; @@ -263,7 +263,7 @@ fn allocate_slots>( cs, &slot, component_idx, - z_ptr.hash, + *z_ptr.value(), )?); component_idx += 1; @@ -278,11 +278,14 @@ fn allocate_slots>( cs, &slot, 1, - z_ptr.tag.to_field(), + z_ptr.tag_field(), )?); // allocate third component preallocated_preimg.push(allocate_preimg_component_for_slot( - cs, &slot, 2, z_ptr.hash, + cs, + &slot, + 2, + *z_ptr.value(), )?); } PreimageData::FPair(a, b) => { @@ -340,8 +343,8 @@ impl Block { Op::Lit(_, lit) => { let lit_ptr = lit.to_ptr_cached(store); let lit_z_ptr = store.hash_ptr(&lit_ptr).unwrap(); - g.alloc_const(cs, lit_z_ptr.tag.to_field()); - g.alloc_const(cs, lit_z_ptr.hash); + g.alloc_const(cs, lit_z_ptr.tag_field()); + g.alloc_const(cs, *lit_z_ptr.value()); } Op::Null(_, tag) => { g.alloc_const(cs, tag.to_field()); @@ -674,11 +677,11 @@ impl Func { Op::Lit(tgt, lit) => { let lit_ptr = lit.to_ptr_cached(g.store); let lit_tag = lit_ptr.tag().to_field(); - let lit_hash = g.store.hash_ptr(&lit_ptr)?.hash; let allocated_tag = g.global_allocator.get_allocated_const_cloned(lit_tag)?; - let allocated_hash = - g.global_allocator.get_allocated_const_cloned(lit_hash)?; + let allocated_hash = g + .global_allocator + .get_allocated_const_cloned(*g.store.hash_ptr(&lit_ptr)?.value())?; let allocated_ptr = AllocatedPtr::from_parts(allocated_tag, allocated_hash); bound_allocations.insert(tgt.clone(), allocated_ptr); } @@ -1149,7 +1152,7 @@ impl Func { .store .interned_symbol(sym) .expect("symbol must have been interned"); - let sym_hash = g.store.hash_ptr(sym_ptr)?.hash; + let sym_hash = *g.store.hash_ptr(sym_ptr)?.value(); cases_vec.push((sym_hash, block)); } @@ -1241,8 +1244,8 @@ impl Func { Op::Lit(_, lit) => { let lit_ptr = lit.to_ptr_cached(store); let lit_z_ptr = store.hash_ptr(&lit_ptr).unwrap(); - globals.insert(FWrap(lit_z_ptr.tag.to_field())); - globals.insert(FWrap(lit_z_ptr.hash)); + globals.insert(FWrap(lit_z_ptr.tag_field())); + globals.insert(FWrap(*lit_z_ptr.value())); } Op::Cast(_, tag, _) => { globals.insert(FWrap(tag.to_field())); diff --git a/src/lem/eval.rs b/src/lem/eval.rs index 4128215788..a10410cf1c 100644 --- a/src/lem/eval.rs +++ b/src/lem/eval.rs @@ -1070,8 +1070,8 @@ mod tests { use blstrs::Scalar as Fr; const NUM_INPUTS: usize = 1; - const NUM_AUX: usize = 10744; - const NUM_CONSTRAINTS: usize = 13299; + const NUM_AUX: usize = 10748; + const NUM_CONSTRAINTS: usize = 13303; const NUM_SLOTS: SlotsCounter = SlotsCounter { hash2: 16, hash3: 4, diff --git a/src/lem/interpreter.rs b/src/lem/interpreter.rs index b3aa073f33..234540a35c 100644 --- a/src/lem/interpreter.rs +++ b/src/lem/interpreter.rs @@ -156,9 +156,7 @@ impl Block { let b = bindings.get(b)?; // In order to compare Ptrs, we *must* resolve the hashes. Otherwise, we risk failing to recognize equality of // compound data with opaque data in either element's transitive closure. - let a_hash = store.hash_ptr(a)?.hash; - let b_hash = store.hash_ptr(b)?.hash; - let c = if a_hash == b_hash { + let c = if store.hash_ptr(a)?.value() == store.hash_ptr(b)?.value() { Ptr::Atom(Tag::Expr(Num), F::ONE) } else { Ptr::Atom(Tag::Expr(Num), F::ZERO) diff --git a/src/lem/mod.rs b/src/lem/mod.rs index 1c25237953..39b2bc2266 100644 --- a/src/lem/mod.rs +++ b/src/lem/mod.rs @@ -69,17 +69,19 @@ mod slot; pub mod store; mod var_map; pub mod zstore; - -use crate::coprocessor::Coprocessor; -use crate::eval::lang::Lang; -use crate::field::LurkField; -use crate::symbol::Symbol; -use crate::tag::{ContTag, ExprTag, Tag as TagTrait}; use anyhow::{bail, Result}; use indexmap::IndexMap; use serde::{Deserialize, Serialize}; use std::sync::Arc; +use crate::{ + coprocessor::Coprocessor, + eval::lang::Lang, + field::LurkField, + symbol::Symbol, + tag::{ContTag, CtrlTag, ExprTag, Tag as TagTrait}, +}; + use self::{pointers::Ptr, slot::SlotsCounter, store::Store, var_map::VarMap}; pub type AString = Arc; @@ -113,40 +115,47 @@ pub enum Tag { Ctrl(CtrlTag), } -#[derive(Copy, Debug, PartialEq, Clone, Eq, Hash, Serialize, Deserialize)] -pub enum CtrlTag { - Return, - MakeThunk, - ApplyContinuation, - Error, +impl From for Tag { + fn from(val: u16) -> Self { + if let Ok(tag) = ExprTag::try_from(val) { + Tag::Expr(tag) + } else if let Ok(tag) = ContTag::try_from(val) { + Tag::Cont(tag) + } else if let Ok(tag) = CtrlTag::try_from(val) { + Tag::Ctrl(tag) + } else { + panic!("Invalid u16 for Tag: {val}") + } + } } -impl Tag { - #[inline] - pub fn to_field(self) -> F { - use Tag::*; - match self { - Expr(tag) => tag.to_field(), - Cont(tag) => tag.to_field(), - Ctrl(tag) => tag.to_field(), +impl From for u16 { + fn from(val: Tag) -> Self { + match val { + Tag::Expr(tag) => tag.into(), + Tag::Cont(tag) => tag.into(), + Tag::Ctrl(tag) => tag.into(), } } } -impl CtrlTag { - #[inline] - fn to_field(self) -> F { - F::from(self as u64) +impl TagTrait for Tag { + fn from_field(f: &F) -> Option { + Self::try_from(f.to_u16()?).ok() + } + + fn to_field(&self) -> F { + Tag::to_field(self) } } -impl std::fmt::Display for CtrlTag { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { +impl Tag { + #[inline] + pub fn to_field(&self) -> F { match self { - Self::Return => write!(f, "return#"), - Self::ApplyContinuation => write!(f, "apply-cont#"), - Self::MakeThunk => write!(f, "make-thunk#"), - Self::Error => write!(f, "error#"), + Tag::Expr(tag) => tag.to_field(), + Tag::Cont(tag) => tag.to_field(), + Tag::Ctrl(tag) => tag.to_field(), } } } @@ -456,10 +465,11 @@ impl Func { let mut tags = HashSet::new(); let mut kind = None; for (tag, block) in cases { + // make sure that this `MatchTag` doesn't have weird semantics let tag_kind = match tag { Tag::Expr(..) => 0, Tag::Cont(..) => 1, - Tag::Ctrl(..) => 4, + Tag::Ctrl(..) => 2, }; if let Some(kind) = kind { if kind != tag_kind { diff --git a/src/lem/pointers.rs b/src/lem/pointers.rs index 8f41dba21b..22f7cc2a0b 100644 --- a/src/lem/pointers.rs +++ b/src/lem/pointers.rs @@ -137,31 +137,7 @@ impl Ptr { /// An important note is that computing the respective `ZPtr` of a `Ptr` can be /// expensive because of the Poseidon hashes. That's why we operate on `Ptr`s /// when interpreting LEMs and delay the need for `ZPtr`s as much as possible. -#[derive(Clone, Copy, PartialEq, Eq, Debug, Serialize, Deserialize)] -pub struct ZPtr { - pub tag: Tag, - pub hash: F, -} - -impl PartialOrd for ZPtr { - fn partial_cmp(&self, other: &Self) -> Option { - ( - self.tag.to_field::().to_repr().as_ref(), - self.hash.to_repr().as_ref(), - ) - .partial_cmp(&( - other.tag.to_field::().to_repr().as_ref(), - other.hash.to_repr().as_ref(), - )) - } -} - -impl Ord for ZPtr { - fn cmp(&self, other: &Self) -> core::cmp::Ordering { - self.partial_cmp(other) - .expect("ZPtr::cmp: partial_cmp domain invariant violation") - } -} +pub type ZPtr = crate::z_data::z_ptr::ZPtr; /// `ZChildren` keeps track of the children of `ZPtr`s, in case they have any. /// This information is saved during hydration and is needed to content-address @@ -173,10 +149,3 @@ pub enum ZChildren { Tuple3(ZPtr, ZPtr, ZPtr), Tuple4(ZPtr, ZPtr, ZPtr, ZPtr), } - -impl std::hash::Hash for ZPtr { - fn hash(&self, state: &mut H) { - self.tag.hash(state); - self.hash.to_repr().as_ref().hash(state); - } -} diff --git a/src/lem/store.rs b/src/lem/store.rs index e2c6685216..58ea3c945c 100644 --- a/src/lem/store.rs +++ b/src/lem/store.rs @@ -316,7 +316,7 @@ impl Store { let z_ptr = self.hash_ptr(&payload)?; let hash = self .poseidon_cache - .hash3(&[secret, z_ptr.tag.to_field(), z_ptr.hash]); + .hash3(&[secret, z_ptr.tag_field(), *z_ptr.value()]); self.comms.insert(FWrap::(hash), (secret, payload)); Ok(Ptr::comm(hash)) } @@ -329,7 +329,7 @@ impl Store { let z_ptr = self.hash_ptr(&payload)?; let hash = self .poseidon_cache - .hash3(&[secret, z_ptr.tag.to_field(), z_ptr.hash]); + .hash3(&[secret, z_ptr.tag_field(), *z_ptr.value()]); self.comms.insert(FWrap::(hash), (secret, payload)); Ok((hash, z_ptr)) } @@ -462,10 +462,7 @@ impl Store { /// depth limit. This limitation is circumvented by calling `hydrate_z_cache`. pub fn hash_ptr(&self, ptr: &Ptr) -> Result> { match ptr { - Ptr::Atom(tag, x) => Ok(ZPtr { - tag: *tag, - hash: *x, - }), + Ptr::Atom(tag, x) => Ok(ZPtr::from_parts(*tag, *x)), Ptr::Tuple2(tag, idx) => match self.z_cache.get(ptr) { Some(z_ptr) => Ok(*z_ptr), None => { @@ -474,15 +471,15 @@ impl Store { }; let a = self.hash_ptr(a)?; let b = self.hash_ptr(b)?; - let z_ptr = ZPtr { - tag: *tag, - hash: self.poseidon_cache.hash4(&[ - a.tag.to_field(), - a.hash, - b.tag.to_field(), - b.hash, + let z_ptr = ZPtr::from_parts( + *tag, + self.poseidon_cache.hash4(&[ + a.tag_field(), + *a.value(), + b.tag_field(), + *b.value(), ]), - }; + ); self.z_cache.insert(*ptr, Box::new(z_ptr)); Ok(z_ptr) } @@ -496,17 +493,17 @@ impl Store { let a = self.hash_ptr(a)?; let b = self.hash_ptr(b)?; let c = self.hash_ptr(c)?; - let z_ptr = ZPtr { - tag: *tag, - hash: self.poseidon_cache.hash6(&[ - a.tag.to_field(), - a.hash, - b.tag.to_field(), - b.hash, - c.tag.to_field(), - c.hash, + let z_ptr = ZPtr::from_parts( + *tag, + self.poseidon_cache.hash6(&[ + a.tag_field(), + *a.value(), + b.tag_field(), + *b.value(), + c.tag_field(), + *c.value(), ]), - }; + ); self.z_cache.insert(*ptr, Box::new(z_ptr)); Ok(z_ptr) } @@ -521,19 +518,19 @@ impl Store { let b = self.hash_ptr(b)?; let c = self.hash_ptr(c)?; let d = self.hash_ptr(d)?; - let z_ptr = ZPtr { - tag: *tag, - hash: self.poseidon_cache.hash8(&[ - a.tag.to_field(), - a.hash, - b.tag.to_field(), - b.hash, - c.tag.to_field(), - c.hash, - d.tag.to_field(), - d.hash, + let z_ptr = ZPtr::from_parts( + *tag, + self.poseidon_cache.hash8(&[ + a.tag_field(), + *a.value(), + b.tag_field(), + *b.value(), + c.tag_field(), + *c.value(), + d.tag_field(), + *d.value(), ]), - }; + ); self.z_cache.insert(*ptr, Box::new(z_ptr)); Ok(z_ptr) } @@ -554,8 +551,8 @@ impl Store { ptrs.iter() .try_fold(Vec::with_capacity(2 * ptrs.len()), |mut acc, ptr| { let z_ptr = self.hash_ptr(ptr)?; - acc.push(z_ptr.tag.to_field()); - acc.push(z_ptr.hash); + acc.push(z_ptr.tag_field()); + acc.push(*z_ptr.value()); Ok(acc) }) } diff --git a/src/lem/zstore.rs b/src/lem/zstore.rs index c49ece6dd1..f82621d3e9 100644 --- a/src/lem/zstore.rs +++ b/src/lem/zstore.rs @@ -44,10 +44,7 @@ pub fn populate_z_store( } else { let z_ptr = match ptr { Ptr::Atom(tag, f) => { - let z_ptr = ZPtr { - tag: *tag, - hash: *f, - }; + let z_ptr = ZPtr::from_parts(*tag, *f); z_store.dag.insert(z_ptr, ZChildren::Atom); z_ptr } @@ -57,15 +54,15 @@ pub fn populate_z_store( }; let a = populate_z_store(z_store, a, store)?; let b = populate_z_store(z_store, b, store)?; - let z_ptr = ZPtr { - tag: *tag, - hash: store.poseidon_cache.hash4(&[ - a.tag.to_field(), - a.hash, - b.tag.to_field(), - b.hash, + let z_ptr = ZPtr::from_parts( + *tag, + store.poseidon_cache.hash4(&[ + a.tag_field(), + *a.value(), + b.tag_field(), + *b.value(), ]), - }; + ); z_store.dag.insert(z_ptr, ZChildren::Tuple2(a, b)); z_ptr } @@ -76,17 +73,17 @@ pub fn populate_z_store( let a = populate_z_store(z_store, a, store)?; let b = populate_z_store(z_store, b, store)?; let c = populate_z_store(z_store, c, store)?; - let z_ptr = ZPtr { - tag: *tag, - hash: store.poseidon_cache.hash6(&[ - a.tag.to_field(), - a.hash, - b.tag.to_field(), - b.hash, - c.tag.to_field(), - c.hash, + let z_ptr = ZPtr::from_parts( + *tag, + store.poseidon_cache.hash6(&[ + a.tag_field(), + *a.value(), + b.tag_field(), + *b.value(), + c.tag_field(), + *c.value(), ]), - }; + ); z_store.dag.insert(z_ptr, ZChildren::Tuple3(a, b, c)); z_ptr } @@ -98,19 +95,19 @@ pub fn populate_z_store( let b = populate_z_store(z_store, b, store)?; let c = populate_z_store(z_store, c, store)?; let d = populate_z_store(z_store, d, store)?; - let z_ptr = ZPtr { - tag: *tag, - hash: store.poseidon_cache.hash8(&[ - a.tag.to_field(), - a.hash, - b.tag.to_field(), - b.hash, - c.tag.to_field(), - c.hash, - d.tag.to_field(), - d.hash, + let z_ptr = ZPtr::from_parts( + *tag, + store.poseidon_cache.hash8(&[ + a.tag_field(), + *a.value(), + b.tag_field(), + *b.value(), + c.tag_field(), + *c.value(), + d.tag_field(), + *d.value(), ]), - }; + ); z_store.dag.insert(z_ptr, ZChildren::Tuple4(a, b, c, d)); z_ptr } @@ -134,24 +131,24 @@ pub fn populate_store( } else { let ptr = match z_store.get_children(z_ptr) { None => bail!("Couldn't find ZPtr"), - Some(ZChildren::Atom) => Ptr::Atom(z_ptr.tag, z_ptr.hash), + Some(ZChildren::Atom) => Ptr::Atom(z_ptr.tag(), *z_ptr.value()), Some(ZChildren::Tuple2(z1, z2)) => { let ptr1 = populate_store(store, z1, z_store)?; let ptr2 = populate_store(store, z2, z_store)?; - store.intern_2_ptrs_hydrated(z_ptr.tag, ptr1, ptr2, *z_ptr) + store.intern_2_ptrs_hydrated(z_ptr.tag(), ptr1, ptr2, *z_ptr) } Some(ZChildren::Tuple3(z1, z2, z3)) => { let ptr1 = populate_store(store, z1, z_store)?; let ptr2 = populate_store(store, z2, z_store)?; let ptr3 = populate_store(store, z3, z_store)?; - store.intern_3_ptrs_hydrated(z_ptr.tag, ptr1, ptr2, ptr3, *z_ptr) + store.intern_3_ptrs_hydrated(z_ptr.tag(), ptr1, ptr2, ptr3, *z_ptr) } Some(ZChildren::Tuple4(z1, z2, z3, z4)) => { let ptr1 = populate_store(store, z1, z_store)?; let ptr2 = populate_store(store, z2, z_store)?; let ptr3 = populate_store(store, z3, z_store)?; let ptr4 = populate_store(store, z4, z_store)?; - store.intern_4_ptrs_hydrated(z_ptr.tag, ptr1, ptr2, ptr3, ptr4, *z_ptr) + store.intern_4_ptrs_hydrated(z_ptr.tag(), ptr1, ptr2, ptr3, ptr4, *z_ptr) } }; cache.insert(*z_ptr, ptr); diff --git a/src/tag.rs b/src/tag.rs index f68eca5c90..aa30813840 100644 --- a/src/tag.rs +++ b/src/tag.rs @@ -466,6 +466,67 @@ impl fmt::Display for Op2 { } } +#[derive( + Copy, + Clone, + Debug, + PartialEq, + PartialOrd, + Eq, + Hash, + Serialize_repr, + Deserialize_repr, + TryFromRepr, +)] +#[cfg_attr(not(target_arch = "wasm32"), derive(Arbitrary))] +#[repr(u16)] +pub enum CtrlTag { + Return = 0b0100_0000_0000_0000, + MakeThunk, + ApplyContinuation, + Error, +} + +impl From for u16 { + fn from(val: CtrlTag) -> Self { + val as u16 + } +} + +impl From for u64 { + fn from(val: CtrlTag) -> Self { + val as u64 + } +} + +impl Tag for CtrlTag { + fn from_field(f: &F) -> Option { + Self::try_from(f.to_u16()?).ok() + } + + fn to_field + ff::Field>(&self) -> F { + F::from(*self as u64) + } + + fn to_field_bytes(&self) -> F::Repr { + let mut res = F::Repr::default(); + let u: u16 = (*self).into(); + res.as_mut()[..2].copy_from_slice(&u.to_le_bytes()); + res + } +} + +impl std::fmt::Display for CtrlTag { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Return => write!(f, "return#"), + Self::ApplyContinuation => write!(f, "apply-cont#"), + Self::MakeThunk => write!(f, "make-thunk#"), + Self::Error => write!(f, "error#"), + } + } +} + #[cfg(test)] pub mod tests {