Skip to content

Commit

Permalink
Add op to cast between dtypes
Browse files Browse the repository at this point in the history
This PR adds an op that can cast between
integer and float types of different bitwidths.

Signed-off-by: Harsh Menon <harsh@nod-labs.com>
  • Loading branch information
harsh-nod committed Oct 17, 2024
1 parent a92f3db commit 5e5f78f
Show file tree
Hide file tree
Showing 4 changed files with 228 additions and 0 deletions.
24 changes: 24 additions & 0 deletions iree/turbine/kernel/ops/wave_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,10 @@ def shuffle(src: "Register", offset: int, width: int) -> "Register":
...


def cast(src: "Register", dtype: DataType) -> "Register":
...


def define_op(op_name: str) -> Callable[[T], T]:
def decorator(cls: T) -> T:
cls.tkw_op_name = op_name
Expand Down Expand Up @@ -1159,3 +1163,23 @@ def indexing_dims(self) -> list[IndexSymbol]:
def type(self) -> Memory:
src_type = get_custom(self.arg).type
return src_type


@define_op("cast")
@dataclass
class CastOp(CustomOp, ABC):
"""
Represents a cast operation.
"""

arg: fx.Node
dtype: DataType

@property
def indexing_dims(self) -> list[IndexSymbol]:
return get_custom(self.arg).indexing_dims

@property
def type(self) -> Memory:
src_shape = get_custom(self.arg).type.symbolic_shape
return Register[*src_shape, self.dtype]
63 changes: 63 additions & 0 deletions iree/turbine/kernel/wave/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
Attribute,
DenseElementsAttr,
FloatAttr,
F16Type,
F32Type,
IndexType,
InsertionPoint,
Expand Down Expand Up @@ -63,6 +64,7 @@
CustomOp,
scheduling_barrier,
scheduling_group_barrier,
cast,
)
from ..lang.wave_types import IndexMapping, IndexSymbol
from ..compiler.base import CodegenError, ValidationError, NDEBUG
Expand Down Expand Up @@ -1198,3 +1200,64 @@ def handle_get_result(emitter: WaveEmitter, node: fx.Node):
@handle_op(operator.getitem)
def handle_getitem(emitter: WaveEmitter, node: fx.Node):
raise NotImplementedError("getitem: Currently only stub implementation")


def get_float_type(bitwidth: int):
match bitwidth:
case 16:
return F16Type.get()
case 32:
return F32Type.get()
case _:
raise NotImplementedError(f"Unsupported float bitwidth: {bitwidth}")


@handle_op(cast)
def handle_cast(emitter: WaveEmitter, node: fx.Node):
try:
register, dtype = node.args
except ValueError as e:
raise ValidationError("Malformed arguments") from e
vector_src = cast_vector(emitter, register)
src_vector_type = vector_src.type
dst_elem_type = IrType.parse(dtype.ir_type_asm())
dst_vector_type = VectorType.get(src_vector_type.shape, dst_elem_type)

if src_vector_type == dst_vector_type:
emitter.bind_node_proxy(node, vector_src)
return

is_src_float = _is_float_type(src_vector_type.element_type)
is_dst_float = _is_float_type(dst_elem_type)

conversion_ops = {
(True, True): lambda _, x: x,
(False, False): lambda _, x: x,
(True, False): arith_d.fptosi,
(False, True): arith_d.sitofp,
}

cast_ops = {
(True, True): arith_d.extf,
(True, False): arith_d.extsi,
(False, True): arith_d.truncf,
(False, False): arith_d.trunci,
}

dtype = (
get_float_type(dst_elem_type.width)
if is_dst_float
else IntegerType.get_signless(dst_elem_type.width)
)
converted_vector = conversion_ops[(is_src_float, is_dst_float)](
VectorType.get(src_vector_type.shape, dtype), vector_src
)

casted_vector = cast_ops[
(
src_vector_type.element_type.width < dst_elem_type.width,
is_dst_float and is_src_float,
)
](dst_vector_type, converted_vector)

emitter.bind_node_proxy(node, IRProxyValue(casted_vector))
83 changes: 83 additions & 0 deletions lit_tests/kernel/wave/casting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# RUN: python %s | FileCheck %s

import pytest
from typing import Callable
import iree.turbine.kernel as tk
import iree.turbine.kernel.lang as tkl
import iree.turbine.kernel.wave as tkw
from iree.turbine.kernel.lang.global_symbols import *
from iree.turbine.kernel.wave.utils import run_test
import torch

M = tkl.sym.M
N = tkl.sym.N
K = tkl.sym.K
B = tkl.sym.B
BLOCK_M = tkl.sym.BLOCK_M
BLOCK_N = tkl.sym.BLOCK_N
BLOCK_K = tkl.sym.BLOCK_K
BLOCK_B = tkl.sym.BLOCK_B
LOAD_ELEMS_PER_THREAD = tkl.sym.LOAD_ELEM_PER_THREAD
STORE_ELEMS_PER_THREAD = tkl.sym.STORE_ELEM_PER_THREAD
ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE
ADDRESS_SPACE_0 = tkl.sym.ADDRESS_SPACE_0


def codegen_test_context(canonicalize: bool = False, dynamic_symbols=[]):
bindings = {
M: 16,
N: 16,
K: 16,
BLOCK_M: 16,
BLOCK_N: 16,
BLOCK_K: 16,
ADDRESS_SPACE: tkl.AddressSpace.SHARED_MEMORY.value,
}

# Remove dynamic symbols from the bindings.
for sym in dynamic_symbols:
if sym in bindings:
del bindings[sym]

return tk.gen.TestLaunchContext(
bindings, canonicalize=canonicalize, dynamic_symbols=dynamic_symbols
)


@run_test
def test_cast():
constraints: list[tkw.Constraint] = [
tkw.HardwareConstraint(
threads_per_wave=64, waves_per_block=(1, 1, 1), vector_shapes={M: 16, N: 16}
)
]
constraints += [tkw.WorkgroupConstraint(M, BLOCK_M, 0)]
constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)]
constraints += [tkw.WaveConstraint(M, BLOCK_M)]
constraints += [tkw.WaveConstraint(N, BLOCK_N)]

@tkw.wave(constraints)
def test(
a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16],
b: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16],
):
a_reg = tkw.read(a, elements_per_thread=16)
a_reg = tkw.cast(a_reg, tkl.f32)
a_reg = tkw.cast(a_reg, tkl.i32)
a_reg = tkw.cast(a_reg, tkl.i16)
a_reg = tkw.cast(a_reg, tkl.i32)
a_reg = tkw.cast(a_reg, tkl.f32)
a_reg = tkw.cast(a_reg, tkl.f16)
tkw.write(a_reg, b, elements_per_thread=16)

with codegen_test_context(canonicalize=True):
a = torch.randn(16, 16, dtype=torch.float16)
b = torch.zeros(16, 16, dtype=torch.float16)
print(test(a, b).module_op)

# CHECK: %[[D0:.*]] = arith.extf {{.*}} : vector<16xf16> to vector<16xf32>
# CHECK: %[[D1:.*]] = arith.fptosi %[[D0]] : vector<16xf32> to vector<16xi32>
# CHECK: %[[D2:.*]] = arith.trunci %[[D1]] : vector<16xi32> to vector<16xi16>
# CHECK: %[[D3:.*]] = arith.extsi %[[D2]] : vector<16xi16> to vector<16xi32>
# CHECK: %[[D4:.*]] = arith.sitofp %[[D3]] : vector<16xi32> to vector<16xf32>
# CHECK: %[[D5:.*]] = arith.truncf %[[D4]] : vector<16xf32> to vector<16xf16>
58 changes: 58 additions & 0 deletions tests/kernel/wave/wave_e2e_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -898,3 +898,61 @@ def repeat(acc: tkl.Register[M, NF, tkl.f32]) -> tkl.Register[M, NF, tkl.f32]:
stride=stride,
run_bench=True,
)


@require_e2e
@pytest.mark.parametrize("shape", [256, 64])
def test_cast(shape, request):
run_bench = request.config.getoption("--runperf")
M = tkl.sym.M
N = tkl.sym.N
ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE

# Each workgroup works on single row of input data, and rows are further
# split into blocks of size up to 256. We have single wave per WG,
# and with default wave size of 64, each thread is operating on up to 4
# elements.
wave_size = 64
BLOCK_M = 1
# Tile size cannot be dynamic, so we use a fixed value here.
BLOCK_N = sympy.Max(sympy.Min(shape[1], 256), wave_size)
ELEMS_PER_THREAD = BLOCK_N / wave_size

constraints: list[tkw.Constraint] = [
tkw.HardwareConstraint(
threads_per_wave=wave_size,
waves_per_block=(1, 1, 1),
vector_shapes={M: BLOCK_M, N: BLOCK_N},
)
]
constraints += [tkw.WorkgroupConstraint(M, BLOCK_M, 1)]
constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 0)]
constraints += [tkw.WaveConstraint(M, BLOCK_M)]
constraints += [tkw.WaveConstraint(N, BLOCK_N)]

@tkw.wave(constraints)
def test(
a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f32],
b: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16],
):
res = tkw.read(a, elements_per_thread=ELEMS_PER_THREAD)
res = tkw.cast(res, tkl.f16)
tkw.write(res, b, elements_per_thread=ELEMS_PER_THREAD)

config = {"backend": "rocm", "device": "hip", "target": "gfx942"}

a = torch.randn(shape, dtype=torch.float32)
b = torch.zeros(shape, dtype=torch.float16)
with tk.gen.TestLaunchContext(
{
M: shape[0],
N: shape[1],
ADDRESS_SPACE: tkl.AddressSpace.GLOBAL_MEMORY.value,
},
canonicalize=True,
run=True,
run_bench=run_bench,
run_config=config,
):
test(a, b)
assert_allclose(a.to(dtype=torch.float16), b)

0 comments on commit 5e5f78f

Please sign in to comment.