diff --git a/src/coprocessor/memoset/demo.rs b/src/coprocessor/memoset/demo.rs new file mode 100644 index 0000000000..53f04fe84a --- /dev/null +++ b/src/coprocessor/memoset/demo.rs @@ -0,0 +1,255 @@ +use bellpepper_core::{num::AllocatedNum, ConstraintSystem, SynthesisError}; + +use super::{ + query::{CircuitQuery, Query}, + CircuitScope, CircuitTranscript, LogMemo, Scope, +}; +use crate::circuit::gadgets::constraints::alloc_is_zero; +use crate::coprocessor::gadgets::construct_list; +use crate::coprocessor::AllocatedPtr; +use crate::field::LurkField; +use crate::lem::circuit::GlobalAllocator; +use crate::lem::{pointers::Ptr, store::Store}; +use crate::symbol::Symbol; +use crate::tag::{ExprTag, Tag}; + +#[allow(dead_code)] +#[derive(Debug, Clone)] +pub(crate) enum DemoQuery { + Factorial(Ptr), + Phantom(F), +} + +#[derive(Debug, Clone)] +pub(crate) enum DemoCircuitQuery { + Factorial(AllocatedPtr), +} + +impl Query for DemoQuery { + type CQ = DemoCircuitQuery; + + // DemoQuery and Scope depend on each other. + fn eval(&self, s: &Store, scope: &mut Scope>) -> Ptr { + match self { + Self::Factorial(n) => { + let n_zptr = s.hash_ptr(n); + let n = n_zptr.value(); + + if *n == F::ZERO { + s.num(F::ONE) + } else { + let m_ptr = self.recursive_eval(scope, s, Self::Factorial(s.num(*n - F::ONE))); + let m_zptr = s.hash_ptr(&m_ptr); + let m = m_zptr.value(); + + s.num(*n * m) + } + } + _ => unreachable!(), + } + } + + fn recursive_eval( + &self, + scope: &mut Scope>, + s: &Store, + subquery: Self, + ) -> Ptr { + scope.query_recursively(s, self, subquery) + } + + fn symbol(&self) -> Symbol { + match self { + Self::Factorial(_) => Symbol::sym(&["lurk", "user", "factorial"]), + _ => unreachable!(), + } + } + + fn from_ptr(s: &Store, ptr: &Ptr) -> Option { + let (head, body) = s.car_cdr(ptr).expect("query should be cons"); + let sym = s.fetch_sym(&head).expect("head should be sym"); + + if sym == Symbol::sym(&["lurk", "user", "factorial"]) { + let (num, _) = s.car_cdr(&body).expect("query body should be cons"); + Some(Self::Factorial(num)) + } else { + None + } + } + + fn to_ptr(&self, s: &Store) -> Ptr { + match self { + Self::Factorial(n) => { + let factorial = s.intern_symbol(&self.symbol()); + + s.list(vec![factorial, *n]) + } + _ => unreachable!(), + } + } + + fn index(&self) -> usize { + match self { + Self::Factorial(_) => 0, + _ => unreachable!(), + } + } +} + +impl CircuitQuery for DemoCircuitQuery { + fn synthesize_eval>( + &self, + cs: &mut CS, + g: &GlobalAllocator, + store: &Store, + scope: &mut CircuitScope>, + acc: &AllocatedPtr, + transcript: &CircuitTranscript, + ) -> Result<(AllocatedPtr, AllocatedPtr, CircuitTranscript), SynthesisError> { + match self { + // TODO: Factor out the recursive boilerplate so individual queries can just implement their distinct logic + // using a sane framework. + Self::Factorial(n) => { + // FIXME: Check n tag or decide not to. + let base_case_f = g.alloc_const(cs, F::ONE); + let base_case = AllocatedPtr::alloc_tag( + &mut cs.namespace(|| "base_case"), + ExprTag::Num.to_field(), + base_case_f.clone(), + )?; + + let n_is_zero = alloc_is_zero(&mut cs.namespace(|| "n_is_zero"), n.hash())?; + + let (recursive_result, recursive_acc, recursive_transcript) = { + let new_n = AllocatedNum::alloc(&mut cs.namespace(|| "new_n"), || { + n.hash() + .get_value() + .map(|n| n - F::ONE) + .ok_or(SynthesisError::AssignmentMissing) + })?; + + // new_n * 1 = n - 1 + cs.enforce( + || "enforce_new_n", + |lc| lc + new_n.get_variable(), + |lc| lc + CS::one(), + |lc| lc + n.hash().get_variable() - CS::one(), + ); + + let subquery = { + let symbol = g.alloc_ptr(cs, &self.symbol_ptr(store), store); + + let new_num = AllocatedPtr::alloc_tag( + &mut cs.namespace(|| "new_num"), + ExprTag::Num.to_field(), + new_n, + )?; + construct_list( + &mut cs.namespace(|| "subquery"), + g, + store, + &[&symbol, &new_num], + None, + )? + }; + + let (sub_result, new_acc, new_transcript) = scope.synthesize_query( + &mut cs.namespace(|| "recursive query"), + g, + store, + &subquery, + acc, + transcript, + &n_is_zero.not(), + )?; + + let result_f = n.hash().mul( + &mut cs.namespace(|| "incremental multiplication"), + sub_result.hash(), + )?; + + let result = AllocatedPtr::alloc_tag( + &mut cs.namespace(|| "result"), + ExprTag::Num.to_field(), + result_f, + )?; + + (result, new_acc, new_transcript) + }; + + let value = AllocatedPtr::pick( + &mut cs.namespace(|| "pick value"), + &n_is_zero, + &base_case, + &recursive_result, + )?; + + let acc = AllocatedPtr::pick( + &mut cs.namespace(|| "pick acc"), + &n_is_zero, + acc, + &recursive_acc, + )?; + + let transcript = CircuitTranscript::pick( + &mut cs.namespace(|| "pick recursive_transcript"), + &n_is_zero, + transcript, + &recursive_transcript, + )?; + + Ok((value, acc, transcript)) + } + } + } + + fn from_ptr>( + cs: &mut CS, + s: &Store, + ptr: &Ptr, + ) -> Result, SynthesisError> { + let query = DemoQuery::from_ptr(s, ptr); + Ok(if let Some(q) = query { + match q { + DemoQuery::Factorial(n) => Some(Self::Factorial(AllocatedPtr::alloc(cs, || { + Ok(s.hash_ptr(&n)) + })?)), + _ => unreachable!(), + } + } else { + None + }) + } + + fn symbol(&self) -> Symbol { + match self { + Self::Factorial(_) => Symbol::sym(&["lurk", "user", "factorial"]), + } + } +} + +#[cfg(test)] +mod test { + use super::*; + + use ff::Field; + use pasta_curves::pallas::Scalar as F; + + #[test] + fn test_factorial() { + let s = Store::default(); + let mut scope: Scope, LogMemo> = Scope::default(); + let zero = s.num(F::ZERO); + let one = s.num(F::ONE); + let two = s.num(F::from_u64(2)); + let three = s.num(F::from_u64(3)); + let four = s.num(F::from_u64(4)); + let six = s.num(F::from_u64(6)); + let twenty_four = s.num(F::from_u64(24)); + assert_eq!(one, DemoQuery::Factorial(zero).eval(&s, &mut scope)); + assert_eq!(one, DemoQuery::Factorial(one).eval(&s, &mut scope)); + assert_eq!(two, DemoQuery::Factorial(two).eval(&s, &mut scope)); + assert_eq!(six, DemoQuery::Factorial(three).eval(&s, &mut scope)); + assert_eq!(twenty_four, DemoQuery::Factorial(four).eval(&s, &mut scope)); + } +} diff --git a/src/coprocessor/memoset/mod.rs b/src/coprocessor/memoset/mod.rs index ced6ee354f..629a452cc7 100644 --- a/src/coprocessor/memoset/mod.rs +++ b/src/coprocessor/memoset/mod.rs @@ -47,14 +47,12 @@ use crate::tag::{ExprTag, Tag as XTag}; use crate::z_ptr::ZPtr; use multiset::MultiSet; -use query::{CircuitQuery, DemoCircuitQuery, Query}; +use query::{CircuitQuery, Query}; +mod demo; mod multiset; mod query; -type ScopeCircuitQuery = DemoCircuitQuery; -type ScopeQuery = as CircuitQuery>::Q; - #[derive(Clone, Debug)] pub struct Transcript { acc: Ptr, @@ -93,7 +91,6 @@ impl Transcript { #[allow(dead_code)] fn dbg(&self, s: &Store) { - //dbg!(self.acc.fmt_to_string_simple(s)); tracing::debug!("transcript: {}", self.acc.fmt_to_string_simple(s)); } @@ -175,7 +172,6 @@ impl CircuitTranscript { fn dbg(&self, s: &Store) { let z = self.acc.get_value::().unwrap(); let transcript = s.to_ptr(&z); - // dbg!(transcript.fmt_to_string_simple(s)); tracing::debug!("transcript: {}", transcript.fmt_to_string_simple(s)); } } @@ -196,10 +192,10 @@ pub struct Scope { internal_insertions: Vec, /// unique keys all_insertions: Vec, - _p: PhantomData<(F, Q)>, + _p: PhantomData, } -impl Default for Scope, LogMemo> { +impl Default for Scope> { fn default() -> Self { Self { memoset: Default::default(), @@ -213,17 +209,17 @@ impl Default for Scope, LogMemo> { } } -pub struct CircuitScope, M: MemoSet> { +pub struct CircuitScope, M: MemoSet> { memoset: M, /// k -> v queries: HashMap, ZPtr>, /// k -> allocated v transcript: CircuitTranscript, acc: Option>, - _p: PhantomData, + _p: PhantomData, } -impl Scope, LogMemo> { +impl> Scope> { pub fn query(&mut self, s: &Store, form: Ptr) -> Ptr { let (response, kv_ptr) = self.query_aux(s, form); @@ -232,12 +228,7 @@ impl Scope, LogMemo> { response } - fn query_recursively( - &mut self, - s: &Store, - parent: &ScopeQuery, - child: ScopeQuery, - ) -> Ptr { + fn query_recursively(&mut self, s: &Store, parent: &Q, child: Q) -> Ptr { let form = child.to_ptr(s); self.internal_insertions.push(form); @@ -253,7 +244,7 @@ impl Scope, LogMemo> { fn query_aux(&mut self, s: &Store, form: Ptr) -> (Ptr, Ptr) { let response = self.queries.get(&form).cloned().unwrap_or_else(|| { - let query = ScopeQuery::from_ptr(s, &form).expect("invalid query"); + let query = Q::from_ptr(s, &form).expect("invalid query"); let evaluated = query.eval(s, self); @@ -299,9 +290,7 @@ impl Scope, LogMemo> { insertions.sort_by_key(|kv| { let (key, _) = s.car_cdr(kv).unwrap(); - ScopeQuery::::from_ptr(s, &key) - .expect("invalid query") - .index() + Q::from_ptr(s, &key).expect("invalid query").index() }); for kv in self.toplevel_insertions.iter() { @@ -356,49 +345,20 @@ impl Scope, LogMemo> { s: &Store, ) -> Result<(), SynthesisError> { self.ensure_transcript_finalized(s); - + let mut circuit_scope = + CircuitScope::from_scope(&mut cs.namespace(|| "transcript"), g, s, self); + circuit_scope.init(cs, g, s); { - let circuit_scope = - &mut CircuitScope::from_scope(&mut cs.namespace(|| "transcript"), g, s, self); - circuit_scope.init(cs, g, s); - { - self.synthesize_insert_toplevel_queries(circuit_scope, cs, g, s)?; - self.synthesize_prove_all_queries(circuit_scope, cs, g, s)?; - } - circuit_scope.finalize(cs, g); - Ok(()) - } - } - - fn synthesize_insert_toplevel_queries>( - &mut self, - circuit_scope: &mut CircuitScope, LogMemo>, - cs: &mut CS, - g: &mut GlobalAllocator, - s: &Store, - ) -> Result<(), SynthesisError> { - for (i, kv) in self.toplevel_insertions.iter().enumerate() { - circuit_scope.synthesize_toplevel_query(cs, g, s, i, kv)?; - } - Ok(()) - } - - fn synthesize_prove_all_queries>( - &mut self, - circuit_scope: &mut CircuitScope, LogMemo>, - cs: &mut CS, - g: &mut GlobalAllocator, - s: &Store, - ) -> Result<(), SynthesisError> { - for (i, kv) in self.all_insertions.iter().enumerate() { - circuit_scope.synthesize_prove_query(cs, g, s, i, kv)?; + circuit_scope.synthesize_insert_toplevel_queries(self, cs, g, s)?; + circuit_scope.synthesize_prove_all_queries(self, cs, g, s)?; } + circuit_scope.finalize(cs, g); Ok(()) } } -impl> CircuitScope> { - fn from_scope>( +impl> CircuitScope> { + fn from_scope, Q: Query>( cs: &mut CS, g: &mut GlobalAllocator, s: &Store, @@ -535,9 +495,20 @@ impl> CircuitScope> { Ok((value, new_acc, new_insertion_transcript)) } -} -impl CircuitScope, LogMemo> { + fn synthesize_insert_toplevel_queries, Q: Query>( + &mut self, + scope: &mut Scope>, + cs: &mut CS, + g: &mut GlobalAllocator, + s: &Store, + ) -> Result<(), SynthesisError> { + for (i, kv) in scope.toplevel_insertions.iter().enumerate() { + self.synthesize_toplevel_query(cs, g, s, i, kv)?; + } + Ok(()) + } + fn synthesize_toplevel_query>( &mut self, cs: &mut CS, @@ -575,6 +546,19 @@ impl CircuitScope, LogMemo> { Ok(()) } + fn synthesize_prove_all_queries, Q: Query>( + &mut self, + scope: &mut Scope>, + cs: &mut CS, + g: &mut GlobalAllocator, + s: &Store, + ) -> Result<(), SynthesisError> { + for (i, kv) in scope.all_insertions.iter().enumerate() { + self.synthesize_prove_query(cs, g, s, i, kv)?; + } + Ok(()) + } + fn synthesize_prove_query>( &mut self, cs: &mut CS, @@ -592,8 +576,7 @@ impl CircuitScope, LogMemo> { ) .unwrap(); - let circuit_query = - ScopeCircuitQuery::from_ptr(&mut cs.namespace(|| "circuit_query"), s, key).unwrap(); + let circuit_query = CQ::from_ptr(&mut cs.namespace(|| "circuit_query"), s, key).unwrap(); let acc = self.acc.clone().unwrap(); let transcript = self.transcript.clone(); @@ -761,6 +744,7 @@ mod test { use crate::state::State; use bellpepper_core::{test_cs::TestConstraintSystem, Comparable}; + use demo::DemoQuery; use expect_test::{expect, Expect}; use pasta_curves::pallas::Scalar as F; use std::default::Default; @@ -768,7 +752,7 @@ mod test { #[test] fn test_query() { let s = &Store::::default(); - let mut scope: Scope, LogMemo> = Scope::default(); + let mut scope: Scope, LogMemo> = Scope::default(); let state = State::init_lurk_state(); let fact_4 = s.read_with_default_state("(factorial 4)").unwrap(); @@ -819,7 +803,7 @@ mod test { assert!(cs.is_satisfied()); } { - let mut scope: Scope, LogMemo> = Scope::default(); + let mut scope: Scope, LogMemo> = Scope::default(); scope.query(s, fact_4); scope.query(s, fact_3); diff --git a/src/coprocessor/memoset/query.rs b/src/coprocessor/memoset/query.rs index fe85db927e..c6b58e069e 100644 --- a/src/coprocessor/memoset/query.rs +++ b/src/coprocessor/memoset/query.rs @@ -1,19 +1,18 @@ -use bellpepper_core::{num::AllocatedNum, ConstraintSystem, SynthesisError}; +use bellpepper_core::{ConstraintSystem, SynthesisError}; use super::{CircuitScope, CircuitTranscript, LogMemo, Scope}; -use crate::circuit::gadgets::constraints::alloc_is_zero; -use crate::coprocessor::gadgets::construct_list; use crate::coprocessor::AllocatedPtr; use crate::field::LurkField; use crate::lem::circuit::GlobalAllocator; use crate::lem::{pointers::Ptr, store::Store}; use crate::symbol::Symbol; -use crate::tag::{ExprTag, Tag}; pub trait Query where - Self: Sized, + Self: Sized + Clone, { + type CQ: CircuitQuery; + fn eval(&self, s: &Store, scope: &mut Scope>) -> Ptr; fn recursive_eval( &self, @@ -31,29 +30,24 @@ where fn index(&self) -> usize; } -#[allow(unreachable_pub)] pub trait CircuitQuery where - Self: Sized, + Self: Sized + Clone, { - type Q: Query; - fn synthesize_eval>( &self, cs: &mut CS, g: &GlobalAllocator, store: &Store, - scope: &mut CircuitScope>, + scope: &mut CircuitScope>, acc: &AllocatedPtr, transcript: &CircuitTranscript, ) -> Result<(AllocatedPtr, AllocatedPtr, CircuitTranscript), SynthesisError>; - fn symbol(&self, s: &Store) -> Symbol { - self.dummy_query_variant(s).symbol() - } + fn symbol(&self) -> Symbol; fn symbol_ptr(&self, s: &Store) -> Ptr { - self.dummy_query_variant(s).symbol_ptr(s) + s.intern_symbol(&self.symbol()) } fn from_ptr>( @@ -61,246 +55,4 @@ where s: &Store, ptr: &Ptr, ) -> Result, SynthesisError>; - - fn dummy_query_variant(&self, s: &Store) -> Self::Q; -} - -#[derive(Debug, Clone)] -pub enum DemoQuery { - Factorial(Ptr), - Phantom(F), -} - -pub enum DemoCircuitQuery { - Factorial(AllocatedPtr), -} - -impl Query for DemoQuery { - // DemoQuery and Scope depend on each other. - fn eval(&self, s: &Store, scope: &mut Scope>) -> Ptr { - match self { - Self::Factorial(n) => { - let n_zptr = s.hash_ptr(n); - let n = n_zptr.value(); - - if *n == F::ZERO { - s.num(F::ONE) - } else { - let m_ptr = self.recursive_eval(scope, s, Self::Factorial(s.num(*n - F::ONE))); - let m_zptr = s.hash_ptr(&m_ptr); - let m = m_zptr.value(); - - s.num(*n * m) - } - } - _ => unreachable!(), - } - } - - fn recursive_eval( - &self, - scope: &mut Scope>, - s: &Store, - subquery: Self, - ) -> Ptr { - scope.query_recursively(s, self, subquery) - } - - fn symbol(&self) -> Symbol { - match self { - Self::Factorial(_) => Symbol::sym(&["lurk", "user", "factorial"]), - _ => unreachable!(), - } - } - - fn from_ptr(s: &Store, ptr: &Ptr) -> Option { - let (head, body) = s.car_cdr(ptr).expect("query should be cons"); - let sym = s.fetch_sym(&head).expect("head should be sym"); - - if sym == Symbol::sym(&["lurk", "user", "factorial"]) { - let (num, _) = s.car_cdr(&body).expect("query body should be cons"); - Some(Self::Factorial(num)) - } else { - None - } - } - - fn to_ptr(&self, s: &Store) -> Ptr { - match self { - Self::Factorial(n) => { - let factorial = s.intern_symbol(&self.symbol()); - - s.list(vec![factorial, *n]) - } - _ => unreachable!(), - } - } - - fn index(&self) -> usize { - match self { - Self::Factorial(_) => 0, - _ => unreachable!(), - } - } -} - -impl CircuitQuery for DemoCircuitQuery { - type Q = DemoQuery; - - fn dummy_query_variant(&self, s: &Store) -> Self::Q { - match self { - Self::Factorial(_) => Self::Q::Factorial(s.num(F::ZERO)), - } - } - - fn synthesize_eval>( - &self, - cs: &mut CS, - g: &GlobalAllocator, - store: &Store, - scope: &mut CircuitScope>, - acc: &AllocatedPtr, - transcript: &CircuitTranscript, - ) -> Result<(AllocatedPtr, AllocatedPtr, CircuitTranscript), SynthesisError> { - match self { - // TODO: Factor out the recursive boilerplate so individual queries can just implement their distinct logic - // using a sane framework. - Self::Factorial(n) => { - // FIXME: Check n tag or decide not to. - let base_case_f = g.alloc_const(cs, F::ONE); - let base_case = AllocatedPtr::alloc_tag( - &mut cs.namespace(|| "base_case"), - ExprTag::Num.to_field(), - base_case_f.clone(), - )?; - - let n_is_zero = alloc_is_zero(&mut cs.namespace(|| "n_is_zero"), n.hash())?; - - let (recursive_result, recursive_acc, recursive_transcript) = { - let new_n = AllocatedNum::alloc(&mut cs.namespace(|| "new_n"), || { - n.hash() - .get_value() - .map(|n| n - F::ONE) - .ok_or(SynthesisError::AssignmentMissing) - })?; - - // new_n * 1 = n - 1 - cs.enforce( - || "enforce_new_n", - |lc| lc + new_n.get_variable(), - |lc| lc + CS::one(), - |lc| lc + n.hash().get_variable() - CS::one(), - ); - - let subquery = { - let symbol = - g.alloc_ptr(cs, &store.intern_symbol(&self.symbol(store)), store); - - let new_num = AllocatedPtr::alloc_tag( - &mut cs.namespace(|| "new_num"), - ExprTag::Num.to_field(), - new_n, - )?; - construct_list( - &mut cs.namespace(|| "subquery"), - g, - store, - &[&symbol, &new_num], - None, - )? - }; - - let (sub_result, new_acc, new_transcript) = scope.synthesize_query( - &mut cs.namespace(|| "recursive query"), - g, - store, - &subquery, - acc, - transcript, - &n_is_zero.not(), - )?; - - let result_f = n.hash().mul( - &mut cs.namespace(|| "incremental multiplication"), - sub_result.hash(), - )?; - - let result = AllocatedPtr::alloc_tag( - &mut cs.namespace(|| "result"), - ExprTag::Num.to_field(), - result_f, - )?; - - (result, new_acc, new_transcript) - }; - - let value = AllocatedPtr::pick( - &mut cs.namespace(|| "pick value"), - &n_is_zero, - &base_case, - &recursive_result, - )?; - - let acc = AllocatedPtr::pick( - &mut cs.namespace(|| "pick acc"), - &n_is_zero, - acc, - &recursive_acc, - )?; - - let transcript = CircuitTranscript::pick( - &mut cs.namespace(|| "pick recursive_transcript"), - &n_is_zero, - transcript, - &recursive_transcript, - )?; - - Ok((value, acc, transcript)) - } - } - } - - fn from_ptr>( - cs: &mut CS, - s: &Store, - ptr: &Ptr, - ) -> Result, SynthesisError> { - let query = Self::Q::from_ptr(s, ptr); - Ok(if let Some(q) = query { - match q { - Self::Q::Factorial(n) => Some(Self::Factorial(AllocatedPtr::alloc(cs, || { - Ok(s.hash_ptr(&n)) - })?)), - _ => unreachable!(), - } - } else { - None - }) - } -} - -#[cfg(test)] -mod test { - use super::*; - - use ff::Field; - use pasta_curves::pallas::Scalar as F; - - #[test] - fn test_factorial() { - let s = Store::default(); - let mut scope: Scope, LogMemo> = Scope::default(); - let zero = s.num(F::ZERO); - let one = s.num(F::ONE); - let two = s.num(F::from_u64(2)); - let three = s.num(F::from_u64(3)); - let four = s.num(F::from_u64(4)); - let six = s.num(F::from_u64(6)); - let twenty_four = s.num(F::from_u64(24)); - assert_eq!(one, DemoQuery::Factorial(zero).eval(&s, &mut scope)); - assert_eq!(one, DemoQuery::Factorial(one).eval(&s, &mut scope)); - assert_eq!(two, DemoQuery::Factorial(two).eval(&s, &mut scope)); - assert_eq!(six, DemoQuery::Factorial(three).eval(&s, &mut scope)); - assert_eq!(twenty_four, DemoQuery::Factorial(four).eval(&s, &mut scope)); - } }