diff --git a/examples/bulletproofs.rs b/examples/bulletproofs.rs index 4819925..9698a4f 100644 --- a/examples/bulletproofs.rs +++ b/examples/bulletproofs.rs @@ -6,39 +6,86 @@ use nimue::arkworks_plugins::{Absorbable, AlgebraicIO}; use nimue::IOPattern; use nimue::{ arkworks_plugins::{Absorbs, FieldChallenges}, - Duplexer, InvalidTag, Merlin, + Arthur, Duplexer, InvalidTag, Merlin, }; use rand::rngs::OsRng; +fn fold_generators( + a: &[G], + b: &[G], + x: &G::ScalarField, + y: &G::ScalarField, +) -> Vec { + a.iter() + .zip(b.iter()) + .map(|(&a, &b)| (a * x + b * y).into_affine()) + .collect() +} + +/// Computes the inner prouct of vectors `a` and `b`. +/// +/// Useless once https://github.com/arkworks-rs/algebra/pull/665 gets merged. +fn inner_prod(a: &[F], b: &[F]) -> F { + a.iter().zip(b.iter()).map(|(&a, &b)| a * b).sum() +} + +/// Folds together `(a, b)` using challenges `x` and `y`. +fn fold(a: &[F], b: &[F], x: &F, y: &F) -> Vec { + a.iter() + .zip(b.iter()) + .map(|(&a, &b)| a * x + b * y) + .collect() +} + +// The bulletproof proof. struct Bulletproof { - proof: Vec<(G, G)>, + /// the prover's messages + round_msgs: Vec<(G, G)>, + /// the last round message last: (G::ScalarField, G::ScalarField), } +/// The IO Pattern of a bulleproof. +/// +/// Defining this as a trait allows us to "attach" the bulletproof IO to +/// the base class [`nimue::IOPattern`] and have other protocol compose the IO pattern. trait BulletproofIOPattern { + fn bulletproof_statement(&self) -> Self + where + G: AffineRepr + Absorbable; + fn bulletproof_io(&self, len: usize) -> Self where G: AffineRepr + Absorbable; } impl BulletproofIOPattern for IOPattern { + /// The IO of the bulletproof statement (the sole commitment) + fn bulletproof_statement(&self) -> Self + where + G: AffineRepr + Absorbable, + { + AlgebraicIO::::from(self).absorb_point::(1).into() + } + + /// The IO of the bulletproof protocol fn bulletproof_io(&self, len: usize) -> Self where G: AffineRepr + Absorbable, - S: Duplexer + S: Duplexer, { - let mut pattern = AlgebraicIO::::from(self).absorb_point::(1); + let mut io_pattern = AlgebraicIO::::from(self); for _ in 0..log2(len) { - pattern = pattern.absorb_point::(2).squeeze_bytes(16); + io_pattern = io_pattern.absorb_point::(2).squeeze_bytes(16); } - pattern.into() + io_pattern.into() } } fn prove( - transcript: &mut Merlin, + transcript: &mut Arthur, generators: (&[G], &[G], &G), - statement: &G, + statement: &G, // the actual inner-roduct of the witness is not really needed witness: (&[G::ScalarField], &[G::ScalarField]), ) -> Result, InvalidTag> where @@ -63,9 +110,8 @@ where let c = a * b; let left = g * a + h * b + u * c; let right = *statement; - println!("{}", (left - right).is_zero()); return Ok(Bulletproof { - proof: vec![], + round_msgs: vec![], last: (witness.0[0], witness.1[0]), }); } @@ -102,7 +148,10 @@ where let new_statement = (*statement + left * x.square() + right * x_inv.square()).into_affine(); let mut bulletproof = prove(transcript, new_generators, &new_statement, new_witness)?; - bulletproof.proof.push((left_compressed, right_compressed)); + // proof will be reverse-order + bulletproof + .round_msgs + .push((left_compressed, right_compressed)); Ok(bulletproof) } @@ -121,9 +170,9 @@ where let u = *generators.2; let mut statement = *statement; - let mut n = 1 << bulletproof.proof.len(); + let mut n = 1 << bulletproof.round_msgs.len(); assert_eq!(g.len(), n); - for (left, right) in bulletproof.proof.iter().rev() { + for (left, right) in bulletproof.round_msgs.iter().rev() { n /= 2; let (g_left, g_right) = g.split_at(n); @@ -154,9 +203,23 @@ fn main() { use ark_std::UniformRand; type H = nimue::DefaultHash; - - let a = [1, 2, 3, 4, 5, 6, 7, 8].iter().map(|&x| F::from(x)).collect::>(); - let b = [1, 2, 3, 4, 5, 6, 7, 8].iter().map(|&x| F::from(x)).collect::>(); + // the vector size + let size = 8u64; + + // initialize the IO Pattern putting the domain separator ("example.com") + let io_pattern = IOPattern::new("example.com") + // add the IO of the bulletproof statement (the commitment) + .bulletproof_statement::() + // (optional) process the data so far, filling the block till the end. + .process() + // add the IO of the bulletproof protocol (the transcript) + .bulletproof_io::(size as usize); + + // the test vectors + let a = (0..size).map(|x| F::from(x)).collect::>(); + let b = (0..size).map(|x| F::from(x + 42)).collect::>(); + let ab = inner_prod(&a, &b); + // the generators to be used for respectively a, b, ip let g = (0..a.len()) .map(|_| G::rand(&mut OsRng)) .collect::>(); @@ -164,45 +227,28 @@ fn main() { .map(|_| G::rand(&mut OsRng)) .collect::>(); let u = G::rand(&mut OsRng); - let ip = inner_prod(&a, &b); let generators = (&g[..], &h[..], &u); let statement = - (G1Projective::msm(&g, &a).unwrap() + G1Projective::msm(&h, &b).unwrap() + u * ip) + (G1Projective::msm(&g, &a).unwrap() + G1Projective::msm(&h, &b).unwrap() + u * ab) .into_affine(); let witness = (&a[..], &b[..]); - let iop = IOPattern::new("example.com").bulletproof_io::(a.len()); - let mut transcript = Merlin::new(&iop); - transcript.append_element(&statement).unwrap(); + let mut prover_transcript = Arthur::new(&io_pattern, OsRng); + prover_transcript.append_element(&statement).unwrap(); + prover_transcript.process().unwrap(); let bulletproof = - prove::(&mut transcript, generators, &statement, witness) + prove::(&mut prover_transcript, generators, &statement, witness) .unwrap(); - let mut transcript = Merlin::::new(&iop); - transcript.append_element(&statement).unwrap(); - verify(&mut transcript, generators, &statement, &bulletproof).expect("Invalid proof"); -} - -fn fold(a: &[F], b: &[F], x: &F, y: &F) -> Vec { - a.iter() - .zip(b.iter()) - .map(|(&a, &b)| a * x + b * y) - .collect() -} - -fn fold_generators( - a: &[G], - b: &[G], - x: &G::ScalarField, - y: &G::ScalarField, -) -> Vec { - a.iter() - .zip(b.iter()) - .map(|(&a, &b)| (a * x + b * y).into_affine()) - .collect() -} - -fn inner_prod(a: &[F], b: &[F]) -> F { - a.iter().zip(b.iter()).map(|(&a, &b)| a * b).sum() + let mut verifier_transcript = Merlin::::new(&io_pattern); + verifier_transcript.append_element(&statement).unwrap(); + verifier_transcript.process().unwrap(); + verify( + &mut verifier_transcript, + generators, + &statement, + &bulletproof, + ) + .expect("Invalid proof"); } diff --git a/examples/schnorr.rs b/examples/schnorr.rs index 5641b78..4a69637 100644 --- a/examples/schnorr.rs +++ b/examples/schnorr.rs @@ -2,25 +2,36 @@ use ark_ec::{AffineRepr, CurveGroup}; use ark_serialize::CanonicalSerialize; use ark_std::UniformRand; use nimue::arkworks_plugins::{Absorbable, Absorbs, AlgebraicIO, FieldChallenges}; -use nimue::{Duplexer, IOPattern, InvalidTag, Merlin, Arthur}; +use nimue::{Arthur, Duplexer, IOPattern, InvalidTag, Merlin}; trait SchnorrIOPattern { + fn schnorr_statement(&self) -> Self + where + G: AffineRepr + Absorbable; + fn schnorr_io(&self) -> Self where G: AffineRepr + Absorbable; } impl SchnorrIOPattern for IOPattern { - /// A Schnorr signature's IO Pattern. - fn schnorr_io(&self) -> IOPattern + fn schnorr_statement(&self) -> Self where G: AffineRepr + Absorbable, { + // the statement: generator and public key AlgebraicIO::::from(self) - // the statement: generator and public key .absorb_point::(2) // (optional) allow for preprocessing of the generators - .process() + .into() + } + + /// A Schnorr signature's IO Pattern. + fn schnorr_io(&self) -> IOPattern + where + G: AffineRepr + Absorbable, + { + AlgebraicIO::::from(self) // absorb the commitment .absorb_point::(1) // challenge in bytes @@ -29,7 +40,7 @@ impl SchnorrIOPattern for IOPattern { } } -fn schnorr_proof>( +fn prove>( transcript: &mut Arthur, sk: G::ScalarField, g: G, @@ -80,14 +91,21 @@ fn main() { type G = ark_bls12_381::G1Affine; type F = ark_bls12_381::Fr; - let io_pattern = IOPattern::new("domsep").schnorr_io::(); + let io_pattern = IOPattern::new("the domain separator goes here") + // append the statement (generator, public key) + .schnorr_statement::() + // process the statement separating it from the rest of the protocol + .process() + // add the schnorr io pattern + .schnorr_io::(); let sk = F::rand(&mut OsRng); let g = G::generator(); let mut writer = Vec::new(); g.serialize_compressed(&mut writer).unwrap(); let pk = (g * &sk).into(); + let mut prover_transcript = Arthur::::from(io_pattern.clone()); - let proof = schnorr_proof::(&mut prover_transcript, sk, g, pk).expect("Valid proof"); + let proof = prove::(&mut prover_transcript, sk, g, pk).expect("Valid proof"); let mut verifier_transcript = Merlin::from(io_pattern.clone()); verify::(&mut verifier_transcript, g, pk, proof).expect("Valid proof"); diff --git a/src/arkworks_plugins/absorbs.rs b/src/arkworks_plugins/absorbs.rs index d12feb6..b293b13 100644 --- a/src/arkworks_plugins/absorbs.rs +++ b/src/arkworks_plugins/absorbs.rs @@ -1,6 +1,6 @@ use crate::{Lane, Merlin}; -use super::super::{Duplexer, InvalidTag, Arthur}; +use super::super::{Arthur, Duplexer, InvalidTag}; use super::Absorbable; use rand::{CryptoRng, RngCore}; diff --git a/src/arkworks_plugins/field_challenges.rs b/src/arkworks_plugins/field_challenges.rs index ceebe79..4fdf1bf 100644 --- a/src/arkworks_plugins/field_challenges.rs +++ b/src/arkworks_plugins/field_challenges.rs @@ -1,9 +1,7 @@ -use super::super::{Duplexer, InvalidTag, Merlin, Arthur}; +use super::super::{Arthur, Duplexer, InvalidTag, Merlin}; use ark_ff::PrimeField; use rand::{CryptoRng, RngCore}; - - pub trait FieldChallenges { /// Squeeze a field element challenge of `byte_count` bytes /// from the protocol transcript. diff --git a/src/arkworks_plugins/iopattern.rs b/src/arkworks_plugins/iopattern.rs index 764410d..e7abb14 100644 --- a/src/arkworks_plugins/iopattern.rs +++ b/src/arkworks_plugins/iopattern.rs @@ -3,7 +3,7 @@ use ark_ff::{Field, PrimeField}; use core::borrow::Borrow; use super::{ - super::{Duplexer, IOPattern, Lane, Merlin, Arthur}, + super::{Arthur, Duplexer, IOPattern, Lane, Merlin}, Absorbable, }; @@ -40,7 +40,7 @@ where } pub fn absorb_bytes(self, count: usize) -> Self { - let count = usize::div_ceil(count, S::L::compressed_size()); + let count = crate::div_ceil!(count, S::L::compressed_size()); self.iop.absorb(count).into() } @@ -60,13 +60,14 @@ where } pub fn squeeze_bytes(self, count: usize) -> Self { - let count = usize::div_ceil(count, S::L::extractable_bytelen()); + let count = crate::div_ceil!(count, S::L::extractable_bytelen()); self.iop.squeeze(count).into() } pub fn squeeze_field(self, count: usize) -> Self { // XXX. check if typeof::() == typeof::() and if so use native squeeze - self.squeeze_bytes(super::random_felt_bytelen::() * count).into() + self.squeeze_bytes(super::random_felt_bytelen::() * count) + .into() } } diff --git a/src/arkworks_plugins/mod.rs b/src/arkworks_plugins/mod.rs index c798acb..3f757e1 100644 --- a/src/arkworks_plugins/mod.rs +++ b/src/arkworks_plugins/mod.rs @@ -5,7 +5,7 @@ mod iopattern; pub use absorbs::Absorbs; use ark_ec::{ - short_weierstrass::{Affine, SWCurveConfig, Projective}, + short_weierstrass::{Affine, Projective, SWCurveConfig}, AffineRepr, CurveGroup, }; use ark_ff::{BigInteger, Fp, FpConfig, PrimeField}; @@ -62,10 +62,12 @@ impl, P: SWCurveConfig>> Abs } // this one little `where` trick avoids specifying in any implementation `Projective

: Absorbable`. -impl<'a, P: SWCurveConfig, L: Lane> Absorbable for Affine

where - Projective

: Absorbable { +impl<'a, P: SWCurveConfig, L: Lane> Absorbable for Affine

+where + Projective

: Absorbable, +{ fn absorb_size() -> usize { - usize::div_ceil(Self::default().compressed_size(), L::compressed_size()) + crate::div_ceil!(Self::default().compressed_size(), L::compressed_size()) } fn to_absorbable(&self) -> Vec { @@ -75,20 +77,16 @@ impl<'a, P: SWCurveConfig, L: Lane> Absorbable for Affine

where } } -impl Absorbable for Projective

{ +impl Absorbable for Projective

{ fn absorb_size() -> usize { as Absorbable>::absorb_size() } fn to_absorbable(&self) -> Vec { as Absorbable>::to_absorbable(&self.into_affine()) - } } - - - #[macro_export] macro_rules! impl_absorbable { ($t:ty) => { diff --git a/src/arthur.rs b/src/arthur.rs index 7f88a3f..25144ba 100644 --- a/src/arthur.rs +++ b/src/arthur.rs @@ -111,9 +111,6 @@ where pub(crate) merlin: Merlin, } - - - impl Arthur { pub fn new(io_pattern: &IOPattern, csrng: R) -> Self { ArthurBuilder::new(io_pattern).finalize_with_rng(csrng) diff --git a/src/lane.rs b/src/lane.rs index a54d8ec..e723881 100644 --- a/src/lane.rs +++ b/src/lane.rs @@ -1,5 +1,5 @@ +use serde::{Deserialize, Serialize}; use zeroize::Zeroize; -use serde::{Serialize, Deserialize}; /// A Lane is the basic unit a sponge function works on. /// We need only two things from a lane: the ability to convert it to bytes and back. diff --git a/src/lib.rs b/src/lib.rs index 218099e..5125a88 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,4 +1,3 @@ -#![feature(int_roundings)] //! //! **This crate is work in progress, not suitable for production.** //! @@ -120,3 +119,12 @@ pub(crate) use sponge::DuplexSponge; pub type DefaultRng = rand::rngs::OsRng; pub type DefaultHash = keccak::Keccak; pub type DefaultTranscript = Arthur; + +/// Perform ceil division. +/// XXX. Remove once feature(int_roundings) is on stable. +macro_rules! div_ceil { + ($a: expr, $b: expr) => { + ($a + $b - 1) / $b + }; +} +pub(crate) use div_ceil; diff --git a/src/safe.rs b/src/safe.rs index fe76c27..482dc37 100644 --- a/src/safe.rs +++ b/src/safe.rs @@ -249,11 +249,20 @@ impl Safe { } None => { self.stack.clear(); - Err(format!("Invalid tag. Stack empty, got {:?}", Op::Absorb(input.len())).into()) + Err(format!( + "Invalid tag. Stack empty, got {:?}", + Op::Absorb(input.len()) + ) + .into()) } Some(op) => { self.stack.clear(); - Err(format!("Invalid tag. Expected {:?}, got {:?}", Op::Absorb(input.len()), op).into()) + Err(format!( + "Invalid tag. Expected {:?}, got {:?}", + Op::Absorb(input.len()), + op + ) + .into()) } } } @@ -265,8 +274,8 @@ impl Safe { /// This function provides no guarantee of streaming-friendliness. pub fn squeeze_bytes(&mut self, output: &mut [u8]) -> Result<(), InvalidTag> { match self.stack.pop_front() { - Some(Op::Squeeze(length)) if output.len()<= length => { - let squeeze_len = usize::div_ceil(length, D::L::extractable_bytelen()); + Some(Op::Squeeze(length)) if output.len() <= length => { + let squeeze_len = super::div_ceil!(length, D::L::extractable_bytelen()); let mut squeeze_lane = vec![D::L::default(); squeeze_len]; self.sponge.squeeze_unchecked(&mut squeeze_lane); let mut squeeze_bytes = vec![0u8; D::L::extractable_bytelen() * squeeze_len]; @@ -279,13 +288,21 @@ impl Safe { } None => { self.stack.clear(); - Err(format!("Invalid tag. Stack empty, got {:?}", Op::Squeeze(output.len())).into()) + Err(format!( + "Invalid tag. Stack empty, got {:?}", + Op::Squeeze(output.len()) + ) + .into()) } Some(op) => { self.stack.clear(); - Err(format!("Invalid tag. Expected {:?}, got {:?}", Op::Squeeze(output.len()), op).into()) + Err(format!( + "Invalid tag. Expected {:?}, got {:?}", + Op::Squeeze(output.len()), + op + ) + .into()) } - } } diff --git a/src/sponge.rs b/src/sponge.rs index 7430ddf..692988e 100644 --- a/src/sponge.rs +++ b/src/sponge.rs @@ -106,5 +106,4 @@ impl> Duplexer for DuplexSponge { self.squeeze_pos = C::RATE; self } - } diff --git a/src/tests.rs b/src/tests.rs index a144d6c..71c944f 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -1,4 +1,4 @@ -use crate::{IOPattern, Merlin, keccak::Keccak}; +use crate::{keccak::Keccak, IOPattern, Merlin}; /// How should a protocol without IOPattern be handled? #[test] @@ -32,7 +32,6 @@ fn test_deterministic() { let mut first_merlin = Merlin::::new(&iop); let mut second_merlin = Merlin::::new(&iop); - let mut first = [0u8; 16]; let mut second = [0u8; 16]; @@ -44,20 +43,25 @@ fn test_deterministic() { assert_eq!(first, second); } - /// Basic scatistical test to check that the squeezed output looks random. /// XXX. #[test] fn test_statistics() { - let iop = IOPattern::new("example.com").absorb(4).process().squeeze(2048); + let iop = IOPattern::new("example.com") + .absorb(4) + .process() + .squeeze(2048); let mut merlin = Merlin::::new(&iop); merlin.append(b"seed").unwrap(); merlin.process().unwrap(); let mut output = [0u8; 2048]; merlin.challenge_bytes(&mut output).unwrap(); - - let frequencies = (0u8..=255).map(|i| output.iter().filter(|&&x| x == i).count()).collect::>(); + let frequencies = (0u8..=255) + .map(|i| output.iter().filter(|&&x| x == i).count()) + .collect::>(); // each element should appear roughly 8 times on average. Checking we're not too far from that. - assert!(frequencies.iter().all(|&x| x < frequencies[0] + 16 && x > 0)); -} \ No newline at end of file + assert!(frequencies + .iter() + .all(|&x| x < frequencies[0] + 16 && x > 0)); +}