Skip to content

Commit

Permalink
Adapt type simplifications (#1092)
Browse files Browse the repository at this point in the history
* chore: point to type simplifications PR

* refactor: Refactor to adapt to Arecibo simplifications

- Significant refactoring has been done across multiple files to include the `SecEng` trait and `RecursiveSNARKTrait`, replacing `C1LEM`, `E2`, `C2` as well as `GrumpkinEngine` and `VestaEngine`.
- Adjustments to various function arguments and returns to reflect the above changes, primarily noted in `public_parameters/mod.rs` and `proof/nova.rs`.
- The implementation of `NonUniformCircuit` has been updated to only work with `E1` and has `C1` and `C2` types added.
- removal of `MultiFrame` usage from `PublicParams` and simplification of type declarations.
- Replaced functions across files, mainly `public_params`, `supernova_circuit_params`, `supernova_aux_params`, and `supernova_public_params`, have been adjusted for syntax changes.

* refactor: Refactor to replace SecEng with Dual Alias

- Consolidated the use of `nova` related code with the use of `seceng` being changed to `Dual` in various contexts across multiple files.

* refactor: Improve readability through correct use of F

- Simplified type declarations and trait bounds in `src/public_parameters/mod.rs` for improved readability
- Updated `TestConstraintSystem` initialization in `src/proof/nova.rs`

* refactor: avoid most pedantic instances of compress/verify

* refactor: use default argument for Lang

* chore: clippy

* chore: rename nova_tests_lem -> nova_tests

* chore: rename SecEng -> DualEng

* chore: point back to dev
  • Loading branch information
huitseeker authored Feb 5, 2024
1 parent 107dc37 commit 222255f
Show file tree
Hide file tree
Showing 21 changed files with 195 additions and 227 deletions.
2 changes: 1 addition & 1 deletion benches/common/fib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ fn compute_coeffs<F: LurkField>(store: &Store<F>) -> (usize, usize) {
store.intern_empty_env(),
store.cont_outermost(),
];
let lang: Lang<F, Coproc<F>> = Lang::new();
let lang: Lang<F> = Lang::new();
let mut coef_lin = 0;
let coef_ang;
let step_func = eval_step();
Expand Down
10 changes: 5 additions & 5 deletions benches/end2end.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ fn end2end_benchmark(c: &mut Criterion) {
let reduction_count = DEFAULT_REDUCTION_COUNT;

// setup
let lang_pallas = Lang::<Fq, Coproc<Fq>>::new();
let lang_pallas = Lang::<Fq>::new();
let lang_pallas_rc = Arc::new(lang_pallas.clone());

let store = Store::default();
Expand Down Expand Up @@ -234,7 +234,7 @@ fn prove_benchmark(c: &mut Criterion) {

let state = State::init_lurk_state().rccell();

let lang_pallas = Lang::<Fq, Coproc<Fq>>::new();
let lang_pallas = Lang::<Fq>::new();
let lang_pallas_rc = Arc::new(lang_pallas.clone());

// use cached public params
Expand Down Expand Up @@ -282,7 +282,7 @@ fn prove_compressed_benchmark(c: &mut Criterion) {

let state = State::init_lurk_state().rccell();

let lang_pallas = Lang::<Fq, Coproc<Fq>>::new();
let lang_pallas = Lang::<Fq>::new();
let lang_pallas_rc = Arc::new(lang_pallas.clone());

// use cached public params
Expand Down Expand Up @@ -324,7 +324,7 @@ fn verify_benchmark(c: &mut Criterion) {

let state = State::init_lurk_state().rccell();

let lang_pallas = Lang::<Fq, Coproc<Fq>>::new();
let lang_pallas = Lang::<Fq>::new();
let lang_pallas_rc = Arc::new(lang_pallas.clone());

// use cached public params
Expand Down Expand Up @@ -377,7 +377,7 @@ fn verify_compressed_benchmark(c: &mut Criterion) {

let state = State::init_lurk_state().rccell();

let lang_pallas = Lang::<Fq, Coproc<Fq>>::new();
let lang_pallas = Lang::<Fq>::new();
let lang_pallas_rc = Arc::new(lang_pallas.clone());

// use cached public params
Expand Down
2 changes: 1 addition & 1 deletion benches/fibonacci.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ fn fibonacci_prove<M: measurement::Measurement>(
c: &mut BenchmarkGroup<'_, M>,
) {
let limit = fib_limit(prove_params.fib_n, prove_params.reduction_count);
let lang_pallas = Lang::<pallas::Scalar, Coproc<pallas::Scalar>>::new();
let lang_pallas = Lang::<pallas::Scalar>::new();
let lang_rc = Arc::new(lang_pallas.clone());

// use cached public params
Expand Down
8 changes: 2 additions & 6 deletions benches/public_params.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
use criterion::{black_box, criterion_group, criterion_main, Criterion, SamplingMode};
use lurk::{
eval::lang::{Coproc, Lang},
proof::nova,
};
use lurk::{eval::lang::Lang, proof::nova};
use std::sync::Arc;
use std::time::Duration;

Expand All @@ -14,8 +11,7 @@ const DEFAULT_REDUCTION_COUNT: usize = 10;
fn public_params_benchmark(c: &mut Criterion) {
let mut group = c.benchmark_group("public_params_benchmark");
group.sampling_mode(SamplingMode::Flat);
let lang_pallas =
Lang::<pasta_curves::pallas::Scalar, Coproc<pasta_curves::pallas::Scalar>>::new();
let lang_pallas = Lang::<pasta_curves::pallas::Scalar>::new();
let lang_pallas_rc = Arc::new(lang_pallas);

let reduction_count = DEFAULT_REDUCTION_COUNT;
Expand Down
8 changes: 4 additions & 4 deletions benches/sha256.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ fn sha256_ivc<F: LurkField>(
state: Rc<RefCell<State>>,
arity: usize,
n: usize,
input: &Vec<usize>,
input: &[usize],
) -> Ptr {
assert_eq!(n, input.len());
let input = input
Expand Down Expand Up @@ -128,7 +128,7 @@ fn sha256_ivc_prove<M: measurement::Measurement>(
state.clone(),
black_box(prove_params.arity),
black_box(prove_params.n),
&(0..prove_params.n).collect(),
&(0..prove_params.n).collect::<Vec<_>>(),
);

let prover = NovaProver::new(prove_params.reduction_count, lang_rc.clone());
Expand Down Expand Up @@ -209,7 +209,7 @@ fn sha256_ivc_prove_compressed<M: measurement::Measurement>(
state.clone(),
black_box(prove_params.arity),
black_box(prove_params.n),
&(0..prove_params.n).collect(),
&(0..prove_params.n).collect::<Vec<_>>(),
);

let prover = NovaProver::new(prove_params.reduction_count, lang_rc.clone());
Expand Down Expand Up @@ -293,7 +293,7 @@ fn sha256_nivc_prove<M: measurement::Measurement>(
state.clone(),
black_box(prove_params.arity),
black_box(prove_params.n),
&(0..prove_params.n).collect(),
&(0..prove_params.n).collect::<Vec<_>>(),
);

let prover = SuperNovaProver::new(prove_params.reduction_count, lang_rc.clone());
Expand Down
2 changes: 1 addition & 1 deletion benches/synthesis.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ fn synthesize<M: measurement::Measurement>(
c: &mut BenchmarkGroup<'_, M>,
) {
let limit = 1_000_000;
let lang_rc = Arc::new(Lang::<Fq, Coproc<Fq>>::new());
let lang_rc = Arc::new(Lang::<Fq>::new());
let state = State::init_lurk_state().rccell();

c.bench_with_input(
Expand Down
7 changes: 3 additions & 4 deletions examples/tp_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use ascii_table::{Align, AsciiTable};
use criterion::black_box;
use lurk::{
eval::lang::{Coproc, Lang},
lem::{eval::evaluate, multiframe::MultiFrame, store::Store},
lem::{eval::evaluate, store::Store},
proof::nova::{public_params, NovaProver, PublicParams},
};
use num_traits::ToPrimitive;
Expand Down Expand Up @@ -157,7 +157,7 @@ fn main() {

let frames = evaluate::<Fr, Coproc<Fr>>(None, program, &store, limit).unwrap();

let lang = Lang::<Fr, Coproc<Fr>>::new();
let lang = Lang::<Fr>::new();
let lang_arc = Arc::new(lang.clone());

let mut data = Vec::with_capacity(rc_vec.len());
Expand All @@ -166,8 +166,7 @@ fn main() {
let prover: NovaProver<'_, _, _> = NovaProver::new(rc, lang_arc.clone());
println!("Getting public params for rc={rc}");
// TODO: use cache once it's fixed
let pp: PublicParams<_, MultiFrame<'_, _, Coproc<Fr>>> =
public_params(rc, lang_arc.clone());
let pp: PublicParams<_> = public_params(rc, lang_arc.clone());
let n_folds_data = (0..=max_n_folds)
.map(|n_folds| {
let n_frames = n_iters(n_folds, rc);
Expand Down
18 changes: 10 additions & 8 deletions src/cli/lurk_proof.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use ::nova::traits::Engine;
use abomonation::Abomonation;
use anyhow::{bail, Result};
use ff::PrimeField;
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use std::{collections::HashMap, sync::Arc};

Expand All @@ -10,7 +10,7 @@ use crate::{
field::LurkField,
lem::{pointers::ZPtr, store::Store},
proof::{
nova::{self, CurveCycleEquipped, C1LEM, E1, E2},
nova::{self, CurveCycleEquipped, Dual, C1LEM},
RecursiveSNARKTrait,
},
public_parameters::{
Expand Down Expand Up @@ -154,9 +154,10 @@ impl<'a, F: CurveCycleEquipped + Serialize, C: Coprocessor<F> + Serialize + Dese
}

impl<
'a,
F: CurveCycleEquipped + DeserializeOwned,
C: Coprocessor<F> + Serialize + DeserializeOwned + 'static,
> LurkProof<'static, F, C>
C: Coprocessor<F> + Serialize + DeserializeOwned + 'a,
> LurkProof<'a, F, C>
{
#[inline]
pub(crate) fn is_cached(proof_key: &str) -> bool {
Expand All @@ -165,12 +166,13 @@ impl<
}

impl<
'a,
F: CurveCycleEquipped + DeserializeOwned,
C: Coprocessor<F> + Serialize + DeserializeOwned + 'static,
> LurkProof<'static, F, C>
C: Coprocessor<F> + Serialize + DeserializeOwned + 'a,
> LurkProof<'a, F, C>
where
<<E1<F> as Engine>::Scalar as ff::PrimeField>::Repr: Abomonation,
<<E2<F> as Engine>::Scalar as ff::PrimeField>::Repr: Abomonation,
F::Repr: Abomonation,
<Dual<F> as PrimeField>::Repr: Abomonation,
{
pub(crate) fn verify_proof(proof_key: &str) -> Result<()> {
let lurk_proof = load::<Self>(&proof_path(proof_key))?;
Expand Down
27 changes: 10 additions & 17 deletions src/cli/repl/meta_cmd.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use ::nova::traits::Engine;
use ::nova::supernova::StepCircuit;
use abomonation::Abomonation;
use anyhow::{anyhow, bail, Context, Result};
use camino::{Utf8Path, Utf8PathBuf};
use ff::PrimeField;
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use std::{collections::HashMap, process};

Expand All @@ -22,7 +23,7 @@ use crate::{
},
package::{Package, SymbolRef},
proof::{
nova::{self, CurveCycleEquipped, C1LEM, E1, E2},
nova::{self, CurveCycleEquipped, Dual, C1LEM},
RecursiveSNARKTrait,
},
public_parameters::{
Expand All @@ -44,12 +45,13 @@ pub(super) struct MetaCmd<F: LurkField, C: Coprocessor<F> + Serialize + Deserial
}

impl<
'a,
F: CurveCycleEquipped + Serialize + DeserializeOwned,
C: Coprocessor<F> + Serialize + DeserializeOwned + 'static,
> MetaCmd<F, C>
where
<<E1<F> as Engine>::Scalar as ff::PrimeField>::Repr: Abomonation,
<<E2<F> as Engine>::Scalar as ff::PrimeField>::Repr: Abomonation,
F::Repr: Abomonation,
<Dual<F> as PrimeField>::Repr: Abomonation,
{
const LOAD: MetaCmd<F, C> = MetaCmd {
name: "load",
Expand Down Expand Up @@ -992,7 +994,7 @@ where

let (fun, proto_rc) = Self::get_fun_and_rc(repl, ptcl)?;

match load::<ProtocolProof<'_, _, C>>(&path)? {
match load::<ProtocolProof<F, C1LEM<'a, F, C>>>(&path)? {
ProtocolProof::Nova {
args: LurkData { z_ptr, z_dag },
proof,
Expand Down Expand Up @@ -1096,23 +1098,14 @@ fn get_path<F: LurkField, C: Coprocessor<F> + Serialize + DeserializeOwned>(
#[non_exhaustive]
#[derive(Serialize, Deserialize)]
#[serde(bound(serialize = "F: Serialize", deserialize = "F: DeserializeOwned"))]
enum ProtocolProof<'a, F: CurveCycleEquipped, C: Coprocessor<F> + Serialize + DeserializeOwned>
where
<<E1<F> as Engine>::Scalar as ff::PrimeField>::Repr: Abomonation,
<<E2<F> as Engine>::Scalar as ff::PrimeField>::Repr: Abomonation,
{
enum ProtocolProof<F: CurveCycleEquipped, S> {
Nova {
args: LurkData<F>,
proof: nova::Proof<F, C1LEM<'a, F, C>>,
proof: nova::Proof<F, S>,
},
}

impl<'a, F: CurveCycleEquipped, C: Coprocessor<F> + 'a + Serialize + DeserializeOwned>
HasFieldModulus for ProtocolProof<'a, F, C>
where
<<E1<F> as Engine>::Scalar as ff::PrimeField>::Repr: Abomonation,
<<E2<F> as Engine>::Scalar as ff::PrimeField>::Repr: Abomonation,
{
impl<F: CurveCycleEquipped, S: StepCircuit<F>> HasFieldModulus for ProtocolProof<F, S> {
fn field_modulus() -> String {
F::MODULUS.to_owned()
}
Expand Down
8 changes: 4 additions & 4 deletions src/cli/repl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ mod meta_cmd;
use abomonation::Abomonation;
use anyhow::{anyhow, bail, Context, Result};
use camino::{Utf8Path, Utf8PathBuf};
use nova::traits::Engine;
use ff::PrimeField;
use rustyline::{
error::ReadlineError,
history::DefaultHistory,
Expand Down Expand Up @@ -39,7 +39,7 @@ use crate::{
},
parser,
proof::{
nova::{CurveCycleEquipped, NovaProver, E1, E2},
nova::{CurveCycleEquipped, Dual, NovaProver},
RecursiveSNARKTrait,
},
public_parameters::{
Expand Down Expand Up @@ -166,8 +166,8 @@ impl<
C: Coprocessor<F> + Serialize + DeserializeOwned + 'static,
> Repl<F, C>
where
<<E1<F> as Engine>::Scalar as ff::PrimeField>::Repr: Abomonation,
<<E2<F> as Engine>::Scalar as ff::PrimeField>::Repr: Abomonation,
F::Repr: Abomonation,
<Dual<F> as PrimeField>::Repr: Abomonation,
{
pub(crate) fn new(
store: Store<F>,
Expand Down
6 changes: 3 additions & 3 deletions src/eval/lang.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ pub enum Coproc<F: LurkField> {
///
// TODO: Define a trait for the Hash and parameterize on that also.
#[derive(Debug, Default, Clone, Deserialize, Serialize)]
pub struct Lang<F, C> {
pub struct Lang<F, C = Coproc<F>> {
/// An IndexMap that stores coprocessors with their associated `Sym` keys.
coprocessors: IndexMap<Symbol, C>,
_p: PhantomData<F>,
Expand Down Expand Up @@ -185,12 +185,12 @@ pub(crate) mod test {

#[test]
fn lang() {
Lang::<Fr, Coproc<Fr>>::new();
Lang::<Fr>::new();
}

#[test]
fn dummy_lang() {
let _lang = Lang::<Fr, Coproc<Fr>>::new_with_bindings(vec![(
let _lang = Lang::<Fr>::new_with_bindings(vec![(
sym!("coproc", "dummy"),
DummyCoprocessor::new().into(),
)]);
Expand Down
7 changes: 2 additions & 5 deletions src/lem/eval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1720,10 +1720,7 @@ fn make_thunk() -> Func {
#[cfg(test)]
mod tests {
use super::*;
use crate::{
eval::lang::{Coproc, Lang},
lem::store::Store,
};
use crate::{eval::lang::Lang, lem::store::Store};
use bellpepper_core::{test_cs::TestConstraintSystem, Comparable};
use expect_test::{expect, Expect};
use halo2curves::bn256::Fr;
Expand All @@ -1734,7 +1731,7 @@ mod tests {
let func = eval_step();
let frame = Frame::blank(func, 0, &store);
let mut cs = TestConstraintSystem::<Fr>::new();
let lang: Lang<Fr, Coproc<Fr>> = Lang::new();
let lang: Lang<Fr> = Lang::new();
let _ = func.synthesize_frame_aux(&mut cs, &store, &frame, &lang);
let expect_eq = |computed: usize, expected: Expect| {
expected.assert_eq(&computed.to_string());
Expand Down
11 changes: 7 additions & 4 deletions src/lem/multiframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use crate::{
eval::lang::Lang,
field::{LanguageField, LurkField},
proof::{
nova::{CurveCycleEquipped, E1, E2},
nova::{CurveCycleEquipped, E1},
supernova::{FoldingConfig, C2},
CEKState, EvaluationStore, FrameLike, Provable,
},
Expand Down Expand Up @@ -898,11 +898,14 @@ impl<'a, F: LurkField, C: Coprocessor<F>> nova::supernova::StepCircuit<F> for Mu
}
}

impl<'a, F, C> NonUniformCircuit<E1<F>, E2<F>, MultiFrame<'a, F, C>, C2<F>> for MultiFrame<'a, F, C>
impl<'a, F, C> NonUniformCircuit<E1<F>> for MultiFrame<'a, F, C>
where
F: CurveCycleEquipped + LurkField,
C: Coprocessor<F> + 'a,
{
type C1 = MultiFrame<'a, F, C>;
type C2 = C2<F>;

fn num_circuits(&self) -> usize {
assert_eq!(self.pc, 0);
self.get_lang().coprocessor_count() + 1
Expand Down Expand Up @@ -1042,7 +1045,7 @@ mod tests {

let mut cs_clone = cs.clone();

let lang = Lang::<Bn, Coproc<Bn>>::new();
let lang = Lang::<Bn>::new();

let output_sequential = synthesize_frames_sequential(
&mut cs,
Expand Down Expand Up @@ -1083,7 +1086,7 @@ mod tests {
// not self-evaluating
let expr = store.read_with_default_state("(+ 1 2)").unwrap();

let lang = Arc::new(Lang::<Bn, Coproc<Bn>>::new());
let lang = Arc::new(Lang::<Bn>::new());
let mut frames = evaluate::<Bn, Coproc<Bn>>(None, expr, &store, 1).unwrap();
assert_eq!(frames.len(), 1);

Expand Down
Loading

0 comments on commit 222255f

Please sign in to comment.