From 8046da0e91f8ca9eb41b5733efcd23ca47c4fcae Mon Sep 17 00:00:00 2001 From: Gabriel Barreto Date: Tue, 23 Jan 2024 14:27:57 -0300 Subject: [PATCH] Generic `Prover` (#1074) * Removed `lang` and `reduction_count` from `prove_recursively` * Removed `Coprocessor` from `RecursiveSNARKTrait` * WIP removing coprocessor from `Prover` * Simplified `RecursiveSNARKTrait` * MultiFrame is FrameLike * Generic `Prover` * Generic `Proof` * Use multiframe input --- benches/end2end.rs | 14 ++++--- benches/fibonacci.rs | 3 +- benches/sha256.rs | 8 ++-- examples/sha256_nivc.rs | 8 +++- examples/tp_table.rs | 7 +--- src/cli/lurk_proof.rs | 4 +- src/cli/repl/meta_cmd.rs | 4 +- src/cli/repl/mod.rs | 4 +- src/lem/multiframe.rs | 18 +++++--- src/proof/mod.rs | 51 ++++++----------------- src/proof/nova.rs | 88 ++++++++++++++++++++++++---------------- src/proof/supernova.rs | 76 +++++++++++++++++++++++----------- src/proof/tests/mod.rs | 6 +-- 13 files changed, 160 insertions(+), 131 deletions(-) diff --git a/benches/end2end.rs b/benches/end2end.rs index 72e7d4a982..e6e5602edb 100644 --- a/benches/end2end.rs +++ b/benches/end2end.rs @@ -12,7 +12,7 @@ use lurk::{ pointers::Ptr, store::Store, }, - proof::{nova::NovaProver, Prover, RecursiveSNARKTrait}, + proof::{nova::NovaProver, RecursiveSNARKTrait}, public_parameters::{ self, instance::{Instance, Kind}, @@ -84,7 +84,7 @@ fn end2end_benchmark(c: &mut Criterion) { b.iter(|| { let ptr = go_base::(&store, state.clone(), s.0, s.1); let frames = evaluate::>(None, ptr, &store, limit).unwrap(); - let _result = prover.prove(&pp, &frames, &store).unwrap(); + let _result = prover.prove_from_frames(&pp, &frames, &store).unwrap(); }) }); @@ -253,7 +253,7 @@ fn prove_benchmark(c: &mut Criterion) { let frames = evaluate::>(None, ptr, &store, limit).unwrap(); b.iter(|| { - let result = prover.prove(&pp, &frames, &store).unwrap(); + let result = prover.prove_from_frames(&pp, &frames, &store).unwrap(); black_box(result); }) }); @@ -300,7 +300,7 @@ fn prove_compressed_benchmark(c: &mut Criterion) { let frames = evaluate::>(None, ptr, &store, limit).unwrap(); b.iter(|| { - let (proof, _, _, _) = prover.prove(&pp, &frames, &store).unwrap(); + let (proof, _, _, _) = prover.prove_from_frames(&pp, &frames, &store).unwrap(); let compressed_result = proof.compress(&pp).unwrap(); black_box(compressed_result); @@ -344,7 +344,8 @@ fn verify_benchmark(c: &mut Criterion) { let ptr = go_base(&store, state.clone(), s.0, s.1); let prover = NovaProver::new(reduction_count, lang_pallas_rc.clone()); let frames = evaluate::>(None, ptr, &store, limit).unwrap(); - let (proof, z0, zi, _num_steps) = prover.prove(&pp, &frames, &store).unwrap(); + let (proof, z0, zi, _num_steps) = + prover.prove_from_frames(&pp, &frames, &store).unwrap(); b.iter_batched( || z0.clone(), @@ -396,7 +397,8 @@ fn verify_compressed_benchmark(c: &mut Criterion) { let ptr = go_base(&store, state.clone(), s.0, s.1); let prover = NovaProver::new(reduction_count, lang_pallas_rc.clone()); let frames = evaluate::>(None, ptr, &store, limit).unwrap(); - let (proof, z0, zi, _num_steps) = prover.prove(&pp, &frames, &store).unwrap(); + let (proof, z0, zi, _num_steps) = + prover.prove_from_frames(&pp, &frames, &store).unwrap(); let compressed_proof = proof.compress(&pp).unwrap(); diff --git a/benches/fibonacci.rs b/benches/fibonacci.rs index edd3d07cbe..d4d3b2b7c2 100644 --- a/benches/fibonacci.rs +++ b/benches/fibonacci.rs @@ -10,7 +10,6 @@ use lurk::{ eval::lang::{Coproc, Lang}, lem::{eval::evaluate, store::Store}, proof::nova::NovaProver, - proof::Prover, public_parameters::{ instance::{Instance, Kind}, public_params, @@ -116,7 +115,7 @@ fn fibonacci_prove( b.iter_batched( || frames, |frames| { - let result = prover.prove(&pp, frames, &store); + let result = prover.prove_from_frames(&pp, frames, &store); let _ = black_box(result); }, BatchSize::LargeInput, diff --git a/benches/sha256.rs b/benches/sha256.rs index 9242f146c1..97603cbd9f 100644 --- a/benches/sha256.rs +++ b/benches/sha256.rs @@ -22,7 +22,7 @@ use lurk::{ pointers::Ptr, store::Store, }, - proof::{nova::NovaProver, supernova::SuperNovaProver, Prover, RecursiveSNARKTrait}, + proof::{nova::NovaProver, supernova::SuperNovaProver, RecursiveSNARKTrait}, public_parameters::{ instance::{Instance, Kind}, public_params, supernova_public_params, @@ -138,7 +138,7 @@ fn sha256_ivc_prove( b.iter_batched( || frames, |frames| { - let result = prover.prove(&pp, frames, store); + let result = prover.prove_from_frames(&pp, frames, store); let _ = black_box(result); }, BatchSize::LargeInput, @@ -219,7 +219,7 @@ fn sha256_ivc_prove_compressed( b.iter_batched( || frames, |frames| { - let (proof, _, _, _) = prover.prove(&pp, frames, store).unwrap(); + let (proof, _, _, _) = prover.prove_from_frames(&pp, frames, store).unwrap(); let compressed_result = proof.compress(&pp).unwrap(); let _ = black_box(compressed_result); @@ -303,7 +303,7 @@ fn sha256_nivc_prove( b.iter_batched( || frames, |frames| { - let result = prover.prove(&pp, frames, store); + let result = prover.prove_from_frames(&pp, frames, store); let _ = black_box(result); }, BatchSize::LargeInput, diff --git a/examples/sha256_nivc.rs b/examples/sha256_nivc.rs index b8769f6102..c7ab304645 100644 --- a/examples/sha256_nivc.rs +++ b/examples/sha256_nivc.rs @@ -12,7 +12,7 @@ use lurk::{ pointers::Ptr, store::Store, }, - proof::{supernova::SuperNovaProver, Prover, RecursiveSNARKTrait}, + proof::{supernova::SuperNovaProver, RecursiveSNARKTrait}, public_parameters::{ instance::{Instance, Kind}, supernova_public_params, @@ -94,7 +94,11 @@ fn main() { println!("Beginning proof step..."); let proof_start = Instant::now(); let (proof, z0, zi, _num_steps) = tracing_texray::examine(tracing::info_span!("bang!")) - .in_scope(|| supernova_prover.prove(&pp, &frames, store).unwrap()); + .in_scope(|| { + supernova_prover + .prove_from_frames(&pp, &frames, store) + .unwrap() + }); let proof_end = proof_start.elapsed(); println!("Proofs took {:?}", proof_end); diff --git a/examples/tp_table.rs b/examples/tp_table.rs index 697105dd48..73459290e8 100644 --- a/examples/tp_table.rs +++ b/examples/tp_table.rs @@ -4,10 +4,7 @@ use criterion::black_box; use lurk::{ eval::lang::{Coproc, Lang}, lem::{eval::evaluate, multiframe::MultiFrame, store::Store}, - proof::{ - nova::{public_params, NovaProver, PublicParams}, - Prover, - }, + proof::nova::{public_params, NovaProver, PublicParams}, }; use num_traits::ToPrimitive; use pasta_curves::pallas::Scalar as Fr; @@ -179,7 +176,7 @@ fn main() { let mut timings = Vec::with_capacity(n_samples); for _ in 0..n_samples { let start = Instant::now(); - let result = prover.prove(&pp, frames, &store); + let result = prover.prove_from_frames(&pp, frames, &store); let _ = black_box(result); let end = start.elapsed().as_secs_f64(); timings.push(end); diff --git a/src/cli/lurk_proof.rs b/src/cli/lurk_proof.rs index 1dc1876f78..3f61213ebe 100644 --- a/src/cli/lurk_proof.rs +++ b/src/cli/lurk_proof.rs @@ -10,7 +10,7 @@ use crate::{ field::LurkField, lem::{pointers::ZPtr, store::Store}, proof::{ - nova::{self, CurveCycleEquipped, E1, E2}, + nova::{self, CurveCycleEquipped, C1LEM, E1, E2}, RecursiveSNARKTrait, }, public_parameters::{ @@ -131,7 +131,7 @@ pub(crate) enum LurkProof< < as Engine>::Scalar as ff::PrimeField>::Repr: Abomonation, { Nova { - proof: nova::Proof<'a, F, C>, + proof: nova::Proof>, public_inputs: Vec, public_outputs: Vec, rc: usize, diff --git a/src/cli/repl/meta_cmd.rs b/src/cli/repl/meta_cmd.rs index cfbd283942..f025098b28 100644 --- a/src/cli/repl/meta_cmd.rs +++ b/src/cli/repl/meta_cmd.rs @@ -23,7 +23,7 @@ use crate::{ }, package::{Package, SymbolRef}, proof::{ - nova::{self, CurveCycleEquipped, E1, E2}, + nova::{self, CurveCycleEquipped, C1LEM, E1, E2}, RecursiveSNARKTrait, }, public_parameters::{ @@ -1104,7 +1104,7 @@ where { Nova { args: LurkData, - proof: nova::Proof<'a, F, C>, + proof: nova::Proof>, }, } diff --git a/src/cli/repl/mod.rs b/src/cli/repl/mod.rs index 32f19dd12c..8b6973883a 100644 --- a/src/cli/repl/mod.rs +++ b/src/cli/repl/mod.rs @@ -41,7 +41,7 @@ use crate::{ parser, proof::{ nova::{CurveCycleEquipped, NovaProver}, - Prover, RecursiveSNARKTrait, + RecursiveSNARKTrait, }, public_parameters::{ instance::{Instance, Kind}, @@ -342,7 +342,7 @@ where info!("Proving"); let (proof, public_inputs, public_outputs, num_steps) = - prover.prove(&pp, frames, &self.store)?; + prover.prove_from_frames(&pp, frames, &self.store)?; info!("Compressing proof"); let proof = proof.compress(&pp)?; assert_eq!(self.rc * num_steps, pad(n_frames, self.rc)); diff --git a/src/lem/multiframe.rs b/src/lem/multiframe.rs index f4099bebeb..71089df968 100644 --- a/src/lem/multiframe.rs +++ b/src/lem/multiframe.rs @@ -75,11 +75,6 @@ impl<'a, F: LurkField, C: Coprocessor> MultiFrame<'a, F, C> { self.frames.as_ref() } - #[inline] - pub fn output(&self) -> &Option> { - &self.output - } - pub fn emitted(_store: &Store, eval_frame: &Frame) -> Vec { eval_frame.emitted.clone() } @@ -385,6 +380,19 @@ impl CEKState for Vec { } } +impl<'a, F: LurkField, C: Coprocessor> FrameLike for MultiFrame<'a, F, C> { + type FrameIO = Vec; + #[inline] + fn input(&self) -> &Vec { + self.input.as_ref().unwrap() + } + + #[inline] + fn output(&self) -> &Vec { + self.output.as_ref().unwrap() + } +} + impl FrameLike for Frame { type FrameIO = Vec; fn input(&self) -> &Self::FrameIO { diff --git a/src/proof/mod.rs b/src/proof/mod.rs index d1daa03861..86cdf7a0a6 100644 --- a/src/proof/mod.rs +++ b/src/proof/mod.rs @@ -23,14 +23,11 @@ use crate::{ error::ProofError, eval::lang::Lang, field::LurkField, - lem::{eval::EvalConfig, interpreter::Frame, pointers::Ptr, store::Store}, + lem::{eval::EvalConfig, pointers::Ptr, store::Store}, proof::nova::E2, }; -use self::{ - nova::{CurveCycleEquipped, C1LEM}, - supernova::FoldingConfig, -}; +use self::{nova::CurveCycleEquipped, supernova::FoldingConfig}; /// The State of a CEK machine. pub trait CEKState { @@ -88,7 +85,7 @@ pub trait Provable { // * `Prover`, which abstracts over Nova and SuperNova provers /// Trait to abstract Nova and SuperNova proofs -pub trait RecursiveSNARKTrait<'a, F: CurveCycleEquipped, C: Coprocessor + 'a> +pub trait RecursiveSNARKTrait where Self: Sized, { @@ -102,10 +99,8 @@ where fn prove_recursively( pp: &Self::PublicParams, z0: &[F], - steps: Vec>, - store: &'a Store, - reduction_count: usize, - lang: Arc>, + steps: Vec, + store: &Store, ) -> Result; /// Compress a proof @@ -155,12 +150,12 @@ impl FoldingMode { } /// A trait for a prover that works with a field `F`. -pub trait Prover<'a, F: CurveCycleEquipped, C: Coprocessor + 'a> { +pub trait Prover<'a, F: CurveCycleEquipped, M: FrameLike>> { /// Associated type for public parameters type PublicParams; /// Assiciated proof type, which must implement `RecursiveSNARKTrait` - type RecursiveSnark: RecursiveSNARKTrait<'a, F, C, PublicParams = Self::PublicParams>; + type RecursiveSnark: RecursiveSNARKTrait; /// Returns a reference to the prover's FoldingMode fn folding_mode(&self) -> &FoldingMode; @@ -168,36 +163,20 @@ pub trait Prover<'a, F: CurveCycleEquipped, C: Coprocessor + 'a> { /// Returns the number of reductions for the prover. fn reduction_count(&self) -> usize; - /// Returns a reference to the Prover's Lang. - fn lang(&self) -> &Arc>; - - /// Generate a proof from a sequence of frames + /// Generates a recursive proof from a vector of `M` fn prove( &self, pp: &Self::PublicParams, - frames: &[Frame], + steps: Vec, store: &'a Store, ) -> Result<(Self::RecursiveSnark, Vec, Vec, usize), ProofError> { store.hydrate_z_cache(); - let z0 = store.to_scalar_vector(frames[0].input()); - let zi = store.to_scalar_vector(frames.last().unwrap().output()); - - let lang = self.lang().clone(); - let folding_config = self - .folding_mode() - .folding_config(lang.clone(), self.reduction_count()); + let z0 = store.to_scalar_vector(steps[0].input()); + let zi = store.to_scalar_vector(steps.last().unwrap().output()); - let steps = C1LEM::<'a, F, C>::from_frames(frames, store, &folding_config.into()); let num_steps = steps.len(); - let prove_output = Self::RecursiveSnark::prove_recursively( - pp, - &z0, - steps, - store, - self.reduction_count(), - lang, - )?; + let prove_output = Self::RecursiveSnark::prove_recursively(pp, &z0, steps, store)?; Ok((prove_output, z0, zi, num_steps)) } @@ -210,11 +189,7 @@ pub trait Prover<'a, F: CurveCycleEquipped, C: Coprocessor + 'a> { env: Ptr, store: &'a Store, limit: usize, - ) -> Result<(Self::RecursiveSnark, Vec, Vec, usize), ProofError> { - let eval_config = self.folding_mode().eval_config(self.lang()); - let frames = C1LEM::<'a, F, C>::build_frames(expr, env, store, limit, &eval_config)?; - self.prove(pp, &frames, store) - } + ) -> Result<(Self::RecursiveSnark, Vec, Vec, usize), ProofError>; /// Returns the expected total number of steps for the prover given raw iterations. fn expected_num_steps(&self, raw_iterations: usize) -> usize { diff --git a/src/proof/nova.rs b/src/proof/nova.rs index 9ed1b39a79..03139a8651 100644 --- a/src/proof/nova.rs +++ b/src/proof/nova.rs @@ -26,7 +26,7 @@ use crate::{ error::ProofError, eval::lang::Lang, field::LurkField, - lem::store::Store, + lem::{interpreter::Frame, pointers::Ptr, store::Store}, proof::{supernova::FoldingConfig, FrameLike, Prover}, }; @@ -151,24 +151,19 @@ where /// An enum representing the two types of proofs that can be generated and verified. #[derive(Serialize, Deserialize)] #[serde(bound = "")] -pub enum Proof<'a, F: CurveCycleEquipped, C: Coprocessor> +pub enum Proof> where < as Engine>::Scalar as ff::PrimeField>::Repr: Abomonation, < as Engine>::Scalar as ff::PrimeField>::Repr: Abomonation, { /// A proof for the intermediate steps of a recursive computation along with /// the number of steps used for verification - Recursive( - Box, E2, C1LEM<'a, F, C>, C2>>, - usize, - PhantomData<&'a C>, - ), + Recursive(Box, E2, C1, C2>>, usize), /// A proof for the final step of a recursive computation along with the number /// of steps used for verification Compressed( - Box, E2, C1LEM<'a, F, C>, C2, SS1, SS2>>, + Box, E2, C1, C2, SS1, SS2>>, usize, - PhantomData<&'a C>, ), } @@ -223,7 +218,8 @@ pub fn circuits<'a, F: CurveCycleEquipped, C: Coprocessor + 'a>( ) } -impl<'a, F: CurveCycleEquipped, C: Coprocessor> RecursiveSNARKTrait<'a, F, C> for Proof<'a, F, C> +impl<'a, F: CurveCycleEquipped, C: Coprocessor> RecursiveSNARKTrait> + for Proof> where ::Repr: Abomonation, <<::E2 as Engine>::Scalar as PrimeField>::Repr: Abomonation, @@ -237,9 +233,7 @@ where pp: &PublicParams>, z0: &[F], steps: Vec>, - store: &'a Store, - reduction_count: usize, - lang: Arc>, + store: &Store, ) -> Result { assert!(!steps.is_empty()); assert_eq!(steps[0].arity(), z0.len()); @@ -247,11 +241,7 @@ where let z0_primary = z0; let z0_secondary = Self::z0_secondary(); - assert_eq!(steps[0].frames().unwrap().len(), reduction_count); - let (_circuit_primary, circuit_secondary): ( - C1LEM<'a, F, C>, - TrivialCircuit< as Engine>::Scalar>, - ) = circuits(reduction_count, lang); + let circuit_secondary = TrivialCircuit::default(); let num_steps = steps.len(); tracing::debug!("steps.len: {num_steps}"); @@ -283,7 +273,6 @@ where for circuit_primary in cc.iter() { let circuit_primary = circuit_primary.lock().unwrap(); - assert_eq!(reduction_count, circuit_primary.frames().unwrap().len()); let mut r_snark = recursive_snark.unwrap_or_else(|| { RecursiveSNARK::new( @@ -305,15 +294,12 @@ where .unwrap() } else { for circuit_primary in steps.iter() { - assert_eq!(reduction_count, circuit_primary.frames().unwrap().len()); if debug { // For debugging purposes, synthesize the circuit and check that the constraint system is satisfied. use bellpepper_core::test_cs::TestConstraintSystem; let mut cs = TestConstraintSystem::< as Engine>::Scalar>::new(); - // This is a CircuitFrame, not an EvalFrame - let first_frame = circuit_primary.frames().unwrap().iter().next().unwrap(); - let zi = store.to_scalar_vector(first_frame.input()); + let zi = store.to_scalar_vector(circuit_primary.input()); let zi_allocated: Vec<_> = zi .iter() .enumerate() @@ -348,20 +334,18 @@ where Ok(Self::Recursive( Box::new(recursive_snark.unwrap()), num_steps, - PhantomData, )) } fn compress(self, pp: &PublicParams>) -> Result { match self { - Self::Recursive(recursive_snark, num_steps, _) => Ok(Self::Compressed( + Self::Recursive(recursive_snark, num_steps) => Ok(Self::Compressed( Box::new(CompressedSNARK::<_, _, _, _, SS1, SS2>::prove( &pp.pp, &pp.pk, &recursive_snark, )?), num_steps, - PhantomData, )), Self::Compressed(..) => Ok(self), } @@ -373,10 +357,10 @@ where let zi_secondary = &z0_secondary; let (zi_primary_verified, zi_secondary_verified) = match self { - Self::Recursive(p, num_steps, _) => { + Self::Recursive(p, num_steps) => { p.verify(&pp.pp, *num_steps, z0_primary, &z0_secondary)? } - Self::Compressed(p, num_steps, _) => { + Self::Compressed(p, num_steps) => { p.verify(&pp.vk, *num_steps, z0_primary, &z0_secondary)? } }; @@ -395,7 +379,11 @@ pub struct NovaProver<'a, F: CurveCycleEquipped, C: Coprocessor> { _phantom: PhantomData<&'a ()>, } -impl<'a, F: CurveCycleEquipped, C: Coprocessor> NovaProver<'a, F, C> { +impl<'a, F: CurveCycleEquipped, C: Coprocessor + 'a> NovaProver<'a, F, C> +where + < as Engine>::Scalar as ff::PrimeField>::Repr: Abomonation, + < as Engine>::Scalar as ff::PrimeField>::Repr: Abomonation, +{ /// Create a new NovaProver with a reduction count and a `Lang` #[inline] pub fn new(reduction_count: usize, lang: Arc>) -> Self { @@ -406,28 +394,56 @@ impl<'a, F: CurveCycleEquipped, C: Coprocessor> NovaProver<'a, F, C> { _phantom: PhantomData, } } + + /// Generate a proof from a sequence of frames + pub fn prove_from_frames( + &self, + pp: &PublicParams>, + frames: &[Frame], + store: &'a Store, + ) -> Result<(Proof>, Vec, Vec, usize), ProofError> { + let folding_config = self + .folding_mode() + .folding_config(self.lang().clone(), self.reduction_count()); + let steps = C1LEM::<'a, F, C>::from_frames(frames, store, &folding_config.into()); + self.prove(pp, steps, store) + } + + #[inline] + fn lang(&self) -> &Arc> { + &self.lang + } } -impl<'a, F: CurveCycleEquipped, C: Coprocessor + 'a> Prover<'a, F, C> for NovaProver<'a, F, C> +impl<'a, F: CurveCycleEquipped, C: Coprocessor + 'a> Prover<'a, F, C1LEM<'a, F, C>> + for NovaProver<'a, F, C> where < as Engine>::Scalar as ff::PrimeField>::Repr: Abomonation, < as Engine>::Scalar as ff::PrimeField>::Repr: Abomonation, { type PublicParams = PublicParams>; - type RecursiveSnark = Proof<'a, F, C>; + type RecursiveSnark = Proof>; #[inline] fn reduction_count(&self) -> usize { self.reduction_count } - #[inline] - fn lang(&self) -> &Arc> { - &self.lang - } - #[inline] fn folding_mode(&self) -> &FoldingMode { &self.folding_mode } + + fn evaluate_and_prove( + &self, + pp: &Self::PublicParams, + expr: Ptr, + env: Ptr, + store: &'a Store, + limit: usize, + ) -> Result<(Self::RecursiveSnark, Vec, Vec, usize), ProofError> { + let eval_config = self.folding_mode().eval_config(self.lang()); + let frames = C1LEM::<'a, F, C>::build_frames(expr, env, store, limit, &eval_config)?; + self.prove_from_frames(pp, &frames, store) + } } diff --git a/src/proof/supernova.rs b/src/proof/supernova.rs index 1e0decc608..74a7ef93db 100644 --- a/src/proof/supernova.rs +++ b/src/proof/supernova.rs @@ -28,7 +28,7 @@ use crate::{ error::ProofError, eval::lang::Lang, field::LurkField, - lem::store::Store, + lem::{interpreter::Frame, pointers::Ptr, store::Store}, proof::{ nova::{CurveCycleEquipped, NovaCircuitShape, E1, E2}, Prover, RecursiveSNARKTrait, @@ -128,7 +128,7 @@ where /// An enum representing the two types of proofs that can be generated and verified. #[derive(Serialize, Deserialize)] #[serde(bound = "")] -pub enum Proof<'a, F: CurveCycleEquipped, C: Coprocessor> +pub enum Proof> where < as Engine>::Scalar as ff::PrimeField>::Repr: Abomonation, < as Engine>::Scalar as ff::PrimeField>::Repr: Abomonation, @@ -136,10 +136,7 @@ where /// A proof for the intermediate steps of a recursive computation Recursive(Box, E2>>), /// A proof for the final step of a recursive computation - Compressed( - Box, E2, C1LEM<'a, F, C>, C2, SS1, SS2>>, - PhantomData<&'a C>, - ), + Compressed(Box, E2, C1, C2, SS1, SS2>>), } /// A struct for the Nova prover that operates on field elements of type `F`. @@ -153,7 +150,11 @@ pub struct SuperNovaProver<'a, F: CurveCycleEquipped, C: Coprocessor + 'a> { _phantom: PhantomData<&'a ()>, } -impl<'a, F: CurveCycleEquipped, C: Coprocessor + 'a> SuperNovaProver<'a, F, C> { +impl<'a, F: CurveCycleEquipped, C: Coprocessor + 'a> SuperNovaProver<'a, F, C> +where + < as Engine>::Scalar as ff::PrimeField>::Repr: Abomonation, + < as Engine>::Scalar as ff::PrimeField>::Repr: Abomonation, +{ /// Create a new SuperNovaProver with a reduction count and a `Lang` #[inline] pub fn new(reduction_count: usize, lang: Arc>) -> Self { @@ -164,9 +165,29 @@ impl<'a, F: CurveCycleEquipped, C: Coprocessor + 'a> SuperNovaProver<'a, F, C _phantom: PhantomData, } } + + /// Generate a proof from a sequence of frames + pub fn prove_from_frames( + &self, + pp: &PublicParams>, + frames: &[Frame], + store: &'a Store, + ) -> Result<(Proof>, Vec, Vec, usize), ProofError> { + let folding_config = self + .folding_mode() + .folding_config(self.lang().clone(), self.reduction_count()); + let steps = C1LEM::<'a, F, C>::from_frames(frames, store, &folding_config.into()); + self.prove(pp, steps, store) + } + + #[inline] + fn lang(&self) -> &Arc> { + &self.lang + } } -impl<'a, F: CurveCycleEquipped, C: Coprocessor> RecursiveSNARKTrait<'a, F, C> for Proof<'a, F, C> +impl<'a, F: CurveCycleEquipped, C: Coprocessor> RecursiveSNARKTrait> + for Proof> where < as Engine>::Scalar as PrimeField>::Repr: Abomonation, < as Engine>::Scalar as PrimeField>::Repr: Abomonation, @@ -180,9 +201,7 @@ where pp: &PublicParams>, z0: &[F], steps: Vec>, - store: &'a Store, - _reduction_count: usize, - _lang: Arc>, + store: &Store, ) -> Result { let mut recursive_snark_option: Option, E2>> = None; @@ -279,14 +298,14 @@ where fn compress(self, pp: &PublicParams>) -> Result { match &self { - Self::Recursive(recursive_snark) => Ok(Self::Compressed( - Box::new(CompressedSNARK::<_, _, _, _, SS1, SS2>::prove( + Self::Recursive(recursive_snark) => { + let snark = CompressedSNARK::<_, _, _, _, SS1, SS2>::prove( &pp.pp, &pp.pk, recursive_snark, - )?), - PhantomData, - )), + )?; + Ok(Self::Compressed(Box::new(snark))) + } Self::Compressed(..) => Ok(self), } } @@ -298,35 +317,44 @@ where let (zi_primary_verified, zi_secondary_verified) = match self { Self::Recursive(p) => p.verify(&pp.pp, z0_primary, &z0_secondary)?, - Self::Compressed(p, _) => p.verify(&pp.pp, &pp.vk, z0_primary, &z0_secondary)?, + Self::Compressed(p) => p.verify(&pp.pp, &pp.vk, z0_primary, &z0_secondary)?, }; Ok(zi_primary == zi_primary_verified && zi_secondary == &zi_secondary_verified) } } -impl<'a, F: CurveCycleEquipped, C: Coprocessor> Prover<'a, F, C> for SuperNovaProver<'a, F, C> +impl<'a, F: CurveCycleEquipped, C: Coprocessor> Prover<'a, F, C1LEM<'a, F, C>> + for SuperNovaProver<'a, F, C> where < as Engine>::Scalar as ff::PrimeField>::Repr: Abomonation, < as Engine>::Scalar as ff::PrimeField>::Repr: Abomonation, { type PublicParams = PublicParams>; - type RecursiveSnark = Proof<'a, F, C>; + type RecursiveSnark = Proof>; #[inline] fn reduction_count(&self) -> usize { self.reduction_count } - #[inline] - fn lang(&self) -> &Arc> { - &self.lang - } - #[inline] fn folding_mode(&self) -> &FoldingMode { &self.folding_mode } + + fn evaluate_and_prove( + &self, + pp: &Self::PublicParams, + expr: Ptr, + env: Ptr, + store: &'a Store, + limit: usize, + ) -> Result<(Self::RecursiveSnark, Vec, Vec, usize), ProofError> { + let eval_config = self.folding_mode().eval_config(self.lang()); + let frames = C1LEM::<'a, F, C>::build_frames(expr, env, store, limit, &eval_config)?; + self.prove_from_frames(pp, &frames, store) + } } #[derive(Clone, Debug)] diff --git a/src/proof/tests/mod.rs b/src/proof/tests/mod.rs index 96901ad5de..fdaa215085 100644 --- a/src/proof/tests/mod.rs +++ b/src/proof/tests/mod.rs @@ -14,7 +14,7 @@ use crate::{ proof::{ nova::{public_params, CurveCycleEquipped, NovaProver, C1LEM, E1, E2}, supernova::FoldingConfig, - CEKState, EvaluationStore, Provable, Prover, RecursiveSNARKTrait, + CEKState, EvaluationStore, FrameLike, Provable, Prover, RecursiveSNARKTrait, }, }; @@ -140,7 +140,7 @@ where if check_nova { let pp = public_params(reduction_count, lang.clone()); - let (proof, z0, zi, _num_steps) = nova_prover.prove(&pp, &frames, s).unwrap(); + let (proof, z0, zi, _num_steps) = nova_prover.prove_from_frames(&pp, &frames, s).unwrap(); let res = proof.verify(&pp, &z0, &zi); if res.is_err() { @@ -207,7 +207,7 @@ where assert!(delta == Delta::Equal); } - let output = previous_frame.unwrap().output().as_ref().unwrap(); + let output = previous_frame.unwrap().output(); if let Some(expected_emitted) = expected_emitted { let mut emitted_vec = Vec::default();