Skip to content

Commit

Permalink
Associates communication_resharding_costs and memory_sharding_costs w…
Browse files Browse the repository at this point in the history
…ith input sharding combinations instead of strategies.

PiperOrigin-RevId: 689374211
  • Loading branch information
Google-ML-Automation committed Oct 24, 2024
1 parent a6eeef2 commit dfaeff3
Show file tree
Hide file tree
Showing 6 changed files with 151 additions and 111 deletions.
132 changes: 80 additions & 52 deletions xla/hlo/experimental/auto_sharding/auto_sharding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,6 @@ void FollowArrayOrTokenStrategyGroup(
double compute_cost = 0, communication_cost = 0;
double memory_cost = ByteSizeOfShapeWithSharding(shape, *output_spec);
size_t num_in_nodes = strategy_group.in_nodes.size();
InputShardings input_shardings{name, {num_in_nodes, *output_spec}};
ReshardingCosts communication_resharding_costs;
ReshardingCosts memory_resharding_costs;
for (size_t i = 0; i < strategy_group.in_nodes.size(); ++i) {
Expand All @@ -344,11 +343,14 @@ void FollowArrayOrTokenStrategyGroup(
memory_resharding_costs.push_back(MemoryReshardingCostVector(
*strategy_group.in_nodes[i], shape, *output_spec, cluster_env));
}
InputShardings input_shardings{name,
{num_in_nodes, *output_spec},
communication_resharding_costs,
memory_resharding_costs};

strategy_group.AddStrategy(
ShardingStrategy({*output_spec, compute_cost, communication_cost,
memory_cost, communication_resharding_costs,
memory_resharding_costs}),
ShardingStrategy(
{*output_spec, compute_cost, communication_cost, memory_cost}),
input_shardings);
}
}
Expand Down Expand Up @@ -404,12 +406,14 @@ std::unique_ptr<StrategyGroup> HandlePartialReduce(
GenerateReshardingCostsAndMissingShardingsForAllOperands(
ins, output_spec, strategy_map, cluster_env, call_graph,
input_shardings);
input_shardings.communication_resharding_costs =
std::move(resharding_costs.first);
input_shardings.memory_resharding_costs =
std::move(resharding_costs.second);

child_strategy_group->AddStrategy(
ShardingStrategy({std::move(output_spec), compute_cost,
communication_cost, memory_cost,
std::move(resharding_costs.first),
std::move(resharding_costs.second)}),
communication_cost, memory_cost}),
std::move(input_shardings));
}

Expand Down Expand Up @@ -554,9 +558,11 @@ absl::StatusOr<std::unique_ptr<StrategyGroup>> FollowReduceStrategy(
}
}
const ShardingStrategy strategy = ShardingStrategy(
{output_spec, compute_cost, communication_cost, memory_cost,
communication_resharding_costs, memory_resharding_costs});
strategy_group->AddStrategy(strategy, {name, {input_sharding}});
{output_spec, compute_cost, communication_cost, memory_cost});
strategy_group->AddStrategy(strategy, {name,
{input_sharding},
communication_resharding_costs,
memory_resharding_costs});
}
} else {
LOG(FATAL) << "Unhandled kReduce shape: " << ins->shape().ToString();
Expand Down Expand Up @@ -697,11 +703,13 @@ void GenerateOutfeedStrategy(const HloInstruction* ins, const Shape& shape,
}
communication_resharding_costs.push_back({});
memory_resharding_costs.push_back({});
input_shardings.communication_resharding_costs =
std::move(communication_resharding_costs);
input_shardings.memory_resharding_costs = std::move(memory_resharding_costs);
double memory_cost = ByteSizeOfShapeWithSharding(shape, output_spec);
strategy_group.AddStrategy(
ShardingStrategy({HloSharding::Replicate(), replicated_penalty, 0,
memory_cost, std::move(communication_resharding_costs),
std::move(memory_resharding_costs)}),
ShardingStrategy(
{HloSharding::Replicate(), replicated_penalty, 0, memory_cost}),
input_shardings);
}

Expand Down Expand Up @@ -802,15 +810,18 @@ void AddReplicatedStrategy(
}
}

for (size_t j = 0; j < possible_input_shardings.size(); ++j) {
possible_input_shardings[j].communication_resharding_costs =
std::move(possible_communication_resharding_costs[j]);
possible_input_shardings[j].memory_resharding_costs =
std::move(possible_memory_resharding_costs[j]);
}
for (size_t j = 0; j < possible_input_shardings.size(); ++j) {
double communication_cost = ComputeCommunicationCost(
ins, possible_input_shardings[j], cluster_env);
strategy_group.AddStrategy(
ShardingStrategy(
{replicated_strategy, replicated_penalty, communication_cost,
memory_cost,
std::move(possible_communication_resharding_costs[j]),
std::move(possible_memory_resharding_costs[j])}),
ShardingStrategy({replicated_strategy, replicated_penalty,
communication_cost, memory_cost}),
std::move(possible_input_shardings[j]));
}
} else {
Expand Down Expand Up @@ -848,11 +859,13 @@ void AddReplicatedStrategy(
}
}
}
input_shardings.communication_resharding_costs =
std::move(communication_resharding_costs);
input_shardings.memory_resharding_costs =
std::move(memory_resharding_costs);
strategy_group.AddStrategy(
ShardingStrategy({HloSharding::Replicate(), replicated_penalty, 0,
memory_cost,
std::move(communication_resharding_costs),
std::move(memory_resharding_costs)}),
ShardingStrategy(
{HloSharding::Replicate(), replicated_penalty, 0, memory_cost}),
input_shardings);
}
}
Expand Down Expand Up @@ -939,11 +952,13 @@ void EnumerateAll1DPartition(
communication_cost = ComputeSortCommunicationCost(
ins->operand(0)->shape().rank() - 1, i, j, shape, cluster_env);
}
input_shardings.communication_resharding_costs =
std::move(communication_resharding_costs);
input_shardings.memory_resharding_costs =
std::move(memory_resharding_costs);
strategy_group.AddStrategy(
ShardingStrategy({output_spec, compute_cost, communication_cost,
memory_cost,
std::move(communication_resharding_costs),
std::move(memory_resharding_costs)}),
ShardingStrategy(
{output_spec, compute_cost, communication_cost, memory_cost}),
input_shardings);
}
}
Expand Down Expand Up @@ -1051,10 +1066,12 @@ void BuildStrategyAndCostForOp(const HloInstruction* ins, const Shape& shape,
}
}

input_shardings.communication_resharding_costs =
std::move(communication_resharding_costs);
input_shardings.memory_resharding_costs = std::move(memory_resharding_costs);
strategy_group.AddStrategy(
ShardingStrategy({output_spec, compute_cost, communication_cost,
memory_cost, std::move(communication_resharding_costs),
std::move(memory_resharding_costs)}),
ShardingStrategy(
{output_spec, compute_cost, communication_cost, memory_cost}),
input_shardings);
}

Expand Down Expand Up @@ -1102,11 +1119,12 @@ void EnumerateAll1DPartitionReshape(const HloInstruction* ins,
ReshardingCosts memory_resharding_costs{MemoryReshardingCostVector(
operand_strategy_group, operand_shape, *input_spec, cluster_env)};
strategy_group.AddStrategy(
ShardingStrategy({output_spec, compute_cost, communication_cost,
memory_cost,
std::move(communication_resharding_costs),
std::move(memory_resharding_costs)}),
{name, {*input_spec}});
ShardingStrategy(
{output_spec, compute_cost, communication_cost, memory_cost}),
{name,
{*input_spec},
std::move(communication_resharding_costs),
std::move(memory_resharding_costs)});
}
}
}
Expand Down Expand Up @@ -1421,19 +1439,21 @@ void TrimOrGenerateStrategiesBasedOnExistingSharding(
strategy_group.GetStrategies();
}
strategy_group.ClearStrategies();
input_shardings.communication_resharding_costs =
communication_resharding_costs;
input_shardings.memory_resharding_costs = memory_resharding_costs;
strategy_group.AddStrategy(
ShardingStrategy({existing_sharding, 0, 0, memory_cost,
communication_resharding_costs,
memory_resharding_costs}),
ShardingStrategy({existing_sharding, 0, 0, memory_cost}),
input_shardings);
}
// If there is only one option for resharding, and the cost computed for
// that option is kInfinityCost, set the cost to zero. This is okay
// because there is only one option anyway, and having the costs set to
// kInfinityCost is problematic for the solver.
if (strategy_group.GetStrategies().size() == 1) {
if (strategy_group.GetStrategyInputShardings().size() == 1) {
for (auto& operand_communication_resharding_costs :
strategy_group.GetStrategy(0).communication_resharding_costs) {
strategy_group.GetMutableInputShardings(0)
.communication_resharding_costs) {
if (operand_communication_resharding_costs.size() == 1 &&
operand_communication_resharding_costs[0] >= kInfinityCost) {
operand_communication_resharding_costs[0] = 0;
Expand Down Expand Up @@ -1561,10 +1581,17 @@ void ScaleCostsWithExecutionCounts(const int64_t execution_count,
ShardingStrategy& strategy = leaf_strategy_group.GetStrategy(sid);
scale_cost(strategy.compute_cost);
scale_cost(strategy.communication_cost);
for (int i = 0; i < strategy.communication_resharding_costs.size(); ++i) {
for (int j = 0; j < strategy.communication_resharding_costs[i].size();
}
for (int iid = 0;
iid < leaf_strategy_group.GetStrategyInputShardings().size(); ++iid) {
InputShardings& input_shardings =
leaf_strategy_group.GetMutableInputShardings(iid);
for (int i = 0; i < input_shardings.communication_resharding_costs.size();
++i) {
for (int j = 0;
j < input_shardings.communication_resharding_costs[i].size();
++j) {
scale_cost(strategy.communication_resharding_costs[i][j]);
scale_cost(input_shardings.communication_resharding_costs[i][j]);
}
}
}
Expand Down Expand Up @@ -1676,11 +1703,13 @@ std::unique_ptr<StrategyGroup> HandleManuallyShardedInstruction(
memory_resharding_costs.push_back(zeros);
}
}
input_shardings.communication_resharding_costs =
std::move(communication_resharding_costs);
input_shardings.memory_resharding_costs =
std::move(memory_resharding_costs);
strategy_group->AddStrategy(
ShardingStrategy({HloSharding::Replicate(), 0, 0,
static_cast<double>(ShapeUtil::ByteSizeOf(shape)),
std::move(communication_resharding_costs),
std::move(memory_resharding_costs)}),
static_cast<double>(ShapeUtil::ByteSizeOf(shape))}),
std::move(input_shardings));
} else {
LOG(FATAL) << "Unsupported instruction shape: " << shape.DebugString();
Expand Down Expand Up @@ -1728,13 +1757,12 @@ std::unique_ptr<StrategyGroup> CreateReshapeStrategies(
operand_strategy_group, operand->shape(),
operand_strategy.output_sharding, cluster_env);
strategy_group->AddStrategy(
ShardingStrategy({*output_sharding,
compute_cost,
communication_cost,
memory_cost,
{communication_resharding_costs},
{memory_resharding_costs}}),
{name, {operand_strategy.output_sharding}});
ShardingStrategy(
{*output_sharding, compute_cost, communication_cost, memory_cost}),
{name,
{operand_strategy.output_sharding},
{communication_resharding_costs},
{memory_resharding_costs}});
}

if (strategy_group->GetStrategies().empty()) {
Expand Down
21 changes: 12 additions & 9 deletions xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -163,25 +163,28 @@ EdgeReshardingCostMatrix CostGraph::CreateEdgeCost(
CHECK_LT(src_idx, node_lens_.size());
CHECK_LT(dst_idx, node_lens_.size());
EdgeReshardingCostMatrix edge_cost(node_lens_[src_idx], node_lens_[dst_idx]);
const auto& strategies = strategy_group->GetStrategies();
for (NodeStrategyIdx k = 0; k < strategies.size(); ++k) {
const ShardingStrategy& strategy = strategies[k];
const auto& strategy_input_shardings =
strategy_group->GetStrategyInputShardings();
for (size_t iid = 0; iid < strategy_input_shardings.size(); ++iid) {
const InputShardings& input_shardings = strategy_input_shardings[iid];
const NodeStrategyIdx k =
strategy_group->GetStrategyIdxForInputShardings(iid);
size_t start_idx = 0;
CHECK_LT(in_node_idx, strategy.memory_resharding_costs.size())
CHECK_LT(in_node_idx, input_shardings.memory_resharding_costs.size())
<< strategy_group->node_idx;
if (strategy.memory_resharding_costs[in_node_idx].size() >
if (input_shardings.memory_resharding_costs[in_node_idx].size() >
node_lens_[src_idx]) {
start_idx = strategy.memory_resharding_costs[in_node_idx].size() -
start_idx = input_shardings.memory_resharding_costs[in_node_idx].size() -
node_lens_[src_idx];
}
for (size_t j = start_idx;
j < strategy.memory_resharding_costs[in_node_idx].size(); ++j) {
j < input_shardings.memory_resharding_costs[in_node_idx].size(); ++j) {
double communication_cost = 0;
double memory_cost = 0;
if (!zero_cost) {
communication_cost =
strategy.communication_resharding_costs[in_node_idx][j];
memory_cost = strategy.memory_resharding_costs[in_node_idx][j];
input_shardings.communication_resharding_costs[in_node_idx][j];
memory_cost = input_shardings.memory_resharding_costs[in_node_idx][j];
}
edge_cost(j - start_idx, k) =
EdgeReshardingCost(communication_cost, memory_cost);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -337,10 +337,11 @@ void HandlerBase::AppendNewStrategy(const std::string& name,
strategy_group_->AddStrategy(
ShardingStrategy({output_spec, compute_cost, communication_cost,
static_cast<double>(ByteSizeOfShapeWithSharding(
ins_->shape(), output_spec)),
communication_resharding_costs,
memory_resharding_costs}),
{name, {input_specs.begin(), input_specs.end()}});
ins_->shape(), output_spec))}),
{name,
{input_specs.begin(), input_specs.end()},
communication_resharding_costs,
memory_resharding_costs});
}

// Given lhs and rhs dim maps, infers a sharding for the output by relying
Expand Down
37 changes: 21 additions & 16 deletions xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -351,12 +351,14 @@ BuildStrategyAndCost(
GenerateReshardingCostsAndMissingShardingsForAllOperands(
ins, scatter_sharding, strategy_map, cluster_env, call_graph,
input_shardings_optional);
input_shardings_optional.communication_resharding_costs =
std::move(resharding_costs.first);
input_shardings_optional.memory_resharding_costs =
std::move(resharding_costs.second);

strategy_group->AddStrategy(
ShardingStrategy({scatter_sharding, compute_cost,
communication_cost, memory_cost,
std::move(resharding_costs.first),
std::move(resharding_costs.second)}),
communication_cost, memory_cost}),
input_shardings_optional);
};

Expand Down Expand Up @@ -402,12 +404,14 @@ BuildStrategyAndCost(
GenerateReshardingCostsAndMissingShardingsForAllOperands(
ins, output_sharding, strategy_map, cluster_env, call_graph,
input_shardings_optional);
input_shardings_optional.communication_resharding_costs =
std::move(resharding_costs.first);
input_shardings_optional.memory_resharding_costs =
std::move(resharding_costs.second);

strategy_group->AddStrategy(
ShardingStrategy({output_sharding, compute_cost,
communication_cost, memory_cost,
std::move(resharding_costs.first),
std::move(resharding_costs.second)}),
communication_cost, memory_cost}),
input_shardings_optional);
};

Expand Down Expand Up @@ -558,13 +562,12 @@ BuildStrategyAndCost(
MemoryReshardingCostVector(src_strategy_group, operand->shape(),
input_spec, cluster_env);
strategy_group->AddStrategy(
ShardingStrategy({output_spec,
compute_cost,
communication_cost,
memory_cost,
{communication_resharding_costs},
{memory_resharding_costs}}),
{name, {input_spec}});
ShardingStrategy(
{output_spec, compute_cost, communication_cost, memory_cost}),
{name,
{input_spec},
{communication_resharding_costs},
{memory_resharding_costs}});
}
break;
}
Expand Down Expand Up @@ -696,9 +699,11 @@ BuildStrategyAndCost(

strategy_group->AddStrategy(
ShardingStrategy({*output_spec, compute_cost, communication_cost,
memory_cost, std::move(resharding_costs.first),
std::move(resharding_costs.second)}),
{name, {input_spec}});
memory_cost}),
{name,
{input_spec},
std::move(resharding_costs.first),
std::move(resharding_costs.second)});
}

if (strategy_group->GetStrategies().empty()) {
Expand Down
Loading

0 comments on commit dfaeff3

Please sign in to comment.