Skip to content

Commit

Permalink
sigh
Browse files Browse the repository at this point in the history
  • Loading branch information
Hanting Zhang committed Feb 8, 2024
1 parent 574cc81 commit ed444ea
Showing 1 changed file with 44 additions and 3 deletions.
47 changes: 44 additions & 3 deletions msm/pippenger.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -504,7 +504,34 @@ class msm_t {
}

RustError invoke(point_t &out, const scalar_t *scalars, size_t nscalars, uint32_t pidx[], bool mont = true) {
assert(this->nscalars <= nscalars);
// assert(this->nscalars <= nscalars);

wbits = 17;
if (nscalars > 192) {
wbits = std::min(lg2(nscalars + nscalars/2) - 8, 18);
if (wbits < 10)
wbits = 10;
} else if (nscalars > 0) {
wbits = 10;
}
nwins = (scalar_t::bit_length() - 1) / wbits + 1;

uint32_t row_sz = 1U << (wbits-1);

size_t d_buckets_sz = (nwins * row_sz)
+ (gpu.sm_count() * BATCH_ADD_BLOCK_SIZE / WARP_SZ);
size_t d_blob_sz = (d_buckets_sz * sizeof(d_buckets[0]))
+ (nwins * row_sz * sizeof(uint32_t));

d_buckets = reinterpret_cast<decltype(d_buckets)>(gpu.Dmalloc(d_blob_sz));
d_hist = vec2d_t<uint32_t>(&d_buckets[d_buckets_sz], row_sz);

uint32_t lg_n = lg2(nscalars + nscalars/2);
size_t batch = 1 << (std::max(lg_n, wbits) - wbits);
batch >>= 6;
batch = batch ? batch : 1;
uint32_t stride = (nscalars + batch - 1) / batch;
stride = (stride+WARP_SZ-1) & ((size_t)0-WARP_SZ);

std::vector<result_t> res(nwins);
std::vector<bucket_t> ones(gpu.sm_count() * BATCH_ADD_BLOCK_SIZE / WARP_SZ);
Expand All @@ -513,6 +540,20 @@ class msm_t {
point_t p;

try {
// |scalars| being nullptr means the scalars are pre-loaded to
// |d_scalars|, otherwise allocate stride.
size_t temp_sz = scalars ? sizeof(scalar_t) : 0;
temp_sz = stride * std::max(2*sizeof(uint2), temp_sz);

size_t digits_sz = nwins * stride * sizeof(uint32_t);

dev_ptr_t<uint8_t> d_temp{temp_sz + digits_sz, gpu[2]};

vec2d_t<uint2> d_temps{&d_temp[0], stride};
vec2d_t<uint32_t> d_digits{&d_temp[temp_sz], stride};

scalar_t* d_scalars = scalars ? (scalar_t*)&d_temp[0]
: this->d_scalars;
size_t d_off = 0; // device offset
size_t h_off = 0; // host offset
size_t num = stride > nscalars ? nscalars : stride;
Expand Down Expand Up @@ -709,7 +750,7 @@ static RustError mult_pippenger(point_t *out, const affine_t points[], size_t np
size_t nscalars, bool mont = true) {
try {
msm_t<bucket_t, point_t, affine_t, scalar_t> msm{points, npoints, true};
msm.setup_scratch(nscalars);
// msm.setup_scratch(nscalars);
return msm.invoke(*out, scalars, nscalars, nullptr, mont);
} catch (const cuda_error &e) {
out->inf();
Expand All @@ -727,7 +768,7 @@ static RustError mult_pippenger_with(point_t *out, msm_context_t<affine_h> *msm_
const scalar_t scalars[], size_t nscalars, uint32_t pidx[], bool mont = true) {
try {
msm_t<bucket_t, point_t, affine_t, scalar_t> msm{msm_context->d_points, msm_context->npoints};
msm.setup_scratch(nscalars);
// msm.setup_scratch(nscalars);
return msm.invoke(*out, scalars, nscalars, pidx, mont);
} catch (const cuda_error &e) {
out->inf();
Expand Down

0 comments on commit ed444ea

Please sign in to comment.