diff --git a/src/bellpepper/r1cs.rs b/src/bellpepper/r1cs.rs index ae0d5013f..df36da799 100644 --- a/src/bellpepper/r1cs.rs +++ b/src/bellpepper/r1cs.rs @@ -112,9 +112,9 @@ fn add_constraint( ) { let (A, B, C, nn) = X; let n = **nn; - assert_eq!(n + 1, A.indptr.len(), "A: invalid shape"); - assert_eq!(n + 1, B.indptr.len(), "B: invalid shape"); - assert_eq!(n + 1, C.indptr.len(), "C: invalid shape"); + assert_eq!(n, A.num_rows(), "A: invalid shape"); + assert_eq!(n, B.num_rows(), "B: invalid shape"); + assert_eq!(n, C.num_rows(), "C: invalid shape"); let add_constraint_component = |index: Index, coeff: &S, M: &mut SparseMatrix| { // we add constraints to the matrix only if the associated coefficient is non-zero diff --git a/src/r1cs/sparse.rs b/src/r1cs/sparse.rs index dae4f78c8..d95d0810a 100644 --- a/src/r1cs/sparse.rs +++ b/src/r1cs/sparse.rs @@ -13,6 +13,7 @@ use ff::PrimeField; use itertools::Itertools as _; use rand_core::{CryptoRng, RngCore}; use rayon::prelude::*; +use ref_cast::RefCast; use serde::{Deserialize, Serialize}; /// CSR format sparse matrix, We follow the names used by scipy. @@ -31,6 +32,11 @@ pub struct SparseMatrix { pub cols: usize, } +/// Wrapper type for encode rows of [`SparseMatrix`] +#[derive(Debug, Clone, RefCast)] +#[repr(transparent)] +pub struct RowData([usize; 2]); + /// [`SparseMatrix`]s are often large, and this helps with cloning bottlenecks impl Clone for SparseMatrix { fn clone(&self) -> Self { @@ -111,6 +117,30 @@ impl SparseMatrix { Self::new(&matrix, rows, cols) } + /// Returns an iterator into the rows + pub fn iter_rows(&self) -> impl Iterator { + self + .indptr + .windows(2) + .map(|ptrs| RowData::ref_cast(ptrs.try_into().unwrap())) + } + + /// Returns a parallel iterator into the rows + pub fn par_iter_rows(&self) -> impl IndexedParallelIterator { + self + .indptr + .par_windows(2) + .map(|ptrs| RowData::ref_cast(ptrs.try_into().unwrap())) + } + + /// Retrieves the data for row slice [i..j] from `row`. + /// [`RowData`] **must** be created from unmodified `self` previously to guarentee safety. + pub fn get_row(&self, row: &RowData) -> impl Iterator { + self.data[row.0[0]..row.0[1]] + .iter() + .zip_eq(&self.indices[row.0[0]..row.0[1]]) + } + /// Retrieves the data for row slice [i..j] from `ptrs`. /// We assume that `ptrs` is indexed from `indptrs` and do not check if the /// returned slice is actually a valid row. @@ -226,6 +256,14 @@ impl SparseMatrix { nnz: *self.indptr.last().unwrap(), } } + + pub fn num_rows(&self) -> usize { + self.indptr.len() - 1 + } + + pub fn num_cols(&self) -> usize { + self.cols + } } /// Iterator for sparse matrix diff --git a/src/spartan/batched.rs b/src/spartan/batched.rs index ebcc5508d..60abafe8a 100644 --- a/src/spartan/batched.rs +++ b/src/spartan/batched.rs @@ -513,13 +513,11 @@ impl> BatchedRelaxedR1CSSNARKTrait r_y: &[E::Scalar]| -> Vec { let evaluate_with_table = - // TODO(@winston-h-zhang): review |M: &SparseMatrix, T_x: &[E::Scalar], T_y: &[E::Scalar]| -> E::Scalar { - M.indptr - .par_windows(2) + M.par_iter_rows() .enumerate() - .map(|(row_idx, ptrs)| { - M.get_row_unchecked(ptrs.try_into().unwrap()) + .map(|(row_idx, row)| { + M.get_row(row) .map(|(val, col_idx)| T_x[row_idx] * T_y[*col_idx] * val) .sum::() }) diff --git a/src/spartan/mod.rs b/src/spartan/mod.rs index 9b38adb24..663addb98 100644 --- a/src/spartan/mod.rs +++ b/src/spartan/mod.rs @@ -174,8 +174,9 @@ fn compute_eval_table_sparse( assert_eq!(rx.len(), S.num_cons); let inner = |M: &SparseMatrix, M_evals: &mut Vec| { - for (row_idx, ptrs) in M.indptr.windows(2).enumerate() { - for (val, col_idx) in M.get_row_unchecked(ptrs.try_into().unwrap()) { + for (row_idx, row) in M.iter_rows().enumerate() { + for (val, col_idx) in M.get_row(row) { + // TODO(@winston-h-zhang): Parallelize? Will need more complicated locking M_evals[*col_idx] += rx[row_idx] * val; } } diff --git a/src/spartan/snark.rs b/src/spartan/snark.rs index 3b82c3e57..3c41735e2 100644 --- a/src/spartan/snark.rs +++ b/src/spartan/snark.rs @@ -351,11 +351,10 @@ impl> RelaxedR1CSSNARKTrait for Relax -> Vec { let evaluate_with_table = |M: &SparseMatrix, T_x: &[E::Scalar], T_y: &[E::Scalar]| -> E::Scalar { - M.indptr - .par_windows(2) + M.par_iter_rows() .enumerate() - .map(|(row_idx, ptrs)| { - M.get_row_unchecked(ptrs.try_into().unwrap()) + .map(|(row_idx, row)| { + M.get_row(row) .map(|(val, col_idx)| T_x[row_idx] * T_y[*col_idx] * val) .sum::() })