From dfaeff31e781a046f5b11563b556b242edea7e10 Mon Sep 17 00:00:00 2001 From: xla authors Date: Thu, 24 Oct 2024 07:32:52 -0700 Subject: [PATCH] Associates communication_resharding_costs and memory_sharding_costs with input sharding combinations instead of strategies. PiperOrigin-RevId: 689374211 --- .../auto_sharding/auto_sharding.cc | 132 +++++++++++------- .../auto_sharding/auto_sharding_cost_graph.cc | 21 +-- .../auto_sharding_dot_handler.cc | 9 +- .../auto_sharding/auto_sharding_strategy.cc | 37 ++--- .../auto_sharding/auto_sharding_strategy.h | 61 ++++---- .../auto_sharding/auto_sharding_util.cc | 2 +- 6 files changed, 151 insertions(+), 111 deletions(-) diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding.cc b/xla/hlo/experimental/auto_sharding/auto_sharding.cc index 21f49fc8498e63..c299a0ff1c41ff 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding.cc +++ b/xla/hlo/experimental/auto_sharding/auto_sharding.cc @@ -334,7 +334,6 @@ void FollowArrayOrTokenStrategyGroup( double compute_cost = 0, communication_cost = 0; double memory_cost = ByteSizeOfShapeWithSharding(shape, *output_spec); size_t num_in_nodes = strategy_group.in_nodes.size(); - InputShardings input_shardings{name, {num_in_nodes, *output_spec}}; ReshardingCosts communication_resharding_costs; ReshardingCosts memory_resharding_costs; for (size_t i = 0; i < strategy_group.in_nodes.size(); ++i) { @@ -344,11 +343,14 @@ void FollowArrayOrTokenStrategyGroup( memory_resharding_costs.push_back(MemoryReshardingCostVector( *strategy_group.in_nodes[i], shape, *output_spec, cluster_env)); } + InputShardings input_shardings{name, + {num_in_nodes, *output_spec}, + communication_resharding_costs, + memory_resharding_costs}; strategy_group.AddStrategy( - ShardingStrategy({*output_spec, compute_cost, communication_cost, - memory_cost, communication_resharding_costs, - memory_resharding_costs}), + ShardingStrategy( + {*output_spec, compute_cost, communication_cost, memory_cost}), input_shardings); } } @@ -404,12 +406,14 @@ std::unique_ptr HandlePartialReduce( GenerateReshardingCostsAndMissingShardingsForAllOperands( ins, output_spec, strategy_map, cluster_env, call_graph, input_shardings); + input_shardings.communication_resharding_costs = + std::move(resharding_costs.first); + input_shardings.memory_resharding_costs = + std::move(resharding_costs.second); child_strategy_group->AddStrategy( ShardingStrategy({std::move(output_spec), compute_cost, - communication_cost, memory_cost, - std::move(resharding_costs.first), - std::move(resharding_costs.second)}), + communication_cost, memory_cost}), std::move(input_shardings)); } @@ -554,9 +558,11 @@ absl::StatusOr> FollowReduceStrategy( } } const ShardingStrategy strategy = ShardingStrategy( - {output_spec, compute_cost, communication_cost, memory_cost, - communication_resharding_costs, memory_resharding_costs}); - strategy_group->AddStrategy(strategy, {name, {input_sharding}}); + {output_spec, compute_cost, communication_cost, memory_cost}); + strategy_group->AddStrategy(strategy, {name, + {input_sharding}, + communication_resharding_costs, + memory_resharding_costs}); } } else { LOG(FATAL) << "Unhandled kReduce shape: " << ins->shape().ToString(); @@ -697,11 +703,13 @@ void GenerateOutfeedStrategy(const HloInstruction* ins, const Shape& shape, } communication_resharding_costs.push_back({}); memory_resharding_costs.push_back({}); + input_shardings.communication_resharding_costs = + std::move(communication_resharding_costs); + input_shardings.memory_resharding_costs = std::move(memory_resharding_costs); double memory_cost = ByteSizeOfShapeWithSharding(shape, output_spec); strategy_group.AddStrategy( - ShardingStrategy({HloSharding::Replicate(), replicated_penalty, 0, - memory_cost, std::move(communication_resharding_costs), - std::move(memory_resharding_costs)}), + ShardingStrategy( + {HloSharding::Replicate(), replicated_penalty, 0, memory_cost}), input_shardings); } @@ -802,15 +810,18 @@ void AddReplicatedStrategy( } } + for (size_t j = 0; j < possible_input_shardings.size(); ++j) { + possible_input_shardings[j].communication_resharding_costs = + std::move(possible_communication_resharding_costs[j]); + possible_input_shardings[j].memory_resharding_costs = + std::move(possible_memory_resharding_costs[j]); + } for (size_t j = 0; j < possible_input_shardings.size(); ++j) { double communication_cost = ComputeCommunicationCost( ins, possible_input_shardings[j], cluster_env); strategy_group.AddStrategy( - ShardingStrategy( - {replicated_strategy, replicated_penalty, communication_cost, - memory_cost, - std::move(possible_communication_resharding_costs[j]), - std::move(possible_memory_resharding_costs[j])}), + ShardingStrategy({replicated_strategy, replicated_penalty, + communication_cost, memory_cost}), std::move(possible_input_shardings[j])); } } else { @@ -848,11 +859,13 @@ void AddReplicatedStrategy( } } } + input_shardings.communication_resharding_costs = + std::move(communication_resharding_costs); + input_shardings.memory_resharding_costs = + std::move(memory_resharding_costs); strategy_group.AddStrategy( - ShardingStrategy({HloSharding::Replicate(), replicated_penalty, 0, - memory_cost, - std::move(communication_resharding_costs), - std::move(memory_resharding_costs)}), + ShardingStrategy( + {HloSharding::Replicate(), replicated_penalty, 0, memory_cost}), input_shardings); } } @@ -939,11 +952,13 @@ void EnumerateAll1DPartition( communication_cost = ComputeSortCommunicationCost( ins->operand(0)->shape().rank() - 1, i, j, shape, cluster_env); } + input_shardings.communication_resharding_costs = + std::move(communication_resharding_costs); + input_shardings.memory_resharding_costs = + std::move(memory_resharding_costs); strategy_group.AddStrategy( - ShardingStrategy({output_spec, compute_cost, communication_cost, - memory_cost, - std::move(communication_resharding_costs), - std::move(memory_resharding_costs)}), + ShardingStrategy( + {output_spec, compute_cost, communication_cost, memory_cost}), input_shardings); } } @@ -1051,10 +1066,12 @@ void BuildStrategyAndCostForOp(const HloInstruction* ins, const Shape& shape, } } + input_shardings.communication_resharding_costs = + std::move(communication_resharding_costs); + input_shardings.memory_resharding_costs = std::move(memory_resharding_costs); strategy_group.AddStrategy( - ShardingStrategy({output_spec, compute_cost, communication_cost, - memory_cost, std::move(communication_resharding_costs), - std::move(memory_resharding_costs)}), + ShardingStrategy( + {output_spec, compute_cost, communication_cost, memory_cost}), input_shardings); } @@ -1102,11 +1119,12 @@ void EnumerateAll1DPartitionReshape(const HloInstruction* ins, ReshardingCosts memory_resharding_costs{MemoryReshardingCostVector( operand_strategy_group, operand_shape, *input_spec, cluster_env)}; strategy_group.AddStrategy( - ShardingStrategy({output_spec, compute_cost, communication_cost, - memory_cost, - std::move(communication_resharding_costs), - std::move(memory_resharding_costs)}), - {name, {*input_spec}}); + ShardingStrategy( + {output_spec, compute_cost, communication_cost, memory_cost}), + {name, + {*input_spec}, + std::move(communication_resharding_costs), + std::move(memory_resharding_costs)}); } } } @@ -1421,19 +1439,21 @@ void TrimOrGenerateStrategiesBasedOnExistingSharding( strategy_group.GetStrategies(); } strategy_group.ClearStrategies(); + input_shardings.communication_resharding_costs = + communication_resharding_costs; + input_shardings.memory_resharding_costs = memory_resharding_costs; strategy_group.AddStrategy( - ShardingStrategy({existing_sharding, 0, 0, memory_cost, - communication_resharding_costs, - memory_resharding_costs}), + ShardingStrategy({existing_sharding, 0, 0, memory_cost}), input_shardings); } // If there is only one option for resharding, and the cost computed for // that option is kInfinityCost, set the cost to zero. This is okay // because there is only one option anyway, and having the costs set to // kInfinityCost is problematic for the solver. - if (strategy_group.GetStrategies().size() == 1) { + if (strategy_group.GetStrategyInputShardings().size() == 1) { for (auto& operand_communication_resharding_costs : - strategy_group.GetStrategy(0).communication_resharding_costs) { + strategy_group.GetMutableInputShardings(0) + .communication_resharding_costs) { if (operand_communication_resharding_costs.size() == 1 && operand_communication_resharding_costs[0] >= kInfinityCost) { operand_communication_resharding_costs[0] = 0; @@ -1561,10 +1581,17 @@ void ScaleCostsWithExecutionCounts(const int64_t execution_count, ShardingStrategy& strategy = leaf_strategy_group.GetStrategy(sid); scale_cost(strategy.compute_cost); scale_cost(strategy.communication_cost); - for (int i = 0; i < strategy.communication_resharding_costs.size(); ++i) { - for (int j = 0; j < strategy.communication_resharding_costs[i].size(); + } + for (int iid = 0; + iid < leaf_strategy_group.GetStrategyInputShardings().size(); ++iid) { + InputShardings& input_shardings = + leaf_strategy_group.GetMutableInputShardings(iid); + for (int i = 0; i < input_shardings.communication_resharding_costs.size(); + ++i) { + for (int j = 0; + j < input_shardings.communication_resharding_costs[i].size(); ++j) { - scale_cost(strategy.communication_resharding_costs[i][j]); + scale_cost(input_shardings.communication_resharding_costs[i][j]); } } } @@ -1676,11 +1703,13 @@ std::unique_ptr HandleManuallyShardedInstruction( memory_resharding_costs.push_back(zeros); } } + input_shardings.communication_resharding_costs = + std::move(communication_resharding_costs); + input_shardings.memory_resharding_costs = + std::move(memory_resharding_costs); strategy_group->AddStrategy( ShardingStrategy({HloSharding::Replicate(), 0, 0, - static_cast(ShapeUtil::ByteSizeOf(shape)), - std::move(communication_resharding_costs), - std::move(memory_resharding_costs)}), + static_cast(ShapeUtil::ByteSizeOf(shape))}), std::move(input_shardings)); } else { LOG(FATAL) << "Unsupported instruction shape: " << shape.DebugString(); @@ -1728,13 +1757,12 @@ std::unique_ptr CreateReshapeStrategies( operand_strategy_group, operand->shape(), operand_strategy.output_sharding, cluster_env); strategy_group->AddStrategy( - ShardingStrategy({*output_sharding, - compute_cost, - communication_cost, - memory_cost, - {communication_resharding_costs}, - {memory_resharding_costs}}), - {name, {operand_strategy.output_sharding}}); + ShardingStrategy( + {*output_sharding, compute_cost, communication_cost, memory_cost}), + {name, + {operand_strategy.output_sharding}, + {communication_resharding_costs}, + {memory_resharding_costs}}); } if (strategy_group->GetStrategies().empty()) { diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.cc b/xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.cc index 5d1016830c1c63..183e25cacbc1ac 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.cc +++ b/xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.cc @@ -163,25 +163,28 @@ EdgeReshardingCostMatrix CostGraph::CreateEdgeCost( CHECK_LT(src_idx, node_lens_.size()); CHECK_LT(dst_idx, node_lens_.size()); EdgeReshardingCostMatrix edge_cost(node_lens_[src_idx], node_lens_[dst_idx]); - const auto& strategies = strategy_group->GetStrategies(); - for (NodeStrategyIdx k = 0; k < strategies.size(); ++k) { - const ShardingStrategy& strategy = strategies[k]; + const auto& strategy_input_shardings = + strategy_group->GetStrategyInputShardings(); + for (size_t iid = 0; iid < strategy_input_shardings.size(); ++iid) { + const InputShardings& input_shardings = strategy_input_shardings[iid]; + const NodeStrategyIdx k = + strategy_group->GetStrategyIdxForInputShardings(iid); size_t start_idx = 0; - CHECK_LT(in_node_idx, strategy.memory_resharding_costs.size()) + CHECK_LT(in_node_idx, input_shardings.memory_resharding_costs.size()) << strategy_group->node_idx; - if (strategy.memory_resharding_costs[in_node_idx].size() > + if (input_shardings.memory_resharding_costs[in_node_idx].size() > node_lens_[src_idx]) { - start_idx = strategy.memory_resharding_costs[in_node_idx].size() - + start_idx = input_shardings.memory_resharding_costs[in_node_idx].size() - node_lens_[src_idx]; } for (size_t j = start_idx; - j < strategy.memory_resharding_costs[in_node_idx].size(); ++j) { + j < input_shardings.memory_resharding_costs[in_node_idx].size(); ++j) { double communication_cost = 0; double memory_cost = 0; if (!zero_cost) { communication_cost = - strategy.communication_resharding_costs[in_node_idx][j]; - memory_cost = strategy.memory_resharding_costs[in_node_idx][j]; + input_shardings.communication_resharding_costs[in_node_idx][j]; + memory_cost = input_shardings.memory_resharding_costs[in_node_idx][j]; } edge_cost(j - start_idx, k) = EdgeReshardingCost(communication_cost, memory_cost); diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc b/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc index e161674473157e..031a47009b25a8 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc +++ b/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc @@ -337,10 +337,11 @@ void HandlerBase::AppendNewStrategy(const std::string& name, strategy_group_->AddStrategy( ShardingStrategy({output_spec, compute_cost, communication_cost, static_cast(ByteSizeOfShapeWithSharding( - ins_->shape(), output_spec)), - communication_resharding_costs, - memory_resharding_costs}), - {name, {input_specs.begin(), input_specs.end()}}); + ins_->shape(), output_spec))}), + {name, + {input_specs.begin(), input_specs.end()}, + communication_resharding_costs, + memory_resharding_costs}); } // Given lhs and rhs dim maps, infers a sharding for the output by relying diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc b/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc index e1a81dbf33772c..dea54c0bb9e7a7 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc +++ b/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc @@ -351,12 +351,14 @@ BuildStrategyAndCost( GenerateReshardingCostsAndMissingShardingsForAllOperands( ins, scatter_sharding, strategy_map, cluster_env, call_graph, input_shardings_optional); + input_shardings_optional.communication_resharding_costs = + std::move(resharding_costs.first); + input_shardings_optional.memory_resharding_costs = + std::move(resharding_costs.second); strategy_group->AddStrategy( ShardingStrategy({scatter_sharding, compute_cost, - communication_cost, memory_cost, - std::move(resharding_costs.first), - std::move(resharding_costs.second)}), + communication_cost, memory_cost}), input_shardings_optional); }; @@ -402,12 +404,14 @@ BuildStrategyAndCost( GenerateReshardingCostsAndMissingShardingsForAllOperands( ins, output_sharding, strategy_map, cluster_env, call_graph, input_shardings_optional); + input_shardings_optional.communication_resharding_costs = + std::move(resharding_costs.first); + input_shardings_optional.memory_resharding_costs = + std::move(resharding_costs.second); strategy_group->AddStrategy( ShardingStrategy({output_sharding, compute_cost, - communication_cost, memory_cost, - std::move(resharding_costs.first), - std::move(resharding_costs.second)}), + communication_cost, memory_cost}), input_shardings_optional); }; @@ -558,13 +562,12 @@ BuildStrategyAndCost( MemoryReshardingCostVector(src_strategy_group, operand->shape(), input_spec, cluster_env); strategy_group->AddStrategy( - ShardingStrategy({output_spec, - compute_cost, - communication_cost, - memory_cost, - {communication_resharding_costs}, - {memory_resharding_costs}}), - {name, {input_spec}}); + ShardingStrategy( + {output_spec, compute_cost, communication_cost, memory_cost}), + {name, + {input_spec}, + {communication_resharding_costs}, + {memory_resharding_costs}}); } break; } @@ -696,9 +699,11 @@ BuildStrategyAndCost( strategy_group->AddStrategy( ShardingStrategy({*output_spec, compute_cost, communication_cost, - memory_cost, std::move(resharding_costs.first), - std::move(resharding_costs.second)}), - {name, {input_spec}}); + memory_cost}), + {name, + {input_spec}, + std::move(resharding_costs.first), + std::move(resharding_costs.second)}); } if (strategy_group->GetStrategies().empty()) { diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h b/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h index 49212fe84ce655..f40b5aebb3410a 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h +++ b/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h @@ -82,6 +82,12 @@ using ReshardingCosts = std::vector>; struct InputShardings { std::string name; std::vector> shardings; + // resharding_costs[i][j] is the resharding cost from the output of + // i-th operand's j-th strategy to this strategy. + // If there is only one tuple operand,resharding_costs[i][j] is the resharding + // cost from i-th tuple element's j-th strategy. + ReshardingCosts communication_resharding_costs; + ReshardingCosts memory_resharding_costs; std::string ToString() const { std::string str = absl::StrCat(name, " "); @@ -102,26 +108,7 @@ struct InputShardings { } } } - return str; - } -}; - -// One sharding strategy -struct ShardingStrategy { - HloSharding output_sharding; - double compute_cost; - double communication_cost; - double memory_cost; - // resharding_costs[i][j] is the resharding cost from the output of - // i-th operand's j-th strategy to this strategy. - // If there is only one tuple operand,resharding_costs[i][j] is the resharding - // cost from i-th tuple element's j-th strategy. - ReshardingCosts communication_resharding_costs; - ReshardingCosts memory_resharding_costs; - - std::string ToString() const { return output_sharding.ToString(); } - std::string ToStringLong() const { std::vector communication_resharding_vector_strings; communication_resharding_vector_strings.reserve( communication_resharding_costs.size()); @@ -140,23 +127,35 @@ struct ShardingStrategy { } std::string memory_resharding_cost_str = absl::StrCat( "{", absl::StrJoin(memory_resharding_vector_strings, ", "), "}"); + absl::StrAppend(&str, ", communication_resharding_costs=", + communication_resharding_cost_str, + ", memory_resharding_costs=", memory_resharding_cost_str); - return absl::StrCat( - output_sharding.ToString(), ", compute_cost=", compute_cost, - ", communication_cost=", communication_cost, - ", memory_cost=", memory_cost, - ", communication_resharding_costs=", communication_resharding_cost_str, - ", memory_resharding_costs=", memory_resharding_cost_str); + return str; + } +}; + +// One sharding strategy +struct ShardingStrategy { + HloSharding output_sharding; + double compute_cost; + double communication_cost; + double memory_cost; + + std::string ToString() const { return output_sharding.ToString(); } + + std::string ToStringLong() const { + return absl::StrCat(output_sharding.ToString(), + ", compute_cost=", compute_cost, + ", communication_cost=", communication_cost, + ", memory_cost=", memory_cost); } bool operator==(const ShardingStrategy& other) const { return output_sharding == other.output_sharding && compute_cost == other.compute_cost && communication_cost == other.communication_cost && - memory_cost == other.memory_cost && - communication_resharding_costs == - other.communication_resharding_costs && - memory_resharding_costs == other.memory_resharding_costs; + memory_cost == other.memory_cost; } }; @@ -327,6 +326,10 @@ struct StrategyGroup { return strategy_input_shardings[input_sharding_idx]; } + InputShardings& GetMutableInputShardings(size_t input_sharding_idx) { + return strategy_input_shardings[input_sharding_idx]; + } + const InputShardings& GetInputShardingsForStrategy( size_t strategy_idx) const { const size_t input_sharding_idx = diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc b/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc index ccdbbe4c0ff137..7de63f8a2f67c6 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc +++ b/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc @@ -846,7 +846,7 @@ void RemoveDuplicatedStrategy(StrategyGroup& strategy_group) { const InputShardings& input_shardings = strategy_input_shardings[iid]; const ShardingStrategy& strategy = strategy_group.GetStrategyForInputShardings(iid); - if (AllInfinityCosts(strategy.communication_resharding_costs)) { + if (AllInfinityCosts(input_shardings.communication_resharding_costs)) { num_skipped_due_to_infinity_costs++; continue; }