Skip to content

Commit

Permalink
Implement env lookup query. (#1080)
Browse files Browse the repository at this point in the history
* Implement env lookup query.

* Add and use default implementation of recursive_eval.

* Add RecursiveQuery trait.

* Optimize Factorial arg in DemoQuery.

* Fix rebase on coroutine rc.

* Clippy.

---------

Co-authored-by: porcuquine <porcuquine@users.noreply.github.com>
  • Loading branch information
porcuquine and porcuquine authored Feb 5, 2024
1 parent d82def2 commit 107dc37
Show file tree
Hide file tree
Showing 7 changed files with 706 additions and 123 deletions.
4 changes: 2 additions & 2 deletions src/circuit/gadgets/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -195,15 +195,15 @@ macro_rules! equal_t {
// Returns a Boolean which is true if any of its arguments are true.
macro_rules! or {
($cs:expr, $a:expr, $b:expr) => {
or(
crate::circuit::gadgets::constraints::or(
$cs.namespace(|| format!("{} or {}", stringify!($a), stringify!($b))),
$a,
$b,
)
};
($cs:expr, $a:expr, $b:expr, $c:expr, $($x:expr),+) => {{
let or_tmp_cs_ = &mut $cs.namespace(|| format!("or({})", stringify!(vec![$a, $b, $c, $($x),*])));
or_v(or_tmp_cs_, &[$a, $b, $c, $($x),*])
bellpepper::gadgets::boolean_utils::or_v(or_tmp_cs_, &[$a, $b, $c, $($x),*])
}};
($cs:expr, $a:expr, $($x:expr),+) => {{
let or_tmp_cs_ = &mut $cs.namespace(|| format!("or {}", stringify!(vec![$a, $($x),*])));
Expand Down
81 changes: 81 additions & 0 deletions src/coprocessor/gadgets.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use crate::{
circuit::GlobalAllocator,
pointers::{Ptr, ZPtr},
store::{expect_ptrs, Store},
tag,
},
tag::{ExprTag, Tag},
};
Expand Down Expand Up @@ -145,6 +146,86 @@ pub(crate) fn construct_list<F: LurkField, CS: ConstraintSystem<F>>(
})
}

/// Constructs an `Env` pointer
#[allow(dead_code)]
#[inline]
pub(crate) fn construct_env<F: LurkField, CS: ConstraintSystem<F>>(
cs: &mut CS,
g: &GlobalAllocator<F>,
store: &Store<F>,
var_hash: &AllocatedNum<F>,
val: &AllocatedPtr<F>,
next_env: &AllocatedNum<F>,
) -> Result<AllocatedPtr<F>, SynthesisError> {
let tag = g.alloc_tag_cloned(cs, &ExprTag::Env);

let hash = hash_poseidon(
cs,
vec![
var_hash.clone(),
val.tag().clone(),
val.hash().clone(),
next_env.clone(),
],
store.poseidon_cache.constants.c4(),
)?;

Ok(AllocatedPtr::from_parts(tag, hash))
}

/// Deconstructs `env`, assumed to be a composition of a symbol hash, a value `Ptr`, and a next `Env` hash.
///
/// # Panics
/// Panics if the store can't deconstruct the env hash.
#[allow(dead_code)]
pub(crate) fn deconstruct_env<F: LurkField, CS: ConstraintSystem<F>>(
cs: &mut CS,
s: &Store<F>,
not_dummy: &Boolean,
env: &AllocatedNum<F>,
) -> Result<(AllocatedNum<F>, AllocatedPtr<F>, AllocatedNum<F>), SynthesisError> {
let env_zptr = ZPtr::from_parts(tag::Tag::Expr(ExprTag::Env), env.get_value().unwrap());
let env_ptr = s.to_ptr(&env_zptr);

let (a, b, c, d) = {
if let Some([v, val, new_env]) = s.pop_binding(env_ptr) {
let v_zptr = s.hash_ptr(&v);
let val_zptr = s.hash_ptr(&val);
let new_env_zptr = s.hash_ptr(&new_env);
(
*v_zptr.value(),
val_zptr.tag().to_field::<F>(),
*val_zptr.value(),
*new_env_zptr.value(),
)
} else {
(F::ZERO, F::ZERO, F::ZERO, F::ZERO)
}
};

let key_sym_hash = AllocatedNum::alloc_infallible(&mut cs.namespace(|| "key_sym_hash"), || a);
let val_tag = AllocatedNum::alloc_infallible(&mut cs.namespace(|| "val_tag"), || b);
let val_hash = AllocatedNum::alloc_infallible(&mut cs.namespace(|| "val_hash"), || c);
let new_env_hash = AllocatedNum::alloc_infallible(&mut cs.namespace(|| "new_env_hash"), || d);

let hash = hash_poseidon(
&mut cs.namespace(|| "hash"),
vec![
key_sym_hash.clone(),
val_tag.clone(),
val_hash.clone(),
new_env_hash.clone(),
],
s.poseidon_cache.constants.c4(),
)?;

let val = AllocatedPtr::from_parts(val_tag, val_hash);

implies_equal(&mut cs.namespace(|| "hash equality"), not_dummy, env, &hash);

Ok((key_sym_hash, val, new_env_hash))
}

/// Retrieves the `Ptr` that corresponds to `a_ptr` by using the `Store` as the
/// hint provider
#[allow(dead_code)]
Expand Down
150 changes: 57 additions & 93 deletions src/coroutine/memoset/demo.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
use bellpepper_core::{num::AllocatedNum, ConstraintSystem, SynthesisError};

use super::{
query::{CircuitQuery, Query},
query::{CircuitQuery, Query, RecursiveQuery},
CircuitScope, CircuitTranscript, LogMemo, LogMemoCircuit, Scope,
};
use crate::circuit::gadgets::constraints::alloc_is_zero;
use crate::circuit::gadgets::pointer::AllocatedPtr;
use crate::coprocessor::gadgets::construct_list;
use crate::field::LurkField;
use crate::lem::circuit::GlobalAllocator;
use crate::lem::{pointers::Ptr, store::Store};
Expand All @@ -28,7 +27,6 @@ pub(crate) enum DemoCircuitQuery<F: LurkField> {
impl<F: LurkField> Query<F> for DemoQuery<F> {
type CQ = DemoCircuitQuery<F>;

// DemoQuery and Scope depend on each other.
fn eval(&self, s: &Store<F>, scope: &mut Scope<Self, LogMemo<F>>) -> Ptr {
match self {
Self::Factorial(n) => {
Expand All @@ -49,15 +47,6 @@ impl<F: LurkField> Query<F> for DemoQuery<F> {
}
}

fn recursive_eval(
&self,
scope: &mut Scope<Self, LogMemo<F>>,
s: &Store<F>,
subquery: Self,
) -> Ptr {
scope.query_recursively(s, self, subquery)
}

fn symbol(&self) -> Symbol {
match self {
Self::Factorial(_) => Symbol::sym(&["lurk", "user", "factorial"]),
Expand All @@ -70,7 +59,7 @@ impl<F: LurkField> Query<F> for DemoQuery<F> {
let sym = s.fetch_sym(&head).expect("head should be sym");

if sym == Symbol::sym(&["lurk", "user", "factorial"]) {
let (num, _) = s.car_cdr(&body).expect("query body should be cons");
let num = body;
Some(Self::Factorial(num))
} else {
None
Expand All @@ -82,7 +71,7 @@ impl<F: LurkField> Query<F> for DemoQuery<F> {
Self::Factorial(n) => {
let factorial = s.intern_symbol(&self.symbol());

s.list(vec![factorial, *n])
s.cons(factorial, *n)
}
_ => unreachable!(),
}
Expand Down Expand Up @@ -116,6 +105,32 @@ impl<F: LurkField> Query<F> for DemoQuery<F> {
}
}

impl<F: LurkField> RecursiveQuery<F> for DemoCircuitQuery<F> {
// It would be nice if this could be passed to `CircuitQuery::recurse` as an optional closure, rather than be a
// trait method. That would allow more generality. The types get complicated, though. For generality, we should
// support a context object that can be initialized once in `synthesize_eval` and be passed through for use here.
fn post_recursion<CS: ConstraintSystem<F>>(
&self,
cs: &mut CS,
subquery_result: AllocatedPtr<F>,
) -> Result<AllocatedPtr<F>, SynthesisError> {
match self {
Self::Factorial(n) => {
let result_f = n.hash().mul(
&mut cs.namespace(|| "incremental multiplication"),
subquery_result.hash(),
)?;

AllocatedPtr::alloc_tag(
&mut cs.namespace(|| "result"),
ExprTag::Num.to_field(),
result_f,
)
}
}
}
}

impl<F: LurkField> CircuitQuery<F> for DemoCircuitQuery<F> {
fn synthesize_eval<CS: ConstraintSystem<F>>(
&self,
Expand All @@ -127,8 +142,6 @@ impl<F: LurkField> CircuitQuery<F> for DemoCircuitQuery<F> {
transcript: &CircuitTranscript<F>,
) -> Result<(AllocatedPtr<F>, AllocatedPtr<F>, CircuitTranscript<F>), SynthesisError> {
match self {
// TODO: Factor out the recursive boilerplate so individual queries can just implement their distinct logic
// using a sane framework.
Self::Factorial(n) => {
// FIXME: Check n tag or decide not to.
let base_case_f = g.alloc_const(cs, F::ONE);
Expand All @@ -140,85 +153,36 @@ impl<F: LurkField> CircuitQuery<F> for DemoCircuitQuery<F> {

let n_is_zero = alloc_is_zero(&mut cs.namespace(|| "n_is_zero"), n.hash())?;

let (recursive_result, recursive_acc, recursive_transcript) = {
let new_n = AllocatedNum::alloc(&mut cs.namespace(|| "new_n"), || {
n.hash()
.get_value()
.map(|n| n - F::ONE)
.ok_or(SynthesisError::AssignmentMissing)
})?;

// new_n * 1 = n - 1
cs.enforce(
|| "enforce_new_n",
|lc| lc + new_n.get_variable(),
|lc| lc + CS::one(),
|lc| lc + n.hash().get_variable() - CS::one(),
);

let subquery = {
let symbol = g.alloc_ptr(cs, &self.symbol_ptr(store), store);

let new_num = AllocatedPtr::alloc_tag(
&mut cs.namespace(|| "new_num"),
ExprTag::Num.to_field(),
new_n,
)?;
construct_list(
&mut cs.namespace(|| "subquery"),
g,
store,
&[&symbol, &new_num],
None,
)?
};

let (sub_result, new_acc, new_transcript) = scope.synthesize_internal_query(
&mut cs.namespace(|| "recursive query"),
g,
store,
&subquery,
acc,
transcript,
&n_is_zero.not(),
)?;

let result_f = n.hash().mul(
&mut cs.namespace(|| "incremental multiplication"),
sub_result.hash(),
)?;

let result = AllocatedPtr::alloc_tag(
&mut cs.namespace(|| "result"),
ExprTag::Num.to_field(),
result_f,
)?;

(result, new_acc, new_transcript)
};

let value = AllocatedPtr::pick(
&mut cs.namespace(|| "pick value"),
&n_is_zero,
&base_case,
&recursive_result,
)?;

let acc = AllocatedPtr::pick(
&mut cs.namespace(|| "pick acc"),
&n_is_zero,
acc,
&recursive_acc,
)?;

let transcript = CircuitTranscript::pick(
&mut cs.namespace(|| "pick recursive_transcript"),
&n_is_zero,
transcript,
&recursive_transcript,
let new_n = AllocatedNum::alloc(&mut cs.namespace(|| "new_n"), || {
n.hash()
.get_value()
.map(|n| n - F::ONE)
.ok_or(SynthesisError::AssignmentMissing)
})?;

// new_n * 1 = n - 1
cs.enforce(
|| "enforce_new_n",
|lc| lc + new_n.get_variable(),
|lc| lc + CS::one(),
|lc| lc + n.hash().get_variable() - CS::one(),
);

let new_num = AllocatedPtr::alloc_tag(
&mut cs.namespace(|| "new_num"),
ExprTag::Num.to_field(),
new_n,
)?;

Ok((value, acc, transcript))
self.recurse(
cs,
g,
store,
scope,
&new_num,
&n_is_zero.not(),
(&base_case, acc, transcript),
)
}
}
}
Expand Down
Loading

0 comments on commit 107dc37

Please sign in to comment.