diff --git a/xla/debug_options_flags.cc b/xla/debug_options_flags.cc index dec9398e3dd94..c2d7bc263f380 100644 --- a/xla/debug_options_flags.cc +++ b/xla/debug_options_flags.cc @@ -205,7 +205,6 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_gpu_exhaustive_tiling_search(false); - opts.set_xla_gpu_enable_priority_fusion(true); opts.set_xla_gpu_experimental_enable_triton_heroless_priority_fusion(false); opts.set_xla_gpu_experimental_enable_triton_softmax_priority_fusion(false); @@ -1619,11 +1618,9 @@ void MakeDebugOptionsFlags(std::vector* flag_list, bool_setter_for(&DebugOptions::set_xla_gpu_exhaustive_tiling_search), debug_options->xla_gpu_exhaustive_tiling_search(), "Enable (slow) search for the Triton GEMM fusion tilings.")); - flag_list->push_back(tsl::Flag( - "xla_gpu_enable_priority_fusion", - bool_setter_for(&DebugOptions::set_xla_gpu_enable_priority_fusion), - debug_options->xla_gpu_enable_priority_fusion(), - "Enable priority queue for fusion order.")); + flag_list->push_back(tsl::Flag("xla_gpu_enable_priority_fusion", + noop_flag_setter, true, + "[Deprecated, do not use]")); flag_list->push_back(tsl::Flag( "xla_gpu_experimental_enable_triton_heroless_priority_fusion", bool_setter_for( diff --git a/xla/service/gpu/fusion_pipeline.cc b/xla/service/gpu/fusion_pipeline.cc index bf18a2c9413bd..4b56e92a952c2 100644 --- a/xla/service/gpu/fusion_pipeline.cc +++ b/xla/service/gpu/fusion_pipeline.cc @@ -59,21 +59,14 @@ HloPassPipeline FusionPipeline( std::make_unique(std::move(opts)), "hlo verifier (debug)"); - if (debug_options.xla_gpu_enable_priority_fusion()) { - GpuHloCostAnalysis::Options cost_analysis_options{ - shape_size_bytes_function, - /*per_second_rates=*/{}, - /*min_latencies_seconds=*/{}, - /*count_multiple_input_accesses=*/true}; - fusion.AddPass(thread_pool, gpu_device_info, - std::move(cost_analysis_options)); - } else { - fusion.AddPass(/*may_duplicate=*/false, - gpu_device_info); - fusion.AddPass(/*may_duplicate=*/true, - gpu_device_info); - fusion.AddPass(gpu_device_info, shape_size_bytes_function); - } + GpuHloCostAnalysis::Options cost_analysis_options{ + shape_size_bytes_function, + /*per_second_rates=*/{}, + /*min_latencies_seconds=*/{}, + /*count_multiple_input_accesses=*/true}; + fusion.AddPass(thread_pool, gpu_device_info, + std::move(cost_analysis_options)); + // Running CSE affects how many users an op has. This plays a role in what // we detect as a tiled transpose fusion. fusion.AddPass(/*is_layout_sensitive=*/true, diff --git a/xla/service/gpu/gpu_compiler.cc b/xla/service/gpu/gpu_compiler.cc index 2c1ffdfba277c..2e5dfefcd1b29 100755 --- a/xla/service/gpu/gpu_compiler.cc +++ b/xla/service/gpu/gpu_compiler.cc @@ -1565,11 +1565,8 @@ absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment( } pipeline.AddPass(); - // Do not split small reduction dimensions unless priority fusion is - // enabled, which handles such cases well. - bool ignore_small_reduce_dims = - !debug_options.xla_gpu_enable_priority_fusion(); - pipeline.AddPass>(ignore_small_reduce_dims); + pipeline.AddPass>( + /*ignore_small_reduce_dims=*/false); pipeline.AddPass>(gpu_version); // Normalization passes might have introduced s4 tensors without bit width // annotations, this pass will add the annotations. diff --git a/xla/service/gpu/model/gpu_performance_model_base.h b/xla/service/gpu/model/gpu_performance_model_base.h index dbca8d0adc21d..3f5cf9aeaf33d 100644 --- a/xla/service/gpu/model/gpu_performance_model_base.h +++ b/xla/service/gpu/model/gpu_performance_model_base.h @@ -157,9 +157,7 @@ struct GpuPerformanceModelOptions { } static GpuPerformanceModelOptions ForModule(const HloModule* module) { - return module->config().debug_options().xla_gpu_enable_priority_fusion() - ? PriorityFusion() // Only cache within priority fusion. - : Default(); + return PriorityFusion(); } }; diff --git a/xla/xla.proto b/xla/xla.proto index 578249b43bdd8..094c429311ec2 100644 --- a/xla/xla.proto +++ b/xla/xla.proto @@ -707,8 +707,6 @@ message DebugOptions { reserved 220; // Was xla_gpu_enable_triton_softmax_fusion - bool xla_gpu_enable_priority_fusion = 221; - reserved 286; // Was xla_gpu_enable_triton_softmax_priority_fusion // File to write autotune results to. It will be a binary file unless the name @@ -1056,7 +1054,8 @@ message DebugOptions { // xla_gpu_single_wave_autotuning // xla_gpu_enable_persistent_temp_buffers // xla_gpu_enable_triton_gemm_int4 - reserved 5, 117, 133, 139, 176, 178, 180, 193, 214, 194, 242, 206, 320; + // xla_gpu_enable_priority_fusion + reserved 5, 117, 133, 139, 176, 178, 180, 193, 214, 194, 221, 242, 206, 320; } // Contains flags which affects the GPU compilation result.