Skip to content

Commit

Permalink
[XLA:TPU:MSA]
Browse files Browse the repository at this point in the history
* Add support for overriding cross program prefetch behavior
* Add support for filtering buffer intervals based on the uses of the buffer.
* Refactor some functions into msa/utils

PiperOrigin-RevId: 651595697
  • Loading branch information
subhankarshah authored and Google-ML-Automation committed Oct 10, 2024
1 parent 28a4ebf commit 148dfb7
Show file tree
Hide file tree
Showing 9 changed files with 453 additions and 248 deletions.
14 changes: 11 additions & 3 deletions xla/service/memory_space_assignment/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -165,12 +165,22 @@ cc_library(
srcs = ["utils.cc"],
hdrs = ["utils.h"],
deps = [
":memory_space_assignment_proto_cc",
"//xla:shape_util",
"//xla:util",
"//xla/hlo/ir:hlo",
"//xla/hlo/utils:hlo_live_range",
"//xla/service:hlo_value",
"//xla/service/heap_simulator",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:string_view",
"@com_googlesource_code_re2//:re2",
"@tsl//tsl/platform:statusor",
],
)

Expand Down Expand Up @@ -544,7 +554,6 @@ cc_library(
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/strings:string_view",
"@com_google_absl//absl/types:span",
"@com_googlesource_code_re2//:re2",
"@tsl//tsl/platform:logging",
"@tsl//tsl/platform:status",
"@tsl//tsl/platform:statusor",
Expand All @@ -558,6 +567,7 @@ cc_library(
deps = [
":cost_analysis",
":memory_space_assignment_proto_cc",
":utils",
"//xla:shape_util",
"//xla:util",
"//xla/hlo/ir:hlo",
Expand All @@ -568,10 +578,8 @@ cc_library(
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/log",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:string_view",
"@com_googlesource_code_re2//:re2",
],
)

Expand Down
146 changes: 13 additions & 133 deletions xla/service/memory_space_assignment/algorithm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ limitations under the License.
#include "absl/strings/str_join.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "re2/re2.h"
#include "xla/debug_options_flags.h"
#include "xla/hlo/ir/hlo_computation.h"
#include "xla/hlo/ir/hlo_instruction.h"
Expand Down Expand Up @@ -282,20 +281,24 @@ std::vector<MsaBufferInterval> FindCrossProgramPrefetchCandidates(
for (const HloBuffer& buffer : alias_analysis.buffers()) {
CHECK_GE(buffer.values().size(), 1);
const HloValue* value = buffer.values().at(0);
MsaBufferInterval interval;
interval.buffer = value;
interval.size = options.size_fn(*value);
interval.start = 0;
interval.end = hlo_live_range.schedule_end_time();
interval.need_allocation = true;
interval.colocations = {++buffer.values().begin(), buffer.values().end()};
if (IsCrossProgramPrefetchCandidate(*value, alias_analysis, options)) {
MsaBufferInterval interval;
interval.buffer = value;
interval.size = options.size_fn(*value);
interval.start = 0;
interval.end = hlo_live_range.schedule_end_time();
interval.need_allocation = true;
interval.colocations = {++buffer.values().begin(), buffer.values().end()};
candidates.emplace_back(interval);
} else if (MemorySpaceAssignmentUtils::
DoesCrossProgramPrefetchBufferMatchAnyFilter(
options.msa_sort_order_overrides, interval)) {
candidates.emplace_back(interval);
}
}

DefaultCrossProgramPrefetchBufferIntervalComparator default_comparator(
hlo_live_range);
hlo_live_range, options.msa_sort_order_overrides);
BufferIntervalComparator* comparator =
(options.default_cross_program_prefetch_heuristic &&
options.buffer_interval_comparator
Expand All @@ -313,129 +316,6 @@ std::vector<MsaBufferInterval> FindCrossProgramPrefetchCandidates(
return candidates;
}

absl::StatusOr<xla::HloLiveRange::LogicalTime>
GetScheduleTimeFromInstructionName(
absl::string_view name,
const absl::flat_hash_map<const xla::HloInstruction*,
xla::HloLiveRange::LogicalTime>& schedule) {
for (auto schedule_entry : schedule) {
if (schedule_entry.first->name() == name) {
return schedule_entry.second;
}
}
return NotFound("Reference instruction %s was not found in the schedule.",
name);
}

bool DoesOperandMatchFilter(const HloOperandFilter& filter,
int64_t operand_size, const HloUse& hlo_use) {
if (filter.has_size_gte() && operand_size < filter.size_gte()) {
return false;
}
if (filter.has_size_lte() && operand_size > filter.size_lte()) {
return false;
}
if (filter.has_operand_number() &&
hlo_use.operand_number != filter.operand_number()) {
return false;
}
if (filter.has_instruction_name_regex() &&
!RE2::FullMatch(hlo_use.instruction->name(),
filter.instruction_name_regex())) {
return false;
}
if (filter.has_tuple_index() &&
hlo_use.operand_index != ShapeIndex(filter.tuple_index().index().begin(),
filter.tuple_index().index().end())) {
return false;
}
return true;
}

absl::StatusOr<std::optional<int64_t>> GetPrefetchTimeByEagerness(
float prefetch_eagerness, int64_t earliest_prefetch_time,
int64_t latest_prefetch_time) {
CHECK_GE(prefetch_eagerness, 0.0);
CHECK_LE(prefetch_eagerness, 1.0);
if (earliest_prefetch_time > latest_prefetch_time) {
return static_cast<std::optional<int64_t>>(std::nullopt);
}
return static_cast<std::optional<int64_t>>(
earliest_prefetch_time +
(latest_prefetch_time - earliest_prefetch_time) * prefetch_eagerness);
}

absl::StatusOr<std::optional<int64_t>> GetPrefetchTimeAfterInstruction(
const std::string& after_instruction_name,
const absl::flat_hash_map<const xla::HloInstruction*,
xla::HloLiveRange::LogicalTime>& schedule) {
TF_ASSIGN_OR_RETURN(
auto reference_instruction_time,
GetScheduleTimeFromInstructionName(after_instruction_name, schedule));
return static_cast<std::optional<int64_t>>(reference_instruction_time);
}

absl::StatusOr<std::optional<int64_t>> GetPrefetchTimeBeforeInstruction(
const std::string& before_instruction_name,
const absl::flat_hash_map<const xla::HloInstruction*,
xla::HloLiveRange::LogicalTime>& schedule) {
TF_ASSIGN_OR_RETURN(
auto reference_instruction_time,
GetScheduleTimeFromInstructionName(before_instruction_name, schedule));
return static_cast<std::optional<int64_t>>(reference_instruction_time - 1);
}

absl::StatusOr<std::optional<int64_t>> GetPrefetchTime(
const PreferredPrefetchOverrideOptions& override_options,
int64_t earliest_prefetch_time, int64_t latest_prefetch_time,
const absl::flat_hash_map<const HloInstruction*, HloLiveRange::LogicalTime>&
instruction_schedule) {
switch (override_options.options_case()) {
case PreferredPrefetchOverrideOptions::kPrefetchEagerness:
return GetPrefetchTimeByEagerness(override_options.prefetch_eagerness(),
earliest_prefetch_time,
latest_prefetch_time);
case PreferredPrefetchOverrideOptions::kAfterInstructionName:
return GetPrefetchTimeAfterInstruction(
override_options.after_instruction_name(), instruction_schedule);
case PreferredPrefetchOverrideOptions::kBeforeInstructionName:
return GetPrefetchTimeBeforeInstruction(
override_options.before_instruction_name(), instruction_schedule);
case PreferredPrefetchOverrideOptions::OPTIONS_NOT_SET:
break;
}
return static_cast<absl::StatusOr<std::optional<int64_t>>>(std::nullopt);
}

absl::StatusOr<std::optional<int64_t>> GetOverriddenPreferredPrefetchTime(
const PreferredPrefetchOverrides& preferred_prefetch_overrides,
int64_t operand_size, const HloUse& hlo_use,
const absl::flat_hash_map<const HloInstruction*, HloLiveRange::LogicalTime>&
instruction_schedule,
int64_t earliest_prefetch_time, int64_t latest_prefetch_time) {
for (const auto& override : preferred_prefetch_overrides.overrides()) {
if (!DoesOperandMatchFilter(override.hlo_operand_filter(), operand_size,
hlo_use)) {
continue;
}
LOG(INFO) << "Config match for instruction " << hlo_use.instruction->name()
<< " operand number " << hlo_use.operand_number
<< " operand index " << hlo_use.operand_index.ToString()
<< " size " << operand_size << " live range ("
<< earliest_prefetch_time << ", " << latest_prefetch_time << ")";
TF_ASSIGN_OR_RETURN(
auto prefetch_time,
GetPrefetchTime(override.override_options(), earliest_prefetch_time,
latest_prefetch_time, instruction_schedule));
if (prefetch_time.has_value() &&
prefetch_time.value() >= earliest_prefetch_time &&
prefetch_time.value() <= latest_prefetch_time) {
return prefetch_time;
}
}
return static_cast<absl::StatusOr<std::optional<int64_t>>>(std::nullopt);
}

} // namespace

std::string AllocationValue::ToString() const {
Expand Down Expand Up @@ -2507,7 +2387,7 @@ MsaAlgorithm::AllocationRequest MsaAlgorithm::CreateAllocationRequest(
? earliest_prefetch_time.value()
: std::min(definition_time, use_time));
auto overridden_preferred_prefetch_time =
GetOverriddenPreferredPrefetchTime(
MemorySpaceAssignmentUtils::GetOverriddenPreferredPrefetchTime(
options_.preferred_prefetch_overrides, allocation_value.size(),
hlo_use, instruction_schedule, live_range_start_time,
latest_prefetch_time);
Expand Down
83 changes: 16 additions & 67 deletions xla/service/memory_space_assignment/buffer_interval_comparator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,86 +17,27 @@ limitations under the License.

#include <algorithm>
#include <cstdint>
#include <limits>
#include <string>
#include <tuple>
#include <utility>

#include "absl/algorithm/container.h"
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/log/log.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "absl/strings/string_view.h"
#include "re2/re2.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/utils/hlo_live_range.h"
#include "xla/service/hlo_value.h"
#include "xla/service/memory_space_assignment/cost_analysis.h"
#include "xla/service/memory_space_assignment/memory_space_assignment.pb.h"
#include "xla/service/memory_space_assignment/utils.h"
#include "xla/shape_util.h"
#include "xla/util.h"

namespace xla {
namespace memory_space_assignment {
namespace {

bool DoesResultMatchFilter(const HloPositionMatcher& filter,
const MsaBufferInterval& buffer_interval) {
HloInstruction* instruction = buffer_interval.buffer->instruction();
if (filter.has_instruction_regex() &&
!RE2::FullMatch(instruction->ToString(), filter.instruction_regex())) {
return false;
}
if (filter.has_instruction_name_regex() &&
!RE2::FullMatch(instruction->name(), filter.instruction_name_regex())) {
return false;
}
if (filter.has_tuple_index() &&
buffer_interval.buffer->index() !=
ShapeIndex(filter.tuple_index().index().begin(),
filter.tuple_index().index().end())) {
return false;
}
if (filter.has_size_gte() && filter.size_gte() > buffer_interval.size) {
return false;
}
if (filter.has_size_lte() && filter.size_lte() < buffer_interval.size) {
return false;
}
return true;
}

// Returns an integer representing the priority of a MsaBufferInterval during
// assignment, a smaller number indicates a higher priority.
int64_t GetBufferIntervalOverridePriority(
const MsaSortOrderOverrides& msa_sort_order_overrides,
const MsaBufferInterval& buffer_interval) {
if (msa_sort_order_overrides.overrides_size() == 0) {
return 0;
}
for (int64_t i = 0; i < msa_sort_order_overrides.overrides_size(); ++i) {
const auto& override = msa_sort_order_overrides.overrides(i);
if (!DoesResultMatchFilter(override.hlo_position_matcher(),
buffer_interval)) {
continue;
}
LOG(INFO) << "Override Sort Order Config " << i << " matches "
<< buffer_interval.buffer->instruction()->ToString();
switch (override.override_options().options_case()) {
case MsaSortOrderOverrideOptions::kAssignFirst:
return std::numeric_limits<int64_t>::lowest() + i;
case MsaSortOrderOverrideOptions::kAssignLast:
return std::numeric_limits<int64_t>::max() - i;
case MsaSortOrderOverrideOptions::OPTIONS_NOT_SET:
continue;
}
}
return 0;
}

} // namespace

MemoryBoundednessBufferIntervalComparator::
MemoryBoundednessBufferIntervalComparator(
Expand Down Expand Up @@ -156,8 +97,9 @@ int64_t MemoryBoundednessBufferIntervalComparator::GetLatestUseTime(
MemoryBoundednessBufferIntervalComparator::ComparisonTuple
MemoryBoundednessBufferIntervalComparator::GetTuple(
const MsaBufferInterval& buffer_interval) {
int64_t priority = GetBufferIntervalOverridePriority(
msa_sort_order_overrides_, buffer_interval);
int64_t priority =
MemorySpaceAssignmentUtils::GetBufferIntervalOverridePriority(
msa_sort_order_overrides_, buffer_interval);
float inverse_memory_boundedness =
-1.0 * cost_analysis_.GetMemoryBoundedness(buffer_interval,
cost_analysis_cache_);
Expand All @@ -173,8 +115,11 @@ MemoryBoundednessBufferIntervalComparator::GetTuple(

DefaultCrossProgramPrefetchBufferIntervalComparator::
DefaultCrossProgramPrefetchBufferIntervalComparator(
const HloLiveRange& hlo_live_range)
: BufferIntervalComparator(), hlo_live_range_(hlo_live_range) {}
const HloLiveRange& hlo_live_range,
const MsaSortOrderOverrides& msa_sort_order_overrides)
: BufferIntervalComparator(),
hlo_live_range_(hlo_live_range),
msa_sort_order_overrides_(msa_sort_order_overrides) {}

std::string DefaultCrossProgramPrefetchBufferIntervalComparator::
DescribeComparisonCriteria() const {
Expand All @@ -196,6 +141,9 @@ bool DefaultCrossProgramPrefetchBufferIntervalComparator::LessThan(
DefaultCrossProgramPrefetchBufferIntervalComparator::ComparisonTuple
DefaultCrossProgramPrefetchBufferIntervalComparator::GetTuple(
const MsaBufferInterval& buffer_interval) {
int64_t priority =
MemorySpaceAssignmentUtils::GetBufferIntervalOverridePriority(
msa_sort_order_overrides_, buffer_interval);
auto sort_data_it = additional_sort_data_.find(buffer_interval.buffer);
if (sort_data_it == additional_sort_data_.end()) {
AdditionalSortData sort_data;
Expand All @@ -213,9 +161,10 @@ DefaultCrossProgramPrefetchBufferIntervalComparator::GetTuple(
.first;
}

return std::make_tuple(
-1 * buffer_interval.size, -1 * sort_data_it->second.cumulative_use_size,
sort_data_it->second.latest_use, buffer_interval.buffer->id());
return std::make_tuple(priority, -1 * buffer_interval.size,
-1 * sort_data_it->second.cumulative_use_size,
sort_data_it->second.latest_use,
buffer_interval.buffer->id());
}

} // namespace memory_space_assignment
Expand Down
Loading

0 comments on commit 148dfb7

Please sign in to comment.