Skip to content

Commit

Permalink
go back now
Browse files Browse the repository at this point in the history
  • Loading branch information
Hanting Zhang committed Feb 8, 2024
1 parent ed444ea commit cd04b5e
Showing 1 changed file with 38 additions and 36 deletions.
74 changes: 38 additions & 36 deletions msm/pippenger.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ template <class bucket_t, class point_t, class affine_t, class scalar_t, class a
class bucket_h = class bucket_t::mem_t>
class msm_t {
const gpu_t &gpu;

// main data
bool owned;
affine_h *d_points;
Expand Down Expand Up @@ -346,7 +346,7 @@ class msm_t {

public:
result_t() {}
inline operator decltype(ret) &() { return ret; }
inline operator decltype(ret) & () { return ret; }
inline const bucket_t *operator[](size_t i) const { return ret[i]; }
};

Expand All @@ -359,8 +359,7 @@ class msm_t {

public:
// Initialize the MSM by moving the points to the device
msm_t(const affine_t points[], size_t npoints, bool owned, int device_id = -1)
: gpu(select_gpu(device_id)) {
msm_t(const affine_t points[], size_t npoints, bool owned, int device_id = -1) : gpu(select_gpu(device_id)) {
// set default values for fields
this->d_points = nullptr;
this->d_scalars = nullptr;
Expand All @@ -375,8 +374,7 @@ class msm_t {
CUDA_OK(cudaGetLastError());
}

msm_t(affine_h *d_points, size_t npoints, int device_id = -1)
: gpu(select_gpu(device_id)) {
msm_t(affine_h *d_points, size_t npoints, int device_id = -1) : gpu(select_gpu(device_id)) {
// set default values for fields
this->d_points = d_points;
this->d_scalars = nullptr;
Expand Down Expand Up @@ -479,7 +477,7 @@ class msm_t {
batch = batch ? batch : 1;
stride = (nscalars + batch - 1) / batch;
stride = (stride + WARP_SZ - 1) & ~(WARP_SZ - 1);

// Allocate the memory required for each batch
size_t scalars_sz = stride * std::max(2 * sizeof(uint2), sizeof(scalar_t));
size_t pidx_sz = sizeof(uint32_t) * stride;
Expand All @@ -505,33 +503,49 @@ class msm_t {

RustError invoke(point_t &out, const scalar_t *scalars, size_t nscalars, uint32_t pidx[], bool mont = true) {
// assert(this->nscalars <= nscalars);


uint32_t lg_n = lg2(nscalars + nscalars / 2);

wbits = 17;
if (nscalars > 192) {
wbits = std::min(lg2(nscalars + nscalars/2) - 8, 18);
wbits = std::min(lg_n, (uint32_t)18);
if (wbits < 10)
wbits = 10;
} else if (nscalars > 0) {
wbits = 10;
}
nwins = (scalar_t::bit_length() - 1) / wbits + 1;

uint32_t row_sz = 1U << (wbits-1);
uint32_t row_sz = 1U << (wbits - 1);

size_t d_buckets_sz = (nwins * row_sz)
+ (gpu.sm_count() * BATCH_ADD_BLOCK_SIZE / WARP_SZ);
size_t d_blob_sz = (d_buckets_sz * sizeof(d_buckets[0]))
+ (nwins * row_sz * sizeof(uint32_t));
size_t d_buckets_sz = (nwins * row_sz) + (gpu.sm_count() * BATCH_ADD_BLOCK_SIZE / WARP_SZ);
d_buckets_sz *= sizeof(d_buckets[0]);
size_t d_hist_sz = nwins * row_sz * sizeof(uint32_t);
size_t d_blob_sz = d_buckets_sz + d_hist_sz;

d_buckets = reinterpret_cast<decltype(d_buckets)>(gpu.Dmalloc(d_blob_sz));
d_hist = vec2d_t<uint32_t>(&d_buckets[d_buckets_sz], row_sz);
char *d_blob = reinterpret_cast<char *>(gpu.Dmalloc(d_blob_sz));
size_t offset = 0;
d_buckets = reinterpret_cast<decltype(d_buckets)>(&d_blob[offset]);
offset += d_buckets_sz;
d_hist = vec2d_t<uint32_t>((uint32_t *)&d_blob[offset], row_sz);

uint32_t lg_n = lg2(nscalars + nscalars/2);
size_t batch = 1 << (std::max(lg_n, wbits) - wbits);
batch >>= 6;
batch = batch ? batch : 1;
uint32_t stride = (nscalars + batch - 1) / batch;
stride = (stride+WARP_SZ-1) & ((size_t)0-WARP_SZ);
stride = (stride + WARP_SZ - 1) & ((size_t)0 - WARP_SZ);

size_t temp_sz = scalars ? sizeof(scalar_t) : 0;
temp_sz = stride * std::max(2 * sizeof(uint2), temp_sz);

size_t digits_sz = nwins * stride * sizeof(uint32_t);

uint8_t *d_temp = (uint8_t *) gpu[2].Dmalloc(temp_sz + digits_sz);

vec2d_t<uint2> d_temps{(uint2 *)&d_temp[0], stride};
vec2d_t<uint32_t> d_digits{(uint32_t *)&d_temp[temp_sz], stride};

scalar_t *d_scalars = (scalar_t *)&d_temp[0];

std::vector<result_t> res(nwins);
std::vector<bucket_t> ones(gpu.sm_count() * BATCH_ADD_BLOCK_SIZE / WARP_SZ);
Expand All @@ -540,27 +554,14 @@ class msm_t {
point_t p;

try {
// |scalars| being nullptr means the scalars are pre-loaded to
// |d_scalars|, otherwise allocate stride.
size_t temp_sz = scalars ? sizeof(scalar_t) : 0;
temp_sz = stride * std::max(2*sizeof(uint2), temp_sz);

size_t digits_sz = nwins * stride * sizeof(uint32_t);

dev_ptr_t<uint8_t> d_temp{temp_sz + digits_sz, gpu[2]};

vec2d_t<uint2> d_temps{&d_temp[0], stride};
vec2d_t<uint32_t> d_digits{&d_temp[temp_sz], stride};

scalar_t* d_scalars = scalars ? (scalar_t*)&d_temp[0]
: this->d_scalars;
size_t d_off = 0; // device offset
size_t h_off = 0; // host offset
size_t num = stride > nscalars ? nscalars : stride;
event_t ev;

gpu[2].HtoD(&d_scalars[0], &scalars[h_off], num);
if (pidx) gpu[2].HtoD(&d_pidx[0], &pidx[h_off], num);
if (pidx)
gpu[2].HtoD(&d_pidx[0], &pidx[h_off], num);

digits(&d_scalars[0], num, d_digits, d_temps, mont, d_pidx);
gpu[2].record(ev);
Expand All @@ -587,7 +588,8 @@ class msm_t {
num = d_off + stride <= nscalars ? stride : nscalars - d_off;

gpu[2].HtoD(&d_scalars[0], &scalars[d_off], num);
if (pidx) gpu[2].HtoD(&d_pidx[0], &pidx[d_off], num);
if (pidx)
gpu[2].HtoD(&d_pidx[0], &pidx[d_off], num);

gpu[2].wait(ev);
digits(&d_scalars[0], num, d_digits, d_temps, mont, d_pidx);
Expand Down Expand Up @@ -764,8 +766,8 @@ static RustError mult_pippenger(point_t *out, const affine_t points[], size_t np

template <class bucket_t, class point_t, class affine_t, class scalar_t, class affine_h = class affine_t::mem_t,
class bucket_h = class bucket_t::mem_t>
static RustError mult_pippenger_with(point_t *out, msm_context_t<affine_h> *msm_context,
const scalar_t scalars[], size_t nscalars, uint32_t pidx[], bool mont = true) {
static RustError mult_pippenger_with(point_t *out, msm_context_t<affine_h> *msm_context, const scalar_t scalars[],
size_t nscalars, uint32_t pidx[], bool mont = true) {
try {
msm_t<bucket_t, point_t, affine_t, scalar_t> msm{msm_context->d_points, msm_context->npoints};
// msm.setup_scratch(nscalars);
Expand Down

0 comments on commit cd04b5e

Please sign in to comment.