Skip to content

Commit

Permalink
go back now
Browse files Browse the repository at this point in the history
  • Loading branch information
Hanting Zhang committed Feb 8, 2024
1 parent ed444ea commit 32f9822
Showing 1 changed file with 73 additions and 79 deletions.
152 changes: 73 additions & 79 deletions msm/pippenger.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ template <class bucket_t, class point_t, class affine_t, class scalar_t, class a
class bucket_h = class bucket_t::mem_t>
class msm_t {
const gpu_t &gpu;

// main data
bool owned;
affine_h *d_points;
Expand Down Expand Up @@ -346,7 +346,7 @@ class msm_t {

public:
result_t() {}
inline operator decltype(ret) &() { return ret; }
inline operator decltype(ret) & () { return ret; }
inline const bucket_t *operator[](size_t i) const { return ret[i]; }
};

Expand All @@ -359,8 +359,7 @@ class msm_t {

public:
// Initialize the MSM by moving the points to the device
msm_t(const affine_t points[], size_t npoints, bool owned, int device_id = -1)
: gpu(select_gpu(device_id)) {
msm_t(const affine_t points[], size_t npoints, bool owned, int device_id = -1) : gpu(select_gpu(device_id)) {
// set default values for fields
this->d_points = nullptr;
this->d_scalars = nullptr;
Expand All @@ -375,8 +374,7 @@ class msm_t {
CUDA_OK(cudaGetLastError());
}

msm_t(affine_h *d_points, size_t npoints, int device_id = -1)
: gpu(select_gpu(device_id)) {
msm_t(affine_h *d_points, size_t npoints, int device_id = -1) : gpu(select_gpu(device_id)) {
// set default values for fields
this->d_points = d_points;
this->d_scalars = nullptr;
Expand Down Expand Up @@ -453,85 +451,93 @@ class msm_t {
void setup_scratch(size_t nscalars) {
this->nscalars = nscalars;

// nscalars = (nscalars + WARP_SZ - 1) & ~(WARP_SZ - 1);
// uint32_t lg_n = lg2(nscalars + nscalars / 2);

// wbits = 17;
// if (nscalars > 192) {
// wbits = std::min(lg_n, (uint32_t)18);
// if (wbits < 10)
// wbits = 10;
// } else if (nscalars > 0) {
// wbits = 10;
// }
// nwins = (scalar_t::bit_length() - 1) / wbits + 1;

// uint32_t row_sz = 1U << (wbits - 1);

// size_t d_buckets_sz = (nwins * row_sz) + (gpu.sm_count() * BATCH_ADD_BLOCK_SIZE / WARP_SZ);
// d_buckets_sz *= sizeof(d_buckets[0]);
// size_t d_hist_sz = nwins * row_sz * sizeof(uint32_t);
// size_t temp_sz = scalars ? sizeof(scalar_t) : 0;

// size_t batch = 1 << (std::max(lg_n, wbits) - wbits);
// batch >>= 6;
// batch = batch ? batch : 1;
// uint32_t stride = (nscalars + batch - 1) / batch;
// stride = (stride + WARP_SZ - 1) & ((size_t)0 - WARP_SZ);

// temp_sz = stride * std::max(2 * sizeof(uint2), temp_sz);
// size_t digits_sz = nwins * stride * sizeof(uint32_t);

// size_t d_blob_sz = d_buckets_sz + d_hist_sz + temp_sz + digits_sz;

// d_total_blob = reinterpret_cast<char *>(gpu.Dmalloc(d_blob_sz));
// size_t offset = 0;
// d_buckets = reinterpret_cast<decltype(d_buckets)>(&d_total_blob[offset]);
// offset += d_buckets_sz;
// d_hist = vec2d_t<uint32_t>((uint32_t *)&d_total_blob[offset], row_sz);
// offset += d_hist_sz;

// d_temps = vec2d_t<uint2>((uint2 *)&d_total_blob[offset], stride);
// d_scalars = (scalar_t *)&d_total_blob[offset];
// offset += temp_sz;
// d_digits = vec2d_t<uint32_t>((uint32_t *)&d_total_blob[offset], stride);
}

RustError invoke(point_t &out, const scalar_t *scalars, size_t nscalars, uint32_t pidx[], bool mont = true) {
// assert(this->nscalars <= nscalars);

uint32_t lg_n = lg2(nscalars + nscalars / 2);

// Compute window size
wbits = 17;
if (nscalars > 192) {
wbits = std::min(lg_n - 8, (uint32_t)18);
wbits = std::min(lg_n, (uint32_t)18);
if (wbits < 10)
wbits = 10;
} else if (nscalars > 0) {
wbits = 10;
}
nwins = (scalar_t::bit_length() - 1) / wbits + 1;

// Allocate the buckets and histogram
uint32_t row_sz = 1U << (wbits - 1);

size_t d_buckets_sz = (nwins * row_sz) + (gpu.sm_count() * BATCH_ADD_BLOCK_SIZE / WARP_SZ);
d_buckets_sz *= sizeof(bucket_h);
d_buckets_sz *= sizeof(d_buckets[0]);
size_t d_hist_sz = nwins * row_sz * sizeof(uint32_t);
size_t temp_sz = scalars ? sizeof(scalar_t) : 0;

// Compute how big each batch should be
batch = 1 << (std::max(lg_n, wbits) - wbits);
size_t batch = 1 << (std::max(lg_n, wbits) - wbits);
batch >>= 6;
batch = batch ? batch : 1;
stride = (nscalars + batch - 1) / batch;
stride = (stride + WARP_SZ - 1) & ~(WARP_SZ - 1);

// Allocate the memory required for each batch
size_t scalars_sz = stride * std::max(2 * sizeof(uint2), sizeof(scalar_t));
size_t pidx_sz = sizeof(uint32_t) * stride;
uint32_t stride = (nscalars + batch - 1) / batch;
stride = (stride + WARP_SZ - 1) & ((size_t)0 - WARP_SZ);

temp_sz = stride * std::max(2 * sizeof(uint2), temp_sz);
size_t digits_sz = nwins * stride * sizeof(uint32_t);

size_t d_blob_sz = d_buckets_sz + d_hist_sz + scalars_sz + pidx_sz + digits_sz;
size_t d_blob_sz = d_buckets_sz + d_hist_sz + temp_sz + digits_sz;

d_total_blob = reinterpret_cast<char *>(gpu.Dmalloc(d_blob_sz));
size_t offset = 0;

d_buckets = reinterpret_cast<decltype(d_buckets)>(d_total_blob);
d_buckets = reinterpret_cast<decltype(d_buckets)>(&d_total_blob[offset]);
offset += d_buckets_sz;
d_hist = vec2d_t<uint32_t>(reinterpret_cast<uint32_t *>(&d_total_blob[offset]), row_sz);
d_hist = vec2d_t<uint32_t>((uint32_t *)&d_total_blob[offset], row_sz);
offset += d_hist_sz;

d_temps = vec2d_t<uint2>(reinterpret_cast<uint2 *>(&d_total_blob[offset]), stride);
d_scalars = reinterpret_cast<scalar_t *>(&d_total_blob[offset]);
offset += scalars_sz;
d_pidx = reinterpret_cast<uint32_t *>(&d_total_blob[offset]);
offset += pidx_sz;
d_digits = vec2d_t<uint32_t>(reinterpret_cast<uint32_t *>(&d_total_blob[offset]), stride);
offset += digits_sz;
}

RustError invoke(point_t &out, const scalar_t *scalars, size_t nscalars, uint32_t pidx[], bool mont = true) {
// assert(this->nscalars <= nscalars);

wbits = 17;
if (nscalars > 192) {
wbits = std::min(lg2(nscalars + nscalars/2) - 8, 18);
if (wbits < 10)
wbits = 10;
} else if (nscalars > 0) {
wbits = 10;
}
nwins = (scalar_t::bit_length() - 1) / wbits + 1;

uint32_t row_sz = 1U << (wbits-1);

size_t d_buckets_sz = (nwins * row_sz)
+ (gpu.sm_count() * BATCH_ADD_BLOCK_SIZE / WARP_SZ);
size_t d_blob_sz = (d_buckets_sz * sizeof(d_buckets[0]))
+ (nwins * row_sz * sizeof(uint32_t));

d_buckets = reinterpret_cast<decltype(d_buckets)>(gpu.Dmalloc(d_blob_sz));
d_hist = vec2d_t<uint32_t>(&d_buckets[d_buckets_sz], row_sz);

uint32_t lg_n = lg2(nscalars + nscalars/2);
size_t batch = 1 << (std::max(lg_n, wbits) - wbits);
batch >>= 6;
batch = batch ? batch : 1;
uint32_t stride = (nscalars + batch - 1) / batch;
stride = (stride+WARP_SZ-1) & ((size_t)0-WARP_SZ);
d_temps = vec2d_t<uint2>((uint2 *)&d_total_blob[offset], stride);
d_scalars = (scalar_t *)&d_total_blob[offset];
offset += temp_sz;
d_digits = vec2d_t<uint32_t>((uint32_t *)&d_total_blob[offset], stride);

std::vector<result_t> res(nwins);
std::vector<bucket_t> ones(gpu.sm_count() * BATCH_ADD_BLOCK_SIZE / WARP_SZ);
Expand All @@ -540,27 +546,14 @@ class msm_t {
point_t p;

try {
// |scalars| being nullptr means the scalars are pre-loaded to
// |d_scalars|, otherwise allocate stride.
size_t temp_sz = scalars ? sizeof(scalar_t) : 0;
temp_sz = stride * std::max(2*sizeof(uint2), temp_sz);

size_t digits_sz = nwins * stride * sizeof(uint32_t);

dev_ptr_t<uint8_t> d_temp{temp_sz + digits_sz, gpu[2]};

vec2d_t<uint2> d_temps{&d_temp[0], stride};
vec2d_t<uint32_t> d_digits{&d_temp[temp_sz], stride};

scalar_t* d_scalars = scalars ? (scalar_t*)&d_temp[0]
: this->d_scalars;
size_t d_off = 0; // device offset
size_t h_off = 0; // host offset
size_t num = stride > nscalars ? nscalars : stride;
event_t ev;

gpu[2].HtoD(&d_scalars[0], &scalars[h_off], num);
if (pidx) gpu[2].HtoD(&d_pidx[0], &pidx[h_off], num);
if (pidx)
gpu[2].HtoD(&d_pidx[0], &pidx[h_off], num);

digits(&d_scalars[0], num, d_digits, d_temps, mont, d_pidx);
gpu[2].record(ev);
Expand All @@ -587,7 +580,8 @@ class msm_t {
num = d_off + stride <= nscalars ? stride : nscalars - d_off;

gpu[2].HtoD(&d_scalars[0], &scalars[d_off], num);
if (pidx) gpu[2].HtoD(&d_pidx[0], &pidx[d_off], num);
if (pidx)
gpu[2].HtoD(&d_pidx[0], &pidx[d_off], num);

gpu[2].wait(ev);
digits(&d_scalars[0], num, d_digits, d_temps, mont, d_pidx);
Expand Down Expand Up @@ -764,8 +758,8 @@ static RustError mult_pippenger(point_t *out, const affine_t points[], size_t np

template <class bucket_t, class point_t, class affine_t, class scalar_t, class affine_h = class affine_t::mem_t,
class bucket_h = class bucket_t::mem_t>
static RustError mult_pippenger_with(point_t *out, msm_context_t<affine_h> *msm_context,
const scalar_t scalars[], size_t nscalars, uint32_t pidx[], bool mont = true) {
static RustError mult_pippenger_with(point_t *out, msm_context_t<affine_h> *msm_context, const scalar_t scalars[],
size_t nscalars, uint32_t pidx[], bool mont = true) {
try {
msm_t<bucket_t, point_t, affine_t, scalar_t> msm{msm_context->d_points, msm_context->npoints};
// msm.setup_scratch(nscalars);
Expand Down

0 comments on commit 32f9822

Please sign in to comment.