From 489f0ddcf696b727b3ae2c0e4ac1f6abfc405382 Mon Sep 17 00:00:00 2001 From: Hanting Zhang Date: Fri, 9 Feb 2024 19:26:44 +0000 Subject: [PATCH] fix slowdown --- msm/pippenger.cuh | 68 ++++++++++++++++++++++++++++++++++++----------- 1 file changed, 53 insertions(+), 15 deletions(-) diff --git a/msm/pippenger.cuh b/msm/pippenger.cuh index f73155e..9bc0563 100644 --- a/msm/pippenger.cuh +++ b/msm/pippenger.cuh @@ -369,9 +369,11 @@ class msm_t { this->d_total_blob = nullptr; - d_points = reinterpret_cast(gpu.Dmalloc(npoints * sizeof(d_points[0]))); - gpu.HtoD(d_points, points, npoints, sizeof(affine_h)); - CUDA_OK(cudaGetLastError()); + if (points) { + d_points = reinterpret_cast(gpu.Dmalloc(npoints * sizeof(d_points[0]))); + gpu.HtoD(d_points, points, npoints, sizeof(affine_h)); + CUDA_OK(cudaGetLastError()); + } } msm_t(affine_h *d_points, size_t npoints, int device_id = -1) : gpu(select_gpu(device_id)) { @@ -448,14 +450,29 @@ class msm_t { public: // Compute various constants (stride length, window size) based on the number of scalars. // Also allocate scratch space. - void setup_scratch(size_t nscalars, uint32_t *pidx) { + void setup_scratch(affine_t *&points, size_t npoints, size_t nscalars, uint32_t *pidx) { + this->npoints = npoints; this->nscalars = nscalars; + // if pidx is nullptr, then we nust have npoints == nscalars + if (points && !pidx) + assert(npoints == nscalars); + else if (!points && pidx) + assert(d_points); + else if (points && pidx) { + // if both are not null, then we move all the points onto the GPU at once, + // at a performance penalty + points = nullptr; + d_points = reinterpret_cast(gpu.Dmalloc(npoints * sizeof(d_points[0]))); + gpu.HtoD(d_points, points, npoints, sizeof(affine_h)); + CUDA_OK(cudaGetLastError()); + } + uint32_t lg_n = lg2(nscalars + nscalars / 2); wbits = 17; if (nscalars > 192) { - wbits = std::min(lg_n, (uint32_t)18); + wbits = std::min(lg_n - 8, (uint32_t)18); if (wbits < 10) wbits = 10; } else if (nscalars > 0) { @@ -475,11 +492,17 @@ class msm_t { this->stride = (nscalars + batch - 1) / batch; stride = (stride + WARP_SZ - 1) & ((size_t)0 - WARP_SZ); + // scratch space for scalars and sorting indexes size_t temp_sz = stride * std::max(2 * sizeof(uint2), sizeof(scalar_t)); size_t digits_sz = nwins * stride * sizeof(uint32_t); + // scratch space for either pidx or points + // the logic is that if pidx is nullptr, then we should load the points + // stride by stride size_t pidx_sz = pidx ? stride * sizeof(uint32_t) : 0; + size_t points_sz = points ? (batch > 1 ? 2 * stride : stride) : 0; + points_sz *= sizeof(affine_h); - size_t d_blob_sz = d_buckets_sz + d_hist_sz + temp_sz + digits_sz + pidx_sz; + size_t d_blob_sz = d_buckets_sz + d_hist_sz + temp_sz + digits_sz + pidx_sz + points_sz; d_total_blob = reinterpret_cast(gpu.Dmalloc(d_blob_sz)); size_t offset = 0; @@ -493,12 +516,16 @@ class msm_t { offset += temp_sz; d_digits = vec2d_t((uint32_t *)&d_total_blob[offset], stride); offset += digits_sz; - if (pidx) d_pidx = (uint32_t *)&d_total_blob[offset]; + if (pidx) + d_pidx = (uint32_t *)&d_total_blob[offset]; + if (points) + d_points = (affine_h *)&d_total_blob[offset]; } - RustError invoke(point_t &out, const scalar_t *scalars, size_t nscalars, uint32_t pidx[], bool mont = true) { - // assert(this->nscalars <= nscalars); - setup_scratch(nscalars, pidx); + RustError invoke(point_t &out, const affine_t points_[], size_t npoints, const scalar_t *scalars, size_t nscalars, + uint32_t pidx[], bool mont = true) { + affine_t *points = points_; + setup_scratch(points, npoints, nscalars, pidx); std::vector res(nwins); std::vector ones(gpu.sm_count() * BATCH_ADD_BLOCK_SIZE / WARP_SZ); @@ -519,6 +546,9 @@ class msm_t { digits(&d_scalars[0], num, d_digits, d_temps, mont, d_pidx); gpu[2].record(ev); + if (points) + gpu[0].HtoD(&d_points[0], &points[h_off], num, sizeof(affine_h)); + for (uint32_t i = 0; i < batch; i++) { gpu[i & 1].wait(ev); @@ -537,8 +567,8 @@ class msm_t { CUDA_OK(cudaGetLastError()); if (i < batch - 1) { - d_off += stride; - num = d_off + stride <= nscalars ? stride : nscalars - d_off; + h_off += stride; + num = h_off + stride <= npoints ? stride : npoints - h_off; gpu[2].HtoD(&d_scalars[0], &scalars[d_off], num); if (pidx) @@ -547,6 +577,14 @@ class msm_t { gpu[2].wait(ev); digits(&d_scalars[0], num, d_digits, d_temps, mont, d_pidx); gpu[2].record(ev); + + if (points) { + size_t j = (i + 1) & 1; + d_off = j ? stride : 0; + gpu[j].HtoD(&d_points[d_off], &points[h_off], num, sizeof(affine_h)); + } else { + d_off = h_off; + } } if (i > 0) { @@ -704,9 +742,9 @@ template static RustError mult_pippenger(point_t *out, const affine_t points[], size_t npoints, const scalar_t scalars[], size_t nscalars, uint32_t pidx[], bool mont = true) { try { - msm_t msm{points, npoints, true}; + msm_t msm{nullptr, npoints, false}; // msm.setup_scratch(nscalars); - return msm.invoke(*out, scalars, nscalars, pidx, mont); + return msm.invoke(*out, points, npoints, scalars, nscalars, pidx, mont); } catch (const cuda_error &e) { out->inf(); #ifdef TAKE_RESPONSIBILITY_FOR_ERROR_MESSAGE @@ -724,7 +762,7 @@ static RustError mult_pippenger_with(point_t *out, msm_context_t *msm_ try { msm_t msm{msm_context->d_points, msm_context->npoints}; // msm.setup_scratch(nscalars); - return msm.invoke(*out, scalars, nscalars, pidx, mont); + return msm.invoke(*out, nullptr, msm_context->npoints, scalars, nscalars, pidx, mont); } catch (const cuda_error &e) { out->inf(); #ifdef TAKE_RESPONSIBILITY_FOR_ERROR_MESSAGE