Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert any compute on host memory into host compute, including dynamic-slice. #18624

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 51 additions & 42 deletions xla/hlo/transforms/host_offloader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,12 @@ bool HostOffloader::InstructionIsAllowedBetweenDsAndMoveToDevice(
absl::StatusOr<bool> HostOffloader::WalkDownHostMemoryOffloadPaths(
const InstructionAndShapeIndex& starting_instruction_and_index,
bool insert_copy_before) {
VLOG(3) << absl::StreamFormat(
"Walking down host memory offload paths starting from (%s, %s). Insert "
"copy before: %v",
starting_instruction_and_index.instruction->name(),
starting_instruction_and_index.shape_index.ToString(),
insert_copy_before);
bool changed = false;
absl::flat_hash_set<HloInstruction*> mth_custom_calls_to_remove;
absl::flat_hash_set<HloInstruction*> slices_to_dynamify;
Expand All @@ -147,6 +153,7 @@ absl::StatusOr<bool> HostOffloader::WalkDownHostMemoryOffloadPaths(
VLOG(4) << absl::StreamFormat("Visiting instruction: %s",
instruction_and_shape_index.ToString());
bool already_saved_buffer = false;
bool need_to_wrap_instruction_as_host_compute = false;
if (instruction->opcode() == HloOpcode::kCustomCall &&
instruction->custom_call_target() ==
host_memory_offload_annotations::kMoveToHostCustomCallTarget) {
Expand Down Expand Up @@ -198,33 +205,43 @@ absl::StatusOr<bool> HostOffloader::WalkDownHostMemoryOffloadPaths(
}
}
} else if (instruction->opcode() == HloOpcode::kDynamicSlice) {
TF_RETURN_IF_ERROR(
ValidateSliceLeadsToMoveToDeviceCustomCall(instruction));
// This DynamicSlice is the end of this path of host memory offload.
continue;
TF_ASSIGN_OR_RETURN(bool is_end_of_offload,
SliceLeadsToMoveToDeviceCustomCall(instruction));
if (is_end_of_offload) {
// This DynamicSlice is the end of this path of host memory offload.
continue;
} else {
// This is not the end of host memory offload. This is treated as device
// compute happening on host memory, convert it to host compute.
need_to_wrap_instruction_as_host_compute = true;
}
} else if (instruction->opcode() == HloOpcode::kSlice) {
TF_RETURN_IF_ERROR(
ValidateSliceLeadsToMoveToDeviceCustomCall(instruction));
// This Slice is the end of this path of host memory offload.
// This Slice should be a DynamicSlice to be able to work with host
// memory.
slices_to_dynamify.insert(instruction);
continue;
} else if (instruction->opcode() == HloOpcode::kAllGather ||
instruction->opcode() == HloOpcode::kAllReduce) {
TF_ASSIGN_OR_RETURN(bool is_end_of_offload,
SliceLeadsToMoveToDeviceCustomCall(instruction));
if (is_end_of_offload) {
// This Slice is the end of this path of host memory offload.
// This Slice should be a DynamicSlice to be able to work with host
// memory.
slices_to_dynamify.insert(instruction);
continue;
} else {
// This is not the end of host memory offload. This is treated as device
// compute happening on host memory, convert it to host compute.
need_to_wrap_instruction_as_host_compute = true;
}
} else {
// This is some unaccounted for instruction. Since it is unaccounted for,
// it must be something which is not legal to do with device compute.
need_to_wrap_instruction_as_host_compute = true;
}

if (need_to_wrap_instruction_as_host_compute) {
LOG(WARNING) << absl::StreamFormat(
"Found an instruction (\"%s\") which does device compute in host "
"memory space. Converting into host compute. This is likely to have "
"a very high overhead.",
instruction->name());
SetHostComputeFrontendAttribute(*instruction);
} else {
// Found an instruction which is invalid during host memory offloading.
return absl::InvalidArgumentError(
absl::StrFormat("Tensor which is moved to host (starting from "
"\"%s\") is used by an instruction (\"%s\") which is "
"not acceptable during pure memory offload.",
starting_instruction->name(), instruction->name()));
}

if (!already_saved_buffer) {
Expand Down Expand Up @@ -313,11 +330,7 @@ absl::StatusOr<bool> HostOffloader::WalkDownHostMemoryOffloadPaths(
}

for (HloInstruction* slice : slices_to_dynamify) {
TF_ASSIGN_OR_RETURN(HloInstruction * dynamic_slice, DynamifySlice(slice));
// We've already validated this slice. Since we're changing it to a dynamic
// slice, save the new dynamic slice so that we don't try to validate it
// again.
validated_slices_.insert(dynamic_slice);
TF_RETURN_IF_ERROR(DynamifySlice(slice));
changed = true;
}

Expand Down Expand Up @@ -364,8 +377,8 @@ absl::StatusOr<bool> HostOffloader::HandleMoveToHostCustomCall(
custom_call_instruction)) {
return false;
}
VLOG(1) << "Offloading " << custom_call_instruction->operand(0)->name()
<< " to host.";
VLOG(1) << "Offloading \"" << custom_call_instruction->operand(0)->name()
<< "\" to host.";
TF_ASSIGN_OR_RETURN(
std::vector<InstructionAndShapeIndex> starting_instruction_and_shapes,
GetStartingInstructions(custom_call_instruction));
Expand Down Expand Up @@ -542,12 +555,8 @@ HostOffloader::GetStartingInstructions(
return result;
}

absl::Status HostOffloader::ValidateSliceLeadsToMoveToDeviceCustomCall(
absl::StatusOr<bool> HostOffloader::SliceLeadsToMoveToDeviceCustomCall(
HloInstruction* slice) {
if (validated_slices_.find(slice) != validated_slices_.end()) {
// Already validated this one.
return absl::OkStatus();
}
// Every host-to-device DynamicSlice/Slice must be followed by a MoveToDevice
// custom call. This function verifiest that.
CHECK(slice->opcode() == HloOpcode::kDynamicSlice ||
Expand All @@ -571,11 +580,13 @@ absl::Status HostOffloader::ValidateSliceLeadsToMoveToDeviceCustomCall(
continue;
}
if (!InstructionIsAllowedBetweenDsAndMoveToDevice(current_instruction)) {
return absl::InvalidArgumentError(absl::StrFormat(
"Tensor which is moved to host and back to device (ending at \"%s\") "
"has an invalid instruction (\"%s\") between DynamicSlice/Slice and "
"the MoveToDevice custom call.",
slice->name(), current_instruction->name()));
// We were expecting to find a MoveToDevice custom call here, marking the
// end of host memory offloading, but we did not.
LOG(WARNING) << absl::StreamFormat(
"Encountered %s on tensor which is in host memory. %s does not move "
"the tensor back to device. %s will be converted into host compute.",
HloOpcodeString(slice->opcode()), slice->name(), slice->name());
return false;
}
TF_ASSIGN_OR_RETURN(
const std::vector<InstructionAndShapeIndex> successors,
Expand All @@ -584,8 +595,7 @@ absl::Status HostOffloader::ValidateSliceLeadsToMoveToDeviceCustomCall(
queue.push(successor);
}
}
validated_slices_.insert(slice);
return absl::OkStatus();
return true;
}

absl::Status HostOffloader::CreateAllocateBufferForDynamicUpdateSlice(
Expand Down Expand Up @@ -732,8 +742,7 @@ absl::Status HostOffloader::CreateAllocateBufferForDynamicUpdateSlice(
return absl::OkStatus();
}

absl::StatusOr<HloInstruction*> HostOffloader::DynamifySlice(
HloInstruction* slice) {
absl::Status HostOffloader::DynamifySlice(HloInstruction* slice) {
std::vector<HloInstruction*> start_constants;
for (int64_t start : slice->slice_starts()) {
HloInstruction* constant = slice->parent()->AddInstruction(
Expand All @@ -754,7 +763,7 @@ absl::StatusOr<HloInstruction*> HostOffloader::DynamifySlice(
"Changed slice \"%s\" into dynamic slice \"%s\"", slice->name(),
new_ds->name());
TF_RETURN_IF_ERROR(slice->parent()->RemoveInstruction(slice));
return new_ds;
return absl::OkStatus();
}

absl::StatusOr<bool> HostOffloader::ApplySchedulingFix(
Expand Down
10 changes: 5 additions & 5 deletions xla/hlo/transforms/host_offloader.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@ class HostOffloader : public HloModulePass {
absl::flat_hash_set<HloInstruction*>
already_visited_move_to_host_custom_calls_;
absl::flat_hash_set<HloInstruction*> dynamic_update_slices_already_allocated_;
absl::flat_hash_set<HloInstruction*> validated_slices_;
absl::flat_hash_map<HloInstruction*, HloInstruction*> copies_created_after_;
absl::flat_hash_set<HloInstruction*> move_to_device_custom_calls_to_remove_;
absl::flat_hash_set<host_offload_utils::InstructionAndShapeIndex>
Expand All @@ -87,7 +86,7 @@ class HostOffloader : public HloModulePass {
// Sometimes previous transformations turn a DynamicSlice into a Slice. Since
// we're doing a DMA between the host and device, we need to turn the Slice
// back into a DynamicSlice.
absl::StatusOr<HloInstruction*> DynamifySlice(HloInstruction* slice);
absl::Status DynamifySlice(HloInstruction* slice);

// Returns true if the instruction is allowed to be in the
// middle of a path between a MoveToHost custom-call annotation and a
Expand Down Expand Up @@ -126,9 +125,10 @@ class HostOffloader : public HloModulePass {
absl::Status CreateAllocateBufferForDynamicUpdateSlice(
HloInstruction* dynamic_update_slice);

// Returns an error if something unallowed exists between the
// Slice/DynamicSlice and the MoveToDevice custom call.
absl::Status ValidateSliceLeadsToMoveToDeviceCustomCall(
// One way to move data to the device is to use a Slice or DynamicSlice. This
// function returns true if the slice is followed by a MoveToDevice custom
// call.
absl::StatusOr<bool> SliceLeadsToMoveToDeviceCustomCall(
HloInstruction* slice);

// Common function for doing the actual walking of the graph. Host memory
Expand Down
88 changes: 88 additions & 0 deletions xla/hlo/transforms/host_offloader_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ limitations under the License.
#include "xla/layout.h"
#include "xla/service/hlo_verifier.h"
#include "xla/service/host_memory_offload_annotations.h"
#include "xla/service/host_offload_utils.h"
#include "xla/service/pattern_matcher.h"
#include "xla/service/pattern_matcher_gmock.h"
#include "xla/shape.h"
Expand Down Expand Up @@ -4190,6 +4191,93 @@ TEST_F(HostOffloaderTest, AvoidRedundantCopiesToHost) {
}
}

TEST_F(HostOffloaderTest, TanhOnHostMemory) {
const absl::string_view hlo_string = R"(
HloModule module, entry_computation_layout={(f32[1024]{0})->f32[1024]{0}}
ENTRY main {
param = f32[1024]{0} parameter(0)
to_host = f32[1024]{0} custom-call(param), custom_call_target="MoveToHost"
tanh = f32[1024]{0} tanh(to_host)
ROOT to_device = f32[1024]{0} custom-call(tanh), custom_call_target="MoveToDevice"
})";

TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
ParseAndReturnVerifiedModule(hlo_string));
TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHostOffloader(module.get()));
EXPECT_TRUE(changed);
VLOG(1) << module->ToString();
HloInstruction* tanh = FindInstruction(module.get(), "tanh");
EXPECT_TRUE(host_offload_utils::ComputeTypeIsHost(tanh));
}

TEST_F(HostOffloaderTest, DynamicSliceOnHostMemoryParamCopied) {
const absl::string_view hlo_string = R"(
HloModule module, entry_computation_layout={(f32[1024]{0}, s32[]{:T(128)})->f32[256]{0}}
ENTRY main {
param = f32[1024]{0} parameter(0)
index = s32[]{:T(128)} parameter(1)
to_host = f32[1024]{0} custom-call(param), custom_call_target="MoveToHost"
dynamic_slice = f32[256]{0} dynamic-slice(to_host, index), dynamic_slice_sizes={256}
tanh = f32[256]{0} tanh(dynamic_slice)
ROOT to_device = f32[256]{0} custom-call(tanh), custom_call_target="MoveToDevice"
})";

TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
ParseAndReturnVerifiedModule(hlo_string));
TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHostOffloader(module.get()));
EXPECT_TRUE(changed);
VLOG(1) << module->ToString();
HloInstruction* tanh = FindInstruction(module.get(), "tanh");
EXPECT_TRUE(host_offload_utils::ComputeTypeIsHost(tanh));
HloInstruction* dynamic_slice =
FindInstruction(module.get(), "dynamic_slice");
EXPECT_TRUE(host_offload_utils::ComputeTypeIsHost(dynamic_slice));
// Check memory spaces
ASSERT_EQ(dynamic_slice->operand_count(), 2);
HloInstruction* copy_of_param = dynamic_slice->mutable_operand(0);
EXPECT_EQ(copy_of_param->opcode(), HloOpcode::kCopy);
TestShapeHasMemorySpace(copy_of_param->shape(), Layout::kHostMemorySpace);
// The below tests something which needn't always be true.
// The current expected behavior of HostOffloader for this test is to detect
// compute happening on data in host memory space, which is the ops
// dynamic_slice and tanh. HostOffloader will mark these two as host compute.
// The interesting thing here is that the index to the dynamic_slice has not
// explicitly been moved to host memory space. The below check expects that
// HostOffloader does not explicitly move the index to host memory space. If
// HostOffloader changes to enable this, that is fine, I just wanted to make
// sure that it doesn't happen by accident.
HloInstruction* index = dynamic_slice->mutable_operand(1);
EXPECT_EQ(index->opcode(), HloOpcode::kParameter);
TestShapeHasMemorySpace(index->shape(), Layout::kDefaultMemorySpace);
}

TEST_F(HostOffloaderTest, DynamicSliceOnHostMemoryIndexCopied) {
const absl::string_view hlo_string = R"(
HloModule module, entry_computation_layout={(f32[1024]{0}, s32[]{:T(128)})->f32[256]{0}}
ENTRY main {
param = f32[1024]{0} parameter(0)
index = s32[]{:T(128)} parameter(1)
index_to_host = s32[]{:T(128)} custom-call(index), custom_call_target="MoveToHost"
dynamic_slice = f32[256]{0} dynamic-slice(param, index_to_host), dynamic_slice_sizes={256}
tanh = f32[256]{0} tanh(dynamic_slice)
ROOT to_device = f32[256]{0} custom-call(tanh), custom_call_target="MoveToDevice"
})";

TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
ParseAndReturnVerifiedModule(hlo_string));
TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHostOffloader(module.get()));
EXPECT_TRUE(changed);
VLOG(1) << module->ToString();
HloInstruction* dynamic_slice =
FindInstruction(module.get(), "dynamic_slice");
EXPECT_TRUE(host_offload_utils::ComputeTypeIsHost(dynamic_slice));
HloInstruction* tanh = FindInstruction(module.get(), "tanh");
EXPECT_TRUE(host_offload_utils::ComputeTypeIsHost(tanh));
}

} // namespace

} // namespace xla
Loading