From d65a64fb9715c0424728fa33f4b67e9dccf196d7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Garillot?= <4142+huitseeker@users.noreply.github.com> Date: Thu, 29 Jun 2023 16:29:16 -0400 Subject: [PATCH] Unsafe Serialization for Faster Public Parameter Caching (#474) * checkpoint3 * fix missing args, todo unhack later * add timing features, mmap attempt * fix * attempt to use closures * attempt to use closures * utlra fast param decoding * match APIs, bugfix * cargo fmt * add a Clone * refactor: Enhance error handling in `get_with_timing` method - Refactored `get_with_timing` method in `src/public_parameters/file_map.rs` for more streamlined functionality * chore: clippy --------- Co-authored-by: Hanting Zhang --- Cargo.lock | 33 ++++++----------- Cargo.toml | 5 +++ benches/end2end.rs | 12 ++++--- benches/fibonacci.rs | 4 +-- clutch/src/lib.rs | 6 ++-- examples/sha256.rs | 60 ++++++++++++++++--------------- fcomm/src/bin/fcomm.rs | 6 ++-- fcomm/src/lib.rs | 2 +- src/cli/lurk_proof.rs | 2 +- src/cli/repl.rs | 2 +- src/hash.rs | 15 ++++---- src/proof/nova.rs | 38 ++++++++++++++++---- src/public_parameters/file_map.rs | 47 ++++++++++++++++++++++-- src/public_parameters/mod.rs | 59 +++++++++++++++++++++++++++--- src/public_parameters/registry.rs | 55 ++++++++++++++++++++-------- 15 files changed, 243 insertions(+), 103 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 2f1546dfc1..7305261b50 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -608,7 +608,7 @@ dependencies = [ "fcomm", "ff", "lurk", - "pasta_curves 0.5.1 (git+https://github.com/lurk-lab/pasta_curves?branch=dev)", + "pasta_curves", "pretty_env_logger", "serde", ] @@ -977,7 +977,7 @@ dependencies = [ "num_cpus", "once_cell", "pairing", - "pasta_curves 0.5.1 (git+https://github.com/lurk-lab/pasta_curves?branch=dev)", + "pasta_curves", "predicates 2.1.5", "pretty_env_logger", "proptest", @@ -1159,7 +1159,7 @@ dependencies = [ "lazy_static", "num-bigint 0.4.3", "num-traits", - "pasta_curves 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)", + "pasta_curves", "paste", "rand", "rand_core", @@ -1431,6 +1431,8 @@ checksum = "b06a4cde4c0f271a446782e3eff8de789548ce57dbc8eca9292c27f4a42004b4" name = "lurk" version = "0.2.0" dependencies = [ + "abomonation", + "abomonation_derive", "ahash 0.7.6", "anyhow", "anymap", @@ -1466,7 +1468,7 @@ dependencies = [ "once_cell", "pairing", "pasta-msm", - "pasta_curves 0.5.1 (git+https://github.com/lurk-lab/pasta_curves?branch=dev)", + "pasta_curves", "peekmore", "pprof", "pretty_env_logger", @@ -1496,7 +1498,7 @@ version = "0.1.0" dependencies = [ "bincode", "lurk", - "pasta_curves 0.5.1 (git+https://github.com/lurk-lab/pasta_curves?branch=dev)", + "pasta_curves", "proc-macro2 1.0.66", "proptest", "proptest-derive", @@ -1561,7 +1563,7 @@ dependencies = [ "generic-array", "itertools 0.8.2", "log", - "pasta_curves 0.5.1 (git+https://github.com/lurk-lab/pasta_curves?branch=dev)", + "pasta_curves", "serde", "trait-set", ] @@ -1628,7 +1630,7 @@ dependencies = [ "num-integer", "num-traits", "pasta-msm", - "pasta_curves 0.5.1 (git+https://github.com/lurk-lab/pasta_curves?branch=dev)", + "pasta_curves", "rand_chacha", "rand_core", "rayon", @@ -1796,27 +1798,12 @@ version = "0.1.4" source = "git+https://github.com/lurk-lab/pasta-msm?branch=dev#182b971dd0f6dcc1a9a6bd5db8646bdd4600ed7e" dependencies = [ "cc", - "pasta_curves 0.5.1 (git+https://github.com/lurk-lab/pasta_curves?branch=dev)", + "pasta_curves", "semolina", "sppark", "which", ] -[[package]] -name = "pasta_curves" -version = "0.5.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d3e57598f73cc7e1b2ac63c79c517b31a0877cd7c402cdcaa311b5208de7a095" -dependencies = [ - "blake2b_simd", - "ff", - "group", - "lazy_static", - "rand", - "static_assertions", - "subtle", -] - [[package]] name = "pasta_curves" version = "0.5.1" diff --git a/Cargo.toml b/Cargo.toml index 35a6724e4d..350ec58bca 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -53,6 +53,8 @@ serde_repr = "0.1.14" tap = "1.0.1" stable_deref_trait = "1.2.0" thiserror = { workspace = true } +abomonation = "0.7.3" +abomonation_derive = { git = "https://github.com/winston-h-zhang/abomonation_derive.git" } [target.'cfg(not(target_arch = "wasm32"))'.dependencies] memmap = { version = "0.5.10", package = "memmap2" } @@ -152,3 +154,6 @@ harness = false [patch.crates-io] sppark = { git = "https://github.com/supranational/sppark", rev="5fea26f43cc5d12a77776c70815e7c722fd1f8a7" } +# This is needed to ensure halo2curves, which imports pasta-curves, uses the *same* traits in bn256_grumpkin +pasta_curves = { git="https://github.com/lurk-lab/pasta_curves", branch="dev" } + diff --git a/benches/end2end.rs b/benches/end2end.rs index 5d22a81f60..ea4f4e2f5d 100644 --- a/benches/end2end.rs +++ b/benches/end2end.rs @@ -61,7 +61,8 @@ fn end2end_benchmark(c: &mut Criterion) { let prover = NovaProver::new(reduction_count, lang_pallas); // use cached public params - let pp = public_parameters::public_params(reduction_count, lang_pallas_rc.clone()).unwrap(); + let pp = + public_parameters::public_params(reduction_count, true, lang_pallas_rc.clone()).unwrap(); let size = (10, 0); let benchmark_id = BenchmarkId::new("end2end_go_base_nova", format!("_{}_{}", size.0, size.1)); @@ -265,7 +266,8 @@ fn prove_benchmark(c: &mut Criterion) { group.bench_with_input(benchmark_id, &size, |b, &s| { let ptr = go_base::(&mut store, s.0, s.1); let prover = NovaProver::new(reduction_count, lang_pallas.clone()); - let pp = public_parameters::public_params(reduction_count, lang_pallas_rc.clone()).unwrap(); + let pp = public_parameters::public_params(reduction_count, true, lang_pallas_rc.clone()) + .unwrap(); let frames = prover .get_evaluation_frames(ptr, empty_sym_env(&store), &mut store, limit, &lang_pallas) .unwrap(); @@ -343,7 +345,8 @@ fn verify_benchmark(c: &mut Criterion) { let ptr = go_base(&mut store, s.0, s.1); let prover = NovaProver::new(reduction_count, lang_pallas.clone()); let pp = - public_parameters::public_params(reduction_count, lang_pallas_rc.clone()).unwrap(); + public_parameters::public_params(reduction_count, true, lang_pallas_rc.clone()) + .unwrap(); let frames = prover .get_evaluation_frames(ptr, empty_sym_env(&store), &mut store, limit, &lang_pallas) .unwrap(); @@ -388,7 +391,8 @@ fn verify_compressed_benchmark(c: &mut Criterion) { let ptr = go_base(&mut store, s.0, s.1); let prover = NovaProver::new(reduction_count, lang_pallas.clone()); let pp = - public_parameters::public_params(reduction_count, lang_pallas_rc.clone()).unwrap(); + public_parameters::public_params(reduction_count, true, lang_pallas_rc.clone()) + .unwrap(); let frames = prover .get_evaluation_frames(ptr, empty_sym_env(&store), &mut store, limit, &lang_pallas) .unwrap(); diff --git a/benches/fibonacci.rs b/benches/fibonacci.rs index f05c7732d7..bc2130e9ff 100644 --- a/benches/fibonacci.rs +++ b/benches/fibonacci.rs @@ -48,7 +48,7 @@ fn fibo_total(name: &str, iterations: u64, c: &mut let reduction_count = DEFAULT_REDUCTION_COUNT; // use cached public params - let pp = public_params(reduction_count, lang_rc.clone()).unwrap(); + let pp = public_params(reduction_count, true, lang_rc.clone()).unwrap(); c.bench_with_input( BenchmarkId::new(name.to_string(), iterations), @@ -99,7 +99,7 @@ fn fibo_prove(name: &str, iterations: u64, c: &mut let lang_pallas = Lang::>::new(); let lang_rc = Arc::new(lang_pallas.clone()); let reduction_count = DEFAULT_REDUCTION_COUNT; - let pp = public_params(reduction_count, lang_rc.clone()).unwrap(); + let pp = public_params(reduction_count, true, lang_rc.clone()).unwrap(); c.bench_with_input( BenchmarkId::new(name.to_string(), iterations), diff --git a/clutch/src/lib.rs b/clutch/src/lib.rs index e70d74b697..fa9943bcc3 100644 --- a/clutch/src/lib.rs +++ b/clutch/src/lib.rs @@ -137,7 +137,7 @@ impl ReplTrait> for ClutchState> { let lang_rc = Arc::new(lang.clone()); // Load params from disk cache, or generate them in the background. - thread::spawn(move || public_params(reduction_count, lang_rc)); + thread::spawn(move || public_params(reduction_count, true, lang_rc)); Self { repl_state: ReplState::new(s, limit, command, lang), @@ -497,7 +497,7 @@ impl ClutchState> { let (proof_in_expr, _rest1) = store.car_cdr(&rest)?; let prover = NovaProver::>::new(self.reduction_count, (*self.lang()).clone()); - let pp = public_params(self.reduction_count, self.lang())?; + let pp = public_params(self.reduction_count, true, self.lang())?; let proof = if rest.is_nil() { self.last_claim @@ -556,7 +556,7 @@ impl ClutchState> { .get(&zptr_string) .ok_or_else(|| anyhow!("proof not found: {zptr_string}"))?; - let pp = public_params(self.reduction_count, self.lang())?; + let pp = public_params(self.reduction_count, true, self.lang())?; let result = proof.verify(&pp, &self.lang()).unwrap(); if result.verified { diff --git a/examples/sha256.rs b/examples/sha256.rs index 849082eced..b4a3539a21 100644 --- a/examples/sha256.rs +++ b/examples/sha256.rs @@ -11,7 +11,7 @@ use lurk::eval::{empty_sym_env, lang::Lang}; use lurk::field::LurkField; use lurk::proof::{nova::NovaProver, Prover}; use lurk::ptr::Ptr; -use lurk::public_parameters::public_params; +use lurk::public_parameters::with_public_params; use lurk::store::Store; use lurk::sym; use lurk_macros::Coproc; @@ -26,7 +26,7 @@ use pasta_curves::pallas::Scalar as Fr; use serde::{Deserialize, Serialize}; use sha2::{Digest, Sha256}; -const REDUCTION_COUNT: usize = 10; +const REDUCTION_COUNT: usize = 100; #[derive(Clone, Debug, Serialize, Deserialize)] pub(crate) struct Sha256Coprocessor { @@ -179,47 +179,49 @@ fn main() { Sha256Coprocessor::new(input_size, u).into(), )], ); + let lang_rc = Arc::new(lang.clone()); let coproc_expr = format!("({})", sym_str); let ptr = store.read(&coproc_expr).unwrap(); let nova_prover = NovaProver::>::new(REDUCTION_COUNT, lang.clone()); - let lang_rc = Arc::new(lang); - println!("Setting up public parameters..."); + println!("Setting up public parameters (rc = {REDUCTION_COUNT})..."); let pp_start = Instant::now(); - let pp = public_params::>(REDUCTION_COUNT, lang_rc.clone()).unwrap(); - let pp_end = pp_start.elapsed(); - println!("Public parameters took {:?}", pp_end); + // see the documentation on `with_public_params` + let _res = with_public_params(REDUCTION_COUNT, lang_rc.clone(), |pp| { + let pp_end = pp_start.elapsed(); + println!("Public parameters took {:?}", pp_end); - if setup_only { - return; - } - - println!("Beginning proof step..."); + if setup_only { + return; + } - let proof_start = Instant::now(); - let (proof, z0, zi, num_steps) = nova_prover - .evaluate_and_prove(&pp, ptr, empty_sym_env(store), store, 10000, lang_rc) - .unwrap(); - let proof_end = proof_start.elapsed(); + println!("Beginning proof step..."); + let proof_start = Instant::now(); + let (proof, z0, zi, num_steps) = nova_prover + .evaluate_and_prove(pp, ptr, empty_sym_env(store), store, 10000, lang_rc) + .unwrap(); + let proof_end = proof_start.elapsed(); - println!("Proofs took {:?}", proof_end); + println!("Proofs took {:?}", proof_end); - println!("Verifying proof..."); + println!("Verifying proof..."); - let verify_start = Instant::now(); - let res = proof.verify(&pp, num_steps, &z0, &zi).unwrap(); - let verify_end = verify_start.elapsed(); + let verify_start = Instant::now(); + let res = proof.verify(&pp, num_steps, &z0, &zi).unwrap(); + let verify_end = verify_start.elapsed(); - println!("Verify took {:?}", verify_end); + println!("Verify took {:?}", verify_end); - if res { - println!( - "Congratulations! You proved and verified a SHA256 hash calculation in {:?} time!", - pp_end + proof_end + verify_end - ); - } + if res { + println!( + "Congratulations! You proved and verified a SHA256 hash calculation in {:?} time!", + pp_end + proof_end + verify_end + ); + } + }) + .unwrap(); } diff --git a/fcomm/src/bin/fcomm.rs b/fcomm/src/bin/fcomm.rs index d217e5315a..1c8013a9e8 100644 --- a/fcomm/src/bin/fcomm.rs +++ b/fcomm/src/bin/fcomm.rs @@ -228,7 +228,7 @@ impl Open { let rc = ReductionCount::try_from(self.reduction_count).expect("reduction count"); let prover = NovaProver::>::new(rc.count(), lang.clone()); let lang_rc = Arc::new(lang.clone()); - let pp = public_params(rc.count(), lang_rc).expect("public params"); + let pp = public_params(rc.count(), true, lang_rc).expect("public params"); let function_map = committed_expression_store(); let handle_proof = |out_path, proof: Proof| { @@ -332,7 +332,7 @@ impl Prove { let rc = ReductionCount::try_from(self.reduction_count).unwrap(); let prover = NovaProver::>::new(rc.count(), lang.clone()); let lang_rc = Arc::new(lang.clone()); - let pp = public_params(rc.count(), lang_rc.clone()).unwrap(); + let pp = public_params(rc.count(), true, lang_rc.clone()).unwrap(); let proof = match &self.claim { Some(claim) => { @@ -378,7 +378,7 @@ impl Verify { fn verify(&self, cli_error: bool, lang: &Lang>) { let proof = proof(Some(&self.proof)).unwrap(); let lang_rc = Arc::new(lang.clone()); - let pp = public_params(proof.reduction_count.count(), lang_rc).unwrap(); + let pp = public_params(proof.reduction_count.count(), true, lang_rc).unwrap(); let result = proof.verify(&pp, lang).unwrap(); serde_json::to_writer(io::stdout(), &result).unwrap(); diff --git a/fcomm/src/lib.rs b/fcomm/src/lib.rs index be22c7d343..f52f389ca9 100644 --- a/fcomm/src/lib.rs +++ b/fcomm/src/lib.rs @@ -1135,7 +1135,7 @@ mod test { let lang = Lang::new(); let lang_rc = Arc::new(lang.clone()); let rc = ReductionCount::One; - let pp = public_params(rc.count(), lang_rc.clone()).expect("public params"); + let pp = public_params(rc.count(), true, lang_rc.clone()).expect("public params"); let chained = true; let s = &mut Store::::default(); diff --git a/src/cli/lurk_proof.rs b/src/cli/lurk_proof.rs index d9c5fab6a4..4c8f9bc432 100644 --- a/src/cli/lurk_proof.rs +++ b/src/cli/lurk_proof.rs @@ -106,7 +106,7 @@ mod non_wasm { lang, } => { log::info!("Loading public parameters"); - let pp = public_params(rc, std::sync::Arc::new(lang))?; + let pp = public_params(rc, true, std::sync::Arc::new(lang))?; Ok(proof.verify(&pp, num_steps, &public_inputs, &public_outputs)?) } } diff --git a/src/cli/repl.rs b/src/cli/repl.rs index 2427eeb294..abbfec8cfd 100644 --- a/src/cli/repl.rs +++ b/src/cli/repl.rs @@ -237,7 +237,7 @@ impl Repl { } info!("Loading public parameters"); - let pp = public_params(self.rc, self.lang.clone())?; + let pp = public_params(self.rc, true, self.lang.clone())?; let prover = NovaProver::new(self.rc, (*self.lang).clone()); diff --git a/src/hash.rs b/src/hash.rs index 6de8e89f21..4816712ff3 100644 --- a/src/hash.rs +++ b/src/hash.rs @@ -1,5 +1,6 @@ use std::collections::HashMap; use std::hash::Hash; +use std::sync::Arc; use crate::cache_map::CacheMap; use crate::field::{FWrap, LurkField}; @@ -36,7 +37,7 @@ pub enum HashConst<'a, F: LurkField> { } /// Holds the constants needed for poseidon hashing. -#[derive(Debug)] +#[derive(Clone, Debug)] pub struct HashConstants { c3: OnceCell>, c4: OnceCell>, @@ -82,12 +83,12 @@ impl HashConstants { } } -#[derive(Default, Debug)] +#[derive(Clone, Default, Debug)] pub struct PoseidonCache { - a3: CacheMap, F>, - a4: CacheMap, F>, - a6: CacheMap, F>, - a8: CacheMap, F>, + a3: Arc, F>>, + a4: Arc, F>>, + a6: Arc, F>>, + a8: Arc, F>>, pub constants: HashConstants, } @@ -112,7 +113,7 @@ impl PoseidonCache { } } -#[derive(Default, Debug)] +#[derive(Clone, Default, Debug)] pub struct InversePoseidonCache { a3: HashMap, [F; 3]>, a4: HashMap, [F; 4]>, diff --git a/src/proof/nova.rs b/src/proof/nova.rs index 52d7da9483..13ce88fc8b 100644 --- a/src/proof/nova.rs +++ b/src/proof/nova.rs @@ -2,6 +2,7 @@ use std::marker::PhantomData; +use abomonation::Abomonation; use bellperson::{gadgets::num::AllocatedNum, ConstraintSystem, SynthesisError}; use nova::{ @@ -47,9 +48,9 @@ pub type EE1 = nova::provider::ipa_pc::EvaluationEngine; pub type EE2 = nova::provider::ipa_pc::EvaluationEngine; /// Type alias for the Relaxed R1CS Spartan SNARK using G1 group elements, EE1. -pub type SS1 = nova::spartan::RelaxedR1CSSNARK; +pub type SS1 = nova::spartan::snark::RelaxedR1CSSNARK; /// Type alias for the Relaxed R1CS Spartan SNARK using G2 group elements, EE2. -pub type SS2 = nova::spartan::RelaxedR1CSSNARK; +pub type SS2 = nova::spartan::snark::RelaxedR1CSSNARK; /// Type alias for a MultiFrame with S1 field elements. pub type C1<'a, C> = MultiFrame<'a, S1, IO, Witness, C>; @@ -60,12 +61,35 @@ pub type C2 = TrivialTestCircuit<::Scalar>; pub type NovaPublicParams<'a, C> = nova::PublicParams, C2>; /// A struct that contains public parameters for the Nova proving system. -#[derive(Serialize, Deserialize)] +#[derive(Clone, Serialize, Deserialize)] #[serde(bound = "")] -pub struct PublicParams<'a, C: Coprocessor> { - pp: NovaPublicParams<'a, C>, - pk: ProverKey, C2, SS1, SS2>, - vk: VerifierKey, C2, SS1, SS2>, +pub struct PublicParams<'c, C: Coprocessor> { + pp: NovaPublicParams<'c, C>, + pk: ProverKey, C2, SS1, SS2>, + vk: VerifierKey, C2, SS1, SS2>, +} + +impl<'c, C: Coprocessor> Abomonation for PublicParams<'c, C> { + unsafe fn entomb(&self, bytes: &mut W) -> std::io::Result<()> { + self.pp.entomb(bytes)?; + self.pk.entomb(bytes)?; + self.vk.entomb(bytes)?; + Ok(()) + } + + unsafe fn exhume<'b>(&mut self, mut bytes: &'b mut [u8]) -> Option<&'b mut [u8]> { + let temp = bytes; + bytes = self.pp.exhume(temp)?; + let temp = bytes; + bytes = self.pk.exhume(temp)?; + let temp = bytes; + bytes = self.vk.exhume(temp)?; + Some(bytes) + } + + fn extent(&self) -> usize { + self.pp.extent() + self.pk.extent() + self.vk.extent() + } } /// An enum representing the two types of proofs that can be generated and verified. diff --git a/src/public_parameters/file_map.rs b/src/public_parameters/file_map.rs index 8bff103f13..6bfc9267bf 100644 --- a/src/public_parameters/file_map.rs +++ b/src/public_parameters/file_map.rs @@ -1,8 +1,13 @@ -use std::fs::create_dir_all; -use std::io::Error; +use std::fs::{create_dir_all, File}; +use std::io::{BufReader, Read}; use std::marker::PhantomData; use std::path::{Path, PathBuf}; +use std::time::Instant; +use tap::TapFallible; +use abomonation::{encode, Abomonation}; + +use crate::public_parameters::error::Error; use crate::public_parameters::FileStore; pub(crate) fn data_dir() -> PathBuf { @@ -34,12 +39,48 @@ impl FileIndex { } pub(crate) fn get(&self, key: &K) -> Option { - self.key_path(key); V::read_from_path(self.key_path(key)).ok() } + pub(crate) fn get_raw_bytes(&self, key: &K) -> Result, Error> { + let file = File::open(self.key_path(key))?; + let mut reader = BufReader::new(file); + let mut bytes = Vec::new(); + reader.read_to_end(&mut bytes)?; + Ok(bytes) + } + + #[allow(dead_code)] + pub(crate) fn get_with_timing(&self, key: &K, discr: &String) -> Option { + let start = Instant::now(); + let result = V::read_from_path(self.key_path(key)).tap_err(|e| eprintln!("{e}")); + let end = start.elapsed(); + eprintln!("Reading {discr} from disk-cache in {:?}", end); + result.ok() + } + pub(crate) fn set(&self, key: &K, data: &V) -> Result<(), Error> { + data.write_to_path(self.key_path(&key)); + Ok(()) + } + + pub(crate) fn set_abomonated(&self, key: &K, data: &V) -> Result<(), Error> { + let mut file = File::create(self.key_path(key))?; + unsafe { encode(data, &mut file).expect("failed to encode") }; + Ok(()) + } + + #[allow(dead_code)] + pub(crate) fn set_with_timing( + &self, + key: &K, + data: &V, + discr: &String, + ) -> Result<(), Error> { + let start = Instant::now(); data.write_to_path(self.key_path(key)); + let end = start.elapsed(); + eprintln!("Writing {discr} to disk-cache in {:?}", end); Ok(()) } } diff --git a/src/public_parameters/mod.rs b/src/public_parameters/mod.rs index 095f40f4ff..312f2d1cfc 100644 --- a/src/public_parameters/mod.rs +++ b/src/public_parameters/mod.rs @@ -1,3 +1,4 @@ +use abomonation::decode; use std::fs::File; use std::io::{self, BufReader, BufWriter}; use std::path::Path; @@ -21,10 +22,61 @@ pub type S1 = pallas::Scalar; pub fn public_params + 'static>( rc: usize, + abomonated: bool, lang: Arc>, -) -> Result>, Error> { +) -> Result>, Error> +where + C: Coprocessor + 'static, +{ let f = |lang: Arc>| Arc::new(nova::public_params(rc, lang)); - registry::CACHE_REG.get_coprocessor_or_update_with(rc, f, lang) + registry::CACHE_REG.get_coprocessor_or_update_with(rc, abomonated, f, lang) +} + +/// Attempts to extract abomonated public parameters. +/// To avoid all copying overhead, we zerocopy all of the data within the file; +/// this leads to extremely high performance, but restricts the lifetime of the data +/// to the lifetime of the file. Thus, we cannot pass a reference out and must +/// rely on a closure to capture the data and continue the computation in `bind`. +pub fn with_public_params(rc: usize, lang: Arc>, bind: F) -> Result +where + C: Coprocessor + 'static, + F: FnOnce(&PublicParams<'static, C>) -> T, +{ + let disk_cache = file_map::FileIndex::new("public_params").unwrap(); + // use the cached language key + let lang_key = lang.key(); + // Sanity-check: we're about to use a lang-dependent disk cache, which should be specialized + // for this lang/coprocessor. + let key = format!("public-params-rc-{rc}-coproc-{lang_key}-abomonated"); + + match disk_cache.get_raw_bytes(&key) { + Ok(mut bytes) => { + if let Some((pp, remaining)) = unsafe { decode(&mut bytes) } { + assert!(remaining.is_empty()); + eprintln!("Using disk-cached public params for lang {}", lang_key); + Ok(bind(pp)) + } else { + eprintln!("failed to decode bytes"); + let pp = nova::public_params(rc, lang); + let mut bytes = Vec::new(); + unsafe { abomonation::encode(&pp, &mut bytes)? }; + // maybe just directly write + disk_cache + .set_abomonated(&key, &pp) + .map_err(|e| Error::CacheError(format!("Disk write error: {e}")))?; + Ok(bind(&pp)) + } + } + Err(e) => { + eprintln!("{e}"); + let pp = nova::public_params(rc, lang); + // maybe just directly write + disk_cache + .set_abomonated(&key, &pp) + .map_err(|e| Error::CacheError(format!("Disk write error: {e}")))?; + Ok(bind(&pp)) + } + } } pub trait FileStore where @@ -58,8 +110,7 @@ where fn read_from_path>(path: P) -> Result { let file = File::open(path)?; let reader = BufReader::new(file); - bincode::deserialize_from(reader) - .map_err(|e| Error::CacheError(format!("Cache deserialization error: {}", e))) + bincode::deserialize_from(reader).map_err(|e| Error::CacheError(format!("{}", e))) } fn read_from_json_path>(path: P) -> Result { diff --git a/src/public_parameters/registry.rs b/src/public_parameters/registry.rs index e4c45b5ba6..ea7463ad42 100644 --- a/src/public_parameters/registry.rs +++ b/src/public_parameters/registry.rs @@ -3,6 +3,7 @@ use std::{ sync::{Arc, Mutex}, }; +use abomonation::decode; use log::info; use once_cell::sync::Lazy; use pasta_curves::pallas; @@ -15,7 +16,7 @@ use super::file_map::FileIndex; type S1 = pallas::Scalar; type AnyMap = anymap::Map; -type PublicParamMemCache = HashMap>>; +type PublicParamMemCache = HashMap<(usize, bool), Arc>>; /// This is a global registry for Coproc-specific parameters. /// It is used to cache parameters for each Coproc, so that they are not @@ -38,6 +39,7 @@ impl Registry { >( &'static self, rc: usize, + abomonated: bool, default: F, lang: Arc>, ) -> Result>, Error> { @@ -45,20 +47,42 @@ impl Registry { let disk_cache = FileIndex::new("public_params").unwrap(); // use the cached language key let lang_key = lang.key(); + let quick_suffix = if abomonated { "-abomonated" } else { "" }; // Sanity-check: we're about to use a lang-dependent disk cache, which should be specialized // for this lang/coprocessor. - let key = format!("public-params-rc-{rc}-coproc-{lang_key}"); - // read the file if it exists, otherwise initialize - if let Some(pp) = disk_cache.get::>(&key) { - info!("Using disk-cached public params for lang {lang_key}"); - Ok(Arc::new(pp)) + let key = format!("public-params-rc-{rc}-coproc-{lang_key}{quick_suffix}"); + if abomonated { + match disk_cache.get_raw_bytes(&key) { + Ok(mut bytes) => { + info!("Using abomonated public params for lang {lang_key}"); + let (pp, rest) = unsafe { decode::>(&mut bytes).unwrap() }; + assert!(rest.is_empty()); + Ok(Arc::new(pp.clone())) // this clone is VERY expensive + } + Err(e) => { + eprintln!("{e}"); + let pp = default(lang); + // maybe just directly write + disk_cache + .set_abomonated(&key, &*pp) + .tap_ok(|_| info!("Writing public params to disk-cache: {}", lang_key)) + .map_err(|e| Error::CacheError(format!("Disk write error: {e}")))?; + Ok(pp) + } + } } else { - let pp = default(lang); - disk_cache - .set(&key, &*pp) - .tap_ok(|_| info!("Writing public params to disk-cache for lang {lang_key}")) - .map_err(|e| Error::CacheError(format!("Disk write error: {e}")))?; - Ok(pp) + // read the file if it exists, otherwise initialize + if let Some(pp) = disk_cache.get::>(&key) { + info!("Using disk-cached public params for lang {lang_key}"); + Ok(Arc::new(pp)) + } else { + let pp = default(lang); + disk_cache + .set(&key, &*pp) + .tap_ok(|_| info!("Writing public params to disk-cache: {}", lang_key)) + .map_err(|e| Error::CacheError(format!("Disk write error: {e}")))?; + Ok(pp) + } } } @@ -70,6 +94,7 @@ impl Registry { >( &'static self, rc: usize, + abomonated: bool, default: F, lang: Arc>, ) -> Result>, Error> { @@ -79,13 +104,13 @@ impl Registry { let entry = registry.entry::>(); // deduce the map and populate it if needed let param_entry = entry.or_insert_with(HashMap::new); - match param_entry.entry(rc) { + match param_entry.entry((rc, abomonated)) { Entry::Occupied(o) => Ok(o.into_mut()), Entry::Vacant(v) => { - let val = self.get_from_file_cache_or_update_with(rc, default, lang)?; + let val = self.get_from_file_cache_or_update_with(rc, abomonated, default, lang)?; Ok(v.insert(val)) } } - .cloned() + .cloned() // this clone is VERY expensive } }