Skip to content

Commit

Permalink
[ROCm] Disable gemm triton fusions for ROCm, until autotuner is funct…
Browse files Browse the repository at this point in the history
…ional.
  • Loading branch information
zoranjovanovic-ns committed Oct 1, 2024
1 parent 686aa6d commit 805b73d
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,19 @@ class TritonTest : public GpuCodegenTest {

class TritonGemmTest : public TritonTest {
public:
se::GpuComputeCapability GetGpuComputeCapability() {
return backend()
.default_stream_executor()
->GetDeviceDescription()
.gpu_compute_capability();
}

void SetUp() override {
if (std::holds_alternative<se::RocmComputeCapability>(GetGpuComputeCapability())) {
GTEST_SKIP() << "Not supported on ROCm until Triton is re-enabled.";
}
}

DebugOptions GetDebugOptionsForTest() override {
DebugOptions debug_options = TritonTest::GetDebugOptionsForTest();
// Do not fall back to cuBLAS, we are testing Triton.
Expand Down
13 changes: 13 additions & 0 deletions xla/service/gpu/fusions/triton/triton_fusion_emitter_large_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,19 @@ namespace {

class TritonGemmTest : public GpuCodegenTest {
public:
se::GpuComputeCapability GetGpuComputeCapability() {
return backend()
.default_stream_executor()
->GetDeviceDescription()
.gpu_compute_capability();
}

void SetUp() override {
if (std::holds_alternative<se::RocmComputeCapability>(GetGpuComputeCapability())) {
GTEST_SKIP() << "Not supported on ROCm until Triton is re-enabled.";
}
}

DebugOptions GetDebugOptionsForTest() override {
DebugOptions debug_options = HloTestBase::GetDebugOptionsForTest();
debug_options.set_xla_gpu_cublas_fallback(false);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,19 @@ class MixedTypeTest : public GpuCodegenTest,
.cuda_compute_capability();
}

se::GpuComputeCapability GetGpuComputeCapability() {
return backend()
.default_stream_executor()
->GetDeviceDescription()
.gpu_compute_capability();
}

void SetUp() override {
if (std::holds_alternative<se::RocmComputeCapability>(GetGpuComputeCapability())) {
GTEST_SKIP() << "Related fusions are not performed on ROCm without Triton.";
}
}

DebugOptions GetDebugOptionsForTest() override {
DebugOptions debug_options = GpuCodegenTest::GetDebugOptionsForTest();
// We are testing Triton, remove cuBLAS fallback for these tests.
Expand Down
5 changes: 2 additions & 3 deletions xla/service/gpu/gpu_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1387,9 +1387,8 @@ absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment(
const auto* rocm_cc = std::get_if<se::RocmComputeCapability>(&gpu_version);

if (debug_options.xla_gpu_enable_triton_gemm() &&
((cuda_cc != nullptr &&
cuda_cc->IsAtLeast(se::CudaComputeCapability::AMPERE)) ||
rocm_cc != nullptr)) {
(cuda_cc != nullptr &&
cuda_cc->IsAtLeast(se::CudaComputeCapability::AMPERE))) {
pipeline.AddPass<GemvRewriter>();
pipeline.AddPass<GemmFusion>(gpu_version);
}
Expand Down
7 changes: 7 additions & 0 deletions xla/service/gpu/gpu_compiler_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,13 @@ class GpuCompilerTest : public HloTestBase {
return tensorflow::down_cast<GpuCompiler*>(compiler)
->RunPostSchedulingPipelines(module, 4 * 1024 * 1024, gpu_device_info);
}

const stream_executor::GpuComputeCapability& GpuComputeComp() {
return backend()
.default_stream_executor()
->GetDeviceDescription()
.gpu_compute_capability();
}
};

TEST_F(GpuCompilerTest, CompiledProgramsCount) {
Expand Down

0 comments on commit 805b73d

Please sign in to comment.