diff --git a/Cargo.toml b/Cargo.toml index 629fe2e..25b170e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -33,7 +33,7 @@ cuda-mobile = [] [dependencies] blst = "~0.3.11" semolina = "~0.1.3" -sppark = "~0.1.2" +sppark = { git = "https://github.com/lurk-lab/sppark", branch = "preallocate-msm" } halo2curves = { version = "0.6.0" } pasta_curves = { git = "https://github.com/lurk-lab/pasta_curves", branch = "dev", version = ">=0.3.1, <=0.5", features = ["repr-c"] } rand = "^0" diff --git a/benches/grumpkin_msm.rs b/benches/grumpkin_msm.rs index 1575f48..86b9edc 100644 --- a/benches/grumpkin_msm.rs +++ b/benches/grumpkin_msm.rs @@ -30,10 +30,21 @@ fn criterion_benchmark(c: &mut Criterion) { group.bench_function(format!("2**{} points", bench_npow), |b| { b.iter(|| { - let _ = grumpkin_msm::bn256(&points, &scalars); + let _ = grumpkin_msm::bn256::msm(&points, &scalars); }) }); + let context = grumpkin_msm::bn256::init(&points); + + group.bench_function( + format!("\"preallocate\" 2**{} points", bench_npow), + |b| { + b.iter(|| { + let _ = grumpkin_msm::bn256::with(&context, &scalars); + }) + }, + ); + group.finish(); #[cfg(feature = "cuda")] @@ -54,10 +65,22 @@ fn criterion_benchmark(c: &mut Criterion) { group.bench_function(format!("2**{} points", bench_npow), |b| { b.iter(|| { - let _ = grumpkin_msm::bn256(&points, &scalars); + let _ = grumpkin_msm::bn256::msm(&points, &scalars); }) }); + let context = grumpkin_msm::bn256::init(&points); + + group.bench_function( + format!("preallocate 2**{} points", bench_npow), + |b| { + b.iter(|| { + let _ = + grumpkin_msm::bn256::with(&context, &scalars); + }) + }, + ); + group.finish(); } } diff --git a/benches/pasta_msm.rs b/benches/pasta_msm.rs index f77610f..2ddac1e 100644 --- a/benches/pasta_msm.rs +++ b/benches/pasta_msm.rs @@ -30,10 +30,21 @@ fn criterion_benchmark(c: &mut Criterion) { group.bench_function(format!("2**{} points", bench_npow), |b| { b.iter(|| { - let _ = grumpkin_msm::pasta::pallas(&points, &scalars); + let _ = grumpkin_msm::pasta::pallas::msm(&points, &scalars); }) }); + let context = grumpkin_msm::pasta::pallas::init(&points); + + group.bench_function( + format!("\"preallocate\" 2**{} points", bench_npow), + |b| { + b.iter(|| { + let _ = grumpkin_msm::pasta::pallas::with(&context, &scalars); + }) + }, + ); + group.finish(); #[cfg(feature = "cuda")] @@ -54,10 +65,22 @@ fn criterion_benchmark(c: &mut Criterion) { group.bench_function(format!("2**{} points", bench_npow), |b| { b.iter(|| { - let _ = grumpkin_msm::pasta::pallas(&points, &scalars); + let _ = grumpkin_msm::pasta::pallas::msm(&points, &scalars); }) }); + let context = grumpkin_msm::pasta::pallas::init(&points); + + group.bench_function( + format!("preallocate 2**{} points", bench_npow), + |b| { + b.iter(|| { + let _ = + grumpkin_msm::pasta::pallas::with(&context, &scalars); + }) + }, + ); + group.finish(); } } diff --git a/cuda/bn254.cu b/cuda/bn254.cu index 3c0c1c9..fc97057 100644 --- a/cuda/bn254.cu +++ b/cuda/bn254.cu @@ -17,8 +17,26 @@ typedef fr_t scalar_t; #include #ifndef __CUDA_ARCH__ -extern "C" -RustError cuda_pippenger_bn254(point_t *out, const affine_t points[], size_t npoints, - const scalar_t scalars[]) -{ return mult_pippenger(out, points, npoints, scalars); } + +extern "C" void drop_msm_context_bn254(msm_context_t &ref) { + CUDA_OK(cudaFree(ref.d_points)); +} + +extern "C" RustError +cuda_bn254_init(const affine_t points[], size_t npoints, msm_context_t *msm_context) +{ + return mult_pippenger_init(points, npoints, msm_context); +} + +extern "C" RustError cuda_bn254(point_t *out, const affine_t points[], size_t npoints, + const scalar_t scalars[]) +{ + return mult_pippenger(out, points, npoints, scalars); +} + +extern "C" RustError cuda_bn254_with(point_t *out, msm_context_t *msm_context, size_t npoints, + const scalar_t scalars[]) +{ + return mult_pippenger_with(out, msm_context, npoints, scalars); +} #endif diff --git a/cuda/grumpkin.cu b/cuda/grumpkin.cu index 861610d..8403f92 100644 --- a/cuda/grumpkin.cu +++ b/cuda/grumpkin.cu @@ -17,8 +17,26 @@ typedef fp_t scalar_t; #include #ifndef __CUDA_ARCH__ -extern "C" -RustError cuda_pippenger_grumpkin(point_t *out, const affine_t points[], size_t npoints, - const scalar_t scalars[]) -{ return mult_pippenger(out, points, npoints, scalars); } + +extern "C" void drop_msm_context_grumpkin(msm_context_t &ref) { + CUDA_OK(cudaFree(ref.d_points)); +} + +extern "C" RustError +cuda_grumpkin_init(const affine_t points[], size_t npoints, msm_context_t *msm_context) +{ + return mult_pippenger_init(points, npoints, msm_context); +} + +extern "C" RustError cuda_grumpkin(point_t *out, const affine_t points[], size_t npoints, + const scalar_t scalars[]) +{ + return mult_pippenger(out, points, npoints, scalars); +} + +extern "C" RustError cuda_grumpkin_with(point_t *out, msm_context_t *msm_context, size_t npoints, + const scalar_t scalars[]) +{ + return mult_pippenger_with(out, msm_context, npoints, scalars); +} #endif diff --git a/cuda/pallas.cu b/cuda/pallas.cu index f3897bb..62a375a 100644 --- a/cuda/pallas.cu +++ b/cuda/pallas.cu @@ -17,8 +17,27 @@ typedef vesta_t scalar_t; #include #ifndef __CUDA_ARCH__ -extern "C" -RustError cuda_pippenger_pallas(point_t *out, const affine_t points[], size_t npoints, - const scalar_t scalars[]) -{ return mult_pippenger(out, points, npoints, scalars); } + +extern "C" void drop_msm_context_pallas(msm_context_t &ref) { + CUDA_OK(cudaFree(ref.d_points)); +} + +extern "C" RustError +cuda_pallas_init(const affine_t points[], size_t npoints, msm_context_t *msm_context) +{ + return mult_pippenger_init(points, npoints, msm_context); +} + +extern "C" RustError cuda_pallas(point_t *out, const affine_t points[], size_t npoints, + const scalar_t scalars[]) +{ + return mult_pippenger(out, points, npoints, scalars); +} + +extern "C" RustError cuda_pallas_with(point_t *out, msm_context_t *msm_context, size_t npoints, + const scalar_t scalars[]) +{ + return mult_pippenger_with(out, msm_context, npoints, scalars); +} + #endif diff --git a/cuda/vesta.cu b/cuda/vesta.cu index a926c6d..63db638 100644 --- a/cuda/vesta.cu +++ b/cuda/vesta.cu @@ -17,8 +17,27 @@ typedef pallas_t scalar_t; #include #ifndef __CUDA_ARCH__ -extern "C" -RustError cuda_pippenger_vesta(point_t *out, const affine_t points[], size_t npoints, - const scalar_t scalars[]) -{ return mult_pippenger(out, points, npoints, scalars); } + +extern "C" void drop_msm_context_vesta(msm_context_t &ref) { + CUDA_OK(cudaFree(ref.d_points)); +} + +extern "C" RustError +cuda_vesta_init(const affine_t points[], size_t npoints, msm_context_t *msm_context) +{ + return mult_pippenger_init(points, npoints, msm_context); +} + +extern "C" RustError cuda_vesta(point_t *out, const affine_t points[], size_t npoints, + const scalar_t scalars[]) +{ + return mult_pippenger(out, points, npoints, scalars); +} + +extern "C" RustError cuda_vesta_with(point_t *out, msm_context_t *msm_context, size_t npoints, + const scalar_t scalars[]) +{ + return mult_pippenger_with(out, msm_context, npoints, scalars); +} + #endif diff --git a/examples/grumpkin_msm.rs b/examples/grumpkin_msm.rs index 3ca237a..56eaed5 100644 --- a/examples/grumpkin_msm.rs +++ b/examples/grumpkin_msm.rs @@ -20,7 +20,7 @@ fn main() { unsafe { grumpkin_msm::CUDA_OFF = false }; } - let res = grumpkin_msm::bn256(&points, &scalars).to_affine(); + let res = grumpkin_msm::bn256::msm(&points, &scalars).to_affine(); let native = naive_multiscalar_mul(&points, &scalars); assert_eq!(res, native); println!("success!") diff --git a/examples/pasta_msm.rs b/examples/pasta_msm.rs index a682c35..f382756 100644 --- a/examples/pasta_msm.rs +++ b/examples/pasta_msm.rs @@ -8,7 +8,7 @@ use pasta_curves::group::Curve; fn main() { let bench_npow: usize = std::env::var("BENCH_NPOW") - .unwrap_or("17".to_string()) + .unwrap_or("22".to_string()) .parse() .unwrap(); let npoints: usize = 1 << bench_npow; @@ -22,8 +22,10 @@ fn main() { unsafe { grumpkin_msm::CUDA_OFF = false }; } - let res = grumpkin_msm::pasta::pallas(&points, &scalars).to_affine(); let native = naive_multiscalar_mul(&points, &scalars); + let context = grumpkin_msm::pasta::pallas::init(&points); + let res = grumpkin_msm::pasta::pallas::with(&context, &scalars).to_affine(); + assert_eq!(res, native); println!("success!") } diff --git a/src/lib.rs b/src/lib.rs index 16fb22a..cd51f3d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -18,90 +18,239 @@ extern "C" { #[cfg(feature = "cuda")] pub static mut CUDA_OFF: bool = false; -use halo2curves::bn256; -use halo2curves::CurveExt; +pub mod bn256 { + use halo2curves::{ + bn256::{Fr as Scalar, G1Affine as Affine, G1 as Point}, + CurveExt, + }; -extern "C" { - fn mult_pippenger_bn254( - out: *mut bn256::G1, - points: *const bn256::G1Affine, - npoints: usize, - scalars: *const bn256::Fr, + use crate::impl_msm; + + impl_msm!( + cuda_bn254, + cuda_bn254_init, + cuda_bn254_with, + mult_pippenger_bn254, + Point, + Affine, + Scalar ); +} + +pub mod grumpkin { + use halo2curves::{ + grumpkin::{Fr as Scalar, G1Affine as Affine, G1 as Point}, + CurveExt, + }; + use crate::impl_msm; + + impl_msm!( + cuda_grumpkin, + cuda_grumpkin_init, + cuda_grumpkin_with, + mult_pippenger_grumpkin, + Point, + Affine, + Scalar + ); } -pub fn bn256(points: &[bn256::G1Affine], scalars: &[bn256::Fr]) -> bn256::G1 { - let npoints = points.len(); - assert!(npoints == scalars.len(), "length mismatch"); +#[macro_export] +macro_rules! impl_msm { + ( + $name:ident, + $name_init:ident, + $name_with:ident, + $name_cpu:ident, + $point:ident, + $affine:ident, + $scalar:ident + ) => { + #[cfg(feature = "cuda")] + use $crate::{cuda, cuda_available, CUDA_OFF}; - #[cfg(feature = "cuda")] - if npoints >= 1 << 16 && unsafe { !CUDA_OFF && cuda_available() } { - extern "C" { - fn cuda_pippenger_bn254( - out: *mut bn256::G1, - points: *const bn256::G1Affine, - npoints: usize, - scalars: *const bn256::Fr, - ) -> cuda::Error; + #[repr(C)] + #[derive(Debug, Clone)] + pub struct CudaMSMContext { + context: *const std::ffi::c_void, + npoints: usize, + } + unsafe impl Send for CudaMSMContext {} + + unsafe impl Sync for CudaMSMContext {} + + impl Default for CudaMSMContext { + fn default() -> Self { + Self { + context: std::ptr::null(), + npoints: 0, + } + } } - let mut ret = bn256::G1::default(); - let err = unsafe { - cuda_pippenger_bn254(&mut ret, &points[0], npoints, &scalars[0]) - }; - assert!(err.code == 0, "{}", String::from(err)); - return bn256::G1::new_jacobian(ret.x, ret.y, ret.z).unwrap(); - } - let mut ret = bn256::G1::default(); - unsafe { mult_pippenger_bn254(&mut ret, &points[0], npoints, &scalars[0]) }; - bn256::G1::new_jacobian(ret.x, ret.y, ret.z).unwrap() -} + #[cfg(feature = "cuda")] + // TODO: check for device-side memory leaks + impl Drop for CudaMSMContext { + fn drop(&mut self) { + extern "C" { + fn drop_msm_context_bn254(by_ref: &CudaMSMContext); + } + unsafe { + drop_msm_context_bn254(std::mem::transmute::<&_, &_>(self)) + }; + self.context = core::ptr::null(); + } + } -use halo2curves::grumpkin; + #[derive(Default, Debug, Clone)] + pub struct MSMContext<'a> { + cuda_context: CudaMSMContext, + on_gpu: bool, + cpu_context: &'a [$affine], + } -extern "C" { - fn mult_pippenger_grumpkin( - out: *mut grumpkin::G1, - points: *const grumpkin::G1Affine, - npoints: usize, - scalars: *const grumpkin::Fr, - ); + unsafe impl<'a> Send for MSMContext<'a> {} -} + unsafe impl<'a> Sync for MSMContext<'a> {} + + impl<'a> MSMContext<'a> { + fn new(points: &'a [$affine]) -> Self { + Self { + cuda_context: CudaMSMContext::default(), + on_gpu: false, + cpu_context: points, + } + } -pub fn grumpkin( - points: &[grumpkin::G1Affine], - scalars: &[grumpkin::Fr], -) -> grumpkin::G1 { - let npoints = points.len(); - assert!(npoints == scalars.len(), "length mismatch"); + fn npoints(&self) -> usize { + if self.on_gpu { + assert_eq!( + self.cpu_context.len(), + self.cuda_context.npoints + ); + } + self.cpu_context.len() + } + + fn cuda(&self) -> &CudaMSMContext { + &self.cuda_context + } + + fn points(&self) -> &[$affine] { + &self.cpu_context + } + } - #[cfg(feature = "cuda")] - if npoints >= 1 << 16 && unsafe { !CUDA_OFF && cuda_available() } { extern "C" { - fn cuda_pippenger_grumpkin( - out: *mut grumpkin::G1, - points: *const grumpkin::G1Affine, + fn $name_cpu( + out: *mut $point, + points: *const $affine, npoints: usize, - scalars: *const grumpkin::Fr, - ) -> cuda::Error; + scalars: *const $scalar, + ); } - let mut ret = grumpkin::G1::default(); - let err = unsafe { - cuda_pippenger_grumpkin(&mut ret, &points[0], npoints, &scalars[0]) - }; - assert!(err.code == 0, "{}", String::from(err)); - return grumpkin::G1::new_jacobian(ret.x, ret.y, ret.z).unwrap(); - } - let mut ret = grumpkin::G1::default(); - unsafe { - mult_pippenger_grumpkin(&mut ret, &points[0], npoints, &scalars[0]) + pub fn msm(points: &[$affine], scalars: &[$scalar]) -> $point { + let npoints = points.len(); + assert!(npoints == scalars.len(), "length mismatch"); + + #[cfg(feature = "cuda")] + if npoints >= 1 << 16 && unsafe { !CUDA_OFF && cuda_available() } { + extern "C" { + fn $name( + out: *mut $point, + points: *const $affine, + npoints: usize, + scalars: *const $scalar, + ) -> cuda::Error; + + } + let mut ret = $point::default(); + let err = unsafe { + $name(&mut ret, &points[0], npoints, &scalars[0]) + }; + assert!(err.code == 0, "{}", String::from(err)); + + return $point::new_jacobian(ret.x, ret.y, ret.z).unwrap(); + } + let mut ret = $point::default(); + unsafe { $name_cpu(&mut ret, &points[0], npoints, &scalars[0]) }; + $point::new_jacobian(ret.x, ret.y, ret.z).unwrap() + } + + pub fn init(points: &[$affine]) -> MSMContext { + let npoints = points.len(); + + let mut ret = MSMContext::new(points); + + #[cfg(feature = "cuda")] + if npoints >= 1 << 16 && unsafe { !CUDA_OFF && cuda_available() } { + extern "C" { + fn $name_init( + points: *const $affine, + npoints: usize, + msm_context: &mut CudaMSMContext, + ) -> cuda::Error; + } + + let npoints = points.len(); + let err = unsafe { + $name_init( + points.as_ptr() as *const _, + npoints, + &mut ret.cuda_context, + ) + }; + assert!(err.code == 0, "{}", String::from(err)); + ret.on_gpu = true; + return ret; + } + + ret + } + + pub fn with(context: &MSMContext, scalars: &[$scalar]) -> $point { + let npoints = context.npoints(); + let nscalars = scalars.len(); + assert!(npoints >= nscalars, "not enough points"); + + let mut ret = $point::default(); + + #[cfg(feature = "cuda")] + if nscalars >= 1 << 16 + && context.on_gpu + && unsafe { !CUDA_OFF && cuda_available() } + { + extern "C" { + fn $name_with( + out: *mut $point, + context: &CudaMSMContext, + npoints: usize, + scalars: *const $scalar, + ) -> cuda::Error; + } + + let err = unsafe { + $name_with(&mut ret, &context.cuda_context, nscalars, &scalars[0]) + }; + assert!(err.code == 0, "{}", String::from(err)); + return $point::new_jacobian(ret.x, ret.y, ret.z).unwrap(); + } + + unsafe { + $name_cpu( + &mut ret, + &context.cpu_context[0], + nscalars, + &scalars[0], + ) + }; + $point::new_jacobian(ret.x, ret.y, ret.z).unwrap() + } }; - grumpkin::G1::new_jacobian(ret.x, ret.y, ret.z).unwrap() } #[cfg(test)] @@ -123,9 +272,14 @@ mod tests { let naive = naive_multiscalar_mul(&points, &scalars); println!("{:?}", naive); - let ret = crate::bn256(&points, &scalars).to_affine(); + let ret = crate::bn256::msm(&points, &scalars).to_affine(); println!("{:?}", ret); + let context = crate::bn256::init(&points); + let ret_other = crate::bn256::with(&context, &scalars).to_affine(); + println!("{:?}", ret_other); + assert_eq!(ret, naive); + assert_eq!(ret, ret_other); } } diff --git a/src/pasta.rs b/src/pasta.rs index c69f6be..de0e781 100644 --- a/src/pasta.rs +++ b/src/pasta.rs @@ -6,111 +6,239 @@ extern crate semolina; -use pasta_curves::pallas; - -#[cfg(feature = "cuda")] -use crate::{cuda, cuda_available, CUDA_OFF}; - -extern "C" { - fn mult_pippenger_pallas( - out: *mut pallas::Point, - points: *const pallas::Affine, - npoints: usize, - scalars: *const pallas::Scalar, - is_mont: bool, +pub mod pallas { + use pasta_curves::pallas::{Affine, Point, Scalar}; + + use crate::impl_pasta; + + impl_pasta!( + cuda_pallas, + cuda_pallas_init, + cuda_pallas_with, + mult_pippenger_pallas, + Point, + Affine, + Scalar ); } -pub fn pallas( - points: &[pallas::Affine], - scalars: &[pallas::Scalar], -) -> pallas::Point { - let npoints = points.len(); - assert_eq!(npoints, scalars.len(), "length mismatch"); +pub mod vesta { + use pasta_curves::vesta::{Affine, Point, Scalar}; - #[cfg(feature = "cuda")] - if npoints >= 1 << 16 && unsafe { !CUDA_OFF && cuda_available() } { - extern "C" { - fn cuda_pippenger_pallas( - out: *mut pallas::Point, - points: *const pallas::Affine, - npoints: usize, - scalars: *const pallas::Scalar, - is_mont: bool, - ) -> cuda::Error; + use crate::impl_pasta; - } - let mut ret = pallas::Point::default(); - let err = unsafe { - cuda_pippenger_pallas( - &mut ret, - &points[0], - npoints, - &scalars[0], - true, - ) - }; - assert!(err.code == 0, "{}", String::from(err)); - - return ret; - } - let mut ret = pallas::Point::default(); - unsafe { - mult_pippenger_pallas(&mut ret, &points[0], npoints, &scalars[0], true) - }; - ret + impl_pasta!( + cuda_vesta, + cuda_vesta_init, + cuda_vesta_with, + mult_pippenger_vesta, + Point, + Affine, + Scalar + ); } -use pasta_curves::vesta; +#[macro_export] +macro_rules! impl_pasta { + ( + $name:ident, + $name_init:ident, + $name_with:ident, + $name_cpu:ident, + $point:ident, + $affine:ident, + $scalar:ident + ) => { + #[cfg(feature = "cuda")] + use $crate::{cuda, cuda_available, CUDA_OFF}; + + #[repr(C)] + #[derive(Debug, Clone)] + pub struct CudaMSMContext { + context: *const std::ffi::c_void, + npoints: usize, + } -extern "C" { - fn mult_pippenger_vesta( - out: *mut vesta::Point, - points: *const vesta::Affine, - npoints: usize, - scalars: *const vesta::Scalar, - is_mont: bool, - ); -} + unsafe impl Send for CudaMSMContext {} + + unsafe impl Sync for CudaMSMContext {} + + impl Default for CudaMSMContext { + fn default() -> Self { + Self { + context: std::ptr::null(), + npoints: 0, + } + } + } + + #[cfg(feature = "cuda")] + // TODO: check for device-side memory leaks + impl Drop for CudaMSMContext { + fn drop(&mut self) { + extern "C" { + fn drop_msm_context_bn254(by_ref: &CudaMSMContext); + } + unsafe { + drop_msm_context_bn254(std::mem::transmute::<&_, &_>(self)) + }; + self.context = core::ptr::null(); + } + } + + #[derive(Default, Debug, Clone)] + pub struct MSMContext<'a> { + cuda_context: CudaMSMContext, + on_gpu: bool, + cpu_context: &'a [$affine], + } + + unsafe impl<'a> Send for MSMContext<'a> {} + + unsafe impl<'a> Sync for MSMContext<'a> {} -pub fn vesta( - points: &[vesta::Affine], - scalars: &[vesta::Scalar], -) -> vesta::Point { - let npoints = points.len(); - assert_eq!(npoints, scalars.len(), "length mismatch"); + impl<'a> MSMContext<'a> { + fn new(points: &'a [$affine]) -> Self { + Self { + cuda_context: CudaMSMContext::default(), + on_gpu: false, + cpu_context: points, + } + } + + fn npoints(&self) -> usize { + if self.on_gpu { + assert_eq!( + self.cpu_context.len(), + self.cuda_context.npoints + ); + } + self.cpu_context.len() + } + + fn cuda(&self) -> &CudaMSMContext { + &self.cuda_context + } + + fn points(&self) -> &[$affine] { + &self.cpu_context + } + } - #[cfg(feature = "cuda")] - if npoints >= 1 << 16 && unsafe { !CUDA_OFF && cuda_available() } { extern "C" { - fn cuda_pippenger_vesta( - out: *mut vesta::Point, - points: *const vesta::Affine, + fn $name_cpu( + out: *mut $point, + points: *const $affine, npoints: usize, - scalars: *const vesta::Scalar, + scalars: *const $scalar, is_mont: bool, - ) -> cuda::Error; + ); } - let mut ret = vesta::Point::default(); - let err = unsafe { - cuda_pippenger_vesta( - &mut ret, - &points[0], - npoints, - &scalars[0], - true, - ) - }; - assert!(err.code == 0, "{}", String::from(err)); - - return ret; - } - let mut ret = vesta::Point::default(); - unsafe { - mult_pippenger_vesta(&mut ret, &points[0], npoints, &scalars[0], true) + + pub fn msm(points: &[$affine], scalars: &[$scalar]) -> $point { + let npoints = points.len(); + assert!(npoints == scalars.len(), "length mismatch"); + + #[cfg(feature = "cuda")] + if npoints >= 1 << 16 && unsafe { !CUDA_OFF && cuda_available() } { + extern "C" { + fn $name( + out: *mut $point, + points: *const $affine, + npoints: usize, + scalars: *const $scalar, + is_mont: bool, + ) -> cuda::Error; + + } + let mut ret = $point::default(); + let err = unsafe { + $name(&mut ret, &points[0], npoints, &scalars[0], true) + }; + assert!(err.code == 0, "{}", String::from(err)); + + return ret; + } + let mut ret = $point::default(); + unsafe { + $name_cpu(&mut ret, &points[0], npoints, &scalars[0], true) + }; + ret + } + + pub fn init(points: &[$affine]) -> MSMContext { + let npoints = points.len(); + + let mut ret = MSMContext::new(points); + + #[cfg(feature = "cuda")] + if npoints >= 1 << 16 && unsafe { !CUDA_OFF && cuda_available() } { + extern "C" { + fn $name_init( + points: *const $affine, + npoints: usize, + msm_context: &mut CudaMSMContext, + ) -> cuda::Error; + } + + let npoints = points.len(); + let err = unsafe { + $name_init( + points.as_ptr() as *const _, + npoints, + &mut ret.cuda_context, + ) + }; + assert!(err.code == 0, "{}", String::from(err)); + ret.on_gpu = true; + return ret; + } + + ret + } + + pub fn with(context: &MSMContext, scalars: &[$scalar]) -> $point { + let npoints = context.npoints(); + let nscalars = scalars.len(); + assert!(npoints >= nscalars, "not enough points"); + + let mut ret = $point::default(); + + #[cfg(feature = "cuda")] + if nscalars >= 1 << 16 + && unsafe { !CUDA_OFF && cuda_available() } + { + extern "C" { + fn $name_with( + out: *mut $point, + context: &CudaMSMContext, + npoints: usize, + scalars: *const $scalar, + is_mont: bool, + ) -> cuda::Error; + } + + let err = unsafe { + $name_with(&mut ret, context.cuda(), nscalars, &scalars[0], true) + }; + assert!(err.code == 0, "{}", String::from(err)); + return ret; + } + + unsafe { + $name_cpu( + &mut ret, + &context.cpu_context[0], + nscalars, + &scalars[0], + true, + ) + }; + + ret + } }; - ret } pub mod utils { @@ -246,9 +374,14 @@ mod tests { let naive = naive_multiscalar_mul(&points, &scalars); println!("{:?}", naive); - let ret = pallas(&points, &scalars).to_affine(); + let ret = pallas::msm(&points, &scalars).to_affine(); println!("{:?}", ret); + let context = pallas::init(&points); + let ret_other = pallas::with(&context, &scalars).to_affine(); + println!("{:?}", ret_other); + assert_eq!(ret, naive); + assert_eq!(ret, ret_other); } }