diff --git a/src/lem/circuit.rs b/src/lem/circuit.rs index 438160f60c..5b7d5c177f 100644 --- a/src/lem/circuit.rs +++ b/src/lem/circuit.rs @@ -585,10 +585,35 @@ impl Func { let allocated_ptr = AllocatedPtr::from_parts(tag, src.hash().clone()); bound_allocations.insert(tgt.clone(), allocated_ptr); } + Op::EqTag(tgt, a, b) => { + let a = bound_allocations.get(a)?; + let b = bound_allocations.get(b)?; + let a_num = a.tag(); + let b_num = b.tag(); + let eq = alloc_equal(&mut cs.namespace(|| "equal_tag"), a_num, b_num)?; + let c_num = boolean_to_num(&mut cs.namespace(|| "equal_tag.to_num"), &eq)?; + let tag = g + .global_allocator + .get_or_alloc_const(cs, Tag::Expr(Num).to_field())?; + let c = AllocatedPtr::from_parts(tag, c_num); + bound_allocations.insert(tgt.clone(), c); + } + Op::EqVal(tgt, a, b) => { + let a = bound_allocations.get(a)?; + let b = bound_allocations.get(b)?; + let a_num = a.hash(); + let b_num = b.hash(); + let eq = alloc_equal(&mut cs.namespace(|| "equal_val"), a_num, b_num)?; + let c_num = boolean_to_num(&mut cs.namespace(|| "equal_val.to_num"), &eq)?; + let tag = g + .global_allocator + .get_or_alloc_const(cs, Tag::Expr(Num).to_field())?; + let c = AllocatedPtr::from_parts(tag, c_num); + bound_allocations.insert(tgt.clone(), c); + } Op::Add(tgt, a, b) => { let a = bound_allocations.get(a)?; let b = bound_allocations.get(b)?; - // TODO check that the tags are correct let a_num = a.hash(); let b_num = b.hash(); let c_num = add(&mut cs.namespace(|| "add"), a_num, b_num)?; @@ -601,7 +626,6 @@ impl Func { Op::Sub(tgt, a, b) => { let a = bound_allocations.get(a)?; let b = bound_allocations.get(b)?; - // TODO check that the tags are correct let a_num = a.hash(); let b_num = b.hash(); let c_num = sub(&mut cs.namespace(|| "sub"), a_num, b_num)?; @@ -614,7 +638,6 @@ impl Func { Op::Mul(tgt, a, b) => { let a = bound_allocations.get(a)?; let b = bound_allocations.get(b)?; - // TODO check that the tags are correct let a_num = a.hash(); let b_num = b.hash(); let c_num = mul(&mut cs.namespace(|| "mul"), a_num, b_num)?; @@ -627,7 +650,6 @@ impl Func { Op::Div(tgt, a, b) => { let a = bound_allocations.get(a)?; let b = bound_allocations.get(b)?; - // TODO check that the tags are correct let a_num = a.hash(); let b_num = b.hash(); @@ -652,7 +674,6 @@ impl Func { Op::Lt(tgt, a, b) => { let a = bound_allocations.get(a)?; let b = bound_allocations.get(b)?; - // TODO check that the tags are correct let tag = g .global_allocator .get_or_alloc_const(cs, Tag::Expr(Num).to_field())?; @@ -782,10 +803,8 @@ impl Func { Ctrl::IfEq(x, y, eq_block, else_block) => { let x = bound_allocations.get(x)?.hash(); let y = bound_allocations.get(y)?.hash(); - // TODO should we check whether the tags are equal too? let eq = alloc_equal(&mut cs.namespace(|| "if_eq.alloc_equal"), x, y)?; let not_eq = eq.not(); - // TODO is this the most efficient way of doing if statements? let not_dummy_and_eq = and(&mut cs.namespace(|| "if_eq.and"), not_dummy, &eq)?; let not_dummy_and_not_eq = and(&mut cs.namespace(|| "if_eq.and.2"), not_dummy, ¬_eq)?; @@ -1027,6 +1046,10 @@ impl Func { Op::Cast(_tgt, tag, _src) => { globals.insert(FWrap(tag.to_field())); } + Op::EqTag(_, _, _) | Op::EqVal(_, _, _) => { + globals.insert(FWrap(Tag::Expr(Num).to_field())); + num_constraints += 5; + } Op::Add(_, _, _) | Op::Sub(_, _, _) | Op::Mul(_, _, _) => { globals.insert(FWrap(Tag::Expr(Num).to_field())); num_constraints += 1; diff --git a/src/lem/eval.rs b/src/lem/eval.rs index 3c0ba02b00..d9cd0f62f8 100644 --- a/src/lem/eval.rs +++ b/src/lem/eval.rs @@ -725,11 +725,17 @@ fn apply_cont() -> Func { return(hidden, env, continuation, makethunk) } Symbol("eq") => { - // TODO should we check whether the tags are also equal? - if evaled_arg == result { - return (t, env, continuation, makethunk) + let eq_tag = eq_tag(evaled_arg, result); + let eq_val = eq_val(evaled_arg, result); + let eq = mul(eq_tag, eq_val); + match eq.val { + Num(0) => { + return (nil, env, continuation, makethunk) + } + Num(1) => { + return (t, env, continuation, makethunk) + } } - return (nil, env, continuation, makethunk) } Symbol("+") => { match args_num_type.val { @@ -924,8 +930,8 @@ mod tests { use blstrs::Scalar as Fr; const NUM_INPUTS: usize = 1; - const NUM_AUX: usize = 8781; - const NUM_CONSTRAINTS: usize = 10875; + const NUM_AUX: usize = 8868; + const NUM_CONSTRAINTS: usize = 11096; const NUM_SLOTS: SlotsCounter = SlotsCounter { hash2: 16, hash3: 4, diff --git a/src/lem/interpreter.rs b/src/lem/interpreter.rs index 52641a6263..cfff4d2ced 100644 --- a/src/lem/interpreter.rs +++ b/src/lem/interpreter.rs @@ -111,6 +111,30 @@ impl Block { let tgt_ptr = src_ptr.cast(*tag); bindings.insert(tgt.clone(), tgt_ptr); } + Op::EqTag(tgt, a, b) => { + let a = bindings.get(a)?; + let b = bindings.get(b)?; + let c = if a.tag() == b.tag() { + Ptr::Leaf(Tag::Expr(Num), F::ONE) + } else { + Ptr::Leaf(Tag::Expr(Num), F::ZERO) + }; + bindings.insert(tgt.clone(), c); + } + Op::EqVal(tgt, a, b) => { + let a = bindings.get(a)?; + let b = bindings.get(b)?; + // In order to compare Ptrs, we *must* resolve the hashes. Otherwise, we risk failing to recognize equality of + // compound data with opaque data in either element's transitive closure. + let a_hash = store.hash_ptr(a)?.hash; + let b_hash = store.hash_ptr(b)?.hash; + let c = if a_hash == b_hash { + Ptr::Leaf(Tag::Expr(Num), F::ONE) + } else { + Ptr::Leaf(Tag::Expr(Num), F::ZERO) + }; + bindings.insert(tgt.clone(), c); + } Op::Add(tgt, a, b) => { let a = bindings.get(a)?; let b = bindings.get(b)?; diff --git a/src/lem/macros.rs b/src/lem/macros.rs index a691b381af..977fdfe600 100644 --- a/src/lem/macros.rs +++ b/src/lem/macros.rs @@ -58,6 +58,20 @@ macro_rules! op { $crate::var!($src), ) }; + ( let $tgt:ident = eq_tag($a:ident, $b:ident) ) => { + $crate::lem::Op::EqTag( + $crate::var!($tgt), + $crate::var!($a), + $crate::var!($b), + ) + }; + ( let $tgt:ident = eq_val($a:ident, $b:ident) ) => { + $crate::lem::Op::EqVal( + $crate::var!($tgt), + $crate::var!($a), + $crate::var!($b), + ) + }; ( let $tgt:ident = add($a:ident, $b:ident) ) => { $crate::lem::Op::Add( $crate::var!($tgt), @@ -250,6 +264,26 @@ macro_rules! block { $($tail)* ) }; + (@seq {$($limbs:expr)*}, let $tgt:ident = eq_tag($a:ident, $b:ident) ; $($tail:tt)*) => { + $crate::block! ( + @seq + { + $($limbs)* + $crate::op!(let $tgt = eq_tag($a, $b)) + }, + $($tail)* + ) + }; + (@seq {$($limbs:expr)*}, let $tgt:ident = eq_val($a:ident, $b:ident) ; $($tail:tt)*) => { + $crate::block! ( + @seq + { + $($limbs)* + $crate::op!(let $tgt = eq_val($a, $b)) + }, + $($tail)* + ) + }; (@seq {$($limbs:expr)*}, let $tgt:ident = add($a:ident, $b:ident) ; $($tail:tt)*) => { $crate::block! ( @seq diff --git a/src/lem/mod.rs b/src/lem/mod.rs index 41ea92320c..230c99258a 100644 --- a/src/lem/mod.rs +++ b/src/lem/mod.rs @@ -238,6 +238,10 @@ pub enum Op { /// `Cast(y, t, x)` binds `y` to a pointer with tag `t` and the hash of `x` Cast(Var, Tag, Var), /// `Add(y, a, b)` binds `y` to the sum of `a` and `b` + EqTag(Var, Var, Var), + /// `EqVal(y, a, b)` binds `y` to `1` if `a.val != b.val`, or to `0` otherwise + EqVal(Var, Var, Var), + /// `Lt(y, a, b)` binds `y` to `1` if `a < b`, or to `0` otherwise Add(Var, Var, Var), /// `Sub(y, a, b)` binds `y` to the sum of `a` and `b` Sub(Var, Var, Var), @@ -245,7 +249,7 @@ pub enum Op { Mul(Var, Var, Var), /// `Div(y, a, b)` binds `y` to the sum of `a` and `b` Div(Var, Var, Var), - /// `Lt(y, a, b)` binds `y` to `t` if `a < b`, or to `nil` otherwise + /// `Lt(y, a, b)` binds `y` to `1` if `a < b`, or to `0` otherwise Lt(Var, Var, Var), /// `Emit(v)` simply prints out the value of `v` when interpreting the code Emit(Var), @@ -342,7 +346,9 @@ impl Func { is_bound(src, map)?; is_unique(tgt, map); } - Op::Add(tgt, a, b) + Op::EqTag(tgt, a, b) + | Op::EqVal(tgt, a, b) + | Op::Add(tgt, a, b) | Op::Sub(tgt, a, b) | Op::Mul(tgt, a, b) | Op::Div(tgt, a, b) @@ -554,6 +560,18 @@ impl Block { let tgt = insert_one(map, uniq, &tgt); ops.push(Op::Cast(tgt, tag, src)) } + Op::EqTag(tgt, a, b) => { + let a = map.get_cloned(&a)?; + let b = map.get_cloned(&b)?; + let tgt = insert_one(map, uniq, &tgt); + ops.push(Op::EqTag(tgt, a, b)) + } + Op::EqVal(tgt, a, b) => { + let a = map.get_cloned(&a)?; + let b = map.get_cloned(&b)?; + let tgt = insert_one(map, uniq, &tgt); + ops.push(Op::EqVal(tgt, a, b)) + } Op::Add(tgt, a, b) => { let a = map.get_cloned(&a)?; let b = map.get_cloned(&b)?;