diff --git a/xla/service/gpu/BUILD b/xla/service/gpu/BUILD index a74e228ffa72d3..07fb149d681d89 100644 --- a/xla/service/gpu/BUILD +++ b/xla/service/gpu/BUILD @@ -1301,6 +1301,25 @@ cc_library( ], ) +cc_library( + name = "fusion_dispatch_pipeline", + srcs = ["fusion_dispatch_pipeline.cc"], + hdrs = ["fusion_dispatch_pipeline.h"], + deps = [ + "//xla:shape_util", + "//xla:xla_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass_pipeline", + "//xla/service:hlo_cost_analysis", + "//xla/service:pattern_matcher", + "//xla/service/gpu/transforms:fusion_block_level_rewriter", + "//xla/stream_executor:device_description", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/types:span", + "@llvm-project//llvm:Support", + ], +) + cc_library( name = "fusion_pipeline", srcs = ["fusion_pipeline.cc"], @@ -1367,6 +1386,7 @@ cc_library( ":cublas_cudnn", ":executable_proto_cc", ":execution_stream_assignment", + ":fusion_dispatch_pipeline", ":fusion_pipeline", ":gpu_constants", ":gpu_executable", @@ -1408,6 +1428,7 @@ cc_library( "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", "@llvm-project//mlir:Support", + "//xla/hlo/experimental/auto_sharding:auto_sharding_option", "//xla/hlo/ir:hlo", "//xla/hlo/ir:hlo_module_group", "//xla/hlo/pass:hlo_pass", @@ -1607,6 +1628,7 @@ xla_test( backends = ["gpu"], data = ["gpu_compiler_test_autotune_db.textproto"], deps = [ + ":backend_configs_cc", ":gpu_compiler", ":gpu_hlo_schedule", ":metrics", diff --git a/xla/service/gpu/fusion_dispatch_pipeline.cc b/xla/service/gpu/fusion_dispatch_pipeline.cc new file mode 100644 index 00000000000000..435ad70c473282 --- /dev/null +++ b/xla/service/gpu/fusion_dispatch_pipeline.cc @@ -0,0 +1,138 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/fusion_dispatch_pipeline.h" + +#include +#include +#include + +#include "absl/status/statusor.h" +#include "absl/types/span.h" +#include "llvm/Support/MathExtras.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/pass/hlo_pass_pipeline.h" +#include "xla/layout_util.h" +#include "xla/service/gpu/transforms/fusion_block_level_rewriter.h" +#include "xla/service/hlo_cost_analysis.h" +#include "xla/service/pattern_matcher.h" +#include "xla/stream_executor/device_description.h" +#include "xla/xla.pb.h" + +namespace xla { +namespace gpu { + +namespace { + +namespace m = ::xla::match; + +bool IsSlowLoopTransposeFusion(const HloFusionInstruction* fusion) { + const HloInstruction* root = + fusion->fused_instructions_computation()->root_instruction(); + + bool is_loop_transpose_fusion = + fusion->fusion_kind() == HloInstruction::FusionKind::kLoop && + root->opcode() == HloOpcode::kTranspose; + + if (!is_loop_transpose_fusion) { + return false; + } + + // The slow transposes are those when the minormost dimension in the input + // is neither the minormost nor the second minormost dimension in the output, + // and the output minormost dimension is swapped with the new minormost + // dimension. + int64_t rank = root->shape().rank(); + + // The transpose dimension grouper has run, so it should be enough to check + // that the minormost dimension's index within the result is smaller than + // rank - 2, and that the new minormost dimension is swapped with it. + // This only triggers for transposes with major-to-minor layout. + bool has_major_to_minor_layout = + LayoutUtil::IsMonotonicWithDim0Major(root->shape().layout()); + absl::Span transpose_dimensions = root->dimensions(); + int64_t result_minormost_dim_in_operand = transpose_dimensions.back(); + + return has_major_to_minor_layout && + transpose_dimensions[result_minormost_dim_in_operand] == rank - 1 && + transpose_dimensions[rank - 1] < rank - 2; +} + +// Pattern-matches slow loop fusions that can likely be handled better by +// Triton than by other emitters. +// TODO(b/370690811,b/372187266): generalize this to other slow transposes. +bool FusionWillBeHandledBetterByTriton( + const HloFusionInstruction* fusion, + const se::DeviceDescription& device_description) { + if (!IsSlowLoopTransposeFusion(fusion)) { + return false; + } + + const HloInstruction* root = + fusion->fused_instructions_computation()->root_instruction(); + + // Because of Triton's power-of-two restriction, we're only guaranteed to + // handle the bitcast case when the bitcast's minor dimension is a power of + // two. This ensures that we can tile it reasonably even if the bitcast's + // input has that dimension collapsed. (See comments in `symbolic_tile.cc` + // around destructuring summations to understand why this is important.) + auto can_bitcast_input_be_tiled_efficiently = + [](const HloInstruction* bitcast) { + return llvm::isPowerOf2_64(bitcast->shape().dimensions_minor(0)); + }; + + bool is_pure_transpose = ::xla::Match(root, m::Transpose(m::Parameter())); + bool is_bitcasted_transpose_with_power_of_two_minor_dim = ::xla::Match( + root, + m::Transpose(m::Bitcast(m::Parameter()) + .WithPredicate(can_bitcast_input_be_tiled_efficiently))); + return is_pure_transpose || + is_bitcasted_transpose_with_power_of_two_minor_dim; +} + +} // anonymous namespace + +HloPassPipeline FusionDispatchPipeline( + const se::DeviceDescription& device_description, + HloCostAnalysis::ShapeSizeFunction shape_size_fn) { + std::function(const HloFusionInstruction*)> + try_rewrite_fusion_if = + [&device_description]( + const HloFusionInstruction* fusion) -> absl::StatusOr { + bool should_always_rewrite_to_block_level = + fusion->GetModule() + ->config() + .debug_options() + .xla_gpu_experimental_enable_fusion_block_level_rewriter(); + + // TODO(b/370690811): this rewrite may no longer be necessary once MLIR + // emitters transposes are faster. + return should_always_rewrite_to_block_level || + FusionWillBeHandledBetterByTriton(fusion, device_description); + }; + + // Even though this is a single pass, we need to create a pipeline in order + // to make sure the pass's run is recorded in the `HloModuleMetadata`. + HloPassPipeline pipeline("fusion-dispatch-pipeline"); + pipeline.AddPass(device_description, shape_size_fn, + std::move(try_rewrite_fusion_if)); + return pipeline; +} + +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/fusion_dispatch_pipeline.h b/xla/service/gpu/fusion_dispatch_pipeline.h new file mode 100644 index 00000000000000..7256f9d2d9567d --- /dev/null +++ b/xla/service/gpu/fusion_dispatch_pipeline.h @@ -0,0 +1,36 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_FUSION_DISPATCH_PIPELINE_H_ +#define XLA_SERVICE_GPU_FUSION_DISPATCH_PIPELINE_H_ + +#include "xla/hlo/pass/hlo_pass_pipeline.h" +#include "xla/service/hlo_cost_analysis.h" +#include "xla/stream_executor/device_description.h" +#include "xla/xla.pb.h" + +namespace xla { +namespace gpu { + +// Returns a pipeline that attempts to redirect fusions to the most efficient +// emitter possible. +HloPassPipeline FusionDispatchPipeline( + const se::DeviceDescription& device_description, + HloCostAnalysis::ShapeSizeFunction shape_size_fn); + +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_FUSION_DISPATCH_PIPELINE_H_ diff --git a/xla/service/gpu/gpu_compiler.cc b/xla/service/gpu/gpu_compiler.cc index 7d742f96df9e75..82b282ec9bb201 100644 --- a/xla/service/gpu/gpu_compiler.cc +++ b/xla/service/gpu/gpu_compiler.cc @@ -20,7 +20,6 @@ limitations under the License. #include #include #include -#include #include #include #include @@ -61,7 +60,9 @@ limitations under the License. #include "mlir/IR/Diagnostics.h" #include "mlir/IR/DialectRegistry.h" #include "mlir/Support/LLVM.h" +#include "xla/hlo/experimental/auto_sharding/auto_sharding_option.h" #include "xla/hlo/ir/hlo_casting_utils.h" +#include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_module.h" @@ -121,6 +122,7 @@ limitations under the License. #include "xla/service/gpu/conv_layout_normalization.h" #include "xla/service/gpu/cublas_cudnn.h" #include "xla/service/gpu/execution_stream_assignment.h" +#include "xla/service/gpu/fusion_dispatch_pipeline.h" #include "xla/service/gpu/fusion_pipeline.h" #include "xla/service/gpu/fusions/triton/triton_support.h" #include "xla/service/gpu/gpu_executable.h" @@ -160,7 +162,6 @@ limitations under the License. #include "xla/service/gpu/transforms/dot_operand_converter.h" #include "xla/service/gpu/transforms/double_buffer_loop_unrolling.h" #include "xla/service/gpu/transforms/dynamic_slice_fusion_rewriter.h" -#include "xla/service/gpu/transforms/fusion_block_level_rewriter.h" #include "xla/service/gpu/transforms/fusion_wrapper.h" #include "xla/service/gpu/transforms/gemm_broadcast_folding_rewriter.h" #include "xla/service/gpu/transforms/gemm_fusion.h" @@ -1733,18 +1734,17 @@ absl::StatusOr> GpuCompiler::RunHloPasses( TF_RETURN_IF_ERROR(PrepareHloModuleForIrEmitting(module.get())); - // This needs to run after every pass affecting fusions, which includes - // `CopyFusion`, which itself must run in the `PrepareHloModuleForIrEmitting` - // pipeline. - if (module->config() - .debug_options() - .xla_gpu_experimental_enable_fusion_block_level_rewriter()) { - // Even though this is a single pass, we need to create a pipeline in order - // to make sure the pass's run is recorded in the `HloModuleMetadata`. - HloPassPipeline pipeline("fusion-block-level-rewriter-pipeline"); - pipeline.AddPass( - gpu_target_config.device_description, ShapeSizeBytesFunction()); - TF_RETURN_IF_ERROR(pipeline.Run(module.get()).status()); + const auto* cuda_cc = std::get_if( + &gpu_target_config.device_description.gpu_compute_capability()); + if (cuda_cc != nullptr && cuda_cc->IsAtLeastAmpere()) { + // This needs to run after every pass affecting fusions, which includes + // `CopyFusion`, which itself must run in the + // `PrepareHloModuleForIrEmitting` pipeline. + TF_RETURN_IF_ERROR( + FusionDispatchPipeline(gpu_target_config.device_description, + ShapeSizeBytesFunction()) + .Run(module.get()) + .status()); } uint64_t end_usecs = tsl::Env::Default()->NowMicros(); diff --git a/xla/service/gpu/gpu_compiler_test.cc b/xla/service/gpu/gpu_compiler_test.cc index 0883e5a0431ca5..4c0dc2795ba9a5 100644 --- a/xla/service/gpu/gpu_compiler_test.cc +++ b/xla/service/gpu/gpu_compiler_test.cc @@ -47,6 +47,7 @@ limitations under the License. #include "xla/service/compiler.h" #include "xla/service/executable.h" #include "xla/service/gpu/autotuning/autotuner_util.h" +#include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/gpu_hlo_schedule.h" #include "xla/service/gpu/metrics.h" #include "xla/service/hlo_module_config.h" @@ -978,6 +979,104 @@ TEST_F(GpuCompilerTest, TestFlag_xla_gpu_unsafe_pipelined_loop_annotator) { EXPECT_TRUE(filecheck_matched); } +bool HasBlockLevelFusionConfig(const HloInstruction* fusion) { + return fusion->opcode() == HloOpcode::kFusion && + fusion->has_backend_config() && + fusion->backend_config().ok() && + fusion->backend_config() + ->fusion_backend_config() + .has_block_level_fusion_config(); +} + +TEST_F(GpuCompilerTest, + LoopFusionRootedInTransposeIsRewrittenToBlockLevelByDefaultPostAmpere) { + auto cc = backend() + .default_stream_executor() + ->GetDeviceDescription() + .cuda_compute_capability(); + + constexpr absl::string_view transpose_fusion_module = R"( +transpose { + p0 = f32[1024,1024,1024] parameter(0) + ROOT transpose = f32[1024,1024,1024] transpose(p0), dimensions={2,1,0} +} + +ENTRY main { + p0 = f32[1024,1024,1024] parameter(0) + ROOT fusion = f32[1024,1024,1024] fusion(p0), kind=kLoop, calls=transpose +})"; + + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + ParseAndReturnVerifiedModule(transpose_fusion_module)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr optimized_module, + GetOptimizedModule(std::move(module))); + + if (cc.IsAtLeastAmpere()) { + EXPECT_TRUE(HasBlockLevelFusionConfig( + optimized_module->entry_computation()->root_instruction())); + } else { + EXPECT_FALSE(HasBlockLevelFusionConfig( + optimized_module->entry_computation()->root_instruction())); + } +} + +TEST_F( + GpuCompilerTest, + FusionBlockLevelRewriterRewritesKLoopTransposeWithBitcastIfTheSmallMinorDimIsAPowerOfTwo) { // NOLINT(whitespace/line_length) + auto cc = backend() + .default_stream_executor() + ->GetDeviceDescription() + .cuda_compute_capability(); + if (!cc.IsAtLeastAmpere()) { + GTEST_SKIP() << "FusionBlockLevelRewriter requires Ampere+ to run."; + } + + // If this test starts failing, then it's likely that this no longer generates + // a kLoop transpose. That's great---it probably means the rewrite in question + // is no longer necessary! + // + // The small minor dimension here is a power of two, so the rewrite should + // succeed. + constexpr absl::string_view rewritable_transpose_string = R"( +ENTRY main { + p0 = f32[1024,4096]{1,0} parameter(0) + reshape = f32[1024,1024,4]{2,1,0} reshape(p0) + ROOT transpose = f32[4,1024,1024]{2,1,0} transpose(reshape), dimensions={2,1,0} +})"; + + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr rewritable_transpose_module, + ParseAndReturnVerifiedModule(rewritable_transpose_string)); + + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr rewritable_transpose_optimized_module, + GetOptimizedModule(std::move(rewritable_transpose_module))); + EXPECT_TRUE(HasBlockLevelFusionConfig( + rewritable_transpose_optimized_module->entry_computation() + ->root_instruction())); + + // The small minor dimension here is not a power of two, so the rewrite should + // fail. + constexpr absl::string_view unrewritable_transpose_string = R"( +ENTRY main { + p0 = f32[1024,6144]{1,0} parameter(0) + reshape = f32[1024,1024,6]{2,1,0} reshape(p0) + ROOT transpose = f32[6,1024,1024]{2,1,0} transpose(reshape), dimensions={2,1,0} +})"; + + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr unrewritable_transpose_module, + ParseAndReturnVerifiedModule(unrewritable_transpose_string)); + + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr unrewritable_transpose_optimized_module, + GetOptimizedModule(std::move(unrewritable_transpose_module))); + EXPECT_FALSE(HasBlockLevelFusionConfig( + unrewritable_transpose_optimized_module->entry_computation() + ->root_instruction())); +} + using GpuCompilerPassTest = GpuCompilerTest; TEST_F(GpuCompilerPassTest, diff --git a/xla/service/gpu/transforms/BUILD b/xla/service/gpu/transforms/BUILD index f28812f3a8cdb8..d4f704912716ac 100644 --- a/xla/service/gpu/transforms/BUILD +++ b/xla/service/gpu/transforms/BUILD @@ -352,9 +352,9 @@ cc_library( "//xla/service/gpu/model:gpu_indexing_performance_model", "//xla/stream_executor:device_description", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@llvm-project//mlir:IR", diff --git a/xla/service/gpu/transforms/fusion_block_level_rewriter.cc b/xla/service/gpu/transforms/fusion_block_level_rewriter.cc index 2459cb216231f8..809183953d306d 100644 --- a/xla/service/gpu/transforms/fusion_block_level_rewriter.cc +++ b/xla/service/gpu/transforms/fusion_block_level_rewriter.cc @@ -135,11 +135,18 @@ absl::StatusOr FusionBlockLevelRewriter::Run( continue; } + HloFusionInstruction* fusion_instruction = + ::xla::Cast(computation->FusionInstruction()); + + TF_ASSIGN_OR_RETURN(bool should_try_rewrite, + should_try_rewrite_if_(fusion_instruction)); + if (!should_try_rewrite) { + continue; + } + TF_ASSIGN_OR_RETURN( - bool changed, - ProcessFusionInstruction( - ::xla::Cast(computation->FusionInstruction()), - device_info_, shape_size_, &ctx)); + bool changed, ProcessFusionInstruction(fusion_instruction, device_info_, + shape_size_, &ctx)); has_changed |= changed; } diff --git a/xla/service/gpu/transforms/fusion_block_level_rewriter.h b/xla/service/gpu/transforms/fusion_block_level_rewriter.h index 72cb17b098b8b9..6cf8f988242f97 100644 --- a/xla/service/gpu/transforms/fusion_block_level_rewriter.h +++ b/xla/service/gpu/transforms/fusion_block_level_rewriter.h @@ -16,9 +16,13 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_TRANSFORMS_FUSION_BLOCK_LEVEL_REWRITER_H_ #define XLA_SERVICE_GPU_TRANSFORMS_FUSION_BLOCK_LEVEL_REWRITER_H_ +#include + #include "absl/container/flat_hash_set.h" +#include "absl/functional/any_invocable.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/pass/hlo_pass_interface.h" #include "xla/service/hlo_cost_analysis.h" @@ -31,8 +35,12 @@ class FusionBlockLevelRewriter : public HloModulePass { public: explicit FusionBlockLevelRewriter( const se::DeviceDescription& device_info, - HloCostAnalysis::ShapeSizeFunction shape_size) - : device_info_(device_info), shape_size_(shape_size) {} + HloCostAnalysis::ShapeSizeFunction shape_size, + absl::AnyInvocable(const HloFusionInstruction*)> + should_try_rewrite_if) + : device_info_(device_info), + shape_size_(shape_size), + should_try_rewrite_if_(std::move(should_try_rewrite_if)) {} absl::string_view name() const override { return "fusion-block-level-rewriter"; @@ -46,6 +54,8 @@ class FusionBlockLevelRewriter : public HloModulePass { private: const se::DeviceDescription& device_info_; HloCostAnalysis::ShapeSizeFunction shape_size_; + absl::AnyInvocable(const HloFusionInstruction*)> + should_try_rewrite_if_; }; } // namespace gpu diff --git a/xla/service/gpu/transforms/fusion_block_level_rewriter_test.cc b/xla/service/gpu/transforms/fusion_block_level_rewriter_test.cc index 37cb04445b5ea8..58da2aae00b9e3 100644 --- a/xla/service/gpu/transforms/fusion_block_level_rewriter_test.cc +++ b/xla/service/gpu/transforms/fusion_block_level_rewriter_test.cc @@ -26,6 +26,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "mlir/IR/MLIRContext.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/gpu/backend_configs.pb.h" @@ -72,6 +73,10 @@ class FusionBlockLevelRewriterTest : public HloTestBase { se::CudaComputeCapability::Ampere())}; }; +bool RewriteEverythingPossible(const HloFusionInstruction* fusion) { + return true; +} + TEST_F(FusionBlockLevelRewriterTest, DoesNotRewriteFusionThatIsAlreadyBlockLevel) { const absl::string_view hlo_text = R"( @@ -88,7 +93,8 @@ ENTRY entry { })"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hlo_text)); - EXPECT_THAT(FusionBlockLevelRewriter(device_info_, ShapeSizeBytesFunction()) + EXPECT_THAT(FusionBlockLevelRewriter(device_info_, ShapeSizeBytesFunction(), + RewriteEverythingPossible) .Run(module.get()), IsOkAndHolds(false)); } @@ -107,7 +113,9 @@ ENTRY entry { })"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hlo_text)); - EXPECT_THAT(FusionBlockLevelRewriter(device_info_, ShapeSizeBytesFunction()) + + EXPECT_THAT(FusionBlockLevelRewriter(device_info_, ShapeSizeBytesFunction(), + RewriteEverythingPossible) .Run(module.get()), IsOkAndHolds(true)); const HloInstruction* root = module->entry_computation()->root_instruction(); @@ -136,7 +144,8 @@ ENTRY entry { ASSERT_FALSE(std::holds_alternative( SymbolicTileAnalysis::AnalyzeComputation( *module->GetComputationWithName("fusion_computation"), &ctx))); - EXPECT_THAT(FusionBlockLevelRewriter(device_info_, ShapeSizeBytesFunction()) + EXPECT_THAT(FusionBlockLevelRewriter(device_info_, ShapeSizeBytesFunction(), + RewriteEverythingPossible) .Run(module.get()), IsOkAndHolds(false)); } @@ -159,7 +168,8 @@ ENTRY entry { ASSERT_FALSE(IsTritonSupportedComputation( *module->GetComputationWithName("fusion_computation"), device_info_.gpu_compute_capability())); - EXPECT_THAT(FusionBlockLevelRewriter(device_info_, ShapeSizeBytesFunction()) + EXPECT_THAT(FusionBlockLevelRewriter(device_info_, ShapeSizeBytesFunction(), + RewriteEverythingPossible) .Run(module.get()), IsOkAndHolds(false)); }