Skip to content

Commit

Permalink
Fix TSAN violations for Distribution ops (#1768)
Browse files Browse the repository at this point in the history
The TSAN violation is due to an unprotected insert shown above and
replacing `llvm::RefCountedBase` with its thread-safe alternative
`llvm::ThreadSafeRefCountedBase`.

In the original implementation of `rendezvous`, there were some nuances
that were undesirable.
For example, it's preferred to clear out the entry once all processes
have read the data and clear out the map instead of letting the next
distribution op to clear it out. The current implementation is also
incorrect since there is no guarantee that processGroup size is the same
as previous op calling `rendezvous`.
```
if (channels_[channelKey].size() == processGroup.size())
  channels_[channelKey].clear();

channels_[channelKey].insert(processId, operand);
```

The proposed update enhances the `rendezvous` logic such that each
process has a shared pointer to the result object once out of scope, it
is automatically deleted.

The PR also reenables tests for distribution ops.

closes #1755
  • Loading branch information
ghpvnist authored Sep 12, 2023
1 parent 371f514 commit d17a137
Show file tree
Hide file tree
Showing 12 changed files with 157 additions and 91 deletions.
14 changes: 7 additions & 7 deletions stablehlo/reference/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -859,8 +859,8 @@ Tensor evalAllGatherOp(const Tensor &operand, int64_t allGatherDim,
"Failed to find process group with process_id: (%d, %d)",
process->getId().replicaId, process->getId().partitionId));

auto groupOperands =
process->rendezvous(*processGroup, channelId, operand).getSortedTensors();
auto groupOperands = process->rendezvous(*processGroup, channelId, operand)
->getSortedTensors();

return evalConcatenateOp(groupOperands, allGatherDim, resultType);
}
Expand Down Expand Up @@ -888,8 +888,8 @@ Tensor evalAllReduceOp(const Tensor &operand,
"Failed to find process group with process_id: (%d, %d)",
process->getId().replicaId, process->getId().partitionId));

auto groupOperands =
process->rendezvous(*processGroup, channelId, operand).getSortedTensors();
auto groupOperands = process->rendezvous(*processGroup, channelId, operand)
->getSortedTensors();

Tensor result(resultType);
for (auto resultIt = result.index_begin(); resultIt != result.index_end();
Expand Down Expand Up @@ -929,8 +929,8 @@ Tensor evalAllToAllOp(const Tensor &operand, Axis splitDimension,
"Failed to find process group with process_id: (%d, %d)",
process->getId().replicaId, process->getId().partitionId));

auto groupOperands =
process->rendezvous(*processGroup, channelId, operand).getSortedTensors();
auto groupOperands = process->rendezvous(*processGroup, channelId, operand)
->getSortedTensors();

SmallVector<Tensor> scatteredParts;
for (const auto &groupOperand : groupOperands) {
Expand Down Expand Up @@ -1080,7 +1080,7 @@ Tensor evalCollectivePermuteOp(
if (from != process->getId() && to != process->getId()) continue;

auto rendezvousResult =
process->rendezvous(processGroup, channelId, operand);
*process->rendezvous(processGroup, channelId, operand);
if (to != process->getId()) continue;
result = rendezvousResult.lookup(from);
}
Expand Down
5 changes: 2 additions & 3 deletions stablehlo/reference/Process.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,8 @@ ProcessId Process::getId() { return id_; }

void Process::outfeed(ArrayRef<Tensor> inputs) { grid_->outfeed(inputs); }

RendezvousResult Process::rendezvous(ProcessGroup processGroup,
ChannelId channelId,
const Tensor &operand) {
const std::shared_ptr<RendezvousResult> Process::rendezvous(
ProcessGroup processGroup, ChannelId channelId, const Tensor &operand) {
return grid_->rendezvous(processGroup, channelId, getId(), operand);
}

Expand Down
5 changes: 3 additions & 2 deletions stablehlo/reference/Process.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,9 @@ class Process {
void outfeed(ArrayRef<Tensor> inputs);

/// See `ProcessGrid::rendezvous`.
RendezvousResult rendezvous(ProcessGroup processGroup, ChannelId channelId,
const Tensor &operand);
const std::shared_ptr<RendezvousResult> rendezvous(ProcessGroup processGroup,
ChannelId channelId,
const Tensor &operand);

private:
/// StableHLO `process_id`.
Expand Down
113 changes: 79 additions & 34 deletions stablehlo/reference/ProcessGrid.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ std::optional<ProcessGroup> ProcessGroups::findGroup(ProcessId processId) {
// RendezvousResult.
//===----------------------------------------------------------------------===//

void RendezvousResult::clear() { result_.clear(); }
RendezvousResult::RendezvousResult(std::map<ProcessId, Tensor> result)
: result_(result) {}

void RendezvousResult::insert(ProcessId processId, Tensor tensor) {
result_[processId] = tensor;
Expand All @@ -76,7 +77,24 @@ SmallVector<Tensor> RendezvousResult::getSortedTensors() {
llvm::map_range(result_, [](const auto &pair) { return pair.second; }));
}

size_t RendezvousResult::size() { return result_.size(); }
//===----------------------------------------------------------------------===//
// ThreadSafeMap.
//===----------------------------------------------------------------------===//

template <typename K, typename V>
V &ProcessGrid::ThreadSafeMap<K, V>::operator[](const K &key) {
std::lock_guard<std::mutex> lock(lock_);
return map_[key];
}

//===----------------------------------------------------------------------===//
// ThreadSafeQueue.
//===----------------------------------------------------------------------===//

void ProcessGrid::ThreadSafeQueue::push(ArrayRef<Tensor> inputs) {
std::lock_guard<std::mutex> lock(lock_);
queue_.emplace(inputs);
}

//===----------------------------------------------------------------------===//
// ProcessGrid.
Expand Down Expand Up @@ -142,45 +160,72 @@ ProcessGroups ProcessGrid::flattenedIds(
return processGroups;
}

std::mutex &ProcessGrid::getRendezvousLock(ProcessGroup processGroup,
ChannelId channelId) {
std::lock_guard<std::mutex> lock(rendezvousLock_);
std::pair<ProcessGroup, ChannelId> channelKey(processGroup, channelId);
return channelLocks_[channelKey];
}
void ProcessGrid::outfeed(ArrayRef<Tensor> inputs) { outfeed_.push(inputs); }

void ProcessGrid::outfeed(ArrayRef<Tensor> inputs) {
std::lock_guard<std::mutex> lock(outfeedLock_);
outfeed_.emplace(inputs);
}

RendezvousResult ProcessGrid::rendezvous(ProcessGroup processGroup,
ChannelId channelId,
ProcessId processId,
const Tensor &operand) {
const std::shared_ptr<RendezvousResult> ProcessGrid::rendezvous(
ProcessGroup processGroup, ChannelId channelId, ProcessId processId,
const Tensor &operand) {
std::pair<ProcessGroup, ChannelId> channelKey(processGroup, channelId);
{
std::lock_guard<std::mutex> lock(
getRendezvousLock(processGroup, channelId));
if (channels_[channelKey].size() == processGroup.size())
channels_[channelKey].clear();

channels_[channelKey].insert(processId, operand);
}
{
std::unique_lock<std::mutex> lock(
getRendezvousLock(processGroup, channelId));
if (channels_[channelKey].size() == processGroup.size())
channelConditions_[channelKey].notify_all();

// Process wait/notify logic below doesn't work for single process.
if (processGroup.size() == 1)
return std::make_shared<RendezvousResult>(
RendezvousResult({std::pair{processId, operand}}));

auto &state = channels_[channelKey];

std::unique_lock<std::mutex> lock(state.mutex);
state.values[processId] = operand;

if (state.values.size() == processGroup.size()) {
// If values are full, that means all other processes are currently waiting.
// The last process to contribute moves the values into the result
// then waits for each process to return a copy of the result before
// cleaning up the state variable for future computations in this process
// grid.
state.result = std::make_shared<RendezvousResult>(state.values);
state.values.clear();
channelConditions_[channelKey].notify_one();

// The last process to contribute waits until the rest of the processes have
// read the values.
if (!channelConditions_[channelKey].wait_for(
lock, std::chrono::seconds(3), [&] {
return channels_[channelKey].size() == processGroup.size();
return state.result.use_count() >=
static_cast<int64_t>(processGroup.size());
}))
llvm::report_fatal_error("rendezvous timed out");
llvm::report_fatal_error(
"rendezvous timed out: not all processes have contributed yet");

return channels_[channelKey];
if (state.result.use_count() > static_cast<int64_t>(processGroup.size()))
llvm::report_fatal_error(
"Each process should have only one shared access to the result.");

// The last process to contribute takes the result from the state to allow
// the process that contributed last to exit the function.
channelConditions_[channelKey].notify_one();
return std::move(state.result);
}

// Wait for all processes to contribute values.
if (!channelConditions_[channelKey].wait_for(
lock, std::chrono::seconds(3),
[&] { return state.result != nullptr; }))
llvm::report_fatal_error(
"rendezvous timed out: not all process has received the results yet");

// Copy result from the state before notifying.
auto result = state.result;
channelConditions_[channelKey].notify_one();

// Wait for the remaining processes to have retrieved the result. In other
// words, wait until the last process to contribute exit the function.
if (!channelConditions_[channelKey].wait_for(
lock, std::chrono::seconds(3),
[&] { return state.result == nullptr; }))
llvm::report_fatal_error(
"rendezvous timed out: not all process has received the results yet");

return result;
}

} // namespace stablehlo
Expand Down
91 changes: 59 additions & 32 deletions stablehlo/reference/ProcessGrid.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ limitations under the License.
#include <condition_variable>
#include <cstdint>
#include <map>
#include <memory>
#include <mutex>
#include <optional>
#include <queue>
Expand Down Expand Up @@ -69,8 +70,7 @@ class ProcessGroups : public SmallVector<ProcessGroup> {
/// map-like API.
class RendezvousResult {
public:
/// Erases all elements in the map.
void clear();
RendezvousResult(std::map<ProcessId, Tensor> result);

/// Iterates through the (ProcessId, Tensor) map entires and returns a vector
/// of Tensors sorted by ProcessId--(replicaId, partitionId) pair--in
Expand All @@ -84,9 +84,6 @@ class RendezvousResult {
/// `processId`. If key is not found, return an empty `Tensor`.
Tensor lookup(ProcessId processId);

/// Returns the number of elements in the map.
size_t size();

private:
/// Internal map representation of the result of `ProcessGrid::rendezvous`.
std::map<ProcessId, Tensor> result_;
Expand Down Expand Up @@ -134,45 +131,75 @@ class ProcessGrid {
/// deadlock the interpreter.
///
/// At the barrier, each StableHLO process contributes a tensor, and these
/// tensors are accumulated in `RendezvousResult` which is returned to all
/// callers once the barrier has been reached by all StableHLO processes.
RendezvousResult rendezvous(ProcessGroup processGroup, ChannelId channelId,
ProcessId processId, const Tensor &operand);
/// tensors are accumulated in `RendezvousResult` whose shard pointer is
/// returned to all callers once the barrier has been reached by all StableHLO
/// processes.
const std::shared_ptr<RendezvousResult> rendezvous(ProcessGroup processGroup,
ChannelId channelId,
ProcessId processId,
const Tensor &operand);

private:
/// Obtain a mutex that is shared between all processes participating in
/// a call to `rendezvous` for a given combination of `processGroup` and
/// `channelId`.
std::mutex &getRendezvousLock(ProcessGroup processGroup, ChannelId channelId);
/// Internal storate used in `rendezvous` to manage concurrent access to the
/// shared resource. Processes contribute their data to `values` concurrently.
/// Once all processes have added their data, the data in `values` is moved to
/// `result` that multiple processes can concurrently read from.
struct RendezvousState {
/// Synchronization primitive used to manage concurrent access to this
/// object.
std::mutex mutex;
/// Internal storage used to store data contributed by the processes.
std::map<ProcessId, Tensor> values;
/// Shared pointer to the result of `rendezvous`.
std::shared_ptr<RendezvousResult> result;
};

/// Stores the result of `rendezvous` represented as a map that allows
/// concurrent access.
/// Each call to `rendezvous`, i.e. each combination `processGroup` and
/// `channelId`, has its own key in the map. Within the implementation of
/// `rendezvous`, the value corresponding to this key is gradually populated
/// with tensors arriving from different processes in the process group.
template <typename K, typename V>
class ThreadSafeMap {
public:
/// Returns a reference to the data associated with the `key`.
V &operator[](const K &key);

private:
/// Synchronization primitive used to manage concurrent access to the map.
std::mutex lock_;
/// Internal storage used to implement `rendezvous`.
std::map<K, V> map_;
};

/// StableHLO `outfeed` represented as a queue that allows concurrent access.
class ThreadSafeQueue {
public:
/// Add `inputs` to the end of the queue.
void push(ArrayRef<Tensor> inputs);

private:
/// Synchronization primitive used to manage concurrent access to the queue.
std::mutex lock_;
/// Internal storage used to implement StableHLO `outfeed`.
std::queue<SmallVector<Tensor>> queue_;
};

/// StableHLO `num_replicas`.
const uint32_t numReplicas_;

/// StableHLO `num_partitions`.
const uint32_t numPartitions_;

/// StableHLO `outfeed` represented as a queue.
std::queue<SmallVector<Tensor>> outfeed_;

std::mutex outfeedLock_;

/// Synchronization primitive used to manage concurrent access to
/// `channelLocks_`.
std::mutex rendezvousLock_;
/// See `ThreadSafeQueue`.
ThreadSafeQueue outfeed_;

/// Internal storage used to implement `rendezvous`.
/// Each call to `rendezvous`, i.e. each combination `processGroup` and
/// `channelId`, has its own key in the map.
/// Within the implementation of `rendezvous`, the value corresponding to
/// this key is gradually populated with tensors arriving from different
/// processes in the process group.
std::map<std::pair<ProcessGroup, ChannelId>, RendezvousResult> channels_;

/// Synchronization primitive used to manage concurrent access to `channels_`.
std::map<std::pair<ProcessGroup, ChannelId>, std::mutex> channelLocks_;
/// See `ThreadSafeMap`.
ThreadSafeMap<std::pair<ProcessGroup, ChannelId>, RendezvousState> channels_;

/// Synchronization primitive used to manage concurrent access to `channels_`.
std::map<std::pair<ProcessGroup, ChannelId>, std::condition_variable>
ThreadSafeMap<std::pair<ProcessGroup, ChannelId>, std::condition_variable>
channelConditions_;
};

Expand Down
2 changes: 1 addition & 1 deletion stablehlo/reference/Tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ namespace stablehlo {
namespace detail {

/// Underlying storage class for Tensor objects.
class Buffer : public llvm::RefCountedBase<Buffer> {
class Buffer : public llvm::ThreadSafeRefCountedBase<Buffer> {
public:
/// \name Constructors
/// @{
Expand Down
3 changes: 1 addition & 2 deletions stablehlo/tests/interpret_all_gather.mlir
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
// RUN: echo 'DISABLED'
// RUN-DISABLED: stablehlo-translate --interpret -split-input-file %s
// RUN: stablehlo-translate --interpret -split-input-file %s

module @cross_replica {
func.func public @all_gather(%arg0 : tensor<2x2xi64>) -> tensor<2x4xi64> {
Expand Down
3 changes: 1 addition & 2 deletions stablehlo/tests/interpret_all_reduce.mlir
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
// RUN: echo 'DISABLED'
// RUN-DISABLED: stablehlo-translate --interpret -split-input-file %s
// RUN: stablehlo-translate --interpret -split-input-file %s

module @cross_replica {
func.func public @all_reduce(%operand : tensor<4xi64>) -> tensor<4xi64> {
Expand Down
3 changes: 1 addition & 2 deletions stablehlo/tests/interpret_all_to_all.mlir
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
// RUN: echo 'DISABLED'
// RUN-DISABLED: stablehlo-translate --interpret -split-input-file %s
// RUN: stablehlo-translate --interpret -split-input-file %s

module @cross_replica {
func.func public @all_to_all(%operand : tensor<2x4xi64>) -> tensor<4x2xi64> {
Expand Down
3 changes: 1 addition & 2 deletions stablehlo/tests/interpret_collective_permute.mlir
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
// RUN: echo 'DISABLED'
// RUN-DISABLED: stablehlo-translate --interpret -split-input-file %s
// RUN: stablehlo-translate --interpret -split-input-file %s

module @cross_replica {
func.func public @collective_permute(%operand : tensor<2x2xi64>) -> tensor<2x2xi64> {
Expand Down
3 changes: 1 addition & 2 deletions stablehlo/tests/interpret_outfeed.mlir
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
// RUN: echo 'DISABLED'
// RUN-DISABLED: stablehlo-translate --interpret -split-input-file %s
// RUN: stablehlo-translate --interpret -split-input-file %s

module @distribution_ops {
func.func public @outfeed(%inputs0 : tensor<2x2x2xi64>, %token : !stablehlo.token) -> !stablehlo.token {
Expand Down
3 changes: 1 addition & 2 deletions stablehlo/tests/interpret_reduce_scatter.mlir
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
// RUN: echo 'DISABLED'
// RUN-DISABLED: stablehlo-translate --interpret -split-input-file %s
// RUN: stablehlo-translate --interpret -split-input-file %s

module @cross_replica {
func.func public @reduce_scatter(%operand : tensor<2x4xi64>) -> tensor<2x2xi64> {
Expand Down

0 comments on commit d17a137

Please sign in to comment.