From 08a0cec2eb0d3d2c2abde2b8cb599a9180ccc72c Mon Sep 17 00:00:00 2001 From: Caio Rocha <164253795+caiomcbr@users.noreply.github.com> Date: Tue, 24 Sep 2024 23:01:41 -0700 Subject: [PATCH] Fixing RegisterMemory Allocation for ProxyChannels (#353) Co-authored-by: Binyang Li Co-authored-by: Changho Hwang --- python/test/executor_test.py | 34 +++++++++++++----- src/connection.cc | 17 +++++---- src/executor/executor.cc | 67 +++++++++++++++++++----------------- 3 files changed, 71 insertions(+), 47 deletions(-) diff --git a/python/test/executor_test.py b/python/test/executor_test.py index 53c11eb1..3e0c369d 100644 --- a/python/test/executor_test.py +++ b/python/test/executor_test.py @@ -78,12 +78,13 @@ def dtype_to_mscclpp_dtype(dtype): def main( - execution_paln_name: str, + execution_plan_name: str, execution_plan_path: str, size: int, + in_place: bool = True, dtype: cp.dtype = cp.float16, packet_type: PacketType = PacketType.LL16, - seed: int = 42 + MPI.COMM_WORLD.rank, + seed: int = 42, ): mscclpp_group = mscclpp_comm.CommGroup(MPI.COMM_WORLD) cp.cuda.Device(mscclpp_group.my_rank % mscclpp_group.nranks_per_node).use() @@ -91,21 +92,33 @@ def main( npkit_dump_dir = os.getenv("NPKIT_DUMP_DIR") if npkit_dump_dir is not None: npkit.init(mscclpp_group.my_rank) - execution_plan = ExecutionPlan(execution_paln_name, execution_plan_path) + execution_plan = ExecutionPlan(execution_plan_name, execution_plan_path) cp.random.seed(seed) nelems = size // cp.dtype(dtype).itemsize - sendbuf = cp.random.random(nelems).astype(dtype) - expected = cp.asnumpy(sendbuf) - expected = MPI.COMM_WORLD.allreduce(expected, op=MPI.SUM) + buffer = cp.random.random(nelems * mscclpp_group.nranks, dtype=cp.float32).astype(dtype) + sub_arrays = cp.split(buffer, MPI.COMM_WORLD.size) + sendbuf = cp.zeros(nelems, dtype=dtype) + for i in range(nelems): + sendbuf[i] = sub_arrays[MPI.COMM_WORLD.rank][i] + + if "allgather" in execution_plan_name: + recvbuf = cp.zeros(nelems * mscclpp_group.nranks, dtype=dtype) + expected = buffer + else: + cp.random.seed(seed) + recvbuf = cp.zeros(nelems, dtype=dtype) + expected = cp.zeros_like(sendbuf, dtype=dtype) + for i in range(mscclpp_group.nranks): + expected += sub_arrays[i] mscclpp_group.barrier() executor_func = lambda stream: executor.execute( MPI.COMM_WORLD.rank, sendbuf.data.ptr, - sendbuf.data.ptr, - sendbuf.nbytes, + sendbuf.data.ptr if in_place else recvbuf.data.ptr, sendbuf.nbytes, + sendbuf.nbytes if in_place else recvbuf.nbytes, dtype_to_mscclpp_dtype(dtype), execution_plan, stream.ptr, @@ -115,7 +128,8 @@ def main( stream = cp.cuda.Stream(non_blocking=True) executor_func(stream) stream.synchronize() - assert cp.allclose(sendbuf, expected, atol=1e-2 * mscclpp_group.nranks) + + assert cp.allclose(sendbuf if in_place else recvbuf, expected, atol=1e-2 * mscclpp_group.nranks) mscclpp_group.barrier() execution_time = bench_time(100, 10, executor_func) @@ -136,6 +150,7 @@ def main( parser.add_argument("-n", "--execution_plan_name", type=str, required=True) parser.add_argument("-path", "--execution_plan_path", type=str, required=True) parser.add_argument("--size", type=str, required=True) + parser.add_argument("--in_place", action="store_true", help="flag to define an in-place operation") parser.add_argument("--dtype", type=str, default="float16", help="Choose from float16, float32, int32") parser.add_argument("--packet_type", type=str, default="LL16", help="Choose from LL8, LL16") parser.add_argument("--seed", type=int, default=42) @@ -151,6 +166,7 @@ def main( args.execution_plan_name, args.execution_plan_path, buffer_size, + args.in_place, dtype, packet_type, args.seed, diff --git a/src/connection.cc b/src/connection.cc index 57e77b40..79c4c963 100644 --- a/src/connection.cc +++ b/src/connection.cc @@ -16,10 +16,13 @@ namespace mscclpp { -void validateTransport(RegisteredMemory mem, Transport transport) { +void validateTransport(RegisteredMemory mem, Transport transport, uint64_t offset = 0, uint64_t size = 0) { if (!mem.transports().has(transport)) { throw Error("RegisteredMemory does not support this transport", ErrorCode::InvalidUsage); } + if (offset + size > mem.size()) { + throw Error("RegisteredMemory out of bounds", ErrorCode::InvalidUsage); + } } // Connection @@ -59,8 +62,8 @@ Transport CudaIpcConnection::remoteTransport() { return Transport::CudaIpc; } void CudaIpcConnection::write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset, uint64_t size) { - validateTransport(dst, remoteTransport()); - validateTransport(src, transport()); + validateTransport(dst, remoteTransport(), dstOffset, size); + validateTransport(src, transport(), srcOffset, size); char* dstPtr = (char*)dst.data(); char* srcPtr = (char*)src.data(); @@ -115,8 +118,8 @@ Transport IBConnection::remoteTransport() { return remoteTransport_; } void IBConnection::write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset, uint64_t size) { - validateTransport(dst, remoteTransport()); - validateTransport(src, transport()); + validateTransport(dst, remoteTransport(), dstOffset, size); + validateTransport(src, transport(), srcOffset, size); auto dstTransportInfo = getImpl(dst)->getTransportInfo(remoteTransport()); if (dstTransportInfo.ibLocal) { @@ -231,8 +234,8 @@ Transport EthernetConnection::remoteTransport() { return Transport::Ethernet; } void EthernetConnection::write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset, uint64_t size) { // Validating Transport Protocol - validateTransport(dst, remoteTransport()); - validateTransport(src, transport()); + validateTransport(dst, remoteTransport(), dstOffset, size); + validateTransport(src, transport(), srcOffset, size); // Initializing Variables char* srcPtr = reinterpret_cast(src.data()) + srcOffset / sizeof(char); diff --git a/src/executor/executor.cc b/src/executor/executor.cc index f0da0e97..54986d5d 100644 --- a/src/executor/executor.cc +++ b/src/executor/executor.cc @@ -108,7 +108,7 @@ struct Executor::Impl { context.nthreadsPerBlock = plan.impl_->getNThreadsPerBlock(); this->setupConnections(context, rank, plan); this->setupRegisteredMemories(context, sendbuff, recvbuff, sendBufferSize, recvBufferSize, rank, plan); - this->setupChannels(context, sendbuff, recvbuff, sendBufferSize, rank, plan); + this->setupChannels(context, sendbuff, recvbuff, sendBufferSize, recvBufferSize, rank, plan); this->setupDeviceExecutionPlan(context, rank, plan); context.deviceExecutionPlansBuffer = allocExtSharedCuda(context.deviceExecutionPlans.size() * sizeof(DeviceExecutionPlan)); @@ -119,6 +119,23 @@ struct Executor::Impl { return context; } + TransportFlags getTransportFlags(std::vector& infos, int rank) { + TransportFlags flags; + for (ChannelInfo& info : infos) { + if (info.channelType == ChannelType::SM) { + flags |= Transport::CudaIpc; + } else if (info.channelType == ChannelType::PROXY) { + for (int peer : info.connectedPeers) { + if (!inSameNode(rank, peer, this->nranksPerNode)) { + flags |= IBs[rank % this->nranksPerNode]; + } else + flags |= Transport::CudaIpc; + } + } + } + return flags; + }; + void setupConnections(ExecutionContext& context, int rank, const ExecutionPlan& plan) { std::vector connectedPeers = plan.impl_->getConnectedPeers(rank); std::vector>> connectionFutures; @@ -135,22 +152,6 @@ struct Executor::Impl { void setupRegisteredMemories(ExecutionContext& context, void* sendbuff, void* recvbuff, size_t sendBufferSize, size_t recvBufferSize, int rank, const ExecutionPlan& plan) { - auto getTransportFlags = [&](std::vector& infos, int rank) { - TransportFlags flags; - for (ChannelInfo& info : infos) { - if (info.channelType == ChannelType::SM) { - flags |= Transport::CudaIpc; - } else if (info.channelType == ChannelType::PROXY) { - for (int peer : info.connectedPeers) { - if (!inSameNode(rank, peer, this->nranksPerNode)) { - flags |= IBs[rank % this->nranksPerNode]; - } else - flags |= Transport::CudaIpc; - } - } - } - return flags; - }; auto getBufferInfo = [&](BufferType type) { switch (type) { case BufferType::INPUT: @@ -192,22 +193,12 @@ struct Executor::Impl { comm->setup(); for (size_t i = 0; i < remoteRegMemoryFutures.size(); i++) { context.registeredMemories[{bufferType, connectedPeers[i]}] = std::move(remoteRegMemoryFutures[i].get()); - CUdeviceptr myRegBaseAdr, peerRegBaseAdr; - size_t temp; - MSCCLPP_CUTHROW(cuMemGetAddressRange(&myRegBaseAdr, &temp, (CUdeviceptr)(char*)memory.data())); - MSCCLPP_CUTHROW(cuMemGetAddressRange( - &peerRegBaseAdr, &temp, - (CUdeviceptr)(char*)context.registeredMemories[{bufferType, connectedPeers[i]}].data())); - size_t myRegOffset = (char*)memory.data() - (char*)myRegBaseAdr; - size_t peerRegOffset = - (char*)context.registeredMemories[{bufferType, connectedPeers[i]}].data() - (char*)peerRegBaseAdr; - if (myRegOffset != peerRegOffset) throw Error("Divergent data offset between peers", ErrorCode::ExecutorError); } } } - void setupChannels(ExecutionContext& context, void* sendbuff, void* recvbuff, size_t sendBufferSize, int rank, - const ExecutionPlan& plan) { + void setupChannels(ExecutionContext& context, void* sendbuff, void* recvbuff, size_t sendBufferSize, + size_t recvBufferSize, int rank, const ExecutionPlan& plan) { const auto channelTypes = {ChannelType::SM, ChannelType::PROXY}; std::vector> smSemaphores; std::vector proxySemaphores; @@ -251,13 +242,27 @@ struct Executor::Impl { throw Error("Invalid buffer type", ErrorCode::ExecutorError); } }; + auto getBufferSize = [&](BufferType type) { + switch (type) { + case BufferType::INPUT: + return sendBufferSize; + case BufferType::OUTPUT: + return recvBufferSize; + case BufferType::SCRATCH: + return context.scratchBufferSize; + default: + throw Error("Invalid buffer type", ErrorCode::ExecutorError); + } + }; + for (ChannelType channelType : channelTypes) { std::vector channelInfos = plan.impl_->getChannelInfos(rank, channelType); int index = 0; for (ChannelInfo& info : channelInfos) { void* src = getBuffer(info.srcBufferType); - TransportFlags transport = context.registeredMemories.begin()->second.transports(); - RegisteredMemory localMemory = this->comm->registerMemory(src, sendBufferSize, transport); + size_t bufferSize = getBufferSize(info.srcBufferType); + TransportFlags transport = getTransportFlags(channelInfos, rank); + RegisteredMemory localMemory = this->comm->registerMemory(src, bufferSize, transport); for (int peer : info.connectedPeers) { if (channelType == ChannelType::SM) { context.smChannels.emplace_back(context.smSemaphores[index++],