diff --git a/examples/aot_mlp/mlp_export_dynamic.py b/examples/aot_mlp/mlp_export_dynamic.py deleted file mode 100644 index 3bedd7c1..00000000 --- a/examples/aot_mlp/mlp_export_dynamic.py +++ /dev/null @@ -1,77 +0,0 @@ -# Copyright 2023 Nod Labs, Inc -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -# This sample builds a dynamic shape version of the MLP with -# a dynamic batch dimension. It uses the advanced, low-level -# API because we don't have dynamic shapes available in the -# simple API yet. - -import torch -import torch.nn as nn - -import iree.turbine.aot as aot - - -class MLP(nn.Module): - def __init__(self): - super().__init__() - self.layer0 = nn.Linear(8, 8, bias=True) - self.layer1 = nn.Linear(8, 4, bias=True) - self.layer2 = nn.Linear(4, 2, bias=True) - self.layer3 = nn.Linear(2, 2, bias=True) - - def forward(self, x: torch.Tensor): - x = self.layer0(x) - x = torch.sigmoid(x) - x = self.layer1(x) - x = torch.sigmoid(x) - x = self.layer2(x) - x = torch.sigmoid(x) - x = self.layer3(x) - return x - - -model = MLP() - - -class CompiledMLP(aot.CompiledModule): - params = aot.export_parameters(model) - - def main(self, x=aot.AbstractTensor(None, 97, 8, dtype=torch.float32)): - return aot.jittable(model.forward)( - x, - constraints=[ - x.dynamic_dim(0), - ], - ) - - -batch = torch.export.Dim("batch") -exported = aot.export( - model, - args=(torch.empty([2, 97, 8], dtype=torch.float32),), - dynamic_shapes={"x": {0: batch}}, -) -# Note that dynamic Torch IR is created below. -exported.print_readable() - - -# TODO: Enable once version roll to ToT torch-mlir with dynamic view -# op legalization fixes. -# compiled_binary = exported.compile(save_to=None) -# def infer(): -# import numpy as np -# import iree.runtime as rt - -# config = rt.Config("local-task") -# vmm = rt.load_vm_module( -# rt.VmModule.wrap_buffer(config.vm_instance, compiled_binary.map_memory()), -# config, -# ) -# x = np.random.rand(10, 97, 8).astype(np.float32) -# y = vmm.main(x) -# print(y.to_host()) -# infer() diff --git a/examples/resnet-18/resnet-18.py b/examples/resnet-18/resnet-18.py index 2b3fce56..5a6346bf 100644 --- a/examples/resnet-18/resnet-18.py +++ b/examples/resnet-18/resnet-18.py @@ -22,11 +22,8 @@ def forward(pixel_values_tensor: torch.Tensor): class RN18(CompiledModule): params = export_parameters(model) - def forward(self, x=AbstractTensor(None, 3, 224, 224, dtype=torch.float32)): - # set a constraint for the dynamic number of batches - # interestingly enough, it doesn't seem to limit BATCH_SIZE - const = [x.dynamic_dim(0) < 16] - return jittable(forward)(x, constraints=const) + def forward(self, x=AbstractTensor(10, 3, 224, 224, dtype=torch.float32)): + return jittable(forward)(x) # build an mlir module with 1-shot exporter diff --git a/iree/turbine/aot/builtins/jittable.py b/iree/turbine/aot/builtins/jittable.py index 7a34a01b..23ad0424 100644 --- a/iree/turbine/aot/builtins/jittable.py +++ b/iree/turbine/aot/builtins/jittable.py @@ -308,7 +308,7 @@ def flat_wrapped_f(*args): def _split_py_arg(self, arg) -> Tuple[Value, Any]: if isinstance(arg, IrTensor): - meta_tensor, _ = arg._to_meta_tensor() + meta_tensor = arg._to_meta_tensor() return arg.ir_value, meta_tensor raise TypeError(f"Unsupported argument to jittable: {arg}") diff --git a/iree/turbine/aot/passes/functorch.py b/iree/turbine/aot/passes/functorch.py index 06967ecf..c36a761c 100644 --- a/iree/turbine/aot/passes/functorch.py +++ b/iree/turbine/aot/passes/functorch.py @@ -11,6 +11,7 @@ GraphModule, ) from torch.fx.experimental import proxy_tensor +from torch._subclasses.fake_tensor import unset_fake_temporarily from torch.utils import _pytree as pytree @@ -43,7 +44,7 @@ def functorch_functionalize(gm_callable: Any, *args) -> GraphModule: functionalized_callable = _functionalize_callabale(gm_callable) # TODO: There is more of a dance needed if the user has entered with a fake_mode. - with proxy_tensor.maybe_disable_fake_tensor_mode(): + with unset_fake_temporarily(): new_gm = proxy_tensor.make_fx( functionalized_callable, decomposition_table={}, diff --git a/iree/turbine/aot/support/procedural/iree_emitter.py b/iree/turbine/aot/support/procedural/iree_emitter.py index dbfd4ea2..54f89987 100644 --- a/iree/turbine/aot/support/procedural/iree_emitter.py +++ b/iree/turbine/aot/support/procedural/iree_emitter.py @@ -199,7 +199,7 @@ def tensor_reshape( result_value = flow_d.TensorReshapeOp( result_type, source.ir_value, - source.get_only_dynamic_dim_values(constant_cache=constant_cache), + [], # forcing empty list for dynamic dims until supported in CompiledModule result_dynamic_dims, ).result result = IrImmediateTensor(result_value, dtype=source.dtype) @@ -276,7 +276,7 @@ def tensor_slice( result_value = flow_d.TensorSliceOp( result_type, source_value, - source.get_only_dynamic_dim_values(constant_cache=constant_cache), + [], # forcing empty list for dynamic dims until supported in CompiledModule start_index_values, length_values, result_dynamic_dims, @@ -295,26 +295,19 @@ def tensor_update( """Applies an update to a target at start_indices and returns the mutated target.""" constant_cache: Dict[int, Value] = {} target = cast_tensor_value(target) - target_dynamic_dims = target.get_only_dynamic_dim_values( - constant_cache=constant_cache - ) update = cast_tensor_value(update) - update_dynamic_dims = update.get_only_dynamic_dim_values( - constant_cache=constant_cache - ) start_index_dim_values = [ cast_index_value(idx, constant_cache=constant_cache) for idx in start_indices ] result_value = flow_d.TensorUpdateOp( target.ir_value, - target_dynamic_dims, + [], # forcing empty list for dynamic dims until supported in CompiledModule start_index_dim_values, update.ir_value, - update_dynamic_dims, + [], # forcing empty list for updated dynamic dims until supported in CompiledModule ).result result = IrImmediateTensor(result_value, target.dtype) - result.set_dynamic_dim_values(target_dynamic_dims) return result @emitter @@ -342,11 +335,8 @@ def tensor_splat( @emitter def tensor_trace(self, key: str, *ts: BuildableTensorType): - dynamic_dims = [] - for t in ts: - dynamic_dims.extend(t.get_only_dynamic_dim_values()) ts = tuple(cast_tensor_value(t).ir_value for t in ts) - flow_d.TensorTraceOp(StringAttr.get(key), ts, dynamic_dims) + flow_d.TensorTraceOp(StringAttr.get(key), ts, []) # Circular imports to resolve typing. diff --git a/iree/turbine/aot/support/procedural/primitives.py b/iree/turbine/aot/support/procedural/primitives.py index ad406c87..db690d3a 100644 --- a/iree/turbine/aot/support/procedural/primitives.py +++ b/iree/turbine/aot/support/procedural/primitives.py @@ -20,11 +20,6 @@ import torch -from torch.export import ( - Constraint, - dynamic_dim, -) - from ....support.ir_imports import ( F32Type, IrType, @@ -154,67 +149,23 @@ class IrTensor(Intrinsic): def __init__(self, ir_type: IrType, dtype: torch.dtype): assert isinstance(dtype, torch.dtype) ranked_ir_type = RankedTensorType(ir_type) + self._shape = ranked_ir_type.shape self.ir_type = ranked_ir_type self.dtype = dtype # We always cache the meta tensor once asked for since it is used - # to anchor constraints. The constraints list is the same size as - # the rank and has a non-None dynamic_dim constraint for each - # dynamic dimension in the type. + # for anchoring certain constraints. self._meta_tensor: Optional[torch.Tensor] = None - self._meta_tensor_constraints: Optional[List[Constraint]] = None - - # Figure dynamic dims. - # _dynamic_dims is either Empty if static, or Value/None if dynamic. - self._shape = ranked_ir_type.shape - self._dynamic_dims: List[Union[EmptyType, Value, None]] = [ - None if d == ShapedTypeDynamicSizeSentinel else Empty for d in self._shape - ] # If we computed a dim, then stash it here for later use. - self._cached_dim_values: List[Optional[Value]] = [None] * len( - self._dynamic_dims - ) - - def dynamic_dim(self, i: int) -> Constraint: - """Access the dynamic_dim constraint for the i'th dimension.""" - mt, constraints = self._get_meta_tensor_constraints() - c = constraints[i] - if c is None: - raise TypeError( - f"Requested dynamic_dim constraint for dimension {i} of {self.ir_type} which is not dynamic" - ) - return c + self._cached_dim_values: List[Optional[Value]] = [None] * len(self._shape) @property def rank(self) -> int: return len(self._shape) - @property - def dynamic_dim_count(self) -> int: - return len(self._dynamic_dims) - self._dynamic_dims.count(Empty) - - def set_dim_value(self, index: int, value: Optional[Value]): - """Sets the value of a dynamic dim. - - Raises ValueError if the dimension is not dynamic. - """ - if self._dynamic_dims is Empty: - raise ValueError(f"Dimension {index} of {self} is not dynamic") - self._dynamic_dims[index] = value - def set_dynamic_dim_values(self, values: Sequence[Value]): """Sets all dynamic dim values.""" - dd = self._dynamic_dims - input_index = 0 - for pos in range(len(dd)): - if dd[pos] is Empty: - # Static - continue - assert input_index < len(values), "Mismatched static/dynamic dims" - assert isinstance(values[input_index], Value) - dd[pos] = values[input_index] - input_index += 1 - assert input_index == len(values), "Mismatched static/dynamic dims" + assert len(values) == 0, "Dynamic dims not currently supported" def get_dim_value( self, @@ -233,81 +184,33 @@ def get_dim_value( cached_dim = self._cached_dim_values[index] if cached_dim: return cached_dim - dynamic_dim = self._dynamic_dims[index] - if dynamic_dim is Empty or dynamic_dim is None: - if resolved_ir_value is None: - resolved_ir_value = self.ir_value - # Construct a static dimension. - # TODO: Add MLIR API support for creating an insertion point after - # an operation and use that to set the InsertionPoint to the - # earliest point. - # See: https://github.com/nod-ai/SHARK-Turbine/issues/133 - dim_value = build_tensor_dim_value( - resolved_ir_value, index, constant_cache=constant_cache - ) - self._cached_dim_values[index] = dim_value - return dim_value - else: - # Dynamic dim is known. - return dynamic_dim - - def get_only_dynamic_dim_values( - self, - *, - constant_cache: Optional[Dict[int, Value]] = None, - resolved_ir_value: Optional[Value] = None, - ) -> List[Value]: - """Returns a list of *only* the dynamic dim Values.""" - values: List[Value] = [] - for i, sentinel in enumerate(self._dynamic_dims): - if sentinel is not Empty: - # Cache IR value so we don't materialize for each - # dynamic dim. - if resolved_ir_value is None: - resolved_ir_value = self.ir_value - values.append( - self.get_dim_value( - i, - constant_cache=constant_cache, - resolved_ir_value=resolved_ir_value, - ) - ) - return values + if resolved_ir_value is None: + resolved_ir_value = self.ir_value + # Construct a static dimension. + # TODO: Add MLIR API support for creating an insertion point after + # an operation and use that to set the InsertionPoint to the + # earliest point. + # See: https://github.com/nod-ai/SHARK-Turbine/issues/133 + dim_value = build_tensor_dim_value( + resolved_ir_value, index, constant_cache=constant_cache + ) + self._cached_dim_values[index] = dim_value + return dim_value @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): return NotImplemented - def _get_meta_tensor_constraints(self) -> tuple[torch.Tensor, list[Constraint]]: - if self._meta_tensor is not None and self._meta_tensor_constraints is not None: - return self._meta_tensor, self._meta_tensor_constraints - - ir_tensor_type = self.ir_type - shape = ir_tensor_type.shape - # TODO: We shouldn't need to create a real tensor here, as Dynamo will - # immediately convert it to fake. However, it will also set up the shape - # environment and asserts that any fake tensor inputs are from its - # internal FakeMode. There should be a way but needs more investigation. - # TODO: This tensor needs a device that matches the model being exported. - # We just create these on the CPU because that is common. - # Note that in Dynamo's modeling of dynamic shapes, 0/1 are specialized and - # cannot be dynamic, and we must use a >= 2 dimension value to represent - # a dynamic quantity. We therefore adjust the shape in this way and - # add a dynamic_dim constraint. - # See: https://github.com/nod-ai/SHARK-Turbine/issues/134 - extents = [2 if d < 0 else d for d in shape] - mt = self._meta_tensor = torch.empty(extents, dtype=self.dtype) - # Generate constraints that are aligned with any dynamic dimensions or None - # if static. - self._meta_tensor_constraints = constraints = [ - dynamic_dim(mt, i) if d < 0 else None for i, d in enumerate(shape) - ] - return mt, constraints - - def _to_meta_tensor(self) -> Tuple[torch.Tensor, List[Constraint]]: + def _to_meta_tensor(self) -> torch.Tensor: """Converts to a fake Tensor that dynamo can handle.""" - mt, constraints = self._get_meta_tensor_constraints() - return mt, [c for c in constraints if c is not None] + if self._meta_tensor is None: + ir_tensor_type = self.ir_type + shape = ir_tensor_type.shape + assert not any( + d < 0 for d in shape + ), "Unsupported dynamic dims in meta tensor" + self._meta_tensor = torch.empty(shape, dtype=self.dtype) + return self._meta_tensor class IrImmediateTensor(IrTensor): diff --git a/iree/turbine/dynamo/passes.py b/iree/turbine/dynamo/passes.py index 23078a83..4543fc2c 100644 --- a/iree/turbine/dynamo/passes.py +++ b/iree/turbine/dynamo/passes.py @@ -2,7 +2,7 @@ from torch.fx.experimental.proxy_tensor import make_fx from torch._decomp import get_decompositions from torch.func import functionalize -from typing import List, Optional +from typing import List, Optional, Mapping from .decompositions import DEFAULT_DECOMPOSITIONS @@ -15,7 +15,7 @@ def apply_decompositions( if decompose_ops is None: return gm - decompositions = get_decompositions(decompose_ops) + decompositions: Mapping = get_decompositions(decompose_ops) gm = make_fx( functionalize(gm), decomposition_table=decompositions, diff --git a/iree/turbine/runtime/device.py b/iree/turbine/runtime/device.py index d34f49a8..e3c717e1 100644 --- a/iree/turbine/runtime/device.py +++ b/iree/turbine/runtime/device.py @@ -503,7 +503,7 @@ def _create_hip_device(torch_device: torch.device, props) -> Optional[Device]: if device: gcn_arch_name = gcn_arch_name device.compile_target_flags = device.compile_target_flags + ( - f"--iree-rocm-target-chip={gcn_arch_name}", + f"--iree-hip-target={gcn_arch_name}", ) device._recompute_target_keys() return device diff --git a/pytorch-cpu-requirements.txt b/pytorch-cpu-requirements.txt index 68f37da8..20aa5da8 100644 --- a/pytorch-cpu-requirements.txt +++ b/pytorch-cpu-requirements.txt @@ -1,5 +1,5 @@ --pre --index-url https://download.pytorch.org/whl/test/cpu -torch>=2.3.0, <2.5.0 +torch>=2.3.0 torchaudio torchvision diff --git a/pytorch-rocm-requirements.txt b/pytorch-rocm-requirements.txt index 3bdde67a..f8e262b8 100644 --- a/pytorch-rocm-requirements.txt +++ b/pytorch-rocm-requirements.txt @@ -1,5 +1,5 @@ --pre --index-url https://download.pytorch.org/whl/rocm6.0 -torch>=2.3.0, <2.5.0 +torch>=2.3.0 torchaudio torchvision diff --git a/tests/aot/functionalize_test.py b/tests/aot/functionalize_test.py index 2a2ea309..6938a87c 100644 --- a/tests/aot/functionalize_test.py +++ b/tests/aot/functionalize_test.py @@ -6,6 +6,7 @@ import logging import unittest +import pytest import torch @@ -34,6 +35,9 @@ def compute(): print(module_str) self.assertNotIn("add_", module_str) + @pytest.mark.xfail( + reason="CompiledModule dynamic dims no longer supported in latest torch versions" + ) def testDynamicDims(self): class ProcArgsModule(CompiledModule): def dynamic_dim(self, a=AbstractTensor(None, 2), b=AbstractTensor(None, 1)): diff --git a/tests/aot/iree_procedural_test.py b/tests/aot/iree_procedural_test.py index 251c8f12..c8007bf3 100644 --- a/tests/aot/iree_procedural_test.py +++ b/tests/aot/iree_procedural_test.py @@ -6,6 +6,7 @@ import logging import unittest +import pytest import torch @@ -44,59 +45,55 @@ def foobar(self, a=AbstractTensor(None, 3)): def testTensorEmpty(self): class BasicModule(CompiledModule): - def foobar(self, x=AbstractIndex): - empty = IREE.tensor_empty(x, 16) + def foobar(self): + empty = IREE.tensor_empty(1, 16) dim0 = IREE.tensor_dim(empty, 0) return empty, dim0 inst = BasicModule(context=Context(), import_to=None) module_str = str(CompiledModule.get_mlir_module(inst)) print(module_str) - self.assertIn("%0 = flow.tensor.empty : tensor{%arg0}", module_str) - # NOTE: We are testing below that the dynamic dimension is associated - # and used from the input vs being recalculated. - self.assertIn("return %0, %arg0 : tensor, index", module_str) + self.assertIn("%0 = flow.tensor.empty : tensor<1x16xf32>", module_str) + self.assertIn("return %0, %dim : tensor<1x16xf32>, index", module_str) def testTensorSplat(self): class BasicModule(CompiledModule): - def foobar(self, x=AbstractIndex, y=AbstractF32): - empty = IREE.tensor_splat(x, 34, value=y, dtype=torch.float32) + def foobar(self, y=AbstractF32): + empty = IREE.tensor_splat(2, 34, value=y, dtype=torch.float32) dim0 = IREE.tensor_dim(empty, 0) return empty, dim0 inst = BasicModule(context=Context(), import_to=None) module_str = str(CompiledModule.get_mlir_module(inst)) print(module_str) - self.assertIn( - "%0 = flow.tensor.splat %arg1 : tensor{%arg0}", module_str - ) + self.assertIn("%0 = flow.tensor.splat %arg0 : tensor<2x34xf32>", module_str) # NOTE: We are testing below that the dynamic dimension is associated # and used from the input vs being recalculated. - self.assertIn("return %0, %arg0 : tensor, index", module_str) + self.assertIn("return %0, %dim : tensor<2x34xf32>, index", module_str) def testTensorSplatCasting(self): class BasicModule(CompiledModule): - def foobar(self, x=AbstractIndex, y=AbstractIndex): - empty = IREE.tensor_splat(x, 34, value=y, dtype=torch.int32) + def foobar(self, y=AbstractIndex): + empty = IREE.tensor_splat(8, 34, value=y, dtype=torch.int32) dim0 = IREE.tensor_dim(empty, 0) return empty, dim0 inst = BasicModule(context=Context(), import_to=None) module_str = str(CompiledModule.get_mlir_module(inst)) print(module_str) - self.assertIn("%0 = arith.index_castui %arg1 : index to i32", module_str) - self.assertIn("%1 = flow.tensor.splat %0 : tensor{%arg0}", module_str) + self.assertIn("%0 = arith.index_castui %arg0 : index to i32", module_str) + self.assertIn("%1 = flow.tensor.splat %0 : tensor<8x34xi32>", module_str) def testTensorTrace(self): class BasicModule(CompiledModule): - def foobar(self, x=AbstractTensor(None), y=AbstractTensor(3)): + def foobar(self, x=AbstractTensor(5), y=AbstractTensor(3)): IREE.tensor_trace("DEBUG", x, y) inst = BasicModule(context=Context(), import_to=None) module_str = str(CompiledModule.get_mlir_module(inst)) print(module_str) self.assertIn( - 'flow.tensor.trace "DEBUG" = [%arg0 : tensor{%dim}, %arg1 : tensor<3xf32>]', + 'flow.tensor.trace "DEBUG" = [%arg0 : tensor<5xf32>, %arg1 : tensor<3xf32>]', module_str, ) @@ -128,6 +125,9 @@ def foobar(self, x=AbstractTensor(3, 4)): module_str, ) + @pytest.mark.xfail( + reason="CompiledModule dynamic dims no longer supported in latest torch versions" + ) def testTensorSliceDynamicIndex(self): class SliceDynamicIndex(CompiledModule): def foobar(self, x=AbstractIndex): @@ -142,6 +142,9 @@ def foobar(self, x=AbstractIndex): module_str, ) + @pytest.mark.xfail( + reason="CompiledModule dynamic dims no longer supported in latest torch versions" + ) def testTensorSliceDynamicLength(self): class SliceDynamicIndex(CompiledModule): def foobar(self, x=AbstractIndex, y=AbstractIndex): @@ -175,6 +178,9 @@ def foobar( module_str, ) + @pytest.mark.xfail( + reason="CompiledModule dynamic dims no longer supported in latest torch versions" + ) def testTensorUpdateDynamic(self): class UpdateDynamic(CompiledModule): def foobar( @@ -199,16 +205,16 @@ def foobar( def testTensorReshape(self): class ReshapeModule(CompiledModule): - def foobar(self, x=AbstractIndex, y=AbstractIndex): - empty = IREE.tensor_empty(x, 16) - reshaped = IREE.tensor_reshape(empty, 1, y, y) + def foobar(self): + empty = IREE.tensor_empty(4, 16) + reshaped = IREE.tensor_reshape(empty, 1, 2, 2) return reshaped inst = ReshapeModule(context=Context(), import_to=None) module_str = str(CompiledModule.get_mlir_module(inst)) print(module_str) self.assertIn( - "flow.tensor.reshape %0 : tensor{%arg0} -> tensor<1x?x?xf32>{%arg1, %arg1}", + "flow.tensor.reshape %0 : tensor<4x16xf32> -> tensor<1x2x2xf32>", module_str, ) diff --git a/tests/aot/jittable_test.py b/tests/aot/jittable_test.py index d19988bc..c1101aec 100644 --- a/tests/aot/jittable_test.py +++ b/tests/aot/jittable_test.py @@ -6,6 +6,7 @@ import logging import unittest +import pytest import torch @@ -72,6 +73,9 @@ def compute(*, a, b): module_str = str(CompiledModule.get_mlir_module(inst)) print(module_str) + @pytest.mark.xfail( + reason="CompiledModule dynamic dims no longer supported in latest torch versions" + ) def testDynamicDims(self): class DynamicDimsModule(CompiledModule): def dynamic_dim(self, a=AbstractTensor(None, 2), b=AbstractTensor(None, 1)): @@ -108,6 +112,7 @@ def compute(a, b): module_str = str(CompiledModule.get_mlir_module(inst)) print(module_str) + @pytest.mark.xfail(reason="CompiledModule dynamic dims no longer supported") def testIrImmediateTensorAsInputToDynamicDims(self): class ProcArgsModule(CompiledModule): def dynamic_dim(self, x=AbstractIndex): diff --git a/tests/dynamo/importer_dynamic_test.py b/tests/dynamo/importer_dynamic_test.py index 682aa140..ba773344 100644 --- a/tests/dynamo/importer_dynamic_test.py +++ b/tests/dynamo/importer_dynamic_test.py @@ -10,7 +10,6 @@ import torch import torch._dynamo as dynamo -from torch._export import dynamic_dim # from torch._export.constraints import constrain_as_size, constrain_as_value from iree.compiler.extras.fx_importer import FxImporter diff --git a/tests/dynamo/llama_test.py b/tests/dynamo/llama_test.py index 65750277..99be0f53 100644 --- a/tests/dynamo/llama_test.py +++ b/tests/dynamo/llama_test.py @@ -315,7 +315,6 @@ def main(): opt(example_tokens, start_pos) -@pytest.mark.xfail(reason="https://github.com/nod-ai/SHARK-Turbine/issues/221") class ModelTests(unittest.TestCase): def testLLama(self): main() diff --git a/tests/examples/aot_mlp_test.py b/tests/examples/aot_mlp_test.py index c4266a4a..d55e0b9c 100644 --- a/tests/examples/aot_mlp_test.py +++ b/tests/examples/aot_mlp_test.py @@ -22,9 +22,6 @@ class AOTMLPTest(unittest.TestCase): def testMLPExportSimple(self): _run("examples/aot_mlp/mlp_export_simple.py") - def testMLPExportSimple(self): - _run("examples/aot_mlp/mlp_export_dynamic.py") - if __name__ == "__main__": logging.basicConfig(level=logging.DEBUG)