From 9b3793fdcf5a3ebba40f7194b97e1e471e5d8e9c Mon Sep 17 00:00:00 2001 From: Giacomo Fenzi Date: Tue, 3 Sep 2024 13:47:30 +0200 Subject: [PATCH] Add extension fields (#15) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Remco Bloemen Co-authored-by: Michele Orrù --- Cargo.toml | 48 ++++---- src/plugins/ark/common.rs | 35 ++++-- src/plugins/ark/iopattern.rs | 25 +++-- src/plugins/ark/tests.rs | 71 +++++++++++- src/plugins/ark/writer.rs | 4 +- src/plugins/pow.rs | 206 +++++++++++++++++++++++++++-------- 6 files changed, 301 insertions(+), 88 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 1ce36fd..ade6a2a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,37 +8,42 @@ license = "MIT/Apache-2.0" resolver = "2" [patch.crates-io] -ark-std = {git = "https://github.com/arkworks-rs/utils"} -ark-ec = {git = "https://github.com/arkworks-rs/algebra"} -ark-ff = {git = "https://github.com/arkworks-rs/algebra"} -ark-serialize = {git = "https://github.com/arkworks-rs/algebra"} -ark-bls12-381 = {git = "https://github.com/arkworks-rs/algebra"} -ark-curve25519 = {git = "https://github.com/arkworks-rs/algebra"} -ark-pallas = {git = "https://github.com/arkworks-rs/algebra"} -ark-vesta = {git = "https://github.com/arkworks-rs/algebra"} +ark-std = { git = "https://github.com/arkworks-rs/utils" } +ark-ec = { git = "https://github.com/arkworks-rs/algebra" } +ark-ff = { git = "https://github.com/arkworks-rs/algebra" } +ark-serialize = { git = "https://github.com/arkworks-rs/algebra" } +ark-bls12-381 = { git = "https://github.com/arkworks-rs/algebra" } +ark-curve25519 = { git = "https://github.com/arkworks-rs/algebra" } +ark-pallas = { git = "https://github.com/arkworks-rs/algebra" } +ark-vesta = { git = "https://github.com/arkworks-rs/algebra" } [dependencies] -zeroize = {version="1.6.0", features=["zeroize_derive"]} -rand = {version="0.8.5", features=["getrandom"]} +zeroize = { version = "1.6.0", features = ["zeroize_derive"] } +rand = { version = "0.8.5", features = ["getrandom"] } digest = "0.10.7" generic-array = "0.14.7" # used as default hasher for the prover -keccak = "0.1.4" +keccak = { version = "0.1.4", features = ["asm", "simd"] } log = "0.4.20" # optional dependencies -ark-ff = {version="0.4.0", optional=true} -ark-ec = {version="0.4.0", optional=true} -ark-serialize = {version="0.4.2", optional=true, features=["std"]} +ark-ff = { version = "0.4.0", optional = true } +ark-ec = { version = "0.4.0", optional = true } +ark-serialize = { version = "0.4.2", optional = true, features = ["std"] } # anemoi = {git = "https://github.com/anemoi-hash/anemoi-rust", optional=true} -group = {version="0.13.0", optional=true} -ark-bls12-381 = {version="0.4.0", optional=true} +group = { version = "0.13.0", optional = true } +ark-bls12-381 = { version = "0.4.0", optional = true } +rayon = { version = "1.10.0", optional = true } +bytemuck = "1.17.1" +blake3 = "1.5.4" [features] -default = [] +default = ["parallel"] +parallel = ["dep:rayon"] ark = ["dep:ark-ff", "dep:ark-ec", "dep:ark-serialize"] group = ["dep:group"] ark-bls12-381 = ["ark", "dep:ark-bls12-381"] +rayon = ["dep:rayon"] # anemoi = ["dep:anemoi"] [dev-dependencies] @@ -47,7 +52,7 @@ sha2 = "0.10.7" blake2 = "0.10.6" hex = "0.4.3" # test curve25519 compatibility -curve25519-dalek = {version="4.0.0", features=["group"]} +curve25519-dalek = { version = "4.0.0", features = ["group"] } ark-curve25519 = "0.4.0" # test algebraic hashers bls12_381 = "0.8.0" @@ -58,10 +63,7 @@ pasta_curves = "0.5.1" ark-vesta = { version = "0.4.0", features = ["std"] } [package.metadata.docs.rs] -rustdoc-args = [ - "--html-in-header", "doc/katex-header.html", - "--cfg", "docsrs", -] +rustdoc-args = ["--html-in-header", "doc/katex-header.html", "--cfg", "docsrs"] features = ["ark", "group"] [[example]] @@ -70,7 +72,7 @@ required-features = ["ark"] [[example]] name = "schnorr_algebraic_hash" -required-features = ["ark", "ark-bls12-381"] +required-features = ["ark", "ark-bls112-381"] [[example]] name = "bulletproof" diff --git a/src/plugins/ark/common.rs b/src/plugins/ark/common.rs index 1117aba..cd501db 100644 --- a/src/plugins/ark/common.rs +++ b/src/plugins/ark/common.rs @@ -77,15 +77,20 @@ where impl FieldChallenges for T where - F: PrimeField, + F: Field, T: ByteChallenges, { fn fill_challenge_scalars(&mut self, output: &mut [F]) -> ProofResult<()> { - let mut buf = vec![0u8; bytes_uniform_modp(F::MODULUS_BIT_SIZE)]; + let base_field_size = bytes_uniform_modp(F::BasePrimeField::MODULUS_BIT_SIZE); + let mut buf = vec![0u8; F::extension_degree() as usize * base_field_size]; for o in output.iter_mut() { self.fill_challenge_bytes(&mut buf)?; - *o = F::from_be_bytes_mod_order(&buf).into(); + *o = F::from_base_prime_field_elems( + buf.chunks(base_field_size) + .map(F::BasePrimeField::from_be_bytes_mod_order), + ) + .expect("Could not convert"); } Ok(()) } @@ -93,16 +98,22 @@ where // Field <-> Field interactions: -impl FieldPublic> for Merlin, R> +impl FieldPublic for Merlin, R> where + F: Field>, H: DuplexHash>, R: RngCore + CryptoRng, C: FpConfig, { type Repr = (); - fn public_scalars(&mut self, input: &[Fp]) -> ProofResult { - self.public_units(input)?; + fn public_scalars(&mut self, input: &[F]) -> ProofResult { + let flattened: Vec<_> = input + .into_iter() + .map(|f| f.to_base_prime_field_elements()) + .flatten() + .collect(); + self.public_units(&flattened)?; Ok(()) } } @@ -126,15 +137,21 @@ where // // -impl FieldPublic> for Arthur<'_, H, Fp> +impl FieldPublic for Arthur<'_, H, Fp> where + F: Field>, H: DuplexHash>, C: FpConfig, { type Repr = (); - fn public_scalars(&mut self, input: &[Fp]) -> ProofResult { - self.public_units(input)?; + fn public_scalars(&mut self, input: &[F]) -> ProofResult { + let flattened: Vec<_> = input + .into_iter() + .map(|f| f.to_base_prime_field_elements()) + .flatten() + .collect(); + self.public_units(&flattened)?; Ok(()) } } diff --git a/src/plugins/ark/iopattern.rs b/src/plugins/ark/iopattern.rs index 4995454..f26d879 100644 --- a/src/plugins/ark/iopattern.rs +++ b/src/plugins/ark/iopattern.rs @@ -1,34 +1,45 @@ use ark_ec::CurveGroup; -use ark_ff::{Fp, FpConfig, PrimeField}; +use ark_ff::{Field, Fp, FpConfig, PrimeField}; use super::*; use crate::plugins::{bytes_modp, bytes_uniform_modp}; impl FieldIOPattern for IOPattern where - F: PrimeField, + F: Field, H: DuplexHash, { fn add_scalars(self, count: usize, label: &str) -> Self { - self.add_bytes(count * bytes_modp(F::MODULUS_BIT_SIZE), label) + self.add_bytes( + count + * F::extension_degree() as usize + * bytes_modp(F::BasePrimeField::MODULUS_BIT_SIZE), + label, + ) } fn challenge_scalars(self, count: usize, label: &str) -> Self { - self.challenge_bytes(count * bytes_uniform_modp(F::MODULUS_BIT_SIZE), label) + self.challenge_bytes( + count + * F::extension_degree() as usize + * bytes_uniform_modp(F::BasePrimeField::MODULUS_BIT_SIZE), + label, + ) } } -impl FieldIOPattern> for IOPattern> +impl FieldIOPattern for IOPattern> where + F: Field>, C: FpConfig, H: DuplexHash>, { fn add_scalars(self, count: usize, label: &str) -> Self { - self.absorb(count, label) + self.absorb(count * F::extension_degree() as usize, label) } fn challenge_scalars(self, count: usize, label: &str) -> Self { - self.squeeze(count, label) + self.squeeze(count * F::extension_degree() as usize, label) } } diff --git a/src/plugins/ark/tests.rs b/src/plugins/ark/tests.rs index 336864b..652965a 100644 --- a/src/plugins/ark/tests.rs +++ b/src/plugins/ark/tests.rs @@ -1,8 +1,12 @@ #[cfg(feature = "ark-bls12-381")] use super::poseidon::PoseidonHash; -use crate::{DefaultHash, DuplexHash, IOPattern, Unit, UnitTranscript}; +use crate::{ + ByteChallenges, ByteIOPattern, ByteReader, ByteWriter, DefaultHash, DuplexHash, IOPattern, + ProofResult, Unit, UnitTranscript, +}; #[cfg(feature = "ark-bls12-381")] -use ark_bls12_381::Fr; +use ark_bls12_381::{Fq2, Fr}; +use ark_ff::Field; /// Test that the algebraic hashes do use the IV generated from the IO Pattern. fn check_iv_is_used, F: Unit + Copy + Default + Eq + core::fmt::Debug>() { @@ -43,3 +47,66 @@ fn test_poseidon_basic() { assert_ne!(challenge, F::from(0)); } } + +fn ark_iopattern() -> IOPattern +where + F: Field, + H: DuplexHash, + IOPattern: super::FieldIOPattern + ByteIOPattern, +{ + use super::{ByteIOPattern, FieldIOPattern}; + + IOPattern::new("github.com/mmaker/nimue") + .add_scalars(3, "com") + .challenge_bytes(16, "chal") + .add_bytes(16, "resp") + .challenge_scalars(2, "chal") +} + +fn test_arkworks_end_to_end() -> ProofResult<()> { + use crate::plugins::ark::{FieldChallenges, FieldReader, FieldWriter}; + use rand::Rng; + + let mut rng = ark_std::test_rng(); + // Generate elements for the transcript + let (f0, f1, f2) = (F::rand(&mut rng), F::rand(&mut rng), F::rand(&mut rng)); + let mut b0 = [0; 16]; + let mut c0 = [0; 16]; + let b1: [u8; 16] = rng.gen(); + let mut f3 = [F::ZERO; 2]; + let mut g3 = [F::ZERO; 2]; + + let io_pattern = ark_iopattern::(); + + let mut merlin = io_pattern.to_merlin(); + + merlin.add_scalars(&[f0, f1, f2])?; + merlin.fill_challenge_bytes(&mut b0)?; + merlin.add_bytes(&b1)?; + merlin.fill_challenge_scalars(&mut f3)?; + + let mut arthur = io_pattern.to_arthur(merlin.transcript()); + let [g0, g1, g2]: [F; 3] = arthur.next_scalars()?; + arthur.fill_challenge_bytes(&mut c0)?; + let c1: [u8; 16] = arthur.next_bytes()?; + arthur.fill_challenge_scalars(&mut g3)?; + + assert_eq!(f0, g0); + assert_eq!(f1, g1); + assert_eq!(f2, g2); + assert_eq!(f3, g3); + assert_eq!(b0, c0); + assert_eq!(b1, c1); + + Ok(()) +} + +#[cfg(feature = "ark-bls12-381")] +#[test] +fn test_arkworks() { + type F = Fr; + type F2 = Fq2; + + test_arkworks_end_to_end::().unwrap(); + test_arkworks_end_to_end::().unwrap(); +} diff --git a/src/plugins/ark/writer.rs b/src/plugins/ark/writer.rs index e646140..8d3e3eb 100644 --- a/src/plugins/ark/writer.rs +++ b/src/plugins/ark/writer.rs @@ -1,12 +1,12 @@ use ark_ec::CurveGroup; -use ark_ff::{Fp, FpConfig, PrimeField}; +use ark_ff::{Field, Fp, FpConfig}; use ark_serialize::CanonicalSerialize; use rand::{CryptoRng, RngCore}; use super::{FieldPublic, FieldWriter, GroupPublic, GroupWriter}; use crate::{DuplexHash, Merlin, ProofResult, UnitTranscript}; -impl FieldWriter for Merlin { +impl FieldWriter for Merlin { fn add_scalars(&mut self, input: &[F]) -> ProofResult<()> { let serialized = self.public_scalars(input); self.transcript.extend(serialized?); diff --git a/src/plugins/pow.rs b/src/plugins/pow.rs index 7236392..117b1e0 100644 --- a/src/plugins/pow.rs +++ b/src/plugins/pow.rs @@ -1,11 +1,18 @@ use crate::{ - hash::Keccak, Arthur, ByteChallenges, ByteIOPattern, ByteReader, ByteWriter, DuplexHash, - IOPattern, Merlin, ProofError, ProofResult, + Arthur, ByteChallenges, ByteIOPattern, ByteReader, ByteWriter, IOPattern, Merlin, ProofError, + ProofResult, +}; +use { + blake3::{ + guts::BLOCK_LEN, + platform::{Platform, MAX_SIMD_DEGREE}, + IncrementCounter, OUT_LEN, + }, + std::sync::atomic::{AtomicU64, Ordering}, }; -/// Wrapper type for a challenge generated via a proof-of-work. -/// The challenge is a 128-bit integer. -pub struct PoWChal(pub u128); +#[cfg(feature = "parallel")] +use rayon::broadcast; /// [`IOPattern`] for proof-of-work challenges. pub trait PoWIOPattern { @@ -31,36 +38,20 @@ impl PoWIOPattern for IOPattern { pub trait PoWChallenge { /// Extension trait for generating a proof-of-work challenge. - fn challenge_pow(&mut self, bits: usize) -> ProofResult; + fn challenge_pow(&mut self, bits: f64) -> ProofResult<()>; } impl PoWChallenge for Merlin where Merlin: ByteWriter, { - fn challenge_pow(&mut self, bits: usize) -> ProofResult { - // Seed a new hash with the 32-byte challenge. - let mut challenge = [0u8; 32]; - self.fill_challenge_bytes(&mut challenge)?; - let hash = Keccak::new(challenge); - - // Output buffer for the hash - let mut chal_bytes = [0u8; 16]; - - // Loop over a 64-bit integer to find a PoWChal sufficiently small. - for nonce in 0u64.. { - hash.clone() - .absorb_unchecked(&nonce.to_be_bytes()) - .squeeze_unchecked(&mut chal_bytes); - let chal = u128::from_be_bytes(chal_bytes); - if (chal << bits) >> bits == chal { - self.add_bytes(&nonce.to_be_bytes())?; - return Ok(PoWChal(chal)); - } - } - - // Congratulations, you wasted 2^64 Keccak calls. You're a winner. - Err(ProofError::InvalidProof) + fn challenge_pow(&mut self, bits: f64) -> ProofResult<()> { + let challenge = self.challenge_bytes()?; + let nonce = Pow::new(challenge, bits) + .solve() + .ok_or(ProofError::InvalidProof)?; + self.add_bytes(&nonce.to_be_bytes())?; + Ok(()) } } @@ -68,38 +59,163 @@ impl<'a> PoWChallenge for Arthur<'a> where Arthur<'a>: ByteReader, { - fn challenge_pow(&mut self, bits: usize) -> ProofResult { - // Re-compute the challenge and store it in chal_bytes - let mut chal_bytes = [0u8; 16]; - let iv = self.challenge_bytes::<32>()?; - let nonce = self.next_bytes::<8>()?; - Keccak::new(iv) - .absorb_unchecked(&nonce) - .squeeze_unchecked(&mut chal_bytes); - - // Check if the challenge is valid - let chal = u128::from_be_bytes(chal_bytes); - if (chal << bits) >> bits == chal { - Ok(PoWChal(chal)) + fn challenge_pow(&mut self, bits: f64) -> ProofResult<()> { + let challenge = self.challenge_bytes()?; + let nonce = u64::from_be_bytes(self.next_bytes()?); + if Pow::new(challenge, bits).check(nonce) { + Ok(()) } else { Err(ProofError::InvalidProof) } } } +#[derive(Clone, Copy)] +struct Pow { + challenge: [u8; 32], + threshold: u64, + platform: Platform, + inputs: [u8; BLOCK_LEN * MAX_SIMD_DEGREE], + outputs: [u8; OUT_LEN * MAX_SIMD_DEGREE], +} + +impl Pow { + /// Default Blake3 initialization vector. Copied here because it is not publicly exported. + const BLAKE3_IV: [u32; 8] = [ + 0x6A09E667, 0xBB67AE85, 0x3C6EF372, 0xA54FF53A, 0x510E527F, 0x9B05688C, 0x1F83D9AB, + 0x5BE0CD19, + ]; + const BLAKE3_FLAGS: u8 = 0x0B; // CHUNK_START | CHUNK_END | ROOT + + /// Creates a new proof-of-work challenge. + /// The `challenge` is a 32-byte array that represents the challenge. + /// The `bits` is the binary logarithm of the expected amount of work. + /// When `bits` is large (i.e. close to 64), a valid solution may not be found. + fn new(challenge: [u8; 32], bits: f64) -> Self { + assert_eq!(BLOCK_LEN, 64); + assert_eq!(OUT_LEN, 32); + assert!((0.0..60.0).contains(&bits), "bits must be smaller than 60"); + let threshold = (64.0 - bits).exp2().ceil() as u64; + let platform = Platform::detect(); + let mut inputs = [0; BLOCK_LEN * MAX_SIMD_DEGREE]; + for input in inputs.chunks_exact_mut(BLOCK_LEN) { + input[..challenge.len()].copy_from_slice(&challenge); + } + let outputs = [0; OUT_LEN * MAX_SIMD_DEGREE]; + Self { + challenge, + threshold, + platform, + inputs, + outputs, + } + } + + /// Check if the `nonce` satisfies the challenge. + /// This deliberately uses the high level interface to guarantee + /// compatibility with standard Blake3. + fn check(&mut self, nonce: u64) -> bool { + // Ingest the challenge and the nonce. + let mut hasher = blake3::Hasher::new(); + hasher.update(&self.challenge); + hasher.update(&nonce.to_le_bytes()); + hasher.update(&[0; 24]); // Nonce is zero extended to 32 bytes. + + // Check if the hash is below the threshold. + let mut result_bytes = [0; 8]; + hasher.finalize_xof().fill(&mut result_bytes); + let result = u64::from_le_bytes(result_bytes); + result < self.threshold + } + + /// Find the minimal nonce that satisfies the challenge (if any) in a + /// length `MAX_SIMD_DEGREE` sequence of nonces starting from `nonce`. + fn check_many(&mut self, nonce: u64) -> Option { + for (i, input) in self.inputs.chunks_exact_mut(BLOCK_LEN).enumerate() { + input[32..40].copy_from_slice(&(nonce + i as u64).to_le_bytes()) + } + // `hash_many` requires an array of references. We need to construct this fresh + // each call as we cannot store the references and mutate the array. + let inputs: [&[u8; BLOCK_LEN]; MAX_SIMD_DEGREE] = std::array::from_fn(|i| { + self.inputs[(i * BLOCK_LEN)..((i + 1) * BLOCK_LEN)] + .try_into() + .unwrap() + }); + let counter = 0; + let flags_start = 0; + let flags_end = 0; + self.platform.hash_many::( + &inputs, + &Self::BLAKE3_IV, + counter, + IncrementCounter::No, + Self::BLAKE3_FLAGS, + flags_start, + flags_end, + &mut self.outputs, + ); + for (i, input) in self.outputs.chunks_exact_mut(OUT_LEN).enumerate() { + let result = u64::from_le_bytes(input[..8].try_into().unwrap()); + if result < self.threshold { + return Some(nonce + i as u64); + } + } + None + } + + /// Finds the minimal `nonce` that satisfies the challenge. + #[cfg(not(feature = "parallel"))] + fn solve(&mut self) -> Option { + (0u64..) + .step_by(MAX_SIMD_DEGREE) + .find_map(|nonce| self.check_many(nonce)) + } + + /// Finds the minimal `nonce` that satisfies the challenge. + #[cfg(feature = "parallel")] + fn solve(&mut self) -> Option { + // Split the work across all available threads. + // Use atomics to find the unique deterministic lowest satisfying nonce. + let global_min = AtomicU64::new(u64::MAX); + let _ = broadcast(|ctx| { + let mut worker = self.clone(); + let nonces = ((MAX_SIMD_DEGREE * ctx.index()) as u64..) + .step_by(MAX_SIMD_DEGREE * ctx.num_threads()); + for nonce in nonces { + // Use relaxed ordering to eventually get notified of another thread's solution. + // (Propagation delay should be in the order of tens of nanoseconds.) + if nonce >= global_min.load(Ordering::Relaxed) { + break; + } + if let Some(nonce) = worker.check_many(nonce) { + // We found a solution, store it in the global_min. + // Use fetch_min to solve race condition with simultaneous solutions. + global_min.fetch_min(nonce, Ordering::SeqCst); + break; + } + } + }); + match global_min.load(Ordering::SeqCst) { + u64::MAX => self.check(u64::MAX).then_some(u64::MAX), + nonce => Some(nonce), + } + } +} + #[test] fn test_pow() { + const BITS: f64 = 10.0; + let iopattern = IOPattern::new("the proof of work lottery 🎰") .add_bytes(1, "something") .challenge_pow("rolling dices"); let mut prover = iopattern.to_merlin(); prover.add_bytes(b"\0").expect("Invalid IOPattern"); - let expected = prover.challenge_pow(5).unwrap(); + prover.challenge_pow(BITS).unwrap(); let mut verifier = iopattern.to_arthur(prover.transcript()); let byte = verifier.next_bytes::<1>().unwrap(); assert_eq!(&byte, b"\0"); - let got = verifier.challenge_pow(5).unwrap(); - assert_eq!(expected.0, got.0); + verifier.challenge_pow(BITS).unwrap(); }