diff --git a/gemmbench/gemm_utils.py b/gemmbench/gemm_utils.py index 348a5ca..7f5a865 100644 --- a/gemmbench/gemm_utils.py +++ b/gemmbench/gemm_utils.py @@ -116,7 +116,10 @@ def generate_mlir(config: GemmConfig): return mlir_template_B return mlir_template + def generate_tk_mlir(config: GemmConfig): + assert config.operand_element_type == 'f16', "Unsupported problem" + assert config.accumulator_element_type == 'f32', "Unsupported problem" # Input sizes M = tkl.sym.M N = tkl.sym.N