diff --git a/iree/turbine/kernel/wave/codegen.py b/iree/turbine/kernel/wave/codegen.py index 5c02872d..b2b1c6a6 100644 --- a/iree/turbine/kernel/wave/codegen.py +++ b/iree/turbine/kernel/wave/codegen.py @@ -1325,16 +1325,35 @@ def handle_reshape(emitter: WaveEmitter, node: fx.Node): # Determine whether to extract or combine. if len(args) > 1: - raise NotImplementedError( - "reshape: Currently only handles cases where target_vector_shapes > custom.vector_shapes" - ) + concatenated = None + for i, sub_arg in enumerate(args): + vector = cast_vector(emitter, sub_arg) + shape = vector.type.shape[0] + if concatenated is None: + element_type = vector.type.element_type + vector_type = VectorType.get([shape * len(args)], element_type) + concatenated = arith_d.ConstantOp( + vector_type, + DenseElementsAttr.get_splat( + vector_type, get_constant_attr(0, element_type) + ), + ).result + concatenated = vector_d.insert_strided_slice( + vector, concatenated, [i * shape], [1] + ) + emitter.bind_node_proxy(node, IRProxyValue(concatenated)) + return # Extract the appropriate slice. The offset is obtained from the expanded_dim # and so corresponds to the dim_query during expansion. To obtain the - # actual offset, we need to multiple by the size which is determined by comparing - # the source and target vector shapes along the innermost dimension. - size = target_vector_shapes[innermost_dim] // custom.vector_shapes[innermost_dim] + # actual offset, we need to multiply by the size. The size is obtained by + # computing the number of partitions using the source and target vector shapes + # and dividing the incoming vector shape by the number of partitions. + num_partitions = ( + target_vector_shapes[innermost_dim] // custom.vector_shapes[innermost_dim] + ) vector = cast_vector(emitter, args[0]) + size = vector.type.shape[0] // num_partitions result_type = VectorType.get([size], vector.type.element_type) slice = vector_d.extract_strided_slice( result_type, diff --git a/iree/turbine/kernel/wave/iree_utils.py b/iree/turbine/kernel/wave/iree_utils.py index d5e23629..d1031bd5 100644 --- a/iree/turbine/kernel/wave/iree_utils.py +++ b/iree/turbine/kernel/wave/iree_utils.py @@ -41,6 +41,44 @@ def get_chain_mmt_asm( }}""" +def get_chain_mmt_f8_asm( + query_type: str, key_type: str, value_type: str, output_type: str +) -> str: + B, M, K1, input_dtype = query_type.split("x") + B, K2, K1, input_dtype = key_type.split("x") + B, N, K2, input_dtype = value_type.split("x") + B, N, M, output_dtype = output_type.split("x") + f8_dtype = "f8E4M3FNUZ" + intermediate_output_type = f"{B}x{K2}x{M}x{output_dtype}" + intermediate_cast_type = f"{B}x{K2}x{M}x{f8_dtype}" + transposed_cast_type = f"{B}x{M}x{K2}x{f8_dtype}" + transposed_output_type = f"{B}x{M}x{N}x{output_dtype}" + query_f8_type = "x".join([B, M, K1, f8_dtype]) + key_f8_type = "x".join([B, K2, K1, f8_dtype]) + value_f8_type = "x".join([B, N, K2, f8_dtype]) + return f""" + func.func @chain_mmt_f8(%query: tensor<{query_type}>, %key: tensor<{key_type}>, %value: tensor<{value_type}>) -> tensor<{output_type}> {{ + %c0 = arith.constant 0.0 : f32 + %init = tensor.empty() : tensor<{intermediate_output_type}> + %query_f8 = arith.truncf %query : tensor<{query_type}> to tensor<{query_f8_type}> + %key_f8 = arith.truncf %key : tensor<{key_type}> to tensor<{key_f8_type}> + %inital_result = linalg.fill ins(%c0 : f32) outs(%init : tensor<{intermediate_output_type}>) -> tensor<{intermediate_output_type}> + %result = linalg.batch_matmul_transpose_b ins(%key_f8, %query_f8 : tensor<{key_f8_type}>, tensor<{query_f8_type}>) + outs(%inital_result : tensor<{intermediate_output_type}>) -> tensor<{intermediate_output_type}> + %trunc = arith.truncf %result : tensor<{intermediate_output_type}> to tensor<{intermediate_cast_type}> + %init2 = tensor.empty() : tensor<{transposed_cast_type}> + %transpose = linalg.transpose ins(%trunc: tensor<{intermediate_cast_type}>) outs(%init2: tensor<{transposed_cast_type}>) permutation=[0, 2, 1] + %init3 = tensor.empty() : tensor<{transposed_output_type}> + %inital_result3 = linalg.fill ins(%c0 : f32) outs(%init3 : tensor<{transposed_output_type}>) -> tensor<{transposed_output_type}> + %value_f8 = arith.truncf %value : tensor<{value_type}> to tensor<{value_f8_type}> + %result2 = linalg.batch_matmul_transpose_b ins(%transpose, %value_f8: tensor<{transposed_cast_type}>, tensor<{value_f8_type}>) + outs(%inital_result3 : tensor<{transposed_output_type}>) -> tensor<{transposed_output_type}> + %init4 = tensor.empty() : tensor<{output_type}> + %transpose2 = linalg.transpose ins(%result2: tensor<{transposed_output_type}>) outs(%init4: tensor<{output_type}>) permutation=[0, 2, 1] + return %transpose2 : tensor<{output_type}> + }}""" + + def get_mmt_asm( lhs_type: str, rhs_type: str, @@ -141,6 +179,12 @@ def generate_iree_ref( value_type = get_type_str(kernel_inputs[2].shape, kernel_inputs[2].dtype) output_type = get_type_str(kernel_outputs[0].shape, kernel_outputs[0].dtype) asm = get_chain_mmt_asm(query_type, key_type, value_type, output_type) + elif kernel_type == "chain_mmt_f8": + query_type = get_type_str(kernel_inputs[0].shape, kernel_inputs[0].dtype) + key_type = get_type_str(kernel_inputs[1].shape, kernel_inputs[1].dtype) + value_type = get_type_str(kernel_inputs[2].shape, kernel_inputs[2].dtype) + output_type = get_type_str(kernel_outputs[0].shape, kernel_outputs[0].dtype) + asm = get_chain_mmt_f8_asm(query_type, key_type, value_type, output_type) elif kernel_type.startswith(conv_str): lhs_type = get_type_str(kernel_inputs[0].shape, kernel_inputs[0].dtype) rhs_type = get_type_str(kernel_inputs[1].shape, kernel_inputs[1].dtype) diff --git a/lit_tests/kernel/wave/codegen.py b/lit_tests/kernel/wave/codegen.py index 95075a38..fc5e482c 100644 --- a/lit_tests/kernel/wave/codegen.py +++ b/lit_tests/kernel/wave/codegen.py @@ -918,6 +918,178 @@ def repeat( # CHECK: scf.yield +@run_test +def test_chained_gemm_32x32x16(): + K1 = tkl.sym.K1 + K2 = tkl.sym.K2 + BLOCK_K2 = tkl.sym.BLOCK_K2 + + constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] + constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] + constraints += [tkw.WorkgroupConstraint(B, BLOCK_B, 2)] + constraints += [tkw.TilingConstraint(K2, BLOCK_K2)] + constraints += [tkw.WaveConstraint(M, BLOCK_M / 2)] + constraints += [tkw.WaveConstraint(N, BLOCK_N / 2)] + + mfma_variant = tkw.MMAType.F32_32x32x16_F8 + constraints += [ + tkw.HardwareConstraint( + threads_per_wave=64, + waves_per_block=(2, 2, 1), + mma_type=mfma_variant, + vector_shapes={B: 0}, + ) + ] + + @tkw.wave(constraints) + def chained_gemm_32x32x16( + q: tkl.Memory[B, M, K1, ADDRESS_SPACE, tkl.f16], + k: tkl.Memory[B, K2, K1, ADDRESS_SPACE, tkl.f16], + v: tkl.Memory[B, N, K2, ADDRESS_SPACE, tkl.f16], + c: tkl.Memory[B, M, N, ADDRESS_SPACE_0, tkl.f32], + ): + c_reg = tkl.Register[B, M, N, tkl.f32](0.0) + + @tkw.reduction(K2, init_args=[c_reg]) + def repeat( + acc: tkl.Register[B, M, N, tkl.f32] + ) -> tkl.Register[B, M, N, tkl.f32]: + inner_acc = tkl.Register[B, K2, M, tkl.f32](0.0) + q_reg = tkw.read(q, elements_per_thread=LOAD_ELEMS_PER_THREAD) + k_reg = tkw.read(k, elements_per_thread=LOAD_ELEMS_PER_THREAD) + q_reg = tkw.cast(q_reg, tkl.f8e4m3fnuz) + k_reg = tkw.cast(k_reg, tkl.f8e4m3fnuz) + kq_reg = tkw.mma(k_reg, q_reg, inner_acc) + qk_reg = tkw.permute(kq_reg, target_shape=[B, M, K2]) + qk_cast_reg = tkw.cast(qk_reg, tkl.f8e4m3fnuz) + v_reg = tkw.read(v, elements_per_thread=LOAD_ELEMS_PER_THREAD) + v_reg = tkw.cast(v_reg, tkl.f8e4m3fnuz) + acc = tkw.mma(qk_cast_reg, v_reg, acc) + return acc + + tkw.write(repeat, c, elements_per_thread=STORE_ELEMS_PER_THREAD) + + with tk.gen.TestLaunchContext( + { + M: 128, + N: 128, + K1: 32, + K2: 256, + B: 8, + BLOCK_M: 64, + BLOCK_N: 64, + BLOCK_K2: 32, + BLOCK_B: 1, + LOAD_ELEMS_PER_THREAD: get_mfma_load_elems_per_thread(mfma_variant), + STORE_ELEMS_PER_THREAD: get_mfma_store_elems_per_thread(mfma_variant), + ADDRESS_SPACE: SHARED_ADDRESS_SPACE, + ADDRESS_SPACE_0: GLOBAL_ADDRESS_SPACE, + }, + canonicalize=True, + ): + q = torch.randn(8, 64, 64, dtype=torch.float16) + k = torch.randn(8, 256, 64, dtype=torch.float16) + v = torch.zeros(8, 128, 256, dtype=torch.float16) + output = torch.zeros(8, 64, 128, dtype=torch.float32) + print(chained_gemm_32x32x16(q, k, v, output).module_op) + + # CHECK: func.func @chained_gemm_32x32x16( + # CHECK: {{.*}} = scf.for + # CHECK-COUNT-2: {{.*}} = amdgpu.mfma + # CHECK-COUNT-3: {{.*}} = arith.truncf + # CHECK: {{.*}} = vector.extract_strided_slice {{.*}} {offsets = [0], sizes = [8], strides = [1]} + # CHECK: {{.*}} = vector.extract_strided_slice {{.*}} {offsets = [8], sizes = [8], strides = [1]} + # CHECK-COUNT-2: {{.*}} = amdgpu.mfma + # CHECK: scf.yield + + +@run_test +def test_chained_gemm_16x16x32(): + K1 = tkl.sym.K1 + K2 = tkl.sym.K2 + BLOCK_K2 = tkl.sym.BLOCK_K2 + + constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] + constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] + constraints += [tkw.WorkgroupConstraint(B, BLOCK_B, 2)] + constraints += [tkw.TilingConstraint(K2, BLOCK_K2)] + constraints += [tkw.WaveConstraint(M, BLOCK_M / 2)] + constraints += [tkw.WaveConstraint(N, BLOCK_N / 2)] + + mfma_variant = tkw.MMAType.F32_16x16x32_F8 + constraints += [ + tkw.HardwareConstraint( + threads_per_wave=64, + waves_per_block=(2, 2, 1), + mma_type=mfma_variant, + vector_shapes={B: 0}, + ) + ] + + @tkw.wave(constraints) + def chained_gemm_16x16x32( + q: tkl.Memory[B, M, K1, ADDRESS_SPACE, tkl.f16], + k: tkl.Memory[B, K2, K1, ADDRESS_SPACE, tkl.f16], + v: tkl.Memory[B, N, K2, ADDRESS_SPACE, tkl.f16], + c: tkl.Memory[B, M, N, ADDRESS_SPACE_0, tkl.f32], + ): + c_reg = tkl.Register[B, M, N, tkl.f32](0.0) + + @tkw.reduction(K2, init_args=[c_reg]) + def repeat( + acc: tkl.Register[B, M, N, tkl.f32] + ) -> tkl.Register[B, M, N, tkl.f32]: + inner_acc = tkl.Register[B, K2, M, tkl.f32](0.0) + q_reg = tkw.read(q, elements_per_thread=LOAD_ELEMS_PER_THREAD) + k_reg = tkw.read(k, elements_per_thread=LOAD_ELEMS_PER_THREAD) + q_reg = tkw.cast(q_reg, tkl.f8e4m3fnuz) + k_reg = tkw.cast(k_reg, tkl.f8e4m3fnuz) + kq_reg = tkw.mma(k_reg, q_reg, inner_acc) + qk_reg = tkw.permute(kq_reg, target_shape=[B, M, K2]) + qk_cast_reg = tkw.cast(qk_reg, tkl.f8e4m3fnuz) + v_reg = tkw.read(v, elements_per_thread=LOAD_ELEMS_PER_THREAD) + v_reg = tkw.cast(v_reg, tkl.f8e4m3fnuz) + acc = tkw.mma(qk_cast_reg, v_reg, acc) + return acc + + tkw.write(repeat, c, elements_per_thread=STORE_ELEMS_PER_THREAD) + + with tk.gen.TestLaunchContext( + { + M: 128, + N: 128, + K1: 32, + K2: 256, + B: 8, + BLOCK_M: 64, + BLOCK_N: 64, + BLOCK_K2: 32, + BLOCK_B: 1, + LOAD_ELEMS_PER_THREAD: get_mfma_load_elems_per_thread(mfma_variant), + STORE_ELEMS_PER_THREAD: get_mfma_store_elems_per_thread(mfma_variant), + ADDRESS_SPACE: SHARED_ADDRESS_SPACE, + ADDRESS_SPACE_0: GLOBAL_ADDRESS_SPACE, + }, + canonicalize=True, + ): + q = torch.randn(8, 64, 64, dtype=torch.float16) + k = torch.randn(8, 256, 64, dtype=torch.float16) + v = torch.zeros(8, 128, 256, dtype=torch.float16) + output = torch.zeros(8, 64, 128, dtype=torch.float32) + print(chained_gemm_16x16x32(q, k, v, output).module_op) + + # CHECK: func.func @chained_gemm_16x16x32( + # CHECK: {{.*}} = scf.for + # CHECK-COUNT-4: {{.*}} = amdgpu.mfma + # CHECK-COUNT-6: {{.*}} = arith.truncf + # CHECK: {{.*}} = vector.insert_strided_slice {{.*}} {offsets = [0], strides = [1]} + # CHECK: {{.*}} = vector.insert_strided_slice {{.*}} {offsets = [4], strides = [1]} + # CHECK: {{.*}} = vector.insert_strided_slice {{.*}} {offsets = [0], strides = [1]} + # CHECK: {{.*}} = vector.insert_strided_slice {{.*}} {offsets = [4], strides = [1]} + # CHECK-COUNT-4: {{.*}} = amdgpu.mfma + # CHECK: scf.yield + + @run_test def test_gemm_pipelined(): constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] diff --git a/tests/kernel/wave/wave_attention_test.py b/tests/kernel/wave/wave_attention_test.py index 6ae11055..f2cb3672 100644 --- a/tests/kernel/wave/wave_attention_test.py +++ b/tests/kernel/wave/wave_attention_test.py @@ -198,6 +198,155 @@ def repeat( assert_close(output, iree_ref) +# This test requires some more analysis on the index sequences between +# the two chained GEMMs. +@require_e2e +@pytest.mark.xfail +@pytest.mark.parametrize("shape", get_test_shapes("test_attention")) +@pytest.mark.parametrize("enable_scheduling", [False]) +@pytest.mark.parametrize( + "mfma_variant", + [ + MMAType.F32_16x16x32_F8, + MMAType.F32_32x32x16_F8, + ], +) +def testChainedGemm_f8( + shape: tuple[int], enable_scheduling: bool, mfma_variant: MMAType, request +): + run_bench = request.config.getoption("--runperf") + dump_perf = request.config.getoption("--dump-perf-files-path") + # Input sizes + B = tkl.sym.B + M = tkl.sym.M + N = tkl.sym.N + K1 = tkl.sym.K1 + K2 = tkl.sym.K2 + # Workgroup tile sizes + BLOCK_B = tkl.sym.BLOCK_B + BLOCK_M = tkl.sym.BLOCK_M + BLOCK_N = tkl.sym.BLOCK_N + BLOCK_K2 = tkl.sym.BLOCK_K2 + # Address space (for GPU, shared(1) or global(0)) + ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE + # Other hyperparameters + LOAD_ELEMS_PER_THREAD = tkl.sym.LOAD_ELEMS_PER_THREAD + STORE_ELEMS_PER_THREAD = tkl.sym.STORE_ELEMS_PER_THREAD + + # Expose user-constraints + constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] + constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] + constraints += [tkw.WorkgroupConstraint(B, BLOCK_B, 2)] + constraints += [tkw.TilingConstraint(K2, BLOCK_K2)] + constraints += [tkw.WaveConstraint(M, BLOCK_M / 2)] + constraints += [tkw.WaveConstraint(N, BLOCK_N / 2)] + + constraints += [ + tkw.HardwareConstraint( + threads_per_wave=64, + waves_per_block=(2, 2, 1), + mma_type=mfma_variant, + vector_shapes={B: 0}, + ) + ] + + i = tkw.IndexMapping.iterator(0) + j = tkw.IndexMapping.iterator(1) + k = tkw.IndexMapping.iterator(2) + mapping = tkw.IndexMapping( + num_iterators=3, inputs={B: i, M: j, N: k}, outputs={B: i, N: k, M: j} + ) + + @tkw.wave(constraints) + def chained_gemm_f8( + q: tkl.Memory[B, M, K1, GLOBAL_ADDRESS_SPACE, tkl.f16], + k: tkl.Memory[B, K2, K1, ADDRESS_SPACE, tkl.f16], + v: tkl.Memory[B, N, K2, ADDRESS_SPACE, tkl.f16], + c: tkl.Memory[B, N, M, GLOBAL_ADDRESS_SPACE, tkl.f32], + ): + c_reg = tkl.Register[B, M, N, tkl.f32](0.0) + + @tkw.reduction(K2, init_args=[c_reg]) + def repeat( + acc: tkl.Register[B, M, N, tkl.f32] + ) -> tkl.Register[B, M, N, tkl.f32]: + inner_acc = tkl.Register[B, K2, M, tkl.f32](0.0) + q_reg = tkw.read(q, elements_per_thread=LOAD_ELEMS_PER_THREAD) + k_reg = tkw.read(k, elements_per_thread=LOAD_ELEMS_PER_THREAD) + q_reg = tkw.cast(q_reg, tkl.f8e4m3fnuz) + k_reg = tkw.cast(k_reg, tkl.f8e4m3fnuz) + kq_reg = tkw.mma(k_reg, q_reg, inner_acc) + qk_reg = tkw.permute(kq_reg, target_shape=[B, M, K2]) + qk_cast_reg = tkw.cast(qk_reg, tkl.f8e4m3fnuz) + v_reg = tkw.read(v, elements_per_thread=LOAD_ELEMS_PER_THREAD) + v_reg = tkw.cast(v_reg, tkl.f8e4m3fnuz) + acc = tkw.mma(qk_cast_reg, v_reg, acc) + return acc + + # repeat represents the results of the loop + tkw.write( + repeat, c, mapping=mapping, elements_per_thread=STORE_ELEMS_PER_THREAD + ) + + hyperparams = { + ADDRESS_SPACE: SHARED_ADDRESS_SPACE, + LOAD_ELEMS_PER_THREAD: get_mfma_load_elems_per_thread(mfma_variant), + STORE_ELEMS_PER_THREAD: get_mfma_store_elems_per_thread(mfma_variant), + BLOCK_B: 1, + BLOCK_M: 64, + BLOCK_N: 64, + BLOCK_K2: 32, + B: shape[0], + M: shape[1], + N: shape[2], + K1: shape[3], + K2: shape[4], + READ_SHARED_DELAY: 1, + WRITE_SHARED_DELAY: 1, + READ_GLOBAL_DELAY: 2, + WRITE_GLOBAL_DELAY: 2, + MMA_DELAY: 1, + SHARED_MEMORY_UNITS: 4, + GLOBAL_MEMORY_UNITS: 4, + MMA_UNITS: 4, + } + config = {"backend": "rocm", "device": "hip", "target": "gfx942"} + if run_bench: + config["benchmark_batch_size"] = 10 + config["benchmark_repetitions"] = 3 + if dump_perf is not None: + perf_filename = request.node.name + ".json" + config["benchmark_results_file"] = os.path.join( + dump_perf, "tk_" + perf_filename + ) + + with tk.gen.TestLaunchContext( + hyperparams, + canonicalize=True, + run=True, + run_bench=run_bench, + run_config=config, + schedule=enable_scheduling, + use_scheduling_barriers=enable_scheduling_barriers, + ): + q = torch.randn(shape[0], shape[1], shape[3], dtype=torch.float16) + k = torch.randn(shape[0], shape[4], shape[3], dtype=torch.float16) + v = torch.randn(shape[0], shape[2], shape[4], dtype=torch.float16) + output = torch.zeros(shape[0], shape[2], shape[1], dtype=torch.float32) + mb = chained_gemm_f8(q, k, v, output) + + if test_dump_generated_mlir: + filename = f"wave_cgemm_{'x'.join(map(str, shape))}.mlir" + with open(filename, "w") as f: + f.write(mb.module_op.get_asm()) + + iree_ref = torch.zeros(shape[0], shape[2], shape[1], dtype=torch.float32) + generate_iree_ref( + "chain_mmt_f8", [q, k, v], [iree_ref], config, run_bench=run_bench + ) + assert_close(output, iree_ref) + + @require_e2e @pytest.mark.parametrize("shape", get_test_shapes("test_attention")) @pytest.mark.parametrize("enable_scheduling", [False])