Skip to content

Commit

Permalink
BLS fast sums (#840)
Browse files Browse the repository at this point in the history
* Group elements to and from uncompressed form

* Add ToFromUncompressedBytes impl for G2Elements

* docs + tests

* add sum function

* clippy

* benchmarks

* add fast sum function

* very fast version works

* clean up

* keep only 'safe' version

* clean up

* more test cases

* name

* add to_bytes method

* Clean up

* benchmarks + docs

* regression test

* review comments #1

* test

* tests

* test

* test

* use vector of ptrs

* remove redundant comment
  • Loading branch information
jonas-lj authored Oct 2, 2024
1 parent c050ffc commit 2f502fd
Show file tree
Hide file tree
Showing 3 changed files with 266 additions and 11 deletions.
57 changes: 53 additions & 4 deletions fastcrypto/benches/groups.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,16 @@ mod group_benches {
use criterion::measurement::Measurement;
use criterion::{measurement, BenchmarkGroup, BenchmarkId, Criterion};
use fastcrypto::groups::bls12381::{
G1Element, G2Element, GTElement, Scalar as BlsScalar, G1_ELEMENT_BYTE_LENGTH,
G2_ELEMENT_BYTE_LENGTH, GT_ELEMENT_BYTE_LENGTH, SCALAR_LENGTH,
G1Element, G1ElementUncompressed, G2Element, GTElement, Scalar as BlsScalar,
G1_ELEMENT_BYTE_LENGTH, G2_ELEMENT_BYTE_LENGTH, GT_ELEMENT_BYTE_LENGTH, SCALAR_LENGTH,
};
use fastcrypto::groups::multiplier::windowed::WindowedScalarMultiplier;
use fastcrypto::groups::multiplier::ScalarMultiplier;
use fastcrypto::groups::ristretto255::RistrettoPoint;
use fastcrypto::groups::secp256r1::ProjectivePoint;
use fastcrypto::groups::{
secp256r1, FromTrustedByteArray, GroupElement, HashToGroupElement, MultiScalarMul, Pairing,
Scalar,
bls12381, secp256r1, FromTrustedByteArray, GroupElement, HashToGroupElement,
MultiScalarMul, Pairing, Scalar,
};
use fastcrypto::serde_helpers::ToFromByteArray;
use rand::thread_rng;
Expand Down Expand Up @@ -211,6 +211,54 @@ mod group_benches {
pairing_single::<G1Element, _>("BLS12381-G1", &mut group);
}

fn sum(c: &mut Criterion) {
static NUMBER_OF_TERMS: [usize; 4] = [10, 100, 500, 1000];

for n in NUMBER_OF_TERMS {
let terms: Vec<G1Element> = (0..n)
.map(|_| G1Element::generator() * bls12381::Scalar::rand(&mut thread_rng()))
.collect();

let terms_uncompressed = terms
.iter()
.map(G1ElementUncompressed::from)
.map(G1ElementUncompressed::into_byte_array)
.collect::<Vec<_>>();

let terms_compressed = terms
.iter()
.map(G1Element::to_byte_array)
.collect::<Vec<_>>();

c.bench_function(&format!("Sum/BLS12381-G1/{} uncompressed", n), move |b| {
b.iter_batched(
|| terms_uncompressed.clone(),
|t| {
let terms_deserialized = t
.into_iter()
.map(G1ElementUncompressed::from_trusted_byte_array)
.collect::<Vec<_>>();
G1ElementUncompressed::sum(terms_deserialized.as_slice())
},
criterion::BatchSize::SmallInput,
)
});

c.bench_function(&format!("Sum/BLS12381-G1/{} compressed", n), move |b| {
b.iter_batched(
|| terms_compressed.clone(),
|t| {
t.iter()
.map(G1Element::from_trusted_byte_array)
.map(Result::unwrap)
.reduce(|a, b| a + b)
},
criterion::BatchSize::SmallInput,
)
});
}
}

/// Implementation of a `Multiplier` where scalar multiplication is done without any pre-computation by
/// simply calling the GroupElement implementation. Only used for benchmarking.
struct DefaultMultiplier<G: GroupElement>(G);
Expand Down Expand Up @@ -294,6 +342,7 @@ mod group_benches {
pairing,
double_scale,
blst_msm,
sum,
}
}

Expand Down
112 changes: 106 additions & 6 deletions fastcrypto/src/groups/bls12381.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
use crate::bls12381::min_pk::DST_G2;
use crate::bls12381::min_sig::DST_G1;
use crate::encoding::{Encoding, Hex};
use crate::error::{FastCryptoError, FastCryptoResult};
use crate::error::{FastCryptoError, FastCryptoError::InvalidInput, FastCryptoResult};
use crate::groups::{
FiatShamirChallenge, FromTrustedByteArray, GroupElement, HashToGroupElement, MultiScalarMul,
Pairing, Scalar as ScalarType,
Expand All @@ -20,11 +20,12 @@ use blst::{
blst_fr_from_scalar, blst_fr_from_uint64, blst_fr_inverse, blst_fr_mul, blst_fr_rshift,
blst_fr_sub, blst_hash_to_g1, blst_hash_to_g2, blst_lendian_from_scalar, blst_miller_loop,
blst_p1, blst_p1_add_or_double, blst_p1_affine, blst_p1_cneg, blst_p1_compress,
blst_p1_from_affine, blst_p1_in_g1, blst_p1_mult, blst_p1_to_affine, blst_p1_uncompress,
blst_p2, blst_p2_add_or_double, blst_p2_affine, blst_p2_cneg, blst_p2_compress,
blst_p2_from_affine, blst_p2_in_g2, blst_p2_mult, blst_p2_to_affine, blst_p2_uncompress,
blst_scalar, blst_scalar_fr_check, blst_scalar_from_be_bytes, blst_scalar_from_bendian,
blst_scalar_from_fr, p1_affines, p2_affines, BLS12_381_G1, BLS12_381_G2, BLST_ERROR,
blst_p1_deserialize, blst_p1_from_affine, blst_p1_in_g1, blst_p1_mult, blst_p1_serialize,
blst_p1_to_affine, blst_p1_uncompress, blst_p1s_add, blst_p2, blst_p2_add_or_double,
blst_p2_affine, blst_p2_cneg, blst_p2_compress, blst_p2_from_affine, blst_p2_in_g2,
blst_p2_mult, blst_p2_to_affine, blst_p2_uncompress, blst_scalar, blst_scalar_fr_check,
blst_scalar_from_be_bytes, blst_scalar_from_bendian, blst_scalar_from_fr, p1_affines,
p2_affines, BLS12_381_G1, BLS12_381_G2, BLST_ERROR,
};
use fastcrypto_derive::GroupOpsExtend;
use hex_literal::hex;
Expand Down Expand Up @@ -333,6 +334,105 @@ impl Debug for G1Element {
serialize_deserialize_with_to_from_byte_array!(G1Element);
generate_bytes_representation!(G1Element, G1_ELEMENT_BYTE_LENGTH, G1ElementAsBytes);

/// An uncompressed serialization of a G1 element. This format is two times longer than the compressed
/// format used by `G1Element::serialize`, but is much faster to deserialize.
///
/// The intended use of this struct is to deserialize and sum a large number of G1 elements without
/// having to decompress them first.
#[derive(Clone, Debug)]
#[repr(transparent)]
pub struct G1ElementUncompressed(pub(crate) [u8; 2 * G1_ELEMENT_BYTE_LENGTH]);

impl From<&G1Element> for G1ElementUncompressed {
fn from(element: &G1Element) -> Self {
let mut bytes = [0u8; 2 * G1_ELEMENT_BYTE_LENGTH];
unsafe {
blst_p1_serialize(bytes.as_mut_ptr(), &element.0);
}
G1ElementUncompressed(bytes)
}
}

impl TryFrom<&G1ElementUncompressed> for G1Element {
type Error = FastCryptoError;

fn try_from(value: &G1ElementUncompressed) -> Result<Self, Self::Error> {
// See https://github.com/supranational/blst for details on the serialization format.

// Note that `blst_p1_deserialize` accepts both compressed and uncompressed serializations,
// so we check that the compressed bit flag (the 1st) is not set. The third is used for
// compressed points to indicate sign of the y-coordinate and should also not be set.
if value.0[0] & 0x20 != 0 || value.0[0] & 0x80 != 0 {
return Err(InvalidInput);
}

let mut ret = blst_p1::default();
unsafe {
let mut affine = blst_p1_affine::default();
if blst_p1_deserialize(&mut affine, value.0.as_ptr()) != BLST_ERROR::BLST_SUCCESS {
return Err(InvalidInput);
}
blst_p1_from_affine(&mut ret, &affine);

if !blst_p1_in_g1(&ret) {
return Err(InvalidInput);
}
}
Ok(G1Element(ret))
}
}

impl G1ElementUncompressed {
/// Create a new `G1ElementUncompressed` from a byte array.
/// The input is not validated so it should come from a trusted source.
///
/// See [the blst docs](https://github.com/supranational/blst/tree/master?tab=readme-ov-file#serialization-format) for details about the uncompressed serialization format.
pub fn from_trusted_byte_array(bytes: [u8; 2 * G1_ELEMENT_BYTE_LENGTH]) -> Self {
Self(bytes)
}

/// Get the byte array representation of this element.
pub fn into_byte_array(self) -> [u8; 2 * G1_ELEMENT_BYTE_LENGTH] {
self.0
}

/// This will never fail if the input is a valid G1 element.
fn to_blst_p1_affine(&self) -> FastCryptoResult<blst_p1_affine> {
let mut affine = blst_p1_affine::default();
unsafe {
// This fails if the point is not on the curve or if it is (0, ±2) which is on the curve
// but not in the G1 subgroup. See https://github.com/supranational/blst/blob/6f3136ffb636974166a93f2f25436854fe8d10ff/src/e1.c#L296-L326.
// A subgroup check is not performed here.
if blst_p1_deserialize(&mut affine, self.0.as_ptr()) != BLST_ERROR::BLST_SUCCESS {
return Err(InvalidInput);
}
}
Ok(affine)
}

/// Compute the sum of a slice of uncompressed G1 elements.
///
/// This function will never fail if the inputs are valid G1 element.
pub fn sum(terms: &[G1ElementUncompressed]) -> FastCryptoResult<G1Element> {
if terms.is_empty() {
return Ok(G1Element::zero());
}

let affine_points: Vec<blst_p1_affine> = terms
.iter()
.map(G1ElementUncompressed::to_blst_p1_affine)
.collect::<FastCryptoResult<Vec<_>>>()?;

let mut ret = blst_p1::default();
let p = affine_points
.iter()
.map(|p| p as *const _)
.collect::<Vec<_>>();
unsafe { blst_p1s_add(&mut ret, p.as_ptr(), p.len()) };
Ok(G1Element(ret))
}
}

impl Add for G2Element {
type Output = Self;

Expand Down
108 changes: 107 additions & 1 deletion fastcrypto/src/tests/bls12381_group_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@
// SPDX-License-Identifier: Apache-2.0

use crate::bls12381::min_pk::{BLS12381KeyPair, BLS12381Signature};
use crate::groups::bls12381::{reduce_mod_uniform_buffer, G1Element, G2Element, GTElement, Scalar};
use crate::groups::bls12381::{
reduce_mod_uniform_buffer, G1Element, G1ElementUncompressed, G2Element, GTElement, Scalar,
G1_ELEMENT_BYTE_LENGTH,
};
use crate::groups::{
FromTrustedByteArray, GroupElement, HashToGroupElement, MultiScalarMul, Pairing,
Scalar as ScalarTrait,
Expand Down Expand Up @@ -653,3 +656,106 @@ fn test_serialization_gt() {
assert!(GTElement::from_trusted_byte_array(&bytes).is_ok());
assert!(GTElement::from_byte_array(&bytes).is_err());
}

#[test]
fn test_g1_to_uncompressed() {
let a = G1Element::generator() * Scalar::from(7u128);

let uncompressed_bytes = G1ElementUncompressed::from(&a);

// Compressed bit flags (1 and 3) should not be set.
assert_eq!(uncompressed_bytes.0[0] & 0xA0, 0);

// Infinity bit flag (2) should not be set.
assert_eq!(uncompressed_bytes.0[0] & 0x40, 0);

// Regression test
assert_eq!(&uncompressed_bytes.0, hex::decode("1928f3beb93519eecf0145da903b40a4c97dca00b21f12ac0df3be9116ef2ef27b2ae6bcd4c5bc2d54ef5a70627efcb7108dadbaa4b636445639d5ae3089b3c43a8a1d47818edd1839d7383959a41c10fdc66849cfa1b08c5a11ec7e28981a1c").unwrap().as_slice());

// Check round-trip
let b = G1Element::try_from(&uncompressed_bytes).unwrap();
assert_eq!(a, b);

// Simply padding a compressed serialization with 0's will fail
let mut padded = b.to_byte_array().to_vec();
padded.extend_from_slice(&[0u8; G1_ELEMENT_BYTE_LENGTH]);
assert_eq!(padded.len(), 2 * G1_ELEMENT_BYTE_LENGTH);
let uncompressed = G1ElementUncompressed::from_trusted_byte_array(padded.try_into().unwrap());
assert!(G1Element::try_from(&uncompressed).is_err());

// A point not on the curve fails
let mut bytes = uncompressed_bytes.into_byte_array();
bytes[1] += 1;
let uncompressed_bytes = G1ElementUncompressed::from_trusted_byte_array(bytes);
assert!(G1Element::try_from(&uncompressed_bytes).is_err());

// Serialize the point-at-infinity
let a = G1Element::zero();
let uncompressed_bytes = G1ElementUncompressed::from(&a);

// Only the point at infinity flag should be set.
assert_eq!(uncompressed_bytes.0[0], 0x40);

// The remaining bytes should all be zero
assert_eq!(
uncompressed_bytes.0[1..],
[0u8; G1_ELEMENT_BYTE_LENGTH * 2 - 1]
);

// All zeros
let uncompressed =
G1ElementUncompressed::from_trusted_byte_array([0u8; 2 * G1_ELEMENT_BYTE_LENGTH]);
assert!(G1Element::try_from(&uncompressed).is_err());
}

#[test]
fn test_g1_sum() {
// Empty sum
assert_eq!(G1ElementUncompressed::sum(&[]).unwrap(), G1Element::zero());

// Non-trivial sum
let a = G1Element::generator();
let b = G1Element::generator() * Scalar::from(2u128);
let c = G1Element::generator() * Scalar::from(3u128);
let mut bytes: Vec<G1ElementUncompressed> = vec![(&a).into(), (&b).into(), (&c).into()];
let sum = G1ElementUncompressed::sum(&bytes).unwrap();
assert_eq!(sum, G1Element::generator() * Scalar::from(6u128));

// Adding zeros doesn't change anything
bytes.push(G1ElementUncompressed::from(&G1Element::zero()));
let sum = G1ElementUncompressed::sum(&bytes).unwrap();
assert_eq!(sum, G1Element::generator() * Scalar::from(6u128));

// Equal elements in sum
let bytes = vec![(&b).into(), (&b).into()];
let sum = G1ElementUncompressed::sum(&bytes).unwrap();
assert_eq!(sum, G1Element::generator() * Scalar::from(4u128));

// Singleton sum
let bytes = [(&b).into()];
let sum = G1ElementUncompressed::sum(&bytes).unwrap();
assert_eq!(sum, b);

// Adding zero's
let mut bytes = vec![G1ElementUncompressed::from(&G1Element::zero())];
let sum = G1ElementUncompressed::sum(&bytes).unwrap();
assert_eq!(sum, G1Element::zero());
bytes.push(G1ElementUncompressed::from(&G1Element::zero()));
let sum = G1ElementUncompressed::sum(&bytes).unwrap();
assert_eq!(sum, G1Element::zero());
}

#[test]
fn test_g1_large_sum() {
let mut rng = thread_rng();
let n: usize = 100;
let points: Vec<G1Element> = (0..n)
.map(|_| G1Element::generator() * Scalar::rand(&mut rng))
.collect();
let expected = points.iter().fold(G1Element::zero(), |acc, p| acc + p);

let as_uncompressed: Vec<G1ElementUncompressed> =
points.iter().map(G1ElementUncompressed::from).collect();
let sum = G1ElementUncompressed::sum(as_uncompressed.as_slice()).unwrap();
assert_eq!(expected, sum);
}

0 comments on commit 2f502fd

Please sign in to comment.