Skip to content

Commit

Permalink
Executor AllGather In-Place Support (#365)
Browse files Browse the repository at this point in the history
  • Loading branch information
caiomcbr authored Oct 21, 2024
1 parent 4136153 commit c6e06cf
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 41 deletions.
25 changes: 20 additions & 5 deletions python/test/executor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,15 @@ def dtype_to_mscclpp_dtype(dtype):
raise ValueError(f"Unknown data type: {dtype}")


def determine_result_buf(sendbuf, recvbuf, in_place, execution_plan_name):
if "allgather" in execution_plan_name:
return recvbuf
elif in_place:
return sendbuf
else:
return recvbuf


def main(
execution_plan_name: str,
execution_plan_path: str,
Expand Down Expand Up @@ -104,9 +113,11 @@ def main(

if "allgather" in execution_plan_name:
recvbuf = cp.zeros(nelems * mscclpp_group.nranks, dtype=dtype)
if in_place:
for i in range(nelems):
recvbuf[mscclpp_group.my_rank * nelems + i] = sendbuf[i]
expected = buffer
else:
cp.random.seed(seed)
recvbuf = cp.zeros(nelems, dtype=dtype)
expected = cp.zeros_like(sendbuf, dtype=dtype)
for i in range(mscclpp_group.nranks):
Expand All @@ -116,9 +127,9 @@ def main(
executor_func = lambda stream: executor.execute(
MPI.COMM_WORLD.rank,
sendbuf.data.ptr,
sendbuf.data.ptr if in_place else recvbuf.data.ptr,
determine_result_buf(sendbuf, recvbuf, in_place, execution_plan_name).data.ptr,
sendbuf.nbytes,
sendbuf.nbytes if in_place else recvbuf.nbytes,
determine_result_buf(sendbuf, recvbuf, in_place, execution_plan_name).nbytes,
dtype_to_mscclpp_dtype(dtype),
execution_plan,
stream.ptr,
Expand All @@ -129,10 +140,14 @@ def main(
executor_func(stream)
stream.synchronize()

assert cp.allclose(sendbuf if in_place else recvbuf, expected, atol=1e-2 * mscclpp_group.nranks)
assert cp.allclose(
determine_result_buf(sendbuf, recvbuf, in_place, execution_plan_name),
expected,
atol=1e-2 * mscclpp_group.nranks,
)

mscclpp_group.barrier()
execution_time = bench_time(100, 10, executor_func)
execution_time = bench_time(10, 10, executor_func)
if npkit_dump_dir is not None:
npkit.dump(npkit_dump_dir)
npkit.shutdown()
Expand Down
81 changes: 58 additions & 23 deletions src/executor/execution_plan.cc
Original file line number Diff line number Diff line change
Expand Up @@ -148,12 +148,19 @@ std::vector<BufferType> ExecutionPlan::Impl::getConnectedBufferTypes(int rank) c
}
return std::vector<BufferType>(bufferTypes.begin(), bufferTypes.end());
}
size_t ExecutionPlan::Impl::getScratchBufferSize(int rank, size_t inputSize) const {
size_t ExecutionPlan::Impl::getScratchBufferSize(int rank, size_t inputSize, size_t outputSize) const {
size_t sizePerRank;
if (this->inputChunks.at(rank) != 0)
sizePerRank = inputSize / this->inputChunks.at(rank);
else if (this->outputChunks.at(rank) != 0)
sizePerRank = outputSize / this->outputChunks.at(rank);
else
throw mscclpp::Error("Output or Input chunks must be greater than 0", mscclpp::ErrorCode::ExecutorError);

if (this->isUsingPacket) {
return inputSize / this->inputChunks.at(rank) * this->scratchChunks.at(rank) * 2 /* data + flag*/ *
2 /*double buffer*/;
return sizePerRank * this->scratchChunks.at(rank) * 2 /* data + flag*/ * 2 /*double buffer*/;
}
return inputSize / this->inputChunks.at(rank) * this->scratchChunks.at(rank);
return sizePerRank * this->scratchChunks.at(rank);
}
std::vector<Operation> ExecutionPlan::Impl::getOperations(int rank, int threadblock) const {
return this->operations.at(rank)[threadblock];
Expand All @@ -163,7 +170,8 @@ int ExecutionPlan::Impl::getThreadblockCount(int rank) const { return this->oper

int ExecutionPlan::Impl::getNThreadsPerBlock() const { return this->nThreadsPerBlock; }

void ExecutionPlan::Impl::loadExecutionPlan(size_t inputSize, size_t contsSrcOffset, size_t constDstOffset) {
void ExecutionPlan::Impl::loadExecutionPlan(size_t inputSize, size_t outputSize, size_t contsSrcOffset,
size_t constDstOffset) {
std::ifstream file(this->planPath);
json obj = json::parse(file);
if (this->name != obj["name"]) {
Expand All @@ -186,10 +194,12 @@ void ExecutionPlan::Impl::loadExecutionPlan(size_t inputSize, size_t contsSrcOff
this->setupChannels(gpus);

this->inputSize = inputSize;
this->outputSize = outputSize;
this->setupOperations(gpus, contsSrcOffset, constDstOffset);
}

void ExecutionPlan::Impl::lightLoadExecutionPlan(size_t inputSize, size_t contsSrcOffset, size_t constDstOffset) {
void ExecutionPlan::Impl::lightLoadExecutionPlan(size_t inputSize, size_t outputSize, size_t contsSrcOffset,
size_t constDstOffset) {
std::ifstream file(this->planPath);
json obj = json::parse(file);
if (this->name != obj["name"]) {
Expand All @@ -210,6 +220,7 @@ void ExecutionPlan::Impl::lightLoadExecutionPlan(size_t inputSize, size_t contsS
}

this->inputSize = inputSize;
this->outputSize = outputSize;
this->setupOperations(gpus, contsSrcOffset, constDstOffset);
}

Expand Down Expand Up @@ -313,8 +324,9 @@ void ExecutionPlan::Impl::setupOperations(const json& gpus, size_t contsSrcOffse
// Get the relevant channel index in rank channelInfos
operation.inputChannelIndexes[i] =
channelIndexes[{srcBufferType, dstBufferType, operation.channelType}][op["i_cids"][i]["id"]];
operation.inputOffsets[i] = this->getOffset(rank, this->inputSize, (uint32_t)op["i_cids"][i]["off"]) +
(srcBufferType != BufferType::SCRATCH ? contsSrcOffset : 0);
operation.inputOffsets[i] =
this->getOffset(rank, this->inputSize, this->outputSize, (uint32_t)op["i_cids"][i]["off"]) +
(srcBufferType != BufferType::SCRATCH ? contsSrcOffset : 0);
chunkIndexes.push_back((uint32_t)op["i_cids"][i]["off"]);
}
}
Expand All @@ -323,8 +335,9 @@ void ExecutionPlan::Impl::setupOperations(const json& gpus, size_t contsSrcOffse
operation.nInputs = op["srcs"].size();
operation.inputBufferType = convertToBufferType(op["srcs"][0]["buff"]);
for (int i = 0; i < operation.nInputs; i++) {
operation.inputOffsets[i] = this->getOffset(rank, this->inputSize, (uint32_t)op["srcs"][i]["off"]) +
(operation.inputBufferType != BufferType::SCRATCH ? contsSrcOffset : 0);
operation.inputOffsets[i] =
this->getOffset(rank, this->inputSize, this->outputSize, (uint32_t)op["srcs"][i]["off"]) +
(operation.inputBufferType != BufferType::SCRATCH ? contsSrcOffset : 0);
chunkIndexes.push_back((uint32_t)op["srcs"][i]["off"]);
}
}
Expand All @@ -335,8 +348,9 @@ void ExecutionPlan::Impl::setupOperations(const json& gpus, size_t contsSrcOffse
BufferType dstBufferType = convertToBufferType(op["o_buff"]["dst"]);
operation.outputChannelIndexes[i] =
channelIndexes[{srcBufferType, dstBufferType, operation.channelType}][op["o_cids"][i]["id"]];
operation.outputOffsets[i] = this->getOffset(rank, this->inputSize, (uint32_t)op["o_cids"][i]["off"]) +
(dstBufferType != BufferType::SCRATCH ? constDstOffset : 0);
operation.outputOffsets[i] =
this->getOffset(rank, this->inputSize, this->outputSize, (uint32_t)op["o_cids"][i]["off"]) +
(dstBufferType != BufferType::SCRATCH ? constDstOffset : 0);
chunkIndexes.push_back((uint32_t)op["o_cids"][i]["off"]);
}
}
Expand All @@ -345,27 +359,29 @@ void ExecutionPlan::Impl::setupOperations(const json& gpus, size_t contsSrcOffse
operation.nOutputs = op["dsts"].size();
operation.outputBufferType = convertToBufferType(op["dsts"][0]["buff"]);
for (int i = 0; i < operation.nOutputs; i++) {
operation.outputOffsets[i] = this->getOffset(rank, this->inputSize, (uint32_t)op["dsts"][i]["off"]) +
(operation.outputBufferType != BufferType::SCRATCH ? constDstOffset : 0);
operation.outputOffsets[i] =
this->getOffset(rank, this->inputSize, this->outputSize, (uint32_t)op["dsts"][i]["off"]) +
(operation.outputBufferType != BufferType::SCRATCH ? constDstOffset : 0);
chunkIndexes.push_back((uint32_t)op["dsts"][i]["off"]);
}
}
if (op.contains("srcbuff")) {
operation.srcBufferType = convertToBufferType(op["srcbuff"]);
}
if (op.contains("srcoff")) {
operation.srcOffset = this->getOffset(rank, this->inputSize, (uint32_t)op["srcoff"]);
operation.srcOffset = this->getOffset(rank, this->inputSize, this->outputSize, (uint32_t)op["srcoff"]);
chunkIndexes.push_back((uint32_t)op["srcoff"]);
}
if (op.contains("dstbuff")) {
operation.dstBufferType = convertToBufferType(op["dstbuff"]);
}
if (op.contains("dstoff")) {
operation.dstOffset = this->getOffset(rank, this->inputSize, (uint32_t)op["dstoff"]);
operation.dstOffset = this->getOffset(rank, this->inputSize, this->outputSize, (uint32_t)op["dstoff"]);
chunkIndexes.push_back((uint32_t)op["dstoff"]);
}
if (op.contains("cnt")) {
operation.size = this->getNChunkSize(rank, this->inputSize, (uint32_t)op["cnt"], chunkIndexes);
operation.size =
this->getNChunkSize(rank, this->inputSize, this->outputSize, (uint32_t)op["cnt"], chunkIndexes);
}
ops.push_back(operation);
}
Expand All @@ -374,14 +390,33 @@ void ExecutionPlan::Impl::setupOperations(const json& gpus, size_t contsSrcOffse
}
}

size_t ExecutionPlan::Impl::getOffset(int rank, size_t inputSize, uint32_t chunkIndex, uint32_t alignment) const {
std::pair<size_t, u_int32_t> ExecutionPlan::Impl::calcSizePerRank(int rank, size_t inputSize, size_t outputSize) const {
std::pair<size_t, u_int32_t> sizePerRank;
if (this->inputChunks.at(rank) == 0 && this->outputChunks.at(rank) == 0) {
throw mscclpp::Error("Output or Input chunks must be greater than 0", mscclpp::ErrorCode::ExecutorError);
} else if (this->inputChunks.at(rank) != 0 && this->outputChunks.at(rank) != 0) {
if (inputSize / this->inputChunks.at(rank) != outputSize / this->outputChunks.at(rank))
throw mscclpp::Error("Size per chunks inconsistent", mscclpp::ErrorCode::ExecutorError);
else
sizePerRank = std::make_pair(inputSize, this->inputChunks.at(rank));
} else if (this->inputChunks.at(rank) != 0) {
sizePerRank = std::make_pair(inputSize, this->inputChunks.at(rank));
} else if (this->outputChunks.at(rank) != 0) {
sizePerRank = std::make_pair(outputSize, this->outputChunks.at(rank));
}
return sizePerRank;
}

size_t ExecutionPlan::Impl::getOffset(int rank, size_t inputSize, size_t outputSize, uint32_t chunkIndex,
uint32_t alignment) const {
if (inputSize % alignment != 0) {
throw Error("inputSize must be a multiple of alignment", ErrorCode::ExecutorError);
}

const int nGroups = this->chunkGroups.at(rank);
uint32_t nInputChunks = this->inputChunks.at(rank);
uint32_t nelems = inputSize / (alignment * sizeof(uint8_t));
auto sizePerRank = calcSizePerRank(rank, inputSize, outputSize);
uint32_t nInputChunks = sizePerRank.second;
uint32_t nelems = sizePerRank.first / (alignment * sizeof(uint8_t));
if (nelems % nGroups != 0) {
throw Error("Input size must be a multiple of nGroups", ErrorCode::ExecutorError);
}
Expand All @@ -397,12 +432,12 @@ size_t ExecutionPlan::Impl::getOffset(int rank, size_t inputSize, uint32_t chunk
return static_cast<size_t>(offset) * alignment;
}

size_t ExecutionPlan::Impl::getNChunkSize(int rank, size_t inputSize, uint32_t nChunks,
size_t ExecutionPlan::Impl::getNChunkSize(int rank, size_t inputSize, size_t outputSize, uint32_t nChunks,
const std::vector<uint32_t> chunkIndexes) const {
size_t nChunkSize = 0;
for (uint32_t index : chunkIndexes) {
uint32_t beginOff = getOffset(rank, inputSize, index);
uint32_t endOff = getOffset(rank, inputSize, index + nChunks);
uint32_t beginOff = getOffset(rank, inputSize, outputSize, index);
uint32_t endOff = getOffset(rank, inputSize, outputSize, index + nChunks);
if (nChunkSize == 0) {
nChunkSize = endOff - beginOff;
} else if (nChunkSize != endOff - beginOff) {
Expand Down
17 changes: 9 additions & 8 deletions src/executor/executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,13 +80,13 @@ struct Executor::Impl {
}
~Impl() = default;

ExecutionContext setupExecutionContext(int rank, void* sendbuff, void* recvbuff, size_t messageSize,
size_t contsSrcOffset, size_t constDstOffset, size_t sendBufferSize,
size_t recvBufferSize, const ExecutionPlan& plan) {
ExecutionContext setupExecutionContext(int rank, void* sendbuff, void* recvbuff, size_t inputMessageSize,
size_t outputMessageSize, size_t contsSrcOffset, size_t constDstOffset,
size_t sendBufferSize, size_t recvBufferSize, const ExecutionPlan& plan) {
ExecutionContextKey key = {sendbuff, recvbuff, sendBufferSize, recvBufferSize, plan.impl_->name};
if (this->contexts.find(key) != this->contexts.end()) {
plan.impl_->operationsReset();
plan.impl_->lightLoadExecutionPlan(messageSize, contsSrcOffset, constDstOffset);
plan.impl_->lightLoadExecutionPlan(inputMessageSize, outputMessageSize, contsSrcOffset, constDstOffset);
this->setupDeviceExecutionPlan(this->contexts[key], rank, plan);
this->contexts[key].deviceExecutionPlansBuffer =
allocExtSharedCuda<char>(this->contexts[key].deviceExecutionPlans.size() * sizeof(DeviceExecutionPlan));
Expand All @@ -97,10 +97,10 @@ struct Executor::Impl {
}

plan.impl_->reset();
plan.impl_->loadExecutionPlan(messageSize, contsSrcOffset, constDstOffset);
plan.impl_->loadExecutionPlan(inputMessageSize, outputMessageSize, contsSrcOffset, constDstOffset);

ExecutionContext context;
size_t scratchBufferSize = plan.impl_->getScratchBufferSize(rank, sendBufferSize);
size_t scratchBufferSize = plan.impl_->getScratchBufferSize(rank, sendBufferSize, recvBufferSize);
std::shared_ptr<char> scratchBuffer = allocExtSharedCuda<char>(scratchBufferSize);
context.scratchBuffer = scratchBuffer;
context.scratchBufferSize = scratchBufferSize;
Expand Down Expand Up @@ -350,8 +350,9 @@ void Executor::execute(int rank, void* sendbuff, void* recvbuff, size_t sendBuff
size_t offsetIn = (char*)sendbuff - (char*)sendBasePtr;
size_t offsetOut = (char*)recvbuff - (char*)recvBasePtr;

ExecutionContext context = this->impl_->setupExecutionContext(
rank, (void*)sendBasePtr, (void*)recvBasePtr, sendBuffSize, offsetIn, offsetOut, sendBytes, recvBytes, plan);
ExecutionContext context =
this->impl_->setupExecutionContext(rank, (void*)sendBasePtr, (void*)recvBasePtr, sendBuffSize, recvBuffSize,
offsetIn, offsetOut, sendBytes, recvBytes, plan);
this->impl_->launchKernel(context, rank, sendbuff, recvbuff, dataType, stream, packetType);
}

Expand Down
13 changes: 8 additions & 5 deletions src/include/execution_plan.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,13 @@ struct ExecutionPlan::Impl {
std::vector<ChannelInfo> getUnpairedChannelInfos(int rank, int worldSize, ChannelType channelType);
std::vector<int> getConnectedPeers(int rank) const;
std::vector<BufferType> getConnectedBufferTypes(int rank) const;
size_t getScratchBufferSize(int rank, size_t inputSize) const;
size_t getScratchBufferSize(int rank, size_t inputSize, size_t outputSize) const;
std::vector<Operation> getOperations(int rank, int threadblock) const;
int getThreadblockCount(int rank) const;
int getNThreadsPerBlock() const;

void loadExecutionPlan(size_t inputSize, size_t contsSrcOffset, size_t constDstOffset);
void lightLoadExecutionPlan(size_t inputSize, size_t contsSrcOffset, size_t constDstOffset);
void loadExecutionPlan(size_t inputSize, size_t outputSize, size_t contsSrcOffset, size_t constDstOffset);
void lightLoadExecutionPlan(size_t inputSize, size_t outputSize, size_t contsSrcOffset, size_t constDstOffset);
void setupChannels(const nlohmann::json& gpus);
void setupOperations(const nlohmann::json& gpus, size_t contsSrcOffset, size_t constDstOffset);

Expand All @@ -94,11 +94,14 @@ struct ExecutionPlan::Impl {
std::unordered_map<int, uint32_t> scratchChunks;
std::unordered_map<int, uint32_t> chunkGroups;
size_t inputSize;
size_t outputSize;
int nThreadsPerBlock;

private:
size_t getOffset(int rank, size_t inputSize, uint32_t chunkIndex, uint32_t alignment = 16) const;
size_t getNChunkSize(int rank, size_t inputSize, uint32_t nChunks, const std::vector<uint32_t> offsets) const;
std::pair<size_t, u_int32_t> calcSizePerRank(int rank, size_t inputSize, size_t outputSize) const;
size_t getOffset(int rank, size_t inputSize, size_t outputSize, uint32_t chunkIndex, uint32_t alignment = 16) const;
size_t getNChunkSize(int rank, size_t inputSize, size_t outputSize, uint32_t nChunks,
const std::vector<uint32_t> offsets) const;
};

} // namespace mscclpp
Expand Down

0 comments on commit c6e06cf

Please sign in to comment.