From f57003479af73c70ccd27c84ba84fbb14a130b2f Mon Sep 17 00:00:00 2001 From: Seunghwa Kang Date: Wed, 6 Nov 2024 15:21:07 -0800 Subject: [PATCH] misc updates --- cpp/src/prims/detail/prim_utils.cuh | 104 ++++++++++++++++++++++++++++ cpp/tests/c_api/mg_test_utils.cpp | 9 ++- cpp/tests/c_api/mg_test_utils.h | 12 ++++ 3 files changed, 124 insertions(+), 1 deletion(-) create mode 100644 cpp/src/prims/detail/prim_utils.cuh diff --git a/cpp/src/prims/detail/prim_utils.cuh b/cpp/src/prims/detail/prim_utils.cuh new file mode 100644 index 00000000000..3d8f5626042 --- /dev/null +++ b/cpp/src/prims/detail/prim_utils.cuh @@ -0,0 +1,104 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include + +namespace cugraph { + +namespace detail { + +template +__host__ __device__ priority_t +rank_to_priority(int rank, + int root, + int subgroup_size /* faster interconnect within a subgroup */, + int comm_size, + vertex_t offset /* to evenly distribute traffic */) +{ + static_assert(sizeof(priority_t) == 1 || sizeof(priority_t) == 2 || sizeof(priority_t) == 4); + using cast_t = std::conditional_t< + sizeof(priority_t) == 1, + int16_t, + std::conditional_t>; // to prevent overflow + + if (rank == root) { + return priority_t{0}; + } else if (rank / subgroup_size == + root / subgroup_size) { // intra-subgroup communication is sufficient (priorities in + // [1, subgroup_size) + auto rank_dist = + static_cast(((static_cast(rank) + subgroup_size) - root) % subgroup_size); + int modulo = subgroup_size - 1; + return static_cast(1 + (static_cast(rank_dist - 1) + (offset % modulo)) % + modulo); + } else { // inter-subgroup communication is necessary (priorities in [subgroup_size, comm_size) + auto subgroup_dist = + static_cast(((static_cast(rank / subgroup_size) + (comm_size / subgroup_size)) - + (root / subgroup_size)) % + (comm_size / subgroup_size)); + auto intra_subgroup_rank_dist = static_cast( + ((static_cast(rank % subgroup_size) + subgroup_size) - (root % subgroup_size)) % + subgroup_size); + auto rank_dist = subgroup_dist * subgroup_size + intra_subgroup_rank_dist; + int modulo = comm_size - subgroup_size; + return static_cast( + subgroup_size + + (static_cast(rank_dist - subgroup_size) + (offset % modulo)) % modulo); + } +} + +template +__host__ __device__ int priority_to_rank( + priority_t priority, + int root, + int subgroup_size /* faster interconnect within a subgroup */, + int comm_size, + vertex_t offset /* to evenly distribute traffict */) +{ + static_assert(sizeof(priority_t) == 1 || sizeof(priority_t) == 2 || sizeof(priority_t) == 4); + using cast_t = std::conditional_t< + sizeof(priority_t) == 1, + int16_t, + std::conditional_t>; // to prevent overflow + + if (priority == priority_t{0}) { + return root; + } else if (priority < static_cast(subgroup_size)) { + int modulo = subgroup_size - 1; + auto rank_dist = static_cast( + 1 + ((static_cast(priority - 1) + modulo) - (offset % modulo)) % modulo); + return static_cast((root - (root % subgroup_size)) + + ((static_cast(root) + rank_dist) % subgroup_size)); + } else { + int modulo = comm_size - subgroup_size; + auto rank_dist = static_cast( + subgroup_size + + ((static_cast(priority) - subgroup_size) + (modulo - (offset % modulo))) % modulo); + auto subgroup_dist = rank_dist / subgroup_size; + auto intra_subgroup_rank_dist = rank_dist % subgroup_size; + return static_cast( + ((static_cast((root / subgroup_size) * subgroup_size) + + subgroup_dist * subgroup_size) + + (static_cast(root) + intra_subgroup_rank_dist) % subgroup_size) % + comm_size); + } +} + +} // namespace detail + +} // namespace cugraph diff --git a/cpp/tests/c_api/mg_test_utils.cpp b/cpp/tests/c_api/mg_test_utils.cpp index 58c5e59c16f..18807b00a6b 100644 --- a/cpp/tests/c_api/mg_test_utils.cpp +++ b/cpp/tests/c_api/mg_test_utils.cpp @@ -95,9 +95,16 @@ extern "C" void* create_mg_raft_handle(int argc, char** argv) C_MPI_TRY(MPI_Comm_size(MPI_COMM_WORLD, &comm_size)); C_CUDA_TRY(cudaGetDeviceCount(&num_gpus_per_node)); C_CUDA_TRY(cudaSetDevice(comm_rank % num_gpus_per_node)); + ncclUniqueId id{}; + if (comm_rank == 0) { + C_NCCL_TRY(ncclGetUniqueId(&id)); + } + C_MPI_TRY(MPI_Bcast((void*)&id, sizeof(id), MPI_BYTE, 0, MPI_COMM_WORLD)); + ncclComm_t nccl_comm{}; + C_NCCL_TRY(ncclCommInitRank(&nccl_comm, comm_size, id, comm_rank)); raft::handle_t* handle = new raft::handle_t{}; - raft::comms::initialize_mpi_comms(handle, MPI_COMM_WORLD); + raft::comms::initialize_mpi_comms(handle, MPI_COMM_WORLD, nccl_comm); #if 1 int gpu_row_comm_size = 1; diff --git a/cpp/tests/c_api/mg_test_utils.h b/cpp/tests/c_api/mg_test_utils.h index 7461d402b5b..a79c74675d2 100644 --- a/cpp/tests/c_api/mg_test_utils.h +++ b/cpp/tests/c_api/mg_test_utils.h @@ -36,6 +36,18 @@ } \ } while (0) +#define C_NCCL_TRY(call) \ + do { \ + ncclResult_t status = call; \ + if (ncclSuccess != status) { \ + printf("NCCL call='%s' at file=%s line=%d failed.", \ + #call, \ + __FILE__, \ + __LINE__); \ + exit(1); \ + } \ + } while (0) + #define C_CUDA_TRY(call) \ do { \ cudaError_t const status = call; \