Skip to content

Commit

Permalink
[PJRT:Python] Pass key value store to XLA compilation options.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 689698771
  • Loading branch information
Google-ML-Automation committed Oct 25, 2024
1 parent d881e8a commit f548642
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 0 deletions.
1 change: 1 addition & 0 deletions xla/pjrt/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ cc_library(
"//xla:xla_data_proto_cc",
"//xla/hlo/builder:xla_computation",
"//xla/hlo/ir:hlo",
"//xla/pjrt/distributed:key_value_store_interface",
"//xla/service:computation_placer_hdr",
"//xla/service:hlo_cost_analysis",
"//xla/tsl/framework:allocator",
Expand Down
5 changes: 5 additions & 0 deletions xla/pjrt/gpu/se_gpu_pjrt_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,11 @@ class StreamExecutorGpuClient : public xla::PjRtStreamExecutorClient {
std::shared_ptr<KeyValueStoreInterface> kv_store,
std::shared_ptr<const GpuTopology> gpu_topology);

std::optional<std::shared_ptr<KeyValueStoreInterface>> key_value_store()
const override {
return kv_store_;
}

absl::StatusOr<xla::DeviceAssignment> GetDefaultDeviceAssignment(
int num_replicas, int num_partitions) const override;

Expand Down
7 changes: 7 additions & 0 deletions xla/pjrt/pjrt_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ limitations under the License.
#include "xla/hlo/builder/xla_computation.h"
#include "xla/layout.h"
#include "xla/literal.h"
#include "xla/pjrt/distributed/key_value_store_interface.h"
#include "xla/pjrt/pjrt_common.h"
#include "xla/pjrt/pjrt_compiler.h"
#include "xla/pjrt/pjrt_device_description.h"
Expand Down Expand Up @@ -544,6 +545,12 @@ class PjRtClient {
// (e.g. the CUDA version on GPU or libtpu version on Cloud TPU).
virtual absl::string_view platform_version() const = 0;

// Returns the key value store used by the client.
virtual std::optional<std::shared_ptr<KeyValueStoreInterface>>
key_value_store() const {
return std::nullopt;
}

// Returns information about the underlying PJRT C API plugin if such a plugin
// is being used, otherwise returns nullopt.
virtual std::optional<PjRtPluginAttributes> plugin_attributes() const {
Expand Down
5 changes: 5 additions & 0 deletions xla/python/py_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,11 @@ PyClient::CompileIfrtProgram(
*stats->bytes_limit);
}
}

if (pjrt_compatible_client->pjrt_client()->key_value_store().has_value()) {
options.executable_build_options.set_key_value_store(
*pjrt_compatible_client->pjrt_client()->key_value_store());
}
}

std::unique_ptr<ifrt::LoadedExecutable> ifrt_loaded_executable;
Expand Down

0 comments on commit f548642

Please sign in to comment.