Skip to content

Commit

Permalink
chunked multiply
Browse files Browse the repository at this point in the history
  • Loading branch information
Hanting Zhang committed Mar 18, 2024
1 parent ef677bf commit cde7927
Showing 1 changed file with 88 additions and 15 deletions.
103 changes: 88 additions & 15 deletions src/r1cs/sparse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -215,22 +215,95 @@ impl<F: PrimeField> SparseMatrix<F> {
/// This does not check that the shape of the matrix/vector are compatible.
pub fn multiply_witness_into_unchecked(&self, W: &[F], u: &F, X: &[F], sink: &mut Vec<F>) {
let num_vars = W.len();
self
.indptr
.par_windows(2)
.map(|ptrs| {
self
.get_row_unchecked(ptrs.try_into().unwrap())
.fold(F::ZERO, |acc, (val, col_idx)| {
let val = match col_idx.cmp(&num_vars) {
Ordering::Less => *val * W[*col_idx],
Ordering::Equal => *val * *u,
Ordering::Greater => *val * X[*col_idx - num_vars - 1],
};
acc + val
})
sink.clear();
// Parallelism strategy below splits the (row, column, value) tuples into num_threads different chunks.
// It is assumed that the tuples are (row, column) ordered. We exploit this fact to create a mutex over
// each of the chunks and assume that only one of the threads will be writing to each chunk at a time
// due to ordering.

let num_threads = rayon::current_num_threads() * 4; // Enable work stealing incase of thread work imbalance
let row_chunk_size = (self.num_rows() as f64 / num_threads as f64).ceil() as usize;

let mut chunks: Vec<std::sync::Mutex<Vec<F>>> = Vec::with_capacity(num_threads);
let mut remaining_rows = self.num_rows();
(0..num_threads).for_each(|i| {
if i == num_threads - 1 {
// the final chunk may be smaller
let inner = std::sync::Mutex::new(vec![F::ZERO; remaining_rows]);
chunks.push(inner);
} else {
let inner = std::sync::Mutex::new(vec![F::ZERO; row_chunk_size]);
chunks.push(inner);
remaining_rows -= row_chunk_size;
}
});

let get_chunk = |row_index: usize| -> usize { row_index / row_chunk_size };
let get_index = |row_index: usize| -> usize { row_index % row_chunk_size };
let get_value = |col_idx: usize| -> F {
match col_idx.cmp(&num_vars) {
Ordering::Less => W[col_idx],
Ordering::Equal => *u,
Ordering::Greater => X[col_idx - num_vars - 1],
}
};
let mul_row = |row: &RowData| -> F {
self.get_row(row).fold(F::ZERO, |acc, (&val, col_idx)| {
let col_val = get_value(*col_idx);
let val = if val == F::ONE {
col_val
} else if col_val == F::ONE {
val
} else {
val * col_val
};
acc + val
})
.collect_into_vec(sink);
};

let span = tracing::span!(tracing::Level::TRACE, "all_chunks_multiplication");
let _enter = span.enter();
self
.par_iter_rows()
.enumerate()
.chunks(row_chunk_size)
.for_each(|sub_matrix| {
let (init_row_idx, init_row) = sub_matrix[0];
let mut prev_chunk_index = get_chunk(init_row_idx);
let curr_row_index = get_index(init_row_idx);
let mut curr_chunk = chunks[prev_chunk_index].lock().unwrap();

curr_chunk[curr_row_index] = mul_row(init_row);

let span_a = tracing::span!(tracing::Level::TRACE, "chunk_multiplication");
let _enter_b = span_a.enter();
for (row_idx, row) in sub_matrix {
let curr_chunk_index = get_chunk(row_idx);
if prev_chunk_index != curr_chunk_index {
// only unlock the mutex again if required
drop(curr_chunk); // drop the curr_chunk before waiting for the next to avoid race condition
let new_chunk = chunks[curr_chunk_index].lock().unwrap();
curr_chunk = new_chunk;

prev_chunk_index = curr_chunk_index;
}

let curr_row_index = get_index(row_idx);
curr_chunk[curr_row_index] = mul_row(row);
}
});
drop(_enter);
drop(span);

let span_a = tracing::span!(tracing::Level::TRACE, "chunks_mutex_unwrap");
let _enter_a = span_a.enter();
// TODO(sragss): Mutex unwrap takes about 30% of the time due to clone, likely unnecessary.
for chunk in chunks {
let inner_vec = chunk.into_inner().unwrap();
sink.extend(inner_vec.iter());
}
drop(_enter_a);
drop(span_a);
}

/// number of non-zero entries
Expand Down

0 comments on commit cde7927

Please sign in to comment.