Skip to content

Commit

Permalink
fix slowdown
Browse files Browse the repository at this point in the history
  • Loading branch information
Hanting Zhang committed Feb 9, 2024
1 parent c985f76 commit 489f0dd
Showing 1 changed file with 53 additions and 15 deletions.
68 changes: 53 additions & 15 deletions msm/pippenger.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -369,9 +369,11 @@ class msm_t {

this->d_total_blob = nullptr;

d_points = reinterpret_cast<decltype(d_points)>(gpu.Dmalloc(npoints * sizeof(d_points[0])));
gpu.HtoD(d_points, points, npoints, sizeof(affine_h));
CUDA_OK(cudaGetLastError());
if (points) {
d_points = reinterpret_cast<decltype(d_points)>(gpu.Dmalloc(npoints * sizeof(d_points[0])));
gpu.HtoD(d_points, points, npoints, sizeof(affine_h));
CUDA_OK(cudaGetLastError());
}
}

msm_t(affine_h *d_points, size_t npoints, int device_id = -1) : gpu(select_gpu(device_id)) {
Expand Down Expand Up @@ -448,14 +450,29 @@ class msm_t {
public:
// Compute various constants (stride length, window size) based on the number of scalars.
// Also allocate scratch space.
void setup_scratch(size_t nscalars, uint32_t *pidx) {
void setup_scratch(affine_t *&points, size_t npoints, size_t nscalars, uint32_t *pidx) {
this->npoints = npoints;
this->nscalars = nscalars;

// if pidx is nullptr, then we nust have npoints == nscalars
if (points && !pidx)
assert(npoints == nscalars);
else if (!points && pidx)
assert(d_points);
else if (points && pidx) {
// if both are not null, then we move all the points onto the GPU at once,
// at a performance penalty
points = nullptr;
d_points = reinterpret_cast<decltype(d_points)>(gpu.Dmalloc(npoints * sizeof(d_points[0])));
gpu.HtoD(d_points, points, npoints, sizeof(affine_h));
CUDA_OK(cudaGetLastError());
}

uint32_t lg_n = lg2(nscalars + nscalars / 2);

wbits = 17;
if (nscalars > 192) {
wbits = std::min(lg_n, (uint32_t)18);
wbits = std::min(lg_n - 8, (uint32_t)18);
if (wbits < 10)
wbits = 10;
} else if (nscalars > 0) {
Expand All @@ -475,11 +492,17 @@ class msm_t {
this->stride = (nscalars + batch - 1) / batch;
stride = (stride + WARP_SZ - 1) & ((size_t)0 - WARP_SZ);

// scratch space for scalars and sorting indexes
size_t temp_sz = stride * std::max(2 * sizeof(uint2), sizeof(scalar_t));
size_t digits_sz = nwins * stride * sizeof(uint32_t);
// scratch space for either pidx or points
// the logic is that if pidx is nullptr, then we should load the points
// stride by stride
size_t pidx_sz = pidx ? stride * sizeof(uint32_t) : 0;
size_t points_sz = points ? (batch > 1 ? 2 * stride : stride) : 0;
points_sz *= sizeof(affine_h);

size_t d_blob_sz = d_buckets_sz + d_hist_sz + temp_sz + digits_sz + pidx_sz;
size_t d_blob_sz = d_buckets_sz + d_hist_sz + temp_sz + digits_sz + pidx_sz + points_sz;

d_total_blob = reinterpret_cast<char *>(gpu.Dmalloc(d_blob_sz));
size_t offset = 0;
Expand All @@ -493,12 +516,16 @@ class msm_t {
offset += temp_sz;
d_digits = vec2d_t<uint32_t>((uint32_t *)&d_total_blob[offset], stride);
offset += digits_sz;
if (pidx) d_pidx = (uint32_t *)&d_total_blob[offset];
if (pidx)
d_pidx = (uint32_t *)&d_total_blob[offset];
if (points)
d_points = (affine_h *)&d_total_blob[offset];
}

RustError invoke(point_t &out, const scalar_t *scalars, size_t nscalars, uint32_t pidx[], bool mont = true) {
// assert(this->nscalars <= nscalars);
setup_scratch(nscalars, pidx);
RustError invoke(point_t &out, const affine_t points_[], size_t npoints, const scalar_t *scalars, size_t nscalars,
uint32_t pidx[], bool mont = true) {
affine_t *points = points_;
setup_scratch(points, npoints, nscalars, pidx);

std::vector<result_t> res(nwins);
std::vector<bucket_t> ones(gpu.sm_count() * BATCH_ADD_BLOCK_SIZE / WARP_SZ);
Expand All @@ -519,6 +546,9 @@ class msm_t {
digits(&d_scalars[0], num, d_digits, d_temps, mont, d_pidx);
gpu[2].record(ev);

if (points)
gpu[0].HtoD(&d_points[0], &points[h_off], num, sizeof(affine_h));

for (uint32_t i = 0; i < batch; i++) {
gpu[i & 1].wait(ev);

Expand All @@ -537,8 +567,8 @@ class msm_t {
CUDA_OK(cudaGetLastError());

if (i < batch - 1) {
d_off += stride;
num = d_off + stride <= nscalars ? stride : nscalars - d_off;
h_off += stride;
num = h_off + stride <= npoints ? stride : npoints - h_off;

gpu[2].HtoD(&d_scalars[0], &scalars[d_off], num);
if (pidx)
Expand All @@ -547,6 +577,14 @@ class msm_t {
gpu[2].wait(ev);
digits(&d_scalars[0], num, d_digits, d_temps, mont, d_pidx);
gpu[2].record(ev);

if (points) {
size_t j = (i + 1) & 1;
d_off = j ? stride : 0;
gpu[j].HtoD(&d_points[d_off], &points[h_off], num, sizeof(affine_h));
} else {
d_off = h_off;
}
}

if (i > 0) {
Expand Down Expand Up @@ -704,9 +742,9 @@ template <class bucket_t, class point_t, class affine_t, class scalar_t>
static RustError mult_pippenger(point_t *out, const affine_t points[], size_t npoints, const scalar_t scalars[],
size_t nscalars, uint32_t pidx[], bool mont = true) {
try {
msm_t<bucket_t, point_t, affine_t, scalar_t> msm{points, npoints, true};
msm_t<bucket_t, point_t, affine_t, scalar_t> msm{nullptr, npoints, false};
// msm.setup_scratch(nscalars);
return msm.invoke(*out, scalars, nscalars, pidx, mont);
return msm.invoke(*out, points, npoints, scalars, nscalars, pidx, mont);
} catch (const cuda_error &e) {
out->inf();
#ifdef TAKE_RESPONSIBILITY_FOR_ERROR_MESSAGE
Expand All @@ -724,7 +762,7 @@ static RustError mult_pippenger_with(point_t *out, msm_context_t<affine_h> *msm_
try {
msm_t<bucket_t, point_t, affine_t, scalar_t> msm{msm_context->d_points, msm_context->npoints};
// msm.setup_scratch(nscalars);
return msm.invoke(*out, scalars, nscalars, pidx, mont);
return msm.invoke(*out, nullptr, msm_context->npoints, scalars, nscalars, pidx, mont);
} catch (const cuda_error &e) {
out->inf();
#ifdef TAKE_RESPONSIBILITY_FOR_ERROR_MESSAGE
Expand Down

0 comments on commit 489f0dd

Please sign in to comment.