Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NVLS support for msccl++ executor #375

Draft
wants to merge 13 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion apps/nccl/src/allreduce.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
#include <mscclpp/concurrency_device.hpp>
#include <mscclpp/core.hpp>
#include <mscclpp/gpu.hpp>
#include <mscclpp/gpu_data_types.hpp>
#include <mscclpp/packet_device.hpp>
#include <mscclpp/sm_channel.hpp>
#include <mscclpp/sm_channel_device.hpp>

#include "common.hpp"
#include "gpu_data_types.hpp"

template <typename To, typename From>
__forceinline__ __device__ To bit_cast(const From& src) {
Expand Down
1 change: 1 addition & 0 deletions include/mscclpp/gpu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ constexpr auto CU_MEM_ACCESS_FLAGS_PROT_READWRITE = hipMemAccessFlagsProtReadWri
#define cuMemSetAccess(...) hipMemSetAccess(__VA_ARGS__)
#define cuMemMap(...) hipMemMap(__VA_ARGS__)
#define cuMemUnmap(...) hipMemUnmap(__VA_ARGS__)
#define cuMemRetainAllocationHandle(...) hipMemRetainAllocationHandle(__VA_ARGS__)

#else

Expand Down
File renamed without changes.
75 changes: 75 additions & 0 deletions include/mscclpp/gpu_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,47 @@ PhysicalCudaMemory<T>* cudaPhysicalCalloc(size_t nelem, size_t gran) {
return new PhysicalCudaMemory<T>(memHandle, devicePtr, bufferSize);
}

template <class T>
T* cudaPhysicalCallocPtr(size_t nelem, size_t gran) {
AvoidCudaGraphCaptureGuard cgcGuard;

int deviceId = -1;
MSCCLPP_CUDATHROW(cudaGetDevice(&deviceId));

CUmemAllocationProp prop = {};
prop.type = CU_MEM_ALLOCATION_TYPE_PINNED;
prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
prop.location.id = deviceId;
#if defined(__HIP_PLATFORM_AMD__)
// TODO: revisit when HIP fixes this typo in the field name
prop.requestedHandleType = CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR;
#else
prop.requestedHandleTypes = CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR;
#endif

CUmemGenericAllocationHandle memHandle;
size_t bufferSize = sizeof(T) * nelem;
// allocate physical memory
MSCCLPP_CUTHROW(cuMemCreate(&memHandle, bufferSize, &prop, 0 /*flags*/));

CUmemAccessDesc accessDesc = {};
accessDesc.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
accessDesc.location.id = deviceId;
accessDesc.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE;

T* devicePtr = nullptr;
// Map the device pointer
MSCCLPP_CUTHROW(cuMemAddressReserve((CUdeviceptr*)&devicePtr, bufferSize, gran, 0U, 0));
MSCCLPP_CUTHROW(cuMemMap((CUdeviceptr)devicePtr, bufferSize, 0, memHandle, 0));
MSCCLPP_CUTHROW(cuMemSetAccess((CUdeviceptr)devicePtr, bufferSize, &accessDesc, 1));
CudaStreamWithFlags stream(cudaStreamNonBlocking);
MSCCLPP_CUDATHROW(cudaMemsetAsync(devicePtr, 0, bufferSize, stream));

MSCCLPP_CUDATHROW(cudaStreamSynchronize(stream));

return devicePtr;
}

template <class T>
T* cudaExtCalloc(size_t nelem) {
AvoidCudaGraphCaptureGuard cgcGuard;
Expand Down Expand Up @@ -214,6 +255,22 @@ struct CudaPhysicalDeleter {
}
};

template <class T>
struct CudaPhysicalPtrDeleter {
static_assert(!std::is_array_v<T>, "T must not be an array");
void operator()(T* ptr) {
AvoidCudaGraphCaptureGuard cgcGuard;
CUmemGenericAllocationHandle handle;
size_t size = 0;
MSCCLPP_CUTHROW(cuMemRetainAllocationHandle(&handle, ptr));
MSCCLPP_CUTHROW(cuMemRelease(handle));
MSCCLPP_CUTHROW(cuMemGetAddressRange(NULL, &size, (CUdeviceptr)ptr));
MSCCLPP_CUTHROW(cuMemUnmap((CUdeviceptr)ptr, size));
MSCCLPP_CUTHROW(cuMemRelease(handle));
MSCCLPP_CUTHROW(cuMemAddressFree((CUdeviceptr)ptr, size));
}
};

/// A deleter that calls cudaFreeHost for use with std::unique_ptr or std::shared_ptr.
/// @tparam T Type of each element in the allocated memory.
template <class T>
Expand Down Expand Up @@ -246,6 +303,24 @@ std::shared_ptr<PhysicalCudaMemory<T>> allocSharedPhysicalCuda(size_t count, siz
std::shared_ptr<PhysicalCudaMemory<T>>>(count, gran);
}

#if (USE_NVLS)
template <class T>
std::shared_ptr<T> allocSharedPhysicalCudaPtr(size_t count, size_t gran = 0) {
if (gran == 0) {
CUmulticastObjectProp mcProp = {};
int numDevices = 0;
// device count is dummy here, it may effect the granularity in future.
MSCCLPP_CUDATHROW(cudaGetDeviceCount(&numDevices));
mcProp.size = count * sizeof(T);
mcProp.numDevices = numDevices;
mcProp.handleTypes = CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR;
MSCCLPP_CUTHROW(cuMulticastGetGranularity(&gran, &mcProp, CU_MULTICAST_GRANULARITY_MINIMUM));
}
return detail::safeAlloc<T, detail::cudaPhysicalCallocPtr<T>, CudaPhysicalPtrDeleter<T>, std::shared_ptr<T>>(count,
gran);
}
#endif

/// Allocates memory on the device and returns a std::shared_ptr to it. The memory is zeroed out.
/// @tparam T Type of each element in the allocated memory.
/// @param count Number of elements to allocate.
Expand Down
7 changes: 6 additions & 1 deletion include/mscclpp/nvls.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,17 @@ class NvlsConnection {
struct DeviceMulticastPointer {
private:
std::shared_ptr<PhysicalCudaMemory<char>> deviceMem_;
void* devicePtr_;
std::shared_ptr<char> mcPtr_;
size_t bufferSize_;

public:
using DeviceHandle = DeviceMulticastPointerDeviceHandle;
DeviceMulticastPointer(std::shared_ptr<PhysicalCudaMemory<char>> deviceMem, std::shared_ptr<char> mcPtr,
size_t bufferSize)
: deviceMem_(deviceMem), mcPtr_(mcPtr), bufferSize_(bufferSize) {}
: deviceMem_(deviceMem), devicePtr_(nullptr), mcPtr_(mcPtr), bufferSize_(bufferSize) {}
DeviceMulticastPointer(void* devicePtr, std::shared_ptr<char> mcPtr, size_t bufferSize)
: deviceMem_(nullptr), devicePtr_(devicePtr), mcPtr_(mcPtr), bufferSize_(bufferSize) {}
DeviceHandle deviceHandle();
char* getDevicePtr();

Expand All @@ -46,6 +49,8 @@ class NvlsConnection {
/// and the \p size of the allocation.
std::shared_ptr<char> bindAllocatedCuda(CUmemGenericAllocationHandle memHandle, size_t size);

std::shared_ptr<char> bindAllocatedCudaWithPtr(CUdeviceptr devicePtr, size_t size);

size_t getMultiCastMinGranularity();

private:
Expand Down
27 changes: 25 additions & 2 deletions include/mscclpp/nvls_device.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
#include <cuda_fp16.h>
#endif // defined(MSCCLPP_DEVICE_CUDA)

#include <mscclpp/gpu_data_types.hpp>

#include "device.hpp"

namespace mscclpp {
Expand All @@ -27,7 +29,11 @@ struct DeviceMulticastPointerDeviceHandle {
#if defined(MSCCLPP_DEVICE_CUDA)
template <typename TValue, typename T>
MSCCLPP_DEVICE_INLINE static void multimemLoadReduce(TValue& val, T* ptr) {
if constexpr (std::is_same_v<TValue, uint4> && std::is_same_v<T, float>) {
if constexpr (std::is_same_v<TValue, int32_t> && std::is_same_v<T, int32_t>) {
asm("multimem.ld_reduce.relaxed.sys.global.add.s32 %0, [%1];" : "=r"(val) : "l"(ptr) : "memory");
} else if constexpr (std::is_same_v<TValue, uint32_t> && std::is_same_v<T, uint32_t>) {
asm("multimem.ld_reduce.relaxed.sys.global.add.u32 %0, [%1];" : "=r"(val) : "l"(ptr) : "memory");
} else if constexpr (std::is_same_v<TValue, uint4> && std::is_same_v<T, float>) {
asm("multimem.ld_reduce.relaxed.sys.global.add.v4.f32 {%0,%1,%2,%3}, [%4];"
: "=r"(val.x), "=r"(val.y), "=r"(val.z), "=r"(val.w)
: "l"(ptr)
Expand All @@ -51,14 +57,25 @@ struct DeviceMulticastPointerDeviceHandle {
: "memory");
} else if constexpr (std::is_same_v<TValue, uint1> && std::is_same_v<T, __half2>) {
asm("multimem.ld_reduce.relaxed.sys.global.add.f16x2 {%0}, [%1];" : "=r"(val.x) : "l"(ptr) : "memory");
} else if constexpr (std::is_same_v<TValue, uint4> && std::is_same_v<T, __bfloat162>) {
asm("multimem.ld_reduce.relaxed.sys.global.add.v4.bf16x2 {%0,%1,%2,%3}, [%4];"
: "=r"(val.x), "=r"(val.y), "=r"(val.z), "=r"(val.w)
: "l"(ptr)
: "memory");
} else if constexpr (std::is_same_v<TValue, uint1> && std::is_same_v<T, __bfloat162>) {
asm("multimem.ld_reduce.relaxed.sys.global.add.bf16x2 {%0}, [%1];" : "=r"(val.x) : "l"(ptr) : "memory");
} else {
static_assert(dependentFalse<T>, "Not supported type");
}
};

template <typename TValue, typename T>
MSCCLPP_DEVICE_INLINE static void multimemStore(const TValue& val, T* ptr) {
if constexpr (std::is_same_v<TValue, uint4> && std::is_same_v<T, float>) {
if constexpr (std::is_same_v<TValue, int32_t> && std::is_same_v<T, int32_t>) {
asm volatile("multimem.st.relaxed.sys.global.s32 [%0], %1;" ::"l"(ptr), "r"(val) : "memory");
} else if constexpr (std::is_same_v<TValue, uint32_t> && std::is_same_v<T, uint32_t>) {
asm volatile("multimem.st.relaxed.sys.global.u32 [%0], %1;" ::"l"(ptr), "r"(val) : "memory");
} else if constexpr (std::is_same_v<TValue, uint4> && std::is_same_v<T, float>) {
asm volatile("multimem.st.relaxed.sys.global.v4.f32 [%0], {%1,%2,%3,%4};" ::"l"(ptr), "r"(val.x), "r"(val.y),
"r"(val.z), "r"(val.w)
: "memory");
Expand All @@ -76,6 +93,12 @@ struct DeviceMulticastPointerDeviceHandle {
: "memory");
} else if constexpr (std::is_same_v<TValue, uint1> && std::is_same_v<T, __half2>) {
asm volatile("multimem.st.relaxed.sys.global.f16x2 [%0], {%1};" ::"l"(ptr), "r"(val.x) : "memory");
} else if constexpr (std::is_same_v<TValue, uint4> && std::is_same_v<T, __bfloat162>) {
asm volatile("multimem.st.relaxed.sys.global.v4.bf16x2 [%0], {%1,%2,%3,%4};" ::"l"(ptr), "r"(val.x), "r"(val.y),
"r"(val.z), "r"(val.w)
: "memory");
} else if constexpr (std::is_same_v<TValue, uint1> && std::is_same_v<T, __bfloat162>) {
asm volatile("multimem.st.relaxed.sys.global.bf16x2 [%0], {%1};" ::"l"(ptr), "r"(val.x) : "memory");
} else {
static_assert(dependentFalse<T>, "Not supported type");
}
Expand Down
2 changes: 1 addition & 1 deletion python/mscclpp_benchmark/allreduce.cu
Original file line number Diff line number Diff line change
Expand Up @@ -839,7 +839,7 @@ MSCCLPP_DEVICE_INLINE void allreduce6_helper(mscclpp::SmDevice2DeviceSemaphoreDe
mscclpp::DeviceMulticastPointerDeviceHandle::multimemLoadReduce(val, (float*)(mc_ptr + idx));
mscclpp::DeviceMulticastPointerDeviceHandle::multimemStore(val, (float*)(mc_ptr + idx));
} else if constexpr (std::is_same_v<DataType, float> && (kVecSize == 1)) {
uint1 val; // fits 1 float element
float val; // fits 1 float element
mscclpp::DeviceMulticastPointerDeviceHandle::multimemLoadReduce(val, (float*)(mc_ptr + idx));
mscclpp::DeviceMulticastPointerDeviceHandle::multimemStore(val, (float*)(mc_ptr + idx));
} else if constexpr (std::is_same_v<DataType, __half> && (kVecSize == 8)) {
Expand Down
Loading
Loading