Skip to content

Commit

Permalink
Add support for F8 intrinsics
Browse files Browse the repository at this point in the history
Signed-off-by: Harsh Menon <harsh@nod-labs.com>
  • Loading branch information
harsh-nod committed Oct 29, 2024
1 parent 98c52e3 commit 8863678
Show file tree
Hide file tree
Showing 4 changed files with 390 additions and 6 deletions.
31 changes: 25 additions & 6 deletions iree/turbine/kernel/wave/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -1325,16 +1325,35 @@ def handle_reshape(emitter: WaveEmitter, node: fx.Node):

# Determine whether to extract or combine.
if len(args) > 1:
raise NotImplementedError(
"reshape: Currently only handles cases where target_vector_shapes > custom.vector_shapes"
)
concatenated = None
for i, sub_arg in enumerate(args):
vector = cast_vector(emitter, sub_arg)
shape = vector.type.shape[0]
if concatenated is None:
element_type = vector.type.element_type
vector_type = VectorType.get([shape * len(args)], element_type)
concatenated = arith_d.ConstantOp(
vector_type,
DenseElementsAttr.get_splat(
vector_type, get_constant_attr(0, element_type)
),
).result
concatenated = vector_d.insert_strided_slice(
vector, concatenated, [i * shape], [1]
)
emitter.bind_node_proxy(node, IRProxyValue(concatenated))
return

# Extract the appropriate slice. The offset is obtained from the expanded_dim
# and so corresponds to the dim_query during expansion. To obtain the
# actual offset, we need to multiple by the size which is determined by comparing
# the source and target vector shapes along the innermost dimension.
size = target_vector_shapes[innermost_dim] // custom.vector_shapes[innermost_dim]
# actual offset, we need to multiply by the size. The size is obtained by
# computing the number of partitions using the source and target vector shapes
# and dividing the incoming vector shape by the number of partitions.
num_partitions = (
target_vector_shapes[innermost_dim] // custom.vector_shapes[innermost_dim]
)
vector = cast_vector(emitter, args[0])
size = vector.type.shape[0] // num_partitions
result_type = VectorType.get([size], vector.type.element_type)
slice = vector_d.extract_strided_slice(
result_type,
Expand Down
44 changes: 44 additions & 0 deletions iree/turbine/kernel/wave/iree_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,44 @@ def get_chain_mmt_asm(
}}"""


def get_chain_mmt_f8_asm(
query_type: str, key_type: str, value_type: str, output_type: str
) -> str:
B, M, K1, input_dtype = query_type.split("x")
B, K2, K1, input_dtype = key_type.split("x")
B, N, K2, input_dtype = value_type.split("x")
B, N, M, output_dtype = output_type.split("x")
f8_dtype = "f8E4M3FNUZ"
intermediate_output_type = f"{B}x{K2}x{M}x{output_dtype}"
intermediate_cast_type = f"{B}x{K2}x{M}x{f8_dtype}"
transposed_cast_type = f"{B}x{M}x{K2}x{f8_dtype}"
transposed_output_type = f"{B}x{M}x{N}x{output_dtype}"
query_f8_type = "x".join([B, M, K1, f8_dtype])
key_f8_type = "x".join([B, K2, K1, f8_dtype])
value_f8_type = "x".join([B, N, K2, f8_dtype])
return f"""
func.func @chain_mmt_f8(%query: tensor<{query_type}>, %key: tensor<{key_type}>, %value: tensor<{value_type}>) -> tensor<{output_type}> {{
%c0 = arith.constant 0.0 : f32
%init = tensor.empty() : tensor<{intermediate_output_type}>
%query_f8 = arith.truncf %query : tensor<{query_type}> to tensor<{query_f8_type}>
%key_f8 = arith.truncf %key : tensor<{key_type}> to tensor<{key_f8_type}>
%inital_result = linalg.fill ins(%c0 : f32) outs(%init : tensor<{intermediate_output_type}>) -> tensor<{intermediate_output_type}>
%result = linalg.batch_matmul_transpose_b ins(%key_f8, %query_f8 : tensor<{key_f8_type}>, tensor<{query_f8_type}>)
outs(%inital_result : tensor<{intermediate_output_type}>) -> tensor<{intermediate_output_type}>
%trunc = arith.truncf %result : tensor<{intermediate_output_type}> to tensor<{intermediate_cast_type}>
%init2 = tensor.empty() : tensor<{transposed_cast_type}>
%transpose = linalg.transpose ins(%trunc: tensor<{intermediate_cast_type}>) outs(%init2: tensor<{transposed_cast_type}>) permutation=[0, 2, 1]
%init3 = tensor.empty() : tensor<{transposed_output_type}>
%inital_result3 = linalg.fill ins(%c0 : f32) outs(%init3 : tensor<{transposed_output_type}>) -> tensor<{transposed_output_type}>
%value_f8 = arith.truncf %value : tensor<{value_type}> to tensor<{value_f8_type}>
%result2 = linalg.batch_matmul_transpose_b ins(%transpose, %value_f8: tensor<{transposed_cast_type}>, tensor<{value_f8_type}>)
outs(%inital_result3 : tensor<{transposed_output_type}>) -> tensor<{transposed_output_type}>
%init4 = tensor.empty() : tensor<{output_type}>
%transpose2 = linalg.transpose ins(%result2: tensor<{transposed_output_type}>) outs(%init4: tensor<{output_type}>) permutation=[0, 2, 1]
return %transpose2 : tensor<{output_type}>
}}"""


def get_mmt_asm(
lhs_type: str,
rhs_type: str,
Expand Down Expand Up @@ -141,6 +179,12 @@ def generate_iree_ref(
value_type = get_type_str(kernel_inputs[2].shape, kernel_inputs[2].dtype)
output_type = get_type_str(kernel_outputs[0].shape, kernel_outputs[0].dtype)
asm = get_chain_mmt_asm(query_type, key_type, value_type, output_type)
elif kernel_type == "chain_mmt_f8":
query_type = get_type_str(kernel_inputs[0].shape, kernel_inputs[0].dtype)
key_type = get_type_str(kernel_inputs[1].shape, kernel_inputs[1].dtype)
value_type = get_type_str(kernel_inputs[2].shape, kernel_inputs[2].dtype)
output_type = get_type_str(kernel_outputs[0].shape, kernel_outputs[0].dtype)
asm = get_chain_mmt_f8_asm(query_type, key_type, value_type, output_type)
elif kernel_type.startswith(conv_str):
lhs_type = get_type_str(kernel_inputs[0].shape, kernel_inputs[0].dtype)
rhs_type = get_type_str(kernel_inputs[1].shape, kernel_inputs[1].dtype)
Expand Down
172 changes: 172 additions & 0 deletions lit_tests/kernel/wave/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -918,6 +918,178 @@ def repeat(
# CHECK: scf.yield


@run_test
def test_chained_gemm_32x32x16():
K1 = tkl.sym.K1
K2 = tkl.sym.K2
BLOCK_K2 = tkl.sym.BLOCK_K2

constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)]
constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)]
constraints += [tkw.WorkgroupConstraint(B, BLOCK_B, 2)]
constraints += [tkw.TilingConstraint(K2, BLOCK_K2)]
constraints += [tkw.WaveConstraint(M, BLOCK_M / 2)]
constraints += [tkw.WaveConstraint(N, BLOCK_N / 2)]

mfma_variant = tkw.MMAType.F32_32x32x16_F8
constraints += [
tkw.HardwareConstraint(
threads_per_wave=64,
waves_per_block=(2, 2, 1),
mma_type=mfma_variant,
vector_shapes={B: 0},
)
]

@tkw.wave(constraints)
def chained_gemm_32x32x16(
q: tkl.Memory[B, M, K1, ADDRESS_SPACE, tkl.f16],
k: tkl.Memory[B, K2, K1, ADDRESS_SPACE, tkl.f16],
v: tkl.Memory[B, N, K2, ADDRESS_SPACE, tkl.f16],
c: tkl.Memory[B, M, N, ADDRESS_SPACE_0, tkl.f32],
):
c_reg = tkl.Register[B, M, N, tkl.f32](0.0)

@tkw.reduction(K2, init_args=[c_reg])
def repeat(
acc: tkl.Register[B, M, N, tkl.f32]
) -> tkl.Register[B, M, N, tkl.f32]:
inner_acc = tkl.Register[B, K2, M, tkl.f32](0.0)
q_reg = tkw.read(q, elements_per_thread=LOAD_ELEMS_PER_THREAD)
k_reg = tkw.read(k, elements_per_thread=LOAD_ELEMS_PER_THREAD)
q_reg = tkw.cast(q_reg, tkl.f8e4m3fnuz)
k_reg = tkw.cast(k_reg, tkl.f8e4m3fnuz)
kq_reg = tkw.mma(k_reg, q_reg, inner_acc)
qk_reg = tkw.permute(kq_reg, target_shape=[B, M, K2])
qk_cast_reg = tkw.cast(qk_reg, tkl.f8e4m3fnuz)
v_reg = tkw.read(v, elements_per_thread=LOAD_ELEMS_PER_THREAD)
v_reg = tkw.cast(v_reg, tkl.f8e4m3fnuz)
acc = tkw.mma(qk_cast_reg, v_reg, acc)
return acc

tkw.write(repeat, c, elements_per_thread=STORE_ELEMS_PER_THREAD)

with tk.gen.TestLaunchContext(
{
M: 128,
N: 128,
K1: 32,
K2: 256,
B: 8,
BLOCK_M: 64,
BLOCK_N: 64,
BLOCK_K2: 32,
BLOCK_B: 1,
LOAD_ELEMS_PER_THREAD: get_mfma_load_elems_per_thread(mfma_variant),
STORE_ELEMS_PER_THREAD: get_mfma_store_elems_per_thread(mfma_variant),
ADDRESS_SPACE: SHARED_ADDRESS_SPACE,
ADDRESS_SPACE_0: GLOBAL_ADDRESS_SPACE,
},
canonicalize=True,
):
q = torch.randn(8, 64, 64, dtype=torch.float16)
k = torch.randn(8, 256, 64, dtype=torch.float16)
v = torch.zeros(8, 128, 256, dtype=torch.float16)
output = torch.zeros(8, 64, 128, dtype=torch.float32)
print(chained_gemm_32x32x16(q, k, v, output).module_op)

# CHECK: func.func @chained_gemm_32x32x16(
# CHECK: {{.*}} = scf.for
# CHECK-COUNT-2: {{.*}} = amdgpu.mfma
# CHECK-COUNT-3: {{.*}} = arith.truncf
# CHECK: {{.*}} = vector.extract_strided_slice {{.*}} {offsets = [0], sizes = [8], strides = [1]}
# CHECK: {{.*}} = vector.extract_strided_slice {{.*}} {offsets = [8], sizes = [8], strides = [1]}
# CHECK-COUNT-2: {{.*}} = amdgpu.mfma
# CHECK: scf.yield


@run_test
def test_chained_gemm_16x16x32():
K1 = tkl.sym.K1
K2 = tkl.sym.K2
BLOCK_K2 = tkl.sym.BLOCK_K2

constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)]
constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)]
constraints += [tkw.WorkgroupConstraint(B, BLOCK_B, 2)]
constraints += [tkw.TilingConstraint(K2, BLOCK_K2)]
constraints += [tkw.WaveConstraint(M, BLOCK_M / 2)]
constraints += [tkw.WaveConstraint(N, BLOCK_N / 2)]

mfma_variant = tkw.MMAType.F32_16x16x32_F8
constraints += [
tkw.HardwareConstraint(
threads_per_wave=64,
waves_per_block=(2, 2, 1),
mma_type=mfma_variant,
vector_shapes={B: 0},
)
]

@tkw.wave(constraints)
def chained_gemm_16x16x32(
q: tkl.Memory[B, M, K1, ADDRESS_SPACE, tkl.f16],
k: tkl.Memory[B, K2, K1, ADDRESS_SPACE, tkl.f16],
v: tkl.Memory[B, N, K2, ADDRESS_SPACE, tkl.f16],
c: tkl.Memory[B, M, N, ADDRESS_SPACE_0, tkl.f32],
):
c_reg = tkl.Register[B, M, N, tkl.f32](0.0)

@tkw.reduction(K2, init_args=[c_reg])
def repeat(
acc: tkl.Register[B, M, N, tkl.f32]
) -> tkl.Register[B, M, N, tkl.f32]:
inner_acc = tkl.Register[B, K2, M, tkl.f32](0.0)
q_reg = tkw.read(q, elements_per_thread=LOAD_ELEMS_PER_THREAD)
k_reg = tkw.read(k, elements_per_thread=LOAD_ELEMS_PER_THREAD)
q_reg = tkw.cast(q_reg, tkl.f8e4m3fnuz)
k_reg = tkw.cast(k_reg, tkl.f8e4m3fnuz)
kq_reg = tkw.mma(k_reg, q_reg, inner_acc)
qk_reg = tkw.permute(kq_reg, target_shape=[B, M, K2])
qk_cast_reg = tkw.cast(qk_reg, tkl.f8e4m3fnuz)
v_reg = tkw.read(v, elements_per_thread=LOAD_ELEMS_PER_THREAD)
v_reg = tkw.cast(v_reg, tkl.f8e4m3fnuz)
acc = tkw.mma(qk_cast_reg, v_reg, acc)
return acc

tkw.write(repeat, c, elements_per_thread=STORE_ELEMS_PER_THREAD)

with tk.gen.TestLaunchContext(
{
M: 128,
N: 128,
K1: 32,
K2: 256,
B: 8,
BLOCK_M: 64,
BLOCK_N: 64,
BLOCK_K2: 32,
BLOCK_B: 1,
LOAD_ELEMS_PER_THREAD: get_mfma_load_elems_per_thread(mfma_variant),
STORE_ELEMS_PER_THREAD: get_mfma_store_elems_per_thread(mfma_variant),
ADDRESS_SPACE: SHARED_ADDRESS_SPACE,
ADDRESS_SPACE_0: GLOBAL_ADDRESS_SPACE,
},
canonicalize=True,
):
q = torch.randn(8, 64, 64, dtype=torch.float16)
k = torch.randn(8, 256, 64, dtype=torch.float16)
v = torch.zeros(8, 128, 256, dtype=torch.float16)
output = torch.zeros(8, 64, 128, dtype=torch.float32)
print(chained_gemm_16x16x32(q, k, v, output).module_op)

# CHECK: func.func @chained_gemm_16x16x32(
# CHECK: {{.*}} = scf.for
# CHECK-COUNT-4: {{.*}} = amdgpu.mfma
# CHECK-COUNT-6: {{.*}} = arith.truncf
# CHECK: {{.*}} = vector.insert_strided_slice {{.*}} {offsets = [0], strides = [1]}
# CHECK: {{.*}} = vector.insert_strided_slice {{.*}} {offsets = [4], strides = [1]}
# CHECK: {{.*}} = vector.insert_strided_slice {{.*}} {offsets = [0], strides = [1]}
# CHECK: {{.*}} = vector.insert_strided_slice {{.*}} {offsets = [4], strides = [1]}
# CHECK-COUNT-4: {{.*}} = amdgpu.mfma
# CHECK: scf.yield


@run_test
def test_gemm_pipelined():
constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)]
Expand Down
Loading

0 comments on commit 8863678

Please sign in to comment.