Skip to content

Commit

Permalink
add witness struct
Browse files Browse the repository at this point in the history
  • Loading branch information
Hanting Zhang committed Mar 15, 2024
1 parent 0daa74b commit 96a1199
Showing 1 changed file with 23 additions and 5 deletions.
28 changes: 23 additions & 5 deletions spmvm/spmvm.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,17 @@ struct spmvm_host_t
size_t block_size;
};

template <typename scalar_t>
struct witness_t
{
const scalar_t *W;
const scalar_t *u;
const scalar_t *U;

size_t nW;
size_t nU;
};

template <typename scalar_t>
class spmvm_t
{
Expand Down Expand Up @@ -209,16 +220,21 @@ public:
}

public:
RustError invoke(spmvm_host_t<scalar_t> *csr, const scalar_t scalars[], scalar_t out[], size_t nthreads)
RustError invoke(spmvm_host_t<scalar_t> *csr, const witness_t<scalar_t> *witness, scalar_t out[], size_t nthreads)
{
assert(csr->num_rows == context->num_rows);
assert(csr->num_cols == context->num_cols);
assert(csr->nnz == context->nnz);
assert(witness->nW + witness->nU + 1 == csr->num_cols);

try
{
if (scalars)
gpu[2].HtoD(&context->d_scalars[0], &scalars[0], context->num_cols);
if (witness->W)
gpu[2].HtoD(&context->d_scalars[0], &witness->W[0], witness->nW);
gpu[2].HtoD(&context->d_scalars[witness->nW], witness->u, 1);
if (witness->U)
gpu[2].HtoD(&context->d_scalars[witness->nW + 1], &witness->U[0], witness->nU);

cudaMemsetAsync(&context->d_out[0], 0, context->num_rows * sizeof(scalar_t), gpu[2]);

size_t start_row = 0;
Expand Down Expand Up @@ -258,6 +274,8 @@ public:

csr_vector_mul<scalar_t><<<gpu.sm_count(), nthreads, 0, gpu[i&1]>>>(d_context);
CUDA_OK(cudaGetLastError());

gpu[i&1].sync();
}

gpu[2].DtoH(&out[0], &context->d_out[0], csr->num_rows);
Expand All @@ -278,12 +296,12 @@ public:
};

template <typename scalar_t>
static RustError sparse_matrix_mul(spmvm_host_t<scalar_t> *csr, const scalar_t *scalars, scalar_t *out, size_t nthreads)
static RustError sparse_matrix_mul(spmvm_host_t<scalar_t> *csr, const witness_t<scalar_t> *witness, scalar_t *out, size_t nthreads)
{
try
{
spmvm_t<scalar_t> spmvm{csr};
return spmvm.invoke(csr, scalars, out, nthreads);
return spmvm.invoke(csr, witness, out, nthreads);
}
catch (const cuda_error &e)
{
Expand Down

0 comments on commit 96a1199

Please sign in to comment.