From 54e92922b8de887c547e48320fd5c6b8bc622652 Mon Sep 17 00:00:00 2001 From: Arthur Paulino Date: Mon, 18 Mar 2024 13:08:45 -0300 Subject: [PATCH] feat: support for streamed computations and their proofs (#1209) * feat: support for streamed computations and their proofs Introduce the `Op::Recv` LEM operator, which awaits for data to be received by the channel terminal at LEM interpretation time. This operator is used to implement the logic for the `StreamIn` continuation, which receives the argument to be applied to the (chaining) callable object. The result will be available once the (new) `StreamOut` continuation is reached. When that happens, the computation can be resumed by supplying two pieces of data through the channel: * A flag to tell if the machine should stutter while in `StreamOut` * The next argument to be consumed in the incoming `StreamIn` state Note: the second message should not be be sent if the machine is set to stutter with the first message. There is a test to show that we're now able to construct proofs of computations that grow incrementally by starting from where we had stopped. We don't need to start the folding process from the beginning. * optimize: skip new `StreamIn` ocurrences Prepare the next call directly from `StreamOut` without having to go through `StreamIn` again. Since `StreamIn` only occurs in the very first frame after this optimization, it was renamed to `StreamStart`. And `StreamOut`, which now serves as literal point for pause while supporting both input and output, was renamed to `StreamPause`. --- benches/common/fib.rs | 4 +- benches/end2end.rs | 22 ++- benches/fibonacci.rs | 2 +- benches/sha256.rs | 7 +- chain-server/src/server.rs | 2 +- examples/keccak.rs | 2 +- examples/sha256_nivc.rs | 2 +- examples/tp_table.rs | 2 +- src/cli/repl/mod.rs | 4 +- src/coroutine/memoset/prove.rs | 7 +- src/lem/circuit.rs | 17 +- src/lem/coroutine/eval.rs | 2 + src/lem/coroutine/synthesis.rs | 1 + src/lem/eval.rs | 272 ++++++++++++++++++++++++----- src/lem/interpreter.rs | 10 +- src/lem/macros.rs | 13 ++ src/lem/mod.rs | 13 +- src/lem/multiframe.rs | 17 +- src/lem/slot.rs | 9 + src/lem/store.rs | 10 +- src/lem/tests/mod.rs | 1 + src/lem/tests/nivc_steps.rs | 14 +- src/lem/tests/stream.rs | 127 ++++++++++++++ src/proof/mod.rs | 16 +- src/proof/nova.rs | 14 +- src/proof/supernova.rs | 22 ++- src/proof/tests/mod.rs | 5 +- src/proof/tests/stream.rs | 93 ++++++++++ src/proof/tests/supernova_tests.rs | 2 +- src/tag.rs | 9 + 30 files changed, 623 insertions(+), 98 deletions(-) create mode 100644 src/lem/tests/stream.rs create mode 100644 src/proof/tests/stream.rs diff --git a/benches/common/fib.rs b/benches/common/fib.rs index 17b0d8a9a3..aee157ca13 100644 --- a/benches/common/fib.rs +++ b/benches/common/fib.rs @@ -90,7 +90,9 @@ fn compute_coeffs(store: &Store) -> (usize, usize) { } } } - let frame = step_func.call_simple(&input, store, &lang, 0).unwrap(); + let frame = step_func + .call_simple(&input, store, &lang, 0, &dummy_terminal()) + .unwrap(); input = frame.output.clone(); iteration += 1; } diff --git a/benches/end2end.rs b/benches/end2end.rs index 5dda99573e..6d3ababcee 100644 --- a/benches/end2end.rs +++ b/benches/end2end.rs @@ -80,7 +80,9 @@ fn end2end_benchmark(c: &mut Criterion) { let ptr = go_base::(&store, state.clone(), s.0, s.1); let frames = evaluate::>(None, ptr, &store, limit, &dummy_terminal()).unwrap(); - let _result = prover.prove_from_frames(&pp, &frames, &store).unwrap(); + let _result = prover + .prove_from_frames(&pp, &frames, &store, None) + .unwrap(); }) }); @@ -220,7 +222,9 @@ fn prove_benchmark(c: &mut Criterion) { evaluate::>(None, ptr, &store, limit, &dummy_terminal()).unwrap(); b.iter(|| { - let result = prover.prove_from_frames(&pp, &frames, &store).unwrap(); + let result = prover + .prove_from_frames(&pp, &frames, &store, None) + .unwrap(); black_box(result); }) }); @@ -268,7 +272,9 @@ fn prove_compressed_benchmark(c: &mut Criterion) { evaluate::>(None, ptr, &store, limit, &dummy_terminal()).unwrap(); b.iter(|| { - let (proof, _, _, _) = prover.prove_from_frames(&pp, &frames, &store).unwrap(); + let (proof, _, _, _) = prover + .prove_from_frames(&pp, &frames, &store, None) + .unwrap(); let compressed_result = proof.compress(&pp).unwrap(); black_box(compressed_result); @@ -313,8 +319,9 @@ fn verify_benchmark(c: &mut Criterion) { let prover = NovaProver::new(reduction_count, lang_rc.clone()); let frames = evaluate::>(None, ptr, &store, limit, &dummy_terminal()).unwrap(); - let (proof, z0, zi, _num_steps) = - prover.prove_from_frames(&pp, &frames, &store).unwrap(); + let (proof, z0, zi, _num_steps) = prover + .prove_from_frames(&pp, &frames, &store, None) + .unwrap(); b.iter_batched( || z0.clone(), @@ -367,8 +374,9 @@ fn verify_compressed_benchmark(c: &mut Criterion) { let prover = NovaProver::new(reduction_count, lang_rc.clone()); let frames = evaluate::>(None, ptr, &store, limit, &dummy_terminal()).unwrap(); - let (proof, z0, zi, _num_steps) = - prover.prove_from_frames(&pp, &frames, &store).unwrap(); + let (proof, z0, zi, _num_steps) = prover + .prove_from_frames(&pp, &frames, &store, None) + .unwrap(); let compressed_proof = proof.compress(&pp).unwrap(); diff --git a/benches/fibonacci.rs b/benches/fibonacci.rs index 581980e792..790a8ad313 100644 --- a/benches/fibonacci.rs +++ b/benches/fibonacci.rs @@ -114,7 +114,7 @@ fn fibonacci_prove( b.iter_batched( || frames, |frames| { - let result = prover.prove_from_frames(&pp, frames, &store); + let result = prover.prove_from_frames(&pp, frames, &store, None); let _ = black_box(result); }, BatchSize::LargeInput, diff --git a/benches/sha256.rs b/benches/sha256.rs index 5e25b9ca2f..3082062604 100644 --- a/benches/sha256.rs +++ b/benches/sha256.rs @@ -146,7 +146,7 @@ fn sha256_ivc_prove( b.iter_batched( || frames, |frames| { - let result = prover.prove_from_frames(&pp, frames, store); + let result = prover.prove_from_frames(&pp, frames, store, None); let _ = black_box(result); }, BatchSize::LargeInput, @@ -234,7 +234,8 @@ fn sha256_ivc_prove_compressed( b.iter_batched( || frames, |frames| { - let (proof, _, _, _) = prover.prove_from_frames(&pp, frames, store).unwrap(); + let (proof, _, _, _) = + prover.prove_from_frames(&pp, frames, store, None).unwrap(); let compressed_result = proof.compress(&pp).unwrap(); let _ = black_box(compressed_result); @@ -325,7 +326,7 @@ fn sha256_nivc_prove( b.iter_batched( || frames, |frames| { - let result = prover.prove_from_frames(&pp, frames, store); + let result = prover.prove_from_frames(&pp, frames, store, None); let _ = black_box(result); }, BatchSize::LargeInput, diff --git a/chain-server/src/server.rs b/chain-server/src/server.rs index f77b003ce0..ac0062b669 100644 --- a/chain-server/src/server.rs +++ b/chain-server/src/server.rs @@ -156,7 +156,7 @@ where // prove then compress the proof let (proof, ..) = self .prover - .prove_from_frames(pp, &frames, &self.store) + .prove_from_frames(pp, &frames, &self.store, None) .map_err(|e| Status::internal(e.to_string()))?; let proof = proof .compress(pp) diff --git a/examples/keccak.rs b/examples/keccak.rs index 4d040f83ae..a373081d3e 100644 --- a/examples/keccak.rs +++ b/examples/keccak.rs @@ -395,7 +395,7 @@ fn main() { let proof_start = Instant::now(); let (proof, z0, zi, _) = supernova_prover - .prove_from_frames(&pp, &frames, store) + .prove_from_frames(&pp, &frames, store, None) .unwrap(); let proof_end = proof_start.elapsed(); diff --git a/examples/sha256_nivc.rs b/examples/sha256_nivc.rs index ae3d86089a..c4d46183a7 100644 --- a/examples/sha256_nivc.rs +++ b/examples/sha256_nivc.rs @@ -104,7 +104,7 @@ fn main() { let (proof, z0, zi, _num_steps) = tracing_texray::examine(tracing::info_span!("bang!")) .in_scope(|| { supernova_prover - .prove_from_frames(&pp, &frames, store) + .prove_from_frames(&pp, &frames, store, None) .unwrap() }); let proof_end = proof_start.elapsed(); diff --git a/examples/tp_table.rs b/examples/tp_table.rs index a9198fa6bf..36a3c78809 100644 --- a/examples/tp_table.rs +++ b/examples/tp_table.rs @@ -176,7 +176,7 @@ fn main() { let mut timings = Vec::with_capacity(n_samples); for _ in 0..n_samples { let start = Instant::now(); - let result = prover.prove_from_frames(&pp, frames, &store); + let result = prover.prove_from_frames(&pp, frames, &store, None); let _ = black_box(result); let end = start.elapsed().as_secs_f64(); timings.push(end); diff --git a/src/cli/repl/mod.rs b/src/cli/repl/mod.rs index 6aff3567b8..c86f091107 100644 --- a/src/cli/repl/mod.rs +++ b/src/cli/repl/mod.rs @@ -355,7 +355,7 @@ where info!("Proving with NovaProver"); let (proof, public_inputs, public_outputs, num_steps) = - prover.prove_from_frames(&pp, frames, &self.store)?; + prover.prove_from_frames(&pp, frames, &self.store, None)?; info!("Compressing Nova proof"); let proof = proof.compress(&pp)?; assert_eq!(self.rc * num_steps, pad(n_frames, self.rc)); @@ -370,7 +370,7 @@ where info!("Proving with SuperNovaProver"); let (proof, public_inputs, public_outputs, _num_steps) = - prover.prove_from_frames(&pp, frames, &self.store)?; + prover.prove_from_frames(&pp, frames, &self.store, None)?; info!("Compressing SuperNova proof"); let proof = proof.compress(&pp)?; assert!(proof.verify(&pp, &public_inputs, &public_outputs)?); diff --git a/src/coroutine/memoset/prove.rs b/src/coroutine/memoset/prove.rs index 1cb006f84a..8689577d92 100644 --- a/src/coroutine/memoset/prove.rs +++ b/src/coroutine/memoset/prove.rs @@ -96,7 +96,7 @@ impl<'a, F: CurveCycleEquipped, Q: Query + Send + Sync> RecursiveSNARKTrait> for Proof> { type PublicParams = PublicParams; - + type BaseRecursiveSNARK = RecursiveSNARK>; type ErrorType = SuperNovaError; #[tracing::instrument(skip_all, name = "supernova::prove_recursively")] @@ -105,8 +105,9 @@ impl<'a, F: CurveCycleEquipped, Q: Query + Send + Sync> z0: &[F], steps: Vec>, _store: &Store, + init: Option>>, ) -> Result { - let mut recursive_snark_option: Option>> = None; + let mut recursive_snark_option = init; let z0_primary = z0; let z0_secondary = Self::z0_secondary(); @@ -251,7 +252,7 @@ impl<'a, F: CurveCycleEquipped, Q: Query + Send + Sync> MemosetProver<'a, F, let num_steps = steps.len(); - let prove_output = Proof::prove_recursively(pp, &z0, steps, store)?; + let prove_output = Proof::prove_recursively(pp, &z0, steps, store, None)?; let zi = match prove_output { Proof::Recursive(ref snark, ..) => snark.zi_primary().clone(), Proof::Compressed(..) => unreachable!(), diff --git a/src/lem/circuit.rs b/src/lem/circuit.rs index e54c767ffc..7eb86b97f0 100644 --- a/src/lem/circuit.rs +++ b/src/lem/circuit.rs @@ -1121,6 +1121,16 @@ fn synthesize_block, C: Coprocessor>( bound_allocations.insert_ptr(tgt[1].clone(), rem_ptr); } Op::Emit(_) | Op::Unit(_) => (), + Op::Recv(tgt) => { + let ptr = if let Ok(val) = ctx.bindings.get(tgt) { + *val.get_ptr().expect("Received data must be a pointer") + } else { + ctx.store.dummy() + }; + let z_ptr = || ctx.store.hash_ptr(&ptr); + let a_ptr = AllocatedPtr::alloc_infallible(ns!(cs, format!("recv {tgt}")), z_ptr); + bound_allocations.insert_ptr(tgt.clone(), a_ptr); + } Op::Hide(tgt, sec, pay) => { let sec = bound_allocations.get_ptr(sec)?; let pay = bound_allocations.get_ptr(pay)?; @@ -1593,7 +1603,12 @@ impl Func { // three implies_u64, one sub and one linear num_constraints += 197; } - Op::Not(..) | Op::Emit(_) | Op::Cproc(..) | Op::Copy(..) | Op::Unit(_) => (), + Op::Not(..) + | Op::Emit(_) + | Op::Recv(_) + | Op::Cproc(..) + | Op::Copy(..) + | Op::Unit(_) => (), Op::Cons2(_, tag, _) => { // tag for the image globals.insert(FWrap(tag.to_field())); diff --git a/src/lem/coroutine/eval.rs b/src/lem/coroutine/eval.rs index a06eb7e403..a2e6b1bf19 100644 --- a/src/lem/coroutine/eval.rs +++ b/src/lem/coroutine/eval.rs @@ -200,9 +200,11 @@ fn run( bindings.insert_ptr(tgt[1].clone(), c2); } Op::Emit(a) => { + // TODO: send `a` through a channel as in the original interpreter let a = bindings.get_ptr(a)?; println!("{}", a.fmt_to_string_simple(&scope.store)); } + Op::Recv(_) => todo!("not supported yet"), Op::Cons2(img, tag, preimg) => { let preimg_ptrs = bindings.get_many_ptr(preimg)?; let tgt_ptr = intern_ptrs!(scope.store, *tag, preimg_ptrs[0], preimg_ptrs[1]); diff --git a/src/lem/coroutine/synthesis.rs b/src/lem/coroutine/synthesis.rs index 1f25e13756..3d23274ab4 100644 --- a/src/lem/coroutine/synthesis.rs +++ b/src/lem/coroutine/synthesis.rs @@ -458,6 +458,7 @@ fn synthesize_run<'a, F: LurkField, CS: ConstraintSystem>( bound_allocations.insert_ptr(tgt[1].clone(), rem_ptr); } Op::Emit(_) | Op::Unit(_) => (), + Op::Recv(_) => todo!("not supported yet"), Op::Hide(tgt, sec, pay) => { let sec = bound_allocations.get_ptr(sec)?; let pay = bound_allocations.get_ptr(pay)?; diff --git a/src/lem/eval.rs b/src/lem/eval.rs index 097414d2a3..f13fd4bb7d 100644 --- a/src/lem/eval.rs +++ b/src/lem/eval.rs @@ -15,7 +15,7 @@ use crate::{ proof::FoldingMode, state::initial_lurk_state, tag::{ - ContTag::{Error, Terminal}, + ContTag::{Error, StreamPause, Terminal}, ExprTag::Cproc, }, Symbol, @@ -74,12 +74,25 @@ fn compute_frame>( assert_eq!(func.input_params.len(), input.len()); let preimages = Hints::new_from_func(func); let frame = func.call(input, store, preimages, ch_terminal, lang, pc)?; - let must_break = matches!(frame.output[2].tag(), Tag::Cont(Terminal | Error)); + let must_break = matches!( + frame.output[2].tag(), + Tag::Cont(Terminal | Error | StreamPause) + ); Ok((frame, must_break)) } +fn log_fmt(i: usize, input: &[Ptr], store: &Store) -> String { + let state = initial_lurk_state(); + format!( + "Frame: {i}\n\tExpr: {}\n\tEnv: {}\n\tCont: {}", + input[0].fmt_to_string(store, state), + input[1].fmt_to_string(store, state), + input[2].fmt_to_string(store, state) + ) +} + // Builds frames for IVC or NIVC scheme -fn build_frames, LogFmt: Fn(usize, &[Ptr], &Store) -> String>( +fn build_frames>( lurk_step: &Func, cprocs: &[Func], mut input: Vec, @@ -87,7 +100,6 @@ fn build_frames, LogFmt: Fn(usize, &[Ptr], &Stor limit: usize, lang: &Lang, ch_terminal: &ChannelTerminal, - log_fmt: LogFmt, ) -> Result> { let mut pc = 0; let mut frames = vec![]; @@ -147,42 +159,15 @@ pub fn evaluate_with_env_and_cont>( limit: usize, ch_terminal: &ChannelTerminal, ) -> Result> { - let state = initial_lurk_state(); - let log_fmt = |i: usize, inp: &[Ptr], store: &Store| { - format!( - "Frame: {i}\n\tExpr: {}\n\tEnv: {}\n\tCont: {}", - inp[0].fmt_to_string(store, state), - inp[1].fmt_to_string(store, state), - inp[2].fmt_to_string(store, state) - ) - }; - let input = vec![expr, env, cont]; - match lang_setup { None => { let lang: Lang = Lang::new(); - build_frames( - eval_step(), - &[], - input, - store, - limit, - &lang, - ch_terminal, - log_fmt, - ) + build_frames(eval_step(), &[], input, store, limit, &lang, ch_terminal) + } + Some((lurk_step, cprocs, lang)) => { + build_frames(lurk_step, cprocs, input, store, limit, lang, ch_terminal) } - Some((lurk_step, cprocs, lang)) => build_frames( - lurk_step, - cprocs, - input, - store, - limit, - lang, - ch_terminal, - log_fmt, - ), } } @@ -225,15 +210,16 @@ pub fn evaluate>( ) } -pub fn evaluate_simple_with_env>( +pub fn evaluate_simple_with_env_and_cont>( lang_setup: Option<(&Func, &[Func], &Lang)>, expr: Ptr, env: Ptr, + cont: Ptr, store: &Store, limit: usize, ch_terminal: &ChannelTerminal, ) -> Result<(Vec, usize)> { - let input = vec![expr, env, store.cont_outermost()]; + let input = vec![expr, env, cont]; match lang_setup { None => { let lang: Lang = Lang::new(); @@ -245,6 +231,26 @@ pub fn evaluate_simple_with_env>( } } +#[inline] +pub fn evaluate_simple_with_env>( + lang_setup: Option<(&Func, &[Func], &Lang)>, + expr: Ptr, + env: Ptr, + store: &Store, + limit: usize, + ch_terminal: &ChannelTerminal, +) -> Result<(Vec, usize)> { + evaluate_simple_with_env_and_cont( + lang_setup, + expr, + env, + store.cont_outermost(), + store, + limit, + ch_terminal, + ) +} + #[inline] pub fn evaluate_simple>( lang_setup: Option<(&Func, &[Func], &Lang)>, @@ -263,6 +269,122 @@ pub fn evaluate_simple>( ) } +#[inline] +pub fn start_stream_with_env>( + lang_setup: Option<(&Func, &[Func], &Lang)>, + callable: Ptr, + env: Ptr, + store: &Store, + limit: usize, + ch_terminal: &ChannelTerminal, +) -> Result> { + evaluate_with_env_and_cont( + lang_setup, + callable, + env, + store.cont_stream_start(), + store, + limit, + ch_terminal, + ) +} + +#[inline] +pub fn start_stream>( + lang_setup: Option<(&Func, &[Func], &Lang)>, + callable: Ptr, + store: &Store, + limit: usize, + ch_terminal: &ChannelTerminal, +) -> Result> { + start_stream_with_env( + lang_setup, + callable, + store.intern_nil(), + store, + limit, + ch_terminal, + ) +} + +#[inline] +pub fn resume_stream>( + lang_setup: Option<(&Func, &[Func], &Lang)>, + input: Vec, + store: &Store, + limit: usize, + ch_terminal: &ChannelTerminal, +) -> Result> { + assert!(matches!(input[2].tag(), Tag::Cont(StreamPause))); + match lang_setup { + None => { + let lang: Lang = Lang::new(); + build_frames(eval_step(), &[], input, store, limit, &lang, ch_terminal) + } + Some((lurk_step, cprocs, lang)) => { + build_frames(lurk_step, cprocs, input, store, limit, lang, ch_terminal) + } + } +} + +#[inline] +pub fn start_stream_simple_with_env>( + lang_setup: Option<(&Func, &[Func], &Lang)>, + callable: Ptr, + env: Ptr, + store: &Store, + limit: usize, + ch_terminal: &ChannelTerminal, +) -> Result<(Vec, usize)> { + evaluate_simple_with_env_and_cont( + lang_setup, + callable, + env, + store.cont_stream_start(), + store, + limit, + ch_terminal, + ) +} + +#[inline] +pub fn start_stream_simple>( + lang_setup: Option<(&Func, &[Func], &Lang)>, + callable: Ptr, + store: &Store, + limit: usize, + ch_terminal: &ChannelTerminal, +) -> Result<(Vec, usize)> { + start_stream_simple_with_env( + lang_setup, + callable, + store.intern_nil(), + store, + limit, + ch_terminal, + ) +} + +#[inline] +pub fn resume_stream_simple>( + lang_setup: Option<(&Func, &[Func], &Lang)>, + input: Vec, + store: &Store, + limit: usize, + ch_terminal: &ChannelTerminal, +) -> Result<(Vec, usize)> { + assert!(matches!(input[2].tag(), Tag::Cont(StreamPause))); + match lang_setup { + None => { + let lang: Lang = Lang::new(); + traverse_frames(eval_step(), &[], input, store, limit, &lang, ch_terminal) + } + Some((lurk_step, cprocs, lang)) => { + traverse_frames(lurk_step, cprocs, input, store, limit, lang, ch_terminal) + } + } +} + pub struct EvalConfig<'a, F, C> { lang: &'a Lang, folding_mode: FoldingMode, @@ -868,12 +990,27 @@ fn reduce(cprocs: &[(&Symbol, usize)]) -> Func { return (expr, smaller_env, not_found) }); + // 1. receive data from channel; + // 2. build the list of arguments with it (just one argument!) + // 3. setup a call cycle with a `StreamDispatch` stacked underneath + let mk_stream_call_cont = aux_func!(mk_stream_call_cont(env): 1 => { + let nil = Symbol("nil"); + let nil = cast(nil, Expr::Nil); + let foo: Expr::Nil; + let arg =! recv(); + let arg_list: Expr::Cons = cons2(arg, nil); + let cont: Cont::StreamDispatch = HASH_8_ZEROS; + let cont: Cont::Call = cons4(arg_list, env, cont, foo); + return (cont); + }); + aux_func!(reduce(expr, env, cont): 4 => { let ret = Symbol("return"); let term: Cont::Terminal = HASH_8_ZEROS; let err: Cont::Error = HASH_8_ZEROS; let cproc: Expr::Cproc; + // stuttering condition when not in `StreamPause` let cont_is_term = eq_tag(cont, term); let cont_is_err = eq_tag(cont, err); let expr_is_cproc = eq_tag(expr, cproc); @@ -882,11 +1019,42 @@ fn reduce(cprocs: &[(&Symbol, usize)]) -> Func { if acc_ret { return (expr, env, cont, ret) } + + let errctrl = Symbol("error"); + + match cont.tag { + Cont::StreamStart => { + let (cont) = mk_stream_call_cont(env); + return (expr, env, cont, ret); + } + Cont::StreamPause => { + let stutter =! recv(); + match stutter.tag { + Expr::Nil => { + // 1. make sure the resulting expression is a cons + // 2. deconstruct it to acquire the next callable + // 3. setup the next call cycle and start it ASAP + match expr.tag { + Expr::Cons => { + let (_result, callable) = decons2(expr); + let (cont) = mk_stream_call_cont(env); + return(callable, env, cont, ret); + } + }; + return (expr, env, err, errctrl); + } + }; + // `stutter != nil` is the stuttering condition when in `StreamPause` + return (expr, env, cont, ret); + } + }; + let apply = Symbol("apply-continuation"); let thunk: Expr::Thunk; let sym: Expr::Sym; let cons: Expr::Cons; + // non self-evaluating condition let expr_is_thunk = eq_tag(expr, thunk); let expr_is_sym = eq_tag(expr, sym); let expr_is_cons = eq_tag(expr, cons); @@ -895,11 +1063,11 @@ fn reduce(cprocs: &[(&Symbol, usize)]) -> Func { if !acc_not_apply { return (expr, env, cont, apply) } - let errctrl = Symbol("error"); - let t = Symbol("t"); + let nil = Symbol("nil"); let nil = cast(nil, Expr::Nil); let foo: Expr::Nil; + let t = Symbol("t"); match expr.tag { Expr::Thunk => { @@ -1251,7 +1419,6 @@ fn apply_cont(cprocs: &[(&Symbol, usize)], ivc: bool) -> Func { match ctrl.value { Symbol("apply-continuation") => { let makethunk = Symbol("make-thunk"); - let errctrl = Symbol("error"); let ret = Symbol("return"); let t = Symbol("t"); @@ -1264,12 +1431,21 @@ fn apply_cont(cprocs: &[(&Symbol, usize)], ivc: bool) -> Func { let char: Expr::Char; let u64: Expr::U64; let err: Cont::Error = HASH_8_ZEROS; - let term: Cont::Terminal = HASH_8_ZEROS; match cont.tag { Cont::Outermost => { + let term: Cont::Terminal = HASH_8_ZEROS; // We erase the environment as to not leak any information about internal variables. return (result, empty_env, term, ret) } + Cont::StreamDispatch => { + match result.tag { + Expr::Cons => { + let cont: Cont::StreamPause = HASH_8_ZEROS; + return (result, env, cont, ret); + } + }; + return (result, env, err, errctrl); + } Cont::Emit => { let (cont, _rest, _foo, _foo) = decons4(cont); return (result, env, cont, makethunk) @@ -1762,10 +1938,15 @@ fn make_thunk() -> Func { Symbol("make-thunk") => { match cont.tag { Cont::Outermost => { - let empty_env: Expr::Env; let cont: Cont::Terminal = HASH_8_ZEROS; + // We erase the environment as to not leak any information about internal variables. + let empty_env: Expr::Env; return (expr, empty_env, cont) } + Cont::StreamDispatch => { + let cont: Cont::StreamPause = HASH_8_ZEROS; + return (expr, env, cont); + } }; let thunk: Expr::Thunk = cons2(expr, cont); let cont: Cont::Dummy = HASH_8_ZEROS; @@ -1791,7 +1972,8 @@ mod tests { let frame = Frame::blank(func, 0, &store); let mut cs = TestConstraintSystem::::new(); let lang: Lang = Lang::new(); - let _ = func.synthesize_frame_aux(&mut cs, &store, &frame, &lang); + func.synthesize_frame_aux(&mut cs, &store, &frame, &lang) + .unwrap(); let expect_eq = |computed: usize, expected: Expect| { expected.assert_eq(&computed.to_string()); }; @@ -1801,8 +1983,8 @@ mod tests { expect_eq(func.slots_count.commitment, expect!["1"]); expect_eq(func.slots_count.bit_decomp, expect!["3"]); expect_eq(cs.num_inputs(), expect!["1"]); - expect_eq(cs.aux().len(), expect!["9094"]); - expect_eq(cs.num_constraints(), expect!["11032"]); + expect_eq(cs.aux().len(), expect!["9119"]); + expect_eq(cs.num_constraints(), expect!["11141"]); assert_eq!(func.num_constraints(&store), cs.num_constraints()); } } diff --git a/src/lem/interpreter.rs b/src/lem/interpreter.rs index aa6b900950..3d47122038 100644 --- a/src/lem/interpreter.rs +++ b/src/lem/interpreter.rs @@ -11,7 +11,7 @@ use super::{ use crate::{ coprocessor::Coprocessor, - dual_channel::{dummy_terminal, ChannelTerminal}, + dual_channel::ChannelTerminal, field::LurkField, lang::Lang, num::Num as BaseNum, @@ -332,6 +332,11 @@ impl Block { Op::Emit(a) => { ch_terminal.send(bindings.get_ptr(a)?)?; } + Op::Recv(a) => { + let ptr = ch_terminal.recv()?; + hints.bindings.insert_ptr(a.clone(), ptr); + bindings.insert_ptr(a.clone(), ptr); + } Op::Cons2(img, tag, preimg) => { let preimg_ptrs = bindings.get_many_ptr(preimg)?; let tgt_ptr = intern_ptrs!(store, *tag, preimg_ptrs[0], preimg_ptrs[1]); @@ -579,12 +584,13 @@ impl Func { store: &Store, lang: &Lang, pc: usize, + ch_terminal: &ChannelTerminal, ) -> Result { self.call( args, store, Hints::new_from_func(self), - &dummy_terminal(), + ch_terminal, lang, pc, ) diff --git a/src/lem/macros.rs b/src/lem/macros.rs index 7511d56b09..5db96e4b8a 100644 --- a/src/lem/macros.rs +++ b/src/lem/macros.rs @@ -159,6 +159,9 @@ macro_rules! op { ( emit($v:ident) ) => { $crate::lem::Op::Emit($crate::var!($v)) }; + ( let $v:ident =! recv() ) => { + $crate::lem::Op::Recv($crate::var!($v)) + }; ( let $tgt:ident : $kind:ident::$tag:ident = cons2($src1:ident, $src2:ident) ) => { $crate::lem::Op::Cons2( $crate::var!($tgt), @@ -502,6 +505,16 @@ macro_rules! block { $($tail)* ) }; + (@seq {$($limbs:expr)*}, let $v:ident =! recv() ; $($tail:tt)*) => { + $crate::block! ( + @seq + { + $($limbs)* + $crate::op!(let $v =! recv()) + }, + $($tail)* + ) + }; (@seq {$($limbs:expr)*}, let $tgt:ident = Num($sym:literal) ; $($tail:tt)*) => { $crate::block! ( @seq diff --git a/src/lem/mod.rs b/src/lem/mod.rs index 6b151e1603..bccdb2dbd9 100644 --- a/src/lem/mod.rs +++ b/src/lem/mod.rs @@ -258,6 +258,12 @@ pub enum Op { DivRem64([Var; 2], Var, Var), /// `Emit(v)` sends the value of `v` through the channel during interpretation Emit(Var), + /// `Recv(v)` binds `v` to a variable received from the channel + /// + /// # Warnings + /// * This will lock the interpretation thread until a message is received + /// * This will be an unconstrained allocation in the circuit + Recv(Var), /// `Cons2(x, t, ys)` binds `x` to a `Ptr` with tag `t` and 2 children `ys` Cons2(Var, Tag, [Var; 2]), /// `Cons3(x, t, ys)` binds `x` to a `Ptr` with tag `t` and 3 children `ys` @@ -390,7 +396,8 @@ impl Func { | Op::Hash4Zeros(tgt, _) | Op::Hash6Zeros(tgt, _) | Op::Hash8Zeros(tgt, _) - | Op::Lit(tgt, _) => { + | Op::Lit(tgt, _) + | Op::Recv(tgt) => { is_unique(tgt, map); } Op::Cast(tgt, _tag, src) => { @@ -728,6 +735,10 @@ impl Block { let a = map.get_cloned(&a)?; ops.push(Op::Emit(a)) } + Op::Recv(tgt) => { + let tgt = insert_one(map, uniq, &tgt); + ops.push(Op::Recv(tgt)) + } Op::Cons2(img, tag, preimg) => { let preimg = map.get_many_cloned(&preimg)?.try_into().unwrap(); let img = insert_one(map, uniq, &img); diff --git a/src/lem/multiframe.rs b/src/lem/multiframe.rs index 8e5bc47083..8f4728a4b3 100644 --- a/src/lem/multiframe.rs +++ b/src/lem/multiframe.rs @@ -11,7 +11,7 @@ use crate::{ circuit::gadgets::pointer::AllocatedPtr, config::lurk_config, coprocessor::Coprocessor, - dual_channel::ChannelTerminal, + dual_channel::{dummy_terminal, pair_terminals, ChannelTerminal}, error::{ProofError, ReductionError}, field::{LanguageField, LurkField}, lang::Lang, @@ -720,9 +720,18 @@ fn pad_frames>( size: usize, store: &Store, ) { - let padding_frame = lurk_step - .call_simple(input, store, lang, 0) - .expect("reduction step failed"); + let padding_frame = if matches!(input[2].tag(), Tag::Cont(ContTag::StreamPause)) { + // we need to allow stuttering for the padding frame + let (t1, t2) = pair_terminals(); + t2.send(store.intern_t()).unwrap(); // anything but `nil` to allow stuttering + lurk_step + .call_simple(input, store, lang, 0, &t1) + .expect("reduction step failed") + } else { + lurk_step + .call_simple(input, store, lang, 0, &dummy_terminal()) + .expect("reduction step failed") + }; assert_eq!(padding_frame.pc, 0); assert_eq!(input, padding_frame.output); frames.resize(size, padding_frame); diff --git a/src/lem/slot.rs b/src/lem/slot.rs index 2f24ddbff3..7fa336e8ef 100644 --- a/src/lem/slot.rs +++ b/src/lem/slot.rs @@ -103,6 +103,8 @@ //! STEP 2 will need as many iterations as it takes to evaluate the Lurk //! expression and so will STEP 3. +use match_opt::match_opt; + use super::{ pointers::{Ptr, RawPtr}, Block, Ctrl, Op, @@ -248,6 +250,13 @@ pub enum Val { Boolean(bool), } +impl Val { + #[inline] + pub(crate) fn get_ptr(&self) -> Option<&Ptr> { + match_opt!(self, Self::Pointer(ptr) => ptr) + } +} + /// Holds data to feed the slots #[derive(Clone, Debug)] pub struct SlotData { diff --git a/src/lem/store.rs b/src/lem/store.rs index a99b88772e..1c57f0c7b0 100644 --- a/src/lem/store.rs +++ b/src/lem/store.rs @@ -22,7 +22,7 @@ use crate::{ syntax::Syntax, tag::ContTag::{ self, Binop, Binop2, Call, Call0, Call2, Dummy, Emit, If, Let, LetRec, Lookup, Outermost, - Tail, Terminal, Unop, + StreamDispatch, StreamPause, StreamStart, Tail, Terminal, Unop, }, tag::ExprTag::{ Char, Comm, Cons, Cproc, Env, Fun, Key, Nil, Num, Prov, Rec, Str, Sym, Thunk, U64, @@ -756,6 +756,11 @@ impl Store { Ptr::new(Tag::Cont(Terminal), RawPtr::Atom(self.hash8zeros_idx)) } + #[inline] + pub fn cont_stream_start(&self) -> Ptr { + Ptr::new(Tag::Cont(StreamStart), RawPtr::Atom(self.hash8zeros_idx)) + } + /// Function specialized on deconstructing `Cons` pointers into their car/cdr pub fn fetch_cons(&self, ptr: &Ptr) -> Option<(Ptr, Ptr)> { match_opt!((ptr.tag(), ptr.raw()), (Tag::Expr(Cons), RawPtr::Hash4(idx)) => { @@ -1374,6 +1379,9 @@ impl Ptr { store, state, ), + StreamStart => "StreamStart".into(), + StreamDispatch => "StreamDispatch".into(), + StreamPause => "StreamPause".into(), }, Tag::Op1(op) => op.to_string(), Tag::Op2(op) => op.to_string(), diff --git a/src/lem/tests/mod.rs b/src/lem/tests/mod.rs index ab19c43ea5..35fa0e7cda 100644 --- a/src/lem/tests/mod.rs +++ b/src/lem/tests/mod.rs @@ -1,3 +1,4 @@ mod eval_tests; mod misc; mod nivc_steps; +mod stream; diff --git a/src/lem/tests/nivc_steps.rs b/src/lem/tests/nivc_steps.rs index 85bbd93715..f31b6ff097 100644 --- a/src/lem/tests/nivc_steps.rs +++ b/src/lem/tests/nivc_steps.rs @@ -54,9 +54,13 @@ fn test_nivc_steps() { Tag::Cont(ContTag::Terminal) )); + let dt = &dummy_terminal(); + // `cproc` can't reduce the first input, which is meant for `lurk_step` let first_input = &frames[0].input; - assert!(cproc.call_simple(first_input, &store, &lang, 0).is_err()); + assert!(cproc + .call_simple(first_input, &store, &lang, 0, dt) + .is_err()); // the fourth frame is the one reduced by the coprocessor let cproc_frame = &frames[3]; @@ -66,14 +70,14 @@ fn test_nivc_steps() { // `lurk_step` stutters on the cproc input let output = &lurk_step - .call_simple(&cproc_input, &store, &lang, 0) + .call_simple(&cproc_input, &store, &lang, 0, dt) .unwrap() .output; assert_eq!(&cproc_input, output); // `cproc` *can* reduce the cproc input let output = &cproc - .call_simple(&cproc_input, &store, &lang, 1) + .call_simple(&cproc_input, &store, &lang, 1, dt) .unwrap() .output; assert_ne!(&cproc_input, output); @@ -97,5 +101,7 @@ fn test_nivc_steps() { // `cproc` can't reduce the altered cproc input (with the wrong name) let cproc_input = vec![new_expr, env, cont]; - assert!(cproc.call_simple(&cproc_input, &store, &lang, 0).is_err()); + assert!(cproc + .call_simple(&cproc_input, &store, &lang, 0, dt) + .is_err()); } diff --git a/src/lem/tests/stream.rs b/src/lem/tests/stream.rs new file mode 100644 index 0000000000..d6ff7ea219 --- /dev/null +++ b/src/lem/tests/stream.rs @@ -0,0 +1,127 @@ +use expect_test::{expect, Expect}; +use halo2curves::bn256::Fr; + +use crate::{ + dual_channel::{dummy_terminal, pair_terminals}, + lang::Coproc, + lem::{ + eval::{evaluate_simple, resume_stream_simple, start_stream_simple}, + pointers::Ptr, + store::Store, + }, +}; + +const LIMIT: usize = 200; + +fn get_callable(callable_str: &str, store: &Store) -> Ptr { + let callable = store.read_with_default_state(callable_str).unwrap(); + let (io, _) = + evaluate_simple::>(None, callable, store, LIMIT, &dummy_terminal()).unwrap(); + io[0] +} + +#[inline] +fn expect_eq(computed: usize, expected: &Expect) { + expected.assert_eq(&computed.to_string()); +} + +fn assert_start_stream( + callable: Ptr, + arg: Ptr, + store: &Store, + expected_result: Ptr, + expected_iterations: &Expect, +) -> Vec { + let (t1, t2) = pair_terminals(); + t2.send(arg).unwrap(); + let (output, iterations) = + start_stream_simple::>(None, callable, store, LIMIT, &t1).unwrap(); + let (result, _) = store.fetch_cons(&output[0]).unwrap(); + assert_eq!(result, expected_result); + expect_eq(iterations, expected_iterations); + output +} + +fn assert_resume_stream( + input: Vec, + arg: Ptr, + store: &Store, + expected_result: Ptr, + expected_iterations: &Expect, +) -> Vec { + let (t1, t2) = pair_terminals(); + t2.send(store.intern_nil()).unwrap(); // send nil to skip stuttering + t2.send(arg).unwrap(); + let (output, iterations) = + resume_stream_simple::>(None, input, store, LIMIT, &t1).unwrap(); + let (result, _) = store.fetch_cons(&output[0]).unwrap(); + assert_eq!(result, expected_result); + expect_eq(iterations, expected_iterations); + output +} + +#[test] +fn test_comm_callable() { + let callable_str = "(commit (letrec ((add (lambda (counter x) + (let ((counter (+ counter x))) + (cons counter (commit (add counter))))))) + (add 0)))"; + let store = Store::::default(); + let callable = get_callable(callable_str, &store); + let expected_iterations = &expect!["16"]; + + let output = assert_start_stream( + callable, + store.num_u64(123), + &store, + store.num_u64(123), + expected_iterations, + ); + let output = assert_resume_stream( + output, + store.num_u64(321), + &store, + store.num_u64(444), + expected_iterations, + ); + assert_resume_stream( + output, + store.num_u64(111), + &store, + store.num_u64(555), + expected_iterations, + ); +} + +#[test] +fn test_fun_callable() { + let callable_str = "(letrec ((add (lambda (counter x) + (let ((counter (+ counter x))) + (cons counter (add counter)))))) + (add 0))"; + let store = Store::::default(); + let callable = get_callable(callable_str, &store); + let expected_iterations = &expect!["14"]; + + let output = assert_start_stream( + callable, + store.num_u64(123), + &store, + store.num_u64(123), + expected_iterations, + ); + let output = assert_resume_stream( + output, + store.num_u64(321), + &store, + store.num_u64(444), + expected_iterations, + ); + assert_resume_stream( + output, + store.num_u64(111), + &store, + store.num_u64(555), + expected_iterations, + ); +} diff --git a/src/proof/mod.rs b/src/proof/mod.rs index e65f42ba05..dda7800d8e 100644 --- a/src/proof/mod.rs +++ b/src/proof/mod.rs @@ -93,6 +93,10 @@ where /// Associated type for public parameters type PublicParams; + /// Type for the base recursive SNARK that can be used as a starting point + /// in `Self::prove_recursively` + type BaseRecursiveSNARK; + /// Type for error potentially thrown during verification type ErrorType; @@ -102,6 +106,7 @@ where z0: &[F], steps: Vec, store: &Store, + init: Option, ) -> Result; /// Compress a proof @@ -158,7 +163,7 @@ pub trait Prover<'a, F: CurveCycleEquipped> { type PublicParams; /// Associated proof type, which must implement `RecursiveSNARKTrait` - type RecursiveSnark: RecursiveSNARKTrait; + type RecursiveSNARK: RecursiveSNARKTrait; /// Returns a reference to the prover's FoldingMode fn folding_mode(&self) -> &FoldingMode; @@ -172,14 +177,17 @@ pub trait Prover<'a, F: CurveCycleEquipped> { pp: &Self::PublicParams, steps: Vec, store: &'a Store, - ) -> Result<(Self::RecursiveSnark, Vec, Vec, usize), ProofError> { + init: Option< + >::BaseRecursiveSNARK, + >, + ) -> Result<(Self::RecursiveSNARK, Vec, Vec, usize), ProofError> { store.hydrate_z_cache(); let z0 = store.to_scalar_vector(steps[0].input()); let zi = store.to_scalar_vector(steps.last().unwrap().output()); let num_steps = steps.len(); - let prove_output = Self::RecursiveSnark::prove_recursively(pp, &z0, steps, store)?; + let prove_output = Self::RecursiveSNARK::prove_recursively(pp, &z0, steps, store, init)?; Ok((prove_output, z0, zi, num_steps)) } @@ -193,7 +201,7 @@ pub trait Prover<'a, F: CurveCycleEquipped> { store: &'a Store, limit: usize, ch_terminal: &ChannelTerminal, - ) -> Result<(Self::RecursiveSnark, Vec, Vec, usize), ProofError>; + ) -> Result<(Self::RecursiveSNARK, Vec, Vec, usize), ProofError>; /// Returns the expected total number of steps for the prover given raw iterations. fn expected_num_steps(&self, raw_iterations: usize) -> usize { diff --git a/src/proof/nova.rs b/src/proof/nova.rs index 056bcd58e3..821edf67d0 100644 --- a/src/proof/nova.rs +++ b/src/proof/nova.rs @@ -253,7 +253,7 @@ impl<'a, F: CurveCycleEquipped, C: Coprocessor> RecursiveSNARKTrait> { type PublicParams = PublicParams; - + type BaseRecursiveSNARK = RecursiveSNARK>; type ErrorType = NovaError; #[tracing::instrument(skip_all, name = "nova::prove_recursively")] @@ -262,6 +262,7 @@ impl<'a, F: CurveCycleEquipped, C: Coprocessor> RecursiveSNARKTrait>, store: &Store, + init: Option>>, ) -> Result { let debug = false; assert_eq!(steps[0].arity(), z0.len()); @@ -271,7 +272,7 @@ impl<'a, F: CurveCycleEquipped, C: Coprocessor> RecursiveSNARKTrait>> = None; + let mut recursive_snark_option = init; let prove_step = |i: usize, step: &C1LEM<'a, F, C>, rs: &mut Option>>| { @@ -391,12 +392,13 @@ impl<'a, F: CurveCycleEquipped, C: Coprocessor + 'a> NovaProver<'a, F, C> { pp: &PublicParams, frames: &[Frame], store: &'a Store, + init: Option>>, ) -> Result<(Proof>, Vec, Vec, usize), ProofError> { let folding_config = self .folding_mode() .folding_config(self.lang().clone(), self.reduction_count()); let steps = C1LEM::<'a, F, C>::from_frames(frames, store, &folding_config.into()); - self.prove(pp, steps, store) + self.prove(pp, steps, store, init) } #[inline] @@ -409,7 +411,7 @@ impl<'a, F: CurveCycleEquipped, C: Coprocessor + 'a> NovaProver<'a, F, C> { impl<'a, F: CurveCycleEquipped, C: Coprocessor + 'a> Prover<'a, F> for NovaProver<'a, F, C> { type Frame = C1LEM<'a, F, C>; type PublicParams = PublicParams; - type RecursiveSnark = Proof>; + type RecursiveSNARK = Proof>; #[inline] fn reduction_count(&self) -> usize { @@ -429,10 +431,10 @@ impl<'a, F: CurveCycleEquipped, C: Coprocessor + 'a> Prover<'a, F> for NovaPr store: &'a Store, limit: usize, ch_terminal: &ChannelTerminal, - ) -> Result<(Self::RecursiveSnark, Vec, Vec, usize), ProofError> { + ) -> Result<(Self::RecursiveSNARK, Vec, Vec, usize), ProofError> { let eval_config = self.folding_mode().eval_config(self.lang()); let frames = C1LEM::<'a, F, C>::build_frames(expr, env, store, limit, &eval_config, ch_terminal)?; - self.prove_from_frames(pp, &frames, store) + self.prove_from_frames(pp, &frames, store, None) } } diff --git a/src/proof/supernova.rs b/src/proof/supernova.rs index fff0452bfa..f5e6d4a7fe 100644 --- a/src/proof/supernova.rs +++ b/src/proof/supernova.rs @@ -147,11 +147,17 @@ pub enum Proof { } impl Proof { - /// todo + /// Extracts the original `CompressedSNARK` #[inline] pub fn get_compressed(self) -> Option, SS1, SS2>> { match_opt::match_opt!(self, Self::Compressed(proof, _) => *proof) } + + /// Extracts the original `RecursiveSNARK` + #[inline] + pub fn get_recursive(self) -> Option>> { + match_opt::match_opt!(self, Self::Recursive(proof, _) => *proof) + } } /// A struct for the Nova prover that operates on field elements of type `F`. @@ -183,12 +189,13 @@ impl<'a, F: CurveCycleEquipped, C: Coprocessor + 'a> SuperNovaProver<'a, F, C pp: &PublicParams, frames: &[Frame], store: &'a Store, + init: Option>>, ) -> Result<(Proof>, Vec, Vec, usize), ProofError> { let folding_config = self .folding_mode() .folding_config(self.lang().clone(), self.reduction_count()); let steps = C1LEM::<'a, F, C>::from_frames(frames, store, &folding_config.into()); - self.prove(pp, steps, store) + self.prove(pp, steps, store, init) } #[inline] @@ -202,7 +209,7 @@ impl<'a, F: CurveCycleEquipped, C: Coprocessor> RecursiveSNARKTrait> { type PublicParams = PublicParams; - + type BaseRecursiveSNARK = RecursiveSNARK>; type ErrorType = SuperNovaError; #[tracing::instrument(skip_all, name = "supernova::prove_recursively")] @@ -211,12 +218,13 @@ impl<'a, F: CurveCycleEquipped, C: Coprocessor> RecursiveSNARKTrait>, store: &Store, + init: Option>>, ) -> Result { let debug = false; info!("proving {} steps", steps.len()); - let mut recursive_snark_option: Option>> = None; + let mut recursive_snark_option = init; let prove_step = |i: usize, step: &C1LEM<'a, F, C>, rs: &mut Option>>| { @@ -336,7 +344,7 @@ impl<'a, F: CurveCycleEquipped, C: Coprocessor> RecursiveSNARKTrait> Prover<'a, F> for SuperNovaProver<'a, F, C> { type Frame = C1LEM<'a, F, C>; type PublicParams = PublicParams; - type RecursiveSnark = Proof>; + type RecursiveSNARK = Proof>; #[inline] fn reduction_count(&self) -> usize { @@ -356,11 +364,11 @@ impl<'a, F: CurveCycleEquipped, C: Coprocessor> Prover<'a, F> for SuperNovaPr store: &'a Store, limit: usize, ch_terminal: &ChannelTerminal, - ) -> Result<(Self::RecursiveSnark, Vec, Vec, usize), ProofError> { + ) -> Result<(Self::RecursiveSNARK, Vec, Vec, usize), ProofError> { let eval_config = self.folding_mode().eval_config(self.lang()); let frames = C1LEM::<'a, F, C>::build_frames(expr, env, store, limit, &eval_config, ch_terminal)?; - self.prove_from_frames(pp, &frames, store) + self.prove_from_frames(pp, &frames, store, None) } } diff --git a/src/proof/tests/mod.rs b/src/proof/tests/mod.rs index 0274eae23f..3e4e1131d2 100644 --- a/src/proof/tests/mod.rs +++ b/src/proof/tests/mod.rs @@ -1,4 +1,5 @@ mod nova_tests; +mod stream; mod supernova_tests; use bellpepper::util_cs::{metric_cs::MetricCS, witness_cs::WitnessCS, Comparable}; @@ -183,7 +184,9 @@ fn nova_test_full_aux2<'a, F: CurveCycleEquipped, C: Coprocessor + 'a>( if check_nova { let pp = public_params(reduction_count, lang.clone()); - let (proof, z0, zi, _num_steps) = nova_prover.prove_from_frames(&pp, &frames, s).unwrap(); + let (proof, z0, zi, _num_steps) = nova_prover + .prove_from_frames(&pp, &frames, s, None) + .unwrap(); let res = proof.verify(&pp, &z0, &zi); if res.is_err() { diff --git a/src/proof/tests/stream.rs b/src/proof/tests/stream.rs new file mode 100644 index 0000000000..fad6f29ce5 --- /dev/null +++ b/src/proof/tests/stream.rs @@ -0,0 +1,93 @@ +use expect_test::{expect, Expect}; +use halo2curves::bn256::Fr; +use rayon::iter::{IntoParallelIterator, ParallelIterator}; +use std::sync::Arc; + +use crate::{ + dual_channel::{dummy_terminal, pair_terminals}, + lang::{Coproc, Lang}, + lem::{ + eval::{evaluate_simple, resume_stream, start_stream}, + pointers::Ptr, + store::Store, + }, + proof::{supernova::SuperNovaProver, RecursiveSNARKTrait}, + public_parameters::{instance::Instance, supernova_public_params}, +}; + +const LIMIT: usize = 200; + +fn get_callable(callable_str: &str, store: &Store) -> Ptr { + let callable = store.read_with_default_state(callable_str).unwrap(); + let (io, _) = + evaluate_simple::>(None, callable, store, LIMIT, &dummy_terminal()).unwrap(); + io[0] +} + +#[inline] +fn expect_eq(computed: usize, expected: &Expect) { + expected.assert_eq(&computed.to_string()); +} + +#[test] +fn test_continued_proof() { + let callable_str = "(letrec ((add (lambda (counter x) + (let ((counter (+ counter x))) + (cons counter (add counter)))))) + (add 0))"; + let store = Store::::default(); + let callable = get_callable(callable_str, &store); + let expected_iterations = &expect!["14"]; + + let lang = Arc::new(Lang::>::new()); + + [1, 3, 5].into_par_iter().for_each(|rc| { + let prover = SuperNovaProver::new(rc, lang.clone()); + let instance = Instance::new_supernova(&prover, true); + let pp = supernova_public_params(&instance).unwrap(); + + let (t1, t2) = pair_terminals(); + t2.send(store.num_u64(123)).unwrap(); + let frames = start_stream::>(None, callable, &store, LIMIT, &t1).unwrap(); + + // this input will be used to construct the public input of every proof + let z0 = store.to_scalar_vector(&frames.first().unwrap().input); + + expect_eq(frames.len(), expected_iterations); + let output = &frames.last().unwrap().output; + let (result, _) = store.fetch_cons(&output[0]).unwrap(); + assert_eq!(result, store.num_u64(123)); + + let (proof, ..) = prover + .prove_from_frames(&pp, &frames, &store, None) + .unwrap(); + + proof + .verify(&pp, &z0, &store.to_scalar_vector(output)) + .unwrap(); + + let base_snark = proof.get_recursive(); + assert!(base_snark.is_some()); + + // into the next stream cycle + t2.send(store.intern_nil()).unwrap(); // send nil to skip stuttering + t2.send(store.num_u64(321)).unwrap(); + let frames = + resume_stream::>(None, output.clone(), &store, LIMIT, &t1).unwrap(); + + expect_eq(frames.len(), expected_iterations); + let output = &frames.last().unwrap().output; + let (result, _) = store.fetch_cons(&output[0]).unwrap(); + assert_eq!(result, store.num_u64(444)); + + let (proof, ..) = prover + .prove_from_frames(&pp, &frames, &store, base_snark) + .unwrap(); + + let zi = store.to_scalar_vector(output); + proof.verify(&pp, &z0, &zi).unwrap(); + + let proof = proof.compress(&pp).unwrap(); + proof.verify(&pp, &z0, &zi).unwrap(); + }); +} diff --git a/src/proof/tests/supernova_tests.rs b/src/proof/tests/supernova_tests.rs index 7431fef19f..e17858f557 100644 --- a/src/proof/tests/supernova_tests.rs +++ b/src/proof/tests/supernova_tests.rs @@ -55,7 +55,7 @@ fn test_nil_nil_lang() { let pp = supernova_public_params(&instance).unwrap(); let (proof, ..) = supernova_prover - .prove_from_frames(&pp, &frames, &store) + .prove_from_frames(&pp, &frames, &store, None) .unwrap(); let input_scalar = store.to_scalar_vector(&first_frame.input); diff --git a/src/tag.rs b/src/tag.rs index c045589880..949cda5bd0 100644 --- a/src/tag.rs +++ b/src/tag.rs @@ -142,6 +142,9 @@ pub enum ContTag { Terminal, Emit, Cproc, + StreamStart, + StreamDispatch, + StreamPause, } impl From for u16 { @@ -193,6 +196,9 @@ impl fmt::Display for ContTag { ContTag::Terminal => write!(f, "terminal#"), ContTag::Emit => write!(f, "emit#"), ContTag::Cproc => write!(f, "cproc#"), + ContTag::StreamStart => write!(f, "stream-start#"), + ContTag::StreamDispatch => write!(f, "stream-dispatch#"), + ContTag::StreamPause => write!(f, "stream-pause#"), } } } @@ -539,6 +545,9 @@ pub(crate) mod tests { (ContTag::Terminal, 4110), (ContTag::Emit, 4111), (ContTag::Cproc, 4112), + (ContTag::StreamStart, 4113), + (ContTag::StreamDispatch, 4114), + (ContTag::StreamPause, 4115), ]); assert_eq!(map.len(), ContTag::COUNT); assert_tags_u16s(map)