-
Notifications
You must be signed in to change notification settings - Fork 24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add device affinities for arguments in AOT #231
Conversation
ce0da33
to
00d8857
Compare
@@ -107,12 +112,27 @@ def __call__(self, *args, **kwargs): | |||
return self.py_value(*args, **kwargs) | |||
|
|||
|
|||
class ExportTargetDef: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure if I should drop the ExportTargetDef
and use half-initialized ExportProcDef and ExportedProgramDef directly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the benefit to having this separate class?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the place where we need a structure to store this data we are not ready to construct ExportProcDef or ExportedProgramDef. If we are to use them directly they will have a more complicated multi-step initialization.
This PR removes the need for #220, which kind of abuses the function generation mechanism. |
Thanks. I will review this in a few minutes |
iree/turbine/aot/compiled_module.py
Outdated
argument_device_affinities: dict[int, "DeviceAffinity"] | None = None, | ||
): | ||
self.target = target | ||
self.argument_device_affinities = argument_device_affinities |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make the name more succinct. Instead of argument_device_affinities
just arg_device
. It should focus on where the argument is placed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
iree/turbine/aot/compiled_module.py
Outdated
@@ -207,6 +240,21 @@ def globals_defs(self) -> Generator[Tuple[str, GlobalsDef], None, None]: | |||
) # type: ignore | |||
|
|||
def def_attribute(self, key, value): | |||
if isinstance(value, ExportTargetDef): | |||
if isinstance(value.target, ExportedProgram): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Swap the if else and remove the else:
part. You can do
if not isinstance(value.target, ExportedProgram):
# We expect exported function.
assert callable(value.target) and inspect.isfunction(value.target)
return self.def_export_proc(
key, value.target, value.argument_device_affinities
)
And given the return it would exit right away anyway.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
iree/turbine/aot/compiled_module.py
Outdated
@@ -633,6 +708,7 @@ def __new__( | |||
ep_def.exported_program, | |||
symbol_name=ep_def.export_name or "main", | |||
symbol_visibility=None if ep_def.public else "private", | |||
argument_device_affinities=ep_def.argument_device_affinities or {}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should just support whatever the default is for ep_def.argument_device_affinities
rather than using or
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
tests/aot/fx_programs_test.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do not wipe out the existing test. Create a separate fx_programs_test_device.py
that tests the device affinity work. We should try to guarantee the old patch works for as long as possible
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I moved the new test to a new file as you suggested.
iree/turbine/aot/compiled_module.py
Outdated
@@ -207,6 +240,21 @@ def globals_defs(self) -> Generator[Tuple[str, GlobalsDef], None, None]: | |||
) # type: ignore | |||
|
|||
def def_attribute(self, key, value): | |||
if isinstance(value, ExportTargetDef): | |||
if isinstance(value.target, ExportedProgram): | |||
value = ExportedProgramDef( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't this be returned as well? It is setting the value but this falls through. I see it is handled further down, if so we should handle it right before the follow up case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I moved this case down before the handling of ExportedProgramDef
.
iree/turbine/aot/exporter.py
Outdated
class DeviceAffinity: | ||
"""This is used to provide device affinities to exported function arguments.""" | ||
|
||
def __init__(self, moniker: str): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make the moniker an int. All cases around just specify it that way anyway.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I changed it to ordinal: int
.
iree/turbine/aot/compiled_module.py
Outdated
@@ -568,6 +627,22 @@ def save_mlir(inst: "CompiledModule", path: Union[Path, str]): | |||
|
|||
jittable = staticmethod(builtins.jittable) | |||
|
|||
@staticmethod | |||
def annotate( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Name this something better than annotate
. Even signature_info
would be sufficient to explain that its adding additional information to the function signature.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
@@ -61,7 +68,7 @@ class FxPrograms: | |||
""" | |||
|
|||
def __init__(self): | |||
self.programs: dict[str, torch.export.ExportedProgram] = {} | |||
self.programs: dict[str, ExportTargetDef] = {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@rsuderman This is something I forgot to bring to attention, but I am not sure if self.programs
should be a part of the interface. This changes it and I also changed one test that specifically used it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. A minor comment on code organization.
iree/turbine/aot/exporter.py
Outdated
@@ -49,6 +50,21 @@ | |||
SaveableTarget = Union[str, Path, None, Output] | |||
|
|||
|
|||
class DeviceAffinity: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Put this in tensor_traits.py. Then you won't need to work around weird circular references and can just cleanly import it from anywhere needed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I moved it there. Why I did not see this.
We don't have support for providing device affinities for function arguments, which need to end up as MLIR function argument attributes. This change adds a class DeviceAffinity and provides the ability to supply affinities when exporting Torch functions/modules or when tracing in IREE-Trubine itself. Signed-off-by: Boian Petkantchin <boian.petkantchin@amd.com>
9baff6b
to
a7e4bc4
Compare
I squashed and rebased to prepare for merging. |
We don't have support for providing device affinities for function arguments, which need to end up as MLIR function argument attributes.
This change adds a class DeviceAffinity and provides the ability to supply affinities when exporting Torch functions/modules or when tracing in IREE-Trubine itself.