diff --git a/msm/pippenger.cuh b/msm/pippenger.cuh index 841f4f7..0481abc 100644 --- a/msm/pippenger.cuh +++ b/msm/pippenger.cuh @@ -504,7 +504,34 @@ class msm_t { } RustError invoke(point_t &out, const scalar_t *scalars, size_t nscalars, uint32_t pidx[], bool mont = true) { - assert(this->nscalars <= nscalars); + // assert(this->nscalars <= nscalars); + + wbits = 17; + if (nscalars > 192) { + wbits = std::min(lg2(nscalars + nscalars/2) - 8, 18); + if (wbits < 10) + wbits = 10; + } else if (nscalars > 0) { + wbits = 10; + } + nwins = (scalar_t::bit_length() - 1) / wbits + 1; + + uint32_t row_sz = 1U << (wbits-1); + + size_t d_buckets_sz = (nwins * row_sz) + + (gpu.sm_count() * BATCH_ADD_BLOCK_SIZE / WARP_SZ); + size_t d_blob_sz = (d_buckets_sz * sizeof(d_buckets[0])) + + (nwins * row_sz * sizeof(uint32_t)); + + d_buckets = reinterpret_cast(gpu.Dmalloc(d_blob_sz)); + d_hist = vec2d_t(&d_buckets[d_buckets_sz], row_sz); + + uint32_t lg_n = lg2(nscalars + nscalars/2); + size_t batch = 1 << (std::max(lg_n, wbits) - wbits); + batch >>= 6; + batch = batch ? batch : 1; + uint32_t stride = (nscalars + batch - 1) / batch; + stride = (stride+WARP_SZ-1) & ((size_t)0-WARP_SZ); std::vector res(nwins); std::vector ones(gpu.sm_count() * BATCH_ADD_BLOCK_SIZE / WARP_SZ); @@ -513,6 +540,20 @@ class msm_t { point_t p; try { + // |scalars| being nullptr means the scalars are pre-loaded to + // |d_scalars|, otherwise allocate stride. + size_t temp_sz = scalars ? sizeof(scalar_t) : 0; + temp_sz = stride * std::max(2*sizeof(uint2), temp_sz); + + size_t digits_sz = nwins * stride * sizeof(uint32_t); + + dev_ptr_t d_temp{temp_sz + digits_sz, gpu[2]}; + + vec2d_t d_temps{&d_temp[0], stride}; + vec2d_t d_digits{&d_temp[temp_sz], stride}; + + scalar_t* d_scalars = scalars ? (scalar_t*)&d_temp[0] + : this->d_scalars; size_t d_off = 0; // device offset size_t h_off = 0; // host offset size_t num = stride > nscalars ? nscalars : stride; @@ -709,7 +750,7 @@ static RustError mult_pippenger(point_t *out, const affine_t points[], size_t np size_t nscalars, bool mont = true) { try { msm_t msm{points, npoints, true}; - msm.setup_scratch(nscalars); + // msm.setup_scratch(nscalars); return msm.invoke(*out, scalars, nscalars, nullptr, mont); } catch (const cuda_error &e) { out->inf(); @@ -727,7 +768,7 @@ static RustError mult_pippenger_with(point_t *out, msm_context_t *msm_ const scalar_t scalars[], size_t nscalars, uint32_t pidx[], bool mont = true) { try { msm_t msm{msm_context->d_points, msm_context->npoints}; - msm.setup_scratch(nscalars); + // msm.setup_scratch(nscalars); return msm.invoke(*out, scalars, nscalars, pidx, mont); } catch (const cuda_error &e) { out->inf();