From 3c3249067f65d76c13a01830dffe96e47b75a8e6 Mon Sep 17 00:00:00 2001 From: Robin Salen Date: Tue, 6 Aug 2024 12:46:22 -0400 Subject: [PATCH] Do not panic on wire set twice or generator not run issues --- plonky2/examples/bench_recursion.rs | 8 +- plonky2/examples/factorial.rs | 2 +- plonky2/examples/fibonacci.rs | 4 +- plonky2/examples/fibonacci_serialization.rs | 4 +- plonky2/examples/range_check.rs | 2 +- plonky2/examples/square_root.rs | 10 +- plonky2/src/batch_fri/oracle.rs | 2 +- plonky2/src/fri/witness_util.rs | 20 ++- plonky2/src/gadgets/arithmetic.rs | 12 +- plonky2/src/gadgets/arithmetic_extension.rs | 16 +- plonky2/src/gadgets/range_check.rs | 12 +- plonky2/src/gadgets/select.rs | 4 +- plonky2/src/gadgets/split_base.rs | 9 +- plonky2/src/gadgets/split_join.rs | 22 ++- plonky2/src/gates/arithmetic_base.rs | 8 +- plonky2/src/gates/arithmetic_extension.rs | 8 +- plonky2/src/gates/base_sum.rs | 12 +- plonky2/src/gates/coset_interpolation.rs | 16 +- plonky2/src/gates/exponentiation.rs | 12 +- plonky2/src/gates/gate_testing.rs | 8 +- plonky2/src/gates/lookup.rs | 18 ++- plonky2/src/gates/lookup_table.rs | 15 +- plonky2/src/gates/multiplication_extension.rs | 8 +- plonky2/src/gates/poseidon.rs | 53 ++++--- plonky2/src/gates/poseidon_mds.rs | 12 +- plonky2/src/gates/random_access.rs | 13 +- plonky2/src/gates/reducing.rs | 13 +- plonky2/src/gates/reducing_extension.rs | 12 +- plonky2/src/hash/merkle_proofs.rs | 6 +- plonky2/src/iop/challenger.rs | 3 +- plonky2/src/iop/generator.rs | 77 ++++++---- plonky2/src/iop/witness.rs | 144 +++++++++++------- plonky2/src/lookup_test.rs | 35 +++-- plonky2/src/plonk/circuit_data.rs | 2 +- plonky2/src/plonk/prover.rs | 14 +- .../conditional_recursive_verifier.rs | 10 +- plonky2/src/recursion/cyclic_recursion.rs | 18 +-- plonky2/src/recursion/dummy_circuit.rs | 13 +- plonky2/src/recursion/recursive_verifier.rs | 18 +-- plonky2/src/util/reducing.rs | 4 +- starky/src/fibonacci_stark.rs | 2 +- starky/src/permutation_stark.rs | 2 +- starky/src/recursive_verifier.rs | 20 +-- starky/src/stark_testing.rs | 14 +- starky/src/unconstrained_stark.rs | 2 +- 45 files changed, 462 insertions(+), 257 deletions(-) diff --git a/plonky2/examples/bench_recursion.rs b/plonky2/examples/bench_recursion.rs index 8201c96de0..ec4a67f9ab 100644 --- a/plonky2/examples/bench_recursion.rs +++ b/plonky2/examples/bench_recursion.rs @@ -138,7 +138,7 @@ fn dummy_lookup_proof, C: GenericConfig, let data = builder.build::(); let mut inputs = PartialWitness::::new(); - inputs.set_target(initial_a, F::ONE); + inputs.set_target(initial_a, F::ONE)?; let mut timing = TimingTree::new("prove with one lookup", Level::Debug); let proof = prove(&data.prover_only, &data.common, inputs, &mut timing)?; timing.print(); @@ -189,7 +189,7 @@ fn dummy_many_rows_proof< builder.register_public_input(output); let mut pw = PartialWitness::new(); - pw.set_target(initial_a, F::ONE); + pw.set_target(initial_a, F::ONE)?; let data = builder.build::(); let mut timing = TimingTree::new("prove with many lookups", Level::Debug); let proof = prove(&data.prover_only, &data.common, pw, &mut timing)?; @@ -235,8 +235,8 @@ where let data = builder.build::(); let mut pw = PartialWitness::new(); - pw.set_proof_with_pis_target(&pt, inner_proof); - pw.set_verifier_data_target(&inner_data, inner_vd); + pw.set_proof_with_pis_target(&pt, inner_proof)?; + pw.set_verifier_data_target(&inner_data, inner_vd)?; let mut timing = TimingTree::new("prove", Level::Debug); let proof = prove::(&data.prover_only, &data.common, pw, &mut timing)?; diff --git a/plonky2/examples/factorial.rs b/plonky2/examples/factorial.rs index 019a7aa5f3..0f0cdad13c 100644 --- a/plonky2/examples/factorial.rs +++ b/plonky2/examples/factorial.rs @@ -29,7 +29,7 @@ fn main() -> Result<()> { builder.register_public_input(cur_target); let mut pw = PartialWitness::new(); - pw.set_target(initial, F::ONE); + pw.set_target(initial, F::ONE)?; let data = builder.build::(); let proof = data.prove(pw)?; diff --git a/plonky2/examples/fibonacci.rs b/plonky2/examples/fibonacci.rs index 28403862ee..578dc24245 100644 --- a/plonky2/examples/fibonacci.rs +++ b/plonky2/examples/fibonacci.rs @@ -34,8 +34,8 @@ fn main() -> Result<()> { // Provide initial values. let mut pw = PartialWitness::new(); - pw.set_target(initial_a, F::ZERO); - pw.set_target(initial_b, F::ONE); + pw.set_target(initial_a, F::ZERO)?; + pw.set_target(initial_b, F::ONE)?; let data = builder.build::(); let proof = data.prove(pw)?; diff --git a/plonky2/examples/fibonacci_serialization.rs b/plonky2/examples/fibonacci_serialization.rs index 0a6ad30f2c..033ce39b41 100644 --- a/plonky2/examples/fibonacci_serialization.rs +++ b/plonky2/examples/fibonacci_serialization.rs @@ -37,8 +37,8 @@ fn main() -> Result<()> { // Provide initial values. let mut pw = PartialWitness::new(); - pw.set_target(initial_a, F::ZERO); - pw.set_target(initial_b, F::ONE); + pw.set_target(initial_a, F::ZERO)?; + pw.set_target(initial_b, F::ONE)?; let data = builder.build::(); diff --git a/plonky2/examples/range_check.rs b/plonky2/examples/range_check.rs index d9351e1b1c..95df7439c2 100644 --- a/plonky2/examples/range_check.rs +++ b/plonky2/examples/range_check.rs @@ -24,7 +24,7 @@ fn main() -> Result<()> { builder.range_check(value, log_max); let mut pw = PartialWitness::new(); - pw.set_target(value, F::from_canonical_usize(42)); + pw.set_target(value, F::from_canonical_usize(42))?; let data = builder.build::(); let proof = data.prove(pw)?; diff --git a/plonky2/examples/square_root.rs b/plonky2/examples/square_root.rs index fb970a67c5..ad1b1a6f80 100644 --- a/plonky2/examples/square_root.rs +++ b/plonky2/examples/square_root.rs @@ -41,13 +41,17 @@ impl, const D: usize> SimpleGenerator vec![self.x_squared] } - fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { + fn run_once( + &self, + witness: &PartitionWitness, + out_buffer: &mut GeneratedValues, + ) -> Result<()> { let x_squared = witness.get_target(self.x_squared); let x = x_squared.sqrt().unwrap(); println!("Square root: {x}"); - out_buffer.set_target(self.x, x); + out_buffer.set_target(self.x, x) } fn serialize(&self, dst: &mut Vec, _common_data: &CommonCircuitData) -> IoResult<()> { @@ -121,7 +125,7 @@ fn main() -> Result<()> { }; let mut pw = PartialWitness::new(); - pw.set_target(x_squared, x_squared_value); + pw.set_target(x_squared, x_squared_value)?; let data = builder.build::(); let proof = data.prove(pw.clone())?; diff --git a/plonky2/src/batch_fri/oracle.rs b/plonky2/src/batch_fri/oracle.rs index 71d808ed9b..192e374451 100644 --- a/plonky2/src/batch_fri/oracle.rs +++ b/plonky2/src/batch_fri/oracle.rs @@ -599,7 +599,7 @@ mod test { ); let mut pw = PartialWitness::new(); - set_fri_proof_target(&mut pw, &fri_proof_target, &proof); + set_fri_proof_target(&mut pw, &fri_proof_target, &proof)?; let data = builder.build::(); let proof = prove::(&data.prover_only, &data.common, pw, &mut timing)?; diff --git a/plonky2/src/fri/witness_util.rs b/plonky2/src/fri/witness_util.rs index ce47e7cba5..041d43c134 100644 --- a/plonky2/src/fri/witness_util.rs +++ b/plonky2/src/fri/witness_util.rs @@ -1,3 +1,4 @@ +use anyhow::Result; use itertools::Itertools; use crate::field::extension::Extendable; @@ -11,12 +12,13 @@ pub fn set_fri_proof_target( witness: &mut W, fri_proof_target: &FriProofTarget, fri_proof: &FriProof, -) where +) -> Result<()> +where F: RichField + Extendable, W: WitnessWrite + ?Sized, H: AlgebraicHasher, { - witness.set_target(fri_proof_target.pow_witness, fri_proof.pow_witness); + witness.set_target(fri_proof_target.pow_witness, fri_proof.pow_witness)?; for (&t, &x) in fri_proof_target .final_poly @@ -24,7 +26,7 @@ pub fn set_fri_proof_target( .iter() .zip_eq(&fri_proof.final_poly.coeffs) { - witness.set_extension_target(t, x); + witness.set_extension_target(t, x)?; } for (t, x) in fri_proof_target @@ -32,7 +34,7 @@ pub fn set_fri_proof_target( .iter() .zip_eq(&fri_proof.commit_phase_merkle_caps) { - witness.set_cap_target(t, x); + witness.set_cap_target(t, x)?; } for (qt, q) in fri_proof_target @@ -47,16 +49,16 @@ pub fn set_fri_proof_target( .zip_eq(&q.initial_trees_proof.evals_proofs) { for (&t, &x) in at.0.iter().zip_eq(&a.0) { - witness.set_target(t, x); + witness.set_target(t, x)?; } for (&t, &x) in at.1.siblings.iter().zip_eq(&a.1.siblings) { - witness.set_hash_target(t, x); + witness.set_hash_target(t, x)?; } } for (st, s) in qt.steps.iter().zip_eq(&q.steps) { for (&t, &x) in st.evals.iter().zip_eq(&s.evals) { - witness.set_extension_target(t, x); + witness.set_extension_target(t, x)?; } for (&t, &x) in st .merkle_proof @@ -64,8 +66,10 @@ pub fn set_fri_proof_target( .iter() .zip_eq(&s.merkle_proof.siblings) { - witness.set_hash_target(t, x); + witness.set_hash_target(t, x)?; } } } + + Ok(()) } diff --git a/plonky2/src/gadgets/arithmetic.rs b/plonky2/src/gadgets/arithmetic.rs index e162f1116f..5da59352c6 100644 --- a/plonky2/src/gadgets/arithmetic.rs +++ b/plonky2/src/gadgets/arithmetic.rs @@ -6,6 +6,8 @@ use alloc::{ }; use core::borrow::Borrow; +use anyhow::Result; + use crate::field::extension::Extendable; use crate::field::types::Field64; use crate::gates::arithmetic_base::ArithmeticGate; @@ -397,14 +399,18 @@ impl, const D: usize> SimpleGenerator for Equ vec![self.x, self.y] } - fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { + fn run_once( + &self, + witness: &PartitionWitness, + out_buffer: &mut GeneratedValues, + ) -> Result<()> { let x = witness.get_target(self.x); let y = witness.get_target(self.y); let inv = if x != y { (x - y).inverse() } else { F::ZERO }; - out_buffer.set_bool_target(self.equal, x == y); - out_buffer.set_target(self.inv, inv); + out_buffer.set_bool_target(self.equal, x == y)?; + out_buffer.set_target(self.inv, inv) } fn serialize(&self, dst: &mut Vec, _common_data: &CommonCircuitData) -> IoResult<()> { diff --git a/plonky2/src/gadgets/arithmetic_extension.rs b/plonky2/src/gadgets/arithmetic_extension.rs index afea71df39..6026a8cb16 100644 --- a/plonky2/src/gadgets/arithmetic_extension.rs +++ b/plonky2/src/gadgets/arithmetic_extension.rs @@ -6,6 +6,8 @@ use alloc::{ }; use core::borrow::Borrow; +use anyhow::Result; + use crate::field::extension::{Extendable, FieldExtension, OEF}; use crate::field::types::{Field, Field64}; use crate::gates::arithmetic_extension::ArithmeticExtensionGate; @@ -519,7 +521,11 @@ impl, const D: usize> SimpleGenerator deps } - fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { + fn run_once( + &self, + witness: &PartitionWitness, + out_buffer: &mut GeneratedValues, + ) -> Result<()> { let num = witness.get_extension_target(self.numerator); let dem = witness.get_extension_target(self.denominator); let quotient = num / dem; @@ -621,7 +627,7 @@ mod tests { let vs = FF::rand_vec(3); let ts = builder.add_virtual_extension_targets(3); for (&v, &t) in vs.iter().zip(&ts) { - pw.set_extension_target(t, v); + pw.set_extension_target(t, v)?; } let mul0 = builder.mul_many_extension(&ts); let mul1 = { @@ -696,9 +702,9 @@ mod tests { let y = ExtensionAlgebra::(FF::rand_array()); let z = x * y; for i in 0..D { - pw.set_extension_target(xt.0[i], x.0[i]); - pw.set_extension_target(yt.0[i], y.0[i]); - pw.set_extension_target(zt.0[i], z.0[i]); + pw.set_extension_target(xt.0[i], x.0[i])?; + pw.set_extension_target(yt.0[i], y.0[i])?; + pw.set_extension_target(zt.0[i], z.0[i])?; } let data = builder.build::(); diff --git a/plonky2/src/gadgets/range_check.rs b/plonky2/src/gadgets/range_check.rs index 9a66a6a6c6..6b99c550b6 100644 --- a/plonky2/src/gadgets/range_check.rs +++ b/plonky2/src/gadgets/range_check.rs @@ -5,6 +5,8 @@ use alloc::{ vec::Vec, }; +use anyhow::Result; + use crate::field::extension::Extendable; use crate::hash::hash_types::RichField; use crate::iop::generator::{GeneratedValues, SimpleGenerator}; @@ -74,13 +76,17 @@ impl, const D: usize> SimpleGenerator for Low vec![self.integer] } - fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { + fn run_once( + &self, + witness: &PartitionWitness, + out_buffer: &mut GeneratedValues, + ) -> Result<()> { let integer_value = witness.get_target(self.integer).to_canonical_u64(); let low = integer_value & ((1 << self.n_log) - 1); let high = integer_value >> self.n_log; - out_buffer.set_target(self.low, F::from_canonical_u64(low)); - out_buffer.set_target(self.high, F::from_canonical_u64(high)); + out_buffer.set_target(self.low, F::from_canonical_u64(low))?; + out_buffer.set_target(self.high, F::from_canonical_u64(high)) } fn serialize(&self, dst: &mut Vec, _common_data: &CommonCircuitData) -> IoResult<()> { diff --git a/plonky2/src/gadgets/select.rs b/plonky2/src/gadgets/select.rs index b34092ed7e..623880951b 100644 --- a/plonky2/src/gadgets/select.rs +++ b/plonky2/src/gadgets/select.rs @@ -63,8 +63,8 @@ mod tests { let truet = builder._true(); let falset = builder._false(); - pw.set_extension_target(xt, x); - pw.set_extension_target(yt, y); + pw.set_extension_target(xt, x)?; + pw.set_extension_target(yt, y)?; let should_be_x = builder.select_ext(truet, xt, yt); let should_be_y = builder.select_ext(falset, xt, yt); diff --git a/plonky2/src/gadgets/split_base.rs b/plonky2/src/gadgets/split_base.rs index 1cdec86203..273ab65772 100644 --- a/plonky2/src/gadgets/split_base.rs +++ b/plonky2/src/gadgets/split_base.rs @@ -2,6 +2,7 @@ use alloc::{format, string::String, vec, vec::Vec}; use core::borrow::Borrow; +use anyhow::Result; use itertools::Itertools; use crate::field::extension::Extendable; @@ -97,7 +98,11 @@ impl, const B: usize, const D: usize> SimpleGenerat self.limbs.iter().map(|b| b.target).collect() } - fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { + fn run_once( + &self, + witness: &PartitionWitness, + out_buffer: &mut GeneratedValues, + ) -> Result<()> { let sum = self .limbs .iter() @@ -107,7 +112,7 @@ impl, const B: usize, const D: usize> SimpleGenerat acc * F::from_canonical_usize(B) + F::from_bool(limb) }); - out_buffer.set_target(Target::wire(self.row, BaseSumGate::::WIRE_SUM), sum); + out_buffer.set_target(Target::wire(self.row, BaseSumGate::::WIRE_SUM), sum) } fn serialize(&self, dst: &mut Vec, _common_data: &CommonCircuitData) -> IoResult<()> { diff --git a/plonky2/src/gadgets/split_join.rs b/plonky2/src/gadgets/split_join.rs index d20c09bc5f..904a3e5fae 100644 --- a/plonky2/src/gadgets/split_join.rs +++ b/plonky2/src/gadgets/split_join.rs @@ -5,6 +5,8 @@ use alloc::{ vec::Vec, }; +use anyhow::Result; + use crate::field::extension::Extendable; use crate::gates::base_sum::BaseSumGate; use crate::hash::hash_types::RichField; @@ -75,12 +77,16 @@ impl, const D: usize> SimpleGenerator for Spl vec![self.integer] } - fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { + fn run_once( + &self, + witness: &PartitionWitness, + out_buffer: &mut GeneratedValues, + ) -> Result<()> { let mut integer_value = witness.get_target(self.integer).to_canonical_u64(); for &b in &self.bits { let b_value = integer_value & 1; - out_buffer.set_target(b, F::from_canonical_u64(b_value)); + out_buffer.set_target(b, F::from_canonical_u64(b_value))?; integer_value >>= 1; } @@ -88,6 +94,8 @@ impl, const D: usize> SimpleGenerator for Spl integer_value, 0, "Integer too large to fit in given number of bits" ); + + Ok(()) } fn serialize(&self, dst: &mut Vec, _common_data: &CommonCircuitData) -> IoResult<()> { @@ -118,7 +126,11 @@ impl, const D: usize> SimpleGenerator for Wir vec![self.integer] } - fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { + fn run_once( + &self, + witness: &PartitionWitness, + out_buffer: &mut GeneratedValues, + ) -> Result<()> { let mut integer_value = witness.get_target(self.integer).to_canonical_u64(); for &gate in &self.gates { @@ -134,7 +146,7 @@ impl, const D: usize> SimpleGenerator for Wir integer_value = 0; }; - out_buffer.set_target(sum, F::from_canonical_u64(truncated_value)); + out_buffer.set_target(sum, F::from_canonical_u64(truncated_value))?; } debug_assert_eq!( @@ -143,6 +155,8 @@ impl, const D: usize> SimpleGenerator for Wir "Integer too large to fit in {} many `BaseSumGate`s", self.gates.len() ); + + Ok(()) } fn serialize(&self, dst: &mut Vec, _common_data: &CommonCircuitData) -> IoResult<()> { diff --git a/plonky2/src/gates/arithmetic_base.rs b/plonky2/src/gates/arithmetic_base.rs index ab3ac0bd57..1f46db5e5f 100644 --- a/plonky2/src/gates/arithmetic_base.rs +++ b/plonky2/src/gates/arithmetic_base.rs @@ -5,6 +5,8 @@ use alloc::{ vec::Vec, }; +use anyhow::Result; + use crate::field::extension::Extendable; use crate::field::packed::PackedField; use crate::gates::gate::Gate; @@ -209,7 +211,11 @@ impl, const D: usize> SimpleGenerator .collect() } - fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { + fn run_once( + &self, + witness: &PartitionWitness, + out_buffer: &mut GeneratedValues, + ) -> Result<()> { let get_wire = |wire: usize| -> F { witness.get_target(Target::wire(self.row, wire)) }; let multiplicand_0 = get_wire(ArithmeticGate::wire_ith_multiplicand_0(self.i)); diff --git a/plonky2/src/gates/arithmetic_extension.rs b/plonky2/src/gates/arithmetic_extension.rs index e549f97045..bb110cd946 100644 --- a/plonky2/src/gates/arithmetic_extension.rs +++ b/plonky2/src/gates/arithmetic_extension.rs @@ -6,6 +6,8 @@ use alloc::{ }; use core::ops::Range; +use anyhow::Result; + use crate::field::extension::{Extendable, FieldExtension}; use crate::gates::gate::Gate; use crate::gates::util::StridedConstraintConsumer; @@ -192,7 +194,11 @@ impl, const D: usize> SimpleGenerator .collect() } - fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { + fn run_once( + &self, + witness: &PartitionWitness, + out_buffer: &mut GeneratedValues, + ) -> Result<()> { let extract_extension = |range: Range| -> F::Extension { let t = ExtensionTarget::from_range(self.row, range); witness.get_extension_target(t) diff --git a/plonky2/src/gates/base_sum.rs b/plonky2/src/gates/base_sum.rs index a169aa170a..1101bd967a 100644 --- a/plonky2/src/gates/base_sum.rs +++ b/plonky2/src/gates/base_sum.rs @@ -2,6 +2,8 @@ use alloc::{format, string::String, vec, vec::Vec}; use core::ops::Range; +use anyhow::Result; + use crate::field::extension::Extendable; use crate::field::packed::PackedField; use crate::field::types::{Field, Field64}; @@ -185,7 +187,11 @@ impl, const B: usize, const D: usize> SimpleGenerat vec![Target::wire(self.row, BaseSumGate::::WIRE_SUM)] } - fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { + fn run_once( + &self, + witness: &PartitionWitness, + out_buffer: &mut GeneratedValues, + ) -> Result<()> { let sum_value = witness .get_target(Target::wire(self.row, BaseSumGate::::WIRE_SUM)) .to_canonical_u64() as usize; @@ -206,8 +212,10 @@ impl, const B: usize, const D: usize> SimpleGenerat .collect::>(); for (b, b_value) in limbs.zip(limbs_value) { - out_buffer.set_target(b, b_value); + out_buffer.set_target(b, b_value)?; } + + Ok(()) } fn serialize(&self, dst: &mut Vec, _common_data: &CommonCircuitData) -> IoResult<()> { diff --git a/plonky2/src/gates/coset_interpolation.rs b/plonky2/src/gates/coset_interpolation.rs index eb133f173a..39ff1f23d8 100644 --- a/plonky2/src/gates/coset_interpolation.rs +++ b/plonky2/src/gates/coset_interpolation.rs @@ -8,6 +8,8 @@ use alloc::{ use core::marker::PhantomData; use core::ops::Range; +use anyhow::Result; + use crate::field::extension::algebra::ExtensionAlgebra; use crate::field::extension::{Extendable, FieldExtension, OEF}; use crate::field::interpolation::barycentric_weights; @@ -442,7 +444,11 @@ impl, const D: usize> SimpleGenerator deps } - fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { + fn run_once( + &self, + witness: &PartitionWitness, + out_buffer: &mut GeneratedValues, + ) -> Result<()> { let local_wire = |column| Wire { row: self.row, column, @@ -465,7 +471,7 @@ impl, const D: usize> SimpleGenerator out_buffer.set_ext_wires( self.gate.wires_shifted_evaluation_point().map(local_wire), shifted_evaluation_point, - ); + )?; let domain = &self.interpolation_domain; let values = (0..self.gate.num_points()) @@ -485,8 +491,8 @@ impl, const D: usize> SimpleGenerator for i in 0..self.gate.num_intermediates() { let intermediate_eval_wires = self.gate.wires_intermediate_eval(i).map(local_wire); let intermediate_prod_wires = self.gate.wires_intermediate_prod(i).map(local_wire); - out_buffer.set_ext_wires(intermediate_eval_wires, computed_eval); - out_buffer.set_ext_wires(intermediate_prod_wires, computed_prod); + out_buffer.set_ext_wires(intermediate_eval_wires, computed_eval)?; + out_buffer.set_ext_wires(intermediate_prod_wires, computed_prod)?; let start_index = 1 + (degree - 1) * (i + 1); let end_index = (start_index + degree - 1).min(self.gate.num_points()); @@ -501,7 +507,7 @@ impl, const D: usize> SimpleGenerator } let evaluation_value_wires = self.gate.wires_evaluation_value().map(local_wire); - out_buffer.set_ext_wires(evaluation_value_wires, computed_eval); + out_buffer.set_ext_wires(evaluation_value_wires, computed_eval) } fn serialize(&self, dst: &mut Vec, _common_data: &CommonCircuitData) -> IoResult<()> { diff --git a/plonky2/src/gates/exponentiation.rs b/plonky2/src/gates/exponentiation.rs index aab9322db4..a93d1af7d1 100644 --- a/plonky2/src/gates/exponentiation.rs +++ b/plonky2/src/gates/exponentiation.rs @@ -7,6 +7,8 @@ use alloc::{ }; use core::marker::PhantomData; +use anyhow::Result; + use crate::field::extension::Extendable; use crate::field::ops::Square; use crate::field::packed::PackedField; @@ -265,7 +267,11 @@ impl, const D: usize> SimpleGenerator deps } - fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { + fn run_once( + &self, + witness: &PartitionWitness, + out_buffer: &mut GeneratedValues, + ) -> Result<()> { let local_wire = |column| Wire { row: self.row, column, @@ -292,11 +298,11 @@ impl, const D: usize> SimpleGenerator for i in 0..num_power_bits { let intermediate_value_wire = local_wire(self.gate.wire_intermediate_value(i)); - out_buffer.set_wire(intermediate_value_wire, intermediate_values[i]); + out_buffer.set_wire(intermediate_value_wire, intermediate_values[i])?; } let output_wire = local_wire(self.gate.wire_output()); - out_buffer.set_wire(output_wire, intermediate_values[num_power_bits - 1]); + out_buffer.set_wire(output_wire, intermediate_values[num_power_bits - 1]) } fn serialize(&self, dst: &mut Vec, _common_data: &CommonCircuitData) -> IoResult<()> { diff --git a/plonky2/src/gates/gate_testing.rs b/plonky2/src/gates/gate_testing.rs index c71e96dff7..fe8c1cee7a 100644 --- a/plonky2/src/gates/gate_testing.rs +++ b/plonky2/src/gates/gate_testing.rs @@ -137,10 +137,10 @@ pub fn test_eval_fns< let wires_t = builder.add_virtual_extension_targets(wires.len()); let constants_t = builder.add_virtual_extension_targets(constants.len()); - pw.set_extension_targets(&wires_t, &wires); - pw.set_extension_targets(&constants_t, &constants); + pw.set_extension_targets(&wires_t, &wires)?; + pw.set_extension_targets(&constants_t, &constants)?; let public_inputs_hash_t = builder.add_virtual_hash(); - pw.set_hash_target(public_inputs_hash_t, public_inputs_hash); + pw.set_hash_target(public_inputs_hash_t, public_inputs_hash)?; let vars = EvaluationVars { local_constants: &constants, @@ -155,7 +155,7 @@ pub fn test_eval_fns< public_inputs_hash: &public_inputs_hash_t, }; let evals_t = gate.eval_unfiltered_circuit(&mut builder, vars_t); - pw.set_extension_targets(&evals_t, &evals); + pw.set_extension_targets(&evals_t, &evals)?; let data = builder.build::(); let proof = data.prove(pw)?; diff --git a/plonky2/src/gates/lookup.rs b/plonky2/src/gates/lookup.rs index 498221708e..6a28745c21 100644 --- a/plonky2/src/gates/lookup.rs +++ b/plonky2/src/gates/lookup.rs @@ -6,6 +6,7 @@ use alloc::{ vec::Vec, }; +use anyhow::{anyhow, Result}; use itertools::Itertools; use keccak_hash::keccak; @@ -188,7 +189,11 @@ impl, const D: usize> SimpleGenerator for Loo )] } - fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { + fn run_once( + &self, + witness: &PartitionWitness, + out_buffer: &mut GeneratedValues, + ) -> Result<()> { let get_wire = |wire: usize| -> F { witness.get_target(Target::wire(self.row, wire)) }; let input_val = get_wire(LookupGate::wire_ith_looking_inp(self.slot_nb)); @@ -197,7 +202,7 @@ impl, const D: usize> SimpleGenerator for Loo let output_val = F::from_canonical_u16(output); let out_wire = Target::wire(self.row, LookupGate::wire_ith_looking_out(self.slot_nb)); - out_buffer.set_target(out_wire, output_val); + out_buffer.set_target(out_wire, output_val) } else { for (input, output) in self.lut.iter() { if input_val == F::from_canonical_u16(*input) { @@ -205,12 +210,13 @@ impl, const D: usize> SimpleGenerator for Loo let out_wire = Target::wire(self.row, LookupGate::wire_ith_looking_out(self.slot_nb)); - out_buffer.set_target(out_wire, output_val); - return; + out_buffer.set_target(out_wire, output_val)?; + + return Ok(()); } } - panic!("Incorrect input value provided"); - }; + Err(anyhow!("Incorrect input value provided")) + } } fn serialize(&self, dst: &mut Vec, common_data: &CommonCircuitData) -> IoResult<()> { diff --git a/plonky2/src/gates/lookup_table.rs b/plonky2/src/gates/lookup_table.rs index e7ead9aff3..97b49595ee 100644 --- a/plonky2/src/gates/lookup_table.rs +++ b/plonky2/src/gates/lookup_table.rs @@ -9,6 +9,7 @@ use alloc::{ #[cfg(feature = "std")] use std::sync::Arc; +use anyhow::Result; use itertools::Itertools; use keccak_hash::keccak; @@ -205,7 +206,11 @@ impl, const D: usize> SimpleGenerator for Loo vec![] } - fn run_once(&self, _witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { + fn run_once( + &self, + _witness: &PartitionWitness, + out_buffer: &mut GeneratedValues, + ) -> Result<()> { let first_row = self.last_lut_row + self.lut.len().div_ceil(self.num_slots) - 1; let slot = (first_row - self.row) * self.num_slots + self.slot_nb; @@ -216,12 +221,12 @@ impl, const D: usize> SimpleGenerator for Loo if slot < self.lut.len() { let (input, output) = self.lut[slot]; - out_buffer.set_target(slot_input_target, F::from_canonical_usize(input as usize)); - out_buffer.set_target(slot_output_target, F::from_canonical_usize(output as usize)); + out_buffer.set_target(slot_input_target, F::from_canonical_usize(input as usize))?; + out_buffer.set_target(slot_output_target, F::from_canonical_usize(output as usize)) } else { // Pad with zeros. - out_buffer.set_target(slot_input_target, F::ZERO); - out_buffer.set_target(slot_output_target, F::ZERO); + out_buffer.set_target(slot_input_target, F::ZERO)?; + out_buffer.set_target(slot_output_target, F::ZERO) } } diff --git a/plonky2/src/gates/multiplication_extension.rs b/plonky2/src/gates/multiplication_extension.rs index 633a4b21ca..f604c9ebb8 100644 --- a/plonky2/src/gates/multiplication_extension.rs +++ b/plonky2/src/gates/multiplication_extension.rs @@ -6,6 +6,8 @@ use alloc::{ }; use core::ops::Range; +use anyhow::Result; + use crate::field::extension::{Extendable, FieldExtension}; use crate::gates::gate::Gate; use crate::gates::util::StridedConstraintConsumer; @@ -175,7 +177,11 @@ impl, const D: usize> SimpleGenerator .collect() } - fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { + fn run_once( + &self, + witness: &PartitionWitness, + out_buffer: &mut GeneratedValues, + ) -> Result<()> { let extract_extension = |range: Range| -> F::Extension { let t = ExtensionTarget::from_range(self.row, range); witness.get_extension_target(t) diff --git a/plonky2/src/gates/poseidon.rs b/plonky2/src/gates/poseidon.rs index fa880753ae..cf34e7bc1b 100644 --- a/plonky2/src/gates/poseidon.rs +++ b/plonky2/src/gates/poseidon.rs @@ -7,6 +7,8 @@ use alloc::{ }; use core::marker::PhantomData; +use anyhow::Result; + use crate::field::extension::Extendable; use crate::field::types::Field; use crate::gates::gate::Gate; @@ -439,7 +441,11 @@ impl + Poseidon, const D: usize> SimpleGenerator, out_buffer: &mut GeneratedValues) { + fn run_once( + &self, + witness: &PartitionWitness, + out_buffer: &mut GeneratedValues, + ) -> Result<()> { let local_wire = |column| Wire { row: self.row, column, @@ -454,7 +460,7 @@ impl + Poseidon, const D: usize> SimpleGenerator::wire_delta(i)), delta_i); + out_buffer.set_wire(local_wire(PoseidonGate::::wire_delta(i)), delta_i)?; } if swap_value == F::ONE { @@ -473,7 +479,7 @@ impl + Poseidon, const D: usize> SimpleGenerator::wire_full_sbox_0(r, i)), state[i], - ); + )?; } } ::sbox_layer_field(&mut state); @@ -487,7 +493,7 @@ impl + Poseidon, const D: usize> SimpleGenerator::wire_partial_sbox(r)), state[0], - ); + )?; state[0] = ::sbox_monomial(state[0]); state[0] += F::from_canonical_u64(::FAST_PARTIAL_ROUND_CONSTANTS[r]); state = ::mds_partial_layer_fast_field(&state, r); @@ -497,7 +503,7 @@ impl + Poseidon, const D: usize> SimpleGenerator::sbox_monomial(state[0]); state = ::mds_partial_layer_fast_field(&state, poseidon::N_PARTIAL_ROUNDS - 1); @@ -509,7 +515,7 @@ impl + Poseidon, const D: usize> SimpleGenerator::wire_full_sbox_1(r, i)), state[i], - ); + )?; } ::sbox_layer_field(&mut state); state = ::mds_layer_field(&state); @@ -517,8 +523,10 @@ impl + Poseidon, const D: usize> SimpleGenerator::wire_output(i)), state[i]); + out_buffer.set_wire(local_wire(PoseidonGate::::wire_output(i)), state[i])? } + + Ok(()) } fn serialize(&self, dst: &mut Vec, _common_data: &CommonCircuitData) -> IoResult<()> { @@ -589,24 +597,29 @@ mod tests { .collect::>(); let mut inputs = PartialWitness::new(); - inputs.set_wire( - Wire { - row, - column: Gate::WIRE_SWAP, - }, - F::ZERO, - ); - for i in 0..SPONGE_WIDTH { - inputs.set_wire( + inputs + .set_wire( Wire { row, - column: Gate::wire_input(i), + column: Gate::WIRE_SWAP, }, - permutation_inputs[i], - ); + F::ZERO, + ) + .unwrap(); + for i in 0..SPONGE_WIDTH { + inputs + .set_wire( + Wire { + row, + column: Gate::wire_input(i), + }, + permutation_inputs[i], + ) + .unwrap(); } - let witness = generate_partial_witness(inputs, &circuit.prover_only, &circuit.common); + let witness = + generate_partial_witness(inputs, &circuit.prover_only, &circuit.common).unwrap(); let expected_outputs: [F; SPONGE_WIDTH] = F::poseidon(permutation_inputs.try_into().unwrap()); diff --git a/plonky2/src/gates/poseidon_mds.rs b/plonky2/src/gates/poseidon_mds.rs index b692e2834e..801879a3b9 100644 --- a/plonky2/src/gates/poseidon_mds.rs +++ b/plonky2/src/gates/poseidon_mds.rs @@ -8,6 +8,8 @@ use alloc::{ use core::marker::PhantomData; use core::ops::Range; +use anyhow::Result; + use crate::field::extension::algebra::ExtensionAlgebra; use crate::field::extension::{Extendable, FieldExtension}; use crate::field::types::Field; @@ -238,7 +240,11 @@ impl + Poseidon, const D: usize> SimpleGenerator, out_buffer: &mut GeneratedValues) { + fn run_once( + &self, + witness: &PartitionWitness, + out_buffer: &mut GeneratedValues, + ) -> Result<()> { let get_local_get_target = |wire_range| ExtensionTarget::from_range(self.row, wire_range); let get_local_ext = |wire_range| witness.get_extension_target(get_local_get_target(wire_range)); @@ -255,8 +261,10 @@ impl + Poseidon, const D: usize> SimpleGenerator::wires_output(i)), out, - ); + )?; } + + Ok(()) } fn serialize(&self, dst: &mut Vec, _common_data: &CommonCircuitData) -> IoResult<()> { diff --git a/plonky2/src/gates/random_access.rs b/plonky2/src/gates/random_access.rs index 24abb58837..5815c908a6 100644 --- a/plonky2/src/gates/random_access.rs +++ b/plonky2/src/gates/random_access.rs @@ -7,6 +7,7 @@ use alloc::{ }; use core::marker::PhantomData; +use anyhow::Result; use itertools::Itertools; use crate::field::extension::Extendable; @@ -366,7 +367,11 @@ impl, const D: usize> SimpleGenerator deps } - fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { + fn run_once( + &self, + witness: &PartitionWitness, + out_buffer: &mut GeneratedValues, + ) -> Result<()> { let local_wire = |column| Wire { row: self.row, column, @@ -390,12 +395,14 @@ impl, const D: usize> SimpleGenerator set_local_wire( self.gate.wire_claimed_element(copy), get_local_wire(self.gate.wire_list_item(access_index, copy)), - ); + )?; for i in 0..self.gate.bits { let bit = F::from_bool(((access_index >> i) & 1) != 0); - set_local_wire(self.gate.wire_bit(i, copy), bit); + set_local_wire(self.gate.wire_bit(i, copy), bit)?; } + + Ok(()) } fn serialize(&self, dst: &mut Vec, _common_data: &CommonCircuitData) -> IoResult<()> { diff --git a/plonky2/src/gates/reducing.rs b/plonky2/src/gates/reducing.rs index 0030550514..9636850939 100644 --- a/plonky2/src/gates/reducing.rs +++ b/plonky2/src/gates/reducing.rs @@ -7,6 +7,8 @@ use alloc::{ }; use core::ops::Range; +use anyhow::Result; + use crate::field::extension::{Extendable, FieldExtension}; use crate::gates::gate::Gate; use crate::gates::util::StridedConstraintConsumer; @@ -201,7 +203,11 @@ impl, const D: usize> SimpleGenerator for Red .collect() } - fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { + fn run_once( + &self, + witness: &PartitionWitness, + out_buffer: &mut GeneratedValues, + ) -> Result<()> { let extract_extension = |range: Range| -> F::Extension { let t = ExtensionTarget::from_range(self.row, range); witness.get_extension_target(t) @@ -224,10 +230,11 @@ impl, const D: usize> SimpleGenerator for Red let mut acc = old_acc; for i in 0..self.gate.num_coeffs { let computed_acc = acc * alpha + coeffs[i].into(); - out_buffer.set_extension_target(accs[i], computed_acc); + out_buffer.set_extension_target(accs[i], computed_acc)?; acc = computed_acc; } - out_buffer.set_extension_target(output, acc); + + out_buffer.set_extension_target(output, acc) } fn serialize(&self, dst: &mut Vec, _common_data: &CommonCircuitData) -> IoResult<()> { diff --git a/plonky2/src/gates/reducing_extension.rs b/plonky2/src/gates/reducing_extension.rs index 3b7bc9e26c..1fd0865e49 100644 --- a/plonky2/src/gates/reducing_extension.rs +++ b/plonky2/src/gates/reducing_extension.rs @@ -7,6 +7,8 @@ use alloc::{ }; use core::ops::Range; +use anyhow::Result; + use crate::field::extension::{Extendable, FieldExtension}; use crate::gates::gate::Gate; use crate::gates::util::StridedConstraintConsumer; @@ -201,7 +203,11 @@ impl, const D: usize> SimpleGenerator for Red .collect() } - fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { + fn run_once( + &self, + witness: &PartitionWitness, + out_buffer: &mut GeneratedValues, + ) -> Result<()> { let local_extension = |range: Range| -> F::Extension { let t = ExtensionTarget::from_range(self.row, range); witness.get_extension_target(t) @@ -219,9 +225,11 @@ impl, const D: usize> SimpleGenerator for Red let mut acc = old_acc; for i in 0..self.gate.num_coeffs { let computed_acc = acc * alpha + coeffs[i]; - out_buffer.set_extension_target(accs[i], computed_acc); + out_buffer.set_extension_target(accs[i], computed_acc)?; acc = computed_acc; } + + Ok(()) } fn serialize(&self, dst: &mut Vec, _common_data: &CommonCircuitData) -> IoResult<()> { diff --git a/plonky2/src/hash/merkle_proofs.rs b/plonky2/src/hash/merkle_proofs.rs index 95a347ae55..021240a98e 100644 --- a/plonky2/src/hash/merkle_proofs.rs +++ b/plonky2/src/hash/merkle_proofs.rs @@ -292,18 +292,18 @@ mod tests { siblings: builder.add_virtual_hashes(proof.siblings.len()), }; for i in 0..proof.siblings.len() { - pw.set_hash_target(proof_t.siblings[i], proof.siblings[i]); + pw.set_hash_target(proof_t.siblings[i], proof.siblings[i])?; } let cap_t = builder.add_virtual_cap(cap_height); - pw.set_cap_target(&cap_t, &tree.cap); + pw.set_cap_target(&cap_t, &tree.cap)?; let i_c = builder.constant(F::from_canonical_usize(i)); let i_bits = builder.split_le(i_c, log_n); let data = builder.add_virtual_targets(tree.leaves[i].len()); for j in 0..data.len() { - pw.set_target(data[j], tree.leaves[i][j]); + pw.set_target(data[j], tree.leaves[i][j])?; } builder.verify_merkle_proof_to_cap::<>::InnerHasher>( diff --git a/plonky2/src/iop/challenger.rs b/plonky2/src/iop/challenger.rs index 57660fd487..27c827ab0d 100644 --- a/plonky2/src/iop/challenger.rs +++ b/plonky2/src/iop/challenger.rs @@ -365,7 +365,8 @@ mod tests { } let circuit = builder.build::(); let inputs = PartialWitness::new(); - let witness = generate_partial_witness(inputs, &circuit.prover_only, &circuit.common); + let witness = + generate_partial_witness(inputs, &circuit.prover_only, &circuit.common).unwrap(); let recursive_output_values_per_round: Vec> = recursive_outputs_per_round .iter() .map(|outputs| witness.get_targets(outputs)) diff --git a/plonky2/src/iop/generator.rs b/plonky2/src/iop/generator.rs index 6cdd75dcf6..f81508b7a3 100644 --- a/plonky2/src/iop/generator.rs +++ b/plonky2/src/iop/generator.rs @@ -8,6 +8,8 @@ use alloc::{ use core::fmt::Debug; use core::marker::PhantomData; +use anyhow::{anyhow, Result}; + use crate::field::extension::Extendable; use crate::field::types::Field; use crate::hash::hash_types::RichField; @@ -30,7 +32,7 @@ pub fn generate_partial_witness< inputs: PartialWitness, prover_data: &'a ProverOnlyCircuitData, common_data: &'a CommonCircuitData, -) -> PartitionWitness<'a, F> { +) -> Result> { let config = &common_data.config; let generators = &prover_data.generators; let generator_indices_by_watches = &prover_data.generator_indices_by_watches; @@ -42,7 +44,7 @@ pub fn generate_partial_witness< ); for (t, v) in inputs.target_values.into_iter() { - witness.set_target(t, v); + witness.set_target(t, v)?; } // Build a list of "pending" generators which are queued to be run. Initially, all generators @@ -72,10 +74,11 @@ pub fn generate_partial_witness< // Merge any generated values into our witness, and get a list of newly-populated // targets' representatives. - let new_target_reps = buffer - .target_values - .drain(..) - .flat_map(|(t, v)| witness.set_target_returning_rep(t, v)); + let mut new_target_reps = Vec::with_capacity(buffer.target_values.len()); + for (t, v) in buffer.target_values.drain(..) { + let reps = witness.set_target_returning_rep(t, v)?; + new_target_reps.extend(reps); + } // Enqueue unfinished generators that were watching one of the newly populated targets. for watch in new_target_reps { @@ -93,13 +96,11 @@ pub fn generate_partial_witness< pending_generator_indices = next_pending_generator_indices; } - assert_eq!( - remaining_generators, 0, - "{} generators weren't run", - remaining_generators, - ); + if remaining_generators != 0 { + return Err(anyhow!("{} generators weren't run", remaining_generators)); + } - witness + Ok(witness) } /// A generator participates in the generation of the witness. @@ -163,8 +164,10 @@ impl From> for GeneratedValues { } impl WitnessWrite for GeneratedValues { - fn set_target(&mut self, target: Target, value: F) { + fn set_target(&mut self, target: Target, value: F) -> Result<()> { self.target_values.push((target, value)); + + Ok(()) } } @@ -188,13 +191,14 @@ impl GeneratedValues { pub fn singleton_extension_target( et: ExtensionTarget, value: F::Extension, - ) -> Self + ) -> Result where F: RichField + Extendable, { let mut witness = Self::with_capacity(D); - witness.set_extension_target(et, value); - witness + witness.set_extension_target(et, value)?; + + Ok(witness) } } @@ -206,7 +210,11 @@ pub trait SimpleGenerator, const D: usize>: fn dependencies(&self) -> Vec; - fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues); + fn run_once( + &self, + witness: &PartitionWitness, + out_buffer: &mut GeneratedValues, + ) -> Result<()>; fn adapter(self) -> SimpleGeneratorAdapter where @@ -248,8 +256,7 @@ impl, SG: SimpleGenerator, const D: usize> Wi fn run(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) -> bool { if witness.contains_all(&self.inner.dependencies()) { - self.inner.run_once(witness, out_buffer); - true + self.inner.run_once(witness, out_buffer).is_ok() } else { false } @@ -283,9 +290,13 @@ impl, const D: usize> SimpleGenerator for Cop vec![self.src] } - fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { + fn run_once( + &self, + witness: &PartitionWitness, + out_buffer: &mut GeneratedValues, + ) -> Result<()> { let value = witness.get_target(self.src); - out_buffer.set_target(self.dst, value); + out_buffer.set_target(self.dst, value) } fn serialize(&self, dst: &mut Vec, _common_data: &CommonCircuitData) -> IoResult<()> { @@ -315,9 +326,13 @@ impl, const D: usize> SimpleGenerator for Ran Vec::new() } - fn run_once(&self, _witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { + fn run_once( + &self, + _witness: &PartitionWitness, + out_buffer: &mut GeneratedValues, + ) -> Result<()> { let random_value = F::rand(); - out_buffer.set_target(self.target, random_value); + out_buffer.set_target(self.target, random_value) } fn serialize(&self, dst: &mut Vec, _common_data: &CommonCircuitData) -> IoResult<()> { @@ -346,7 +361,11 @@ impl, const D: usize> SimpleGenerator for Non vec![self.to_test] } - fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { + fn run_once( + &self, + witness: &PartitionWitness, + out_buffer: &mut GeneratedValues, + ) -> Result<()> { let to_test_value = witness.get_target(self.to_test); let dummy_value = if to_test_value == F::ZERO { @@ -355,7 +374,7 @@ impl, const D: usize> SimpleGenerator for Non to_test_value.inverse() }; - out_buffer.set_target(self.dummy, dummy_value); + out_buffer.set_target(self.dummy, dummy_value) } fn serialize(&self, dst: &mut Vec, _common_data: &CommonCircuitData) -> IoResult<()> { @@ -394,8 +413,12 @@ impl, const D: usize> SimpleGenerator for Con vec![] } - fn run_once(&self, _witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { - out_buffer.set_target(Target::wire(self.row, self.wire_index), self.constant); + fn run_once( + &self, + _witness: &PartitionWitness, + out_buffer: &mut GeneratedValues, + ) -> Result<()> { + out_buffer.set_target(Target::wire(self.row, self.wire_index), self.constant) } fn serialize(&self, dst: &mut Vec, _common_data: &CommonCircuitData) -> IoResult<()> { diff --git a/plonky2/src/iop/witness.rs b/plonky2/src/iop/witness.rs index 85af6ca41b..0db9186d9a 100644 --- a/plonky2/src/iop/witness.rs +++ b/plonky2/src/iop/witness.rs @@ -1,6 +1,8 @@ #[cfg(not(feature = "std"))] use alloc::{vec, vec::Vec}; +use std::iter::zip; +use anyhow::{anyhow, Result}; use hashbrown::HashMap; use itertools::{zip_eq, Itertools}; @@ -18,52 +20,67 @@ use crate::plonk::config::{AlgebraicHasher, GenericConfig}; use crate::plonk::proof::{Proof, ProofTarget, ProofWithPublicInputs, ProofWithPublicInputsTarget}; pub trait WitnessWrite { - fn set_target(&mut self, target: Target, value: F); + fn set_target(&mut self, target: Target, value: F) -> Result<()>; - fn set_hash_target(&mut self, ht: HashOutTarget, value: HashOut) { - ht.elements - .iter() - .zip(value.elements) - .for_each(|(&t, x)| self.set_target(t, x)); + fn set_hash_target(&mut self, ht: HashOutTarget, value: HashOut) -> Result<()> { + for (t, x) in zip(ht.elements, value.elements) { + self.set_target(t, x)?; + } + + Ok(()) } fn set_cap_target>( &mut self, ct: &MerkleCapTarget, value: &MerkleCap, - ) where + ) -> Result<()> + where F: RichField, { for (ht, h) in ct.0.iter().zip(&value.0) { - self.set_hash_target(*ht, *h); + self.set_hash_target(*ht, *h)?; } + + Ok(()) } - fn set_extension_target(&mut self, et: ExtensionTarget, value: F::Extension) + fn set_extension_target( + &mut self, + et: ExtensionTarget, + value: F::Extension, + ) -> Result<()> where F: RichField + Extendable, { - self.set_target_arr(&et.0, &value.to_basefield_array()); + self.set_target_arr(&et.0, &value.to_basefield_array()) } - fn set_target_arr(&mut self, targets: &[Target], values: &[F]) { - zip_eq(targets, values).for_each(|(&target, &value)| self.set_target(target, value)); + fn set_target_arr(&mut self, targets: &[Target], values: &[F]) -> Result<()> { + for (&target, &value) in zip_eq(targets, values) { + self.set_target(target, value)?; + } + + Ok(()) } fn set_extension_targets( &mut self, ets: &[ExtensionTarget], values: &[F::Extension], - ) where + ) -> Result<()> + where F: RichField + Extendable, { debug_assert_eq!(ets.len(), values.len()); - ets.iter() - .zip(values) - .for_each(|(&et, &v)| self.set_extension_target(et, v)); + for (&et, &v) in zip(ets, values) { + self.set_extension_target(et, v)?; + } + + Ok(()) } - fn set_bool_target(&mut self, target: BoolTarget, value: bool) { + fn set_bool_target(&mut self, target: BoolTarget, value: bool) -> Result<()> { self.set_target(target.target, F::from_bool(value)) } @@ -73,7 +90,8 @@ pub trait WitnessWrite { &mut self, proof_with_pis_target: &ProofWithPublicInputsTarget, proof_with_pis: &ProofWithPublicInputs, - ) where + ) -> Result<()> + where F: RichField + Extendable, C::Hasher: AlgebraicHasher, { @@ -88,10 +106,10 @@ pub trait WitnessWrite { // Set public inputs. for (&pi_t, &pi) in pi_targets.iter().zip_eq(public_inputs) { - self.set_target(pi_t, pi); + self.set_target(pi_t, pi)?; } - self.set_proof_target(pt, proof); + self.set_proof_target(pt, proof) } /// Set the targets in a `ProofTarget` to their corresponding values in a `Proof`. @@ -99,30 +117,32 @@ pub trait WitnessWrite { &mut self, proof_target: &ProofTarget, proof: &Proof, - ) where + ) -> Result<()> + where F: RichField + Extendable, C::Hasher: AlgebraicHasher, { - self.set_cap_target(&proof_target.wires_cap, &proof.wires_cap); + self.set_cap_target(&proof_target.wires_cap, &proof.wires_cap)?; self.set_cap_target( &proof_target.plonk_zs_partial_products_cap, &proof.plonk_zs_partial_products_cap, - ); - self.set_cap_target(&proof_target.quotient_polys_cap, &proof.quotient_polys_cap); + )?; + self.set_cap_target(&proof_target.quotient_polys_cap, &proof.quotient_polys_cap)?; self.set_fri_openings( &proof_target.openings.to_fri_openings(), &proof.openings.to_fri_openings(), - ); + )?; - set_fri_proof_target(self, &proof_target.opening_proof, &proof.opening_proof); + set_fri_proof_target(self, &proof_target.opening_proof, &proof.opening_proof) } fn set_fri_openings( &mut self, fri_openings_target: &FriOpeningsTarget, fri_openings: &FriOpenings, - ) where + ) -> Result<()> + where F: RichField + Extendable, { for (batch_target, batch) in fri_openings_target @@ -130,48 +150,55 @@ pub trait WitnessWrite { .iter() .zip_eq(&fri_openings.batches) { - self.set_extension_targets(&batch_target.values, &batch.values); + self.set_extension_targets(&batch_target.values, &batch.values)?; } + + Ok(()) } fn set_verifier_data_target, const D: usize>( &mut self, vdt: &VerifierCircuitTarget, vd: &VerifierOnlyCircuitData, - ) where + ) -> Result<()> + where F: RichField + Extendable, C::Hasher: AlgebraicHasher, { - self.set_cap_target(&vdt.constants_sigmas_cap, &vd.constants_sigmas_cap); - self.set_hash_target(vdt.circuit_digest, vd.circuit_digest); + self.set_cap_target(&vdt.constants_sigmas_cap, &vd.constants_sigmas_cap)?; + self.set_hash_target(vdt.circuit_digest, vd.circuit_digest) } - fn set_wire(&mut self, wire: Wire, value: F) { + fn set_wire(&mut self, wire: Wire, value: F) -> Result<()> { self.set_target(Target::Wire(wire), value) } - fn set_wires(&mut self, wires: W, values: &[F]) + fn set_wires(&mut self, wires: W, values: &[F]) -> Result<()> where W: IntoIterator, { // If we used itertools, we could use zip_eq for extra safety. for (wire, &value) in wires.into_iter().zip(values) { - self.set_wire(wire, value); + self.set_wire(wire, value)?; } + + Ok(()) } - fn set_ext_wires(&mut self, wires: W, value: F::Extension) + fn set_ext_wires(&mut self, wires: W, value: F::Extension) -> Result<()> where F: RichField + Extendable, W: IntoIterator, { - self.set_wires(wires, &value.to_basefield_array()); + self.set_wires(wires, &value.to_basefield_array()) } - fn extend>(&mut self, pairs: I) { + fn extend>(&mut self, pairs: I) -> Result<()> { for (t, v) in pairs { - self.set_target(t, v); + self.set_target(t, v)?; } + + Ok(()) } } @@ -277,15 +304,20 @@ impl PartialWitness { } impl WitnessWrite for PartialWitness { - fn set_target(&mut self, target: Target, value: F) { + fn set_target(&mut self, target: Target, value: F) -> Result<()> { let opt_old_value = self.target_values.insert(target, value); if let Some(old_value) = opt_old_value { - assert_eq!( - value, old_value, - "Target {:?} was set twice with different values: {} != {}", - target, old_value, value - ); + if value != old_value { + return Err(anyhow!( + "Target {:?} was set twice with different values: {} != {}", + target, + old_value, + value + )); + } } + + Ok(()) } } @@ -317,19 +349,23 @@ impl<'a, F: Field> PartitionWitness<'a, F> { /// Set a `Target`. On success, returns the representative index of the newly-set target. If the /// target was already set, returns `None`. - pub fn set_target_returning_rep(&mut self, target: Target, value: F) -> Option { + pub fn set_target_returning_rep(&mut self, target: Target, value: F) -> Result> { let rep_index = self.representative_map[self.target_index(target)]; let rep_value = &mut self.values[rep_index]; if let Some(old_value) = *rep_value { - assert_eq!( - value, old_value, - "Partition containing {:?} was set twice with different values: {} != {}", - target, old_value, value - ); - None + if value != old_value { + return Err(anyhow!( + "Partition containing {:?} was set twice with different values: {} != {}", + target, + old_value, + value + )); + } + + Ok(None) } else { *rep_value = Some(value); - Some(rep_index) + Ok(Some(rep_index)) } } @@ -353,8 +389,8 @@ impl<'a, F: Field> PartitionWitness<'a, F> { } impl<'a, F: Field> WitnessWrite for PartitionWitness<'a, F> { - fn set_target(&mut self, target: Target, value: F) { - self.set_target_returning_rep(target, value); + fn set_target(&mut self, target: Target, value: F) -> Result<()> { + self.set_target_returning_rep(target, value).map(|_| ()) } } diff --git a/plonky2/src/lookup_test.rs b/plonky2/src/lookup_test.rs index cb6b53f86b..1e6a943e42 100644 --- a/plonky2/src/lookup_test.rs +++ b/plonky2/src/lookup_test.rs @@ -103,8 +103,8 @@ fn test_one_lookup() -> anyhow::Result<()> { let mut pw = PartialWitness::new(); - pw.set_target(initial_a, F::from_canonical_usize(look_val_a)); - pw.set_target(initial_b, F::from_canonical_usize(look_val_b)); + pw.set_target(initial_a, F::from_canonical_usize(look_val_a))?; + pw.set_target(initial_b, F::from_canonical_usize(look_val_b))?; let data = builder.build::(); let mut timing = TimingTree::new("prove one lookup", Level::Debug); @@ -171,8 +171,8 @@ fn test_two_luts() -> anyhow::Result<()> { builder.register_public_input(output_final); let mut pw = PartialWitness::new(); - pw.set_target(initial_a, F::from_canonical_usize(look_val_a)); - pw.set_target(initial_b, F::from_canonical_usize(look_val_b)); + pw.set_target(initial_a, F::from_canonical_usize(look_val_a))?; + pw.set_target(initial_b, F::from_canonical_usize(look_val_b))?; let data = builder.build::(); let mut timing = TimingTree::new("prove two_luts", Level::Debug); let proof = prove(&data.prover_only, &data.common, pw, &mut timing)?; @@ -241,8 +241,8 @@ fn test_different_inputs() -> anyhow::Result<()> { let look_val_a = table[init_a].0; let look_val_b = table[init_b].0; - pw.set_target(initial_a, F::from_canonical_u16(look_val_a)); - pw.set_target(initial_b, F::from_canonical_u16(look_val_b)); + pw.set_target(initial_a, F::from_canonical_u16(look_val_a))?; + pw.set_target(initial_b, F::from_canonical_u16(look_val_b))?; let data = builder.build::(); let mut timing = TimingTree::new("prove different lookups", Level::Debug); @@ -327,8 +327,8 @@ fn test_many_lookups() -> anyhow::Result<()> { let mut pw = PartialWitness::new(); - pw.set_target(initial_a, F::from_canonical_usize(look_val_a)); - pw.set_target(initial_b, F::from_canonical_usize(look_val_b)); + pw.set_target(initial_a, F::from_canonical_usize(look_val_a))?; + pw.set_target(initial_b, F::from_canonical_usize(look_val_b))?; let data = builder.build::(); let mut timing = TimingTree::new("prove different lookups", Level::Debug); @@ -404,8 +404,8 @@ fn test_same_luts() -> anyhow::Result<()> { let mut pw = PartialWitness::new(); - pw.set_target(initial_a, F::from_canonical_usize(look_val_a)); - pw.set_target(initial_b, F::from_canonical_usize(look_val_b)); + pw.set_target(initial_a, F::from_canonical_usize(look_val_a))?; + pw.set_target(initial_b, F::from_canonical_usize(look_val_b))?; let data = builder.build::(); let mut timing = TimingTree::new("prove two_luts", Level::Debug); @@ -443,8 +443,8 @@ fn test_big_lut() -> anyhow::Result<()> { let mut pw = PartialWitness::new(); - pw.set_target(initial_a, F::from_canonical_u16(look_val_a)); - pw.set_target(initial_b, F::from_canonical_u16(look_val_b)); + pw.set_target(initial_a, F::from_canonical_u16(look_val_a))?; + pw.set_target(initial_b, F::from_canonical_u16(look_val_b))?; let proof = data.prove(pw)?; assert_eq!( @@ -494,12 +494,11 @@ fn test_many_lookups_on_big_lut() -> anyhow::Result<()> { let mut pw = PartialWitness::new(); - inputs - .into_iter() - .enumerate() - .for_each(|(i, t)| pw.set_target(t, F::from_canonical_usize(i))); - pw.set_target(initial_a, F::from_canonical_u16(look_val_a)); - pw.set_target(initial_b, F::from_canonical_u16(look_val_b)); + for (i, t) in inputs.into_iter().enumerate() { + pw.set_target(t, F::from_canonical_usize(i))? + } + pw.set_target(initial_a, F::from_canonical_u16(look_val_a))?; + pw.set_target(initial_b, F::from_canonical_u16(look_val_b))?; let proof = data.prove(pw)?; assert_eq!( diff --git a/plonky2/src/plonk/circuit_data.rs b/plonky2/src/plonk/circuit_data.rs index 3f848ad2cb..441e810374 100644 --- a/plonky2/src/plonk/circuit_data.rs +++ b/plonky2/src/plonk/circuit_data.rs @@ -152,7 +152,7 @@ impl, C: GenericConfig, const D: usize> MockCircuitData { pub fn generate_witness(&self, inputs: PartialWitness) -> PartitionWitness { - generate_partial_witness::(inputs, &self.prover_only, &self.common) + generate_partial_witness::(inputs, &self.prover_only, &self.common).unwrap() } } diff --git a/plonky2/src/plonk/prover.rs b/plonky2/src/plonk/prover.rs index 3001e60eb6..fcd784f326 100644 --- a/plonky2/src/plonk/prover.rs +++ b/plonky2/src/plonk/prover.rs @@ -46,7 +46,7 @@ pub fn set_lookup_wires< prover_data: &ProverOnlyCircuitData, common_data: &CommonCircuitData, pw: &mut PartitionWitness, -) { +) -> Result<()> { for ( lut_index, &LookupWire { @@ -88,8 +88,8 @@ pub fn set_lookup_wires< Target::wire(last_lut_gate - 1, LookupGate::wire_ith_looking_inp(slot)); let out_target = Target::wire(last_lut_gate - 1, LookupGate::wire_ith_looking_out(slot)); - pw.set_target(inp_target, F::from_canonical_u16(first_inp_value)); - pw.set_target(out_target, F::from_canonical_u16(first_out_value)); + pw.set_target(inp_target, F::from_canonical_u16(first_inp_value))?; + pw.set_target(out_target, F::from_canonical_u16(first_out_value))?; multiplicities[0] += 1; } @@ -104,9 +104,11 @@ pub fn set_lookup_wires< pw.set_target( mul_target, F::from_canonical_usize(multiplicities[lut_entry]), - ); + )?; } } + + Ok(()) } pub fn prove, C: GenericConfig, const D: usize>( @@ -122,7 +124,7 @@ where let partition_witness = timed!( timing, &format!("run {} generators", prover_data.generators.len()), - generate_partial_witness(inputs, prover_data, common_data) + generate_partial_witness(inputs, prover_data, common_data)? ); prove_with_partition_witness(prover_data, common_data, partition_witness, timing) @@ -148,7 +150,7 @@ where let quotient_degree = common_data.quotient_degree(); let degree = common_data.degree(); - set_lookup_wires(prover_data, common_data, &mut partition_witness); + set_lookup_wires(prover_data, common_data, &mut partition_witness)?; let public_inputs = partition_witness.get_targets(&prover_data.public_inputs); let public_inputs_hash = C::InnerHasher::hash_no_pad(&public_inputs); diff --git a/plonky2/src/recursion/conditional_recursive_verifier.rs b/plonky2/src/recursion/conditional_recursive_verifier.rs index c41a9b583a..0bca23f7ff 100644 --- a/plonky2/src/recursion/conditional_recursive_verifier.rs +++ b/plonky2/src/recursion/conditional_recursive_verifier.rs @@ -370,7 +370,7 @@ mod tests { let mut builder = CircuitBuilder::::new(config.clone()); let mut pw = PartialWitness::new(); let t = builder.add_virtual_target(); - pw.set_target(t, F::rand()); + pw.set_target(t, F::rand())?; builder.register_public_input(t); let _t2 = builder.square(t); for _ in 0..64 { @@ -388,15 +388,15 @@ mod tests { let mut builder = CircuitBuilder::::new(config); let mut pw = PartialWitness::new(); let pt = builder.add_virtual_proof_with_pis(&data.common); - pw.set_proof_with_pis_target(&pt, &proof); + pw.set_proof_with_pis_target(&pt, &proof)?; let dummy_pt = builder.add_virtual_proof_with_pis(&data.common); - pw.set_proof_with_pis_target::(&dummy_pt, &dummy_proof); + pw.set_proof_with_pis_target::(&dummy_pt, &dummy_proof)?; let inner_data = builder.add_virtual_verifier_data(data.common.config.fri_config.cap_height); - pw.set_verifier_data_target(&inner_data, &data.verifier_only); + pw.set_verifier_data_target(&inner_data, &data.verifier_only)?; let dummy_inner_data = builder.add_virtual_verifier_data(data.common.config.fri_config.cap_height); - pw.set_verifier_data_target(&dummy_inner_data, &dummy_data.verifier_only); + pw.set_verifier_data_target(&dummy_inner_data, &dummy_data.verifier_only)?; let b = builder.constant_bool(F::rand().0 % 2 == 0); builder.conditionally_verify_proof::( b, diff --git a/plonky2/src/recursion/cyclic_recursion.rs b/plonky2/src/recursion/cyclic_recursion.rs index 7be554176c..df0fb95cdd 100644 --- a/plonky2/src/recursion/cyclic_recursion.rs +++ b/plonky2/src/recursion/cyclic_recursion.rs @@ -312,7 +312,7 @@ mod tests { let mut pw = PartialWitness::new(); let initial_hash = [F::ZERO, F::ONE, F::TWO, F::from_canonical_usize(3)]; let initial_hash_pis = initial_hash.into_iter().enumerate().collect(); - pw.set_bool_target(condition, false); + pw.set_bool_target(condition, false)?; pw.set_proof_with_pis_target::( &inner_cyclic_proof_with_pis, &cyclic_base_proof( @@ -320,8 +320,8 @@ mod tests { &cyclic_circuit_data.verifier_only, initial_hash_pis, ), - ); - pw.set_verifier_data_target(&verifier_data_target, &cyclic_circuit_data.verifier_only); + )?; + pw.set_verifier_data_target(&verifier_data_target, &cyclic_circuit_data.verifier_only)?; let proof = cyclic_circuit_data.prove(pw)?; check_cyclic_proof_verifier_data( &proof, @@ -332,9 +332,9 @@ mod tests { // 1st recursive layer. let mut pw = PartialWitness::new(); - pw.set_bool_target(condition, true); - pw.set_proof_with_pis_target(&inner_cyclic_proof_with_pis, &proof); - pw.set_verifier_data_target(&verifier_data_target, &cyclic_circuit_data.verifier_only); + pw.set_bool_target(condition, true)?; + pw.set_proof_with_pis_target(&inner_cyclic_proof_with_pis, &proof)?; + pw.set_verifier_data_target(&verifier_data_target, &cyclic_circuit_data.verifier_only)?; let proof = cyclic_circuit_data.prove(pw)?; check_cyclic_proof_verifier_data( &proof, @@ -345,9 +345,9 @@ mod tests { // 2nd recursive layer. let mut pw = PartialWitness::new(); - pw.set_bool_target(condition, true); - pw.set_proof_with_pis_target(&inner_cyclic_proof_with_pis, &proof); - pw.set_verifier_data_target(&verifier_data_target, &cyclic_circuit_data.verifier_only); + pw.set_bool_target(condition, true)?; + pw.set_proof_with_pis_target(&inner_cyclic_proof_with_pis, &proof)?; + pw.set_verifier_data_target(&verifier_data_target, &cyclic_circuit_data.verifier_only)?; let proof = cyclic_circuit_data.prove(pw)?; check_cyclic_proof_verifier_data( &proof, diff --git a/plonky2/src/recursion/dummy_circuit.rs b/plonky2/src/recursion/dummy_circuit.rs index 1dc872db5a..115a4c32fd 100644 --- a/plonky2/src/recursion/dummy_circuit.rs +++ b/plonky2/src/recursion/dummy_circuit.rs @@ -5,6 +5,7 @@ use alloc::{ vec::Vec, }; +use anyhow::Result; use hashbrown::HashMap; use plonky2_field::extension::Extendable; use plonky2_field::polynomial::PolynomialCoeffs; @@ -76,7 +77,7 @@ where let mut pw = PartialWitness::new(); for i in 0..circuit.common.num_public_inputs { let pi = nonzero_public_inputs.get(&i).copied().unwrap_or_default(); - pw.set_target(circuit.prover_only.public_inputs[i], pi); + pw.set_target(circuit.prover_only.public_inputs[i], pi)?; } circuit.prove(pw) } @@ -238,9 +239,13 @@ where vec![] } - fn run_once(&self, _witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { - out_buffer.set_proof_with_pis_target(&self.proof_with_pis_target, &self.proof_with_pis); - out_buffer.set_verifier_data_target(&self.verifier_data_target, &self.verifier_data); + fn run_once( + &self, + _witness: &PartitionWitness, + out_buffer: &mut GeneratedValues, + ) -> Result<()> { + out_buffer.set_proof_with_pis_target(&self.proof_with_pis_target, &self.proof_with_pis)?; + out_buffer.set_verifier_data_target(&self.verifier_data_target, &self.verifier_data) } fn serialize(&self, dst: &mut Vec, _common_data: &CommonCircuitData) -> IoResult<()> { diff --git a/plonky2/src/recursion/recursive_verifier.rs b/plonky2/src/recursion/recursive_verifier.rs index 4635d56de2..16a8ba85b2 100644 --- a/plonky2/src/recursion/recursive_verifier.rs +++ b/plonky2/src/recursion/recursive_verifier.rs @@ -473,8 +473,8 @@ mod tests { let data = builder.build::(); let mut inputs = PartialWitness::new(); - inputs.set_target(initial_a, F::from_canonical_usize(look_val_a)); - inputs.set_target(initial_b, F::from_canonical_usize(look_val_b)); + inputs.set_target(initial_a, F::from_canonical_usize(look_val_a))?; + inputs.set_target(initial_b, F::from_canonical_usize(look_val_b))?; let proof = data.prove(inputs)?; data.verify(proof.clone())?; @@ -540,8 +540,8 @@ mod tests { builder.register_public_input(output_final); let mut pw = PartialWitness::new(); - pw.set_target(initial_a, F::ONE); - pw.set_target(initial_b, F::TWO); + pw.set_target(initial_a, F::ONE)?; + pw.set_target(initial_b, F::TWO)?; let data = builder.build::(); let proof = data.prove(pw)?; @@ -606,8 +606,8 @@ mod tests { let mut pw = PartialWitness::new(); - pw.set_target(initial_a, F::from_canonical_usize(look_val_a)); - pw.set_target(initial_b, F::from_canonical_usize(look_val_b)); + pw.set_target(initial_a, F::from_canonical_usize(look_val_a))?; + pw.set_target(initial_b, F::from_canonical_usize(look_val_b))?; let data = builder.build::(); let proof = data.prove(pw)?; @@ -646,14 +646,14 @@ mod tests { let mut builder = CircuitBuilder::::new(config.clone()); let mut pw = PartialWitness::new(); let pt = builder.add_virtual_proof_with_pis(&inner_cd); - pw.set_proof_with_pis_target(&pt, &inner_proof); + pw.set_proof_with_pis_target(&pt, &inner_proof)?; let inner_data = builder.add_virtual_verifier_data(inner_cd.config.fri_config.cap_height); pw.set_cap_target( &inner_data.constants_sigmas_cap, &inner_vd.constants_sigmas_cap, - ); - pw.set_hash_target(inner_data.circuit_digest, inner_vd.circuit_digest); + )?; + pw.set_hash_target(inner_data.circuit_digest, inner_vd.circuit_digest)?; builder.verify_proof::(&pt, &inner_data, &inner_cd); diff --git a/plonky2/src/util/reducing.rs b/plonky2/src/util/reducing.rs index b99da32e6a..89c8bd96c1 100644 --- a/plonky2/src/util/reducing.rs +++ b/plonky2/src/util/reducing.rs @@ -303,7 +303,7 @@ mod tests { let mut alpha_t = ReducingFactorTarget::new(builder.constant_extension(alpha)); let vs_t = builder.add_virtual_targets(vs.len()); for (&v, &v_t) in vs.iter().zip(&vs_t) { - pw.set_target(v_t, v); + pw.set_target(v_t, v)?; } let circuit_reduce = alpha_t.reduce_base(&vs_t, &mut builder); @@ -334,7 +334,7 @@ mod tests { let mut alpha_t = ReducingFactorTarget::new(builder.constant_extension(alpha)); let vs_t = builder.add_virtual_extension_targets(vs.len()); - pw.set_extension_targets(&vs_t, &vs); + pw.set_extension_targets(&vs_t, &vs)?; let circuit_reduce = alpha_t.reduce(&vs_t, &mut builder); builder.connect_extension(manual_reduce, circuit_reduce); diff --git a/starky/src/fibonacci_stark.rs b/starky/src/fibonacci_stark.rs index 7aa40b6ed9..b66f48f88b 100644 --- a/starky/src/fibonacci_stark.rs +++ b/starky/src/fibonacci_stark.rs @@ -255,7 +255,7 @@ mod tests { let degree_bits = inner_proof.proof.recover_degree_bits(inner_config); let pt = add_virtual_stark_proof_with_pis(&mut builder, &stark, inner_config, degree_bits, 0, 0); - set_stark_proof_with_pis_target(&mut pw, &pt, &inner_proof, builder.zero()); + set_stark_proof_with_pis_target(&mut pw, &pt, &inner_proof, builder.zero())?; verify_stark_proof_circuit::(&mut builder, stark, pt, inner_config); diff --git a/starky/src/permutation_stark.rs b/starky/src/permutation_stark.rs index 62290b658d..8a897d792b 100644 --- a/starky/src/permutation_stark.rs +++ b/starky/src/permutation_stark.rs @@ -217,7 +217,7 @@ mod tests { let degree_bits = inner_proof.proof.recover_degree_bits(inner_config); let pt = add_virtual_stark_proof_with_pis(&mut builder, &stark, inner_config, degree_bits, 0, 0); - set_stark_proof_with_pis_target(&mut pw, &pt, &inner_proof, builder.zero()); + set_stark_proof_with_pis_target(&mut pw, &pt, &inner_proof, builder.zero())?; verify_stark_proof_circuit::(&mut builder, stark, pt, inner_config); diff --git a/starky/src/recursive_verifier.rs b/starky/src/recursive_verifier.rs index d6c2f9fec8..2c485d0d0d 100644 --- a/starky/src/recursive_verifier.rs +++ b/starky/src/recursive_verifier.rs @@ -325,7 +325,8 @@ pub fn set_stark_proof_with_pis_target, W, const D stark_proof_with_pis_target: &StarkProofWithPublicInputsTarget, stark_proof_with_pis: &StarkProofWithPublicInputs, zero: Target, -) where +) -> Result<()> +where F: RichField + Extendable, C::Hasher: AlgebraicHasher, W: WitnessWrite, @@ -341,10 +342,10 @@ pub fn set_stark_proof_with_pis_target, W, const D // Set public inputs. for (&pi_t, &pi) in pi_targets.iter().zip_eq(public_inputs) { - witness.set_target(pi_t, pi); + witness.set_target(pi_t, pi)?; } - set_stark_proof_target(witness, pt, proof, zero); + set_stark_proof_target(witness, pt, proof, zero) } /// Set the targets in a [`StarkProofTarget`] to their corresponding values in a @@ -354,31 +355,32 @@ pub fn set_stark_proof_target, W, const D: usize>( proof_target: &StarkProofTarget, proof: &StarkProof, zero: Target, -) where +) -> Result<()> +where F: RichField + Extendable, C::Hasher: AlgebraicHasher, W: WitnessWrite, { - witness.set_cap_target(&proof_target.trace_cap, &proof.trace_cap); + witness.set_cap_target(&proof_target.trace_cap, &proof.trace_cap)?; if let (Some(quotient_polys_cap_target), Some(quotient_polys_cap)) = (&proof_target.quotient_polys_cap, &proof.quotient_polys_cap) { - witness.set_cap_target(quotient_polys_cap_target, quotient_polys_cap); + witness.set_cap_target(quotient_polys_cap_target, quotient_polys_cap)?; } witness.set_fri_openings( &proof_target.openings.to_fri_openings(zero), &proof.openings.to_fri_openings(), - ); + )?; if let (Some(auxiliary_polys_cap_target), Some(auxiliary_polys_cap)) = ( &proof_target.auxiliary_polys_cap, &proof.auxiliary_polys_cap, ) { - witness.set_cap_target(auxiliary_polys_cap_target, auxiliary_polys_cap); + witness.set_cap_target(auxiliary_polys_cap_target, auxiliary_polys_cap)?; } - set_fri_proof_target(witness, &proof_target.opening_proof, &proof.opening_proof); + set_fri_proof_target(witness, &proof_target.opening_proof, &proof.opening_proof) } /// Utility function to check that all lookups data wrapped in `Option`s are `Some` iff diff --git a/starky/src/stark_testing.rs b/starky/src/stark_testing.rs index bbe1c840c9..42b8e0ebce 100644 --- a/starky/src/stark_testing.rs +++ b/starky/src/stark_testing.rs @@ -109,19 +109,19 @@ pub fn test_stark_circuit_constraints< let mut pw = PartialWitness::::new(); let locals_t = builder.add_virtual_extension_targets(S::COLUMNS); - pw.set_extension_targets(&locals_t, vars.get_local_values()); + pw.set_extension_targets(&locals_t, vars.get_local_values())?; let nexts_t = builder.add_virtual_extension_targets(S::COLUMNS); - pw.set_extension_targets(&nexts_t, vars.get_next_values()); + pw.set_extension_targets(&nexts_t, vars.get_next_values())?; let pis_t = builder.add_virtual_extension_targets(S::PUBLIC_INPUTS); - pw.set_extension_targets(&pis_t, vars.get_public_inputs()); + pw.set_extension_targets(&pis_t, vars.get_public_inputs())?; let alphas_t = builder.add_virtual_targets(1); - pw.set_target(alphas_t[0], alphas[0]); + pw.set_target(alphas_t[0], alphas[0])?; let z_last_t = builder.add_virtual_extension_target(); - pw.set_extension_target(z_last_t, z_last); + pw.set_extension_target(z_last_t, z_last)?; let lagrange_first_t = builder.add_virtual_extension_target(); - pw.set_extension_target(lagrange_first_t, lagrange_first); + pw.set_extension_target(lagrange_first_t, lagrange_first)?; let lagrange_last_t = builder.add_virtual_extension_target(); - pw.set_extension_target(lagrange_last_t, lagrange_last); + pw.set_extension_target(lagrange_last_t, lagrange_last)?; let vars = S::EvaluationFrameTarget::from_values(&locals_t, &nexts_t, &pis_t); let mut consumer = RecursiveConstraintConsumer::::new( diff --git a/starky/src/unconstrained_stark.rs b/starky/src/unconstrained_stark.rs index 179010cc7a..350ea6bf31 100644 --- a/starky/src/unconstrained_stark.rs +++ b/starky/src/unconstrained_stark.rs @@ -182,7 +182,7 @@ mod tests { let degree_bits = inner_proof.proof.recover_degree_bits(inner_config); let pt = add_virtual_stark_proof_with_pis(&mut builder, &stark, inner_config, degree_bits, 0, 0); - set_stark_proof_with_pis_target(&mut pw, &pt, &inner_proof, builder.zero()); + set_stark_proof_with_pis_target(&mut pw, &pt, &inner_proof, builder.zero())?; verify_stark_proof_circuit::(&mut builder, stark, pt, inner_config);