From b47ed8c93d82b2030db4592f57a78afd470dbc08 Mon Sep 17 00:00:00 2001 From: Diego Caballero Date: Mon, 7 Aug 2023 12:05:34 -0700 Subject: [PATCH] [CPU] Add `contract` fast-math-flag to arith operations (#14551) This patch adds the `contract` FMF to some arith operations so that they can be folded into an fma instruction. We are doing this by default as we are lowering matmul ops by default to fmas. We will add different fp modes to have more control on fp optimizations depending on the tolerance to fp errors. --- .../Codegen/Common/AddFastMathFlags.cpp | 44 +++++++++++++++++++ .../iree/compiler/Codegen/Common/BUILD.bazel | 1 + .../compiler/Codegen/Common/CMakeLists.txt | 1 + .../iree/compiler/Codegen/Common/PassDetail.h | 1 + .../src/iree/compiler/Codegen/Common/Passes.h | 3 ++ .../iree/compiler/Codegen/Common/Passes.td | 7 +++ .../compiler/Codegen/Common/test/BUILD.bazel | 1 + .../Codegen/Common/test/CMakeLists.txt | 1 + .../Codegen/Common/test/add_fmfs.mlir | 13 ++++++ .../iree/compiler/Codegen/LLVMCPU/Passes.cpp | 1 + 10 files changed, 73 insertions(+) create mode 100644 compiler/src/iree/compiler/Codegen/Common/AddFastMathFlags.cpp create mode 100644 compiler/src/iree/compiler/Codegen/Common/test/add_fmfs.mlir diff --git a/compiler/src/iree/compiler/Codegen/Common/AddFastMathFlags.cpp b/compiler/src/iree/compiler/Codegen/Common/AddFastMathFlags.cpp new file mode 100644 index 000000000000..35115db36435 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Common/AddFastMathFlags.cpp @@ -0,0 +1,44 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Codegen/Common/PassDetail.h" +#include "iree/compiler/Codegen/Common/Passes.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" + +#define DEBUG_TYPE "iree-codegen-add-fast-math-flags" + +using namespace mlir; +using namespace mlir::iree_compiler; + +/// Add `contract` FMF to operations that support it. +static void addContractFMF(Operation *op) { + LLVM::FastmathFlags contract = LLVM::FastmathFlags::contract; + TypeSwitch(op) + .Case( + [&](auto llvmOp) { llvmOp.setFastmathFlags(contract); }); +} + +namespace { + +/// Add the corresponding fast-math flags to operations given a floating-point +/// optimization mode. +// TODO: For now we only allow default flags, such as arithmetic reassociation. +struct AddFastMathFlagsPass + : public AddFastMathFlagsBase { +public: + using AddFastMathFlagsBase::AddFastMathFlagsBase; + + void runOnOperation() override { + getOperation()->walk([](Operation *op) { addContractFMF(op); }); + } +}; + +} // namespace + +std::unique_ptr> +mlir::iree_compiler::createAddFastMathFlagsPass() { + return std::make_unique(); +} diff --git a/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel index bd10121be77f..75b0e01c2975 100644 --- a/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel @@ -143,6 +143,7 @@ iree_compiler_cc_library( iree_compiler_cc_library( name = "Common", srcs = [ + "AddFastMathFlags.cpp", "BubbleUpOrdinalOps.cpp", "BufferizationAnalysis.cpp", "BufferizeCopyOnlyDispatchesPass.cpp", diff --git a/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt index d43a2ea8df7e..6df643e9288b 100644 --- a/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt @@ -118,6 +118,7 @@ iree_cc_library( "Transforms.h" "UserConfig.h" SRCS + "AddFastMathFlags.cpp" "BubbleUpOrdinalOps.cpp" "BufferizationAnalysis.cpp" "BufferizeCopyOnlyDispatchesPass.cpp" diff --git a/compiler/src/iree/compiler/Codegen/Common/PassDetail.h b/compiler/src/iree/compiler/Codegen/Common/PassDetail.h index 2f74e83019ff..f0d4b10c7793 100644 --- a/compiler/src/iree/compiler/Codegen/Common/PassDetail.h +++ b/compiler/src/iree/compiler/Codegen/Common/PassDetail.h @@ -10,6 +10,7 @@ #include "iree/compiler/Dialect/HAL/IR/HALOps.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" #include "mlir/Pass/Pass.h" diff --git a/compiler/src/iree/compiler/Codegen/Common/Passes.h b/compiler/src/iree/compiler/Codegen/Common/Passes.h index 71dd8c6c1e75..bd23f3380589 100644 --- a/compiler/src/iree/compiler/Codegen/Common/Passes.h +++ b/compiler/src/iree/compiler/Codegen/Common/Passes.h @@ -14,6 +14,7 @@ #include "iree/compiler/Codegen/Dialect/IREECodegenAttrs.h" #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Pass/Pass.h" namespace mlir { @@ -37,6 +38,8 @@ void addIREEComprehensiveBufferizePasses( std::nullopt, std::optional memCpyFn = std::nullopt); +std::unique_ptr> createAddFastMathFlagsPass(); + /// Pass to bubble up ordinal operations to allow workgroup count computation /// based on slices to correlate back to workload computation. std::unique_ptr createBubbleUpOrdinalOpsPass(); diff --git a/compiler/src/iree/compiler/Codegen/Common/Passes.td b/compiler/src/iree/compiler/Codegen/Common/Passes.td index 2cc6b777062c..dbed920ce661 100644 --- a/compiler/src/iree/compiler/Codegen/Common/Passes.td +++ b/compiler/src/iree/compiler/Codegen/Common/Passes.td @@ -13,6 +13,13 @@ include "mlir/Pass/PassBase.td" // Common passes for all backends (keep alphabetical) //===---------------------------------------------------------------------===// +def AddFastMathFlags + : Pass<"iree-codegen-add-fast-math-flags", "LLVM::LLVMFuncOp"> { + let summary = "Add fast math flags to all the operations supporting them, " + "given a floating-point mode."; + let constructor = "mlir::iree_compiler::createAddFastMathFlagsPass()"; +} + def BubbleUpOrdinalOps : Pass<"iree-codegen-bubble-up-ordinal-ops", ""> { let summary = "Bubbles op ordinal ops to allow for workgroup count computation"; let constructor = "mlir::iree_compiler::createBubbleUpOrdinalOpsPass()"; diff --git a/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel index c2f533a996b8..ecd46603c00f 100644 --- a/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel @@ -18,6 +18,7 @@ iree_lit_test_suite( name = "lit", srcs = enforce_glob( [ + "add_fmfs.mlir", "affinemin_canonicalization.mlir", "batch_matmuls.mlir", "bubble_up_ordinal_ops.mlir", diff --git a/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt index a971351ed343..e1159169b363 100644 --- a/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt @@ -14,6 +14,7 @@ iree_lit_test_suite( NAME lit SRCS + "add_fmfs.mlir" "affinemin_canonicalization.mlir" "batch_matmuls.mlir" "bubble_up_ordinal_ops.mlir" diff --git a/compiler/src/iree/compiler/Codegen/Common/test/add_fmfs.mlir b/compiler/src/iree/compiler/Codegen/Common/test/add_fmfs.mlir new file mode 100644 index 000000000000..47a76cf5c9f0 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Common/test/add_fmfs.mlir @@ -0,0 +1,13 @@ +// RUN: iree-opt -iree-codegen-add-fast-math-flags --split-input-file %s | FileCheck %s + +// LABEL: llvm.func @fmfs +llvm.func @fmfs() -> f32 { + %c3 = llvm.mlir.constant(3.000000e+00 : f32) : f32 + %c6 = llvm.mlir.constant(6.000000e+00 : f32) : f32 + %mul = llvm.fmul %c3, %c3 : f32 + %add = llvm.fadd %c3, %c6 : f32 + llvm.return %add : f32 +} + +// CHECK: llvm.fmul %{{.*}}, %{{.*}} {fastmathFlags = #llvm.fastmath} : f32 +// CHECK: llvm.fadd %{{.*}}, %{{.*}} {fastmathFlags = #llvm.fastmath} : f32 diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp index fd1d90b71674..20f0c3f13a99 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp @@ -752,6 +752,7 @@ static void addLowerToLLVMPasses(OpPassManager &passManager) { passManager.addPass(createCanonicalizerPass()); passManager.addPass(createCSEPass()); + passManager.addNestedPass(createAddFastMathFlagsPass()); } void buildLLVMCPUCodegenPassPipeline(OpPassManager &passManager) {