diff --git a/xla/debug_options_flags.cc b/xla/debug_options_flags.cc index f1d110ea31d28e..1b1f00a2853705 100644 --- a/xla/debug_options_flags.cc +++ b/xla/debug_options_flags.cc @@ -175,8 +175,6 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_gpu_enable_pipelined_reduce_scatter(true); opts.set_xla_gpu_enable_pipelined_p2p(false); - opts.set_xla_gpu_run_post_layout_collective_pipeliner(false); - opts.set_xla_gpu_collective_permute_decomposer_threshold( std::numeric_limits::max()); @@ -1559,12 +1557,6 @@ void MakeDebugOptionsFlags(std::vector* flag_list, bool_setter_for(&DebugOptions::set_xla_gpu_enable_pipelined_p2p), debug_options->xla_gpu_enable_pipelined_p2p(), "Enable pipelinling of P2P instructions.")); - flag_list->push_back(tsl::Flag( - "xla_gpu_run_post_layout_collective_pipeliner", - bool_setter_for( - &DebugOptions::set_xla_gpu_run_post_layout_collective_pipeliner), - debug_options->xla_gpu_run_post_layout_collective_pipeliner(), - "Move collective pipeliner after the post-layout optimization.")); flag_list->push_back(tsl::Flag( "xla_gpu_collective_permute_decomposer_threshold", int64_setter_for( diff --git a/xla/service/gpu/gpu_compiler.cc b/xla/service/gpu/gpu_compiler.cc index 7d742f96df9e75..c8a220fb306940 100644 --- a/xla/service/gpu/gpu_compiler.cc +++ b/xla/service/gpu/gpu_compiler.cc @@ -834,8 +834,32 @@ absl::Status RunOptimizationPasses( return pipeline.Run(hlo_module).status(); } -absl::Status AddCollectivePipelinerPasses( - const DebugOptions& debug_options, HloPassPipeline& collectives_pipeline) { +absl::Status RunCollectiveOptimizationPasses( + HloModule* hlo_module, + const AlgebraicSimplifierOptions& layout_insensitive_algsimp_opts, + se::GpuComputeCapability gpu_version) { + // Optimize collectives generated by SPMD partitioning. Enable these passes + // otherwise as well so that all collectives can get these optimizations. + const DebugOptions& debug_options = hlo_module->config().debug_options(); + + HloPassPipeline collectives_pipeline("collective-optimizations"); + collectives_pipeline.AddPass(); + collectives_pipeline.AddPass(); + collectives_pipeline.AddPass(); + collectives_pipeline.AddPass( + debug_options.xla_gpu_enable_reassociation_for_converted_ar()); + collectives_pipeline.AddPass(); + + collectives_pipeline.AddPass( + /*enable_reduce_scatter=*/debug_options + .xla_gpu_enable_while_loop_reduce_scatter_code_motion()); + + // Moves collectives' subsequent quantization before the collective to + // minimize data transfers. + collectives_pipeline.AddPass(); + // Remove dead computations after collective quantization. + collectives_pipeline.AddPass(); + if (debug_options.xla_gpu_enable_pipelined_collectives() || debug_options.xla_gpu_enable_pipelined_all_reduce()) { CollectivePipeliner::Config config{ @@ -887,54 +911,7 @@ absl::Status AddCollectivePipelinerPasses( /*reuse_pipelined_op_buffer=*/HloPredicateFalse}; collectives_pipeline.AddPass(config); } - return absl::OkStatus(); -} -absl::Status RunPostLayoutCollectivePipelinerPasses(HloModule* hlo_module) { - const DebugOptions& debug_options = hlo_module->config().debug_options(); - HloPassPipeline collectives_pipeline("collective-pipeliner-optimizations"); - if (debug_options.xla_gpu_run_post_layout_collective_pipeliner()) { - TF_RETURN_IF_ERROR( - AddCollectivePipelinerPasses(debug_options, collectives_pipeline)); - // We call WhileLoopTripCountAnnotator at the end of the collective - // pipeline, which might have changed the loop trip count. - collectives_pipeline.AddPass(); - // Flatten call graph after loop peeling. - collectives_pipeline.AddPass(); - } - return collectives_pipeline.Run(hlo_module).status(); -} - -absl::Status RunCollectiveOptimizationPasses( - HloModule* hlo_module, - const AlgebraicSimplifierOptions& layout_insensitive_algsimp_opts, - se::GpuComputeCapability gpu_version) { - // Optimize collectives generated by SPMD partitioning. Enable these passes - // otherwise as well so that all collectives can get these optimizations. - const DebugOptions& debug_options = hlo_module->config().debug_options(); - - HloPassPipeline collectives_pipeline("collective-optimizations"); - collectives_pipeline.AddPass(); - collectives_pipeline.AddPass(); - collectives_pipeline.AddPass(); - collectives_pipeline.AddPass( - debug_options.xla_gpu_enable_reassociation_for_converted_ar()); - collectives_pipeline.AddPass(); - - collectives_pipeline.AddPass( - /*enable_reduce_scatter=*/debug_options - .xla_gpu_enable_while_loop_reduce_scatter_code_motion()); - - // Moves collectives' subsequent quantization before the collective to - // minimize data transfers. - collectives_pipeline.AddPass(); - // Remove dead computations after collective quantization. - collectives_pipeline.AddPass(); - - if (!debug_options.xla_gpu_run_post_layout_collective_pipeliner()) { - TF_RETURN_IF_ERROR( - AddCollectivePipelinerPasses(debug_options, collectives_pipeline)); - } collectives_pipeline.AddPass(); collectives_pipeline.AddPass( @@ -1353,8 +1330,6 @@ absl::Status GpuCompiler::OptimizeHloModule( hlo_module, stream_exec, options, gpu_target_config, thread_pool.get_mutable())); - TF_RETURN_IF_ERROR(RunPostLayoutCollectivePipelinerPasses(hlo_module)); - // This is a "low effort, high impact" fusion that should be run first. TF_RETURN_IF_ERROR(RunDynamicSliceFusionPasses(hlo_module, PlatformId())); diff --git a/xla/tests/collective_ops_e2e_test.cc b/xla/tests/collective_ops_e2e_test.cc index 77e9cd85a40166..ff876f80f74a53 100644 --- a/xla/tests/collective_ops_e2e_test.cc +++ b/xla/tests/collective_ops_e2e_test.cc @@ -1159,99 +1159,6 @@ ENTRY entry { absl::StrReplaceAll(kModuleReplicatedStr, replacements_), opts); } -TEST_F(CollectiveOpsTestE2E, - PostLayoutCollectivePipelinerShouldFlattenCallGraph) { - // The allgather in the loop has a nested while loop as its operand, - // when the pipelining happens, the nested while loop will be peeled outside. - // However, when a while is cloned, its call sites are still preserved which - // will error out in alias analysis. When the graph is flattened, the error - // should not happen. - absl::string_view kModuleReplicatedStr = R"( -HloModule module - -while_cond { - param = (s32[], f32[2,128], f32[8,128], f32[8,128]) parameter(0) - gte = s32[] get-tuple-element(param), index=0 - constant.1 = s32[] constant(3) - ROOT cmp = pred[] compare(gte, constant.1), direction=LT -} - -while_nested_cond { - param.nested = (s32[], f32[2,128]) parameter(0) - gte.nested = s32[] get-tuple-element(param.nested), index=0 - constant.nested = s32[] constant(3) - ROOT cmp.nested = pred[] compare(gte.nested, constant.nested), direction=LT -} -while_nested_body { - param.body_nested = (s32[], f32[2,128]) parameter(0) - gte.body_nested = s32[] get-tuple-element(param.body_nested), index=0 - gte.2.body_nested = f32[2,128] get-tuple-element(param.body_nested), index=1 - - constant.body_nested = s32[] constant(1) - add.body_nested = s32[] add(gte.body_nested, constant.body_nested) - rsqrt.body_nested = f32[2,128] rsqrt(gte.2.body_nested) - ROOT tuple.body_nested = (s32[], f32[2,128]) tuple(add.body_nested, rsqrt.body_nested) -} - -while_body { - param = (s32[], f32[2,128], f32[8,128], f32[8,128]) parameter(0) - get-tuple-element.394 = s32[] get-tuple-element(param), index=0 - get-tuple-element.395 = f32[2,128] get-tuple-element(param), index=1 - get-tuple-element.35 = f32[8,128] get-tuple-element(param), index=2 - get-tuple-element.36 = f32[8,128] get-tuple-element(param), index=3 - - constant.2557 = s32[] constant(1) - add.230 = s32[] add(get-tuple-element.394, constant.2557) - mul = f32[2,128] multiply(get-tuple-element.395, get-tuple-element.395) - constant.while = s32[] constant(0) - tuple.1 = (s32[], f32[2,128]) tuple(constant.while, mul) - while.1 = (s32[], f32[2,128]) while(tuple.1), condition=while_nested_cond, body=while_nested_body - gte.while = f32[2,128] get-tuple-element(while.1), index=1 - add.while = f32[2,128] add(gte.while, get-tuple-element.395) - - ag.1 = f32[8,128] all-gather(add.while), replica_groups={}, dimensions={0} - add.ag = f32[8,128] add(ag.1, get-tuple-element.36) - - ROOT tuple = (s32[], f32[2,128], f32[8,128], f32[8,128]) tuple(add.230, get-tuple-element.395, get-tuple-element.35, ag.1) -} - -ENTRY entry { - c0 = s32[] constant(0) - p0 = f32[2,128] parameter(0) - p1 = f32[8,128] parameter(1) - - tuple = (s32[], f32[2,128], f32[8,128], f32[8,128]) tuple(c0, p0, p1, p1) - while = (s32[], f32[2,128], f32[8,128], f32[8,128]) while(tuple), condition=while_cond, body=while_body - gte1 = f32[2,128] get-tuple-element(while), index=1 - gte2 = f32[8,128] get-tuple-element(while), index=3 - ROOT tuple.result = (f32[2,128], f32[8,128]) tuple(gte1, gte2) -} -)"; - - const int64_t kNumReplicas = 1; - SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas); - - HloModuleConfig config = - GetModuleConfigForTest(/*replica_count=*/kNumReplicas); - auto opts = GetDebugOptionsForTest(); - opts.set_xla_gpu_run_post_layout_collective_pipeliner(true); - opts.set_xla_gpu_enable_pipelined_all_reduce(true); - opts.set_xla_gpu_enable_pipelined_all_gather(true); - opts.set_xla_gpu_enable_pipelined_reduce_scatter(true); - - opts.set_xla_gpu_enable_triton_gemm(false); - config.set_debug_options(opts); - config.set_use_spmd_partitioning(false); - - TF_ASSERT_OK_AND_ASSIGN( - auto module, ParseAndReturnVerifiedModule(kModuleReplicatedStr, config)); - - TF_ASSERT_OK_AND_ASSIGN(auto executable, - CreateExecutable(std::move(module), - /*run_hlo_passes=*/true)); - EXPECT_TRUE(executable->has_module()); -} - TEST_F(CollectiveOpsTestE2E, AllToAllQuantizeCollectiveQuantizer) { absl::string_view kModuleReplicatedStr = R"( HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={()->bf16[2]}, num_partitions=2 diff --git a/xla/xla.proto b/xla/xla.proto index 5e5311fddd1ef8..563eb1f92c4d3e 100644 --- a/xla/xla.proto +++ b/xla/xla.proto @@ -662,7 +662,8 @@ message DebugOptions { bool xla_gpu_enable_pipelined_all_gather = 227; bool xla_gpu_enable_pipelined_reduce_scatter = 231; bool xla_gpu_enable_pipelined_p2p = 246; - bool xla_gpu_run_post_layout_collective_pipeliner = 313; + + reserved 313; // Was xla_gpu_run_post_layout_collective_pipeliner. // The minimum data size in bytes to trigger collective-permute-decomposer // transformation.