Skip to content

Commit

Permalink
Add a CLI iree-ir-tool with a command to strip data. (#14636)
Browse files Browse the repository at this point in the history
Been meaning to do this for a while in order to have a place to stash
more power-user style things that core developers typically use iree-opt
for.

This one adds a `strip-data` sub-command which uses the passes from
#14627 to systematically replace tensor constants with synthetic values.
With an ASAN/asserts build, this was able to strip a 7GiB int4 vicuna
MLIR file in ~5s and a 23GiB h2ogpt model in about 30s (the latter has
some other characteristics which make it more expensive to load as well
as being bigger). Results were a 2.6MiB and 1.4MiB MLIR file
respectively, consisting just of program IR and annotations for
synthetic data.

Getting the opt pipeline right for arbitrary input is a bit tricky, so I
decided we should just armor this into a tool

From installed packages, this can be used as:

```
iree-ir-tool strip-data input.mlir -o output.mlir
```

From a build tree with Python setup:

```
python -m iree.compiler.tools.ir_tool strip-data input.mlir -o output.mlir
```

Required adding some additional compiler APIs:

* `ireeCompilerInvocationRunPassPipeline` to run an arbitrary textual
pass pipeline on an invocation.
* `ireeCompilerInvocationOutputIRBytecode` to emit bytecode from an
invocation.
  • Loading branch information
stellaraccident authored Aug 11, 2023
1 parent 57b9239 commit ecc49f6
Show file tree
Hide file tree
Showing 16 changed files with 381 additions and 35 deletions.
15 changes: 15 additions & 0 deletions compiler/bindings/c/iree/compiler/embedding_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -253,11 +253,26 @@ IREE_EMBED_EXPORTED bool
ireeCompilerInvocationPipeline(iree_compiler_invocation_t *inv,
enum iree_compiler_pipeline_t pipeline);

// Runs an arbitrary pass pipeline.
// Returns false and emits diagnostics on failure.
// Available since: 1.4
IREE_EMBED_EXPORTED bool
ireeCompilerInvocationRunPassPipeline(iree_compiler_invocation_t *inv,
const char *textPassPipeline);

// Outputs the current compiler state as textual IR to the output.
IREE_EMBED_EXPORTED iree_compiler_error_t *
ireeCompilerInvocationOutputIR(iree_compiler_invocation_t *inv,
iree_compiler_output_t *output);

// Outputs the current compiler state as bytecode IR to the output.
// Emits as the given bytecode version or most recent if -1.
// Available since: 1.4
IREE_EMBED_EXPORTED iree_compiler_error_t *
ireeCompilerInvocationOutputIRBytecode(iree_compiler_invocation_t *inv,
iree_compiler_output_t *output,
int bytecodeVersion);

// Assuming that the compiler has produced VM IR, converts it to bytecode
// and outputs it. This is a valid next step after running the
// IREE_COMPILER_PIPELINE_STD pipeline.
Expand Down
2 changes: 2 additions & 0 deletions compiler/bindings/c/iree/compiler/loader/handle_symbols.inc
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ HANDLE_SYMBOL(ireeCompilerInvocationSetCompileFromPhase)
HANDLE_SYMBOL(ireeCompilerInvocationSetCompileToPhase)
HANDLE_SYMBOL(ireeCompilerInvocationSetVerifyIR)
HANDLE_SYMBOL(ireeCompilerInvocationPipeline)
HANDLE_VERSIONED_SYMBOL(ireeCompilerInvocationRunPassPipeline, 1, 4)
HANDLE_SYMBOL(ireeCompilerInvocationOutputIR)
HANDLE_VERSIONED_SYMBOL(ireeCompilerInvocationOutputIRBytecode, 1, 4)
HANDLE_SYMBOL(ireeCompilerInvocationOutputVMBytecode)
HANDLE_SYMBOL(ireeCompilerInvocationOutputVMCSource)
HANDLE_SYMBOL(ireeCompilerInvocationOutputHALExecutable)
Expand Down
12 changes: 12 additions & 0 deletions compiler/bindings/c/iree/compiler/loader/loader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -275,12 +275,24 @@ bool ireeCompilerInvocationPipeline(iree_compiler_invocation_t *run,
return __ireeCompilerInvocationPipeline(run, pipeline);
}

bool ireeCompilerInvocationRunPassPipeline(iree_compiler_invocation_t *inv,
const char *textPassPipeline) {
return __ireeCompilerInvocationRunPassPipeline(inv, textPassPipeline);
}

iree_compiler_error_t *
ireeCompilerInvocationOutputIR(iree_compiler_invocation_t *run,
iree_compiler_output_t *output) {
return __ireeCompilerInvocationOutputIR(run, output);
}

iree_compiler_error_t *
ireeCompilerInvocationOutputIRBytecode(iree_compiler_invocation_t *inv,
iree_compiler_output_t *output,
int bytecodeVersion) {
return __ireeCompilerInvocationOutputIRBytecode(inv, output, bytecodeVersion);
}

iree_compiler_error_t *
ireeCompilerInvocationOutputVMBytecode(iree_compiler_invocation_t *run,
iree_compiler_output_t *output) {
Expand Down
1 change: 1 addition & 0 deletions compiler/bindings/python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ declare_mlir_python_sources(IREECompilerAPIPythonTools
SOURCES_GLOB
api/*.py
tools/*.py
tools/ir_tool/*.py
)

################################################################################
Expand Down
22 changes: 22 additions & 0 deletions compiler/bindings/python/iree/compiler/api/ctypes_dl.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,13 @@ def _init_dylib():
_setsig(_dylib.ireeCompilerInvocationEnableConsoleDiagnostics, None, [c_void_p])
_setsig(_dylib.ireeCompilerInvocationParseSource, c_bool, [c_void_p, c_void_p])
_setsig(_dylib.ireeCompilerInvocationPipeline, c_bool, [c_void_p, c_int])
_setsig(_dylib.ireeCompilerInvocationRunPassPipeline, c_bool, [c_void_p, c_char_p])
_setsig(_dylib.ireeCompilerInvocationOutputIR, c_void_p, [c_void_p, c_void_p])
_setsig(
_dylib.ireeCompilerInvocationOutputIRBytecode,
c_void_p,
[c_void_p, c_void_p, c_int],
)
_setsig(
_dylib.ireeCompilerInvocationOutputVMBytecode, c_void_p, [c_void_p, c_void_p]
)
Expand Down Expand Up @@ -328,6 +334,10 @@ def __init__(self, session: Session):
# Invocation.
self._retained_module_op = None

@property
def session(self) -> Session:
return self._session

def __del__(self):
self.close()

Expand Down Expand Up @@ -365,11 +375,23 @@ def execute(
) -> bool:
return _dylib.ireeCompilerInvocationPipeline(self._inv_p, pipeline)

def execute_text_pass_pipeline(self, text_pipeline_spec: str) -> bool:
return _dylib.ireeCompilerInvocationRunPassPipeline(
self._inv_p, text_pipeline_spec.encode()
)

def output_ir(self, output: Output):
_handle_error(
_dylib.ireeCompilerInvocationOutputIR(self._inv_p, output._output_p)
)

def output_ir_bytecode(self, output: Output, bytecode_version: int = -1):
_handle_error(
_dylib.ireeCompilerInvocationOutputIRBytecode(
self._inv_p, output._output_p, bytecode_version
)
)

def output_vm_bytecode(self, output: Output):
_handle_error(
_dylib.ireeCompilerInvocationOutputVMBytecode(self._inv_p, output._output_p)
Expand Down
106 changes: 106 additions & 0 deletions compiler/bindings/python/iree/compiler/tools/ir_tool/__main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# Copyright 2023 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import argparse
import logging
import sys

from ...api import Invocation, Session, Source, Output


def load_source(inv: Invocation, input_file: str) -> Source:
source = Source.open_file(inv.session, input_file)
if not inv.parse_source(source):
raise RuntimeError(f"Error parsing source file {input_file}")
return source


def write_output(inv: Invocation, output: Output, args, keep: bool = True):
if args.emit_bytecode:
inv.output_ir_bytecode(output, args.bytecode_version)
else:
inv.output_ir(output)
if keep:
output.keep()


###############################################################################
# CLI handling
###############################################################################


def parse_arguments(argv=None):
parser = argparse.ArgumentParser(description="IREE IR Tool")
subparsers = parser.add_subparsers(
help="sub-command help", required=True, dest="sub_command"
)

def add_ouptut_options(subparser):
subparser.add_argument(
"--emit-bytecode", action="store_true", help="Emit bytecode"
)
subparser.add_argument(
"--bytecode-version",
default=-1,
type=int,
help="Bytecode version to emit or -1 for latest",
)

# strip-data command.
strip_data_parser = subparsers.add_parser(
"strip-data",
help="Strip large constants and values, "
"replacing them with pseudo data suitable for interactive "
"debugging of IR",
)
add_ouptut_options(strip_data_parser)
strip_data_parser.add_argument(
"--no-import",
action="store_true",
help="Disable import of public dialects to internal",
)
strip_data_parser.add_argument("input_file", help="File to process")
strip_data_parser.add_argument(
"-o", required=True, dest="output_file", help="Output file"
)
args = parser.parse_args(argv)
return args


def main(args) -> int:
if args.sub_command == "strip-data":
return do_strip_data(args)
else:
print("error: Unrecognized sub-command {args.sub_command}", file=sys.stderr)
return 1
return 0


def do_strip_data(args) -> int:
session = Session()
output = Output.open_file(args.output_file)
inv = session.invocation()
inv.enable_console_diagnostics()
load_source(inv, args.input_file)
if not args.no_import:
if not inv.execute_text_pass_pipeline(
"iree-import-public, iree-import-ml-program"
):
return 1
if not inv.execute_text_pass_pipeline(
"iree-util-outline-constants, iree-util-strip-and-splat-constants"
):
return 2
write_output(inv, output, args)
return 0


def _cli_main():
sys.exit(main(parse_arguments()))


if __name__ == "__main__":
_cli_main()
7 changes: 7 additions & 0 deletions compiler/bindings/python/test/tools/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,13 @@ iree_py_test(
)
endif() # IREE_BUILD_BUNDLED_LLVM

iree_py_test(
NAME
ir_tool_test
SRCS
"ir_tool_test.py"
)

iree_py_test(
NAME
compiler_tf_test
Expand Down
114 changes: 114 additions & 0 deletions compiler/bindings/python/test/tools/ir_tool_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# Copyright 2023 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from iree.compiler.tools.ir_tool import __main__

import os
import tempfile
import unittest


def run_tool(*argv: str):
try:
args = __main__.parse_arguments(list(argv))
__main__.main(args)
except SystemExit as e:
if e.code != 0:
raise RuntimeError(f"Tool exited with code {e.code}")


class IrToolTest(unittest.TestCase):
def setUp(self):
with tempfile.NamedTemporaryFile(delete=False) as f:
self.inputPath = f.name
with tempfile.NamedTemporaryFile(delete=False) as f:
self.outputPath = f.name

def tearDown(self) -> None:
if os.path.exists(self.inputPath):
os.unlink(self.inputPath)
if os.path.exists(self.outputPath):
os.unlink(self.outputPath)

def saveInput(self, contents, text=True):
with open(self.inputPath, "wt" if text else "wb") as f:
f.write(contents)

def loadOutput(self, text=True):
with open(self.outputPath, "rt" if text else "rb") as f:
return f.read()

def testStripDataWithImport(self):
self.saveInput(
r"""
builtin.module {
func.func @main() -> tensor<4xf32> {
%0 = arith.constant dense<[0.1, 0.2, 0.3, 0.4]> : tensor<4xf32>
func.return %0 : tensor<4xf32>
}
}
"""
)
run_tool("strip-data", self.inputPath, "-o", self.outputPath)
output = self.loadOutput()
print("Output:", output)
self.assertIn("#util.byte_pattern", output)

def testStripDataNoImport(self):
# Without import, ml_program.global is not recognized.
self.saveInput(
r"""
builtin.module {
ml_program.global public @foobar(dense<[0.1, 0.2, 0.3, 0.4]> : tensor<4xf32>) : tensor<4xf32>
}
"""
)
run_tool("strip-data", "--no-import", self.inputPath, "-o", self.outputPath)
output = self.loadOutput()
print("Output:", output)
self.assertNotIn("#util.byte_pattern", output)

def testStripDataParseError(self):
self.saveInput(
r"""
FOOBAR
"""
)
with self.assertRaisesRegex(RuntimeError, "Error parsing source file"):
run_tool("strip-data", self.inputPath, "-o", self.outputPath)

def testStripDataEmitBytecode(self):
self.saveInput(
r"""
builtin.module {
}
"""
)
run_tool("strip-data", "--emit-bytecode", self.inputPath, "-o", self.outputPath)
output = self.loadOutput(text=False)
self.assertIn(b"MLIR", output)

def testStripDataEmitBytecodeVersion(self):
self.saveInput(
r"""
builtin.module {
}
"""
)
run_tool(
"strip-data",
"--emit-bytecode",
"--bytecode-version=0",
self.inputPath,
"-o",
self.outputPath,
)
output = self.loadOutput(text=False)
self.assertIn(b"MLIR", output)


if __name__ == "__main__":
unittest.main()
1 change: 1 addition & 0 deletions compiler/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,7 @@ def find_git_submodule_revision(submodule_path):
# TODO: We have renamed to iree-compile on 2022-03-18. Remove
# this alias once no longer needed.
"ireec = iree.compiler.tools.scripts.ireec.__main__:main",
"iree-ir-tool = iree.compiler.tools.ir_tool.__main__:_cli_main",
],
},
install_requires=[
Expand Down
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/API/Internal/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ iree_compiler_cc_library(
"//compiler/src/iree/compiler/Utils",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:BuiltinToLLVMIRTranslation",
"@llvm-project//mlir:BytecodeWriter",
"@llvm-project//mlir:CAPIIR",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Parser",
Expand Down
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/API/Internal/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ iree_cc_library(
DEPS
LLVMSupport
MLIRBuiltinToLLVMIRTranslation
MLIRBytecodeWriter
MLIRCAPIIR
MLIRIR
MLIRParser
Expand Down
Loading

0 comments on commit ecc49f6

Please sign in to comment.