Skip to content

Commit

Permalink
Add more shapes to Tk benchmark
Browse files Browse the repository at this point in the history
Also add tuned config which specifies what
tile sizes and scheduling params to use.
  • Loading branch information
harsh-amd committed Oct 8, 2024
1 parent 11fd8c4 commit 8a5fdd8
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 10 deletions.
51 changes: 44 additions & 7 deletions gemmbench/gemm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,37 @@ def generate_mlir(config: GemmConfig):
return mlir_template_B
return mlir_template

@dataclass
class TkTunedConfig:
BLOCK_M: int
BLOCK_N: int
BLOCK_K: int
RATIO_M: int
RATIO_N: int
WAVES_PER_EU: int
MMA_UNITS: int
SHARED_UNITS: int
GLOBAL_UNITS: int
DELAY_MMA: int
DELAY_SHARED: int
DELAY_GLOBAL: int

def get_tk_tuned_config(config: GemmConfig) -> TunedConfig:
if config.M == 2048 and config.N == 10240 and config.K == 1280:
return TunedConfig(128, 320, 32, 2, 2, 2, 2, 2, 2, 1, 1, 2)
if config.M == 2048 and config.N == 1280 and config.K == 1280:
return TunedConfig(64, 64, 64, 2, 2, 1, 2, 1, 1, 1, 1, 2)
if config.M == 2048 and config.N == 1280 and config.K == 5120:
return TunedConfig(128, 80, 128, 4, 1, 1, 4, 2, 2, 1, 1, 2)
if config.M == 128 and config.N == 1280 and config.K == 2048:
return TunedConfig(64, 64, 128, 2, 2, 1, 8, 2, 2, 1, 1, 2)
if config.M == 8192 and config.N == 5120 and config.K == 640:
return TunedConfig(128, 128, 32, 2, 2, 1, 4, 2, 2, 1, 1, 2)

def generate_tk_mlir(config: GemmConfig):
# TODO: Enable waves_per_eu
# TODO: Use scheduling barriers with LLVM patch
tc = get_tk_tuned_config(config)
assert config.operand_element_type == 'f16', "Unsupported problem"
assert config.accumulator_element_type == 'f32', "Unsupported problem"
# Input sizes
Expand All @@ -138,11 +167,11 @@ def generate_tk_mlir(config: GemmConfig):
constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)]
constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)]
constraints += [tkw.TilingConstraint(K, BLOCK_K)]
constraints += [tkw.WaveConstraint(M, BLOCK_M / 2)]
constraints += [tkw.WaveConstraint(N, BLOCK_N / 2)]
constraints += [tkw.WaveConstraint(M, BLOCK_M / tc.RATIO_M)]
constraints += [tkw.WaveConstraint(N, BLOCK_N / tc.RATIO_N)]

constraints += [
tkw.HardwareConstraint(threads_per_wave=64, waves_per_block=(2, 2, 1))
tkw.HardwareConstraint(threads_per_wave=64, waves_per_block=(tc.RATIO_M, tc.RATIO_N, 1))
]

# Wave-level micro-kernel.
Expand Down Expand Up @@ -184,16 +213,24 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
ADDRESS_SPACE: SHARED_ADDRESS_SPACE,
LOAD_ELEMS_PER_THREAD: 4,
STORE_ELEMS_PER_THREAD: 4,
BLOCK_M: 64,
BLOCK_N: 64,
BLOCK_K: 32,
BLOCK_M: tc.BLOCK_M,
BLOCK_N: tc.BLOCK_N,
BLOCK_K: tc.BLOCK_K,
M: shape[0],
N: shape[1],
K: shape[2],
READ_SHARED_DELAY: tc.DELAY_SHARED,
WRITE_SHARED_DELAY: tc.DELAY_SHARED,
READ_GLOBAL_DELAY: tc.DELAY_GLOBAL,
WRITE_GLOBAL_DELAY: tc.DELAY_GLOBAL,
MMA_DELAY: tc.DELAY_MMA,
SHARED_MEMORY_UNITS: tc.SHARED_UNITS,
GLOBAL_MEMORY_UNITS: tc.GLOBAL_UNITS,
MMA_UNITS: tc.MMA_UNITS,
}
config = {"backend": "rocm", "device": "hip", "target": "gfx942"}
with tk.gen.TestLaunchContext(
hyperparams, canonicalize=True, run=True, run_config=config
hyperparams, canonicalize=True, run=True, run_config=config, schedule=True,
):
a = torch.randn(shape[0], shape[2], dtype=operand_element_type)
b = torch.randn(shape[1], shape[2], dtype=operand_element_type)
Expand Down
10 changes: 7 additions & 3 deletions gemmbench/problems.py
Original file line number Diff line number Diff line change
Expand Up @@ -864,11 +864,15 @@ def tk_default(dtype: str) -> list[GemmConfig]:
acc_type = get_default_accumulator_element_type(dtype)
res_type = get_default_result_element_type(dtype)
configs = []
M, N, K = 1024, 5120, 640
configs.append(GemmConfig(M, N, K, "N", "T", dtype, acc_type, res_type))
M, N, K = 2048, 10240, 1280
configs.append(GemmConfig(M, N, K, "N", "T", dtype, acc_type, res_type))
M, N, K = 4096, 20480, 2560
M, N, K = 2048, 1280, 1280
configs.append(GemmConfig(M, N, K, "N", "T", dtype, acc_type, res_type))
M, N, K = 2048, 1280, 5120
configs.append(GemmConfig(M, N, K, "N", "T", dtype, acc_type, res_type))
M, N, K = 128, 1280, 2048
configs.append(GemmConfig(M, N, K, "N", "T", dtype, acc_type, res_type))
M, N, K = 8192, 5120, 640
configs.append(GemmConfig(M, N, K, "N", "T", dtype, acc_type, res_type))
return configs

Expand Down

0 comments on commit 8a5fdd8

Please sign in to comment.