Skip to content

Commit

Permalink
[IFRT] Add IFRT IR passes for compiling atom programs.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 689507483
  • Loading branch information
ICGog authored and Google-ML-Automation committed Oct 24, 2024
1 parent 93eb762 commit a2d165e
Show file tree
Hide file tree
Showing 16 changed files with 2,260 additions and 7 deletions.
15 changes: 15 additions & 0 deletions xla/python/ifrt/ir/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -226,3 +226,18 @@ xla_cc_test(
"@tsl//tsl/platform:statusor",
],
)

cc_library(
name = "atom_program_compiler",
hdrs = ["atom_program_compiler.h"],
compatible_with = get_compatible_with_portable(),
visibility = ["//xla/python/ifrt:friends"],
deps = [
":ir",
"//xla/pjrt:pjrt_executable",
"//xla/python/ifrt",
"//xla/python/ifrt/hlo:hlo_program",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/status:statusor",
],
)
63 changes: 63 additions & 0 deletions xla/python/ifrt/ir/atom_program_compiler.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/* Copyright 2024 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef XLA_PYTHON_IFRT_IR_ATOM_PROGRAM_COMPILER_H_
#define XLA_PYTHON_IFRT_IR_ATOM_PROGRAM_COMPILER_H_

#include <memory>
#include <string>
#include <vector>

#include "absl/container/flat_hash_map.h"
#include "absl/status/statusor.h"
#include "xla/pjrt/pjrt_executable.h"
#include "xla/python/ifrt/dtype.h"
#include "xla/python/ifrt/executable.h"
#include "xla/python/ifrt/hlo/hlo_program.h"
#include "xla/python/ifrt/ir/ifrt_dialect.h"
#include "xla/python/ifrt/shape.h"

namespace xla {
namespace ifrt {

// Loaded executable and unique name for a compiled atom program.
struct AtomProgramCompileResult {
std::string name;
std::shared_ptr<LoadedExecutable> executable;
};

using AtomExecutableMap =
absl::flat_hash_map<std::string, std::shared_ptr<LoadedExecutable>>;

class AtomProgramCompiler {
public:
virtual ~AtomProgramCompiler() = default;

// Delegates the compilation of an atom XLA program.
// `options` uses logical device id in the main mlir module.
virtual absl::StatusOr<AtomProgramCompileResult> CompileXla(
std::unique_ptr<HloProgram> computation, xla::CompileOptions options) = 0;

// Delegates the compilation of an MPMD reshard program.
virtual absl::StatusOr<AtomProgramCompileResult> CompileMpmdReshard(
std::vector<DType> dtypes, std::vector<Shape> shapes,
std::vector<IfrtArrayType> in_array_types,
std::vector<IfrtArrayType> out_array_types) = 0;
};

} // namespace ifrt
} // namespace xla

#endif // XLA_PYTHON_IFRT_IR_ATOM_PROGRAM_COMPILER_H_
16 changes: 16 additions & 0 deletions xla/python/ifrt/ir/tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ lit_test_suite(
name = "all_tests",
srcs = enforce_glob(
[
"ifrt_compile_atom_program.mlir",
"ifrt_compile_and_propagate_shardings.mlir",
"ifrt_duplicated_callee_elimination.mlir",
"ifrt_lower_mpmd_reshard_to_call.mlir",
"ifrt_lower_sharding_to_xla.mlir",
Expand Down Expand Up @@ -46,11 +48,25 @@ lit_test_suite(

xla_cc_binary(
name = "ifrt-opt",
testonly = True,
srcs = ["ifrt-opt.cc"],
deps = [
"//xla/mlir_hlo:hlo_dialect_registration",
"//xla/pjrt:pjrt_executable",
"//xla/python/ifrt",
"//xla/python/ifrt:mock",
"//xla/python/ifrt/hlo:hlo_program",
"//xla/python/ifrt/ir",
"//xla/python/ifrt/ir:atom_program_compiler",
"//xla/python/ifrt/ir/transforms:passes",
"//xla/python/ifrt/support:module_parsing",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/synchronization",
"@com_google_googletest//:gtest",
"@llvm-project//mlir:AllPassesAndDialects",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:MlirOptLib",
Expand Down
108 changes: 107 additions & 1 deletion xla/python/ifrt/ir/tests/ifrt-opt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,122 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include <memory>
#include <string>
#include <utility>
#include <vector>

#include <gmock/gmock.h>
#include "absl/base/thread_annotations.h"
#include "absl/container/flat_hash_map.h"
#include "absl/log/check.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
#include "absl/synchronization/mutex.h"
#include "mlir/IR/DialectRegistry.h"
#include "mlir/InitAllPasses.h"
#include "mlir/Tools/mlir-opt/MlirOptMain.h"
#include "xla/mlir_hlo/mhlo/IR/register.h"
#include "xla/pjrt/pjrt_executable.h"
#include "xla/python/ifrt/dtype.h"
#include "xla/python/ifrt/executable.h"
#include "xla/python/ifrt/hlo/hlo_program.h"
#include "xla/python/ifrt/ir/atom_program_compiler.h"
#include "xla/python/ifrt/ir/ifrt_dialect.h"
#include "xla/python/ifrt/ir/transforms/passes.h"
#include "xla/python/ifrt/mock.h"
#include "xla/python/ifrt/shape.h"
#include "xla/python/ifrt/support/module_parsing.h"

namespace xla {
namespace ifrt {
namespace {

static constexpr int kMaxTestMethods = 1000;

class TestChildExecutableCompiler : public AtomProgramCompiler {
public:
TestChildExecutableCompiler() { methods_.reserve(kMaxTestMethods); }

absl::StatusOr<AtomProgramCompileResult> CompileXla(
std::unique_ptr<HloProgram> hlo_program,
xla::CompileOptions options) override ABSL_LOCKS_EXCLUDED(mu_) {
absl::MutexLock lock(&mu_);
methods_.push_back(absl::StrCat("fake_method_", methods_.size()));
CHECK_LT(methods_.size(), kMaxTestMethods)
<< "push_back() might have caused reallocation, which might have "
"invalidated some method string_views.";
auto mock_executable =
std::make_unique<testing::NiceMock<MockLoadedExecutable>>();
int num_parameters_to_propagate =
options.executable_build_options
.allow_spmd_sharding_propagation_to_parameters()
.size();
if (num_parameters_to_propagate > 0) {
xla::OpSharding op_sharding;
op_sharding.set_type(xla::OpSharding::REPLICATED);
std::vector<xla::OpSharding> parameter_shardings(
num_parameters_to_propagate, op_sharding);
ON_CALL(*mock_executable, GetParameterShardings())
.WillByDefault(testing::Return(std::move(parameter_shardings)));
}
int num_outputs_to_propagate =
options.executable_build_options
.allow_spmd_sharding_propagation_to_output()
.size();
if (num_outputs_to_propagate > 0) {
// Always infer output shardings to be replicated for the lit tests.
xla::OpSharding op_sharding;
op_sharding.set_type(xla::OpSharding::REPLICATED);
std::vector<xla::OpSharding> output_shardings(num_outputs_to_propagate,
op_sharding);
ON_CALL(*mock_executable, GetOutputShardings())
.WillByDefault(testing::Return(std::move(output_shardings)));
}
return AtomProgramCompileResult{
/*name=*/absl::StrCat("fake_component__", methods_.back()),
/*executable=*/std::move(mock_executable)};
}

absl::StatusOr<AtomProgramCompileResult> CompileMpmdReshard(
std::vector<DType> dtypes, std::vector<Shape> shapes,
std::vector<IfrtArrayType> in_array_types,
std::vector<IfrtArrayType> out_array_types) override
ABSL_LOCKS_EXCLUDED(mu_) {
absl::MutexLock lock(&mu_);
methods_.push_back(absl::StrCat("fake_method_", methods_.size()));
CHECK_LT(methods_.size(), kMaxTestMethods)
<< "push_back() might have caused reallocation, which might have "
"invalidated some method string_views.";
auto mock_executable =
std::make_unique<testing::NiceMock<MockLoadedExecutable>>();
return AtomProgramCompileResult{
/*name=*/absl::StrCat("fake_mpmd_reshard_component__", methods_.back()),
/*executable=*/std::make_unique<MockLoadedExecutable>()};
}

private:
absl::Mutex mu_;
std::vector<std::string> methods_ ABSL_GUARDED_BY(mu_);
};

} // namespace
} // namespace ifrt
} // namespace xla

int main(int argc, char** argv) {
std::shared_ptr<xla::ifrt::AtomProgramCompiler> compiler =
std::make_shared<xla::ifrt::TestChildExecutableCompiler>();
auto compile_options = std::make_shared<absl::flat_hash_map<
std::string, std::unique_ptr<xla::ifrt::CompileOptions>>>();
std::shared_ptr<xla::ifrt::AtomExecutableMap> atom_executable_map =
std::make_shared<xla::ifrt::AtomExecutableMap>();
std::shared_ptr<xla::ifrt::AtomExecutableMap> bound_executable_map =
std::make_shared<xla::ifrt::AtomExecutableMap>();

mlir::registerAllPasses();
xla::ifrt::RegisterIfrtPassesAndPipelines();
xla::ifrt::RegisterIfrtPassesAndPipelines(
compiler, compile_options, atom_executable_map, bound_executable_map);
mlir::DialectRegistry registry;
xla::ifrt::support::InitializeMlirDialectRegistry(registry);

Expand Down
Loading

0 comments on commit a2d165e

Please sign in to comment.