diff --git a/stablehlo/reference/Process.cpp b/stablehlo/reference/Process.cpp index f87583a613f..fc408f257c7 100644 --- a/stablehlo/reference/Process.cpp +++ b/stablehlo/reference/Process.cpp @@ -44,7 +44,7 @@ ProcessId Process::getId() { return id_; } void Process::outfeed(ArrayRef inputs) { grid_->outfeed(inputs); } -const std::shared_ptr Process::rendezvous( +std::shared_ptr Process::rendezvous( ProcessGroup processGroup, ChannelId channelId, const Tensor &operand) { return grid_->rendezvous(processGroup, channelId, getId(), operand); } diff --git a/stablehlo/reference/Process.h b/stablehlo/reference/Process.h index e40a031ee51..0651d2d651f 100644 --- a/stablehlo/reference/Process.h +++ b/stablehlo/reference/Process.h @@ -54,7 +54,7 @@ class Process { void outfeed(ArrayRef inputs); /// See `ProcessGrid::rendezvous`. - const std::shared_ptr rendezvous(ProcessGroup processGroup, + std::shared_ptr rendezvous(ProcessGroup processGroup, ChannelId channelId, const Tensor &operand); diff --git a/stablehlo/reference/ProcessGrid.cpp b/stablehlo/reference/ProcessGrid.cpp index a6b38f50dea..c4c09bb1c38 100644 --- a/stablehlo/reference/ProcessGrid.cpp +++ b/stablehlo/reference/ProcessGrid.cpp @@ -59,20 +59,20 @@ std::optional ProcessGroups::findGroup(ProcessId processId) { // RendezvousResult. //===----------------------------------------------------------------------===// -RendezvousResult::RendezvousResult(std::map result) +RendezvousResult::RendezvousResult(std::map const &result) : result_(result) {} void RendezvousResult::insert(ProcessId processId, Tensor tensor) { result_[processId] = tensor; } -Tensor RendezvousResult::lookup(ProcessId processId) { +Tensor RendezvousResult::lookup(ProcessId processId) const { auto it = result_.find(processId); if (it != result_.end()) return it->second; return {}; } -SmallVector RendezvousResult::getSortedTensors() { +SmallVector RendezvousResult::getSortedTensors() const { return llvm::to_vector( llvm::map_range(result_, [](const auto &pair) { return pair.second; })); } @@ -162,7 +162,7 @@ ProcessGroups ProcessGrid::flattenedIds( void ProcessGrid::outfeed(ArrayRef inputs) { outfeed_.push(inputs); } -const std::shared_ptr ProcessGrid::rendezvous( +std::shared_ptr ProcessGrid::rendezvous( ProcessGroup processGroup, ChannelId channelId, ProcessId processId, const Tensor &operand) { std::pair channelKey(processGroup, channelId); diff --git a/stablehlo/reference/ProcessGrid.h b/stablehlo/reference/ProcessGrid.h index bcbc6512ec5..3ea26dc5197 100644 --- a/stablehlo/reference/ProcessGrid.h +++ b/stablehlo/reference/ProcessGrid.h @@ -70,19 +70,19 @@ class ProcessGroups : public SmallVector { /// map-like API. class RendezvousResult { public: - RendezvousResult(std::map result); + RendezvousResult(std::map const &result); /// Iterates through the (ProcessId, Tensor) map entires and returns a vector /// of Tensors sorted by ProcessId--(replicaId, partitionId) pair--in /// lexicographical order. - SmallVector getSortedTensors(); + SmallVector getSortedTensors() const; /// Inserts `tensor` into the map using the key `processId`. void insert(ProcessId processId, Tensor tensor); /// Iterates through the map and returns the value associated with the key /// `processId`. If key is not found, return an empty `Tensor`. - Tensor lookup(ProcessId processId); + Tensor lookup(ProcessId processId) const; private: /// Internal map representation of the result of `ProcessGrid::rendezvous`. @@ -134,7 +134,7 @@ class ProcessGrid { /// tensors are accumulated in `RendezvousResult` whose shard pointer is /// returned to all callers once the barrier has been reached by all StableHLO /// processes. - const std::shared_ptr rendezvous(ProcessGroup processGroup, + std::shared_ptr rendezvous(ProcessGroup processGroup, ChannelId channelId, ProcessId processId, const Tensor &operand);