-
Notifications
You must be signed in to change notification settings - Fork 608
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add a CLI
iree-ir-tool
with a command to strip data. (#14636)
Been meaning to do this for a while in order to have a place to stash more power-user style things that core developers typically use iree-opt for. This one adds a `strip-data` sub-command which uses the passes from #14627 to systematically replace tensor constants with synthetic values. With an ASAN/asserts build, this was able to strip a 7GiB int4 vicuna MLIR file in ~5s and a 23GiB h2ogpt model in about 30s (the latter has some other characteristics which make it more expensive to load as well as being bigger). Results were a 2.6MiB and 1.4MiB MLIR file respectively, consisting just of program IR and annotations for synthetic data. Getting the opt pipeline right for arbitrary input is a bit tricky, so I decided we should just armor this into a tool From installed packages, this can be used as: ``` iree-ir-tool strip-data input.mlir -o output.mlir ``` From a build tree with Python setup: ``` python -m iree.compiler.tools.ir_tool strip-data input.mlir -o output.mlir ``` Required adding some additional compiler APIs: * `ireeCompilerInvocationRunPassPipeline` to run an arbitrary textual pass pipeline on an invocation. * `ireeCompilerInvocationOutputIRBytecode` to emit bytecode from an invocation.
- Loading branch information
1 parent
57b9239
commit ecc49f6
Showing
16 changed files
with
381 additions
and
35 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
106 changes: 106 additions & 0 deletions
106
compiler/bindings/python/iree/compiler/tools/ir_tool/__main__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
# Copyright 2023 The IREE Authors | ||
# | ||
# Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
import argparse | ||
import logging | ||
import sys | ||
|
||
from ...api import Invocation, Session, Source, Output | ||
|
||
|
||
def load_source(inv: Invocation, input_file: str) -> Source: | ||
source = Source.open_file(inv.session, input_file) | ||
if not inv.parse_source(source): | ||
raise RuntimeError(f"Error parsing source file {input_file}") | ||
return source | ||
|
||
|
||
def write_output(inv: Invocation, output: Output, args, keep: bool = True): | ||
if args.emit_bytecode: | ||
inv.output_ir_bytecode(output, args.bytecode_version) | ||
else: | ||
inv.output_ir(output) | ||
if keep: | ||
output.keep() | ||
|
||
|
||
############################################################################### | ||
# CLI handling | ||
############################################################################### | ||
|
||
|
||
def parse_arguments(argv=None): | ||
parser = argparse.ArgumentParser(description="IREE IR Tool") | ||
subparsers = parser.add_subparsers( | ||
help="sub-command help", required=True, dest="sub_command" | ||
) | ||
|
||
def add_ouptut_options(subparser): | ||
subparser.add_argument( | ||
"--emit-bytecode", action="store_true", help="Emit bytecode" | ||
) | ||
subparser.add_argument( | ||
"--bytecode-version", | ||
default=-1, | ||
type=int, | ||
help="Bytecode version to emit or -1 for latest", | ||
) | ||
|
||
# strip-data command. | ||
strip_data_parser = subparsers.add_parser( | ||
"strip-data", | ||
help="Strip large constants and values, " | ||
"replacing them with pseudo data suitable for interactive " | ||
"debugging of IR", | ||
) | ||
add_ouptut_options(strip_data_parser) | ||
strip_data_parser.add_argument( | ||
"--no-import", | ||
action="store_true", | ||
help="Disable import of public dialects to internal", | ||
) | ||
strip_data_parser.add_argument("input_file", help="File to process") | ||
strip_data_parser.add_argument( | ||
"-o", required=True, dest="output_file", help="Output file" | ||
) | ||
args = parser.parse_args(argv) | ||
return args | ||
|
||
|
||
def main(args) -> int: | ||
if args.sub_command == "strip-data": | ||
return do_strip_data(args) | ||
else: | ||
print("error: Unrecognized sub-command {args.sub_command}", file=sys.stderr) | ||
return 1 | ||
return 0 | ||
|
||
|
||
def do_strip_data(args) -> int: | ||
session = Session() | ||
output = Output.open_file(args.output_file) | ||
inv = session.invocation() | ||
inv.enable_console_diagnostics() | ||
load_source(inv, args.input_file) | ||
if not args.no_import: | ||
if not inv.execute_text_pass_pipeline( | ||
"iree-import-public, iree-import-ml-program" | ||
): | ||
return 1 | ||
if not inv.execute_text_pass_pipeline( | ||
"iree-util-outline-constants, iree-util-strip-and-splat-constants" | ||
): | ||
return 2 | ||
write_output(inv, output, args) | ||
return 0 | ||
|
||
|
||
def _cli_main(): | ||
sys.exit(main(parse_arguments())) | ||
|
||
|
||
if __name__ == "__main__": | ||
_cli_main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
# Copyright 2023 The IREE Authors | ||
# | ||
# Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
from iree.compiler.tools.ir_tool import __main__ | ||
|
||
import os | ||
import tempfile | ||
import unittest | ||
|
||
|
||
def run_tool(*argv: str): | ||
try: | ||
args = __main__.parse_arguments(list(argv)) | ||
__main__.main(args) | ||
except SystemExit as e: | ||
if e.code != 0: | ||
raise RuntimeError(f"Tool exited with code {e.code}") | ||
|
||
|
||
class IrToolTest(unittest.TestCase): | ||
def setUp(self): | ||
with tempfile.NamedTemporaryFile(delete=False) as f: | ||
self.inputPath = f.name | ||
with tempfile.NamedTemporaryFile(delete=False) as f: | ||
self.outputPath = f.name | ||
|
||
def tearDown(self) -> None: | ||
if os.path.exists(self.inputPath): | ||
os.unlink(self.inputPath) | ||
if os.path.exists(self.outputPath): | ||
os.unlink(self.outputPath) | ||
|
||
def saveInput(self, contents, text=True): | ||
with open(self.inputPath, "wt" if text else "wb") as f: | ||
f.write(contents) | ||
|
||
def loadOutput(self, text=True): | ||
with open(self.outputPath, "rt" if text else "rb") as f: | ||
return f.read() | ||
|
||
def testStripDataWithImport(self): | ||
self.saveInput( | ||
r""" | ||
builtin.module { | ||
func.func @main() -> tensor<4xf32> { | ||
%0 = arith.constant dense<[0.1, 0.2, 0.3, 0.4]> : tensor<4xf32> | ||
func.return %0 : tensor<4xf32> | ||
} | ||
} | ||
""" | ||
) | ||
run_tool("strip-data", self.inputPath, "-o", self.outputPath) | ||
output = self.loadOutput() | ||
print("Output:", output) | ||
self.assertIn("#util.byte_pattern", output) | ||
|
||
def testStripDataNoImport(self): | ||
# Without import, ml_program.global is not recognized. | ||
self.saveInput( | ||
r""" | ||
builtin.module { | ||
ml_program.global public @foobar(dense<[0.1, 0.2, 0.3, 0.4]> : tensor<4xf32>) : tensor<4xf32> | ||
} | ||
""" | ||
) | ||
run_tool("strip-data", "--no-import", self.inputPath, "-o", self.outputPath) | ||
output = self.loadOutput() | ||
print("Output:", output) | ||
self.assertNotIn("#util.byte_pattern", output) | ||
|
||
def testStripDataParseError(self): | ||
self.saveInput( | ||
r""" | ||
FOOBAR | ||
""" | ||
) | ||
with self.assertRaisesRegex(RuntimeError, "Error parsing source file"): | ||
run_tool("strip-data", self.inputPath, "-o", self.outputPath) | ||
|
||
def testStripDataEmitBytecode(self): | ||
self.saveInput( | ||
r""" | ||
builtin.module { | ||
} | ||
""" | ||
) | ||
run_tool("strip-data", "--emit-bytecode", self.inputPath, "-o", self.outputPath) | ||
output = self.loadOutput(text=False) | ||
self.assertIn(b"MLIR", output) | ||
|
||
def testStripDataEmitBytecodeVersion(self): | ||
self.saveInput( | ||
r""" | ||
builtin.module { | ||
} | ||
""" | ||
) | ||
run_tool( | ||
"strip-data", | ||
"--emit-bytecode", | ||
"--bytecode-version=0", | ||
self.inputPath, | ||
"-o", | ||
self.outputPath, | ||
) | ||
output = self.loadOutput(text=False) | ||
self.assertIn(b"MLIR", output) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.