Skip to content

Commit

Permalink
Fix clang tidy, change constness of rendezvous results (#1778)
Browse files Browse the repository at this point in the history
These results should not be modified by threads once returned, thus the
pointer can be to const data. Making the shared pointer const doesn't
add much, just means that the thing being pointed at cant be changed,
which doesn't guarantee much.
  • Loading branch information
GleasonK authored Sep 19, 2023
1 parent 887e069 commit 7a4b80e
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 10 deletions.
2 changes: 1 addition & 1 deletion stablehlo/reference/Process.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ ProcessId Process::getId() { return id_; }

void Process::outfeed(ArrayRef<Tensor> inputs) { grid_->outfeed(inputs); }

const std::shared_ptr<RendezvousResult> Process::rendezvous(
std::shared_ptr<RendezvousResult const> Process::rendezvous(
ProcessGroup processGroup, ChannelId channelId, const Tensor &operand) {
return grid_->rendezvous(processGroup, channelId, getId(), operand);
}
Expand Down
2 changes: 1 addition & 1 deletion stablehlo/reference/Process.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class Process {
void outfeed(ArrayRef<Tensor> inputs);

/// See `ProcessGrid::rendezvous`.
const std::shared_ptr<RendezvousResult> rendezvous(ProcessGroup processGroup,
std::shared_ptr<RendezvousResult const> rendezvous(ProcessGroup processGroup,
ChannelId channelId,
const Tensor &operand);

Expand Down
8 changes: 4 additions & 4 deletions stablehlo/reference/ProcessGrid.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,20 +59,20 @@ std::optional<ProcessGroup> ProcessGroups::findGroup(ProcessId processId) {
// RendezvousResult.
//===----------------------------------------------------------------------===//

RendezvousResult::RendezvousResult(std::map<ProcessId, Tensor> result)
RendezvousResult::RendezvousResult(std::map<ProcessId, Tensor> const &result)
: result_(result) {}

void RendezvousResult::insert(ProcessId processId, Tensor tensor) {
result_[processId] = tensor;
}

Tensor RendezvousResult::lookup(ProcessId processId) {
Tensor RendezvousResult::lookup(ProcessId processId) const {
auto it = result_.find(processId);
if (it != result_.end()) return it->second;
return {};
}

SmallVector<Tensor> RendezvousResult::getSortedTensors() {
SmallVector<Tensor> RendezvousResult::getSortedTensors() const {
return llvm::to_vector(
llvm::map_range(result_, [](const auto &pair) { return pair.second; }));
}
Expand Down Expand Up @@ -162,7 +162,7 @@ ProcessGroups ProcessGrid::flattenedIds(

void ProcessGrid::outfeed(ArrayRef<Tensor> inputs) { outfeed_.push(inputs); }

const std::shared_ptr<RendezvousResult> ProcessGrid::rendezvous(
std::shared_ptr<RendezvousResult const> ProcessGrid::rendezvous(
ProcessGroup processGroup, ChannelId channelId, ProcessId processId,
const Tensor &operand) {
std::pair<ProcessGroup, ChannelId> channelKey(processGroup, channelId);
Expand Down
8 changes: 4 additions & 4 deletions stablehlo/reference/ProcessGrid.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,19 +70,19 @@ class ProcessGroups : public SmallVector<ProcessGroup> {
/// map-like API.
class RendezvousResult {
public:
RendezvousResult(std::map<ProcessId, Tensor> result);
RendezvousResult(std::map<ProcessId, Tensor> const &result);

/// Iterates through the (ProcessId, Tensor) map entires and returns a vector
/// of Tensors sorted by ProcessId--(replicaId, partitionId) pair--in
/// lexicographical order.
SmallVector<Tensor> getSortedTensors();
SmallVector<Tensor> getSortedTensors() const;

/// Inserts `tensor` into the map using the key `processId`.
void insert(ProcessId processId, Tensor tensor);

/// Iterates through the map and returns the value associated with the key
/// `processId`. If key is not found, return an empty `Tensor`.
Tensor lookup(ProcessId processId);
Tensor lookup(ProcessId processId) const;

private:
/// Internal map representation of the result of `ProcessGrid::rendezvous`.
Expand Down Expand Up @@ -134,7 +134,7 @@ class ProcessGrid {
/// tensors are accumulated in `RendezvousResult` whose shard pointer is
/// returned to all callers once the barrier has been reached by all StableHLO
/// processes.
const std::shared_ptr<RendezvousResult> rendezvous(ProcessGroup processGroup,
std::shared_ptr<RendezvousResult const> rendezvous(ProcessGroup processGroup,
ChannelId channelId,
ProcessId processId,
const Tensor &operand);
Expand Down

0 comments on commit 7a4b80e

Please sign in to comment.