Skip to content

Commit

Permalink
Unsafe Serialization for Faster Public Parameter Caching (#474)
Browse files Browse the repository at this point in the history
* checkpoint3

* fix missing args, todo unhack later

* add timing features, mmap attempt

* fix

* attempt to use closures

* attempt to use closures

* utlra fast param decoding

* match APIs, bugfix

* cargo fmt

* add a Clone

* refactor: Enhance error handling in `get_with_timing` method

- Refactored `get_with_timing` method in `src/public_parameters/file_map.rs` for more streamlined functionality

* chore: clippy

---------

Co-authored-by: Hanting Zhang <hantingz@usc.edu>
  • Loading branch information
huitseeker and winston-h-zhang committed Jul 28, 2023
1 parent 6f10b40 commit 0e83a45
Show file tree
Hide file tree
Showing 15 changed files with 244 additions and 104 deletions.
33 changes: 10 additions & 23 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 5 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ serde_repr = "0.1.14"
tap = "1.0.1"
stable_deref_trait = "1.2.0"
thiserror = { workspace = true }
abomonation = "0.7.3"
abomonation_derive = { git = "https://github.com/winston-h-zhang/abomonation_derive.git" }

[target.'cfg(not(target_arch = "wasm32"))'.dependencies]
memmap = { version = "0.5.10", package = "memmap2" }
Expand Down Expand Up @@ -158,3 +160,6 @@ harness = false

[patch.crates-io]
sppark = { git = "https://github.com/supranational/sppark", rev="5fea26f43cc5d12a77776c70815e7c722fd1f8a7" }
# This is needed to ensure halo2curves, which imports pasta-curves, uses the *same* traits in bn256_grumpkin
pasta_curves = { git="https://github.com/lurk-lab/pasta_curves", branch="dev" }

14 changes: 9 additions & 5 deletions benches/end2end.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ fn end2end_benchmark(c: &mut Criterion) {
let prover = NovaProver::new(reduction_count, lang_pallas);

// use cached public params
let pp = public_parameters::public_params(reduction_count, lang_pallas_rc.clone()).unwrap();
let pp =
public_parameters::public_params(reduction_count, true, lang_pallas_rc.clone()).unwrap();

let size = (10, 0);
let benchmark_id = BenchmarkId::new("end2end_go_base_nova", format!("_{}_{}", size.0, size.1));
Expand Down Expand Up @@ -265,7 +266,8 @@ fn prove_benchmark(c: &mut Criterion) {
group.bench_with_input(benchmark_id, &size, |b, &s| {
let ptr = go_base::<pallas::Scalar>(&mut store, s.0, s.1);
let prover = NovaProver::new(reduction_count, lang_pallas.clone());
let pp = public_parameters::public_params(reduction_count, lang_pallas_rc.clone()).unwrap();
let pp = public_parameters::public_params(reduction_count, true, lang_pallas_rc.clone())
.unwrap();
let frames = prover
.get_evaluation_frames(ptr, empty_sym_env(&store), &mut store, limit, &lang_pallas)
.unwrap();
Expand Down Expand Up @@ -304,7 +306,7 @@ fn prove_compressed_benchmark(c: &mut Criterion) {
group.bench_with_input(benchmark_id, &size, |b, &s| {
let ptr = go_base::<pallas::Scalar>(&mut store, s.0, s.1);
let prover = NovaProver::new(reduction_count, lang_pallas.clone());
let pp = public_parameters::public_params(reduction_count, lang_pallas_rc.clone()).unwrap();
let pp = public_parameters::public_params(reduction_count, true, lang_pallas_rc.clone()).unwrap();
let frames = prover
.get_evaluation_frames(ptr, empty_sym_env(&store), &mut store, limit, &lang_pallas)
.unwrap();
Expand Down Expand Up @@ -343,7 +345,8 @@ fn verify_benchmark(c: &mut Criterion) {
let ptr = go_base(&mut store, s.0, s.1);
let prover = NovaProver::new(reduction_count, lang_pallas.clone());
let pp =
public_parameters::public_params(reduction_count, lang_pallas_rc.clone()).unwrap();
public_parameters::public_params(reduction_count, true, lang_pallas_rc.clone())
.unwrap();
let frames = prover
.get_evaluation_frames(ptr, empty_sym_env(&store), &mut store, limit, &lang_pallas)
.unwrap();
Expand Down Expand Up @@ -388,7 +391,8 @@ fn verify_compressed_benchmark(c: &mut Criterion) {
let ptr = go_base(&mut store, s.0, s.1);
let prover = NovaProver::new(reduction_count, lang_pallas.clone());
let pp =
public_parameters::public_params(reduction_count, lang_pallas_rc.clone()).unwrap();
public_parameters::public_params(reduction_count, true, lang_pallas_rc.clone())
.unwrap();
let frames = prover
.get_evaluation_frames(ptr, empty_sym_env(&store), &mut store, limit, &lang_pallas)
.unwrap();
Expand Down
4 changes: 2 additions & 2 deletions benches/fibonacci.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ fn fibo_total<M: measurement::Measurement>(name: &str, iterations: u64, c: &mut
let reduction_count = DEFAULT_REDUCTION_COUNT;

// use cached public params
let pp = public_params(reduction_count, lang_rc.clone()).unwrap();
let pp = public_params(reduction_count, true, lang_rc.clone()).unwrap();

c.bench_with_input(
BenchmarkId::new(name.to_string(), iterations),
Expand Down Expand Up @@ -99,7 +99,7 @@ fn fibo_prove<M: measurement::Measurement>(name: &str, iterations: u64, c: &mut
let lang_pallas = Lang::<pallas::Scalar, Coproc<pallas::Scalar>>::new();
let lang_rc = Arc::new(lang_pallas.clone());
let reduction_count = DEFAULT_REDUCTION_COUNT;
let pp = public_params(reduction_count, lang_rc.clone()).unwrap();
let pp = public_params(reduction_count, true, lang_rc.clone()).unwrap();

c.bench_with_input(
BenchmarkId::new(name.to_string(), iterations),
Expand Down
6 changes: 3 additions & 3 deletions clutch/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ impl ReplTrait<F, Coproc<F>> for ClutchState<F, Coproc<F>> {

let lang_rc = Arc::new(lang.clone());
// Load params from disk cache, or generate them in the background.
thread::spawn(move || public_params(reduction_count, lang_rc));
thread::spawn(move || public_params(reduction_count, true, lang_rc));

Self {
repl_state: ReplState::new(s, limit, command, lang),
Expand Down Expand Up @@ -497,7 +497,7 @@ impl ClutchState<F, Coproc<F>> {
let (proof_in_expr, _rest1) = store.car_cdr(&rest)?;

let prover = NovaProver::<F, Coproc<F>>::new(self.reduction_count, (*self.lang()).clone());
let pp = public_params(self.reduction_count, self.lang())?;
let pp = public_params(self.reduction_count, true, self.lang())?;

let proof = if rest.is_nil() {
self.last_claim
Expand Down Expand Up @@ -556,7 +556,7 @@ impl ClutchState<F, Coproc<F>> {
.get(&zptr_string)
.ok_or_else(|| anyhow!("proof not found: {zptr_string}"))?;

let pp = public_params(self.reduction_count, self.lang())?;
let pp = public_params(self.reduction_count, true, self.lang())?;
let result = proof.verify(&pp, &self.lang()).unwrap();

if result.verified {
Expand Down
60 changes: 31 additions & 29 deletions examples/sha256.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use lurk::eval::{empty_sym_env, lang::Lang};
use lurk::field::LurkField;
use lurk::proof::{nova::NovaProver, Prover};
use lurk::ptr::Ptr;
use lurk::public_parameters::public_params;
use lurk::public_parameters::with_public_params;
use lurk::store::Store;
use lurk::sym;
use lurk_macros::Coproc;
Expand All @@ -26,7 +26,7 @@ use pasta_curves::pallas::Scalar as Fr;
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};

const REDUCTION_COUNT: usize = 10;
const REDUCTION_COUNT: usize = 100;

#[derive(Clone, Debug, Serialize, Deserialize)]
pub(crate) struct Sha256Coprocessor<F: LurkField> {
Expand Down Expand Up @@ -179,47 +179,49 @@ fn main() {
Sha256Coprocessor::new(input_size, u).into(),
)],
);
let lang_rc = Arc::new(lang.clone());

let coproc_expr = format!("({})", sym_str);
let ptr = store.read(&coproc_expr).unwrap();

let nova_prover = NovaProver::<Fr, Sha256Coproc<Fr>>::new(REDUCTION_COUNT, lang.clone());
let lang_rc = Arc::new(lang);

println!("Setting up public parameters...");
println!("Setting up public parameters (rc = {REDUCTION_COUNT})...");

let pp_start = Instant::now();
let pp = public_params::<Sha256Coproc<Fr>>(REDUCTION_COUNT, lang_rc.clone()).unwrap();
let pp_end = pp_start.elapsed();

println!("Public parameters took {:?}", pp_end);
// see the documentation on `with_public_params`
let _res = with_public_params(REDUCTION_COUNT, lang_rc.clone(), |pp| {
let pp_end = pp_start.elapsed();
println!("Public parameters took {:?}", pp_end);

if setup_only {
return;
}

println!("Beginning proof step...");
if setup_only {
return;
}

let proof_start = Instant::now();
let (proof, z0, zi, num_steps) = nova_prover
.evaluate_and_prove(&pp, ptr, empty_sym_env(store), store, 10000, lang_rc)
.unwrap();
let proof_end = proof_start.elapsed();
println!("Beginning proof step...");
let proof_start = Instant::now();
let (proof, z0, zi, num_steps) = nova_prover
.evaluate_and_prove(pp, ptr, empty_sym_env(store), store, 10000, lang_rc)
.unwrap();
let proof_end = proof_start.elapsed();

println!("Proofs took {:?}", proof_end);
println!("Proofs took {:?}", proof_end);

println!("Verifying proof...");
println!("Verifying proof...");

let verify_start = Instant::now();
let res = proof.verify(&pp, num_steps, &z0, &zi).unwrap();
let verify_end = verify_start.elapsed();
let verify_start = Instant::now();
let res = proof.verify(&pp, num_steps, &z0, &zi).unwrap();
let verify_end = verify_start.elapsed();

println!("Verify took {:?}", verify_end);
println!("Verify took {:?}", verify_end);

if res {
println!(
"Congratulations! You proved and verified a SHA256 hash calculation in {:?} time!",
pp_end + proof_end + verify_end
);
}
if res {
println!(
"Congratulations! You proved and verified a SHA256 hash calculation in {:?} time!",
pp_end + proof_end + verify_end
);
}
})
.unwrap();
}
6 changes: 3 additions & 3 deletions fcomm/src/bin/fcomm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ impl Open {
let rc = ReductionCount::try_from(self.reduction_count).expect("reduction count");
let prover = NovaProver::<S1, Coproc<S1>>::new(rc.count(), lang.clone());
let lang_rc = Arc::new(lang.clone());
let pp = public_params(rc.count(), lang_rc).expect("public params");
let pp = public_params(rc.count(), true, lang_rc).expect("public params");
let function_map = committed_expression_store();

let handle_proof = |out_path, proof: Proof<S1>| {
Expand Down Expand Up @@ -332,7 +332,7 @@ impl Prove {
let rc = ReductionCount::try_from(self.reduction_count).unwrap();
let prover = NovaProver::<S1, Coproc<S1>>::new(rc.count(), lang.clone());
let lang_rc = Arc::new(lang.clone());
let pp = public_params(rc.count(), lang_rc.clone()).unwrap();
let pp = public_params(rc.count(), true, lang_rc.clone()).unwrap();

let proof = match &self.claim {
Some(claim) => {
Expand Down Expand Up @@ -378,7 +378,7 @@ impl Verify {
fn verify(&self, cli_error: bool, lang: &Lang<S1, Coproc<S1>>) {
let proof = proof(Some(&self.proof)).unwrap();
let lang_rc = Arc::new(lang.clone());
let pp = public_params(proof.reduction_count.count(), lang_rc).unwrap();
let pp = public_params(proof.reduction_count.count(), true, lang_rc).unwrap();
let result = proof.verify(&pp, lang).unwrap();

serde_json::to_writer(io::stdout(), &result).unwrap();
Expand Down
2 changes: 1 addition & 1 deletion fcomm/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1199,7 +1199,7 @@ mod test {
let lang = Lang::new();
let lang_rc = Arc::new(lang.clone());
let rc = ReductionCount::One;
let pp = public_params(rc.count(), lang_rc.clone()).expect("public params");
let pp = public_params(rc.count(), true, lang_rc.clone()).expect("public params");
let chained = true;
let s = &mut Store::<S1>::default();

Expand Down
2 changes: 1 addition & 1 deletion src/cli/lurk_proof.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ mod non_wasm {
lang,
} => {
log::info!("Loading public parameters");
let pp = public_params(rc, std::sync::Arc::new(lang))?;
let pp = public_params(rc, true, std::sync::Arc::new(lang))?;
Ok(proof.verify(&pp, num_steps, &public_inputs, &public_outputs)?)
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/cli/repl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ impl Repl<F> {
}

info!("Loading public parameters");
let pp = public_params(self.rc, self.lang.clone())?;
let pp = public_params(self.rc, true, self.lang.clone())?;

let prover = NovaProver::new(self.rc, (*self.lang).clone());

Expand Down
Loading

0 comments on commit 0e83a45

Please sign in to comment.