diff --git a/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_legacy_test.cc b/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_legacy_test.cc index 9a2d2d6dbb752..e623f320cf293 100644 --- a/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_legacy_test.cc +++ b/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_legacy_test.cc @@ -99,6 +99,19 @@ class TritonTest : public GpuCodegenTest { class TritonGemmTest : public TritonTest { public: + se::GpuComputeCapability GetGpuComputeCapability() { + return backend() + .default_stream_executor() + ->GetDeviceDescription() + .gpu_compute_capability(); + } + + void SetUp() override { + if (std::holds_alternative(GetGpuComputeCapability())) { + GTEST_SKIP() << "Not supported on ROCm until Triton is re-enabled."; + } + } + DebugOptions GetDebugOptionsForTest() override { DebugOptions debug_options = TritonTest::GetDebugOptionsForTest(); // Do not fall back to cuBLAS, we are testing Triton. diff --git a/xla/service/gpu/fusions/triton/triton_fusion_emitter_large_test.cc b/xla/service/gpu/fusions/triton/triton_fusion_emitter_large_test.cc index 039b0d2d1863c..3006a7a4b001c 100644 --- a/xla/service/gpu/fusions/triton/triton_fusion_emitter_large_test.cc +++ b/xla/service/gpu/fusions/triton/triton_fusion_emitter_large_test.cc @@ -28,6 +28,19 @@ namespace { class TritonGemmTest : public GpuCodegenTest { public: + se::GpuComputeCapability GetGpuComputeCapability() { + return backend() + .default_stream_executor() + ->GetDeviceDescription() + .gpu_compute_capability(); + } + + void SetUp() override { + if (std::holds_alternative(GetGpuComputeCapability())) { + GTEST_SKIP() << "Not supported on ROCm until Triton is re-enabled."; + } + } + DebugOptions GetDebugOptionsForTest() override { DebugOptions debug_options = HloTestBase::GetDebugOptionsForTest(); debug_options.set_xla_gpu_cublas_fallback(false); diff --git a/xla/service/gpu/fusions/triton/triton_fusion_emitter_parametrized_test.cc b/xla/service/gpu/fusions/triton/triton_fusion_emitter_parametrized_test.cc index f946cc4257e5e..36bf13ff0c1a5 100644 --- a/xla/service/gpu/fusions/triton/triton_fusion_emitter_parametrized_test.cc +++ b/xla/service/gpu/fusions/triton/triton_fusion_emitter_parametrized_test.cc @@ -59,6 +59,19 @@ class MixedTypeTest : public GpuCodegenTest, .cuda_compute_capability(); } + se::GpuComputeCapability GetGpuComputeCapability() { + return backend() + .default_stream_executor() + ->GetDeviceDescription() + .gpu_compute_capability(); + } + + void SetUp() override { + if (std::holds_alternative(GetGpuComputeCapability())) { + GTEST_SKIP() << "Related fusions are not performed on ROCm without Triton."; + } + } + DebugOptions GetDebugOptionsForTest() override { DebugOptions debug_options = GpuCodegenTest::GetDebugOptionsForTest(); // We are testing Triton, remove cuBLAS fallback for these tests. diff --git a/xla/service/gpu/gpu_compiler.cc b/xla/service/gpu/gpu_compiler.cc index 676ed086bd163..bd25afbdaa9a3 100644 --- a/xla/service/gpu/gpu_compiler.cc +++ b/xla/service/gpu/gpu_compiler.cc @@ -1387,9 +1387,8 @@ absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment( const auto* rocm_cc = std::get_if(&gpu_version); if (debug_options.xla_gpu_enable_triton_gemm() && - ((cuda_cc != nullptr && - cuda_cc->IsAtLeast(se::CudaComputeCapability::AMPERE)) || - rocm_cc != nullptr)) { + (cuda_cc != nullptr && + cuda_cc->IsAtLeast(se::CudaComputeCapability::AMPERE))) { pipeline.AddPass(); pipeline.AddPass(gpu_version); } diff --git a/xla/service/gpu/gpu_compiler_test.cc b/xla/service/gpu/gpu_compiler_test.cc index d95b411c7ac70..648eb0d198f1f 100644 --- a/xla/service/gpu/gpu_compiler_test.cc +++ b/xla/service/gpu/gpu_compiler_test.cc @@ -76,6 +76,13 @@ class GpuCompilerTest : public HloTestBase { return tensorflow::down_cast(compiler) ->RunPostSchedulingPipelines(module, 4 * 1024 * 1024, gpu_device_info); } + + const stream_executor::GpuComputeCapability& GpuComputeComp() { + return backend() + .default_stream_executor() + ->GetDeviceDescription() + .gpu_compute_capability(); + } }; TEST_F(GpuCompilerTest, CompiledProgramsCount) {