From ae9a51c654b97f82eed9f1de145aef5a83aea8e9 Mon Sep 17 00:00:00 2001 From: Boian Petkantchin Date: Wed, 23 Oct 2024 12:41:05 -0400 Subject: [PATCH] Add device affinities for arguments in AOT (#231) We don't have support for providing device affinities for function arguments, which need to end up as MLIR function argument attributes. This change adds a class DeviceAffinity and provides the ability to supply affinities when exporting Torch functions/modules or when tracing in IREE-Trubine itself. Signed-off-by: Boian Petkantchin --- iree/turbine/aot/compiled_module.py | 87 ++++++++++++++++++- iree/turbine/aot/exporter.py | 20 ++++- iree/turbine/aot/fx_programs.py | 22 ++++- iree/turbine/aot/support/ir_utils.py | 55 +++++++++++- .../support/procedural/exported_program.py | 15 ++++ iree/turbine/aot/support/procedural/tracer.py | 20 ++++- iree/turbine/aot/tensor_traits.py | 16 ++++ iree/turbine/support/ir_imports.py | 2 + tests/aot/args_test.py | 18 ++++ tests/aot/compiled_exported_program_test.py | 24 +++++ tests/aot/fx_programs_test.py | 8 +- tests/aot/fx_programs_test_device.py | 47 ++++++++++ 12 files changed, 315 insertions(+), 19 deletions(-) create mode 100644 tests/aot/fx_programs_test_device.py diff --git a/iree/turbine/aot/compiled_module.py b/iree/turbine/aot/compiled_module.py index 270534e1..fc1c104b 100644 --- a/iree/turbine/aot/compiled_module.py +++ b/iree/turbine/aot/compiled_module.py @@ -44,6 +44,8 @@ ModuleBuilderOptions, ) +from .tensor_traits import DeviceAffinity + __all__ = [ "CompiledModule", @@ -107,12 +109,27 @@ def __call__(self, *args, **kwargs): return self.py_value(*args, **kwargs) +class ExportTargetDef: + def __init__( + self, + target: Union[Callable, ExportedProgram], + *, + arg_device: dict[int, DeviceAffinity] | None = None, + ): + self.target = target + self.arg_device = arg_device + + def __call__(self, *args, **kwargs): + return self.target(*args, **kwargs) + + class ExportProcDef: __slots__ = [ "callable", "export_name", "signature", "file_line_loc", + "arg_device", ] def __init__( @@ -122,14 +139,22 @@ def __init__( *, signature, file_line_loc: Optional[Tuple[str, int]] = None, + arg_device: dict[int, DeviceAffinity] | None = None, ): self.export_name = export_name self.callable = callable self.signature = signature self.file_line_loc = file_line_loc + self.arg_device = arg_device def copy(self) -> "ExportProcDef": - return ExportProcDef(self.export_name, self.callable, signature=self.signature) + return ExportProcDef( + self.export_name, + self.callable, + signature=self.signature, + file_line_loc=self.file_line_loc, + arg_device=self.arg_device, + ) def __repr__(self): return f"" @@ -142,14 +167,19 @@ def __init__( *, export_name: Optional[str] = None, public: bool = False, + arg_device: dict[int, DeviceAffinity] | None = None, ): self.export_name = export_name self.exported_program = ep self.public = public + self.arg_device = arg_device def copy(self) -> "ExportedProgramDef": return ExportedProgramDef( - self.exported_program, export_name=self.export_name, public=self.public + self.exported_program, + export_name=self.export_name, + public=self.public, + arg_device=self.arg_device, ) def __repr__(self): @@ -207,6 +237,19 @@ def globals_defs(self) -> Generator[Tuple[str, GlobalsDef], None, None]: ) # type: ignore def def_attribute(self, key, value): + if isinstance(value, ExportTargetDef): + if not isinstance(value.target, ExportedProgram): + # We expect exported function. + assert callable(value.target) and inspect.isfunction(value.target) + return self.def_export_proc(key, value.target, value.arg_device) + + value = ExportedProgramDef( + value.target, + export_name=key, + public=not key.startswith("_"), + arg_device=value.arg_device, + ) + # Some decorators, the only thing we do is convert them to PyOnlyDef. # Do that first so the generic descriptor code below handles them. if isinstance(value, builtins.jittable): @@ -233,6 +276,15 @@ def def_attribute(self, key, value): logging.debug("DEFINE PY_ONLY: %s = %r", key, value) self.add_export(key, value) return value + if isinstance(value, ExportTargetDef) and isinstance( + value.target, ExportedProgram + ): + value = ExportedProgramDef( + value.target, + export_name=key, + public=not key.startswith("_"), + arg_device=value.arg_device, + ) if isinstance(value, ExportedProgramDef): if value.export_name is None: value = value.copy() @@ -250,7 +302,12 @@ def def_attribute(self, key, value): f"compiled module: {value!r}" ) - def def_export_proc(self, name, f) -> ExportProcDef: + def def_export_proc( + self, + name, + f, + arg_device: dict[int, DeviceAffinity] | None = None, + ) -> ExportProcDef: logging.debug("DEFINE EXPORT: %s = %r", name, f) # Get a reasonable location. file_line_loc = None @@ -292,7 +349,13 @@ def def_export_proc(self, name, f) -> ExportProcDef: ) input_sig.append(param_desc) - info = ExportProcDef(name, f, signature=input_sig, file_line_loc=file_line_loc) + info = ExportProcDef( + name, + f, + signature=input_sig, + file_line_loc=file_line_loc, + arg_device=arg_device, + ) self.add_export(name, info) return info @@ -568,6 +631,20 @@ def save_mlir(inst: "CompiledModule", path: Union[Path, str]): jittable = staticmethod(builtins.jittable) + @staticmethod + def signature_info( + *, + arg_device: dict[int, DeviceAffinity] | None = None, + ) -> Callable: + """Annotate an export target function. + This annotation is only required when additional information needs to be + provided.""" + + def _decorator(f: Callable): + return ExportTargetDef(f, arg_device=arg_device) + + return _decorator + def __getattr__(self, name): info = CompiledModule.get_info(self) try: @@ -633,6 +710,7 @@ def __new__( ep_def.exported_program, symbol_name=ep_def.export_name or "main", symbol_visibility=None if ep_def.public else "private", + arg_device=ep_def.arg_device, ) # Instantiate procs. @@ -661,6 +739,7 @@ def invoke_with_self(*args, **kwargs): posargs=proc_def.signature, kwargs={}, # TODO(#128): kwargs loc=loc, + arg_device=proc_def.arg_device, ) trace.trace_py_func(invoke_with_self) info.shadow_dict[key] = _uncallable_public_export diff --git a/iree/turbine/aot/exporter.py b/iree/turbine/aot/exporter.py index c1adb527..dbd859ac 100644 --- a/iree/turbine/aot/exporter.py +++ b/iree/turbine/aot/exporter.py @@ -32,6 +32,8 @@ from .fx_programs import FxPrograms from . import decompositions +from .tensor_traits import DeviceAffinity + __all__ = [ "export", "ExportOutput", @@ -177,6 +179,7 @@ def export( function_name: Optional[str] = None, strict_export: bool = True, import_symbolic_shape_expressions: bool = False, + arg_device: dict[int, DeviceAffinity] | None = None, ) -> ExportOutput: """Exports a torch.nn.Module. @@ -199,6 +202,7 @@ def export( *, module_name: Optional[str] = None, function_name: Optional[str] = None, + arg_device: dict[int, DeviceAffinity] | None = None, ) -> ExportOutput: """Exports a single entry-point module consisting of an ExportedProgram.""" ... @@ -226,6 +230,7 @@ def export( function_name: Optional[str] = None, strict_export: bool = True, import_symbolic_shape_expressions: bool = False, + arg_device: dict[int, DeviceAffinity] | None = None, ) -> ExportOutput: """Generic export of supported entities. @@ -247,6 +252,10 @@ def export( must be empty. kwargs: Example keyword arguments. dynamic_shapes: Dynamic shape specs to pass to torch.export. + arg_device: device affinities for the exported function + arguments. On what devices should the program expect its arguments. + It is a mapping of argument index to device affinity of the flattened + arguments. Returns: An ExportOutput object that wraps the compilation and provides @@ -266,12 +275,14 @@ def export( "This is an experimental feature in PyTorch that the IREE Turbine project is still evaluating. Please report issues or experiences." ) + from .compiled_module import ExportTargetDef + TransformedModule: Any current_decomps = decompositions.current_aot_decompositions() if isinstance(mdl, torch.export.ExportedProgram): TransformedModule = CompiledModule.create_from_dict( "LambdaCompiledModule", - {(function_name or "main"): mdl}, + {(function_name or "main"): ExportTargetDef(mdl, arg_device=arg_device)}, export_name=module_name or "module", options=ModuleBuilderOptions( import_symbolic_shape_expressions=import_symbolic_shape_expressions, @@ -311,7 +322,12 @@ def export( TransformedModule = CompiledModule.create_from_dict( "LambdaCompiledModule", - {(function_name or "main"): exported_program}, + { + (function_name or "main"): ExportTargetDef( + exported_program, + arg_device=arg_device, + ) + }, export_name=module_name or "module", options=ModuleBuilderOptions( import_symbolic_shape_expressions=import_symbolic_shape_expressions, diff --git a/iree/turbine/aot/fx_programs.py b/iree/turbine/aot/fx_programs.py index 696f9a00..1bfd21f0 100644 --- a/iree/turbine/aot/fx_programs.py +++ b/iree/turbine/aot/fx_programs.py @@ -14,6 +14,7 @@ import os from pathlib import Path from typing import Any, Optional, Union +from .compiled_module import ExportTargetDef import functools @@ -21,6 +22,12 @@ import torch.nn as nn from .decompositions import current_aot_decompositions +from .tensor_traits import DeviceAffinity + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from .compiled_module import ExportTargetDef # The dynamic_shapes support showed up in the Torch 2.3 timeframe. _supports_dynamic_shapes = hasattr(torch.export, "Dim") @@ -61,7 +68,7 @@ class FxPrograms: """ def __init__(self): - self.programs: dict[str, torch.export.ExportedProgram] = {} + self.programs: dict[str, ExportTargetDef] = {} def save(self, path: Union[str, os.PathLike]) -> int: """Saves the set of exported programs to a descriptor file. @@ -86,7 +93,9 @@ def permute_path(name): count_deduped = 0 # Save each. - for program_name, ep in self.programs.items(): + for program_name, export_def in self.programs.items(): + ep = export_def.target + assert isinstance(ep, torch.export.ExportedProgram) # First validate the ep with normal rules, which we will then # disable since we are violating the spec. ep._validate() @@ -129,7 +138,7 @@ def load(path: Union[str, os.PathLike]) -> "FxPrograms": ep = torch.export.load(path.parent / program_file_name) _unsharify_state_dict(shared_state_dict, ep.state_dict) _unsharify_state_dict(shared_constants, _get_optional_constants(ep)) - instance.programs[program_name] = ep + instance.programs[program_name] = ExportTargetDef(ep) return instance @@ -169,6 +178,7 @@ def export_program( dynamic_shapes=None, strict: bool = True, name: Optional[str] = None, + arg_device: dict[int, DeviceAffinity] | None = None, ): if f is None: return functools.partial( @@ -178,6 +188,7 @@ def export_program( strict=strict, dynamic_shapes=dynamic_shapes, name=name, + arg_device=arg_device, ) if name is None: @@ -234,7 +245,10 @@ def new_forward(self, *forward_args, **forward_kwargs): _patch_op_dispatch_for_export() program = program.run_decompositions(current_decomps) - fx_builder.programs[name] = program + fx_builder.programs[name] = ExportTargetDef( + program, + arg_device=arg_device, + ) return program diff --git a/iree/turbine/aot/support/ir_utils.py b/iree/turbine/aot/support/ir_utils.py index ee51f7b9..91f363d2 100644 --- a/iree/turbine/aot/support/ir_utils.py +++ b/iree/turbine/aot/support/ir_utils.py @@ -10,6 +10,7 @@ from dataclasses import dataclass from pathlib import Path import tempfile +from itertools import zip_longest import numpy as np import torch @@ -26,11 +27,13 @@ ) from ...support.ir_imports import ( - AsmState, + ArrayAttr, Attribute, BF16Type, + Context, DenseElementsAttr, DenseResourceElementsAttr, + DictAttr, F16Type, F32Type, F64Type, @@ -63,6 +66,7 @@ from ...support.logging import aot_logger as logger from ..tensor_traits import ( + DeviceAffinity, ExternalTensorTrait, ) @@ -235,6 +239,8 @@ def create_func_op( argument_types: Sequence[IrType], is_public: bool = True, add_entry_block: bool = True, + # Array of DictAttr corresponding to the attributes for each argument. + argument_attributes: ArrayAttr | list[DictAttr] | None = None, ) -> Tuple[str, func_d.FuncOp]: with self.ip: ftype = FunctionType.get(argument_types, []) @@ -245,6 +251,8 @@ def create_func_op( func_op.add_entry_block() self.symbol_table.insert(func_op) actual_symbol_name = StringAttr(func_op.attributes["sym_name"]).value + if argument_attributes is not None: + func_op.arg_attrs = argument_attributes return actual_symbol_name, func_op def torch_dtype_to_iree_type(self, dtype: torch.dtype) -> IrType: @@ -470,3 +478,48 @@ def _is_float_type(type): def _is_integer_like_type(type): return isinstance(type, (IntegerType, IndexType)) + + +def _attribute_from_device_affinity( + affinity: DeviceAffinity, context: Context +) -> Attribute: + return Attribute.parse( + f'#hal.device.promise<@"__device_{affinity.ordinal}">', context + ) + + +def attributes_from_argument_device_affinities( + affinities: dict[int, DeviceAffinity] | None, + arguments_count: int, + context: Context, +) -> list[dict[str, Attribute]]: + """Get as attributes for function op arguments.""" + if affinities is None: + return [{} for _ in range(arguments_count)] + return [ + {"iree.abi.affinity": _attribute_from_device_affinity(affinities[i], context)} + if i in affinities + else {} + for i in range(arguments_count) + ] + + +def update_func_op_argument_attributes( + func_op: func_d.FuncOp, attributes: list[dict[str, Attribute]] +): + if func_d.ARGUMENT_ATTRIBUTE_NAME not in func_op.attributes: + mutable_arg_attrs: list[dict[str, Attribute]] = [ + {} for _ in range(len(func_op.arguments)) + ] + else: + mutable_arg_attrs = [ + {named_attr.name: named_attr.attr for named_attr in dict_attr} + for dict_attr in func_op.arg_attrs + ] + + for src, dst in zip_longest(attributes, mutable_arg_attrs): + dst.update(src) + + func_op.arg_attrs = [ + DictAttr.get(d, context=func_op.context) for d in mutable_arg_attrs + ] diff --git a/iree/turbine/aot/support/procedural/exported_program.py b/iree/turbine/aot/support/procedural/exported_program.py index f6540bab..14a47d7f 100644 --- a/iree/turbine/aot/support/procedural/exported_program.py +++ b/iree/turbine/aot/support/procedural/exported_program.py @@ -45,12 +45,15 @@ ) from ...tensor_traits import ( + DeviceAffinity, ExternalTensorTrait, ) from ..ir_utils import ( + attributes_from_argument_device_affinities, GlobalAttributes, ModuleBuilder, + update_func_op_argument_attributes, ) from .base import ( @@ -71,6 +74,9 @@ IrTrace, ) +from typing import TYPE_CHECKING + + # Limit of tensor volumes. Over this limit, otherwise uncategorized tensor # constants will be emitted out-of-line. Under the limit, inline. INLINE_TENSOR_VOLUME_LIMIT = 1024 @@ -178,6 +184,7 @@ def import_exported_program( exported_program: torch.export.ExportedProgram, symbol_name: str, symbol_visibility: Optional[str], + arg_device: dict[int, DeviceAffinity] | None, ) -> ExportedProgramIntrinsic: fx_importer = _create_fx_importer(module_builder) entry_func_op = fx_importer.import_program( @@ -186,6 +193,14 @@ def import_exported_program( func_visibility=symbol_visibility, import_symbolic_shape_expressions=module_builder.options.import_symbolic_shape_expressions, ) + update_func_op_argument_attributes( + entry_func_op, + attributes_from_argument_device_affinities( + arg_device, + len(entry_func_op.arguments), + entry_func_op.context, + ), + ) module_call_graph = exported_program.module_call_graph assert len(module_call_graph) >= 1, "Expected at least one module call signature" diff --git a/iree/turbine/aot/support/procedural/tracer.py b/iree/turbine/aot/support/procedural/tracer.py index 19342deb..252065c4 100644 --- a/iree/turbine/aot/support/procedural/tracer.py +++ b/iree/turbine/aot/support/procedural/tracer.py @@ -21,6 +21,7 @@ ) from ....support.ir_imports import ( + DictAttr, Location, StringAttr, Value, @@ -29,9 +30,7 @@ from ....support.logging import aot_logger as logger -from ..ir_utils import ( - ModuleBuilder, -) +from ..ir_utils import ModuleBuilder, attributes_from_argument_device_affinities from .base import ( AbstractIntrinsic, @@ -45,6 +44,8 @@ LiveGlobalCollectionProxy, ) +from ...tensor_traits import DeviceAffinity + ############################################################################### # Concrete procedure building IrTracer. ############################################################################### @@ -78,6 +79,7 @@ def define_func( posargs: Sequence, kwargs: dict, loc: Location, + arg_device: dict[int, DeviceAffinity] | None = None, ) -> "ProcedureTrace": # Unpack arguments. arguments_flat, arguments_tree_def = tree_flatten((posargs, kwargs)) @@ -88,7 +90,17 @@ def define_func( argument_ir_types.append(arg.get_ir_type(module_builder)) with loc: - _, func_op = module_builder.create_func_op(symbol_name, argument_ir_types) + argument_attributes = [ + DictAttr.get(d) + for d in attributes_from_argument_device_affinities( + arg_device, + arguments_count=len(argument_ir_types), + context=module_builder.context, + ) + ] + _, func_op = module_builder.create_func_op( + symbol_name, argument_ir_types, argument_attributes=argument_attributes + ) # Bind proxy arguments to an IR value. ir_proxy_arguments_flat = [] diff --git a/iree/turbine/aot/tensor_traits.py b/iree/turbine/aot/tensor_traits.py index bb7a5280..97283784 100644 --- a/iree/turbine/aot/tensor_traits.py +++ b/iree/turbine/aot/tensor_traits.py @@ -11,10 +11,26 @@ __all__ = [ + "DeviceAffinity", "ExternalTensorTrait", ] +class DeviceAffinity: + """This is used to provide device affinities to exported function arguments.""" + + def __init__(self, ordinal: int): + self.ordinal = ordinal + + def __eq__(self, other) -> bool: + if not isinstance(other, DeviceAffinity): + return False + return self.ordinal == other.ordinal + + def __repr__(self) -> str: + return f"DeviceAffinity({self.ordinal})" + + @dataclass class ExternalTensorTrait: """Represents a 'trait' that can be applied to a Tensor to signal that diff --git a/iree/turbine/support/ir_imports.py b/iree/turbine/support/ir_imports.py index 09aa4042..1803f16a 100644 --- a/iree/turbine/support/ir_imports.py +++ b/iree/turbine/support/ir_imports.py @@ -8,6 +8,7 @@ """Unifies all imports of iree.compiler.ir into one place.""" from iree.compiler.ir import ( + ArrayAttr, AsmState, Attribute, Block, @@ -15,6 +16,7 @@ Context, DenseElementsAttr, DenseResourceElementsAttr, + DictAttr, FlatSymbolRefAttr, FloatAttr, FunctionType, diff --git a/tests/aot/args_test.py b/tests/aot/args_test.py index efbce489..2910fa2e 100644 --- a/tests/aot/args_test.py +++ b/tests/aot/args_test.py @@ -65,6 +65,24 @@ def compute(a, b): msg=f"Did not find two linalg.generics in module: module_str", ) + def testDeviceAffinities(self): + class ProcArgsModule(CompiledModule): + @CompiledModule.signature_info(arg_device={1: DeviceAffinity(1)}) + def foobar(self, a=AbstractTensor(3, 2), b=AbstractTensor(1, 1)): + return a, b + + inst = ProcArgsModule(context=Context(), import_to="import") + module_str = str(CompiledModule.get_mlir_module(inst)) + print(module_str) + self.assertRegex( + module_str, + ( + "func.func @foobar\(" + "%.+: tensor<3x2xf32>, " + "%.+: tensor<1x1xf32> {iree.abi.affinity = #hal.device.promise<@__device_1>}\)" + ), + ) + if __name__ == "__main__": logging.basicConfig(level=logging.DEBUG) diff --git a/tests/aot/compiled_exported_program_test.py b/tests/aot/compiled_exported_program_test.py index 6b86b185..d14e8ac6 100644 --- a/tests/aot/compiled_exported_program_test.py +++ b/tests/aot/compiled_exported_program_test.py @@ -177,6 +177,30 @@ class BuffersAsGlobalsModule(CompiledModule): self.assertIn("%_buffers.buf = util.global.load @_buffers.buf", module_str) self.assertIn("util.global.store", module_str) + def testDeviceAffinities(self): + class Module(torch.nn.Module): + def forward(self, x, y): + return x, y + + module = Module() + export_output = export( + module, + function_name="foo", + args=(torch.empty(1, dtype=torch.int8), torch.empty(2, dtype=torch.int8)), + arg_device={1: DeviceAffinity(1)}, + ) + asm = str(export_output.mlir_module) + print(asm) + self.assertRegex( + asm, + ( + "func.func @foo\(" + "%.+: !torch.vtensor<\[1\],si8>, " + "%.+: !torch.vtensor<\[2\],si8> " + "{iree.abi.affinity = #hal.device.promise<@__device_1>}\)" + ), + ) + class SimpleParams(nn.Module): def __init__(self): diff --git a/tests/aot/fx_programs_test.py b/tests/aot/fx_programs_test.py index f2c70456..d5241654 100644 --- a/tests/aot/fx_programs_test.py +++ b/tests/aot/fx_programs_test.py @@ -61,10 +61,10 @@ def bs32(module: M, x1, x2): prog_0 = new_programs.programs["dynamic_batch"] prog_1 = new_programs.programs["bs32"] - for key, value_0 in prog_0.state_dict.items(): - value_1 = prog_1.state_dict[key] + for key, value_0 in prog_0.target.state_dict.items(): + value_1 = prog_1.target.state_dict[key] assert value_0 is value_1, f"State dict item {key} was not aliased on load" - for key, value_0 in prog_0.constants.items(): - value_1 = prog_1.constants[key] + for key, value_0 in prog_0.target.constants.items(): + value_1 = prog_1.target.constants[key] assert value_0 is value_1, f"Constant item {key} was not aliased on load" diff --git a/tests/aot/fx_programs_test_device.py b/tests/aot/fx_programs_test_device.py new file mode 100644 index 00000000..2c2ec65b --- /dev/null +++ b/tests/aot/fx_programs_test_device.py @@ -0,0 +1,47 @@ +# Copyright 2024 Advanced Micro Devices, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from pathlib import Path +import unittest + +import torch + +from iree.turbine.aot import ( + DeviceAffinity, + export, + FxProgramsBuilder, +) + + +class FxProgramsTestDevice(unittest.TestCase): + def test_argument_device_affinities(self): + class Module(torch.nn.Module): + def main(self, x1, x2): + return x1, x2 + + args = ( + torch.empty(2, 3, dtype=torch.int8), + torch.empty(4, 5, dtype=torch.int8), + ) + fxb = FxProgramsBuilder(Module()) + + @fxb.export_program( + args=args, + arg_device={0: DeviceAffinity(0), 1: DeviceAffinity(1)}, + ) + def main(module: Module, x1, x2): + return module.main(x1, x2) + + output = export(fxb) + asm = str(output.mlir_module) + self.assertRegex( + asm, + ( + "func.func @main\(" + "%.+: !torch.vtensor<\[2,3\],si8> {iree.abi.affinity = #hal.device.promise<@__device_0>}, " + "%.+: !torch.vtensor<\[4,5\],si8> {iree.abi.affinity = #hal.device.promise<@__device_1>}\)" + ), + )