From 7a4b80ed8da88fda3f8a7da1f29d45cb5c00929e Mon Sep 17 00:00:00 2001 From: Kevin Gleason Date: Tue, 19 Sep 2023 17:03:54 -0400 Subject: [PATCH] Fix clang tidy, change constness of rendezvous results (#1778) These results should not be modified by threads once returned, thus the pointer can be to const data. Making the shared pointer const doesn't add much, just means that the thing being pointed at cant be changed, which doesn't guarantee much. --- stablehlo/reference/Process.cpp | 2 +- stablehlo/reference/Process.h | 2 +- stablehlo/reference/ProcessGrid.cpp | 8 ++++---- stablehlo/reference/ProcessGrid.h | 8 ++++---- 4 files changed, 10 insertions(+), 10 deletions(-) diff --git a/stablehlo/reference/Process.cpp b/stablehlo/reference/Process.cpp index f87583a613f..fc408f257c7 100644 --- a/stablehlo/reference/Process.cpp +++ b/stablehlo/reference/Process.cpp @@ -44,7 +44,7 @@ ProcessId Process::getId() { return id_; } void Process::outfeed(ArrayRef inputs) { grid_->outfeed(inputs); } -const std::shared_ptr Process::rendezvous( +std::shared_ptr Process::rendezvous( ProcessGroup processGroup, ChannelId channelId, const Tensor &operand) { return grid_->rendezvous(processGroup, channelId, getId(), operand); } diff --git a/stablehlo/reference/Process.h b/stablehlo/reference/Process.h index e40a031ee51..0651d2d651f 100644 --- a/stablehlo/reference/Process.h +++ b/stablehlo/reference/Process.h @@ -54,7 +54,7 @@ class Process { void outfeed(ArrayRef inputs); /// See `ProcessGrid::rendezvous`. - const std::shared_ptr rendezvous(ProcessGroup processGroup, + std::shared_ptr rendezvous(ProcessGroup processGroup, ChannelId channelId, const Tensor &operand); diff --git a/stablehlo/reference/ProcessGrid.cpp b/stablehlo/reference/ProcessGrid.cpp index a6b38f50dea..c4c09bb1c38 100644 --- a/stablehlo/reference/ProcessGrid.cpp +++ b/stablehlo/reference/ProcessGrid.cpp @@ -59,20 +59,20 @@ std::optional ProcessGroups::findGroup(ProcessId processId) { // RendezvousResult. //===----------------------------------------------------------------------===// -RendezvousResult::RendezvousResult(std::map result) +RendezvousResult::RendezvousResult(std::map const &result) : result_(result) {} void RendezvousResult::insert(ProcessId processId, Tensor tensor) { result_[processId] = tensor; } -Tensor RendezvousResult::lookup(ProcessId processId) { +Tensor RendezvousResult::lookup(ProcessId processId) const { auto it = result_.find(processId); if (it != result_.end()) return it->second; return {}; } -SmallVector RendezvousResult::getSortedTensors() { +SmallVector RendezvousResult::getSortedTensors() const { return llvm::to_vector( llvm::map_range(result_, [](const auto &pair) { return pair.second; })); } @@ -162,7 +162,7 @@ ProcessGroups ProcessGrid::flattenedIds( void ProcessGrid::outfeed(ArrayRef inputs) { outfeed_.push(inputs); } -const std::shared_ptr ProcessGrid::rendezvous( +std::shared_ptr ProcessGrid::rendezvous( ProcessGroup processGroup, ChannelId channelId, ProcessId processId, const Tensor &operand) { std::pair channelKey(processGroup, channelId); diff --git a/stablehlo/reference/ProcessGrid.h b/stablehlo/reference/ProcessGrid.h index bcbc6512ec5..3ea26dc5197 100644 --- a/stablehlo/reference/ProcessGrid.h +++ b/stablehlo/reference/ProcessGrid.h @@ -70,19 +70,19 @@ class ProcessGroups : public SmallVector { /// map-like API. class RendezvousResult { public: - RendezvousResult(std::map result); + RendezvousResult(std::map const &result); /// Iterates through the (ProcessId, Tensor) map entires and returns a vector /// of Tensors sorted by ProcessId--(replicaId, partitionId) pair--in /// lexicographical order. - SmallVector getSortedTensors(); + SmallVector getSortedTensors() const; /// Inserts `tensor` into the map using the key `processId`. void insert(ProcessId processId, Tensor tensor); /// Iterates through the map and returns the value associated with the key /// `processId`. If key is not found, return an empty `Tensor`. - Tensor lookup(ProcessId processId); + Tensor lookup(ProcessId processId) const; private: /// Internal map representation of the result of `ProcessGrid::rendezvous`. @@ -134,7 +134,7 @@ class ProcessGrid { /// tensors are accumulated in `RendezvousResult` whose shard pointer is /// returned to all callers once the barrier has been reached by all StableHLO /// processes. - const std::shared_ptr rendezvous(ProcessGroup processGroup, + std::shared_ptr rendezvous(ProcessGroup processGroup, ChannelId channelId, ProcessId processId, const Tensor &operand);