Skip to content

Commit

Permalink
[XLA:GPU] Redirect some currently slow fusions to use the Triton emit…
Browse files Browse the repository at this point in the history
…ter if possible.

Typically, when we choose to use a `kLoop` fusion to emit a transpose, we're
making a bad decision. This change identifies a class of especially slow fusions,
and dispatches them to use the Triton emitter if they can be both tiled and
codegen'd correctly.

PiperOrigin-RevId: 683112912
  • Loading branch information
bchetioui authored and Google-ML-Automation committed Oct 10, 2024
1 parent d05087b commit 22afbbc
Show file tree
Hide file tree
Showing 9 changed files with 347 additions and 25 deletions.
22 changes: 22 additions & 0 deletions xla/service/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1301,6 +1301,25 @@ cc_library(
],
)

cc_library(
name = "fusion_dispatch_pipeline",
srcs = ["fusion_dispatch_pipeline.cc"],
hdrs = ["fusion_dispatch_pipeline.h"],
deps = [
"//xla:shape_util",
"//xla:xla_proto_cc",
"//xla/hlo/ir:hlo",
"//xla/hlo/pass:hlo_pass_pipeline",
"//xla/service:hlo_cost_analysis",
"//xla/service:pattern_matcher",
"//xla/service/gpu/transforms:fusion_block_level_rewriter",
"//xla/stream_executor:device_description",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/types:span",
"@llvm-project//llvm:Support",
],
)

cc_library(
name = "fusion_pipeline",
srcs = ["fusion_pipeline.cc"],
Expand Down Expand Up @@ -1367,6 +1386,7 @@ cc_library(
":cublas_cudnn",
":executable_proto_cc",
":execution_stream_assignment",
":fusion_dispatch_pipeline",
":fusion_pipeline",
":gpu_constants",
":gpu_executable",
Expand Down Expand Up @@ -1408,6 +1428,7 @@ cc_library(
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
"//xla/hlo/experimental/auto_sharding:auto_sharding_option",
"//xla/hlo/ir:hlo",
"//xla/hlo/ir:hlo_module_group",
"//xla/hlo/pass:hlo_pass",
Expand Down Expand Up @@ -1607,6 +1628,7 @@ xla_test(
backends = ["gpu"],
data = ["gpu_compiler_test_autotune_db.textproto"],
deps = [
":backend_configs_cc",
":gpu_compiler",
":gpu_hlo_schedule",
":metrics",
Expand Down
138 changes: 138 additions & 0 deletions xla/service/gpu/fusion_dispatch_pipeline.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
/* Copyright 2023 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "xla/service/gpu/fusion_dispatch_pipeline.h"

#include <cstdint>
#include <functional>
#include <utility>

#include "absl/status/statusor.h"
#include "absl/types/span.h"
#include "llvm/Support/MathExtras.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_instructions.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/hlo/pass/hlo_pass_pipeline.h"
#include "xla/layout_util.h"
#include "xla/service/gpu/transforms/fusion_block_level_rewriter.h"
#include "xla/service/hlo_cost_analysis.h"
#include "xla/service/pattern_matcher.h"
#include "xla/stream_executor/device_description.h"
#include "xla/xla.pb.h"

namespace xla {
namespace gpu {

namespace {

namespace m = ::xla::match;

bool IsSlowLoopTransposeFusion(const HloFusionInstruction* fusion) {
const HloInstruction* root =
fusion->fused_instructions_computation()->root_instruction();

bool is_loop_transpose_fusion =
fusion->fusion_kind() == HloInstruction::FusionKind::kLoop &&
root->opcode() == HloOpcode::kTranspose;

if (!is_loop_transpose_fusion) {
return false;
}

// The slow transposes are those when the minormost dimension in the input
// is neither the minormost nor the second minormost dimension in the output,
// and the output minormost dimension is swapped with the new minormost
// dimension.
int64_t rank = root->shape().rank();

// The transpose dimension grouper has run, so it should be enough to check
// that the minormost dimension's index within the result is smaller than
// rank - 2, and that the new minormost dimension is swapped with it.
// This only triggers for transposes with major-to-minor layout.
bool has_major_to_minor_layout =
LayoutUtil::IsMonotonicWithDim0Major(root->shape().layout());
absl::Span<int64_t const> transpose_dimensions = root->dimensions();
int64_t result_minormost_dim_in_operand = transpose_dimensions.back();

return has_major_to_minor_layout &&
transpose_dimensions[result_minormost_dim_in_operand] == rank - 1 &&
transpose_dimensions[rank - 1] < rank - 2;
}

// Pattern-matches slow loop fusions that can likely be handled better by
// Triton than by other emitters.
// TODO(b/370690811,b/372187266): generalize this to other slow transposes.
bool FusionWillBeHandledBetterByTriton(
const HloFusionInstruction* fusion,
const se::DeviceDescription& device_description) {
if (!IsSlowLoopTransposeFusion(fusion)) {
return false;
}

const HloInstruction* root =
fusion->fused_instructions_computation()->root_instruction();

// Because of Triton's power-of-two restriction, we're only guaranteed to
// handle the bitcast case when the bitcast's minor dimension is a power of
// two. This ensures that we can tile it reasonably even if the bitcast's
// input has that dimension collapsed. (See comments in `symbolic_tile.cc`
// around destructuring summations to understand why this is important.)
auto can_bitcast_input_be_tiled_efficiently =
[](const HloInstruction* bitcast) {
return llvm::isPowerOf2_64(bitcast->shape().dimensions_minor(0));
};

bool is_pure_transpose = ::xla::Match(root, m::Transpose(m::Parameter()));
bool is_bitcasted_transpose_with_power_of_two_minor_dim = ::xla::Match(
root,
m::Transpose(m::Bitcast(m::Parameter())
.WithPredicate(can_bitcast_input_be_tiled_efficiently)));
return is_pure_transpose ||
is_bitcasted_transpose_with_power_of_two_minor_dim;
}

} // anonymous namespace

HloPassPipeline FusionDispatchPipeline(
const se::DeviceDescription& device_description,
HloCostAnalysis::ShapeSizeFunction shape_size_fn) {
std::function<absl::StatusOr<bool>(const HloFusionInstruction*)>
try_rewrite_fusion_if =
[&device_description](
const HloFusionInstruction* fusion) -> absl::StatusOr<bool> {
bool should_always_rewrite_to_block_level =
fusion->GetModule()
->config()
.debug_options()
.xla_gpu_experimental_enable_fusion_block_level_rewriter();

// TODO(b/370690811): this rewrite may no longer be necessary once MLIR
// emitters transposes are faster.
return should_always_rewrite_to_block_level ||
FusionWillBeHandledBetterByTriton(fusion, device_description);
};

// Even though this is a single pass, we need to create a pipeline in order
// to make sure the pass's run is recorded in the `HloModuleMetadata`.
HloPassPipeline pipeline("fusion-dispatch-pipeline");
pipeline.AddPass<FusionBlockLevelRewriter>(device_description, shape_size_fn,
std::move(try_rewrite_fusion_if));
return pipeline;
}

} // namespace gpu
} // namespace xla
36 changes: 36 additions & 0 deletions xla/service/gpu/fusion_dispatch_pipeline.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/* Copyright 2024 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef XLA_SERVICE_GPU_FUSION_DISPATCH_PIPELINE_H_
#define XLA_SERVICE_GPU_FUSION_DISPATCH_PIPELINE_H_

#include "xla/hlo/pass/hlo_pass_pipeline.h"
#include "xla/service/hlo_cost_analysis.h"
#include "xla/stream_executor/device_description.h"
#include "xla/xla.pb.h"

namespace xla {
namespace gpu {

// Returns a pipeline that attempts to redirect fusions to the most efficient
// emitter possible.
HloPassPipeline FusionDispatchPipeline(
const se::DeviceDescription& device_description,
HloCostAnalysis::ShapeSizeFunction shape_size_fn);

} // namespace gpu
} // namespace xla

#endif // XLA_SERVICE_GPU_FUSION_DISPATCH_PIPELINE_H_
28 changes: 14 additions & 14 deletions xla/service/gpu/gpu_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ limitations under the License.
#include <cstdint>
#include <functional>
#include <memory>
#include <new>
#include <optional>
#include <string>
#include <string_view>
Expand Down Expand Up @@ -61,7 +60,9 @@ limitations under the License.
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/DialectRegistry.h"
#include "mlir/Support/LLVM.h"
#include "xla/hlo/experimental/auto_sharding/auto_sharding_option.h"
#include "xla/hlo/ir/hlo_casting_utils.h"
#include "xla/hlo/ir/hlo_computation.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_instructions.h"
#include "xla/hlo/ir/hlo_module.h"
Expand Down Expand Up @@ -121,6 +122,7 @@ limitations under the License.
#include "xla/service/gpu/conv_layout_normalization.h"
#include "xla/service/gpu/cublas_cudnn.h"
#include "xla/service/gpu/execution_stream_assignment.h"
#include "xla/service/gpu/fusion_dispatch_pipeline.h"
#include "xla/service/gpu/fusion_pipeline.h"
#include "xla/service/gpu/fusions/triton/triton_support.h"
#include "xla/service/gpu/gpu_executable.h"
Expand Down Expand Up @@ -160,7 +162,6 @@ limitations under the License.
#include "xla/service/gpu/transforms/dot_operand_converter.h"
#include "xla/service/gpu/transforms/double_buffer_loop_unrolling.h"
#include "xla/service/gpu/transforms/dynamic_slice_fusion_rewriter.h"
#include "xla/service/gpu/transforms/fusion_block_level_rewriter.h"
#include "xla/service/gpu/transforms/fusion_wrapper.h"
#include "xla/service/gpu/transforms/gemm_broadcast_folding_rewriter.h"
#include "xla/service/gpu/transforms/gemm_fusion.h"
Expand Down Expand Up @@ -1733,18 +1734,17 @@ absl::StatusOr<std::unique_ptr<HloModule>> GpuCompiler::RunHloPasses(

TF_RETURN_IF_ERROR(PrepareHloModuleForIrEmitting(module.get()));

// This needs to run after every pass affecting fusions, which includes
// `CopyFusion`, which itself must run in the `PrepareHloModuleForIrEmitting`
// pipeline.
if (module->config()
.debug_options()
.xla_gpu_experimental_enable_fusion_block_level_rewriter()) {
// Even though this is a single pass, we need to create a pipeline in order
// to make sure the pass's run is recorded in the `HloModuleMetadata`.
HloPassPipeline pipeline("fusion-block-level-rewriter-pipeline");
pipeline.AddPass<FusionBlockLevelRewriter>(
gpu_target_config.device_description, ShapeSizeBytesFunction());
TF_RETURN_IF_ERROR(pipeline.Run(module.get()).status());
const auto* cuda_cc = std::get_if<se::CudaComputeCapability>(
&gpu_target_config.device_description.gpu_compute_capability());
if (cuda_cc != nullptr && cuda_cc->IsAtLeastAmpere()) {
// This needs to run after every pass affecting fusions, which includes
// `CopyFusion`, which itself must run in the
// `PrepareHloModuleForIrEmitting` pipeline.
TF_RETURN_IF_ERROR(
FusionDispatchPipeline(gpu_target_config.device_description,
ShapeSizeBytesFunction())
.Run(module.get())
.status());
}

uint64_t end_usecs = tsl::Env::Default()->NowMicros();
Expand Down
Loading

0 comments on commit 22afbbc

Please sign in to comment.