Skip to content

Commit

Permalink
Remove deprecated constraints and dynamic dim support with test fixes (
Browse files Browse the repository at this point in the history
…#233)

The backing APIs have been deprecated for a year and were removed from
PyTorch 2.6 nightlies, causing import errors.

torch.export and torch-mlir have long since provided a supported way for
exporting dynamic dims.

---------

Signed-off-by: Stella Laurenzo <stellaraccident@gmail.com>
Signed-off-by: Ian <ian.nordeng@amd.com>
Co-authored-by: Stella Laurenzo <stellaraccident@gmail.com>
  • Loading branch information
IanNod and stellaraccident authored Oct 18, 2024
1 parent e66846c commit 97e0517
Show file tree
Hide file tree
Showing 16 changed files with 77 additions and 253 deletions.
77 changes: 0 additions & 77 deletions examples/aot_mlp/mlp_export_dynamic.py

This file was deleted.

7 changes: 2 additions & 5 deletions examples/resnet-18/resnet-18.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,8 @@ def forward(pixel_values_tensor: torch.Tensor):
class RN18(CompiledModule):
params = export_parameters(model)

def forward(self, x=AbstractTensor(None, 3, 224, 224, dtype=torch.float32)):
# set a constraint for the dynamic number of batches
# interestingly enough, it doesn't seem to limit BATCH_SIZE
const = [x.dynamic_dim(0) < 16]
return jittable(forward)(x, constraints=const)
def forward(self, x=AbstractTensor(10, 3, 224, 224, dtype=torch.float32)):
return jittable(forward)(x)


# build an mlir module with 1-shot exporter
Expand Down
2 changes: 1 addition & 1 deletion iree/turbine/aot/builtins/jittable.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ def flat_wrapped_f(*args):

def _split_py_arg(self, arg) -> Tuple[Value, Any]:
if isinstance(arg, IrTensor):
meta_tensor, _ = arg._to_meta_tensor()
meta_tensor = arg._to_meta_tensor()
return arg.ir_value, meta_tensor

raise TypeError(f"Unsupported argument to jittable: {arg}")
Expand Down
3 changes: 2 additions & 1 deletion iree/turbine/aot/passes/functorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
GraphModule,
)
from torch.fx.experimental import proxy_tensor
from torch._subclasses.fake_tensor import unset_fake_temporarily
from torch.utils import _pytree as pytree


Expand Down Expand Up @@ -43,7 +44,7 @@
def functorch_functionalize(gm_callable: Any, *args) -> GraphModule:
functionalized_callable = _functionalize_callabale(gm_callable)
# TODO: There is more of a dance needed if the user has entered with a fake_mode.
with proxy_tensor.maybe_disable_fake_tensor_mode():
with unset_fake_temporarily():
new_gm = proxy_tensor.make_fx(
functionalized_callable,
decomposition_table={},
Expand Down
20 changes: 5 additions & 15 deletions iree/turbine/aot/support/procedural/iree_emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def tensor_reshape(
result_value = flow_d.TensorReshapeOp(
result_type,
source.ir_value,
source.get_only_dynamic_dim_values(constant_cache=constant_cache),
[], # forcing empty list for dynamic dims until supported in CompiledModule
result_dynamic_dims,
).result
result = IrImmediateTensor(result_value, dtype=source.dtype)
Expand Down Expand Up @@ -276,7 +276,7 @@ def tensor_slice(
result_value = flow_d.TensorSliceOp(
result_type,
source_value,
source.get_only_dynamic_dim_values(constant_cache=constant_cache),
[], # forcing empty list for dynamic dims until supported in CompiledModule
start_index_values,
length_values,
result_dynamic_dims,
Expand All @@ -295,26 +295,19 @@ def tensor_update(
"""Applies an update to a target at start_indices and returns the mutated target."""
constant_cache: Dict[int, Value] = {}
target = cast_tensor_value(target)
target_dynamic_dims = target.get_only_dynamic_dim_values(
constant_cache=constant_cache
)
update = cast_tensor_value(update)
update_dynamic_dims = update.get_only_dynamic_dim_values(
constant_cache=constant_cache
)
start_index_dim_values = [
cast_index_value(idx, constant_cache=constant_cache)
for idx in start_indices
]
result_value = flow_d.TensorUpdateOp(
target.ir_value,
target_dynamic_dims,
[], # forcing empty list for dynamic dims until supported in CompiledModule
start_index_dim_values,
update.ir_value,
update_dynamic_dims,
[], # forcing empty list for updated dynamic dims until supported in CompiledModule
).result
result = IrImmediateTensor(result_value, target.dtype)
result.set_dynamic_dim_values(target_dynamic_dims)
return result

@emitter
Expand Down Expand Up @@ -342,11 +335,8 @@ def tensor_splat(

@emitter
def tensor_trace(self, key: str, *ts: BuildableTensorType):
dynamic_dims = []
for t in ts:
dynamic_dims.extend(t.get_only_dynamic_dim_values())
ts = tuple(cast_tensor_value(t).ir_value for t in ts)
flow_d.TensorTraceOp(StringAttr.get(key), ts, dynamic_dims)
flow_d.TensorTraceOp(StringAttr.get(key), ts, [])


# Circular imports to resolve typing.
Expand Down
147 changes: 25 additions & 122 deletions iree/turbine/aot/support/procedural/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,6 @@

import torch

from torch.export import (
Constraint,
dynamic_dim,
)

from ....support.ir_imports import (
F32Type,
IrType,
Expand Down Expand Up @@ -154,67 +149,23 @@ class IrTensor(Intrinsic):
def __init__(self, ir_type: IrType, dtype: torch.dtype):
assert isinstance(dtype, torch.dtype)
ranked_ir_type = RankedTensorType(ir_type)
self._shape = ranked_ir_type.shape
self.ir_type = ranked_ir_type
self.dtype = dtype
# We always cache the meta tensor once asked for since it is used
# to anchor constraints. The constraints list is the same size as
# the rank and has a non-None dynamic_dim constraint for each
# dynamic dimension in the type.
# for anchoring certain constraints.
self._meta_tensor: Optional[torch.Tensor] = None
self._meta_tensor_constraints: Optional[List[Constraint]] = None

# Figure dynamic dims.
# _dynamic_dims is either Empty if static, or Value/None if dynamic.
self._shape = ranked_ir_type.shape
self._dynamic_dims: List[Union[EmptyType, Value, None]] = [
None if d == ShapedTypeDynamicSizeSentinel else Empty for d in self._shape
]

# If we computed a dim, then stash it here for later use.
self._cached_dim_values: List[Optional[Value]] = [None] * len(
self._dynamic_dims
)

def dynamic_dim(self, i: int) -> Constraint:
"""Access the dynamic_dim constraint for the i'th dimension."""
mt, constraints = self._get_meta_tensor_constraints()
c = constraints[i]
if c is None:
raise TypeError(
f"Requested dynamic_dim constraint for dimension {i} of {self.ir_type} which is not dynamic"
)
return c
self._cached_dim_values: List[Optional[Value]] = [None] * len(self._shape)

@property
def rank(self) -> int:
return len(self._shape)

@property
def dynamic_dim_count(self) -> int:
return len(self._dynamic_dims) - self._dynamic_dims.count(Empty)

def set_dim_value(self, index: int, value: Optional[Value]):
"""Sets the value of a dynamic dim.
Raises ValueError if the dimension is not dynamic.
"""
if self._dynamic_dims is Empty:
raise ValueError(f"Dimension {index} of {self} is not dynamic")
self._dynamic_dims[index] = value

def set_dynamic_dim_values(self, values: Sequence[Value]):
"""Sets all dynamic dim values."""
dd = self._dynamic_dims
input_index = 0
for pos in range(len(dd)):
if dd[pos] is Empty:
# Static
continue
assert input_index < len(values), "Mismatched static/dynamic dims"
assert isinstance(values[input_index], Value)
dd[pos] = values[input_index]
input_index += 1
assert input_index == len(values), "Mismatched static/dynamic dims"
assert len(values) == 0, "Dynamic dims not currently supported"

def get_dim_value(
self,
Expand All @@ -233,81 +184,33 @@ def get_dim_value(
cached_dim = self._cached_dim_values[index]
if cached_dim:
return cached_dim
dynamic_dim = self._dynamic_dims[index]
if dynamic_dim is Empty or dynamic_dim is None:
if resolved_ir_value is None:
resolved_ir_value = self.ir_value
# Construct a static dimension.
# TODO: Add MLIR API support for creating an insertion point after
# an operation and use that to set the InsertionPoint to the
# earliest point.
# See: https://github.com/nod-ai/SHARK-Turbine/issues/133
dim_value = build_tensor_dim_value(
resolved_ir_value, index, constant_cache=constant_cache
)
self._cached_dim_values[index] = dim_value
return dim_value
else:
# Dynamic dim is known.
return dynamic_dim

def get_only_dynamic_dim_values(
self,
*,
constant_cache: Optional[Dict[int, Value]] = None,
resolved_ir_value: Optional[Value] = None,
) -> List[Value]:
"""Returns a list of *only* the dynamic dim Values."""
values: List[Value] = []
for i, sentinel in enumerate(self._dynamic_dims):
if sentinel is not Empty:
# Cache IR value so we don't materialize for each
# dynamic dim.
if resolved_ir_value is None:
resolved_ir_value = self.ir_value
values.append(
self.get_dim_value(
i,
constant_cache=constant_cache,
resolved_ir_value=resolved_ir_value,
)
)
return values
if resolved_ir_value is None:
resolved_ir_value = self.ir_value
# Construct a static dimension.
# TODO: Add MLIR API support for creating an insertion point after
# an operation and use that to set the InsertionPoint to the
# earliest point.
# See: https://github.com/nod-ai/SHARK-Turbine/issues/133
dim_value = build_tensor_dim_value(
resolved_ir_value, index, constant_cache=constant_cache
)
self._cached_dim_values[index] = dim_value
return dim_value

@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
return NotImplemented

def _get_meta_tensor_constraints(self) -> tuple[torch.Tensor, list[Constraint]]:
if self._meta_tensor is not None and self._meta_tensor_constraints is not None:
return self._meta_tensor, self._meta_tensor_constraints

ir_tensor_type = self.ir_type
shape = ir_tensor_type.shape
# TODO: We shouldn't need to create a real tensor here, as Dynamo will
# immediately convert it to fake. However, it will also set up the shape
# environment and asserts that any fake tensor inputs are from its
# internal FakeMode. There should be a way but needs more investigation.
# TODO: This tensor needs a device that matches the model being exported.
# We just create these on the CPU because that is common.
# Note that in Dynamo's modeling of dynamic shapes, 0/1 are specialized and
# cannot be dynamic, and we must use a >= 2 dimension value to represent
# a dynamic quantity. We therefore adjust the shape in this way and
# add a dynamic_dim constraint.
# See: https://github.com/nod-ai/SHARK-Turbine/issues/134
extents = [2 if d < 0 else d for d in shape]
mt = self._meta_tensor = torch.empty(extents, dtype=self.dtype)
# Generate constraints that are aligned with any dynamic dimensions or None
# if static.
self._meta_tensor_constraints = constraints = [
dynamic_dim(mt, i) if d < 0 else None for i, d in enumerate(shape)
]
return mt, constraints

def _to_meta_tensor(self) -> Tuple[torch.Tensor, List[Constraint]]:
def _to_meta_tensor(self) -> torch.Tensor:
"""Converts to a fake Tensor that dynamo can handle."""
mt, constraints = self._get_meta_tensor_constraints()
return mt, [c for c in constraints if c is not None]
if self._meta_tensor is None:
ir_tensor_type = self.ir_type
shape = ir_tensor_type.shape
assert not any(
d < 0 for d in shape
), "Unsupported dynamic dims in meta tensor"
self._meta_tensor = torch.empty(shape, dtype=self.dtype)
return self._meta_tensor


class IrImmediateTensor(IrTensor):
Expand Down
4 changes: 2 additions & 2 deletions iree/turbine/dynamo/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from torch.fx.experimental.proxy_tensor import make_fx
from torch._decomp import get_decompositions
from torch.func import functionalize
from typing import List, Optional
from typing import List, Optional, Mapping

from .decompositions import DEFAULT_DECOMPOSITIONS

Expand All @@ -15,7 +15,7 @@ def apply_decompositions(
if decompose_ops is None:
return gm

decompositions = get_decompositions(decompose_ops)
decompositions: Mapping = get_decompositions(decompose_ops)
gm = make_fx(
functionalize(gm),
decomposition_table=decompositions,
Expand Down
2 changes: 1 addition & 1 deletion iree/turbine/runtime/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,7 @@ def _create_hip_device(torch_device: torch.device, props) -> Optional[Device]:
if device:
gcn_arch_name = gcn_arch_name
device.compile_target_flags = device.compile_target_flags + (
f"--iree-rocm-target-chip={gcn_arch_name}",
f"--iree-hip-target={gcn_arch_name}",
)
device._recompute_target_keys()
return device
Expand Down
2 changes: 1 addition & 1 deletion pytorch-cpu-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
--pre
--index-url https://download.pytorch.org/whl/test/cpu
torch>=2.3.0, <2.5.0
torch>=2.3.0
torchaudio
torchvision
Loading

0 comments on commit 97e0517

Please sign in to comment.