Skip to content

Commit

Permalink
no pidx
Browse files Browse the repository at this point in the history
  • Loading branch information
Hanting Zhang committed Feb 8, 2024
1 parent 06f92bd commit c8720f3
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions msm/pippenger.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,7 @@ class msm_t {
public:
// Compute various constants (stride length, window size) based on the number of scalars.
// Also allocate scratch space.
void setup_scratch(size_t nscalars) {
void setup_scratch(size_t nscalars, uint32_t *pidx) {
this->nscalars = nscalars;

uint32_t lg_n = lg2(nscalars + nscalars / 2);
Expand Down Expand Up @@ -477,7 +477,7 @@ class msm_t {

size_t temp_sz = stride * std::max(2 * sizeof(uint2), sizeof(scalar_t));
size_t digits_sz = nwins * stride * sizeof(uint32_t);
size_t pidx_sz = stride * sizeof(uint32_t);
size_t pidx_sz = pidx ? stride * sizeof(uint32_t) : 0;

size_t d_blob_sz = d_buckets_sz + d_hist_sz + temp_sz + pidx_sz + digits_sz;

Expand All @@ -498,7 +498,7 @@ class msm_t {

RustError invoke(point_t &out, const scalar_t *scalars, size_t nscalars, uint32_t pidx[], bool mont = true) {
// assert(this->nscalars <= nscalars);
setup_scratch(nscalars);
setup_scratch(nscalars, pidx);

std::vector<result_t> res(nwins);
std::vector<bucket_t> ones(gpu.sm_count() * BATCH_ADD_BLOCK_SIZE / WARP_SZ);
Expand Down

0 comments on commit c8720f3

Please sign in to comment.