diff --git a/stablehlo/reference/Ops.cpp b/stablehlo/reference/Ops.cpp index 63dae171003..8875ad5e800 100644 --- a/stablehlo/reference/Ops.cpp +++ b/stablehlo/reference/Ops.cpp @@ -859,8 +859,8 @@ Tensor evalAllGatherOp(const Tensor &operand, int64_t allGatherDim, "Failed to find process group with process_id: (%d, %d)", process->getId().replicaId, process->getId().partitionId)); - auto groupOperands = - process->rendezvous(*processGroup, channelId, operand).getSortedTensors(); + auto groupOperands = process->rendezvous(*processGroup, channelId, operand) + ->getSortedTensors(); return evalConcatenateOp(groupOperands, allGatherDim, resultType); } @@ -888,8 +888,8 @@ Tensor evalAllReduceOp(const Tensor &operand, "Failed to find process group with process_id: (%d, %d)", process->getId().replicaId, process->getId().partitionId)); - auto groupOperands = - process->rendezvous(*processGroup, channelId, operand).getSortedTensors(); + auto groupOperands = process->rendezvous(*processGroup, channelId, operand) + ->getSortedTensors(); Tensor result(resultType); for (auto resultIt = result.index_begin(); resultIt != result.index_end(); @@ -929,8 +929,8 @@ Tensor evalAllToAllOp(const Tensor &operand, Axis splitDimension, "Failed to find process group with process_id: (%d, %d)", process->getId().replicaId, process->getId().partitionId)); - auto groupOperands = - process->rendezvous(*processGroup, channelId, operand).getSortedTensors(); + auto groupOperands = process->rendezvous(*processGroup, channelId, operand) + ->getSortedTensors(); SmallVector scatteredParts; for (const auto &groupOperand : groupOperands) { @@ -1080,7 +1080,7 @@ Tensor evalCollectivePermuteOp( if (from != process->getId() && to != process->getId()) continue; auto rendezvousResult = - process->rendezvous(processGroup, channelId, operand); + *process->rendezvous(processGroup, channelId, operand); if (to != process->getId()) continue; result = rendezvousResult.lookup(from); } diff --git a/stablehlo/reference/Process.cpp b/stablehlo/reference/Process.cpp index a29ed5f96b3..f87583a613f 100644 --- a/stablehlo/reference/Process.cpp +++ b/stablehlo/reference/Process.cpp @@ -44,9 +44,8 @@ ProcessId Process::getId() { return id_; } void Process::outfeed(ArrayRef inputs) { grid_->outfeed(inputs); } -RendezvousResult Process::rendezvous(ProcessGroup processGroup, - ChannelId channelId, - const Tensor &operand) { +const std::shared_ptr Process::rendezvous( + ProcessGroup processGroup, ChannelId channelId, const Tensor &operand) { return grid_->rendezvous(processGroup, channelId, getId(), operand); } diff --git a/stablehlo/reference/Process.h b/stablehlo/reference/Process.h index 11da144fbe7..e40a031ee51 100644 --- a/stablehlo/reference/Process.h +++ b/stablehlo/reference/Process.h @@ -54,8 +54,9 @@ class Process { void outfeed(ArrayRef inputs); /// See `ProcessGrid::rendezvous`. - RendezvousResult rendezvous(ProcessGroup processGroup, ChannelId channelId, - const Tensor &operand); + const std::shared_ptr rendezvous(ProcessGroup processGroup, + ChannelId channelId, + const Tensor &operand); private: /// StableHLO `process_id`. diff --git a/stablehlo/reference/ProcessGrid.cpp b/stablehlo/reference/ProcessGrid.cpp index 5f221221dba..a6b38f50dea 100644 --- a/stablehlo/reference/ProcessGrid.cpp +++ b/stablehlo/reference/ProcessGrid.cpp @@ -59,7 +59,8 @@ std::optional ProcessGroups::findGroup(ProcessId processId) { // RendezvousResult. //===----------------------------------------------------------------------===// -void RendezvousResult::clear() { result_.clear(); } +RendezvousResult::RendezvousResult(std::map result) + : result_(result) {} void RendezvousResult::insert(ProcessId processId, Tensor tensor) { result_[processId] = tensor; @@ -76,7 +77,24 @@ SmallVector RendezvousResult::getSortedTensors() { llvm::map_range(result_, [](const auto &pair) { return pair.second; })); } -size_t RendezvousResult::size() { return result_.size(); } +//===----------------------------------------------------------------------===// +// ThreadSafeMap. +//===----------------------------------------------------------------------===// + +template +V &ProcessGrid::ThreadSafeMap::operator[](const K &key) { + std::lock_guard lock(lock_); + return map_[key]; +} + +//===----------------------------------------------------------------------===// +// ThreadSafeQueue. +//===----------------------------------------------------------------------===// + +void ProcessGrid::ThreadSafeQueue::push(ArrayRef inputs) { + std::lock_guard lock(lock_); + queue_.emplace(inputs); +} //===----------------------------------------------------------------------===// // ProcessGrid. @@ -142,45 +160,72 @@ ProcessGroups ProcessGrid::flattenedIds( return processGroups; } -std::mutex &ProcessGrid::getRendezvousLock(ProcessGroup processGroup, - ChannelId channelId) { - std::lock_guard lock(rendezvousLock_); - std::pair channelKey(processGroup, channelId); - return channelLocks_[channelKey]; -} +void ProcessGrid::outfeed(ArrayRef inputs) { outfeed_.push(inputs); } -void ProcessGrid::outfeed(ArrayRef inputs) { - std::lock_guard lock(outfeedLock_); - outfeed_.emplace(inputs); -} - -RendezvousResult ProcessGrid::rendezvous(ProcessGroup processGroup, - ChannelId channelId, - ProcessId processId, - const Tensor &operand) { +const std::shared_ptr ProcessGrid::rendezvous( + ProcessGroup processGroup, ChannelId channelId, ProcessId processId, + const Tensor &operand) { std::pair channelKey(processGroup, channelId); - { - std::lock_guard lock( - getRendezvousLock(processGroup, channelId)); - if (channels_[channelKey].size() == processGroup.size()) - channels_[channelKey].clear(); - - channels_[channelKey].insert(processId, operand); - } - { - std::unique_lock lock( - getRendezvousLock(processGroup, channelId)); - if (channels_[channelKey].size() == processGroup.size()) - channelConditions_[channelKey].notify_all(); - + // Process wait/notify logic below doesn't work for single process. + if (processGroup.size() == 1) + return std::make_shared( + RendezvousResult({std::pair{processId, operand}})); + + auto &state = channels_[channelKey]; + + std::unique_lock lock(state.mutex); + state.values[processId] = operand; + + if (state.values.size() == processGroup.size()) { + // If values are full, that means all other processes are currently waiting. + // The last process to contribute moves the values into the result + // then waits for each process to return a copy of the result before + // cleaning up the state variable for future computations in this process + // grid. + state.result = std::make_shared(state.values); + state.values.clear(); + channelConditions_[channelKey].notify_one(); + + // The last process to contribute waits until the rest of the processes have + // read the values. if (!channelConditions_[channelKey].wait_for( lock, std::chrono::seconds(3), [&] { - return channels_[channelKey].size() == processGroup.size(); + return state.result.use_count() >= + static_cast(processGroup.size()); })) - llvm::report_fatal_error("rendezvous timed out"); + llvm::report_fatal_error( + "rendezvous timed out: not all processes have contributed yet"); - return channels_[channelKey]; + if (state.result.use_count() > static_cast(processGroup.size())) + llvm::report_fatal_error( + "Each process should have only one shared access to the result."); + + // The last process to contribute takes the result from the state to allow + // the process that contributed last to exit the function. + channelConditions_[channelKey].notify_one(); + return std::move(state.result); } + + // Wait for all processes to contribute values. + if (!channelConditions_[channelKey].wait_for( + lock, std::chrono::seconds(3), + [&] { return state.result != nullptr; })) + llvm::report_fatal_error( + "rendezvous timed out: not all process has received the results yet"); + + // Copy result from the state before notifying. + auto result = state.result; + channelConditions_[channelKey].notify_one(); + + // Wait for the remaining processes to have retrieved the result. In other + // words, wait until the last process to contribute exit the function. + if (!channelConditions_[channelKey].wait_for( + lock, std::chrono::seconds(3), + [&] { return state.result == nullptr; })) + llvm::report_fatal_error( + "rendezvous timed out: not all process has received the results yet"); + + return result; } } // namespace stablehlo diff --git a/stablehlo/reference/ProcessGrid.h b/stablehlo/reference/ProcessGrid.h index 33b8ae5d762..bcbc6512ec5 100644 --- a/stablehlo/reference/ProcessGrid.h +++ b/stablehlo/reference/ProcessGrid.h @@ -19,6 +19,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -69,8 +70,7 @@ class ProcessGroups : public SmallVector { /// map-like API. class RendezvousResult { public: - /// Erases all elements in the map. - void clear(); + RendezvousResult(std::map result); /// Iterates through the (ProcessId, Tensor) map entires and returns a vector /// of Tensors sorted by ProcessId--(replicaId, partitionId) pair--in @@ -84,9 +84,6 @@ class RendezvousResult { /// `processId`. If key is not found, return an empty `Tensor`. Tensor lookup(ProcessId processId); - /// Returns the number of elements in the map. - size_t size(); - private: /// Internal map representation of the result of `ProcessGrid::rendezvous`. std::map result_; @@ -134,16 +131,60 @@ class ProcessGrid { /// deadlock the interpreter. /// /// At the barrier, each StableHLO process contributes a tensor, and these - /// tensors are accumulated in `RendezvousResult` which is returned to all - /// callers once the barrier has been reached by all StableHLO processes. - RendezvousResult rendezvous(ProcessGroup processGroup, ChannelId channelId, - ProcessId processId, const Tensor &operand); + /// tensors are accumulated in `RendezvousResult` whose shard pointer is + /// returned to all callers once the barrier has been reached by all StableHLO + /// processes. + const std::shared_ptr rendezvous(ProcessGroup processGroup, + ChannelId channelId, + ProcessId processId, + const Tensor &operand); private: - /// Obtain a mutex that is shared between all processes participating in - /// a call to `rendezvous` for a given combination of `processGroup` and - /// `channelId`. - std::mutex &getRendezvousLock(ProcessGroup processGroup, ChannelId channelId); + /// Internal storate used in `rendezvous` to manage concurrent access to the + /// shared resource. Processes contribute their data to `values` concurrently. + /// Once all processes have added their data, the data in `values` is moved to + /// `result` that multiple processes can concurrently read from. + struct RendezvousState { + /// Synchronization primitive used to manage concurrent access to this + /// object. + std::mutex mutex; + /// Internal storage used to store data contributed by the processes. + std::map values; + /// Shared pointer to the result of `rendezvous`. + std::shared_ptr result; + }; + + /// Stores the result of `rendezvous` represented as a map that allows + /// concurrent access. + /// Each call to `rendezvous`, i.e. each combination `processGroup` and + /// `channelId`, has its own key in the map. Within the implementation of + /// `rendezvous`, the value corresponding to this key is gradually populated + /// with tensors arriving from different processes in the process group. + template + class ThreadSafeMap { + public: + /// Returns a reference to the data associated with the `key`. + V &operator[](const K &key); + + private: + /// Synchronization primitive used to manage concurrent access to the map. + std::mutex lock_; + /// Internal storage used to implement `rendezvous`. + std::map map_; + }; + + /// StableHLO `outfeed` represented as a queue that allows concurrent access. + class ThreadSafeQueue { + public: + /// Add `inputs` to the end of the queue. + void push(ArrayRef inputs); + + private: + /// Synchronization primitive used to manage concurrent access to the queue. + std::mutex lock_; + /// Internal storage used to implement StableHLO `outfeed`. + std::queue> queue_; + }; /// StableHLO `num_replicas`. const uint32_t numReplicas_; @@ -151,28 +192,14 @@ class ProcessGrid { /// StableHLO `num_partitions`. const uint32_t numPartitions_; - /// StableHLO `outfeed` represented as a queue. - std::queue> outfeed_; - - std::mutex outfeedLock_; - - /// Synchronization primitive used to manage concurrent access to - /// `channelLocks_`. - std::mutex rendezvousLock_; + /// See `ThreadSafeQueue`. + ThreadSafeQueue outfeed_; - /// Internal storage used to implement `rendezvous`. - /// Each call to `rendezvous`, i.e. each combination `processGroup` and - /// `channelId`, has its own key in the map. - /// Within the implementation of `rendezvous`, the value corresponding to - /// this key is gradually populated with tensors arriving from different - /// processes in the process group. - std::map, RendezvousResult> channels_; - - /// Synchronization primitive used to manage concurrent access to `channels_`. - std::map, std::mutex> channelLocks_; + /// See `ThreadSafeMap`. + ThreadSafeMap, RendezvousState> channels_; /// Synchronization primitive used to manage concurrent access to `channels_`. - std::map, std::condition_variable> + ThreadSafeMap, std::condition_variable> channelConditions_; }; diff --git a/stablehlo/reference/Tensor.h b/stablehlo/reference/Tensor.h index 205977dd798..b4ea0c6ecfe 100644 --- a/stablehlo/reference/Tensor.h +++ b/stablehlo/reference/Tensor.h @@ -36,7 +36,7 @@ namespace stablehlo { namespace detail { /// Underlying storage class for Tensor objects. -class Buffer : public llvm::RefCountedBase { +class Buffer : public llvm::ThreadSafeRefCountedBase { public: /// \name Constructors /// @{ diff --git a/stablehlo/tests/interpret_all_gather.mlir b/stablehlo/tests/interpret_all_gather.mlir index f206586d79b..97c1494ea67 100644 --- a/stablehlo/tests/interpret_all_gather.mlir +++ b/stablehlo/tests/interpret_all_gather.mlir @@ -1,5 +1,4 @@ -// RUN: echo 'DISABLED' -// RUN-DISABLED: stablehlo-translate --interpret -split-input-file %s +// RUN: stablehlo-translate --interpret -split-input-file %s module @cross_replica { func.func public @all_gather(%arg0 : tensor<2x2xi64>) -> tensor<2x4xi64> { diff --git a/stablehlo/tests/interpret_all_reduce.mlir b/stablehlo/tests/interpret_all_reduce.mlir index 9a072339935..4581c2759e5 100644 --- a/stablehlo/tests/interpret_all_reduce.mlir +++ b/stablehlo/tests/interpret_all_reduce.mlir @@ -1,5 +1,4 @@ -// RUN: echo 'DISABLED' -// RUN-DISABLED: stablehlo-translate --interpret -split-input-file %s +// RUN: stablehlo-translate --interpret -split-input-file %s module @cross_replica { func.func public @all_reduce(%operand : tensor<4xi64>) -> tensor<4xi64> { diff --git a/stablehlo/tests/interpret_all_to_all.mlir b/stablehlo/tests/interpret_all_to_all.mlir index f83a4c28001..1665ba28322 100644 --- a/stablehlo/tests/interpret_all_to_all.mlir +++ b/stablehlo/tests/interpret_all_to_all.mlir @@ -1,5 +1,4 @@ -// RUN: echo 'DISABLED' -// RUN-DISABLED: stablehlo-translate --interpret -split-input-file %s +// RUN: stablehlo-translate --interpret -split-input-file %s module @cross_replica { func.func public @all_to_all(%operand : tensor<2x4xi64>) -> tensor<4x2xi64> { diff --git a/stablehlo/tests/interpret_collective_permute.mlir b/stablehlo/tests/interpret_collective_permute.mlir index db3b9140cd8..605f276f962 100644 --- a/stablehlo/tests/interpret_collective_permute.mlir +++ b/stablehlo/tests/interpret_collective_permute.mlir @@ -1,5 +1,4 @@ -// RUN: echo 'DISABLED' -// RUN-DISABLED: stablehlo-translate --interpret -split-input-file %s +// RUN: stablehlo-translate --interpret -split-input-file %s module @cross_replica { func.func public @collective_permute(%operand : tensor<2x2xi64>) -> tensor<2x2xi64> { diff --git a/stablehlo/tests/interpret_outfeed.mlir b/stablehlo/tests/interpret_outfeed.mlir index bebf281b38f..5d7d943ca32 100644 --- a/stablehlo/tests/interpret_outfeed.mlir +++ b/stablehlo/tests/interpret_outfeed.mlir @@ -1,5 +1,4 @@ -// RUN: echo 'DISABLED' -// RUN-DISABLED: stablehlo-translate --interpret -split-input-file %s +// RUN: stablehlo-translate --interpret -split-input-file %s module @distribution_ops { func.func public @outfeed(%inputs0 : tensor<2x2x2xi64>, %token : !stablehlo.token) -> !stablehlo.token { diff --git a/stablehlo/tests/interpret_reduce_scatter.mlir b/stablehlo/tests/interpret_reduce_scatter.mlir index b12619e8d61..3b05497adbc 100644 --- a/stablehlo/tests/interpret_reduce_scatter.mlir +++ b/stablehlo/tests/interpret_reduce_scatter.mlir @@ -1,5 +1,4 @@ -// RUN: echo 'DISABLED' -// RUN-DISABLED: stablehlo-translate --interpret -split-input-file %s +// RUN: stablehlo-translate --interpret -split-input-file %s module @cross_replica { func.func public @reduce_scatter(%operand : tensor<2x4xi64>) -> tensor<2x2xi64> {