Skip to content

Commit

Permalink
Convert any compute on host memory into host compute, including dynam…
Browse files Browse the repository at this point in the history
…ic-slice.

PiperOrigin-RevId: 686283784
  • Loading branch information
SandSnip3r authored and Google-ML-Automation committed Oct 24, 2024
1 parent 6c0ce17 commit 3e1fe06
Show file tree
Hide file tree
Showing 3 changed files with 144 additions and 47 deletions.
93 changes: 51 additions & 42 deletions xla/hlo/transforms/host_offloader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,12 @@ bool HostOffloader::InstructionIsAllowedBetweenDsAndMoveToDevice(
absl::StatusOr<bool> HostOffloader::WalkDownHostMemoryOffloadPaths(
const InstructionAndShapeIndex& starting_instruction_and_index,
bool insert_copy_before) {
VLOG(3) << absl::StreamFormat(
"Walking down host memory offload paths starting from (%s, %s). Insert "
"copy before: %v",
starting_instruction_and_index.instruction->name(),
starting_instruction_and_index.shape_index.ToString(),
insert_copy_before);
bool changed = false;
absl::flat_hash_set<HloInstruction*> mth_custom_calls_to_remove;
absl::flat_hash_set<HloInstruction*> slices_to_dynamify;
Expand All @@ -147,6 +153,7 @@ absl::StatusOr<bool> HostOffloader::WalkDownHostMemoryOffloadPaths(
VLOG(4) << absl::StreamFormat("Visiting instruction: %s",
instruction_and_shape_index.ToString());
bool already_saved_buffer = false;
bool need_to_wrap_instruction_as_host_compute = false;
if (instruction->opcode() == HloOpcode::kCustomCall &&
instruction->custom_call_target() ==
host_memory_offload_annotations::kMoveToHostCustomCallTarget) {
Expand Down Expand Up @@ -198,33 +205,43 @@ absl::StatusOr<bool> HostOffloader::WalkDownHostMemoryOffloadPaths(
}
}
} else if (instruction->opcode() == HloOpcode::kDynamicSlice) {
TF_RETURN_IF_ERROR(
ValidateSliceLeadsToMoveToDeviceCustomCall(instruction));
// This DynamicSlice is the end of this path of host memory offload.
continue;
TF_ASSIGN_OR_RETURN(bool is_end_of_offload,
SliceLeadsToMoveToDeviceCustomCall(instruction));
if (is_end_of_offload) {
// This DynamicSlice is the end of this path of host memory offload.
continue;
} else {
// This is not the end of host memory offload. This is treated as device
// compute happening on host memory, convert it to host compute.
need_to_wrap_instruction_as_host_compute = true;
}
} else if (instruction->opcode() == HloOpcode::kSlice) {
TF_RETURN_IF_ERROR(
ValidateSliceLeadsToMoveToDeviceCustomCall(instruction));
// This Slice is the end of this path of host memory offload.
// This Slice should be a DynamicSlice to be able to work with host
// memory.
slices_to_dynamify.insert(instruction);
continue;
} else if (instruction->opcode() == HloOpcode::kAllGather ||
instruction->opcode() == HloOpcode::kAllReduce) {
TF_ASSIGN_OR_RETURN(bool is_end_of_offload,
SliceLeadsToMoveToDeviceCustomCall(instruction));
if (is_end_of_offload) {
// This Slice is the end of this path of host memory offload.
// This Slice should be a DynamicSlice to be able to work with host
// memory.
slices_to_dynamify.insert(instruction);
continue;
} else {
// This is not the end of host memory offload. This is treated as device
// compute happening on host memory, convert it to host compute.
need_to_wrap_instruction_as_host_compute = true;
}
} else {
// This is some unaccounted for instruction. Since it is unaccounted for,
// it must be something which is not legal to do with device compute.
need_to_wrap_instruction_as_host_compute = true;
}

if (need_to_wrap_instruction_as_host_compute) {
LOG(WARNING) << absl::StreamFormat(
"Found an instruction (\"%s\") which does device compute in host "
"memory space. Converting into host compute. This is likely to have "
"a very high overhead.",
instruction->name());
SetHostComputeFrontendAttribute(*instruction);
} else {
// Found an instruction which is invalid during host memory offloading.
return absl::InvalidArgumentError(
absl::StrFormat("Tensor which is moved to host (starting from "
"\"%s\") is used by an instruction (\"%s\") which is "
"not acceptable during pure memory offload.",
starting_instruction->name(), instruction->name()));
}

if (!already_saved_buffer) {
Expand Down Expand Up @@ -313,11 +330,7 @@ absl::StatusOr<bool> HostOffloader::WalkDownHostMemoryOffloadPaths(
}

for (HloInstruction* slice : slices_to_dynamify) {
TF_ASSIGN_OR_RETURN(HloInstruction * dynamic_slice, DynamifySlice(slice));
// We've already validated this slice. Since we're changing it to a dynamic
// slice, save the new dynamic slice so that we don't try to validate it
// again.
validated_slices_.insert(dynamic_slice);
TF_RETURN_IF_ERROR(DynamifySlice(slice));
changed = true;
}

Expand Down Expand Up @@ -364,8 +377,8 @@ absl::StatusOr<bool> HostOffloader::HandleMoveToHostCustomCall(
custom_call_instruction)) {
return false;
}
VLOG(1) << "Offloading " << custom_call_instruction->operand(0)->name()
<< " to host.";
VLOG(1) << "Offloading \"" << custom_call_instruction->operand(0)->name()
<< "\" to host.";
TF_ASSIGN_OR_RETURN(
std::vector<InstructionAndShapeIndex> starting_instruction_and_shapes,
GetStartingInstructions(custom_call_instruction));
Expand Down Expand Up @@ -542,12 +555,8 @@ HostOffloader::GetStartingInstructions(
return result;
}

absl::Status HostOffloader::ValidateSliceLeadsToMoveToDeviceCustomCall(
absl::StatusOr<bool> HostOffloader::SliceLeadsToMoveToDeviceCustomCall(
HloInstruction* slice) {
if (validated_slices_.find(slice) != validated_slices_.end()) {
// Already validated this one.
return absl::OkStatus();
}
// Every host-to-device DynamicSlice/Slice must be followed by a MoveToDevice
// custom call. This function verifiest that.
CHECK(slice->opcode() == HloOpcode::kDynamicSlice ||
Expand All @@ -571,11 +580,13 @@ absl::Status HostOffloader::ValidateSliceLeadsToMoveToDeviceCustomCall(
continue;
}
if (!InstructionIsAllowedBetweenDsAndMoveToDevice(current_instruction)) {
return absl::InvalidArgumentError(absl::StrFormat(
"Tensor which is moved to host and back to device (ending at \"%s\") "
"has an invalid instruction (\"%s\") between DynamicSlice/Slice and "
"the MoveToDevice custom call.",
slice->name(), current_instruction->name()));
// We were expecting to find a MoveToDevice custom call here, marking the
// end of host memory offloading, but we did not.
LOG(WARNING) << absl::StreamFormat(
"Encountered %s on tensor which is in host memory. %s does not move "
"the tensor back to device. %s will be converted into host compute.",
HloOpcodeString(slice->opcode()), slice->name(), slice->name());
return false;
}
TF_ASSIGN_OR_RETURN(
const std::vector<InstructionAndShapeIndex> successors,
Expand All @@ -584,8 +595,7 @@ absl::Status HostOffloader::ValidateSliceLeadsToMoveToDeviceCustomCall(
queue.push(successor);
}
}
validated_slices_.insert(slice);
return absl::OkStatus();
return true;
}

absl::Status HostOffloader::CreateAllocateBufferForDynamicUpdateSlice(
Expand Down Expand Up @@ -732,8 +742,7 @@ absl::Status HostOffloader::CreateAllocateBufferForDynamicUpdateSlice(
return absl::OkStatus();
}

absl::StatusOr<HloInstruction*> HostOffloader::DynamifySlice(
HloInstruction* slice) {
absl::Status HostOffloader::DynamifySlice(HloInstruction* slice) {
std::vector<HloInstruction*> start_constants;
for (int64_t start : slice->slice_starts()) {
HloInstruction* constant = slice->parent()->AddInstruction(
Expand All @@ -754,7 +763,7 @@ absl::StatusOr<HloInstruction*> HostOffloader::DynamifySlice(
"Changed slice \"%s\" into dynamic slice \"%s\"", slice->name(),
new_ds->name());
TF_RETURN_IF_ERROR(slice->parent()->RemoveInstruction(slice));
return new_ds;
return absl::OkStatus();
}

absl::StatusOr<bool> HostOffloader::ApplySchedulingFix(
Expand Down
10 changes: 5 additions & 5 deletions xla/hlo/transforms/host_offloader.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@ class HostOffloader : public HloModulePass {
absl::flat_hash_set<HloInstruction*>
already_visited_move_to_host_custom_calls_;
absl::flat_hash_set<HloInstruction*> dynamic_update_slices_already_allocated_;
absl::flat_hash_set<HloInstruction*> validated_slices_;
absl::flat_hash_map<HloInstruction*, HloInstruction*> copies_created_after_;
absl::flat_hash_set<HloInstruction*> move_to_device_custom_calls_to_remove_;
absl::flat_hash_set<host_offload_utils::InstructionAndShapeIndex>
Expand All @@ -87,7 +86,7 @@ class HostOffloader : public HloModulePass {
// Sometimes previous transformations turn a DynamicSlice into a Slice. Since
// we're doing a DMA between the host and device, we need to turn the Slice
// back into a DynamicSlice.
absl::StatusOr<HloInstruction*> DynamifySlice(HloInstruction* slice);
absl::Status DynamifySlice(HloInstruction* slice);

// Returns true if the instruction is allowed to be in the
// middle of a path between a MoveToHost custom-call annotation and a
Expand Down Expand Up @@ -126,9 +125,10 @@ class HostOffloader : public HloModulePass {
absl::Status CreateAllocateBufferForDynamicUpdateSlice(
HloInstruction* dynamic_update_slice);

// Returns an error if something unallowed exists between the
// Slice/DynamicSlice and the MoveToDevice custom call.
absl::Status ValidateSliceLeadsToMoveToDeviceCustomCall(
// One way to move data to the device is to use a Slice or DynamicSlice. This
// function returns true if the slice is followed by a MoveToDevice custom
// call.
absl::StatusOr<bool> SliceLeadsToMoveToDeviceCustomCall(
HloInstruction* slice);

// Common function for doing the actual walking of the graph. Host memory
Expand Down
88 changes: 88 additions & 0 deletions xla/hlo/transforms/host_offloader_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ limitations under the License.
#include "xla/layout.h"
#include "xla/service/hlo_verifier.h"
#include "xla/service/host_memory_offload_annotations.h"
#include "xla/service/host_offload_utils.h"
#include "xla/service/pattern_matcher.h"
#include "xla/service/pattern_matcher_gmock.h"
#include "xla/shape.h"
Expand Down Expand Up @@ -4190,6 +4191,93 @@ TEST_F(HostOffloaderTest, AvoidRedundantCopiesToHost) {
}
}

TEST_F(HostOffloaderTest, TanhOnHostMemory) {
const absl::string_view hlo_string = R"(
HloModule module, entry_computation_layout={(f32[1024]{0})->f32[1024]{0}}
ENTRY main {
param = f32[1024]{0} parameter(0)
to_host = f32[1024]{0} custom-call(param), custom_call_target="MoveToHost"
tanh = f32[1024]{0} tanh(to_host)
ROOT to_device = f32[1024]{0} custom-call(tanh), custom_call_target="MoveToDevice"
})";

TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
ParseAndReturnVerifiedModule(hlo_string));
TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHostOffloader(module.get()));
EXPECT_TRUE(changed);
VLOG(1) << module->ToString();
HloInstruction* tanh = FindInstruction(module.get(), "tanh");
EXPECT_TRUE(host_offload_utils::ComputeTypeIsHost(tanh));
}

TEST_F(HostOffloaderTest, DynamicSliceOnHostMemoryParamCopied) {
const absl::string_view hlo_string = R"(
HloModule module, entry_computation_layout={(f32[1024]{0}, s32[]{:T(128)})->f32[256]{0}}
ENTRY main {
param = f32[1024]{0} parameter(0)
index = s32[]{:T(128)} parameter(1)
to_host = f32[1024]{0} custom-call(param), custom_call_target="MoveToHost"
dynamic_slice = f32[256]{0} dynamic-slice(to_host, index), dynamic_slice_sizes={256}
tanh = f32[256]{0} tanh(dynamic_slice)
ROOT to_device = f32[256]{0} custom-call(tanh), custom_call_target="MoveToDevice"
})";

TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
ParseAndReturnVerifiedModule(hlo_string));
TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHostOffloader(module.get()));
EXPECT_TRUE(changed);
VLOG(1) << module->ToString();
HloInstruction* tanh = FindInstruction(module.get(), "tanh");
EXPECT_TRUE(host_offload_utils::ComputeTypeIsHost(tanh));
HloInstruction* dynamic_slice =
FindInstruction(module.get(), "dynamic_slice");
EXPECT_TRUE(host_offload_utils::ComputeTypeIsHost(dynamic_slice));
// Check memory spaces
ASSERT_EQ(dynamic_slice->operand_count(), 2);
HloInstruction* copy_of_param = dynamic_slice->mutable_operand(0);
EXPECT_EQ(copy_of_param->opcode(), HloOpcode::kCopy);
TestShapeHasMemorySpace(copy_of_param->shape(), Layout::kHostMemorySpace);
// The below tests something which needn't always be true.
// The current expected behavior of HostOffloader for this test is to detect
// compute happening on data in host memory space, which is the ops
// dynamic_slice and tanh. HostOffloader will mark these two as host compute.
// The interesting thing here is that the index to the dynamic_slice has not
// explicitly been moved to host memory space. The below check expects that
// HostOffloader does not explicitly move the index to host memory space. If
// HostOffloader changes to enable this, that is fine, I just wanted to make
// sure that it doesn't happen by accident.
HloInstruction* index = dynamic_slice->mutable_operand(1);
EXPECT_EQ(index->opcode(), HloOpcode::kParameter);
TestShapeHasMemorySpace(index->shape(), Layout::kDefaultMemorySpace);
}

TEST_F(HostOffloaderTest, DynamicSliceOnHostMemoryIndexCopied) {
const absl::string_view hlo_string = R"(
HloModule module, entry_computation_layout={(f32[1024]{0}, s32[]{:T(128)})->f32[256]{0}}
ENTRY main {
param = f32[1024]{0} parameter(0)
index = s32[]{:T(128)} parameter(1)
index_to_host = s32[]{:T(128)} custom-call(index), custom_call_target="MoveToHost"
dynamic_slice = f32[256]{0} dynamic-slice(param, index_to_host), dynamic_slice_sizes={256}
tanh = f32[256]{0} tanh(dynamic_slice)
ROOT to_device = f32[256]{0} custom-call(tanh), custom_call_target="MoveToDevice"
})";

TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
ParseAndReturnVerifiedModule(hlo_string));
TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHostOffloader(module.get()));
EXPECT_TRUE(changed);
VLOG(1) << module->ToString();
HloInstruction* dynamic_slice =
FindInstruction(module.get(), "dynamic_slice");
EXPECT_TRUE(host_offload_utils::ComputeTypeIsHost(dynamic_slice));
HloInstruction* tanh = FindInstruction(module.get(), "tanh");
EXPECT_TRUE(host_offload_utils::ComputeTypeIsHost(tanh));
}

} // namespace

} // namespace xla

0 comments on commit 3e1fe06

Please sign in to comment.