diff --git a/src/proof/nova.rs b/src/proof/nova.rs index 821edf67d..1146d636c 100644 --- a/src/proof/nova.rs +++ b/src/proof/nova.rs @@ -13,6 +13,7 @@ use nova::{ }; use once_cell::sync::OnceCell; use pasta_curves::pallas; +use rayon::iter::{IntoParallelIterator, ParallelIterator}; use serde::{Deserialize, Serialize}; use std::{ marker::PhantomData, @@ -297,24 +298,34 @@ impl<'a, F: CurveCycleEquipped, C: Coprocessor> RecursiveSNARKTrait>(); + let (folding_idx_sender, folding_idx_receiver) = + std::sync::mpsc::channel::>(); std::thread::scope(|s| { s.spawn(|| { - // Skip the very first circuit's witness, so `prove_step` can begin immediately. - // That circuit's witness will not be cached and will just be computed on-demand. - cc.iter().skip(1).for_each(|mf| { - mf.lock() - .unwrap() - .cache_witness(store) - .expect("witness caching failed"); - }); + for folding_idx in folding_idx_receiver { + if let Some(folding_idx) = folding_idx { + [1, 2].into_par_iter().for_each(|shift| { + if let Some(mf) = cc.get(folding_idx + shift) { + mf.lock() + .unwrap() + .cache_witness(store) + .expect("witness caching failed"); + } + }); + } else { + return; + } + } }); for (i, step) in cc.iter().enumerate() { + folding_idx_sender.send(Some(i)).unwrap(); let mut step = step.lock().unwrap(); prove_step(i, &step, &mut recursive_snark_option); step.clear_cached_witness(); } + folding_idx_sender.send(None).unwrap(); recursive_snark_option }) } else { diff --git a/src/proof/supernova.rs b/src/proof/supernova.rs index f5e6d4a7f..c7d11a08b 100644 --- a/src/proof/supernova.rs +++ b/src/proof/supernova.rs @@ -11,7 +11,9 @@ use nova::{ }, }; use once_cell::sync::OnceCell; -use rayon::iter::{IndexedParallelIterator, IntoParallelRefIterator, ParallelIterator}; +use rayon::iter::{ + IndexedParallelIterator, IntoParallelIterator, IntoParallelRefIterator, ParallelIterator, +}; use serde::{Deserialize, Serialize}; use std::{ marker::PhantomData, @@ -260,24 +262,11 @@ impl<'a, F: CurveCycleEquipped, C: Coprocessor> RecursiveSNARKTrait>(); + let (folding_idx_sender, folding_idx_receiver) = + std::sync::mpsc::channel::>(); std::thread::scope(|s| { s.spawn(|| { - // Skip the very first circuit's witness, so `prove_step` can begin immediately. - // That circuit's witness will not be cached and will just be computed on-demand. - - // There are many MultiFrames with PC = 0, each with several inner frames and heavy internal - // parallelism for witness generation. So we do it like on Nova's pipeline. - cc.iter() - .skip(1) - .filter(|(is_zero_pc, _)| *is_zero_pc) - .for_each(|(_, mf)| { - mf.lock() - .unwrap() - .cache_witness(store) - .expect("witness caching failed"); - }); - // There shouldn't be as many MultiFrames with PC != 0 and they only have one inner frame, each with // poor internal parallelism for witness generation, so we can generate their witnesses in parallel. // This is mimicking the behavior we had in the Nova pipeline before #941 so... @@ -294,11 +283,33 @@ impl<'a, F: CurveCycleEquipped, C: Coprocessor> RecursiveSNARKTrait