From ecc49f6a2a14d7edcd68340521e4b2d2500a3d8d Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Thu, 10 Aug 2023 23:37:22 -0700 Subject: [PATCH] Add a CLI `iree-ir-tool` with a command to strip data. (#14636) Been meaning to do this for a while in order to have a place to stash more power-user style things that core developers typically use iree-opt for. This one adds a `strip-data` sub-command which uses the passes from #14627 to systematically replace tensor constants with synthetic values. With an ASAN/asserts build, this was able to strip a 7GiB int4 vicuna MLIR file in ~5s and a 23GiB h2ogpt model in about 30s (the latter has some other characteristics which make it more expensive to load as well as being bigger). Results were a 2.6MiB and 1.4MiB MLIR file respectively, consisting just of program IR and annotations for synthetic data. Getting the opt pipeline right for arbitrary input is a bit tricky, so I decided we should just armor this into a tool From installed packages, this can be used as: ``` iree-ir-tool strip-data input.mlir -o output.mlir ``` From a build tree with Python setup: ``` python -m iree.compiler.tools.ir_tool strip-data input.mlir -o output.mlir ``` Required adding some additional compiler APIs: * `ireeCompilerInvocationRunPassPipeline` to run an arbitrary textual pass pipeline on an invocation. * `ireeCompilerInvocationOutputIRBytecode` to emit bytecode from an invocation. --- .../bindings/c/iree/compiler/embedding_api.h | 15 +++ .../c/iree/compiler/loader/handle_symbols.inc | 2 + .../c/iree/compiler/loader/loader.cpp | 12 ++ compiler/bindings/python/CMakeLists.txt | 1 + .../python/iree/compiler/api/ctypes_dl.py | 22 ++++ .../iree/compiler/tools/ir_tool/__main__.py | 106 ++++++++++++++++ .../bindings/python/test/tools/CMakeLists.txt | 7 ++ .../python/test/tools/ir_tool_test.py | 114 +++++++++++++++++ compiler/setup.py | 1 + .../iree/compiler/API/Internal/BUILD.bazel | 1 + .../iree/compiler/API/Internal/CMakeLists.txt | 1 + .../src/iree/compiler/API/Internal/Embed.cpp | 119 ++++++++++++------ compiler/src/iree/compiler/API/api_exports.c | 6 + .../src/iree/compiler/API/api_exports.def | 3 + compiler/src/iree/compiler/API/api_exports.ld | 3 + .../iree/compiler/API/api_exports.macos.lst | 3 + 16 files changed, 381 insertions(+), 35 deletions(-) create mode 100644 compiler/bindings/python/iree/compiler/tools/ir_tool/__main__.py create mode 100644 compiler/bindings/python/test/tools/ir_tool_test.py diff --git a/compiler/bindings/c/iree/compiler/embedding_api.h b/compiler/bindings/c/iree/compiler/embedding_api.h index d56dbcf6d415..2cb5eddba713 100644 --- a/compiler/bindings/c/iree/compiler/embedding_api.h +++ b/compiler/bindings/c/iree/compiler/embedding_api.h @@ -253,11 +253,26 @@ IREE_EMBED_EXPORTED bool ireeCompilerInvocationPipeline(iree_compiler_invocation_t *inv, enum iree_compiler_pipeline_t pipeline); +// Runs an arbitrary pass pipeline. +// Returns false and emits diagnostics on failure. +// Available since: 1.4 +IREE_EMBED_EXPORTED bool +ireeCompilerInvocationRunPassPipeline(iree_compiler_invocation_t *inv, + const char *textPassPipeline); + // Outputs the current compiler state as textual IR to the output. IREE_EMBED_EXPORTED iree_compiler_error_t * ireeCompilerInvocationOutputIR(iree_compiler_invocation_t *inv, iree_compiler_output_t *output); +// Outputs the current compiler state as bytecode IR to the output. +// Emits as the given bytecode version or most recent if -1. +// Available since: 1.4 +IREE_EMBED_EXPORTED iree_compiler_error_t * +ireeCompilerInvocationOutputIRBytecode(iree_compiler_invocation_t *inv, + iree_compiler_output_t *output, + int bytecodeVersion); + // Assuming that the compiler has produced VM IR, converts it to bytecode // and outputs it. This is a valid next step after running the // IREE_COMPILER_PIPELINE_STD pipeline. diff --git a/compiler/bindings/c/iree/compiler/loader/handle_symbols.inc b/compiler/bindings/c/iree/compiler/loader/handle_symbols.inc index 60ff4ce11eab..665a031aa36c 100644 --- a/compiler/bindings/c/iree/compiler/loader/handle_symbols.inc +++ b/compiler/bindings/c/iree/compiler/loader/handle_symbols.inc @@ -24,7 +24,9 @@ HANDLE_SYMBOL(ireeCompilerInvocationSetCompileFromPhase) HANDLE_SYMBOL(ireeCompilerInvocationSetCompileToPhase) HANDLE_SYMBOL(ireeCompilerInvocationSetVerifyIR) HANDLE_SYMBOL(ireeCompilerInvocationPipeline) +HANDLE_VERSIONED_SYMBOL(ireeCompilerInvocationRunPassPipeline, 1, 4) HANDLE_SYMBOL(ireeCompilerInvocationOutputIR) +HANDLE_VERSIONED_SYMBOL(ireeCompilerInvocationOutputIRBytecode, 1, 4) HANDLE_SYMBOL(ireeCompilerInvocationOutputVMBytecode) HANDLE_SYMBOL(ireeCompilerInvocationOutputVMCSource) HANDLE_SYMBOL(ireeCompilerInvocationOutputHALExecutable) diff --git a/compiler/bindings/c/iree/compiler/loader/loader.cpp b/compiler/bindings/c/iree/compiler/loader/loader.cpp index b4fd6144d34a..0469b2f3703b 100644 --- a/compiler/bindings/c/iree/compiler/loader/loader.cpp +++ b/compiler/bindings/c/iree/compiler/loader/loader.cpp @@ -275,12 +275,24 @@ bool ireeCompilerInvocationPipeline(iree_compiler_invocation_t *run, return __ireeCompilerInvocationPipeline(run, pipeline); } +bool ireeCompilerInvocationRunPassPipeline(iree_compiler_invocation_t *inv, + const char *textPassPipeline) { + return __ireeCompilerInvocationRunPassPipeline(inv, textPassPipeline); +} + iree_compiler_error_t * ireeCompilerInvocationOutputIR(iree_compiler_invocation_t *run, iree_compiler_output_t *output) { return __ireeCompilerInvocationOutputIR(run, output); } +iree_compiler_error_t * +ireeCompilerInvocationOutputIRBytecode(iree_compiler_invocation_t *inv, + iree_compiler_output_t *output, + int bytecodeVersion) { + return __ireeCompilerInvocationOutputIRBytecode(inv, output, bytecodeVersion); +} + iree_compiler_error_t * ireeCompilerInvocationOutputVMBytecode(iree_compiler_invocation_t *run, iree_compiler_output_t *output) { diff --git a/compiler/bindings/python/CMakeLists.txt b/compiler/bindings/python/CMakeLists.txt index b3566bb32009..206c08542ff3 100644 --- a/compiler/bindings/python/CMakeLists.txt +++ b/compiler/bindings/python/CMakeLists.txt @@ -60,6 +60,7 @@ declare_mlir_python_sources(IREECompilerAPIPythonTools SOURCES_GLOB api/*.py tools/*.py + tools/ir_tool/*.py ) ################################################################################ diff --git a/compiler/bindings/python/iree/compiler/api/ctypes_dl.py b/compiler/bindings/python/iree/compiler/api/ctypes_dl.py index ee14ee1bea9a..40be17a4b9e6 100644 --- a/compiler/bindings/python/iree/compiler/api/ctypes_dl.py +++ b/compiler/bindings/python/iree/compiler/api/ctypes_dl.py @@ -54,7 +54,13 @@ def _init_dylib(): _setsig(_dylib.ireeCompilerInvocationEnableConsoleDiagnostics, None, [c_void_p]) _setsig(_dylib.ireeCompilerInvocationParseSource, c_bool, [c_void_p, c_void_p]) _setsig(_dylib.ireeCompilerInvocationPipeline, c_bool, [c_void_p, c_int]) + _setsig(_dylib.ireeCompilerInvocationRunPassPipeline, c_bool, [c_void_p, c_char_p]) _setsig(_dylib.ireeCompilerInvocationOutputIR, c_void_p, [c_void_p, c_void_p]) + _setsig( + _dylib.ireeCompilerInvocationOutputIRBytecode, + c_void_p, + [c_void_p, c_void_p, c_int], + ) _setsig( _dylib.ireeCompilerInvocationOutputVMBytecode, c_void_p, [c_void_p, c_void_p] ) @@ -328,6 +334,10 @@ def __init__(self, session: Session): # Invocation. self._retained_module_op = None + @property + def session(self) -> Session: + return self._session + def __del__(self): self.close() @@ -365,11 +375,23 @@ def execute( ) -> bool: return _dylib.ireeCompilerInvocationPipeline(self._inv_p, pipeline) + def execute_text_pass_pipeline(self, text_pipeline_spec: str) -> bool: + return _dylib.ireeCompilerInvocationRunPassPipeline( + self._inv_p, text_pipeline_spec.encode() + ) + def output_ir(self, output: Output): _handle_error( _dylib.ireeCompilerInvocationOutputIR(self._inv_p, output._output_p) ) + def output_ir_bytecode(self, output: Output, bytecode_version: int = -1): + _handle_error( + _dylib.ireeCompilerInvocationOutputIRBytecode( + self._inv_p, output._output_p, bytecode_version + ) + ) + def output_vm_bytecode(self, output: Output): _handle_error( _dylib.ireeCompilerInvocationOutputVMBytecode(self._inv_p, output._output_p) diff --git a/compiler/bindings/python/iree/compiler/tools/ir_tool/__main__.py b/compiler/bindings/python/iree/compiler/tools/ir_tool/__main__.py new file mode 100644 index 000000000000..4448c9447e1d --- /dev/null +++ b/compiler/bindings/python/iree/compiler/tools/ir_tool/__main__.py @@ -0,0 +1,106 @@ +# Copyright 2023 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import argparse +import logging +import sys + +from ...api import Invocation, Session, Source, Output + + +def load_source(inv: Invocation, input_file: str) -> Source: + source = Source.open_file(inv.session, input_file) + if not inv.parse_source(source): + raise RuntimeError(f"Error parsing source file {input_file}") + return source + + +def write_output(inv: Invocation, output: Output, args, keep: bool = True): + if args.emit_bytecode: + inv.output_ir_bytecode(output, args.bytecode_version) + else: + inv.output_ir(output) + if keep: + output.keep() + + +############################################################################### +# CLI handling +############################################################################### + + +def parse_arguments(argv=None): + parser = argparse.ArgumentParser(description="IREE IR Tool") + subparsers = parser.add_subparsers( + help="sub-command help", required=True, dest="sub_command" + ) + + def add_ouptut_options(subparser): + subparser.add_argument( + "--emit-bytecode", action="store_true", help="Emit bytecode" + ) + subparser.add_argument( + "--bytecode-version", + default=-1, + type=int, + help="Bytecode version to emit or -1 for latest", + ) + + # strip-data command. + strip_data_parser = subparsers.add_parser( + "strip-data", + help="Strip large constants and values, " + "replacing them with pseudo data suitable for interactive " + "debugging of IR", + ) + add_ouptut_options(strip_data_parser) + strip_data_parser.add_argument( + "--no-import", + action="store_true", + help="Disable import of public dialects to internal", + ) + strip_data_parser.add_argument("input_file", help="File to process") + strip_data_parser.add_argument( + "-o", required=True, dest="output_file", help="Output file" + ) + args = parser.parse_args(argv) + return args + + +def main(args) -> int: + if args.sub_command == "strip-data": + return do_strip_data(args) + else: + print("error: Unrecognized sub-command {args.sub_command}", file=sys.stderr) + return 1 + return 0 + + +def do_strip_data(args) -> int: + session = Session() + output = Output.open_file(args.output_file) + inv = session.invocation() + inv.enable_console_diagnostics() + load_source(inv, args.input_file) + if not args.no_import: + if not inv.execute_text_pass_pipeline( + "iree-import-public, iree-import-ml-program" + ): + return 1 + if not inv.execute_text_pass_pipeline( + "iree-util-outline-constants, iree-util-strip-and-splat-constants" + ): + return 2 + write_output(inv, output, args) + return 0 + + +def _cli_main(): + sys.exit(main(parse_arguments())) + + +if __name__ == "__main__": + _cli_main() diff --git a/compiler/bindings/python/test/tools/CMakeLists.txt b/compiler/bindings/python/test/tools/CMakeLists.txt index 427116b4d299..01456f9fb0f9 100644 --- a/compiler/bindings/python/test/tools/CMakeLists.txt +++ b/compiler/bindings/python/test/tools/CMakeLists.txt @@ -15,6 +15,13 @@ iree_py_test( ) endif() # IREE_BUILD_BUNDLED_LLVM +iree_py_test( + NAME + ir_tool_test + SRCS + "ir_tool_test.py" +) + iree_py_test( NAME compiler_tf_test diff --git a/compiler/bindings/python/test/tools/ir_tool_test.py b/compiler/bindings/python/test/tools/ir_tool_test.py new file mode 100644 index 000000000000..2a7a538ffb14 --- /dev/null +++ b/compiler/bindings/python/test/tools/ir_tool_test.py @@ -0,0 +1,114 @@ +# Copyright 2023 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from iree.compiler.tools.ir_tool import __main__ + +import os +import tempfile +import unittest + + +def run_tool(*argv: str): + try: + args = __main__.parse_arguments(list(argv)) + __main__.main(args) + except SystemExit as e: + if e.code != 0: + raise RuntimeError(f"Tool exited with code {e.code}") + + +class IrToolTest(unittest.TestCase): + def setUp(self): + with tempfile.NamedTemporaryFile(delete=False) as f: + self.inputPath = f.name + with tempfile.NamedTemporaryFile(delete=False) as f: + self.outputPath = f.name + + def tearDown(self) -> None: + if os.path.exists(self.inputPath): + os.unlink(self.inputPath) + if os.path.exists(self.outputPath): + os.unlink(self.outputPath) + + def saveInput(self, contents, text=True): + with open(self.inputPath, "wt" if text else "wb") as f: + f.write(contents) + + def loadOutput(self, text=True): + with open(self.outputPath, "rt" if text else "rb") as f: + return f.read() + + def testStripDataWithImport(self): + self.saveInput( + r""" + builtin.module { + func.func @main() -> tensor<4xf32> { + %0 = arith.constant dense<[0.1, 0.2, 0.3, 0.4]> : tensor<4xf32> + func.return %0 : tensor<4xf32> + } + } + """ + ) + run_tool("strip-data", self.inputPath, "-o", self.outputPath) + output = self.loadOutput() + print("Output:", output) + self.assertIn("#util.byte_pattern", output) + + def testStripDataNoImport(self): + # Without import, ml_program.global is not recognized. + self.saveInput( + r""" + builtin.module { + ml_program.global public @foobar(dense<[0.1, 0.2, 0.3, 0.4]> : tensor<4xf32>) : tensor<4xf32> + } + """ + ) + run_tool("strip-data", "--no-import", self.inputPath, "-o", self.outputPath) + output = self.loadOutput() + print("Output:", output) + self.assertNotIn("#util.byte_pattern", output) + + def testStripDataParseError(self): + self.saveInput( + r""" + FOOBAR + """ + ) + with self.assertRaisesRegex(RuntimeError, "Error parsing source file"): + run_tool("strip-data", self.inputPath, "-o", self.outputPath) + + def testStripDataEmitBytecode(self): + self.saveInput( + r""" + builtin.module { + } + """ + ) + run_tool("strip-data", "--emit-bytecode", self.inputPath, "-o", self.outputPath) + output = self.loadOutput(text=False) + self.assertIn(b"MLIR", output) + + def testStripDataEmitBytecodeVersion(self): + self.saveInput( + r""" + builtin.module { + } + """ + ) + run_tool( + "strip-data", + "--emit-bytecode", + "--bytecode-version=0", + self.inputPath, + "-o", + self.outputPath, + ) + output = self.loadOutput(text=False) + self.assertIn(b"MLIR", output) + + +if __name__ == "__main__": + unittest.main() diff --git a/compiler/setup.py b/compiler/setup.py index 37a3aa15c3e5..ea1859f0dd60 100644 --- a/compiler/setup.py +++ b/compiler/setup.py @@ -455,6 +455,7 @@ def find_git_submodule_revision(submodule_path): # TODO: We have renamed to iree-compile on 2022-03-18. Remove # this alias once no longer needed. "ireec = iree.compiler.tools.scripts.ireec.__main__:main", + "iree-ir-tool = iree.compiler.tools.ir_tool.__main__:_cli_main", ], }, install_requires=[ diff --git a/compiler/src/iree/compiler/API/Internal/BUILD.bazel b/compiler/src/iree/compiler/API/Internal/BUILD.bazel index f9cec5c1139e..ab6a55fd39c9 100644 --- a/compiler/src/iree/compiler/API/Internal/BUILD.bazel +++ b/compiler/src/iree/compiler/API/Internal/BUILD.bazel @@ -38,6 +38,7 @@ iree_compiler_cc_library( "//compiler/src/iree/compiler/Utils", "@llvm-project//llvm:Support", "@llvm-project//mlir:BuiltinToLLVMIRTranslation", + "@llvm-project//mlir:BytecodeWriter", "@llvm-project//mlir:CAPIIR", "@llvm-project//mlir:IR", "@llvm-project//mlir:Parser", diff --git a/compiler/src/iree/compiler/API/Internal/CMakeLists.txt b/compiler/src/iree/compiler/API/Internal/CMakeLists.txt index aabf542c71d1..d14608f4d576 100644 --- a/compiler/src/iree/compiler/API/Internal/CMakeLists.txt +++ b/compiler/src/iree/compiler/API/Internal/CMakeLists.txt @@ -21,6 +21,7 @@ iree_cc_library( DEPS LLVMSupport MLIRBuiltinToLLVMIRTranslation + MLIRBytecodeWriter MLIRCAPIIR MLIRIR MLIRParser diff --git a/compiler/src/iree/compiler/API/Internal/Embed.cpp b/compiler/src/iree/compiler/API/Internal/Embed.cpp index 044cd6933fee..36801a921a5a 100644 --- a/compiler/src/iree/compiler/API/Internal/Embed.cpp +++ b/compiler/src/iree/compiler/API/Internal/Embed.cpp @@ -63,6 +63,7 @@ #include "llvm/Support/Signals.h" #include "llvm/Support/SourceMgr.h" #include "llvm/Support/ToolOutputFile.h" +#include "mlir/Bytecode/BytecodeWriter.h" #include "mlir/CAPI/IR.h" #include "mlir/CAPI/Wrap.h" #include "mlir/IR/AsmState.h" @@ -84,7 +85,7 @@ #endif #define IREE_COMPILER_API_MAJOR 1 -#define IREE_COMPILER_API_MINOR 3 +#define IREE_COMPILER_API_MINOR 4 namespace mlir::iree_compiler::embed { namespace { @@ -529,19 +530,23 @@ void Output::keep() { // Invocation corresponds to iree_compiler_invocation_t struct Invocation { + using PassManagerInitializer = std::function; Invocation(Session &session); ~Invocation(); bool initializeInvocation(); + std::unique_ptr createPassManager(); bool parseSource(Source &source); bool importModule(Operation *inputModule, bool steal); bool runPipeline(enum iree_compiler_pipeline_t pipeline); + bool runTextualPassPipeline(const char *textPassPipeline); Error *outputIR(Output &output); + Error *outputIRBytecode(Output &output, int bytecodeVersion); Error *outputVMBytecode(Output &output); Error *outputVMCSource(Output &output); Error *outputHALExecutable(Output &output); Session &session; - PassManager passManager; + llvm::SmallVector passManagerInitializers; IREEVMPipelineHooks pipelineHooks; // Diagnostic handlers are instantiated upon parsing the source (when we @@ -567,17 +572,7 @@ struct Invocation { int diagnosticCallbackFlags = 0; }; -Invocation::Invocation(Session &session) - : session(session), passManager(&session.context) { - if (session.globalInit.usesCommandLine) { - if (failed(mlir::applyPassManagerCLOptions(passManager))) { - emitError(UnknownLoc::get(&session.context)) - << "Failed to apply pass manager CL options"; - } - mlir::applyDefaultTimingPassManagerCLOptions(passManager); - } - passManager.addInstrumentation(std::make_unique()); - +Invocation::Invocation(Session &session) : session(session) { // Since the jitter invokes much of the top-level compiler recursively, // it must be injected at the top-level here vs in the pass pipeline // (or else the circular dependency cannot be resolved). @@ -597,6 +592,23 @@ Invocation::~Invocation() { } } +std::unique_ptr Invocation::createPassManager() { + auto passManager = std::make_unique(&session.context); + if (session.globalInit.usesCommandLine) { + if (failed(mlir::applyPassManagerCLOptions(*passManager))) { + emitError(UnknownLoc::get(&session.context)) + << "Failed to apply pass manager CL options"; + } + mlir::applyDefaultTimingPassManagerCLOptions(*passManager); + } + passManager->addInstrumentation(std::make_unique()); + passManager->enableVerifier(enableVerifier); + for (auto &init : passManagerInitializers) { + init(*passManager); + } + return passManager; +} + bool Invocation::initializeInvocation() { // Initialize callback diagnostics. if (diagnosticCallback && !callbackDiagnosticHandler) { @@ -697,6 +709,7 @@ bool Invocation::importModule(Operation *inputModule, bool steal) { } bool Invocation::runPipeline(enum iree_compiler_pipeline_t pipeline) { + auto passManager = createPassManager(); switch (pipeline) { case IREE_COMPILER_PIPELINE_STD: { // Parse the compile to phase name. @@ -745,7 +758,7 @@ bool Invocation::runPipeline(enum iree_compiler_pipeline_t pipeline) { session.targetRegistry, session.bindingOptions, session.inputOptions, session.preprocessingOptions, session.highLevelOptimizationOptions, session.schedulingOptions, session.halTargetOptions, - session.vmTargetOptions, pipelineHooks, passManager, *compileFromPhase, + session.vmTargetOptions, pipelineHooks, *passManager, *compileFromPhase, *compileToPhase); break; } @@ -764,7 +777,7 @@ bool Invocation::runPipeline(enum iree_compiler_pipeline_t pipeline) { return false; } IREE::HAL::buildHALTransformPassPipeline( - passManager, session.targetRegistry, session.halTargetOptions); + *passManager, session.targetRegistry, session.halTargetOptions); break; } default: @@ -772,8 +785,18 @@ bool Invocation::runPipeline(enum iree_compiler_pipeline_t pipeline) { return false; } - passManager.enableVerifier(enableVerifier); - if (failed(passManager.run(parsedModule))) { + if (failed(passManager->run(parsedModule))) { + return false; + } + return true; +} + +bool Invocation::runTextualPassPipeline(const char *textPassPipeline) { + auto passManager = createPassManager(); + if (failed(mlir::parsePassPipeline(textPassPipeline, *passManager, + llvm::errs()))) + return false; + if (failed(passManager->run(parsedModule))) { return false; } return true; @@ -784,6 +807,17 @@ Error *Invocation::outputIR(Output &output) { return output.getWriteError(); } +Error *Invocation::outputIRBytecode(Output &output, int bytecodeVersion) { + mlir::BytecodeWriterConfig config; + if (bytecodeVersion >= 0) + config.setDesiredBytecodeVersion(bytecodeVersion); + if (failed(mlir::writeBytecodeToFile(parsedModule, *output.outputStream, + config))) { + return new Error("illegal bytecode version requested"); + } + return output.getWriteError(); +} + Error *Invocation::outputVMBytecode(Output &output) { auto vmModule = llvm::dyn_cast(*parsedModule); auto builtinModule = llvm::dyn_cast(*parsedModule); @@ -1134,24 +1168,27 @@ void ireeCompilerInvocationSetCrashHandler( iree_compiler_output_t *output; }; - unwrap(inv)->passManager.enableCrashReproducerGeneration( - [=](std::string &errorMessage) - -> std::unique_ptr { - iree_compiler_output_t *output = nullptr; - auto error = onCrashCallback(&output, userData); - if (error) { - errorMessage = ireeCompilerErrorGetMessage(error); - return nullptr; - } - - if (!output) { - errorMessage = "callback did not set output"; - return nullptr; - } - - return std::make_unique(output); - }, - /*genLocalReproducer=*/genLocalReproducer); + unwrap(inv)->passManagerInitializers.push_back( + [=](mlir::PassManager &passManager) { + passManager.enableCrashReproducerGeneration( + [=](std::string &errorMessage) + -> std::unique_ptr { + iree_compiler_output_t *output = nullptr; + auto error = onCrashCallback(&output, userData); + if (error) { + errorMessage = ireeCompilerErrorGetMessage(error); + return nullptr; + } + + if (!output) { + errorMessage = "callback did not set output"; + return nullptr; + } + + return std::make_unique(output); + }, + /*genLocalReproducer=*/genLocalReproducer); + }); } bool ireeCompilerInvocationParseSource(iree_compiler_invocation_t *inv, @@ -1179,6 +1216,11 @@ bool ireeCompilerInvocationPipeline(iree_compiler_invocation_t *inv, return unwrap(inv)->runPipeline(pipeline); } +bool ireeCompilerInvocationRunPassPipeline(iree_compiler_invocation_t *inv, + const char *textPassPipeline) { + return unwrap(inv)->runTextualPassPipeline(textPassPipeline); +} + void ireeCompilerSourceDestroy(iree_compiler_source_t *source) { delete unwrap(source); } @@ -1262,6 +1304,13 @@ ireeCompilerInvocationOutputIR(iree_compiler_invocation_t *inv, return wrap(unwrap(inv)->outputIR(*unwrap(output))); } +iree_compiler_error_t * +ireeCompilerInvocationOutputIRBytecode(iree_compiler_invocation_t *inv, + iree_compiler_output_t *output, + int bytecodeVersion) { + return wrap(unwrap(inv)->outputIRBytecode(*unwrap(output), bytecodeVersion)); +} + iree_compiler_error_t * ireeCompilerInvocationOutputVMBytecode(iree_compiler_invocation_t *inv, iree_compiler_output_t *output) { diff --git a/compiler/src/iree/compiler/API/api_exports.c b/compiler/src/iree/compiler/API/api_exports.c index 5a3cae73dfdf..114d7ca5991b 100644 --- a/compiler/src/iree/compiler/API/api_exports.c +++ b/compiler/src/iree/compiler/API/api_exports.c @@ -25,10 +25,12 @@ extern void ireeCompilerInvocationImportBorrowModule(); extern void ireeCompilerInvocationImportStealModule(); extern void ireeCompilerInvocationOutputHALExecutable(); extern void ireeCompilerInvocationOutputIR(); +extern void ireeCompilerInvocationOutputIRBytecode(); extern void ireeCompilerInvocationOutputVMBytecode(); extern void ireeCompilerInvocationOutputVMCSource(); extern void ireeCompilerInvocationParseSource(); extern void ireeCompilerInvocationPipeline(); +extern void ireeCompilerInvocationRunPassPipeline(); extern void ireeCompilerInvocationSetCompileFromPhase(); extern void ireeCompilerInvocationSetCompileToPhase(); extern void ireeCompilerInvocationSetCrashHandler(); @@ -649,6 +651,7 @@ extern void mlirValueIsAOpResult(); extern void mlirValuePrint(); extern void mlirValuePrintAsOperand(); extern void mlirValueReplaceAllUsesOfWith(); +extern void mlirValueSetType(); extern void mlirVectorTypeGet(); extern void mlirVectorTypeGetChecked(); extern void mlirVectorTypeGetTypeID(); @@ -672,10 +675,12 @@ uintptr_t __iree_compiler_hidden_force_extern() { x += (uintptr_t)&ireeCompilerInvocationImportStealModule; x += (uintptr_t)&ireeCompilerInvocationOutputHALExecutable; x += (uintptr_t)&ireeCompilerInvocationOutputIR; + x += (uintptr_t)&ireeCompilerInvocationOutputIRBytecode; x += (uintptr_t)&ireeCompilerInvocationOutputVMBytecode; x += (uintptr_t)&ireeCompilerInvocationOutputVMCSource; x += (uintptr_t)&ireeCompilerInvocationParseSource; x += (uintptr_t)&ireeCompilerInvocationPipeline; + x += (uintptr_t)&ireeCompilerInvocationRunPassPipeline; x += (uintptr_t)&ireeCompilerInvocationSetCompileFromPhase; x += (uintptr_t)&ireeCompilerInvocationSetCompileToPhase; x += (uintptr_t)&ireeCompilerInvocationSetCrashHandler; @@ -1296,6 +1301,7 @@ uintptr_t __iree_compiler_hidden_force_extern() { x += (uintptr_t)&mlirValuePrint; x += (uintptr_t)&mlirValuePrintAsOperand; x += (uintptr_t)&mlirValueReplaceAllUsesOfWith; + x += (uintptr_t)&mlirValueSetType; x += (uintptr_t)&mlirVectorTypeGet; x += (uintptr_t)&mlirVectorTypeGetChecked; x += (uintptr_t)&mlirVectorTypeGetTypeID; diff --git a/compiler/src/iree/compiler/API/api_exports.def b/compiler/src/iree/compiler/API/api_exports.def index 986cb461f06f..b07a8a4a6470 100644 --- a/compiler/src/iree/compiler/API/api_exports.def +++ b/compiler/src/iree/compiler/API/api_exports.def @@ -17,10 +17,12 @@ EXPORTS ireeCompilerInvocationImportStealModule ireeCompilerInvocationOutputHALExecutable ireeCompilerInvocationOutputIR + ireeCompilerInvocationOutputIRBytecode ireeCompilerInvocationOutputVMBytecode ireeCompilerInvocationOutputVMCSource ireeCompilerInvocationParseSource ireeCompilerInvocationPipeline + ireeCompilerInvocationRunPassPipeline ireeCompilerInvocationSetCompileFromPhase ireeCompilerInvocationSetCompileToPhase ireeCompilerInvocationSetCrashHandler @@ -641,6 +643,7 @@ EXPORTS mlirValuePrint mlirValuePrintAsOperand mlirValueReplaceAllUsesOfWith + mlirValueSetType mlirVectorTypeGet mlirVectorTypeGetChecked mlirVectorTypeGetTypeID diff --git a/compiler/src/iree/compiler/API/api_exports.ld b/compiler/src/iree/compiler/API/api_exports.ld index a252dcbce2a1..f02718fa20a7 100644 --- a/compiler/src/iree/compiler/API/api_exports.ld +++ b/compiler/src/iree/compiler/API/api_exports.ld @@ -18,10 +18,12 @@ VER_0 { ireeCompilerInvocationImportStealModule; ireeCompilerInvocationOutputHALExecutable; ireeCompilerInvocationOutputIR; + ireeCompilerInvocationOutputIRBytecode; ireeCompilerInvocationOutputVMBytecode; ireeCompilerInvocationOutputVMCSource; ireeCompilerInvocationParseSource; ireeCompilerInvocationPipeline; + ireeCompilerInvocationRunPassPipeline; ireeCompilerInvocationSetCompileFromPhase; ireeCompilerInvocationSetCompileToPhase; ireeCompilerInvocationSetCrashHandler; @@ -642,6 +644,7 @@ VER_0 { mlirValuePrint; mlirValuePrintAsOperand; mlirValueReplaceAllUsesOfWith; + mlirValueSetType; mlirVectorTypeGet; mlirVectorTypeGetChecked; mlirVectorTypeGetTypeID; diff --git a/compiler/src/iree/compiler/API/api_exports.macos.lst b/compiler/src/iree/compiler/API/api_exports.macos.lst index 07fb7539c832..14283da2ec5a 100644 --- a/compiler/src/iree/compiler/API/api_exports.macos.lst +++ b/compiler/src/iree/compiler/API/api_exports.macos.lst @@ -16,10 +16,12 @@ _ireeCompilerInvocationImportBorrowModule _ireeCompilerInvocationImportStealModule _ireeCompilerInvocationOutputHALExecutable _ireeCompilerInvocationOutputIR +_ireeCompilerInvocationOutputIRBytecode _ireeCompilerInvocationOutputVMBytecode _ireeCompilerInvocationOutputVMCSource _ireeCompilerInvocationParseSource _ireeCompilerInvocationPipeline +_ireeCompilerInvocationRunPassPipeline _ireeCompilerInvocationSetCompileFromPhase _ireeCompilerInvocationSetCompileToPhase _ireeCompilerInvocationSetCrashHandler @@ -640,6 +642,7 @@ _mlirValueIsAOpResult _mlirValuePrint _mlirValuePrintAsOperand _mlirValueReplaceAllUsesOfWith +_mlirValueSetType _mlirVectorTypeGet _mlirVectorTypeGetChecked _mlirVectorTypeGetTypeID