Skip to content

Commit

Permalink
Add per-coroutine-circuit rc.
Browse files Browse the repository at this point in the history
  • Loading branch information
porcuquine committed Jan 30, 2024
1 parent e3008a0 commit a658e74
Show file tree
Hide file tree
Showing 3 changed files with 178 additions and 49 deletions.
34 changes: 21 additions & 13 deletions src/coroutine/memoset/demo.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,22 @@ impl<F: LurkField> Query<F> for DemoQuery<F> {
}
}

fn to_circuit<CS: ConstraintSystem<F>>(&self, cs: &mut CS, s: &Store<F>) -> Self::CQ {
match self {
DemoQuery::Factorial(n) => {
Self::CQ::Factorial(AllocatedPtr::alloc_infallible(cs, || s.hash_ptr(n)))
}
_ => unreachable!(),
}
}

fn dummy_from_index(s: &Store<F>, index: usize) -> Self {
match index {
0 => Self::Factorial(s.num(0.into())),
_ => unreachable!(),
}
}

fn index(&self) -> usize {
match self {
Self::Factorial(_) => 0,
Expand Down Expand Up @@ -208,19 +224,11 @@ impl<F: LurkField> CircuitQuery<F> for DemoCircuitQuery<F> {
}

fn from_ptr<CS: ConstraintSystem<F>>(cs: &mut CS, s: &Store<F>, ptr: &Ptr) -> Option<Self> {
let query = DemoQuery::from_ptr(s, ptr);
if let Some(q) = query {
match q {
DemoQuery::Factorial(n) => {
Some(Self::Factorial(AllocatedPtr::alloc_infallible(cs, || {
s.hash_ptr(&n)
})))
}
_ => unreachable!(),
}
} else {
None
}
DemoQuery::from_ptr(s, ptr).map(|q| q.to_circuit(cs, s))
}

fn dummy_from_index<CS: ConstraintSystem<F>>(cs: &mut CS, s: &Store<F>, index: usize) -> Self {
DemoQuery::dummy_from_index(s, index).to_circuit(cs, s)
}

fn symbol(&self) -> Symbol {
Expand Down
188 changes: 152 additions & 36 deletions src/coroutine/memoset/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
//! results computed 'naturally' during evaluation. We then separate and sort in an order matching that which the NIVC
//! prover will follow when provably maintaining the multiset accumulator and Fiat-Shamir transcript in the circuit.

use itertools::Itertools;
use std::collections::HashMap;
use std::marker::PhantomData;

Expand Down Expand Up @@ -193,18 +194,21 @@ pub struct Scope<Q, M> {
/// unique keys: query-index -> [key]
unique_inserted_keys: HashMap<usize, Vec<Ptr>>,
transcribe_internal_insertions: bool,
// This may become an explicit map or something allowing more fine-grained control.
default_rc: usize,
}

const DEFAULT_RC_FOR_QUERY: usize = 1;
const DEFAULT_TRANSCRIBE_INTERNAL_INSERTIONS: bool = false;

impl<F: LurkField, Q> Default for Scope<Q, LogMemo<F>> {
fn default() -> Self {
Self::new(DEFAULT_TRANSCRIBE_INTERNAL_INSERTIONS)
Self::new(DEFAULT_TRANSCRIBE_INTERNAL_INSERTIONS, DEFAULT_RC_FOR_QUERY)
}
}

impl<F: LurkField, Q> Scope<Q, LogMemo<F>> {
fn new(transcribe_internal_insertions: bool) -> Self {
fn new(transcribe_internal_insertions: bool, default_rc: usize) -> Self {
Self {
memoset: Default::default(),
queries: Default::default(),
Expand All @@ -213,6 +217,7 @@ impl<F: LurkField, Q> Scope<Q, LogMemo<F>> {
internal_insertions: Default::default(),
unique_inserted_keys: Default::default(),
transcribe_internal_insertions,
default_rc,
}
}
}
Expand All @@ -232,8 +237,10 @@ pub struct CoroutineCircuit<'a, F: LurkField, CM, Q> {
queries: &'a HashMap<Ptr, Ptr>,
memoset: CM,
keys: Vec<Ptr>,
query_index: usize,
store: &'a Store<F>,
transcribe_internal_insertions: bool,
rc: usize,
_p: PhantomData<Q>,
}

Expand All @@ -244,14 +251,19 @@ impl<'a, F: LurkField, Q: Query<F>> CoroutineCircuit<'a, F, LogMemoCircuit<F>, Q
scope: &'a Scope<Q, LogMemo<F>>,
memoset: LogMemoCircuit<F>,
keys: Vec<Ptr>,
query_index: usize,
store: &'a Store<F>,
rc: usize,
) -> Self {
assert!(keys.len() <= rc);
Self {
memoset,
queries: &scope.queries,
keys,
query_index,
store,
transcribe_internal_insertions: scope.transcribe_internal_insertions,
rc,
_p: Default::default(),
}
}
Expand Down Expand Up @@ -280,9 +292,21 @@ impl<'a, F: LurkField, Q: Query<F>> CoroutineCircuit<'a, F, LogMemoCircuit<F>, Q
);
circuit_scope.update_from_io(memoset_acc.clone(), transcript.clone(), r);

for (i, key) in self.keys.iter().enumerate() {
for (i, key) in self
.keys
.iter()
.map(Some)
.pad_using(self.rc, |_| None)
.enumerate()
{
let cs = &mut cs.namespace(|| format!("internal-{i}"));
circuit_scope.synthesize_prove_key_query::<_, Q>(cs, g, self.store, key)?;
circuit_scope.synthesize_prove_key_query::<_, Q>(
cs,
g,
self.store,
key,
self.query_index,
)?;
}

let (memoset_acc, transcript, r_num) = circuit_scope.io();
Expand Down Expand Up @@ -456,7 +480,7 @@ impl<F: LurkField, Q: Query<F>> Scope<Q, LogMemo<F>> {
r_num,
)?;
let dummy = g.alloc_ptr(cs, &s.intern_nil(), s);
let z = vec![
let mut z = vec![
dummy.clone(),
dummy.clone(),
dummy.clone(),
Expand All @@ -467,16 +491,37 @@ impl<F: LurkField, Q: Query<F>> Scope<Q, LogMemo<F>> {
for (index, keys) in self.unique_inserted_keys.iter() {
let cs = &mut cs.namespace(|| format!("query-index-{index}"));

let mut circuit: CoroutineCircuit<'_, F, LogMemoCircuit<F>, Q> =
CoroutineCircuit::new(self, memoset_circuit.clone(), keys.clone(), s);

let (_next_pc, z_out) = circuit.synthesize(cs, &z)?;
{
let memoset_acc = &z_out[3];
let transcript = &z_out[4];
let r = &z_out[5];

circuit_scope.update_from_io(memoset_acc.clone(), transcript.clone(), r);
let rc = self.rc_for_query(*index);

for (i, chunk) in keys.chunks(rc).enumerate() {
// This namespace exists only because we are putting multiple 'chunks' into a single, larger circuit (as a stage in development).
// It shouldn't exist, when instead we have only the single NIVC circuit repeated multiple times.
let cs = &mut cs.namespace(|| format!("chunk-{i}"));

let mut circuit: CoroutineCircuit<'_, F, LogMemoCircuit<F>, Q> =
CoroutineCircuit::new(
self,
memoset_circuit.clone(),
chunk.to_vec(),
*index,
s,
rc,
);

let (_next_pc, z_out) = circuit.synthesize(cs, &z)?;
{
let memoset_acc = &z_out[3];
let transcript = &z_out[4];
let r = &z_out[5];

circuit_scope.update_from_io(
memoset_acc.clone(),
transcript.clone(),
r,
);

z = z_out;
}
}
}
}
Expand All @@ -485,6 +530,10 @@ impl<F: LurkField, Q: Query<F>> Scope<Q, LogMemo<F>> {
circuit_scope.finalize(cs, g);
Ok(())
}

fn rc_for_query(&self, _index: usize) -> usize {
self.default_rc
}
}

impl<F: LurkField> CircuitScope<F, LogMemoCircuit<F>> {
Expand Down Expand Up @@ -734,18 +783,34 @@ impl<F: LurkField> CircuitScope<F, LogMemoCircuit<F>> {
cs: &mut CS,
g: &mut GlobalAllocator<F>,
s: &Store<F>,
key: &Ptr,
key: Option<&Ptr>,
index: usize,
) -> Result<(), SynthesisError> {
let allocated_key =
AllocatedPtr::alloc(
&mut cs.namespace(|| "allocated_key"),
|| Ok(s.hash_ptr(key)),
)
.unwrap();
let allocated_key = AllocatedPtr::alloc(&mut cs.namespace(|| "allocated_key"), || {
if let Some(key) = key {
Ok(s.hash_ptr(key))
} else {
Ok(s.hash_ptr(&s.intern_nil()))
}
})
.unwrap();

let circuit_query = if let Some(key) = key {
Q::CQ::from_ptr(&mut cs.namespace(|| "circuit_query"), s, key).unwrap()
} else {
Q::CQ::dummy_from_index(&mut cs.namespace(|| "circuit_query"), s, index)
};

let circuit_query = Q::CQ::from_ptr(&mut cs.namespace(|| "circuit_query"), s, key).unwrap();
let not_dummy = key.is_some();

self.synthesize_prove_query::<_, Q::CQ>(cs, g, s, &allocated_key, &circuit_query)?;
self.synthesize_prove_query::<_, Q::CQ>(
cs,
g,
s,
&allocated_key,
&circuit_query,
not_dummy,
)?;
Ok(())
}

Expand All @@ -756,6 +821,7 @@ impl<F: LurkField> CircuitScope<F, LogMemoCircuit<F>> {
s: &Store<F>,
allocated_key: &AllocatedPtr<F>,
circuit_query: &CQ,
not_dummy: bool,
) -> Result<(), SynthesisError> {
let acc = self.acc.clone().unwrap();
let transcript = self.transcript.clone();
Expand All @@ -767,8 +833,22 @@ impl<F: LurkField> CircuitScope<F, LogMemoCircuit<F>> {
let (new_acc, new_transcript) =
self.synthesize_remove(cs, g, s, &new_acc, &new_transcript, allocated_key, &val)?;

self.acc = Some(new_acc);
self.transcript = new_transcript;
// Prover can choose non-deterministically whether or not a given query is a dummy, to allow for padding.
let final_acc = AllocatedPtr::pick(
&mut cs.namespace(|| "final_acc"),
&Boolean::Constant(not_dummy),
&new_acc,
self.acc.as_ref().expect("acc missing"),
)?;
let final_transcript = CircuitTranscript::pick(
&mut cs.namespace(|| "final_transcripot"),
&Boolean::Constant(not_dummy),
&new_transcript,
&self.transcript,
)?;

self.acc = Some(final_acc);
self.transcript = final_transcript;

Ok(())
}
Expand Down Expand Up @@ -975,21 +1055,55 @@ mod test {
fn test_query_with_internal_insertion_transcript() {
test_query_aux(
true,
expect!["10831"],
expect!["10864"],
expect!["11413"],
expect!["11450"],
expect!["10875"],
expect!["10908"],
expect!["11457"],
expect!["11494"],
1,
);
test_query_aux(
true,
expect!["12908"],
expect!["12947"],
expect!["13490"],
expect!["13533"],
3,
);
test_query_aux(
true,
expect!["21106"],
expect!["21169"],
expect!["21688"],
expect!["21755"],
10,
)
}

#[test]
fn test_query_without_internal_insertion_transcript() {
test_query_aux(
false,
expect!["9386"],
expect!["9419"],
expect!["9968"],
expect!["10005"],
expect!["9430"],
expect!["9463"],
expect!["10012"],
expect!["10049"],
1,
);
test_query_aux(
false,
expect!["11174"],
expect!["11213"],
expect!["11756"],
expect!["11799"],
3,
);
test_query_aux(
false,
expect!["18216"],
expect!["18279"],
expect!["18798"],
expect!["18865"],
10,
)
}

Expand All @@ -999,9 +1113,11 @@ mod test {
expected_aux_simple: Expect,
expected_constraints_compound: Expect,
expected_aux_compound: Expect,
circuit_query_rc: usize,
) {
let s = &Store::<F>::default();
let mut scope: Scope<DemoQuery<F>, LogMemo<F>> = Scope::new(transcribe_internal_insertions);
let mut scope: Scope<DemoQuery<F>, LogMemo<F>> =
Scope::new(transcribe_internal_insertions, circuit_query_rc);
let state = State::init_lurk_state();

let fact_4 = s.read_with_default_state("(factorial 4)").unwrap();
Expand Down Expand Up @@ -1054,7 +1170,7 @@ mod test {

{
let mut scope: Scope<DemoQuery<F>, LogMemo<F>> =
Scope::new(transcribe_internal_insertions);
Scope::new(transcribe_internal_insertions, circuit_query_rc);
scope.query(s, fact_4);
scope.query(s, fact_3);

Expand Down
5 changes: 5 additions & 0 deletions src/coroutine/memoset/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ where
) -> Ptr;
fn from_ptr(s: &Store<F>, ptr: &Ptr) -> Option<Self>;
fn to_ptr(&self, s: &Store<F>) -> Ptr;
fn to_circuit<CS: ConstraintSystem<F>>(&self, cs: &mut CS, s: &Store<F>) -> Self::CQ;
fn dummy_from_index(s: &Store<F>, index: usize) -> Self;

fn symbol(&self) -> Symbol;
fn symbol_ptr(&self, s: &Store<F>) -> Ptr {
s.intern_symbol(&self.symbol())
Expand Down Expand Up @@ -54,4 +57,6 @@ where
}

fn from_ptr<CS: ConstraintSystem<F>>(cs: &mut CS, s: &Store<F>, ptr: &Ptr) -> Option<Self>;

fn dummy_from_index<CS: ConstraintSystem<F>>(cs: &mut CS, s: &Store<F>, index: usize) -> Self;
}

0 comments on commit a658e74

Please sign in to comment.