diff --git a/msm/pippenger.cuh b/msm/pippenger.cuh index 0481abc..24c5f70 100644 --- a/msm/pippenger.cuh +++ b/msm/pippenger.cuh @@ -318,7 +318,7 @@ template class msm_t { const gpu_t &gpu; - + // main data bool owned; affine_h *d_points; @@ -346,7 +346,7 @@ class msm_t { public: result_t() {} - inline operator decltype(ret) &() { return ret; } + inline operator decltype(ret) & () { return ret; } inline const bucket_t *operator[](size_t i) const { return ret[i]; } }; @@ -359,8 +359,7 @@ class msm_t { public: // Initialize the MSM by moving the points to the device - msm_t(const affine_t points[], size_t npoints, bool owned, int device_id = -1) - : gpu(select_gpu(device_id)) { + msm_t(const affine_t points[], size_t npoints, bool owned, int device_id = -1) : gpu(select_gpu(device_id)) { // set default values for fields this->d_points = nullptr; this->d_scalars = nullptr; @@ -375,8 +374,7 @@ class msm_t { CUDA_OK(cudaGetLastError()); } - msm_t(affine_h *d_points, size_t npoints, int device_id = -1) - : gpu(select_gpu(device_id)) { + msm_t(affine_h *d_points, size_t npoints, int device_id = -1) : gpu(select_gpu(device_id)) { // set default values for fields this->d_points = d_points; this->d_scalars = nullptr; @@ -453,13 +451,57 @@ class msm_t { void setup_scratch(size_t nscalars) { this->nscalars = nscalars; - // nscalars = (nscalars + WARP_SZ - 1) & ~(WARP_SZ - 1); + // uint32_t lg_n = lg2(nscalars + nscalars / 2); + + // wbits = 17; + // if (nscalars > 192) { + // wbits = std::min(lg_n, (uint32_t)18); + // if (wbits < 10) + // wbits = 10; + // } else if (nscalars > 0) { + // wbits = 10; + // } + // nwins = (scalar_t::bit_length() - 1) / wbits + 1; + + // uint32_t row_sz = 1U << (wbits - 1); + + // size_t d_buckets_sz = (nwins * row_sz) + (gpu.sm_count() * BATCH_ADD_BLOCK_SIZE / WARP_SZ); + // d_buckets_sz *= sizeof(d_buckets[0]); + // size_t d_hist_sz = nwins * row_sz * sizeof(uint32_t); + // size_t temp_sz = scalars ? sizeof(scalar_t) : 0; + + // size_t batch = 1 << (std::max(lg_n, wbits) - wbits); + // batch >>= 6; + // batch = batch ? batch : 1; + // uint32_t stride = (nscalars + batch - 1) / batch; + // stride = (stride + WARP_SZ - 1) & ((size_t)0 - WARP_SZ); + + // temp_sz = stride * std::max(2 * sizeof(uint2), temp_sz); + // size_t digits_sz = nwins * stride * sizeof(uint32_t); + + // size_t d_blob_sz = d_buckets_sz + d_hist_sz + temp_sz + digits_sz; + + // d_total_blob = reinterpret_cast(gpu.Dmalloc(d_blob_sz)); + // size_t offset = 0; + // d_buckets = reinterpret_cast(&d_total_blob[offset]); + // offset += d_buckets_sz; + // d_hist = vec2d_t((uint32_t *)&d_total_blob[offset], row_sz); + // offset += d_hist_sz; + + // d_temps = vec2d_t((uint2 *)&d_total_blob[offset], stride); + // d_scalars = (scalar_t *)&d_total_blob[offset]; + // offset += temp_sz; + // d_digits = vec2d_t((uint32_t *)&d_total_blob[offset], stride); + } + + RustError invoke(point_t &out, const scalar_t *scalars, size_t nscalars, uint32_t pidx[], bool mont = true) { + // assert(this->nscalars <= nscalars); + uint32_t lg_n = lg2(nscalars + nscalars / 2); - // Compute window size wbits = 17; if (nscalars > 192) { - wbits = std::min(lg_n - 8, (uint32_t)18); + wbits = std::min(lg_n, (uint32_t)18); if (wbits < 10) wbits = 10; } else if (nscalars > 0) { @@ -467,71 +509,35 @@ class msm_t { } nwins = (scalar_t::bit_length() - 1) / wbits + 1; - // Allocate the buckets and histogram uint32_t row_sz = 1U << (wbits - 1); + size_t d_buckets_sz = (nwins * row_sz) + (gpu.sm_count() * BATCH_ADD_BLOCK_SIZE / WARP_SZ); - d_buckets_sz *= sizeof(bucket_h); + d_buckets_sz *= sizeof(d_buckets[0]); size_t d_hist_sz = nwins * row_sz * sizeof(uint32_t); + size_t temp_sz = sizeof(scalar_t); + temp_sz = stride * std::max(2 * sizeof(uint2), temp_sz); - // Compute how big each batch should be - batch = 1 << (std::max(lg_n, wbits) - wbits); + size_t batch = 1 << (std::max(lg_n, wbits) - wbits); batch >>= 6; batch = batch ? batch : 1; - stride = (nscalars + batch - 1) / batch; - stride = (stride + WARP_SZ - 1) & ~(WARP_SZ - 1); - - // Allocate the memory required for each batch - size_t scalars_sz = stride * std::max(2 * sizeof(uint2), sizeof(scalar_t)); - size_t pidx_sz = sizeof(uint32_t) * stride; + uint32_t stride = (nscalars + batch - 1) / batch; + stride = (stride + WARP_SZ - 1) & ((size_t)0 - WARP_SZ); + size_t digits_sz = nwins * stride * sizeof(uint32_t); - size_t d_blob_sz = d_buckets_sz + d_hist_sz + scalars_sz + pidx_sz + digits_sz; + size_t d_blob_sz = d_buckets_sz + d_hist_sz + temp_sz + digits_sz; + d_total_blob = reinterpret_cast(gpu.Dmalloc(d_blob_sz)); size_t offset = 0; - - d_buckets = reinterpret_cast(d_total_blob); + d_buckets = reinterpret_cast(&d_total_blob[offset]); offset += d_buckets_sz; - d_hist = vec2d_t(reinterpret_cast(&d_total_blob[offset]), row_sz); + d_hist = vec2d_t((uint32_t *)&d_total_blob[offset], row_sz); offset += d_hist_sz; - d_temps = vec2d_t(reinterpret_cast(&d_total_blob[offset]), stride); - d_scalars = reinterpret_cast(&d_total_blob[offset]); - offset += scalars_sz; - d_pidx = reinterpret_cast(&d_total_blob[offset]); - offset += pidx_sz; - d_digits = vec2d_t(reinterpret_cast(&d_total_blob[offset]), stride); - offset += digits_sz; - } - - RustError invoke(point_t &out, const scalar_t *scalars, size_t nscalars, uint32_t pidx[], bool mont = true) { - // assert(this->nscalars <= nscalars); - - wbits = 17; - if (nscalars > 192) { - wbits = std::min(lg2(nscalars + nscalars/2) - 8, 18); - if (wbits < 10) - wbits = 10; - } else if (nscalars > 0) { - wbits = 10; - } - nwins = (scalar_t::bit_length() - 1) / wbits + 1; - - uint32_t row_sz = 1U << (wbits-1); - - size_t d_buckets_sz = (nwins * row_sz) - + (gpu.sm_count() * BATCH_ADD_BLOCK_SIZE / WARP_SZ); - size_t d_blob_sz = (d_buckets_sz * sizeof(d_buckets[0])) - + (nwins * row_sz * sizeof(uint32_t)); - - d_buckets = reinterpret_cast(gpu.Dmalloc(d_blob_sz)); - d_hist = vec2d_t(&d_buckets[d_buckets_sz], row_sz); - - uint32_t lg_n = lg2(nscalars + nscalars/2); - size_t batch = 1 << (std::max(lg_n, wbits) - wbits); - batch >>= 6; - batch = batch ? batch : 1; - uint32_t stride = (nscalars + batch - 1) / batch; - stride = (stride+WARP_SZ-1) & ((size_t)0-WARP_SZ); + d_temps = vec2d_t((uint2 *)&d_total_blob[offset], stride); + d_scalars = (scalar_t *)&d_total_blob[offset]; + offset += temp_sz; + d_digits = vec2d_t((uint32_t *)&d_total_blob[offset], stride); std::vector res(nwins); std::vector ones(gpu.sm_count() * BATCH_ADD_BLOCK_SIZE / WARP_SZ); @@ -540,27 +546,14 @@ class msm_t { point_t p; try { - // |scalars| being nullptr means the scalars are pre-loaded to - // |d_scalars|, otherwise allocate stride. - size_t temp_sz = scalars ? sizeof(scalar_t) : 0; - temp_sz = stride * std::max(2*sizeof(uint2), temp_sz); - - size_t digits_sz = nwins * stride * sizeof(uint32_t); - - dev_ptr_t d_temp{temp_sz + digits_sz, gpu[2]}; - - vec2d_t d_temps{&d_temp[0], stride}; - vec2d_t d_digits{&d_temp[temp_sz], stride}; - - scalar_t* d_scalars = scalars ? (scalar_t*)&d_temp[0] - : this->d_scalars; size_t d_off = 0; // device offset size_t h_off = 0; // host offset size_t num = stride > nscalars ? nscalars : stride; event_t ev; gpu[2].HtoD(&d_scalars[0], &scalars[h_off], num); - if (pidx) gpu[2].HtoD(&d_pidx[0], &pidx[h_off], num); + if (pidx) + gpu[2].HtoD(&d_pidx[0], &pidx[h_off], num); digits(&d_scalars[0], num, d_digits, d_temps, mont, d_pidx); gpu[2].record(ev); @@ -587,7 +580,8 @@ class msm_t { num = d_off + stride <= nscalars ? stride : nscalars - d_off; gpu[2].HtoD(&d_scalars[0], &scalars[d_off], num); - if (pidx) gpu[2].HtoD(&d_pidx[0], &pidx[d_off], num); + if (pidx) + gpu[2].HtoD(&d_pidx[0], &pidx[d_off], num); gpu[2].wait(ev); digits(&d_scalars[0], num, d_digits, d_temps, mont, d_pidx); @@ -764,8 +758,8 @@ static RustError mult_pippenger(point_t *out, const affine_t points[], size_t np template -static RustError mult_pippenger_with(point_t *out, msm_context_t *msm_context, - const scalar_t scalars[], size_t nscalars, uint32_t pidx[], bool mont = true) { +static RustError mult_pippenger_with(point_t *out, msm_context_t *msm_context, const scalar_t scalars[], + size_t nscalars, uint32_t pidx[], bool mont = true) { try { msm_t msm{msm_context->d_points, msm_context->npoints}; // msm.setup_scratch(nscalars);