diff --git a/xla/pjrt/BUILD b/xla/pjrt/BUILD index 1d00186fbb49c..ff731242a73e1 100644 --- a/xla/pjrt/BUILD +++ b/xla/pjrt/BUILD @@ -216,6 +216,7 @@ cc_library( "//xla:xla_data_proto_cc", "//xla/hlo/builder:xla_computation", "//xla/hlo/ir:hlo", + "//xla/pjrt/distributed:key_value_store_interface", "//xla/service:computation_placer_hdr", "//xla/service:hlo_cost_analysis", "//xla/tsl/framework:allocator", diff --git a/xla/pjrt/gpu/se_gpu_pjrt_client.h b/xla/pjrt/gpu/se_gpu_pjrt_client.h index 21b7d1e382be6..ccc6e5dde1702 100644 --- a/xla/pjrt/gpu/se_gpu_pjrt_client.h +++ b/xla/pjrt/gpu/se_gpu_pjrt_client.h @@ -201,6 +201,11 @@ class StreamExecutorGpuClient : public xla::PjRtStreamExecutorClient { std::shared_ptr kv_store, std::shared_ptr gpu_topology); + std::optional> key_value_store() + const override { + return kv_store_; + } + absl::StatusOr GetDefaultDeviceAssignment( int num_replicas, int num_partitions) const override; diff --git a/xla/pjrt/pjrt_client.h b/xla/pjrt/pjrt_client.h index f9e34f0056308..ddcfe417e0c6d 100644 --- a/xla/pjrt/pjrt_client.h +++ b/xla/pjrt/pjrt_client.h @@ -44,6 +44,7 @@ limitations under the License. #include "xla/hlo/builder/xla_computation.h" #include "xla/layout.h" #include "xla/literal.h" +#include "xla/pjrt/distributed/key_value_store_interface.h" #include "xla/pjrt/pjrt_common.h" #include "xla/pjrt/pjrt_compiler.h" #include "xla/pjrt/pjrt_device_description.h" @@ -544,6 +545,12 @@ class PjRtClient { // (e.g. the CUDA version on GPU or libtpu version on Cloud TPU). virtual absl::string_view platform_version() const = 0; + // Returns the key value store used by the client. + virtual std::optional> + key_value_store() const { + return std::nullopt; + } + // Returns information about the underlying PJRT C API plugin if such a plugin // is being used, otherwise returns nullopt. virtual std::optional plugin_attributes() const { diff --git a/xla/python/py_client.cc b/xla/python/py_client.cc index 93c6282ae57a8..b7dace3445d07 100644 --- a/xla/python/py_client.cc +++ b/xla/python/py_client.cc @@ -430,6 +430,11 @@ PyClient::CompileIfrtProgram( *stats->bytes_limit); } } + + if (pjrt_compatible_client->pjrt_client()->key_value_store().has_value()) { + options.executable_build_options.set_key_value_store( + *pjrt_compatible_client->pjrt_client()->key_value_store()); + } } std::unique_ptr ifrt_loaded_executable;