Skip to content

Commit

Permalink
[XLA:GPU] Deprecate xla_gpu_enable_priority_fusion flag.
Browse files Browse the repository at this point in the history
This flag has been enabled by default since Feb 2024. All users should have migrated to Priority Fusion by this time.

PiperOrigin-RevId: 689673947
  • Loading branch information
olegshyshkov authored and Google-ML-Automation committed Oct 25, 2024
1 parent 31e7e36 commit 318aeb1
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 32 deletions.
9 changes: 3 additions & 6 deletions xla/debug_options_flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,6 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() {

opts.set_xla_gpu_exhaustive_tiling_search(false);

opts.set_xla_gpu_enable_priority_fusion(true);
opts.set_xla_gpu_experimental_enable_triton_heroless_priority_fusion(false);
opts.set_xla_gpu_experimental_enable_triton_softmax_priority_fusion(false);

Expand Down Expand Up @@ -1619,11 +1618,9 @@ void MakeDebugOptionsFlags(std::vector<tsl::Flag>* flag_list,
bool_setter_for(&DebugOptions::set_xla_gpu_exhaustive_tiling_search),
debug_options->xla_gpu_exhaustive_tiling_search(),
"Enable (slow) search for the Triton GEMM fusion tilings."));
flag_list->push_back(tsl::Flag(
"xla_gpu_enable_priority_fusion",
bool_setter_for(&DebugOptions::set_xla_gpu_enable_priority_fusion),
debug_options->xla_gpu_enable_priority_fusion(),
"Enable priority queue for fusion order."));
flag_list->push_back(tsl::Flag("xla_gpu_enable_priority_fusion",
noop_flag_setter<bool>, true,
"[Deprecated, do not use]"));
flag_list->push_back(tsl::Flag(
"xla_gpu_experimental_enable_triton_heroless_priority_fusion",
bool_setter_for(
Expand Down
23 changes: 8 additions & 15 deletions xla/service/gpu/fusion_pipeline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,21 +59,14 @@ HloPassPipeline FusionPipeline(
std::make_unique<CpuGpuVerifierMetadata>(std::move(opts)),
"hlo verifier (debug)");

if (debug_options.xla_gpu_enable_priority_fusion()) {
GpuHloCostAnalysis::Options cost_analysis_options{
shape_size_bytes_function,
/*per_second_rates=*/{},
/*min_latencies_seconds=*/{},
/*count_multiple_input_accesses=*/true};
fusion.AddPass<PriorityFusion>(thread_pool, gpu_device_info,
std::move(cost_analysis_options));
} else {
fusion.AddPass<GpuInstructionFusion>(/*may_duplicate=*/false,
gpu_device_info);
fusion.AddPass<GpuInstructionFusion>(/*may_duplicate=*/true,
gpu_device_info);
fusion.AddPass<FusionMerger>(gpu_device_info, shape_size_bytes_function);
}
GpuHloCostAnalysis::Options cost_analysis_options{
shape_size_bytes_function,
/*per_second_rates=*/{},
/*min_latencies_seconds=*/{},
/*count_multiple_input_accesses=*/true};
fusion.AddPass<PriorityFusion>(thread_pool, gpu_device_info,
std::move(cost_analysis_options));

// Running CSE affects how many users an op has. This plays a role in what
// we detect as a tiled transpose fusion.
fusion.AddPass<HloCSE>(/*is_layout_sensitive=*/true,
Expand Down
7 changes: 2 additions & 5 deletions xla/service/gpu/gpu_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1565,11 +1565,8 @@ absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment(
}

pipeline.AddPass<ReductionDimensionGrouper>();
// Do not split small reduction dimensions unless priority fusion is
// enabled, which handles such cases well.
bool ignore_small_reduce_dims =
!debug_options.xla_gpu_enable_priority_fusion();
pipeline.AddPass<HloPassFix<ReductionSplitter>>(ignore_small_reduce_dims);
pipeline.AddPass<HloPassFix<ReductionSplitter>>(
/*ignore_small_reduce_dims=*/false);
pipeline.AddPass<HloPassFix<TreeReductionRewriter>>(gpu_version);
// Normalization passes might have introduced s4 tensors without bit width
// annotations, this pass will add the annotations.
Expand Down
4 changes: 1 addition & 3 deletions xla/service/gpu/model/gpu_performance_model_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -157,9 +157,7 @@ struct GpuPerformanceModelOptions {
}

static GpuPerformanceModelOptions ForModule(const HloModule* module) {
return module->config().debug_options().xla_gpu_enable_priority_fusion()
? PriorityFusion() // Only cache within priority fusion.
: Default();
return PriorityFusion();
}
};

Expand Down
5 changes: 2 additions & 3 deletions xla/xla.proto
Original file line number Diff line number Diff line change
Expand Up @@ -707,8 +707,6 @@ message DebugOptions {

reserved 220; // Was xla_gpu_enable_triton_softmax_fusion

bool xla_gpu_enable_priority_fusion = 221;

reserved 286; // Was xla_gpu_enable_triton_softmax_priority_fusion

// File to write autotune results to. It will be a binary file unless the name
Expand Down Expand Up @@ -1056,7 +1054,8 @@ message DebugOptions {
// xla_gpu_single_wave_autotuning
// xla_gpu_enable_persistent_temp_buffers
// xla_gpu_enable_triton_gemm_int4
reserved 5, 117, 133, 139, 176, 178, 180, 193, 214, 194, 242, 206, 320;
// xla_gpu_enable_priority_fusion
reserved 5, 117, 133, 139, 176, 178, 180, 193, 214, 194, 221, 242, 206, 320;
}

// Contains flags which affects the GPU compilation result.
Expand Down

0 comments on commit 318aeb1

Please sign in to comment.