Skip to content

Commit

Permalink
yes pidx
Browse files Browse the repository at this point in the history
  • Loading branch information
Hanting Zhang committed Feb 8, 2024
1 parent 06f92bd commit c985f76
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 72 deletions.
64 changes: 0 additions & 64 deletions .vscode/settings.json

This file was deleted.

16 changes: 8 additions & 8 deletions msm/pippenger.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,7 @@ class msm_t {
public:
// Compute various constants (stride length, window size) based on the number of scalars.
// Also allocate scratch space.
void setup_scratch(size_t nscalars) {
void setup_scratch(size_t nscalars, uint32_t *pidx) {
this->nscalars = nscalars;

uint32_t lg_n = lg2(nscalars + nscalars / 2);
Expand Down Expand Up @@ -477,9 +477,9 @@ class msm_t {

size_t temp_sz = stride * std::max(2 * sizeof(uint2), sizeof(scalar_t));
size_t digits_sz = nwins * stride * sizeof(uint32_t);
size_t pidx_sz = stride * sizeof(uint32_t);
size_t pidx_sz = pidx ? stride * sizeof(uint32_t) : 0;

size_t d_blob_sz = d_buckets_sz + d_hist_sz + temp_sz + pidx_sz + digits_sz;
size_t d_blob_sz = d_buckets_sz + d_hist_sz + temp_sz + digits_sz + pidx_sz;

d_total_blob = reinterpret_cast<char *>(gpu.Dmalloc(d_blob_sz));
size_t offset = 0;
Expand All @@ -491,14 +491,14 @@ class msm_t {
d_temps = vec2d_t<uint2>((uint2 *)&d_total_blob[offset], stride);
d_scalars = (scalar_t *)&d_total_blob[offset];
offset += temp_sz;
d_pidx = (uint32_t *)&d_total_blob[offset];
offset += pidx_sz;
d_digits = vec2d_t<uint32_t>((uint32_t *)&d_total_blob[offset], stride);
offset += digits_sz;
if (pidx) d_pidx = (uint32_t *)&d_total_blob[offset];
}

RustError invoke(point_t &out, const scalar_t *scalars, size_t nscalars, uint32_t pidx[], bool mont = true) {
// assert(this->nscalars <= nscalars);
setup_scratch(nscalars);
setup_scratch(nscalars, pidx);

std::vector<result_t> res(nwins);
std::vector<bucket_t> ones(gpu.sm_count() * BATCH_ADD_BLOCK_SIZE / WARP_SZ);
Expand Down Expand Up @@ -702,11 +702,11 @@ static RustError mult_pippenger_init(const affine_t points[], size_t npoints, ms

template <class bucket_t, class point_t, class affine_t, class scalar_t>
static RustError mult_pippenger(point_t *out, const affine_t points[], size_t npoints, const scalar_t scalars[],
size_t nscalars, bool mont = true) {
size_t nscalars, uint32_t pidx[], bool mont = true) {
try {
msm_t<bucket_t, point_t, affine_t, scalar_t> msm{points, npoints, true};
// msm.setup_scratch(nscalars);
return msm.invoke(*out, scalars, nscalars, nullptr, mont);
return msm.invoke(*out, scalars, nscalars, pidx, mont);
} catch (const cuda_error &e) {
out->inf();
#ifdef TAKE_RESPONSIBILITY_FOR_ERROR_MESSAGE
Expand Down

0 comments on commit c985f76

Please sign in to comment.