From 032739849c8a3e235973e68c555cb0468e4dcc8d Mon Sep 17 00:00:00 2001 From: harsh-nod Date: Mon, 23 Sep 2024 22:13:44 -0500 Subject: [PATCH 01/28] Reorder iter args to match ordering of init args and outputs (#161) This PR modifies the insertion point for iter args to ensure that the iter args are in the same order as the init args and outputs. This simplifies the mapping between init args, iter args and outputs. Signed-off-by: Harsh Menon --- lit_tests/kernel/wave/barriers.py | 4 ++-- lit_tests/kernel/wave/expansion.py | 12 +++++----- .../kernel/wave/index_sequence_analysis.py | 24 +++++++++---------- .../kernel/wave/minimize_global_loads.py | 8 +++---- shark_turbine/kernel/ops/wave_ops.py | 5 +++- shark_turbine/kernel/wave/expansion.py | 16 ++++++++++++- 6 files changed, 43 insertions(+), 26 deletions(-) diff --git a/lit_tests/kernel/wave/barriers.py b/lit_tests/kernel/wave/barriers.py index 1b446dc0..8f8b4a6f 100644 --- a/lit_tests/kernel/wave/barriers.py +++ b/lit_tests/kernel/wave/barriers.py @@ -207,9 +207,9 @@ def test_gemm(): # Reduction subgraph: # CHECK: %acc_0_0_0 - # CHECK-NEXT: %acc_1_1_0 - # CHECK-NEXT: %acc_1_0_0 # CHECK-NEXT: %acc_0_1_0 + # CHECK-NEXT: %acc_1_0_0 + # CHECK-NEXT: %acc_1_1_0 # CHECK-NEXT: %a # CHECK-NEXT: %read_0_0_0 # CHECK-NEXT: %read_0_0_1 diff --git a/lit_tests/kernel/wave/expansion.py b/lit_tests/kernel/wave/expansion.py index a20965f3..f1bf0ede 100644 --- a/lit_tests/kernel/wave/expansion.py +++ b/lit_tests/kernel/wave/expansion.py @@ -265,9 +265,9 @@ def test_gemm(): # Reduction subgraph: # CHECK: %acc_0_0_0 - # CHECK-NEXT: %acc_1_1_0 - # CHECK-NEXT: %acc_1_0_0 # CHECK-NEXT: %acc_0_1_0 + # CHECK-NEXT: %acc_1_0_0 + # CHECK-NEXT: %acc_1_1_0 # CHECK-NEXT: %a # CHECK-NEXT: %read_0_0_0 @@ -305,13 +305,13 @@ def test_gemm(): # CHECK-SAME: (%read_0_0_0, %read_0_1_0, %acc_0_1_0) # CHECK-NEXT: %mma_0_1_1 # CHECK-SAME: (%read_0_0_1, %read_0_1_1, %mma_0_1_0) - # CHECK-NEXT: return [mma_0_0_1, mma_1_1_1, mma_1_0_1, mma_0_1_1] + # CHECK-NEXT: return [mma_0_0_1, mma_0_1_1, mma_1_0_1, mma_1_1_1] # Custom format: # CHECK-NEXT: placeholder(_name=acc_0_0_0 - # CHECK-NEXT: placeholder(_name=acc_1_1_0 - # CHECK-NEXT: placeholder(_name=acc_1_0_0 # CHECK-NEXT: placeholder(_name=acc_0_1_0 + # CHECK-NEXT: placeholder(_name=acc_1_0_0 + # CHECK-NEXT: placeholder(_name=acc_1_1_0 # CHECK-NEXT: placeholder(_name=a # CHECK-NEXT: read(memory=a, elements_per_thread=4, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16), K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) : 4 : 1}) # CHECK-NEXT: read(memory=a, elements_per_thread=4, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16), K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 1}) @@ -346,7 +346,7 @@ def test_gemm(): # CHECK-NEXT: mma(lhs=read_0_0_1 (index = {M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16), K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 1}) # CHECK-SAME: rhs=read_0_1_1 (index = {N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16, K: ARGK*BLOCK_K + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 1}) # CHECK-SAME: acc=mma_0_1_0 (index = {M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) : 4 : 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16})) - # CHECK-NEXT: output(return_vals=([mma_0_0_1, mma_1_1_1, mma_1_0_1, mma_0_1_1],)) + # CHECK-NEXT: output(return_vals=([mma_0_0_1, mma_0_1_1, mma_1_0_1, mma_1_1_1],)) # CHECK-NEXT: ----- diff --git a/lit_tests/kernel/wave/index_sequence_analysis.py b/lit_tests/kernel/wave/index_sequence_analysis.py index d49ee3b2..4cdb997b 100644 --- a/lit_tests/kernel/wave/index_sequence_analysis.py +++ b/lit_tests/kernel/wave/index_sequence_analysis.py @@ -210,16 +210,16 @@ def test_gemm(): # CHECK-SAME: index={M: 64*$WG0 + 4*floor((Mod($T0, 64))/16) + 3, N: 64*$WG1 + Mod($T0, 16) + 32}) # CHECK-NEXT: extract_slice(register_=getresult_1_1_0, offset=[0], size=[1], stride=[1]) # CHECK-NEXT: write(register_=extract_slice_4, memory=c, elements_per_thread=1, - # CHECK-SAME: index={M: 64*$WG0 + 4*floor((Mod($T0, 64))/16), N: 64*$WG1 + Mod($T0, 16) + 48}) + # CHECK-SAME: index={M: 64*$WG0 + 4*floor((Mod($T0, 64))/16) + 16, N: 64*$WG1 + Mod($T0, 16) + 48}) # CHECK-NEXT: extract_slice(register_=getresult_1_1_0, offset=[1], size=[1], stride=[1]) # CHECK-NEXT: write(register_=extract_slice_5, memory=c, elements_per_thread=1, - # CHECK-SAME: index={M: 64*$WG0 + 4*floor((Mod($T0, 64))/16) + 1, N: 64*$WG1 + Mod($T0, 16) + 48}) + # CHECK-SAME: index={M: 64*$WG0 + 4*floor((Mod($T0, 64))/16) + 17, N: 64*$WG1 + Mod($T0, 16) + 48}) # CHECK-NEXT: extract_slice(register_=getresult_1_1_0, offset=[2], size=[1], stride=[1]) # CHECK-NEXT: write(register_=extract_slice_6, memory=c, elements_per_thread=1, - # CHECK-SAME: index={M: 64*$WG0 + 4*floor((Mod($T0, 64))/16) + 2, N: 64*$WG1 + Mod($T0, 16) + 48}) + # CHECK-SAME: index={M: 64*$WG0 + 4*floor((Mod($T0, 64))/16) + 18, N: 64*$WG1 + Mod($T0, 16) + 48}) # CHECK-NEXT: extract_slice(register_=getresult_1_1_0, offset=[3], size=[1], stride=[1]) # CHECK-NEXT: write(register_=extract_slice_7, memory=c, elements_per_thread=1, - # CHECK-SAME: index={M: 64*$WG0 + 4*floor((Mod($T0, 64))/16) + 3, N: 64*$WG1 + Mod($T0, 16) + 48}) + # CHECK-SAME: index={M: 64*$WG0 + 4*floor((Mod($T0, 64))/16) + 19, N: 64*$WG1 + Mod($T0, 16) + 48}) # CHECK-NEXT: extract_slice(register_=getresult_1_0_0, offset=[0], size=[1], stride=[1]) # CHECK-NEXT: write(register_=extract_slice_8, memory=c, elements_per_thread=1, # CHECK-SAME: index={M: 64*$WG0 + 4*floor((Mod($T0, 64))/16) + 16, N: 64*$WG1 + Mod($T0, 16) + 32}) @@ -234,22 +234,22 @@ def test_gemm(): # CHECK-SAME: index={M: 64*$WG0 + 4*floor((Mod($T0, 64))/16) + 19, N: 64*$WG1 + Mod($T0, 16) + 32}) # CHECK-NEXT: extract_slice(register_=getresult_0_1_0, offset=[0], size=[1], stride=[1]) # CHECK-NEXT: write(register_=extract_slice_12, memory=c, elements_per_thread=1, - # CHECK-SAME: index={M: 64*$WG0 + 4*floor((Mod($T0, 64))/16) + 16, N: 64*$WG1 + Mod($T0, 16) + 48}) + # CHECK-SAME: index={M: 64*$WG0 + 4*floor((Mod($T0, 64))/16), N: 64*$WG1 + Mod($T0, 16) + 48}) # CHECK-NEXT: extract_slice(register_=getresult_0_1_0, offset=[1], size=[1], stride=[1]) # CHECK-NEXT: write(register_=extract_slice_13, memory=c, elements_per_thread=1, - # CHECK-SAME: index={M: 64*$WG0 + 4*floor((Mod($T0, 64))/16) + 17, N: 64*$WG1 + Mod($T0, 16) + 48}) + # CHECK-SAME: index={M: 64*$WG0 + 4*floor((Mod($T0, 64))/16) + 1, N: 64*$WG1 + Mod($T0, 16) + 48}) # CHECK-NEXT: extract_slice(register_=getresult_0_1_0, offset=[2], size=[1], stride=[1]) # CHECK-NEXT: write(register_=extract_slice_14, memory=c, elements_per_thread=1, - # CHECK-SAME: index={M: 64*$WG0 + 4*floor((Mod($T0, 64))/16) + 18, N: 64*$WG1 + Mod($T0, 16) + 48}) + # CHECK-SAME: index={M: 64*$WG0 + 4*floor((Mod($T0, 64))/16) + 2, N: 64*$WG1 + Mod($T0, 16) + 48}) # CHECK-NEXT: extract_slice(register_=getresult_0_1_0, offset=[3], size=[1], stride=[1]) # CHECK-NEXT: write(register_=extract_slice_15, memory=c, elements_per_thread=1, - # CHECK-SAME: index={M: 64*$WG0 + 4*floor((Mod($T0, 64))/16) + 19, N: 64*$WG1 + Mod($T0, 16) + 48}) + # CHECK-SAME: index={M: 64*$WG0 + 4*floor((Mod($T0, 64))/16) + 3, N: 64*$WG1 + Mod($T0, 16) + 48}) # Reduction subgraph: # CHECK: %acc_0_0_0 - # CHECK-NEXT: %acc_1_1_0 - # CHECK-NEXT: %acc_1_0_0 # CHECK-NEXT: %acc_0_1_0 + # CHECK-NEXT: %acc_1_0_0 + # CHECK-NEXT: %acc_1_1_0 # CHECK-NEXT: %a # CHECK-NEXT: %read_4 # CHECK-SAME: (%a, 8, None, None) @@ -303,9 +303,9 @@ def test_gemm(): # Reduction subgraph (custom format): # CHECK: placeholder(_name=acc_0_0_0 - # CHECK-NEXT: placeholder(_name=acc_1_1_0 - # CHECK-NEXT: placeholder(_name=acc_1_0_0 # CHECK-NEXT: placeholder(_name=acc_0_1_0 + # CHECK-NEXT: placeholder(_name=acc_1_0_0 + # CHECK-NEXT: placeholder(_name=acc_1_1_0 # CHECK-NEXT: placeholder(_name=a # CHECK-NEXT: read(memory=a, elements_per_thread=8, # CHECK-SAME: index={M: $WG0*BLOCK_M + Mod(16*$T1 + 32*$T2 + floor($T0/8), 64), K: ARGK*BLOCK_K + 8*(Mod($T0, 8)) : 8 : 1}) diff --git a/lit_tests/kernel/wave/minimize_global_loads.py b/lit_tests/kernel/wave/minimize_global_loads.py index 310e9ef4..b085d8b1 100644 --- a/lit_tests/kernel/wave/minimize_global_loads.py +++ b/lit_tests/kernel/wave/minimize_global_loads.py @@ -156,9 +156,9 @@ def test_gemm(): # Reduction subgraph: # CHECK: %acc_0_0_0 - # CHECK-NEXT: %acc_1_1_0 - # CHECK-NEXT: %acc_1_0_0 # CHECK-NEXT: %acc_0_1_0 + # CHECK-NEXT: %acc_1_0_0 + # CHECK-NEXT: %acc_1_1_0 # CHECK-NEXT: %a # CHECK-NEXT: %read_4 # CHECK-SAME: (%a, 8, None, None) @@ -215,9 +215,9 @@ def test_gemm(): # Reduction subgraph (custom format): # CHECK: placeholder(_name=acc_0_0_0 - # CHECK-NEXT: placeholder(_name=acc_1_1_0 - # CHECK-NEXT: placeholder(_name=acc_1_0_0 # CHECK-NEXT: placeholder(_name=acc_0_1_0 + # CHECK-NEXT: placeholder(_name=acc_1_0_0 + # CHECK-NEXT: placeholder(_name=acc_1_1_0 # CHECK-NEXT: placeholder(_name=a # CHECK-NEXT: read(memory=a, elements_per_thread=8, # CHECK-SAME: index={M: $WG0*BLOCK_M + Mod(16*$T1 + 32*$T2 + floor($T0/8), 64), K: ARGK*BLOCK_K + 8*(Mod($T0, 8)) : 8 : 1}) diff --git a/shark_turbine/kernel/ops/wave_ops.py b/shark_turbine/kernel/ops/wave_ops.py index 0298a065..3a2d3d3b 100644 --- a/shark_turbine/kernel/ops/wave_ops.py +++ b/shark_turbine/kernel/ops/wave_ops.py @@ -365,12 +365,15 @@ def copy( new_name: Optional[str] = None, new_graph: Optional[fx.Graph] = None, arg_transform: Optional[Callable[[Any], Any]] = lambda x: x, + anchor: Optional[fx.Node] = None, ) -> Self: """Returns a duplicate of this node.""" graph = new_graph if new_graph is None: graph = self.graph - graph.inserting_after(self.fx_node) + if anchor is None: + anchor = self.fx_node + graph.inserting_after(anchor) new_node = graph.node_copy(self.fx_node, arg_transform=arg_transform) new_node.tkw_op = self new_node.tkw_op_name = self.tkw_op_name diff --git a/shark_turbine/kernel/wave/expansion.py b/shark_turbine/kernel/wave/expansion.py index b96f778c..53796682 100644 --- a/shark_turbine/kernel/wave/expansion.py +++ b/shark_turbine/kernel/wave/expansion.py @@ -329,14 +329,28 @@ def _expand_node( # Filter out the dimensions that are not indexed by the node restricted_dims = filter_and_zero_unselected_dims(dim_query, node.indexing_dims) logger.debug(f"Expanding node: {node} in {restricted_dims}") + + # For iter args, we want to insert + if not hasattr(_expand_node, "last_expanded_iter_arg"): + _expand_node.last_expanded_iter_arg = None + # Clone the node for the new expansion. The original node is reused for the # case of all dimensions being zero. if expansion_needed(restricted_dims, node.indexing_dims): - new_node = node.copy() + new_node = node.copy( + anchor=( + _expand_node.last_expanded_iter_arg + if isinstance(node, IterArg) + else None + ) + ) else: new_node = node logger.debug(f"did not clone node: {node} in {restricted_dims}") + if isinstance(node, IterArg): + _expand_node.last_expanded_iter_arg = new_node.fx_node + new_node.fx_node.expanded_dims = restricted_dims new_node.fx_node.name = get_expanded_name(node, restricted_dims) node_index_setter(new_node, restricted_dims) From 4e9535111b1da24f06445f933a99cfcd7d5cb9f0 Mon Sep 17 00:00:00 2001 From: Christopher McGirr <7071833+chrsmcgrr@users.noreply.github.com> Date: Tue, 24 Sep 2024 06:07:15 +0200 Subject: [PATCH 02/28] [ExportedProgram] Add mutable attribute to buffer (#123) Fixes iree-org/iree-turbine#85 PR based on the work of @maxbartel Requires changes in torch-mlir: [llvm/torch-mlir/#3688](https://github.com/llvm/torch-mlir/pull/3688) Adds the mutable modifier to a global buffer and lifts said buffer to a global if there is a store-producer node associated with it. Signed-off-by: Christopher McGirr Co-authored-by: Maximilian Bartel --- .../support/procedural/exported_program.py | 15 +++-- tests/aot/globals_test.py | 62 +++++++++++++++++++ 2 files changed, 73 insertions(+), 4 deletions(-) diff --git a/shark_turbine/aot/support/procedural/exported_program.py b/shark_turbine/aot/support/procedural/exported_program.py index 331a7345..bbc431ae 100644 --- a/shark_turbine/aot/support/procedural/exported_program.py +++ b/shark_turbine/aot/support/procedural/exported_program.py @@ -234,6 +234,8 @@ def store_produced_value( raise ValueError(f"Cannot store value to unmapped global for: {info}") logger.debug("Resolved global for store %r", mapping) materialized_global: MaterializedGlobal = mapping.value # type: ignore + assert isinstance(materialized_global.global_op, util_d.GlobalOp) + materialized_global.global_op.is_mutable = True converted_value = Operation.create( "torch_c.to_builtin_tensor", results=[materialized_global.ir_type], @@ -251,7 +253,7 @@ def resolve_literal( return None # See if we know about it. - materialized_global = self._lift_tensor_to_global(literal) + materialized_global = self._lift_tensor_to_global(literal, info) if not materialized_global: # If it is unknown, just let the default importer take it on. return None @@ -269,7 +271,7 @@ def resolve_literal( return converted_value def _lift_tensor_to_global( - self, literal: torch.Tensor + self, literal: torch.Tensor, info: InputInfo | None ) -> Optional[MaterializedGlobal]: module_builder = self.module_builder mapping = module_builder.global_ref_tracker.track(literal) @@ -282,7 +284,7 @@ def _lift_tensor_to_global( # Policy check: Should we auto-import? Generally, we keep "small" # tensors as inline as they can be optimized. external_trait = ExternalTensorTrait.get(literal) - if not self._should_lift_tensor_to_global(literal, external_trait): + if not self._should_lift_tensor_to_global(literal, external_trait, info): return None # If it is a tensor we haven't seen yet, materialize it @@ -304,8 +306,13 @@ def _lift_tensor_to_global( return materialized_global def _should_lift_tensor_to_global( - self, literal: torch.Tensor, external_trait: Optional[ExternalTensorTrait] + self, + literal: torch.Tensor, + external_trait: Optional[ExternalTensorTrait], + info: InputInfo | None, ) -> bool: + if info is not None and info.store_producer_node: + return True if external_trait is not None: return True volume = math.prod(literal.shape) diff --git a/tests/aot/globals_test.py b/tests/aot/globals_test.py index 26bab1a6..607382fd 100644 --- a/tests/aot/globals_test.py +++ b/tests/aot/globals_test.py @@ -425,6 +425,68 @@ def testUnsupportedCombinations(self): export_global(AbstractF32, external=True, uninitialized=True) +class SimpleCache(torch.nn.Module): + def __init__(self, max_size, dtype=torch.float32): + super().__init__() + self.register_buffer("cache", torch.zeros(max_size, dtype=dtype)) + + def forward(self, input_pos, values): + # input_pos: [S], values: [S] + assert input_pos.shape[0] == values.shape[0] + + # Writing the values to the buffer at the specified positions + cache = torch.ops.aten.index_put_(self.cache, [input_pos], values) + + return cache + + +class ReadWriteReadCache(torch.nn.Module): + def __init__(self, max_size, dtype=torch.float32): + super().__init__() + self.register_buffer("cache", torch.zeros(max_size, dtype=dtype)) + + def forward(self, input_pos, values): + # input_pos: [S], values: [S] + assert input_pos.shape[0] == values.shape[0] + cache_value_0 = self.cache[2].clone() + # Writing the values to the buffer at the specified positions + cache = torch.ops.aten.index_put_(self.cache, [input_pos], values) + cache_value_1 = cache[2].clone() + return cache, cache_value_0, cache_value_1 + + +class BufferTest(unittest.TestCase): + def testMutableBuffer(self): + max_size = 10 + simple_cache = SimpleCache(max_size) + + input_pos = torch.tensor([2, 5, 7]) + values = torch.tensor([1.0, 2.0, 3.0]) + simple_cache(input_pos, values) + exported_fx_graph = torch.export.export(simple_cache, args=(input_pos, values)) + exported_programm = export(exported_fx_graph) + module_str = str(exported_programm.mlir_module) + self.assertIn( + "util.global private mutable @__auto.constant_10_torch.float32", + module_str, + ) + + def testReadWriteReadMutableBuffer(self): + max_size = 10 + simple_cache = ReadWriteReadCache(max_size) + + input_pos = torch.tensor([2, 5, 7]) + values = torch.tensor([1.0, 2.0, 3.0]) + simple_cache(input_pos, values) + exported_fx_graph = torch.export.export(simple_cache, args=(input_pos, values)) + exported_programm = export(exported_fx_graph) + module_str = str(exported_programm.mlir_module) + self.assertIn( + "util.global private mutable @__auto.constant_10_torch.float32", + module_str, + ) + + if __name__ == "__main__": logging.basicConfig(level=logging.DEBUG) unittest.main() From 65eb532abb2acec0b95b66565b35352579a7d896 Mon Sep 17 00:00:00 2001 From: Stanley Winata <68087699+raikonenfnu@users.noreply.github.com> Date: Tue, 24 Sep 2024 13:15:31 -0700 Subject: [PATCH 03/28] [TKW] Add xfail decorator for unaligned shape (#163) --- .github/workflows/ci.yaml | 1 - tests/kernel/wave/wave_e2e_test.py | 13 ++++++++++++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 8f938a0e..5146f144 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -60,7 +60,6 @@ jobs: if: "contains(matrix.os, 'mi300') && !cancelled()" run: | export WAVE_RUN_E2E_TESTS=1 - export TEST_PARAMS_PATH=./tests/kernel/wave/test_param.json pytest -n 4 ./tests/kernel/wave/ - name: Run LIT tests diff --git a/tests/kernel/wave/wave_e2e_test.py b/tests/kernel/wave/wave_e2e_test.py index 4c2f04db..fcabb8d9 100644 --- a/tests/kernel/wave/wave_e2e_test.py +++ b/tests/kernel/wave/wave_e2e_test.py @@ -44,6 +44,15 @@ def get_test_shapes(test_name: str) -> list[tuple[int]]: return default_test_shapes +def xfail_unaligned(func): + def wrapper(shape): + if shape[-1] % 2 != 0: + pytest.xfail("Unaligned shape is not expected to work on this test yet.") + func(shape) + + return wrapper + + @require_e2e @pytest.mark.parametrize("shape", get_test_shapes("test_copy")) def test_copy(shape): @@ -269,13 +278,14 @@ def test( @require_e2e @pytest.mark.parametrize("shape", get_test_shapes("test_tiled_reduce_max")) +@xfail_unaligned def test_tiled_reduce_max(shape): M = tkl.sym.M N = tkl.sym.N wave_size = 64 BLOCK_M = 1 BLOCK_N = tkl.sym.BLOCK_N - ELEMS_PER_THREAD = BLOCK_N / wave_size + ELEMS_PER_THREAD = tkl.sym.ELEMS_PER_THREAD ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE constraints: list[tkw.Constraint] = [ @@ -322,6 +332,7 @@ def repeat( M: shape[0], N: shape[1], BLOCK_N: min(128, shape[1]), + ELEMS_PER_THREAD: min(128, shape[1]) // wave_size, ADDRESS_SPACE: tkl.AddressSpace.GLOBAL_MEMORY.value, }, canonicalize=True, From 909411a5e47f41ff70368a1acba2de5cafea9440 Mon Sep 17 00:00:00 2001 From: Stanley Winata <68087699+raikonenfnu@users.noreply.github.com> Date: Tue, 24 Sep 2024 16:45:27 -0700 Subject: [PATCH 04/28] [TKW] Fix indexing of Reduction and GetResult to enable post-tile op. (#162) This PR introduces changes to handle elementwise or general arithmetic operations after we did some tiled-loop-reduction ("Reduction") operation. The main problem with the current stack is indexing_dims information for Reduction relies on the user. This would work if it's user/consumer is tkw.write, but in other cases such as BinaryPyOp or UnaryPyOp, it will lack such information. To make matters worst BinaryPyOp/UnaryPyOp depends on it's src/producer for indexing dim, while Reduction op depends on it's dst/consumer for its' indexing dim information. This would ended up causing infinite loop between UnaryPyOp/BinaryPyOp <-> Reduction. This PR fixes the indexing dimension logic Reduction and GetResult (required for expanded Reduction) to be based on it's reduction axis(for Reduction) and it's source/consumer information. --------- Signed-off-by: Stanley Winata --- lit_tests/kernel/wave/codegen.py | 88 ++++++++++++++++++++++++++ shark_turbine/kernel/ops/wave_ops.py | 45 +++++++++---- shark_turbine/kernel/wave/expansion.py | 5 ++ 3 files changed, 125 insertions(+), 13 deletions(-) diff --git a/lit_tests/kernel/wave/codegen.py b/lit_tests/kernel/wave/codegen.py index 0bba2384..b84cc271 100644 --- a/lit_tests/kernel/wave/codegen.py +++ b/lit_tests/kernel/wave/codegen.py @@ -756,6 +756,94 @@ def test( # CHECK: arith.addf {{.*}} : vector<1xf16> +# This test is to ensure that the propagation of indexing_dims between reduction and operations +# outside the reduction is working properly. +@run_test +def test_reduction_and_elemwise(): + M = tkl.sym.M + N = tkl.sym.N + BLOCK_M = tkl.sym.BLOCK_M + BLOCK_N = tkl.sym.BLOCK_N + ELEMS_PER_THREAD = tkl.sym.ELEMS_PER_THREAD + ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE + + constraints: list[tkw.Constraint] = [ + tkw.HardwareConstraint( + threads_per_wave=64, + waves_per_block=(1, 1, 1), + vector_shapes={M: 1, N: BLOCK_N}, + ) + ] + constraints += [tkw.WorkgroupConstraint(M, BLOCK_M, 1)] + constraints += [tkw.WorkgroupConstraint(N, N, 0)] + constraints += [tkw.TilingConstraint(N, BLOCK_N)] + constraints += [tkw.WaveConstraint(M, BLOCK_M)] + constraints += [tkw.WaveConstraint(N, BLOCK_N)] + + @tkw.wave(constraints) + def test( + a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16], + c: tkl.Memory[M, ADDRESS_SPACE, tkl.f16], + ): + init_max = tkl.Register[M, tkl.f16](-1e6) + + @tkw.reduction(N, init_args=[init_max]) + def repeat( + partial_max: tkl.Register[M, tkl.f16], + ) -> tkl.Register[M, tkl.f16]: + lhs = tkw.read(a, elements_per_thread=ELEMS_PER_THREAD) + partial_max = tkw.max(lhs, partial_max, dim=N) + return partial_max + + result = repeat + repeat + tkw.write(result, c, elements_per_thread=1) + + config = {"backend": "rocm", "device": "hip", "target": "gfx942"} + + shape = (256, 512) + a = torch.randn(shape, dtype=torch.float16) + c = torch.zeros((shape[0],), dtype=torch.float16) + with tk.gen.TestLaunchContext( + { + M: shape[0], + N: shape[1], + BLOCK_M: 2, + BLOCK_N: 128, + ELEMS_PER_THREAD: 2, + ADDRESS_SPACE: tkl.AddressSpace.GLOBAL_MEMORY.value, + }, + canonicalize=True, + ): + print(test(a, c).module_op) + # CHECK-DAG: %[[C0_IDX:.+]] = arith.constant 0 : index + # CHECK-DAG: %[[C4_IDX:.+]] = arith.constant 4 : index + # CHECK-DAG: %[[C1_IDX:.+]] = arith.constant 1 : index + # CHECK-DAG: %[[INIT:.+]] = arith.constant dense<0xFC00> : vector<1xf16> + + # Tile Reduction Loop + # CHECK: %[[TILED:.+]]:2 = scf.for %[[ITER:.+]] = %[[C0_IDX]] to %[[C4_IDX]] step %[[C1_IDX]] + # CHECK-SAME: iter_args(%[[ACC0:.+]] = %[[INIT]], %[[ACC1:.+]] = %[[INIT]]) -> (vector<1xf16>, vector<1xf16>) { + # 1st Expanded Local Reduction + # CHECK: arith.maximumf {{.*}} : vector<1xf16> + # 1st Expanded Global Reduction + # CHECK-COUNT-6: gpu.shuffle xor + # 1st Expanded Accumulator Reduction + # CHECK: %[[ACC_REDUCE_0:.+]] = arith.maximumf %[[ACC0]], %{{.*}} + + # 2nd Expanded Local Reduction + # CHECK: arith.maximumf {{.*}} : vector<1xf16> + # 2nd Expanded Global Reduction + # CHECK-COUNT-6: gpu.shuffle xor + # 2nd Expanded Accumulator Reduction + # CHECK: %[[ACC_REDUCE_1:.+]] = arith.maximumf %[[ACC1]], %{{.*}} + + # CHECK: scf.yield %[[ACC_REDUCE_0]], %[[ACC_REDUCE_1]] : vector<1xf16>, vector<1xf16> + # CHECK: %[[POST_TILE_ELEMWISE_0:.+]] = arith.addf %[[TILED]]#0, %[[TILED]]#0 : vector<1xf16> + # CHECK: %[[POST_TILE_ELEMWISE_1:.+]] = arith.addf %[[TILED]]#1, %[[TILED]]#1 : vector<1xf16> + # CHECK: vector.store %[[POST_TILE_ELEMWISE_0:.+]], %{{.*}} + # CHECK: vector.store %[[POST_TILE_ELEMWISE_1:.+]], %{{.*}} + + @run_test def test_tiled_reduce_max(): M = tkl.sym.M diff --git a/shark_turbine/kernel/ops/wave_ops.py b/shark_turbine/kernel/ops/wave_ops.py index 3a2d3d3b..ebadf0c4 100644 --- a/shark_turbine/kernel/ops/wave_ops.py +++ b/shark_turbine/kernel/ops/wave_ops.py @@ -861,12 +861,23 @@ def wrapper(f): return wrapper @property - def indexing_dims(self) -> list[IndexSymbol]: + def indexing_dims(self) -> list[IndexSymbol] | list[list[IndexSymbol]]: expand_dims: list[IndexSymbol] = [] - for user in self.users: - for indexing_dim in user.indexing_dims: - if indexing_dim not in expand_dims: - expand_dims.append(indexing_dim) + return_node = [ + nested_node + for nested_node in self.graph.subgraphs[self.subgraph_name].nodes + if isinstance(get_custom(nested_node), Output) + ] + assert len(return_node) == 1 + return_vals = get_custom(return_node[0]).return_vals[0] + if not isinstance(return_vals, Sequence): + return_vals = [return_vals] + for return_val in return_vals: + return_dims = get_custom(return_val).indexing_dims + reduced_dims = [dims for dims in return_dims if dims != self.axis] + expand_dims.append(reduced_dims) + if len(expand_dims) == 1: + expand_dims = expand_dims[0] return expand_dims def iter_args(self, graph: fx.Graph) -> list[fx.Node]: @@ -952,16 +963,24 @@ class GetResult(CustomOp): @property def type(self) -> "Memory": - return get_custom(self.value).type[self.res_idx] + src_type = get_custom(self.value).type + if isinstance(src_type, list): + return src_type[self.res_idx] + else: + return src_type @property - def indexing_dims(self) -> list[IndexSymbol]: - expand_dims: list[IndexSymbol] = [] - for user in self.users: - for indexing_dim in user.indexing_dims: - if indexing_dim not in expand_dims: - expand_dims.append(indexing_dim) - return expand_dims + def indexing_dims(self) -> list[IndexExpr]: + has_multiple_value = lambda x: all(isinstance(el, list) for el in x) + is_valid_indexing_dim = lambda x: isinstance(src_indexing, list) and all( + isinstance(el, IndexExpr) for el in x + ) + src_indexing = get_custom(self.value).indexing_dims + if has_multiple_value(src_indexing): + assert self.res_idx <= len(src_indexing) - 1 + src_indexing = src_indexing[self.res_idx] + assert is_valid_indexing_dim(src_indexing) + return src_indexing @property def index(self) -> dict[IndexSymbol, IndexSequence]: diff --git a/shark_turbine/kernel/wave/expansion.py b/shark_turbine/kernel/wave/expansion.py index 53796682..2610f968 100644 --- a/shark_turbine/kernel/wave/expansion.py +++ b/shark_turbine/kernel/wave/expansion.py @@ -81,6 +81,11 @@ def get_indexed_dims( """ if isinstance(nodeOrDims, CustomOp): nodeOrDims = nodeOrDims.indexing_dims + # Flatten dims for node with multiple values or expanded Reduction. + if all(isinstance(el, Sequence) for el in nodeOrDims): + flattened_dims = list(itertools.chain.from_iterable(nodeOrDims)) + flatten_dims_set = dict.fromkeys(flattened_dims) + nodeOrDims = list(flatten_dims_set) return tuple((key, all_dims[key]) for key in nodeOrDims if key in all_dims) From d37c6a49c112abd9019df1a43c7c4702935fde9c Mon Sep 17 00:00:00 2001 From: harsh-nod Date: Thu, 26 Sep 2024 12:34:45 -0500 Subject: [PATCH 05/28] Get GEMMs working without minimize_global_loads (#167) This PR removes the need for propagating indices using post expansion. The new approach propagates the MMA indices to the MMA dimensions of all tensors (rather than just MMA nodes) and then specializes them depending on whether they lie within the backward slices of the LHS and RHS or forward slices of the ACC. --------- Signed-off-by: Harsh Menon --- lit_tests/kernel/wave/codegen.py | 185 +++++++++--------- lit_tests/kernel/wave/expansion.py | 20 +- .../kernel/wave/index_sequence_analysis.py | 8 +- .../kernel/wave/minimize_global_loads.py | 16 +- shark_turbine/kernel/ops/wave_ops.py | 10 - shark_turbine/kernel/wave/expansion.py | 15 +- .../kernel/wave/index_sequence_analysis.py | 2 +- shark_turbine/kernel/wave/utils.py | 150 +++++++++++++- 8 files changed, 271 insertions(+), 135 deletions(-) diff --git a/lit_tests/kernel/wave/codegen.py b/lit_tests/kernel/wave/codegen.py index b84cc271..d102f353 100644 --- a/lit_tests/kernel/wave/codegen.py +++ b/lit_tests/kernel/wave/codegen.py @@ -386,7 +386,7 @@ def mma( print(mma(a, b, c).module_op) # CHECK: func.func @mma(%[[ARG0:[a-zA-Z0-9_]+]]: !stream.binding, %[[ARG1:[a-zA-Z0-9_]+]]: !stream.binding, - # CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: !stream.binding) + # CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: !stream.binding) attributes {translation_info = #[[TRANSLATION:.+]]} { # CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index # CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index # CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index @@ -405,60 +405,63 @@ def mma( # CHECK: %[[D1:.+]] = arith.divsi %[[THREAD_ID_X]], %[[C64]] : index # CHECK: %[[D2:.+]] = arith.muli %[[D1]], %[[C16]] : index # CHECK: %[[D3:.+]] = arith.muli %[[WORKGROUP_ID_0]], %[[C32]] : index - # CHECK: %[[D4:.+]] = arith.addi %[[D3]], %[[D2]] : index - # CHECK: %[[D5:.+]] = vector.load %[[D0]][%[[D4]], %[[C0]]] : memref<64x16xf16, strided<[16, 1], offset: ?>>, - # CHECK-SAME: vector<4xf16> + # CHECK: %[[D4:.+]] = arith.remsi %[[THREAD_ID_X]], %[[C16]] : index + # CHECK: %[[D5:.+]] = arith.addi %[[D4]], %[[D3]] : index + # CHECK: %[[D6:.+]] = arith.addi %[[D5]], %[[D2]] : index + # CHECK: %[[D7:.+]] = arith.remsi %[[THREAD_ID_X]], %[[C64]] : index + # CHECK: %[[D8:.+]] = arith.divsi %[[D7]], %[[C16]] : index + # CHECK: %[[D9:.+]] = arith.muli %[[D8]], %[[C4]] : index + # CHECK: %[[D10:.+]] = vector.load %[[D0]][%[[D6]], %[[D9]]] : memref<64x16xf16, strided<[16, 1], offset: + # CHECK-SAME: ?>>, vector<4xf16> # CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<32x16xf16, #[[GPU:.+]].address_space> - # CHECK: vector.store %[[D5]], %[[ALLOC]][%[[D2]], %[[C0]]] : memref<32x16xf16, + # CHECK: %[[D11:.+]] = arith.addi %[[D4]], %[[D2]] : index + # CHECK: vector.store %[[D10]], %[[ALLOC]][%[[D11]], %[[D9]]] : memref<32x16xf16, # CHECK-SAME: #[[GPU]].address_space>, vector<4xf16> # CHECK: amdgpu.lds_barrier - # CHECK: %[[D6:.+]] = arith.remsi %[[THREAD_ID_X]], %[[C16]] : index - # CHECK: %[[D7:.+]] = arith.addi %[[D6]], %[[D2]] : index - # CHECK: %[[D8:.+]] = arith.remsi %[[THREAD_ID_X]], %[[C64]] : index - # CHECK: %[[D9:.+]] = arith.divsi %[[D8]], %[[C16]] : index - # CHECK: %[[D10:.+]] = arith.muli %[[D9]], %[[C4]] : index - # CHECK: %[[D11:.+]] = vector.load %[[ALLOC]][%[[D7]], %[[D10]]] : memref<32x16xf16, + # CHECK: %[[D12:.+]] = vector.load %[[ALLOC]][%[[D11]], %[[D9]]] : memref<32x16xf16, # CHECK-SAME: #[[GPU]].address_space>, vector<4xf16> - # CHECK: %[[D12:.+]] = stream.binding.subspan %[[ARG1]][%[[C0]]] : !stream.binding -> memref<128x16xf16, + # CHECK: %[[D13:.+]] = stream.binding.subspan %[[ARG1]][%[[C0]]] : !stream.binding -> memref<128x16xf16, # CHECK-SAME: strided<[16, 1], offset: ?>> - # CHECK: %[[D13:.+]] = arith.muli %[[THREAD_ID_Y]], %[[C16]] : index - # CHECK: %[[D14:.+]] = arith.muli %[[WORKGROUP_ID_1]], %[[C32]] : index - # CHECK: %[[D15:.+]] = arith.addi %[[D14]], %[[D13]] : index - # CHECK: %[[D16:.+]] = vector.load %[[D12]][%[[D15]], %[[C0]]] : memref<128x16xf16, strided<[16, 1], offset: + # CHECK: %[[D14:.+]] = arith.muli %[[THREAD_ID_Y]], %[[C16]] : index + # CHECK: %[[D15:.+]] = arith.muli %[[WORKGROUP_ID_1]], %[[C32]] : index + # CHECK: %[[D16:.+]] = arith.addi %[[D4]], %[[D15]] : index + # CHECK: %[[D17:.+]] = arith.addi %[[D16]], %[[D14]] : index + # CHECK: %[[D18:.+]] = vector.load %[[D13]][%[[D17]], %[[D9]]] : memref<128x16xf16, strided<[16, 1], offset: # CHECK-SAME: ?>>, vector<4xf16> # CHECK: %[[ALLOC_0:.+]] = memref.alloc() : memref<32x16xf16, #[[GPU]].address_space> - # CHECK: vector.store %[[D16]], %[[ALLOC_0]][%[[D13]], %[[C0]]] : memref<32x16xf16, + # CHECK: amdgpu.lds_barrier + # CHECK: %[[D19:.+]] = arith.addi %[[D4]], %[[D14]] : index + # CHECK: vector.store %[[D18]], %[[ALLOC_0]][%[[D19]], %[[D9]]] : memref<32x16xf16, # CHECK-SAME: #[[GPU]].address_space>, vector<4xf16> # CHECK: amdgpu.lds_barrier - # CHECK: %[[D17:.+]] = arith.addi %[[D6]], %[[D13]] : index - # CHECK: %[[D18:.+]] = vector.load %[[ALLOC_0]][%[[D17]], %[[D10]]] : memref<32x16xf16, + # CHECK: %[[D20:.+]] = vector.load %[[ALLOC_0]][%[[D19]], %[[D9]]] : memref<32x16xf16, # CHECK-SAME: #[[GPU]].address_space>, vector<4xf16> - # CHECK: %[[D19:.+]] = amdgpu.mfma %[[D11]] * %[[D18]] + %[[CST]] {blocks = 1 : i32, k = 16 : i32, m = 16 : + # CHECK: %[[D21:.+]] = amdgpu.mfma %[[D12]] * %[[D20]] + %[[CST]] {blocks = 1 : i32, k = 16 : i32, m = 16 : # CHECK-SAME: i32, n = 16 : i32} blgp = none : vector<4xf16>, vector<4xf16>, vector<4xf32> - # CHECK: %[[D20:.+]] = vector.extract_strided_slice %[[D19]] {offsets = [0], sizes = [1], strides = [1]} : + # CHECK: %[[D22:.+]] = vector.extract_strided_slice %[[D21]] {offsets = [0], sizes = [1], strides = [1]} : # CHECK-SAME: vector<4xf32> to vector<1xf32> - # CHECK: %[[D21:.+]] = stream.binding.subspan %[[ARG2]][%[[C0]]] : !stream.binding -> memref<64x128xf32, + # CHECK: %[[D23:.+]] = stream.binding.subspan %[[ARG2]][%[[C0]]] : !stream.binding -> memref<64x128xf32, # CHECK-SAME: strided<[128, 1], offset: ?>> - # CHECK: %[[D22:.+]] = arith.addi %[[D4]], %[[D10]] : index - # CHECK: %[[D23:.+]] = arith.addi %[[D6]], %[[D14]] : index - # CHECK: %[[D24:.+]] = arith.addi %[[D23]], %[[D13]] : index - # CHECK: vector.store %[[D20]], %[[D21]][%[[D22]], %[[D24]]] : memref<64x128xf32, strided<[128, 1], offset: + # CHECK: %[[D24:.+]] = arith.addi %[[D3]], %[[D2]] : index + # CHECK: %[[D25:.+]] = arith.addi %[[D24]], %[[D9]] : index + # CHECK: vector.store %[[D22]], %[[D23]][%[[D25]], %[[D17]]] : memref<64x128xf32, strided<[128, 1], offset: # CHECK-SAME: ?>>, vector<1xf32> - # CHECK: %[[D25:.+]] = vector.extract_strided_slice %[[D19]] {offsets = [1], sizes = [1], strides = [1]} : + # CHECK: %[[D26:.+]] = vector.extract_strided_slice %[[D21]] {offsets = [1], sizes = [1], strides = [1]} : # CHECK-SAME: vector<4xf32> to vector<1xf32> - # CHECK: %[[D26:.+]] = arith.addi %[[D22]], %[[C1]] : index - # CHECK: vector.store %[[D25]], %[[D21]][%[[D26]], %[[D24]]] : memref<64x128xf32, strided<[128, 1], offset: + # CHECK: %[[D27:.+]] = arith.addi %[[D25]], %[[C1]] : index + # CHECK: vector.store %[[D26]], %[[D23]][%[[D27]], %[[D17]]] : memref<64x128xf32, strided<[128, 1], offset: # CHECK-SAME: ?>>, vector<1xf32> - # CHECK: %[[D27:.+]] = vector.extract_strided_slice %[[D19]] {offsets = [2], sizes = [1], strides = [1]} : + # CHECK: %[[D28:.+]] = vector.extract_strided_slice %[[D21]] {offsets = [2], sizes = [1], strides = [1]} : # CHECK-SAME: vector<4xf32> to vector<1xf32> - # CHECK: %[[D28:.+]] = arith.addi %[[D22]], %[[C2]] : index - # CHECK: vector.store %[[D27]], %[[D21]][%[[D28]], %[[D24]]] : memref<64x128xf32, strided<[128, 1], offset: + # CHECK: %[[D29:.+]] = arith.addi %[[D25]], %[[C2]] : index + # CHECK: vector.store %[[D28]], %[[D23]][%[[D29]], %[[D17]]] : memref<64x128xf32, strided<[128, 1], offset: # CHECK-SAME: ?>>, vector<1xf32> - # CHECK: %[[D29:.+]] = vector.extract_strided_slice %[[D19]] {offsets = [3], sizes = [1], strides = [1]} : + # CHECK: %[[D30:.+]] = vector.extract_strided_slice %[[D21]] {offsets = [3], sizes = [1], strides = [1]} : # CHECK-SAME: vector<4xf32> to vector<1xf32> - # CHECK: %[[D30:.+]] = arith.addi %[[D22]], %[[C3]] : index - # CHECK: vector.store %[[D29]], %[[D21]][%[[D30]], %[[D24]]] : memref<64x128xf32, strided<[128, 1], offset: + # CHECK: %[[D31:.+]] = arith.addi %[[D25]], %[[C3]] : index + # CHECK: vector.store %[[D30]], %[[D23]][%[[D31]], %[[D17]]] : memref<64x128xf32, strided<[128, 1], offset: # CHECK-SAME: ?>>, vector<1xf32> + # CHECK: return @run_test @@ -515,7 +518,7 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: print(gemm(a, b, c).module_op) # CHECK: func.func @gemm(%[[ARG0:[a-zA-Z0-9_]+]]: !stream.binding, %[[ARG1:[a-zA-Z0-9_]+]]: !stream.binding, - # CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: !stream.binding) + # CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: !stream.binding) attributes {translation_info = #[[TRANSLATION:.+]]} { # CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index # CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index # CHECK-DAG: %[[C32:.+]] = arith.constant 32 : index @@ -531,77 +534,81 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: # CHECK-DAG: %[[THREAD_ID_Y:.+]] = gpu.thread_id y # CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<32x16xf16, #[[GPU:.+]].address_space> # CHECK: %[[ALLOC_0:.+]] = memref.alloc() : memref<32x16xf16, #[[GPU]].address_space> - # CHECK: %[[D22:.+]] = stream.binding.subspan %[[ARG0]][%[[C0]]] : !stream.binding -> memref<64x64xf16, - # CHECK-SAME: strided<[64, 1], offset: ?>> - # CHECK: %[[D23:.+]] = stream.binding.subspan %[[ARG1]][%[[C0]]] : !stream.binding -> memref<128x64xf16, - # CHECK-SAME: strided<[64, 1], offset: ?>> - # CHECK: %[[D24:.+]] = arith.divsi %[[THREAD_ID_X]], %[[C64]] : index - # CHECK: %[[D25:.+]] = arith.muli %[[D24]], %[[C16]] : index - # CHECK: %[[D26:.+]] = arith.muli %[[WORKGROUP_ID_0]], %[[C32]] : index - # CHECK: %[[D27:.+]] = arith.addi %[[D26]], %[[D25]] : index - # CHECK: %[[D30:.+]] = arith.remsi %[[THREAD_ID_X]], %[[C16]] : index - # CHECK: %[[D31:.+]] = arith.addi %[[D30]], %[[D25]] : index - # CHECK: %[[D32:.+]] = arith.remsi %[[THREAD_ID_X]], %[[C64]] : index - # CHECK: %[[D33:.+]] = arith.divsi %[[D32]], %[[C16]] : index - # CHECK: %[[D34:.+]] = arith.muli %[[D33]], %[[C4]] : index - # CHECK: %[[D36:.+]] = arith.muli %[[THREAD_ID_Y]], %[[C16]] : index - # CHECK: %[[D37:.+]] = arith.muli %[[WORKGROUP_ID_1]], %[[C32]] : index - # CHECK: %[[D38:.+]] = arith.addi %[[D37]], %[[D36]] : index - # CHECK: %[[D40:.+]] = arith.addi %[[D30]], %[[D36]] : index - # CHECK: %[[D0:.+]] = scf.for %[[ARG3:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C4]] step %[[C1]] + # CHECK: %[[D0:.+]] = stream.binding.subspan %[[ARG0]][%[[C0]]] : !stream.binding -> memref<64x64xf16, + # CHECK-SAME: strided<[64, 1], offset: ?>> + # CHECK: %[[D1:.+]] = stream.binding.subspan %[[ARG1]][%[[C0]]] : !stream.binding -> memref<128x64xf16, + # CHECK-SAME: strided<[64, 1], offset: ?>> + # CHECK: %[[D2:.+]] = arith.divsi %[[THREAD_ID_X]], %[[C64]] : index + # CHECK: %[[D3:.+]] = arith.muli %[[D2]], %[[C16]] : index + # CHECK: %[[D4:.+]] = arith.muli %[[WORKGROUP_ID_0]], %[[C32]] : index + # CHECK: %[[D5:.+]] = arith.remsi %[[THREAD_ID_X]], %[[C16]] : index + # CHECK: %[[D6:.+]] = arith.addi %[[D5]], %[[D4]] : index + # CHECK: %[[D7:.+]] = arith.addi %[[D6]], %[[D3]] : index + # CHECK: %[[D8:.+]] = arith.remsi %[[THREAD_ID_X]], %[[C64]] : index + # CHECK: %[[D9:.+]] = arith.divsi %[[D8]], %[[C16]] : index + # CHECK: %[[D10:.+]] = arith.muli %[[D9]], %[[C4]] : index + # CHECK: %[[D11:.+]] = arith.addi %[[D5]], %[[D3]] : index + # CHECK: %[[D12:.+]] = arith.muli %[[THREAD_ID_Y]], %[[C16]] : index + # CHECK: %[[D13:.+]] = arith.muli %[[WORKGROUP_ID_1]], %[[C32]] : index + # CHECK: %[[D14:.+]] = arith.addi %[[D5]], %[[D13]] : index + # CHECK: %[[D15:.+]] = arith.addi %[[D14]], %[[D12]] : index + # CHECK: %[[D16:.+]] = arith.addi %[[D5]], %[[D12]] : index + # CHECK: %[[D17:.+]] = scf.for %[[ARG3:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C4]] step %[[C1]] # CHECK-SAME: iter_args(%[[ARG4:[a-zA-Z0-9_]+]] = %[[CST]]) -> (vector<4xf32>) { - # CHECK: %[[D28:.+]] = arith.muli %[[ARG3]], %[[C16]] : index - # CHECK: %[[D29:.+]] = vector.load %[[D22]][%[[D27]], %[[D28]]] : memref<64x64xf16, strided<[64, 1], - # CHECK-SAME: offset: ?>>, vector<4xf16> - # CHECK: vector.store %[[D29]], %[[ALLOC]][%[[D25]], %[[C0]]] : memref<32x16xf16, + # CHECK: %[[D39:.+]] = arith.muli %[[ARG3]], %[[C16]] : index + # CHECK: %[[D40:.+]] = arith.addi %[[D39]], %[[D10]] : index + # CHECK: %[[D41:.+]] = vector.load %[[D0]][%[[D7]], %[[D40]]] : memref<64x64xf16, strided<[64, 1], offset: + # CHECK-SAME: ?>>, vector<4xf16> + # CHECK: vector.store %[[D41]], %[[ALLOC]][%[[D11]], %[[D10]]] : memref<32x16xf16, # CHECK-SAME: #[[GPU]].address_space>, vector<4xf16> # CHECK: amdgpu.lds_barrier - # CHECK: %[[D35:.+]] = vector.load %[[ALLOC]][%[[D31]], %[[D34]]] : memref<32x16xf16, + # CHECK: %[[D42:.+]] = vector.load %[[ALLOC]][%[[D11]], %[[D10]]] : memref<32x16xf16, # CHECK-SAME: #[[GPU]].address_space>, vector<4xf16> - # CHECK: %[[D39:.+]] = vector.load %[[D23]][%[[D38]], %[[D28]]] : memref<128x64xf16, strided<[64, 1], + # CHECK: %[[D43:.+]] = vector.load %[[D1]][%[[D15]], %[[D40]]] : memref<128x64xf16, strided<[64, 1], # CHECK-SAME: offset: ?>>, vector<4xf16> - # CHECK: vector.store %[[D39]], %[[ALLOC_0]][%[[D36]], %[[C0]]] : memref<32x16xf16, + # CHECK: amdgpu.lds_barrier + # CHECK: vector.store %[[D43]], %[[ALLOC_0]][%[[D16]], %[[D10]]] : memref<32x16xf16, # CHECK-SAME: #[[GPU]].address_space>, vector<4xf16> # CHECK: amdgpu.lds_barrier - # CHECK: %[[D41:.+]] = vector.load %[[ALLOC_0]][%[[D40]], %[[D34]]] : memref<32x16xf16, + # CHECK: %[[D44:.+]] = vector.load %[[ALLOC_0]][%[[D16]], %[[D10]]] : memref<32x16xf16, # CHECK-SAME: #[[GPU]].address_space>, vector<4xf16> - # CHECK: %[[D42:.+]] = amdgpu.mfma %[[D35]] * %[[D41]] + %[[ARG4]] {blocks = 1 : i32, k = 16 : i32, m = 16 + # CHECK: %[[D45:.+]] = amdgpu.mfma %[[D42]] * %[[D44]] + %[[ARG4]] {blocks = 1 : i32, k = 16 : i32, m = 16 # CHECK-SAME: : i32, n = 16 : i32} blgp = none : vector<4xf16>, vector<4xf16>, vector<4xf32> - # CHECK: scf.yield %[[D42]] : vector<4xf32> + # CHECK: scf.yield %[[D45]] : vector<4xf32> # CHECK: } - # CHECK: %[[D1:.+]] = vector.extract_strided_slice %[[D0]] {offsets = [0], sizes = [1], strides = [1]} : + # CHECK: %[[D18:.+]] = vector.extract_strided_slice %[[D17]] {offsets = [0], sizes = [1], strides = [1]} : # CHECK-SAME: vector<4xf32> to vector<1xf32> - # CHECK: %[[D2:.+]] = stream.binding.subspan %[[ARG2]][%[[C0]]] : !stream.binding -> memref<64x128xf32, + # CHECK: %[[D19:.+]] = stream.binding.subspan %[[ARG2]][%[[C0]]] : !stream.binding -> memref<64x128xf32, # CHECK-SAME: strided<[128, 1], offset: ?>> - # CHECK: %[[D3:.+]] = arith.remsi %[[THREAD_ID_X]], %[[C64]] : index - # CHECK: %[[D4:.+]] = arith.divsi %[[D3]], %[[C16]] : index - # CHECK: %[[D5:.+]] = arith.muli %[[D4]], %[[C4]] : index - # CHECK: %[[D6:.+]] = arith.divsi %[[THREAD_ID_X]], %[[C64]] : index - # CHECK: %[[D7:.+]] = arith.muli %[[D6]], %[[C16]] : index - # CHECK: %[[D8:.+]] = arith.muli %[[WORKGROUP_ID_0]], %[[C32]] : index - # CHECK: %[[D9:.+]] = arith.addi %[[D8]], %[[D7]] : index - # CHECK: %[[D10:.+]] = arith.addi %[[D9]], %[[D5]] : index - # CHECK: %[[D11:.+]] = arith.muli %[[THREAD_ID_Y]], %[[C16]] : index - # CHECK: %[[D12:.+]] = arith.muli %[[WORKGROUP_ID_1]], %[[C32]] : index - # CHECK: %[[D13:.+]] = arith.remsi %[[THREAD_ID_X]], %[[C16]] : index - # CHECK: %[[D14:.+]] = arith.addi %[[D13]], %[[D12]] : index - # CHECK: %[[D15:.+]] = arith.addi %[[D14]], %[[D11]] : index - # CHECK: vector.store %[[D1]], %[[D2]][%[[D10]], %[[D15]]] : memref<64x128xf32, strided<[128, 1], offset: + # CHECK: %[[D20:.+]] = arith.remsi %[[THREAD_ID_X]], %[[C64]] : index + # CHECK: %[[D21:.+]] = arith.divsi %[[D20]], %[[C16]] : index + # CHECK: %[[D22:.+]] = arith.muli %[[D21]], %[[C4]] : index + # CHECK: %[[D23:.+]] = arith.divsi %[[THREAD_ID_X]], %[[C64]] : index + # CHECK: %[[D24:.+]] = arith.muli %[[D23]], %[[C16]] : index + # CHECK: %[[D25:.+]] = arith.muli %[[WORKGROUP_ID_0]], %[[C32]] : index + # CHECK: %[[D26:.+]] = arith.addi %[[D25]], %[[D24]] : index + # CHECK: %[[D27:.+]] = arith.addi %[[D26]], %[[D22]] : index + # CHECK: %[[D28:.+]] = arith.muli %[[THREAD_ID_Y]], %[[C16]] : index + # CHECK: %[[D29:.+]] = arith.muli %[[WORKGROUP_ID_1]], %[[C32]] : index + # CHECK: %[[D30:.+]] = arith.remsi %[[THREAD_ID_X]], %[[C16]] : index + # CHECK: %[[D31:.+]] = arith.addi %[[D30]], %[[D29]] : index + # CHECK: %[[D32:.+]] = arith.addi %[[D31]], %[[D28]] : index + # CHECK: vector.store %[[D18]], %[[D19]][%[[D27]], %[[D32]]] : memref<64x128xf32, strided<[128, 1], offset: # CHECK-SAME: ?>>, vector<1xf32> - # CHECK: %[[D16:.+]] = vector.extract_strided_slice %[[D0]] {offsets = [1], sizes = [1], strides = [1]} : + # CHECK: %[[D33:.+]] = vector.extract_strided_slice %[[D17]] {offsets = [1], sizes = [1], strides = [1]} : # CHECK-SAME: vector<4xf32> to vector<1xf32> - # CHECK: %[[D17:.+]] = arith.addi %[[D10]], %[[C1]] : index - # CHECK: vector.store %[[D16]], %[[D2]][%[[D17]], %[[D15]]] : memref<64x128xf32, strided<[128, 1], offset: + # CHECK: %[[D34:.+]] = arith.addi %[[D27]], %[[C1]] : index + # CHECK: vector.store %[[D33]], %[[D19]][%[[D34]], %[[D32]]] : memref<64x128xf32, strided<[128, 1], offset: # CHECK-SAME: ?>>, vector<1xf32> - # CHECK: %[[D18:.+]] = vector.extract_strided_slice %[[D0]] {offsets = [2], sizes = [1], strides = [1]} : + # CHECK: %[[D35:.+]] = vector.extract_strided_slice %[[D17]] {offsets = [2], sizes = [1], strides = [1]} : # CHECK-SAME: vector<4xf32> to vector<1xf32> - # CHECK: %[[D19:.+]] = arith.addi %[[D10]], %[[C2]] : index - # CHECK: vector.store %[[D18]], %[[D2]][%[[D19]], %[[D15]]] : memref<64x128xf32, strided<[128, 1], offset: + # CHECK: %[[D36:.+]] = arith.addi %[[D27]], %[[C2]] : index + # CHECK: vector.store %[[D35]], %[[D19]][%[[D36]], %[[D32]]] : memref<64x128xf32, strided<[128, 1], offset: # CHECK-SAME: ?>>, vector<1xf32> - # CHECK: %[[D20:.+]] = vector.extract_strided_slice %[[D0]] {offsets = [3], sizes = [1], strides = [1]} : + # CHECK: %[[D37:.+]] = vector.extract_strided_slice %[[D17]] {offsets = [3], sizes = [1], strides = [1]} : # CHECK-SAME: vector<4xf32> to vector<1xf32> - # CHECK: %[[D21:.+]] = arith.addi %[[D10]], %[[C3]] : index - # CHECK: vector.store %[[D20]], %[[D2]][%[[D21]], %[[D15]]] : memref<64x128xf32, strided<[128, 1], offset: + # CHECK: %[[D38:.+]] = arith.addi %[[D27]], %[[C3]] : index + # CHECK: vector.store %[[D37]], %[[D19]][%[[D38]], %[[D32]]] : memref<64x128xf32, strided<[128, 1], offset: # CHECK-SAME: ?>>, vector<1xf32> # CHECK: return diff --git a/lit_tests/kernel/wave/expansion.py b/lit_tests/kernel/wave/expansion.py index f1bf0ede..6f4e2f29 100644 --- a/lit_tests/kernel/wave/expansion.py +++ b/lit_tests/kernel/wave/expansion.py @@ -243,23 +243,23 @@ def test_gemm(): # CHECK-NEXT: placeholder(_name=a # CHECK-NEXT: placeholder(_name=b # CHECK-NEXT: placeholder(_name=c - # CHECK-NEXT: register(shape=(M, N), dtype=f32, value=0.0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N}) - # CHECK-NEXT: register(shape=(M, N), dtype=f32, value=0.0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + 16}) - # CHECK-NEXT: register(shape=(M, N), dtype=f32, value=0.0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N}) - # CHECK-NEXT: register(shape=(M, N), dtype=f32, value=0.0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + 16}) + # CHECK-NEXT: register(shape=(M, N), dtype=f32, value=0.0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16), N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16)}) + # CHECK-NEXT: register(shape=(M, N), dtype=f32, value=0.0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) + 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16}) + # CHECK-NEXT: register(shape=(M, N), dtype=f32, value=0.0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) + 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16)}) + # CHECK-NEXT: register(shape=(M, N), dtype=f32, value=0.0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16), N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16}) # CHECK-NEXT: reduction(axis=K, init_args=[register_0_0_0, register_0_1_0, register_1_0_0, register_1_1_0], subgraph_name=region_0, implicit_captures=[a, b]) # CHECK-NEXT: get_result(value=reduction, res_idx=3) # CHECK-NEXT: get_result(value=reduction, res_idx=2) # CHECK-NEXT: get_result(value=reduction, res_idx=1) # CHECK-NEXT: get_result(value=reduction, res_idx=0) # CHECK-NEXT: write(register_=getresult_0_0_0 - # CHECK-SAME: index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N} + # CHECK-SAME: index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16), N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16)} # CHECK-NEXT: write(register_=getresult_1_1_0 - # CHECK-SAME: index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + 16} + # CHECK-SAME: index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) + 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16} # CHECK-NEXT: write(register_=getresult_1_0_0 - # CHECK-SAME: index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N} + # CHECK-SAME: index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) + 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16)} # CHECK-NEXT: write(register_=getresult_0_1_0 - # CHECK-SAME: index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + 16} + # CHECK-SAME: index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16), N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16} # CHECK-NEXT: output # Reduction subgraph: @@ -389,11 +389,11 @@ def test_gemm_reduction_expansion_only(): # CHECK-NEXT: placeholder(_name=a # CHECK-NEXT: placeholder(_name=b # CHECK-NEXT: placeholder(_name=c - # CHECK-NEXT: register(shape=(M, N), dtype=f32, value=0.0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N}) + # CHECK-NEXT: register(shape=(M, N), dtype=f32, value=0.0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16), N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16)}) # CHECK-NEXT: reduction(axis=K, init_args=[register_0_0_0] # CHECK-NEXT: get_result(value=reduction, res_idx=0) # CHECK-NEXT: write(register_=getresult_0_0_0 - # CHECK-SAME: index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N}) + # CHECK-SAME: index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16), N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16)}) # CHECK-NEXT: output(return_vals=(None,)) # Reduction subgraph: diff --git a/lit_tests/kernel/wave/index_sequence_analysis.py b/lit_tests/kernel/wave/index_sequence_analysis.py index 4cdb997b..2bebc690 100644 --- a/lit_tests/kernel/wave/index_sequence_analysis.py +++ b/lit_tests/kernel/wave/index_sequence_analysis.py @@ -182,13 +182,13 @@ def test_gemm(): # CHECK-NEXT: placeholder(_name=b # CHECK-NEXT: placeholder(_name=c # CHECK-NEXT: register - # CHECK-SAME: index={M: $WG0*BLOCK_M, N: $WG1*BLOCK_N + BLOCK_N/2}) + # CHECK-SAME: index={M: $WG0*BLOCK_M + Mod($T0, 16), N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16)}) # CHECK-NEXT: register( - # CHECK-SAME: index={M: $WG0*BLOCK_M + 16, N: $WG1*BLOCK_N + BLOCK_N/2 + 16}) + # CHECK-SAME: index={M: $WG0*BLOCK_M + Mod($T0, 16) + 16, N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16) + 16}) # CHECK-NEXT: register( - # CHECK-SAME: index={M: $WG0*BLOCK_M + 16, N: $WG1*BLOCK_N + BLOCK_N/2}) + # CHECK-SAME: index={M: $WG0*BLOCK_M + Mod($T0, 16) + 16, N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16)}) # CHECK-NEXT: register( - # CHECK-SAME: index={M: $WG0*BLOCK_M, N: $WG1*BLOCK_N + BLOCK_N/2 + 16}) + # CHECK-SAME: index={M: $WG0*BLOCK_M + Mod($T0, 16), N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16) + 16}) # CHECK-NEXT: allocate( # CHECK-NEXT: allocate( # CHECK-NEXT: reduction( diff --git a/lit_tests/kernel/wave/minimize_global_loads.py b/lit_tests/kernel/wave/minimize_global_loads.py index b085d8b1..dcf6b225 100644 --- a/lit_tests/kernel/wave/minimize_global_loads.py +++ b/lit_tests/kernel/wave/minimize_global_loads.py @@ -131,13 +131,13 @@ def test_gemm(): # CHECK-NEXT: placeholder(_name=b # CHECK-NEXT: placeholder(_name=c # CHECK-NEXT: register - # CHECK-SAME: index={M: $WG0*BLOCK_M, N: $WG1*BLOCK_N + BLOCK_N/2}) + # CHECK-SAME: index={M: $WG0*BLOCK_M + Mod($T0, 16), N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16)}) # CHECK-NEXT: register( - # CHECK-SAME: index={M: $WG0*BLOCK_M + 16, N: $WG1*BLOCK_N + BLOCK_N/2 + 16}) + # CHECK-SAME: index={M: $WG0*BLOCK_M + Mod($T0, 16) + 16, N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16) + 16}) # CHECK-NEXT: register( - # CHECK-SAME: index={M: $WG0*BLOCK_M + 16, N: $WG1*BLOCK_N + BLOCK_N/2}) + # CHECK-SAME: index={M: $WG0*BLOCK_M + Mod($T0, 16) + 16, N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16)}) # CHECK-NEXT: register( - # CHECK-SAME: index={M: $WG0*BLOCK_M, N: $WG1*BLOCK_N + BLOCK_N/2 + 16}) + # CHECK-SAME: index={M: $WG0*BLOCK_M + Mod($T0, 16), N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16) + 16}) # CHECK-NEXT: allocate( # CHECK-NEXT: allocate( # CHECK-NEXT: reduction( @@ -146,13 +146,13 @@ def test_gemm(): # CHECK-NEXT: get_result(value=reduction, res_idx=1) # CHECK-NEXT: get_result(value=reduction, res_idx=0) # CHECK-NEXT: write(register_=getresult_0_0_0, memory=c - # CHECK-SAME: index={M: $WG0*BLOCK_M, N: $WG1*BLOCK_N + BLOCK_N/2}) + # CHECK-SAME: index={M: $WG0*BLOCK_M + Mod($T0, 16), N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16)}) # CHECK-NEXT: write(register_=getresult_1_1_0, memory=c - # CHECK-SAME: index={M: $WG0*BLOCK_M + 16, N: $WG1*BLOCK_N + BLOCK_N/2 + 16}) + # CHECK-SAME: index={M: $WG0*BLOCK_M + Mod($T0, 16) + 16, N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16) + 16}) # CHECK-NEXT: write(register_=getresult_1_0_0, memory=c - # CHECK-SAME: index={M: $WG0*BLOCK_M + 16, N: $WG1*BLOCK_N + BLOCK_N/2}) + # CHECK-SAME: index={M: $WG0*BLOCK_M + Mod($T0, 16) + 16, N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16)}) # CHECK-NEXT: write(register_=getresult_0_1_0, memory=c - # CHECK-SAME: index={M: $WG0*BLOCK_M, N: $WG1*BLOCK_N + BLOCK_N/2 + 16}) + # CHECK-SAME: index={M: $WG0*BLOCK_M + Mod($T0, 16), N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16) + 16}) # Reduction subgraph: # CHECK: %acc_0_0_0 diff --git a/shark_turbine/kernel/ops/wave_ops.py b/shark_turbine/kernel/ops/wave_ops.py index ebadf0c4..905095c6 100644 --- a/shark_turbine/kernel/ops/wave_ops.py +++ b/shark_turbine/kernel/ops/wave_ops.py @@ -780,16 +780,6 @@ def custom_string(self, value_map: dict[str, str]) -> str: custom_str += f"acc={self.acc} (index = {self.acc_index}))" return custom_str - def post_expansion(self, constraints: list["Constraint"]) -> None: - """ - Once the arguments have been expanded, we set their indices, - ensuring that the LHS and RHS indices are consistent with their - corresponding address spaces. - """ - self.lhs.index = self.lhs_index - self.rhs.index = self.rhs_index - self.acc.index = self.acc_index - @define_op("read") @dataclass diff --git a/shark_turbine/kernel/wave/expansion.py b/shark_turbine/kernel/wave/expansion.py index 2610f968..ebbc2d46 100644 --- a/shark_turbine/kernel/wave/expansion.py +++ b/shark_turbine/kernel/wave/expansion.py @@ -19,7 +19,7 @@ from .._support.indexing import IndexingContext, IndexSequence from ...support.logging import get_logger from .._support.tracing import CapturedTrace -from .utils import get_mma_dimensional_mapping +from .utils import get_mma_dimensional_mapping, specialize_index_sequence from ..lang.global_symbols import * logger = get_logger("turbine.wave.expansion") @@ -146,6 +146,7 @@ def compute_stride( def set_node_index( constraints: Sequence[Constraint], mma_index: dict[IndexSymbol, int], + mma_slices: dict[IndexSymbol, list[fx.Node]], dim_tile_size: dict[IndexSymbol, int], custom: CustomOp, dim_scaling: dict[IndexSymbol, int], @@ -176,11 +177,7 @@ def set_node_index( for dim in custom.indexing_dims: index_seq = None for constraint in sorted_constraints: - mma_check = ( - isinstance(constraint, HardwareConstraint) - and dim in mma_index - and isinstance(custom, MMA) - ) + mma_check = isinstance(constraint, HardwareConstraint) and dim in mma_index vector_check = ( isinstance(constraint, HardwareConstraint) @@ -222,6 +219,8 @@ def set_node_index( index_seq = constraint.apply( constraint_index, dim, elements_per_thread, stride ) + if mma_index: + index_seq = specialize_index_sequence(index_seq, mma_slices, custom) else: if index_seq is None: @@ -251,10 +250,10 @@ def expand_graph( dim_scaling = constraints_or_scaling node_index_setter = lambda *args: None else: - mma_index = get_mma_dimensional_mapping(trace) + mma_index, mma_slices = get_mma_dimensional_mapping(trace) dim_scaling, dim_tile_size = get_dim_scaling(constraints_or_scaling, mma_index) node_index_setter = partial( - set_node_index, constraints_or_scaling, mma_index, dim_tile_size + set_node_index, constraints_or_scaling, mma_index, mma_slices, dim_tile_size ) # Start from the back and expand in the corresponding indexing dimensions of a node diff --git a/shark_turbine/kernel/wave/index_sequence_analysis.py b/shark_turbine/kernel/wave/index_sequence_analysis.py index cec8b60b..b9212f01 100644 --- a/shark_turbine/kernel/wave/index_sequence_analysis.py +++ b/shark_turbine/kernel/wave/index_sequence_analysis.py @@ -24,7 +24,7 @@ def get_vector_shape( hardware_constraint: HardwareConstraint, symbolic_shape: list[IndexSymbol], ) -> list[int]: - mma_indices = get_mma_dimensional_mapping(trace) + mma_indices, _ = get_mma_dimensional_mapping(trace) return [ get_hardware_vector_size(dim, hardware_constraint, mma_indices) for dim in symbolic_shape diff --git a/shark_turbine/kernel/wave/utils.py b/shark_turbine/kernel/wave/utils.py index affd5fef..42e5bca3 100644 --- a/shark_turbine/kernel/wave/utils.py +++ b/shark_turbine/kernel/wave/utils.py @@ -1,5 +1,4 @@ # Copyright 2024 The IREE Authors -# # Licensed under the Apache License v2.0 with LLVM Exceptions. # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception @@ -16,7 +15,16 @@ from .._support.tracing import CapturedTrace from .._support.indexing import IndexExpr, IndexingContext, IndexSymbol, IndexSequence from ..lang.global_symbols import * -from ..ops.wave_ops import get_custom, Output, Write, MMA +from ..ops.wave_ops import ( + get_custom, + Output, + Write, + MMA, + CustomOp, + Reduction, + GetResult, + IterArg, +) from .constraints import Constraint, HardwareConstraint, TilingConstraint import torch.fx as fx import shark_turbine.kernel.lang as tkl @@ -145,7 +153,9 @@ def simplify_index(index: IndexExpr) -> IndexExpr: return subs_idxc(index.subs(mapping)) -def get_mma_dimensional_mapping(trace: CapturedTrace) -> dict[IndexSymbol, int]: +def get_mma_dimensional_mapping( + trace: CapturedTrace, +) -> tuple[dict[IndexSymbol, int], dict[IndexSymbol, list[fx.Node]]]: """ Given a trace, determine the MMA dimensional mapping for all the MMA operations in the graph. For example, if we have @@ -159,7 +169,8 @@ def is_mma(node): return isinstance(get_custom(node), MMA) mapping: dict[IndexSymbol, int] = {} - for node in trace.walk(is_mma): + mma_nodes = trace.walk(is_mma) + for node in mma_nodes: custom: MMA = get_custom(node) m, n = custom.acc_type.symbolic_shape[-2:] lhs_shape = custom.lhs_type.symbolic_shape @@ -170,7 +181,7 @@ def is_mma(node): mapping[n] = 1 mapping[k] = 2 - return mapping + return mapping, capture_mma_slices([get_custom(x) for x in mma_nodes]) def get_hardware_vector_size( @@ -378,3 +389,132 @@ def erase_graph(graph: fx.Graph): for user in node.users: graph.erase_node(user) graph.erase_node(node) + + +def get_users( + node: fx.Node, reduction: fx.Node = None +) -> tuple[list[fx.Node], fx.Node]: + """ + Return the users of a node, propagating through reductions. + """ + users = [] + for user in node.users: + custom = get_custom(user) + if isinstance(custom, Reduction): + # Map init arg to iter arg + reduction = custom + init_arg_idx = custom.init_args.index(node) + users.append(custom.iter_args[init_arg_idx]) + continue + if isinstance(custom, Output) and reduction: + # Map output to get result + return_vals = custom.return_vals[0] + get_results = sorted( + [x for x in reduction.users if isinstance(get_custom(x), GetResult)], + lambda x: get_custom(x).res_idx, + ) + if isinstance(return_vals, list): + output_idx = return_vals.index(node) + users.append(get_results[output_idx]) + else: + users.append(get_results[0]) + continue + users.append(user) + return users, reduction + + +def get_inputs( + node: fx.Node, reduction: fx.Node = None +) -> tuple[list[fx.Node], fx.Node]: + """ + Return the inputs of a node, propagating through reductions. + """ + inputs = [] + for input in node.all_input_nodes: + custom = get_custom(input) + if isinstance(custom, GetResult): + reduction = custom.value + assert isinstance( + reduction, Reduction + ), "GetResult must be used by a Reduction" + # Map get result to output + inputs.append(reduction.outputs[custom.res_idx]) + continue + if isinstance(custom, IterArg): + # Map iter args to init args + iter_arg_idx = reduction.iter_args.index(node) + inputs.append(reduction.init_args[iter_arg_idx]) + continue + inputs.append(input) + return inputs, reduction + + +def bfs( + node: fx.Node, + get_neighbors: Callable[[fx.Node, fx.Node], list[fx.Node]], +) -> set[fx.Node]: + """ + Run BFS on the graph to capture the forward slice of a node. + """ + visited: set[fx.Node] = set() + queue: list[fx.Node] = [] + visited.add(node) + queue.append(node) + reduction = None + while queue: + s = queue.pop(0) + neighbors, reduction = get_neighbors(s, reduction) + for neighbor in neighbors: + if neighbor not in visited: + visited.add(neighbor) + queue.append(neighbor) + return visited + + +def capture_forward_slice(node: fx.Node) -> set[fx.Node]: + """ + Run BFS on the graph to capture the forward slice of a node. + """ + return bfs(node, lambda x, y: get_users(x, y)) + + +def capture_backward_slice(node: fx.Node) -> set[fx.Node]: + """ + Capture backward slice from a node and return the tree. + Assumes graph is directed. + """ + return bfs(node, lambda x, y: get_inputs(x, y)) + + +def capture_mma_slices(mma_nodes: list[MMA]) -> dict[IndexSymbol, list[fx.Node]]: + """ + Given an index sequence, specialize it to a LHS, RHS or ACC index sequence + based on whether the node is used as the LHS, RHS or ACC in the MMA node. + """ + mma_slices = {x: [] for x in [MMA_LHS, MMA_RHS, MMA_ACC]} + for mma in mma_nodes: + mma_slices[MMA_LHS] += capture_backward_slice(mma.lhs) + mma_slices[MMA_RHS] += capture_backward_slice(mma.rhs) + mma_slices[MMA_ACC] += capture_forward_slice(mma.acc) + return mma_slices + + +def specialize_index_sequence( + index_seq: IndexSequence, + mma_slices: dict[IndexSymbol, list[fx.Node]], + custom: CustomOp, +) -> IndexSequence: + """ + Given an index sequence, specialize it to a LHS, RHS or ACC index sequence + based on whether the node is used as the LHS, RHS or ACC in the MMA node. + If the node is not used as any of the operands, return the original index sequence + with all the MMA symbols zeroed out. + """ + if isinstance(custom, MMA): + return index_seq + operand_map = {MMA_LHS: 0, MMA_RHS: 0, MMA_ACC: 0} + for key in mma_slices: + if custom.fx_node in mma_slices[key]: + operand_map[key] = 1 + return index_seq.subs(operand_map) + return index_seq.subs(operand_map) From 04a4ba5534c4d2ad398c96115cc40d8a3406f4b5 Mon Sep 17 00:00:00 2001 From: harsh-nod Date: Thu, 26 Sep 2024 12:35:31 -0500 Subject: [PATCH 06/28] Add first draft of introduction (#168) This PR adds more documentation about tkw. Specifically, it provides a first draft of the introduction and adds a section on memory access patterns. Signed-off-by: Harsh Menon --- .../kernel/wave/docs/mlsys/.gitignore | 1 + shark_turbine/kernel/wave/docs/mlsys/tkw.bbl | 236 +++++++++++++++--- shark_turbine/kernel/wave/docs/mlsys/tkw.bib | 153 +++++++----- shark_turbine/kernel/wave/docs/mlsys/tkw.blg | 46 ---- shark_turbine/kernel/wave/docs/mlsys/tkw.tex | 80 +++++- 5 files changed, 364 insertions(+), 152 deletions(-) delete mode 100644 shark_turbine/kernel/wave/docs/mlsys/tkw.blg diff --git a/shark_turbine/kernel/wave/docs/mlsys/.gitignore b/shark_turbine/kernel/wave/docs/mlsys/.gitignore index f2e31fe2..b4c7d64b 100644 --- a/shark_turbine/kernel/wave/docs/mlsys/.gitignore +++ b/shark_turbine/kernel/wave/docs/mlsys/.gitignore @@ -3,3 +3,4 @@ *.out *.pdf *.synctex.gz +*.blg diff --git a/shark_turbine/kernel/wave/docs/mlsys/tkw.bbl b/shark_turbine/kernel/wave/docs/mlsys/tkw.bbl index 1295a6d4..5ca46234 100644 --- a/shark_turbine/kernel/wave/docs/mlsys/tkw.bbl +++ b/shark_turbine/kernel/wave/docs/mlsys/tkw.bbl @@ -1,54 +1,208 @@ -\begin{thebibliography}{8} +\begin{thebibliography}{6} \providecommand{\natexlab}[1]{#1} \providecommand{\url}[1]{\texttt{#1}} \expandafter\ifx\csname urlstyle\endcsname\relax \providecommand{\doi}[1]{doi: #1}\else \providecommand{\doi}{doi: \begingroup \urlstyle{rm}\Url}\fi -\bibitem[Author(2018)]{anonymous} -Author, N.~N. -\newblock Suppressed for anonymity, 2018. +\bibitem[Chetlur et~al.(2014)Chetlur, Woolley, Vandermersch, Cohen, Tran, + Catanzaro, and Shelhamer]{chetlur_cudnn_2014} +Chetlur, S., Woolley, C., Vandermersch, P., Cohen, J., Tran, J., Catanzaro, B., + and Shelhamer, E. +\newblock {cuDNN}: {Efficient} {Primitives} for {Deep} {Learning}, December + 2014. +\newblock URL \url{http://arxiv.org/abs/1410.0759}. +\newblock arXiv:1410.0759 [cs]. -\bibitem[Duda et~al.(2000)Duda, Hart, and Stork]{DudaHart2nd} -Duda, R.~O., Hart, P.~E., and Stork, D.~G. -\newblock \emph{Pattern Classification}. -\newblock John Wiley and Sons, 2nd edition, 2000. +\bibitem[Dubey et~al.(2024)Dubey, Jauhri, Pandey, Kadian, Al-Dahle, Letman, + Mathur, Schelten, Yang, Fan, Goyal, Hartshorn, Yang, Mitra, Sravankumar, + Korenev, Hinsvark, Rao, Zhang, Rodriguez, Gregerson, Spataru, Roziere, Biron, + Tang, Chern, Caucheteux, Nayak, Bi, Marra, McConnell, Keller, Touret, Wu, + Wong, Ferrer, Nikolaidis, Allonsius, Song, Pintz, Livshits, Esiobu, + Choudhary, Mahajan, Garcia-Olano, Perino, Hupkes, Lakomkin, AlBadawy, + Lobanova, Dinan, Smith, Radenovic, Zhang, Synnaeve, Lee, Anderson, Nail, + Mialon, Pang, Cucurell, Nguyen, Korevaar, Xu, Touvron, Zarov, Ibarra, + Kloumann, Misra, Evtimov, Copet, Lee, Geffert, Vranes, Park, Mahadeokar, + Shah, van~der Linde, Billock, Hong, Lee, Fu, Chi, Huang, Liu, Wang, Yu, + Bitton, Spisak, Park, Rocca, Johnstun, Saxe, Jia, Alwala, Upasani, Plawiak, + Li, Heafield, Stone, El-Arini, Iyer, Malik, Chiu, Bhalla, Rantala-Yeary, + van~der Maaten, Chen, Tan, Jenkins, Martin, Madaan, Malo, Blecher, Landzaat, + de~Oliveira, Muzzi, Pasupuleti, Singh, Paluri, Kardas, Oldham, Rita, Pavlova, + Kambadur, Lewis, Si, Singh, Hassan, Goyal, Torabi, Bashlykov, Bogoychev, + Chatterji, Duchenne, Çelebi, Alrassy, Zhang, Li, Vasic, Weng, Bhargava, + Dubal, Krishnan, Koura, Xu, He, Dong, Srinivasan, Ganapathy, Calderer, + Cabral, Stojnic, Raileanu, Girdhar, Patel, Sauvestre, Polidoro, Sumbaly, + Taylor, Silva, Hou, Wang, Hosseini, Chennabasappa, Singh, Bell, Kim, Edunov, + Nie, Narang, Raparthy, Shen, Wan, Bhosale, Zhang, Vandenhende, Batra, + Whitman, Sootla, Collot, Gururangan, Borodinsky, Herman, Fowler, Sheasha, + Georgiou, Scialom, Speckbacher, Mihaylov, Xiao, Karn, Goswami, Gupta, + Ramanathan, Kerkez, Gonguet, Do, Vogeti, Petrovic, Chu, Xiong, Fu, Meers, + Martinet, Wang, Tan, Xie, Jia, Wang, Goldschlag, Gaur, Babaei, Wen, Song, + Zhang, Li, Mao, Coudert, Yan, Chen, Papakipos, Singh, Grattafiori, Jain, + Kelsey, Shajnfeld, Gangidi, Victoria, Goldstand, Menon, Sharma, Boesenberg, + Vaughan, Baevski, Feinstein, Kallet, Sangani, Yunus, Lupu, Alvarado, Caples, + Gu, Ho, Poulton, Ryan, Ramchandani, Franco, Saraf, Chowdhury, Gabriel, + Bharambe, Eisenman, Yazdan, James, Maurer, Leonhardi, Huang, Loyd, De~Paola, + Paranjape, Liu, Wu, Ni, Hancock, Wasti, Spence, Stojkovic, Gamido, Montalvo, + Parker, Burton, Mejia, Wang, Kim, Zhou, Hu, Chu, Cai, Tindal, Feichtenhofer, + Civin, Beaty, Kreymer, Li, Wyatt, Adkins, Xu, Testuggine, David, Parikh, + Liskovich, Foss, Wang, Le, Holland, Dowling, Jamil, Montgomery, Presani, + Hahn, Wood, Brinkman, Arcaute, Dunbar, Smothers, Sun, Kreuk, Tian, Ozgenel, + Caggioni, Guzmán, Kanayet, Seide, Florez, Schwarz, Badeer, Swee, Halpern, + Thattai, Herman, Sizov, Guangyi, Zhang, Lakshminarayanan, Shojanazeri, Zou, + Wang, Zha, Habeeb, Rudolph, Suk, Aspegren, Goldman, Damlaj, Molybog, Tufanov, + Veliche, Gat, Weissman, Geboski, Kohli, Asher, Gaya, Marcus, Tang, Chan, + Zhen, Reizenstein, Teboul, Zhong, Jin, Yang, Cummings, Carvill, Shepard, + McPhie, Torres, Ginsburg, Wang, Wu, U, Saxena, Prasad, Khandelwal, Zand, + Matosich, Veeraraghavan, Michelena, Li, Huang, Chawla, Lakhotia, Huang, Chen, + Garg, A, Silva, Bell, Zhang, Guo, Yu, Moshkovich, Wehrstedt, Khabsa, Avalani, + Bhatt, Tsimpoukelli, Mankus, Hasson, Lennie, Reso, Groshev, Naumov, Lathi, + Keneally, Seltzer, Valko, Restrepo, Patel, Vyatskov, Samvelyan, Clark, Macey, + Wang, Hermoso, Metanat, Rastegari, Bansal, Santhanam, Parks, White, Bawa, + Singhal, Egebo, Usunier, Laptev, Dong, Zhang, Cheng, Chernoguz, Hart, + Salpekar, Kalinli, Kent, Parekh, Saab, Balaji, Rittner, Bontrager, Roux, + Dollar, Zvyagina, Ratanchandani, Yuvraj, Liang, Alao, Rodriguez, Ayub, + Murthy, Nayani, Mitra, Li, Hogan, Battey, Wang, Maheswari, Howes, Rinott, + Bondu, Datta, Chugh, Hunt, Dhillon, Sidorov, Pan, Verma, Yamamoto, Ramaswamy, + Lindsay, Lindsay, Feng, Lin, Zha, Shankar, Zhang, Zhang, Wang, Agarwal, + Sajuyigbe, Chintala, Max, Chen, Kehoe, Satterfield, Govindaprasad, Gupta, + Cho, Virk, Subramanian, Choudhury, Goldman, Remez, Glaser, Best, Kohler, + Robinson, Li, Zhang, Matthews, Chou, Shaked, Vontimitta, Ajayi, Montanez, + Mohan, Kumar, Mangla, Albiero, Ionescu, Poenaru, Mihailescu, Ivanov, Li, + Wang, Jiang, Bouaziz, Constable, Tang, Wang, Wu, Wang, Xia, Wu, Gao, Chen, + Hu, Jia, Qi, Li, Zhang, Zhang, Adi, Nam, Yu, Wang, Hao, Qian, He, Rait, + DeVito, Rosnbrick, Wen, Yang, and Zhao]{dubey_llama_2024} +Dubey, A., Jauhri, A., Pandey, A., Kadian, A., Al-Dahle, A., Letman, A., + Mathur, A., Schelten, A., Yang, A., Fan, A., Goyal, A., Hartshorn, A., Yang, + A., Mitra, A., Sravankumar, A., Korenev, A., Hinsvark, A., Rao, A., Zhang, + A., Rodriguez, A., Gregerson, A., Spataru, A., Roziere, B., Biron, B., Tang, + B., Chern, B., Caucheteux, C., Nayak, C., Bi, C., Marra, C., McConnell, C., + Keller, C., Touret, C., Wu, C., Wong, C., Ferrer, C.~C., Nikolaidis, C., + Allonsius, D., Song, D., Pintz, D., Livshits, D., Esiobu, D., Choudhary, D., + Mahajan, D., Garcia-Olano, D., Perino, D., Hupkes, D., Lakomkin, E., + AlBadawy, E., Lobanova, E., Dinan, E., Smith, E.~M., Radenovic, F., Zhang, + F., Synnaeve, G., Lee, G., Anderson, G.~L., Nail, G., Mialon, G., Pang, G., + Cucurell, G., Nguyen, H., Korevaar, H., Xu, H., Touvron, H., Zarov, I., + Ibarra, I.~A., Kloumann, I., Misra, I., Evtimov, I., Copet, J., Lee, J., + Geffert, J., Vranes, J., Park, J., Mahadeokar, J., Shah, J., van~der Linde, + J., Billock, J., Hong, J., Lee, J., Fu, J., Chi, J., Huang, J., Liu, J., + Wang, J., Yu, J., Bitton, J., Spisak, J., Park, J., Rocca, J., Johnstun, J., + Saxe, J., Jia, J., Alwala, K.~V., Upasani, K., Plawiak, K., Li, K., Heafield, + K., Stone, K., El-Arini, K., Iyer, K., Malik, K., Chiu, K., Bhalla, K., + Rantala-Yeary, L., van~der Maaten, L., Chen, L., Tan, L., Jenkins, L., + Martin, L., Madaan, L., Malo, L., Blecher, L., Landzaat, L., de~Oliveira, L., + Muzzi, M., Pasupuleti, M., Singh, M., Paluri, M., Kardas, M., Oldham, M., + Rita, M., Pavlova, M., Kambadur, M., Lewis, M., Si, M., Singh, M.~K., Hassan, + M., Goyal, N., Torabi, N., Bashlykov, N., Bogoychev, N., Chatterji, N., + Duchenne, O., Çelebi, O., Alrassy, P., Zhang, P., Li, P., Vasic, P., Weng, + P., Bhargava, P., Dubal, P., Krishnan, P., Koura, P.~S., Xu, P., He, Q., + Dong, Q., Srinivasan, R., Ganapathy, R., Calderer, R., Cabral, R.~S., + Stojnic, R., Raileanu, R., Girdhar, R., Patel, R., Sauvestre, R., Polidoro, + R., Sumbaly, R., Taylor, R., Silva, R., Hou, R., Wang, R., Hosseini, S., + Chennabasappa, S., Singh, S., Bell, S., Kim, S.~S., Edunov, S., Nie, S., + Narang, S., Raparthy, S., Shen, S., Wan, S., Bhosale, S., Zhang, S., + Vandenhende, S., Batra, S., Whitman, S., Sootla, S., Collot, S., Gururangan, + S., Borodinsky, S., Herman, T., Fowler, T., Sheasha, T., Georgiou, T., + Scialom, T., Speckbacher, T., Mihaylov, T., Xiao, T., Karn, U., Goswami, V., + Gupta, V., Ramanathan, V., Kerkez, V., Gonguet, V., Do, V., Vogeti, V., + Petrovic, V., Chu, W., Xiong, W., Fu, W., Meers, W., Martinet, X., Wang, X., + Tan, X.~E., Xie, X., Jia, X., Wang, X., Goldschlag, Y., Gaur, Y., Babaei, Y., + Wen, Y., Song, Y., Zhang, Y., Li, Y., Mao, Y., Coudert, Z.~D., Yan, Z., Chen, + Z., Papakipos, Z., Singh, A., Grattafiori, A., Jain, A., Kelsey, A., + Shajnfeld, A., Gangidi, A., Victoria, A., Goldstand, A., Menon, A., Sharma, + A., Boesenberg, A., Vaughan, A., Baevski, A., Feinstein, A., Kallet, A., + Sangani, A., Yunus, A., Lupu, A., Alvarado, A., Caples, A., Gu, A., Ho, A., + Poulton, A., Ryan, A., Ramchandani, A., Franco, A., Saraf, A., Chowdhury, A., + Gabriel, A., Bharambe, A., Eisenman, A., Yazdan, A., James, B., Maurer, B., + Leonhardi, B., Huang, B., Loyd, B., De~Paola, B., Paranjape, B., Liu, B., Wu, + B., Ni, B., Hancock, B., Wasti, B., Spence, B., Stojkovic, B., Gamido, B., + Montalvo, B., Parker, C., Burton, C., Mejia, C., Wang, C., Kim, C., Zhou, C., + Hu, C., Chu, C.-H., Cai, C., Tindal, C., Feichtenhofer, C., Civin, D., Beaty, + D., Kreymer, D., Li, D., Wyatt, D., Adkins, D., Xu, D., Testuggine, D., + David, D., Parikh, D., Liskovich, D., Foss, D., Wang, D., Le, D., Holland, + D., Dowling, E., Jamil, E., Montgomery, E., Presani, E., Hahn, E., Wood, E., + Brinkman, E., Arcaute, E., Dunbar, E., Smothers, E., Sun, F., Kreuk, F., + Tian, F., Ozgenel, F., Caggioni, F., Guzmán, F., Kanayet, F., Seide, F., + Florez, G.~M., Schwarz, G., Badeer, G., Swee, G., Halpern, G., Thattai, G., + Herman, G., Sizov, G., Guangyi, Zhang, Lakshminarayanan, G., Shojanazeri, H., + Zou, H., Wang, H., Zha, H., Habeeb, H., Rudolph, H., Suk, H., Aspegren, H., + Goldman, H., Damlaj, I., Molybog, I., Tufanov, I., Veliche, I.-E., Gat, I., + Weissman, J., Geboski, J., Kohli, J., Asher, J., Gaya, J.-B., Marcus, J., + Tang, J., Chan, J., Zhen, J., Reizenstein, J., Teboul, J., Zhong, J., Jin, + J., Yang, J., Cummings, J., Carvill, J., Shepard, J., McPhie, J., Torres, J., + Ginsburg, J., Wang, J., Wu, K., U, K.~H., Saxena, K., Prasad, K., Khandelwal, + K., Zand, K., Matosich, K., Veeraraghavan, K., Michelena, K., Li, K., Huang, + K., Chawla, K., Lakhotia, K., Huang, K., Chen, L., Garg, L., A, L., Silva, + L., Bell, L., Zhang, L., Guo, L., Yu, L., Moshkovich, L., Wehrstedt, L., + Khabsa, M., Avalani, M., Bhatt, M., Tsimpoukelli, M., Mankus, M., Hasson, M., + Lennie, M., Reso, M., Groshev, M., Naumov, M., Lathi, M., Keneally, M., + Seltzer, M.~L., Valko, M., Restrepo, M., Patel, M., Vyatskov, M., Samvelyan, + M., Clark, M., Macey, M., Wang, M., Hermoso, M.~J., Metanat, M., Rastegari, + M., Bansal, M., Santhanam, N., Parks, N., White, N., Bawa, N., Singhal, N., + Egebo, N., Usunier, N., Laptev, N.~P., Dong, N., Zhang, N., Cheng, N., + Chernoguz, O., Hart, O., Salpekar, O., Kalinli, O., Kent, P., Parekh, P., + Saab, P., Balaji, P., Rittner, P., Bontrager, P., Roux, P., Dollar, P., + Zvyagina, P., Ratanchandani, P., Yuvraj, P., Liang, Q., Alao, R., Rodriguez, + R., Ayub, R., Murthy, R., Nayani, R., Mitra, R., Li, R., Hogan, R., Battey, + R., Wang, R., Maheswari, R., Howes, R., Rinott, R., Bondu, S.~J., Datta, S., + Chugh, S., Hunt, S., Dhillon, S., Sidorov, S., Pan, S., Verma, S., Yamamoto, + S., Ramaswamy, S., Lindsay, S., Lindsay, S., Feng, S., Lin, S., Zha, S.~C., + Shankar, S., Zhang, S., Zhang, S., Wang, S., Agarwal, S., Sajuyigbe, S., + Chintala, S., Max, S., Chen, S., Kehoe, S., Satterfield, S., Govindaprasad, + S., Gupta, S., Cho, S., Virk, S., Subramanian, S., Choudhury, S., Goldman, + S., Remez, T., Glaser, T., Best, T., Kohler, T., Robinson, T., Li, T., Zhang, + T., Matthews, T., Chou, T., Shaked, T., Vontimitta, V., Ajayi, V., Montanez, + V., Mohan, V., Kumar, V.~S., Mangla, V., Albiero, V., Ionescu, V., Poenaru, + V., Mihailescu, V.~T., Ivanov, V., Li, W., Wang, W., Jiang, W., Bouaziz, W., + Constable, W., Tang, X., Wang, X., Wu, X., Wang, X., Xia, X., Wu, X., Gao, + X., Chen, Y., Hu, Y., Jia, Y., Qi, Y., Li, Y., Zhang, Y., Zhang, Y., Adi, Y., + Nam, Y., Yu, Wang, Hao, Y., Qian, Y., He, Y., Rait, Z., DeVito, Z., + Rosnbrick, Z., Wen, Z., Yang, Z., and Zhao, Z. +\newblock The {Llama} 3 {Herd} of {Models}, August 2024. +\newblock URL \url{http://arxiv.org/abs/2407.21783}. +\newblock arXiv:2407.21783 [cs]. -\bibitem[Kearns(1989)]{kearns89} -Kearns, M.~J. -\newblock \emph{Computational Complexity of Machine Learning}. -\newblock PhD thesis, Department of Computer Science, Harvard University, 1989. +\bibitem[Paszke et~al.(2019)Paszke, Gross, Massa, Lerer, Bradbury, Chanan, + Killeen, Lin, Gimelshein, Antiga, Desmaison, Köpf, Yang, DeVito, Raison, + Tejani, Chilamkurthy, Steiner, Fang, Bai, and Chintala]{paszke_pytorch_2019} +Paszke, A., Gross, S., Massa, F., Lerer, A., Bradbury, J., Chanan, G., Killeen, + T., Lin, Z., Gimelshein, N., Antiga, L., Desmaison, A., Köpf, A., Yang, E., + DeVito, Z., Raison, M., Tejani, A., Chilamkurthy, S., Steiner, B., Fang, L., + Bai, J., and Chintala, S. +\newblock {PyTorch}: {An} {Imperative} {Style}, {High}-{Performance} {Deep} + {Learning} {Library}, December 2019. +\newblock URL \url{http://arxiv.org/abs/1912.01703}. +\newblock arXiv:1912.01703 [cs, stat]. -\bibitem[Langley(2000)]{langley00} -Langley, P. -\newblock Crafting papers on machine learning. -\newblock In Langley, P. (ed.), \emph{Proceedings of the 17th International - Conference on Machine Learning (ICML 2000)}, pp.\ 1207--1216, Stanford, CA, - 2000. Morgan Kaufmann. +\bibitem[Podell et~al.(2023)Podell, English, Lacey, Blattmann, Dockhorn, + Müller, Penna, and Rombach]{podell_sdxl_2023} +Podell, D., English, Z., Lacey, K., Blattmann, A., Dockhorn, T., Müller, J., + Penna, J., and Rombach, R. +\newblock {SDXL}: {Improving} {Latent} {Diffusion} {Models} for + {High}-{Resolution} {Image} {Synthesis}, July 2023. +\newblock URL \url{http://arxiv.org/abs/2307.01952}. +\newblock arXiv:2307.01952 [cs]. -\bibitem[Michalski et~al.(1983)Michalski, Carbonell, and - Mitchell]{MachineLearningI} -Michalski, R.~S., Carbonell, J.~G., and Mitchell, T.~M. (eds.). -\newblock \emph{Machine Learning: An Artificial Intelligence Approach, Vol. I}. -\newblock Tioga, Palo Alto, CA, 1983. +\bibitem[Sun et~al.(2023)Sun, Li, Geng, Stuijk, and + Corporaal]{sun_dissecting_2023} +Sun, W., Li, A., Geng, T., Stuijk, S., and Corporaal, H. +\newblock Dissecting {Tensor} {Cores} via {Microbenchmarks}: {Latency}, + {Throughput} and {Numeric} {Behaviors}. +\newblock \emph{IEEE Transactions on Parallel and Distributed Systems}, + 34\penalty0 (1):\penalty0 246--261, January 2023. +\newblock ISSN 1045-9219, 1558-2183, 2161-9883. +\newblock \doi{10.1109/TPDS.2022.3217824}. +\newblock URL \url{https://ieeexplore.ieee.org/document/9931992/}. -\bibitem[Mitchell(1980)]{mitchell80} -Mitchell, T.~M. -\newblock The need for biases in learning generalizations. -\newblock Technical report, Computer Science Department, Rutgers University, - New Brunswick, MA, 1980. - -\bibitem[Newell \& Rosenbloom(1981)Newell and Rosenbloom]{Newell81} -Newell, A. and Rosenbloom, P.~S. -\newblock Mechanisms of skill acquisition and the law of practice. -\newblock In Anderson, J.~R. (ed.), \emph{Cognitive Skills and Their - Acquisition}, chapter~1, pp.\ 1--51. Lawrence Erlbaum Associates, Inc., - Hillsdale, NJ, 1981. - -\bibitem[Samuel(1959)]{Samuel59} -Samuel, A.~L. -\newblock Some studies in machine learning using the game of checkers. -\newblock \emph{IBM Journal of Research and Development}, 3\penalty0 - (3):\penalty0 211--229, 1959. +\bibitem[Tillet et~al.(2019)Tillet, Kung, and Cox]{tillet_triton_2019} +Tillet, P., Kung, H.~T., and Cox, D. +\newblock Triton: an intermediate language and compiler for tiled neural + network computations. +\newblock In \emph{Proceedings of the 3rd {ACM} {SIGPLAN} {International} + {Workshop} on {Machine} {Learning} and {Programming} {Languages}}, pp.\ + 10--19, Phoenix AZ USA, June 2019. ACM. +\newblock ISBN 978-1-4503-6719-6. +\newblock \doi{10.1145/3315508.3329973}. +\newblock URL \url{https://dl.acm.org/doi/10.1145/3315508.3329973}. \end{thebibliography} diff --git a/shark_turbine/kernel/wave/docs/mlsys/tkw.bib b/shark_turbine/kernel/wave/docs/mlsys/tkw.bib index 6bd0e3ee..61f02789 100644 --- a/shark_turbine/kernel/wave/docs/mlsys/tkw.bib +++ b/shark_turbine/kernel/wave/docs/mlsys/tkw.bib @@ -1,75 +1,102 @@ -@inproceedings{langley00, - author = {P. Langley}, - title = {Crafting Papers on Machine Learning}, - year = {2000}, - pages = {1207--1216}, - editor = {Pat Langley}, - booktitle = {Proceedings of the 17th International Conference - on Machine Learning (ICML 2000)}, - address = {Stanford, CA}, - publisher = {Morgan Kaufmann} -} - -@TechReport{mitchell80, - author = "T. M. Mitchell", - title = "The Need for Biases in Learning Generalizations", - institution = "Computer Science Department, Rutgers University", - year = "1980", - address = "New Brunswick, MA", -} -@phdthesis{kearns89, - author = {M. J. Kearns}, - title = {Computational Complexity of Machine Learning}, - school = {Department of Computer Science, Harvard University}, - year = {1989} +@inproceedings{tillet_triton_2019, + address = {Phoenix AZ USA}, + title = {Triton: an intermediate language and compiler for tiled neural network computations}, + isbn = {978-1-4503-6719-6}, + shorttitle = {Triton}, + url = {https://dl.acm.org/doi/10.1145/3315508.3329973}, + doi = {10.1145/3315508.3329973}, + abstract = {The validation and deployment of novel research ideas in the field of Deep Learning is often limited by the availability of efficient compute kernels for certain basic primitives. In particular, operations that cannot leverage existing vendor libraries (e.g., cuBLAS, cuDNN) are at risk of facing poor device utilization unless custom implementations are written by experts – usually at the expense of portability. For this reason, the development of new programming abstractions for specifying custom Deep Learning workloads at a minimal performance cost has become crucial.}, + language = {en}, + urldate = {2024-09-25}, + booktitle = {Proceedings of the 3rd {ACM} {SIGPLAN} {International} {Workshop} on {Machine} {Learning} and {Programming} {Languages}}, + publisher = {ACM}, + author = {Tillet, Philippe and Kung, H. T. and Cox, David}, + month = jun, + year = {2019}, + pages = {10--19}, + file = {PDF:/Users/harsh/Zotero/storage/FMLLYK4M/Tillet et al. - 2019 - Triton an intermediate language and compiler for tiled neural network computations.pdf:application/pdf}, } -@Book{MachineLearningI, - editor = "R. S. Michalski and J. G. Carbonell and T. - M. Mitchell", - title = "Machine Learning: An Artificial Intelligence - Approach, Vol. I", - publisher = "Tioga", - year = "1983", - address = "Palo Alto, CA" +@misc{podell_sdxl_2023, + title = {{SDXL}: {Improving} {Latent} {Diffusion} {Models} for {High}-{Resolution} {Image} {Synthesis}}, + shorttitle = {{SDXL}}, + url = {http://arxiv.org/abs/2307.01952}, + abstract = {We present SDXL, a latent diffusion model for text-to-image synthesis. Compared to previous versions of Stable Diffusion, SDXL leverages a three times larger UNet backbone: The increase of model parameters is mainly due to more attention blocks and a larger cross-attention context as SDXL uses a second text encoder. We design multiple novel conditioning schemes and train SDXL on multiple aspect ratios. We also introduce a refinement model which is used to improve the visual fidelity of samples generated by SDXL using a post-hoc image-to-image technique. We demonstrate that SDXL shows drastically improved performance compared to previous versions of Stable Diffusion and achieves results competitive with those of black-box state-of-the-art image generators. In the spirit of promoting open research and fostering transparency in large model training and evaluation, we provide access to code and model weights.}, + language = {en}, + urldate = {2024-09-25}, + publisher = {arXiv}, + author = {Podell, Dustin and English, Zion and Lacey, Kyle and Blattmann, Andreas and Dockhorn, Tim and Müller, Jonas and Penna, Joe and Rombach, Robin}, + month = jul, + year = {2023}, + note = {arXiv:2307.01952 [cs]}, + keywords = {Computer Science - Artificial Intelligence, Computer Science - Computer Vision and Pattern Recognition}, + file = {PDF:/Users/harsh/Zotero/storage/ARJZQZ42/Podell et al. - 2023 - SDXL Improving Latent Diffusion Models for High-Resolution Image Synthesis.pdf:application/pdf}, } -@Book{DudaHart2nd, - author = "R. O. Duda and P. E. Hart and D. G. Stork", - title = "Pattern Classification", - publisher = "John Wiley and Sons", - edition = "2nd", - year = "2000" +@misc{dubey_llama_2024, + title = {The {Llama} 3 {Herd} of {Models}}, + url = {http://arxiv.org/abs/2407.21783}, + abstract = {Modern artificial intelligence (AI) systems are powered by foundation models. This paper presents a new set of foundation models, called Llama 3. It is a herd of language models that natively support multilinguality, coding, reasoning, and tool usage. Our largest model is a dense Transformer with 405B parameters and a context window of up to 128K tokens. This paper presents an extensive empirical evaluation of Llama 3. We find that Llama 3 delivers comparable quality to leading language models such as GPT-4 on a plethora of tasks. We publicly release Llama 3, including pre-trained and post-trained versions of the 405B parameter language model and our Llama Guard 3 model for input and output safety. The paper also presents the results of experiments in which we integrate image, video, and speech capabilities into Llama 3 via a compositional approach. We observe this approach performs competitively with the state-of-the-art on image, video, and speech recognition tasks. The resulting models are not yet being broadly released as they are still under development.}, + language = {en}, + urldate = {2024-09-25}, + publisher = {arXiv}, + author = {Dubey, Abhimanyu and Jauhri, Abhinav and Pandey, Abhinav and Kadian, Abhishek and Al-Dahle, Ahmad and Letman, Aiesha and Mathur, Akhil and Schelten, Alan and Yang, Amy and Fan, Angela and Goyal, Anirudh and Hartshorn, Anthony and Yang, Aobo and Mitra, Archi and Sravankumar, Archie and Korenev, Artem and Hinsvark, Arthur and Rao, Arun and Zhang, Aston and Rodriguez, Aurelien and Gregerson, Austen and Spataru, Ava and Roziere, Baptiste and Biron, Bethany and Tang, Binh and Chern, Bobbie and Caucheteux, Charlotte and Nayak, Chaya and Bi, Chloe and Marra, Chris and McConnell, Chris and Keller, Christian and Touret, Christophe and Wu, Chunyang and Wong, Corinne and Ferrer, Cristian Canton and Nikolaidis, Cyrus and Allonsius, Damien and Song, Daniel and Pintz, Danielle and Livshits, Danny and Esiobu, David and Choudhary, Dhruv and Mahajan, Dhruv and Garcia-Olano, Diego and Perino, Diego and Hupkes, Dieuwke and Lakomkin, Egor and AlBadawy, Ehab and Lobanova, Elina and Dinan, Emily and Smith, Eric Michael and Radenovic, Filip and Zhang, Frank and Synnaeve, Gabriel and Lee, Gabrielle and Anderson, Georgia Lewis and Nail, Graeme and Mialon, Gregoire and Pang, Guan and Cucurell, Guillem and Nguyen, Hailey and Korevaar, Hannah and Xu, Hu and Touvron, Hugo and Zarov, Iliyan and Ibarra, Imanol Arrieta and Kloumann, Isabel and Misra, Ishan and Evtimov, Ivan and Copet, Jade and Lee, Jaewon and Geffert, Jan and Vranes, Jana and Park, Jason and Mahadeokar, Jay and Shah, Jeet and van der Linde, Jelmer and Billock, Jennifer and Hong, Jenny and Lee, Jenya and Fu, Jeremy and Chi, Jianfeng and Huang, Jianyu and Liu, Jiawen and Wang, Jie and Yu, Jiecao and Bitton, Joanna and Spisak, Joe and Park, Jongsoo and Rocca, Joseph and Johnstun, Joshua and Saxe, Joshua and Jia, Junteng and Alwala, Kalyan Vasuden and Upasani, Kartikeya and Plawiak, Kate and Li, Ke and Heafield, Kenneth and Stone, Kevin and El-Arini, Khalid and Iyer, Krithika and Malik, Kshitiz and Chiu, Kuenley and Bhalla, Kunal and Rantala-Yeary, Lauren and van der Maaten, Laurens and Chen, Lawrence and Tan, Liang and Jenkins, Liz and Martin, Louis and Madaan, Lovish and Malo, Lubo and Blecher, Lukas and Landzaat, Lukas and de Oliveira, Luke and Muzzi, Madeline and Pasupuleti, Mahesh and Singh, Mannat and Paluri, Manohar and Kardas, Marcin and Oldham, Mathew and Rita, Mathieu and Pavlova, Maya and Kambadur, Melanie and Lewis, Mike and Si, Min and Singh, Mitesh Kumar and Hassan, Mona and Goyal, Naman and Torabi, Narjes and Bashlykov, Nikolay and Bogoychev, Nikolay and Chatterji, Niladri and Duchenne, Olivier and Çelebi, Onur and Alrassy, Patrick and Zhang, Pengchuan and Li, Pengwei and Vasic, Petar and Weng, Peter and Bhargava, Prajjwal and Dubal, Pratik and Krishnan, Praveen and Koura, Punit Singh and Xu, Puxin and He, Qing and Dong, Qingxiao and Srinivasan, Ragavan and Ganapathy, Raj and Calderer, Ramon and Cabral, Ricardo Silveira and Stojnic, Robert and Raileanu, Roberta and Girdhar, Rohit and Patel, Rohit and Sauvestre, Romain and Polidoro, Ronnie and Sumbaly, Roshan and Taylor, Ross and Silva, Ruan and Hou, Rui and Wang, Rui and Hosseini, Saghar and Chennabasappa, Sahana and Singh, Sanjay and Bell, Sean and Kim, Seohyun Sonia and Edunov, Sergey and Nie, Shaoliang and Narang, Sharan and Raparthy, Sharath and Shen, Sheng and Wan, Shengye and Bhosale, Shruti and Zhang, Shun and Vandenhende, Simon and Batra, Soumya and Whitman, Spencer and Sootla, Sten and Collot, Stephane and Gururangan, Suchin and Borodinsky, Sydney and Herman, Tamar and Fowler, Tara and Sheasha, Tarek and Georgiou, Thomas and Scialom, Thomas and Speckbacher, Tobias and Mihaylov, Todor and Xiao, Tong and Karn, Ujjwal and Goswami, Vedanuj and Gupta, Vibhor and Ramanathan, Vignesh and Kerkez, Viktor and Gonguet, Vincent and Do, Virginie and Vogeti, Vish and Petrovic, Vladan and Chu, Weiwei and Xiong, Wenhan and Fu, Wenyin and Meers, Whitney and Martinet, Xavier and Wang, Xiaodong and Tan, Xiaoqing Ellen and Xie, Xinfeng and Jia, Xuchao and Wang, Xuewei and Goldschlag, Yaelle and Gaur, Yashesh and Babaei, Yasmine and Wen, Yi and Song, Yiwen and Zhang, Yuchen and Li, Yue and Mao, Yuning and Coudert, Zacharie Delpierre and Yan, Zheng and Chen, Zhengxing and Papakipos, Zoe and Singh, Aaditya and Grattafiori, Aaron and Jain, Abha and Kelsey, Adam and Shajnfeld, Adam and Gangidi, Adithya and Victoria, Adolfo and Goldstand, Ahuva and Menon, Ajay and Sharma, Ajay and Boesenberg, Alex and Vaughan, Alex and Baevski, Alexei and Feinstein, Allie and Kallet, Amanda and Sangani, Amit and Yunus, Anam and Lupu, Andrei and Alvarado, Andres and Caples, Andrew and Gu, Andrew and Ho, Andrew and Poulton, Andrew and Ryan, Andrew and Ramchandani, Ankit and Franco, Annie and Saraf, Aparajita and Chowdhury, Arkabandhu and Gabriel, Ashley and Bharambe, Ashwin and Eisenman, Assaf and Yazdan, Azadeh and James, Beau and Maurer, Ben and Leonhardi, Benjamin and Huang, Bernie and Loyd, Beth and De Paola, Beto and Paranjape, Bhargavi and Liu, Bing and Wu, Bo and Ni, Boyu and Hancock, Braden and Wasti, Bram and Spence, Brandon and Stojkovic, Brani and Gamido, Brian and Montalvo, Britt and Parker, Carl and Burton, Carly and Mejia, Catalina and Wang, Changhan and Kim, Changkyu and Zhou, Chao and Hu, Chester and Chu, Ching-Hsiang and Cai, Chris and Tindal, Chris and Feichtenhofer, Christoph and Civin, Damon and Beaty, Dana and Kreymer, Daniel and Li, Daniel and Wyatt, Danny and Adkins, David and Xu, David and Testuggine, Davide and David, Delia and Parikh, Devi and Liskovich, Diana and Foss, Didem and Wang, Dingkang and Le, Duc and Holland, Dustin and Dowling, Edward and Jamil, Eissa and Montgomery, Elaine and Presani, Eleonora and Hahn, Emily and Wood, Emily and Brinkman, Erik and Arcaute, Esteban and Dunbar, Evan and Smothers, Evan and Sun, Fei and Kreuk, Felix and Tian, Feng and Ozgenel, Firat and Caggioni, Francesco and Guzmán, Francisco and Kanayet, Frank and Seide, Frank and Florez, Gabriela Medina and Schwarz, Gabriella and Badeer, Gada and Swee, Georgia and Halpern, Gil and Thattai, Govind and Herman, Grant and Sizov, Grigory and Guangyi and Zhang and Lakshminarayanan, Guna and Shojanazeri, Hamid and Zou, Han and Wang, Hannah and Zha, Hanwen and Habeeb, Haroun and Rudolph, Harrison and Suk, Helen and Aspegren, Henry and Goldman, Hunter and Damlaj, Ibrahim and Molybog, Igor and Tufanov, Igor and Veliche, Irina-Elena and Gat, Itai and Weissman, Jake and Geboski, James and Kohli, James and Asher, Japhet and Gaya, Jean-Baptiste and Marcus, Jeff and Tang, Jeff and Chan, Jennifer and Zhen, Jenny and Reizenstein, Jeremy and Teboul, Jeremy and Zhong, Jessica and Jin, Jian and Yang, Jingyi and Cummings, Joe and Carvill, Jon and Shepard, Jon and McPhie, Jonathan and Torres, Jonathan and Ginsburg, Josh and Wang, Junjie and Wu, Kai and U, Kam Hou and Saxena, Karan and Prasad, Karthik and Khandelwal, Kartikay and Zand, Katayoun and Matosich, Kathy and Veeraraghavan, Kaushik and Michelena, Kelly and Li, Keqian and Huang, Kun and Chawla, Kunal and Lakhotia, Kushal and Huang, Kyle and Chen, Lailin and Garg, Lakshya and A, Lavender and Silva, Leandro and Bell, Lee and Zhang, Lei and Guo, Liangpeng and Yu, Licheng and Moshkovich, Liron and Wehrstedt, Luca and Khabsa, Madian and Avalani, Manav and Bhatt, Manish and Tsimpoukelli, Maria and Mankus, Martynas and Hasson, Matan and Lennie, Matthew and Reso, Matthias and Groshev, Maxim and Naumov, Maxim and Lathi, Maya and Keneally, Meghan and Seltzer, Michael L. and Valko, Michal and Restrepo, Michelle and Patel, Mihir and Vyatskov, Mik and Samvelyan, Mikayel and Clark, Mike and Macey, Mike and Wang, Mike and Hermoso, Miquel Jubert and Metanat, Mo and Rastegari, Mohammad and Bansal, Munish and Santhanam, Nandhini and Parks, Natascha and White, Natasha and Bawa, Navyata and Singhal, Nayan and Egebo, Nick and Usunier, Nicolas and Laptev, Nikolay Pavlovich and Dong, Ning and Zhang, Ning and Cheng, Norman and Chernoguz, Oleg and Hart, Olivia and Salpekar, Omkar and Kalinli, Ozlem and Kent, Parkin and Parekh, Parth and Saab, Paul and Balaji, Pavan and Rittner, Pedro and Bontrager, Philip and Roux, Pierre and Dollar, Piotr and Zvyagina, Polina and Ratanchandani, Prashant and Yuvraj, Pritish and Liang, Qian and Alao, Rachad and Rodriguez, Rachel and Ayub, Rafi and Murthy, Raghotham and Nayani, Raghu and Mitra, Rahul and Li, Raymond and Hogan, Rebekkah and Battey, Robin and Wang, Rocky and Maheswari, Rohan and Howes, Russ and Rinott, Ruty and Bondu, Sai Jayesh and Datta, Samyak and Chugh, Sara and Hunt, Sara and Dhillon, Sargun and Sidorov, Sasha and Pan, Satadru and Verma, Saurabh and Yamamoto, Seiji and Ramaswamy, Sharadh and Lindsay, Shaun and Lindsay, Shaun and Feng, Sheng and Lin, Shenghao and Zha, Shengxin Cindy and Shankar, Shiva and Zhang, Shuqiang and Zhang, Shuqiang and Wang, Sinong and Agarwal, Sneha and Sajuyigbe, Soji and Chintala, Soumith and Max, Stephanie and Chen, Stephen and Kehoe, Steve and Satterfield, Steve and Govindaprasad, Sudarshan and Gupta, Sumit and Cho, Sungmin and Virk, Sunny and Subramanian, Suraj and Choudhury, Sy and Goldman, Sydney and Remez, Tal and Glaser, Tamar and Best, Tamara and Kohler, Thilo and Robinson, Thomas and Li, Tianhe and Zhang, Tianjun and Matthews, Tim and Chou, Timothy and Shaked, Tzook and Vontimitta, Varun and Ajayi, Victoria and Montanez, Victoria and Mohan, Vijai and Kumar, Vinay Satish and Mangla, Vishal and Albiero, Vítor and Ionescu, Vlad and Poenaru, Vlad and Mihailescu, Vlad Tiberiu and Ivanov, Vladimir and Li, Wei and Wang, Wenchen and Jiang, Wenwen and Bouaziz, Wes and Constable, Will and Tang, Xiaocheng and Wang, Xiaofang and Wu, Xiaojian and Wang, Xiaolan and Xia, Xide and Wu, Xilun and Gao, Xinbo and Chen, Yanjun and Hu, Ye and Jia, Ye and Qi, Ye and Li, Yenda and Zhang, Yilin and Zhang, Ying and Adi, Yossi and Nam, Youngjin and Yu and Wang and Hao, Yuchen and Qian, Yundi and He, Yuzi and Rait, Zach and DeVito, Zachary and Rosnbrick, Zef and Wen, Zhaoduo and Yang, Zhenyu and Zhao, Zhiwei}, + month = aug, + year = {2024}, + note = {arXiv:2407.21783 [cs]}, + keywords = {Computer Science - Artificial Intelligence, Computer Science - Computer Vision and Pattern Recognition, Computer Science - Computation and Language}, + file = {PDF:/Users/harsh/Zotero/storage/BQKY8VZZ/Dubey et al. - 2024 - The Llama 3 Herd of Models.pdf:application/pdf}, } -@misc{anonymous, - title= {Suppressed for Anonymity}, - author= {Author, N. N.}, - year= {2018} +@article{sun_dissecting_2023, + title = {Dissecting {Tensor} {Cores} via {Microbenchmarks}: {Latency}, {Throughput} and {Numeric} {Behaviors}}, + volume = {34}, + copyright = {https://ieeexplore.ieee.org/Xplorehelp/downloads/license-information/IEEE.html}, + issn = {1045-9219, 1558-2183, 2161-9883}, + shorttitle = {Dissecting {Tensor} {Cores} via {Microbenchmarks}}, + url = {https://ieeexplore.ieee.org/document/9931992/}, + doi = {10.1109/TPDS.2022.3217824}, + abstract = {Tensor Cores have been an important unit to accelerate Fused Matrix Multiplication Accumulation (MMA) in all NVIDIA GPUs since Volta Architecture. To program Tensor Cores, users have to use either legacy wmma APIs or current mma APIs. Legacy wmma APIs are more easy-to-use but can only exploit limited features and power of Tensor Cores. Specifically, wmma APIs support fewer operand shapes and can not leverage the new sparse matrix multiplication feature of the newest Ampere Tensor Cores. However, the performance of current programming interface has not been well explored. Furthermore, the computation numeric behaviors of lowprecision floating points (TF32, BF16, and FP16) supported by the newest Ampere Tensor Cores are also mysterious. In this paper, we explore the throughput and latency of current programming APIs. We also intuitively study the numeric behaviors of Tensor Cores MMA and profile the intermediate operations including multiplication, addition of inner product, and accumulation. All codes used in this work can be found in https://github.com/sunlex0717/DissectingTensorCores.}, + language = {en}, + number = {1}, + urldate = {2024-09-25}, + journal = {IEEE Transactions on Parallel and Distributed Systems}, + author = {Sun, Wei and Li, Ang and Geng, Tong and Stuijk, Sander and Corporaal, Henk}, + month = jan, + year = {2023}, + pages = {246--261}, + file = {PDF:/Users/harsh/Zotero/storage/NZD3FJUB/Sun et al. - 2023 - Dissecting Tensor Cores via Microbenchmarks Latency, Throughput and Numeric Behaviors.pdf:application/pdf}, } -@InCollection{Newell81, - author = "A. Newell and P. S. Rosenbloom", - title = "Mechanisms of Skill Acquisition and the Law of - Practice", - booktitle = "Cognitive Skills and Their Acquisition", - pages = "1--51", - publisher = "Lawrence Erlbaum Associates, Inc.", - year = "1981", - editor = "J. R. Anderson", - chapter = "1", - address = "Hillsdale, NJ" +@misc{paszke_pytorch_2019, + title = {{PyTorch}: {An} {Imperative} {Style}, {High}-{Performance} {Deep} {Learning} {Library}}, + shorttitle = {{PyTorch}}, + url = {http://arxiv.org/abs/1912.01703}, + abstract = {Deep learning frameworks have often focused on either usability or speed, but not both. PyTorch is a machine learning library that shows that these two goals are in fact compatible: it provides an imperative and Pythonic programming style that supports code as a model, makes debugging easy and is consistent with other popular scientific computing libraries, while remaining efficient and supporting hardware accelerators such as GPUs.}, + language = {en}, + urldate = {2024-09-25}, + publisher = {arXiv}, + author = {Paszke, Adam and Gross, Sam and Massa, Francisco and Lerer, Adam and Bradbury, James and Chanan, Gregory and Killeen, Trevor and Lin, Zeming and Gimelshein, Natalia and Antiga, Luca and Desmaison, Alban and Köpf, Andreas and Yang, Edward and DeVito, Zach and Raison, Martin and Tejani, Alykhan and Chilamkurthy, Sasank and Steiner, Benoit and Fang, Lu and Bai, Junjie and Chintala, Soumith}, + month = dec, + year = {2019}, + note = {arXiv:1912.01703 [cs, stat]}, + keywords = {Computer Science - Machine Learning, Computer Science - Mathematical Software, Statistics - Machine Learning}, + annote = {Comment: 12 pages, 3 figures, NeurIPS 2019}, + file = {PDF:/Users/harsh/Zotero/storage/D72HUVME/Paszke et al. - 2019 - PyTorch An Imperative Style, High-Performance Deep Learning Library.pdf:application/pdf}, } - -@Article{Samuel59, - author = "A. L. Samuel", - title = "Some Studies in Machine Learning Using the Game of - Checkers", - journal = "IBM Journal of Research and Development", - year = "1959", - volume = "3", - number = "3", - pages = "211--229" +@misc{chetlur_cudnn_2014, + title = {{cuDNN}: {Efficient} {Primitives} for {Deep} {Learning}}, + shorttitle = {{cuDNN}}, + url = {http://arxiv.org/abs/1410.0759}, + doi = {10.48550/arXiv.1410.0759}, + abstract = {We present a library of efficient implementations of deep learning primitives. Deep learning workloads are computationally intensive, and optimizing their kernels is difficult and time-consuming. As parallel architectures evolve, kernels must be reoptimized, which makes maintaining codebases difficult over time. Similar issues have long been addressed in the HPC community by libraries such as the Basic Linear Algebra Subroutines (BLAS). However, there is no analogous library for deep learning. Without such a library, researchers implementing deep learning workloads on parallel processors must create and optimize their own implementations of the main computational kernels, and this work must be repeated as new parallel processors emerge. To address this problem, we have created a library similar in intent to BLAS, with optimized routines for deep learning workloads. Our implementation contains routines for GPUs, although similarly to the BLAS library, these routines could be implemented for other platforms. The library is easy to integrate into existing frameworks, and provides optimized performance and memory usage. For example, integrating cuDNN into Caffe, a popular framework for convolutional networks, improves performance by 36\% on a standard model while also reducing memory consumption.}, + urldate = {2024-09-25}, + publisher = {arXiv}, + author = {Chetlur, Sharan and Woolley, Cliff and Vandermersch, Philippe and Cohen, Jonathan and Tran, John and Catanzaro, Bryan and Shelhamer, Evan}, + month = dec, + year = {2014}, + note = {arXiv:1410.0759 [cs]}, + keywords = {Computer Science - Machine Learning, Computer Science - Mathematical Software, Computer Science - Neural and Evolutionary Computing}, } diff --git a/shark_turbine/kernel/wave/docs/mlsys/tkw.blg b/shark_turbine/kernel/wave/docs/mlsys/tkw.blg deleted file mode 100644 index ef864a1b..00000000 --- a/shark_turbine/kernel/wave/docs/mlsys/tkw.blg +++ /dev/null @@ -1,46 +0,0 @@ -This is BibTeX, Version 0.99d (TeX Live 2020) -Capacity: max_strings=200000, hash_size=200000, hash_prime=170003 -The top-level auxiliary file: example_paper.aux -The style file: mlsys2024.bst -Database file #1: example_paper.bib -You've used 8 entries, - 2773 wiz_defined-function locations, - 645 strings with 5916 characters, -and the built_in function-call counts, 3248 in all, are: -= -- 293 -> -- 140 -< -- 9 -+ -- 49 -- -- 41 -* -- 223 -:= -- 507 -add.period$ -- 25 -call.type$ -- 8 -change.case$ -- 36 -chr.to.int$ -- 8 -cite$ -- 16 -duplicate$ -- 174 -empty$ -- 295 -format.name$ -- 51 -if$ -- 691 -int.to.chr$ -- 1 -int.to.str$ -- 1 -missing$ -- 6 -newline$ -- 47 -num.names$ -- 37 -pop$ -- 81 -preamble$ -- 1 -purify$ -- 29 -quote$ -- 0 -skip$ -- 127 -stack$ -- 0 -substring$ -- 100 -swap$ -- 24 -text.length$ -- 3 -text.prefix$ -- 0 -top$ -- 0 -type$ -- 78 -warning$ -- 0 -while$ -- 34 -width$ -- 0 -write$ -- 113 diff --git a/shark_turbine/kernel/wave/docs/mlsys/tkw.tex b/shark_turbine/kernel/wave/docs/mlsys/tkw.tex index cb56cab1..33282648 100644 --- a/shark_turbine/kernel/wave/docs/mlsys/tkw.tex +++ b/shark_turbine/kernel/wave/docs/mlsys/tkw.tex @@ -30,7 +30,7 @@ \begin{document} \twocolumn[ -\mlsystitle{Submission and Formatting Instructions for MLSys 2024} +\mlsystitle{Wave : A Python DSL for High Performance Machine Learning} % It is OKAY to include author information, even for blind % submissions: the style file will automatically remove it for you @@ -94,6 +94,81 @@ %\printAffiliationsAndNotice{} % leave blank if no need to mention equal contribution \printAffiliationsAndNotice{\mlsysEqualContribution} % otherwise use the standard text. +\section{Introduction} +Generative models have seen tremendous success in a wide variety of +domains ranging from image generation to natural language processing and beyond. +\cite{podell_sdxl_2023,dubey_llama_2024}. Much of this success is being +driven by graphics processing units (GPUs) which while originally +designed for graphics, are now being optimized for machine learning. +Both datacenter and consumer grade GPUs feature powerful matrix multiplication hardware units +and specialized instructions to enable high performance inference and training \cite{sun_dissecting_2023}. +\\ \\ +Given the importance of GPUs in machine learning, significant +effort has been put into developing frameworks that allow developers to +write high performance machine learning models with a low barrier to entry. Frameworks such +as Pytorch \cite{paszke_pytorch_2019} have become extremely popular +because they expose a Python based approach to programming GPUs. Prior +to the advent of these frameworks, developers had to write CUDA or OpenCL +kernels by hand which required significant expertise to achieve +good performance and did not scale well to new operators. +\\ \\ +Under the hood, these machine learning frameworks rely heavily +on vendor-specific libraries such as cuDNN \cite{chetlur_cudnn_2014} to achieve high performance. +These libraries are performant but are black boxes consisting of +hand-written kernels and often do not support the full set of +fused operators encountered in machine learning models. +To address these limitations, recent work has focused on developing +Python domain specific languages (DSL) that allow developers to get high performance +while reducing the kernel complexity. Triton \cite{tillet_triton_2019}. +is a popular Python DSL that exposes a workgroup level programming +model and allows developers to author high performance kernels. +However, Triton kernels often get quite complex and start to +resemble hand-written kernels as the kernel complexity grows. +Furthermore, fusion of Triton kernels is limited to a few operators +and remains an open problem. + +In this paper, we introduce Wave, a Python DSL for high performance machine learning. +Wave exposes a subgroup (wave or warp) level programming model that allows +for much simpler kernels compared to Triton. Through the use of constraints, Wave forces developers to +come up with the distribution strategy for their kernel - +which dimensions are parallel and which are sequential and how to distribute those +dimensions across the memory and compute hierarchy of the GPU. This allows for a separation +between the kernel and the distribution strategy and makes the kernel simpler. +Wave also embraces symbolic data types using sympy to represent the shapes and +memory access patterns of tensors in the kernel. +It has a Python based compiler that uses torch.fx tracing to define +and trace operators written in the language. The torch.fx graphs are then run through a series of optimization passes +on the computation graph and are finally lowered to MLIR and subsequently LLVM. This code generation flow allows compiler writers +to blend high productivity in Python with high performance from the MLIR and LLVM +code generation flow. +\\ \\ +In summary, the contributions of this paper are as follows: +\begin{itemize} + \item A novel subgroup programming model for GPU with a Python DSL that separates distribution strategies from the core kernel allowing for simpler kernels, + \item A symbolic data type system that allows for reasoning about tensor shapes and memory access patterns in the kernel, + \item A Python compiler that leverages torch.fx for tracing and maps torch.fx graphs to MLIR and LLVM for high performance code generation. +\end{itemize} + + +\section{Memory Access Patterns} +We represent memory access patterns in the language using the standard +triplet notation consisting of an offset, number of elements, and absolute stride and associate +a triplet with each tensor dimension. The memory access pattern for a given operation +is determined by the access patterns of the operands of the operation as well as +the user-specified constraints. For example, the memory access pattern for the output +of an elementwise operation is determined from the access patterns of the inputs, +whereas for a matrix-multiply accumulate operation, the memory access patterns of the operands are specified by +the hardware constraint. +\\ \\ +One of the advantages of the dimension based specification is that it obviates +the need for any propagation of memory access patterns through the computation graph, +as is commonly done in other frameworks. When setting the access pattern for a specific +dimension of a tensor, the access pattern is taken to be the union of all possible +access patterns with the determination of which access pattern to use based on +the minimization of an appropriate metric across the entire graph (see Section 3). + + +\iffalse \section{Electronic Submission} \label{submission} @@ -527,8 +602,9 @@ \section*{Acknowledgements} % In the unusual situation where you want a paper to appear in the % references without citing it in the main text, use \nocite \nocite{langley00} +\fi -\bibliography{example_paper} +\bibliography{tkw} \bibliographystyle{mlsys2024} From 76861577f3bbedc2ccec21647376e6b5f771ea94 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Fri, 27 Sep 2024 00:31:38 +0300 Subject: [PATCH 07/28] [TKW] igemm shared mem tests (#171) Signed-off-by: Ivan Butygin --- shark_turbine/kernel/wave/minimize_global_loads.py | 5 ++--- tests/kernel/wave/wave_e2e_test.py | 5 +++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/shark_turbine/kernel/wave/minimize_global_loads.py b/shark_turbine/kernel/wave/minimize_global_loads.py index 3ea1a3d0..17971354 100644 --- a/shark_turbine/kernel/wave/minimize_global_loads.py +++ b/shark_turbine/kernel/wave/minimize_global_loads.py @@ -63,12 +63,11 @@ def materialize_shape( constraint_tile_size: dict[IndexSymbol, int], symbolic_shape: list[IndexSymbol] ) -> list[int]: materialized_shape = [] - idxc = IndexingContext.current() for dim in symbolic_shape: if dim in constraint_tile_size: - materialized_shape.append(constraint_tile_size[dim].subs(idxc.subs)) + materialized_shape.append(subs_idxc(constraint_tile_size[dim])) else: - materialized_shape.append(dim.subs(idxc.subs)) + materialized_shape.append(subs_idxc(dim)) return materialized_shape diff --git a/tests/kernel/wave/wave_e2e_test.py b/tests/kernel/wave/wave_e2e_test.py index fcabb8d9..0e55a3f9 100644 --- a/tests/kernel/wave/wave_e2e_test.py +++ b/tests/kernel/wave/wave_e2e_test.py @@ -582,7 +582,8 @@ def repeat(acc: tkl.Register[M, NF, tkl.f32]) -> tkl.Register[M, NF, tkl.f32]: @pytest.mark.parametrize("c", [1, 3, 4, 10]) @pytest.mark.parametrize("nf", [1, 2, 16]) @pytest.mark.parametrize("stride", [1, 2, 3]) -def test_igemm_conv(n, c, nf, stride): +@pytest.mark.parametrize("mem_space", [GLOBAL_ADDRESS_SPACE, SHARED_ADDRESS_SPACE]) +def test_igemm_conv(n, c, nf, stride, mem_space): h, w = 5, 5 # Image. cf, hf, wf = c, 2, 2 # Filters. padding = 0 # TODO: only pad=0 is supported for now @@ -702,7 +703,7 @@ def repeat(acc: tkl.Register[M, NF, tkl.f32]) -> tkl.Register[M, NF, tkl.f32]: BLOCK_M: 16, BLOCK_N: 16, ELEMS_PER_THREAD: 4, - ADDRESS_SPACE: GLOBAL_ADDRESS_SPACE, + ADDRESS_SPACE: mem_space, }, canonicalize=True, run=True, From 0e16d541822b4a3478ce38948881701eec38a55a Mon Sep 17 00:00:00 2001 From: Stanley Winata <68087699+raikonenfnu@users.noreply.github.com> Date: Fri, 27 Sep 2024 00:24:24 -0700 Subject: [PATCH 08/28] [TKW] Implement support for multiple iter args on Reduction (#166) The main motivation behind this PR is to enable multiple induction variable/iterArg on the same tiled "Reduction" loop. To enable above we did a couple things: 1. Enable lowering/expansion on `operator.getitem` (the op that extract multiple results in python i.e `res0, res1 = fn`) by templating it on`GetResult(CustomOp)` since they have the same args and interface and can reuse most of the indexing/expansion helper. 2. Introduce `res_idx`, a variable to represent which result index of an op we are referring to, during expansion and context map. This is useful for ops that has more than one results / variables as outputs. 3. bug fix in expand_reduction, where we hoist out iterating and expanding of `reduction.init_args` out of the loop that iterates and expands over the `yield`/`return_val` of the reduction loop. It is expected that the size of `init_args` is the same as size of `yield`/`return_val`. Hence if we had N iter_args/yields, we ended up expanding the `init_args` N x N time instead of N times. We haven't seen it thus far because we have been only playing with 1 init_arg/iterArg, and 1x1 == 1. 4. Introduce a canonicalization pattern to fold chains of GetResult. this is because GetResult by semantic/design is only expected to extract and have one result. Hence a chain of GetResult should just be replaced by itself. This help clean up the IR. num.4 also helps circumvent issue where Reduction and GetResult is expanded completely by itself not following the DFS structure per dimension like the rest of the expansion code. This becomes especially problematic for multiple IterArg since Getitem is not expecting its' source value to be expanded without it. --------- Signed-off-by: Stanley Winata --- lit_tests/kernel/wave/codegen.py | 105 +++++++++++++++++++++++++ shark_turbine/kernel/ops/wave_ops.py | 2 +- shark_turbine/kernel/wave/expansion.py | 89 ++++++++++++++------- shark_turbine/kernel/wave/utils.py | 13 +++ shark_turbine/kernel/wave/wave.py | 10 ++- tests/kernel/wave/wave_e2e_test.py | 38 +++++---- 6 files changed, 212 insertions(+), 45 deletions(-) diff --git a/lit_tests/kernel/wave/codegen.py b/lit_tests/kernel/wave/codegen.py index d102f353..e18764b1 100644 --- a/lit_tests/kernel/wave/codegen.py +++ b/lit_tests/kernel/wave/codegen.py @@ -946,6 +946,111 @@ def repeat( # CHECK: scf.yield %[[ACC_REDUCE]] : vector<1xf16> +# This test is to ensure that the we can handle multiple IV in reduction properly. +@run_test +def test_multiple_reduction_iv(): + M = tkl.sym.M + N = tkl.sym.N + BLOCK_M = tkl.sym.BLOCK_M + BLOCK_N = tkl.sym.BLOCK_N + ELEMS_PER_THREAD = tkl.sym.ELEMS_PER_THREAD + ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE + + constraints: list[tkw.Constraint] = [ + tkw.HardwareConstraint( + threads_per_wave=64, + waves_per_block=(1, 1, 1), + vector_shapes={M: 1, N: BLOCK_N}, + ) + ] + constraints += [tkw.WorkgroupConstraint(M, BLOCK_M, 1)] + constraints += [tkw.WorkgroupConstraint(N, N, 0)] + constraints += [tkw.TilingConstraint(N, BLOCK_N)] + constraints += [tkw.WaveConstraint(M, BLOCK_M)] + constraints += [tkw.WaveConstraint(N, BLOCK_N)] + + @tkw.wave(constraints) + def test( + a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16], + c: tkl.Memory[M, ADDRESS_SPACE, tkl.f16], + d: tkl.Memory[M, ADDRESS_SPACE, tkl.f16], + ): + init_max = tkl.Register[M, tkl.f16](-1e6) + init_sum = tkl.Register[M, tkl.f16](0) + + @tkw.reduction(N, init_args=[init_max, init_sum]) + def repeat( + partial_max: tkl.Register[M, tkl.f16], + partial_sum: tkl.Register[M, tkl.f16], + ) -> tkl.Register[M, tkl.f16]: + lhs = tkw.read(a, elements_per_thread=ELEMS_PER_THREAD) + partial_max = tkw.max(lhs, partial_max, dim=N) + partial_sum = tkw.sum(lhs, partial_sum, dim=N) + return partial_max, partial_sum + + res_max, res_sum = repeat + tkw.write(res_max, c, elements_per_thread=1) + tkw.write(res_sum, d, elements_per_thread=1) + + config = {"backend": "rocm", "device": "hip", "target": "gfx942"} + + shape = (256, 512) + a = torch.randn(shape, dtype=torch.float16) + c = torch.zeros((shape[0],), dtype=torch.float16) + d = torch.zeros((shape[0],), dtype=torch.float16) + with tk.gen.TestLaunchContext( + { + M: shape[0], + N: shape[1], + BLOCK_M: 2, + BLOCK_N: 128, + ELEMS_PER_THREAD: 2, + ADDRESS_SPACE: tkl.AddressSpace.GLOBAL_MEMORY.value, + }, + canonicalize=True, + ): + print(test(a, c).module_op) + # CHECK-DAG: %[[C0_IDX:.+]] = arith.constant 0 : index + # CHECK-DAG: %[[C4_IDX:.+]] = arith.constant 4 : index + # CHECK-DAG: %[[C1_IDX:.+]] = arith.constant 1 : index + # CHECK-DAG: %[[INIT_MAX:.+]] = arith.constant dense<0xFC00> : vector<1xf16> + # CHECK-DAG: %[[INIT_SUM:.+]] = arith.constant dense<0.000000e+00> : vector<1xf16> + + # Tile Reduction Loop + # CHECK: %[[TILED:.+]]:4 = scf.for %[[ITER:.+]] = %[[C0_IDX]] to %[[C4_IDX]] step %[[C1_IDX]] + # CHECK-SAME: iter_args(%[[ACC0:.+]] = %[[INIT_MAX]], %[[ACC1:.+]] = %[[INIT_SUM]], %[[ACC2:.+]] = %[[INIT_MAX]], %[[ACC3:.+]] = %[[INIT_SUM]]) + # CHECK-SAME: -> (vector<1xf16>, vector<1xf16>, vector<1xf16>, vector<1xf16>) { + # 1st Expanded Local Max Reduction + # CHECK: arith.maximumf {{.*}} : vector<1xf16> + # 1st Expanded Global Max Reduction + # CHECK-COUNT-6: gpu.shuffle xor + # 1st Expanded Accumulator Max Reduction + # CHECK: %[[ACC_MAX_0:.+]] = arith.maximumf %[[ACC0]], %{{.*}} + + # 2nd Expanded Local Max Reduction + # CHECK: arith.maximumf {{.*}} : vector<1xf16> + # 2nd Expanded Global Max Reduction + # CHECK-COUNT-6: gpu.shuffle xor + # 2nd Expanded Accumulator Max Reduction + # CHECK: %[[ACC_MAX_1:.+]] = arith.maximumf %[[ACC2]], %{{.*}} + + # 1st Expanded Local Sum Reduction + # CHECK: arith.addf {{.*}} : vector<1xf16> + # 1st Expanded Global Sum Reduction + # CHECK-COUNT-6: gpu.shuffle xor + # 1st Expanded Accumulator Sum Reduction + # CHECK: %[[ACC_SUM_0:.+]] = arith.addf %[[ACC1]], %{{.*}} + + # 2nd Expanded Local Sum Reduction + # CHECK: arith.addf {{.*}} : vector<1xf16> + # 2nd Expanded Global Sum Reduction + # CHECK-COUNT-6: gpu.shuffle xor + # 2nd Expanded Accumulator Sum Reduction + # CHECK: %[[ACC_SUM_1:.+]] = arith.addf %[[ACC3]], %{{.*}} + + # CHECK: scf.yield %[[ACC_MAX_0]], %[[ACC_SUM_0]], %[[ACC_MAX_1]], %[[ACC_SUM_1]] + + @run_test def test_binary_lowerings(): constraints: list[tkw.Constraint] = [ diff --git a/shark_turbine/kernel/ops/wave_ops.py b/shark_turbine/kernel/ops/wave_ops.py index 905095c6..2c38c9c2 100644 --- a/shark_turbine/kernel/ops/wave_ops.py +++ b/shark_turbine/kernel/ops/wave_ops.py @@ -486,7 +486,6 @@ def post_expansion(self, constraints: list["Constraint"]) -> None: pass -@define_py_op(operator.getitem) @define_py_op(operator.add) @define_py_op(operator.sub) @define_py_op(operator.mul) @@ -945,6 +944,7 @@ def register_index(self) -> dict[IndexSymbol, IndexSequence]: return custom.index +@define_py_op(operator.getitem) @define_op("get_result") @dataclass class GetResult(CustomOp): diff --git a/shark_turbine/kernel/wave/expansion.py b/shark_turbine/kernel/wave/expansion.py index ebbc2d46..69785031 100644 --- a/shark_turbine/kernel/wave/expansion.py +++ b/shark_turbine/kernel/wave/expansion.py @@ -23,11 +23,11 @@ from ..lang.global_symbols import * logger = get_logger("turbine.wave.expansion") -# This represents a mapping of a node + indexing into the dimensions to the -# corresponding expanded node in these specific dimensions. An example for a -# record in this map is (read_0_0_0, ((M,0),(N,0),(K,1)) -> read_0_0_1 +# This represents a mapping of a node + indexing + res_idx(output index for op with multiple results) +# of node into the dimensions to the corresponding expanded node in these specific dimensions. +# An example for a record in this map is (read_0_0_0, ((M,0),(N,0),(K,1), 0) -> read_0_0_1. ExpandedNodeMap: TypeAlias = dict[ - tuple[CustomOp, tuple[tuple[IndexSymbol, int], ...]], CustomOp + tuple[CustomOp, tuple[tuple[IndexSymbol, int], int, ...]], CustomOp ] @@ -302,6 +302,7 @@ def _expand_node( dim_scaling: dict[IndexSymbol, int], node_index_setter: Callable[[CustomOp, dict[IndexSymbol, int]], None], context: ExpandedNodeMap, + res_idx: int = 0, ) -> CustomOp: """Expand a single node or list of nodes in specific dimensions and recursively proceed to its inputs.""" if isinstance(node, list): @@ -309,23 +310,31 @@ def _expand_node( for elem in node: expanded_nodes.append( _expand_node( - elem, trace, dim_query, dim_scaling, node_index_setter, context + elem, + trace, + dim_query, + dim_scaling, + node_index_setter, + context, + res_idx, ).fx_node ) return expanded_nodes # If we expanded a node in the same dimensions before, we can reuse it - if (node, get_indexed_dims(dim_query, node)) in context: + if (node, get_indexed_dims(dim_query, node), res_idx) in context: logger.debug(f"Already expanded node: {node} in {dim_query}") - return context[(node, get_indexed_dims(dim_query, node))] + return context[(node, get_indexed_dims(dim_query, node), res_idx)] elif isinstance(node, Reduction): return _expand_reduction( node, trace, dim_query, dim_scaling, node_index_setter, context ) - elif isinstance(node, GetResult): + elif isinstance(node, Getitem): + res_idx = node.res_idx + elif isinstance(node, GetResult) and not isinstance(node, Getitem): # The presence of a GetResult node indicates that the reduction has already # been expanded. Simply return the corresponding node. reduction = get_custom(node.value) - return context[(reduction, get_indexed_dims(dim_query, reduction))] + return context[(reduction, get_indexed_dims(dim_query, reduction), res_idx)] elif isinstance(node, Allocate): # Allocate nodes are not expanded. return node @@ -371,12 +380,13 @@ def _expand_node( dim_scaling, node_index_setter, context, + res_idx, ) new_node.update_arg(i, new_arg) new_node.post_expansion(constraints) - context[(node, get_indexed_dims(restricted_dims, node))] = new_node + context[(node, get_indexed_dims(restricted_dims, node), res_idx)] = new_node return new_node @@ -387,6 +397,7 @@ def _expand_reduction( dim_scaling: dict[IndexSymbol, int], node_index_setter: Callable[[CustomOp, dict[IndexSymbol, int]], None], context: ExpandedNodeMap, + res_idx: int = 0, ) -> CustomOp: """Expand a reduction in a specific dimension and recursively proceed to its inputs.""" # Determine the dimensions to expand the reduction from the indexing of its users @@ -409,32 +420,41 @@ def _expand_reduction( new_output_args = [] new_init_args = [] for dim_vals in get_dim_combinations(dim_scaling, expand_dims): - for arg_idx, arg in output.node_args.items(): - dims = {dim: val for dim, val in zip(dim_scaling.keys(), dim_vals)} + return_vals = output.return_vals[0] + dims = {dim: val for dim, val in zip(dim_scaling.keys(), dim_vals)} + if not isinstance(return_vals, Sequence): + return_vals = [return_vals] + for arg_idx, arg in enumerate(return_vals): + arg = get_custom(arg) # Add GetResult nodes for the corresponding dimensions reduction.graph.inserting_after(reduction.fx_node) new_node = GetResult(reduction.fx_node, len(new_output_args)) new_node.add_to_graph(reduction.graph) new_node.fx_node.name = get_expanded_name(new_node, dims) - context[(reduction, get_indexed_dims(dims, expand_dims))] = new_node + context[ + (reduction, get_indexed_dims(dims, expand_dims), arg_idx) + ] = new_node # Proceed with expansion inside the reduction new_output_args.append( - _expand_node(arg, trace, dims, dim_scaling, node_index_setter, context) + _expand_node( + arg, trace, dims, dim_scaling, node_index_setter, context, res_idx + ) ) - # Proceed with expansion outside the reduction - for init_arg in reduction.init_args: - new_init_args.append( - _expand_node( - get_custom(init_arg), - trace, - dims, - dim_scaling, - node_index_setter, - context, - ) + # Proceed with expansion outside the reduction + for init_arg in reduction.init_args: + new_init_args.append( + _expand_node( + get_custom(init_arg), + trace, + dims, + dim_scaling, + node_index_setter, + context, + res_idx, ) + ) # Update init_args and return values reduction.update_arg( @@ -442,11 +462,17 @@ def _expand_reduction( ) output.update_arg("return_vals", [node.fx_node for node in new_output_args]) _handle_reduction_dim( - reduction, output, trace, dim_scaling, node_index_setter, context + reduction, + output, + trace, + dim_scaling, + node_index_setter, + context, + res_idx, ) # Even though we expanded the reduction in multiple dimensions, we only return # the node corresponding to the original query - return context[(reduction, get_indexed_dims(dim_query, expand_dims))] + return context[(reduction, get_indexed_dims(dim_query, expand_dims), res_idx)] def get_expanded_name(node: CustomOp, dims: dict[IndexSymbol, int]) -> str: @@ -536,6 +562,7 @@ def _handle_reduction_dim( dim_scaling: dict[IndexSymbol, int], node_index_setter: Callable[[CustomOp, dict[IndexSymbol, int]], None], context: ExpandedNodeMap, + res_idx: int, ): # Rediscover iter args # TODO: Register iter args with the reduction initially so accessing them is easier @@ -572,7 +599,13 @@ def _handle_reduction_dim( saved_arg = user.node_args[index] user.update_arg(index, dummy) new_node = _expand_node( - user, trace, dims, dim_scaling, node_index_setter, context + user, + trace, + dims, + dim_scaling, + node_index_setter, + context, + res_idx, ) # This expansion always happens, user should never be reused diff --git a/shark_turbine/kernel/wave/utils.py b/shark_turbine/kernel/wave/utils.py index 42e5bca3..dda9013b 100644 --- a/shark_turbine/kernel/wave/utils.py +++ b/shark_turbine/kernel/wave/utils.py @@ -123,6 +123,19 @@ def is_removable_operator(node: fx.Node) -> bool: get_custom(node).graph.erase_node(node) +def remove_chained_getresult(trace: CapturedTrace): + def is_chained_getresult(node: fx.Node) -> bool: + custom = get_custom(node) + return isinstance(custom, GetResult) and isinstance( + get_custom(custom.value), GetResult + ) + + while removable_nodes := trace.walk(is_chained_getresult): + for node in removable_nodes: + get_custom(node).replace_all_uses_with(get_custom(node).value) + get_custom(node).graph.erase_node(node) + + def delinearize_index(index: IndexExpr, shape: list[int]) -> list[IndexExpr]: """ Delinearizes a 1D index into a multi-dimensional index diff --git a/shark_turbine/kernel/wave/wave.py b/shark_turbine/kernel/wave/wave.py index eb6003de..4d19d99f 100644 --- a/shark_turbine/kernel/wave/wave.py +++ b/shark_turbine/kernel/wave/wave.py @@ -23,7 +23,12 @@ from .expansion import expand_graph from .promotion import promote_placeholders from .hoisting import hoist_allocs -from .utils import canonicalize_module, compile_and_invoke, safe_subs +from .utils import ( + canonicalize_module, + compile_and_invoke, + safe_subs, + remove_chained_getresult, +) from .minimize_global_loads import minimize_global_loads from .decompose_reduce_ops import decompose_reduce_ops from .barriers import add_shared_memory_barriers @@ -205,6 +210,9 @@ def _trace_and_get_kernel_signature( # Expansion expand_graph(graph, self.constraints) + # Clean up chains of GetResults + remove_chained_getresult(graph) + # Register analysis to determine register shapes. determine_register_shape(graph, self.constraints) diff --git a/tests/kernel/wave/wave_e2e_test.py b/tests/kernel/wave/wave_e2e_test.py index 0e55a3f9..dbe88424 100644 --- a/tests/kernel/wave/wave_e2e_test.py +++ b/tests/kernel/wave/wave_e2e_test.py @@ -279,7 +279,7 @@ def test( @require_e2e @pytest.mark.parametrize("shape", get_test_shapes("test_tiled_reduce_max")) @xfail_unaligned -def test_tiled_reduce_max(shape): +def test_toy_online_softmax(shape): M = tkl.sym.M N = tkl.sym.N wave_size = 64 @@ -303,30 +303,38 @@ def test_tiled_reduce_max(shape): @tkw.wave(constraints) def test( - a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16], - b: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16], - c: tkl.Memory[M, ADDRESS_SPACE, tkl.f16], + a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f32], + b: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f32], + c: tkl.Memory[M, ADDRESS_SPACE, tkl.f32], ): - init_max = tkl.Register[M, tkl.f16](-1e6) + init_max = tkl.Register[M, tkl.f32](-1e6) + init_sum = tkl.Register[M, tkl.f32](0) - @tkw.reduction(N, init_args=[init_max]) + @tkw.reduction(N, init_args=[init_max, init_sum]) def repeat( - partial_max: tkl.Register[M, tkl.f16], - ) -> tkl.Register[M, tkl.f16]: + partial_max: tkl.Register[M, tkl.f32], + partial_sum: tkl.Register[M, tkl.f32], + ) -> tkl.Register[M, tkl.f32]: lhs = tkw.read(a, elements_per_thread=ELEMS_PER_THREAD) rhs = tkw.read(b, elements_per_thread=ELEMS_PER_THREAD) res = lhs * rhs partial_max = tkw.max(res, partial_max, dim=N) - return partial_max + partial_sum = tkw.sum(res, partial_sum, dim=N) + return partial_max, partial_sum - tkw.write(repeat, c, elements_per_thread=1) + res_max, res_sum = repeat + result = res_max / res_sum + tkw.write(result, c, elements_per_thread=1) config = {"backend": "rocm", "device": "hip", "target": "gfx942"} - a = torch.randn(shape, dtype=torch.float16) - b = torch.randn(shape, dtype=torch.float16) - c = torch.zeros((shape[0],), dtype=torch.float16) - ref = torch.max((a * b), dim=-1) + torch.manual_seed(1) + a = torch.randn(shape, dtype=torch.float32) + b = torch.randn(shape, dtype=torch.float32) + c = torch.zeros((shape[0],), dtype=torch.float32) + ref_max = torch.max((a * b), dim=-1).values + ref_sum = torch.sum((a * b), dim=-1) + ref = ref_max / ref_sum with tk.gen.TestLaunchContext( { M: shape[0], @@ -343,7 +351,7 @@ def repeat( # Assert equal does cast to boolean on torch.Tensor # which causes issues, hence we cast to numpy before # checking. - assert_equal(c, ref.values.numpy()) + assert_allclose(ref, c, atol=0.015) @require_e2e From 192a78640de88da8ad6e3085faf052e087511380 Mon Sep 17 00:00:00 2001 From: Boian Petkantchin Date: Fri, 27 Sep 2024 13:42:13 -0400 Subject: [PATCH 09/28] Handle complex element type in torch.vtensor conversion (#175) Signed-off-by: Boian Petkantchin --- shark_turbine/dynamo/type_conversion.py | 2 +- tests/dynamo/type_conversion_test.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/shark_turbine/dynamo/type_conversion.py b/shark_turbine/dynamo/type_conversion.py index 8206e10f..e829bafc 100644 --- a/shark_turbine/dynamo/type_conversion.py +++ b/shark_turbine/dynamo/type_conversion.py @@ -32,7 +32,7 @@ # 1. Local name (int, float, vtensor) # 2. Parameter block ("<...>"), including the delimitters # 3. Inner parameter block (no delimitters) -DECOMPOSE_TORCH_TYPE_PATTERN = re.compile(r"^!torch.([^<]+)(<([^>]*)>)?$") +DECOMPOSE_TORCH_TYPE_PATTERN = re.compile(r"^!torch\.([^<]+)(<(.*)>)?$") # Decomposes a vtensor parameter block into a dimension list and dtype. Groups: # 1. Dimension list diff --git a/tests/dynamo/type_conversion_test.py b/tests/dynamo/type_conversion_test.py index dfc3de25..617c5d05 100644 --- a/tests/dynamo/type_conversion_test.py +++ b/tests/dynamo/type_conversion_test.py @@ -32,6 +32,7 @@ def testValueTensors(self): self._compareNative("!torch.vtensor<[2, 2],f32>", "tensor<2x2xf32>") self._compareNative("!torch.vtensor<[?, ?],f32>", "tensor") self._compareNative("!torch.vtensor<[],f32>", "tensor") + self._compareNative("!torch.vtensor<[],complex>", "tensor>") def _compareNative(self, torch_str: str, native_str: str, *, signless: bool = True): with self.conv._context: From 92ad9007470ecb8e1ebf8f4b803aee48956306df Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Mon, 30 Sep 2024 17:46:17 +0300 Subject: [PATCH 10/28] [TKW] Rework vector mask generation (#172) Instead of generating individual element comparisons and doing `vector.insertelement` generate the whole mask using vector ops. Add support for vector codegen when generating MLIR IR from sympy expressions. Add method `IndexingContext.iota` to generate special symbols which map to `(1,2 ... n-1)` vec expressions. `gen_sympy_index` will start to generate vector ops when encountering such symbols, inserting proper `splat`'s between scalar vals when necessary. --------- Signed-off-by: Ivan Butygin --- lit_tests/kernel/wave/codegen.py | 79 +++++++-------- shark_turbine/kernel/_support/indexing.py | 16 +++ shark_turbine/kernel/wave/codegen.py | 116 ++++++++++++++-------- shark_turbine/kernel/wave/utils.py | 33 +++++- 4 files changed, 154 insertions(+), 90 deletions(-) diff --git a/lit_tests/kernel/wave/codegen.py b/lit_tests/kernel/wave/codegen.py index e18764b1..be4b04bb 100644 --- a/lit_tests/kernel/wave/codegen.py +++ b/lit_tests/kernel/wave/codegen.py @@ -231,51 +231,40 @@ def test( print(test(a, b).module_op) # CHECK: func.func @test(%[[ARG0:[a-zA-Z0-9_]+]]: !stream.binding, %[[ARG1:[a-zA-Z0-9_]+]]: !stream.binding) - # CHECK-DAG: %[[CST:.+]] = arith.constant dense<0.000000e+00> : vector<4xf16> - # CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index - # CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index - # CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index - # CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index - # CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index - # CHECK-DAG: %[[C8:.+]] = arith.constant 8 : index - # CHECK-DAG: %[[C64:.+]] = arith.constant 64 : index - # CHECK: %[[WORKGROUP_ID_0:.+]] = stream.dispatch.workgroup.id[0] : index - # CHECK: %[[WORKGROUP_ID_1:.+]] = stream.dispatch.workgroup.id[1] : index - # CHECK-DAG: %[[THREAD_ID_X:.+]] = gpu.thread_id x - # CHECK-DAG: %[[THREAD_ID_Y:.+]] = gpu.thread_id y - # CHECK: %[[D0:.+]] = stream.binding.subspan %[[ARG0]][%[[C0]]] : !stream.binding -> - # CHECK-SAME: memref<1x3xf16, strided<[3, 1], offset: ?>> - # CHECK: %[[D1:.+]] = arith.muli %[[WORKGROUP_ID_0]], %[[C4]] : index - # CHECK: %[[D2:.+]] = arith.divsi %[[THREAD_ID_X]], %[[C64]] : index - # CHECK: %[[D3:.+]] = arith.muli %[[D2]], %[[C4]] : index - # CHECK: %[[D4:.+]] = arith.addi %[[D3]], %[[D1]] : index - # CHECK: %[[D5:.+]] = arith.addi %[[D4]], %[[THREAD_ID_X]] : index - # CHECK: %[[D6:.+]] = arith.muli %[[WORKGROUP_ID_1]], %[[C4]] : index - # CHECK: %[[D7:.+]] = arith.muli %[[THREAD_ID_Y]], %[[C8]] : index - # CHECK: %[[D8:.+]] = arith.addi %[[D7]], %[[D6]] : index - # CHECK: %[[D9:.+]] = vector.constant_mask [4] : vector<4xi1> - # CHECK: %[[D10:.+]] = arith.cmpi slt, %[[D5]], %[[C1]] : index - # CHECK: %[[D11:.+]] = arith.cmpi slt, %[[D8]], %[[C3]] : index - # CHECK: %[[D12:.+]] = arith.andi %[[D10]], %[[D11]] : i1 - # CHECK: %[[D13:.+]] = vector.insertelement %[[D12]], %[[D9]][%[[C0]] : index] : vector<4xi1> - # CHECK: %[[D14:.+]] = arith.addi %[[D8]], %[[C1]] : index - # CHECK: %[[D15:.+]] = arith.cmpi slt, %[[D14]], %[[C3]] : index - # CHECK: %[[D16:.+]] = arith.andi %[[D10]], %[[D15]] : i1 - # CHECK: %[[D17:.+]] = vector.insertelement %[[D16]], %[[D13]][%[[C1]] : index] : vector<4xi1> - # CHECK: %[[D18:.+]] = arith.addi %[[D8]], %[[C2]] : index - # CHECK: %[[D19:.+]] = arith.cmpi slt, %[[D18]], %[[C3]] : index - # CHECK: %[[D20:.+]] = arith.andi %[[D10]], %[[D19]] : i1 - # CHECK: %[[D21:.+]] = vector.insertelement %[[D20]], %[[D17]][%[[C2]] : index] : vector<4xi1> - # CHECK: %[[D22:.+]] = arith.addi %[[D8]], %[[C3]] : index - # CHECK: %[[D23:.+]] = arith.cmpi slt, %[[D22]], %[[C3]] : index - # CHECK: %[[D24:.+]] = arith.andi %[[D10]], %[[D23]] : i1 - # CHECK: %[[D25:.+]] = vector.insertelement %[[D24]], %[[D21]][%[[C3]] : index] : vector<4xi1> - # CHECK: %[[D26:.+]] = vector.maskedload %[[D0]][%[[D5]], %[[D8]]], %[[D25]], %[[CST]] : - # CHECK-SAME: memref<1x3xf16, strided<[3, 1], offset: ?>>, vector<4xi1>, vector<4xf16> into vector<4xf16> - # CHECK: %[[D27:.+]] = stream.binding.subspan %[[ARG1]][%[[C0]]] : !stream.binding -> - # CHECK-SAME: memref<1x3xf16, strided<[3, 1], offset: ?>> - # CHECK: vector.maskedstore %[[D27]][%[[D5]], %[[D8]]], %[[D25]], %[[D26]] : - # CHECK-SAME: memref<1x3xf16, strided<[3, 1], offset: ?>>, vector<4xi1>, vector<4xf16> + # CHECK-DAG: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<4xf16> + # CHECK-DAG: %[[CST_0:.*]] = arith.constant dense<3> : vector<4xindex> + # CHECK-DAG: %[[CST_1:.*]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex> + # CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index + # CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index + # CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index + # CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index + # CHECK-DAG: %[[C64:.*]] = arith.constant 64 : index + # CHECK: %[[WORKGROUP_ID_0:.*]] = stream.dispatch.workgroup.id[0] : index + # CHECK: %[[WORKGROUP_ID_1:.*]] = stream.dispatch.workgroup.id[1] : index + # CHECK-DAG: %[[THREAD_ID_X:.*]] = gpu.thread_id x + # CHECK-DAG: %[[THREAD_ID_Y:.*]] = gpu.thread_id y + # CHECK: %[[D0:.*]] = stream.binding.subspan %[[ARG0]][%[[C0]]] : !stream.binding -> memref<1x3xf16, + # CHECK-SAME: strided<[3, 1], offset: ?>> + # CHECK: %[[D1:.*]] = arith.muli %[[WORKGROUP_ID_0]], %[[C4]] : index + # CHECK: %[[D2:.*]] = arith.divsi %[[THREAD_ID_X]], %[[C64]] : index + # CHECK: %[[D3:.*]] = arith.muli %[[D2]], %[[C4]] : index + # CHECK: %[[D4:.*]] = arith.addi %[[D3]], %[[D1]] : index + # CHECK: %[[D5:.*]] = arith.addi %[[D4]], %[[THREAD_ID_X]] : index + # CHECK: %[[D6:.*]] = arith.muli %[[WORKGROUP_ID_1]], %[[C4]] : index + # CHECK: %[[D7:.*]] = arith.muli %[[THREAD_ID_Y]], %[[C8]] : index + # CHECK: %[[D8:.*]] = arith.addi %[[D7]], %[[D6]] : index + # CHECK: %[[D9:.*]] = vector.splat %[[D8]] : vector<4xindex> + # CHECK: %[[D10:.*]] = arith.addi %[[D9]], %[[CST_1]] : vector<4xindex> + # CHECK: %[[D11:.*]] = arith.cmpi slt, %[[D10]], %[[CST_0]] : vector<4xindex> + # CHECK: %[[D12:.*]] = arith.cmpi slt, %[[D5]], %[[C1]] : index + # CHECK: %[[D13:.*]] = vector.splat %[[D12]] : vector<4xi1> + # CHECK: %[[D14:.*]] = arith.andi %[[D11]], %[[D13]] : vector<4xi1> + # CHECK: %[[D15:.*]] = vector.maskedload %[[D0]][%[[D5]], %[[D8]]], %[[D14]], %[[CST]] : memref<1x3xf16, + # CHECK-SAME: strided<[3, 1], offset: ?>>, vector<4xi1>, vector<4xf16> into vector<4xf16> + # CHECK: %[[D16:.*]] = stream.binding.subspan %arg1[%[[C0]]] : !stream.binding -> memref<1x3xf16, + # CHECK-SAME: strided<[3, 1], offset: ?>> + # CHECK: vector.maskedstore %[[D16]][%[[D5]], %[[D8]]], %[[D14]], %[[D15]] : memref<1x3xf16, + # CHECK-SAME: strided<[3, 1], offset: ?>>, vector<4xi1>, vector<4xf16> @run_test diff --git a/shark_turbine/kernel/_support/indexing.py b/shark_turbine/kernel/_support/indexing.py index 3f092278..b99d7b5b 100644 --- a/shark_turbine/kernel/_support/indexing.py +++ b/shark_turbine/kernel/_support/indexing.py @@ -99,6 +99,7 @@ class IndexingContext: __slots__ = [ "subs", + "special_subs", "shaped_bindings", "dyn_dims", "frozen_subs", @@ -109,6 +110,7 @@ class IndexingContext: def __init__(self): self.subs: dict[IndexSymbol, int] = {} + self.special_subs: dict[IndexSymbol, Any] = {} # Indexed by .instance self.shaped_bindings: dict[Any, _ShapedBinding] = {} self.dyn_dims: list[IndexSymbol] = [] @@ -245,6 +247,20 @@ def get_static_value(self, expr: IndexExpr | int) -> Optional[int]: except TypeError: return None + def iota(self, n: int) -> IndexExpr: + sym = index_symbol(f"$IOTA{n}") + if sym not in self.special_subs: + self.special_subs[sym] = tuple(range(n)) + + return sym + + def get_val(self, sym: IndexSymbol) -> Any: + res = self.subs.get(sym, None) + if res is None: + res = self.special_subs.get(sym, None) + + return res + ##### Context management. @staticmethod def current() -> "IndexingContext": diff --git a/shark_turbine/kernel/wave/codegen.py b/shark_turbine/kernel/wave/codegen.py index e4a8cf72..aff72cf3 100644 --- a/shark_turbine/kernel/wave/codegen.py +++ b/shark_turbine/kernel/wave/codegen.py @@ -4,6 +4,7 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +import functools import operator import sympy import math @@ -77,7 +78,7 @@ WorkgroupConstraint, TilingConstraint, ) -from .utils import subs_idxc +from .utils import subs_idxc, find_index_bounds # Indexing imports. from .._support.indexing import IndexingContext, IndexExpr, IndexSequence @@ -171,6 +172,32 @@ def get_type_or_element_type(operand_type: IrType): def gen_sympy_index(emitter: WaveEmitter, expr: sympy.Expr) -> OpResult: stack: list[OpResult] = [] + def _broadcast(a, b): + if not isinstance(a, (Value, OpResult)): + a = a.result + + if not isinstance(b, (Value, OpResult)): + b = b.result + + if a.type == b.type: + return a, b + + if isinstance(a.type, VectorType) and isinstance( + b.type, (IndexType, IntegerType) + ): + assert a.type.element_type == b.type + b = vector_d.splat(a.type, b) + return a, b + + if isinstance(a.type, (IndexType, IntegerType)) and isinstance( + b.type, VectorType + ): + assert b.type.element_type == a.type + a = vector_d.splat(b.type, a) + return a, b + + raise CodegenError(f"Cannot broadcast {a.type} and {b.type}") + def _process_mul_add_ops(term, is_mul): args = [] callables = [] @@ -187,9 +214,9 @@ def _process_mul_add_ops(term, is_mul): continue if is_mul: - operation = arith_d.MulIOp(operation, arg) + operation = arith_d.MulIOp(*_broadcast(operation, arg)) else: - operation = arith_d.AddIOp(operation, arg) + operation = arith_d.AddIOp(*_broadcast(operation, arg)) for arg in callables: operation = arg(operation, is_mul) @@ -197,16 +224,29 @@ def _process_mul_add_ops(term, is_mul): stack.append(operation) def _get_mul(numerator): - return lambda x: arith_d.MulIOp(x, numerator) + return lambda x: arith_d.MulIOp(*_broadcast(x, numerator)) def _get_add(numerator, denominator): - return lambda x: arith_d.AddIOp(arith_d.MulIOp(x, denominator), numerator) + return lambda x: arith_d.AddIOp( + *_broadcast(arith_d.MulIOp(*_broadcast(x, denominator)), numerator) + ) def _get_div(mul, add, denominator): return lambda x, is_mul: arith_d.DivSIOp( - mul(x) if is_mul else add(x), denominator + *_broadcast(mul(x) if is_mul else add(x), denominator) ) + def _get_const(val): + if isinstance(val, int): + return arith_d.constant(IndexType.get(), res) + + if isinstance(val, (tuple, list)): + vec_type = VectorType.get([len(val)], IndexType.get()) + vals = [IntegerAttr.get(IndexType.get(), v) for v in val] + return arith_d.constant(vec_type, DenseElementsAttr.get(vals, vec_type)) + + raise CodegenError(f"Unsupported const val {val} : {type(val)}") + induction_var_syms = [] induction_vars = [] for constraint in emitter.constraints: @@ -237,9 +277,9 @@ def _get_div(mul, add, denominator): for term in sympy.postorder_traversal(expr): match term: case sympy.Symbol(): - if term in idxc.subs.keys(): - cst = arith_d.constant(IndexType.get(), idxc.subs[term]) - stack.append(cst) + res = idxc.get_val(term) + if res is not None: + stack.append(_get_const(res)) elif term in dynamics.keys(): stack.append(dynamics[term]) else: @@ -253,7 +293,7 @@ def _get_div(mul, add, denominator): case sympy.Mod(): rhs = stack.pop() lhs = stack.pop() - mod = arith_d.RemSIOp(lhs, rhs) + mod = arith_d.RemSIOp(*_broadcast(lhs, rhs)) stack.append(mod) case sympy.floor(): # TODO: Since divsi rounds to zero, this seems to work. @@ -267,17 +307,27 @@ def _get_div(mul, add, denominator): # Assumes that the negative term is always carried on the numerator if abs(term.p) > term.p: zero = arith_d.constant(IndexType.get(), int(0)) - numerator = arith_d.SubIOp(zero, numerator) + numerator = arith_d.SubIOp(*_broadcast(zero, numerator)) mul = lambda x: x if abs(term.p) != 1: mul = _get_mul(numerator) add = _get_add(numerator, denominator) operation = _get_div(mul, add, denominator) stack.append(operation) + case sympy.StrictLessThan(): + rhs = stack.pop() + lhs = stack.pop() + res = arith_d.cmpi(arith_d.CmpIPredicate.slt, *_broadcast(lhs, rhs)) + stack.append(res) + case sympy.And(): + rhs = stack.pop() + lhs = stack.pop() + res = arith_d.andi(*_broadcast(lhs, rhs)) + stack.append(res) case sympy.UnevaluatedExpr(): continue case _: - raise CodegenError(f"Can not handle {term} yet") + raise CodegenError(f"Can not handle {type(term)} : {term}") if len(stack) != 1: raise CodegenError(f"Expected single result, got {len(stack)}") return stack[0] @@ -392,44 +442,24 @@ def _is_identity_mapping( def _build_mask( emitter: WaveEmitter, index: Dict[IndexExpr, IndexExpr], elements_per_thread: int ) -> Optional[OpResult]: - bounds = [] - for constraint in emitter.constraints: - if not isinstance(constraint, (WorkgroupConstraint, TilingConstraint)): - continue - - dim = constraint.dim - if dim not in index: - continue - - work_size = constraint.count * constraint.tile_size - if subs_idxc(work_size) == subs_idxc(dim): - continue - - bounds.append((dim, gen_sympy_index(emitter, dim))) - - if len(bounds) == 0: + bounds = find_index_bounds(emitter.constraints, index) + if bounds is None: return None - mask_vec_type = VectorType.get([elements_per_thread], IntegerType.get_signless(1)) - mask = vector_d.constant_mask(mask_vec_type, [elements_per_thread]) - + idxc = IndexingContext.current() last_dim = tuple(index.keys())[-1] new_index = {k: _get_start_index(v) for k, v in index.items()} - for i in range(elements_per_thread): - cond = None - for dim, bound in bounds: - idx = gen_sympy_index(emitter, new_index[dim]) - lt = arith_d.cmpi(arith_d.CmpIPredicate.slt, idx, bound) - if cond is None: - cond = lt - else: - cond = arith_d.andi(cond, lt) + new_index[last_dim] = new_index[last_dim] + idxc.iota(elements_per_thread) - pos = arith_d.ConstantOp(IndexType.get(), i) - mask = vector_d.insertelement(cond, mask, position=pos) + mask_expr = functools.reduce( + lambda a, b: sympy.And(a, b), (new_index[dim] < dim for dim in bounds) + ) + mask = gen_sympy_index(emitter, mask_expr) - new_index[last_dim] = new_index[last_dim] + 1 + mask_vec_type = VectorType.get([elements_per_thread], IntegerType.get_signless(1)) + if mask.type != mask_vec_type: + mask = vector_d.splat(mask_vec_type, mask) return mask diff --git a/shark_turbine/kernel/wave/utils.py b/shark_turbine/kernel/wave/utils.py index dda9013b..c2d0a582 100644 --- a/shark_turbine/kernel/wave/utils.py +++ b/shark_turbine/kernel/wave/utils.py @@ -11,7 +11,7 @@ transform_d, UnitAttr, ) -from typing import Callable, Any, List, Tuple +from typing import Optional, Callable, Any, List, Tuple from .._support.tracing import CapturedTrace from .._support.indexing import IndexExpr, IndexingContext, IndexSymbol, IndexSequence from ..lang.global_symbols import * @@ -25,7 +25,12 @@ GetResult, IterArg, ) -from .constraints import Constraint, HardwareConstraint, TilingConstraint +from .constraints import ( + Constraint, + WorkgroupConstraint, + HardwareConstraint, + TilingConstraint, +) import torch.fx as fx import shark_turbine.kernel.lang as tkl @@ -531,3 +536,27 @@ def specialize_index_sequence( operand_map[key] = 1 return index_seq.subs(operand_map) return index_seq.subs(operand_map) + + +def find_index_bounds( + constraints: list[Constraint], index: dict[IndexExpr, IndexExpr] +) -> Optional[list[IndexExpr]]: + bounds = [] + for constraint in constraints: + if not isinstance(constraint, (WorkgroupConstraint, TilingConstraint)): + continue + + dim = constraint.dim + if dim not in index: + continue + + work_size = constraint.count * constraint.tile_size + if subs_idxc(work_size) == subs_idxc(dim): + continue + + bounds.append(dim) + + if len(bounds) == 0: + return None + + return bounds From 621cbe1f814354ddbb0a24debbe1aceac18bda25 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Mon, 30 Sep 2024 15:47:20 -0700 Subject: [PATCH 11/28] Enable import_symbolic_shape_expressions in the FxImporter. (#179) * Adds an option to `aot.export(import_symbolic_shape_expressions=True)` to enable emission of torch-mlir symbolic shape constraints. This is currently set to False until IREE is ready to ingest these by default. Rough sequence of work in IREE proper: * Custom lowering of `torch.symbolic_int` and `torch.bind_symbolic_shape` ops to IREE util "assume" ops. Note that we are only planning to lower "terminal" bindings (basically function arguments and a couple of other such categories). * Canonicalizations to ensure that assume equalities are == 0 (versus the native form from torch where they assume a non zero equality). * Fusion will clone corresponding bindings on dependent dims into dispatch regions. * Existing linalg shape analysis extended and queryable by codegen. --------- Signed-off-by: Stella Laurenzo --- shark_turbine/aot/compiled_module.py | 26 ++++++++-- shark_turbine/aot/exporter.py | 16 +++++- shark_turbine/aot/support/ir_utils.py | 13 ++++- .../support/procedural/exported_program.py | 5 +- tests/aot/dynamic_shape_export_test.py | 50 +++++++++++++++++++ 5 files changed, 102 insertions(+), 8 deletions(-) create mode 100644 tests/aot/dynamic_shape_export_test.py diff --git a/shark_turbine/aot/compiled_module.py b/shark_turbine/aot/compiled_module.py index 3f44c8b9..5fffd6a0 100644 --- a/shark_turbine/aot/compiled_module.py +++ b/shark_turbine/aot/compiled_module.py @@ -41,6 +41,7 @@ from .support.ir_utils import ( ModuleBuilder, + ModuleBuilderOptions, ) @@ -162,11 +163,13 @@ class CompiledModuleClassInfo: __slots__ = [ "all_exports", "ir_module_name", + "options", ] - def __init__(self, *, ir_module_name: str): + def __init__(self, *, ir_module_name: str, options: ModuleBuilderOptions): self.ir_module_name = ir_module_name self.all_exports: Dict[str, Exportable] = dict() + self.options = options def add_export(self, key: str, value: Exportable): if key in self.all_exports: @@ -370,13 +373,23 @@ class CompiledModuleMeta(type): # It is passed the dictionary of declared attributes and any keyword # arguments from the class declaration: # class Foo(Bar, kwarg="you probably just learned this is possible"): - def __new__(mcls, name: str, bases, dct, *, export_name: Optional[str] = None): + def __new__( + mcls, + name: str, + bases, + dct, + *, + export_name: Optional[str] = None, + options: Optional[ModuleBuilderOptions] = None, + ): if not _metaclass_setup_complete: return type.__new__(mcls, name, bases, dct) ir_module_name = _derive_ir_module_name(name, export_name) logger.debug("Create new CompiledModule: %s", ir_module_name) - info = CompiledModuleClassInfo(ir_module_name=ir_module_name) + info = CompiledModuleClassInfo( + ir_module_name=ir_module_name, options=options or ModuleBuilderOptions() + ) # Process that attributes that were set as part of class definition. # Any attributes that we decide are part of the compiled module @@ -436,6 +449,7 @@ def create_from_dict( dct: dict, *, export_name: Optional[str] = None, + options: Optional[ModuleBuilderOptions] = None, ) -> CompiledModuleMeta: """Creates a CompiledModule subclass with an explicit dictionary of members. @@ -446,7 +460,9 @@ class Foo(CompiledModule, export_name="bar"): def member(): ... ``` """ - return CompiledModuleMeta(name, (cls,), dct, export_name=export_name) + return CompiledModuleMeta( + name, (cls,), dct, export_name=export_name, options=options + ) @staticmethod def get_class_info(cls: CompiledModuleMeta) -> CompiledModuleClassInfo: @@ -596,7 +612,7 @@ def __new__( module_op.attributes["sym_name"] = StringAttr.get( class_info.ir_module_name, context=context ) - module_builder = ModuleBuilder(module_op) + module_builder = ModuleBuilder(module_op, options=class_info.options) info = CompiledModuleInstanceInfo(class_info, module_builder=module_builder) _all_compiled_module_instance_infos[self] = info diff --git a/shark_turbine/aot/exporter.py b/shark_turbine/aot/exporter.py index 4c0e0160..c1adb527 100644 --- a/shark_turbine/aot/exporter.py +++ b/shark_turbine/aot/exporter.py @@ -26,6 +26,7 @@ from .builtins import * from .compiled_module import ( CompiledModule, + ModuleBuilderOptions, ImportPhase, ) from .fx_programs import FxPrograms @@ -175,6 +176,7 @@ def export( module_name: Optional[str] = None, function_name: Optional[str] = None, strict_export: bool = True, + import_symbolic_shape_expressions: bool = False, ) -> ExportOutput: """Exports a torch.nn.Module. @@ -223,6 +225,7 @@ def export( module_name: Optional[str] = None, function_name: Optional[str] = None, strict_export: bool = True, + import_symbolic_shape_expressions: bool = False, ) -> ExportOutput: """Generic export of supported entities. @@ -270,11 +273,19 @@ def export( "LambdaCompiledModule", {(function_name or "main"): mdl}, export_name=module_name or "module", + options=ModuleBuilderOptions( + import_symbolic_shape_expressions=import_symbolic_shape_expressions, + ), ) elif isinstance(mdl, FxPrograms): TransformedModule = CompiledModule.create_from_dict( - "LambdaCompiledModule", mdl.programs, export_name=module_name or "module" + "LambdaCompiledModule", + mdl.programs, + export_name=module_name or "module", + options=ModuleBuilderOptions( + import_symbolic_shape_expressions=import_symbolic_shape_expressions, + ), ) elif isinstance(mdl, torch.nn.Module): # Normalize arguments for torch.export. @@ -302,6 +313,9 @@ def export( "LambdaCompiledModule", {(function_name or "main"): exported_program}, export_name=module_name or "module", + options=ModuleBuilderOptions( + import_symbolic_shape_expressions=import_symbolic_shape_expressions, + ), ) elif issubclass(mdl, CompiledModule): TransformedModule = mdl diff --git a/shark_turbine/aot/support/ir_utils.py b/shark_turbine/aot/support/ir_utils.py index a662c15c..e1eb9d56 100644 --- a/shark_turbine/aot/support/ir_utils.py +++ b/shark_turbine/aot/support/ir_utils.py @@ -7,6 +7,7 @@ from typing import Callable, Dict, Optional, Sequence, Tuple +from dataclasses import dataclass from pathlib import Path import tempfile @@ -148,6 +149,12 @@ def infer_external_from_tensor( ############################################################################### +@dataclass +class ModuleBuilderOptions: + # Whether to import torch symbolic shape expressions for ExportedPrograms. + import_symbolic_shape_expressions: bool = False + + class ModuleBuilder: """Wrapper around module and IR accounting for a module being built.""" @@ -159,14 +166,18 @@ class ModuleBuilder: "last_global_op", "ip", "module_op", + "options", "symbol_table", "global_ref_tracker", "native_type_converter", "_auto_symbol_counts", ] - def __init__(self, module_op: Operation): + def __init__( + self, module_op: Operation, *, options: Optional[ModuleBuilderOptions] = None + ): self.module_op = module_op + self.options = options or ModuleBuilderOptions() self.context = module_op.context self.body = module_op.regions[0].blocks[0] self.symbol_table = SymbolTable(module_op) diff --git a/shark_turbine/aot/support/procedural/exported_program.py b/shark_turbine/aot/support/procedural/exported_program.py index bbc431ae..f6540bab 100644 --- a/shark_turbine/aot/support/procedural/exported_program.py +++ b/shark_turbine/aot/support/procedural/exported_program.py @@ -181,7 +181,10 @@ def import_exported_program( ) -> ExportedProgramIntrinsic: fx_importer = _create_fx_importer(module_builder) entry_func_op = fx_importer.import_program( - exported_program, func_name=symbol_name, func_visibility=symbol_visibility + exported_program, + func_name=symbol_name, + func_visibility=symbol_visibility, + import_symbolic_shape_expressions=module_builder.options.import_symbolic_shape_expressions, ) module_call_graph = exported_program.module_call_graph diff --git a/tests/aot/dynamic_shape_export_test.py b/tests/aot/dynamic_shape_export_test.py new file mode 100644 index 00000000..da8c11b7 --- /dev/null +++ b/tests/aot/dynamic_shape_export_test.py @@ -0,0 +1,50 @@ +import torch + +import pytest + +from shark_turbine.aot import * + + +@pytest.mark.parametrize( + "import_symbolic_shape_expressions", + [ + True, + False, + ], +) +def test_exported_program_dynamic_shapes(import_symbolic_shape_expressions): + class M(torch.nn.Module): + def __init__(self): + super().__init__() + + self.branch1 = torch.nn.Sequential(torch.nn.Linear(64, 32), torch.nn.ReLU()) + self.branch2 = torch.nn.Sequential( + torch.nn.Linear(128, 64), torch.nn.ReLU() + ) + self.buffer = torch.ones(32) + + def forward(self, x1, x2): + out1 = self.branch1(x1) + out2 = self.branch2(x2) + return (out1 + self.buffer, out2) + + example_args = (torch.randn(32, 64), torch.randn(32, 128)) + + # Create a dynamic batch size + batch = torch.export.Dim("batch") + # Specify that the first dimension of each input is that batch size + dynamic_shapes = {"x1": {0: batch}, "x2": {0: batch}} + + output = export( + M(), + args=example_args, + dynamic_shapes=dynamic_shapes, + import_symbolic_shape_expressions=import_symbolic_shape_expressions, + ) + output.print_readable() + asm = str(output.mlir_module) + + if import_symbolic_shape_expressions: + assert "bind_symbolic_shape" in asm + else: + assert "bind_symbolic_shape" not in asm From 84320eaaf267e0dc014d6872d29db1031c1872d5 Mon Sep 17 00:00:00 2001 From: harsh-nod Date: Tue, 1 Oct 2024 11:45:32 -0700 Subject: [PATCH 12/28] Add code to construct pipelined loop from schedule (#160) This PR adds code to construct the epilogue, kernel and prologue once we have computed a schedule. We simulate rotating registers in software and add visualization tools to show the pipelined graphs. --------- Signed-off-by: Harsh Menon --- lit_tests/kernel/wave/codegen.py | 84 +++ lit_tests/kernel/wave/scheduling.py | 227 +++++++ shark_turbine/kernel/_support/tracing.py | 3 + shark_turbine/kernel/ops/wave_ops.py | 18 +- shark_turbine/kernel/wave/codegen.py | 24 +- .../kernel/wave/scheduling/graph_utils.py | 3 +- .../wave/scheduling/loop_reconstruction.py | 556 ++++++++++++++++++ .../scheduling/loop_reconstruction_utils.py | 285 +++++++++ .../wave/scheduling/modulo_scheduling.py | 9 + .../kernel/wave/scheduling/schedule.py | 32 +- shark_turbine/kernel/wave/utils.py | 54 ++ shark_turbine/kernel/wave/visualization.py | 95 ++- shark_turbine/kernel/wave/wave.py | 14 + tests/kernel/wave/wave_gemm_test.py | 25 +- 14 files changed, 1402 insertions(+), 27 deletions(-) create mode 100644 lit_tests/kernel/wave/scheduling.py create mode 100644 shark_turbine/kernel/wave/scheduling/loop_reconstruction.py create mode 100644 shark_turbine/kernel/wave/scheduling/loop_reconstruction_utils.py diff --git a/lit_tests/kernel/wave/codegen.py b/lit_tests/kernel/wave/codegen.py index be4b04bb..a7781a39 100644 --- a/lit_tests/kernel/wave/codegen.py +++ b/lit_tests/kernel/wave/codegen.py @@ -602,6 +602,90 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: # CHECK: return +@run_test +def test_gemm_pipelined(): + constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] + constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] + constraints += [tkw.TilingConstraint(K, BLOCK_K)] + constraints += [tkw.WaveConstraint(M, BLOCK_M / 2)] + constraints += [tkw.WaveConstraint(N, BLOCK_N / 2)] + + constraints += [ + tkw.HardwareConstraint( + threads_per_wave=64, + waves_per_block=(2, 2, 1), + mma_type=tkw.MMAType.F32_16x16x16_F16, + ) + ] + + @tkw.wave(constraints) + def gemm_pipelined( + a: tkl.Memory[M, K, ADDRESS_SPACE, tkl.f16], + b: tkl.Memory[N, K, ADDRESS_SPACE, tkl.f16], + c: tkl.Memory[M, N, ADDRESS_SPACE_0, tkl.f32], + ): + c_reg = tkl.Register[M, N, tkl.f32](0.0) + + @tkw.reduction(K, init_args=[c_reg]) + def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: + a_reg = tkw.read(a, elements_per_thread=LOAD_ELEMS_PER_THREAD) + b_reg = tkw.read(b, elements_per_thread=LOAD_ELEMS_PER_THREAD) + acc = tkw.mma(a_reg, b_reg, acc) + return acc + + tkw.write(repeat, c, elements_per_thread=STORE_ELEMS_PER_THREAD) + + with tk.gen.TestLaunchContext( + { + M: 128, + N: 128, + K: 128, + BLOCK_M: 64, + BLOCK_N: 64, + BLOCK_K: 32, + LOAD_ELEMS_PER_THREAD: 4, + STORE_ELEMS_PER_THREAD: 4, + ADDRESS_SPACE: SHARED_ADDRESS_SPACE, + ADDRESS_SPACE_0: GLOBAL_ADDRESS_SPACE, + READ_SHARED_DELAY: 1, + WRITE_SHARED_DELAY: 1, + READ_GLOBAL_DELAY: 2, + WRITE_GLOBAL_DELAY: 2, + MMA_DELAY: 1, + SHARED_MEMORY_UNITS: 4, + GLOBAL_MEMORY_UNITS: 4, + MMA_UNITS: 4, + }, + canonicalize=True, + schedule=True, + ): + a = torch.randn(64, 32, dtype=torch.float16) + b = torch.randn(128, 32, dtype=torch.float16) + c = torch.zeros(64, 128, dtype=torch.float32) + print(gemm_pipelined(a, b, c).module_op) + + # CHECK: func.func @gemm_pipelined + # CHECK-COUNT-2: vector.load + # CHECK-COUNT-2: vector.store + # CHECK-COUNT-1: amdgpu.lds_barrier + # CHECK-COUNT-10: vector.load + # CHECK-COUNT-4: amdgpu.mfma + # CHECK-COUNT-1: amdgpu.lds_barrier + # CHECK-COUNT-2: vector.store + # CHECK-COUNT-1: scf.for + # CHECK-COUNT-4: amdgpu.mfma + # CHECK-COUNT-1: amdgpu.lds_barrier + # CHECK-COUNT-10: vector.load + # CHECK-COUNT-4: amdgpu.mfma + # CHECK-COUNT-1: amdgpu.lds_barrier + # CHECK-COUNT-2: vector.store + # CHECK-COUNT-1: scf.yield + # CHECK-COUNT-4: amdgpu.mfma + # CHECK-COUNT-1: amdgpu.lds_barrier + # CHECK-COUNT-8: vector.load + # CHECK-COUNT-8: amdgpu.mfma + + @run_test def test_add_float(): constraints: list[tkw.Constraint] = [ diff --git a/lit_tests/kernel/wave/scheduling.py b/lit_tests/kernel/wave/scheduling.py new file mode 100644 index 00000000..eafabb27 --- /dev/null +++ b/lit_tests/kernel/wave/scheduling.py @@ -0,0 +1,227 @@ +# RUN: python %s | FileCheck %s + +import logging +import unittest +import shark_turbine.kernel as tk +import shark_turbine.kernel.lang as tkl +import shark_turbine.kernel.wave as tkw +from shark_turbine.kernel.wave.promotion import promote_placeholders +from shark_turbine.kernel.wave.hoisting import hoist_allocs +from shark_turbine.kernel.wave.expansion import expand_graph +from shark_turbine.kernel.lang.global_symbols import * +from shark_turbine.kernel._support.tracing import CapturedTrace +from shark_turbine.kernel._support.indexing import IndexingContext +from shark_turbine.kernel.ops.wave_ops import * +from shark_turbine.kernel.wave.utils import run_test, print_subgraph +from shark_turbine.kernel.wave.minimize_global_loads import minimize_global_loads +from shark_turbine.kernel.wave.shared_memory_indexing import ( + apply_shared_memory_indexing_corrections, +) +from shark_turbine.kernel.wave.scheduling.schedule import schedule_graph + + +# Input sizes +M = tkl.sym.M +N = tkl.sym.N +K = tkl.sym.K + +# Workgroup tile sizes +BLOCK_M = tkl.sym.BLOCK_M +BLOCK_N = tkl.sym.BLOCK_N +BLOCK_K = tkl.sym.BLOCK_K + +# Address space +ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE +ADDRESS_SPACE_0 = tkl.sym.ADDRESS_SPACE_0 + +# Induction variable for dimension K +ARGK = tkl.sym.ARGK + + +@tkw.wave_trace_only() +def gemm_pipelined( + a: tkl.Memory[M, K, ADDRESS_SPACE_0, tkl.f16], + b: tkl.Memory[N, K, ADDRESS_SPACE_0, tkl.f16], + c: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f32], +): + c_reg = tkl.Register[M, N, tkl.f32](0.0) + + @tkw.reduction(K, init_args=[c_reg]) + def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: + a_reg = tkw.read(a, elements_per_thread=4) + b_reg = tkw.read(b, elements_per_thread=4) + acc = tkw.mma(a_reg, b_reg, acc) + return acc + + tkw.write(repeat, c, elements_per_thread=4) + + +@run_test +def test_gemm_pipelined(): + constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] + constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] + constraints += [tkw.WaveConstraint(M, BLOCK_M / 2, 0)] + constraints += [tkw.WaveConstraint(N, BLOCK_N / 2, 1)] + constraints += [tkw.TilingConstraint(K, BLOCK_K, ARGK)] + constraints += [ + tkw.HardwareConstraint(threads_per_wave=64, waves_per_block=(2, 2, 1)) + ] + with tk.gen.TestLaunchContext( + { + M: 128, + N: 256, + K: 128, + BLOCK_M: 64, + BLOCK_N: 64, + BLOCK_K: 32, + ADDRESS_SPACE: GLOBAL_ADDRESS_SPACE, + ADDRESS_SPACE_0: SHARED_ADDRESS_SPACE, + READ_SHARED_DELAY: 1, + WRITE_SHARED_DELAY: 1, + READ_GLOBAL_DELAY: 2, + WRITE_GLOBAL_DELAY: 2, + MMA_DELAY: 1, + SHARED_MEMORY_UNITS: 2, + GLOBAL_MEMORY_UNITS: 2, + MMA_UNITS: 2, + } + ): + trace: CapturedTrace = gemm_pipelined() + IndexingContext.current().finalize() + promote_placeholders(trace, constraints) + hoist_allocs(trace) + expand_graph(trace, constraints) + minimize_global_loads(trace, constraints) + apply_shared_memory_indexing_corrections(trace, constraints) + schedule_graph(trace, constraints) + + print_subgraph(trace, "pipelined_reduction", False) + # CHECK: %acc_0_0_0 + # CHECK-NEXT: %acc_0_1_0 + # CHECK-NEXT: %acc_1_0_0 + # CHECK-NEXT: %acc_1_1_0 + # CHECK-NEXT: %rotating_reg_0 + # CHECK-NEXT: %rotating_reg_1 + # CHECK-NEXT: %rotating_reg_2 + # CHECK-NEXT: %rotating_reg_3 + # CHECK-NEXT: %rotating_reg_4 + # CHECK-NEXT: %rotating_reg_5 + # CHECK-NEXT: %rotating_reg_6 + # CHECK-NEXT: %mma_1_1_1 + # CHECK-SAME: (%rotating_reg_1, %rotating_reg_4, %rotating_reg_6) + # CHECK-NEXT: %read_shared_0_0_0 + # CHECK-NEXT: %read_shared_0_0_1 + # CHECK-NEXT: %read_4 + # CHECK-NEXT: %read_5 + # CHECK-NEXT: %read_shared_1_0_0 + # CHECK-NEXT: %read_shared_1_0_1 + # CHECK-NEXT: %mma_0_0_0 + # CHECK-SAME: (%read_shared_0_0_0, %read_shared_0_0_1, %acc_0_0_0) + # CHECK-NEXT: %mma_0_1_0 + # CHECK-SAME: (%read_shared_0_0_0, %rotating_reg_3, %acc_0_1_0) + # CHECK-NEXT: %mma_0_0_1 + # CHECK-SAME: (%rotating_reg_0, %rotating_reg_2, %mma_0_0_0) + # CHECK-NEXT: %mma_1_0_0 + # CHECK-SAME: (%read_shared_1_0_0, %read_shared_0_0_1, %acc_1_0_0) + # CHECK-NEXT: %write_2 + # CHECK-NEXT: %write_3 + # CHECK-NEXT: %mma_1_0_1 + # CHECK-SAME: (%read_shared_1_0_1, %rotating_reg_2, %mma_1_0_0) + # CHECK-NEXT: %mma_0_1_1 + # CHECK-SAME: (%rotating_reg_0, %rotating_reg_5, %mma_0_1_0) + # CHECK-NEXT: %read_shared_0_1_0 + # CHECK-NEXT: %read_shared_0_1_1 + # CHECK-NEXT: %mma_1_1_0 + # CHECK-SAME: (%read_shared_1_0_0, %rotating_reg_3, %mma_1_1_1) + # CHECK-NEXT: %read_shared_0_0_2 + # CHECK-NEXT: %read_shared_0_0_3 + # CHECK-NEXT: [mma_0_0_1, mma_0_1_1, mma_1_0_1, mma_1_1_1, read_shared_0_0_2, read_shared_1_0_1, read_shared_0_0_3, read_shared_0_1_0, rotating_reg_5, read_shared_0_1_1, mma_1_1_0] + + print_subgraph(trace, "region_1", False) + # CHECK: %a + # CHECK-NEXT: %b + # CHECK-NEXT: %c + # CHECK-NEXT: %register_0_0_0 + # CHECK-NEXT: %register_1_1_0 + # CHECK-NEXT: %register_1_0_0 + # CHECK-NEXT: %register_0_1_0 + # CHECK-NEXT: %allocate + # CHECK-NEXT: %allocate_1 + # CHECK-NEXT: %read_4 + # CHECK-NEXT: %read_5 + # CHECK-NEXT: %write_2 + # CHECK-NEXT: %write_3 + # CHECK-NEXT: %read_shared_0_1_0 + # CHECK-NEXT: %read_shared_0_1_1 + # CHECK-NEXT: %read_shared_0_0_1 + # CHECK-NEXT: %read_shared_0_0_2 + # CHECK-NEXT: %read_shared_0_0_0 + # CHECK-NEXT: %read_shared_0_0_3 + # CHECK-NEXT: %read_6 + # CHECK-NEXT: %read_7 + # CHECK-NEXT: %read_shared_1_0_0 + # CHECK-NEXT: %read_shared_1_0_1 + # CHECK-NEXT: %mma_0_0_0 + # CHECK-SAME: (%read_shared_0_0_0, %read_shared_0_0_3, %register_0_0_0) + # CHECK-NEXT: %mma_0_1_0 + # CHECK-SAME: (%read_shared_0_0_0, %read_shared_0_1_0, %register_0_1_0) + # CHECK-NEXT: %mma_0_0_1 + # CHECK-SAME: (%read_shared_0_0_1, %read_shared_0_0_2, %mma_0_0_0) + # CHECK-NEXT: %mma_1_0_0 + # CHECK-SAME: (%read_shared_1_0_0, %read_shared_0_0_3, %register_1_0_0) + # CHECK-NEXT: %write_4 + # CHECK-NEXT: %write_5 + # CHECK-NEXT: %mma_1_0_1 + # CHECK-SAME: (%read_shared_1_0_1, %read_shared_0_0_2, %mma_1_0_0) + # CHECK-NEXT: %mma_0_1_1 + # CHECK-SAME: (%read_shared_0_0_1, %read_shared_0_1_1, %mma_0_1_0) + # CHECK-NEXT: %read_shared_0_1_2 + # CHECK-NEXT: %read_shared_0_1_3 + # CHECK-NEXT: %mma_1_1_0 + # CHECK-SAME: (%read_shared_1_0_0, %read_shared_0_1_0, %register_1_1_0) + # CHECK-NEXT: %read_shared_0_0_4 + # CHECK-NEXT: %read_shared_0_0_5 + # CHECK-NEXT: %reduction_1 + # CHECK-NEXT: %getresult_1_1_0 + # CHECK-NEXT: %getresult_1_0_0 + # CHECK-NEXT: %getresult_0_1_0 + # CHECK-NEXT: %getresult_0_0_0 + # CHECK-NEXT: %get_result_4 + # CHECK-NEXT: %get_result_5 + # CHECK-NEXT: %get_result_6 + # CHECK-NEXT: %get_result_7 + # CHECK-NEXT: %get_result_8 + # CHECK-NEXT: %get_result_9 + # CHECK-NEXT: %get_result_10 + # CHECK-NEXT: %mma_1_1_1 + # CHECK-SAME: (%get_result_5, %get_result_9, %get_result_10) + # CHECK-NEXT: %read_shared_0_0_6 + # CHECK-NEXT: %read_shared_0_0_7 + # CHECK-NEXT: %read_shared_1_0_2 + # CHECK-NEXT: %read_shared_1_0_3 + # CHECK-NEXT: %mma_0_0_2 + # CHECK-SAME: (%read_shared_0_0_6, %read_shared_0_0_7, %getresult_0_0_0) + # CHECK-NEXT: %mma_0_1_2 + # CHECK-SAME: (%read_shared_0_0_6, %get_result_7, %getresult_0_1_0) + # CHECK-NEXT: %mma_0_0_3 + # CHECK-SAME: (%get_result_4, %get_result_6, %mma_0_0_2) + # CHECK-NEXT: %mma_1_0_2 + # CHECK-SAME: (%read_shared_1_0_2, %read_shared_0_0_7, %getresult_1_0_0) + # CHECK-NEXT: %mma_1_0_3 + # CHECK-SAME: (%read_shared_1_0_3, %get_result_6, %mma_1_0_2) + # CHECK-NEXT: %mma_0_1_3 + # CHECK-SAME: (%get_result_4, %get_result_9, %mma_0_1_2) + # CHECK-NEXT: %mma_1_1_2 + # CHECK-SAME: (%read_shared_1_0_2, %get_result_7, %mma_1_1_1) + # CHECK-NEXT: %mma_1_1_3 + # CHECK-SAME: (%read_shared_1_0_3, %get_result_9, %mma_1_1_2) + # CHECK-NEXT: %write_0_0_0 + # CHECK-NEXT: %write_1_1_0 + # CHECK-NEXT: %write_1_0_0 + # CHECK-NEXT: %write_0_1_0 + # CHECK-NEXT: return None + + +if __name__ == "__main__": + logging.basicConfig(level=logging.DEBUG) + unittest.main() diff --git a/shark_turbine/kernel/_support/tracing.py b/shark_turbine/kernel/_support/tracing.py index 42424257..857cdb34 100644 --- a/shark_turbine/kernel/_support/tracing.py +++ b/shark_turbine/kernel/_support/tracing.py @@ -129,6 +129,9 @@ def __init__(self, region_graph: RegionGraph, root_graph: str): def get_subgraph(self, name: str) -> fx.Graph: return self.region_graph.subgraphs[name] + def add_subgraph(self, name: str, graph: fx.Graph): + self.region_graph.subgraphs[name] = graph + def get_root_graph(self) -> fx.Graph: return self.get_subgraph(self.root_graph) diff --git a/shark_turbine/kernel/ops/wave_ops.py b/shark_turbine/kernel/ops/wave_ops.py index 2c38c9c2..f1292e4d 100644 --- a/shark_turbine/kernel/ops/wave_ops.py +++ b/shark_turbine/kernel/ops/wave_ops.py @@ -453,6 +453,8 @@ def index(self, value: Any): self.fx_node.index = {} for dim, key in value.items(): self.fx_node.index[dim] = key + elif isinstance(value, list): + self.fx_node.index = value else: raise ValueError("Index must be a dict") @@ -691,7 +693,7 @@ def is_barrier_between(self, src: fx.Node, dst: fx.Node) -> bool: prev_node, found_src = prev_node.prev, prev_node == src if not found_src: return False - while next_node and not found_dst: + while next_node.next.op != "root" and not found_dst: next_node, found_dst = next_node.next, next_node == dst return found_dst @@ -910,6 +912,20 @@ def index(self) -> list[dict[IndexSymbol, IndexSequence]]: else None ) + @index.setter + def index(self, value: Any): + CustomOp.index.fset(self, value) + + @property + def count(self) -> int: + if hasattr(self.fx_node, "count"): + return self.fx_node.count + return None + + @count.setter + def count(self, value: int): + self.fx_node.count = value + @define_op("write") @dataclass diff --git a/shark_turbine/kernel/wave/codegen.py b/shark_turbine/kernel/wave/codegen.py index aff72cf3..abb40aaf 100644 --- a/shark_turbine/kernel/wave/codegen.py +++ b/shark_turbine/kernel/wave/codegen.py @@ -249,13 +249,14 @@ def _get_const(val): induction_var_syms = [] induction_vars = [] - for constraint in emitter.constraints: - if isinstance(constraint, TilingConstraint): - assert ( - constraint.dim in emitter.induction_vars - ), f"Could not find induction var for {constraint.dim} dimension" - induction_var_syms.append(constraint.induction_var) - induction_vars.append(emitter.induction_vars[constraint.dim]) + if emitter.induction_vars: + for constraint in emitter.constraints: + if isinstance(constraint, TilingConstraint): + assert ( + constraint.dim in emitter.induction_vars + ), f"Could not find induction var for {constraint.dim} dimension" + induction_var_syms.append(constraint.induction_var) + induction_vars.append(emitter.induction_vars[constraint.dim]) # TODO: factor this out all_symbols = emitter.thread_ids + emitter.workgroup_ids + induction_vars @@ -940,18 +941,11 @@ def handle_reduction(emitter: WaveEmitter, node: fx.Node): flat_init_args, _ = pytree.tree_flatten((init_args)) flat_init_args = [cast_py_value(emitter, arg) for arg in flat_init_args] - # Without scheduling, we assume that we always start at 0. start = arith_d.constant(IndexType.get(), int(0)) - count = None - for constraint in emitter.constraints: - if isinstance(constraint, TilingConstraint) and constraint.dim == axis: - count = subs_idxc(constraint.count) - assert count is not None, "Could not find tiling constraint for reduction axis." - # For now, we assume that dimensions that have tiling constraints on them, # do not have any other constraints. - end = arith_d.constant(IndexType.get(), int(count)) + end = arith_d.constant(IndexType.get(), int(node.count)) # Since we divide the end by the tile size, we need to make sure that the # step is 1. diff --git a/shark_turbine/kernel/wave/scheduling/graph_utils.py b/shark_turbine/kernel/wave/scheduling/graph_utils.py index e625b666..af398af3 100644 --- a/shark_turbine/kernel/wave/scheduling/graph_utils.py +++ b/shark_turbine/kernel/wave/scheduling/graph_utils.py @@ -213,12 +213,13 @@ def topological_sort_nodes( Perform a topological sort on the nodes in the strongly connected component that have an edge in edges, excluding certain nodes. """ - scc_nodes = set(scc) - set(exclude) + scc_nodes = set(scc) filtered_nodes = set() for edge in edges: if edge._from in scc_nodes and edge._to in scc_nodes: filtered_nodes.add(edge._to) filtered_nodes.add(edge._from) + filtered_nodes -= set(exclude) if exclude is not None else set() sorted_nodes = sorted(filtered_nodes, key=lambda x: x.f) return sorted_nodes diff --git a/shark_turbine/kernel/wave/scheduling/loop_reconstruction.py b/shark_turbine/kernel/wave/scheduling/loop_reconstruction.py new file mode 100644 index 00000000..52f205b1 --- /dev/null +++ b/shark_turbine/kernel/wave/scheduling/loop_reconstruction.py @@ -0,0 +1,556 @@ +from ..constraints import Constraint, TilingConstraint +from ..._support.indexing import IndexSymbol +from ..._support.tracing import CapturedTrace +from ...ops.wave_ops import ( + Reduction, + IterArg, + Placeholder, + Allocate, + Output, + Write, + GetResult, + get_custom, +) +from .modulo_scheduling import ModuloScheduler +from ..utils import ( + graph_copy, + erase_graph, + get_induction_variable, + replace_uses_in, +) +from ..utils import subs_idxc +import torch.fx as fx +import math +from collections import deque +from ..visualization import visualize_mapped_graphs, visualize_graph +from ....support.logging import get_logger +from ...lang.global_symbols import SHARED_ADDRESS_SPACE +import random +from typing import Optional +from .loop_reconstruction_utils import ( + ArgumentContext, + create_fill_stage_schedule, + create_drain_stage_schedule, + liveness_analysis, + partition_graph_by_stage, + interleave_instructions, +) +from enum import Enum + +logger = get_logger("turbine.wave.scheduling.loop_reconstruction") + + +class PipelineStage(Enum): + PROLOGUE = 0 + KERNEL = 1 + EPILOGUE = 2 + + +def add_nodes_by_schedule( + reduction_graph: fx.Graph, + partitioned_graph: list[dict[int, fx.Node]], + arg_context: ArgumentContext, + stages: list[int], + initiation_interval: int, + induction_variable: IndexSymbol, + current_induction_variables: list[int], + rotating_registers: dict[fx.Node, list[fx.Node]], + pipelining_stage: PipelineStage = PipelineStage.KERNEL, +): + """ + Interleave the instructions in the partitioned graph by stage + for a single initiation interval, updating the argument maps + per stage starting at the provided start times and indices. + """ + fill_or_drain = pipelining_stage in [PipelineStage.PROLOGUE, PipelineStage.EPILOGUE] + fill = pipelining_stage == PipelineStage.PROLOGUE + drain = pipelining_stage == PipelineStage.EPILOGUE + + for cycle in range(initiation_interval): + logger.debug(f"Cycle: {cycle}") + # Interleave the instructions that are scheduled at the same cycle. + interleaved_instructions = [] + for iteration, stage in enumerate(stages): + if stage is None: + continue + if cycle not in partitioned_graph[stage]: + continue + for node in partitioned_graph[stage][cycle]: + interleaved_instructions.append((iteration, stage, node)) + interleave_instructions(interleaved_instructions) + + for iteration, stage, node in interleaved_instructions: + logger.debug(f"Node: {node}, Stage: {stage}, Iteration: {iteration}") + custom_node = get_custom(node) + logger.debug(f"Node args: {node.args}") + for arg in node.args: + if arg_context.contains_in_iteration(iteration, arg): + logger.debug( + f"Found arg: {arg} in partitioned argument map. Using {arg_context.get_from_iteration(iteration, arg)}." + ) + continue + new_node = custom_node.copy( + new_graph=reduction_graph, + arg_transform=lambda x: ( + arg_context.get_from_iteration(iteration, x) + if arg_context.contains_in_iteration(iteration, x) + else x + ), + ) + # Update the argument context. + arg_context[(iteration, stage, node)] = new_node.fx_node + logger.debug( + f"Copying Node: {node}, Stage: {stage}, Iteration: {iteration} -> {new_node.fx_node}" + ) + # Set the index for the new node by substituting the induction variable + # for the current iteration. + new_node.index = node.index + for dim in new_node.index: + new_node.index[dim] = new_node.index[dim].subs( + {induction_variable: current_induction_variables[iteration]} + ) + # Add scheduling parameters for debugging. + new_node.scheduling_parameters = node.scheduling_parameters + # Update the rotating registers and argument context for the current node (if applicable). + if node in rotating_registers: + rotating_registers[node].append(new_node.fx_node) + rotating_registers[node].popleft() + # If draining, then override the rotating registers and update the argument context. + if fill_or_drain: + for next_stage in range(stage + 1, len(stages)): + arg_context[(iteration, next_stage, node)] = new_node.fx_node + + # Update the init args in the argument context whenever a result is computed. + if node in arg_context.results: + if ( + pipelining_stage == PipelineStage.KERNEL + or pipelining_stage == PipelineStage.EPILOGUE + ): + logger.debug( + f"Updating result: {node} -> {arg_context.result_to_iter_arg[node]} to {new_node.fx_node}." + ) + arg_context.map_arg_all( + arg_context.result_to_iter_arg[node], new_node.fx_node + ) + if pipelining_stage == PipelineStage.PROLOGUE: + logger.debug( + f"Updating result: {node} -> {arg_context.result_to_init_arg[node]} to {new_node.fx_node}." + ) + arg_context.map_arg_all( + arg_context.result_to_init_arg[node], new_node.fx_node + ) + + +def push_placeholders( + implicit_captures: list[fx.Node], + reduction_subgraph: fx.Node, + arg_context: ArgumentContext, +): + """ + Push placeholders into the argument context for the reduction graph. + """ + for node in reduction_subgraph.nodes: + custom = get_custom(node) + if isinstance(custom, Placeholder) and not isinstance(custom, IterArg): + root_node = [x for x in implicit_captures if x.name == node.name][0] + assert root_node is not None + arg_context.map_arg_all(node, root_node) + + +def construct_prologue( + reduction_subgraph: fx.Graph, + reduction: Reduction, + partitioned_graph: list[dict[int, fx.Node]], + scheduler: ModuloScheduler, + rotating_registers: dict[fx.Node, list[fx.Node]], + induction_variable: IndexSymbol, + new_induction_variables: list[int], + stages: list[int], +): + """ + Construct the prologue of the pipelined loop. + For this, we need to copy nodes from the reduction_graph and insert them + before the reduction operator in the root graph in the appropriate order. + We also need to initialize the rotating registers and update the indices + of the nodes to use the appropriate values of the induction variable. + """ + logger.debug("=====================================") + logger.debug("Constructing prologue.") + logger.debug("=====================================") + + arg_context = ArgumentContext( + reduction.outputs(reduction_subgraph), + reduction.iter_args(reduction_subgraph), + reduction.init_args, + scheduler.num_stages, + ) + + # Map iter args to init args in the prologue. + for iter_arg, init_arg in zip( + reduction.iter_args(reduction_subgraph), reduction.init_args + ): + arg_context.map_arg_all(iter_arg, init_arg) + + push_placeholders(reduction.implicit_captures, reduction_subgraph, arg_context) + with reduction.graph.inserting_before(reduction.fx_node): + for i in range(scheduler.num_stages - 1): + add_nodes_by_schedule( + reduction.graph, + partitioned_graph, + arg_context, + stages[i], + scheduler.initiation_interval, + induction_variable, + new_induction_variables, + rotating_registers, + PipelineStage.PROLOGUE, + ) + + # During the prologue, we may have computed results that need to be passed as init args + # to the kernel. + new_init_args: list[fx.Node] = [] + for init_arg in reduction.init_args: + mapped_init_arg = arg_context.lookup(init_arg) + if mapped_init_arg is None: + mapped_init_arg = init_arg + new_init_args.append(mapped_init_arg) + reduction.init_args = new_init_args + + +def flatten_dict_values( + rotating_registers: dict[fx.Node, list[fx.Node]] +) -> list[fx.Node]: + """ + Flatten the values of the rotating registers into a list. + """ + return [ + register for registers in rotating_registers.values() for register in registers + ] + + +def unflatten_dict_values( + rotating_registers_shapes: dict[fx.Node, int], values: list[fx.Node] +) -> dict[fx.Node, list[fx.Node]]: + """ + Unflatten the values of the rotating registers into a dictionary + using the provided shapes. + """ + rotating_registers = {} + count = 0 + for node, shape in rotating_registers_shapes.items(): + rotating_registers[node] = deque(values[count : count + shape]) + count += shape + assert count == sum(rotating_registers_shapes.values()) + return rotating_registers + + +def push_rotating_registers( + arg_context: ArgumentContext, + rotating_registers: dict[fx.Node, list[fx.Node]], + graph: fx.Graph, + node_map: dict[fx.Node, fx.Node], + create_new_nodes: bool = False, +) -> dict[fx.Node, deque[fx.Node]]: + """ + Pushes the rotating registers into the argument map + at the appropriate stages. Create new nodes in the + specified graph if requested. + + For each rotating register, + we evaluate which stage it belongs to and update the argument + context for the next stage and n - 1 stages after it, where + n is the total number of rotating registers. + If var a has [a, b, c] as rotating registers, then in a 3-stage schedule + a is used in stage 2, (iteration 0) + b in stage 1, (iteration 1) + c in stage 0. (iteration 2) + """ + new_rotating_registers: dict[fx.Node, deque[fx.Node]] = {} + count = 0 + for node, registers in rotating_registers.items(): + new_registers: deque[fx.Node] = deque() + custom = get_custom(node) + stage = custom.scheduling_parameters["stage"] + iteration = arg_context.get_kernel_iteration(stage) + arg_context[(iteration, stage, node)] = registers[-1] + for i, register in enumerate(registers): + mapped_stage = stage + len(registers) - i + mapped_iteration = arg_context.get_kernel_iteration(mapped_stage) + if create_new_nodes: + iter_arg = IterArg(f"rotating_reg_{count}").add_to_graph(graph) + iter_arg.type = get_custom(node).type + iter_arg.index = get_custom(node).index + arg_context[(mapped_iteration, mapped_stage, node)] = iter_arg + new_registers.append(iter_arg) + logger.debug( + f"Mapped orig: {node_map[node]} / mapped: {iter_arg} to stage {mapped_stage}." + ) + else: + arg_context[(mapped_iteration, mapped_stage, node)] = register + logger.debug( + f"Mapped orig: {node_map[node]} / mapped: {register} to stage {mapped_stage}." + ) + count += 1 + if new_registers: + new_rotating_registers[node] = new_registers + return new_rotating_registers + + +def construct_kernel( + reduction_subgraph: fx.Graph, + reduction: Reduction, + partitioned_graph: list[dict[int, fx.Node]], + scheduler: ModuloScheduler, + rotating_registers: dict[fx.Node, list[fx.Node]], + induction_variable: IndexSymbol, + new_induction_variables: list[int], + node_map: dict[fx.Node, fx.Node], + visualize: bool = False, +) -> tuple[Reduction, fx.Graph]: + """ + Construct the kernel of the pipelined loop. + First, we construct a new reduction op with an empty graph. + Then, we set the init args, construct the iter args and add the ops. + Finally, we create the output node with the return values. + The iter args/results of the pipelined reduction are always: + [results0, result1, ..., resultN, rotating_reg0, rotating_reg1, ..., rotating_regN] + """ + logger.debug("=====================================") + logger.debug("Constructing kernel.") + logger.debug("=====================================") + + with reduction.graph.inserting_before(reduction.fx_node): + pipelined_reduction = Reduction( + reduction.axis, + init_args=reduction.init_args + flatten_dict_values(rotating_registers), + subgraph_name="pipelined_reduction", + implicit_captures=reduction.implicit_captures, + ).add_to_graph(reduction.graph) + pipelined_reduction.index = reduction.index + pipelined_reduction_graph = fx.Graph() + reduction.graph.subgraphs["pipelined_reduction"] = pipelined_reduction_graph + + # Update the argument map for the new reduction. + arg_context = ArgumentContext( + reduction.outputs(reduction_subgraph), + reduction.iter_args(reduction_subgraph), + reduction.init_args, + scheduler.num_stages, + ) + push_placeholders(reduction.implicit_captures, reduction_subgraph, arg_context) + + # For the original iter args, we just map the old ones to the new ones. + # Do this for all stages, since the original iter args are "dummy" nodes + # during scheduling. + for node in arg_context.iter_args: + iter_arg = IterArg(node.name).add_to_graph(pipelined_reduction_graph) + iter_arg.type = get_custom(node).type + iter_arg.index = get_custom(node).index + arg_context.map_arg_all(node, iter_arg) + + # Push the rotating registers into the argument context. + new_rotating_registers: dict[fx.Node, deque[fx.Node]] = push_rotating_registers( + arg_context, + rotating_registers, + pipelined_reduction_graph, + node_map, + create_new_nodes=True, + ) + + add_nodes_by_schedule( + pipelined_reduction_graph, + partitioned_graph, + arg_context, + list(reversed(range(scheduler.num_stages))), + scheduler.initiation_interval, + induction_variable, + new_induction_variables, + new_rotating_registers, + PipelineStage.KERNEL, + ) + + # Create output node (last node in the graph). + return_vals: list[fx.Node] = arg_context.get_kernel_results() + for registers in new_rotating_registers.values(): + return_vals.extend(registers) + + Output(return_vals).add_to_graph(pipelined_reduction_graph) + reduction.replace_all_uses_with(pipelined_reduction) + + if visualize: + visualize_mapped_graphs( + pipelined_reduction_graph, + new_rotating_registers, + arg_context.argument_map, + "kernel.png", + ) + + return pipelined_reduction, pipelined_reduction_graph + + +def construct_epilogue( + reduction_subgraph: fx.Graph, + reduction: Reduction, + pipelined_reduction: Reduction, + partitioned_graph: list[dict[int, fx.Node]], + scheduler: ModuloScheduler, + rotating_registers: dict[fx.Node, list[fx.Node]], + induction_variable: IndexSymbol, + new_induction_variables: list[int], + stages: list[int], + num_rotating_registers: dict[fx.Node, int], + node_map: dict[fx.Node, fx.Node], + visualize: bool = False, +): + """ + Construct the epilogue of the pipelined loop. + The difference from the prologue is that we need to map the results + of the pipelined reduction to the remaining stages. (In the prologue, + no iteration is every completed and so we don't compute the final results) + We emit GetResult nodes for the rotating registers and map them to + the different epilogue stages. + """ + logger.debug("=====================================") + logger.debug("Constructing epilogue.") + logger.debug("=====================================") + + arg_context = ArgumentContext( + reduction.outputs(reduction_subgraph), + reduction.iter_args(reduction_subgraph), + reduction.init_args, + scheduler.num_stages, + ) + + existing_get_results: list[GetResult] = sorted( + [x for x in pipelined_reduction.users if isinstance(x, GetResult)], + key=lambda x: x.res_idx, + ) + existing_users = {x: x.users for x in existing_get_results} + + # Map the results from the kernel to the init args (for stages). + for iter_arg, get_result in zip( + reduction.iter_args(reduction_subgraph), existing_get_results + ): + arg_context.map_arg_all(iter_arg, get_result.fx_node) + + with pipelined_reduction.graph.inserting_before( + existing_get_results[0].fx_node.next + ): + # Add get result nodes for the rotating registers and update the + # argument map with them. + rotating_registers_get_results = [] + offset = len(existing_get_results) + for i in range(len(flatten_dict_values(rotating_registers))): + rotating_registers_get_results.append( + GetResult(pipelined_reduction.fx_node, i + offset).add_to_graph( + pipelined_reduction.graph + ) + ) + rotating_registers = unflatten_dict_values( + num_rotating_registers, rotating_registers_get_results + ) + + # Push the rotating registers onto the argument map. + push_rotating_registers(arg_context, rotating_registers, None, node_map, False) + push_placeholders(reduction.implicit_captures, reduction_subgraph, arg_context) + + for i in range(scheduler.num_stages - 1): + add_nodes_by_schedule( + pipelined_reduction.graph, + partitioned_graph, + arg_context, + stages[i], + scheduler.initiation_interval, + induction_variable, + new_induction_variables, + rotating_registers, + PipelineStage.EPILOGUE, + ) + + # Replace the existing uses with the new results. + new_results = arg_context.get_mapped_results(existing_get_results) + assert len(new_results) == len(existing_get_results) + for i, get_result in enumerate(existing_get_results): + replace_uses_in(existing_users, get_result, new_results[i]) + + if visualize: + visualize_mapped_graphs( + pipelined_reduction.graph, + rotating_registers, + arg_context.argument_map, + "epilogue.png", + ) + + +def construct_pipelined_loop( + trace: CapturedTrace, + reduction: Reduction, + graph: fx.Graph, + constraints: list[Constraint], + scheduler: ModuloScheduler, + node_map: dict[fx.Node, fx.Node], + max_induction_variable: int, + visualize: bool = False, +) -> fx.Node: + """ + Given a graph annotated with scheduling parameters, construct a pipelined loop + with a prologue, kernel and epilogue. + """ + induction_variable = get_induction_variable(reduction, constraints) + num_rotating_registers = liveness_analysis(graph, constraints, scheduler) + rotating_registers: dict[fx.Node, deque[fx.Node]] = { + k: deque([None for _ in range(v)]) for k, v in num_rotating_registers.items() + } + partitioned_graph = partition_graph_by_stage(graph, scheduler) + # Construct prologue. + construct_prologue( + graph, + reduction, + partitioned_graph, + scheduler, + rotating_registers, + induction_variable, + list(range(scheduler.num_stages)), + create_fill_stage_schedule(scheduler.num_stages), + ) + # Construct kernel. + pipelined_reduction, pipelined_reduction_graph = construct_kernel( + graph, + reduction, + partitioned_graph, + scheduler, + rotating_registers, + induction_variable, + [induction_variable + i for i in range(scheduler.num_stages)], + node_map, + visualize, + ) + trace.add_subgraph( + get_custom(pipelined_reduction).subgraph_name, pipelined_reduction_graph + ) + # Construct epilogue. + construct_epilogue( + graph, + reduction, + get_custom(pipelined_reduction), + partitioned_graph, + scheduler, + rotating_registers, + induction_variable, + [ + max_induction_variable - scheduler.num_stages + i + for i in range(scheduler.num_stages) + ], + create_drain_stage_schedule(scheduler.num_stages), + num_rotating_registers, + node_map, + visualize, + ) + + # Remove the unpipelined reduction. + reduction.graph.erase_node(reduction.fx_node) + + if visualize: + visualize_graph(pipelined_reduction.graph, "pipelined.png") + + return pipelined_reduction diff --git a/shark_turbine/kernel/wave/scheduling/loop_reconstruction_utils.py b/shark_turbine/kernel/wave/scheduling/loop_reconstruction_utils.py new file mode 100644 index 00000000..b6993a21 --- /dev/null +++ b/shark_turbine/kernel/wave/scheduling/loop_reconstruction_utils.py @@ -0,0 +1,285 @@ +from ..constraints import Constraint, TilingConstraint +from ..._support.indexing import IndexSymbol +from ..._support.tracing import CapturedTrace +from ...ops.wave_ops import Reduction, IterArg, Output, Write, GetResult, get_custom +from .modulo_scheduling import ModuloScheduler +from ..utils import graph_copy, erase_graph +from ..utils import subs_idxc +import torch.fx as fx +import math +from collections import defaultdict, deque, ChainMap +from ..visualization import visualize_mapped_graphs +from ....support.logging import get_logger +from ...lang.global_symbols import SHARED_ADDRESS_SPACE +import random +from typing import Optional + +logger = get_logger("turbine.wave.scheduling.loop_reconstruction_utils") + + +class ArgumentContext: + """ + The argument context is used to store the mapping of arguments + for each modulo pipelining stage. + """ + + def __init__( + self, + results: list[fx.Node], + iter_args: list[fx.Node], + init_args: list[fx.Node], + num_stages: int, + ) -> None: + self.argument_map: list[list[dict[fx.Node, fx.Node]]] = [ + [{} for _ in range(num_stages)] for _ in range(num_stages) + ] + self.results = results + self.iter_args = iter_args + self.init_args = init_args + self.num_stages = num_stages + self.num_iterations = num_stages + self.result_to_iter_arg: dict[fx.Node, fx.Node] = {} + self.result_to_init_arg: dict[fx.Node, fx.Node] = {} + + for result, iter_arg in zip(results, iter_args): + self.result_to_iter_arg[result] = iter_arg + for result, init_arg in zip(results, init_args): + self.result_to_init_arg[result] = init_arg + + def map_arg_all(self, from_: fx.Node, to_: fx.Node) -> None: + """ + Maps the given argument from one to another into the argument context for all stages + and for all iterations. + """ + for iteration in range(self.num_iterations): + for stage in range(self.num_stages): + self.argument_map[iteration][stage][from_] = to_ + + def map_arg_all_iterations(self, stage: int, from_: fx.Node, to_: fx.Node) -> None: + """ + Maps the given argument from one to another into the argument context for all stages + and for all iterations. + """ + for iteration in range(self.num_iterations): + self.argument_map[iteration][stage][from_] = to_ + + def get_mapped_results(self, get_results: list[GetResult]) -> list[fx.Node]: + """ + Gets the mapped results from the last iteration. If the result is not + in the last iteration, then get it from the get result nodes. + """ + mapped_results = [] + for result, get_result in zip(self.results, get_results): + stage = result.scheduling_parameters["stage"] + if result not in self.argument_map[self.num_iterations - 1][stage]: + mapped_results.append(get_result.fx_node) + else: + mapped_results.append( + self.argument_map[self.num_iterations - 1][stage][result] + ) + return mapped_results + + def get_kernel_iteration(self, stage: int) -> int: + """ + Get the iteration from the stage for the kernel. + """ + return self.num_stages - 1 - stage + + def get_kernel_results(self) -> list[fx.Node]: + """ + Gets the mapped results for the kernel. Here there + exists a fixed relationship between the iteration and stage. + """ + mapped_results = [] + for result in self.results: + stage = result.scheduling_parameters["stage"] + iteration = self.get_kernel_iteration(stage) + mapped_results.append(self.argument_map[iteration][stage][result]) + return mapped_results + + def __setitem__(self, key: tuple[int, fx.Node], value: fx.Node) -> None: + """ + Sets the argument mapping for the given stage. + """ + assert isinstance(key, tuple), "Argument context key must be a tuple" + iteration, stage, from_ = key + assert iteration < len( + self.argument_map + ), f"Iteration {iteration} not yet initialized" + assert stage < len(self.argument_map), f"Stage {stage} not yet initialized" + self.argument_map[iteration][stage][from_] = value + + def __getitem__(self, value: tuple[int, fx.Node]) -> fx.Node: + """ + Gets the argument mapping for the given stage. + """ + assert isinstance(value, tuple), "Argument context key must be a tuple" + iteration, stage, key = value + assert iteration < len( + self.argument_map + ), f"Iteration {iteration} not yet initialized" + assert stage < len(self.argument_map), f"Stage {stage} not yet initialized" + return self.argument_map[iteration][stage].get(key, None) + + def __contains__(self, key: fx.Node | tuple[int, fx.Node]) -> bool: + """ + Checks if the argument context contains the given node at a specified + iteration and stage or at all iterations and stages. + """ + if isinstance(key, tuple): + iteration, stage, key = key + return key in self.argument_map[iteration][stage] + return any( + key in self.argument_map[iteration][stage] + for iteration in range(self.num_iterations) + for stage in range(self.num_stages) + ) + + def lookup(self, key: fx.Node) -> Optional[fx.Node]: + """ + Looks up the argument mapping for the given node. + """ + for iteration in range(self.num_iterations): + for stage in range(self.num_stages): + if key in self.argument_map[iteration][stage]: + return self.argument_map[iteration][stage][key] + return None + + def contains_in_iteration(self, iteration: int, key: fx.Node) -> bool: + """ + Checks if the argument context contains the given node at a specified + iteration. + """ + return any( + key in self.argument_map[iteration][stage] + for stage in range(self.num_stages) + ) + + def get_from_iteration(self, iteration: int, key: fx.Node) -> fx.Node: + """ + Gets the argument mapping for the given iteration. + """ + for stage in range(self.num_stages): + if key in self.argument_map[iteration][stage]: + return self.argument_map[iteration][stage][key] + return None + + def dump(self): + """ + Dump the argument context to the logger. + """ + for iteration in range(self.num_iterations): + for stage in range(self.num_stages): + logger.debug(f"Iteration: {iteration}, Stage: {stage}") + for key, value in self.argument_map[iteration][stage].items(): + logger.debug(f" {key} -> {value}") + + +def create_fill_stage_schedule(n: int) -> list[list[int]]: + """ + Create the schedule of which stages need to be interleaved for the prologue (fill). + This looks like: + [0 None None None] + [1 0 None None] + [2 1 0 None] + """ + schedule = [] + for i in range(n - 1): + row = list(range(i, -1, -1)) + row.extend([None] * (n - i - 1)) + schedule.append(row) + return schedule + + +def create_drain_stage_schedule(n: int) -> list[list[int]]: + """ + Create the schedule of which stages need to be interleaved for the epilogue (drain). + This looks like: + [None 3 2 1] + [None None 3 2] + [None None None 3] + """ + schedule = [] + for i in range(n - 1): + row = [None] * (i + 1) + row.extend(range(n - 1, i, -1)) + schedule.append(row) + return schedule + + +def liveness_analysis( + graph: fx.Graph, constraints: list[Constraint], scheduler: ModuloScheduler +) -> dict[fx.Node, int]: + """ + Perform liveness analysis on the graph to determine the live ranges of + variables and use that to deduce how many rotating registers we need. + """ + lifetime: dict[fx.Node, int] = {} + for node in graph.nodes: + custom = get_custom(node) + if custom.scheduling_parameters is None: + continue + if node not in lifetime: + lifetime[node] = 0 + for user in custom.users: + if user.scheduling_parameters is None: + continue + logger.debug( + f"Node: {node}, User: {user.fx_node}, lifetime: {user.scheduling_parameters['stage'] - custom.scheduling_parameters['stage']}" + ) + lifetime[node] = max( + user.scheduling_parameters["stage"] + - custom.scheduling_parameters["stage"], + lifetime[node], + ) + + # Determine how many copies we need for each node. If the lifetime of a node + # is l clocks and the initiation interval is T, then only ceil(l/T) values + # of the node can be live at the same time. We need to create copies of only + # those nodes that are live at more than one stage. + num_rotating_registers: dict[fx.Node, int] = {} + for node, l in lifetime.items(): + if node in num_rotating_registers: + continue + custom = get_custom(node) + if ( + isinstance(custom, Write) + and custom.memory_type.address_space == SHARED_ADDRESS_SPACE + ): + continue + if l > 0: + num_rotating_registers[node] = l + + return num_rotating_registers + + +def partition_graph_by_stage( + graph: fx.Graph, scheduler: ModuloScheduler +) -> list[dict[int, list[fx.Node]]]: + """ + Partition the graph into stages based on the scheduling parameters. + """ + partitioned_graph: list[dict[int, list[fx.Node]]] = [ + defaultdict(list) for _ in range(scheduler.num_stages) + ] + for stage in range(scheduler.num_stages): + for node in graph.nodes: + custom = get_custom(node) + if custom.scheduling_parameters is None: + continue + if isinstance(custom, IterArg): + continue + if custom.scheduling_parameters["stage"] == stage: + cycle = custom.scheduling_parameters["cycle"] + partitioned_graph[stage][cycle].append(node) + return partitioned_graph + + +def interleave_instructions(instructions: list[tuple[int, int, fx.Node]]): + """ + Interleave the instructions that are scheduled in the same cycle. + Currently, we just randomly shuffle them, but we could also sort + them based on some criteria. + """ + rng = random.Random(0) + # rng.shuffle(instructions) diff --git a/shark_turbine/kernel/wave/scheduling/modulo_scheduling.py b/shark_turbine/kernel/wave/scheduling/modulo_scheduling.py index f2abbd13..82940113 100644 --- a/shark_turbine/kernel/wave/scheduling/modulo_scheduling.py +++ b/shark_turbine/kernel/wave/scheduling/modulo_scheduling.py @@ -18,6 +18,7 @@ ) from typing import Callable import numpy as np +import math logger = get_logger("turbine.wave.modulo_scheduling") @@ -263,3 +264,11 @@ def resource_reservations(self) -> np.array: Returns the resource reservations of the schedule. """ return self.RT + + @property + def num_stages(self) -> int: + """ + Returns the number of stages in the kernel of the pipelined loop. + """ + max_cycle = max([t for t in self.schedule.values()]) + return math.ceil(max_cycle / self.initiation_interval) diff --git a/shark_turbine/kernel/wave/scheduling/schedule.py b/shark_turbine/kernel/wave/scheduling/schedule.py index a03ad082..e2b3a88e 100644 --- a/shark_turbine/kernel/wave/scheduling/schedule.py +++ b/shark_turbine/kernel/wave/scheduling/schedule.py @@ -11,8 +11,12 @@ from .graph_utils import create_scheduling_edges, Edge from .resources import get_available_resources, annotate_resource_usage from ..visualization import visualize_edges, visualize_graph, visualize_schedule -from ..utils import subs_idxc, graph_copy, erase_graph +from .loop_reconstruction import construct_pipelined_loop +from ..utils import graph_copy, erase_graph, get_tiling_constraint, subs_idxc import torch.fx as fx +from ....support.logging import get_logger + +logger = get_logger("turbine.wave.scheduling.schedule") def visualize_scheduling_graph(edges: list[Edge]): @@ -68,6 +72,32 @@ def schedule_reduction( erase_graph(graph) + # After scheduling has completed, we have enough information to decide + # whether to pipeline the loop. For pipelining to be possible, we need + # to have atleast N iterations of the loop where N > num_stages - 1 (because + # we will be peeling off num_stages iterations from the loop). + tiling_constraint = get_tiling_constraint(reduction, constraints) + max_induction_variable = int( + subs_idxc(tiling_constraint.dim) // subs_idxc(tiling_constraint.tile_size) + ) + if max_induction_variable <= scheduler.num_stages - 1: + logger.warn("Not enough iterations to pipeline the loop. Skipping pipelining.") + return {} + + new_reduction = construct_pipelined_loop( + trace, + reduction, + reduction_graph, + constraints, + scheduler, + node_map, + max_induction_variable, + visualize, + ) + + # Update new reduction count. + new_reduction.count = max_induction_variable - (scheduler.num_stages - 1) + def schedule_graph(trace: CapturedTrace, constraints: list[Constraint]): """ diff --git a/shark_turbine/kernel/wave/utils.py b/shark_turbine/kernel/wave/utils.py index c2d0a582..e3f7a62b 100644 --- a/shark_turbine/kernel/wave/utils.py +++ b/shark_turbine/kernel/wave/utils.py @@ -103,6 +103,21 @@ def print_trace(trace: CapturedTrace, custom_print: bool = True): print(get_custom(node)) +def print_subgraph(trace: CapturedTrace, subgraph_name: str, custom_print: bool = True): + """ + Prints a specific subgraphs of a trace. + The graphs are printed first in the torch printing format and + then using our custom node format. + """ + # The root graph is at the back so we print the subgraphs in reverse order + for name, subgraph in trace.region_graph.subgraphs.items(): + if name == subgraph_name: + print(subgraph) + if custom_print: + for node in subgraph.nodes: + print(get_custom(node)) + + def DCE(trace: CapturedTrace): """ Removes all operators that are not used in the graph, @@ -560,3 +575,42 @@ def find_index_bounds( return None return bounds + + +def get_induction_variable( + reduction: Reduction, constraints: list[Constraint] +) -> IndexSymbol: + induction_var = None + for constraint in constraints: + if ( + isinstance(constraint, TilingConstraint) + and reduction.axis == constraint.dim + ): + induction_var = constraint.induction_var + break + else: + raise ValueError(f"Could not find induction variable for reduction {reduction}") + return induction_var + + +def get_tiling_constraint( + reduction: Reduction, constraints: list[Constraint] +) -> TilingConstraint: + for constraint in constraints: + if ( + isinstance(constraint, TilingConstraint) + and reduction.axis == constraint.dim + ): + return constraint + else: + raise ValueError(f"Could not find tiling constraint for reduction {reduction}") + + +def replace_uses_in(users: dict[fx.Node, list[CustomOp]], old: CustomOp, new: fx.Node): + """ + Replace all uses of `old` with `new` in the list of users. + """ + for user in users[old]: + for i, arg in enumerate(user.fx_node.args): + if arg == old.fx_node: + user.update_arg(i, new) diff --git a/shark_turbine/kernel/wave/visualization.py b/shark_turbine/kernel/wave/visualization.py index 924c36bd..d6438bfc 100644 --- a/shark_turbine/kernel/wave/visualization.py +++ b/shark_turbine/kernel/wave/visualization.py @@ -11,6 +11,8 @@ graphviz_disabled = True from torch import fx from .scheduling.graph_utils import Edge +from ..ops.wave_ops import Output, Placeholder, IterArg, get_custom +from collections import ChainMap import math @@ -27,6 +29,9 @@ def visualize_graph(graph: fx.Graph, file_name: str): G.add_node(node_numbering[id(node)], label=node.name) for node in graph.nodes: for user in node.users.keys(): + # Handle scenario where nodes are shared across graphs. + if user not in graph.nodes: + continue G.add_edge(node_numbering[id(node)], node_numbering[id(user)]) G.layout(prog="dot") G.draw(file_name) @@ -71,7 +76,7 @@ def visualize_schedule( for key, value in schedule.items(): table[value + stage * initiation_interval][stage] += f"{key}
" - df = pd.DataFrame(table, columns=[f"Stage {i}" for i in range(cols)]) + df = pd.DataFrame(table, columns=[f"Iteration {i}" for i in range(cols)]) s = df.style.set_properties(**{"text-align": "center"}) s = s.set_table_styles( [ @@ -95,3 +100,91 @@ def visualize_schedule( ).to_html() with open(f"{file_name}", "w") as f: f.write(output) + + +def visualize_mapped_graphs( + second: fx.Graph, + rotating_registers: dict[fx.Node, list[fx.Node]], + mappings: list[list[dict[fx.Node, fx.Node]]], + file_name: str, +): + """ + Given the pipelined graph and a list of mappings of nodes from the original + graph to the pipelined graph (per stage), visualize the pipelined graph (with their original labels) + + """ + + if graphviz_disabled: + raise ImportError("pygraphviz not installed, cannot visualize graph") + second_numbering = number_nodes(second) + + flat_inverse_map: dict[fx.Node, fx.Node] = {} + flat_map: dict[fx.Node, fx.Node] = {} + for iteration_mapping in mappings: + for mapping in iteration_mapping: + flat_inverse_map.update({v: k for k, v in mapping.items()}) + flat_map.update(mapping) + flat_inverse_map = ChainMap(flat_inverse_map) + flat_map = ChainMap(flat_map) + + # Draw nodes and edges in the pipelined graph. + G = pgv.AGraph(directed=True) + G0 = G.add_subgraph(name="pipelined") + stage: dict[fx.Node, int] = {} + for node in second.nodes: + if hasattr(node, "scheduling_parameters"): + if node in flat_inverse_map: + name = flat_inverse_map[node].name + else: + name = node.name + else: + name = node.name + G0.add_node( + second_numbering[id(node)], + label=name, + color="lightblue", + style="filled", + ) + for user in node.users.keys(): + if user not in second.nodes: + continue + if isinstance(get_custom(user), Output): + continue + G0.add_edge( + second_numbering[id(node)], + second_numbering[id(user)], + color="black", + ) + + # Draw nodes and edges in the original graph. + colors = ["red", "green", "orange", "purple", "orange", "cyan", "magenta"] + max_stage = len(mappings) + for node, mapped_node in flat_map.items(): + for user in node.users.keys(): + if user not in flat_map: + continue + mapped_user = flat_map[user] + if mapped_user not in second.nodes or mapped_node not in second.nodes: + continue + stage = "" + if hasattr(user, "scheduling_parameters"): + stage = user.scheduling_parameters["stage"] + G.add_edge( + second_numbering[id(mapped_node)], + second_numbering[id(mapped_user)], + label=f"{stage}", + color=colors[stage % max_stage], + ) + + # Draw edges between rotating registers for the same variable. + for node in rotating_registers: + all_registers = [k for k, v in flat_inverse_map.items() if v == node] + for second, first in zip(all_registers[:-1], all_registers[1:]): + G.add_edge( + second_numbering[id(first)], + second_numbering[id(second)], + color="blue", + ) + + G.layout(prog="dot") + G.draw(file_name) diff --git a/shark_turbine/kernel/wave/wave.py b/shark_turbine/kernel/wave/wave.py index 4d19d99f..323a3198 100644 --- a/shark_turbine/kernel/wave/wave.py +++ b/shark_turbine/kernel/wave/wave.py @@ -28,6 +28,7 @@ compile_and_invoke, safe_subs, remove_chained_getresult, + subs_idxc, ) from .minimize_global_loads import minimize_global_loads from .decompose_reduce_ops import decompose_reduce_ops @@ -184,6 +185,18 @@ def initialize_wave_constraints(self, trace: CapturedTrace) -> None: / hardware_constraint.threads_per_wave ) + def initialize_reductions(self, trace: CapturedTrace) -> None: + """ + For each reduction, initializes the reduction count by looking at the + tiling constraints associated with the reduction. + + """ + is_reduction = lambda node: isinstance(get_custom(node), Reduction) + for reduction in trace.walk(is_reduction): + for tiling_constraint in self.tiling_constraints: + if tiling_constraint.dim == get_custom(reduction).axis: + reduction.count = subs_idxc(tiling_constraint.count) + def _trace_and_get_kernel_signature( self, args, @@ -196,6 +209,7 @@ def _trace_and_get_kernel_signature( self.create_induction_vars(graph) self.initialize_wave_constraints(graph) + self.initialize_reductions(graph) idxc = IndexingContext.current() idxc.finalize() diff --git a/tests/kernel/wave/wave_gemm_test.py b/tests/kernel/wave/wave_gemm_test.py index 344032a4..2386ebd9 100644 --- a/tests/kernel/wave/wave_gemm_test.py +++ b/tests/kernel/wave/wave_gemm_test.py @@ -15,6 +15,7 @@ from shark_turbine.kernel.wave.iree_utils import generate_iree_ref import os import json +from torch.testing import assert_close _run_e2e = int(os.environ.get("WAVE_RUN_E2E_TESTS", 0)) require_e2e = pytest.mark.skipif(not _run_e2e, reason="e2e tests are disabled") @@ -40,7 +41,8 @@ def get_test_shapes(test_name: str) -> list[tuple[int]]: @require_e2e @pytest.mark.parametrize("shape", get_test_shapes("test_gemm")) -def testGemm(shape: tuple[int]): +@pytest.mark.parametrize("enable_scheduling", [False, True]) +def testGemm(shape: tuple[int], enable_scheduling: bool): # Input sizes M = tkl.sym.M @@ -106,10 +108,22 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: M: shape[0], N: shape[1], K: shape[2], + READ_SHARED_DELAY: 1, + WRITE_SHARED_DELAY: 1, + READ_GLOBAL_DELAY: 2, + WRITE_GLOBAL_DELAY: 2, + MMA_DELAY: 1, + SHARED_MEMORY_UNITS: 4, + GLOBAL_MEMORY_UNITS: 4, + MMA_UNITS: 4, } config = {"backend": "rocm", "device": "hip", "target": "gfx942"} with tk.gen.TestLaunchContext( - hyperparams, canonicalize=True, run=True, run_config=config + hyperparams, + canonicalize=True, + run=True, + run_config=config, + schedule=enable_scheduling, ): a = torch.randn(shape[0], shape[2], dtype=torch.float16) b = torch.randn(shape[1], shape[2], dtype=torch.float16) @@ -123,9 +137,4 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: iree_ref = torch.zeros(shape[0], shape[1], dtype=torch.float32) generate_iree_ref("mmt", [a, b], [iree_ref], config) - assert torch.equal(c, iree_ref) - - -if __name__ == "__main__": - logging.basicConfig(level=logging.DEBUG) - unittest.main() + assert_close(c, iree_ref) From 553e929459ba84773468cce2eb53622dd46497ab Mon Sep 17 00:00:00 2001 From: harsh-nod Date: Tue, 1 Oct 2024 16:27:27 -0700 Subject: [PATCH 13/28] Add support for dynamic dims (#178) This PR adds support for dynamic dimensions in the kernels. The user specifies the dynamic dimensions by - Not adding them to the hyperparameter dictionary - Explicitly specifying them with the dynamic_symbols kwarg and the dynamic_symbols_mapping kwarg to specify which values to use for the dynamic dims at runtime This PR does not modify the codegen and so incorrect or unsupported values for the dynamic dims will result in incorrect results. (garbage in -> garbage out) --------- Signed-off-by: Harsh Menon --- lit_tests/kernel/wave/codegen.py | 94 ++++++++++++++++--- .../kernel/compiler/dispatch_codegen.py | 75 ++++++++++----- shark_turbine/kernel/compiler/host_codegen.py | 40 +++++++- .../kernel/compiler/kernel_codegen.py | 16 ++++ shark_turbine/kernel/wave/codegen.py | 68 +++++++++----- shark_turbine/kernel/wave/utils.py | 12 ++- shark_turbine/kernel/wave/wave.py | 22 ++++- tests/kernel/wave/wave_e2e_test.py | 58 +++++++++++- 8 files changed, 318 insertions(+), 67 deletions(-) diff --git a/lit_tests/kernel/wave/codegen.py b/lit_tests/kernel/wave/codegen.py index a7781a39..4c732079 100644 --- a/lit_tests/kernel/wave/codegen.py +++ b/lit_tests/kernel/wave/codegen.py @@ -21,18 +21,24 @@ ADDRESS_SPACE_0 = tkl.sym.ADDRESS_SPACE_0 -def codegen_test_context(canonicalize: bool = False): +def codegen_test_context(canonicalize: bool = False, dynamic_symbols=[]): + bindings = { + M: 16, + N: 16, + K: 16, + BLOCK_M: 16, + BLOCK_N: 16, + BLOCK_K: 16, + ADDRESS_SPACE: tkl.AddressSpace.SHARED_MEMORY.value, + } + + # Remove dynamic symbols from the bindings. + for sym in dynamic_symbols: + if sym in bindings: + del bindings[sym] + return tk.gen.TestLaunchContext( - { - M: 16, - N: 16, - K: 16, - BLOCK_M: 16, - BLOCK_N: 16, - BLOCK_K: 16, - ADDRESS_SPACE: tkl.AddressSpace.SHARED_MEMORY.value, - }, - canonicalize=canonicalize, + bindings, canonicalize=canonicalize, dynamic_symbols=dynamic_symbols ) @@ -328,6 +334,72 @@ def test( # CHECK-SAME: strided<[16, 1], offset: ?>>, vector<16xindex>, vector<16xi1>, vector<16xf16> +@run_test +def test_dynamic_copy(): + constraints: list[tkw.Constraint] = [ + tkw.HardwareConstraint( + threads_per_wave=64, waves_per_block=(1, 1, 1), vector_shapes={M: 16, N: 16} + ) + ] + constraints += [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] + constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] + constraints += [tkw.WaveConstraint(M, BLOCK_M)] + constraints += [tkw.WaveConstraint(N, BLOCK_N)] + + @tkw.wave(constraints) + def test(a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16]): + b = tkw.read(a, elements_per_thread=16) + tkw.write(b, a, elements_per_thread=16) + + with codegen_test_context(canonicalize=True, dynamic_symbols=[M, N]): + a = torch.randn(16, 16, dtype=torch.float16) + print(test(a).module_op) + + # CHECK: stream.executable.export public @test workgroups(%[[ARG0:.*]]: index, %[[ARG1:.*]]: + # CHECK-SAME: index) -> (index, index, index) { + # CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index + # CHECK-DAG: %[[C16:.+]] = arith.constant 16 : index + # CHECK: %[[D0:.+]] = arith.ceildivsi %[[ARG0]], %[[C16]] : index + # CHECK: %[[D1:.+]] = arith.ceildivsi %[[ARG1]], %[[C16]] : index + # CHECK: stream.return %[[D0]], %[[D1]], %[[C1]] : index, index, index + # CHECK: } + # CHECK: func.func @test(%[[ARG0:.*]]: !stream.binding, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index) + # CHECK-SAME: attributes {translation_info = #[[TRANSLATION:.+]]} { + # CHECK-DAG: %[[CST:.+]] = arith.constant dense<0.000000e+00> : vector<16xf16> + # CHECK-DAG: %[[CST_0:.+]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : + # CHECK-SAME: vector<16xindex> + # CHECK-DAG: %[[C32:.+]] = arith.constant 32 : index + # CHECK-DAG: %[[C64:.+]] = arith.constant 64 : index + # CHECK-DAG: %[[C16]] = arith.constant 16 : index + # CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index + # CHECK: %[[WORKGROUP_ID_0:.+]] = stream.dispatch.workgroup.id[0] : index + # CHECK: %[[WORKGROUP_ID_1:.+]] = stream.dispatch.workgroup.id[1] : index + # CHECK-DAG: %[[THREAD_ID_X:.+]] = gpu.thread_id x + # CHECK-DAG: %[[THREAD_ID_Y:.+]] = gpu.thread_id y + # CHECK: %[[D0]] = stream.binding.subspan %[[ARG0]][%[[C0]]] : !stream.binding -> memref{%[[ARG1]], + # CHECK-SAME: %[[ARG2]]} + # CHECK: %[[D1]] = arith.muli %[[WORKGROUP_ID_0]], %[[C16]] : index + # CHECK: %[[D2:.+]] = arith.divsi %[[THREAD_ID_X]], %[[C64]] : index + # CHECK: %[[D3:.+]] = arith.muli %[[D2]], %[[C16]] : index + # CHECK: %[[D4:.+]] = arith.addi %[[D3]], %[[D1]] : index + # CHECK: %[[D5:.+]] = arith.addi %[[D4]], %[[THREAD_ID_X]] : index + # CHECK: %[[D6:.+]] = arith.muli %[[WORKGROUP_ID_1]], %[[C16]] : index + # CHECK: %[[D7:.+]] = arith.muli %[[THREAD_ID_Y]], %[[C32]] : index + # CHECK: %[[D8:.+]] = arith.addi %[[D7]], %[[D6]] : index + # CHECK: %[[D9:.+]] = vector.splat %[[D8]] : vector<16xindex> + # CHECK: %[[D10:.+]] = arith.addi %[[D9]], %[[CST_0]] : vector<16xindex> + # CHECK: %[[D11:.+]] = vector.splat %[[ARG2]] : vector<16xindex> + # CHECK: %[[D12:.+]] = arith.cmpi slt, %[[D10]], %[[D11]] : vector<16xindex> + # CHECK: %[[D13:.+]] = arith.cmpi slt, %[[D5]], %[[ARG1]] : index + # CHECK: %[[D14:.+]] = vector.splat %[[D13]] : vector<16xi1> + # CHECK: %[[D15:.+]] = arith.andi %[[D12]], %[[D14]] : vector<16xi1> + # CHECK: %[[D16:.+]] = vector.maskedload %[[D0]][%[[D5]], %[[D8]]], %[[D15]], %[[CST]] : memref, + # CHECK-SAME: vector<16xi1>, vector<16xf16> into vector<16xf16> + # CHECK: vector.maskedstore %[[D0]][%[[D5]], %[[D8]]], %[[D15]], %[[D16]] : memref, vector<16xi1>, + # CHECK-SAME: vector<16xf16> + # CHECK: return + + @run_test def test_mma(): constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] diff --git a/shark_turbine/kernel/compiler/dispatch_codegen.py b/shark_turbine/kernel/compiler/dispatch_codegen.py index 0fccf39c..32dab88c 100644 --- a/shark_turbine/kernel/compiler/dispatch_codegen.py +++ b/shark_turbine/kernel/compiler/dispatch_codegen.py @@ -7,9 +7,7 @@ from typing import Any, Callable, Optional, Type -from .._support.indexing import ( - IndexingContext, -) +from .._support.indexing import IndexingContext, IndexSymbol, IndexExpr from .base import ( CodegenError, @@ -99,6 +97,7 @@ def define_entrypoint( grid: Grid, workgroup_size: list[int] = None, subgroup_size: int = None, + dynamic_symbols: list[IndexSymbol] = [], ) -> "DispatchEntrypoint": """Defines a dispatch function with a signature like: @@ -119,7 +118,6 @@ def define_entrypoint( The given name is not uniqued (must be unique as given by the caller). """ kb_input_bindings = sig.kernel_buffer_input_bindings - kb_temp_bindings = sig.kernel_buffer_temporary_bindings kb_output_bindings = sig.kernel_buffer_output_bindings # TODO: The way we are doing grid bindings is wrong. The Grid type # should be paramerized with special grid axis symbols which are @@ -127,18 +125,17 @@ def define_entrypoint( # just assuming that the grid dims can be resolved to constants , when # in reality, we should pass the workload and parameterize the grid # dims on the workloads. - workload_axis_bindings = [] + dynamic_dim_bindings = sig.dynamic_dim_bindings # Input bindings are always user specified. - # Grid/workgroup bindings are in the inputs section but are implied. - # Temp bindings are a special kind of output bindings. # Output bindings are the real outputs. - linear_bindings = ( - kb_input_bindings - + workload_axis_bindings - + kb_temp_bindings - + kb_output_bindings - ) + # Dynamic dim bindings are the dynamic dims of the input and output tensors. + linear_bindings = kb_input_bindings + dynamic_dim_bindings + kb_output_bindings + + dynamic_dim_indices = { + "begin": len(kb_input_bindings), + "end": len(linear_bindings) - len(kb_output_bindings), + } # TODO: This is sloppy. This assert will hit on some user errors for # unsupported type combinations and is just a last resort right now. @@ -177,7 +174,7 @@ def abi_type(binding: BindingDesc): with InsertionPoint.at_block_begin(self._exe_block): export_op = stream_d.ExecutableExportOp(name, name) export_block = export_op.workgroup_count.blocks.append( - *([b.as_mlir_type() for b in workload_axis_bindings]) + *([b.as_mlir_type() for b in dynamic_dim_bindings]) ) workgroup_builder = WorkgroupBuilder( @@ -185,12 +182,30 @@ def abi_type(binding: BindingDesc): ) # TODO: Support passing workload to the dispatch function. + from ..wave.codegen import gen_sympy_index + + # Map dynamic symbols to block arguments. + dynamic_symbols_mapping = { + k: v + for k, v in zip( + dynamic_symbols, workgroup_builder.entry_block.arguments + ) + } + with InsertionPoint(workgroup_builder.entry_block): result_type = IndexType.get() - workgroup_values = [ - arith_d.constant(result_type, IntegerAttr.get(result_type, dim)) - for dim in grid.dims - ] + workgroup_values = [] + for dim in grid.dims: + if isinstance(dim, IndexExpr): + workgroup_values.append( + gen_sympy_index(dynamic_symbols_mapping, dim) + ) + else: + workgroup_values.append( + arith_d.constant( + result_type, IntegerAttr.get(result_type, dim) + ) + ) while len(workgroup_values) < 3: workgroup_values.append( @@ -198,7 +213,20 @@ def abi_type(binding: BindingDesc): ) workgroup_builder.terminate(workgroup_values) - return DispatchEntrypoint(sig, def_func_block, linear_bindings) + # Map dynamic symbols to func arguments for dispatch entrypoint. + dynamic_symbols_mapping = { + k: v + for k, v in zip( + dynamic_symbols, + def_func_args[ + dynamic_dim_indices["begin"] : dynamic_dim_indices["end"] + ], + ) + } + + return DispatchEntrypoint( + sig, def_func_block, linear_bindings, dynamic_symbols_mapping + ) class WorkgroupBuilder: @@ -231,8 +259,10 @@ def __init__( sig: KernelSignature, entry_block: Block, linear_bindings: list[BindingDesc], + dynamic_symbols_mapping: dict[IndexSymbol, Value], ): super().__init__(sig, entry_block) + self.dynamic_symbols_mapping = dynamic_symbols_mapping self._abi_value_by_reference: dict[tuple[str, Any], Value] = { b.reference: value for value, b in zip(entry_block.arguments, linear_bindings) @@ -250,12 +280,15 @@ def resolve(self, binding: BindingDesc) -> Value: result_type = IndexType.get() zero_value = arith_d.constant(result_type, IntegerAttr.get(result_type, 0)) linear_arg_value = self._abi_value_by_reference[binding.reference] - # TODO: Need to also look up dynamic symbol values. return stream_d.binding_subspan( binding.as_mlir_type(), linear_arg_value, byte_offset=zero_value, - dynamic_dims=[], + dynamic_dims=[ + self.dynamic_symbols_mapping[dim] + for dim in binding.kernel_buffer_type.symbolic_shape + if dim in self.dynamic_symbols_mapping + ], ) raise ValidationError(f"Unhandled binding type: {binding}") diff --git a/shark_turbine/kernel/compiler/host_codegen.py b/shark_turbine/kernel/compiler/host_codegen.py index 9225d831..d74af490 100644 --- a/shark_turbine/kernel/compiler/host_codegen.py +++ b/shark_turbine/kernel/compiler/host_codegen.py @@ -8,6 +8,7 @@ from .ir import ( Block, FunctionType, + IndexType, InsertionPoint, IrType, Location, @@ -19,6 +20,9 @@ func_d, ) +from .._support.indexing import IndexSymbol +from .kernel_codegen import BindingDesc + def memref_to_tensor(memrefs: list[IrType]): tensors = [] @@ -29,22 +33,47 @@ def memref_to_tensor(memrefs: list[IrType]): return tensors +def get_dynamic_dims(bindings: list[BindingDesc], dynamic_symbols: list[IndexSymbol]): + dynamic_dims: list[IndexSymbol] = [] + for b in bindings: + for dim in b.kernel_buffer_type.symbolic_shape: + if dim in dynamic_symbols: + dynamic_dims.append(dim) + return dynamic_dims + + def isolated_test_call( - mb: ModuleBuilder, exe: StreamExecutable, sig: KernelSignature, entrypoint: str + mb: ModuleBuilder, + exe: StreamExecutable, + sig: KernelSignature, + entrypoint: str, + dynamic_symbols: list[IndexSymbol] = [], ): with InsertionPoint(mb.body_block), Location.unknown(): input_types = [b.as_mlir_type() for b in sig.kernel_buffer_input_bindings] input_tensors = memref_to_tensor(input_types) + argument_dims = get_dynamic_dims( + sig.kernel_buffer_input_bindings, dynamic_symbols + ) + input_tensors += [IndexType.get() for _ in argument_dims] + output_types = [b.as_mlir_type() for b in sig.kernel_buffer_output_bindings] output_tensors = memref_to_tensor(output_types) + result_dims = get_dynamic_dims( + sig.kernel_buffer_output_bindings, dynamic_symbols + ) ftype = FunctionType.get(input_tensors, output_tensors) func_op = func_d.FuncOp("isolated_benchmark", ftype) arg_locs = [ (Location.name(b.name) if b.name is not None else Location.unknown()) - for b in sig.kernel_buffer_input_bindings + for b in sig.kernel_buffer_input_bindings + sig.dynamic_dim_bindings ] entry_block = func_op.add_entry_block(arg_locs) + offset = len(sig.kernel_buffer_input_bindings) + dynamic_argument_map = { + k: v for k, v in zip(dynamic_symbols, entry_block.arguments[offset:]) + } with InsertionPoint(entry_block): assert isinstance(entry_block, Block) # Create a flow.dispatch op to the kernel @@ -52,7 +81,12 @@ def isolated_test_call( entrypoints = ArrayAttr.get([dispatch]) out = flow_d.DispatchOp( - output_tensors, [], entrypoints, entry_block.arguments, [], [] + output_tensors, + [dynamic_argument_map[dim] for dim in dynamic_symbols], + entrypoints, + entry_block.arguments, + [dynamic_argument_map[dim] for dim in argument_dims], + [dynamic_argument_map[dim] for dim in result_dims], ) func_d.ReturnOp(out) diff --git a/shark_turbine/kernel/compiler/kernel_codegen.py b/shark_turbine/kernel/compiler/kernel_codegen.py index 0069630c..0ca1fa5a 100644 --- a/shark_turbine/kernel/compiler/kernel_codegen.py +++ b/shark_turbine/kernel/compiler/kernel_codegen.py @@ -177,6 +177,22 @@ def kernel_buffer_temporary_bindings(self) -> list[BindingDesc]: and b.kernel_buffer_type.usage == KernelBufferUsage.TEMPORARY ] + @property + def dynamic_dim_bindings(self) -> list[BindingDesc]: + """Gets all dynamic dimension bindings.""" + return [b for b in self.bindings if b.binding_type == BindingType.SYMBOL_VALUE] + + def add_from_dynamic_symbols(self, dynamic_symbols: list[IndexSymbol]): + for symbol in dynamic_symbols: + self.bindings.append( + BindingDesc( + ("symbol", symbol), + BindingType.SYMBOL_VALUE, + name=symbol.name, + symbol_type=symbol, + ) + ) + def add_from_graph_placeholders(self, graph: fx.Graph): # Extract all placeholder nodes. placeholder_nodes = filter_fx_graph(graph, is_placeholder) diff --git a/shark_turbine/kernel/wave/codegen.py b/shark_turbine/kernel/wave/codegen.py index abb40aaf..3587e0f3 100644 --- a/shark_turbine/kernel/wave/codegen.py +++ b/shark_turbine/kernel/wave/codegen.py @@ -91,6 +91,7 @@ class WaveEmitter: root_sig: BoundKernelSignature trace: CapturedTrace constraints: list[Constraint] + dynamic_symbols: list[IndexSymbol] ip: InsertionPoint = None OP_HANDLERS: ClassVar[dict[str, Callable[["WaveEmitter", fx.Node], None]]] = {} _node_values: ClassVar[dict[fx.Node, List[IRProxyValue]]] = {} @@ -110,6 +111,11 @@ def emit_program_invariants(self): gpu_d.thread_id(gpu_d.Dimension.z), ] self.induction_vars: dict[IndexSymbol, Value] = {} + self.dynamic_dims: dict[IndexSymbol, Value] = {} + symbol_iterator = iter(self.dynamic_symbols) + for arg in self.root_sig.entry_block.arguments: + if arg.type == IndexType.get(): + self.dynamic_dims[next(symbol_iterator)] = arg def emit(self, graph: Optional[fx.Graph] = None): with self.ip, Location.unknown(): @@ -169,7 +175,32 @@ def get_type_or_element_type(operand_type: IrType): return operand_type -def gen_sympy_index(emitter: WaveEmitter, expr: sympy.Expr) -> OpResult: +def add_emitter_subs(emitter: WaveEmitter) -> dict[IndexSymbol, Any]: + induction_var_syms = [] + induction_vars = [] + if emitter.induction_vars: + for constraint in emitter.constraints: + if isinstance(constraint, TilingConstraint): + assert ( + constraint.dim in emitter.induction_vars + ), f"Could not find induction var for {constraint.dim} dimension" + induction_var_syms.append(constraint.induction_var) + induction_vars.append(emitter.induction_vars[constraint.dim]) + + # TODO: factor this out + all_symbols = emitter.thread_ids + emitter.workgroup_ids + induction_vars + dynamics = dict( + zip( + [THREAD_0, THREAD_1, THREAD_2, WORKGROUP_0, WORKGROUP_1, WORKGROUP_2] + + induction_var_syms, + all_symbols, + ) + ) + dynamics.update(emitter.dynamic_dims) + return dynamics + + +def gen_sympy_index(dynamics: dict[IndexSymbol, Any], expr: sympy.Expr) -> OpResult: stack: list[OpResult] = [] def _broadcast(a, b): @@ -247,27 +278,6 @@ def _get_const(val): raise CodegenError(f"Unsupported const val {val} : {type(val)}") - induction_var_syms = [] - induction_vars = [] - if emitter.induction_vars: - for constraint in emitter.constraints: - if isinstance(constraint, TilingConstraint): - assert ( - constraint.dim in emitter.induction_vars - ), f"Could not find induction var for {constraint.dim} dimension" - induction_var_syms.append(constraint.induction_var) - induction_vars.append(emitter.induction_vars[constraint.dim]) - - # TODO: factor this out - all_symbols = emitter.thread_ids + emitter.workgroup_ids + induction_vars - dynamics = dict( - zip( - [THREAD_0, THREAD_1, THREAD_2, WORKGROUP_0, WORKGROUP_1, WORKGROUP_2] - + induction_var_syms, - all_symbols, - ) - ) - idxc = IndexingContext.current() # Substitute in frozen vars to simplify expression. if not isinstance(expr, sympy.Expr): @@ -325,6 +335,11 @@ def _get_const(val): lhs = stack.pop() res = arith_d.andi(*_broadcast(lhs, rhs)) stack.append(res) + case sympy.ceiling(): + value = stack.pop() + if not isinstance(value, arith_d.DivSIOp): + raise CodegenError(f"Cannot handle ceil({value}) yet") + stack.append(arith_d.CeilDivSIOp(value.lhs, value.rhs)) case sympy.UnevaluatedExpr(): continue case _: @@ -412,7 +427,10 @@ def _get_start_indices( def _build_start_indices( emitter: WaveEmitter, src_indices: dict[IndexExpr, IndexSequence | IndexExpr] ) -> list[OpResult]: - return [gen_sympy_index(emitter, i) for i in _get_start_indices(src_indices)] + return [ + gen_sympy_index(add_emitter_subs(emitter), i) + for i in _get_start_indices(src_indices) + ] def _compute_offset(indices: list[IndexExpr], strides: list[IndexExpr]) -> IndexExpr: @@ -456,7 +474,7 @@ def _build_mask( mask_expr = functools.reduce( lambda a, b: sympy.And(a, b), (new_index[dim] < dim for dim in bounds) ) - mask = gen_sympy_index(emitter, mask_expr) + mask = gen_sympy_index(add_emitter_subs(emitter), mask_expr) mask_vec_type = VectorType.get([elements_per_thread], IntegerType.get_signless(1)) if mask.type != mask_vec_type: @@ -534,7 +552,7 @@ def _construct_gather_scatter_indices( # arith ops and then `vector.insertelement` them into offsets vec. offset = int(offset) else: - dyn_offset = gen_sympy_index(emitter, offset) + dyn_offset = gen_sympy_index(add_emitter_subs(emitter), offset) dynamic_offsets.append((i, dyn_offset)) offset = 0 diff --git a/shark_turbine/kernel/wave/utils.py b/shark_turbine/kernel/wave/utils.py index e3f7a62b..278666b0 100644 --- a/shark_turbine/kernel/wave/utils.py +++ b/shark_turbine/kernel/wave/utils.py @@ -10,6 +10,7 @@ Operation, transform_d, UnitAttr, + Value, ) from typing import Optional, Callable, Any, List, Tuple from .._support.tracing import CapturedTrace @@ -267,9 +268,14 @@ def _invoke(vm_context, device, entry_function, inputs, outputs): ret_list = rt.VmVariantList(len(outputs)) for input in inputs: - input_cpu = input.cpu().contiguous() - device_array = rt.asdevicearray(device, input_cpu) - arg_list.push_ref(device_array._buffer_view) + if isinstance(input, torch.Tensor): + input_cpu = input.cpu().contiguous() + device_array = rt.asdevicearray(device, input_cpu) + arg_list.push_ref(device_array._buffer_view) + elif isinstance(input, int): + arg_list.push_int(input) + else: + raise ValueError(f"Unsupported input type: {type(input)}") vm_context.invoke(entry_function, arg_list, ret_list) diff --git a/shark_turbine/kernel/wave/wave.py b/shark_turbine/kernel/wave/wave.py index 323a3198..fde0c792 100644 --- a/shark_turbine/kernel/wave/wave.py +++ b/shark_turbine/kernel/wave/wave.py @@ -260,6 +260,8 @@ def _trace_and_get_kernel_signature( root_graph = graph.get_root_graph() kernel_sig = kernel_codegen.KernelSignature() kernel_sig.add_from_graph_placeholders(root_graph) + dynamic_symbols = kwargs.get("dynamic_symbols", []) + kernel_sig.add_from_dynamic_symbols(dynamic_symbols) kernel_sig.add_grid(self.grid_type) kernel_sig.determine_input_output_buffers(root_graph) @@ -269,10 +271,17 @@ def _trace_and_get_kernel_signature( workgroup_size = self.hardware_constraints[0].threads_per_block subgroup_size = self.hardware_constraints[0].threads_per_wave dispatch_entrypoint = exe.define_entrypoint( - entrypoint_name, kernel_sig, grid, workgroup_size, subgroup_size + entrypoint_name, + kernel_sig, + grid, + workgroup_size, + subgroup_size, + dynamic_symbols, ) - emitter = WaveEmitter(dispatch_entrypoint, graph, self.constraints) + emitter = WaveEmitter( + dispatch_entrypoint, graph, self.constraints, dynamic_symbols + ) emitter.emit(graph.get_root_graph()) emitter.finish() @@ -294,7 +303,10 @@ def test_execute(self, args, kwargs): run_bench = kwargs.get("run_bench", False) if run or run_bench: # TODO: cache compiled code - host_codegen.isolated_test_call(mb, exe, kernel_sig, entrypoint_name) + dynamic_symbols = kwargs.get("dynamic_symbols", []) + host_codegen.isolated_test_call( + mb, exe, kernel_sig, entrypoint_name, dynamic_symbols + ) asm = mb.module_op.get_asm() kernel_inputs = [] @@ -307,6 +319,10 @@ def test_execute(self, args, kwargs): if usage == kernel_codegen.KernelBufferUsage.OUTPUT: kernel_outputs.append(arg) + dynamic_symbols_map = kwargs.get("dynamic_symbols_map", {}) + if dynamic_symbols: + kernel_inputs += [dynamic_symbols_map[sym] for sym in dynamic_symbols] + config = kwargs.get("run_config", None) if not config: raise ValueError("no config provided") diff --git a/tests/kernel/wave/wave_e2e_test.py b/tests/kernel/wave/wave_e2e_test.py index dbe88424..5b6aa640 100644 --- a/tests/kernel/wave/wave_e2e_test.py +++ b/tests/kernel/wave/wave_e2e_test.py @@ -66,7 +66,8 @@ def test_copy(shape): # elements. wave_size = 64 BLOCK_M = 1 - BLOCK_N = sympy.Max(sympy.Min(N, 256), wave_size) + # Tile size cannot be dynamic, so we use a fixed value here. + BLOCK_N = sympy.Max(sympy.Min(shape[1], 256), wave_size) ELEMS_PER_THREAD = BLOCK_N / wave_size constraints: list[tkw.Constraint] = [ @@ -107,6 +108,61 @@ def test( assert_allclose(a, b) +@require_e2e +@pytest.mark.parametrize("shape", get_test_shapes("test_copy")) +def test_dynamic_copy(shape): + M = tkl.sym.M + N = tkl.sym.N + ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE + + # Each workgroup works on single row of input data, and rows are further + # split into blocks of size up to 256. We have single wave per WG, + # and with default wave size of 64, each thread is operating on up to 4 + # elements. + wave_size = 64 + BLOCK_M = 1 + # Tile size cannot be dynamic, so we use a fixed value here. + BLOCK_N = sympy.Max(sympy.Min(shape[1], 256), wave_size) + ELEMS_PER_THREAD = BLOCK_N / wave_size + + constraints: list[tkw.Constraint] = [ + tkw.HardwareConstraint( + threads_per_wave=wave_size, + waves_per_block=(1, 1, 1), + vector_shapes={M: BLOCK_M, N: BLOCK_N}, + ) + ] + constraints += [tkw.WorkgroupConstraint(M, BLOCK_M, 1)] + constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 0)] + constraints += [tkw.WaveConstraint(M, BLOCK_M)] + constraints += [tkw.WaveConstraint(N, BLOCK_N)] + + @tkw.wave(constraints) + def test( + a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16], + b: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16], + ): + res = tkw.read(a, elements_per_thread=ELEMS_PER_THREAD) + tkw.write(res, b, elements_per_thread=ELEMS_PER_THREAD) + + config = {"backend": "rocm", "device": "hip", "target": "gfx942"} + + a = torch.randn(shape, dtype=torch.float16) + b = torch.zeros(shape, dtype=torch.float16) + with tk.gen.TestLaunchContext( + { + ADDRESS_SPACE: tkl.AddressSpace.GLOBAL_MEMORY.value, + }, + dynamic_symbols=(M, N), + dynamic_symbols_map={M: shape[0], N: shape[1]}, + canonicalize=True, + run=True, + run_config=config, + ): + test(a, b) + assert_allclose(a, b) + + @require_e2e @pytest.mark.parametrize("shape", get_test_shapes("test_transpose_read")) def test_transpose_read(shape): From 9ed388a6494b2fe27e25436528379013ea16e7c6 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Thu, 3 Oct 2024 18:28:28 +0300 Subject: [PATCH 14/28] [TKW] Fix sympy expr lowering and add some more igemm test shapes (#184) * Rework how we are lowering `rational` sympy expressions, instead of delayed materialization via lambdas introduce `_Rational` type and propagate `numerator/denominator` values independently. Division will only be materialized on explicit `sympy.floor/ceiling` op. * Rework how igemm test cases are generated and introduce few real shapes. * Use custom pytest markers to separate perf/non-perf tests --------- Signed-off-by: Ivan Butygin --- shark_turbine/kernel/wave/codegen.py | 202 +++++++++++++++++---------- tests/conftest.py | 31 ++++ tests/kernel/wave/wave_e2e_test.py | 192 ++++++++++++++++++++++++- 3 files changed, 346 insertions(+), 79 deletions(-) create mode 100644 tests/conftest.py diff --git a/shark_turbine/kernel/wave/codegen.py b/shark_turbine/kernel/wave/codegen.py index 3587e0f3..adcd69b8 100644 --- a/shark_turbine/kernel/wave/codegen.py +++ b/shark_turbine/kernel/wave/codegen.py @@ -12,6 +12,7 @@ from dataclasses import dataclass import torch.fx as fx import torch.utils._pytree as pytree +from collections import namedtuple from ..compiler.ir import ( Attribute, @@ -200,76 +201,139 @@ def add_emitter_subs(emitter: WaveEmitter) -> dict[IndexSymbol, Any]: return dynamics +_Rational = namedtuple("_Rational", ["numerator", "denominator"]) + + def gen_sympy_index(dynamics: dict[IndexSymbol, Any], expr: sympy.Expr) -> OpResult: stack: list[OpResult] = [] - def _broadcast(a, b): - if not isinstance(a, (Value, OpResult)): - a = a.result + def _get_ir_value(arg): + if not isinstance(arg, (Value, OpResult)): + arg = arg.result + + return arg - if not isinstance(b, (Value, OpResult)): - b = b.result + def _check_vec_scalar(a, b): + return isinstance(a.type, VectorType) and a.type.element_type == b.type + + def _broadcast(a, b): + a = _get_ir_value(a) + b = _get_ir_value(b) if a.type == b.type: return a, b - if isinstance(a.type, VectorType) and isinstance( - b.type, (IndexType, IntegerType) - ): - assert a.type.element_type == b.type + if _check_vec_scalar(a, b): b = vector_d.splat(a.type, b) return a, b - if isinstance(a.type, (IndexType, IntegerType)) and isinstance( - b.type, VectorType - ): - assert b.type.element_type == a.type + if _check_vec_scalar(b, a): a = vector_d.splat(b.type, a) return a, b raise CodegenError(f"Cannot broadcast {a.type} and {b.type}") - def _process_mul_add_ops(term, is_mul): - args = [] - callables = [] - for _ in range(len(term.args)): - val = stack.pop() - if callable(val): - callables.append(val) - else: - args.append(val) - operation = None - for arg in args: - if operation is None: - operation = arg - continue + def get_const_val(arg): + if isinstance(arg, OpResult): + arg = arg.owner.opview - if is_mul: - operation = arith_d.MulIOp(*_broadcast(operation, arg)) - else: - operation = arith_d.AddIOp(*_broadcast(operation, arg)) + if isinstance(arg, arith_d.ConstantOp): + value = arg.attributes["value"] + if isinstance(value, IntegerAttr): + return int(value) - for arg in callables: - operation = arg(operation, is_mul) + return None - stack.append(operation) + def muli_fold(lhs, rhs): + if get_const_val(lhs) == 1: + return rhs + + if get_const_val(rhs) == 1: + return lhs + + return arith_d.muli(lhs, rhs) + + # `x + (a/b)` transformed into `(x*b + a) / b` + def _add(lhs, rhs): + is_rational_lhs = isinstance(lhs, _Rational) + is_rational_rhs = isinstance(rhs, _Rational) + if is_rational_lhs and not is_rational_rhs: + numerator = muli_fold(*_broadcast(lhs.denominator, rhs)) + numerator = arith_d.addi(*_broadcast(numerator, lhs.numerator)) + return _Rational(numerator, lhs.denominator) + elif not is_rational_lhs and is_rational_rhs: + numerator = muli_fold(*_broadcast(lhs, rhs.denominator)) + numerator = arith_d.addi(*_broadcast(numerator, rhs.numerator)) + return _Rational(numerator, rhs.denominator) + elif is_rational_lhs and is_rational_rhs: + lhs_numerator = muli_fold(*_broadcast(lhs.numerator, rhs.denominator)) + rhs_numerator = muli_fold(*_broadcast(rhs.numerator, lhs.denominator)) + numerator = arith_d.addi(*_broadcast(lhs_numerator, rhs_numerator)) + denominator = muli_fold(*_broadcast(lhs.denominator, rhs.denominator)) + return _Rational(numerator, denominator) + else: + return arith_d.addi(*_broadcast(lhs, rhs)) + + # `x * (a/b)` transformed into `(x * a) / b` + def _mul(lhs, rhs): + is_rational_lhs = isinstance(lhs, _Rational) + is_rational_rhs = isinstance(rhs, _Rational) + if is_rational_lhs and not is_rational_rhs: + numerator = muli_fold(*_broadcast(lhs.numerator, rhs)) + return _Rational(numerator, lhs.denominator) + elif not is_rational_lhs and is_rational_rhs: + numerator = muli_fold(*_broadcast(lhs, rhs.numerator)) + return _Rational(numerator, rhs.denominator) + elif is_rational_lhs and is_rational_rhs: + numerator = muli_fold(*_broadcast(lhs.numerator, rhs.numerator)) + denominator = muli_fold(*_broadcast(lhs.denominator, rhs.denominator)) + return _Rational(numerator, denominator) + else: + return muli_fold(*_broadcast(lhs, rhs)) - def _get_mul(numerator): - return lambda x: arith_d.MulIOp(*_broadcast(x, numerator)) + def _floor(value): + if isinstance(value, _Rational): + value = arith_d.divsi(*_broadcast(value.numerator, value.denominator)) - def _get_add(numerator, denominator): - return lambda x: arith_d.AddIOp( - *_broadcast(arith_d.MulIOp(*_broadcast(x, denominator)), numerator) - ) + return value - def _get_div(mul, add, denominator): - return lambda x, is_mul: arith_d.DivSIOp( - *_broadcast(mul(x) if is_mul else add(x), denominator) - ) + def _ceiling(value): + if isinstance(value, _Rational): + value = arith_d.ceildivsi(*_broadcast(value.numerator, value.denominator)) + + return value + + def _group_rationals(stack, count): + """Group rationals and non-rationals args into 2 contiguous sets. + + This allows to mul/add all non-rationals first, reducing total number of ops. + """ + rationals = [] + non_rationals = [] + for _ in range(count): + val = stack.pop() + if isinstance(val, _Rational): + rationals.append(val) + else: + non_rationals.append(val) + + return non_rationals + rationals + + def _apply(args, func): + assert len(args) > 0 + value = args[0] + for val in args[1:]: + value = func(value, val) + + return value + + def _enforce_non_rational(val, term): + if isinstance(val, _Rational): + raise CodegenError(f"Rational is not supported yet in '{type(term)}'") def _get_const(val): if isinstance(val, int): - return arith_d.constant(IndexType.get(), res) + return arith_d.constant(IndexType.get(), val) if isinstance(val, (tuple, list)): vec_type = VectorType.get([len(val)], IndexType.get()) @@ -296,56 +360,50 @@ def _get_const(val): else: raise CodegenError(f"Unknown symbol {term}") case sympy.Integer(): - stack.append(arith_d.constant(IndexType.get(), int(term))) + stack.append(_get_const(int(term))) case sympy.Mul(): - _process_mul_add_ops(term, is_mul=True) + args = _group_rationals(stack, len(term.args)) + stack.append(_apply(args, _mul)) case sympy.Add(): - _process_mul_add_ops(term, is_mul=False) + args = _group_rationals(stack, len(term.args)) + stack.append(_apply(args, _add)) case sympy.Mod(): rhs = stack.pop() lhs = stack.pop() - mod = arith_d.RemSIOp(*_broadcast(lhs, rhs)) + _enforce_non_rational(rhs, term) + _enforce_non_rational(lhs, term) + mod = arith_d.remsi(*_broadcast(lhs, rhs)) stack.append(mod) case sympy.floor(): - # TODO: Since divsi rounds to zero, this seems to work. - # But check whether floordivsi is needed. - stack.append(stack.pop()) + stack.append(_floor(stack.pop())) + case sympy.ceiling(): + stack.append(_ceiling(stack.pop())) case sympy.Rational(): - # `x * (a/b)` transformed into `(x * a) / b` - # `x + (a/b)` transformed into `(x*b + a) / b` - numerator = arith_d.constant(IndexType.get(), abs(term.p)) - denominator = arith_d.constant(IndexType.get(), abs(term.q)) - # Assumes that the negative term is always carried on the numerator - if abs(term.p) > term.p: - zero = arith_d.constant(IndexType.get(), int(0)) - numerator = arith_d.SubIOp(*_broadcast(zero, numerator)) - mul = lambda x: x - if abs(term.p) != 1: - mul = _get_mul(numerator) - add = _get_add(numerator, denominator) - operation = _get_div(mul, add, denominator) - stack.append(operation) + numerator = _get_const(term.p) + denominator = _get_const(term.q) + stack.append(_Rational(numerator, denominator)) case sympy.StrictLessThan(): rhs = stack.pop() lhs = stack.pop() + _enforce_non_rational(rhs, term) + _enforce_non_rational(lhs, term) res = arith_d.cmpi(arith_d.CmpIPredicate.slt, *_broadcast(lhs, rhs)) stack.append(res) case sympy.And(): rhs = stack.pop() lhs = stack.pop() + _enforce_non_rational(rhs, term) + _enforce_non_rational(lhs, term) res = arith_d.andi(*_broadcast(lhs, rhs)) stack.append(res) - case sympy.ceiling(): - value = stack.pop() - if not isinstance(value, arith_d.DivSIOp): - raise CodegenError(f"Cannot handle ceil({value}) yet") - stack.append(arith_d.CeilDivSIOp(value.lhs, value.rhs)) case sympy.UnevaluatedExpr(): continue case _: raise CodegenError(f"Can not handle {type(term)} : {term}") - if len(stack) != 1: + + if len(stack) != 1 or isinstance(stack[0], _Rational): raise CodegenError(f"Expected single result, got {len(stack)}") + return stack[0] diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..93bd8d6f --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,31 @@ +# Copyright 2024 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import pytest + + +def pytest_addoption(parser): + parser.addoption( + "--runperf", action="store_true", default=False, help="run performance tests" + ) + + +def pytest_configure(config): + config.addinivalue_line( + "markers", "perf_only: performace test, runs only with '--runperf'" + ) + + +def pytest_collection_modifyitems(config, items): + run_perf = config.getoption("--runperf") + for item in items: + is_perf_only = next(item.iter_markers("perf_only"), None) is not None + if run_perf: + if not is_perf_only: + item.add_marker(pytest.mark.skip("skip non-perf test")) + else: + if is_perf_only: + item.add_marker(pytest.mark.skip("skip perf test")) diff --git a/tests/kernel/wave/wave_e2e_test.py b/tests/kernel/wave/wave_e2e_test.py index 5b6aa640..69ba718a 100644 --- a/tests/kernel/wave/wave_e2e_test.py +++ b/tests/kernel/wave/wave_e2e_test.py @@ -641,15 +641,193 @@ def repeat(acc: tkl.Register[M, NF, tkl.f32]) -> tkl.Register[M, NF, tkl.f32]: assert_allclose(out, out_ref, rtol=1e-05, atol=1e-05) +_igemm_cases = [ + (4, 5, 5, 10, 2, 2, 16, 3), + (2, 5, 5, 10, 2, 2, 16, 3), + (1, 5, 5, 10, 2, 2, 16, 3), + (4, 5, 5, 4, 2, 2, 16, 3), + (1, 5, 5, 4, 2, 2, 16, 3), + (1, 5, 5, 3, 2, 2, 16, 3), + (2, 5, 5, 1, 2, 2, 16, 3), + (4, 5, 5, 10, 2, 2, 2, 3), + (2, 5, 5, 10, 2, 2, 2, 3), + (1, 5, 5, 10, 2, 2, 2, 3), + (4, 5, 5, 4, 2, 2, 2, 3), + (2, 5, 5, 4, 2, 2, 2, 3), + (1, 5, 5, 3, 2, 2, 2, 3), + (2, 5, 5, 1, 2, 2, 2, 3), + (1, 5, 5, 1, 2, 2, 2, 3), + (4, 5, 5, 10, 2, 2, 1, 3), + (2, 5, 5, 10, 2, 2, 1, 3), + (1, 5, 5, 10, 2, 2, 1, 3), + (4, 5, 5, 4, 2, 2, 1, 3), + (2, 5, 5, 4, 2, 2, 1, 3), + (1, 5, 5, 4, 2, 2, 1, 3), + (2, 5, 5, 3, 2, 2, 1, 3), + (4, 5, 5, 1, 2, 2, 1, 3), + (2, 5, 5, 1, 2, 2, 1, 3), + (1, 5, 5, 1, 2, 2, 1, 3), + (4, 5, 5, 10, 2, 2, 16, 2), + (2, 5, 5, 10, 2, 2, 16, 2), + (1, 5, 5, 10, 2, 2, 16, 2), + (4, 5, 5, 4, 2, 2, 16, 2), + (1, 5, 5, 4, 2, 2, 16, 2), + (4, 5, 5, 3, 2, 2, 16, 2), + (4, 5, 5, 1, 2, 2, 16, 2), + (1, 5, 5, 1, 2, 2, 16, 2), + (4, 5, 5, 10, 2, 2, 2, 2), + (2, 5, 5, 10, 2, 2, 2, 2), + (1, 5, 5, 10, 2, 2, 2, 2), + (4, 5, 5, 4, 2, 2, 2, 2), + (2, 5, 5, 4, 2, 2, 2, 2), + (2, 5, 5, 3, 2, 2, 2, 2), + (2, 5, 5, 1, 2, 2, 2, 2), + (1, 5, 5, 1, 2, 2, 2, 2), + (4, 5, 5, 10, 2, 2, 1, 2), + (2, 5, 5, 10, 2, 2, 1, 2), + (1, 5, 5, 10, 2, 2, 1, 2), + (4, 5, 5, 4, 2, 2, 1, 2), + (2, 5, 5, 4, 2, 2, 1, 2), + (1, 5, 5, 4, 2, 2, 1, 2), + (4, 5, 5, 1, 2, 2, 1, 2), + (1, 5, 5, 1, 2, 2, 1, 2), + (4, 5, 5, 10, 2, 2, 16, 1), + (2, 5, 5, 10, 2, 2, 16, 1), + (4, 5, 5, 4, 2, 2, 16, 1), + (2, 5, 5, 4, 2, 2, 16, 1), + (1, 5, 5, 4, 2, 2, 16, 1), + (4, 5, 5, 3, 2, 2, 16, 1), + (1, 5, 5, 3, 2, 2, 16, 1), + (2, 5, 5, 1, 2, 2, 16, 1), + (1, 5, 5, 1, 2, 2, 16, 1), + (4, 5, 5, 10, 2, 2, 2, 1), + (2, 5, 5, 10, 2, 2, 2, 1), + (1, 5, 5, 10, 2, 2, 2, 1), + (4, 5, 5, 4, 2, 2, 2, 1), + (2, 5, 5, 4, 2, 2, 2, 1), + (1, 5, 5, 4, 2, 2, 2, 1), + (1, 5, 5, 3, 2, 2, 2, 1), + (2, 5, 5, 1, 2, 2, 2, 1), + (1, 5, 5, 1, 2, 2, 2, 1), + (4, 5, 5, 10, 2, 2, 1, 1), + (2, 5, 5, 10, 2, 2, 1, 1), + (4, 5, 5, 4, 2, 2, 1, 1), + (2, 5, 5, 4, 2, 2, 1, 1), + (1, 5, 5, 4, 2, 2, 1, 1), + (2, 5, 5, 1, 2, 2, 1, 1), + (1, 5, 5, 1, 2, 2, 1, 1), + (4, 5, 5, 10, 2, 2, 16, 3), + (2, 5, 5, 10, 2, 2, 16, 3), + (1, 5, 5, 10, 2, 2, 16, 3), + (4, 5, 5, 4, 2, 2, 16, 3), + (2, 5, 5, 4, 2, 2, 16, 3), + (1, 5, 5, 4, 2, 2, 16, 3), + (4, 5, 5, 1, 2, 2, 16, 3), + (1, 5, 5, 1, 2, 2, 16, 3), + (4, 5, 5, 10, 2, 2, 2, 3), + (1, 5, 5, 10, 2, 2, 2, 3), + (2, 5, 5, 4, 2, 2, 2, 3), + (1, 5, 5, 4, 2, 2, 2, 3), + (2, 5, 5, 3, 2, 2, 2, 3), + (4, 5, 5, 1, 2, 2, 2, 3), + (2, 5, 5, 1, 2, 2, 2, 3), + (1, 5, 5, 1, 2, 2, 2, 3), + (4, 5, 5, 10, 2, 2, 1, 3), + (2, 5, 5, 10, 2, 2, 1, 3), + (1, 5, 5, 10, 2, 2, 1, 3), + (4, 5, 5, 4, 2, 2, 1, 3), + (2, 5, 5, 4, 2, 2, 1, 3), + (1, 5, 5, 4, 2, 2, 1, 3), + (4, 5, 5, 1, 2, 2, 1, 3), + (2, 5, 5, 1, 2, 2, 1, 3), + (1, 5, 5, 1, 2, 2, 1, 3), + (4, 5, 5, 10, 2, 2, 16, 2), + (2, 5, 5, 10, 2, 2, 16, 2), + (1, 5, 5, 10, 2, 2, 16, 2), + (4, 5, 5, 4, 2, 2, 16, 2), + (1, 5, 5, 4, 2, 2, 16, 2), + (4, 5, 5, 1, 2, 2, 16, 2), + (2, 5, 5, 1, 2, 2, 16, 2), + (4, 5, 5, 10, 2, 2, 2, 2), + (2, 5, 5, 10, 2, 2, 2, 2), + (1, 5, 5, 10, 2, 2, 2, 2), + (4, 5, 5, 4, 2, 2, 2, 2), + (2, 5, 5, 4, 2, 2, 2, 2), + (1, 5, 5, 4, 2, 2, 2, 2), + (1, 5, 5, 3, 2, 2, 2, 2), + (2, 5, 5, 1, 2, 2, 2, 2), + (1, 5, 5, 1, 2, 2, 2, 2), + (2, 5, 5, 10, 2, 2, 1, 2), + (1, 5, 5, 10, 2, 2, 1, 2), + (4, 5, 5, 4, 2, 2, 1, 2), + (2, 5, 5, 4, 2, 2, 1, 2), + (1, 5, 5, 4, 2, 2, 1, 2), + (1, 5, 5, 3, 2, 2, 1, 2), + (2, 5, 5, 1, 2, 2, 1, 2), + (1, 5, 5, 1, 2, 2, 1, 2), + (4, 5, 5, 10, 2, 2, 16, 1), + (2, 5, 5, 10, 2, 2, 16, 1), + (1, 5, 5, 10, 2, 2, 16, 1), + (4, 5, 5, 4, 2, 2, 16, 1), + (2, 5, 5, 4, 2, 2, 16, 1), + (1, 5, 5, 4, 2, 2, 16, 1), + (2, 5, 5, 3, 2, 2, 16, 1), + (1, 5, 5, 3, 2, 2, 16, 1), + (4, 5, 5, 1, 2, 2, 16, 1), + (1, 5, 5, 1, 2, 2, 16, 1), + (4, 5, 5, 10, 2, 2, 2, 1), + (1, 5, 5, 10, 2, 2, 2, 1), + (4, 5, 5, 4, 2, 2, 2, 1), + (2, 5, 5, 4, 2, 2, 2, 1), + (1, 5, 5, 4, 2, 2, 2, 1), + (4, 5, 5, 3, 2, 2, 2, 1), + (4, 5, 5, 1, 2, 2, 2, 1), + (2, 5, 5, 1, 2, 2, 2, 1), + (1, 5, 5, 1, 2, 2, 2, 1), + (4, 5, 5, 10, 2, 2, 1, 1), + (2, 5, 5, 10, 2, 2, 1, 1), + (4, 5, 5, 4, 2, 2, 1, 1), + (2, 5, 5, 4, 2, 2, 1, 1), + (1, 5, 5, 4, 2, 2, 1, 1), + (4, 5, 5, 3, 2, 2, 1, 1), + (2, 5, 5, 3, 2, 2, 1, 1), + (1, 5, 5, 3, 2, 2, 1, 1), + (2, 5, 5, 1, 2, 2, 1, 1), + (1, 5, 5, 1, 2, 2, 1, 1), + (1, 5, 5, 1, 3, 3, 1, 1), +] + +perf_test = lambda *a: pytest.param(*a, marks=pytest.mark.perf_only) + +_igemm_cases += [ + perf_test(2, 128, 128, 16, 3, 3, 320, 1), + perf_test(2, 128, 128, 320, 1, 1, 640, 1), + perf_test(2, 128, 128, 320, 1, 1, 960, 1), + perf_test(2, 128, 128, 320, 3, 3, 16, 1), + perf_test(2, 128, 128, 320, 3, 3, 320, 1), + perf_test(2, 32, 32, 1280, 1, 1, 1920, 1), + perf_test(2, 32, 32, 1280, 1, 1, 2560, 1), + perf_test(2, 32, 32, 1280, 1, 1, 640, 1), + perf_test(2, 32, 32, 1280, 3, 3, 1280, 1), + perf_test(2, 32, 32, 1280, 3, 3, 1920, 1), + perf_test(2, 32, 32, 1280, 3, 3, 2560, 1), + perf_test(2, 32, 32, 1280, 3, 3, 640, 1), + perf_test(2, 32, 32, 640, 3, 3, 640, 1), + perf_test(2, 64, 64, 320, 3, 3, 320, 1), + perf_test(2, 64, 64, 640, 1, 1, 1280, 1), + perf_test(2, 64, 64, 640, 1, 1, 1920, 1), + perf_test(2, 64, 64, 640, 1, 1, 320, 1), + perf_test(2, 64, 64, 640, 1, 1, 960, 1), + perf_test(2, 64, 64, 640, 3, 3, 320, 1), + perf_test(2, 64, 64, 640, 3, 3, 640, 1), +] + + @require_e2e -@pytest.mark.parametrize("n", [1, 2, 4]) -@pytest.mark.parametrize("c", [1, 3, 4, 10]) -@pytest.mark.parametrize("nf", [1, 2, 16]) -@pytest.mark.parametrize("stride", [1, 2, 3]) +@pytest.mark.parametrize("n, h, w, c, hf, wf, nf, stride", _igemm_cases) @pytest.mark.parametrize("mem_space", [GLOBAL_ADDRESS_SPACE, SHARED_ADDRESS_SPACE]) -def test_igemm_conv(n, c, nf, stride, mem_space): - h, w = 5, 5 # Image. - cf, hf, wf = c, 2, 2 # Filters. +def test_igemm_conv(n, h, w, c, hf, wf, nf, stride, mem_space): + cf = c padding = 0 # TODO: only pad=0 is supported for now torch.manual_seed(1) From 0f00c6d77e9a0ef19069fe7c021b89d36dc100df Mon Sep 17 00:00:00 2001 From: erman-gurses <99776114+erman-gurses@users.noreply.github.com> Date: Thu, 3 Oct 2024 11:08:06 -0700 Subject: [PATCH 15/28] Add benchmark support for e2e tests (#183) Signed-off-by: erman-gurses --- .github/workflows/ci.yaml | 2 +- shark_turbine/kernel/wave/utils.py | 27 ++++++++++++++++++++++++++- tests/kernel/wave/wave_e2e_test.py | 8 ++++++++ tests/kernel/wave/wave_gemm_test.py | 1 + 4 files changed, 36 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 5146f144..0796d30a 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -60,7 +60,7 @@ jobs: if: "contains(matrix.os, 'mi300') && !cancelled()" run: | export WAVE_RUN_E2E_TESTS=1 - pytest -n 4 ./tests/kernel/wave/ + pytest -n 4 --capture=tee-sys ./tests/kernel/wave/ - name: Run LIT tests if: ${{ !cancelled() }} diff --git a/shark_turbine/kernel/wave/utils.py b/shark_turbine/kernel/wave/utils.py index 278666b0..9ea9adad 100644 --- a/shark_turbine/kernel/wave/utils.py +++ b/shark_turbine/kernel/wave/utils.py @@ -35,6 +35,9 @@ import torch.fx as fx import shark_turbine.kernel.lang as tkl + +import tempfile +from ...support.conversions import TORCH_DTYPE_TO_SIGNED_MLIR_TYPE_ASM from iree.compiler.dialects.transform import ( interpreter as transform_interpreter, any_op_t, @@ -372,7 +375,29 @@ def compile_and_invoke( _invoke(ctx.vm_context, device, func, kernel_inputs, kernel_outputs) if run_bench: - inputs = [inp.numpy() for inp in kernel_inputs] + bench_with_constant_weights = config.get("bench_with_constant_weights", False) + tempfiles = [] + inputs = [] + if bench_with_constant_weights: + for inp in kernel_inputs: + inputs.append( + "x".join( + [str(x) for x in inp.shape] + + [TORCH_DTYPE_TO_SIGNED_MLIR_TYPE_ASM[inp.dtype]] + ) + ) + else: + for inp in kernel_inputs: + tf = tempfile.NamedTemporaryFile() + torch.save(inp, tf) + tempfiles.append(tf) + inputs.append("@" + tf.name) + + benchmark_results = bench.benchmark_module( + mod, + entry_function=func_name, + ) + benchmark_results = bench.benchmark_module( mod, entry_function=func_name, diff --git a/tests/kernel/wave/wave_e2e_test.py b/tests/kernel/wave/wave_e2e_test.py index 69ba718a..633c73c5 100644 --- a/tests/kernel/wave/wave_e2e_test.py +++ b/tests/kernel/wave/wave_e2e_test.py @@ -102,6 +102,7 @@ def test( }, canonicalize=True, run=True, + run_bench=True, run_config=config, ): test(a, b) @@ -214,6 +215,7 @@ def test( }, canonicalize=True, run=True, + run_bench=True, run_config=config, ): test(a, b) @@ -270,6 +272,7 @@ def test( }, canonicalize=True, run=True, + run_bench=True, run_config=config, ): test(a, b) @@ -326,6 +329,7 @@ def test( }, canonicalize=True, run=True, + run_bench=True, run_config=config, ): test(a, b, c) @@ -401,6 +405,7 @@ def repeat( }, canonicalize=True, run=True, + run_bench=True, run_config=config, ): test(a, b, c) @@ -505,6 +510,7 @@ def test( }, canonicalize=True, run=True, + run_bench=True, run_config=config, ): test(a, b) @@ -635,6 +641,7 @@ def repeat(acc: tkl.Register[M, NF, tkl.f32]) -> tkl.Register[M, NF, tkl.f32]: }, canonicalize=True, run=True, + run_bench=True, run_config=config, ): gpu_func(x, we, out) @@ -949,6 +956,7 @@ def repeat(acc: tkl.Register[M, NF, tkl.f32]) -> tkl.Register[M, NF, tkl.f32]: }, canonicalize=True, run=True, + run_bench=True, run_config=config, ): conv(x, we, out) diff --git a/tests/kernel/wave/wave_gemm_test.py b/tests/kernel/wave/wave_gemm_test.py index 2386ebd9..63e51909 100644 --- a/tests/kernel/wave/wave_gemm_test.py +++ b/tests/kernel/wave/wave_gemm_test.py @@ -122,6 +122,7 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: hyperparams, canonicalize=True, run=True, + run_bench=True, run_config=config, schedule=enable_scheduling, ): From e0a8fdf6b50a3b66faad752cc9aaef9ed459aee0 Mon Sep 17 00:00:00 2001 From: Stanley Winata <68087699+raikonenfnu@users.noreply.github.com> Date: Thu, 3 Oct 2024 11:43:14 -0700 Subject: [PATCH 16/28] [TKW] Thread Shape analysis (#186) The motivation of this pass is to generalize the register analysis pass which is used to determine the thread shape of TKW.Register, to all other operations. One main use case for such is to allow reduction, and later on "broadcast" to use thread shape information from the kernel as opposed to relying on vector_shape which may not always be valid. We generalize the register analysis metho by finding a few anchor ops who's thread shape information is determined, and then propagate to it's successors and ancestors. In addition to that we also implemented a couple helper function/attributes. 1. Control_fn on BFS, ForwardSlice, BackwardSlice. This is to make it easier for us to control/stop the search when we hit ops we do not want to explore. In this case, we do not want to explore/propagate onto other anchor ops and their children. 2. Introducing parent_op to IterArg and region of Reduction, for developer ergonomics. 3. Move handling of IterArg and GetUser in BackwardSlice/BFS's get_input exploration phase to be handled individually as opposed to being handled when its' consumer is being explored. Previously to explore/propagate IterArg/GetUser, we need to explore its' consumer, just exploring IterArg/GetUser will not get handled correctly. This is useful for the case where we want to propagate/explore mma.acc (usually IterArg) directly. --------- Signed-off-by: Stanley Winata --- shark_turbine/kernel/ops/wave_ops.py | 8 + shark_turbine/kernel/wave/codegen.py | 4 +- .../kernel/wave/decompose_reduce_ops.py | 11 +- .../kernel/wave/register_analysis.py | 93 ------------ .../kernel/wave/thread_shape_analysis.py | 142 ++++++++++++++++++ shark_turbine/kernel/wave/utils.py | 51 ++++--- shark_turbine/kernel/wave/wave.py | 8 +- 7 files changed, 193 insertions(+), 124 deletions(-) delete mode 100644 shark_turbine/kernel/wave/register_analysis.py create mode 100644 shark_turbine/kernel/wave/thread_shape_analysis.py diff --git a/shark_turbine/kernel/ops/wave_ops.py b/shark_turbine/kernel/ops/wave_ops.py index f1292e4d..2373feec 100644 --- a/shark_turbine/kernel/ops/wave_ops.py +++ b/shark_turbine/kernel/ops/wave_ops.py @@ -651,6 +651,13 @@ class IterArg(Placeholder): a reduction node. """ + def parent_op(self): + return get_custom(self.graph.parent_op) + + def get_iter_idx(self): + src_reduction = self.parent_op() + return src_reduction.iter_args(self.graph).index(self.fx_node) + # Ops modeling TKW operations in the kernel language @@ -847,6 +854,7 @@ def wrapper(f): node._add_proxy_to_graph(graph) node.fx_node.node.tkw_op = cls node.fx_node.node.tkw_op_name = cls.tkw_op_name + graph.subgraphs[subgraph_name].parent_op = node.fx_node.node return node.fx_node return wrapper diff --git a/shark_turbine/kernel/wave/codegen.py b/shark_turbine/kernel/wave/codegen.py index adcd69b8..313e72cb 100644 --- a/shark_turbine/kernel/wave/codegen.py +++ b/shark_turbine/kernel/wave/codegen.py @@ -436,8 +436,8 @@ def handle_register(emitter: WaveEmitter, node: fx.Node): shape, dtype, value = node.args except ValueError as e: raise ValidationError("Malformed arguments") from e - if hasattr(node, "thread_shape"): - shape = [node.thread_shape] + get_thread_shape = lambda index: max(x.size for x in index.values()) + shape = [get_thread_shape(get_custom(node).index)] vector_shape = cast_py_literal(emitter, shape) element_type = IrType.parse(dtype.ir_type_asm()) vector_type = VectorType.get(vector_shape, element_type) diff --git a/shark_turbine/kernel/wave/decompose_reduce_ops.py b/shark_turbine/kernel/wave/decompose_reduce_ops.py index 1dac06cc..9916bb50 100644 --- a/shark_turbine/kernel/wave/decompose_reduce_ops.py +++ b/shark_turbine/kernel/wave/decompose_reduce_ops.py @@ -20,9 +20,10 @@ ShuffleOp, CustomOp, ExtractSlice, + Reduction, ) -from .utils import DCE +from .utils import DCE, subs_idxc import torch.fx as fx import math from typing import Callable @@ -103,9 +104,11 @@ def decompose_reduce_ops( raise NotImplementedError( "Only implemented reduction on fastest dimension." ) - reduction_block_size = constraint_tile_size[reduction_dim] - reduction_size = reduction_block_size.subs(index_map) - local_reduction_size = reduction_size / subgroup_size + + get_thread_shape = lambda index: max( + subs_idxc(x.size) for x in index.values() + ) + local_reduction_size = get_thread_shape(get_custom(custom.arg).index) local_reduction = emit_local_reduction( binary_fn, reduction_src, custom.graph, local_reduction_size ) diff --git a/shark_turbine/kernel/wave/register_analysis.py b/shark_turbine/kernel/wave/register_analysis.py deleted file mode 100644 index cbd42fbe..00000000 --- a/shark_turbine/kernel/wave/register_analysis.py +++ /dev/null @@ -1,93 +0,0 @@ -# Copyright 2024 The IREE Authors -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -from ..wave.constraints import Constraint -from .._support.indexing import IndexingContext, IndexSequence, IndexSymbol, IndexExpr -from .._support.tracing import CapturedTrace -from ...support.logging import get_logger -from ..ops.wave_ops import get_custom, NewRegister, CustomOp, MMA, Reduction, ReduceOp -from .utils import get_hardware_vector_map -import torch.fx as fx - -logger = get_logger("turbine.wave.register_analysis") - - -def set_register_shape( - trace: CapturedTrace, custom: CustomOp, vector_map: dict[IndexSymbol, int] -) -> None: - for custom_user in custom.users: - if isinstance(custom_user, MMA): - arg_index = custom_user.fx_node.args.index(custom.fx_node) - get_thread_shape = lambda index: max(x.size for x in index.values()) - match arg_index: - case 0: - custom.fx_node.thread_shape = get_thread_shape( - custom_user.lhs_index - ) - case 1: - custom.fx_node.thread_shape = get_thread_shape( - custom_user.rhs_index - ) - case 2: - custom.fx_node.thread_shape = get_thread_shape( - custom_user.acc_index - ) - break - - elif isinstance(custom_user, Reduction): - idx = custom_user.init_args.index(custom.fx_node) - iter_arg = get_custom( - custom_user.iter_args(trace.get_subgraph(custom_user.subgraph_name))[ - idx - ] - ) - set_register_shape(trace, iter_arg, vector_map) - custom.fx_node.thread_shape = iter_arg.fx_node.thread_shape - break - elif isinstance(custom_user, ReduceOp): - # Check that dim is non-reduction && in hw_constraint.vector_shape. - is_parallel_dim = lambda dim: dim != custom_user.dim and dim in vector_map - # TODO: Modify num_reduction_dims once we add support for multi-dim reduction. - num_reduction_dims = 1 - register_shape = [ - vector_map[dim] - for dim in custom_user.type.symbolic_shape - if is_parallel_dim(dim) - ] - expected_result_rank = ( - len(custom_user.type.symbolic_shape) - custom_user.num_reduction_dims - ) - # If rank do not match => some dims not found in hw_constraint.vector_shape. - if len(register_shape) != expected_result_rank: - raise NotImplementedError( - "NYI: Handling of dim not in vector_shapes during register analysis." - ) - non_unit_dims = sum(1 for dim in register_shape if dim > 1) - if non_unit_dims > 1: - raise NotImplementedError( - "NYI: Currently Register semantic only support 0-D vector." - ) - custom.fx_node.thread_shape = max(register_shape) - else: - raise NotImplementedError( - f"Register shape propagation not implemented for {custom_user}" - ) - - -def determine_register_shape( - trace: CapturedTrace | fx.Graph, constraints: list[Constraint] -) -> None: - """ - Each register op is annotated with the wave shape of the register. This - function determines the thread shape of the register based on the uses - of the register in the graph. - """ - register_nodes = trace.walk(lambda node: isinstance(get_custom(node), NewRegister)) - if not register_nodes: - return - vector_map = get_hardware_vector_map(constraints) - for node in register_nodes: - set_register_shape(trace, get_custom(node), vector_map) diff --git a/shark_turbine/kernel/wave/thread_shape_analysis.py b/shark_turbine/kernel/wave/thread_shape_analysis.py new file mode 100644 index 00000000..572561fb --- /dev/null +++ b/shark_turbine/kernel/wave/thread_shape_analysis.py @@ -0,0 +1,142 @@ +# Copyright 2024 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from ...support.logging import get_logger +from shark_turbine.kernel._support.tracing import CapturedTrace +import torch.fx as fx +from ..ops.wave_ops import * +from ..lang.global_symbols import * +from .utils import capture_forward_slice, capture_backward_slice + +logger = get_logger("turbine.wave.thread_shape_analysis") + + +@dataclass(order=True) +class DimSize: + dim: IndexSymbol + size: int + + def __hash__(self): + return hash((self.dim, self.size)) + + +def get_dim_sizes(indices: list[IndexSequence]): + dims = frozenset([DimSize(dim, seq.size) for dim, seq in indices.items()]) + return dims + + +def get_custom_dim_sizes(custom: CustomOp): + return get_dim_sizes(custom.index) + + +def set_index_size(custom: CustomOp, target_dim_sizes: list[DimSize]): + for target in target_dim_sizes: + if target.dim not in custom.index: + raise NotImplementedError( + "NYI: Handle when source target index size is not found in target/user index." + ) + custom.index[target.dim].size = target.size + + +def determine_thread_shapes(trace: CapturedTrace): + """ + This function does analysis and propagation of thread shape. It does by such: + 1. Look for "anchor" ops who has information of it's elem_per_thread. + 2. Do a forward/backward slice on these anchor ops to get ops that + who's shapes depends on these anchor ops. + 3. We bucket these ops to Variadic(Index->elem_per_thread) mapping. + 4. At every bucket of (index -> elem_per_thread), we apply these information + by updating their indexSequence size. + + We stored the buckets above in a variable/dict called `thread_size_to_ops`. + + `thread_size_to_ops` is a dict that uses thread_shapes as key and for every + key/thread_shape will map to a set of fx.nodes that needs to have that + thread_shape in it's indexSequence. + + `thread_shapes` is used to store thread_size at every dimension that the op + cares about. We use a frozenset[DimSize] to represent it, where DimSize + is essentially a pair. we are using + frozen_set since we do not care about the order of dims for the shape/size + propagation. + + We use sets[CustomOp] to represent the values of `thread_size_ops` S.T we can + easily find any conflicting of index using set operations and handle/resolve it + if required. + + For better illustration, here's an example: + Kernel: + imm = tkw.mul(x, y) + lhs = tkw.neg(imm) + a = tkw.mma(lhs, rhs, acc) + b = tkw.exp2(a) + Anchors: + mma.lhs: {IndexSize(index=M, size=1), IndexSize(index=K, size=4)} + mma.rhs: {IndexSize(index=K, size=4), IndexSize(index=N, size=1)} + mma.acc: {IndexSize(index=M, size=4), IndexSize(index=N, size=1)} + Bucket Entry: + thread_sizes_to_ops[frozenset({IndexSize(index=M, size=1), IndexSize(index=K, size=4)}] = set(lhs, imm, x, y) + thread_sizes_to_ops[frozenset({IndexSize(index=M, size=4), IndexSize(index=N, size=1)}] = set(acc, exp2_0) + thread_sizes_to_ops[frozenset({IndexSize(index=K, size=4), IndexSize(index=N, size=1)}] = set(rhs, ...) + + """ + + # Anchor ops are ops who's thread shape are predetermined. + anchorOpTypes = (Read, Write, MMA, ReduceOp) + noHandleTypes = (Placeholder, Output, ExtractSlice, Allocate) + nonPropagatableTypes = anchorOpTypes + noHandleTypes + + def is_anchor_op(node: fx.Node): + return isinstance(get_custom(node), anchorOpTypes) + + def propagatable_op(node: fx.Node): + return not isinstance(get_custom(node), nonPropagatableTypes) + + anchor_ops = trace.walk(is_anchor_op) + thread_size_to_ops: dict[frozenset[DimSize], set[CustomOp]] = {} + for anchor_op in anchor_ops: + custom = get_custom(anchor_op) + index_sizes = get_custom_dim_sizes(custom) + if isinstance(custom, (Read, ReduceOp)): + fwd_slice = capture_forward_slice(custom.fx_node, propagatable_op) + thread_size_to_ops[index_sizes] = thread_size_to_ops.get( + index_sizes, set([]) + ).union(fwd_slice) + elif isinstance(custom, Write): + bwd_slice = capture_backward_slice(custom.fx_node, propagatable_op) + thread_size_to_ops[index_sizes] = thread_size_to_ops.get( + index_sizes, set([]) + ).union(bwd_slice) + elif isinstance(custom, MMA): + lhs_bwd_slice = capture_backward_slice(custom.lhs, propagatable_op) + rhs_bwd_slice = capture_backward_slice(custom.rhs, propagatable_op) + acc_slice = capture_forward_slice(custom.acc, propagatable_op) + acc_slice = acc_slice.union( + capture_backward_slice(custom.acc, propagatable_op) + ) + acc_index = get_dim_sizes(custom.acc_index) + lhs_index = get_dim_sizes(custom.lhs_index) + rhs_index = get_dim_sizes(custom.rhs_index) + thread_size_to_ops[acc_index] = thread_size_to_ops.get( + acc_index, set([]) + ).union(acc_slice) + thread_size_to_ops[lhs_index] = thread_size_to_ops.get( + lhs_index, set([]) + ).union(lhs_bwd_slice) + thread_size_to_ops[rhs_index] = thread_size_to_ops.get( + rhs_index, set([]) + ).union(rhs_bwd_slice) + + # Go through each index-size buckets, and apply the index-size to ops in the bucket. + cummulative_set = set() + for target_index_size, target_ops in thread_size_to_ops.items(): + # Ensure that we do not have any conflicts. + if not cummulative_set.isdisjoint(target_ops): + raise NotImplementedError("NYI: Handling of conflicting thread shape.") + cummulative_set = cummulative_set.union(target_ops) + for user in target_ops: + custom_user = get_custom(user) + set_index_size(custom_user, target_index_size) diff --git a/shark_turbine/kernel/wave/utils.py b/shark_turbine/kernel/wave/utils.py index 9ea9adad..5e221f52 100644 --- a/shark_turbine/kernel/wave/utils.py +++ b/shark_turbine/kernel/wave/utils.py @@ -494,28 +494,33 @@ def get_inputs( Return the inputs of a node, propagating through reductions. """ inputs = [] - for input in node.all_input_nodes: - custom = get_custom(input) - if isinstance(custom, GetResult): - reduction = custom.value - assert isinstance( - reduction, Reduction - ), "GetResult must be used by a Reduction" - # Map get result to output - inputs.append(reduction.outputs[custom.res_idx]) - continue - if isinstance(custom, IterArg): - # Map iter args to init args - iter_arg_idx = reduction.iter_args.index(node) - inputs.append(reduction.init_args[iter_arg_idx]) - continue - inputs.append(input) + custom = get_custom(node) + if isinstance(custom, IterArg): + # Map iter args to init args + local_reduction = reduction + if reduction is None: + local_reduction = custom.parent_op() + iter_arg_idx = custom.get_iter_idx() + inputs.append(local_reduction.init_args[iter_arg_idx]) + elif isinstance(custom, GetResult): + reduction = get_custom(custom.value) + assert isinstance( + get_custom(reduction), Reduction + ), "GetResult must be used by a Reduction" + # Map get result to output + reduction_subgraph = reduction.graph.subgraphs[reduction.subgraph_name] + inputs.append(reduction.outputs(reduction_subgraph)[custom.res_idx]) + else: + # Default handling for other ops. + for input in node.all_input_nodes: + inputs.append(input) return inputs, reduction def bfs( node: fx.Node, get_neighbors: Callable[[fx.Node, fx.Node], list[fx.Node]], + filter_fn: Callable[[fx.node], bool], ) -> set[fx.Node]: """ Run BFS on the graph to capture the forward slice of a node. @@ -529,25 +534,29 @@ def bfs( s = queue.pop(0) neighbors, reduction = get_neighbors(s, reduction) for neighbor in neighbors: - if neighbor not in visited: + if neighbor not in visited and filter_fn(neighbor): visited.add(neighbor) queue.append(neighbor) return visited -def capture_forward_slice(node: fx.Node) -> set[fx.Node]: +def capture_forward_slice( + node: fx.Node, filter_fn: Callable[[fx.node], bool] = lambda x: True +) -> set[fx.Node]: """ Run BFS on the graph to capture the forward slice of a node. """ - return bfs(node, lambda x, y: get_users(x, y)) + return bfs(node, lambda x, y: get_users(x, y), filter_fn) -def capture_backward_slice(node: fx.Node) -> set[fx.Node]: +def capture_backward_slice( + node: fx.Node, filter_fn: Callable[[fx.node], bool] = lambda x: True +) -> set[fx.Node]: """ Capture backward slice from a node and return the tree. Assumes graph is directed. """ - return bfs(node, lambda x, y: get_inputs(x, y)) + return bfs(node, lambda x, y: get_inputs(x, y), filter_fn) def capture_mma_slices(mma_nodes: list[MMA]) -> dict[IndexSymbol, list[fx.Node]]: diff --git a/shark_turbine/kernel/wave/wave.py b/shark_turbine/kernel/wave/wave.py index fde0c792..202cdd92 100644 --- a/shark_turbine/kernel/wave/wave.py +++ b/shark_turbine/kernel/wave/wave.py @@ -39,7 +39,7 @@ from ..ops.wave_ops import Reduction, CustomOp, get_custom from .index_sequence_analysis import partition_strided_operators from .shared_memory_indexing import apply_shared_memory_indexing_corrections -from .register_analysis import determine_register_shape +from .thread_shape_analysis import determine_thread_shapes from .scheduling.schedule import schedule_graph from .._support.indexing import IndexingContext, IndexExpr import shark_turbine.kernel.lang as tkl @@ -227,9 +227,6 @@ def _trace_and_get_kernel_signature( # Clean up chains of GetResults remove_chained_getresult(graph) - # Register analysis to determine register shapes. - determine_register_shape(graph, self.constraints) - # Optimizations. minimize_global_loads(graph, self.constraints) @@ -239,6 +236,9 @@ def _trace_and_get_kernel_signature( # Partition strided operators. partition_strided_operators(graph, self.constraints) + # Analyze Thread Shapes per Op. + determine_thread_shapes(graph) + # Decompose reduce Ops. decompose_reduce_ops(graph, self.constraints, idxc.subs) From d98e5213db2d5bc4eb819f92a773ae3dd31e991c Mon Sep 17 00:00:00 2001 From: harsh-nod Date: Thu, 3 Oct 2024 15:51:45 -0700 Subject: [PATCH 17/28] Disable benchmarking on all e2e tests for now (#189) We would like this to be controlled with a flag. Signed-off-by: Harsh Menon --- tests/kernel/wave/wave_e2e_test.py | 16 ++++++++-------- tests/kernel/wave/wave_gemm_test.py | 2 +- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/tests/kernel/wave/wave_e2e_test.py b/tests/kernel/wave/wave_e2e_test.py index 633c73c5..066a4d4e 100644 --- a/tests/kernel/wave/wave_e2e_test.py +++ b/tests/kernel/wave/wave_e2e_test.py @@ -102,7 +102,7 @@ def test( }, canonicalize=True, run=True, - run_bench=True, + run_bench=False, run_config=config, ): test(a, b) @@ -215,7 +215,7 @@ def test( }, canonicalize=True, run=True, - run_bench=True, + run_bench=False, run_config=config, ): test(a, b) @@ -272,7 +272,7 @@ def test( }, canonicalize=True, run=True, - run_bench=True, + run_bench=False, run_config=config, ): test(a, b) @@ -329,7 +329,7 @@ def test( }, canonicalize=True, run=True, - run_bench=True, + run_bench=False, run_config=config, ): test(a, b, c) @@ -405,7 +405,7 @@ def repeat( }, canonicalize=True, run=True, - run_bench=True, + run_bench=False, run_config=config, ): test(a, b, c) @@ -510,7 +510,7 @@ def test( }, canonicalize=True, run=True, - run_bench=True, + run_bench=False, run_config=config, ): test(a, b) @@ -641,7 +641,7 @@ def repeat(acc: tkl.Register[M, NF, tkl.f32]) -> tkl.Register[M, NF, tkl.f32]: }, canonicalize=True, run=True, - run_bench=True, + run_bench=False, run_config=config, ): gpu_func(x, we, out) @@ -956,7 +956,7 @@ def repeat(acc: tkl.Register[M, NF, tkl.f32]) -> tkl.Register[M, NF, tkl.f32]: }, canonicalize=True, run=True, - run_bench=True, + run_bench=False, run_config=config, ): conv(x, we, out) diff --git a/tests/kernel/wave/wave_gemm_test.py b/tests/kernel/wave/wave_gemm_test.py index 63e51909..dc9b00a6 100644 --- a/tests/kernel/wave/wave_gemm_test.py +++ b/tests/kernel/wave/wave_gemm_test.py @@ -122,7 +122,7 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: hyperparams, canonicalize=True, run=True, - run_bench=True, + run_bench=False, run_config=config, schedule=enable_scheduling, ): From a04ea80ff897f8dacb9228f267e035ed094579e8 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Fri, 4 Oct 2024 01:59:52 +0300 Subject: [PATCH 18/28] Set `fail-fast: false` (#190) Our tests are flaky, `fail-fast: false` won't allow failing builds abort other. Signed-off-by: Ivan Butygin --- .github/workflows/ci.yaml | 1 + .github/workflows/perf.yaml | 1 + .github/workflows/test_build_release.yml | 1 + 3 files changed, 3 insertions(+) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 0796d30a..9486845b 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -18,6 +18,7 @@ jobs: test: name: "Unit Tests and Type Checking" strategy: + fail-fast: false matrix: version: [3.11] os: [ubuntu-latest, nodai-amdgpu-mi300-x86-64] diff --git a/.github/workflows/perf.yaml b/.github/workflows/perf.yaml index 87adb13d..d86af8ac 100644 --- a/.github/workflows/perf.yaml +++ b/.github/workflows/perf.yaml @@ -21,6 +21,7 @@ jobs: test: name: "Unit Tests and Type Checking" strategy: + fail-fast: false matrix: version: [3.11] os: [ubuntu-latest, nodai-amdgpu-mi300-x86-64] diff --git a/.github/workflows/test_build_release.yml b/.github/workflows/test_build_release.yml index c0546365..aea5a16a 100644 --- a/.github/workflows/test_build_release.yml +++ b/.github/workflows/test_build_release.yml @@ -19,6 +19,7 @@ jobs: test: name: "Test Build Release Process" strategy: + fail-fast: false matrix: version: [3.11] os: [ubuntu-latest] From 207efd961d54a629dfedd8469a88045fd80dfd3a Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Fri, 4 Oct 2024 20:09:57 +0300 Subject: [PATCH 19/28] [TKW] IGEMM Benchmarking (#187) Initial version of IGEMM benchmarking. * If `--runperf` pytest option is set, generate IREE ref code and run both TKW and ref code with `run_bench=True` * Add `--dump-perf-files-path` option to save perf info files into provided directory (filenames based on test name and params) --------- Signed-off-by: Ivan Butygin --- shark_turbine/kernel/wave/iree_utils.py | 36 +++++++++++++++++++++- shark_turbine/kernel/wave/utils.py | 9 ++---- tests/conftest.py | 8 ++++- tests/kernel/wave/wave_e2e_test.py | 40 +++++++++++++++++++++++-- 4 files changed, 81 insertions(+), 12 deletions(-) diff --git a/shark_turbine/kernel/wave/iree_utils.py b/shark_turbine/kernel/wave/iree_utils.py index 6d612c91..39f67404 100644 --- a/shark_turbine/kernel/wave/iree_utils.py +++ b/shark_turbine/kernel/wave/iree_utils.py @@ -5,6 +5,7 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception import torch +from typing import Any from .utils import compile_and_invoke from ...support.conversions import TORCH_DTYPE_TO_MLIR_TYPE_ASM @@ -23,6 +24,23 @@ def get_mmt_asm(lhs_type: str, rhs_type: str, acc_type: str) -> str: return matmul_function +def get_conv_asm( + conv_type: str, lhs_type: str, rhs_type: str, res_type: str, stride: int +) -> str: + res_dtype = res_type.split("x")[-1] + return f""" + func.func @conv_{conv_type}(%lhs: tensor<{lhs_type}>, %rhs: tensor<{rhs_type}>) -> tensor<{res_type}> {{ + %c0 = arith.constant 0.0 : {res_dtype} + %init = tensor.empty() : tensor<{res_type}> + %inital_result = linalg.fill ins(%c0 : {res_dtype}) outs(%init : tensor<{res_type}>) -> tensor<{res_type}> + %result = linalg.conv_{conv_type} + {{dilations = dense<1> : tensor<2xi64>, strides = dense<{stride}> : tensor<2xi64>}} + ins(%lhs, %rhs : tensor<{lhs_type}>, tensor<{rhs_type}>) + outs(%inital_result : tensor<{res_type}>) -> tensor<{res_type}> + return %result : tensor<{res_type}> + }}""" + + def dtype_str(dtype: torch.dtype) -> str: dtype_str = TORCH_DTYPE_TO_MLIR_TYPE_ASM.get(dtype, None) if dtype_str is None: @@ -39,20 +57,36 @@ def generate_iree_ref( kernel_inputs: list[torch.Tensor], kernel_outputs: list[torch.Tensor], config: dict[str, str], + **kwargs: dict[str, Any], ): """ Generate a reference output for the given kernel type and arguments. """ asm = None + conv_str = "conv_" if kernel_type == "mmt": lhs_type = get_type_str(kernel_inputs[0].shape, kernel_inputs[0].dtype) rhs_type = get_type_str(kernel_inputs[1].shape, kernel_inputs[1].dtype) acc_type = get_type_str(kernel_outputs[0].shape, kernel_outputs[0].dtype) asm = get_mmt_asm(lhs_type, rhs_type, acc_type) + elif kernel_type.startswith(conv_str): + lhs_type = get_type_str(kernel_inputs[0].shape, kernel_inputs[0].dtype) + rhs_type = get_type_str(kernel_inputs[1].shape, kernel_inputs[1].dtype) + acc_type = get_type_str(kernel_outputs[0].shape, kernel_outputs[0].dtype) + conv_type = kernel_type[len(conv_str) :] + asm = get_conv_asm( + conv_type, lhs_type, rhs_type, acc_type, int(kwargs["stride"]) + ) else: raise ValueError(f"Unknown kernel type: {kernel_type}") compile_and_invoke( - asm, kernel_type, config, kernel_inputs, kernel_outputs, True, False + asm, + kernel_type, + config, + kernel_inputs, + kernel_outputs, + run=True, + run_bench=kwargs.get("run_bench", False), ) diff --git a/shark_turbine/kernel/wave/utils.py b/shark_turbine/kernel/wave/utils.py index 5e221f52..871e9e79 100644 --- a/shark_turbine/kernel/wave/utils.py +++ b/shark_turbine/kernel/wave/utils.py @@ -388,16 +388,11 @@ def compile_and_invoke( ) else: for inp in kernel_inputs: - tf = tempfile.NamedTemporaryFile() - torch.save(inp, tf) + tf = tempfile.NamedTemporaryFile(suffix=".npy") + numpy.save(tf, inp.numpy()) tempfiles.append(tf) inputs.append("@" + tf.name) - benchmark_results = bench.benchmark_module( - mod, - entry_function=func_name, - ) - benchmark_results = bench.benchmark_module( mod, entry_function=func_name, diff --git a/tests/conftest.py b/tests/conftest.py index 93bd8d6f..2129b19f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,11 +11,17 @@ def pytest_addoption(parser): parser.addoption( "--runperf", action="store_true", default=False, help="run performance tests" ) + parser.addoption( + "--dump-perf-files-path", + action="store", + default=None, + help="save performance info into provided directory, filename based on current test name", + ) def pytest_configure(config): config.addinivalue_line( - "markers", "perf_only: performace test, runs only with '--runperf'" + "markers", "perf_only: performance test, runs only with '--runperf'" ) diff --git a/tests/kernel/wave/wave_e2e_test.py b/tests/kernel/wave/wave_e2e_test.py index 066a4d4e..a7ff3342 100644 --- a/tests/kernel/wave/wave_e2e_test.py +++ b/tests/kernel/wave/wave_e2e_test.py @@ -9,6 +9,7 @@ import shark_turbine.kernel.wave as tkw from shark_turbine.kernel.wave.wave_sim import wave_sim from shark_turbine.kernel.lang.global_symbols import * +from shark_turbine.kernel.wave.iree_utils import generate_iree_ref import torch from numpy.testing import assert_allclose, assert_equal import pytest @@ -829,11 +830,16 @@ def repeat(acc: tkl.Register[M, NF, tkl.f32]) -> tkl.Register[M, NF, tkl.f32]: perf_test(2, 64, 64, 640, 3, 3, 640, 1), ] +_mem_spaces = [ + pytest.param(GLOBAL_ADDRESS_SPACE, id="global"), + pytest.param(SHARED_ADDRESS_SPACE, id="shared"), +] + @require_e2e @pytest.mark.parametrize("n, h, w, c, hf, wf, nf, stride", _igemm_cases) -@pytest.mark.parametrize("mem_space", [GLOBAL_ADDRESS_SPACE, SHARED_ADDRESS_SPACE]) -def test_igemm_conv(n, h, w, c, hf, wf, nf, stride, mem_space): +@pytest.mark.parametrize("mem_space", _mem_spaces) +def test_igemm_conv(n, h, w, c, hf, wf, nf, stride, mem_space, request): cf = c padding = 0 # TODO: only pad=0 is supported for now @@ -940,6 +946,18 @@ def repeat(acc: tkl.Register[M, NF, tkl.f32]) -> tkl.Register[M, NF, tkl.f32]: config = {"backend": "rocm", "device": "hip", "target": "gfx942"} + run_bench = request.config.getoption("--runperf") + dump_perf = request.config.getoption("--dump-perf-files-path") + if run_bench: + config["benchmark_batch_size"] = 10 + config["benchmark_repetitions"] = 3 + + if dump_perf is not None: + perf_filename = request.node.name + ".json" + config["benchmark_results_file"] = os.path.join( + dump_perf, "tk_" + perf_filename + ) + with tk.gen.TestLaunchContext( { N: n, @@ -956,8 +974,24 @@ def repeat(acc: tkl.Register[M, NF, tkl.f32]) -> tkl.Register[M, NF, tkl.f32]: }, canonicalize=True, run=True, - run_bench=False, + run_bench=run_bench, run_config=config, ): conv(x, we, out) assert_allclose(out, out_ref, rtol=1e-03, atol=1e-03) + + if run_bench: + if dump_perf is not None: + config["benchmark_results_file"] = os.path.join( + dump_perf, "iree_" + perf_filename + ) + + iree_ref = torch.zeros_like(out_ref) + generate_iree_ref( + "conv_2d_nchw_fchw", + [x, we], + [iree_ref], + config, + stride=stride, + run_bench=True, + ) From 64b7d27b030c250a8a9dbefc7363e39c131b6403 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Fri, 4 Oct 2024 20:33:11 +0300 Subject: [PATCH 20/28] [TKW] Update IR interpreter (#182) * Add `arith.andi`, `arith.cmpi`, `vector.maskedload`, `vector.gather`, `vector.contant_mask`, `vector.insertelement`, `vectot.splat`, support non-splatted contants. * Add `interpret_ndrange` helper --------- Signed-off-by: Ivan Butygin --- mypy.ini | 4 + shark_turbine/tools/interpreter.py | 140 ++++++++++++++++++++++------- 2 files changed, 113 insertions(+), 31 deletions(-) diff --git a/mypy.ini b/mypy.ini index 29c35b65..528b8d48 100644 --- a/mypy.ini +++ b/mypy.ini @@ -20,6 +20,10 @@ ignore_errors = True [mypy-shark_turbine.kernel.*] ignore_errors = True +# TODO: Some pytorch errors. +[mypy-shark_turbine.tools.interpreter] +ignore_errors = True + # Ignore all typing errors in tests/tools (these depend on TK). [mypy-tests.tools.*] ignore_errors = True diff --git a/shark_turbine/tools/interpreter.py b/shark_turbine/tools/interpreter.py index 5e4a0b15..5022933d 100644 --- a/shark_turbine/tools/interpreter.py +++ b/shark_turbine/tools/interpreter.py @@ -4,32 +4,33 @@ import re from typing import Callable from collections import namedtuple +import numpy as np logger = get_logger("turbine.wave.interpreter") from ..kernel.compiler.ir import ( - amdgpu_d, - builtin_d, Context, + F16Type, + F32Type, IndexType, - Value, - VectorType, + IntegerAttr, + IntegerType, Module, Operation, + Value, + VectorType, + amdgpu_d, + arith_d, + builtin_d, flow_d, func_d, gpu_d, llvm_d, - scf_d, - vector_d, memref_d, - IntegerAttr, - IndexType, - arith_d, + scf_d, stream_d, - F32Type, - F16Type, + vector_d, ) @@ -53,17 +54,12 @@ def get_dtype(self, dtype): return torch.float32 if type(dtype) == F16Type: return torch.float16 + if type(dtype) == IndexType: + return torch.int64 + if dtype == IntegerType.get_signless(1): + return torch.bool raise NotImplementedError(f"Unsupported dtype: {dtype}") - def create_tensor(self, shape: list[int], dtype, value) -> torch.Tensor: - """ - Creates a constant tensor with the given shape, dtype and value. - The tensor is filled with ones. - """ - if type(dtype) == F32Type or type(dtype) == F16Type: - value = float(value) - return torch.ones(*shape, dtype=self.get_dtype(dtype)) * value - def callback(self, op: Operation) -> None: if ( op.operation.parent.name == "func.func" @@ -80,11 +76,13 @@ def callback(self, op: Operation) -> None: elif vtype == VectorType: shape = op.value.type.shape dtype = op.value.type.element_type - value = self.create_tensor( - shape, - dtype, - op.attributes["value"].get_splat_value(), - ) + val = op.attributes["value"] + dtype = self.get_dtype(dtype) + if val.is_splat: + val = val.get_splat_value().value + value = torch.full(shape, val, dtype=dtype) + else: + value = torch.from_numpy(np.array(val)).type(dtype=dtype) else: raise NotImplementedError(f"Unsupported constant type: {vtype}") case arith_d.MulIOp: @@ -112,6 +110,21 @@ def callback(self, op: Operation) -> None: self.symbol_table[op.operands[0]] // self.symbol_table[op.operands[1]] ) + case arith_d.AndIOp: + value = ( + self.symbol_table[op.operands[0]] + & self.symbol_table[op.operands[1]] + ) + case arith_d.CmpIOp: + lhs = self.symbol_table[op.lhs] + rhs = self.symbol_table[op.rhs] + pred = int(op.predicate) + if pred == int(arith_d.CmpIPredicate.slt): + value = lhs < rhs + else: + raise NotImplementedError( + f"Unsupported predicate: {op.predicate}" + ) case amdgpu_d.LDSBarrierOp: return case amdgpu_d.MFMAOp: @@ -136,11 +149,10 @@ def callback(self, op: Operation) -> None: ) # Row-major load offset = [0 for _ in range(len(load_indices))] - offset[-1] += 1 for i in range(*result_shape): - value[i] = memref[ - *[int(x) + y for x, y in zip(load_indices, offset)] - ] + ind = [int(x) + y for x, y in zip(load_indices, offset)] + value[i] = memref[*ind] + offset[-1] += 1 case vector_d.ExtractStridedSliceOp: vector = self.symbol_table[op.vector] value = vector[[int(x) for x in op.offsets]] @@ -154,11 +166,69 @@ def callback(self, op: Operation) -> None: result_shape = vector.shape # Row-major store offset = [0 for _ in range(len(store_indices))] - offset[-1] += 1 for i in range(*result_shape): memref[ *[int(x) + y for x, y in zip(store_indices, offset)] ] = vector[i] + offset[-1] += 1 + case vector_d.MaskedStoreOp: + store_indices = [] + for index in op.indices: + store_indices.append(self.symbol_table[index]) + vector = self.symbol_table[op.valueToStore] + memref = self.symbol_table[op.base] + mask = self.symbol_table[op.mask] + result_type = vector.type + result_shape = vector.shape + # Row-major store + offset = [0 for _ in range(len(store_indices))] + for i in range(*result_shape): + if mask[i]: + ind = [int(x) + y for x, y in zip(store_indices, offset)] + memref[*ind] = vector[i] + + offset[-1] += 1 + case vector_d.ConstantMaskOp: + shape = op.result.type.shape + value = torch.ones(shape, dtype=torch.bool) + case vector_d.GatherOp: + load_indices = [] + for index in op.indices: + load_indices.append(self.symbol_table[index]) + logger.debug("Gather indices:", load_indices) + memref = self.symbol_table[op.base] + mask = self.symbol_table[op.mask] + index_vec = self.symbol_table[op.index_vec] + pass_thru = self.symbol_table[op.pass_thru] + result_type = op.result.type + result_shape = result_type.shape + result_dtype = result_type.element_type + value = torch.zeros( + *result_shape, dtype=self.get_dtype(result_dtype) + ) + # Row-major load + offset = [0 for _ in range(len(load_indices))] + for i in range(*result_shape): + if mask[i]: + off = [ + slice(int(x) + y, None) + for x, y in zip(load_indices, offset) + ] + m = memref[off].flatten() + value[i] = m[index_vec[i]] + else: + value[i] = pass_thru[i] + case vector_d.InsertElementOp: + source = self.symbol_table[op.source] + value = self.symbol_table[op.dest].clone() + position = self.symbol_table[op.position] + value[int(position[0])] = source + case vector_d.SplatOp: + mtype = op.result.type + shape = mtype.shape + dtype = mtype.element_type + input = self.symbol_table[op.input][0] + value = torch.full(shape, input, dtype=self.get_dtype(dtype)) case stream_d.DispatchWorkgroupIDOp: index = int(op.attributes["dimension"]) value = self.workgroup_ids[index] @@ -214,7 +284,7 @@ def callback(self, op: Operation) -> None: case _: raise NotImplementedError(f"Unsupported operation: {op}") - if type(op) != vector_d.StoreOp: + if type(op) not in (vector_d.StoreOp, vector_d.MaskedStoreOp): self.symbol_table[op.result] = value def walk_operations(self, operation: Operation, callback: Callable) -> None: @@ -237,6 +307,14 @@ def interpret(self, asm: str) -> None: operation = module.operation self.walk_operations(operation, self.callback) + @staticmethod + def interpret_ndrange( + asm: str, workgroup_count: list[int], workgroup_size: list[int] + ): + for wg in np.ndindex(*workgroup_count): + for t in np.ndindex(*workgroup_size): + Interpreter([*wg], [*t]).interpret(asm) + if __name__ == "__main__": parser = argparse.ArgumentParser(description="MLIR Interpreter") From 7617c94ea8d09ab51db494d92b967ed5628e274f Mon Sep 17 00:00:00 2001 From: Stanley Winata <68087699+raikonenfnu@users.noreply.github.com> Date: Fri, 4 Oct 2024 11:16:27 -0700 Subject: [PATCH 21/28] [TKW] Implement broadcastOp class, lowering and insertion (#176) Motivation of this PR is to be able to codegen/lower broadcast properly. With that in mind, we implemented these things: 1. BroadcastOp class, op and lowering, to represent and store broadcasting information. Mostly S.T we can query target shape information and the source operand of broadcast. 2. Treat broadcast-add as an index conflict and handle it by emitting broadcastOp. --------- Signed-off-by: Stanley Winata --- lit_tests/kernel/wave/codegen.py | 72 +++++++++++++++++++ shark_turbine/kernel/ops/wave_ops.py | 61 ++++++++++++++-- shark_turbine/kernel/wave/codegen.py | 39 +++++++++- .../kernel/wave/thread_shape_analysis.py | 47 ++++++++++-- 4 files changed, 208 insertions(+), 11 deletions(-) diff --git a/lit_tests/kernel/wave/codegen.py b/lit_tests/kernel/wave/codegen.py index 4c732079..b65a9035 100644 --- a/lit_tests/kernel/wave/codegen.py +++ b/lit_tests/kernel/wave/codegen.py @@ -1196,6 +1196,78 @@ def repeat( # CHECK: scf.yield %[[ACC_MAX_0]], %[[ACC_SUM_0]], %[[ACC_MAX_1]], %[[ACC_SUM_1]] +@run_test +def test_broadcast_add(): + constraints: list[tkw.Constraint] = [ + tkw.HardwareConstraint( + threads_per_wave=64, + waves_per_block=(1, 1, 1), + vector_shapes={M: 1, N: BLOCK_N}, + ) + ] + constraints += [tkw.WorkgroupConstraint(M, BLOCK_M, 1)] + constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 0)] + constraints += [tkw.WaveConstraint(M, BLOCK_M)] + constraints += [tkw.WaveConstraint(N, BLOCK_N)] + + @tkw.wave(constraints) + def test( + a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16], + b: tkl.Memory[M, ADDRESS_SPACE, tkl.f16], + c: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16], + ): + lhs = tkw.read(a, elements_per_thread=LOAD_ELEMS_PER_THREAD) + rhs = tkw.read(b, elements_per_thread=1) + res = lhs + rhs + tkw.write(res, c, elements_per_thread=STORE_ELEMS_PER_THREAD) + + config = {"backend": "rocm", "device": "hip", "target": "gfx942"} + + shape = (256, 128) + a = torch.ones(shape, dtype=torch.float16) + b = torch.ones(shape[0], dtype=torch.float16) + c = torch.zeros(shape, dtype=torch.float16) + with tk.gen.TestLaunchContext( + { + M: shape[0], + N: shape[1], + BLOCK_M: 2, + BLOCK_N: 128, + LOAD_ELEMS_PER_THREAD: 2, + STORE_ELEMS_PER_THREAD: 2, + ADDRESS_SPACE: tkl.AddressSpace.GLOBAL_MEMORY.value, + }, + canonicalize=True, + run=False, + run_config=config, + ): + print(test(a, b, c).module_op) + # CHECK: func.func @test(%[[ARG0:.+]]: !stream.binding, %[[ARG1:.+]]: !stream.binding, %{{.+}}: !stream.binding) + # CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index + # CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index + + # Slicing LHS + # CHECK: %[[LHS:.+]] = stream.binding.subspan %[[ARG0]][%[[C0]]] : !stream.binding -> memref<256x128xf16 + # CHECK: %[[LHS_0:.+]] = vector.load %[[LHS]][%[[X_SLICE_0:.+]], %[[Y_SLICE:.+]]] : memref<256x128xf16, strided<[128, 1], offset: ?>>, vector<2xf16> + # CHECK: %[[X_SLICE_1:.+]] = arith.addi %[[X_SLICE_0]], %c1 : index + # CHECK: %[[LHS_1:.+]] = vector.load %[[LHS]][%[[X_SLICE_1]], %[[Y_SLICE]]] : memref<256x128xf16, strided<[128, 1], offset: ?>>, vector<2xf16> + + # Slicing RHS + # CHECK: %[[RHS:.+]] = stream.binding.subspan %[[ARG1]][%[[C0]]] : !stream.binding -> memref<256xf16 + # CHECK: %[[RHS_0:.+]] = vector.load %[[RHS]][%[[X_SLICE_0]]] : memref<256xf16, strided<[1], offset: ?>>, vector<1xf16> + # CHECK: %[[RHS_1:.+]] = vector.load %[[RHS]][%[[X_SLICE_1]]] : memref<256xf16, strided<[1], offset: ?>>, vector<1xf16> + + # 1st Broadcast-ADD RHS + # CHECK: %[[EXTRACT_0:.+]] = vector.extract %[[RHS_0]][0] : f16 from vector<1xf16> + # CHECK: %[[BCAST_RHS_0:.+]] = vector.splat %[[EXTRACT_0]] : vector<2xf16> + # CHECK: arith.addf %[[LHS_0]], %[[BCAST_RHS_0]] : vector<2xf16> + + # 2nd Broadcast-ADD RHS + # CHECK: %[[EXTRACT_1:.+]] = vector.extract %[[RHS_1]][0] : f16 from vector<1xf16> + # CHECK: %[[BCAST_RHS_1:.+]] = vector.splat %[[EXTRACT_1]] : vector<2xf16> + # CHECK: arith.addf %[[LHS_1]], %[[BCAST_RHS_1]] : vector<2xf16> + + @run_test def test_binary_lowerings(): constraints: list[tkw.Constraint] = [ diff --git a/shark_turbine/kernel/ops/wave_ops.py b/shark_turbine/kernel/ops/wave_ops.py index 2373feec..47c6a5ec 100644 --- a/shark_turbine/kernel/ops/wave_ops.py +++ b/shark_turbine/kernel/ops/wave_ops.py @@ -96,6 +96,12 @@ def maximum(lhs: "Register", rhs: "Register") -> "Register": ... +def broadcast( + arg: "Register", target_shape: Optional[IndexExpr | int] = None +) -> "Register": + ... + + def sum( src: "Register", acc: Optional["Register"] = None, @@ -496,7 +502,12 @@ def post_expansion(self, constraints: list["Constraint"]) -> None: @dataclass class BinaryPyOp(CustomOp, ABC): """ - Represents a binary python operator. + Represents an elementwise binary python operator. + + DTYPE requirement: lhs and rhs needs to have the same dtpye. + Shape requirement: lhs and rhs either have same shape or + their shape must be broadcastable to + one another. """ lhs: Any @@ -522,9 +533,16 @@ def type(self) -> Memory: lhs_type = get_custom(self.lhs).type rhs_type = get_custom(self.rhs).type has_same_type = has_same_custom_type(lhs_type, rhs_type) - if not has_same_type: - raise ValueError("Expected lhs and rhs to have same type post-expansion") - return lhs_type + if has_same_type: + return lhs_type + lhs_dim_set = set(lhs_type.symbolic_shape) + rhs_dim_set = set(rhs_type.symbolic_shape) + if lhs_dim_set.isdisjoint(rhs_dim_set): + raise ValueError( + "BinaryPyOp requires lhs and rhs shape to be at least broadcastable." + ) + broadcasted_type = lhs_type if lhs_dim_set > rhs_dim_set else rhstype + return broadcasted_type @define_interface_op("exp2") @@ -899,8 +917,11 @@ def captured_vars(self, graph: fx.Graph) -> list[fx.Node]: return captured_vars @property - def type(self) -> list[Memory | Register]: - return [get_custom(x).type for x in self.init_args] + def type(self) -> Memory | Register | list[Memory | Register]: + res_types = [get_custom(x).type for x in self.init_args] + if len(res_types) == 1: + res_types = res_types[0] + return res_types def outputs(self, graph: fx.Graph) -> list[fx.Node]: for node in graph.nodes: @@ -1022,6 +1043,34 @@ def type(self) -> "Register": return get_custom(self.register_).type +@define_op("broadcast") +@dataclass +class Broadcast(CustomOp, ABC): + """ + Represents a Broadcast operation. + + arg: Source tensor/value to broadcast + target_shape: symbolic target broadcast shape. + """ + + arg: fx.Node + target_type: Sequence[IndexSymbol] = None + + @property + def target_shape(self): + return self.target_type.symbolic_shape + + @property + def indexing_dims(self) -> list[IndexSymbol]: + return self.target_shape + + @property + def type(self) -> Memory: + src_dtype = get_custom(self.arg).type.dtype + dst_type = Register[*self.target_shape, src_dtype] + return dst_type + + @define_interface_op("max") @define_interface_op("sum") @dataclass diff --git a/shark_turbine/kernel/wave/codegen.py b/shark_turbine/kernel/wave/codegen.py index 313e72cb..e218d71c 100644 --- a/shark_turbine/kernel/wave/codegen.py +++ b/shark_turbine/kernel/wave/codegen.py @@ -46,6 +46,7 @@ from shark_turbine.kernel.lang.global_symbols import * from ..ops.wave_ops import ( write, + broadcast, register, mma, shuffle, @@ -79,7 +80,7 @@ WorkgroupConstraint, TilingConstraint, ) -from .utils import subs_idxc, find_index_bounds +from .utils import subs_idxc, find_index_bounds, get_hardware_vector_map # Indexing imports. from .._support.indexing import IndexingContext, IndexExpr, IndexSequence @@ -1095,6 +1096,42 @@ def handle_extract_slice(emitter: WaveEmitter, node: fx.Node): emitter.bind_node_proxy(node, IRProxyValue(element)) +############################################################################### +# Reshape ops +############################################################################### + + +@handle_op(broadcast) +def handle_broadcast(emitter: WaveEmitter, node: fx.Node): + try: + register, target_type = node.args + except ValueError as e: + raise ValidationError("Malformed arguments") from e + + # Get thread_shape/size for broadcast. + get_thread_shape = lambda index: max(x.size for x in index.values()) + bcast_dim_lane_dim_size = get_thread_shape(node.index) + + # Check MLIR shape + vector_src = cast_vector(emitter, register) + vector_type = vector_src.type + # Only support broadcasting vector<1xdtype> for now. + if not VectorType.isinstance(vector_type): + raise NotImplementedError("Scalar src is not implemented yet for shuffleOp.") + assert vector_type.rank == 1 + assert vector_type.shape[0] == 1 + + # Extract and Splat + # If by chance broadcast size matches current size, we can return src. + if bcast_dim_lane_dim_size == vector_type.shape[0]: + emitter.bind_node_proxy(node, IRProxyValue(vector_src)) + + result_type = VectorType.get([bcast_dim_lane_dim_size], vector_type.element_type) + element = vector_d.extract(vector_src, static_position=[0], dynamic_position=[]) + splat = vector_d.splat(result_type, element) + emitter.bind_node_proxy(node, IRProxyValue(splat)) + + ############################################################################### # Miscellanous ops ############################################################################### diff --git a/shark_turbine/kernel/wave/thread_shape_analysis.py b/shark_turbine/kernel/wave/thread_shape_analysis.py index 572561fb..5fd0b999 100644 --- a/shark_turbine/kernel/wave/thread_shape_analysis.py +++ b/shark_turbine/kernel/wave/thread_shape_analysis.py @@ -9,7 +9,7 @@ import torch.fx as fx from ..ops.wave_ops import * from ..lang.global_symbols import * -from .utils import capture_forward_slice, capture_backward_slice +from .utils import capture_forward_slice, capture_backward_slice, subs_idxc logger = get_logger("turbine.wave.thread_shape_analysis") @@ -24,7 +24,9 @@ def __hash__(self): def get_dim_sizes(indices: list[IndexSequence]): - dims = frozenset([DimSize(dim, seq.size) for dim, seq in indices.items()]) + dims = frozenset( + [DimSize(dim, subs_idxc(seq.size)) for dim, seq in indices.items()] + ) return dims @@ -41,6 +43,39 @@ def set_index_size(custom: CustomOp, target_dim_sizes: list[DimSize]): custom.index[target.dim].size = target.size +def handle_binaryop_conflict(custom_node: CustomOp): + # Analyze if we can resolve conflict with broadcast. + lhs = get_custom(custom_node.lhs) + rhs = get_custom(custom_node.rhs) + lhs_dim_set = set(lhs.type.symbolic_shape) + rhs_dim_set = set(rhs.type.symbolic_shape) + if lhs_dim_set == rhs_dim_set: + raise ValueError("Cannot broadcast if lhs and rhs is already same.") + if lhs_dim_set.isdisjoint(rhs_dim_set): + raise ValueError("Cannot broadcast if lhs and rhs has disjointed shapes.") + # Determine the correct indexSize for binaryOp and insert broadcasting. + dst_op = lhs if lhs_dim_set > rhs_dim_set else rhs + broadcast_idx, broadcast_src = (1, rhs) if lhs_dim_set > rhs_dim_set else (0, lhs) + broadcast = Broadcast(broadcast_src.fx_node, dst_op.type) + with custom_node.graph.inserting_before(custom_node.fx_node): + broadcast.add_to_graph(custom_node.graph) + setattr(broadcast.fx_node, "index", dst_op.index) + custom_node.index = dst_op.index + custom_node.update_arg(broadcast_idx, broadcast.fx_node) + return True + + +# Returns True iff all conflicts are handled succesfully. +def handle_conflicts(conflicted_ops: set[CustomOp]): + for conflict in conflicted_ops: + custom = get_custom(conflict) + if isinstance(custom, BinaryPyOp): + handle_binaryop_conflict(custom) + else: + return False + return True + + def determine_thread_shapes(trace: CapturedTrace): """ This function does analysis and propagation of thread shape. It does by such: @@ -133,10 +168,14 @@ def propagatable_op(node: fx.Node): # Go through each index-size buckets, and apply the index-size to ops in the bucket. cummulative_set = set() for target_index_size, target_ops in thread_size_to_ops.items(): - # Ensure that we do not have any conflicts. + # Try to handle conflicts and remove from target set if successfully handled. if not cummulative_set.isdisjoint(target_ops): - raise NotImplementedError("NYI: Handling of conflicting thread shape.") + conflicted_ops = cummulative_set.intersection(target_ops) + if handle_conflicts(conflicted_ops) == False: + raise NotImplementedError("Failed to handle conflicting thread shape.") + target_ops = target_ops.difference(conflicted_ops) cummulative_set = cummulative_set.union(target_ops) + # Set target ops's indexSize to be the determined from analysis. for user in target_ops: custom_user = get_custom(user) set_index_size(custom_user, target_index_size) From 83bbc402de36e35542e6742bb7d67755d6403b38 Mon Sep 17 00:00:00 2001 From: harsh-nod Date: Fri, 4 Oct 2024 12:49:33 -0700 Subject: [PATCH 22/28] Add ability to dump intermediates (#194) This PR adds a flag to dump intermediates which include .ll and .s files to see what instructions were generated. --------- Signed-off-by: Harsh Menon --- shark_turbine/kernel/wave/utils.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/shark_turbine/kernel/wave/utils.py b/shark_turbine/kernel/wave/utils.py index 871e9e79..869df061 100644 --- a/shark_turbine/kernel/wave/utils.py +++ b/shark_turbine/kernel/wave/utils.py @@ -334,6 +334,12 @@ def compile_and_invoke( if config.get("print_ir_after_all", False): flags.append("--mlir-print-ir-after-all") + if "dump_intermediates" in config: + intermediates_path = config.get("dump_intermediates") + flags.append( + f"--iree-hal-dump-executable-intermediates-to={intermediates_path}" + ) + if run_bench: bench_batch_size = config.get("benchmark_batch_size", None) bench_repetitions = config.get("benchmark_repetitions", None) From 39acab892bff934866a6a4c25572f7485840b8f7 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Fri, 4 Oct 2024 23:03:49 +0300 Subject: [PATCH 23/28] Split TK CI from main CI (#195) * Main CI is flaky, add a separate pipeline, which tests only TK as temp solution * Make `pytest` output more verbose * Remove unnecessary stuff from perf pipeline --------- Signed-off-by: Ivan Butygin --- .github/workflows/ci-tk.yaml | 74 ++++++++++++++++++++++++++++++++++++ .github/workflows/ci.yaml | 10 +---- .github/workflows/perf.yaml | 17 ++------- 3 files changed, 79 insertions(+), 22 deletions(-) create mode 100644 .github/workflows/ci-tk.yaml diff --git a/.github/workflows/ci-tk.yaml b/.github/workflows/ci-tk.yaml new file mode 100644 index 00000000..b8e44b74 --- /dev/null +++ b/.github/workflows/ci-tk.yaml @@ -0,0 +1,74 @@ +name: "TK CI" + +on: + pull_request: + push: + branches: + - main + +concurrency: + # A PR number if a pull request and otherwise the commit hash. This cancels + # queued and in-progress runs for the same PR (presubmit) or commit + # (postsubmit). The workflow name is prepended to avoid conflicts between + # different workflows. + group: ${{ github.workflow }}-${{ github.event.number || github.sha }} + cancel-in-progress: true + +jobs: + test: + name: "Unit Tests and Type Checking" + strategy: + fail-fast: false + matrix: + version: [3.11] + os: [ubuntu-latest, nodai-amdgpu-mi300-x86-64] + runs-on: ${{matrix.os}} + env: + PIP_CACHE_DIR: "${{ github.workspace }}/.pip-cache" + steps: + - name: "Setting up Python" + id: setup_python + uses: actions/setup-python@v3 + with: + python-version: ${{matrix.version}} + + - name: "Checkout Code" + uses: actions/checkout@v3 + + - name: Cache Pip Packages + uses: actions/cache@v4 + id: cache-pip + with: + path: ${{ env.PIP_CACHE_DIR }} + key: pip-${{ steps.setup_python.outputs.python-version }}-${{ hashFiles('*requirements.txt') }} + + - name: Install pip deps + run: | + python -m pip install --no-compile --upgrade pip + # Note: We install in three steps in order to satisfy requirements + # from non default locations first. Installing the PyTorch CPU + # wheels saves multiple minutes and a lot of bandwidth on runner setup. + pip install --no-compile -r pytorch-cpu-requirements.txt + pip install --no-cache-dir -r iree-requirements-ci.txt + pip install -r requirements.txt -e . + + - name: Run unit tests + if: ${{ !cancelled() }} + run: | + pytest -n 4 --capture=tee-sys -vv ./tests/kernel/wave/ + + - name: Run e2e tests on MI300 + if: "contains(matrix.os, 'mi300') && !cancelled()" + run: | + export WAVE_RUN_E2E_TESTS=1 + pytest -n 4 --capture=tee-sys -vv ./tests/kernel/wave/ + + - name: Run LIT tests + if: ${{ !cancelled() }} + run: | + lit lit_tests/ -v + + - name: MyPy Type Checking + if: ${{ !cancelled() }} + run: | + mypy diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 9486845b..bfb1175e 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -21,7 +21,7 @@ jobs: fail-fast: false matrix: version: [3.11] - os: [ubuntu-latest, nodai-amdgpu-mi300-x86-64] + os: [ubuntu-latest] runs-on: ${{matrix.os}} env: PIP_CACHE_DIR: "${{ github.workspace }}/.pip-cache" @@ -55,13 +55,7 @@ jobs: - name: Run unit tests if: ${{ !cancelled() }} run: | - pytest -n 4 . - - - name: Run e2e tests on MI300 - if: "contains(matrix.os, 'mi300') && !cancelled()" - run: | - export WAVE_RUN_E2E_TESTS=1 - pytest -n 4 --capture=tee-sys ./tests/kernel/wave/ + pytest -n 4 --capture=tee-sys -vv . - name: Run LIT tests if: ${{ !cancelled() }} diff --git a/.github/workflows/perf.yaml b/.github/workflows/perf.yaml index d86af8ac..1b8d8271 100644 --- a/.github/workflows/perf.yaml +++ b/.github/workflows/perf.yaml @@ -24,7 +24,7 @@ jobs: fail-fast: false matrix: version: [3.11] - os: [ubuntu-latest, nodai-amdgpu-mi300-x86-64] + os: [nodai-amdgpu-mi300-x86-64] runs-on: ${{matrix.os}} env: PIP_CACHE_DIR: "${{ github.workspace }}/.pip-cache" @@ -54,21 +54,10 @@ jobs: pip install --no-compile -r pytorch-cpu-requirements.txt pip install --no-cache-dir -r iree-requirements-ci.txt pip install -r requirements.txt -e . - - name: Run unit tests - if: ${{ !cancelled() }} - run: | - pytest -n 4 . + - name: Run e2e tests on MI300 if: "contains(matrix.os, 'mi300') && !cancelled()" run: | export WAVE_RUN_E2E_TESTS=1 export TEST_PARAMS_PATH="tests/kernel/wave/test_param.json" - pytest -n 1 ./tests/kernel/wave/ - - name: Run LIT tests - if: ${{ !cancelled() }} - run: | - lit lit_tests/ -v - - name: MyPy Type Checking - if: ${{ !cancelled() }} - run: | - mypy + pytest -n 1 --capture=tee-sys -vv ./tests/kernel/wave/ From 4fec47c101bde3af5410794f4dd24b18590b0564 Mon Sep 17 00:00:00 2001 From: erman-gurses <99776114+erman-gurses@users.noreply.github.com> Date: Fri, 4 Oct 2024 14:02:56 -0700 Subject: [PATCH 24/28] Add parameterization for benchmark flag (#192) Signed-off-by: erman-gurses --- tests/kernel/wave/wave_e2e_test.py | 35 +++++++++++++++++------------ tests/kernel/wave/wave_gemm_test.py | 31 +++++++++++++++++++++---- 2 files changed, 48 insertions(+), 18 deletions(-) diff --git a/tests/kernel/wave/wave_e2e_test.py b/tests/kernel/wave/wave_e2e_test.py index a7ff3342..cf2d1315 100644 --- a/tests/kernel/wave/wave_e2e_test.py +++ b/tests/kernel/wave/wave_e2e_test.py @@ -56,7 +56,8 @@ def wrapper(shape): @require_e2e @pytest.mark.parametrize("shape", get_test_shapes("test_copy")) -def test_copy(shape): +def test_copy(shape, request): + run_bench = request.config.getoption("--runperf") M = tkl.sym.M N = tkl.sym.N ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE @@ -103,7 +104,7 @@ def test( }, canonicalize=True, run=True, - run_bench=False, + run_bench=run_bench, run_config=config, ): test(a, b) @@ -112,7 +113,8 @@ def test( @require_e2e @pytest.mark.parametrize("shape", get_test_shapes("test_copy")) -def test_dynamic_copy(shape): +def test_dynamic_copy(shape, request): + run_bench = request.config.getoption("--runperf") M = tkl.sym.M N = tkl.sym.N ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE @@ -159,6 +161,7 @@ def test( dynamic_symbols_map={M: shape[0], N: shape[1]}, canonicalize=True, run=True, + run_bench=run_bench, run_config=config, ): test(a, b) @@ -167,7 +170,8 @@ def test( @require_e2e @pytest.mark.parametrize("shape", get_test_shapes("test_transpose_read")) -def test_transpose_read(shape): +def test_transpose_read(shape, request): + run_bench = request.config.getoption("--runperf") shape = shape[::-1] M = tkl.sym.M N = tkl.sym.N @@ -216,7 +220,7 @@ def test( }, canonicalize=True, run=True, - run_bench=False, + run_bench=run_bench, run_config=config, ): test(a, b) @@ -225,7 +229,8 @@ def test( @require_e2e @pytest.mark.parametrize("shape", get_test_shapes("test_transpose_write")) -def test_transpose_write(shape): +def test_transpose_write(shape, request): + run_bench = request.config.getoption("--runperf") M = tkl.sym.M N = tkl.sym.N ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE @@ -273,7 +278,7 @@ def test( }, canonicalize=True, run=True, - run_bench=False, + run_bench=run_bench, run_config=config, ): test(a, b) @@ -282,7 +287,8 @@ def test( @require_e2e @pytest.mark.parametrize("shape", get_test_shapes("test_reduce_sum")) -def test_reduce_sum(shape): +def test_reduce_sum(shape, request): + run_bench = request.config.getoption("--runperf") M = tkl.sym.M N = tkl.sym.N wave_size = 64 @@ -330,7 +336,7 @@ def test( }, canonicalize=True, run=True, - run_bench=False, + run_bench=run_bench, run_config=config, ): test(a, b, c) @@ -406,7 +412,6 @@ def repeat( }, canonicalize=True, run=True, - run_bench=False, run_config=config, ): test(a, b, c) @@ -417,7 +422,8 @@ def repeat( @require_e2e -def test_im2col(): +def test_im2col(request): + run_bench = request.config.getoption("--runperf") # TODO: we don't support unaligned access at the moment so all sizes must # be aligned to WG/Wave sizes, c * hw * wf == 8 and number of windows == 64. n, c, h, w = 1, 2, 9, 9 # Image. @@ -511,7 +517,7 @@ def test( }, canonicalize=True, run=True, - run_bench=False, + run_bench=run_bench, run_config=config, ): test(a, b) @@ -519,7 +525,8 @@ def test( @require_e2e -def test_im2col_mma(): +def test_im2col_mma(request): + run_bench = request.config.getoption("--runperf") # igemm without final col2im n, c, h, w = 1, 4, 9, 9 # Image. nf, cf, hf, wf = 64, c, 2, 2 # Filters. @@ -642,7 +649,7 @@ def repeat(acc: tkl.Register[M, NF, tkl.f32]) -> tkl.Register[M, NF, tkl.f32]: }, canonicalize=True, run=True, - run_bench=False, + run_bench=run_bench, run_config=config, ): gpu_func(x, we, out) diff --git a/tests/kernel/wave/wave_gemm_test.py b/tests/kernel/wave/wave_gemm_test.py index dc9b00a6..f9de28b2 100644 --- a/tests/kernel/wave/wave_gemm_test.py +++ b/tests/kernel/wave/wave_gemm_test.py @@ -24,6 +24,14 @@ default_test_shapes = [(1024, 5120, 640), (2048, 10240, 1280), (4096, 20480, 2560)] +perf_test = lambda *a: pytest.param(*a, marks=pytest.mark.perf_only) + +default_test_shapes += [ + perf_test((1024, 5120, 640)), + perf_test((2048, 10240, 1280)), + perf_test((4096, 20480, 2560)), +] + user_specified_test_shapes = "" test_params_path = os.environ.get("TEST_PARAMS_PATH", None) @@ -42,8 +50,9 @@ def get_test_shapes(test_name: str) -> list[tuple[int]]: @require_e2e @pytest.mark.parametrize("shape", get_test_shapes("test_gemm")) @pytest.mark.parametrize("enable_scheduling", [False, True]) -def testGemm(shape: tuple[int], enable_scheduling: bool): - +def testGemm(shape: tuple[int], enable_scheduling: bool, request): + run_bench = request.config.getoption("--runperf") + dump_perf = request.config.getoption("--dump-perf-files-path") # Input sizes M = tkl.sym.M N = tkl.sym.N @@ -118,11 +127,20 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: MMA_UNITS: 4, } config = {"backend": "rocm", "device": "hip", "target": "gfx942"} + if run_bench: + config["benchmark_batch_size"] = 10 + config["benchmark_repetitions"] = 3 + if dump_perf is not None: + perf_filename = request.node.name + ".json" + config["benchmark_results_file"] = os.path.join( + dump_perf, "tk_" + perf_filename + ) + with tk.gen.TestLaunchContext( hyperparams, canonicalize=True, run=True, - run_bench=False, + run_bench=run_bench, run_config=config, schedule=enable_scheduling, ): @@ -136,6 +154,11 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: with open(filename, "w") as f: f.write(mb.module_op.get_asm()) + if run_bench: + if dump_perf is not None: + config["benchmark_results_file"] = os.path.join( + dump_perf, "iree_" + perf_filename + ) iree_ref = torch.zeros(shape[0], shape[1], dtype=torch.float32) - generate_iree_ref("mmt", [a, b], [iree_ref], config) + generate_iree_ref("mmt", [a, b], [iree_ref], config, run_bench=run_bench) assert_close(c, iree_ref) From da3436d768a0bcf9cce0ff00fb3b89f369e4fd56 Mon Sep 17 00:00:00 2001 From: harsh-nod Date: Fri, 4 Oct 2024 14:11:21 -0700 Subject: [PATCH 25/28] Add padding to reduce shared memory bank conflicts (#193) --- lit_tests/kernel/wave/barriers.py | 6 ++--- lit_tests/kernel/wave/codegen.py | 24 +++++++++---------- .../kernel/wave/index_sequence_analysis.py | 4 ++-- .../kernel/wave/minimize_global_loads.py | 4 ++-- lit_tests/kernel/wave/promotion.py | 8 +++---- shark_turbine/kernel/wave/promotion.py | 22 ++++++++++++++++- 6 files changed, 44 insertions(+), 24 deletions(-) diff --git a/lit_tests/kernel/wave/barriers.py b/lit_tests/kernel/wave/barriers.py index 8f8b4a6f..fcb7dbed 100644 --- a/lit_tests/kernel/wave/barriers.py +++ b/lit_tests/kernel/wave/barriers.py @@ -98,7 +98,7 @@ def test_read_write_equal_sizes(): # CHECK-NEXT: %read_0_1 # CHECK-SAME: (%a, 4, None, None) # CHECK-NEXT: %allocate - # CHECK-SAME: ((M, N), (BLOCK_M, BLOCK_N), f16, $SHARED_ADDRESS_SPACE) + # CHECK-SAME: ((M, N), (BLOCK_M, BLOCK_N + 4), f16, $SHARED_ADDRESS_SPACE) # CHECK-NEXT: %write_shared_0_0 # CHECK-SAME: (%read_0_0, %allocate, 4, None) # CHECK-NEXT: %write_shared_1_1 @@ -182,9 +182,9 @@ def test_gemm(): # CHECK-NEXT: %register_1_0_0 # CHECK-NEXT: %register_0_1_0 # CHECK-NEXT: %allocate - # CHECK-SAME: ((M, K), (BLOCK_M, BLOCK_K), f16, $SHARED_ADDRESS_SPACE) + # CHECK-SAME: ((M, K), (BLOCK_M, BLOCK_K + 4), f16, $SHARED_ADDRESS_SPACE) # CHECK-NEXT: %allocate_1 - # CHECK-SAME: ((N, K), (BLOCK_N, BLOCK_K), f16, $SHARED_ADDRESS_SPACE) + # CHECK-SAME: ((N, K), (BLOCK_N, BLOCK_K + 4), f16, $SHARED_ADDRESS_SPACE) # CHECK-NEXT: reduction # CHECK-SAME (K, [%register_0_0_0, %register_1_1_0, %register_1_0_0, %register_0_1_0] # CHECK-NEXT: %getresult_1_1_0 diff --git a/lit_tests/kernel/wave/codegen.py b/lit_tests/kernel/wave/codegen.py index b65a9035..4800e9bd 100644 --- a/lit_tests/kernel/wave/codegen.py +++ b/lit_tests/kernel/wave/codegen.py @@ -474,12 +474,12 @@ def mma( # CHECK: %[[D9:.+]] = arith.muli %[[D8]], %[[C4]] : index # CHECK: %[[D10:.+]] = vector.load %[[D0]][%[[D6]], %[[D9]]] : memref<64x16xf16, strided<[16, 1], offset: # CHECK-SAME: ?>>, vector<4xf16> - # CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<32x16xf16, #[[GPU:.+]].address_space> + # CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<32x20xf16, #[[GPU:.+]].address_space> # CHECK: %[[D11:.+]] = arith.addi %[[D4]], %[[D2]] : index - # CHECK: vector.store %[[D10]], %[[ALLOC]][%[[D11]], %[[D9]]] : memref<32x16xf16, + # CHECK: vector.store %[[D10]], %[[ALLOC]][%[[D11]], %[[D9]]] : memref<32x20xf16, # CHECK-SAME: #[[GPU]].address_space>, vector<4xf16> # CHECK: amdgpu.lds_barrier - # CHECK: %[[D12:.+]] = vector.load %[[ALLOC]][%[[D11]], %[[D9]]] : memref<32x16xf16, + # CHECK: %[[D12:.+]] = vector.load %[[ALLOC]][%[[D11]], %[[D9]]] : memref<32x20xf16, # CHECK-SAME: #[[GPU]].address_space>, vector<4xf16> # CHECK: %[[D13:.+]] = stream.binding.subspan %[[ARG1]][%[[C0]]] : !stream.binding -> memref<128x16xf16, # CHECK-SAME: strided<[16, 1], offset: ?>> @@ -489,13 +489,13 @@ def mma( # CHECK: %[[D17:.+]] = arith.addi %[[D16]], %[[D14]] : index # CHECK: %[[D18:.+]] = vector.load %[[D13]][%[[D17]], %[[D9]]] : memref<128x16xf16, strided<[16, 1], offset: # CHECK-SAME: ?>>, vector<4xf16> - # CHECK: %[[ALLOC_0:.+]] = memref.alloc() : memref<32x16xf16, #[[GPU]].address_space> + # CHECK: %[[ALLOC_0:.+]] = memref.alloc() : memref<32x20xf16, #[[GPU]].address_space> # CHECK: amdgpu.lds_barrier # CHECK: %[[D19:.+]] = arith.addi %[[D4]], %[[D14]] : index - # CHECK: vector.store %[[D18]], %[[ALLOC_0]][%[[D19]], %[[D9]]] : memref<32x16xf16, + # CHECK: vector.store %[[D18]], %[[ALLOC_0]][%[[D19]], %[[D9]]] : memref<32x20xf16, # CHECK-SAME: #[[GPU]].address_space>, vector<4xf16> # CHECK: amdgpu.lds_barrier - # CHECK: %[[D20:.+]] = vector.load %[[ALLOC_0]][%[[D19]], %[[D9]]] : memref<32x16xf16, + # CHECK: %[[D20:.+]] = vector.load %[[ALLOC_0]][%[[D19]], %[[D9]]] : memref<32x20xf16, # CHECK-SAME: #[[GPU]].address_space>, vector<4xf16> # CHECK: %[[D21:.+]] = amdgpu.mfma %[[D12]] * %[[D20]] + %[[CST]] {blocks = 1 : i32, k = 16 : i32, m = 16 : # CHECK-SAME: i32, n = 16 : i32} blgp = none : vector<4xf16>, vector<4xf16>, vector<4xf32> @@ -593,8 +593,8 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: # CHECK: %[[WORKGROUP_ID_1:.+]] = stream.dispatch.workgroup.id[1] : index # CHECK-DAG: %[[THREAD_ID_X:.+]] = gpu.thread_id x # CHECK-DAG: %[[THREAD_ID_Y:.+]] = gpu.thread_id y - # CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<32x16xf16, #[[GPU:.+]].address_space> - # CHECK: %[[ALLOC_0:.+]] = memref.alloc() : memref<32x16xf16, #[[GPU]].address_space> + # CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<32x20xf16, #[[GPU:.+]].address_space> + # CHECK: %[[ALLOC_0:.+]] = memref.alloc() : memref<32x20xf16, #[[GPU]].address_space> # CHECK: %[[D0:.+]] = stream.binding.subspan %[[ARG0]][%[[C0]]] : !stream.binding -> memref<64x64xf16, # CHECK-SAME: strided<[64, 1], offset: ?>> # CHECK: %[[D1:.+]] = stream.binding.subspan %[[ARG1]][%[[C0]]] : !stream.binding -> memref<128x64xf16, @@ -620,18 +620,18 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: # CHECK: %[[D40:.+]] = arith.addi %[[D39]], %[[D10]] : index # CHECK: %[[D41:.+]] = vector.load %[[D0]][%[[D7]], %[[D40]]] : memref<64x64xf16, strided<[64, 1], offset: # CHECK-SAME: ?>>, vector<4xf16> - # CHECK: vector.store %[[D41]], %[[ALLOC]][%[[D11]], %[[D10]]] : memref<32x16xf16, + # CHECK: vector.store %[[D41]], %[[ALLOC]][%[[D11]], %[[D10]]] : memref<32x20xf16, # CHECK-SAME: #[[GPU]].address_space>, vector<4xf16> # CHECK: amdgpu.lds_barrier - # CHECK: %[[D42:.+]] = vector.load %[[ALLOC]][%[[D11]], %[[D10]]] : memref<32x16xf16, + # CHECK: %[[D42:.+]] = vector.load %[[ALLOC]][%[[D11]], %[[D10]]] : memref<32x20xf16, # CHECK-SAME: #[[GPU]].address_space>, vector<4xf16> # CHECK: %[[D43:.+]] = vector.load %[[D1]][%[[D15]], %[[D40]]] : memref<128x64xf16, strided<[64, 1], # CHECK-SAME: offset: ?>>, vector<4xf16> # CHECK: amdgpu.lds_barrier - # CHECK: vector.store %[[D43]], %[[ALLOC_0]][%[[D16]], %[[D10]]] : memref<32x16xf16, + # CHECK: vector.store %[[D43]], %[[ALLOC_0]][%[[D16]], %[[D10]]] : memref<32x20xf16, # CHECK-SAME: #[[GPU]].address_space>, vector<4xf16> # CHECK: amdgpu.lds_barrier - # CHECK: %[[D44:.+]] = vector.load %[[ALLOC_0]][%[[D16]], %[[D10]]] : memref<32x16xf16, + # CHECK: %[[D44:.+]] = vector.load %[[ALLOC_0]][%[[D16]], %[[D10]]] : memref<32x20xf16, # CHECK-SAME: #[[GPU]].address_space>, vector<4xf16> # CHECK: %[[D45:.+]] = amdgpu.mfma %[[D42]] * %[[D44]] + %[[ARG4]] {blocks = 1 : i32, k = 16 : i32, m = 16 # CHECK-SAME: : i32, n = 16 : i32} blgp = none : vector<4xf16>, vector<4xf16>, vector<4xf32> diff --git a/lit_tests/kernel/wave/index_sequence_analysis.py b/lit_tests/kernel/wave/index_sequence_analysis.py index 2bebc690..7dd266ee 100644 --- a/lit_tests/kernel/wave/index_sequence_analysis.py +++ b/lit_tests/kernel/wave/index_sequence_analysis.py @@ -98,9 +98,9 @@ def test_gemm(): # CHECK-NEXT: %register_1_0_0 # CHECK-NEXT: %register_0_1_0 # CHECK-NEXT: %allocate - # CHECK-SAME: ((M, K), (BLOCK_M, BLOCK_K), f16, $SHARED_ADDRESS_SPACE) + # CHECK-SAME: ((M, K), (BLOCK_M, BLOCK_K + 4), f16, $SHARED_ADDRESS_SPACE) # CHECK-NEXT: %allocate_1 - # CHECK-SAME: ((N, K), (BLOCK_N, BLOCK_K), f16, $SHARED_ADDRESS_SPACE) + # CHECK-SAME: ((N, K), (BLOCK_N, BLOCK_K + 4), f16, $SHARED_ADDRESS_SPACE) # CHECK-NEXT: reduction # CHECK-SAME (K, [%register_0_0_0, %register_1_1_0, %register_1_0_0, %register_0_1_0] # CHECK-NEXT: %getresult_1_1_0 diff --git a/lit_tests/kernel/wave/minimize_global_loads.py b/lit_tests/kernel/wave/minimize_global_loads.py index dcf6b225..7596a94b 100644 --- a/lit_tests/kernel/wave/minimize_global_loads.py +++ b/lit_tests/kernel/wave/minimize_global_loads.py @@ -103,9 +103,9 @@ def test_gemm(): # CHECK-NEXT: %register_1_0_0 # CHECK-NEXT: %register_0_1_0 # CHECK-NEXT: %allocate - # CHECK-SAME: ((M, K), (BLOCK_M, BLOCK_K), f16, $SHARED_ADDRESS_SPACE) + # CHECK-SAME: ((M, K), (BLOCK_M, BLOCK_K + 4), f16, $SHARED_ADDRESS_SPACE) # CHECK-NEXT: %allocate_1 - # CHECK-SAME: ((N, K), (BLOCK_N, BLOCK_K), f16, $SHARED_ADDRESS_SPACE) + # CHECK-SAME: ((N, K), (BLOCK_N, BLOCK_K + 4), f16, $SHARED_ADDRESS_SPACE) # CHECK-NEXT: reduction # CHECK-SAME (K, [%register_0_0_0, %register_1_1_0, %register_1_0_0, %register_0_1_0] # CHECK-NEXT: %getresult_1_1_0 diff --git a/lit_tests/kernel/wave/promotion.py b/lit_tests/kernel/wave/promotion.py index 01db88cc..3843c406 100644 --- a/lit_tests/kernel/wave/promotion.py +++ b/lit_tests/kernel/wave/promotion.py @@ -74,7 +74,7 @@ def test_read_write_equal_sizes(): # CHECK-NEXT: %read # CHECK-SAME: (%a, 4, None, None) # CHECK-NEXT: %allocate - # CHECK-SAME: ((M, N), (BLOCK_M, BLOCK_N), f16, $SHARED_ADDRESS_SPACE) + # CHECK-SAME: ((M, N), (BLOCK_M, BLOCK_N + 4), f16, $SHARED_ADDRESS_SPACE) # CHECK-NEXT: %write_1 # CHECK-SAME: (%read, %allocate, 4, None) # CHECK-NEXT: %read_1 @@ -123,7 +123,7 @@ def test_read_write_equal_sizes_different_address_spaces(): # CHECK-NEXT: %read # CHECK-SAME: (%a, 4, None, None) # CHECK-NEXT: %allocate - # CHECK-SAME: ((M, N), (BLOCK_M, BLOCK_N), f16, $SHARED_ADDRESS_SPACE) + # CHECK-SAME: ((M, N), (BLOCK_M, BLOCK_N + 4), f16, $SHARED_ADDRESS_SPACE) # CHECK-NEXT: %write_1 # CHECK-SAME: (%read, %allocate, 4, None) # CHECK-NEXT: %read_1 @@ -181,9 +181,9 @@ def test_gemm(): # CHECK-NEXT: %c # CHECK-NEXT: %register # CHECK-NEXT: %allocate - # CHECK-SAME: ((M, K), (BLOCK_M, BLOCK_K), f16, $SHARED_ADDRESS_SPACE) + # CHECK-SAME: ((M, K), (BLOCK_M, BLOCK_K + 4), f16, $SHARED_ADDRESS_SPACE) # CHECK-NEXT: %allocate_1 - # CHECK-SAME: ((N, K), (BLOCK_N, BLOCK_K), f16, $SHARED_ADDRESS_SPACE) + # CHECK-SAME: ((N, K), (BLOCK_N, BLOCK_K + 4), f16, $SHARED_ADDRESS_SPACE) # CHECK-NEXT: reduction # CHECK-NEXT: %write # CHECK-SAME: (%reduction, %c, 4, None) diff --git a/shark_turbine/kernel/wave/promotion.py b/shark_turbine/kernel/wave/promotion.py index fd1aa541..3711436f 100644 --- a/shark_turbine/kernel/wave/promotion.py +++ b/shark_turbine/kernel/wave/promotion.py @@ -15,6 +15,25 @@ logger = get_logger("turbine.wave.promotion") +def apply_padding( + shape: tuple[IndexSymbol | int], dtype: DataType +) -> tuple[IndexSymbol | int]: + """ + When accessing shared memory, we need to be cognizant of bank conflicts + that can have a significant impact on performance. One way to mitigate + these conflicts is by applying padding to the shared memory allocation. + This function applies padding of 64 bits to the shared memory allocation. + While this approach accomplishes the goal of reducing bank conflicts, it + is inefficient in terms of memory usage. A more sophisticated approach + would involve swizzling of the shared memory access patterns. + """ + padding = 64 // dtype.bitwidth() + return tuple( + value + padding if i == len(shape) - 1 else value + for i, value in enumerate(shape) + ) + + def apply_promotion_pattern(custom_node: Read | Write, allocate_node: Allocate): match custom_node: case Read(memory, elements_per_thread) if get_custom( @@ -47,9 +66,10 @@ def promote_node( assert isinstance(node, Read) or isinstance(node, Write) with node.graph.inserting_before(node.fx_node.next): constrained_shape = get_constrained_shape(node.type.symbolic_shape, constraints) + padded_shape = apply_padding(constrained_shape, node.type.dtype) allocate_node = Allocate( node.type.symbolic_shape, - constrained_shape, + padded_shape, node.type.dtype, address_space, ) From b0ef345ee1d95ba5c26942b904caa4925d9a74c6 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sat, 5 Oct 2024 19:30:17 +0300 Subject: [PATCH 26/28] Rename `shark-turbine` -> `iree.turbine` (#197) * Move files from files from `shark-turbine` to `iree/turbine`. * Update imports * Update `setup.py` * Make backward redirect `shark-turbine` -> `iree.turbine` (do we need this?) Progress on #28 --------- Signed-off-by: Ivan Butygin --- MANIFEST.in | 2 +- README.md | 2 +- build_tools/build_release.py | 4 +- examples/aot_mlp/mlp_export_dynamic.py | 2 +- examples/aot_mlp/mlp_export_simple.py | 2 +- examples/llama2_inference/README.md | 47 -- examples/llama2_inference/llama2.ipynb | 503 ------------------ .../llama2_inference/llama2_state_schema.json | 1 - examples/llama2_inference/requirements.txt | 4 - examples/resnet-18/requirements.txt | 2 +- examples/resnet-18/resnet-18.py | 2 +- .../runtime_torture/launchable_torture.py | 4 +- iree/turbine/__init__.py | 12 - .../turbine}/aot/__init__.py | 0 .../turbine}/aot/builtins/__init__.py | 0 .../turbine}/aot/builtins/globals.py | 0 .../turbine}/aot/builtins/jittable.py | 0 .../turbine}/aot/compiled_module.py | 0 .../turbine}/aot/decompositions.py | 0 .../turbine}/aot/exporter.py | 0 .../turbine}/aot/fx_programs.py | 0 {shark_turbine => iree/turbine}/aot/params.py | 0 .../turbine}/aot/passes/__init__.py | 0 .../turbine}/aot/passes/functorch.py | 0 .../turbine}/aot/support/ir_utils.py | 0 .../aot/support/procedural/__init__.py | 0 .../turbine}/aot/support/procedural/base.py | 0 .../support/procedural/exported_program.py | 0 .../aot/support/procedural/globals.py | 0 .../aot/support/procedural/iree_emitter.py | 0 .../aot/support/procedural/primitives.py | 0 .../turbine}/aot/support/procedural/tracer.py | 0 .../turbine}/aot/tensor_traits.py | 0 .../turbine}/dynamo/__init__.py | 0 .../turbine}/dynamo/backends/cpu.py | 0 .../turbine}/dynamo/decompositions.py | 0 .../turbine}/dynamo/executor.py | 0 .../turbine}/dynamo/passes.py | 0 .../turbine}/dynamo/tensor.py | 4 +- .../turbine}/dynamo/type_conversion.py | 0 .../turbine}/importers/README.md | 0 .../turbine}/importers/ir.py | 0 .../turbine}/importers/utils.py | 0 .../turbine}/kernel/__init__.py | 0 .../turbine}/kernel/_support/context.py | 0 .../turbine}/kernel/_support/dtype.py | 0 .../turbine}/kernel/_support/indexing.py | 0 .../turbine}/kernel/_support/regions.py | 0 .../turbine}/kernel/_support/shaped_type.py | 0 .../turbine}/kernel/_support/tracing.py | 0 .../turbine}/kernel/compiler/base.py | 0 .../turbine}/kernel/compiler/builder.py | 0 .../kernel/compiler/dispatch_codegen.py | 0 .../turbine}/kernel/compiler/host_codegen.py | 0 .../turbine}/kernel/compiler/ir.py | 0 .../kernel/compiler/kernel_codegen.py | 0 .../turbine}/kernel/compiler/op_matchers.py | 0 .../turbine}/kernel/compiler/utils.py | 0 .../kernel/compiler/vector_codegen.py | 0 .../turbine}/kernel/gen/__init__.py | 0 .../turbine}/kernel/gen/kernel.py | 0 .../turbine}/kernel/gen/thread.py | 0 .../turbine}/kernel/lang/__init__.py | 0 .../turbine}/kernel/lang/global_symbols.py | 0 .../turbine}/kernel/lang/grid.py | 0 .../turbine}/kernel/lang/kernel_buffer.py | 0 .../turbine}/kernel/lang/prims.py | 0 .../turbine}/kernel/lang/types.py | 0 .../turbine}/kernel/lang/wave_types.py | 0 .../turbine}/kernel/ops/__init__.py | 0 .../turbine}/kernel/ops/base.py | 0 .../turbine}/kernel/ops/control_flow.py | 0 .../turbine}/kernel/ops/core.py | 0 .../turbine}/kernel/ops/math.py | 0 .../turbine}/kernel/ops/memory.py | 0 .../turbine}/kernel/ops/reduction.py | 0 .../turbine}/kernel/ops/shape_manipulation.py | 0 .../turbine}/kernel/ops/wave_ops.py | 0 .../turbine}/kernel/wave/README.md | 0 .../turbine}/kernel/wave/__init__.py | 0 .../turbine}/kernel/wave/barriers.py | 0 .../turbine}/kernel/wave/codegen.py | 4 +- .../turbine}/kernel/wave/constraints.py | 0 .../kernel/wave/decompose_reduce_ops.py | 0 .../turbine}/kernel/wave/docs/gemm_example.md | 0 .../kernel/wave/docs/mlsys/.gitignore | 0 .../kernel/wave/docs/mlsys/algorithm.sty | 0 .../kernel/wave/docs/mlsys/algorithmic.sty | 0 .../kernel/wave/docs/mlsys/fancyhdr.sty | 0 .../kernel/wave/docs/mlsys/mlsys2024.bst | 0 .../kernel/wave/docs/mlsys/mlsys2024.sty | 0 .../turbine}/kernel/wave/docs/mlsys/tkw.bbl | 0 .../turbine}/kernel/wave/docs/mlsys/tkw.bib | 0 .../turbine}/kernel/wave/docs/mlsys/tkw.tex | 0 .../turbine}/kernel/wave/expansion.py | 0 .../turbine}/kernel/wave/hoisting.py | 2 +- .../kernel/wave/index_sequence_analysis.py | 0 .../turbine}/kernel/wave/iree_utils.py | 0 .../kernel/wave/minimize_global_loads.py | 0 .../turbine}/kernel/wave/promotion.py | 0 .../kernel/wave/scheduling/__init__.py | 0 .../kernel/wave/scheduling/graph_utils.py | 0 .../wave/scheduling/loop_reconstruction.py | 0 .../scheduling/loop_reconstruction_utils.py | 0 .../wave/scheduling/modulo_scheduling.py | 0 .../kernel/wave/scheduling/resources.py | 0 .../kernel/wave/scheduling/schedule.py | 0 .../kernel/wave/shared_memory_indexing.py | 0 .../kernel/wave/thread_shape_analysis.py | 2 +- .../turbine}/kernel/wave/utils.py | 2 +- .../turbine}/kernel/wave/visualization.py | 0 .../turbine}/kernel/wave/wave.py | 2 +- .../turbine}/kernel/wave/wave_sim.py | 0 .../turbine}/ops/__init__.py | 0 .../turbine}/ops/_jinja_test_ops.py | 0 .../turbine}/ops/_str_format_test_ops.py | 0 {shark_turbine => iree/turbine}/ops/iree.py | 0 .../ops/templates/test_add_jinja.mlir | 0 .../ops/templates/test_add_strformat.mlir | 0 .../ops/templates/test_syntax_error.mlir | 0 .../turbine}/runtime/__init__.py | 0 .../turbine}/runtime/device.py | 0 .../turbine}/runtime/launch.py | 0 .../turbine}/runtime/op_reg/__init__.py | 0 .../turbine}/runtime/op_reg/base.py | 0 .../turbine}/runtime/op_reg/compiler.py | 0 .../turbine}/runtime/op_reg/eager.py | 0 .../turbine}/runtime/op_reg/impl_helper.py | 0 .../turbine}/runtime/tracing.py | 0 .../turbine}/support/__init__.py | 0 .../turbine}/support/conversions.py | 0 .../turbine}/support/debugging.py | 0 .../turbine}/support/exceptions.py | 0 .../turbine}/support/ir_imports.py | 0 .../turbine}/support/logging.py | 0 .../turbine}/tools/__init__.py | 0 .../turbine}/tools/interpreter.py | 0 .../turbine}/transforms/builder.py | 0 .../transforms/general/add_metadata.py | 2 +- .../transforms/general/custom_op_expansion.py | 0 .../transforms/general/rename_parameters.py | 0 .../turbine}/transforms/merger.py | 0 .../transforms/quantization/mm_group_quant.py | 0 .../turbine}/transforms/rewriter.py | 0 lit_tests/kernel/wave/barriers.py | 24 +- lit_tests/kernel/wave/codegen.py | 10 +- lit_tests/kernel/wave/expansion.py | 14 +- .../kernel/wave/index_sequence_analysis.py | 28 +- .../kernel/wave/minimize_global_loads.py | 30 +- lit_tests/kernel/wave/promotion.py | 20 +- lit_tests/kernel/wave/scheduling.py | 28 +- lit_tests/kernel/wave/tracing.py | 10 +- lit_tests/lit.cfg.py | 2 +- mypy.ini | 8 +- setup.py | 15 +- shark_turbine/__init__.py | 13 + tests/aot/api_test.py | 2 +- tests/aot/args_test.py | 2 +- tests/aot/compiled_exported_program_test.py | 4 +- tests/aot/decompositions_test.py | 2 +- tests/aot/dynamic_shape_export_test.py | 2 +- tests/aot/functionalize_test.py | 2 +- tests/aot/fx_programs_test.py | 2 +- tests/aot/globals_test.py | 2 +- tests/aot/iree_procedural_test.py | 2 +- tests/aot/jittable_test.py | 2 +- tests/aot/non_strict_export_test.py | 2 +- tests/aot/params_test.py | 2 +- tests/dynamo/importer_dynamic_test.py | 2 +- tests/dynamo/tensor_test.py | 4 +- tests/dynamo/type_conversion_test.py | 2 +- tests/generated/evaluate.py | 2 +- tests/kernel/aot_kernel_test.py | 6 +- tests/kernel/arith_test.py | 8 +- tests/kernel/compiler/utils_test.py | 6 +- tests/kernel/dispatch_codegen_test.py | 8 +- tests/kernel/fused_attention_test.py | 4 +- tests/kernel/indexing_test.py | 4 +- tests/kernel/simple_kernel_test.py | 20 +- tests/kernel/types_test.py | 2 +- tests/kernel/vector_codegen_test.py | 4 +- tests/kernel/wave/constraints_test.py | 4 +- tests/kernel/wave/scheduling_test.py | 30 +- tests/kernel/wave/types_test.py | 6 +- tests/kernel/wave/visualization_test.py | 18 +- tests/kernel/wave/wave_e2e_test.py | 12 +- tests/kernel/wave/wave_gemm_test.py | 10 +- tests/kernel/wave/wave_sim_test.py | 6 +- tests/kernel/wave/wave_utils_test.py | 4 +- tests/ops/iree_test.py | 4 +- tests/runtime/device_test.py | 12 +- tests/runtime/launch_test.py | 4 +- tests/runtime/op_reg/impl_helper_test.py | 2 +- tests/runtime/op_reg/kernel_aot_test.py | 6 +- tests/runtime/op_reg/kernel_reg_test.py | 4 +- tests/tools/interpreter_test.py | 10 +- tests/top_level_package_test.py | 4 +- tests/transforms/general/add_metadata_test.py | 2 +- .../general/custom_op_expansion_test.py | 6 +- .../general/rename_parameters_test.py | 4 +- .../quantization/mm_group_quant_test.py | 4 +- 201 files changed, 250 insertions(+), 807 deletions(-) delete mode 100644 examples/llama2_inference/README.md delete mode 100644 examples/llama2_inference/llama2.ipynb delete mode 100644 examples/llama2_inference/llama2_state_schema.json delete mode 100644 examples/llama2_inference/requirements.txt rename {shark_turbine => iree/turbine}/aot/__init__.py (100%) rename {shark_turbine => iree/turbine}/aot/builtins/__init__.py (100%) rename {shark_turbine => iree/turbine}/aot/builtins/globals.py (100%) rename {shark_turbine => iree/turbine}/aot/builtins/jittable.py (100%) rename {shark_turbine => iree/turbine}/aot/compiled_module.py (100%) rename {shark_turbine => iree/turbine}/aot/decompositions.py (100%) rename {shark_turbine => iree/turbine}/aot/exporter.py (100%) rename {shark_turbine => iree/turbine}/aot/fx_programs.py (100%) rename {shark_turbine => iree/turbine}/aot/params.py (100%) rename {shark_turbine => iree/turbine}/aot/passes/__init__.py (100%) rename {shark_turbine => iree/turbine}/aot/passes/functorch.py (100%) rename {shark_turbine => iree/turbine}/aot/support/ir_utils.py (100%) rename {shark_turbine => iree/turbine}/aot/support/procedural/__init__.py (100%) rename {shark_turbine => iree/turbine}/aot/support/procedural/base.py (100%) rename {shark_turbine => iree/turbine}/aot/support/procedural/exported_program.py (100%) rename {shark_turbine => iree/turbine}/aot/support/procedural/globals.py (100%) rename {shark_turbine => iree/turbine}/aot/support/procedural/iree_emitter.py (100%) rename {shark_turbine => iree/turbine}/aot/support/procedural/primitives.py (100%) rename {shark_turbine => iree/turbine}/aot/support/procedural/tracer.py (100%) rename {shark_turbine => iree/turbine}/aot/tensor_traits.py (100%) rename {shark_turbine => iree/turbine}/dynamo/__init__.py (100%) rename {shark_turbine => iree/turbine}/dynamo/backends/cpu.py (100%) rename {shark_turbine => iree/turbine}/dynamo/decompositions.py (100%) rename {shark_turbine => iree/turbine}/dynamo/executor.py (100%) rename {shark_turbine => iree/turbine}/dynamo/passes.py (100%) rename {shark_turbine => iree/turbine}/dynamo/tensor.py (99%) rename {shark_turbine => iree/turbine}/dynamo/type_conversion.py (100%) rename {shark_turbine => iree/turbine}/importers/README.md (100%) rename {shark_turbine => iree/turbine}/importers/ir.py (100%) rename {shark_turbine => iree/turbine}/importers/utils.py (100%) rename {shark_turbine => iree/turbine}/kernel/__init__.py (100%) rename {shark_turbine => iree/turbine}/kernel/_support/context.py (100%) rename {shark_turbine => iree/turbine}/kernel/_support/dtype.py (100%) rename {shark_turbine => iree/turbine}/kernel/_support/indexing.py (100%) rename {shark_turbine => iree/turbine}/kernel/_support/regions.py (100%) rename {shark_turbine => iree/turbine}/kernel/_support/shaped_type.py (100%) rename {shark_turbine => iree/turbine}/kernel/_support/tracing.py (100%) rename {shark_turbine => iree/turbine}/kernel/compiler/base.py (100%) rename {shark_turbine => iree/turbine}/kernel/compiler/builder.py (100%) rename {shark_turbine => iree/turbine}/kernel/compiler/dispatch_codegen.py (100%) rename {shark_turbine => iree/turbine}/kernel/compiler/host_codegen.py (100%) rename {shark_turbine => iree/turbine}/kernel/compiler/ir.py (100%) rename {shark_turbine => iree/turbine}/kernel/compiler/kernel_codegen.py (100%) rename {shark_turbine => iree/turbine}/kernel/compiler/op_matchers.py (100%) rename {shark_turbine => iree/turbine}/kernel/compiler/utils.py (100%) rename {shark_turbine => iree/turbine}/kernel/compiler/vector_codegen.py (100%) rename {shark_turbine => iree/turbine}/kernel/gen/__init__.py (100%) rename {shark_turbine => iree/turbine}/kernel/gen/kernel.py (100%) rename {shark_turbine => iree/turbine}/kernel/gen/thread.py (100%) rename {shark_turbine => iree/turbine}/kernel/lang/__init__.py (100%) rename {shark_turbine => iree/turbine}/kernel/lang/global_symbols.py (100%) rename {shark_turbine => iree/turbine}/kernel/lang/grid.py (100%) rename {shark_turbine => iree/turbine}/kernel/lang/kernel_buffer.py (100%) rename {shark_turbine => iree/turbine}/kernel/lang/prims.py (100%) rename {shark_turbine => iree/turbine}/kernel/lang/types.py (100%) rename {shark_turbine => iree/turbine}/kernel/lang/wave_types.py (100%) rename {shark_turbine => iree/turbine}/kernel/ops/__init__.py (100%) rename {shark_turbine => iree/turbine}/kernel/ops/base.py (100%) rename {shark_turbine => iree/turbine}/kernel/ops/control_flow.py (100%) rename {shark_turbine => iree/turbine}/kernel/ops/core.py (100%) rename {shark_turbine => iree/turbine}/kernel/ops/math.py (100%) rename {shark_turbine => iree/turbine}/kernel/ops/memory.py (100%) rename {shark_turbine => iree/turbine}/kernel/ops/reduction.py (100%) rename {shark_turbine => iree/turbine}/kernel/ops/shape_manipulation.py (100%) rename {shark_turbine => iree/turbine}/kernel/ops/wave_ops.py (100%) rename {shark_turbine => iree/turbine}/kernel/wave/README.md (100%) rename {shark_turbine => iree/turbine}/kernel/wave/__init__.py (100%) rename {shark_turbine => iree/turbine}/kernel/wave/barriers.py (100%) rename {shark_turbine => iree/turbine}/kernel/wave/codegen.py (99%) rename {shark_turbine => iree/turbine}/kernel/wave/constraints.py (100%) rename {shark_turbine => iree/turbine}/kernel/wave/decompose_reduce_ops.py (100%) rename {shark_turbine => iree/turbine}/kernel/wave/docs/gemm_example.md (100%) rename {shark_turbine => iree/turbine}/kernel/wave/docs/mlsys/.gitignore (100%) rename {shark_turbine => iree/turbine}/kernel/wave/docs/mlsys/algorithm.sty (100%) rename {shark_turbine => iree/turbine}/kernel/wave/docs/mlsys/algorithmic.sty (100%) rename {shark_turbine => iree/turbine}/kernel/wave/docs/mlsys/fancyhdr.sty (100%) rename {shark_turbine => iree/turbine}/kernel/wave/docs/mlsys/mlsys2024.bst (100%) rename {shark_turbine => iree/turbine}/kernel/wave/docs/mlsys/mlsys2024.sty (100%) rename {shark_turbine => iree/turbine}/kernel/wave/docs/mlsys/tkw.bbl (100%) rename {shark_turbine => iree/turbine}/kernel/wave/docs/mlsys/tkw.bib (100%) rename {shark_turbine => iree/turbine}/kernel/wave/docs/mlsys/tkw.tex (100%) rename {shark_turbine => iree/turbine}/kernel/wave/expansion.py (100%) rename {shark_turbine => iree/turbine}/kernel/wave/hoisting.py (95%) rename {shark_turbine => iree/turbine}/kernel/wave/index_sequence_analysis.py (100%) rename {shark_turbine => iree/turbine}/kernel/wave/iree_utils.py (100%) rename {shark_turbine => iree/turbine}/kernel/wave/minimize_global_loads.py (100%) rename {shark_turbine => iree/turbine}/kernel/wave/promotion.py (100%) rename {shark_turbine => iree/turbine}/kernel/wave/scheduling/__init__.py (100%) rename {shark_turbine => iree/turbine}/kernel/wave/scheduling/graph_utils.py (100%) rename {shark_turbine => iree/turbine}/kernel/wave/scheduling/loop_reconstruction.py (100%) rename {shark_turbine => iree/turbine}/kernel/wave/scheduling/loop_reconstruction_utils.py (100%) rename {shark_turbine => iree/turbine}/kernel/wave/scheduling/modulo_scheduling.py (100%) rename {shark_turbine => iree/turbine}/kernel/wave/scheduling/resources.py (100%) rename {shark_turbine => iree/turbine}/kernel/wave/scheduling/schedule.py (100%) rename {shark_turbine => iree/turbine}/kernel/wave/shared_memory_indexing.py (100%) rename {shark_turbine => iree/turbine}/kernel/wave/thread_shape_analysis.py (99%) rename {shark_turbine => iree/turbine}/kernel/wave/utils.py (99%) rename {shark_turbine => iree/turbine}/kernel/wave/visualization.py (100%) rename {shark_turbine => iree/turbine}/kernel/wave/wave.py (99%) rename {shark_turbine => iree/turbine}/kernel/wave/wave_sim.py (100%) rename {shark_turbine => iree/turbine}/ops/__init__.py (100%) rename {shark_turbine => iree/turbine}/ops/_jinja_test_ops.py (100%) rename {shark_turbine => iree/turbine}/ops/_str_format_test_ops.py (100%) rename {shark_turbine => iree/turbine}/ops/iree.py (100%) rename {shark_turbine => iree/turbine}/ops/templates/test_add_jinja.mlir (100%) rename {shark_turbine => iree/turbine}/ops/templates/test_add_strformat.mlir (100%) rename {shark_turbine => iree/turbine}/ops/templates/test_syntax_error.mlir (100%) rename {shark_turbine => iree/turbine}/runtime/__init__.py (100%) rename {shark_turbine => iree/turbine}/runtime/device.py (100%) rename {shark_turbine => iree/turbine}/runtime/launch.py (100%) rename {shark_turbine => iree/turbine}/runtime/op_reg/__init__.py (100%) rename {shark_turbine => iree/turbine}/runtime/op_reg/base.py (100%) rename {shark_turbine => iree/turbine}/runtime/op_reg/compiler.py (100%) rename {shark_turbine => iree/turbine}/runtime/op_reg/eager.py (100%) rename {shark_turbine => iree/turbine}/runtime/op_reg/impl_helper.py (100%) rename {shark_turbine => iree/turbine}/runtime/tracing.py (100%) rename {shark_turbine => iree/turbine}/support/__init__.py (100%) rename {shark_turbine => iree/turbine}/support/conversions.py (100%) rename {shark_turbine => iree/turbine}/support/debugging.py (100%) rename {shark_turbine => iree/turbine}/support/exceptions.py (100%) rename {shark_turbine => iree/turbine}/support/ir_imports.py (100%) rename {shark_turbine => iree/turbine}/support/logging.py (100%) rename {shark_turbine => iree/turbine}/tools/__init__.py (100%) rename {shark_turbine => iree/turbine}/tools/interpreter.py (100%) rename {shark_turbine => iree/turbine}/transforms/builder.py (100%) rename {shark_turbine => iree/turbine}/transforms/general/add_metadata.py (97%) rename {shark_turbine => iree/turbine}/transforms/general/custom_op_expansion.py (100%) rename {shark_turbine => iree/turbine}/transforms/general/rename_parameters.py (100%) rename {shark_turbine => iree/turbine}/transforms/merger.py (100%) rename {shark_turbine => iree/turbine}/transforms/quantization/mm_group_quant.py (100%) rename {shark_turbine => iree/turbine}/transforms/rewriter.py (100%) create mode 100644 shark_turbine/__init__.py diff --git a/MANIFEST.in b/MANIFEST.in index 97971bba..65338637 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -2,4 +2,4 @@ include README.md include requirements.txt include pytorch-cpu-requirements.txt include version_info.json -include shark_turbine/ops/templates/*.mlir +include iree/turbine/ops/templates/*.mlir diff --git a/README.md b/README.md index 4d0d0c22..aa01b826 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,7 @@ Turbine provides a collection of tools: * *AOT Export*: For compiling one or more `nn.Module`s to compiled, deployment ready artifacts. This operates via both a simple one-shot export API (Already upstreamed to [torch-mlir](https://github.com/llvm/torch-mlir/blob/main/python/torch_mlir/extras/fx_importer.py)) - for simple models and an underlying [advanced API](shark_turbine/aot/compiled_module.py) for complicated models + for simple models and an underlying [advanced API](iree/turbine/aot/compiled_module.py) for complicated models and accessing the full features of the runtime. * *Eager Execution*: A `torch.compile` backend is provided and a Turbine Tensor/Device is available for more native, interactive use within a PyTorch session. diff --git a/build_tools/build_release.py b/build_tools/build_release.py index 5a6ef98d..5a90a7cf 100755 --- a/build_tools/build_release.py +++ b/build_tools/build_release.py @@ -159,10 +159,8 @@ def main(): print("Downloading remaining requirements") download_requirements(REPO_ROOT / "requirements.txt") - print("Building shark-turbine") - build_wheel(REPO_ROOT) print("Building iree-turbine") - build_wheel(REPO_ROOT, env={"TURBINE_PACKAGE_NAME": "iree-turbine"}) + build_wheel(REPO_ROOT) if __name__ == "__main__": diff --git a/examples/aot_mlp/mlp_export_dynamic.py b/examples/aot_mlp/mlp_export_dynamic.py index cd863655..3bedd7c1 100644 --- a/examples/aot_mlp/mlp_export_dynamic.py +++ b/examples/aot_mlp/mlp_export_dynamic.py @@ -12,7 +12,7 @@ import torch import torch.nn as nn -import shark_turbine.aot as aot +import iree.turbine.aot as aot class MLP(nn.Module): diff --git a/examples/aot_mlp/mlp_export_simple.py b/examples/aot_mlp/mlp_export_simple.py index fed4795d..30d7ae95 100644 --- a/examples/aot_mlp/mlp_export_simple.py +++ b/examples/aot_mlp/mlp_export_simple.py @@ -9,7 +9,7 @@ import torch import torch.nn as nn -import shark_turbine.aot as aot +import iree.turbine.aot as aot class MLP(nn.Module): diff --git a/examples/llama2_inference/README.md b/examples/llama2_inference/README.md deleted file mode 100644 index 50bc6537..00000000 --- a/examples/llama2_inference/README.md +++ /dev/null @@ -1,47 +0,0 @@ -# LLAMA 2 Inference - -This example require some extra dependencies. Here's an easy way to get it running on a fresh server. - -Don't forget to put in your huggingface token from https://huggingface.co/settings/tokens - -```bash -#!/bin/bash - - -# if you don't insert it, you will be prompted to log in later; -# you may need to rerun this script after logging in -YOUR_HF_TOKEN="insert token for headless" - -# clone and install dependencies -sudo apt install -y git -git clone https://github.com/nod-ai/SHARK-Turbine.git -cd SHARK-Turbine -pip install -r requirements.txt -pip install --update "huggingface_hub[cli]" transformers sentencepiece protobuf - -# do an editable install from the cloned SHARK-Turbine -pip install --editable . - -# Log in with Hugging Face CLI if token setup is required -if [[ $YOUR_HF_TOKEN == hf_* ]]; then - huggingface login --token $YOUR_HF_TOKEN - echo "Logged in with YOUR_HF_TOKEN." -elif [ -f ~/.cache/huggingface/token ]; then - # Read token from the file - TOKEN_CONTENT=$(cat ~/.cache/huggingface/token) - - # Check if the token starts with "hf_" - if [[ $TOKEN_CONTENT == hf_* ]]; then - echo "Already logged in with a Hugging Face token." - else - echo "Token in file does not start with 'hf_'. Please log into huggingface to download models." - huggingface-cli login - fi -else - echo "Please log into huggingface to download models." - huggingface-cli login -fi - -# Step 7: Run the Python script -python examples/llama2_inference/stateless_llama.py -``` diff --git a/examples/llama2_inference/llama2.ipynb b/examples/llama2_inference/llama2.ipynb deleted file mode 100644 index b008bbd2..00000000 --- a/examples/llama2_inference/llama2.ipynb +++ /dev/null @@ -1,503 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "id": "c0c9f034-7af1-4dc2-bbfb-5bb9e27c07ca", - "metadata": {}, - "outputs": [], - "source": [ - "from transformers import AutoTokenizer, AutoModelForCausalLM\n", - "import torch\n", - "from torch.utils import _pytree as pytree\n", - "from shark_turbine.aot import *\n", - "from iree.compiler.ir import Context\n" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "4d92bb47-2b93-4f32-a445-c0ad2adc37ad", - "metadata": {}, - "outputs": [], - "source": [ - "#set some config values\n", - "\n", - "hf_auth_token = \"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\"\n", - "hf_model_name = \"meta-llama/Llama-2-7b-chat-hf\"\n", - "state_schema_path = \"llama2_state_schema.json\"\n", - "with open(state_schema_path, \"r+\") as f:\n", - " state_schema = pytree.treespec_loads(f.read())\n", - "prompt = \"\"\"\n", - "[INST] <>\n", - "Be concise. You are a helpful, respectful and honest assistant. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. <> hi what are you? [/INST]\n", - "\"\"\"\n" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "d4664585-5e15-45c7-8c5c-c8eaf6381435", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/transformers/models/auto/tokenization_auto.py:640: FutureWarning: The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers.\n", - " warnings.warn(\n", - "/home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/transformers/models/auto/auto_factory.py:479: FutureWarning: The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers.\n", - " warnings.warn(\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "5e411acda19c4228b008ff622bdf110e", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Loading checkpoint shards: 0%| | 0/2 [00:00.5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:26 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:33,234] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s1 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:72 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:33,409] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s2, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:118 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:33,707] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s3 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:189 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:33,845] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s3, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:228 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:33,878] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s4, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:235 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:34,188] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s5 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:306 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:34,326] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s5, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:345 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:34,359] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s6, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:352 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:34,661] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s7 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:423 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:34,800] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s7, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:462 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:34,832] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s8, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:469 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:35,130] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s9 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:540 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:35,271] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s9, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:579 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:35,305] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s10, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:586 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:35,611] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s11 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:657 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:35,762] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s11, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:696 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:35,795] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s12, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:703 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:36,107] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s13 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:774 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:36,249] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s13, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:813 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:36,282] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s14, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:820 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:36,589] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s15 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:891 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:36,734] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s15, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:930 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:36,768] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s16, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:937 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:37,105] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s17 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:1008 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:37,249] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s17, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:1047 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:37,286] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s18, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:1054 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:37,595] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s19 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:1125 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:37,744] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s19, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:1164 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:37,778] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s20, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:1171 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:38,090] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s21 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:1242 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:38,238] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s21, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:1281 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:38,272] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s22, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:1288 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:38,584] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s23 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:1359 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:38,734] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s23, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:1398 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:38,768] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s24, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:1405 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:39,086] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s25 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:1476 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:39,239] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s25, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:1515 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:39,274] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s26, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:1522 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:39,597] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s27 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:1593 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:39,759] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s27, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:1632 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:39,812] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s28, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:1639 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:40,330] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s29 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:1710 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:40,534] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s29, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:1749 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:40,582] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s30, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:1756 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:41,068] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s31 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:1827 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:41,242] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s31, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:1866 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:41,280] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s32, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:1873 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:41,686] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s33 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:1944 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:41,968] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s33, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:1983 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:42,004] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s34, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:1990 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:42,419] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s35 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:2061 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:42,580] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s35, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:2100 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:42,618] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s36, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:2107 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:43,002] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s37 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:2178 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:43,174] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s37, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:2217 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:43,215] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s38, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:2224 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:43,566] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s39 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:2295 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:43,738] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s39, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:2334 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:43,776] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s40, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:2341 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:44,116] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s41 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:2412 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:44,281] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s41, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:2451 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:44,320] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s42, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:2458 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:44,656] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s43 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:2529 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:44,822] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s43, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:2568 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:44,860] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s44, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:2575 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:45,218] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s45 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:2646 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:45,387] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s45, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:2685 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:45,426] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s46, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:2692 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:45,772] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s47 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:2763 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:45,943] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s47, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:2802 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:45,983] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s48, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:2809 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:46,376] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s49 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:2880 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:46,563] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s49, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:2919 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:46,605] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s50, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:2926 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:46,962] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s51 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:2997 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:47,136] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s51, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:3036 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:47,176] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s52, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:3043 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:47,540] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s53 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:3114 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:47,718] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s53, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:3153 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:47,758] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s54, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:3160 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:48,125] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s55 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:3231 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:48,308] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s55, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:3270 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:48,349] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s56, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:3277 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:48,715] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s57 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:3348 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:48,897] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s57, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:3387 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:48,937] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s58, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:3394 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:49,317] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s59 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:3465 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:49,499] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s59, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:3504 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:49,540] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s60, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:3511 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:49,915] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s61 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:3582 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:50,113] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s61, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:3621 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:50,155] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s62, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:3628 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:50,515] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s63 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:3699 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:50,697] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s63, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:3738 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:50,737] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s64, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:3745 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:53,791] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] produce_guards\n", - "[2023-10-09 18:49:54,155] torch.fx.experimental.symbolic_shapes: [WARNING] Ignored guard s0 + s1 > 4096 == False, this could result in accuracy problems\n", - "[2023-10-09 18:49:54,157] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s1 <= 4096 [guard added] (_decomp/decompositions.py:725 in slice_forward)\n" - ] - } - ], - "source": [ - "#Run the export pipeline\n", - "inst = StateUpdateModule(context=Context(), import_to=\"IMPORT\")\n", - "module_str = str(CompiledModule.get_mlir_module(inst))" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "bc04e1db-a8cc-4182-884d-ba3d8ae5adeb", - "metadata": {}, - "outputs": [], - "source": [ - "#Output a torch-ir mlir file\n", - "with open(\"llama2_torch.mlir\", \"w+\") as f:\n", - " f.write(module_str)\n", - "#TODO: run the rest of the compile pipeline and do inference" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.4" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/examples/llama2_inference/llama2_state_schema.json b/examples/llama2_inference/llama2_state_schema.json deleted file mode 100644 index b5506055..00000000 --- a/examples/llama2_inference/llama2_state_schema.json +++ /dev/null @@ -1 +0,0 @@ -[1, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}]}] diff --git a/examples/llama2_inference/requirements.txt b/examples/llama2_inference/requirements.txt deleted file mode 100644 index acbc93ca..00000000 --- a/examples/llama2_inference/requirements.txt +++ /dev/null @@ -1,4 +0,0 @@ -protobuf -sentencepiece -shark_turbine -transformers @ git+https://github.com/huggingface/transformers.git@7d8ff3629b2725ec43ace99c1a6e87ac1978d433 diff --git a/examples/resnet-18/requirements.txt b/examples/resnet-18/requirements.txt index a5123e97..b7428649 100644 --- a/examples/resnet-18/requirements.txt +++ b/examples/resnet-18/requirements.txt @@ -1,2 +1,2 @@ transformers -shark_turbine==0.9.2 +iree_turbine==0.9.2 diff --git a/examples/resnet-18/resnet-18.py b/examples/resnet-18/resnet-18.py index 20340013..2b3fce56 100644 --- a/examples/resnet-18/resnet-18.py +++ b/examples/resnet-18/resnet-18.py @@ -1,6 +1,6 @@ from transformers import AutoFeatureExtractor, AutoModelForImageClassification import torch -from shark_turbine.aot import * +from iree.turbine.aot import * import iree.runtime as rt # Loading feature extractor and pretrained model from huggingface diff --git a/examples/runtime_torture/launchable_torture.py b/examples/runtime_torture/launchable_torture.py index 56f92a99..d58c6a80 100644 --- a/examples/runtime_torture/launchable_torture.py +++ b/examples/runtime_torture/launchable_torture.py @@ -12,9 +12,9 @@ import torch import torch.nn as nn -import shark_turbine.aot as aot +import iree.turbine.aot as aot -from shark_turbine.runtime import ( +from iree.turbine.runtime import ( Launchable, ) diff --git a/iree/turbine/__init__.py b/iree/turbine/__init__.py index c59e85c2..d95aa54f 100644 --- a/iree/turbine/__init__.py +++ b/iree/turbine/__init__.py @@ -8,15 +8,3 @@ # Licensed under the Apache License v2.0 with LLVM Exceptions. # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -# TODO: This redirection layer exists while we are migrating from the -# shark_turbine top-level package name to iree.turbine. It exports the -# public API but not the internal details. In a future switch, all code -# will be directly located here and the redirect will be done in the -# shark_turbine namespace. - -from shark_turbine import aot -from shark_turbine import dynamo -from shark_turbine import kernel -from shark_turbine import ops -from shark_turbine import runtime diff --git a/shark_turbine/aot/__init__.py b/iree/turbine/aot/__init__.py similarity index 100% rename from shark_turbine/aot/__init__.py rename to iree/turbine/aot/__init__.py diff --git a/shark_turbine/aot/builtins/__init__.py b/iree/turbine/aot/builtins/__init__.py similarity index 100% rename from shark_turbine/aot/builtins/__init__.py rename to iree/turbine/aot/builtins/__init__.py diff --git a/shark_turbine/aot/builtins/globals.py b/iree/turbine/aot/builtins/globals.py similarity index 100% rename from shark_turbine/aot/builtins/globals.py rename to iree/turbine/aot/builtins/globals.py diff --git a/shark_turbine/aot/builtins/jittable.py b/iree/turbine/aot/builtins/jittable.py similarity index 100% rename from shark_turbine/aot/builtins/jittable.py rename to iree/turbine/aot/builtins/jittable.py diff --git a/shark_turbine/aot/compiled_module.py b/iree/turbine/aot/compiled_module.py similarity index 100% rename from shark_turbine/aot/compiled_module.py rename to iree/turbine/aot/compiled_module.py diff --git a/shark_turbine/aot/decompositions.py b/iree/turbine/aot/decompositions.py similarity index 100% rename from shark_turbine/aot/decompositions.py rename to iree/turbine/aot/decompositions.py diff --git a/shark_turbine/aot/exporter.py b/iree/turbine/aot/exporter.py similarity index 100% rename from shark_turbine/aot/exporter.py rename to iree/turbine/aot/exporter.py diff --git a/shark_turbine/aot/fx_programs.py b/iree/turbine/aot/fx_programs.py similarity index 100% rename from shark_turbine/aot/fx_programs.py rename to iree/turbine/aot/fx_programs.py diff --git a/shark_turbine/aot/params.py b/iree/turbine/aot/params.py similarity index 100% rename from shark_turbine/aot/params.py rename to iree/turbine/aot/params.py diff --git a/shark_turbine/aot/passes/__init__.py b/iree/turbine/aot/passes/__init__.py similarity index 100% rename from shark_turbine/aot/passes/__init__.py rename to iree/turbine/aot/passes/__init__.py diff --git a/shark_turbine/aot/passes/functorch.py b/iree/turbine/aot/passes/functorch.py similarity index 100% rename from shark_turbine/aot/passes/functorch.py rename to iree/turbine/aot/passes/functorch.py diff --git a/shark_turbine/aot/support/ir_utils.py b/iree/turbine/aot/support/ir_utils.py similarity index 100% rename from shark_turbine/aot/support/ir_utils.py rename to iree/turbine/aot/support/ir_utils.py diff --git a/shark_turbine/aot/support/procedural/__init__.py b/iree/turbine/aot/support/procedural/__init__.py similarity index 100% rename from shark_turbine/aot/support/procedural/__init__.py rename to iree/turbine/aot/support/procedural/__init__.py diff --git a/shark_turbine/aot/support/procedural/base.py b/iree/turbine/aot/support/procedural/base.py similarity index 100% rename from shark_turbine/aot/support/procedural/base.py rename to iree/turbine/aot/support/procedural/base.py diff --git a/shark_turbine/aot/support/procedural/exported_program.py b/iree/turbine/aot/support/procedural/exported_program.py similarity index 100% rename from shark_turbine/aot/support/procedural/exported_program.py rename to iree/turbine/aot/support/procedural/exported_program.py diff --git a/shark_turbine/aot/support/procedural/globals.py b/iree/turbine/aot/support/procedural/globals.py similarity index 100% rename from shark_turbine/aot/support/procedural/globals.py rename to iree/turbine/aot/support/procedural/globals.py diff --git a/shark_turbine/aot/support/procedural/iree_emitter.py b/iree/turbine/aot/support/procedural/iree_emitter.py similarity index 100% rename from shark_turbine/aot/support/procedural/iree_emitter.py rename to iree/turbine/aot/support/procedural/iree_emitter.py diff --git a/shark_turbine/aot/support/procedural/primitives.py b/iree/turbine/aot/support/procedural/primitives.py similarity index 100% rename from shark_turbine/aot/support/procedural/primitives.py rename to iree/turbine/aot/support/procedural/primitives.py diff --git a/shark_turbine/aot/support/procedural/tracer.py b/iree/turbine/aot/support/procedural/tracer.py similarity index 100% rename from shark_turbine/aot/support/procedural/tracer.py rename to iree/turbine/aot/support/procedural/tracer.py diff --git a/shark_turbine/aot/tensor_traits.py b/iree/turbine/aot/tensor_traits.py similarity index 100% rename from shark_turbine/aot/tensor_traits.py rename to iree/turbine/aot/tensor_traits.py diff --git a/shark_turbine/dynamo/__init__.py b/iree/turbine/dynamo/__init__.py similarity index 100% rename from shark_turbine/dynamo/__init__.py rename to iree/turbine/dynamo/__init__.py diff --git a/shark_turbine/dynamo/backends/cpu.py b/iree/turbine/dynamo/backends/cpu.py similarity index 100% rename from shark_turbine/dynamo/backends/cpu.py rename to iree/turbine/dynamo/backends/cpu.py diff --git a/shark_turbine/dynamo/decompositions.py b/iree/turbine/dynamo/decompositions.py similarity index 100% rename from shark_turbine/dynamo/decompositions.py rename to iree/turbine/dynamo/decompositions.py diff --git a/shark_turbine/dynamo/executor.py b/iree/turbine/dynamo/executor.py similarity index 100% rename from shark_turbine/dynamo/executor.py rename to iree/turbine/dynamo/executor.py diff --git a/shark_turbine/dynamo/passes.py b/iree/turbine/dynamo/passes.py similarity index 100% rename from shark_turbine/dynamo/passes.py rename to iree/turbine/dynamo/passes.py diff --git a/shark_turbine/dynamo/tensor.py b/iree/turbine/dynamo/tensor.py similarity index 99% rename from shark_turbine/dynamo/tensor.py rename to iree/turbine/dynamo/tensor.py index cd1de1ea..bdf1cb83 100644 --- a/shark_turbine/dynamo/tensor.py +++ b/iree/turbine/dynamo/tensor.py @@ -474,8 +474,8 @@ def _get_device_state() -> DeviceState: return DeviceState(driver="local-task") -# Inspiration from https://github.com/nod-ai/SHARK-Turbine/blob/8293de5414889c72ff5cd10bf33c43fb0a3ea3ee/python/shark_turbine/aot/builtins/jittable.py#L212-L237 -# and https://github.com/nod-ai/SHARK-Turbine/blob/main/python/shark_turbine/dynamo/backends/cpu.py +# Inspiration from https://github.com/nod-ai/SHARK-Turbine/blob/8293de5414889c72ff5cd10bf33c43fb0a3ea3ee/python/iree/turbine/aot/builtins/jittable.py#L212-L237 +# and https://github.com/nod-ai/SHARK-Turbine/blob/main/python/iree/turbine/dynamo/backends/cpu.py # TODO: Try to generalize for other devices. def compute_method(super_fn, *args, **kwargs): # Compute factory fns reserve the last arg as src_op diff --git a/shark_turbine/dynamo/type_conversion.py b/iree/turbine/dynamo/type_conversion.py similarity index 100% rename from shark_turbine/dynamo/type_conversion.py rename to iree/turbine/dynamo/type_conversion.py diff --git a/shark_turbine/importers/README.md b/iree/turbine/importers/README.md similarity index 100% rename from shark_turbine/importers/README.md rename to iree/turbine/importers/README.md diff --git a/shark_turbine/importers/ir.py b/iree/turbine/importers/ir.py similarity index 100% rename from shark_turbine/importers/ir.py rename to iree/turbine/importers/ir.py diff --git a/shark_turbine/importers/utils.py b/iree/turbine/importers/utils.py similarity index 100% rename from shark_turbine/importers/utils.py rename to iree/turbine/importers/utils.py diff --git a/shark_turbine/kernel/__init__.py b/iree/turbine/kernel/__init__.py similarity index 100% rename from shark_turbine/kernel/__init__.py rename to iree/turbine/kernel/__init__.py diff --git a/shark_turbine/kernel/_support/context.py b/iree/turbine/kernel/_support/context.py similarity index 100% rename from shark_turbine/kernel/_support/context.py rename to iree/turbine/kernel/_support/context.py diff --git a/shark_turbine/kernel/_support/dtype.py b/iree/turbine/kernel/_support/dtype.py similarity index 100% rename from shark_turbine/kernel/_support/dtype.py rename to iree/turbine/kernel/_support/dtype.py diff --git a/shark_turbine/kernel/_support/indexing.py b/iree/turbine/kernel/_support/indexing.py similarity index 100% rename from shark_turbine/kernel/_support/indexing.py rename to iree/turbine/kernel/_support/indexing.py diff --git a/shark_turbine/kernel/_support/regions.py b/iree/turbine/kernel/_support/regions.py similarity index 100% rename from shark_turbine/kernel/_support/regions.py rename to iree/turbine/kernel/_support/regions.py diff --git a/shark_turbine/kernel/_support/shaped_type.py b/iree/turbine/kernel/_support/shaped_type.py similarity index 100% rename from shark_turbine/kernel/_support/shaped_type.py rename to iree/turbine/kernel/_support/shaped_type.py diff --git a/shark_turbine/kernel/_support/tracing.py b/iree/turbine/kernel/_support/tracing.py similarity index 100% rename from shark_turbine/kernel/_support/tracing.py rename to iree/turbine/kernel/_support/tracing.py diff --git a/shark_turbine/kernel/compiler/base.py b/iree/turbine/kernel/compiler/base.py similarity index 100% rename from shark_turbine/kernel/compiler/base.py rename to iree/turbine/kernel/compiler/base.py diff --git a/shark_turbine/kernel/compiler/builder.py b/iree/turbine/kernel/compiler/builder.py similarity index 100% rename from shark_turbine/kernel/compiler/builder.py rename to iree/turbine/kernel/compiler/builder.py diff --git a/shark_turbine/kernel/compiler/dispatch_codegen.py b/iree/turbine/kernel/compiler/dispatch_codegen.py similarity index 100% rename from shark_turbine/kernel/compiler/dispatch_codegen.py rename to iree/turbine/kernel/compiler/dispatch_codegen.py diff --git a/shark_turbine/kernel/compiler/host_codegen.py b/iree/turbine/kernel/compiler/host_codegen.py similarity index 100% rename from shark_turbine/kernel/compiler/host_codegen.py rename to iree/turbine/kernel/compiler/host_codegen.py diff --git a/shark_turbine/kernel/compiler/ir.py b/iree/turbine/kernel/compiler/ir.py similarity index 100% rename from shark_turbine/kernel/compiler/ir.py rename to iree/turbine/kernel/compiler/ir.py diff --git a/shark_turbine/kernel/compiler/kernel_codegen.py b/iree/turbine/kernel/compiler/kernel_codegen.py similarity index 100% rename from shark_turbine/kernel/compiler/kernel_codegen.py rename to iree/turbine/kernel/compiler/kernel_codegen.py diff --git a/shark_turbine/kernel/compiler/op_matchers.py b/iree/turbine/kernel/compiler/op_matchers.py similarity index 100% rename from shark_turbine/kernel/compiler/op_matchers.py rename to iree/turbine/kernel/compiler/op_matchers.py diff --git a/shark_turbine/kernel/compiler/utils.py b/iree/turbine/kernel/compiler/utils.py similarity index 100% rename from shark_turbine/kernel/compiler/utils.py rename to iree/turbine/kernel/compiler/utils.py diff --git a/shark_turbine/kernel/compiler/vector_codegen.py b/iree/turbine/kernel/compiler/vector_codegen.py similarity index 100% rename from shark_turbine/kernel/compiler/vector_codegen.py rename to iree/turbine/kernel/compiler/vector_codegen.py diff --git a/shark_turbine/kernel/gen/__init__.py b/iree/turbine/kernel/gen/__init__.py similarity index 100% rename from shark_turbine/kernel/gen/__init__.py rename to iree/turbine/kernel/gen/__init__.py diff --git a/shark_turbine/kernel/gen/kernel.py b/iree/turbine/kernel/gen/kernel.py similarity index 100% rename from shark_turbine/kernel/gen/kernel.py rename to iree/turbine/kernel/gen/kernel.py diff --git a/shark_turbine/kernel/gen/thread.py b/iree/turbine/kernel/gen/thread.py similarity index 100% rename from shark_turbine/kernel/gen/thread.py rename to iree/turbine/kernel/gen/thread.py diff --git a/shark_turbine/kernel/lang/__init__.py b/iree/turbine/kernel/lang/__init__.py similarity index 100% rename from shark_turbine/kernel/lang/__init__.py rename to iree/turbine/kernel/lang/__init__.py diff --git a/shark_turbine/kernel/lang/global_symbols.py b/iree/turbine/kernel/lang/global_symbols.py similarity index 100% rename from shark_turbine/kernel/lang/global_symbols.py rename to iree/turbine/kernel/lang/global_symbols.py diff --git a/shark_turbine/kernel/lang/grid.py b/iree/turbine/kernel/lang/grid.py similarity index 100% rename from shark_turbine/kernel/lang/grid.py rename to iree/turbine/kernel/lang/grid.py diff --git a/shark_turbine/kernel/lang/kernel_buffer.py b/iree/turbine/kernel/lang/kernel_buffer.py similarity index 100% rename from shark_turbine/kernel/lang/kernel_buffer.py rename to iree/turbine/kernel/lang/kernel_buffer.py diff --git a/shark_turbine/kernel/lang/prims.py b/iree/turbine/kernel/lang/prims.py similarity index 100% rename from shark_turbine/kernel/lang/prims.py rename to iree/turbine/kernel/lang/prims.py diff --git a/shark_turbine/kernel/lang/types.py b/iree/turbine/kernel/lang/types.py similarity index 100% rename from shark_turbine/kernel/lang/types.py rename to iree/turbine/kernel/lang/types.py diff --git a/shark_turbine/kernel/lang/wave_types.py b/iree/turbine/kernel/lang/wave_types.py similarity index 100% rename from shark_turbine/kernel/lang/wave_types.py rename to iree/turbine/kernel/lang/wave_types.py diff --git a/shark_turbine/kernel/ops/__init__.py b/iree/turbine/kernel/ops/__init__.py similarity index 100% rename from shark_turbine/kernel/ops/__init__.py rename to iree/turbine/kernel/ops/__init__.py diff --git a/shark_turbine/kernel/ops/base.py b/iree/turbine/kernel/ops/base.py similarity index 100% rename from shark_turbine/kernel/ops/base.py rename to iree/turbine/kernel/ops/base.py diff --git a/shark_turbine/kernel/ops/control_flow.py b/iree/turbine/kernel/ops/control_flow.py similarity index 100% rename from shark_turbine/kernel/ops/control_flow.py rename to iree/turbine/kernel/ops/control_flow.py diff --git a/shark_turbine/kernel/ops/core.py b/iree/turbine/kernel/ops/core.py similarity index 100% rename from shark_turbine/kernel/ops/core.py rename to iree/turbine/kernel/ops/core.py diff --git a/shark_turbine/kernel/ops/math.py b/iree/turbine/kernel/ops/math.py similarity index 100% rename from shark_turbine/kernel/ops/math.py rename to iree/turbine/kernel/ops/math.py diff --git a/shark_turbine/kernel/ops/memory.py b/iree/turbine/kernel/ops/memory.py similarity index 100% rename from shark_turbine/kernel/ops/memory.py rename to iree/turbine/kernel/ops/memory.py diff --git a/shark_turbine/kernel/ops/reduction.py b/iree/turbine/kernel/ops/reduction.py similarity index 100% rename from shark_turbine/kernel/ops/reduction.py rename to iree/turbine/kernel/ops/reduction.py diff --git a/shark_turbine/kernel/ops/shape_manipulation.py b/iree/turbine/kernel/ops/shape_manipulation.py similarity index 100% rename from shark_turbine/kernel/ops/shape_manipulation.py rename to iree/turbine/kernel/ops/shape_manipulation.py diff --git a/shark_turbine/kernel/ops/wave_ops.py b/iree/turbine/kernel/ops/wave_ops.py similarity index 100% rename from shark_turbine/kernel/ops/wave_ops.py rename to iree/turbine/kernel/ops/wave_ops.py diff --git a/shark_turbine/kernel/wave/README.md b/iree/turbine/kernel/wave/README.md similarity index 100% rename from shark_turbine/kernel/wave/README.md rename to iree/turbine/kernel/wave/README.md diff --git a/shark_turbine/kernel/wave/__init__.py b/iree/turbine/kernel/wave/__init__.py similarity index 100% rename from shark_turbine/kernel/wave/__init__.py rename to iree/turbine/kernel/wave/__init__.py diff --git a/shark_turbine/kernel/wave/barriers.py b/iree/turbine/kernel/wave/barriers.py similarity index 100% rename from shark_turbine/kernel/wave/barriers.py rename to iree/turbine/kernel/wave/barriers.py diff --git a/shark_turbine/kernel/wave/codegen.py b/iree/turbine/kernel/wave/codegen.py similarity index 99% rename from shark_turbine/kernel/wave/codegen.py rename to iree/turbine/kernel/wave/codegen.py index e218d71c..233a571d 100644 --- a/shark_turbine/kernel/wave/codegen.py +++ b/iree/turbine/kernel/wave/codegen.py @@ -40,10 +40,10 @@ scf_d, vector_d, ) -from shark_turbine.aot.support.ir_utils import _is_float_type, _is_integer_like_type +from iree.turbine.aot.support.ir_utils import _is_float_type, _is_integer_like_type # TK infrastructure imports. -from shark_turbine.kernel.lang.global_symbols import * +from iree.turbine.kernel.lang.global_symbols import * from ..ops.wave_ops import ( write, broadcast, diff --git a/shark_turbine/kernel/wave/constraints.py b/iree/turbine/kernel/wave/constraints.py similarity index 100% rename from shark_turbine/kernel/wave/constraints.py rename to iree/turbine/kernel/wave/constraints.py diff --git a/shark_turbine/kernel/wave/decompose_reduce_ops.py b/iree/turbine/kernel/wave/decompose_reduce_ops.py similarity index 100% rename from shark_turbine/kernel/wave/decompose_reduce_ops.py rename to iree/turbine/kernel/wave/decompose_reduce_ops.py diff --git a/shark_turbine/kernel/wave/docs/gemm_example.md b/iree/turbine/kernel/wave/docs/gemm_example.md similarity index 100% rename from shark_turbine/kernel/wave/docs/gemm_example.md rename to iree/turbine/kernel/wave/docs/gemm_example.md diff --git a/shark_turbine/kernel/wave/docs/mlsys/.gitignore b/iree/turbine/kernel/wave/docs/mlsys/.gitignore similarity index 100% rename from shark_turbine/kernel/wave/docs/mlsys/.gitignore rename to iree/turbine/kernel/wave/docs/mlsys/.gitignore diff --git a/shark_turbine/kernel/wave/docs/mlsys/algorithm.sty b/iree/turbine/kernel/wave/docs/mlsys/algorithm.sty similarity index 100% rename from shark_turbine/kernel/wave/docs/mlsys/algorithm.sty rename to iree/turbine/kernel/wave/docs/mlsys/algorithm.sty diff --git a/shark_turbine/kernel/wave/docs/mlsys/algorithmic.sty b/iree/turbine/kernel/wave/docs/mlsys/algorithmic.sty similarity index 100% rename from shark_turbine/kernel/wave/docs/mlsys/algorithmic.sty rename to iree/turbine/kernel/wave/docs/mlsys/algorithmic.sty diff --git a/shark_turbine/kernel/wave/docs/mlsys/fancyhdr.sty b/iree/turbine/kernel/wave/docs/mlsys/fancyhdr.sty similarity index 100% rename from shark_turbine/kernel/wave/docs/mlsys/fancyhdr.sty rename to iree/turbine/kernel/wave/docs/mlsys/fancyhdr.sty diff --git a/shark_turbine/kernel/wave/docs/mlsys/mlsys2024.bst b/iree/turbine/kernel/wave/docs/mlsys/mlsys2024.bst similarity index 100% rename from shark_turbine/kernel/wave/docs/mlsys/mlsys2024.bst rename to iree/turbine/kernel/wave/docs/mlsys/mlsys2024.bst diff --git a/shark_turbine/kernel/wave/docs/mlsys/mlsys2024.sty b/iree/turbine/kernel/wave/docs/mlsys/mlsys2024.sty similarity index 100% rename from shark_turbine/kernel/wave/docs/mlsys/mlsys2024.sty rename to iree/turbine/kernel/wave/docs/mlsys/mlsys2024.sty diff --git a/shark_turbine/kernel/wave/docs/mlsys/tkw.bbl b/iree/turbine/kernel/wave/docs/mlsys/tkw.bbl similarity index 100% rename from shark_turbine/kernel/wave/docs/mlsys/tkw.bbl rename to iree/turbine/kernel/wave/docs/mlsys/tkw.bbl diff --git a/shark_turbine/kernel/wave/docs/mlsys/tkw.bib b/iree/turbine/kernel/wave/docs/mlsys/tkw.bib similarity index 100% rename from shark_turbine/kernel/wave/docs/mlsys/tkw.bib rename to iree/turbine/kernel/wave/docs/mlsys/tkw.bib diff --git a/shark_turbine/kernel/wave/docs/mlsys/tkw.tex b/iree/turbine/kernel/wave/docs/mlsys/tkw.tex similarity index 100% rename from shark_turbine/kernel/wave/docs/mlsys/tkw.tex rename to iree/turbine/kernel/wave/docs/mlsys/tkw.tex diff --git a/shark_turbine/kernel/wave/expansion.py b/iree/turbine/kernel/wave/expansion.py similarity index 100% rename from shark_turbine/kernel/wave/expansion.py rename to iree/turbine/kernel/wave/expansion.py diff --git a/shark_turbine/kernel/wave/hoisting.py b/iree/turbine/kernel/wave/hoisting.py similarity index 95% rename from shark_turbine/kernel/wave/hoisting.py rename to iree/turbine/kernel/wave/hoisting.py index df68c753..5a4773d7 100644 --- a/shark_turbine/kernel/wave/hoisting.py +++ b/iree/turbine/kernel/wave/hoisting.py @@ -5,7 +5,7 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from ...support.logging import get_logger -from shark_turbine.kernel._support.tracing import CapturedTrace +from iree.turbine.kernel._support.tracing import CapturedTrace import torch.fx as fx from ..ops.wave_ops import * from ..lang.global_symbols import * diff --git a/shark_turbine/kernel/wave/index_sequence_analysis.py b/iree/turbine/kernel/wave/index_sequence_analysis.py similarity index 100% rename from shark_turbine/kernel/wave/index_sequence_analysis.py rename to iree/turbine/kernel/wave/index_sequence_analysis.py diff --git a/shark_turbine/kernel/wave/iree_utils.py b/iree/turbine/kernel/wave/iree_utils.py similarity index 100% rename from shark_turbine/kernel/wave/iree_utils.py rename to iree/turbine/kernel/wave/iree_utils.py diff --git a/shark_turbine/kernel/wave/minimize_global_loads.py b/iree/turbine/kernel/wave/minimize_global_loads.py similarity index 100% rename from shark_turbine/kernel/wave/minimize_global_loads.py rename to iree/turbine/kernel/wave/minimize_global_loads.py diff --git a/shark_turbine/kernel/wave/promotion.py b/iree/turbine/kernel/wave/promotion.py similarity index 100% rename from shark_turbine/kernel/wave/promotion.py rename to iree/turbine/kernel/wave/promotion.py diff --git a/shark_turbine/kernel/wave/scheduling/__init__.py b/iree/turbine/kernel/wave/scheduling/__init__.py similarity index 100% rename from shark_turbine/kernel/wave/scheduling/__init__.py rename to iree/turbine/kernel/wave/scheduling/__init__.py diff --git a/shark_turbine/kernel/wave/scheduling/graph_utils.py b/iree/turbine/kernel/wave/scheduling/graph_utils.py similarity index 100% rename from shark_turbine/kernel/wave/scheduling/graph_utils.py rename to iree/turbine/kernel/wave/scheduling/graph_utils.py diff --git a/shark_turbine/kernel/wave/scheduling/loop_reconstruction.py b/iree/turbine/kernel/wave/scheduling/loop_reconstruction.py similarity index 100% rename from shark_turbine/kernel/wave/scheduling/loop_reconstruction.py rename to iree/turbine/kernel/wave/scheduling/loop_reconstruction.py diff --git a/shark_turbine/kernel/wave/scheduling/loop_reconstruction_utils.py b/iree/turbine/kernel/wave/scheduling/loop_reconstruction_utils.py similarity index 100% rename from shark_turbine/kernel/wave/scheduling/loop_reconstruction_utils.py rename to iree/turbine/kernel/wave/scheduling/loop_reconstruction_utils.py diff --git a/shark_turbine/kernel/wave/scheduling/modulo_scheduling.py b/iree/turbine/kernel/wave/scheduling/modulo_scheduling.py similarity index 100% rename from shark_turbine/kernel/wave/scheduling/modulo_scheduling.py rename to iree/turbine/kernel/wave/scheduling/modulo_scheduling.py diff --git a/shark_turbine/kernel/wave/scheduling/resources.py b/iree/turbine/kernel/wave/scheduling/resources.py similarity index 100% rename from shark_turbine/kernel/wave/scheduling/resources.py rename to iree/turbine/kernel/wave/scheduling/resources.py diff --git a/shark_turbine/kernel/wave/scheduling/schedule.py b/iree/turbine/kernel/wave/scheduling/schedule.py similarity index 100% rename from shark_turbine/kernel/wave/scheduling/schedule.py rename to iree/turbine/kernel/wave/scheduling/schedule.py diff --git a/shark_turbine/kernel/wave/shared_memory_indexing.py b/iree/turbine/kernel/wave/shared_memory_indexing.py similarity index 100% rename from shark_turbine/kernel/wave/shared_memory_indexing.py rename to iree/turbine/kernel/wave/shared_memory_indexing.py diff --git a/shark_turbine/kernel/wave/thread_shape_analysis.py b/iree/turbine/kernel/wave/thread_shape_analysis.py similarity index 99% rename from shark_turbine/kernel/wave/thread_shape_analysis.py rename to iree/turbine/kernel/wave/thread_shape_analysis.py index 5fd0b999..927bd363 100644 --- a/shark_turbine/kernel/wave/thread_shape_analysis.py +++ b/iree/turbine/kernel/wave/thread_shape_analysis.py @@ -5,7 +5,7 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from ...support.logging import get_logger -from shark_turbine.kernel._support.tracing import CapturedTrace +from iree.turbine.kernel._support.tracing import CapturedTrace import torch.fx as fx from ..ops.wave_ops import * from ..lang.global_symbols import * diff --git a/shark_turbine/kernel/wave/utils.py b/iree/turbine/kernel/wave/utils.py similarity index 99% rename from shark_turbine/kernel/wave/utils.py rename to iree/turbine/kernel/wave/utils.py index 869df061..020adf1f 100644 --- a/shark_turbine/kernel/wave/utils.py +++ b/iree/turbine/kernel/wave/utils.py @@ -33,7 +33,7 @@ TilingConstraint, ) import torch.fx as fx -import shark_turbine.kernel.lang as tkl +import iree.turbine.kernel.lang as tkl import tempfile diff --git a/shark_turbine/kernel/wave/visualization.py b/iree/turbine/kernel/wave/visualization.py similarity index 100% rename from shark_turbine/kernel/wave/visualization.py rename to iree/turbine/kernel/wave/visualization.py diff --git a/shark_turbine/kernel/wave/wave.py b/iree/turbine/kernel/wave/wave.py similarity index 99% rename from shark_turbine/kernel/wave/wave.py rename to iree/turbine/kernel/wave/wave.py index 202cdd92..21485ed1 100644 --- a/shark_turbine/kernel/wave/wave.py +++ b/iree/turbine/kernel/wave/wave.py @@ -42,7 +42,7 @@ from .thread_shape_analysis import determine_thread_shapes from .scheduling.schedule import schedule_graph from .._support.indexing import IndexingContext, IndexExpr -import shark_turbine.kernel.lang as tkl +import iree.turbine.kernel.lang as tkl from .._support.tracing import ( CapturedTrace, CompiledContext, diff --git a/shark_turbine/kernel/wave/wave_sim.py b/iree/turbine/kernel/wave/wave_sim.py similarity index 100% rename from shark_turbine/kernel/wave/wave_sim.py rename to iree/turbine/kernel/wave/wave_sim.py diff --git a/shark_turbine/ops/__init__.py b/iree/turbine/ops/__init__.py similarity index 100% rename from shark_turbine/ops/__init__.py rename to iree/turbine/ops/__init__.py diff --git a/shark_turbine/ops/_jinja_test_ops.py b/iree/turbine/ops/_jinja_test_ops.py similarity index 100% rename from shark_turbine/ops/_jinja_test_ops.py rename to iree/turbine/ops/_jinja_test_ops.py diff --git a/shark_turbine/ops/_str_format_test_ops.py b/iree/turbine/ops/_str_format_test_ops.py similarity index 100% rename from shark_turbine/ops/_str_format_test_ops.py rename to iree/turbine/ops/_str_format_test_ops.py diff --git a/shark_turbine/ops/iree.py b/iree/turbine/ops/iree.py similarity index 100% rename from shark_turbine/ops/iree.py rename to iree/turbine/ops/iree.py diff --git a/shark_turbine/ops/templates/test_add_jinja.mlir b/iree/turbine/ops/templates/test_add_jinja.mlir similarity index 100% rename from shark_turbine/ops/templates/test_add_jinja.mlir rename to iree/turbine/ops/templates/test_add_jinja.mlir diff --git a/shark_turbine/ops/templates/test_add_strformat.mlir b/iree/turbine/ops/templates/test_add_strformat.mlir similarity index 100% rename from shark_turbine/ops/templates/test_add_strformat.mlir rename to iree/turbine/ops/templates/test_add_strformat.mlir diff --git a/shark_turbine/ops/templates/test_syntax_error.mlir b/iree/turbine/ops/templates/test_syntax_error.mlir similarity index 100% rename from shark_turbine/ops/templates/test_syntax_error.mlir rename to iree/turbine/ops/templates/test_syntax_error.mlir diff --git a/shark_turbine/runtime/__init__.py b/iree/turbine/runtime/__init__.py similarity index 100% rename from shark_turbine/runtime/__init__.py rename to iree/turbine/runtime/__init__.py diff --git a/shark_turbine/runtime/device.py b/iree/turbine/runtime/device.py similarity index 100% rename from shark_turbine/runtime/device.py rename to iree/turbine/runtime/device.py diff --git a/shark_turbine/runtime/launch.py b/iree/turbine/runtime/launch.py similarity index 100% rename from shark_turbine/runtime/launch.py rename to iree/turbine/runtime/launch.py diff --git a/shark_turbine/runtime/op_reg/__init__.py b/iree/turbine/runtime/op_reg/__init__.py similarity index 100% rename from shark_turbine/runtime/op_reg/__init__.py rename to iree/turbine/runtime/op_reg/__init__.py diff --git a/shark_turbine/runtime/op_reg/base.py b/iree/turbine/runtime/op_reg/base.py similarity index 100% rename from shark_turbine/runtime/op_reg/base.py rename to iree/turbine/runtime/op_reg/base.py diff --git a/shark_turbine/runtime/op_reg/compiler.py b/iree/turbine/runtime/op_reg/compiler.py similarity index 100% rename from shark_turbine/runtime/op_reg/compiler.py rename to iree/turbine/runtime/op_reg/compiler.py diff --git a/shark_turbine/runtime/op_reg/eager.py b/iree/turbine/runtime/op_reg/eager.py similarity index 100% rename from shark_turbine/runtime/op_reg/eager.py rename to iree/turbine/runtime/op_reg/eager.py diff --git a/shark_turbine/runtime/op_reg/impl_helper.py b/iree/turbine/runtime/op_reg/impl_helper.py similarity index 100% rename from shark_turbine/runtime/op_reg/impl_helper.py rename to iree/turbine/runtime/op_reg/impl_helper.py diff --git a/shark_turbine/runtime/tracing.py b/iree/turbine/runtime/tracing.py similarity index 100% rename from shark_turbine/runtime/tracing.py rename to iree/turbine/runtime/tracing.py diff --git a/shark_turbine/support/__init__.py b/iree/turbine/support/__init__.py similarity index 100% rename from shark_turbine/support/__init__.py rename to iree/turbine/support/__init__.py diff --git a/shark_turbine/support/conversions.py b/iree/turbine/support/conversions.py similarity index 100% rename from shark_turbine/support/conversions.py rename to iree/turbine/support/conversions.py diff --git a/shark_turbine/support/debugging.py b/iree/turbine/support/debugging.py similarity index 100% rename from shark_turbine/support/debugging.py rename to iree/turbine/support/debugging.py diff --git a/shark_turbine/support/exceptions.py b/iree/turbine/support/exceptions.py similarity index 100% rename from shark_turbine/support/exceptions.py rename to iree/turbine/support/exceptions.py diff --git a/shark_turbine/support/ir_imports.py b/iree/turbine/support/ir_imports.py similarity index 100% rename from shark_turbine/support/ir_imports.py rename to iree/turbine/support/ir_imports.py diff --git a/shark_turbine/support/logging.py b/iree/turbine/support/logging.py similarity index 100% rename from shark_turbine/support/logging.py rename to iree/turbine/support/logging.py diff --git a/shark_turbine/tools/__init__.py b/iree/turbine/tools/__init__.py similarity index 100% rename from shark_turbine/tools/__init__.py rename to iree/turbine/tools/__init__.py diff --git a/shark_turbine/tools/interpreter.py b/iree/turbine/tools/interpreter.py similarity index 100% rename from shark_turbine/tools/interpreter.py rename to iree/turbine/tools/interpreter.py diff --git a/shark_turbine/transforms/builder.py b/iree/turbine/transforms/builder.py similarity index 100% rename from shark_turbine/transforms/builder.py rename to iree/turbine/transforms/builder.py diff --git a/shark_turbine/transforms/general/add_metadata.py b/iree/turbine/transforms/general/add_metadata.py similarity index 97% rename from shark_turbine/transforms/general/add_metadata.py rename to iree/turbine/transforms/general/add_metadata.py index 44aa2413..340169ec 100644 --- a/shark_turbine/transforms/general/add_metadata.py +++ b/iree/turbine/transforms/general/add_metadata.py @@ -12,7 +12,7 @@ import re -from shark_turbine.support.ir_imports import * +from iree.turbine.support.ir_imports import * from ..rewriter import * from iree.compiler.ir import Context, DictAttr diff --git a/shark_turbine/transforms/general/custom_op_expansion.py b/iree/turbine/transforms/general/custom_op_expansion.py similarity index 100% rename from shark_turbine/transforms/general/custom_op_expansion.py rename to iree/turbine/transforms/general/custom_op_expansion.py diff --git a/shark_turbine/transforms/general/rename_parameters.py b/iree/turbine/transforms/general/rename_parameters.py similarity index 100% rename from shark_turbine/transforms/general/rename_parameters.py rename to iree/turbine/transforms/general/rename_parameters.py diff --git a/shark_turbine/transforms/merger.py b/iree/turbine/transforms/merger.py similarity index 100% rename from shark_turbine/transforms/merger.py rename to iree/turbine/transforms/merger.py diff --git a/shark_turbine/transforms/quantization/mm_group_quant.py b/iree/turbine/transforms/quantization/mm_group_quant.py similarity index 100% rename from shark_turbine/transforms/quantization/mm_group_quant.py rename to iree/turbine/transforms/quantization/mm_group_quant.py diff --git a/shark_turbine/transforms/rewriter.py b/iree/turbine/transforms/rewriter.py similarity index 100% rename from shark_turbine/transforms/rewriter.py rename to iree/turbine/transforms/rewriter.py diff --git a/lit_tests/kernel/wave/barriers.py b/lit_tests/kernel/wave/barriers.py index fcb7dbed..14eb2e60 100644 --- a/lit_tests/kernel/wave/barriers.py +++ b/lit_tests/kernel/wave/barriers.py @@ -3,18 +3,18 @@ import logging from typing import Callable import unittest -import shark_turbine.kernel as tk -import shark_turbine.kernel.lang as tkl -import shark_turbine.kernel.wave as tkw -from shark_turbine.kernel.wave.promotion import promote_node, promote_placeholders -from shark_turbine.kernel.wave.barriers import add_shared_memory_barriers -from shark_turbine.kernel.wave.hoisting import hoist_allocs -from shark_turbine.kernel.wave.expansion import expand_graph -from shark_turbine.kernel.lang.global_symbols import * -from shark_turbine.kernel._support.tracing import CapturedTrace -from shark_turbine.kernel._support.indexing import IndexingContext -from shark_turbine.kernel.ops.wave_ops import * -from shark_turbine.kernel.wave.utils import run_test, print_trace +import iree.turbine.kernel as tk +import iree.turbine.kernel.lang as tkl +import iree.turbine.kernel.wave as tkw +from iree.turbine.kernel.wave.promotion import promote_node, promote_placeholders +from iree.turbine.kernel.wave.barriers import add_shared_memory_barriers +from iree.turbine.kernel.wave.hoisting import hoist_allocs +from iree.turbine.kernel.wave.expansion import expand_graph +from iree.turbine.kernel.lang.global_symbols import * +from iree.turbine.kernel._support.tracing import CapturedTrace +from iree.turbine.kernel._support.indexing import IndexingContext +from iree.turbine.kernel.ops.wave_ops import * +from iree.turbine.kernel.wave.utils import run_test, print_trace def get_read_nodes(graph: fx.Graph) -> list[CustomOp]: diff --git a/lit_tests/kernel/wave/codegen.py b/lit_tests/kernel/wave/codegen.py index 4800e9bd..30144024 100644 --- a/lit_tests/kernel/wave/codegen.py +++ b/lit_tests/kernel/wave/codegen.py @@ -2,11 +2,11 @@ import pytest from typing import Callable -import shark_turbine.kernel as tk -import shark_turbine.kernel.lang as tkl -import shark_turbine.kernel.wave as tkw -from shark_turbine.kernel.lang.global_symbols import * -from shark_turbine.kernel.wave.utils import run_test +import iree.turbine.kernel as tk +import iree.turbine.kernel.lang as tkl +import iree.turbine.kernel.wave as tkw +from iree.turbine.kernel.lang.global_symbols import * +from iree.turbine.kernel.wave.utils import run_test import torch M = tkl.sym.M diff --git a/lit_tests/kernel/wave/expansion.py b/lit_tests/kernel/wave/expansion.py index 6f4e2f29..efcdd582 100644 --- a/lit_tests/kernel/wave/expansion.py +++ b/lit_tests/kernel/wave/expansion.py @@ -2,13 +2,13 @@ import logging import unittest -import shark_turbine.kernel as tk -import shark_turbine.kernel.lang as tkl -import shark_turbine.kernel.wave as tkw -from shark_turbine.kernel.wave.expansion import expand_graph -from shark_turbine.kernel._support.indexing import IndexingContext -from shark_turbine.kernel.lang.global_symbols import * -from shark_turbine.kernel.wave.utils import run_test, print_trace +import iree.turbine.kernel as tk +import iree.turbine.kernel.lang as tkl +import iree.turbine.kernel.wave as tkw +from iree.turbine.kernel.wave.expansion import expand_graph +from iree.turbine.kernel._support.indexing import IndexingContext +from iree.turbine.kernel.lang.global_symbols import * +from iree.turbine.kernel.wave.utils import run_test, print_trace import sympy # Input sizes diff --git a/lit_tests/kernel/wave/index_sequence_analysis.py b/lit_tests/kernel/wave/index_sequence_analysis.py index 7dd266ee..f0149b70 100644 --- a/lit_tests/kernel/wave/index_sequence_analysis.py +++ b/lit_tests/kernel/wave/index_sequence_analysis.py @@ -2,22 +2,22 @@ import logging import unittest -import shark_turbine.kernel as tk -import shark_turbine.kernel.lang as tkl -import shark_turbine.kernel.wave as tkw -from shark_turbine.kernel.wave.promotion import promote_placeholders -from shark_turbine.kernel.wave.hoisting import hoist_allocs -from shark_turbine.kernel.wave.expansion import expand_graph -from shark_turbine.kernel.lang.global_symbols import * -from shark_turbine.kernel._support.tracing import CapturedTrace -from shark_turbine.kernel._support.indexing import IndexingContext -from shark_turbine.kernel.ops.wave_ops import * -from shark_turbine.kernel.wave.utils import run_test, print_trace -from shark_turbine.kernel.wave.minimize_global_loads import minimize_global_loads -from shark_turbine.kernel.wave.shared_memory_indexing import ( +import iree.turbine.kernel as tk +import iree.turbine.kernel.lang as tkl +import iree.turbine.kernel.wave as tkw +from iree.turbine.kernel.wave.promotion import promote_placeholders +from iree.turbine.kernel.wave.hoisting import hoist_allocs +from iree.turbine.kernel.wave.expansion import expand_graph +from iree.turbine.kernel.lang.global_symbols import * +from iree.turbine.kernel._support.tracing import CapturedTrace +from iree.turbine.kernel._support.indexing import IndexingContext +from iree.turbine.kernel.ops.wave_ops import * +from iree.turbine.kernel.wave.utils import run_test, print_trace +from iree.turbine.kernel.wave.minimize_global_loads import minimize_global_loads +from iree.turbine.kernel.wave.shared_memory_indexing import ( apply_shared_memory_indexing_corrections, ) -from shark_turbine.kernel.wave.index_sequence_analysis import ( +from iree.turbine.kernel.wave.index_sequence_analysis import ( partition_strided_operators, ) diff --git a/lit_tests/kernel/wave/minimize_global_loads.py b/lit_tests/kernel/wave/minimize_global_loads.py index 7596a94b..329a9ccf 100644 --- a/lit_tests/kernel/wave/minimize_global_loads.py +++ b/lit_tests/kernel/wave/minimize_global_loads.py @@ -3,21 +3,21 @@ import logging from typing import Callable import unittest -import shark_turbine.kernel as tk -import shark_turbine.kernel.lang as tkl -import shark_turbine.kernel.wave as tkw -from shark_turbine.kernel.wave.promotion import promote_placeholders -from shark_turbine.kernel.wave.hoisting import hoist_allocs -from shark_turbine.kernel.wave.barriers import add_shared_memory_barriers -from shark_turbine.kernel.wave.expansion import expand_graph -from shark_turbine.kernel.lang.global_symbols import * -from shark_turbine.kernel._support.tracing import CapturedTrace -from shark_turbine.kernel._support.indexing import IndexingContext -from shark_turbine.kernel.ops.wave_ops import * -from shark_turbine.kernel.wave.utils import run_test, print_trace -from shark_turbine.kernel.wave.minimize_global_loads import minimize_global_loads -from shark_turbine.kernel.wave.visualization import visualize_graph -from shark_turbine.kernel.wave.shared_memory_indexing import ( +import iree.turbine.kernel as tk +import iree.turbine.kernel.lang as tkl +import iree.turbine.kernel.wave as tkw +from iree.turbine.kernel.wave.promotion import promote_placeholders +from iree.turbine.kernel.wave.hoisting import hoist_allocs +from iree.turbine.kernel.wave.barriers import add_shared_memory_barriers +from iree.turbine.kernel.wave.expansion import expand_graph +from iree.turbine.kernel.lang.global_symbols import * +from iree.turbine.kernel._support.tracing import CapturedTrace +from iree.turbine.kernel._support.indexing import IndexingContext +from iree.turbine.kernel.ops.wave_ops import * +from iree.turbine.kernel.wave.utils import run_test, print_trace +from iree.turbine.kernel.wave.minimize_global_loads import minimize_global_loads +from iree.turbine.kernel.wave.visualization import visualize_graph +from iree.turbine.kernel.wave.shared_memory_indexing import ( apply_shared_memory_indexing_corrections, ) diff --git a/lit_tests/kernel/wave/promotion.py b/lit_tests/kernel/wave/promotion.py index 3843c406..c3836f4f 100644 --- a/lit_tests/kernel/wave/promotion.py +++ b/lit_tests/kernel/wave/promotion.py @@ -2,16 +2,16 @@ import logging import unittest -import shark_turbine.kernel as tk -import shark_turbine.kernel.lang as tkl -import shark_turbine.kernel.wave as tkw -from shark_turbine.kernel.wave.promotion import promote_node, promote_placeholders -from shark_turbine.kernel.wave.hoisting import hoist_allocs -from shark_turbine.kernel.lang.global_symbols import * -from shark_turbine.kernel._support.tracing import CapturedTrace -from shark_turbine.kernel._support.indexing import IndexingContext -from shark_turbine.kernel.ops.wave_ops import * -from shark_turbine.kernel.wave.utils import run_test, print_trace +import iree.turbine.kernel as tk +import iree.turbine.kernel.lang as tkl +import iree.turbine.kernel.wave as tkw +from iree.turbine.kernel.wave.promotion import promote_node, promote_placeholders +from iree.turbine.kernel.wave.hoisting import hoist_allocs +from iree.turbine.kernel.lang.global_symbols import * +from iree.turbine.kernel._support.tracing import CapturedTrace +from iree.turbine.kernel._support.indexing import IndexingContext +from iree.turbine.kernel.ops.wave_ops import * +from iree.turbine.kernel.wave.utils import run_test, print_trace def get_read_nodes(graph: fx.Graph) -> list[CustomOp]: diff --git a/lit_tests/kernel/wave/scheduling.py b/lit_tests/kernel/wave/scheduling.py index eafabb27..aefad516 100644 --- a/lit_tests/kernel/wave/scheduling.py +++ b/lit_tests/kernel/wave/scheduling.py @@ -2,22 +2,22 @@ import logging import unittest -import shark_turbine.kernel as tk -import shark_turbine.kernel.lang as tkl -import shark_turbine.kernel.wave as tkw -from shark_turbine.kernel.wave.promotion import promote_placeholders -from shark_turbine.kernel.wave.hoisting import hoist_allocs -from shark_turbine.kernel.wave.expansion import expand_graph -from shark_turbine.kernel.lang.global_symbols import * -from shark_turbine.kernel._support.tracing import CapturedTrace -from shark_turbine.kernel._support.indexing import IndexingContext -from shark_turbine.kernel.ops.wave_ops import * -from shark_turbine.kernel.wave.utils import run_test, print_subgraph -from shark_turbine.kernel.wave.minimize_global_loads import minimize_global_loads -from shark_turbine.kernel.wave.shared_memory_indexing import ( +import iree.turbine.kernel as tk +import iree.turbine.kernel.lang as tkl +import iree.turbine.kernel.wave as tkw +from iree.turbine.kernel.wave.promotion import promote_placeholders +from iree.turbine.kernel.wave.hoisting import hoist_allocs +from iree.turbine.kernel.wave.expansion import expand_graph +from iree.turbine.kernel.lang.global_symbols import * +from iree.turbine.kernel._support.tracing import CapturedTrace +from iree.turbine.kernel._support.indexing import IndexingContext +from iree.turbine.kernel.ops.wave_ops import * +from iree.turbine.kernel.wave.utils import run_test, print_subgraph +from iree.turbine.kernel.wave.minimize_global_loads import minimize_global_loads +from iree.turbine.kernel.wave.shared_memory_indexing import ( apply_shared_memory_indexing_corrections, ) -from shark_turbine.kernel.wave.scheduling.schedule import schedule_graph +from iree.turbine.kernel.wave.scheduling.schedule import schedule_graph # Input sizes diff --git a/lit_tests/kernel/wave/tracing.py b/lit_tests/kernel/wave/tracing.py index 283b6436..f6c9306b 100644 --- a/lit_tests/kernel/wave/tracing.py +++ b/lit_tests/kernel/wave/tracing.py @@ -1,11 +1,11 @@ # RUN: python %s | FileCheck %s from typing import Callable -from shark_turbine.kernel._support.tracing import CapturedTrace -import shark_turbine.kernel.lang as tkl -import shark_turbine.kernel.wave as tkw -from shark_turbine.kernel.ops.wave_ops import get_custom, Read, Write -from shark_turbine.kernel.wave.utils import run_test, print_trace +from iree.turbine.kernel._support.tracing import CapturedTrace +import iree.turbine.kernel.lang as tkl +import iree.turbine.kernel.wave as tkw +from iree.turbine.kernel.ops.wave_ops import get_custom, Read, Write +from iree.turbine.kernel.wave.utils import run_test, print_trace M = tkl.sym.M N = tkl.sym.N diff --git a/lit_tests/lit.cfg.py b/lit_tests/lit.cfg.py index 5b40c7eb..614383fc 100644 --- a/lit_tests/lit.cfg.py +++ b/lit_tests/lit.cfg.py @@ -7,7 +7,7 @@ import lit.llvm -from shark_turbine.support.logging import get_logger +from iree.turbine.support.logging import get_logger logger = get_logger("turbine.lit_tests") diff --git a/mypy.ini b/mypy.ini index 528b8d48..5638faef 100644 --- a/mypy.ini +++ b/mypy.ini @@ -2,7 +2,7 @@ explicit_package_bases = True mypy_path = $MYPY_CONFIG_FILE_DIR -packages = shark_turbine +packages = iree.turbine # Missing typing stubs for iree.compiler. [mypy-iree.compiler.*] @@ -13,15 +13,15 @@ ignore_missing_imports = True ignore_missing_imports = True # fx_importer needs to be fixed upstream. -[mypy-shark_turbine.importers.fx_importer.*] +[mypy-iree.turbine.importers.fx_importer.*] ignore_errors = True # TODO: Fix all typing errors in TK. -[mypy-shark_turbine.kernel.*] +[mypy-iree.turbine.kernel.*] ignore_errors = True # TODO: Some pytorch errors. -[mypy-shark_turbine.tools.interpreter] +[mypy-iree.turbine.tools.interpreter] ignore_errors = True # Ignore all typing errors in tests/tools (these depend on TK). diff --git a/setup.py b/setup.py index 63a028cb..c73c3532 100644 --- a/setup.py +++ b/setup.py @@ -15,8 +15,7 @@ REPO_DIR = THIS_DIR VERSION_INFO_FILE = os.path.join(REPO_DIR, "version_info.json") -# Transitional as we migrate from shark-turbine -> iree-turbine. -TURBINE_PACKAGE_NAME = os.getenv("TURBINE_PACKAGE_NAME", "shark-turbine") +TURBINE_PACKAGE_NAME = "iree-turbine" with open( os.path.join( @@ -81,12 +80,12 @@ def initialize_options(self): setup( name=f"{TURBINE_PACKAGE_NAME}", version=f"{PACKAGE_VERSION}", - author="SHARK Authors", - author_email="stella@nod.ai", - description="SHARK Turbine Machine Learning Deployment Tools", + author="IREE Authors", + author_email="iree-technical-discussion@lists.lfaidata.foundation", + description="IREE Turbine Machine Learning Deployment Tools", long_description=README, long_description_content_type="text/markdown", - url="https://github.com/nod-ai/SHARK-Turbine", + url="https://github.com/iree-org/iree-turbine/", license="Apache-2.0", classifiers=[ "Development Status :: 5 - Production/Stable", @@ -96,11 +95,11 @@ def initialize_options(self): packages=packages, include_package_data=True, package_data={ - "shark_turbine": ["ops/templates/*.mlir"], # Include MLIR templates + "iree.turbine": ["ops/templates/*.mlir"], # Include MLIR templates }, entry_points={ "torch_dynamo_backends": [ - "turbine_cpu = shark_turbine.dynamo.backends.cpu:backend", + "turbine_cpu = iree.turbine.dynamo.backends.cpu:backend", ], }, install_requires=[ diff --git a/shark_turbine/__init__.py b/shark_turbine/__init__.py new file mode 100644 index 00000000..f1e1c318 --- /dev/null +++ b/shark_turbine/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + + +# Temp redirect from old shark_turbine namespace. +from iree.turbine import aot +from iree.turbine import dynamo +from iree.turbine import kernel +from iree.turbine import ops +from iree.turbine import runtime diff --git a/tests/aot/api_test.py b/tests/aot/api_test.py index e038704d..0d5f4215 100644 --- a/tests/aot/api_test.py +++ b/tests/aot/api_test.py @@ -11,7 +11,7 @@ Context, ) -from shark_turbine.aot import * +from iree.turbine.aot import * import torch import torch.nn as nn diff --git a/tests/aot/args_test.py b/tests/aot/args_test.py index d7ec458d..efbce489 100644 --- a/tests/aot/args_test.py +++ b/tests/aot/args_test.py @@ -11,7 +11,7 @@ Context, ) -from shark_turbine.aot import * +from iree.turbine.aot import * class ArgsTest(unittest.TestCase): diff --git a/tests/aot/compiled_exported_program_test.py b/tests/aot/compiled_exported_program_test.py index baaeb9bb..6b86b185 100644 --- a/tests/aot/compiled_exported_program_test.py +++ b/tests/aot/compiled_exported_program_test.py @@ -14,8 +14,8 @@ Context, ) -from shark_turbine.aot import * -from shark_turbine.aot.builtins import * +from iree.turbine.aot import * +from iree.turbine.aot.builtins import * class TorchExportTests(unittest.TestCase): diff --git a/tests/aot/decompositions_test.py b/tests/aot/decompositions_test.py index baf96604..f186cf12 100644 --- a/tests/aot/decompositions_test.py +++ b/tests/aot/decompositions_test.py @@ -9,7 +9,7 @@ import logging import unittest -from shark_turbine.aot import decompositions +from iree.turbine.aot import decompositions class DecompTest(unittest.TestCase): diff --git a/tests/aot/dynamic_shape_export_test.py b/tests/aot/dynamic_shape_export_test.py index da8c11b7..8f53df27 100644 --- a/tests/aot/dynamic_shape_export_test.py +++ b/tests/aot/dynamic_shape_export_test.py @@ -2,7 +2,7 @@ import pytest -from shark_turbine.aot import * +from iree.turbine.aot import * @pytest.mark.parametrize( diff --git a/tests/aot/functionalize_test.py b/tests/aot/functionalize_test.py index 0cad8e93..2a2ea309 100644 --- a/tests/aot/functionalize_test.py +++ b/tests/aot/functionalize_test.py @@ -13,7 +13,7 @@ Context, ) -from shark_turbine.aot import * +from iree.turbine.aot import * class FunctionalizeTests(unittest.TestCase): diff --git a/tests/aot/fx_programs_test.py b/tests/aot/fx_programs_test.py index c54f1851..f2c70456 100644 --- a/tests/aot/fx_programs_test.py +++ b/tests/aot/fx_programs_test.py @@ -10,7 +10,7 @@ import pytest import torch -from shark_turbine.aot import ( +from iree.turbine.aot import ( FxPrograms, FxProgramsBuilder, ) diff --git a/tests/aot/globals_test.py b/tests/aot/globals_test.py index 607382fd..7a250531 100644 --- a/tests/aot/globals_test.py +++ b/tests/aot/globals_test.py @@ -11,7 +11,7 @@ Context, ) -from shark_turbine.aot import * +from iree.turbine.aot import * import torch import torch.nn as nn diff --git a/tests/aot/iree_procedural_test.py b/tests/aot/iree_procedural_test.py index 9f479921..251c8f12 100644 --- a/tests/aot/iree_procedural_test.py +++ b/tests/aot/iree_procedural_test.py @@ -13,7 +13,7 @@ Context, ) -from shark_turbine.aot import * +from iree.turbine.aot import * class CompiledModuleAPI(unittest.TestCase): diff --git a/tests/aot/jittable_test.py b/tests/aot/jittable_test.py index 9c87fb11..d19988bc 100644 --- a/tests/aot/jittable_test.py +++ b/tests/aot/jittable_test.py @@ -13,7 +13,7 @@ Context, ) -from shark_turbine.aot import * +from iree.turbine.aot import * class JittableTests(unittest.TestCase): diff --git a/tests/aot/non_strict_export_test.py b/tests/aot/non_strict_export_test.py index ece961dc..2ed1b603 100644 --- a/tests/aot/non_strict_export_test.py +++ b/tests/aot/non_strict_export_test.py @@ -3,7 +3,7 @@ from torch import nn import torch -from shark_turbine.aot import * +from iree.turbine.aot import * logger = logging.getLogger(__file__) diff --git a/tests/aot/params_test.py b/tests/aot/params_test.py index a1d64206..895cb2b9 100644 --- a/tests/aot/params_test.py +++ b/tests/aot/params_test.py @@ -12,7 +12,7 @@ import torch import torch.nn as nn -from shark_turbine.aot import ( +from iree.turbine.aot import ( export, externalize_module_parameters, save_module_parameters, diff --git a/tests/dynamo/importer_dynamic_test.py b/tests/dynamo/importer_dynamic_test.py index 72ff4f82..682aa140 100644 --- a/tests/dynamo/importer_dynamic_test.py +++ b/tests/dynamo/importer_dynamic_test.py @@ -14,7 +14,7 @@ # from torch._export.constraints import constrain_as_size, constrain_as_value from iree.compiler.extras.fx_importer import FxImporter -from shark_turbine.dynamo.passes import turbine_cpu_pass_pipeline +from iree.turbine.dynamo.passes import turbine_cpu_pass_pipeline import torch import torch._dynamo as dynamo from torch._dynamo.backends.common import aot_autograd diff --git a/tests/dynamo/tensor_test.py b/tests/dynamo/tensor_test.py index fcd40660..0562c071 100644 --- a/tests/dynamo/tensor_test.py +++ b/tests/dynamo/tensor_test.py @@ -12,8 +12,8 @@ import torch # Public API imports. -from shark_turbine.runtime import Device -from shark_turbine.dynamo import TurbineMode, DeviceTensor +from iree.turbine.runtime import Device +from iree.turbine.dynamo import TurbineMode, DeviceTensor class TensorTest(unittest.TestCase): diff --git a/tests/dynamo/type_conversion_test.py b/tests/dynamo/type_conversion_test.py index 617c5d05..70375efb 100644 --- a/tests/dynamo/type_conversion_test.py +++ b/tests/dynamo/type_conversion_test.py @@ -12,7 +12,7 @@ Type as IrType, ) -import shark_turbine.dynamo.type_conversion as tc +import iree.turbine.dynamo.type_conversion as tc class TypeConversionTest(unittest.TestCase): diff --git a/tests/generated/evaluate.py b/tests/generated/evaluate.py index 3184930d..a971e23c 100644 --- a/tests/generated/evaluate.py +++ b/tests/generated/evaluate.py @@ -2,7 +2,7 @@ import logging from iree.compiler.extras.fx_importer import FxImporter -from shark_turbine.dynamo.passes import turbine_cpu_pass_pipeline +from iree.turbine.dynamo.passes import turbine_cpu_pass_pipeline import torch import torch._dynamo as dynamo from torch._dynamo.backends.common import aot_autograd diff --git a/tests/kernel/aot_kernel_test.py b/tests/kernel/aot_kernel_test.py index 690e366a..16363048 100644 --- a/tests/kernel/aot_kernel_test.py +++ b/tests/kernel/aot_kernel_test.py @@ -8,9 +8,9 @@ import unittest import torch -from shark_turbine.aot import export -import shark_turbine.kernel as tk -import shark_turbine.kernel.lang as tkl +from iree.turbine.aot import export +import iree.turbine.kernel as tk +import iree.turbine.kernel.lang as tkl def export_softmax_kernel(): diff --git a/tests/kernel/arith_test.py b/tests/kernel/arith_test.py index 1631454c..ce9e659e 100644 --- a/tests/kernel/arith_test.py +++ b/tests/kernel/arith_test.py @@ -8,15 +8,15 @@ import unittest import torch -import shark_turbine.kernel as tk -import shark_turbine.kernel.lang as tkl +import iree.turbine.kernel as tk +import iree.turbine.kernel.lang as tkl -from shark_turbine.kernel.compiler import ( +from iree.turbine.kernel.compiler import ( builder, kernel_codegen, vector_codegen, ) -from shark_turbine.kernel._support import ( +from iree.turbine.kernel._support import ( indexing, ) diff --git a/tests/kernel/compiler/utils_test.py b/tests/kernel/compiler/utils_test.py index be084613..6f2db310 100644 --- a/tests/kernel/compiler/utils_test.py +++ b/tests/kernel/compiler/utils_test.py @@ -1,9 +1,9 @@ import logging import pytest import unittest -from shark_turbine.kernel.lang import sym -from shark_turbine.kernel._support.indexing import IndexSymbol, IndexingContext -from shark_turbine.kernel.compiler.utils import strides_from_symbolic_shape +from iree.turbine.kernel.lang import sym +from iree.turbine.kernel._support.indexing import IndexSymbol, IndexingContext +from iree.turbine.kernel.compiler.utils import strides_from_symbolic_shape class UtilsTest(unittest.TestCase): diff --git a/tests/kernel/dispatch_codegen_test.py b/tests/kernel/dispatch_codegen_test.py index be17a86d..b76ed2e1 100644 --- a/tests/kernel/dispatch_codegen_test.py +++ b/tests/kernel/dispatch_codegen_test.py @@ -8,16 +8,16 @@ import unittest import torch -import shark_turbine.kernel as tk -import shark_turbine.kernel.lang as tkl +import iree.turbine.kernel as tk +import iree.turbine.kernel.lang as tkl -from shark_turbine.kernel.compiler import ( +from iree.turbine.kernel.compiler import ( builder, dispatch_codegen, kernel_codegen, vector_codegen, ) -from shark_turbine.kernel._support import ( +from iree.turbine.kernel._support import ( indexing, ) diff --git a/tests/kernel/fused_attention_test.py b/tests/kernel/fused_attention_test.py index 89883780..abc9d7ad 100644 --- a/tests/kernel/fused_attention_test.py +++ b/tests/kernel/fused_attention_test.py @@ -8,8 +8,8 @@ import unittest import torch -import shark_turbine.kernel as tk -import shark_turbine.kernel.lang as tkl +import iree.turbine.kernel as tk +import iree.turbine.kernel.lang as tkl BATCH = tkl.sym.BATCH N_HEADS = tkl.sym.N_HEADS diff --git a/tests/kernel/indexing_test.py b/tests/kernel/indexing_test.py index 8bc27c50..677bbf09 100644 --- a/tests/kernel/indexing_test.py +++ b/tests/kernel/indexing_test.py @@ -9,8 +9,8 @@ import torch -from shark_turbine.kernel._support.indexing import * -from shark_turbine.kernel.lang import * +from iree.turbine.kernel._support.indexing import * +from iree.turbine.kernel.lang import * M = sym.M N = sym.N diff --git a/tests/kernel/simple_kernel_test.py b/tests/kernel/simple_kernel_test.py index 87cf3ed2..bffe723c 100644 --- a/tests/kernel/simple_kernel_test.py +++ b/tests/kernel/simple_kernel_test.py @@ -9,8 +9,8 @@ import torch -import shark_turbine.kernel as tk -import shark_turbine.kernel.lang as tkl +import iree.turbine.kernel as tk +import iree.turbine.kernel.lang as tkl M = tk.lang.sym.M K = tk.lang.sym.K @@ -37,9 +37,9 @@ def iota_kernel(out: tk.lang.KernelBuffer[M, tkl.index]): print(iota_kernel._trace().region_graph) # Prints: # .graph(): - # %out : shark_turbine.kernel.lang.types.KernelBuffer [num_users=1] = placeholder[target=out] - # %program_id : [num_users=1] = call_function[target=shark_turbine.kernel.lang.prims.program_id](args = (0,), kwargs = {}) - # %_global_buffer_setitem : [num_users=0] = call_function[target=shark_turbine.kernel._support.tracing._global_buffer_setitem](args = (%out, %program_id, %program_id), kwargs = {}) + # %out : iree.turbine.kernel.lang.types.KernelBuffer [num_users=1] = placeholder[target=out] + # %program_id : [num_users=1] = call_function[target=iree.turbine.kernel.lang.prims.program_id](args = (0,), kwargs = {}) + # %_global_buffer_setitem : [num_users=0] = call_function[target=iree.turbine.kernel._support.tracing._global_buffer_setitem](args = (%out, %program_id, %program_id), kwargs = {}) # return None def testSoftmax(self): @@ -76,17 +76,17 @@ def softmax(x): print(softmax_kernel._trace().region_graph) # Prints: # graph(): - # %input_1 : shark_turbine.kernel.lang.types.KernelBuffer [num_users=1] = placeholder[target=input] - # %output : shark_turbine.kernel.lang.types.KernelBuffer [num_users=1] = placeholder[target=output] - # %program_id : [num_users=1] = call_function[target=shark_turbine.kernel.lang.prims.program_id](args = (0,), kwargs = {}) + # %input_1 : iree.turbine.kernel.lang.types.KernelBuffer [num_users=1] = placeholder[target=input] + # %output : iree.turbine.kernel.lang.types.KernelBuffer [num_users=1] = placeholder[target=output] + # %program_id : [num_users=1] = call_function[target=iree.turbine.kernel.lang.prims.program_id](args = (0,), kwargs = {}) # %getitem : [num_users=2] = call_function[target=operator.getitem](args = (%input_1, (%program_id, slice(None, None, None))), kwargs = {}) # %max_1 : [num_users=1] = call_function[target=torch.max](args = (%getitem,), kwargs = {}) # %sub : [num_users=1] = call_function[target=operator.sub](args = (%getitem, %max_1), kwargs = {}) # %exp : [num_users=2] = call_function[target=torch.exp](args = (%sub,), kwargs = {}) # %sum_1 : [num_users=1] = call_function[target=torch.sum](args = (%exp,), kwargs = {}) # %truediv : [num_users=1] = call_function[target=operator.truediv](args = (%exp, %sum_1), kwargs = {}) - # %program_id_1 : [num_users=1] = call_function[target=shark_turbine.kernel.lang.prims.program_id](args = (0,), kwargs = {}) - # %_kernel_buffer_setitem : [num_users=0] = call_function[target=shark_turbine.kernel._support.tracing._kernel_buffer_setitem](args = (%output, (%program_id_1, slice(None, None, None)), %truediv), kwargs = {}) + # %program_id_1 : [num_users=1] = call_function[target=iree.turbine.kernel.lang.prims.program_id](args = (0,), kwargs = {}) + # %_kernel_buffer_setitem : [num_users=0] = call_function[target=iree.turbine.kernel._support.tracing._kernel_buffer_setitem](args = (%output, (%program_id_1, slice(None, None, None)), %truediv), kwargs = {}) # return None diff --git a/tests/kernel/types_test.py b/tests/kernel/types_test.py index 87dc6536..e355db31 100644 --- a/tests/kernel/types_test.py +++ b/tests/kernel/types_test.py @@ -7,7 +7,7 @@ import logging import unittest -from shark_turbine.kernel.lang import ( +from iree.turbine.kernel.lang import ( Index, ) diff --git a/tests/kernel/vector_codegen_test.py b/tests/kernel/vector_codegen_test.py index fcd33462..696852c0 100644 --- a/tests/kernel/vector_codegen_test.py +++ b/tests/kernel/vector_codegen_test.py @@ -8,8 +8,8 @@ import unittest import torch -import shark_turbine.kernel as tk -import shark_turbine.kernel.lang as tkl +import iree.turbine.kernel as tk +import iree.turbine.kernel.lang as tkl M = tk.lang.sym.M K = tk.lang.sym.K diff --git a/tests/kernel/wave/constraints_test.py b/tests/kernel/wave/constraints_test.py index 418c3c8b..f2915ac4 100644 --- a/tests/kernel/wave/constraints_test.py +++ b/tests/kernel/wave/constraints_test.py @@ -8,8 +8,8 @@ import pytest import unittest from sympy import ceiling -from shark_turbine.kernel.lang import sym -from shark_turbine.kernel.wave.constraints import ( +from iree.turbine.kernel.lang import sym +from iree.turbine.kernel.wave.constraints import ( WorkgroupConstraint, get_grid_shape, TilingConstraint, diff --git a/tests/kernel/wave/scheduling_test.py b/tests/kernel/wave/scheduling_test.py index 93d9cb6c..bb7cbc25 100644 --- a/tests/kernel/wave/scheduling_test.py +++ b/tests/kernel/wave/scheduling_test.py @@ -6,32 +6,32 @@ import unittest import logging -from shark_turbine.kernel.wave.scheduling.modulo_scheduling import ( +from iree.turbine.kernel.wave.scheduling.modulo_scheduling import ( ModuloScheduler, EdgeWeight, Edge, ) import torch.fx as fx import numpy as np -from shark_turbine.kernel.wave.visualization import visualize_graph -from shark_turbine.kernel.wave.scheduling.graph_utils import ( +from iree.turbine.kernel.wave.visualization import visualize_graph +from iree.turbine.kernel.wave.scheduling.graph_utils import ( find_strongly_connected_components, find_cycles_in_scc, all_pairs_longest_paths, evaluate_all_pairs_longest_paths, ) -import shark_turbine.kernel as tk -import shark_turbine.kernel.lang as tkl -import shark_turbine.kernel.wave as tkw -from shark_turbine.kernel.lang.global_symbols import * -from shark_turbine.kernel._support.tracing import CapturedTrace -from shark_turbine.kernel._support.indexing import IndexingContext -from shark_turbine.kernel.wave.promotion import promote_placeholders -from shark_turbine.kernel.wave.hoisting import hoist_allocs -from shark_turbine.kernel.wave.expansion import expand_graph -from shark_turbine.kernel.wave.minimize_global_loads import minimize_global_loads -from shark_turbine.kernel.wave.scheduling.schedule import schedule_graph -from shark_turbine.kernel.ops.wave_ops import get_custom +import iree.turbine.kernel as tk +import iree.turbine.kernel.lang as tkl +import iree.turbine.kernel.wave as tkw +from iree.turbine.kernel.lang.global_symbols import * +from iree.turbine.kernel._support.tracing import CapturedTrace +from iree.turbine.kernel._support.indexing import IndexingContext +from iree.turbine.kernel.wave.promotion import promote_placeholders +from iree.turbine.kernel.wave.hoisting import hoist_allocs +from iree.turbine.kernel.wave.expansion import expand_graph +from iree.turbine.kernel.wave.minimize_global_loads import minimize_global_loads +from iree.turbine.kernel.wave.scheduling.schedule import schedule_graph +from iree.turbine.kernel.ops.wave_ops import get_custom class SchedulingTest(unittest.TestCase): diff --git a/tests/kernel/wave/types_test.py b/tests/kernel/wave/types_test.py index d27c4c47..cdb05c3e 100644 --- a/tests/kernel/wave/types_test.py +++ b/tests/kernel/wave/types_test.py @@ -9,9 +9,9 @@ import sympy import unittest -from shark_turbine.kernel.lang import Memory, Register, sym, f16 -from shark_turbine.kernel.lang.wave_types import AddressSpace -from shark_turbine.kernel.lang.kernel_buffer import KernelBufferUsage +from iree.turbine.kernel.lang import Memory, Register, sym, f16 +from iree.turbine.kernel.lang.wave_types import AddressSpace +from iree.turbine.kernel.lang.kernel_buffer import KernelBufferUsage M = sym.M N = sym.N diff --git a/tests/kernel/wave/visualization_test.py b/tests/kernel/wave/visualization_test.py index 17cce11c..ebe6a75f 100644 --- a/tests/kernel/wave/visualization_test.py +++ b/tests/kernel/wave/visualization_test.py @@ -9,15 +9,15 @@ import unittest import os import pytest -import shark_turbine.kernel as tk -import shark_turbine.kernel.lang as tkl -import shark_turbine.kernel.wave as tkw -from shark_turbine.kernel.wave.expansion import expand_graph -from shark_turbine.kernel._support.tracing import CapturedTrace -from shark_turbine.kernel._support.indexing import IndexingContext -from shark_turbine.kernel.ops.wave_ops import get_custom -from shark_turbine.kernel.lang.global_symbols import * -from shark_turbine.kernel.wave.visualization import visualize_graph +import iree.turbine.kernel as tk +import iree.turbine.kernel.lang as tkl +import iree.turbine.kernel.wave as tkw +from iree.turbine.kernel.wave.expansion import expand_graph +from iree.turbine.kernel._support.tracing import CapturedTrace +from iree.turbine.kernel._support.indexing import IndexingContext +from iree.turbine.kernel.ops.wave_ops import get_custom +from iree.turbine.kernel.lang.global_symbols import * +from iree.turbine.kernel.wave.visualization import visualize_graph def run(func: Callable[[], None]) -> Callable[[], None]: diff --git a/tests/kernel/wave/wave_e2e_test.py b/tests/kernel/wave/wave_e2e_test.py index cf2d1315..0611257d 100644 --- a/tests/kernel/wave/wave_e2e_test.py +++ b/tests/kernel/wave/wave_e2e_test.py @@ -4,12 +4,12 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -import shark_turbine.kernel as tk -import shark_turbine.kernel.lang as tkl -import shark_turbine.kernel.wave as tkw -from shark_turbine.kernel.wave.wave_sim import wave_sim -from shark_turbine.kernel.lang.global_symbols import * -from shark_turbine.kernel.wave.iree_utils import generate_iree_ref +import iree.turbine.kernel as tk +import iree.turbine.kernel.lang as tkl +import iree.turbine.kernel.wave as tkw +from iree.turbine.kernel.wave.wave_sim import wave_sim +from iree.turbine.kernel.lang.global_symbols import * +from iree.turbine.kernel.wave.iree_utils import generate_iree_ref import torch from numpy.testing import assert_allclose, assert_equal import pytest diff --git a/tests/kernel/wave/wave_gemm_test.py b/tests/kernel/wave/wave_gemm_test.py index f9de28b2..e9487a64 100644 --- a/tests/kernel/wave/wave_gemm_test.py +++ b/tests/kernel/wave/wave_gemm_test.py @@ -8,11 +8,11 @@ import pytest import torch import unittest -import shark_turbine.kernel as tk -import shark_turbine.kernel.lang as tkl -import shark_turbine.kernel.wave as tkw -from shark_turbine.kernel.lang.global_symbols import * -from shark_turbine.kernel.wave.iree_utils import generate_iree_ref +import iree.turbine.kernel as tk +import iree.turbine.kernel.lang as tkl +import iree.turbine.kernel.wave as tkw +from iree.turbine.kernel.lang.global_symbols import * +from iree.turbine.kernel.wave.iree_utils import generate_iree_ref import os import json from torch.testing import assert_close diff --git a/tests/kernel/wave/wave_sim_test.py b/tests/kernel/wave/wave_sim_test.py index 5fa5695a..58ec1255 100644 --- a/tests/kernel/wave/wave_sim_test.py +++ b/tests/kernel/wave/wave_sim_test.py @@ -6,9 +6,9 @@ import pytest import torch -import shark_turbine.kernel.lang as tkl -import shark_turbine.kernel.wave as tkw -from shark_turbine.kernel.wave.wave_sim import wave_sim +import iree.turbine.kernel.lang as tkl +import iree.turbine.kernel.wave as tkw +from iree.turbine.kernel.wave.wave_sim import wave_sim from numpy.testing import assert_allclose diff --git a/tests/kernel/wave/wave_utils_test.py b/tests/kernel/wave/wave_utils_test.py index ec1198fd..bce6de9f 100644 --- a/tests/kernel/wave/wave_utils_test.py +++ b/tests/kernel/wave/wave_utils_test.py @@ -6,8 +6,8 @@ import logging import unittest -from shark_turbine.kernel.lang import sym -from shark_turbine.kernel.wave.utils import delinearize_index +from iree.turbine.kernel.lang import sym +from iree.turbine.kernel.wave.utils import delinearize_index import numpy as np M = sym.M diff --git a/tests/ops/iree_test.py b/tests/ops/iree_test.py index b06a7910..facbf545 100644 --- a/tests/ops/iree_test.py +++ b/tests/ops/iree_test.py @@ -10,8 +10,8 @@ import torch import torch.nn as nn -import shark_turbine.aot as aot -import shark_turbine.ops as ops +import iree.turbine.aot as aot +import iree.turbine.ops as ops # See runtime/op_reg/kernel_aot_test.py for additional tests of the trace diff --git a/tests/runtime/device_test.py b/tests/runtime/device_test.py index e78aff8e..89cd83c8 100644 --- a/tests/runtime/device_test.py +++ b/tests/runtime/device_test.py @@ -14,17 +14,17 @@ from iree.runtime import HalElementType # Public API imports. -from shark_turbine.runtime import ( +from iree.turbine.runtime import ( Device, ) # Internals. -from shark_turbine.runtime.device import ( +from iree.turbine.runtime.device import ( _CURRENT_THREAD, get_device_from_torch, ) -from shark_turbine.support.exceptions import * +from iree.turbine.support.exceptions import * class DeviceTest(unittest.TestCase): @@ -151,7 +151,7 @@ def testFromTorchDevice(self): print(device.dump_device_info()) def testJit(self): - from shark_turbine.ops import _str_format_test_ops as test_ops + from iree.turbine.ops import _str_format_test_ops as test_ops t = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], device="cuda:0") result = test_ops.test_add(t, t) @@ -161,7 +161,7 @@ def testJit(self): class TorchCPUInterop(unittest.TestCase): def testJitStrFormat(self): - from shark_turbine.ops import _str_format_test_ops as test_ops + from iree.turbine.ops import _str_format_test_ops as test_ops t = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], device="cpu") result = test_ops.test_add(t, t) @@ -169,7 +169,7 @@ def testJitStrFormat(self): torch.testing.assert_close(result, expected) def testJitJinja(self): - from shark_turbine.ops import _jinja_test_ops as test_ops + from iree.turbine.ops import _jinja_test_ops as test_ops t = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], device="cpu") result = test_ops.test_add(t, t) diff --git a/tests/runtime/launch_test.py b/tests/runtime/launch_test.py index 1a142161..ad12b2e3 100644 --- a/tests/runtime/launch_test.py +++ b/tests/runtime/launch_test.py @@ -8,11 +8,11 @@ import torch import unittest -from shark_turbine.aot.params import ( +from iree.turbine.aot.params import ( ParameterArchiveBuilder, ) -from shark_turbine.runtime import ( +from iree.turbine.runtime import ( Launchable, ) diff --git a/tests/runtime/op_reg/impl_helper_test.py b/tests/runtime/op_reg/impl_helper_test.py index b0797c2d..2661dc5b 100644 --- a/tests/runtime/op_reg/impl_helper_test.py +++ b/tests/runtime/op_reg/impl_helper_test.py @@ -9,7 +9,7 @@ import torch -from shark_turbine.ops import _str_format_test_ops +from iree.turbine.ops import _str_format_test_ops class KernelRegTest(unittest.TestCase): diff --git a/tests/runtime/op_reg/kernel_aot_test.py b/tests/runtime/op_reg/kernel_aot_test.py index 4aa04857..4533326a 100644 --- a/tests/runtime/op_reg/kernel_aot_test.py +++ b/tests/runtime/op_reg/kernel_aot_test.py @@ -10,10 +10,10 @@ import torch import torch.nn as nn -import shark_turbine.aot as aot -import shark_turbine.ops as ops +import iree.turbine.aot as aot +import iree.turbine.ops as ops -from shark_turbine.transforms.general.custom_op_expansion import ExpandCustomOpsPass +from iree.turbine.transforms.general.custom_op_expansion import ExpandCustomOpsPass class MLP(nn.Module): diff --git a/tests/runtime/op_reg/kernel_reg_test.py b/tests/runtime/op_reg/kernel_reg_test.py index 75554b04..dfc88d83 100644 --- a/tests/runtime/op_reg/kernel_reg_test.py +++ b/tests/runtime/op_reg/kernel_reg_test.py @@ -9,9 +9,9 @@ import torch -from shark_turbine.runtime.op_reg import * +from iree.turbine.runtime.op_reg import * -from shark_turbine.runtime.op_reg.compiler import _testing_get_cache_size +from iree.turbine.runtime.op_reg.compiler import _testing_get_cache_size class KernelRegTest(unittest.TestCase): diff --git a/tests/tools/interpreter_test.py b/tests/tools/interpreter_test.py index 0513b10b..2152c701 100644 --- a/tests/tools/interpreter_test.py +++ b/tests/tools/interpreter_test.py @@ -1,8 +1,8 @@ -from shark_turbine.tools.interpreter import Interpreter -import shark_turbine.kernel as tk -import shark_turbine.kernel.lang as tkl -import shark_turbine.kernel.wave as tkw -from shark_turbine.kernel.lang.global_symbols import * +from iree.turbine.tools.interpreter import Interpreter +import iree.turbine.kernel as tk +import iree.turbine.kernel.lang as tkl +import iree.turbine.kernel.wave as tkw +from iree.turbine.kernel.lang.global_symbols import * import torch diff --git a/tests/top_level_package_test.py b/tests/top_level_package_test.py index 52ea796b..b2c04cdd 100644 --- a/tests/top_level_package_test.py +++ b/tests/top_level_package_test.py @@ -11,8 +11,8 @@ class TopLevelPackageTest(unittest.TestCase): def testIreeTurbineRedirect(self): # We have a temporary redirect of the top-level API to the - # iree.turbine namespace. - from iree.turbine import aot, dynamo, kernel, ops, runtime + # shark-turbine namespace. + from shark_turbine import aot, dynamo, kernel, ops, runtime if __name__ == "__main__": diff --git a/tests/transforms/general/add_metadata_test.py b/tests/transforms/general/add_metadata_test.py index 8055fa26..da5d0207 100644 --- a/tests/transforms/general/add_metadata_test.py +++ b/tests/transforms/general/add_metadata_test.py @@ -11,7 +11,7 @@ from iree.compiler.ir import Context, Operation, Module -from shark_turbine.transforms.general import add_metadata +from iree.turbine.transforms.general import add_metadata SIMPLE_FUNC_ASM = r""" func.func @list_func(%arg0 : !iree_input.list) -> !iree_input.list { diff --git a/tests/transforms/general/custom_op_expansion_test.py b/tests/transforms/general/custom_op_expansion_test.py index b94e2750..f621320d 100644 --- a/tests/transforms/general/custom_op_expansion_test.py +++ b/tests/transforms/general/custom_op_expansion_test.py @@ -9,15 +9,15 @@ import torch import unittest -from shark_turbine.transforms.general.custom_op_expansion import ExpandCustomOpsPass -from shark_turbine.runtime.op_reg import ( +from iree.turbine.transforms.general.custom_op_expansion import ExpandCustomOpsPass +from iree.turbine.runtime.op_reg import ( def_library, CustomOp, KernelBuilder, KernelSelection, ) -from shark_turbine.support.ir_imports import ( +from iree.turbine.support.ir_imports import ( Context, Module, ) diff --git a/tests/transforms/general/rename_parameters_test.py b/tests/transforms/general/rename_parameters_test.py index 74fc6753..a14dbcbd 100644 --- a/tests/transforms/general/rename_parameters_test.py +++ b/tests/transforms/general/rename_parameters_test.py @@ -14,8 +14,8 @@ Operation, ) -from shark_turbine.transforms import rewriter -from shark_turbine.transforms.general import rename_parameters +from iree.turbine.transforms import rewriter +from iree.turbine.transforms.general import rename_parameters SIMPLE_GLOBALS_ASM = r""" module { diff --git a/tests/transforms/quantization/mm_group_quant_test.py b/tests/transforms/quantization/mm_group_quant_test.py index c6870d2c..b465301b 100644 --- a/tests/transforms/quantization/mm_group_quant_test.py +++ b/tests/transforms/quantization/mm_group_quant_test.py @@ -14,8 +14,8 @@ Operation, ) -from shark_turbine.transforms import rewriter -from shark_turbine.transforms.quantization import mm_group_quant +from iree.turbine.transforms import rewriter +from iree.turbine.transforms.quantization import mm_group_quant MM_F32_TO_INT4_CONTENTS = ( Path(__file__).resolve().parent / "mm_f32_to_int4.mlir" From 796f3a57bb80b153f9b832112047597d52a183ab Mon Sep 17 00:00:00 2001 From: Stanley Winata <68087699+raikonenfnu@users.noreply.github.com> Date: Sat, 5 Oct 2024 16:07:51 -0700 Subject: [PATCH 27/28] [TKW] Minor bug fix expansion to handle reduction and MMA at same time. (#196) Signed-off-by: Stanley Winata --- iree/turbine/kernel/wave/expansion.py | 2 +- lit_tests/kernel/wave/codegen.py | 78 +++++++++++++++++++++++++++ 2 files changed, 79 insertions(+), 1 deletion(-) diff --git a/iree/turbine/kernel/wave/expansion.py b/iree/turbine/kernel/wave/expansion.py index 69785031..75cc5551 100644 --- a/iree/turbine/kernel/wave/expansion.py +++ b/iree/turbine/kernel/wave/expansion.py @@ -326,7 +326,7 @@ def _expand_node( return context[(node, get_indexed_dims(dim_query, node), res_idx)] elif isinstance(node, Reduction): return _expand_reduction( - node, trace, dim_query, dim_scaling, node_index_setter, context + node, trace, dim_query, dim_scaling, node_index_setter, context, res_idx ) elif isinstance(node, Getitem): res_idx = node.res_idx diff --git a/lit_tests/kernel/wave/codegen.py b/lit_tests/kernel/wave/codegen.py index 30144024..65db9a78 100644 --- a/lit_tests/kernel/wave/codegen.py +++ b/lit_tests/kernel/wave/codegen.py @@ -758,6 +758,84 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: # CHECK-COUNT-8: amdgpu.mfma +# This test is used to check two things +# 1. Reduction with multiple different types(MMA, ReduceOp) of iterArg works +# 2. ReduceOp lowering works using constraints from MMA (not just vector_shape). +@run_test +def test_gemm_and_reduce(): + constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] + constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] + constraints += [tkw.TilingConstraint(K, BLOCK_K)] + constraints += [tkw.WaveConstraint(M, BLOCK_M / 2)] + constraints += [tkw.WaveConstraint(N, BLOCK_N / 2)] + + constraints += [ + tkw.HardwareConstraint( + threads_per_wave=64, + waves_per_block=(2, 2, 1), + mma_type=tkw.MMAType.F32_16x16x16_F16, + ) + ] + + @tkw.wave(constraints) + def gemm( + a: tkl.Memory[M, K, ADDRESS_SPACE, tkl.f16], + b: tkl.Memory[N, K, ADDRESS_SPACE, tkl.f16], + c: tkl.Memory[M, ADDRESS_SPACE_0, tkl.f16], + d: tkl.Memory[M, N, ADDRESS_SPACE_0, tkl.f32], + ): + c_reg = tkl.Register[M, N, tkl.f32](0.0) + init_max = tkl.Register[M, tkl.f16](-1e6) + + @tkw.reduction(K, init_args=[init_max, c_reg]) + def repeat( + partial_max: tkl.Register[M, tkl.f16], acc: tkl.Register[M, N, tkl.f32] + ) -> tkl.Register[M, N, tkl.f32]: + a_reg = tkw.read(a, elements_per_thread=LOAD_ELEMS_PER_THREAD) + b_reg = tkw.read(b, elements_per_thread=LOAD_ELEMS_PER_THREAD) + partial_max = tkw.max(a_reg, partial_max, dim=K) + acc = tkw.mma(a_reg, b_reg, acc) + return partial_max, acc + + res_max, res_mm = repeat + tkw.write(res_max, c, elements_per_thread=1) + tkw.write(res_mm, d, elements_per_thread=STORE_ELEMS_PER_THREAD) + + with tk.gen.TestLaunchContext( + { + M: 64, + N: 128, + K: 64, + BLOCK_M: 32, + BLOCK_N: 32, + BLOCK_K: 16, + LOAD_ELEMS_PER_THREAD: 4, + STORE_ELEMS_PER_THREAD: 4, + ADDRESS_SPACE: SHARED_ADDRESS_SPACE, + ADDRESS_SPACE_0: GLOBAL_ADDRESS_SPACE, + }, + canonicalize=True, + ): + a = torch.randn(64, 32, dtype=torch.float16) + b = torch.randn(128, 32, dtype=torch.float16) + c = torch.zeros(64, dtype=torch.float16) + d = torch.zeros(64, 128, dtype=torch.float32) + print(gemm(a, b, c, d).module_op) + # CHECK-DAG: %[[C0_IDX:.+]] = arith.constant 0 : index + # CHECK-DAG: %[[C4_IDX:.+]] = arith.constant 4 : index + # CHECK-DAG: %[[C1_IDX:.+]] = arith.constant 1 : index + + # Tile Reduction Loop + # Note: Shape is 32x20 instead of 32x16 because of padding to avoid bank conflicts + # CHECK: %{{.*}}:2 = scf.for %[[ITER:.+]] = %[[C0_IDX]] to %[[C4_IDX]] step %[[C1_IDX]] + # CHECK-SAME: iter_args(%[[ACC0:.+]] = %{{.*}}, %[[ACC1:.+]] = {{.*}}) + # CHECK-COUNT-2: vector.load{{.*}} memref<32x20xf16, #gpu.address_space>, vector<4xf16> + # CHECK-COUNT-6: gpu.shuffle xor + # CHECK: %[[MAX:.+]] = arith.maximumf %[[ACC0]], %{{.*}} + # CHECK: %[[MMA:.+]] = amdgpu.mfma %{{.*}} * %{{.*}} + %[[ACC1]] + # CHECK: scf.yield %[[MAX]], %[[MMA]] : vector<1xf16>, vector<4xf32> + + @run_test def test_add_float(): constraints: list[tkw.Constraint] = [ From 47720f22b5fec29c64d1e688820dd53e99589810 Mon Sep 17 00:00:00 2001 From: Vinayak Dev Date: Mon, 7 Oct 2024 22:52:55 +0530 Subject: [PATCH 28/28] [tests][aot] Add test for externalized parameters --- tests/aot/params_test.py | 53 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) diff --git a/tests/aot/params_test.py b/tests/aot/params_test.py index 895cb2b9..bd484c3e 100644 --- a/tests/aot/params_test.py +++ b/tests/aot/params_test.py @@ -12,6 +12,9 @@ import torch import torch.nn as nn +import iree.runtime as rt +import iree.compiler as ireec + from iree.turbine.aot import ( export, externalize_module_parameters, @@ -35,6 +38,56 @@ def forward(self, x): return result +class LinearModule(torch.nn.Module): + def __init__(self, in_features, out_features): + super().__init__() + self.weight = torch.nn.Parameter(torch.randn(in_features, out_features)) + self.bias = torch.nn.Parameter(torch.randn(out_features)) + + def forward(self, input): + return (input @ self.weight) + self.bias + + +class ExternalParamsTest(unittest.TestCase): + def setUp(self): + self.instance = rt.VmInstance() + self.device = rt.get_device(ireec.core.DEFAULT_TESTING_DRIVER) + self.config = rt.Config(device=self.device) + + def testSeparateWeightsAtRuntime(self): + linear_module = LinearModule(4, 3).requires_grad_(False) + externalize_module_parameters(linear_module) + wt = linear_module.weight.data.contiguous() + bias = linear_module.bias.data.contiguous() + + input = torch.randn(4) + exported_module = export(linear_module, input) + binary = exported_module.compile(save_to=None) + + idx = rt.ParameterIndex() + idx.add_buffer("weight", wt.detach().numpy().tobytes()) + idx.add_buffer("bias", bias.detach().numpy().tobytes()) + + config = rt.Config(driver_name="local-task") + instance = config.vm_instance + param_module = rt.create_io_parameters_module( + instance, + idx.create_provider(scope="model"), + ) + + vm_modules = rt.load_vm_modules( + param_module, + rt.create_hal_module(instance, config.device), + rt.VmModule.copy_buffer(instance, binary.map_memory()), + config=config, + ) + + m = vm_modules[-1] + result_vm = m.main(input).to_host() + result_torch = linear_module(input) + torch.testing.assert_close(torch.from_numpy(result_vm), result_torch) + + class ParamsTest(unittest.TestCase): def testCreateArchive(self): with tempfile.TemporaryDirectory() as td: