diff --git a/Cargo.lock b/Cargo.lock index 9712deb2da..d86b4cd5e2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -161,6 +161,12 @@ version = "0.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "96d30a06541fbafbc7f82ed10c06164cfbd2c401138f6addd8404629c4b16711" +[[package]] +name = "ascii_table" +version = "4.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75054ce561491263d7b80dc2f6f6c6f8cdfd0c7a7c17c5cf3b8117829fa72ae1" + [[package]] name = "assert_cmd" version = "2.0.12" @@ -725,6 +731,20 @@ dependencies = [ "itertools 0.10.5", ] +[[package]] +name = "crossbeam" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2801af0d36612ae591caa9568261fddce32ce6e08a7275ea334a06a4ad021a2c" +dependencies = [ + "cfg-if", + "crossbeam-channel", + "crossbeam-deque", + "crossbeam-epoch", + "crossbeam-queue", + "crossbeam-utils", +] + [[package]] name = "crossbeam-channel" version = "0.5.8" @@ -759,6 +779,16 @@ dependencies = [ "scopeguard", ] +[[package]] +name = "crossbeam-queue" +version = "0.3.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d1cfb3ea8a53f37c40dea2c7bedcbd88bdfae54f5e2175d6ecaff1c988353add" +dependencies = [ + "cfg-if", + "crossbeam-utils", +] + [[package]] name = "crossbeam-utils" version = "0.8.16" @@ -1496,6 +1526,7 @@ dependencies = [ "ahash 0.7.6", "anyhow", "anymap", + "ascii_table", "assert_cmd", "base-x", "base32ct", @@ -1507,6 +1538,7 @@ dependencies = [ "clap 4.3.17", "config", "criterion", + "crossbeam", "dashmap", "ff", "generic-array", @@ -1527,6 +1559,7 @@ dependencies = [ "num-bigint 0.4.3", "num-integer", "num-traits", + "num_cpus", "once_cell", "pairing", "pasta-msm", @@ -1552,6 +1585,7 @@ dependencies = [ "tap", "tempfile", "thiserror", + "vergen", ] [[package]] @@ -1800,6 +1834,15 @@ dependencies = [ "libc", ] +[[package]] +name = "num_threads" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2819ce041d2ee131036f4fc9d6ae7ae125a3a40e97ba64d04fe799ad9dabbb44" +dependencies = [ + "libc", +] + [[package]] name = "object" version = "0.31.1" @@ -2819,6 +2862,35 @@ dependencies = [ "num_cpus", ] +[[package]] +name = "time" +version = "0.3.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ea9e1b3cf1243ae005d9e74085d4d542f3125458f3a81af210d901dcd7411efd" +dependencies = [ + "itoa", + "libc", + "num_threads", + "serde", + "time-core", + "time-macros", +] + +[[package]] +name = "time-core" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7300fbefb4dadc1af235a9cef3737cea692a9d97e1b9cbcd4ebdae6f8868e6fb" + +[[package]] +name = "time-macros" +version = "0.2.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "372950940a5f07bf38dbe211d7283c9e6d7327df53794992d293e534c733d09b" +dependencies = [ + "time-core", +] + [[package]] name = "tinytemplate" version = "1.2.1" @@ -2903,6 +2975,17 @@ version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d023da39d1fde5a8a3fe1f3e01ca9632ada0a63e9797de55a879d6e2236277be" +[[package]] +name = "vergen" +version = "8.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ce38fc503fa57441ac2539c3e723b5adf76601eb4f1ad24025c6660d27f355b7" +dependencies = [ + "anyhow", + "rustversion", + "time", +] + [[package]] name = "version_check" version = "0.9.4" diff --git a/Cargo.toml b/Cargo.toml index 31e64dab2c..da3a1a152b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -36,6 +36,7 @@ nom = "7.1.3" nom_locate = "4.1.0" nova = { workspace = true, default-features = false } num-bigint = "0.4.3" +num_cpus = "1.10.1" num-integer = "0.1.45" num-traits = "0.2.15" once_cell = { workspace = true } @@ -57,6 +58,7 @@ stable_deref_trait = "1.2.0" thiserror = { workspace = true } abomonation = "0.7.3" abomonation_derive = { git = "https://github.com/winston-h-zhang/abomonation_derive.git" } +crossbeam = "0.8.2" [target.'cfg(not(target_arch = "wasm32"))'.dependencies] memmap = { version = "0.5.10", package = "memmap2" } @@ -74,7 +76,7 @@ rustyline = { version = "11.0", features = ["derive"], default-features = false [features] default = [] opencl = ["neptune/opencl"] -cuda = ["neptune/cuda"] +cuda = ["neptune/cuda", "nova/cuda"] # compile without ISA extensions portable = ["blstrs/portable", "pasta-msm/portable"] flamegraph = ["pprof/flamegraph", "pprof/criterion"] @@ -82,6 +84,7 @@ flamegraph = ["pprof/flamegraph", "pprof/criterion"] [dev-dependencies] assert_cmd = "2.0.12" cfg-if = "1.0.0" +ascii_table = "4.0.2" criterion = "0.4" hex = "0.4.3" pprof = { version = "0.11" } @@ -90,6 +93,9 @@ structopt = { version = "0.3", default-features = false } tap = "1.0.1" tempfile = "3.6.0" +[build-dependencies] +vergen = { version = "8", features = ["build", "git", "gitcl"] } + [workspace] resolver = "2" members = [ diff --git a/benches/fibonacci.rs b/benches/fibonacci.rs index bc2130e9ff..65aeef04f6 100644 --- a/benches/fibonacci.rs +++ b/benches/fibonacci.rs @@ -11,7 +11,6 @@ use lurk::{ eval::{ empty_sym_env, lang::{Coproc, Lang}, - Evaluator, }, field::LurkField, proof::nova::NovaProver, @@ -21,104 +20,73 @@ use lurk::{ store::Store, }; -const DEFAULT_REDUCTION_COUNT: usize = 100; -fn fib(store: &mut Store, a: u64) -> Ptr { - let program = format!( - r#" -(let ((fib (lambda (target) - (letrec ((next (lambda (a b target) - (if (= 0 target) - a - (next b - (+ a b) - (- target 1)))))) - (next 0 1 target))))) - (fib {a})) -"# - ); +fn fib(store: &mut Store, _a: u64) -> Ptr { + let program = r#" +(letrec ((next (lambda (a b) (next b (+ a b)))) + (fib (next 0 1))) + (fib)) +"#; - store.read(&program).unwrap() + store.read(program).unwrap() } -#[allow(dead_code)] -fn fibo_total(name: &str, iterations: u64, c: &mut BenchmarkGroup) { - let limit: usize = 10_000_000_000; - let lang_pallas = Lang::>::new(); - let lang_rc = Arc::new(lang_pallas.clone()); - let reduction_count = DEFAULT_REDUCTION_COUNT; - - // use cached public params - let pp = public_params(reduction_count, true, lang_rc.clone()).unwrap(); - - c.bench_with_input( - BenchmarkId::new(name.to_string(), iterations), - &(iterations), - |b, iterations| { - let mut store = Store::default(); - let env = empty_sym_env(&store); - let ptr = fib::(&mut store, black_box(*iterations)); - let prover = NovaProver::new(reduction_count, lang_pallas.clone()); +// The env output in the `fib_frame`th frame of the above, infinite Fibonacci computation will contain a binding of the +// nth Fibonacci number to `a`. +// means of computing it.] +fn fib_frame(n: usize) -> usize { + 11 + 16 * n +} - b.iter_batched( - || lang_rc.clone(), - |lang_rc| { - let result = prover - .evaluate_and_prove(&pp, ptr, env, &mut store, limit, lang_rc) - .unwrap(); - black_box(result); - }, - BatchSize::SmallInput, - ) - }, - ); +// Set the limit so the last step will be filled exactly, since Lurk currently only pads terminal/error continuations. +fn fib_limit(n: usize, rc: usize) -> usize { + let frame = fib_frame(n); + rc * (frame / rc + (frame % rc != 0) as usize) } -#[allow(dead_code)] -fn fibo_eval(name: &str, iterations: u64, c: &mut BenchmarkGroup) { - let limit = 10_000_000_000; - let lang_pallas = Lang::>::new(); +struct ProveParams { + fib_n: usize, + reduction_count: usize, +} - c.bench_with_input( - BenchmarkId::new(name.to_string(), iterations), - &(iterations), - |b, iterations| { - let mut store = Store::default(); - let ptr = fib::(&mut store, black_box(*iterations)); - b.iter(|| { - let result = - Evaluator::new(ptr, empty_sym_env(&store), &mut store, limit, &lang_pallas) - .eval(); - black_box(result) - }); - }, - ); +impl ProveParams { + fn name(&self) -> String { + let date = env!("VERGEN_GIT_COMMIT_DATE"); + let sha = env!("VERGEN_GIT_SHA"); + format!("{date}:{sha}:Fibonacci-rc={}", self.reduction_count) + } } -fn fibo_prove(name: &str, iterations: u64, c: &mut BenchmarkGroup) { - let limit = 10_000_000_000; +fn fibo_prove(prove_params: ProveParams, c: &mut BenchmarkGroup) { + let ProveParams { + fib_n, + reduction_count, + } = prove_params; + + let limit = fib_limit(fib_n, reduction_count); let lang_pallas = Lang::>::new(); let lang_rc = Arc::new(lang_pallas.clone()); - let reduction_count = DEFAULT_REDUCTION_COUNT; - let pp = public_params(reduction_count, true, lang_rc.clone()).unwrap(); + + let pp = public_params(prove_params.reduction_count, true, lang_rc.clone()).unwrap(); c.bench_with_input( - BenchmarkId::new(name.to_string(), iterations), - &iterations, - |b, iterations| { + BenchmarkId::new(prove_params.name(), fib_n), + &prove_params, + |b, _prove_params| { let mut store = Store::default(); + let env = empty_sym_env(&store); - let ptr = fib::(&mut store, black_box(*iterations)); + let ptr = fib::(&mut store, black_box(fib_n as u64)); let prover = NovaProver::new(reduction_count, lang_pallas.clone()); - let frames = prover + let frames = &prover .get_evaluation_frames(ptr, env, &mut store, limit, &lang_pallas) .unwrap(); b.iter_batched( - || (frames.clone(), lang_rc.clone()), // avoid cloning the frames in the benchmark + || (frames, lang_rc.clone()), |(frames, lang_rc)| { - let result = prover.prove(&pp, &frames, &mut store, lang_rc).unwrap(); - black_box(result); + let result = prover.prove(&pp, &frames, &mut store, lang_rc); + let _ = black_box(result); }, BatchSize::LargeInput, ) @@ -126,45 +94,27 @@ fn fibo_prove(name: &str, iterations: u64, c: &mut ); } -#[allow(dead_code)] -fn fibonacci_eval(c: &mut Criterion) { - static BATCH_SIZES: [u64; 2] = [100, 1000]; - let mut group: BenchmarkGroup<_> = c.benchmark_group("Evaluate"); - for size in BATCH_SIZES.iter() { - fibo_eval("Fibonacci", *size, &mut group); - } -} - fn fibonacci_prove(c: &mut Criterion) { - static BATCH_SIZES: [u64; 2] = [100, 1000]; + let _ = dbg!(&*lurk::config::CONFIG); + let reduction_counts = vec![100, 600, 700, 800, 900]; + let batch_sizes = vec![100, 200]; let mut group: BenchmarkGroup<_> = c.benchmark_group("Prove"); group.sampling_mode(SamplingMode::Flat); // This can take a *while* group.sample_size(10); - for size in BATCH_SIZES.iter() { - fibo_prove("Fibonacci", *size, &mut group); - } -} - -#[allow(dead_code)] -fn fibonacci_total(c: &mut Criterion) { - static BATCH_SIZES: [u64; 2] = [100, 1000]; - let mut group: BenchmarkGroup<_> = c.benchmark_group("Total"); - group.sampling_mode(SamplingMode::Flat); // This can take a *while* - group.sample_size(10); - - for size in BATCH_SIZES.iter() { - fibo_total("Fibonacci", *size, &mut group); + for fib_n in batch_sizes.iter() { + for reduction_count in reduction_counts.iter() { + let prove_params = ProveParams { + fib_n: *fib_n, + reduction_count: *reduction_count, + }; + fibo_prove(prove_params, &mut group); + } } } cfg_if::cfg_if! { if #[cfg(feature = "flamegraph")] { - // In order to collect a flamegraph, you need to indicate a profile time, see - // https://github.com/tikv/pprof-rs#integrate-with-criterion - // Example usage : - // cargo criterion --bench fibonacci --features flamegraph -- --profile-time 5 - // Warning: it is not recommended to run this on an M1 Mac, as making pprof work well there is hard. criterion_group! { name = benches; config = Criterion::default() @@ -172,8 +122,8 @@ cfg_if::cfg_if! { .sample_size(10) .with_profiler(pprof::criterion::PProfProfiler::new(100, pprof::criterion::Output::Flamegraph(None))); targets = - fibonacci_prove, - } + fibonacci_prove, + } } else { criterion_group! { name = benches; @@ -181,8 +131,8 @@ cfg_if::cfg_if! { .measurement_time(Duration::from_secs(120)) .sample_size(10); targets = - fibonacci_prove, - } + fibonacci_prove, + } } } diff --git a/benches/synthesis.rs b/benches/synthesis.rs index 0332d1eec5..facb419ac5 100644 --- a/benches/synthesis.rs +++ b/benches/synthesis.rs @@ -1,6 +1,7 @@ use std::{sync::Arc, time::Duration}; -use bellperson::{util_cs::test_cs::TestConstraintSystem, Circuit}; +use bellperson::util_cs::witness_cs::WitnessCS; +use bellperson::{Circuit, ConstraintSystem}; use criterion::{ black_box, criterion_group, criterion_main, measurement, BatchSize, BenchmarkGroup, BenchmarkId, Criterion, SamplingMode, @@ -61,12 +62,13 @@ fn synthesize( .unwrap(); let multiframe = - MultiFrame::from_frames(*reduction_count, &frames, &store, &lang_rc)[0].clone(); + MultiFrame::from_frames(*reduction_count, &frames, &store, lang_rc.clone())[0] + .clone(); b.iter_batched( || (multiframe.clone()), // avoid cloning the frames in the benchmark |multiframe| { - let mut cs = TestConstraintSystem::new(); + let mut cs = WitnessCS::new(); let result = multiframe.synthesize(&mut cs); let _ = black_box(result); }, diff --git a/build.rs b/build.rs new file mode 100644 index 0000000000..35aa44a6d8 --- /dev/null +++ b/build.rs @@ -0,0 +1,8 @@ +use std::error::Error; +use vergen::EmitBuilder; + +fn main() -> Result<(), Box> { + // Emit the instructions + EmitBuilder::builder().all_git().emit()?; + Ok(()) +} diff --git a/examples/fibonacci.rs b/examples/fibonacci.rs new file mode 100644 index 0000000000..5af21a9c55 --- /dev/null +++ b/examples/fibonacci.rs @@ -0,0 +1,72 @@ +use lurk::eval::{ + empty_sym_env, + lang::{Coproc, Lang}, + Evaluator, +}; +use lurk::field::LurkField; +use lurk::ptr::Ptr; +use lurk::store::Store; +use lurk::writer::Write; +use pasta_curves::pallas::Scalar; + +fn fib_expr(store: &mut Store) -> Ptr { + let program = r#" +(letrec ((next (lambda (a b) (next b (+ a b)))) + (fib (next 0 1))) + (fib)) +"#; + + store.read(program).unwrap() +} + +// The env output in the `fib_frame`th frame of the above, infinite Fibonacci computation contains a binding of the +// nth Fibonacci number to `a`. +fn fib_frame(n: usize) -> usize { + 11 + 16 * n +} + +// Set the limit so the last step will be filled exactly, since Lurk currently only pads terminal/error continuations. +#[allow(dead_code)] +fn fib_limit(n: usize, rc: usize) -> usize { + let frame = fib_frame(n); + rc * (frame / rc + (frame % rc != 0) as usize) +} + +fn lurk_fib(store: &mut Store, n: usize, _rc: usize) -> Ptr { + let lang = Lang::>::new(); + let frame_idx = fib_frame(n); + // let limit = fib_limit(n, rc); + let limit = frame_idx; + let fib_expr = fib_expr(store); + + let frames = Evaluator::new(fib_expr, empty_sym_env(store), store, limit, &lang) + .get_frames() + .unwrap(); + + let target_frame = frames.last().unwrap(); + + let target_env = target_frame.output.env; + + // The result is the value of the second binding (of `A`), in the target env. + // See relevant excerpt of execution trace below: + // + // INFO lurk::eval > Frame: 11 + // Expr: (NEXT B (+ A B)) + // Env: ((B . 1) (A . 0) ((NEXT . ))) + // Cont: Tail{ saved_env: (((NEXT . ))), continuation: LetRec{var: FIB, + // saved_env: (((NEXT . ))), body: (FIB), continuation: Tail{ saved_env: + // NIL, continuation: Outermost } } } + + let rest_bindings = store.cdr(&target_env).unwrap(); + let second_binding = store.car(&rest_bindings).unwrap(); + store.cdr(&second_binding).unwrap() +} + +fn main() { + let store = &mut Store::::new(); + let n: usize = std::env::args().collect::>()[1].parse().unwrap(); + + let fib = lurk_fib(store, n, 100); + + println!("Fib({n}) = {}", fib.fmt_to_string(store)); +} diff --git a/examples/itcalc.rs b/examples/itcalc.rs new file mode 100644 index 0000000000..438cfd515f --- /dev/null +++ b/examples/itcalc.rs @@ -0,0 +1,86 @@ +use ascii_table::AsciiTable; + +#[derive(Debug, Clone, Copy)] +struct Prog { + setup_iterations: usize, + loop_iterations: usize, +} + +fn real_iterations(prog: Prog, n: usize) -> usize { + prog.setup_iterations + prog.loop_iterations * n +} + +fn ceiling(n: usize, m: usize) -> usize { + n / m + (n % m != 0) as usize +} + +enum Opt { + Some(T), + None, + Empty, +} + +impl core::fmt::Display for Opt { + fn fmt(&self, fmt: &mut core::fmt::Formatter) -> Result<(), std::fmt::Error> { + match self { + Opt::None => "-".fmt(fmt), + Opt::Some(x) => x.fmt(fmt), + Opt::Empty => "".fmt(fmt), + } + } +} + +fn total_iterations(real_iterations: usize, rc: usize) -> Opt { + let steps = ceiling(real_iterations, rc); + let total_iterations = steps * rc; + + if real_iterations < rc { + Opt::None + } else { + Opt::Some(total_iterations) + } +} +fn rc_total_iterations(prog: Prog, n: usize, rc: usize) -> Opt { + let real_iterations = real_iterations(prog, n); + total_iterations(real_iterations, rc) +} + +fn analyze_rcs(prog: Prog, n: usize, rcs: &[usize]) -> Vec> { + let mut analysis = Vec::with_capacity(rcs.len() + 2); + analysis.push(Opt::Some(n)); + analysis.push(Opt::Empty); + analysis.extend(rcs.iter().map(|rc| rc_total_iterations(prog, n, *rc))); + analysis +} + +fn analyze_ncs_rcs(prog: Prog, ns: &[usize], rcs: &[usize]) -> Vec>> { + ns.iter().map(|n| analyze_rcs(prog, *n, rcs)).collect() +} + +/// Produces a table of 'real Lurk iterations' proved per loop-iteration/rc combination. +/// If the program has fewer real iterations than rc, no value is produced. +/// Otherwise, the number of total iterations (including padding) is used. +fn main() { + let args = std::env::args().collect::>(); + + let setup_iterations: usize = args[1].parse().unwrap(); + let loop_iterations: usize = args[2].parse().unwrap(); + let ns = [10, 20, 30, 40, 50, 60, 100, 200]; + let rcs = [100, 200, 300, 400, 500, 600, 700, 800, 900, 1000]; + + let prog = Prog { + setup_iterations, + loop_iterations, + }; + let analysis = analyze_ncs_rcs(prog, &ns, &rcs); + let mut table = AsciiTable::default(); + + table.column(0).set_header("n"); + table.column(1).set_header("rc"); + for (i, rc) in rcs.into_iter().enumerate() { + table.column(i + 2).set_header(rc.to_string()); + } + + println!("\nSetup iterations: {setup_iterations}; Iterations per loop: {loop_iterations}."); + table.print(analysis); +} diff --git a/examples/sha256.rs b/examples/sha256.rs index b4a3539a21..689b046fd0 100644 --- a/examples/sha256.rs +++ b/examples/sha256.rs @@ -184,14 +184,14 @@ fn main() { let coproc_expr = format!("({})", sym_str); let ptr = store.read(&coproc_expr).unwrap(); - let nova_prover = NovaProver::>::new(REDUCTION_COUNT, lang.clone()); + let nova_prover = NovaProver::>::new(REDUCTION_COUNT, lang); println!("Setting up public parameters (rc = {REDUCTION_COUNT})..."); let pp_start = Instant::now(); // see the documentation on `with_public_params` - let _res = with_public_params(REDUCTION_COUNT, lang_rc.clone(), |pp| { + with_public_params(REDUCTION_COUNT, lang_rc.clone(), |pp| { let pp_end = pp_start.elapsed(); println!("Public parameters took {:?}", pp_end); @@ -211,7 +211,7 @@ fn main() { println!("Verifying proof..."); let verify_start = Instant::now(); - let res = proof.verify(&pp, num_steps, &z0, &zi).unwrap(); + let res = proof.verify(pp, num_steps, &z0, &zi).unwrap(); let verify_end = verify_start.elapsed(); println!("Verify took {:?}", verify_end); diff --git a/src/circuit/circuit_frame.rs b/src/circuit/circuit_frame.rs index f7754b1474..7500a49a4b 100644 --- a/src/circuit/circuit_frame.rs +++ b/src/circuit/circuit_frame.rs @@ -1,11 +1,13 @@ +use std::collections::HashMap; use std::fmt::Debug; use std::marker::PhantomData; use bellperson::{ gadgets::{boolean::Boolean, num::AllocatedNum}, - util_cs::Comparable, + util_cs::{witness_cs::WitnessCS, Comparable}, Circuit, ConstraintSystem, SynthesisError, }; +use rayon::prelude::*; use crate::{ circuit::gadgets::{ @@ -13,8 +15,12 @@ use crate::{ data::GlobalAllocations, pointer::{AllocatedContPtr, AllocatedPtr, AsAllocatedHashComponents}, }, + config::CONFIG, field::LurkField, - hash_witness::{ConsName, ContName}, + hash::HashConst, + hash_witness::{ + ConsCircuitWitness, ConsName, ContCircuitWitness, ContName, HashCircuitWitnessCache, + }, store::NamedConstants, tag::Tag, }; @@ -51,12 +57,13 @@ pub struct CircuitFrame<'a, F: LurkField, T, W, C: Coprocessor> { } #[derive(Clone)] -pub struct MultiFrame<'a, F: LurkField, T: Copy, W, C: Coprocessor> { +pub struct MultiFrame<'a, F: LurkField, T: Copy + Sync, W: Sync, C: Coprocessor> { pub store: Option<&'a Store>, pub lang: Option>>, pub input: Option, pub output: Option, pub frames: Option>>, + pub cached_witness: Option>, pub count: usize, } @@ -82,8 +89,14 @@ impl<'a, F: LurkField, T: Clone + Copy, W: Copy, C: Coprocessor> CircuitFrame } } -impl<'a, F: LurkField, T: Clone + Copy + std::cmp::PartialEq, W: Copy, C: Coprocessor> - MultiFrame<'a, F, T, W, C> +impl< + 'a, + F: LurkField, + // T: Clone + Copy + std::cmp::PartialEq + Sync, + //W: Copy + Sync, + C: Coprocessor, + // > MultiFrame<'a, F, T, W, C> + > MultiFrame<'a, F, IO, Witness, C> { pub fn blank(count: usize, lang: Arc>) -> Self { Self { @@ -92,6 +105,7 @@ impl<'a, F: LurkField, T: Clone + Copy + std::cmp::PartialEq, W: Copy, C: Coproc input: None, output: None, frames: None, + cached_witness: None, count, } } @@ -102,9 +116,9 @@ impl<'a, F: LurkField, T: Clone + Copy + std::cmp::PartialEq, W: Copy, C: Coproc pub fn from_frames( count: usize, - frames: &[Frame], + frames: &[Frame, Witness, C>], store: &'a Store, - lang: &Arc>, + lang: Arc>, ) -> Vec { // `count` is the number of `Frames` to include per `MultiFrame`. let total_frames = frames.len(); @@ -139,6 +153,7 @@ impl<'a, F: LurkField, T: Clone + Copy + std::cmp::PartialEq, W: Copy, C: Coproc input: Some(chunk[0].input), output: Some(output), frames: Some(inner_frames), + cached_witness: None, count, }; @@ -151,7 +166,7 @@ impl<'a, F: LurkField, T: Clone + Copy + std::cmp::PartialEq, W: Copy, C: Coproc /// Make a dummy `MultiFrame`, duplicating `self`'s final `CircuitFrame`. pub(crate) fn make_dummy( count: usize, - circuit_frame: Option>, + circuit_frame: Option, Witness, C>>, store: &'a Store, lang: Arc>, ) -> Self { @@ -170,6 +185,7 @@ impl<'a, F: LurkField, T: Clone + Copy + std::cmp::PartialEq, W: Copy, C: Coproc input, output, frames, + cached_witness: None, count, } } @@ -184,6 +200,28 @@ impl<'a, F: LurkField, T: Clone + Copy + std::cmp::PartialEq, W: Copy, C: Coproc frames: &[CircuitFrame<'_, F, IO, Witness, C>], g: &GlobalAllocations, ) -> (AllocatedPtr, AllocatedPtr, AllocatedContPtr) { + if cs.is_witness_generator() && CONFIG.parallelism.synthesis.is_parallel() { + self.synthesize_frames_parallel(cs, store, input_expr, input_env, input_cont, frames, g) + } else { + self.synthesize_frames_sequential( + cs, store, input_expr, input_env, input_cont, frames, None, g, + ) + } + } + + pub fn synthesize_frames_sequential>( + &self, + cs: &mut CS, + store: &Store, + input_expr: AllocatedPtr, + input_env: AllocatedPtr, + input_cont: AllocatedContPtr, + frames: &[CircuitFrame<'_, F, IO, Witness, C>], + cons_and_cont_witnesses: Option, ContCircuitWitness)>>, + g: &GlobalAllocations, + ) -> (AllocatedPtr, AllocatedPtr, AllocatedContPtr) { + let mut hash_circuit_witness_cache = HashMap::new(); + let acc = (input_expr, input_env, input_cont); let (_, (new_expr, new_env, new_cont)) = @@ -221,22 +259,147 @@ impl<'a, F: LurkField, T: Clone + Copy + std::cmp::PartialEq, W: Copy, C: Coproc "cont mismatch" ); }; - ( - i + 1, - frame - .synthesize( - cs, - i, - allocated_io, - self.lang.as_ref().expect("Lang missing"), - g, + let (cons_witnesses, cont_witnesses) = + if let Some(cons_and_cont_witnesses) = &cons_and_cont_witnesses { + ( + Some(cons_and_cont_witnesses[i].0.clone()), + Some(cons_and_cont_witnesses[i].1.clone()), ) - .unwrap(), - ) + } else { + (None, None) + }; + + let new_allocated_io = frame + .synthesize( + cs, + i, + allocated_io, + self.lang.as_ref().expect("Lang missing"), + g, + &mut hash_circuit_witness_cache, + cons_witnesses, + cont_witnesses, + ) + .unwrap(); + + (i + 1, new_allocated_io) }); (new_expr, new_env, new_cont) } + + pub fn synthesize_frames_parallel>( + &self, + cs: &mut CS, + store: &Store, + input_expr: AllocatedPtr, + input_env: AllocatedPtr, + input_cont: AllocatedContPtr, + frames: &[CircuitFrame<'_, F, IO, Witness, C>], + g: &GlobalAllocations, + ) -> (AllocatedPtr, AllocatedPtr, AllocatedContPtr) { + assert!(cs.is_witness_generator()); + assert!(CONFIG.parallelism.synthesis.is_parallel()); + + // TODO: this probably belongs in config, perhaps per-Flow. + const MIN_CHUNK_SIZE: usize = 10; + + let num_frames = frames.len(); + + let chunk_size = CONFIG + .parallelism + .synthesis + .chunk_size(num_frames, MIN_CHUNK_SIZE); + + let css = frames + .par_chunks(chunk_size) + .enumerate() + .map(|(i, chunk)| { + let (input_expr, input_env, input_cont) = if i == 0 { + (input_expr.clone(), input_env.clone(), input_cont.clone()) + } else { + let previous_frame = &frames[i * chunk_size]; + let mut bogus_cs = WitnessCS::new(); + let x = previous_frame.input.unwrap().expr; + let input_expr = + AllocatedPtr::alloc_ptr(&mut bogus_cs, store, || Ok(&x)).unwrap(); + let y = previous_frame.input.unwrap().env; + let input_env = + AllocatedPtr::alloc_ptr(&mut bogus_cs, store, || Ok(&y)).unwrap(); + let z = previous_frame.input.unwrap().cont; + let input_cont = + AllocatedContPtr::alloc_cont_ptr(&mut bogus_cs, store, || Ok(&z)).unwrap(); + (input_expr, input_env, input_cont) + }; + + let cons_and_cont_witnesses = { + macro_rules! f { + () => { + |frame| { + let cons_circuit_witness: ConsCircuitWitness = frame + .witness + .map(|x| x.conses) + .unwrap_or_else(|| HashWitness::new_blank()) + .into(); + + let cons_constants: HashConst<'_, F> = + store.poseidon_constants().constants(4.into()); + + // Force generating the witness. This is the important part! + cons_circuit_witness.circuit_witness_blocks(store, cons_constants); + + let cont_circuit_witness: ContCircuitWitness = frame + .witness + .map(|x| x.conts) + .unwrap_or_else(|| HashWitness::new_blank()) + .into(); + + let cont_constants: HashConst<'_, F> = + store.poseidon_constants().constants(8.into()); + + // Force generating the witness. This is the important part! + cont_circuit_witness.circuit_witness_blocks(store, cont_constants); + + (cons_circuit_witness, cont_circuit_witness) + } + }; + } + + if CONFIG.parallelism.poseidon_witnesses.is_parallel() { + chunk.par_iter().map(f!()).collect::>() + } else { + chunk.iter().map(f!()).collect::>() + } + }; + + let mut cs = WitnessCS::new(); + + let output = self.synthesize_frames_sequential( + &mut cs, + store, + input_expr, + input_env, + input_cont, + chunk, + Some(cons_and_cont_witnesses), + g, + ); + + (cs, output) + }) + .collect::>(); + + let mut final_output = None; + + for (frames_cs, output) in css.into_iter() { + final_output = Some(output); + + let aux = frames_cs.aux_slice(); + cs.extend_aux(aux); + } + + final_output.unwrap() + } } impl> CircuitFrame<'_, F, T, W, C> { @@ -245,13 +408,19 @@ impl> CircuitFrame<'_, } } -impl> MultiFrame<'_, F, T, W, C> { +impl> + MultiFrame<'_, F, T, W, C> +{ pub fn precedes(&self, maybe_next: &Self) -> bool { self.output == maybe_next.input } } -impl> Provable for MultiFrame<'_, F, IO, W, C> { +impl< + F: LurkField, // W: Copy + Sync, + C: Coprocessor, + > Provable for MultiFrame<'_, F, IO, Witness, C> +{ fn public_inputs(&self) -> Vec { let mut inputs: Vec<_> = Vec::with_capacity(Self::public_input_size()); @@ -289,29 +458,44 @@ impl> CircuitFrame<'_, F, IO, Witness, C> inputs: AllocatedIO, lang: &Lang, g: &GlobalAllocations, + _hash_circuit_witness_cache: &mut HashCircuitWitnessCache, // Currently unused. + cons_circuit_witness: Option>, + cont_circuit_witness: Option>, ) -> Result, SynthesisError> { let (input_expr, input_env, input_cont) = inputs; - let mut reduce = |store| { - let cons_witness = self - .witness - .map(|x| x.conses) - .unwrap_or_else(|| HashWitness::new_blank()); + let reduce = |store| { + let cons_circuit_witness = if let Some(ccw) = cons_circuit_witness { + ccw + } else { + let cons_witness = self + .witness + .map(|x| x.conses) + .unwrap_or_else(|| HashWitness::new_blank()); + + (cons_witness).into() + }; + let mut allocated_cons_witness = AllocatedConsWitness::from_cons_witness( &mut cs.namespace(|| format!("allocated_cons_witness {i}")), store, - &cons_witness, + &cons_circuit_witness, )?; - let cont_witness = self - .witness - .map(|x| x.conts) - .unwrap_or_else(|| HashWitness::new_blank()); + let cont_circuit_witness = if let Some(ccw) = cont_circuit_witness { + ccw + } else { + let cont_witness = self + .witness + .map(|x| x.conts) + .unwrap_or_else(|| HashWitness::new_blank()); + (cont_witness).into() + }; let mut allocated_cont_witness = AllocatedContWitness::from_cont_witness( &mut cs.namespace(|| format!("allocated_cont_witness {i}")), store, - &cont_witness, + &cont_circuit_witness, )?; reduce_expression( @@ -939,7 +1123,6 @@ fn reduce_expression, C: Coprocessor>( )?; allocated_cons_witness.assert_final_invariants(); - allocated_cont_witness.witness.all_names(); allocated_cont_witness.assert_final_invariants(); // dbg!(&result_expr.fetch_and_write_str(store)); @@ -5292,7 +5475,7 @@ mod tests { _p: Default::default(), }], store, - &lang, + lang.clone(), ); let multiframe = &multiframes[0]; @@ -5303,6 +5486,7 @@ mod tests { .expect("failed to synthesize"); let delta = cs.delta(&cs_blank, false); + dbg!(&delta); assert!(delta == Delta::Equal); //println!("{}", print_cs(&cs)); @@ -5408,7 +5592,7 @@ mod tests { DEFAULT_REDUCTION_COUNT, &[frame], store, - &lang, + lang.clone(), )[0] .clone() .synthesize(&mut cs) @@ -5488,7 +5672,7 @@ mod tests { DEFAULT_REDUCTION_COUNT, &[frame], store, - &lang, + lang.clone(), )[0] .clone() .synthesize(&mut cs) @@ -5570,7 +5754,7 @@ mod tests { DEFAULT_REDUCTION_COUNT, &[frame], store, - &lang, + lang.clone(), )[0] .clone() .synthesize(&mut cs) @@ -5652,7 +5836,7 @@ mod tests { DEFAULT_REDUCTION_COUNT, &[frame], store, - &lang, + lang, )[0] .clone() .synthesize(&mut cs) diff --git a/src/circuit/gadgets/data.rs b/src/circuit/gadgets/data.rs index d487f21580..d97c0b4d29 100644 --- a/src/circuit/gadgets/data.rs +++ b/src/circuit/gadgets/data.rs @@ -4,6 +4,7 @@ use bellperson::{ }; use neptune::{ circuit2::poseidon_hash_allocated as poseidon_hash, + circuit2_witness::poseidon_hash_allocated_witness, poseidon::{Arity, PoseidonConstants}, }; @@ -285,11 +286,15 @@ impl GlobalAllocations { } pub(crate) fn hash_poseidon, F: LurkField, A: Arity>( - cs: CS, + mut cs: CS, preimage: Vec>, constants: &PoseidonConstants, ) -> Result, SynthesisError> { - poseidon_hash(cs, preimage, constants) + if cs.is_witness_generator() { + poseidon_hash_allocated_witness(&mut cs, &preimage, constants) + } else { + poseidon_hash(cs, preimage, constants) + } } impl Ptr { diff --git a/src/circuit/gadgets/hashes.rs b/src/circuit/gadgets/hashes.rs index 01db245ca2..952ee736eb 100644 --- a/src/circuit/gadgets/hashes.rs +++ b/src/circuit/gadgets/hashes.rs @@ -1,17 +1,21 @@ +use std::collections::HashMap; use std::fmt::Debug; use bellperson::{gadgets::num::AllocatedNum, ConstraintSystem, SynthesisError}; use neptune::circuit2::poseidon_hash_allocated as poseidon_hash; +use neptune::circuit2_witness::{poseidon_hash_allocated_witness, poseidon_hash_scalar_witness}; use crate::circuit::gadgets::pointer::{AllocatedPtr, AsAllocatedHashComponents}; -use crate::field::LurkField; +use crate::config::CONFIG; +use crate::field::{FWrap, LurkField}; use crate::hash::{HashConst, HashConstants}; -use crate::hash_witness::{ConsName, ConsWitness, ContName, ContWitness, HashName, Stub}; +use crate::hash_witness::{ + ConsCircuitWitness, ConsName, ContCircuitWitness, ContName, Digest, HashName, WitnessBlock, +}; +use crate::ptr::ContPtr; use crate::store::Store; -use crate::tag::ExprTag; -use crate::z_ptr::ZExprPtr; #[derive(Clone)] pub struct AllocatedHash { @@ -60,13 +64,14 @@ impl Slot { - pub(crate) witness: &'a VanillaWitness, // Sometimes used for debugging. +pub struct AllocatedWitness { + #[allow(dead_code)] + // pub(crate) witness: &'a VanillaWitness, // Sometimes used for debugging. slots: Vec>, } -impl<'a, VanillaWitness, Name: Debug, F: LurkField, PreimageType> - AllocatedWitness<'a, VanillaWitness, Name, AllocatedHash> +impl + AllocatedWitness> { pub fn assert_final_invariants(&self) { if self.slots[0].is_blank() { @@ -87,16 +92,36 @@ impl<'a, VanillaWitness, Name: Debug, F: LurkField, PreimageType> } } -pub(crate) type AllocatedConsWitness<'a, F> = - AllocatedWitness<'a, ConsWitness, ConsName, AllocatedPtrHash>; -pub(crate) type AllocatedContWitness<'a, F> = - AllocatedWitness<'a, ContWitness, ContName, AllocatedNumHash>; +pub(crate) type AllocatedConsWitness<'a, F> = AllocatedWitness>; +pub(crate) type AllocatedContWitness<'a, F> = AllocatedWitness>; + +type HashCircuitWitnessCache = HashMap>, (Vec, F)>; impl AllocatedPtrHash { fn alloc>( cs: &mut CS, constants: &HashConstants, preimage: Vec>, + hash_circuit_witness_cache: Option<&mut HashCircuitWitnessCache>, + ) -> Result { + let constants = constants.constants((2 * preimage.len()).into()); + + let pr: Vec> = preimage + .iter() + .flat_map(|x| x.as_allocated_hash_components()) + .cloned() + .collect(); + + let digest = constants.hash(cs, pr, hash_circuit_witness_cache)?; + + Ok(Self { preimage, digest }) + } + + fn alloc_with_witness>( + cs: &mut CS, + constants: &HashConstants, + preimage: Vec>, + block: &(WitnessBlock, Digest), ) -> Result { let constants = constants.constants((2 * preimage.len()).into()); @@ -106,7 +131,7 @@ impl AllocatedPtrHash { .cloned() .collect(); - let digest = constants.hash(cs, pr)?; + let digest = constants.hash_with_witness(cs, pr, Some(block))?; Ok(Self { preimage, digest }) } @@ -117,71 +142,182 @@ impl AllocatedNumHash { cs: &mut CS, constants: &HashConstants, preimage: Vec>, + hash_circuit_witness_cache: Option<&mut HashCircuitWitnessCache>, + ) -> Result { + let constants = constants.constants(preimage.len().into()); + + let pr: Vec> = preimage.to_vec(); + + let digest = constants.hash(cs, pr, hash_circuit_witness_cache)?; + + Ok(Self { preimage, digest }) + } + fn alloc_with_witness>( + cs: &mut CS, + constants: &HashConstants, + preimage: Vec>, + block: &(WitnessBlock, Digest), ) -> Result { let constants = constants.constants(preimage.len().into()); let pr: Vec> = preimage.to_vec(); - let digest = constants.hash(cs, pr)?; + let digest = constants.hash_with_witness(cs, pr, Some(block))?; Ok(Self { preimage, digest }) } } +impl<'a, F: LurkField> HashConst<'a, F> { + #[allow(dead_code)] + fn cache_hash_witness>( + &self, + cs: &mut CS, + preimage: Vec, + hash_circuit_witness_cache: &mut HashCircuitWitnessCache, + ) { + macro_rules! hash { + ($c:ident) => {{ + assert!(cs.is_witness_generator()); + let key: Vec> = preimage.iter().map(|f| FWrap(*f)).collect(); + + let _ = hash_circuit_witness_cache + .entry(key) + .or_insert_with(|| poseidon_hash_scalar_witness(&preimage, $c)); + }}; + } + match self { + HashConst::A3(c) => hash!(c), + HashConst::A4(c) => hash!(c), + HashConst::A6(c) => hash!(c), + HashConst::A8(c) => hash!(c), + } + } +} + +impl<'a, F: LurkField> HashConst<'a, F> { + pub fn cache_hash_witness_aux(&self, preimage: Vec) -> (Vec, F) { + macro_rules! hash { + ($c:ident) => {{ + poseidon_hash_scalar_witness(&preimage, $c) + }}; + } + match self { + HashConst::A3(c) => hash!(c), + HashConst::A4(c) => hash!(c), + HashConst::A6(c) => hash!(c), + HashConst::A8(c) => hash!(c), + } + } +} + impl<'a, F: LurkField> HashConst<'a, F> { fn hash>( &self, cs: &mut CS, preimage: Vec>, + hash_circuit_witness_cache: Option<&mut HashCircuitWitnessCache>, + ) -> Result, SynthesisError> { + let witness_block = if cs.is_witness_generator() { + hash_circuit_witness_cache.map(|cache| { + let key = preimage + .iter() + .map(|allocated| FWrap(allocated.get_value().unwrap())) + .collect::>(); + + let cached = cache.get(&key).unwrap(); + cached + }) + } else { + None + }; + + self.hash_with_witness(cs, preimage, witness_block) + } + + fn hash_with_witness>( + &self, + cs: &mut CS, + preimage: Vec>, + circuit_witness: Option<&(WitnessBlock, Digest)>, ) -> Result, SynthesisError> { + macro_rules! hash { + ($c:ident) => { + if cs.is_witness_generator() { + if let Some((aux_buf, res)) = circuit_witness { + cs.extend_aux(aux_buf); + + AllocatedNum::alloc(cs, || Ok(*res)) + } else { + // We have no cache, just allocate the witness. + poseidon_hash_allocated_witness(cs, &preimage, $c) + } + } else { + // CS is not a witness generator, just hash. + poseidon_hash(cs, preimage, $c) + } + }; + } match self { - HashConst::A3(c) => poseidon_hash(cs, preimage, c), - HashConst::A4(c) => poseidon_hash(cs, preimage, c), - HashConst::A6(c) => poseidon_hash(cs, preimage, c), - HashConst::A8(c) => poseidon_hash(cs, preimage, c), + HashConst::A3(c) => hash!(c), + HashConst::A4(c) => hash!(c), + HashConst::A6(c) => hash!(c), + HashConst::A8(c) => hash!(c), } } } impl<'a, F: LurkField> AllocatedConsWitness<'a, F> { pub fn from_cons_witness>( - cs0: &mut CS, + cs: &mut CS, s: &Store, - cons_witness: &'a ConsWitness, + cons_circuit_witness: &'a ConsCircuitWitness, ) -> Result { + let cons_witness = cons_circuit_witness.hash_witness; let mut slots = Vec::with_capacity(cons_witness.slots.len()); - for (i, (name, p)) in cons_witness.slots.iter().enumerate() { - let cs = &mut cs0.namespace(|| format!("slot-{i}")); - let (car_ptr, cdr_ptr, cons_hash) = match p { - Stub::Dummy => ( - Some(ZExprPtr::from_parts(ExprTag::Nil, F::ZERO)), - Some(ZExprPtr::from_parts(ExprTag::Nil, F::ZERO)), - None, - ), - Stub::Blank => (None, None, None), - Stub::Value(hash) => ( - s.hash_expr(&hash.car), - s.hash_expr(&hash.cdr), - s.hash_expr(&hash.cons), - ), + let names_and_ptrs = cons_circuit_witness.names_and_ptrs(s); + let cons_constants: HashConst<'_, F> = s.poseidon_constants().constants(4.into()); + + let circuit_witness_blocks = + if cs.is_witness_generator() && CONFIG.witness_generation.precompute_neptune { + Some(cons_circuit_witness.circuit_witness_blocks(s, cons_constants)) + } else { + None }; + for (i, (name, spr)) in names_and_ptrs.iter().enumerate() { + let cs = &mut cs.namespace(|| format!("slot-{i}")); + let allocated_car = AllocatedPtr::alloc(&mut cs.namespace(|| "car"), || { - car_ptr.ok_or(SynthesisError::AssignmentMissing) + spr.as_ref() + .map(|x| x.car) + .ok_or(SynthesisError::AssignmentMissing) })?; let allocated_cdr = AllocatedPtr::alloc(&mut cs.namespace(|| "cdr"), || { - cdr_ptr.ok_or(SynthesisError::AssignmentMissing) + spr.as_ref() + .map(|x| x.cdr) + .ok_or(SynthesisError::AssignmentMissing) })?; - let allocated_hash = AllocatedPtrHash::alloc( - &mut cs.namespace(|| "cons"), - s.poseidon_constants(), - vec![allocated_car, allocated_cdr], - )?; + let allocated_hash = if let Some(blocks) = circuit_witness_blocks { + AllocatedPtrHash::alloc_with_witness( + &mut cs.namespace(|| "cons"), + s.poseidon_constants(), + vec![allocated_car, allocated_cdr], + &blocks[i], + )? + } else { + AllocatedPtrHash::alloc( + &mut cs.namespace(|| "cons"), + s.poseidon_constants(), + vec![allocated_car, allocated_cdr], + None, + )? + }; - if cons_hash.is_some() { + if spr.is_some() { slots.push(Slot::new(*name, allocated_hash)); } else { slots.push(Slot::new_dummy(allocated_hash)); @@ -189,8 +325,7 @@ impl<'a, F: LurkField> AllocatedConsWitness<'a, F> { } Ok(Self { - witness: cons_witness, - slots, + slots: slots.to_vec(), }) } @@ -225,36 +360,59 @@ impl<'a, F: LurkField> AllocatedConsWitness<'a, F> { } impl<'a, F: LurkField> AllocatedContWitness<'a, F> { + // Currently unused, but not necessarily useless. + #[allow(dead_code)] + fn make_hash_cache>( + cs: &mut CS, + names_and_ptrs: &[(ContName, (Option>, Option>))], + hash_constants: HashConst<'_, F>, + ) -> Option> { + if cs.is_witness_generator() { + let mut c = HashMap::new(); + + let results = names_and_ptrs + .iter() + .map(|(_, (_, p))| { + let preimage = p.as_ref().unwrap(); + ( + preimage.clone(), + hash_constants.cache_hash_witness_aux(preimage.to_vec()), + ) + }) + .collect::>(); + + for (preimage, x) in results.iter() { + let key: Vec> = preimage.iter().map(|f| FWrap(*f)).collect(); + c.insert(key, x.clone()); + } + Some(c) + } else { + None + } + } + pub fn from_cont_witness>( - cs0: &mut CS, + cs: &mut CS, s: &Store, - cont_witness: &'a ContWitness, + cont_circuit_witness: &'a ContCircuitWitness, ) -> Result { + let cont_witness = cont_circuit_witness.hash_witness; let mut slots = Vec::with_capacity(cont_witness.slots.len()); - for (i, (name, p)) in cont_witness.slots.iter().enumerate() { - let cs = &mut cs0.namespace(|| format!("slot-{i}")); - let (cont_ptr, components) = match p { - Stub::Dummy => ( - None, - Some([ - F::ZERO, - F::ZERO, - F::ZERO, - F::ZERO, - F::ZERO, - F::ZERO, - F::ZERO, - F::ZERO, - ]), - ), - Stub::Blank => (None, None), - Stub::Value(cont) => ( - Some(cont.cont_ptr), - s.get_hash_components_cont(&cont.cont_ptr), - ), + let names_and_ptrs = cont_circuit_witness.names_and_ptrs(s); + let cont_constants: HashConst<'_, F> = s.poseidon_constants().constants(8.into()); + + let circuit_witness_blocks = + if cs.is_witness_generator() && CONFIG.witness_generation.precompute_neptune { + Some(cont_circuit_witness.circuit_witness_blocks(s, cont_constants)) + } else { + None }; + for (i, (name, spr)) in names_and_ptrs.iter().enumerate() { + let cs = &mut cs.namespace(|| format!("slot-{i}")); + + let components = spr.as_ref().map(|spr| spr.components); let allocated_components = if let Some(components) = components { components .iter() @@ -279,23 +437,30 @@ impl<'a, F: LurkField> AllocatedContWitness<'a, F> { .collect::>() }; - let allocated_hash = AllocatedNumHash::alloc( - &mut cs.namespace(|| "cont"), - s.poseidon_constants(), - allocated_components, - )?; + let allocated_hash = if let Some(blocks) = circuit_witness_blocks { + AllocatedNumHash::alloc_with_witness( + &mut cs.namespace(|| "cont"), + s.poseidon_constants(), + allocated_components, + &blocks[i], + )? + } else { + AllocatedNumHash::alloc( + &mut cs.namespace(|| "cont"), + s.poseidon_constants(), + allocated_components, + None, + )? + }; - if cont_ptr.is_some() { + if spr.as_ref().map(|spr| spr.cont).is_some() { slots.push(Slot::new(*name, allocated_hash)); } else { slots.push(Slot::new_dummy(allocated_hash)); } } - Ok(Self { - witness: cont_witness, - slots, - }) + Ok(Self { slots }) } pub fn get_components( @@ -312,7 +477,6 @@ impl<'a, F: LurkField> AllocatedContWitness<'a, F> { if !expect_dummy { match allocated_name { Err(_) => { - dbg!(&self.witness); panic!("requested {:?} but found a dummy allocation", name) } Ok(alloc_name) => { diff --git a/src/config.rs b/src/config.rs new file mode 100644 index 0000000000..e95e89c9e3 --- /dev/null +++ b/src/config.rs @@ -0,0 +1,152 @@ +//! Global config for parallelism. +use anyhow::bail; +use once_cell::sync::Lazy; + +pub static CONFIG: Lazy = Lazy::new(init_config); + +fn canned_config_from_env() -> Option { + if let Ok(x) = std::env::var("LURK_CANNED_CONFIG") { + let canned = CannedConfig::try_from(x.as_str()).ok(); + + dbg!(&canned); + + canned + } else { + None + } +} + +#[derive(Default, Debug)] +pub enum Flow { + #[default] + Sequential, + Parallel, // Try to be smart. + ParallelN(usize), // How many threads to use? (Advisory, might be ignored.) +} + +impl Flow { + pub fn is_sequential(&self) -> bool { + matches!(self, Self::Sequential) + } + + pub fn is_parallel(&self) -> bool { + !self.is_sequential() + } + + pub fn num_threads(&self) -> usize { + match self { + Self::Sequential => 1, + Self::Parallel => num_cpus::get(), + Self::ParallelN(threads) => *threads, + } + } + + pub fn chunk_size(&self, total_n: usize, min_chunk_size: usize) -> usize { + if self.is_sequential() { + total_n + } else { + let num_threads = self.num_threads(); + let divides_evenly = total_n % num_threads == 0; + + ((total_n / num_threads) + !divides_evenly as usize).max(min_chunk_size) + } + } +} + +#[derive(Default, Debug)] +pub struct ParallelConfig { + pub recursive_steps: Flow, // Multiple `StepCircuit`s. + pub synthesis: Flow, // Synthesis (within one `StepCircuit`) + pub poseidon_witnesses: Flow, // The poseidon witness part of synthesis. +} + +/// Should we use optimized witness-generation when possible? +#[derive(Debug, Default)] +pub struct WitnessGeneration { + // NOTE: Neptune itself *will* do this transparently at the level of individual hashes, where possible. + // so this configuration is only required for higher-level decisions. + pub precompute_neptune: bool, +} + +#[derive(Default, Debug)] +pub struct Config { + pub parallelism: ParallelConfig, + pub witness_generation: WitnessGeneration, +} + +impl Config { + fn fully_sequential() -> Self { + Self { + parallelism: ParallelConfig { + recursive_steps: Flow::Sequential, + synthesis: Flow::Sequential, + poseidon_witnesses: Flow::Sequential, + }, + witness_generation: WitnessGeneration { + precompute_neptune: false, + }, + } + } + + fn max_parallel_simple() -> Self { + Self { + parallelism: ParallelConfig { + recursive_steps: Flow::Parallel, + synthesis: Flow::Parallel, + poseidon_witnesses: Flow::Parallel, + }, + witness_generation: WitnessGeneration { + precompute_neptune: true, + }, + } + } + + fn parallel_steps_only() -> Self { + Self { + parallelism: ParallelConfig { + recursive_steps: Flow::Parallel, + synthesis: Flow::Sequential, + poseidon_witnesses: Flow::Sequential, + }, + witness_generation: WitnessGeneration { + precompute_neptune: true, + }, + } + } +} + +#[derive(Debug)] +enum CannedConfig { + FullySequential, + MaxParallelSimple, + ParallelStepsOnly, +} + +impl From for Config { + fn from(canned: CannedConfig) -> Self { + match canned { + CannedConfig::FullySequential => Self::fully_sequential(), + CannedConfig::MaxParallelSimple => Self::max_parallel_simple(), + CannedConfig::ParallelStepsOnly => Self::parallel_steps_only(), + } + } +} + +impl TryFrom<&str> for CannedConfig { + type Error = anyhow::Error; + + fn try_from(s: &str) -> Result { + match s { + "FULLY-SEQUENTIAL" => Ok(Self::FullySequential), + "MAX-PARALLEL-SIMPLE" => Ok(Self::MaxParallelSimple), + "PARALLEL-STEPS-ONLY" => Ok(Self::ParallelStepsOnly), + _ => bail!("Invalid CannedConfig: {s}"), + } + } +} + +fn init_config() -> Config { + canned_config_from_env() + .map(|x| x.into()) + .unwrap_or_else(Config::fully_sequential) +} diff --git a/src/hash_witness.rs b/src/hash_witness.rs index 3385236855..151c373c98 100644 --- a/src/hash_witness.rs +++ b/src/hash_witness.rs @@ -2,12 +2,17 @@ use std::collections::HashMap; use std::fmt::Debug; use std::marker::PhantomData; +use anyhow::{anyhow, Result}; +use once_cell::sync::OnceCell; + use crate::cont::Continuation; use crate::error::ReductionError; -use crate::field::LurkField; +use crate::field::{FWrap, LurkField}; +use crate::hash::HashConst; use crate::ptr::{ContPtr, Ptr}; use crate::store::{self, Store}; use crate::tag::ExprTag; +use crate::z_ptr::{ZContPtr, ZExprPtr}; pub const MAX_CONSES_PER_REDUCTION: usize = 11; pub const MAX_CONTS_PER_REDUCTION: usize = 2; @@ -25,6 +30,27 @@ impl Stub { } } +pub trait ContentAddressed +where + Self::ScalarPtrRepr: CAddr, +{ + type ScalarPtrRepr; + + fn preimage(&self, s: &Store) -> Result> { + self.to_scalar_ptr_repr(s) + .map(|x| x.preimage()) + .ok_or_else(|| anyhow!("failed to get preimage")) + } + fn to_scalar_ptr_repr(&self, s: &Store) -> Option; + fn to_dummy_scalar_ptr_repr() -> Option { + unimplemented!() + } +} + +pub trait CAddr { + fn preimage(&self) -> Preimage; +} + #[derive(Clone, Copy, Debug, PartialEq, Eq)] pub struct Cons { pub car: Ptr, @@ -32,6 +58,106 @@ pub struct Cons { pub cons: Ptr, } +#[derive(Clone, Debug)] +pub struct ScalarCons { + pub car: ZExprPtr, + pub cdr: ZExprPtr, + pub cons: Option>, +} + +#[derive(Clone, Debug)] +pub struct ScalarCont { + pub components: [F; 8], + pub cont: Option>, +} + +impl, T: CAddr> ContentAddressed + for Stub +{ + type ScalarPtrRepr = T; + + fn to_scalar_ptr_repr(&self, s: &Store) -> Option { + match self { + Stub::Dummy => C::to_dummy_scalar_ptr_repr(), + Stub::Blank => None, + Stub::Value(v) => v.to_scalar_ptr_repr(s), + } + } +} + +impl ContentAddressed for Cons { + type ScalarPtrRepr = ScalarCons; + + fn preimage(&self, s: &Store) -> Result> { + let spr = self.to_scalar_ptr_repr(s).ok_or(anyhow!("missing"))?; + Ok(spr.preimage()) + } + + fn to_scalar_ptr_repr(&self, s: &Store) -> Option { + let car = s.hash_expr(&self.car)?; + let cdr = s.hash_expr(&self.cdr)?; + let cons = Some(s.hash_expr(&self.cons)?); + Some(ScalarCons { car, cdr, cons }) + } + + fn to_dummy_scalar_ptr_repr() -> Option { + let car = ZExprPtr::from_parts(ExprTag::Nil, F::ZERO); + let cdr = ZExprPtr::from_parts(ExprTag::Nil, F::ZERO); + let cons = None; + Some(ScalarCons { car, cdr, cons }) + } +} + +impl ContentAddressed for Cont { + type ScalarPtrRepr = ScalarCont; + + fn preimage(&self, s: &Store) -> Result> { + let spr = self.to_scalar_ptr_repr(s).ok_or(anyhow!("missing"))?; + Ok(spr.preimage()) + } + + fn to_scalar_ptr_repr(&self, s: &Store) -> Option { + let cont = s.hash_cont(&self.cont_ptr)?; + let components = s.get_hash_components_cont(&self.cont_ptr).unwrap(); + Some(ScalarCont { + cont: Some(cont), + components, + }) + } + + fn to_dummy_scalar_ptr_repr() -> Option { + let cont = None; + let components = [ + F::ZERO, + F::ZERO, + F::ZERO, + F::ZERO, + F::ZERO, + F::ZERO, + F::ZERO, + F::ZERO, + ]; + Some(ScalarCont { cont, components }) + } +} + +impl CAddr for ScalarCons { + fn preimage(&self) -> Preimage { + vec![ + self.car.tag_field(), + *self.car.value(), + self.cdr.tag_field(), + *self.cdr.value(), + ] + } +} + +impl CAddr for ScalarCont { + fn preimage(&self) -> Preimage { + self.components.to_vec() + } +} + #[derive(Clone, Copy, Debug, PartialEq, Eq)] pub struct Cont { pub cont_ptr: ContPtr, @@ -76,7 +202,7 @@ pub enum ConsName { Expanded, } -pub trait HashName { +pub trait HashName: Copy { fn index(&self) -> usize; } @@ -218,12 +344,76 @@ impl ConsStub { impl ContStub {} +pub type Preimage = Vec; +pub type PreimageKey = Vec>; +pub type WitnessBlock = Vec; +pub type Digest = F; +pub type HashCircuitWitnessCache = HashMap, (WitnessBlock, Digest)>; +pub type HashCircuitWitnessBlocks = Vec<(WitnessBlock, Digest)>; + #[derive(Clone, Copy, Debug, PartialEq)] pub struct HashWitness { pub slots: [(Name, Stub); L], _f: PhantomData, } +#[derive(Clone, Debug, PartialEq)] +pub struct CircuitHashWitness, const L: usize, F: LurkField> +{ + pub hash_witness: HashWitness, + pub names_and_ptrs: OnceCell)>>, + pub circuit_witness_blocks: OnceCell>, +} + +impl, const L: usize, F: LurkField> + From> for CircuitHashWitness +{ + fn from(hash_witness: HashWitness) -> Self { + Self { + hash_witness, + names_and_ptrs: OnceCell::new(), + circuit_witness_blocks: OnceCell::new(), + } + } +} + +impl, const L: usize, F: LurkField> + CircuitHashWitness +where + T::ScalarPtrRepr: Debug, +{ + pub fn names_and_ptrs(&self, s: &Store) -> &Vec<(Name, Option)> { + self.names_and_ptrs.get_or_init(|| { + self.hash_witness + .slots + .iter() + .map(|(name, x)| (*name, (*x).to_scalar_ptr_repr(s))) + .collect::>() + }) + } + + /// Precompute the witness blocks for all the named hashes. + pub fn circuit_witness_blocks( + &self, + s: &Store, + hash_constants: HashConst<'_, F>, + ) -> &HashCircuitWitnessBlocks { + self.circuit_witness_blocks.get_or_init(|| { + // TODO: In order to be interesting or useful, this should call a Neptune + // API function (which doesn't exist yet) to perform batched witness-generation. + // That code could be optimized and parallelized, eventually even performed on GPU. + self.names_and_ptrs(s) + .iter() + .map(|(_, scalar_ptr_repr)| { + let scalar_ptr_repr = scalar_ptr_repr.as_ref().unwrap(); + let preimage = scalar_ptr_repr.preimage(); + hash_constants.cache_hash_witness_aux(preimage) + }) + .collect::>() + }) + } +} + impl HashWitness { pub fn length() -> usize { L @@ -233,6 +423,9 @@ impl HashWitness pub type ConsWitness = HashWitness, MAX_CONSES_PER_REDUCTION, F>; pub type ContWitness = HashWitness, MAX_CONTS_PER_REDUCTION, F>; +pub type ConsCircuitWitness = CircuitHashWitness, MAX_CONSES_PER_REDUCTION, F>; +pub type ContCircuitWitness = CircuitHashWitness, MAX_CONTS_PER_REDUCTION, F>; + impl HashWitness, MAX_CONSES_PER_REDUCTION, F> { #[allow(dead_code)] fn assert_specific_invariants(&self, store: &Store) { diff --git a/src/lib.rs b/src/lib.rs index fe4f206e89..430457954c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -7,6 +7,7 @@ extern crate alloc; pub mod cache_map; pub mod circuit; +pub mod config; pub mod cont; pub mod coprocessor; pub mod eval; diff --git a/src/main.rs b/src/main.rs index b1079ccc8c..61b0763038 100644 --- a/src/main.rs +++ b/src/main.rs @@ -7,5 +7,10 @@ fn main() -> Result<()> { // do not replace by let _ = ... let _metrics_handle = lurk_metrics::MetricsSink::init(); pretty_env_logger::init(); + println!( + "commit: {} {}", + env!("VERGEN_GIT_COMMIT_DATE"), + env!("VERGEN_GIT_SHA") + ); cli::parse_and_run() } diff --git a/src/proof/groth16.rs b/src/proof/groth16.rs index 4b65a8bfd4..c405ce4ab2 100644 --- a/src/proof/groth16.rs +++ b/src/proof/groth16.rs @@ -145,7 +145,8 @@ impl> Groth16Prover { let frames = Evaluator::generate_frames(expr, env, store, limit, padding_predicate, &lang)?; store.hydrate_scalar_cache(); - let multiframes = MultiFrame::from_frames(self.reduction_count(), &frames, store, &lang); + let multiframes = + MultiFrame::from_frames(self.reduction_count(), &frames, store, lang.clone()); let mut proofs = Vec::with_capacity(multiframes.len()); let mut statements = Vec::with_capacity(multiframes.len()); @@ -404,7 +405,7 @@ mod tests { s.hydrate_scalar_cache(); let multi_frames = - MultiFrame::from_frames(DEFAULT_REDUCTION_COUNT, &frames, s, &lang_rc); + MultiFrame::from_frames(DEFAULT_REDUCTION_COUNT, &frames, s, lang_rc.clone()); let cs = groth_prover.outer_synthesize(&multi_frames).unwrap(); diff --git a/src/proof/nova.rs b/src/proof/nova.rs index 13ce88fc8b..82a955cb5b 100644 --- a/src/proof/nova.rs +++ b/src/proof/nova.rs @@ -1,9 +1,13 @@ #![allow(non_snake_case)] use std::marker::PhantomData; +use std::sync::Mutex; use abomonation::Abomonation; -use bellperson::{gadgets::num::AllocatedNum, ConstraintSystem, SynthesisError}; +use bellperson::{ + gadgets::num::AllocatedNum, util_cs::witness_cs::WitnessCS, ConstraintSystem, SynthesisError, +}; +use rayon::prelude::*; use nova::{ errors::NovaError, @@ -24,6 +28,8 @@ use crate::circuit::{ }, CircuitFrame, MultiFrame, }; +use crate::config::CONFIG; + use crate::coprocessor::Coprocessor; use crate::error::ProofError; use crate::eval::{lang::Lang, Evaluator, Frame, Witness, IO}; @@ -180,10 +186,12 @@ impl> NovaProver { ) -> Result<(Proof<'_, C>, Vec, Vec, usize), ProofError> { let z0 = frames[0].input.to_vector(store)?; let zi = frames.last().unwrap().output.to_vector(store)?; - let circuits = MultiFrame::from_frames(self.reduction_count(), frames, store, &lang); + let circuits = + MultiFrame::from_frames(self.reduction_count(), &frames, store, lang.clone()); + let num_steps = circuits.len(); let proof = - Proof::prove_recursively(pp, store, &circuits, self.reduction_count, z0.clone(), lang)?; + Proof::prove_recursively(pp, store, circuits, self.reduction_count, z0.clone(), lang)?; Ok((proof, z0, zi, num_steps)) } @@ -203,6 +211,38 @@ impl> NovaProver { } } +impl<'a, F: LurkField, C: Coprocessor> MultiFrame<'a, F, IO, Witness, C> { + fn compute_witness(&self, s: &Store) -> WitnessCS { + let mut wcs = WitnessCS::new(); + + let input = self.input.unwrap(); + + use crate::tag::Tag; + let expr = s.hash_expr(&input.expr).unwrap(); + let env = s.hash_expr(&input.env).unwrap(); + let cont = s.hash_cont(&input.cont).unwrap(); + + let z_scalar = vec![ + expr.tag().to_field(), + *expr.value(), + env.tag().to_field(), + *env.value(), + cont.tag().to_field(), + *cont.value(), + ]; + + let mut bogus_cs = WitnessCS::::new(); + let z: Vec> = z_scalar + .iter() + .map(|x| AllocatedNum::alloc(&mut bogus_cs, || Ok(*x)).unwrap()) + .collect::>(); + + let _ = self.clone().synthesize(&mut wcs, z.as_slice()); + + wcs + } +} + impl<'a, F: LurkField, C: Coprocessor> StepCircuit for MultiFrame<'a, F, IO, Witness, C> { @@ -220,6 +260,29 @@ impl<'a, F: LurkField, C: Coprocessor> StepCircuit { assert_eq!(self.arity(), z.len()); + if cs.is_witness_generator() { + if let Some(w) = &self.cached_witness { + let aux = w.aux_slice(); + let end = aux.len() - 6; + let inputs = &w.inputs_slice()[1..]; + + cs.extend_aux(aux); + cs.extend_inputs(inputs); + + let scalars = &aux[end..]; + + let allocated = { + let mut bogus_cs = WitnessCS::new(); + + scalars + .iter() + .map(|scalar| AllocatedNum::alloc(&mut bogus_cs, || Ok(*scalar)).unwrap()) + .collect::>() + }; + + return Ok(allocated); + } + }; let input_expr = AllocatedPtr::by_index(0, z); let input_env = AllocatedPtr::by_index(1, z); let input_cont = AllocatedContPtr::by_index(2, z); @@ -230,6 +293,7 @@ impl<'a, F: LurkField, C: Coprocessor> StepCircuit Some(frames) => { let s = self.store.expect("store missing"); let g = GlobalAllocations::new(&mut cs.namespace(|| "global_allocations"), s)?; + self.synthesize_frames(cs, s, input_expr, input_env, input_cont, frames, &g) } None => { @@ -239,6 +303,7 @@ impl<'a, F: LurkField, C: Coprocessor> StepCircuit let frames = vec![blank_frame; count]; let g = GlobalAllocations::new(&mut cs.namespace(|| "global_allocations"), &s)?; + self.synthesize_frames(cs, &s, input_expr, input_env, input_cont, &frames, &g) } }; @@ -269,7 +334,7 @@ impl<'a: 'b, 'b, C: Coprocessor> Proof<'a, C> { pub fn prove_recursively( pp: &'a PublicParams<'_, C>, store: &'a Store, - circuits: &[C1<'a, C>], + circuits: Vec>, num_iters_per_step: usize, z0: Vec, lang: Arc>, @@ -288,55 +353,114 @@ impl<'a: 'b, 'b, C: Coprocessor> Proof<'a, C> { MultiFrame<'_, S1, IO, Witness, C>, TrivialTestCircuit, ) = C1::<'a>::circuits(num_iters_per_step, lang); + + dbg!(circuits.len()); + // produce a recursive SNARK let mut recursive_snark: Option, C2>> = None; - for circuit_primary in circuits.iter() { - assert_eq!( - num_iters_per_step, - circuit_primary.frames.as_ref().unwrap().len() - ); - if debug { - // For debugging purposes, synthesize the circuit and check that the constraint system is satisfied. - use bellperson::util_cs::test_cs::TestConstraintSystem; - let mut cs = TestConstraintSystem::<::Scalar>::new(); - - let zi = circuit_primary.frames.as_ref().unwrap()[0] - .input - .unwrap() - .to_vector(store)?; - let zi_allocated: Vec<_> = zi - .iter() - .enumerate() - .map(|(i, x)| { - AllocatedNum::alloc(cs.namespace(|| format!("z{i}_1")), || Ok(*x)) - }) - .collect::>()?; - - circuit_primary.synthesize(&mut cs, zi_allocated.as_slice())?; - - assert!(cs.is_satisfied()); + // the shadowing here is voluntary + let recursive_snark = if CONFIG.parallelism.recursive_steps.is_parallel() { + let cc = circuits + .iter() + .map(|c| Mutex::new(c.clone())) + .collect::>(); + + crossbeam::thread::scope(|s| { + s.spawn(|_| { + // Skip the very first circuit's witness, so `prove_step` can begin immediately. + // That circuit's witness will not be cached and will just be computed on-demand. + cc.par_iter().skip(1).for_each(|mf| { + let witness = { + let mf1 = mf.lock().unwrap(); + mf1.compute_witness(store) + }; + let mut mf2 = mf.lock().unwrap(); + + mf2.cached_witness = Some(witness); + }); + }); + + for circuit_primary in cc.iter() { + let circuit_primary = circuit_primary.lock().unwrap(); + assert_eq!( + num_iters_per_step, + circuit_primary.frames.as_ref().unwrap().len() + ); + + let mut r_snark = recursive_snark.unwrap_or_else(|| { + RecursiveSNARK::new( + &pp.pp, + &circuit_primary, + &circuit_secondary, + z0_primary.clone(), + z0_secondary.clone(), + ) + }); + r_snark + .prove_step( + &pp.pp, + &circuit_primary, + &circuit_secondary, + z0_primary.clone(), + z0_secondary.clone(), + ) + .expect("failure to prove Nova step"); + recursive_snark = Some(r_snark); + } + recursive_snark + }) + .unwrap() + } else { + for circuit_primary in circuits.iter() { + assert_eq!( + num_iters_per_step, + circuit_primary.frames.as_ref().unwrap().len() + ); + if debug { + // For debugging purposes, synthesize the circuit and check that the constraint system is satisfied. + use bellperson::util_cs::test_cs::TestConstraintSystem; + let mut cs = TestConstraintSystem::<::Scalar>::new(); + + let zi = circuit_primary.frames.as_ref().unwrap()[0] + .input + .unwrap() + .to_vector(store)?; + let zi_allocated: Vec<_> = zi + .iter() + .enumerate() + .map(|(i, x)| { + AllocatedNum::alloc(cs.namespace(|| format!("z{i}_1")), || Ok(*x)) + }) + .collect::>()?; + + circuit_primary.synthesize(&mut cs, zi_allocated.as_slice())?; + + assert!(cs.is_satisfied()); + } + + let mut r_snark = recursive_snark.unwrap_or_else(|| { + RecursiveSNARK::new( + &pp.pp, + circuit_primary, + &circuit_secondary, + z0_primary.clone(), + z0_secondary.clone(), + ) + }); + r_snark + .prove_step( + &pp.pp, + circuit_primary, + &circuit_secondary, + z0_primary.clone(), + z0_secondary.clone(), + ) + .expect("failure to prove Nova step"); + recursive_snark = Some(r_snark); } - let mut r_snark = recursive_snark.unwrap_or_else(|| { - RecursiveSNARK::new( - &pp.pp, - circuit_primary, - &circuit_secondary, - z0_primary.clone(), - z0_secondary.clone(), - ) - }); - r_snark - .prove_step( - &pp.pp, - circuit_primary, - &circuit_secondary, - z0_primary.clone(), - z0_secondary.clone(), - ) - .expect("failure to prove Nova step"); - recursive_snark = Some(r_snark); - } + recursive_snark + }; Ok(Self::Recursive(Box::new(recursive_snark.unwrap()))) } @@ -396,6 +520,7 @@ pub mod tests { use crate::ptr::ContPtr; use crate::tag::{Op, Op1, Op2}; + use bellperson::util_cs::witness_cs::WitnessCS; use bellperson::{ util_cs::{metric_cs::MetricCS, test_cs::TestConstraintSystem, Comparable, Delta}, Circuit, @@ -404,7 +529,22 @@ pub mod tests { const DEFAULT_REDUCTION_COUNT: usize = 5; const REDUCTION_COUNTS_TO_TEST: [usize; 3] = [1, 2, 5]; - /// fake docs + + // Returns index of first mismatch, along with the mismatched elements if they exist. + fn mismatch(a: &[T], b: &[T]) -> Option<(usize, (Option, Option))> { + let min_len = a.len().min(b.len()); + for i in 0..min_len { + if a[i] != b[i] { + return Some((i, (Some(a[i]), Some(b[i])))); + } + } + match (a.get(min_len), b.get(min_len)) { + (Some(&a_elem), None) => Some((min_len, (Some(a_elem), None))), + (None, Some(&b_elem)) => Some((min_len, (None, Some(b_elem)))), + _ => None, + } + } + pub fn test_aux>( s: &mut Store, expr: &str, @@ -512,7 +652,8 @@ pub mod tests { .get_evaluation_frames(expr, e, s, limit, &lang) .unwrap(); - let multiframes = MultiFrame::from_frames(nova_prover.reduction_count(), &frames, s, &lang); + let multiframes = + MultiFrame::from_frames(nova_prover.reduction_count(), &frames, s, lang.clone()); let len = multiframes.len(); let adjusted_iterations = nova_prover.expected_total_iterations(expected_iterations); @@ -527,7 +668,12 @@ pub mod tests { for (_i, multiframe) in multiframes.iter().enumerate() { let mut cs = TestConstraintSystem::new(); + let mut wcs = WitnessCS::new(); + + dbg!("synthesizing test cs"); multiframe.clone().synthesize(&mut cs).unwrap(); + dbg!("synthesizing witness cs"); + multiframe.clone().synthesize(&mut wcs).unwrap(); if let Some(prev) = previous_frame { assert!(prev.precedes(multiframe)); @@ -541,6 +687,15 @@ pub mod tests { } assert!(cs.is_satisfied()); assert!(cs.verify(&multiframe.public_inputs())); + dbg!("cs is satisfied!"); + let cs_inputs = cs.scalar_inputs(); + let cs_aux = cs.scalar_aux(); + + let wcs_inputs = wcs.scalar_inputs(); + let wcs_aux = wcs.scalar_aux(); + + assert_eq!(None, mismatch(&cs_inputs, &wcs_inputs)); + assert_eq!(None, mismatch(&cs_aux, &wcs_aux)); previous_frame = Some(multiframe.clone()); @@ -559,6 +714,11 @@ pub mod tests { } if let Some(expected_result) = expected_result { + use crate::writer::Write; + dbg!( + &expected_result.fmt_to_string(s), + &output.expr.fmt_to_string(s) + ); assert!(s.ptr_eq(&expected_result, &output.expr).unwrap()); } if let Some(expected_env) = expected_env { @@ -3563,7 +3723,7 @@ pub mod tests { #[test] #[ignore] - fn test_eval_non_symbol_binding_error() { + fn test_prove_non_symbol_binding_error() { let s = &mut Store::::default(); let error = s.get_cont_error();