From ff2af20d547dbdcebab6ab757f4daa2567797218 Mon Sep 17 00:00:00 2001 From: Arthur Paulino Date: Wed, 3 Jan 2024 12:02:33 -0300 Subject: [PATCH] Chore: parallel steps witness generation for NIVC * MultiFrames with PC=0 have their witnesses cached just like in the IVC pipeline * MultiFrames with PC!=0 have their witnesses cached in parallel due to limited size (agnostic to RC) and lack of internal parallelism --- src/lem/multiframe.rs | 4 ++ src/proof/mod.rs | 3 + src/proof/nova.rs | 2 - src/proof/supernova.rs | 128 ++++++++++++++++++++++++++++++++--------- 4 files changed, 108 insertions(+), 29 deletions(-) diff --git a/src/lem/multiframe.rs b/src/lem/multiframe.rs index f8c4e2640c..c245d7cfda 100644 --- a/src/lem/multiframe.rs +++ b/src/lem/multiframe.rs @@ -686,6 +686,10 @@ impl<'a, F: LurkField, C: Coprocessor + 'a> MultiFrameTrait<'a, F, C> for Mul .skip_while(|f| f.input == f.output && stop_cond(&f.output)) .count() } + + fn program_counter(&self) -> usize { + self.pc + } } impl<'a, F: LurkField, C: Coprocessor> Circuit for MultiFrame<'a, F, C> { diff --git a/src/proof/mod.rs b/src/proof/mod.rs index 8730d9c63a..a788569616 100644 --- a/src/proof/mod.rs +++ b/src/proof/mod.rs @@ -149,6 +149,9 @@ pub trait MultiFrameTrait<'a, F: LurkField, C: Coprocessor + 'a>: store: &'a Self::Store, folding_config: Arc>, ) -> Vec; + + /// The program counter (circuit index) + fn program_counter(&self) -> usize; } /// Represents a sequential Constraint System for a given proof. diff --git a/src/proof/nova.rs b/src/proof/nova.rs index 74bb4a6728..c7788b5d98 100644 --- a/src/proof/nova.rs +++ b/src/proof/nova.rs @@ -1,5 +1,3 @@ -#![allow(non_snake_case)] - use abomonation::Abomonation; use bellpepper_core::{num::AllocatedNum, ConstraintSystem}; use ff::PrimeField; diff --git a/src/proof/supernova.rs b/src/proof/supernova.rs index f8e28ed7aa..c919c56ffc 100644 --- a/src/proof/supernova.rs +++ b/src/proof/supernova.rs @@ -1,5 +1,3 @@ -#![allow(non_snake_case)] - use abomonation::Abomonation; use ff::PrimeField; use nova::{ @@ -15,11 +13,17 @@ use nova::{ Engine, }, }; +use rayon::iter::{IndexedParallelIterator, IntoParallelRefIterator, ParallelIterator}; use serde::{Deserialize, Serialize}; -use std::{marker::PhantomData, ops::Index, sync::Arc}; +use std::{ + marker::PhantomData, + ops::Index, + sync::{Arc, Mutex}, +}; use tracing::info; use crate::{ + config::lurk_config, coprocessor::Coprocessor, error::ProofError, eval::lang::Lang, @@ -199,7 +203,7 @@ where pp: &PublicParams, z0: &[F], steps: Vec, - _store: &'a ::Store, + store: &'a ::Store, _reduction_count: usize, _lang: Arc>, ) -> Result { @@ -208,29 +212,99 @@ where let z0_primary = z0; let z0_secondary = Self::z0_secondary(); - for (i, step) in steps.iter().enumerate() { - info!("prove_recursively, step {i}"); - - let mut recursive_snark = recursive_snark_option.clone().unwrap_or_else(|| { - info!("RecursiveSnark::new {i}"); - RecursiveSNARK::new( - &pp.pp, - step, - step, - &step.secondary_circuit(), - z0_primary, - &z0_secondary, - ) - .unwrap() - }); - - info!("prove_step {i}"); - - recursive_snark - .prove_step(&pp.pp, step, &step.secondary_circuit()) - .unwrap(); - - recursive_snark_option = Some(recursive_snark); + if lurk_config(None, None) + .perf + .parallelism + .recursive_steps + .is_parallel() + { + let cc = steps + .into_iter() + .map(|mf| (mf.program_counter() == 0, Mutex::new(mf))) + .collect::>(); + + crossbeam::thread::scope(|s| { + s.spawn(|_| { + // Skip the very first circuit's witness, so `prove_step` can begin immediately. + // That circuit's witness will not be cached and will just be computed on-demand. + + // There are many MultiFrames with PC = 0 and they have several inner frames, with proper internal + // paralellism for witness generation, so we do it like on Nova's pipeline + cc.iter() + .skip(1) + .filter(|(is_zero_pc, _)| *is_zero_pc) + .for_each(|(_, mf)| { + mf.lock() + .unwrap() + .cache_witness(store) + .expect("witness caching failed"); + }); + + // There shouldn't be too many MultiFrames with PC != 0 and they only have one inner frame, without + // internal parallelism for witness generation, so we can generate their witnesses in parallel + cc.par_iter() + .skip(1) + .filter(|(is_zero_pc, _)| !*is_zero_pc) + .for_each(|(_, mf)| { + mf.lock() + .unwrap() + .cache_witness(store) + .expect("witness caching failed"); + }); + }); + + for (i, (_, step)) in cc.iter().enumerate() { + let step = step.lock().unwrap(); + info!("prove_recursively, step {i}"); + + let mut recursive_snark = recursive_snark_option.clone().unwrap_or_else(|| { + info!("RecursiveSnark::new {i}"); + RecursiveSNARK::new( + &pp.pp, + &*step, + &step, + &step.secondary_circuit(), + z0_primary, + &z0_secondary, + ) + .unwrap() + }); + + info!("prove_step {i}"); + + recursive_snark + .prove_step(&pp.pp, &step, &step.secondary_circuit()) + .unwrap(); + + recursive_snark_option = Some(recursive_snark); + } + }) + .unwrap() + } else { + for (i, step) in steps.iter().enumerate() { + info!("prove_recursively, step {i}"); + + let mut recursive_snark = recursive_snark_option.clone().unwrap_or_else(|| { + info!("RecursiveSnark::new {i}"); + RecursiveSNARK::new( + &pp.pp, + step, + step, + &step.secondary_circuit(), + z0_primary, + &z0_secondary, + ) + .unwrap() + }); + + info!("prove_step {i}"); + + recursive_snark + .prove_step(&pp.pp, step, &step.secondary_circuit()) + .unwrap(); + + recursive_snark_option = Some(recursive_snark); + } } // This probably should be made unnecessary.