Skip to content

Commit

Permalink
Coroutine simplifications and generalization (#1160)
Browse files Browse the repository at this point in the history
* Some simplifications

* `GlobalAllocator` does not need to be mutable

* Generalized `recurse` and `post_recursion`

* Moved nil to outside the loop
  • Loading branch information
gabriel-barrett authored Feb 21, 2024
1 parent 3d2f7f6 commit 5458a53
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 72 deletions.
6 changes: 4 additions & 2 deletions src/coroutine/memoset/demo.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,10 @@ impl<F: LurkField> RecursiveQuery<F> for DemoCircuitQuery<F> {
fn post_recursion<CS: ConstraintSystem<F>>(
&self,
cs: &mut CS,
subquery_result: AllocatedPtr<F>,
subquery_results: &[AllocatedPtr<F>],
) -> Result<AllocatedPtr<F>, SynthesisError> {
assert_eq!(subquery_results.len(), 1);
let subquery_result = &subquery_results[0];
match self {
Self::Factorial(n) => {
let result_f = n.hash().mul(
Expand Down Expand Up @@ -185,7 +187,7 @@ impl<F: LurkField> CircuitQuery<F> for DemoCircuitQuery<F> {
g,
store,
scope,
subquery,
&[subquery],
&n_is_zero.not(),
(&base_case, acc),
allocated_key,
Expand Down
33 changes: 15 additions & 18 deletions src/coroutine/memoset/env.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,10 @@ impl<F: LurkField> Query<F> for EnvQuery<F> {
fn to_circuit<CS: ConstraintSystem<F>>(&self, cs: &mut CS, s: &Store<F>) -> Self::CQ {
match self {
EnvQuery::Lookup(var, env) => {
let mut var_cs = cs.namespace(|| "var");
let allocated_var =
AllocatedNum::alloc_infallible(&mut var_cs, || *s.hash_ptr(var).value());
let mut env_cs = var_cs.namespace(|| "env");
AllocatedNum::alloc_infallible(ns!(cs, "var"), || *s.hash_ptr(var).value());
let allocated_env =
AllocatedNum::alloc_infallible(&mut env_cs, || *s.hash_ptr(env).value());
AllocatedNum::alloc_infallible(ns!(cs, "env"), || *s.hash_ptr(env).value());
Self::CQ::Lookup(allocated_var, allocated_env)
}
_ => unreachable!(),
Expand All @@ -114,7 +112,16 @@ impl<F: LurkField> Query<F> for EnvQuery<F> {
}
}

impl<F: LurkField> RecursiveQuery<F> for EnvCircuitQuery<F> {}
impl<F: LurkField> RecursiveQuery<F> for EnvCircuitQuery<F> {
fn post_recursion<CS: ConstraintSystem<F>>(
&self,
_cs: &mut CS,
subquery_results: &[AllocatedPtr<F>],
) -> Result<AllocatedPtr<F>, SynthesisError> {
assert_eq!(subquery_results.len(), 1);
Ok(subquery_results[0].clone())
}
}

impl<F: LurkField> CircuitQuery<F> for EnvCircuitQuery<F> {
fn synthesize_args<CS: ConstraintSystem<F>>(
Expand Down Expand Up @@ -188,7 +195,7 @@ impl<F: LurkField> CircuitQuery<F> for EnvCircuitQuery<F> {
g,
store,
scope,
subquery,
&[subquery],
&is_immediate.not(),
(&immediate_result, acc),
allocated_key,
Expand All @@ -198,17 +205,7 @@ impl<F: LurkField> CircuitQuery<F> for EnvCircuitQuery<F> {
}

fn from_ptr<CS: ConstraintSystem<F>>(cs: &mut CS, s: &Store<F>, ptr: &Ptr) -> Option<Self> {
let query = EnvQuery::from_ptr(s, ptr);
query.and_then(|q| match q {
EnvQuery::Lookup(var, env) => {
let allocated_var =
AllocatedNum::alloc_infallible(ns!(cs, "var"), || *s.hash_ptr(&var).value());
let allocated_env =
AllocatedNum::alloc_infallible(ns!(cs, "env"), || *s.hash_ptr(&env).value());
Some(Self::Lookup(allocated_var, allocated_env))
}
_ => unreachable!(),
})
EnvQuery::from_ptr(s, ptr).map(|q| q.to_circuit(cs, s))
}

fn dummy_from_index<CS: ConstraintSystem<F>>(cs: &mut CS, s: &Store<F>, index: usize) -> Self {
Expand Down Expand Up @@ -376,7 +373,7 @@ mod test {
scope.finalize_transcript(s);

let cs = &mut TestConstraintSystem::new();
let g = &mut GlobalAllocator::default();
let g = &GlobalAllocator::default();

scope.synthesize(cs, g, s).unwrap();

Expand Down
29 changes: 12 additions & 17 deletions src/coroutine/memoset/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ pub struct CircuitTranscript<F: LurkField> {
}

impl<F: LurkField> CircuitTranscript<F> {
fn new<CS: ConstraintSystem<F>>(cs: &mut CS, g: &mut GlobalAllocator<F>, s: &Store<F>) -> Self {
fn new<CS: ConstraintSystem<F>>(cs: &mut CS, g: &GlobalAllocator<F>, s: &Store<F>) -> Self {
let nil = s.intern_nil();
let allocated_nil = g.alloc_ptr(cs, &nil, s);
Self {
Expand Down Expand Up @@ -486,7 +486,7 @@ impl<'a, F: LurkField, Q: Query<F>> CoroutineCircuit<'a, F, LogMemo<F>, Q> {
cs: &mut CS,
z: &[AllocatedPtr<F>],
) -> Result<(Option<AllocatedNum<F>>, Vec<AllocatedPtr<F>>), SynthesisError> {
let g = &mut GlobalAllocator::<F>::default();
let g = &GlobalAllocator::<F>::default();

assert_eq!(6, z.len());
let [c, e, k, memoset_acc, transcript, r] = z else {
Expand Down Expand Up @@ -805,7 +805,7 @@ impl<F: LurkField, Q: Query<F>> Scope<Q, LogMemo<F>, F> {
pub fn synthesize<CS: ConstraintSystem<F>>(
&mut self,
cs: &mut CS,
g: &mut GlobalAllocator<F>,
g: &GlobalAllocator<F>,
s: &Store<F>,
) -> Result<(), SynthesisError> {
self.ensure_transcript_finalized(s);
Expand Down Expand Up @@ -892,7 +892,7 @@ impl<F: LurkField, Q: Query<F>> Scope<Q, LogMemo<F>, F> {
impl<F: LurkField> CircuitScope<F, LogMemoCircuit<F>> {
fn new<CS: ConstraintSystem<F>>(
cs: &mut CS,
g: &mut GlobalAllocator<F>,
g: &GlobalAllocator<F>,
s: &Store<F>,
memoset: LogMemoCircuit<F>,
provenances: &HashMap<ZPtr<Tag, F>, ZPtr<Tag, F>>,
Expand All @@ -905,12 +905,7 @@ impl<F: LurkField> CircuitScope<F, LogMemoCircuit<F>> {
}
}

fn init<CS: ConstraintSystem<F>>(
&mut self,
cs: &mut CS,
g: &mut GlobalAllocator<F>,
s: &Store<F>,
) {
fn init<CS: ConstraintSystem<F>>(&mut self, cs: &mut CS, g: &GlobalAllocator<F>, s: &Store<F>) {
self.acc =
Some(AllocatedPtr::alloc_constant(ns!(cs, "acc"), s.hash_ptr(&s.num_u64(0))).unwrap());

Expand Down Expand Up @@ -1035,7 +1030,7 @@ impl<F: LurkField> CircuitScope<F, LogMemoCircuit<F>> {
Ok((new_acc, new_transcript))
}

fn finalize<CS: ConstraintSystem<F>>(&mut self, cs: &mut CS, _g: &mut GlobalAllocator<F>) {
fn finalize<CS: ConstraintSystem<F>>(&mut self, cs: &mut CS, _g: &GlobalAllocator<F>) {
let r = self.memoset.allocated_r();
enforce_equal(cs, || "r_matches_transcript", self.transcript.r(), &r);
enforce_equal_zero(cs, || "acc_is_zero", self.acc.clone().unwrap().hash());
Expand Down Expand Up @@ -1145,7 +1140,7 @@ impl<F: LurkField> CircuitScope<F, LogMemoCircuit<F>> {
&mut self,
scope: &mut Scope<Q, LogMemo<F>, F>,
cs: &mut CS,
g: &mut GlobalAllocator<F>,
g: &GlobalAllocator<F>,
s: &Store<F>,
) -> Result<(), SynthesisError> {
for (i, kv) in scope.toplevel_insertions.iter().enumerate() {
Expand All @@ -1166,7 +1161,7 @@ impl<F: LurkField> CircuitScope<F, LogMemoCircuit<F>> {
fn synthesize_toplevel_query<CS: ConstraintSystem<F>>(
&mut self,
cs: &mut CS,
g: &mut GlobalAllocator<F>,
g: &GlobalAllocator<F>,
s: &Store<F>,
i: usize,
allocated_key: &AllocatedPtr<F>,
Expand Down Expand Up @@ -1199,7 +1194,7 @@ impl<F: LurkField> CircuitScope<F, LogMemoCircuit<F>> {
fn synthesize_prove_key_query<CS: ConstraintSystem<F>, Q: Query<F>>(
&mut self,
cs: &mut CS,
g: &mut GlobalAllocator<F>,
g: &GlobalAllocator<F>,
s: &Store<F>,
key: Option<&Ptr>,
index: usize,
Expand Down Expand Up @@ -1233,7 +1228,7 @@ impl<F: LurkField> CircuitScope<F, LogMemoCircuit<F>> {
fn synthesize_prove_query<CS: ConstraintSystem<F>, CQ: CircuitQuery<F>>(
&mut self,
cs: &mut CS,
g: &mut GlobalAllocator<F>,
g: &GlobalAllocator<F>,
s: &Store<F>,
allocated_key: &AllocatedPtr<F>,
circuit_query: &CQ,
Expand Down Expand Up @@ -1525,7 +1520,7 @@ mod test {
scope.finalize_transcript(s);

let cs = &mut TestConstraintSystem::new();
let g = &mut GlobalAllocator::default();
let g = &GlobalAllocator::default();

scope.synthesize(cs, g, s).unwrap();

Expand Down Expand Up @@ -1565,7 +1560,7 @@ mod test {
scope.finalize_transcript(s);

let cs = &mut TestConstraintSystem::new();
let g = &mut GlobalAllocator::default();
let g = &GlobalAllocator::default();

scope.synthesize(cs, g, s).unwrap();

Expand Down
68 changes: 33 additions & 35 deletions src/coroutine/memoset/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,13 +88,9 @@ where
store: &Store<F>,
result: AllocatedPtr<F>,
dependency_provenances: Vec<AllocatedPtr<F>>,
allocated_key: Option<&AllocatedPtr<F>>,
allocated_key: &AllocatedPtr<F>,
) -> Result<AllocatedPtr<F>, SynthesisError> {
let query = if let Some(q) = allocated_key {
q.clone()
} else {
self.synthesize_query(ns!(cs, "query"), g, store)?
};
let query = allocated_key.clone();
let p = AllocatedProvenance::new(query, result, dependency_provenances.clone());

Ok(p.to_ptr(cs, g, store)?.clone())
Expand All @@ -105,37 +101,48 @@ pub(crate) trait RecursiveQuery<F: LurkField>: CircuitQuery<F> {
fn post_recursion<CS: ConstraintSystem<F>>(
&self,
_cs: &mut CS,
subquery_result: AllocatedPtr<F>,
) -> Result<AllocatedPtr<F>, SynthesisError> {
// The default implementation provides tail recursion.
Ok(subquery_result)
}
subquery_results: &[AllocatedPtr<F>],
) -> Result<AllocatedPtr<F>, SynthesisError>;

fn recurse<CS: ConstraintSystem<F>>(
&self,
cs: &mut CS,
g: &GlobalAllocator<F>,
store: &Store<F>,
scope: &mut CircuitScope<F, LogMemoCircuit<F>>,
subquery: Self,
subqueries: &[Self],
is_recursive: &Boolean,
immediate: (&AllocatedPtr<F>, &AllocatedPtr<F>),
allocated_key: &AllocatedPtr<F>,
) -> Result<((AllocatedPtr<F>, AllocatedPtr<F>), AllocatedPtr<F>), SynthesisError> {
let is_immediate = is_recursive.not();
let nil = g.alloc_ptr(ns!(cs, "nil"), &store.intern_nil(), store);

let subquery = subquery.synthesize_query(cs, g, store)?;

let ((sub_result, sub_provenance), new_acc) = scope.synthesize_internal_query(
ns!(cs, "recursive query"),
g,
store,
&subquery,
immediate.1,
is_recursive,
)?;

let (recursive_result, recursive_acc) = (self.post_recursion(cs, sub_result)?, new_acc);
let mut sub_results = vec![];
let mut dependency_provenances = vec![];
let mut new_acc = immediate.1.clone();
for subquery in subqueries {
let subquery = subquery.synthesize_query(cs, g, store)?;
let ((sub_result, sub_provenance), next_acc) = scope.synthesize_internal_query(
ns!(cs, "recursive query"),
g,
store,
&subquery,
&new_acc,
is_recursive,
)?;
let dependency_provenance = AllocatedPtr::pick(
ns!(cs, "dependency provenance"),
&is_immediate,
&nil,
&sub_provenance,
)?;
sub_results.push(sub_result);
dependency_provenances.push(dependency_provenance);
new_acc = next_acc;
}

let (recursive_result, recursive_acc) = (self.post_recursion(cs, &sub_results)?, new_acc);

let value = AllocatedPtr::pick(
ns!(cs, "pick value"),
Expand All @@ -151,22 +158,13 @@ pub(crate) trait RecursiveQuery<F: LurkField>: CircuitQuery<F> {
&recursive_acc,
)?;

let nil = g.alloc_ptr(ns!(cs, "nil"), &store.intern_nil(), store);

let dependency_provenance = AllocatedPtr::pick(
ns!(cs, "dependency provenance"),
&is_immediate,
&nil,
&sub_provenance,
)?;

let provenance = self.synthesize_provenance(
ns!(cs, "provenance"),
g,
store,
value.clone(),
vec![dependency_provenance.clone()],
Some(allocated_key),
dependency_provenances,
allocated_key,
)?;

Ok(((value, provenance.clone()), acc))
Expand Down

1 comment on commit 5458a53

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmarks

Table of Contents

Overview

This benchmark report shows the Fibonacci GPU benchmark.
NVIDIA L4
Intel(R) Xeon(R) CPU @ 2.20GHz
32 vCPUs
125 GB RAM
Workflow run: https://github.com/lurk-lab/lurk-rs/actions/runs/7982795723

Benchmark Results

LEM Fibonacci Prove - rc = 100

ref=3d2f7f61686c716bc3722a9e864eccfe03f926a4 ref=5458a5308d6a370a09ba6543aab3a1a3d7883c1a
num-100 1.45 s (✅ 1.00x) 1.45 s (✅ 1.00x faster)
num-200 2.77 s (✅ 1.00x) 2.78 s (✅ 1.00x slower)

LEM Fibonacci Prove - rc = 600

ref=3d2f7f61686c716bc3722a9e864eccfe03f926a4 ref=5458a5308d6a370a09ba6543aab3a1a3d7883c1a
num-100 1.82 s (✅ 1.00x) 1.83 s (✅ 1.01x slower)
num-200 3.01 s (✅ 1.00x) 3.02 s (✅ 1.00x slower)

Made with criterion-table

Please sign in to comment.