diff --git a/codegen/compiler/src/Quidditch/Target/ConfigureForSnitch.cpp b/codegen/compiler/src/Quidditch/Target/ConfigureForSnitch.cpp index 2fca7cf..16399e1 100644 --- a/codegen/compiler/src/Quidditch/Target/ConfigureForSnitch.cpp +++ b/codegen/compiler/src/Quidditch/Target/ConfigureForSnitch.cpp @@ -119,8 +119,11 @@ void ConfigureForSnitch::runOnOperation() { if (failed(setTranslationInfo(funcOp))) return signalPassFailure(); - if (failed(setRootConfig(funcOp, rootOperation))) - return signalPassFailure(); + auto loweringConfig = + getLoweringConfig(rootOperation); + if (!loweringConfig) + if (failed(setRootConfig(funcOp, rootOperation))) + return signalPassFailure(); // The root configuration setting introduces `tensor.dim` operations. // Resolve those away. diff --git a/codegen/compiler/src/Quidditch/Target/QuidditchTarget.cpp b/codegen/compiler/src/Quidditch/Target/QuidditchTarget.cpp index 0e5e17b..5e57741 100644 --- a/codegen/compiler/src/Quidditch/Target/QuidditchTarget.cpp +++ b/codegen/compiler/src/Quidditch/Target/QuidditchTarget.cpp @@ -574,7 +574,14 @@ class QuidditchTargetBackend final : public IREE::HAL::TargetBackend { class QuidditchSession final : public PluginSession { +public: + static void registerGlobalDialects(DialectRegistry ®istry) { + // Required to allow the 'quidditch_snitch' dialect to also be used in + // input IR without just being parsed as an 'OpaqueAttr'. + registry.insert(); + } +private: void populateHALTargetDevices(IREE::HAL::TargetDeviceList &targets) override { targets.add("quidditch_device", []() { return std::make_shared(); }); diff --git a/runtime/cmake/quidditch_module.cmake b/runtime/cmake/quidditch_module.cmake index 5819098..cce7252 100644 --- a/runtime/cmake/quidditch_module.cmake +++ b/runtime/cmake/quidditch_module.cmake @@ -52,7 +52,7 @@ find_program(XDSL_OPT_PATH xdsl-opt # The resulting library is the source file's name with the extension removed and # '_module' appended. function(quidditch_module) - cmake_parse_arguments(_RULE "LLVM;ASSERT_XDSL" "SRC;N_THREADS;DST" "FLAGS;DEPENDS" ${ARGN}) + cmake_parse_arguments(_RULE "LLVM;ASSERT_XDSL" "SRC;DST" "FLAGS;DEPENDS" ${ARGN}) set(_MLIR_SRC "${_RULE_SRC}") if (NOT _RULE_DST) @@ -60,10 +60,6 @@ function(quidditch_module) set(_RULE_DST "${_RULE_DST}") endif () - if (NOT _RULE_N_THREADS) - set(_RULE_N_THREADS 8) - endif () - get_filename_component(_MLIR_SRC "${_MLIR_SRC}" REALPATH) set(_O_QUIDDITCH_FILE_NAME "${CMAKE_CURRENT_BINARY_DIR}/${_RULE_DST}/${_RULE_DST}.o") set(_O_LLVM_FILE_NAME "${CMAKE_CURRENT_BINARY_DIR}/${_RULE_DST}/${_RULE_DST}_llvm.o") @@ -95,7 +91,9 @@ function(quidditch_module) list(APPEND _OUTPUT_FILES "${_O_QUIDDITCH_FILE_NAME}") list(APPEND _OBJECT_FILES "${_O_QUIDDITCH_FILE_NAME}") - list(APPEND _OBJECT_FILES "${_O_LLVM_FILE_NAME}") + if (NOT _RULE_ASSERT_XDSL) + list(APPEND _OBJECT_FILES "${_O_LLVM_FILE_NAME}") + endif () string(REPLACE ".o" ".h" _STATIC_HDR_PATH "${_O_QUIDDITCH_FILE_NAME}") list(APPEND _OUTPUT_FILES "${_STATIC_HDR_PATH}") @@ -105,19 +103,23 @@ function(quidditch_module) endif () string(REPLACE ".o" ".h" _STATIC_HDR_PATH "${_O_LLVM_FILE_NAME}") - list(APPEND _OUTPUT_FILES "${_STATIC_HDR_PATH}" "${_O_LLVM_FILE_NAME}") - - list(APPEND _COMPILER_ARGS "--iree-hal-target-backends=llvm-cpu") - list(APPEND _COMPILER_ARGS "--iree-llvmcpu-debug-symbols=true") - list(APPEND _COMPILER_ARGS "--iree-llvmcpu-target-triple=riscv32-unknown-elf") - list(APPEND _COMPILER_ARGS "--iree-llvmcpu-target-cpu=generic-rv32") - list(APPEND _COMPILER_ARGS "--iree-llvmcpu-target-cpu-features=+m,+f,+d,+zfh") - list(APPEND _COMPILER_ARGS "--iree-llvmcpu-target-abi=ilp32d") - list(APPEND _COMPILER_ARGS "--iree-llvmcpu-target-float-abi=hard") - list(APPEND _COMPILER_ARGS "--iree-llvmcpu-link-embedded=false") - list(APPEND _COMPILER_ARGS "--iree-llvmcpu-link-static") - list(APPEND _COMPILER_ARGS "--iree-llvmcpu-number-of-threads=${_RULE_N_THREADS}") - list(APPEND _COMPILER_ARGS "--iree-llvmcpu-static-library-output-path=${_O_LLVM_FILE_NAME}") + list(APPEND _OUTPUT_FILES "${_STATIC_HDR_PATH}") + + if (NOT _RULE_ASSERT_XDSL) + list(APPEND _COMPILER_ARGS "--iree-hal-target-backends=llvm-cpu") + list(APPEND _COMPILER_ARGS "--iree-llvmcpu-debug-symbols=true") + list(APPEND _COMPILER_ARGS "--iree-llvmcpu-target-triple=riscv32-unknown-elf") + list(APPEND _COMPILER_ARGS "--iree-llvmcpu-target-cpu=generic-rv32") + list(APPEND _COMPILER_ARGS "--iree-llvmcpu-target-cpu-features=+m,+f,+d,+zfh") + list(APPEND _COMPILER_ARGS "--iree-llvmcpu-target-abi=ilp32d") + list(APPEND _COMPILER_ARGS "--iree-llvmcpu-target-float-abi=hard") + list(APPEND _COMPILER_ARGS "--iree-llvmcpu-link-embedded=false") + list(APPEND _COMPILER_ARGS "--iree-llvmcpu-link-static") + list(APPEND _COMPILER_ARGS "--iree-llvmcpu-number-of-threads=8") + list(APPEND _COMPILER_ARGS "--iree-llvmcpu-static-library-output-path=${_O_LLVM_FILE_NAME}") + + list(APPEND _OUTPUT_FILES "${_O_LLVM_FILE_NAME}") + endif () list(APPEND _COMPILER_ARGS "--output-format=vm-c") list(APPEND _COMPILER_ARGS "--iree-vm-target-index-bits=32") diff --git a/runtime/samples/CMakeLists.txt b/runtime/samples/CMakeLists.txt index d64669d..0452780 100644 --- a/runtime/samples/CMakeLists.txt +++ b/runtime/samples/CMakeLists.txt @@ -1,5 +1,6 @@ include(quidditch_module) +add_subdirectory(big_matvec) add_subdirectory(nsnet2) add_subdirectory(util) add_subdirectory(vec_multiply) diff --git a/runtime/samples/big_matvec/CMakeLists.txt b/runtime/samples/big_matvec/CMakeLists.txt new file mode 100644 index 0000000..7749503 --- /dev/null +++ b/runtime/samples/big_matvec/CMakeLists.txt @@ -0,0 +1,12 @@ +quidditch_module(SRC big_matvec.mlir ASSERT_XDSL) + +add_executable(big_matvec_sample main.c) +target_link_libraries( + big_matvec_sample + PRIVATE + samples_util + big_matvec + snRuntime + Quidditch::dispatch::dispatch +) + diff --git a/runtime/samples/big_matvec/big_matvec.mlir b/runtime/samples/big_matvec/big_matvec.mlir new file mode 100644 index 0000000..6d0d04f --- /dev/null +++ b/runtime/samples/big_matvec/big_matvec.mlir @@ -0,0 +1,126 @@ +builtin.module @big_matvec { + func.func @test32(%arg0: tensor<1x400xf64>, %arg1: tensor<320x400xf64>) -> tensor<1x320xf64> { + %init = tensor.empty() : tensor<1x320xf64> + %out = linalg.matmul_transpose_b { + lowering_config = #quidditch_snitch.lowering_config< + l1_tiles = [0, 32, 80], + l1_tiles_interchange = [2, 0, 1], + dual_buffer = true + > + } + ins(%arg0, %arg1 : tensor<1x400xf64>, tensor<320x400xf64>) + outs(%init : tensor<1x320xf64>) -> tensor<1x320xf64> + %out2 = linalg.generic { + indexing_maps = [ + affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0, d1)> + ], iterator_types = ["parallel", "parallel"] + } ins(%out : tensor<1x320xf64>) outs(%out : tensor<1x320xf64>) { + ^bb0(%element : f64, %outs : f64): + %bias = arith.constant 5.0 : f64 + %added = arith.addf %element, %bias : f64 + linalg.yield %added : f64 + } -> tensor<1x320xf64> + func.return %out2 : tensor<1x320xf64> + } + + func.func @test40(%arg0: tensor<1x400xf64>, %arg1: tensor<320x400xf64>) -> tensor<1x320xf64> { + %init = tensor.empty() : tensor<1x320xf64> + %out = linalg.matmul_transpose_b { + lowering_config = #quidditch_snitch.lowering_config< + l1_tiles = [0, 40, 80], + l1_tiles_interchange = [2, 0, 1], + dual_buffer = true + > + } + ins(%arg0, %arg1 : tensor<1x400xf64>, tensor<320x400xf64>) + outs(%init : tensor<1x320xf64>) -> tensor<1x320xf64> + %out2 = linalg.generic { + indexing_maps = [ + affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0, d1)> + ], iterator_types = ["parallel", "parallel"] + } ins(%out : tensor<1x320xf64>) outs(%out : tensor<1x320xf64>) { + ^bb0(%element : f64, %outs : f64): + %bias = arith.constant 5.0 : f64 + %added = arith.addf %element, %bias : f64 + linalg.yield %added : f64 + } -> tensor<1x320xf64> + func.return %out2 : tensor<1x320xf64> + } + + func.func @test64(%arg0: tensor<1x400xf64>, %arg1: tensor<320x400xf64>) -> tensor<1x320xf64> { + %init = tensor.empty() : tensor<1x320xf64> + %out = linalg.matmul_transpose_b { + lowering_config = #quidditch_snitch.lowering_config< + l1_tiles = [0, 64, 80], + l1_tiles_interchange = [2, 0, 1], + dual_buffer = true + > + } + ins(%arg0, %arg1 : tensor<1x400xf64>, tensor<320x400xf64>) + outs(%init : tensor<1x320xf64>) -> tensor<1x320xf64> + %out2 = linalg.generic { + indexing_maps = [ + affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0, d1)> + ], iterator_types = ["parallel", "parallel"] + } ins(%out : tensor<1x320xf64>) outs(%out : tensor<1x320xf64>) { + ^bb0(%element : f64, %outs : f64): + %bias = arith.constant 5.0 : f64 + %added = arith.addf %element, %bias : f64 + linalg.yield %added : f64 + } -> tensor<1x320xf64> + func.return %out2 : tensor<1x320xf64> + } + + func.func @test32_100(%arg0: tensor<1x400xf64>, %arg1: tensor<320x400xf64>) -> tensor<1x320xf64> { + %init = tensor.empty() : tensor<1x320xf64> + %out = linalg.matmul_transpose_b { + lowering_config = #quidditch_snitch.lowering_config< + l1_tiles = [0, 32, 100], + l1_tiles_interchange = [2, 0, 1], + dual_buffer = true + > + } + ins(%arg0, %arg1 : tensor<1x400xf64>, tensor<320x400xf64>) + outs(%init : tensor<1x320xf64>) -> tensor<1x320xf64> + %out2 = linalg.generic { + indexing_maps = [ + affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0, d1)> + ], iterator_types = ["parallel", "parallel"] + } ins(%out : tensor<1x320xf64>) outs(%out : tensor<1x320xf64>) { + ^bb0(%element : f64, %outs : f64): + %bias = arith.constant 5.0 : f64 + %added = arith.addf %element, %bias : f64 + linalg.yield %added : f64 + } -> tensor<1x320xf64> + func.return %out2 : tensor<1x320xf64> + } + + func.func @test40_100(%arg0: tensor<1x400xf64>, %arg1: tensor<320x400xf64>) -> tensor<1x320xf64> { + %init = tensor.empty() : tensor<1x320xf64> + %out = linalg.matmul_transpose_b { + lowering_config = #quidditch_snitch.lowering_config< + l1_tiles = [0, 40, 100], + l1_tiles_interchange = [2, 0, 1], + dual_buffer = true + > + } + ins(%arg0, %arg1 : tensor<1x400xf64>, tensor<320x400xf64>) + outs(%init : tensor<1x320xf64>) -> tensor<1x320xf64> + %out2 = linalg.generic { + indexing_maps = [ + affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0, d1)> + ], iterator_types = ["parallel", "parallel"] + } ins(%out : tensor<1x320xf64>) outs(%out : tensor<1x320xf64>) { + ^bb0(%element : f64, %outs : f64): + %bias = arith.constant 5.0 : f64 + %added = arith.addf %element, %bias : f64 + linalg.yield %added : f64 + } -> tensor<1x320xf64> + func.return %out2 : tensor<1x320xf64> + } +} diff --git a/runtime/samples/big_matvec/main.c b/runtime/samples/big_matvec/main.c new file mode 100644 index 0000000..dc494a5 --- /dev/null +++ b/runtime/samples/big_matvec/main.c @@ -0,0 +1,179 @@ +#include +#include +#include + +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +static iree_status_t setup_instance_and_device( + const model_config_t* config, iree_allocator_t host_allocator, + iree_vm_instance_t** out_instance, iree_hal_device_t** out_device) { + IREE_ASSERT_ARGUMENT(out_instance); + IREE_ASSERT_ARGUMENT(out_device); + + IREE_RETURN_IF_ERROR(iree_vm_instance_create(IREE_VM_TYPE_CAPACITY_DEFAULT, + host_allocator, out_instance)); + + iree_status_t result = iree_hal_module_register_all_types(*out_instance); + if (!iree_status_is_ok(result)) goto error_release_vm; + + iree_hal_executable_loader_t* loader; + result = quidditch_loader_create(config->num_libraries, config->libraries, + iree_hal_executable_import_provider_null(), + host_allocator, &loader); + if (!iree_status_is_ok(result)) goto error_release_vm; + + iree_hal_allocator_t* device_allocator; + result = iree_hal_allocator_create_heap(iree_make_cstring_view("quidditch"), + /*data_allocator=*/host_allocator, + host_allocator, &device_allocator); + if (!iree_status_is_ok(result)) goto error_release_library_loader; + + quidditch_device_params_t params; + quidditch_device_params_initialize(¶ms); + result = + quidditch_device_create(IREE_SV("snitch"), ¶ms, + /*loader_count=*/1, &loader, device_allocator, + host_allocator, out_device); + iree_hal_executable_loader_release(loader); + iree_hal_allocator_release(device_allocator); + return result; + +error_release_library_loader: + iree_hal_executable_loader_release(loader); +error_release_vm: + iree_vm_instance_release(*out_instance); + return result; +} + +int main() { + if (!snrt_is_dm_core()) return quidditch_dispatch_enter_worker_loop(); + + model_config_t config = { + .libraries = + (iree_hal_executable_library_query_fn_t[]){ + quidditch_big_matvec_linked_quidditch_library_query, + }, + .num_libraries = 1, + .module_constructor = big_matvec_create, + + .element_type = IREE_HAL_ELEMENT_TYPE_FLOAT_64, + + .num_inputs = 2, + .input_sizes = (const iree_host_size_t[]){1 * 320, 320 * 400}, + .input_ranks = (const iree_host_size_t[]){2, 2}, + .input_shapes = (const iree_hal_dim_t*[]){(iree_hal_dim_t[]){1, 400}, + (iree_hal_dim_t[]){320, 400}}, + + .num_outputs = 1, + .output_sizes = (const iree_host_size_t[]){1, 320}, + }; + + // Inlined 'run_model' to avoid constructing and destructing the device + // and host module multiple times. + + iree_allocator_t host_allocator = iree_allocator_system(); + + iree_vm_instance_t* vmInstance; + iree_hal_device_t* device; + IREE_CHECK_OK( + setup_instance_and_device(&config, host_allocator, &vmInstance, &device)); + + iree_vm_module_t* hal_module = NULL; + IREE_CHECK_OK(iree_hal_module_create(vmInstance, /*device_count=*/1, + /*devices=*/&device, + IREE_HAL_MODULE_FLAG_NONE, + host_allocator, &hal_module)); + + iree_vm_module_t* mlir_module = NULL; + IREE_CHECK_OK( + config.module_constructor(vmInstance, host_allocator, &mlir_module)); + + iree_vm_module_t* modules[] = {hal_module, mlir_module}; + + iree_vm_context_t* context; + IREE_CHECK_OK(iree_vm_context_create_with_modules( + vmInstance, IREE_VM_CONTEXT_FLAG_NONE, IREE_ARRAYSIZE(modules), modules, + host_allocator, &context)); + + iree_vm_list_t* inputs = NULL; + IREE_CHECK_OK(iree_vm_list_create( + /*element_type=*/iree_vm_make_undefined_type_def(), + /*initial_capacity=*/config.num_inputs, host_allocator, &inputs)); + + for (iree_host_size_t i = 0; i < config.num_inputs; i++) { + iree_const_byte_span_t span = iree_make_const_byte_span( + config.input_data[i], + config.input_sizes[i] * + iree_hal_element_dense_byte_count(config.element_type)); + + iree_device_size_t out_size; + IREE_CHECK_OK(iree_hal_buffer_compute_view_size( + config.input_ranks[i], config.input_shapes[i], config.element_type, + IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, &out_size)); + + iree_hal_buffer_params_t params = { + .usage = IREE_HAL_BUFFER_USAGE_DISPATCH_STORAGE, + .access = IREE_HAL_MEMORY_ACCESS_NONE, + .type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL, + }; + iree_hal_buffer_params_canonicalize(¶ms); + + iree_hal_buffer_t* buffer = NULL; + IREE_CHECK_OK(iree_hal_allocator_allocate_buffer( + iree_hal_device_allocator(device), params, out_size, &buffer)); + + iree_hal_buffer_view_t* buffer_view; + IREE_CHECK_OK(iree_hal_buffer_view_create( + buffer, config.input_ranks[i], config.input_shapes[i], + config.element_type, IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, + host_allocator, &buffer_view)); + + iree_vm_ref_t arg_buffer_view_ref; + arg_buffer_view_ref = iree_hal_buffer_view_move_ref(buffer_view); + IREE_CHECK_OK(iree_vm_list_push_ref_retain(inputs, &arg_buffer_view_ref)); + } + + iree_vm_list_t* outputs = NULL; + IREE_CHECK_OK(iree_vm_list_create( + /*element_type=*/iree_vm_make_undefined_type_def(), + /*initial_capacity=*/1, host_allocator, &outputs)); + + iree_string_view_t functions[] = { + iree_make_cstring_view("big_matvec.test32"), + iree_make_cstring_view("big_matvec.test40"), + iree_make_cstring_view("big_matvec.test64"), + iree_make_cstring_view("big_matvec.test32_100"), + iree_make_cstring_view("big_matvec.test40_100"), + }; + + for (int i = 0; i < IREE_ARRAYSIZE(functions); i++) { + iree_vm_function_t main_function; + IREE_CHECK_OK(iree_vm_context_resolve_function(context, functions[i], + &main_function)); + + IREE_CHECK_OK( + iree_vm_invoke(context, main_function, IREE_VM_CONTEXT_FLAG_NONE, + /*policy=*/NULL, inputs, outputs, host_allocator)); + } + + iree_vm_list_release(outputs); + iree_vm_list_release(inputs); + iree_vm_context_release(context); + iree_vm_module_release(mlir_module); + iree_vm_module_release(hal_module); + iree_hal_device_release(device); + iree_vm_instance_release(vmInstance); + + quidditch_dispatch_quit(); + + return 0; +} diff --git a/runtime/tests/CMakeLists.txt b/runtime/tests/CMakeLists.txt index 9d3895c..7704a11 100644 --- a/runtime/tests/CMakeLists.txt +++ b/runtime/tests/CMakeLists.txt @@ -82,5 +82,6 @@ endmacro() test_executable(HelloWorld PRECOMMIT) test_executable(vec_multiply PRECOMMIT) +test_executable(big_matvec NIGHTLY) test_executable(NsNet2 NIGHTLY) test_executable(NsNet2LLVM NIGHTLY)