diff --git a/tests/kernel/wave/wave_attention_test.py b/tests/kernel/wave/wave_attention_test.py index 1ecfa7b2..dfbd18f3 100644 --- a/tests/kernel/wave/wave_attention_test.py +++ b/tests/kernel/wave/wave_attention_test.py @@ -249,7 +249,7 @@ def testAttention( j = tkw.IndexMapping.iterator(1) k = tkw.IndexMapping.iterator(2) mapping = tkw.IndexMapping( - num_iterators=3, inputs={B: i, M: j, N: k}, outputs={B: i, N: k, M: j} + num_iterators=3, inputs={B: i, N: j, M: k}, outputs={B: i, M: k, N: j} ) @tkw.wave(constraints) @@ -295,7 +295,7 @@ def repeat( # repeat represents the results of the loop res_max, res_sum, res_mm = repeat res = res_mm / res_sum - tkw.write(res, c, elements_per_thread=STORE_ELEMS_PER_THREAD) + tkw.write(res, c, mapping=mapping, elements_per_thread=STORE_ELEMS_PER_THREAD) hyperparams = { ADDRESS_SPACE: SHARED_ADDRESS_SPACE, @@ -357,5 +357,4 @@ def repeat( with open(filename, "w") as f: f.write(mb.module_op.get_asm()) - # TODO: Fix transposed writes to output. - assert_allclose(output.permute([0, 2, 1]), torch_ref) + assert_allclose(output, torch_ref)