diff --git a/xla/service/memory_space_assignment/BUILD b/xla/service/memory_space_assignment/BUILD index da4c57af71537f..045cfbb8dc2054 100644 --- a/xla/service/memory_space_assignment/BUILD +++ b/xla/service/memory_space_assignment/BUILD @@ -165,12 +165,22 @@ cc_library( srcs = ["utils.cc"], hdrs = ["utils.h"], deps = [ + ":memory_space_assignment_proto_cc", "//xla:shape_util", + "//xla:util", "//xla/hlo/ir:hlo", + "//xla/hlo/utils:hlo_live_range", "//xla/service:hlo_value", "//xla/service/heap_simulator", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@com_googlesource_code_re2//:re2", + "@tsl//tsl/platform:statusor", ], ) @@ -544,7 +554,6 @@ cc_library( "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", - "@com_googlesource_code_re2//:re2", "@tsl//tsl/platform:logging", "@tsl//tsl/platform:status", "@tsl//tsl/platform:statusor", @@ -558,6 +567,7 @@ cc_library( deps = [ ":cost_analysis", ":memory_space_assignment_proto_cc", + ":utils", "//xla:shape_util", "//xla:util", "//xla/hlo/ir:hlo", @@ -568,10 +578,8 @@ cc_library( "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:string_view", - "@com_googlesource_code_re2//:re2", ], ) diff --git a/xla/service/memory_space_assignment/algorithm.cc b/xla/service/memory_space_assignment/algorithm.cc index 7a5ff4073692c5..e27aabc53e36df 100644 --- a/xla/service/memory_space_assignment/algorithm.cc +++ b/xla/service/memory_space_assignment/algorithm.cc @@ -47,7 +47,6 @@ limitations under the License. #include "absl/strings/str_join.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" -#include "re2/re2.h" #include "xla/debug_options_flags.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -282,20 +281,24 @@ std::vector FindCrossProgramPrefetchCandidates( for (const HloBuffer& buffer : alias_analysis.buffers()) { CHECK_GE(buffer.values().size(), 1); const HloValue* value = buffer.values().at(0); + MsaBufferInterval interval; + interval.buffer = value; + interval.size = options.size_fn(*value); + interval.start = 0; + interval.end = hlo_live_range.schedule_end_time(); + interval.need_allocation = true; + interval.colocations = {++buffer.values().begin(), buffer.values().end()}; if (IsCrossProgramPrefetchCandidate(*value, alias_analysis, options)) { - MsaBufferInterval interval; - interval.buffer = value; - interval.size = options.size_fn(*value); - interval.start = 0; - interval.end = hlo_live_range.schedule_end_time(); - interval.need_allocation = true; - interval.colocations = {++buffer.values().begin(), buffer.values().end()}; + candidates.emplace_back(interval); + } else if (MemorySpaceAssignmentUtils:: + DoesCrossProgramPrefetchBufferMatchAnyFilter( + options.msa_sort_order_overrides, interval)) { candidates.emplace_back(interval); } } DefaultCrossProgramPrefetchBufferIntervalComparator default_comparator( - hlo_live_range); + hlo_live_range, options.msa_sort_order_overrides); BufferIntervalComparator* comparator = (options.default_cross_program_prefetch_heuristic && options.buffer_interval_comparator @@ -313,129 +316,6 @@ std::vector FindCrossProgramPrefetchCandidates( return candidates; } -absl::StatusOr -GetScheduleTimeFromInstructionName( - absl::string_view name, - const absl::flat_hash_map& schedule) { - for (auto schedule_entry : schedule) { - if (schedule_entry.first->name() == name) { - return schedule_entry.second; - } - } - return NotFound("Reference instruction %s was not found in the schedule.", - name); -} - -bool DoesOperandMatchFilter(const HloOperandFilter& filter, - int64_t operand_size, const HloUse& hlo_use) { - if (filter.has_size_gte() && operand_size < filter.size_gte()) { - return false; - } - if (filter.has_size_lte() && operand_size > filter.size_lte()) { - return false; - } - if (filter.has_operand_number() && - hlo_use.operand_number != filter.operand_number()) { - return false; - } - if (filter.has_instruction_name_regex() && - !RE2::FullMatch(hlo_use.instruction->name(), - filter.instruction_name_regex())) { - return false; - } - if (filter.has_tuple_index() && - hlo_use.operand_index != ShapeIndex(filter.tuple_index().index().begin(), - filter.tuple_index().index().end())) { - return false; - } - return true; -} - -absl::StatusOr> GetPrefetchTimeByEagerness( - float prefetch_eagerness, int64_t earliest_prefetch_time, - int64_t latest_prefetch_time) { - CHECK_GE(prefetch_eagerness, 0.0); - CHECK_LE(prefetch_eagerness, 1.0); - if (earliest_prefetch_time > latest_prefetch_time) { - return static_cast>(std::nullopt); - } - return static_cast>( - earliest_prefetch_time + - (latest_prefetch_time - earliest_prefetch_time) * prefetch_eagerness); -} - -absl::StatusOr> GetPrefetchTimeAfterInstruction( - const std::string& after_instruction_name, - const absl::flat_hash_map& schedule) { - TF_ASSIGN_OR_RETURN( - auto reference_instruction_time, - GetScheduleTimeFromInstructionName(after_instruction_name, schedule)); - return static_cast>(reference_instruction_time); -} - -absl::StatusOr> GetPrefetchTimeBeforeInstruction( - const std::string& before_instruction_name, - const absl::flat_hash_map& schedule) { - TF_ASSIGN_OR_RETURN( - auto reference_instruction_time, - GetScheduleTimeFromInstructionName(before_instruction_name, schedule)); - return static_cast>(reference_instruction_time - 1); -} - -absl::StatusOr> GetPrefetchTime( - const PreferredPrefetchOverrideOptions& override_options, - int64_t earliest_prefetch_time, int64_t latest_prefetch_time, - const absl::flat_hash_map& - instruction_schedule) { - switch (override_options.options_case()) { - case PreferredPrefetchOverrideOptions::kPrefetchEagerness: - return GetPrefetchTimeByEagerness(override_options.prefetch_eagerness(), - earliest_prefetch_time, - latest_prefetch_time); - case PreferredPrefetchOverrideOptions::kAfterInstructionName: - return GetPrefetchTimeAfterInstruction( - override_options.after_instruction_name(), instruction_schedule); - case PreferredPrefetchOverrideOptions::kBeforeInstructionName: - return GetPrefetchTimeBeforeInstruction( - override_options.before_instruction_name(), instruction_schedule); - case PreferredPrefetchOverrideOptions::OPTIONS_NOT_SET: - break; - } - return static_cast>>(std::nullopt); -} - -absl::StatusOr> GetOverriddenPreferredPrefetchTime( - const PreferredPrefetchOverrides& preferred_prefetch_overrides, - int64_t operand_size, const HloUse& hlo_use, - const absl::flat_hash_map& - instruction_schedule, - int64_t earliest_prefetch_time, int64_t latest_prefetch_time) { - for (const auto& override : preferred_prefetch_overrides.overrides()) { - if (!DoesOperandMatchFilter(override.hlo_operand_filter(), operand_size, - hlo_use)) { - continue; - } - LOG(INFO) << "Config match for instruction " << hlo_use.instruction->name() - << " operand number " << hlo_use.operand_number - << " operand index " << hlo_use.operand_index.ToString() - << " size " << operand_size << " live range (" - << earliest_prefetch_time << ", " << latest_prefetch_time << ")"; - TF_ASSIGN_OR_RETURN( - auto prefetch_time, - GetPrefetchTime(override.override_options(), earliest_prefetch_time, - latest_prefetch_time, instruction_schedule)); - if (prefetch_time.has_value() && - prefetch_time.value() >= earliest_prefetch_time && - prefetch_time.value() <= latest_prefetch_time) { - return prefetch_time; - } - } - return static_cast>>(std::nullopt); -} - } // namespace std::string AllocationValue::ToString() const { @@ -2507,7 +2387,7 @@ MsaAlgorithm::AllocationRequest MsaAlgorithm::CreateAllocationRequest( ? earliest_prefetch_time.value() : std::min(definition_time, use_time)); auto overridden_preferred_prefetch_time = - GetOverriddenPreferredPrefetchTime( + MemorySpaceAssignmentUtils::GetOverriddenPreferredPrefetchTime( options_.preferred_prefetch_overrides, allocation_value.size(), hlo_use, instruction_schedule, live_range_start_time, latest_prefetch_time); diff --git a/xla/service/memory_space_assignment/buffer_interval_comparator.cc b/xla/service/memory_space_assignment/buffer_interval_comparator.cc index 10ced32f478b79..818c00ea9dc945 100644 --- a/xla/service/memory_space_assignment/buffer_interval_comparator.cc +++ b/xla/service/memory_space_assignment/buffer_interval_comparator.cc @@ -17,7 +17,6 @@ limitations under the License. #include #include -#include #include #include #include @@ -25,78 +24,20 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" -#include "absl/log/log.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "absl/strings/string_view.h" -#include "re2/re2.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/utils/hlo_live_range.h" #include "xla/service/hlo_value.h" #include "xla/service/memory_space_assignment/cost_analysis.h" #include "xla/service/memory_space_assignment/memory_space_assignment.pb.h" +#include "xla/service/memory_space_assignment/utils.h" #include "xla/shape_util.h" #include "xla/util.h" namespace xla { namespace memory_space_assignment { -namespace { - -bool DoesResultMatchFilter(const HloPositionMatcher& filter, - const MsaBufferInterval& buffer_interval) { - HloInstruction* instruction = buffer_interval.buffer->instruction(); - if (filter.has_instruction_regex() && - !RE2::FullMatch(instruction->ToString(), filter.instruction_regex())) { - return false; - } - if (filter.has_instruction_name_regex() && - !RE2::FullMatch(instruction->name(), filter.instruction_name_regex())) { - return false; - } - if (filter.has_tuple_index() && - buffer_interval.buffer->index() != - ShapeIndex(filter.tuple_index().index().begin(), - filter.tuple_index().index().end())) { - return false; - } - if (filter.has_size_gte() && filter.size_gte() > buffer_interval.size) { - return false; - } - if (filter.has_size_lte() && filter.size_lte() < buffer_interval.size) { - return false; - } - return true; -} - -// Returns an integer representing the priority of a MsaBufferInterval during -// assignment, a smaller number indicates a higher priority. -int64_t GetBufferIntervalOverridePriority( - const MsaSortOrderOverrides& msa_sort_order_overrides, - const MsaBufferInterval& buffer_interval) { - if (msa_sort_order_overrides.overrides_size() == 0) { - return 0; - } - for (int64_t i = 0; i < msa_sort_order_overrides.overrides_size(); ++i) { - const auto& override = msa_sort_order_overrides.overrides(i); - if (!DoesResultMatchFilter(override.hlo_position_matcher(), - buffer_interval)) { - continue; - } - LOG(INFO) << "Override Sort Order Config " << i << " matches " - << buffer_interval.buffer->instruction()->ToString(); - switch (override.override_options().options_case()) { - case MsaSortOrderOverrideOptions::kAssignFirst: - return std::numeric_limits::lowest() + i; - case MsaSortOrderOverrideOptions::kAssignLast: - return std::numeric_limits::max() - i; - case MsaSortOrderOverrideOptions::OPTIONS_NOT_SET: - continue; - } - } - return 0; -} - -} // namespace MemoryBoundednessBufferIntervalComparator:: MemoryBoundednessBufferIntervalComparator( @@ -156,8 +97,9 @@ int64_t MemoryBoundednessBufferIntervalComparator::GetLatestUseTime( MemoryBoundednessBufferIntervalComparator::ComparisonTuple MemoryBoundednessBufferIntervalComparator::GetTuple( const MsaBufferInterval& buffer_interval) { - int64_t priority = GetBufferIntervalOverridePriority( - msa_sort_order_overrides_, buffer_interval); + int64_t priority = + MemorySpaceAssignmentUtils::GetBufferIntervalOverridePriority( + msa_sort_order_overrides_, buffer_interval); float inverse_memory_boundedness = -1.0 * cost_analysis_.GetMemoryBoundedness(buffer_interval, cost_analysis_cache_); @@ -173,8 +115,11 @@ MemoryBoundednessBufferIntervalComparator::GetTuple( DefaultCrossProgramPrefetchBufferIntervalComparator:: DefaultCrossProgramPrefetchBufferIntervalComparator( - const HloLiveRange& hlo_live_range) - : BufferIntervalComparator(), hlo_live_range_(hlo_live_range) {} + const HloLiveRange& hlo_live_range, + const MsaSortOrderOverrides& msa_sort_order_overrides) + : BufferIntervalComparator(), + hlo_live_range_(hlo_live_range), + msa_sort_order_overrides_(msa_sort_order_overrides) {} std::string DefaultCrossProgramPrefetchBufferIntervalComparator:: DescribeComparisonCriteria() const { @@ -196,6 +141,9 @@ bool DefaultCrossProgramPrefetchBufferIntervalComparator::LessThan( DefaultCrossProgramPrefetchBufferIntervalComparator::ComparisonTuple DefaultCrossProgramPrefetchBufferIntervalComparator::GetTuple( const MsaBufferInterval& buffer_interval) { + int64_t priority = + MemorySpaceAssignmentUtils::GetBufferIntervalOverridePriority( + msa_sort_order_overrides_, buffer_interval); auto sort_data_it = additional_sort_data_.find(buffer_interval.buffer); if (sort_data_it == additional_sort_data_.end()) { AdditionalSortData sort_data; @@ -213,9 +161,10 @@ DefaultCrossProgramPrefetchBufferIntervalComparator::GetTuple( .first; } - return std::make_tuple( - -1 * buffer_interval.size, -1 * sort_data_it->second.cumulative_use_size, - sort_data_it->second.latest_use, buffer_interval.buffer->id()); + return std::make_tuple(priority, -1 * buffer_interval.size, + -1 * sort_data_it->second.cumulative_use_size, + sort_data_it->second.latest_use, + buffer_interval.buffer->id()); } } // namespace memory_space_assignment diff --git a/xla/service/memory_space_assignment/buffer_interval_comparator.h b/xla/service/memory_space_assignment/buffer_interval_comparator.h index f5705df568f3d9..5c7a94b6ffd468 100644 --- a/xla/service/memory_space_assignment/buffer_interval_comparator.h +++ b/xla/service/memory_space_assignment/buffer_interval_comparator.h @@ -27,12 +27,11 @@ limitations under the License. #include "xla/service/hlo_value.h" #include "xla/service/memory_space_assignment/cost_analysis.h" #include "xla/service/memory_space_assignment/memory_space_assignment.pb.h" +#include "xla/service/memory_space_assignment/utils.h" namespace xla { namespace memory_space_assignment { -using MsaBufferInterval = - GlobalDecreasingSizeBestFitHeap::BufferInterval; using MsaBufferIntervalCompare = GlobalDecreasingSizeBestFitHeap::BufferIntervalCompare; @@ -115,7 +114,8 @@ class DefaultCrossProgramPrefetchBufferIntervalComparator : public BufferIntervalComparator { public: explicit DefaultCrossProgramPrefetchBufferIntervalComparator( - const HloLiveRange& hlo_live_range); + const HloLiveRange& hlo_live_range, + const MsaSortOrderOverrides& msa_sort_order_overrides); ~DefaultCrossProgramPrefetchBufferIntervalComparator() override = default; @@ -129,7 +129,7 @@ class DefaultCrossProgramPrefetchBufferIntervalComparator // See the value returned by DescribeComparisonCriteria() for the meaning of // each tuple element. using ComparisonTuple = - std::tuple; + std::tuple; struct AdditionalSortData { int64_t latest_use = 0; @@ -141,6 +141,7 @@ class DefaultCrossProgramPrefetchBufferIntervalComparator absl::flat_hash_map additional_sort_data_; const HloLiveRange& hlo_live_range_; + const MsaSortOrderOverrides& msa_sort_order_overrides_; }; } // namespace memory_space_assignment diff --git a/xla/service/memory_space_assignment/memory_space_assignment.proto b/xla/service/memory_space_assignment/memory_space_assignment.proto index e15d564dac8f35..09ae1382a4031c 100644 --- a/xla/service/memory_space_assignment/memory_space_assignment.proto +++ b/xla/service/memory_space_assignment/memory_space_assignment.proto @@ -105,6 +105,11 @@ message HloOperandFilter { // If operand of an instruction is a tuple and indexing into the tuple is // required. optional TupleShapeIndex tuple_index = 5; + // Regex to match the entire instruction HLO. The HLO string is constructed + // using default HloPrintOptions. Refer to the HloPrintOptions class in + // hlo_instruction.h to know more about the format of the HLO string used for + // matching. + optional string instruction_regex = 6; } // Options to override preferred prefetch time for an operand. @@ -154,6 +159,8 @@ message HloPositionMatcher { optional int64 size_gte = 4; // Filters instructions with output size in bytes less or equal to a value. optional int64 size_lte = 5; + // Filters instructions that have a use that matches the filter. + optional HloOperandFilter hlo_use_filter = 6; } // Options to override preferred prefetch time for an operand. @@ -176,6 +183,7 @@ message MsaSortOrderOverride { optional HloPositionMatcher hlo_position_matcher = 1; optional xla.memory_space_assignment.MsaSortOrderOverrideOptions override_options = 2; + optional bool apply_to_cross_program_prefetches = 3; } // Encloses chained override configs. The first config has highest precedence diff --git a/xla/service/memory_space_assignment/memory_space_assignment_test.cc b/xla/service/memory_space_assignment/memory_space_assignment_test.cc index 7fc30f30102e43..c42f29e4d5abe1 100644 --- a/xla/service/memory_space_assignment/memory_space_assignment_test.cc +++ b/xla/service/memory_space_assignment/memory_space_assignment_test.cc @@ -20,7 +20,6 @@ limitations under the License. #include #include #include -#include #include #include #include @@ -4892,6 +4891,77 @@ TEST_F(MemorySpaceAssignmentTest, EXPECT_EQ(tanh4->shape().layout().memory_space(), kDefaultMemorySpace); } +TEST_F(MemorySpaceAssignmentTest, + MemoryBoundednessOverrideSortOrderByUseAssignFirst) { + absl::string_view hlo_string = R"( + HloModule module, is_scheduled=true + + ENTRY entry { + p0 = f32[3,4]{1,0} parameter(0) + p1 = f32[3,4]{1,0} parameter(1) + tanh0 = f32[3,4]{1,0} tanh(p0) + negate0 = f32[3,4]{1,0} negate(p1) + tanh1 = f32[3,4]{1,0} tanh(tanh0) + negate1 = f32[3,4]{1,0} negate(negate0) + tanh2 = f32[3,4]{1,0} tanh(tanh1) + negate2 = f32[3,4]{1,0} negate(negate1) + tanh3 = f32[3,4]{1,0} tanh(tanh2) + negate3 = f32[3,4]{1,0} negate(negate2) + tanh4 = f32[3,4]{1,0} tanh(tanh3) + negate4 = f32[3,4]{1,0} negate(negate3) + ROOT tuple = (f32[3,4]{1,0}, f32[3,4]{1,0}) tuple(tanh4, negate4) + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + + // Override MSA sort order and try to assign all negates to alternate memory + // first. Alternate memory size is enough to fit 2 f32[4,3] tensors at a time. + const std::string text_proto = R"pb( + overrides { + hlo_position_matcher { + hlo_use_filter { instruction_name_regex: "negate(.*)" } + } + override_options { assign_first: true } + })pb"; + TF_ASSERT_OK_AND_ASSIGN(auto msa_sort_order_overrides, + ParseTextProto(text_proto)); + + AssignMemorySpaceUsingCostAnalysis( + module.get(), /*memory_space_options_override=*/std::nullopt, + /*cost_analysis_options_override=*/std::nullopt, + /*hlo_cost_options_override=*/std::nullopt, + /*optional_msa_sort_order_overrides=*/msa_sort_order_overrides); + // Parameters are in the default memory space. + const HloInstruction* p0 = FindInstruction(module.get(), "p0"); + EXPECT_EQ(p0->shape().layout().memory_space(), kDefaultMemorySpace); + const HloInstruction* p1 = FindInstruction(module.get(), "p1"); + EXPECT_EQ(p1->shape().layout().memory_space(), kDefaultMemorySpace); + // Check that all negates are in alternate memory space except negate4. + // negate4 is a program output, so it has to land in default memory. + HloInstruction* negate0 = FindInstruction(module.get(), "negate0"); + EXPECT_EQ(negate0->shape().layout().memory_space(), kAlternateMemorySpace); + HloInstruction* negate1 = FindInstruction(module.get(), "negate1"); + EXPECT_EQ(negate1->shape().layout().memory_space(), kAlternateMemorySpace); + HloInstruction* negate2 = FindInstruction(module.get(), "negate2"); + EXPECT_EQ(negate2->shape().layout().memory_space(), kAlternateMemorySpace); + HloInstruction* negate3 = FindInstruction(module.get(), "negate3"); + EXPECT_EQ(negate3->shape().layout().memory_space(), kAlternateMemorySpace); + HloInstruction* negate4 = FindInstruction(module.get(), "negate4"); + EXPECT_EQ(negate4->shape().layout().memory_space(), kDefaultMemorySpace); + const HloInstruction* tanh0 = FindInstruction(module.get(), "tanh0"); + EXPECT_EQ(tanh0->shape().layout().memory_space(), kDefaultMemorySpace); + const HloInstruction* tanh1 = FindInstruction(module.get(), "tanh1"); + EXPECT_EQ(tanh1->shape().layout().memory_space(), kDefaultMemorySpace); + const HloInstruction* tanh2 = FindInstruction(module.get(), "tanh2"); + EXPECT_EQ(tanh2->shape().layout().memory_space(), kDefaultMemorySpace); + const HloInstruction* tanh3 = FindInstruction(module.get(), "tanh3"); + EXPECT_EQ(tanh3->shape().layout().memory_space(), kDefaultMemorySpace); + const HloInstruction* tanh4 = FindInstruction(module.get(), "tanh4"); + EXPECT_EQ(tanh4->shape().layout().memory_space(), kDefaultMemorySpace); +} + TEST_F(MemorySpaceAssignmentTest, SimpleWhileTupleTest) { Shape s32 = ShapeUtil::MakeShape(xla::S32, {}); Shape f32v1 = ShapeUtil::MakeShape(F32, {1}); @@ -8896,10 +8966,8 @@ TEST_F(MemorySpaceAssignmentTest, CrossProgramPrefetchTest) { auto cross_program_prefetches = module->CrossProgramPrefetches(); EXPECT_EQ(cross_program_prefetches.size(), 1); - if (!cross_program_prefetches.empty()) { - EXPECT_EQ(cross_program_prefetches[0].parameter, 1); - EXPECT_EQ(cross_program_prefetches[0].index, ShapeIndex({})); - } + EXPECT_EQ(cross_program_prefetches[0].parameter, 1); + EXPECT_EQ(cross_program_prefetches[0].index, ShapeIndex({})); EXPECT_THAT(module->entry_computation()->root_instruction(), op::Dot(op::Parameter(0), @@ -8956,14 +9024,10 @@ TEST_F(MemorySpaceAssignmentTest, MultiCrossProgramPrefetchTest) { auto cross_program_prefetches = module->CrossProgramPrefetches(); EXPECT_EQ(cross_program_prefetches.size(), 2); - if (!cross_program_prefetches.empty()) { - EXPECT_EQ(cross_program_prefetches[0].parameter, 1); - EXPECT_EQ(cross_program_prefetches[0].index, ShapeIndex({})); - } - if (cross_program_prefetches.size() > 1) { - EXPECT_EQ(cross_program_prefetches[1].parameter, 2); - EXPECT_EQ(cross_program_prefetches[1].index, ShapeIndex({})); - } + EXPECT_EQ(cross_program_prefetches[0].parameter, 1); + EXPECT_EQ(cross_program_prefetches[0].index, ShapeIndex({})); + EXPECT_EQ(cross_program_prefetches[1].parameter, 2); + EXPECT_EQ(cross_program_prefetches[1].index, ShapeIndex({})); EXPECT_THAT( module->entry_computation()->root_instruction(), @@ -9010,10 +9074,8 @@ TEST_F(MemorySpaceAssignmentTest, CrossProgramPrefetchTupleTest) { auto cross_program_prefetches = module->CrossProgramPrefetches(); EXPECT_EQ(cross_program_prefetches.size(), 1); - if (!cross_program_prefetches.empty()) { - EXPECT_EQ(cross_program_prefetches[0].parameter, 0); - EXPECT_EQ(cross_program_prefetches[0].index, ShapeIndex({1})); - } + EXPECT_EQ(cross_program_prefetches[0].parameter, 0); + EXPECT_EQ(cross_program_prefetches[0].index, ShapeIndex({1})); } TEST_F(MemorySpaceAssignmentTest, CrossProgramPrefetchBitcastTest) { @@ -9052,10 +9114,8 @@ TEST_F(MemorySpaceAssignmentTest, CrossProgramPrefetchBitcastTest) { auto cross_program_prefetches = module->CrossProgramPrefetches(); EXPECT_EQ(cross_program_prefetches.size(), 1); - if (!cross_program_prefetches.empty()) { - EXPECT_EQ(cross_program_prefetches[0].parameter, 1); - EXPECT_EQ(cross_program_prefetches[0].index, ShapeIndex({})); - } + EXPECT_EQ(cross_program_prefetches[0].parameter, 1); + EXPECT_EQ(cross_program_prefetches[0].index, ShapeIndex({})); } TEST_F(MemorySpaceAssignmentTest, CrossProgramPrefetchBitcastTupleTest) { @@ -9098,10 +9158,8 @@ TEST_F(MemorySpaceAssignmentTest, CrossProgramPrefetchBitcastTupleTest) { auto cross_program_prefetches = module->CrossProgramPrefetches(); EXPECT_EQ(cross_program_prefetches.size(), 1); - if (!cross_program_prefetches.empty()) { - EXPECT_EQ(cross_program_prefetches[0].parameter, 0); - EXPECT_EQ(cross_program_prefetches[0].index, ShapeIndex({1})); - } + EXPECT_EQ(cross_program_prefetches[0].parameter, 0); + EXPECT_EQ(cross_program_prefetches[0].index, ShapeIndex({1})); } TEST_F(MemorySpaceAssignmentTest, CrossProgramPrefetchNestedTupleTest) { @@ -9631,10 +9689,8 @@ TEST_F(MemorySpaceAssignmentTest, CrossProgramPrefetchNoReuse) { auto cross_program_prefetches = module->CrossProgramPrefetches(); EXPECT_EQ(cross_program_prefetches.size(), 1); - if (!cross_program_prefetches.empty()) { - EXPECT_EQ(cross_program_prefetches[0].parameter, 1); - EXPECT_EQ(cross_program_prefetches[0].index, ShapeIndex({})); - } + EXPECT_EQ(cross_program_prefetches[0].parameter, 1); + EXPECT_EQ(cross_program_prefetches[0].index, ShapeIndex({})); TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr dataflow_analysis, @@ -9711,10 +9767,8 @@ TEST_F(MemorySpaceAssignmentTest, CrossProgramPrefetchTupleNoReuse) { auto cross_program_prefetches = module->CrossProgramPrefetches(); EXPECT_EQ(cross_program_prefetches.size(), 1); - if (!cross_program_prefetches.empty()) { - EXPECT_EQ(cross_program_prefetches[0].parameter, 0); - EXPECT_EQ(cross_program_prefetches[0].index, ShapeIndex({1})); - } + EXPECT_EQ(cross_program_prefetches[0].parameter, 0); + EXPECT_EQ(cross_program_prefetches[0].index, ShapeIndex({1})); TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr dataflow_analysis, @@ -9790,10 +9844,8 @@ TEST_F(MemorySpaceAssignmentTest, CrossProgramPrefetchReuse) { auto cross_program_prefetches = module->CrossProgramPrefetches(); EXPECT_EQ(cross_program_prefetches.size(), 1); - if (!cross_program_prefetches.empty()) { - EXPECT_EQ(cross_program_prefetches[0].parameter, 1); - EXPECT_EQ(cross_program_prefetches[0].index, ShapeIndex({})); - } + EXPECT_EQ(cross_program_prefetches[0].parameter, 1); + EXPECT_EQ(cross_program_prefetches[0].index, ShapeIndex({})); TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr dataflow_analysis, @@ -9850,10 +9902,8 @@ TEST_F(MemorySpaceAssignmentTest, CrossProgramPrefetchTupleReuse) { auto cross_program_prefetches = module->CrossProgramPrefetches(); EXPECT_EQ(cross_program_prefetches.size(), 1); - if (!cross_program_prefetches.empty()) { - EXPECT_EQ(cross_program_prefetches[0].parameter, 0); - EXPECT_EQ(cross_program_prefetches[0].index, ShapeIndex({1})); - } + EXPECT_EQ(cross_program_prefetches[0].parameter, 0); + EXPECT_EQ(cross_program_prefetches[0].index, ShapeIndex({1})); TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr dataflow_analysis, diff --git a/xla/service/memory_space_assignment/options.h b/xla/service/memory_space_assignment/options.h index e7cb78a17bebf7..6db6917993170d 100644 --- a/xla/service/memory_space_assignment/options.h +++ b/xla/service/memory_space_assignment/options.h @@ -259,6 +259,8 @@ struct Options { // and gives MSA more flexibility in choosing the prefetch time and how much // data to prefetch. bool enable_window_prefetch = false; + + MsaSortOrderOverrides msa_sort_order_overrides; }; } // namespace memory_space_assignment } // namespace xla diff --git a/xla/service/memory_space_assignment/utils.cc b/xla/service/memory_space_assignment/utils.cc index 7f40919b53efb8..ee11ccf250c677 100644 --- a/xla/service/memory_space_assignment/utils.cc +++ b/xla/service/memory_space_assignment/utils.cc @@ -15,14 +15,28 @@ limitations under the License. #include "xla/service/memory_space_assignment/utils.h" +#include +#include +#include +#include + #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" +#include "absl/log/check.h" #include "absl/log/log.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "re2/re2.h" #include "xla/hlo/ir/hlo_casting_utils.h" +#include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/utils/hlo_live_range.h" #include "xla/service/heap_simulator/heap_simulator.h" #include "xla/service/hlo_value.h" #include "xla/shape_util.h" +#include "xla/util.h" +#include "tsl/platform/statusor.h" namespace xla { namespace memory_space_assignment { @@ -103,5 +117,227 @@ bool MemorySpaceAssignmentUtils::IsIntervalAllowedInAlternateMemory( }); } +bool MemorySpaceAssignmentUtils::DoesUseMatchFilter( + const HloOperandFilter& filter, const HloUse& hlo_use, + int64_t operand_size) { + // The order of checks is such that the most expensive checks are done last. + if (filter.has_size_gte() && operand_size < filter.size_gte()) { + return false; + } + if (filter.has_size_lte() && operand_size > filter.size_lte()) { + return false; + } + if (filter.has_operand_number() && + hlo_use.operand_number != filter.operand_number()) { + return false; + } + if (filter.has_tuple_index() && + hlo_use.operand_index != ShapeIndex(filter.tuple_index().index().begin(), + filter.tuple_index().index().end())) { + return false; + } + if (filter.has_instruction_name_regex() && + !RE2::FullMatch(hlo_use.instruction->name(), + filter.instruction_name_regex())) { + return false; + } + if (filter.has_instruction_regex() && + !RE2::FullMatch(hlo_use.instruction->ToString(), + filter.instruction_regex())) { + return false; + } + return true; +} + +bool MemorySpaceAssignmentUtils::DoesPositionMatchFilter( + const HloPositionMatcher& filter, + const MsaBufferInterval& buffer_interval) { + // The order of checks is such that the most expensive checks are done last. + HloInstruction* instruction = buffer_interval.buffer->instruction(); + if (filter.has_size_gte() && filter.size_gte() > buffer_interval.size) { + return false; + } + if (filter.has_size_lte() && filter.size_lte() < buffer_interval.size) { + return false; + } + if (filter.has_tuple_index() && + buffer_interval.buffer->index() != + ShapeIndex(filter.tuple_index().index().begin(), + filter.tuple_index().index().end())) { + return false; + } + if (filter.has_instruction_name_regex() && + !RE2::FullMatch(instruction->name(), filter.instruction_name_regex())) { + return false; + } + if (filter.has_instruction_regex() && + !RE2::FullMatch(instruction->ToString(), filter.instruction_regex())) { + return false; + } + return DoesBufferIntervalMatchHloUseFilter(filter, buffer_interval); +} + +bool MemorySpaceAssignmentUtils::DoesBufferIntervalMatchHloUseFilter( + const HloPositionMatcher& filter, + const MsaBufferInterval& buffer_interval) { + if (!filter.has_hlo_use_filter()) { + return true; + } + for (const HloUse& use : buffer_interval.buffer->GetUses()) { + if (DoesUseMatchFilter(filter.hlo_use_filter(), use, + buffer_interval.size)) { + return true; + } + } + return false; +} + +absl::StatusOr +MemorySpaceAssignmentUtils::GetScheduleTimeFromInstructionName( + absl::string_view name, + const absl::flat_hash_map& schedule) { + for (auto schedule_entry : schedule) { + if (schedule_entry.first->name() == name) { + return schedule_entry.second; + } + } + return NotFound("Reference instruction %s was not found in the schedule.", + name); +} + +absl::StatusOr> +MemorySpaceAssignmentUtils::GetPrefetchTimeByEagerness( + float prefetch_eagerness, int64_t earliest_prefetch_time, + int64_t latest_prefetch_time) { + CHECK_GE(prefetch_eagerness, 0.0); + CHECK_LE(prefetch_eagerness, 1.0); + if (earliest_prefetch_time > latest_prefetch_time) { + return static_cast>(std::nullopt); + } + return static_cast>( + earliest_prefetch_time + + (latest_prefetch_time - earliest_prefetch_time) * prefetch_eagerness); +} + +absl::StatusOr> +MemorySpaceAssignmentUtils::GetPrefetchTimeAfterInstruction( + const std::string& after_instruction_name, + const absl::flat_hash_map& schedule) { + TF_ASSIGN_OR_RETURN( + auto reference_instruction_time, + GetScheduleTimeFromInstructionName(after_instruction_name, schedule)); + return static_cast>(reference_instruction_time); +} + +absl::StatusOr> +MemorySpaceAssignmentUtils::GetPrefetchTimeBeforeInstruction( + const std::string& before_instruction_name, + const absl::flat_hash_map& schedule) { + TF_ASSIGN_OR_RETURN( + auto reference_instruction_time, + GetScheduleTimeFromInstructionName(before_instruction_name, schedule)); + return static_cast>(reference_instruction_time - 1); +} +absl::StatusOr> +MemorySpaceAssignmentUtils::GetPrefetchTime( + const PreferredPrefetchOverrideOptions& override_options, + int64_t earliest_prefetch_time, int64_t latest_prefetch_time, + const absl::flat_hash_map& + instruction_schedule) { + switch (override_options.options_case()) { + case PreferredPrefetchOverrideOptions::kPrefetchEagerness: + return GetPrefetchTimeByEagerness(override_options.prefetch_eagerness(), + earliest_prefetch_time, + latest_prefetch_time); + case PreferredPrefetchOverrideOptions::kAfterInstructionName: + return GetPrefetchTimeAfterInstruction( + override_options.after_instruction_name(), instruction_schedule); + case PreferredPrefetchOverrideOptions::kBeforeInstructionName: + return GetPrefetchTimeBeforeInstruction( + override_options.before_instruction_name(), instruction_schedule); + case PreferredPrefetchOverrideOptions::OPTIONS_NOT_SET: + break; + } + return static_cast>>(std::nullopt); +} + +absl::StatusOr> +MemorySpaceAssignmentUtils::GetOverriddenPreferredPrefetchTime( + const PreferredPrefetchOverrides& preferred_prefetch_overrides, + int64_t operand_size, const HloUse& hlo_use, + const absl::flat_hash_map& + instruction_schedule, + int64_t earliest_prefetch_time, int64_t latest_prefetch_time) { + for (const auto& override : preferred_prefetch_overrides.overrides()) { + if (!MemorySpaceAssignmentUtils::DoesUseMatchFilter( + override.hlo_operand_filter(), hlo_use, operand_size)) { + continue; + } + VLOG(3) << "Config match for instruction " << hlo_use.instruction->name() + << " operand number " << hlo_use.operand_number << " operand index " + << hlo_use.operand_index.ToString() << " size " << operand_size + << " live range (" << earliest_prefetch_time << ", " + << latest_prefetch_time << ")"; + TF_ASSIGN_OR_RETURN( + auto prefetch_time, + GetPrefetchTime(override.override_options(), earliest_prefetch_time, + latest_prefetch_time, instruction_schedule)); + if (prefetch_time.has_value() && + prefetch_time.value() >= earliest_prefetch_time && + prefetch_time.value() <= latest_prefetch_time) { + return prefetch_time; + } + } + return static_cast>>(std::nullopt); +} + +bool MemorySpaceAssignmentUtils::DoesCrossProgramPrefetchBufferMatchAnyFilter( + const MsaSortOrderOverrides& sort_order_overrides, + const MsaBufferInterval& buffer_interval) { + for (const MsaSortOrderOverride& override : + sort_order_overrides.overrides()) { + if (override.has_apply_to_cross_program_prefetches() && + override.apply_to_cross_program_prefetches() && + MemorySpaceAssignmentUtils::DoesPositionMatchFilter( + override.hlo_position_matcher(), buffer_interval) && + override.override_options().has_assign_first() && + override.override_options().assign_first()) { + VLOG(3) << "Cross program prefetch buffer " + << buffer_interval.buffer->ToString() + << " matches sort order override " << absl::StrCat(override); + return true; + } + } + return false; +} + +int64_t MemorySpaceAssignmentUtils::GetBufferIntervalOverridePriority( + const MsaSortOrderOverrides& msa_sort_order_overrides, + const MsaBufferInterval& buffer_interval) { + if (msa_sort_order_overrides.overrides_size() == 0) { + return 0; + } + for (int64_t i = 0; i < msa_sort_order_overrides.overrides_size(); ++i) { + const auto& override = msa_sort_order_overrides.overrides(i); + if (!MemorySpaceAssignmentUtils::DoesPositionMatchFilter( + override.hlo_position_matcher(), buffer_interval)) { + continue; + } + VLOG(3) << "Override Sort Order Config " << i << " matches " + << buffer_interval.buffer->instruction()->ToString(); + switch (override.override_options().options_case()) { + case MsaSortOrderOverrideOptions::kAssignFirst: + return std::numeric_limits::lowest() + i; + case MsaSortOrderOverrideOptions::kAssignLast: + return std::numeric_limits::max() - i; + case MsaSortOrderOverrideOptions::OPTIONS_NOT_SET: + continue; + } + } + return 0; +} } // namespace memory_space_assignment } // namespace xla diff --git a/xla/service/memory_space_assignment/utils.h b/xla/service/memory_space_assignment/utils.h index 272688af26735b..7ad6757df2606a 100644 --- a/xla/service/memory_space_assignment/utils.h +++ b/xla/service/memory_space_assignment/utils.h @@ -16,12 +16,25 @@ limitations under the License. #ifndef XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_UTILS_H_ #define XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_UTILS_H_ +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/utils/hlo_live_range.h" #include "xla/service/heap_simulator/heap_simulator.h" #include "xla/service/hlo_value.h" +#include "xla/service/memory_space_assignment/memory_space_assignment.pb.h" namespace xla { namespace memory_space_assignment { +using MsaBufferInterval = + GlobalDecreasingSizeBestFitHeap::BufferInterval; + // Encapsulates common utility methods for memory space assignment. class MemorySpaceAssignmentUtils { public: @@ -34,6 +47,64 @@ class MemorySpaceAssignmentUtils { // Returns true if the HloValue is allowed to be placed in alternate memory. static bool IsValueAllowedInAlternateMemory(const HloValue* value, int64_t alternate_memory_space); + + static bool DoesUseMatchFilter(const HloOperandFilter& filter, + const HloUse& hlo_use, int64_t operand_size); + + static bool DoesPositionMatchFilter(const HloPositionMatcher& filter, + const MsaBufferInterval& buffer_interval); + + static absl::StatusOr + GetScheduleTimeFromInstructionName( + absl::string_view name, + const absl::flat_hash_map& schedule); + + static absl::StatusOr> GetPrefetchTimeByEagerness( + float prefetch_eagerness, int64_t earliest_prefetch_time, + int64_t latest_prefetch_time); + + static absl::StatusOr> GetPrefetchTimeAfterInstruction( + const std::string& after_instruction_name, + const absl::flat_hash_map& schedule); + + static absl::StatusOr> + GetPrefetchTimeBeforeInstruction( + const std::string& before_instruction_name, + const absl::flat_hash_map& schedule); + + static absl::StatusOr> GetPrefetchTime( + const PreferredPrefetchOverrideOptions& override_options, + int64_t earliest_prefetch_time, int64_t latest_prefetch_time, + const absl::flat_hash_map& + instruction_schedule); + + static absl::StatusOr> + GetOverriddenPreferredPrefetchTime( + const PreferredPrefetchOverrides& preferred_prefetch_overrides, + int64_t operand_size, const HloUse& hlo_use, + const absl::flat_hash_map& + instruction_schedule, + int64_t earliest_prefetch_time, int64_t latest_prefetch_time); + + static bool DoesCrossProgramPrefetchBufferMatchAnyFilter( + const MsaSortOrderOverrides& sort_order_overrides, + const MsaBufferInterval& buffer_interval); + + // Returns an integer representing the priority of a MsaBufferInterval during + // assignment, a smaller number indicates a higher priority. + static int64_t GetBufferIntervalOverridePriority( + const MsaSortOrderOverrides& msa_sort_order_overrides, + const MsaBufferInterval& buffer_interval); + + private: + static bool DoesBufferIntervalMatchHloUseFilter( + const HloPositionMatcher& filter, + const MsaBufferInterval& buffer_interval); }; } // namespace memory_space_assignment