diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 11cf13cc35c2..6edcff51e7ad 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -32,7 +32,6 @@ # Experimental # It's experimental, but we still don't want any old directory added here. /experimental/ @benvanik @stellaraccident -/experimental/rocm/ @benvanik /experimental/web/ @ScottTodd /experimental/webgpu/ @benvanik @ScottTodd diff --git a/.github/workflows/pkgci_regression_test.yml b/.github/workflows/pkgci_regression_test.yml index e6fb8ad2b210..9849c574dd72 100644 --- a/.github/workflows/pkgci_regression_test.yml +++ b/.github/workflows/pkgci_regression_test.yml @@ -220,7 +220,7 @@ jobs: --goldentime-rocm-unet-ms 419.0 \ --goldentime-rocm-clip-ms 18.5 \ --goldentime-rocm-vae-ms 337.0 \ - --goldendispatch-rocm-unet 1545 \ + --goldendispatch-rocm-unet 1531 \ --goldendispatch-rocm-clip 1139 \ --goldendispatch-rocm-vae 247 \ --goldensize-rocm-unet-bytes 2280000 \ @@ -241,7 +241,7 @@ jobs: --goldentime-rocm-unet-ms 95.0 \ --goldentime-rocm-clip-ms 15.5 \ --goldentime-rocm-vae-ms 80.0 \ - --goldendispatch-rocm-unet 1545 \ + --goldendispatch-rocm-unet 1531 \ --goldendispatch-rocm-clip 1139 \ --goldendispatch-rocm-vae 247 \ --goldensize-rocm-unet-bytes 2270000 \ diff --git a/.gitmodules b/.gitmodules index 58e22edefcec..2c6c117e9c57 100644 --- a/.gitmodules +++ b/.gitmodules @@ -7,10 +7,6 @@ [submodule "third_party/vulkan_headers"] path = third_party/vulkan_headers url = https://github.com/KhronosGroup/Vulkan-Headers.git -[submodule "third_party/pybind11"] - path = third_party/pybind11 - url = https://github.com/pybind/pybind11.git - branch = stable [submodule "third_party/benchmark"] path = third_party/benchmark url = https://github.com/google/benchmark.git diff --git a/CMakeLists.txt b/CMakeLists.txt index 3e34ee8b05bd..a8a9b70bfa44 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -772,6 +772,30 @@ endif() # MLIR/LLVM Dependency #------------------------------------------------------------------------------- +# Both the IREE and MLIR Python bindings require pybind11. We initialize it here +# at the top level so that everything uses ours consistently. +if(IREE_BUILD_PYTHON_BINDINGS AND IREE_BUILD_COMPILER) + set(pybind11_VERSION 2.13.6) + include(FetchContent) + FetchContent_Declare( + pybind11 + GIT_REPOSITORY https://github.com/pybind/pybind11 + GIT_TAG v${pybind11_VERSION} + ) + set(PYBIND11_FINDPYTHON ON) + FetchContent_MakeAvailable(pybind11) + # pybind11 source fetches do not include find_package integration, which is + # a shame since sub-projects can require that to work. If we were using + # CMake 3.24, we could just add OVERRIDE_FIND_PACKAGE to the + # FetchContent_Declare call above and it would take care of doing the + # following to let subsequent sub-project find_package calls to resolve + # successfully. + set(pybind11_DIR "${pybind11_BINARY_DIR}") + file(WRITE "${pybind11_BINARY_DIR}/pybind11Config.cmake" "") + file(WRITE "${pybind11_BINARY_DIR}/pybind11ConfigVersion.cmake" + "set(PACKAGE_VERSION ${pybind11_VERSION})\nset(PACKAGE_VERSION_COMPATIBLE TRUE)") +endif() + if(NOT IREE_BUILD_COMPILER) message(STATUS "Not adding LLVM/MLIR because the configuration does not require it") else() @@ -921,19 +945,6 @@ if(IREE_BUILD_TESTS) include(iree_configure_testing) endif() -if(IREE_BUILD_PYTHON_BINDINGS) - # The compiler uses pybind11 - if(IREE_BUILD_COMPILER) - if(NOT TARGET pybind11::module) - message(STATUS "Using bundled pybind11") - set(PYBIND11_FINDPYTHON ON) - add_subdirectory(third_party/pybind11 EXCLUDE_FROM_ALL) - else() - message(STATUS "Not including bundled pybind11 (already configured)") - endif() - endif() -endif() - if(IREE_TARGET_BACKEND_METAL_SPIRV) # SPIRV-Cross is needed to cross compile SPIR-V into MSL source code. iree_set_spirv_cross_cmake_options() diff --git a/build_tools/bazel/workspace.bzl b/build_tools/bazel/workspace.bzl index 654649508d3c..bb39437ea561 100644 --- a/build_tools/bazel/workspace.bzl +++ b/build_tools/bazel/workspace.bzl @@ -147,6 +147,13 @@ def configure_iree_submodule_deps(iree_repo_alias = "@", iree_path = "./"): path = paths.join(iree_path, "third_party/nccl"), ) + maybe( + native.new_local_repository, + name = "hsa_runtime_headers", + build_file = iree_repo_alias + "//:build_tools/third_party/hsa-runtime-headers/BUILD.overlay", + path = paths.join(iree_path, "third_party/hsa-runtime-headers"), + ) + maybe( native.new_local_repository, name = "webgpu_headers", diff --git a/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py b/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py index cecc21777f5f..0c0469eb335c 100644 --- a/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py +++ b/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py @@ -113,6 +113,7 @@ def __init__(self, repo_map: Dict[str, str]): "@com_google_googletest//:gtest": ["gmock", "gtest"], "@spirv_cross//:spirv_cross_lib": ["spirv-cross-msl"], "@cpuinfo": ["${IREE_CPUINFO_TARGET}"], + "@hsa_runtime_headers": ["hsa_runtime::headers"], "@webgpu_headers": [], } ) diff --git a/build_tools/cmake/build_and_test_byo_llvm.sh b/build_tools/cmake/build_and_test_byo_llvm.sh index 043bc98c8784..d233664be2f1 100755 --- a/build_tools/cmake/build_and_test_byo_llvm.sh +++ b/build_tools/cmake/build_and_test_byo_llvm.sh @@ -24,6 +24,12 @@ echo "Setting up venv at $VENV_DIR" python3 -m venv "$VENV_DIR" source "$VENV_DIR/bin/activate" python -m pip install -r runtime/bindings/python/iree/runtime/build_requirements.txt +python -m pip install -r third_party/llvm-project/mlir/python/requirements.txt +# Note: IREE's Python bindings for Python 3.13 are build with support for +# free-threading for which support was added to pybind with version 2.13.0. +# Therefore, we upgrade to a more recent version and avoid mixing of different +# pybind versions. +python -m pip install pybind11==2.13.6 # Note: by using the `build_llvm` action here, we are exercising byo_llvm.sh's # ability to build LLVM... from our own third_party/llvm-project. That's not diff --git a/build_tools/llvm/byo_llvm.sh b/build_tools/llvm/byo_llvm.sh index 0f3d0fda3d4d..d88fb57caf45 100755 --- a/build_tools/llvm/byo_llvm.sh +++ b/build_tools/llvm/byo_llvm.sh @@ -113,6 +113,9 @@ do_build_mlir() { cmake_options="-DLLVM_DIR='${main_install_dir}/lib/cmake/llvm'" cmake_options="${cmake_options} -DPython3_EXECUTABLE='$(which $python3_command)'" + # Note: Building the MLIR Python bindings requires the installation of + # dependencies as specified in `mlir/python/requirements.txt`, which among + # others include pybind11. cmake_options="${cmake_options} -DMLIR_ENABLE_BINDINGS_PYTHON=ON" cmake_options="${cmake_options} -DCMAKE_INSTALL_PREFIX=${mlir_install_dir}" cmake_options="${cmake_options} -C $TD/mlir_config.cmake" diff --git a/build_tools/third_party/hsa-runtime-headers/BUILD.overlay b/build_tools/third_party/hsa-runtime-headers/BUILD.overlay new file mode 100644 index 000000000000..b3e0b85dd47b --- /dev/null +++ b/build_tools/third_party/hsa-runtime-headers/BUILD.overlay @@ -0,0 +1,16 @@ +# Copyright 2024 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +package(default_visibility = ["//visibility:public"]) + +cc_library( + name = "hsa_runtime_headers", + hdrs = glob([ + "include/hsa/*.h", + ]), + include_prefix = "third_party/hsa-runtime-headers/", + includes = ["include"], +) diff --git a/build_tools/third_party/hsa-runtime-headers/CMakeLists.txt b/build_tools/third_party/hsa-runtime-headers/CMakeLists.txt new file mode 100644 index 000000000000..e939e9895122 --- /dev/null +++ b/build_tools/third_party/hsa-runtime-headers/CMakeLists.txt @@ -0,0 +1,28 @@ +# Copyright 2024 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +set(HSA_RUNTIME_HEADERS_ROOT "${IREE_ROOT_DIR}/third_party/hsa-runtime-headers/") + +external_cc_library( + PACKAGE + hsa_runtime + NAME + headers + ROOT + ${HSA_RUNTIME_HEADERS_ROOT} + SYSTEM_INCLUDES + ${HSA_RUNTIME_HEADERS_ROOT}/include/ + PUBLIC +) + +iree_install_targets( + TARGETS + hsa_runtime_headers + COMPONENT + IREEBundledLibraries + EXPORT_SET + Runtime +) diff --git a/compiler/plugins/input/Torch/PluginRegistration.cpp b/compiler/plugins/input/Torch/PluginRegistration.cpp index b9a497f12550..6f79686ad267 100644 --- a/compiler/plugins/input/Torch/PluginRegistration.cpp +++ b/compiler/plugins/input/Torch/PluginRegistration.cpp @@ -57,8 +57,19 @@ struct TorchSession OpPassManager &passManager, std::string_view typeMnemonic) override { if (typeMnemonic == "onnx") { // ONNX input is a pre-processing step to torch. - passManager.addNestedPass( - mlir::torch::onnx_c::createTorchOnnxToTorchPass()); + mlir::torch::Torch::TorchLoweringPipelineOptions torchOnnxPipelineOptions; + // The `aten.flatten.using_ints` and `aten.unflatten.int` are added to the + // list of backend legal ops so that they are not decomposed into the + // `aten.view` op during the run of `DecomposeComplexOps` pass. The issue + // with this is that the `aten.view` op eventually lowers to + // `tensor.reshape` op while there exists a direct torch->linalg lowering + // for both the flatten/unflatten ops which lowers to + // `tensor.collapse_shape/expand_shape` op, and this is a more preferred + // path for the downstream pipeline. + torchOnnxPipelineOptions.backendLegalOps = {"aten.flatten.using_ints", + "aten.unflatten.int"}; + mlir::torch::Torch::createTorchOnnxToTorchBackendPipeline( + passManager, torchOnnxPipelineOptions); } if (typeMnemonic == "torch" || typeMnemonic == "onnx") { diff --git a/compiler/plugins/target/CUDA/CUDATarget.cpp b/compiler/plugins/target/CUDA/CUDATarget.cpp index 18896f2bb0fe..ffc49b57fa7d 100644 --- a/compiler/plugins/target/CUDA/CUDATarget.cpp +++ b/compiler/plugins/target/CUDA/CUDATarget.cpp @@ -461,6 +461,10 @@ class CUDATargetBackend final : public TargetBackend { buildLLVMGPUCodegenPassPipeline(passManager, false); } + void buildLinkingPassPipeline(OpPassManager &passManager) override { + buildLLVMGPULinkingPassPipeline(passManager, "cuda"); + } + LogicalResult serializeExecutable(const SerializationOptions &serOptions, IREE::HAL::ExecutableVariantOp variantOp, OpBuilder &executableBuilder) override { diff --git a/compiler/plugins/target/CUDA/test/smoketest.mlir b/compiler/plugins/target/CUDA/test/smoketest.mlir index 6e6fa946fcd9..0c12f0652e84 100644 --- a/compiler/plugins/target/CUDA/test/smoketest.mlir +++ b/compiler/plugins/target/CUDA/test/smoketest.mlir @@ -1,8 +1,6 @@ // RUN: iree-opt --split-input-file --iree-hal-transformation-pipeline --iree-gpu-test-target=sm_60 %s | FileCheck %s // RUN: iree-opt --split-input-file --iree-hal-transformation-pipeline --iree-gpu-test-target=sm_60 --iree-hal-dump-executable-binaries-to=- %s 2>&1 | FileCheck %s --check-prefix=PTX -#map = affine_map<(d0) -> (d0)> - module attributes { hal.device.targets = [ #hal.device.target<"cuda", [ @@ -11,13 +9,13 @@ module attributes { ] } { -stream.executable public @add_dispatch_0 { - stream.executable.export @add_dispatch_0 workgroups(%arg0 : index) -> (index, index, index) { +stream.executable public @add_dispatch_executable { + stream.executable.export @add_dispatch workgroups(%arg0 : index) -> (index, index, index) { %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg0 stream.return %x, %y, %z : index, index, index } builtin.module { - func.func @add_dispatch_0(%arg0_binding: !stream.binding, %arg1_binding: !stream.binding, %arg2_binding: !stream.binding) { + func.func @add_dispatch(%arg0_binding: !stream.binding, %arg1_binding: !stream.binding, %arg2_binding: !stream.binding) { %c0 = arith.constant 0 : index %arg0 = stream.binding.subspan %arg0_binding[%c0] : !stream.binding -> !flow.dispatch.tensor> %arg1 = stream.binding.subspan %arg1_binding[%c0] : !stream.binding -> !flow.dispatch.tensor> @@ -26,7 +24,7 @@ stream.executable public @add_dispatch_0 { %1 = flow.dispatch.tensor.load %arg0, offsets=[0], sizes=[16], strides=[1] : !flow.dispatch.tensor> -> tensor<16xf32> %2 = flow.dispatch.tensor.load %arg1, offsets=[0], sizes=[16], strides=[1] : !flow.dispatch.tensor> -> tensor<16xf32> %3 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%1, %2 : tensor<16xf32>, tensor<16xf32>) outs(%0 : tensor<16xf32>) { - ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors + ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): %4 = arith.addf %arg3, %arg4 : f32 linalg.yield %4 : f32 } -> tensor<16xf32> @@ -36,12 +34,42 @@ stream.executable public @add_dispatch_0 { } } +stream.executable public @mul_dispatch_executable { + stream.executable.export @mul_dispatch workgroups(%arg0 : index) -> (index, index, index) { + %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg0 + stream.return %x, %y, %z : index, index, index + } + builtin.module { + func.func @mul_dispatch(%arg0_binding: !stream.binding, %arg1_binding: !stream.binding, %arg2_binding: !stream.binding) { + %c0 = arith.constant 0 : index + %arg0 = stream.binding.subspan %arg0_binding[%c0] : !stream.binding -> !flow.dispatch.tensor> + %arg1 = stream.binding.subspan %arg1_binding[%c0] : !stream.binding -> !flow.dispatch.tensor> + %arg2 = stream.binding.subspan %arg2_binding[%c0] : !stream.binding -> !flow.dispatch.tensor> + %0 = tensor.empty() : tensor<16xf32> + %1 = flow.dispatch.tensor.load %arg0, offsets=[0], sizes=[16], strides=[1] : !flow.dispatch.tensor> -> tensor<16xf32> + %2 = flow.dispatch.tensor.load %arg1, offsets=[0], sizes=[16], strides=[1] : !flow.dispatch.tensor> -> tensor<16xf32> + %3 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%1, %2 : tensor<16xf32>, tensor<16xf32>) outs(%0 : tensor<16xf32>) { + ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): + %4 = arith.mulf %arg3, %arg4 : f32 + linalg.yield %4 : f32 + } -> tensor<16xf32> + flow.dispatch.tensor.store %3, %arg2, offsets=[0], sizes=[16], strides=[1] : tensor<16xf32> -> !flow.dispatch.tensor> + return + } + } +} + } -// PTX: .entry add_dispatch_0 +// PTX: .entry add_dispatch // PTX: .maxntid 64, 1, 1 // PTX: add.rn.f32 -// CHECK: hal.executable.binary public @cuda_nvptx_fb attributes { +// PTX: .entry mul_dispatch +// PTX: .maxntid 64, 1, 1 +// PTX: mul.rn.f32 + +// CHECK: hal.executable public @smoketest_linked +// CHECK-NEXT: hal.executable.binary public @cuda_nvptx_fb attributes { // CHECK-SAME: data = dense // CHECK-SAME: format = "cuda-nvptx-fb" diff --git a/compiler/plugins/target/LLVMCPU/LLVMCPUTarget.cpp b/compiler/plugins/target/LLVMCPU/LLVMCPUTarget.cpp index 7db50acd0033..ee8e256321a5 100644 --- a/compiler/plugins/target/LLVMCPU/LLVMCPUTarget.cpp +++ b/compiler/plugins/target/LLVMCPU/LLVMCPUTarget.cpp @@ -241,7 +241,7 @@ class LLVMCPUTargetBackend final : public TargetBackend { } void buildLinkingPassPipeline(OpPassManager &passManager) override { - buildLLVMCPULinkingPassPipeline(passManager); + buildLLVMCPULinkingPassPipeline(passManager, "llvm-cpu"); } // Gets the LLVM target from |variantOp|. diff --git a/compiler/plugins/target/ROCM/BUILD.bazel b/compiler/plugins/target/ROCM/BUILD.bazel index 7962cf8e6073..9692d1aafd26 100644 --- a/compiler/plugins/target/ROCM/BUILD.bazel +++ b/compiler/plugins/target/ROCM/BUILD.bazel @@ -39,6 +39,7 @@ iree_compiler_cc_library( "//compiler/src/iree/compiler/Dialect/HAL/Utils:LLVMLinkerUtils", "//compiler/src/iree/compiler/PluginAPI", "//compiler/src/iree/compiler/Utils", + "//runtime/src/iree/schemas:amdgpu_executable_def_c_fbs", "//runtime/src/iree/schemas:executable_debug_info_c_fbs", "//runtime/src/iree/schemas:hip_executable_def_c_fbs", "@llvm-project//llvm:AMDGPUCodeGen", diff --git a/compiler/plugins/target/ROCM/CMakeLists.txt b/compiler/plugins/target/ROCM/CMakeLists.txt index 9430dca4fc16..938261acd14e 100644 --- a/compiler/plugins/target/ROCM/CMakeLists.txt +++ b/compiler/plugins/target/ROCM/CMakeLists.txt @@ -64,6 +64,7 @@ iree_cc_library( iree::compiler::Dialect::HAL::Utils::LLVMLinkerUtils iree::compiler::PluginAPI iree::compiler::Utils + iree::schemas::amdgpu_executable_def_c_fbs iree::schemas::executable_debug_info_c_fbs iree::schemas::hip_executable_def_c_fbs PUBLIC diff --git a/compiler/plugins/target/ROCM/ROCMTarget.cpp b/compiler/plugins/target/ROCM/ROCMTarget.cpp index 0a2fcc388b27..05ab66779271 100644 --- a/compiler/plugins/target/ROCM/ROCMTarget.cpp +++ b/compiler/plugins/target/ROCM/ROCMTarget.cpp @@ -23,6 +23,7 @@ #include "iree/compiler/PluginAPI/Client.h" #include "iree/compiler/Utils/FlatbufferUtils.h" #include "iree/compiler/Utils/ToolUtils.h" +#include "iree/schemas/amdgpu_executable_def_builder.h" #include "iree/schemas/hip_executable_def_builder.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringRef.h" @@ -54,13 +55,17 @@ namespace mlir::iree_compiler::IREE::HAL { namespace { -struct ROCmOptions { +// TODO(#18792): rename flags back to iree-rocm- as they are not HIP-specific. +// Only iree-hip-legacy-sync applies uniquely to HIP. +struct ROCMOptions { std::string target = ""; std::string targetFeatures = ""; std::string bitcodeDirectory = getDefaultBitcodeDirectory(); int wavesPerEu = 0; std::string enableROCMUkernels = "none"; bool legacySync = true; + bool slpVectorization = false; + bool globalISel = false; /// List of LLVM opt pass pluggins to be loaded during GPU code /// generation. The pluggins are paths to dynamic libraries that @@ -108,6 +113,13 @@ struct ROCmOptions { "to be passed to the target backend compiler during HIP " "executable serialization"), cl::ZeroOrMore, cl::cat(category)); + binder.opt( + "iree-hip-llvm-slp-vec", slpVectorization, cl::cat(category), + cl::desc( + "Enable slp vectorization in llvm opt. This can have an impact on " + "performance/numerics so its turned off by default currently.")); + binder.opt("iree-hip-llvm-global-isel", globalISel, cl::cat(category), + cl::desc("Enable global instruction selection in llvm.")); } LogicalResult verify(mlir::Builder &builder) const { @@ -187,45 +199,9 @@ static std::string translateModuleToISA(llvm::Module &module, } } // namespace -class ROCMTargetDevice final : public TargetDevice { -public: - ROCMTargetDevice(const ROCmOptions &options) : options(options) {} - - IREE::HAL::DeviceTargetAttr - getDefaultDeviceTarget(MLIRContext *context, - const TargetRegistry &targetRegistry) const override { - Builder b(context); - - SmallVector deviceConfigAttrs; - if (options.legacySync) { - // Indicates that the runtime HAL driver operates only in the legacy - // synchronous mode. - deviceConfigAttrs.emplace_back(b.getStringAttr("legacy_sync"), - b.getUnitAttr()); - } - auto deviceConfigAttr = b.getDictionaryAttr(deviceConfigAttrs); - - SmallVector executableConfigAttrs; - auto executableConfigAttr = b.getDictionaryAttr(executableConfigAttrs); - - // If we had multiple target environments we would generate one target attr - // per environment, with each setting its own environment attribute. - SmallVector executableTargetAttrs; - targetRegistry.getTargetBackend("rocm")->getDefaultExecutableTargets( - context, "rocm", executableConfigAttr, executableTargetAttrs); - - return IREE::HAL::DeviceTargetAttr::get(context, b.getStringAttr("hip"), - deviceConfigAttr, - executableTargetAttrs); - } - -private: - const ROCmOptions &options; -}; - class ROCMTargetBackend final : public TargetBackend { public: - ROCMTargetBackend(const ROCmOptions &options) : options(options) {} + ROCMTargetBackend(const ROCMOptions &options) : options(options) {} std::string getLegacyDefaultDeviceID() const override { return "hip"; } @@ -233,31 +209,43 @@ class ROCMTargetBackend final : public TargetBackend { MLIRContext *context, StringRef deviceID, DictionaryAttr deviceConfigAttr, SmallVectorImpl &executableTargetAttrs) const override { - if (auto target = getExecutableTarget(context)) + if (auto target = getExecutableTarget(deviceID, context)) { executableTargetAttrs.push_back(target); + } } IREE::HAL::ExecutableTargetAttr - getExecutableTarget(MLIRContext *context) const { + getExecutableTarget(StringRef deviceID, MLIRContext *context) const { Builder b(context); SmallVector configItems; auto addConfig = [&](StringRef name, Attribute value) { configItems.emplace_back(b.getStringAttr(name), value); }; - if (failed(options.verify(b))) + if (failed(options.verify(b))) { return nullptr; + } + + addConfig("abi", b.getStringAttr(deviceID)); + std::string format; + if (deviceID == "amdgpu") { + format = options.target; + } else { + format = "rocm-hsaco-fb"; // legacy HIP + } - if (auto target = GPU::getHIPTargetDetails(options.target, - options.targetFeatures, context)) + if (auto target = GPU::getHIPTargetDetails( + options.target, options.targetFeatures, context)) { addConfig("iree.gpu.target", target); + } addConfig("ukernels", b.getStringAttr(options.enableROCMUkernels)); - if (options.wavesPerEu > 0) + if (options.wavesPerEu > 0) { addConfig("waves_per_eu", b.getI64IntegerAttr(options.wavesPerEu)); + } return b.getAttr( - b.getStringAttr("rocm"), b.getStringAttr("rocm-hsaco-fb"), + b.getStringAttr("rocm"), b.getStringAttr(format), b.getDictionaryAttr(configItems)); } @@ -281,12 +269,17 @@ class ROCMTargetBackend final : public TargetBackend { buildLLVMGPUCodegenPassPipeline(passManager, true); } + void buildLinkingPassPipeline(OpPassManager &passManager) override { + buildLLVMGPULinkingPassPipeline(passManager, "rocm"); + } + // Performs optimizations on |module| (including LTO-style whole-program // ones). Inspired by code section in // https://github.com/iree-org/iree/blob/main/compiler/plugins/target/CUDA/CUDATarget.cpp static void optimizeModule(llvm::Module &module, llvm::TargetMachine &targetMachine, - ArrayRef passPlugins) { + ArrayRef passPlugins, + bool slpVectorization) { llvm::LoopAnalysisManager lam; llvm::FunctionAnalysisManager fam; llvm::CGSCCAnalysisManager cgam; @@ -295,7 +288,7 @@ class ROCMTargetBackend final : public TargetBackend { fam.registerPass([&] { return targetMachine.getTargetIRAnalysis(); }); llvm::PipelineTuningOptions pto; - pto.SLPVectorization = false; + pto.SLPVectorization = slpVectorization; llvm::PassInstrumentationCallbacks pic; @@ -346,9 +339,10 @@ class ROCMTargetBackend final : public TargetBackend { return success(); } - LogicalResult serializeExecutable(const SerializationOptions &serOptions, - IREE::HAL::ExecutableVariantOp variantOp, - OpBuilder &executableBuilder) override { + LogicalResult + serializeExecutable(const SerializationOptions &serializationOptions, + IREE::HAL::ExecutableVariantOp variantOp, + OpBuilder &executableBuilder) override { ModuleOp innerModuleOp = variantOp.getInnerModule(); auto targetAttr = variantOp.getTargetAttr(); StringRef targetArch = options.target; @@ -459,6 +453,7 @@ class ROCMTargetBackend final : public TargetBackend { opt.UnsafeFPMath = false; opt.NoInfsFPMath = false; opt.NoNaNsFPMath = true; + opt.EnableGlobalISel = options.globalISel; SmallVector features; if (targetArch.starts_with("gfx10") || targetArch.starts_with("gfx11")) { @@ -541,17 +536,18 @@ class ROCMTargetBackend final : public TargetBackend { return failure(); } - if (!serOptions.dumpIntermediatesPath.empty()) { - dumpModuleToPath(serOptions.dumpIntermediatesPath, - serOptions.dumpBaseName, variantOp.getName(), + if (!serializationOptions.dumpIntermediatesPath.empty()) { + dumpModuleToPath(serializationOptions.dumpIntermediatesPath, + serializationOptions.dumpBaseName, variantOp.getName(), ".linked.ll", *llvmModule); } // Run LLVM optimization passes. - optimizeModule(*llvmModule, *targetMachine, options.passPlugins); - if (!serOptions.dumpIntermediatesPath.empty()) { - dumpModuleToPath(serOptions.dumpIntermediatesPath, - serOptions.dumpBaseName, variantOp.getName(), + optimizeModule(*llvmModule, *targetMachine, options.passPlugins, + options.slpVectorization); + if (!serializationOptions.dumpIntermediatesPath.empty()) { + dumpModuleToPath(serializationOptions.dumpIntermediatesPath, + serializationOptions.dumpBaseName, variantOp.getName(), ".optimized.ll", *llvmModule); } @@ -560,7 +556,7 @@ class ROCMTargetBackend final : public TargetBackend { } // Dump the assembly output. - if (!serOptions.dumpIntermediatesPath.empty()) { + if (!serializationOptions.dumpIntermediatesPath.empty()) { auto moduleCopy = llvm::CloneModule(*llvmModule); if (!moduleCopy) { llvm::errs() << "Error: cloning LLVM IR failed\n"; @@ -568,9 +564,9 @@ class ROCMTargetBackend final : public TargetBackend { } std::string targetISA = translateModuleToISA(*moduleCopy.get(), *targetMachine); - dumpDataToPath(serOptions.dumpIntermediatesPath, - serOptions.dumpBaseName, variantOp.getName(), ".rocmasm", - targetISA); + dumpDataToPath(serializationOptions.dumpIntermediatesPath, + serializationOptions.dumpBaseName, variantOp.getName(), + ".rocmasm", targetISA); } // Serialize hsaco kernel into the binary that we will embed in the @@ -581,23 +577,136 @@ class ROCMTargetBackend final : public TargetBackend { return failure(); } - if (!serOptions.dumpBinariesPath.empty()) { - dumpDataToPath(serOptions.dumpBinariesPath, serOptions.dumpBaseName, - variantOp.getName(), ".hsaco", targetHSACO); + if (!serializationOptions.dumpBinariesPath.empty()) { + dumpDataToPath(serializationOptions.dumpBinariesPath, + serializationOptions.dumpBaseName, variantOp.getName(), + ".hsaco", targetHSACO); + } + + // Wrap the HSACO ELF binary in a Flatbuffers container. + FailureOr binaryContainer; + if (targetAttr.getConfiguration() && + targetAttr.getConfiguration().getAs("abi") == "amdgpu") { + binaryContainer = serializeAMDGPUBinaryContainer( + serializationOptions, variantOp, exportOps, targetHSACO); + } else { + binaryContainer = serializeHIPBinaryContainer( + serializationOptions, variantOp, exportOps, targetHSACO); + } + if (failed(binaryContainer) || !binaryContainer.value()) { + return failure(); } + // Add the binary data to the target executable. + executableBuilder.create( + variantOp.getLoc(), variantOp.getSymName(), + variantOp.getTarget().getFormat(), binaryContainer.value()); + + return success(); + } + +protected: + FailureOr serializeAMDGPUBinaryContainer( + const SerializationOptions &serializationOptions, + IREE::HAL::ExecutableVariantOp variantOp, + ArrayRef exportOps, + StringRef hsacoModule) { + iree_compiler::FlatbufferBuilder builder; + iree_hal_amdgpu_ExecutableDef_start_as_root(builder); + + // Attach embedded source file contents. + auto sourceFilesRef = createSourceFilesVec( + serializationOptions.debugLevel, variantOp.getSourcesAttr(), builder); + + // Only a single module today. + SmallVector moduleRefs; + { + auto hsacoImageRef = flatbuffers_string_create( + builder, hsacoModule.data(), hsacoModule.size()); + moduleRefs.push_back( + iree_hal_amdgpu_ModuleDef_create(builder, hsacoImageRef)); + } + auto modulesRef = builder.createOffsetVecDestructive(moduleRefs); + + // Generate optional per-export debug information. + // May be empty if no debug information was requested. + auto exportDebugInfos = + createExportDefs(serializationOptions.debugLevel, exportOps, builder); + + SmallVector exportRefs; + exportRefs.resize(exportOps.size(), 0); + for (auto exportOp : exportOps) { + auto ordinalAttr = exportOp.getOrdinalAttr(); + if (!ordinalAttr) { + return mlir::emitError(exportOp.getLoc()) + << "could not compile rocm binary: export op is missing ordinal"; + } + int64_t ordinal = ordinalAttr.getInt(); + + auto symbolNameRef = builder.createString(exportOp.getName()); + + iree_hal_amdgpu_Dims_t workgroupSize = {0}; + if (auto workgroupSizeAttr = exportOp.getWorkgroupSize()) { + auto workgroupSizeDims = workgroupSizeAttr->getValue(); + workgroupSize.x = cast(workgroupSizeDims[0]).getInt(); + workgroupSize.y = cast(workgroupSizeDims[1]).getInt(); + workgroupSize.z = cast(workgroupSizeDims[2]).getInt(); + } + + auto layoutAttr = exportOp.getLayoutAttr(); + uint32_t constantCount = static_cast(layoutAttr.getConstants()); + SmallVector bindingFlags; + for (auto bindingAttr : layoutAttr.getBindings()) { + iree_hal_amdgpu_BindingBits_enum_t flags = 0; + if (allEnumBitsSet(bindingAttr.getFlags(), + IREE::HAL::DescriptorFlags::ReadOnly)) { + flags |= iree_hal_amdgpu_BindingBits_READ_ONLY; + } + if (allEnumBitsSet(bindingAttr.getFlags(), + IREE::HAL::DescriptorFlags::Indirect)) { + flags |= iree_hal_amdgpu_BindingBits_INDIRECT; + } + bindingFlags.push_back(flags); + } + auto bindingFlagsRef = iree_hal_amdgpu_BindingBits_vec_create( + builder, bindingFlags.data(), bindingFlags.size()); + + iree_hal_amdgpu_ExportDef_start(builder); + iree_hal_amdgpu_ExportDef_symbol_name_add(builder, symbolNameRef); + iree_hal_amdgpu_ExportDef_workgroup_size_add(builder, &workgroupSize); + iree_hal_amdgpu_ExportDef_constant_count_add(builder, constantCount); + iree_hal_amdgpu_ExportDef_binding_flags_add(builder, bindingFlagsRef); + iree_hal_amdgpu_ExportDef_debug_info_add(builder, + exportDebugInfos[ordinal]); + exportRefs[ordinal] = iree_hal_amdgpu_ExportDef_end(builder); + } + auto exportsRef = builder.createOffsetVecDestructive(exportRefs); + + iree_hal_amdgpu_ExecutableDef_exports_add(builder, exportsRef); + iree_hal_amdgpu_ExecutableDef_modules_add(builder, modulesRef); + iree_hal_amdgpu_ExecutableDef_source_files_add(builder, sourceFilesRef); + iree_hal_amdgpu_ExecutableDef_end_as_root(builder); + + return builder.getBufferAttr(variantOp.getContext()); + } + + FailureOr + serializeHIPBinaryContainer(const SerializationOptions &serializationOptions, + IREE::HAL::ExecutableVariantOp variantOp, + ArrayRef exportOps, + StringRef hsacoModule) { iree_compiler::FlatbufferBuilder builder; iree_hal_hip_ExecutableDef_start_as_root(builder); // Attach embedded source file contents. auto sourceFilesRef = createSourceFilesVec( - serOptions.debugLevel, variantOp.getSourcesAttr(), builder); + serializationOptions.debugLevel, variantOp.getSourcesAttr(), builder); // Only a single module today. SmallVector moduleRefs; { auto hsacoImageRef = flatbuffers_string_create( - builder, targetHSACO.c_str(), targetHSACO.size()); + builder, hsacoModule.data(), hsacoModule.size()); moduleRefs.push_back( iree_hal_hip_ModuleDef_create(builder, hsacoImageRef)); } @@ -606,7 +715,7 @@ class ROCMTargetBackend final : public TargetBackend { // Generate optional per-export debug information. // May be empty if no debug information was requested. auto exportDebugInfos = - createExportDefs(serOptions.debugLevel, exportOps, builder); + createExportDefs(serializationOptions.debugLevel, exportOps, builder); SmallVector exportRefs; exportRefs.resize(exportOps.size(), 0); @@ -670,27 +779,91 @@ class ROCMTargetBackend final : public TargetBackend { iree_hal_hip_ExecutableDef_source_files_add(builder, sourceFilesRef); iree_hal_hip_ExecutableDef_end_as_root(builder); - // Add the binary data to the target executable. - executableBuilder.create( - variantOp.getLoc(), variantOp.getSymName(), - variantOp.getTarget().getFormat(), - builder.getBufferAttr(executableBuilder.getContext())); + return builder.getBufferAttr(variantOp.getContext()); + } - return success(); +private: + const ROCMOptions &options; +}; + +class AMDGPUTargetDevice final : public TargetDevice { +public: + AMDGPUTargetDevice(const ROCMOptions &options) : options(options) {} + + IREE::HAL::DeviceTargetAttr + getDefaultDeviceTarget(MLIRContext *context, + const TargetRegistry &targetRegistry) const override { + Builder b(context); + + SmallVector deviceConfigAttrs; + auto deviceConfigAttr = b.getDictionaryAttr(deviceConfigAttrs); + + SmallVector executableConfigAttrs; + auto executableConfigAttr = b.getDictionaryAttr(executableConfigAttrs); + + // If we had multiple target environments we would generate one target attr + // per environment, with each setting its own environment attribute. + SmallVector executableTargetAttrs; + targetRegistry.getTargetBackend("rocm")->getDefaultExecutableTargets( + context, "amdgpu", executableConfigAttr, executableTargetAttrs); + + return IREE::HAL::DeviceTargetAttr::get(context, b.getStringAttr("amdgpu"), + deviceConfigAttr, + executableTargetAttrs); + } + +private: + const ROCMOptions &options; +}; + +class HIPTargetDevice final : public TargetDevice { +public: + HIPTargetDevice(const ROCMOptions &options) : options(options) {} + + IREE::HAL::DeviceTargetAttr + getDefaultDeviceTarget(MLIRContext *context, + const TargetRegistry &targetRegistry) const override { + Builder b(context); + + SmallVector deviceConfigAttrs; + if (options.legacySync) { + // Indicates that the runtime HAL driver operates only in the legacy + // synchronous mode. + deviceConfigAttrs.emplace_back(b.getStringAttr("legacy_sync"), + b.getUnitAttr()); + } + auto deviceConfigAttr = b.getDictionaryAttr(deviceConfigAttrs); + + SmallVector executableConfigAttrs; + auto executableConfigAttr = b.getDictionaryAttr(executableConfigAttrs); + + // If we had multiple target environments we would generate one target attr + // per environment, with each setting its own environment attribute. + SmallVector executableTargetAttrs; + targetRegistry.getTargetBackend("rocm")->getDefaultExecutableTargets( + context, "hip", executableConfigAttr, executableTargetAttrs); + + return IREE::HAL::DeviceTargetAttr::get(context, b.getStringAttr("hip"), + deviceConfigAttr, + executableTargetAttrs); } private: - const ROCmOptions &options; + const ROCMOptions &options; }; namespace { struct ROCMSession final - : PluginSession { void populateHALTargetDevices(IREE::HAL::TargetDeviceList &targets) { + // #hal.device.target<"amdgpu", ... + targets.add("amdgpu", [&]() { + return std::make_shared(options); + }); // #hal.device.target<"hip", ... targets.add("hip", - [&]() { return std::make_shared(options); }); + [&]() { return std::make_shared(options); }); } void populateHALTargetBackends(IREE::HAL::TargetBackendList &targets) { // #hal.executable.target<"rocm", ... @@ -716,4 +889,4 @@ extern "C" bool iree_register_compiler_plugin_hal_target_rocm( return true; } -IREE_DEFINE_COMPILER_OPTION_FLAGS(mlir::iree_compiler::IREE::HAL::ROCmOptions); +IREE_DEFINE_COMPILER_OPTION_FLAGS(mlir::iree_compiler::IREE::HAL::ROCMOptions); diff --git a/compiler/plugins/target/ROCM/ROCMTargetUtils.cpp b/compiler/plugins/target/ROCM/ROCMTargetUtils.cpp index 7453af749b80..a1757afd75f1 100644 --- a/compiler/plugins/target/ROCM/ROCMTargetUtils.cpp +++ b/compiler/plugins/target/ROCM/ROCMTargetUtils.cpp @@ -35,7 +35,7 @@ loadIRModule(Location loc, const std::string &filename, diagnostic, *llvm_context)); if (!module) { - mlir::emitError(loc) << "error loading HIP LLVM module: " + mlir::emitError(loc) << "error loading ROCM LLVM module: " << diagnostic.getFilename().str() << ":" << diagnostic.getLineNo() << ":" << diagnostic.getColumnNo() << ": " @@ -90,7 +90,7 @@ static LogicalResult linkBitcodeFile(Location loc, llvm::Linker &linker, auto setAlwaysInline = [&](llvm::Module &module) { if (targetMachine.getTargetCPU().contains("gfx10") || targetMachine.getTargetCPU().contains("gfx11")) { - // some ROCM/HIP functions for gfx10 or gfx11 has accuracy issue if + // Some ROCM/HIP functions for gfx10 or gfx11 has accuracy issue if // inlined. return; } diff --git a/compiler/plugins/target/ROCM/test/smoketest.mlir b/compiler/plugins/target/ROCM/test/smoketest.mlir index 1afe688467ee..a25547b387e2 100644 --- a/compiler/plugins/target/ROCM/test/smoketest.mlir +++ b/compiler/plugins/target/ROCM/test/smoketest.mlir @@ -2,19 +2,19 @@ module attributes { hal.device.targets = [ - #hal.device.target<"hip", [ - #hal.executable.target<"rocm", "rocm-hsaco-fb"> + #hal.device.target<"amdgpu", [ + #hal.executable.target<"rocm", "amdgcn-amd-amdhsa"> ]> : !hal.device ] } { -stream.executable public @add_dispatch_0 { - stream.executable.export @add_dispatch_0 workgroups(%arg0 : index) -> (index, index, index) { +stream.executable public @add_dispatch_executable { + stream.executable.export @add_dispatch workgroups(%arg0 : index) -> (index, index, index) { %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg0 stream.return %x, %y, %z : index, index, index } builtin.module { - func.func @add_dispatch_0(%arg0_binding: !stream.binding, %arg1_binding: !stream.binding, %arg2_binding: !stream.binding) { + func.func @add_dispatch(%arg0_binding: !stream.binding, %arg1_binding: !stream.binding, %arg2_binding: !stream.binding) { %c0 = arith.constant 0 : index %arg0 = stream.binding.subspan %arg0_binding[%c0] : !stream.binding -> !flow.dispatch.tensor> %arg1 = stream.binding.subspan %arg1_binding[%c0] : !stream.binding -> !flow.dispatch.tensor> @@ -23,7 +23,7 @@ stream.executable public @add_dispatch_0 { %1 = flow.dispatch.tensor.load %arg0, offsets=[0], sizes=[16], strides=[1] : !flow.dispatch.tensor> -> tensor<16xf32> %2 = flow.dispatch.tensor.load %arg1, offsets=[0], sizes=[16], strides=[1] : !flow.dispatch.tensor> -> tensor<16xf32> %3 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%1, %2 : tensor<16xf32>, tensor<16xf32>) outs(%0 : tensor<16xf32>) { - ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors + ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): %4 = arith.addf %arg3, %arg4 : f32 linalg.yield %4 : f32 } -> tensor<16xf32> @@ -33,11 +33,37 @@ stream.executable public @add_dispatch_0 { } } +stream.executable public @mul_dispatch_executable { + stream.executable.export @mul_dispatch workgroups(%arg0 : index) -> (index, index, index) { + %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg0 + stream.return %x, %y, %z : index, index, index + } + builtin.module { + func.func @mul_dispatch(%arg0_binding: !stream.binding, %arg1_binding: !stream.binding, %arg2_binding: !stream.binding) { + %c0 = arith.constant 0 : index + %arg0 = stream.binding.subspan %arg0_binding[%c0] : !stream.binding -> !flow.dispatch.tensor> + %arg1 = stream.binding.subspan %arg1_binding[%c0] : !stream.binding -> !flow.dispatch.tensor> + %arg2 = stream.binding.subspan %arg2_binding[%c0] : !stream.binding -> !flow.dispatch.tensor> + %0 = tensor.empty() : tensor<16xf32> + %1 = flow.dispatch.tensor.load %arg0, offsets=[0], sizes=[16], strides=[1] : !flow.dispatch.tensor> -> tensor<16xf32> + %2 = flow.dispatch.tensor.load %arg1, offsets=[0], sizes=[16], strides=[1] : !flow.dispatch.tensor> -> tensor<16xf32> + %3 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%1, %2 : tensor<16xf32>, tensor<16xf32>) outs(%0 : tensor<16xf32>) { + ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): + %4 = arith.mulf %arg3, %arg4 : f32 + linalg.yield %4 : f32 + } -> tensor<16xf32> + flow.dispatch.tensor.store %3, %arg2, offsets=[0], sizes=[16], strides=[1] : tensor<16xf32> -> !flow.dispatch.tensor> + return + } + } +} + } -// CHECK: hal.executable.binary public @rocm_hsaco_fb attributes { +// CHECK: hal.executable public @smoketest_linked +// CHECK: hal.executable.binary public @amdgcn_amd_amdhsa attributes { // CHECK-SAME: data = dense -// CHECK-SAME: format = "rocm-hsaco-fb" +// CHECK-SAME: format = "amdgcn-amd-amdhsa" // ----- @@ -52,13 +78,13 @@ module attributes { ] } { -stream.executable public @add_dispatch_0 { - stream.executable.export @add_dispatch_0 workgroups(%arg0 : index) -> (index, index, index) { +stream.executable public @executable { + stream.executable.export @export workgroups(%arg0 : index) -> (index, index, index) { %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg0 stream.return %x, %y, %z : index, index, index } loc(#loc) builtin.module { - func.func @add_dispatch_0() { + func.func @export() { return } loc(#loc) } loc(#loc) diff --git a/compiler/plugins/target/ROCM/test/target_device_features.mlir b/compiler/plugins/target/ROCM/test/target_device_features.mlir index b7a5ab68a014..578cd5921c59 100644 --- a/compiler/plugins/target/ROCM/test/target_device_features.mlir +++ b/compiler/plugins/target/ROCM/test/target_device_features.mlir @@ -15,10 +15,10 @@ // GFX942: target = #iree_gpu.target, , , , , ], +// GFX942-SAME: mma = [, , , , , , , , ], // GFX942-SAME: subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024], // GFX942-SAME: max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536, -// GFX942-SAME: max_workgroup_counts = [2147483647, 2147483647, 2147483647]>, +// GFX942-SAME: max_workgroup_counts = [2147483647, 2147483647, 2147483647], // MI300X: chip = > // MI300A: chip = > @@ -26,7 +26,7 @@ // GFX941-SAME: features = "+sramecc,-xnack" // GFX940: target = #iree_gpu.target, , , , , ], +// GFX940-SAME: mma = [, , , , , , , , ], // GFX1100: target = #iree_gpu.target, , ] diff --git a/compiler/pyproject.toml b/compiler/pyproject.toml index 5a07bd90cc5e..b7a4fd4de382 100644 --- a/compiler/pyproject.toml +++ b/compiler/pyproject.toml @@ -3,15 +3,10 @@ requires = [ "setuptools>=42", "wheel", "cmake", - # Note that the compiler wheel does not presently need nanobind, but - # it's build is enabled by the same flag which enables the runtime - # configuration, which does. - "nanobind==2.2.0", "ninja", # MLIR build depends. "numpy", "packaging", - "pybind11==2.13.6", "sympy", ] build-backend = "setuptools.build_meta" diff --git a/compiler/src/iree/compiler/API/Internal/BUILD.bazel b/compiler/src/iree/compiler/API/Internal/BUILD.bazel index c8ac4551dd1d..2413bed54150 100644 --- a/compiler/src/iree/compiler/API/Internal/BUILD.bazel +++ b/compiler/src/iree/compiler/API/Internal/BUILD.bazel @@ -38,6 +38,7 @@ iree_compiler_cc_library( "@llvm-project//mlir:BuiltinToLLVMIRTranslation", "@llvm-project//mlir:BytecodeWriter", "@llvm-project//mlir:CAPIIR", + "@llvm-project//mlir:Debug", "@llvm-project//mlir:IR", "@llvm-project//mlir:Parser", "@llvm-project//mlir:Support", diff --git a/compiler/src/iree/compiler/API/Internal/CMakeLists.txt b/compiler/src/iree/compiler/API/Internal/CMakeLists.txt index 61631e148162..191ea93a1cbe 100644 --- a/compiler/src/iree/compiler/API/Internal/CMakeLists.txt +++ b/compiler/src/iree/compiler/API/Internal/CMakeLists.txt @@ -23,6 +23,7 @@ iree_cc_library( MLIRBuiltinToLLVMIRTranslation MLIRBytecodeWriter MLIRCAPIIR + MLIRDebug MLIRIR MLIRParser MLIRSupport diff --git a/compiler/src/iree/compiler/API/Internal/CompilerDriver.cpp b/compiler/src/iree/compiler/API/Internal/CompilerDriver.cpp index 488555af6640..7f83a5e3b3fe 100644 --- a/compiler/src/iree/compiler/API/Internal/CompilerDriver.cpp +++ b/compiler/src/iree/compiler/API/Internal/CompilerDriver.cpp @@ -67,6 +67,7 @@ #include "mlir/Bytecode/BytecodeWriter.h" #include "mlir/CAPI/IR.h" #include "mlir/CAPI/Wrap.h" +#include "mlir/Debug/CLOptionsSetup.h" #include "mlir/IR/AsmState.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/Dialect.h" @@ -274,6 +275,7 @@ void GlobalInit::registerCommandLineOptions() { // Register pass manager command-line options like -mlir-print-ir-*. mlir::registerPassManagerCLOptions(); mlir::registerDefaultTimingManagerCLOptions(); + mlir::tracing::DebugConfig::registerCLOptions(); // Bind session options to the command line environment. clPluginManagerOptions = &PluginManagerOptions::FromFlags::get(); @@ -366,6 +368,11 @@ struct Session { // All user access to the context is done via this reference. MLIRContext &context; OptionsBinder binder; + + // Debug configuration. + mlir::tracing::DebugConfig debugConfig; + std::optional debugHandlerInstall; + // PluginManagerOptions must initialize first because the session depends on // it. PluginManagerOptions pluginManagerOptions; @@ -402,6 +409,7 @@ Session::Session(GlobalInit &globalInit) // Bootstrap session options from the cl environment, if enabled. if (globalInit.usesCommandLine) { + debugConfig = mlir::tracing::DebugConfig::createFromCLOptions(); pluginManagerOptions = *globalInit.clPluginManagerOptions; bindingOptions = *globalInit.clBindingOptions; inputOptions = *globalInit.clInputOptions; @@ -417,6 +425,9 @@ Session::Session(GlobalInit &globalInit) #endif } + // Enable debug integration. + debugHandlerInstall.emplace(context, debugConfig); + // Register each options struct with the binder so we can manipulate // mnemonically via the API. bindingOptions.bindOptions(binder); diff --git a/compiler/src/iree/compiler/API/Internal/IREEGPUDialectCAPI.cpp b/compiler/src/iree/compiler/API/Internal/IREEGPUDialectCAPI.cpp index 9b4639bb0cc2..555601c4bcc4 100644 --- a/compiler/src/iree/compiler/API/Internal/IREEGPUDialectCAPI.cpp +++ b/compiler/src/iree/compiler/API/Internal/IREEGPUDialectCAPI.cpp @@ -1,120 +1,120 @@ -// Copyright 2024 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h" -#include "iree/compiler/dialects/iree_gpu.h" -#include "mlir-c/IR.h" -#include "mlir/CAPI/IR.h" -#include "mlir/CAPI/Support.h" - -bool ireeAttributeIsAGPUPipelineOptionsAttr(MlirAttribute attr) { - return llvm::isa( - unwrap(attr)); -} - -MlirAttribute -ireeGPUPipelineOptionsAttrGet(MlirContext mlirCtx, bool *prefetchSharedMemory, - bool *noReduceSharedMemoryBankConflicts, - MlirAttribute *reorderWorkgroupsStrategy) { - mlir::MLIRContext *ctx = unwrap(mlirCtx); - mlir::Builder b(ctx); - auto prefetchSharedMemoryAttr = mlir::BoolAttr(); - if (prefetchSharedMemory) { - prefetchSharedMemoryAttr = b.getBoolAttr(*prefetchSharedMemory); - } - auto noReduceSharedMemoryBankConflictsAttr = mlir::BoolAttr(); - if (noReduceSharedMemoryBankConflicts) { - noReduceSharedMemoryBankConflictsAttr = - b.getBoolAttr(*noReduceSharedMemoryBankConflicts); - } - auto strategyAttr = - mlir::iree_compiler::IREE::GPU::ReorderWorkgroupsStrategyAttr(); - if (reorderWorkgroupsStrategy) { - strategyAttr = llvm::dyn_cast< - mlir::iree_compiler::IREE::GPU::ReorderWorkgroupsStrategyAttr>( - unwrap(*reorderWorkgroupsStrategy)); - } - return wrap(mlir::iree_compiler::IREE::GPU::GPUPipelineOptionsAttr::get( - ctx, prefetchSharedMemoryAttr, noReduceSharedMemoryBankConflictsAttr, - strategyAttr)); -} - -MlirAttribute -ireeGPUPipelineOptionsAttrGetPrefetchSharedMemory(MlirAttribute attr) { - auto gpuAttr = - llvm::cast( - unwrap(attr)); - return wrap(gpuAttr.getPrefetchSharedMemory()); -} - -MlirAttribute ireeGPUPipelineOptionsAttrGetNoReduceSharedMemoryBankConflicts( - MlirAttribute attr) { - auto gpuAttr = - llvm::cast( - unwrap(attr)); - return wrap(gpuAttr.getNoReduceSharedMemoryBankConflicts()); -} - -MlirAttribute -ireeGPUPipelineOptionsAttrGetReorderWorkgroupsStrategy(MlirAttribute attr) { - auto gpuAttr = - llvm::cast( - unwrap(attr)); - return wrap(gpuAttr.getReorderWorkgroupsStrategy()); -} - -MlirTypeID ireeGPUPipelineOptionsAttrGetTypeID() { - return wrap( - mlir::iree_compiler::IREE::GPU::GPUPipelineOptionsAttr::getTypeID()); -} - -static_assert( - static_cast(ireeGPUReorderWorkgroupsStrategyEnumNone) == - static_cast(mlir::iree_compiler::IREE::GPU:: - ReorderWorkgroupsStrategy::None) && - static_cast(ireeGPUReorderWorkgroupsStrategyEnumSwizzle) == - static_cast(mlir::iree_compiler::IREE::GPU:: - ReorderWorkgroupsStrategy::Swizzle) && - static_cast(ireeGPUReorderWorkgroupsStrategyEnumTranspose) == - static_cast(mlir::iree_compiler::IREE::GPU:: - ReorderWorkgroupsStrategy::Transpose) && - static_cast(ireeGPUReorderWorkgroupsStrategyEnumTranspose) == - mlir::iree_compiler::IREE::GPU:: - getMaxEnumValForReorderWorkgroupsStrategy(), - "ireeGPUReorderWorkgroupsStrategyEnum and " - "mlir::iree_compiler::IREE::GPU::ReorderWorkgroupsStrategy definitions " - "have diverged"); - -bool ireeAttributeIsAGPUReorderWorkgroupsStrategyAttr(MlirAttribute attr) { - return llvm::isa< - mlir::iree_compiler::IREE::GPU::ReorderWorkgroupsStrategyAttr>( - unwrap(attr)); -} - -MlirTypeID ireeGPUReorderWorkgroupsStrategyAttrGetTypeID() { - return wrap(mlir::iree_compiler::IREE::GPU::ReorderWorkgroupsStrategyAttr:: - getTypeID()); -} - -MlirAttribute ireeGPUReorderWorkgroupsStrategyAttrGet( - MlirContext mlirCtx, ireeGPUReorderWorkgroupsStrategyEnum value) { - mlir::MLIRContext *ctx = unwrap(mlirCtx); - return wrap( - mlir::iree_compiler::IREE::GPU::ReorderWorkgroupsStrategyAttr::get( - ctx, static_cast< - mlir::iree_compiler::IREE::GPU::ReorderWorkgroupsStrategy>( - value))); -} - -ireeGPUReorderWorkgroupsStrategyEnum -ireeGPUReorderWorkgroupsStrategyAttrGetValue(MlirAttribute attr) { - assert(ireeAttributeIsAGPUReorderWorkgroupsStrategyAttr(attr) && - "attr is not a GPUReorderWorkgroupsStrategyAttr"); - return static_cast( - llvm::cast( - unwrap(attr)) - .getValue()); -} +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h" +#include "iree/compiler/dialects/iree_gpu.h" +#include "mlir-c/IR.h" +#include "mlir/CAPI/IR.h" +#include "mlir/CAPI/Support.h" + +bool ireeAttributeIsAGPUPipelineOptionsAttr(MlirAttribute attr) { + return llvm::isa( + unwrap(attr)); +} + +MlirAttribute +ireeGPUPipelineOptionsAttrGet(MlirContext mlirCtx, bool *prefetchSharedMemory, + bool *noReduceSharedMemoryBankConflicts, + MlirAttribute *reorderWorkgroupsStrategy) { + mlir::MLIRContext *ctx = unwrap(mlirCtx); + mlir::Builder b(ctx); + auto prefetchSharedMemoryAttr = mlir::BoolAttr(); + if (prefetchSharedMemory) { + prefetchSharedMemoryAttr = b.getBoolAttr(*prefetchSharedMemory); + } + auto noReduceSharedMemoryBankConflictsAttr = mlir::BoolAttr(); + if (noReduceSharedMemoryBankConflicts) { + noReduceSharedMemoryBankConflictsAttr = + b.getBoolAttr(*noReduceSharedMemoryBankConflicts); + } + auto strategyAttr = + mlir::iree_compiler::IREE::GPU::ReorderWorkgroupsStrategyAttr(); + if (reorderWorkgroupsStrategy) { + strategyAttr = llvm::dyn_cast< + mlir::iree_compiler::IREE::GPU::ReorderWorkgroupsStrategyAttr>( + unwrap(*reorderWorkgroupsStrategy)); + } + return wrap(mlir::iree_compiler::IREE::GPU::GPUPipelineOptionsAttr::get( + ctx, prefetchSharedMemoryAttr, noReduceSharedMemoryBankConflictsAttr, + strategyAttr)); +} + +MlirAttribute +ireeGPUPipelineOptionsAttrGetPrefetchSharedMemory(MlirAttribute attr) { + auto gpuAttr = + llvm::cast( + unwrap(attr)); + return wrap(gpuAttr.getPrefetchSharedMemory()); +} + +MlirAttribute ireeGPUPipelineOptionsAttrGetNoReduceSharedMemoryBankConflicts( + MlirAttribute attr) { + auto gpuAttr = + llvm::cast( + unwrap(attr)); + return wrap(gpuAttr.getNoReduceSharedMemoryBankConflicts()); +} + +MlirAttribute +ireeGPUPipelineOptionsAttrGetReorderWorkgroupsStrategy(MlirAttribute attr) { + auto gpuAttr = + llvm::cast( + unwrap(attr)); + return wrap(gpuAttr.getReorderWorkgroupsStrategy()); +} + +MlirTypeID ireeGPUPipelineOptionsAttrGetTypeID() { + return wrap( + mlir::iree_compiler::IREE::GPU::GPUPipelineOptionsAttr::getTypeID()); +} + +static_assert( + static_cast(ireeGPUReorderWorkgroupsStrategyEnumNone) == + static_cast(mlir::iree_compiler::IREE::GPU:: + ReorderWorkgroupsStrategy::None) && + static_cast(ireeGPUReorderWorkgroupsStrategyEnumSwizzle) == + static_cast(mlir::iree_compiler::IREE::GPU:: + ReorderWorkgroupsStrategy::Swizzle) && + static_cast(ireeGPUReorderWorkgroupsStrategyEnumTranspose) == + static_cast(mlir::iree_compiler::IREE::GPU:: + ReorderWorkgroupsStrategy::Transpose) && + static_cast(ireeGPUReorderWorkgroupsStrategyEnumTranspose) == + mlir::iree_compiler::IREE::GPU:: + getMaxEnumValForReorderWorkgroupsStrategy(), + "ireeGPUReorderWorkgroupsStrategyEnum and " + "mlir::iree_compiler::IREE::GPU::ReorderWorkgroupsStrategy definitions " + "have diverged"); + +bool ireeAttributeIsAGPUReorderWorkgroupsStrategyAttr(MlirAttribute attr) { + return llvm::isa< + mlir::iree_compiler::IREE::GPU::ReorderWorkgroupsStrategyAttr>( + unwrap(attr)); +} + +MlirTypeID ireeGPUReorderWorkgroupsStrategyAttrGetTypeID() { + return wrap(mlir::iree_compiler::IREE::GPU::ReorderWorkgroupsStrategyAttr:: + getTypeID()); +} + +MlirAttribute ireeGPUReorderWorkgroupsStrategyAttrGet( + MlirContext mlirCtx, ireeGPUReorderWorkgroupsStrategyEnum value) { + mlir::MLIRContext *ctx = unwrap(mlirCtx); + return wrap( + mlir::iree_compiler::IREE::GPU::ReorderWorkgroupsStrategyAttr::get( + ctx, static_cast< + mlir::iree_compiler::IREE::GPU::ReorderWorkgroupsStrategy>( + value))); +} + +ireeGPUReorderWorkgroupsStrategyEnum +ireeGPUReorderWorkgroupsStrategyAttrGetValue(MlirAttribute attr) { + assert(ireeAttributeIsAGPUReorderWorkgroupsStrategyAttr(attr) && + "attr is not a GPUReorderWorkgroupsStrategyAttr"); + return static_cast( + llvm::cast( + unwrap(attr)) + .getValue()); +} diff --git a/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel index 7aca986d540b..d6cc75d9cefa 100644 --- a/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel @@ -78,6 +78,7 @@ iree_compiler_cc_library( "@llvm-project//mlir:Analysis", "@llvm-project//mlir:DialectUtils", "@llvm-project//mlir:IR", + "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:VectorDialect", ], ) @@ -86,6 +87,7 @@ iree_compiler_cc_library( name = "Common", srcs = [ "AddFastMathFlags.cpp", + "BlockDynamicDimensions.cpp", "BubbleUpOrdinalOps.cpp", "BufferizationAnalysis.cpp", "BufferizeCopyOnlyDispatchesPass.cpp", @@ -137,6 +139,7 @@ iree_compiler_cc_library( "RemoveSingleIterationLoop.cpp", "ReplaceSlowMinMaxOps.cpp", "SplitFullPartialTransferPass.cpp", + "TensorDynamicDimAnalysis.cpp", "TensorToVectorVectorizePad.cpp", "TestExecutablePreprocessing.cpp", "TestPartitionableLoopsInterface.cpp", @@ -155,6 +158,7 @@ iree_compiler_cc_library( "ExtractAddressComputation.h", "PassUtils.h", "Passes.h", + "TensorDynamicDimAnalysis.h", "TileSizeSelection.h", "Transforms.h", "UserConfig.h", @@ -176,6 +180,7 @@ iree_compiler_cc_library( "//compiler/src/iree/compiler/Dialect/HAL/IR", "//compiler/src/iree/compiler/Dialect/LinalgExt/IR", "//compiler/src/iree/compiler/Dialect/LinalgExt/Transforms", + "//compiler/src/iree/compiler/Dialect/Util/Analysis", "//compiler/src/iree/compiler/Dialect/Util/IR", "//compiler/src/iree/compiler/Utils", "//llvm-external-projects/iree-dialects:IREELinalgTransformDialect", @@ -191,6 +196,7 @@ iree_compiler_cc_library( "@llvm-project//mlir:BufferizationDialect", "@llvm-project//mlir:BufferizationInterfaces", "@llvm-project//mlir:BufferizationTransforms", + "@llvm-project//mlir:DestinationStyleOpInterface", "@llvm-project//mlir:DialectUtils", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:FuncTransforms", diff --git a/compiler/src/iree/compiler/Codegen/Common/BlockDynamicDimensions.cpp b/compiler/src/iree/compiler/Codegen/Common/BlockDynamicDimensions.cpp new file mode 100644 index 000000000000..7a45116f0abd --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Common/BlockDynamicDimensions.cpp @@ -0,0 +1,302 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Codegen/Common/Passes.h" +#include "iree/compiler/Codegen/Common/TensorDynamicDimAnalysis.h" +#include "iree/compiler/Codegen/Transforms/Transforms.h" +#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h" +#include "iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h" +#include "iree/compiler/Dialect/Util/IR/UtilOps.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/Utils/Utils.h" +#include "mlir/Dialect/MemRef/Transforms/Transforms.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#define DEBUG_TYPE "iree-codegen-block-dynamic-dimensions" + +namespace mlir::iree_compiler { + +#define GEN_PASS_DEF_BLOCKDYNAMICDIMENSIONSPASS +#include "iree/compiler/Codegen/Common/Passes.h.inc" + +using TensorDivisibilityInfo = + llvm::SmallDenseMap; + +namespace { + +struct RemoveOptimizationBarrier final + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(IREE::Util::OptimizationBarrierOp barrierOp, + PatternRewriter &rewriter) const override { + rewriter.replaceOp(barrierOp, barrierOp.getOperands()); + return success(); + } +}; + +/// This pass is used to materialize information about dynamic dimensions of +/// `tensor` operands of an operation in the IR. If a dynamic dimension is +/// known to be a multiple of a compile-time constant value, this pass +/// expands the shape of the operands. For example if a `tensor` operand +/// is of shape `tensor<...x?x...>` and that dimension is known to be a +/// multiple of 16, this operand is expanded to `tensor<...x?x16x...>` where the +/// size of the new dynamic dimension is 1/16-th the size of the original +/// dynamic dimension size. This is done in two steps. +/// 1) Replace operands with such dynamic dimension with the result of a +/// `tensor.expand_shape/tensor.collapse_shape` pair +/// to materialize the new static dimension and immediately fold it away. A +/// optimization barrier is added in between to prevent these operations from +/// being folded. +/// 2) Use patterns that propagate the `tensor.collapse_shape` down to +/// manipulate the operation appropriately. This +/// allows re-using the (fairly complex) logic used to expand dimensions of +/// operations implemented in the propagation patterns. +/// At the end of the pass the optimization barriers are removed to fold away +/// any un-propagated `tensor.expand_shape/tensor.collapse_shape` patterns. +struct BlockDynamicDimensionsPass final + : impl::BlockDynamicDimensionsPassBase { + void runOnOperation() override; +}; +} // namespace + +/// Retrieve the divisibility information for dynamic dimensions of `v` if +/// known. +static TensorDivisibilityInfo +getTensorDivisibilityInfo(const TensorDynamicDimAnalysis &dynamicDimAnalysis, + Value v) { + TensorDivisibilityInfo divisibilityInfo; + auto tensorType = dyn_cast(v.getType()); + if (!tensorType) { + return divisibilityInfo; + } + + for (auto [index, dim] : llvm::enumerate(tensorType.getShape())) { + if (!tensorType.isDynamicDim(index)) + continue; + std::optional dimDivisibility = + dynamicDimAnalysis.getDivisibilityInfo(v, index); + if (!dimDivisibility) + continue; + divisibilityInfo[index] = std::move(dimDivisibility.value()); + } + + return divisibilityInfo; +} + +/// For a `v` if the dimension is known to be multiple of a compile-time static +/// value, insert +/// +/// ```mlir +/// %v_expand = tensor.expand_shape %v +/// %barrier = util.optimization.barrier %v +/// %v_collapse = tensor.collapse_shape %barrier +/// ``` +/// +/// where the generated `tensor.expand_shape` and `tensor.collapse_shape` are +/// inverses of each other. The `util.optimization.barrier` avoid these from +/// getting folded away during reshape propagation. Return the result of the +/// `tensor.collapse_shape generated. +static std::optional +blockDynamicDimensionsOfValue(RewriterBase &rewriter, + const TensorDivisibilityInfo &divisibilityInfo, + Value v) { + auto tensorType = dyn_cast(v.getType()); + if (!tensorType) { + return std::nullopt; + } + + // Check if we know that the operands have a divisibility information. + SmallVector outputShape; + SmallVector reassociation; + Location loc = v.getLoc(); + + for (auto [index, dim] : llvm::enumerate(tensorType.getShape())) { + reassociation.emplace_back(ReassociationIndices{}); + + // Check if this needs division. + if (!tensorType.isDynamicDim(index) || !divisibilityInfo.contains(index)) { + reassociation.back().push_back(outputShape.size()); + outputShape.push_back(rewriter.getIndexAttr(dim)); + continue; + } + + // Split the dynamic based on the divisibility info. + IREE::Util::ConstantIntDivisibility currDivisibility = + divisibilityInfo.lookup(index); + uint64_t factor = currDivisibility.sdiv(); + AffineExpr s0 = rewriter.getAffineSymbolExpr(0); + AffineExpr divExpr = s0.floorDiv(factor); + Value sourceDim = rewriter.create(loc, v, index).getResult(); + OpFoldResult newDynamicDim = affine::makeComposedFoldedAffineApply( + rewriter, loc, divExpr, ArrayRef{sourceDim}); + OpFoldResult newStaticDim = rewriter.getIndexAttr(factor); + + reassociation.back().push_back(outputShape.size()); + reassociation.back().push_back(outputShape.size() + 1); + + outputShape.push_back(newDynamicDim); + outputShape.push_back(newStaticDim); + } + + auto staticOutputShape = + llvm::map_to_vector(outputShape, [](OpFoldResult ofr) { + if (auto staticShapeAttr = dyn_cast(ofr)) { + return cast(staticShapeAttr).getInt(); + } + return ShapedType::kDynamic; + }); + auto outputType = RankedTensorType::get( + staticOutputShape, tensorType.getElementType(), tensorType.getEncoding()); + + Value expandShape = rewriter.create( + loc, outputType, v, reassociation, outputShape); + Value barrier = + rewriter.create(loc, expandShape) + .getResult(0); + Value collapseShape = rewriter.create( + loc, tensorType, barrier, reassociation); + return collapseShape; +} + +/// For an operation, replace the operands at indices specified in +/// `limitToOperandIndices` with the result of +/// `tensor.expand_shape`/`tensor.collapse_shape` pair to materialize the +/// information about dynamic dimensions that are known to be a multiple of a +/// compile-time static value. For example, +/// +/// ```mlir +/// %1 = (..., %0, ...) : ... , tensor<4x?x6xf32> +/// ``` +/// +/// If the dynamic dimension is known to be a multiple of 16, then generate +/// +/// ```mlir +/// %expanded = tensor.expand_shape %0 : +/// tensor<4x?x5xf32> into tensor<4x?x16x6xf32> +/// %barrier = util.optimization.barrier %expanded +/// %collapsed = tensor.collapse_shape %barrier +/// : tensor<4x?x16x5xf32> into tensor<4x?x5xf32> +/// %1 = (..., %collaped, ...) : ... , tensor<4x?x6xf32> +/// ``` +static LogicalResult blockDynamicDimensions( + RewriterBase &rewriter, const TensorDynamicDimAnalysis &dynamicDimAnalysis, + Operation *operation, llvm::SmallDenseSet limitToOperandIndices) { + OpBuilder::InsertionGuard g(rewriter); + + for (OpOperand &operand : operation->getOpOperands()) { + if (!limitToOperandIndices.contains(operand.getOperandNumber())) + continue; + if (operand.get().getDefiningOp()) + continue; + TensorDivisibilityInfo operandDivisibilityInfo = + getTensorDivisibilityInfo(dynamicDimAnalysis, operand.get()); + if (operandDivisibilityInfo.empty()) + continue; + std::optional newOperand = blockDynamicDimensionsOfValue( + rewriter, operandDivisibilityInfo, operand.get()); + if (newOperand) { + rewriter.modifyOpInPlace(operation, + [&]() { operand.set(newOperand.value()); }); + } + } + return success(); +} + +/// Insert `tensor.expand_shape` operations to materialize in IR information +/// about dynamic dimensions that are known to be a multiple of a compile-time +/// know value, for the operands of `iree_linalg_ext.attention` operation. +static LogicalResult +blockDynamicDimensions(RewriterBase &rewriter, + const TensorDynamicDimAnalysis &dynamicDimAnalysis, + IREE::LinalgExt::AttentionOp attentionOp) { + // Only block the q and k values. + llvm::SmallDenseSet prunedOperandsList; + prunedOperandsList.insert(attentionOp.getQueryMutable().getOperandNumber()); + prunedOperandsList.insert(attentionOp.getKeyMutable().getOperandNumber()); + return blockDynamicDimensions(rewriter, dynamicDimAnalysis, attentionOp, + prunedOperandsList); +} + +void BlockDynamicDimensionsPass::runOnOperation() { + Operation *operation = getOperation(); + MLIRContext *context = &getContext(); + TensorDynamicDimAnalysis dynamicDimAnalysis(operation); + if (failed(dynamicDimAnalysis.run())) { + return signalPassFailure(); + } + + IRRewriter rewriter(context); + auto walkResult = operation->walk( + [&](IREE::LinalgExt::AttentionOp attentionOp) -> WalkResult { + rewriter.setInsertionPoint(attentionOp); + return blockDynamicDimensions(rewriter, dynamicDimAnalysis, + attentionOp); + }); + if (walkResult.wasInterrupted()) { + return signalPassFailure(); + } + + LLVM_DEBUG({ + llvm::dbgs() << "After blocking dimensions:\n"; + operation->print(llvm::dbgs(), OpPrintingFlags().useLocalScope()); + llvm::dbgs() << "\n"; + }); + + { + RewritePatternSet bubbleExpandShapePatterns(context); + // Add patterns to "push down" the `tensor.collapse_shape` patterns (which + // are the dual of the patterns to "bubble up" `tensor.expand_shape` + // patterns) + linalg::ControlFusionFn controlFn = [](OpOperand *) { return true; }; + linalg::populateFoldReshapeOpsByExpansionPatterns(bubbleExpandShapePatterns, + controlFn); + IREE::LinalgExt::populateFoldReshapeOpsByExpansionPatterns( + bubbleExpandShapePatterns, controlFn); + // Add patterns to fold the "bubbled-up" `tensor.expand_shape` operation and + // "pushed-down" `tensor.collapse_shape` operation with their interface + // bindings or `tensor.empty` operations. + populateReshapeToInterfaceTensorPatterns(bubbleExpandShapePatterns); + tensor::populateFoldTensorEmptyPatterns(bubbleExpandShapePatterns); + // Add some additional patterns that can simplify the IR and remove dead + // operations. + memref::populateResolveRankedShapedTypeResultDimsPatterns( + bubbleExpandShapePatterns); + populateRemoveDeadMemAllocPatterns(bubbleExpandShapePatterns); + if (failed(applyPatternsAndFoldGreedily( + operation, std::move(bubbleExpandShapePatterns)))) { + operation->emitOpError( + "failed in application of bubble up expand shape patterns"); + return signalPassFailure(); + } + } + + LLVM_DEBUG({ + llvm::dbgs() << "After reshape propagation:\n"; + operation->print(llvm::dbgs(), OpPrintingFlags().useLocalScope()); + llvm::dbgs() << "\n"; + }); + + // Delete the optimization barrier and run some further cleanup. + { + RewritePatternSet removeBarrierOpsPatterns(context); + removeBarrierOpsPatterns.insert(context); + tensor::ExpandShapeOp::getCanonicalizationPatterns(removeBarrierOpsPatterns, + context); + tensor::CollapseShapeOp::getCanonicalizationPatterns( + removeBarrierOpsPatterns, context); + if (failed(applyPatternsAndFoldGreedily( + operation, std::move(removeBarrierOpsPatterns)))) { + operation->emitOpError("failed in cleanup patterns"); + return signalPassFailure(); + } + } + + return; +} + +} // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt index 764bc258c902..8f729de2f714 100644 --- a/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt @@ -58,6 +58,7 @@ iree_cc_library( LLVMSupport MLIRAnalysis MLIRIR + MLIRSCFDialect MLIRVectorDialect iree::compiler::Codegen::Dialect::VectorExt::IR::IREEVectorExtDialect PUBLIC @@ -72,11 +73,13 @@ iree_cc_library( "ExtractAddressComputation.h" "PassUtils.h" "Passes.h" + "TensorDynamicDimAnalysis.h" "TileSizeSelection.h" "Transforms.h" "UserConfig.h" SRCS "AddFastMathFlags.cpp" + "BlockDynamicDimensions.cpp" "BubbleUpOrdinalOps.cpp" "BufferizationAnalysis.cpp" "BufferizeCopyOnlyDispatchesPass.cpp" @@ -128,6 +131,7 @@ iree_cc_library( "RemoveSingleIterationLoop.cpp" "ReplaceSlowMinMaxOps.cpp" "SplitFullPartialTransferPass.cpp" + "TensorDynamicDimAnalysis.cpp" "TensorToVectorVectorizePad.cpp" "TestExecutablePreprocessing.cpp" "TestPartitionableLoopsInterface.cpp" @@ -154,6 +158,7 @@ iree_cc_library( MLIRArithUtils MLIRBufferizationDialect MLIRBufferizationTransforms + MLIRDestinationStyleOpInterface MLIRFuncDialect MLIRFuncTransforms MLIRFunctionInterfaces @@ -203,6 +208,7 @@ iree_cc_library( iree::compiler::Dialect::HAL::IR iree::compiler::Dialect::LinalgExt::IR iree::compiler::Dialect::LinalgExt::Transforms + iree::compiler::Dialect::Util::Analysis iree::compiler::Dialect::Util::IR iree::compiler::Utils PUBLIC diff --git a/compiler/src/iree/compiler/Codegen/Common/ConvolutionToIGEMM.cpp b/compiler/src/iree/compiler/Codegen/Common/ConvolutionToIGEMM.cpp index 58b678ce6588..8998b11ccee4 100644 --- a/compiler/src/iree/compiler/Codegen/Common/ConvolutionToIGEMM.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/ConvolutionToIGEMM.cpp @@ -5,6 +5,7 @@ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception #include "iree/compiler/Codegen/Common/Passes.h" +#include "iree/compiler/Codegen/Common/Transforms.h" #include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h" #include "iree/compiler/Codegen/Transforms/Transforms.h" #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h" @@ -12,6 +13,7 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/IR/MLIRContext.h" #include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassRegistry.h" @@ -26,10 +28,14 @@ namespace { using iree_compiler::IREE::LinalgExt::IREELinalgExtDialect; +/// Pattern to set a lowering configuration on an IGEMM convolution. Searches +/// for a contraction with a linalg_ext.im2col producer, and calls the configFn +/// to set the configuration. +/// TODO(Max191): Use a funcOp walk instead of a pattern for this. struct SetIGEMMConfiguration final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; - SetIGEMMConfiguration(MLIRContext *context, ConfigFn configFn) + SetIGEMMConfiguration(MLIRContext *context, IGEMMConfigFn configFn) : OpRewritePattern(context), configFn(configFn) {} LogicalResult matchAndRewrite(linalg::GenericOp genericOp, @@ -67,7 +73,7 @@ struct SetIGEMMConfiguration final : OpRewritePattern { } private: - ConfigFn configFn; + IGEMMConfigFn configFn; }; class ConvolutionToIGEMMPass final @@ -75,91 +81,87 @@ class ConvolutionToIGEMMPass final public: using ConvolutionToIGEMMPassBase::ConvolutionToIGEMMPassBase; - explicit ConvolutionToIGEMMPass(ConfigFn configFn) : configFn(configFn) {} + ConvolutionToIGEMMPass(std::optional configFn, + std::optional controlFn) + : configFn(configFn), controlFn(controlFn) {} - void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); - } - void runOnOperation() override { - MLIRContext *context = &getContext(); - - // Rewrite convolutions into a im2col and GEMM. - { - auto conv2dToIm2colControlFn = [](Operation *conv) { - // Don't transform convolutions that have a preset lowering config. - if (getLoweringConfig(conv)) { - return false; - } - return true; - }; - MLIRContext *context = &getContext(); - RewritePatternSet patterns(context); - iree_compiler::IREE::LinalgExt::populateConv2DToIm2colOpPatterns( - patterns, conv2dToIm2colControlFn); - patterns.add(context, configFn); - if (failed(applyPatternsAndFoldGreedily(getOperation(), - std::move(patterns)))) { - return signalPassFailure(); - } - } - - // The im2col transformation collapses some of the dimensions of the - // convolution operands. Try to push the reshape ops towards the boundaries - // of the function and fold with interface tensor ops. - // - // TODO(Max191): Allow for the im2col op to have multiple M dimensions, and - // generate a multi-M dim contraction instead of collapsing and - // propagating reshapes. It should ultimately become a pass option to - // decide whether to collapse the contraction dimensions into a single - // M/N/K dimension. - { - RewritePatternSet bubbleCollapseShapePatterns(context); - linalg::ControlFusionFn bubbleUpExpansionControlFn = - [](OpOperand *fusedOperand) { - Operation *producer = fusedOperand->get().getDefiningOp(); - Operation *consumer = fusedOperand->getOwner(); - - // Block only if one of the operations has a lowering configuration - // which means it likely expects tiling specific to its original - // shape. - if (getLoweringConfig(producer) || getLoweringConfig(consumer)) { - return false; - } - return true; - }; - linalg::populateFoldReshapeOpsByCollapsingPatterns( - bubbleCollapseShapePatterns, bubbleUpExpansionControlFn); - // Add patterns to do some additional cleanup (on top of canonicalizations - // that can be done later) of reshape ops. - tensor::populateFoldTensorEmptyPatterns(bubbleCollapseShapePatterns); - linalg::FillOp::getCanonicalizationPatterns(bubbleCollapseShapePatterns, - context); - tensor::CollapseShapeOp::getCanonicalizationPatterns( - bubbleCollapseShapePatterns, context); - tensor::EmptyOp::getCanonicalizationPatterns(bubbleCollapseShapePatterns, - context); - tensor::ExpandShapeOp::getCanonicalizationPatterns( - bubbleCollapseShapePatterns, context); - populateReshapeToInterfaceTensorPatterns(bubbleCollapseShapePatterns); - if (failed(applyPatternsAndFoldGreedily( - getOperation(), std::move(bubbleCollapseShapePatterns)))) { - return signalPassFailure(); - } - } - } + void runOnOperation() override; private: - ConfigFn configFn = [](linalg::GenericOp genericOp, - IREE::LinalgExt::Im2colOp im2colOp) { - return failure(); - }; + std::optional configFn; + std::optional controlFn; }; } // namespace -std::unique_ptr> -createConvolutionToIGEMMPass(ConfigFn configFn) { - return std::make_unique(configFn); +LogicalResult +convertToIGEMMAndSetConfig(FunctionOpInterface funcOp, + std::optional configFn, + std::optional controlFn) { + // Rewrite convolutions into a im2col and GEMM. + MLIRContext *context = funcOp->getContext(); + { + RewritePatternSet patterns(context); + iree_compiler::IREE::LinalgExt::populateConv2DToIm2colOpPatterns(patterns, + controlFn); + if (configFn.has_value()) { + patterns.add(context, configFn.value()); + } + if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) { + return failure(); + } + } + + // The im2col transformation collapses some of the dimensions of the + // convolution operands. Try to push the reshape ops towards the boundaries + // of the function and fold with interface tensor ops. + // + // TODO(Max191): Allow for the im2col op to have multiple M dimensions, and + // generate a multi-M dim contraction instead of collapsing and + // propagating reshapes. It should ultimately become a pass option to + // decide whether to collapse the contraction dimensions into a single + // M/N/K dimension. + { + RewritePatternSet bubbleCollapseShapePatterns(context); + linalg::ControlFusionFn bubbleUpExpansionControlFn = + [](OpOperand *fusedOperand) { + Operation *producer = fusedOperand->get().getDefiningOp(); + Operation *consumer = fusedOperand->getOwner(); + + // Block only if one of the operations has a lowering configuration + // which means it likely expects tiling specific to its original + // shape. + if (getLoweringConfig(producer) || getLoweringConfig(consumer)) { + return false; + } + return true; + }; + linalg::populateFoldReshapeOpsByCollapsingPatterns( + bubbleCollapseShapePatterns, bubbleUpExpansionControlFn); + // Add patterns to do some additional cleanup (on top of canonicalizations + // that can be done later) of reshape ops. + tensor::populateFoldTensorEmptyPatterns(bubbleCollapseShapePatterns); + linalg::FillOp::getCanonicalizationPatterns(bubbleCollapseShapePatterns, + context); + tensor::CollapseShapeOp::getCanonicalizationPatterns( + bubbleCollapseShapePatterns, context); + tensor::EmptyOp::getCanonicalizationPatterns(bubbleCollapseShapePatterns, + context); + tensor::ExpandShapeOp::getCanonicalizationPatterns( + bubbleCollapseShapePatterns, context); + populateReshapeToInterfaceTensorPatterns(bubbleCollapseShapePatterns); + if (failed(applyPatternsAndFoldGreedily( + funcOp, std::move(bubbleCollapseShapePatterns)))) { + return failure(); + } + } + return success(); +} + +void ConvolutionToIGEMMPass::runOnOperation() { + if (failed(convertToIGEMMAndSetConfig(getOperation()))) { + return signalPassFailure(); + } } } // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Codegen/Common/DecomposePackUnPackOps.cpp b/compiler/src/iree/compiler/Codegen/Common/DecomposePackUnPackOps.cpp index e8b18370a2c3..fed4470e8580 100644 --- a/compiler/src/iree/compiler/Codegen/Common/DecomposePackUnPackOps.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/DecomposePackUnPackOps.cpp @@ -5,6 +5,7 @@ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception #include "iree/compiler/Codegen/Common/Passes.h" +#include "iree/compiler/Dialect/Flow/IR/FlowOps.h" #include "llvm/Support/Debug.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" @@ -16,10 +17,13 @@ #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" #include "mlir/Dialect/Tensor/Utils/Utils.h" #include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Matchers.h" +#include "mlir/IR/OpDefinition.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/Visitors.h" +#include "mlir/Pass/PassManager.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #define DEBUG_TYPE "iree-codegen-decompose-pack-unpack-ops" @@ -27,10 +31,17 @@ namespace mlir::iree_compiler { #define GEN_PASS_DEF_DECOMPOSEPACKUNPACKOPSPASS +#define GEN_PASS_DEF_DECOMPOSEBOUNDARYPACKUNPACKOPSPASS #include "iree/compiler/Codegen/Common/Passes.h.inc" +using PackUnPackControlFn = std::function; + namespace { +//===----------------------------------------------------------------------===// +// Shared rewrite patterns +//===----------------------------------------------------------------------===// + /// A wrapper pattern that calls linalg::lowerPack on tensor::PackOp. It lowers /// a tensor.pack op to tensor.pad + tensor.expand_shape + linalg.transpose ops. struct LowerPackPattern : public OpRewritePattern { @@ -85,33 +96,14 @@ struct LowerUnPackPattern : public OpRewritePattern { std::optional controlFn; }; -struct DecomposePackUnPackOpsPass final - : impl::DecomposePackUnPackOpsPassBase { - using impl::DecomposePackUnPackOpsPassBase< - DecomposePackUnPackOpsPass>::DecomposePackUnPackOpsPassBase; - explicit DecomposePackUnPackOpsPass( - bool tileOuterToOne, bool useOnlyReshapes, - std::optional controlFn) { - this->tileOuterToOne = tileOuterToOne; - this->useOnlyReshapes = useOnlyReshapes; - this->controlFn = controlFn; - } - void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); - } +//===----------------------------------------------------------------------===// +// Shared pass implementation +//===----------------------------------------------------------------------===// - void runOnOperation() override; - -private: - std::optional controlFn; -}; - -} // namespace - -void DecomposePackUnPackOpsPass::runOnOperation() { - MLIRContext *ctx = &getContext(); - auto funcOp = getOperation(); +static LogicalResult commonRunOnOperation( + MLIRContext *ctx, FunctionOpInterface funcOp, bool useOnlyReshapes, + bool tileOuterToOne, + std::optional controlFn = std::nullopt) { // Generalization patterns for outer unit dims have higher priority because // they do not generate reshape ops. if (!useOnlyReshapes) { @@ -122,7 +114,7 @@ void DecomposePackUnPackOpsPass::runOnOperation() { funcOp.emitError( "failed to apply generalization patterns on pack/unpack ops for " "outer unit dims cases"); - return signalPassFailure(); + return failure(); } } @@ -135,7 +127,7 @@ void DecomposePackUnPackOpsPass::runOnOperation() { funcOp.emitError( "failed to apply generalization patterns on pack/unpack ops for " "general cases."); - return signalPassFailure(); + return failure(); } } @@ -163,17 +155,24 @@ void DecomposePackUnPackOpsPass::runOnOperation() { builder.getIndexAttr(1)); return tileSizes; })); - funcOp->walk([&](tensor::PackOp op) { - if (controlFn && failed(controlFn.value()(op))) { - return; + { + WalkResult status = funcOp->walk([&](tensor::PackOp op) { + if (controlFn && failed(controlFn.value()(op))) { + return WalkResult::advance(); + } + FailureOr tileAndFuseResult = + scf::tileConsumerAndFuseProducersUsingSCF( + rewriter, cast(op.getOperation()), + packOptions); + if (failed(tileAndFuseResult)) + return WalkResult::interrupt(); + rewriter.replaceOp(op, tileAndFuseResult->replacements[op.getResult()]); + return WalkResult::advance(); + }); + if (status.wasInterrupted()) { + return failure(); } - FailureOr tileAndFuseResult = - scf::tileConsumerAndFuseProducersUsingSCF( - rewriter, cast(op.getOperation()), packOptions); - if (failed(tileAndFuseResult)) - return signalPassFailure(); - rewriter.replaceOp(op, tileAndFuseResult->replacements[op.getResult()]); - }); + } auto unpackTilingOptions = scf::SCFTilingOptions().setTileSizeComputationFunction( @@ -191,17 +190,23 @@ void DecomposePackUnPackOpsPass::runOnOperation() { } return tileSizes; }); - funcOp->walk([&](tensor::UnPackOp op) { - if (controlFn && failed(controlFn.value()(op))) { - return; + { + WalkResult status = funcOp->walk([&](tensor::UnPackOp op) { + if (controlFn && failed(controlFn.value()(op))) { + return WalkResult::advance(); + } + FailureOr tilingResult = scf::tileUsingSCF( + rewriter, cast(op.getOperation()), + unpackTilingOptions); + if (failed(tilingResult)) + return WalkResult::interrupt(); + rewriter.replaceOp(op, tilingResult->replacements); + return WalkResult::advance(); + }); + if (status.wasInterrupted()) { + return failure(); } - FailureOr tilingResult = - scf::tileUsingSCF(rewriter, cast(op.getOperation()), - unpackTilingOptions); - if (failed(tilingResult)) - return signalPassFailure(); - rewriter.replaceOp(op, tilingResult->replacements); - }); + } LLVM_DEBUG({ llvm::dbgs() @@ -219,7 +224,7 @@ void DecomposePackUnPackOpsPass::runOnOperation() { ctx->getOrLoadDialect()->getCanonicalizationPatterns( patterns); if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) { - return signalPassFailure(); + return failure(); } } @@ -238,16 +243,114 @@ void DecomposePackUnPackOpsPass::runOnOperation() { linalg::GeneralizeOuterUnitDimsUnPackOpPattern>(ctx); } if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) { - return signalPassFailure(); + return failure(); } } + return success(); } -std::unique_ptr> -createDecomposePackUnPackOpsPass(bool tileOuterToOne, bool useOnlyReshapes, - std::optional controlFn) { - return std::make_unique( - tileOuterToOne, useOnlyReshapes, controlFn); +//===----------------------------------------------------------------------===// +// DecomposePackUnPackOpsPass +//===----------------------------------------------------------------------===// + +struct DecomposePackUnPackOpsPass final + : impl::DecomposePackUnPackOpsPassBase { + using impl::DecomposePackUnPackOpsPassBase< + DecomposePackUnPackOpsPass>::DecomposePackUnPackOpsPassBase; + + void runOnOperation() override; +}; + +} // namespace + +void DecomposePackUnPackOpsPass::runOnOperation() { + if (failed(commonRunOnOperation(&getContext(), getOperation(), + useOnlyReshapes, tileOuterToOne))) { + return signalPassFailure(); + } +} + +//===----------------------------------------------------------------------===// +// DecomposeBoundaryPackUnPackOpsPass +//===----------------------------------------------------------------------===// + +namespace { + +struct DecomposeBoundaryPackUnPackOpsPass final + : impl::DecomposeBoundaryPackUnPackOpsPassBase< + DecomposeBoundaryPackUnPackOpsPass> { + using impl::DecomposeBoundaryPackUnPackOpsPassBase< + DecomposeBoundaryPackUnPackOpsPass>:: + DecomposeBoundaryPackUnPackOpsPassBase; + + void runOnOperation() override; +}; + +} // namespace + +/// Check if the given `op` is a pack or unpack op with padding. +static bool hasPadding(Operation *op) { + auto needsPad = [](ShapedType unpackedType, ArrayRef innerDimPos, + ArrayRef staticInnerTiles) { + for (auto [dimPos, tile] : llvm::zip_equal(innerDimPos, staticInnerTiles)) { + if (unpackedType.isDynamicDim(dimPos) || ShapedType::isDynamic(tile) || + unpackedType.getDimSize(dimPos) % tile != 0) { + return true; + } + } + return false; + }; + auto packOp = dyn_cast(op); + if (packOp && needsPad(packOp.getSourceType(), packOp.getInnerDimsPos(), + packOp.getStaticInnerTiles())) { + return true; + } + auto unPackOp = dyn_cast(op); + if (unPackOp && needsPad(unPackOp.getDestType(), unPackOp.getInnerDimsPos(), + unPackOp.getStaticInnerTiles())) { + return true; + } + return false; +} + +/// Control function for decomposing pack and unpack ops. Returns true if the +/// op is an unpadded pack or unpack op, and it is at the boundary of a +/// dispatch. The following conditions need to be met: +/// 1. The PackOp or UnPackOp must have no padding. +/// 2. If the op is a PackOp, then its producer must be a dispatch tensor load. +/// 3. If the op is an UnPackOp, then all of its consumers must be dispatch +/// tensor stores. +static LogicalResult isUnpaddedAndAtBoundary(Operation *op) { + if (!isa(op) && !isa(op)) { + return failure(); + } + if (hasPadding(op)) { + return failure(); + } + + // If the producer is a dispatch tensor load, then the `op` is decomposable + // if it is a PackOp. + if (isa(op) && + op->getOperand(0).getDefiningOp()) { + return success(); + } + // If all consumers are dispatch tensor stores, then the `op` is decomposable + // if it is an UnPackOp. + if (isa(op) && + llvm::all_of(op->getUsers(), [&](Operation *user) { + return isa(user); + })) { + return success(); + } + return failure(); +} + +void DecomposeBoundaryPackUnPackOpsPass::runOnOperation() { + if (failed(commonRunOnOperation(&getContext(), getOperation(), + /*useOnlyReshapes=*/true, tileOuterToOne, + isUnpaddedAndAtBoundary))) { + return signalPassFailure(); + } } } // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributionPatterns.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributionPatterns.cpp index 5133d9dfaa4a..0ef6e64d2c26 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributionPatterns.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributionPatterns.cpp @@ -107,11 +107,11 @@ struct DistributeConstants final : OpDistributionPattern { Type elementType = constant.getType().getElementType(); auto vectorType = VectorType::get(layout.getDistributedShape(), elementType); - Operation *distirbutedOp = rewriter.create( + auto distributedOp = rewriter.create( constantOp.getLoc(), vectorType, SplatElementsAttr::get(vectorType, attr.getSplatValue())); replaceOpWithDistributedValues(rewriter, constantOp, - distirbutedOp->getResult(0)); + distributedOp->getResult(0)); return success(); } }; @@ -536,8 +536,10 @@ struct DistributeScfFor final : OpDistributionPattern { SmallVector newInitArgs; for (Value initArg : forOp.getInitArgs()) { if (auto vectorInitArg = dyn_cast(initArg)) { - initArg = - getDistributed(rewriter, vectorInitArg, signature[vectorInitArg]); + if (isNonZeroRank(vectorInitArg)) { + initArg = + getDistributed(rewriter, vectorInitArg, signature[vectorInitArg]); + } } newInitArgs.push_back(initArg); } @@ -582,8 +584,14 @@ struct DistributeScfFor final : OpDistributionPattern { SmallVector operands; for (Value operand : yieldOp->getOperands()) { if (auto vectorOperand = dyn_cast(operand)) { - operand = DistributionPattern::getDistributed(rewriter, vectorOperand, - signature[vectorOperand]); + // Distributing the operand requires it to have a non-zero rank, meaning + // it must have at least one dimension. If the vector has a non-zero + // rank, the operand is distributed according to the provided layout + // signature. + if (isNonZeroRank(vectorOperand)) { + operand = DistributionPattern::getDistributed( + rewriter, vectorOperand, signature[vectorOperand]); + } } operands.push_back(operand); } @@ -606,8 +614,10 @@ struct DistributeScfFor final : OpDistributionPattern { for (auto [bbArg, oldInit] : llvm::zip_equal(bbArgs, oldInits)) { Value val = bbArg; if (auto oldVectorInit = dyn_cast(oldInit)) { - val = rewriter.create( - oldVectorInit.getLoc(), oldVectorInit.getType(), val); + if (isNonZeroRank(oldVectorInit)) { + val = rewriter.create( + oldVectorInit.getLoc(), oldVectorInit.getType(), val); + } } replacements.push_back(val); } diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.cpp index dc3078372f92..790484d2c565 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.cpp @@ -9,6 +9,7 @@ #include #include "llvm/ADT/APInt.h" +#include "llvm/ADT/Sequence.h" #include "llvm/Support/Debug.h" #include "llvm/Support/MathExtras.h" #include "llvm/Support/raw_ostream.h" @@ -20,51 +21,106 @@ using llvm::APIntOps::GreatestCommonDivisor; namespace mlir::iree_compiler { +template static llvm::raw_ostream &operator<<(llvm::raw_ostream &os, - const GPUMMASchedule &schedule) { - os << "mSize: " << schedule.mSize << ", "; - os << "nSize: " << schedule.nSize << ", "; - os << "kSize: " << schedule.kSize << ", "; - os << "mTileCount: " << schedule.mTileCount << ", "; - os << "nTileCount: " << schedule.nTileCount << ", "; - os << "kTileCount: " << schedule.kTileCount << ", "; - os << "mWarpCount: " << schedule.mWarpCount << ", "; - os << "nWarpCount: " << schedule.nWarpCount; + const llvm::SmallVectorImpl &vector) { + os << "["; + llvm::interleaveComma(vector, os); + os << "]"; return os; } +llvm::raw_ostream &operator<<(llvm::raw_ostream &os, + const GPUMMASchedule &schedule) { + os << "mSizes: " << schedule.mSize << ", "; + os << "nSizes: " << schedule.nSize << ", "; + os << "kSizes: " << schedule.kSize << ", "; + os << "mTileSizes: " << schedule.mTileSizes << ", "; + os << "nTileSizes: " << schedule.nTileSizes << ", "; + os << "kTileSizes: " << schedule.kTileSizes << ", "; + os << "mSubgroupCounts: " << schedule.mSubgroupCounts << ", "; + os << "nSubgroupCounts: " << schedule.nSubgroupCounts; + return os; +} + +// Shortened helper to compute the product of `values`. +static int64_t prod(ArrayRef values) { + return ShapedType::getNumElements(values); +} + static int64_t calculateSharedMemoryUsedInBytes(const GPUMMASchedule &schedule, int64_t lhsBitwidth, int64_t rhsBitwidth) { - int64_t tileM = schedule.mSize * schedule.mTileCount * schedule.mWarpCount; - int64_t tileN = schedule.nSize * schedule.nTileCount * schedule.nWarpCount; - int64_t tileK = schedule.kSize * schedule.kTileCount; + + int64_t tileM = schedule.mSize * prod(schedule.mTileSizes) * + prod(schedule.mSubgroupCounts); + int64_t tileN = schedule.nSize * prod(schedule.nTileSizes) * + prod(schedule.nSubgroupCounts); + int64_t tileK = schedule.kSize * prod(schedule.kTileSizes); return (tileM * tileK * lhsBitwidth + tileN * tileK * rhsBitwidth) / 8; } +/// Check that a GPUMMASchedule fits alignment restrictions. To be aligned, +/// the problem must be evenly divisible by the number of elements in the +/// schedule for each dimension. If `mustBeAligned` is false, then the innermost +/// problem dimension is allowed to be unaligned . static bool isScheduleAligned(const GPUMatmulShapeType &problem, const GPUMMASchedule &schedule, bool mustBeAligned) { - auto alignedMSize = - mustBeAligned - ? problem.mSize - : llvm::divideCeil(problem.mSize, schedule.mSize) * schedule.mSize; - auto alignedNSize = - mustBeAligned - ? problem.nSize - : llvm::divideCeil(problem.nSize, schedule.nSize) * schedule.nSize; - auto alignedKSize = - mustBeAligned - ? problem.kSize - : llvm::divideCeil(problem.kSize, schedule.kSize) * schedule.kSize; - bool isValidM = (alignedMSize % (schedule.mSize * schedule.mTileCount * - schedule.mWarpCount)) == 0; - bool isValidN = (alignedNSize % (schedule.nSize * schedule.nTileCount * - schedule.nWarpCount)) == 0; - bool isValidK = (alignedKSize % (schedule.kSize * schedule.kTileCount)) == 0; + SmallVector alignedMSizes(problem.mSizes); + alignedMSizes.back() = + mustBeAligned ? problem.mSizes.back() + : llvm::divideCeil(problem.mSizes.back(), schedule.mSize) * + schedule.mSize; + SmallVector alignedNSizes(problem.nSizes); + alignedNSizes.back() = + mustBeAligned ? problem.nSizes.back() + : llvm::divideCeil(problem.nSizes.back(), schedule.nSize) * + schedule.nSize; + SmallVector alignedKSizes(problem.kSizes); + alignedKSizes.back() = + mustBeAligned ? problem.kSizes.back() + : llvm::divideCeil(problem.kSizes.back(), schedule.kSize) * + schedule.kSize; + // Returns the number of elements in the schedule for each dimension. + auto getScheduleSizes = + [&](int64_t size, SmallVector tileCount, + std::optional> subgroupCount) { + SmallVector sizes = llvm::map_to_vector( + llvm::seq(tileCount.size()), [&](int64_t i) { + return subgroupCount ? tileCount[i] * subgroupCount.value()[i] + : tileCount[i]; + }); + sizes.back() *= size; + return sizes; + }; + // Checks whether the elements of `a` are evenly divisible by the + // corresponding elements of `b`. + auto areAligned = [](SmallVector a, SmallVector b) { + for (auto [aVal, bVal] : llvm::zip_equal(a, b)) { + if (aVal % bVal != 0) { + return false; + } + } + return true; + }; + bool isValidM = areAligned( + alignedMSizes, getScheduleSizes(schedule.mSize, schedule.mTileSizes, + schedule.mSubgroupCounts)); + bool isValidN = areAligned( + alignedNSizes, getScheduleSizes(schedule.nSize, schedule.nTileSizes, + schedule.nSubgroupCounts)); + bool isValidK = areAligned( + alignedKSizes, + getScheduleSizes(schedule.kSize, schedule.kTileSizes, std::nullopt)); return isValidM && isValidN && isValidK; } +/// Returns whether or not a GPUMMASchedule is valid for the given problem. +/// This checks that: +/// - The problem is aligned to the schedule +/// - the number of threads in the schedule workgroup can be distributed +/// to a corresponding vector.transfer read in VectorDistribute. static bool isValidMMASchedule(const GPUMatmulShapeType &problem, const GPUMMASchedule &schedule, bool mustBeAligned, int64_t subgroupSize, @@ -76,11 +132,13 @@ static bool isValidMMASchedule(const GPUMatmulShapeType &problem, const int64_t kMaxVectorLoadBitWidth = 128; int64_t elemsPerThread = kMaxVectorLoadBitWidth / problem.bType.getIntOrFloatBitWidth(); - int64_t wgThreads = schedule.mWarpCount * schedule.nWarpCount * subgroupSize; - - int64_t mWgSize = schedule.mSize * schedule.mTileCount * schedule.mWarpCount; - int64_t nWgSize = schedule.nSize * schedule.nTileCount * schedule.nWarpCount; - int64_t kWgSize = schedule.kSize * schedule.kTileCount; + int64_t wgThreads = subgroupSize * prod(schedule.mSubgroupCounts) * + prod(schedule.nSubgroupCounts); + int64_t mWgSize = schedule.mSize * prod(schedule.mTileSizes) * + prod(schedule.mSubgroupCounts); + int64_t nWgSize = schedule.nSize * prod(schedule.nTileSizes) * + prod(schedule.nSubgroupCounts); + int64_t kWgSize = schedule.kSize * prod(schedule.kTileSizes); int64_t innerLhsDimSize = transposedLhs ? mWgSize : kWgSize; int64_t innerRhsDimSize = transposedRhs ? kWgSize : nWgSize; @@ -94,6 +152,10 @@ static bool isValidMMASchedule(const GPUMatmulShapeType &problem, return isAligned && isDistributableLhs && isDistributableRhs; } +/// Tries to fit the schedule into shared memory by decrementing the size of the +/// schedule dimensions from outermost to innermost until a valid schedule is +/// found. The schedule sizes are reduced in the order of mTileSizes, +/// nTileSizes, kTileSizes, mSubgroupCounts, nSubgroupCounts. static FailureOr fitScheduleInSharedMemory( GPUMatmulShapeType intrinsic, GPUMMASchedule schedule, llvm::function_ref isScheduleValid) { @@ -105,31 +167,35 @@ static FailureOr fitScheduleInSharedMemory( llvm::dbgs() << "Shrinking schedule...\n"; }); - auto decrementIfPossible = [](int64_t &c) -> LogicalResult { - if (c <= 1) { - return failure(); + auto decrementIfPossible = + [](SmallVector &sizes) -> LogicalResult { + for (int64_t &size : sizes) { + if (size <= 1) + continue; + --size; + return success(); } - --c; - return success(); + return failure(); }; // Attempt to shrink the schedule along one of the dimensions. // TODO: A better solution should probably factor problem.mSize / - // (mWarpCount * mTileCount * mSize) and then pop off the smallest factors - // one at a time, preferably trying to keep the tile "generally square." - if (succeeded(decrementIfPossible(schedule.mTileCount))) { + // (mSubgroupCount * mTileCount * mSize) and then pop off the smallest + // factors one at a time, preferably trying to keep the tile "generally + // square." + if (succeeded(decrementIfPossible(schedule.mTileSizes))) { continue; } - if (succeeded(decrementIfPossible(schedule.nTileCount))) { + if (succeeded(decrementIfPossible(schedule.nTileSizes))) { continue; } - if (succeeded(decrementIfPossible(schedule.kTileCount))) { + if (succeeded(decrementIfPossible(schedule.kTileSizes))) { continue; } - if (succeeded(decrementIfPossible(schedule.mWarpCount))) { + if (succeeded(decrementIfPossible(schedule.mSubgroupCounts))) { continue; } - if (succeeded(decrementIfPossible(schedule.nWarpCount))) { + if (succeeded(decrementIfPossible(schedule.nSubgroupCounts))) { continue; } @@ -148,6 +214,9 @@ static FailureOr fitScheduleInSharedMemory( static LogicalResult canTargetIntrinsic(const GPUMatmulShapeType &problem, const GPUMatmulShapeType &intrinsic, bool canUpcastAcc, bool mustBeAligned) { + assert(intrinsic.mSizes.size() == 1 && intrinsic.nSizes.size() == 1 && + intrinsic.kSizes.size() == 1 && + "expected intrinsic to have a single M, N, and K dimension."); if (problem.aType != intrinsic.aType || problem.bType != intrinsic.bType) { return failure(); // Cannot use this intrinsic for mismatched types } @@ -161,17 +230,17 @@ static LogicalResult canTargetIntrinsic(const GPUMatmulShapeType &problem, } } - if (mustBeAligned && (problem.mSize % intrinsic.mSize != 0 || - problem.nSize % intrinsic.nSize != 0 || - problem.kSize % intrinsic.kSize != 0)) { + if (mustBeAligned && (problem.mSizes.back() % intrinsic.mSizes[0] != 0 || + problem.nSizes.back() % intrinsic.nSizes[0] != 0 || + problem.kSizes.back() % intrinsic.kSizes[0] != 0)) { return failure(); // Cannot use this intrinsic for misaligned cases. } // Cannot use the intrinsic when the tile size is greater than problem size. // Because tiling is a no-op, and we can't infer tiling sizes from IR. - if (!mustBeAligned && - (problem.mSize < intrinsic.mSize || problem.nSize < intrinsic.nSize || - problem.kSize < intrinsic.kSize)) { + if (!mustBeAligned && (problem.mSizes.back() < intrinsic.mSizes[0] || + problem.nSizes.back() < intrinsic.nSizes[0] || + problem.kSizes.back() < intrinsic.kSizes[0])) { return failure(); } @@ -185,77 +254,123 @@ static GPUMMASchedule getOptimalMMASchedule(const GPUMatmulShapeType &problem, const GPUMatmulShapeType &intrinsic, const GPUMMAHeuristicSeeds &seeds, uint64_t intrinsicIndex) { - int64_t mTotalTileCount = llvm::divideCeil(problem.mSize, intrinsic.mSize); - int64_t nTotalTileCount = llvm::divideCeil(problem.nSize, intrinsic.nSize); - - int64_t remainingWarps = seeds.bestSubgroupCountPerWorkgroup; + assert(intrinsic.mSizes.size() == 1 && intrinsic.nSizes.size() == 1 && + intrinsic.kSizes.size() == 1 && + "expected intrinsic to have a single M, N, and K dimension."); + // mTotalTileCounts and nTotalTileCounts represent the total number of + // intrinsics along the M or N dimensions needed to fill the problem size. + // For example, if the problem is {M:[4, 16], N:[2, 32], K[3, 128]} for a + // 16x16x16 intrinsic, then: + // - mTotalTileCounts would be 4 * (16/16) = 4 + // - nTotalTileCounts would be 2 * (32/16) = 4 + SmallVector mTotalTileCounts = problem.mSizes; + SmallVector nTotalTileCounts = problem.nSizes; + mTotalTileCounts.back() = + llvm::divideCeil(problem.mSizes.back(), intrinsic.mSizes[0]); + nTotalTileCounts.back() = + llvm::divideCeil(problem.nSizes.back(), intrinsic.nSizes[0]); + + int64_t remainingSubgroups = seeds.bestSubgroupCountPerWorkgroup; int64_t remainingTiles = seeds.bestMNTileCountPerSubgroup; - // Assign more warps to the M dimension (used later) to balance thread + // Assign more subgroups to the M dimension (used later) to balance thread // counts along X and Y dimensions. - int64_t warpSqrt = - 1ull << (llvm::divideCeil(llvm::Log2_64(remainingWarps), 2)); - int64_t tileSqrt = 1ull << (llvm::Log2_64(remainingTiles) / 2); - - int64_t mWarpCount = 0, nWarpCount = 0; - int64_t mTileCount = 0, nTileCount = 0; - - // See if the square root can divide mTotalTileCount. If so it means we can - // distribute to both dimensions evenly. Otherwise, try to distribute to N - // and then M. - if (mTotalTileCount > (warpSqrt * tileSqrt) && - mTotalTileCount % (warpSqrt * tileSqrt) == 0) { - mWarpCount = warpSqrt; - mTileCount = tileSqrt; - - remainingWarps /= warpSqrt; - remainingTiles /= tileSqrt; - - APInt nGCD = GreatestCommonDivisor(APInt(64, nTotalTileCount), - APInt(64, remainingWarps)); - nWarpCount = nGCD.getSExtValue(); - nTotalTileCount /= nWarpCount; - remainingWarps /= nWarpCount; - - nGCD = GreatestCommonDivisor(APInt(64, nTotalTileCount), - APInt(64, remainingTiles)); - nTileCount = nGCD.getSExtValue(); - } else { - APInt nGCD = GreatestCommonDivisor(APInt(64, nTotalTileCount), - APInt(64, remainingWarps)); - nWarpCount = nGCD.getSExtValue(); - nTotalTileCount /= nWarpCount; - remainingWarps /= nWarpCount; - - nGCD = GreatestCommonDivisor(APInt(64, nTotalTileCount), - APInt(64, remainingTiles)); - nTileCount = nGCD.getSExtValue(); - remainingTiles /= nTileCount; - - APInt mGCD = GreatestCommonDivisor(APInt(64, mTotalTileCount), - APInt(64, remainingWarps)); - mWarpCount = mGCD.getSExtValue(); - mTotalTileCount /= mWarpCount; - remainingWarps /= mWarpCount; - - mGCD = GreatestCommonDivisor(APInt(64, mTotalTileCount), - APInt(64, remainingTiles)); - mTileCount = mGCD.getSExtValue(); + int mDim = problem.mSizes.size() - 1; + int nDim = problem.nSizes.size() - 1; + SmallVector mTileSizes(problem.mSizes.size(), 0), + nTileSizes(problem.nSizes.size(), 0), + mSubgroupCounts(problem.mSizes.size(), 0), + nSubgroupCounts(problem.nSizes.size(), 0); + // Start at the innermost nDim and mDim, and try to distribute evenly to M and + // N for each pair of M and N dims. Otherwise, distribute to N and then M. + while (mDim >= 0 || nDim >= 0) { + int64_t subgroupSqrt = + 1ull << (llvm::divideCeil(llvm::Log2_64(remainingSubgroups), 2)); + int64_t tileSqrt = 1ull << (llvm::Log2_64(remainingTiles) / 2); + + // See if the square root can divide mTotalTileCount. If so it means we can + // distribute to both dimensions evenly to minimize the number of global + // loads. Otherwise, try to distribute to N and then M. + if (mDim >= 0 && nDim >= 0 && + mTotalTileCounts[mDim] > (subgroupSqrt * tileSqrt) && + mTotalTileCounts[mDim] % (subgroupSqrt * tileSqrt) == 0) { + mSubgroupCounts[mDim] = subgroupSqrt; + mTileSizes[mDim] = tileSqrt; + + remainingSubgroups /= subgroupSqrt; + remainingTiles /= tileSqrt; + + APInt nGCD = GreatestCommonDivisor(APInt(64, nTotalTileCounts[nDim]), + APInt(64, remainingSubgroups)); + nSubgroupCounts[nDim] = nGCD.getSExtValue(); + nTotalTileCounts[nDim] /= nSubgroupCounts[nDim]; + remainingSubgroups /= nSubgroupCounts[nDim]; + + nGCD = GreatestCommonDivisor(APInt(64, nTotalTileCounts[nDim]), + APInt(64, remainingTiles)); + nTileSizes[nDim] = nGCD.getSExtValue(); + remainingTiles /= nTileSizes[nDim]; + } else { + if (nDim >= 0) { + APInt nGCD = GreatestCommonDivisor(APInt(64, nTotalTileCounts[nDim]), + APInt(64, remainingSubgroups)); + nSubgroupCounts[nDim] = nGCD.getSExtValue(); + nTotalTileCounts[nDim] /= nSubgroupCounts[nDim]; + remainingSubgroups /= nSubgroupCounts[nDim]; + + nGCD = GreatestCommonDivisor(APInt(64, nTotalTileCounts[nDim]), + APInt(64, remainingTiles)); + nTileSizes[nDim] = nGCD.getSExtValue(); + remainingTiles /= nTileSizes[nDim]; + } + + if (mDim >= 0) { + APInt mGCD = GreatestCommonDivisor(APInt(64, mTotalTileCounts[mDim]), + APInt(64, remainingSubgroups)); + mSubgroupCounts[mDim] = mGCD.getSExtValue(); + mTotalTileCounts[mDim] /= mSubgroupCounts[mDim]; + remainingSubgroups /= mSubgroupCounts[mDim]; + + mGCD = GreatestCommonDivisor(APInt(64, mTotalTileCounts[mDim]), + APInt(64, remainingTiles)); + mTileSizes[mDim] = mGCD.getSExtValue(); + remainingTiles /= mTileSizes[mDim]; + } + } + --mDim; + --nDim; } - const uint64_t kTotalTileCount = - llvm::divideCeil(problem.kSize, intrinsic.kSize); + // kTotalTileCounts is similar to m/nTotalTileCounts, representing the total + // number of intrinsics along the K dimensions needed to fill the problem. + // For the problem described above {M:[4, 16], N:[2, 32], K[3, 128]} with a + // 16x16x16 intrinsic, then: + // - kTotalTileCounts would be 3 * (128/16) = 24 + SmallVector kTotalTileCounts = problem.kSizes; + kTotalTileCounts.back() = + llvm::divideCeil(problem.kSizes.back(), intrinsic.kSizes[0]); + // Compute the ideal number of intrinsics along K per subgroup based on the + // seed. int64_t bestKTileCountPerSubgroup = seeds.bestKElementCountPerSubgroup ? llvm::divideCeil(seeds.bestKElementCountPerSubgroup, - intrinsic.kSize) + intrinsic.kSizes[0]) : seeds.bestKTileCountPerSubgroup; - APInt kGCD = GreatestCommonDivisor(APInt(64, kTotalTileCount), - APInt(64, bestKTileCountPerSubgroup)); - int64_t kTileCount = kGCD.getSExtValue(); + SmallVector kTileSizes(problem.kSizes.size(), 0); + // Start at the innermost K dim, and tile each dim to try to satisfy the ideal + // K intrinsic count per subgroup with the overall product of K tile counts. + int kDim = problem.kSizes.size() - 1; + while (kDim >= 0) { + APInt kGCD = GreatestCommonDivisor(APInt(64, kTotalTileCounts[kDim]), + APInt(64, bestKTileCountPerSubgroup)); + kTileSizes[kDim] = kGCD.getSExtValue(); + bestKTileCountPerSubgroup /= kTileSizes[kDim]; + --kDim; + } - return GPUMMASchedule{intrinsicIndex, intrinsic.mSize, intrinsic.nSize, - intrinsic.kSize, mWarpCount, nWarpCount, - mTileCount, nTileCount, kTileCount}; + return GPUMMASchedule{ + intrinsicIndex, intrinsic.mSizes[0], intrinsic.nSizes[0], + intrinsic.kSizes[0], mSubgroupCounts, nSubgroupCounts, + mTileSizes, nTileSizes, kTileSizes}; } FailureOr deduceMMASchedule( @@ -297,7 +412,6 @@ FailureOr deduceMMASchedule( return isAligned && sharedMemoryUsed <= sharedMemLimitInBytes; }; - return fitScheduleInSharedMemory(intrinsic, schedule, isValidSchedule); } return failure(); @@ -309,7 +423,10 @@ FailureOr deduceAttentionSchedule( const GPUMMAHeuristicSeeds &pvMatmulSeeds, int64_t sharedMemLimitInBytes, int64_t subgroupSize, bool transposedQ, bool transposedK, bool transposedV, bool canUpcastAcc, bool mustBeAligned) { - + assert(pvMatmul.mSizes.size() == 1 && pvMatmul.nSizes.size() == 1 && + pvMatmul.kSizes.size() == 1 && qkMatmul.mSizes.size() == 1 && + qkMatmul.nSizes.size() == 1 && qkMatmul.kSizes.size() == 1 && + "unimplemented: multi M/N/K attention schedule"); for (auto [index, intrinsic] : llvm::enumerate(intrinsics)) { if (failed(canTargetIntrinsic(qkMatmul, intrinsic, canUpcastAcc, mustBeAligned))) { @@ -329,7 +446,7 @@ FailureOr deduceAttentionSchedule( llvm::dbgs() << " " << schedule << "\n"; }); - int64_t intrinsicK = intrinsic.kSize; + int64_t intrinsicK = intrinsic.kSizes[0]; auto isValidSchedule = [&](const GPUMMASchedule &schedule) -> bool { // Create a mma schedule for qkMatmul in attention. // qkMatmul.M = pvMatmul.M @@ -339,11 +456,11 @@ FailureOr deduceAttentionSchedule( schedule.mSize, schedule.kSize, intrinsicK, - /*mWarpCount=*/schedule.mWarpCount, - /*nWarpCount=*/1, - schedule.mTileCount, - schedule.kTileCount, - qkMatmul.kSize / intrinsicK}; + /*mSubgroupCount=*/schedule.mSubgroupCounts[0], + /*nSubgroupCount=*/1, + schedule.mTileSizes[0], + schedule.kTileSizes[0], + qkMatmul.kSizes[0] / intrinsicK}; bool isQKAligned = isValidMMASchedule(qkMatmul, qkSchedule, mustBeAligned, subgroupSize, diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.h b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.h index 8211443a2e12..13f6a56c1b6f 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.h +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.h @@ -10,15 +10,18 @@ namespace mlir::iree_compiler { /// Struct containing information about a matmul's shape and type. struct GPUMatmulShapeType { - int64_t mSize; - int64_t nSize; - int64_t kSize; + SmallVector mSizes; + SmallVector nSizes; + SmallVector kSizes; Type aType; Type bType; Type cType; GPUMatmulShapeType(int64_t m, int64_t n, int64_t k, Type a, Type b, Type c) - : mSize(m), nSize(n), kSize(k), aType(a), bType(b), cType(c) {} + : mSizes({m}), nSizes({n}), kSizes({k}), aType(a), bType(b), cType(c) {} + GPUMatmulShapeType(SmallVector m, SmallVector n, + SmallVector k, Type a, Type b, Type c) + : mSizes(m), nSizes(n), kSizes(k), aType(a), bType(b), cType(c) {} }; /// Struct containing seed tile sizes for GPU MMA heuristics deduction logic. @@ -38,14 +41,42 @@ struct GPUMMAHeuristicSeeds { struct GPUMMASchedule { // Index of the chosen intrinsic into the list of given MMA intrinsics uint64_t index; - int64_t mSize; // Native MMA size along M dimension - int64_t nSize; // Native MMA size along N dimension - int64_t kSize; // Native MMA size along K dimension - int64_t mWarpCount; // Number of subgroups along M dimension - int64_t nWarpCount; // Number of subgroups along N dimension - int64_t mTileCount; // Number of tiles per subgroup along M dimension - int64_t nTileCount; // Number of tiles per subgroup along N dimension - int64_t kTileCount; // Number of tiles along K dimension + int64_t mSize; // Native MMA intrinsic size along M dimension for a subgroup. + int64_t nSize; // Native MMA intrinsic size along N dimension for a subgroup. + int64_t kSize; // Native MMA intrinsic size along K dimension for a subgroup. + + // Number of subgroups along each M and N dimension. + SmallVector mSubgroupCounts; + SmallVector nSubgroupCounts; + + // Tile sizes for each M, N, and K dimension. When there are multiple M, N, + // or K dimensions, the intrinsic sizes are targeted to the innermost + // dimension, and the outer dimensions can be thought of as unrolling factors + // along M, N, or K. + SmallVector mTileSizes; // M tile sizes per subgroup. + SmallVector nTileSizes; // N tile sizes per subgroup. + SmallVector kTileSizes; // K tile sizes. + + // Constructor for multi M, N, K dim schedules. + GPUMMASchedule(uint64_t i, int64_t mIntrinsicSize, int64_t nIntrinsicSize, + int64_t kIntrinsicSize, SmallVector mSubgroupCounts, + SmallVector nSubgroupCounts, + SmallVector mTileSizes, + SmallVector nTileSizes, + SmallVector kTileSizes) + : index(i), mSize(mIntrinsicSize), nSize(nIntrinsicSize), + kSize(kIntrinsicSize), mSubgroupCounts(mSubgroupCounts), + nSubgroupCounts(nSubgroupCounts), mTileSizes(mTileSizes), + nTileSizes(nTileSizes), kTileSizes(kTileSizes) {} + + // Constructor for single M, N, K dim schedules. + GPUMMASchedule(uint64_t i, int64_t mIntrinsicSize, int64_t nIntrinsicSize, + int64_t kIntrinsicSize, int64_t mSubgroup, int64_t nSubgroup, + int64_t mTileSize, int64_t nTileSize, int64_t kTileSize) + : index(i), mSize(mIntrinsicSize), nSize(nIntrinsicSize), + kSize(kIntrinsicSize), mSubgroupCounts({mSubgroup}), + nSubgroupCounts({nSubgroup}), mTileSizes({mTileSize}), + nTileSizes({nTileSize}), kTileSizes({kTileSize}) {} }; /// Returns a schedule for using one of the given MMA |intrinsics| to target the @@ -69,4 +100,7 @@ FailureOr deduceAttentionSchedule( bool transposedV = false, bool canUpcastAcc = false, bool mustBeAligned = true); +llvm::raw_ostream &operator<<(llvm::raw_ostream &os, + const GPUMMASchedule &schedule); + } // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUMaterializeEncoding.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUMaterializeEncoding.cpp index 5f0d60660b84..778cd082736a 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUMaterializeEncoding.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUMaterializeEncoding.cpp @@ -4,6 +4,7 @@ // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +#include #include "iree/compiler/Codegen/Common/EncodingUtils.h" #include "iree/compiler/Codegen/Common/GPU/Passes.h" #include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h" @@ -40,84 +41,142 @@ namespace mlir::iree_compiler { #define GEN_PASS_DEF_GPUMATERIALIZEHOSTENCODINGPASS #include "iree/compiler/Codegen/Common/GPU/Passes.h.inc" -static bool hasIntrinsic(IREE::GPU::TargetAttr target, - IREE::GPU::MMAIntrinsic intrinsic) { - for (IREE::GPU::MMAAttr mma : target.getWgp().getMma()) { - if (mma.getIntrinsic().getValue() == intrinsic) { - return true; +static IREE::GPU::MMAAttr chooseIntrinsicMMAAttr(TypeRange eTypes, + IREE::GPU::TargetWgpAttr wgp) { + IREE::GPU::MMAAttr candidateMma; + for (IREE::GPU::MMAAttr mma : wgp.getMma()) { + // Filter out intrinsics that don't match the element types of this matmul. + auto [et0, et1, et2] = mma.getABCElementTypes(); + if (et0 != eTypes[0] || et1 != eTypes[1] || et2 != eTypes[2]) { + continue; + } + // If multiple intrinsics are available for the given element types, we have + // to make a choice. On CDNA3, there may be an intrinsic with larger M/N and + // smaller K, which would optimize power, and an intrinsic with larger K, + // which would optimize performance when power is not the bottleneck. + // Currently we just choose the intrinsic maximizing K, but that can be + // revisited later. + if (candidateMma && candidateMma.getKSize() > mma.getKSize()) { + continue; } + candidateMma = mma; } - return false; + return candidateMma; } -static std::optional +static IREE::GPU::DataTiledMMAAttr chooseDataTiledMMAAttr(TypeRange eTypes, IREE::GPU::TargetAttr target, IREE::Encoding::EncodingAttr encoding) { using namespace IREE::GPU; + if (!target) { + return {}; + } MLIRContext *ctx = target.getContext(); + IREE::GPU::TargetWgpAttr wgp = target.getWgp(); + if (!wgp.getMaxLoadInstructionBits() || !wgp.getVgprSpaceBits() || + !wgp.getSimdsPerWgp()) { + // Missing workgroup parameters: data tiling not supported on this target. + return {}; + } // // Step 1: select a MMAIntrinsic. // - const MMAIntrinsic candidateIntrinsics[] = { - MMAIntrinsic::MFMA_F32_16x16x4_F32, - MMAIntrinsic::MFMA_F32_16x16x16_F16, - MMAIntrinsic::MFMA_I32_16x16x32_I8, - }; - std::optional intrinsic; - for (MMAIntrinsic candidateIntrinsic : candidateIntrinsics) { - if (!hasIntrinsic(target, candidateIntrinsic)) { - continue; - } - auto [et0, et1, et2] = - MMAAttr::get(ctx, candidateIntrinsic).getABCElementTypes(); - if (et0 != eTypes[0] || et1 != eTypes[1] || et2 != eTypes[2]) { - continue; - } - intrinsic = candidateIntrinsic; - break; - } - if (!intrinsic) { - return std::nullopt; + MMAAttr intrinsicMma = chooseIntrinsicMMAAttr(eTypes, wgp); + if (!intrinsicMma) { + return {}; } // // Step 2: Select the unrolling factors for the generic case where there is no // narrow dimension. // - // These hardcoded constants should become functions querying `target`. - // - // Target ISA preferred load instruction size, in bits. - const int kLoadInstructionBits = 128; - // Target ISA preferred number of subgroups per block to get full utilization. - const int kNumSubgroups = 4; - // Number of register space bits to use for accumulators. Should typically be - // between 50% and 80% of total available register space, as the accumulator - // tends to be larger than the A and B matrix tiles. - const int kMaxAccumulatorRegisterBits = 256 * 32; - - MMAAttr intrinsicMma = MMAAttr::get(ctx, *intrinsic); + + auto sizeInBits = [](VectorType type) -> int { + return type.getElementTypeBitWidth() * type.getNumElements(); + }; + auto [intrinsicA, intrinsicB, intrinsicC] = intrinsicMma.getABCVectorTypes(); // The unrollK factor serves to allow loads from the A and B matrices to use // the target ISA's vector loads. For instance, if the ISA has 128-bit loads // and each intrinsic consumes only 32 bits from A and B, then we want to set // unrollK=4 to turn 4 separate 32-bit loads into one 128-bit load. - const int unrollK = - kLoadInstructionBits / - std::min( - intrinsicA.getElementTypeBitWidth() * intrinsicA.getNumElements(), - intrinsicB.getElementTypeBitWidth() * intrinsicB.getNumElements()); + int intrinsicLoadBits = + std::min(sizeInBits(intrinsicA), sizeInBits(intrinsicB)); + if (*wgp.getMaxLoadInstructionBits() % intrinsicLoadBits != 0) { + // Never seen that case: the ISA does not have a suitable load instruction + // to feed that intrinsic?! + return {}; + } + const int unrollK = *wgp.getMaxLoadInstructionBits() / intrinsicLoadBits; + // The total amount of unrolling along the M and N dimensions is normally // limited only by the number of available registers, since larger M and N // yields higher arithmetic intensity. Here, we do not yet distinguish between // plain unrolling (more instructions on each thread) and - // unrolling-to-subgroups (more threads). - const int totalUnrollMN = - kMaxAccumulatorRegisterBits / - (intrinsicC.getElementTypeBitWidth() * intrinsicC.getNumElements()); - const int totalUnrollM = static_cast( - std::floor(std::sqrt(static_cast(totalUnrollMN)))); - const int totalUnrollN = totalUnrollMN / totalUnrollM; + // unrolling-to-subgroups (more threads), since expanding to more subgroups + // correspondingly divides the available register space between this many + // subgroups, making it cancel out of the equation here. + // + // We need to solve for two variables here, unroll_m and unroll_n, constrained + // by one quadratic equation expressing that the A, B and C tiles must fit in + // VGPR space. Since we have only 1 constraint for two variables, we + // self-impose a second constraint for now: that the unrolling shape should be + // square, i.e. unrollM == unrollN. + // TODO(#18850): that is suboptimal for narrow cases. + // + // Now we have only one variable, call it x, to solve for. + + // The register space taken is: + // A-tile: x * unrollK * sizeInBits(intrinsicA) + // B-tile: x * unrollK * sizeInBits(intrinsicB) + // C-tile: x^2 * sizeInBits(intrinsicC) + // So the equation to solve is: + // x^2 * sizeInBits(intrinsicC) + // + x * unrollK * (sizeInBits(intrinsicA) + sizeInBits(intrinsicB)) + // == wgp.getVgprSpaceBits() + float c2 = sizeInBits(intrinsicC); + float c1 = unrollK * (sizeInBits(intrinsicA) + sizeInBits(intrinsicB)); + float c0 = -*wgp.getVgprSpaceBits(); // negative by construction. + // Now the equation to solve is: c2 * x^2 + c1 * x + c0 == 0. + float discriminant = c1 * c1 - 4 * c0 * c2; // positive, because c0 < 0. + // x = unique positive solution. + float x = (-c1 + std::sqrt(discriminant)) / (2 * c2); + +#ifndef NDEBUG + // Self-check quadratic solver. 10 epsilon is just a crude upper bound; + // In practice, cancellation results in check == 0 in current cases. + float check = c2 * x * x + c1 * x + c0; + assert(std::abs(check) < 10 * FLT_EPSILON * std::abs(c0)); +#endif + + // Now, looking geometrically at our unrolling space along the M and N + // dimensions, we solve the following problem in the (M,N)-plane: approximate + // a square of side length `x`, by a rectangle of side lengths `totalUnrollM` + // and `totalUnrollN`, under the constraints: + // 1. totalUnrollM * totalUnrollN <= x * x + // * Reason: by construction of x, any larger area would exceed the + // wgp.getVgprSpaceBits() budget. + // 2. totalUnrollM and totalUnrollN are powers of 2. + // * Reason: that is a self-imposed constraint for now to avoid prematurely + // entering excessing fine-tuning of unrolling factors. Also, since below + // we will put all the unroll-to-subgroups in the N dimension, that + // requires totalUnrollN to be a multiple of wgp.getSimdsPerWgp(), + // which is typically a power of 2, specifically 4. + // TODO(#18851): we will not always put all the unroll-to-subgroups on N. + // 3. totalUnrollN >= totalUnrollM. + // * Reason: Just like the previous constraint, that is also motivated by + // the code below currently putting all the unroll-to-subgroups in the N + // dimension, which requires a sufficiently large totalUnrollN. + // TODO(#18851): we will not always put all the unroll-to-subgroups on N. + // + // Set totalUnrollN = round x to nearest power of two, break ties away from 0 + // per specification of std::round. + int totalUnrollN = std::exp2(std::round(std::log2(x))); + // Based on above constraint 1: + float unroundedMaxTotalUnrollM = x * x / totalUnrollN; + int totalUnrollM = std::exp2(std::floor(std::log2(unroundedMaxTotalUnrollM))); + // Now we introduce unroll-to-subgroups. It doesn't change the overall tile // size, as it increases the number of subgroups but correspondingly decreases // the number of registers available to each subgroups. In other words, the @@ -125,16 +184,18 @@ chooseDataTiledMMAAttr(TypeRange eTypes, IREE::GPU::TargetAttr target, // overall number of registers, not with how they are split between subgroups. // // For now for simplicity we put all the unroll-to-subgroups in the N - // dimension. That might be suboptimal, revisit later. That does simplify the - // below adjustments for narrow M/N, as we don't need to think about - // unroll-to-subgroups when making the narrowing adjustment. + // dimension. TODO(#18851): revisit that. + // + // That does simplify the below adjustments for narrow M/N, as we don't need + // to think about unroll-to-subgroups when making the narrowing adjustment. int unrollMToSubgroups = 1; - int unrollNToSubgroups = kNumSubgroups; + int unrollNToSubgroups = *wgp.getSimdsPerWgp(); int unrollM = totalUnrollM / unrollMToSubgroups; int unrollN = totalUnrollN / unrollNToSubgroups; // // Step 3: Adjust the unrolling factors when there is a narrow dimension. + // TODO(#18850): dealing with narrow cases as a fix-up is suboptimal. // IREE::Encoding::MatmulNarrowDim narrowDim = IREE::Encoding::getMatmulNarrowDim(encoding); @@ -177,7 +238,7 @@ materializeEncodingForTarget(RankedTensorType tensorType, } else { gpuTargetAttr = getCLGPUTarget(tensorType.getContext()); } - std::optional mma = chooseDataTiledMMAAttr( + IREE::GPU::DataTiledMMAAttr mma = chooseDataTiledMMAAttr( encoding.getElementTypesArray(), gpuTargetAttr, encoding); if (!mma) { return failure(); @@ -187,11 +248,11 @@ materializeEncodingForTarget(RankedTensorType tensorType, // based on its operand index in the matmul. auto rank = tensorType.getRank(); TileMxNxK innerTile; - std::tie(innerTile.M, innerTile.N, innerTile.K) = mma->getMNKShape(); + std::tie(innerTile.M, innerTile.N, innerTile.K) = mma.getMNKShape(); auto encodingInfo = getEncodingInfoForMatmul(encoding, rank, innerTile); auto fragment = static_cast(encoding.getOperandIndex().getInt()); - encodingInfo.swizzle = getSwizzle(*mma, fragment); + encodingInfo.swizzle = getSwizzle(mma, fragment); return encodingInfo; } diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp index e36ad993684f..c8b2edef15d2 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp @@ -305,7 +305,9 @@ struct DistributeBroadcast final : OpDistributionPattern { auto vectorType = VectorType::get(distShape, elementType); VectorValue srcVector = dyn_cast(broadcastOp.getSource()); - if (!srcVector) { + // If the srcVector is a scalar (like f32) or a rank-0 vector (like + // vector), we proceed with the scalar distribution branch. + if (!srcVector || !isNonZeroRank(srcVector)) { // The way distribution currently works, there is no partial thread // distribution, so a scalar is available to all threads. Scalar // distribution is simply a broadcast from scalar to the distributed @@ -413,16 +415,10 @@ struct DistributeMultiReduction final DistributionSignature &signature, PatternRewriter &rewriter) const override { VectorValue srcVector = multiReduceOp.getSource(); - auto accVector = dyn_cast(multiReduceOp.getAcc()); - if (!accVector) { - return rewriter.notifyMatchFailure( - multiReduceOp, "unimplemented: scalar accumulator distribution"); - } - auto resVector = dyn_cast(multiReduceOp.getResult()); - if (!resVector) { - return rewriter.notifyMatchFailure( - multiReduceOp, "unimplemented: scalar result distribution"); - } + Value acc = multiReduceOp.getAcc(); + Value res = multiReduceOp.getResult(); + auto accVector = dyn_cast(acc); + auto resVector = dyn_cast(res); auto srcLayout = dyn_cast_or_null(signature[srcVector]); if (!srcLayout) { @@ -440,8 +436,14 @@ struct DistributeMultiReduction final VectorValue disSrc = getDistributed(rewriter, srcVector, signature[srcVector]); - VectorValue disAcc = - getDistributed(rewriter, accVector, signature[accVector]); + + Value disAcc; + if (accVector) { + disAcc = getDistributed(rewriter, accVector, signature[accVector]); + } else { + // Scalars are always distributed to all threads already. + disAcc = multiReduceOp.getAcc(); + } Location loc = multiReduceOp.getLoc(); @@ -462,7 +464,16 @@ struct DistributeMultiReduction final auto localReduction = rewriter.create( loc, disSrc, localInit, distributedReductionMask, multiReduceOp.getKind()); - auto locallyReduced = dyn_cast(localReduction.getResult()); + + VectorValue locallyReduced; + if (accVector) { + locallyReduced = dyn_cast(localReduction.getResult()); + } else { + // Broadcast scalar accumulator to vector. + VectorType vecType = VectorType::get(ArrayRef{int64_t(1)}, elemTy); + locallyReduced = rewriter.create( + loc, vecType, localReduction.getResult()); + } assert(locallyReduced && "result should have been a vector"); @@ -485,15 +496,30 @@ struct DistributeMultiReduction final // reduction. VectorValue unflattened = rewriter.create( loc, shaped, threadReduced.value()); + + if (!accVector) { + // Broadcast the scalar (e.g., f32) to a vector type (e.g., vector) + // because the following implementation requires the operand to be a + // vector. + disAcc = rewriter.create(loc, shaped, disAcc); + } + Value accReduction = vector::makeArithReduction( rewriter, loc, multiReduceOp.getKind(), unflattened, disAcc); auto accReduced = dyn_cast(accReduction); if (!accReduced) { return failure(); } - replaceOpWithDistributedValues(rewriter, multiReduceOp, accReduced); - return failure(); + if (resVector) { + replaceOpWithDistributedValues(rewriter, multiReduceOp, accReduced); + } else { + Value accReducedVal = rewriter.create( + loc, accReduction, ArrayRef{int64_t(0)}); + replaceOpWithDistributedValues(rewriter, multiReduceOp, accReducedVal); + } + + return success(); } FailureOr doThreadReduction(RewriterBase &rewriter, diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPromoteMatmulOperands.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPromoteMatmulOperands.cpp index dd498fad50e8..5e50a956bd82 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPromoteMatmulOperands.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPromoteMatmulOperands.cpp @@ -53,9 +53,15 @@ void promoteOperand(OpBuilder &builder, Operation *op, unsigned index) { return; } } - setLoweringConfig(producer, IREE::GPU::DerivedThreadConfigAttr::get( - builder.getContext())); - return; + + // We only support thread tile size derivation of linalgOp and Im2colOp for + // now. + if (isa( + producer.getOperation())) { + setLoweringConfig(producer, IREE::GPU::DerivedThreadConfigAttr::get( + builder.getContext())); + return; + } } auto tensorType = dyn_cast(operand.getType()); diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUReduceBankConflicts.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUReduceBankConflicts.cpp index 807ab9d339eb..51898adc02d7 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUReduceBankConflicts.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUReduceBankConflicts.cpp @@ -18,6 +18,23 @@ namespace mlir::iree_compiler { namespace { +/// Check if AllocOp has a CollapseShapeOp user. +static bool hasCollapseShapeUser(memref::AllocOp allocOp) { + SmallVector users(allocOp->getUsers()); + while (!users.empty()) { + auto user = users.pop_back_val(); + if (isa(user)) { + return true; + } + if (isa(user)) { + for (auto u : user->getUsers()) { + users.push_back(u); + } + } + } + return false; +} + /// Pad out the inner dimension of the `memref.alloc` op in order reduce the /// chances to have bank conflicts when reading 2D shapes within shared memory. static void padAlloc(MLIRContext *context, memref::AllocOp allocOp, @@ -28,6 +45,12 @@ static void padAlloc(MLIRContext *context, memref::AllocOp allocOp, int64_t innerDim = allocOpShape.back(); if (ShapedType::isDynamic(innerDim)) return; + + // Return if we have CollapseShape op as an user as padding in that case is + // unsupported. + if (hasCollapseShapeUser(allocOp)) + return; + Type elType = allocOp.getType().getElementType(); unsigned bitwidth = mlir::DataLayout::closest(allocOp).getTypeSizeInBits(elType); diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.cpp index a8831809e25b..7e927b499077 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.cpp @@ -132,14 +132,16 @@ void DistributionPattern::replaceOpWithDistributedValues( for (auto [opResult, replacement] : llvm::zip_equal(op->getOpResults(), values)) { // If this value is a vector type, it must be converted back to simd. - if (isa(replacement.getType())) { - auto oldResult = cast(opResult); - // Create a toSIMD op to convert the value back to the simd. - rewriter.setInsertionPointAfterValue(oldResult); - Value toSIMD = rewriter.create( - oldResult.getLoc(), oldResult.getType(), replacement); - // Add to replacements. - replacement = toSIMD; + if (auto replacementType = dyn_cast(replacement.getType())) { + if (replacementType.getRank() != 0) { + auto oldResult = cast(opResult); + // Create a toSIMD op to convert the value back to the simd. + rewriter.setInsertionPointAfterValue(oldResult); + Value toSIMD = rewriter.create( + oldResult.getLoc(), oldResult.getType(), replacement); + // Add to replacements. + replacement = toSIMD; + } } replacements.push_back(replacement); } diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_materialize_encoding.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_materialize_encoding.mlir index fd97eaf051d5..90becb209c6c 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_materialize_encoding.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_materialize_encoding.mlir @@ -323,7 +323,7 @@ func.func @unset_encoding_ACC_dynamic_unroll8x8x4_MFMA_F32_16x16x4_F32() { #hal.pipeline.binding, #hal.pipeline.binding ]> -func.func @matmul_lowering_unroll8x8x4_MFMA_F32_16x16x4_F32() { +func.func @matmul_lowering_MFMA_F32_16x16x4_F32() { %c0 = arith.constant 0 : index %M = hal.interface.constant.load layout(#pipeline_layout_3) ordinal(0) : index %N = hal.interface.constant.load layout(#pipeline_layout_3) ordinal(1) : index @@ -356,7 +356,7 @@ func.func @matmul_lowering_unroll8x8x4_MFMA_F32_16x16x4_F32() { // CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)> // CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)> // CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)> -// CHECK: func.func @matmul_lowering_unroll8x8x4_MFMA_F32_16x16x4_F32 +// CHECK: func.func @matmul_lowering_MFMA_F32_16x16x4_F32 // CHECK-DAG: %[[LHS_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(0) // CHECK-DAG: %[[RHS_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(1) // CHECK-DAG: %[[ACC_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(2) @@ -382,7 +382,7 @@ func.func @matmul_lowering_unroll8x8x4_MFMA_F32_16x16x4_F32() { #hal.pipeline.binding, #hal.pipeline.binding ]> -func.func @batch_matmul_lowering_unroll8x8x4_MFMA_F32_16x16x4_F32() { +func.func @batch_matmul_lowering_MFMA_F32_16x16x4_F32() { %c0 = arith.constant 0 : index %B = hal.interface.constant.load layout(#pipeline_layout_4) ordinal(0) : index %M = hal.interface.constant.load layout(#pipeline_layout_4) ordinal(1) : index @@ -416,7 +416,7 @@ func.func @batch_matmul_lowering_unroll8x8x4_MFMA_F32_16x16x4_F32() { // CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> // CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)> // CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> -// CHECK: func.func @batch_matmul_lowering_unroll8x8x4_MFMA_F32_16x16x4_F32 +// CHECK: func.func @batch_matmul_lowering_MFMA_F32_16x16x4_F32 // CHECK-DAG: %[[LHS_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(0) // CHECK-DAG: %[[RHS_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(1) // CHECK-DAG: %[[ACC_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(2) @@ -429,6 +429,8 @@ func.func @batch_matmul_lowering_unroll8x8x4_MFMA_F32_16x16x4_F32() { // CHECK-SAME: kind = #iree_gpu.data_tiled_mma_layout // CHECK: flow.dispatch.tensor.store %[[MMA]], %[[ACC_BINDING]] +// ----- + //----------------------------------------------------------------------------- // 2. MFMA_I32_16x16x32_I8 //----------------------------------------------------------------------------- @@ -577,7 +579,7 @@ func.func @unset_encoding_ACC_unroll8x8x2_MFMA_I32_16x16x32_I8() { #hal.pipeline.binding ]> -func.func @matmul_lowering_unroll8x8x2_MFMA_I32_16x16x32_I8() { +func.func @matmul_lowering_MFMA_I32_16x16x32_I8() { %c0 = arith.constant 0 : index %M = hal.interface.constant.load layout(#pipeline_layout_3) ordinal(0) : index %N = hal.interface.constant.load layout(#pipeline_layout_3) ordinal(1) : index @@ -610,7 +612,7 @@ func.func @matmul_lowering_unroll8x8x2_MFMA_I32_16x16x32_I8() { // CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)> // CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)> // CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)> -// CHECK: func.func @matmul_lowering_unroll8x8x2_MFMA_I32_16x16x32_I8 +// CHECK: func.func @matmul_lowering_MFMA_I32_16x16x32_I8 // CHECK-DAG: %[[LHS_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(0) // CHECK-DAG: %[[RHS_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(1) // CHECK-DAG: %[[ACC_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(2) @@ -622,3 +624,569 @@ func.func @matmul_lowering_unroll8x8x2_MFMA_I32_16x16x32_I8() { // CHECK-SAME: iterator_types = [#iree_gpu.iterator_type, #iree_gpu.iterator_type, #iree_gpu.iterator_type] // CHECK-SAME: kind = #iree_gpu.data_tiled_mma_layout // CHECK: flow.dispatch.tensor.store %[[MMA]], %[[ACC_BINDING]] + +// ----- + +//------------------------------------------------------------------------- +// 3. Custom target parameters to test more MaterializeEncoding heuristics. +//------------------------------------------------------------------------- + +// Custom {max_load_instruction_bits = 64} => implied default {unroll_k = 1} (omitted in output) instead of {unroll_k = 2}. + +#target_gfx942_except_max_load_instruction_bits_64 = #hal.executable.target<"rocm", "rocm-hsaco-fb", { + iree.gpu.target = #iree_gpu.target< + arch = "gfx942", features = "", wgp = < + compute = fp64|fp32|fp16|int64|int32|int16|int8, + storage = b64|b32|b16|b8, + subgroup = shuffle|arithmetic, + dot = dp4xi8toi32, + mma = [, , , , , ], + subgroup_size_choices = [64], + max_workgroup_sizes = [1024, 1024, 1024], + max_thread_count_per_workgroup = 1024, + max_workgroup_memory_bytes = 65536, + max_workgroup_counts = [2147483647, 2147483647, 2147483647], + max_load_instruction_bits = 64, + simds_per_wgp = 4, + vgpr_space_bits = 16384 + > + >, + ukernels = "none" +}> + +#map = affine_map<(d0, d1, d2) -> (d0, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d2, d1)> +#map2 = affine_map<(d0, d1, d2) -> (d0, d1)> +#encoding_lhs = #iree_encoding.encoding> +#encoding_rhs = #iree_encoding.encoding> +#encoding_result = #iree_encoding.encoding> +#pipeline_layout_3 = #hal.pipeline.layout, + #hal.pipeline.binding, + #hal.pipeline.binding +]> + +func.func @matmul_lowering_MFMA_I32_16x16x32_I8_custom_max_load_instruction_bits_64() attributes {hal.executable.target = #target_gfx942_except_max_load_instruction_bits_64} { + %c0 = arith.constant 0 : index + %M = hal.interface.constant.load layout(#pipeline_layout_3) ordinal(0) : index + %N = hal.interface.constant.load layout(#pipeline_layout_3) ordinal(1) : index + %K = hal.interface.constant.load layout(#pipeline_layout_3) ordinal(2) : index + %0 = hal.interface.binding.subspan layout(#pipeline_layout_3) binding(0) alignment(64) offset(%c0) + : !flow.dispatch.tensor>{%M, %K} + %1 = hal.interface.binding.subspan layout(#pipeline_layout_3) binding(1) alignment(64) offset(%c0) + : !flow.dispatch.tensor>{%K, %N} + %2 = hal.interface.binding.subspan layout(#pipeline_layout_3) binding(2) alignment(64) offset(%c0) + : !flow.dispatch.tensor>{%M, %N} + %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [%M, %K], strides = [1, 1] + : !flow.dispatch.tensor>{%M, %K} + -> tensor + %4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [%K, %N], strides = [1, 1] + : !flow.dispatch.tensor>{%K, %N} + -> tensor + %5 = flow.dispatch.tensor.load %2, offsets = [0, 0], sizes = [%M, %N], strides = [1, 1] + : !flow.dispatch.tensor>{%M, %N} + -> tensor + %6 = linalg.matmul + ins(%3, %4 : tensor, + tensor) + outs(%5 : tensor) + -> tensor + flow.dispatch.tensor.store %6, %2, offsets = [0, 0], sizes = [%M, %N], strides = [1, 1] + : tensor + -> !flow.dispatch.tensor>{%M, %N} + return +} + +// CHECK: func.func @matmul_lowering_MFMA_I32_16x16x32_I8_custom_max_load_instruction_bits_64 +// CHECK: iree_gpu.multi_mma %[[LHS]], %[[RHS]], %[[ACC]] +// CHECK-SAME: iterator_types = [#iree_gpu.iterator_type, #iree_gpu.iterator_type, #iree_gpu.iterator_type] +// CHECK-SAME: kind = #iree_gpu.data_tiled_mma_layout + +// ----- + +// Custom {max_load_instruction_bits = 256} => {unroll_k = 4} instead of {unroll_k = 2}. + +#target_gfx942_except_max_load_instruction_bits_256 = #hal.executable.target<"rocm", "rocm-hsaco-fb", { + iree.gpu.target = #iree_gpu.target< + arch = "gfx942", features = "", wgp = < + compute = fp64|fp32|fp16|int64|int32|int16|int8, + storage = b64|b32|b16|b8, + subgroup = shuffle|arithmetic, + dot = dp4xi8toi32, + mma = [, , , , , ], + subgroup_size_choices = [64], + max_workgroup_sizes = [1024, 1024, 1024], + max_thread_count_per_workgroup = 1024, + max_workgroup_memory_bytes = 65536, + max_workgroup_counts = [2147483647, 2147483647, 2147483647], + max_load_instruction_bits = 256, + simds_per_wgp = 4, + vgpr_space_bits = 16384 + > + >, + ukernels = "none" +}> + +#map = affine_map<(d0, d1, d2) -> (d0, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d2, d1)> +#map2 = affine_map<(d0, d1, d2) -> (d0, d1)> +#encoding_lhs = #iree_encoding.encoding> +#encoding_rhs = #iree_encoding.encoding> +#encoding_result = #iree_encoding.encoding> +#pipeline_layout_3 = #hal.pipeline.layout, + #hal.pipeline.binding, + #hal.pipeline.binding +]> + +func.func @matmul_lowering_MFMA_I32_16x16x32_I8_custom_max_load_instruction_bits_64() attributes {hal.executable.target = #target_gfx942_except_max_load_instruction_bits_256} { + %c0 = arith.constant 0 : index + %M = hal.interface.constant.load layout(#pipeline_layout_3) ordinal(0) : index + %N = hal.interface.constant.load layout(#pipeline_layout_3) ordinal(1) : index + %K = hal.interface.constant.load layout(#pipeline_layout_3) ordinal(2) : index + %0 = hal.interface.binding.subspan layout(#pipeline_layout_3) binding(0) alignment(64) offset(%c0) + : !flow.dispatch.tensor>{%M, %K} + %1 = hal.interface.binding.subspan layout(#pipeline_layout_3) binding(1) alignment(64) offset(%c0) + : !flow.dispatch.tensor>{%K, %N} + %2 = hal.interface.binding.subspan layout(#pipeline_layout_3) binding(2) alignment(64) offset(%c0) + : !flow.dispatch.tensor>{%M, %N} + %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [%M, %K], strides = [1, 1] + : !flow.dispatch.tensor>{%M, %K} + -> tensor + %4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [%K, %N], strides = [1, 1] + : !flow.dispatch.tensor>{%K, %N} + -> tensor + %5 = flow.dispatch.tensor.load %2, offsets = [0, 0], sizes = [%M, %N], strides = [1, 1] + : !flow.dispatch.tensor>{%M, %N} + -> tensor + %6 = linalg.matmul + ins(%3, %4 : tensor, + tensor) + outs(%5 : tensor) + -> tensor + flow.dispatch.tensor.store %6, %2, offsets = [0, 0], sizes = [%M, %N], strides = [1, 1] + : tensor + -> !flow.dispatch.tensor>{%M, %N} + return +} + +// CHECK: func.func @matmul_lowering_MFMA_I32_16x16x32_I8_custom_max_load_instruction_bits_64 +// CHECK: iree_gpu.multi_mma %[[LHS]], %[[RHS]], %[[ACC]] +// CHECK-SAME: iterator_types = [#iree_gpu.iterator_type, #iree_gpu.iterator_type, #iree_gpu.iterator_type] +// CHECK-SAME: kind = #iree_gpu.data_tiled_mma_layout + +// ----- + +// Custom {simds_per_wgp = 1} => implied default {unroll_n_to_subgroups = 1} (omitted in output) and {unroll_n = 8} instead of {unroll_n_to_subgroups = 4}. + +#target_gfx942_except_simds_per_wgp_1 = #hal.executable.target<"rocm", "rocm-hsaco-fb", { + iree.gpu.target = #iree_gpu.target< + arch = "gfx942", features = "", wgp = < + compute = fp64|fp32|fp16|int64|int32|int16|int8, + storage = b64|b32|b16|b8, + subgroup = shuffle|arithmetic, + dot = dp4xi8toi32, + mma = [, , , , , ], + subgroup_size_choices = [64], + max_workgroup_sizes = [1024, 1024, 1024], + max_thread_count_per_workgroup = 1024, + max_workgroup_memory_bytes = 65536, + max_workgroup_counts = [2147483647, 2147483647, 2147483647], + max_load_instruction_bits = 128, + simds_per_wgp = 1, + vgpr_space_bits = 16384 + > + >, + ukernels = "none" +}> + +#map = affine_map<(d0, d1, d2) -> (d0, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d2, d1)> +#map2 = affine_map<(d0, d1, d2) -> (d0, d1)> +#encoding_lhs = #iree_encoding.encoding> +#encoding_rhs = #iree_encoding.encoding> +#encoding_result = #iree_encoding.encoding> +#pipeline_layout_3 = #hal.pipeline.layout, + #hal.pipeline.binding, + #hal.pipeline.binding +]> + +func.func @matmul_lowering_MFMA_I32_16x16x32_I8_custom_simds_per_wgp_1() attributes {hal.executable.target = #target_gfx942_except_simds_per_wgp_1} { + %c0 = arith.constant 0 : index + %M = hal.interface.constant.load layout(#pipeline_layout_3) ordinal(0) : index + %N = hal.interface.constant.load layout(#pipeline_layout_3) ordinal(1) : index + %K = hal.interface.constant.load layout(#pipeline_layout_3) ordinal(2) : index + %0 = hal.interface.binding.subspan layout(#pipeline_layout_3) binding(0) alignment(64) offset(%c0) + : !flow.dispatch.tensor>{%M, %K} + %1 = hal.interface.binding.subspan layout(#pipeline_layout_3) binding(1) alignment(64) offset(%c0) + : !flow.dispatch.tensor>{%K, %N} + %2 = hal.interface.binding.subspan layout(#pipeline_layout_3) binding(2) alignment(64) offset(%c0) + : !flow.dispatch.tensor>{%M, %N} + %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [%M, %K], strides = [1, 1] + : !flow.dispatch.tensor>{%M, %K} + -> tensor + %4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [%K, %N], strides = [1, 1] + : !flow.dispatch.tensor>{%K, %N} + -> tensor + %5 = flow.dispatch.tensor.load %2, offsets = [0, 0], sizes = [%M, %N], strides = [1, 1] + : !flow.dispatch.tensor>{%M, %N} + -> tensor + %6 = linalg.matmul + ins(%3, %4 : tensor, + tensor) + outs(%5 : tensor) + -> tensor + flow.dispatch.tensor.store %6, %2, offsets = [0, 0], sizes = [%M, %N], strides = [1, 1] + : tensor + -> !flow.dispatch.tensor>{%M, %N} + return +} + +// CHECK: func.func @matmul_lowering_MFMA_I32_16x16x32_I8_custom_simds_per_wgp_1 +// CHECK: iree_gpu.multi_mma %[[LHS]], %[[RHS]], %[[ACC]] +// CHECK-SAME: iterator_types = [#iree_gpu.iterator_type, #iree_gpu.iterator_type, #iree_gpu.iterator_type] +// CHECK-SAME: kind = #iree_gpu.data_tiled_mma_layout + +// ----- + +// Custom 2x smaller {vgpr_space_bits = 8192} => smaller unroll_m and unroll_n + +#target_gfx942_except_vgpr_space_bits_8192 = #hal.executable.target<"rocm", "rocm-hsaco-fb", { + iree.gpu.target = #iree_gpu.target< + arch = "gfx942", features = "", wgp = < + compute = fp64|fp32|fp16|int64|int32|int16|int8, + storage = b64|b32|b16|b8, + subgroup = shuffle|arithmetic, + dot = dp4xi8toi32, + mma = [, , , , , ], + subgroup_size_choices = [64], + max_workgroup_sizes = [1024, 1024, 1024], + max_thread_count_per_workgroup = 1024, + max_workgroup_memory_bytes = 65536, + max_workgroup_counts = [2147483647, 2147483647, 2147483647], + max_load_instruction_bits = 128, + simds_per_wgp = 4, + vgpr_space_bits = 8192 + > + >, + ukernels = "none" +}> + +#map = affine_map<(d0, d1, d2) -> (d0, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d2, d1)> +#map2 = affine_map<(d0, d1, d2) -> (d0, d1)> +#encoding_lhs = #iree_encoding.encoding> +#encoding_rhs = #iree_encoding.encoding> +#encoding_result = #iree_encoding.encoding> +#pipeline_layout_3 = #hal.pipeline.layout, + #hal.pipeline.binding, + #hal.pipeline.binding +]> + +func.func @matmul_lowering_MFMA_I32_16x16x32_I8_custom_vgpr_space_bits_8192() attributes {hal.executable.target = #target_gfx942_except_vgpr_space_bits_8192} { + %c0 = arith.constant 0 : index + %M = hal.interface.constant.load layout(#pipeline_layout_3) ordinal(0) : index + %N = hal.interface.constant.load layout(#pipeline_layout_3) ordinal(1) : index + %K = hal.interface.constant.load layout(#pipeline_layout_3) ordinal(2) : index + %0 = hal.interface.binding.subspan layout(#pipeline_layout_3) binding(0) alignment(64) offset(%c0) + : !flow.dispatch.tensor>{%M, %K} + %1 = hal.interface.binding.subspan layout(#pipeline_layout_3) binding(1) alignment(64) offset(%c0) + : !flow.dispatch.tensor>{%K, %N} + %2 = hal.interface.binding.subspan layout(#pipeline_layout_3) binding(2) alignment(64) offset(%c0) + : !flow.dispatch.tensor>{%M, %N} + %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [%M, %K], strides = [1, 1] + : !flow.dispatch.tensor>{%M, %K} + -> tensor + %4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [%K, %N], strides = [1, 1] + : !flow.dispatch.tensor>{%K, %N} + -> tensor + %5 = flow.dispatch.tensor.load %2, offsets = [0, 0], sizes = [%M, %N], strides = [1, 1] + : !flow.dispatch.tensor>{%M, %N} + -> tensor + %6 = linalg.matmul + ins(%3, %4 : tensor, + tensor) + outs(%5 : tensor) + -> tensor + flow.dispatch.tensor.store %6, %2, offsets = [0, 0], sizes = [%M, %N], strides = [1, 1] + : tensor + -> !flow.dispatch.tensor>{%M, %N} + return +} + +// CHECK: func.func @matmul_lowering_MFMA_I32_16x16x32_I8_custom_vgpr_space_bits_8192 +// CHECK: iree_gpu.multi_mma %[[LHS]], %[[RHS]], %[[ACC]] +// CHECK-SAME: iterator_types = [#iree_gpu.iterator_type, #iree_gpu.iterator_type, #iree_gpu.iterator_type] +// CHECK-SAME: kind = #iree_gpu.data_tiled_mma_layout + +// ----- + +// Custom 4x smaller {vgpr_space_bits = 4096} => smaller unroll_m and unroll_n + +#target_gfx942_except_vgpr_space_bits_4096 = #hal.executable.target<"rocm", "rocm-hsaco-fb", { + iree.gpu.target = #iree_gpu.target< + arch = "gfx942", features = "", wgp = < + compute = fp64|fp32|fp16|int64|int32|int16|int8, + storage = b64|b32|b16|b8, + subgroup = shuffle|arithmetic, + dot = dp4xi8toi32, + mma = [, , , , , ], + subgroup_size_choices = [64], + max_workgroup_sizes = [1024, 1024, 1024], + max_thread_count_per_workgroup = 1024, + max_workgroup_memory_bytes = 65536, + max_workgroup_counts = [2147483647, 2147483647, 2147483647], + max_load_instruction_bits = 128, + simds_per_wgp = 4, + vgpr_space_bits = 4096 + > + >, + ukernels = "none" +}> + +#map = affine_map<(d0, d1, d2) -> (d0, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d2, d1)> +#map2 = affine_map<(d0, d1, d2) -> (d0, d1)> +#encoding_lhs = #iree_encoding.encoding> +#encoding_rhs = #iree_encoding.encoding> +#encoding_result = #iree_encoding.encoding> +#pipeline_layout_3 = #hal.pipeline.layout, + #hal.pipeline.binding, + #hal.pipeline.binding +]> + +func.func @matmul_lowering_MFMA_I32_16x16x32_I8_custom_vgpr_space_bits_4096() attributes {hal.executable.target = #target_gfx942_except_vgpr_space_bits_4096} { + %c0 = arith.constant 0 : index + %M = hal.interface.constant.load layout(#pipeline_layout_3) ordinal(0) : index + %N = hal.interface.constant.load layout(#pipeline_layout_3) ordinal(1) : index + %K = hal.interface.constant.load layout(#pipeline_layout_3) ordinal(2) : index + %0 = hal.interface.binding.subspan layout(#pipeline_layout_3) binding(0) alignment(64) offset(%c0) + : !flow.dispatch.tensor>{%M, %K} + %1 = hal.interface.binding.subspan layout(#pipeline_layout_3) binding(1) alignment(64) offset(%c0) + : !flow.dispatch.tensor>{%K, %N} + %2 = hal.interface.binding.subspan layout(#pipeline_layout_3) binding(2) alignment(64) offset(%c0) + : !flow.dispatch.tensor>{%M, %N} + %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [%M, %K], strides = [1, 1] + : !flow.dispatch.tensor>{%M, %K} + -> tensor + %4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [%K, %N], strides = [1, 1] + : !flow.dispatch.tensor>{%K, %N} + -> tensor + %5 = flow.dispatch.tensor.load %2, offsets = [0, 0], sizes = [%M, %N], strides = [1, 1] + : !flow.dispatch.tensor>{%M, %N} + -> tensor + %6 = linalg.matmul + ins(%3, %4 : tensor, + tensor) + outs(%5 : tensor) + -> tensor + flow.dispatch.tensor.store %6, %2, offsets = [0, 0], sizes = [%M, %N], strides = [1, 1] + : tensor + -> !flow.dispatch.tensor>{%M, %N} + return +} + +// CHECK: func.func @matmul_lowering_MFMA_I32_16x16x32_I8_custom_vgpr_space_bits_4096 +// CHECK: iree_gpu.multi_mma %[[LHS]], %[[RHS]], %[[ACC]] +// CHECK-SAME: iterator_types = [#iree_gpu.iterator_type, #iree_gpu.iterator_type, #iree_gpu.iterator_type] +// CHECK-SAME: kind = #iree_gpu.data_tiled_mma_layout + +// ----- + +// Custom smaller {vgpr_space_bits = 32768} => larger unroll_m and/or unroll_n + +#target_gfx942_except_vgpr_space_bits_32768 = #hal.executable.target<"rocm", "rocm-hsaco-fb", { + iree.gpu.target = #iree_gpu.target< + arch = "gfx942", features = "", wgp = < + compute = fp64|fp32|fp16|int64|int32|int16|int8, + storage = b64|b32|b16|b8, + subgroup = shuffle|arithmetic, + dot = dp4xi8toi32, + mma = [, , , , , ], + subgroup_size_choices = [64], + max_workgroup_sizes = [1024, 1024, 1024], + max_thread_count_per_workgroup = 1024, + max_workgroup_memory_bytes = 65536, + max_workgroup_counts = [2147483647, 2147483647, 2147483647], + max_load_instruction_bits = 128, + simds_per_wgp = 4, + vgpr_space_bits = 32768 + > + >, + ukernels = "none" +}> + +#map = affine_map<(d0, d1, d2) -> (d0, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d2, d1)> +#map2 = affine_map<(d0, d1, d2) -> (d0, d1)> +#encoding_lhs = #iree_encoding.encoding> +#encoding_rhs = #iree_encoding.encoding> +#encoding_result = #iree_encoding.encoding> +#pipeline_layout_3 = #hal.pipeline.layout, + #hal.pipeline.binding, + #hal.pipeline.binding +]> + +func.func @matmul_lowering_MFMA_I32_16x16x32_I8_custom_vgpr_space_bits_32768() attributes {hal.executable.target = #target_gfx942_except_vgpr_space_bits_32768} { + %c0 = arith.constant 0 : index + %M = hal.interface.constant.load layout(#pipeline_layout_3) ordinal(0) : index + %N = hal.interface.constant.load layout(#pipeline_layout_3) ordinal(1) : index + %K = hal.interface.constant.load layout(#pipeline_layout_3) ordinal(2) : index + %0 = hal.interface.binding.subspan layout(#pipeline_layout_3) binding(0) alignment(64) offset(%c0) + : !flow.dispatch.tensor>{%M, %K} + %1 = hal.interface.binding.subspan layout(#pipeline_layout_3) binding(1) alignment(64) offset(%c0) + : !flow.dispatch.tensor>{%K, %N} + %2 = hal.interface.binding.subspan layout(#pipeline_layout_3) binding(2) alignment(64) offset(%c0) + : !flow.dispatch.tensor>{%M, %N} + %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [%M, %K], strides = [1, 1] + : !flow.dispatch.tensor>{%M, %K} + -> tensor + %4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [%K, %N], strides = [1, 1] + : !flow.dispatch.tensor>{%K, %N} + -> tensor + %5 = flow.dispatch.tensor.load %2, offsets = [0, 0], sizes = [%M, %N], strides = [1, 1] + : !flow.dispatch.tensor>{%M, %N} + -> tensor + %6 = linalg.matmul + ins(%3, %4 : tensor, + tensor) + outs(%5 : tensor) + -> tensor + flow.dispatch.tensor.store %6, %2, offsets = [0, 0], sizes = [%M, %N], strides = [1, 1] + : tensor + -> !flow.dispatch.tensor>{%M, %N} + return +} + +// CHECK: func.func @matmul_lowering_MFMA_I32_16x16x32_I8_custom_vgpr_space_bits_32768 +// CHECK: iree_gpu.multi_mma %[[LHS]], %[[RHS]], %[[ACC]] +// CHECK-SAME: iterator_types = [#iree_gpu.iterator_type, #iree_gpu.iterator_type, #iree_gpu.iterator_type] +// CHECK-SAME: kind = #iree_gpu.data_tiled_mma_layout + +// ----- + +//--------------------------------------------------------------------------- +// 4. Additional element types, testing only the multi_mma, not set_encoding. +//--------------------------------------------------------------------------- + +#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> +#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)> +#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> +#encoding_lhs = #iree_encoding.encoding +#encoding_rhs = #iree_encoding.encoding +#encoding_result = #iree_encoding.encoding +#pipeline_layout_4 = #hal.pipeline.layout, + #hal.pipeline.binding, + #hal.pipeline.binding +]> +func.func @batch_matmul_lowering_MFMA_F32_16x16x32_F8E4M3FNUZ() { + %c0 = arith.constant 0 : index + %B = hal.interface.constant.load layout(#pipeline_layout_4) ordinal(0) : index + %M = hal.interface.constant.load layout(#pipeline_layout_4) ordinal(1) : index + %N = hal.interface.constant.load layout(#pipeline_layout_4) ordinal(2) : index + %K = hal.interface.constant.load layout(#pipeline_layout_4) ordinal(3) : index + %0 = hal.interface.binding.subspan layout(#pipeline_layout_4) binding(0) alignment(64) offset(%c0) + : !flow.dispatch.tensor>{%B, %M, %K} + %1 = hal.interface.binding.subspan layout(#pipeline_layout_4) binding(1) alignment(64) offset(%c0) + : !flow.dispatch.tensor>{%B, %K, %N} + %2 = hal.interface.binding.subspan layout(#pipeline_layout_4) binding(2) alignment(64) offset(%c0) + : !flow.dispatch.tensor>{%B, %M, %N} + %3 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0], sizes = [%B, %M, %K], strides = [1, 1, 1] + : !flow.dispatch.tensor>{%B, %M, %K} + -> tensor + %4 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0], sizes = [%B, %K, %N], strides = [1, 1, 1] + : !flow.dispatch.tensor>{%B, %K, %N} + -> tensor + %5 = flow.dispatch.tensor.load %2, offsets = [0, 0, 0], sizes = [%B, %M, %N], strides = [1, 1, 1] + : !flow.dispatch.tensor>{%B, %M, %N} + -> tensor + %6 = linalg.batch_matmul + ins(%3, %4 : tensor, + tensor) + outs(%5 : tensor) + -> tensor + flow.dispatch.tensor.store %6, %2, offsets = [0, 0, 0], sizes = [%B, %M, %N], strides = [1, 1, 1] + : tensor + -> !flow.dispatch.tensor>{%B, %M, %N} + return +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> +// CHECK: func.func @batch_matmul_lowering_MFMA_F32_16x16x32_F8E4M3FNUZ +// CHECK-DAG: %[[LHS_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(0) +// CHECK-DAG: %[[RHS_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(1) +// CHECK-DAG: %[[ACC_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(2) +// CHECK-DAG: %[[LHS:.+]] = flow.dispatch.tensor.load %[[LHS_BINDING]]{{.+}} -> tensor +// CHECK-DAG: %[[RHS:.+]] = flow.dispatch.tensor.load %[[RHS_BINDING]]{{.+}} -> tensor +// CHECK-DAG: %[[ACC:.+]] = flow.dispatch.tensor.load %[[ACC_BINDING]]{{.+}} -> tensor +// CHECK: %[[MMA:.+]] = iree_gpu.multi_mma %[[LHS]], %[[RHS]], %[[ACC]] +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]], +// CHECK-SAME: iterator_types = [#iree_gpu.iterator_type, #iree_gpu.iterator_type, #iree_gpu.iterator_type, #iree_gpu.iterator_type] +// CHECK-SAME: kind = #iree_gpu.data_tiled_mma_layout +// CHECK: flow.dispatch.tensor.store %[[MMA]], %[[ACC_BINDING]] + +// ----- + +#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> +#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)> +#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> +#encoding_lhs = #iree_encoding.encoding +#encoding_rhs = #iree_encoding.encoding +#encoding_result = #iree_encoding.encoding +#pipeline_layout_4 = #hal.pipeline.layout, + #hal.pipeline.binding, + #hal.pipeline.binding +]> +func.func @batch_matmul_lowering_MFMA_F32_16x16x16_BF16() { + %c0 = arith.constant 0 : index + %B = hal.interface.constant.load layout(#pipeline_layout_4) ordinal(0) : index + %M = hal.interface.constant.load layout(#pipeline_layout_4) ordinal(1) : index + %N = hal.interface.constant.load layout(#pipeline_layout_4) ordinal(2) : index + %K = hal.interface.constant.load layout(#pipeline_layout_4) ordinal(3) : index + %0 = hal.interface.binding.subspan layout(#pipeline_layout_4) binding(0) alignment(64) offset(%c0) + : !flow.dispatch.tensor>{%B, %M, %K} + %1 = hal.interface.binding.subspan layout(#pipeline_layout_4) binding(1) alignment(64) offset(%c0) + : !flow.dispatch.tensor>{%B, %K, %N} + %2 = hal.interface.binding.subspan layout(#pipeline_layout_4) binding(2) alignment(64) offset(%c0) + : !flow.dispatch.tensor>{%B, %M, %N} + %3 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0], sizes = [%B, %M, %K], strides = [1, 1, 1] + : !flow.dispatch.tensor>{%B, %M, %K} + -> tensor + %4 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0], sizes = [%B, %K, %N], strides = [1, 1, 1] + : !flow.dispatch.tensor>{%B, %K, %N} + -> tensor + %5 = flow.dispatch.tensor.load %2, offsets = [0, 0, 0], sizes = [%B, %M, %N], strides = [1, 1, 1] + : !flow.dispatch.tensor>{%B, %M, %N} + -> tensor + %6 = linalg.batch_matmul + ins(%3, %4 : tensor, + tensor) + outs(%5 : tensor) + -> tensor + flow.dispatch.tensor.store %6, %2, offsets = [0, 0, 0], sizes = [%B, %M, %N], strides = [1, 1, 1] + : tensor + -> !flow.dispatch.tensor>{%B, %M, %N} + return +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> +// CHECK: func.func @batch_matmul_lowering_MFMA_F32_16x16x16_BF16 +// CHECK-DAG: %[[LHS_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(0) +// CHECK-DAG: %[[RHS_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(1) +// CHECK-DAG: %[[ACC_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(2) +// CHECK-DAG: %[[LHS:.+]] = flow.dispatch.tensor.load %[[LHS_BINDING]]{{.+}} -> tensor +// CHECK-DAG: %[[RHS:.+]] = flow.dispatch.tensor.load %[[RHS_BINDING]]{{.+}} -> tensor +// CHECK-DAG: %[[ACC:.+]] = flow.dispatch.tensor.load %[[ACC_BINDING]]{{.+}} -> tensor +// CHECK: %[[MMA:.+]] = iree_gpu.multi_mma %[[LHS]], %[[RHS]], %[[ACC]] +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]], +// CHECK-SAME: iterator_types = [#iree_gpu.iterator_type, #iree_gpu.iterator_type, #iree_gpu.iterator_type, #iree_gpu.iterator_type] +// CHECK-SAME: kind = #iree_gpu.data_tiled_mma_layout +// CHECK: flow.dispatch.tensor.store %[[MMA]], %[[ACC_BINDING]] diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_vector_distribution.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_vector_distribution.mlir index f05b9925cd6c..1fd7682b58e6 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_vector_distribution.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_vector_distribution.mlir @@ -1047,3 +1047,95 @@ builtin.module attributes { transform.with_named_sequence } { // CHECK: gpu.subgroup_reduce maximumf %{{.*}} cluster(size = 2, stride = 32) : (f32) -> f32 // Accumulator reduction // CHECK: arith.maximumf %{{.*}}, %{{.*}} : vector<1x1x1xf32> + +// ----- + +#nested = #iree_vector_ext.nested_layout< + subgroup_tile = [1, 1], + batch_tile = [2, 2], + outer_tile = [1, 1], + thread_tile = [16, 4], + element_tile = [1, 4], + + subgroup_strides = [1, 1], + thread_strides = [1, 16] +> + +func.func @mfma_16x16x16_out_reduced_alldims(%arg0: vector<32x32xf32>, %arg1: f32) -> f32 { + %arg0l = iree_vector_ext.to_layout %arg0 to layout(#nested) : vector<32x32xf32> + %0 = vector.multi_reduction , %arg0l, %arg1 [0, 1] : vector<32x32xf32> to f32 + return %0 : f32 +} + +builtin.module attributes { transform.with_named_sequence } { + transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) { + %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op + transform.iree.test_gpu_vector_distribution %top_level_func : !transform.any_op + transform.yield + } +} + +// CHECK-LABEL: func @mfma_16x16x16_out_reduced_alldims +// Local reduction +// CHECK: vector.multi_reduction , %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5] : vector<2x2x1x1x1x4xf32> to f32 +// Global reduction +// CHECK: gpu.subgroup_reduce maximumf %{{.*}} cluster(size = 16) : (f32) -> f32 +// CHECK-NEXT: gpu.subgroup_reduce maximumf %{{.*}} cluster(size = 4, stride = 16) : (f32) -> f32 +// Accumulator reduction +// CHECK: arith.maximumf %{{.*}}, %{{.*}} : vector<1xf32> + +// ----- + +#layout = #iree_vector_ext.nested_layout< + subgroup_tile = [1, 1], + batch_tile = [2, 2], + outer_tile = [1, 1], + thread_tile = [16, 4], + element_tile = [1, 4], + + subgroup_strides = [1, 1], + thread_strides = [1, 16] +> + +func.func @distribute_scf_for(%arr: memref<32x32xf16>, %a: vector<32x32xf16>) -> vector { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + %cst = arith.constant dense<0.000000e+00> : vector + %cst_0 = arith.constant 0.0 : f16 + %out = scf.for %i = %c0 to %c128 step %c1 iter_args(%arg0 = %cst) -> (vector) { + %root = vector.transfer_read %arr[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<32x32xf16>, vector<32x32xf16> + %rootl = iree_vector_ext.to_layout %root to layout(#layout) : vector<32x32xf16> + %b = arith.addf %rootl, %a : vector<32x32xf16> + %c = arith.extf %b : vector<32x32xf16> to vector<32x32xf32> + %init = vector.extractelement %arg0[] : vector + %root_red = vector.multi_reduction, %c, %init [0, 1] : vector<32x32xf32> to f32 + %d = vector.broadcast %root_red : f32 to vector + scf.yield %d : vector + } + return %out : vector +} + +builtin.module attributes { transform.with_named_sequence } { + transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) { + %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op + transform.iree.test_gpu_vector_distribution %top_level_func : !transform.any_op + transform.yield + } +} + +// CHECK-LABEL: func @distribute_scf_for +// CHECK: %[[ROOT:.*]] = arith.constant dense<0.000000e+00> : vector +// CHECK: iter_args(%[[ARG0:.*]] = %[[ROOT]]) -> (vector) +// CHECK: %[[A:.*]] = iree_vector_ext.to_simt %{{.*}} : vector<32x32xf16> -> vector<2x2x1x1x1x4xf16> +// CHECK: %[[B:.*]] = arith.addf %{{.*}}, %[[A]] +// CHECK: %[[C:.*]] = arith.extf %[[B]] +// CHECK-NEXT: %[[D:.*]] = vector.extractelement %[[ARG0]][] : vector +// Local reduction +// CHECK: vector.multi_reduction , %[[C]], %{{.*}} [0, 1, 2, 3, 4, 5] : vector<2x2x1x1x1x4xf32> to f32 +// Global reduction +// CHECK: gpu.subgroup_reduce add %{{.*}} cluster(size = 16) : (f32) -> f32 +// CHECK-NEXT: gpu.subgroup_reduce add %{{.*}} cluster(size = 4, stride = 16) : (f32) -> f32 +// Accumulator reduction +// CHECK: vector.broadcast %[[D]] : f32 to vector<1xf32> +// CHECK: arith.addf %{{.*}}, %{{.*}} : vector<1xf32> diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_promote_matmul_operands.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_promote_matmul_operands.mlir index f05cf7b1890b..643b12c01e39 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_promote_matmul_operands.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_promote_matmul_operands.mlir @@ -82,3 +82,27 @@ func.func @no_promote_fill(%b: tensor<128x128xf32>) -> tensor<4x128xf32> { // CHECK-LABEL: func.func @no_promote_fill // CHECK-NOT: iree_gpu.derived_thread_config // CHECK: return + +// ----- + +#lowering_config = #iree_gpu.lowering_config<{promote_operands = [0]}> + +func.func @promote_pad(%a : tensor<4x127xf32>, %b: tensor<128x128xf32>) -> tensor<4x128xf32> { + %cst = arith.constant 0.000000e+00 : f32 + %empty = tensor.empty() : tensor<4x128xf32> + %fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<4x128xf32>) -> tensor<4x128xf32> + %padded = tensor.pad %a low[0, 0] high[0, 1] { + ^bb0(%arg0: index, %arg1: index): + tensor.yield %cst : f32 + } : tensor<4x127xf32> to tensor<4x128xf32> + %mm = linalg.matmul {lowering_config = #lowering_config} + ins(%padded, %b : tensor<4x128xf32>, tensor<128x128xf32>) outs(%fill : tensor<4x128xf32>) -> tensor<4x128xf32> + return %mm : tensor<4x128xf32> +} + +// Verify that pad is promoted with linalg.copy +// CHECK-LABEL: func.func @promote_pad +// CHECK: tensor.pad +// CHECK: linalg.copy +// CHECK-SAME: derived_thread_config +// CHECK: return diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_vector_distribution.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_vector_distribution.mlir index da40806ac73c..cf47ca9d47b5 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_vector_distribution.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_vector_distribution.mlir @@ -666,10 +666,10 @@ builtin.module attributes { transform.with_named_sequence } { } // CHECK-LABEL: func.func @resolve_constant_with_multiple_layout_uses // CHECK-SAME: (%[[ARG0:.+]]: vector<64x64xf16>, %[[ARG0:.+]]: vector<64x64xf16>) -// CHECK: %[[V0:.+]] = arith.constant dense<0.000000e+00> : vector<2x2x16xf16> -// CHECK: %[[V1:.+]] = arith.constant dense<0.000000e+00> : vector<2x2x8xf16> -// CHECK: %[[ADD0:.+]] = arith.addf %{{.+}}, %[[V1]]{{.*}} : vector<2x2x8xf16> -// CHECK: %[[ADD1:.+]] = arith.addf %{{.+}}, %[[V0]]{{.*}} : vector<2x2x16xf16> +// CHECK: %[[V0:.+]] = arith.constant dense<0.000000e+00> : vector<2x2x8xf16> +// CHECK: %[[V1:.+]] = arith.constant dense<0.000000e+00> : vector<2x2x16xf16> +// CHECK: %[[ADD0:.+]] = arith.addf %{{.+}}, %[[V0]]{{.*}} : vector<2x2x8xf16> +// CHECK: %[[ADD1:.+]] = arith.addf %{{.+}}, %[[V1]]{{.*}} : vector<2x2x16xf16> // CHECK: arith.addf %{{.+}}, %[[ADD0]]{{.*}} : vector<2x2x8xf16> transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) { diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/reduce_bank_conflicts.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/reduce_bank_conflicts.mlir index befb2445ab24..b934772ffd34 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/reduce_bank_conflicts.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/reduce_bank_conflicts.mlir @@ -47,6 +47,66 @@ func.func @pad_alloc_expand_shape(%a: memref<1024x1024xf32>) { return } +// ----- +// CHECK-LABEL: func.func @no_pad_alloc_collapse_shape +// CHECK: %[[A:.*]] = memref.alloc() : memref<4x2x16x8x8xf32, #gpu.address_space> +// CHECK: %[[C:.*]] = memref.collapse_shape %[[A]] {{\[}}[0], [1, 2], [3, 4]] +// CHECK-SAME: memref<4x2x16x8x8xf32, #gpu.address_space> into +// CHECK-SAME: memref<4x32x64xf32, #gpu.address_space> +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[CST_0:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[VEC_READ:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]]], %[[CST_0]] {in_bounds = [true]} : +// CHECK-SAME: memref<1024x1024xf32>, vector<4xf32> +// CHECK: vector.transfer_write %[[VEC_READ]], %[[C]][%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true]} : +// CHECK-SAME: vector<4xf32>, memref<4x32x64xf32, #gpu.address_space> + + +func.func @no_pad_alloc_collapse_shape(%a: memref<1024x1024xf32>) { + %0 = memref.alloc() : memref<4x2x16x8x8xf32, #gpu.address_space> + %1 = memref.collapse_shape %0 [[0], [1, 2], [3, 4]] + : memref<4x2x16x8x8xf32, #gpu.address_space> into memref<4x32x64xf32, #gpu.address_space> + %c0 = arith.constant 0 : index + %cst_0 = arith.constant 0.000000e+00 : f32 + %3 = vector.transfer_read %a[%c0, %c0], %cst_0 {in_bounds = [true]} : + memref<1024x1024xf32>, vector<4xf32> + vector.transfer_write %3, %1[%c0, %c0, %c0] {in_bounds = [true]} : + vector<4xf32>, memref<4x32x64xf32, #gpu.address_space> + return +} + +// ----- + +// CHECK-LABEL: func.func @no_pad_alloc_collapse_shape_throughsubview +// CHECK: %[[A:.*]] = memref.alloc() : memref<4x2x16x8x8xf32, #gpu.address_space> +// CHECK: %[[S:.*]] = memref.subview %[[A]][0, 0, 0, 0, 0] [4, 2, 16, 8, 8] [1, 1, 1, 1, 1] : +// CHECK-SAME: memref<4x2x16x8x8xf32, #gpu.address_space> to +// CHECK-SAME: memref<4x2x16x8x8xf32, #gpu.address_space> +// CHECK: %[[C:.*]] = memref.collapse_shape %[[S]] {{\[}}[0], [1, 2], [3, 4]] +// CHECK-SAME: memref<4x2x16x8x8xf32, #gpu.address_space> into +// CHECK-SAME: memref<4x32x64xf32, #gpu.address_space> +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[VEC_READ:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]]], %[[CST]] {in_bounds = [true]} : +// CHECK-SAME: memref<1024x1024xf32>, vector<4xf32> +// CHECK: vector.transfer_write %[[VEC_READ]], %[[C]][%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true]} : +// CHECK-SAME: vector<4xf32>, memref<4x32x64xf32, #gpu.address_space> + + +func.func @no_pad_alloc_collapse_shape_throughsubview(%a: memref<1024x1024xf32>) { + %0 = memref.alloc() : memref<4x2x16x8x8xf32, #gpu.address_space> + %subview = memref.subview %0[0, 0, 0, 0, 0] [4, 2, 16, 8, 8] [1, 1, 1, 1, 1] + : memref<4x2x16x8x8xf32, #gpu.address_space> to memref<4x2x16x8x8xf32, #gpu.address_space> + %1 = memref.collapse_shape %subview [[0], [1, 2], [3, 4]] + : memref<4x2x16x8x8xf32, #gpu.address_space> into memref<4x32x64xf32, #gpu.address_space> + %c0 = arith.constant 0 : index + %cst_0 = arith.constant 0.000000e+00 : f32 + %3 = vector.transfer_read %a[%c0, %c0], %cst_0 {in_bounds = [true]} : + memref<1024x1024xf32>, vector<4xf32> + vector.transfer_write %3, %1[%c0, %c0, %c0] {in_bounds = [true]} : + vector<4xf32>, memref<4x32x64xf32, #gpu.address_space> + return +} + // ----- // CHECK-LABEL: func.func @pad_alloc_negative diff --git a/compiler/src/iree/compiler/Codegen/Common/OptimizeVectorTransferPass.cpp b/compiler/src/iree/compiler/Codegen/Common/OptimizeVectorTransferPass.cpp index f300291149b2..ebaabcc56a89 100644 --- a/compiler/src/iree/compiler/Codegen/Common/OptimizeVectorTransferPass.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/OptimizeVectorTransferPass.cpp @@ -91,7 +91,8 @@ struct OptimizeVectorTransferPass final // Workaround, run loop invariant code motion before hoist redundant // vector transfer to workaround a bug upstream. loopInvariantCodeMotion(funcOp); - linalg::hoistRedundantVectorTransfers(cast(funcOp)); + linalg::hoistRedundantVectorTransfers(cast(funcOp), + /*verifyNonZeroTrip=*/true); } IRRewriter rewriter(funcOp->getContext()); vector::transferOpflowOpt(rewriter, funcOp); diff --git a/compiler/src/iree/compiler/Codegen/Common/Passes.h b/compiler/src/iree/compiler/Codegen/Common/Passes.h index 502a3cfb9024..eac457dc6280 100644 --- a/compiler/src/iree/compiler/Codegen/Common/Passes.h +++ b/compiler/src/iree/compiler/Codegen/Common/Passes.h @@ -60,24 +60,6 @@ std::unique_ptr> createConvertToDestinationPassingStylePass( bool useWARForCooperativeMatrixCodegen); -using ConfigFn = - std::function; -/// Pass to convert Conv2D ops into IGEMM (Im2colOp + matmul). `configFn` is -/// used to set lowering configurations on the resulting ops, if necessary. -std::unique_ptr> -createConvolutionToIGEMMPass(ConfigFn configFn); - -using PackUnPackControlFn = std::function; -/// Pass to decompose pack and unpack ops into pad/extract_slice and reshape -/// ops. If specified, `controlFn` controls which ops get decomposed. The -/// `controlFn` should be used with `useOnlyReshapes` set to true. -/// TODO(Max191): Add a controlFn upstream for `GeneralizeOuterUnitDim*` -/// patterns and remove the need to have `useOnlyReshapes = true` when using -/// `controlFn`. -std::unique_ptr> -createDecomposePackUnPackOpsPass(bool tileOuterToOne, bool useOnlyReshapes, - std::optional controlFn); - std::unique_ptr createDecomposeSoftmaxPass(bool useFusion); /// Pass to perform linalg on tensor bufferization. The function passed into diff --git a/compiler/src/iree/compiler/Codegen/Common/Passes.td b/compiler/src/iree/compiler/Codegen/Common/Passes.td index 6bb6c829ad48..5aa3ef414bcb 100644 --- a/compiler/src/iree/compiler/Codegen/Common/Passes.td +++ b/compiler/src/iree/compiler/Codegen/Common/Passes.td @@ -19,6 +19,12 @@ def AddFastMathFlagsPass "given a floating-point mode."; } +def BlockDynamicDimensionsPass + : Pass<"iree-codegen-block-dynamic-dimensions"> { + let summary = "Expand dynamic dimensions that are known to be multiples of " + "statically known values."; +} + def BubbleUpOrdinalOpsPass : Pass<"iree-codegen-bubble-up-ordinal-ops", ""> { let summary = "Bubbles op ordinal ops to allow for workgroup count computation"; let description = [{ @@ -83,6 +89,10 @@ def ConvolutionToIGEMMPass : InterfacePass<"iree-codegen-convolution-to-igemm", "mlir::FunctionOpInterface"> { let summary = "Transforms convolution operations into an implicit GEMM format."; + let dependentDialects = [ + "tensor::TensorDialect", + "iree_compiler::IREE::LinalgExt::IREELinalgExtDialect" + ]; } def DecomposeAffineOpsPass: Pass<"iree-codegen-decompose-affine-ops"> { @@ -157,6 +167,27 @@ def DecomposePackUnPackOpsPass : Option<"useOnlyReshapes", "use-only-reshapes", "bool", "false", "Use decomposition into reshape ops, even when packing unit dimensions."> ]; + let dependentDialects = [ + "arith::ArithDialect", + "linalg::LinalgDialect", + "scf::SCFDialect", + "tensor::TensorDialect" + ]; +} + +def DecomposeBoundaryPackUnPackOpsPass : + InterfacePass<"iree-codegen-decompose-boundary-pack-unpack-ops", "mlir::FunctionOpInterface"> { + let summary = "Wrapper for DecomposePackUnPackOpsPass to decompose ops at function boundaries"; + let options = [ + Option<"tileOuterToOne", "tile-outer-to-one", "bool", "false", + "Always apply tiling to make outer dimension be ones"> + ]; + let dependentDialects = [ + "arith::ArithDialect", + "linalg::LinalgDialect", + "scf::SCFDialect", + "tensor::TensorDialect" + ]; } def DecomposeSoftmaxPass : diff --git a/compiler/src/iree/compiler/Codegen/Common/TensorDynamicDimAnalysis.cpp b/compiler/src/iree/compiler/Codegen/Common/TensorDynamicDimAnalysis.cpp new file mode 100644 index 000000000000..b0e76678732e --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Common/TensorDynamicDimAnalysis.cpp @@ -0,0 +1,236 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Codegen/Common/TensorDynamicDimAnalysis.h" +#include "iree/compiler/Dialect/Flow/IR/FlowOps.h" +#include "iree/compiler/Dialect/Util/Analysis/IntegerDivisibilityAnalysis.h" +#include "llvm/Support/Debug.h" +#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" +#include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Interfaces/DestinationStyleOpInterface.h" + +#define DEBUG_TYPE "iree-codegen-dynamic-dim-analysis" + +namespace mlir::iree_compiler { + +//===---------------------------------------------------------------------===// +// Helper function to update tensor dynamic dimension info +//===---------------------------------------------------------------------===// + +static void +updateRangeInfo(TensorDynamicDimAnalysis::TensorDimRangeInfo &rangeInfo, + Value v, unsigned dim, const ConstantIntRanges &range) { + assert(!rangeInfo.contains({v, dim}) && + "overwriting existing dim range info"); + rangeInfo.insert({{v, dim}, + ConstantIntRanges(range.umin(), range.umax(), range.smin(), + range.smax())}); +} + +static void updateDivisibilityInfo( + TensorDynamicDimAnalysis::TensorDimDivisibilityInfo &divisibilityInfo, + Value v, unsigned dim, + const IREE::Util::ConstantIntDivisibility &divisibility) { + assert(!divisibilityInfo.contains({v, dim}) && + "overwriting existing dim divisibility info"); + divisibilityInfo[{v, dim}] = divisibility; +} + +// Update the dynamic dim analysis to record the range/divisibility information +// for `tensorValue` at dimension `dimIndex` based on the range/divisibility +// information of an integer/index value `dynamicDim`. +static void updateTensorDimInfo( + Value tensorValue, unsigned dimIndex, Value dynamicDim, + const DataFlowSolver &solver, + TensorDynamicDimAnalysis::TensorDimDivisibilityInfo &divisibilityInfo, + TensorDynamicDimAnalysis::TensorDimRangeInfo &rangeInfo) { + // Update range info. + auto *rangeState = + solver.lookupState(dynamicDim); + if (rangeState && !rangeState->getValue().isUninitialized()) { + updateRangeInfo(rangeInfo, tensorValue, dimIndex, + rangeState->getValue().getValue()); + } + + // Update solver info + auto *divisibilityState = + solver.lookupState(dynamicDim); + if (divisibilityState && !divisibilityState->getValue().isUninitialized()) { + updateDivisibilityInfo(divisibilityInfo, tensorValue, dimIndex, + divisibilityState->getValue().getValue()); + } +} + +//===---------------------------------------------------------------------===// +// Transfer functions for updating dynamic dimension of results of operation. +//===---------------------------------------------------------------------===// + +// Helper function to just transfer the range and divisibility information +// `source` value to `dest` value. +static void transferTensorDimInfo( + Value source, Value dest, const DataFlowSolver &solver, + TensorDynamicDimAnalysis::TensorDimDivisibilityInfo &divisibilityInfo, + TensorDynamicDimAnalysis::TensorDimRangeInfo &rangeInfo) { + // expected that `source` and `dest` are of `RankedTensorType` and of the same + // type. + assert(source.getType() == dest.getType()); + auto sourceType = cast(source.getType()); + for (auto index : llvm::seq(0, sourceType.getRank())) { + // Transfer range info + auto rangeIt = rangeInfo.find({source, index}); + if (rangeIt != rangeInfo.end()) { + updateRangeInfo(rangeInfo, dest, index, rangeIt->second); + } + + auto divisibilityIt = divisibilityInfo.find({source, index}); + if (divisibilityIt != divisibilityInfo.end()) { + updateDivisibilityInfo(divisibilityInfo, dest, index, + divisibilityIt->second); + } + } +} + +// Update the tensor dimension information for result of a +// `flow.dispatch.tensor.load` operation. +static void updateTensorDimInfo( + IREE::Flow::DispatchTensorLoadOp flowLoadOp, const DataFlowSolver &solver, + TensorDynamicDimAnalysis::TensorDimDivisibilityInfo &divisibilityInfo, + TensorDynamicDimAnalysis::TensorDimRangeInfo &rangeInfo) { + // If there are no dynamic dimensions, nothing to do. + if (flowLoadOp.getType().hasStaticShape()) { + return; + } + // Check that all strides are 1. Abort otherwise + if (llvm::any_of(flowLoadOp.getMixedStrides(), + [](OpFoldResult s) { return !isConstantIntValue(s, 1); })) { + return; + } + + Value result = flowLoadOp.getResult(); + for (auto [index, size] : llvm::enumerate(flowLoadOp.getMixedSizes())) { + auto dynamicDim = dyn_cast(size); + if (!dynamicDim) { + continue; + } + updateTensorDimInfo(result, index, dynamicDim, solver, divisibilityInfo, + rangeInfo); + } +} + +// Update the tensor dimension information for result of a `tensor.empty` +// operation. +static void updateTensorDimInfo( + tensor::EmptyOp emptyOp, const DataFlowSolver &solver, + TensorDynamicDimAnalysis::TensorDimDivisibilityInfo &divisibilityInfo, + TensorDynamicDimAnalysis::TensorDimRangeInfo &rangeInfo) { + auto dimOperands = emptyOp.getOperands(); + if (dimOperands.empty()) { + return; + } + + Value result = emptyOp.getResult(); + auto resultType = cast(result.getType()); + int dimOperandIndex = 0; + for (auto [index, shape] : llvm::enumerate(resultType.getShape())) { + if (!ShapedType::isDynamic(shape)) + continue; + updateTensorDimInfo(result, index, dimOperands[dimOperandIndex++], solver, + divisibilityInfo, rangeInfo); + } +} + +// Update the tensor dimension information for results of an operation that +// implements the `DestinationStyleOpInterface`. +static void updateTensorDimInfo( + DestinationStyleOpInterface dstStyleOp, const DataFlowSolver &solver, + TensorDynamicDimAnalysis::TensorDimDivisibilityInfo &divisibilityInfo, + TensorDynamicDimAnalysis::TensorDimRangeInfo &rangeInfo) { + for (auto [index, result] : llvm::enumerate(dstStyleOp->getResults())) { + auto resultTensorType = dyn_cast(result.getType()); + if (!resultTensorType || resultTensorType.hasStaticShape()) { + continue; + } + Value source = dstStyleOp.getDpsInitOperand(index)->get(); + transferTensorDimInfo(source, result, solver, divisibilityInfo, rangeInfo); + } +} + +// Dispatch to the method that updates the dimension information for an +// operation. +static void updateTensorDimInfo( + Operation *op, const DataFlowSolver &solver, + TensorDynamicDimAnalysis::TensorDimDivisibilityInfo &divisibilityInfo, + TensorDynamicDimAnalysis::TensorDimRangeInfo &rangeInfo) { + LLVM_DEBUG({ + llvm::dbgs() << "Start updating op\n"; + op->print(llvm::dbgs(), OpPrintingFlags().useLocalScope()); + llvm::dbgs() << "\n"; + }); + + TypeSwitch(op) + .Case([&](auto op) { + updateTensorDimInfo(op, solver, divisibilityInfo, rangeInfo); + }) + .Case([&](auto op) { + updateTensorDimInfo(op, solver, divisibilityInfo, rangeInfo); + }); + + LLVM_DEBUG({ + for (auto [resultIndex, result] : llvm::enumerate(op->getResults())) { + auto tensorType = dyn_cast(result.getType()); + if (!tensorType) + continue; + for (auto index : llvm::seq(0, tensorType.getRank())) { + std::optional range; + std::optional divisibility; + auto rangeIt = rangeInfo.find({result, index}); + if (rangeIt != rangeInfo.end()) { + range = rangeIt->second; + } + auto divisibilityIt = divisibilityInfo.find({result, index}); + if (divisibilityIt != divisibilityInfo.end()) { + divisibility = divisibilityIt->second; + } + if (!range && !divisibility) { + continue; + } + llvm::dbgs() << "\tDim Info: Result number : " << resultIndex + << ", dim " << index; + if (range) { + llvm::dbgs() << " : Range " << range.value(); + } + if (divisibility) { + llvm::dbgs() << " : Divisibility " << divisibility.value(); + } + llvm::dbgs() << "\n"; + } + } + }); +} + +TensorDynamicDimAnalysis::TensorDynamicDimAnalysis(Operation *rootOp) + : rootOperation(rootOp) { + solver.load(); + solver.load(); + solver.load(); +} + +LogicalResult TensorDynamicDimAnalysis::run() { + if (failed(solver.initializeAndRun(rootOperation))) { + return failure(); + } + + // Walk the IR pre-order, forward and update the dynamic information for each + // tensor. + rootOperation->walk([&](Operation *op) { + updateTensorDimInfo(op, solver, divisibilityInfo, rangeInfo); + }); + + return success(); +} + +} // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Codegen/Common/TensorDynamicDimAnalysis.h b/compiler/src/iree/compiler/Codegen/Common/TensorDynamicDimAnalysis.h new file mode 100644 index 000000000000..13bdb5cac8d7 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Common/TensorDynamicDimAnalysis.h @@ -0,0 +1,65 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Dialect/Util/IR/UtilTypes.h" +#include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h" +#include "mlir/Analysis/DataFlowFramework.h" + +namespace mlir::iree_compiler { + +/// Analysis to compute information about dynamic dimensions of tensors. +/// +/// Using the IntegerRangeAnalysis and the IntegerDivisibilityAnalysis +/// this analysis builds information about the range and divisibility of dynamic +/// dimensions of tensor operands in the program. The analysis can then be +/// queried to get the range and divisibility info for any tensor value for any +/// dynamic dimension. +/// TODO: This is not a dataflow analysis or does not update information on IR +/// changes. This could be potentially expensive and is really meant to be used +/// before any transformations to the dispatch. If this needs to be more +/// efficient then this needs to be converted to a data flow solver. +class TensorDynamicDimAnalysis { +public: + explicit TensorDynamicDimAnalysis(Operation *rootOperation); + + LogicalResult run(); + + using TensorDimDivisibilityInfo = + DenseMap, + IREE::Util::ConstantIntDivisibility>; + using TensorDimRangeInfo = + DenseMap, ConstantIntRanges>; + + std::optional getRangeInfo(Value v, + unsigned dimIndex) const { + auto it = rangeInfo.find({v, dimIndex}); + if (it == rangeInfo.end()) { + return std::nullopt; + } + return it->second; + } + + std::optional + getDivisibilityInfo(Value v, unsigned dimIndex) const { + auto it = divisibilityInfo.find({v, dimIndex}); + if (it == divisibilityInfo.end()) { + return std::nullopt; + } + return it->second; + } + +private: + DataFlowSolver solver; + + // Operation scope within which the analysis is run. + Operation *rootOperation; + + // Map of tensor value to integer divisibility information for each dimension. + TensorDimDivisibilityInfo divisibilityInfo; + TensorDimRangeInfo rangeInfo; +}; + +} // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Codegen/Common/TileDispatchUsingForall.cpp b/compiler/src/iree/compiler/Codegen/Common/TileDispatchUsingForall.cpp index ebbe585bf53e..218b7f5217f1 100644 --- a/compiler/src/iree/compiler/Codegen/Common/TileDispatchUsingForall.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/TileDispatchUsingForall.cpp @@ -202,13 +202,16 @@ static LogicalResult dropUnitDistributedDims(RewriterBase &rewriter, llvm::SmallDenseSet droppedLoops; for (auto [index, lb, ub, step] : llvm::enumerate(mixedLbs, mixedUbs, mixedSteps)) { - if (!isa(lb) || !isa(ub) || !isa(step)) { + + std::optional lbVal = getConstantIntValue(lb); + std::optional ubVal = getConstantIntValue(ub); + std::optional stepVal = getConstantIntValue(step); + + if (!(lbVal && ubVal && stepVal)) { continue; } - int64_t lbVal = getConstantIntValue(lb).value(); - int64_t ubVal = getConstantIntValue(ub).value(); - int64_t stepVal = getConstantIntValue(step).value(); - if (CEILDIV(ubVal - lbVal, stepVal) == 1) { + + if (CEILDIV(ubVal.value() - lbVal.value(), stepVal.value()) == 1) { droppedLoops.insert(index); } } diff --git a/compiler/src/iree/compiler/Codegen/Common/Transforms.h b/compiler/src/iree/compiler/Codegen/Common/Transforms.h index 13cdbf577363..0a000348e22e 100644 --- a/compiler/src/iree/compiler/Codegen/Common/Transforms.h +++ b/compiler/src/iree/compiler/Codegen/Common/Transforms.h @@ -18,6 +18,17 @@ struct OneShotBufferizationOptions; namespace mlir::iree_compiler { +using IGEMMConfigFn = + std::function; +using IGEMMControlFn = std::function; + +/// Converts conv_2d ops into linalg_ext.im2col + matmul, and sets a lowering +/// configuration on the matmul. +LogicalResult convertToIGEMMAndSetConfig( + FunctionOpInterface funcOp, + std::optional configFn = std::nullopt, + std::optional controlFn = std::nullopt); + /// Eliminates tensor.empty ops to avoid buffer allocations. LogicalResult eliminateEmptyTensors( RewriterBase &rewriter, Operation *op, diff --git a/compiler/src/iree/compiler/Codegen/Common/VectorLayoutAnalysis.cpp b/compiler/src/iree/compiler/Codegen/Common/VectorLayoutAnalysis.cpp index fa3786caf61d..28b75d1f7ef8 100644 --- a/compiler/src/iree/compiler/Codegen/Common/VectorLayoutAnalysis.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/VectorLayoutAnalysis.cpp @@ -13,6 +13,7 @@ #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/raw_ostream.h" #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/Diagnostics.h" @@ -135,6 +136,9 @@ class EnforceLayout : public DataFlowAnalysis { RegionBranchPoint branchPoint, MutableArrayRef operands); + void visitRegionBranchTerminatorOpInterface(RegionBranchOpInterface branch, + RegionBranchPoint branchPoint); + DistributionLayout *getLatticeElement(Value val); MLIRContext *ctx; @@ -205,6 +209,7 @@ ChangeResult DistributionLayout::resolveWithPossibleConflict( if (!opOperand.get().hasOneUse() && !vectorLayout && llvm::dyn_cast_or_null( opOperand.get().getDefiningOp())) { + builder.setInsertionPoint(opOperand.get().getDefiningOp()); Operation *copiedConstOp = builder.clone(*opOperand.get().getDefiningOp()); Value copiedConst = copiedConstOp->getResult(0); builder.replaceAllUsesExcept(opOperand.get(), copiedConst, @@ -661,6 +666,9 @@ static void enforceLayoutToMultiReductionOp( ArrayRef operandLattices, ArrayRef resultLattices, std::function update) { + if (resultLattices.empty()) { + return; + } // Reductions should always propagate value layout to result. Result can // enforce it's layout on init. const DistributionLayout *result = resultLattices[0]; @@ -726,9 +734,12 @@ static void enforceLayoutToBroadcastOp( auto resultShape = broadcast.getResultVectorType().getShape(); auto inputType = broadcast.getSourceType(); - assert(isa(inputType) && - "Scalar broadcast not supported for now."); - auto inputShape = cast(inputType).getShape(); + + VectorType inputVectorType = dyn_cast(inputType); + if (!inputVectorType) + return; + + auto inputShape = inputVectorType.getShape(); SmallVector reductionMask(resultShape.size(), false); // Set the trailing dimensions to be reduced. @@ -993,6 +1004,9 @@ void EnforceLayout::visitOperation(Operation *op) { if (auto branch = dyn_cast(op)) { visitRegionSuccessors(branch, RegionBranchPoint::parent(), branch->getOpOperands()); + + // Handle the propagation from scf.for to yield op. + visitRegionBranchTerminatorOpInterface(branch, RegionBranchPoint::parent()); return; } @@ -1085,6 +1099,43 @@ void EnforceLayout::visitRegionSuccessors(RegionBranchOpInterface branch, } } +void EnforceLayout::visitRegionBranchTerminatorOpInterface( + RegionBranchOpInterface branch, RegionBranchPoint branchPoint) { + SmallVector successors; + branch.getSuccessorRegions(branchPoint, successors); + if (!branch.hasLoop()) + return; + SmallVector resultLattices; + for (Value result : branch->getResults()) { + DistributionLayout *resultLattice = getLatticeElement(result); + if (resultLattice->isUninitialized()) + continue; + resultLattices.push_back(resultLattice); + } + + // We do not support multiple results yet. + if (resultLattices.size() != 1) + return; + + for (RegionSuccessor successor : successors) { + if (Region *succ = successor.getSuccessor()) { + Operation *terminator = succ->back().getTerminator(); + if (scf::YieldOp yieldOp = dyn_cast(terminator)) { + for (Value operand : yieldOp.getOperands()) { + if (!isa(operand.getType())) { + continue; + } + DistributionLayout *forwardLattice = getLatticeElement(operand); + ChangeResult changed = forwardLattice->resolve(resultLattices[0]); + propagateIfChanged(forwardLattice, changed); + } + } + } + } + + return; +} + DistributionLayout *EnforceLayout::getLatticeElement(Value val) { // Add dependency of operation on the analysis state. assert(isa(val.getType()) && "Lattice value should be a vector"); diff --git a/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel index b00a94a3e4e3..ab1a76ab2fc8 100644 --- a/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel @@ -21,6 +21,7 @@ iree_lit_test_suite( "add_fmfs.mlir", "affinemin_canonicalization.mlir", "batch_matmuls.mlir", + "block_dynamic_dims.mlir", "bubble_up_ordinal_ops.mlir", "bufferize_copy_only_dispatches.mlir", "canonicalize_interface_load_store.mlir", @@ -31,6 +32,7 @@ iree_lit_test_suite( "convolutions.mlir", "erase_dead_alloc_and_stores.mlir", "decompose_affine_ops.mlir", + "decompose_boundary_pack_unpack_ops.mlir", "decompose_conv2d.mlir", "decompose_linalg_generic.mlir", "decompose_pack_unpack_ops.mlir", diff --git a/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt index fb27a4be8963..3ac6423c08c0 100644 --- a/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt @@ -17,6 +17,7 @@ iree_lit_test_suite( "add_fmfs.mlir" "affinemin_canonicalization.mlir" "batch_matmuls.mlir" + "block_dynamic_dims.mlir" "bubble_up_ordinal_ops.mlir" "bufferize_copy_only_dispatches.mlir" "canonicalize_interface_load_store.mlir" @@ -26,6 +27,7 @@ iree_lit_test_suite( "convolution_to_igemm.mlir" "convolutions.mlir" "decompose_affine_ops.mlir" + "decompose_boundary_pack_unpack_ops.mlir" "decompose_conv2d.mlir" "decompose_linalg_generic.mlir" "decompose_pack_unpack_ops.mlir" diff --git a/compiler/src/iree/compiler/Codegen/Common/test/block_dynamic_dims.mlir b/compiler/src/iree/compiler/Codegen/Common/test/block_dynamic_dims.mlir new file mode 100644 index 000000000000..819c4128a546 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Common/test/block_dynamic_dims.mlir @@ -0,0 +1,101 @@ +// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-codegen-block-dynamic-dimensions, cse))" --split-input-file --mlir-print-local-scope %s | FileCheck %s + +#pipeline_layout = #hal.pipeline.layout, + #hal.pipeline.binding, + #hal.pipeline.binding, + #hal.pipeline.binding, + #hal.pipeline.binding], flags = Indirect> +func.func @block_attention_dims() { + %c0 = arith.constant 0 : index + %cst = arith.constant 8.837890e-02 : f16 + %m_in = hal.interface.constant.load layout(#pipeline_layout) ordinal(0) : index + %k2_in = hal.interface.constant.load layout(#pipeline_layout) ordinal(1) : index + %0:2 = util.assume.int + %m_in, + %k2_in + : index, index + %m = flow.dispatch.workload.ordinal %0#0, 0 : index + %k2 = flow.dispatch.workload.ordinal %0#1, 1 : index + %q_in = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags("ReadOnly|Indirect") + : !flow.dispatch.tensor>{%m} + %key_in = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) flags("ReadOnly|Indirect") + : !flow.dispatch.tensor>{%k2} + %value_in = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) alignment(64) offset(%c0) flags("ReadOnly|Indirect") + : !flow.dispatch.tensor>{%k2} + %mask_in = hal.interface.binding.subspan layout(#pipeline_layout) binding(3) alignment(64) offset(%c0) flags("ReadOnly|Indirect") + : !flow.dispatch.tensor>{%m, %k2} + %output_in = hal.interface.binding.subspan layout(#pipeline_layout) binding(4) alignment(64) offset(%c0) flags(Indirect) + : !flow.dispatch.tensor>{%m} + %q = flow.dispatch.tensor.load %q_in, offsets = [0, 0, 0, 0], sizes = [4, %m, 32, 128], strides = [1, 1, 1, 1] + : !flow.dispatch.tensor>{%m} -> tensor<4x?x32x128xf16> + %key = flow.dispatch.tensor.load %key_in, offsets = [0, 0, 0, 0], sizes = [4, %k2, 32, 128], strides = [1, 1, 1, 1] + : !flow.dispatch.tensor>{%k2} -> tensor<4x?x32x128xf16> + %value = flow.dispatch.tensor.load %value_in, offsets = [0, 0, 0, 0], sizes = [4, %k2, 32, 128], strides = [1, 1, 1, 1] + : !flow.dispatch.tensor>{%k2} -> tensor<4x?x32x128xf16> + %mask = flow.dispatch.tensor.load %mask_in, offsets = [0, 0, 0, 0], sizes = [4, 32, %m, %k2], strides = [1, 1, 1, 1] + : !flow.dispatch.tensor>{%m, %k2} -> tensor<4x32x?x?xf16> + %1 = tensor.empty(%m) : tensor<4x?x32x128xf16> + %2 = tensor.empty(%m) : tensor<4x32x?x128xf16> + %attn = iree_linalg_ext.attention { + indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d1, d4)>, + affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d5, d1, d4)>, + affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d5, d1, d3)>, + affine_map<(d0, d1, d2, d3, d4, d5) -> ()>, + affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>, + affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>]} + ins(%q, %key, %value, %cst, %mask : tensor<4x?x32x128xf16>, tensor<4x?x32x128xf16>, tensor<4x?x32x128xf16>, f16, tensor<4x32x?x?xf16>) + outs(%2 : tensor<4x32x?x128xf16>) { + ^bb0(%b0 : f16) : + iree_linalg_ext.yield %b0 : f16 + }-> tensor<4x32x?x128xf16> + %result = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], + iterator_types = ["parallel", "parallel", "parallel", "parallel"]} + ins(%attn : tensor<4x32x?x128xf16>) outs(%1 : tensor<4x?x32x128xf16>) { + ^bb0(%in: f16, %out: f16): + linalg.yield %in : f16 + } -> tensor<4x?x32x128xf16> + flow.dispatch.tensor.store %result, %output_in, offsets = [0, 0, 0, 0], sizes = [4, %m, 32, 128], strides = [1, 1, 1, 1] + : tensor<4x?x32x128xf16> -> !flow.dispatch.tensor>{%m} + return +} +// CHECK-LABEL: func @block_attention_dims() +// CHECK-DAG: %[[C32:.+]] = arith.constant 32 : index +// CHECK-DAG: %[[C16:.+]] = arith.constant 16 : index +// CHECK-DAG: %[[M:.+]] = flow.dispatch.workload.ordinal %{{.+}}, 0 : index +// CHECK-DAG: %[[K2:.+]] = flow.dispatch.workload.ordinal %{{.+}}, 1 : index +// CHECK-DAG: %[[M_DYNAMIC:.+]] = arith.divui %[[M]], %[[C16]] +// CHECK: %[[Q_BINDING:.+]] = hal.interface.binding.subspan +// CHECK-SAME: binding(0) +// CHECK-SAME: !flow.dispatch.tensor>{%[[M_DYNAMIC]]} +// CHECK: %[[K2_DYNAMIC:.+]] = arith.divui %[[K2]], %[[C32]] +// CHECK: %[[K_BINDING:.+]] = hal.interface.binding.subspan +// CHECK-SAME: binding(1) +// CHECK-SAME: !flow.dispatch.tensor>{%[[K2_DYNAMIC]]} +// CHECK: %[[V_BINDING:.+]] = hal.interface.binding.subspan +// CHECK-SAME: binding(2) +// CHECK-SAME: !flow.dispatch.tensor>{%[[K2_DYNAMIC]]} +// CHECK: %[[MASK_BINDING:.+]] = hal.interface.binding.subspan +// CHECK-SAME: binding(3) +// CHECK-SAME: !flow.dispatch.tensor>{%[[M_DYNAMIC]], %[[K2_DYNAMIC]]} +// CHECK: %[[OUTPUT_BINDING:.+]] = hal.interface.binding.subspan +// CHECK-SAME: binding(4) +// CHECK-SAME: !flow.dispatch.tensor>{%[[M_DYNAMIC]]} +// CHECK: %[[Q:.+]] = flow.dispatch.tensor.load %[[Q_BINDING]] +// CHECK-SAME: sizes = [4, %[[M_DYNAMIC]], 16, 32, 128] +// CHECK-SAME: !flow.dispatch.tensor>{%[[M_DYNAMIC]]} +// CHECK: %[[K:.+]] = flow.dispatch.tensor.load %[[K_BINDING]] +// CHECK-SAME: sizes = [4, %[[K2_DYNAMIC]], 32, 32, 128] +// CHECK-SAME: !flow.dispatch.tensor>{%[[K2_DYNAMIC]]} +// CHECK: %[[V:.+]] = flow.dispatch.tensor.load %[[V_BINDING]] +// CHECK-SAME: sizes = [4, %[[K2_DYNAMIC]], 32, 32, 128] +// CHECK-SAME: !flow.dispatch.tensor>{%[[K2_DYNAMIC]]} +// CHECK: %[[MASK:.+]] = flow.dispatch.tensor.load %[[MASK_BINDING]] +// CHECK-SAME: sizes = [4, 32, %[[M_DYNAMIC]], 16, %[[K2_DYNAMIC]], 32] +// CHECK-SAME: !flow.dispatch.tensor>{%[[M_DYNAMIC]], %[[K2_DYNAMIC]]} +// CHECK: %[[ATTENTION:.+]] = iree_linalg_ext.attention +// CHECK: ins(%[[Q]], %[[K]], %[[V]], %{{.+}}, %[[MASK]] : +// CHECK: %[[GENERIC:.+]] = linalg.generic +// CHECK: flow.dispatch.tensor.store %[[GENERIC]], %[[OUTPUT_BINDING]] diff --git a/compiler/src/iree/compiler/Codegen/Common/test/canonicalize_interface_load_store.mlir b/compiler/src/iree/compiler/Codegen/Common/test/canonicalize_interface_load_store.mlir index 50ca569bc8f1..6f1cc19b452e 100644 --- a/compiler/src/iree/compiler/Codegen/Common/test/canonicalize_interface_load_store.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/test/canonicalize_interface_load_store.mlir @@ -71,26 +71,36 @@ func.func @dont_fold_reshape_with_not_full_load() { // ----- #pipeline_layout = #hal.pipeline.layout, #hal.pipeline.binding ]> -// CHECK-LABEL: func.func @dont_fold_dynamic_reshape() -func.func @dont_fold_dynamic_reshape() { +func.func @fold_dynamic_reshape() { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %dim0 = hal.interface.constant.load layout(#pipeline_layout) ordinal(0) : index %dim1 = hal.interface.constant.load layout(#pipeline_layout) ordinal(1) : index %dim2 = hal.interface.constant.load layout(#pipeline_layout) ordinal(2) : index %1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) : !flow.dispatch.tensor>{%dim0, %dim1} - %2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) : !flow.dispatch.tensor>{%dim2} + %2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) : !flow.dispatch.tensor>{%dim2} %3 = flow.dispatch.tensor.load %1, offsets=[0, 0, 0], sizes =[%dim0, %dim1, 96], strides=[1, 1, 1] : !flow.dispatch.tensor>{%dim0, %dim1} -> tensor - // CHECK: tensor.collapse_shape - // CHECK: tensor.expand_shape %4 = tensor.collapse_shape %3 [[0, 1], [2]] : tensor into tensor %dyn = tensor.dim %4, %c0 : tensor %5 = tensor.expand_shape %4 [[0], [1, 2]] output_shape [%dyn, 12, 8] : tensor into tensor - flow.dispatch.tensor.store %5, %2, offsets = [%c0, %c0, %c0], sizes = [%c1, 12, 8], strides = [%c1, %c1, %c1] : tensor -> !flow.dispatch.tensor>{%dim2} + flow.dispatch.tensor.store %5, %2, offsets = [0, 0, 0], sizes = [%dim2, 12, 8], strides = [1, 1, 1] : tensor -> !flow.dispatch.tensor>{%dim2} return } +// CHECK: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 * s1)> +// CHECK: func.func @fold_dynamic_reshape() +// CHECK-DAG: %[[CST0:.+]] = hal.interface.constant.load layout(#{{.+}}) ordinal(0) +// CHECK-DAG: %[[CST1:.+]] = hal.interface.constant.load layout(#{{.+}}) ordinal(1) +// CHECK-DAG: %[[CST2:.+]] = hal.interface.constant.load layout(#{{.+}}) ordinal(2) +// CHECK: %[[COLLAPSED:.+]] = affine.apply #[[MAP]]()[%[[CST0]], %[[CST1]]] +// CHECK: %[[IN_BINDING:.+]] = hal.interface.binding.subspan +// CHECK-SAME: binding(0) : !flow.dispatch.tensor>{%[[COLLAPSED]]} +// CHECK: %[[OUT_BINDING:.+]] = hal.interface.binding.subspan +// CHECK-SAME: binding(1) : !flow.dispatch.tensor>{%[[CST2]]} +// CHECK: %[[IN:.+]] = flow.dispatch.tensor.load %[[IN_BINDING]] +// CHECK: flow.dispatch.tensor.store %[[IN]], %[[OUT_BINDING]] // ----- diff --git a/compiler/src/iree/compiler/Codegen/Common/test/convolution_to_igemm.mlir b/compiler/src/iree/compiler/Codegen/Common/test/convolution_to_igemm.mlir index 3d5494e79244..3373fda8c326 100644 --- a/compiler/src/iree/compiler/Codegen/Common/test/convolution_to_igemm.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/test/convolution_to_igemm.mlir @@ -69,25 +69,6 @@ module { // ----- -#map = affine_map<(d0, d1, d2, d3)->(d0, d1, d2, d3)> -#config = #iree_codegen.lowering_config -func.func public @conv_with_lowering_config(%arg0: tensor<1x16x16x4xf32>, %arg1: tensor<3x3x4x16xf32>) -> tensor<1x14x14x16xf32> { - %cst = arith.constant 0.0 : f32 - %empty = tensor.empty() : tensor<1x14x14x16xf32> - %fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> - %0 = linalg.conv_2d_nhwc_hwcf {lowering_config = #config, - dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} - ins(%arg0, %arg1: tensor<1x16x16x4xf32>, tensor<3x3x4x16xf32>) - outs(%fill: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> - return %0 : tensor<1x14x14x16xf32> -} -// CHECK: func.func public @conv_with_lowering_config -// CHECK-NOT: iree_linalg_ext.im2col -// CHECK: linalg.conv_2d_nhwc_hwcf -// CHECK-SAME: lowering_config - -// ----- - #map = affine_map<(d0, d1, d2) -> (d0, d2)> #map1 = affine_map<(d0, d1, d2) -> (d1, d2)> #map2 = affine_map<(d0, d1, d2) -> (d0, d1)> diff --git a/compiler/src/iree/compiler/Codegen/Common/test/decompose_boundary_pack_unpack_ops.mlir b/compiler/src/iree/compiler/Codegen/Common/test/decompose_boundary_pack_unpack_ops.mlir new file mode 100644 index 000000000000..6ff5bed59060 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Common/test/decompose_boundary_pack_unpack_ops.mlir @@ -0,0 +1,201 @@ +// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-codegen-decompose-boundary-pack-unpack-ops))" --split-input-file %s | FileCheck %s -check-prefixes=CHECK + +#pipeline_layout = #hal.pipeline.layout, + #hal.pipeline.binding +]> +func.func @pack_at_source() { + %c0 = arith.constant 0 : index + %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) : !flow.dispatch.tensor> + %1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) : !flow.dispatch.tensor> + %src = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [16, 16], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<16x16xf32> + %dest = flow.dispatch.tensor.load %1, offsets = [0, 0, 0, 0], sizes = [4, 4, 4, 4], strides = [1, 1, 1, 1] : !flow.dispatch.tensor> -> tensor<4x4x4x4xf32> + %pack = tensor.pack %src inner_dims_pos = [0, 1] inner_tiles = [4, 4] into %dest : tensor<16x16xf32> -> tensor<4x4x4x4xf32> + flow.dispatch.tensor.store %pack, %1, offsets = [0, 0, 0, 0], sizes = [4, 4, 4, 4], strides = [1, 1, 1, 1] : tensor<4x4x4x4xf32> -> !flow.dispatch.tensor> + return +} +// CHECK-LABEL: func.func @pack_at_source +// CHECK-NOT: tensor.pack + +// ----- + +#pipeline_layout = #hal.pipeline.layout, + #hal.pipeline.binding +]> +func.func @unpack_at_source() { + %c0 = arith.constant 0 : index + %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) : !flow.dispatch.tensor> + %1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) : !flow.dispatch.tensor> + %src = flow.dispatch.tensor.load %0, offsets = [0, 0, 0, 0], sizes = [4, 4, 4, 4], strides = [1, 1, 1, 1] : !flow.dispatch.tensor> -> tensor<4x4x4x4xf32> + %dest = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [16, 16], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<16x16xf32> + %unpack = tensor.unpack %src inner_dims_pos = [0, 1] inner_tiles = [4, 4] into %dest : tensor<4x4x4x4xf32> -> tensor<16x16xf32> + %copy = linalg.copy ins(%unpack : tensor<16x16xf32>) outs(%dest : tensor<16x16xf32>) -> tensor<16x16xf32> + flow.dispatch.tensor.store %copy, %1, offsets = [0, 0], sizes = [16, 16], strides = [1, 1] : tensor<16x16xf32> -> !flow.dispatch.tensor> + return +} +// CHECK-LABEL: func.func @unpack_at_source +// CHECK: tensor.unpack + +// ----- + +#pipeline_layout = #hal.pipeline.layout, + #hal.pipeline.binding +]> +func.func @pack_at_dest() { + %c0 = arith.constant 0 : index + %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) : !flow.dispatch.tensor> + %1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) : !flow.dispatch.tensor> + %src = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [16, 16], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<16x16xf32> + %dest = flow.dispatch.tensor.load %1, offsets = [0, 0, 0, 0], sizes = [4, 4, 4, 4], strides = [1, 1, 1, 1] : !flow.dispatch.tensor> -> tensor<4x4x4x4xf32> + %empty = tensor.empty() : tensor<16x16xf32> + %copy = linalg.copy ins(%src : tensor<16x16xf32>) outs(%empty : tensor<16x16xf32>) -> tensor<16x16xf32> + %pack = tensor.pack %copy inner_dims_pos = [0, 1] inner_tiles = [4, 4] into %dest : tensor<16x16xf32> -> tensor<4x4x4x4xf32> + flow.dispatch.tensor.store %pack, %1, offsets = [0, 0, 0, 0], sizes = [4, 4, 4, 4], strides = [1, 1, 1, 1] : tensor<4x4x4x4xf32> -> !flow.dispatch.tensor> + return +} +// CHECK-LABEL: func.func @pack_at_dest +// CHECK: tensor.pack + +// ----- + +#pipeline_layout = #hal.pipeline.layout, + #hal.pipeline.binding +]> +func.func @unpack_at_dest() { + %c0 = arith.constant 0 : index + %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) : !flow.dispatch.tensor> + %1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) : !flow.dispatch.tensor> + %src = flow.dispatch.tensor.load %0, offsets = [0, 0, 0, 0], sizes = [4, 4, 4, 4], strides = [1, 1, 1, 1] : !flow.dispatch.tensor> -> tensor<4x4x4x4xf32> + %dest = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [16, 16], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<16x16xf32> + %unpack = tensor.unpack %src inner_dims_pos = [0, 1] inner_tiles = [4, 4] into %dest : tensor<4x4x4x4xf32> -> tensor<16x16xf32> + flow.dispatch.tensor.store %unpack, %1, offsets = [0, 0], sizes = [16, 16], strides = [1, 1] : tensor<16x16xf32> -> !flow.dispatch.tensor> + return +} +// CHECK-LABEL: func.func @unpack_at_dest +// CHECK-NOT: tensor.unpack + +// ----- + +#pipeline_layout = #hal.pipeline.layout, + #hal.pipeline.binding +]> +func.func @padded_pack() { + %c0 = arith.constant 0 : index + %cst = arith.constant 0.0 : f32 + %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) : !flow.dispatch.tensor> + %1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) : !flow.dispatch.tensor> + %src = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [15, 15], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<15x15xf32> + %dest = flow.dispatch.tensor.load %1, offsets = [0, 0, 0, 0], sizes = [4, 4, 4, 4], strides = [1, 1, 1, 1] : !flow.dispatch.tensor> -> tensor<4x4x4x4xf32> + %pack = tensor.pack %src padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [4, 4] into %dest : tensor<15x15xf32> -> tensor<4x4x4x4xf32> + flow.dispatch.tensor.store %pack, %1, offsets = [0, 0, 0, 0], sizes = [4, 4, 4, 4], strides = [1, 1, 1, 1] : tensor<4x4x4x4xf32> -> !flow.dispatch.tensor> + return +} +// CHECK-LABEL: func.func @padded_pack +// CHECK: tensor.pack + +// ----- + +#pipeline_layout = #hal.pipeline.layout, + #hal.pipeline.binding +]> +func.func @padded_unpack() { + %c0 = arith.constant 0 : index + %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) : !flow.dispatch.tensor> + %1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) : !flow.dispatch.tensor> + %src = flow.dispatch.tensor.load %0, offsets = [0, 0, 0, 0], sizes = [4, 4, 4, 4], strides = [1, 1, 1, 1] : !flow.dispatch.tensor> -> tensor<4x4x4x4xf32> + %dest = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [15, 15], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<15x15xf32> + %unpack = tensor.unpack %src inner_dims_pos = [0, 1] inner_tiles = [4, 4] into %dest : tensor<4x4x4x4xf32> -> tensor<15x15xf32> + flow.dispatch.tensor.store %unpack, %1, offsets = [0, 0], sizes = [15, 15], strides = [1, 1] : tensor<15x15xf32> -> !flow.dispatch.tensor> + return +} +// CHECK-LABEL: func.func @padded_unpack +// CHECK: tensor.unpack + +// ----- + +#pipeline_layout = #hal.pipeline.layout, + #hal.pipeline.binding +]> +func.func @load_non_full_slice() { + %c0 = arith.constant 0 : index + %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) : !flow.dispatch.tensor> + %1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) : !flow.dispatch.tensor> + %src = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [16, 16], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<16x16xf32> + %dest = flow.dispatch.tensor.load %1, offsets = [0, 0, 0, 0], sizes = [4, 4, 4, 4], strides = [1, 1, 1, 1] : !flow.dispatch.tensor> -> tensor<4x4x4x4xf32> + %pack = tensor.pack %src inner_dims_pos = [0, 1] inner_tiles = [4, 4] into %dest : tensor<16x16xf32> -> tensor<4x4x4x4xf32> + flow.dispatch.tensor.store %pack, %1, offsets = [0, 0, 0, 0], sizes = [4, 4, 4, 4], strides = [1, 1, 1, 1] : tensor<4x4x4x4xf32> -> !flow.dispatch.tensor> + return +} +// CHECK-LABEL: func.func @load_non_full_slice +// CHECK-NOT: tensor.pack + +// ----- + +#pipeline_layout = #hal.pipeline.layout, + #hal.pipeline.binding +]> +func.func @store_non_full_slice() { + %c0 = arith.constant 0 : index + %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) : !flow.dispatch.tensor> + %1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) : !flow.dispatch.tensor> + %src = flow.dispatch.tensor.load %0, offsets = [0, 0, 0, 0], sizes = [4, 4, 4, 4], strides = [1, 1, 1, 1] : !flow.dispatch.tensor> -> tensor<4x4x4x4xf32> + %dest = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [16, 16], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<16x16xf32> + %unpack = tensor.unpack %src inner_dims_pos = [0, 1] inner_tiles = [4, 4] into %dest : tensor<4x4x4x4xf32> -> tensor<16x16xf32> + flow.dispatch.tensor.store %unpack, %1, offsets = [0, 0], sizes = [16, 16], strides = [1, 1] : tensor<16x16xf32> -> !flow.dispatch.tensor> + return +} +// CHECK-LABEL: func.func @store_non_full_slice +// CHECK-NOT: tensor.unpack + +// ----- + +#pipeline_layout = #hal.pipeline.layout, + #hal.pipeline.binding, + #hal.pipeline.binding +]> +func.func @multi_use_unpack_fold() { + %c0 = arith.constant 0 : index + %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) : !flow.dispatch.tensor> + %1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) : !flow.dispatch.tensor> + %2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) alignment(64) offset(%c0) : !flow.dispatch.tensor> + %src = flow.dispatch.tensor.load %0, offsets = [0, 0, 0, 0], sizes = [4, 4, 4, 4], strides = [1, 1, 1, 1] : !flow.dispatch.tensor> -> tensor<4x4x4x4xf32> + %dest = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [16, 16], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<16x16xf32> + %unpack = tensor.unpack %src inner_dims_pos = [0, 1] inner_tiles = [4, 4] into %dest : tensor<4x4x4x4xf32> -> tensor<16x16xf32> + flow.dispatch.tensor.store %unpack, %1, offsets = [0, 0], sizes = [16, 16], strides = [1, 1] : tensor<16x16xf32> -> !flow.dispatch.tensor> + flow.dispatch.tensor.store %unpack, %2, offsets = [0, 0], sizes = [16, 16], strides = [1, 1] : tensor<16x16xf32> -> !flow.dispatch.tensor> + return +} +// CHECK-LABEL: func.func @multi_use_unpack_fold +// CHECK-NOT: tensor.unpack + +// ----- + +#pipeline_layout = #hal.pipeline.layout, + #hal.pipeline.binding, + #hal.pipeline.binding +]> +func.func @multi_use_unpack_no_fold() { + %c0 = arith.constant 0 : index + %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) : !flow.dispatch.tensor> + %1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) : !flow.dispatch.tensor> + %2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) alignment(64) offset(%c0) : !flow.dispatch.tensor> + %src = flow.dispatch.tensor.load %0, offsets = [0, 0, 0, 0], sizes = [4, 4, 4, 4], strides = [1, 1, 1, 1] : !flow.dispatch.tensor> -> tensor<4x4x4x4xf32> + %dest = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [16, 16], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<16x16xf32> + %dest2 = flow.dispatch.tensor.load %2, offsets = [0, 0], sizes = [16, 16], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<16x16xf32> + %unpack = tensor.unpack %src inner_dims_pos = [0, 1] inner_tiles = [4, 4] into %dest : tensor<4x4x4x4xf32> -> tensor<16x16xf32> + flow.dispatch.tensor.store %unpack, %1, offsets = [0, 0], sizes = [16, 16], strides = [1, 1] : tensor<16x16xf32> -> !flow.dispatch.tensor> + %copy = linalg.copy ins(%unpack : tensor<16x16xf32>) outs(%dest2 : tensor<16x16xf32>) -> tensor<16x16xf32> + flow.dispatch.tensor.store %copy, %2, offsets = [0, 0], sizes = [16, 16], strides = [1, 1] : tensor<16x16xf32> -> !flow.dispatch.tensor> + return +} +// CHECK-LABEL: func.func @multi_use_unpack_no_fold +// CHECK: tensor.unpack diff --git a/compiler/src/iree/compiler/Codegen/Common/test/propagate_reshapes_by_expansion.mlir b/compiler/src/iree/compiler/Codegen/Common/test/propagate_reshapes_by_expansion.mlir index 7dd745e5a7c3..fc9e85e3a764 100644 --- a/compiler/src/iree/compiler/Codegen/Common/test/propagate_reshapes_by_expansion.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/test/propagate_reshapes_by_expansion.mlir @@ -1,4 +1,4 @@ -// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-codegen-propagate-reshapes-by-expansion))" --split-input-file %s | FileCheck %s +// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-codegen-propagate-reshapes-by-expansion))" --split-input-file %s --mlir-print-local-scope | FileCheck %s func.func @reshape_and_lowering_config(%src: tensor<3x4xf16>, %dest: tensor<12xf16>, %dest2: tensor<12xf16>) -> tensor<12xf16> { %collapse = tensor.collapse_shape %src [[0, 1]] : tensor<3x4xf16> into tensor<12xf16> @@ -14,3 +14,75 @@ func.func @reshape_and_lowering_config(%src: tensor<3x4xf16>, %dest: tensor<12xf // CHECK: linalg.copy // CHECK-SAME: lowering_config = #iree_gpu.derived_thread_config // CHECK-SAME: ins(%[[COLLAPSE]] + +// ----- + +#pipeline_layout = #hal.pipeline.layout], flags = Indirect> +func.func @fold_collapse_into_loads_dynamic() -> tensor { + %c0 = arith.constant 0 : index + %0 = hal.interface.constant.load layout(#pipeline_layout) ordinal(0) : index + %1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) + flags("ReadOnly|Indirect") : !flow.dispatch.tensor>{%0} + %2 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0], sizes = [2, %0, 32], strides = [1, 1, 1] + : !flow.dispatch.tensor>{%0} -> tensor<2x?x32xf32> + %3 = tensor.collapse_shape %2 [[0, 1], [2]] : tensor<2x?x32xf32> into tensor + return %3 : tensor +} +// CHECK-LABEL: func @fold_collapse_into_loads_dynamic() +// CHECK: %[[CONST:.+]] = hal.interface.constant.load +// CHECK: %[[SHAPE:.+]] = affine.apply affine_map<()[s0] -> (s0 * 2)>()[%[[CONST]]] +// CHECK: %[[SUBSPAN:.+]] = hal.interface.binding.subspan +// CHECK-SAME: !flow.dispatch.tensor>{%[[SHAPE]]} +// CHECK: %[[LOAD:.+]] = flow.dispatch.tensor.load %[[SUBSPAN]] +// CHECK-SAME: offsets = [0, 0], sizes = [%[[SHAPE]], 32], strides = [1, 1] +// CHECK-SAME: !flow.dispatch.tensor>{%[[SHAPE]]} + +// ----- + +#pipeline_layout = #hal.pipeline.layout], flags = Indirect> +func.func @fold_expand_into_loads_dynamic() -> tensor<2x?x16x32xf32> { + %c0 = arith.constant 0 : index + %0 = hal.interface.constant.load layout(#pipeline_layout) ordinal(0) : index + %1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) + flags("ReadOnly|Indirect") : !flow.dispatch.tensor>{%0} + %2 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0], sizes = [2, %0, 32], strides = [1, 1, 1] + : !flow.dispatch.tensor>{%0} -> tensor<2x?x32xf32> + %3 = affine.apply affine_map<()[s0] -> (s0 floordiv 2)>()[%0] + %4 = tensor.expand_shape %2 [[0], [1, 2], [3]] output_shape [2, %3, 16, 32] : tensor<2x?x32xf32> into tensor<2x?x16x32xf32> + return %4 : tensor<2x?x16x32xf32> +} +// CHECK-LABEL: func @fold_expand_into_loads_dynamic() +// CHECK-DAG: %[[C16:.+]] = arith.constant 16 : index +// CHECK-DAG: %[[CONST:.+]] = hal.interface.constant.load +// CHECK: %[[SHAPE:.+]] = arith.divui %[[CONST]], %[[C16]] +// CHECK: %[[SUBSPAN:.+]] = hal.interface.binding.subspan +// CHECK-SAME: !flow.dispatch.tensor>{%[[SHAPE]]} +// CHECK: %[[LOAD:.+]] = flow.dispatch.tensor.load %[[SUBSPAN]] +// CHECK-SAME: offsets = [0, 0, 0, 0], sizes = [2, %[[SHAPE]], 16, 32], strides = [1, 1, 1, 1] +// CHECK-SAME: !flow.dispatch.tensor>{%[[SHAPE]]} + +// ----- + +#pipeline_layout = #hal.pipeline.layout], flags = Indirect> +func.func @fold_collapse_into_stores_dynamic(%arg0 : tensor<2x?x32xf32>) { + %c0 = arith.constant 0 : index + %0 = hal.interface.constant.load layout(#pipeline_layout) ordinal(0) : index + %1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) + flags("ReadOnly|Indirect") : !flow.dispatch.tensor>{%0} + %2 = tensor.collapse_shape %arg0 [[0, 1], [2]] : tensor<2x?x32xf32> into tensor + flow.dispatch.tensor.store %2, %1, offsets = [0, 0], sizes = [%0, 32], strides = [1, 1] + : tensor -> !flow.dispatch.tensor>{%0} + return +} +// CHECK-LABEL: func @fold_collapse_into_stores_dynamic( +// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index +// CHECK: %[[CONST:.+]] = hal.interface.constant.load +// CHECK: %[[SHAPE:.+]] = arith.divui %[[CONST]], %[[C2]] +// CHECK: %[[SUBSPAN:.+]] = hal.interface.binding.subspan +// CHECK-SAME: !flow.dispatch.tensor>{%[[SHAPE]]} +// CHECK: flow.dispatch.tensor.store %{{.+}}, %[[SUBSPAN]] +// CHECK-SAME: offsets = [0, 0, 0], sizes = [2, %[[SHAPE]], 32], strides = [1, 1, 1] +// CHECK-SAME: !flow.dispatch.tensor>{%[[SHAPE]]} diff --git a/compiler/src/iree/compiler/Codegen/Common/test/vector_layout_analysis.mlir b/compiler/src/iree/compiler/Codegen/Common/test/vector_layout_analysis.mlir index 09c4d2787bf1..6533a09e6d5a 100644 --- a/compiler/src/iree/compiler/Codegen/Common/test/vector_layout_analysis.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/test/vector_layout_analysis.mlir @@ -464,6 +464,47 @@ builtin.module attributes { transform.with_named_sequence } { // ----- +#contract_layout = #iree_vector_ext.nested_layout< + subgroup_tile = [1, 1], + batch_tile = [3, 2], + outer_tile = [4, 1], + thread_tile = [2, 32], + element_tile = [4, 1], + + subgroup_strides = [0, 0], + thread_strides = [32, 1] +> + +// This test ensures that we are not running into ops not dominating constantOp operands after layout analysis. +// We simulate that by doing elmentwise op on the value with "layout" i.e scaled_lhs after scaled_rhs. +// If not handled properly, will generate constOp before "scaled_lhs", but would get used also by "scaled_rhs". +builtin.module attributes { transform.with_named_sequence } { + func.func @handle_multiuse_constant(%lhs: vector<96x64xf16>, %rhs: vector<96x64xf16>, %arr: memref<96x64xf16>) -> () { + %c0 = arith.constant 0 : index + %cst = arith.constant 0.000000e+00 : f16 + %cst_1 = arith.constant dense<1.562500e-02> : vector<96x64xf16> + // expected-remark @above {{thread_strides = [32, 1]}} + %lhs_layout = iree_vector_ext.to_layout %lhs to layout(#contract_layout) : vector<96x64xf16> + + %scaled_rhs = arith.mulf %rhs, %cst_1 : vector<96x64xf16> + // expected-remark @above {{thread_strides = [32, 1]}} + %scaled_lhs = arith.mulf %lhs_layout, %cst_1 : vector<96x64xf16> + // expected-remark @above {{thread_strides = [32, 1]}} + %add = arith.addf %scaled_lhs, %scaled_rhs : vector<96x64xf16> + // expected-remark @above {{thread_strides = [32, 1]}} + vector.transfer_write %add, %arr[%c0, %c0] {in_bounds = [true, true]} : vector<96x64xf16>, memref<96x64xf16> + func.return + } + + transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) { + %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op + transform.iree.test_vector_layout_analysis %top_level_func : !transform.any_op + transform.yield + } +} + +// ----- + #layout = #iree_vector_ext.nested_layout< subgroup_tile = [2, 1, 1], batch_tile = [1, 2, 4], @@ -521,3 +562,36 @@ builtin.module attributes { transform.with_named_sequence } { transform.yield } } + +// ----- + +#layout = #iree_vector_ext.layout<<[VECTORY], [16]>, <[BATCHY, VECTORX], [2, 8]>> + +// Propagate and enforce through scf.for +builtin.module attributes { transform.with_named_sequence } { + func.func @scffor(%arr: memref<16x16xf16>, %arr2: memref<16xf16>, %a: vector<16xf16>, %b: vector<16xf16>) -> vector { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c1024 = arith.constant 1024 : index + %cst = arith.constant dense<0.000000e+00> : vector + %cst_0 = arith.constant 0.0 : f16 + + %out = scf.for %iv = %c0 to %c1024 step %c1 iter_args(%arg1 = %cst) -> (vector) { + %root = vector.transfer_read %arr[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16> + // expected-remark @above {{layout of result #0 is #iree_vector_ext.layout<<[ VECTORY], [16]>, <[ BATCHY, VECTORX], [2, 8]>>}} + %rootl = iree_vector_ext.to_layout %root to layout(#layout) : vector<16x16xf16> + %init = vector.extractelement %arg1[] : vector + %root_red = vector.multi_reduction, %rootl, %init [0, 1] : vector<16x16xf16> to f16 + %c = vector.broadcast %root_red : f16 to vector + scf.yield %c : vector + } + + func.return %out : vector + } + + transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) { + %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op + transform.iree.test_vector_layout_analysis %top_level_func : !transform.any_op + transform.yield + } +} diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/GPUTileSwizzleUtils.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/GPUTileSwizzleUtils.cpp index 33ddd044d588..7ef46a6c0d9a 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/GPUTileSwizzleUtils.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/GPUTileSwizzleUtils.cpp @@ -144,38 +144,19 @@ TileSwizzle getIntrinsicSwizzle(IREE::GPU::MMAIntrinsic intrinsic, return swizzle; } -// Returns the index of the dimension whose flattened size (flattening inner -// dimensions into it) matches the given `targetSize`. This is used to compute -// interleaving indices. -// -// Example: -// Input shape = [16, 8, 4, 4] -// Input targetSize = 16 -// -> Return 2, because the tail of the shape starting at index 2 is [4, 4], -// whose product equals targetSize. -static int64_t -getDimIdxForTargetSize(const TileSwizzle::ExpandShapeDimVectorType &shape, - int64_t targetSize) { - int interleaveAt = 0; - int size = 1; - for (interleaveAt = shape.size() - 1; interleaveAt >= 0; --interleaveAt) { - assert(size <= targetSize); - assert((targetSize % size) == 0); - if (size == targetSize) { - break; +static int getInnermostNonInternalDimIdx( + const TileSwizzle::ExpandShapeDimVectorType &shape) { + for (int idx = shape.size() - 1; idx >= 0; --idx) { + if (shape[idx].kind != TileSwizzle::Dim::Kind::Internal) { + return idx; } - size *= shape[interleaveAt].size; } - return interleaveAt; + assert(false && "all dimensions are internal!"); + return 0; } TileSwizzle getSwizzle(IREE::GPU::DataTiledMMAAttr mma, IREE::GPU::MMAFragment fragment) { - auto [aType, bType, cType] = mma.getABCElementTypes(); - int aBits = aType.getIntOrFloatBitWidth(); - int bBits = bType.getIntOrFloatBitWidth(); - // TODO(bjacob): Should be looked up from GPU target, instead of hard-coded. - const int targetPreferredLoadBitWidth = 128; auto swizzle = getIntrinsicSwizzle(mma.getIntrinsic().getValue(), fragment); using Kind = TileSwizzle::Dim::Kind; switch (fragment) { @@ -184,9 +165,8 @@ TileSwizzle getSwizzle(IREE::GPU::DataTiledMMAAttr mma, // Unroll on K with interleaving, then on M. if (mma.getUnrollK() > 1) { unroll(swizzle, 1, mma.getUnrollK(), Kind::CrossIntrinsic); - int interleavingIdx = getDimIdxForTargetSize( - swizzle.expandShape[1], - targetPreferredLoadBitWidth / (mma.getUnrollK() * aBits)); + int interleavingIdx = + getInnermostNonInternalDimIdx(swizzle.expandShape[1]); interleave(swizzle, 1, interleavingIdx); } if (mma.getUnrollM() > 1) { @@ -202,9 +182,8 @@ TileSwizzle getSwizzle(IREE::GPU::DataTiledMMAAttr mma, // Unroll on K with interleaving, then on N. if (mma.getUnrollK() > 1) { unroll(swizzle, 1, mma.getUnrollK(), Kind::CrossIntrinsic); - int interleavingIdx = getDimIdxForTargetSize( - swizzle.expandShape[1], - targetPreferredLoadBitWidth / (mma.getUnrollK() * bBits)); + int interleavingIdx = + getInnermostNonInternalDimIdx(swizzle.expandShape[1]); interleave(swizzle, 1, interleavingIdx); } if (mma.getUnrollN() > 1) { diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp index 45b3f4baae64..41c099f12809 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp @@ -212,7 +212,9 @@ getContractionLayout(vector::ContractionOp contract, ConcreteMmaLayout layout) { static OpaqueMmaLayout getOpaqueMFMALayout(MLIRContext *context, MMAIntrinsic type) { Type f8E4M3FNUZ = Float8E4M3FNUZType::get(context); + Type f8E5M2FNUZ = Float8E5M2FNUZType::get(context); Type f16 = Float16Type::get(context); + Type bf16 = BFloat16Type::get(context); Type f32 = Float32Type::get(context); Type i8 = IntegerType::get(context, 8); @@ -228,9 +230,18 @@ static OpaqueMmaLayout getOpaqueMFMALayout(MLIRContext *context, case MMAIntrinsic::MFMA_F32_32x32x8_F16: { return OpaqueMmaLayout{32, 32, 8, f16, f16, f32}; } + case MMAIntrinsic::MFMA_F32_16x16x16_BF16: { + return OpaqueMmaLayout{16, 16, 16, bf16, bf16, f32}; + } + case MMAIntrinsic::MFMA_F32_32x32x8_BF16: { + return OpaqueMmaLayout{32, 32, 8, bf16, bf16, f32}; + } case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ: { return OpaqueMmaLayout{16, 16, 32, f8E4M3FNUZ, f8E4M3FNUZ, f32}; } + case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ: { + return OpaqueMmaLayout{16, 16, 32, f8E5M2FNUZ, f8E5M2FNUZ, f32}; + } case MMAIntrinsic::MFMA_I32_16x16x32_I8: { return OpaqueMmaLayout{16, 16, 32, i8, i8, i32}; } @@ -332,6 +343,45 @@ static ConcreteMmaLayout getConcreteMFMALayout(MLIRContext *context, return ConcreteMmaLayout{opaqueLayout, aMLayout, aKLayout, bKLayout, bNLayout, cMLayout, cNLayout}; } + case MMAIntrinsic::MFMA_F32_16x16x16_BF16: { + // #outer = #iree_vector_ext.per_dim_layout<[LANEX], [16]> + // #inner = #iree_vector_ext.per_dim_layout<[LANEY, VECTORX], [4, 4]> + // #layout_a = #iree_vector_ext.layout<#outer, #inner> + // #layout_b = #iree_vector_ext.layout<#inner, #outer> + // #layout_c = #iree_vector_ext.layout<#inner, #outer> + + auto outer = PerDimLayoutAttr::get(context, {laneX}, {16}); + auto inner = PerDimLayoutAttr::get(context, {laneY, vectorX}, {4, 4}); + auto aMLayout = outer; + auto aKLayout = inner; + auto bKLayout = inner; + auto bNLayout = outer; + auto cMLayout = inner; + auto cNLayout = outer; + return ConcreteMmaLayout{opaqueLayout, aMLayout, aKLayout, bKLayout, + bNLayout, cMLayout, cNLayout}; + } + case MMAIntrinsic::MFMA_F32_32x32x8_BF16: { + // #outer = #iree_vector_ext.per_dim_layout<[LANEX], [32]> + // #inner1 = #iree_vector_ext.per_dim_layout<[LANEY, VECTORX], [2, 4]> + // #inner2 = #iree_vector_ext.per_dim_layout<[VECTORY, LANEY, VECTORX], + // [4, 2, 4]> + // #layout_a = #iree_vector_ext.layout<#outer, #inner1> + // #layout_b = #iree_vector_ext.layout<#inner1, #outer> + // #layout_c = #iree_vector_ext.layout<#inner2, #outer> + + auto outer = PerDimLayoutAttr::get(context, {laneX}, {32}); + auto inner = PerDimLayoutAttr::get(context, {laneY, vectorX}, {2, 4}); + auto aMLayout = outer; + auto aKLayout = inner; + auto bKLayout = inner; + auto bNLayout = outer; + auto cMLayout = + PerDimLayoutAttr::get(context, {vectorY, laneY, vectorX}, {4, 2, 4}); + auto cNLayout = outer; + return ConcreteMmaLayout{opaqueLayout, aMLayout, aKLayout, bKLayout, + bNLayout, cMLayout, cNLayout}; + } case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ: case MMAIntrinsic::MFMA_I32_16x16x32_I8: { // #outer = #iree_vector_ext.per_dim_layout<[LANEX], [16]> @@ -458,20 +508,23 @@ MMAAttr::getABCVectorTypes() const { return std::make_tuple(aType, bType, cType); } case MMAIntrinsic::MFMA_I32_16x16x16_I8: - case MMAIntrinsic::MFMA_F32_16x16x16_F16: { + case MMAIntrinsic::MFMA_F32_16x16x16_F16: + case MMAIntrinsic::MFMA_F32_16x16x16_BF16: { auto aType = VectorType::get({4}, getAType()); auto bType = VectorType::get({4}, getBType()); auto cType = VectorType::get({4}, getCType()); return std::make_tuple(aType, bType, cType); } case MMAIntrinsic::MFMA_I32_32x32x8_I8: - case MMAIntrinsic::MFMA_F32_32x32x8_F16: { + case MMAIntrinsic::MFMA_F32_32x32x8_F16: + case MMAIntrinsic::MFMA_F32_32x32x8_BF16: { auto aType = VectorType::get({4}, getAType()); auto bType = VectorType::get({4}, getBType()); auto cType = VectorType::get({16}, getCType()); return std::make_tuple(aType, bType, cType); } case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ: + case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ: case MMAIntrinsic::MFMA_I32_16x16x32_I8: { auto aType = VectorType::get({8}, getAType()); auto bType = VectorType::get({8}, getBType()); @@ -514,10 +567,13 @@ int64_t MMAAttr::getBlockSize() const { switch (getIntrinsic().getValue()) { case MMAIntrinsic::MFMA_F32_16x16x4_F32: case MMAIntrinsic::MFMA_F32_16x16x16_F16: + case MMAIntrinsic::MFMA_F32_16x16x16_BF16: case MMAIntrinsic::MFMA_I32_16x16x16_I8: case MMAIntrinsic::MFMA_F32_32x32x8_F16: + case MMAIntrinsic::MFMA_F32_32x32x8_BF16: case MMAIntrinsic::MFMA_I32_32x32x8_I8: case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ: + case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ: case MMAIntrinsic::MFMA_I32_16x16x32_I8: case MMAIntrinsic::MFMA_I32_32x32x16_I8: case MMAIntrinsic::WMMA_F16_16x16x16_F16: @@ -534,10 +590,13 @@ static int64_t getIntrinsicSubgroupSize(MMAIntrinsic intrinsic) { switch (intrinsic) { case MMAIntrinsic::MFMA_F32_16x16x4_F32: case MMAIntrinsic::MFMA_F32_16x16x16_F16: + case MMAIntrinsic::MFMA_F32_16x16x16_BF16: case MMAIntrinsic::MFMA_I32_16x16x16_I8: case MMAIntrinsic::MFMA_F32_32x32x8_F16: + case MMAIntrinsic::MFMA_F32_32x32x8_BF16: case MMAIntrinsic::MFMA_I32_32x32x8_I8: case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ: + case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ: case MMAIntrinsic::MFMA_I32_16x16x32_I8: case MMAIntrinsic::MFMA_I32_32x32x16_I8: { return 64; @@ -577,6 +636,7 @@ MMASingleSubgroupLayout getSingleSubgroupLayout(MMAIntrinsic intrinsic, } case MMAIntrinsic::MFMA_I32_16x16x16_I8: case MMAIntrinsic::MFMA_F32_16x16x16_F16: + case MMAIntrinsic::MFMA_F32_16x16x16_BF16: switch (fragment) { case MMAFragment::Lhs: return {/*outer=*/{1, 1}, /*thread=*/{16, 4}, /*tstrides=*/{1, 16}, @@ -590,6 +650,7 @@ MMASingleSubgroupLayout getSingleSubgroupLayout(MMAIntrinsic intrinsic, } case MMAIntrinsic::MFMA_I32_32x32x8_I8: case MMAIntrinsic::MFMA_F32_32x32x8_F16: + case MMAIntrinsic::MFMA_F32_32x32x8_BF16: switch (fragment) { case MMAFragment::Lhs: return {/*outer=*/{1, 1}, /*thread=*/{32, 2}, /*tstrides=*/{1, 32}, @@ -602,6 +663,7 @@ MMASingleSubgroupLayout getSingleSubgroupLayout(MMAIntrinsic intrinsic, /*element=*/{4, 1}}; } case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ: + case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ: case MMAIntrinsic::MFMA_I32_16x16x32_I8: switch (fragment) { case MMAFragment::Lhs: @@ -696,9 +758,12 @@ FailureOr MMAAttr::buildMmaOperation(OpBuilder &builder, Location loc, } case MMAIntrinsic::MFMA_I32_16x16x16_I8: case MMAIntrinsic::MFMA_F32_16x16x16_F16: + case MMAIntrinsic::MFMA_F32_16x16x16_BF16: case MMAIntrinsic::MFMA_I32_32x32x8_I8: case MMAIntrinsic::MFMA_F32_32x32x8_F16: + case MMAIntrinsic::MFMA_F32_32x32x8_BF16: case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ: + case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ: case MMAIntrinsic::MFMA_I32_16x16x32_I8: case MMAIntrinsic::MFMA_I32_32x32x16_I8: { auto [m, n, k] = getMNKShape(); @@ -1280,9 +1345,6 @@ MMAScheduleAttr::getContractionLayout(VectorContractOpInfo &opInfo, llvm::errs() << "Getting mma layouts for:\n" << contractOp << "\n"; llvm::errs() << "For schedule: " << *this << "\n"; }); - if (opInfo.getKDims().size() != 1) { - return contractOp->emitError("Unimplemented: > 1 k dims"); - } int64_t rank = contractOp.getIteratorTypesArray().size(); auto mmaAttr = llvm::cast(getIntrinsic()); @@ -1450,6 +1512,10 @@ MMAScheduleAttr::getContractionLayout(VectorContractOpInfo &opInfo, aSubgroupSizes[dim] = subgroupMBasis[i]; aSubgroupStrides[dim] = subgroupMStrides[i]; } + for (auto [kDim, lhsKDim] : + llvm::zip_equal(opInfo.getKDims(), opInfo.lhsKDim)) { + aBatchSizes[lhsKDim] = bounds[kDim]; + } aBatchSizes[afk] = bounds[opInfo.getKDims().back()] / intrinsicK; auto aLayout = createNestedLayout(context, aRank, afm, afk, @@ -1470,6 +1536,10 @@ MMAScheduleAttr::getContractionLayout(VectorContractOpInfo &opInfo, bSubgroupSizes[dim] = subgroupNBasis[i]; bSubgroupStrides[dim] = subgroupNStrides[i]; } + for (auto [kDim, rhsKDim] : + llvm::zip_equal(opInfo.getKDims(), opInfo.rhsKDim)) { + bBatchSizes[rhsKDim] = bounds[kDim]; + } bBatchSizes[bfk] = bounds[opInfo.getKDims().back()] / intrinsicK; auto bLayout = createNestedLayout(context, bRank, bfk, bfn, diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td index 1f2cad748c08..d04e9fefe5b9 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td @@ -333,11 +333,17 @@ def IREEGPU_TargetWgpAttr : AttrDef { // The maximal number of threads per X/Y/Z dimension in one workgroup. "DenseI32ArrayAttr":$max_workgroup_sizes, // The maximal number of threads we can have in one workgroup. - "uint32_t":$max_thread_count_per_workgroup, + "int32_t":$max_thread_count_per_workgroup, // The maximal number of shared memory bytes we can allocate per workgroup. - "uint32_t":$max_workgroup_memory_bytes, - // Tthe maximum number of workgroups per X/Y/Z dimension in a dispatch. + "int32_t":$max_workgroup_memory_bytes, + // The maximum number of workgroups per X/Y/Z dimension in a dispatch. "DenseI32ArrayAttr":$max_workgroup_counts, + // Max load instruction size in bits. TODO(#18849): populate on all GPUs. + OptionalParameter<"std::optional">:$max_load_instruction_bits, + // Number of SIMDs per workgroup processor. TODO(#18849): populate on all GPUs. + OptionalParameter<"std::optional">:$simds_per_wgp, + // VGPR register space size in bits. TODO(#18849): populate on all GPUs. + OptionalParameter<"std::optional">:$vgpr_space_bits, // An optional extra dict // This field allows to inject more features/limits not supported in the diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.td b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.td index d1a91597c79b..9d4ac2e9a4e1 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.td +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.td @@ -99,31 +99,56 @@ class IREEGPU_I32MmaEnumAttr } // Format: __xx_ -def MFMA_F32_16x16x4_F32 : I32EnumAttrCase<"MFMA_F32_16x16x4_F32", 0>; -def MFMA_F32_16x16x16_F16 : I32EnumAttrCase<"MFMA_F32_16x16x16_F16", 1>; -def MFMA_F32_32x32x8_F16 : I32EnumAttrCase<"MFMA_F32_32x32x8_F16", 2>; -def MFMA_F32_16x16x32_F8E4M3FNUZ : I32EnumAttrCase<"MFMA_F32_16x16x32_F8E4M3FNUZ", 3>; -def MFMA_I32_16x16x32_I8 : I32EnumAttrCase<"MFMA_I32_16x16x32_I8", 4>; -def MFMA_I32_32x32x16_I8 : I32EnumAttrCase<"MFMA_I32_32x32x16_I8", 5>; +// Values: 0xABCD where: +// * A = vendor: +// * 0 = AMD +// * 1 = NVIDIA +// * B is architecture: +// * For AMD: +// * 0 = RDNA3 +// * 8 = CDNA2 +// * 9 = CDNA3 +// * C is A/B data type: +// * 0 = f32 +// * 1 = f16 +// * 2 = bf16 +// * 3 = f8e5m2 (and variants like fnuz). +// * 4 = f8e4m3 (and variants like fnuz). +// * 8 = i8 +// * D enumerates intrinsics for the same data type. +// +// CDNA3 instrinsics +def MFMA_F32_16x16x4_F32 : I32EnumAttrCase<"MFMA_F32_16x16x4_F32", 0x0900>; +def MFMA_F32_16x16x16_F16 : I32EnumAttrCase<"MFMA_F32_16x16x16_F16", 0x0910>; +def MFMA_F32_32x32x8_F16 : I32EnumAttrCase<"MFMA_F32_32x32x8_F16", 0x0911>; +def MFMA_F32_16x16x16_BF16 : I32EnumAttrCase<"MFMA_F32_16x16x16_BF16", 0x0920>; +def MFMA_F32_32x32x8_BF16 : I32EnumAttrCase<"MFMA_F32_32x32x8_BF16", 0x0921>; +def MFMA_F32_16x16x32_F8E5M2FNUZ : I32EnumAttrCase<"MFMA_F32_16x16x32_F8E5M2FNUZ", 0x0930>; +def MFMA_F32_16x16x32_F8E4M3FNUZ : I32EnumAttrCase<"MFMA_F32_16x16x32_F8E4M3FNUZ", 0x0940>; +def MFMA_I32_16x16x32_I8 : I32EnumAttrCase<"MFMA_I32_16x16x32_I8", 0x0980>; +def MFMA_I32_32x32x16_I8 : I32EnumAttrCase<"MFMA_I32_32x32x16_I8", 0x0981>; // CDNA2 instrinsics -def MFMA_I32_16x16x16_I8 : I32EnumAttrCase<"MFMA_I32_16x16x16_I8", 6>; -def MFMA_I32_32x32x8_I8 : I32EnumAttrCase<"MFMA_I32_32x32x8_I8", 7>; +def MFMA_I32_16x16x16_I8 : I32EnumAttrCase<"MFMA_I32_16x16x16_I8", 0x0880>; +def MFMA_I32_32x32x8_I8 : I32EnumAttrCase<"MFMA_I32_32x32x8_I8", 0x0881>; // TODO: Create separate WMMA ops for AMD and NVIDIA GPUs -def WMMA_F32_16x16x16_F16 : I32EnumAttrCase<"WMMA_F32_16x16x16_F16", 8>; -def WMMA_F16_16x16x16_F16 : I32EnumAttrCase<"WMMA_F16_16x16x16_F16", 9>; +def WMMA_F32_16x16x16_F16 : I32EnumAttrCase<"WMMA_F32_16x16x16_F16", 0x0010>; +def WMMA_F16_16x16x16_F16 : I32EnumAttrCase<"WMMA_F16_16x16x16_F16", 0x0011>; // TODO: The actual I8 instruction allows specifying (mixed) signedness. // This will need to become its own class of MMA attribute. -def WMMA_I32_16x16x16_I8 : I32EnumAttrCase<"WMMA_I32_16x16x16_I8", 10>; +def WMMA_I32_16x16x16_I8 : I32EnumAttrCase<"WMMA_I32_16x16x16_I8", 0x0080>; def IREEGPU_MMAIntrinsic : IREEGPU_I32MmaEnumAttr<"MMAIntrinsic", "Descriptor for different MMA intrinsics", [ MFMA_F32_16x16x4_F32, MFMA_F32_16x16x16_F16, MFMA_F32_32x32x8_F16, + MFMA_F32_16x16x16_BF16, + MFMA_F32_32x32x8_BF16, MFMA_F32_16x16x32_F8E4M3FNUZ, + MFMA_F32_16x16x32_F8E5M2FNUZ, MFMA_I32_16x16x32_I8, MFMA_I32_32x32x16_I8, MFMA_I32_16x16x16_I8, diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp index 611a87454ecf..58bfdc0a028b 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp @@ -13,6 +13,7 @@ #include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUInterfaces.h" #include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.h" #include "iree/compiler/Codegen/Utils/Utils.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" @@ -124,20 +125,37 @@ LogicalResult setMatmulLoweringConfig(IREE::GPU::TargetAttr target, return failure(); } - // For now we are not being smart and trying to reshape dimensions to allow - // for better usage of intrinsics, and instead are tiling all dimensions - // except the inner most m, n, and k dimensions to 1. - int64_t mDim = contractionDims.m.back(); - int64_t nDim = contractionDims.n.back(); - int64_t kDim = contractionDims.k.back(); - - // Dynamic dims are expected to be taken care of earlier in the pipeline. - if (ShapedType::isDynamic(bounds[mDim]) || - ShapedType::isDynamic(bounds[nDim]) || - ShapedType::isDynamic(bounds[kDim])) { + // TODO(Max191): add dynamic shape support for inner most dims. + if (ShapedType::isDynamic(bounds[contractionDims.m.back()]) || + ShapedType::isDynamic(bounds[contractionDims.n.back()]) || + ShapedType::isDynamic(bounds[contractionDims.k.back()])) { return failure(); } + // Gather all static M, N, and K dimensions to deduce the MMASchedule. Dynamic + // dimensions will be tiled to 1 in workgroup tiling, so they are ignored when + // computing an MMA schedule. + SmallVector mDims, nDims, kDims; + for (auto mDim : contractionDims.m) { + if (!ShapedType::isDynamic(bounds[mDim])) { + mDims.push_back(mDim); + } + } + for (auto nDim : contractionDims.n) { + if (!ShapedType::isDynamic(bounds[nDim])) { + nDims.push_back(nDim); + } + } + for (auto kDim : contractionDims.k) { + if (!ShapedType::isDynamic(bounds[kDim])) { + kDims.push_back(kDim); + } + } + + auto getDimBounds = [&](SmallVector dims) -> SmallVector { + return llvm::map_to_vector(dims, [&](int64_t dim) { return bounds[dim]; }); + }; + Value lhs = linalgOp.getDpsInputOperand(0)->get(); Value rhs = linalgOp.getDpsInputOperand(1)->get(); Value init = linalgOp.getDpsInitOperand(0)->get(); @@ -146,8 +164,9 @@ LogicalResult setMatmulLoweringConfig(IREE::GPU::TargetAttr target, Type rhsElemType = getElementTypeOrSelf(rhs); Type initElemType = getElementTypeOrSelf(init); - GPUMatmulShapeType problem{bounds[mDim], bounds[nDim], bounds[kDim], - lhsElemType, rhsElemType, initElemType}; + GPUMatmulShapeType problem{getDimBounds(mDims), getDimBounds(nDims), + getDimBounds(kDims), lhsElemType, + rhsElemType, initElemType}; SmallVector intrinsics; for (IREE::GPU::MMAAttr mma : target.getWgp().getMma()) { @@ -166,7 +185,9 @@ LogicalResult setMatmulLoweringConfig(IREE::GPU::TargetAttr target, // Note that the following heuristic seeds are just placeholder values. // We need to clean it up and make it adjusting to different targets. // See https://github.com/iree-org/iree/issues/16341 for details. - if (problem.mSize * problem.nSize <= 512 * 512) { + int64_t mSize = ShapedType::getNumElements(problem.mSizes); + int64_t nSize = ShapedType::getNumElements(problem.nSizes); + if (mSize * nSize <= 512 * 512) { // For matmuls with small M*N size, we want to distribute M*N onto more // workgroups to fill the GPU. Use a smaller bestMNTileCountPerSubgroup // and a larger bestKTileCountPerSubgroup. @@ -190,10 +211,10 @@ LogicalResult setMatmulLoweringConfig(IREE::GPU::TargetAttr target, // TODO: Drop this. This is only a consideration for other pipelines. SmallVector maps = linalgOp.getIndexingMapsArray(); bool transposedLhs = - kDim != + kDims.back() != llvm::cast(maps[0].getResults().back()).getPosition(); bool transposedRhs = - nDim != + nDims.back() != llvm::cast(maps[1].getResults().back()).getPosition(); // First try to find a schedule with an exactly matching intrinsic. @@ -213,16 +234,13 @@ LogicalResult setMatmulLoweringConfig(IREE::GPU::TargetAttr target, } LDBG("Target Subgroup size: " << targetSubgroupSize); - LDBG("Schedule: sizes [" << schedule->mSize << ", " << schedule->nSize << ", " - << schedule->kSize << "]"); - LDBG("Schedule: tile counts [" << schedule->mTileCount << ", " - << schedule->nTileCount << ", " - << schedule->kTileCount << "]"); - LDBG("Schedule: warp counts [" << schedule->mWarpCount << ", " - << schedule->nWarpCount << "]"); + LDBG("Schedule: " << schedule); - std::array workgroupSize{ - schedule->nWarpCount * targetSubgroupSize, schedule->mWarpCount, 1}; + int64_t flatWorkgroupSize = + targetSubgroupSize * + ShapedType::getNumElements(schedule->nSubgroupCounts) * + ShapedType::getNumElements(schedule->mSubgroupCounts); + std::array workgroupSize{flatWorkgroupSize, 1, 1}; SmallVector workgroupTileSizes(linalgOp.getNumLoops(), 0); SmallVector reductionTileSizes(linalgOp.getNumLoops(), 0); @@ -244,18 +262,30 @@ LogicalResult setMatmulLoweringConfig(IREE::GPU::TargetAttr target, reductionTileSizes[k] = 1; } - // Compute the M/N dimension tile size by multiplying subgroup information. - workgroupTileSizes[mDim] = - schedule->mWarpCount * schedule->mTileCount * schedule->mSize; - workgroupTileSizes[nDim] = - schedule->nWarpCount * schedule->nTileCount * schedule->nSize; - - // Specify the subgroup tile sizes from the mma schedule. This is applied - subgroupTileSizes[mDim] = schedule->mTileCount; - subgroupTileSizes[nDim] = schedule->nTileCount; + // Adjust the inner bound size for packing to intrinsic shapes, since tiling + // happens after packing. + assert(bounds[mDims.back()] % schedule->mSize == 0 && + bounds[nDims.back()] % schedule->nSize == 0 && + "expected inner bound to be evenly divisible by schedule sizes."); + bounds[mDims.back()] /= schedule->mSize; + bounds[nDims.back()] /= schedule->nSize; + + // Compute the M/N dimension tile sizes by multiplying subgroup information. + for (auto [i, mDim] : llvm::enumerate(mDims)) { + workgroupTileSizes[mDim] = + schedule->mSubgroupCounts[i] * schedule->mTileSizes[i]; + subgroupTileSizes[mDim] = schedule->mTileSizes[i]; + } + for (auto [i, nDim] : llvm::enumerate(nDims)) { + workgroupTileSizes[nDim] = + schedule->nSubgroupCounts[i] * schedule->nTileSizes[i]; + subgroupTileSizes[nDim] = schedule->nTileSizes[i]; + } // Similarly the reduction tile size is just the post-packing tile count. - reductionTileSizes[kDim] = schedule->kTileCount; + for (auto [i, kDim] : llvm::enumerate(kDims)) { + reductionTileSizes[kDim] = schedule->kTileSizes[i]; + } IREE::GPU::MmaInterfaceAttr mmaKind = target.getWgp().getMma()[schedule->index]; diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp index ef04a2282c5e..5e8f031ff8ac 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp @@ -43,9 +43,12 @@ struct WgpDetails { // modes. Use duplicated values if the GPU only have one subgroup size. std::array subgroupSizeChoices; std::array maxWorkgroupSizes; - uint32_t maxThreadSize; - uint32_t maxWorkgroupMemoryBytes; + int32_t maxThreadSize; + int32_t maxWorkgroupMemoryBytes; std::array maxWorkgroupCounts; + std::optional maxLoadInstructionBits; + std::optional simdsPerWgp; + std::optional vgprSpaceBits; }; // Chip level feature/limit details @@ -109,6 +112,7 @@ TargetAttr createTargetAttr(const TargetDetails &details, StringRef arch, DenseI32ArrayAttr::get(context, wgp->maxWorkgroupSizes), wgp->maxThreadSize, wgp->maxWorkgroupMemoryBytes, DenseI32ArrayAttr::get(context, wgp->maxWorkgroupCounts), + wgp->maxLoadInstructionBits, wgp->simdsPerWgp, wgp->vgprSpaceBits, DictionaryAttr{}); TargetChipAttr targetChip; @@ -132,7 +136,10 @@ const WgpDetails *getCDNA3WgpDetails() { MMAIntrinsic::MFMA_F32_16x16x4_F32, MMAIntrinsic::MFMA_F32_16x16x16_F16, MMAIntrinsic::MFMA_F32_32x32x8_F16, + MMAIntrinsic::MFMA_F32_16x16x16_BF16, + MMAIntrinsic::MFMA_F32_32x32x8_BF16, MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ, + MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ, MMAIntrinsic::MFMA_I32_16x16x32_I8, MMAIntrinsic::MFMA_I32_32x32x16_I8, }; @@ -146,7 +153,10 @@ const WgpDetails *getCDNA3WgpDetails() { {1024, 1024, 1024}, 1024, 64 * 1024, - {0x7fffffff, 0x7fffffff, 0x7fffffff}}; + {0x7fffffff, 0x7fffffff, 0x7fffffff}, + /*maxLoadInstructionBits=*/128, + /*simdsPerWgp=*/4, + /*vgprSpaceBits=*/512 * 32}; return &cdna3Wgp; } diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp index 6f9983454af5..5111b7668958 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp @@ -792,11 +792,8 @@ static int getRegisterSpaceBitsIfKnown(IREE::HAL::ExecutableTargetAttr target) { return 16 * 128; } } else if (isAArch64(target)) { - // Can't determine register space size at compile time on SVE. - if (hasFeature(target, "+sve") || hasFeature(target, "+sve2")) { - return 0; - } - // 32 NEON registers (128-bit each). + // 32 NEON/SVE registers (at least 128-bit each, returns the base size for + // SVE). return 32 * 128; } else { // Don't know register space size as a compile-time constant on other diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPULinkExecutables.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPULinkExecutables.cpp index 8a2e91c6a646..7bfe586beec5 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPULinkExecutables.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPULinkExecutables.cpp @@ -19,7 +19,8 @@ namespace { struct LLVMCPULinkExecutablesPass : public impl::LLVMCPULinkExecutablesPassBase { - LLVMCPULinkExecutablesPass() = default; + using impl::LLVMCPULinkExecutablesPassBase< + LLVMCPULinkExecutablesPass>::LLVMCPULinkExecutablesPassBase; void runOnOperation() override { auto moduleOp = getOperation(); auto moduleBuilder = OpBuilder::atBlockBegin(moduleOp.getBody()); @@ -30,29 +31,36 @@ struct LLVMCPULinkExecutablesPass return; // Guess a module name, if needed, to make the output files readable. - auto moduleName = guessModuleName(moduleOp, "llvm_module"); + auto moduleName = guessModuleName(moduleOp, "module"); // Create our new "linked" hal.executable. - std::string linkedExecutableName = - llvm::formatv("{0}_linked_{1}", moduleName, "llvm_cpu"); + SymbolTable moduleTable(moduleOp); + std::string linkedExecutableName = llvm::formatv("{0}_linked", moduleName); auto linkedExecutableOp = moduleBuilder.create( moduleOp.getLoc(), linkedExecutableName); linkedExecutableOp.setVisibility( sourceExecutableOps.front().getVisibility()); + moduleTable.insert(linkedExecutableOp); auto executableBuilder = OpBuilder::atBlockBegin(&linkedExecutableOp.getBlock()); // Gather all unique executable targets - we may have multiple. auto executableTargetAttrs = gatherExecutableTargets(sourceExecutableOps); - for (auto [index, attr] : llvm::enumerate(executableTargetAttrs)) { + for (auto [index, targetAttr] : llvm::enumerate(executableTargetAttrs)) { + // Only link the target specified. If none specified link all. + if (!target.empty() && targetAttr.getBackend().getValue() != target) { + continue; // not linking this target + } + // Add our hal.executable.variant with an empty module. std::string linkedVariantName = executableTargetAttrs.size() == 1 - ? attr.getSymbolNameFragment() - : llvm::formatv("{0}_{1}", attr.getSymbolNameFragment(), index); + ? targetAttr.getSymbolNameFragment() + : llvm::formatv("{0}_{1}", targetAttr.getSymbolNameFragment(), + index); auto linkedTargetOp = executableBuilder.create( - moduleOp.getLoc(), linkedVariantName, attr); + moduleOp.getLoc(), linkedVariantName, targetAttr); auto targetBuilder = OpBuilder::atBlockBegin(&linkedTargetOp.getBlock()); targetBuilder.create(moduleOp.getLoc()); @@ -71,5 +79,6 @@ struct LLVMCPULinkExecutablesPass } } }; + } // namespace } // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp index 0951fbba4273..9ef65e28e94f 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp @@ -95,7 +95,7 @@ static llvm::cl::opt clEnableVectorContractCustomKernels( static llvm::cl::opt clTileDispatchUsingForall( "iree-llvmcpu-tile-dispatch-using-forall", llvm::cl::desc("Enable tile and distribute to workgroups using scf.forall"), - llvm::cl::init(false)); + llvm::cl::init(true)); // By default, IREE does not enable the Armv9-A streaming SVE mode in the // presence of scalable vectors (even when using `+sme`), as currently there's @@ -111,9 +111,8 @@ static llvm::cl::opt clForceArmStreaming( llvm::cl::init(false)); // TODO: Enable `TileDispatchUsingForall` for every pipeline. -static void addTileAndDistributePasses(OpPassManager &funcPassManager, - bool enableTileDispatchUsingForall) { - if (enableTileDispatchUsingForall || clTileDispatchUsingForall) { +static void addTileAndDistributePasses(OpPassManager &funcPassManager) { + if (clTileDispatchUsingForall) { funcPassManager.addPass( createTileAndDistributeToWorkgroupsUsingForallOpPass()); } else { @@ -346,8 +345,7 @@ void buildLLVMCPUVectorLoweringPipeline( void addCPUBufferOpsTileAndVectorizePipeline( OpPassManager &funcPassManager, TilingConfig &tilingConfig, LLVMCPUPipelineOptions &pipelineOpt) { - addTileAndDistributePasses(funcPassManager, - /*enableTileDispatchUsingForall=*/true); + addTileAndDistributePasses(funcPassManager); // Skip tiling reduction loops because this is expected to apply on copy ops // only. @@ -384,8 +382,7 @@ void addCPUBufferOpsTileAndVectorizePipeline( void addMultiTilingExpertPassPipeline(OpPassManager &funcPassManager, TilingConfig &tilingConfig, LLVMCPUPipelineOptions &pipelineOpt) { - addTileAndDistributePasses(funcPassManager, - /*enableTileDispatchUsingForall=*/true); + addTileAndDistributePasses(funcPassManager); SmallVector allFusableLevels(tilingConfig.getFusableLevels()); // Apply tile and fuse to all the non-distribution fusable levels. Skip @@ -464,8 +461,7 @@ void addMultiTilingExpertPassPipeline(OpPassManager &funcPassManager, void addConvTileAndDecomposeExpertPassPipeline( OpPassManager &funcPassManager, TilingConfig &tilingConfig, LLVMCPUPipelineOptions &pipelineOpt) { - addTileAndDistributePasses(funcPassManager, - /*enableTileDispatchUsingForall=*/true); + addTileAndDistributePasses(funcPassManager); // Run LLVMTileAndFuse firstly in case that we have fill + conv + generic // ops. At this stage, we do not apply vectorization. The reduction dim won't @@ -528,8 +524,7 @@ void addConvTileAndDecomposeExpertPassPipeline( void addMmt4dTilingExpertPassPipeline(OpPassManager &funcPassManager, TilingConfig &tilingConfig, LLVMCPUPipelineOptions &pipelineOpt) { - addTileAndDistributePasses(funcPassManager, - /*enableTileDispatchUsingForall=*/true); + addTileAndDistributePasses(funcPassManager); funcPassManager.addPass(createLLVMCPUTileAndFusePass( static_cast(tilingConfig.getVectorCommonParallelLevel()))); @@ -577,8 +572,7 @@ void addMmt4dTilingExpertPassPipeline(OpPassManager &funcPassManager, void addCPUDataTilingPipeline(OpPassManager &funcPassManager, TilingConfig &tilingConfig, LLVMCPUPipelineOptions &pipelineOpt) { - addTileAndDistributePasses(funcPassManager, - /*enableTileDispatchUsingForall=*/true); + addTileAndDistributePasses(funcPassManager); // The below two passes are nop if pack/unpack is not specified in ukernels // attribute. By default, they are disabled. @@ -621,8 +615,7 @@ void addCPUDataTilingPipeline(OpPassManager &funcPassManager, void addCPULinalgExtTileAndVectorizePipeline( OpPassManager &funcPassManager, TilingConfig &tilingConfig, LLVMCPUPipelineOptions &pipelineOpt) { - addTileAndDistributePasses(funcPassManager, - /*enableTileDispatchUsingForall=*/false); + addTileAndDistributePasses(funcPassManager); funcPassManager.addPass( createLLVMCPUTilePass(tilingConfig.getVectorCommonParallelLevel())); // TODO: Remove the pass once we have PartialReductionOpInterface implemented @@ -661,8 +654,7 @@ void addCPULinalgExtTileAndVectorizePipeline( } void addCPUDefaultPassPipeline(OpPassManager &funcPassManager) { - addTileAndDistributePasses(funcPassManager, - /*enableTileDispatchUsingForall=*/false); + addTileAndDistributePasses(funcPassManager); addCPUBufferizePasses(funcPassManager); } @@ -835,9 +827,12 @@ void buildLLVMCPUCodegenPassPipeline(OpPassManager &variantPassManager, // NOTE: this runs on the top-level program module containing all // hal.executable ops. -void buildLLVMCPULinkingPassPipeline(OpPassManager &modulePassManager) { +void buildLLVMCPULinkingPassPipeline(OpPassManager &modulePassManager, + std::optional target) { // Link together executables. This may produce some IR duplication. - modulePassManager.addPass(createLLVMCPULinkExecutablesPass()); + LLVMCPULinkExecutablesPassOptions linkOptions; + linkOptions.target = target.value_or(""); + modulePassManager.addPass(createLLVMCPULinkExecutablesPass(linkOptions)); // Cleanup IR duplication. modulePassManager.addNestedPass( diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.h b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.h index 42d4035260db..4696bc808118 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.h +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.h @@ -12,6 +12,8 @@ #ifndef IREE_COMPILER_CODEGEN_LLVMCPU_PASSES_H_ #define IREE_COMPILER_CODEGEN_LLVMCPU_PASSES_H_ +#include + #include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h" #include "mlir/Pass/Pass.h" @@ -156,7 +158,9 @@ void buildLLVMCPUCodegenPassPipeline(OpPassManager &variantPassManager, //----------------------------------------------------------------------------// /// Populates passes needed to link HAL executables across LLVMCPU targets. -void buildLLVMCPULinkingPassPipeline(OpPassManager &modulePassManager); +void buildLLVMCPULinkingPassPipeline( + OpPassManager &modulePassManager, + std::optional target = std::nullopt); //----------------------------------------------------------------------------// // Register LLVMCPU Passes diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.td b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.td index c9aec6740923..12f90be95ee0 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.td +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.td @@ -69,6 +69,13 @@ def LLVMCPUEmitVectorizationRemarksPass : def LLVMCPULinkExecutablesPass : Pass<"iree-llvmcpu-link-executables", "mlir::ModuleOp"> { let summary = "Links LLVMCPU HAL executables within the top-level program module."; + let options = [ + Option< + "target", "target", + "std::string", "", + "Target backend name whose executables will be linked by this pass." + >, + ]; } def LLVMCPULowerExecutableTargetPass : diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/select_aarch64_sve_lowering_strategy.mlir b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/select_aarch64_sve_lowering_strategy.mlir index 757a039ed119..1308442f23bf 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/select_aarch64_sve_lowering_strategy.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/select_aarch64_sve_lowering_strategy.mlir @@ -28,7 +28,7 @@ func.func @matmul_tensors() attributes {hal.executable.target = #executable_targ return } -// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config +// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config // CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info // CHECK: func.func @matmul_tensors() // CHECK-SAME: translation_info = #[[TRANSLATION]] @@ -118,7 +118,7 @@ func.func @matmul_tensors() attributes {hal.executable.target = #executable_targ return } -// DISABLE-ARM-SME-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config +// DISABLE-ARM-SME-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config // DISABLE-ARM-SME-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info // DISABLE-ARM-SME: func.func @matmul_tensors() // DISABLE-ARM-SME-SAME: translation_info = #[[TRANSLATION]] @@ -179,8 +179,8 @@ func.func @matmul_with_fill() attributes {hal.executable.target = #executable_ta return } -// CHECK-DAG: #[[CONFIG1:.+]] = #iree_codegen.lowering_config -// CHECK-DAG: #[[CONFIG2:.+]] = #iree_codegen.lowering_config +// CHECK-DAG: #[[CONFIG1:.+]] = #iree_codegen.lowering_config +// CHECK-DAG: #[[CONFIG2:.+]] = #iree_codegen.lowering_config // CHECK: #[[TRANSLATION:.+]] = #iree_codegen.translation_info // CHECK: func.func @matmul_with_fill() // CHECK-SAME: translation_info = #[[TRANSLATION]] @@ -217,3 +217,34 @@ func.func @depthwise_conv() attributes {hal.executable.target = #executable_targ // CHECK-SAME: translation_info = #[[TRANSLATION]] // CHECK: linalg.depthwise_conv_2d_nhwc_hwc // CHECK-SAME: lowering_config = #[[CONFIG]] + +// ----- + +// Regression test. SVE isn't used (scalable vectorizaton of this op is not yet +// supported), but used to fail to compile when SVE was enabled due to tile +// sizes leading to large vectors. + +#pipeline_layout = #hal.pipeline.layout, + #hal.pipeline.binding +]> +#executable_target_embedded_elf_arm_64_ = #hal.executable.target<"llvm-cpu", "embedded-elf-arm_64", {cpu_features = "+sve", data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128", native_vector_size = 16 : index, target_triple = "aarch64-none-elf"}> +func.func @pooling_nchw_max(%arg0: !flow.dispatch.tensor>, %arg1: !flow.dispatch.tensor>) attributes {hal.executable.target = #executable_target_embedded_elf_arm_64_} { + %cst = arith.constant 0.0 : f32 + %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) : !flow.dispatch.tensor> + %1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) : !flow.dispatch.tensor> + %2 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0, 0], sizes = [1, 64, 114, 114], strides = [1, 1, 1, 1] : !flow.dispatch.tensor> -> tensor<1x64x114x114xf32> + %3 = tensor.empty() : tensor<1x64x56x56xf32> + %4 = tensor.empty() : tensor<3x3xf32> + %5 = linalg.fill ins(%cst : f32) outs(%3 : tensor<1x64x56x56xf32>) -> tensor<1x64x56x56xf32> + %6 = linalg.pooling_nchw_max {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<2xi64>} ins(%2, %4 : tensor<1x64x114x114xf32>, tensor<3x3xf32>) outs(%3 : tensor<1x64x56x56xf32>) -> tensor<1x64x56x56xf32> + flow.dispatch.tensor.store %6, %1, offsets = [0, 0, 0, 0], sizes = [1, 64, 56, 56], strides = [1, 1, 1, 1] : tensor<1x64x56x56xf32> -> !flow.dispatch.tensor> + return +} + +// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config +// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info +// CHECK: func.func @pooling_nchw_max +// CHECK-SAME: translation_info = #[[TRANSLATION]] +// CHECK: linalg.pooling_nchw_max +// CHECK-SAME: lowering_config = #[[CONFIG]] diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel index b074612adbc5..19af0c4155c3 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel @@ -91,10 +91,13 @@ iree_compiler_cc_library( "ConvertToROCDL.cpp", "ExtractAddressComputationGPUPass.cpp", "KernelConfig.cpp", + "LLVMGPUAssignConstantOrdinals.cpp", "LLVMGPUCastAddressSpaceFunction.cpp", "LLVMGPUCastTypeToFitMMA.cpp", "LLVMGPUConfigureTensorLayouts.cpp", "LLVMGPUConfigureVectorLayouts.cpp", + "LLVMGPUConvolutionToIGEMM.cpp", + "LLVMGPULinkExecutables.cpp", "LLVMGPULowerExecutableTarget.cpp", "LLVMGPUPackSharedMemoryAlloc.cpp", "LLVMGPUPrefetching.cpp", diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt index 6a92f60d7f04..aa2c5a56bea5 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt @@ -76,10 +76,13 @@ iree_cc_library( "ConvertToROCDL.cpp" "ExtractAddressComputationGPUPass.cpp" "KernelConfig.cpp" + "LLVMGPUAssignConstantOrdinals.cpp" "LLVMGPUCastAddressSpaceFunction.cpp" "LLVMGPUCastTypeToFitMMA.cpp" "LLVMGPUConfigureTensorLayouts.cpp" "LLVMGPUConfigureVectorLayouts.cpp" + "LLVMGPUConvolutionToIGEMM.cpp" + "LLVMGPULinkExecutables.cpp" "LLVMGPULowerExecutableTarget.cpp" "LLVMGPUPackSharedMemoryAlloc.cpp" "LLVMGPUPrefetching.cpp" diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp index ff002ace5b0f..0d9c7f9ad2e6 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp @@ -301,6 +301,11 @@ setConvolutionVectorDistributionConfig(IREE::GPU::TargetAttr target, Type rhsElemType = getElementTypeOrSelf(rhs); Type initElemType = getElementTypeOrSelf(init); + // TODO(Max191): Support multiple M/N/K dimension problems for MMASchedules + // once the pipeline is able to support it. After adding multiple dimensions, + // all instances of schedule->m/nSubgroupCounts[0] and + // schedule->m/n/kTileSizes[0] need to use the full list of sizes instead of + // just the first element. GPUMatmulShapeType problem{bounds[mDim], bounds[nDim], bounds[kDim], lhsElemType, rhsElemType, initElemType}; @@ -339,8 +344,9 @@ setConvolutionVectorDistributionConfig(IREE::GPU::TargetAttr target, return failure(); } - std::array workgroupSize{ - schedule->nWarpCount * targetSubgroupSize, schedule->mWarpCount, 1}; + std::array workgroupSize{schedule->nSubgroupCounts[0] * + targetSubgroupSize, + schedule->mSubgroupCounts[0], 1}; SmallVector workgroupTileSizes(op.getNumLoops(), 0); SmallVector reductionTileSizes(op.getNumLoops(), 0); @@ -360,11 +366,11 @@ setConvolutionVectorDistributionConfig(IREE::GPU::TargetAttr target, } // Compute the M/N dimension tile size by multiply subgroup information. workgroupTileSizes[mDim] = - schedule->mWarpCount * schedule->mTileCount * schedule->mSize; + schedule->mSubgroupCounts[0] * schedule->mTileSizes[0] * schedule->mSize; workgroupTileSizes[nDim] = - schedule->nWarpCount * schedule->nTileCount * schedule->nSize; + schedule->nSubgroupCounts[0] * schedule->nTileSizes[0] * schedule->nSize; - reductionTileSizes[kDim] = schedule->kTileCount * schedule->kSize; + reductionTileSizes[kDim] = schedule->kTileSizes[0] * schedule->kSize; // Tile all filter loop dimensions to 1. for (int64_t filterDim : convolutionDims->filterLoop) { @@ -386,8 +392,8 @@ setConvolutionVectorDistributionConfig(IREE::GPU::TargetAttr target, // for later access in the pipeline. SmallVector pipelineAttrs; auto scheduleAttr = IREE::GPU::MMAScheduleAttr::get( - context, target.getWgp().getMma()[schedule->index], schedule->mWarpCount, - schedule->nWarpCount); + context, target.getWgp().getMma()[schedule->index], + schedule->mSubgroupCounts[0], schedule->nSubgroupCounts[0]); pipelineAttrs.emplace_back(StringAttr::get(context, "mma_schedule"), scheduleAttr); @@ -489,6 +495,11 @@ setMatmulVectorDistributionConfig(IREE::GPU::TargetAttr target, rhsElemType = getElementTypeOrSelf(rhsOp.getDpsInputs()[0]); } + // TODO(Max191): Support multiple M/N/K dimension problems for MMASchedules + // once the pipeline is able to support it. After adding multiple dimensions, + // all instances of schedule->m/nSubgroupCounts[0] and + // schedule->m/n/kTileSizes[0] need to use the full list of sizes instead of + // just the first element. GPUMatmulShapeType problem{bounds[mDim], bounds[nDim], bounds[kDim], lhsElemType, rhsElemType, initElemType}; @@ -509,7 +520,7 @@ setMatmulVectorDistributionConfig(IREE::GPU::TargetAttr target, // Note that the following heuristic seeds are just placeholder values. // We need to clean it up and make it adjusting to different targets. // See https://github.com/iree-org/iree/issues/16341 for details. - if (problem.mSize * problem.nSize <= clGPUMatmulCThreshold) { + if (problem.mSizes[0] * problem.nSizes[0] <= clGPUMatmulCThreshold) { // For matmuls with small M*N size, we want to distribute M*N onto more // workgroups to fill the GPU. Use a smaller bestMNTileCountPerSubgroup // and a larger bestKTileCountPerSubgroup. @@ -573,16 +584,11 @@ setMatmulVectorDistributionConfig(IREE::GPU::TargetAttr target, } LDBG("Target Subgroup size: " << targetSubgroupSize); - LDBG("Schedule: sizes [" << schedule->mSize << ", " << schedule->nSize << ", " - << schedule->kSize << "]"); - LDBG("Schedule: tile counts [" << schedule->mTileCount << ", " - << schedule->nTileCount << ", " - << schedule->kTileCount << "]"); - LDBG("Schedule: warp counts [" << schedule->mWarpCount << ", " - << schedule->nWarpCount << "]"); + LDBG("Schedule: " << schedule); - std::array workgroupSize{ - schedule->nWarpCount * targetSubgroupSize, schedule->mWarpCount, 1}; + std::array workgroupSize{schedule->nSubgroupCounts[0] * + targetSubgroupSize, + schedule->mSubgroupCounts[0], 1}; SmallVector workgroupTileSizes(op.getNumLoops(), 0); SmallVector reductionTileSizes(op.getNumLoops(), 0); @@ -605,11 +611,11 @@ setMatmulVectorDistributionConfig(IREE::GPU::TargetAttr target, // Compute the M/N dimension tile size by multiply subgroup information. workgroupTileSizes[mDim] = - schedule->mWarpCount * schedule->mTileCount * schedule->mSize; + schedule->mSubgroupCounts[0] * schedule->mTileSizes[0] * schedule->mSize; workgroupTileSizes[nDim] = - schedule->nWarpCount * schedule->nTileCount * schedule->nSize; + schedule->nSubgroupCounts[0] * schedule->nTileSizes[0] * schedule->nSize; - reductionTileSizes[kDim] = schedule->kTileCount * schedule->kSize; + reductionTileSizes[kDim] = schedule->kTileSizes[0] * schedule->kSize; LLVM_DEBUG(debugPrintContractionInfo("Workgroup tile sizes", op.getNumLoops(), *contractionDims, workgroupTileSizes)); @@ -631,8 +637,8 @@ setMatmulVectorDistributionConfig(IREE::GPU::TargetAttr target, // for later access in the pipeline. SmallVector pipelineAttrs; auto scheduleAttr = IREE::GPU::MMAScheduleAttr::get( - context, target.getWgp().getMma()[schedule->index], schedule->mWarpCount, - schedule->nWarpCount); + context, target.getWgp().getMma()[schedule->index], + schedule->mSubgroupCounts[0], schedule->nSubgroupCounts[0]); pipelineAttrs.emplace_back(StringAttr::get(context, "mma_schedule"), scheduleAttr); @@ -772,22 +778,17 @@ setAttentionVectorDistributionConfig(IREE::GPU::TargetAttr target, // TODO: Due to a bug in layout configuration, we cannot set warp count on // the N dimension. This is however ok, because we generally do not want to // distribute subgroups on N dimension anyway. - if (schedule->nWarpCount != 1) { - schedule->nTileCount *= schedule->nWarpCount; - schedule->nWarpCount = 1; + if (schedule->nSubgroupCounts[0] != 1) { + schedule->nTileSizes[0] *= schedule->nSubgroupCounts[0]; + schedule->nSubgroupCounts[0] = 1; } LDBG("Target Subgroup size: " << targetSubgroupSize); - LDBG("Schedule: sizes [" << schedule->mSize << ", " << schedule->nSize << ", " - << schedule->kSize << "]"); - LDBG("Schedule: tile counts [" << schedule->mTileCount << ", " - << schedule->nTileCount << ", " - << schedule->kTileCount << "]"); - LDBG("Schedule: warp counts [" << schedule->mWarpCount << ", " - << schedule->nWarpCount << "]"); + LDBG("Schedule: " << schedule); - std::array workgroupSize{ - schedule->nWarpCount * targetSubgroupSize, schedule->mWarpCount, 1}; + std::array workgroupSize{schedule->nSubgroupCounts[0] * + targetSubgroupSize, + schedule->mSubgroupCounts[0], 1}; SmallVector workgroupTileSizes(opInfo.getDomainRank(), 0); SmallVector reductionTileSizes(op.getNumLoops(), 0); @@ -811,11 +812,11 @@ setAttentionVectorDistributionConfig(IREE::GPU::TargetAttr target, // Compute the M/N dimension tile size by multiply subgroup information. workgroupTileSizes[mDim] = - schedule->mWarpCount * schedule->mTileCount * schedule->mSize; + schedule->mSubgroupCounts[0] * schedule->mTileSizes[0] * schedule->mSize; workgroupTileSizes[nDim] = - schedule->nWarpCount * schedule->nTileCount * schedule->nSize; + schedule->nSubgroupCounts[0] * schedule->nTileSizes[0] * schedule->nSize; - reductionTileSizes[k2Dim] = schedule->kTileCount * schedule->kSize; + reductionTileSizes[k2Dim] = schedule->kTileSizes[0] * schedule->kSize; MLIRContext *context = op.getContext(); SmallVector attrs; @@ -824,15 +825,33 @@ setAttentionVectorDistributionConfig(IREE::GPU::TargetAttr target, attrs.emplace_back(StringAttr::get(context, "reduction"), b.getI64ArrayAttr(reductionTileSizes)); - auto configDict = DictionaryAttr::get(context, attrs); + SmallVector qkAttrs; + SmallVector pvAttrs; + + qkAttrs.emplace_back(b.getNamedAttr("attention_qk_matmul", b.getUnitAttr())); + pvAttrs.emplace_back(b.getNamedAttr("attention_pv_matmul", b.getUnitAttr())); + + auto qkAttrDict = b.getDictionaryAttr(qkAttrs); + auto pvAttrDict = b.getDictionaryAttr(pvAttrs); + + SmallVector decompositionConfig; + decompositionConfig.emplace_back( + b.getNamedAttr(IREE::LinalgExt::AttentionOp::getQKAttrStr(), qkAttrDict)); + decompositionConfig.emplace_back( + b.getNamedAttr(IREE::LinalgExt::AttentionOp::getPVAttrStr(), pvAttrDict)); + + DictionaryAttr decompositionConfigDict = + b.getDictionaryAttr(decompositionConfig); + + auto configDict = b.getDictionaryAttr(attrs); auto loweringConfig = IREE::GPU::LoweringConfigAttr::get(context, configDict); // Attach the MMA schedule as an attribute to the entry point export function // for later access in the pipeline. SmallVector pipelineAttrs; auto scheduleAttr = IREE::GPU::MMAScheduleAttr::get( - context, target.getWgp().getMma()[schedule->index], schedule->mWarpCount, - schedule->nWarpCount); + context, target.getWgp().getMma()[schedule->index], + schedule->mSubgroupCounts[0], schedule->nSubgroupCounts[0]); pipelineAttrs.emplace_back(StringAttr::get(context, "mma_schedule"), scheduleAttr); @@ -842,6 +861,9 @@ setAttentionVectorDistributionConfig(IREE::GPU::TargetAttr target, auto pipelineConfig = DictionaryAttr::get(context, pipelineAttrs); + // Set attention decomposition control config. + op.setDecompositionConfigAttr(decompositionConfigDict); + return setOpConfigAndEntryPointFnTranslation( entryPoint, op, loweringConfig, CodeGenPipeline::LLVMGPUVectorDistribute, workgroupSize, targetSubgroupSize, pipelineConfig); diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUAssignConstantOrdinals.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUAssignConstantOrdinals.cpp new file mode 100644 index 000000000000..c789b92a644c --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUAssignConstantOrdinals.cpp @@ -0,0 +1,53 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Codegen/LLVMGPU/Passes.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Pass/Pass.h" + +namespace mlir::iree_compiler { + +#define GEN_PASS_DEF_LLVMGPUASSIGNCONSTANTORDINALSPASS +#include "iree/compiler/Codegen/LLVMGPU/Passes.h.inc" + +namespace { + +struct LLVMGPUAssignConstantOrdinalsPass + : public impl::LLVMGPUAssignConstantOrdinalsPassBase< + LLVMGPUAssignConstantOrdinalsPass> { + void runOnOperation() override { + auto variantOp = getOperation(); + + // Get a constant key -> ordinal mapping. + auto keyOrdinals = variantOp.gatherConstantOrdinals(); + if (keyOrdinals.empty()) + return; + + // Update placeholders to hold the concrete ordinal values. + // Eventually MLIR or LLVM will inline them. + auto moduleOp = variantOp.getInnerModule(); + for (auto globalOp : + llvm::make_early_inc_range(moduleOp.getOps())) { + auto keyAttr = globalOp->getAttr( + IREE::HAL::ExecutableConstantBlockOp::getKeyAttrName()); + if (!keyAttr) + continue; + auto it = keyOrdinals.find(keyAttr); + if (it == keyOrdinals.end()) { + globalOp.emitOpError() + << "no constant block providing key '" << keyAttr << "'"; + return signalPassFailure(); + } + globalOp->removeAttr( + IREE::HAL::ExecutableConstantBlockOp::getKeyAttrName()); + globalOp.setConstantAttr(UnitAttr::get(globalOp.getContext())); + globalOp.setValueAttr(IntegerAttr::get( + IntegerType::get(globalOp.getContext(), 32), it->second)); + } + } +}; +} // namespace +} // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUConvolutionToIGEMM.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUConvolutionToIGEMM.cpp new file mode 100644 index 000000000000..b88696ab8f63 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUConvolutionToIGEMM.cpp @@ -0,0 +1,66 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Codegen/Common/Passes.h" +#include "iree/compiler/Codegen/Common/Transforms.h" +#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.h" +#include "iree/compiler/Codegen/LLVMGPU/Passes.h" +#include "iree/compiler/Codegen/Utils/GPUUtils.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" + +#define DEBUG_TYPE "iree-llvmgpu-convolution-to-igemm" + +namespace mlir::iree_compiler { + +#define GEN_PASS_DEF_LLVMGPUCONVOLUTIONTOIGEMMPASS +#include "iree/compiler/Codegen/LLVMGPU/Passes.h.inc" + +namespace { + +/// Function for setting lowering configurations on contractions resulting from +/// the IGEMM transformation. This currently uses the TileAndFuse pipeline, and +/// tries to target MMA intrinsics. +static LogicalResult llvmgpuConfigFn(linalg::GenericOp genericOp, + IREE::LinalgExt::Im2colOp im2colOp) { + auto funcOp = genericOp->getParentOfType(); + if (!funcOp) { + return genericOp.emitError("cannot find parent funcOp"); + } + IREE::GPU::TargetAttr target = getGPUTargetAttr(funcOp); + if (!target) { + return funcOp.emitError("missing GPU target in parent funcOp"); + } + if (failed(IREE::GPU::setMatmulLoweringConfig(target, funcOp, genericOp))) { + return IREE::GPU::setTileAndFuseLoweringConfig(target, funcOp, genericOp); + } + return success(); +} + +static bool llvmgpuControlFn(Operation *op) { + // Do not convert anything that already has a lowering configuration. + if (getLoweringConfig(op)) { + return false; + } + return true; +} + +struct LLVMGPUConvolutionToIGEMMPass final + : impl::LLVMGPUConvolutionToIGEMMPassBase { + using impl::LLVMGPUConvolutionToIGEMMPassBase< + LLVMGPUConvolutionToIGEMMPass>::LLVMGPUConvolutionToIGEMMPassBase; + + void runOnOperation() override; +}; + +void LLVMGPUConvolutionToIGEMMPass::runOnOperation() { + if (failed(convertToIGEMMAndSetConfig(getOperation(), llvmgpuConfigFn, + llvmgpuControlFn))) { + return signalPassFailure(); + } +} + +} // namespace +} // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPULinkExecutables.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPULinkExecutables.cpp new file mode 100644 index 000000000000..5ffaff984b98 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPULinkExecutables.cpp @@ -0,0 +1,123 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Codegen/LLVMGPU/Passes.h" +#include "iree/compiler/Codegen/Utils/LinkingUtils.h" +#include "iree/compiler/Utils/ModuleUtils.h" +#include "llvm/Support/FormatVariadic.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Pass/Pass.h" + +namespace mlir::iree_compiler { + +#define GEN_PASS_DEF_LLVMGPULINKEXECUTABLESPASS +#include "iree/compiler/Codegen/LLVMGPU/Passes.h.inc" + +namespace { + +// Returns true if the address space of a global symbol is private to the module +// scope it originates in. AMD and NVIDIA disagree on the naming but the values +// match. LLVM is a mess here. +static bool isSymbolAddressSpacePrivate(uint32_t addressSpace) { + return addressSpace == /*local*/ 3 || addressSpace == /*private*/ 5; +} + +static SymbolTable::Visibility +convertLinkageToVisibility(LLVM::Linkage linkage) { + switch (linkage) { + case LLVM::Linkage::Private: + return SymbolTable::Visibility::Private; + case LLVM::Linkage::External: + return SymbolTable::Visibility::Public; + default: + return SymbolTable::Visibility::Public; + } +} + +// Returns true if we are allowed to rename |op| as part of merging. +// The LLVMGPU lowering is super careful about assigning linkage so we err on +// the side of renaming (as 100% of usage today does not reference external +// things). +static bool allowRenamingPrivateLLVMSymbols(Operation *op) { + if (auto globalOp = dyn_cast(op)) { + if (isSymbolAddressSpacePrivate(globalOp.getAddrSpace())) { + return true; + } + return convertLinkageToVisibility(globalOp.getLinkage()) == + SymbolTable::Visibility::Private; + } else if (auto funcOp = dyn_cast(op)) { + return convertLinkageToVisibility(funcOp.getLinkage()) == + SymbolTable::Visibility::Private; + } + return SymbolTable::getSymbolVisibility(op) == + SymbolTable::Visibility::Private; +} + +struct LLVMGPULinkExecutablesPass + : public impl::LLVMGPULinkExecutablesPassBase { + using impl::LLVMGPULinkExecutablesPassBase< + LLVMGPULinkExecutablesPass>::LLVMGPULinkExecutablesPassBase; + void runOnOperation() override { + auto moduleOp = getOperation(); + auto moduleBuilder = OpBuilder::atBlockBegin(moduleOp.getBody()); + + auto sourceExecutableOps = + llvm::to_vector<8>(moduleOp.getOps()); + if (sourceExecutableOps.size() <= 1) + return; + + // Guess a module name, if needed, to make the output files readable. + auto moduleName = guessModuleName(moduleOp, "module"); + + // Create our new "linked" hal.executable. + SymbolTable moduleTable(moduleOp); + std::string linkedExecutableName = llvm::formatv("{0}_linked", moduleName); + auto linkedExecutableOp = moduleBuilder.create( + moduleOp.getLoc(), linkedExecutableName); + linkedExecutableOp.setVisibility( + sourceExecutableOps.front().getVisibility()); + moduleTable.insert(linkedExecutableOp); + auto executableBuilder = + OpBuilder::atBlockBegin(&linkedExecutableOp.getBlock()); + + // Gather all unique executable targets - we may have multiple. + auto executableTargetAttrs = gatherExecutableTargets(sourceExecutableOps); + for (auto [index, targetAttr] : llvm::enumerate(executableTargetAttrs)) { + // Only link the target specified. If none specified link all. + if (!target.empty() && targetAttr.getBackend().getValue() != target) { + continue; // not linking this target + } + + // Add our hal.executable.variant with an empty module. + std::string linkedVariantName = + executableTargetAttrs.size() == 1 + ? targetAttr.getSymbolNameFragment() + : llvm::formatv("{0}_{1}", targetAttr.getSymbolNameFragment(), + index); + auto linkedTargetOp = + executableBuilder.create( + moduleOp.getLoc(), linkedVariantName, targetAttr); + auto targetBuilder = OpBuilder::atBlockBegin(&linkedTargetOp.getBlock()); + targetBuilder.create(moduleOp.getLoc()); + + auto mergeModuleFn = [](mlir::ModuleOp sourceInnerModule, + mlir::ModuleOp linkedInnerModule, + DenseMap &symbolMap) { + return mergeModuleInto(sourceInnerModule, linkedInnerModule, symbolMap, + allowRenamingPrivateLLVMSymbols); + }; + + // Try linking together all executables in moduleOp. + if (failed(linkExecutablesInto(moduleOp, sourceExecutableOps, + linkedExecutableOp, linkedTargetOp, + mergeModuleFn))) { + return signalPassFailure(); + } + } + } +}; +} // namespace +} // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUPromoteMatmulToFitMMA.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUPromoteMatmulToFitMMA.cpp index dbcc5b1e54b6..24214940f30e 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUPromoteMatmulToFitMMA.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUPromoteMatmulToFitMMA.cpp @@ -27,25 +27,18 @@ class LLVMGPUPromoteMatmulToFitMMAPass final public: using impl::LLVMGPUPromoteMatmulToFitMMAPassBase< LLVMGPUPromoteMatmulToFitMMAPass>::LLVMGPUPromoteMatmulToFitMMAPassBase; - explicit LLVMGPUPromoteMatmulToFitMMAPass( - const LLVMGPUMatmulPadOption &option) { - this->targetDimensions.setValue(option); - } void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); } void padWithZeroValue(RewriterBase &rewriter, linalg::LinalgOp op, - ArrayRef paddingDims, - ArrayRef padToMultipleOf, bool noFold) const { - assert(paddingDims.size() == padToMultipleOf.size() && - "invalid pad multiples for padding dimensions"); - + ArrayRef padToMultipleOf) const { LLVM_DEBUG(llvm::dbgs() << "candidate: " << op << "\n"); OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointAfter(op); - SmallVector nofoldFlags(op.getNumDpsInputs(), noFold); + SmallVector paddingDims = + llvm::to_vector(llvm::seq(padToMultipleOf.size())); SmallVector paddingValueAttributes; for (auto &operand : op->getOpOperands()) { @@ -58,7 +51,6 @@ class LLVMGPUPromoteMatmulToFitMMAPass final .setPaddingDimensions(paddingDims) .setPaddingValues(paddingValueAttributes) .setPadToMultipleOf(padToMultipleOf) - .setNofoldFlags(nofoldFlags) .setCopyBackOp(linalg::LinalgPaddingOptions::CopyBackOp::None); FailureOr result = @@ -72,26 +64,6 @@ class LLVMGPUPromoteMatmulToFitMMAPass final MLIRContext *ctx = &getContext(); auto funcOp = getOperation(); - // Preserve the innermost tensor.pad ops (i.e., pad for reduction dims), so - // we can kick canonicalization patterns to fold outer tensor.pad ops away. - bool noFold = false; - utils::IteratorType targetIterType = utils::IteratorType::parallel; - switch (targetDimensions) { - case LLVMGPUMatmulPadOption::ParallelDims: - LLVM_DEBUG(llvm::dbgs() << "padding parallel dims\n"); - targetIterType = utils::IteratorType::parallel; - noFold = false; - break; - case LLVMGPUMatmulPadOption::ReductionDims: - LLVM_DEBUG(llvm::dbgs() << "padding reduction dims\n"); - targetIterType = utils::IteratorType::reduction; - noFold = true; - break; - default: // Unreachable. - assert(false); - break; - }; - SmallVector candidates; funcOp->walk([&](linalg::LinalgOp op) { if (linalg::isaContractionOpInterface(op)) { @@ -101,46 +73,27 @@ class LLVMGPUPromoteMatmulToFitMMAPass final IRRewriter rewriter(ctx); for (linalg::LinalgOp op : candidates) { - SmallVector padMultiples(op.getNumLoops(), 1); auto config = dyn_cast_or_null( getLoweringConfig(op)); - if (config) { - switch (targetDimensions) { - case LLVMGPUMatmulPadOption::ParallelDims: - padMultiples = config.getStaticTilingLevelSizes( - static_cast(IREE::GPU::TilingLevel::Workgroup), op); - break; - case LLVMGPUMatmulPadOption::ReductionDims: - padMultiples = config.getStaticTilingLevelSizes( - static_cast(IREE::GPU::TilingLevel::Reduction), op); - break; - default: - assert(false && "Unexpected target dimensions"); - break; - } + if (!config) { + continue; } - // Populate padding dimensions. - SmallVector paddingDimensions; - for (auto [idx, iter] : llvm::enumerate(op.getIteratorTypesArray())) { - if (iter == targetIterType) { - paddingDimensions.push_back(idx); - } - } + SmallVector wgTiles = config.getStaticTilingLevelSizes( + static_cast(IREE::GPU::TilingLevel::Workgroup), op); + SmallVector redTiles = config.getStaticTilingLevelSizes( + static_cast(IREE::GPU::TilingLevel::Reduction), op); - // Populate tile sizes. We pad to multiples of workgroup/reduction - // tile sizes based on the selected target tiling dimensions. - // This pass is ran after the select target tiling is done to pad - // all dimensions to the select tile sizes. - SmallVector padToMultipleOf; - for (int64_t dim : paddingDimensions) { - if (padMultiples[dim] != 0) { - padToMultipleOf.push_back(padMultiples[dim]); - } + // Populate padding dimensions to maximum of possible tile sizes. + SmallVector padToMultipleOf(op.getNumLoops(), 1); + for (auto [wgTile, redTile, padMultiple] : + llvm::zip_equal(wgTiles, redTiles, padToMultipleOf)) { + padMultiple = std::max({wgTile, redTile, padMultiple}); } + SmallVector paddingDimensions = + llvm::to_vector(llvm::seq(op.getNumLoops())); - padWithZeroValue(rewriter, op, paddingDimensions, padToMultipleOf, - noFold); + padWithZeroValue(rewriter, op, padToMultipleOf); } { @@ -156,58 +109,8 @@ class LLVMGPUPromoteMatmulToFitMMAPass final return signalPassFailure(); } } - - // XXX(hanchung): This is needed for pad op fusion, which will remove - // outer pad ops. I.e., it mainly wants to remove first pad op in the - // pad->extract_slice->pad chain, while the canonicalization pattern can - // only recognize slice->pad->slice->pad. - { - SmallVector padOps; - funcOp.walk([&](tensor::PadOp op) { padOps.push_back(op); }); - for (auto op : padOps) { - auto srcExtractSliceOp = - op.getSource().getDefiningOp(); - if (!srcExtractSliceOp) { - continue; - } - auto producerPadOp = - srcExtractSliceOp.getSource().getDefiningOp(); - if (!producerPadOp) { - continue; - } - auto src = producerPadOp.getSource() - .getDefiningOp(); - if (!src) { - continue; - } - - rewriter.setInsertionPointAfter(src); - SmallVector sizes = - tensor::getMixedSizes(rewriter, op.getLoc(), src); - SmallVector offsets(sizes.size(), - rewriter.getIndexAttr(0)); - SmallVector strides(sizes.size(), - rewriter.getIndexAttr(1)); - auto extractSliceOp = rewriter.create( - op.getLoc(), src.getResult(), offsets, sizes, strides); - rewriter.startOpModification(op); - producerPadOp.getSourceMutable().assign(extractSliceOp.getResult()); - rewriter.finalizeOpModification(op); - } - - RewritePatternSet patterns(ctx); - tensor::PadOp::getCanonicalizationPatterns(patterns, ctx); - if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) { - return signalPassFailure(); - } - } } }; } // namespace -std::unique_ptr> -createLLVMGPUPromoteMatmulToFitMMAPass(LLVMGPUMatmulPadOption option) { - return std::make_unique(option); -} - } // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp index 3508e526925e..3c7eaf88eb46 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp @@ -19,6 +19,7 @@ #include "iree/compiler/Codegen/Utils/GPUUtils.h" #include "iree/compiler/Codegen/Utils/MarkerUtils.h" #include "iree/compiler/Codegen/Utils/Utils.h" +#include "iree/compiler/Dialect/Flow/IR/FlowOps.h" #include "iree/compiler/Dialect/HAL/IR/HALTypes.h" #include "iree/compiler/Dialect/Util/Transforms/Passes.h" #include "iree/compiler/Utils/PassUtils.h" @@ -190,18 +191,23 @@ static void addBufferizePasses(OpPassManager &funcPassManager) { } static void tileAndDistributeToWorkgroup( - OpPassManager &funcPassManager, + OpPassManager &funcPassManager, bool useForall, std::optional convertToDpsOptions = ConvertToDestinationPassingStylePassOptions{}) { - funcPassManager.addPass(createTileAndDistributeToWorkgroupsPass( - kNumMaxParallelDims, - linalg::DistributionMethod::CyclicNumProcsEqNumIters)); - funcPassManager.addPass(createCSEPass()); - - if (convertToDpsOptions) { + if (useForall) { funcPassManager.addPass( - createConvertToDestinationPassingStylePass(*convertToDpsOptions)); + createTileAndDistributeToWorkgroupsUsingForallOpPass()); + } else { + funcPassManager.addPass(createTileAndDistributeToWorkgroupsPass( + kNumMaxParallelDims, + linalg::DistributionMethod::CyclicNumProcsEqNumIters)); + funcPassManager.addPass(createCSEPass()); + if (convertToDpsOptions) { + funcPassManager.addPass( + createConvertToDestinationPassingStylePass(*convertToDpsOptions)); + } } + // TODO(#16421): Disable decomposition due to failure in bufferization. // funcPassManager.addPass( // IREE::LinalgExt::createTileAndDecomposeAttentionPass()); @@ -212,7 +218,8 @@ static void tileAndDistributeToWorkgroup( static void tileAndBufferize(OpPassManager &funcPassManager) { ConvertToDestinationPassingStylePassOptions options; options.useWARForCooperativeMatrixCodegen = true; - tileAndDistributeToWorkgroup(funcPassManager, options); + tileAndDistributeToWorkgroup(funcPassManager, /*useForall=*/false, + /*convertToDpsOptions=*/options); addBufferizePasses(funcPassManager); } @@ -243,7 +250,7 @@ static void addGPUVectorizationPasses(OpPassManager &funcPassManager, //===---------------------------------------------------------------------===// void addGPUVectorizationPassPipeline(OpPassManager &funcPassManager) { - tileAndDistributeToWorkgroup(funcPassManager); + tileAndDistributeToWorkgroup(funcPassManager, /*useForall=*/false); funcPassManager.addPass(createCanonicalizerPass()); funcPassManager.addPass(createCanonicalizerPass()); @@ -323,22 +330,45 @@ static void addGPUBufferizePasses(OpPassManager &funcPassManager) { funcPassManager.addPass(createCSEPass()); } +/// Control function for decomposing pack and unpack ops. Returns true if the +/// op is a PackOp with a DispatchTensorLoadOp producer, or an UnPackOp with +/// only DispatchTensorStoreOp consumers. +LogicalResult isAtBoundary(Operation *op) { + if (isa(op)) { + if (isa_and_nonnull( + op->getOperand(0).getDefiningOp())) { + return success(); + } + } else if (isa(op)) { + if (llvm::all_of(op->getUsers(), [](Operation *user) { + return isa(user); + })) { + return success(); + } + } + return failure(); +} + void addGPUTileAndFusePassPipeline(OpPassManager &funcPassManager, const GPUPipelineOptions &pipelineOptions) { - tileAndDistributeToWorkgroup(funcPassManager, - /*convertToDpsOptions=*/std::nullopt); - // Step 1. Promote matmul operands and pack to intrinsic shapes. funcPassManager.addPass(createGPUPromoteMatmulOperandsPass()); funcPassManager.addPass(IREE::GPU::createPackToIntrinsicsPass()); + // Decompose packs and unpacks that are at the function boundary. + funcPassManager.addPass(createDecomposeBoundaryPackUnPackOpsPass()); - // Step 1.5. Expand result shapes of MultiMmaOps before reduction tiling. + // Step 1.5. Expand result shapes of MultiMmaOps before tiling, and + // propagate reshapes to the function boundary. { IREE::GPU::ConcretizeMmaShapesPassOptions options; options.concretizeInputs = false; options.concretizeResult = true; funcPassManager.addPass(IREE::GPU::createConcretizeMmaShapesPass()); } + funcPassManager.addPass(createPropagateReshapesByExpansionPass()); + + tileAndDistributeToWorkgroup(funcPassManager, /*useForall=*/true, + /*convertToDpsOptions=*/std::nullopt); // Step 2. Tile and fuse tileable ops to reduction loops. { @@ -350,10 +380,9 @@ void addGPUTileAndFusePassPipeline(OpPassManager &funcPassManager, } // Step 3. Decompose pack and unpack ops and propagate the resulting reshapes. - funcPassManager.addPass( - createDecomposePackUnPackOpsPass(/*tileOuterToOne=*/false, - /*useOnlyReshapes=*/true, - /*controlFn=*/std::nullopt)); + funcPassManager.addPass(createDecomposePackUnPackOpsPass( + DecomposePackUnPackOpsPassOptions{/*tileOuterToOne=*/false, + /*useOnlyReshapes=*/true})); // Step 3.5. Expand the inner dimensions of MultiMma ops in preparation for // distribution to lanes. @@ -469,7 +498,7 @@ void addGPUTileAndFusePassPipeline(OpPassManager &funcPassManager, //===---------------------------------------------------------------------===// void addGPUWinogradVectorizePassPipeline(OpPassManager &funcPassManager) { - tileAndDistributeToWorkgroup(funcPassManager); + tileAndDistributeToWorkgroup(funcPassManager, /*useForall=*/false); funcPassManager.addPass(createCanonicalizerPass()); funcPassManager.addPass(createCanonicalizerPass()); @@ -506,7 +535,7 @@ void addGPUWinogradVectorizePassPipeline(OpPassManager &funcPassManager) { void addGPUMatmulSimtPassPipeline(OpPassManager &funcPassManager, const GPUPipelineOptions &options) { - tileAndDistributeToWorkgroup(funcPassManager); + tileAndDistributeToWorkgroup(funcPassManager, /*useForall=*/false); funcPassManager.addPass(createCanonicalizerPass()); funcPassManager.addPass(createCanonicalizerPass()); @@ -710,7 +739,7 @@ void addGPUMatmulTensorCoreMmaSyncPassPipeline( void addGPUTransposePassPipeline(OpPassManager &funcPassManager, const GPUPipelineOptions &options) { - tileAndDistributeToWorkgroup(funcPassManager); + tileAndDistributeToWorkgroup(funcPassManager, /*useForall=*/false); funcPassManager.addPass(createCanonicalizerPass()); funcPassManager.addPass(createCanonicalizerPass()); @@ -815,7 +844,7 @@ static void addVectorBufferizePasses(OpPassManager &funcPassManager) { void addGPUVectorDistributePassPipeline(OpPassManager &funcPassManager, const GPUPipelineOptions &options, bool usePadToModelSharedMemcpy) { - tileAndDistributeToWorkgroup(funcPassManager); + tileAndDistributeToWorkgroup(funcPassManager, /*useForall=*/false); ReorderWorkgroupsStrategy reorderStrategy = getReorderWorkgroupsStrategy(options.reorderStrategy); @@ -829,25 +858,20 @@ void addGPUVectorDistributePassPipeline(OpPassManager &funcPassManager, funcPassManager.addPass(createCSEPass()); if (usePadToModelSharedMemcpy) { - LLVMGPUMatmulPadOption option = LLVMGPUMatmulPadOption::ParallelDims; - funcPassManager.addPass(createLLVMGPUPromoteMatmulToFitMMAPass(option)); + funcPassManager.addPass(createLLVMGPUPromoteMatmulToFitMMAPass()); } // Tile to reduction loops. { GPUApplyTilingLevelPassOptions options; options.tilingLevel = IREE::GPU::TilingLevel::Reduction; + options.allowZeroSlices = true; funcPassManager.addPass(createGPUApplyTilingLevelPass(options)); funcPassManager.addPass(affine::createLoopCoalescingPass()); funcPassManager.addPass(createCanonicalizerPass()); funcPassManager.addPass(createCSEPass()); } - if (usePadToModelSharedMemcpy) { - LLVMGPUMatmulPadOption option = LLVMGPUMatmulPadOption::ReductionDims; - funcPassManager.addPass(createLLVMGPUPromoteMatmulToFitMMAPass(option)); - } - funcPassManager.addPass(IREE::LinalgExt::createDecomposeAttentionPass()); funcPassManager.addPass(createCanonicalizerPass()); funcPassManager.addPass(createCSEPass()); @@ -915,7 +939,7 @@ void addGPUVectorDistributePassPipeline(OpPassManager &funcPassManager, } void addGPUWarpReductionPassPipeline(OpPassManager &funcPassManager) { - tileAndDistributeToWorkgroup(funcPassManager); + tileAndDistributeToWorkgroup(funcPassManager, /*useForall=*/false); funcPassManager.addPass(createRematerializeParallelOpsPass()); funcPassManager.addPass(createCanonicalizerPass()); funcPassManager.addPass(createGPUTileReductionPass()); @@ -959,7 +983,7 @@ void addGPUWarpReductionPassPipeline(OpPassManager &funcPassManager) { } void addGPUPackUnPackPasses(OpPassManager &funcPassManager) { - tileAndDistributeToWorkgroup(funcPassManager); + tileAndDistributeToWorkgroup(funcPassManager, /*useForall=*/false); funcPassManager.addPass(createCanonicalizerPass()); funcPassManager.addPass(createCSEPass()); @@ -967,10 +991,9 @@ void addGPUPackUnPackPasses(OpPassManager &funcPassManager) { funcPassManager.addPass(createCanonicalizerPass()); funcPassManager.addPass(createCSEPass()); - funcPassManager.addPass( - createDecomposePackUnPackOpsPass(/*tileOuterToOne=*/true, - /*useOnlyReshapes=*/false, - /*controlFn=*/std::nullopt)); + funcPassManager.addPass(createDecomposePackUnPackOpsPass( + DecomposePackUnPackOpsPassOptions{/*tileOuterToOne=*/true, + /*useOnlyReshapes=*/false})); funcPassManager.addPass(createCanonicalizerPass()); funcPassManager.addPass(createCSEPass()); addGPUVectorizationPasses(funcPassManager); @@ -996,7 +1019,8 @@ void addGPUDefaultPassPipeline(OpPassManager &funcPassManager, const GPUPipelineOptions &options) { ConvertToDestinationPassingStylePassOptions dpsOptions; dpsOptions.useWARForCooperativeMatrixCodegen = true; - tileAndDistributeToWorkgroup(funcPassManager, dpsOptions); + tileAndDistributeToWorkgroup(funcPassManager, /*useForall=*/false, + /*convertToDpsOptions=*/dpsOptions); if (options.enableUkernels) { funcPassManager.addPass(createGPULowerToUKernelsPass()); } @@ -1146,32 +1170,18 @@ void addGPUTransformDialectPasses(OpPassManager &funcPassManager, // Common Pass Pipelines //===----------------------------------------------------------------------===// -static LogicalResult igemmConfigFn(linalg::GenericOp genericOp, - IREE::LinalgExt::Im2colOp im2colOp) { - auto funcOp = genericOp->getParentOfType(); - if (!funcOp) { - return genericOp.emitError("cannot find parent funcOp"); - } - IREE::GPU::TargetAttr target = getGPUTargetAttr(funcOp); - if (!target) { - return funcOp.emitError("missing GPU target in parent funcOp"); - } - if (failed(IREE::GPU::setMatmulLoweringConfig(target, funcOp, genericOp))) { - return IREE::GPU::setTileAndFuseLoweringConfig(target, funcOp, genericOp); - } - return success(); -} - static void buildLLVMGPUCodegenConfigurationPassPipelineImpl( OpPassManager &modulePassManager) { { FunctionLikeNest funcPassManager(modulePassManager); - funcPassManager.addPredicatedPass(clLLVMGPUUseIgemm, []() { - return createConvolutionToIGEMMPass(igemmConfigFn); - }); + funcPassManager.addPredicatedPass(clLLVMGPUUseIgemm, + createLLVMGPUConvolutionToIGEMMPass); funcPassManager.addPass(createGPUGeneralizeNamedOpsPass); addCommonTargetExecutablePreprocessingPasses(funcPassManager); addEncodingToNopPasses(funcPassManager); + funcPassManager.addPass(createBlockDynamicDimensionsPass); + funcPassManager.addPass(createCanonicalizerPass); + funcPassManager.addPass(createCSEPass); } modulePassManager.addPass(createMaterializeUserConfigsPass()); modulePassManager.addPass(createLLVMGPUSelectLoweringStrategyPass()); @@ -1210,6 +1220,25 @@ void buildLLVMGPUCodegenPassPipeline(OpPassManager &variantPassManager, }); } +// NOTE: this runs on the top-level program module containing all +// hal.executable ops. +void buildLLVMGPULinkingPassPipeline(OpPassManager &modulePassManager, + std::optional target) { + // Link together executables. This may produce some IR duplication. + LLVMGPULinkExecutablesPassOptions linkOptions; + linkOptions.target = target.value_or(""); + modulePassManager.addPass(createLLVMGPULinkExecutablesPass(linkOptions)); + + // Cleanup IR duplication. + modulePassManager.addNestedPass( + mlir::createCanonicalizerPass()); + + // Assign final executable constant and import ordinals. + auto &variantPassManager = modulePassManager.nest() + .nest(); + variantPassManager.addPass(createLLVMGPUAssignConstantOrdinalsPass()); +} + //===----------------------------------------------------------------------===// // ROCDL Pass Pipelines //===----------------------------------------------------------------------===// @@ -1218,9 +1247,8 @@ static void buildROCDLCodegenConfigurationPassPipelineImpl( OpPassManager &modulePassManager) { { FunctionLikeNest funcPassManager(modulePassManager); - funcPassManager.addPredicatedPass(clLLVMGPUUseIgemm, []() { - return createConvolutionToIGEMMPass(igemmConfigFn); - }); + funcPassManager.addPredicatedPass(clLLVMGPUUseIgemm, + createLLVMGPUConvolutionToIGEMMPass); funcPassManager.addPass(createGPUGeneralizeNamedOpsPass); addCommonTargetExecutablePreprocessingPasses(funcPassManager); } @@ -1289,6 +1317,13 @@ void registerCodegenLLVMGPUPasses() { [](OpPassManager &passManager) { buildLLVMGPUCodegenPassPipeline(passManager, true); }); + + static PassPipelineRegistration<> LLVMGPULinkingPipeline( + "iree-codegen-llvmgpu-linking-pipeline", + "Runs the LLVMGPU HAL executable linking pipeline", + [](OpPassManager &modulePassManager) { + buildLLVMGPULinkingPassPipeline(modulePassManager); + }); } //===---------------------------------------------------------------------===// diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.h b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.h index c1181776e8f7..e7132c7bbd08 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.h +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.h @@ -12,6 +12,8 @@ #ifndef IREE_COMPILER_CODEGEN_LLVMGPU_PASSES_H_ #define IREE_COMPILER_CODEGEN_LLVMGPU_PASSES_H_ +#include + #include "iree/compiler/Codegen/Common/GPU/Passes.h" #include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h" #include "iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.h" @@ -22,7 +24,7 @@ namespace mlir::iree_compiler { using IREE::GPU::GPUPipelineOptions; //----------------------------------------------------------------------------// -// LLVMGPU backend Pass Pipelines. +// LLVMGPU Backend Pass Pipelines //----------------------------------------------------------------------------// /// Lowering using SIMT CUDA core operations. @@ -99,14 +101,19 @@ verifyGPUMatmulPipeline(Operation *op, IREE::Codegen::TranslationInfoAttr translationInfo, ArrayRef workgroupSize); +//----------------------------------------------------------------------------// +// LLVMGPU Linking Passes and Pipelines +//----------------------------------------------------------------------------// + +/// Populates passes needed to link HAL executables across LLVMGPU targets. +void buildLLVMGPULinkingPassPipeline( + OpPassManager &modulePassManager, + std::optional target = std::nullopt); + //------------------------------------------------------------------------------ -// Wrappers that not use tablegen options. +// Wrappers that do not use tablegen options //------------------------------------------------------------------------------ -enum class LLVMGPUMatmulPadOption { ParallelDims, ReductionDims }; -std::unique_ptr> -createLLVMGPUPromoteMatmulToFitMMAPass(LLVMGPUMatmulPadOption option); - enum class GPUTensorCoreType { WMMA = 0, MMA_SYNC = 1, diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.td b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.td index ef51a6a9a883..0b8df811a628 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.td +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.td @@ -66,6 +66,11 @@ def ExtractAddressComputationGPUPass: Pass<"extract-address-computation-gpu"> { ]; } +def LLVMGPUAssignConstantOrdinalsPass : + Pass<"iree-llvmgpu-assign-constant-ordinals", "IREE::HAL::ExecutableVariantOp"> { + let summary = "Assigns executable constant ordinals across all LLVMGPU variants."; +} + def LLVMGPUCastAddressSpaceFunctionPass : Pass<"iree-llvmgpu-cast-address-space-function", "ModuleOp"> { let summary = "Cast address space to generic in CallOp and FuncOp"; @@ -87,6 +92,29 @@ def LLVMGPUConfigureVectorLayoutsPass : let summary = "Pass to set layouts for vector distribution"; } +def LLVMGPUConvolutionToIGEMMPass : + InterfacePass<"iree-llvmgpu-convolution-to-igemm", "mlir::FunctionOpInterface"> { + let summary = "Pass to convert conv_2d ops to igemm and set a lowering configuration."; + let dependentDialects = [ + "tensor::TensorDialect", + "iree_compiler::IREE::Codegen::IREECodegenDialect", + "iree_compiler::IREE::GPU::IREEGPUDialect", + "iree_compiler::IREE::LinalgExt::IREELinalgExtDialect" + ]; +} + +def LLVMGPULinkExecutablesPass : + Pass<"iree-llvmgpu-link-executables", "mlir::ModuleOp"> { + let summary = "Links LLVMGPU HAL executables within the top-level program module."; + let options = [ + Option< + "target", "target", + "std::string", "", + "Target backend name whose executables will be linked by this pass." + >, + ]; +} + def LLVMGPULowerExecutableTargetPass : InterfacePass<"iree-llvmgpu-lower-executable-target", "mlir::FunctionOpInterface"> { let summary = "Perform lowering of executable target using one of the IREE::HAL::DispatchLoweringPassPipeline"; @@ -105,19 +133,6 @@ def LLVMGPUPrefetchSharedMemoryPass : def LLVMGPUPromoteMatmulToFitMMAPass : InterfacePass<"iree-llvmgpu-promote-matmul-to-fit-mma", "mlir::FunctionOpInterface"> { let summary = "Pass to promote contraction ops to fit mma shapes"; - let options = [ - Option<"targetDimensions", "target-dimensions", "mlir::iree_compiler::LLVMGPUMatmulPadOption", - /*default=*/"mlir::iree_compiler::LLVMGPUMatmulPadOption::ParallelDims", - "Select the strategy to control how multi_reduction is lowered.", - [{::llvm::cl::values( - clEnumValN(mlir::iree_compiler::LLVMGPUMatmulPadOption::ParallelDims, - "parallel", - "Pad all the parallel dims for contraction ops."), - clEnumValN(mlir::iree_compiler::LLVMGPUMatmulPadOption::ReductionDims, - "reduction", - "Pad all the reduction dims for contraction ops.") - )}]> - ]; } def LLVMGPUSelectLoweringStrategyPass : diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel index 00bc6f967acf..1088035a5697 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel @@ -21,6 +21,7 @@ iree_lit_test_suite( "amdgpu_chained_matmul.mlir", "amdgpu_contraction_distribution.mlir", "amdgpu_set_anchor_layouts.mlir", + "assign_constant_ordinals.mlir", "conv_pipeline_test_cuda.mlir", "conv_pipeline_test_rocm.mlir", "convert_to_nvvm.mlir", @@ -38,6 +39,7 @@ iree_lit_test_suite( "gpu_set_num_workgroups.mlir", "gpu_pipeline_generalize_named_ops.mlir", "gpu_pipeline_igemm.mlir", + "link_executables.mlir", "nvvm_extract_address_computation.mlir", "nvvm_pipeline_test.mlir", "nvvm_mma_sync_pipeline_test.mlir", @@ -49,6 +51,7 @@ iree_lit_test_suite( "legalize.mlir", "linalg_transform.mlir", "llvmgpu_bufferize.mlir", + "llvmgpu_convolution_to_igemm.mlir", "pack_pipeline_test.mlir", "pack_shared_memory_alloc.mlir", "prefetch_shared_memory.mlir", diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt index 6be97c06d533..795ee25f3303 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt @@ -17,6 +17,7 @@ iree_lit_test_suite( "amdgpu_chained_matmul.mlir" "amdgpu_contraction_distribution.mlir" "amdgpu_set_anchor_layouts.mlir" + "assign_constant_ordinals.mlir" "cast_address_space_function.mlir" "cast_type_to_fit_mma.mlir" "config_custom_op.mlir" @@ -39,7 +40,9 @@ iree_lit_test_suite( "illegal_configuration.mlir" "legalize.mlir" "linalg_transform.mlir" + "link_executables.mlir" "llvmgpu_bufferize.mlir" + "llvmgpu_convolution_to_igemm.mlir" "nvvm_extract_address_computation.mlir" "nvvm_mma_sync_pipeline_test.mlir" "nvvm_pipeline_test.mlir" diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse.mlir index 53952e953549..819b8826bb1d 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse.mlir @@ -37,8 +37,76 @@ func.func @expanded_matmul_transpose_b(%lhs: tensor<2x64x2048xf16>, %rhs: tensor // CHECK-SAME: mma_kind = #iree_gpu.mma_layout // CHECK-SAME: promote_operands = [0, 1] // CHECK-SAME: reduction = [0, 0, 0, 0, 4] -// CHECK-SAME: subgroup = [0, 0, 4, 1, 0] -// CHECK-SAME: workgroup = [1, 1, 64, 64, 0] +// CHECK-SAME: subgroup = [1, 1, 4, 1, 0] +// CHECK-SAME: workgroup = [1, 1, 4, 4, 0] + +// ----- + +#map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d4, d5)> +#map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d3, d4, d5)> +#map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> +func.func @multi_dim_mma_schedule(%lhs: tensor<10x32x128x16xf16>, %rhs: tensor<4x32x128x16xf16>) -> tensor<10x4x32x32xf16> { + %c0 = arith.constant 0 : index + %cst = arith.constant 0.000000e+00 : f16 + %5 = tensor.empty() : tensor<10x4x32x32xf16> + %6 = linalg.fill ins(%cst : f16) outs(%5 : tensor<10x4x32x32xf16>) -> tensor<10x4x32x32xf16> + %7 = linalg.generic { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} + ins(%lhs, %rhs : tensor<10x32x128x16xf16>, tensor<4x32x128x16xf16>) outs(%6 : tensor<10x4x32x32xf16>) { + ^bb0(%in: f16, %in_0: f16, %out: f16): + %8 = arith.mulf %in, %in_0 : f16 + %9 = arith.addf %8, %out : f16 + linalg.yield %9 : f16 + } -> tensor<10x4x32x32xf16> + return %7 : tensor<10x4x32x32xf16> +} + +// CHECK-LABEL: func.func @multi_dim_mma_schedule +// CHECK-SAME: #iree_codegen.translation_info + +// CHECK: linalg.generic {{.*}}lowering_config = #iree_gpu.lowering_config +// CHECK-SAME: mma_kind = #iree_gpu.mma_layout +// CHECK-SAME: promote_operands = [0, 1] +// CHECK-SAME: reduction = [0, 0, 0, 0, 4, 1] +// CHECK-SAME: subgroup = [2, 2, 1, 1, 0, 0] +// CHECK-SAME: workgroup = [2, 2, 2, 2, 0, 0] + +// ----- + +#map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d3, d5, d6)> +#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d2, d4, d5, d6)> +#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4)> +func.func @dynamic_multi_dim_mma_schedule(%lhs: tensor, %rhs: tensor) -> tensor { + %c0 = arith.constant 0 : index + %cst = arith.constant 0.000000e+00 : f16 + %d0 = tensor.dim %lhs, %c0 : tensor + %d2 = tensor.dim %rhs, %c0 : tensor + %5 = tensor.empty(%d0, %d2) : tensor + %6 = linalg.fill ins(%cst : f16) outs(%5 : tensor) -> tensor + %7 = linalg.generic { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} + ins(%lhs, %rhs : tensor, tensor) outs(%6 : tensor) { + ^bb0(%in: f16, %in_0: f16, %out: f16): + %8 = arith.mulf %in, %in_0 : f16 + %9 = arith.addf %8, %out : f16 + linalg.yield %9 : f16 + } -> tensor + return %7 : tensor +} + +// CHECK-LABEL: func.func @dynamic_multi_dim_mma_schedule +// CHECK-SAME: #iree_codegen.translation_info + +// CHECK: linalg.generic {{.*}}lowering_config = #iree_gpu.lowering_config +// CHECK-SAME: mma_kind = #iree_gpu.mma_layout +// CHECK-SAME: promote_operands = [0, 1] +// CHECK-SAME: reduction = [0, 0, 0, 0, 0, 1, 1] +// CHECK-SAME: subgroup = [0, 1, 0, 1, 1, 0, 0] +// CHECK-SAME: workgroup = [1, 2, 1, 1, 2, 0, 0] // ----- @@ -52,7 +120,7 @@ func.func @mfma_matmul_1024x1024x1024(%lhs: tensor<1024x1024xf16>, %rhs: tensor< } // CHECK-LABEL: func.func @mfma_matmul_1024x1024x1024 -// CHECK-SAME: #iree_codegen.translation_info // Verify that the fill does not have the lowering config propagated to it. @@ -63,7 +131,7 @@ func.func @mfma_matmul_1024x1024x1024(%lhs: tensor<1024x1024xf16>, %rhs: tensor< // CHECK-SAME: promote_operands = [0, 1] // CHECK-SAME: reduction = [0, 0, 2] // CHECK-SAME: subgroup = [4, 4, 0] -// CHECK-SAME: workgroup = [128, 128, 0] +// CHECK-SAME: workgroup = [8, 8, 0] // ----- diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir index 66ac37b6a370..2ebc85496759 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir @@ -50,18 +50,20 @@ hal.executable public @main { // CHECK-DAG: %[[B2:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(2) // CHECK-DAG: memref.alloc() : memref<64x8xf16, #gpu.address_space> // CHECK-DAG: memref.alloc() : memref<64x8xf16, #gpu.address_space> -// CHECK: %[[LOOP:.+]] = scf.for %[[IV:.+]] = %c0 to %c1280 step %c4 {{.*}} -> (vector<8x4xf32>) -// CHECK: gpu.barrier -// CHECK-DAG: %[[LHS_RD:.+]] = vector.transfer_read %[[B0]]{{.*}} vector<2xf16> -// CHECK-DAG: vector.transfer_write %[[LHS_RD]], %[[LHS_ALLOC:[A-Za-z0-9]+]] -// CHECK-DAG: %[[RHS_RD:.+]] = vector.transfer_read %[[B1]]{{.*}} vector<2xf16> -// CHECK-DAG: vector.transfer_write %[[RHS_RD]], %[[RHS_ALLOC:[A-Za-z0-9]+]] -// CHECK: gpu.barrier -// CHECK-DAG: %[[LHS_MM:.+]] = vector.transfer_read %[[LHS_ALLOC]]{{.*}} vector<8x4xf16> -// CHECK-DAG: %[[RHS_MM:.+]] = vector.transfer_read %[[RHS_ALLOC]]{{.*}} vector<4x4xf16> -// CHECK: %[[MM:.+]] = vector.contract {{.*}} %[[LHS_MM]], %[[RHS_MM]] -// CHECK: scf.yield %[[MM]] -// CHECK: vector.transfer_write %[[LOOP]], %[[B2]] +// CHECK: scf.forall ({{.*}}) in (32, 160) { +// CHECK: %[[LOOP:.+]] = scf.for %[[IV:.+]] = %c0 to %c1280 step %c4 {{.*}} -> (vector<8x4xf32>) +// CHECK: gpu.barrier +// CHECK-DAG: %[[LHS_RD:.+]] = vector.transfer_read %[[B0]]{{.*}} vector<2xf16> +// CHECK-DAG: vector.transfer_write %[[LHS_RD]], %[[LHS_ALLOC:[A-Za-z0-9]+]] +// CHECK-DAG: %[[RHS_RD:.+]] = vector.transfer_read %[[B1]]{{.*}} vector<2xf16> +// CHECK-DAG: vector.transfer_write %[[RHS_RD]], %[[RHS_ALLOC:[A-Za-z0-9]+]] +// CHECK: gpu.barrier +// CHECK-DAG: %[[LHS_MM:.+]] = vector.transfer_read %[[LHS_ALLOC]]{{.*}} vector<8x4xf16> +// CHECK-DAG: %[[RHS_MM:.+]] = vector.transfer_read %[[RHS_ALLOC]]{{.*}} vector<4x4xf16> +// CHECK: %[[MM:.+]] = vector.contract {{.*}} %[[LHS_MM]], %[[RHS_MM]] +// CHECK: scf.yield %[[MM]] +// CHECK: vector.transfer_write %[[LOOP]], %[[B2]] +// CHECK: } {mapping = [#iree_codegen.workgroup_mapping, #iree_codegen.workgroup_mapping]} // ----- @@ -71,7 +73,7 @@ hal.executable public @main { #hal.pipeline.binding ]> #config = #iree_gpu.lowering_config<{ - workgroup = [64, 64, 0], + workgroup = [4, 4, 0], reduction = [0, 0, 2], subgroup = [2, 2], mma_kind = #iree_gpu.mma_layout, @@ -112,21 +114,23 @@ hal.executable public @main { // CHECK-DAG: %[[B2:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(2) // CHECK-DAG: memref.alloc() : memref<64x36xf16, #gpu.address_space> // CHECK-DAG: memref.alloc() : memref<64x36xf16, #gpu.address_space> -// CHECK: %[[LOOP:.+]] = scf.for %[[IV:.+]] = %c0 to %c80 step %c2 {{.*}} -> (vector<2x2x4x1xf32>) -// CHECK: gpu.barrier -// CHECK-DAG: %[[LHS_RD:.+]] = vector.transfer_read %[[B0]]{{.*}} vector<8xf16> -// CHECK-DAG: vector.transfer_write %[[LHS_RD]] -// CHECK-DAG: %[[RHS_RD:.+]] = vector.transfer_read %[[B1]]{{.*}} vector<8xf16> -// CHECK-DAG: vector.transfer_write %[[RHS_RD]] -// CHECK: gpu.barrier -// CHECK-DAG: vector.transfer_read {{.*}} #gpu.address_space>, vector<2x1x2x4xf16> -// CHECK-DAG: vector.transfer_read {{.*}} #gpu.address_space>, vector<2x1x2x4xf16> -// CHECK-DAG: vector.transpose %{{.*}}, [0, 2, 1, 3] : vector<2x1x2x4xf16> -// CHECK-DAG: vector.transpose %{{.*}}, [0, 2, 1, 3] : vector<2x1x2x4xf16> -// CHECK-COUNT-4: amdgpu.mfma {{.*}}blocks = 1 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32 -// CHECK: scf.yield -// CHECK: %[[LOOP_T:.+]] = vector.transpose %[[LOOP]], [0, 2, 1, 3] : vector<2x2x4x1xf32> to vector<2x4x2x1xf32> -// CHECK: vector.transfer_write %[[LOOP_T]], %[[B2]] +// CHECK: scf.forall ({{.*}}) in (32, 160) { +// CHECK: %[[LOOP:.+]] = scf.for %[[IV:.+]] = %c0 to %c80 step %c2 {{.*}} -> (vector<2x2x4x1xf32>) +// CHECK: gpu.barrier +// CHECK-DAG: %[[LHS_RD:.+]] = vector.transfer_read %[[B0]]{{.*}} vector<8xf16> +// CHECK-DAG: vector.transfer_write %[[LHS_RD]] +// CHECK-DAG: %[[RHS_RD:.+]] = vector.transfer_read %[[B1]]{{.*}} vector<8xf16> +// CHECK-DAG: vector.transfer_write %[[RHS_RD]] +// CHECK: gpu.barrier +// CHECK-DAG: vector.transfer_read {{.*}} #gpu.address_space>, vector<2x1x2x4xf16> +// CHECK-DAG: vector.transfer_read {{.*}} #gpu.address_space>, vector<2x1x2x4xf16> +// CHECK-DAG: vector.transpose %{{.*}}, [0, 2, 1, 3] : vector<2x1x2x4xf16> +// CHECK-DAG: vector.transpose %{{.*}}, [0, 2, 1, 3] : vector<2x1x2x4xf16> +// CHECK-COUNT-4: amdgpu.mfma {{.*}}blocks = 1 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32 +// CHECK: scf.yield +// CHECK: %[[LOOP_T:.+]] = vector.transpose %[[LOOP]], [0, 2, 1, 3] : vector<2x2x4x1xf32> to vector<2x4x2x1xf32> +// CHECK: vector.transfer_write %[[LOOP_T]], %[[B2]] +// CHECK: } {mapping = [#iree_codegen.workgroup_mapping, #iree_codegen.workgroup_mapping]} // ----- @@ -136,7 +140,7 @@ hal.executable public @main { #hal.pipeline.binding ]> #config = #iree_gpu.lowering_config<{ - workgroup = [1, 64, 64, 0], + workgroup = [1, 4, 4, 0], reduction = [0, 0, 0, 2], subgroup = [1, 2, 2], mma_kind = #iree_gpu.mma_layout, @@ -154,11 +158,11 @@ hal.executable private @main { %cst = arith.constant 0.000000e+00 : f32 %c0 = arith.constant 0 : index %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> - %1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> - %2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) alignment(64) offset(%c0) : !flow.dispatch.tensor> + %1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> + %2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) alignment(64) offset(%c0) : !flow.dispatch.tensor> %3 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0, 0], sizes = [2, 34, 34, 1280], strides = [1, 1, 1, 1] : !flow.dispatch.tensor> -> tensor<2x34x34x1280xf16> - %4 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0, 0], sizes = [3, 3, 1280, 1280], strides = [1, 1, 1, 1] : !flow.dispatch.tensor> -> tensor<3x3x1280x1280xf16> - %5 = tensor.empty() : tensor<2x16x16x1280xf32> + %4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [11520, 1280], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<11520x1280xf16> + %5 = tensor.empty() : tensor<2x256x1280xf32> %6 = tensor.empty() : tensor<2x256x11520xf16> %7 = iree_linalg_ext.im2col strides = [2, 2] dilations = [1, 1] kernel_size = [3, 3] @@ -166,15 +170,13 @@ hal.executable private @main { batch_pos = [0] m_pos = [1, 2] k_pos = [3] ins(%3 : tensor<2x34x34x1280xf16>) outs(%6 : tensor<2x256x11520xf16>) -> tensor<2x256x11520xf16> - %collapsed = tensor.collapse_shape %4 [[0, 1, 2], [3]] : tensor<3x3x1280x1280xf16> into tensor<11520x1280xf16> - %collapsed_0 = tensor.collapse_shape %5 [[0], [1, 2], [3]] : tensor<2x16x16x1280xf32> into tensor<2x256x1280xf32> - %8 = linalg.fill ins(%cst : f32) outs(%collapsed_0 : tensor<2x256x1280xf32>) -> tensor<2x256x1280xf32> + %8 = linalg.fill ins(%cst : f32) outs(%5 : tensor<2x256x1280xf32>) -> tensor<2x256x1280xf32> %9 = linalg.generic { indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} - ins(%7, %collapsed : tensor<2x256x11520xf16>, tensor<11520x1280xf16>) + ins(%7, %4 : tensor<2x256x11520xf16>, tensor<11520x1280xf16>) outs(%8 : tensor<2x256x1280xf32>) attrs = {lowering_config = #config} { ^bb0(%in: f16, %in_1: f16, %out: f32): %10 = arith.extf %in : f16 to f32 @@ -183,8 +185,7 @@ hal.executable private @main { %13 = arith.addf %12, %out : f32 linalg.yield %13 : f32 } -> tensor<2x256x1280xf32> - %expanded = tensor.expand_shape %9 [[0], [1, 2], [3]] output_shape [2, 16, 16, 1280] : tensor<2x256x1280xf32> into tensor<2x16x16x1280xf32> - flow.dispatch.tensor.store %expanded, %2, offsets = [0, 0, 0, 0], sizes = [2, 16, 16, 1280], strides = [1, 1, 1, 1] : tensor<2x16x16x1280xf32> -> !flow.dispatch.tensor> + flow.dispatch.tensor.store %9, %2, offsets = [0, 0, 0], sizes = [2, 256, 1280], strides = [1, 1, 1] : tensor<2x256x1280xf32> -> !flow.dispatch.tensor> return } } @@ -200,22 +201,24 @@ hal.executable private @main { // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index // CHECK-DAG: %[[C720:.+]] = arith.constant 720 : index // CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index -// CHECK: %[[LOOP:.+]] = scf.for %[[IV:.+]] = %[[C0]] to %[[C720]] step %[[C2]] {{.*}} -> (vector<1x2x2x4x1xf32>) -// CHECK: gpu.barrier -// CHECK-DAG: %[[LHS_RD:.+]] = vector.transfer_read %[[B0]]{{.*}} vector<8xf16> -// CHECK-DAG: vector.transfer_write %[[LHS_RD]] -// CHECK-DAG: %[[RHS_RD:.+]] = vector.transfer_read %[[B1]]{{.*}} vector<8xf16> -// CHECK-DAG: vector.transfer_write %[[RHS_RD]] -// CHECK: gpu.barrier -// CHECK-DAG: %[[LHS_MM0:.+]] = vector.transfer_read {{.*}} vector<2x1x2x4xf16> -// CHECK-DAG: %[[LHS_MM1:.+]] = vector.broadcast {{.*}} vector<2x1x2x4xf16> to vector<1x2x1x2x4xf16> -// CHECK-DAG: %[[RHS_MM:.+]] = vector.transfer_read {{.*}} vector<2x4x2x1xf16> -// CHECK-DAG: vector.transpose %[[LHS_MM1]], [0, 1, 3, 2, 4] : vector<1x2x1x2x4xf16> to vector<1x2x2x1x4xf16> -// CHECK-DAG: vector.transpose %[[RHS_MM]], [0, 2, 3, 1] : vector<2x4x2x1xf16> to vector<2x2x1x4xf16> -// CHECK-COUNT-4: amdgpu.mfma {{.*}}blocks = 1 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32 -// CHECK: %[[LOOP_T:.+]] = vector.transpose %[[LOOP]], [0, 1, 3, 2, 4] : vector<1x2x2x4x1xf32> to vector<1x2x4x2x1xf32> -// CHECK: %[[EXTRACT:.+]] = vector.extract %[[LOOP_T]][0] : vector<2x4x2x1xf32> from vector<1x2x4x2x1xf32> -// CHECK: vector.transfer_write %[[EXTRACT]], %[[B2]] +// CHECK: scf.forall ({{.*}}) in (2, 4, 20) { +// CHECK: %[[LOOP:.+]] = scf.for %[[IV:.+]] = %[[C0]] to %[[C720]] step %[[C2]] {{.*}} -> (vector<1x2x2x4x1xf32>) +// CHECK: gpu.barrier +// CHECK-DAG: %[[LHS_RD:.+]] = vector.transfer_read %[[B0]]{{.*}} vector<8xf16> +// CHECK-DAG: vector.transfer_write %[[LHS_RD]] +// CHECK-DAG: %[[RHS_RD:.+]] = vector.transfer_read %[[B1]]{{.*}} vector<8xf16> +// CHECK-DAG: vector.transfer_write %[[RHS_RD]] +// CHECK: gpu.barrier +// CHECK-DAG: %[[LHS_MM0:.+]] = vector.transfer_read {{.*}} vector<2x1x2x4xf16> +// CHECK-DAG: %[[LHS_MM1:.+]] = vector.broadcast {{.*}} vector<2x1x2x4xf16> to vector<1x2x1x2x4xf16> +// CHECK-DAG: %[[RHS_MM:.+]] = vector.transfer_read {{.*}} vector<2x4x2x1xf16> +// CHECK-DAG: vector.transpose %[[LHS_MM1]], [0, 1, 3, 2, 4] : vector<1x2x1x2x4xf16> to vector<1x2x2x1x4xf16> +// CHECK-DAG: vector.transpose %[[RHS_MM]], [0, 2, 3, 1] : vector<2x4x2x1xf16> to vector<2x2x1x4xf16> +// CHECK-COUNT-4: amdgpu.mfma {{.*}}blocks = 1 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32 +// CHECK: %[[LOOP_T:.+]] = vector.transpose %[[LOOP]], [0, 1, 3, 2, 4] : vector<1x2x2x4x1xf32> to vector<1x2x4x2x1xf32> +// CHECK: %[[EXTRACT:.+]] = vector.extract %[[LOOP_T]][0] : vector<2x4x2x1xf32> from vector<1x2x4x2x1xf32> +// CHECK: vector.transfer_write %[[EXTRACT]], %[[B2]] +// CHECK: } {mapping = [#iree_codegen.workgroup_mapping, #iree_codegen.workgroup_mapping, #iree_codegen.workgroup_mapping]} // ----- @@ -225,7 +228,7 @@ hal.executable private @main { #hal.pipeline.binding ]> #config = #iree_gpu.lowering_config<{ - workgroup = [1, 4, 16, 256, 0], + workgroup = [1, 4, 16, 16, 0], reduction = [0, 0, 0, 0, 2], subgroup = [1, 4, 1, 4, 0], mma_kind = #iree_gpu.mma_layout, @@ -287,6 +290,7 @@ hal.executable private @main { // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index // CHECK-DAG: %[[C720:.+]] = arith.constant 720 : index // CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index +// CHECK: scf.forall ({{.*}}) in (2, 4, 5) { // CHECK: %[[LOOP:.+]] = scf.for %[[IV:.+]] = %[[C0]] to %[[C720]] step %[[C2]] {{.*}} -> (vector<1x4x1x4x4x1xf32>) // CHECK: gpu.barrier // CHECK-DAG: %[[LHS_RD:.+]] = vector.transfer_read %[[B0]]{{.*}} vector<8xf16> @@ -303,6 +307,7 @@ hal.executable private @main { // CHECK: %[[LOOP_T:.+]] = vector.transpose %[[LOOP]], [0, 1, 2, 4, 3, 5] : vector<1x4x1x4x4x1xf32> to vector<1x4x1x4x4x1xf32> // CHECK: %[[EXTRACT:.+]] = vector.extract %[[LOOP_T]][0] : vector<4x1x4x4x1xf32> from vector<1x4x1x4x4x1xf32> // CHECK: vector.transfer_write %[[EXTRACT]], %[[B2]] +// CHECK: } {mapping = [#iree_codegen.workgroup_mapping, #iree_codegen.workgroup_mapping, #iree_codegen.workgroup_mapping]} // ----- @@ -312,7 +317,7 @@ hal.executable private @main { #hal.pipeline.binding ]> #config = #iree_gpu.lowering_config<{ - workgroup = [64, 64, 0], + workgroup = [4, 4, 0], reduction = [0, 0, 2], subgroup = [2, 2], mma_kind = #iree_gpu.mma_layout, @@ -353,21 +358,23 @@ hal.executable public @main { // CHECK-DAG: %[[B2:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(2) // CHECK-DAG: memref.alloc() : memref<64x36xf16, #gpu.address_space> // CHECK-DAG: memref.alloc() : memref<64x36xf16, #gpu.address_space> -// CHECK: %[[LOOP:.+]] = scf.for %[[IV:.+]] = %c0 to %c80 step %c2 {{.*}} -> (vector<2x2x8x1x1xf32>) -// CHECK: gpu.barrier -// CHECK-DAG: vector.transfer_read %[[B0]]{{.*}} vector<8xf16> -// CHECK-DAG: vector.transfer_read %[[B0]]{{.*}} vector<8xf16> -// CHECK-DAG: vector.transfer_read %[[B1]]{{.*}} vector<8xf16> -// CHECK-DAG: vector.transfer_read %[[B1]]{{.*}} vector<8xf16> -// CHECK: gpu.barrier -// CHECK-DAG: vector.transfer_read {{.*}} vector<2x1x2x16xf16> -// CHECK-DAG: vector.transfer_read {{.*}} vector<2x1x2x16xf16> -// CHECK-DAG: vector.transpose %{{.*}}, [0, 2, 1, 3] : vector<2x1x2x16xf16> -// CHECK-DAG: vector.transpose %{{.*}}, [0, 2, 1, 3] : vector<2x1x2x16xf16> -// CHECK-COUNT-8: amdgpu.wmma {{.*}} : vector<16xf16>, vector<16xf16>, vector<8xf32> -// CHECK: scf.yield -// CHECK: %[[LOOP_T:.+]] = vector.transpose %[[LOOP]], [0, 2, 3, 1, 4] : vector<2x2x8x1x1xf32> to vector<2x8x1x2x1xf32> -// CHECK: vector.transfer_write %[[LOOP_T]], %[[B2]] +// CHECK: scf.forall ({{.*}}) in (32, 160) { +// CHECK: %[[LOOP:.+]] = scf.for %[[IV:.+]] = %c0 to %c80 step %c2 {{.*}} -> (vector<2x2x8x1x1xf32>) +// CHECK: gpu.barrier +// CHECK-DAG: vector.transfer_read %[[B0]]{{.*}} vector<8xf16> +// CHECK-DAG: vector.transfer_read %[[B0]]{{.*}} vector<8xf16> +// CHECK-DAG: vector.transfer_read %[[B1]]{{.*}} vector<8xf16> +// CHECK-DAG: vector.transfer_read %[[B1]]{{.*}} vector<8xf16> +// CHECK: gpu.barrier +// CHECK-DAG: vector.transfer_read {{.*}} vector<2x1x2x16xf16> +// CHECK-DAG: vector.transfer_read {{.*}} vector<2x1x2x16xf16> +// CHECK-DAG: vector.transpose %{{.*}}, [0, 2, 1, 3] : vector<2x1x2x16xf16> +// CHECK-DAG: vector.transpose %{{.*}}, [0, 2, 1, 3] : vector<2x1x2x16xf16> +// CHECK-COUNT-8: amdgpu.wmma {{.*}} : vector<16xf16>, vector<16xf16>, vector<8xf32> +// CHECK: scf.yield +// CHECK: %[[LOOP_T:.+]] = vector.transpose %[[LOOP]], [0, 2, 3, 1, 4] : vector<2x2x8x1x1xf32> to vector<2x8x1x2x1xf32> +// CHECK: vector.transfer_write %[[LOOP_T]], %[[B2]] +// CHECK: } {mapping = [#iree_codegen.workgroup_mapping, #iree_codegen.workgroup_mapping]} // ----- @@ -377,7 +384,7 @@ hal.executable public @main { #hal.pipeline.binding ]> #config = #iree_gpu.lowering_config<{ - workgroup = [64, 64, 0], + workgroup = [4, 4, 0], reduction = [0, 0, 2], subgroup = [2, 2], mma_kind = #iree_gpu.mma_layout, @@ -419,9 +426,11 @@ hal.executable public @main { // CHECK-LABEL: func @matmul_transpose_b_mfma_16x16x4 // CHECK-DAG: memref.alloc() : memref<64x10xf32, #gpu.address_space> // CHECK-DAG: memref.alloc() : memref<64x10xf32, #gpu.address_space> -// CHECK: scf.for %{{.*}} = %c0 to %c320 step %c2 {{.*}} -> (vector<2x2x4x1xf32>) -// CHECK-COUNT-8: amdgpu.mfma {{.*}}blocks = 1 : i32, k = 4 : i32, m = 16 : i32, n = 16 : i32 -// CHECK: scf.yield +// CHECK: scf.forall ({{.*}}) in (32, 160) { +// CHECK: scf.for %{{.*}} = %c0 to %c320 step %c2 {{.*}} -> (vector<2x2x4x1xf32>) +// CHECK-COUNT-8: amdgpu.mfma {{.*}}blocks = 1 : i32, k = 4 : i32, m = 16 : i32, n = 16 : i32 +// CHECK: scf.yield +// CHECK: } {mapping = [#iree_codegen.workgroup_mapping, #iree_codegen.workgroup_mapping]} // ----- @@ -431,7 +440,7 @@ hal.executable public @main { #hal.pipeline.binding ]> #config = #iree_gpu.lowering_config<{ - workgroup = [64, 64, 0], + workgroup = [4, 4, 0], reduction = [0, 0, 2], subgroup = [2, 2], mma_kind = #iree_gpu.mma_layout, @@ -473,9 +482,11 @@ hal.executable public @main { // CHECK-LABEL: func @matmul_transpose_b_mfma_16x16x32_f8 // CHECK-DAG: memref.alloc() : memref<64x72xf8E4M3FNUZ, #gpu.address_space> // CHECK-DAG: memref.alloc() : memref<64x72xf8E4M3FNUZ, #gpu.address_space> -// CHECK: scf.for %{{.*}} = %c0 to %c40 step %c2 {{.*}} -> (vector<2x2x4x1xf32>) -// CHECK-COUNT-8: amdgpu.mfma {{.*}}blocks = 1 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32 -// CHECK: scf.yield +// CHECK: scf.forall ({{.*}}) in (32, 160) { +// CHECK: scf.for %{{.*}} = %c0 to %c40 step %c2 {{.*}} -> (vector<2x2x4x1xf32>) +// CHECK-COUNT-8: amdgpu.mfma {{.*}}blocks = 1 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32 +// CHECK: scf.yield +// CHECK: } {mapping = [#iree_codegen.workgroup_mapping, #iree_codegen.workgroup_mapping]} // ----- @@ -485,7 +496,7 @@ hal.executable public @main { #hal.pipeline.binding ]> #config = #iree_gpu.lowering_config<{ - workgroup = [64, 64, 0], + workgroup = [2, 2, 0], reduction = [0, 0, 2], subgroup = [1, 1], mma_kind = #iree_gpu.mma_layout, @@ -527,9 +538,11 @@ hal.executable public @main { // CHECK-LABEL: func @matmul_transpose_b_mfma_32x32x16_i8 // CHECK-DAG: memref.alloc() : memref<64x40xi8, #gpu.address_space> // CHECK-DAG: memref.alloc() : memref<64x40xi8, #gpu.address_space> -// CHECK: scf.for %{{.*}} = %c0 to %c80 step %c2 {{.*}} -> (vector<1x1x4x4x1xi32>) -// CHECK-COUNT-2: amdgpu.mfma {{.*}}blocks = 1 : i32, k = 16 : i32, m = 32 : i32, n = 32 : i32 -// CHECK: scf.yield +// CHECK: scf.forall ({{.*}}) in (32, 160) { +// CHECK: scf.for %{{.*}} = %c0 to %c80 step %c2 {{.*}} -> (vector<1x1x4x4x1xi32>) +// CHECK-COUNT-2: amdgpu.mfma {{.*}}blocks = 1 : i32, k = 16 : i32, m = 32 : i32, n = 32 : i32 +// CHECK: scf.yield +// CHECK: } {mapping = [#iree_codegen.workgroup_mapping, #iree_codegen.workgroup_mapping]} // ----- @@ -539,7 +552,7 @@ hal.executable public @main { #hal.pipeline.binding ]> #config = #iree_gpu.lowering_config<{ - workgroup = [64, 64, 0], + workgroup = [4, 4, 0], reduction = [0, 0, 2], subgroup = [2, 2], mma_kind = #iree_gpu.mma_layout, @@ -581,9 +594,11 @@ hal.executable public @main { // CHECK-LABEL: func @matmul_transpose_b_wmma_f16_16x16x16_f16 // CHECK-DAG: memref.alloc() : memref<64x36xf16, #gpu.address_space> // CHECK-DAG: memref.alloc() : memref<64x36xf16, #gpu.address_space> -// CHECK: scf.for %{{.*}} = %c0 to %c80 step %c2 {{.*}} -> (vector<2x2x16x1x1xf16>) -// CHECK-COUNT-8: amdgpu.wmma {{.*}} : vector<16xf16>, vector<16xf16>, vector<16xf16> -// CHECK: scf.yield +// CHECK: scf.forall ({{.*}}) in (32, 160) { +// CHECK: scf.for %{{.*}} = %c0 to %c80 step %c2 {{.*}} -> (vector<2x2x16x1x1xf16>) +// CHECK-COUNT-8: amdgpu.wmma {{.*}} : vector<16xf16>, vector<16xf16>, vector<16xf16> +// CHECK: scf.yield +// CHECK: } {mapping = [#iree_codegen.workgroup_mapping, #iree_codegen.workgroup_mapping]} // ----- @@ -639,12 +654,14 @@ hal.executable public @main { // the producer's (convolution's) distributed scf.forall loop. // CHECK-LABEL: func @conv_nchw_fused // CHECK: %[[ALLOCA:.+]] = memref.alloca() : memref<1x1x1x1xf32, #gpu.address_space> -// CHECK: scf.for %{{.*}} = %c0 to %c64 step %c1 -// CHECK: linalg.conv_2d_nchw_fchw -// CHECK-SAME: outs(%[[ALLOCA]] : memref<1x1x1x1xf32, #gpu.address_space>) -// CHECK: arith.addf -// CHECK: arith.cmpf -// CHECK: arith.select +// CHECK: scf.forall ({{.*}}) in (64, 14, 7) { +// CHECK: scf.for %{{.*}} = %c0 to %c64 step %c1 +// CHECK: linalg.conv_2d_nchw_fchw +// CHECK-SAME: outs(%[[ALLOCA]] : memref<1x1x1x1xf32, #gpu.address_space>) +// CHECK: arith.addf +// CHECK: arith.cmpf +// CHECK: arith.select +// CHECK: } {mapping = [#iree_codegen.workgroup_mapping, #iree_codegen.workgroup_mapping, #iree_codegen.workgroup_mapping]} // ----- @@ -715,11 +732,13 @@ hal.executable public @main { // CHECK: %[[LINID0:.+]] = affine.apply #[[$MAP]]()[%[[IDX]], %[[IDY]], %[[IDZ]]] // CHECK: %[[IDS:.+]]:2 = affine.delinearize_index %[[LINID0:.+]] into (%c4, %c8) : index, index // CHECK: %[[LINID1:.+]] = affine.apply #[[$MAP1]]()[%[[IDS]]#0, %[[IDS]]#1] -// CHECK: scf.for %{{.*}} = %c0 to %c256 step %c4 {{.*}} -> (vector<1x4xf32>) -// CHECK: scf.for %{{.*}} = %[[LINID1]] to %c4 step %c32 -// CHECK: %[[READ:.+]] = vector.transfer_read {{.*}} : memref<128x256xf32, {{.*}}storage_buffer>>, vector<4xf32> -// CHECK: vector.transfer_write %[[READ]], %{{.*}} : vector<4xf32>, memref<4x6xf32, #gpu.address_space> -// CHECK: vector.contract +// CHECK: scf.forall ({{.*}}) in (32, 98) { +// CHECK: scf.for %{{.*}} = %c0 to %c256 step %c4 {{.*}} -> (vector<1x4xf32>) +// CHECK: scf.for %{{.*}} = %[[LINID1]] to %c4 step %c32 +// CHECK: %[[READ:.+]] = vector.transfer_read {{.*}} : memref<128x256xf32, {{.*}}storage_buffer>>, vector<4xf32> +// CHECK: vector.transfer_write %[[READ]], %{{.*}} : vector<4xf32>, memref<4x6xf32, #gpu.address_space> +// CHECK: vector.contract +// CHECK: } {mapping = [#iree_codegen.workgroup_mapping, #iree_codegen.workgroup_mapping]} // ----- @@ -736,7 +755,7 @@ hal.executable public @main { mma_kind = #iree_gpu.mma_layout, reduction = [0, 0, 4], subgroup = [2, 4, 0], - workgroup = [64, 128, 0], + workgroup = [4, 8, 0], promote_operands = [0, 1] }> @@ -1012,12 +1031,8 @@ hal.executable public @main { // CHECK-DAG: %[[RHS_ALLOC:.+]] = memref.alloc() : memref<4x130xf32, #gpu.address_space> // CHECK: %[[LOOP:.+]] = scf.for %[[IV:.+]] = %c0 to %c1000 step %c4 {{.*}} -> (vector<1x4xf32>) // CHECK: gpu.barrier - -// TODO: The fact that this read gets hoisted out of the subsequent for loop -// is a bug in LICM that does no verification that the loop has at least one -// trip. -// CHECK: %[[LHS_RD:.+]] = vector.transfer_read %[[B0]]{{.*}} vector<4xf32> // CHECK: scf.for %{{.*}} = %{{.*}} to %c1 step %c32 +// CHECK: %[[LHS_RD:.+]] = vector.transfer_read %[[B0]]{{.*}} vector<4xf32> // CHECK-NEXT: vector.transfer_write %[[LHS_RD]], %[[LHS_ALLOC]] // CHECK: gpu.barrier // CHECK-DAG: %[[LHS_MM:.+]] = vector.transfer_read %[[LHS_ALLOC]]{{.*}} vector<4xf32> @@ -1072,6 +1087,7 @@ hal.executable public @main { // Verify that the write does not get hoisted out of the single threaded // for loop. -// CHECK: vector.transfer_write %{{.*}}, %[[B2]]{{.*}} memref<10x1xf32, #hal.descriptor_type> -// CHECK-NEXT: } +// CHECK: vector.transfer_write %{{.*}}, %[[B2]]{{.*}} memref<10x1xf32, #hal.descriptor_type> +// CHECK-NEXT: } +// CHECK-NEXT: } {mapping = [#iree_codegen.workgroup_mapping]} // CHECK-NEXT: return diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx940.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx940.mlir index 1c460f7bec9c..4334e79d6f88 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx940.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx940.mlir @@ -166,6 +166,55 @@ hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb">) { // ----- +#pipeline_layout = #hal.pipeline.layout, + #hal.pipeline.binding, + #hal.pipeline.binding +]> + +hal.executable @matmul_multiple_k { + hal.executable.variant public @rocm target(<"rocm", "rocm-hsaco-fb">) { + hal.executable.export public @matmul_multiple_k layout(#pipeline_layout) { + ^bb0(%arg0: !hal.device): + %x, %y, %z = flow.dispatch.workgroup_count_from_slice + hal.return %x, %y, %z : index, index, index + } + builtin.module { + func.func @matmul_multiple_k() attributes {translation_info = #iree_codegen.translation_info, subgroup_m_count = 1, subgroup_n_count = 4>}>} { + %cst = arith.constant 0.000000e+00 : f16 + %c0 = arith.constant 0 : index + %0 = hal.interface.binding.subspan layout(, #hal.pipeline.binding, #hal.pipeline.binding], flags = Indirect>) binding(0) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : !flow.dispatch.tensor> + %1 = hal.interface.binding.subspan layout(, #hal.pipeline.binding, #hal.pipeline.binding], flags = Indirect>) binding(1) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : !flow.dispatch.tensor> + %2 = hal.interface.binding.subspan layout(, #hal.pipeline.binding, #hal.pipeline.binding], flags = Indirect>) binding(2) alignment(64) offset(%c0) flags(Indirect) : !flow.dispatch.tensor> + %3 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0, 0], sizes = [2, 128, 64, 2048], strides = [1, 1, 1, 1] : !flow.dispatch.tensor> -> tensor<2x128x64x2048xf16> + %4 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0, 0], sizes = [10, 128, 64, 2048], strides = [1, 1, 1, 1] : !flow.dispatch.tensor> -> tensor<10x128x64x2048xf16> + %5 = tensor.empty() : tensor<2x10x64x64xf16> + %6 = linalg.fill ins(%cst : f16) outs(%5 : tensor<2x10x64x64xf16>) -> tensor<2x10x64x64xf16> + %7 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d2, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d4, d3, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%3, %4 : tensor<2x128x64x2048xf16>, tensor<10x128x64x2048xf16>) outs(%6 : tensor<2x10x64x64xf16>) attrs = {lowering_config = #iree_gpu.lowering_config<{reduction = [0, 0, 0, 0, 1, 128], workgroup = [1, 1, 64, 64, 0, 0]}>} { + ^bb0(%in: f16, %in_0: f16, %out: f16): + %8 = arith.mulf %in, %in_0 : f16 + %9 = arith.addf %8, %out : f16 + linalg.yield %9 : f16 + } -> tensor<2x10x64x64xf16> + flow.dispatch.tensor.store %7, %2, offsets = [0, 0, 0, 0], sizes = [2, 10, 64, 64], strides = [1, 1, 1, 1] : tensor<2x10x64x64xf16> -> !flow.dispatch.tensor> + return + } + } + } +} + +// Check if we can handle multiple reduction dimensions and that they generate +// one coalesced loop. + +// CHECK-LABEL: func.func @matmul_multiple_k +// CHECK: scf.for %[[IV:.+]] = %c0 to %c2048 step %c1 +// CHECK: affine.delinearize_index %[[IV]] into (%c128, %c16) +// CHECK-COUNT-32: amdgpu.mfma +// CHECK: scf.yield +// CHECK-COUNT-4: vector.transfer_write {{.+}} {in_bounds = [true, true]} : vector<4x1xf16>, memref<2x10x64x64xf16, #hal.descriptor_type> + +// ----- + // Basic f8, f8 -> f32 matmul. #config = #iree_gpu.lowering_config<{workgroup = [64, 64, 0], reduction = [0, 0, 256]}> @@ -462,7 +511,7 @@ hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb">) { // CHECK: %[[RHS_LOAD:.+]] = vector.transfer_read %[[RHS_GLOBAL_SUB]]{{.+}} {in_bounds = [true, false, false]} // CHECK: vector.transfer_write %[[LHS_LOAD]], %[[LHS_SHARED]] // CHECK: vector.transfer_write %[[RHS_LOAD]], %[[RHS_SHARED]] -// CHECK: %[[RES:.+]] scf.for {{.*}} = %c0 to %c1265 step %c16 iter_args({{.*}}) -> (vector<1x1x1x1x1x1x1x4x1xf16>) +// CHECK: %[[RES:.+]] scf.for {{.*}} = %c0 to %c1280 step %c16 iter_args({{.*}}) -> (vector<1x1x1x1x1x1x1x4x1xf16>) // CHECK-DAG: %[[LHS_GLOBAL_SUB:.+]] = memref.subview %[[LHS_GLOBAL]] // CHECK-DAG: %[[RHS_GLOBAL_SUB:.+]] = memref.subview %[[RHS_GLOBAL]] // CHECK: %[[LHS_LOAD:.+]] = vector.transfer_read %[[LHS_GLOBAL_SUB]] @@ -532,9 +581,11 @@ hal.executable public @pad_batch_matmul { // CHECK-SAME: memref<196x16x24xf32 // CHECK-SAME: vector<1x1x1xf32> // RHS +// The dynamic dimension should be removed after: +// https://github.com/llvm/llvm-project/pull/112236 // CHECK: vector.transfer_read -// CHECK-SAME: in_bounds = [true, true, false] -// CHECK-SAME: memref<1x8x24xf32 +// CHECK-SAME: in_bounds = [true, false, false] +// CHECK-SAME: memref<1x?x24xf32 // CHECK-SAME: vector<1x1x2xf32> // CHECK: scf.yield // OUTPUT @@ -637,7 +688,9 @@ hal.executable private @attention_20x4096x64x4096x64 { affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> ()>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>], - lowering_config = #config} + lowering_config = #config, + decomposition_config = {qk_attrs = {attention_qk_matmul}, + pv_attrs = {attention_pv_matmul}}} ins(%4, %5, %6, %cst : tensor<20x4096x64xf16>, tensor<20x4096x64xf16>, tensor<20x4096x64xf16>, f16) outs(%7 : tensor<20x4096x64xf16>) { ^bb0(%score: f32): iree_linalg_ext.yield %score : f32 @@ -702,7 +755,15 @@ hal.executable private @attention_multiple_m_transpose { %6 = flow.dispatch.tensor.load %2, offsets = [0, 0, 0], sizes = [24, 4608, 128], strides = [1, 1, 1] : !flow.dispatch.tensor> -> tensor<24x4608x128xf16> %7 = tensor.empty() : tensor<64x4608x24x128xf16> %8 = tensor.empty() : tensor<24x64x4608x128xf16> - %9 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d3)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> ()>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>], lowering_config = #config} ins(%4, %5, %6, %cst : tensor<24x64x4608x128xf16>, tensor<24x4608x128xf16>, tensor<24x4608x128xf16>, f16) outs(%8 : tensor<24x64x4608x128xf16>) { + %9 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>, + affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d3)>, + affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d5)>, + affine_map<(d0, d1, d2, d3, d4, d5) -> ()>, + affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>], + lowering_config = #config, + decomposition_config = {qk_attrs = {attention_qk_matmul}, + pv_attrs = {attention_pv_matmul}}} + ins(%4, %5, %6, %cst : tensor<24x64x4608x128xf16>, tensor<24x4608x128xf16>, tensor<24x4608x128xf16>, f16) outs(%8 : tensor<24x64x4608x128xf16>) { ^bb0(%score: f32): iree_linalg_ext.yield %score : f32 } -> tensor<24x64x4608x128xf16> @@ -760,7 +821,15 @@ hal.executable private @attention_mfma_32x32x8 { %6 = flow.dispatch.tensor.load %2, offsets = [0, 0, 0], sizes = [24, 4608, 128], strides = [1, 1, 1] : !flow.dispatch.tensor> -> tensor<24x4608x128xf16> %7 = tensor.empty() : tensor<64x4608x24x128xf16> %8 = tensor.empty() : tensor<24x64x4608x128xf16> - %9 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d3)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> ()>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>], lowering_config = #config} ins(%4, %5, %6, %cst : tensor<24x64x4608x128xf16>, tensor<24x4608x128xf16>, tensor<24x4608x128xf16>, f16) outs(%8 : tensor<24x64x4608x128xf16>) { + %9 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>, + affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d3)>, + affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d5)>, + affine_map<(d0, d1, d2, d3, d4, d5) -> ()>, + affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>], + lowering_config = #config, + decomposition_config = {qk_attrs = {attention_qk_matmul}, + pv_attrs = {attention_pv_matmul}}} + ins(%4, %5, %6, %cst : tensor<24x64x4608x128xf16>, tensor<24x4608x128xf16>, tensor<24x4608x128xf16>, f16) outs(%8 : tensor<24x64x4608x128xf16>) { ^bb0(%score: f32): iree_linalg_ext.yield %score : f32 } -> tensor<24x64x4608x128xf16> diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/assign_constant_ordinals.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/assign_constant_ordinals.mlir new file mode 100644 index 000000000000..8a133f91a8fb --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/assign_constant_ordinals.mlir @@ -0,0 +1,22 @@ +// RUN: iree-opt --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(iree-llvmgpu-assign-constant-ordinals)))" --split-input-file %s | FileCheck %s + +hal.executable private @executable { + hal.executable.variant public @variant target(#hal.executable.target<"rocm", "rocm-hsaco-fb">) { + hal.executable.constant.block(%device: !hal.device) -> i32 as "foo" { + %c0 = arith.constant 0 : i32 + hal.return %c0 : i32 + } + hal.executable.constant.block(%device: !hal.device) -> i32 as "bar" { + %c1 = arith.constant 1 : i32 + hal.return %c1 : i32 + } + builtin.module { + // CHECK: llvm.mlir.global internal constant @__constant_ordinal_foo_a(0 : i32) + llvm.mlir.global internal @__constant_ordinal_foo_a() {addr_space = 4 : i32, hal.executable.constant.key = "foo", sym_visibility = "private"} : i32 + // CHECK: llvm.mlir.global internal constant @__constant_ordinal_foo_b(0 : i32) + llvm.mlir.global internal @__constant_ordinal_foo_b() {addr_space = 4 : i32, hal.executable.constant.key = "foo", sym_visibility = "private"} : i32 + // CHECK: llvm.mlir.global internal constant @__constant_ordinal_bar(1 : i32) + llvm.mlir.global internal @__constant_ordinal_bar() {addr_space = 4 : i32, hal.executable.constant.key = "bar", sym_visibility = "private"} : i32 + } + } +} diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/linalg_transform.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/linalg_transform.mlir index cfa8875c9685..daeaa225a265 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/linalg_transform.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/linalg_transform.mlir @@ -49,7 +49,6 @@ func.func @matmul_static_dispatch_0() attributes {hal.executable.target = #execu // FOREACH-TO-GPU: %[[COND:.*]] = arith.andi %[[LT1]], %[[LT5]] : i1 // FOREACH-TO-GPU: scf.if %[[COND]] { // FOREACH-TO-GPU: affine.apply #{{.*}}()[%[[TIDY]]] - // FOREACH-TO-GPU: affine.apply #{{.*}}()[%[[TIDX]]] // FOREACH-TO-GPU: linalg.fill // FOREACH-TO-GPU: } // FOREACH-TO-GPU: gpu.barrier diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/link_executables.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/link_executables.mlir new file mode 100644 index 000000000000..5655992d1d8f --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/link_executables.mlir @@ -0,0 +1,150 @@ +// RUN: iree-opt --iree-llvmgpu-link-executables --split-input-file %s | FileCheck %s +// RUN: iree-opt --pass-pipeline='builtin.module(iree-llvmgpu-link-executables{target="rocm"})' --split-input-file %s | FileCheck %s --check-prefix=CHECK-TARGET +// RUN: iree-opt --pass-pipeline='builtin.module(iree-llvmgpu-link-executables{target="cuda"},iree-llvmgpu-link-executables{target="rocm"})' --split-input-file %s | FileCheck %s --check-prefix=CHECK-MULTI + +#executable_target_rocm = #hal.executable.target<"rocm", "rocm-hsaco-fb"> + +// Expect a single executable with both exports and correct ordinals. +// CHECK: hal.executable private @link_executables_linked +// CHECK: hal.executable.variant public @rocm_hsaco_fb +// CHECK: hal.executable.export public @export0 ordinal(0) +// CHECK: hal.executable.export public @export1 ordinal(1) + +// Expect one LLVM module with all globals and functions. +// Note that shared memory is duplicated but dynamic shared memory is not. +// CHECK: builtin.module +// CHECK-NEXT: llvm.mlir.global external @__dynamic_shared_memory__ +// CHECK-NEXT: llvm.mlir.global private @__shared_memory__{{.+}} : !llvm.array<2 x array<64 x i32>> +// CHECK-NEXT: llvm.func @export0 +// CHECK-NEXT: llvm.mlir.addressof @__dynamic_shared_memory__ : !llvm.ptr<3> +// CHECK-NEXT: llvm.mlir.addressof @__shared_memory__ : !llvm.ptr<3> +// CHECK: llvm.mlir.global private @__shared_memory___0{{.+}} : !llvm.array<2 x array<128 x i32>> +// CHECK-NEXT: llvm.func @export1 +// CHECK-NEXT: llvm.mlir.addressof @__dynamic_shared_memory__ : !llvm.ptr<3> +// CHECK-NEXT: llvm.mlir.addressof @__shared_memory___0 : !llvm.ptr<3> + +hal.executable private @executable0 { + hal.executable.variant public @rocm_hsaco_fb target(#executable_target_rocm) { + hal.executable.export public @export0 ordinal(0) layout(#hal.pipeline.layout]>) { + ^bb0(%arg0: !hal.device): + %c1 = arith.constant 1 : index + hal.return %c1, %c1, %c1 : index, index, index + } + builtin.module { + llvm.mlir.global external @__dynamic_shared_memory__() {addr_space = 3 : i32, alignment = 16 : i64} : !llvm.array<0 x i8> + llvm.mlir.global private @__shared_memory__() {addr_space = 3 : i32, alignment = 4 : i64} : !llvm.array<2 x array<64 x i32>> + llvm.func @export0(%arg0: !llvm.ptr<1> {llvm.align = 16 : i32, llvm.noalias}) { + %0 = llvm.mlir.addressof @__dynamic_shared_memory__ : !llvm.ptr<3> + %1 = llvm.mlir.addressof @__shared_memory__ : !llvm.ptr<3> + llvm.return + } + } + } +} +hal.executable private @executable1 { + hal.executable.variant public @rocm_hsaco_fb target(#executable_target_rocm) { + hal.executable.export public @export1 ordinal(0) layout(#hal.pipeline.layout]>) { + ^bb0(%arg0: !hal.device): + %c1 = arith.constant 1 : index + hal.return %c1, %c1, %c1 : index, index, index + } + builtin.module { + llvm.mlir.global external @__dynamic_shared_memory__() {addr_space = 3 : i32, alignment = 16 : i64} : !llvm.array<0 x i8> + llvm.mlir.global private @__shared_memory__() {addr_space = 3 : i32, alignment = 4 : i64} : !llvm.array<2 x array<128 x i32>> + llvm.func @export1(%arg0: !llvm.ptr<1> {llvm.align = 16 : i32, llvm.noalias}) { + %0 = llvm.mlir.addressof @__dynamic_shared_memory__ : !llvm.ptr<3> + %1 = llvm.mlir.addressof @__shared_memory__ : !llvm.ptr<3> + llvm.return + } + } + } +} + +// ----- + +#executable_target_cuda = #hal.executable.target<"cuda", "cuda-nvptx-fb"> +#executable_target_rocm = #hal.executable.target<"rocm", "rocm-hsaco-fb"> + +// Expect a single executable with multiple variants when not specifying target. +// CHECK: hal.executable private @link_executables_linked +// CHECK: hal.executable.variant public @cuda_nvptx_fb_0 +// CHECK: hal.executable.export public @export0 ordinal(0) +// CHECK: hal.executable.export public @export1 ordinal(1) +// CHECK: hal.executable.variant public @rocm_hsaco_fb_1 +// CHECK: hal.executable.export public @export0 ordinal(0) +// CHECK: hal.executable.export public @export1 ordinal(1) + +// Expect only one target be linked when specified. +// CHECK-TARGET: hal.executable private @link_executables_linked +// CHECK-TARGET: hal.executable.variant public @rocm_hsaco_fb_1 +// CHECK-TARGET: hal.executable.export public @export0 ordinal(0) +// CHECK-TARGET: hal.executable.export public @export1 ordinal(1) +// CHECK-TARGET: hal.executable private @executable0 +// CHECK-TARGET: hal.executable.variant public @cuda_nvptx_fb +// CHECK-TARGET: hal.executable.export public @export0 ordinal(0) +// CHECK-TARGET: hal.executable private @executable1 +// CHECK-TARGET: hal.executable.variant public @cuda_nvptx_fb +// CHECK-TARGET: hal.executable.export public @export1 ordinal(0) + +// Multiple applications of the pass per target should not conflict. +// CHECK-MULTI: hal.executable private @link_executables_linked_0 +// CHECK-MULTI: hal.executable.variant public @rocm_hsaco_fb_1 +// CHECK-MULTI: hal.executable.export public @export0 ordinal(0) +// CHECK-MULTI: hal.executable.export public @export1 ordinal(1) +// CHECK-MULTI: hal.executable private @link_executables_linked +// CHECK-MULTI: hal.executable.variant public @cuda_nvptx_fb_0 +// CHECK-MULTI: hal.executable.export public @export0 ordinal(0) +// CHECK-MULTI: hal.executable.export public @export1 ordinal(1) + +hal.executable private @executable0 { + hal.executable.variant public @cuda_nvptx_fb target(#executable_target_cuda) { + hal.executable.export public @export0 ordinal(0) layout(#hal.pipeline.layout]>) { + ^bb0(%arg0: !hal.device): + %c1 = arith.constant 1 : index + hal.return %c1, %c1, %c1 : index, index, index + } + builtin.module { + llvm.func @export0(%arg0: !llvm.ptr<1> {llvm.align = 16 : i32, llvm.noalias}) { + llvm.return + } + } + } + hal.executable.variant public @rocm_hsaco_fb target(#executable_target_rocm) { + hal.executable.export public @export0 ordinal(0) layout(#hal.pipeline.layout]>) { + ^bb0(%arg0: !hal.device): + %c1 = arith.constant 1 : index + hal.return %c1, %c1, %c1 : index, index, index + } + builtin.module { + llvm.func @export0(%arg0: !llvm.ptr<1> {llvm.align = 16 : i32, llvm.noalias}) { + llvm.return + } + } + } +} +hal.executable private @executable1 { + hal.executable.variant public @cuda_nvptx_fb target(#executable_target_cuda) { + hal.executable.export public @export1 ordinal(0) layout(#hal.pipeline.layout]>) { + ^bb0(%arg0: !hal.device): + %c1 = arith.constant 1 : index + hal.return %c1, %c1, %c1 : index, index, index + } + builtin.module { + llvm.func @export1(%arg0: !llvm.ptr<1> {llvm.align = 16 : i32, llvm.noalias}) { + llvm.return + } + } + } + hal.executable.variant public @rocm_hsaco_fb target(#executable_target_rocm) { + hal.executable.export public @export1 ordinal(0) layout(#hal.pipeline.layout]>) { + ^bb0(%arg0: !hal.device): + %c1 = arith.constant 1 : index + hal.return %c1, %c1, %c1 : index, index, index + } + builtin.module { + llvm.func @export1(%arg0: !llvm.ptr<1> {llvm.align = 16 : i32, llvm.noalias}) { + llvm.return + } + } + } +} diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/llvmgpu_convolution_to_igemm.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/llvmgpu_convolution_to_igemm.mlir new file mode 100644 index 000000000000..9618281c699e --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/llvmgpu_convolution_to_igemm.mlir @@ -0,0 +1,36 @@ +// RUN: iree-opt --split-input-file --iree-gpu-test-target=gfx940 --pass-pipeline="builtin.module(func.func(iree-llvmgpu-convolution-to-igemm),canonicalize,cse)" %s | FileCheck %s + +#config = #iree_codegen.lowering_config +func.func public @conv_with_lowering_config(%arg0: tensor<1x16x16x4xf32>, %arg1: tensor<3x3x4x16xf32>) -> tensor<1x14x14x16xf32> { + %cst = arith.constant 0.0 : f32 + %empty = tensor.empty() : tensor<1x14x14x16xf32> + %fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> + %0 = linalg.conv_2d_nhwc_hwcf {lowering_config = #config, + dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} + ins(%arg0, %arg1: tensor<1x16x16x4xf32>, tensor<3x3x4x16xf32>) + outs(%fill: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> + return %0 : tensor<1x14x14x16xf32> +} +// CHECK: func.func public @conv_with_lowering_config +// CHECK-NOT: iree_linalg_ext.im2col +// CHECK: linalg.conv_2d_nhwc_hwcf +// CHECK-SAME: lowering_config + +// ----- + +func.func public @set_lowering_config(%arg0: tensor<1x34x34x128xf32>, %arg1: tensor<3x3x128x128xf32>) -> tensor<1x32x32x128xf32> { + %cst = arith.constant 0.0 : f32 + %empty = tensor.empty() : tensor<1x32x32x128xf32> + %fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<1x32x32x128xf32>) -> tensor<1x32x32x128xf32> + %0 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} + ins(%arg0, %arg1: tensor<1x34x34x128xf32>, tensor<3x3x128x128xf32>) + outs(%fill: tensor<1x32x32x128xf32>) -> tensor<1x32x32x128xf32> + return %0 : tensor<1x32x32x128xf32> +} +// CHECK: func.func public @set_lowering_config +// CHECK: iree_linalg_ext.im2col +// CHECK: linalg.generic +// CHECK-SAME: lowering_config = #iree_gpu.lowering_config< +// CHECK-SAME: {mma_kind = #iree_gpu.mma_layout, +// CHECK-SAME: promote_operands = [0, 1], reduction = [0, 0, 0, 0, 8], +// CHECK-SAME: subgroup = [1, 1, 2, 2, 0], workgroup = [1, 1, 2, 8, 0]}> diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/promote_matmul_to_fit_mma.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/promote_matmul_to_fit_mma.mlir index bda4836eaec3..21bc2fc3cac3 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/promote_matmul_to_fit_mma.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/promote_matmul_to_fit_mma.mlir @@ -1,5 +1,4 @@ -// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(func.func(iree-llvmgpu-promote-matmul-to-fit-mma{target-dimensions=parallel}))" %s | FileCheck %s --check-prefixes=ALL,PARALLEL -// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(func.func(iree-llvmgpu-promote-matmul-to-fit-mma{target-dimensions=reduction}))" %s | FileCheck %s --check-prefixes=ALL,REDUCTION +// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(func.func(iree-llvmgpu-promote-matmul-to-fit-mma))" %s | FileCheck %s #pipeline_layout = #hal.pipeline.layout, @@ -34,114 +33,20 @@ func.func @batch_matmul_f16() { flow.dispatch.tensor.store %11, %2, offsets = [%workgroup_id_z, %3, %4], sizes = [1, %5, %6], strides = [1, 1, 1] : tensor<1x?x?xf16> -> !flow.dispatch.tensor> return } -// ALL-LABEL: func.func @batch_matmul_f16 -// ALL: %[[LHS_HANDLE:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> -// ALL: %[[RHS_HANDLE:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(1) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> -// ALL: %[[OUT_HANDLE:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(2) alignment(64) offset(%c0) : !flow.dispatch.tensor> -// ALL-DAG: %[[LHS:.+]] = flow.dispatch.tensor.load %[[LHS_HANDLE]] -// ALL-DAG: %[[RHS:.+]] = flow.dispatch.tensor.load %[[RHS_HANDLE]] -// PARALLEL: %[[PADDED_LHS:.+]] = tensor.pad %[[LHS]] -// PARALLEL: } : tensor<1x?x1281xf16> to tensor<1x64x1281xf16> -// PARALLEL: %[[PADDED_RHS:.+]] = tensor.pad %[[RHS]] -// PARALLEL: } : tensor<1x1281x?xf16> to tensor<1x1281x128xf16> -// PARALLEL: %[[FILL:.+]] = linalg.fill -// PARALLEL: %[[GEMM:.+]] = linalg.batch_matmul -// PARALLEL-SAME: ins(%[[PADDED_LHS]], %[[PADDED_RHS]] -// PARALLEL-SAME: outs(%[[FILL]] +// CHECK-LABEL: func.func @batch_matmul_f16 +// CHECK: %[[LHS_HANDLE:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> +// CHECK: %[[RHS_HANDLE:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(1) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> +// CHECK: %[[OUT_HANDLE:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(2) alignment(64) offset(%c0) : !flow.dispatch.tensor> +// CHECK-DAG: %[[LHS:.+]] = flow.dispatch.tensor.load %[[LHS_HANDLE]] +// CHECK-DAG: %[[RHS:.+]] = flow.dispatch.tensor.load %[[RHS_HANDLE]] +// CHECK: %[[PADDED_LHS:.+]] = tensor.pad %[[LHS]] +// CHECK: } : tensor<1x?x1281xf16> to tensor<1x64x1296xf16> +// CHECK: %[[PADDED_RHS:.+]] = tensor.pad %[[RHS]] +// CHECK: } : tensor<1x1281x?xf16> to tensor<1x1296x128xf16> +// CHECK: %[[FILL:.+]] = linalg.fill +// CHECK: %[[GEMM:.+]] = linalg.batch_matmul +// CHECK-SAME: ins(%[[PADDED_LHS]], %[[PADDED_RHS]] +// CHECK-SAME: outs(%[[FILL]] -// The reduction dim is not tiled in the test case, so it pads it to the -// matmul intrinsic k. -// REDUCTION-DAG: %[[FILL_DEST:.+]] = flow.dispatch.tensor.load %[[OUT_HANDLE]] -// REDUCTION: %[[FILL:.+]] = linalg.fill ins(%{{.+}}) outs(%[[FILL_DEST]] -// REDUCTION: %[[PADDED_LHS:.+]] = tensor.pad %[[LHS]] -// REDUCTION: } : tensor<1x?x1281xf16> to tensor<1x?x1296xf16> -// REDUCTION: %[[PADDED_RHS:.+]] = tensor.pad %[[RHS]] -// REDUCTION: } : tensor<1x1281x?xf16> to tensor<1x1296x?xf16> -// REDUCTION: %[[GEMM:.+]] = linalg.batch_matmul -// REDUCTION-SAME: ins(%[[PADDED_LHS]], %[[PADDED_RHS]] -// REDUCTION-SAME: outs(%[[FILL]] - -// ALL: %[[OUT_SLICE:.+]] = tensor.extract_slice %[[GEMM]] -// ALL: flow.dispatch.tensor.store %[[OUT_SLICE]], %[[OUT_HANDLE]] - -// ----- - -#pipeline_layout = #hal.pipeline.layout, - #hal.pipeline.binding, - #hal.pipeline.binding -]> -#map = affine_map<()[s0] -> (s0 * 64)> -#map1 = affine_map<()[s0] -> (s0 * 128)> -#map2 = affine_map<()[s0] -> (s0 * -64 + 968, 64)> -#map3 = affine_map<()[s0] -> (s0 * -128 + 1281, 128)> -#map4 = affine_map<()[s0] -> (-s0 + 64)> -#map5 = affine_map<()[s0] -> (-s0 + 128)> -#map6 = affine_map<(d0) -> (-d0 + 1281, 64)> -func.func @batch_matmul_pad_reduction_after_tiling() { - %c64 = arith.constant 64 : index - %c1281 = arith.constant 1281 : index - %c2 = arith.constant 2 : index - %c1 = arith.constant 1 : index - %cst = arith.constant 0.000000e+00 : f16 - %c0 = arith.constant 0 : index - %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> - %1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> - %2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) alignment(64) offset(%c0) : !flow.dispatch.tensor> - %workgroup_id_z = hal.interface.workgroup.id[2] : index - %workgroup_id_y = hal.interface.workgroup.id[1] : index - %3 = affine.apply #map()[%workgroup_id_y] - %workgroup_id_x = hal.interface.workgroup.id[0] : index - %4 = affine.apply #map1()[%workgroup_id_x] - %5 = affine.min #map2()[%workgroup_id_y] - %6 = affine.min #map3()[%workgroup_id_x] - %7 = flow.dispatch.tensor.load %0, offsets = [%workgroup_id_z, %3, 0], sizes = [1, %5, 1281], strides = [1, 1, 1] : !flow.dispatch.tensor> -> tensor<1x?x1281xf16> - %dim = tensor.dim %7, %c1 : tensor<1x?x1281xf16> - %8 = flow.dispatch.tensor.load %1, offsets = [%workgroup_id_z, 0, %4], sizes = [1, 1281, %6], strides = [1, 1, 1] : !flow.dispatch.tensor> -> tensor<1x1281x?xf16> - %dim_0 = tensor.dim %8, %c2 : tensor<1x1281x?xf16> - %9 = affine.apply #map4()[%5] - %padded = tensor.pad %7 low[0, 0, 0] high[0, %9, 0] { - ^bb0(%arg0: index, %arg1: index, %arg2: index): - tensor.yield %cst : f16 - } : tensor<1x?x1281xf16> to tensor<1x64x1281xf16> - %10 = affine.apply #map5()[%6] - %padded_2 = tensor.pad %8 low[0, 0, 0] high[0, 0, %10] { - ^bb0(%arg0: index, %arg1: index, %arg2: index): - tensor.yield %cst : f16 - } : tensor<1x1281x?xf16> to tensor<1x1281x128xf16> - %11 = tensor.empty() : tensor<1x64x128xf16> - %12 = linalg.fill ins(%cst : f16) outs(%11 : tensor<1x64x128xf16>) -> tensor<1x64x128xf16> - %13 = scf.for %arg0 = %c0 to %c1281 step %c64 iter_args(%arg1 = %12) -> (tensor<1x64x128xf16>) { - %14 = affine.min #map6(%arg0) - %extracted_slice_4 = tensor.extract_slice %padded[0, 0, %arg0] [1, 64, %14] [1, 1, 1] : tensor<1x64x1281xf16> to tensor<1x64x?xf16> - %extracted_slice_5 = tensor.extract_slice %padded_2[0, %arg0, 0] [1, %14, 128] [1, 1, 1] : tensor<1x1281x128xf16> to tensor<1x?x128xf16> - %15 = linalg.batch_matmul ins(%extracted_slice_4, %extracted_slice_5 : tensor<1x64x?xf16>, tensor<1x?x128xf16>) outs(%arg1 : tensor<1x64x128xf16>) -> tensor<1x64x128xf16> - scf.yield %15 : tensor<1x64x128xf16> - } - %extracted_slice_3 = tensor.extract_slice %13[0, 0, 0] [1, %5, %6] [1, 1, 1] : tensor<1x64x128xf16> to tensor<1x?x?xf16> - flow.dispatch.tensor.store %extracted_slice_3, %2, offsets = [%workgroup_id_z, %3, %4], sizes = [1, %5, %6], strides = [1, 1, 1] : tensor<1x?x?xf16> -> !flow.dispatch.tensor> - return -} -// The padding on parallel dims is a nop because they are already padded. Skip -// the check for the testcase. -// ALL-LABEL: func.func @batch_matmul_pad_reduction_after_tiling -// ALL: %[[LHS_HANDLE:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> -// ALL: %[[RHS_HANDLE:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(1) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> -// ALL: %[[OUT_HANDLE:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(2) alignment(64) offset(%c0) : !flow.dispatch.tensor> -// ALL-DAG: %[[LHS:.+]] = flow.dispatch.tensor.load %[[LHS_HANDLE]] -// ALL-DAG: %[[RHS:.+]] = flow.dispatch.tensor.load %[[RHS_HANDLE]] -// REDUCTION: %[[INIT:.+]] = tensor.empty() : tensor<1x64x128xf16> -// REDUCTION: %[[FILL:.+]] = linalg.fill ins(%{{.+}}) outs(%[[INIT]] -// REDUCTION: %[[RES:.+]] = scf.for {{.+}} iter_args(%[[ITER:.+]] = %[[FILL]]) -// REDUCTION: %[[LHS_SLICE:.+]] = tensor.extract_slice %[[LHS]] -// REDUCTION: %[[PADDED_LHS:.+]] = tensor.pad %[[LHS_SLICE]] -// REDUCTION: } : tensor<1x?x?xf16> to tensor<1x64x64xf16> -// REDUCTION: %[[RHS_SLICE:.+]] = tensor.extract_slice %[[RHS]] -// REDUCTION: %[[PADDED_RHS:.+]] = tensor.pad %[[RHS_SLICE]] -// REDUCTION: } : tensor<1x?x?xf16> to tensor<1x64x128xf16> -// REDUCTION: %[[GEMM:.+]] = linalg.batch_matmul -// REDUCTION-SAME: ins(%[[PADDED_LHS]], %[[PADDED_RHS]] -// REDUCTION-SAME: outs(%[[ITER]] -// REDUCTION: scf.yield %[[GEMM]] -// REDUCTION: %[[OUT_SLICE:.+]] = tensor.extract_slice %[[RES]] -// REDUCTION: flow.dispatch.tensor.store %[[OUT_SLICE]], %[[OUT_HANDLE]] +// CHECK: %[[OUT_SLICE:.+]] = tensor.extract_slice %[[GEMM]] +// CHECK: flow.dispatch.tensor.store %[[OUT_SLICE]], %[[OUT_HANDLE]] diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp index 16a1acf4316f..bbdec5c83f6d 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp +++ b/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp @@ -884,6 +884,11 @@ setCooperativeMatrixConfig(IREE::GPU::TargetAttr target, linalg::LinalgOp op, Type lhsElem = getElementType(lhs); Type rhsElem = getElementType(rhs); Type initElem = getElementType(init); + // TODO(Max191): Support multiple M/N/K dimension problems for MMASchedules + // once the pipeline is able to support it. After adding multiple dimensions, + // all instances of schedule->m/nSubgroupCounts[0] and + // schedule->m/n/kTileSizes[0] need to use the full list of sizes instead of + // just the first element. GPUMatmulShapeType problem(dimM, dimN, dimK, lhsElem, rhsElem, initElem); SmallVector intrinsics; @@ -921,8 +926,9 @@ setCooperativeMatrixConfig(IREE::GPU::TargetAttr target, linalg::LinalgOp op, auto pipeline = CodeGenPipeline::SPIRVCooperativeMatrixVectorize; - std::array workgroupSize{schedule->nWarpCount * subgroupSize, - schedule->mWarpCount, 1}; + std::array workgroupSize{schedule->nSubgroupCounts[0] * + subgroupSize, + schedule->mSubgroupCounts[0], 1}; SmallVector vectorSizes(kIndex + 1, 0); if (isBM) @@ -934,21 +940,23 @@ setCooperativeMatrixConfig(IREE::GPU::TargetAttr target, linalg::LinalgOp op, SmallVector subgroupTileSizes(lastParallelDim + 1, 0); if (isBM) subgroupTileSizes[bIndex] = 1; - subgroupTileSizes[mIndex] = schedule->mTileCount * vectorSizes[mIndex]; - subgroupTileSizes[nIndex] = schedule->nTileCount * vectorSizes[nIndex]; + subgroupTileSizes[mIndex] = schedule->mTileSizes[0] * vectorSizes[mIndex]; + subgroupTileSizes[nIndex] = schedule->nTileSizes[0] * vectorSizes[nIndex]; SmallVector workgroupTileSizes(lastParallelDim + 1, 0); if (isBM) workgroupTileSizes[bIndex] = 1; - workgroupTileSizes[mIndex] = schedule->mWarpCount * subgroupTileSizes[mIndex]; - workgroupTileSizes[nIndex] = schedule->nWarpCount * subgroupTileSizes[nIndex]; + workgroupTileSizes[mIndex] = + schedule->mSubgroupCounts[0] * subgroupTileSizes[mIndex]; + workgroupTileSizes[nIndex] = + schedule->nSubgroupCounts[0] * subgroupTileSizes[nIndex]; // Also create one level for reduction. This is needed because of // SPIRVTileAndPromotePass requires it. // TODO(#10499): Consolidate tiling configuration across different pipelines. SmallVector reductionTileSizes; reductionTileSizes.append(kIndex, 0); - reductionTileSizes.push_back(schedule->kTileCount * schedule->kSize); + reductionTileSizes.push_back(schedule->kTileSizes[0] * schedule->kSize); TileSizesListType tileSizes = {workgroupTileSizes, subgroupTileSizes, reductionTileSizes, vectorSizes}; @@ -956,7 +964,7 @@ setCooperativeMatrixConfig(IREE::GPU::TargetAttr target, linalg::LinalgOp op, // Don't do multibuffering if the inner reduction loop is folded out. auto pipelineDepth = softwarePipelineDepth; auto storeStage = softwarePipelineStoreStage; - if (schedule->kTileCount <= 1) { + if (schedule->kTileSizes[0] <= 1) { pipelineDepth = 0; storeStage = 0; } diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_to_cooperative_ops.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_to_cooperative_ops.mlir index d97410b66230..d57d1631bd77 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_to_cooperative_ops.mlir +++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_to_cooperative_ops.mlir @@ -109,13 +109,13 @@ func.func @matmul_256x1024x128_div_add() attributes {translation_info = #transla // CHECK: gpu.barrier // CHECK: scf.for %[[IV_Y:.+]] = %[[OFFSET_Y]] to %[[C32]] step %[[C32]] // CHECK: %[[LHS_VIEW:.+]] = memref.subview %[[LHS_ALLOC]][%[[IV_Y]], 0] -// CHECK: %[[READ0:.+]] = vector.transfer_read %[[LHS_VIEW]][%[[C0]], %[[C0]]] -// CHECK: %[[READ1:.+]] = vector.transfer_read %[[LHS_VIEW]][%[[C0]], %[[C16]]] // CHECK: scf.for %[[IV_X:.+]] = %[[OFFSET_X]] to %[[C32]] step %[[C32]] // CHECK: %[[RHS_VIEW:.+]] = memref.subview %[[RHS_ALLOC]][0, %[[IV_X]]] -// CHECK: %[[READ2:.+]] = vector.transfer_read %[[RHS_VIEW]][%[[C0]], %[[C0]]] -// CHECK: %[[READ3:.+]] = vector.transfer_read %[[RHS_VIEW]][%[[C16]], %[[C0]]] -// CHECK: %[[READ4:.+]] = vector.transfer_read %{{.+}}[%[[C0]], %[[C0]]] +// CHECK-DAG: %[[READ0:.+]] = vector.transfer_read %[[LHS_VIEW]][%[[C0]], %[[C0]]] +// CHECK-DAG: %[[READ1:.+]] = vector.transfer_read %[[LHS_VIEW]][%[[C0]], %[[C16]]] +// CHECK-DAG: %[[READ2:.+]] = vector.transfer_read %[[RHS_VIEW]][%[[C0]], %[[C0]]] +// CHECK-DAG: %[[READ3:.+]] = vector.transfer_read %[[RHS_VIEW]][%[[C16]], %[[C0]]] +// CHECK-DAG: %[[READ4:.+]] = vector.transfer_read %{{.+}}[%[[C0]], %[[C0]]] // CHECK: %[[CT0:.+]] = vector.contract // CHECK-SAME: %[[READ0]], %[[READ2]], %[[READ4]] : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16> // CHECK: %[[CT1:.+]] = vector.contract @@ -246,13 +246,13 @@ func.func @matmul_256x1024x128_div_add() attributes {translation_info = #transla // CHECK: scf.for %[[IV_Z:.+]] = %[[ID_Z]] to %[[C1]] step %[[C1]] // CHECK: scf.for %[[IV_Y:.+]] = %[[OFFSET_Y]] to %[[C32]] step %[[C32]] // CHECK: %[[LHS_VIEW:.+]] = memref.subview %[[LHS_ALLOC]][%[[IV_Z]], %[[IV_Y]], 0] [1, 16, 32] -// CHECK: %[[READ0:.+]] = vector.transfer_read %[[LHS_VIEW]][%[[C0]], %[[C0]], %[[C0]]] -// CHECK: %[[READ1:.+]] = vector.transfer_read %[[LHS_VIEW]][%[[C0]], %[[C0]], %[[C16]]] // CHECK: scf.for %[[IV_X:.+]] = %[[OFFSET_X]] to %[[C32]] step %[[C32]] { // CHECK: %[[RHS_VIEW:.+]] = memref.subview %[[RHS_ALLOC]][%[[IV_Z]], 0, %[[IV_X]]] [1, 32, 16] -// CHECK: %[[READ2:.+]] = vector.transfer_read %[[RHS_VIEW]][%[[C0]], %[[C0]], %[[C0]]] -// CHECK: %[[READ3:.+]] = vector.transfer_read %[[RHS_VIEW]][%[[C0]], %[[C16]], %[[C0]]] -// CHECK: %[[READ4:.+]] = vector.transfer_read %{{.+}}[%[[C0]], %[[C0]], %[[C0]]] +// CHECK-DAG: %[[READ0:.+]] = vector.transfer_read %[[LHS_VIEW]][%[[C0]], %[[C0]], %[[C0]]] +// CHECK-DAG: %[[READ1:.+]] = vector.transfer_read %[[LHS_VIEW]][%[[C0]], %[[C0]], %[[C16]]] +// CHECK-DAG: %[[READ2:.+]] = vector.transfer_read %[[RHS_VIEW]][%[[C0]], %[[C0]], %[[C0]]] +// CHECK-DAG: %[[READ3:.+]] = vector.transfer_read %[[RHS_VIEW]][%[[C0]], %[[C16]], %[[C0]]] +// CHECK-DAG: %[[READ4:.+]] = vector.transfer_read %{{.+}}[%[[C0]], %[[C0]], %[[C0]]] // CHECK: %[[CT0:.+]] = vector.contract // CHECK-SAME: %[[READ0]], %[[READ2]], %[[READ4]] : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16> // CHECK: %[[CT1:.+]] = vector.contract @@ -369,13 +369,13 @@ func.func @matmul_256x1024x128_mixed_signedness_int8() { // CHECK: gpu.barrier // CHECK: scf.for %[[IV_Y:.+]] = %[[OFFSET_Y]] to %[[C32]] step %[[C32]] // CHECK: %[[LHS_VIEW:.+]] = memref.subview %[[LHS_ALLOC]][%[[IV_Y]], 0] -// CHECK: %[[READ0:.+]] = vector.transfer_read %[[LHS_VIEW]][%[[C0]], %[[C0]]] -// CHECK: %[[READ1:.+]] = vector.transfer_read %[[LHS_VIEW]][%[[C0]], %[[C16]]] // CHECK: scf.for %[[IV_X:.+]] = %[[OFFSET_X]] to %[[C32]] step %[[C32]] // CHECK: %[[RHS_VIEW:.+]] = memref.subview %[[RHS_ALLOC]][0, %[[IV_X]]] -// CHECK: %[[READ2:.+]] = vector.transfer_read %[[RHS_VIEW]][%[[C0]], %[[C0]]] -// CHECK: %[[READ3:.+]] = vector.transfer_read %[[RHS_VIEW]][%[[C16]], %[[C0]]] -// CHECK: %[[READ4:.+]] = vector.transfer_read %{{.+}}[%[[C0]], %[[C0]]] +// CHECK-DAG: %[[READ0:.+]] = vector.transfer_read %[[LHS_VIEW]][%[[C0]], %[[C0]]] +// CHECK-DAG: %[[READ1:.+]] = vector.transfer_read %[[LHS_VIEW]][%[[C0]], %[[C16]]] +// CHECK-DAG: %[[READ2:.+]] = vector.transfer_read %[[RHS_VIEW]][%[[C0]], %[[C0]]] +// CHECK-DAG: %[[READ3:.+]] = vector.transfer_read %[[RHS_VIEW]][%[[C16]], %[[C0]]] +// CHECK-DAG: %[[READ4:.+]] = vector.transfer_read %{{.+}}[%[[C0]], %[[C0]]] // CHECK: %[[EXTUI0:.+]] = arith.extui %[[READ0]] : vector<16x16xi8> to vector<16x16xi32> // CHECK: %[[EXTUI1:.+]] = arith.extui %[[READ1]] : vector<16x16xi8> to vector<16x16xi32> // CHECK: %[[EXTSI0:.+]] = arith.extsi %[[READ2]] : vector<16x16xi8> to vector<16x16xi32> diff --git a/compiler/src/iree/compiler/Codegen/Transforms/Transforms.cpp b/compiler/src/iree/compiler/Codegen/Transforms/Transforms.cpp index 0b8c49c4be69..6d0d05277b11 100644 --- a/compiler/src/iree/compiler/Codegen/Transforms/Transforms.cpp +++ b/compiler/src/iree/compiler/Codegen/Transforms/Transforms.cpp @@ -36,20 +36,29 @@ namespace mlir::iree_compiler { -static bool isAllConstantValue(SmallVector ofrs, int64_t v) { +static bool isAllConstantValue(ArrayRef ofrs, int64_t v) { return llvm::all_of( ofrs, [&](OpFoldResult ofr) { return isConstantIntValue(ofr, v); }); } -static bool isFullSlice(SmallVector mixedOffsets, - SmallVector mixedSizes, - SmallVector mixedStrides, - IREE::Flow::DispatchTensorType tensorType) { - std::optional> constSizes = - getConstantIntValues(mixedSizes); +static bool isFullSlice(ArrayRef mixedOffsets, + ArrayRef mixedSizes, + ArrayRef mixedStrides, + IREE::Flow::DispatchTensorType tensorType, + ValueRange dynamicDims) { + OpBuilder builder(tensorType.getContext()); + SmallVector tensorShape = llvm::to_vector(tensorType.getShape()); + SmallVector mixedTensorShape = + mlir::getMixedValues(tensorShape, dynamicDims, builder); return isAllConstantValue(mixedOffsets, 0) && - isAllConstantValue(mixedStrides, 1) && constSizes && - llvm::equal(tensorType.getShape(), *constSizes); + isAllConstantValue(mixedStrides, 1) && mixedTensorShape == mixedSizes; +} +static bool isFullSlice(OffsetSizeAndStrideOpInterface sliceLoadStoreOp, + IREE::Flow::DispatchTensorType tensorType, + ValueRange dynamicDims) { + return isFullSlice( + sliceLoadStoreOp.getMixedOffsets(), sliceLoadStoreOp.getMixedSizes(), + sliceLoadStoreOp.getMixedStrides(), tensorType, dynamicDims); } static bool sliceFilter(Operation *op, ValueRange nonIndexComputationOperands, @@ -546,14 +555,29 @@ void moveLoopInvariantCodeFromGuaranteedLoops(Operation *target) { namespace { -// TODO(antigainst): enable dynamic shape support once they are needed. -template -static std::optional getStaticReshapeOpSrc(TensorReshapeOp reshapeOp) { - auto reshapeSrcType = llvm::cast(reshapeOp.getSrc().getType()); - auto reshapeDstType = llvm::cast(reshapeOp.getType()); - if (!reshapeSrcType.hasStaticShape() || !reshapeDstType.hasStaticShape()) - return std::nullopt; - return reshapeOp.getSrc(); +static SmallVector +inferCollapsedShape(RewriterBase &rewriter, Location loc, + RankedTensorType expandedType, + ArrayRef reassociations, + ValueRange expandedDynamicDims) { + ArrayRef expandedStaticShape = expandedType.getShape(); + SmallVector expandedMixedShape = + mlir::getMixedValues(expandedStaticShape, expandedDynamicDims, rewriter); + SmallVector collapsedShape; + unsigned expandedShapeDim = 0; + for (auto reassociation : reassociations) { + AffineExpr mulExpr = rewriter.getAffineSymbolExpr(0); + for (auto i : llvm::seq(1, reassociation.size())) { + mulExpr = mulExpr * rewriter.getAffineSymbolExpr(i); + } + auto collapsedDim = affine::makeComposedFoldedAffineApply( + rewriter, loc, mulExpr, + ArrayRef(expandedMixedShape) + .slice(expandedShapeDim, reassociation.size())); + collapsedShape.push_back(collapsedDim); + expandedShapeDim += reassociation.size(); + } + return collapsedShape; } /// Folds tensor.expand/collapse_shape into the source @@ -576,35 +600,38 @@ static std::optional getStaticReshapeOpSrc(TensorReshapeOp reshapeOp) { /// !flow.dispatch.tensor> /// %0 = flow.dispatch.tensor.load %subspan : /// !flow.dispatch.tensor> -> tensor<864xf32> -template -struct FoldReshapeIntoInterfaceTensorLoad : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +struct FoldCollapseShapeIntoInterfaceTensorLoad + : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp, + LogicalResult matchAndRewrite(tensor::CollapseShapeOp reshapeOp, PatternRewriter &rewriter) const override { - std::optional reshapeSrc = - getStaticReshapeOpSrc(reshapeOp); - if (!reshapeSrc) - return failure(); - - auto loadOp = - reshapeSrc->template getDefiningOp(); + Value reshapeSrc = reshapeOp.getSrc(); + auto reshapeSrcType = cast(reshapeSrc.getType()); + auto loadOp = reshapeSrc.getDefiningOp(); if (!loadOp) return failure(); // Make sure we are loading the full incoming subspan. Otherwise we cannot // simply adjust the subspan's resultant type later. - if (!isFullSlice(loadOp.getMixedOffsets(), loadOp.getMixedSizes(), - loadOp.getMixedStrides(), loadOp.getSourceType())) { + if (!isFullSlice(loadOp, loadOp.getSourceType(), loadOp.getSourceDims())) { return failure(); } - auto subspanOp = - loadOp.getSource() - .template getDefiningOp(); + auto subspanOp = loadOp.getSource() + .getDefiningOp(); if (!subspanOp) return failure(); - assert(subspanOp.getDynamicDims().empty()); + + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(subspanOp); + SmallVector collapsedShape = inferCollapsedShape( + rewriter, subspanOp.getLoc(), reshapeSrcType, + reshapeOp.getReassociationIndices(), subspanOp.getDynamicDims()); + SmallVector collapsedStaticShape; + SmallVector collapsedDynamicShape; + dispatchIndexOpFoldResults(collapsedShape, collapsedDynamicShape, + collapsedStaticShape); auto tensorAccess = llvm::cast(subspanOp.getType()) @@ -615,12 +642,111 @@ struct FoldReshapeIntoInterfaceTensorLoad : OpRewritePattern { Value newSubspanOp = rewriter.create( subspanOp.getLoc(), newSubspanType, subspanOp.getLayout(), subspanOp.getBinding(), subspanOp.getByteOffset(), - subspanOp.getDynamicDims(), subspanOp.getAlignmentAttr(), + collapsedDynamicShape, subspanOp.getAlignmentAttr(), subspanOp.getDescriptorFlagsAttr()); + rewriter.setInsertionPoint(reshapeOp); + rewriter.replaceOpWithNewOp( + reshapeOp, reshapeOp.getResultType(), newSubspanOp, + collapsedDynamicShape); + + return success(); + } +}; + +/// Folds tensor.expand_shape into the source +/// hal.interface.binding.subspan. +/// +/// For example, this matches the following pattern: +/// +/// %subspan = hal.interface.binding.subspan ... : +/// !flow.dispatch.tensor> +/// %tensor = flow.dispatch.tensor.load %subspan : +/// !flow.dispatch.tensor> -> +/// tensor<3x3x1x96xf32> +/// %0 = linalg.expand_reshape %tensor [ +/// affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +/// ] : tensor<3x3x1x96xf32> into tensor<864xf32> +/// +/// And turns it into: +/// +/// %subspan = hal.interface.binding.subspan ... : +/// !flow.dispatch.tensor> +/// %0 = flow.dispatch.tensor.load %subspan : +/// !flow.dispatch.tensor> -> tensor<864xf32> +struct FoldExpandShapeIntoInterfaceTensorLoad + : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::ExpandShapeOp reshapeOp, + PatternRewriter &rewriter) const override { + Value reshapeSrc = reshapeOp.getSrc(); + auto loadOp = reshapeSrc.getDefiningOp(); + if (!loadOp) { + return failure(); + } + + // Make sure we are loading the full incoming subspan. Otherwise we cannot + // simply adjust the subspan's resultant type later. + if (!isFullSlice(loadOp, loadOp.getSourceType(), loadOp.getSourceDims())) { + return failure(); + } + + // In the corner case where the expand_shape is the source of a store, dont + // fold with the load. Instead fold with the store to reduce the + // dimensionality + if (reshapeOp->hasOneUse()) { + if (auto storeOp = dyn_cast( + *reshapeOp->getUsers().begin())) { + if (isFullSlice(storeOp, storeOp.getTargetType(), + storeOp.getTargetDims())) { + return rewriter.notifyMatchFailure(reshapeOp, + "fold with store instead"); + } + } + } + + auto subspanOp = loadOp.getSource() + .getDefiningOp(); + if (!subspanOp) + return failure(); + + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(subspanOp); + + auto currDynamicDims = subspanOp.getDynamicDims(); + auto currStaticDims = loadOp.getType().getShape(); + auto currOfrDynamicDims = + mlir::getMixedValues(currStaticDims, currDynamicDims, rewriter); + std::optional> expandedDims = + mlir::inferExpandShapeOutputShape( + rewriter, subspanOp.getLoc(), reshapeOp.getType(), + reshapeOp.getReassociationIndices(), currOfrDynamicDims); + if (!expandedDims) { + return reshapeOp.emitOpError("failure in expanded shape"); + } + + auto tensorAccess = + llvm::cast(subspanOp.getType()) + .getAccess(); + auto newSubspanType = IREE::Flow::DispatchTensorType::get( + tensorAccess, reshapeOp.getResultType()); + + SmallVector expandedDynamicDims; + SmallVector expandedStaticDims; + dispatchIndexOpFoldResults(expandedDims.value(), expandedDynamicDims, + expandedStaticDims); + + Value newSubspanOp; + newSubspanOp = rewriter.create( + subspanOp.getLoc(), newSubspanType, subspanOp.getLayout(), + subspanOp.getBinding(), subspanOp.getByteOffset(), expandedDynamicDims, + subspanOp.getAlignmentAttr(), subspanOp.getDescriptorFlagsAttr()); + + rewriter.setInsertionPoint(reshapeOp); rewriter.replaceOpWithNewOp( reshapeOp, reshapeOp.getResultType(), newSubspanOp, - loadOp.getSourceDims()); + expandedDynamicDims); return success(); } @@ -652,8 +778,8 @@ struct FoldExpandShapeIntoInterfaceTensorStore PatternRewriter &rewriter) const override { // Make sure we are storing the full incoming subspan. Otherwise we cannot // simply adjust the subspan's resultant type later. - if (!isFullSlice(storeOp.getMixedOffsets(), storeOp.getMixedSizes(), - storeOp.getMixedStrides(), storeOp.getTargetType())) { + if (!isFullSlice(storeOp, storeOp.getTargetType(), + storeOp.getTargetDims())) { return failure(); } @@ -662,38 +788,136 @@ struct FoldExpandShapeIntoInterfaceTensorStore return failure(); } - // Dynamic shapes are currently unsupported. - std::optional reshapeSrc = - getStaticReshapeOpSrc(reshapeOp); - if (!reshapeSrc) - return failure(); + Value reshapeSrc = reshapeOp.getSrc(); + // If the source is a `flow.dispatch.tensor.load`, fold with the load + // instead to reduce dimensionality of the problem + if (auto loadOp = + reshapeSrc.getDefiningOp()) { + if (isFullSlice(loadOp, loadOp.getSourceType(), loadOp.getSourceDims())) { + return rewriter.notifyMatchFailure( + storeOp, "fold expand_shape with load instead"); + } + } - auto subspanOp = - storeOp.getTarget() - .template getDefiningOp(); + auto subspanOp = storeOp.getTarget() + .getDefiningOp(); if (!subspanOp) return failure(); - assert(subspanOp.getDynamicDims().empty()); + + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(subspanOp); + SmallVector collapsedShape = inferCollapsedShape( + rewriter, subspanOp.getLoc(), reshapeOp.getResultType(), + reshapeOp.getReassociationIndices(), subspanOp.getDynamicDims()); + SmallVector collapsedStaticShape; + SmallVector collapsedDynamicShape; + dispatchIndexOpFoldResults(collapsedShape, collapsedDynamicShape, + collapsedStaticShape); auto tensorAccess = llvm::cast(subspanOp.getType()) .getAccess(); - auto newSubspanType = IREE::Flow::DispatchTensorType::get( - tensorAccess, reshapeSrc->getType()); + auto newSubspanType = + IREE::Flow::DispatchTensorType::get(tensorAccess, reshapeSrc.getType()); - Value newSubspanOp; - { - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPointAfter(subspanOp); - newSubspanOp = rewriter.create( - subspanOp.getLoc(), newSubspanType, subspanOp.getLayout(), - subspanOp.getBinding(), subspanOp.getByteOffset(), - subspanOp.getDynamicDims(), subspanOp.getAlignmentAttr(), - subspanOp.getDescriptorFlagsAttr()); + Value newSubspanOp = rewriter.create( + subspanOp.getLoc(), newSubspanType, subspanOp.getLayout(), + subspanOp.getBinding(), subspanOp.getByteOffset(), + collapsedDynamicShape, subspanOp.getAlignmentAttr(), + subspanOp.getDescriptorFlagsAttr()); + + rewriter.setInsertionPoint(storeOp); + rewriter.replaceOpWithNewOp( + storeOp, reshapeSrc, newSubspanOp, collapsedDynamicShape); + + return success(); + } +}; + +/// Folds tensor.collapse_shape into the source hal.interface.binding.subspan. +/// +/// For example, this matches the following pattern: +/// +/// %subspan = hal.interface.binding.subspan ... : +/// !flow.dispatch.tensor> +/// %0 = tensor.collapse_shape %tensor [[0, 1, 2, 3]] +/// : tensor<3x?x?x96xf32> into tensor +/// %tensor = flow.dispatch.tensor.store %0, %subspan : +/// tensor -> !flow.dispatch.tensor>{%dim} +/// +/// And turns it into: +/// +/// %subspan = hal.interface.binding.subspan ... : +/// !flow.dispatch.tensor> +/// %0 = flow.dispatch.tensor.store %tensor, %subspan : +/// tensor<3x?x?x96xf32> -> +/// !flow.dispatch.tensor>{%d0, %d1} +/// +/// TODO: This handles full slices. The pattern below +/// (`FoldCollapseShapeIntoTensorInsertSlice`) handles cases where the slic is +/// not a full slice, but requires the shapes to be static. This pattern handles +/// dynamic shapes as well. Combine the two (if possible, it isnt clear that it +/// is possible) +struct FoldCollapseShapeIntoInterfaceTensorStoreFullSlice + : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(IREE::Flow::DispatchTensorStoreOp storeOp, + PatternRewriter &rewriter) const override { + // Make sure we are storing the full incoming subspan. Otherwise we cannot + // simply adjust the subspan's resultant type later. + if (!isFullSlice(storeOp, storeOp.getTargetType(), + storeOp.getTargetDims())) { + return failure(); } + auto reshapeOp = + storeOp.getValue().getDefiningOp(); + if (!reshapeOp) { + return failure(); + } + auto subspanOp = storeOp.getTarget() + .getDefiningOp(); + if (!subspanOp) + return failure(); + + Value reshapeSrc = reshapeOp.getSrc(); + auto reshapeSrcType = cast(reshapeSrc.getType()); + + // Compute the type and dynamic dims of the interface binding. + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(subspanOp); + auto dynamicDims = subspanOp.getDynamicDims(); + ArrayRef staticShape = reshapeOp.getType().getShape(); + SmallVector mixedShape = + mlir::getMixedValues(staticShape, dynamicDims, rewriter); + std::optional> expandedShape = + mlir::inferExpandShapeOutputShape( + rewriter, subspanOp.getLoc(), + cast(reshapeSrc.getType()), + reshapeOp.getReassociationIndices(), mixedShape); + if (!expandedShape) { + return rewriter.notifyMatchFailure( + storeOp, "failed to compute expand shape for interface binding"); + } + SmallVector expandedStaticShape; + SmallVector expandedDynamicShape; + dispatchIndexOpFoldResults(*expandedShape, expandedDynamicShape, + expandedStaticShape); + + auto tensorAccess = + cast(subspanOp.getType()).getAccess(); + auto newSubspanType = + IREE::Flow::DispatchTensorType::get(tensorAccess, reshapeSrcType); + + auto newSubspanOp = rewriter.create( + subspanOp.getLoc(), newSubspanType, subspanOp.getLayout(), + subspanOp.getBinding(), subspanOp.getByteOffset(), expandedDynamicShape, + subspanOp.getAlignmentAttr(), subspanOp.getDescriptorFlagsAttr()); + + rewriter.setInsertionPoint(storeOp); rewriter.replaceOpWithNewOp( - storeOp, *reshapeSrc, newSubspanOp, storeOp.getTargetDims()); + storeOp, reshapeSrc, newSubspanOp, expandedDynamicShape); return success(); } @@ -840,12 +1064,11 @@ struct FoldCollapseShapeIntoInterfaceTensorStore } // namespace void populateReshapeToInterfaceTensorPatterns(RewritePatternSet &patterns) { - patterns.insert, - FoldReshapeIntoInterfaceTensorLoad>( - patterns.getContext()); - patterns.insert( - patterns.getContext()); - patterns.insert( + patterns.insert( patterns.getContext()); } diff --git a/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.cpp b/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.cpp index 5eb4519ec8ce..e996aba997b1 100644 --- a/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.cpp +++ b/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.cpp @@ -212,6 +212,10 @@ getGPUScfTileSizeComputeFn(mlir::FunctionOpInterface funcOp, int tilingLevel) { return computeFn; } +bool isNonZeroRank(TypedValue val) { + return val.getType().getRank() != 0; +} + //===----------------------------------------------------------------------===// // GPU workgroup memory //===----------------------------------------------------------------------===// diff --git a/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.h b/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.h index cdbc297cb4c1..4e7c108f7c19 100644 --- a/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.h +++ b/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.h @@ -105,6 +105,9 @@ FailureOr> getGPUTileSize(mlir::FunctionOpInterface funcOp, FailureOr getGPUScfTileSizeComputeFn(mlir::FunctionOpInterface funcOp, int tilingLevel); +/// Returns true iff the rank of the input value 'val' is non-zero. +bool isNonZeroRank(TypedValue val); + //===----------------------------------------------------------------------===// // GPU workgroup memory //===----------------------------------------------------------------------===// diff --git a/compiler/src/iree/compiler/Codegen/Utils/LinkingUtils.cpp b/compiler/src/iree/compiler/Codegen/Utils/LinkingUtils.cpp index ad4e543d70b3..003d3f759d0e 100644 --- a/compiler/src/iree/compiler/Codegen/Utils/LinkingUtils.cpp +++ b/compiler/src/iree/compiler/Codegen/Utils/LinkingUtils.cpp @@ -67,7 +67,8 @@ renameWithDisambiguatedName(Operation *op, Operation *moduleOp, // symbol tracked in |targetSymbolMap|. LogicalResult mergeModuleInto(Operation *sourceModuleOp, Operation *targetModuleOp, - DenseMap &targetSymbolMap) { + DenseMap &targetSymbolMap, + std::function canRenameSymbol) { auto &sourceBlock = sourceModuleOp->getRegion(0).front(); auto &targetBlock = targetModuleOp->getRegion(0).front(); SymbolTable sourceSymbolTable(sourceModuleOp); @@ -90,15 +91,19 @@ mergeModuleInto(Operation *sourceModuleOp, Operation *targetModuleOp, // use the existing target op. continue; } - if (symbolOp.getVisibility() == SymbolTable::Visibility::Private) { + if (canRenameSymbol(symbolOp)) { // Since the source symbol is private we can rename it as all uses // are known to be local to the source module. renameWithDisambiguatedName(sourceOp, sourceModuleOp, targetSymbolMap, &sourceSymbolTable); } else { // The source symbol has 'nested' or 'public' visibility. - if (SymbolTable::getSymbolVisibility(targetOp) != - SymbolTable::Visibility::Private) { + if (canRenameSymbol(targetOp)) { + // Keep the original name for our new op, rename the target op. + renameWithDisambiguatedName(targetOp, targetModuleOp, + targetSymbolMap, + /*optionalSymbolTable=*/nullptr); + } else { // Oops! Both symbols are public and we can't safely rename either. // If you hit this with ops that you think are safe to rename, mark // them private. @@ -109,11 +114,6 @@ mergeModuleInto(Operation *sourceModuleOp, Operation *targetModuleOp, // where that isn't true. return sourceOp->emitError() << "multiple public symbols with the name: " << symbolName; - } else { - // Keep the original name for our new op, rename the target op. - renameWithDisambiguatedName(targetOp, targetModuleOp, - targetSymbolMap, - /*optionalSymbolTable=*/nullptr); } } } diff --git a/compiler/src/iree/compiler/Codegen/Utils/LinkingUtils.h b/compiler/src/iree/compiler/Codegen/Utils/LinkingUtils.h index cf4ca4db47b5..a33f168d08c2 100644 --- a/compiler/src/iree/compiler/Codegen/Utils/LinkingUtils.h +++ b/compiler/src/iree/compiler/Codegen/Utils/LinkingUtils.h @@ -19,6 +19,11 @@ gatherExecutableTargets(ArrayRef executableOps); // TODO(benvanik): replace with iree/compiler/Utils/ModuleUtils.h version. // Only difference is one has the symbol map that we don't even need. +static inline bool allowRenamingPrivateSymbols(Operation *op) { + return SymbolTable::getSymbolVisibility(op) == + SymbolTable::Visibility::Private; +} + // Destructively merges |sourceModuleOp| into |targetModuleOp|. // |targetSymbolMap| is updated with the new symbols. // @@ -29,7 +34,9 @@ gatherExecutableTargets(ArrayRef executableOps); // symbol tracked in |targetSymbolMap|. LogicalResult mergeModuleInto(Operation *sourceModuleOp, Operation *targetModuleOp, - DenseMap &targetSymbolMap); + DenseMap &targetSymbolMap, + std::function canRenameSymbol = + allowRenamingPrivateSymbols); // Links all executables for the current target found in |moduleOp| into // |linkedExecutableOp|. Functions will be moved into |linkedModuleOp|. diff --git a/compiler/src/iree/compiler/Codegen/Utils/VectorOpUtils.cpp b/compiler/src/iree/compiler/Codegen/Utils/VectorOpUtils.cpp index 83ffaf879a16..94cc749df9ba 100644 --- a/compiler/src/iree/compiler/Codegen/Utils/VectorOpUtils.cpp +++ b/compiler/src/iree/compiler/Codegen/Utils/VectorOpUtils.cpp @@ -18,7 +18,7 @@ std::pair VectorContractOpInfo::getOperandMNIndex() const { // Returns the (LHS K, RHS K) dimension index pair. std::pair VectorContractOpInfo::getOperandKIndex() const { - return std::make_pair(lhsKDim, rhsKDim); + return std::make_pair(lhsKDim.back(), rhsKDim.back()); } // Returns the result (M, N) dimension index pair. @@ -55,9 +55,12 @@ VectorContractOpInfo::inferFromIndexingMaps(ArrayRef maps) { opInfo.outNDims.push_back( *maps[2].getResultPosition(getAffineDimExpr(n, ctx))); } - int64_t k = contractionDims.k.back(); - opInfo.lhsKDim = *maps[0].getResultPosition(getAffineDimExpr(k, ctx)); - opInfo.rhsKDim = *maps[1].getResultPosition(getAffineDimExpr(k, ctx)); + for (auto k : contractionDims.k) { + opInfo.lhsKDim.push_back( + *maps[0].getResultPosition(getAffineDimExpr(k, ctx))); + opInfo.rhsKDim.push_back( + *maps[1].getResultPosition(getAffineDimExpr(k, ctx))); + } opInfo.lhsUnitDims = maps[0].getBroadcastDims(); opInfo.rhsUnitDims = maps[1].getBroadcastDims(); diff --git a/compiler/src/iree/compiler/Codegen/Utils/VectorOpUtils.h b/compiler/src/iree/compiler/Codegen/Utils/VectorOpUtils.h index 101bf27ed5af..b8bde250d120 100644 --- a/compiler/src/iree/compiler/Codegen/Utils/VectorOpUtils.h +++ b/compiler/src/iree/compiler/Codegen/Utils/VectorOpUtils.h @@ -49,9 +49,9 @@ class VectorContractOpInfo { int64_t getBatchCount() const { return contractionDims.batch.size(); } SmallVector lhsMDims; - int64_t lhsKDim; + SmallVector lhsKDim; SmallVector rhsNDims; - int64_t rhsKDim; + SmallVector rhsKDim; SmallVector outMDims; SmallVector outNDims; diff --git a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.cpp b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.cpp index 5e5b2efed3cc..f19a665d9258 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.cpp @@ -1347,6 +1347,27 @@ LogicalResult verifyDispatchWorkgroupInfoOp(Operation *op, uint64_t dimension) { return success(); } +//===----------------------------------------------------------------------===// +// flow.dispatch.workload.ordinal +//===----------------------------------------------------------------------===// + +void DispatchWorkloadOrdinalOp::inferResultDivisibility( + ArrayRef argDivs, + IREE::Util::SetIntDivisibilityFn setResultDivisibility) { + if (argDivs[0].isUninitialized()) { + setResultDivisibility(getResult(), + IREE::Util::ConstantIntDivisibility(1, 1)); + return; + } + setResultDivisibility(getResult(), argDivs[0].getValue()); +} + +void DispatchWorkloadOrdinalOp::inferResultRanges( + ArrayRef argRanges, SetIntRangeFn setResultRange) { + assert(!argRanges.empty() && "expected range of input to be set"); + setResultRange(getResult(), argRanges[0]); +} + //===----------------------------------------------------------------------===// // flow.executable //===----------------------------------------------------------------------===// diff --git a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.td b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.td index 301ce8b15b28..69d8cc419382 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.td +++ b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.td @@ -16,6 +16,7 @@ include "mlir/IR/OpAsmInterface.td" include "mlir/IR/SymbolInterfaces.td" include "mlir/Interfaces/CallInterfaces.td" include "mlir/Interfaces/ControlFlowInterfaces.td" +include "mlir/Interfaces/InferIntRangeInterface.td" include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/ViewLikeInterface.td" @@ -1741,7 +1742,11 @@ def FLOW_DispatchWorkgroupCountFromSliceOp : } def FLOW_DispatchWorkloadOrdinalOp : - FLOW_PureOp<"dispatch.workload.ordinal"> { + FLOW_PureOp<"dispatch.workload.ordinal", [ + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods + ]> { let arguments = (ins Index:$operand, IndexAttr:$ordinal diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.td b/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.td index 976a4cabccd1..3c1ebd7c9864 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.td +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.td @@ -336,6 +336,11 @@ def HAL_CollectiveElementType_Float16 : I32EnumAttrCase<"Float16", 8, "f16">; def HAL_CollectiveElementType_Float32 : I32EnumAttrCase<"Float32", 9, "f32">; def HAL_CollectiveElementType_Float64 : I32EnumAttrCase<"Float64", 10, "f64">; def HAL_CollectiveElementType_BFloat16 : I32EnumAttrCase<"BFloat16", 11, "bf16">; +def HAL_CollectiveElementType_Float8E5M2 : I32EnumAttrCase<"Float8E5M2", 12, "f8E5M2">; +def HAL_CollectiveElementType_Float8E4M3 : I32EnumAttrCase<"Float8E4M3", 13, "f8E4M3">; +def HAL_CollectiveElementType_Float8E5M2FNUZ : I32EnumAttrCase<"Float8E5M2FNUZ", 14, "f8E5M2FNUZ">; +def HAL_CollectiveElementType_Float8E4M3FNUZ : I32EnumAttrCase<"Float8E4M3FNUZ", 15, "f8E4M3FNUZ">; + def HAL_CollectiveElementTypeAttr : I32EnumAttr<"CollectiveElementType", "valid CollectiveElementType", [ HAL_CollectiveElementType_Sint8, @@ -350,6 +355,10 @@ def HAL_CollectiveElementTypeAttr : HAL_CollectiveElementType_Float32, HAL_CollectiveElementType_Float64, HAL_CollectiveElementType_BFloat16, + HAL_CollectiveElementType_Float8E5M2, + HAL_CollectiveElementType_Float8E4M3, + HAL_CollectiveElementType_Float8E5M2FNUZ, + HAL_CollectiveElementType_Float8E4M3FNUZ ]> { let cppNamespace = "::mlir::iree_compiler::IREE::HAL"; } diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp index d06c2dc892c8..81f8da846eb1 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp @@ -22,6 +22,27 @@ namespace mlir::iree_compiler::IREE::HAL { +namespace { + +// We aribtrarily say that unbounded dimensions in a torch program cannot +// exceed 53bits, making the maximum safe dimension 9007199254740991. The +// astute reader will note that this is also the maximum safe value in +// JavaScript, which also "happens" to be the largest mantissa value in a +// 64bit double. We need a maximum and in the absence of a better choice, +// with this one we are at least in good company. This limit is also used +// in the frontends. +static constexpr uint64_t MAX_DIM_VALUE = (static_cast(1) << 53) - 1; + +// Similarly we use a very conservative maximum rank value for specifying +// ranges of runtime rank resolution functions. Various frameworks have hard +// and practical limits ranging from 32 (numpy) to hundreds. At the time of +// writing, PyTorch throws weird errors if trying to print a tensor with a rank +// greater than 992. We really just want a smallish integer value to bound +// arithmetic, so we use an arbitrary maximum. +static constexpr uint64_t MAX_RANK_VALUE = 4096; + +} // namespace + //===----------------------------------------------------------------------===// // custom($descriptor_type) //===----------------------------------------------------------------------===// @@ -878,6 +899,10 @@ enum class NumericalType : uint32_t { kFloatIEEE = kFloat | 0x01, kFloatBrain = kFloat | 0x02, kFloatComplex = kFloat | 0x03, + kFloat8E5M2 = kFloat | 0x04, + kFloat8E4M3 = kFloat | 0x05, + kFloat8E5M2FNUZ = kFloat | 0x06, + kFloat8E4M3FNUZ = kFloat | 0x07, }; constexpr inline int32_t makeElementTypeValue(NumericalType numericalType, @@ -905,7 +930,14 @@ std::optional ElementTypeOp::getTypeValue(Type type) { return makeElementTypeValue(numericalType, intType.getWidth()); } else if (auto floatType = llvm::dyn_cast_if_present(type)) { switch (APFloat::SemanticsToEnum(floatType.getFloatSemantics())) { + case APFloat::S_Float8E5M2: + return makeElementTypeValue(NumericalType::kFloat8E5M2, 8); + case APFloat::S_Float8E4M3: + return makeElementTypeValue(NumericalType::kFloat8E4M3, 8); + case APFloat::S_Float8E5M2FNUZ: + return makeElementTypeValue(NumericalType::kFloat8E5M2FNUZ, 8); case APFloat::S_Float8E4M3FNUZ: + return makeElementTypeValue(NumericalType::kFloat8E4M3FNUZ, 8); case APFloat::S_IEEEhalf: case APFloat::S_IEEEsingle: case APFloat::S_IEEEdouble: @@ -1013,6 +1045,30 @@ void BufferViewBufferOp::getAsmResultNames( setNameFn(getResult(), "buffer"); } +//===----------------------------------------------------------------------===// +// hal.buffer_view.dim +//===----------------------------------------------------------------------===// + +void BufferViewDimOp::inferResultRangesFromOptional( + ArrayRef argRanges, SetIntLatticeFn setResultRange) { + const unsigned indexTypeNumBits = 64; + setResultRange(getResult(), IntegerValueRange(ConstantIntRanges::fromUnsigned( + APInt::getZero(indexTypeNumBits), + APInt(indexTypeNumBits, MAX_DIM_VALUE)))); +} + +//===----------------------------------------------------------------------===// +// hal.buffer_view.dim +//===----------------------------------------------------------------------===// + +void BufferViewRankOp::inferResultRangesFromOptional( + ArrayRef argRanges, SetIntLatticeFn setResultRange) { + const unsigned indexTypeNumBits = 64; + setResultRange(getResult(), IntegerValueRange(ConstantIntRanges::fromUnsigned( + APInt::getZero(indexTypeNumBits), + APInt(indexTypeNumBits, MAX_RANK_VALUE)))); +} + //===----------------------------------------------------------------------===// // hal.channel.create //===----------------------------------------------------------------------===// diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.h b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.h index ae58127959bb..16dd46bc5e17 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.h +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.h @@ -20,6 +20,7 @@ #include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/SymbolTable.h" +#include "mlir/Interfaces/InferIntRangeInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Interfaces/ViewLikeInterface.h" diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td index fdd43b7a5e72..9e370a10c22b 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td @@ -18,6 +18,7 @@ include "mlir/IR/OpAsmInterface.td" include "mlir/IR/SymbolInterfaces.td" include "mlir/Interfaces/CallInterfaces.td" include "mlir/Interfaces/ControlFlowInterfaces.td" +include "mlir/Interfaces/InferIntRangeInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/ViewLikeInterface.td" @@ -1010,7 +1011,10 @@ def HAL_BufferViewEncodingTypeOp : HAL_PureOp<"buffer_view.encoding_type"> { }]; } -def HAL_BufferViewRankOp : HAL_PureOp<"buffer_view.rank"> { +def HAL_BufferViewRankOp : HAL_PureOp<"buffer_view.rank", [ + DeclareOpInterfaceMethods, +]> { let summary = [{buffer view rank query}]; let description = [{ Returns the rank of the buffer view. @@ -1030,7 +1034,10 @@ def HAL_BufferViewRankOp : HAL_PureOp<"buffer_view.rank"> { }]; } -def HAL_BufferViewDimOp : HAL_PureOp<"buffer_view.dim"> { +def HAL_BufferViewDimOp : HAL_PureOp<"buffer_view.dim", [ + DeclareOpInterfaceMethods, +]> { let summary = [{buffer view dimension value query}]; let description = [{ Returns the value of the given dimension. diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp index b6cd4cc53677..7fc985bf67ab 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp @@ -79,7 +79,8 @@ static Value reciprocalValue(OpBuilder &b, Location loc, Value input, } static Value truncateFloat(OpBuilder &builder, Location loc, AffineMap inputMap, - AffineMap outputMap, Value value, Value output) { + AffineMap outputMap, Value value, Value output, + bool clampToFPRange) { SmallVector compressedMaps = compressUnusedDims(SmallVector{inputMap, outputMap}); inputMap = compressedMaps[0]; @@ -94,19 +95,23 @@ static Value truncateFloat(OpBuilder &builder, Location loc, AffineMap inputMap, auto srcTy = cast(args[0].getType()); auto dstTy = cast(args[1].getType()); - double mxDbl = - APFloat::getLargest(dstTy.getFloatSemantics(), /*Negative=*/false) - .convertToDouble(); + Value input = args[0]; - // Clamp input to dstTy(usually `fp8`) MAX value to prevent NaNs. - // We do not clamp for `-MAX` because this function meant to only be - // used by attention's exp2 who's value is always > 0. - Value mx = builder.create( - loc, builder.getFloatAttr(srcTy, mxDbl)); - Value clamp = b.create(loc, mx, args[0]); + if (clampToFPRange) { + double mxDbl = + APFloat::getLargest(dstTy.getFloatSemantics(), /*Negative=*/false) + .convertToDouble(); + + // Clamp input to dstTy(usually `fp8`) MAX value to prevent NaNs. + // We do not clamp for `-MAX` because this function meant to only be + // used by attention's exp2 who's value is always > 0. + Value mx = builder.create( + loc, builder.getFloatAttr(srcTy, mxDbl)); + input = b.create(loc, mx, input); + } // Convert scale to the same datatype as input. - Value trunc = convertScalarToDtype(b, loc, clamp, dstTy, + Value trunc = convertScalarToDtype(b, loc, input, dstTy, /*isUnsignedCast=*/false); b.create(loc, trunc); }); @@ -294,43 +299,29 @@ static bool willBeContiguousSlice(OpFoldResult inputSize, OpFoldResult tileSize, } //===----------------------------------------------------------------------===// -// OnlineAttentionOp +// Attention Helpers //===----------------------------------------------------------------------===// -FailureOr> -OnlineAttentionOp::decomposeOperation(OpBuilder &b) { - Location loc = getLoc(); - Value query = getQuery(); - Value key = getKey(); - Value value = getValue(); - std::optional mask = getMask(); - Value oldAcc = getOutput(); - Value oldMax = getMax(); - Value oldSum = getSum(); - Type elementType = getElementTypeOrSelf(getOutput().getType()); - - FailureOr maybeOpInfo = - AttentionOpDetail::get(getIndexingMapsArray()); - assert(succeeded(maybeOpInfo) && "Invalid attention indexing maps"); - AttentionOpDetail opInfo = maybeOpInfo.value(); - - SmallVector sizes = llvm::map_to_vector( - getIterationDomain(b), [](Range x) { return x.size; }); - +Value computeQKAndElementwise(Location loc, OpBuilder &b, Value query, + Value key, Value scale, std::optional mask, + AffineMap qMap, AffineMap kMap, AffineMap sMap, + std::optional maskMap, + SmallVector iterationDomain, + Type sElementType, Region &elementwiseRegion, + DictionaryAttr qkAttrs, bool lowPrecision) { + MLIRContext *ctx = b.getContext(); // Since we use exp2 for attention instead of the original exp, we have to // multiply the scale by log2(e). We use exp2 instead of exp as most platforms // have better support for exp2 (we verified that we gain some speedup on // some GPUs). - Value scale = getScale(); Value log2e = b.create( loc, b.getFloatAttr(scale.getType(), M_LOG2E)); scale = b.create(loc, scale, log2e); auto qETy = getElementTypeOrSelf(query.getType()); - auto vETy = getElementTypeOrSelf(value.getType()); - AffineMap scaleMap = AffineMap::get(/*dimCount=*/getQueryMap().getNumInputs(), - /*symbolCount=*/0, getContext()); + AffineMap scaleMap = AffineMap::get(/*dimCount=*/qMap.getNumInputs(), + /*symbolCount=*/0, ctx); // In the original algorithm, the scaling is done after the softmax: // softmax(Q @ K.T * scale) @ V @@ -340,43 +331,40 @@ OnlineAttentionOp::decomposeOperation(OpBuilder &b) { // iteration of the loop. This is only valid for f16 or f32 types as f8 // is extremely limited on its dynamic range therefore this would // significantly affect numerics. - if (qETy.getIntOrFloatBitWidth() > 8) { - AffineMap qMap = getQueryMap(); + if (!lowPrecision) { query = elementwiseValueInPlace(b, loc, qMap, scaleMap, query, scale); } - // ---- Matmul 1 ---- + // ---- QK Matmul ---- // Get sizes for S. - AffineMap sMap = opInfo.getSMap(); SmallVector sSizes; for (AffineExpr dimExpr : sMap.getResults()) { int dim = cast(dimExpr).getPosition(); - sSizes.push_back(sizes[dim]); + sSizes.push_back(iterationDomain[dim]); } // S = Q @ K // SMap = QMap @ KMap - Value emptyS = b.create(loc, sSizes, elementType); - Value sZero = b.create(loc, b.getZeroAttr(elementType)); + Value emptyS = b.create(loc, sSizes, sElementType); + Value sZero = b.create(loc, b.getZeroAttr(sElementType)); Value s = b.create(loc, sZero, emptyS).getResult(0); - s = computeMatmul(b, loc, getQueryMap(), getKeyMap(), sMap, query, key, s); - - // TODO: We shouldn't be relying on such attributes. We need a better - // mechanism to identify attention matmuls. - s.getDefiningOp()->setAttr("attention_qk_matmul", b.getUnitAttr()); + s = computeMatmul(b, loc, qMap, kMap, sMap, query, key, s); + if (qkAttrs) { + s.getDefiningOp()->setAttrs(qkAttrs); + } - s = applyPostQKMatmulElementwise(b, loc, getRegion(), s); + s = applyPostQKMatmulElementwise(b, loc, elementwiseRegion, s); - if (qETy.getIntOrFloatBitWidth() <= 8) { + if (lowPrecision) { // For low bit-depth types we perform post Q @ K scaling. This is to avoid // losing numerical precision due to the low dynamic range of fp8 types when // pre applying the sclaing. AffineMap sMap = b.getMultiDimIdentityMap(sSizes.size()); AffineMap scaleMap = AffineMap::get(/*dimCount=*/sMap.getNumInputs(), - /*symbolCount=*/0, getContext()); + /*symbolCount=*/0, ctx); s = elementwiseValueInPlace(b, loc, sMap, scaleMap, s, scale); @@ -389,16 +377,176 @@ OnlineAttentionOp::decomposeOperation(OpBuilder &b) { APFloat::getLargest(fpTy.getFloatSemantics(), /*Negative=*/false) .convertToDouble(); Value offset = b.create( - loc, b.getFloatAttr(elementType, clAttentionSoftmaxMax / mx)); + loc, b.getFloatAttr(sElementType, clAttentionSoftmaxMax / mx)); s = elementwiseValueInPlace(b, loc, sMap, scaleMap, s, offset); } // S += mask if (mask != nullptr) { - s = applyMask(b, loc, sMap, *getMaskMap(), s, mask.value()); + s = applyMask(b, loc, sMap, *maskMap, s, mask.value()); + } + + return s; +} + +//===----------------------------------------------------------------------===// +// AttentionOp +//===----------------------------------------------------------------------===// + +FailureOr> AttentionOp::decomposeOperation(OpBuilder &b) { + Location loc = getLoc(); + Value query = getQuery(); + Value key = getKey(); + Value value = getValue(); + std::optional mask = getMask(); + DictionaryAttr config = getDecompositionConfigAttr(); + + DictionaryAttr qkAttrs, pvAttrs; + if (config) { + qkAttrs = config.getAs(getQKAttrStr()); + pvAttrs = config.getAs(getPVAttrStr()); + } + Value output = getOutput(); + + FailureOr maybeOpInfo = + AttentionOpDetail::get(getIndexingMapsArray()); + assert(succeeded(maybeOpInfo) && "Invalid attention indexing maps"); + AttentionOpDetail opInfo = maybeOpInfo.value(); + + SmallVector sizes = llvm::map_to_vector( + getIterationDomain(b), [](Range x) { return x.size; }); + + AffineMap qMap = getQueryMap(); + AffineMap kMap = getKeyMap(); + AffineMap sMap = opInfo.getSMap(); + + auto qETy = getElementTypeOrSelf(query.getType()); + bool lowPrecision = qETy.getIntOrFloatBitWidth() <= 8; + + // We compute output of first matmul in f32. + Type f32Type = b.getF32Type(); + + // ---- QK Matmul + elementwise math ---- + Value s = computeQKAndElementwise(loc, b, query, key, getScale(), mask, qMap, + kMap, sMap, getMaskMap(), sizes, f32Type, + getRegion(), qkAttrs, lowPrecision); + + // ---- Softmax ---- + + AffineMap accMap = getOutputMap(); + + llvm::SmallBitVector projectedK2Dims(opInfo.getDomainRank(), false); + for (auto dim : opInfo.getK2Dims()) { + projectedK2Dims.set(dim); + } + + AffineMap maxMap = projectDims(sMap, projectedK2Dims).dropZeroResults(); + AffineMap sumMap = maxMap; + + SmallVector rowRedSize = + applyPermutationMap(maxMap, sizes); + + Value rowRedEmpty = b.create(loc, rowRedSize, f32Type); + + Value accInit = arith::getIdentityValue(arith::AtomicRMWKind::addf, + getElementTypeOrSelf(output), b, loc, + /*useOnlyFiniteValue=*/true); + Value maxInit = + arith::getIdentityValue(arith::AtomicRMWKind::maximumf, f32Type, b, loc, + /*useOnlyFiniteValue=*/true); + Value sumInit = + arith::getIdentityValue(arith::AtomicRMWKind::addf, f32Type, b, loc); + + Value accFill = + b.create(loc, ValueRange{accInit}, output).getResult(0); + Value maxFill = + b.create(loc, ValueRange{maxInit}, rowRedEmpty) + .getResult(0); + Value sumFill = + b.create(loc, ValueRange{sumInit}, rowRedEmpty) + .getResult(0); + + // max = rowMax(S) + Value max = reduce(b, loc, sMap, maxMap, s, maxFill); + + // P = exp2(S - max) + AffineMap pMap = sMap; + Value p = computeSubAndExp2(b, loc, maxMap, sMap, max, s); + + // sum = rowSum(P) + Value sum = reduce(b, loc, pMap, sumMap, p, sumFill); + + // P = P / sum + p = elementwiseValueInPlace(b, loc, pMap, sumMap, p, sum); + + // ---- Scale and truncate LHS to match RHS ---- + SmallVector sSizes; + for (AffineExpr dimExpr : sMap.getResults()) { + int dim = cast(dimExpr).getPosition(); + sSizes.push_back(sizes[dim]); + } + + auto pETy = getElementTypeOrSelf(p.getType()); + auto vETy = getElementTypeOrSelf(value.getType()); + if (pETy != vETy && isa(vETy)) { + Value convertP = b.create(loc, sSizes, vETy); + p = truncateFloat(b, loc, pMap, pMap, p, convertP, lowPrecision); } + // result = P @ V + acc + Value result = + computeMatmul(b, loc, pMap, getValueMap(), accMap, p, value, accFill); + if (pvAttrs) { + result.getDefiningOp()->setAttrs(pvAttrs); + } + + return SmallVector{result}; +} + +//===----------------------------------------------------------------------===// +// OnlineAttentionOp +//===----------------------------------------------------------------------===// + +FailureOr> +OnlineAttentionOp::decomposeOperation(OpBuilder &b) { + Location loc = getLoc(); + Value query = getQuery(); + Value key = getKey(); + Value value = getValue(); + std::optional mask = getMask(); + Value oldAcc = getOutput(); + Value oldMax = getMax(); + Value oldSum = getSum(); + Type elementType = getElementTypeOrSelf(getOutput().getType()); + DictionaryAttr config = getDecompositionConfigAttr(); + + DictionaryAttr qkAttrs, pvAttrs; + if (config) { + qkAttrs = config.getAs(getQKAttrStr()); + pvAttrs = config.getAs(getPVAttrStr()); + } + + FailureOr maybeOpInfo = + AttentionOpDetail::get(getIndexingMapsArray()); + assert(succeeded(maybeOpInfo) && "Invalid attention indexing maps"); + AttentionOpDetail opInfo = maybeOpInfo.value(); + + SmallVector sizes = llvm::map_to_vector( + getIterationDomain(b), [](Range x) { return x.size; }); + + AffineMap qMap = getQueryMap(); + AffineMap kMap = getKeyMap(); + AffineMap sMap = opInfo.getSMap(); + + auto qETy = getElementTypeOrSelf(query.getType()); + bool lowPrecision = qETy.getIntOrFloatBitWidth() <= 8; + + // ---- QK Matmul + elementwise math ---- + Value s = computeQKAndElementwise( + loc, b, query, key, getScale(), mask, qMap, kMap, sMap, getMaskMap(), + sizes, elementType, getRegion(), qkAttrs, lowPrecision); + // TODO: This decomposition should be in a seperate op called // "online softmax". // ---- Online Softmax ---- @@ -429,10 +577,17 @@ OnlineAttentionOp::decomposeOperation(OpBuilder &b) { AffineMap accMap = getOutputMap(); // ---- Scale and truncate LHS to match RHS ---- + SmallVector sSizes; + for (AffineExpr dimExpr : sMap.getResults()) { + int dim = cast(dimExpr).getPosition(); + sSizes.push_back(sizes[dim]); + } + auto pETy = getElementTypeOrSelf(p.getType()); + auto vETy = getElementTypeOrSelf(value.getType()); if (pETy != vETy && isa(vETy)) { Value convertP = b.create(loc, sSizes, vETy); - p = truncateFloat(b, loc, pMap, pMap, p, convertP); + p = truncateFloat(b, loc, pMap, pMap, p, convertP, lowPrecision); } Value newAcc = elementwiseValueInPlace(b, loc, accMap, normMap, @@ -442,9 +597,9 @@ OnlineAttentionOp::decomposeOperation(OpBuilder &b) { // newAcc = P @ V + newAcc newAcc = computeMatmul(b, loc, pMap, getValueMap(), accMap, p, value, newAcc); - // TODO: We shouldn't be relying on such attributes. We need a better - // mechanism to identify attention matmuls. - newAcc.getDefiningOp()->setAttr("attention_pv_matmul", b.getUnitAttr()); + if (pvAttrs) { + newAcc.getDefiningOp()->setDiscardableAttrs(pvAttrs); + } return SmallVector{newAcc, newMax, newSum}; } diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp index 6abaec41f91a..77a2d518acb2 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp @@ -1213,7 +1213,7 @@ void AttentionOp::build(OpBuilder &odsBuilder, OperationState &odsState, std::optional mask) { Value maskIn = mask.value_or(Value()); build(odsBuilder, odsState, results, query, key, value, scale, maskIn, output, - indexingMaps); + indexingMaps, DictionaryAttr()); } LogicalResult AttentionOp::verify() { @@ -1388,7 +1388,7 @@ void OnlineAttentionOp::build(OpBuilder &odsBuilder, OperationState &odsState, std::optional mask) { Value maskIn = mask.value_or(Value()); build(odsBuilder, odsState, results, query, key, value, maskIn, scale, output, - max, sum, indexingMaps); + max, sum, indexingMaps, DictionaryAttr()); } LogicalResult OnlineAttentionOp::verify() { diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td index e097ce5a9089..3b46114abe5e 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td @@ -475,6 +475,7 @@ def IREELinalgExt_AttentionOp : IREELinalgExt_PureOp<"attention", ["getIndexingMapsForResults", "getIndexingMapsForOperands", "getStaticLoopRanges"]>, DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, DeclareOpInterfaceMethods:$mask, AnyShaped:$output, - AffineMapArrayAttr:$indexing_maps + AffineMapArrayAttr:$indexing_maps, + OptionalAttr:$decomposition_config ); let regions = (region SizedRegion<1>:$region); @@ -558,6 +560,12 @@ def IREELinalgExt_AttentionOp : IREELinalgExt_PureOp<"attention", int64_t getIterationDomainRank() { return getQueryMap().getNumDims(); } + + /* Decomposition control attributes */ + + // Attributes to set on QK and PV matmul after decomposition. + static StringRef getQKAttrStr() { return "qk_attrs"; } + static StringRef getPVAttrStr() { return "pv_attrs"; } }]; } @@ -612,7 +620,8 @@ def IREELinalgExt_OnlineAttentionOp : IREELinalgExt_PureOp<"online_attention", AnyShaped:$output, AnyShaped:$max, AnyShaped:$sum, - AffineMapArrayAttr:$indexing_maps + AffineMapArrayAttr:$indexing_maps, + OptionalAttr:$decomposition_config ); let regions = (region SizedRegion<1>:$region); @@ -679,6 +688,12 @@ def IREELinalgExt_OnlineAttentionOp : IREELinalgExt_PureOp<"online_attention", int64_t getIterationDomainRank() { return getQueryMap().getNumDims(); } + + /* Decomposition control attributes */ + + // Attributes to set on QK and PV matmul after decomposition. + static StringRef getQKAttrStr() { return "qk_attrs"; } + static StringRef getPVAttrStr() { return "pv_attrs"; } }]; } //===----------------------------------------------------------------------===// diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/BUILD.bazel index ade61354f05a..aff5921fd57b 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/BUILD.bazel @@ -17,6 +17,7 @@ iree_lit_test_suite( srcs = enforce_glob( [ "canonicalize.mlir", + "decompose_aggregate_op.mlir", "invalid.mlir", "roundtrip.mlir", ], diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/CMakeLists.txt index f6d6730ad5b2..36bdf43dc97c 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/CMakeLists.txt @@ -15,6 +15,7 @@ iree_lit_test_suite( lit SRCS "canonicalize.mlir" + "decompose_aggregate_op.mlir" "invalid.mlir" "roundtrip.mlir" TOOLS diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_online_attention.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/decompose_aggregate_op.mlir similarity index 51% rename from compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_online_attention.mlir rename to compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/decompose_aggregate_op.mlir index 19bd2ca12411..fae9e5b76b23 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_online_attention.mlir +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/decompose_aggregate_op.mlir @@ -1,4 +1,79 @@ -// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(func.func(iree-linalg-ext-decompose-attention),canonicalize,cse)" %s | FileCheck %s +// RUN: iree-opt --iree-transform-dialect-interpreter --canonicalize --mlir-print-local-scope --split-input-file %s | FileCheck %s + +// Spec to decompose custom op. +module attributes { transform.with_named_sequence } { + transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["iree_linalg_ext.custom_op"]} in %module_op : (!transform.any_op) -> !transform.any_op + transform.iree.decompose_aggregate_op %0 : (!transform.any_op) -> () + transform.yield + } +} + +func.func @custom_op_decomposition(%lhs1 : tensor<1000000x?xf32>, + %rhs1 : tensor, %rhs2 : tensor, %scalar : f32, + %outs1 : tensor<1000000x?xf32>, %outs2 : tensor<1000000x?xf32>) + -> (tensor<1000000x?xf32>, tensor<1000000x?xf32>) { + %0:2 = iree_linalg_ext.custom_op { + indexing_maps = [affine_map<(d0, d1)[s0, s1] -> (d0, s0)>, + affine_map<(d0, d1)[s0, s1] -> (s0, s1)>, + affine_map<(d0, d1)[s0, s1] -> (s1, d1)>, + affine_map<(d0, d1)[s0, s1] -> ()>, + affine_map<(d0, d1)[s0, s1] -> (d0, s1)>, + affine_map<(d0, d1)[s0, s1] -> (d0, d1)>], + iterator_types = [#iree_linalg_ext.iterator_type, + #iree_linalg_ext.iterator_type]} + ins(%lhs1, %rhs1, %rhs2, %scalar + : tensor<1000000x?xf32>, tensor, tensor, f32) + outs(%outs1, %outs2 : tensor<1000000x?xf32>, tensor<1000000x?xf32>) { + ^bb0(%t0 : tensor, %t1 : tensor, %t2 : tensor, + %s : f32, %t3 : tensor, %t4 : tensor) : + %0 = linalg.matmul ins(%t0, %t1 : tensor, tensor) + outs(%t3 : tensor) -> tensor + %1 = linalg.matmul ins(%0, %t2 : tensor, tensor) + outs(%t4 : tensor) -> tensor + %2 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> ()>, + affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"]} + ins(%1, %s : tensor, f32) outs(%1 : tensor) { + ^bb0(%b0 : f32, %b1 : f32, %b2 :f32): + %3 = arith.addf %b0, %b2 : f32 + linalg.yield %3 : f32 + } -> tensor + iree_linalg_ext.yield %0, %2 : tensor, tensor + } -> tensor<1000000x?xf32>, tensor<1000000x?xf32> + return %0#0, %0#1 : tensor<1000000x?xf32>, tensor<1000000x?xf32> +} + +// CHECK-LABEL: func @custom_op_decomposition( +// CHECK-SAME: %[[LHS1:[a-zA-Z0-9]+]]: tensor<1000000x?xf32> +// CHECK-SAME: %[[RHS1:[a-zA-Z0-9]+]]: tensor +// CHECK-SAME: %[[RHS2:[a-zA-Z0-9]+]]: tensor +// CHECK-SAME: %[[SCALAR:[a-zA-Z0-9]+]]: f32 +// CHECK-SAME: %[[INIT1:[a-zA-Z0-9]+]]: tensor<1000000x?xf32> +// CHECK-SAME: %[[INIT2:[a-zA-Z0-9]+]]: tensor<1000000x?xf32> +// CHECK: %[[MATMUL1:.+]] = linalg.matmul +// CHECK-SAME: ins(%[[LHS1]], %[[RHS1]] : +// CHECK-SAME: outs(%[[INIT1]] : +// CHECK: %[[MATMUL2:.+]] = linalg.matmul +// CHECK-SAME: ins(%[[MATMUL1]], %[[RHS2]] : +// CHECK-SAME: outs(%[[INIT2]] : +// CHECK: %[[GENERIC:.+]] = linalg.generic +// CHECK-SAME: ins(%[[MATMUL2]], %[[SCALAR]] : +// CHECK-SAME: outs(%[[MATMUL2]] : +// CHECK: return %[[MATMUL1]], %[[GENERIC]] + +// ----- + +// Spec to decompose online attention op. +module attributes { transform.with_named_sequence } { + transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["iree_linalg_ext.attention"]} in %module_op : (!transform.any_op) -> !transform.any_op + transform.iree.decompose_aggregate_op %0 : (!transform.any_op) -> () + transform.yield + } +} #mapQ = affine_map<(batch, m, k1, k2, n) -> (batch, m, k1)> #mapK = affine_map<(batch, m, k1, k2, n) -> (batch, k2, k1)> @@ -8,6 +83,89 @@ #mapR = affine_map<(batch, m, k1, k2, n) -> (batch, m)> func.func @attention_f16(%query: tensor<192x1024x64xf16>, + %key: tensor<192x1024x64xf16>, + %value: tensor<192x1024x64xf16>, + %output: tensor<192x1024x64xf32>) + -> (tensor<192x1024x64xf32>) { + %scale = arith.constant 1.0 : f16 + + %out = iree_linalg_ext.attention + { indexing_maps = [#mapQ, #mapK, #mapV, #mapS, #mapO] } + ins(%query, %key, %value, %scale : tensor<192x1024x64xf16>, tensor<192x1024x64xf16>, tensor<192x1024x64xf16>, f16) + outs(%output : tensor<192x1024x64xf32>) { + ^bb0(%score: f32): + iree_linalg_ext.yield %score: f32 + } + -> tensor<192x1024x64xf32> + + return %out : tensor<192x1024x64xf32> +} + +// We just want to check if we are using the correct algorithm +// CHECK-LABEL: @attention_f16 +// Q = Q * scale +// CHECK: linalg.generic +// CHECK: arith.mulf +// S = Q @ K +// CHECK: linalg.generic +// CHECK: arith.extf +// CHECK: arith.extf +// CHECK: arith.mulf +// CHECK: arith.addf +// CHECK: linalg.yield +// max = rowMax(S) +// CHECK: linalg.generic +// CHECK-NOT: arith.extf +// CHECK: arith.maximumf +// CHECK: linalg.yield +// P = exp2(S - max) +// CHECK: linalg.generic +// CHECK-NOT: arith.extf +// CHECK: arith.subf +// CHECK: math.exp2 +// CHECK: linalg.yield +// sum = rowSum(P) +// CHECK: linalg.generic +// CHECK-NOT: arith.extf +// CHECK: arith.addf +// CHECK: linalg.yield +// P = P /= sum +// CHECK: linalg.generic +// CHECK-NOT: arith.extf +// CHECK: arith.divf +// CHECK: linalg.yield +// truncf P : f32 to f16 +// CHECK: linalg.generic +// CHECK-NOT: arith.extf +// CHECK: arith.truncf +// CHECK: linalg.yield +// newAcc = P @ V +// CHECK: linalg.generic +// CHECK: arith.extf +// CHECK: arith.extf +// CHECK: arith.mulf +// CHECK: arith.addf +// CHECK: linalg.yield + +// ----- + +// Spec to decompose online attention op. +module attributes { transform.with_named_sequence } { + transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["iree_linalg_ext.online_attention"]} in %module_op : (!transform.any_op) -> !transform.any_op + transform.iree.decompose_aggregate_op %0 : (!transform.any_op) -> () + transform.yield + } +} + +#mapQ = affine_map<(batch, m, k1, k2, n) -> (batch, m, k1)> +#mapK = affine_map<(batch, m, k1, k2, n) -> (batch, k2, k1)> +#mapV = affine_map<(batch, m, k1, k2, n) -> (batch, k2, n)> +#mapS = affine_map<(batch, m, k1, k2, n) -> ()> +#mapO = affine_map<(batch, m, k1, k2, n) -> (batch, m, n)> +#mapR = affine_map<(batch, m, k1, k2, n) -> (batch, m)> + +func.func @online_attention_f16(%query: tensor<192x1024x64xf16>, %key: tensor<192x1024x64xf16>, %value: tensor<192x1024x64xf16>, %output: tensor<192x1024x64xf32>, @@ -30,7 +188,7 @@ func.func @attention_f16(%query: tensor<192x1024x64xf16>, // We just want to check if we are using the correct algorithm and the // correct number of extf/truncfs are emitted. -// CHECK-LABEL: @attention_f16 +// CHECK-LABEL: @online_attention_f16 // Q = Q * scale // CHECK: linalg.generic // CHECK: arith.mulf @@ -83,6 +241,15 @@ func.func @attention_f16(%query: tensor<192x1024x64xf16>, // ----- +// Spec to decompose online attention op. +module attributes { transform.with_named_sequence } { + transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["iree_linalg_ext.online_attention"]} in %module_op : (!transform.any_op) -> !transform.any_op + transform.iree.decompose_aggregate_op %0 : (!transform.any_op) -> () + transform.yield + } +} + #mapQ = affine_map<(batch, m, k1, k2, n) -> (batch, m, k1)> #mapK = affine_map<(batch, m, k1, k2, n) -> (batch, k2, k1)> #mapV = affine_map<(batch, m, k1, k2, n) -> (batch, k2, n)> @@ -90,7 +257,7 @@ func.func @attention_f16(%query: tensor<192x1024x64xf16>, #mapO = affine_map<(batch, m, k1, k2, n) -> (batch, m, n)> #mapR = affine_map<(batch, m, k1, k2, n) -> (batch, m)> -func.func @attention_f8(%query: tensor<192x1024x64xf8E4M3FNUZ>, +func.func @online_attention_f8(%query: tensor<192x1024x64xf8E4M3FNUZ>, %key: tensor<192x1024x64xf8E4M3FNUZ>, %value: tensor<192x1024x64xf8E4M3FNUZ>, %output: tensor<192x1024x64xf32>, @@ -111,7 +278,7 @@ func.func @attention_f8(%query: tensor<192x1024x64xf8E4M3FNUZ>, return %out#0, %out#2 : tensor<192x1024x64xf32>, tensor<192x1024xf32> } -// CHECK-LABEL: @attention_f8 +// CHECK-LABEL: @online_attention_f8 // S = Q @ K // CHECK: linalg.generic // CHECK: arith.extf %[[A:.+]] : f8E4M3FNUZ to f32 @@ -176,6 +343,15 @@ func.func @attention_f8(%query: tensor<192x1024x64xf8E4M3FNUZ>, // ----- +// Spec to decompose online attention op. +module attributes { transform.with_named_sequence } { + transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["iree_linalg_ext.online_attention"]} in %module_op : (!transform.any_op) -> !transform.any_op + transform.iree.decompose_aggregate_op %0 : (!transform.any_op) -> () + transform.yield + } +} + #mapQ = affine_map<(batch, m, k1, k2, n) -> (batch, m, k1)> #mapK = affine_map<(batch, m, k1, k2, n) -> (batch, k2, k1)> #mapV = affine_map<(batch, m, k1, k2, n) -> (batch, k2, n)> @@ -184,7 +360,7 @@ func.func @attention_f8(%query: tensor<192x1024x64xf8E4M3FNUZ>, #mapO = affine_map<(batch, m, k1, k2, n) -> (batch, m, n)> #mapR = affine_map<(batch, m, k1, k2, n) -> (batch, m)> -func.func @attention_f8_masked(%query: tensor<192x1024x64xf8E4M3FNUZ>, +func.func @online_attention_f8_masked(%query: tensor<192x1024x64xf8E4M3FNUZ>, %key: tensor<192x1024x64xf8E4M3FNUZ>, %value: tensor<192x1024x64xf8E4M3FNUZ>, %mask: tensor<192x1024x1024xf8E4M3FNUZ>, @@ -205,7 +381,7 @@ func.func @attention_f8_masked(%query: tensor<192x1024x64xf8E4M3FNUZ>, return %out#0, %out#2 : tensor<192x1024x64xf32>, tensor<192x1024xf32> } -// CHECK-LABEL: @attention_f8_masked +// CHECK-LABEL: @online_attention_f8_masked // S = Q @ K // CHECK: linalg.generic // CHECK: arith.extf %[[A:.+]] : f8E4M3FNUZ to f32 diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ConvertConv2DToIm2ColOp.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ConvertConv2DToIm2ColOp.cpp index 6e699fda1f2e..131ff3e5437b 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ConvertConv2DToIm2ColOp.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ConvertConv2DToIm2ColOp.cpp @@ -38,7 +38,7 @@ static Value createMul(Location loc, Value x, Value y, OpBuilder &builder) { namespace { -using ControlFnTy = std::optional>; +using ControlFnTy = std::function; // Convert linalg.conv_2d_nhwc_hwcf into linalg.generic (for img2col packing) // and linalg.matmul. @@ -78,7 +78,8 @@ class ConvertConv2DNhwcHwcf final public: using OpRewritePattern::OpRewritePattern; - ConvertConv2DNhwcHwcf(MLIRContext *context, ControlFnTy controlFn) + ConvertConv2DNhwcHwcf(MLIRContext *context, + std::optional controlFn) : OpRewritePattern(context), controlFn(controlFn) {} @@ -192,7 +193,7 @@ class ConvertConv2DNhwcHwcf final } private: - ControlFnTy controlFn; + std::optional controlFn; }; // For nchw, because the channels are to the left of the image shape dimensions, @@ -204,7 +205,8 @@ class ConvertConv2DNchwFchw final public: using OpRewritePattern::OpRewritePattern; - ConvertConv2DNchwFchw(MLIRContext *context, ControlFnTy controlFn) + ConvertConv2DNchwFchw(MLIRContext *context, + std::optional controlFn) : OpRewritePattern(context), controlFn(controlFn) {} @@ -314,7 +316,7 @@ class ConvertConv2DNchwFchw final } private: - ControlFnTy controlFn; + std::optional controlFn; }; struct ConvertConv2DToIm2ColOpPass final @@ -335,7 +337,7 @@ struct ConvertConv2DToIm2ColOpPass final } // namespace void populateConv2DToIm2colOpPatterns(RewritePatternSet &patterns, - ControlFnTy controlFn) { + std::optional controlFn) { patterns.insert( patterns.getContext(), std::move(controlFn)); } diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.h b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.h index 1e858df14e2f..cc894b3edecd 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.h +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.h @@ -28,8 +28,10 @@ LogicalResult splitReduction(RewriterBase &rewriter, LinalgExt::TopkOp topkOp, const TopkSplitReductionControlFn &splitReductionFn); -// Patterns to convert linalg convolution ops into a gemm with an im2col -// op and reshapes on the inputs. +/// Patterns to convert linalg convolution ops into a gemm with an im2col +/// op and reshapes on the inputs. +/// TODO(Max191): Maybe move to transforms and use a funcOp walk instead of a +/// rewrite pattern for this. void populateConv2DToIm2colOpPatterns( RewritePatternSet &patterns, std::optional> controlFn = std::nullopt); diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TileAttention.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TileAttention.cpp index 0aa3a37aa5fe..d9a48736fdd4 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TileAttention.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TileAttention.cpp @@ -106,7 +106,8 @@ void convertToOnlineAttention(IREE::LinalgExt::AttentionOp attnOp, loc, TypeRange{accFill.getType(), maxFill.getType(), sumFill.getType()}, attnOp.getQuery(), attnOp.getKey(), attnOp.getValue(), attnOp.getScale(), mask, accFill, maxFill, sumFill, - rewriter.getAffineMapArrayAttr(indexingMaps)); + rewriter.getAffineMapArrayAttr(indexingMaps), + attnOp.getDecompositionConfigAttr()); rewriter.cloneRegionBefore(attnOp.getRegion(), onlineAttn.getRegion(), onlineAttn.getRegion().begin()); diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/BUILD.bazel index efe463a65949..6ba9d5cd801d 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/BUILD.bazel @@ -20,9 +20,7 @@ iree_lit_test_suite( "conv2d_to_winograd.mlir", "convert_to_loops.mlir", "convert_to_online_attention.mlir", - "decompose_aggregate_op.mlir", "decompose_im2col.mlir", - "decompose_online_attention.mlir", "decompose_winograd.mlir", "distribution.mlir", "pad_contraction_to_block_size.mlir", diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/CMakeLists.txt index 3288c1443dfd..a912973cb2f7 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/CMakeLists.txt @@ -18,9 +18,7 @@ iree_lit_test_suite( "conv2d_to_winograd.mlir" "convert_to_loops.mlir" "convert_to_online_attention.mlir" - "decompose_aggregate_op.mlir" "decompose_im2col.mlir" - "decompose_online_attention.mlir" "decompose_winograd.mlir" "distribution.mlir" "pad_contraction_to_block_size.mlir" diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_aggregate_op.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_aggregate_op.mlir deleted file mode 100644 index 80b0b7a693e3..000000000000 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_aggregate_op.mlir +++ /dev/null @@ -1,62 +0,0 @@ -// RUN: iree-opt --iree-transform-dialect-interpreter --canonicalize --mlir-print-local-scope --split-input-file %s | FileCheck %s - -func.func @custom_op_decomposition(%lhs1 : tensor<1000000x?xf32>, - %rhs1 : tensor, %rhs2 : tensor, %scalar : f32, - %outs1 : tensor<1000000x?xf32>, %outs2 : tensor<1000000x?xf32>) - -> (tensor<1000000x?xf32>, tensor<1000000x?xf32>) { - %0:2 = iree_linalg_ext.custom_op { - indexing_maps = [affine_map<(d0, d1)[s0, s1] -> (d0, s0)>, - affine_map<(d0, d1)[s0, s1] -> (s0, s1)>, - affine_map<(d0, d1)[s0, s1] -> (s1, d1)>, - affine_map<(d0, d1)[s0, s1] -> ()>, - affine_map<(d0, d1)[s0, s1] -> (d0, s1)>, - affine_map<(d0, d1)[s0, s1] -> (d0, d1)>], - iterator_types = [#iree_linalg_ext.iterator_type, - #iree_linalg_ext.iterator_type]} - ins(%lhs1, %rhs1, %rhs2, %scalar - : tensor<1000000x?xf32>, tensor, tensor, f32) - outs(%outs1, %outs2 : tensor<1000000x?xf32>, tensor<1000000x?xf32>) { - ^bb0(%t0 : tensor, %t1 : tensor, %t2 : tensor, - %s : f32, %t3 : tensor, %t4 : tensor) : - %0 = linalg.matmul ins(%t0, %t1 : tensor, tensor) - outs(%t3 : tensor) -> tensor - %1 = linalg.matmul ins(%0, %t2 : tensor, tensor) - outs(%t4 : tensor) -> tensor - %2 = linalg.generic { - indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, - affine_map<(d0, d1) -> ()>, - affine_map<(d0, d1) -> (d0, d1)>], - iterator_types = ["parallel", "parallel"]} - ins(%1, %s : tensor, f32) outs(%1 : tensor) { - ^bb0(%b0 : f32, %b1 : f32, %b2 :f32): - %3 = arith.addf %b0, %b2 : f32 - linalg.yield %3 : f32 - } -> tensor - iree_linalg_ext.yield %0, %2 : tensor, tensor - } -> tensor<1000000x?xf32>, tensor<1000000x?xf32> - return %0#0, %0#1 : tensor<1000000x?xf32>, tensor<1000000x?xf32> -} -module attributes { transform.with_named_sequence } { - transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) { - %0 = transform.structured.match ops{["iree_linalg_ext.custom_op"]} in %module_op : (!transform.any_op) -> !transform.any_op - transform.iree.decompose_aggregate_op %0 : (!transform.any_op) -> () - transform.yield - } -} -// CHECK-LABEL: func @custom_op_decomposition( -// CHECK-SAME: %[[LHS1:[a-zA-Z0-9]+]]: tensor<1000000x?xf32> -// CHECK-SAME: %[[RHS1:[a-zA-Z0-9]+]]: tensor -// CHECK-SAME: %[[RHS2:[a-zA-Z0-9]+]]: tensor -// CHECK-SAME: %[[SCALAR:[a-zA-Z0-9]+]]: f32 -// CHECK-SAME: %[[INIT1:[a-zA-Z0-9]+]]: tensor<1000000x?xf32> -// CHECK-SAME: %[[INIT2:[a-zA-Z0-9]+]]: tensor<1000000x?xf32> -// CHECK: %[[MATMUL1:.+]] = linalg.matmul -// CHECK-SAME: ins(%[[LHS1]], %[[RHS1]] : -// CHECK-SAME: outs(%[[INIT1]] : -// CHECK: %[[MATMUL2:.+]] = linalg.matmul -// CHECK-SAME: ins(%[[MATMUL1]], %[[RHS2]] : -// CHECK-SAME: outs(%[[INIT2]] : -// CHECK: %[[GENERIC:.+]] = linalg.generic -// CHECK-SAME: ins(%[[MATMUL2]], %[[SCALAR]] : -// CHECK-SAME: outs(%[[MATMUL2]] : -// CHECK: return %[[MATMUL1]], %[[GENERIC]] diff --git a/compiler/src/iree/compiler/Dialect/Stream/Builtins/fill_i16.mlir b/compiler/src/iree/compiler/Dialect/Stream/Builtins/fill_i16.mlir index 81f683f26cdf..af2286a55139 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Builtins/fill_i16.mlir +++ b/compiler/src/iree/compiler/Dialect/Stream/Builtins/fill_i16.mlir @@ -9,16 +9,17 @@ stream.executable private @__builtin_fill_i16 { stream.executable.export public @__builtin_fill_i16 workgroups(%arg0: index) -> (index, index, index) { - %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg0 + %x, %y, %z = flow.dispatch.workgroup_count_from_slice %arg0 stream.return %x, %y, %z : index, index, index } builtin.module { func.func @__builtin_fill_i16(%value: i16, %count: index, %out_binding: !stream.binding) { %c0 = arith.constant 0 : index - %out = stream.binding.subspan %out_binding[%c0] : !stream.binding -> !flow.dispatch.tensor>{%count} - %0 = tensor.empty(%count) : tensor + %count0 = flow.dispatch.workload.ordinal %count, 0 : index + %out = stream.binding.subspan %out_binding[%c0] : !stream.binding -> !flow.dispatch.tensor>{%count0} + %0 = tensor.empty(%count0) : tensor %1 = linalg.fill ins(%value : i16) outs(%0 : tensor) -> tensor - flow.dispatch.tensor.store %1, %out, offsets = [0], sizes = [%count], strides = [1] : tensor -> !flow.dispatch.tensor>{%count} + flow.dispatch.tensor.store %1, %out, offsets = [0], sizes = [%count0], strides = [1] : tensor -> !flow.dispatch.tensor>{%count} return } } diff --git a/compiler/src/iree/compiler/Dialect/Stream/Builtins/fill_i32.mlir b/compiler/src/iree/compiler/Dialect/Stream/Builtins/fill_i32.mlir index 43b0829e99e9..758591f4159e 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Builtins/fill_i32.mlir +++ b/compiler/src/iree/compiler/Dialect/Stream/Builtins/fill_i32.mlir @@ -9,16 +9,17 @@ stream.executable private @__builtin_fill_i32 { stream.executable.export public @__builtin_fill_i32 workgroups(%arg0: index) -> (index, index, index) { - %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg0 + %x, %y, %z = flow.dispatch.workgroup_count_from_slice %arg0 stream.return %x, %y, %z : index, index, index } builtin.module { func.func @__builtin_fill_i32(%value: i32, %count: index, %out_binding: !stream.binding) { %c0 = arith.constant 0 : index - %out = stream.binding.subspan %out_binding[%c0] : !stream.binding -> !flow.dispatch.tensor>{%count} - %0 = tensor.empty(%count) : tensor + %count0 = flow.dispatch.workload.ordinal %count, 0 : index + %out = stream.binding.subspan %out_binding[%c0] : !stream.binding -> !flow.dispatch.tensor>{%count0} + %0 = tensor.empty(%count0) : tensor %1 = linalg.fill ins(%value : i32) outs(%0 : tensor) -> tensor - flow.dispatch.tensor.store %1, %out, offsets = [0], sizes = [%count], strides = [1] : tensor -> !flow.dispatch.tensor>{%count} + flow.dispatch.tensor.store %1, %out, offsets = [0], sizes = [%count0], strides = [1] : tensor -> !flow.dispatch.tensor>{%count} return } } diff --git a/compiler/src/iree/compiler/Dialect/Stream/Builtins/fill_i64.mlir b/compiler/src/iree/compiler/Dialect/Stream/Builtins/fill_i64.mlir index 96e527a20f0f..5d7d686b5bd1 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Builtins/fill_i64.mlir +++ b/compiler/src/iree/compiler/Dialect/Stream/Builtins/fill_i64.mlir @@ -9,16 +9,17 @@ stream.executable private @__builtin_fill_i64 { stream.executable.export public @__builtin_fill_i64 workgroups(%arg0: index) -> (index, index, index) { - %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg0 + %x, %y, %z = flow.dispatch.workgroup_count_from_slice %arg0 stream.return %x, %y, %z : index, index, index } builtin.module { func.func @__builtin_fill_i64(%value: i64, %count: index, %out_binding: !stream.binding) { %c0 = arith.constant 0 : index - %out = stream.binding.subspan %out_binding[%c0] : !stream.binding -> !flow.dispatch.tensor>{%count} - %0 = tensor.empty(%count) : tensor + %count0 = flow.dispatch.workload.ordinal %count, 0 : index + %out = stream.binding.subspan %out_binding[%c0] : !stream.binding -> !flow.dispatch.tensor>{%count0} + %0 = tensor.empty(%count0) : tensor %1 = linalg.fill ins(%value : i64) outs(%0 : tensor) -> tensor - flow.dispatch.tensor.store %1, %out, offsets = [0], sizes = [%count], strides = [1] : tensor -> !flow.dispatch.tensor>{%count} + flow.dispatch.tensor.store %1, %out, offsets = [0], sizes = [%count0], strides = [1] : tensor -> !flow.dispatch.tensor>{%count0} return } } diff --git a/compiler/src/iree/compiler/Dialect/Stream/Builtins/fill_i8.mlir b/compiler/src/iree/compiler/Dialect/Stream/Builtins/fill_i8.mlir index 7005ded9aee4..c2c642dd53bb 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Builtins/fill_i8.mlir +++ b/compiler/src/iree/compiler/Dialect/Stream/Builtins/fill_i8.mlir @@ -9,16 +9,17 @@ stream.executable private @__builtin_fill_i8 { stream.executable.export public @__builtin_fill_i8 workgroups(%arg0: index) -> (index, index, index) { - %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg0 + %x, %y, %z = flow.dispatch.workgroup_count_from_slice %arg0 stream.return %x, %y, %z : index, index, index } builtin.module { func.func @__builtin_fill_i8(%value: i8, %count: index, %out_binding: !stream.binding) { %c0 = arith.constant 0 : index - %out = stream.binding.subspan %out_binding[%c0] : !stream.binding -> !flow.dispatch.tensor>{%count} - %0 = tensor.empty(%count) : tensor + %count0 = flow.dispatch.workload.ordinal %count, 0 : index + %out = stream.binding.subspan %out_binding[%c0] : !stream.binding -> !flow.dispatch.tensor>{%count0} + %0 = tensor.empty(%count0) : tensor %1 = linalg.fill ins(%value : i8) outs(%0 : tensor) -> tensor - flow.dispatch.tensor.store %1, %out, offsets = [0], sizes = [%count], strides = [1] : tensor -> !flow.dispatch.tensor>{%count} + flow.dispatch.tensor.store %1, %out, offsets = [0], sizes = [%count0], strides = [1] : tensor -> !flow.dispatch.tensor>{%count} return } } diff --git a/compiler/src/iree/compiler/Dialect/Stream/Builtins/splat_i16.mlir b/compiler/src/iree/compiler/Dialect/Stream/Builtins/splat_i16.mlir index a94cdf1d6cf7..139788921a8a 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Builtins/splat_i16.mlir +++ b/compiler/src/iree/compiler/Dialect/Stream/Builtins/splat_i16.mlir @@ -9,16 +9,17 @@ stream.executable private @__builtin_splat_i16 { stream.executable.export public @__builtin_splat_i16 workgroups(%arg0: index) -> (index, index, index) { - %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg0 + %x, %y, %z = flow.dispatch.workgroup_count_from_slice %arg0 stream.return %x, %y, %z : index, index, index } builtin.module { func.func @__builtin_splat_i16(%value: i16, %count: index, %out_binding: !stream.binding) { %c0 = arith.constant 0 : index - %out = stream.binding.subspan %out_binding[%c0] : !stream.binding -> !flow.dispatch.tensor>{%count} - %0 = tensor.empty(%count) : tensor + %count0 = flow.dispatch.workload.ordinal %count, 0 : index + %out = stream.binding.subspan %out_binding[%c0] : !stream.binding -> !flow.dispatch.tensor>{%count0} + %0 = tensor.empty(%count0) : tensor %1 = linalg.fill ins(%value : i16) outs(%0 : tensor) -> tensor - flow.dispatch.tensor.store %1, %out, offsets = [0], sizes = [%count], strides = [1] : tensor -> !flow.dispatch.tensor>{%count} + flow.dispatch.tensor.store %1, %out, offsets = [0], sizes = [%count0], strides = [1] : tensor -> !flow.dispatch.tensor>{%count} return } } diff --git a/compiler/src/iree/compiler/Dialect/Stream/Builtins/splat_i32.mlir b/compiler/src/iree/compiler/Dialect/Stream/Builtins/splat_i32.mlir index 07f3b4cb1b54..a1f19b894e7a 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Builtins/splat_i32.mlir +++ b/compiler/src/iree/compiler/Dialect/Stream/Builtins/splat_i32.mlir @@ -9,16 +9,17 @@ stream.executable private @__builtin_splat_i32 { stream.executable.export public @__builtin_splat_i32 workgroups(%arg0: index) -> (index, index, index) { - %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg0 + %x, %y, %z = flow.dispatch.workgroup_count_from_slice %arg0 stream.return %x, %y, %z : index, index, index } builtin.module { func.func @__builtin_splat_i32(%value: i32, %count: index, %out_binding: !stream.binding) { %c0 = arith.constant 0 : index - %out = stream.binding.subspan %out_binding[%c0] : !stream.binding -> !flow.dispatch.tensor>{%count} - %0 = tensor.empty(%count) : tensor + %count0 = flow.dispatch.workload.ordinal %count, 0 : index + %out = stream.binding.subspan %out_binding[%c0] : !stream.binding -> !flow.dispatch.tensor>{%count0} + %0 = tensor.empty(%count0) : tensor %1 = linalg.fill ins(%value : i32) outs(%0 : tensor) -> tensor - flow.dispatch.tensor.store %1, %out, offsets = [0], sizes = [%count], strides = [1] : tensor -> !flow.dispatch.tensor>{%count} + flow.dispatch.tensor.store %1, %out, offsets = [0], sizes = [%count0], strides = [1] : tensor -> !flow.dispatch.tensor>{%count} return } } diff --git a/compiler/src/iree/compiler/Dialect/Stream/Builtins/splat_i64.mlir b/compiler/src/iree/compiler/Dialect/Stream/Builtins/splat_i64.mlir index 7d94e51a26d7..4d25d358c7b4 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Builtins/splat_i64.mlir +++ b/compiler/src/iree/compiler/Dialect/Stream/Builtins/splat_i64.mlir @@ -9,16 +9,17 @@ stream.executable private @__builtin_splat_i64 { stream.executable.export public @__builtin_splat_i64 workgroups(%arg0: index) -> (index, index, index) { - %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg0 + %x, %y, %z = flow.dispatch.workgroup_count_from_slice %arg0 stream.return %x, %y, %z : index, index, index } builtin.module { func.func @__builtin_splat_i64(%value: i64, %count: index, %out_binding: !stream.binding) { %c0 = arith.constant 0 : index - %out = stream.binding.subspan %out_binding[%c0] : !stream.binding -> !flow.dispatch.tensor>{%count} - %0 = tensor.empty(%count) : tensor + %count0 = flow.dispatch.workload.ordinal %count, 0 : index + %out = stream.binding.subspan %out_binding[%c0] : !stream.binding -> !flow.dispatch.tensor>{%count0} + %0 = tensor.empty(%count0) : tensor %1 = linalg.fill ins(%value : i64) outs(%0 : tensor) -> tensor - flow.dispatch.tensor.store %1, %out, offsets = [0], sizes = [%count], strides = [1] : tensor -> !flow.dispatch.tensor>{%count} + flow.dispatch.tensor.store %1, %out, offsets = [0], sizes = [%count0], strides = [1] : tensor -> !flow.dispatch.tensor>{%count0} return } } diff --git a/compiler/src/iree/compiler/Dialect/Stream/Builtins/splat_i8.mlir b/compiler/src/iree/compiler/Dialect/Stream/Builtins/splat_i8.mlir index 5e5f8cb261d7..d0c6dc046f1e 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Builtins/splat_i8.mlir +++ b/compiler/src/iree/compiler/Dialect/Stream/Builtins/splat_i8.mlir @@ -9,16 +9,17 @@ stream.executable private @__builtin_splat_i8 { stream.executable.export public @__builtin_splat_i8 workgroups(%arg0: index) -> (index, index, index) { - %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg0 + %x, %y, %z = flow.dispatch.workgroup_count_from_slice %arg0 stream.return %x, %y, %z : index, index, index } builtin.module { func.func @__builtin_splat_i8(%value: i8, %count: index, %out_binding: !stream.binding) { %c0 = arith.constant 0 : index - %out = stream.binding.subspan %out_binding[%c0] : !stream.binding -> !flow.dispatch.tensor>{%count} - %0 = tensor.empty(%count) : tensor + %count0 = flow.dispatch.workload.ordinal %count, 0 : index + %out = stream.binding.subspan %out_binding[%c0] : !stream.binding -> !flow.dispatch.tensor>{%count0} + %0 = tensor.empty(%count0) : tensor %1 = linalg.fill ins(%value : i8) outs(%0 : tensor) -> tensor - flow.dispatch.tensor.store %1, %out, offsets = [0], sizes = [%count], strides = [1] : tensor -> !flow.dispatch.tensor>{%count} + flow.dispatch.tensor.store %1, %out, offsets = [0], sizes = [%count0], strides = [1] : tensor -> !flow.dispatch.tensor>{%count} return } } diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/ConvertToStream.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/ConvertToStream.cpp index 91f5c0ffff3f..92a3457dfa22 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/ConvertToStream.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/ConvertToStream.cpp @@ -224,6 +224,11 @@ struct ConvertToStreamPass final // for all SSA values we'll use during conversion are available. AffinityAnalysis affinityAnalysis(getOperation()); if (failed(affinityAnalysis.run())) { + getOperation().emitError() + << "affinity analysis failed to converge (input program may have " + "invalid affinities assigned); use" + "`--iree-stream-annotate-input-affinities` to help identify the " + "invalid affinities"; return signalPassFailure(); } diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/optimize_int_arithmetic.mlir b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/optimize_int_arithmetic.mlir index 1924f423ef66..f78817cb03af 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/optimize_int_arithmetic.mlir +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/optimize_int_arithmetic.mlir @@ -493,3 +493,33 @@ util.func @util_align_zero(%arg0 : i64) -> i64 { %rem16 = arith.remui %0, %c16 : i64 util.return %rem16 : i64 } + +// ----- + +util.func @hal_buffer_view_dim_min_max(%bv : !hal.buffer_view) -> (i1, i1, i1) { + %zero = arith.constant 0 : index + %max = arith.constant 9007199254740991 : index + %0 = hal.buffer_view.dim<%bv : !hal.buffer_view>[0] : index + %1 = arith.cmpi slt, %0, %zero : index + %2 = arith.cmpi uge, %0, %zero : index + %3 = arith.cmpi ugt, %0, %max : index + // CHECK-DAG: %[[FALSE:.*]] = arith.constant false + // CHECK-DAG: %[[TRUE:.*]] = arith.constant true + // CHECK: util.return %[[FALSE]], %[[TRUE]], %[[FALSE]] + util.return %1, %2, %3 : i1, i1, i1 +} + +// ----- + +util.func @hal_buffer_view_rank_min_max(%bv : !hal.buffer_view) -> (i1, i1, i1) { + %zero = arith.constant 0 : index + %max = arith.constant 4096 : index + %0 = hal.buffer_view.rank<%bv : !hal.buffer_view> : index + %1 = arith.cmpi slt, %0, %zero : index + %2 = arith.cmpi uge, %0, %zero : index + %3 = arith.cmpi ugt, %0, %max : index + // CHECK-DAG: %[[FALSE:.*]] = arith.constant false + // CHECK-DAG: %[[TRUE:.*]] = arith.constant true + // CHECK: util.return %[[FALSE]], %[[TRUE]], %[[FALSE]] + util.return %1, %2, %3 : i1, i1, i1 +} diff --git a/compiler/src/iree/compiler/DispatchCreation/BubbleUpExtractSlices.cpp b/compiler/src/iree/compiler/DispatchCreation/BubbleUpExtractSlices.cpp index 4b03b60f5779..59ddb577fb51 100644 --- a/compiler/src/iree/compiler/DispatchCreation/BubbleUpExtractSlices.cpp +++ b/compiler/src/iree/compiler/DispatchCreation/BubbleUpExtractSlices.cpp @@ -115,6 +115,30 @@ struct BubbleUpExtract : OpRewritePattern { } }; +/// Swaps tensor.extract_slice(linalg.fill(%cst, %init)) into linalg.fill(%cst, +/// tensor.extract_slice(%init)) even when the linalg.fill has multiple users. +/// Bubbles up tensor.extract_slice when encountered with linalg.fill and the +/// former can be folded away. +struct SwapExtractSliceOfFill final + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::ExtractSliceOp extractOp, + PatternRewriter &rewriter) const override { + auto fillOp = extractOp.getSource().getDefiningOp(); + if (!fillOp) + return failure(); + + auto newExtractOp = rewriter.create( + extractOp.getLoc(), extractOp.getType(), fillOp.getOutputs()[0], + extractOp.getMixedOffsets(), extractOp.getMixedSizes(), + extractOp.getMixedStrides()); + rewriter.replaceOpWithNewOp( + extractOp, fillOp.getInputs(), ValueRange{newExtractOp.getResult()}); + return success(); + } +}; + struct BubbleUpExtractSlicesPass : impl::BubbleUpExtractSlicesPassBase { void runOnOperation() override { @@ -122,6 +146,8 @@ struct BubbleUpExtractSlicesPass { RewritePatternSet patterns(context); patterns.insert(context); + patterns.insert(context); + tensor::populateFoldTensorEmptyPatterns(patterns, false); if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) { return signalPassFailure(); diff --git a/compiler/src/iree/compiler/DispatchCreation/FormDispatchRegions.cpp b/compiler/src/iree/compiler/DispatchCreation/FormDispatchRegions.cpp index e866022eb9a9..b38b1a593001 100644 --- a/compiler/src/iree/compiler/DispatchCreation/FormDispatchRegions.cpp +++ b/compiler/src/iree/compiler/DispatchCreation/FormDispatchRegions.cpp @@ -547,6 +547,14 @@ isFusableWithConsumer(OpOperand &fusedOperand, return false; } + // TODO: Enable grouped convolution and depth wise pooling fusion. + // Rightnow, this is going through the default CPU pipeline and not through + // CONVTilingExpert. + if (isa(producer)) { + return false; + } + auto producerFusionOp = dyn_cast(producer); auto consumerFusionOp = diff --git a/compiler/src/iree/compiler/DispatchCreation/FuseHorizontalContractions.cpp b/compiler/src/iree/compiler/DispatchCreation/FuseHorizontalContractions.cpp index 845485667d38..a78b6b83876b 100644 --- a/compiler/src/iree/compiler/DispatchCreation/FuseHorizontalContractions.cpp +++ b/compiler/src/iree/compiler/DispatchCreation/FuseHorizontalContractions.cpp @@ -7,7 +7,6 @@ #include "iree/compiler/Dialect/Flow/IR/FlowOps.h" #include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h" #include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h" -#include "iree/compiler/DispatchCreation/FusionUtils.h" #include "iree/compiler/DispatchCreation/Passes.h" #include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Analysis/TopologicalSortUtils.h" @@ -108,6 +107,25 @@ static bool isEmptyFillContractionDAGRootOp( return true; } +/// Check that a given operation is "horizontal" to the group. The operation +/// is horizontal if the `slice` of the operation does not contain any op +/// from the group. +static bool isHorizontalToGroup(Operation *op, + const llvm::SetVector &currGroup, + const DominanceInfo &dominanceInfo, + Operation *seedOp) { + BackwardSliceOptions options; + // Limit the slice to the seed to make sure the slice is small. + options.filter = [&](Operation *op) { + return !dominanceInfo.properlyDominates(op, seedOp); + }; + llvm::SetVector slice; + getBackwardSlice(op, &slice, options); + return !llvm::any_of(currGroup, [&](Operation *groupedOp) { + return slice.contains(groupedOp); + }); +} + /// Get user of operation that is a truncate operation. static std::optional getTruncateOp(Operation *op, @@ -131,8 +149,8 @@ getTruncateOp(Operation *op, if (!checkOperationEquivalence(genericOp, seedTruncateOp.value())) { return std::nullopt; } - if (!isHorizontalToGroup(genericOp, groupedOperations.getArrayRef(), - dominanceInfo, seedTruncateOp.value())) { + if (!isHorizontalToGroup(genericOp, groupedOperations, dominanceInfo, + seedTruncateOp.value())) { return std::nullopt; } } @@ -208,8 +226,7 @@ static std::optional getHorizontalFusionGroupMembers( if (!dominanceInfo.properlyDominates(seedOp, linalgOp)) { return false; } - if (!isHorizontalToGroup(linalgOp, allOps.getArrayRef(), dominanceInfo, - seedOp)) { + if (!isHorizontalToGroup(linalgOp, allOps, dominanceInfo, seedOp)) { return false; } return true; @@ -329,6 +346,40 @@ static AffineMap getConcatenatedIndexingMap(RewriterBase &rewriter, return newIndexingMap.insertResult(rewriter.getAffineDimExpr(0), 0); } +/// During horizontal fusion, there might be operands of the fused operations +/// whose definitions are interspersed between the fused operations. For groups +/// chosen to fuse horizontally, such operations can be moved before the +/// seed contraction operation (where the fused operation is generated). +template +static LogicalResult +moveOperandDefs(RewriterBase &rewriter, ArrayRef operations, + Operation *insertionPoint, DominanceInfo &dominanceInfo, + ArrayRef ignoreOperations = {}) { + BackwardSliceOptions options; + llvm::DenseSet ignoreOperationsSet; + ignoreOperationsSet.insert(ignoreOperations.begin(), ignoreOperations.end()); + options.filter = [&](Operation *op) { + return !dominanceInfo.properlyDominates(op, insertionPoint) && + !ignoreOperationsSet.contains(op); + }; + // Set inclusive to true cause the slice is computed from the operand, and + // we want to include the defining op (which is the point here) + options.inclusive = true; + + llvm::SetVector slice; + for (auto op : operations) { + for (auto operand : op->getOperands()) { + getBackwardSlice(operand, &slice, options); + } + } + + mlir::topologicalSort(slice); + for (auto op : slice) { + rewriter.moveOpBefore(op, insertionPoint); + } + return success(); +} + /// On finding this pattern /// ``` /// %0 = linalg.matmul ins(%arg0, %arg1) diff --git a/compiler/src/iree/compiler/DispatchCreation/FuseMultiUseElementwiseProducer.cpp b/compiler/src/iree/compiler/DispatchCreation/FuseMultiUseElementwiseProducer.cpp index d79d5145e77d..9d9d477c9a57 100644 --- a/compiler/src/iree/compiler/DispatchCreation/FuseMultiUseElementwiseProducer.cpp +++ b/compiler/src/iree/compiler/DispatchCreation/FuseMultiUseElementwiseProducer.cpp @@ -16,13 +16,9 @@ #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h" #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h" #include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h" -#include "iree/compiler/DispatchCreation/FusionUtils.h" #include "iree/compiler/DispatchCreation/Passes.h" -#include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/STLExtras.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" -#include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Analysis/TopologicalSortUtils.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" @@ -49,55 +45,25 @@ static llvm::cl::opt clLinalgMaxConstantFoldElements( llvm::cl::desc("Maximum number of elements to try to constant fold."), llvm::cl::init(0)); -static Operation *getMostDominantUse(Operation *op, - const DominanceInfo &dominanceInfo) { - auto uses = op->getUses(); - auto it = llvm::find_if(uses, [&](OpOperand &source) { - Operation *sourceOp = source.getOwner(); - - return llvm::all_of(uses, [&](OpOperand &target) { - Operation *targetOp = target.getOwner(); - return dominanceInfo.dominates(sourceOp, targetOp); - }); - }); - if (it != uses.end()) { - return it->getOwner(); - } - return nullptr; -} - /// Check if any of the use dominates all other uses of the operation. -static Operation *getFusableUse(Operation *op, - const DominanceInfo &dominanceInfo) { +static std::optional getFusableUse(Operation *op, + DominanceInfo &dominanceInfo) { auto uses = op->getUses(); - Operation *fusableUse = nullptr; for (OpOperand &source : uses) { Operation *sourceOp = source.getOwner(); - - bool dominatesAllFusableOps = llvm::all_of(uses, [&](OpOperand &target) { + bool dominatesAllUsers = true; + for (OpOperand &target : uses) { Operation *targetOp = target.getOwner(); - return !isa(targetOp) || - dominanceInfo.dominates(sourceOp, targetOp); - }); - if (dominatesAllFusableOps) { - fusableUse = sourceOp; - break; + if (!dominanceInfo.dominates(sourceOp, targetOp)) { + dominatesAllUsers = false; + break; + } + } + if (dominatesAllUsers) { + return &source; } } - Operation *mostDominantOp = getMostDominantUse(op, dominanceInfo); - if (!fusableUse || !mostDominantOp) { - return nullptr; - } - - // If `fusableUse` dominates all other users, there's nothing else to do. - if (fusableUse == mostDominantOp) { - return fusableUse; - } - - SmallVector users(op->getUsers().begin(), op->getUsers().end()); - return isHorizontalToGroup(fusableUse, users, dominanceInfo, mostDominantOp) - ? fusableUse - : nullptr; + return std::nullopt; } static OpOperand *getFirstUseInConsumer(Operation *producer, @@ -125,7 +91,6 @@ static SmallVector getAllUsesInConsumer(Operation *producer, /// using elementwise fusion. static LogicalResult doMultiUseFusion(Operation *rootOp, llvm::SetVector &fusableOps, - const DominanceInfo &dominanceInfo, RewriterBase &rewriter) { assert(rootOp && "root op cant be null"); @@ -147,20 +112,11 @@ static LogicalResult doMultiUseFusion(Operation *rootOp, Operation *consumerOp = rootOp; OpBuilder::InsertionGuard g(rewriter); for (Operation *producerOp : llvm::reverse(fusedOpsVec)) { - Operation *mostDominantUser = getMostDominantUse(producerOp, dominanceInfo); // Fuse all uses from producer -> consumer. It has been checked // before that all uses are fusable. while (OpOperand *fusedOperand = getFirstUseInConsumer(producerOp, consumerOp)) { rewriter.setInsertionPoint(consumerOp); - - if (consumerOp != mostDominantUser && - failed(moveOperandDefs(rewriter, ArrayRef{consumerOp}, - mostDominantUser, dominanceInfo))) { - return rewriter.notifyMatchFailure(consumerOp, - "failed to move operand defs"); - } - rewriter.moveOpBefore(consumerOp, mostDominantUser); FailureOr fusionResult = linalg::fuseElementwiseOps(rewriter, fusedOperand); if (failed(fusionResult)) { @@ -234,8 +190,9 @@ static FailureOr fuseMultiUseProducers(Operation *funcOp, } // 6. Check that the `genericOp` dominates all uses of `producer`. - Operation *fusableUse = getFusableUse(producer, dominanceInfo); - if (!fusableUse || fusableUse != genericOp) { + std::optional fusableUse = + getFusableUse(producer, dominanceInfo); + if (!fusableUse || fusableUse.value()->getOwner() != genericOp) { continue; } @@ -275,8 +232,7 @@ static FailureOr fuseMultiUseProducers(Operation *funcOp, IRRewriter rewriter(context); for (auto it = fusedOps.rbegin(), ie = fusedOps.rend(); it != ie; ++it) { - if (failed( - doMultiUseFusion(it->first, it->second, dominanceInfo, rewriter))) { + if (failed(doMultiUseFusion(it->first, it->second, rewriter))) { return funcOp->emitOpError("failed multi use fusion"); } } diff --git a/compiler/src/iree/compiler/DispatchCreation/FusionUtils.cpp b/compiler/src/iree/compiler/DispatchCreation/FusionUtils.cpp index 238c866fe461..c428091f6cf8 100644 --- a/compiler/src/iree/compiler/DispatchCreation/FusionUtils.cpp +++ b/compiler/src/iree/compiler/DispatchCreation/FusionUtils.cpp @@ -10,11 +10,7 @@ #include "compiler/src/iree/compiler/DispatchCreation/FusionUtils.h" #include "compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h" #include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h" -#include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/IR/Dominance.h" -#include "mlir/IR/OpDefinition.h" -#include "mlir/Transforms/RegionUtils.h" namespace mlir::iree_compiler::DispatchCreation { @@ -101,33 +97,4 @@ bool areFusableAsElementwiseOps(MLIRContext *context, OpOperand *fusedOperand, return true; } -bool isHorizontalToGroup(Operation *op, ArrayRef currGroup, - const DominanceInfo &dominanceInfo, - Operation *seedOp) { - assert(dominanceInfo.properlyDominates(seedOp, op) && - op->getParentRegion() == seedOp->getParentRegion()); - BackwardSliceOptions options; - // Limit the slice to the seed to make sure the slice is small. - options.filter = [&](Operation *op) { - return !dominanceInfo.properlyDominates(op, seedOp); - }; - llvm::SetVector slice; - getBackwardSlice(op, &slice, options); - - // `getBackwardSlice` doesnt track uses from within an ops region, so make - // sure there are no values defined above. - for (Operation *sliceOp : slice) { - bool usesValuesFromAbove = false; - mlir::visitUsedValuesDefinedAbove( - sliceOp->getRegions(), [&](void *) { usesValuesFromAbove = true; }); - if (usesValuesFromAbove) { - return false; - } - } - - return !llvm::any_of(currGroup, [&](Operation *groupedOp) { - return slice.contains(groupedOp); - }); -} - } // namespace mlir::iree_compiler::DispatchCreation diff --git a/compiler/src/iree/compiler/DispatchCreation/FusionUtils.h b/compiler/src/iree/compiler/DispatchCreation/FusionUtils.h index 6526badfea31..1d9c9306f7ae 100644 --- a/compiler/src/iree/compiler/DispatchCreation/FusionUtils.h +++ b/compiler/src/iree/compiler/DispatchCreation/FusionUtils.h @@ -10,10 +10,6 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Analysis/SliceAnalysis.h" -#include "mlir/Analysis/TopologicalSortUtils.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/IR/Dominance.h" #include "mlir/IR/Operation.h" namespace mlir::iree_compiler::DispatchCreation { @@ -23,44 +19,4 @@ namespace mlir::iree_compiler::DispatchCreation { bool areFusableAsElementwiseOps(MLIRContext *context, OpOperand *operand, bool fuseMultiReduction); -/// Check that a given operation is "horizontal" to the group. The operation -/// is horizontal if the program slice of the operation (from op back to seedOp) -/// does not contain any op from the group. -bool isHorizontalToGroup(Operation *op, ArrayRef currGroup, - const DominanceInfo &dominanceInfo, Operation *seedOp); - -/// Moves the operands and transitive defs for each op in `operations` directly -/// after `insertionPoint`. Note: this does not check if it is legal to move the -/// operands. -template -static LogicalResult -moveOperandDefs(RewriterBase &rewriter, ArrayRef operations, - Operation *insertionPoint, const DominanceInfo &dominanceInfo, - ArrayRef ignoreOperations = {}) { - BackwardSliceOptions options; - llvm::DenseSet ignoreOperationsSet; - ignoreOperationsSet.insert(ignoreOperations.begin(), ignoreOperations.end()); - options.filter = [&](Operation *op) { - return !dominanceInfo.properlyDominates(op, insertionPoint) && - !ignoreOperationsSet.contains(op); - }; - // Set inclusive to true cause the slice is computed from the operand, and - // we want to include the defining op (which is the point here) - options.inclusive = true; - - llvm::SetVector slice; - for (auto op : operations) { - assert(insertionPoint->getBlock() == op->getBlock()); - for (auto operand : op->getOperands()) { - getBackwardSlice(operand, &slice, options); - } - } - - mlir::topologicalSort(slice); - for (auto op : slice) { - rewriter.moveOpBefore(op, insertionPoint); - } - return success(); -} - } // namespace mlir::iree_compiler::DispatchCreation diff --git a/compiler/src/iree/compiler/DispatchCreation/Passes.cpp b/compiler/src/iree/compiler/DispatchCreation/Passes.cpp index d5150ccae7c1..afee21cbbcd8 100644 --- a/compiler/src/iree/compiler/DispatchCreation/Passes.cpp +++ b/compiler/src/iree/compiler/DispatchCreation/Passes.cpp @@ -49,10 +49,10 @@ static llvm::cl::opt clEnableFusePaddingIntoLinalgProducerOps( static llvm::cl::opt clPadFactor( "iree-dispatch-creation-pad-factor", - llvm::cl::desc( - "Provides padding size hints that will be attached to " - "encodings. This only affects the experimental data tiling " - "path in Flow with iree-dispatch-creation-experimental-data-tiling."), + llvm::cl::desc("Provides padding size hints that will be attached to " + "encodings. This only affects the experimental data tiling " + "path in DispatchCreation with " + "iree-dispatch-creation-experimental-data-tiling."), llvm::cl::init(32)); static llvm::cl::opt clEnablePadHandling( @@ -337,14 +337,14 @@ void registerDispatchCreationPasses() { } void registerDispatchCreationPipelines() { - PassPipelineRegistration flowDispatchRegionCreationPipeline( + PassPipelineRegistration dispatchCreationPipeline( "iree-dispatch-creation-pipeline", "Flag used to run passes that form dispatch regions", [](OpPassManager &passManager, const TransformOptions &transformOptions) { buildDispatchCreationPassPipeline(passManager, transformOptions); }); - PassPipelineRegistration<> flowDispatchRegionFormationPreprocessingPipeline( + PassPipelineRegistration<> dispatchCreationPreprocessingPipeline( "iree-dispatch-creation-preprocessing-pipeline", "Flag used to run preprocessing passes that run passes before dispatch " "region formation. Used only for testing", diff --git a/compiler/src/iree/compiler/DispatchCreation/test/bubble_up_extract_slice.mlir b/compiler/src/iree/compiler/DispatchCreation/test/bubble_up_extract_slice.mlir index a5b7ea13ee27..56fa91d7b2d6 100644 --- a/compiler/src/iree/compiler/DispatchCreation/test/bubble_up_extract_slice.mlir +++ b/compiler/src/iree/compiler/DispatchCreation/test/bubble_up_extract_slice.mlir @@ -94,3 +94,24 @@ util.func public @bubble_up_extract_with_use(%arg0 : tensor<1024x7x7x2xi8>) -> ( // CHECK-DAG: %[[GENERIC1:.+]] = linalg.generic // CHECK-SAME: ins(%[[EXTRACT0]] : tensor<1024x7x7xi8>) // CHECK: util.return %[[GENERIC1]], %[[GENERIC0]] + +util.func public @bubble_up_extract_fill_multi_use() -> tensor<2x320x130x130xf8E4M3FNUZ> { + %cst_1 = arith.constant 1.000000e+00 : f8E4M3FNUZ + %cst_2 = arith.constant 2.000000e+00 : f8E4M3FNUZ + %1 = tensor.empty() : tensor<2x320x128x128xf8E4M3FNUZ> + %2 = linalg.fill ins(%cst_2 : f8E4M3FNUZ) outs(%1 : tensor<2x320x128x128xf8E4M3FNUZ>) -> tensor<2x320x128x128xf8E4M3FNUZ> + %3 = tensor.empty() : tensor<2x320x130x130xf8E4M3FNUZ> + %4 = linalg.fill ins(%cst_1 : f8E4M3FNUZ) outs(%3 : tensor<2x320x130x130xf8E4M3FNUZ>) -> tensor<2x320x130x130xf8E4M3FNUZ> + %extracted_slice_1 = tensor.extract_slice %4[0, 0, 1, 0] [2, 320, 128, 130] [1, 1, 1, 1] : tensor<2x320x130x130xf8E4M3FNUZ> to tensor<2x320x128x130xf8E4M3FNUZ> + %inserted_slice_1 = tensor.insert_slice %2 into %extracted_slice_1[0, 0, 0, 1] [2, 320, 128, 128] [1, 1, 1, 1] : tensor<2x320x128x128xf8E4M3FNUZ> into tensor<2x320x128x130xf8E4M3FNUZ> + %inserted_slice_2 = tensor.insert_slice %inserted_slice_1 into %4[0, 0, 1, 0] [2, 320, 128, 130] [1, 1, 1, 1] : tensor<2x320x128x130xf8E4M3FNUZ> into tensor<2x320x130x130xf8E4M3FNUZ> + util.return %inserted_slice_2 : tensor<2x320x130x130xf8E4M3FNUZ> +} + +// CHECK-LABEL: @bubble_up_extract_fill_multi_use +// CHECK: %[[FILL1:.+]] = linalg.fill +// CHECK: %[[EMPTY1:.+]] = tensor.empty +// CHECK: %[[FILL2:.+]] = linalg.fill +// CHECK-NOT: %[[SLICE:.+]] = tensor.extract_slice +// CHECK: %[[EMPTY2:.+]] = tensor.empty +// CHECK: %[[FILL3:.+]] = linalg.fill diff --git a/compiler/src/iree/compiler/DispatchCreation/test/fuse_multiuse_elementwise_producer.mlir b/compiler/src/iree/compiler/DispatchCreation/test/fuse_multiuse_elementwise_producer.mlir index c76fa0653635..cc3e159ca943 100644 --- a/compiler/src/iree/compiler/DispatchCreation/test/fuse_multiuse_elementwise_producer.mlir +++ b/compiler/src/iree/compiler/DispatchCreation/test/fuse_multiuse_elementwise_producer.mlir @@ -139,28 +139,3 @@ util.func public @math_sin() { // CHECK: %[[GENERIC:.+]]:2 = linalg.generic // CHECK-DAG: check.expect_almost_eq(%[[GENERIC]]#0, // CHECK-DAG: check.expect_almost_eq(%[[GENERIC]]#1, - -// ----- - -#map = affine_map<(d0, d1) -> (d0, d1)> -util.func public @fuse_by_moving_consumer(%arg0: tensor<5x5xf32>, %arg1: tensor<5x5xf32>) -> (tensor<5x5xf32>, tensor<25xf32>) { - %cst = arith.constant 1.000000e+00 : f32 - %cst_0 = arith.constant 2.000000e+00 : f32 - %cst_1 = arith.constant 3.000000e+00 : f32 - %4 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor<5x5xf32>) outs(%arg1 : tensor<5x5xf32>) { - ^bb0(%arg2: f32, %arg3: f32): - %8 = arith.addf %arg2, %cst : f32 - linalg.yield %8 : f32 - } -> tensor<5x5xf32> - // expected-note @below {{prior use here}} - %collapsed = tensor.collapse_shape %4 [[0, 1]] : tensor<5x5xf32> into tensor<25xf32> - %5 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%4 : tensor<5x5xf32>) outs(%arg1 : tensor<5x5xf32>) { - ^bb0(%arg2: f32, %arg3: f32): - %8 = arith.subf %arg2, %cst_0 : f32 - linalg.yield %8 : f32 - } -> tensor<5x5xf32> - util.return %5, %collapsed: tensor<5x5xf32>, tensor<25xf32> -} -// CHECK-LABEL: util.func public @fuse_by_moving_consumer -// CHECK: linalg.generic -// CHECK-NOT: linalg.generic diff --git a/compiler/src/iree/compiler/GlobalOptimization/Convert1X1FilterConv2DToMatmul.cpp b/compiler/src/iree/compiler/GlobalOptimization/Convert1X1FilterConv2DToMatmul.cpp index a8b4becfff2b..7128dbdfc03b 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/Convert1X1FilterConv2DToMatmul.cpp +++ b/compiler/src/iree/compiler/GlobalOptimization/Convert1X1FilterConv2DToMatmul.cpp @@ -6,7 +6,8 @@ #include "iree/compiler/GlobalOptimization/Passes.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/IR/AffineExpr.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -26,134 +27,51 @@ class Convert1x1FilterConvToMatmul : public OpRewritePattern { LogicalResult matchAndRewrite(Conv2DOpType convOp, PatternRewriter &rewriter) const override { - auto inputShapeType = llvm::dyn_cast( - convOp.getDpsInputOperand(0)->get().getType()); auto filterShapeType = llvm::dyn_cast( convOp.getDpsInputOperand(1)->get().getType()); - auto outputShapeType = llvm::dyn_cast( - convOp.getDpsInitOperand(0)->get().getType()); - - const bool isNCHW = isa(convOp); - const bool isNHWC = isa(convOp); - if (!isNCHW & !isNHWC) + if (!filterShapeType) return failure(); - if (!inputShapeType || !filterShapeType || !outputShapeType) - return failure(); + constexpr bool isNCHW = + std::is_same_v; + constexpr bool isNHWC = + std::is_same_v; + static_assert(isNCHW || isNHWC); - auto inputShape = inputShapeType.getShape(); auto filterShape = filterShapeType.getShape(); - auto outputShape = outputShapeType.getShape(); + + constexpr int64_t numLoops = 7; // Adjusting dimension indices based on Conv2DOpType. - const int nIndex = 0; - const int kcIndex = isNHWC ? 2 : 1; - const int kfIndex = isNHWC ? 3 : 0; - const int khIndex = isNHWC ? 0 : 2; - const int kwIndex = isNHWC ? 1 : 3; - const int ohIndex = isNHWC ? 1 : 2; - const int owIndex = isNHWC ? 2 : 3; - const int ocIndex = isNHWC ? 3 : 1; - - bool isInputHWDynamic = ShapedType::isDynamic(inputShape[ohIndex]) && - ShapedType::isDynamic(inputShape[owIndex]); - - // We cannot merge the width and height if they are both dynamic as we - // cannot expand them back to their dynamic values. - if (isInputHWDynamic) - return failure(); + constexpr int khIndex = isNHWC ? 0 : 2; + constexpr int kwIndex = isNHWC ? 1 : 3; + constexpr int khLoopIndex = isNHWC ? 4 : 5; + constexpr int kwLoopIndex = isNHWC ? 5 : 6; if (filterShape[khIndex] != 1 || filterShape[kwIndex] != 1) return failure(); - // TODO(ataei): Support conversion to linalg.batch_matmul. - if (inputShape[0] != 1) - return failure(); - - if (!llvm::all_of(convOp.getStrides(), [](APInt element) { - return element.getSExtValue() == 1; - })) - return failure(); - if (!llvm::all_of(convOp.getDilations(), [](APInt element) { - return element.getSExtValue() == 1; - })) - return failure(); - - auto combineDims = [](int64_t a, int64_t b) { - if (ShapedType::isDynamic(a) || ShapedType::isDynamic(b)) - return ShapedType::kDynamic; - return a * b; - }; - - SmallVector reassociationInputOutputIndices; - SmallVector reassociationFilterIndices; - SmallVector reshapedInputShape(2, 0); - SmallVector reshapedFilterShape(2, 0); - SmallVector reshapedOutputShape(2, 0); - if (isNHWC) { - // Generate reassociation indices. - reassociationInputOutputIndices = {{nIndex, ohIndex, owIndex}, {ocIndex}}; - reassociationFilterIndices = {{khIndex, kwIndex, kcIndex}, {kfIndex}}; - - // Generate matmul shapes from 1x1 conv. - reshapedInputShape = { - combineDims(inputShape[ohIndex], inputShape[owIndex]), - inputShape[ocIndex]}; - reshapedFilterShape = {filterShape[kcIndex], filterShape[kfIndex]}; - reshapedOutputShape = { - combineDims(outputShape[ohIndex], outputShape[owIndex]), - outputShape[ocIndex]}; - } else if (isNCHW) { - // Generate reassociation indices. - reassociationInputOutputIndices = {{nIndex, ocIndex}, {ohIndex, owIndex}}; - reassociationFilterIndices = {{kfIndex}, {kcIndex, khIndex, kwIndex}}; - - // Generate matmul shapes from 1x1 conv. - reshapedInputShape = { - inputShape[ocIndex], - combineDims(inputShape[ohIndex], inputShape[owIndex])}; - reshapedFilterShape = {filterShape[kfIndex], filterShape[kcIndex]}; - reshapedOutputShape = { - outputShape[ocIndex], - combineDims(outputShape[ohIndex], outputShape[owIndex])}; + SmallVector dimReplacements; + for (int i = 0; i < numLoops; i++) { + if (llvm::is_contained({khLoopIndex, kwLoopIndex}, i)) { + dimReplacements.push_back( + getAffineConstantExpr(0, rewriter.getContext())); + } else { + dimReplacements.push_back(getAffineDimExpr(i, rewriter.getContext())); + } } - auto reshapedInputType = RankedTensorType::get( - reshapedInputShape, inputShapeType.getElementType()); - - auto reshapedFilterType = RankedTensorType::get( - reshapedFilterShape, filterShapeType.getElementType()); - - auto reshapedOutputType = RankedTensorType::get( - reshapedOutputShape, outputShapeType.getElementType()); - - Value input = convOp.getDpsInputOperand(0)->get(); - Value filter = convOp.getDpsInputOperand(1)->get(); - Value output = convOp.getDpsInitOperand(0)->get(); - auto loc = convOp.getLoc(); - - Value reshapedInput = rewriter.create( - loc, reshapedInputType, input, reassociationInputOutputIndices); - Value reshapedFilter = rewriter.create( - loc, reshapedFilterType, filter, reassociationFilterIndices); - Value reshapedOutput = rewriter.create( - loc, reshapedOutputType, output, reassociationInputOutputIndices); - - SmallVector matmulInput; - if (isNHWC) { - matmulInput = {reshapedInput, reshapedFilter}; - } else if (isNCHW) { - matmulInput = {reshapedFilter, reshapedInput}; - } - auto matmulResult = rewriter.create( - loc, reshapedOutputType, matmulInput, ArrayRef{reshapedOutput}); - - auto reshapedResult = rewriter.create( - loc, outputShapeType, matmulResult.getResults()[0], - reassociationInputOutputIndices); - - rewriter.replaceOp(convOp, ArrayRef{reshapedResult}); - + SmallVector newMaps = convOp.getIndexingMapsArray(); + AffineMap inputMap = newMaps[0]; + SmallVector newExprs = + llvm::map_to_vector(inputMap.getResults(), [&](AffineExpr resultExpr) { + return resultExpr.replaceDims(dimReplacements); + }); + newMaps[0] = AffineMap::get(inputMap.getNumDims(), inputMap.getNumSymbols(), + newExprs, rewriter.getContext()); + + auto genericOp = linalg::generalizeNamedOp(rewriter, convOp).value(); + genericOp.setIndexingMapsAttr(rewriter.getAffineMapArrayAttr(newMaps)); return success(); } }; diff --git a/compiler/src/iree/compiler/GlobalOptimization/GeneralizeLinalgNamedOps.cpp b/compiler/src/iree/compiler/GlobalOptimization/GeneralizeLinalgNamedOps.cpp index 92293bc156ba..99f6268a47b0 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/GeneralizeLinalgNamedOps.cpp +++ b/compiler/src/iree/compiler/GlobalOptimization/GeneralizeLinalgNamedOps.cpp @@ -30,6 +30,34 @@ struct GeneralizeLinalgNamedOpsPass }; } // namespace +/// Returns true of `linalgOp` is a Conv2DNchwFchwOp or Conv2DNhwcHwcfOp with +/// all strides equal to 1 and with a kernel height and width of 1 +static bool isConvFoldableToContraction(linalg::LinalgOp linalgOp) { + auto NCHWOp = dyn_cast(linalgOp.getOperation()); + auto NHWCOp = dyn_cast(linalgOp.getOperation()); + + if (!NCHWOp && !NHWCOp) + return false; + + DenseIntElementsAttr strides = + NCHWOp ? NCHWOp.getStrides() : NHWCOp.getStrides(); + if (!llvm::all_of( + strides, [](APInt element) { return element.getSExtValue() == 1; })) { + return false; + } + + auto filterShapeType = llvm::dyn_cast( + linalgOp.getDpsInputOperand(1)->get().getType()); + if (!filterShapeType) + return false; + + // Adjusting dimension indices based on Conv2DOpType. + const int khIndex = NHWCOp ? 0 : 2; + const int kwIndex = NHWCOp ? 1 : 3; + auto filterShape = filterShapeType.getShape(); + return filterShape[khIndex] == 1 && filterShape[kwIndex] == 1; +} + void GeneralizeLinalgNamedOpsPass::runOnOperation() { auto funcOp = getOperation(); SmallVector namedOpCandidates; @@ -44,7 +72,8 @@ void GeneralizeLinalgNamedOpsPass::runOnOperation() { linalg::LogOp, linalg::MapOp, linalg::MaxOp, linalg::MulOp, linalg::NegFOp, linalg::ReduceOp, linalg::SubOp, linalg::TransposeOp>( - linalgOp.getOperation())) { + linalgOp.getOperation()) || + isConvFoldableToContraction(linalgOp)) { namedOpCandidates.push_back(linalgOp); } }); diff --git a/compiler/src/iree/compiler/GlobalOptimization/MaterializeHomogeneousEncodings.cpp b/compiler/src/iree/compiler/GlobalOptimization/MaterializeHomogeneousEncodings.cpp index 46a985f4bb9b..4ce2d92d5748 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/MaterializeHomogeneousEncodings.cpp +++ b/compiler/src/iree/compiler/GlobalOptimization/MaterializeHomogeneousEncodings.cpp @@ -85,9 +85,9 @@ class MaterializeHomogeneousEncodingsPass executableTarget.getBackend() == "rocm") { passManager.addPass(createGPUMaterializeHostEncodingPass()); FunctionLikeNest(passManager).addPass([&]() { - return createDecomposePackUnPackOpsPass(/*tileOuterToOne=*/false, - /*useOnlyReshapes=*/true, - /*controlFn=*/std::nullopt); + return createDecomposePackUnPackOpsPass( + DecomposePackUnPackOpsPassOptions{/*tileOuterToOne=*/false, + /*useOnlyReshapes=*/true}); }); } else { addNopPipeline(passManager); diff --git a/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp b/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp index 4f9a33e22a2f..bd61d4b6ce76 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp +++ b/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp @@ -101,8 +101,7 @@ void buildGlobalOptimizationPassPipeline( .addPass(IREE::Flow::createCanonicalizerPass) .addPass(createRemoveZeroExtentTensorsPass) .addPass(createDetachElementwiseFromNamedOpsPass) - .addPass(mlir::createLinalgNamedOpConversionPass) - .addPass(createConvert1X1FilterConv2DToMatmulPass); + .addPass(mlir::createLinalgNamedOpConversionPass); mainPassManager.addPass(createEraseUnusedLinalgOperandsPass()); // Expand tensor shapes into SSA values and optimize the whole program. diff --git a/compiler/src/iree/compiler/GlobalOptimization/test/conv1x1_to_matmul.mlir b/compiler/src/iree/compiler/GlobalOptimization/test/conv1x1_to_matmul.mlir index 980db9329b4e..607f137b87b0 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/test/conv1x1_to_matmul.mlir +++ b/compiler/src/iree/compiler/GlobalOptimization/test/conv1x1_to_matmul.mlir @@ -1,4 +1,4 @@ -// RUN: iree-opt --split-input-file -iree-global-opt-convert-1x1-filter-conv2d-to-matmul %s | FileCheck %s +// RUN: iree-opt --split-input-file --mlir-print-local-scope -iree-global-opt-convert-1x1-filter-conv2d-to-matmul %s | FileCheck %s util.func public @nhwc_conv_2d(%input: tensor<1x4x5x2xf32>, %filter: tensor<1x1x2x7xf32>) -> tensor<1x4x5x7xf32> { %0 = tensor.empty() : tensor<1x4x5x7xf32> @@ -9,20 +9,15 @@ util.func public @nhwc_conv_2d(%input: tensor<1x4x5x2xf32>, %filter: tensor<1x1x util.return %1 : tensor<1x4x5x7xf32> } -// CHECK: @nhwc_conv_2d -// CHECK: %[[INPUT:.+]]: tensor<1x4x5x2xf32> -// CHECK: %[[FILTER:.+]]: tensor<1x1x2x7xf32> -// CHECK: %[[OUTPUT:.+]] = tensor.empty() : tensor<1x4x5x7xf32> -// CHECK: %[[RESHAPED_INPUT:.+]] = tensor.collapse_shape %[[INPUT]] {{\[}}[0, 1, 2], [3]] : tensor<1x4x5x2xf32> into tensor<20x2xf32> -// CHECK: %[[RESHAPED_FILTER:.+]] = tensor.collapse_shape %[[FILTER]] {{\[}}[0, 1, 2], [3]] : tensor<1x1x2x7xf32> into tensor<2x7xf32> -// CHECK: %[[RESHAPED_OUTPUT:.+]] = tensor.collapse_shape %[[OUTPUT]] {{\[}}[0, 1, 2], [3]] : tensor<1x4x5x7xf32> into tensor<20x7xf32> -// CHECK: %[[MATMUL_RESULT:.+]] = linalg.matmul ins(%[[RESHAPED_INPUT]], %[[RESHAPED_FILTER]] : tensor<20x2xf32>, tensor<2x7xf32>) outs(%[[RESHAPED_OUTPUT]] : tensor<20x7xf32>) -// CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[MATMUL_RESULT]] {{\[}}[0, 1, 2], [3]] output_shape [1, 4, 5, 7] : tensor<20x7xf32> into tensor<1x4x5x7xf32> -// CHECK: util.return %[[RESULT]] +// CHECK-LABEL: @nhwc_conv_2d +// CHECK: %[[RESULT:.*]] = linalg.generic +// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d6)> +// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)> +// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)> +// CHECK: util.return %[[RESULT]] // ----- -// CHECK: @dynamic_nhwc_conv_2d util.func public @dynamic_nhwc_conv_2d(%input: tensor<1x4x?x2xf32>, %filter: tensor<1x1x2x7xf32>) -> tensor<1x4x?x7xf32> { %c2 = arith.constant 2 : index %d2 = tensor.dim %input, %c2 : tensor<1x4x?x2xf32> @@ -34,34 +29,12 @@ util.func public @dynamic_nhwc_conv_2d(%input: tensor<1x4x?x2xf32>, %filter: ten util.return %1 : tensor<1x4x?x7xf32> } -// CHECK: %[[INPUT:.+]]: tensor<1x4x?x2xf32> -// CHECK: %[[FILTER:.+]]: tensor<1x1x2x7xf32> -// CHECK: %[[C2:.+]] = arith.constant 2 : index -// CHECK: %[[D2:.+]] = tensor.dim %[[INPUT]], %[[C2]] -// CHECK: %[[OUTPUT:.+]] = tensor.empty(%[[D2]]) : tensor<1x4x?x7xf32> -// CHECK: %[[RESHAPED_INPUT:.+]] = tensor.collapse_shape %[[INPUT]] {{\[}}[0, 1, 2], [3]] : tensor<1x4x?x2xf32> into tensor -// CHECK: %[[RESHAPED_FILTER:.+]] = tensor.collapse_shape %[[FILTER]] {{\[}}[0, 1, 2], [3]] : tensor<1x1x2x7xf32> into tensor<2x7xf32> -// CHECK: %[[RESHAPED_OUTPUT:.+]] = tensor.collapse_shape %[[OUTPUT]] {{\[}}[0, 1, 2], [3]] : tensor<1x4x?x7xf32> into tensor -// CHECK: %[[MATMUL_RESULT:.+]] = linalg.matmul ins(%[[RESHAPED_INPUT]], %[[RESHAPED_FILTER]] : tensor, tensor<2x7xf32>) outs(%[[RESHAPED_OUTPUT]] : tensor) -// CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[MATMUL_RESULT]] {{\[}}[0, 1, 2], [3]] - -// ----- - -util.func public @fail_dynamic_nhwc_conv_2d(%input: tensor<1x?x?x2xf32>, %filter: tensor<1x1x2x7xf32>) -> tensor<1x?x?x7xf32> { - %c1 = arith.constant 1 : index - %c2 = arith.constant 2 : index - %d1 = tensor.dim %input, %c1 : tensor<1x?x?x2xf32> - %d2 = tensor.dim %input, %c2 : tensor<1x?x?x2xf32> - %0 = tensor.empty(%d1, %d2) : tensor<1x?x?x7xf32> - %1 = linalg.conv_2d_nhwc_hwcf { - dilations = dense<1> : tensor<2xi64>, - strides = dense<1> : tensor<2xi64> - } ins(%input, %filter : tensor<1x?x?x2xf32>, tensor<1x1x2x7xf32>) outs(%0 : tensor<1x?x?x7xf32>) -> tensor<1x?x?x7xf32> - util.return %1 : tensor<1x?x?x7xf32> -} - -// CHECK: @fail_dynamic_nhwc_conv_2d -// CHECK: linalg.conv_2d_nhwc_hwcf +// CHECK-LABEL: @dynamic_nhwc_conv_2d +// CHECK: %[[RESULT:.*]] = linalg.generic +// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d6)> +// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)> +// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)> +// CHECK: util.return %[[RESULT]] // ----- @@ -73,16 +46,12 @@ util.func public @nchw_conv_2d(%input: tensor<1x2x4x5xf32>, %filter: tensor<7x2x } ins(%input, %filter : tensor<1x2x4x5xf32>, tensor<7x2x1x1xf32>) outs(%0 : tensor<1x7x4x5xf32>) -> tensor<1x7x4x5xf32> util.return %1 : tensor<1x7x4x5xf32> } -// CHECK: @nchw_conv_2d -// CHECK: %[[INPUT:.+]]: tensor<1x2x4x5xf32> -// CHECK: %[[FILTER:.+]]: tensor<7x2x1x1xf32> -// CHECK: %[[OUTPUT:.+]] = tensor.empty() : tensor<1x7x4x5xf32> -// CHECK: %[[RESHAPED_INPUT:.+]] = tensor.collapse_shape %[[INPUT]] {{\[}}[0, 1], [2, 3]] : tensor<1x2x4x5xf32> into tensor<2x20xf32> -// CHECK: %[[RESHAPED_FILTER:.+]] = tensor.collapse_shape %[[FILTER]] {{\[}}[0], [1, 2, 3]] : tensor<7x2x1x1xf32> into tensor<7x2xf32> -// CHECK: %[[RESHAPED_OUTPUT:.+]] = tensor.collapse_shape %[[OUTPUT]] {{\[}}[0, 1], [2, 3]] : tensor<1x7x4x5xf32> into tensor<7x20xf32> -// CHECK: %[[MATMUL_RESULT:.+]] = linalg.matmul ins(%[[RESHAPED_FILTER]], %[[RESHAPED_INPUT]] : tensor<7x2xf32>, tensor<2x20xf32>) outs(%[[RESHAPED_OUTPUT]] : tensor<7x20xf32>) -// CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[MATMUL_RESULT]] {{\[}}[0, 1], [2, 3]] output_shape [1, 7, 4, 5] : tensor<7x20xf32> into tensor<1x7x4x5xf32> -// CHECK: util.return %[[RESULT]] +// CHECK-LABEL: @nchw_conv_2d +// CHECK: %[[RESULT:.*]] = linalg.generic +// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d2, d3)> +// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d4, d5, d6)> +// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)> +// CHECK: util.return %[[RESULT]] // ----- @@ -97,33 +66,27 @@ util.func public @dynamic_nchw_conv_2d(%input: tensor<1x2x4x?xf32>, %filter: ten util.return %1 : tensor<1x7x4x?xf32> } -// CHECK: @dynamic_nchw_conv_2d -// CHECK: %[[INPUT:.+]]: tensor<1x2x4x?xf32> -// CHECK: %[[FILTER:.+]]: tensor<7x2x1x1xf32> -// CHECK: %[[C3:.+]] = arith.constant 3 : index -// CHECK: %[[D3:.+]] = tensor.dim %[[INPUT]], %[[C3]] -// CHECK: %[[OUTPUT:.+]] = tensor.empty(%[[D3]]) : tensor<1x7x4x?xf32> -// CHECK: %[[RESHAPED_INPUT:.+]] = tensor.collapse_shape %[[INPUT]] {{\[}}[0, 1], [2, 3]] : tensor<1x2x4x?xf32> into tensor<2x?xf32> -// CHECK: %[[RESHAPED_FILTER:.+]] = tensor.collapse_shape %[[FILTER]] {{\[}}[0], [1, 2, 3]] : tensor<7x2x1x1xf32> into tensor<7x2xf32> -// CHECK: %[[RESHAPED_OUTPUT:.+]] = tensor.collapse_shape %[[OUTPUT]] {{\[}}[0, 1], [2, 3]] : tensor<1x7x4x?xf32> into tensor<7x?xf32> -// CHECK: %[[MATMUL_RESULT:.+]] = linalg.matmul ins(%[[RESHAPED_FILTER]], %[[RESHAPED_INPUT]] : tensor<7x2xf32>, tensor<2x?xf32>) outs(%[[RESHAPED_OUTPUT]] : tensor<7x?xf32>) -// CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[MATMUL_RESULT]] {{\[}}[0, 1], [2, 3]] -// CHECK: util.return %[[RESULT]] +// CHECK-LABEL: @dynamic_nchw_conv_2d +// CHECK: %[[RESULT:.*]] = linalg.generic +// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d2, d3)> +// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d4, d5, d6)> +// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)> +// CHECK: util.return %[[RESULT]] // ----- -util.func public @fail_dynamic_nchw_conv_2d(%input: tensor<1x2x?x?xf32>, %filter: tensor<7x2x1x1xf32>) -> tensor<1x7x?x?xf32> { - %c2 = arith.constant 2 : index - %c3 = arith.constant 3 : index - %d2 = tensor.dim %input, %c2 : tensor<1x2x?x?xf32> - %d3 = tensor.dim %input, %c3 : tensor<1x2x?x?xf32> +util.func public @strided_nchw_conv_2d(%input: tensor<1x2x?x?xf32>, %filter: tensor<7x2x1x1xf32>, %d2 : index, %d3 : index) -> tensor<1x7x?x?xf32> { %0 = tensor.empty(%d2, %d3) : tensor<1x7x?x?xf32> %1 = linalg.conv_2d_nchw_fchw { dilations = dense<1> : tensor<2xi64>, - strides = dense<1> : tensor<2xi64> + strides = dense<2> : tensor<2xi64> } ins(%input, %filter : tensor<1x2x?x?xf32>, tensor<7x2x1x1xf32>) outs(%0 : tensor<1x7x?x?xf32>) -> tensor<1x7x?x?xf32> util.return %1 : tensor<1x7x?x?xf32> } -// CHECK: @fail_dynamic_nchw_conv_2d -// CHECK: linalg.conv_2d_nchw_fchw +// CHECK-LABEL: @strided_nchw_conv_2d +// CHECK: %[[RESULT:.*]] = linalg.generic +// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d2 * 2, d3 * 2)> +// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d4, d5, d6)> +// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)> +// CHECK: util.return %[[RESULT]] diff --git a/compiler/src/iree/compiler/GlobalOptimization/test/generalize_named_ops.mlir b/compiler/src/iree/compiler/GlobalOptimization/test/generalize_named_ops.mlir index 5111152b7b0d..f3f0f8a0eb9b 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/test/generalize_named_ops.mlir +++ b/compiler/src/iree/compiler/GlobalOptimization/test/generalize_named_ops.mlir @@ -1,4 +1,4 @@ -// RUN: iree-opt --pass-pipeline="builtin.module(util.func(iree-global-opt-generalize-linalg-named-ops))" --split-input-file %s | FileCheck %s +// RUN: iree-opt --mlir-print-local-scope --pass-pipeline="builtin.module(util.func(iree-global-opt-generalize-linalg-named-ops))" --split-input-file %s | FileCheck %s util.func public @generalize_op(%arg0 : tensor, %arg1 : tensor) -> tensor { %c0 = arith.constant 0 : index @@ -34,3 +34,35 @@ util.func public @no_generalize_op_within_dispatch(%arg0 : tensor, %arg // CHECK: %[[ADD:.+]] = linalg.add // CHECK: flow.return %[[ADD]] // CHECK: util.return %[[DISPATCH]] + +// ----- + +util.func public @generalize_1x1_nhwc_conv_2d(%input: tensor<1x4x?x2xf32>, %filter: tensor<1x1x2x7xf32>) -> tensor<1x4x?x7xf32> { + %c2 = arith.constant 2 : index + %d2 = tensor.dim %input, %c2 : tensor<1x4x?x2xf32> + %0 = tensor.empty(%d2) : tensor<1x4x?x7xf32> + %1 = linalg.conv_2d_nhwc_hwcf { + dilations = dense<1> : tensor<2xi64>, + strides = dense<1> : tensor<2xi64> + } ins(%input, %filter : tensor<1x4x?x2xf32>, tensor<1x1x2x7xf32>) outs(%0 : tensor<1x4x?x7xf32>) -> tensor<1x4x?x7xf32> + util.return %1 : tensor<1x4x?x7xf32> +} + +// CHECK-LABEL: @generalize_1x1_nhwc_conv_2d +// CHECK: %[[RESULT:.*]] = linalg.generic +// CHECK: util.return %[[RESULT]] + +// ----- + +util.func public @generalize_1x1_nchw_conv_2d(%input: tensor<1x2x4x5xf32>, %filter: tensor<7x2x1x1xf32>) -> tensor<1x7x4x5xf32> { + %0 = tensor.empty() : tensor<1x7x4x5xf32> + %1 = linalg.conv_2d_nchw_fchw { + dilations = dense<1> : tensor<2xi64>, + strides = dense<1> : tensor<2xi64> + } ins(%input, %filter : tensor<1x2x4x5xf32>, tensor<7x2x1x1xf32>) outs(%0 : tensor<1x7x4x5xf32>) -> tensor<1x7x4x5xf32> + util.return %1 : tensor<1x7x4x5xf32> +} + +// CHECK-LABEL: @generalize_1x1_nchw_conv_2d +// CHECK: %[[RESULT:.*]] = linalg.generic +// CHECK: util.return %[[RESULT]] diff --git a/compiler/src/iree/compiler/InputConversion/Common/IREEImportPublic.cpp b/compiler/src/iree/compiler/InputConversion/Common/IREEImportPublic.cpp index de1f73603873..809fe8ffd1da 100644 --- a/compiler/src/iree/compiler/InputConversion/Common/IREEImportPublic.cpp +++ b/compiler/src/iree/compiler/InputConversion/Common/IREEImportPublic.cpp @@ -347,7 +347,7 @@ class FuncFuncOpPattern : public OpConversionPattern { // Allowlist of function attributes to retain when importing funcs. constexpr const char *kRetainedAttributes[] = { "iree.reflection", "stream.affinity", "vm.fallback", - "vm.signature", "vm.version", + "vm.signature", "vm.version", "nosideeffects", }; auto retainedAttributes = ArrayRef( kRetainedAttributes, diff --git a/compiler/src/iree/compiler/InputConversion/Common/test/iree_import_public.mlir b/compiler/src/iree/compiler/InputConversion/Common/test/iree_import_public.mlir index 0c48dbc93538..d51c8b60a136 100644 --- a/compiler/src/iree/compiler/InputConversion/Common/test/iree_import_public.mlir +++ b/compiler/src/iree/compiler/InputConversion/Common/test/iree_import_public.mlir @@ -13,6 +13,13 @@ func.func @noinline_func() -> () attributes {noinline} { return } +// ----- +// CHECK-LABEL: util.func public @nosideeffects_func +// CHECK: nosideeffects +func.func @nosideeffects_func() -> () attributes {nosideeffects} { + return +} + // ----- // CHECK-LABEL: util.func public @b_func // CHECK-SAME: (%arg0: !hal.buffer, %arg1: !hal.buffer) -> (!hal.buffer, !hal.buffer) diff --git a/compiler/src/iree/compiler/Preprocessing/Common/PadToIntrinsics.cpp b/compiler/src/iree/compiler/Preprocessing/Common/PadToIntrinsics.cpp index ba415b3fb656..922e50882775 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/PadToIntrinsics.cpp +++ b/compiler/src/iree/compiler/Preprocessing/Common/PadToIntrinsics.cpp @@ -242,16 +242,16 @@ padConvOp(RewriterBase &rewriter, linalg::LinalgOp linalgOp, return llvm::divideCeil(value, padTo) * padTo - value; }; - if (mSize % intrinsic.mSize != 0) { - mPadding = getPadding(mSize, intrinsic.mSize); + if (mSize % intrinsic.mSizes[0] != 0) { + mPadding = getPadding(mSize, intrinsic.mSizes[0]); } - if (nSize % intrinsic.nSize != 0) { - nPadding = getPadding(nSize, intrinsic.nSize); + if (nSize % intrinsic.nSizes[0] != 0) { + nPadding = getPadding(nSize, intrinsic.nSizes[0]); } - if (kSize % intrinsic.kSize != 0) { - kPadding = getPadding(kSize, intrinsic.kSize); + if (kSize % intrinsic.kSizes[0] != 0) { + kPadding = getPadding(kSize, intrinsic.kSizes[0]); } if (!mPadding && !nPadding && !kPadding) { @@ -381,7 +381,7 @@ static void padContractionLikeOp( for (GPUMatmulShapeType &intrinsic : intrinsics) { std::optional mPadding, nPadding, kPadding; SmallVector> dimsToExpandCandidate; - if (mSize % intrinsic.mSize != 0 || ShapedType::isDynamic(mSize)) { + if (mSize % intrinsic.mSizes[0] != 0 || ShapedType::isDynamic(mSize)) { OpFoldResult mSizeExpr = rewriter.getIndexAttr(mSize); if (ShapedType::isDynamic(mSize)) { auto mOperandDimPair = getSrcOperandAndDim(mDim); @@ -390,12 +390,12 @@ static void padContractionLikeOp( auto [mOperand, mOperandDim] = mOperandDimPair.value(); mSizeExpr = rewriter.create(loc, mOperand, mOperandDim) .getResult(); - dimsToExpandCandidate.emplace_back(mDim, intrinsic.mSize); + dimsToExpandCandidate.emplace_back(mDim, intrinsic.mSizes[0]); } - mPadding = getPadding(mSizeExpr, intrinsic.mSize); + mPadding = getPadding(mSizeExpr, intrinsic.mSizes[0]); } - if (nSize % intrinsic.nSize != 0 || ShapedType::isDynamic(nSize)) { + if (nSize % intrinsic.nSizes[0] != 0 || ShapedType::isDynamic(nSize)) { OpFoldResult nSizeExpr = rewriter.getIndexAttr(nSize); if (ShapedType::isDynamic(nSize)) { auto nOperandDimPair = getSrcOperandAndDim(nDim); @@ -404,12 +404,12 @@ static void padContractionLikeOp( auto [nOperand, nOperandDim] = nOperandDimPair.value(); nSizeExpr = rewriter.create(loc, nOperand, nOperandDim) .getResult(); - dimsToExpandCandidate.emplace_back(nDim, intrinsic.nSize); + dimsToExpandCandidate.emplace_back(nDim, intrinsic.nSizes[0]); } - nPadding = getPadding(nSizeExpr, intrinsic.nSize); + nPadding = getPadding(nSizeExpr, intrinsic.nSizes[0]); } - if (kSize % intrinsic.kSize != 0 || ShapedType::isDynamic(kSize)) { + if (kSize % intrinsic.kSizes[0] != 0 || ShapedType::isDynamic(kSize)) { OpFoldResult kSizeExpr = rewriter.getIndexAttr(kSize); if (ShapedType::isDynamic(kSize)) { auto kOperandDimPair = getSrcOperandAndDim(kDim); @@ -418,9 +418,9 @@ static void padContractionLikeOp( auto [kOperand, kOperandDim] = kOperandDimPair.value(); kSizeExpr = rewriter.create(loc, kOperand, kOperandDim) .getResult(); - dimsToExpandCandidate.emplace_back(kDim, intrinsic.kSize); + dimsToExpandCandidate.emplace_back(kDim, intrinsic.kSizes[0]); } - kPadding = getPadding(kSizeExpr, intrinsic.kSize); + kPadding = getPadding(kSizeExpr, intrinsic.kSizes[0]); } if (!mPadding && !nPadding && !kPadding) { diff --git a/compiler/src/iree/compiler/Tools/BUILD.bazel b/compiler/src/iree/compiler/Tools/BUILD.bazel index 7c717d641ebf..7a813c2a51cd 100644 --- a/compiler/src/iree/compiler/Tools/BUILD.bazel +++ b/compiler/src/iree/compiler/Tools/BUILD.bazel @@ -114,6 +114,7 @@ iree_compiler_cc_library( "@llvm-project//mlir:MathDialect", "@llvm-project//mlir:MemRefDialect", "@llvm-project//mlir:QuantOps", + "@llvm-project//mlir:ROCDLDialect", "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:SCFToGPU", "@llvm-project//mlir:SCFTransforms", diff --git a/compiler/src/iree/compiler/Tools/CMakeLists.txt b/compiler/src/iree/compiler/Tools/CMakeLists.txt index d9033ba99cee..8ad934245c3f 100644 --- a/compiler/src/iree/compiler/Tools/CMakeLists.txt +++ b/compiler/src/iree/compiler/Tools/CMakeLists.txt @@ -5,6 +5,7 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # Doesn't use bazel_to_cmake because of various special logic throughout. +# That there's various special logic throughout is _bad_. Don't replicate this. # Enable compiler targets based on options. set(IREE_COMPILER_TARGETS "") @@ -95,6 +96,7 @@ iree_cc_library( MLIRLinalgTransforms MLIRMLProgramDialect MLIRQuantDialect + MLIRROCDLDialect MLIRSCFDialect MLIRSCFToGPU MLIRSCFTransforms diff --git a/compiler/src/iree/compiler/Tools/init_mlir_dialects.h b/compiler/src/iree/compiler/Tools/init_mlir_dialects.h index e399e63dffe8..ce2b67134989 100644 --- a/compiler/src/iree/compiler/Tools/init_mlir_dialects.h +++ b/compiler/src/iree/compiler/Tools/init_mlir_dialects.h @@ -29,6 +29,7 @@ #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/ROCDLDialect.h" #include "mlir/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/TransformOps/DialectExtension.h" @@ -82,6 +83,7 @@ inline void registerMlirDialects(DialectRegistry ®istry) { pdl_interp::PDLInterpDialect, scf::SCFDialect, quant::QuantDialect, + ROCDL::ROCDLDialect, spirv::SPIRVDialect, arm_neon::ArmNeonDialect, arm_sve::ArmSVEDialect, diff --git a/docs/website/docs/developers/debugging/model-development.md b/docs/website/docs/developers/debugging/model-development.md index ea7bbbae0116..c7ba61a71ae3 100644 --- a/docs/website/docs/developers/debugging/model-development.md +++ b/docs/website/docs/developers/debugging/model-development.md @@ -147,7 +147,7 @@ if crashes, bugs, numerical issues, etc. can be reproduced at that scale. Some existing test suites can be found at these locations: * -* +* * * * diff --git a/docs/website/docs/developers/general/testing-guide.md b/docs/website/docs/developers/general/testing-guide.md index 070085dae253..50957e24e38f 100644 --- a/docs/website/docs/developers/general/testing-guide.md +++ b/docs/website/docs/developers/general/testing-guide.md @@ -412,6 +412,18 @@ repository also contains tests for many machine learning models. Some of these tests are planned to be migrated into [iree-org/iree-test-suites](https://github.com/iree-org/iree-test-suites). +### linalg operator tests + +Tests for operators in the MLIR linalg dialect like `matmul`, and `convolution` +are being migrated from folders like +[`tests/e2e/matmul/`](https://github.com/iree-org/iree/tree/main/tests/e2e/matmul) +in the +[iree-org/iree](https://github.com/iree-org/iree) repository to +[`linalg_ops/`](https://github.com/iree-org/iree-test-suites/tree/main/linalg_ops) +in the +[iree-org/iree-test-suites](https://github.com/iree-org/iree-test-suites) +repository. + ### ONNX operator tests Tests for individual ONNX operators are included at @@ -483,3 +495,12 @@ The workflow job that failed should then upload a new config file as an committed: ![image](https://github.com/user-attachments/assets/b5dbdcb4-4c0a-4ff2-adc6-9021614179b2) + +### ONNX model tests + +Tests for ONNX models are included at +[`onnx_models/`](https://github.com/iree-org/iree-test-suites/tree/main/onnx_models) +in the +[iree-org/iree-test-suites](https://github.com/iree-org/iree-test-suites) +repository. These tests use models from the upstream +[onnx/models](https://github.com/onnx/models) repository. diff --git a/docs/website/docs/guides/deployment-configurations/snippets/_iree-compiler-from-release.md b/docs/website/docs/guides/deployment-configurations/snippets/_iree-compiler-from-release.md index ab793f2b6d54..d461d640ef9a 100644 --- a/docs/website/docs/guides/deployment-configurations/snippets/_iree-compiler-from-release.md +++ b/docs/website/docs/guides/deployment-configurations/snippets/_iree-compiler-from-release.md @@ -4,7 +4,7 @@ [published to PyPI](https://pypi.org/user/google-iree-pypi-deploy/). ``` shell - python -m pip install iree-compiler iree-runtime + python -m pip install iree-compiler ``` === ":material-alert: Nightly releases" @@ -15,7 +15,7 @@ ``` shell python -m pip install \ --find-links https://iree.dev/pip-release-links.html \ - --upgrade iree-compiler iree-runtime + --upgrade iree-compiler ``` !!! tip diff --git a/docs/website/docs/guides/ml-frameworks/onnx.md b/docs/website/docs/guides/ml-frameworks/onnx.md index 5abe0112f862..d2dda65fe9f7 100644 --- a/docs/website/docs/guides/ml-frameworks/onnx.md +++ b/docs/website/docs/guides/ml-frameworks/onnx.md @@ -118,7 +118,8 @@ graph LR | Code samples | | | -- | -- | Generated op tests | [iree-test-suites `onnx_ops`](https://github.com/iree-org/iree-test-suites/tree/main/onnx_ops) -Curated op and model tests | [SHARK-TestSuite `e2eshark/onnx`](https://github.com/nod-ai/SHARK-TestSuite/tree/main/e2eshark/onnx) +Public model tests | [iree-test-suites `onnx_models`](https://github.com/iree-org/iree-test-suites/tree/main/onnx_models) +Curated op and model tests | SHARK-TestSuite [`e2eshark/onnx`](https://github.com/nod-ai/SHARK-TestSuite/tree/main/e2eshark/onnx) and [`alt_e2eshark/onnx_tests`](https://github.com/nod-ai/SHARK-TestSuite/tree/main/alt_e2eshark/onnx_tests) Importer tests | [torch-mlir `test/python/onnx_importer`](https://github.com/llvm/torch-mlir/tree/main/test/python/onnx_importer) ## :octicons-question-16: Troubleshooting diff --git a/experimental/web/sample_static/device_multithreaded.c b/experimental/web/sample_static/device_multithreaded.c index c70924bdc4bd..8b5ba39f6c78 100644 --- a/experimental/web/sample_static/device_multithreaded.c +++ b/experimental/web/sample_static/device_multithreaded.c @@ -18,7 +18,7 @@ iree_status_t create_device_with_static_loader(iree_allocator_t host_allocator, // Register the statically linked executable library. const iree_hal_executable_library_query_fn_t libraries[] = { - mnist_linked_llvm_cpu_library_query, + mnist_linked_library_query, }; iree_hal_executable_loader_t* library_loader = NULL; iree_status_t status = iree_hal_static_library_loader_create( diff --git a/experimental/web/sample_static/device_sync.c b/experimental/web/sample_static/device_sync.c index 3fbe3eed0bf6..f072903b963f 100644 --- a/experimental/web/sample_static/device_sync.c +++ b/experimental/web/sample_static/device_sync.c @@ -15,7 +15,7 @@ iree_status_t create_device_with_static_loader(iree_allocator_t host_allocator, // Register the statically linked executable library. const iree_hal_executable_library_query_fn_t libraries[] = { - mnist_linked_llvm_cpu_library_query, + mnist_linked_library_query, }; iree_hal_executable_loader_t* library_loader = NULL; iree_status_t status = iree_hal_static_library_loader_create( diff --git a/experimental/webgpu/nop_semaphore.c b/experimental/webgpu/nop_semaphore.c index d4151ee29990..65d26486567b 100644 --- a/experimental/webgpu/nop_semaphore.c +++ b/experimental/webgpu/nop_semaphore.c @@ -38,8 +38,8 @@ iree_status_t iree_hal_webgpu_nop_semaphore_create( iree_hal_resource_initialize(&iree_hal_webgpu_nop_semaphore_vtable, &semaphore->resource); semaphore->host_allocator = host_allocator; - iree_atomic_store_int64(&semaphore->value, initial_value, - iree_memory_order_seq_cst); + iree_atomic_store(&semaphore->value, initial_value, + iree_memory_order_seq_cst); *out_semaphore = (iree_hal_semaphore_t*)semaphore; } @@ -63,8 +63,7 @@ static iree_status_t iree_hal_webgpu_nop_semaphore_query( iree_hal_semaphore_t* base_semaphore, uint64_t* out_value) { iree_hal_webgpu_nop_semaphore_t* semaphore = iree_hal_webgpu_nop_semaphore_cast(base_semaphore); - *out_value = - iree_atomic_load_int64(&semaphore->value, iree_memory_order_seq_cst); + *out_value = iree_atomic_load(&semaphore->value, iree_memory_order_seq_cst); return iree_ok_status(); } @@ -72,8 +71,7 @@ static iree_status_t iree_hal_webgpu_nop_semaphore_signal( iree_hal_semaphore_t* base_semaphore, uint64_t new_value) { iree_hal_webgpu_nop_semaphore_t* semaphore = iree_hal_webgpu_nop_semaphore_cast(base_semaphore); - iree_atomic_store_int64(&semaphore->value, new_value, - iree_memory_order_seq_cst); + iree_atomic_store(&semaphore->value, new_value, iree_memory_order_seq_cst); return iree_ok_status(); } @@ -88,7 +86,7 @@ static iree_status_t iree_hal_webgpu_nop_semaphore_wait( iree_hal_webgpu_nop_semaphore_t* semaphore = iree_hal_webgpu_nop_semaphore_cast(base_semaphore); uint64_t current_value = - iree_atomic_load_int64(&semaphore->value, iree_memory_order_seq_cst); + iree_atomic_load(&semaphore->value, iree_memory_order_seq_cst); if (current_value < value) { return iree_make_status( IREE_STATUS_FAILED_PRECONDITION, diff --git a/runtime/bindings/python/CMakeLists.txt b/runtime/bindings/python/CMakeLists.txt index 5ba7b84fbcf9..35a85d190a27 100644 --- a/runtime/bindings/python/CMakeLists.txt +++ b/runtime/bindings/python/CMakeLists.txt @@ -4,21 +4,14 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -if(NOT nanobind_FOUND) - find_package(nanobind CONFIG QUIET) - if(NOT nanobind_FOUND) - execute_process( - COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir - OUTPUT_STRIP_TRAILING_WHITESPACE - OUTPUT_VARIABLE NB_DIR - RESULT_VARIABLE RC) - if(RC AND NOT RC EQUAL 0) - message(WARNING "Probing for nanobind failed. Please install the project's Python dependencies or '${Python_EXECUTABLE} -m pip install nanobind'") - endif() - list(APPEND CMAKE_PREFIX_PATH "${NB_DIR}") - endif() - find_package(nanobind CONFIG REQUIRED) -endif() +# nanobind +include(FetchContent) +FetchContent_Declare( + nanobind + GIT_REPOSITORY https://github.com/wjakob/nanobind.git + GIT_TAG 784efa2a0358a4dc5432c74f5685ee026e20f2b6 # 2.2.0 +) +FetchContent_MakeAvailable(nanobind) set(_EXTRA_INSTALL_TOOL_TARGETS) set(_TRACY_ENABLED OFF) diff --git a/runtime/bindings/python/iree/runtime/build_requirements.txt b/runtime/bindings/python/iree/runtime/build_requirements.txt index e3f07b20f0bb..0a8dca312c7d 100644 --- a/runtime/bindings/python/iree/runtime/build_requirements.txt +++ b/runtime/bindings/python/iree/runtime/build_requirements.txt @@ -5,12 +5,7 @@ pip>=21.3 setuptools>=62.4.0 -nanobind==2.2.0 numpy>=2.0.0b1 requests>=2.28.0 wheel>=0.36.2 sympy==1.12.1 - -# TODO: nanobind is used in the runtime but the compiler uses pybind and -# removing this breaks CI bots; remove this. -pybind11==2.13.6 diff --git a/runtime/pyproject.toml b/runtime/pyproject.toml index 3259f736c8db..16567a601072 100644 --- a/runtime/pyproject.toml +++ b/runtime/pyproject.toml @@ -3,7 +3,6 @@ requires = [ "setuptools>=42", "wheel", "cmake", - "nanobind==2.2.0", "ninja", "numpy>=2.0.0b1", "packaging", diff --git a/runtime/src/iree/base/internal/atomics.h b/runtime/src/iree/base/internal/atomics.h index 731d9eef510e..f428731506a5 100644 --- a/runtime/src/iree/base/internal/atomics.h +++ b/runtime/src/iree/base/internal/atomics.h @@ -86,47 +86,6 @@ extern "C" { #endif // IREE_COMPILER_* -// If the compiler can automatically determine the types: -#ifdef iree_atomic_load_auto - -#define iree_atomic_load_int32 iree_atomic_load_auto -#define iree_atomic_store_int32 iree_atomic_store_auto -#define iree_atomic_fetch_add_int32 iree_atomic_fetch_add_auto -#define iree_atomic_fetch_sub_int32 iree_atomic_fetch_sub_auto -#define iree_atomic_fetch_and_int32 iree_atomic_fetch_and_auto -#define iree_atomic_fetch_or_int32 iree_atomic_fetch_or_auto -#define iree_atomic_fetch_xor_int32 iree_atomic_fetch_xor_auto -#define iree_atomic_exchange_int32 iree_atomic_exchange_auto -#define iree_atomic_compare_exchange_strong_int32 \ - iree_atomic_compare_exchange_strong_auto -#define iree_atomic_compare_exchange_weak_int32 \ - iree_atomic_compare_exchange_weak_auto - -#define iree_atomic_load_int64 iree_atomic_load_auto -#define iree_atomic_store_int64 iree_atomic_store_auto -#define iree_atomic_fetch_add_int64 iree_atomic_fetch_add_auto -#define iree_atomic_fetch_sub_int64 iree_atomic_fetch_sub_auto -#define iree_atomic_fetch_and_int64 iree_atomic_fetch_and_auto -#define iree_atomic_fetch_or_int64 iree_atomic_fetch_or_auto -#define iree_atomic_fetch_xor_int64 iree_atomic_fetch_xor_auto -#define iree_atomic_exchange_int64 iree_atomic_exchange_auto -#define iree_atomic_compare_exchange_strong_int64 \ - iree_atomic_compare_exchange_strong_auto -#define iree_atomic_compare_exchange_weak_int64 \ - iree_atomic_compare_exchange_weak_auto - -#define iree_atomic_load_intptr iree_atomic_load_auto -#define iree_atomic_store_intptr iree_atomic_store_auto -#define iree_atomic_fetch_add_intptr iree_atomic_fetch_add_auto -#define iree_atomic_fetch_sub_intptr iree_atomic_fetch_sub_auto -#define iree_atomic_exchange_intptr iree_atomic_exchange_auto -#define iree_atomic_compare_exchange_strong_intptr \ - iree_atomic_compare_exchange_strong_auto -#define iree_atomic_compare_exchange_weak_intptr \ - iree_atomic_compare_exchange_weak_auto - -#endif // iree_atomic_load_auto - //============================================================================== // Reference count atomics //============================================================================== @@ -140,10 +99,10 @@ typedef iree_atomic_int32_t iree_atomic_ref_count_t; // should use IREE_ATOMIC_VAR_INIT, but apparently this has to be fixed // at call sites (where the variables are initialized in the first place). #define iree_atomic_ref_count_init_value(count_ptr, value) \ - iree_atomic_store_int32(count_ptr, value, iree_memory_order_relaxed) + iree_atomic_store((count_ptr), (value), iree_memory_order_relaxed) #define iree_atomic_ref_count_init(count_ptr) \ - iree_atomic_ref_count_init_value(count_ptr, 1) + iree_atomic_ref_count_init_value((count_ptr), 1) // Why relaxed order: // https://www.boost.org/doc/libs/1_57_0/doc/html/atomic/usage_examples.html#boost_atomic.usage_examples.example_reference_counters.discussion @@ -155,9 +114,9 @@ typedef iree_atomic_int32_t iree_atomic_ref_count_t; // value (unlike iree_atomic_ref_count_dec), so we make sure that it does not, // which allows the implementation to use faster atomic instructions where // available, e.g. STADD on ARMv8.1-a. -#define iree_atomic_ref_count_inc(count_ptr) \ - do { \ - iree_atomic_fetch_add_int32(count_ptr, 1, iree_memory_order_relaxed); \ +#define iree_atomic_ref_count_inc(count_ptr) \ + do { \ + iree_atomic_fetch_add((count_ptr), 1, iree_memory_order_relaxed); \ } while (false) // For now we stick to acq_rel order. TODO: should we follow Boost's advice? @@ -169,13 +128,13 @@ typedef iree_atomic_int32_t iree_atomic_ref_count_t; // may be a pessimization... I would like to hear a second opinion on this, // particularly regarding how x86-centric this might be. #define iree_atomic_ref_count_dec(count_ptr) \ - iree_atomic_fetch_sub_int32(count_ptr, 1, iree_memory_order_acq_rel) + iree_atomic_fetch_sub((count_ptr), 1, iree_memory_order_acq_rel) // memory_order_acquire order ensures that this sees decrements from // iree_atomic_ref_count_dec. On the other hand, there is no ordering with // iree_atomic_ref_count_inc. #define iree_atomic_ref_count_load(count_ptr) \ - iree_atomic_load_int32(count_ptr, iree_memory_order_acquire) + iree_atomic_load((count_ptr), iree_memory_order_acquire) // Aborts the program if the given reference count value is not 1. // This should be avoided in all situations but those where continuing execution diff --git a/runtime/src/iree/base/internal/atomics_clang.h b/runtime/src/iree/base/internal/atomics_clang.h index 44514e05c742..afa7a3352017 100644 --- a/runtime/src/iree/base/internal/atomics_clang.h +++ b/runtime/src/iree/base/internal/atomics_clang.h @@ -33,37 +33,38 @@ typedef enum iree_memory_order_e { typedef _Atomic int32_t iree_atomic_int32_t; typedef _Atomic int64_t iree_atomic_int64_t; +typedef _Atomic uint32_t iree_atomic_uint32_t; +typedef _Atomic uint64_t iree_atomic_uint64_t; // TODO(#3453): check for __int128 support before using // typedef _Atomic __int128 iree_atomic_int128_t; typedef _Atomic intptr_t iree_atomic_intptr_t; -#define iree_atomic_load_auto(object, order) \ - __c11_atomic_load((object), (order)) -#define iree_atomic_store_auto(object, desired, order) \ +#define iree_atomic_thread_fence(order) __c11_atomic_thread_fence(order) + +#define iree_atomic_load(object, order) __c11_atomic_load((object), (order)) +#define iree_atomic_store(object, desired, order) \ __c11_atomic_store((object), (desired), (order)) -#define iree_atomic_fetch_add_auto(object, operand, order) \ +#define iree_atomic_fetch_add(object, operand, order) \ __c11_atomic_fetch_add((object), (operand), (order)) -#define iree_atomic_fetch_sub_auto(object, operand, order) \ +#define iree_atomic_fetch_sub(object, operand, order) \ __c11_atomic_fetch_sub((object), (operand), (order)) -#define iree_atomic_fetch_and_auto(object, operand, order) \ +#define iree_atomic_fetch_and(object, operand, order) \ __c11_atomic_fetch_and((object), (operand), (order)) -#define iree_atomic_fetch_or_auto(object, operand, order) \ +#define iree_atomic_fetch_or(object, operand, order) \ __c11_atomic_fetch_or((object), (operand), (order)) -#define iree_atomic_fetch_xor_auto(object, operand, order) \ +#define iree_atomic_fetch_xor(object, operand, order) \ __c11_atomic_fetch_xor((object), (operand), (order)) -#define iree_atomic_exchange_auto(object, operand, order) \ +#define iree_atomic_exchange(object, operand, order) \ __c11_atomic_exchange((object), (operand), (order)) -#define iree_atomic_compare_exchange_strong_auto(object, expected, desired, \ - order_succ, order_fail) \ - __c11_atomic_compare_exchange_strong((object), (expected), (desired), \ +#define iree_atomic_compare_exchange_strong(object, expected, desired, \ + order_succ, order_fail) \ + __c11_atomic_compare_exchange_strong((object), (expected), (desired), \ (order_succ), (order_fail)) -#define iree_atomic_compare_exchange_weak_auto(object, expected, desired, \ - order_succ, order_fail) \ - __c11_atomic_compare_exchange_weak((object), (expected), (desired), \ +#define iree_atomic_compare_exchange_weak(object, expected, desired, \ + order_succ, order_fail) \ + __c11_atomic_compare_exchange_weak((object), (expected), (desired), \ (order_succ), (order_fail)) -#define iree_atomic_thread_fence(order) __c11_atomic_thread_fence(order) - #ifdef __cplusplus } // extern "C" #endif diff --git a/runtime/src/iree/base/internal/atomics_disabled.h b/runtime/src/iree/base/internal/atomics_disabled.h index 5c0a7cad6ff5..5dbb272f4748 100644 --- a/runtime/src/iree/base/internal/atomics_disabled.h +++ b/runtime/src/iree/base/internal/atomics_disabled.h @@ -16,12 +16,8 @@ #if IREE_SYNCHRONIZATION_DISABLE_UNSAFE -#ifdef __cplusplus -extern "C" { -#endif - typedef enum iree_memory_order_e { - iree_memory_order_relaxed, + iree_memory_order_relaxed = 0u, iree_memory_order_consume, iree_memory_order_acquire, iree_memory_order_release, @@ -33,65 +29,197 @@ typedef enum iree_memory_order_e { typedef int32_t iree_atomic_int32_t; typedef int64_t iree_atomic_int64_t; +typedef uint32_t iree_atomic_uint32_t; +typedef uint64_t iree_atomic_uint64_t; // TODO(#3453): check for __int128 support before using // typedef __int128 iree_atomic_int128_t; typedef intptr_t iree_atomic_intptr_t; -#define iree_atomic_load_int32(object, order) (*(object)) -#define iree_atomic_store_int32(object, desired, order) (*(object) = (desired)) -#define iree_atomic_fetch_add_int32(object, operand, order) \ - iree_atomic_fetch_add_int32_impl((volatile iree_atomic_int32_t*)(object), \ - (int32_t)(operand)) -#define iree_atomic_fetch_sub_int32(object, operand, order) \ - iree_atomic_fetch_add_int32_impl((volatile iree_atomic_int32_t*)(object), \ - -(int32_t)(operand)) -#define iree_atomic_fetch_and_int32(object, operand, order) \ - iree_atomic_fetch_and_int32_impl((volatile iree_atomic_int32_t*)(object), \ - (int32_t)(operand)) -#define iree_atomic_fetch_or_int32(object, operand, order) \ - iree_atomic_fetch_or_int32_impl((volatile iree_atomic_int32_t*)(object), \ - (int32_t)(operand)) -#define iree_atomic_fetch_xor_int32(object, operand, order) \ - iree_atomic_fetch_xor_int32_impl((volatile iree_atomic_int32_t*)(object), \ - (int32_t)(operand)) -#define iree_atomic_exchange_int32(object, desired, order) \ - iree_atomic_fetch_exchange_int32_impl( \ - (volatile iree_atomic_int32_t*)(object), (int32_t)(desired)) -#define iree_atomic_compare_exchange_strong_int32(object, expected, desired, \ - order_succ, order_fail) \ - iree_atomic_compare_exchange_int32_impl( \ - (volatile iree_atomic_int32_t*)(object), (int32_t*)(expected), \ - (int32_t)(desired)) -#define iree_atomic_compare_exchange_weak_int32 \ - iree_atomic_compare_exchange_strong_int32 - -#define iree_atomic_load_int64(object, order) (*(object)) -#define iree_atomic_store_int64(object, desired, order) (*(object) = (desired)) -#define iree_atomic_fetch_add_int64(object, operand, order) \ - iree_atomic_fetch_add_int64_impl((volatile iree_atomic_int64_t*)(object), \ - (int64_t)(operand)) -#define iree_atomic_fetch_sub_int64(object, operand, order) \ - iree_atomic_fetch_add_int64_impl((volatile iree_atomic_int64_t*)(object), \ - -(int64_t)(operand)) -#define iree_atomic_fetch_and_int64(object, operand, order) \ - iree_atomic_fetch_and_int64_impl((volatile iree_atomic_int64_t*)(object), \ - (int64_t)(operand)) -#define iree_atomic_fetch_or_int64(object, operand, order) \ - iree_atomic_fetch_or_int64_impl((volatile iree_atomic_int64_t*)(object), \ - (int64_t)(operand)) -#define iree_atomic_fetch_xor_int64(object, operand, order) \ - iree_atomic_fetch_xor_int64_impl((volatile iree_atomic_int64_t*)(object), \ - (int64_t)(operand)) -#define iree_atomic_exchange_int64(object, desired, order) \ - iree_atomic_fetch_exchange_int64_impl( \ - (volatile iree_atomic_int64_t*)(object), (int64_t)(desired)) -#define iree_atomic_compare_exchange_strong_int64(object, expected, desired, \ - order_succ, order_fail) \ - iree_atomic_compare_exchange_int64_impl( \ - (volatile iree_atomic_int64_t*)(object), (int64_t*)(expected), \ - (int64_t)(desired)) -#define iree_atomic_compare_exchange_weak_int64 \ - iree_atomic_compare_exchange_strong_int64 +#define iree_atomic_thread_fence(order) + +#ifdef __cplusplus + +extern "C++" { + +#define iree_atomic_load(object, order) (*(object)) +#define iree_atomic_store(object, desired, order) (*(object) = (desired)) +#define iree_atomic_fetch_add(object, operand, order) \ + iree_atomic_fetch_add_impl((object), (operand)) +#define iree_atomic_fetch_sub(object, operand, order) \ + iree_atomic_fetch_sub_impl((object), (operand)) +#define iree_atomic_fetch_and(object, operand, order) \ + iree_atomic_fetch_and_impl((object), (operand)) +#define iree_atomic_fetch_or(object, operand, order) \ + iree_atomic_fetch_or_impl((object), (operand)) +#define iree_atomic_fetch_xor(object, operand, order) \ + iree_atomic_fetch_xor_impl((object), (operand)) +#define iree_atomic_exchange(object, desired, order) \ + iree_atomic_fetch_exchange_impl((object), (desired)) +#define iree_atomic_compare_exchange_strong(object, expected, desired, \ + order_succ, order_fail) \ + iree_atomic_compare_exchange_impl((object), (expected), (desired)) +#define iree_atomic_compare_exchange_weak iree_atomic_compare_exchange_strong + +template +static inline T iree_atomic_fetch_add_impl(volatile T* object, V operand) { + T original = *object; + *object += operand; + return original; +} + +template +static inline T iree_atomic_fetch_sub_impl(volatile T* object, V operand) { + T original = *object; + *object -= operand; + return original; +} + +template +static inline T iree_atomic_fetch_and_impl(volatile T* object, V operand) { + T original = *object; + *object &= operand; + return original; +} + +template +static inline T iree_atomic_fetch_or_impl(volatile T* object, V operand) { + T original = *object; + *object |= operand; + return original; +} + +template +static inline T iree_atomic_fetch_xor_impl(volatile T* object, V operand) { + T original = *object; + *object ^= operand; + return original; +} + +template +static inline T iree_atomic_fetch_exchange_impl(volatile T* object, V desired) { + T original = *object; + *object = desired; + return original; +} + +template +static inline bool iree_atomic_compare_exchange_impl(volatile T* object, + V* expected, V desired) { + if (*object == *expected) { + *object = desired; + return true; + } else { + *expected = *object; + return false; + } +} + +} // extern "C" + +#else + +#define iree_atomic_load(object, order) (*(object)) +#define iree_atomic_store(object, desired, order) (*(object) = (desired)) +#define iree_atomic_fetch_add(object, operand, order) \ + _Generic((object), \ + iree_atomic_int32_t *: iree_atomic_fetch_add_int32_impl( \ + (volatile iree_atomic_int32_t*)(object), \ + (int32_t)(operand)), \ + iree_atomic_int64_t *: iree_atomic_fetch_add_int64_impl( \ + (volatile iree_atomic_int64_t*)(object), \ + (int64_t)(operand)), \ + iree_atomic_uint32_t *: iree_atomic_fetch_add_uint32_impl( \ + (volatile iree_atomic_uint32_t*)(object), \ + (uint32_t)(operand)), \ + iree_atomic_uint64_t *: iree_atomic_fetch_add_uint64_impl( \ + (volatile iree_atomic_uint64_t*)(object), \ + (uint64_t)(operand))) +#define iree_atomic_fetch_sub(object, operand, order) \ + _Generic((object), \ + iree_atomic_int32_t *: iree_atomic_fetch_sub_int32_impl( \ + (volatile iree_atomic_int32_t*)(object), \ + (int32_t)(operand)), \ + iree_atomic_int64_t *: iree_atomic_fetch_sub_int64_impl( \ + (volatile iree_atomic_int64_t*)(object), \ + (int64_t)(operand)), \ + iree_atomic_uint32_t *: iree_atomic_fetch_sub_uint32_impl( \ + (volatile iree_atomic_uint32_t*)(object), \ + (uint32_t)(operand)), \ + iree_atomic_uint64_t *: iree_atomic_fetch_sub_uint64_impl( \ + (volatile iree_atomic_uint64_t*)(object), \ + (uint64_t)(operand))) +#define iree_atomic_fetch_and(object, operand, order) \ + _Generic((object), \ + iree_atomic_int32_t *: iree_atomic_fetch_and_int32_impl( \ + (volatile iree_atomic_int32_t*)(object), \ + (int32_t)(operand)), \ + iree_atomic_int64_t *: iree_atomic_fetch_and_int64_impl( \ + (volatile iree_atomic_int64_t*)(object), \ + (int64_t)(operand)), \ + iree_atomic_uint32_t *: iree_atomic_fetch_and_int32_impl( \ + (volatile iree_atomic_int32_t*)(object), \ + (int32_t)(operand)), \ + iree_atomic_uint64_t *: iree_atomic_fetch_and_int64_impl( \ + (volatile iree_atomic_int64_t*)(object), \ + (int64_t)(operand))) +#define iree_atomic_fetch_or(object, operand, order) \ + _Generic((object), \ + iree_atomic_int32_t *: iree_atomic_fetch_or_int32_impl( \ + (volatile iree_atomic_int32_t*)(object), \ + (int32_t)(operand)), \ + iree_atomic_int64_t *: iree_atomic_fetch_or_int64_impl( \ + (volatile iree_atomic_int64_t*)(object), \ + (int64_t)(operand)), \ + iree_atomic_uint32_t *: iree_atomic_fetch_or_int32_impl( \ + (volatile iree_atomic_int32_t*)(object), \ + (int32_t)(operand)), \ + iree_atomic_uint64_t *: iree_atomic_fetch_or_int64_impl( \ + (volatile iree_atomic_int64_t*)(object), \ + (int64_t)(operand))) +#define iree_atomic_fetch_xor(object, operand, order) \ + _Generic((object), \ + iree_atomic_int32_t *: iree_atomic_fetch_xor_int32_impl( \ + (volatile iree_atomic_int32_t*)(object), \ + (int32_t)(operand)), \ + iree_atomic_int64_t *: iree_atomic_fetch_xor_int64_impl( \ + (volatile iree_atomic_int64_t*)(object), \ + (int64_t)(operand)), \ + iree_atomic_uint32_t *: iree_atomic_fetch_xor_int32_impl( \ + (volatile iree_atomic_int32_t*)(object), \ + (int32_t)(operand)), \ + iree_atomic_uint64_t *: iree_atomic_fetch_xor_int64_impl( \ + (volatile iree_atomic_int64_t*)(object), \ + (int64_t)(operand))) +#define iree_atomic_exchange(object, desired, order) \ + _Generic((object), \ + iree_atomic_int32_t *: iree_atomic_fetch_exchange_int32_impl( \ + (volatile iree_atomic_int32_t*)(object), \ + (int32_t)(desired)), \ + iree_atomic_int64_t *: iree_atomic_fetch_exchange_int64_impl( \ + (volatile iree_atomic_int64_t*)(object), \ + (int64_t)(desired)), \ + iree_atomic_uint32_t *: iree_atomic_fetch_exchange_int32_impl( \ + (volatile iree_atomic_int32_t*)(object), \ + (int32_t)(desired)), \ + iree_atomic_uint64_t *: iree_atomic_fetch_exchange_int64_impl( \ + (volatile iree_atomic_int64_t*)(object), \ + (int64_t)(desired))) +#define iree_atomic_compare_exchange_strong(object, expected, desired, \ + order_succ, order_fail) \ + _Generic((object), \ + iree_atomic_int32_t *: iree_atomic_compare_exchange_int32_impl( \ + (volatile iree_atomic_int32_t*)(object), \ + (int32_t*)(expected), (int32_t)(desired)), \ + iree_atomic_int64_t *: iree_atomic_compare_exchange_int64_impl( \ + (volatile iree_atomic_int64_t*)(object), \ + (int64_t*)(expected), (int64_t)(desired)), \ + iree_atomic_uint32_t *: iree_atomic_compare_exchange_int32_impl( \ + (volatile iree_atomic_int32_t*)(object), \ + (int32_t*)(expected), (int32_t)(desired)), \ + iree_atomic_uint64_t *: iree_atomic_compare_exchange_int64_impl( \ + (volatile iree_atomic_int64_t*)(object), \ + (int64_t*)(expected), (int64_t)(desired))) +#define iree_atomic_compare_exchange_weak iree_atomic_compare_exchange_strong static inline int32_t iree_atomic_fetch_add_int32_impl( volatile iree_atomic_int32_t* object, int32_t operand) { @@ -100,6 +228,27 @@ static inline int32_t iree_atomic_fetch_add_int32_impl( return original; } +static inline int32_t iree_atomic_fetch_sub_int32_impl( + volatile iree_atomic_int32_t* object, int32_t operand) { + int32_t original = *object; + *object -= operand; + return original; +} + +static inline int32_t iree_atomic_fetch_add_uint32_impl( + volatile iree_atomic_int32_t* object, uint32_t operand) { + uint32_t original = *object; + *object += operand; + return original; +} + +static inline int32_t iree_atomic_fetch_sub_uint32_impl( + volatile iree_atomic_uint32_t* object, uint32_t operand) { + uint32_t original = *object; + *object -= operand; + return original; +} + static inline int32_t iree_atomic_fetch_and_int32_impl( volatile iree_atomic_int32_t* object, int32_t operand) { int32_t original = *object; @@ -146,6 +295,27 @@ static inline int64_t iree_atomic_fetch_add_int64_impl( return original; } +static inline int64_t iree_atomic_fetch_sub_int64_impl( + volatile iree_atomic_int64_t* object, int64_t operand) { + int64_t original = *object; + *object -= operand; + return original; +} + +static inline int64_t iree_atomic_fetch_add_uint64_impl( + volatile iree_atomic_uint64_t* object, uint64_t operand) { + uint64_t original = *object; + *object += operand; + return original; +} + +static inline int64_t iree_atomic_fetch_sub_uint64_impl( + volatile iree_atomic_uint64_t* object, uint64_t operand) { + uint64_t original = *object; + *object -= operand; + return original; +} + static inline int64_t iree_atomic_fetch_and_int64_impl( volatile iree_atomic_int64_t* object, int64_t operand) { int64_t original = *object; @@ -185,59 +355,7 @@ static inline bool iree_atomic_compare_exchange_int64_impl( } } -// There are no pointer-width atomic ops in MSVC so we need to specialize based -// on the pointer size. -#if defined(IREE_PTR_SIZE_32) -#define iree_atomic_load_intptr(object, order) \ - (intptr_t) iree_atomic_load_int32((iree_atomic_int32_t*)(object), (order)) -#define iree_atomic_store_intptr(object, desired, order) \ - (intptr_t) iree_atomic_store_int32((iree_atomic_int32_t*)(object), \ - (int32_t)(desired), (order)) -#define iree_atomic_fetch_add_intptr(object, operand, order) \ - (intptr_t) iree_atomic_fetch_add_int32((iree_atomic_int32_t*)(object), \ - (int32_t)(operand), (order)) -#define iree_atomic_fetch_sub_intptr(object, operand, order) \ - (intptr_t) iree_atomic_fetch_sub_int32((iree_atomic_int32_t*)(object), \ - (int32_t)(operand), (order)) -#define iree_atomic_exchange_intptr(object, desired, order) \ - (intptr_t) iree_atomic_exchange_int32((iree_atomic_int32_t*)(object), \ - (int32_t)(desired), (order)) -#define iree_atomic_compare_exchange_strong_intptr(object, expected, desired, \ - order_succ, order_fail) \ - iree_atomic_compare_exchange_strong_int32( \ - (iree_atomic_int32_t*)(object), (int32_t*)(expected), \ - (int32_t)(desired), (order_succ), (order_fail)) -#define iree_atomic_compare_exchange_weak_intptr \ - iree_atomic_compare_exchange_strong_intptr -#else -#define iree_atomic_load_intptr(object, order) \ - (intptr_t) iree_atomic_load_int64((iree_atomic_int64_t*)(object), (order)) -#define iree_atomic_store_intptr(object, desired, order) \ - (intptr_t) iree_atomic_store_int64((iree_atomic_int64_t*)(object), \ - (int64_t)(desired), (order)) -#define iree_atomic_fetch_add_intptr(object, operand, order) \ - (intptr_t) iree_atomic_fetch_add_int64((iree_atomic_int64_t*)(object), \ - (int64_t)(operand), (order)) -#define iree_atomic_fetch_sub_intptr(object, operand, order) \ - (intptr_t) iree_atomic_fetch_sub_int64((iree_atomic_int64_t*)(object), \ - (int64_t)(operand), (order)) -#define iree_atomic_exchange_intptr(object, desired, order) \ - (intptr_t) iree_atomic_exchange_int64((iree_atomic_int64_t*)(object), \ - (int64_t)(desired), (order)) -#define iree_atomic_compare_exchange_strong_intptr(object, expected, desired, \ - order_succ, order_fail) \ - iree_atomic_compare_exchange_strong_int64( \ - (iree_atomic_int64_t*)(object), (int64_t*)(expected), \ - (int64_t)(desired), (order_succ), (order_fail)) -#define iree_atomic_compare_exchange_weak_intptr \ - iree_atomic_compare_exchange_strong_intptr -#endif // IREE_PTR_SIZE_32 - -#define iree_atomic_thread_fence(order) - -#ifdef __cplusplus -} // extern "C" -#endif +#endif // __cplusplus #endif // IREE_SYNCHRONIZATION_DISABLE_UNSAFE diff --git a/runtime/src/iree/base/internal/atomics_gcc.h b/runtime/src/iree/base/internal/atomics_gcc.h index d413b9816253..728add728612 100644 --- a/runtime/src/iree/base/internal/atomics_gcc.h +++ b/runtime/src/iree/base/internal/atomics_gcc.h @@ -34,6 +34,8 @@ typedef enum iree_memory_order_e { typedef int32_t iree_atomic_int32_t; typedef int64_t iree_atomic_int64_t; +typedef uint32_t iree_atomic_uint32_t; +typedef uint64_t iree_atomic_uint64_t; // typedef __int128 iree_atomic_int128_t; typedef intptr_t iree_atomic_intptr_t; @@ -45,47 +47,47 @@ typedef intptr_t iree_atomic_intptr_t; #define __iree_auto_type __auto_type #endif -#define iree_atomic_load_auto(object, order) \ +static inline void iree_atomic_thread_fence(int order) { + // Ignore error where TSan does not support atomic thread fence. + IREE_DISABLE_COMPILER_TSAN_ERRORS() + __atomic_thread_fence(order); + IREE_RESTORE_COMPILER_TSAN_ERRORS() +} + +#define iree_atomic_load(object, order) \ __extension__({ \ __iree_auto_type __atomic_load_ptr = (object); \ __typeof__(*__atomic_load_ptr) __atomic_load_tmp; \ __atomic_load(__atomic_load_ptr, &__atomic_load_tmp, (order)); \ __atomic_load_tmp; \ }) -#define iree_atomic_store_auto(object, desired, order) \ +#define iree_atomic_store(object, desired, order) \ __extension__({ \ __iree_auto_type __atomic_store_ptr = (object); \ __typeof__(*__atomic_store_ptr) __atomic_store_tmp = (desired); \ __atomic_store(__atomic_store_ptr, &__atomic_store_tmp, (order)); \ }) -#define iree_atomic_fetch_add_auto(object, operand, order) \ +#define iree_atomic_fetch_add(object, operand, order) \ __atomic_fetch_add((object), (operand), (order)) -#define iree_atomic_fetch_sub_auto(object, operand, order) \ +#define iree_atomic_fetch_sub(object, operand, order) \ __atomic_fetch_sub((object), (operand), (order)) -#define iree_atomic_fetch_and_auto(object, operand, order) \ +#define iree_atomic_fetch_and(object, operand, order) \ __atomic_fetch_and((object), (operand), (order)) -#define iree_atomic_fetch_or_auto(object, operand, order) \ +#define iree_atomic_fetch_or(object, operand, order) \ __atomic_fetch_or((object), (operand), (order)) -#define iree_atomic_fetch_xor_auto(object, operand, order) \ +#define iree_atomic_fetch_xor(object, operand, order) \ __atomic_fetch_xor((object), (operand), (order)) -#define iree_atomic_exchange_auto(object, operand, order) \ +#define iree_atomic_exchange(object, operand, order) \ __atomic_exchange_n((object), (operand), (order)) -#define iree_atomic_compare_exchange_strong_auto(object, expected, desired, \ - order_succ, order_fail) \ - __atomic_compare_exchange_n(object, expected, desired, /*weak=*/false, \ +#define iree_atomic_compare_exchange_strong(object, expected, desired, \ + order_succ, order_fail) \ + __atomic_compare_exchange_n(object, expected, desired, /*weak=*/false, \ (order_succ), (order_fail)) -#define iree_atomic_compare_exchange_weak_auto(object, expected, desired, \ - order_succ, order_fail) \ - __atomic_compare_exchange_n(object, expected, desired, /*weak=*/true, \ +#define iree_atomic_compare_exchange_weak(object, expected, desired, \ + order_succ, order_fail) \ + __atomic_compare_exchange_n(object, expected, desired, /*weak=*/true, \ (order_succ), (order_fail)) -static inline void iree_atomic_thread_fence(int order) { - // Ignore error where TSan does not support atomic thread fence. - IREE_DISABLE_COMPILER_TSAN_ERRORS() - __atomic_thread_fence(order); - IREE_RESTORE_COMPILER_TSAN_ERRORS() -} - #ifdef __cplusplus } // extern "C" #endif diff --git a/runtime/src/iree/base/internal/atomics_msvc.h b/runtime/src/iree/base/internal/atomics_msvc.h index 5cfbf43eb3f6..2af2798c0a13 100644 --- a/runtime/src/iree/base/internal/atomics_msvc.h +++ b/runtime/src/iree/base/internal/atomics_msvc.h @@ -16,12 +16,141 @@ #if defined(IREE_COMPILER_MSVC) -#ifdef __cplusplus +// TODO(benvanik): make MSVC's C11 atomic support work. +// It's difficult to detect and has some weird configuration assertions around +// mixed C and C++ code. Support is only present when the +// `/experimental:c11atomics` but that is ignored on /TP (C++) compilation. +// __STDC_NO_ATOMICS__ is not unset when included/enabled so we can't use the +// standard check. Hopefully that'd be fixed if it ever leaves experimental. +#define IREE_ATOMIC_USE_MSVC_C11 0 +#if IREE_ATOMIC_USE_MSVC_C11 +#include +#endif // IREE_ATOMIC_USE_MSVC_C11 + +#if IREE_ATOMIC_USE_MSVC_C11 && defined(atomic_init) + +typedef enum iree_memory_order_e { + iree_memory_order_relaxed = _Atomic_memory_order_relaxed, + iree_memory_order_consume = _Atomic_memory_order_consume, + iree_memory_order_acquire = _Atomic_memory_order_acquire, + iree_memory_order_release = _Atomic_memory_order_release, + iree_memory_order_acq_rel = _Atomic_memory_order_acq_rel, + iree_memory_order_seq_cst = _Atomic_memory_order_seq_cst, +} iree_memory_order_t; + +#define IREE_ATOMIC_VAR_INIT(value) (value) + +typedef _Atomic int32_t iree_atomic_int32_t; +typedef _Atomic int64_t iree_atomic_int64_t; +typedef _Atomic uint32_t iree_atomic_uint32_t; +typedef _Atomic uint64_t iree_atomic_uint64_t; +// TODO(#3453): check for __int128 support before using +// typedef _Atomic __int128 iree_atomic_int128_t; +typedef _Atomic intptr_t iree_atomic_intptr_t; + +#define iree_atomic_thread_fence(order) atomic_thread_fence(order) + +#define iree_atomic_load(object, order) __c11_atomic_load((object), (order)) +#define iree_atomic_store(object, desired, order) \ + __c11_atomic_store((object), (desired), (order)) +#define iree_atomic_fetch_add(object, operand, order) \ + __c11_atomic_fetch_add((object), (operand), (order)) +#define iree_atomic_fetch_sub(object, operand, order) \ + __c11_atomic_fetch_sub((object), (operand), (order)) +#define iree_atomic_fetch_and(object, operand, order) \ + __c11_atomic_fetch_and((object), (operand), (order)) +#define iree_atomic_fetch_or(object, operand, order) \ + __c11_atomic_fetch_or((object), (operand), (order)) +#define iree_atomic_fetch_xor(object, operand, order) \ + __c11_atomic_fetch_xor((object), (operand), (order)) +#define iree_atomic_exchange(object, operand, order) \ + __c11_atomic_exchange((object), (operand), (order)) +#define iree_atomic_compare_exchange_strong(object, expected, desired, \ + order_succ, order_fail) \ + __c11_atomic_compare_exchange_strong((object), (expected), (desired), \ + (order_succ), (order_fail)) +#define iree_atomic_compare_exchange_weak(object, expected, desired, \ + order_succ, order_fail) \ + __c11_atomic_compare_exchange_weak((object), (expected), (desired), \ + (order_succ), (order_fail)) + +#elif __cplusplus + +// When compiling for C++ we reinterpret atomics as std::atomic. This relies +// on std::atomic on primitive types being lock-free such that the memory for +// each atomic is just the atomic value. We need this special path because MSVC +// doesn't support C features like _Generic in C++. + +extern "C++" { +#include +} // extern "C++" + extern "C" { -#endif typedef enum iree_memory_order_e { - iree_memory_order_relaxed, + iree_memory_order_relaxed = _Atomic_memory_order_relaxed, + iree_memory_order_consume = _Atomic_memory_order_consume, + iree_memory_order_acquire = _Atomic_memory_order_acquire, + iree_memory_order_release = _Atomic_memory_order_release, + iree_memory_order_acq_rel = _Atomic_memory_order_acq_rel, + iree_memory_order_seq_cst = _Atomic_memory_order_seq_cst, +} iree_memory_order_t; + +#define IREE_ATOMIC_VAR_INIT(value) (value) + +typedef std::atomic iree_atomic_int32_t; +typedef std::atomic iree_atomic_int64_t; +typedef std::atomic iree_atomic_uint32_t; +typedef std::atomic iree_atomic_uint64_t; +typedef std::atomic iree_atomic_intptr_t; + +#define iree_atomic_thread_fence(order) std::atomic_thread_fence(order) + +#define iree_atomic_load(object, order) \ + std::atomic_load_explicit((object), (std::memory_order)(order)) +#define iree_atomic_store(object, desired, order) \ + std::atomic_store_explicit((object), (desired), (std::memory_order)(order)) +#define iree_atomic_fetch_add(object, operand, order) \ + std::atomic_fetch_add_explicit((object), (operand), \ + (std::memory_order)(order)) +#define iree_atomic_fetch_sub(object, operand, order) \ + std::atomic_fetch_sub_explicit((object), (operand), \ + (std::memory_order)(order)) +#define iree_atomic_fetch_and(object, operand, order) \ + std::atomic_fetch_and_explicit((object), (operand), \ + (std::memory_order)(order)) +#define iree_atomic_fetch_or(object, operand, order) \ + std::atomic_fetch_or_explicit((object), (operand), (std::memory_order)(order)) +#define iree_atomic_fetch_xor(object, operand, order) \ + std::atomic_fetch_xor_explicit((object), (operand), \ + (std::memory_order)(order)) +#define iree_atomic_exchange(object, operand, order) \ + std::atomic_exchange_explicit((object), (operand), (std::memory_order)(order)) +#define iree_atomic_compare_exchange_strong(object, expected, desired, \ + order_succ, order_fail) \ + std::atomic_compare_exchange_strong_explicit( \ + (object), (expected), (desired), (std::memory_order)(order_succ), \ + (std::memory_order)(order_fail)) +#define iree_atomic_compare_exchange_weak(object, expected, desired, \ + order_succ, order_fail) \ + std::atomic_compare_exchange_weak_explicit((object), (expected), (desired), \ + (std::memory_order)(order_succ), \ + (std::memory_order)(order_fail)) + +} // extern "C" + +#else + +// When compiling in C we can use _Generic to automatically route to the +// builtins that change their name based on the atomic type. This implementation +// is not good: it ignores memory order entirely and uses the full barrier +// implied by any of the _Interlocked* builtins. There are some variants of the +// builtins that we could use based on the order but their support across +// targets differs. Hopefully ~soon we can use C11 atomics directly and drop +// this code path. + +typedef enum iree_memory_order_e { + iree_memory_order_relaxed = 0u, iree_memory_order_consume, iree_memory_order_acquire, iree_memory_order_release, @@ -29,72 +158,131 @@ typedef enum iree_memory_order_e { iree_memory_order_seq_cst, } iree_memory_order_t; -#define IREE_ATOMIC_VAR_INIT(value) \ - { (value) } - -typedef struct { - int32_t __val; -} iree_atomic_int32_t; -typedef struct { - int64_t __val; -} iree_atomic_int64_t; -// typedef __declspec(align(16)) struct { -// uint64_t __val[2]; -// } iree_atomic_int128_t; -typedef struct { - intptr_t __val; -} iree_atomic_intptr_t; - -#define iree_atomic_load_int32(object, order) \ - InterlockedExchangeAdd((volatile LONG*)object, 0) -#define iree_atomic_store_int32(object, desired, order) \ - InterlockedExchange((volatile LONG*)object, desired) -#define iree_atomic_fetch_add_int32(object, operand, order) \ - InterlockedExchangeAdd((volatile LONG*)object, operand) -#define iree_atomic_fetch_sub_int32(object, operand, order) \ - InterlockedExchangeAdd((volatile LONG*)object, -((int32_t)(operand))) -#define iree_atomic_fetch_and_int32(object, operand, order) \ - InterlockedAnd((volatile LONG*)object, operand) -#define iree_atomic_fetch_or_int32(object, operand, order) \ - InterlockedOr((volatile LONG*)object, operand) -#define iree_atomic_fetch_xor_int32(object, operand, order) \ - InterlockedXor((volatile LONG*)object, operand) -#define iree_atomic_exchange_int32(object, desired, order) \ - InterlockedExchange((volatile LONG*)object, desired) -#define iree_atomic_compare_exchange_strong_int32(object, expected, desired, \ - order_succ, order_fail) \ - iree_atomic_compare_exchange_strong_int32_impl( \ - (volatile iree_atomic_int32_t*)(object), (int32_t*)(expected), \ - (int32_t)(desired), (order_succ), (order_fail)) -#define iree_atomic_compare_exchange_weak_int32 \ - iree_atomic_compare_exchange_strong_int32 - -#define iree_atomic_load_int64(object, order) \ - InterlockedExchangeAdd64((volatile LONG64*)object, 0) -#define iree_atomic_store_int64(object, desired, order) \ - InterlockedExchange64((volatile LONG64*)object, (LONG64)desired) -#define iree_atomic_fetch_add_int64(object, operand, order) \ - InterlockedExchangeAdd64((volatile LONG64*)object, (LONG64)operand) -#define iree_atomic_fetch_sub_int64(object, operand, order) \ - InterlockedExchangeAdd64((volatile LONG64*)object, -(operand)) -#define iree_atomic_fetch_and_int64(object, operand, order) \ - InterlockedAnd64((volatile LONG64*)object, operand) -#define iree_atomic_fetch_or_int64(object, operand, order) \ - InterlockedOr64((volatile LONG64*)object, operand) -#define iree_atomic_fetch_xor_int64(object, operand, order) \ - InterlockedXor64((volatile LONG64*)object, operand) -#define iree_atomic_exchange_int64(object, desired, order) \ - InterlockedExchange64((volatile LONG64*)object, desired) -#define iree_atomic_compare_exchange_strong_int64(object, expected, desired, \ - order_succ, order_fail) \ - iree_atomic_compare_exchange_strong_int64_impl( \ - (volatile iree_atomic_int64_t*)(object), (int64_t*)(expected), \ - (int64_t)(desired), (order_succ), (order_fail)) -#define iree_atomic_compare_exchange_weak_int64 \ - iree_atomic_compare_exchange_strong_int64 +#define IREE_ATOMIC_VAR_INIT(value) (value) + +typedef int32_t iree_atomic_int32_t; +typedef int64_t iree_atomic_int64_t; +typedef uint32_t iree_atomic_uint32_t; +typedef uint64_t iree_atomic_uint64_t; +typedef intptr_t iree_atomic_intptr_t; #define iree_atomic_thread_fence(order) MemoryBarrier() +#define iree_atomic_load(object, order) \ + _Generic((object), \ + iree_atomic_int32_t *: _InterlockedExchangeAdd( \ + (volatile int32_t*)(object), 0), \ + iree_atomic_int64_t *: _InterlockedExchangeAdd64( \ + (volatile int64_t*)(object), 0), \ + iree_atomic_uint32_t *: _InterlockedExchangeAdd( \ + (volatile int32_t*)(object), 0), \ + iree_atomic_uint64_t *: _InterlockedExchangeAdd64( \ + (volatile int64_t*)(object), 0)) +#define iree_atomic_store(object, desired, order) \ + _Generic((object), \ + iree_atomic_int32_t *: _InterlockedExchange((volatile int32_t*)(object), \ + (int32_t)(desired)), \ + iree_atomic_int64_t *: _InterlockedExchange64( \ + (volatile int64_t*)(object), \ + (int64_t)(desired)), \ + iree_atomic_uint32_t *: _InterlockedExchange( \ + (volatile int32_t*)(object), \ + (int32_t)(desired)), \ + iree_atomic_uint64_t *: _InterlockedExchange64( \ + (volatile int64_t*)(object), \ + (int64_t)(desired))) +#define iree_atomic_fetch_add(object, operand, order) \ + _Generic((object), \ + iree_atomic_int32_t *: _InterlockedExchangeAdd( \ + (volatile int32_t*)(object), \ + (int32_t)(operand)), \ + iree_atomic_int64_t *: _InterlockedExchangeAdd64( \ + (volatile int64_t*)(object), \ + (int64_t)(operand)), \ + iree_atomic_uint32_t *: _InterlockedExchangeAdd( \ + (volatile int32_t*)(object), \ + (int32_t)(operand)), \ + iree_atomic_uint64_t *: _InterlockedExchangeAdd64( \ + (volatile int64_t*)(object), \ + (int64_t)(operand))) +#define iree_atomic_fetch_sub(object, operand, order) \ + _Generic((object), \ + iree_atomic_int32_t *: _InterlockedExchangeAdd( \ + (volatile int32_t*)(object), \ + -((int32_t)(operand))), \ + iree_atomic_int64_t *: _InterlockedExchangeAdd64( \ + (volatile int64_t*)(object), \ + -((int64_t)(operand))), \ + iree_atomic_uint32_t *: _InterlockedExchangeAdd( \ + (volatile int32_t*)(object), \ + -((int32_t)(operand))), \ + iree_atomic_uint64_t *: _InterlockedExchangeAdd64( \ + (volatile int64_t*)(object), \ + -((int64_t)(operand)))) +#define iree_atomic_fetch_and(object, operand, order) \ + _Generic((object), \ + iree_atomic_int32_t *: _InterlockedAnd((volatile int32_t*)(object), \ + (int32_t)(operand)), \ + iree_atomic_int64_t *: _InterlockedAnd64((volatile int64_t*)(object), \ + (int64_t)(operand)), \ + iree_atomic_uint32_t *: _InterlockedAnd((volatile int32_t*)(object), \ + (int32_t)(operand)), \ + iree_atomic_uint64_t *: _InterlockedAnd64((volatile int64_t*)(object), \ + (int64_t)(operand))) +#define iree_atomic_fetch_or(object, operand, order) \ + _Generic((object), \ + iree_atomic_int32_t *: _InterlockedOr((volatile int32_t*)(object), \ + (int32_t)(operand)), \ + iree_atomic_int64_t *: _InterlockedOr64((volatile int64_t*)(object), \ + (int64_t)(operand)), \ + iree_atomic_uint32_t *: _InterlockedOr((volatile int32_t*)(object), \ + (int32_t)(operand)), \ + iree_atomic_uint64_t *: _InterlockedOr64((volatile int64_t*)(object), \ + (int64_t)(operand))) +#define iree_atomic_fetch_xor(object, operand, order) \ + _Generic((object), \ + iree_atomic_int32_t *: _InterlockedXor((volatile int32_t*)(object), \ + (int32_t)(operand)), \ + iree_atomic_int64_t *: _InterlockedXor64((volatile int64_t*)(object), \ + (int64_t)(operand)), \ + iree_atomic_uint32_t *: _InterlockedXor((volatile int32_t*)(object), \ + (int32_t)(operand)), \ + iree_atomic_uint64_t *: _InterlockedXor64((volatile int64_t*)(object), \ + (int64_t)(operand))) +#define iree_atomic_exchange(object, desired, order) \ + _Generic((object), \ + iree_atomic_int32_t *: _InterlockedExchange((volatile int32_t*)(object), \ + (int32_t)(desired)), \ + iree_atomic_int64_t *: _InterlockedExchange64( \ + (volatile int64_t*)(object), \ + (int64_t)(desired)), \ + iree_atomic_uint32_t *: _InterlockedExchange( \ + (volatile int32_t*)(object), \ + (int32_t)(desired)), \ + iree_atomic_uint64_t *: _InterlockedExchange64( \ + (volatile int64_t*)(object), \ + (int64_t)(desired))) +#define iree_atomic_compare_exchange_strong(object, expected, desired, \ + order_succ, order_fail) \ + _Generic((object), \ + iree_atomic_int32_t *: iree_atomic_compare_exchange_strong_int32_impl( \ + (volatile iree_atomic_int32_t*)(object), \ + (int32_t*)(expected), (int32_t)(desired), \ + (order_succ), (order_fail)), \ + iree_atomic_int64_t *: iree_atomic_compare_exchange_strong_int64_impl( \ + (volatile iree_atomic_int64_t*)(object), \ + (int64_t*)(expected), (int64_t)(desired), \ + (order_succ), (order_fail)), \ + iree_atomic_uint32_t *: iree_atomic_compare_exchange_strong_int32_impl( \ + (volatile iree_atomic_int32_t*)(object), \ + (int32_t*)(expected), (int32_t)(desired), \ + (order_succ), (order_fail)), \ + iree_atomic_uint64_t *: iree_atomic_compare_exchange_strong_int64_impl( \ + (volatile iree_atomic_int64_t*)(object), \ + (int64_t*)(expected), (int64_t)(desired), \ + (order_succ), (order_fail))) +#define iree_atomic_compare_exchange_weak iree_atomic_compare_exchange_strong + static inline bool iree_atomic_compare_exchange_strong_int32_impl( volatile iree_atomic_int32_t* object, int32_t* expected, int32_t desired, iree_memory_order_t order_succ, iree_memory_order_t order_fail) { @@ -123,59 +311,7 @@ static inline bool iree_atomic_compare_exchange_strong_int64_impl( } } -#define iree_atomic_thread_fence(order) MemoryBarrier() - -// There are no pointer-width atomic ops in MSVC so we need to specialize based -// on the pointer size. -#if defined(IREE_PTR_SIZE_32) -#define iree_atomic_load_intptr(object, order) \ - (intptr_t) iree_atomic_load_int32((iree_atomic_int32_t*)(object), (order)) -#define iree_atomic_store_intptr(object, desired, order) \ - (intptr_t) iree_atomic_store_int32((iree_atomic_int32_t*)(object), \ - (int32_t)(desired), (order)) -#define iree_atomic_fetch_add_intptr(object, operand, order) \ - (intptr_t) iree_atomic_fetch_add_int32((iree_atomic_int32_t*)(object), \ - (int32_t)(operand), (order)) -#define iree_atomic_fetch_sub_intptr(object, operand, order) \ - (intptr_t) iree_atomic_fetch_sub_int32((iree_atomic_int32_t*)(object), \ - (int32_t)(operand), (order)) -#define iree_atomic_exchange_intptr(object, desired, order) \ - (intptr_t) iree_atomic_exchange_int32((iree_atomic_int32_t*)(object), \ - (int32_t)(desired), (order)) -#define iree_atomic_compare_exchange_strong_intptr(object, expected, desired, \ - order_succ, order_fail) \ - iree_atomic_compare_exchange_strong_int32( \ - (iree_atomic_int32_t*)(object), (int32_t*)(expected), \ - (int32_t)(desired), (order_succ), (order_fail)) -#define iree_atomic_compare_exchange_weak_intptr \ - iree_atomic_compare_exchange_strong_intptr -#else -#define iree_atomic_load_intptr(object, order) \ - (intptr_t) iree_atomic_load_int64((iree_atomic_int64_t*)(object), (order)) -#define iree_atomic_store_intptr(object, desired, order) \ - (intptr_t) iree_atomic_store_int64((iree_atomic_int64_t*)(object), \ - (int64_t)(desired), (order)) -#define iree_atomic_fetch_add_intptr(object, operand, order) \ - (intptr_t) iree_atomic_fetch_add_int64((iree_atomic_int64_t*)(object), \ - (int64_t)(operand), (order)) -#define iree_atomic_fetch_sub_intptr(object, operand, order) \ - (intptr_t) iree_atomic_fetch_sub_int64((iree_atomic_int64_t*)(object), \ - (int64_t)(operand), (order)) -#define iree_atomic_exchange_intptr(object, desired, order) \ - (intptr_t) iree_atomic_exchange_int64((iree_atomic_int64_t*)(object), \ - (int64_t)(desired), (order)) -#define iree_atomic_compare_exchange_strong_intptr(object, expected, desired, \ - order_succ, order_fail) \ - iree_atomic_compare_exchange_strong_int64( \ - (iree_atomic_int64_t*)(object), (int64_t*)(expected), \ - (int64_t)(desired), (order_succ), (order_fail)) -#define iree_atomic_compare_exchange_weak_intptr \ - iree_atomic_compare_exchange_strong_intptr -#endif // IREE_PTR_SIZE_32 - -#ifdef __cplusplus -} // extern "C" -#endif +#endif // IREE_ATOMIC_USE_MSVC_C11 #endif // IREE_COMPILER_MSVC diff --git a/runtime/src/iree/base/internal/atomics_test.cc b/runtime/src/iree/base/internal/atomics_test.cc index a9fce2f2173e..d78890c674a7 100644 --- a/runtime/src/iree/base/internal/atomics_test.cc +++ b/runtime/src/iree/base/internal/atomics_test.cc @@ -21,9 +21,9 @@ TEST(AtomicPtr, LoadStore) { intptr_t ptr_0 = 0x0; intptr_t ptr_1 = 0x1; iree_atomic_intptr_t value = IREE_ATOMIC_VAR_INIT(ptr_0); - EXPECT_EQ(ptr_0, iree_atomic_load_intptr(&value, iree_memory_order_seq_cst)); - iree_atomic_store_intptr(&value, ptr_1, iree_memory_order_seq_cst); - EXPECT_EQ(ptr_1, iree_atomic_load_intptr(&value, iree_memory_order_seq_cst)); + EXPECT_EQ(ptr_0, iree_atomic_load(&value, iree_memory_order_seq_cst)); + iree_atomic_store(&value, ptr_1, iree_memory_order_seq_cst); + EXPECT_EQ(ptr_1, iree_atomic_load(&value, iree_memory_order_seq_cst)); } TEST(AtomicPtr, AddSub) { @@ -31,15 +31,15 @@ TEST(AtomicPtr, AddSub) { intptr_t ptr_1 = 0x1; intptr_t ptr_2 = 0x2; iree_atomic_intptr_t value = IREE_ATOMIC_VAR_INIT(ptr_0); - EXPECT_EQ(ptr_0, iree_atomic_fetch_add_intptr(&value, ptr_1, - iree_memory_order_seq_cst)); - EXPECT_EQ(ptr_1, iree_atomic_fetch_add_intptr(&value, ptr_1, - iree_memory_order_seq_cst)); - EXPECT_EQ(ptr_2, iree_atomic_fetch_sub_intptr(&value, ptr_1, - iree_memory_order_seq_cst)); - EXPECT_EQ(ptr_1, iree_atomic_fetch_sub_intptr(&value, ptr_1, - iree_memory_order_seq_cst)); - EXPECT_EQ(ptr_0, iree_atomic_load_intptr(&value, iree_memory_order_seq_cst)); + EXPECT_EQ(ptr_0, + iree_atomic_fetch_add(&value, ptr_1, iree_memory_order_seq_cst)); + EXPECT_EQ(ptr_1, + iree_atomic_fetch_add(&value, ptr_1, iree_memory_order_seq_cst)); + EXPECT_EQ(ptr_2, + iree_atomic_fetch_sub(&value, ptr_1, iree_memory_order_seq_cst)); + EXPECT_EQ(ptr_1, + iree_atomic_fetch_sub(&value, ptr_1, iree_memory_order_seq_cst)); + EXPECT_EQ(ptr_0, iree_atomic_load(&value, iree_memory_order_seq_cst)); } TEST(AtomicPtr, Exchange) { @@ -47,11 +47,11 @@ TEST(AtomicPtr, Exchange) { intptr_t ptr_1 = 0x1; intptr_t ptr_2 = 0x2; iree_atomic_intptr_t value = IREE_ATOMIC_VAR_INIT(ptr_0); - EXPECT_EQ(ptr_0, iree_atomic_exchange_intptr(&value, ptr_1, - iree_memory_order_seq_cst)); - EXPECT_EQ(ptr_1, iree_atomic_exchange_intptr(&value, ptr_2, - iree_memory_order_seq_cst)); - EXPECT_EQ(ptr_2, iree_atomic_load_intptr(&value, iree_memory_order_seq_cst)); + EXPECT_EQ(ptr_0, + iree_atomic_exchange(&value, ptr_1, iree_memory_order_seq_cst)); + EXPECT_EQ(ptr_1, + iree_atomic_exchange(&value, ptr_2, iree_memory_order_seq_cst)); + EXPECT_EQ(ptr_2, iree_atomic_load(&value, iree_memory_order_seq_cst)); } TEST(AtomicPtr, CompareExchange) { @@ -62,31 +62,31 @@ TEST(AtomicPtr, CompareExchange) { intptr_t ptr_expected = 0; // OK: value == ptr_0, CAS(ptr_0 -> ptr_1) - iree_atomic_store_intptr(&value, ptr_0, iree_memory_order_seq_cst); + iree_atomic_store(&value, ptr_0, iree_memory_order_seq_cst); ptr_expected = ptr_0; - EXPECT_TRUE(iree_atomic_compare_exchange_strong_intptr( - &value, &ptr_expected, ptr_1, iree_memory_order_seq_cst, - iree_memory_order_seq_cst)); + EXPECT_TRUE(iree_atomic_compare_exchange_strong(&value, &ptr_expected, ptr_1, + iree_memory_order_seq_cst, + iree_memory_order_seq_cst)); EXPECT_EQ(ptr_0, ptr_expected); - EXPECT_EQ(ptr_1, iree_atomic_load_intptr(&value, iree_memory_order_seq_cst)); + EXPECT_EQ(ptr_1, iree_atomic_load(&value, iree_memory_order_seq_cst)); // OK: value == ptr_1, CAS(ptr_1 -> ptr_2) - iree_atomic_store_intptr(&value, ptr_1, iree_memory_order_seq_cst); + iree_atomic_store(&value, ptr_1, iree_memory_order_seq_cst); ptr_expected = ptr_1; - EXPECT_TRUE(iree_atomic_compare_exchange_strong_intptr( - &value, &ptr_expected, ptr_2, iree_memory_order_seq_cst, - iree_memory_order_seq_cst)); + EXPECT_TRUE(iree_atomic_compare_exchange_strong(&value, &ptr_expected, ptr_2, + iree_memory_order_seq_cst, + iree_memory_order_seq_cst)); EXPECT_EQ(ptr_1, ptr_expected); - EXPECT_EQ(ptr_2, iree_atomic_load_intptr(&value, iree_memory_order_seq_cst)); + EXPECT_EQ(ptr_2, iree_atomic_load(&value, iree_memory_order_seq_cst)); // FAIL: value == ptr_0, CAS(ptr_1 -> ptr_2) - iree_atomic_store_intptr(&value, ptr_0, iree_memory_order_seq_cst); + iree_atomic_store(&value, ptr_0, iree_memory_order_seq_cst); ptr_expected = ptr_1; - EXPECT_FALSE(iree_atomic_compare_exchange_strong_intptr( - &value, &ptr_expected, ptr_2, iree_memory_order_seq_cst, - iree_memory_order_seq_cst)); + EXPECT_FALSE(iree_atomic_compare_exchange_strong(&value, &ptr_expected, ptr_2, + iree_memory_order_seq_cst, + iree_memory_order_seq_cst)); EXPECT_EQ(ptr_0, ptr_expected); - EXPECT_EQ(ptr_0, iree_atomic_load_intptr(&value, iree_memory_order_seq_cst)); + EXPECT_EQ(ptr_0, iree_atomic_load(&value, iree_memory_order_seq_cst)); } TEST(AtomicRefCount, IncDec) { diff --git a/runtime/src/iree/base/internal/dynamic_library_win32.c b/runtime/src/iree/base/internal/dynamic_library_win32.c index af6e4e80b8ef..2cbdd07f6416 100644 --- a/runtime/src/iree/base/internal/dynamic_library_win32.c +++ b/runtime/src/iree/base/internal/dynamic_library_win32.c @@ -91,7 +91,7 @@ static iree_status_t iree_dynamic_library_make_temp_file_path( static iree_atomic_int32_t next_unique_id = IREE_ATOMIC_VAR_INIT(0); // relaxed because we only care about uniqueness, we don't care about ordering // of accesses to unique_id w.r.t. other memory operations. - uint32_t unique_id = (uint32_t)iree_atomic_fetch_add_int32( + uint32_t unique_id = (uint32_t)iree_atomic_fetch_add( &next_unique_id, 1, iree_memory_order_relaxed); // Allocate storage for the full file path and format it in. diff --git a/runtime/src/iree/base/internal/math.h b/runtime/src/iree/base/internal/math.h index 58dd88d13ea5..1e71e0d4553b 100644 --- a/runtime/src/iree/base/internal/math.h +++ b/runtime/src/iree/base/internal/math.h @@ -275,7 +275,7 @@ static inline uint64_t iree_math_round_up_to_pow2_u64(uint64_t n) { // Define some helper constants for working with a floating-point format with // the given number of {exponent,mantissa} bits. -#define IREE_MATH_FP_FORMAT_CONSTANTS(prefix, ebits, mbits) \ +#define IREE_MATH_FP_FORMAT_CONSTANTS(prefix, ebits, mbits, bias_tweak) \ const int prefix##exp_bits IREE_ATTRIBUTE_UNUSED = ebits; \ const int prefix##mantissa_bits IREE_ATTRIBUTE_UNUSED = mbits; \ const int prefix##sign_shift IREE_ATTRIBUTE_UNUSED = ebits + mbits; \ @@ -287,7 +287,7 @@ static inline uint64_t iree_math_round_up_to_pow2_u64(uint64_t n) { const int prefix##exp_mask IREE_ATTRIBUTE_UNUSED = \ (1u << prefix##sign_shift) - (1u << prefix##exp_shift); \ const int prefix##exp_bias IREE_ATTRIBUTE_UNUSED = \ - (1u << (prefix##exp_bits - 1)) - 1; + bias_tweak + (1u << (prefix##exp_bits - 1)) - 1; // Generic conversion from any less-than-32-bit floating-point format to f32. // The `src` value is typed as a uint32_t for genericity but occupies only the @@ -295,39 +295,54 @@ static inline uint64_t iree_math_round_up_to_pow2_u64(uint64_t n) { // unused. static inline float iree_math_make_f32_from_bits(uint32_t src, int exp_bits, int mantissa_bits, - bool have_infinity) { - IREE_MATH_FP_FORMAT_CONSTANTS(src_, exp_bits, mantissa_bits) - IREE_MATH_FP_FORMAT_CONSTANTS(f32_, 8, 23) + bool have_infinity, + int bias_tweak, + bool nan_as_neg_zero) { + IREE_MATH_FP_FORMAT_CONSTANTS(src_, exp_bits, mantissa_bits, bias_tweak) + IREE_MATH_FP_FORMAT_CONSTANTS(f32_, 8, 23, 0) const uint32_t src_sign = src & src_sign_mask; const uint32_t f32_sign = src_sign << (f32_sign_shift - src_sign_shift); const uint32_t src_exp = src & src_exp_mask; const uint32_t src_mantissa = src & src_mantissa_mask; - uint32_t f32_exp = 0; - uint32_t f32_mantissa = 0; + // Initializing f32_exp and f32_mantissa for the case of normal finite values. + // Below we will overload that in other cases. + uint32_t f32_exp = ((src_exp >> src_exp_shift) + f32_exp_bias - src_exp_bias) + << f32_exp_shift; + uint32_t f32_mantissa = src_mantissa + << (f32_mantissa_bits - src_mantissa_bits); if (src_exp == src_exp_mask) { - // No infinities => more large finite values. - if (!have_infinity && src_mantissa != src_mantissa_mask) { - float sign = (src & src_sign_mask) ? -1.0f : 1.0f; - return sign * 2 * (1u << src_exp_bits) * - ((1u << src_mantissa_bits) + src_mantissa); + // Top exponent value normally means infinity or NaN. + if (have_infinity) { + // NaN or Inf case. + f32_exp = f32_exp_mask; + if (src_mantissa) { + f32_mantissa = f32_mantissa_mask; // Quiet NaN. + } else { + f32_mantissa = 0; // Inf. + } + } else { + // No infinities => more large finite values, unless this is a NaN. + bool is_finite = src_mantissa != src_mantissa_mask || nan_as_neg_zero; + if (is_finite) { + f32_exp = ((src_exp >> src_exp_shift) + f32_exp_bias - src_exp_bias) + << f32_exp_shift; + f32_mantissa = src_mantissa << (f32_mantissa_bits - src_mantissa_bits); + } else { + // NaN. Generate a quiet NaN. + f32_exp = f32_exp_mask; + f32_mantissa = f32_mantissa_mask; + } } - // NaN or Inf case. - f32_exp = f32_exp_mask; - if (src_mantissa) { - // NaN. Generate a quiet NaN. + } else if (src_exp == 0) { + // Zero or subnormal. Generate zero, except in one case: if the source type + // encodes NaN as signed zero, we handle that now. + if (nan_as_neg_zero && src == src_sign_mask) { + f32_exp = f32_exp_mask; f32_mantissa = f32_mantissa_mask; } else { - // Inf. Leave zero mantissa. + f32_exp = 0; + f32_mantissa = 0; } - } else if (src_exp == 0) { - // Zero or subnormal. Generate zero. Leave zero mantissa. - } else { - // Normal finite value. - int arithmetic_src_exp = src_exp >> src_exp_shift; - int arithmetic_f32_exp = arithmetic_src_exp + (1 << (f32_exp_bits - 1)) - - (1 << (src_exp_bits - 1)); - f32_exp = arithmetic_f32_exp << f32_exp_shift; - f32_mantissa = src_mantissa << (f32_mantissa_bits - src_mantissa_bits); } const uint32_t u32_value = f32_sign | f32_exp | f32_mantissa; float f32_value; @@ -340,28 +355,34 @@ static inline float iree_math_make_f32_from_bits(uint32_t src, int exp_bits, // genericity but occupies only the bottom (1 + exp_bits + mantissa_bits) bits. // The upper bits of the return value are unused. static inline uint32_t iree_math_truncate_f32_to_bits_rounding_to_nearest_even( - float value, int exp_bits, int mantissa_bits, bool have_infinity) { - IREE_MATH_FP_FORMAT_CONSTANTS(dst_, exp_bits, mantissa_bits) - IREE_MATH_FP_FORMAT_CONSTANTS(f32_, 8, 23) + float value, int exp_bits, int mantissa_bits, bool have_infinity, + int bias_tweak, bool nan_as_neg_zero) { + IREE_MATH_FP_FORMAT_CONSTANTS(dst_, exp_bits, mantissa_bits, bias_tweak) + IREE_MATH_FP_FORMAT_CONSTANTS(f32_, 8, 23, 0) uint32_t u32_value; memcpy(&u32_value, &value, sizeof value); const uint32_t f32_sign = u32_value & f32_sign_mask; - const uint32_t dst_sign = f32_sign >> (f32_sign_shift - dst_sign_shift); + uint32_t dst_sign = f32_sign >> (f32_sign_shift - dst_sign_shift); const uint32_t f32_exp = u32_value & f32_exp_mask; const uint32_t f32_mantissa = u32_value & f32_mantissa_mask; uint32_t dst_exp = 0; uint32_t dst_mantissa = 0; + bool generate_nan = false; if (f32_exp >= f32_exp_mask) { // NaN or Inf case. dst_exp = dst_exp_mask; if (f32_mantissa || !have_infinity) { // NaN. Generate a quiet NaN. - dst_mantissa = dst_mantissa_mask; + generate_nan = true; } else { // Inf. Leave zero mantissa. } } else if (f32_exp == 0) { // Zero or subnormal. Generate zero. Leave zero mantissa. + if (nan_as_neg_zero) { + // The destination has no signed zero. Avoid accidentally generating NaN. + dst_sign = 0; + } } else { // Normal finite value. int arithmetic_exp = (f32_exp >> f32_exp_shift) - f32_exp_bias; @@ -373,7 +394,7 @@ static inline uint32_t iree_math_truncate_f32_to_bits_rounding_to_nearest_even( dst_exp = dst_exp_mask; if (!have_infinity) { // Generate NaN. - dst_mantissa = dst_mantissa_mask; + generate_nan = true; } } else if (arithmetic_exp < -(1 << (dst_exp_bits - 1))) { // Underflow. Generate zero. Leave zero mantissa. @@ -401,38 +422,52 @@ static inline uint32_t iree_math_truncate_f32_to_bits_rounding_to_nearest_even( biased_f32_mantissa = 0; ++arithmetic_exp; } - // In the !have_infinity case, arithmetic_exp might have been the top - // value already, so incrementing it may have overflown it. - if (!have_infinity && arithmetic_exp > (1 << (dst_exp_bits - 1))) { - dst_exp = dst_exp_mask; - dst_mantissa = dst_mantissa_mask; - } else { - // The exponent increment in the above if() branch may cause overflow. - // This is exercised by converting 65520.0f from f32 to f16. No special - // handling is needed for this case: the above if() branch already set - // biased_f32_mantissa=0, so we will be generating a 0 mantissa, as - // needed for infinite values. - dst_exp = (arithmetic_exp + dst_exp_bias) << dst_exp_shift; - dst_mantissa = - biased_f32_mantissa >> (f32_mantissa_bits - dst_mantissa_bits); + // The exponent increment in the above if() branch may cause overflow. + // This is exercised by converting 65520.0f from f32 to f16. When the + // destination type has infinities, no special handling is needed for this + // case: the above if() branch already set biased_f32_mantissa=0, so we + // will be generating a 0 mantissa, as needed for infinite values. The one + // case where special handling is needed is when the destination type has + // no infinities and we need to generate NaN. + dst_exp = (arithmetic_exp + dst_exp_bias) << dst_exp_shift; + dst_mantissa = + biased_f32_mantissa >> (f32_mantissa_bits - dst_mantissa_bits); + if (!have_infinity && dst_exp > dst_exp_mask) { + generate_nan = true; } } } - uint32_t dst_value = dst_sign | dst_exp | dst_mantissa; - return dst_value; + if (generate_nan) { + if (nan_as_neg_zero) { + return dst_sign_mask; + } else { + return dst_sign | dst_exp_mask | dst_mantissa_mask; + } + } else { + if (nan_as_neg_zero && dst_exp == 0 && dst_mantissa == 0) { + // Negative zero needs to be rounded to positive zero to avoid + // accidentally producing NaN when negative-zero is the NaN encoding. + return 0; + } else { + return dst_sign | dst_exp | dst_mantissa; + } + } } #define IREE_MATH_MAKE_FLOAT_TYPE_HELPERS(NAME, INT_TYPE, EXP_BITS, \ - MANTISSA_BITS, HAVE_INFINITY) \ + MANTISSA_BITS, HAVE_INFINITY, \ + BIAS_TWEAK, NAN_AS_NEG_ZERO) \ /* Converts a to a 32-bit C `float`. */ \ static inline float iree_math_##NAME##_to_f32(INT_TYPE src) { \ return iree_math_make_f32_from_bits(src, EXP_BITS, MANTISSA_BITS, \ - HAVE_INFINITY); \ + HAVE_INFINITY, BIAS_TWEAK, \ + NAN_AS_NEG_ZERO); \ } \ /* Truncates a 32-bit C `float`, rounding to nearest even. */ \ static inline INT_TYPE iree_math_f32_to_##NAME(float value) { \ return iree_math_truncate_f32_to_bits_rounding_to_nearest_even( \ - value, EXP_BITS, MANTISSA_BITS, HAVE_INFINITY); \ + value, EXP_BITS, MANTISSA_BITS, HAVE_INFINITY, BIAS_TWEAK, \ + NAN_AS_NEG_ZERO); \ } \ /* Round-trip f32->f32 rounding via the narrow float type */ \ static inline float iree_math_round_to_nearest_##NAME(float value) { \ @@ -441,16 +476,44 @@ static inline uint32_t iree_math_truncate_f32_to_bits_rounding_to_nearest_even( // IEEE half-precision a.k.a. float16, // https://en.wikipedia.org/wiki/Half-precision_floating-point_format -IREE_MATH_MAKE_FLOAT_TYPE_HELPERS(f16, uint16_t, 5, 10, /*have_infinity=*/true) +IREE_MATH_MAKE_FLOAT_TYPE_HELPERS(f16, uint16_t, 5, 10, /*have_infinity=*/true, + /*bias_tweak=*/0, /*nan_as_neg_zero=*/false) // Bfloat16, https://en.wikipedia.org/wiki/Bfloat16_floating-point_format -IREE_MATH_MAKE_FLOAT_TYPE_HELPERS(bf16, uint16_t, 8, 7, /*have_infinity=*/true) +IREE_MATH_MAKE_FLOAT_TYPE_HELPERS(bf16, uint16_t, 8, 7, /*have_infinity=*/true, + /*bias_tweak=*/0, /*nan_as_neg_zero=*/false) // F8E5M2 type, https://arxiv.org/abs/2209.05433 -IREE_MATH_MAKE_FLOAT_TYPE_HELPERS(f8e5m2, uint8_t, 5, 2, /*have_infinity=*/true) +IREE_MATH_MAKE_FLOAT_TYPE_HELPERS(f8e5m2, uint8_t, 5, 2, /*have_infinity=*/true, + /*bias_tweak=*/0, /*nan_as_neg_zero=*/false) // F8E4M3 type, https://arxiv.org/abs/2209.05433. IREE_MATH_MAKE_FLOAT_TYPE_HELPERS(f8e4m3, uint8_t, 4, 3, - /*have_infinity=*/false) + /*have_infinity=*/false, /*bias_tweak=*/0, + /*nan_as_neg_zero=*/false) + +// F8E5M2FNUZ type, found in some AMD GPUs (MI300), called "BF8" there. +// Quoting LLVM's APFloat.h: +// 8-bit floating point number mostly following IEEE-754 conventions +// and bit layout S1E5M2 described in https://arxiv.org/abs/2206.02915, +// with expanded range and with no infinity or signed zero. +// NaN is represented as negative zero. (FN -> Finite, UZ -> unsigned zero). +// This format's exponent bias is 16, instead of the 15 (2 ** (5 - 1) - 1) +// that IEEE precedent would imply. +IREE_MATH_MAKE_FLOAT_TYPE_HELPERS(f8e5m2fnuz, uint8_t, 5, 2, + /*have_infinity=*/false, /*bias_tweak=*/1, + /*nan_as_neg_zero=*/true) + +// F8E4M3FNUZ type, found in some AMD GPUs (MI300), called "FP8" there. +// Quoting LLVM's APFloat.h: +// 8-bit floating point number mostly following IEEE-754 conventions +// and bit layout S1E4M3 described in https://arxiv.org/abs/2206.02915, +// with expanded range and with no infinity or signed zero. +// NaN is represented as negative zero. (FN -> Finite, UZ -> unsigned zero). +// This format's exponent bias is 8, instead of the 7 (2 ** (4 - 1) - 1) +// that IEEE precedent would imply. +IREE_MATH_MAKE_FLOAT_TYPE_HELPERS(f8e4m3fnuz, uint8_t, 4, 3, + /*have_infinity=*/false, /*bias_tweak=*/1, + /*nan_as_neg_zero=*/true) #endif // IREE_BASE_INTERNAL_MATH_H_ diff --git a/runtime/src/iree/base/internal/math_test.cc b/runtime/src/iree/base/internal/math_test.cc index b3548d762e45..347d7d2ba5c3 100644 --- a/runtime/src/iree/base/internal/math_test.cc +++ b/runtime/src/iree/base/internal/math_test.cc @@ -523,4 +523,200 @@ TEST(F8E4M3ConversionTest, F32ToF8E4M3ToF32) { EXPECT_NE(nan, nan); } +//============================================================================== +// F8E5M2FNUZ support +//============================================================================== + +TEST(F8E5M2FNUZConversionTest, F32ToF8E5M2FNUZ) { + constexpr float kF8E5M2FNUZMax = 57344.f; + constexpr float kF8E5M2FNUZMin = 1.f / 32768.f; + // Within range, normal truncation. + EXPECT_EQ(0x38, iree_math_f32_to_f8e5m2fnuz(0.25f)); + EXPECT_EQ(0xDA, iree_math_f32_to_f8e5m2fnuz(-100.375f)); + EXPECT_EQ(0x7E, iree_math_f32_to_f8e5m2fnuz(49152.f)); + EXPECT_EQ(0xFE, iree_math_f32_to_f8e5m2fnuz(-49152.f)); + EXPECT_EQ(0x7F, iree_math_f32_to_f8e5m2fnuz(kF8E5M2FNUZMax)); + EXPECT_EQ(0xFF, iree_math_f32_to_f8e5m2fnuz(-kF8E5M2FNUZMax)); + EXPECT_EQ(0x04, iree_math_f32_to_f8e5m2fnuz(kF8E5M2FNUZMin)); + EXPECT_EQ(0x84, iree_math_f32_to_f8e5m2fnuz(-kF8E5M2FNUZMin)); + // No infinities, so they convert to NaN, encoded as negative zero. + EXPECT_EQ(0x80, iree_math_f32_to_f8e5m2fnuz(INFINITY)); + EXPECT_EQ(0x80, iree_math_f32_to_f8e5m2fnuz(-INFINITY)); + // Overflow. + EXPECT_EQ(0x80, iree_math_f32_to_f8e5m2fnuz(FLT_MAX)); + EXPECT_EQ(0x80, iree_math_f32_to_f8e5m2fnuz(-FLT_MAX)); + // Underflow + EXPECT_EQ(0, iree_math_f32_to_f8e5m2fnuz(FLT_MIN)); + EXPECT_EQ(0, iree_math_f32_to_f8e5m2fnuz(-FLT_MIN)); // No negative zero. + // Denormals may or may not get flushed to zero. Accept both ways. + uint32_t positive_denormal = iree_math_f32_to_f8e5m2fnuz(kF8E5M2FNUZMin / 2); + EXPECT_TRUE(positive_denormal == 0 || positive_denormal == 0x02); + uint32_t negative_denormal = iree_math_f32_to_f8e5m2fnuz(-kF8E5M2FNUZMin / 2); + // No negative zero. + EXPECT_TRUE(negative_denormal == 0x0 || negative_denormal == 0x02); +} + +TEST(F8E5M2FNUZConversionTest, F32ToF8E5M2ToF32FNUZ) { + constexpr float kF8E5M2FNUZMax = 57344.f; + constexpr float kF8E5M2FNUZMin = 1.f / 32768.f; + // Within range, should just round. + EXPECT_EQ(0.25f, + iree_math_f8e5m2fnuz_to_f32(iree_math_f32_to_f8e5m2fnuz(0.25f))); + EXPECT_EQ(-0.25f, + iree_math_f8e5m2fnuz_to_f32(iree_math_f32_to_f8e5m2fnuz(-0.25f))); + EXPECT_EQ(96.f, + iree_math_f8e5m2fnuz_to_f32(iree_math_f32_to_f8e5m2fnuz(100.375f))); + EXPECT_EQ(-96.f, iree_math_f8e5m2fnuz_to_f32( + iree_math_f32_to_f8e5m2fnuz(-100.375f))); + EXPECT_EQ(96.f, + iree_math_f8e5m2fnuz_to_f32(iree_math_f32_to_f8e5m2fnuz(96.f))); + EXPECT_EQ(-96.f, + iree_math_f8e5m2fnuz_to_f32(iree_math_f32_to_f8e5m2fnuz(-96.f))); + EXPECT_EQ(kF8E5M2FNUZMax, iree_math_f8e5m2fnuz_to_f32( + iree_math_f32_to_f8e5m2fnuz(kF8E5M2FNUZMax))); + EXPECT_EQ(-kF8E5M2FNUZMax, iree_math_f8e5m2fnuz_to_f32( + iree_math_f32_to_f8e5m2fnuz(-kF8E5M2FNUZMax))); + EXPECT_EQ(kF8E5M2FNUZMin, iree_math_f8e5m2fnuz_to_f32( + iree_math_f32_to_f8e5m2fnuz(kF8E5M2FNUZMin))); + EXPECT_EQ(-kF8E5M2FNUZMin, iree_math_f8e5m2fnuz_to_f32( + iree_math_f32_to_f8e5m2fnuz(-kF8E5M2FNUZMin))); + // Powers of two should always be exactly representable across the + // exponent range. + EXPECT_EQ(32768.f, + iree_math_f8e5m2fnuz_to_f32(iree_math_f32_to_f8e5m2fnuz(32768.f))); + EXPECT_EQ(-32768.f, + iree_math_f8e5m2fnuz_to_f32(iree_math_f32_to_f8e5m2fnuz(-32768.f))); + // Overflow + EXPECT_TRUE(std::isnan( + iree_math_f8e5m2fnuz_to_f32(iree_math_f32_to_f8e5m2fnuz(FLT_MAX)))); + EXPECT_TRUE(std::isnan( + iree_math_f8e5m2fnuz_to_f32(iree_math_f32_to_f8e5m2fnuz(-FLT_MAX)))); + EXPECT_GT(kF8E5M2FNUZMax + 1.f, + iree_math_f8e5m2fnuz_to_f32( + iree_math_f32_to_f8e5m2fnuz(kF8E5M2FNUZMax + 1.f))); + // Underflow + EXPECT_EQ(0.0f, + iree_math_f8e5m2fnuz_to_f32(iree_math_f32_to_f8e5m2fnuz(FLT_MIN))); + EXPECT_EQ(0.0f, + iree_math_f8e5m2fnuz_to_f32(iree_math_f32_to_f8e5m2fnuz(-FLT_MIN))); + // Denormals may or may not get flushed to zero. Accept both ways. + float positive_denormal = iree_math_f8e5m2fnuz_to_f32( + iree_math_f32_to_f8e5m2fnuz(kF8E5M2FNUZMin / 2)); + EXPECT_TRUE(positive_denormal == 0.0f || + positive_denormal == 3.05175781e-05f); + // Inf and NaN. No infinities, so we get NaN. + EXPECT_TRUE(std::isnan( + iree_math_f8e5m2fnuz_to_f32(iree_math_f32_to_f8e5m2fnuz(INFINITY)))); + EXPECT_TRUE(std::isnan( + iree_math_f8e5m2fnuz_to_f32(iree_math_f32_to_f8e5m2fnuz(-INFINITY)))); + EXPECT_TRUE(std::isnan( + iree_math_f8e5m2fnuz_to_f32(iree_math_f32_to_f8e5m2fnuz(NAN)))); +} + +//============================================================================== +// F8E4M3FNUZ support +//============================================================================== + +TEST(F8E4M3FNUZConversionTest, F32ToF8E4M3FNUZ) { + // See https://arxiv.org/pdf/2209.05433.pdf, Table 1. + // The F8E4M3 format is special: it has no infinities, and has some larger + // finite values instead. + constexpr float kF8E4M3FNUZMax = 240.f; + constexpr float kF8E4M3FNUZMin = 1.f / 128.f; + // Within range, normal truncation. + EXPECT_EQ(0x30, iree_math_f32_to_f8e4m3fnuz(0.25f)); + EXPECT_EQ(0xF5, iree_math_f32_to_f8e4m3fnuz(-100.375f)); + // Extra large finite values thanks to not having infinities. + EXPECT_EQ(0x7F, iree_math_f32_to_f8e4m3fnuz(kF8E4M3FNUZMax)); + EXPECT_EQ(0x7F, iree_math_f32_to_f8e4m3fnuz(247.0f)); + EXPECT_EQ(0xFF, iree_math_f32_to_f8e4m3fnuz(-kF8E4M3FNUZMax)); + EXPECT_EQ(0xFF, iree_math_f32_to_f8e4m3fnuz(-247.0f)); + // First value that overflows. + EXPECT_EQ(0x80, iree_math_f32_to_f8e4m3fnuz(248.0f)); + EXPECT_EQ(0x80, iree_math_f32_to_f8e4m3fnuz(-248.0f)); + // Min normal values. + EXPECT_EQ(0x08, iree_math_f32_to_f8e4m3fnuz(kF8E4M3FNUZMin)); + EXPECT_EQ(0x88, iree_math_f32_to_f8e4m3fnuz(-kF8E4M3FNUZMin)); + // Infinity + EXPECT_EQ(0x80, iree_math_f32_to_f8e4m3fnuz(INFINITY)); + EXPECT_EQ(0x80, iree_math_f32_to_f8e4m3fnuz(-INFINITY)); + // Overflow + EXPECT_EQ(0x80, iree_math_f32_to_f8e4m3fnuz(FLT_MAX)); + EXPECT_EQ(0x80, iree_math_f32_to_f8e4m3fnuz(-FLT_MAX)); + // Test some round-to-nearest-even behavior. + EXPECT_EQ(0x78, iree_math_f32_to_f8e4m3fnuz(136.0f)); + EXPECT_EQ(0x7A, iree_math_f32_to_f8e4m3fnuz(152.0f)); + EXPECT_EQ(0x7A, iree_math_f32_to_f8e4m3fnuz(168.0f)); + EXPECT_EQ(0x7C, iree_math_f32_to_f8e4m3fnuz(184.0f)); + // Underflow + EXPECT_EQ(0, iree_math_f32_to_f8e4m3fnuz(FLT_MIN)); + EXPECT_EQ(0, iree_math_f32_to_f8e4m3fnuz(-FLT_MIN)); + // Denormals may or may not get flushed to zero. Accept both ways. + uint32_t positive_denormal = iree_math_f32_to_f8e4m3fnuz(kF8E4M3FNUZMin / 2); + EXPECT_TRUE(positive_denormal == 0 || positive_denormal == 0x04); + uint32_t negative_denormal = iree_math_f32_to_f8e4m3fnuz(-kF8E4M3FNUZMin / 2); + EXPECT_TRUE(negative_denormal == 0 || negative_denormal == 0x84); +} + +TEST(F8E4M3FNUZConversionTest, F32ToF8E4M3ToF32FNUZ) { + // See https://arxiv.org/pdf/2209.05433.pdf, Table 1. + // The F8E4M3 format is special: it has no infinities, and has some larger + // finite values instead. + constexpr float kF8E4M3FNUZMax = 240.f; + constexpr float kF8E4M3FNUZMin = 1.f / 128.f; + // Within range, should just round. + EXPECT_EQ(0.25f, + iree_math_f8e4m3fnuz_to_f32(iree_math_f32_to_f8e4m3fnuz(0.25f))); + EXPECT_EQ(-0.25f, + iree_math_f8e4m3fnuz_to_f32(iree_math_f32_to_f8e4m3fnuz(-0.25f))); + EXPECT_EQ(104.f, + iree_math_f8e4m3fnuz_to_f32(iree_math_f32_to_f8e4m3fnuz(100.375f))); + EXPECT_EQ(-104.f, iree_math_f8e4m3fnuz_to_f32( + iree_math_f32_to_f8e4m3fnuz(-100.375f))); + EXPECT_EQ(104.f, + iree_math_f8e4m3fnuz_to_f32(iree_math_f32_to_f8e4m3fnuz(100.4f))); + EXPECT_EQ(-104.f, + iree_math_f8e4m3fnuz_to_f32(iree_math_f32_to_f8e4m3fnuz(-100.4f))); + EXPECT_EQ(kF8E4M3FNUZMax, iree_math_f8e4m3fnuz_to_f32( + iree_math_f32_to_f8e4m3fnuz(kF8E4M3FNUZMax))); + EXPECT_EQ(-kF8E4M3FNUZMax, iree_math_f8e4m3fnuz_to_f32( + iree_math_f32_to_f8e4m3fnuz(-kF8E4M3FNUZMax))); + EXPECT_EQ(kF8E4M3FNUZMin, iree_math_f8e4m3fnuz_to_f32( + iree_math_f32_to_f8e4m3fnuz(kF8E4M3FNUZMin))); + EXPECT_EQ(-kF8E4M3FNUZMin, iree_math_f8e4m3fnuz_to_f32( + iree_math_f32_to_f8e4m3fnuz(-kF8E4M3FNUZMin))); + // Powers of two should always be exactly representable across the + // exponent range. + EXPECT_EQ(128.f, + iree_math_f8e4m3fnuz_to_f32(iree_math_f32_to_f8e4m3fnuz(128.f))); + EXPECT_EQ(-128.f, + iree_math_f8e4m3fnuz_to_f32(iree_math_f32_to_f8e4m3fnuz(-128.f))); + // Overflow + EXPECT_TRUE(std::isnan( + iree_math_f8e4m3fnuz_to_f32(iree_math_f32_to_f8e4m3fnuz(FLT_MAX)))); + EXPECT_TRUE(std::isnan( + iree_math_f8e4m3fnuz_to_f32(iree_math_f32_to_f8e4m3fnuz(-FLT_MAX)))); + EXPECT_GT(kF8E4M3FNUZMax + 1.f, + iree_math_f8e4m3fnuz_to_f32( + iree_math_f32_to_f8e4m3fnuz(kF8E4M3FNUZMax + 1.f))); + // Underflow + EXPECT_EQ(0.0f, + iree_math_f8e4m3fnuz_to_f32(iree_math_f32_to_f8e4m3fnuz(FLT_MIN))); + EXPECT_EQ(0.0f, + iree_math_f8e4m3fnuz_to_f32(iree_math_f32_to_f8e4m3fnuz(-FLT_MIN))); + // Denormals may or may not get flushed to zero. Accept both ways. + float positive_denormal = iree_math_f8e4m3fnuz_to_f32( + iree_math_f32_to_f8e4m3fnuz(kF8E4M3FNUZMin / 2)); + EXPECT_TRUE(positive_denormal == 0.0f || + positive_denormal == 3.05175781e-05f); + // Inf and Nan + EXPECT_TRUE(std::isnan( + iree_math_f8e4m3fnuz_to_f32(iree_math_f32_to_f8e4m3fnuz(INFINITY)))); + EXPECT_TRUE(std::isnan( + iree_math_f8e4m3fnuz_to_f32(iree_math_f32_to_f8e4m3fnuz(-INFINITY)))); + // Check that the result is a Nan with nan != nan. + float nan = iree_math_f8e4m3fnuz_to_f32(iree_math_f32_to_f8e4m3fnuz(NAN)); + EXPECT_NE(nan, nan); +} + } // namespace diff --git a/runtime/src/iree/base/internal/synchronization.c b/runtime/src/iree/base/internal/synchronization.c index 65fb0d1a93e8..960a70c3be9b 100644 --- a/runtime/src/iree/base/internal/synchronization.c +++ b/runtime/src/iree/base/internal/synchronization.c @@ -447,8 +447,7 @@ void iree_slim_mutex_initialize(iree_slim_mutex_t* out_mutex) { void iree_slim_mutex_deinitialize(iree_slim_mutex_t* mutex) { // Assert unlocked (callers must ensure the mutex is no longer in use). - SYNC_ASSERT( - iree_atomic_load_int32(&mutex->value, iree_memory_order_acquire) == 0); + SYNC_ASSERT(iree_atomic_load(&mutex->value, iree_memory_order_acquire) == 0); } // Helper to perform a compare_exchange operation on mutex->value, internally @@ -467,9 +466,9 @@ static bool iree_slim_mutex_try_lock_compare_exchange( // more about efficiency in the uncontended case than we care about avoiding // spurious failure. Also, some callers are calling this in a loop, where they // would want the weak form anyway. - return iree_atomic_compare_exchange_weak_int32( - &mutex->value, expected, desired, iree_memory_order_acquire, - iree_memory_order_relaxed); + return iree_atomic_compare_exchange_weak(&mutex->value, expected, desired, + iree_memory_order_acquire, + iree_memory_order_relaxed); } void iree_slim_mutex_lock(iree_slim_mutex_t* mutex) @@ -490,8 +489,7 @@ void iree_slim_mutex_lock(iree_slim_mutex_t* mutex) // This uses relaxed order because this is an internal intermediate step and // we only need atomicity here. value = - iree_atomic_fetch_add_int32(&mutex->value, 1, iree_memory_order_relaxed) + - 1; + iree_atomic_fetch_add(&mutex->value, 1, iree_memory_order_relaxed) + 1; while (true) { // While the lock is available: try to acquire it for this thread. @@ -513,8 +511,7 @@ void iree_slim_mutex_lock(iree_slim_mutex_t* mutex) int spin_count = 100; for (int i = 0; i < spin_count && iree_slim_mutex_is_locked(value); ++i) { iree_processor_yield(); - value = - iree_atomic_load_int32(&mutex->value, iree_memory_order_relaxed); + value = iree_atomic_load(&mutex->value, iree_memory_order_relaxed); } } @@ -523,7 +520,7 @@ void iree_slim_mutex_lock(iree_slim_mutex_t* mutex) // NOTE: we don't care about wait failure here as we are going to loop // and check again anyway. iree_futex_wait(&mutex->value, value, IREE_TIME_INFINITE_FUTURE); - value = iree_atomic_load_int32(&mutex->value, iree_memory_order_relaxed); + value = iree_atomic_load(&mutex->value, iree_memory_order_relaxed); } } } @@ -541,8 +538,8 @@ void iree_slim_mutex_unlock(iree_slim_mutex_t* mutex) IREE_DISABLE_THREAD_SAFETY_ANALYSIS { // Refer to the iree_slim_mutex_t struct comment, "Notes on atomics". // Transition 1->0 (unlocking with no waiters) or 2->1 (with waiters). - if (iree_atomic_fetch_sub_int32(&mutex->value, iree_slim_mutex_value(1), - iree_memory_order_release) != + if (iree_atomic_fetch_sub(&mutex->value, iree_slim_mutex_value(1), + iree_memory_order_release) != iree_slim_mutex_value(1)) { // One (or more) waiters; wake a single one to avoid a thundering herd of // multiple threads all waking and trying to grab the lock (as only one will @@ -749,14 +746,14 @@ void iree_notification_initialize(iree_notification_t* out_notification) { void iree_notification_deinitialize(iree_notification_t* notification) { // Assert no more waiters (callers must tear down waiters first). SYNC_ASSERT( - (iree_atomic_load_int64(¬ification->value, iree_memory_order_acquire) & + (iree_atomic_load(¬ification->value, iree_memory_order_acquire) & IREE_NOTIFICATION_WAITER_MASK) == 0); } void iree_notification_post(iree_notification_t* notification, int32_t count) { - uint64_t previous_value = iree_atomic_fetch_add_int64( - ¬ification->value, IREE_NOTIFICATION_EPOCH_INC, - iree_memory_order_acq_rel); + uint64_t previous_value = + iree_atomic_fetch_add(¬ification->value, IREE_NOTIFICATION_EPOCH_INC, + iree_memory_order_acq_rel); // Ensure we have at least one waiter; wake up to |count| of them. if (IREE_UNLIKELY(previous_value & IREE_NOTIFICATION_WAITER_MASK)) { iree_futex_wake(iree_notification_epoch_address(notification), count); @@ -765,9 +762,9 @@ void iree_notification_post(iree_notification_t* notification, int32_t count) { iree_wait_token_t iree_notification_prepare_wait( iree_notification_t* notification) { - uint64_t previous_value = iree_atomic_fetch_add_int64( - ¬ification->value, IREE_NOTIFICATION_WAITER_INC, - iree_memory_order_acq_rel); + uint64_t previous_value = + iree_atomic_fetch_add(¬ification->value, IREE_NOTIFICATION_WAITER_INC, + iree_memory_order_acq_rel); return (iree_wait_token_t)(previous_value >> IREE_NOTIFICATION_EPOCH_SHIFT); } @@ -779,8 +776,7 @@ typedef enum iree_notification_result_e { static iree_notification_result_t iree_notification_test_wait_condition( iree_notification_t* notification, iree_wait_token_t wait_token) { - return (iree_atomic_load_int64(¬ification->value, - iree_memory_order_acquire) >> + return (iree_atomic_load(¬ification->value, iree_memory_order_acquire) >> IREE_NOTIFICATION_EPOCH_SHIFT) != wait_token ? IREE_NOTIFICATION_RESULT_RESOLVED : IREE_NOTIFICATION_RESULT_UNRESOLVED; @@ -830,9 +826,9 @@ bool iree_notification_commit_wait(iree_notification_t* notification, // TODO(benvanik): benchmark under real workloads. // iree_memory_order_relaxed would suffice for correctness but the faster // the waiter count gets to 0 the less likely we'll wake on the futex. - uint64_t previous_value = iree_atomic_fetch_add_int64( - ¬ification->value, IREE_NOTIFICATION_WAITER_DEC, - iree_memory_order_acq_rel); + uint64_t previous_value = + iree_atomic_fetch_add(¬ification->value, IREE_NOTIFICATION_WAITER_DEC, + iree_memory_order_acq_rel); SYNC_ASSERT((previous_value & IREE_NOTIFICATION_WAITER_MASK) != 0); return result == IREE_NOTIFICATION_RESULT_RESOLVED; @@ -842,9 +838,9 @@ void iree_notification_cancel_wait(iree_notification_t* notification) { // TODO(benvanik): benchmark under real workloads. // iree_memory_order_relaxed would suffice for correctness but the faster // the waiter count gets to 0 the less likely we'll wake on the futex. - uint64_t previous_value = iree_atomic_fetch_add_int64( - ¬ification->value, IREE_NOTIFICATION_WAITER_DEC, - iree_memory_order_acq_rel); + uint64_t previous_value = + iree_atomic_fetch_add(¬ification->value, IREE_NOTIFICATION_WAITER_DEC, + iree_memory_order_acq_rel); SYNC_ASSERT((previous_value & IREE_NOTIFICATION_WAITER_MASK) != 0); } diff --git a/runtime/src/iree/base/internal/threading_darwin.c b/runtime/src/iree/base/internal/threading_darwin.c index 52932f848816..dc4b5f8ef81e 100644 --- a/runtime/src/iree/base/internal/threading_darwin.c +++ b/runtime/src/iree/base/internal/threading_darwin.c @@ -104,9 +104,8 @@ iree_status_t iree_thread_create(iree_thread_entry_t entry, void* entry_arg, thread->entry_arg = entry_arg; iree_strncpy_s(thread->name, IREE_ARRAYSIZE(thread->name), params.name.data, iree_min(params.name.size, IREE_ARRAYSIZE(thread->name) - 1)); - iree_atomic_store_int32(&thread->is_suspended, - params.create_suspended ? 1 : 0, - iree_memory_order_relaxed); + iree_atomic_store(&thread->is_suspended, params.create_suspended ? 1 : 0, + iree_memory_order_relaxed); pthread_attr_t thread_attr; pthread_attr_init(&thread_attr); @@ -239,7 +238,7 @@ void iree_thread_resume(iree_thread_t* thread) { // always balance suspend/resume or else we'll mess with any // debuggers/profilers that may be suspending threads for their own uses. int32_t expected = 1; - if (iree_atomic_compare_exchange_strong_int32( + if (iree_atomic_compare_exchange_strong( &thread->is_suspended, &expected, 0, iree_memory_order_acq_rel, iree_memory_order_relaxed /* expected is unused */)) { thread_resume(thread->mach_port); diff --git a/runtime/src/iree/base/internal/threading_pthreads.c b/runtime/src/iree/base/internal/threading_pthreads.c index 1686fd16a060..3f15987be768 100644 --- a/runtime/src/iree/base/internal/threading_pthreads.c +++ b/runtime/src/iree/base/internal/threading_pthreads.c @@ -51,8 +51,8 @@ static void iree_thread_set_priority_class( static bool iree_thread_resumed_predicate(void* arg) { iree_thread_t* thread = (iree_thread_t*)arg; - return iree_atomic_load_int32(&thread->suspend_count, - iree_memory_order_acquire) == 0; + return iree_atomic_load(&thread->suspend_count, iree_memory_order_acquire) == + 0; } #if defined(IREE_PLATFORM_EMSCRIPTEN) @@ -99,8 +99,8 @@ static void* iree_thread_start_routine(void* param) { IREE_TRACE_SET_THREAD_NAME(thread->name); // Wait until we resume if we were created suspended. - while (iree_atomic_load_int32(&thread->suspend_count, - iree_memory_order_acquire) > 0) { + while (iree_atomic_load(&thread->suspend_count, iree_memory_order_acquire) > + 0) { iree_notification_await(&thread->suspend_barrier, iree_thread_resumed_predicate, thread, iree_infinite_timeout()); @@ -335,8 +335,8 @@ void iree_thread_request_affinity(iree_thread_t* thread, void iree_thread_resume(iree_thread_t* thread) { IREE_TRACE_ZONE_BEGIN(z0); - if (iree_atomic_exchange_int32(&thread->suspend_count, 0, - iree_memory_order_acq_rel) == 1) { + if (iree_atomic_exchange(&thread->suspend_count, 0, + iree_memory_order_acq_rel) == 1) { iree_notification_post(&thread->suspend_barrier, IREE_ALL_WAITERS); } diff --git a/runtime/src/iree/base/internal/threading_test.cc b/runtime/src/iree/base/internal/threading_test.cc index 8ee5a96b7fa6..1fd973083e22 100644 --- a/runtime/src/iree/base/internal/threading_test.cc +++ b/runtime/src/iree/base/internal/threading_test.cc @@ -34,12 +34,11 @@ TEST(ThreadTest, Lifetime) { iree_atomic_int32_t value; iree_notification_t barrier; } entry_data; - iree_atomic_store_int32(&entry_data.value, 123, iree_memory_order_relaxed); + iree_atomic_store(&entry_data.value, 123, iree_memory_order_relaxed); iree_notification_initialize(&entry_data.barrier); iree_thread_entry_t entry_fn = +[](void* entry_arg) -> int { auto* entry_data = reinterpret_cast(entry_arg); - iree_atomic_fetch_add_int32(&entry_data->value, 1, - iree_memory_order_acq_rel); + iree_atomic_fetch_add(&entry_data->value, 1, iree_memory_order_acq_rel); iree_notification_post(&entry_data->barrier, IREE_ALL_WAITERS); return 0; }; @@ -55,8 +54,8 @@ TEST(ThreadTest, Lifetime) { &entry_data.barrier, +[](void* entry_arg) -> bool { auto* entry_data = reinterpret_cast(entry_arg); - return iree_atomic_load_int32(&entry_data->value, - iree_memory_order_relaxed) == (123 + 1); + return iree_atomic_load(&entry_data->value, + iree_memory_order_relaxed) == (123 + 1); }, &entry_data, iree_infinite_timeout()); @@ -76,12 +75,11 @@ TEST(ThreadTest, CreateSuspended) { iree_atomic_int32_t value; iree_notification_t barrier; } entry_data; - iree_atomic_store_int32(&entry_data.value, 123, iree_memory_order_relaxed); + iree_atomic_store(&entry_data.value, 123, iree_memory_order_relaxed); iree_notification_initialize(&entry_data.barrier); iree_thread_entry_t entry_fn = +[](void* entry_arg) -> int { auto* entry_data = reinterpret_cast(entry_arg); - iree_atomic_fetch_add_int32(&entry_data->value, 1, - iree_memory_order_acq_rel); + iree_atomic_fetch_add(&entry_data->value, 1, iree_memory_order_acq_rel); iree_notification_post(&entry_data->barrier, IREE_ALL_WAITERS); return 0; }; @@ -95,11 +93,11 @@ TEST(ThreadTest, CreateSuspended) { // the value. I can't think of a good way to test this, though, so we'll just // wait a moment here and assume that if the thread was able to run it would // have during this wait. - ASSERT_EQ(123, iree_atomic_load_int32(&entry_data.value, - iree_memory_order_seq_cst)); + ASSERT_EQ(123, + iree_atomic_load(&entry_data.value, iree_memory_order_seq_cst)); std::this_thread::sleep_for(std::chrono::milliseconds(150)); - ASSERT_EQ(123, iree_atomic_load_int32(&entry_data.value, - iree_memory_order_seq_cst)); + ASSERT_EQ(123, + iree_atomic_load(&entry_data.value, iree_memory_order_seq_cst)); // Resume the thread and wait for it to finish its work. iree_thread_resume(thread); @@ -107,8 +105,8 @@ TEST(ThreadTest, CreateSuspended) { &entry_data.barrier, +[](void* entry_arg) -> bool { auto* entry_data = reinterpret_cast(entry_arg); - return iree_atomic_load_int32(&entry_data->value, - iree_memory_order_relaxed) == (123 + 1); + return iree_atomic_load(&entry_data->value, + iree_memory_order_relaxed) == (123 + 1); }, &entry_data, iree_infinite_timeout()); iree_thread_release(thread); @@ -126,11 +124,10 @@ TEST(ThreadTest, PriorityOverride) { struct entry_data_t { iree_atomic_int32_t value; } entry_data; - iree_atomic_store_int32(&entry_data.value, 0, iree_memory_order_relaxed); + iree_atomic_store(&entry_data.value, 0, iree_memory_order_relaxed); iree_thread_entry_t entry_fn = +[](void* entry_arg) -> int { auto* entry_data = reinterpret_cast(entry_arg); - iree_atomic_fetch_add_int32(&entry_data->value, 1, - iree_memory_order_release); + iree_atomic_fetch_add(&entry_data->value, 1, iree_memory_order_release); return 0; }; @@ -150,8 +147,7 @@ TEST(ThreadTest, PriorityOverride) { thread, IREE_THREAD_PRIORITY_CLASS_LOWEST); // Wait for the thread to finish. - while (iree_atomic_load_int32(&entry_data.value, iree_memory_order_acquire) != - 1) { + while (iree_atomic_load(&entry_data.value, iree_memory_order_acquire) != 1) { iree_thread_yield(); } diff --git a/runtime/src/iree/base/internal/threading_win32.c b/runtime/src/iree/base/internal/threading_win32.c index 6166ce288175..64ddca614da2 100644 --- a/runtime/src/iree/base/internal/threading_win32.c +++ b/runtime/src/iree/base/internal/threading_win32.c @@ -143,9 +143,8 @@ iree_status_t iree_thread_create(iree_thread_entry_t entry, void* entry_arg, thread->entry_arg = entry_arg; strncpy_s(thread->name, IREE_ARRAYSIZE(thread->name), params.name.data, min(params.name.size, IREE_ARRAYSIZE(thread->name) - 1)); - iree_atomic_store_int32(&thread->is_suspended, - params.create_suspended ? 1 : 0, - iree_memory_order_relaxed); + iree_atomic_store(&thread->is_suspended, params.create_suspended ? 1 : 0, + iree_memory_order_relaxed); iree_thread_override_list_initialize(iree_thread_set_priority_class, params.priority_class, thread->allocator, &thread->qos_override_list); @@ -304,7 +303,7 @@ void iree_thread_resume(iree_thread_t* thread) { // always balance suspend/resume or else we'll mess with any // debuggers/profilers that may be suspending threads for their own uses. int32_t expected = 1; - if (iree_atomic_compare_exchange_strong_int32( + if (iree_atomic_compare_exchange_strong( &thread->is_suspended, &expected, 0, iree_memory_order_acq_rel, iree_memory_order_relaxed /* expected is unused */)) { ResumeThread(thread->handle); diff --git a/runtime/src/iree/base/internal/wait_handle_inproc.c b/runtime/src/iree/base/internal/wait_handle_inproc.c index e3192595e177..7f92797b1bc8 100644 --- a/runtime/src/iree/base/internal/wait_handle_inproc.c +++ b/runtime/src/iree/base/internal/wait_handle_inproc.c @@ -240,7 +240,7 @@ static bool iree_wait_set_check(const iree_wait_set_check_params_t* params) { iree_wait_handle_t* wait_handle = ¶ms->set->handles[i]; iree_futex_handle_t* futex = (iree_futex_handle_t*)wait_handle->value.local_futex; - if (iree_atomic_load_int64(&futex->value, iree_memory_order_acquire) != 0) { + if (iree_atomic_load(&futex->value, iree_memory_order_acquire) != 0) { ++ready_count; if (params->wake_handle) { *params->wake_handle = *wait_handle; @@ -292,7 +292,7 @@ iree_status_t iree_wait_any(iree_wait_set_t* set, iree_time_t deadline_ns, } static bool iree_futex_handle_check(iree_futex_handle_t* futex) { - return iree_atomic_load_int64(&futex->value, iree_memory_order_acquire) != 0; + return iree_atomic_load(&futex->value, iree_memory_order_acquire) != 0; } iree_status_t iree_wait_one(iree_wait_handle_t* handle, @@ -335,8 +335,8 @@ iree_status_t iree_event_initialize(bool initial_state, if (iree_status_is_ok(status)) { out_event->type = IREE_WAIT_PRIMITIVE_TYPE_LOCAL_FUTEX; out_event->value.local_futex = (void*)futex; - iree_atomic_store_int64(&futex->value, initial_state ? 1 : 0, - iree_memory_order_release); + iree_atomic_store(&futex->value, initial_state ? 1 : 0, + iree_memory_order_release); iree_notification_initialize(&futex->notification); } @@ -358,8 +358,7 @@ void iree_event_set(iree_event_t* event) { // Try to transition from unset -> set. // No-op if already set and otherwise we successfully signaled the event and // need to notify all waiters. - if (iree_atomic_exchange_int64(&futex->value, 1, iree_memory_order_release) == - 0) { + if (iree_atomic_exchange(&futex->value, 1, iree_memory_order_release) == 0) { // Notify those waiting on just this event. iree_notification_post(&futex->notification, IREE_ALL_WAITERS); // Notify any multi-waits that may have this event as part of their set. @@ -371,7 +370,7 @@ void iree_event_reset(iree_event_t* event) { if (!event) return; iree_futex_handle_t* futex = (iree_futex_handle_t*)event->value.local_futex; if (!futex) return; - iree_atomic_store_int64(&futex->value, 0, iree_memory_order_release); + iree_atomic_store(&futex->value, 0, iree_memory_order_release); } #endif // IREE_WAIT_API == IREE_WAIT_API_INPROC diff --git a/runtime/src/iree/hal/buffer_view.h b/runtime/src/iree/hal/buffer_view.h index 96b9fd487ce5..b5c4861dcdd0 100644 --- a/runtime/src/iree/hal/buffer_view.h +++ b/runtime/src/iree/hal/buffer_view.h @@ -48,6 +48,14 @@ enum iree_hal_numerical_type_bits_t { IREE_HAL_NUMERICAL_TYPE_FLOAT_BRAIN = IREE_HAL_NUMERICAL_TYPE_FLOAT | 0x02u, // Paired (real, imag) complex number in floating-point format. IREE_HAL_NUMERICAL_TYPE_FLOAT_COMPLEX = IREE_HAL_NUMERICAL_TYPE_FLOAT | 0x03u, + // Ad-hoc entries for the zoo of low-bit-depth float types. They are special + // in that there are many different types sharing the same size. + IREE_HAL_NUMERICAL_TYPE_FLOAT_8_E5M2 = IREE_HAL_NUMERICAL_TYPE_FLOAT | 0x04u, + IREE_HAL_NUMERICAL_TYPE_FLOAT_8_E4M3 = IREE_HAL_NUMERICAL_TYPE_FLOAT | 0x05u, + IREE_HAL_NUMERICAL_TYPE_FLOAT_8_E5M2_FNUZ = + IREE_HAL_NUMERICAL_TYPE_FLOAT | 0x06u, + IREE_HAL_NUMERICAL_TYPE_FLOAT_8_E4M3_FNUZ = + IREE_HAL_NUMERICAL_TYPE_FLOAT | 0x07u, }; typedef uint8_t iree_hal_numerical_type_t; @@ -148,6 +156,10 @@ enum iree_hal_element_types_t { IREE_HAL_ELEMENT_TYPE_BFLOAT_16 = IREE_HAL_ELEMENT_TYPE_VALUE(IREE_HAL_NUMERICAL_TYPE_FLOAT_BRAIN, 16), // NOLINT IREE_HAL_ELEMENT_TYPE_COMPLEX_FLOAT_64 = IREE_HAL_ELEMENT_TYPE_VALUE(IREE_HAL_NUMERICAL_TYPE_FLOAT_COMPLEX, 64), // NOLINT IREE_HAL_ELEMENT_TYPE_COMPLEX_FLOAT_128 = IREE_HAL_ELEMENT_TYPE_VALUE(IREE_HAL_NUMERICAL_TYPE_FLOAT_COMPLEX, 128), // NOLINT + IREE_HAL_ELEMENT_TYPE_FLOAT_8_E5M2 = IREE_HAL_ELEMENT_TYPE_VALUE(IREE_HAL_NUMERICAL_TYPE_FLOAT_8_E5M2, 8), // NOLINT + IREE_HAL_ELEMENT_TYPE_FLOAT_8_E4M3 = IREE_HAL_ELEMENT_TYPE_VALUE(IREE_HAL_NUMERICAL_TYPE_FLOAT_8_E4M3, 8), // NOLINT + IREE_HAL_ELEMENT_TYPE_FLOAT_8_E5M2_FNUZ = IREE_HAL_ELEMENT_TYPE_VALUE(IREE_HAL_NUMERICAL_TYPE_FLOAT_8_E5M2_FNUZ, 8), // NOLINT + IREE_HAL_ELEMENT_TYPE_FLOAT_8_E4M3_FNUZ = IREE_HAL_ELEMENT_TYPE_VALUE(IREE_HAL_NUMERICAL_TYPE_FLOAT_8_E4M3_FNUZ, 8), // NOLINT }; typedef uint32_t iree_hal_element_type_t; // clang-format on diff --git a/runtime/src/iree/hal/cts/semaphore_submission_test.h b/runtime/src/iree/hal/cts/semaphore_submission_test.h index a158082b36c3..b745761cf6d9 100644 --- a/runtime/src/iree/hal/cts/semaphore_submission_test.h +++ b/runtime/src/iree/hal/cts/semaphore_submission_test.h @@ -882,7 +882,7 @@ TEST_F(SemaphoreSubmissionTest, PropagateFailSignal) { EXPECT_EQ(iree_status_code(wait_status), IREE_STATUS_ABORTED); uint64_t value = 1234; iree_status_t query_status = iree_hal_semaphore_query(semaphore2, &value); - EXPECT_EQ(value, IREE_HAL_SEMAPHORE_FAILURE_VALUE); + EXPECT_GE(value, IREE_HAL_SEMAPHORE_FAILURE_VALUE); CheckStatusContains(query_status, status); signal_thread.join(); diff --git a/runtime/src/iree/hal/cts/semaphore_test.h b/runtime/src/iree/hal/cts/semaphore_test.h index 54e907e47004..7d0592f1921a 100644 --- a/runtime/src/iree/hal/cts/semaphore_test.h +++ b/runtime/src/iree/hal/cts/semaphore_test.h @@ -406,7 +406,7 @@ TEST_F(SemaphoreTest, FailThenWait) { EXPECT_EQ(iree_status_code(wait_status), IREE_STATUS_ABORTED); uint64_t value = 1234; iree_status_t query_status = iree_hal_semaphore_query(semaphore, &value); - EXPECT_EQ(value, IREE_HAL_SEMAPHORE_FAILURE_VALUE); + EXPECT_GE(value, IREE_HAL_SEMAPHORE_FAILURE_VALUE); CheckStatusContains(query_status, status); iree_hal_semaphore_release(semaphore); @@ -431,7 +431,7 @@ TEST_F(SemaphoreTest, WaitThenFail) { EXPECT_EQ(iree_status_code(wait_status), IREE_STATUS_ABORTED); uint64_t value = 1234; iree_status_t query_status = iree_hal_semaphore_query(semaphore, &value); - EXPECT_EQ(value, IREE_HAL_SEMAPHORE_FAILURE_VALUE); + EXPECT_GE(value, IREE_HAL_SEMAPHORE_FAILURE_VALUE); CheckStatusContains(query_status, status); signal_thread.join(); @@ -467,7 +467,7 @@ TEST_F(SemaphoreTest, MultiWaitThenFail) { uint64_t value = 1234; iree_status_t semaphore1_query_status = iree_hal_semaphore_query(semaphore1, &value); - EXPECT_EQ(value, IREE_HAL_SEMAPHORE_FAILURE_VALUE); + EXPECT_GE(value, IREE_HAL_SEMAPHORE_FAILURE_VALUE); CheckStatusContains(semaphore1_query_status, status); // semaphore2 must not have changed. @@ -511,7 +511,7 @@ TEST_F(SemaphoreTest, DeviceMultiWaitThenFail) { uint64_t value = 1234; iree_status_t semaphore1_query_status = iree_hal_semaphore_query(semaphore1, &value); - EXPECT_EQ(value, IREE_HAL_SEMAPHORE_FAILURE_VALUE); + EXPECT_GE(value, IREE_HAL_SEMAPHORE_FAILURE_VALUE); CheckStatusContains(semaphore1_query_status, status); // semaphore2 must not have changed. diff --git a/runtime/src/iree/hal/drivers/cuda/event_semaphore.c b/runtime/src/iree/hal/drivers/cuda/event_semaphore.c index fb86efe7e815..0c0cf41e6ba9 100644 --- a/runtime/src/iree/hal/drivers/cuda/event_semaphore.c +++ b/runtime/src/iree/hal/drivers/cuda/event_semaphore.c @@ -325,7 +325,7 @@ static iree_status_t iree_hal_cuda_semaphore_wait( } iree_slim_mutex_lock(&semaphore->mutex); - if (semaphore->current_value == IREE_HAL_SEMAPHORE_FAILURE_VALUE) { + if (semaphore->current_value >= IREE_HAL_SEMAPHORE_FAILURE_VALUE) { iree_slim_mutex_unlock(&semaphore->mutex); IREE_TRACE_ZONE_END(z0); return iree_make_status(IREE_STATUS_ABORTED); @@ -350,7 +350,7 @@ static iree_status_t iree_hal_cuda_semaphore_wait( } iree_slim_mutex_lock(&semaphore->mutex); - if (semaphore->current_value == IREE_HAL_SEMAPHORE_FAILURE_VALUE) { + if (semaphore->current_value >= IREE_HAL_SEMAPHORE_FAILURE_VALUE) { status = iree_make_status(IREE_STATUS_ABORTED); } iree_slim_mutex_unlock(&semaphore->mutex); @@ -444,7 +444,7 @@ iree_status_t iree_hal_cuda_semaphore_multi_wait( iree_hal_cuda_semaphore_t* semaphore = iree_hal_cuda_semaphore_cast(semaphore_list.semaphores[i]); iree_slim_mutex_lock(&semaphore->mutex); - if (semaphore->current_value == IREE_HAL_SEMAPHORE_FAILURE_VALUE) { + if (semaphore->current_value >= IREE_HAL_SEMAPHORE_FAILURE_VALUE) { iree_slim_mutex_unlock(&semaphore->mutex); status = iree_make_status(IREE_STATUS_ABORTED); break; diff --git a/runtime/src/iree/hal/drivers/cuda/memory_pools.c b/runtime/src/iree/hal/drivers/cuda/memory_pools.c index 236ffaac840b..1e34422478f5 100644 --- a/runtime/src/iree/hal/drivers/cuda/memory_pools.c +++ b/runtime/src/iree/hal/drivers/cuda/memory_pools.c @@ -121,8 +121,8 @@ static void iree_hal_cuda_memory_pool_track_alloc( iree_atomic_int64_t* bytes_allocated = is_device_local ? &pools->statistics.device_bytes_allocated : &pools->statistics.host_bytes_allocated; - iree_atomic_fetch_add_int64(bytes_allocated, allocation_size, - iree_memory_order_relaxed); + iree_atomic_fetch_add(bytes_allocated, allocation_size, + iree_memory_order_relaxed); }); } @@ -141,8 +141,8 @@ static void iree_hal_cuda_memory_pool_track_free( : &pools->statistics.host_bytes_freed; iree_device_size_t allocation_size = iree_hal_buffer_allocation_size(buffer); - iree_atomic_fetch_add_int64(bytes_freed, allocation_size, - iree_memory_order_relaxed); + iree_atomic_fetch_add(bytes_freed, allocation_size, + iree_memory_order_relaxed); }); } @@ -150,13 +150,13 @@ void iree_hal_cuda_memory_pools_merge_statistics( iree_hal_cuda_memory_pools_t* pools, iree_hal_allocator_statistics_t* statistics) { IREE_STATISTICS({ - statistics->device_bytes_allocated = iree_atomic_load_int64( + statistics->device_bytes_allocated = iree_atomic_load( &pools->statistics.device_bytes_allocated, iree_memory_order_relaxed); - statistics->host_bytes_allocated = iree_atomic_load_int64( + statistics->host_bytes_allocated = iree_atomic_load( &pools->statistics.host_bytes_allocated, iree_memory_order_relaxed); - statistics->device_bytes_freed = iree_atomic_load_int64( + statistics->device_bytes_freed = iree_atomic_load( &pools->statistics.device_bytes_freed, iree_memory_order_relaxed); - statistics->host_bytes_freed = iree_atomic_load_int64( + statistics->host_bytes_freed = iree_atomic_load( &pools->statistics.host_bytes_freed, iree_memory_order_relaxed); if (pools->device_local) { cuuint64_t pool_peak = 0; diff --git a/runtime/src/iree/hal/drivers/hip/event_semaphore.c b/runtime/src/iree/hal/drivers/hip/event_semaphore.c index 926eb54ce5f7..de10b09125ec 100644 --- a/runtime/src/iree/hal/drivers/hip/event_semaphore.c +++ b/runtime/src/iree/hal/drivers/hip/event_semaphore.c @@ -323,7 +323,7 @@ static iree_status_t iree_hal_hip_semaphore_wait( } iree_slim_mutex_lock(&semaphore->mutex); - if (semaphore->current_value == IREE_HAL_SEMAPHORE_FAILURE_VALUE) { + if (semaphore->current_value >= IREE_HAL_SEMAPHORE_FAILURE_VALUE) { iree_slim_mutex_unlock(&semaphore->mutex); IREE_TRACE_ZONE_END(z0); return iree_make_status(IREE_STATUS_ABORTED); @@ -346,7 +346,7 @@ static iree_status_t iree_hal_hip_semaphore_wait( } iree_slim_mutex_lock(&semaphore->mutex); - if (semaphore->current_value == IREE_HAL_SEMAPHORE_FAILURE_VALUE) { + if (semaphore->current_value >= IREE_HAL_SEMAPHORE_FAILURE_VALUE) { status = iree_make_status(IREE_STATUS_ABORTED); } iree_slim_mutex_unlock(&semaphore->mutex); @@ -440,7 +440,7 @@ iree_status_t iree_hal_hip_semaphore_multi_wait( iree_hal_hip_semaphore_t* semaphore = iree_hal_hip_semaphore_cast(semaphore_list.semaphores[i]); iree_slim_mutex_lock(&semaphore->mutex); - if (semaphore->current_value == IREE_HAL_SEMAPHORE_FAILURE_VALUE) { + if (semaphore->current_value >= IREE_HAL_SEMAPHORE_FAILURE_VALUE) { iree_slim_mutex_unlock(&semaphore->mutex); status = iree_make_status(IREE_STATUS_ABORTED); break; diff --git a/runtime/src/iree/hal/drivers/hip/memory_pools.c b/runtime/src/iree/hal/drivers/hip/memory_pools.c index e599cf62daa0..89e27fafdfd1 100644 --- a/runtime/src/iree/hal/drivers/hip/memory_pools.c +++ b/runtime/src/iree/hal/drivers/hip/memory_pools.c @@ -121,8 +121,8 @@ static void iree_hal_hip_memory_pool_track_alloc( iree_atomic_int64_t* bytes_allocated = is_device_local ? &pools->statistics.device_bytes_allocated : &pools->statistics.host_bytes_allocated; - iree_atomic_fetch_add_int64(bytes_allocated, allocation_size, - iree_memory_order_relaxed); + iree_atomic_fetch_add(bytes_allocated, allocation_size, + iree_memory_order_relaxed); }); } @@ -141,8 +141,8 @@ static void iree_hal_hip_memory_pool_track_free( : &pools->statistics.host_bytes_freed; iree_device_size_t allocation_size = iree_hal_buffer_allocation_size(buffer); - iree_atomic_fetch_add_int64(bytes_freed, allocation_size, - iree_memory_order_relaxed); + iree_atomic_fetch_add(bytes_freed, allocation_size, + iree_memory_order_relaxed); }); } @@ -150,13 +150,13 @@ void iree_hal_hip_memory_pools_merge_statistics( iree_hal_hip_memory_pools_t* pools, iree_hal_allocator_statistics_t* statistics) { IREE_STATISTICS({ - statistics->device_bytes_allocated = iree_atomic_load_int64( + statistics->device_bytes_allocated = iree_atomic_load( &pools->statistics.device_bytes_allocated, iree_memory_order_relaxed); - statistics->host_bytes_allocated = iree_atomic_load_int64( + statistics->host_bytes_allocated = iree_atomic_load( &pools->statistics.host_bytes_allocated, iree_memory_order_relaxed); - statistics->device_bytes_freed = iree_atomic_load_int64( + statistics->device_bytes_freed = iree_atomic_load( &pools->statistics.device_bytes_freed, iree_memory_order_relaxed); - statistics->host_bytes_freed = iree_atomic_load_int64( + statistics->host_bytes_freed = iree_atomic_load( &pools->statistics.host_bytes_freed, iree_memory_order_relaxed); if (pools->device_local) { diff --git a/runtime/src/iree/hal/drivers/metal/shared_event.m b/runtime/src/iree/hal/drivers/metal/shared_event.m index f741f2ea3a63..716306c215bb 100644 --- a/runtime/src/iree/hal/drivers/metal/shared_event.m +++ b/runtime/src/iree/hal/drivers/metal/shared_event.m @@ -231,7 +231,7 @@ iree_status_t iree_hal_metal_shared_event_multi_wait( // Create an atomic to count how many semaphores have signaled. Mark it as `__block` so different // threads are sharing the same data via reference. __block iree_atomic_int32_t wait_count; - iree_atomic_store_int32(&wait_count, 0, iree_memory_order_release); + iree_atomic_store(&wait_count, 0, iree_memory_order_release); // The total count we are expecting to see. iree_host_size_t total_count = (wait_mode == IREE_HAL_WAIT_MODE_ALL) ? semaphore_list->count : 1; // Theoretically we don't really need to mark the semaphore handle as __block given that the @@ -253,7 +253,7 @@ iree_status_t iree_hal_metal_shared_event_multi_wait( // Fail as a whole if any participating semaphore failed. if (v >= IREE_HAL_SEMAPHORE_FAILURE_VALUE) did_fail = true; - int32_t old_value = iree_atomic_fetch_add_int32( + int32_t old_value = iree_atomic_fetch_add( &wait_count, 1, iree_memory_order_release); // The last signaled semaphore send out the notification. // Atomic fetch add returns the old value, so need to +1. diff --git a/runtime/src/iree/hal/drivers/metal/staging_buffer.m b/runtime/src/iree/hal/drivers/metal/staging_buffer.m index ca0128f78890..e83e622e868b 100644 --- a/runtime/src/iree/hal/drivers/metal/staging_buffer.m +++ b/runtime/src/iree/hal/drivers/metal/staging_buffer.m @@ -37,8 +37,7 @@ iree_status_t iree_hal_metal_staging_buffer_initialize( out_staging_buffer->host_buffer = metal_buffer.contents; iree_slim_mutex_initialize(&out_staging_buffer->offset_mutex); out_staging_buffer->offset = 0; - iree_atomic_store_int32(&out_staging_buffer->pending_command_buffers, 0, - iree_memory_order_relaxed); + iree_atomic_store(&out_staging_buffer->pending_command_buffers, 0, iree_memory_order_relaxed); IREE_TRACE_ZONE_END(z0); return iree_ok_status(); @@ -97,14 +96,13 @@ void iree_hal_metal_staging_buffer_reset(iree_hal_metal_staging_buffer_t* stagin void iree_hal_metal_staging_buffer_increase_command_buffer_refcount( iree_hal_metal_staging_buffer_t* staging_buffer) { - iree_atomic_fetch_add_int32(&staging_buffer->pending_command_buffers, 1, - iree_memory_order_relaxed); + iree_atomic_fetch_add(&staging_buffer->pending_command_buffers, 1, iree_memory_order_relaxed); } void iree_hal_metal_staging_buffer_decrease_command_buffer_refcount( iree_hal_metal_staging_buffer_t* staging_buffer) { - if (iree_atomic_fetch_sub_int32(&staging_buffer->pending_command_buffers, 1, - iree_memory_order_acq_rel) == 1) { + if (iree_atomic_fetch_sub(&staging_buffer->pending_command_buffers, 1, + iree_memory_order_acq_rel) == 1) { iree_hal_metal_staging_buffer_reset(staging_buffer); } } diff --git a/runtime/src/iree/hal/drivers/null/README.md b/runtime/src/iree/hal/drivers/null/README.md index 3c3e4200334b..3dcb62a512a9 100644 --- a/runtime/src/iree/hal/drivers/null/README.md +++ b/runtime/src/iree/hal/drivers/null/README.md @@ -17,8 +17,10 @@ fill (memset) you can often implement copy (memcpy) as well at the same time. `experimental/` folder if going in-tree. 1. Find/replace `{Null}` with the friendly name of your driver (e.g. `Vulkan`). 1. Find/replace `_null_` with the C name of your driver (e.g. `vulkan`). +1. Find/replace `_NULL_` with the upper C name of your driver (e.g. `VULKAN`). 1. Find/replace `// TODO(null):` with your github ID, your driver name, or a GitHub issue number tracking driver creation (e.g. `// TODO(#1234):`). +1. Find/replace `iree/hal/drivers/null/` with your source path. ## Build Setup diff --git a/runtime/src/iree/hal/drivers/null/allocator.c b/runtime/src/iree/hal/drivers/null/allocator.c index f84f00257ee5..e1c91ce7c35d 100644 --- a/runtime/src/iree/hal/drivers/null/allocator.c +++ b/runtime/src/iree/hal/drivers/null/allocator.c @@ -8,6 +8,10 @@ #include "iree/hal/drivers/null/buffer.h" +//===----------------------------------------------------------------------===// +// iree_hal_null_allocator_t +//===----------------------------------------------------------------------===// + // TODO(null): use one ID per address space or pool - each shows as a different // track in tracing tools. #if IREE_TRACING_FEATURES & IREE_TRACING_FEATURE_ALLOCATION_TRACKING @@ -33,6 +37,7 @@ iree_status_t iree_hal_null_allocator_create( iree_allocator_t host_allocator, iree_hal_allocator_t** out_allocator) { IREE_ASSERT_ARGUMENT(out_allocator); IREE_TRACE_ZONE_BEGIN(z0); + *out_allocator = NULL; iree_hal_null_allocator_t* allocator = NULL; IREE_RETURN_AND_END_ZONE_IF_ERROR( diff --git a/runtime/src/iree/hal/drivers/null/allocator.h b/runtime/src/iree/hal/drivers/null/allocator.h index c0286bac6041..299c9c96c44f 100644 --- a/runtime/src/iree/hal/drivers/null/allocator.h +++ b/runtime/src/iree/hal/drivers/null/allocator.h @@ -10,6 +10,10 @@ #include "iree/base/api.h" #include "iree/hal/api.h" +//===----------------------------------------------------------------------===// +// iree_hal_null_allocator_t +//===----------------------------------------------------------------------===// + // Creates a {Null} buffer allocator used for persistent allocations. iree_status_t iree_hal_null_allocator_create( iree_allocator_t host_allocator, iree_hal_allocator_t** out_allocator); diff --git a/runtime/src/iree/hal/drivers/null/buffer.c b/runtime/src/iree/hal/drivers/null/buffer.c index f6eeecb11f20..6e676526b1b5 100644 --- a/runtime/src/iree/hal/drivers/null/buffer.c +++ b/runtime/src/iree/hal/drivers/null/buffer.c @@ -6,6 +6,10 @@ #include "iree/hal/drivers/null/buffer.h" +//===----------------------------------------------------------------------===// +// iree_hal_null_buffer_t +//===----------------------------------------------------------------------===// + typedef struct iree_hal_null_buffer_t { iree_hal_buffer_t base; iree_hal_buffer_release_callback_t release_callback; @@ -33,8 +37,8 @@ iree_status_t iree_hal_null_buffer_wrap( iree_hal_buffer_release_callback_t release_callback, iree_allocator_t host_allocator, iree_hal_buffer_t** out_buffer) { IREE_ASSERT_ARGUMENT(out_buffer); - *out_buffer = NULL; IREE_TRACE_ZONE_BEGIN(z0); + *out_buffer = NULL; iree_hal_null_buffer_t* buffer = NULL; IREE_RETURN_AND_END_ZONE_IF_ERROR( diff --git a/runtime/src/iree/hal/drivers/null/buffer.h b/runtime/src/iree/hal/drivers/null/buffer.h index 7e492f4d49d7..edf2e457e5e8 100644 --- a/runtime/src/iree/hal/drivers/null/buffer.h +++ b/runtime/src/iree/hal/drivers/null/buffer.h @@ -10,6 +10,10 @@ #include "iree/base/api.h" #include "iree/hal/api.h" +//===----------------------------------------------------------------------===// +// iree_hal_null_buffer_t +//===----------------------------------------------------------------------===// + // Wraps a {Null} allocation in an iree_hal_buffer_t. iree_status_t iree_hal_null_buffer_wrap( iree_hal_allocator_t* allocator, iree_hal_memory_type_t memory_type, diff --git a/runtime/src/iree/hal/drivers/null/channel.c b/runtime/src/iree/hal/drivers/null/channel.c index 195c3d5786cd..0d2915b066fa 100644 --- a/runtime/src/iree/hal/drivers/null/channel.c +++ b/runtime/src/iree/hal/drivers/null/channel.c @@ -6,6 +6,10 @@ #include "iree/hal/drivers/null/channel.h" +//===----------------------------------------------------------------------===// +// iree_hal_null_channel_t +//===----------------------------------------------------------------------===// + typedef struct iree_hal_null_channel_t { iree_hal_resource_t resource; iree_allocator_t host_allocator; @@ -34,8 +38,8 @@ iree_status_t iree_hal_null_channel_create(iree_hal_channel_params_t params, iree_allocator_t host_allocator, iree_hal_channel_t** out_channel) { IREE_ASSERT_ARGUMENT(out_channel); - *out_channel = NULL; IREE_TRACE_ZONE_BEGIN(z0); + *out_channel = NULL; iree_hal_null_channel_t* channel = NULL; IREE_RETURN_AND_END_ZONE_IF_ERROR( diff --git a/runtime/src/iree/hal/drivers/null/channel.h b/runtime/src/iree/hal/drivers/null/channel.h index 83c4ef1aef88..efa7c10c5e62 100644 --- a/runtime/src/iree/hal/drivers/null/channel.h +++ b/runtime/src/iree/hal/drivers/null/channel.h @@ -10,6 +10,10 @@ #include "iree/base/api.h" #include "iree/hal/api.h" +//===----------------------------------------------------------------------===// +// iree_hal_null_channel_t +//===----------------------------------------------------------------------===// + // Creates a {Null} HAL collective channel using the given |params|. iree_status_t iree_hal_null_channel_create(iree_hal_channel_params_t params, iree_allocator_t host_allocator, diff --git a/runtime/src/iree/hal/drivers/null/command_buffer.c b/runtime/src/iree/hal/drivers/null/command_buffer.c index 9d474d44cb96..4f8fe822ccf1 100644 --- a/runtime/src/iree/hal/drivers/null/command_buffer.c +++ b/runtime/src/iree/hal/drivers/null/command_buffer.c @@ -10,6 +10,10 @@ #include "iree/hal/drivers/null/channel.h" #include "iree/hal/drivers/null/executable.h" +//===----------------------------------------------------------------------===// +// iree_hal_null_command_buffer_t +//===----------------------------------------------------------------------===// + typedef struct iree_hal_null_command_buffer_t { iree_hal_command_buffer_t base; iree_allocator_t host_allocator; @@ -31,8 +35,8 @@ iree_status_t iree_hal_null_command_buffer_create( iree_allocator_t host_allocator, iree_hal_command_buffer_t** out_command_buffer) { IREE_ASSERT_ARGUMENT(out_command_buffer); - *out_command_buffer = NULL; IREE_TRACE_ZONE_BEGIN(z0); + *out_command_buffer = NULL; iree_hal_null_command_buffer_t* command_buffer = NULL; IREE_RETURN_AND_END_ZONE_IF_ERROR( diff --git a/runtime/src/iree/hal/drivers/null/command_buffer.h b/runtime/src/iree/hal/drivers/null/command_buffer.h index cca92367dd82..d8ab61d89175 100644 --- a/runtime/src/iree/hal/drivers/null/command_buffer.h +++ b/runtime/src/iree/hal/drivers/null/command_buffer.h @@ -10,6 +10,10 @@ #include "iree/base/api.h" #include "iree/hal/api.h" +//===----------------------------------------------------------------------===// +// iree_hal_null_command_buffer_t +//===----------------------------------------------------------------------===// + // Creates {Null} command buffer. iree_status_t iree_hal_null_command_buffer_create( iree_hal_allocator_t* device_allocator, iree_hal_command_buffer_mode_t mode, diff --git a/runtime/src/iree/hal/drivers/null/device.c b/runtime/src/iree/hal/drivers/null/device.c index aaa7b1591fd5..ce122400d313 100644 --- a/runtime/src/iree/hal/drivers/null/device.c +++ b/runtime/src/iree/hal/drivers/null/device.c @@ -17,6 +17,10 @@ #include "iree/hal/utils/file_transfer.h" #include "iree/hal/utils/memory_file.h" +//===----------------------------------------------------------------------===// +// iree_hal_null_device_t +//===----------------------------------------------------------------------===// + typedef struct iree_hal_null_device_t { iree_hal_resource_t resource; iree_string_view_t identifier; @@ -60,8 +64,8 @@ iree_status_t iree_hal_null_device_create( iree_allocator_t host_allocator, iree_hal_device_t** out_device) { IREE_ASSERT_ARGUMENT(options); IREE_ASSERT_ARGUMENT(out_device); - *out_device = NULL; IREE_TRACE_ZONE_BEGIN(z0); + *out_device = NULL; // Verify the parameters prior to creating resources. IREE_RETURN_AND_END_ZONE_IF_ERROR( diff --git a/runtime/src/iree/hal/drivers/null/device.h b/runtime/src/iree/hal/drivers/null/device.h index aa70db6408d6..18978668bbf7 100644 --- a/runtime/src/iree/hal/drivers/null/device.h +++ b/runtime/src/iree/hal/drivers/null/device.h @@ -10,6 +10,10 @@ #include "iree/base/api.h" #include "iree/hal/api.h" +//===----------------------------------------------------------------------===// +// iree_hal_null_device_t +//===----------------------------------------------------------------------===// + // NOTE: nothing in the skeleton implementation. Device creation and adoption is // part of the public API header. This header can contain internal types and // functions. diff --git a/runtime/src/iree/hal/drivers/null/driver.c b/runtime/src/iree/hal/drivers/null/driver.c index 94be18a45364..78cf511a6999 100644 --- a/runtime/src/iree/hal/drivers/null/driver.c +++ b/runtime/src/iree/hal/drivers/null/driver.c @@ -8,6 +8,10 @@ #include "iree/hal/drivers/null/api.h" +//===----------------------------------------------------------------------===// +// iree_hal_null_driver_t +//===----------------------------------------------------------------------===// + // TODO(null): if it's possible to have more than one device use real IDs. // This is a placeholder for this skeleton that just indicates the first and // only device. @@ -57,8 +61,8 @@ IREE_API_EXPORT iree_status_t iree_hal_null_driver_create( iree_allocator_t host_allocator, iree_hal_driver_t** out_driver) { IREE_ASSERT_ARGUMENT(options); IREE_ASSERT_ARGUMENT(out_driver); - *out_driver = NULL; IREE_TRACE_ZONE_BEGIN(z0); + *out_driver = NULL; // TODO(null): verify options; this may be moved after any libraries are // loaded so the verification can use underlying implementation queries. diff --git a/runtime/src/iree/hal/drivers/null/driver.h b/runtime/src/iree/hal/drivers/null/driver.h index 84b12c1beac8..1938778056d8 100644 --- a/runtime/src/iree/hal/drivers/null/driver.h +++ b/runtime/src/iree/hal/drivers/null/driver.h @@ -10,6 +10,10 @@ #include "iree/base/api.h" #include "iree/hal/api.h" +//===----------------------------------------------------------------------===// +// iree_hal_null_driver_t +//===----------------------------------------------------------------------===// + // NOTE: nothing in the skeleton implementation. Driver creation and adoption is // part of the public API header. This header can contain internal types and // functions. diff --git a/runtime/src/iree/hal/drivers/null/event.c b/runtime/src/iree/hal/drivers/null/event.c index 5f1e413ca204..fabbe45b1311 100644 --- a/runtime/src/iree/hal/drivers/null/event.c +++ b/runtime/src/iree/hal/drivers/null/event.c @@ -6,6 +6,10 @@ #include "iree/hal/drivers/null/event.h" +//===----------------------------------------------------------------------===// +// iree_hal_null_event_t +//===----------------------------------------------------------------------===// + typedef struct iree_hal_null_event_t { iree_hal_resource_t resource; iree_allocator_t host_allocator; @@ -23,8 +27,8 @@ iree_status_t iree_hal_null_event_create( iree_hal_queue_affinity_t queue_affinity, iree_hal_event_flags_t flags, iree_allocator_t host_allocator, iree_hal_event_t** out_event) { IREE_ASSERT_ARGUMENT(out_event); - *out_event = NULL; IREE_TRACE_ZONE_BEGIN(z0); + *out_event = NULL; iree_hal_null_event_t* event = NULL; IREE_RETURN_AND_END_ZONE_IF_ERROR( diff --git a/runtime/src/iree/hal/drivers/null/event.h b/runtime/src/iree/hal/drivers/null/event.h index 68c11f44d1e6..ca7f364e458e 100644 --- a/runtime/src/iree/hal/drivers/null/event.h +++ b/runtime/src/iree/hal/drivers/null/event.h @@ -10,6 +10,10 @@ #include "iree/base/api.h" #include "iree/hal/api.h" +//===----------------------------------------------------------------------===// +// iree_hal_null_event_t +//===----------------------------------------------------------------------===// + // WIP API and may change. Mostly ignored for now. iree_status_t iree_hal_null_event_create( iree_hal_queue_affinity_t queue_affinity, iree_hal_event_flags_t flags, diff --git a/runtime/src/iree/hal/drivers/null/executable.c b/runtime/src/iree/hal/drivers/null/executable.c index a90d697d9d8d..3301d6cd767a 100644 --- a/runtime/src/iree/hal/drivers/null/executable.c +++ b/runtime/src/iree/hal/drivers/null/executable.c @@ -6,6 +6,10 @@ #include "iree/hal/drivers/null/executable.h" +//===----------------------------------------------------------------------===// +// iree_hal_null_executable_t +//===----------------------------------------------------------------------===// + typedef struct iree_hal_null_executable_t { iree_hal_resource_t resource; iree_allocator_t host_allocator; @@ -24,8 +28,8 @@ iree_status_t iree_hal_null_executable_create( iree_allocator_t host_allocator, iree_hal_executable_t** out_executable) { IREE_ASSERT_ARGUMENT(executable_params); IREE_ASSERT_ARGUMENT(out_executable); - *out_executable = NULL; IREE_TRACE_ZONE_BEGIN(z0); + *out_executable = NULL; // Allocate storage for the executable and its associated data structures. iree_hal_null_executable_t* executable = NULL; diff --git a/runtime/src/iree/hal/drivers/null/executable.h b/runtime/src/iree/hal/drivers/null/executable.h index 0107e1a14d4a..0ae87aefb947 100644 --- a/runtime/src/iree/hal/drivers/null/executable.h +++ b/runtime/src/iree/hal/drivers/null/executable.h @@ -10,6 +10,10 @@ #include "iree/base/api.h" #include "iree/hal/api.h" +//===----------------------------------------------------------------------===// +// iree_hal_null_executable_t +//===----------------------------------------------------------------------===// + // Creates a {Null} executable from a binary in memory. Each executable may // contain multiple entry points and be composed of several modules presented to // the HAL as a single instance. See iree_hal_executable_params_t for more diff --git a/runtime/src/iree/hal/drivers/null/executable_cache.c b/runtime/src/iree/hal/drivers/null/executable_cache.c index d4f0ad6ad066..a7c6f4b7cec1 100644 --- a/runtime/src/iree/hal/drivers/null/executable_cache.c +++ b/runtime/src/iree/hal/drivers/null/executable_cache.c @@ -8,6 +8,10 @@ #include "iree/hal/drivers/null/executable.h" +//===----------------------------------------------------------------------===// +// iree_hal_null_executable_cache_t +//===----------------------------------------------------------------------===// + typedef struct iree_hal_null_executable_cache_t { iree_hal_resource_t resource; iree_allocator_t host_allocator; @@ -26,8 +30,8 @@ iree_status_t iree_hal_null_executable_cache_create( iree_string_view_t identifier, iree_allocator_t host_allocator, iree_hal_executable_cache_t** out_executable_cache) { IREE_ASSERT_ARGUMENT(out_executable_cache); - *out_executable_cache = NULL; IREE_TRACE_ZONE_BEGIN(z0); + *out_executable_cache = NULL; iree_hal_null_executable_cache_t* executable_cache = NULL; IREE_RETURN_AND_END_ZONE_IF_ERROR( diff --git a/runtime/src/iree/hal/drivers/null/executable_cache.h b/runtime/src/iree/hal/drivers/null/executable_cache.h index 519b8c05e18a..b4af9e76cc28 100644 --- a/runtime/src/iree/hal/drivers/null/executable_cache.h +++ b/runtime/src/iree/hal/drivers/null/executable_cache.h @@ -10,6 +10,10 @@ #include "iree/base/api.h" #include "iree/hal/api.h" +//===----------------------------------------------------------------------===// +// iree_hal_null_executable_cache_t +//===----------------------------------------------------------------------===// + // Creates a no-op executable cache that does not cache at all. // This is useful to isolate pipeline caching behavior and verify compilation // behavior. diff --git a/runtime/src/iree/hal/drivers/null/semaphore.c b/runtime/src/iree/hal/drivers/null/semaphore.c index 25ec7dc99fbb..b397c85fe2c1 100644 --- a/runtime/src/iree/hal/drivers/null/semaphore.c +++ b/runtime/src/iree/hal/drivers/null/semaphore.c @@ -29,8 +29,8 @@ iree_status_t iree_hal_null_semaphore_create( uint64_t initial_value, iree_hal_semaphore_flags_t flags, iree_allocator_t host_allocator, iree_hal_semaphore_t** out_semaphore) { IREE_ASSERT_ARGUMENT(out_semaphore); - *out_semaphore = NULL; IREE_TRACE_ZONE_BEGIN(z0); + *out_semaphore = NULL; iree_hal_null_semaphore_t* semaphore = NULL; IREE_RETURN_AND_END_ZONE_IF_ERROR( diff --git a/runtime/src/iree/hal/drivers/vulkan/native_semaphore.cc b/runtime/src/iree/hal/drivers/vulkan/native_semaphore.cc index f75b2c0bbdb1..631f138a1c26 100644 --- a/runtime/src/iree/hal/drivers/vulkan/native_semaphore.cc +++ b/runtime/src/iree/hal/drivers/vulkan/native_semaphore.cc @@ -68,8 +68,7 @@ iree_status_t iree_hal_vulkan_native_semaphore_create( &semaphore->base); semaphore->logical_device = logical_device; semaphore->handle = handle; - iree_atomic_store_intptr(&semaphore->failure_status, 0, - iree_memory_order_release); + iree_atomic_store(&semaphore->failure_status, 0, iree_memory_order_release); *out_semaphore = &semaphore->base; } else { logical_device->syms()->vkDestroySemaphore(*logical_device, handle, @@ -87,7 +86,7 @@ static void iree_hal_vulkan_native_semaphore_destroy( iree_allocator_t host_allocator = semaphore->logical_device->host_allocator(); IREE_TRACE_ZONE_BEGIN(z0); - iree_status_ignore((iree_status_t)iree_atomic_load_intptr( + iree_status_ignore((iree_status_t)iree_atomic_load( &semaphore->failure_status, iree_memory_order_acquire)); semaphore->logical_device->syms()->vkDestroySemaphore( @@ -127,7 +126,7 @@ static iree_status_t iree_hal_vulkan_native_semaphore_query( // If the semaphore failed then clone the status so we can report it. if (value >= IREE_HAL_SEMAPHORE_FAILURE_VALUE) { - iree_status_t failure_status = (iree_status_t)iree_atomic_load_intptr( + iree_status_t failure_status = (iree_status_t)iree_atomic_load( &semaphore->failure_status, iree_memory_order_acquire); if (iree_status_is_ok(failure_status)) { return iree_make_status(IREE_STATUS_RESOURCE_EXHAUSTED, @@ -178,7 +177,7 @@ static void iree_hal_vulkan_native_semaphore_fail( // Try to set our local status - we only preserve the first failure so only // do this if we are going from a valid semaphore to a failed one. iree_status_t old_status = iree_ok_status(); - if (!iree_atomic_compare_exchange_strong_intptr( + if (!iree_atomic_compare_exchange_strong( &semaphore->failure_status, (intptr_t*)&old_status, (intptr_t)status, iree_memory_order_acq_rel, iree_memory_order_relaxed /* old_status is unused */)) { diff --git a/runtime/src/iree/hal/local/executable_plugin_manager.c b/runtime/src/iree/hal/local/executable_plugin_manager.c index 6d41c76df5d0..2739aa9f26c6 100644 --- a/runtime/src/iree/hal/local/executable_plugin_manager.c +++ b/runtime/src/iree/hal/local/executable_plugin_manager.c @@ -432,8 +432,8 @@ static iree_status_t iree_hal_executable_plugin_manager_register( // Get the next provider slot. Note that we don't yet increment it as we need // to put the provider in there first. - int32_t slot = iree_atomic_load_int32(&manager->provider_count, - iree_memory_order_acquire); + int32_t slot = + iree_atomic_load(&manager->provider_count, iree_memory_order_acquire); if (slot >= manager->capacity) { iree_slim_mutex_unlock(&manager->mutex); return iree_make_status(IREE_STATUS_RESOURCE_EXHAUSTED, @@ -449,8 +449,7 @@ static iree_status_t iree_hal_executable_plugin_manager_register( } // Mark the slot as valid now that the provider is in it. - iree_atomic_fetch_add_int32(&manager->provider_count, 1, - iree_memory_order_release); + iree_atomic_fetch_add(&manager->provider_count, 1, iree_memory_order_release); iree_slim_mutex_unlock(&manager->mutex); return iree_ok_status(); @@ -506,8 +505,8 @@ static iree_status_t iree_hal_executable_plugin_manager_resolve( // but that's ok: multithreaded registration/resolution is non-deterministic // by nature. Not holding the lock here means we allow multiple threads to // resolve imports at the same time. - int32_t provider_count = iree_atomic_load_int32(&manager->provider_count, - iree_memory_order_acquire); + int32_t provider_count = + iree_atomic_load(&manager->provider_count, iree_memory_order_acquire); // Scan in reverse registration order so that more recently registered // providers get queried first. try_resolve will populate any function diff --git a/runtime/src/iree/hal/semaphore.h b/runtime/src/iree/hal/semaphore.h index 8cc073bfcbf9..52571ed048fd 100644 --- a/runtime/src/iree/hal/semaphore.h +++ b/runtime/src/iree/hal/semaphore.h @@ -30,10 +30,6 @@ enum iree_hal_semaphore_flag_bits_t { }; typedef uint32_t iree_hal_semaphore_flags_t; -//===----------------------------------------------------------------------===// -// iree_hal_semaphore_t -//===----------------------------------------------------------------------===// - // The maximum valid payload value of an iree_hal_semaphore_t. // Payload values larger than this indicate that the semaphore has failed. // @@ -56,8 +52,66 @@ typedef uint32_t iree_hal_semaphore_flags_t; // https://vulkan.gpuinfo.org/displayextensionproperty.php?name=maxTimelineSemaphoreValueDifference #define IREE_HAL_SEMAPHORE_MAX_VALUE (2147483647ull - 1) +// The minimum value for a semaphore that indicates failure. Any value +// greater-than-or-equal-to (>=) this indicates the semaphore has failed. +// +// If the upper bit 63 is set then the value represents an iree_status_t. +// Use iree_hal_semaphore_failure_as_status to convert a payload value to a +// status. Not all implementations do (or can) support encoding statuses and may +// only ever be able to set a failing semaphore to this value. #define IREE_HAL_SEMAPHORE_FAILURE_VALUE (IREE_HAL_SEMAPHORE_MAX_VALUE + 1) +// Bit indicating that a failing semaphore value can be interpreted as an +// iree_status_t. +#define IREE_HAL_SEMAPHORE_FAILURE_VALUE_STATUS_BIT 0x8000000000000000ull + +// Returns a semaphore payload value that encodes the given |status|. +// Ownership of the status is transferred to the semaphore and it must be +// freed by a consumer. Not all implementations can support failure status +// payloads and this should only be used by those implementations that can. +static inline uint64_t iree_hal_status_as_semaphore_failure( + iree_status_t status) { + return IREE_HAL_SEMAPHORE_FAILURE_VALUE_STATUS_BIT | + (((uint64_t)status) >> 1); +} + +// Returns OK if the |value| does not indicate an error. +// Returns an error status if the semaphore payload value represents a failure. +// If the payload contains an encoded iree_status_t it will be cloned and the +// new copy will be returned to the caller. +static inline iree_status_t iree_hal_semaphore_failure_as_status( + uint64_t value) { + if (value >= IREE_HAL_SEMAPHORE_FAILURE_VALUE) { + if (value & IREE_HAL_SEMAPHORE_FAILURE_VALUE_STATUS_BIT) { + // The top bits of a pointer are sign-extended from bit 47 so we can + // restore the top bit by left-shifting the upper bits and then + // right-shifting with sign extension. We only use a single bit today and + // so bit 62 should still be the original value of the pointer. + // Note that if the status is just a code (no allocated pointer) this + // clone is a no-op and the code will be returned without an allocation. + // + // See: + // https://en.wikipedia.org/wiki/X86-64#Canonical_form_addresses + return iree_status_clone((iree_status_t)(((int64_t)value << 1) >> 1)); + } else { + return iree_status_from_code(IREE_STATUS_INTERNAL); + } + } else { + return iree_ok_status(); + } +} + +// Frees an iree_status_t encoded in a semaphore |value|, if any. +static inline void iree_hal_semaphore_failure_free(uint64_t value) { + if (value & IREE_HAL_SEMAPHORE_FAILURE_VALUE_STATUS_BIT) { + iree_status_free((iree_status_t)(((int64_t)value << 1) >> 1)); + } +} + +//===----------------------------------------------------------------------===// +// iree_hal_semaphore_t +//===----------------------------------------------------------------------===// + // Synchronization mechanism for host->device, device->host, host->host, // and device->device notification. Semaphores behave like Vulkan timeline // semaphores (or D3D12 fences) and contain a monotonically increasing diff --git a/runtime/src/iree/hal/string_util.c b/runtime/src/iree/hal/string_util.c index 11cd2ce7b14f..9b097973178e 100644 --- a/runtime/src/iree/hal/string_util.c +++ b/runtime/src/iree/hal/string_util.c @@ -134,6 +134,18 @@ IREE_API_EXPORT iree_status_t iree_hal_parse_element_type( numerical_type = IREE_HAL_NUMERICAL_TYPE_INTEGER_SIGNED; } else if (iree_string_view_consume_prefix(&str_value, IREE_SV("ui"))) { numerical_type = IREE_HAL_NUMERICAL_TYPE_INTEGER_UNSIGNED; + } else if (iree_string_view_equal(str_value, IREE_SV("f8E5M2"))) { + *out_element_type = IREE_HAL_ELEMENT_TYPE_FLOAT_8_E5M2; + return iree_ok_status(); + } else if (iree_string_view_equal(str_value, IREE_SV("f8E4M3"))) { + *out_element_type = IREE_HAL_ELEMENT_TYPE_FLOAT_8_E4M3; + return iree_ok_status(); + } else if (iree_string_view_equal(str_value, IREE_SV("f8E5M2FNUZ"))) { + *out_element_type = IREE_HAL_ELEMENT_TYPE_FLOAT_8_E5M2_FNUZ; + return iree_ok_status(); + } else if (iree_string_view_equal(str_value, IREE_SV("f8E4M3FNUZ"))) { + *out_element_type = IREE_HAL_ELEMENT_TYPE_FLOAT_8_E4M3_FNUZ; + return iree_ok_status(); } else if (iree_string_view_consume_prefix(&str_value, IREE_SV("f"))) { numerical_type = IREE_HAL_NUMERICAL_TYPE_FLOAT_IEEE; } else if (iree_string_view_consume_prefix(&str_value, IREE_SV("bf"))) { @@ -164,6 +176,37 @@ IREE_API_EXPORT iree_status_t iree_hal_parse_element_type( IREE_API_EXPORT iree_status_t iree_hal_format_element_type( iree_hal_element_type_t element_type, iree_host_size_t buffer_capacity, char* buffer, iree_host_size_t* out_buffer_length) { + const char* special_name = NULL; + switch (element_type) { + case IREE_HAL_ELEMENT_TYPE_FLOAT_8_E5M2: + special_name = "f8E5M2"; + break; + case IREE_HAL_ELEMENT_TYPE_FLOAT_8_E4M3: + special_name = "f8E4M3"; + break; + case IREE_HAL_ELEMENT_TYPE_FLOAT_8_E5M2_FNUZ: + special_name = "f8E5M2FNUZ"; + break; + case IREE_HAL_ELEMENT_TYPE_FLOAT_8_E4M3_FNUZ: + special_name = "f8E4M3FNUZ"; + break; + default: + break; + } + if (special_name) { + int n = snprintf(buffer, buffer_capacity, "%s", special_name); + if (n < 0) { + return iree_make_status(IREE_STATUS_FAILED_PRECONDITION, + "snprintf failed"); + } + if (out_buffer_length) { + *out_buffer_length = n; + } + return n >= buffer_capacity + ? iree_status_from_code(IREE_STATUS_OUT_OF_RANGE) + : iree_ok_status(); + } + if (out_buffer_length) { *out_buffer_length = 0; } @@ -366,6 +409,38 @@ static iree_status_t iree_hal_parse_element_unsafe( return iree_string_view_atoi_uint64(data_str, (uint64_t*)out_data) ? iree_ok_status() : iree_status_from_code(IREE_STATUS_INVALID_ARGUMENT); + case IREE_HAL_ELEMENT_TYPE_FLOAT_8_E5M2: { + float temp_float = 0; + if (!iree_string_view_atof(data_str, &temp_float)) { + return iree_status_from_code(IREE_STATUS_INVALID_ARGUMENT); + } + *(uint8_t*)out_data = (uint8_t)iree_math_f32_to_f8e5m2(temp_float); + return iree_ok_status(); + } + case IREE_HAL_ELEMENT_TYPE_FLOAT_8_E4M3: { + float temp_float = 0; + if (!iree_string_view_atof(data_str, &temp_float)) { + return iree_status_from_code(IREE_STATUS_INVALID_ARGUMENT); + } + *(uint8_t*)out_data = (uint8_t)iree_math_f32_to_f8e4m3(temp_float); + return iree_ok_status(); + } + case IREE_HAL_ELEMENT_TYPE_FLOAT_8_E5M2_FNUZ: { + float temp_float = 0; + if (!iree_string_view_atof(data_str, &temp_float)) { + return iree_status_from_code(IREE_STATUS_INVALID_ARGUMENT); + } + *(uint8_t*)out_data = (uint8_t)iree_math_f32_to_f8e5m2fnuz(temp_float); + return iree_ok_status(); + } + case IREE_HAL_ELEMENT_TYPE_FLOAT_8_E4M3_FNUZ: { + float temp_float = 0; + if (!iree_string_view_atof(data_str, &temp_float)) { + return iree_status_from_code(IREE_STATUS_INVALID_ARGUMENT); + } + *(uint8_t*)out_data = (uint8_t)iree_math_f32_to_f8e4m3fnuz(temp_float); + return iree_ok_status(); + } case IREE_HAL_ELEMENT_TYPE_BFLOAT_16: { float temp = 0; if (!iree_string_view_atof(data_str, &temp)) { diff --git a/runtime/src/iree/hal/string_util_test.cc b/runtime/src/iree/hal/string_util_test.cc index 2d134fdbf9e6..8de9fe58f482 100644 --- a/runtime/src/iree/hal/string_util_test.cc +++ b/runtime/src/iree/hal/string_util_test.cc @@ -608,6 +608,14 @@ TEST(ElementTypeStringUtilTest, ParseElementType) { IsOkAndHolds(Eq(IREE_HAL_ELEMENT_TYPE_FLOAT_16))); EXPECT_THAT(ParseElementType("bf16"), IsOkAndHolds(Eq(IREE_HAL_ELEMENT_TYPE_BFLOAT_16))); + EXPECT_THAT(ParseElementType("f8E5M2"), + IsOkAndHolds(Eq(IREE_HAL_ELEMENT_TYPE_FLOAT_8_E5M2))); + EXPECT_THAT(ParseElementType("f8E4M3"), + IsOkAndHolds(Eq(IREE_HAL_ELEMENT_TYPE_FLOAT_8_E4M3))); + EXPECT_THAT(ParseElementType("f8E5M2FNUZ"), + IsOkAndHolds(Eq(IREE_HAL_ELEMENT_TYPE_FLOAT_8_E5M2_FNUZ))); + EXPECT_THAT(ParseElementType("f8E4M3FNUZ"), + IsOkAndHolds(Eq(IREE_HAL_ELEMENT_TYPE_FLOAT_8_E4M3_FNUZ))); EXPECT_THAT(ParseElementType("x64"), IsOkAndHolds(Eq(IREE_HAL_ELEMENT_TYPE_OPAQUE_64))); EXPECT_THAT(ParseElementType("*64"), @@ -635,8 +643,18 @@ TEST(ElementTypeStringUtilTest, FormatElementType) { IsOkAndHolds(Eq("ui16"))); EXPECT_THAT(FormatElementType(IREE_HAL_ELEMENT_TYPE_FLOAT_32), IsOkAndHolds(Eq("f32"))); + EXPECT_THAT(FormatElementType(IREE_HAL_ELEMENT_TYPE_FLOAT_16), + IsOkAndHolds(Eq("f16"))); EXPECT_THAT(FormatElementType(IREE_HAL_ELEMENT_TYPE_BFLOAT_16), IsOkAndHolds(Eq("bf16"))); + EXPECT_THAT(FormatElementType(IREE_HAL_ELEMENT_TYPE_FLOAT_8_E5M2), + IsOkAndHolds(Eq("f8E5M2"))); + EXPECT_THAT(FormatElementType(IREE_HAL_ELEMENT_TYPE_FLOAT_8_E4M3), + IsOkAndHolds(Eq("f8E4M3"))); + EXPECT_THAT(FormatElementType(IREE_HAL_ELEMENT_TYPE_FLOAT_8_E5M2_FNUZ), + IsOkAndHolds(Eq("f8E5M2FNUZ"))); + EXPECT_THAT(FormatElementType(IREE_HAL_ELEMENT_TYPE_FLOAT_8_E4M3_FNUZ), + IsOkAndHolds(Eq("f8E4M3FNUZ"))); EXPECT_THAT(FormatElementType(IREE_HAL_ELEMENT_TYPE_OPAQUE_64), IsOkAndHolds(Eq("*64"))); EXPECT_THAT(FormatElementType(iree_hal_make_element_type( diff --git a/runtime/src/iree/hal/utils/deferred_work_queue.c b/runtime/src/iree/hal/utils/deferred_work_queue.c index b4b2285c972f..e41fe3523778 100644 --- a/runtime/src/iree/hal/utils/deferred_work_queue.c +++ b/runtime/src/iree/hal/utils/deferred_work_queue.c @@ -393,9 +393,9 @@ static void iree_hal_deferred_work_queue_working_area_initialize( iree_notification_initialize(&working_area->state_notification); iree_hal_deferred_work_queue_ready_action_list_deinitialize( &working_area->ready_worklist, host_allocator); - iree_atomic_store_int32(&working_area->worker_state, - IREE_HAL_WORKER_STATE_IDLE_WAITING, - iree_memory_order_release); + iree_atomic_store(&working_area->worker_state, + IREE_HAL_WORKER_STATE_IDLE_WAITING, + iree_memory_order_release); } static void iree_hal_deferred_work_queue_working_area_deinitialize( @@ -413,9 +413,9 @@ static void iree_hal_deferred_work_queue_completion_area_initialize( iree_notification_initialize(&completion_area->state_notification); iree_hal_deferred_work_queue_completion_list_initialize( &completion_area->completion_list); - iree_atomic_store_int32(&completion_area->worker_state, - IREE_HAL_WORKER_STATE_IDLE_WAITING, - iree_memory_order_release); + iree_atomic_store(&completion_area->worker_state, + IREE_HAL_WORKER_STATE_IDLE_WAITING, + iree_memory_order_release); } static void iree_hal_deferred_work_queue_completion_area_deinitialize( @@ -557,17 +557,17 @@ static iree_hal_deferred_work_queue_t* iree_hal_deferred_work_queue_cast( static void iree_hal_deferred_work_queue_notify_worker_thread( iree_hal_deferred_work_queue_working_area_t* working_area) { - iree_atomic_store_int32(&working_area->worker_state, - IREE_HAL_WORKER_STATE_WORKLOAD_PENDING, - iree_memory_order_release); + iree_atomic_store(&working_area->worker_state, + IREE_HAL_WORKER_STATE_WORKLOAD_PENDING, + iree_memory_order_release); iree_notification_post(&working_area->state_notification, IREE_ALL_WAITERS); } static void iree_hal_deferred_work_queue_notify_completion_thread( iree_hal_deferred_work_queue_completion_area_t* completion_area) { - iree_atomic_store_int32(&completion_area->worker_state, - IREE_HAL_WORKER_STATE_WORKLOAD_PENDING, - iree_memory_order_release); + iree_atomic_store(&completion_area->worker_state, + IREE_HAL_WORKER_STATE_WORKLOAD_PENDING, + iree_memory_order_release); iree_notification_post(&completion_area->state_notification, IREE_ALL_WAITERS); } @@ -1236,14 +1236,14 @@ iree_status_t iree_hal_deferred_work_queue_issue( static bool iree_hal_deferred_work_queue_worker_has_incoming_request( iree_hal_deferred_work_queue_working_area_t* working_area) { - iree_hal_deferred_work_queue_worker_state_t value = iree_atomic_load_int32( - &working_area->worker_state, iree_memory_order_acquire); + iree_hal_deferred_work_queue_worker_state_t value = + iree_atomic_load(&working_area->worker_state, iree_memory_order_acquire); return value == IREE_HAL_WORKER_STATE_WORKLOAD_PENDING; } static bool iree_hal_deferred_work_queue_completion_has_incoming_request( iree_hal_deferred_work_queue_completion_area_t* completion_area) { - iree_hal_deferred_work_queue_worker_state_t value = iree_atomic_load_int32( + iree_hal_deferred_work_queue_worker_state_t value = iree_atomic_load( &completion_area->worker_state, iree_memory_order_acquire); return value == IREE_HAL_WORKER_STATE_WORKLOAD_PENDING; } @@ -1369,9 +1369,9 @@ static int iree_hal_deferred_work_queue_completion_execute( // sure that we don't accidentally ignore new workload pushed after done // ready list processing but before overwriting the state from this worker // thread. - iree_atomic_store_int32(&completion_area->worker_state, - IREE_HAL_WORKER_STATE_IDLE_WAITING, - iree_memory_order_release); + iree_atomic_store(&completion_area->worker_state, + IREE_HAL_WORKER_STATE_IDLE_WAITING, + iree_memory_order_release); iree_hal_deferred_work_queue_worker_process_completion(actions); iree_slim_mutex_lock(&actions->action_mutex); @@ -1424,9 +1424,9 @@ static int iree_hal_deferred_work_queue_worker_execute( // sure that we don't accidentally ignore new workload pushed after done // ready list processing but before overwriting the state from this worker // thread. - iree_atomic_store_int32(&working_area->worker_state, - IREE_HAL_WORKER_STATE_IDLE_WAITING, - iree_memory_order_release); + iree_atomic_store(&working_area->worker_state, + IREE_HAL_WORKER_STATE_IDLE_WAITING, + iree_memory_order_release); iree_hal_deferred_work_queue_worker_process_ready_list(actions); diff --git a/runtime/src/iree/hal/utils/file_transfer.c b/runtime/src/iree/hal/utils/file_transfer.c index cee1df6ebe2c..2bc8decf2f9a 100644 --- a/runtime/src/iree/hal/utils/file_transfer.c +++ b/runtime/src/iree/hal/utils/file_transfer.c @@ -242,8 +242,8 @@ static iree_status_t iree_hal_transfer_operation_create( // steps are part of this transfer. IREE_TRACE({ static iree_atomic_int32_t next_trace_id = IREE_ATOMIC_VAR_INIT(0); - operation->trace_id = iree_atomic_fetch_add_int32( - &next_trace_id, 1, iree_memory_order_seq_cst); + operation->trace_id = + iree_atomic_fetch_add(&next_trace_id, 1, iree_memory_order_seq_cst); IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, operation->trace_id); }); diff --git a/runtime/src/iree/modules/io/parameters/module.c b/runtime/src/iree/modules/io/parameters/module.c index 655c1bae9f8f..c5dfffdd0614 100644 --- a/runtime/src/iree/modules/io/parameters/module.c +++ b/runtime/src/iree/modules/io/parameters/module.c @@ -489,6 +489,7 @@ IREE_API_EXPORT iree_status_t iree_io_parameters_module_create( .destroy = iree_io_parameters_module_destroy, .alloc_state = iree_io_parameters_module_alloc_state, .free_state = iree_io_parameters_module_free_state, + .fork_state = iree_io_parameters_module_fork_state, .notify = iree_io_parameters_module_notify, }; diff --git a/runtime/src/iree/schemas/BUILD.bazel b/runtime/src/iree/schemas/BUILD.bazel index a8fbfcab8b12..e98a425424ee 100644 --- a/runtime/src/iree/schemas/BUILD.bazel +++ b/runtime/src/iree/schemas/BUILD.bazel @@ -20,6 +20,13 @@ FLATCC_ARGS = [ "--json", ] +iree_flatbuffer_c_library( + name = "amdgpu_executable_def_c_fbs", + srcs = ["amdgpu_executable_def.fbs"], + flatcc_args = FLATCC_ARGS, + includes = ["executable_debug_info.fbs"], +) + iree_flatbuffer_c_library( name = "bytecode_module_def_c_fbs", srcs = ["bytecode_module_def.fbs"], @@ -70,6 +77,7 @@ iree_flatbuffer_c_library( iree_build_test( name = "schema_build_test", targets = [ + ":amdgpu_executable_def_c_fbs", ":bytecode_module_def_c_fbs", ":cuda_executable_def_c_fbs", ":executable_debug_info_c_fbs", diff --git a/runtime/src/iree/schemas/CMakeLists.txt b/runtime/src/iree/schemas/CMakeLists.txt index 574b2cac4578..f30430df0789 100644 --- a/runtime/src/iree/schemas/CMakeLists.txt +++ b/runtime/src/iree/schemas/CMakeLists.txt @@ -10,6 +10,21 @@ iree_add_all_subdirs() +flatbuffer_c_library( + NAME + amdgpu_executable_def_c_fbs + SRCS + "amdgpu_executable_def.fbs" + FLATCC_ARGS + "--reader" + "--builder" + "--verifier" + "--json" + INCLUDES + "executable_debug_info.fbs" + PUBLIC +) + flatbuffer_c_library( NAME bytecode_module_def_c_fbs diff --git a/runtime/src/iree/schemas/amdgpu_executable_def.fbs b/runtime/src/iree/schemas/amdgpu_executable_def.fbs new file mode 100644 index 000000000000..43efdb0a34dc --- /dev/null +++ b/runtime/src/iree/schemas/amdgpu_executable_def.fbs @@ -0,0 +1,63 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +include "iree/schemas/executable_debug_info.fbs"; + +namespace iree.hal.amdgpu; + +// 'AMDGPU v1 Executable'. +file_identifier "AMD1"; +file_extension "amd1"; + +// A struct for the kernel block size along each dimension. +struct Dims { + x:uint32; + y:uint32; + z:uint32; +} + +// Describes the behavior of each binding. +enum BindingBits:uint64 (bit_flags) { + READ_ONLY = 0, // 1u << 0 + INDIRECT = 1, // 1u << 1 +} + +// Information about an exported function on the executable. +table ExportDef { + // String name of the exported function symbol in the module. + symbol_name:string; + + // Workgroup size for the export. + workgroup_size:Dims; + + // Total number of 32-bit push constants used by the export. + constant_count:uint32; + + // Binding count and flags for each binding. + binding_flags:[BindingBits]; + + // Optional debug information related to the export. + debug_info:iree.hal.debug.ExportDef; +} + +// A library containing one or more exported functions. +table ModuleDef { + // AMD ELF image for loading an hsa_executable_t. + image:string; +} + +table ExecutableDef { + // Exported functions in canonical executable entry point order. + exports:[ExportDef]; + + // Modules containing executable code. + modules:[ModuleDef]; + + // Embedded source files sorted ascending by path. + source_files:[iree.hal.debug.SourceFileDef]; +} + +root_type ExecutableDef; diff --git a/runtime/src/iree/task/affinity_set.h b/runtime/src/iree/task/affinity_set.h index 3dbf756d7519..dfe6a7a5293e 100644 --- a/runtime/src/iree/task/affinity_set.h +++ b/runtime/src/iree/task/affinity_set.h @@ -61,25 +61,25 @@ typedef iree_atomic_int64_t iree_atomic_task_affinity_set_t; static inline iree_task_affinity_set_t iree_atomic_task_affinity_set_load( iree_atomic_task_affinity_set_t* set, iree_memory_order_t order) { - return iree_atomic_load_int64(set, order); + return iree_atomic_load(set, order); } static inline void iree_atomic_task_affinity_set_store( iree_atomic_task_affinity_set_t* set, iree_task_affinity_set_t value, iree_memory_order_t order) { - iree_atomic_store_int64(set, value, order); + iree_atomic_store(set, value, order); } static inline iree_task_affinity_set_t iree_atomic_task_affinity_set_fetch_and( iree_atomic_task_affinity_set_t* set, iree_task_affinity_set_t value, iree_memory_order_t order) { - return iree_atomic_fetch_and_int64(set, value, order); + return iree_atomic_fetch_and(set, value, order); } static inline iree_task_affinity_set_t iree_atomic_task_affinity_set_fetch_or( iree_atomic_task_affinity_set_t* set, iree_task_affinity_set_t value, iree_memory_order_t order) { - return iree_atomic_fetch_or_int64(set, value, order); + return iree_atomic_fetch_or(set, value, order); } #ifdef __cplusplus diff --git a/runtime/src/iree/task/executor.c b/runtime/src/iree/task/executor.c index ff3280aaf4d2..6fc98e279e4c 100644 --- a/runtime/src/iree/task/executor.c +++ b/runtime/src/iree/task/executor.c @@ -103,10 +103,9 @@ iree_status_t iree_task_executor_create(iree_task_executor_options_t options, IREE_TRACE({ static iree_atomic_int32_t executor_id = IREE_ATOMIC_VAR_INIT(0); char trace_name[32]; - int trace_name_length = - snprintf(trace_name, sizeof(trace_name), "iree-executor-%d", - iree_atomic_fetch_add_int32(&executor_id, 1, - iree_memory_order_seq_cst)); + int trace_name_length = snprintf( + trace_name, sizeof(trace_name), "iree-executor-%d", + iree_atomic_fetch_add(&executor_id, 1, iree_memory_order_seq_cst)); IREE_LEAK_CHECK_DISABLE_PUSH(); executor->trace_name = malloc(trace_name_length + 1); memcpy((void*)executor->trace_name, trace_name, trace_name_length + 1); @@ -540,8 +539,7 @@ static iree_task_t* iree_task_executor_try_steal_task_from_affinity_set( worker_index += offset + 1; mask = iree_shr(mask, offset + 1); iree_task_worker_t* victim_worker = &executor->workers[victim_index]; - if (iree_atomic_load_int32(&victim_worker->state, - iree_memory_order_acquire) != + if (iree_atomic_load(&victim_worker->state, iree_memory_order_acquire) != IREE_TASK_WORKER_STATE_RUNNING) { return NULL; } diff --git a/runtime/src/iree/task/executor_demo.cc b/runtime/src/iree/task/executor_demo.cc index 63dba4ce0192..972d16b114a7 100644 --- a/runtime/src/iree/task/executor_demo.cc +++ b/runtime/src/iree/task/executor_demo.cc @@ -89,8 +89,8 @@ extern "C" int main(int argc, char* argv[]) { IREE_TRACE_SCOPE_NAMED("tile0"); IREE_ASSERT_EQ(0, user_context); simulate_work(tile_context); - iree_atomic_fetch_add_int32(&tile_context->statistics->reserved, 1, - iree_memory_order_relaxed); + iree_atomic_fetch_add(&tile_context->statistics->reserved, 1, + iree_memory_order_relaxed); return iree_ok_status(); }, 0), @@ -107,8 +107,8 @@ extern "C" int main(int argc, char* argv[]) { IREE_TRACE_SCOPE_NAMED("tile1"); IREE_ASSERT_EQ(0, user_context); simulate_work(tile_context); - iree_atomic_fetch_add_int32(&tile_context->statistics->reserved, 1, - iree_memory_order_relaxed); + iree_atomic_fetch_add(&tile_context->statistics->reserved, 1, + iree_memory_order_relaxed); return iree_ok_status(); }, 0), diff --git a/runtime/src/iree/task/poller.c b/runtime/src/iree/task/poller.c index e314379dc3be..e04aa3bcf162 100644 --- a/runtime/src/iree/task/poller.c +++ b/runtime/src/iree/task/poller.c @@ -32,8 +32,8 @@ iree_status_t iree_task_poller_initialize( // thread as it performs the initial resume of the wait thread. We'll need to // check in enqueue to see if the wait thread needs to be resumed. // initial_state = IREE_TASK_POLLER_STATE_SUSPENDED; - iree_atomic_store_int32(&out_poller->state, initial_state, - iree_memory_order_release); + iree_atomic_store(&out_poller->state, initial_state, + iree_memory_order_release); // Acquire an event we can use to wake the wait thread from other threads. iree_status_t status = iree_event_pool_acquire( @@ -83,7 +83,7 @@ void iree_task_poller_request_exit(iree_task_poller_t* poller) { // If the thread is already in the exiting/zombie state we don't need to do // anything. iree_task_poller_state_t prev_state = - (iree_task_poller_state_t)iree_atomic_exchange_int32( + (iree_task_poller_state_t)iree_atomic_exchange( &poller->state, IREE_TASK_POLLER_STATE_EXITING, iree_memory_order_acq_rel); switch (prev_state) { @@ -93,8 +93,8 @@ void iree_task_poller_request_exit(iree_task_poller_t* poller) { break; case IREE_TASK_POLLER_STATE_ZOMBIE: // Poller already exited; reset state to ZOMBIE. - iree_atomic_store_int32(&poller->state, IREE_TASK_POLLER_STATE_ZOMBIE, - iree_memory_order_release); + iree_atomic_store(&poller->state, IREE_TASK_POLLER_STATE_ZOMBIE, + iree_memory_order_release); break; default: // Poller now set to EXITING and should exit soon. @@ -111,7 +111,7 @@ void iree_task_poller_request_exit(iree_task_poller_t* poller) { // Returns true if the wait thread is in the zombie state (exited and awaiting // teardown). static bool iree_task_poller_is_zombie(iree_task_poller_t* poller) { - return iree_atomic_load_int32(&poller->state, iree_memory_order_acquire) == + return iree_atomic_load(&poller->state, iree_memory_order_acquire) == IREE_TASK_POLLER_STATE_ZOMBIE; } @@ -240,8 +240,8 @@ static iree_task_poller_prepare_result_t iree_task_poller_prepare_task( // scan of tasks. wait_status_code = IREE_STATUS_OK; } else if (task->cancellation_flag != NULL && - iree_atomic_load_int32(task->cancellation_flag, - iree_memory_order_acquire) != 0) { + iree_atomic_load(task->cancellation_flag, + iree_memory_order_acquire) != 0) { // Task was cancelled by the user (or a wait-any). These retire without // failure and it's up to the user to handle what happens to them. wait_status_code = IREE_STATUS_CANCELLED; @@ -313,8 +313,8 @@ static iree_task_poller_prepare_result_t iree_task_poller_prepare_task( // If this was part of a wait-any operation then set the cancellation flag // such that other waits are cancelled. if (iree_any_bit_set(task->header.flags, IREE_TASK_FLAG_WAIT_ANY)) { - if (iree_atomic_fetch_add_int32(task->cancellation_flag, 1, - iree_memory_order_release) == 0) { + if (iree_atomic_fetch_add(task->cancellation_flag, 1, + iree_memory_order_release) == 0) { // Ensure we scan again to clean up any potentially cancelled tasks. // If this was task 4 in a wait-any list then tasks 0-3 need to be // retired. @@ -429,7 +429,7 @@ static void iree_task_poller_wake_task(iree_task_poller_t* poller, // wait handles were resolved. static void iree_task_poller_commit_wait(iree_task_poller_t* poller, iree_time_t deadline_ns) { - if (iree_atomic_load_int32(&poller->state, iree_memory_order_acquire) == + if (iree_atomic_load(&poller->state, iree_memory_order_acquire) == IREE_TASK_POLLER_STATE_EXITING) { // Thread exit requested - don't block shutdown. return; @@ -486,7 +486,7 @@ static void iree_task_poller_commit_wait(iree_task_poller_t* poller, static void iree_task_poller_pump_until_exit(iree_task_poller_t* poller) { while (true) { // Check state to see if we've been asked to exit. - if (iree_atomic_load_int32(&poller->state, iree_memory_order_acquire) == + if (iree_atomic_load(&poller->state, iree_memory_order_acquire) == IREE_TASK_POLLER_STATE_EXITING) { // Thread exit requested - cancel pumping. break; @@ -536,8 +536,8 @@ static int iree_task_poller_main(iree_task_poller_t* poller) { // to exit while suspended/still starting up, so check that here before we // mess with any data structures. const bool should_run = - iree_atomic_exchange_int32(&poller->state, IREE_TASK_POLLER_STATE_RUNNING, - iree_memory_order_acq_rel) != + iree_atomic_exchange(&poller->state, IREE_TASK_POLLER_STATE_RUNNING, + iree_memory_order_acq_rel) != IREE_TASK_POLLER_STATE_EXITING; if (IREE_LIKELY(should_run)) { // << work happens here >> @@ -545,8 +545,8 @@ static int iree_task_poller_main(iree_task_poller_t* poller) { } IREE_TRACE_ZONE_END(thread_zone); - iree_atomic_store_int32(&poller->state, IREE_TASK_POLLER_STATE_ZOMBIE, - iree_memory_order_release); + iree_atomic_store(&poller->state, IREE_TASK_POLLER_STATE_ZOMBIE, + iree_memory_order_release); iree_notification_post(&poller->state_notification, IREE_ALL_WAITERS); return 0; } diff --git a/runtime/src/iree/task/scope.c b/runtime/src/iree/task/scope.c index 3ccf6ae5dfea..a777d3dc6067 100644 --- a/runtime/src/iree/task/scope.c +++ b/runtime/src/iree/task/scope.c @@ -49,12 +49,12 @@ void iree_task_scope_deinitialize(iree_task_scope_t* scope) { memset(scope->name, 0xCD, sizeof(scope->name)); // In most cases the status will have been consumed by the scope owner. - iree_status_t status = (iree_status_t)iree_atomic_exchange_intptr( + iree_status_t status = (iree_status_t)iree_atomic_exchange( &scope->permanent_status, (intptr_t)NULL, iree_memory_order_acquire); IREE_IGNORE_ERROR(status); - while (iree_atomic_load_int32(&scope->pending_idle_notification_posts, - iree_memory_order_acquire)) { + while (iree_atomic_load(&scope->pending_idle_notification_posts, + iree_memory_order_acquire)) { iree_thread_yield(); } iree_notification_deinitialize(&scope->idle_notification); @@ -74,14 +74,14 @@ iree_task_dispatch_statistics_t iree_task_scope_consume_statistics( } bool iree_task_scope_has_failed(iree_task_scope_t* scope) { - return iree_atomic_load_intptr(&scope->permanent_status, - iree_memory_order_acquire) != 0; + return iree_atomic_load(&scope->permanent_status, + iree_memory_order_acquire) != 0; } iree_status_t iree_task_scope_consume_status(iree_task_scope_t* scope) { iree_status_t old_status = iree_ok_status(); iree_status_t new_status = iree_ok_status(); - while (!iree_atomic_compare_exchange_strong_intptr( + while (!iree_atomic_compare_exchange_strong( &scope->permanent_status, (intptr_t*)&old_status, (intptr_t)new_status, iree_memory_order_acq_rel, iree_memory_order_acquire /* old_status is actually used */)) { @@ -114,7 +114,7 @@ static void iree_task_scope_try_set_status(iree_task_scope_t* scope, } iree_status_t old_status = iree_ok_status(); - if (!iree_atomic_compare_exchange_strong_intptr( + if (!iree_atomic_compare_exchange_strong( &scope->permanent_status, (intptr_t*)&old_status, (intptr_t)new_status, iree_memory_order_acq_rel, iree_memory_order_relaxed /* old_status is unused */)) { @@ -140,16 +140,16 @@ void iree_task_scope_begin(iree_task_scope_t* scope) { // relaxed because this 'begin' call will be paired with a 'end' call that // will perform the release-store, and this value is only read by // 'deinitialize'. - iree_atomic_store_int32(&scope->pending_idle_notification_posts, 1, - iree_memory_order_relaxed); + iree_atomic_store(&scope->pending_idle_notification_posts, 1, + iree_memory_order_relaxed); } void iree_task_scope_end(iree_task_scope_t* scope) { if (iree_atomic_ref_count_dec(&scope->pending_submissions) == 1) { // All submissions have completed in this scope - notify any waiters. iree_notification_post(&scope->idle_notification, IREE_ALL_WAITERS); - iree_atomic_store_int32(&scope->pending_idle_notification_posts, 0, - iree_memory_order_release); + iree_atomic_store(&scope->pending_idle_notification_posts, 0, + iree_memory_order_release); } } diff --git a/runtime/src/iree/task/task.c b/runtime/src/iree/task/task.c index ae4fbf99d5b3..d0e40103e814 100644 --- a/runtime/src/iree/task/task.c +++ b/runtime/src/iree/task/task.c @@ -39,13 +39,13 @@ void iree_task_set_completion_task(iree_task_t* task, iree_task_t* completion_task) { IREE_ASSERT(!task->completion_task); task->completion_task = completion_task; - iree_atomic_fetch_add_int32(&completion_task->pending_dependency_count, 1, - iree_memory_order_acq_rel); + iree_atomic_fetch_add(&completion_task->pending_dependency_count, 1, + iree_memory_order_acq_rel); } bool iree_task_is_ready(iree_task_t* task) { - if (iree_atomic_load_int32(&task->pending_dependency_count, - iree_memory_order_acquire) > 0) { + if (iree_atomic_load(&task->pending_dependency_count, + iree_memory_order_acquire) > 0) { // At least one dependency is still pending. return false; } @@ -62,7 +62,7 @@ static void iree_task_try_set_status(iree_atomic_intptr_t* permanent_status, z0, iree_status_code_string(iree_status_code(new_status))); iree_status_t old_status = iree_ok_status(); - if (!iree_atomic_compare_exchange_strong_intptr( + if (!iree_atomic_compare_exchange_strong( permanent_status, (intptr_t*)&old_status, (intptr_t)new_status, iree_memory_order_acq_rel, iree_memory_order_relaxed /* old_status is unused */)) { @@ -102,16 +102,15 @@ void iree_task_discard(iree_task_t* task, iree_task_list_t* discard_worklist) { // tasks in the appropriate order: if we had a DAG of A -> B, C -> D we must // discard respecting the same topological ordering. - IREE_ASSERT_EQ(0, iree_atomic_load_int32(&task->pending_dependency_count, - iree_memory_order_acquire)); + IREE_ASSERT_EQ(0, iree_atomic_load(&task->pending_dependency_count, + iree_memory_order_acquire)); // Almost all tasks will have a completion task; some may have additional // dependent tasks (like barriers) that will be handled below. const bool completion_task_ready = task->completion_task && - iree_atomic_fetch_sub_int32( - &task->completion_task->pending_dependency_count, 1, - iree_memory_order_acq_rel) == 1; + iree_atomic_fetch_sub(&task->completion_task->pending_dependency_count, 1, + iree_memory_order_acq_rel) == 1; if (completion_task_ready) { iree_task_list_push_back(discard_worklist, task->completion_task); } @@ -147,8 +146,8 @@ void iree_task_discard(iree_task_t* task, iree_task_list_t* discard_worklist) { static void iree_task_retire(iree_task_t* task, iree_task_submission_t* pending_submission, iree_status_t status) { - IREE_ASSERT_EQ(0, iree_atomic_load_int32(&task->pending_dependency_count, - iree_memory_order_acquire)); + IREE_ASSERT_EQ(0, iree_atomic_load(&task->pending_dependency_count, + iree_memory_order_acquire)); // Decrement the pending count on the completion task, if any. iree_task_t* completion_task = task->completion_task; @@ -159,8 +158,8 @@ static void iree_task_retire(iree_task_t* task, iree_task_cleanup(task, IREE_STATUS_OK); bool completion_task_ready = completion_task && - iree_atomic_fetch_sub_int32(&completion_task->pending_dependency_count, - 1, iree_memory_order_acq_rel) == 1; + iree_atomic_fetch_sub(&completion_task->pending_dependency_count, 1, + iree_memory_order_acq_rel) == 1; if (completion_task_ready) { // This was the last pending dependency and the completion task is ready // to run. @@ -180,8 +179,8 @@ static void iree_task_retire(iree_task_t* task, bool completion_task_ready = completion_task && - iree_atomic_fetch_sub_int32(&completion_task->pending_dependency_count, - 1, iree_memory_order_acq_rel) == 1; + iree_atomic_fetch_sub(&completion_task->pending_dependency_count, 1, + iree_memory_order_acq_rel) == 1; if (completion_task_ready) { // This was the last pending dependency and we know that we can safely // abort the completion task by discarding. @@ -239,7 +238,7 @@ void iree_task_call_initialize(iree_task_scope_t* scope, iree_task_call_t* out_task) { iree_task_initialize(IREE_TASK_TYPE_CALL, scope, &out_task->header); out_task->closure = closure; - iree_atomic_store_intptr(&out_task->status, 0, iree_memory_order_release); + iree_atomic_store(&out_task->status, 0, iree_memory_order_release); } void iree_task_call_execute(iree_task_call_t* task, @@ -272,9 +271,9 @@ void iree_task_call_execute(iree_task_call_t* task, // Check to see if there are no pending dependencies before retiring; the // dependency count can go up if new nested tasks were enqueued. - if (iree_atomic_load_int32(&task->header.pending_dependency_count, - iree_memory_order_acquire) == 0) { - iree_status_t status = (iree_status_t)iree_atomic_exchange_intptr( + if (iree_atomic_load(&task->header.pending_dependency_count, + iree_memory_order_acquire) == 0) { + iree_status_t status = (iree_status_t)iree_atomic_exchange( &task->status, 0, iree_memory_order_acq_rel); iree_task_retire(&task->header, pending_submission, status); } @@ -295,8 +294,8 @@ void iree_task_barrier_initialize(iree_task_scope_t* scope, out_task->dependent_tasks = dependent_tasks; for (iree_host_size_t i = 0; i < out_task->dependent_task_count; ++i) { iree_task_t* dependent_task = out_task->dependent_tasks[i]; - iree_atomic_fetch_add_int32(&dependent_task->pending_dependency_count, 1, - iree_memory_order_acq_rel); + iree_atomic_fetch_add(&dependent_task->pending_dependency_count, 1, + iree_memory_order_acq_rel); } } @@ -314,8 +313,8 @@ void iree_task_barrier_set_dependent_tasks( task->dependent_tasks = dependent_tasks; for (iree_host_size_t i = 0; i < task->dependent_task_count; ++i) { iree_task_t* dependent_task = task->dependent_tasks[i]; - iree_atomic_fetch_add_int32(&dependent_task->pending_dependency_count, 1, - iree_memory_order_acq_rel); + iree_atomic_fetch_add(&dependent_task->pending_dependency_count, 1, + iree_memory_order_acq_rel); } } @@ -329,8 +328,8 @@ static void iree_task_barrier_discard(iree_task_barrier_t* task, for (iree_host_size_t i = 0; i < task->dependent_task_count; ++i) { iree_task_t* dependent_task = task->dependent_tasks[i]; const bool dependent_task_ready = - iree_atomic_fetch_sub_int32(&dependent_task->pending_dependency_count, - 1, iree_memory_order_acq_rel) == 1; + iree_atomic_fetch_sub(&dependent_task->pending_dependency_count, 1, + iree_memory_order_acq_rel) == 1; if (dependent_task_ready) { // The dependent task has retired and can now be discard. iree_task_list_push_back(discard_worklist, dependent_task); @@ -348,8 +347,8 @@ void iree_task_barrier_retire(iree_task_barrier_t* task, for (iree_host_size_t i = 0; i < task->dependent_task_count; ++i) { iree_task_t* dependent_task = task->dependent_tasks[task->dependent_task_count - i - 1]; - if (iree_atomic_fetch_sub_int32(&dependent_task->pending_dependency_count, - 1, iree_memory_order_acq_rel) == 1) { + if (iree_atomic_fetch_sub(&dependent_task->pending_dependency_count, 1, + iree_memory_order_acq_rel) == 1) { // The dependent task has retired and can now be made ready. iree_task_submission_enqueue(pending_submission, dependent_task); } @@ -530,13 +529,13 @@ static void iree_task_dispatch_initialize_base( memcpy(out_task->workgroup_size, workgroup_size, sizeof(out_task->workgroup_size)); out_task->local_memory_size = 0; - iree_atomic_store_intptr(&out_task->status, 0, iree_memory_order_release); + iree_atomic_store(&out_task->status, 0, iree_memory_order_release); memset(&out_task->statistics, 0, sizeof(out_task->statistics)); IREE_TRACE({ static iree_atomic_int64_t next_dispatch_id = IREE_ATOMIC_VAR_INIT(0); - out_task->dispatch_id = iree_atomic_fetch_add_int64( - &next_dispatch_id, 1ll, iree_memory_order_acq_rel); + out_task->dispatch_id = iree_atomic_fetch_add(&next_dispatch_id, 1ll, + iree_memory_order_acq_rel); }); } @@ -597,8 +596,7 @@ void iree_task_dispatch_issue(iree_task_dispatch_t* dispatch_task, #endif // IREE_HAL_VERBOSE_TRACING_ENABLE // Setup the iteration space for shards to pull work from the complete grid. - iree_atomic_store_int32(&dispatch_task->tile_index, 0, - iree_memory_order_relaxed); + iree_atomic_store(&dispatch_task->tile_index, 0, iree_memory_order_relaxed); dispatch_task->tile_count = workgroup_count[0] * workgroup_count[1] * workgroup_count[2]; @@ -672,7 +670,7 @@ void iree_task_dispatch_retire(iree_task_dispatch_t* dispatch_task, // any other has hit an error; failure in a dispatch should be so exceedingly // rare that allowing some shards to complete after one encounters an error is // not a problem. - iree_status_t status = (iree_status_t)iree_atomic_exchange_intptr( + iree_status_t status = (iree_status_t)iree_atomic_exchange( &dispatch_task->status, 0, iree_memory_order_acq_rel); iree_task_retire(&dispatch_task->header, pending_submission, status); @@ -763,9 +761,9 @@ void iree_task_dispatch_shard_execute( const uint32_t tiles_per_reservation = dispatch_task->tiles_per_reservation; // relaxed order because we only care about atomic increments, not about // ordering of tile_index accesses w.r.t. other memory accesses. - uint32_t tile_base = iree_atomic_fetch_add_int32(&dispatch_task->tile_index, - tiles_per_reservation, - iree_memory_order_relaxed); + uint32_t tile_base = + iree_atomic_fetch_add(&dispatch_task->tile_index, tiles_per_reservation, + iree_memory_order_relaxed); while (tile_base < tile_count) { const uint32_t tile_range = iree_min(tile_base + tiles_per_reservation, tile_count); @@ -813,9 +811,9 @@ void iree_task_dispatch_shard_execute( } // Try to grab the next slice of tiles. - tile_base = iree_atomic_fetch_add_int32(&dispatch_task->tile_index, - tiles_per_reservation, - iree_memory_order_relaxed); + tile_base = + iree_atomic_fetch_add(&dispatch_task->tile_index, tiles_per_reservation, + iree_memory_order_relaxed); } abort_shard: diff --git a/runtime/src/iree/task/task_test_dispatch.cc b/runtime/src/iree/task/task_test_dispatch.cc index 3324b6cc464e..b18c26e790ec 100644 --- a/runtime/src/iree/task/task_test_dispatch.cc +++ b/runtime/src/iree/task/task_test_dispatch.cc @@ -35,8 +35,7 @@ class GridCoverage { bool Verify() { fflush(stdout); for (iree_host_size_t i = 0; i < workgroup_count_; ++i) { - if (iree_atomic_load_int32(&storage_[i], iree_memory_order_seq_cst) != - 1) { + if (iree_atomic_load(&storage_[i], iree_memory_order_seq_cst) != 1) { return false; } } @@ -52,8 +51,8 @@ class GridCoverage { tile_context->workgroup_count[0]) + tile_context->workgroup_xyz[1] * tile_context->workgroup_count[0] + tile_context->workgroup_xyz[0]; - iree_atomic_fetch_add_int32(&coverage->storage_[slot], 1, - iree_memory_order_seq_cst); + iree_atomic_fetch_add(&coverage->storage_[slot], 1, + iree_memory_order_seq_cst); // Useful when testing large grids: // printf("%u, %u, %u\n", tile_context->workgroup_xyz[0], diff --git a/runtime/src/iree/task/worker.c b/runtime/src/iree/task/worker.c index 5bebaa50fc09..e0e1efd82085 100644 --- a/runtime/src/iree/task/worker.c +++ b/runtime/src/iree/task/worker.c @@ -48,8 +48,8 @@ iree_status_t iree_task_worker_initialize( iree_task_queue_initialize(&out_worker->local_task_queue); iree_task_worker_state_t initial_state = IREE_TASK_WORKER_STATE_RUNNING; - iree_atomic_store_int32(&out_worker->state, initial_state, - iree_memory_order_release); + iree_atomic_store(&out_worker->state, initial_state, + iree_memory_order_release); iree_thread_create_params_t thread_params; memset(&thread_params, 0, sizeof(thread_params)); @@ -78,14 +78,14 @@ void iree_task_worker_request_exit(iree_task_worker_t* worker) { // If the thread is already in the exiting/zombie state we don't need to do // anything. iree_task_worker_state_t prev_state = - (iree_task_worker_state_t)iree_atomic_exchange_int32( + (iree_task_worker_state_t)iree_atomic_exchange( &worker->state, IREE_TASK_WORKER_STATE_EXITING, iree_memory_order_acq_rel); switch (prev_state) { case IREE_TASK_WORKER_STATE_ZOMBIE: // Worker already exited; reset state to ZOMBIE. - iree_atomic_store_int32(&worker->state, IREE_TASK_WORKER_STATE_ZOMBIE, - iree_memory_order_release); + iree_atomic_store(&worker->state, IREE_TASK_WORKER_STATE_ZOMBIE, + iree_memory_order_release); break; default: // Worker now set to EXITING and should exit soon. @@ -101,7 +101,7 @@ void iree_task_worker_request_exit(iree_task_worker_t* worker) { // Returns true if the worker is in the zombie state (exited and awaiting // teardown). static bool iree_task_worker_is_zombie(iree_task_worker_t* worker) { - return iree_atomic_load_int32(&worker->state, iree_memory_order_acquire) == + return iree_atomic_load(&worker->state, iree_memory_order_acquire) == IREE_TASK_WORKER_STATE_ZOMBIE; } @@ -310,7 +310,7 @@ static void iree_task_worker_pump_until_exit(iree_task_worker_t* worker) { iree_task_worker_mark_active(worker); // Check state to see if we've been asked to exit. - if (iree_atomic_load_int32(&worker->state, iree_memory_order_acquire) == + if (iree_atomic_load(&worker->state, iree_memory_order_acquire) == IREE_TASK_WORKER_STATE_EXITING) { // Thread exit requested - cancel pumping. iree_notification_cancel_wait(&worker->wake_notification); @@ -395,8 +395,8 @@ static int iree_task_worker_main(iree_task_worker_t* worker) { // to exit while suspended/still starting up, so check that here before we // mess with any data structures. const bool should_run = - iree_atomic_exchange_int32(&worker->state, IREE_TASK_WORKER_STATE_RUNNING, - iree_memory_order_acq_rel) != + iree_atomic_exchange(&worker->state, IREE_TASK_WORKER_STATE_RUNNING, + iree_memory_order_acq_rel) != IREE_TASK_WORKER_STATE_EXITING; if (IREE_LIKELY(should_run)) { // << work happens here >> @@ -407,8 +407,8 @@ static int iree_task_worker_main(iree_task_worker_t* worker) { iree_task_worker_mark_idle(worker); IREE_TRACE_ZONE_END(thread_zone); - iree_atomic_store_int32(&worker->state, IREE_TASK_WORKER_STATE_ZOMBIE, - iree_memory_order_release); + iree_atomic_store(&worker->state, IREE_TASK_WORKER_STATE_ZOMBIE, + iree_memory_order_release); iree_notification_post(&worker->state_notification, IREE_ALL_WAITERS); return 0; } diff --git a/runtime/src/iree/vm/context.c b/runtime/src/iree/vm/context.c index d55e67fb99f3..3a1fc239e999 100644 --- a/runtime/src/iree/vm/context.c +++ b/runtime/src/iree/vm/context.c @@ -51,8 +51,8 @@ static iree_vm_context_id_t iree_vm_context_allocate_id(void) { static iree_atomic_int32_t next_context_id = IREE_ATOMIC_VAR_INIT(1); // relaxed because we only care about atomic increments, not ordering w.r.t. // other memory accesses. - uint32_t context_id = iree_atomic_fetch_add_int32(&next_context_id, 1, - iree_memory_order_relaxed); + uint32_t context_id = + iree_atomic_fetch_add(&next_context_id, 1, iree_memory_order_relaxed); #if IREE_TRACING_FEATURES & IREE_TRACING_FEATURE_FIBERS // This is what we pass to Tracy as the fiber name. // The string must remain live for the lifetime of the process. diff --git a/runtime/src/iree/vm/invocation.c b/runtime/src/iree/vm/invocation.c index 2ba5bab75ab3..d3fe20ac0f12 100644 --- a/runtime/src/iree/vm/invocation.c +++ b/runtime/src/iree/vm/invocation.c @@ -226,8 +226,8 @@ static iree_vm_invocation_id_t iree_vm_invoke_allocate_id( // The string must remain live for the lifetime of the process. // TODO(benvanik): name it based on the function? static iree_atomic_int32_t next_invocation_id = IREE_ATOMIC_VAR_INIT(1); - uint32_t invocation_id = iree_atomic_fetch_add_int32( - &next_invocation_id, 1, iree_memory_order_relaxed); + uint32_t invocation_id = iree_atomic_fetch_add(&next_invocation_id, 1, + iree_memory_order_relaxed); IREE_LEAK_CHECK_DISABLE_PUSH(); char* name = (char*)malloc(32); snprintf(name, 32, "invoke-%04d", invocation_id - 1); diff --git a/runtime/src/iree/vm/ref.c b/runtime/src/iree/vm/ref.c index 3d5f2552b585..fe3313620075 100644 --- a/runtime/src/iree/vm/ref.c +++ b/runtime/src/iree/vm/ref.c @@ -12,15 +12,15 @@ // Useful debugging tool: #if 0 -static inline volatile iree_atomic_ref_count_t* iree_vm_get_raw_counter_ptr( +static inline iree_atomic_ref_count_t* iree_vm_get_raw_counter_ptr( void* ptr, iree_vm_ref_type_t type); -static inline volatile iree_atomic_ref_count_t* iree_vm_get_ref_counter_ptr( +static inline iree_atomic_ref_count_t* iree_vm_get_ref_counter_ptr( iree_vm_ref_t* ref); static void iree_vm_ref_trace(const char* msg, iree_vm_ref_t* ref) { if (!ref->ptr) return; - volatile iree_atomic_ref_count_t* counter = iree_vm_get_ref_counter_ptr(ref); + iree_atomic_ref_count_t* counter = iree_vm_get_ref_counter_ptr(ref); iree_string_view_t name = iree_vm_ref_type_name(ref->type); fprintf(stderr, "%s %.*s 0x%p %d\n", msg, (int)name.size, name.data, ref->ptr, iree_atomic_ref_count_load(counter)); @@ -28,7 +28,7 @@ static void iree_vm_ref_trace(const char* msg, iree_vm_ref_t* ref) { static void iree_vm_ref_ptr_trace(const char* msg, void* ptr, iree_vm_ref_type_t type) { if (!ptr) return; - volatile iree_atomic_ref_count_t* counter = + iree_atomic_ref_count_t* counter = iree_vm_get_raw_counter_ptr(ptr, type); iree_string_view_t name = iree_vm_ref_type_name(type); fprintf(stderr, "%s %.*s 0x%p %d\n", msg, (int)name.size, name.data, ptr, @@ -45,19 +45,18 @@ iree_vm_ref_type_name(iree_vm_ref_type_t type) { return iree_vm_ref_type_descriptor(type)->type_name; } -static inline volatile iree_atomic_ref_count_t* iree_vm_get_raw_counter_ptr( +static inline iree_atomic_ref_count_t* iree_vm_get_raw_counter_ptr( void* ptr, iree_vm_ref_type_t type) { IREE_VM_REF_ASSERT(ptr); IREE_VM_REF_ASSERT(type_descriptor); - return (volatile iree_atomic_ref_count_t*)ptr + - (type & IREE_VM_REF_TYPE_TAG_BIT_MASK); + return (iree_atomic_ref_count_t*)ptr + (type & IREE_VM_REF_TYPE_TAG_BIT_MASK); } -static inline volatile iree_atomic_ref_count_t* iree_vm_get_ref_counter_ptr( +static inline iree_atomic_ref_count_t* iree_vm_get_ref_counter_ptr( iree_vm_ref_t* ref) { IREE_VM_REF_ASSERT(ref); IREE_VM_REF_ASSERT(ref->ptr); - return (volatile iree_atomic_ref_count_t*)ref->ptr + + return (iree_atomic_ref_count_t*)ref->ptr + (ref->type & IREE_VM_REF_TYPE_TAG_BIT_MASK); } @@ -65,8 +64,7 @@ IREE_API_EXPORT void iree_vm_ref_object_retain(void* ptr, iree_vm_ref_type_t type) { if (!ptr) return; IREE_VM_REF_ASSERT(type); - volatile iree_atomic_ref_count_t* counter = - iree_vm_get_raw_counter_ptr(ptr, type); + iree_atomic_ref_count_t* counter = iree_vm_get_raw_counter_ptr(ptr, type); iree_atomic_ref_count_inc(counter); iree_vm_ref_ptr_trace("RETAIN", ptr, type); } @@ -76,8 +74,7 @@ IREE_API_EXPORT void iree_vm_ref_object_release(void* ptr, if (!ptr) return; IREE_VM_REF_ASSERT(type); iree_vm_ref_ptr_trace("RELEASE", ptr, type); - volatile iree_atomic_ref_count_t* counter = - iree_vm_get_raw_counter_ptr(ptr, type); + iree_atomic_ref_count_t* counter = iree_vm_get_raw_counter_ptr(ptr, type); if (iree_atomic_ref_count_dec(counter) == 1) { const iree_vm_ref_type_descriptor_t* descriptor = iree_vm_ref_type_descriptor(type); @@ -130,8 +127,7 @@ IREE_API_EXPORT iree_status_t iree_vm_ref_wrap_retain(void* ptr, out_ref->ptr = ptr; out_ref->type = type; if (out_ref->ptr) { - volatile iree_atomic_ref_count_t* counter = - iree_vm_get_ref_counter_ptr(out_ref); + iree_atomic_ref_count_t* counter = iree_vm_get_ref_counter_ptr(out_ref); iree_atomic_ref_count_inc(counter); iree_vm_ref_trace("WRAP RETAIN", out_ref); } @@ -142,8 +138,7 @@ IREE_API_EXPORT iree_status_t iree_vm_ref_wrap_retain(void* ptr, IREE_API_EXPORT void iree_vm_ref_retain_inplace(iree_vm_ref_t* ref) { IREE_VM_REF_ASSERT(ref); if (ref->ptr) { - volatile iree_atomic_ref_count_t* counter = - iree_vm_get_ref_counter_ptr(ref); + iree_atomic_ref_count_t* counter = iree_vm_get_ref_counter_ptr(ref); iree_atomic_ref_count_inc(counter); iree_vm_ref_trace("RETAIN", ref); } @@ -157,8 +152,7 @@ IREE_API_EXPORT void iree_vm_ref_retain(iree_vm_ref_t* ref, IREE_VM_REF_ASSERT(out_ref); iree_vm_ref_t temp_ref = *ref; if (ref->ptr) { - volatile iree_atomic_ref_count_t* counter = - iree_vm_get_ref_counter_ptr(ref); + iree_atomic_ref_count_t* counter = iree_vm_get_ref_counter_ptr(ref); iree_atomic_ref_count_inc(counter); iree_vm_ref_trace("RETAIN", ref); } @@ -217,7 +211,7 @@ IREE_API_EXPORT void iree_vm_ref_release(iree_vm_ref_t* ref) { if (ref->type == IREE_VM_REF_TYPE_NULL || ref->ptr == NULL) return; iree_vm_ref_trace("RELEASE", ref); - volatile iree_atomic_ref_count_t* counter = iree_vm_get_ref_counter_ptr(ref); + iree_atomic_ref_count_t* counter = iree_vm_get_ref_counter_ptr(ref); if (iree_atomic_ref_count_dec(counter) == 1) { const iree_vm_ref_type_descriptor_t* descriptor = iree_vm_ref_type_descriptor(ref->type); diff --git a/runtime/src/iree/vm/ref_test.cc b/runtime/src/iree/vm/ref_test.cc index 68eaa5eb5dc5..5260749b31aa 100644 --- a/runtime/src/iree/vm/ref_test.cc +++ b/runtime/src/iree/vm/ref_test.cc @@ -73,9 +73,9 @@ static iree_vm_ref_t MakeRef(InstancePtr& instance, const char* type_name) { // WARNING: this is an implementation detail and must never be relied on - it's // only here to test the expected behavior. static int32_t ReadCounter(iree_vm_ref_t* ref) { - return iree_atomic_load_int32((iree_atomic_ref_count_t*)ref->ptr + - (ref->type & IREE_VM_REF_TYPE_TAG_BIT_MASK), - iree_memory_order_seq_cst); + return iree_atomic_load((iree_atomic_ref_count_t*)ref->ptr + + (ref->type & IREE_VM_REF_TYPE_TAG_BIT_MASK), + iree_memory_order_seq_cst); } } // namespace diff --git a/tests/e2e/matmul/BUILD.bazel b/tests/e2e/matmul/BUILD.bazel index 635ee0cc3213..a82bfb691047 100644 --- a/tests/e2e/matmul/BUILD.bazel +++ b/tests/e2e/matmul/BUILD.bazel @@ -360,6 +360,7 @@ X86_64_AVX512_BF16 = X86_64_AVX512 + [ generator = ":generate_e2e_matmul_tests", generator_args = [ "--lhs_rhs_type=%s" % lhs_rhs_type, + "--acc_type=%s" % acc_type, "--shapes=small", ], target_backends_and_drivers = [ @@ -367,9 +368,9 @@ X86_64_AVX512_BF16 = X86_64_AVX512 + [ ], test_runner = "//tools/testing/e2e:iree-e2e-matmul-test", test_type = "matmul", -) for lhs_rhs_type in [ - "i8", - "f32", +) for (lhs_rhs_type, acc_type) in [ + ("i8", "i32"), + ("f32", "f32"), ]] ########################################################################### @@ -383,6 +384,7 @@ iree_generated_e2e_runner_test( generator = ":generate_e2e_matmul_tests", generator_args = [ "--lhs_rhs_type=f32", + "--acc_type=f32", "--shapes=easy_large_static", "--compilation_info=LLVMGPUMatmulSimt", ], @@ -411,6 +413,7 @@ iree_generated_e2e_runner_test( generator = ":generate_e2e_matmul_tests", generator_args = [ "--lhs_rhs_type=f32", + "--acc_type=f32", "--shapes=easy_large_static", "--compilation_info=LLVMGPUMatmulTensorCore", ], @@ -437,6 +440,7 @@ iree_generated_e2e_runner_test( generator = ":generate_e2e_matmul_tests", generator_args = [ "--lhs_rhs_type=f32", + "--acc_type=f32", ], tags = [ # CUDA cuInit fails with sanitizer on. @@ -461,6 +465,7 @@ iree_generated_e2e_runner_test( generator = ":generate_e2e_matmul_tests", generator_args = [ "--lhs_rhs_type=f16", + "--acc_type=f32", ], tags = [ # CUDA cuInit fails with sanitizer on. @@ -486,6 +491,7 @@ iree_generated_e2e_runner_test( generator = ":generate_e2e_matmul_tests", generator_args = [ "--lhs_rhs_type=f32", + "--acc_type=f32", "--shapes=easy_large_static", "--compilation_info=LLVMGPUMatmulTensorCoreMmaSync", ], @@ -513,6 +519,7 @@ iree_generated_e2e_runner_test( generator = ":generate_e2e_matmul_tests", generator_args = [ "--lhs_rhs_type=f16", + "--acc_type=f32", "--shapes=easy_large_static", "--compilation_info=LLVMGPUMatmulTensorCore", ], @@ -540,6 +547,7 @@ iree_generated_e2e_runner_test( generator = ":generate_e2e_matmul_tests", generator_args = [ "--lhs_rhs_type=f16", + "--acc_type=f32", "--shapes=easy_large_static", "--compilation_info=LLVMGPUMatmulTensorCoreMmaSync", ], @@ -566,6 +574,7 @@ iree_generated_e2e_runner_test( generator = ":generate_e2e_matmul_tests", generator_args = [ "--lhs_rhs_type=%s" % lhs_rhs_type, + "--acc_type=%s" % acc_type, ], tags = [ # CUDA cuInit fails with sanitizer on. @@ -580,8 +589,8 @@ iree_generated_e2e_runner_test( ], test_runner = "//tools/testing/e2e:iree-e2e-matmul-test", test_type = "matmul", -) for lhs_rhs_type in [ - "f32", +) for (lhs_rhs_type, acc_type) in [ + ("f32", "f32"), ]] ########################################################################### @@ -598,6 +607,7 @@ iree_generated_e2e_runner_test( generator = ":generate_e2e_matmul_tests", generator_args = [ "--lhs_rhs_type=%s" % lhs_rhs_type, + "--acc_type=%s" % acc_type, "--shapes=easy_large_static", "--compilation_info=SPIRVVectorizeMali", ], @@ -611,10 +621,10 @@ iree_generated_e2e_runner_test( ], test_runner = "//tools/testing/e2e:iree-e2e-matmul-test", test_type = "matmul", -) for lhs_rhs_type in [ - "i8", - "f16", - "f32", +) for (lhs_rhs_type, acc_type) in [ + ("i8", "i32"), + ("f16", "f32"), + ("f32", "f32"), ]] [iree_generated_e2e_runner_test( @@ -625,6 +635,7 @@ iree_generated_e2e_runner_test( generator = ":generate_e2e_matmul_tests", generator_args = [ "--lhs_rhs_type=%s" % lhs_rhs_type, + "--acc_type=%s" % acc_type, "--shapes=easy_large_static", "--compilation_info=SPIRVVectorizeNVIDIA", ], @@ -637,10 +648,10 @@ iree_generated_e2e_runner_test( ], test_runner = "//tools/testing/e2e:iree-e2e-matmul-test", test_type = "matmul", -) for lhs_rhs_type in [ - "i8", - "f16", - "f32", +) for (lhs_rhs_type, acc_type) in [ + ("i8", "i32"), + ("f16", "f32"), + ("f32", "f32"), ]] iree_generated_e2e_runner_test( @@ -651,6 +662,7 @@ iree_generated_e2e_runner_test( generator = ":generate_e2e_matmul_tests", generator_args = [ "--lhs_rhs_type=f16", + "--acc_type=f32", "--shapes=easy_large_static", "--compilation_info=SPIRVCooperativeMatrixVectorize", ], diff --git a/tests/e2e/matmul/CMakeLists.txt b/tests/e2e/matmul/CMakeLists.txt index 36e1255c5bfd..f2294345984f 100644 --- a/tests/e2e/matmul/CMakeLists.txt +++ b/tests/e2e/matmul/CMakeLists.txt @@ -927,6 +927,7 @@ iree_generated_e2e_runner_test( "generate_e2e_matmul_tests.py" GENERATOR_ARGS "--lhs_rhs_type=i8" + "--acc_type=i32" "--shapes=small" TEST_RUNNER iree_tools_testing_e2e_iree-e2e-matmul-test @@ -948,6 +949,7 @@ iree_generated_e2e_runner_test( "generate_e2e_matmul_tests.py" GENERATOR_ARGS "--lhs_rhs_type=f32" + "--acc_type=f32" "--shapes=small" TEST_RUNNER iree_tools_testing_e2e_iree-e2e-matmul-test @@ -969,6 +971,7 @@ iree_generated_e2e_runner_test( "generate_e2e_matmul_tests.py" GENERATOR_ARGS "--lhs_rhs_type=f32" + "--acc_type=f32" "--shapes=easy_large_static" "--compilation_info=LLVMGPUMatmulSimt" TEST_RUNNER @@ -994,6 +997,7 @@ iree_generated_e2e_runner_test( "generate_e2e_matmul_tests.py" GENERATOR_ARGS "--lhs_rhs_type=f32" + "--acc_type=f32" "--shapes=easy_large_static" "--compilation_info=LLVMGPUMatmulTensorCore" TEST_RUNNER @@ -1021,6 +1025,7 @@ iree_generated_e2e_runner_test( "generate_e2e_matmul_tests.py" GENERATOR_ARGS "--lhs_rhs_type=f32" + "--acc_type=f32" TEST_RUNNER iree_tools_testing_e2e_iree-e2e-matmul-test TARGET_BACKENDS @@ -1046,6 +1051,7 @@ iree_generated_e2e_runner_test( "generate_e2e_matmul_tests.py" GENERATOR_ARGS "--lhs_rhs_type=f16" + "--acc_type=f32" TEST_RUNNER iree_tools_testing_e2e_iree-e2e-matmul-test TARGET_BACKENDS @@ -1071,6 +1077,7 @@ iree_generated_e2e_runner_test( "generate_e2e_matmul_tests.py" GENERATOR_ARGS "--lhs_rhs_type=f32" + "--acc_type=f32" "--shapes=easy_large_static" "--compilation_info=LLVMGPUMatmulTensorCoreMmaSync" TEST_RUNNER @@ -1098,6 +1105,7 @@ iree_generated_e2e_runner_test( "generate_e2e_matmul_tests.py" GENERATOR_ARGS "--lhs_rhs_type=f16" + "--acc_type=f32" "--shapes=easy_large_static" "--compilation_info=LLVMGPUMatmulTensorCore" TEST_RUNNER @@ -1125,6 +1133,7 @@ iree_generated_e2e_runner_test( "generate_e2e_matmul_tests.py" GENERATOR_ARGS "--lhs_rhs_type=f16" + "--acc_type=f32" "--shapes=easy_large_static" "--compilation_info=LLVMGPUMatmulTensorCoreMmaSync" TEST_RUNNER @@ -1152,6 +1161,7 @@ iree_generated_e2e_runner_test( "generate_e2e_matmul_tests.py" GENERATOR_ARGS "--lhs_rhs_type=f32" + "--acc_type=f32" TEST_RUNNER iree_tools_testing_e2e_iree-e2e-matmul-test TARGET_BACKENDS @@ -1177,6 +1187,7 @@ iree_generated_e2e_runner_test( "generate_e2e_matmul_tests.py" GENERATOR_ARGS "--lhs_rhs_type=i8" + "--acc_type=i32" "--shapes=easy_large_static" "--compilation_info=SPIRVVectorizeMali" TEST_RUNNER @@ -1201,6 +1212,7 @@ iree_generated_e2e_runner_test( "generate_e2e_matmul_tests.py" GENERATOR_ARGS "--lhs_rhs_type=f16" + "--acc_type=f32" "--shapes=easy_large_static" "--compilation_info=SPIRVVectorizeMali" TEST_RUNNER @@ -1225,6 +1237,7 @@ iree_generated_e2e_runner_test( "generate_e2e_matmul_tests.py" GENERATOR_ARGS "--lhs_rhs_type=f32" + "--acc_type=f32" "--shapes=easy_large_static" "--compilation_info=SPIRVVectorizeMali" TEST_RUNNER @@ -1249,6 +1262,7 @@ iree_generated_e2e_runner_test( "generate_e2e_matmul_tests.py" GENERATOR_ARGS "--lhs_rhs_type=i8" + "--acc_type=i32" "--shapes=easy_large_static" "--compilation_info=SPIRVVectorizeNVIDIA" TEST_RUNNER @@ -1273,6 +1287,7 @@ iree_generated_e2e_runner_test( "generate_e2e_matmul_tests.py" GENERATOR_ARGS "--lhs_rhs_type=f16" + "--acc_type=f32" "--shapes=easy_large_static" "--compilation_info=SPIRVVectorizeNVIDIA" TEST_RUNNER @@ -1297,6 +1312,7 @@ iree_generated_e2e_runner_test( "generate_e2e_matmul_tests.py" GENERATOR_ARGS "--lhs_rhs_type=f32" + "--acc_type=f32" "--shapes=easy_large_static" "--compilation_info=SPIRVVectorizeNVIDIA" TEST_RUNNER @@ -1321,6 +1337,7 @@ iree_generated_e2e_runner_test( "generate_e2e_matmul_tests.py" GENERATOR_ARGS "--lhs_rhs_type=f16" + "--acc_type=f32" "--shapes=easy_large_static" "--compilation_info=SPIRVCooperativeMatrixVectorize" TEST_RUNNER @@ -1526,7 +1543,7 @@ iree_generated_e2e_runner_test( iree_generated_e2e_runner_test( NAME - e2e_matmul_rocm_f16_large_cdna3_mfma_data_tiled + e2e_matmul_rocm_f16_cdna3_mfma_data_tiled TEST_TYPE matmul GENERATOR @@ -1555,7 +1572,36 @@ iree_generated_e2e_runner_test( iree_generated_e2e_runner_test( NAME - e2e_matmul_rocm_i8_large_cdna3_mfma_data_tiled + e2e_matmul_rocm_bf16_cdna3_mfma_data_tiled + TEST_TYPE + matmul + GENERATOR + "generate_e2e_matmul_tests.py" + GENERATOR_ARGS + "--lhs_rhs_type=bf16" + "--acc_type=f32" + TEST_RUNNER + iree_tools_testing_e2e_iree-e2e-matmul-test + TARGET_BACKENDS + "rocm" + DRIVERS + "hip" + COMPILER_FLAGS + ${IREE_HIP_TEST_COMPILER_FLAGS} + "--iree-opt-data-tiling" + "--iree-global-opt-experimental-rocm-data-tiling" + "--iree-global-opt-enable-early-materialization=true" + LABELS + "noasan" + "nomsan" + "notsan" + "noubsan" + "requires-gpu-cdna3" +) + +iree_generated_e2e_runner_test( + NAME + e2e_matmul_rocm_i8_cdna3_mfma_data_tiled TEST_TYPE matmul GENERATOR @@ -1584,7 +1630,7 @@ iree_generated_e2e_runner_test( iree_generated_e2e_runner_test( NAME - e2e_matmul_rocm_f32_large_cdna3_mfma_data_tiled + e2e_matmul_rocm_f32_cdna3_mfma_data_tiled TEST_TYPE matmul GENERATOR @@ -1611,6 +1657,64 @@ iree_generated_e2e_runner_test( "requires-gpu-cdna3" ) +iree_generated_e2e_runner_test( + NAME + e2e_matmul_rocm_f8E5M2FNUZ_cdna3_mfma_data_tiled + TEST_TYPE + matmul + GENERATOR + "generate_e2e_matmul_tests.py" + GENERATOR_ARGS + "--lhs_rhs_type=f8E5M2FNUZ" + "--acc_type=f32" + TEST_RUNNER + iree_tools_testing_e2e_iree-e2e-matmul-test + TARGET_BACKENDS + "rocm" + DRIVERS + "hip" + COMPILER_FLAGS + ${IREE_HIP_TEST_COMPILER_FLAGS} + "--iree-opt-data-tiling" + "--iree-global-opt-experimental-rocm-data-tiling" + "--iree-global-opt-enable-early-materialization=true" + LABELS + "noasan" + "nomsan" + "notsan" + "noubsan" + "requires-gpu-cdna3" +) + +iree_generated_e2e_runner_test( + NAME + e2e_matmul_rocm_f8E4M3FNUZ_cdna3_mfma_data_tiled + TEST_TYPE + matmul + GENERATOR + "generate_e2e_matmul_tests.py" + GENERATOR_ARGS + "--lhs_rhs_type=f8E4M3FNUZ" + "--acc_type=f32" + TEST_RUNNER + iree_tools_testing_e2e_iree-e2e-matmul-test + TARGET_BACKENDS + "rocm" + DRIVERS + "hip" + COMPILER_FLAGS + ${IREE_HIP_TEST_COMPILER_FLAGS} + "--iree-opt-data-tiling" + "--iree-global-opt-experimental-rocm-data-tiling" + "--iree-global-opt-enable-early-materialization=true" + LABELS + "noasan" + "nomsan" + "notsan" + "noubsan" + "requires-gpu-cdna3" +) + endif() elseif(IREE_HIP_TEST_TARGET_CHIP MATCHES "^gfx11") diff --git a/tests/e2e/matmul/generate_e2e_matmul_tests.py b/tests/e2e/matmul/generate_e2e_matmul_tests.py index 30d210dedec0..cd6f8ebea6d3 100644 --- a/tests/e2e/matmul/generate_e2e_matmul_tests.py +++ b/tests/e2e/matmul/generate_e2e_matmul_tests.py @@ -27,8 +27,11 @@ class MatrixElemTypeId(enum.Enum): I32 = "i32" F32 = "f32" F16 = "f16" - F8E4M3FNUZ = "f8E4M3FNUZ" BF16 = "bf16" + F8E5M2 = "f8E5M2" + F8E4M3 = "f8E4M3" + F8E5M2FNUZ = "f8E5M2FNUZ" + F8E4M3FNUZ = "f8E4M3FNUZ" # Enumerates of the collections of shapes that we can generate tests for. @@ -542,20 +545,6 @@ def int_or_DYN(s: DimSize): return s.value or "DYN" -# Gets friendlier form/type that we can use as arg types which we can cast into the target_type. -def cast_argtype_if_required(target_type: MatrixElemTypeId): - if target_type == MatrixElemTypeId.F8E4M3FNUZ: - return MatrixElemTypeId.F32 - return target_type - - -# Gets the op needed to cast/convert from the friendly form/type into the target_type. -def get_castback_from_arg_op(target_type: MatrixElemTypeId): - if target_type == MatrixElemTypeId.F8E4M3FNUZ: - return "arith.truncf" - return ValueError(f"Unhandled castback type of {target_type}") - - # Describes the fully resolved shape dimensions of all 3 input matrices, # LHS, RHS, and Accumulator, in a testcase. # Each value is a string, which may either represent a positive integer such as "123", @@ -656,9 +645,8 @@ def generate_function( acc_r = int_or_question_mark(shapes.acc_rows) acc_c = int_or_question_mark(shapes.acc_cols) - casted_lhs_rhs_type = cast_argtype_if_required(lhs_rhs_type) - lhs_tensor_type = f"tensor<{lhs_r}x{lhs_c}x{casted_lhs_rhs_type.value}>" - rhs_tensor_type = f"tensor<{rhs_r}x{rhs_c}x{casted_lhs_rhs_type.value}>" + lhs_tensor_type = f"tensor<{lhs_r}x{lhs_c}x{lhs_rhs_type.value}>" + rhs_tensor_type = f"tensor<{rhs_r}x{rhs_c}x{lhs_rhs_type.value}>" acc_tensor_type = f"tensor<{acc_r}x{acc_c}x{acc_type.value}>" if transpose_rhs: @@ -677,15 +665,6 @@ def generate_function( func_definition = func_definition + compilation_info_string generate_function.compilation_index += 1 compute = f" %result = {op_name} {compilation_info_attr}ins(%lhs, %rhs: {lhs_tensor_type}, {rhs_tensor_type}) outs(%acc: {acc_tensor_type}) -> {acc_tensor_type}\n" - if casted_lhs_rhs_type != lhs_rhs_type: - castback_op = get_castback_from_arg_op(lhs_rhs_type) - compute_lhs_tensor_type = f"tensor<{lhs_r}x{lhs_c}x{lhs_rhs_type.value}>" - compute_rhs_tensor_type = f"tensor<{rhs_r}x{rhs_c}x{lhs_rhs_type.value}>" - compute = ( - f" %lhs_casted = {castback_op} %lhs: {lhs_tensor_type} to {compute_lhs_tensor_type}\n" - f" %rhs_casted = {castback_op} %rhs: {rhs_tensor_type} to {compute_rhs_tensor_type}\n" - f" %result = {op_name} {compilation_info_attr}ins(%lhs_casted, %rhs_casted: {compute_lhs_tensor_type}, {compute_rhs_tensor_type}) outs(%acc: {acc_tensor_type}) -> {acc_tensor_type}" - ) if shape.accumulate: signature = f"({lhs_tensor_type}, {rhs_tensor_type}, {acc_tensor_type}) -> {acc_tensor_type}" import_declaration = f"func.func private @module.{func_name}(%lhs: !hal.buffer_view, %rhs: !hal.buffer_view, %acc: !hal.buffer_view) -> !hal.buffer_view" @@ -815,9 +794,8 @@ def generate_call( rhs_shape = [shape.k, shape.n] transpose_rhs = 0 - casted_lhs_rhs_type = cast_argtype_if_required(lhs_rhs_type) - op = op + generate_random_matrix("lhs", lhs_shape, casted_lhs_rhs_type) - op = op + generate_random_matrix("rhs", rhs_shape, casted_lhs_rhs_type) + op = op + generate_random_matrix("lhs", lhs_shape, lhs_rhs_type) + op = op + generate_random_matrix("rhs", rhs_shape, lhs_rhs_type) if shape.accumulate: op = op + generate_random_matrix("acc", [shape.m, shape.n], acc_type) # TODO(#16168): there's a bug with in-place input->output aliasing and @@ -905,17 +883,26 @@ def parse_arguments(): parser.add_argument( "--lhs_rhs_type", type=str, - choices=["i32", "i8", "f32", "f16", "f8E4M3FNUZ", "bf16"], - help="Numeric type of input matrices", + choices=[ + "i32", + "i8", + "f32", + "f16", + "bf16", + "f8E5M2", + "f8E4M3", + "f8E5M2FNUZ", + "f8E4M3FNUZ", + ], + help="Numeric type of input LHS and RHS matrices", required=True, ) parser.add_argument( "--acc_type", type=str, choices=["i32", "f32", "f16", "bf16"], - help="Numeric type of input matrices", - default="", - required=False, + help="Numeric type of the accumulator and result matrices", + required=True, ) parser.add_argument( "--shapes", @@ -992,24 +979,9 @@ def write_calls_file(functions, calls, filename, requirements): file.write(module_definition) -# For now, the accumulator type can always be inferred from the input LHS/RHS -# type, so we do that. That is temporary: eventually there will be cases -# where the same input types are used with different accumulator types, e.g. -# f16 inputs with both f16 and f32 accumulator. -def infer_acc_type(lhs_rhs_type: MatrixElemTypeId, acc_type: MatrixElemTypeId): - if acc_type != MatrixElemTypeId.NONE: - return acc_type - if lhs_rhs_type == MatrixElemTypeId.F8E4M3FNUZ: - return MatrixElemTypeId.F32 - if lhs_rhs_type == MatrixElemTypeId.I8: - return MatrixElemTypeId.I32 - return lhs_rhs_type - - def main(args): lhs_rhs_type = MatrixElemTypeId(args.lhs_rhs_type) acc_type = MatrixElemTypeId(args.acc_type) - acc_type = infer_acc_type(lhs_rhs_type, acc_type) shapes_id = ShapesId(args.shapes) compilation_info_id = CompilationInfoId(args.compilation_info) diff --git a/tests/e2e/stablehlo_models/CMakeLists.txt b/tests/e2e/stablehlo_models/CMakeLists.txt index f12f2fa970f2..896a852e4640 100644 --- a/tests/e2e/stablehlo_models/CMakeLists.txt +++ b/tests/e2e/stablehlo_models/CMakeLists.txt @@ -42,7 +42,7 @@ iree_static_linker_test( SRC "mnist_fake_weights.mlir" STATIC_LIB_PREFIX - mnist_fake_weights_linked_llvm_cpu + mnist_fake_weights_linked ENTRY_FUNCTION "predict" FUNCTION_INPUTS @@ -57,7 +57,7 @@ iree_static_linker_test( SRC "mnist_fake_weights.mlir" STATIC_LIB_PREFIX - mnist_fake_weights_linked_llvm_cpu + mnist_fake_weights_linked ENTRY_FUNCTION "predict" FUNCTION_INPUTS diff --git a/tests/external/iree-test-suites/onnx_ops/onnx_ops_cpu_llvm_sync.json b/tests/external/iree-test-suites/onnx_ops/onnx_ops_cpu_llvm_sync.json index 0bbd604384f8..f8ca790fe5b1 100644 --- a/tests/external/iree-test-suites/onnx_ops/onnx_ops_cpu_llvm_sync.json +++ b/tests/external/iree-test-suites/onnx_ops/onnx_ops_cpu_llvm_sync.json @@ -101,8 +101,6 @@ "onnx/node/generated/test_compress_default_axis", "onnx/node/generated/test_compress_negative_axis", "onnx/node/generated/test_convtranspose_autopad_same", - "onnx/node/generated/test_convtranspose_kernel_shape", - "onnx/node/generated/test_convtranspose_output_shape", "onnx/node/generated/test_cumsum_1d", "onnx/node/generated/test_cumsum_1d_exclusive", "onnx/node/generated/test_cumsum_1d_reverse", @@ -127,7 +125,6 @@ "onnx/node/generated/test_dft_inverse_opset19", "onnx/node/generated/test_dft_opset19", "onnx/node/generated/test_edge_pad", - "onnx/node/generated/test_einsum_sum", "onnx/node/generated/test_gridsample_bicubic", "onnx/node/generated/test_gridsample_bicubic_align_corners_0_additional_1", "onnx/node/generated/test_gridsample_bicubic_align_corners_1_additional_1", @@ -395,13 +392,6 @@ "onnx/node/generated/test_softsign_example", "onnx/node/generated/test_stft", "onnx/node/generated/test_stft_with_window", - "onnx/node/generated/test_tfidfvectorizer_tf_batch_onlybigrams_skip0", - "onnx/node/generated/test_tfidfvectorizer_tf_batch_onlybigrams_skip5", - "onnx/node/generated/test_tfidfvectorizer_tf_batch_uniandbigrams_skip5", - "onnx/node/generated/test_tfidfvectorizer_tf_only_bigrams_skip0", - "onnx/node/generated/test_tfidfvectorizer_tf_onlybigrams_levelempty", - "onnx/node/generated/test_tfidfvectorizer_tf_onlybigrams_skip5", - "onnx/node/generated/test_tfidfvectorizer_tf_uniandbigrams_skip5", "onnx/node/generated/test_training_dropout", "onnx/node/generated/test_training_dropout_default", "onnx/node/generated/test_training_dropout_default_mask", @@ -427,6 +417,7 @@ "onnx/node/generated/test_constantofshape_float_ones", "onnx/node/generated/test_constantofshape_int_shape_zero", "onnx/node/generated/test_constantofshape_int_zeros", + "onnx/node/generated/test_convtranspose_output_shape", "onnx/node/generated/test_dropout_default_mask_ratio", "onnx/node/generated/test_gridsample_nearest", "onnx/node/generated/test_gridsample_nearest_align_corners_0_additional_1", @@ -447,6 +438,8 @@ "onnx/node/generated/test_reduce_min_empty_set", "onnx/node/generated/test_reduce_sum_empty_set_non_reduced_axis_zero", "onnx/node/generated/test_resize_downsample_scales_linear_align_corners", + "onnx/node/generated/test_scan_sum", + "onnx/node/generated/test_scan9_sum", "onnx/node/generated/test_shape_clip_start", "onnx/node/generated/test_shape_end_1", "onnx/node/generated/test_shape_start_1", diff --git a/tests/external/iree-test-suites/onnx_ops/onnx_ops_gpu_rocm_rdna3.json b/tests/external/iree-test-suites/onnx_ops/onnx_ops_gpu_rocm_rdna3.json index 9901e017948c..79cc5a9a4add 100644 --- a/tests/external/iree-test-suites/onnx_ops/onnx_ops_gpu_rocm_rdna3.json +++ b/tests/external/iree-test-suites/onnx_ops/onnx_ops_gpu_rocm_rdna3.json @@ -105,8 +105,6 @@ "onnx/node/generated/test_compress_default_axis", "onnx/node/generated/test_compress_negative_axis", "onnx/node/generated/test_convtranspose_autopad_same", - "onnx/node/generated/test_convtranspose_kernel_shape", - "onnx/node/generated/test_convtranspose_output_shape", "onnx/node/generated/test_cumsum_1d", "onnx/node/generated/test_cumsum_1d_exclusive", "onnx/node/generated/test_cumsum_1d_reverse", @@ -131,7 +129,6 @@ "onnx/node/generated/test_dft_inverse_opset19", "onnx/node/generated/test_dft_opset19", "onnx/node/generated/test_edge_pad", - "onnx/node/generated/test_einsum_sum", "onnx/node/generated/test_gridsample_bicubic", "onnx/node/generated/test_gridsample_bicubic_align_corners_0_additional_1", "onnx/node/generated/test_gridsample_bicubic_align_corners_1_additional_1", @@ -442,6 +439,7 @@ "onnx/node/generated/test_constantofshape_float_ones", "onnx/node/generated/test_constantofshape_int_shape_zero", "onnx/node/generated/test_constantofshape_int_zeros", + "onnx/node/generated/test_convtranspose_output_shape", "onnx/node/generated/test_dropout_default_mask_ratio", "onnx/node/generated/test_eyelike_populate_off_main_diagonal", "onnx/node/generated/test_eyelike_with_dtype", @@ -490,6 +488,8 @@ "onnx/node/generated/test_reduce_sum_square_default_axes_keepdims_random", "onnx/node/generated/test_reduce_sum_square_default_axes_keepdims_random_expanded", "onnx/node/generated/test_resize_downsample_scales_linear_align_corners", + "onnx/node/generated/test_scan_sum", + "onnx/node/generated/test_scan9_sum", "onnx/node/generated/test_shape", "onnx/node/generated/test_shape_clip_end", "onnx/node/generated/test_shape_clip_start", @@ -501,6 +501,7 @@ "onnx/node/generated/test_shape_start_negative_1", "onnx/node/generated/test_size", "onnx/node/generated/test_size_example", + "onnx/node/generated/test_slice_default_axes", "onnx/node/generated/test_split_zero_size_splits_opset13", "onnx/node/generated/test_split_zero_size_splits_opset18", "onnx/node/generated/test_top_k", diff --git a/tests/external/iree-test-suites/onnx_ops/onnx_ops_gpu_vulkan.json b/tests/external/iree-test-suites/onnx_ops/onnx_ops_gpu_vulkan.json index e6d9a4a4e201..8c31c26a421a 100644 --- a/tests/external/iree-test-suites/onnx_ops/onnx_ops_gpu_vulkan.json +++ b/tests/external/iree-test-suites/onnx_ops/onnx_ops_gpu_vulkan.json @@ -136,8 +136,6 @@ "onnx/node/generated/test_compress_default_axis", "onnx/node/generated/test_compress_negative_axis", "onnx/node/generated/test_convtranspose_autopad_same", - "onnx/node/generated/test_convtranspose_kernel_shape", - "onnx/node/generated/test_convtranspose_output_shape", "onnx/node/generated/test_cumsum_1d", "onnx/node/generated/test_cumsum_1d_exclusive", "onnx/node/generated/test_cumsum_1d_reverse", @@ -162,7 +160,6 @@ "onnx/node/generated/test_dft_inverse_opset19", "onnx/node/generated/test_dft_opset19", "onnx/node/generated/test_edge_pad", - "onnx/node/generated/test_einsum_sum", "onnx/node/generated/test_gridsample", "onnx/node/generated/test_gridsample_aligncorners_true", "onnx/node/generated/test_gridsample_bicubic", @@ -528,6 +525,7 @@ "onnx/node/generated/test_constantofshape_int_zeros", "onnx/node/generated/test_convinteger_with_padding", "onnx/node/generated/test_convinteger_without_padding", + "onnx/node/generated/test_convtranspose_output_shape", "onnx/node/generated/test_dequantizelinear_int16", "onnx/node/generated/test_dequantizelinear_uint16", "onnx/node/generated/test_dropout_default_mask_ratio", @@ -536,6 +534,7 @@ "onnx/node/generated/test_einsum_batch_diagonal", "onnx/node/generated/test_einsum_batch_matmul", "onnx/node/generated/test_einsum_transpose", + "onnx/node/generated/test_einsum_sum", "onnx/node/generated/test_eyelike_with_dtype", "onnx/node/generated/test_isinf_float16", "onnx/node/generated/test_isnan_float16", @@ -596,6 +595,8 @@ "onnx/node/generated/test_reduce_sum_square_default_axes_keepdims_example_expanded", "onnx/node/generated/test_reduce_sum_square_default_axes_keepdims_random", "onnx/node/generated/test_reduce_sum_square_default_axes_keepdims_random_expanded", + "onnx/node/generated/test_scan_sum", + "onnx/node/generated/test_scan9_sum", "onnx/node/generated/test_shape_clip_start", "onnx/node/generated/test_shape_end_1", "onnx/node/generated/test_shape_start_1", diff --git a/third_party/hsa-runtime-headers b/third_party/hsa-runtime-headers index c4fb247e2861..ffa0dc3307be 160000 --- a/third_party/hsa-runtime-headers +++ b/third_party/hsa-runtime-headers @@ -1 +1 @@ -Subproject commit c4fb247e28616c51d37a45f2c0056ed5f4df0555 +Subproject commit ffa0dc3307be5472ccdf7c9825c3dc68340649de diff --git a/third_party/llvm-project b/third_party/llvm-project index 922992a22f7c..864902e9b4d8 160000 --- a/third_party/llvm-project +++ b/third_party/llvm-project @@ -1 +1 @@ -Subproject commit 922992a22f7c87c192cf96606038df3cf20d6404 +Subproject commit 864902e9b4d8bc6d3f0852d5c475e3dc97dd8335 diff --git a/third_party/pybind11 b/third_party/pybind11 deleted file mode 160000 index a2e59f0e7065..000000000000 --- a/third_party/pybind11 +++ /dev/null @@ -1 +0,0 @@ -Subproject commit a2e59f0e7065404b44dfe92a28aca47ba1378dc4 diff --git a/third_party/torch-mlir b/third_party/torch-mlir index 45bb17ebfe5e..140cad5659bb 160000 --- a/third_party/torch-mlir +++ b/third_party/torch-mlir @@ -1 +1 @@ -Subproject commit 45bb17ebfe5e9cdcfd2cfabf850d9dec7127c5ab +Subproject commit 140cad5659bb779bb1f5de1888566db5b5d21236 diff --git a/tools/testing/e2e/iree-e2e-matmul-test.cc b/tools/testing/e2e/iree-e2e-matmul-test.cc index 230956041cdc..ce589e20851d 100644 --- a/tools/testing/e2e/iree-e2e-matmul-test.cc +++ b/tools/testing/e2e/iree-e2e-matmul-test.cc @@ -128,6 +128,29 @@ static void reference_matmul_bf16_bf16_f32_f32( result_data[n + m * n_size] = acc; } +#define REFERENCE_MATMUL_F8(LHSTYPE, RHSTYPE) \ + static void reference_matmul_##LHSTYPE##_##RHSTYPE##_f32_f32( \ + iree_hal_dim_t m_size, iree_hal_dim_t k_size, iree_hal_dim_t n_size, \ + iree_hal_element_type_t lhs_type, iree_hal_element_type_t rhs_type, \ + iree_hal_element_type_t acc_type, bool transpose_rhs, \ + const uint8_t* lhs_data, const uint8_t* rhs_data, const float* acc_data, \ + float* result_data, iree_hal_dim_t m, iree_hal_dim_t n) { \ + float acc = acc_data ? acc_data[n + m * n_size] : 0; \ + for (iree_hal_dim_t k = 0; k < k_size; ++k) { \ + float lhs_float = \ + iree_math_##LHSTYPE##_to_f32(lhs_data[k + m * k_size]); \ + float rhs_float = iree_math_##RHSTYPE##_to_f32( \ + rhs_data[transpose_rhs ? k + n * k_size : n + k * n_size]); \ + acc += lhs_float * rhs_float; \ + } \ + result_data[n + m * n_size] = acc; \ + } + +REFERENCE_MATMUL_F8(f8e5m2, f8e5m2) +REFERENCE_MATMUL_F8(f8e4m3, f8e4m3) +REFERENCE_MATMUL_F8(f8e5m2fnuz, f8e5m2fnuz) +REFERENCE_MATMUL_F8(f8e4m3fnuz, f8e4m3fnuz) + // Helper for reference_matmul. // Computes one element in the result matrix. static iree_status_t reference_matmul_element( @@ -185,6 +208,34 @@ static iree_status_t reference_matmul_element( m_size, k_size, n_size, lhs_type, rhs_type, acc_type, transpose_rhs, (const uint16_t*)lhs_data, (const uint16_t*)rhs_data, (const float*)acc_data, (float*)result_data, m, n); + } else if (lhs_type == IREE_HAL_ELEMENT_TYPE_FLOAT_8_E5M2 && + rhs_type == IREE_HAL_ELEMENT_TYPE_FLOAT_8_E5M2 && + acc_type == IREE_HAL_ELEMENT_TYPE_FLOAT_32) { + reference_matmul_f8e5m2_f8e5m2_f32_f32( + m_size, k_size, n_size, lhs_type, rhs_type, acc_type, transpose_rhs, + (const uint8_t*)lhs_data, (const uint8_t*)rhs_data, + (const float*)acc_data, (float*)result_data, m, n); + } else if (lhs_type == IREE_HAL_ELEMENT_TYPE_FLOAT_8_E4M3 && + rhs_type == IREE_HAL_ELEMENT_TYPE_FLOAT_8_E4M3 && + acc_type == IREE_HAL_ELEMENT_TYPE_FLOAT_32) { + reference_matmul_f8e4m3_f8e4m3_f32_f32( + m_size, k_size, n_size, lhs_type, rhs_type, acc_type, transpose_rhs, + (const uint8_t*)lhs_data, (const uint8_t*)rhs_data, + (const float*)acc_data, (float*)result_data, m, n); + } else if (lhs_type == IREE_HAL_ELEMENT_TYPE_FLOAT_8_E5M2_FNUZ && + rhs_type == IREE_HAL_ELEMENT_TYPE_FLOAT_8_E5M2_FNUZ && + acc_type == IREE_HAL_ELEMENT_TYPE_FLOAT_32) { + reference_matmul_f8e5m2fnuz_f8e5m2fnuz_f32_f32( + m_size, k_size, n_size, lhs_type, rhs_type, acc_type, transpose_rhs, + (const uint8_t*)lhs_data, (const uint8_t*)rhs_data, + (const float*)acc_data, (float*)result_data, m, n); + } else if (lhs_type == IREE_HAL_ELEMENT_TYPE_FLOAT_8_E4M3_FNUZ && + rhs_type == IREE_HAL_ELEMENT_TYPE_FLOAT_8_E4M3_FNUZ && + acc_type == IREE_HAL_ELEMENT_TYPE_FLOAT_32) { + reference_matmul_f8e4m3fnuz_f8e4m3fnuz_f32_f32( + m_size, k_size, n_size, lhs_type, rhs_type, acc_type, transpose_rhs, + (const uint8_t*)lhs_data, (const uint8_t*)rhs_data, + (const float*)acc_data, (float*)result_data, m, n); } else { return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, "unhandled combination of element types in matmul"); diff --git a/tools/testing/e2e/test_utils.c b/tools/testing/e2e/test_utils.c index a7119dcba771..c54c7190cdb6 100644 --- a/tools/testing/e2e/test_utils.c +++ b/tools/testing/e2e/test_utils.c @@ -93,6 +93,36 @@ iree_test_utils_e2e_value_t iree_test_utils_value_make_i32(int32_t value) { return result; } +iree_test_utils_e2e_value_t iree_test_utils_value_make_f8E5M2(uint8_t value) { + iree_test_utils_e2e_value_t result; + result.type = IREE_TEST_UTILS_VALUE_TYPE_F8E5M2; + result.f8_u8 = value; + return result; +} + +iree_test_utils_e2e_value_t iree_test_utils_value_make_f8E4M3(uint8_t value) { + iree_test_utils_e2e_value_t result; + result.type = IREE_TEST_UTILS_VALUE_TYPE_F8E4M3; + result.f8_u8 = value; + return result; +} + +iree_test_utils_e2e_value_t iree_test_utils_value_make_f8E5M2FNUZ( + uint16_t value) { + iree_test_utils_e2e_value_t result; + result.type = IREE_TEST_UTILS_VALUE_TYPE_F8E5M2FNUZ; + result.f8_u8 = value; + return result; +} + +iree_test_utils_e2e_value_t iree_test_utils_value_make_f8E4M3FNUZ( + uint16_t value) { + iree_test_utils_e2e_value_t result; + result.type = IREE_TEST_UTILS_VALUE_TYPE_F8E4M3FNUZ; + result.f8_u8 = value; + return result; +} + iree_test_utils_e2e_value_t iree_test_utils_value_make_f16(uint16_t value) { iree_test_utils_e2e_value_t result; result.type = IREE_TEST_UTILS_VALUE_TYPE_F16; @@ -123,6 +153,14 @@ iree_test_utils_e2e_value_t iree_test_utils_read_buffer_element( return iree_test_utils_value_make_i16(((int16_t*)data)[index]); } else if (iree_hal_element_type_is_integer(result_type, 32)) { return iree_test_utils_value_make_i32(((int32_t*)data)[index]); + } else if (result_type == IREE_HAL_ELEMENT_TYPE_FLOAT_8_E5M2) { + return iree_test_utils_value_make_f8E5M2(((uint8_t*)data)[index]); + } else if (result_type == IREE_HAL_ELEMENT_TYPE_FLOAT_8_E4M3) { + return iree_test_utils_value_make_f8E4M3(((uint8_t*)data)[index]); + } else if (result_type == IREE_HAL_ELEMENT_TYPE_FLOAT_8_E5M2_FNUZ) { + return iree_test_utils_value_make_f8E5M2FNUZ(((uint8_t*)data)[index]); + } else if (result_type == IREE_HAL_ELEMENT_TYPE_FLOAT_8_E4M3_FNUZ) { + return iree_test_utils_value_make_f8E4M3FNUZ(((uint8_t*)data)[index]); } else if (result_type == IREE_HAL_ELEMENT_TYPE_FLOAT_16) { return iree_test_utils_value_make_f16(((uint16_t*)data)[index]); } else if (result_type == IREE_HAL_ELEMENT_TYPE_BFLOAT_16) { @@ -147,6 +185,22 @@ int iree_test_utils_snprintf_value(char* buf, size_t bufsize, return snprintf(buf, bufsize, "%" PRIi32, value.i32); case IREE_TEST_UTILS_VALUE_TYPE_I64: return snprintf(buf, bufsize, "%" PRIi64, value.i64); + case IREE_TEST_UTILS_VALUE_TYPE_F8E5M2: + return snprintf(buf, bufsize, + precision == PRECISION_HIGH ? "%.3g" : "%.2g", + iree_math_f8e5m2_to_f32(value.f8_u8)); + case IREE_TEST_UTILS_VALUE_TYPE_F8E4M3: + return snprintf(buf, bufsize, + precision == PRECISION_HIGH ? "%.3g" : "%.2g", + iree_math_f8e4m3_to_f32(value.f8_u8)); + case IREE_TEST_UTILS_VALUE_TYPE_F8E5M2FNUZ: + return snprintf(buf, bufsize, + precision == PRECISION_HIGH ? "%.3g" : "%.2g", + iree_math_f8e5m2fnuz_to_f32(value.f8_u8)); + case IREE_TEST_UTILS_VALUE_TYPE_F8E4M3FNUZ: + return snprintf(buf, bufsize, + precision == PRECISION_HIGH ? "%.3g" : "%.2g", + iree_math_f8e4m3fnuz_to_f32(value.f8_u8)); case IREE_TEST_UTILS_VALUE_TYPE_F16: return snprintf(buf, bufsize, precision == PRECISION_HIGH ? "%.5g" : "%.4g", @@ -257,6 +311,18 @@ void iree_test_utils_write_element(iree_hal_element_type_t element_type, case IREE_HAL_ELEMENT_TYPE_BFLOAT_16: *(uint16_t*)dst = iree_math_f32_to_bf16((float)value); break; + case IREE_HAL_ELEMENT_TYPE_FLOAT_8_E5M2: + *(uint8_t*)dst = iree_math_f32_to_f8e5m2((float)value); + break; + case IREE_HAL_ELEMENT_TYPE_FLOAT_8_E4M3: + *(uint8_t*)dst = iree_math_f32_to_f8e4m3((float)value); + break; + case IREE_HAL_ELEMENT_TYPE_FLOAT_8_E5M2_FNUZ: + *(uint8_t*)dst = iree_math_f32_to_f8e5m2fnuz((float)value); + break; + case IREE_HAL_ELEMENT_TYPE_FLOAT_8_E4M3_FNUZ: + *(uint8_t*)dst = iree_math_f32_to_f8e4m3fnuz((float)value); + break; WRITE_ELEMENT_CASE(FLOAT_32, float) WRITE_ELEMENT_CASE(FLOAT_64, double) // clang-format on @@ -296,6 +362,10 @@ void iree_test_utils_get_min_max_for_element_type( *max = +4; break; case IREE_HAL_ELEMENT_TYPE_BFLOAT_16: + case IREE_HAL_ELEMENT_TYPE_FLOAT_8_E5M2: + case IREE_HAL_ELEMENT_TYPE_FLOAT_8_E4M3: + case IREE_HAL_ELEMENT_TYPE_FLOAT_8_E5M2_FNUZ: + case IREE_HAL_ELEMENT_TYPE_FLOAT_8_E4M3_FNUZ: *min = -2; *max = +2; break; diff --git a/tools/testing/e2e/test_utils.h b/tools/testing/e2e/test_utils.h index f095537112e9..46d99f11df13 100644 --- a/tools/testing/e2e/test_utils.h +++ b/tools/testing/e2e/test_utils.h @@ -48,6 +48,11 @@ typedef enum iree_test_utils_value_type_e { IREE_TEST_UTILS_VALUE_TYPE_F64 = 7, // bfloat16 IREE_TEST_UTILS_VALUE_TYPE_BF16 = 8, + // 8-bit float types. + IREE_TEST_UTILS_VALUE_TYPE_F8E5M2 = 9, + IREE_TEST_UTILS_VALUE_TYPE_F8E4M3 = 10, + IREE_TEST_UTILS_VALUE_TYPE_F8E5M2FNUZ = 11, + IREE_TEST_UTILS_VALUE_TYPE_F8E4M3FNUZ = 12, } iree_test_utils_value_type_t; // Maximum size, in bytes, of any value type we can represent. @@ -64,6 +69,7 @@ typedef struct iree_test_utils_value_t { float f32; uint16_t f16_u16; uint16_t bf16_u16; + uint8_t f8_u8; double f64; uint8_t value_storage[IREE_E2E_TEST_VALUE_STORAGE_SIZE]; // max size of all // value types