diff --git a/build_tools/build_test_cpp.ps1 b/build_tools/build_test_cpp.ps1 index 43bcbe901..f14da8721 100644 --- a/build_tools/build_test_cpp.ps1 +++ b/build_tools/build_test_cpp.ps1 @@ -129,7 +129,7 @@ echo "-----" # better have git-bash installed... $env:Path = "C:\Program Files\Git\bin;$env:Path" pushd $build_dir -& bash -l -c "ctest -R amd-aie -E driver/xrt-lite --output-on-failure -j --repeat until-pass:5" +& bash -l -c "ctest -R amd-aie -E driver --output-on-failure -j --repeat until-pass:5" popd if ($llvm_install_dir -and (Test-Path "$llvm_install_dir")) diff --git a/build_tools/build_test_cpp.sh b/build_tools/build_test_cpp.sh index e4a0a661e..612f5999a 100644 --- a/build_tools/build_test_cpp.sh +++ b/build_tools/build_test_cpp.sh @@ -143,7 +143,7 @@ cmake --build "$build_dir" --target iree-install-dist echo "CTest" echo "-----" if [[ "$OSTYPE" == "linux"* ]]; then - ctest --test-dir "$build_dir" -R amd-aie -E "driver/xrt-lite" --output-on-failure -j + ctest --test-dir "$build_dir" -R amd-aie -E "driver" --output-on-failure -j elif [[ "$OSTYPE" == "darwin"* ]]; then ctest --test-dir "$build_dir" -R amd-aie -E "matmul_pack_peel_air_e2e|matmul_elementwise_pack_peel_air_e2e|conv_fill_spec_pad" --output-on-failure -j --repeat until-pass:5 fi diff --git a/iree_compiler_plugin.cmake b/iree_compiler_plugin.cmake index 958d6de46..3b50361c8 100644 --- a/iree_compiler_plugin.cmake +++ b/iree_compiler_plugin.cmake @@ -17,7 +17,13 @@ if("xrt" IN_LIST IREE_EXTERNAL_HAL_DRIVERS) set(IREE_AMD_AIE_ENABLE_XRT_DRIVER ON) endif() -if(IREE_AMD_AIE_ENABLE_XRT_DRIVER) +set(IREE_AMD_AIE_ENABLE_XRT_LITE_DRIVER OFF) +if("xrt-lite" IN_LIST IREE_EXTERNAL_HAL_DRIVERS) + message(STATUS "Enabling XRT-LITE build because it is an enabled HAL driver") + set(IREE_AMD_AIE_ENABLE_XRT_LITE_DRIVER ON) +endif() + +if(IREE_AMD_AIE_ENABLE_XRT_DRIVER OR IREE_AMD_AIE_ENABLE_XRT_LITE_DRIVER) include(iree_aie_xrt) endif() include(iree_aie_bootgen) diff --git a/runtime/src/iree-amd-aie/driver/xrt-lite/CMakeLists.txt b/runtime/src/iree-amd-aie/driver/xrt-lite/CMakeLists.txt index 9fcdb521f..43c265ca2 100644 --- a/runtime/src/iree-amd-aie/driver/xrt-lite/CMakeLists.txt +++ b/runtime/src/iree-amd-aie/driver/xrt-lite/CMakeLists.txt @@ -25,18 +25,21 @@ iree_cc_library( api.h buffer.cc buffer.h - command_buffer.cc - command_buffer.h + direct_command_buffer.cc + direct_command_buffer.h device.cc driver.cc executable.cc executable.h nop_executable_cache.cc nop_executable_cache.h + nop_semaphore.cc + nop_semaphore.h util.h DEPS iree::base iree::base::core_headers + iree::hal::utils::deferred_command_buffer iree::base::internal::flatcc::parsing iree-amd-aie::schemas::xrt_executable_def_c_fbs iree-amd-aie::driver::xrt-lite::shim::linux::kmq::shim-xdna diff --git a/runtime/src/iree-amd-aie/driver/xrt-lite/buffer.cc b/runtime/src/iree-amd-aie/driver/xrt-lite/buffer.cc index 7f92c9811..a78ee9aa2 100644 --- a/runtime/src/iree-amd-aie/driver/xrt-lite/buffer.cc +++ b/runtime/src/iree-amd-aie/driver/xrt-lite/buffer.cc @@ -71,8 +71,7 @@ iree_status_t iree_hal_xrt_lite_buffer::invalidate_range( IREE_STATUS_FAILED_PRECONDITION, "buffer does not have device memory attached and cannot be mapped"); } - this->bo->sync(shim_xdna::direction::device2host, local_byte_length, - local_byte_offset); + this->bo->sync(shim_xdna::direction::device2host); return iree_ok_status(); } @@ -86,8 +85,7 @@ iree_status_t iree_hal_xrt_lite_buffer::flush_range( IREE_STATUS_FAILED_PRECONDITION, "buffer does not have device memory attached and cannot be mapped"); } - this->bo->sync(shim_xdna::direction::host2device, local_byte_length, - local_byte_offset); + this->bo->sync(shim_xdna::direction::host2device); return iree_ok_status(); } @@ -146,11 +144,10 @@ static void iree_hal_xrt_lite_buffer_destroy(iree_hal_buffer_t* base_buffer) { IREE_TRACE_ZONE_END(z0); } -std::unique_ptr iree_hal_xrt_lite_buffer_unwrap( - iree_hal_buffer_t* base_buffer) { +shim_xdna::bo* iree_hal_xrt_lite_buffer_handle(iree_hal_buffer_t* base_buffer) { iree_hal_xrt_lite_buffer* buffer = reinterpret_cast(base_buffer); - return std::move(buffer->bo); + return buffer->bo.get(); } #define BUFFER_MEMBER_STATUS(member) \ diff --git a/runtime/src/iree-amd-aie/driver/xrt-lite/buffer.h b/runtime/src/iree-amd-aie/driver/xrt-lite/buffer.h index c89f14164..c6f34b7b9 100644 --- a/runtime/src/iree-amd-aie/driver/xrt-lite/buffer.h +++ b/runtime/src/iree-amd-aie/driver/xrt-lite/buffer.h @@ -41,7 +41,6 @@ iree_status_t iree_hal_xrt_lite_buffer_wrap( iree_hal_buffer_release_callback_t release_callback, iree_allocator_t host_allocator, iree_hal_buffer_t** out_buffer); -std::unique_ptr iree_hal_xrt_lite_buffer_unwrap( - iree_hal_buffer_t* base_buffer); +shim_xdna::bo* iree_hal_xrt_lite_buffer_handle(iree_hal_buffer_t* base_buffer); #endif // IREE_HAL_DRIVERS_XRT_LITE_BUFFER_H_ diff --git a/runtime/src/iree-amd-aie/driver/xrt-lite/command_buffer.cc b/runtime/src/iree-amd-aie/driver/xrt-lite/command_buffer.cc deleted file mode 100644 index 52e124af2..000000000 --- a/runtime/src/iree-amd-aie/driver/xrt-lite/command_buffer.cc +++ /dev/null @@ -1,341 +0,0 @@ -// Copyright 2024 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#include "iree-amd-aie/driver/xrt-lite/command_buffer.h" - -#include "buffer.h" -#include "iree-amd-aie/driver/xrt-lite/util.h" -#include "shim/linux/kmq/bo.h" - -#define MAX_EXEC_BO_SIZE (4096) - -namespace { -extern const iree_hal_command_buffer_vtable_t - iree_hal_xrt_lite_command_buffer_vtable; -} - -struct iree_hal_xrt_lite_command_buffer { - iree_hal_command_buffer_t base; - iree_allocator_t host_allocator; - iree_hal_buffer_t* exec_buffer; - - iree_status_t begin() { - // TODO(null): if the implementation needs to route the begin to the - // implementation it can be done here. Note that creation may happen much - // earlier than recording and any expensive work should be deferred until - // this point to make profiling easier. - (void)this; - iree_status_t status = - iree_make_status(IREE_STATUS_UNIMPLEMENTED, - "command buffer recording start not implemented"); - return status; - } - - iree_status_t end() { - // TODO(null): if recording requires multiple passes any fixup/linking can - // happen here. Recording-only resources are no longer needed after this - // point and can be disposed. - (void)this; - iree_status_t status = - iree_make_status(IREE_STATUS_UNIMPLEMENTED, - "command buffer finalization not implemented"); - return status; - } - - void begin_debug_group(iree_string_view_t label, - iree_hal_label_color_t label_color, - const iree_hal_label_location_t* location) { - // TODO(null): begin a nested debug group (push) if the implementation has a - // way to insert markers. This is informational and can be ignored. - (void)this; - } - - void end_debug_group() { - // TODO(null): end a nested debug group (pop). Always called 1:1 in stack - // order with begin_debug_group. - (void)this; - } - - iree_status_t execution_barrier( - iree_hal_execution_stage_t source_stage_mask, - iree_hal_execution_stage_t target_stage_mask, - iree_hal_execution_barrier_flags_t flags, - iree_host_size_t memory_barrier_count, - const iree_hal_memory_barrier_t* memory_barriers, - iree_host_size_t buffer_barrier_count, - const iree_hal_buffer_barrier_t* buffer_barriers) { - // TODO(null): barriers split the execution sequence into all operations - // that did happen before the barrier and all that will happen after. In - // implementations that have no concurrency this can be a no-op. This is - // effectively just a signal_event followed by a wait_event. - (void)this; - iree_status_t status = iree_make_status( - IREE_STATUS_UNIMPLEMENTED, "execution barriers not implemented"); - return status; - } - - iree_status_t signal_event(iree_hal_event_t* event, - iree_hal_execution_stage_t source_stage_mask) { - // TODO(null): WIP API and may change; signals the given event allowing - // waiters to proceed. - (void)this; - iree_status_t status = - iree_make_status(IREE_STATUS_UNIMPLEMENTED, "events not implemented"); - return status; - } - - iree_status_t reset_event(iree_hal_event_t* event, - iree_hal_execution_stage_t source_stage_mask) { - // TODO(null): WIP API and may change; resets the given event to unsignaled. - (void)this; - iree_status_t status = - iree_make_status(IREE_STATUS_UNIMPLEMENTED, "events not implemented"); - return status; - } - - iree_status_t wait_events(iree_host_size_t event_count, - const iree_hal_event_t** events, - iree_hal_execution_stage_t source_stage_mask, - iree_hal_execution_stage_t target_stage_mask, - iree_host_size_t memory_barrier_count, - const iree_hal_memory_barrier_t* memory_barriers, - iree_host_size_t buffer_barrier_count, - const iree_hal_buffer_barrier_t* buffer_barriers) { - // TODO(null): WIP API and may change; waits on the list of events and - // enacts the specified set of barriers. Implementations without - // fine-grained tracking can treat this as an execution_barrier and ignore - // the memory/buffer barriers provided. - (void)this; - iree_status_t status = - iree_make_status(IREE_STATUS_UNIMPLEMENTED, "events not implemented"); - return status; - } - - iree_status_t discard_buffer(iree_hal_buffer_ref_t buffer_ref) { - // TODO(null): WIP API and may change; this is likely to become an - // madvise-like command that can be used to control prefetching and other - // cache behavior. The current discard behavior is a hint that the buffer - // contents will never be used again and that if they are in a cache they - // need not be written back to global memory. - (void)this; - iree_status_t status = iree_make_status(IREE_STATUS_UNIMPLEMENTED, - "discard buffer not implemented"); - return status; - } - - iree_status_t fill_buffer(iree_hal_buffer_ref_t target_ref, - const void* pattern, - iree_host_size_t pattern_length) { - // TODO(null): memset on the buffer. The pattern_length is 1, 2, or 4 bytes. - // Note that the buffer may be a reference to a binding table slot in which - // case it will be provided during submission to a queue. - (void)this; - iree_status_t status = iree_make_status(IREE_STATUS_UNIMPLEMENTED, - "fill buffer not implemented"); - return status; - } - - iree_status_t update_buffer(const void* source_buffer, - iree_host_size_t source_offset, - iree_hal_buffer_ref_t target_ref) { - // TODO(null): embed and copy a small (~64KB) chunk of host memory to the - // target buffer. The source_buffer contents must be captured as they may - // change/be freed after this call completes. - // Note that the target buffer may be a reference to a binding table slot in - // which case it will be provided during submission to a queue. - (void)this; - iree_status_t status = iree_make_status(IREE_STATUS_UNIMPLEMENTED, - "update buffer not implemented"); - - return status; - } - - iree_status_t copy_buffer(iree_hal_buffer_ref_t source_ref, - iree_hal_buffer_ref_t target_ref) { - // TODO(null): memcpy between two buffers. The buffers must both be - // device-visible but may reside on either the host or device. - // Note that either buffer may be a reference to a binding table slot in - // which case it will be provided during submission to a queue. - (void)this; - iree_status_t status = iree_make_status(IREE_STATUS_UNIMPLEMENTED, - "copy buffer not implemented"); - - return status; - } - - iree_status_t collective(iree_hal_channel_t* channel, - iree_hal_collective_op_t op, uint32_t param, - iree_hal_buffer_ref_t send_ref, - iree_hal_buffer_ref_t recv_ref, - iree_device_size_t element_count) { - // TODO(null): perform the collective operation defined by op. See the - // headers for more information. The channel is fixed for a particular - // recording but note that either buffer may be a reference to a binding - // table slot in which case it will be provided during submission to a - // queue. - (void)this; - iree_status_t status = iree_make_status(IREE_STATUS_UNIMPLEMENTED, - "collectives not implemented"); - - return status; - } - - iree_status_t dispatch(iree_hal_executable_t* executable, int32_t entry_point, - const uint32_t workgroup_count[3], - iree_const_byte_span_t constants, - iree_hal_buffer_ref_list_t bindings, - iree_hal_dispatch_flags_t flags) { - // TODO(null): dispatch the specified executable entry point with the given - // workgroup count. The constants must be copied into the command buffer as - // they may be mutated or freed after this call returns. - // Note that any of the bindings may be references to binding table slots in - // which case they will be provided during submission to a queue. - (void)this; - iree_status_t status = - iree_make_status(IREE_STATUS_UNIMPLEMENTED, "dispatch not implemented"); - - return status; - } - - iree_status_t dispatch_indirect(iree_hal_executable_t* executable, - int32_t entry_point, - iree_hal_buffer_ref_t workgroups_ref, - iree_const_byte_span_t constants, - iree_hal_buffer_ref_list_t bindings, - iree_hal_dispatch_flags_t flags) { - // TODO(null): dispatch the specified executable entry point with a - // workgroup count that is stored in the given workgroup count buffer as a - // uint32_t[3]. The workgroup count may change up until immediately prior to - // the dispatch. The constants must be copied into the command buffer as - // they may be mutated or freed after this call returns. Note that any of - // the bindings may be references to binding table slots in which case they - // will be provided during submission to a queue. - (void)this; - iree_status_t status = iree_make_status( - IREE_STATUS_UNIMPLEMENTED, "indirect dispatch not implemented"); - - return status; - } -}; - -static iree_hal_xrt_lite_command_buffer* iree_hal_xrt_lite_command_buffer_cast( - iree_hal_command_buffer_t* base_value) { - IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_xrt_lite_command_buffer_vtable); - return (iree_hal_xrt_lite_command_buffer*)base_value; -} - -iree_status_t iree_hal_xrt_lite_command_buffer_create( - iree_hal_allocator_t* device_allocator, iree_hal_command_buffer_mode_t mode, - iree_hal_command_category_t command_categories, - iree_hal_queue_affinity_t queue_affinity, iree_host_size_t binding_capacity, - iree_allocator_t host_allocator, - iree_hal_command_buffer_t** out_command_buffer) { - IREE_ASSERT_ARGUMENT(out_command_buffer); - *out_command_buffer = nullptr; - IREE_TRACE_ZONE_BEGIN(z0); - - iree_hal_xrt_lite_command_buffer* command_buffer = nullptr; - IREE_RETURN_AND_END_ZONE_IF_ERROR( - z0, - iree_allocator_malloc(host_allocator, - sizeof(*command_buffer) + - iree_hal_command_buffer_validation_state_size( - mode, binding_capacity), - (void**)&command_buffer)); - iree_hal_command_buffer_initialize( - device_allocator, mode, command_categories, queue_affinity, - binding_capacity, (uint8_t*)command_buffer + sizeof(*command_buffer), - &iree_hal_xrt_lite_command_buffer_vtable, &command_buffer->base); - command_buffer->host_allocator = host_allocator; - - // TODO(null): allocate any additional resources for managing command buffer - // state. Some implementations may have their own command buffer/command list - // APIs this can route to or may need to implement it all themselves using - // iree_arena_t/block pools. Implementations should also retain any resources - // used during the recording and can use iree_hal_resource_set_t* to make that - // easier. - iree_hal_buffer_params_t params; - params.type = IREE_HAL_MEMORY_TYPE_OPTIMAL_FOR_DEVICE; - iree_status_t status = iree_hal_allocator_allocate_buffer( - device_allocator, params, MAX_EXEC_BO_SIZE, &command_buffer->exec_buffer); - - if (iree_status_is_ok(status)) { - *out_command_buffer = &command_buffer->base; - } else { - iree_hal_command_buffer_release(&command_buffer->base); - } - IREE_TRACE_ZONE_END(z0); - return status; -} - -static void iree_hal_xrt_lite_command_buffer_destroy( - iree_hal_command_buffer_t* base_command_buffer) { - iree_hal_xrt_lite_command_buffer* command_buffer = - iree_hal_xrt_lite_command_buffer_cast(base_command_buffer); - iree_allocator_t host_allocator = command_buffer->host_allocator; - IREE_TRACE_ZONE_BEGIN(z0); - - // TODO(null): release any implementation resources and - // iree_hal_resource_set_t. - iree_hal_buffer_destroy(command_buffer->exec_buffer); - iree_allocator_free(host_allocator, command_buffer); - - IREE_TRACE_ZONE_END(z0); -} - -bool iree_hal_xrt_lite_command_buffer_isa( - iree_hal_command_buffer_t* command_buffer) { - return iree_hal_resource_is(&command_buffer->resource, - &iree_hal_xrt_lite_command_buffer_vtable); -} - -#define COMMAND_BUFFER_MEMBER(member, return_t) \ - MEMBER_WRAPPER(iree_hal_command_buffer_t, iree_hal_xrt_lite_command_buffer, \ - member, return_t) -#define COMMAND_BUFFER_MEMBER_STATUS(member) \ - MEMBER_WRAPPER_STATUS(iree_hal_command_buffer_t, \ - iree_hal_xrt_lite_command_buffer, member) -#define COMMAND_BUFFER_MEMBER_VOID(member) \ - MEMBER_WRAPPER_VOID(iree_hal_command_buffer_t, \ - iree_hal_xrt_lite_command_buffer, member) - -COMMAND_BUFFER_MEMBER_STATUS(begin); -COMMAND_BUFFER_MEMBER_STATUS(end); -COMMAND_BUFFER_MEMBER_VOID(begin_debug_group); -COMMAND_BUFFER_MEMBER_VOID(end_debug_group); -COMMAND_BUFFER_MEMBER_STATUS(execution_barrier); -COMMAND_BUFFER_MEMBER_STATUS(signal_event); -COMMAND_BUFFER_MEMBER_STATUS(reset_event); -COMMAND_BUFFER_MEMBER_STATUS(wait_events); -COMMAND_BUFFER_MEMBER_STATUS(discard_buffer); -COMMAND_BUFFER_MEMBER_STATUS(fill_buffer); -COMMAND_BUFFER_MEMBER_STATUS(update_buffer); -COMMAND_BUFFER_MEMBER_STATUS(copy_buffer); -COMMAND_BUFFER_MEMBER_STATUS(collective); -COMMAND_BUFFER_MEMBER_STATUS(dispatch); -COMMAND_BUFFER_MEMBER_STATUS(dispatch_indirect); - -namespace { -const iree_hal_command_buffer_vtable_t iree_hal_xrt_lite_command_buffer_vtable = - { - .destroy = iree_hal_xrt_lite_command_buffer_destroy, - .begin = iree_hal_xrt_lite_command_buffer_begin, - .end = iree_hal_xrt_lite_command_buffer_end, - .begin_debug_group = iree_hal_xrt_lite_command_buffer_begin_debug_group, - .end_debug_group = iree_hal_xrt_lite_command_buffer_end_debug_group, - .execution_barrier = iree_hal_xrt_lite_command_buffer_execution_barrier, - .signal_event = iree_hal_xrt_lite_command_buffer_signal_event, - .reset_event = iree_hal_xrt_lite_command_buffer_reset_event, - .wait_events = iree_hal_xrt_lite_command_buffer_wait_events, - .discard_buffer = iree_hal_xrt_lite_command_buffer_discard_buffer, - .fill_buffer = iree_hal_xrt_lite_command_buffer_fill_buffer, - .update_buffer = iree_hal_xrt_lite_command_buffer_update_buffer, - .copy_buffer = iree_hal_xrt_lite_command_buffer_copy_buffer, - .collective = iree_hal_xrt_lite_command_buffer_collective, - .dispatch = iree_hal_xrt_lite_command_buffer_dispatch, - .dispatch_indirect = iree_hal_xrt_lite_command_buffer_dispatch_indirect, -}; -} diff --git a/runtime/src/iree-amd-aie/driver/xrt-lite/command_buffer.h b/runtime/src/iree-amd-aie/driver/xrt-lite/command_buffer.h deleted file mode 100644 index 7283582bf..000000000 --- a/runtime/src/iree-amd-aie/driver/xrt-lite/command_buffer.h +++ /dev/null @@ -1,25 +0,0 @@ -// Copyright 2024 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#ifndef IREE_HAL_DRIVERS_XRT_LITE_COMMAND_BUFFER_H_ -#define IREE_HAL_DRIVERS_XRT_LITE_COMMAND_BUFFER_H_ - -#include "iree/base/api.h" -#include "iree/hal/api.h" - -// Creates {Null} command buffer. -iree_status_t iree_hal_xrt_lite_command_buffer_create( - iree_hal_allocator_t* device_allocator, iree_hal_command_buffer_mode_t mode, - iree_hal_command_category_t command_categories, - iree_hal_queue_affinity_t queue_affinity, iree_host_size_t binding_capacity, - iree_allocator_t host_allocator, - iree_hal_command_buffer_t** out_command_buffer); - -// Returns true if |command_buffer| is a {Null} command buffer. -bool iree_hal_xrt_lite_command_buffer_isa( - iree_hal_command_buffer_t* command_buffer); - -#endif // IREE_HAL_DRIVERS_XRT_LITE_COMMAND_BUFFER_H_ diff --git a/runtime/src/iree-amd-aie/driver/xrt-lite/cts/CMakeLists.txt b/runtime/src/iree-amd-aie/driver/xrt-lite/cts/CMakeLists.txt index 8ed1891b0..a8125ec00 100644 --- a/runtime/src/iree-amd-aie/driver/xrt-lite/cts/CMakeLists.txt +++ b/runtime/src/iree-amd-aie/driver/xrt-lite/cts/CMakeLists.txt @@ -52,6 +52,8 @@ iree_bytecode_module( --iree-amd-aie-show-invoked-commands --iree-hal-memoization=false --iree-hal-indirect-command-buffers=false + DEPS + iree-aie-xclbinutil PUBLIC TESTONLY ) @@ -76,6 +78,51 @@ iree_c_embed_data( TESTONLY ) +#iree_bytecode_module( +# NAME +# xrt_lite_command_buffer_dispatch_test_module +# MODULE_FILE_NAME +# xrt_lite_command_buffer_dispatch_test.bin +# SRC +# "${CMAKE_CURRENT_LIST_DIR}/command_buffer_dispatch_test.mlir" +# FLAGS +# --compile-mode=hal-executable +# --iree-hal-dump-executable-files-to=${CMAKE_CURRENT_BINARY_DIR} +# --iree-hal-target-backends=amd-aie +# --iree-amdaie-lower-to-aie-pipeline=air +# --iree-amdaie-target-device=${TARGET_DEVICE} +# --iree-amd-aie-peano-install-dir=${PEANO_INSTALL_DIR} +# --iree-amd-aie-vitis-install-dir=${VITIS_DIR} +# --iree-amd-aie-enable-chess=$ +# --iree-amd-aie-show-invoked-commands +# --iree-hal-memoization=false +# --iree-hal-indirect-command-buffers=false +# DEPS +# iree-aie-xclbinutil +# PUBLIC +# TESTONLY +#) +# +#iree_c_embed_data( +# NAME +# xrt_lite_command_buffer_dispatch_c +# SRCS +# xrt_lite_command_buffer_dispatch_test.bin +# C_FILE_OUTPUT +# xrt_lite_command_buffer_dispatch_c.c +# H_FILE_OUTPUT +# xrt_lite_command_buffer_dispatch_c.h +# IDENTIFIER +# iree_cts_testdata_command_buffer_dispatch_aie_xrt_lite +# STRIP_PREFIX +# xrt_lite_ +# DEPENDS +# ::xrt_lite_command_buffer_dispatch_test_module +# FLATTEN +# PUBLIC +# TESTONLY +#) + iree_cc_test( NAME xrt_lite_executable_cache_test @@ -92,9 +139,9 @@ iree_cc_test( iree_cc_test( NAME - xrt_lite_command_buffer_dispatch_test + xrt_lite_dispatch_test SRCS - command_buffer_dispatch_test.cc + matmul_dispatch_test.cc DEPS ::xrt_lite_executables_c iree-amd-aie::driver::xrt-lite::registration @@ -106,4 +153,4 @@ iree_cc_test( ) target_include_directories(iree-amd-aie_driver_xrt-lite_cts_xrt_lite_executable_cache_test PRIVATE "${CMAKE_CURRENT_BINARY_DIR}") -target_include_directories(iree-amd-aie_driver_xrt-lite_cts_xrt_lite_command_buffer_dispatch_test PRIVATE "${CMAKE_CURRENT_BINARY_DIR}") +target_include_directories(iree-amd-aie_driver_xrt-lite_cts_xrt_lite_dispatch_test PRIVATE "${CMAKE_CURRENT_BINARY_DIR}") diff --git a/runtime/src/iree-amd-aie/driver/xrt-lite/cts/executable_cache_test.mlir b/runtime/src/iree-amd-aie/driver/xrt-lite/cts/executable_cache_test.mlir index 4a27d79e0..dedbcab6b 100644 --- a/runtime/src/iree-amd-aie/driver/xrt-lite/cts/executable_cache_test.mlir +++ b/runtime/src/iree-amd-aie/driver/xrt-lite/cts/executable_cache_test.mlir @@ -9,24 +9,24 @@ flags = Indirect > hal.executable.source public @amdaie_fb { - hal.executable.export public @matmul_f32_dispatch_0_matmul_256x256x32_f32 ordinal(0) layout(#pipeline_layout) { + hal.executable.export public @matmul_f32_dispatch_0_matmul_32x32x32_f32 ordinal(0) layout(#pipeline_layout) { ^bb0(%arg0: !hal.device): %x, %y, %z = flow.dispatch.workgroup_count_from_slice hal.return %x, %y, %z : index, index, index } builtin.module { - func.func @matmul_f32_dispatch_0_matmul_256x256x32_f32() { + func.func @matmul_f32_dispatch_0_matmul_32x32x32_f32() { %c0_f32 = arith.constant 0.0 : f32 %c0 = arith.constant 0 : index - %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : !flow.dispatch.tensor> - %1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : !flow.dispatch.tensor> - %2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) alignment(64) offset(%c0) flags(Indirect) : !flow.dispatch.tensor> - %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<256x256xf32> - %4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<256x256xf32> - %5 = tensor.empty() : tensor<256x256xf32> - %6 = linalg.fill ins(%c0_f32 : f32) outs(%5 : tensor<256x256xf32>) -> tensor<256x256xf32> - %7 = linalg.matmul ins(%3, %4 : tensor<256x256xf32>, tensor<256x256xf32>) outs(%6 : tensor<256x256xf32>) -> tensor<256x256xf32> - flow.dispatch.tensor.store %7, %2, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : tensor<256x256xf32> -> !flow.dispatch.tensor> + %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : !flow.dispatch.tensor> + %1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : !flow.dispatch.tensor> + %2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) alignment(64) offset(%c0) flags(Indirect) : !flow.dispatch.tensor> + %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [32, 32], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<32x32xf32> + %4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [32, 32], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<32x32xf32> + %5 = tensor.empty() : tensor<32x32xf32> + %6 = linalg.fill ins(%c0_f32 : f32) outs(%5 : tensor<32x32xf32>) -> tensor<32x32xf32> + %7 = linalg.matmul ins(%3, %4 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%6 : tensor<32x32xf32>) -> tensor<32x32xf32> + flow.dispatch.tensor.store %7, %2, offsets = [0, 0], sizes = [32, 32], strides = [1, 1] : tensor<32x32xf32> -> !flow.dispatch.tensor> return } } diff --git a/runtime/src/iree-amd-aie/driver/xrt-lite/cts/command_buffer_dispatch_test.cc b/runtime/src/iree-amd-aie/driver/xrt-lite/cts/matmul_dispatch_test.cc similarity index 94% rename from runtime/src/iree-amd-aie/driver/xrt-lite/cts/command_buffer_dispatch_test.cc rename to runtime/src/iree-amd-aie/driver/xrt-lite/cts/matmul_dispatch_test.cc index 00053e145..f00bfbddc 100644 --- a/runtime/src/iree-amd-aie/driver/xrt-lite/cts/command_buffer_dispatch_test.cc +++ b/runtime/src/iree-amd-aie/driver/xrt-lite/cts/matmul_dispatch_test.cc @@ -23,7 +23,7 @@ iree_status_t register_test_driver(iree_hal_driver_registry_t* registry) { return iree_hal_xrt_lite_driver_module_register(registry); } -const char* get_test_executable_format() { return "amdaie-pdi-fb"; } +const char* get_test_executable_format() { return "amdaie-xclbin-fb"; } iree_const_byte_span_t get_test_executable_data(iree_string_view_t file_name) { const struct iree_file_toc_t* toc = @@ -32,7 +32,7 @@ iree_const_byte_span_t get_test_executable_data(iree_string_view_t file_name) { return iree_make_const_byte_span(file.data, file.size); } -class CommandBufferDispatchTest +class MatMulDispatchTest : public CTSTestBase<::testing::TestWithParam> { protected: void PrepareMatmulExecutable() { @@ -75,7 +75,7 @@ int32_t generate_random_number(iree_hal_element_type_t element_type, min; } -TEST_F(CommandBufferDispatchTest, Create) { +TEST_F(MatMulDispatchTest, Create) { iree_hal_command_buffer_t* command_buffer = nullptr; IREE_ASSERT_OK(iree_hal_command_buffer_create( device_, IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT, @@ -89,7 +89,7 @@ TEST_F(CommandBufferDispatchTest, Create) { iree_hal_command_buffer_release(command_buffer); } -TEST_F(CommandBufferDispatchTest, BeginEnd) { +TEST_F(MatMulDispatchTest, BeginEnd) { iree_hal_command_buffer_t* command_buffer = nullptr; IREE_ASSERT_OK(iree_hal_command_buffer_create( device_, IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT, @@ -102,7 +102,7 @@ TEST_F(CommandBufferDispatchTest, BeginEnd) { iree_hal_command_buffer_release(command_buffer); } -TEST_F(CommandBufferDispatchTest, SubmitEmpty) { +TEST_F(MatMulDispatchTest, SubmitEmpty) { iree_hal_command_buffer_t* command_buffer = nullptr; IREE_ASSERT_OK(iree_hal_command_buffer_create( device_, IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT, @@ -117,11 +117,11 @@ TEST_F(CommandBufferDispatchTest, SubmitEmpty) { iree_hal_command_buffer_release(command_buffer); } -TEST_P(CommandBufferDispatchTest, DispatchMatmul) { +TEST_P(MatMulDispatchTest, DispatchMatmul) { PrepareMatmulExecutable(); // Create input buffer. - constexpr iree_device_size_t WIDTH = 256; + constexpr iree_device_size_t WIDTH = 32; constexpr iree_device_size_t M = WIDTH, K = WIDTH, N = WIDTH; iree_hal_buffer_t *input_A = nullptr, *input_B = nullptr, *output_C = nullptr; int32_t seed = @@ -133,7 +133,7 @@ TEST_P(CommandBufferDispatchTest, DispatchMatmul) { iree_hal_element_types_t::IREE_HAL_ELEMENT_TYPE_FLOAT_32, seed + 1); CreateFilledDeviceBuffer(M * K * sizeof(float), a, &input_A); CreateFilledDeviceBuffer(K * N * sizeof(float), b, &input_B); - CreateFilledDeviceBuffer(M * N * sizeof(float), 0, &output_C); + CreateFilledDeviceBuffer(M * N * sizeof(float), -1, &output_C); iree_hal_buffer_ref_t binding_refs[3]; iree_hal_buffer_binding_table_t binding_table = @@ -217,7 +217,7 @@ TEST_P(CommandBufferDispatchTest, DispatchMatmul) { CleanupExecutable(); } -INSTANTIATE_TEST_SUITE_P(CommandBufferDispatchTest, CommandBufferDispatchTest, +INSTANTIATE_TEST_SUITE_P(MatMulDispatchTest, MatMulDispatchTest, ::testing::Values(RecordingType::kDirect), GenerateTestName()); diff --git a/runtime/src/iree-amd-aie/driver/xrt-lite/device.cc b/runtime/src/iree-amd-aie/driver/xrt-lite/device.cc index ab3dbca70..b5035c79f 100644 --- a/runtime/src/iree-amd-aie/driver/xrt-lite/device.cc +++ b/runtime/src/iree-amd-aie/driver/xrt-lite/device.cc @@ -4,14 +4,18 @@ // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -#include "iree-amd-aie/driver/xrt-lite/device.h" +#include "iree-amd-aie/driver/xrt-lite/shim/linux/kmq/device.h" #include "iree-amd-aie/driver/xrt-lite/allocator.h" #include "iree-amd-aie/driver/xrt-lite/api.h" -#include "iree-amd-aie/driver/xrt-lite/command_buffer.h" -#include "iree-amd-aie/driver/xrt-lite/shim/linux/kmq/device.h" +#include "iree-amd-aie/driver/xrt-lite/direct_command_buffer.h" +#include "iree-amd-aie/driver/xrt-lite/nop_executable_cache.h" +#include "iree-amd-aie/driver/xrt-lite/nop_semaphore.h" #include "iree-amd-aie/driver/xrt-lite/util.h" -#include "nop_executable_cache.h" +#include "iree/hal/utils/deferred_command_buffer.h" +#include "iree/hal/utils/deferred_work_queue.h" + +#define ARENA_BLOCK_SIZE (32 * 1024) struct iree_hal_xrt_lite_device { iree_hal_resource_t resource; @@ -19,6 +23,9 @@ struct iree_hal_xrt_lite_device { iree_allocator_t host_allocator; // not used iree_hal_allocator_t* device_allocator; + // Block pool used for command buffers with a larger block size (as command + // buffers can contain inlined data uploads). + iree_arena_block_pool_t block_pool; std::shared_ptr shim_device; iree_status_t create_executable_cache( @@ -36,9 +43,50 @@ struct iree_hal_xrt_lite_device { iree_hal_command_buffer_t** out_command_buffer) { // TODO(null): pass any additional resources required to create the command // buffer. The implementation could pool command buffers here. - return iree_hal_xrt_lite_command_buffer_create( - device_allocator, mode, command_categories, queue_affinity, - binding_capacity, host_allocator, out_command_buffer); + if (!iree_all_bits_set(mode, IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT)) { + return iree_make_status(IREE_STATUS_UNIMPLEMENTED, + "unimplmented multi-shot command buffer"); + } + return iree_hal_deferred_command_buffer_create( + device_allocator, mode, command_categories, binding_capacity, + &block_pool, host_allocator, out_command_buffer); + } + + iree_status_t create_semaphore(uint64_t initial_value, + iree_hal_semaphore_flags_t flags, + iree_hal_semaphore_t** out_semaphore) { + return iree_hal_xrt_lite_semaphore_create(host_allocator, initial_value, + out_semaphore); + } + + iree_status_t queue_execute( + iree_hal_queue_affinity_t queue_affinity, + const iree_hal_semaphore_list_t wait_semaphore_list, + const iree_hal_semaphore_list_t signal_semaphore_list, + iree_host_size_t command_buffer_count, + iree_hal_command_buffer_t* const* command_buffers, + iree_hal_buffer_binding_table_t const* binding_tables) { + IREE_TRACE_ZONE_BEGIN(z0); + + for (iree_host_size_t i = 0; i < command_buffer_count; i++) { + iree_hal_command_buffer_t* xrt_command_buffer = nullptr; + iree_hal_command_buffer_mode_t mode = + IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT | + IREE_HAL_COMMAND_BUFFER_MODE_ALLOW_INLINE_EXECUTION | + IREE_HAL_COMMAND_BUFFER_MODE_UNVALIDATED; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_xrt_lite_direct_command_buffer_create( + shim_device, device_allocator, mode, + IREE_HAL_COMMAND_CATEGORY_ANY, + /*binding_capacity=*/0, &block_pool, host_allocator, + &xrt_command_buffer)); + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_deferred_command_buffer_apply( + command_buffers[i], xrt_command_buffer, + iree_hal_buffer_binding_table_empty())); + } + IREE_TRACE_ZONE_END(z0); + return iree_ok_status(); } }; @@ -82,6 +130,8 @@ iree_status_t iree_hal_xrt_lite_device_create( // from the same driver. iree_status_t status = iree_hal_xrt_lite_allocator_create( host_allocator, device->shim_device, &device->device_allocator); + iree_arena_block_pool_initialize(ARENA_BLOCK_SIZE, host_allocator, + &device->block_pool); // TODO(max): device id *out_device = reinterpret_cast(device); if (iree_status_is_ok(status)) { @@ -143,6 +193,8 @@ static iree_hal_allocator_t* iree_hal_xrt_lite_device_device_allocator( DEVICE_MEMBER_STATUS(create_executable_cache); DEVICE_MEMBER_STATUS(create_command_buffer); +DEVICE_MEMBER_STATUS(create_semaphore); +DEVICE_MEMBER_STATUS(queue_execute); namespace { const iree_hal_device_vtable_t iree_hal_xrt_lite_device_vtable = { @@ -150,6 +202,9 @@ const iree_hal_device_vtable_t iree_hal_xrt_lite_device_vtable = { .id = iree_hal_xrt_lite_device_id, .host_allocator = iree_hal_xrt_lite_device_host_allocator, .device_allocator = iree_hal_xrt_lite_device_device_allocator, + .create_command_buffer = iree_hal_xrt_lite_device_create_command_buffer, .create_executable_cache = iree_hal_xrt_lite_device_create_executable_cache, - .create_command_buffer = iree_hal_xrt_lite_device_create_command_buffer}; + .create_semaphore = iree_hal_xrt_lite_device_create_semaphore, + .queue_execute = iree_hal_xrt_lite_device_queue_execute, +}; } diff --git a/runtime/src/iree-amd-aie/driver/xrt-lite/device.h b/runtime/src/iree-amd-aie/driver/xrt-lite/device.h deleted file mode 100644 index c8d2a6e1f..000000000 --- a/runtime/src/iree-amd-aie/driver/xrt-lite/device.h +++ /dev/null @@ -1,17 +0,0 @@ -// Copyright 2024 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#ifndef IREE_AMD_AIE_DRIVER_XRT_LITE_DEVICE_H_ -#define IREE_AMD_AIE_DRIVER_XRT_LITE_DEVICE_H_ - -#include "iree/base/api.h" -#include "iree/hal/api.h" - -// NOTE: nothing in the skeleton implementation. Device creation and adoption is -// part of the public API header. This header can contain internal types and -// functions. - -#endif // IREE_AMD_AIE_DRIVER_XRT_LITE_DEVICE_H_ diff --git a/runtime/src/iree-amd-aie/driver/xrt-lite/direct_command_buffer.cc b/runtime/src/iree-amd-aie/driver/xrt-lite/direct_command_buffer.cc new file mode 100644 index 000000000..7c856b88e --- /dev/null +++ b/runtime/src/iree-amd-aie/driver/xrt-lite/direct_command_buffer.cc @@ -0,0 +1,371 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree-amd-aie/driver/xrt-lite/direct_command_buffer.h" + +#include "iree-amd-aie/driver/xrt-lite/buffer.h" +#include "iree-amd-aie/driver/xrt-lite/executable.h" +#include "iree-amd-aie/driver/xrt-lite/shim/linux/kmq/hwq.h" +#include "iree/hal/utils/resource_set.h" + +// The max number of bindings per descriptor set allowed in the XRT HAL +// implementation. +#define IREE_HAL_XRT_LITE_MAX_DESCRIPTOR_SET_BINDING_COUNT 16 + +// The max number of descriptor sets allowed in the XRT HAL implementation. +// This depends on the general descriptor set planning in IREE and should adjust +// with it. +#define IREE_HAL_XRT_LITE_MAX_DESCRIPTOR_SET_COUNT 4 + +struct iree_hal_xrt_lite_direct_command_buffer { + iree_hal_command_buffer_t base; + iree_allocator_t host_allocator; + // A resource set to maintain references to all resources used within the + // command buffer. Reset on each begin. + iree_hal_resource_set_t* resource_set; + // Staging arena used for host->device transfers. + iree_arena_allocator_t arena; + + std::shared_ptr shim_device; + + struct { + shim_xdna::bo* bindings[IREE_HAL_XRT_LITE_MAX_DESCRIPTOR_SET_BINDING_COUNT]; + // Offset and length are used to get the sub buffer at kernel launch. + iree_device_size_t + offsets[IREE_HAL_XRT_LITE_MAX_DESCRIPTOR_SET_BINDING_COUNT]; + iree_device_size_t + lengths[IREE_HAL_XRT_LITE_MAX_DESCRIPTOR_SET_BINDING_COUNT]; + + } descriptor_sets[IREE_HAL_XRT_LITE_MAX_DESCRIPTOR_SET_COUNT]; +}; + +namespace { +extern const iree_hal_command_buffer_vtable_t + iree_hal_xrt_lite_direct_command_buffer_vtable; +} // namespace + +static iree_hal_xrt_lite_direct_command_buffer* +iree_hal_xrt_lite_direct_command_buffer_cast( + iree_hal_command_buffer_t* base_value) { + IREE_HAL_ASSERT_TYPE(base_value, + &iree_hal_xrt_lite_direct_command_buffer_vtable); + return (iree_hal_xrt_lite_direct_command_buffer*)base_value; +} + +iree_status_t iree_hal_xrt_lite_direct_command_buffer_create( + std::shared_ptr shim_device, + iree_hal_allocator_t* device_allocator, iree_hal_command_buffer_mode_t mode, + iree_hal_command_category_t command_categories, + iree_host_size_t binding_capacity, iree_arena_block_pool_t* block_pool, + iree_allocator_t host_allocator, + iree_hal_command_buffer_t** out_command_buffer) { + IREE_ASSERT_ARGUMENT(device_allocator); + IREE_ASSERT_ARGUMENT(out_command_buffer); + *out_command_buffer = nullptr; + if (binding_capacity > 0) { + // TODO(#10144): support indirect command buffers with binding tables. + return iree_make_status(IREE_STATUS_UNIMPLEMENTED, + "indirect command buffers not yet implemented"); + } + + IREE_TRACE_ZONE_BEGIN(z0); + + iree_hal_xrt_lite_direct_command_buffer* command_buffer = nullptr; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, + iree_allocator_malloc(host_allocator, + sizeof(*command_buffer) + + iree_hal_command_buffer_validation_state_size( + mode, binding_capacity), + (void**)&command_buffer)); + iree_hal_command_buffer_initialize( + device_allocator, mode, command_categories, IREE_HAL_QUEUE_AFFINITY_ANY, + binding_capacity, (uint8_t*)command_buffer + sizeof(*command_buffer), + &iree_hal_xrt_lite_direct_command_buffer_vtable, &command_buffer->base); + command_buffer->host_allocator = host_allocator; + command_buffer->shim_device = shim_device; + iree_arena_initialize(block_pool, &command_buffer->arena); + iree_status_t status = + iree_hal_resource_set_allocate(block_pool, &command_buffer->resource_set); + if (iree_status_is_ok(status)) { + *out_command_buffer = &command_buffer->base; + } else { + iree_hal_command_buffer_release(&command_buffer->base); + } + + IREE_TRACE_ZONE_END(z0); + + return status; +} +static void iree_hal_xrt_lite_direct_command_buffer_destroy( + iree_hal_command_buffer_t* base_command_buffer) { + iree_hal_xrt_lite_direct_command_buffer* command_buffer = + iree_hal_xrt_lite_direct_command_buffer_cast(base_command_buffer); + iree_allocator_t host_allocator = command_buffer->host_allocator; + IREE_TRACE_ZONE_BEGIN(z0); + command_buffer->shim_device.reset(); + iree_hal_resource_set_free(command_buffer->resource_set); + iree_arena_deinitialize(&command_buffer->arena); + iree_allocator_free(host_allocator, command_buffer); + + IREE_TRACE_ZONE_END(z0); +} + +static iree_status_t iree_hal_xrt_lite_direct_command_buffer_begin( + iree_hal_command_buffer_t* base_command_buffer) { + // Nothing to do. + return iree_ok_status(); +} + +static iree_status_t iree_hal_xrt_lite_direct_command_buffer_end( + iree_hal_command_buffer_t* base_command_buffer) { + iree_hal_xrt_lite_direct_command_buffer* command_buffer = + iree_hal_xrt_lite_direct_command_buffer_cast(base_command_buffer); + IREE_TRACE_ZONE_BEGIN(z0); + iree_arena_reset(&command_buffer->arena); + iree_hal_resource_set_free(command_buffer->resource_set); + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_resource_set_allocate(command_buffer->arena.block_pool, + &command_buffer->resource_set)); + + IREE_TRACE_ZONE_END(z0); + return iree_ok_status(); +} + +static void iree_hal_xrt_lite_direct_command_buffer_begin_debug_group( + iree_hal_command_buffer_t* base_command_buffer, iree_string_view_t label, + iree_hal_label_color_t label_color, + const iree_hal_label_location_t* location) { + (void)iree_status_from_code(IREE_STATUS_UNIMPLEMENTED); +} + +static void iree_hal_xrt_lite_direct_command_buffer_end_debug_group( + iree_hal_command_buffer_t* base_command_buffer) { + (void)iree_status_from_code(IREE_STATUS_UNIMPLEMENTED); +} + +static iree_status_t iree_hal_xrt_lite_direct_command_buffer_execution_barrier( + iree_hal_command_buffer_t* base_command_buffer, + iree_hal_execution_stage_t source_stage_mask, + iree_hal_execution_stage_t target_stage_mask, + iree_hal_execution_barrier_flags_t flags, + iree_host_size_t memory_barrier_count, + const iree_hal_memory_barrier_t* memory_barriers, + iree_host_size_t buffer_barrier_count, + const iree_hal_buffer_barrier_t* buffer_barriers) { + if (iree_any_bit_set(source_stage_mask, IREE_HAL_EXECUTION_STAGE_HOST) || + iree_any_bit_set(target_stage_mask, IREE_HAL_EXECUTION_STAGE_HOST)) { + return iree_make_status(IREE_STATUS_UNIMPLEMENTED, + "barrier involving host not yet supported"); + } + + if (flags != IREE_HAL_EXECUTION_BARRIER_FLAG_NONE) { + return iree_make_status(IREE_STATUS_UNIMPLEMENTED, + "non-zero barrier flag not yet supported"); + } + + // Nothing to do in current synchronous mode. + + return iree_ok_status(); +} + +static iree_status_t iree_hal_xrt_lite_direct_command_buffer_signal_event( + iree_hal_command_buffer_t* base_command_buffer, iree_hal_event_t* event, + iree_hal_execution_stage_t source_stage_mask) { + return iree_make_status(IREE_STATUS_UNIMPLEMENTED, "event not yet supported"); +} + +static iree_status_t iree_hal_xrt_lite_direct_command_buffer_reset_event( + iree_hal_command_buffer_t* base_command_buffer, iree_hal_event_t* event, + iree_hal_execution_stage_t source_stage_mask) { + return iree_make_status(IREE_STATUS_UNIMPLEMENTED, "event not yet supported"); +} + +static iree_status_t iree_hal_xrt_lite_direct_command_buffer_wait_events( + iree_hal_command_buffer_t* base_command_buffer, + iree_host_size_t event_count, const iree_hal_event_t** events, + iree_hal_execution_stage_t source_stage_mask, + iree_hal_execution_stage_t target_stage_mask, + iree_host_size_t memory_barrier_count, + const iree_hal_memory_barrier_t* memory_barriers, + iree_host_size_t buffer_barrier_count, + const iree_hal_buffer_barrier_t* buffer_barriers) { + return iree_make_status(IREE_STATUS_UNIMPLEMENTED, "event not yet supported"); +} + +static iree_status_t iree_hal_xrt_lite_direct_command_buffer_discard_buffer( + iree_hal_command_buffer_t* base_command_buffer, + iree_hal_buffer_ref_t buffer) { + // It is okay to do nothing here. + return iree_ok_status(); +} + +static iree_status_t iree_hal_xrt_lite_direct_command_buffer_fill_buffer( + iree_hal_command_buffer_t* base_command_buffer, + iree_hal_buffer_ref_t target_ref, const void* pattern, + iree_host_size_t pattern_length) { + return iree_make_status(IREE_STATUS_UNIMPLEMENTED, + "fill buffer not yet supported"); +} + +static iree_status_t iree_hal_xrt_lite_direct_command_buffer_update_buffer( + iree_hal_command_buffer_t* base_command_buffer, const void* source_buffer, + iree_host_size_t source_offset, iree_hal_buffer_ref_t target_ref) { + IREE_TRACE_ZONE_BEGIN(z0); + const uint8_t* src = (const uint8_t*)source_buffer + source_offset; + + // No need to Allocate scratch space (in an arena) as the memcpy + // used below is expected to be synchronized. + shim_xdna::bo* target_device_buffer = iree_hal_xrt_lite_buffer_handle( + iree_hal_buffer_allocated_buffer(target_ref.buffer)); + void* target_device_buffer_ptr = target_device_buffer->map(); + uint8_t* dst = (uint8_t*)target_device_buffer_ptr + + iree_hal_buffer_byte_offset(target_ref.buffer) + + target_ref.offset; + memcpy(dst, src, target_ref.length); + + IREE_TRACE_ZONE_END(z0); + return iree_ok_status(); +} + +static iree_status_t iree_hal_xrt_lite_direct_command_buffer_copy_buffer( + iree_hal_command_buffer_t* base_command_buffer, + iree_hal_buffer_ref_t source_ref, iree_hal_buffer_ref_t target_ref) { + IREE_TRACE_ZONE_BEGIN(z0); + + shim_xdna::bo* target_device_buffer = iree_hal_xrt_lite_buffer_handle( + iree_hal_buffer_allocated_buffer(target_ref.buffer)); + void* target_device_buffer_ptr = target_device_buffer->map(); + iree_device_size_t target_offset = + iree_hal_buffer_byte_offset(target_ref.buffer) + target_ref.offset; + + shim_xdna::bo* source_device_buffer = iree_hal_xrt_lite_buffer_handle( + iree_hal_buffer_allocated_buffer(source_ref.buffer)); + void* source_device_buffer_ptr = source_device_buffer->map(); + iree_device_size_t source_offset = + iree_hal_buffer_byte_offset(source_ref.buffer) + source_ref.offset; + + uint8_t* dst = (uint8_t*)target_device_buffer_ptr + target_offset; + uint8_t* src = (uint8_t*)source_device_buffer_ptr + source_offset; + memcpy(dst, src, target_ref.length); + + IREE_TRACE_ZONE_END(z0); + return iree_ok_status(); +} + +static iree_status_t iree_hal_xrt_lite_direct_command_buffer_collective( + iree_hal_command_buffer_t* base_command_buffer, iree_hal_channel_t* channel, + iree_hal_collective_op_t op, uint32_t param, iree_hal_buffer_ref_t send_ref, + iree_hal_buffer_ref_t recv_ref, iree_device_size_t element_count) { + return iree_make_status(IREE_STATUS_UNIMPLEMENTED, + "collectives not yet supported"); +} + +static iree_status_t iree_hal_xrt_lite_direct_command_buffer_dispatch( + iree_hal_command_buffer_t* base_command_buffer, + iree_hal_executable_t* executable, int32_t entry_point, + const uint32_t workgroup_count[3], iree_const_byte_span_t constants, + iree_hal_buffer_ref_list_t bindings, iree_hal_dispatch_flags_t flags) { + iree_hal_xrt_lite_direct_command_buffer* command_buffer = + reinterpret_cast( + base_command_buffer); + IREE_TRACE_ZONE_BEGIN(z0); + + // Lookup kernel parameters used for side-channeling additional launch + // information from the compiler. + iree_hal_xrt_lite_kernel_params_t kernel_params; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_xrt_lite_native_executable_entry_point_kernel_params( + executable, entry_point, &kernel_params)); + + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_resource_set_insert(command_buffer->resource_set, 1, + &executable)); + + xrt::xclbin xclbin = xrt::xclbin(kernel_params.xclbinVector); + kernel_params.context = + command_buffer->shim_device->create_hw_context(xclbin); + uint32_t num_instr = flatbuffers_uint32_vec_len(kernel_params.asm_inst); + size_t ctrl_code_size = num_instr * sizeof(uint32_t); + auto bo_ctrl_code = command_buffer->shim_device->alloc_bo( + ctrl_code_size, XCL_BO_FLAGS_CACHEABLE); + uint32_t* instr_buffer = static_cast(bo_ctrl_code->map()); + memcpy(instr_buffer, kernel_params.asm_inst, ctrl_code_size); + bo_ctrl_code->sync(shim_xdna::direction::host2device); + + std::string cu_name = kernel_params.kernel_name; + cu_name += ":IREE"; + shim_xdna::cuidx_t cu_idx = kernel_params.context->open_cu_context(cu_name); + + shim_xdna::exec_buf ebuf(command_buffer->shim_device->get_pdev(), + ERT_START_CU); + ebuf.set_cu_idx(cu_idx); + unsigned int opcode = 3; + ebuf.add_arg_64(opcode); + ebuf.add_arg_bo(*bo_ctrl_code); + ebuf.add_arg_32(num_instr); + for (iree_host_size_t j = 0; j < bindings.count; ++j) { + shim_xdna::bo* bo = iree_hal_xrt_lite_buffer_handle( + iree_hal_buffer_allocated_buffer(bindings.values[j].buffer)); + ebuf.add_arg_bo(*bo); + } + + for (iree_host_size_t j = 0; j < bindings.count; ++j) { + shim_xdna::bo* bo = iree_hal_xrt_lite_buffer_handle( + iree_hal_buffer_allocated_buffer(bindings.values[j].buffer)); + bo->sync(shim_xdna::direction::host2device); + } + shim_xdna::hw_q* hwq = kernel_params.context->get_hw_queue(); + hwq->issue_command(ebuf.get_exec_buf_bo()); + hwq->wait_command(ebuf.get_exec_buf_bo(), 0); + + for (iree_host_size_t j = 0; j < bindings.count; ++j) { + shim_xdna::bo* bo = iree_hal_xrt_lite_buffer_handle( + iree_hal_buffer_allocated_buffer(bindings.values[j].buffer)); + bo->sync(shim_xdna::direction::device2host); + } + + for (iree_host_size_t j = 0; j < bindings.count; ++j) { + shim_xdna::bo* bo = iree_hal_xrt_lite_buffer_handle( + iree_hal_buffer_allocated_buffer(bindings.values[j].buffer)); + } + + IREE_TRACE_ZONE_END(z0); + + return iree_ok_status(); +} + +static iree_status_t iree_hal_xrt_lite_direct_command_buffer_dispatch_indirect( + iree_hal_command_buffer_t* base_command_buffer, + iree_hal_executable_t* executable, int32_t entry_point, + iree_hal_buffer_ref_t workgroups_ref, iree_const_byte_span_t constants, + iree_hal_buffer_ref_list_t bindings, iree_hal_dispatch_flags_t flags) { + return iree_make_status(IREE_STATUS_UNIMPLEMENTED, + "need xrt implementation of dispatch indirect"); +} + +namespace { +const iree_hal_command_buffer_vtable_t + iree_hal_xrt_lite_direct_command_buffer_vtable = { + .destroy = iree_hal_xrt_lite_direct_command_buffer_destroy, + .begin = iree_hal_xrt_lite_direct_command_buffer_begin, + .end = iree_hal_xrt_lite_direct_command_buffer_end, + .execution_barrier = + iree_hal_xrt_lite_direct_command_buffer_execution_barrier, + .signal_event = iree_hal_xrt_lite_direct_command_buffer_signal_event, + .reset_event = iree_hal_xrt_lite_direct_command_buffer_reset_event, + .wait_events = iree_hal_xrt_lite_direct_command_buffer_wait_events, + .discard_buffer = + iree_hal_xrt_lite_direct_command_buffer_discard_buffer, + .fill_buffer = iree_hal_xrt_lite_direct_command_buffer_fill_buffer, + .update_buffer = iree_hal_xrt_lite_direct_command_buffer_update_buffer, + .copy_buffer = iree_hal_xrt_lite_direct_command_buffer_copy_buffer, + .collective = iree_hal_xrt_lite_direct_command_buffer_collective, + .dispatch = iree_hal_xrt_lite_direct_command_buffer_dispatch, + .dispatch_indirect = + iree_hal_xrt_lite_direct_command_buffer_dispatch_indirect, +}; +} // namespace diff --git a/runtime/src/iree-amd-aie/driver/xrt-lite/direct_command_buffer.h b/runtime/src/iree-amd-aie/driver/xrt-lite/direct_command_buffer.h new file mode 100644 index 000000000..91eb4aece --- /dev/null +++ b/runtime/src/iree-amd-aie/driver/xrt-lite/direct_command_buffer.h @@ -0,0 +1,32 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_AMD_AIE_DRIVER_XRT_LITE_XRT_LITE_COMMAND_BUFFER_H_ +#define IREE_AMD_AIE_DRIVER_XRT_LITE_XRT_LITE_COMMAND_BUFFER_H_ + +#include "iree-amd-aie/driver/xrt-lite/shim/linux/kmq/device.h" +#include "iree/base/internal/arena.h" +#include "iree/hal/api.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// |out_command_buffer| must be released by the caller (see +// iree_hal_command_buffer_release). +iree_status_t iree_hal_xrt_lite_direct_command_buffer_create( + std::shared_ptr shim_device, + iree_hal_allocator_t* device_allocator, iree_hal_command_buffer_mode_t mode, + iree_hal_command_category_t command_categories, + iree_host_size_t binding_capacity, iree_arena_block_pool_t* block_pool, + iree_allocator_t host_allocator, + iree_hal_command_buffer_t** out_command_buffer); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // IREE_AMD_AIE_DRIVER_XRT_LITE_XRT_LITE_COMMAND_BUFFER_H_ diff --git a/runtime/src/iree-amd-aie/driver/xrt-lite/executable.cc b/runtime/src/iree-amd-aie/driver/xrt-lite/executable.cc index fdb48b0d8..2e9c46ec1 100644 --- a/runtime/src/iree-amd-aie/driver/xrt-lite/executable.cc +++ b/runtime/src/iree-amd-aie/driver/xrt-lite/executable.cc @@ -161,7 +161,9 @@ iree_status_t iree_hal_xrt_lite_native_executable_create( executable->entry_point_count = entry_point_count; for (iree_host_size_t entry_ordinal = 0; entry_ordinal < entry_point_count; entry_ordinal++) { - const char* entry_name = + iree_hal_xrt_lite_kernel_params_t* params = + &executable->entry_points[entry_ordinal]; + params->kernel_name = flatbuffers_string_vec_at(entry_points_vec, entry_ordinal); uint32_t xclbin_index = flatbuffers_uint32_vec_at(xclbin_indices_vec, entry_ordinal); @@ -170,38 +172,33 @@ iree_status_t iree_hal_xrt_lite_native_executable_create( flatbuffers_string_t xclbin_fb = iree_amd_aie_hal_xrt_XclbinDef_xclbin_get(xclbin_def); - iree_hal_xrt_lite_kernel_params_t* params = - &executable->entry_points[entry_ordinal]; - // XRT API needs this vector and cant actually read a void*. std::vector xclbinVector( xclbin_fb, xclbin_fb + flatbuffers_string_len(xclbin_fb)); - xrt::xclbin xclbin = xrt::xclbin(xclbinVector); - params->context = shim_device->create_hw_context(xclbin); + params->xclbinVector = xclbinVector; +// xrt::xclbin xclbin = xrt::xclbin(xclbinVector); +// params->context = shim_device->create_hw_context(xclbin); uint32_t asm_instr_index = flatbuffers_uint32_vec_at(asm_instr_indices_vec, entry_ordinal); iree_amd_aie_hal_xrt_AsmInstDef_table_t asminst_def = iree_amd_aie_hal_xrt_AsmInstDef_vec_at(asm_instrs_vec, asm_instr_index); - flatbuffers_uint32_vec_t asm_inst = + params->asm_inst = iree_amd_aie_hal_xrt_AsmInstDef_asm_inst_get(asminst_def); - uint32_t num_instr = flatbuffers_uint32_vec_len(asm_inst); - size_t ctrl_code_size = num_instr * sizeof(uint32_t); - params->bo_ctrl_code = - shim_device->alloc_bo(ctrl_code_size, XCL_BO_FLAGS_CACHEABLE); - uint32_t* instr_buffer = - static_cast(params->bo_ctrl_code->map()); - memcpy(instr_buffer, asm_inst, ctrl_code_size); - params->num_instr = num_instr; +// uint32_t num_instr = flatbuffers_uint32_vec_len(asm_inst); +// size_t ctrl_code_size = num_instr * sizeof(uint32_t); +// params->bo_ctrl_code = +// shim_device->alloc_bo(ctrl_code_size, XCL_BO_FLAGS_CACHEABLE); +// uint32_t* instr_buffer = +// static_cast(params->bo_ctrl_code->map()); +// memcpy(instr_buffer, asm_inst, ctrl_code_size); // Stash the entry point name in the string table for use when tracing. IREE_TRACE({ - iree_host_size_t entry_name_length = flatbuffers_string_len(entry_name); - memcpy(string_table_buffer, entry_name, entry_name_length); - params->kernel_name = - iree_make_string_view(string_table_buffer, entry_name_length); - string_table_buffer += entry_name_length; + memcpy(string_table_buffer, params->kernel_name.data(), + params->kernel_name.size()); + string_table_buffer += params->kernel_name.size(); }); IREE_TRACE({ diff --git a/runtime/src/iree-amd-aie/driver/xrt-lite/executable.h b/runtime/src/iree-amd-aie/driver/xrt-lite/executable.h index b70c266cc..ee57055e4 100644 --- a/runtime/src/iree-amd-aie/driver/xrt-lite/executable.h +++ b/runtime/src/iree-amd-aie/driver/xrt-lite/executable.h @@ -9,6 +9,7 @@ #include +#include "flatbuffers_common_reader.h" #include "iree-amd-aie/driver/xrt-lite/shim/linux/kmq/bo.h" #include "iree-amd-aie/driver/xrt-lite/shim/linux/kmq/device.h" #include "iree-amd-aie/driver/xrt-lite/shim/linux/kmq/hwctx.h" @@ -24,9 +25,10 @@ extern "C" { struct iree_hal_xrt_lite_kernel_params_t { std::unique_ptr context; std::unique_ptr bo_ctrl_code; + std::vector xclbinVector; + flatbuffers_uint32_vec_t asm_inst; // Number of assembly instructions argument to the kernel - uint32_t num_instr; // number of instructions - IREE_TRACE(iree_string_view_t kernel_name;) + std::string kernel_name; IREE_TRACE(iree_string_view_t source_filename;) IREE_TRACE(uint32_t source_line;) }; @@ -44,7 +46,7 @@ iree_status_t iree_hal_xrt_lite_native_executable_entry_point_kernel_params( iree_hal_xrt_lite_kernel_params_t* out_params); #ifdef __cplusplus -} // extern "C" +} // extern "C" #endif // __cplusplus #endif // IREE_AMD_AIE_DRIVER_XRT_LITE_NATIVE_EXECUTABLE_H_ diff --git a/runtime/src/iree-amd-aie/driver/xrt-lite/nop_executable_cache.cc b/runtime/src/iree-amd-aie/driver/xrt-lite/nop_executable_cache.cc index 2753eebb7..8a617f977 100644 --- a/runtime/src/iree-amd-aie/driver/xrt-lite/nop_executable_cache.cc +++ b/runtime/src/iree-amd-aie/driver/xrt-lite/nop_executable_cache.cc @@ -62,6 +62,7 @@ static void iree_hal_xrt_lite_nop_executable_cache_destroy( iree_hal_xrt_lite_nop_executable_cache_cast(base_executable_cache); IREE_TRACE_ZONE_BEGIN(z0); + executable_cache->shim_device.reset(); iree_allocator_free(executable_cache->host_allocator, executable_cache); IREE_TRACE_ZONE_END(z0); diff --git a/runtime/src/iree-amd-aie/driver/xrt-lite/nop_semaphore.cc b/runtime/src/iree-amd-aie/driver/xrt-lite/nop_semaphore.cc new file mode 100644 index 000000000..17810350f --- /dev/null +++ b/runtime/src/iree-amd-aie/driver/xrt-lite/nop_semaphore.cc @@ -0,0 +1,115 @@ +// Copyright 2021 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree-amd-aie/driver/xrt-lite/nop_semaphore.h" + +#include + +#include "iree/base/api.h" +#include "iree/hal/utils/semaphore_base.h" + +struct iree_hal_xrt_lite_semaphore_t { + iree_hal_semaphore_t base; + iree_atomic_int64_t value; + iree_allocator_t host_allocator; +}; + +namespace { +extern const iree_hal_semaphore_vtable_t iree_hal_xrt_lite_semaphore_vtable; +} // namespace + +static iree_hal_xrt_lite_semaphore_t* iree_hal_xrt_lite_semaphore_cast( + iree_hal_semaphore_t* base_value) { + IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_xrt_lite_semaphore_vtable); + return (iree_hal_xrt_lite_semaphore_t*)base_value; +} + +iree_status_t iree_hal_xrt_lite_semaphore_create( + iree_allocator_t host_allocator, uint64_t initial_value, + iree_hal_semaphore_t** out_semaphore) { + IREE_ASSERT_ARGUMENT(out_semaphore); + IREE_TRACE_ZONE_BEGIN(z0); + + iree_hal_xrt_lite_semaphore_t* semaphore = nullptr; + iree_status_t status = iree_allocator_malloc( + host_allocator, sizeof(*semaphore), (void**)&semaphore); + if (iree_status_is_ok(status)) { + iree_hal_semaphore_initialize(&iree_hal_xrt_lite_semaphore_vtable, + &semaphore->base); + semaphore->host_allocator = host_allocator; + iree_atomic_store_int64(&semaphore->value, initial_value, + iree_memory_order_release); + *out_semaphore = &semaphore->base; + } + + IREE_TRACE_ZONE_END(z0); + return status; +} + +static void iree_hal_xrt_lite_semaphore_destroy( + iree_hal_semaphore_t* base_semaphore) { + iree_hal_xrt_lite_semaphore_t* semaphore = + iree_hal_xrt_lite_semaphore_cast(base_semaphore); + iree_allocator_t host_allocator = semaphore->host_allocator; + IREE_TRACE_ZONE_BEGIN(z0); + + iree_hal_semaphore_deinitialize(&semaphore->base); + iree_allocator_free(host_allocator, semaphore); + + IREE_TRACE_ZONE_END(z0); +} + +static iree_status_t iree_hal_xrt_lite_semaphore_query( + iree_hal_semaphore_t* base_semaphore, uint64_t* out_value) { + iree_hal_xrt_lite_semaphore_t* semaphore = + iree_hal_xrt_lite_semaphore_cast(base_semaphore); + // TODO: Support semaphores completely. + *out_value = + iree_atomic_load_int64(&semaphore->value, iree_memory_order_acquire); + return iree_ok_status(); +} + +static iree_status_t iree_hal_xrt_lite_semaphore_signal( + iree_hal_semaphore_t* base_semaphore, uint64_t new_value) { + iree_hal_xrt_lite_semaphore_t* semaphore = + iree_hal_xrt_lite_semaphore_cast(base_semaphore); + // TODO: Support semaphores completely. Return OK currently as everything is + // synchronized for each submit to allow things to run. + iree_atomic_store_int64(&semaphore->value, new_value, + iree_memory_order_release); + iree_hal_semaphore_poll(&semaphore->base); + return iree_ok_status(); +} + +static void iree_hal_xrt_lite_semaphore_fail( + iree_hal_semaphore_t* base_semaphore, iree_status_t status) { + iree_hal_xrt_lite_semaphore_t* semaphore = + iree_hal_xrt_lite_semaphore_cast(base_semaphore); + // TODO: save status and mark timepoint as failed. + iree_status_ignore(status); + iree_hal_semaphore_poll(&semaphore->base); +} + +static iree_status_t iree_hal_xrt_lite_semaphore_wait( + iree_hal_semaphore_t* base_semaphore, uint64_t value, + iree_timeout_t timeout) { + iree_hal_xrt_lite_semaphore_t* semaphore = + iree_hal_xrt_lite_semaphore_cast(base_semaphore); + // TODO: Support semaphores completely. Return OK currently as everything is + // synchronized for each submit to allow things to run. + iree_hal_semaphore_poll(&semaphore->base); + return iree_ok_status(); +} + +namespace { +const iree_hal_semaphore_vtable_t iree_hal_xrt_lite_semaphore_vtable = { + /*.destroy = */ iree_hal_xrt_lite_semaphore_destroy, + /*.query = */ iree_hal_xrt_lite_semaphore_query, + /*.signal = */ iree_hal_xrt_lite_semaphore_signal, + /*.fail = */ iree_hal_xrt_lite_semaphore_fail, + /*.wait = */ iree_hal_xrt_lite_semaphore_wait, +}; +} // namespace diff --git a/runtime/src/iree-amd-aie/driver/xrt-lite/nop_semaphore.h b/runtime/src/iree-amd-aie/driver/xrt-lite/nop_semaphore.h new file mode 100644 index 000000000..0a8623863 --- /dev/null +++ b/runtime/src/iree-amd-aie/driver/xrt-lite/nop_semaphore.h @@ -0,0 +1,27 @@ +// Copyright 2021 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_AMD_AIE_DRIVER_XRT_LITE_NOP_SEMAPHORE_H_ +#define IREE_AMD_AIE_DRIVER_XRT_LITE_NOP_SEMAPHORE_H_ + +#include + +#include "iree/base/api.h" +#include "iree/hal/api.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +iree_status_t iree_hal_xrt_lite_semaphore_create( + iree_allocator_t host_allocator, uint64_t initial_value, + iree_hal_semaphore_t** out_semaphore); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // IREE_AMD_AIE_DRIVER_XRT_LITE_NOP_SEMAPHORE_H_ diff --git a/runtime/src/iree-amd-aie/driver/xrt-lite/shim/linux/kmq/bo.cpp b/runtime/src/iree-amd-aie/driver/xrt-lite/shim/linux/kmq/bo.cpp index cc349197c..4cb322881 100644 --- a/runtime/src/iree-amd-aie/driver/xrt-lite/shim/linux/kmq/bo.cpp +++ b/runtime/src/iree-amd-aie/driver/xrt-lite/shim/linux/kmq/bo.cpp @@ -270,6 +270,12 @@ bo::bo(const pdev &p, uint32_t ctx_id, size_t size, shim_xcl_bo_flags flags) shim_err(EINVAL, "Invalid BO flags: 0x%lx", flags); } +bo::bo(const pdev &p, uint32_t ctx_id, size_t size, uint32_t flags) + : bo(p, ctx_id, size, shim_xcl_bo_flags{.flags = flags}) { + if (m_type == AMDXDNA_BO_INVALID) + shim_err(EINVAL, "Invalid BO flags: 0x%lx", flags); +} + bo::bo(const pdev &pdev, uint32_t ctx_id, size_t size, shim_xcl_bo_flags flags, amdxdna_bo_type type) : m_pdev(pdev), @@ -430,6 +436,8 @@ void bo::sync(direction dir, size_t size, size_t offset) { } } +void bo::sync(direction dir) { sync(dir, size(), 0); } + void bo::bind_at(size_t pos, const bo &boh, size_t offset, size_t size) { std::lock_guard lg(m_args_map_lock); @@ -471,10 +479,12 @@ uint32_t bo::get_arg_bo_handles(uint32_t *handles, size_t num) const { return sz; } -exec_buf::exec_buf(bo &bo_execbuf, uint32_t op) - : m_exec_buf_bo(bo_execbuf), - m_cmd_pkt(reinterpret_cast(bo_execbuf.map())), - m_cmd_size(bo_execbuf.size()), +exec_buf::exec_buf(const pdev &p, uint32_t op) + : m_exec_buf_bo(std::make_unique(p, AMDXDNA_INVALID_CTX_HANDLE, + MAX_EXEC_BO_SIZE, + XCL_BO_FLAGS_EXECBUF)), + m_cmd_pkt(reinterpret_cast(m_exec_buf_bo->map())), + m_cmd_size(m_exec_buf_bo->size()), m_op(op), m_arg_cnt(0), m_reg_idx(0) { @@ -498,7 +508,7 @@ void exec_buf::set_cu_idx(cuidx_t cu_idx) { void exec_buf::add_ctrl_bo(bo &bo_ctrl) { ert_start_kernel_cmd *cmd_packet = - reinterpret_cast(m_exec_buf_bo.map()); + reinterpret_cast(m_exec_buf_bo->map()); switch (m_op) { case ERT_START_CU: break; @@ -541,7 +551,7 @@ void exec_buf::add_arg_64(uint64_t val) { void exec_buf::add_arg_bo(bo &bo_arg, std::string arg_name) { // Add to argument list for driver - m_exec_buf_bo.bind_at(m_arg_cnt, bo_arg, 0, bo_arg.size()); + m_exec_buf_bo->bind_at(m_arg_cnt, bo_arg, 0, bo_arg.size()); // Add to argument list for control code patching if (arg_name.empty()) m_patching_args.emplace_back(std::to_string(m_arg_cnt), bo_arg.get_paddr()); @@ -553,7 +563,7 @@ void exec_buf::add_arg_bo(bo &bo_arg, std::string arg_name) { void exec_buf::dump() { std::cout << "Dumping exec buf:"; - int *data = static_cast(m_exec_buf_bo.map()); + int *data = static_cast(m_exec_buf_bo->map()); std::cout << std::hex; for (int i = 0; i < m_cmd_pkt->count + 1; i++) { if (i % 4 == 0) std::cout << "\n"; @@ -574,4 +584,7 @@ void exec_buf::inc_pkt_count(uint32_t n) { throw std::runtime_error("Size of exec buf too small: " + std::to_string(m_cmd_size)); } + +bo *exec_buf::get_exec_buf_bo() { return m_exec_buf_bo.get(); } + } // namespace shim_xdna diff --git a/runtime/src/iree-amd-aie/driver/xrt-lite/shim/linux/kmq/bo.h b/runtime/src/iree-amd-aie/driver/xrt-lite/shim/linux/kmq/bo.h index 617e9335a..16d01fe8c 100644 --- a/runtime/src/iree-amd-aie/driver/xrt-lite/shim/linux/kmq/bo.h +++ b/runtime/src/iree-amd-aie/driver/xrt-lite/shim/linux/kmq/bo.h @@ -13,6 +13,8 @@ namespace shim_xdna { +#define MAX_EXEC_BO_SIZE 4096 + enum xclBOSyncDirection { XCL_BO_SYNC_BO_TO_DEVICE = 0, XCL_BO_SYNC_BO_FROM_DEVICE, @@ -68,6 +70,7 @@ struct bo { bo(const pdev &p, uint32_t ctx_id, size_t size, shim_xcl_bo_flags flags, amdxdna_bo_type type); bo(const pdev &p, uint32_t ctx_id, size_t size, shim_xcl_bo_flags flags); + bo(const pdev &p, uint32_t ctx_id, size_t size, uint32_t flags); bo(const pdev &p, int ehdl); // Support BO creation from internal bo(const pdev &p, size_t size, amdxdna_bo_type type); @@ -76,6 +79,7 @@ struct bo { void *map() const; void unmap(void *addr); void sync(direction, size_t size, size_t offset); + void sync(direction); properties get_properties() const; size_t size(); @@ -104,7 +108,7 @@ struct bo { }; struct exec_buf { - bo &m_exec_buf_bo; + std::unique_ptr m_exec_buf_bo; ert_start_kernel_cmd *m_cmd_pkt; size_t m_cmd_size; uint32_t m_op; @@ -112,16 +116,17 @@ struct exec_buf { uint32_t m_reg_idx; std::vector > m_patching_args; - exec_buf(bo &bo_execbuf, uint32_t op); + exec_buf(const pdev &p, uint32_t op); + static void set_cu_idx(bo &bo_execbuf, cuidx_t cu_idx); void set_cu_idx(cuidx_t cu_idx); + bo* get_exec_buf_bo(); + void add_ctrl_bo(bo &bo_ctrl); void add_arg_32(uint32_t val); void add_arg_64(uint64_t val); void add_arg_bo(bo &bo_arg, std::string arg_name = ""); void dump(); - static size_t get_ctrl_code_size(const std::string &elf_path); - void patch_ctrl_code(bo &bo_ctrl, const std::string &elf_path); void inc_pkt_count(uint32_t n); }; diff --git a/runtime/src/iree-amd-aie/driver/xrt-lite/shim/linux/kmq/device.cpp b/runtime/src/iree-amd-aie/driver/xrt-lite/shim/linux/kmq/device.cpp index 0e521e4f2..3b1ddb73a 100644 --- a/runtime/src/iree-amd-aie/driver/xrt-lite/shim/linux/kmq/device.cpp +++ b/runtime/src/iree-amd-aie/driver/xrt-lite/shim/linux/kmq/device.cpp @@ -169,9 +169,8 @@ std::unique_ptr device::alloc_bo(size_t size, shim_xcl_bo_flags flags) { } std::unique_ptr device::alloc_bo(size_t size, uint32_t flags) { - shim_xcl_bo_flags f{}; - f.flags = flags; - return alloc_bo(AMDXDNA_INVALID_CTX_HANDLE, size, f); + return alloc_bo(AMDXDNA_INVALID_CTX_HANDLE, size, + shim_xcl_bo_flags{.flags = flags}); } std::unique_ptr device::import_bo(pid_t pid, int ehdl) { diff --git a/runtime/src/iree-amd-aie/driver/xrt-lite/shim/linux/kmq/hwctx.cpp b/runtime/src/iree-amd-aie/driver/xrt-lite/shim/linux/kmq/hwctx.cpp index 2deefa14b..574fc8a20 100644 --- a/runtime/src/iree-amd-aie/driver/xrt-lite/shim/linux/kmq/hwctx.cpp +++ b/runtime/src/iree-amd-aie/driver/xrt-lite/shim/linux/kmq/hwctx.cpp @@ -85,6 +85,7 @@ hw_ctx::~hw_ctx() { cuidx_t hw_ctx::open_cu_context(const std::string &cu_name) { for (uint32_t i = 0; i < m_cu_info.size(); i++) { auto &ci = m_cu_info[i]; + shim_debug("ci.m_name %s\n", ci.m_name.c_str()); if (ci.m_name == cu_name) return cuidx_t{.index = i}; } diff --git a/runtime/src/iree-amd-aie/driver/xrt/CMakeLists.txt b/runtime/src/iree-amd-aie/driver/xrt/CMakeLists.txt index 81f90689b..9d9cabd44 100644 --- a/runtime/src/iree-amd-aie/driver/xrt/CMakeLists.txt +++ b/runtime/src/iree-amd-aie/driver/xrt/CMakeLists.txt @@ -38,6 +38,7 @@ iree_cc_library( "native_executable.h" "native_executable.cc" "nop_semaphore.cc" + "nop_semaphore.h" "nop_executable_cache.h" "nop_executable_cache.cc" DEPS @@ -48,6 +49,7 @@ iree_cc_library( iree::base::internal::flatcc::parsing iree::hal::utils::deferred_command_buffer iree::hal::utils::file_transfer + iree::hal::utils::semaphore_base iree::hal iree-amd-aie::schemas::xrt_executable_def_c_fbs # hide the target from all exports so it doesn't need to be installed diff --git a/runtime/src/iree-amd-aie/driver/xrt/cts/CMakeLists.txt b/runtime/src/iree-amd-aie/driver/xrt/cts/CMakeLists.txt new file mode 100644 index 000000000..07746787d --- /dev/null +++ b/runtime/src/iree-amd-aie/driver/xrt/cts/CMakeLists.txt @@ -0,0 +1,111 @@ +# Copyright 2024 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +include(CMakeDependentOption) + +iree_hal_cts_test_suite( + DRIVER_NAME + xrt + DRIVER_REGISTRATION_HDR + "iree-amd-aie/driver/xrt/registration/driver_module.h" + DRIVER_REGISTRATION_FN + "iree_hal_xrt_driver_module_register" + COMPILER_TARGET_BACKEND + "amd-aie" + EXECUTABLE_FORMAT + "\"amdaie-xclbin-fb\"" + DEPS + iree-amd-aie::driver::xrt::registration + INCLUDED_TESTS + "allocator" + "buffer_mapping" + "driver" +) + +set(PEANO_INSTALL_DIR "" CACHE PATH "") +set(VITIS_DIR "" CACHE PATH "") +if((NOT PEANO_INSTALL_DIR) AND (NOT VITIS_DIR)) + message(FATAL_ERROR "either PEANO_INSTALL_DIR or VITIS_DIR must be set") +endif() +cmake_dependent_option(USE_CHESS "" "1" "VITIS_DIR" "0") +set(TARGET_DEVICE "npu1_4col" CACHE STRING "") + +iree_bytecode_module( + NAME + xrt_executable_cache_test_module + MODULE_FILE_NAME + xrt_executable_cache_test.bin + SRC + "${CMAKE_CURRENT_LIST_DIR}/executable_cache_test.mlir" + FLAGS + --compile-mode=hal-executable + --iree-hal-dump-executable-files-to=${CMAKE_CURRENT_BINARY_DIR} + --iree-hal-target-backends=amd-aie + --iree-amdaie-lower-to-aie-pipeline=air + --iree-amdaie-target-device=${TARGET_DEVICE} + --iree-amd-aie-peano-install-dir=${PEANO_INSTALL_DIR} + --iree-amd-aie-vitis-install-dir=${VITIS_DIR} + --iree-amd-aie-enable-chess=$ + --iree-amd-aie-show-invoked-commands + --iree-hal-memoization=false + --iree-hal-indirect-command-buffers=false + DEPS + iree-aie-xclbinutil + PUBLIC + TESTONLY +) + +iree_c_embed_data( + NAME + xrt_executables_c + SRCS + xrt_executable_cache_test.bin + C_FILE_OUTPUT + xrt_executables_c.c + H_FILE_OUTPUT + xrt_executables_c.h + IDENTIFIER + iree_cts_testdata_executables_aie_xrt + STRIP_PREFIX + xrt_ + DEPENDS + ::xrt_executable_cache_test_module + FLATTEN + PUBLIC + TESTONLY +) + +iree_cc_test( + NAME + xrt_executable_cache_test + SRCS + executable_cache_test.cc + DEPS + ::xrt_executables_c + iree-amd-aie::driver::xrt::registration + iree::base + iree::hal + iree::hal::cts::cts_test_base + iree::testing::gtest_main +) + +iree_cc_test( + NAME + xrt_dispatch_test + SRCS + matmul_dispatch_test.cc + DEPS + ::xrt_executables_c + iree-amd-aie::driver::xrt::registration + iree::base + iree::hal + iree::hal::cts::cts_test_base + iree::testing::gtest_main + iree::tools::testing::e2e::e2e_test_util +) + +target_include_directories(iree-amd-aie_driver_xrt_cts_xrt_executable_cache_test PRIVATE "${CMAKE_CURRENT_BINARY_DIR}") +target_include_directories(iree-amd-aie_driver_xrt_cts_xrt_dispatch_test PRIVATE "${CMAKE_CURRENT_BINARY_DIR}") diff --git a/runtime/src/iree-amd-aie/driver/xrt/cts/executable_cache_test.cc b/runtime/src/iree-amd-aie/driver/xrt/cts/executable_cache_test.cc new file mode 100644 index 000000000..3e9411cf2 --- /dev/null +++ b/runtime/src/iree-amd-aie/driver/xrt/cts/executable_cache_test.cc @@ -0,0 +1,85 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree-amd-aie/driver/xrt/registration/driver_module.h" +#include "iree/base/api.h" +#include "iree/base/string_view.h" +#include "iree/hal/api.h" +#include "iree/hal/cts/cts_test_base.h" +#include "iree/testing/gtest.h" +#include "iree/testing/status_matchers.h" +#include "xrt_executables_c.h" + +namespace iree::hal::cts { + +const char* get_test_driver_name() { return "xrt"; } + +iree_status_t register_test_driver(iree_hal_driver_registry_t* registry) { + return iree_hal_xrt_driver_module_register(registry); +} + +const char* get_test_executable_format() { return "amdaie-xclbin-fb"; } + +iree_const_byte_span_t get_test_executable_data(iree_string_view_t file_name) { + const struct iree_file_toc_t* toc = + iree_cts_testdata_executables_aie_xrt_create(); + const auto& file = toc[0]; + return iree_make_const_byte_span(file.data, file.size); +} + +class ExecutableCacheTest : public CTSTestBase<> {}; + +TEST_F(ExecutableCacheTest, Create) { + iree_status_t loop_status = iree_ok_status(); + iree_hal_executable_cache_t* executable_cache = nullptr; + IREE_ASSERT_OK(iree_hal_executable_cache_create( + device_, iree_make_cstring_view("default"), + iree_loop_inline(&loop_status), &executable_cache)); + + iree_hal_executable_cache_release(executable_cache); + IREE_ASSERT_OK(loop_status); +} + +TEST_F(ExecutableCacheTest, CantPrepareUnknownFormat) { + iree_status_t loop_status = iree_ok_status(); + iree_hal_executable_cache_t* executable_cache = nullptr; + IREE_ASSERT_OK(iree_hal_executable_cache_create( + device_, iree_make_cstring_view("default"), + iree_loop_inline(&loop_status), &executable_cache)); + + EXPECT_FALSE(iree_hal_executable_cache_can_prepare_format( + executable_cache, /*caching_mode=*/0, iree_make_cstring_view("FOO?"))); + + iree_hal_executable_cache_release(executable_cache); + IREE_ASSERT_OK(loop_status); +} + +TEST_F(ExecutableCacheTest, PrepareExecutable) { + iree_status_t loop_status = iree_ok_status(); + iree_hal_executable_cache_t* executable_cache = nullptr; + IREE_ASSERT_OK(iree_hal_executable_cache_create( + device_, iree_make_cstring_view("default"), + iree_loop_inline(&loop_status), &executable_cache)); + + iree_hal_executable_params_t executable_params; + iree_hal_executable_params_initialize(&executable_params); + executable_params.caching_mode = + IREE_HAL_EXECUTABLE_CACHING_MODE_ALIAS_PROVIDED_DATA; + executable_params.executable_format = + iree_make_cstring_view(get_test_executable_format()); + executable_params.executable_data = get_test_executable_data( + iree_make_cstring_view("executable_cache_test.bin")); + + iree_hal_executable_t* executable = nullptr; + IREE_ASSERT_OK(iree_hal_executable_cache_prepare_executable( + executable_cache, &executable_params, &executable)); + + iree_hal_executable_release(executable); + iree_hal_executable_cache_release(executable_cache); + IREE_ASSERT_OK(loop_status); +} + +} // namespace iree::hal::cts diff --git a/runtime/src/iree-amd-aie/driver/xrt/cts/executable_cache_test.mlir b/runtime/src/iree-amd-aie/driver/xrt/cts/executable_cache_test.mlir new file mode 100644 index 000000000..dedbcab6b --- /dev/null +++ b/runtime/src/iree-amd-aie/driver/xrt/cts/executable_cache_test.mlir @@ -0,0 +1,33 @@ +// bootstrapped from https://github.com/nod-ai/iree-amd-aie/blob/9c4c167baf89a279888fba8db75907845946077c/tests/samples/matmul_pack_peel_objectfifo_e2e.mlir + +#pipeline_layout = #hal.pipeline.layout< + bindings = [ + #hal.pipeline.binding, + #hal.pipeline.binding, + #hal.pipeline.binding + ], + flags = Indirect +> +hal.executable.source public @amdaie_fb { + hal.executable.export public @matmul_f32_dispatch_0_matmul_32x32x32_f32 ordinal(0) layout(#pipeline_layout) { + ^bb0(%arg0: !hal.device): + %x, %y, %z = flow.dispatch.workgroup_count_from_slice + hal.return %x, %y, %z : index, index, index + } + builtin.module { + func.func @matmul_f32_dispatch_0_matmul_32x32x32_f32() { + %c0_f32 = arith.constant 0.0 : f32 + %c0 = arith.constant 0 : index + %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : !flow.dispatch.tensor> + %1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : !flow.dispatch.tensor> + %2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) alignment(64) offset(%c0) flags(Indirect) : !flow.dispatch.tensor> + %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [32, 32], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<32x32xf32> + %4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [32, 32], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<32x32xf32> + %5 = tensor.empty() : tensor<32x32xf32> + %6 = linalg.fill ins(%c0_f32 : f32) outs(%5 : tensor<32x32xf32>) -> tensor<32x32xf32> + %7 = linalg.matmul ins(%3, %4 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%6 : tensor<32x32xf32>) -> tensor<32x32xf32> + flow.dispatch.tensor.store %7, %2, offsets = [0, 0], sizes = [32, 32], strides = [1, 1] : tensor<32x32xf32> -> !flow.dispatch.tensor> + return + } + } +} diff --git a/runtime/src/iree-amd-aie/driver/xrt/cts/matmul_dispatch_test.cc b/runtime/src/iree-amd-aie/driver/xrt/cts/matmul_dispatch_test.cc new file mode 100644 index 000000000..c48ea13f7 --- /dev/null +++ b/runtime/src/iree-amd-aie/driver/xrt/cts/matmul_dispatch_test.cc @@ -0,0 +1,224 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree-amd-aie/driver/xrt/registration/driver_module.h" +#include "iree/base/api.h" +#include "iree/base/string_view.h" +#include "iree/hal/api.h" +#include "iree/hal/buffer_view_util.h" +#include "iree/hal/cts/cts_test_base.h" +#include "iree/testing/gtest.h" +#include "iree/testing/status_matchers.h" +#include "tools/testing/e2e/test_utils.h" +#include "xrt_executables_c.h" + +namespace iree::hal::cts { + +const char* get_test_driver_name() { return "xrt"; } + +iree_status_t register_test_driver(iree_hal_driver_registry_t* registry) { + return iree_hal_xrt_driver_module_register(registry); +} + +const char* get_test_executable_format() { return "amdaie-xclbin-fb"; } + +iree_const_byte_span_t get_test_executable_data(iree_string_view_t file_name) { + const struct iree_file_toc_t* toc = + iree_cts_testdata_executables_aie_xrt_create(); + const auto& file = toc[0]; + return iree_make_const_byte_span(file.data, file.size); +} + +class MatMulDispatchTest + : public CTSTestBase<::testing::TestWithParam> { + protected: + void PrepareMatmulExecutable() { + IREE_ASSERT_OK(iree_hal_executable_cache_create( + device_, iree_make_cstring_view("default"), + iree_loop_inline(&loop_status_), &executable_cache_)); + + iree_hal_executable_params_t executable_params; + iree_hal_executable_params_initialize(&executable_params); + executable_params.caching_mode = + IREE_HAL_EXECUTABLE_CACHING_MODE_ALIAS_PROVIDED_DATA; + executable_params.executable_format = + iree_make_cstring_view(get_test_executable_format()); + executable_params.executable_data = get_test_executable_data( + iree_make_cstring_view("xrt_executable_cache_test.bin")); + + IREE_ASSERT_OK(iree_hal_executable_cache_prepare_executable( + executable_cache_, &executable_params, &executable_)); + } + + void CleanupExecutable() { + iree_hal_executable_release(executable_); + iree_hal_executable_cache_release(executable_cache_); + IREE_ASSERT_OK(loop_status_); + } + + iree_status_t loop_status_ = iree_ok_status(); + iree_hal_executable_cache_t* executable_cache_ = nullptr; + iree_hal_executable_t* executable_ = nullptr; +}; + +int32_t generate_random_number(iree_hal_element_type_t element_type, + int32_t seed) { + int32_t min = 0; + int32_t max = 0; + iree_test_utils_get_min_max_for_element_type(element_type, &min, &max); + uint32_t range = (max - min + 1); + return (int32_t)iree_test_utils_pseudorandom_range( + reinterpret_cast(&seed), range) + + min; +} + +TEST_F(MatMulDispatchTest, Create) { + iree_hal_command_buffer_t* command_buffer = nullptr; + IREE_ASSERT_OK(iree_hal_command_buffer_create( + device_, IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT, + IREE_HAL_COMMAND_CATEGORY_DISPATCH, IREE_HAL_QUEUE_AFFINITY_ANY, + /*binding_capacity=*/0, &command_buffer)); + + EXPECT_TRUE((iree_hal_command_buffer_allowed_categories(command_buffer) & + IREE_HAL_COMMAND_CATEGORY_DISPATCH) == + IREE_HAL_COMMAND_CATEGORY_DISPATCH); + + iree_hal_command_buffer_release(command_buffer); +} + +TEST_F(MatMulDispatchTest, BeginEnd) { + iree_hal_command_buffer_t* command_buffer = nullptr; + IREE_ASSERT_OK(iree_hal_command_buffer_create( + device_, IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT, + IREE_HAL_COMMAND_CATEGORY_DISPATCH, IREE_HAL_QUEUE_AFFINITY_ANY, + /*binding_capacity=*/0, &command_buffer)); + + IREE_ASSERT_OK(iree_hal_command_buffer_begin(command_buffer)); + IREE_ASSERT_OK(iree_hal_command_buffer_end(command_buffer)); + + iree_hal_command_buffer_release(command_buffer); +} + +TEST_F(MatMulDispatchTest, SubmitEmpty) { + iree_hal_command_buffer_t* command_buffer = nullptr; + IREE_ASSERT_OK(iree_hal_command_buffer_create( + device_, IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT, + IREE_HAL_COMMAND_CATEGORY_DISPATCH, IREE_HAL_QUEUE_AFFINITY_ANY, + /*binding_capacity=*/0, &command_buffer)); + + IREE_ASSERT_OK(iree_hal_command_buffer_begin(command_buffer)); + IREE_ASSERT_OK(iree_hal_command_buffer_end(command_buffer)); + + IREE_ASSERT_OK(SubmitCommandBufferAndWait(command_buffer)); + + iree_hal_command_buffer_release(command_buffer); +} + +TEST_P(MatMulDispatchTest, DispatchMatmul) { + PrepareMatmulExecutable(); + + // Create input buffer. + constexpr iree_device_size_t WIDTH = 32; + constexpr iree_device_size_t M = WIDTH, K = WIDTH, N = WIDTH; + iree_hal_buffer_t *input_A = nullptr, *input_B = nullptr, *output_C = nullptr; + int32_t seed = + std::chrono::high_resolution_clock::now().time_since_epoch().count() >> + 32; + int32_t a = generate_random_number( + iree_hal_element_types_t::IREE_HAL_ELEMENT_TYPE_FLOAT_32, seed); + int32_t b = generate_random_number( + iree_hal_element_types_t::IREE_HAL_ELEMENT_TYPE_FLOAT_32, seed + 1); + CreateFilledDeviceBuffer(M * K * sizeof(float), a, &input_A); + CreateFilledDeviceBuffer(K * N * sizeof(float), b, &input_B); + CreateFilledDeviceBuffer(M * N * sizeof(float), -1, &output_C); + + iree_hal_buffer_ref_t binding_refs[3]; + iree_hal_buffer_binding_table_t binding_table = + iree_hal_buffer_binding_table_empty(); + binding_refs[0] = { + /*binding=*/0, + /*buffer_slot=*/0, + /*buffer=*/input_A, + /*offset=*/0, + /*length=*/M * K * sizeof(float), + }; + binding_refs[1] = { + /*binding=*/0, + /*buffer_slot=*/0, + /*buffer=*/input_B, + /*offset=*/0, + /*length=*/K * N * sizeof(float), + }; + binding_refs[2] = { + /*binding=*/0, + /*buffer_slot=*/0, + /*buffer=*/output_C, + /*offset=*/0, + /*length=*/M * N * sizeof(float), + }; + iree_hal_buffer_ref_list_t bindings = { + /*.count=*/IREE_ARRAYSIZE(binding_refs), + /*.values=*/binding_refs, + }; + + iree_hal_command_buffer_t* command_buffer = nullptr; + IREE_ASSERT_OK(iree_hal_command_buffer_create( + device_, IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT, + IREE_HAL_COMMAND_CATEGORY_DISPATCH, IREE_HAL_QUEUE_AFFINITY_ANY, + binding_table.count, &command_buffer)); + IREE_ASSERT_OK(iree_hal_command_buffer_begin(command_buffer)); + + uint32_t workgroup_count[3] = {1, 1, 1}; + IREE_ASSERT_OK(iree_hal_command_buffer_dispatch( + command_buffer, executable_, /*entry_point=*/0, workgroup_count, + iree_const_byte_span_empty(), bindings, IREE_HAL_DISPATCH_FLAG_NONE)); + + IREE_ASSERT_OK(iree_hal_command_buffer_execution_barrier( + command_buffer, + /*source_stage_mask=*/IREE_HAL_EXECUTION_STAGE_DISPATCH | + IREE_HAL_EXECUTION_STAGE_TRANSFER | + IREE_HAL_EXECUTION_STAGE_COMMAND_RETIRE, + /*target_stage_mask=*/IREE_HAL_EXECUTION_STAGE_COMMAND_ISSUE | + IREE_HAL_EXECUTION_STAGE_DISPATCH | IREE_HAL_EXECUTION_STAGE_TRANSFER, + IREE_HAL_EXECUTION_BARRIER_FLAG_NONE, /*memory_barrier_count=*/0, + /*memory_barriers=*/nullptr, + /*buffer_barrier_count=*/0, /*buffer_barriers=*/nullptr)); + + IREE_ASSERT_OK(iree_hal_command_buffer_end(command_buffer)); + + IREE_ASSERT_OK(SubmitCommandBufferAndWait(command_buffer, binding_table)); + + std::vector output_values; + output_values.reserve(M * N); + IREE_ASSERT_OK(iree_hal_device_transfer_d2h( + device_, output_C, + /*source_offset=*/0, output_values.data(), M * N * sizeof(float), + IREE_HAL_TRANSFER_BUFFER_FLAG_DEFAULT, iree_infinite_timeout())); + std::vector correct_output_values; + correct_output_values.reserve(M * N); + std::fill_n(correct_output_values.data(), M * N, (float)WIDTH * (a * b)); + int n_wrong = 0; + for (int i = 0; i < M * N; ++i) { + if (output_values[i] != correct_output_values[i]) { + std::cout << "wrong @ i:" << i << ", " << output_values[i] + << " != " << correct_output_values[i] << "\n"; + n_wrong += 1; + } + } + EXPECT_EQ(n_wrong, 0); + + iree_hal_command_buffer_release(command_buffer); + iree_hal_buffer_release(output_C); + iree_hal_buffer_release(input_B); + iree_hal_buffer_release(input_A); + CleanupExecutable(); +} + +INSTANTIATE_TEST_SUITE_P(MatMulDispatchTest, MatMulDispatchTest, + ::testing::Values(RecordingType::kDirect), + GenerateTestName()); + +} // namespace iree::hal::cts