Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Explicitly tell the GpuCompiler which stream to use from PJRT during the build step. #18705

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions xla/client/executable_build_options.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,16 @@ se::DeviceMemoryAllocator* ExecutableBuildOptions::device_allocator() const {
return device_allocator_;
}

ExecutableBuildOptions& ExecutableBuildOptions::set_compute_stream(
se::Stream* stream) {
compute_stream_ = stream;
return *this;
}

se::Stream* ExecutableBuildOptions::compute_stream() const {
return compute_stream_;
}

ExecutableBuildOptions& ExecutableBuildOptions::set_device_ordinal(
int device_ordinal) {
CHECK_GE(device_ordinal, 0);
Expand Down
6 changes: 6 additions & 0 deletions xla/client/executable_build_options.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ namespace stream_executor {

// Forward-declared to avoid StreamExecutor dependency.
class DeviceMemoryAllocator;
class Stream;

} // namespace stream_executor

Expand Down Expand Up @@ -91,6 +92,10 @@ class ExecutableBuildOptions {
se::DeviceMemoryAllocator* allocator);
se::DeviceMemoryAllocator* device_allocator() const;

// If set, this specifies a stream that can be used for autotuning.
ExecutableBuildOptions& set_compute_stream(se::Stream* stream);
se::Stream* compute_stream() const;

// The number of replicas of this computation that are to be executed.
// Defaults to 1.
int num_replicas() const { return num_replicas_; }
Expand Down Expand Up @@ -287,6 +292,7 @@ class ExecutableBuildOptions {
std::optional<CompilationEnvironments> comp_envs_;
std::optional<DebugOptions> debug_options_;
se::DeviceMemoryAllocator* device_allocator_ = nullptr;
se::Stream* compute_stream_ = nullptr;
int num_replicas_ = 1;
int num_partitions_ = 1;
bool use_spmd_partitioning_ = false;
Expand Down
2 changes: 2 additions & 0 deletions xla/pjrt/pjrt_stream_executor_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3482,6 +3482,8 @@ PjRtStreamExecutorClient::GetExecutableExtras(CompileOptions* options) {
build_options.set_device_ordinal(
addressable_devices.front()->local_hardware_id().value());
}
build_options.set_compute_stream(
device_state(build_options.device_ordinal()).compute_stream());
}
return extras;
}
Expand Down
1 change: 1 addition & 0 deletions xla/service/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1458,6 +1458,7 @@ cc_library(
"//xla/hlo/ir:hlo_module_group",
"//xla/pjrt/distributed:key_value_store_interface",
"//xla/stream_executor:dnn",
"//xla/stream_executor:stream",
"//xla/stream_executor:stream_executor_h",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/status:statusor",
Expand Down
5 changes: 5 additions & 0 deletions xla/service/compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ limitations under the License.
#include "xla/service/executable.h"
#include "xla/service/hlo_module_config.h"
#include "xla/service/metrics_hook_interface.h"
#include "xla/stream_executor/stream.h"
#include "xla/stream_executor/stream_executor.h"
#include "tsl/platform/protobuf.h"
#include "tsl/platform/threadpool.h"
Expand Down Expand Up @@ -158,6 +159,10 @@ class Compiler {
std::optional<TargetConfig> target_config;

MultiProcessKeyValueStore key_value_store;

// If compute_stream is set, this is the stream used for all autotuning
// during compilation.
se::Stream* compute_stream = nullptr;
};

virtual ~Compiler() = default;
Expand Down
1 change: 1 addition & 0 deletions xla/service/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -2983,6 +2983,7 @@ xla_test(
"//xla/service/gpu/autotuning:autotuner_util",
"//xla/service/gpu/tests:gpu_codegen_test",
"//xla/stream_executor:device_description",
"//xla/stream_executor:mock_stream",
"//xla/stream_executor:platform",
"//xla/stream_executor:stream_executor_h",
"//xla/stream_executor/gpu:mock_gpu_executor",
Expand Down
1 change: 1 addition & 0 deletions xla/service/gpu/autotuning/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ xla_test(
"//xla/stream_executor:device_description",
"//xla/stream_executor:device_description_proto_cc",
"//xla/stream_executor:semantic_version",
"//xla/stream_executor:stream",
"//xla/stream_executor:stream_executor_h",
"//xla/tests:filecheck",
"//xla/tests:hlo_test_base",
Expand Down
21 changes: 15 additions & 6 deletions xla/service/gpu/autotuning/autotuner_compile_util_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,12 @@ ENTRY main {
se::Platform* platform = PlatformUtil::GetDefaultPlatform().value();
TF_ASSERT_OK_AND_ASSIGN(std::vector<se::StreamExecutor*> executors,
PlatformUtil::GetStreamExecutors(platform));
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<se::Stream> stream,
executors.at(0)->CreateStream());

AutotuneConfig autotune_config{DeviceConfig{executors.at(0), nullptr},
GetDebugOptionsForTest()};
AutotuneConfig autotune_config{
DeviceConfig{executors.at(0), nullptr, stream.get()},
GetDebugOptionsForTest()};

auto& root = *module->entry_computation()->root_instruction();

Expand Down Expand Up @@ -101,8 +104,11 @@ ENTRY main {
TF_ASSERT_OK_AND_ASSIGN(std::vector<se::StreamExecutor*> executors,
PlatformUtil::GetStreamExecutors(platform));

AutotuneConfig autotune_config{DeviceConfig{executors.at(0), nullptr},
GetDebugOptionsForTest()};
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<se::Stream> stream,
executors.at(0)->CreateStream());
AutotuneConfig autotune_config{
DeviceConfig{executors.at(0), nullptr, stream.get()},
GetDebugOptionsForTest()};

auto& root = *module->entry_computation()->root_instruction();

Expand Down Expand Up @@ -154,8 +160,11 @@ ENTRY main {
TF_ASSERT_OK_AND_ASSIGN(std::vector<se::StreamExecutor*> executors,
PlatformUtil::GetStreamExecutors(platform));

AutotuneConfig autotune_config{DeviceConfig{executors.at(0), nullptr},
GetDebugOptionsForTest()};
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<se::Stream> stream,
executors.at(0)->CreateStream());
AutotuneConfig autotune_config{
DeviceConfig{executors.at(0), nullptr, stream.get()},
GetDebugOptionsForTest()};

auto& root = *module->entry_computation()->root_instruction();

Expand Down
5 changes: 4 additions & 1 deletion xla/service/gpu/autotuning/autotuner_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ struct DeviceConfig {
// memory while timing the various convolution algorithms. If it's null,
// we'll use the default allocator on the StreamExecutor.
se::DeviceMemoryAllocator* allocator = nullptr; // may be null

se::Stream* compute_stream = nullptr;
};

struct DevicelessConfig {
Expand Down Expand Up @@ -177,7 +179,8 @@ class AutotuneConfig {

absl::StatusOr<se::Stream*> GetStream() const {
CHECK(std::holds_alternative<DeviceConfig>(config_));
return GetAllocator()->GetStream(GetExecutor()->device_ordinal());
se::Stream* stream = std::get<DeviceConfig>(config_).compute_stream;
return stream;
}

const se::GpuComputeCapability& GetGpuComputeCapability() const {
Expand Down
8 changes: 6 additions & 2 deletions xla/service/gpu/autotuning/conv_algorithm_picker_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ ENTRY main {
PlatformUtil::GetStreamExecutors(platform));
ASSERT_GT(executors.size(), 0);
se::StreamExecutor* stream_exec = executors[0];
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<se::Stream> stream,
stream_exec->CreateStream());

const se::GpuComputeCapability& cc = backend()
.default_stream_executor()
Expand All @@ -88,7 +90,7 @@ ENTRY main {
changed = false;
DebugOptions opts = DefaultDebugOptionsIgnoringFlags();

AutotuneConfig cfg{DeviceConfig{stream_exec, nullptr}, opts};
AutotuneConfig cfg{DeviceConfig{stream_exec, nullptr, stream.get()}, opts};
TF_ASSERT_OK_AND_ASSIGN(changed,
RunHloPass(GpuConvAlgorithmPicker(cfg), m.get()));
ASSERT_TRUE(changed);
Expand Down Expand Up @@ -200,7 +202,9 @@ ENTRY main {
ASSERT_TRUE(changed);

DebugOptions opts = DefaultDebugOptionsIgnoringFlags();
AutotuneConfig cfg{DeviceConfig{stream_exec, nullptr}, opts};
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<se::Stream> stream,
stream_exec->CreateStream());
AutotuneConfig cfg{DeviceConfig{stream_exec, nullptr, stream.get()}, opts};
TF_ASSERT_OK_AND_ASSIGN(changed,
RunHloPass(GpuConvAlgorithmPicker(cfg), m.get()));
ASSERT_TRUE(changed);
Expand Down
13 changes: 10 additions & 3 deletions xla/service/gpu/autotuning/custom_kernel_fusion_autotuner_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,12 @@ TEST_F(CustomKernelFusionAutotunerTest, DontRunOnNonCustomFusions) {

HloPassPipeline pipeline("custom_kernel_fusion_autotuner");
DebugOptions debug_options;
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<se::Stream> stream,
backend().default_stream_executor()->CreateStream());

AutotuneConfig autotune_config =
AutotuneConfig{DeviceConfig{backend().default_stream_executor(),
backend().memory_allocator()},
backend().memory_allocator(), stream.get()},
debug_options};
pipeline.AddPass<CustomKernelFusionAutotuner>(autotune_config);

Expand Down Expand Up @@ -100,9 +103,11 @@ TEST_F(CustomKernelFusionAutotunerTest,

HloPassPipeline pipeline("custom_kernel_fusion_autotuner");
DebugOptions debug_options;
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<se::Stream> stream,
backend().default_stream_executor()->CreateStream());
AutotuneConfig autotune_config =
AutotuneConfig{DeviceConfig{backend().default_stream_executor(),
backend().memory_allocator()},
backend().memory_allocator(), stream.get()},
debug_options};
pipeline.AddPass<CustomKernelFusionAutotuner>(autotune_config);
ASSERT_TRUE(pipeline.Run(hlo_module.get()).ok());
Expand Down Expand Up @@ -131,9 +136,11 @@ TEST_F(CustomKernelFusionAutotunerTest,

HloPassPipeline pipeline("custom_kernel_fusion_autotuner");
DebugOptions debug_options;
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<se::Stream> stream,
backend().default_stream_executor()->CreateStream());
AutotuneConfig autotune_config =
AutotuneConfig{DeviceConfig{backend().default_stream_executor(),
backend().memory_allocator()},
backend().memory_allocator(), stream.get()},
debug_options};
pipeline.AddPass<CustomKernelFusionAutotuner>(autotune_config);

Expand Down
18 changes: 14 additions & 4 deletions xla/service/gpu/autotuning/gemm_algorithm_picker_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,10 @@ ENTRY main {
/*toolkit_version=*/stream_executor::SemanticVersion{12, 4, 0}),
module.get()));

AutotuneConfig cfg{DeviceConfig{stream_exec(), nullptr}, debug_opts};
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<se::Stream> stream,
stream_exec()->CreateStream());
AutotuneConfig cfg{DeviceConfig{stream_exec(), nullptr, stream.get()},
debug_opts};
GemmAlgorithmPicker gpicker(cfg);
// Note that, we do not care if the algorithm index has been changed:
// the thing matters is the # of algorithms left after sorting out.
Expand Down Expand Up @@ -175,7 +178,10 @@ ENTRY main {
/*toolkit_version=*/stream_executor::SemanticVersion{12, 4, 0}),
module.get()));

AutotuneConfig cfg{DeviceConfig{stream_exec(), nullptr}, debug_opts};
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<se::Stream> stream,
stream_exec()->CreateStream());
AutotuneConfig cfg{DeviceConfig{stream_exec(), nullptr, stream.get()},
debug_opts};
GemmAlgorithmPicker gpicker(cfg);
TF_ASSERT_OK_AND_ASSIGN(changed, RunHloPass(gpicker, module.get()));
num_left2 = gpicker.num_algorithms_left();
Expand Down Expand Up @@ -208,7 +214,9 @@ ENTRY main {
m.get()));
changed = false;
DebugOptions opts;
AutotuneConfig cfg{DeviceConfig{stream_exec(), nullptr}, opts};
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<se::Stream> stream,
stream_exec()->CreateStream());
AutotuneConfig cfg{DeviceConfig{stream_exec(), nullptr, stream.get()}, opts};
TF_ASSERT_OK_AND_ASSIGN(changed,
RunHloPass(GemmAlgorithmPicker(cfg), m.get()));
ASSERT_TRUE(changed);
Expand Down Expand Up @@ -273,7 +281,9 @@ ENTRY main {
changed = false;

DebugOptions opts;
AutotuneConfig cfg{DeviceConfig{stream_exec(), nullptr}, opts};
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<se::Stream> stream,
stream_exec()->CreateStream());
AutotuneConfig cfg{DeviceConfig{stream_exec(), nullptr, stream.get()}, opts};

TF_ASSERT_OK_AND_ASSIGN(changed,
RunHloPass(GemmAlgorithmPicker(cfg), m.get()));
Expand Down
36 changes: 25 additions & 11 deletions xla/service/gpu/autotuning/gemm_fusion_autotuner_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ limitations under the License.
#include "xla/stream_executor/device_description.h"
#include "xla/stream_executor/device_description.pb.h"
#include "xla/stream_executor/semantic_version.h"
#include "xla/stream_executor/stream.h"
#include "xla/stream_executor/stream_executor.h"
#include "xla/tests/filecheck.h"
#include "xla/tests/hlo_test_base.h"
Expand Down Expand Up @@ -192,8 +193,10 @@ class StatelessAutotunerTest : public HloTestBase {
ccc->set_major(compute_capability.major);
ccc->set_minor(compute_capability.minor);

static se::Stream* stream =
backend().default_stream_executor()->CreateStream().value().release();
DeviceConfig test_config{backend().default_stream_executor(),
backend().memory_allocator()};
backend().memory_allocator(), stream};
AutotuneConfig autotune_config{test_config, debug_options};
GemmFusionAutotunerImpl autotuner(autotune_config, toolkit_version,
debug_options, nullptr);
Expand All @@ -210,8 +213,12 @@ class StatelessAutotunerTest : public HloTestBase {
// Returns the config for the current device.
absl::StatusOr<std::vector<GemmFusionAutotunerImpl::BackendConfig>>
GetPossibleMatmulAutotuneConfigs(const HloModule& module) {
static se::Stream* stream =
backend().default_stream_executor()->CreateStream().value().release();

DeviceConfig device_config{backend().default_stream_executor(),
backend().memory_allocator()};
device_config.compute_stream = stream;
AutotuneConfig autotune_config{device_config, GetDebugOptionsForTest()};
GemmFusionAutotunerImpl autotuner(autotune_config, GetToolkitVersion(),
GetDebugOptionsForTest(), nullptr);
Expand Down Expand Up @@ -317,11 +324,14 @@ class GemmFusionAutotunerTest : public StatelessAutotunerTest {
tsl::port::MaxParallelism());
DebugOptions opts;
MultiProcessKeyValueStore key_value_store;
pipeline.AddPass<GemmFusionAutotuner>(
AutotuneConfig{DeviceConfig{backend().default_stream_executor(),
backend().memory_allocator()},
opts},
GetToolkitVersion(), &thread_pool, key_value_store);
static se::Stream* stream =
backend().default_stream_executor()->CreateStream().value().release();
DeviceConfig device_config{backend().default_stream_executor(),
backend().memory_allocator()};
device_config.compute_stream = stream;
pipeline.AddPass<GemmFusionAutotuner>(AutotuneConfig{device_config, opts},
GetToolkitVersion(), &thread_pool,
key_value_store);

RunAndFilecheckHloRewrite(
hlo, std::move(pipeline), expected, [](const HloModule* m) {
Expand Down Expand Up @@ -703,9 +713,12 @@ ENTRY main {
ParseAndReturnVerifiedModule(kHloText));

DebugOptions opts;
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<se::Stream> stream,
backend().default_stream_executor()->CreateStream());

AutotuneConfig autotune_config{
DeviceConfig{backend().default_stream_executor(),
backend().memory_allocator()},
backend().memory_allocator(), stream.get()},
opts};
AutotuneCacheKey cache_key(autotune_config.GetModelStr(),
*module->entry_computation()->root_instruction());
Expand Down Expand Up @@ -1254,11 +1267,12 @@ TEST_F(GemmFusionAutotunerTest, RewritesGemmFusionToCustomKernelFusion) {
std::unique_ptr<VerifiedHloModule> module =
ParseAndReturnVerifiedModule(kHlo).value();

static se::Stream* stream =
backend().default_stream_executor()->CreateStream().value().release();
DebugOptions opts;
AutotuneConfig autotune_config{
DeviceConfig{backend().default_stream_executor(),
backend().memory_allocator()},
opts};
DeviceConfig device_config{backend().default_stream_executor(),
backend().memory_allocator(), stream};
AutotuneConfig autotune_config{device_config, opts};
AutotuneCacheKey cache_key(autotune_config.GetModelStr(),
*module->entry_computation()->root_instruction());
TF_ASSERT_OK_AND_ASSIGN(AutotuneResults autotune_results_override,
Expand Down
4 changes: 4 additions & 0 deletions xla/service/gpu/determinism_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ limitations under the License.
#include "xla/service/platform_util.h"
#include "xla/stream_executor/device_description.h"
#include "xla/stream_executor/gpu/mock_gpu_executor.h"
#include "xla/stream_executor/mock_stream.h"
#include "xla/stream_executor/platform.h"
#include "xla/stream_executor/stream_executor.h"
#include "xla/tests/filecheck.h"
Expand Down Expand Up @@ -111,6 +112,9 @@ class DeterminismTest : public GpuCodegenTest {
TF_ASSERT_OK_AND_ASSIGN(stream_executor::Platform * default_platform,
PlatformUtil::GetDefaultPlatform());
stream_executor::gpu::MockGpuExecutor executor(default_platform, 0);
EXPECT_CALL(executor, CreateStream).WillRepeatedly([&]() {
return backend().default_stream_executor()->CreateStream();
});
EXPECT_CALL(executor, CreateEventBasedTimer).Times(0);
EXPECT_CALL(executor, GetDeviceDescription)
.WillRepeatedly([this]() -> const se::DeviceDescription& {
Expand Down
Loading
Loading