Skip to content

Commit

Permalink
Simplified Func and Cproc call (#966)
Browse files Browse the repository at this point in the history
  • Loading branch information
gabriel-barrett authored Dec 18, 2023
1 parent e04178f commit 2bd16c2
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 55 deletions.
29 changes: 7 additions & 22 deletions src/lem/circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ use bellpepper_core::{
},
};
use std::{
collections::{HashMap, HashSet, VecDeque},
collections::{HashMap, HashSet},
sync::Arc,
};

Expand Down Expand Up @@ -472,9 +472,7 @@ struct RecursiveContext<'a, F: LurkField, C: Coprocessor<F>> {
commitment_slots: &'a [&'a (Vec<AllocatedNum<F>>, AllocatedVal<F>)],
bit_decomp_slots: &'a [&'a (Vec<AllocatedNum<F>>, AllocatedVal<F>)],
blank: bool,
call_outputs: &'a VecDeque<Vec<Ptr>>,
call_idx: usize,
cproc_outputs: &'a [Vec<Ptr>],
bindings: &'a VarMap<Val>,
}

fn synthesize_block<F: LurkField, CS: ConstraintSystem<F>, C: Coprocessor<F>>(
Expand All @@ -485,7 +483,6 @@ fn synthesize_block<F: LurkField, CS: ConstraintSystem<F>, C: Coprocessor<F>>(
bound_allocations: &mut BoundAllocations<F>,
preallocated_outputs: &Vec<AllocatedPtr<F>>,
ctx: &mut RecursiveContext<'_, F, C>,
mut cproc_idx: usize,
) -> Result<()> {
for (op_idx, op) in block.ops.iter().enumerate() {
let mut cs = cs.namespace(|| format!("op {op_idx}"));
Expand Down Expand Up @@ -577,7 +574,7 @@ fn synthesize_block<F: LurkField, CS: ConstraintSystem<F>, C: Coprocessor<F>>(
.ok_or_else(|| anyhow!("Coprocessor for {sym} not found"))?;
let not_dummy_and_not_blank = not_dummy.get_value() == Some(true) && !ctx.blank;
let collected_z_ptrs = if not_dummy_and_not_blank {
let collected_ptrs = &ctx.cproc_outputs[cproc_idx];
let collected_ptrs = ctx.bindings.get_many_ptr(out)?;
if out.len() != collected_ptrs.len() {
bail!("Incompatible output length for coprocessor {sym}")
}
Expand Down Expand Up @@ -633,7 +630,6 @@ fn synthesize_block<F: LurkField, CS: ConstraintSystem<F>, C: Coprocessor<F>>(
);
}
}
cproc_idx += 1;
}
Op::Call(out, func, inp) => {
// Allocate the output pointers that the `func` will return to.
Expand All @@ -644,11 +640,11 @@ fn synthesize_block<F: LurkField, CS: ConstraintSystem<F>, C: Coprocessor<F>>(
// add the results of the call to the witness, or recompute them.
let not_dummy_and_not_blank = not_dummy.get_value() == Some(true) && !ctx.blank;
let output_z_ptrs = if not_dummy_and_not_blank {
let z_ptrs = ctx.call_outputs[ctx.call_idx]
let ptrs = ctx.bindings.get_many_ptr(out)?;
let z_ptrs = ptrs
.iter()
.map(|ptr| ctx.store.hash_ptr(ptr))
.collect::<Vec<_>>();
ctx.call_idx += 1;
assert_eq!(z_ptrs.len(), out.len());
z_ptrs
} else {
Expand Down Expand Up @@ -680,7 +676,6 @@ fn synthesize_block<F: LurkField, CS: ConstraintSystem<F>, C: Coprocessor<F>>(
bound_allocations,
&output_ptrs,
ctx,
cproc_idx,
)?;
}
Op::Cons2(img, tag, preimg) => {
Expand Down Expand Up @@ -1082,7 +1077,6 @@ fn synthesize_block<F: LurkField, CS: ConstraintSystem<F>, C: Coprocessor<F>>(
bound_allocations,
preallocated_outputs,
ctx,
cproc_idx,
)?;
branch_slots.push(branch_slot);
}
Expand Down Expand Up @@ -1118,7 +1112,6 @@ fn synthesize_block<F: LurkField, CS: ConstraintSystem<F>, C: Coprocessor<F>>(
bound_allocations,
preallocated_outputs,
ctx,
cproc_idx,
)?;

// Pushing `is_default` to `selector` to enforce summation = 1
Expand Down Expand Up @@ -1167,7 +1160,6 @@ fn synthesize_block<F: LurkField, CS: ConstraintSystem<F>, C: Coprocessor<F>>(
bound_allocations,
preallocated_outputs,
ctx,
cproc_idx,
)?;
synthesize_block(
&mut cs.namespace(|| "if_eq.false"),
Expand All @@ -1177,7 +1169,6 @@ fn synthesize_block<F: LurkField, CS: ConstraintSystem<F>, C: Coprocessor<F>>(
bound_allocations,
preallocated_outputs,
ctx,
cproc_idx,
)?;
*next_slot = next_slot.cmp_max(branch_slot);
Ok(())
Expand Down Expand Up @@ -1344,11 +1335,8 @@ impl Func {
commitment_slots,
bit_decomp_slots,
blank: frame.blank,
call_outputs: &frame.hints.call_outputs,
call_idx: 0,
cproc_outputs: &frame.hints.cproc_outputs,
bindings: &frame.hints.bindings,
},
0,
)?;
} else {
assert!(!cs.is_witness_generator());
Expand All @@ -1375,11 +1363,8 @@ impl Func {
commitment_slots,
bit_decomp_slots,
blank: frame.blank,
call_outputs: &frame.hints.call_outputs,
call_idx: 0,
cproc_outputs: &frame.hints.cproc_outputs,
bindings: &frame.hints.bindings,
},
0,
)?;
}

Expand Down
49 changes: 16 additions & 33 deletions src/lem/interpreter.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use anyhow::{anyhow, bail, Result};
use std::collections::VecDeque;

use super::{
path::Path,
Expand All @@ -20,7 +19,7 @@ use crate::{
};

impl VarMap<Val> {
fn get_many_ptr(&self, args: &[Var]) -> Result<Vec<Ptr>> {
pub fn get_many_ptr(&self, args: &[Var]) -> Result<Vec<Ptr>> {
args.iter().map(|arg| self.get_ptr(arg)).collect()
}

Expand Down Expand Up @@ -58,8 +57,7 @@ pub struct Hints {
pub hash8: Vec<Option<SlotData>>,
pub commitment: Vec<Option<SlotData>>,
pub bit_decomp: Vec<Option<SlotData>>,
pub call_outputs: VecDeque<Vec<Ptr>>,
pub cproc_outputs: Vec<Vec<Ptr>>,
pub bindings: VarMap<Val>,
}

impl Hints {
Expand All @@ -70,16 +68,14 @@ impl Hints {
let hash8 = Vec::with_capacity(slot.hash8);
let commitment = Vec::with_capacity(slot.commitment);
let bit_decomp = Vec::with_capacity(slot.bit_decomp);
let call_outputs = VecDeque::new();
let cproc_outputs = Vec::new();
let bindings = VarMap::new();
Hints {
hash4,
hash6,
hash8,
commitment,
bit_decomp,
call_outputs,
cproc_outputs,
bindings,
}
}

Expand All @@ -90,16 +86,14 @@ impl Hints {
let hash8 = vec![None; slot.hash8];
let commitment = vec![None; slot.commitment];
let bit_decomp = vec![None; slot.bit_decomp];
let call_outputs = VecDeque::new();
let cproc_outputs = Vec::new();
let bindings = VarMap::new();
Hints {
hash4,
hash6,
hash8,
commitment,
bit_decomp,
call_outputs,
cproc_outputs,
bindings,
}
}
}
Expand Down Expand Up @@ -161,35 +155,24 @@ impl Block {
if out.len() != out_ptrs.len() {
bail!("Incompatible output length for coprocessor {sym}")
}
for (var, ptr) in out.iter().zip(&out_ptrs) {
bindings.insert(var.clone(), Val::Pointer(*ptr));
for (var, ptr) in out.iter().zip(out_ptrs.into_iter()) {
bindings.insert(var.clone(), Val::Pointer(ptr));
hints.bindings.insert(var.clone(), Val::Pointer(ptr));
}
hints.cproc_outputs.push(out_ptrs);
}
Op::Call(out, func, inp) => {
// Get the argument values
let inp_ptrs = bindings.get_many_ptr(inp)?;

// To save lexical order of `call_outputs` we need to push the output
// of the call *before* the inner calls of the `func`. To do this, we
// save all the inner call outputs, push the output of the call in front
// of it, then extend `call_outputs`
let mut inner_call_outputs = VecDeque::new();
std::mem::swap(&mut inner_call_outputs, &mut hints.call_outputs);
let (mut frame, func_path) =
let (frame, func_path) =
func.call(&inp_ptrs, store, hints, emitted, lang, pc)?;
std::mem::swap(&mut inner_call_outputs, &mut frame.hints.call_outputs);

// Extend the path and bind the output variables to the output values
// Extend the path
path.extend_from_path(&func_path);
for (var, ptr) in out.iter().zip(frame.output.iter()) {
bindings.insert_ptr(var.clone(), *ptr);
}

// Update `hints` correctly
inner_call_outputs.push_front(frame.output);
// Bind the output variables to the output values
hints = frame.hints;
hints.call_outputs.extend(inner_call_outputs);
for (var, ptr) in out.iter().zip(frame.output.into_iter()) {
bindings.insert_ptr(var.clone(), ptr);
hints.bindings.insert_ptr(var.clone(), ptr);
}
}
Op::Copy(tgt, src) => {
bindings.insert(tgt.clone(), bindings.get_cloned(src)?);
Expand Down
6 changes: 6 additions & 0 deletions src/lem/var_map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,12 @@ use super::Var;
#[derive(Clone, Debug)]
pub struct VarMap<V>(FxHashMap<Var, V>);

impl<V> Default for VarMap<V> {
fn default() -> VarMap<V> {
VarMap(FxHashMap::default())
}
}

impl<V> VarMap<V> {
/// Creates an empty `VarMap`
#[inline]
Expand Down

1 comment on commit 2bd16c2

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmarks

Table of Contents

Overview

This benchmark report shows the Fibonacci GPU benchmark.
NVIDIA L4
Intel(R) Xeon(R) CPU @ 2.20GHz
125.78 GB RAM
Workflow run: https://github.com/lurk-lab/lurk-rs/actions/runs/7250691926

Benchmark Results

LEM Fibonacci Prove - rc = 100

fib-ref=e04178f4ab9cd661a1f6a19684ca8ebb85ac2342 fib-ref=2bd16c29eb80ee9aa95d6821b434ae3e84a14036
num-100 3.85 s (✅ 1.00x) 3.86 s (✅ 1.00x slower)
num-200 7.70 s (✅ 1.00x) 7.71 s (✅ 1.00x slower)

LEM Fibonacci Prove - rc = 600

fib-ref=e04178f4ab9cd661a1f6a19684ca8ebb85ac2342 fib-ref=2bd16c29eb80ee9aa95d6821b434ae3e84a14036
num-100 3.32 s (✅ 1.00x) 3.32 s (✅ 1.00x faster)
num-200 7.28 s (✅ 1.00x) 7.25 s (✅ 1.00x faster)

Made with criterion-table

Please sign in to comment.