diff --git a/msm/pippenger.cuh b/msm/pippenger.cuh index 3261ecf..2911a38 100644 --- a/msm/pippenger.cuh +++ b/msm/pippenger.cuh @@ -469,17 +469,17 @@ class msm_t { d_buckets_sz *= sizeof(d_buckets[0]); size_t d_hist_sz = nwins * row_sz * sizeof(uint32_t); - size_t batch = 1 << (std::max(lg_n, wbits) - wbits); + this->batch = 1 << (std::max(lg_n, wbits) - wbits); batch >>= 6; batch = batch ? batch : 1; - uint32_t stride = (nscalars + batch - 1) / batch; + this->stride = (nscalars + batch - 1) / batch; stride = (stride + WARP_SZ - 1) & ((size_t)0 - WARP_SZ); size_t temp_sz = stride * std::max(2 * sizeof(uint2), sizeof(scalar_t)); - size_t digits_sz = nwins * stride * sizeof(uint32_t); + size_t pidx_sz = stride * sizeof(uint32_t); - size_t d_blob_sz = d_buckets_sz + d_hist_sz + temp_sz + digits_sz; + size_t d_blob_sz = d_buckets_sz + d_hist_sz + temp_sz + pidx_sz + digits_sz; d_total_blob = reinterpret_cast(gpu.Dmalloc(d_blob_sz)); size_t offset = 0; @@ -491,53 +491,14 @@ class msm_t { d_temps = vec2d_t((uint2 *)&d_total_blob[offset], stride); d_scalars = (scalar_t *)&d_total_blob[offset]; offset += temp_sz; + d_pidx = (uint32_t *)&d_total_blob[offset]; + offset += pidx_sz; d_digits = vec2d_t((uint32_t *)&d_total_blob[offset], stride); } RustError invoke(point_t &out, const scalar_t *scalars, size_t nscalars, uint32_t pidx[], bool mont = true) { // assert(this->nscalars <= nscalars); - - uint32_t lg_n = lg2(nscalars + nscalars / 2); - - wbits = 17; - if (nscalars > 192) { - wbits = std::min(lg_n, (uint32_t)18); - if (wbits < 10) - wbits = 10; - } else if (nscalars > 0) { - wbits = 10; - } - nwins = (scalar_t::bit_length() - 1) / wbits + 1; - - uint32_t row_sz = 1U << (wbits - 1); - - size_t d_buckets_sz = (nwins * row_sz) + (gpu.sm_count() * BATCH_ADD_BLOCK_SIZE / WARP_SZ); - d_buckets_sz *= sizeof(d_buckets[0]); - size_t d_hist_sz = nwins * row_sz * sizeof(uint32_t); - - size_t batch = 1 << (std::max(lg_n, wbits) - wbits); - batch >>= 6; - batch = batch ? batch : 1; - uint32_t stride = (nscalars + batch - 1) / batch; - stride = (stride + WARP_SZ - 1) & ((size_t)0 - WARP_SZ); - - size_t temp_sz = stride * std::max(2 * sizeof(uint2), sizeof(scalar_t)); - - size_t digits_sz = nwins * stride * sizeof(uint32_t); - - size_t d_blob_sz = d_buckets_sz + d_hist_sz + temp_sz + digits_sz; - - d_total_blob = reinterpret_cast(gpu.Dmalloc(d_blob_sz)); - size_t offset = 0; - d_buckets = reinterpret_cast(&d_total_blob[offset]); - offset += d_buckets_sz; - d_hist = vec2d_t((uint32_t *)&d_total_blob[offset], row_sz); - offset += d_hist_sz; - - d_temps = vec2d_t((uint2 *)&d_total_blob[offset], stride); - d_scalars = (scalar_t *)&d_total_blob[offset]; - offset += temp_sz; - d_digits = vec2d_t((uint32_t *)&d_total_blob[offset], stride); + setup_scratch(nscalars); std::vector res(nwins); std::vector ones(gpu.sm_count() * BATCH_ADD_BLOCK_SIZE / WARP_SZ);