Skip to content

Commit

Permalink
refactor: Refactor generic functions to specific C1LEM use
Browse files Browse the repository at this point in the history
- Refactored code, replacing the use of the `M: StepCircuit<F> + MultiFrameTrait<'a, F, C>` generic with the more specific `C1LEM<'a, F, C>` struct.
- Redesigned the `public_params` and `circuits` function outputs to return `C1LEM<'a, F, C>` instead of `M`.
- Observed substantial edits within code used for proof checking, compressed proof verification, and result checking.
- Eliminated usage of `MultiFrameTrait<'a, F, C>` and `MultiFrame` import from function parameters and file imports respectively across a number of files.
- Updated various function calls to accommodate the `C1LEM<'a, F, C>` struct instead of `M`.
- Altered the `public_params` function in `supernova.rs` by updating the return type and modifying setup for non-uniform circuit and public params.
  • Loading branch information
huitseeker committed Jan 6, 2024
1 parent 3a19179 commit b4176e3
Show file tree
Hide file tree
Showing 6 changed files with 285 additions and 308 deletions.
6 changes: 1 addition & 5 deletions benches/public_params.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use criterion::{black_box, criterion_group, criterion_main, Criterion, SamplingMode};
use lurk::{
eval::lang::{Coproc, Lang},
lem::multiframe::MultiFrame,
proof::nova,
};
use std::sync::Arc;
Expand All @@ -23,10 +22,7 @@ fn public_params_benchmark(c: &mut Criterion) {

group.bench_function("public_params_nova", |b| {
b.iter(|| {
let result = nova::public_params::<_, _, MultiFrame<'_, _, _>>(
reduction_count,
lang_pallas_rc.clone(),
);
let result = nova::public_params(reduction_count, lang_pallas_rc.clone());
black_box(result)
})
});
Expand Down
18 changes: 8 additions & 10 deletions src/proof/nova.rs
Original file line number Diff line number Diff line change
Expand Up @@ -190,15 +190,10 @@ pub fn circuit_cache_key<'a, F: CurveCycleEquipped, C: Coprocessor<F> + 'a>(
}

/// Generates the public parameters for the Nova proving system.
pub fn public_params<
'a,
F: CurveCycleEquipped,
C: Coprocessor<F> + 'a,
M: StepCircuit<F> + MultiFrameTrait<'a, F, C>,
>(
pub fn public_params<'a, F: CurveCycleEquipped, C: Coprocessor<F> + 'a>(
reduction_count: usize,
lang: Arc<Lang<F, C>>,
) -> PublicParams<F, M>
) -> PublicParams<F, C1LEM<'a, F, C>>
where
<<E1<F> as Engine>::Scalar as ff::PrimeField>::Repr: Abomonation,
<<E2<F> as Engine>::Scalar as ff::PrimeField>::Repr: Abomonation,
Expand All @@ -219,12 +214,15 @@ where
}

/// Generates the circuits for the Nova proving system.
pub fn circuits<'a, F: CurveCycleEquipped, C: Coprocessor<F> + 'a, M: MultiFrameTrait<'a, F, C>>(
pub fn circuits<'a, F: CurveCycleEquipped, C: Coprocessor<F> + 'a>(
reduction_count: usize,
lang: Arc<Lang<F, C>>,
) -> (M, C2<F>) {
) -> (C1LEM<'a, F, C>, C2<F>) {
let folding_config = Arc::new(FoldingConfig::new_ivc(lang, reduction_count));
(M::blank(folding_config, 0), TrivialCircuit::default())
(
C1LEM::<'a, F, C>::blank(folding_config, 0),
TrivialCircuit::default(),
)
}

impl<'a, F: CurveCycleEquipped, C: Coprocessor<F>> RecursiveSNARKTrait<'a, F, C> for Proof<'a, F, C>
Expand Down
13 changes: 4 additions & 9 deletions src/proof/supernova.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,27 +97,22 @@ pub type SS1<F> = nova::spartan::batched::BatchedRelaxedR1CSSNARK<E1<F>, EE1<F>>
pub type SS2<F> = nova::spartan::snark::RelaxedR1CSSNARK<E2<F>, EE2<F>>;

/// Generates the running claim params for the SuperNova proving system.
pub fn public_params<
'a,
F: CurveCycleEquipped,
C: Coprocessor<F> + 'a,
M: MultiFrameTrait<'a, F, C> + SuperStepCircuit<F> + NonUniformCircuit<E1<F>, E2<F>, M, C2<F>>,
>(
pub fn public_params<'a, F: CurveCycleEquipped, C: Coprocessor<F> + 'a>(
rc: usize,
lang: Arc<Lang<F, C>>,
) -> PublicParams<F, M>
) -> PublicParams<F, C1LEM<'a, F, C>>
where
<<E1<F> as Engine>::Scalar as ff::PrimeField>::Repr: Abomonation,
<<E2<F> as Engine>::Scalar as ff::PrimeField>::Repr: Abomonation,
{
let folding_config = Arc::new(FoldingConfig::new_nivc(lang, rc));
let non_uniform_circuit = M::blank(folding_config, 0);
let non_uniform_circuit = C1LEM::<'a, F, C>::blank(folding_config, 0);

// grab hints for the compressed SNARK variants we will use this with
let commitment_size_hint1 = <SS1<F> as BatchedRelaxedR1CSSNARKTrait<E1<F>>>::ck_floor();
let commitment_size_hint2 = <SS2<F> as RelaxedR1CSSNARKTrait<E2<F>>>::ck_floor();

let pp = SuperNovaPublicParams::<F, M>::setup(
let pp = SuperNovaPublicParams::<F, C1LEM<'a, F, C>>::setup(
&non_uniform_circuit,
&*commitment_size_hint1,
&*commitment_size_hint2,
Expand Down
43 changes: 17 additions & 26 deletions src/proof/tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ mod nova_tests_lem;

use abomonation::Abomonation;
use bellpepper::util_cs::{metric_cs::MetricCS, witness_cs::WitnessCS, Comparable};
use bellpepper_core::{test_cs::TestConstraintSystem, ConstraintSystem, Delta};
use bellpepper_core::{test_cs::TestConstraintSystem, Circuit, ConstraintSystem, Delta};
use expect_test::Expect;
use nova::traits::Engine;
use std::sync::Arc;
Expand All @@ -14,7 +14,7 @@ use crate::{
proof::{
nova::{public_params, CurveCycleEquipped, NovaProver, C1LEM, E1, E2},
supernova::FoldingConfig,
CEKState, EvaluationStore, MultiFrameTrait, Prover, RecursiveSNARKTrait,
CEKState, EvaluationStore, MultiFrameTrait, Provable, Prover, RecursiveSNARKTrait,
},
};

Expand All @@ -36,8 +36,8 @@ fn mismatch<T: PartialEq + Copy>(a: &[T], b: &[T]) -> Option<(usize, (Option<T>,
}
}

fn test_aux<'a, F: CurveCycleEquipped, C: Coprocessor<F> + 'a, M: MultiFrameTrait<'a, F, C> + 'a>(
s: &'a Store<F>,
fn test_aux<F: CurveCycleEquipped, C: Coprocessor<F>>(
s: &Store<F>,
expr: &str,
expected_result: Option<Ptr>,
expected_env: Option<Ptr>,
Expand All @@ -52,7 +52,7 @@ where
<<E2<F> as Engine>::Scalar as ff::PrimeField>::Repr: Abomonation,
{
for chunk_size in REDUCTION_COUNTS_TO_TEST {
nova_test_full_aux::<F, C, M>(
nova_test_full_aux::<F, C>(
s,
expr,
expected_result,
Expand All @@ -68,13 +68,8 @@ where
}
}

fn nova_test_full_aux<
'a,
F: CurveCycleEquipped,
C: Coprocessor<F> + 'a,
M: MultiFrameTrait<'a, F, C> + 'a,
>(
s: &'a Store<F>,
fn nova_test_full_aux<F: CurveCycleEquipped, C: Coprocessor<F>>(
s: &Store<F>,
expr: &str,
expected_result: Option<Ptr>,
expected_env: Option<Ptr>,
Expand All @@ -94,7 +89,7 @@ where
let expr = EvaluationStore::read(s, expr).unwrap();

let f = |l| {
nova_test_full_aux2::<F, C, M>(
nova_test_full_aux2::<F, C>(
s,
expr,
expected_result,
Expand All @@ -117,12 +112,7 @@ where
};
}

fn nova_test_full_aux2<
'a,
F: CurveCycleEquipped,
C: Coprocessor<F> + 'a,
M: MultiFrameTrait<'a, F, C>,
>(
fn nova_test_full_aux2<'a, F: CurveCycleEquipped, C: Coprocessor<F> + 'a>(
s: &'a Store<F>,
expr: Ptr,
expected_result: Option<Ptr>,
Expand All @@ -144,11 +134,12 @@ where

let e = s.initial_empty_env();

let frames = M::build_frames(expr, e, s, limit, &EvalConfig::new_ivc(&lang)).unwrap();
let frames =
C1LEM::<'a, F, C>::build_frames(expr, e, s, limit, &EvalConfig::new_ivc(&lang)).unwrap();
let nova_prover = NovaProver::<'a, F, C>::new(reduction_count, lang.clone());

if check_nova {
let pp = public_params::<_, _, C1LEM<F, C>>(reduction_count, lang.clone());
let pp = public_params(reduction_count, lang.clone());
let (proof, z0, zi, _num_steps) = nova_prover.prove(&pp, &frames, s).unwrap();

let res = proof.verify(&pp, &z0, &zi);
Expand All @@ -165,16 +156,16 @@ where

let folding_config = Arc::new(FoldingConfig::new_ivc(lang, nova_prover.reduction_count()));

let multiframes = M::from_frames(&frames, s, folding_config.clone());
let multiframes = C1LEM::<'a, F, C>::from_frames(&frames, s, &folding_config);
let len = multiframes.len();

let expected_iterations_data = expected_iterations.data().parse::<usize>().unwrap();
let adjusted_iterations = nova_prover.expected_num_steps(expected_iterations_data);
let mut previous_frame: Option<&M> = None;
let mut previous_frame: Option<&C1LEM<'a, F, C>> = None;

let mut cs_blank = MetricCS::<F>::new();

let blank = M::blank(folding_config, 0);
let blank = C1LEM::<'a, F, C>::blank(folding_config, 0);
blank
.synthesize(&mut cs_blank)
.expect("failed to synthesize blank");
Expand Down Expand Up @@ -221,7 +212,7 @@ where
if let Some(expected_emitted) = expected_emitted {
let mut emitted_vec = Vec::default();
for frame in frames.iter() {
emitted_vec.extend(M::emitted(s, frame));
emitted_vec.extend(C1LEM::<'a, F, C>::emitted(s, frame));
}
assert_eq!(expected_emitted, &emitted_vec);
}
Expand All @@ -238,7 +229,7 @@ where
assert_eq!(&s.get_cont_terminal(), output.cont());
}

let iterations = M::significant_frame_count(&frames);
let iterations = C1LEM::<'a, F, C>::significant_frame_count(&frames);
assert!(iterations <= expected_iterations_data);
expected_iterations.assert_eq(&iterations.to_string());
assert_eq!(adjusted_iterations, len);
Expand Down
Loading

0 comments on commit b4176e3

Please sign in to comment.