Skip to content

Commit

Permalink
[XLA:GPU] Move GPU specific combiner utils to GPU directory.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 684487102
  • Loading branch information
golechwierowicz authored and Google-ML-Automation committed Oct 10, 2024
1 parent e1be920 commit 4f12ccd
Show file tree
Hide file tree
Showing 6 changed files with 81 additions and 44 deletions.
23 changes: 0 additions & 23 deletions xla/service/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -3019,16 +3019,13 @@ xla_cc_test(

cc_library(
name = "collective_combiner_utils",
srcs = ["collective_combiner_utils.cc"],
hdrs = ["collective_combiner_utils.h"],
deps = [
":collective_utils",
"//xla:shape_util",
"//xla:status_macros",
"//xla:xla_data_proto_cc",
"//xla/hlo/ir:hlo",
"//xla/hlo/ir:hlo_reachability",
"//xla/stream_executor:device_description",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/functional:function_ref",
Expand All @@ -3040,26 +3037,6 @@ cc_library(
],
)

xla_cc_test(
name = "collective_combiner_utils_test",
srcs = ["collective_combiner_utils_test.cc"],
deps = [
":collective_combiner_utils",
":collective_utils",
":hlo_module_config",
"//xla/hlo/ir:hlo",
"//xla/service/gpu:gpu_hlo_schedule",
"//xla/stream_executor:device_description",
"//xla/tests:hlo_test_base",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:string_view",
"@com_google_googletest//:gtest_main",
"@tsl//tsl/platform:statusor",
"@tsl//tsl/platform:test",
],
)

cc_library(
name = "collective_decomposer_utils",
srcs = ["collective_decomposer_utils.cc"],
Expand Down
16 changes: 0 additions & 16 deletions xla/service/collective_combiner_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ limitations under the License.
#define XLA_SERVICE_COLLECTIVE_COMBINER_UTILS_H_

#include <cstdint>
#include <functional>
#include <memory>
#include <optional>
#include <vector>
Expand All @@ -30,29 +29,14 @@ limitations under the License.
#include "absl/status/statusor.h"
#include "absl/types/span.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/hlo/ir/hlo_reachability.h"
#include "xla/hlo/ir/hlo_schedule.h"
#include "xla/shape_util.h"
#include "xla/status_macros.h"
#include "xla/stream_executor/device_description.h"
#include "xla/xla_data.pb.h"
#include "tsl/platform/errors.h"

namespace xla {

// Suggests a combiner threshold to the caller (combiner). At the moment it only
// suggests a lower value than a default combiner threshold if it exceeds
// available memory on a device. If the scheduling of a `module` failed for any
// reason the method return a default value of a combiner threshold for
// `collective_opcode`.
int64_t ComputeSuggestedCombinerThreshold(
const HloModule& module, const se::DeviceDescription& device_info,
std::function<absl::StatusOr<HloSchedule>(const HloModule*, int64_t,
int64_t*)>
scheduler,
HloOpcode collective_opcode, int64_t pointer_size);

// Combines instructions with matching keys together.
//
// Instructions are combined in topological post-order.
Expand Down
33 changes: 33 additions & 0 deletions xla/service/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -3013,3 +3013,36 @@ xla_cc_test(
"@tsl//tsl/platform:test",
],
)

cc_library(
name = "gpu_collective_combiner_utils",
srcs = ["gpu_collective_combiner_utils.cc"],
hdrs = ["gpu_collective_combiner_utils.h"],
deps = [
"//xla/hlo/ir:hlo",
"//xla/service:collective_utils",
"//xla/stream_executor:device_description",
"@com_google_absl//absl/log",
"@com_google_absl//absl/status:statusor",
],
)

xla_cc_test(
name = "gpu_collective_combiner_utils_test",
srcs = ["gpu_collective_combiner_utils_test.cc"],
deps = [
":gpu_collective_combiner_utils",
":gpu_hlo_schedule",
"//xla/hlo/ir:hlo",
"//xla/service:collective_utils",
"//xla/service:hlo_module_config",
"//xla/stream_executor:device_description",
"//xla/tests:hlo_test_base",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:string_view",
"@com_google_googletest//:gtest_main",
"@tsl//tsl/platform:statusor",
"@tsl//tsl/platform:test",
],
)
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ limitations under the License.
#include "xla/service/collective_utils.h"
#include "xla/stream_executor/device_description.h"

namespace xla {
namespace xla::gpu {

using MemoryAwareScheduler = std::function<absl::StatusOr<HloSchedule>(
const HloModule*, int64_t, int64_t*)>;
Expand Down Expand Up @@ -65,4 +65,4 @@ int64_t ComputeSuggestedCombinerThreshold(
return base_limit * slop_factor / 100 - peak_memory_bytes;
}

} // namespace xla
} // namespace xla::gpu
43 changes: 43 additions & 0 deletions xla/service/gpu/gpu_collective_combiner_utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/* Copyright 2024 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef XLA_SERVICE_GPU_GPU_COLLECTIVE_COMBINER_UTILS_H_
#define XLA_SERVICE_GPU_GPU_COLLECTIVE_COMBINER_UTILS_H_

#include <cstdint>
#include <functional>

#include "absl/status/statusor.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/hlo/ir/hlo_schedule.h"
#include "xla/stream_executor/device_description.h"

namespace xla::gpu {

// Suggests a combiner threshold to the caller (combiner). At the moment it only
// suggests a lower value than a default combiner threshold if it exceeds
// available memory on a device. If the scheduling of a `module` failed for any
// reason the method return a default value of a combiner threshold for
// `collective_opcode`.
int64_t ComputeSuggestedCombinerThreshold(
const HloModule& module, const se::DeviceDescription& device_info,
std::function<absl::StatusOr<HloSchedule>(const HloModule*, int64_t,
int64_t*)>
scheduler,
HloOpcode collective_opcode, int64_t pointer_size);
} // namespace xla::gpu

#endif // XLA_SERVICE_GPU_GPU_COLLECTIVE_COMBINER_UTILS_H_
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "xla/service/collective_combiner_utils.h"
#include "xla/service/gpu/gpu_collective_combiner_utils.h"

#include <cstdint>

Expand All @@ -32,7 +32,7 @@ limitations under the License.
#include "tsl/platform/statusor.h"
#include "tsl/platform/test.h"

namespace xla {
namespace xla::gpu {
namespace {

using CollectiveCombinerUtilsTest = HloTestBase;
Expand Down Expand Up @@ -137,4 +137,4 @@ TEST_F(
}

} // namespace
} // namespace xla
} // namespace xla::gpu

0 comments on commit 4f12ccd

Please sign in to comment.