diff --git a/.clippy.toml b/.clippy.toml index 9c71707ab2..0538654263 100644 --- a/.clippy.toml +++ b/.clippy.toml @@ -5,3 +5,4 @@ disallowed-methods = [ { path = "pasta_curves::Fp", reason = "use pasta_curves::pallas::Base or pasta_curves::vesta::Scalar instead to communicate your intent" }, { path = "pasta_curves::Fq", reason = "use pasta_curves::pallas::Scalar or pasta_curves::vesta::Base instead to communicate your intent" }, ] +allow-dbg-in-tests = true diff --git a/src/circuit/gadgets/constraints.rs b/src/circuit/gadgets/constraints.rs index 98a01e12d4..9788b804b0 100644 --- a/src/circuit/gadgets/constraints.rs +++ b/src/circuit/gadgets/constraints.rs @@ -310,6 +310,28 @@ pub(crate) fn div>( Ok(res) } +pub(crate) fn invert>( + mut cs: CS, + a: &AllocatedNum, +) -> Result, SynthesisError> { + let inv = AllocatedNum::alloc(cs.namespace(|| "invert"), || { + let inv = (a.get_value().ok_or(SynthesisError::AssignmentMissing)?).invert(); + + let inv_opt: Option<_> = inv.into(); + inv_opt.ok_or(SynthesisError::DivisionByZero) + })?; + + // inv * a = 1 + cs.enforce( + || "inversion", + |lc| lc + inv.get_variable(), + |lc| lc + a.get_variable(), + |lc| lc + CS::one(), + ); + + Ok(inv) +} + /// Select the nth element of `from`, where `path_bits` represents n, least-significant bit first. /// The returned result contains the selected element, and constraints are enforced. /// `from.len()` must be a power of two. diff --git a/src/circuit/gadgets/pointer.rs b/src/circuit/gadgets/pointer.rs index 458c07e3da..b3a6072d5d 100644 --- a/src/circuit/gadgets/pointer.rs +++ b/src/circuit/gadgets/pointer.rs @@ -132,6 +132,14 @@ impl AllocatedPtr { &self.hash } + pub fn get_value(&self) -> Option> { + self.tag.get_value().and_then(|tag| { + self.hash + .get_value() + .map(|hash| ZPtr::from_parts(Tag::from_field(&tag).expect("bad tag"), hash)) + }) + } + pub fn enforce_equal>(&self, cs: &mut CS, other: &Self) { // debug_assert_eq!(self.tag.get_value(), other.tag.get_value()); enforce_equal(cs, || "tags equal", &self.tag, &other.tag); diff --git a/src/coprocessor/memoset/mod.rs b/src/coprocessor/memoset/mod.rs new file mode 100644 index 0000000000..b7ccbd275d --- /dev/null +++ b/src/coprocessor/memoset/mod.rs @@ -0,0 +1,716 @@ +//! The `memoset` module implements a MemoSet. + +use std::collections::HashMap; +use std::marker::PhantomData; + +use bellpepper_core::{boolean::Boolean, num::AllocatedNum, ConstraintSystem, SynthesisError}; +use itertools::Itertools; +use once_cell::sync::OnceCell; + +use super::gadgets::construct_cons; +use crate::circuit::gadgets::{ + constraints::{enforce_equal, enforce_equal_zero, invert, sub}, + pointer::AllocatedPtr, +}; +use crate::field::LurkField; +use crate::lem::circuit::GlobalAllocator; +use crate::lem::Tag; +use crate::lem::{pointers::Ptr, store::Store}; +use crate::tag::{ExprTag, Tag as XTag}; +use crate::z_ptr::ZPtr; + +use multiset::MultiSet; +use query::{CircuitQuery, DemoCircuitQuery, DemoQuery, Query}; + +mod multiset; +mod query; + +#[derive(Debug, Default)] +pub struct Scope, M: MemoSet> { + memoset: M, + // k => v + queries: HashMap, + // k => ordered subqueries + dependencies: HashMap>, + // kv pairs + toplevel_insertions: Vec, + // internally-inserted keys + internal_insertions: Vec, + // unique keys + all_insertions: Vec, + _p: PhantomData<(F, Q)>, +} + +impl Default for Scope, LogMemo> { + fn default() -> Self { + Self { + memoset: Default::default(), + queries: Default::default(), + dependencies: Default::default(), + toplevel_insertions: Default::default(), + internal_insertions: Default::default(), + all_insertions: Default::default(), + _p: Default::default(), + } + } +} + +struct CircuitScope, M: MemoSet> { + memoset: M, + // k -> v + queries: HashMap, ZPtr>, + // k -> allocated v + queries_alloc: HashMap, AllocatedPtr>, + transcript: Option>, + acc: Option>, + _p: PhantomData, +} + +type ScopeQuery = DemoQuery; + +impl Scope, LogMemo> { + pub fn query(&mut self, s: &Store, form: Ptr) -> Ptr { + let (response, kv_ptr) = self.query_aux(s, form); + + self.toplevel_insertions.push(kv_ptr); + self.memoset.add(kv_ptr); + + response + } + + fn query_internal( + &mut self, + s: &Store, + parent: &ScopeQuery, + child: ScopeQuery, + ) -> Ptr { + let form = child.to_ptr(s); + self.internal_insertions.push(form); + let (response, kv_ptr) = self.query_aux(s, form); + + self.dependencies + .entry(parent.to_ptr(s)) + .and_modify(|children| children.push(child.clone())) + .or_insert_with(|| vec![child]); + + self.memoset.add(kv_ptr); + response + } + + fn query_aux(&mut self, s: &Store, form: Ptr) -> (Ptr, Ptr) { + let found = self.queries.get(&form); + let response = if let Some(found) = found { + *found + } else { + let query = ScopeQuery::from_ptr(s, &form).expect("invalid query"); + + let evaluated = query.eval(s, self); + + self.queries.insert(form, evaluated); + evaluated + }; + + let kv = s.cons(form, response); + + (response, kv) + } + + fn finalize_transcript(&mut self, s: &Store) -> Ptr { + let (transcript, insertions) = self.build_transcript(s); + self.memoset.finalize_transcript(s, transcript); + self.all_insertions = insertions; + transcript + } + + fn ensure_transcript_finalized(&mut self, s: &Store) { + if !self.memoset.is_finalized() { + self.finalize_transcript(s); + } + } + + fn build_transcript(&self, s: &Store) -> (Ptr, Vec) { + let internal_insertions_kv = self.internal_insertions.iter().map(|key| { + let value = self.queries.get(key).expect("value missing for key"); + s.cons(*key, *value) + }); + + let mut insertions = + Vec::with_capacity(self.toplevel_insertions.len() + self.internal_insertions.len()); + insertions.extend(&self.toplevel_insertions); + insertions.extend(internal_insertions_kv); + + // Sort insertions by query type (index) for processing. + // This is because the transcript will be constructed sequentially by the circuits. + insertions.sort_by_key(|kv| { + let (key, _) = s.car_cdr(kv).unwrap(); + + ScopeQuery::::from_ptr(s, &key) + .expect("invalid query") + .index() + }); + + let mut transcript = s.intern_nil(); + + // Toplevel insertions must come first in transcript. + for kv in self.toplevel_insertions.iter() { + transcript = s.cons(*kv, transcript); + } + + // Then add insertions and deletions interleaved, sorted by query type. + let unique_keys = insertions + .iter() + .dedup() + .map(|kv| { + let key = s.car_cdr(kv).unwrap().0; + + if let Some(dependencies) = self.dependencies.get(&key) { + transcript = dependencies + .iter() + .map(|dependency| { + let k = dependency.to_ptr(s); + let v = self + .queries + .get(&k) + .expect("value missing for dependency key"); + s.cons(k, *v) + }) + .fold(transcript, |acc, dependency_kv| s.cons(dependency_kv, acc)); + }; + let count = self.memoset.count(kv); + let count_num = s.num(F::from_u64(count as u64)); + let kv_count = s.cons(*kv, count_num); + + transcript = s.cons(kv_count, transcript); + + key + }) + .collect::>(); + + (transcript, unique_keys) + } + + pub fn synthesize>( + &mut self, + cs: &mut CS, + g: &mut GlobalAllocator, + s: &Store, + ) -> Result<(), SynthesisError> { + self.ensure_transcript_finalized(s); + + { + let circuit_scope = &mut CircuitScope::from_scope(s, self); + circuit_scope.init(cs, g, s); + { + self.synthesize_insert_toplevel_queries(circuit_scope, cs, g, s)?; + self.synthesize_prove_all_queries(circuit_scope, cs, g, s)?; + } + circuit_scope.finalize(cs, g); + Ok(()) + } + } + + fn synthesize_insert_toplevel_queries>( + &mut self, + circuit_scope: &mut CircuitScope, LogMemo>, + cs: &mut CS, + g: &mut GlobalAllocator, + s: &Store, + ) -> Result<(), SynthesisError> { + for (i, kv) in self.toplevel_insertions.iter().enumerate() { + circuit_scope.synthesize_toplevel_query(cs, g, s, i, kv)?; + } + Ok(()) + } + + fn synthesize_prove_all_queries>( + &mut self, + circuit_scope: &mut CircuitScope, LogMemo>, + cs: &mut CS, + g: &mut GlobalAllocator, + s: &Store, + ) -> Result<(), SynthesisError> { + for (i, kv) in self.all_insertions.iter().enumerate() { + circuit_scope.synthesize_prove_query(cs, g, s, i, kv)?; + } + Ok(()) + } +} + +impl> CircuitScope> { + fn from_scope(s: &Store, scope: &Scope>) -> Self { + let queries = scope + .queries + .iter() + .map(|(k, v)| (s.hash_ptr(k), s.hash_ptr(v))) + .collect(); + Self { + memoset: scope.memoset.clone(), + queries, + queries_alloc: Default::default(), + transcript: Default::default(), + acc: Default::default(), + _p: Default::default(), + } + } + + fn init>( + &mut self, + cs: &mut CS, + g: &mut GlobalAllocator, + s: &Store, + ) { + g.alloc_tag(cs, &ExprTag::Cons); + g.alloc_z_ptr(cs, s.hash_ptr(&s.intern_nil())); + + self.acc = Some( + AllocatedPtr::alloc_constant(&mut cs.namespace(|| "acc"), s.hash_ptr(&s.num_u64(0))) + .unwrap(), + ); + let nil = s.intern_nil(); + let allocated_nil = g.alloc_ptr(cs, &nil, s); + + self.transcript = Some(allocated_nil.clone()); + } + + fn synthesize_insert_query>( + &self, + cs: &mut CS, + g: &GlobalAllocator, + s: &Store, + acc: &AllocatedPtr, + transcript: &AllocatedPtr, + key: &AllocatedPtr, + value: &AllocatedPtr, + ) -> Result<(AllocatedPtr, AllocatedPtr), SynthesisError> { + let kv = construct_cons(&mut cs.namespace(|| "kv"), g, s, key, value)?; + let new_transcript = construct_cons( + &mut cs.namespace(|| "new_transcript"), + g, + s, + &kv, + transcript, + )?; + + let acc_v = acc.hash(); + + let new_acc_v = + self.memoset + .synthesize_add(&mut cs.namespace(|| "new_acc_v"), acc_v, &kv)?; + + let new_acc = AllocatedPtr::alloc_tag( + &mut cs.namespace(|| "new_acc"), + ExprTag::Num.to_field(), + new_acc_v, + )?; + + Ok((new_acc, new_transcript)) + } + + // TODO: Rename + fn synthesize_remove_n>( + &self, + cs: &mut CS, + g: &GlobalAllocator, + s: &Store, + acc: &AllocatedPtr, + transcript: &AllocatedPtr, + key: &AllocatedPtr, + value: &AllocatedPtr, + ) -> Result<(AllocatedPtr, AllocatedPtr), SynthesisError> { + let kv = construct_cons(&mut cs.namespace(|| "kv"), g, s, key, value)?; + let count = { + // FIXME: What about when synthesizing shape? + AllocatedNum::alloc(&mut cs.namespace(|| "count"), || { + let zptr = kv.get_value().unwrap(); + Ok(F::from_u64(self.memoset.count(&s.to_ptr(&zptr)) as u64)) + })? + }; + let count_ptr = AllocatedPtr::alloc_tag( + &mut cs.namespace(|| "count_ptr"), + ExprTag::Num.to_field(), + count.clone(), + )?; + + let kv_count = construct_cons(&mut cs.namespace(|| "kv_count"), g, s, &kv, &count_ptr)?; + + let new_transcript = construct_cons( + &mut cs.namespace(|| "new_removal_transcript"), + g, + s, + &kv_count, + transcript, + )?; + + let new_acc_v = self.memoset.synthesize_remove_n( + &mut cs.namespace(|| "new_acc_v"), + acc.hash(), + &kv, + &count, + )?; + + let new_acc = AllocatedPtr::alloc_tag( + &mut cs.namespace(|| "new_acc"), + ExprTag::Num.to_field(), + new_acc_v, + )?; + Ok((new_acc, new_transcript)) + } + + fn finalize>(&mut self, cs: &mut CS, _g: &mut GlobalAllocator) { + let r = self.memoset.allocated_r(cs); + enforce_equal( + cs, + || "r_matches_transcript", + self.transcript.clone().unwrap().hash(), + &r, + ); + enforce_equal_zero(cs, || "acc_is_zero", self.acc.clone().unwrap().hash()); + } + + fn synthesize_query>( + &mut self, + cs: &mut CS, + g: &GlobalAllocator, + store: &Store, + key: &AllocatedPtr, + acc: &AllocatedPtr, + transcript: &AllocatedPtr, + not_dummy: &Boolean, // TODO: use this more deeply? + ) -> Result<(AllocatedPtr, AllocatedPtr, AllocatedPtr), SynthesisError> { + let value = key + .get_value() + .map(|k| { + self.queries_alloc + .entry(k) + .or_insert_with(|| { + AllocatedPtr::alloc(&mut cs.namespace(|| "value"), || { + if not_dummy.get_value() == Some(true) { + Ok(*self + .queries + .get(&k) + .ok_or(SynthesisError::AssignmentMissing)?) + } else { + // Dummy value that will not be used. + Ok(k) + } + }) + .unwrap() + }) + .clone() + }) + .ok_or(SynthesisError::AssignmentMissing)?; + + let (new_acc, new_insertion_transcript) = + self.synthesize_insert_query(cs, g, store, acc, transcript, key, &value)?; + + Ok((value, new_acc, new_insertion_transcript)) + } +} + +impl CircuitScope, LogMemo> { + fn synthesize_toplevel_query>( + &mut self, + cs: &mut CS, + g: &mut GlobalAllocator, + s: &Store, + i: usize, + kv: &Ptr, + ) -> Result<(), SynthesisError> { + let (key, value) = s.car_cdr(kv).unwrap(); + let cs = &mut cs.namespace(|| format!("toplevel-{i}")); + let allocated_key = AllocatedPtr::alloc(&mut cs.namespace(|| "allocated_key"), || { + Ok(s.hash_ptr(&key)) + }) + .unwrap(); + + let acc = self.acc.clone().unwrap(); + let insertion_transcript = self.transcript.clone().unwrap(); + + let (val, new_acc, new_transcript) = self.synthesize_query( + cs, + g, + s, + &allocated_key, + &acc, + &insertion_transcript, + &Boolean::Constant(true), + )?; + + assert_eq!(Some(value), val.get_value().map(|x| s.to_ptr(&x))); + + self.acc = Some(new_acc); + self.transcript = Some(new_transcript); + Ok(()) + } + + fn synthesize_prove_query>( + &mut self, + cs: &mut CS, + g: &mut GlobalAllocator, + s: &Store, + i: usize, + key: &Ptr, + ) -> Result<(), SynthesisError> { + let cs = &mut cs.namespace(|| format!("internal-{i}")); + + let allocated_key = + AllocatedPtr::alloc( + &mut cs.namespace(|| "allocated_key"), + || Ok(s.hash_ptr(key)), + ) + .unwrap(); + + let circuit_query = + DemoCircuitQuery::from_ptr(&mut cs.namespace(|| "circuit_query"), s, key).unwrap(); + + let acc = self.acc.clone().unwrap(); + let transcript = self.transcript.clone().unwrap(); + + let (val, new_acc, new_transcript) = circuit_query + .expect("not a query form") + .synthesize_eval(&mut cs.namespace(|| "eval"), g, s, self, &acc, &transcript) + .unwrap(); + + let (new_acc, new_transcript) = + self.synthesize_remove_n(cs, g, s, &new_acc, &new_transcript, &allocated_key, &val)?; + + self.acc = Some(new_acc); + self.transcript = Some(new_transcript); + + self.dbg_transcript(s); + Ok(()) + } + + fn dbg_transcript(&self, s: &Store) { + let z = self.transcript.clone().unwrap().get_value::().unwrap(); + let transcript = s.to_ptr(&z); + // dbg!(transcript.fmt_to_string_dammit(s)); + tracing::debug!("transcript: {}", transcript.fmt_to_string_dammit(s)); + } +} + +pub trait MemoSet: Clone { + fn is_finalized(&self) -> bool; + fn finalize_transcript(&mut self, s: &Store, transcript: Ptr); + fn r(&self) -> Option<&F>; + fn map_to_element(&self, x: F) -> Option; + fn add(&mut self, kv: Ptr); + fn synthesize_remove_n>( + &self, + cs: &mut CS, + acc: &AllocatedNum, + kv: &AllocatedPtr, + count: &AllocatedNum, + ) -> Result, SynthesisError>; + fn count(&self, form: &Ptr) -> usize; + + // Circuit + + fn allocated_r>(&self, cs: &mut CS) -> AllocatedNum; + + // x is H(k,v) = hash part of (cons k v) + fn synthesize_map_to_element>( + &self, + cs: &mut CS, + x: AllocatedNum, + ) -> Result, SynthesisError>; + + fn synthesize_add>( + &self, + cs: &mut CS, + acc: &AllocatedNum, + kv: &AllocatedPtr, + ) -> Result, SynthesisError>; +} + +#[derive(Debug, Clone)] +pub struct LogMemo { + multiset: MultiSet, + r: OnceCell, + transcript: OnceCell, + + allocated_r: OnceCell>>, +} + +impl Default for LogMemo { + fn default() -> Self { + // Be explicit. + Self { + multiset: MultiSet::new(), + r: Default::default(), + transcript: Default::default(), + allocated_r: Default::default(), + } + } +} + +impl MemoSet for LogMemo { + fn count(&self, form: &Ptr) -> usize { + self.multiset.get(form).unwrap_or(0) + } + + fn is_finalized(&self) -> bool { + self.transcript.get().is_some() + } + fn finalize_transcript(&mut self, s: &Store, transcript: Ptr) { + self.transcript + .set(transcript) + .expect("transcript already finalized"); + + let z_ptr = s.hash_ptr(&transcript); + assert_eq!(Tag::Expr(ExprTag::Cons), *z_ptr.tag()); + self.r.set(*z_ptr.value()).expect("r has already been set"); + } + + fn r(&self) -> Option<&F> { + self.r.get() + } + + fn allocated_r>(&self, cs: &mut CS) -> AllocatedNum { + self.allocated_r + .get_or_init(|| { + self.r() + .map(|r| AllocatedNum::alloc_infallible(&mut cs.namespace(|| "r"), || *r)) + }) + .clone() + .unwrap() + } + + // x is H(k,v) = hash part of (cons k v) + fn map_to_element(&self, x: F) -> Option { + self.r().and_then(|r| { + let d = *r + x; + d.invert().into() + }) + } + + // x is H(k,v) = hash part of (cons k v) + // 1 / r + x + fn synthesize_map_to_element>( + &self, + cs: &mut CS, + x: AllocatedNum, + ) -> Result, SynthesisError> { + let r = self.allocated_r(cs); + let r_plus_x = r.add(&mut cs.namespace(|| "r+x"), &x)?; + + invert(&mut cs.namespace(|| "invert(r+x)"), &r_plus_x) + } + + fn add(&mut self, kv: Ptr) { + self.multiset.add(kv); + } + + fn synthesize_add>( + &self, + cs: &mut CS, + acc: &AllocatedNum, + kv: &AllocatedPtr, + ) -> Result, SynthesisError> { + let kv_num = kv.hash().clone(); + let element = self.synthesize_map_to_element(&mut cs.namespace(|| "element"), kv_num)?; + acc.add(&mut cs.namespace(|| "add to acc"), &element) + } + + fn synthesize_remove_n>( + &self, + cs: &mut CS, + acc: &AllocatedNum, + kv: &AllocatedPtr, + count: &AllocatedNum, + ) -> Result, SynthesisError> { + let kv_num = kv.hash().clone(); + let element = self.synthesize_map_to_element(&mut cs.namespace(|| "element"), kv_num)?; + let scaled = element.mul(&mut cs.namespace(|| "scaled"), count)?; + sub(&mut cs.namespace(|| "add to acc"), acc, &scaled) + } +} + +#[cfg(test)] +mod test { + use super::*; + + use crate::state::State; + use bellpepper_core::{test_cs::TestConstraintSystem, Comparable}; + use pasta_curves::pallas::Scalar as F; + use std::default::Default; + + #[test] + fn test_query() { + let s = &Store::::default(); + let mut scope: Scope, LogMemo> = Scope::default(); + let state = State::init_lurk_state(); + + let fact_4 = s.read_with_default_state("(factorial 4)").unwrap(); + let fact_3 = s.read_with_default_state("(factorial 3)").unwrap(); + + { + scope.query(s, fact_4); + + for (k, v) in scope.queries.iter() { + println!("k: {}", k.fmt_to_string(s, &state)); + println!("v: {}", v.fmt_to_string(s, &state)); + } + // Factorial 4 will memoize calls to: + // fact(4), fact(3), fact(2), fact(1), and fact(0) + assert_eq!(5, scope.queries.len()); + assert_eq!(1, scope.toplevel_insertions.len()); + assert_eq!(4, scope.internal_insertions.len()); + + scope.finalize_transcript(s); + + let cs = &mut TestConstraintSystem::new(); + let g = &mut GlobalAllocator::default(); + + scope.synthesize(cs, g, s).unwrap(); + + println!( + "transcript: {}", + scope + .memoset + .transcript + .get() + .unwrap() + .fmt_to_string_dammit(s) + ); + + assert_eq!(10826, cs.num_constraints()); + assert_eq!(10859, cs.aux().len()); + + let unsat = cs.which_is_unsatisfied(); + + if unsat.is_some() { + dbg!(unsat); + } + assert!(cs.is_satisfied()); + } + { + let mut scope: Scope, LogMemo> = Scope::default(); + scope.query(s, fact_4); + scope.query(s, fact_3); + + // // No new queries. + assert_eq!(5, scope.queries.len()); + // // One new top-level insertion. + assert_eq!(2, scope.toplevel_insertions.len()); + // // No new internal insertions. + assert_eq!(4, scope.internal_insertions.len()); + + scope.finalize_transcript(s); + + let cs = &mut TestConstraintSystem::new(); + let g = &mut GlobalAllocator::default(); + + scope.synthesize(cs, g, s).unwrap(); + + assert_eq!(11408, cs.num_constraints()); + assert_eq!(11443, cs.aux().len()); + + let unsat = cs.which_is_unsatisfied(); + if unsat.is_some() { + dbg!(unsat); + } + assert!(cs.is_satisfied()); + } + } +} diff --git a/src/coprocessor/memoset/multiset.rs b/src/coprocessor/memoset/multiset.rs new file mode 100644 index 0000000000..9c17db4869 --- /dev/null +++ b/src/coprocessor/memoset/multiset.rs @@ -0,0 +1,23 @@ +use std::collections::HashMap; +use std::default::Default; +use std::hash::Hash; + +#[derive(PartialEq, Eq, Debug, Default, Clone)] +pub(crate) struct MultiSet(pub(crate) HashMap); + +impl MultiSet { + pub(crate) fn new() -> Self { + Self(Default::default()) + } + pub(crate) fn add(&mut self, element: T) { + self.add_n(element, 1); + } + + pub(crate) fn add_n(&mut self, element: T, n: usize) { + *self.0.entry(element).or_insert(0) += n; + } + + pub(crate) fn get(&self, element: &T) -> Option { + self.0.get(element).copied() + } +} diff --git a/src/coprocessor/memoset/query.rs b/src/coprocessor/memoset/query.rs new file mode 100644 index 0000000000..aaf9c13cf3 --- /dev/null +++ b/src/coprocessor/memoset/query.rs @@ -0,0 +1,311 @@ +use bellpepper_core::{num::AllocatedNum, ConstraintSystem, SynthesisError}; + +use super::{CircuitScope, LogMemo, Scope}; +use crate::circuit::gadgets::constraints::alloc_is_zero; +use crate::coprocessor::gadgets::construct_list; +use crate::coprocessor::AllocatedPtr; +use crate::field::LurkField; +use crate::lem::circuit::GlobalAllocator; +use crate::lem::{pointers::Ptr, store::Store}; +use crate::symbol::Symbol; +use crate::tag::{ExprTag, Tag}; + +pub trait Query { + type T: Query; + type C: CircuitQuery; + + fn eval(&self, s: &Store, scope: &mut Scope>) -> Ptr; + fn recursive_eval( + &self, + scope: &mut Scope>, + s: &Store, + subquery: Self::T, + ) -> Ptr; + fn from_ptr(s: &Store, ptr: &Ptr) -> Option; + fn to_ptr(&self, s: &Store) -> Ptr; + fn symbol(&self) -> Symbol; + fn symbol_ptr(&self, s: &Store) -> Ptr { + s.intern_symbol(&self.symbol()) + } + + fn index(&self) -> usize; +} + +#[allow(unreachable_pub)] +pub trait CircuitQuery { + type T: CircuitQuery; + + fn synthesize_eval>( + &self, + cs: &mut CS, + g: &GlobalAllocator, + store: &Store, + scope: &mut CircuitScope, LogMemo>, + acc: &AllocatedPtr, + transcript: &AllocatedPtr, + ) -> Result<(AllocatedPtr, AllocatedPtr, AllocatedPtr), SynthesisError>; + + fn symbol(&self, s: &Store) -> Symbol; + + fn symbol_ptr(&self, s: &Store) -> Ptr; + + fn from_ptr>( + cs: &mut CS, + s: &Store, + ptr: &Ptr, + ) -> Result, SynthesisError>; +} + +#[derive(Debug, Clone)] +pub enum DemoQuery { + Factorial(Ptr), + Phantom(F), +} + +impl DemoQuery { + fn factorial(s: &Store) -> Self { + Self::Factorial(s.num(F::ZERO)) + } +} + +pub enum DemoCircuitQuery { + Factorial(AllocatedPtr), +} + +impl Query for DemoQuery { + type T = Self; + type C = DemoCircuitQuery; + + // DemoQuery and Scope depend on each other. + fn eval(&self, s: &Store, scope: &mut Scope>) -> Ptr { + match self { + Self::Factorial(n) => { + let n_zptr = s.hash_ptr(n); + let n = n_zptr.value(); + + if *n == F::ZERO { + s.num(F::ONE) + } else { + let m_ptr = self.recursive_eval(scope, s, Self::Factorial(s.num(*n - F::ONE))); + let m_zptr = s.hash_ptr(&m_ptr); + let m = m_zptr.value(); + + s.num(*n * m) + } + } + _ => unreachable!(), + } + } + + fn recursive_eval( + &self, + scope: &mut Scope>, + s: &Store, + subquery: Self::T, + ) -> Ptr { + scope.query_internal(s, self, subquery) + } + + fn symbol(&self) -> Symbol { + match self { + Self::Factorial(_) => Symbol::sym(&["lurk", "user", "factorial"]), + _ => unreachable!(), + } + } + + fn from_ptr(s: &Store, ptr: &Ptr) -> Option { + let (head, body) = s.car_cdr(ptr).expect("query should be cons"); + let sym = s.fetch_sym(&head).expect("head should be sym"); + + if sym == Symbol::sym(&["lurk", "user", "factorial"]) { + let (num, _) = s.car_cdr(&body).expect("query body should be cons"); + Some(Self::Factorial(num)) + } else { + None + } + } + + fn to_ptr(&self, s: &Store) -> Ptr { + match self { + Self::Factorial(n) => { + let factorial = s.intern_symbol(&self.symbol()); + + s.list(vec![factorial, *n]) + } + _ => unreachable!(), + } + } + + fn index(&self) -> usize { + match self { + Self::Factorial(_) => 0, + _ => unreachable!(), + } + } +} + +impl CircuitQuery for DemoCircuitQuery { + type T = Self; + + fn symbol(&self, s: &Store) -> Symbol { + match self { + Self::Factorial(_) => DemoQuery::factorial(s).symbol(), + } + } + fn symbol_ptr(&self, s: &Store) -> Ptr { + match self { + Self::Factorial(_) => DemoQuery::factorial(s).symbol_ptr(s), + } + } + + fn synthesize_eval>( + &self, + cs: &mut CS, + g: &GlobalAllocator, + store: &Store, + scope: &mut CircuitScope, LogMemo>, + acc: &AllocatedPtr, + transcript: &AllocatedPtr, + ) -> Result<(AllocatedPtr, AllocatedPtr, AllocatedPtr), SynthesisError> { + match self { + // TODO: Factor out the recursive boilerplate so individual queries can just implement their distinct logic + // using a sane framework. + Self::Factorial(n) => { + // FIXME: Check n tag or decide not to. + let base_case_f = g.alloc_const(cs, F::ONE); + let base_case = AllocatedPtr::alloc_tag( + &mut cs.namespace(|| "base_case"), + ExprTag::Num.to_field(), + base_case_f.clone(), + )?; + + let n_is_zero = alloc_is_zero(&mut cs.namespace(|| "n_is_zero"), n.hash())?; + + let (recursive_result, recursive_acc, recursive_transcript) = { + let new_n = AllocatedNum::alloc(&mut cs.namespace(|| "new_n"), || { + n.hash() + .get_value() + .map(|n| n - F::ONE) + .ok_or(SynthesisError::AssignmentMissing) + })?; + + // new_n * 1 = n - 1 + cs.enforce( + || "enforce_new_n", + |lc| lc + new_n.get_variable(), + |lc| lc + CS::one(), + |lc| lc + n.hash().get_variable() - CS::one(), + ); + + let subquery = { + let symbol = + g.alloc_ptr(cs, &store.intern_symbol(&self.symbol(store)), store); + + let new_num = AllocatedPtr::alloc_tag( + &mut cs.namespace(|| "new_num"), + ExprTag::Num.to_field(), + new_n, + )?; + construct_list( + &mut cs.namespace(|| "subquery"), + g, + store, + &[&symbol, &new_num], + None, + )? + }; + + let (sub_result, new_acc, new_transcript) = scope.synthesize_query( + &mut cs.namespace(|| "recursive query"), + g, + store, + &subquery, + acc, + transcript, + &n_is_zero.not(), + )?; + + let result_f = n.hash().mul( + &mut cs.namespace(|| "incremental multiplication"), + sub_result.hash(), + )?; + + let result = AllocatedPtr::alloc_tag( + &mut cs.namespace(|| "result"), + ExprTag::Num.to_field(), + result_f, + )?; + + (result, new_acc, new_transcript) + }; + + let value = AllocatedPtr::pick( + &mut cs.namespace(|| "pick value"), + &n_is_zero, + &base_case, + &recursive_result, + )?; + + let acc = AllocatedPtr::pick( + &mut cs.namespace(|| "pick acc"), + &n_is_zero, + acc, + &recursive_acc, + )?; + + let transcript = AllocatedPtr::pick( + &mut cs.namespace(|| "pick insertion_transcript"), + &n_is_zero, + transcript, + &recursive_transcript, + )?; + + Ok((value, acc, transcript)) + } + } + } + + fn from_ptr>( + cs: &mut CS, + s: &Store, + ptr: &Ptr, + ) -> Result, SynthesisError> { + let query = DemoQuery::from_ptr(s, ptr); + Ok(if let Some(q) = query { + match q { + DemoQuery::Factorial(n) => Some(Self::Factorial(AllocatedPtr::alloc(cs, || { + Ok(s.hash_ptr(&n)) + })?)), + _ => unreachable!(), + } + } else { + None + }) + } +} + +#[cfg(test)] +mod test { + use super::*; + + use ff::Field; + use pasta_curves::pallas::Scalar as F; + + #[test] + fn test_factorial() { + let s = Store::default(); + let mut scope: Scope, LogMemo> = Scope::default(); + let zero = s.num(F::ZERO); + let one = s.num(F::ONE); + let two = s.num(F::from_u64(2)); + let three = s.num(F::from_u64(3)); + let four = s.num(F::from_u64(4)); + let six = s.num(F::from_u64(6)); + let twenty_four = s.num(F::from_u64(24)); + assert_eq!(one, DemoQuery::Factorial(zero).eval(&s, &mut scope)); + assert_eq!(one, DemoQuery::Factorial(one).eval(&s, &mut scope)); + assert_eq!(two, DemoQuery::Factorial(two).eval(&s, &mut scope)); + assert_eq!(six, DemoQuery::Factorial(three).eval(&s, &mut scope)); + assert_eq!(twenty_four, DemoQuery::Factorial(four).eval(&s, &mut scope)); + } +} diff --git a/src/coprocessor/mod.rs b/src/coprocessor/mod.rs index 8c711370db..81685dcaf9 100644 --- a/src/coprocessor/mod.rs +++ b/src/coprocessor/mod.rs @@ -10,6 +10,7 @@ use crate::{ pub mod circom; pub mod gadgets; +pub mod memoset; pub mod sha256; pub mod trie; diff --git a/src/lem/store.rs b/src/lem/store.rs index acb504d97a..19a1903b57 100644 --- a/src/lem/store.rs +++ b/src/lem/store.rs @@ -1109,6 +1109,10 @@ impl Ptr { } } + pub fn fmt_to_string_dammit(&self, store: &Store) -> String { + self.fmt_to_string(store, crate::state::initial_lurk_state()) + } + fn fmt_cont2_to_string( &self, name: &str, diff --git a/src/proof/tests/mod.rs b/src/proof/tests/mod.rs index 902646af84..976e52f881 100644 --- a/src/proof/tests/mod.rs +++ b/src/proof/tests/mod.rs @@ -186,7 +186,7 @@ where if unsat.is_some() { // For some reason, this isn't getting printed from within the implementation as expected. // Since we always want to know this information, if the condition occurs, just print it here. - tracing::debug!("{:?}", unsat); + println!("{:?}", unsat); } assert!(cs.is_satisfied()); assert!(cs.verify(&multiframe.public_inputs()));