From 93e915d7ab43671b04557dd384a78b45006a8803 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Garillot?= Date: Sat, 6 Jan 2024 15:59:01 -0500 Subject: [PATCH] refactor: Refactor and remove `MultiFrameTrait` across modules - Removed `MultiFrameTrait` from various files due to its deprecation and switched to direct use of 'MultiFrame' methods. - Updated 'Provable' and 'Prover' traits to accommodate removal of `MultiFrameTrait`, with relevant modifications in methods `proof` and `evaluate_and_prove`. - Renamed 'io_to_scalar_vector' function to 'to_scalar_vector' and applied this change in the trait methods. - Introduced debugging and verification of the circuit's constraint system in 'prove_recursively' function in `nova.rs`. - The 'MultiFrame' struct underwent significant reorganization to replace 'MultiFrameTrait', including new getter methods, building methods and potential performance improvements in synthesizing frames. - Reorganized Frame allocation and handling in the `MultiFrame` instances for more effective synthesization. --- benches/synthesis.rs | 2 +- src/lem/multiframe.rs | 837 ++++++++++++++++++++--------------------- src/proof/mod.rs | 69 +--- src/proof/nova.rs | 4 +- src/proof/supernova.rs | 2 +- src/proof/tests/mod.rs | 2 +- 6 files changed, 426 insertions(+), 490 deletions(-) diff --git a/benches/synthesis.rs b/benches/synthesis.rs index da264cac8a..ca6d7c5fce 100644 --- a/benches/synthesis.rs +++ b/benches/synthesis.rs @@ -12,7 +12,7 @@ use lurk::{ eval::lang::{Coproc, Lang}, field::LurkField, lem::{eval::evaluate, multiframe::MultiFrame, pointers::Ptr, store::Store}, - proof::{supernova::FoldingConfig, MultiFrameTrait}, + proof::supernova::FoldingConfig, state::State, }; diff --git a/src/lem/multiframe.rs b/src/lem/multiframe.rs index 9a5652d60f..37b43a6102 100644 --- a/src/lem/multiframe.rs +++ b/src/lem/multiframe.rs @@ -19,7 +19,7 @@ use crate::{ proof::{ nova::{CurveCycleEquipped, E1, E2}, supernova::{FoldingConfig, C2}, - CEKState, EvaluationStore, FrameLike, MultiFrameTrait, Provable, + CEKState, EvaluationStore, FrameLike, Provable, }, tag::ContTag, }; @@ -69,444 +69,130 @@ impl<'a, F: LurkField, C: Coprocessor> MultiFrame<'a, F, C> { fn get_lang(&self) -> &Arc> { self.folding_config.lang() } -} -impl CEKState for Vec { - fn expr(&self) -> &Ptr { - &self[0] - } - fn env(&self) -> &Ptr { - &self[1] - } - fn cont(&self) -> &Ptr { - &self[2] + #[inline] + pub fn frames(&self) -> Option<&Vec> { + self.frames.as_ref() } -} -impl FrameLike for Frame { - type FrameIO = Vec; - fn input(&self) -> &Self::FrameIO { - &self.input - } - fn output(&self) -> &Self::FrameIO { + #[inline] + pub fn output(&self) -> &Option> { &self.output } -} - -impl EvaluationStore for Store { - type Ptr = Ptr; - type Error = anyhow::Error; - fn read(&self, expr: &str) -> Result { - self.read_with_default_state(expr) + pub fn emitted(_store: &Store, eval_frame: &Frame) -> Vec { + eval_frame.emitted.clone() } - fn initial_empty_env(&self) -> Self::Ptr { - self.intern_nil() - } + pub fn cache_witness(&mut self, s: &Store) -> Result<(), SynthesisError> { + let _ = self.cached_witness.get_or_try_init(|| { + let mut wcs = WitnessCS::new(); - fn get_cont_terminal(&self) -> Self::Ptr { - self.cont_terminal() - } + let z_scalar = s.to_scalar_vector(self.input.as_ref().unwrap()); - fn hydrate_z_cache(&self) { - self.hydrate_z_cache() - } + let mut bogus_cs = WitnessCS::::new(); + let z: Vec> = z_scalar + .iter() + .map(|x| AllocatedNum::alloc_infallible(&mut bogus_cs, || *x)) + .collect::>(); - fn ptr_eq(&self, left: &Self::Ptr, right: &Self::Ptr) -> bool { - self.ptr_eq(left, right) + let output = + nova::traits::circuit::StepCircuit::synthesize(self, &mut wcs, z.as_slice())?; + Ok::<_, SynthesisError>((wcs, output)) + })?; + Ok(()) } -} -/// Checks that a slice of pointers and a slice of allocated pointers have -/// the same length. If `!blank`, asserts that the hashed pointers have tags -/// and values corresponding to the ones from the respective allocated pointers -fn assert_eq_ptrs_aptrs( - store: &Store, - blank: bool, - ptrs: &[Ptr], - aptrs: &[AllocatedPtr], -) -> Result<(), SynthesisError> { - assert_eq!(ptrs.len(), aptrs.len()); - if !blank { - for (aptr, ptr) in aptrs.iter().zip(ptrs) { - let z_ptr = store.hash_ptr(ptr); - let (Some(alloc_ptr_tag), Some(alloc_ptr_hash)) = - (aptr.tag().get_value(), aptr.hash().get_value()) - else { - return Err(SynthesisError::AssignmentMissing); - }; - assert_eq!(alloc_ptr_tag, z_ptr.tag().to_field()); - assert_eq!(&alloc_ptr_hash, z_ptr.value()); - } + #[inline] + pub fn precedes(&self, maybe_next: &Self) -> bool { + self.output == maybe_next.input } - Ok(()) -} - -// Hardcoded slot witness sizes, empirically collected -const BIT_DECOMP_PALLAS_WITNESS_SIZE: usize = 298; -const BIT_DECOMP_VESTA_WITNESS_SIZE: usize = 301; -const BIT_DECOMP_BN256_WITNESS_SIZE: usize = 354; -const BIT_DECOMP_GRUMPKIN_WITNESS_SIZE: usize = 364; -/// Computes the witness size for a `SlotType`. Note that the witness size for -/// bit decomposition depends on the field we're in. -#[inline] -fn compute_witness_size(slot_type: &SlotType, store: &Store) -> usize { - match slot_type { - SlotType::Hash4 => store.hash4_cost() + 4, // 4 preimg elts - SlotType::Hash6 => store.hash6_cost() + 6, // 6 preimg elts - SlotType::Hash8 => store.hash8_cost() + 8, // 8 preimg elts - SlotType::Commitment => store.hash3_cost() + 3, // 3 preimg elts - SlotType::BitDecomp => match F::FIELD { - LanguageField::Pallas => BIT_DECOMP_PALLAS_WITNESS_SIZE, - LanguageField::Vesta => BIT_DECOMP_VESTA_WITNESS_SIZE, - LanguageField::BN256 => BIT_DECOMP_BN256_WITNESS_SIZE, - LanguageField::Grumpkin => BIT_DECOMP_GRUMPKIN_WITNESS_SIZE, - }, + pub fn synthesize_frames>( + &self, + cs: &mut CS, + store: &Store, + input: Vec>, + frames: &[Frame], + g: &GlobalAllocator, + ) -> Result>, SynthesisError> { + let func = self.get_func(); + if cs.is_witness_generator() { + let num_slots_per_frame = func.slots_count.total(); + let slots_witnesses = generate_slots_witnesses( + store, + frames, + num_slots_per_frame, + lurk_config(None, None) + .perf + .parallelism + .poseidon_witnesses + .is_parallel(), + ); + if lurk_config(None, None) + .perf + .parallelism + .synthesis + .is_parallel() + { + Ok(synthesize_frames_parallel( + cs, + g, + store, + input, + frames, + func, + self.get_lang(), + &slots_witnesses, + num_slots_per_frame, + )) + } else { + synthesize_frames_sequential( + cs, + g, + store, + &input, + frames, + func, + self.get_lang(), + Some((&slots_witnesses, num_slots_per_frame)), + ) + } + } else { + synthesize_frames_sequential(cs, g, store, &input, frames, func, self.get_lang(), None) + } } -} -/// Generates the witnesses for all slots in `frames`. Since many slots are fed -/// with dummy data, we cache their (dummy) witnesses for extra speed -fn generate_slots_witnesses( - store: &Store, - frames: &[Frame], - num_slots_per_frame: usize, - parallel: bool, -) -> Vec>> { - let mut slots_data = Vec::with_capacity(frames.len() * num_slots_per_frame); - for frame in frames.iter() { - [ - (&frame.hints.hash4, SlotType::Hash4), - (&frame.hints.hash6, SlotType::Hash6), - (&frame.hints.hash8, SlotType::Hash8), - (&frame.hints.commitment, SlotType::Commitment), - (&frame.hints.bit_decomp, SlotType::BitDecomp), - ] - .into_iter() - .for_each(|(sd_vec, st)| sd_vec.iter().for_each(|sd| slots_data.push((sd, st)))); - } - // precompute these values - let hash4_witness_size = compute_witness_size(&SlotType::Hash4, store); - let hash6_witness_size = compute_witness_size(&SlotType::Hash6, store); - let hash8_witness_size = compute_witness_size(&SlotType::Hash8, store); - let commitment_witness_size = compute_witness_size(&SlotType::Commitment, store); - let bit_decomp_witness_size = compute_witness_size(&SlotType::BitDecomp, store); - // fast getter for the precomputed values - let get_witness_size = |slot_type| match slot_type { - SlotType::Hash4 => hash4_witness_size, - SlotType::Hash6 => hash6_witness_size, - SlotType::Hash8 => hash8_witness_size, - SlotType::Commitment => commitment_witness_size, - SlotType::BitDecomp => bit_decomp_witness_size, - }; - // cache dummy slots witnesses with `Arc` for speedy clones - let dummy_witnesses_cache: FrozenMap<_, Box>>> = FrozenMap::default(); - let gen_slot_witness = |(slot_idx, (slot_data, slot_type))| { - let mk_witness = || { - let mut witness = WitnessCS::with_capacity(1, get_witness_size(slot_type)); - let allocations = allocate_slot(&mut witness, slot_data, slot_idx, slot_type, store) - .expect("slot allocations failed"); - Arc::new(SlotWitness { - witness, - allocations, - }) + pub fn blank(folding_config: Arc>, pc: usize) -> Self { + let (lurk_step, cprocs, rc) = match &*folding_config { + FoldingConfig::IVC(lang, rc) => ( + Arc::new(make_eval_step_from_config(&EvalConfig::new_ivc(lang))), + None, + *rc, + ), + FoldingConfig::NIVC(lang, rc) => ( + Arc::new(make_eval_step_from_config(&EvalConfig::new_nivc(lang))), + Some(make_cprocs_funcs_from_lang(lang).into()), + *rc, + ), }; - if Option::as_ref(slot_data).is_some() { - mk_witness() - } else { - // dummy witness - if let Some(sw) = dummy_witnesses_cache.get(&slot_type) { - // already computed - sw.clone() - } else { - // compute, cache and return - let ws = mk_witness(); - dummy_witnesses_cache.insert(slot_type, Box::new(ws.clone())); - ws - } + let num_frames = if pc == 0 { rc } else { 1 }; + Self { + store: None, + lurk_step, + cprocs, + input: None, + output: None, + frames: None, + cached_witness: OnceCell::new(), + num_frames, + folding_config, + pc, + next_pc: 0, } - }; - if parallel { - slots_data - .into_par_iter() - .enumerate() - .map(gen_slot_witness) - .collect() - } else { - slots_data - .into_iter() - .enumerate() - .map(gen_slot_witness) - .collect() } -} -/// Synthesize frames sequentially, feeding the output of a frame as the input of -/// the next -fn synthesize_frames_sequential, C: Coprocessor>( - cs: &mut CS, - g: &GlobalAllocator, - store: &Store, - input: &[AllocatedPtr], - frames: &[Frame], - func: &Func, - lang: &Lang, - slots_witnesses_num_slots_per_frame: Option<(&[Arc>], usize)>, -) -> Result>, SynthesisError> { - let (_, output) = frames - .iter() - .try_fold((0, input.to_vec()), |(i, input), frame| { - let bound_allocations = &mut BoundAllocations::new(); - func.bind_input(&input, bound_allocations); - let output = func - .synthesize_frame( - &mut cs.namespace(|| format!("frame {i}")), - store, - frame, - g, - bound_allocations, - lang, - slots_witnesses_num_slots_per_frame.map(|(sws, num_slots_per_frame)| { - let slots_witnesses_start = i * num_slots_per_frame; - &sws[slots_witnesses_start..slots_witnesses_start + num_slots_per_frame] - }), - ) - .expect("failed to synthesize frame"); - assert_eq!(input.len(), output.len()); - assert_eq_ptrs_aptrs(store, frame.blank, &frame.output, &output)?; - Ok::<_, SynthesisError>((i + 1, output)) - })?; - Ok(output) -} - -/// Synthesize each frame in parallel, ideally one in each CPU. Each -/// frame will produce its corresponding partial witness, which are then -/// used to extend the final witness. -fn synthesize_frames_parallel, C: Coprocessor>( - cs: &mut CS, - g: &GlobalAllocator, - store: &Store, - input: Vec>, - frames: &[Frame], - func: &Func, - lang: &Lang, - slots_witnesses: &[Arc>], - num_slots_per_frame: usize, -) -> Vec> { - assert!(cs.is_witness_generator()); - assert_eq!(frames.len() * num_slots_per_frame, slots_witnesses.len()); - - let mut css = frames - .par_iter() - .enumerate() - .map(|(i, frame)| { - let mut frame_cs = WitnessCS::new(); - // The first frame will take as input the actual input of the circuit. - // Subsequent frames would have to take the output of the previous one as input. - // But since we know the values of each frame input and we are generating the - // witnesses separately and in parallel, we will allocate new variables for each - // frame. - let allocated_input = if i == 0 { - input.clone() - } else { - frame - .input - .iter() - .map(|input_ptr| { - let z_ptr = store.hash_ptr(input_ptr); - AllocatedPtr::alloc(&mut frame_cs, || Ok(z_ptr)).expect("allocation failed") - }) - .collect::>() - }; - let bound_allocations = &mut BoundAllocations::new(); - func.bind_input(&allocated_input, bound_allocations); - let first_sw_idx = i * num_slots_per_frame; - let last_sw_idx = first_sw_idx + num_slots_per_frame; - let frame_slots_witnesses = &slots_witnesses[first_sw_idx..last_sw_idx]; - - let allocated_output = func - .synthesize_frame( - &mut frame_cs, - store, - frame, - g, - bound_allocations, - lang, - Some(frame_slots_witnesses), - ) - .expect("failed to synthesize frame"); - assert_eq!(allocated_input.len(), allocated_output.len()); - assert_eq_ptrs_aptrs(store, frame.blank, &frame.output, &allocated_output) - .expect("assertion failed"); - (frame_cs, allocated_output) - }) - .collect::>(); - - // At last, we need to concatenate all the partial witnesses into a single witness. - // Since we have allocated the input for each frame (apart from the first) instead - // of using the output of the previous frame, we will have to ignore the allocated - // inputs before concatenating the witnesses - for (i, (frame_cs, _)) in css.iter().enumerate() { - let start = if i == 0 { 0 } else { input.len() * 2 }; - cs.extend_aux(&frame_cs.aux_slice()[start..]); - } - - if let Some((_, last_output)) = css.pop() { - // the final output is the output of the last chunk - last_output - } else { - // there were no frames so we just return the input, preserving the - // same behavior as the sequential version - input - } -} - -/// Pads `frames` up to a certain `size`` with a frame generated with Lurk's step -/// function. For efficiency, `frames` should have enough capacity to avoid -/// reallocations -fn pad_frames>( - frames: &mut Vec, - input: &[Ptr], - lurk_step: &Func, - lang: &Lang, - size: usize, - store: &Store, -) { - let padding_frame = lurk_step - .call_simple(input, store, lang, 0) - .expect("reduction step failed"); - assert_eq!(padding_frame.pc, 0); - assert_eq!(input, padding_frame.output); - frames.resize(size, padding_frame); -} - -impl<'a, F: LurkField, C: Coprocessor + 'a> MultiFrameTrait<'a, F, C> for MultiFrame<'a, F, C> { - fn emitted(_store: &Store, eval_frame: &Frame) -> Vec { - eval_frame.emitted.clone() - } - - fn io_to_scalar_vector(store: &Store, io: &>::FrameIO) -> Vec { - store.to_scalar_vector(io) - } - - fn cache_witness(&mut self, s: &Store) -> Result<(), SynthesisError> { - let _ = self.cached_witness.get_or_try_init(|| { - let mut wcs = WitnessCS::new(); - - let z_scalar = s.to_scalar_vector(self.input.as_ref().unwrap()); - - let mut bogus_cs = WitnessCS::::new(); - let z: Vec> = z_scalar - .iter() - .map(|x| AllocatedNum::alloc_infallible(&mut bogus_cs, || *x)) - .collect::>(); - - let output = - nova::traits::circuit::StepCircuit::synthesize(self, &mut wcs, z.as_slice())?; - Ok::<_, SynthesisError>((wcs, output)) - })?; - Ok(()) - } - - fn output(&self) -> &Option<>::FrameIO> { - &self.output - } - - fn frames(&self) -> Option<&Vec> { - self.frames.as_ref() - } - - fn precedes(&self, maybe_next: &Self) -> bool { - self.output == maybe_next.input - } - - fn synthesize_frames>( - &self, - cs: &mut CS, - store: &Store, - input: Vec>, - frames: &[Frame], - g: &GlobalAllocator, - ) -> Result>, SynthesisError> { - let func = self.get_func(); - if cs.is_witness_generator() { - let num_slots_per_frame = func.slots_count.total(); - let slots_witnesses = generate_slots_witnesses( - store, - frames, - num_slots_per_frame, - lurk_config(None, None) - .perf - .parallelism - .poseidon_witnesses - .is_parallel(), - ); - if lurk_config(None, None) - .perf - .parallelism - .synthesis - .is_parallel() - { - Ok(synthesize_frames_parallel( - cs, - g, - store, - input, - frames, - func, - self.get_lang(), - &slots_witnesses, - num_slots_per_frame, - )) - } else { - synthesize_frames_sequential( - cs, - g, - store, - &input, - frames, - func, - self.get_lang(), - Some((&slots_witnesses, num_slots_per_frame)), - ) - } - } else { - synthesize_frames_sequential(cs, g, store, &input, frames, func, self.get_lang(), None) - } - } - - fn blank(folding_config: Arc>, pc: usize) -> Self { - let (lurk_step, cprocs, rc) = match &*folding_config { - FoldingConfig::IVC(lang, rc) => ( - Arc::new(make_eval_step_from_config(&EvalConfig::new_ivc(lang))), - None, - *rc, - ), - FoldingConfig::NIVC(lang, rc) => ( - Arc::new(make_eval_step_from_config(&EvalConfig::new_nivc(lang))), - Some(make_cprocs_funcs_from_lang(lang).into()), - *rc, - ), - }; - let num_frames = if pc == 0 { rc } else { 1 }; - Self { - store: None, - lurk_step, - cprocs, - input: None, - output: None, - frames: None, - cached_witness: OnceCell::new(), - num_frames, - folding_config, - pc, - next_pc: 0, - } - } - - fn from_frames( + pub fn from_frames( frames: &[Frame], store: &'a Store, folding_config: Arc>, @@ -646,7 +332,7 @@ impl<'a, F: LurkField, C: Coprocessor + 'a> MultiFrameTrait<'a, F, C> for Mul multi_frames } - fn build_frames( + pub fn build_frames( expr: Ptr, env: Ptr, store: &Store, @@ -668,7 +354,7 @@ impl<'a, F: LurkField, C: Coprocessor + 'a> MultiFrameTrait<'a, F, C> for Mul .map_err(|e| ProofError::Reduction(ReductionError::Misc(e.to_string()))) } - fn significant_frame_count(frames: &[Frame]) -> usize { + pub fn significant_frame_count(frames: &[Frame]) -> usize { let stop_cond = |output: &[Ptr]| { matches!( output[2].tag(), @@ -683,6 +369,317 @@ impl<'a, F: LurkField, C: Coprocessor + 'a> MultiFrameTrait<'a, F, C> for Mul } } +impl CEKState for Vec { + fn expr(&self) -> &Ptr { + &self[0] + } + fn env(&self) -> &Ptr { + &self[1] + } + fn cont(&self) -> &Ptr { + &self[2] + } +} + +impl FrameLike for Frame { + type FrameIO = Vec; + fn input(&self) -> &Self::FrameIO { + &self.input + } + fn output(&self) -> &Self::FrameIO { + &self.output + } +} + +impl EvaluationStore for Store { + type Ptr = Ptr; + type Error = anyhow::Error; + + fn read(&self, expr: &str) -> Result { + self.read_with_default_state(expr) + } + + fn initial_empty_env(&self) -> Self::Ptr { + self.intern_nil() + } + + fn get_cont_terminal(&self) -> Self::Ptr { + self.cont_terminal() + } + + fn hydrate_z_cache(&self) { + self.hydrate_z_cache() + } + + fn ptr_eq(&self, left: &Self::Ptr, right: &Self::Ptr) -> bool { + self.ptr_eq(left, right) + } +} + +/// Checks that a slice of pointers and a slice of allocated pointers have +/// the same length. If `!blank`, asserts that the hashed pointers have tags +/// and values corresponding to the ones from the respective allocated pointers +fn assert_eq_ptrs_aptrs( + store: &Store, + blank: bool, + ptrs: &[Ptr], + aptrs: &[AllocatedPtr], +) -> Result<(), SynthesisError> { + assert_eq!(ptrs.len(), aptrs.len()); + if !blank { + for (aptr, ptr) in aptrs.iter().zip(ptrs) { + let z_ptr = store.hash_ptr(ptr); + let (Some(alloc_ptr_tag), Some(alloc_ptr_hash)) = + (aptr.tag().get_value(), aptr.hash().get_value()) + else { + return Err(SynthesisError::AssignmentMissing); + }; + assert_eq!(alloc_ptr_tag, z_ptr.tag().to_field()); + assert_eq!(&alloc_ptr_hash, z_ptr.value()); + } + } + Ok(()) +} + +// Hardcoded slot witness sizes, empirically collected +const BIT_DECOMP_PALLAS_WITNESS_SIZE: usize = 298; +const BIT_DECOMP_VESTA_WITNESS_SIZE: usize = 301; +const BIT_DECOMP_BN256_WITNESS_SIZE: usize = 354; +const BIT_DECOMP_GRUMPKIN_WITNESS_SIZE: usize = 364; + +/// Computes the witness size for a `SlotType`. Note that the witness size for +/// bit decomposition depends on the field we're in. +#[inline] +fn compute_witness_size(slot_type: &SlotType, store: &Store) -> usize { + match slot_type { + SlotType::Hash4 => store.hash4_cost() + 4, // 4 preimg elts + SlotType::Hash6 => store.hash6_cost() + 6, // 6 preimg elts + SlotType::Hash8 => store.hash8_cost() + 8, // 8 preimg elts + SlotType::Commitment => store.hash3_cost() + 3, // 3 preimg elts + SlotType::BitDecomp => match F::FIELD { + LanguageField::Pallas => BIT_DECOMP_PALLAS_WITNESS_SIZE, + LanguageField::Vesta => BIT_DECOMP_VESTA_WITNESS_SIZE, + LanguageField::BN256 => BIT_DECOMP_BN256_WITNESS_SIZE, + LanguageField::Grumpkin => BIT_DECOMP_GRUMPKIN_WITNESS_SIZE, + }, + } +} + +/// Generates the witnesses for all slots in `frames`. Since many slots are fed +/// with dummy data, we cache their (dummy) witnesses for extra speed +fn generate_slots_witnesses( + store: &Store, + frames: &[Frame], + num_slots_per_frame: usize, + parallel: bool, +) -> Vec>> { + let mut slots_data = Vec::with_capacity(frames.len() * num_slots_per_frame); + for frame in frames.iter() { + [ + (&frame.hints.hash4, SlotType::Hash4), + (&frame.hints.hash6, SlotType::Hash6), + (&frame.hints.hash8, SlotType::Hash8), + (&frame.hints.commitment, SlotType::Commitment), + (&frame.hints.bit_decomp, SlotType::BitDecomp), + ] + .into_iter() + .for_each(|(sd_vec, st)| sd_vec.iter().for_each(|sd| slots_data.push((sd, st)))); + } + // precompute these values + let hash4_witness_size = compute_witness_size(&SlotType::Hash4, store); + let hash6_witness_size = compute_witness_size(&SlotType::Hash6, store); + let hash8_witness_size = compute_witness_size(&SlotType::Hash8, store); + let commitment_witness_size = compute_witness_size(&SlotType::Commitment, store); + let bit_decomp_witness_size = compute_witness_size(&SlotType::BitDecomp, store); + // fast getter for the precomputed values + let get_witness_size = |slot_type| match slot_type { + SlotType::Hash4 => hash4_witness_size, + SlotType::Hash6 => hash6_witness_size, + SlotType::Hash8 => hash8_witness_size, + SlotType::Commitment => commitment_witness_size, + SlotType::BitDecomp => bit_decomp_witness_size, + }; + // cache dummy slots witnesses with `Arc` for speedy clones + let dummy_witnesses_cache: FrozenMap<_, Box>>> = FrozenMap::default(); + let gen_slot_witness = |(slot_idx, (slot_data, slot_type))| { + let mk_witness = || { + let mut witness = WitnessCS::with_capacity(1, get_witness_size(slot_type)); + let allocations = allocate_slot(&mut witness, slot_data, slot_idx, slot_type, store) + .expect("slot allocations failed"); + Arc::new(SlotWitness { + witness, + allocations, + }) + }; + if Option::as_ref(slot_data).is_some() { + mk_witness() + } else { + // dummy witness + if let Some(sw) = dummy_witnesses_cache.get(&slot_type) { + // already computed + sw.clone() + } else { + // compute, cache and return + let ws = mk_witness(); + dummy_witnesses_cache.insert(slot_type, Box::new(ws.clone())); + ws + } + } + }; + if parallel { + slots_data + .into_par_iter() + .enumerate() + .map(gen_slot_witness) + .collect() + } else { + slots_data + .into_iter() + .enumerate() + .map(gen_slot_witness) + .collect() + } +} + +/// Synthesize frames sequentially, feeding the output of a frame as the input of +/// the next +fn synthesize_frames_sequential, C: Coprocessor>( + cs: &mut CS, + g: &GlobalAllocator, + store: &Store, + input: &[AllocatedPtr], + frames: &[Frame], + func: &Func, + lang: &Lang, + slots_witnesses_num_slots_per_frame: Option<(&[Arc>], usize)>, +) -> Result>, SynthesisError> { + let (_, output) = frames + .iter() + .try_fold((0, input.to_vec()), |(i, input), frame| { + let bound_allocations = &mut BoundAllocations::new(); + func.bind_input(&input, bound_allocations); + let output = func + .synthesize_frame( + &mut cs.namespace(|| format!("frame {i}")), + store, + frame, + g, + bound_allocations, + lang, + slots_witnesses_num_slots_per_frame.map(|(sws, num_slots_per_frame)| { + let slots_witnesses_start = i * num_slots_per_frame; + &sws[slots_witnesses_start..slots_witnesses_start + num_slots_per_frame] + }), + ) + .expect("failed to synthesize frame"); + assert_eq!(input.len(), output.len()); + assert_eq_ptrs_aptrs(store, frame.blank, &frame.output, &output)?; + Ok::<_, SynthesisError>((i + 1, output)) + })?; + Ok(output) +} + +/// Synthesize each frame in parallel, ideally one in each CPU. Each +/// frame will produce its corresponding partial witness, which are then +/// used to extend the final witness. +fn synthesize_frames_parallel, C: Coprocessor>( + cs: &mut CS, + g: &GlobalAllocator, + store: &Store, + input: Vec>, + frames: &[Frame], + func: &Func, + lang: &Lang, + slots_witnesses: &[Arc>], + num_slots_per_frame: usize, +) -> Vec> { + assert!(cs.is_witness_generator()); + assert_eq!(frames.len() * num_slots_per_frame, slots_witnesses.len()); + + let mut css = frames + .par_iter() + .enumerate() + .map(|(i, frame)| { + let mut frame_cs = WitnessCS::new(); + // The first frame will take as input the actual input of the circuit. + // Subsequent frames would have to take the output of the previous one as input. + // But since we know the values of each frame input and we are generating the + // witnesses separately and in parallel, we will allocate new variables for each + // frame. + let allocated_input = if i == 0 { + input.clone() + } else { + frame + .input + .iter() + .map(|input_ptr| { + let z_ptr = store.hash_ptr(input_ptr); + AllocatedPtr::alloc(&mut frame_cs, || Ok(z_ptr)).expect("allocation failed") + }) + .collect::>() + }; + let bound_allocations = &mut BoundAllocations::new(); + func.bind_input(&allocated_input, bound_allocations); + let first_sw_idx = i * num_slots_per_frame; + let last_sw_idx = first_sw_idx + num_slots_per_frame; + let frame_slots_witnesses = &slots_witnesses[first_sw_idx..last_sw_idx]; + + let allocated_output = func + .synthesize_frame( + &mut frame_cs, + store, + frame, + g, + bound_allocations, + lang, + Some(frame_slots_witnesses), + ) + .expect("failed to synthesize frame"); + assert_eq!(allocated_input.len(), allocated_output.len()); + assert_eq_ptrs_aptrs(store, frame.blank, &frame.output, &allocated_output) + .expect("assertion failed"); + (frame_cs, allocated_output) + }) + .collect::>(); + + // At last, we need to concatenate all the partial witnesses into a single witness. + // Since we have allocated the input for each frame (apart from the first) instead + // of using the output of the previous frame, we will have to ignore the allocated + // inputs before concatenating the witnesses + for (i, (frame_cs, _)) in css.iter().enumerate() { + let start = if i == 0 { 0 } else { input.len() * 2 }; + cs.extend_aux(&frame_cs.aux_slice()[start..]); + } + + if let Some((_, last_output)) = css.pop() { + // the final output is the output of the last chunk + last_output + } else { + // there were no frames so we just return the input, preserving the + // same behavior as the sequential version + input + } +} + +/// Pads `frames` up to a certain `size`` with a frame generated with Lurk's step +/// function. For efficiency, `frames` should have enough capacity to avoid +/// reallocations +fn pad_frames>( + frames: &mut Vec, + input: &[Ptr], + lurk_step: &Func, + lang: &Lang, + size: usize, + store: &Store, +) { + let padding_frame = lurk_step + .call_simple(input, store, lang, 0) + .expect("reduction step failed"); + assert_eq!(padding_frame.pc, 0); + assert_eq!(input, padding_frame.output); + frames.resize(size, padding_frame); +} + impl<'a, F: LurkField, C: Coprocessor> Circuit for MultiFrame<'a, F, C> { fn synthesize>(self, cs: &mut CS) -> Result<(), SynthesisError> { let mut synth = |store: &Store, frames: &[Frame], input: &[Ptr], output: &[Ptr]| { diff --git a/src/proof/mod.rs b/src/proof/mod.rs index 1483b7dead..d1daa03861 100644 --- a/src/proof/mod.rs +++ b/src/proof/mod.rs @@ -15,19 +15,15 @@ pub mod supernova; #[cfg(test)] mod tests; -use ::nova::traits::{circuit::StepCircuit, Engine}; -use bellpepper_core::{Circuit, ConstraintSystem, SynthesisError}; +use ::nova::traits::Engine; use std::sync::Arc; use crate::{ - circuit::gadgets::pointer::AllocatedPtr, coprocessor::Coprocessor, error::ProofError, eval::lang::Lang, field::LurkField, - lem::{ - circuit::GlobalAllocator, eval::EvalConfig, interpreter::Frame, pointers::Ptr, store::Store, - }, + lem::{eval::EvalConfig, interpreter::Frame, pointers::Ptr, store::Store}, proof::nova::E2, }; @@ -77,63 +73,6 @@ pub trait EvaluationStore { fn ptr_eq(&self, left: &Self::Ptr, right: &Self::Ptr) -> bool; } -/// Trait to support multiple `MultiFrame` implementations. -pub trait MultiFrameTrait<'a, F: LurkField, C: Coprocessor + 'a>: - Provable + Circuit + StepCircuit + 'a -{ - /// the emitted frames - fn emitted(store: &Store, eval_frame: &Frame) -> Vec; - - /// Counting the number of non-trivial frames in the evaluation - fn significant_frame_count(frames: &[Frame]) -> usize; - - /// Evaluates and generates the frames of the computation given the expression, environment, and store - fn build_frames( - expr: Ptr, - env: Ptr, - store: &Store, - limit: usize, - ec: &EvalConfig<'_, F, C>, - ) -> Result, ProofError>; - - /// Returns a public IO vector when equipped with the local store, and the Self::Frame's IO - fn io_to_scalar_vector(store: &Store, io: &>::FrameIO) -> Vec; - - /// Returns true if the supplied instance directly precedes this one in a sequential computation trace. - fn precedes(&self, maybe_next: &Self) -> bool; - - /// Cache the witness internally, which can be used later during synthesis. - /// This function can be called in parallel to speed up the witness generation - /// for a series of `MultiFrameTrait` instances. - fn cache_witness(&mut self, s: &Store) -> Result<(), SynthesisError>; - - /// The output of the last frame - fn output(&self) -> &Option<>::FrameIO>; - - /// Iterates through the Self::CircuitFrame instances - fn frames(&self) -> Option<&Vec>; - - /// Synthesize some frames. - fn synthesize_frames>( - &self, - cs: &mut CS, - store: &Store, - input: Vec>, - frames: &[Frame], - g: &GlobalAllocator, - ) -> Result>, SynthesisError>; - - /// Synthesize a blank circuit. - fn blank(folding_config: Arc>, pc: usize) -> Self; - - /// Create an instance from some `Self::Frame`s. - fn from_frames( - frames: &[Frame], - store: &'a Store, - folding_config: Arc>, - ) -> Vec; -} - /// A trait for provable structures over a field `F`. pub trait Provable { /// Returns the public inputs of the provable structure. @@ -240,8 +179,8 @@ pub trait Prover<'a, F: CurveCycleEquipped, C: Coprocessor + 'a> { store: &'a Store, ) -> Result<(Self::RecursiveSnark, Vec, Vec, usize), ProofError> { store.hydrate_z_cache(); - let z0 = C1LEM::<'a, F, C>::io_to_scalar_vector(store, frames[0].input()); - let zi = C1LEM::<'a, F, C>::io_to_scalar_vector(store, frames.last().unwrap().output()); + let z0 = store.to_scalar_vector(frames[0].input()); + let zi = store.to_scalar_vector(frames.last().unwrap().output()); let lang = self.lang().clone(); let folding_config = self diff --git a/src/proof/nova.rs b/src/proof/nova.rs index 03219349d6..afbf9561f1 100644 --- a/src/proof/nova.rs +++ b/src/proof/nova.rs @@ -29,7 +29,7 @@ use crate::{ eval::lang::Lang, field::LurkField, lem::store::Store, - proof::{supernova::FoldingConfig, FrameLike, MultiFrameTrait, Prover}, + proof::{supernova::FoldingConfig, FrameLike, Prover}, }; use super::{FoldingMode, RecursiveSNARKTrait}; @@ -315,7 +315,7 @@ where // This is a CircuitFrame, not an EvalFrame let first_frame = circuit_primary.frames().unwrap().iter().next().unwrap(); - let zi = C1LEM::<_, C>::io_to_scalar_vector(store, first_frame.input()); + let zi = store.to_scalar_vector(first_frame.input()); let zi_allocated: Vec<_> = zi .iter() .enumerate() diff --git a/src/proof/supernova.rs b/src/proof/supernova.rs index dfc83df48e..192ce853ff 100644 --- a/src/proof/supernova.rs +++ b/src/proof/supernova.rs @@ -27,7 +27,7 @@ use crate::{ lem::store::Store, proof::{ nova::{CurveCycleEquipped, NovaCircuitShape, E1, E2}, - RecursiveSNARKTrait, {MultiFrameTrait, Prover}, + Prover, RecursiveSNARKTrait, }, }; diff --git a/src/proof/tests/mod.rs b/src/proof/tests/mod.rs index 1a54d340a8..abbc388a8f 100644 --- a/src/proof/tests/mod.rs +++ b/src/proof/tests/mod.rs @@ -14,7 +14,7 @@ use crate::{ proof::{ nova::{public_params, CurveCycleEquipped, NovaProver, C1LEM, E1, E2}, supernova::FoldingConfig, - CEKState, EvaluationStore, MultiFrameTrait, Provable, Prover, RecursiveSNARKTrait, + CEKState, EvaluationStore, Provable, Prover, RecursiveSNARKTrait, }, };