From 878686c44440ab0a8ba152e0a9e798fe7986b69f Mon Sep 17 00:00:00 2001 From: Arthur Paulino Date: Tue, 19 Dec 2023 09:42:38 -0300 Subject: [PATCH] add max-chunk-size as a REPL parameter for memory control --- benches/end2end.rs | 20 +-- benches/fibonacci.rs | 5 +- benches/sha256.rs | 18 +-- benches/synthesis.rs | 2 +- examples/sha256_nivc.rs | 15 +- src/cli/config.rs | 12 +- src/cli/mod.rs | 46 ++++-- src/cli/repl/meta_cmd.rs | 29 ++-- src/cli/repl/mod.rs | 305 +++++++++++++++++++++++------------- src/lem/eval.rs | 59 +++++-- src/lem/multiframe.rs | 8 +- src/lem/tests/nivc_steps.rs | 2 +- src/proof/mod.rs | 32 ++-- src/proof/nova.rs | 15 +- src/proof/supernova.rs | 36 ++--- src/proof/tests/mod.rs | 6 +- 16 files changed, 359 insertions(+), 251 deletions(-) diff --git a/benches/end2end.rs b/benches/end2end.rs index 078450bbe9..e03041ef93 100644 --- a/benches/end2end.rs +++ b/benches/end2end.rs @@ -84,8 +84,8 @@ fn end2end_benchmark(c: &mut Criterion) { group.bench_with_input(benchmark_id, &size, |b, &s| { b.iter(|| { let ptr = go_base::(&store, state.clone(), s.0, s.1); - let (frames, _) = evaluate::>(None, ptr, &store, limit).unwrap(); - let _result = prover.prove(&pp, &frames, &store).unwrap(); + let frames = evaluate::>(None, ptr, &store, limit).unwrap(); + let _result = prover.prove(&pp, &frames, &store, None).unwrap(); }) }); @@ -251,10 +251,10 @@ fn prove_benchmark(c: &mut Criterion) { let ptr = go_base::(&store, state.clone(), s.0, s.1); let prover: NovaProver<'_, Fq, Coproc, MultiFrame<'_, Fq, Coproc>> = NovaProver::new(reduction_count, lang_pallas_rc.clone()); - let (frames, _) = evaluate::>(None, ptr, &store, limit).unwrap(); + let frames = evaluate::>(None, ptr, &store, limit).unwrap(); b.iter(|| { - let result = prover.prove(&pp, &frames, &store).unwrap(); + let result = prover.prove(&pp, &frames, &store, None).unwrap(); black_box(result); }) }); @@ -298,10 +298,10 @@ fn prove_compressed_benchmark(c: &mut Criterion) { group.bench_with_input(benchmark_id, &size, |b, &s| { let ptr = go_base::(&store, state.clone(), s.0, s.1); let prover = NovaProver::new(reduction_count, lang_pallas_rc.clone()); - let (frames, _) = evaluate::>(None, ptr, &store, limit).unwrap(); + let frames = evaluate::>(None, ptr, &store, limit).unwrap(); b.iter(|| { - let (proof, _, _, _) = prover.prove(&pp, &frames, &store).unwrap(); + let (proof, _, _, _) = prover.prove(&pp, &frames, &store, None).unwrap(); let compressed_result = proof.compress(&pp).unwrap(); black_box(compressed_result); @@ -344,8 +344,8 @@ fn verify_benchmark(c: &mut Criterion) { group.bench_with_input(benchmark_id, &size, |b, &s| { let ptr = go_base(&store, state.clone(), s.0, s.1); let prover = NovaProver::new(reduction_count, lang_pallas_rc.clone()); - let (frames, _) = evaluate::>(None, ptr, &store, limit).unwrap(); - let (proof, z0, zi, num_steps) = prover.prove(&pp, &frames, &store).unwrap(); + let frames = evaluate::>(None, ptr, &store, limit).unwrap(); + let (proof, z0, zi, num_steps) = prover.prove(&pp, &frames, &store, None).unwrap(); b.iter_batched( || z0.clone(), @@ -396,8 +396,8 @@ fn verify_compressed_benchmark(c: &mut Criterion) { group.bench_with_input(benchmark_id, &size, |b, &s| { let ptr = go_base(&store, state.clone(), s.0, s.1); let prover = NovaProver::new(reduction_count, lang_pallas_rc.clone()); - let (frames, _) = evaluate::>(None, ptr, &store, limit).unwrap(); - let (proof, z0, zi, num_steps) = prover.prove(&pp, &frames, &store).unwrap(); + let frames = evaluate::>(None, ptr, &store, limit).unwrap(); + let (proof, z0, zi, num_steps) = prover.prove(&pp, &frames, &store, None).unwrap(); let compressed_proof = proof.compress(&pp).unwrap(); diff --git a/benches/fibonacci.rs b/benches/fibonacci.rs index 0d665a4c85..fbaa1b5c80 100644 --- a/benches/fibonacci.rs +++ b/benches/fibonacci.rs @@ -142,13 +142,12 @@ fn fibonacci_prove( let frames = &evaluate::>(None, ptr, &store, limit) - .unwrap() - .0; + .unwrap(); b.iter_batched( || frames, |frames| { - let result = prover.prove(&pp, frames, &store); + let result = prover.prove(&pp, frames, &store, None); let _ = black_box(result); }, BatchSize::LargeInput, diff --git a/benches/sha256.rs b/benches/sha256.rs index a9ad6b5b03..80cc5eb702 100644 --- a/benches/sha256.rs +++ b/benches/sha256.rs @@ -134,14 +134,12 @@ fn sha256_ivc_prove( let prover = NovaProver::new(prove_params.reduction_count, lang_rc.clone()); - let frames = &evaluate(Some((&lurk_step, &lang)), ptr, store, limit) - .unwrap() - .0; + let frames = &evaluate(Some((&lurk_step, &lang)), ptr, store, limit).unwrap(); b.iter_batched( || frames, |frames| { - let result = prover.prove(&pp, frames, store); + let result = prover.prove(&pp, frames, store, None); let _ = black_box(result); }, BatchSize::LargeInput, @@ -217,14 +215,12 @@ fn sha256_ivc_prove_compressed( let prover = NovaProver::new(prove_params.reduction_count, lang_rc.clone()); - let frames = &evaluate(Some((&lurk_step, &lang)), ptr, store, limit) - .unwrap() - .0; + let frames = &evaluate(Some((&lurk_step, &lang)), ptr, store, limit).unwrap(); b.iter_batched( || frames, |frames| { - let (proof, _, _, _) = prover.prove(&pp, frames, store).unwrap(); + let (proof, _, _, _) = prover.prove(&pp, frames, store, None).unwrap(); let compressed_result = proof.compress(&pp).unwrap(); let _ = black_box(compressed_result); @@ -302,14 +298,12 @@ fn sha256_nivc_prove( let prover = SuperNovaProver::new(prove_params.reduction_count, lang_rc.clone()); - let frames = &evaluate(Some((&lurk_step, &lang)), ptr, store, limit) - .unwrap() - .0; + let frames = &evaluate(Some((&lurk_step, &lang)), ptr, store, limit).unwrap(); b.iter_batched( || frames, |frames| { - let result = prover.prove(&pp, frames, store); + let result = prover.prove(&pp, frames, store, None); let _ = black_box(result); }, BatchSize::LargeInput, diff --git a/benches/synthesis.rs b/benches/synthesis.rs index cd1c6e6be0..da264cac8a 100644 --- a/benches/synthesis.rs +++ b/benches/synthesis.rs @@ -50,7 +50,7 @@ fn synthesize( let store = Store::default(); let fib_n = (reduction_count / 3) as u64; // Heuristic, since one fib is 35 iterations. let ptr = fib::(&store, state.clone(), black_box(fib_n)); - let (frames, _) = evaluate::>(None, ptr, &store, limit).unwrap(); + let frames = evaluate::>(None, ptr, &store, limit).unwrap(); let folding_config = Arc::new(FoldingConfig::new_ivc(lang_rc.clone(), *reduction_count)); diff --git a/examples/sha256_nivc.rs b/examples/sha256_nivc.rs index 446740c510..af18aa3e87 100644 --- a/examples/sha256_nivc.rs +++ b/examples/sha256_nivc.rs @@ -1,5 +1,5 @@ use pasta_curves::pallas::Scalar as Fr; -use std::{sync::Arc, time::Instant}; +use std::{marker::PhantomData, sync::Arc, time::Instant}; use tracing_subscriber::{fmt, prelude::*, EnvFilter, Registry}; use tracing_texray::TeXRayLayer; @@ -77,7 +77,7 @@ fn main() { let lang_rc = Arc::new(lang.clone()); let lurk_step = make_eval_step_from_config(&EvalConfig::new_nivc(&lang)); - let (frames, _) = evaluate(Some((&lurk_step, &lang)), call, store, 1000).unwrap(); + let frames = evaluate(Some((&lurk_step, &lang)), call, store, 1000).unwrap(); let supernova_prover = SuperNovaProver::, MultiFrame<'_, _, _>>::new( REDUCTION_COUNT, @@ -95,9 +95,8 @@ fn main() { println!("Beginning proof step..."); let proof_start = Instant::now(); - let ((proof, last_circuit_index), z0, zi, _num_steps) = - tracing_texray::examine(tracing::info_span!("bang!")) - .in_scope(|| supernova_prover.prove(&pp, &frames, store).unwrap()); + let (proof, z0, zi, _num_steps) = tracing_texray::examine(tracing::info_span!("bang!")) + .in_scope(|| supernova_prover.prove(&pp, &frames, store, None).unwrap()); let proof_end = proof_start.elapsed(); println!("Proofs took {:?}", proof_end); @@ -105,7 +104,7 @@ fn main() { println!("Verifying proof..."); let verify_start = Instant::now(); - assert!(proof.verify(&pp, &z0, &zi, last_circuit_index).unwrap()); + assert!(proof.verify(&pp, &z0, &zi, PhantomData).unwrap()); let verify_end = verify_start.elapsed(); println!("Verify took {:?}", verify_end); @@ -118,9 +117,7 @@ fn main() { println!("Compression took {:?}", compress_end); let compressed_verify_start = Instant::now(); - let res = compressed_proof - .verify(&pp, &z0, &zi, last_circuit_index) - .unwrap(); + let res = compressed_proof.verify(&pp, &z0, &zi, PhantomData).unwrap(); let compressed_verify_end = compressed_verify_start.elapsed(); println!("Final verification took {:?}", compressed_verify_end); diff --git a/src/cli/config.rs b/src/cli/config.rs index 76f55e463b..ffff14ca43 100644 --- a/src/cli/config.rs +++ b/src/cli/config.rs @@ -60,6 +60,8 @@ pub(crate) struct CliSettings { /// Iteration limit for the program, which is arbitrary to user preferences /// Used mainly as a safety check, similar to default stack size pub(crate) limit: usize, + /// Maximum number of frames held at once, used to avoid memory overflows + pub(crate) max_chunk_size: usize, } impl CliSettings { @@ -68,7 +70,7 @@ impl CliSettings { config_file: &Utf8PathBuf, cli_settings: Option<&HashMap<&str, String>>, ) -> Result { - let (proofs, commits, circom, backend, field, rc, limit) = ( + let (proofs, commits, circom, backend, field, rc, limit, max_chunk_size) = ( "proofs_dir", "commits_dir", "circom_dir", @@ -76,6 +78,7 @@ impl CliSettings { "field", "rc", "limit", + "max_chunk_size", ); Config::builder() .set_default(proofs, proofs_default_dir().to_string())? @@ -96,6 +99,7 @@ impl CliSettings { .set_override_option(field, cli_settings.and_then(|s| s.get(field).map(|v| v.to_owned())))? .set_override_option(rc, cli_settings.and_then(|s| s.get(rc).map(|v| v.to_owned())))? .set_override_option(limit, cli_settings.and_then(|s| s.get(limit).map(|v| v.to_owned())))? + .set_override_option(max_chunk_size, cli_settings.and_then(|s| s.get(max_chunk_size).map(|v| v.to_owned())))? .build() .and_then(|c| c.try_deserialize()) } @@ -111,6 +115,7 @@ impl Default for CliSettings { field: LanguageField::default(), rc: 10, limit: 100_000_000, + max_chunk_size: 1_000_000, } } } @@ -140,6 +145,7 @@ mod tests { let field = "Pallas"; let rc = 100; let limit = 100_000; + let max_chunk_size = 10_000; let mut config_file = std::fs::File::create(config_dir.clone()).unwrap(); config_file @@ -166,6 +172,9 @@ mod tests { config_file .write_all(format!("limit = {limit}\n").as_bytes()) .unwrap(); + config_file + .write_all(format!("max_chunk_size = {max_chunk_size}\n").as_bytes()) + .unwrap(); let cli_config = CliSettings::from_config(&config_dir, None).unwrap(); let lurk_config = Settings::from_config(&config_dir, None).unwrap(); @@ -177,5 +186,6 @@ mod tests { assert_eq!(cli_config.field, LanguageField::Pallas); assert_eq!(cli_config.rc, rc); assert_eq!(cli_config.limit, limit); + assert_eq!(cli_config.max_chunk_size, max_chunk_size); } } diff --git a/src/cli/mod.rs b/src/cli/mod.rs index bec951a1dc..257e137c3f 100644 --- a/src/cli/mod.rs +++ b/src/cli/mod.rs @@ -89,6 +89,10 @@ struct LoadArgs { #[clap(long, value_parser)] limit: Option, + /// Maximum number of held frames (defaults to 1_000_000; rounded up to the next multiple of rc) + #[clap(long, value_parser)] + max_chunk_size: Option, + /// Prover backend (defaults to "Nova") #[clap(long, value_enum)] backend: Option, @@ -138,6 +142,9 @@ struct LoadCli { #[clap(long, value_parser)] limit: Option, + #[clap(long, value_parser)] + max_chunk_size: Option, + #[clap(long, value_enum)] backend: Option, @@ -169,6 +176,7 @@ impl LoadArgs { config: self.config, rc: self.rc, limit: self.limit, + max_chunk_size: self.max_chunk_size, backend: self.backend, field: self.field, public_params_dir: self.public_params_dir, @@ -202,6 +210,10 @@ struct ReplArgs { #[clap(long, value_parser)] limit: Option, + /// Maximum number of held frames (defaults to 1_000_000; rounded up to the next multiple of rc) + #[clap(long, value_parser)] + max_chunk_size: Option, + /// Prover backend (defaults to "Nova") #[clap(long, value_enum)] backend: Option, @@ -244,6 +256,9 @@ struct ReplCli { #[clap(long, value_parser)] limit: Option, + #[clap(long, value_parser)] + max_chunk_size: Option, + #[clap(long, value_enum)] backend: Option, @@ -271,6 +286,7 @@ impl ReplArgs { config: self.config, rc: self.rc, limit: self.limit, + max_chunk_size: self.max_chunk_size, backend: self.backend, field: self.field, public_params_dir: self.public_params_dir, @@ -302,17 +318,17 @@ fn get_store serde::de::Deserialize<'a>>( } macro_rules! new_repl { - ( $cli: expr, $rc: expr, $limit: expr, $field: path, $backend: expr ) => {{ + ( $cli: expr, $rc: expr, $limit: expr, $max_chunk_size: expr, $field: path, $backend: expr ) => {{ let store = get_store(&$cli.zstore).with_context(|| "reading store from file")?; - Repl::<$field>::new(store, $rc, $limit, $backend) + Repl::<$field>::new(store, $rc, $limit, $max_chunk_size, $backend) }}; } impl ReplCli { fn run(&self) -> Result<()> { macro_rules! repl { - ( $rc: expr, $limit: expr, $field: path, $backend: expr ) => {{ - let mut repl = new_repl!(self, $rc, $limit, $field, $backend); + ( $rc: expr, $limit: expr, $max_chunk_size: expr, $field: path, $backend: expr ) => {{ + let mut repl = new_repl!(self, $rc, $limit, $max_chunk_size, $field, $backend); if let Some(lurk_file) = &self.load { repl.load_file(lurk_file, false)?; } @@ -338,7 +354,8 @@ impl ReplCli { backend, field, rc, - limit + limit, + max_chunk_size ); // Initializes CLI config with CLI arguments as overrides @@ -348,12 +365,15 @@ impl ReplCli { let rc = config.rc; let limit = config.limit; + let max_chunk_size = config.max_chunk_size; let backend = &config.backend; let field = &config.field; validate_non_zero("rc", rc)?; backend.validate_field(field)?; match field { - LanguageField::Pallas => repl!(rc, limit, pallas::Scalar, backend.clone()), + LanguageField::Pallas => { + repl!(rc, limit, max_chunk_size, pallas::Scalar, backend.clone()) + } LanguageField::Vesta => todo!(), LanguageField::BN256 => todo!(), LanguageField::Grumpkin => todo!(), @@ -364,11 +384,11 @@ impl ReplCli { impl LoadCli { fn run(&self) -> Result<()> { macro_rules! load { - ( $rc: expr, $limit: expr, $field: path, $backend: expr ) => {{ - let mut repl = new_repl!(self, $rc, $limit, $field, $backend); + ( $rc: expr, $limit: expr, $max_chunk_size: expr, $field: path, $backend: expr ) => {{ + let mut repl = new_repl!(self, $rc, $limit, $max_chunk_size, $field, $backend); repl.load_file(&self.lurk_file, self.demo)?; if self.prove { - repl.prove_last_frames()?; + repl.prove_last_computation()?; } Ok(()) }}; @@ -392,7 +412,8 @@ impl LoadCli { backend, field, rc, - limit + limit, + max_chunk_size ); // Initializes CLI config with CLI arguments as overrides @@ -402,12 +423,15 @@ impl LoadCli { let rc = config.rc; let limit = config.limit; + let max_chunk_size = config.max_chunk_size; let backend = &config.backend; let field = &config.field; validate_non_zero("rc", rc)?; backend.validate_field(field)?; match field { - LanguageField::Pallas => load!(rc, limit, pallas::Scalar, backend.clone()), + LanguageField::Pallas => { + load!(rc, limit, max_chunk_size, pallas::Scalar, backend.clone()) + } LanguageField::Vesta => todo!(), LanguageField::BN256 => todo!(), LanguageField::Grumpkin => todo!(), diff --git a/src/cli/repl/meta_cmd.rs b/src/cli/repl/meta_cmd.rs index 409c3c8baf..e9fa05dfe9 100644 --- a/src/cli/repl/meta_cmd.rs +++ b/src/cli/repl/meta_cmd.rs @@ -16,7 +16,6 @@ use crate::{ eval::lang::Coproc, field::LurkField, lem::{ - eval::evaluate_with_env_and_cont, multiframe::MultiFrame, pointers::{Ptr, ZPtr}, Tag, @@ -370,9 +369,9 @@ impl MetaCmd { ], run: |repl, args| { if !args.is_nil() { - repl.eval_expr_and_memoize(repl.peek1(args)?)?; + repl.evaluate_with_env_and_cont_then_memoize(repl.peek1(args)?, repl.env, repl.store.cont_outermost())?; } - repl.prove_last_frames()?; + repl.prove_last_computation()?; Ok(()) } }; @@ -605,13 +604,11 @@ impl MetaCmd { ], run: |repl: &mut Repl, args: &Ptr| { Self::call(repl, args)?; - let ev = repl - .get_evaluation() + let result = &repl + .cache .as_ref() - .expect("evaluation must have been set"); - let result = ev - .get_result() - .expect("evaluation result must have been set"); + .expect("evaluation result must have been set") + .1[0]; let (_, comm) = repl.store.car_cdr(result)?; let Ptr::Atom(Tag::Expr(ExprTag::Comm), hash) = comm else { bail!("Second component of a chain must be a commitment") @@ -1047,27 +1044,23 @@ impl MetaCmd { Self::post_verify_check(repl, post_verify)?; - let (frames, iterations) = evaluate_with_env_and_cont::>( - None, + let (output, _) = repl.evaluate_with_env_and_cont_then_memoize( cek_io[0], cek_io[1], Self::get_cont_ptr(repl, &cek_io[2])?, - &repl.store, - repl.limit, )?; { // making sure the output matches expectation before proving - let res = &frames.last().expect("frames can't be empty").output; - if cek_io[3] != res[0] - || cek_io[4] != res[1] - || Self::get_cont_ptr(repl, &cek_io[5])? != res[2] + if cek_io[3] != output[0] + || cek_io[4] != output[1] + || Self::get_cont_ptr(repl, &cek_io[5])? != output[2] { bail!("Mismatch between expected output and computed output") } } - let proof_key = repl.prove_frames(&frames, iterations)?; + let proof_key = repl.prove_last_computation()?; let mut z_dag = ZDag::default(); let z_ptr = z_dag.populate_with(&args, &repl.store, &mut Default::default()); let args = LurkData { z_ptr, z_dag }; diff --git a/src/cli/repl/mod.rs b/src/cli/repl/mod.rs index b64fbf187b..1964a03730 100644 --- a/src/cli/repl/mod.rs +++ b/src/cli/repl/mod.rs @@ -23,7 +23,10 @@ use crate::{ eval::lang::{Coproc, Lang}, field::LurkField, lem::{ - eval::{evaluate_simple_with_env, evaluate_with_env}, + eval::{ + evaluate_simple_with_env, evaluate_simple_with_input, evaluate_with_env_and_cont, + evaluate_with_input, + }, interpreter::Frame, multiframe::MultiFrame, pointers::Ptr, @@ -63,19 +66,6 @@ impl Validator for InputValidator { } } -#[allow(dead_code)] -struct Evaluation { - frames: Vec, - iterations: usize, -} - -impl Evaluation { - #[inline] - fn get_result(&self) -> Option<&Ptr> { - self.frames.last().and_then(|frame| frame.output.first()) - } -} - #[allow(dead_code)] pub(crate) struct Repl { store: Store, @@ -84,8 +74,9 @@ pub(crate) struct Repl { lang: Arc>>, rc: usize, limit: usize, + max_chunk_size: usize, backend: Backend, - evaluation: Option, + cache: Option<(Vec, Vec, usize)>, pwd_path: Utf8PathBuf, meta: HashMap<&'static str, MetaCmd>, apply_fn: OnceCell, @@ -107,10 +98,6 @@ fn pad(a: usize, m: usize) -> usize { } impl Repl { - fn get_evaluation(&self) -> &Option { - &self.evaluation - } - fn peek1(&self, args: &Ptr) -> Result { let (first, rest) = self.store.car_cdr(args)?; if !rest.is_nil() { @@ -152,8 +139,15 @@ impl Repl { type F = pasta_curves::pallas::Scalar; // TODO: generalize this impl Repl { - pub(crate) fn new(store: Store, rc: usize, limit: usize, backend: Backend) -> Repl { + pub(crate) fn new( + store: Store, + rc: usize, + limit: usize, + max_chunk_size: usize, + backend: Backend, + ) -> Repl { let limit = pad(limit, rc); + let max_chunk_size = pad(max_chunk_size, rc); info!( "Launching REPL with backend {backend}, field {}, rc {rc} and limit {limit}", F::FIELD @@ -169,8 +163,9 @@ impl Repl { lang: Arc::new(Lang::new()), rc, limit, + max_chunk_size, backend, - evaluation: None, + cache: None, pwd_path, meta: MetaCmd::cmds(), apply_fn: OnceCell::new(), @@ -258,95 +253,146 @@ impl Repl { format!("{backend}_{field}_{rc}_{claim_hash}") } - /// Proves a computation and returns the proof key - pub(crate) fn prove_frames(&self, frames: &[Frame], iterations: usize) -> Result { - match self.backend { - Backend::Nova => { - info!("Hydrating the store"); - self.store.hydrate_z_cache(); - - let n_frames = frames.len(); - - // saving to avoid clones - let input = &frames[0].input; - let output = &frames[n_frames - 1].output; - let mut z_dag = ZDag::::default(); - let mut cache = HashMap::default(); - let expr = z_dag.populate_with(&input[0], &self.store, &mut cache); - let env = z_dag.populate_with(&input[1], &self.store, &mut cache); - let cont = z_dag.populate_with(&input[2], &self.store, &mut cache); - let expr_out = z_dag.populate_with(&output[0], &self.store, &mut cache); - let env_out = z_dag.populate_with(&output[1], &self.store, &mut cache); - let cont_out = z_dag.populate_with(&output[2], &self.store, &mut cache); - - let claim = Self::proof_claim( - &self.store, - (input[0], output[0]), - (input[1], output[1]), - (cont.parts(), cont_out.parts()), - ); - - let claim_comm = Commitment::new(None, claim, &self.store); - let claim_hash = &claim_comm.hash.hex_digits(); - let proof_key = Self::proof_key(&self.backend, &self.rc, claim_hash); - - let lurk_proof_meta = LurkProofMeta { - iterations, - expr_io: (expr, expr_out), - env_io: Some((env, env_out)), - cont_io: (cont, cont_out), - z_dag, - }; + /// Proves the last cached computation and returns the proof key + pub(crate) fn prove_last_computation(&mut self) -> Result { + // releasing ownership of the cache, which might be lost + let mut cache = None; + std::mem::swap(&mut cache, &mut self.cache); + match cache { + None => bail!("No evaluation to prove"), + Some((frames, output, iterations)) => { + match self.backend { + Backend::Nova => { + info!("Hydrating the store"); + self.store.hydrate_z_cache(); + + // saving to avoid clones + let input = &frames[0].input; + let mut z_dag = ZDag::::default(); + let mut cache = HashMap::default(); + let expr = z_dag.populate_with(&input[0], &self.store, &mut cache); + let env = z_dag.populate_with(&input[1], &self.store, &mut cache); + let cont = z_dag.populate_with(&input[2], &self.store, &mut cache); + let expr_out = z_dag.populate_with(&output[0], &self.store, &mut cache); + let env_out = z_dag.populate_with(&output[1], &self.store, &mut cache); + let cont_out = z_dag.populate_with(&output[2], &self.store, &mut cache); + + let claim = Self::proof_claim( + &self.store, + (input[0], output[0]), + (input[1], output[1]), + (cont.parts(), cont_out.parts()), + ); + + let claim_comm = Commitment::new(None, claim, &self.store); + let claim_hash = &claim_comm.hash.hex_digits(); + let proof_key = Self::proof_key(&self.backend, &self.rc, claim_hash); + + let lurk_proof_meta = LurkProofMeta { + iterations, + expr_io: (expr, expr_out), + env_io: Some((env, env_out)), + cont_io: (cont, cont_out), + z_dag, + }; + + if LurkProof::<_, _, MultiFrame<'_, _, Coproc>>::is_cached(&proof_key) { + info!("Proof already cached"); + } else { + info!("Proof not cached. Loading public parameters"); + let instance = Instance::new( + self.rc, + self.lang.clone(), + true, + Kind::NovaPublicParams, + ); + let pp = public_params(&instance)?; + + let prover = NovaProver::<_, _, MultiFrame<'_, F, Coproc>>::new( + self.rc, + self.lang.clone(), + ); + + info!("Proving"); + let (mut proof, public_inputs, mut public_outputs, mut num_steps) = + prover.prove(&pp, &frames, &self.store, None)?; + let last_frame_output = &frames.last().unwrap().output; + if last_frame_output != &output { + // we need to further evaluate and prove, updating `proof`, + // `public_outputs` and `num_steps`, until we reach the + // known (precomputed) output + let mut input = last_frame_output.clone(); + let mut current_proof = Some(proof); + let mut fuel = self.limit - frames.len(); + loop { + assert!(fuel > 0); + // the maximum number of iterations allowed in this chunk + let partial_fuel = self.max_chunk_size.min(fuel); + let frames = evaluate_with_input::>( + None, + input, + &self.store, + partial_fuel, + )?; + // we could reset the store here as well, keeping only + // the commitments and poseidon cache... but let's not + // do it just yet and wait for a real need to manifest + let ( + partial_proof, + _, + partial_public_outputs, + partial_num_steps, + ) = prover.prove(&pp, &frames, &self.store, current_proof)?; + public_outputs = partial_public_outputs; + num_steps += partial_num_steps; + let partial_iterations = frames.len(); + let last_frame_output = &frames[partial_iterations - 1].output; + if last_frame_output == &output { + proof = partial_proof; + break; + } + input = last_frame_output.clone(); + current_proof = Some(partial_proof); + fuel -= partial_iterations; + } + } else { + // we can recover the cache here because a new + // chunk of frames wasn't necessary + self.cache = Some((frames, output, iterations)); + } - if LurkProof::<_, _, MultiFrame<'_, _, Coproc>>::is_cached(&proof_key) { - info!("Proof already cached"); - } else { - info!("Proof not cached. Loading public parameters"); - let instance = - Instance::new(self.rc, self.lang.clone(), true, Kind::NovaPublicParams); - let pp = public_params(&instance)?; - - let prover = NovaProver::<_, _, MultiFrame<'_, F, Coproc>>::new( - self.rc, - self.lang.clone(), - ); - - info!("Proving"); - let (proof, public_inputs, public_outputs, num_steps) = - prover.prove(&pp, frames, &self.store)?; - info!("Compressing proof"); - let proof = proof.compress(&pp)?; - assert_eq!(self.rc * num_steps, pad(n_frames, self.rc)); - assert!(proof.verify(&pp, &public_inputs, &public_outputs, num_steps)?); - - let lurk_proof = LurkProof::Nova { - proof, - public_inputs, - public_outputs, - num_steps, - rc: self.rc, - lang: (*self.lang).clone(), - }; - - lurk_proof.persist(&proof_key)?; + info!("Compressing proof"); + let proof = proof.compress(&pp)?; + assert_eq!(self.rc * num_steps, pad(iterations, self.rc)); + assert!(proof.verify( + &pp, + &public_inputs, + &public_outputs, + num_steps + )?); + + let lurk_proof = LurkProof::Nova { + proof, + public_inputs, + public_outputs, + num_steps, + rc: self.rc, + lang: (*self.lang).clone(), + }; + + lurk_proof.persist(&proof_key)?; + } + lurk_proof_meta.persist(&proof_key)?; + claim_comm.persist()?; + println!("Claim hash: 0x{claim_hash}"); + println!("Proof key: \"{proof_key}\""); + Ok(proof_key) + } } - lurk_proof_meta.persist(&proof_key)?; - claim_comm.persist()?; - println!("Claim hash: 0x{claim_hash}"); - println!("Proof key: \"{proof_key}\""); - Ok(proof_key) } } } - /// Proves the last cached computation and returns the proof key - pub(crate) fn prove_last_frames(&self) -> Result { - match self.evaluation.as_ref() { - None => bail!("No evaluation to prove"), - Some(Evaluation { frames, iterations }) => self.prove_frames(frames, *iterations), - } - } - fn hide(&mut self, secret: F, payload: Ptr) -> Result<()> { let commitment = Commitment::new(Some(secret), payload, &self.store); let hash_str = &commitment.hash.hex_digits(); @@ -408,6 +454,11 @@ impl Repl { self.eval_expr_with_env(expr, self.env) } + #[inline] + fn halted(cek: &[Ptr]) -> bool { + matches!(cek[2].tag(), Tag::Cont(ContTag::Terminal | ContTag::Error)) + } + fn eval_expr_allowing_error_continuation( &mut self, expr_ptr: Ptr, @@ -419,7 +470,7 @@ impl Repl { &self.store, self.limit, )?; - if matches!(ptrs[2].tag(), Tag::Cont(ContTag::Terminal | ContTag::Error)) { + if Self::halted(&ptrs) { Ok((ptrs, iterations, emitted)) } else { bail!( @@ -429,11 +480,37 @@ impl Repl { } } - fn eval_expr_and_memoize(&mut self, expr_ptr: Ptr) -> Result<(Vec, usize)> { - let (frames, iterations) = - evaluate_with_env::>(None, expr_ptr, self.env, &self.store, self.limit)?; - let output = frames[frames.len() - 1].output.clone(); - self.evaluation = Some(Evaluation { frames, iterations }); + fn evaluate_with_env_and_cont_then_memoize( + &mut self, + expr: Ptr, + env: Ptr, + cont: Ptr, + ) -> Result<(Vec, usize)> { + self.cache = None; + let (frames, output, iterations) = { + let fuel = self.max_chunk_size.min(self.limit); + let frames = evaluate_with_env_and_cont::>( + None, + expr, + env, + cont, + &self.store, + fuel, + )?; + let iterations = frames.len(); + let output = frames[iterations - 1].output.clone(); + if iterations == self.limit || Self::halted(&output) { + // we can't or need to go any further + (frames, output, iterations) + } else { + // max_chunk_size was reached... move on without accumulating frames + let fuel = self.limit - iterations; + let (output, iterations_non_cached, _) = + evaluate_simple_with_input::>(None, output, &self.store, fuel)?; + (frames, output, iterations + iterations_non_cached) + } + }; + self.cache = Some((frames, output.clone(), iterations)); Ok((output, iterations)) } @@ -450,8 +527,12 @@ impl Repl { Ok(self.store.expect_f(*hash_idx)) } - pub(crate) fn handle_non_meta(&mut self, expr_ptr: Ptr) -> Result<()> { - let (output, iterations) = self.eval_expr_and_memoize(expr_ptr)?; + pub(crate) fn handle_non_meta(&mut self, expr: Ptr) -> Result<()> { + let (output, iterations) = self.evaluate_with_env_and_cont_then_memoize( + expr, + self.env, + self.store.cont_outermost(), + )?; let iterations_display = Self::pretty_iterations_display(iterations); match output[2].tag() { Tag::Cont(ContTag::Terminal) => { diff --git a/src/lem/eval.rs b/src/lem/eval.rs index 755bd4bea7..0fe90d3640 100644 --- a/src/lem/eval.rs +++ b/src/lem/eval.rs @@ -92,7 +92,7 @@ fn build_frames< limit: usize, lang: &Lang, log_fmt: LogFmt, -) -> Result<(Vec, usize)> { +) -> Result> { let mut pc = 0; let mut frames = vec![]; let mut iterations = 0; @@ -113,7 +113,7 @@ fn build_frames< } pc = get_pc(&expr, store, lang); } - Ok((frames, iterations)) + Ok(frames) } /// Faster version of `build_frames` that doesn't accumulate frames @@ -143,14 +143,12 @@ fn traverse_frames>( Ok((input, iterations, emitted)) } -pub fn evaluate_with_env_and_cont>( +pub fn evaluate_with_input>( func_lang: Option<(&Func, &Lang)>, - expr: Ptr, - env: Ptr, - cont: Ptr, + input: Vec, store: &Store, limit: usize, -) -> Result<(Vec, usize)> { +) -> Result> { let state = initial_lurk_state(); let log_fmt = |i: usize, inp: &[Ptr], emit: &[Ptr], store: &Store| { let mut out = format!( @@ -165,8 +163,6 @@ pub fn evaluate_with_env_and_cont>( out }; - let input = vec![expr, env, cont]; - match func_lang { None => { let lang: Lang = Lang::new(); @@ -179,6 +175,18 @@ pub fn evaluate_with_env_and_cont>( } } +#[inline] +pub fn evaluate_with_env_and_cont>( + func_lang: Option<(&Func, &Lang)>, + expr: Ptr, + env: Ptr, + cont: Ptr, + store: &Store, + limit: usize, +) -> Result> { + evaluate_with_input(func_lang, vec![expr, env, cont], store, limit) +} + #[inline] pub fn evaluate_with_env>( func_lang: Option<(&Func, &Lang)>, @@ -186,7 +194,7 @@ pub fn evaluate_with_env>( env: Ptr, store: &Store, limit: usize, -) -> Result<(Vec, usize)> { +) -> Result> { evaluate_with_env_and_cont(func_lang, expr, env, store.cont_outermost(), store, limit) } @@ -196,7 +204,7 @@ pub fn evaluate>( expr: Ptr, store: &Store, limit: usize, -) -> Result<(Vec, usize)> { +) -> Result> { evaluate_with_env_and_cont( func_lang, expr, @@ -207,14 +215,12 @@ pub fn evaluate>( ) } -pub fn evaluate_simple_with_env>( +pub fn evaluate_simple_with_input>( func_lang: Option<(&Func, &Lang)>, - expr: Ptr, - env: Ptr, + input: Vec, store: &Store, limit: usize, ) -> Result<(Vec, usize, Vec)> { - let input = vec![expr, env, store.cont_outermost()]; match func_lang { None => { let lang: Lang = Lang::new(); @@ -227,6 +233,29 @@ pub fn evaluate_simple_with_env>( } } +pub fn evaluate_simple_with_env_and_cont>( + func_lang: Option<(&Func, &Lang)>, + expr: Ptr, + env: Ptr, + cont: Ptr, + store: &Store, + limit: usize, +) -> Result<(Vec, usize, Vec)> { + let input = vec![expr, env, cont]; + evaluate_simple_with_input(func_lang, input, store, limit) +} + +#[inline] +pub fn evaluate_simple_with_env>( + func_lang: Option<(&Func, &Lang)>, + expr: Ptr, + env: Ptr, + store: &Store, + limit: usize, +) -> Result<(Vec, usize, Vec)> { + evaluate_simple_with_env_and_cont(func_lang, expr, env, store.cont_outermost(), store, limit) +} + #[inline] pub fn evaluate_simple>( func_lang: Option<(&Func, &Lang)>, diff --git a/src/lem/multiframe.rs b/src/lem/multiframe.rs index 6663f88854..6cd6f01346 100644 --- a/src/lem/multiframe.rs +++ b/src/lem/multiframe.rs @@ -667,8 +667,8 @@ impl<'a, F: LurkField, C: Coprocessor + 'a> MultiFrameTrait<'a, F, C> for Mul limit: usize, ec: &EvalConfig<'_, F, C>, ) -> Result, ProofError> { - let cont = store.cont_outermost(); let lurk_step = make_eval_step_from_config(ec); + let cont = store.cont_outermost(); match evaluate_with_env_and_cont( Some((&lurk_step, ec.lang())), expr, @@ -677,7 +677,7 @@ impl<'a, F: LurkField, C: Coprocessor + 'a> MultiFrameTrait<'a, F, C> for Mul store, limit, ) { - Ok((frames, _)) => Ok(frames), + Ok(frames) => Ok(frames), Err(e) => Err(ProofError::Reduction(ReductionError::Misc(e.to_string()))), } } @@ -994,7 +994,7 @@ mod tests { let expr = store.read_with_default_state("(if t (+ 5 5) 6)").unwrap(); - let (frames, _) = evaluate::>(None, expr, &store, 10).unwrap(); + let frames = evaluate::>(None, expr, &store, 10).unwrap(); let sequential_slots_witnesses = generate_slots_witnesses(&store, &frames, num_slots_per_frame, false); @@ -1101,7 +1101,7 @@ mod tests { let expr = store.read_with_default_state("(+ 1 2)").unwrap(); let lang = Arc::new(Lang::>::new()); - let (mut frames, _) = evaluate::>(None, expr, &store, 1).unwrap(); + let mut frames = evaluate::>(None, expr, &store, 1).unwrap(); assert_eq!(frames.len(), 1); let mut frame = frames.pop().unwrap(); diff --git a/src/lem/tests/nivc_steps.rs b/src/lem/tests/nivc_steps.rs index 39eac8bd88..64b3e11ba4 100644 --- a/src/lem/tests/nivc_steps.rs +++ b/src/lem/tests/nivc_steps.rs @@ -31,7 +31,7 @@ fn test_nivc_steps() { // 9^2 + 8 = 89 let expr = store.read_with_default_state("(cproc-dumb 9 8)").unwrap(); - let (frames, _) = evaluate(Some((&lurk_step, &lang)), expr, &store, 10).unwrap(); + let frames = evaluate(Some((&lurk_step, &lang)), expr, &store, 10).unwrap(); // Iteration 1: evaluate first argument // Iteration 2: evaluate second argument diff --git a/src/proof/mod.rs b/src/proof/mod.rs index ff679ef1a2..39365f61f4 100644 --- a/src/proof/mod.rs +++ b/src/proof/mod.rs @@ -57,8 +57,10 @@ pub trait EvaluationStore { /// interpreting a string representation of an expression fn read(&self, expr: &str) -> Result; + /// getting a pointer to the initial, empty environment fn initial_empty_env(&self) -> Self::Ptr; + /// getting the terminal continuation pointer fn get_cont_terminal(&self) -> Self::ContPtr; @@ -180,16 +182,13 @@ pub trait RecursiveSNARKTrait< /// Associated type for public parameters type PublicParams; - /// Main output of `prove_recursively`, encoding the actual proof - type ProveOutput; - /// Extra input for `verify` to be defined as needed type ExtraVerifyInput; /// Type for error potentially thrown during verification type ErrorType; - /// Generate the recursive SNARK, encoded in `ProveOutput` + /// Generate the recursive SNARK fn prove_recursively( pp: &Self::PublicParams, z0: &[F], @@ -197,7 +196,8 @@ pub trait RecursiveSNARKTrait< store: &'a M::Store, reduction_count: usize, lang: Arc>, - ) -> Result; + initial_snark: Option, + ) -> Result; /// Compress a proof fn compress(self, pp: &Self::PublicParams) -> Result; @@ -257,18 +257,8 @@ pub trait Prover<'a, F: CurveCycleEquipped, C: Coprocessor + 'a, M: MultiFram /// Associated type for public parameters type PublicParams; - /// Main output of `prove`, encoding the actual proof - type ProveOutput; - /// Assiciated proof type, which must implement `RecursiveSNARKTrait` - type RecursiveSnark: RecursiveSNARKTrait< - 'a, - F, - C, - M, - PublicParams = Self::PublicParams, - ProveOutput = Self::ProveOutput, - >; + type RecursiveSNARK: RecursiveSNARKTrait<'a, F, C, M, PublicParams = Self::PublicParams>; /// Returns a reference to the prover's FoldingMode fn folding_mode(&self) -> &FoldingMode; @@ -285,7 +275,8 @@ pub trait Prover<'a, F: CurveCycleEquipped, C: Coprocessor + 'a, M: MultiFram pp: &Self::PublicParams, frames: &[M::EvalFrame], store: &'a M::Store, - ) -> Result<(Self::ProveOutput, Vec, Vec, usize), ProofError> { + initial_snark: Option, + ) -> Result<(Self::RecursiveSNARK, Vec, Vec, usize), ProofError> { store.hydrate_z_cache(); let z0 = M::io_to_scalar_vector(store, frames[0].input()); let zi = M::io_to_scalar_vector(store, frames.last().unwrap().output()); @@ -298,13 +289,14 @@ pub trait Prover<'a, F: CurveCycleEquipped, C: Coprocessor + 'a, M: MultiFram let steps = M::from_frames(frames, store, folding_config.into()); let num_steps = steps.len(); - let prove_output = Self::RecursiveSnark::prove_recursively( + let prove_output = Self::RecursiveSNARK::prove_recursively( pp, &z0, steps, store, self.reduction_count(), lang, + initial_snark, )?; Ok((prove_output, z0, zi, num_steps)) @@ -318,10 +310,10 @@ pub trait Prover<'a, F: CurveCycleEquipped, C: Coprocessor + 'a, M: MultiFram env: M::Ptr, store: &'a M::Store, limit: usize, - ) -> Result<(Self::ProveOutput, Vec, Vec, usize), ProofError> { + ) -> Result<(Self::RecursiveSNARK, Vec, Vec, usize), ProofError> { let eval_config = self.folding_mode().eval_config(self.lang()); let frames = M::build_frames(expr, env, store, limit, &eval_config)?; - self.prove(pp, &frames, store) + self.prove(pp, &frames, store, None) } /// Returns the expected total number of steps for the prover given raw iterations. diff --git a/src/proof/nova.rs b/src/proof/nova.rs index 642db31792..38ca7c7fe7 100644 --- a/src/proof/nova.rs +++ b/src/proof/nova.rs @@ -1,5 +1,3 @@ -#![allow(non_snake_case)] - use abomonation::Abomonation; use bellpepper_core::{num::AllocatedNum, ConstraintSystem}; use ff::PrimeField; @@ -235,8 +233,6 @@ where { type PublicParams = PublicParams; - type ProveOutput = Self; - /// The number of steps type ExtraVerifyInput = usize; @@ -250,6 +246,7 @@ where store: &'a ::Store, reduction_count: usize, lang: Arc>, + initial_snark: Option, ) -> Result { assert!(!steps.is_empty()); assert_eq!(steps[0].arity(), z0.len()); @@ -263,8 +260,11 @@ where tracing::debug!("steps.len: {}", steps.len()); - // produce a recursive SNARK - let mut recursive_snark: Option, E2, M, C2>> = None; + // produce a recursive SNARK, starting from the initial one + let mut recursive_snark = initial_snark.and_then(|p| match p { + Self::Recursive(p, _) => Some(*p), + Self::Compressed(..) => None, + }); // the shadowing here is voluntary let recursive_snark = if lurk_config(None, None) @@ -428,8 +428,7 @@ where < as Engine>::Scalar as ff::PrimeField>::Repr: Abomonation, { type PublicParams = PublicParams; - type ProveOutput = Proof<'a, F, C, M>; - type RecursiveSnark = Proof<'a, F, C, M>; + type RecursiveSNARK = Proof<'a, F, C, M>; #[inline] fn reduction_count(&self) -> usize { diff --git a/src/proof/supernova.rs b/src/proof/supernova.rs index 462adc5516..b1588e6bb5 100644 --- a/src/proof/supernova.rs +++ b/src/proof/supernova.rs @@ -1,5 +1,3 @@ -#![allow(non_snake_case)] - use abomonation::Abomonation; use ff::PrimeField; use nova::{ @@ -192,11 +190,7 @@ where { type PublicParams = PublicParams; - /// Proving with SuperNova outputs the proof and the index of the last circuit - type ProveOutput = (Self, usize); - - /// The index of the last circuit - type ExtraVerifyInput = usize; + type ExtraVerifyInput = PhantomData; type ErrorType = SuperNovaError; @@ -208,18 +202,20 @@ where _store: &'a ::Store, _reduction_count: usize, _lang: Arc>, - ) -> Result<(Self, usize), ProofError> { - let mut recursive_snark_option: Option, E2>> = None; + initial_snark: Option, + ) -> Result { + let mut recursive_snark_option = initial_snark.and_then(|p| match p { + Self::Recursive(p) => Some(*p), + Self::Compressed(..) => None, + }); let z0_primary = z0; let z0_secondary = Self::z0_secondary(); - let mut last_circuit_index = 0; - for (i, step) in steps.iter().enumerate() { info!("prove_recursively, step {i}"); - let mut recursive_snark = recursive_snark_option.clone().unwrap_or_else(|| { + let mut recursive_snark = recursive_snark_option.unwrap_or_else(|| { info!("RecursiveSnark::new {i}"); RecursiveSNARK::new( &pp.pp, @@ -239,17 +235,12 @@ where .unwrap(); recursive_snark_option = Some(recursive_snark); - - last_circuit_index = step.circuit_index(); } // This probably should be made unnecessary. - Ok(( - Self::Recursive(Box::new( - recursive_snark_option.expect("RecursiveSNARK missing"), - )), - last_circuit_index, - )) + Ok(Self::Recursive(Box::new( + recursive_snark_option.expect("RecursiveSNARK missing"), + ))) } fn compress(self, pp: &PublicParams) -> Result { @@ -271,7 +262,7 @@ where pp: &Self::PublicParams, z0: &[F], zi: &[F], - _last_circuit_idx: usize, + _phantom: PhantomData, ) -> Result { let (z0_primary, zi_primary) = (z0, zi); let z0_secondary = Self::z0_secondary(); @@ -297,8 +288,7 @@ where < as Engine>::Scalar as ff::PrimeField>::Repr: Abomonation, { type PublicParams = PublicParams; - type ProveOutput = (Proof<'a, F, C, M>, usize); - type RecursiveSnark = Proof<'a, F, C, M>; + type RecursiveSNARK = Proof<'a, F, C, M>; #[inline] fn reduction_count(&self) -> usize { diff --git a/src/proof/tests/mod.rs b/src/proof/tests/mod.rs index d14abbb005..7cc80b32af 100644 --- a/src/proof/tests/mod.rs +++ b/src/proof/tests/mod.rs @@ -131,14 +131,14 @@ where { let limit = limit.unwrap_or(10000); - let e = s.initial_empty_env(); + let env = s.initial_empty_env(); - let frames = M::build_frames(expr, e, s, limit, &EvalConfig::new_ivc(&lang)).unwrap(); + let frames = M::build_frames(expr, env, s, limit, &EvalConfig::new_ivc(&lang)).unwrap(); let nova_prover = NovaProver::<'a, F, C, M>::new(reduction_count, lang.clone()); if check_nova { let pp = public_params::<_, _, M>(reduction_count, lang.clone()); - let (proof, z0, zi, num_steps) = nova_prover.prove(&pp, &frames, s).unwrap(); + let (proof, z0, zi, num_steps) = nova_prover.prove(&pp, &frames, s, None).unwrap(); let res = proof.verify(&pp, &z0, &zi, num_steps); if res.is_err() {