Skip to content

Commit

Permalink
Add device affinities for arguments in AOT
Browse files Browse the repository at this point in the history
We don't have support for providing device affinities for function
arguments, which need to end up as MLIR function argument attributes.

This change adds a class DeviceAffinity and provides the ability to
supply affinities when exporting Torch functions/modules or when
tracing in IREE-Trubine itself.

Signed-off-by: Boian Petkantchin <boian.petkantchin@amd.com>
  • Loading branch information
sogartar committed Oct 18, 2024
1 parent dc1060b commit 3f8b34a
Show file tree
Hide file tree
Showing 10 changed files with 351 additions and 66 deletions.
85 changes: 81 additions & 4 deletions iree/turbine/aot/compiled_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@
"CompiledModule",
]

from typing import TYPE_CHECKING

if TYPE_CHECKING:
from .exporter import DeviceAffinity

################################################################################
# Data structures
################################################################################
Expand Down Expand Up @@ -107,12 +112,27 @@ def __call__(self, *args, **kwargs):
return self.py_value(*args, **kwargs)


class ExportTargetDef:
def __init__(
self,
target: Union[Callable, ExportedProgram],
*,
argument_device_affinities: dict[int, "DeviceAffinity"] | None,
):
self.target = target
self.argument_device_affinities = argument_device_affinities

def __call__(self, *args, **kwargs):
return self.target(*args, **kwargs)


class ExportProcDef:
__slots__ = [
"callable",
"export_name",
"signature",
"file_line_loc",
"argument_device_affinities",
]

def __init__(
Expand All @@ -122,14 +142,22 @@ def __init__(
*,
signature,
file_line_loc: Optional[Tuple[str, int]] = None,
argument_device_affinities: dict[int, "DeviceAffinity"] | None = None,
):
self.export_name = export_name
self.callable = callable
self.signature = signature
self.file_line_loc = file_line_loc
self.argument_device_affinities = argument_device_affinities

def copy(self) -> "ExportProcDef":
return ExportProcDef(self.export_name, self.callable, signature=self.signature)
return ExportProcDef(
self.export_name,
self.callable,
signature=self.signature,
file_line_loc=self.file_line_loc,
argument_device_affinities=self.argument_device_affinities,
)

def __repr__(self):
return f"<def {self.export_name}({self.signature})>"
Expand All @@ -142,14 +170,19 @@ def __init__(
*,
export_name: Optional[str] = None,
public: bool = False,
argument_device_affinities: dict[int, "DeviceAffinity"] | None = None,
):
self.export_name = export_name
self.exported_program = ep
self.public = public
self.argument_device_affinities = argument_device_affinities

def copy(self) -> "ExportedProgramDef":
return ExportedProgramDef(
self.exported_program, export_name=self.export_name, public=self.public
self.exported_program,
export_name=self.export_name,
public=self.public,
argument_device_affinities=self.argument_device_affinities,
)

def __repr__(self):
Expand Down Expand Up @@ -207,6 +240,21 @@ def globals_defs(self) -> Generator[Tuple[str, GlobalsDef], None, None]:
) # type: ignore

def def_attribute(self, key, value):
if isinstance(value, ExportTargetDef):
if isinstance(value.target, ExportedProgram):
value = ExportedProgramDef(
value.target,
export_name=key,
public=not key.startswith("_"),
argument_device_affinities=value.argument_device_affinities,
)
else:
# We expect exported function.
assert callable(value.target) and inspect.isfunction(value.target)
return self.def_export_proc(
key, value.target, value.argument_device_affinities
)

# Some decorators, the only thing we do is convert them to PyOnlyDef.
# Do that first so the generic descriptor code below handles them.
if isinstance(value, builtins.jittable):
Expand Down Expand Up @@ -250,7 +298,12 @@ def def_attribute(self, key, value):
f"compiled module: {value!r}"
)

def def_export_proc(self, name, f) -> ExportProcDef:
def def_export_proc(
self,
name,
f,
argument_device_affinities: dict[int, "DeviceAffinity"] | None = None,
) -> ExportProcDef:
logging.debug("DEFINE EXPORT: %s = %r", name, f)
# Get a reasonable location.
file_line_loc = None
Expand Down Expand Up @@ -292,7 +345,13 @@ def def_export_proc(self, name, f) -> ExportProcDef:
)
input_sig.append(param_desc)

info = ExportProcDef(name, f, signature=input_sig, file_line_loc=file_line_loc)
info = ExportProcDef(
name,
f,
signature=input_sig,
file_line_loc=file_line_loc,
argument_device_affinities=argument_device_affinities,
)
self.add_export(name, info)
return info

Expand Down Expand Up @@ -568,6 +627,22 @@ def save_mlir(inst: "CompiledModule", path: Union[Path, str]):

jittable = staticmethod(builtins.jittable)

@staticmethod
def annotate(
*,
argument_device_affinities: dict[int, "DeviceAffinity"] | None = None,
) -> Callable:
"""Annotate an export target function.
This annotation is only required when additional information needs to be
provided."""

def _decorator(f: Callable):
return ExportTargetDef(
f, argument_device_affinities=argument_device_affinities
)

return _decorator

def __getattr__(self, name):
info = CompiledModule.get_info(self)
try:
Expand Down Expand Up @@ -633,6 +708,7 @@ def __new__(
ep_def.exported_program,
symbol_name=ep_def.export_name or "main",
symbol_visibility=None if ep_def.public else "private",
argument_device_affinities=ep_def.argument_device_affinities or {},
)

# Instantiate procs.
Expand Down Expand Up @@ -661,6 +737,7 @@ def invoke_with_self(*args, **kwargs):
posargs=proc_def.signature,
kwargs={}, # TODO(#128): kwargs
loc=loc,
argument_device_affinities=proc_def.argument_device_affinities,
)
trace.trace_py_func(invoke_with_self)
info.shadow_dict[key] = _uncallable_public_export
Expand Down
30 changes: 28 additions & 2 deletions iree/turbine/aot/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from . import decompositions

__all__ = [
"DeviceAffinity",
"export",
"ExportOutput",
]
Expand All @@ -49,6 +50,13 @@
SaveableTarget = Union[str, Path, None, Output]


class DeviceAffinity:
"""This is used to provide device affinities to exported function arguments."""

def __init__(self, moniker: str):
self.moniker = moniker


class ExportOutput:
"""Wrapper around a CompiledModule produced by `export`."""

Expand Down Expand Up @@ -177,6 +185,7 @@ def export(
function_name: Optional[str] = None,
strict_export: bool = True,
import_symbolic_shape_expressions: bool = False,
argument_device_affinities: dict[int, DeviceAffinity] | None = None,
) -> ExportOutput:
"""Exports a torch.nn.Module.
Expand All @@ -199,6 +208,7 @@ def export(
*,
module_name: Optional[str] = None,
function_name: Optional[str] = None,
argument_device_affinities: dict[int, DeviceAffinity] | None = None,
) -> ExportOutput:
"""Exports a single entry-point module consisting of an ExportedProgram."""
...
Expand Down Expand Up @@ -226,6 +236,7 @@ def export(
function_name: Optional[str] = None,
strict_export: bool = True,
import_symbolic_shape_expressions: bool = False,
argument_device_affinities: dict[int, DeviceAffinity] | None = None,
) -> ExportOutput:
"""Generic export of supported entities.
Expand All @@ -247,6 +258,10 @@ def export(
must be empty.
kwargs: Example keyword arguments.
dynamic_shapes: Dynamic shape specs to pass to torch.export.
argument_device_affinities: device affinities for the exported function
arguments. On what devices should the program expect its arguments.
It is a mapping of argument index to device affinity of the flattened
arguments.
Returns:
An ExportOutput object that wraps the compilation and provides
Expand All @@ -266,12 +281,18 @@ def export(
"This is an experimental feature in PyTorch that the IREE Turbine project is still evaluating. Please report issues or experiences."
)

from .compiled_module import ExportTargetDef

TransformedModule: Any
current_decomps = decompositions.current_aot_decompositions()
if isinstance(mdl, torch.export.ExportedProgram):
TransformedModule = CompiledModule.create_from_dict(
"LambdaCompiledModule",
{(function_name or "main"): mdl},
{
(function_name or "main"): ExportTargetDef(
mdl, argument_device_affinities=argument_device_affinities
)
},
export_name=module_name or "module",
options=ModuleBuilderOptions(
import_symbolic_shape_expressions=import_symbolic_shape_expressions,
Expand Down Expand Up @@ -311,7 +332,12 @@ def export(

TransformedModule = CompiledModule.create_from_dict(
"LambdaCompiledModule",
{(function_name or "main"): exported_program},
{
(function_name or "main"): ExportTargetDef(
exported_program,
argument_device_affinities=argument_device_affinities,
)
},
export_name=module_name or "module",
options=ModuleBuilderOptions(
import_symbolic_shape_expressions=import_symbolic_shape_expressions,
Expand Down
19 changes: 16 additions & 3 deletions iree/turbine/aot/fx_programs.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import os
from pathlib import Path
from typing import Any, Optional, Union
from .compiled_module import ExportTargetDef

import functools

Expand All @@ -22,6 +23,12 @@

from .decompositions import current_aot_decompositions

from typing import TYPE_CHECKING

if TYPE_CHECKING:
from .exporter import DeviceAffinity
from .compiled_module import ExportTargetDef

# The dynamic_shapes support showed up in the Torch 2.3 timeframe.
_supports_dynamic_shapes = hasattr(torch.export, "Dim")

Expand Down Expand Up @@ -61,7 +68,7 @@ class FxPrograms:
"""

def __init__(self):
self.programs: dict[str, torch.export.ExportedProgram] = {}
self.programs: dict[str, ExportTargetDef] = {}

def save(self, path: Union[str, os.PathLike]) -> int:
"""Saves the set of exported programs to a descriptor file.
Expand All @@ -86,7 +93,8 @@ def permute_path(name):
count_deduped = 0

# Save each.
for program_name, ep in self.programs.items():
for program_name, export_def in self.programs.items():
ep: torch.export.ExportedProgram = export_def.target
# First validate the ep with normal rules, which we will then
# disable since we are violating the spec.
ep._validate()
Expand Down Expand Up @@ -169,6 +177,7 @@ def export_program(
dynamic_shapes=None,
strict: bool = True,
name: Optional[str] = None,
argument_device_affinities: dict[int, "DeviceAffinity"] | None = None,
):
if f is None:
return functools.partial(
Expand All @@ -178,6 +187,7 @@ def export_program(
strict=strict,
dynamic_shapes=dynamic_shapes,
name=name,
argument_device_affinities=argument_device_affinities,
)

if name is None:
Expand Down Expand Up @@ -234,7 +244,10 @@ def new_forward(self, *forward_args, **forward_kwargs):

_patch_op_dispatch_for_export()
program = program.run_decompositions(current_decomps)
fx_builder.programs[name] = program
fx_builder.programs[name] = ExportTargetDef(
program,
argument_device_affinities=argument_device_affinities,
)
return program


Expand Down
Loading

0 comments on commit 3f8b34a

Please sign in to comment.