From 96a1199814acda68f3d2cf51dbfd1007ad5ae0b4 Mon Sep 17 00:00:00 2001 From: Hanting Zhang Date: Fri, 15 Mar 2024 22:36:06 +0000 Subject: [PATCH] add witness struct --- spmvm/spmvm.cuh | 28 +++++++++++++++++++++++----- 1 file changed, 23 insertions(+), 5 deletions(-) diff --git a/spmvm/spmvm.cuh b/spmvm/spmvm.cuh index 1203851..2d4acae 100644 --- a/spmvm/spmvm.cuh +++ b/spmvm/spmvm.cuh @@ -99,6 +99,17 @@ struct spmvm_host_t size_t block_size; }; +template +struct witness_t +{ + const scalar_t *W; + const scalar_t *u; + const scalar_t *U; + + size_t nW; + size_t nU; +}; + template class spmvm_t { @@ -209,16 +220,21 @@ public: } public: - RustError invoke(spmvm_host_t *csr, const scalar_t scalars[], scalar_t out[], size_t nthreads) + RustError invoke(spmvm_host_t *csr, const witness_t *witness, scalar_t out[], size_t nthreads) { assert(csr->num_rows == context->num_rows); assert(csr->num_cols == context->num_cols); assert(csr->nnz == context->nnz); + assert(witness->nW + witness->nU + 1 == csr->num_cols); try { - if (scalars) - gpu[2].HtoD(&context->d_scalars[0], &scalars[0], context->num_cols); + if (witness->W) + gpu[2].HtoD(&context->d_scalars[0], &witness->W[0], witness->nW); + gpu[2].HtoD(&context->d_scalars[witness->nW], witness->u, 1); + if (witness->U) + gpu[2].HtoD(&context->d_scalars[witness->nW + 1], &witness->U[0], witness->nU); + cudaMemsetAsync(&context->d_out[0], 0, context->num_rows * sizeof(scalar_t), gpu[2]); size_t start_row = 0; @@ -258,6 +274,8 @@ public: csr_vector_mul<<>>(d_context); CUDA_OK(cudaGetLastError()); + + gpu[i&1].sync(); } gpu[2].DtoH(&out[0], &context->d_out[0], csr->num_rows); @@ -278,12 +296,12 @@ public: }; template -static RustError sparse_matrix_mul(spmvm_host_t *csr, const scalar_t *scalars, scalar_t *out, size_t nthreads) +static RustError sparse_matrix_mul(spmvm_host_t *csr, const witness_t *witness, scalar_t *out, size_t nthreads) { try { spmvm_t spmvm{csr}; - return spmvm.invoke(csr, scalars, out, nthreads); + return spmvm.invoke(csr, witness, out, nthreads); } catch (const cuda_error &e) {