Skip to content

Commit

Permalink
blocks
Browse files Browse the repository at this point in the history
  • Loading branch information
Hanting Zhang committed Mar 15, 2024
1 parent 1d86d4a commit e1acefc
Showing 1 changed file with 82 additions and 28 deletions.
110 changes: 82 additions & 28 deletions spmvm/spmvm.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ struct spmvm_context_t
scalar_t *d_scalars;
// output scalars
scalar_t *d_out;

size_t start_row;
size_t start_data;
};

template <typename scalar_t>
Expand All @@ -57,17 +60,27 @@ template <typename scalar_t>
__global__ void csr_vector_mul(spmvm_context_t<scalar_t> *d_context)
{
size_t idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < d_context->num_rows) {
printf("idx: %d\n", idx);
}
while (idx < d_context->num_rows)
{
for (size_t i = d_context->d_row_ptr[idx]; i < d_context->d_row_ptr[idx + 1]; i++)
printf("idx: %d\n", idx);
size_t row_start = d_context->d_row_ptr[idx] - d_context->start_data;
size_t row_end = d_context->d_row_ptr[idx + 1] - d_context->start_data;
size_t row_idx = d_context->start_row + idx;
printf("start_row: %d\n", d_context->start_row);
printf("d_row_ptr[idx]: %d\n", d_context->d_row_ptr[idx]);
printf("d_row_ptr[idx + 1]: %d\n", d_context->d_row_ptr[idx + 1]);
printf("start_data: %d\n", d_context->start_data);
for (size_t i = row_start; i < row_end; i++)
{
d_context->d_out[idx] = d_context->d_out[idx] + d_context->d_scalars[d_context->d_col_idx[i]] * d_context->d_data[i];
d_context->d_out[row_idx] = d_context->d_out[row_idx] + d_context->d_scalars[d_context->d_col_idx[i]] * d_context->d_data[i];
}
idx += gridDim.x * blockDim.x;
}
__syncthreads();
}

#undef asm

#ifndef SPPARK_DONT_INSTANTIATE_TEMPLATES
Expand All @@ -84,10 +97,14 @@ struct spmvm_host_t
const scalar_t *data;
const size_t *col_idx;
const size_t *row_ptr;
const size_t *blocks;

size_t num_rows;
size_t num_cols;
size_t nnz;

size_t num_blocks;
size_t block_size;
};

template <typename scalar_t>
Expand Down Expand Up @@ -115,21 +132,24 @@ public:
// scalar_t *d_out;

public:
spmvm_t(size_t num_rows, size_t num_cols, size_t nnz, int device_id = -1)
spmvm_t(spmvm_host_t<scalar_t> *csr, int device_id = -1)
: gpu(select_gpu(device_id))
{
this->context = reinterpret_cast<spmvm_context_t<scalar_t> *>(malloc(sizeof(spmvm_context_t<scalar_t>)));

context->d_data = reinterpret_cast<scalar_t *>(gpu.Dmalloc(nnz * sizeof(scalar_t)));
context->d_col_idx = reinterpret_cast<size_t *>(gpu.Dmalloc(nnz * sizeof(size_t)));
context->d_row_ptr = reinterpret_cast<size_t *>(gpu.Dmalloc((num_rows + 1) * sizeof(size_t)));
context->d_data = reinterpret_cast<scalar_t *>(gpu.Dmalloc(2 * csr->block_size * sizeof(scalar_t)));
context->d_col_idx = reinterpret_cast<size_t *>(gpu.Dmalloc(2 * csr->block_size * sizeof(size_t)));
context->d_row_ptr = reinterpret_cast<size_t *>(gpu.Dmalloc((csr->num_rows + 1) * sizeof(size_t)));

context->num_rows = csr->num_rows;
context->num_cols = csr->num_cols;
context->nnz = csr->nnz;

context->num_rows = num_rows;
context->num_cols = num_cols;
context->nnz = nnz;
context->d_scalars = reinterpret_cast<scalar_t *>(gpu.Dmalloc(csr->num_cols * sizeof(scalar_t)));
context->d_out = reinterpret_cast<scalar_t *>(gpu.Dmalloc(csr->num_rows * sizeof(scalar_t)));

context->d_scalars = reinterpret_cast<scalar_t *>(gpu.Dmalloc(num_cols * sizeof(scalar_t)));
context->d_out = reinterpret_cast<scalar_t *>(gpu.Dmalloc(num_rows * sizeof(scalar_t)));
context->start_row = 0;
context->start_data = 0;

this->owned = true;
}
Expand All @@ -150,6 +170,9 @@ public:
spmvm_context->d_scalars = reinterpret_cast<scalar_t *>(gpu.Dmalloc(csr->num_cols * sizeof(scalar_t)));
spmvm_context->d_out = reinterpret_cast<scalar_t *>(gpu.Dmalloc(csr->num_rows * sizeof(scalar_t)));

spmvm_context->start_row = 0;
spmvm_context->start_data = 0;

// move data into allocated memory
if (csr->data)
gpu[2].HtoD(&spmvm_context->d_data[0], &csr->data[0], csr->nnz);
Expand Down Expand Up @@ -202,21 +225,55 @@ public:

try
{
if (csr->data)
gpu[2].HtoD(&context->d_data[0], &csr->data[0], context->nnz);
if (csr->col_idx)
gpu[2].HtoD(&context->d_col_idx[0], &csr->col_idx[0], context->nnz);
if (csr->row_ptr)
gpu[2].HtoD(&context->d_row_ptr[0], &csr->row_ptr[0], context->num_rows + 1);

if (scalars)
gpu[2].HtoD(&context->d_scalars[0], &scalars[0], context->num_cols);

spmvm_context_t<scalar_t> *d_context = reinterpret_cast<spmvm_context_t<scalar_t> *>(gpu[2].Dmalloc(sizeof(spmvm_context_t<scalar_t>)));
gpu[2].HtoD(d_context, context, 1);
cudaMemsetAsync(&context->d_out[0], 0, context->num_rows * sizeof(scalar_t), gpu[2]);
csr_vector_mul<scalar_t><<<gpu.sm_count(), nthreads, 0, gpu[2]>>>(d_context);
CUDA_OK(cudaGetLastError());

size_t start_row = 0;
size_t end_row = 0;
size_t num_rows = 0;

size_t start_data = 0;
size_t end_data = 0;
size_t num_data = 0;

for (size_t i = 0; i < csr->num_blocks - 1; ++i) {
start_row = csr->blocks[i];
end_row = csr->blocks[i + 1];
num_rows = end_row - start_row;

start_data = csr->row_ptr[start_row];
end_data = csr->row_ptr[end_row];
num_data = end_data - start_data;

if (csr->data) {
gpu[i&1].HtoD(&context->d_data[0], &csr->data[start_data], num_data);
}
if (csr->col_idx) {
gpu[i&1].HtoD(&context->d_col_idx[0], &csr->col_idx[start_data], num_data);
}
if (csr->row_ptr) {
gpu[i&1].HtoD(&context->d_row_ptr[0], &csr->row_ptr[start_row], num_rows + 1);
}
printf("ROW %d: ", i);
for (int j = 0; j < num_rows + 1; ++j) {
printf("%d ", csr->row_ptr[start_row + j]);
}
printf("\n");

gpu[i&1].sync();

context->num_rows = num_rows;
context->start_row = start_row;
context->start_data = start_data;
spmvm_context_t<scalar_t> *d_context = reinterpret_cast<spmvm_context_t<scalar_t> *>(gpu[i&1].Dmalloc(sizeof(spmvm_context_t<scalar_t>)));
gpu[i&1].HtoD(d_context, context, 1);

csr_vector_mul<scalar_t><<<gpu.sm_count(), nthreads, 0, gpu[i&1]>>>(d_context);
CUDA_OK(cudaGetLastError());

gpu[i&1].sync();
}

gpu[2].DtoH(&out[0], &context->d_out[0], context->num_rows);
gpu.sync();
Expand All @@ -240,10 +297,7 @@ static RustError sparse_matrix_mul(spmvm_host_t<scalar_t> *csr, const scalar_t *
{
try
{
size_t num_rows = csr->num_rows;
size_t num_cols = csr->num_cols;
size_t nnz = csr->nnz;
spmvm_t<scalar_t> spmvm{num_rows, num_cols, nnz};
spmvm_t<scalar_t> spmvm{csr};
return spmvm.invoke(csr, scalars, out, nthreads);
}
catch (const cuda_error &e)
Expand Down

0 comments on commit e1acefc

Please sign in to comment.