From d17a137a5d48916f6f58ca3dd18b791b0d99a0c5 Mon Sep 17 00:00:00 2001 From: Gunhyun Park Date: Tue, 12 Sep 2023 10:48:52 -0700 Subject: [PATCH] Fix TSAN violations for Distribution ops (#1768) The TSAN violation is due to an unprotected insert shown above and replacing `llvm::RefCountedBase` with its thread-safe alternative `llvm::ThreadSafeRefCountedBase`. In the original implementation of `rendezvous`, there were some nuances that were undesirable. For example, it's preferred to clear out the entry once all processes have read the data and clear out the map instead of letting the next distribution op to clear it out. The current implementation is also incorrect since there is no guarantee that processGroup size is the same as previous op calling `rendezvous`. ``` if (channels_[channelKey].size() == processGroup.size()) channels_[channelKey].clear(); channels_[channelKey].insert(processId, operand); ``` The proposed update enhances the `rendezvous` logic such that each process has a shared pointer to the result object once out of scope, it is automatically deleted. The PR also reenables tests for distribution ops. closes #1755 --- stablehlo/reference/Ops.cpp | 14 +-- stablehlo/reference/Process.cpp | 5 +- stablehlo/reference/Process.h | 5 +- stablehlo/reference/ProcessGrid.cpp | 113 ++++++++++++------ stablehlo/reference/ProcessGrid.h | 91 +++++++++----- stablehlo/reference/Tensor.h | 2 +- stablehlo/tests/interpret_all_gather.mlir | 3 +- stablehlo/tests/interpret_all_reduce.mlir | 3 +- stablehlo/tests/interpret_all_to_all.mlir | 3 +- .../tests/interpret_collective_permute.mlir | 3 +- stablehlo/tests/interpret_outfeed.mlir | 3 +- stablehlo/tests/interpret_reduce_scatter.mlir | 3 +- 12 files changed, 157 insertions(+), 91 deletions(-) diff --git a/stablehlo/reference/Ops.cpp b/stablehlo/reference/Ops.cpp index 63dae171003..8875ad5e800 100644 --- a/stablehlo/reference/Ops.cpp +++ b/stablehlo/reference/Ops.cpp @@ -859,8 +859,8 @@ Tensor evalAllGatherOp(const Tensor &operand, int64_t allGatherDim, "Failed to find process group with process_id: (%d, %d)", process->getId().replicaId, process->getId().partitionId)); - auto groupOperands = - process->rendezvous(*processGroup, channelId, operand).getSortedTensors(); + auto groupOperands = process->rendezvous(*processGroup, channelId, operand) + ->getSortedTensors(); return evalConcatenateOp(groupOperands, allGatherDim, resultType); } @@ -888,8 +888,8 @@ Tensor evalAllReduceOp(const Tensor &operand, "Failed to find process group with process_id: (%d, %d)", process->getId().replicaId, process->getId().partitionId)); - auto groupOperands = - process->rendezvous(*processGroup, channelId, operand).getSortedTensors(); + auto groupOperands = process->rendezvous(*processGroup, channelId, operand) + ->getSortedTensors(); Tensor result(resultType); for (auto resultIt = result.index_begin(); resultIt != result.index_end(); @@ -929,8 +929,8 @@ Tensor evalAllToAllOp(const Tensor &operand, Axis splitDimension, "Failed to find process group with process_id: (%d, %d)", process->getId().replicaId, process->getId().partitionId)); - auto groupOperands = - process->rendezvous(*processGroup, channelId, operand).getSortedTensors(); + auto groupOperands = process->rendezvous(*processGroup, channelId, operand) + ->getSortedTensors(); SmallVector scatteredParts; for (const auto &groupOperand : groupOperands) { @@ -1080,7 +1080,7 @@ Tensor evalCollectivePermuteOp( if (from != process->getId() && to != process->getId()) continue; auto rendezvousResult = - process->rendezvous(processGroup, channelId, operand); + *process->rendezvous(processGroup, channelId, operand); if (to != process->getId()) continue; result = rendezvousResult.lookup(from); } diff --git a/stablehlo/reference/Process.cpp b/stablehlo/reference/Process.cpp index a29ed5f96b3..f87583a613f 100644 --- a/stablehlo/reference/Process.cpp +++ b/stablehlo/reference/Process.cpp @@ -44,9 +44,8 @@ ProcessId Process::getId() { return id_; } void Process::outfeed(ArrayRef inputs) { grid_->outfeed(inputs); } -RendezvousResult Process::rendezvous(ProcessGroup processGroup, - ChannelId channelId, - const Tensor &operand) { +const std::shared_ptr Process::rendezvous( + ProcessGroup processGroup, ChannelId channelId, const Tensor &operand) { return grid_->rendezvous(processGroup, channelId, getId(), operand); } diff --git a/stablehlo/reference/Process.h b/stablehlo/reference/Process.h index 11da144fbe7..e40a031ee51 100644 --- a/stablehlo/reference/Process.h +++ b/stablehlo/reference/Process.h @@ -54,8 +54,9 @@ class Process { void outfeed(ArrayRef inputs); /// See `ProcessGrid::rendezvous`. - RendezvousResult rendezvous(ProcessGroup processGroup, ChannelId channelId, - const Tensor &operand); + const std::shared_ptr rendezvous(ProcessGroup processGroup, + ChannelId channelId, + const Tensor &operand); private: /// StableHLO `process_id`. diff --git a/stablehlo/reference/ProcessGrid.cpp b/stablehlo/reference/ProcessGrid.cpp index 5f221221dba..a6b38f50dea 100644 --- a/stablehlo/reference/ProcessGrid.cpp +++ b/stablehlo/reference/ProcessGrid.cpp @@ -59,7 +59,8 @@ std::optional ProcessGroups::findGroup(ProcessId processId) { // RendezvousResult. //===----------------------------------------------------------------------===// -void RendezvousResult::clear() { result_.clear(); } +RendezvousResult::RendezvousResult(std::map result) + : result_(result) {} void RendezvousResult::insert(ProcessId processId, Tensor tensor) { result_[processId] = tensor; @@ -76,7 +77,24 @@ SmallVector RendezvousResult::getSortedTensors() { llvm::map_range(result_, [](const auto &pair) { return pair.second; })); } -size_t RendezvousResult::size() { return result_.size(); } +//===----------------------------------------------------------------------===// +// ThreadSafeMap. +//===----------------------------------------------------------------------===// + +template +V &ProcessGrid::ThreadSafeMap::operator[](const K &key) { + std::lock_guard lock(lock_); + return map_[key]; +} + +//===----------------------------------------------------------------------===// +// ThreadSafeQueue. +//===----------------------------------------------------------------------===// + +void ProcessGrid::ThreadSafeQueue::push(ArrayRef inputs) { + std::lock_guard lock(lock_); + queue_.emplace(inputs); +} //===----------------------------------------------------------------------===// // ProcessGrid. @@ -142,45 +160,72 @@ ProcessGroups ProcessGrid::flattenedIds( return processGroups; } -std::mutex &ProcessGrid::getRendezvousLock(ProcessGroup processGroup, - ChannelId channelId) { - std::lock_guard lock(rendezvousLock_); - std::pair channelKey(processGroup, channelId); - return channelLocks_[channelKey]; -} +void ProcessGrid::outfeed(ArrayRef inputs) { outfeed_.push(inputs); } -void ProcessGrid::outfeed(ArrayRef inputs) { - std::lock_guard lock(outfeedLock_); - outfeed_.emplace(inputs); -} - -RendezvousResult ProcessGrid::rendezvous(ProcessGroup processGroup, - ChannelId channelId, - ProcessId processId, - const Tensor &operand) { +const std::shared_ptr ProcessGrid::rendezvous( + ProcessGroup processGroup, ChannelId channelId, ProcessId processId, + const Tensor &operand) { std::pair channelKey(processGroup, channelId); - { - std::lock_guard lock( - getRendezvousLock(processGroup, channelId)); - if (channels_[channelKey].size() == processGroup.size()) - channels_[channelKey].clear(); - - channels_[channelKey].insert(processId, operand); - } - { - std::unique_lock lock( - getRendezvousLock(processGroup, channelId)); - if (channels_[channelKey].size() == processGroup.size()) - channelConditions_[channelKey].notify_all(); - + // Process wait/notify logic below doesn't work for single process. + if (processGroup.size() == 1) + return std::make_shared( + RendezvousResult({std::pair{processId, operand}})); + + auto &state = channels_[channelKey]; + + std::unique_lock lock(state.mutex); + state.values[processId] = operand; + + if (state.values.size() == processGroup.size()) { + // If values are full, that means all other processes are currently waiting. + // The last process to contribute moves the values into the result + // then waits for each process to return a copy of the result before + // cleaning up the state variable for future computations in this process + // grid. + state.result = std::make_shared(state.values); + state.values.clear(); + channelConditions_[channelKey].notify_one(); + + // The last process to contribute waits until the rest of the processes have + // read the values. if (!channelConditions_[channelKey].wait_for( lock, std::chrono::seconds(3), [&] { - return channels_[channelKey].size() == processGroup.size(); + return state.result.use_count() >= + static_cast(processGroup.size()); })) - llvm::report_fatal_error("rendezvous timed out"); + llvm::report_fatal_error( + "rendezvous timed out: not all processes have contributed yet"); - return channels_[channelKey]; + if (state.result.use_count() > static_cast(processGroup.size())) + llvm::report_fatal_error( + "Each process should have only one shared access to the result."); + + // The last process to contribute takes the result from the state to allow + // the process that contributed last to exit the function. + channelConditions_[channelKey].notify_one(); + return std::move(state.result); } + + // Wait for all processes to contribute values. + if (!channelConditions_[channelKey].wait_for( + lock, std::chrono::seconds(3), + [&] { return state.result != nullptr; })) + llvm::report_fatal_error( + "rendezvous timed out: not all process has received the results yet"); + + // Copy result from the state before notifying. + auto result = state.result; + channelConditions_[channelKey].notify_one(); + + // Wait for the remaining processes to have retrieved the result. In other + // words, wait until the last process to contribute exit the function. + if (!channelConditions_[channelKey].wait_for( + lock, std::chrono::seconds(3), + [&] { return state.result == nullptr; })) + llvm::report_fatal_error( + "rendezvous timed out: not all process has received the results yet"); + + return result; } } // namespace stablehlo diff --git a/stablehlo/reference/ProcessGrid.h b/stablehlo/reference/ProcessGrid.h index 33b8ae5d762..bcbc6512ec5 100644 --- a/stablehlo/reference/ProcessGrid.h +++ b/stablehlo/reference/ProcessGrid.h @@ -19,6 +19,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -69,8 +70,7 @@ class ProcessGroups : public SmallVector { /// map-like API. class RendezvousResult { public: - /// Erases all elements in the map. - void clear(); + RendezvousResult(std::map result); /// Iterates through the (ProcessId, Tensor) map entires and returns a vector /// of Tensors sorted by ProcessId--(replicaId, partitionId) pair--in @@ -84,9 +84,6 @@ class RendezvousResult { /// `processId`. If key is not found, return an empty `Tensor`. Tensor lookup(ProcessId processId); - /// Returns the number of elements in the map. - size_t size(); - private: /// Internal map representation of the result of `ProcessGrid::rendezvous`. std::map result_; @@ -134,16 +131,60 @@ class ProcessGrid { /// deadlock the interpreter. /// /// At the barrier, each StableHLO process contributes a tensor, and these - /// tensors are accumulated in `RendezvousResult` which is returned to all - /// callers once the barrier has been reached by all StableHLO processes. - RendezvousResult rendezvous(ProcessGroup processGroup, ChannelId channelId, - ProcessId processId, const Tensor &operand); + /// tensors are accumulated in `RendezvousResult` whose shard pointer is + /// returned to all callers once the barrier has been reached by all StableHLO + /// processes. + const std::shared_ptr rendezvous(ProcessGroup processGroup, + ChannelId channelId, + ProcessId processId, + const Tensor &operand); private: - /// Obtain a mutex that is shared between all processes participating in - /// a call to `rendezvous` for a given combination of `processGroup` and - /// `channelId`. - std::mutex &getRendezvousLock(ProcessGroup processGroup, ChannelId channelId); + /// Internal storate used in `rendezvous` to manage concurrent access to the + /// shared resource. Processes contribute their data to `values` concurrently. + /// Once all processes have added their data, the data in `values` is moved to + /// `result` that multiple processes can concurrently read from. + struct RendezvousState { + /// Synchronization primitive used to manage concurrent access to this + /// object. + std::mutex mutex; + /// Internal storage used to store data contributed by the processes. + std::map values; + /// Shared pointer to the result of `rendezvous`. + std::shared_ptr result; + }; + + /// Stores the result of `rendezvous` represented as a map that allows + /// concurrent access. + /// Each call to `rendezvous`, i.e. each combination `processGroup` and + /// `channelId`, has its own key in the map. Within the implementation of + /// `rendezvous`, the value corresponding to this key is gradually populated + /// with tensors arriving from different processes in the process group. + template + class ThreadSafeMap { + public: + /// Returns a reference to the data associated with the `key`. + V &operator[](const K &key); + + private: + /// Synchronization primitive used to manage concurrent access to the map. + std::mutex lock_; + /// Internal storage used to implement `rendezvous`. + std::map map_; + }; + + /// StableHLO `outfeed` represented as a queue that allows concurrent access. + class ThreadSafeQueue { + public: + /// Add `inputs` to the end of the queue. + void push(ArrayRef inputs); + + private: + /// Synchronization primitive used to manage concurrent access to the queue. + std::mutex lock_; + /// Internal storage used to implement StableHLO `outfeed`. + std::queue> queue_; + }; /// StableHLO `num_replicas`. const uint32_t numReplicas_; @@ -151,28 +192,14 @@ class ProcessGrid { /// StableHLO `num_partitions`. const uint32_t numPartitions_; - /// StableHLO `outfeed` represented as a queue. - std::queue> outfeed_; - - std::mutex outfeedLock_; - - /// Synchronization primitive used to manage concurrent access to - /// `channelLocks_`. - std::mutex rendezvousLock_; + /// See `ThreadSafeQueue`. + ThreadSafeQueue outfeed_; - /// Internal storage used to implement `rendezvous`. - /// Each call to `rendezvous`, i.e. each combination `processGroup` and - /// `channelId`, has its own key in the map. - /// Within the implementation of `rendezvous`, the value corresponding to - /// this key is gradually populated with tensors arriving from different - /// processes in the process group. - std::map, RendezvousResult> channels_; - - /// Synchronization primitive used to manage concurrent access to `channels_`. - std::map, std::mutex> channelLocks_; + /// See `ThreadSafeMap`. + ThreadSafeMap, RendezvousState> channels_; /// Synchronization primitive used to manage concurrent access to `channels_`. - std::map, std::condition_variable> + ThreadSafeMap, std::condition_variable> channelConditions_; }; diff --git a/stablehlo/reference/Tensor.h b/stablehlo/reference/Tensor.h index 205977dd798..b4ea0c6ecfe 100644 --- a/stablehlo/reference/Tensor.h +++ b/stablehlo/reference/Tensor.h @@ -36,7 +36,7 @@ namespace stablehlo { namespace detail { /// Underlying storage class for Tensor objects. -class Buffer : public llvm::RefCountedBase { +class Buffer : public llvm::ThreadSafeRefCountedBase { public: /// \name Constructors /// @{ diff --git a/stablehlo/tests/interpret_all_gather.mlir b/stablehlo/tests/interpret_all_gather.mlir index f206586d79b..97c1494ea67 100644 --- a/stablehlo/tests/interpret_all_gather.mlir +++ b/stablehlo/tests/interpret_all_gather.mlir @@ -1,5 +1,4 @@ -// RUN: echo 'DISABLED' -// RUN-DISABLED: stablehlo-translate --interpret -split-input-file %s +// RUN: stablehlo-translate --interpret -split-input-file %s module @cross_replica { func.func public @all_gather(%arg0 : tensor<2x2xi64>) -> tensor<2x4xi64> { diff --git a/stablehlo/tests/interpret_all_reduce.mlir b/stablehlo/tests/interpret_all_reduce.mlir index 9a072339935..4581c2759e5 100644 --- a/stablehlo/tests/interpret_all_reduce.mlir +++ b/stablehlo/tests/interpret_all_reduce.mlir @@ -1,5 +1,4 @@ -// RUN: echo 'DISABLED' -// RUN-DISABLED: stablehlo-translate --interpret -split-input-file %s +// RUN: stablehlo-translate --interpret -split-input-file %s module @cross_replica { func.func public @all_reduce(%operand : tensor<4xi64>) -> tensor<4xi64> { diff --git a/stablehlo/tests/interpret_all_to_all.mlir b/stablehlo/tests/interpret_all_to_all.mlir index f83a4c28001..1665ba28322 100644 --- a/stablehlo/tests/interpret_all_to_all.mlir +++ b/stablehlo/tests/interpret_all_to_all.mlir @@ -1,5 +1,4 @@ -// RUN: echo 'DISABLED' -// RUN-DISABLED: stablehlo-translate --interpret -split-input-file %s +// RUN: stablehlo-translate --interpret -split-input-file %s module @cross_replica { func.func public @all_to_all(%operand : tensor<2x4xi64>) -> tensor<4x2xi64> { diff --git a/stablehlo/tests/interpret_collective_permute.mlir b/stablehlo/tests/interpret_collective_permute.mlir index db3b9140cd8..605f276f962 100644 --- a/stablehlo/tests/interpret_collective_permute.mlir +++ b/stablehlo/tests/interpret_collective_permute.mlir @@ -1,5 +1,4 @@ -// RUN: echo 'DISABLED' -// RUN-DISABLED: stablehlo-translate --interpret -split-input-file %s +// RUN: stablehlo-translate --interpret -split-input-file %s module @cross_replica { func.func public @collective_permute(%operand : tensor<2x2xi64>) -> tensor<2x2xi64> { diff --git a/stablehlo/tests/interpret_outfeed.mlir b/stablehlo/tests/interpret_outfeed.mlir index bebf281b38f..5d7d943ca32 100644 --- a/stablehlo/tests/interpret_outfeed.mlir +++ b/stablehlo/tests/interpret_outfeed.mlir @@ -1,5 +1,4 @@ -// RUN: echo 'DISABLED' -// RUN-DISABLED: stablehlo-translate --interpret -split-input-file %s +// RUN: stablehlo-translate --interpret -split-input-file %s module @distribution_ops { func.func public @outfeed(%inputs0 : tensor<2x2x2xi64>, %token : !stablehlo.token) -> !stablehlo.token { diff --git a/stablehlo/tests/interpret_reduce_scatter.mlir b/stablehlo/tests/interpret_reduce_scatter.mlir index b12619e8d61..3b05497adbc 100644 --- a/stablehlo/tests/interpret_reduce_scatter.mlir +++ b/stablehlo/tests/interpret_reduce_scatter.mlir @@ -1,5 +1,4 @@ -// RUN: echo 'DISABLED' -// RUN-DISABLED: stablehlo-translate --interpret -split-input-file %s +// RUN: stablehlo-translate --interpret -split-input-file %s module @cross_replica { func.func public @reduce_scatter(%operand : tensor<2x4xi64>) -> tensor<2x2xi64> {