From 3e1fe06c73cdbd33002447ba1d7e06c1c5dd0b0e Mon Sep 17 00:00:00 2001 From: Victor Stone Date: Tue, 15 Oct 2024 16:47:43 -0700 Subject: [PATCH] Convert any compute on host memory into host compute, including dynamic-slice. PiperOrigin-RevId: 686283784 --- xla/hlo/transforms/host_offloader.cc | 93 +++++++++++++---------- xla/hlo/transforms/host_offloader.h | 10 +-- xla/hlo/transforms/host_offloader_test.cc | 88 +++++++++++++++++++++ 3 files changed, 144 insertions(+), 47 deletions(-) diff --git a/xla/hlo/transforms/host_offloader.cc b/xla/hlo/transforms/host_offloader.cc index f24f85d5f13c6..1341dc7be9705 100644 --- a/xla/hlo/transforms/host_offloader.cc +++ b/xla/hlo/transforms/host_offloader.cc @@ -130,6 +130,12 @@ bool HostOffloader::InstructionIsAllowedBetweenDsAndMoveToDevice( absl::StatusOr HostOffloader::WalkDownHostMemoryOffloadPaths( const InstructionAndShapeIndex& starting_instruction_and_index, bool insert_copy_before) { + VLOG(3) << absl::StreamFormat( + "Walking down host memory offload paths starting from (%s, %s). Insert " + "copy before: %v", + starting_instruction_and_index.instruction->name(), + starting_instruction_and_index.shape_index.ToString(), + insert_copy_before); bool changed = false; absl::flat_hash_set mth_custom_calls_to_remove; absl::flat_hash_set slices_to_dynamify; @@ -147,6 +153,7 @@ absl::StatusOr HostOffloader::WalkDownHostMemoryOffloadPaths( VLOG(4) << absl::StreamFormat("Visiting instruction: %s", instruction_and_shape_index.ToString()); bool already_saved_buffer = false; + bool need_to_wrap_instruction_as_host_compute = false; if (instruction->opcode() == HloOpcode::kCustomCall && instruction->custom_call_target() == host_memory_offload_annotations::kMoveToHostCustomCallTarget) { @@ -198,33 +205,43 @@ absl::StatusOr HostOffloader::WalkDownHostMemoryOffloadPaths( } } } else if (instruction->opcode() == HloOpcode::kDynamicSlice) { - TF_RETURN_IF_ERROR( - ValidateSliceLeadsToMoveToDeviceCustomCall(instruction)); - // This DynamicSlice is the end of this path of host memory offload. - continue; + TF_ASSIGN_OR_RETURN(bool is_end_of_offload, + SliceLeadsToMoveToDeviceCustomCall(instruction)); + if (is_end_of_offload) { + // This DynamicSlice is the end of this path of host memory offload. + continue; + } else { + // This is not the end of host memory offload. This is treated as device + // compute happening on host memory, convert it to host compute. + need_to_wrap_instruction_as_host_compute = true; + } } else if (instruction->opcode() == HloOpcode::kSlice) { - TF_RETURN_IF_ERROR( - ValidateSliceLeadsToMoveToDeviceCustomCall(instruction)); - // This Slice is the end of this path of host memory offload. - // This Slice should be a DynamicSlice to be able to work with host - // memory. - slices_to_dynamify.insert(instruction); - continue; - } else if (instruction->opcode() == HloOpcode::kAllGather || - instruction->opcode() == HloOpcode::kAllReduce) { + TF_ASSIGN_OR_RETURN(bool is_end_of_offload, + SliceLeadsToMoveToDeviceCustomCall(instruction)); + if (is_end_of_offload) { + // This Slice is the end of this path of host memory offload. + // This Slice should be a DynamicSlice to be able to work with host + // memory. + slices_to_dynamify.insert(instruction); + continue; + } else { + // This is not the end of host memory offload. This is treated as device + // compute happening on host memory, convert it to host compute. + need_to_wrap_instruction_as_host_compute = true; + } + } else { + // This is some unaccounted for instruction. Since it is unaccounted for, + // it must be something which is not legal to do with device compute. + need_to_wrap_instruction_as_host_compute = true; + } + + if (need_to_wrap_instruction_as_host_compute) { LOG(WARNING) << absl::StreamFormat( "Found an instruction (\"%s\") which does device compute in host " "memory space. Converting into host compute. This is likely to have " "a very high overhead.", instruction->name()); SetHostComputeFrontendAttribute(*instruction); - } else { - // Found an instruction which is invalid during host memory offloading. - return absl::InvalidArgumentError( - absl::StrFormat("Tensor which is moved to host (starting from " - "\"%s\") is used by an instruction (\"%s\") which is " - "not acceptable during pure memory offload.", - starting_instruction->name(), instruction->name())); } if (!already_saved_buffer) { @@ -313,11 +330,7 @@ absl::StatusOr HostOffloader::WalkDownHostMemoryOffloadPaths( } for (HloInstruction* slice : slices_to_dynamify) { - TF_ASSIGN_OR_RETURN(HloInstruction * dynamic_slice, DynamifySlice(slice)); - // We've already validated this slice. Since we're changing it to a dynamic - // slice, save the new dynamic slice so that we don't try to validate it - // again. - validated_slices_.insert(dynamic_slice); + TF_RETURN_IF_ERROR(DynamifySlice(slice)); changed = true; } @@ -364,8 +377,8 @@ absl::StatusOr HostOffloader::HandleMoveToHostCustomCall( custom_call_instruction)) { return false; } - VLOG(1) << "Offloading " << custom_call_instruction->operand(0)->name() - << " to host."; + VLOG(1) << "Offloading \"" << custom_call_instruction->operand(0)->name() + << "\" to host."; TF_ASSIGN_OR_RETURN( std::vector starting_instruction_and_shapes, GetStartingInstructions(custom_call_instruction)); @@ -542,12 +555,8 @@ HostOffloader::GetStartingInstructions( return result; } -absl::Status HostOffloader::ValidateSliceLeadsToMoveToDeviceCustomCall( +absl::StatusOr HostOffloader::SliceLeadsToMoveToDeviceCustomCall( HloInstruction* slice) { - if (validated_slices_.find(slice) != validated_slices_.end()) { - // Already validated this one. - return absl::OkStatus(); - } // Every host-to-device DynamicSlice/Slice must be followed by a MoveToDevice // custom call. This function verifiest that. CHECK(slice->opcode() == HloOpcode::kDynamicSlice || @@ -571,11 +580,13 @@ absl::Status HostOffloader::ValidateSliceLeadsToMoveToDeviceCustomCall( continue; } if (!InstructionIsAllowedBetweenDsAndMoveToDevice(current_instruction)) { - return absl::InvalidArgumentError(absl::StrFormat( - "Tensor which is moved to host and back to device (ending at \"%s\") " - "has an invalid instruction (\"%s\") between DynamicSlice/Slice and " - "the MoveToDevice custom call.", - slice->name(), current_instruction->name())); + // We were expecting to find a MoveToDevice custom call here, marking the + // end of host memory offloading, but we did not. + LOG(WARNING) << absl::StreamFormat( + "Encountered %s on tensor which is in host memory. %s does not move " + "the tensor back to device. %s will be converted into host compute.", + HloOpcodeString(slice->opcode()), slice->name(), slice->name()); + return false; } TF_ASSIGN_OR_RETURN( const std::vector successors, @@ -584,8 +595,7 @@ absl::Status HostOffloader::ValidateSliceLeadsToMoveToDeviceCustomCall( queue.push(successor); } } - validated_slices_.insert(slice); - return absl::OkStatus(); + return true; } absl::Status HostOffloader::CreateAllocateBufferForDynamicUpdateSlice( @@ -732,8 +742,7 @@ absl::Status HostOffloader::CreateAllocateBufferForDynamicUpdateSlice( return absl::OkStatus(); } -absl::StatusOr HostOffloader::DynamifySlice( - HloInstruction* slice) { +absl::Status HostOffloader::DynamifySlice(HloInstruction* slice) { std::vector start_constants; for (int64_t start : slice->slice_starts()) { HloInstruction* constant = slice->parent()->AddInstruction( @@ -754,7 +763,7 @@ absl::StatusOr HostOffloader::DynamifySlice( "Changed slice \"%s\" into dynamic slice \"%s\"", slice->name(), new_ds->name()); TF_RETURN_IF_ERROR(slice->parent()->RemoveInstruction(slice)); - return new_ds; + return absl::OkStatus(); } absl::StatusOr HostOffloader::ApplySchedulingFix( diff --git a/xla/hlo/transforms/host_offloader.h b/xla/hlo/transforms/host_offloader.h index a4d7a755c8302..765b3c2709856 100644 --- a/xla/hlo/transforms/host_offloader.h +++ b/xla/hlo/transforms/host_offloader.h @@ -78,7 +78,6 @@ class HostOffloader : public HloModulePass { absl::flat_hash_set already_visited_move_to_host_custom_calls_; absl::flat_hash_set dynamic_update_slices_already_allocated_; - absl::flat_hash_set validated_slices_; absl::flat_hash_map copies_created_after_; absl::flat_hash_set move_to_device_custom_calls_to_remove_; absl::flat_hash_set @@ -87,7 +86,7 @@ class HostOffloader : public HloModulePass { // Sometimes previous transformations turn a DynamicSlice into a Slice. Since // we're doing a DMA between the host and device, we need to turn the Slice // back into a DynamicSlice. - absl::StatusOr DynamifySlice(HloInstruction* slice); + absl::Status DynamifySlice(HloInstruction* slice); // Returns true if the instruction is allowed to be in the // middle of a path between a MoveToHost custom-call annotation and a @@ -126,9 +125,10 @@ class HostOffloader : public HloModulePass { absl::Status CreateAllocateBufferForDynamicUpdateSlice( HloInstruction* dynamic_update_slice); - // Returns an error if something unallowed exists between the - // Slice/DynamicSlice and the MoveToDevice custom call. - absl::Status ValidateSliceLeadsToMoveToDeviceCustomCall( + // One way to move data to the device is to use a Slice or DynamicSlice. This + // function returns true if the slice is followed by a MoveToDevice custom + // call. + absl::StatusOr SliceLeadsToMoveToDeviceCustomCall( HloInstruction* slice); // Common function for doing the actual walking of the graph. Host memory diff --git a/xla/hlo/transforms/host_offloader_test.cc b/xla/hlo/transforms/host_offloader_test.cc index 0306f5974bfd0..1452815127f1a 100644 --- a/xla/hlo/transforms/host_offloader_test.cc +++ b/xla/hlo/transforms/host_offloader_test.cc @@ -37,6 +37,7 @@ limitations under the License. #include "xla/layout.h" #include "xla/service/hlo_verifier.h" #include "xla/service/host_memory_offload_annotations.h" +#include "xla/service/host_offload_utils.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" #include "xla/shape.h" @@ -4190,6 +4191,93 @@ TEST_F(HostOffloaderTest, AvoidRedundantCopiesToHost) { } } +TEST_F(HostOffloaderTest, TanhOnHostMemory) { + const absl::string_view hlo_string = R"( + HloModule module, entry_computation_layout={(f32[1024]{0})->f32[1024]{0}} + + ENTRY main { + param = f32[1024]{0} parameter(0) + to_host = f32[1024]{0} custom-call(param), custom_call_target="MoveToHost" + tanh = f32[1024]{0} tanh(to_host) + ROOT to_device = f32[1024]{0} custom-call(tanh), custom_call_target="MoveToDevice" + })"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHostOffloader(module.get())); + EXPECT_TRUE(changed); + VLOG(1) << module->ToString(); + HloInstruction* tanh = FindInstruction(module.get(), "tanh"); + EXPECT_TRUE(host_offload_utils::ComputeTypeIsHost(tanh)); +} + +TEST_F(HostOffloaderTest, DynamicSliceOnHostMemoryParamCopied) { + const absl::string_view hlo_string = R"( + HloModule module, entry_computation_layout={(f32[1024]{0}, s32[]{:T(128)})->f32[256]{0}} + + ENTRY main { + param = f32[1024]{0} parameter(0) + index = s32[]{:T(128)} parameter(1) + to_host = f32[1024]{0} custom-call(param), custom_call_target="MoveToHost" + dynamic_slice = f32[256]{0} dynamic-slice(to_host, index), dynamic_slice_sizes={256} + tanh = f32[256]{0} tanh(dynamic_slice) + ROOT to_device = f32[256]{0} custom-call(tanh), custom_call_target="MoveToDevice" + })"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHostOffloader(module.get())); + EXPECT_TRUE(changed); + VLOG(1) << module->ToString(); + HloInstruction* tanh = FindInstruction(module.get(), "tanh"); + EXPECT_TRUE(host_offload_utils::ComputeTypeIsHost(tanh)); + HloInstruction* dynamic_slice = + FindInstruction(module.get(), "dynamic_slice"); + EXPECT_TRUE(host_offload_utils::ComputeTypeIsHost(dynamic_slice)); + // Check memory spaces + ASSERT_EQ(dynamic_slice->operand_count(), 2); + HloInstruction* copy_of_param = dynamic_slice->mutable_operand(0); + EXPECT_EQ(copy_of_param->opcode(), HloOpcode::kCopy); + TestShapeHasMemorySpace(copy_of_param->shape(), Layout::kHostMemorySpace); + // The below tests something which needn't always be true. + // The current expected behavior of HostOffloader for this test is to detect + // compute happening on data in host memory space, which is the ops + // dynamic_slice and tanh. HostOffloader will mark these two as host compute. + // The interesting thing here is that the index to the dynamic_slice has not + // explicitly been moved to host memory space. The below check expects that + // HostOffloader does not explicitly move the index to host memory space. If + // HostOffloader changes to enable this, that is fine, I just wanted to make + // sure that it doesn't happen by accident. + HloInstruction* index = dynamic_slice->mutable_operand(1); + EXPECT_EQ(index->opcode(), HloOpcode::kParameter); + TestShapeHasMemorySpace(index->shape(), Layout::kDefaultMemorySpace); +} + +TEST_F(HostOffloaderTest, DynamicSliceOnHostMemoryIndexCopied) { + const absl::string_view hlo_string = R"( + HloModule module, entry_computation_layout={(f32[1024]{0}, s32[]{:T(128)})->f32[256]{0}} + + ENTRY main { + param = f32[1024]{0} parameter(0) + index = s32[]{:T(128)} parameter(1) + index_to_host = s32[]{:T(128)} custom-call(index), custom_call_target="MoveToHost" + dynamic_slice = f32[256]{0} dynamic-slice(param, index_to_host), dynamic_slice_sizes={256} + tanh = f32[256]{0} tanh(dynamic_slice) + ROOT to_device = f32[256]{0} custom-call(tanh), custom_call_target="MoveToDevice" + })"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHostOffloader(module.get())); + EXPECT_TRUE(changed); + VLOG(1) << module->ToString(); + HloInstruction* dynamic_slice = + FindInstruction(module.get(), "dynamic_slice"); + EXPECT_TRUE(host_offload_utils::ComputeTypeIsHost(dynamic_slice)); + HloInstruction* tanh = FindInstruction(module.get(), "tanh"); + EXPECT_TRUE(host_offload_utils::ComputeTypeIsHost(tanh)); +} + } // namespace } // namespace xla