Skip to content

Commit

Permalink
Simplify rendezvous logic
Browse files Browse the repository at this point in the history
  • Loading branch information
ghpvnist committed Sep 12, 2023
1 parent 0f6bb87 commit a480caa
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 35 deletions.
5 changes: 2 additions & 3 deletions stablehlo/reference/Process.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,8 @@ ProcessId Process::getId() { return id_; }

void Process::outfeed(ArrayRef<Tensor> inputs) { grid_->outfeed(inputs); }

std::shared_ptr<RendezvousResult> Process::rendezvous(ProcessGroup processGroup,
ChannelId channelId,
const Tensor &operand) {
const std::shared_ptr<RendezvousResult> Process::rendezvous(
ProcessGroup processGroup, ChannelId channelId, const Tensor &operand) {
return grid_->rendezvous(processGroup, channelId, getId(), operand);
}

Expand Down
6 changes: 3 additions & 3 deletions stablehlo/reference/Process.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,9 @@ class Process {
void outfeed(ArrayRef<Tensor> inputs);

/// See `ProcessGrid::rendezvous`.
std::shared_ptr<RendezvousResult> rendezvous(ProcessGroup processGroup,
ChannelId channelId,
const Tensor &operand);
const std::shared_ptr<RendezvousResult> rendezvous(ProcessGroup processGroup,
ChannelId channelId,
const Tensor &operand);

private:
/// StableHLO `process_id`.
Expand Down
61 changes: 36 additions & 25 deletions stablehlo/reference/ProcessGrid.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -162,12 +162,11 @@ ProcessGroups ProcessGrid::flattenedIds(

void ProcessGrid::outfeed(ArrayRef<Tensor> inputs) { outfeed_.push(inputs); }

std::shared_ptr<RendezvousResult> ProcessGrid::rendezvous(
const std::shared_ptr<RendezvousResult> ProcessGrid::rendezvous(
ProcessGroup processGroup, ChannelId channelId, ProcessId processId,
const Tensor &operand) {
std::pair<ProcessGroup, ChannelId> channelKey(processGroup, channelId);
// Immediately return the result. The logic below doesn't work for a single
// process.
// Process wait/notify logic below doesn't work for single process.
if (processGroup.size() == 1)
return std::make_shared<RendezvousResult>(
RendezvousResult({std::pair{processId, operand}}));
Expand All @@ -177,37 +176,49 @@ std::shared_ptr<RendezvousResult> ProcessGrid::rendezvous(
std::unique_lock<std::mutex> lock(state.mutex);
state.values[processId] = operand;

std::shared_ptr<RendezvousResult> result;
if (state.values.size() == processGroup.size()) {
// The last process to contribute moves the values into the result.
result = std::make_shared<RendezvousResult>(state.values);
state.result = result;
// If values are full, that means all other processes are currently waiting.
// The last process to contribute moves the values into the result
// then waits for each process to return a copy of the result before
// cleaning up the state variable for future computations in this process
// grid.
state.result = std::make_shared<RendezvousResult>(state.values);
state.values.clear();
channelConditions_[channelKey].notify_one();
} else {
// The remaining processes wait for the last process to contribute to move
// the values into the shared result.

// The last process to contribute waits until the rest of the processes have
// read the values.
if (!channelConditions_[channelKey].wait_for(
lock, std::chrono::seconds(3),
[&] { return state.result != nullptr; }))
lock, std::chrono::seconds(3), [&] {
return state.result.use_count() >=
static_cast<int64_t>(processGroup.size());
}))
llvm::report_fatal_error(
"rendezvous timed out: not all processes have contributed yet");

// The shared result from the state owns one, the last process to contribute
// owns one, and the remaining processes (except the last) owns one here.
if (state.result.use_count() < static_cast<int64_t>(processGroup.size())) {
result = state.result;
channelConditions_[channelKey].notify_one();
} else {
// Of the remaining processes, the last remaining process to arrive takes
// the result from the state to allow the process that contributed last to
// exit the function.
channelConditions_[channelKey].notify_one();
return std::move(state.result);
}
if (state.result.use_count() > static_cast<int64_t>(processGroup.size()))
llvm::report_fatal_error(
"Each process should have only one shared access to the result.");

// The last process to contribute takes the result from the state to allow
// the process that contributed last to exit the function.
channelConditions_[channelKey].notify_one();
return std::move(state.result);
}

// Wait for the remaining processes to have retrieved the result.
// Wait for all processes to contribute values.
if (!channelConditions_[channelKey].wait_for(
lock, std::chrono::seconds(3),
[&] { return state.result != nullptr; }))
llvm::report_fatal_error(
"rendezvous timed out: not all process has received the results yet");

// Copy result from the state before notifying.
auto result = state.result;
channelConditions_[channelKey].notify_one();

// Wait for the remaining processes to have retrieved the result. In other
// words, wait until the last process to contribute exit the function.
if (!channelConditions_[channelKey].wait_for(
lock, std::chrono::seconds(3),
[&] { return state.result == nullptr; }))
Expand Down
8 changes: 4 additions & 4 deletions stablehlo/reference/ProcessGrid.h
Original file line number Diff line number Diff line change
Expand Up @@ -134,10 +134,10 @@ class ProcessGrid {
/// tensors are accumulated in `RendezvousResult` whose shard pointer is
/// returned to all callers once the barrier has been reached by all StableHLO
/// processes.
std::shared_ptr<RendezvousResult> rendezvous(ProcessGroup processGroup,
ChannelId channelId,
ProcessId processId,
const Tensor &operand);
const std::shared_ptr<RendezvousResult> rendezvous(ProcessGroup processGroup,
ChannelId channelId,
ProcessId processId,
const Tensor &operand);

private:
/// Internal storate used in `rendezvous` to manage concurrent access to the
Expand Down

0 comments on commit a480caa

Please sign in to comment.