Skip to content

Commit

Permalink
Generic Prover (#1074)
Browse files Browse the repository at this point in the history
* Removed `lang` and `reduction_count` from `prove_recursively`

* Removed `Coprocessor` from `RecursiveSNARKTrait`

* WIP removing coprocessor from `Prover`

* Simplified `RecursiveSNARKTrait`

* MultiFrame is FrameLike

* Generic `Prover`

* Generic `Proof`

* Use multiframe input
  • Loading branch information
gabriel-barrett authored Jan 23, 2024
1 parent 7acfbb5 commit 8046da0
Show file tree
Hide file tree
Showing 13 changed files with 160 additions and 131 deletions.
14 changes: 8 additions & 6 deletions benches/end2end.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use lurk::{
pointers::Ptr,
store::Store,
},
proof::{nova::NovaProver, Prover, RecursiveSNARKTrait},
proof::{nova::NovaProver, RecursiveSNARKTrait},
public_parameters::{
self,
instance::{Instance, Kind},
Expand Down Expand Up @@ -84,7 +84,7 @@ fn end2end_benchmark(c: &mut Criterion) {
b.iter(|| {
let ptr = go_base::<Fq>(&store, state.clone(), s.0, s.1);
let frames = evaluate::<Fq, Coproc<Fq>>(None, ptr, &store, limit).unwrap();
let _result = prover.prove(&pp, &frames, &store).unwrap();
let _result = prover.prove_from_frames(&pp, &frames, &store).unwrap();
})
});

Expand Down Expand Up @@ -253,7 +253,7 @@ fn prove_benchmark(c: &mut Criterion) {
let frames = evaluate::<Fq, Coproc<Fq>>(None, ptr, &store, limit).unwrap();

b.iter(|| {
let result = prover.prove(&pp, &frames, &store).unwrap();
let result = prover.prove_from_frames(&pp, &frames, &store).unwrap();
black_box(result);
})
});
Expand Down Expand Up @@ -300,7 +300,7 @@ fn prove_compressed_benchmark(c: &mut Criterion) {
let frames = evaluate::<Fq, Coproc<Fq>>(None, ptr, &store, limit).unwrap();

b.iter(|| {
let (proof, _, _, _) = prover.prove(&pp, &frames, &store).unwrap();
let (proof, _, _, _) = prover.prove_from_frames(&pp, &frames, &store).unwrap();

let compressed_result = proof.compress(&pp).unwrap();
black_box(compressed_result);
Expand Down Expand Up @@ -344,7 +344,8 @@ fn verify_benchmark(c: &mut Criterion) {
let ptr = go_base(&store, state.clone(), s.0, s.1);
let prover = NovaProver::new(reduction_count, lang_pallas_rc.clone());
let frames = evaluate::<Fq, Coproc<Fq>>(None, ptr, &store, limit).unwrap();
let (proof, z0, zi, _num_steps) = prover.prove(&pp, &frames, &store).unwrap();
let (proof, z0, zi, _num_steps) =
prover.prove_from_frames(&pp, &frames, &store).unwrap();

b.iter_batched(
|| z0.clone(),
Expand Down Expand Up @@ -396,7 +397,8 @@ fn verify_compressed_benchmark(c: &mut Criterion) {
let ptr = go_base(&store, state.clone(), s.0, s.1);
let prover = NovaProver::new(reduction_count, lang_pallas_rc.clone());
let frames = evaluate::<Fq, Coproc<Fq>>(None, ptr, &store, limit).unwrap();
let (proof, z0, zi, _num_steps) = prover.prove(&pp, &frames, &store).unwrap();
let (proof, z0, zi, _num_steps) =
prover.prove_from_frames(&pp, &frames, &store).unwrap();

let compressed_proof = proof.compress(&pp).unwrap();

Expand Down
3 changes: 1 addition & 2 deletions benches/fibonacci.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ use lurk::{
eval::lang::{Coproc, Lang},
lem::{eval::evaluate, store::Store},
proof::nova::NovaProver,
proof::Prover,
public_parameters::{
instance::{Instance, Kind},
public_params,
Expand Down Expand Up @@ -116,7 +115,7 @@ fn fibonacci_prove<M: measurement::Measurement>(
b.iter_batched(
|| frames,
|frames| {
let result = prover.prove(&pp, frames, &store);
let result = prover.prove_from_frames(&pp, frames, &store);
let _ = black_box(result);
},
BatchSize::LargeInput,
Expand Down
8 changes: 4 additions & 4 deletions benches/sha256.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use lurk::{
pointers::Ptr,
store::Store,
},
proof::{nova::NovaProver, supernova::SuperNovaProver, Prover, RecursiveSNARKTrait},
proof::{nova::NovaProver, supernova::SuperNovaProver, RecursiveSNARKTrait},
public_parameters::{
instance::{Instance, Kind},
public_params, supernova_public_params,
Expand Down Expand Up @@ -138,7 +138,7 @@ fn sha256_ivc_prove<M: measurement::Measurement>(
b.iter_batched(
|| frames,
|frames| {
let result = prover.prove(&pp, frames, store);
let result = prover.prove_from_frames(&pp, frames, store);
let _ = black_box(result);
},
BatchSize::LargeInput,
Expand Down Expand Up @@ -219,7 +219,7 @@ fn sha256_ivc_prove_compressed<M: measurement::Measurement>(
b.iter_batched(
|| frames,
|frames| {
let (proof, _, _, _) = prover.prove(&pp, frames, store).unwrap();
let (proof, _, _, _) = prover.prove_from_frames(&pp, frames, store).unwrap();
let compressed_result = proof.compress(&pp).unwrap();

let _ = black_box(compressed_result);
Expand Down Expand Up @@ -303,7 +303,7 @@ fn sha256_nivc_prove<M: measurement::Measurement>(
b.iter_batched(
|| frames,
|frames| {
let result = prover.prove(&pp, frames, store);
let result = prover.prove_from_frames(&pp, frames, store);
let _ = black_box(result);
},
BatchSize::LargeInput,
Expand Down
8 changes: 6 additions & 2 deletions examples/sha256_nivc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use lurk::{
pointers::Ptr,
store::Store,
},
proof::{supernova::SuperNovaProver, Prover, RecursiveSNARKTrait},
proof::{supernova::SuperNovaProver, RecursiveSNARKTrait},
public_parameters::{
instance::{Instance, Kind},
supernova_public_params,
Expand Down Expand Up @@ -94,7 +94,11 @@ fn main() {
println!("Beginning proof step...");
let proof_start = Instant::now();
let (proof, z0, zi, _num_steps) = tracing_texray::examine(tracing::info_span!("bang!"))
.in_scope(|| supernova_prover.prove(&pp, &frames, store).unwrap());
.in_scope(|| {
supernova_prover
.prove_from_frames(&pp, &frames, store)
.unwrap()
});
let proof_end = proof_start.elapsed();

println!("Proofs took {:?}", proof_end);
Expand Down
7 changes: 2 additions & 5 deletions examples/tp_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,7 @@ use criterion::black_box;
use lurk::{
eval::lang::{Coproc, Lang},
lem::{eval::evaluate, multiframe::MultiFrame, store::Store},
proof::{
nova::{public_params, NovaProver, PublicParams},
Prover,
},
proof::nova::{public_params, NovaProver, PublicParams},
};
use num_traits::ToPrimitive;
use pasta_curves::pallas::Scalar as Fr;
Expand Down Expand Up @@ -179,7 +176,7 @@ fn main() {
let mut timings = Vec::with_capacity(n_samples);
for _ in 0..n_samples {
let start = Instant::now();
let result = prover.prove(&pp, frames, &store);
let result = prover.prove_from_frames(&pp, frames, &store);
let _ = black_box(result);
let end = start.elapsed().as_secs_f64();
timings.push(end);
Expand Down
4 changes: 2 additions & 2 deletions src/cli/lurk_proof.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use crate::{
field::LurkField,
lem::{pointers::ZPtr, store::Store},
proof::{
nova::{self, CurveCycleEquipped, E1, E2},
nova::{self, CurveCycleEquipped, C1LEM, E1, E2},
RecursiveSNARKTrait,
},
public_parameters::{
Expand Down Expand Up @@ -131,7 +131,7 @@ pub(crate) enum LurkProof<
<<E2<F> as Engine>::Scalar as ff::PrimeField>::Repr: Abomonation,
{
Nova {
proof: nova::Proof<'a, F, C>,
proof: nova::Proof<F, C1LEM<'a, F, C>>,
public_inputs: Vec<F>,
public_outputs: Vec<F>,
rc: usize,
Expand Down
4 changes: 2 additions & 2 deletions src/cli/repl/meta_cmd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use crate::{
},
package::{Package, SymbolRef},
proof::{
nova::{self, CurveCycleEquipped, E1, E2},
nova::{self, CurveCycleEquipped, C1LEM, E1, E2},
RecursiveSNARKTrait,
},
public_parameters::{
Expand Down Expand Up @@ -1104,7 +1104,7 @@ where
{
Nova {
args: LurkData<F>,
proof: nova::Proof<'a, F, C>,
proof: nova::Proof<F, C1LEM<'a, F, C>>,
},
}

Expand Down
4 changes: 2 additions & 2 deletions src/cli/repl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ use crate::{
parser,
proof::{
nova::{CurveCycleEquipped, NovaProver},
Prover, RecursiveSNARKTrait,
RecursiveSNARKTrait,
},
public_parameters::{
instance::{Instance, Kind},
Expand Down Expand Up @@ -342,7 +342,7 @@ where

info!("Proving");
let (proof, public_inputs, public_outputs, num_steps) =
prover.prove(&pp, frames, &self.store)?;
prover.prove_from_frames(&pp, frames, &self.store)?;
info!("Compressing proof");
let proof = proof.compress(&pp)?;
assert_eq!(self.rc * num_steps, pad(n_frames, self.rc));
Expand Down
18 changes: 13 additions & 5 deletions src/lem/multiframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,6 @@ impl<'a, F: LurkField, C: Coprocessor<F>> MultiFrame<'a, F, C> {
self.frames.as_ref()
}

#[inline]
pub fn output(&self) -> &Option<Vec<Ptr>> {
&self.output
}

pub fn emitted(_store: &Store<F>, eval_frame: &Frame) -> Vec<Ptr> {
eval_frame.emitted.clone()
}
Expand Down Expand Up @@ -385,6 +380,19 @@ impl CEKState<Ptr> for Vec<Ptr> {
}
}

impl<'a, F: LurkField, C: Coprocessor<F>> FrameLike<Ptr> for MultiFrame<'a, F, C> {
type FrameIO = Vec<Ptr>;
#[inline]
fn input(&self) -> &Vec<Ptr> {
self.input.as_ref().unwrap()
}

#[inline]
fn output(&self) -> &Vec<Ptr> {
self.output.as_ref().unwrap()
}
}

impl FrameLike<Ptr> for Frame {
type FrameIO = Vec<Ptr>;
fn input(&self) -> &Self::FrameIO {
Expand Down
51 changes: 13 additions & 38 deletions src/proof/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,11 @@ use crate::{
error::ProofError,
eval::lang::Lang,
field::LurkField,
lem::{eval::EvalConfig, interpreter::Frame, pointers::Ptr, store::Store},
lem::{eval::EvalConfig, pointers::Ptr, store::Store},
proof::nova::E2,
};

use self::{
nova::{CurveCycleEquipped, C1LEM},
supernova::FoldingConfig,
};
use self::{nova::CurveCycleEquipped, supernova::FoldingConfig};

/// The State of a CEK machine.
pub trait CEKState<Ptr> {
Expand Down Expand Up @@ -88,7 +85,7 @@ pub trait Provable<F: LurkField> {
// * `Prover`, which abstracts over Nova and SuperNova provers

/// Trait to abstract Nova and SuperNova proofs
pub trait RecursiveSNARKTrait<'a, F: CurveCycleEquipped, C: Coprocessor<F> + 'a>
pub trait RecursiveSNARKTrait<F: CurveCycleEquipped, M>
where
Self: Sized,
{
Expand All @@ -102,10 +99,8 @@ where
fn prove_recursively(
pp: &Self::PublicParams,
z0: &[F],
steps: Vec<C1LEM<'a, F, C>>,
store: &'a Store<F>,
reduction_count: usize,
lang: Arc<Lang<F, C>>,
steps: Vec<M>,
store: &Store<F>,
) -> Result<Self, ProofError>;

/// Compress a proof
Expand Down Expand Up @@ -155,49 +150,33 @@ impl FoldingMode {
}

/// A trait for a prover that works with a field `F`.
pub trait Prover<'a, F: CurveCycleEquipped, C: Coprocessor<F> + 'a> {
pub trait Prover<'a, F: CurveCycleEquipped, M: FrameLike<Ptr, FrameIO = Vec<Ptr>>> {
/// Associated type for public parameters
type PublicParams;

/// Assiciated proof type, which must implement `RecursiveSNARKTrait`
type RecursiveSnark: RecursiveSNARKTrait<'a, F, C, PublicParams = Self::PublicParams>;
type RecursiveSnark: RecursiveSNARKTrait<F, M, PublicParams = Self::PublicParams>;

/// Returns a reference to the prover's FoldingMode
fn folding_mode(&self) -> &FoldingMode;

/// Returns the number of reductions for the prover.
fn reduction_count(&self) -> usize;

/// Returns a reference to the Prover's Lang.
fn lang(&self) -> &Arc<Lang<F, C>>;

/// Generate a proof from a sequence of frames
/// Generates a recursive proof from a vector of `M`
fn prove(
&self,
pp: &Self::PublicParams,
frames: &[Frame],
steps: Vec<M>,
store: &'a Store<F>,
) -> Result<(Self::RecursiveSnark, Vec<F>, Vec<F>, usize), ProofError> {
store.hydrate_z_cache();
let z0 = store.to_scalar_vector(frames[0].input());
let zi = store.to_scalar_vector(frames.last().unwrap().output());

let lang = self.lang().clone();
let folding_config = self
.folding_mode()
.folding_config(lang.clone(), self.reduction_count());
let z0 = store.to_scalar_vector(steps[0].input());
let zi = store.to_scalar_vector(steps.last().unwrap().output());

let steps = C1LEM::<'a, F, C>::from_frames(frames, store, &folding_config.into());
let num_steps = steps.len();

let prove_output = Self::RecursiveSnark::prove_recursively(
pp,
&z0,
steps,
store,
self.reduction_count(),
lang,
)?;
let prove_output = Self::RecursiveSnark::prove_recursively(pp, &z0, steps, store)?;

Ok((prove_output, z0, zi, num_steps))
}
Expand All @@ -210,11 +189,7 @@ pub trait Prover<'a, F: CurveCycleEquipped, C: Coprocessor<F> + 'a> {
env: Ptr,
store: &'a Store<F>,
limit: usize,
) -> Result<(Self::RecursiveSnark, Vec<F>, Vec<F>, usize), ProofError> {
let eval_config = self.folding_mode().eval_config(self.lang());
let frames = C1LEM::<'a, F, C>::build_frames(expr, env, store, limit, &eval_config)?;
self.prove(pp, &frames, store)
}
) -> Result<(Self::RecursiveSnark, Vec<F>, Vec<F>, usize), ProofError>;

/// Returns the expected total number of steps for the prover given raw iterations.
fn expected_num_steps(&self, raw_iterations: usize) -> usize {
Expand Down
Loading

1 comment on commit 8046da0

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmarks

Table of Contents

Overview

This benchmark report shows the Fibonacci GPU benchmark.
NVIDIA L4
Intel(R) Xeon(R) CPU @ 2.20GHz
32 vCPUs
125 GB RAM
Workflow run: https://github.com/lurk-lab/lurk-rs/actions/runs/7629541532

Benchmark Results

LEM Fibonacci Prove - rc = 100

ref=7acfbb53bdadf32cd4a3397bd3e193776ffbdba1 ref=8046da0e91f8ca9eb41b5733efcd23ca47c4fcae
num-100 1.74 s (✅ 1.00x) 1.73 s (✅ 1.00x faster)
num-200 3.36 s (✅ 1.00x) 3.35 s (✅ 1.00x faster)

LEM Fibonacci Prove - rc = 600

ref=7acfbb53bdadf32cd4a3397bd3e193776ffbdba1 ref=8046da0e91f8ca9eb41b5733efcd23ca47c4fcae
num-100 2.05 s (✅ 1.00x) 2.03 s (✅ 1.01x faster)
num-200 3.41 s (✅ 1.00x) 3.39 s (✅ 1.01x faster)

Made with criterion-table

Please sign in to comment.