From 604143caab31117424a4eb6846a87dbacbf30f1b Mon Sep 17 00:00:00 2001 From: Arthur Paulino Date: Tue, 13 Feb 2024 10:31:12 -0300 Subject: [PATCH] chore: simplify and homogenize proving code for Nova and SuperNova proofs * Remove code duplication in Nova proving * Factor out and reuse debugging code that was triggered only for Nova proving without "parallel steps" * Remove Option wrapping and cloning for recursive snarks --- src/proof/nova.rs | 124 +++++++++++++++++++---------------------- src/proof/supernova.rs | 49 +++++++--------- 2 files changed, 75 insertions(+), 98 deletions(-) diff --git a/src/proof/nova.rs b/src/proof/nova.rs index 7410f7b16f..e8cb98d382 100644 --- a/src/proof/nova.rs +++ b/src/proof/nova.rs @@ -1,4 +1,4 @@ -use bellpepper_core::{num::AllocatedNum, ConstraintSystem}; +use bellpepper_core::{num::AllocatedNum, ConstraintSystem, SynthesisError}; use halo2curves::bn256::Fr as Bn256Scalar; use nova::{ errors::NovaError, @@ -18,6 +18,7 @@ use std::{ marker::PhantomData, sync::{Arc, Mutex}, }; +use tracing::info; use crate::{ config::lurk_config, @@ -25,7 +26,7 @@ use crate::{ error::ProofError, eval::lang::Lang, field::LurkField, - lem::{interpreter::Frame, pointers::Ptr, store::Store}, + lem::{interpreter::Frame, multiframe::MultiFrame, pointers::Ptr, store::Store}, proof::{supernova::FoldingConfig, FrameLike, Prover}, }; @@ -223,6 +224,28 @@ pub fn circuits<'a, F: CurveCycleEquipped, C: Coprocessor + 'a>( ) } +/// For debugging purposes, synthesize the circuit and check that the constraint +/// system is satisfied +pub(crate) fn debug_step>( + circuit: &MultiFrame<'_, F, C>, + store: &Store, +) -> Result<(), SynthesisError> { + use bellpepper_core::test_cs::TestConstraintSystem; + let mut cs = TestConstraintSystem::::new(); + + let zi = store.to_scalar_vector(circuit.input()); + let zi_allocated: Vec<_> = zi + .iter() + .enumerate() + .map(|(i, x)| AllocatedNum::alloc(cs.namespace(|| format!("z{i}_1")), || Ok(*x))) + .collect::>()?; + + circuit.synthesize(&mut cs, zi_allocated.as_slice())?; + + assert!(cs.is_satisfied()); + Ok(()) +} + impl<'a, F: CurveCycleEquipped, C: Coprocessor> RecursiveSNARKTrait> for Proof> { @@ -237,22 +260,35 @@ impl<'a, F: CurveCycleEquipped, C: Coprocessor> RecursiveSNARKTrait>, store: &Store, ) -> Result { - assert!(!steps.is_empty()); - assert_eq!(steps[0].arity(), z0.len()); let debug = false; - let z0_primary = z0; - let z0_secondary = Self::z0_secondary(); + let first_step = steps.first().expect("steps can't be empty"); + assert_eq!(first_step.arity(), z0.len()); - let circuit_secondary = TrivialCircuit::default(); + let secondary_circuit = TrivialCircuit::default(); let num_steps = steps.len(); - tracing::debug!("steps.len: {num_steps}"); - - // produce a recursive SNARK - let mut recursive_snark: Option>> = None; + info!("proving {num_steps} steps"); + + let mut recursive_snark = RecursiveSNARK::new( + &pp.pp, + first_step, + &secondary_circuit, + z0, + &Self::z0_secondary(), + ) + .map_err(ProofError::Nova)?; + + let mut prove_step = |i: usize, step: &C1LEM<'a, F, C>| { + if debug { + debug_step(step, store).unwrap(); + } + info!("prove_step {i}"); + recursive_snark + .prove_step(&pp.pp, step, &secondary_circuit) + .unwrap(); + }; - // the shadowing here is voluntary - let recursive_snark = if lurk_config(None, None) + if lurk_config(None, None) .perf .parallelism .recursive_steps @@ -272,69 +308,21 @@ impl<'a, F: CurveCycleEquipped, C: Coprocessor> RecursiveSNARKTrait::new(); - - let zi = store.to_scalar_vector(circuit_primary.input()); - let zi_allocated: Vec<_> = zi - .iter() - .enumerate() - .map(|(i, x)| { - AllocatedNum::alloc(cs.namespace(|| format!("z{i}_1")), || Ok(*x)) - }) - .collect::>()?; - - circuit_primary.synthesize(&mut cs, zi_allocated.as_slice())?; - - assert!(cs.is_satisfied()); - } - - let mut r_snark = recursive_snark.unwrap_or_else(|| { - RecursiveSNARK::new( - &pp.pp, - circuit_primary, - &circuit_secondary, - z0_primary, - &z0_secondary, - ) - .expect("Failed to construct initial recursive snark") - }); - r_snark - .prove_step(&pp.pp, circuit_primary, &circuit_secondary) - .expect("failure to prove Nova step"); - recursive_snark = Some(r_snark); + for (i, step) in steps.iter().enumerate() { + prove_step(i, step); } - recursive_snark }; Ok(Self::Recursive( - Box::new(recursive_snark.unwrap()), + Box::new(recursive_snark), num_steps, PhantomData, )) diff --git a/src/proof/supernova.rs b/src/proof/supernova.rs index dfd2bbae93..66883efd6e 100644 --- a/src/proof/supernova.rs +++ b/src/proof/supernova.rs @@ -28,7 +28,7 @@ use crate::{ field::LurkField, lem::{interpreter::Frame, pointers::Ptr, store::Store}, proof::{ - nova::{CurveCycleEquipped, Dual, NovaCircuitShape, E1}, + nova::{debug_step, CurveCycleEquipped, Dual, NovaCircuitShape, E1}, Prover, RecursiveSNARKTrait, }, }; @@ -202,36 +202,29 @@ impl<'a, F: CurveCycleEquipped, C: Coprocessor> RecursiveSNARKTrait>, store: &Store, ) -> Result { - let mut recursive_snark_option: Option>> = None; + let debug = false; + let first_step = steps.first().expect("steps can't be empty"); - let z0_primary = z0; - let z0_secondary = Self::z0_secondary(); + info!("proving {} steps", steps.len()); - let mut prove_step = |i: usize, step: &C1LEM<'a, F, C>| { - info!("prove_recursively, step {i}"); - - let secondary_circuit = step.secondary_circuit(); - - let mut recursive_snark = recursive_snark_option.clone().unwrap_or_else(|| { - info!("RecursiveSnark::new {i}"); - RecursiveSNARK::new( - &pp.pp, - step, - step, - &secondary_circuit, - z0_primary, - &z0_secondary, - ) - .unwrap() - }); + let mut recursive_snark = RecursiveSNARK::new( + &pp.pp, + first_step, + first_step, + &first_step.secondary_circuit(), + z0, + &Self::z0_secondary(), + ) + .map_err(ProofError::SuperNova)?; + let mut prove_step = |i: usize, step: &C1LEM<'a, F, C>| { + if debug { + debug_step(step, store).unwrap(); + } info!("prove_step {i}"); - recursive_snark - .prove_step(&pp.pp, step, &secondary_circuit) + .prove_step(&pp.pp, step, &step.secondary_circuit()) .unwrap(); - - recursive_snark_option = Some(recursive_snark); }; if lurk_config(None, None) @@ -291,11 +284,7 @@ impl<'a, F: CurveCycleEquipped, C: Coprocessor> RecursiveSNARKTrait) -> Result {