Skip to content

Commit

Permalink
chore: folding at i demands witness caching at i+1 and i+2
Browse files Browse the repository at this point in the history
Make folding, in the main thread, be the driving force that demands
caching witnesses for the next two MultiFrames.
  • Loading branch information
arthurpaulino committed Mar 18, 2024
1 parent 54e9292 commit 393ad52
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 24 deletions.
27 changes: 19 additions & 8 deletions src/proof/nova.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use nova::{
};
use once_cell::sync::OnceCell;
use pasta_curves::pallas;
use rayon::iter::{IntoParallelIterator, ParallelIterator};
use serde::{Deserialize, Serialize};
use std::{
marker::PhantomData,
Expand Down Expand Up @@ -297,24 +298,34 @@ impl<'a, F: CurveCycleEquipped, C: Coprocessor<F>> RecursiveSNARKTrait<F, C1LEM<
.is_parallel()
{
let cc = steps.into_iter().map(Mutex::new).collect::<Vec<_>>();
let (folding_idx_sender, folding_idx_receiver) =
std::sync::mpsc::channel::<Option<usize>>();

std::thread::scope(|s| {
s.spawn(|| {
// Skip the very first circuit's witness, so `prove_step` can begin immediately.
// That circuit's witness will not be cached and will just be computed on-demand.
cc.iter().skip(1).for_each(|mf| {
mf.lock()
.unwrap()
.cache_witness(store)
.expect("witness caching failed");
});
for folding_idx in folding_idx_receiver {
if let Some(folding_idx) = folding_idx {
[1, 2].into_par_iter().for_each(|shift| {
if let Some(mf) = cc.get(folding_idx + shift) {
mf.lock()
.unwrap()
.cache_witness(store)
.expect("witness caching failed");
}
});
} else {
return;
}
}
});

for (i, step) in cc.iter().enumerate() {
folding_idx_sender.send(Some(i)).unwrap();
let mut step = step.lock().unwrap();
prove_step(i, &step, &mut recursive_snark_option);
step.clear_cached_witness();
}
folding_idx_sender.send(None).unwrap();
recursive_snark_option
})
} else {
Expand Down
43 changes: 27 additions & 16 deletions src/proof/supernova.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ use nova::{
},
};
use once_cell::sync::OnceCell;
use rayon::iter::{IndexedParallelIterator, IntoParallelRefIterator, ParallelIterator};
use rayon::iter::{
IndexedParallelIterator, IntoParallelIterator, IntoParallelRefIterator, ParallelIterator,
};
use serde::{Deserialize, Serialize};
use std::{
marker::PhantomData,
Expand Down Expand Up @@ -260,24 +262,11 @@ impl<'a, F: CurveCycleEquipped, C: Coprocessor<F>> RecursiveSNARKTrait<F, C1LEM<
.into_iter()
.map(|mf| (mf.program_counter() == 0, Mutex::new(mf)))
.collect::<Vec<_>>();
let (folding_idx_sender, folding_idx_receiver) =
std::sync::mpsc::channel::<Option<usize>>();

std::thread::scope(|s| {
s.spawn(|| {
// Skip the very first circuit's witness, so `prove_step` can begin immediately.
// That circuit's witness will not be cached and will just be computed on-demand.

// There are many MultiFrames with PC = 0, each with several inner frames and heavy internal
// parallelism for witness generation. So we do it like on Nova's pipeline.
cc.iter()
.skip(1)
.filter(|(is_zero_pc, _)| *is_zero_pc)
.for_each(|(_, mf)| {
mf.lock()
.unwrap()
.cache_witness(store)
.expect("witness caching failed");
});

// There shouldn't be as many MultiFrames with PC != 0 and they only have one inner frame, each with
// poor internal parallelism for witness generation, so we can generate their witnesses in parallel.
// This is mimicking the behavior we had in the Nova pipeline before #941 so...
Expand All @@ -294,11 +283,33 @@ impl<'a, F: CurveCycleEquipped, C: Coprocessor<F>> RecursiveSNARKTrait<F, C1LEM<
});
});

s.spawn(|| {
for folding_idx in folding_idx_receiver {
if let Some(folding_idx) = folding_idx {
[1, 2].into_par_iter().for_each(|shift| {
if let Some((is_zero_pc, mf)) = cc.get(folding_idx + shift) {
// We're already generating witnesses for MultiFrames with PC != 0 in parallel
if *is_zero_pc {
mf.lock()
.unwrap()
.cache_witness(store)
.expect("witness caching failed");
}
}
});
} else {
return;
}
}
});

for (i, (_, step)) in cc.iter().enumerate() {
folding_idx_sender.send(Some(i)).unwrap();
let mut step = step.lock().unwrap();
prove_step(i, &step, &mut recursive_snark_option);
step.clear_cached_witness();
}
folding_idx_sender.send(None).unwrap();
recursive_snark_option
})
} else {
Expand Down

0 comments on commit 393ad52

Please sign in to comment.