Skip to content

Commit

Permalink
like change backl????
Browse files Browse the repository at this point in the history
  • Loading branch information
Hanting Zhang committed Feb 8, 2024
1 parent 663ac05 commit 06f92bd
Showing 1 changed file with 7 additions and 46 deletions.
53 changes: 7 additions & 46 deletions msm/pippenger.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -469,17 +469,17 @@ class msm_t {
d_buckets_sz *= sizeof(d_buckets[0]);
size_t d_hist_sz = nwins * row_sz * sizeof(uint32_t);

size_t batch = 1 << (std::max(lg_n, wbits) - wbits);
this->batch = 1 << (std::max(lg_n, wbits) - wbits);
batch >>= 6;
batch = batch ? batch : 1;
uint32_t stride = (nscalars + batch - 1) / batch;
this->stride = (nscalars + batch - 1) / batch;
stride = (stride + WARP_SZ - 1) & ((size_t)0 - WARP_SZ);

size_t temp_sz = stride * std::max(2 * sizeof(uint2), sizeof(scalar_t));

size_t digits_sz = nwins * stride * sizeof(uint32_t);
size_t pidx_sz = stride * sizeof(uint32_t);

size_t d_blob_sz = d_buckets_sz + d_hist_sz + temp_sz + digits_sz;
size_t d_blob_sz = d_buckets_sz + d_hist_sz + temp_sz + pidx_sz + digits_sz;

d_total_blob = reinterpret_cast<char *>(gpu.Dmalloc(d_blob_sz));
size_t offset = 0;
Expand All @@ -491,53 +491,14 @@ class msm_t {
d_temps = vec2d_t<uint2>((uint2 *)&d_total_blob[offset], stride);
d_scalars = (scalar_t *)&d_total_blob[offset];
offset += temp_sz;
d_pidx = (uint32_t *)&d_total_blob[offset];
offset += pidx_sz;
d_digits = vec2d_t<uint32_t>((uint32_t *)&d_total_blob[offset], stride);
}

RustError invoke(point_t &out, const scalar_t *scalars, size_t nscalars, uint32_t pidx[], bool mont = true) {
// assert(this->nscalars <= nscalars);

uint32_t lg_n = lg2(nscalars + nscalars / 2);

wbits = 17;
if (nscalars > 192) {
wbits = std::min(lg_n, (uint32_t)18);
if (wbits < 10)
wbits = 10;
} else if (nscalars > 0) {
wbits = 10;
}
nwins = (scalar_t::bit_length() - 1) / wbits + 1;

uint32_t row_sz = 1U << (wbits - 1);

size_t d_buckets_sz = (nwins * row_sz) + (gpu.sm_count() * BATCH_ADD_BLOCK_SIZE / WARP_SZ);
d_buckets_sz *= sizeof(d_buckets[0]);
size_t d_hist_sz = nwins * row_sz * sizeof(uint32_t);

size_t batch = 1 << (std::max(lg_n, wbits) - wbits);
batch >>= 6;
batch = batch ? batch : 1;
uint32_t stride = (nscalars + batch - 1) / batch;
stride = (stride + WARP_SZ - 1) & ((size_t)0 - WARP_SZ);

size_t temp_sz = stride * std::max(2 * sizeof(uint2), sizeof(scalar_t));

size_t digits_sz = nwins * stride * sizeof(uint32_t);

size_t d_blob_sz = d_buckets_sz + d_hist_sz + temp_sz + digits_sz;

d_total_blob = reinterpret_cast<char *>(gpu.Dmalloc(d_blob_sz));
size_t offset = 0;
d_buckets = reinterpret_cast<decltype(d_buckets)>(&d_total_blob[offset]);
offset += d_buckets_sz;
d_hist = vec2d_t<uint32_t>((uint32_t *)&d_total_blob[offset], row_sz);
offset += d_hist_sz;

d_temps = vec2d_t<uint2>((uint2 *)&d_total_blob[offset], stride);
d_scalars = (scalar_t *)&d_total_blob[offset];
offset += temp_sz;
d_digits = vec2d_t<uint32_t>((uint32_t *)&d_total_blob[offset], stride);
setup_scratch(nscalars);

std::vector<result_t> res(nwins);
std::vector<bucket_t> ones(gpu.sm_count() * BATCH_ADD_BLOCK_SIZE / WARP_SZ);
Expand Down

0 comments on commit 06f92bd

Please sign in to comment.