Skip to content

Commit

Permalink
Accumulate insertions in HashMap to avoid sort. (#1057)
Browse files Browse the repository at this point in the history
* Accumulate insertions in HashMap to avoid sort.

* Refactor insert().

* Clippy: use if let.

* Use IndexSet since dedup is insufficient.

* Refactor prove_query.

* Save unique_keys in index-keyed HashMap to avoid last sort.

* Clippy.

---------

Co-authored-by: porcuquine <porcuquine@users.noreply.github.com>
  • Loading branch information
porcuquine and porcuquine authored Jan 17, 2024
1 parent 8e5b5ea commit 46c7e62
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 68 deletions.
22 changes: 12 additions & 10 deletions src/coroutine/memoset/demo.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,10 @@ impl<F: LurkField> Query<F> for DemoQuery<F> {
_ => unreachable!(),
}
}

fn count() -> usize {
1
}
}

impl<F: LurkField> CircuitQuery<F> for DemoCircuitQuery<F> {
Expand Down Expand Up @@ -203,22 +207,20 @@ impl<F: LurkField> CircuitQuery<F> for DemoCircuitQuery<F> {
}
}

fn from_ptr<CS: ConstraintSystem<F>>(
cs: &mut CS,
s: &Store<F>,
ptr: &Ptr,
) -> Result<Option<Self>, SynthesisError> {
fn from_ptr<CS: ConstraintSystem<F>>(cs: &mut CS, s: &Store<F>, ptr: &Ptr) -> Option<Self> {
let query = DemoQuery::from_ptr(s, ptr);
Ok(if let Some(q) = query {
if let Some(q) = query {
match q {
DemoQuery::Factorial(n) => Some(Self::Factorial(AllocatedPtr::alloc(cs, || {
Ok(s.hash_ptr(&n))
})?)),
DemoQuery::Factorial(n) => {
Some(Self::Factorial(AllocatedPtr::alloc_infallible(cs, || {
s.hash_ptr(&n)
})))
}
_ => unreachable!(),
}
} else {
None
})
}
}

fn symbol(&self) -> Symbol {
Expand Down
149 changes: 96 additions & 53 deletions src/coroutine/memoset/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ use std::collections::HashMap;
use std::marker::PhantomData;

use bellpepper_core::{boolean::Boolean, num::AllocatedNum, ConstraintSystem, SynthesisError};
use itertools::Itertools;
use indexmap::IndexSet;
use once_cell::sync::OnceCell;

use crate::circuit::gadgets::{
Expand Down Expand Up @@ -190,8 +190,8 @@ pub struct Scope<Q, M> {
toplevel_insertions: Vec<Ptr>,
/// internally-inserted keys
internal_insertions: Vec<Ptr>,
/// unique keys
all_insertions: Vec<Ptr>,
/// unique keys: query-index -> [key]
unique_inserted_keys: HashMap<usize, Vec<Ptr>>,
}

impl<F: LurkField, Q> Default for Scope<Q, LogMemo<F>> {
Expand All @@ -202,7 +202,7 @@ impl<F: LurkField, Q> Default for Scope<Q, LogMemo<F>> {
dependencies: Default::default(),
toplevel_insertions: Default::default(),
internal_insertions: Default::default(),
all_insertions: Default::default(),
unique_inserted_keys: Default::default(),
}
}
}
Expand Down Expand Up @@ -258,7 +258,7 @@ impl<F: LurkField, Q: Query<F>> Scope<Q, LogMemo<F>> {
fn finalize_transcript(&mut self, s: &Store<F>) -> Transcript<F> {
let (transcript, insertions) = self.build_transcript(s);
self.memoset.finalize_transcript(s, transcript.clone());
self.all_insertions = insertions;
self.unique_inserted_keys = insertions;
transcript
}

Expand All @@ -268,28 +268,43 @@ impl<F: LurkField, Q: Query<F>> Scope<Q, LogMemo<F>> {
}
}

fn build_transcript(&self, s: &Store<F>) -> (Transcript<F>, Vec<Ptr>) {
fn build_transcript(&self, s: &Store<F>) -> (Transcript<F>, HashMap<usize, Vec<Ptr>>) {
let mut transcript = Transcript::new(s);

let internal_insertions_kv = self.internal_insertions.iter().map(|key| {
let value = self.queries.get(key).expect("value missing for key");
Transcript::make_kv(s, *key, *value)
});
// k -> [kv]
let mut insertions: HashMap<Ptr, IndexSet<Ptr>> = HashMap::new();
let mut unique_keys: HashMap<usize, Vec<Ptr>> = Default::default();

let mut insertions =
Vec::with_capacity(self.toplevel_insertions.len() + self.internal_insertions.len());
insertions.extend(&self.toplevel_insertions);
insertions.extend(internal_insertions_kv);
let mut insert = |kv: Ptr| {
let key = s.car_cdr(&kv).unwrap().0;

// Sort insertions by query type (index) for processing. This is because the transcript will be constructed
// sequentially by the circuits, and we potentially batch queries of the same type in a single coprocessor
// circuit.
insertions.sort_by_key(|kv| {
let (key, _) = s.car_cdr(kv).unwrap();
if let Some(kvs) = insertions.get_mut(&key) {
kvs.insert(kv);
} else {
let index = Q::from_ptr(s, &key).expect("bad query").index();
unique_keys
.entry(index)
.and_modify(|keys| keys.push(key))
.or_insert_with(|| vec![key]);
//unique_keys.push(key);
let mut x = IndexSet::new();
x.insert(kv);

insertions.insert(key, x);
}
};

Q::from_ptr(s, &key).expect("invalid query").index()
let internal_insertions_kv = self.internal_insertions.iter().map(|key| {
let value = self.queries.get(key).expect("value missing for key");
Transcript::make_kv(s, *key, *value)
});

for kv in &self.toplevel_insertions {
insert(*kv);
}
for kv in internal_insertions_kv {
insert(kv);
}
for kv in self.toplevel_insertions.iter() {
transcript.add(s, *kv);
}
Expand All @@ -298,16 +313,11 @@ impl<F: LurkField, Q: Query<F>> Scope<Q, LogMemo<F>> {
// because when proving later, each query's proof must record that its subquery proofs are being deferred
// (insertions) before then proving itself (making use of any subquery results) and removing the now-proved
// deferral from the MemoSet.
let unique_keys = insertions
.iter()
.dedup() // We need to process every key's dependencies once.
.map(|kv| {
let key = s.car_cdr(kv).unwrap().0;

if let Some(dependencies) = self.dependencies.get(&key) {
dependencies
.iter()
.for_each(|dependency| {
for index in 0..Q::count() {
for key in unique_keys.get(&index).expect("unreachable") {
for kv in insertions.get(key).unwrap().iter() {
if let Some(dependencies) = self.dependencies.get(key) {
dependencies.iter().for_each(|dependency| {
let k = dependency.to_ptr(s);
let v = self
.queries
Expand All @@ -317,21 +327,19 @@ impl<F: LurkField, Q: Query<F>> Scope<Q, LogMemo<F>> {
// that these keys might already have been inserted before, but we need to repeat if so
// because the proof must do so each time a query is used.
let kv = Transcript::make_kv(s, k, *v);
transcript.add(s, kv)
transcript.add(s, kv)
})
};
let count = self.memoset.count(kv);
let kv_count = Transcript::make_kv_count(s, *kv, count);

// Add removal for the query identified by `key`. The queries being removed here were deduplicated
// above, so each is removed only once. However, we freely choose the multiplicity (`count`) of the
// removal to match the total number of insertions actually made (considering dependencies).
transcript.add(s, kv_count);

key
})
.collect::<Vec<_>>();

};
let count = self.memoset.count(kv);
let kv_count = Transcript::make_kv_count(s, *kv, count);

// Add removal for the query identified by `key`. The queries being removed here were deduplicated
// above, so each is removed only once. However, we freely choose the multiplicity (`count`) of the
// removal to match the total number of insertions actually made (considering dependencies).
transcript.add(s, kv_count);
}
}
}
(transcript, unique_keys)
}

Expand Down Expand Up @@ -549,41 +557,65 @@ impl<F: LurkField> CircuitScope<F, LogMemo<F>> {
g: &mut GlobalAllocator<F>,
s: &Store<F>,
) -> Result<(), SynthesisError> {
for (i, kv) in scope.all_insertions.iter().enumerate() {
self.synthesize_prove_query::<_, Q::CQ>(cs, g, s, i, kv)?;
for (index, keys) in scope.unique_inserted_keys.iter() {
let cs = &mut cs.namespace(|| format!("query-index-{index}"));
self.synthesize_prove_queries::<_, Q>(cs, g, s, keys)?;
}
Ok(())
}

fn synthesize_prove_query<CS: ConstraintSystem<F>, CQ: CircuitQuery<F>>(
fn synthesize_prove_queries<CS: ConstraintSystem<F>, Q: Query<F>>(
&mut self,
cs: &mut CS,
g: &mut GlobalAllocator<F>,
s: &Store<F>,
i: usize,
key: &Ptr,
keys: &[Ptr],
) -> Result<(), SynthesisError> {
let cs = &mut cs.namespace(|| format!("internal-{i}"));
for (i, key) in keys.iter().enumerate() {
let cs = &mut cs.namespace(|| format!("internal-{i}"));

self.synthesize_prove_key_query::<_, Q>(cs, g, s, key)?;
}
Ok(())
}

fn synthesize_prove_key_query<CS: ConstraintSystem<F>, Q: Query<F>>(
&mut self,
cs: &mut CS,
g: &mut GlobalAllocator<F>,
s: &Store<F>,
key: &Ptr,
) -> Result<(), SynthesisError> {
let allocated_key =
AllocatedPtr::alloc(
&mut cs.namespace(|| "allocated_key"),
|| Ok(s.hash_ptr(key)),
)
.unwrap();

let circuit_query = CQ::from_ptr(&mut cs.namespace(|| "circuit_query"), s, key).unwrap();
let circuit_query = Q::CQ::from_ptr(&mut cs.namespace(|| "circuit_query"), s, key).unwrap();

self.synthesize_prove_query::<_, Q::CQ>(cs, g, s, &allocated_key, &circuit_query)?;
Ok(())
}

fn synthesize_prove_query<CS: ConstraintSystem<F>, CQ: CircuitQuery<F>>(
&mut self,
cs: &mut CS,
g: &mut GlobalAllocator<F>,
s: &Store<F>,
allocated_key: &AllocatedPtr<F>,
circuit_query: &CQ,
) -> Result<(), SynthesisError> {
let acc = self.acc.clone().unwrap();
let transcript = self.transcript.clone();

let (val, new_acc, new_transcript) = circuit_query
.expect("not a query form")
.synthesize_eval(&mut cs.namespace(|| "eval"), g, s, self, &acc, &transcript)
.unwrap();

let (new_acc, new_transcript) =
self.synthesize_remove(cs, g, s, &new_acc, &new_transcript, &allocated_key, &val)?;
self.synthesize_remove(cs, g, s, &new_acc, &new_transcript, allocated_key, &val)?;

self.acc = Some(new_acc);
self.transcript = new_transcript;
Expand Down Expand Up @@ -798,6 +830,7 @@ mod test {
}
assert!(cs.is_satisfied());
}

{
let mut scope: Scope<DemoQuery<F>, LogMemo<F>> = Scope::default();
scope.query(s, fact_4);
Expand All @@ -817,6 +850,16 @@ mod test {

scope.synthesize(cs, g, s).unwrap();

println!(
"transcript: {}",
scope
.memoset
.transcript
.get()
.unwrap()
.fmt_to_string_simple(s)
);

expect_eq(cs.num_constraints(), expect!["11408"]);
expect_eq(cs.aux().len(), expect!["11445"]);

Expand Down
9 changes: 4 additions & 5 deletions src/coroutine/memoset/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,10 @@ where
s.intern_symbol(&self.symbol())
}

/// What is this queries index? Used for ordering circuits and transcripts, grouped by query type.
fn index(&self) -> usize;
/// How many types of query are provided?
fn count() -> usize;
}

pub trait CircuitQuery<F: LurkField>
Expand All @@ -50,9 +53,5 @@ where
s.intern_symbol(&self.symbol())
}

fn from_ptr<CS: ConstraintSystem<F>>(
cs: &mut CS,
s: &Store<F>,
ptr: &Ptr,
) -> Result<Option<Self>, SynthesisError>;
fn from_ptr<CS: ConstraintSystem<F>>(cs: &mut CS, s: &Store<F>, ptr: &Ptr) -> Option<Self>;
}

1 comment on commit 46c7e62

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmarks

Table of Contents

Overview

This benchmark report shows the Fibonacci GPU benchmark.
NVIDIA L4
Intel(R) Xeon(R) CPU @ 2.20GHz
32 vCPUs
125 GB RAM
Workflow run: https://github.com/lurk-lab/lurk-rs/actions/runs/7549743653

Benchmark Results

LEM Fibonacci Prove - rc = 100

ref=8e5b5ea108e16ff6fdea88caf6950b1ee81f6398 ref=46c7e624188170b5f7091ac765aa3c12ebdb1ac0
num-100 1.74 s (✅ 1.00x) 1.74 s (✅ 1.00x slower)
num-200 3.36 s (✅ 1.00x) 3.36 s (✅ 1.00x faster)

LEM Fibonacci Prove - rc = 600

ref=8e5b5ea108e16ff6fdea88caf6950b1ee81f6398 ref=46c7e624188170b5f7091ac765aa3c12ebdb1ac0
num-100 2.03 s (✅ 1.00x) 2.03 s (✅ 1.00x faster)
num-200 3.39 s (✅ 1.00x) 3.39 s (✅ 1.00x faster)

Made with criterion-table

Please sign in to comment.