From 4d20b82812951fd971f930f43167c11e46da1c25 Mon Sep 17 00:00:00 2001 From: Ben Vanik Date: Thu, 24 Oct 2024 08:36:20 -0700 Subject: [PATCH 01/45] Emit an error when affinity analysis fails. (#18883) Includes the flag I'd tell anyone to use if they filed a bug. Fixes #18878. --- .../compiler/Dialect/Stream/Transforms/ConvertToStream.cpp | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/ConvertToStream.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/ConvertToStream.cpp index 91f5c0ffff3f..92a3457dfa22 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/ConvertToStream.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/ConvertToStream.cpp @@ -224,6 +224,11 @@ struct ConvertToStreamPass final // for all SSA values we'll use during conversion are available. AffinityAnalysis affinityAnalysis(getOperation()); if (failed(affinityAnalysis.run())) { + getOperation().emitError() + << "affinity analysis failed to converge (input program may have " + "invalid affinities assigned); use" + "`--iree-stream-annotate-input-affinities` to help identify the " + "invalid affinities"; return signalPassFailure(); } From c3fae2f7443908e91a00eb65a4f330ad1f71e63f Mon Sep 17 00:00:00 2001 From: Max191 <44243577+Max191@users.noreply.github.com> Date: Thu, 24 Oct 2024 08:36:54 -0700 Subject: [PATCH 02/45] [LLVMGPU] Use forall workgroup distribution in TileAndFuse pipeline (#18565) This switches the TileAndFuse pipeline to use scf.forall distribution. Using scf.forall distribution also requires some changes to the pass ordering in the TileAndFuse pipeline, which is also handled by this PR: 1. The main difference is that PackToIntrinsics happens before workgroup distribution. Otherwise, collapse_shape ops can end up at the end of the workgroup forall, and an extra buffer is created. 2. Pack decomposition is now staged, with packs/unpacks at the function boundaries being decomposed early before workgroup decomposition, and the rest being decomposed after reduction tiling as before. This prevents unpacks being fused into the workgroup forall and causing the same problem as in (1). 3. `ConcretizeMmaShapes` now runs before workgroup tiling as well, so the resulting collapse_shape on the multi_mma op result can be propagated to the function boundary before any tiling. This is also to avoid the same problem as in (1). The lowering configs on the MMA path have also changed, since they now need to account for inner tile sizes of packing. depends on https://github.com/iree-org/iree/pull/18852 Signed-off-by: Max Dawkins --- .../Dialect/GPU/TargetUtils/ConfigUtils.cpp | 6 +- .../iree/compiler/Codegen/LLVMGPU/Passes.cpp | 73 ++++-- .../test/ROCDL/config_tile_and_fuse.mlir | 4 +- .../test/ROCDL/pipeline_tile_and_fuse.mlir | 225 ++++++++++-------- ...tile_and_vectorize_to_cooperative_ops.mlir | 10 +- 5 files changed, 183 insertions(+), 135 deletions(-) diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp index 611a87454ecf..ca23b0ca6e06 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp @@ -245,10 +245,8 @@ LogicalResult setMatmulLoweringConfig(IREE::GPU::TargetAttr target, } // Compute the M/N dimension tile size by multiplying subgroup information. - workgroupTileSizes[mDim] = - schedule->mWarpCount * schedule->mTileCount * schedule->mSize; - workgroupTileSizes[nDim] = - schedule->nWarpCount * schedule->nTileCount * schedule->nSize; + workgroupTileSizes[mDim] = schedule->mWarpCount * schedule->mTileCount; + workgroupTileSizes[nDim] = schedule->nWarpCount * schedule->nTileCount; // Specify the subgroup tile sizes from the mma schedule. This is applied subgroupTileSizes[mDim] = schedule->mTileCount; diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp index e8c3de89f80e..76b1af3204be 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp @@ -19,6 +19,7 @@ #include "iree/compiler/Codegen/Utils/GPUUtils.h" #include "iree/compiler/Codegen/Utils/MarkerUtils.h" #include "iree/compiler/Codegen/Utils/Utils.h" +#include "iree/compiler/Dialect/Flow/IR/FlowOps.h" #include "iree/compiler/Dialect/HAL/IR/HALTypes.h" #include "iree/compiler/Dialect/Util/Transforms/Passes.h" #include "iree/compiler/Utils/PassUtils.h" @@ -190,18 +191,23 @@ static void addBufferizePasses(OpPassManager &funcPassManager) { } static void tileAndDistributeToWorkgroup( - OpPassManager &funcPassManager, + OpPassManager &funcPassManager, bool useForall, std::optional convertToDpsOptions = ConvertToDestinationPassingStylePassOptions{}) { - funcPassManager.addPass(createTileAndDistributeToWorkgroupsPass( - kNumMaxParallelDims, - linalg::DistributionMethod::CyclicNumProcsEqNumIters)); - funcPassManager.addPass(createCSEPass()); - - if (convertToDpsOptions) { + if (useForall) { funcPassManager.addPass( - createConvertToDestinationPassingStylePass(*convertToDpsOptions)); + createTileAndDistributeToWorkgroupsUsingForallOpPass()); + } else { + funcPassManager.addPass(createTileAndDistributeToWorkgroupsPass( + kNumMaxParallelDims, + linalg::DistributionMethod::CyclicNumProcsEqNumIters)); + funcPassManager.addPass(createCSEPass()); + if (convertToDpsOptions) { + funcPassManager.addPass( + createConvertToDestinationPassingStylePass(*convertToDpsOptions)); + } } + // TODO(#16421): Disable decomposition due to failure in bufferization. // funcPassManager.addPass( // IREE::LinalgExt::createTileAndDecomposeAttentionPass()); @@ -212,7 +218,8 @@ static void tileAndDistributeToWorkgroup( static void tileAndBufferize(OpPassManager &funcPassManager) { ConvertToDestinationPassingStylePassOptions options; options.useWARForCooperativeMatrixCodegen = true; - tileAndDistributeToWorkgroup(funcPassManager, options); + tileAndDistributeToWorkgroup(funcPassManager, /*useForall=*/false, + /*convertToDpsOptions=*/options); addBufferizePasses(funcPassManager); } @@ -243,7 +250,7 @@ static void addGPUVectorizationPasses(OpPassManager &funcPassManager, //===---------------------------------------------------------------------===// void addGPUVectorizationPassPipeline(OpPassManager &funcPassManager) { - tileAndDistributeToWorkgroup(funcPassManager); + tileAndDistributeToWorkgroup(funcPassManager, /*useForall=*/false); funcPassManager.addPass(createCanonicalizerPass()); funcPassManager.addPass(createCanonicalizerPass()); @@ -323,22 +330,45 @@ static void addGPUBufferizePasses(OpPassManager &funcPassManager) { funcPassManager.addPass(createCSEPass()); } +/// Control function for decomposing pack and unpack ops. Returns true if the +/// op is a PackOp with a DispatchTensorLoadOp producer, or an UnPackOp with +/// only DispatchTensorStoreOp consumers. +LogicalResult isAtBoundary(Operation *op) { + if (isa(op)) { + if (isa_and_nonnull( + op->getOperand(0).getDefiningOp())) { + return success(); + } + } else if (isa(op)) { + if (llvm::all_of(op->getUsers(), [](Operation *user) { + return isa(user); + })) { + return success(); + } + } + return failure(); +} + void addGPUTileAndFusePassPipeline(OpPassManager &funcPassManager, const GPUPipelineOptions &pipelineOptions) { - tileAndDistributeToWorkgroup(funcPassManager, - /*convertToDpsOptions=*/std::nullopt); - // Step 1. Promote matmul operands and pack to intrinsic shapes. funcPassManager.addPass(createGPUPromoteMatmulOperandsPass()); funcPassManager.addPass(IREE::GPU::createPackToIntrinsicsPass()); + // Decompose packs and unpacks that are at the function boundary. + funcPassManager.addPass(createDecomposeBoundaryPackUnPackOpsPass()); - // Step 1.5. Expand result shapes of MultiMmaOps before reduction tiling. + // Step 1.5. Expand result shapes of MultiMmaOps before tiling, and + // propagate reshapes to the function boundary. { IREE::GPU::ConcretizeMmaShapesPassOptions options; options.concretizeInputs = false; options.concretizeResult = true; funcPassManager.addPass(IREE::GPU::createConcretizeMmaShapesPass()); } + funcPassManager.addPass(createPropagateReshapesByExpansionPass()); + + tileAndDistributeToWorkgroup(funcPassManager, /*useForall=*/true, + /*convertToDpsOptions=*/std::nullopt); // Step 2. Tile and fuse tileable ops to reduction loops. { @@ -468,7 +498,7 @@ void addGPUTileAndFusePassPipeline(OpPassManager &funcPassManager, //===---------------------------------------------------------------------===// void addGPUWinogradVectorizePassPipeline(OpPassManager &funcPassManager) { - tileAndDistributeToWorkgroup(funcPassManager); + tileAndDistributeToWorkgroup(funcPassManager, /*useForall=*/false); funcPassManager.addPass(createCanonicalizerPass()); funcPassManager.addPass(createCanonicalizerPass()); @@ -505,7 +535,7 @@ void addGPUWinogradVectorizePassPipeline(OpPassManager &funcPassManager) { void addGPUMatmulSimtPassPipeline(OpPassManager &funcPassManager, const GPUPipelineOptions &options) { - tileAndDistributeToWorkgroup(funcPassManager); + tileAndDistributeToWorkgroup(funcPassManager, /*useForall=*/false); funcPassManager.addPass(createCanonicalizerPass()); funcPassManager.addPass(createCanonicalizerPass()); @@ -709,7 +739,7 @@ void addGPUMatmulTensorCoreMmaSyncPassPipeline( void addGPUTransposePassPipeline(OpPassManager &funcPassManager, const GPUPipelineOptions &options) { - tileAndDistributeToWorkgroup(funcPassManager); + tileAndDistributeToWorkgroup(funcPassManager, /*useForall=*/false); funcPassManager.addPass(createCanonicalizerPass()); funcPassManager.addPass(createCanonicalizerPass()); @@ -814,7 +844,7 @@ static void addVectorBufferizePasses(OpPassManager &funcPassManager) { void addGPUVectorDistributePassPipeline(OpPassManager &funcPassManager, const GPUPipelineOptions &options, bool usePadToModelSharedMemcpy) { - tileAndDistributeToWorkgroup(funcPassManager); + tileAndDistributeToWorkgroup(funcPassManager, /*useForall=*/false); ReorderWorkgroupsStrategy reorderStrategy = getReorderWorkgroupsStrategy(options.reorderStrategy); @@ -914,7 +944,7 @@ void addGPUVectorDistributePassPipeline(OpPassManager &funcPassManager, } void addGPUWarpReductionPassPipeline(OpPassManager &funcPassManager) { - tileAndDistributeToWorkgroup(funcPassManager); + tileAndDistributeToWorkgroup(funcPassManager, /*useForall=*/false); funcPassManager.addPass(createRematerializeParallelOpsPass()); funcPassManager.addPass(createCanonicalizerPass()); funcPassManager.addPass(createGPUTileReductionPass()); @@ -958,7 +988,7 @@ void addGPUWarpReductionPassPipeline(OpPassManager &funcPassManager) { } void addGPUPackUnPackPasses(OpPassManager &funcPassManager) { - tileAndDistributeToWorkgroup(funcPassManager); + tileAndDistributeToWorkgroup(funcPassManager, /*useForall=*/false); funcPassManager.addPass(createCanonicalizerPass()); funcPassManager.addPass(createCSEPass()); @@ -994,7 +1024,8 @@ void addGPUDefaultPassPipeline(OpPassManager &funcPassManager, const GPUPipelineOptions &options) { ConvertToDestinationPassingStylePassOptions dpsOptions; dpsOptions.useWARForCooperativeMatrixCodegen = true; - tileAndDistributeToWorkgroup(funcPassManager, dpsOptions); + tileAndDistributeToWorkgroup(funcPassManager, /*useForall=*/false, + /*convertToDpsOptions=*/dpsOptions); if (options.enableUkernels) { funcPassManager.addPass(createGPULowerToUKernelsPass()); } diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse.mlir index 53952e953549..b98e85a79713 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse.mlir @@ -38,7 +38,7 @@ func.func @expanded_matmul_transpose_b(%lhs: tensor<2x64x2048xf16>, %rhs: tensor // CHECK-SAME: promote_operands = [0, 1] // CHECK-SAME: reduction = [0, 0, 0, 0, 4] // CHECK-SAME: subgroup = [0, 0, 4, 1, 0] -// CHECK-SAME: workgroup = [1, 1, 64, 64, 0] +// CHECK-SAME: workgroup = [1, 1, 4, 4, 0] // ----- @@ -63,7 +63,7 @@ func.func @mfma_matmul_1024x1024x1024(%lhs: tensor<1024x1024xf16>, %rhs: tensor< // CHECK-SAME: promote_operands = [0, 1] // CHECK-SAME: reduction = [0, 0, 2] // CHECK-SAME: subgroup = [4, 4, 0] -// CHECK-SAME: workgroup = [128, 128, 0] +// CHECK-SAME: workgroup = [8, 8, 0] // ----- diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir index 0dc8b0f245a5..912acf310b26 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir @@ -50,18 +50,20 @@ hal.executable public @main { // CHECK-DAG: %[[B2:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(2) // CHECK-DAG: memref.alloc() : memref<64x8xf16, #gpu.address_space> // CHECK-DAG: memref.alloc() : memref<64x8xf16, #gpu.address_space> -// CHECK: %[[LOOP:.+]] = scf.for %[[IV:.+]] = %c0 to %c1280 step %c4 {{.*}} -> (vector<8x4xf32>) -// CHECK: gpu.barrier -// CHECK-DAG: %[[LHS_RD:.+]] = vector.transfer_read %[[B0]]{{.*}} vector<2xf16> -// CHECK-DAG: vector.transfer_write %[[LHS_RD]], %[[LHS_ALLOC:[A-Za-z0-9]+]] -// CHECK-DAG: %[[RHS_RD:.+]] = vector.transfer_read %[[B1]]{{.*}} vector<2xf16> -// CHECK-DAG: vector.transfer_write %[[RHS_RD]], %[[RHS_ALLOC:[A-Za-z0-9]+]] -// CHECK: gpu.barrier -// CHECK-DAG: %[[LHS_MM:.+]] = vector.transfer_read %[[LHS_ALLOC]]{{.*}} vector<8x4xf16> -// CHECK-DAG: %[[RHS_MM:.+]] = vector.transfer_read %[[RHS_ALLOC]]{{.*}} vector<4x4xf16> -// CHECK: %[[MM:.+]] = vector.contract {{.*}} %[[LHS_MM]], %[[RHS_MM]] -// CHECK: scf.yield %[[MM]] -// CHECK: vector.transfer_write %[[LOOP]], %[[B2]] +// CHECK: scf.forall ({{.*}}) in (32, 160) { +// CHECK: %[[LOOP:.+]] = scf.for %[[IV:.+]] = %c0 to %c1280 step %c4 {{.*}} -> (vector<8x4xf32>) +// CHECK: gpu.barrier +// CHECK-DAG: %[[LHS_RD:.+]] = vector.transfer_read %[[B0]]{{.*}} vector<2xf16> +// CHECK-DAG: vector.transfer_write %[[LHS_RD]], %[[LHS_ALLOC:[A-Za-z0-9]+]] +// CHECK-DAG: %[[RHS_RD:.+]] = vector.transfer_read %[[B1]]{{.*}} vector<2xf16> +// CHECK-DAG: vector.transfer_write %[[RHS_RD]], %[[RHS_ALLOC:[A-Za-z0-9]+]] +// CHECK: gpu.barrier +// CHECK-DAG: %[[LHS_MM:.+]] = vector.transfer_read %[[LHS_ALLOC]]{{.*}} vector<8x4xf16> +// CHECK-DAG: %[[RHS_MM:.+]] = vector.transfer_read %[[RHS_ALLOC]]{{.*}} vector<4x4xf16> +// CHECK: %[[MM:.+]] = vector.contract {{.*}} %[[LHS_MM]], %[[RHS_MM]] +// CHECK: scf.yield %[[MM]] +// CHECK: vector.transfer_write %[[LOOP]], %[[B2]] +// CHECK: } {mapping = [#iree_codegen.workgroup_mapping, #iree_codegen.workgroup_mapping]} // ----- @@ -71,7 +73,7 @@ hal.executable public @main { #hal.pipeline.binding ]> #config = #iree_gpu.lowering_config<{ - workgroup = [64, 64, 0], + workgroup = [4, 4, 0], reduction = [0, 0, 2], subgroup = [2, 2], mma_kind = #iree_gpu.mma_layout, @@ -112,21 +114,23 @@ hal.executable public @main { // CHECK-DAG: %[[B2:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(2) // CHECK-DAG: memref.alloc() : memref<64x36xf16, #gpu.address_space> // CHECK-DAG: memref.alloc() : memref<64x36xf16, #gpu.address_space> -// CHECK: %[[LOOP:.+]] = scf.for %[[IV:.+]] = %c0 to %c80 step %c2 {{.*}} -> (vector<2x2x4x1xf32>) -// CHECK: gpu.barrier -// CHECK-DAG: %[[LHS_RD:.+]] = vector.transfer_read %[[B0]]{{.*}} vector<8xf16> -// CHECK-DAG: vector.transfer_write %[[LHS_RD]] -// CHECK-DAG: %[[RHS_RD:.+]] = vector.transfer_read %[[B1]]{{.*}} vector<8xf16> -// CHECK-DAG: vector.transfer_write %[[RHS_RD]] -// CHECK: gpu.barrier -// CHECK-DAG: vector.transfer_read {{.*}} #gpu.address_space>, vector<2x1x2x4xf16> -// CHECK-DAG: vector.transfer_read {{.*}} #gpu.address_space>, vector<2x1x2x4xf16> -// CHECK-DAG: vector.transpose %{{.*}}, [0, 2, 1, 3] : vector<2x1x2x4xf16> -// CHECK-DAG: vector.transpose %{{.*}}, [0, 2, 1, 3] : vector<2x1x2x4xf16> -// CHECK-COUNT-4: amdgpu.mfma {{.*}}blocks = 1 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32 -// CHECK: scf.yield -// CHECK: %[[LOOP_T:.+]] = vector.transpose %[[LOOP]], [0, 2, 1, 3] : vector<2x2x4x1xf32> to vector<2x4x2x1xf32> -// CHECK: vector.transfer_write %[[LOOP_T]], %[[B2]] +// CHECK: scf.forall ({{.*}}) in (32, 160) { +// CHECK: %[[LOOP:.+]] = scf.for %[[IV:.+]] = %c0 to %c80 step %c2 {{.*}} -> (vector<2x2x4x1xf32>) +// CHECK: gpu.barrier +// CHECK-DAG: %[[LHS_RD:.+]] = vector.transfer_read %[[B0]]{{.*}} vector<8xf16> +// CHECK-DAG: vector.transfer_write %[[LHS_RD]] +// CHECK-DAG: %[[RHS_RD:.+]] = vector.transfer_read %[[B1]]{{.*}} vector<8xf16> +// CHECK-DAG: vector.transfer_write %[[RHS_RD]] +// CHECK: gpu.barrier +// CHECK-DAG: vector.transfer_read {{.*}} #gpu.address_space>, vector<2x1x2x4xf16> +// CHECK-DAG: vector.transfer_read {{.*}} #gpu.address_space>, vector<2x1x2x4xf16> +// CHECK-DAG: vector.transpose %{{.*}}, [0, 2, 1, 3] : vector<2x1x2x4xf16> +// CHECK-DAG: vector.transpose %{{.*}}, [0, 2, 1, 3] : vector<2x1x2x4xf16> +// CHECK-COUNT-4: amdgpu.mfma {{.*}}blocks = 1 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32 +// CHECK: scf.yield +// CHECK: %[[LOOP_T:.+]] = vector.transpose %[[LOOP]], [0, 2, 1, 3] : vector<2x2x4x1xf32> to vector<2x4x2x1xf32> +// CHECK: vector.transfer_write %[[LOOP_T]], %[[B2]] +// CHECK: } {mapping = [#iree_codegen.workgroup_mapping, #iree_codegen.workgroup_mapping]} // ----- @@ -136,7 +140,7 @@ hal.executable public @main { #hal.pipeline.binding ]> #config = #iree_gpu.lowering_config<{ - workgroup = [1, 64, 64, 0], + workgroup = [1, 4, 4, 0], reduction = [0, 0, 0, 2], subgroup = [1, 2, 2], mma_kind = #iree_gpu.mma_layout, @@ -154,11 +158,11 @@ hal.executable private @main { %cst = arith.constant 0.000000e+00 : f32 %c0 = arith.constant 0 : index %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> - %1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> - %2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) alignment(64) offset(%c0) : !flow.dispatch.tensor> + %1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> + %2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) alignment(64) offset(%c0) : !flow.dispatch.tensor> %3 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0, 0], sizes = [2, 34, 34, 1280], strides = [1, 1, 1, 1] : !flow.dispatch.tensor> -> tensor<2x34x34x1280xf16> - %4 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0, 0], sizes = [3, 3, 1280, 1280], strides = [1, 1, 1, 1] : !flow.dispatch.tensor> -> tensor<3x3x1280x1280xf16> - %5 = tensor.empty() : tensor<2x16x16x1280xf32> + %4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [11520, 1280], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<11520x1280xf16> + %5 = tensor.empty() : tensor<2x256x1280xf32> %6 = tensor.empty() : tensor<2x256x11520xf16> %7 = iree_linalg_ext.im2col strides = [2, 2] dilations = [1, 1] kernel_size = [3, 3] @@ -166,15 +170,13 @@ hal.executable private @main { batch_pos = [0] m_pos = [1, 2] k_pos = [3] ins(%3 : tensor<2x34x34x1280xf16>) outs(%6 : tensor<2x256x11520xf16>) -> tensor<2x256x11520xf16> - %collapsed = tensor.collapse_shape %4 [[0, 1, 2], [3]] : tensor<3x3x1280x1280xf16> into tensor<11520x1280xf16> - %collapsed_0 = tensor.collapse_shape %5 [[0], [1, 2], [3]] : tensor<2x16x16x1280xf32> into tensor<2x256x1280xf32> - %8 = linalg.fill ins(%cst : f32) outs(%collapsed_0 : tensor<2x256x1280xf32>) -> tensor<2x256x1280xf32> + %8 = linalg.fill ins(%cst : f32) outs(%5 : tensor<2x256x1280xf32>) -> tensor<2x256x1280xf32> %9 = linalg.generic { indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} - ins(%7, %collapsed : tensor<2x256x11520xf16>, tensor<11520x1280xf16>) + ins(%7, %4 : tensor<2x256x11520xf16>, tensor<11520x1280xf16>) outs(%8 : tensor<2x256x1280xf32>) attrs = {lowering_config = #config} { ^bb0(%in: f16, %in_1: f16, %out: f32): %10 = arith.extf %in : f16 to f32 @@ -183,8 +185,7 @@ hal.executable private @main { %13 = arith.addf %12, %out : f32 linalg.yield %13 : f32 } -> tensor<2x256x1280xf32> - %expanded = tensor.expand_shape %9 [[0], [1, 2], [3]] output_shape [2, 16, 16, 1280] : tensor<2x256x1280xf32> into tensor<2x16x16x1280xf32> - flow.dispatch.tensor.store %expanded, %2, offsets = [0, 0, 0, 0], sizes = [2, 16, 16, 1280], strides = [1, 1, 1, 1] : tensor<2x16x16x1280xf32> -> !flow.dispatch.tensor> + flow.dispatch.tensor.store %9, %2, offsets = [0, 0, 0], sizes = [2, 256, 1280], strides = [1, 1, 1] : tensor<2x256x1280xf32> -> !flow.dispatch.tensor> return } } @@ -200,22 +201,24 @@ hal.executable private @main { // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index // CHECK-DAG: %[[C720:.+]] = arith.constant 720 : index // CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index -// CHECK: %[[LOOP:.+]] = scf.for %[[IV:.+]] = %[[C0]] to %[[C720]] step %[[C2]] {{.*}} -> (vector<1x2x2x4x1xf32>) -// CHECK: gpu.barrier -// CHECK-DAG: %[[LHS_RD:.+]] = vector.transfer_read %[[B0]]{{.*}} vector<8xf16> -// CHECK-DAG: vector.transfer_write %[[LHS_RD]] -// CHECK-DAG: %[[RHS_RD:.+]] = vector.transfer_read %[[B1]]{{.*}} vector<8xf16> -// CHECK-DAG: vector.transfer_write %[[RHS_RD]] -// CHECK: gpu.barrier -// CHECK-DAG: %[[LHS_MM0:.+]] = vector.transfer_read {{.*}} vector<2x1x2x4xf16> -// CHECK-DAG: %[[LHS_MM1:.+]] = vector.broadcast {{.*}} vector<2x1x2x4xf16> to vector<1x2x1x2x4xf16> -// CHECK-DAG: %[[RHS_MM:.+]] = vector.transfer_read {{.*}} vector<2x4x2x1xf16> -// CHECK-DAG: vector.transpose %[[LHS_MM1]], [0, 1, 3, 2, 4] : vector<1x2x1x2x4xf16> to vector<1x2x2x1x4xf16> -// CHECK-DAG: vector.transpose %[[RHS_MM]], [0, 2, 3, 1] : vector<2x4x2x1xf16> to vector<2x2x1x4xf16> -// CHECK-COUNT-4: amdgpu.mfma {{.*}}blocks = 1 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32 -// CHECK: %[[LOOP_T:.+]] = vector.transpose %[[LOOP]], [0, 1, 3, 2, 4] : vector<1x2x2x4x1xf32> to vector<1x2x4x2x1xf32> -// CHECK: %[[EXTRACT:.+]] = vector.extract %[[LOOP_T]][0] : vector<2x4x2x1xf32> from vector<1x2x4x2x1xf32> -// CHECK: vector.transfer_write %[[EXTRACT]], %[[B2]] +// CHECK: scf.forall ({{.*}}) in (2, 4, 20) { +// CHECK: %[[LOOP:.+]] = scf.for %[[IV:.+]] = %[[C0]] to %[[C720]] step %[[C2]] {{.*}} -> (vector<1x2x2x4x1xf32>) +// CHECK: gpu.barrier +// CHECK-DAG: %[[LHS_RD:.+]] = vector.transfer_read %[[B0]]{{.*}} vector<8xf16> +// CHECK-DAG: vector.transfer_write %[[LHS_RD]] +// CHECK-DAG: %[[RHS_RD:.+]] = vector.transfer_read %[[B1]]{{.*}} vector<8xf16> +// CHECK-DAG: vector.transfer_write %[[RHS_RD]] +// CHECK: gpu.barrier +// CHECK-DAG: %[[LHS_MM0:.+]] = vector.transfer_read {{.*}} vector<2x1x2x4xf16> +// CHECK-DAG: %[[LHS_MM1:.+]] = vector.broadcast {{.*}} vector<2x1x2x4xf16> to vector<1x2x1x2x4xf16> +// CHECK-DAG: %[[RHS_MM:.+]] = vector.transfer_read {{.*}} vector<2x4x2x1xf16> +// CHECK-DAG: vector.transpose %[[LHS_MM1]], [0, 1, 3, 2, 4] : vector<1x2x1x2x4xf16> to vector<1x2x2x1x4xf16> +// CHECK-DAG: vector.transpose %[[RHS_MM]], [0, 2, 3, 1] : vector<2x4x2x1xf16> to vector<2x2x1x4xf16> +// CHECK-COUNT-4: amdgpu.mfma {{.*}}blocks = 1 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32 +// CHECK: %[[LOOP_T:.+]] = vector.transpose %[[LOOP]], [0, 1, 3, 2, 4] : vector<1x2x2x4x1xf32> to vector<1x2x4x2x1xf32> +// CHECK: %[[EXTRACT:.+]] = vector.extract %[[LOOP_T]][0] : vector<2x4x2x1xf32> from vector<1x2x4x2x1xf32> +// CHECK: vector.transfer_write %[[EXTRACT]], %[[B2]] +// CHECK: } {mapping = [#iree_codegen.workgroup_mapping, #iree_codegen.workgroup_mapping, #iree_codegen.workgroup_mapping]} // ----- @@ -225,7 +228,7 @@ hal.executable private @main { #hal.pipeline.binding ]> #config = #iree_gpu.lowering_config<{ - workgroup = [1, 4, 16, 256, 0], + workgroup = [1, 4, 16, 16, 0], reduction = [0, 0, 0, 0, 2], subgroup = [1, 4, 1, 4, 0], mma_kind = #iree_gpu.mma_layout, @@ -287,6 +290,7 @@ hal.executable private @main { // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index // CHECK-DAG: %[[C720:.+]] = arith.constant 720 : index // CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index +// CHECK: scf.forall ({{.*}}) in (2, 4, 1, 5) { // CHECK: %[[LOOP:.+]] = scf.for %[[IV:.+]] = %[[C0]] to %[[C720]] step %[[C2]] {{.*}} -> (vector<1x4x1x4x4x1xf32>) // CHECK: gpu.barrier // CHECK-DAG: %[[LHS_RD:.+]] = vector.transfer_read %[[B0]]{{.*}} vector<8xf16> @@ -303,6 +307,7 @@ hal.executable private @main { // CHECK: %[[LOOP_T:.+]] = vector.transpose %[[LOOP]], [0, 1, 2, 4, 3, 5] : vector<1x4x1x4x4x1xf32> to vector<1x4x1x4x4x1xf32> // CHECK: %[[EXTRACT:.+]] = vector.extract %[[LOOP_T]][0] : vector<4x1x4x4x1xf32> from vector<1x4x1x4x4x1xf32> // CHECK: vector.transfer_write %[[EXTRACT]], %[[B2]] +// CHECK: } {mapping = [#iree_codegen.workgroup_mapping, #iree_codegen.workgroup_mapping, #iree_codegen.workgroup_mapping, #iree_codegen.workgroup_mapping]} // ----- @@ -312,7 +317,7 @@ hal.executable private @main { #hal.pipeline.binding ]> #config = #iree_gpu.lowering_config<{ - workgroup = [64, 64, 0], + workgroup = [4, 4, 0], reduction = [0, 0, 2], subgroup = [2, 2], mma_kind = #iree_gpu.mma_layout, @@ -353,21 +358,23 @@ hal.executable public @main { // CHECK-DAG: %[[B2:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(2) // CHECK-DAG: memref.alloc() : memref<64x36xf16, #gpu.address_space> // CHECK-DAG: memref.alloc() : memref<64x36xf16, #gpu.address_space> -// CHECK: %[[LOOP:.+]] = scf.for %[[IV:.+]] = %c0 to %c80 step %c2 {{.*}} -> (vector<2x2x8x1x1xf32>) -// CHECK: gpu.barrier -// CHECK-DAG: vector.transfer_read %[[B0]]{{.*}} vector<8xf16> -// CHECK-DAG: vector.transfer_read %[[B0]]{{.*}} vector<8xf16> -// CHECK-DAG: vector.transfer_read %[[B1]]{{.*}} vector<8xf16> -// CHECK-DAG: vector.transfer_read %[[B1]]{{.*}} vector<8xf16> -// CHECK: gpu.barrier -// CHECK-DAG: vector.transfer_read {{.*}} vector<2x1x2x16xf16> -// CHECK-DAG: vector.transfer_read {{.*}} vector<2x1x2x16xf16> -// CHECK-DAG: vector.transpose %{{.*}}, [0, 2, 1, 3] : vector<2x1x2x16xf16> -// CHECK-DAG: vector.transpose %{{.*}}, [0, 2, 1, 3] : vector<2x1x2x16xf16> -// CHECK-COUNT-8: amdgpu.wmma {{.*}} : vector<16xf16>, vector<16xf16>, vector<8xf32> -// CHECK: scf.yield -// CHECK: %[[LOOP_T:.+]] = vector.transpose %[[LOOP]], [0, 2, 3, 1, 4] : vector<2x2x8x1x1xf32> to vector<2x8x1x2x1xf32> -// CHECK: vector.transfer_write %[[LOOP_T]], %[[B2]] +// CHECK: scf.forall ({{.*}}) in (32, 160) { +// CHECK: %[[LOOP:.+]] = scf.for %[[IV:.+]] = %c0 to %c80 step %c2 {{.*}} -> (vector<2x2x8x1x1xf32>) +// CHECK: gpu.barrier +// CHECK-DAG: vector.transfer_read %[[B0]]{{.*}} vector<8xf16> +// CHECK-DAG: vector.transfer_read %[[B0]]{{.*}} vector<8xf16> +// CHECK-DAG: vector.transfer_read %[[B1]]{{.*}} vector<8xf16> +// CHECK-DAG: vector.transfer_read %[[B1]]{{.*}} vector<8xf16> +// CHECK: gpu.barrier +// CHECK-DAG: vector.transfer_read {{.*}} vector<2x1x2x16xf16> +// CHECK-DAG: vector.transfer_read {{.*}} vector<2x1x2x16xf16> +// CHECK-DAG: vector.transpose %{{.*}}, [0, 2, 1, 3] : vector<2x1x2x16xf16> +// CHECK-DAG: vector.transpose %{{.*}}, [0, 2, 1, 3] : vector<2x1x2x16xf16> +// CHECK-COUNT-8: amdgpu.wmma {{.*}} : vector<16xf16>, vector<16xf16>, vector<8xf32> +// CHECK: scf.yield +// CHECK: %[[LOOP_T:.+]] = vector.transpose %[[LOOP]], [0, 2, 3, 1, 4] : vector<2x2x8x1x1xf32> to vector<2x8x1x2x1xf32> +// CHECK: vector.transfer_write %[[LOOP_T]], %[[B2]] +// CHECK: } {mapping = [#iree_codegen.workgroup_mapping, #iree_codegen.workgroup_mapping]} // ----- @@ -377,7 +384,7 @@ hal.executable public @main { #hal.pipeline.binding ]> #config = #iree_gpu.lowering_config<{ - workgroup = [64, 64, 0], + workgroup = [4, 4, 0], reduction = [0, 0, 2], subgroup = [2, 2], mma_kind = #iree_gpu.mma_layout, @@ -419,9 +426,11 @@ hal.executable public @main { // CHECK-LABEL: func @matmul_transpose_b_mfma_16x16x4 // CHECK-DAG: memref.alloc() : memref<64x10xf32, #gpu.address_space> // CHECK-DAG: memref.alloc() : memref<64x10xf32, #gpu.address_space> -// CHECK: scf.for %{{.*}} = %c0 to %c320 step %c2 {{.*}} -> (vector<2x2x4x1xf32>) -// CHECK-COUNT-8: amdgpu.mfma {{.*}}blocks = 1 : i32, k = 4 : i32, m = 16 : i32, n = 16 : i32 -// CHECK: scf.yield +// CHECK: scf.forall ({{.*}}) in (32, 160) { +// CHECK: scf.for %{{.*}} = %c0 to %c320 step %c2 {{.*}} -> (vector<2x2x4x1xf32>) +// CHECK-COUNT-8: amdgpu.mfma {{.*}}blocks = 1 : i32, k = 4 : i32, m = 16 : i32, n = 16 : i32 +// CHECK: scf.yield +// CHECK: } {mapping = [#iree_codegen.workgroup_mapping, #iree_codegen.workgroup_mapping]} // ----- @@ -431,7 +440,7 @@ hal.executable public @main { #hal.pipeline.binding ]> #config = #iree_gpu.lowering_config<{ - workgroup = [64, 64, 0], + workgroup = [4, 4, 0], reduction = [0, 0, 2], subgroup = [2, 2], mma_kind = #iree_gpu.mma_layout, @@ -473,9 +482,11 @@ hal.executable public @main { // CHECK-LABEL: func @matmul_transpose_b_mfma_16x16x32_f8 // CHECK-DAG: memref.alloc() : memref<64x72xf8E4M3FNUZ, #gpu.address_space> // CHECK-DAG: memref.alloc() : memref<64x72xf8E4M3FNUZ, #gpu.address_space> -// CHECK: scf.for %{{.*}} = %c0 to %c40 step %c2 {{.*}} -> (vector<2x2x4x1xf32>) -// CHECK-COUNT-8: amdgpu.mfma {{.*}}blocks = 1 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32 -// CHECK: scf.yield +// CHECK: scf.forall ({{.*}}) in (32, 160) { +// CHECK: scf.for %{{.*}} = %c0 to %c40 step %c2 {{.*}} -> (vector<2x2x4x1xf32>) +// CHECK-COUNT-8: amdgpu.mfma {{.*}}blocks = 1 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32 +// CHECK: scf.yield +// CHECK: } {mapping = [#iree_codegen.workgroup_mapping, #iree_codegen.workgroup_mapping]} // ----- @@ -485,7 +496,7 @@ hal.executable public @main { #hal.pipeline.binding ]> #config = #iree_gpu.lowering_config<{ - workgroup = [64, 64, 0], + workgroup = [2, 2, 0], reduction = [0, 0, 2], subgroup = [1, 1], mma_kind = #iree_gpu.mma_layout, @@ -527,9 +538,11 @@ hal.executable public @main { // CHECK-LABEL: func @matmul_transpose_b_mfma_32x32x16_i8 // CHECK-DAG: memref.alloc() : memref<64x40xi8, #gpu.address_space> // CHECK-DAG: memref.alloc() : memref<64x40xi8, #gpu.address_space> -// CHECK: scf.for %{{.*}} = %c0 to %c80 step %c2 {{.*}} -> (vector<1x1x4x4x1xi32>) -// CHECK-COUNT-2: amdgpu.mfma {{.*}}blocks = 1 : i32, k = 16 : i32, m = 32 : i32, n = 32 : i32 -// CHECK: scf.yield +// CHECK: scf.forall ({{.*}}) in (32, 160) { +// CHECK: scf.for %{{.*}} = %c0 to %c80 step %c2 {{.*}} -> (vector<1x1x4x4x1xi32>) +// CHECK-COUNT-2: amdgpu.mfma {{.*}}blocks = 1 : i32, k = 16 : i32, m = 32 : i32, n = 32 : i32 +// CHECK: scf.yield +// CHECK: } {mapping = [#iree_codegen.workgroup_mapping, #iree_codegen.workgroup_mapping]} // ----- @@ -539,7 +552,7 @@ hal.executable public @main { #hal.pipeline.binding ]> #config = #iree_gpu.lowering_config<{ - workgroup = [64, 64, 0], + workgroup = [4, 4, 0], reduction = [0, 0, 2], subgroup = [2, 2], mma_kind = #iree_gpu.mma_layout, @@ -581,9 +594,11 @@ hal.executable public @main { // CHECK-LABEL: func @matmul_transpose_b_wmma_f16_16x16x16_f16 // CHECK-DAG: memref.alloc() : memref<64x36xf16, #gpu.address_space> // CHECK-DAG: memref.alloc() : memref<64x36xf16, #gpu.address_space> -// CHECK: scf.for %{{.*}} = %c0 to %c80 step %c2 {{.*}} -> (vector<2x2x16x1x1xf16>) -// CHECK-COUNT-8: amdgpu.wmma {{.*}} : vector<16xf16>, vector<16xf16>, vector<16xf16> -// CHECK: scf.yield +// CHECK: scf.forall ({{.*}}) in (32, 160) { +// CHECK: scf.for %{{.*}} = %c0 to %c80 step %c2 {{.*}} -> (vector<2x2x16x1x1xf16>) +// CHECK-COUNT-8: amdgpu.wmma {{.*}} : vector<16xf16>, vector<16xf16>, vector<16xf16> +// CHECK: scf.yield +// CHECK: } {mapping = [#iree_codegen.workgroup_mapping, #iree_codegen.workgroup_mapping]} // ----- @@ -639,12 +654,14 @@ hal.executable public @main { // the producer's (convolution's) distributed scf.forall loop. // CHECK-LABEL: func @conv_nchw_fused // CHECK: %[[ALLOCA:.+]] = memref.alloca() : memref<1x1x1x1xf32, #gpu.address_space> -// CHECK: scf.for %{{.*}} = %c0 to %c64 step %c1 -// CHECK: linalg.conv_2d_nchw_fchw -// CHECK-SAME: outs(%[[ALLOCA]] : memref<1x1x1x1xf32, #gpu.address_space>) -// CHECK: arith.addf -// CHECK: arith.cmpf -// CHECK: arith.select +// CHECK: scf.forall ({{.*}}) in (64, 14, 7) { +// CHECK: scf.for %{{.*}} = %c0 to %c64 step %c1 +// CHECK: linalg.conv_2d_nchw_fchw +// CHECK-SAME: outs(%[[ALLOCA]] : memref<1x1x1x1xf32, #gpu.address_space>) +// CHECK: arith.addf +// CHECK: arith.cmpf +// CHECK: arith.select +// CHECK: } {mapping = [#iree_codegen.workgroup_mapping, #iree_codegen.workgroup_mapping, #iree_codegen.workgroup_mapping]} // ----- @@ -715,11 +732,13 @@ hal.executable public @main { // CHECK: %[[LINID0:.+]] = affine.apply #[[$MAP]]()[%[[IDX]], %[[IDY]], %[[IDZ]]] // CHECK: %[[IDS:.+]]:2 = affine.delinearize_index %[[LINID0:.+]] into (%c4, %c8) : index, index // CHECK: %[[LINID1:.+]] = affine.apply #[[$MAP1]]()[%[[IDS]]#0, %[[IDS]]#1] -// CHECK: scf.for %{{.*}} = %c0 to %c256 step %c4 {{.*}} -> (vector<1x4xf32>) -// CHECK: scf.for %{{.*}} = %[[LINID1]] to %c4 step %c32 -// CHECK: %[[READ:.+]] = vector.transfer_read {{.*}} : memref<128x256xf32, {{.*}}storage_buffer>>, vector<4xf32> -// CHECK: vector.transfer_write %[[READ]], %{{.*}} : vector<4xf32>, memref<4x6xf32, #gpu.address_space> -// CHECK: vector.contract +// CHECK: scf.forall ({{.*}}) in (32, 98) { +// CHECK: scf.for %{{.*}} = %c0 to %c256 step %c4 {{.*}} -> (vector<1x4xf32>) +// CHECK: scf.for %{{.*}} = %[[LINID1]] to %c4 step %c32 +// CHECK: %[[READ:.+]] = vector.transfer_read {{.*}} : memref<128x256xf32, {{.*}}storage_buffer>>, vector<4xf32> +// CHECK: vector.transfer_write %[[READ]], %{{.*}} : vector<4xf32>, memref<4x6xf32, #gpu.address_space> +// CHECK: vector.contract +// CHECK: } {mapping = [#iree_codegen.workgroup_mapping, #iree_codegen.workgroup_mapping]} // ----- @@ -736,7 +755,7 @@ hal.executable public @main { mma_kind = #iree_gpu.mma_layout, reduction = [0, 0, 4], subgroup = [2, 4, 0], - workgroup = [64, 128, 0], + workgroup = [4, 8, 0], promote_operands = [0, 1] }> @@ -1012,7 +1031,6 @@ hal.executable public @main { // CHECK-DAG: %[[RHS_ALLOC:.+]] = memref.alloc() : memref<4x130xf32, #gpu.address_space> // CHECK: %[[LOOP:.+]] = scf.for %[[IV:.+]] = %c0 to %c1000 step %c4 {{.*}} -> (vector<1x4xf32>) // CHECK: gpu.barrier - // CHECK: scf.for %{{.*}} = %{{.*}} to %c1 step %c32 // CHECK: %[[LHS_RD:.+]] = vector.transfer_read %[[B0]]{{.*}} vector<4xf32> // CHECK-NEXT: vector.transfer_write %[[LHS_RD]], %[[LHS_ALLOC]] @@ -1069,6 +1087,7 @@ hal.executable public @main { // Verify that the write does not get hoisted out of the single threaded // for loop. -// CHECK: vector.transfer_write %{{.*}}, %[[B2]]{{.*}} memref<10x1xf32, #hal.descriptor_type> -// CHECK-NEXT: } +// CHECK: vector.transfer_write %{{.*}}, %[[B2]]{{.*}} memref<10x1xf32, #hal.descriptor_type> +// CHECK-NEXT: } +// CHECK-NEXT: } {mapping = [#iree_codegen.workgroup_mapping]} // CHECK-NEXT: return diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_to_cooperative_ops.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_to_cooperative_ops.mlir index 5cc0b7054198..d57d1631bd77 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_to_cooperative_ops.mlir +++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_to_cooperative_ops.mlir @@ -248,11 +248,11 @@ func.func @matmul_256x1024x128_div_add() attributes {translation_info = #transla // CHECK: %[[LHS_VIEW:.+]] = memref.subview %[[LHS_ALLOC]][%[[IV_Z]], %[[IV_Y]], 0] [1, 16, 32] // CHECK: scf.for %[[IV_X:.+]] = %[[OFFSET_X]] to %[[C32]] step %[[C32]] { // CHECK: %[[RHS_VIEW:.+]] = memref.subview %[[RHS_ALLOC]][%[[IV_Z]], 0, %[[IV_X]]] [1, 32, 16] -// CHECK-DAG: %[[READ0:.+]] = vector.transfer_read %[[LHS_VIEW]][%[[C0]], %[[C0]], %[[C0]]] -// CHECK-DAG: %[[READ1:.+]] = vector.transfer_read %[[LHS_VIEW]][%[[C0]], %[[C0]], %[[C16]]] -// CHECK-DAG: %[[READ2:.+]] = vector.transfer_read %[[RHS_VIEW]][%[[C0]], %[[C0]], %[[C0]]] -// CHECK-DAG: %[[READ3:.+]] = vector.transfer_read %[[RHS_VIEW]][%[[C0]], %[[C16]], %[[C0]]] -// CHECK-DAG: %[[READ4:.+]] = vector.transfer_read %{{.+}}[%[[C0]], %[[C0]], %[[C0]]] +// CHECK-DAG: %[[READ0:.+]] = vector.transfer_read %[[LHS_VIEW]][%[[C0]], %[[C0]], %[[C0]]] +// CHECK-DAG: %[[READ1:.+]] = vector.transfer_read %[[LHS_VIEW]][%[[C0]], %[[C0]], %[[C16]]] +// CHECK-DAG: %[[READ2:.+]] = vector.transfer_read %[[RHS_VIEW]][%[[C0]], %[[C0]], %[[C0]]] +// CHECK-DAG: %[[READ3:.+]] = vector.transfer_read %[[RHS_VIEW]][%[[C0]], %[[C16]], %[[C0]]] +// CHECK-DAG: %[[READ4:.+]] = vector.transfer_read %{{.+}}[%[[C0]], %[[C0]], %[[C0]]] // CHECK: %[[CT0:.+]] = vector.contract // CHECK-SAME: %[[READ0]], %[[READ2]], %[[READ4]] : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16> // CHECK: %[[CT1:.+]] = vector.contract From abe3f893bc6027fe38dc89093666ac6776f070e1 Mon Sep 17 00:00:00 2001 From: Ian Wood <75152913+IanWood1@users.noreply.github.com> Date: Thu, 24 Oct 2024 09:39:13 -0700 Subject: [PATCH 03/45] Add conversions for 1x1 conv_2d to matmul (#18736) Convert1X1FilterConv2DToMatmul: handle dynamic cases and conversion to `linalg.generic` representing a broadcasted batch matmul. This pass is kept because it is used in plugins. GeneralizeLinalgNamedOps: generalize conv ops to `linalg.generic` ops when possible. Converting more ops to linalg.generic ops allows for better reshape propagation and fusion opportunities. Also, removed Convert1X1FilterConv2DToMatmulPass from global optimization because generalize named ops would have already generalized any convolutions that were possible to convert. Signed-off-by: Ian Wood --- .github/workflows/pkgci_regression_test.yml | 8 +- .../Convert1X1FilterConv2DToMatmul.cpp | 148 ++++-------------- .../GeneralizeLinalgNamedOps.cpp | 31 +++- .../compiler/GlobalOptimization/Passes.cpp | 3 +- .../test/conv1x1_to_matmul.mlir | 103 ++++-------- .../test/generalize_named_ops.mlir | 34 +++- 6 files changed, 134 insertions(+), 193 deletions(-) diff --git a/.github/workflows/pkgci_regression_test.yml b/.github/workflows/pkgci_regression_test.yml index 0748ec51859b..fb94905c1b29 100644 --- a/.github/workflows/pkgci_regression_test.yml +++ b/.github/workflows/pkgci_regression_test.yml @@ -220,9 +220,9 @@ jobs: --goldentime-rocm-unet-ms 419.0 \ --goldentime-rocm-clip-ms 18.5 \ --goldentime-rocm-vae-ms 337.0 \ - --goldendispatch-rocm-unet 1545 \ + --goldendispatch-rocm-unet 1527 \ --goldendispatch-rocm-clip 1139 \ - --goldendispatch-rocm-vae 248 \ + --goldendispatch-rocm-vae 247 \ --goldensize-rocm-unet-bytes 2280000 \ --goldensize-rocm-clip-bytes 860000 \ --goldensize-rocm-vae-bytes 840000 \ @@ -241,9 +241,9 @@ jobs: --goldentime-rocm-unet-ms 95.0 \ --goldentime-rocm-clip-ms 15.5 \ --goldentime-rocm-vae-ms 80.0 \ - --goldendispatch-rocm-unet 1545 \ + --goldendispatch-rocm-unet 1527 \ --goldendispatch-rocm-clip 1139 \ - --goldendispatch-rocm-vae 248 \ + --goldendispatch-rocm-vae 247 \ --goldensize-rocm-unet-bytes 2270000 \ --goldensize-rocm-clip-bytes 860000 \ --goldensize-rocm-vae-bytes 840000 \ diff --git a/compiler/src/iree/compiler/GlobalOptimization/Convert1X1FilterConv2DToMatmul.cpp b/compiler/src/iree/compiler/GlobalOptimization/Convert1X1FilterConv2DToMatmul.cpp index a8b4becfff2b..7128dbdfc03b 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/Convert1X1FilterConv2DToMatmul.cpp +++ b/compiler/src/iree/compiler/GlobalOptimization/Convert1X1FilterConv2DToMatmul.cpp @@ -6,7 +6,8 @@ #include "iree/compiler/GlobalOptimization/Passes.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/IR/AffineExpr.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -26,134 +27,51 @@ class Convert1x1FilterConvToMatmul : public OpRewritePattern { LogicalResult matchAndRewrite(Conv2DOpType convOp, PatternRewriter &rewriter) const override { - auto inputShapeType = llvm::dyn_cast( - convOp.getDpsInputOperand(0)->get().getType()); auto filterShapeType = llvm::dyn_cast( convOp.getDpsInputOperand(1)->get().getType()); - auto outputShapeType = llvm::dyn_cast( - convOp.getDpsInitOperand(0)->get().getType()); - - const bool isNCHW = isa(convOp); - const bool isNHWC = isa(convOp); - if (!isNCHW & !isNHWC) + if (!filterShapeType) return failure(); - if (!inputShapeType || !filterShapeType || !outputShapeType) - return failure(); + constexpr bool isNCHW = + std::is_same_v; + constexpr bool isNHWC = + std::is_same_v; + static_assert(isNCHW || isNHWC); - auto inputShape = inputShapeType.getShape(); auto filterShape = filterShapeType.getShape(); - auto outputShape = outputShapeType.getShape(); + + constexpr int64_t numLoops = 7; // Adjusting dimension indices based on Conv2DOpType. - const int nIndex = 0; - const int kcIndex = isNHWC ? 2 : 1; - const int kfIndex = isNHWC ? 3 : 0; - const int khIndex = isNHWC ? 0 : 2; - const int kwIndex = isNHWC ? 1 : 3; - const int ohIndex = isNHWC ? 1 : 2; - const int owIndex = isNHWC ? 2 : 3; - const int ocIndex = isNHWC ? 3 : 1; - - bool isInputHWDynamic = ShapedType::isDynamic(inputShape[ohIndex]) && - ShapedType::isDynamic(inputShape[owIndex]); - - // We cannot merge the width and height if they are both dynamic as we - // cannot expand them back to their dynamic values. - if (isInputHWDynamic) - return failure(); + constexpr int khIndex = isNHWC ? 0 : 2; + constexpr int kwIndex = isNHWC ? 1 : 3; + constexpr int khLoopIndex = isNHWC ? 4 : 5; + constexpr int kwLoopIndex = isNHWC ? 5 : 6; if (filterShape[khIndex] != 1 || filterShape[kwIndex] != 1) return failure(); - // TODO(ataei): Support conversion to linalg.batch_matmul. - if (inputShape[0] != 1) - return failure(); - - if (!llvm::all_of(convOp.getStrides(), [](APInt element) { - return element.getSExtValue() == 1; - })) - return failure(); - if (!llvm::all_of(convOp.getDilations(), [](APInt element) { - return element.getSExtValue() == 1; - })) - return failure(); - - auto combineDims = [](int64_t a, int64_t b) { - if (ShapedType::isDynamic(a) || ShapedType::isDynamic(b)) - return ShapedType::kDynamic; - return a * b; - }; - - SmallVector reassociationInputOutputIndices; - SmallVector reassociationFilterIndices; - SmallVector reshapedInputShape(2, 0); - SmallVector reshapedFilterShape(2, 0); - SmallVector reshapedOutputShape(2, 0); - if (isNHWC) { - // Generate reassociation indices. - reassociationInputOutputIndices = {{nIndex, ohIndex, owIndex}, {ocIndex}}; - reassociationFilterIndices = {{khIndex, kwIndex, kcIndex}, {kfIndex}}; - - // Generate matmul shapes from 1x1 conv. - reshapedInputShape = { - combineDims(inputShape[ohIndex], inputShape[owIndex]), - inputShape[ocIndex]}; - reshapedFilterShape = {filterShape[kcIndex], filterShape[kfIndex]}; - reshapedOutputShape = { - combineDims(outputShape[ohIndex], outputShape[owIndex]), - outputShape[ocIndex]}; - } else if (isNCHW) { - // Generate reassociation indices. - reassociationInputOutputIndices = {{nIndex, ocIndex}, {ohIndex, owIndex}}; - reassociationFilterIndices = {{kfIndex}, {kcIndex, khIndex, kwIndex}}; - - // Generate matmul shapes from 1x1 conv. - reshapedInputShape = { - inputShape[ocIndex], - combineDims(inputShape[ohIndex], inputShape[owIndex])}; - reshapedFilterShape = {filterShape[kfIndex], filterShape[kcIndex]}; - reshapedOutputShape = { - outputShape[ocIndex], - combineDims(outputShape[ohIndex], outputShape[owIndex])}; + SmallVector dimReplacements; + for (int i = 0; i < numLoops; i++) { + if (llvm::is_contained({khLoopIndex, kwLoopIndex}, i)) { + dimReplacements.push_back( + getAffineConstantExpr(0, rewriter.getContext())); + } else { + dimReplacements.push_back(getAffineDimExpr(i, rewriter.getContext())); + } } - auto reshapedInputType = RankedTensorType::get( - reshapedInputShape, inputShapeType.getElementType()); - - auto reshapedFilterType = RankedTensorType::get( - reshapedFilterShape, filterShapeType.getElementType()); - - auto reshapedOutputType = RankedTensorType::get( - reshapedOutputShape, outputShapeType.getElementType()); - - Value input = convOp.getDpsInputOperand(0)->get(); - Value filter = convOp.getDpsInputOperand(1)->get(); - Value output = convOp.getDpsInitOperand(0)->get(); - auto loc = convOp.getLoc(); - - Value reshapedInput = rewriter.create( - loc, reshapedInputType, input, reassociationInputOutputIndices); - Value reshapedFilter = rewriter.create( - loc, reshapedFilterType, filter, reassociationFilterIndices); - Value reshapedOutput = rewriter.create( - loc, reshapedOutputType, output, reassociationInputOutputIndices); - - SmallVector matmulInput; - if (isNHWC) { - matmulInput = {reshapedInput, reshapedFilter}; - } else if (isNCHW) { - matmulInput = {reshapedFilter, reshapedInput}; - } - auto matmulResult = rewriter.create( - loc, reshapedOutputType, matmulInput, ArrayRef{reshapedOutput}); - - auto reshapedResult = rewriter.create( - loc, outputShapeType, matmulResult.getResults()[0], - reassociationInputOutputIndices); - - rewriter.replaceOp(convOp, ArrayRef{reshapedResult}); - + SmallVector newMaps = convOp.getIndexingMapsArray(); + AffineMap inputMap = newMaps[0]; + SmallVector newExprs = + llvm::map_to_vector(inputMap.getResults(), [&](AffineExpr resultExpr) { + return resultExpr.replaceDims(dimReplacements); + }); + newMaps[0] = AffineMap::get(inputMap.getNumDims(), inputMap.getNumSymbols(), + newExprs, rewriter.getContext()); + + auto genericOp = linalg::generalizeNamedOp(rewriter, convOp).value(); + genericOp.setIndexingMapsAttr(rewriter.getAffineMapArrayAttr(newMaps)); return success(); } }; diff --git a/compiler/src/iree/compiler/GlobalOptimization/GeneralizeLinalgNamedOps.cpp b/compiler/src/iree/compiler/GlobalOptimization/GeneralizeLinalgNamedOps.cpp index 92293bc156ba..99f6268a47b0 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/GeneralizeLinalgNamedOps.cpp +++ b/compiler/src/iree/compiler/GlobalOptimization/GeneralizeLinalgNamedOps.cpp @@ -30,6 +30,34 @@ struct GeneralizeLinalgNamedOpsPass }; } // namespace +/// Returns true of `linalgOp` is a Conv2DNchwFchwOp or Conv2DNhwcHwcfOp with +/// all strides equal to 1 and with a kernel height and width of 1 +static bool isConvFoldableToContraction(linalg::LinalgOp linalgOp) { + auto NCHWOp = dyn_cast(linalgOp.getOperation()); + auto NHWCOp = dyn_cast(linalgOp.getOperation()); + + if (!NCHWOp && !NHWCOp) + return false; + + DenseIntElementsAttr strides = + NCHWOp ? NCHWOp.getStrides() : NHWCOp.getStrides(); + if (!llvm::all_of( + strides, [](APInt element) { return element.getSExtValue() == 1; })) { + return false; + } + + auto filterShapeType = llvm::dyn_cast( + linalgOp.getDpsInputOperand(1)->get().getType()); + if (!filterShapeType) + return false; + + // Adjusting dimension indices based on Conv2DOpType. + const int khIndex = NHWCOp ? 0 : 2; + const int kwIndex = NHWCOp ? 1 : 3; + auto filterShape = filterShapeType.getShape(); + return filterShape[khIndex] == 1 && filterShape[kwIndex] == 1; +} + void GeneralizeLinalgNamedOpsPass::runOnOperation() { auto funcOp = getOperation(); SmallVector namedOpCandidates; @@ -44,7 +72,8 @@ void GeneralizeLinalgNamedOpsPass::runOnOperation() { linalg::LogOp, linalg::MapOp, linalg::MaxOp, linalg::MulOp, linalg::NegFOp, linalg::ReduceOp, linalg::SubOp, linalg::TransposeOp>( - linalgOp.getOperation())) { + linalgOp.getOperation()) || + isConvFoldableToContraction(linalgOp)) { namedOpCandidates.push_back(linalgOp); } }); diff --git a/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp b/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp index 4f9a33e22a2f..bd61d4b6ce76 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp +++ b/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp @@ -101,8 +101,7 @@ void buildGlobalOptimizationPassPipeline( .addPass(IREE::Flow::createCanonicalizerPass) .addPass(createRemoveZeroExtentTensorsPass) .addPass(createDetachElementwiseFromNamedOpsPass) - .addPass(mlir::createLinalgNamedOpConversionPass) - .addPass(createConvert1X1FilterConv2DToMatmulPass); + .addPass(mlir::createLinalgNamedOpConversionPass); mainPassManager.addPass(createEraseUnusedLinalgOperandsPass()); // Expand tensor shapes into SSA values and optimize the whole program. diff --git a/compiler/src/iree/compiler/GlobalOptimization/test/conv1x1_to_matmul.mlir b/compiler/src/iree/compiler/GlobalOptimization/test/conv1x1_to_matmul.mlir index 980db9329b4e..607f137b87b0 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/test/conv1x1_to_matmul.mlir +++ b/compiler/src/iree/compiler/GlobalOptimization/test/conv1x1_to_matmul.mlir @@ -1,4 +1,4 @@ -// RUN: iree-opt --split-input-file -iree-global-opt-convert-1x1-filter-conv2d-to-matmul %s | FileCheck %s +// RUN: iree-opt --split-input-file --mlir-print-local-scope -iree-global-opt-convert-1x1-filter-conv2d-to-matmul %s | FileCheck %s util.func public @nhwc_conv_2d(%input: tensor<1x4x5x2xf32>, %filter: tensor<1x1x2x7xf32>) -> tensor<1x4x5x7xf32> { %0 = tensor.empty() : tensor<1x4x5x7xf32> @@ -9,20 +9,15 @@ util.func public @nhwc_conv_2d(%input: tensor<1x4x5x2xf32>, %filter: tensor<1x1x util.return %1 : tensor<1x4x5x7xf32> } -// CHECK: @nhwc_conv_2d -// CHECK: %[[INPUT:.+]]: tensor<1x4x5x2xf32> -// CHECK: %[[FILTER:.+]]: tensor<1x1x2x7xf32> -// CHECK: %[[OUTPUT:.+]] = tensor.empty() : tensor<1x4x5x7xf32> -// CHECK: %[[RESHAPED_INPUT:.+]] = tensor.collapse_shape %[[INPUT]] {{\[}}[0, 1, 2], [3]] : tensor<1x4x5x2xf32> into tensor<20x2xf32> -// CHECK: %[[RESHAPED_FILTER:.+]] = tensor.collapse_shape %[[FILTER]] {{\[}}[0, 1, 2], [3]] : tensor<1x1x2x7xf32> into tensor<2x7xf32> -// CHECK: %[[RESHAPED_OUTPUT:.+]] = tensor.collapse_shape %[[OUTPUT]] {{\[}}[0, 1, 2], [3]] : tensor<1x4x5x7xf32> into tensor<20x7xf32> -// CHECK: %[[MATMUL_RESULT:.+]] = linalg.matmul ins(%[[RESHAPED_INPUT]], %[[RESHAPED_FILTER]] : tensor<20x2xf32>, tensor<2x7xf32>) outs(%[[RESHAPED_OUTPUT]] : tensor<20x7xf32>) -// CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[MATMUL_RESULT]] {{\[}}[0, 1, 2], [3]] output_shape [1, 4, 5, 7] : tensor<20x7xf32> into tensor<1x4x5x7xf32> -// CHECK: util.return %[[RESULT]] +// CHECK-LABEL: @nhwc_conv_2d +// CHECK: %[[RESULT:.*]] = linalg.generic +// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d6)> +// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)> +// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)> +// CHECK: util.return %[[RESULT]] // ----- -// CHECK: @dynamic_nhwc_conv_2d util.func public @dynamic_nhwc_conv_2d(%input: tensor<1x4x?x2xf32>, %filter: tensor<1x1x2x7xf32>) -> tensor<1x4x?x7xf32> { %c2 = arith.constant 2 : index %d2 = tensor.dim %input, %c2 : tensor<1x4x?x2xf32> @@ -34,34 +29,12 @@ util.func public @dynamic_nhwc_conv_2d(%input: tensor<1x4x?x2xf32>, %filter: ten util.return %1 : tensor<1x4x?x7xf32> } -// CHECK: %[[INPUT:.+]]: tensor<1x4x?x2xf32> -// CHECK: %[[FILTER:.+]]: tensor<1x1x2x7xf32> -// CHECK: %[[C2:.+]] = arith.constant 2 : index -// CHECK: %[[D2:.+]] = tensor.dim %[[INPUT]], %[[C2]] -// CHECK: %[[OUTPUT:.+]] = tensor.empty(%[[D2]]) : tensor<1x4x?x7xf32> -// CHECK: %[[RESHAPED_INPUT:.+]] = tensor.collapse_shape %[[INPUT]] {{\[}}[0, 1, 2], [3]] : tensor<1x4x?x2xf32> into tensor -// CHECK: %[[RESHAPED_FILTER:.+]] = tensor.collapse_shape %[[FILTER]] {{\[}}[0, 1, 2], [3]] : tensor<1x1x2x7xf32> into tensor<2x7xf32> -// CHECK: %[[RESHAPED_OUTPUT:.+]] = tensor.collapse_shape %[[OUTPUT]] {{\[}}[0, 1, 2], [3]] : tensor<1x4x?x7xf32> into tensor -// CHECK: %[[MATMUL_RESULT:.+]] = linalg.matmul ins(%[[RESHAPED_INPUT]], %[[RESHAPED_FILTER]] : tensor, tensor<2x7xf32>) outs(%[[RESHAPED_OUTPUT]] : tensor) -// CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[MATMUL_RESULT]] {{\[}}[0, 1, 2], [3]] - -// ----- - -util.func public @fail_dynamic_nhwc_conv_2d(%input: tensor<1x?x?x2xf32>, %filter: tensor<1x1x2x7xf32>) -> tensor<1x?x?x7xf32> { - %c1 = arith.constant 1 : index - %c2 = arith.constant 2 : index - %d1 = tensor.dim %input, %c1 : tensor<1x?x?x2xf32> - %d2 = tensor.dim %input, %c2 : tensor<1x?x?x2xf32> - %0 = tensor.empty(%d1, %d2) : tensor<1x?x?x7xf32> - %1 = linalg.conv_2d_nhwc_hwcf { - dilations = dense<1> : tensor<2xi64>, - strides = dense<1> : tensor<2xi64> - } ins(%input, %filter : tensor<1x?x?x2xf32>, tensor<1x1x2x7xf32>) outs(%0 : tensor<1x?x?x7xf32>) -> tensor<1x?x?x7xf32> - util.return %1 : tensor<1x?x?x7xf32> -} - -// CHECK: @fail_dynamic_nhwc_conv_2d -// CHECK: linalg.conv_2d_nhwc_hwcf +// CHECK-LABEL: @dynamic_nhwc_conv_2d +// CHECK: %[[RESULT:.*]] = linalg.generic +// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d6)> +// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)> +// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)> +// CHECK: util.return %[[RESULT]] // ----- @@ -73,16 +46,12 @@ util.func public @nchw_conv_2d(%input: tensor<1x2x4x5xf32>, %filter: tensor<7x2x } ins(%input, %filter : tensor<1x2x4x5xf32>, tensor<7x2x1x1xf32>) outs(%0 : tensor<1x7x4x5xf32>) -> tensor<1x7x4x5xf32> util.return %1 : tensor<1x7x4x5xf32> } -// CHECK: @nchw_conv_2d -// CHECK: %[[INPUT:.+]]: tensor<1x2x4x5xf32> -// CHECK: %[[FILTER:.+]]: tensor<7x2x1x1xf32> -// CHECK: %[[OUTPUT:.+]] = tensor.empty() : tensor<1x7x4x5xf32> -// CHECK: %[[RESHAPED_INPUT:.+]] = tensor.collapse_shape %[[INPUT]] {{\[}}[0, 1], [2, 3]] : tensor<1x2x4x5xf32> into tensor<2x20xf32> -// CHECK: %[[RESHAPED_FILTER:.+]] = tensor.collapse_shape %[[FILTER]] {{\[}}[0], [1, 2, 3]] : tensor<7x2x1x1xf32> into tensor<7x2xf32> -// CHECK: %[[RESHAPED_OUTPUT:.+]] = tensor.collapse_shape %[[OUTPUT]] {{\[}}[0, 1], [2, 3]] : tensor<1x7x4x5xf32> into tensor<7x20xf32> -// CHECK: %[[MATMUL_RESULT:.+]] = linalg.matmul ins(%[[RESHAPED_FILTER]], %[[RESHAPED_INPUT]] : tensor<7x2xf32>, tensor<2x20xf32>) outs(%[[RESHAPED_OUTPUT]] : tensor<7x20xf32>) -// CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[MATMUL_RESULT]] {{\[}}[0, 1], [2, 3]] output_shape [1, 7, 4, 5] : tensor<7x20xf32> into tensor<1x7x4x5xf32> -// CHECK: util.return %[[RESULT]] +// CHECK-LABEL: @nchw_conv_2d +// CHECK: %[[RESULT:.*]] = linalg.generic +// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d2, d3)> +// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d4, d5, d6)> +// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)> +// CHECK: util.return %[[RESULT]] // ----- @@ -97,33 +66,27 @@ util.func public @dynamic_nchw_conv_2d(%input: tensor<1x2x4x?xf32>, %filter: ten util.return %1 : tensor<1x7x4x?xf32> } -// CHECK: @dynamic_nchw_conv_2d -// CHECK: %[[INPUT:.+]]: tensor<1x2x4x?xf32> -// CHECK: %[[FILTER:.+]]: tensor<7x2x1x1xf32> -// CHECK: %[[C3:.+]] = arith.constant 3 : index -// CHECK: %[[D3:.+]] = tensor.dim %[[INPUT]], %[[C3]] -// CHECK: %[[OUTPUT:.+]] = tensor.empty(%[[D3]]) : tensor<1x7x4x?xf32> -// CHECK: %[[RESHAPED_INPUT:.+]] = tensor.collapse_shape %[[INPUT]] {{\[}}[0, 1], [2, 3]] : tensor<1x2x4x?xf32> into tensor<2x?xf32> -// CHECK: %[[RESHAPED_FILTER:.+]] = tensor.collapse_shape %[[FILTER]] {{\[}}[0], [1, 2, 3]] : tensor<7x2x1x1xf32> into tensor<7x2xf32> -// CHECK: %[[RESHAPED_OUTPUT:.+]] = tensor.collapse_shape %[[OUTPUT]] {{\[}}[0, 1], [2, 3]] : tensor<1x7x4x?xf32> into tensor<7x?xf32> -// CHECK: %[[MATMUL_RESULT:.+]] = linalg.matmul ins(%[[RESHAPED_FILTER]], %[[RESHAPED_INPUT]] : tensor<7x2xf32>, tensor<2x?xf32>) outs(%[[RESHAPED_OUTPUT]] : tensor<7x?xf32>) -// CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[MATMUL_RESULT]] {{\[}}[0, 1], [2, 3]] -// CHECK: util.return %[[RESULT]] +// CHECK-LABEL: @dynamic_nchw_conv_2d +// CHECK: %[[RESULT:.*]] = linalg.generic +// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d2, d3)> +// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d4, d5, d6)> +// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)> +// CHECK: util.return %[[RESULT]] // ----- -util.func public @fail_dynamic_nchw_conv_2d(%input: tensor<1x2x?x?xf32>, %filter: tensor<7x2x1x1xf32>) -> tensor<1x7x?x?xf32> { - %c2 = arith.constant 2 : index - %c3 = arith.constant 3 : index - %d2 = tensor.dim %input, %c2 : tensor<1x2x?x?xf32> - %d3 = tensor.dim %input, %c3 : tensor<1x2x?x?xf32> +util.func public @strided_nchw_conv_2d(%input: tensor<1x2x?x?xf32>, %filter: tensor<7x2x1x1xf32>, %d2 : index, %d3 : index) -> tensor<1x7x?x?xf32> { %0 = tensor.empty(%d2, %d3) : tensor<1x7x?x?xf32> %1 = linalg.conv_2d_nchw_fchw { dilations = dense<1> : tensor<2xi64>, - strides = dense<1> : tensor<2xi64> + strides = dense<2> : tensor<2xi64> } ins(%input, %filter : tensor<1x2x?x?xf32>, tensor<7x2x1x1xf32>) outs(%0 : tensor<1x7x?x?xf32>) -> tensor<1x7x?x?xf32> util.return %1 : tensor<1x7x?x?xf32> } -// CHECK: @fail_dynamic_nchw_conv_2d -// CHECK: linalg.conv_2d_nchw_fchw +// CHECK-LABEL: @strided_nchw_conv_2d +// CHECK: %[[RESULT:.*]] = linalg.generic +// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d2 * 2, d3 * 2)> +// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d4, d5, d6)> +// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)> +// CHECK: util.return %[[RESULT]] diff --git a/compiler/src/iree/compiler/GlobalOptimization/test/generalize_named_ops.mlir b/compiler/src/iree/compiler/GlobalOptimization/test/generalize_named_ops.mlir index 5111152b7b0d..f3f0f8a0eb9b 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/test/generalize_named_ops.mlir +++ b/compiler/src/iree/compiler/GlobalOptimization/test/generalize_named_ops.mlir @@ -1,4 +1,4 @@ -// RUN: iree-opt --pass-pipeline="builtin.module(util.func(iree-global-opt-generalize-linalg-named-ops))" --split-input-file %s | FileCheck %s +// RUN: iree-opt --mlir-print-local-scope --pass-pipeline="builtin.module(util.func(iree-global-opt-generalize-linalg-named-ops))" --split-input-file %s | FileCheck %s util.func public @generalize_op(%arg0 : tensor, %arg1 : tensor) -> tensor { %c0 = arith.constant 0 : index @@ -34,3 +34,35 @@ util.func public @no_generalize_op_within_dispatch(%arg0 : tensor, %arg // CHECK: %[[ADD:.+]] = linalg.add // CHECK: flow.return %[[ADD]] // CHECK: util.return %[[DISPATCH]] + +// ----- + +util.func public @generalize_1x1_nhwc_conv_2d(%input: tensor<1x4x?x2xf32>, %filter: tensor<1x1x2x7xf32>) -> tensor<1x4x?x7xf32> { + %c2 = arith.constant 2 : index + %d2 = tensor.dim %input, %c2 : tensor<1x4x?x2xf32> + %0 = tensor.empty(%d2) : tensor<1x4x?x7xf32> + %1 = linalg.conv_2d_nhwc_hwcf { + dilations = dense<1> : tensor<2xi64>, + strides = dense<1> : tensor<2xi64> + } ins(%input, %filter : tensor<1x4x?x2xf32>, tensor<1x1x2x7xf32>) outs(%0 : tensor<1x4x?x7xf32>) -> tensor<1x4x?x7xf32> + util.return %1 : tensor<1x4x?x7xf32> +} + +// CHECK-LABEL: @generalize_1x1_nhwc_conv_2d +// CHECK: %[[RESULT:.*]] = linalg.generic +// CHECK: util.return %[[RESULT]] + +// ----- + +util.func public @generalize_1x1_nchw_conv_2d(%input: tensor<1x2x4x5xf32>, %filter: tensor<7x2x1x1xf32>) -> tensor<1x7x4x5xf32> { + %0 = tensor.empty() : tensor<1x7x4x5xf32> + %1 = linalg.conv_2d_nchw_fchw { + dilations = dense<1> : tensor<2xi64>, + strides = dense<1> : tensor<2xi64> + } ins(%input, %filter : tensor<1x2x4x5xf32>, tensor<7x2x1x1xf32>) outs(%0 : tensor<1x7x4x5xf32>) -> tensor<1x7x4x5xf32> + util.return %1 : tensor<1x7x4x5xf32> +} + +// CHECK-LABEL: @generalize_1x1_nchw_conv_2d +// CHECK: %[[RESULT:.*]] = linalg.generic +// CHECK: util.return %[[RESULT]] From a762328d6516c72f74906409b64e6f919afc3d57 Mon Sep 17 00:00:00 2001 From: Benoit Jacob Date: Thu, 24 Oct 2024 14:07:15 -0400 Subject: [PATCH 04/45] Support 8-bit floats in the runtime (#18885) There are 2 commits in this PR: 1. `math.h` conversions for 8-bit float types. This was previously sent for review as https://github.com/iree-org/iree/pull/17989 but is now folded into this PR. 2. General runtime/HAL support. --------- Signed-off-by: Benoit Jacob --- runtime/src/iree/base/internal/math.h | 173 +++++++++++------ runtime/src/iree/base/internal/math_test.cc | 196 ++++++++++++++++++++ runtime/src/iree/hal/buffer_view.h | 12 ++ runtime/src/iree/hal/string_util.c | 75 ++++++++ runtime/src/iree/hal/string_util_test.cc | 18 ++ 5 files changed, 419 insertions(+), 55 deletions(-) diff --git a/runtime/src/iree/base/internal/math.h b/runtime/src/iree/base/internal/math.h index 58dd88d13ea5..1e71e0d4553b 100644 --- a/runtime/src/iree/base/internal/math.h +++ b/runtime/src/iree/base/internal/math.h @@ -275,7 +275,7 @@ static inline uint64_t iree_math_round_up_to_pow2_u64(uint64_t n) { // Define some helper constants for working with a floating-point format with // the given number of {exponent,mantissa} bits. -#define IREE_MATH_FP_FORMAT_CONSTANTS(prefix, ebits, mbits) \ +#define IREE_MATH_FP_FORMAT_CONSTANTS(prefix, ebits, mbits, bias_tweak) \ const int prefix##exp_bits IREE_ATTRIBUTE_UNUSED = ebits; \ const int prefix##mantissa_bits IREE_ATTRIBUTE_UNUSED = mbits; \ const int prefix##sign_shift IREE_ATTRIBUTE_UNUSED = ebits + mbits; \ @@ -287,7 +287,7 @@ static inline uint64_t iree_math_round_up_to_pow2_u64(uint64_t n) { const int prefix##exp_mask IREE_ATTRIBUTE_UNUSED = \ (1u << prefix##sign_shift) - (1u << prefix##exp_shift); \ const int prefix##exp_bias IREE_ATTRIBUTE_UNUSED = \ - (1u << (prefix##exp_bits - 1)) - 1; + bias_tweak + (1u << (prefix##exp_bits - 1)) - 1; // Generic conversion from any less-than-32-bit floating-point format to f32. // The `src` value is typed as a uint32_t for genericity but occupies only the @@ -295,39 +295,54 @@ static inline uint64_t iree_math_round_up_to_pow2_u64(uint64_t n) { // unused. static inline float iree_math_make_f32_from_bits(uint32_t src, int exp_bits, int mantissa_bits, - bool have_infinity) { - IREE_MATH_FP_FORMAT_CONSTANTS(src_, exp_bits, mantissa_bits) - IREE_MATH_FP_FORMAT_CONSTANTS(f32_, 8, 23) + bool have_infinity, + int bias_tweak, + bool nan_as_neg_zero) { + IREE_MATH_FP_FORMAT_CONSTANTS(src_, exp_bits, mantissa_bits, bias_tweak) + IREE_MATH_FP_FORMAT_CONSTANTS(f32_, 8, 23, 0) const uint32_t src_sign = src & src_sign_mask; const uint32_t f32_sign = src_sign << (f32_sign_shift - src_sign_shift); const uint32_t src_exp = src & src_exp_mask; const uint32_t src_mantissa = src & src_mantissa_mask; - uint32_t f32_exp = 0; - uint32_t f32_mantissa = 0; + // Initializing f32_exp and f32_mantissa for the case of normal finite values. + // Below we will overload that in other cases. + uint32_t f32_exp = ((src_exp >> src_exp_shift) + f32_exp_bias - src_exp_bias) + << f32_exp_shift; + uint32_t f32_mantissa = src_mantissa + << (f32_mantissa_bits - src_mantissa_bits); if (src_exp == src_exp_mask) { - // No infinities => more large finite values. - if (!have_infinity && src_mantissa != src_mantissa_mask) { - float sign = (src & src_sign_mask) ? -1.0f : 1.0f; - return sign * 2 * (1u << src_exp_bits) * - ((1u << src_mantissa_bits) + src_mantissa); + // Top exponent value normally means infinity or NaN. + if (have_infinity) { + // NaN or Inf case. + f32_exp = f32_exp_mask; + if (src_mantissa) { + f32_mantissa = f32_mantissa_mask; // Quiet NaN. + } else { + f32_mantissa = 0; // Inf. + } + } else { + // No infinities => more large finite values, unless this is a NaN. + bool is_finite = src_mantissa != src_mantissa_mask || nan_as_neg_zero; + if (is_finite) { + f32_exp = ((src_exp >> src_exp_shift) + f32_exp_bias - src_exp_bias) + << f32_exp_shift; + f32_mantissa = src_mantissa << (f32_mantissa_bits - src_mantissa_bits); + } else { + // NaN. Generate a quiet NaN. + f32_exp = f32_exp_mask; + f32_mantissa = f32_mantissa_mask; + } } - // NaN or Inf case. - f32_exp = f32_exp_mask; - if (src_mantissa) { - // NaN. Generate a quiet NaN. + } else if (src_exp == 0) { + // Zero or subnormal. Generate zero, except in one case: if the source type + // encodes NaN as signed zero, we handle that now. + if (nan_as_neg_zero && src == src_sign_mask) { + f32_exp = f32_exp_mask; f32_mantissa = f32_mantissa_mask; } else { - // Inf. Leave zero mantissa. + f32_exp = 0; + f32_mantissa = 0; } - } else if (src_exp == 0) { - // Zero or subnormal. Generate zero. Leave zero mantissa. - } else { - // Normal finite value. - int arithmetic_src_exp = src_exp >> src_exp_shift; - int arithmetic_f32_exp = arithmetic_src_exp + (1 << (f32_exp_bits - 1)) - - (1 << (src_exp_bits - 1)); - f32_exp = arithmetic_f32_exp << f32_exp_shift; - f32_mantissa = src_mantissa << (f32_mantissa_bits - src_mantissa_bits); } const uint32_t u32_value = f32_sign | f32_exp | f32_mantissa; float f32_value; @@ -340,28 +355,34 @@ static inline float iree_math_make_f32_from_bits(uint32_t src, int exp_bits, // genericity but occupies only the bottom (1 + exp_bits + mantissa_bits) bits. // The upper bits of the return value are unused. static inline uint32_t iree_math_truncate_f32_to_bits_rounding_to_nearest_even( - float value, int exp_bits, int mantissa_bits, bool have_infinity) { - IREE_MATH_FP_FORMAT_CONSTANTS(dst_, exp_bits, mantissa_bits) - IREE_MATH_FP_FORMAT_CONSTANTS(f32_, 8, 23) + float value, int exp_bits, int mantissa_bits, bool have_infinity, + int bias_tweak, bool nan_as_neg_zero) { + IREE_MATH_FP_FORMAT_CONSTANTS(dst_, exp_bits, mantissa_bits, bias_tweak) + IREE_MATH_FP_FORMAT_CONSTANTS(f32_, 8, 23, 0) uint32_t u32_value; memcpy(&u32_value, &value, sizeof value); const uint32_t f32_sign = u32_value & f32_sign_mask; - const uint32_t dst_sign = f32_sign >> (f32_sign_shift - dst_sign_shift); + uint32_t dst_sign = f32_sign >> (f32_sign_shift - dst_sign_shift); const uint32_t f32_exp = u32_value & f32_exp_mask; const uint32_t f32_mantissa = u32_value & f32_mantissa_mask; uint32_t dst_exp = 0; uint32_t dst_mantissa = 0; + bool generate_nan = false; if (f32_exp >= f32_exp_mask) { // NaN or Inf case. dst_exp = dst_exp_mask; if (f32_mantissa || !have_infinity) { // NaN. Generate a quiet NaN. - dst_mantissa = dst_mantissa_mask; + generate_nan = true; } else { // Inf. Leave zero mantissa. } } else if (f32_exp == 0) { // Zero or subnormal. Generate zero. Leave zero mantissa. + if (nan_as_neg_zero) { + // The destination has no signed zero. Avoid accidentally generating NaN. + dst_sign = 0; + } } else { // Normal finite value. int arithmetic_exp = (f32_exp >> f32_exp_shift) - f32_exp_bias; @@ -373,7 +394,7 @@ static inline uint32_t iree_math_truncate_f32_to_bits_rounding_to_nearest_even( dst_exp = dst_exp_mask; if (!have_infinity) { // Generate NaN. - dst_mantissa = dst_mantissa_mask; + generate_nan = true; } } else if (arithmetic_exp < -(1 << (dst_exp_bits - 1))) { // Underflow. Generate zero. Leave zero mantissa. @@ -401,38 +422,52 @@ static inline uint32_t iree_math_truncate_f32_to_bits_rounding_to_nearest_even( biased_f32_mantissa = 0; ++arithmetic_exp; } - // In the !have_infinity case, arithmetic_exp might have been the top - // value already, so incrementing it may have overflown it. - if (!have_infinity && arithmetic_exp > (1 << (dst_exp_bits - 1))) { - dst_exp = dst_exp_mask; - dst_mantissa = dst_mantissa_mask; - } else { - // The exponent increment in the above if() branch may cause overflow. - // This is exercised by converting 65520.0f from f32 to f16. No special - // handling is needed for this case: the above if() branch already set - // biased_f32_mantissa=0, so we will be generating a 0 mantissa, as - // needed for infinite values. - dst_exp = (arithmetic_exp + dst_exp_bias) << dst_exp_shift; - dst_mantissa = - biased_f32_mantissa >> (f32_mantissa_bits - dst_mantissa_bits); + // The exponent increment in the above if() branch may cause overflow. + // This is exercised by converting 65520.0f from f32 to f16. When the + // destination type has infinities, no special handling is needed for this + // case: the above if() branch already set biased_f32_mantissa=0, so we + // will be generating a 0 mantissa, as needed for infinite values. The one + // case where special handling is needed is when the destination type has + // no infinities and we need to generate NaN. + dst_exp = (arithmetic_exp + dst_exp_bias) << dst_exp_shift; + dst_mantissa = + biased_f32_mantissa >> (f32_mantissa_bits - dst_mantissa_bits); + if (!have_infinity && dst_exp > dst_exp_mask) { + generate_nan = true; } } } - uint32_t dst_value = dst_sign | dst_exp | dst_mantissa; - return dst_value; + if (generate_nan) { + if (nan_as_neg_zero) { + return dst_sign_mask; + } else { + return dst_sign | dst_exp_mask | dst_mantissa_mask; + } + } else { + if (nan_as_neg_zero && dst_exp == 0 && dst_mantissa == 0) { + // Negative zero needs to be rounded to positive zero to avoid + // accidentally producing NaN when negative-zero is the NaN encoding. + return 0; + } else { + return dst_sign | dst_exp | dst_mantissa; + } + } } #define IREE_MATH_MAKE_FLOAT_TYPE_HELPERS(NAME, INT_TYPE, EXP_BITS, \ - MANTISSA_BITS, HAVE_INFINITY) \ + MANTISSA_BITS, HAVE_INFINITY, \ + BIAS_TWEAK, NAN_AS_NEG_ZERO) \ /* Converts a to a 32-bit C `float`. */ \ static inline float iree_math_##NAME##_to_f32(INT_TYPE src) { \ return iree_math_make_f32_from_bits(src, EXP_BITS, MANTISSA_BITS, \ - HAVE_INFINITY); \ + HAVE_INFINITY, BIAS_TWEAK, \ + NAN_AS_NEG_ZERO); \ } \ /* Truncates a 32-bit C `float`, rounding to nearest even. */ \ static inline INT_TYPE iree_math_f32_to_##NAME(float value) { \ return iree_math_truncate_f32_to_bits_rounding_to_nearest_even( \ - value, EXP_BITS, MANTISSA_BITS, HAVE_INFINITY); \ + value, EXP_BITS, MANTISSA_BITS, HAVE_INFINITY, BIAS_TWEAK, \ + NAN_AS_NEG_ZERO); \ } \ /* Round-trip f32->f32 rounding via the narrow float type */ \ static inline float iree_math_round_to_nearest_##NAME(float value) { \ @@ -441,16 +476,44 @@ static inline uint32_t iree_math_truncate_f32_to_bits_rounding_to_nearest_even( // IEEE half-precision a.k.a. float16, // https://en.wikipedia.org/wiki/Half-precision_floating-point_format -IREE_MATH_MAKE_FLOAT_TYPE_HELPERS(f16, uint16_t, 5, 10, /*have_infinity=*/true) +IREE_MATH_MAKE_FLOAT_TYPE_HELPERS(f16, uint16_t, 5, 10, /*have_infinity=*/true, + /*bias_tweak=*/0, /*nan_as_neg_zero=*/false) // Bfloat16, https://en.wikipedia.org/wiki/Bfloat16_floating-point_format -IREE_MATH_MAKE_FLOAT_TYPE_HELPERS(bf16, uint16_t, 8, 7, /*have_infinity=*/true) +IREE_MATH_MAKE_FLOAT_TYPE_HELPERS(bf16, uint16_t, 8, 7, /*have_infinity=*/true, + /*bias_tweak=*/0, /*nan_as_neg_zero=*/false) // F8E5M2 type, https://arxiv.org/abs/2209.05433 -IREE_MATH_MAKE_FLOAT_TYPE_HELPERS(f8e5m2, uint8_t, 5, 2, /*have_infinity=*/true) +IREE_MATH_MAKE_FLOAT_TYPE_HELPERS(f8e5m2, uint8_t, 5, 2, /*have_infinity=*/true, + /*bias_tweak=*/0, /*nan_as_neg_zero=*/false) // F8E4M3 type, https://arxiv.org/abs/2209.05433. IREE_MATH_MAKE_FLOAT_TYPE_HELPERS(f8e4m3, uint8_t, 4, 3, - /*have_infinity=*/false) + /*have_infinity=*/false, /*bias_tweak=*/0, + /*nan_as_neg_zero=*/false) + +// F8E5M2FNUZ type, found in some AMD GPUs (MI300), called "BF8" there. +// Quoting LLVM's APFloat.h: +// 8-bit floating point number mostly following IEEE-754 conventions +// and bit layout S1E5M2 described in https://arxiv.org/abs/2206.02915, +// with expanded range and with no infinity or signed zero. +// NaN is represented as negative zero. (FN -> Finite, UZ -> unsigned zero). +// This format's exponent bias is 16, instead of the 15 (2 ** (5 - 1) - 1) +// that IEEE precedent would imply. +IREE_MATH_MAKE_FLOAT_TYPE_HELPERS(f8e5m2fnuz, uint8_t, 5, 2, + /*have_infinity=*/false, /*bias_tweak=*/1, + /*nan_as_neg_zero=*/true) + +// F8E4M3FNUZ type, found in some AMD GPUs (MI300), called "FP8" there. +// Quoting LLVM's APFloat.h: +// 8-bit floating point number mostly following IEEE-754 conventions +// and bit layout S1E4M3 described in https://arxiv.org/abs/2206.02915, +// with expanded range and with no infinity or signed zero. +// NaN is represented as negative zero. (FN -> Finite, UZ -> unsigned zero). +// This format's exponent bias is 8, instead of the 7 (2 ** (4 - 1) - 1) +// that IEEE precedent would imply. +IREE_MATH_MAKE_FLOAT_TYPE_HELPERS(f8e4m3fnuz, uint8_t, 4, 3, + /*have_infinity=*/false, /*bias_tweak=*/1, + /*nan_as_neg_zero=*/true) #endif // IREE_BASE_INTERNAL_MATH_H_ diff --git a/runtime/src/iree/base/internal/math_test.cc b/runtime/src/iree/base/internal/math_test.cc index b3548d762e45..347d7d2ba5c3 100644 --- a/runtime/src/iree/base/internal/math_test.cc +++ b/runtime/src/iree/base/internal/math_test.cc @@ -523,4 +523,200 @@ TEST(F8E4M3ConversionTest, F32ToF8E4M3ToF32) { EXPECT_NE(nan, nan); } +//============================================================================== +// F8E5M2FNUZ support +//============================================================================== + +TEST(F8E5M2FNUZConversionTest, F32ToF8E5M2FNUZ) { + constexpr float kF8E5M2FNUZMax = 57344.f; + constexpr float kF8E5M2FNUZMin = 1.f / 32768.f; + // Within range, normal truncation. + EXPECT_EQ(0x38, iree_math_f32_to_f8e5m2fnuz(0.25f)); + EXPECT_EQ(0xDA, iree_math_f32_to_f8e5m2fnuz(-100.375f)); + EXPECT_EQ(0x7E, iree_math_f32_to_f8e5m2fnuz(49152.f)); + EXPECT_EQ(0xFE, iree_math_f32_to_f8e5m2fnuz(-49152.f)); + EXPECT_EQ(0x7F, iree_math_f32_to_f8e5m2fnuz(kF8E5M2FNUZMax)); + EXPECT_EQ(0xFF, iree_math_f32_to_f8e5m2fnuz(-kF8E5M2FNUZMax)); + EXPECT_EQ(0x04, iree_math_f32_to_f8e5m2fnuz(kF8E5M2FNUZMin)); + EXPECT_EQ(0x84, iree_math_f32_to_f8e5m2fnuz(-kF8E5M2FNUZMin)); + // No infinities, so they convert to NaN, encoded as negative zero. + EXPECT_EQ(0x80, iree_math_f32_to_f8e5m2fnuz(INFINITY)); + EXPECT_EQ(0x80, iree_math_f32_to_f8e5m2fnuz(-INFINITY)); + // Overflow. + EXPECT_EQ(0x80, iree_math_f32_to_f8e5m2fnuz(FLT_MAX)); + EXPECT_EQ(0x80, iree_math_f32_to_f8e5m2fnuz(-FLT_MAX)); + // Underflow + EXPECT_EQ(0, iree_math_f32_to_f8e5m2fnuz(FLT_MIN)); + EXPECT_EQ(0, iree_math_f32_to_f8e5m2fnuz(-FLT_MIN)); // No negative zero. + // Denormals may or may not get flushed to zero. Accept both ways. + uint32_t positive_denormal = iree_math_f32_to_f8e5m2fnuz(kF8E5M2FNUZMin / 2); + EXPECT_TRUE(positive_denormal == 0 || positive_denormal == 0x02); + uint32_t negative_denormal = iree_math_f32_to_f8e5m2fnuz(-kF8E5M2FNUZMin / 2); + // No negative zero. + EXPECT_TRUE(negative_denormal == 0x0 || negative_denormal == 0x02); +} + +TEST(F8E5M2FNUZConversionTest, F32ToF8E5M2ToF32FNUZ) { + constexpr float kF8E5M2FNUZMax = 57344.f; + constexpr float kF8E5M2FNUZMin = 1.f / 32768.f; + // Within range, should just round. + EXPECT_EQ(0.25f, + iree_math_f8e5m2fnuz_to_f32(iree_math_f32_to_f8e5m2fnuz(0.25f))); + EXPECT_EQ(-0.25f, + iree_math_f8e5m2fnuz_to_f32(iree_math_f32_to_f8e5m2fnuz(-0.25f))); + EXPECT_EQ(96.f, + iree_math_f8e5m2fnuz_to_f32(iree_math_f32_to_f8e5m2fnuz(100.375f))); + EXPECT_EQ(-96.f, iree_math_f8e5m2fnuz_to_f32( + iree_math_f32_to_f8e5m2fnuz(-100.375f))); + EXPECT_EQ(96.f, + iree_math_f8e5m2fnuz_to_f32(iree_math_f32_to_f8e5m2fnuz(96.f))); + EXPECT_EQ(-96.f, + iree_math_f8e5m2fnuz_to_f32(iree_math_f32_to_f8e5m2fnuz(-96.f))); + EXPECT_EQ(kF8E5M2FNUZMax, iree_math_f8e5m2fnuz_to_f32( + iree_math_f32_to_f8e5m2fnuz(kF8E5M2FNUZMax))); + EXPECT_EQ(-kF8E5M2FNUZMax, iree_math_f8e5m2fnuz_to_f32( + iree_math_f32_to_f8e5m2fnuz(-kF8E5M2FNUZMax))); + EXPECT_EQ(kF8E5M2FNUZMin, iree_math_f8e5m2fnuz_to_f32( + iree_math_f32_to_f8e5m2fnuz(kF8E5M2FNUZMin))); + EXPECT_EQ(-kF8E5M2FNUZMin, iree_math_f8e5m2fnuz_to_f32( + iree_math_f32_to_f8e5m2fnuz(-kF8E5M2FNUZMin))); + // Powers of two should always be exactly representable across the + // exponent range. + EXPECT_EQ(32768.f, + iree_math_f8e5m2fnuz_to_f32(iree_math_f32_to_f8e5m2fnuz(32768.f))); + EXPECT_EQ(-32768.f, + iree_math_f8e5m2fnuz_to_f32(iree_math_f32_to_f8e5m2fnuz(-32768.f))); + // Overflow + EXPECT_TRUE(std::isnan( + iree_math_f8e5m2fnuz_to_f32(iree_math_f32_to_f8e5m2fnuz(FLT_MAX)))); + EXPECT_TRUE(std::isnan( + iree_math_f8e5m2fnuz_to_f32(iree_math_f32_to_f8e5m2fnuz(-FLT_MAX)))); + EXPECT_GT(kF8E5M2FNUZMax + 1.f, + iree_math_f8e5m2fnuz_to_f32( + iree_math_f32_to_f8e5m2fnuz(kF8E5M2FNUZMax + 1.f))); + // Underflow + EXPECT_EQ(0.0f, + iree_math_f8e5m2fnuz_to_f32(iree_math_f32_to_f8e5m2fnuz(FLT_MIN))); + EXPECT_EQ(0.0f, + iree_math_f8e5m2fnuz_to_f32(iree_math_f32_to_f8e5m2fnuz(-FLT_MIN))); + // Denormals may or may not get flushed to zero. Accept both ways. + float positive_denormal = iree_math_f8e5m2fnuz_to_f32( + iree_math_f32_to_f8e5m2fnuz(kF8E5M2FNUZMin / 2)); + EXPECT_TRUE(positive_denormal == 0.0f || + positive_denormal == 3.05175781e-05f); + // Inf and NaN. No infinities, so we get NaN. + EXPECT_TRUE(std::isnan( + iree_math_f8e5m2fnuz_to_f32(iree_math_f32_to_f8e5m2fnuz(INFINITY)))); + EXPECT_TRUE(std::isnan( + iree_math_f8e5m2fnuz_to_f32(iree_math_f32_to_f8e5m2fnuz(-INFINITY)))); + EXPECT_TRUE(std::isnan( + iree_math_f8e5m2fnuz_to_f32(iree_math_f32_to_f8e5m2fnuz(NAN)))); +} + +//============================================================================== +// F8E4M3FNUZ support +//============================================================================== + +TEST(F8E4M3FNUZConversionTest, F32ToF8E4M3FNUZ) { + // See https://arxiv.org/pdf/2209.05433.pdf, Table 1. + // The F8E4M3 format is special: it has no infinities, and has some larger + // finite values instead. + constexpr float kF8E4M3FNUZMax = 240.f; + constexpr float kF8E4M3FNUZMin = 1.f / 128.f; + // Within range, normal truncation. + EXPECT_EQ(0x30, iree_math_f32_to_f8e4m3fnuz(0.25f)); + EXPECT_EQ(0xF5, iree_math_f32_to_f8e4m3fnuz(-100.375f)); + // Extra large finite values thanks to not having infinities. + EXPECT_EQ(0x7F, iree_math_f32_to_f8e4m3fnuz(kF8E4M3FNUZMax)); + EXPECT_EQ(0x7F, iree_math_f32_to_f8e4m3fnuz(247.0f)); + EXPECT_EQ(0xFF, iree_math_f32_to_f8e4m3fnuz(-kF8E4M3FNUZMax)); + EXPECT_EQ(0xFF, iree_math_f32_to_f8e4m3fnuz(-247.0f)); + // First value that overflows. + EXPECT_EQ(0x80, iree_math_f32_to_f8e4m3fnuz(248.0f)); + EXPECT_EQ(0x80, iree_math_f32_to_f8e4m3fnuz(-248.0f)); + // Min normal values. + EXPECT_EQ(0x08, iree_math_f32_to_f8e4m3fnuz(kF8E4M3FNUZMin)); + EXPECT_EQ(0x88, iree_math_f32_to_f8e4m3fnuz(-kF8E4M3FNUZMin)); + // Infinity + EXPECT_EQ(0x80, iree_math_f32_to_f8e4m3fnuz(INFINITY)); + EXPECT_EQ(0x80, iree_math_f32_to_f8e4m3fnuz(-INFINITY)); + // Overflow + EXPECT_EQ(0x80, iree_math_f32_to_f8e4m3fnuz(FLT_MAX)); + EXPECT_EQ(0x80, iree_math_f32_to_f8e4m3fnuz(-FLT_MAX)); + // Test some round-to-nearest-even behavior. + EXPECT_EQ(0x78, iree_math_f32_to_f8e4m3fnuz(136.0f)); + EXPECT_EQ(0x7A, iree_math_f32_to_f8e4m3fnuz(152.0f)); + EXPECT_EQ(0x7A, iree_math_f32_to_f8e4m3fnuz(168.0f)); + EXPECT_EQ(0x7C, iree_math_f32_to_f8e4m3fnuz(184.0f)); + // Underflow + EXPECT_EQ(0, iree_math_f32_to_f8e4m3fnuz(FLT_MIN)); + EXPECT_EQ(0, iree_math_f32_to_f8e4m3fnuz(-FLT_MIN)); + // Denormals may or may not get flushed to zero. Accept both ways. + uint32_t positive_denormal = iree_math_f32_to_f8e4m3fnuz(kF8E4M3FNUZMin / 2); + EXPECT_TRUE(positive_denormal == 0 || positive_denormal == 0x04); + uint32_t negative_denormal = iree_math_f32_to_f8e4m3fnuz(-kF8E4M3FNUZMin / 2); + EXPECT_TRUE(negative_denormal == 0 || negative_denormal == 0x84); +} + +TEST(F8E4M3FNUZConversionTest, F32ToF8E4M3ToF32FNUZ) { + // See https://arxiv.org/pdf/2209.05433.pdf, Table 1. + // The F8E4M3 format is special: it has no infinities, and has some larger + // finite values instead. + constexpr float kF8E4M3FNUZMax = 240.f; + constexpr float kF8E4M3FNUZMin = 1.f / 128.f; + // Within range, should just round. + EXPECT_EQ(0.25f, + iree_math_f8e4m3fnuz_to_f32(iree_math_f32_to_f8e4m3fnuz(0.25f))); + EXPECT_EQ(-0.25f, + iree_math_f8e4m3fnuz_to_f32(iree_math_f32_to_f8e4m3fnuz(-0.25f))); + EXPECT_EQ(104.f, + iree_math_f8e4m3fnuz_to_f32(iree_math_f32_to_f8e4m3fnuz(100.375f))); + EXPECT_EQ(-104.f, iree_math_f8e4m3fnuz_to_f32( + iree_math_f32_to_f8e4m3fnuz(-100.375f))); + EXPECT_EQ(104.f, + iree_math_f8e4m3fnuz_to_f32(iree_math_f32_to_f8e4m3fnuz(100.4f))); + EXPECT_EQ(-104.f, + iree_math_f8e4m3fnuz_to_f32(iree_math_f32_to_f8e4m3fnuz(-100.4f))); + EXPECT_EQ(kF8E4M3FNUZMax, iree_math_f8e4m3fnuz_to_f32( + iree_math_f32_to_f8e4m3fnuz(kF8E4M3FNUZMax))); + EXPECT_EQ(-kF8E4M3FNUZMax, iree_math_f8e4m3fnuz_to_f32( + iree_math_f32_to_f8e4m3fnuz(-kF8E4M3FNUZMax))); + EXPECT_EQ(kF8E4M3FNUZMin, iree_math_f8e4m3fnuz_to_f32( + iree_math_f32_to_f8e4m3fnuz(kF8E4M3FNUZMin))); + EXPECT_EQ(-kF8E4M3FNUZMin, iree_math_f8e4m3fnuz_to_f32( + iree_math_f32_to_f8e4m3fnuz(-kF8E4M3FNUZMin))); + // Powers of two should always be exactly representable across the + // exponent range. + EXPECT_EQ(128.f, + iree_math_f8e4m3fnuz_to_f32(iree_math_f32_to_f8e4m3fnuz(128.f))); + EXPECT_EQ(-128.f, + iree_math_f8e4m3fnuz_to_f32(iree_math_f32_to_f8e4m3fnuz(-128.f))); + // Overflow + EXPECT_TRUE(std::isnan( + iree_math_f8e4m3fnuz_to_f32(iree_math_f32_to_f8e4m3fnuz(FLT_MAX)))); + EXPECT_TRUE(std::isnan( + iree_math_f8e4m3fnuz_to_f32(iree_math_f32_to_f8e4m3fnuz(-FLT_MAX)))); + EXPECT_GT(kF8E4M3FNUZMax + 1.f, + iree_math_f8e4m3fnuz_to_f32( + iree_math_f32_to_f8e4m3fnuz(kF8E4M3FNUZMax + 1.f))); + // Underflow + EXPECT_EQ(0.0f, + iree_math_f8e4m3fnuz_to_f32(iree_math_f32_to_f8e4m3fnuz(FLT_MIN))); + EXPECT_EQ(0.0f, + iree_math_f8e4m3fnuz_to_f32(iree_math_f32_to_f8e4m3fnuz(-FLT_MIN))); + // Denormals may or may not get flushed to zero. Accept both ways. + float positive_denormal = iree_math_f8e4m3fnuz_to_f32( + iree_math_f32_to_f8e4m3fnuz(kF8E4M3FNUZMin / 2)); + EXPECT_TRUE(positive_denormal == 0.0f || + positive_denormal == 3.05175781e-05f); + // Inf and Nan + EXPECT_TRUE(std::isnan( + iree_math_f8e4m3fnuz_to_f32(iree_math_f32_to_f8e4m3fnuz(INFINITY)))); + EXPECT_TRUE(std::isnan( + iree_math_f8e4m3fnuz_to_f32(iree_math_f32_to_f8e4m3fnuz(-INFINITY)))); + // Check that the result is a Nan with nan != nan. + float nan = iree_math_f8e4m3fnuz_to_f32(iree_math_f32_to_f8e4m3fnuz(NAN)); + EXPECT_NE(nan, nan); +} + } // namespace diff --git a/runtime/src/iree/hal/buffer_view.h b/runtime/src/iree/hal/buffer_view.h index 96b9fd487ce5..b5c4861dcdd0 100644 --- a/runtime/src/iree/hal/buffer_view.h +++ b/runtime/src/iree/hal/buffer_view.h @@ -48,6 +48,14 @@ enum iree_hal_numerical_type_bits_t { IREE_HAL_NUMERICAL_TYPE_FLOAT_BRAIN = IREE_HAL_NUMERICAL_TYPE_FLOAT | 0x02u, // Paired (real, imag) complex number in floating-point format. IREE_HAL_NUMERICAL_TYPE_FLOAT_COMPLEX = IREE_HAL_NUMERICAL_TYPE_FLOAT | 0x03u, + // Ad-hoc entries for the zoo of low-bit-depth float types. They are special + // in that there are many different types sharing the same size. + IREE_HAL_NUMERICAL_TYPE_FLOAT_8_E5M2 = IREE_HAL_NUMERICAL_TYPE_FLOAT | 0x04u, + IREE_HAL_NUMERICAL_TYPE_FLOAT_8_E4M3 = IREE_HAL_NUMERICAL_TYPE_FLOAT | 0x05u, + IREE_HAL_NUMERICAL_TYPE_FLOAT_8_E5M2_FNUZ = + IREE_HAL_NUMERICAL_TYPE_FLOAT | 0x06u, + IREE_HAL_NUMERICAL_TYPE_FLOAT_8_E4M3_FNUZ = + IREE_HAL_NUMERICAL_TYPE_FLOAT | 0x07u, }; typedef uint8_t iree_hal_numerical_type_t; @@ -148,6 +156,10 @@ enum iree_hal_element_types_t { IREE_HAL_ELEMENT_TYPE_BFLOAT_16 = IREE_HAL_ELEMENT_TYPE_VALUE(IREE_HAL_NUMERICAL_TYPE_FLOAT_BRAIN, 16), // NOLINT IREE_HAL_ELEMENT_TYPE_COMPLEX_FLOAT_64 = IREE_HAL_ELEMENT_TYPE_VALUE(IREE_HAL_NUMERICAL_TYPE_FLOAT_COMPLEX, 64), // NOLINT IREE_HAL_ELEMENT_TYPE_COMPLEX_FLOAT_128 = IREE_HAL_ELEMENT_TYPE_VALUE(IREE_HAL_NUMERICAL_TYPE_FLOAT_COMPLEX, 128), // NOLINT + IREE_HAL_ELEMENT_TYPE_FLOAT_8_E5M2 = IREE_HAL_ELEMENT_TYPE_VALUE(IREE_HAL_NUMERICAL_TYPE_FLOAT_8_E5M2, 8), // NOLINT + IREE_HAL_ELEMENT_TYPE_FLOAT_8_E4M3 = IREE_HAL_ELEMENT_TYPE_VALUE(IREE_HAL_NUMERICAL_TYPE_FLOAT_8_E4M3, 8), // NOLINT + IREE_HAL_ELEMENT_TYPE_FLOAT_8_E5M2_FNUZ = IREE_HAL_ELEMENT_TYPE_VALUE(IREE_HAL_NUMERICAL_TYPE_FLOAT_8_E5M2_FNUZ, 8), // NOLINT + IREE_HAL_ELEMENT_TYPE_FLOAT_8_E4M3_FNUZ = IREE_HAL_ELEMENT_TYPE_VALUE(IREE_HAL_NUMERICAL_TYPE_FLOAT_8_E4M3_FNUZ, 8), // NOLINT }; typedef uint32_t iree_hal_element_type_t; // clang-format on diff --git a/runtime/src/iree/hal/string_util.c b/runtime/src/iree/hal/string_util.c index 11cd2ce7b14f..9b097973178e 100644 --- a/runtime/src/iree/hal/string_util.c +++ b/runtime/src/iree/hal/string_util.c @@ -134,6 +134,18 @@ IREE_API_EXPORT iree_status_t iree_hal_parse_element_type( numerical_type = IREE_HAL_NUMERICAL_TYPE_INTEGER_SIGNED; } else if (iree_string_view_consume_prefix(&str_value, IREE_SV("ui"))) { numerical_type = IREE_HAL_NUMERICAL_TYPE_INTEGER_UNSIGNED; + } else if (iree_string_view_equal(str_value, IREE_SV("f8E5M2"))) { + *out_element_type = IREE_HAL_ELEMENT_TYPE_FLOAT_8_E5M2; + return iree_ok_status(); + } else if (iree_string_view_equal(str_value, IREE_SV("f8E4M3"))) { + *out_element_type = IREE_HAL_ELEMENT_TYPE_FLOAT_8_E4M3; + return iree_ok_status(); + } else if (iree_string_view_equal(str_value, IREE_SV("f8E5M2FNUZ"))) { + *out_element_type = IREE_HAL_ELEMENT_TYPE_FLOAT_8_E5M2_FNUZ; + return iree_ok_status(); + } else if (iree_string_view_equal(str_value, IREE_SV("f8E4M3FNUZ"))) { + *out_element_type = IREE_HAL_ELEMENT_TYPE_FLOAT_8_E4M3_FNUZ; + return iree_ok_status(); } else if (iree_string_view_consume_prefix(&str_value, IREE_SV("f"))) { numerical_type = IREE_HAL_NUMERICAL_TYPE_FLOAT_IEEE; } else if (iree_string_view_consume_prefix(&str_value, IREE_SV("bf"))) { @@ -164,6 +176,37 @@ IREE_API_EXPORT iree_status_t iree_hal_parse_element_type( IREE_API_EXPORT iree_status_t iree_hal_format_element_type( iree_hal_element_type_t element_type, iree_host_size_t buffer_capacity, char* buffer, iree_host_size_t* out_buffer_length) { + const char* special_name = NULL; + switch (element_type) { + case IREE_HAL_ELEMENT_TYPE_FLOAT_8_E5M2: + special_name = "f8E5M2"; + break; + case IREE_HAL_ELEMENT_TYPE_FLOAT_8_E4M3: + special_name = "f8E4M3"; + break; + case IREE_HAL_ELEMENT_TYPE_FLOAT_8_E5M2_FNUZ: + special_name = "f8E5M2FNUZ"; + break; + case IREE_HAL_ELEMENT_TYPE_FLOAT_8_E4M3_FNUZ: + special_name = "f8E4M3FNUZ"; + break; + default: + break; + } + if (special_name) { + int n = snprintf(buffer, buffer_capacity, "%s", special_name); + if (n < 0) { + return iree_make_status(IREE_STATUS_FAILED_PRECONDITION, + "snprintf failed"); + } + if (out_buffer_length) { + *out_buffer_length = n; + } + return n >= buffer_capacity + ? iree_status_from_code(IREE_STATUS_OUT_OF_RANGE) + : iree_ok_status(); + } + if (out_buffer_length) { *out_buffer_length = 0; } @@ -366,6 +409,38 @@ static iree_status_t iree_hal_parse_element_unsafe( return iree_string_view_atoi_uint64(data_str, (uint64_t*)out_data) ? iree_ok_status() : iree_status_from_code(IREE_STATUS_INVALID_ARGUMENT); + case IREE_HAL_ELEMENT_TYPE_FLOAT_8_E5M2: { + float temp_float = 0; + if (!iree_string_view_atof(data_str, &temp_float)) { + return iree_status_from_code(IREE_STATUS_INVALID_ARGUMENT); + } + *(uint8_t*)out_data = (uint8_t)iree_math_f32_to_f8e5m2(temp_float); + return iree_ok_status(); + } + case IREE_HAL_ELEMENT_TYPE_FLOAT_8_E4M3: { + float temp_float = 0; + if (!iree_string_view_atof(data_str, &temp_float)) { + return iree_status_from_code(IREE_STATUS_INVALID_ARGUMENT); + } + *(uint8_t*)out_data = (uint8_t)iree_math_f32_to_f8e4m3(temp_float); + return iree_ok_status(); + } + case IREE_HAL_ELEMENT_TYPE_FLOAT_8_E5M2_FNUZ: { + float temp_float = 0; + if (!iree_string_view_atof(data_str, &temp_float)) { + return iree_status_from_code(IREE_STATUS_INVALID_ARGUMENT); + } + *(uint8_t*)out_data = (uint8_t)iree_math_f32_to_f8e5m2fnuz(temp_float); + return iree_ok_status(); + } + case IREE_HAL_ELEMENT_TYPE_FLOAT_8_E4M3_FNUZ: { + float temp_float = 0; + if (!iree_string_view_atof(data_str, &temp_float)) { + return iree_status_from_code(IREE_STATUS_INVALID_ARGUMENT); + } + *(uint8_t*)out_data = (uint8_t)iree_math_f32_to_f8e4m3fnuz(temp_float); + return iree_ok_status(); + } case IREE_HAL_ELEMENT_TYPE_BFLOAT_16: { float temp = 0; if (!iree_string_view_atof(data_str, &temp)) { diff --git a/runtime/src/iree/hal/string_util_test.cc b/runtime/src/iree/hal/string_util_test.cc index 2d134fdbf9e6..8de9fe58f482 100644 --- a/runtime/src/iree/hal/string_util_test.cc +++ b/runtime/src/iree/hal/string_util_test.cc @@ -608,6 +608,14 @@ TEST(ElementTypeStringUtilTest, ParseElementType) { IsOkAndHolds(Eq(IREE_HAL_ELEMENT_TYPE_FLOAT_16))); EXPECT_THAT(ParseElementType("bf16"), IsOkAndHolds(Eq(IREE_HAL_ELEMENT_TYPE_BFLOAT_16))); + EXPECT_THAT(ParseElementType("f8E5M2"), + IsOkAndHolds(Eq(IREE_HAL_ELEMENT_TYPE_FLOAT_8_E5M2))); + EXPECT_THAT(ParseElementType("f8E4M3"), + IsOkAndHolds(Eq(IREE_HAL_ELEMENT_TYPE_FLOAT_8_E4M3))); + EXPECT_THAT(ParseElementType("f8E5M2FNUZ"), + IsOkAndHolds(Eq(IREE_HAL_ELEMENT_TYPE_FLOAT_8_E5M2_FNUZ))); + EXPECT_THAT(ParseElementType("f8E4M3FNUZ"), + IsOkAndHolds(Eq(IREE_HAL_ELEMENT_TYPE_FLOAT_8_E4M3_FNUZ))); EXPECT_THAT(ParseElementType("x64"), IsOkAndHolds(Eq(IREE_HAL_ELEMENT_TYPE_OPAQUE_64))); EXPECT_THAT(ParseElementType("*64"), @@ -635,8 +643,18 @@ TEST(ElementTypeStringUtilTest, FormatElementType) { IsOkAndHolds(Eq("ui16"))); EXPECT_THAT(FormatElementType(IREE_HAL_ELEMENT_TYPE_FLOAT_32), IsOkAndHolds(Eq("f32"))); + EXPECT_THAT(FormatElementType(IREE_HAL_ELEMENT_TYPE_FLOAT_16), + IsOkAndHolds(Eq("f16"))); EXPECT_THAT(FormatElementType(IREE_HAL_ELEMENT_TYPE_BFLOAT_16), IsOkAndHolds(Eq("bf16"))); + EXPECT_THAT(FormatElementType(IREE_HAL_ELEMENT_TYPE_FLOAT_8_E5M2), + IsOkAndHolds(Eq("f8E5M2"))); + EXPECT_THAT(FormatElementType(IREE_HAL_ELEMENT_TYPE_FLOAT_8_E4M3), + IsOkAndHolds(Eq("f8E4M3"))); + EXPECT_THAT(FormatElementType(IREE_HAL_ELEMENT_TYPE_FLOAT_8_E5M2_FNUZ), + IsOkAndHolds(Eq("f8E5M2FNUZ"))); + EXPECT_THAT(FormatElementType(IREE_HAL_ELEMENT_TYPE_FLOAT_8_E4M3_FNUZ), + IsOkAndHolds(Eq("f8E4M3FNUZ"))); EXPECT_THAT(FormatElementType(IREE_HAL_ELEMENT_TYPE_OPAQUE_64), IsOkAndHolds(Eq("*64"))); EXPECT_THAT(FormatElementType(iree_hal_make_element_type( From 2291b3801fde8be787f3dfea53805bc874d72024 Mon Sep 17 00:00:00 2001 From: Benoit Jacob Date: Thu, 24 Oct 2024 14:07:37 -0400 Subject: [PATCH 05/45] Support 8-bit floats in the compiler. (#18886) This is a step in a series of PRs adding support for 8-bit flows. It's sandwiched between https://github.com/iree-org/iree/pull/18885 and subsequent PRs that will actually do something useful with this. --------- Signed-off-by: Benoit Jacob --- compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.td | 9 +++++++++ compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp | 11 +++++++++++ 2 files changed, 20 insertions(+) diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.td b/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.td index 976a4cabccd1..3c1ebd7c9864 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.td +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.td @@ -336,6 +336,11 @@ def HAL_CollectiveElementType_Float16 : I32EnumAttrCase<"Float16", 8, "f16">; def HAL_CollectiveElementType_Float32 : I32EnumAttrCase<"Float32", 9, "f32">; def HAL_CollectiveElementType_Float64 : I32EnumAttrCase<"Float64", 10, "f64">; def HAL_CollectiveElementType_BFloat16 : I32EnumAttrCase<"BFloat16", 11, "bf16">; +def HAL_CollectiveElementType_Float8E5M2 : I32EnumAttrCase<"Float8E5M2", 12, "f8E5M2">; +def HAL_CollectiveElementType_Float8E4M3 : I32EnumAttrCase<"Float8E4M3", 13, "f8E4M3">; +def HAL_CollectiveElementType_Float8E5M2FNUZ : I32EnumAttrCase<"Float8E5M2FNUZ", 14, "f8E5M2FNUZ">; +def HAL_CollectiveElementType_Float8E4M3FNUZ : I32EnumAttrCase<"Float8E4M3FNUZ", 15, "f8E4M3FNUZ">; + def HAL_CollectiveElementTypeAttr : I32EnumAttr<"CollectiveElementType", "valid CollectiveElementType", [ HAL_CollectiveElementType_Sint8, @@ -350,6 +355,10 @@ def HAL_CollectiveElementTypeAttr : HAL_CollectiveElementType_Float32, HAL_CollectiveElementType_Float64, HAL_CollectiveElementType_BFloat16, + HAL_CollectiveElementType_Float8E5M2, + HAL_CollectiveElementType_Float8E4M3, + HAL_CollectiveElementType_Float8E5M2FNUZ, + HAL_CollectiveElementType_Float8E4M3FNUZ ]> { let cppNamespace = "::mlir::iree_compiler::IREE::HAL"; } diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp index d06c2dc892c8..f1a820fb245b 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp @@ -878,6 +878,10 @@ enum class NumericalType : uint32_t { kFloatIEEE = kFloat | 0x01, kFloatBrain = kFloat | 0x02, kFloatComplex = kFloat | 0x03, + kFloat8E5M2 = kFloat | 0x04, + kFloat8E4M3 = kFloat | 0x05, + kFloat8E5M2FNUZ = kFloat | 0x06, + kFloat8E4M3FNUZ = kFloat | 0x07, }; constexpr inline int32_t makeElementTypeValue(NumericalType numericalType, @@ -905,7 +909,14 @@ std::optional ElementTypeOp::getTypeValue(Type type) { return makeElementTypeValue(numericalType, intType.getWidth()); } else if (auto floatType = llvm::dyn_cast_if_present(type)) { switch (APFloat::SemanticsToEnum(floatType.getFloatSemantics())) { + case APFloat::S_Float8E5M2: + return makeElementTypeValue(NumericalType::kFloat8E5M2, 8); + case APFloat::S_Float8E4M3: + return makeElementTypeValue(NumericalType::kFloat8E4M3, 8); + case APFloat::S_Float8E5M2FNUZ: + return makeElementTypeValue(NumericalType::kFloat8E5M2FNUZ, 8); case APFloat::S_Float8E4M3FNUZ: + return makeElementTypeValue(NumericalType::kFloat8E4M3FNUZ, 8); case APFloat::S_IEEEhalf: case APFloat::S_IEEEsingle: case APFloat::S_IEEEdouble: From 4ad834bfcebf45ac94f3b1e2397326f5a5f46006 Mon Sep 17 00:00:00 2001 From: Benoit Jacob Date: Thu, 24 Oct 2024 15:18:13 -0400 Subject: [PATCH 06/45] Support F8E5M2FNUZ MFMA on CDNA3 (#18887) F8E4M3FNUZ was already there. --------- Signed-off-by: Benoit Jacob --- .../ROCM/test/target_device_features.mlir | 4 +- .../Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp | 9 ++++ .../Codegen/Dialect/GPU/IR/IREEGPUEnums.td | 43 ++++++++++++++----- .../Dialect/GPU/TargetUtils/KnownTargets.cpp | 1 + 4 files changed, 44 insertions(+), 13 deletions(-) diff --git a/compiler/plugins/target/ROCM/test/target_device_features.mlir b/compiler/plugins/target/ROCM/test/target_device_features.mlir index 726f52551744..76818168f648 100644 --- a/compiler/plugins/target/ROCM/test/target_device_features.mlir +++ b/compiler/plugins/target/ROCM/test/target_device_features.mlir @@ -15,7 +15,7 @@ // GFX942: target = #iree_gpu.target, , , , , ], +// GFX942-SAME: mma = [, , , , , , ], // GFX942-SAME: subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024], // GFX942-SAME: max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536, // GFX942-SAME: max_workgroup_counts = [2147483647, 2147483647, 2147483647], @@ -26,7 +26,7 @@ // GFX941-SAME: features = "+sramecc,-xnack" // GFX940: target = #iree_gpu.target, , , , , ], +// GFX940-SAME: mma = [, , , , , , ], // GFX1100: target = #iree_gpu.target, , ] diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp index ff6493b06eb4..d9c26ae21e8e 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp @@ -212,6 +212,7 @@ getContractionLayout(vector::ContractionOp contract, ConcreteMmaLayout layout) { static OpaqueMmaLayout getOpaqueMFMALayout(MLIRContext *context, MMAIntrinsic type) { Type f8E4M3FNUZ = Float8E4M3FNUZType::get(context); + Type f8E5M2FNUZ = Float8E5M2FNUZType::get(context); Type f16 = Float16Type::get(context); Type f32 = Float32Type::get(context); @@ -231,6 +232,9 @@ static OpaqueMmaLayout getOpaqueMFMALayout(MLIRContext *context, case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ: { return OpaqueMmaLayout{16, 16, 32, f8E4M3FNUZ, f8E4M3FNUZ, f32}; } + case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ: { + return OpaqueMmaLayout{16, 16, 32, f8E5M2FNUZ, f8E5M2FNUZ, f32}; + } case MMAIntrinsic::MFMA_I32_16x16x32_I8: { return OpaqueMmaLayout{16, 16, 32, i8, i8, i32}; } @@ -472,6 +476,7 @@ MMAAttr::getABCVectorTypes() const { return std::make_tuple(aType, bType, cType); } case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ: + case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ: case MMAIntrinsic::MFMA_I32_16x16x32_I8: { auto aType = VectorType::get({8}, getAType()); auto bType = VectorType::get({8}, getBType()); @@ -518,6 +523,7 @@ int64_t MMAAttr::getBlockSize() const { case MMAIntrinsic::MFMA_F32_32x32x8_F16: case MMAIntrinsic::MFMA_I32_32x32x8_I8: case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ: + case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ: case MMAIntrinsic::MFMA_I32_16x16x32_I8: case MMAIntrinsic::MFMA_I32_32x32x16_I8: case MMAIntrinsic::WMMA_F16_16x16x16_F16: @@ -538,6 +544,7 @@ static int64_t getIntrinsicSubgroupSize(MMAIntrinsic intrinsic) { case MMAIntrinsic::MFMA_F32_32x32x8_F16: case MMAIntrinsic::MFMA_I32_32x32x8_I8: case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ: + case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ: case MMAIntrinsic::MFMA_I32_16x16x32_I8: case MMAIntrinsic::MFMA_I32_32x32x16_I8: { return 64; @@ -602,6 +609,7 @@ MMASingleSubgroupLayout getSingleSubgroupLayout(MMAIntrinsic intrinsic, /*element=*/{4, 1}}; } case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ: + case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ: case MMAIntrinsic::MFMA_I32_16x16x32_I8: switch (fragment) { case MMAFragment::Lhs: @@ -699,6 +707,7 @@ FailureOr MMAAttr::buildMmaOperation(OpBuilder &builder, Location loc, case MMAIntrinsic::MFMA_I32_32x32x8_I8: case MMAIntrinsic::MFMA_F32_32x32x8_F16: case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ: + case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ: case MMAIntrinsic::MFMA_I32_16x16x32_I8: case MMAIntrinsic::MFMA_I32_32x32x16_I8: { auto [m, n, k] = getMNKShape(); diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.td b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.td index d1a91597c79b..d1c84a8d9eb1 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.td +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.td @@ -99,24 +99,44 @@ class IREEGPU_I32MmaEnumAttr } // Format: __xx_ -def MFMA_F32_16x16x4_F32 : I32EnumAttrCase<"MFMA_F32_16x16x4_F32", 0>; -def MFMA_F32_16x16x16_F16 : I32EnumAttrCase<"MFMA_F32_16x16x16_F16", 1>; -def MFMA_F32_32x32x8_F16 : I32EnumAttrCase<"MFMA_F32_32x32x8_F16", 2>; -def MFMA_F32_16x16x32_F8E4M3FNUZ : I32EnumAttrCase<"MFMA_F32_16x16x32_F8E4M3FNUZ", 3>; -def MFMA_I32_16x16x32_I8 : I32EnumAttrCase<"MFMA_I32_16x16x32_I8", 4>; -def MFMA_I32_32x32x16_I8 : I32EnumAttrCase<"MFMA_I32_32x32x16_I8", 5>; +// Values: 0xABCD where: +// * A = vendor: +// * 0 = AMD +// * 1 = NVIDIA +// * B is architecture: +// * For AMD: +// * 0 = RDNA3 +// * 8 = CDNA2 +// * 9 = CDNA3 +// * C is A/B data type: +// * 0 = f32 +// * 1 = f16 +// * 2 = bf16 +// * 3 = f8e5m2 (and variants like fnuz). +// * 4 = f8e4m3 (and variants like fnuz). +// * 8 = i8 +// * D enumerates intrinsics for the same data type. +// +// CDNA3 instrinsics +def MFMA_F32_16x16x4_F32 : I32EnumAttrCase<"MFMA_F32_16x16x4_F32", 0x0900>; +def MFMA_F32_16x16x16_F16 : I32EnumAttrCase<"MFMA_F32_16x16x16_F16", 0x0910>; +def MFMA_F32_32x32x8_F16 : I32EnumAttrCase<"MFMA_F32_32x32x8_F16", 0x0911>; +def MFMA_F32_16x16x32_F8E5M2FNUZ : I32EnumAttrCase<"MFMA_F32_16x16x32_F8E5M2FNUZ", 0x0930>; +def MFMA_F32_16x16x32_F8E4M3FNUZ : I32EnumAttrCase<"MFMA_F32_16x16x32_F8E4M3FNUZ", 0x0940>; +def MFMA_I32_16x16x32_I8 : I32EnumAttrCase<"MFMA_I32_16x16x32_I8", 0x0980>; +def MFMA_I32_32x32x16_I8 : I32EnumAttrCase<"MFMA_I32_32x32x16_I8", 0x0981>; // CDNA2 instrinsics -def MFMA_I32_16x16x16_I8 : I32EnumAttrCase<"MFMA_I32_16x16x16_I8", 6>; -def MFMA_I32_32x32x8_I8 : I32EnumAttrCase<"MFMA_I32_32x32x8_I8", 7>; +def MFMA_I32_16x16x16_I8 : I32EnumAttrCase<"MFMA_I32_16x16x16_I8", 0x0880>; +def MFMA_I32_32x32x8_I8 : I32EnumAttrCase<"MFMA_I32_32x32x8_I8", 0x0881>; // TODO: Create separate WMMA ops for AMD and NVIDIA GPUs -def WMMA_F32_16x16x16_F16 : I32EnumAttrCase<"WMMA_F32_16x16x16_F16", 8>; -def WMMA_F16_16x16x16_F16 : I32EnumAttrCase<"WMMA_F16_16x16x16_F16", 9>; +def WMMA_F32_16x16x16_F16 : I32EnumAttrCase<"WMMA_F32_16x16x16_F16", 0x0010>; +def WMMA_F16_16x16x16_F16 : I32EnumAttrCase<"WMMA_F16_16x16x16_F16", 0x0011>; // TODO: The actual I8 instruction allows specifying (mixed) signedness. // This will need to become its own class of MMA attribute. -def WMMA_I32_16x16x16_I8 : I32EnumAttrCase<"WMMA_I32_16x16x16_I8", 10>; +def WMMA_I32_16x16x16_I8 : I32EnumAttrCase<"WMMA_I32_16x16x16_I8", 0x0080>; def IREEGPU_MMAIntrinsic : IREEGPU_I32MmaEnumAttr<"MMAIntrinsic", "Descriptor for different MMA intrinsics", [ @@ -124,6 +144,7 @@ def IREEGPU_MMAIntrinsic : IREEGPU_I32MmaEnumAttr<"MMAIntrinsic", MFMA_F32_16x16x16_F16, MFMA_F32_32x32x8_F16, MFMA_F32_16x16x32_F8E4M3FNUZ, + MFMA_F32_16x16x32_F8E5M2FNUZ, MFMA_I32_16x16x32_I8, MFMA_I32_32x32x16_I8, MFMA_I32_16x16x16_I8, diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp index 4fa5074e67a4..c187f44b0512 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp @@ -137,6 +137,7 @@ const WgpDetails *getCDNA3WgpDetails() { MMAIntrinsic::MFMA_F32_16x16x16_F16, MMAIntrinsic::MFMA_F32_32x32x8_F16, MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ, + MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ, MMAIntrinsic::MFMA_I32_16x16x32_I8, MMAIntrinsic::MFMA_I32_32x32x16_I8, }; From 225baf2ddec3a9ac1313ae7b43b7325ada47e66d Mon Sep 17 00:00:00 2001 From: Benoit Jacob Date: Thu, 24 Oct 2024 15:44:48 -0400 Subject: [PATCH 07/45] Add e2e tests for F8E5M2FNUZ and F8E4M3FNUZ data-tiled MFMA on CDNA3 (#18888) Signed-off-by: Benoit Jacob --- tests/e2e/matmul/CMakeLists.txt | 64 ++++++++++++++++- tests/e2e/matmul/generate_e2e_matmul_tests.py | 23 +++++- tools/testing/e2e/iree-e2e-matmul-test.cc | 51 ++++++++++++++ tools/testing/e2e/test_utils.c | 70 +++++++++++++++++++ tools/testing/e2e/test_utils.h | 6 ++ 5 files changed, 209 insertions(+), 5 deletions(-) diff --git a/tests/e2e/matmul/CMakeLists.txt b/tests/e2e/matmul/CMakeLists.txt index 36e1255c5bfd..aeb18c361dd9 100644 --- a/tests/e2e/matmul/CMakeLists.txt +++ b/tests/e2e/matmul/CMakeLists.txt @@ -1526,7 +1526,7 @@ iree_generated_e2e_runner_test( iree_generated_e2e_runner_test( NAME - e2e_matmul_rocm_f16_large_cdna3_mfma_data_tiled + e2e_matmul_rocm_f16_cdna3_mfma_data_tiled TEST_TYPE matmul GENERATOR @@ -1555,7 +1555,7 @@ iree_generated_e2e_runner_test( iree_generated_e2e_runner_test( NAME - e2e_matmul_rocm_i8_large_cdna3_mfma_data_tiled + e2e_matmul_rocm_i8_cdna3_mfma_data_tiled TEST_TYPE matmul GENERATOR @@ -1584,7 +1584,7 @@ iree_generated_e2e_runner_test( iree_generated_e2e_runner_test( NAME - e2e_matmul_rocm_f32_large_cdna3_mfma_data_tiled + e2e_matmul_rocm_f32_cdna3_mfma_data_tiled TEST_TYPE matmul GENERATOR @@ -1611,6 +1611,64 @@ iree_generated_e2e_runner_test( "requires-gpu-cdna3" ) +iree_generated_e2e_runner_test( + NAME + e2e_matmul_rocm_f8E5M2FNUZ_cdna3_mfma_data_tiled + TEST_TYPE + matmul + GENERATOR + "generate_e2e_matmul_tests.py" + GENERATOR_ARGS + "--lhs_rhs_type=f8E5M2FNUZ" + "--acc_type=f32" + TEST_RUNNER + iree_tools_testing_e2e_iree-e2e-matmul-test + TARGET_BACKENDS + "rocm" + DRIVERS + "hip" + COMPILER_FLAGS + ${IREE_HIP_TEST_COMPILER_FLAGS} + "--iree-opt-data-tiling" + "--iree-global-opt-experimental-rocm-data-tiling" + "--iree-global-opt-enable-early-materialization=true" + LABELS + "noasan" + "nomsan" + "notsan" + "noubsan" + "requires-gpu-cdna3" +) + +iree_generated_e2e_runner_test( + NAME + e2e_matmul_rocm_f8E4M3FNUZ_cdna3_mfma_data_tiled + TEST_TYPE + matmul + GENERATOR + "generate_e2e_matmul_tests.py" + GENERATOR_ARGS + "--lhs_rhs_type=f8E4M3FNUZ" + "--acc_type=f32" + TEST_RUNNER + iree_tools_testing_e2e_iree-e2e-matmul-test + TARGET_BACKENDS + "rocm" + DRIVERS + "hip" + COMPILER_FLAGS + ${IREE_HIP_TEST_COMPILER_FLAGS} + "--iree-opt-data-tiling" + "--iree-global-opt-experimental-rocm-data-tiling" + "--iree-global-opt-enable-early-materialization=true" + LABELS + "noasan" + "nomsan" + "notsan" + "noubsan" + "requires-gpu-cdna3" +) + endif() elseif(IREE_HIP_TEST_TARGET_CHIP MATCHES "^gfx11") diff --git a/tests/e2e/matmul/generate_e2e_matmul_tests.py b/tests/e2e/matmul/generate_e2e_matmul_tests.py index 30d210dedec0..b5dac41e5b18 100644 --- a/tests/e2e/matmul/generate_e2e_matmul_tests.py +++ b/tests/e2e/matmul/generate_e2e_matmul_tests.py @@ -27,8 +27,11 @@ class MatrixElemTypeId(enum.Enum): I32 = "i32" F32 = "f32" F16 = "f16" - F8E4M3FNUZ = "f8E4M3FNUZ" BF16 = "bf16" + F8E5M2 = "f8E5M2" + F8E4M3 = "f8E4M3" + F8E5M2FNUZ = "f8E5M2FNUZ" + F8E4M3FNUZ = "f8E4M3FNUZ" # Enumerates of the collections of shapes that we can generate tests for. @@ -905,7 +908,17 @@ def parse_arguments(): parser.add_argument( "--lhs_rhs_type", type=str, - choices=["i32", "i8", "f32", "f16", "f8E4M3FNUZ", "bf16"], + choices=[ + "i32", + "i8", + "f32", + "f16", + "bf16", + "f8E5M2", + "f8E4M3", + "f8E5M2FNUZ", + "f8E4M3FNUZ", + ], help="Numeric type of input matrices", required=True, ) @@ -999,6 +1012,12 @@ def write_calls_file(functions, calls, filename, requirements): def infer_acc_type(lhs_rhs_type: MatrixElemTypeId, acc_type: MatrixElemTypeId): if acc_type != MatrixElemTypeId.NONE: return acc_type + if lhs_rhs_type == MatrixElemTypeId.F8E5M2: + return MatrixElemTypeId.F32 + if lhs_rhs_type == MatrixElemTypeId.F8E4M3: + return MatrixElemTypeId.F32 + if lhs_rhs_type == MatrixElemTypeId.F8E5M2FNUZ: + return MatrixElemTypeId.F32 if lhs_rhs_type == MatrixElemTypeId.F8E4M3FNUZ: return MatrixElemTypeId.F32 if lhs_rhs_type == MatrixElemTypeId.I8: diff --git a/tools/testing/e2e/iree-e2e-matmul-test.cc b/tools/testing/e2e/iree-e2e-matmul-test.cc index 230956041cdc..ce589e20851d 100644 --- a/tools/testing/e2e/iree-e2e-matmul-test.cc +++ b/tools/testing/e2e/iree-e2e-matmul-test.cc @@ -128,6 +128,29 @@ static void reference_matmul_bf16_bf16_f32_f32( result_data[n + m * n_size] = acc; } +#define REFERENCE_MATMUL_F8(LHSTYPE, RHSTYPE) \ + static void reference_matmul_##LHSTYPE##_##RHSTYPE##_f32_f32( \ + iree_hal_dim_t m_size, iree_hal_dim_t k_size, iree_hal_dim_t n_size, \ + iree_hal_element_type_t lhs_type, iree_hal_element_type_t rhs_type, \ + iree_hal_element_type_t acc_type, bool transpose_rhs, \ + const uint8_t* lhs_data, const uint8_t* rhs_data, const float* acc_data, \ + float* result_data, iree_hal_dim_t m, iree_hal_dim_t n) { \ + float acc = acc_data ? acc_data[n + m * n_size] : 0; \ + for (iree_hal_dim_t k = 0; k < k_size; ++k) { \ + float lhs_float = \ + iree_math_##LHSTYPE##_to_f32(lhs_data[k + m * k_size]); \ + float rhs_float = iree_math_##RHSTYPE##_to_f32( \ + rhs_data[transpose_rhs ? k + n * k_size : n + k * n_size]); \ + acc += lhs_float * rhs_float; \ + } \ + result_data[n + m * n_size] = acc; \ + } + +REFERENCE_MATMUL_F8(f8e5m2, f8e5m2) +REFERENCE_MATMUL_F8(f8e4m3, f8e4m3) +REFERENCE_MATMUL_F8(f8e5m2fnuz, f8e5m2fnuz) +REFERENCE_MATMUL_F8(f8e4m3fnuz, f8e4m3fnuz) + // Helper for reference_matmul. // Computes one element in the result matrix. static iree_status_t reference_matmul_element( @@ -185,6 +208,34 @@ static iree_status_t reference_matmul_element( m_size, k_size, n_size, lhs_type, rhs_type, acc_type, transpose_rhs, (const uint16_t*)lhs_data, (const uint16_t*)rhs_data, (const float*)acc_data, (float*)result_data, m, n); + } else if (lhs_type == IREE_HAL_ELEMENT_TYPE_FLOAT_8_E5M2 && + rhs_type == IREE_HAL_ELEMENT_TYPE_FLOAT_8_E5M2 && + acc_type == IREE_HAL_ELEMENT_TYPE_FLOAT_32) { + reference_matmul_f8e5m2_f8e5m2_f32_f32( + m_size, k_size, n_size, lhs_type, rhs_type, acc_type, transpose_rhs, + (const uint8_t*)lhs_data, (const uint8_t*)rhs_data, + (const float*)acc_data, (float*)result_data, m, n); + } else if (lhs_type == IREE_HAL_ELEMENT_TYPE_FLOAT_8_E4M3 && + rhs_type == IREE_HAL_ELEMENT_TYPE_FLOAT_8_E4M3 && + acc_type == IREE_HAL_ELEMENT_TYPE_FLOAT_32) { + reference_matmul_f8e4m3_f8e4m3_f32_f32( + m_size, k_size, n_size, lhs_type, rhs_type, acc_type, transpose_rhs, + (const uint8_t*)lhs_data, (const uint8_t*)rhs_data, + (const float*)acc_data, (float*)result_data, m, n); + } else if (lhs_type == IREE_HAL_ELEMENT_TYPE_FLOAT_8_E5M2_FNUZ && + rhs_type == IREE_HAL_ELEMENT_TYPE_FLOAT_8_E5M2_FNUZ && + acc_type == IREE_HAL_ELEMENT_TYPE_FLOAT_32) { + reference_matmul_f8e5m2fnuz_f8e5m2fnuz_f32_f32( + m_size, k_size, n_size, lhs_type, rhs_type, acc_type, transpose_rhs, + (const uint8_t*)lhs_data, (const uint8_t*)rhs_data, + (const float*)acc_data, (float*)result_data, m, n); + } else if (lhs_type == IREE_HAL_ELEMENT_TYPE_FLOAT_8_E4M3_FNUZ && + rhs_type == IREE_HAL_ELEMENT_TYPE_FLOAT_8_E4M3_FNUZ && + acc_type == IREE_HAL_ELEMENT_TYPE_FLOAT_32) { + reference_matmul_f8e4m3fnuz_f8e4m3fnuz_f32_f32( + m_size, k_size, n_size, lhs_type, rhs_type, acc_type, transpose_rhs, + (const uint8_t*)lhs_data, (const uint8_t*)rhs_data, + (const float*)acc_data, (float*)result_data, m, n); } else { return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, "unhandled combination of element types in matmul"); diff --git a/tools/testing/e2e/test_utils.c b/tools/testing/e2e/test_utils.c index a7119dcba771..c54c7190cdb6 100644 --- a/tools/testing/e2e/test_utils.c +++ b/tools/testing/e2e/test_utils.c @@ -93,6 +93,36 @@ iree_test_utils_e2e_value_t iree_test_utils_value_make_i32(int32_t value) { return result; } +iree_test_utils_e2e_value_t iree_test_utils_value_make_f8E5M2(uint8_t value) { + iree_test_utils_e2e_value_t result; + result.type = IREE_TEST_UTILS_VALUE_TYPE_F8E5M2; + result.f8_u8 = value; + return result; +} + +iree_test_utils_e2e_value_t iree_test_utils_value_make_f8E4M3(uint8_t value) { + iree_test_utils_e2e_value_t result; + result.type = IREE_TEST_UTILS_VALUE_TYPE_F8E4M3; + result.f8_u8 = value; + return result; +} + +iree_test_utils_e2e_value_t iree_test_utils_value_make_f8E5M2FNUZ( + uint16_t value) { + iree_test_utils_e2e_value_t result; + result.type = IREE_TEST_UTILS_VALUE_TYPE_F8E5M2FNUZ; + result.f8_u8 = value; + return result; +} + +iree_test_utils_e2e_value_t iree_test_utils_value_make_f8E4M3FNUZ( + uint16_t value) { + iree_test_utils_e2e_value_t result; + result.type = IREE_TEST_UTILS_VALUE_TYPE_F8E4M3FNUZ; + result.f8_u8 = value; + return result; +} + iree_test_utils_e2e_value_t iree_test_utils_value_make_f16(uint16_t value) { iree_test_utils_e2e_value_t result; result.type = IREE_TEST_UTILS_VALUE_TYPE_F16; @@ -123,6 +153,14 @@ iree_test_utils_e2e_value_t iree_test_utils_read_buffer_element( return iree_test_utils_value_make_i16(((int16_t*)data)[index]); } else if (iree_hal_element_type_is_integer(result_type, 32)) { return iree_test_utils_value_make_i32(((int32_t*)data)[index]); + } else if (result_type == IREE_HAL_ELEMENT_TYPE_FLOAT_8_E5M2) { + return iree_test_utils_value_make_f8E5M2(((uint8_t*)data)[index]); + } else if (result_type == IREE_HAL_ELEMENT_TYPE_FLOAT_8_E4M3) { + return iree_test_utils_value_make_f8E4M3(((uint8_t*)data)[index]); + } else if (result_type == IREE_HAL_ELEMENT_TYPE_FLOAT_8_E5M2_FNUZ) { + return iree_test_utils_value_make_f8E5M2FNUZ(((uint8_t*)data)[index]); + } else if (result_type == IREE_HAL_ELEMENT_TYPE_FLOAT_8_E4M3_FNUZ) { + return iree_test_utils_value_make_f8E4M3FNUZ(((uint8_t*)data)[index]); } else if (result_type == IREE_HAL_ELEMENT_TYPE_FLOAT_16) { return iree_test_utils_value_make_f16(((uint16_t*)data)[index]); } else if (result_type == IREE_HAL_ELEMENT_TYPE_BFLOAT_16) { @@ -147,6 +185,22 @@ int iree_test_utils_snprintf_value(char* buf, size_t bufsize, return snprintf(buf, bufsize, "%" PRIi32, value.i32); case IREE_TEST_UTILS_VALUE_TYPE_I64: return snprintf(buf, bufsize, "%" PRIi64, value.i64); + case IREE_TEST_UTILS_VALUE_TYPE_F8E5M2: + return snprintf(buf, bufsize, + precision == PRECISION_HIGH ? "%.3g" : "%.2g", + iree_math_f8e5m2_to_f32(value.f8_u8)); + case IREE_TEST_UTILS_VALUE_TYPE_F8E4M3: + return snprintf(buf, bufsize, + precision == PRECISION_HIGH ? "%.3g" : "%.2g", + iree_math_f8e4m3_to_f32(value.f8_u8)); + case IREE_TEST_UTILS_VALUE_TYPE_F8E5M2FNUZ: + return snprintf(buf, bufsize, + precision == PRECISION_HIGH ? "%.3g" : "%.2g", + iree_math_f8e5m2fnuz_to_f32(value.f8_u8)); + case IREE_TEST_UTILS_VALUE_TYPE_F8E4M3FNUZ: + return snprintf(buf, bufsize, + precision == PRECISION_HIGH ? "%.3g" : "%.2g", + iree_math_f8e4m3fnuz_to_f32(value.f8_u8)); case IREE_TEST_UTILS_VALUE_TYPE_F16: return snprintf(buf, bufsize, precision == PRECISION_HIGH ? "%.5g" : "%.4g", @@ -257,6 +311,18 @@ void iree_test_utils_write_element(iree_hal_element_type_t element_type, case IREE_HAL_ELEMENT_TYPE_BFLOAT_16: *(uint16_t*)dst = iree_math_f32_to_bf16((float)value); break; + case IREE_HAL_ELEMENT_TYPE_FLOAT_8_E5M2: + *(uint8_t*)dst = iree_math_f32_to_f8e5m2((float)value); + break; + case IREE_HAL_ELEMENT_TYPE_FLOAT_8_E4M3: + *(uint8_t*)dst = iree_math_f32_to_f8e4m3((float)value); + break; + case IREE_HAL_ELEMENT_TYPE_FLOAT_8_E5M2_FNUZ: + *(uint8_t*)dst = iree_math_f32_to_f8e5m2fnuz((float)value); + break; + case IREE_HAL_ELEMENT_TYPE_FLOAT_8_E4M3_FNUZ: + *(uint8_t*)dst = iree_math_f32_to_f8e4m3fnuz((float)value); + break; WRITE_ELEMENT_CASE(FLOAT_32, float) WRITE_ELEMENT_CASE(FLOAT_64, double) // clang-format on @@ -296,6 +362,10 @@ void iree_test_utils_get_min_max_for_element_type( *max = +4; break; case IREE_HAL_ELEMENT_TYPE_BFLOAT_16: + case IREE_HAL_ELEMENT_TYPE_FLOAT_8_E5M2: + case IREE_HAL_ELEMENT_TYPE_FLOAT_8_E4M3: + case IREE_HAL_ELEMENT_TYPE_FLOAT_8_E5M2_FNUZ: + case IREE_HAL_ELEMENT_TYPE_FLOAT_8_E4M3_FNUZ: *min = -2; *max = +2; break; diff --git a/tools/testing/e2e/test_utils.h b/tools/testing/e2e/test_utils.h index f095537112e9..46d99f11df13 100644 --- a/tools/testing/e2e/test_utils.h +++ b/tools/testing/e2e/test_utils.h @@ -48,6 +48,11 @@ typedef enum iree_test_utils_value_type_e { IREE_TEST_UTILS_VALUE_TYPE_F64 = 7, // bfloat16 IREE_TEST_UTILS_VALUE_TYPE_BF16 = 8, + // 8-bit float types. + IREE_TEST_UTILS_VALUE_TYPE_F8E5M2 = 9, + IREE_TEST_UTILS_VALUE_TYPE_F8E4M3 = 10, + IREE_TEST_UTILS_VALUE_TYPE_F8E5M2FNUZ = 11, + IREE_TEST_UTILS_VALUE_TYPE_F8E4M3FNUZ = 12, } iree_test_utils_value_type_t; // Maximum size, in bytes, of any value type we can represent. @@ -64,6 +69,7 @@ typedef struct iree_test_utils_value_t { float f32; uint16_t f16_u16; uint16_t bf16_u16; + uint8_t f8_u8; double f64; uint8_t value_storage[IREE_E2E_TEST_VALUE_STORAGE_SIZE]; // max size of all // value types From 8ce8bed38b562e9a3d649494c55cc16f5feb1b48 Mon Sep 17 00:00:00 2001 From: Benoit Jacob Date: Thu, 24 Oct 2024 16:43:33 -0400 Subject: [PATCH 08/45] Simplifications in e2e matmul tests (#18889) Two commits: 1. Stop inferring `acc_type`. Require specifying it. Only a few tests were relying on the inferrence. 2. Stop special-casing narrow float types (only using f32 as ABI type, generating `arith.truncf` internally). This was only needed when these narrow float types were not supported in the rest of IREE. Signed-off-by: Benoit Jacob --- tests/e2e/matmul/BUILD.bazel | 38 ++++++++---- tests/e2e/matmul/CMakeLists.txt | 17 ++++++ tests/e2e/matmul/generate_e2e_matmul_tests.py | 61 +++---------------- 3 files changed, 49 insertions(+), 67 deletions(-) diff --git a/tests/e2e/matmul/BUILD.bazel b/tests/e2e/matmul/BUILD.bazel index 635ee0cc3213..a82bfb691047 100644 --- a/tests/e2e/matmul/BUILD.bazel +++ b/tests/e2e/matmul/BUILD.bazel @@ -360,6 +360,7 @@ X86_64_AVX512_BF16 = X86_64_AVX512 + [ generator = ":generate_e2e_matmul_tests", generator_args = [ "--lhs_rhs_type=%s" % lhs_rhs_type, + "--acc_type=%s" % acc_type, "--shapes=small", ], target_backends_and_drivers = [ @@ -367,9 +368,9 @@ X86_64_AVX512_BF16 = X86_64_AVX512 + [ ], test_runner = "//tools/testing/e2e:iree-e2e-matmul-test", test_type = "matmul", -) for lhs_rhs_type in [ - "i8", - "f32", +) for (lhs_rhs_type, acc_type) in [ + ("i8", "i32"), + ("f32", "f32"), ]] ########################################################################### @@ -383,6 +384,7 @@ iree_generated_e2e_runner_test( generator = ":generate_e2e_matmul_tests", generator_args = [ "--lhs_rhs_type=f32", + "--acc_type=f32", "--shapes=easy_large_static", "--compilation_info=LLVMGPUMatmulSimt", ], @@ -411,6 +413,7 @@ iree_generated_e2e_runner_test( generator = ":generate_e2e_matmul_tests", generator_args = [ "--lhs_rhs_type=f32", + "--acc_type=f32", "--shapes=easy_large_static", "--compilation_info=LLVMGPUMatmulTensorCore", ], @@ -437,6 +440,7 @@ iree_generated_e2e_runner_test( generator = ":generate_e2e_matmul_tests", generator_args = [ "--lhs_rhs_type=f32", + "--acc_type=f32", ], tags = [ # CUDA cuInit fails with sanitizer on. @@ -461,6 +465,7 @@ iree_generated_e2e_runner_test( generator = ":generate_e2e_matmul_tests", generator_args = [ "--lhs_rhs_type=f16", + "--acc_type=f32", ], tags = [ # CUDA cuInit fails with sanitizer on. @@ -486,6 +491,7 @@ iree_generated_e2e_runner_test( generator = ":generate_e2e_matmul_tests", generator_args = [ "--lhs_rhs_type=f32", + "--acc_type=f32", "--shapes=easy_large_static", "--compilation_info=LLVMGPUMatmulTensorCoreMmaSync", ], @@ -513,6 +519,7 @@ iree_generated_e2e_runner_test( generator = ":generate_e2e_matmul_tests", generator_args = [ "--lhs_rhs_type=f16", + "--acc_type=f32", "--shapes=easy_large_static", "--compilation_info=LLVMGPUMatmulTensorCore", ], @@ -540,6 +547,7 @@ iree_generated_e2e_runner_test( generator = ":generate_e2e_matmul_tests", generator_args = [ "--lhs_rhs_type=f16", + "--acc_type=f32", "--shapes=easy_large_static", "--compilation_info=LLVMGPUMatmulTensorCoreMmaSync", ], @@ -566,6 +574,7 @@ iree_generated_e2e_runner_test( generator = ":generate_e2e_matmul_tests", generator_args = [ "--lhs_rhs_type=%s" % lhs_rhs_type, + "--acc_type=%s" % acc_type, ], tags = [ # CUDA cuInit fails with sanitizer on. @@ -580,8 +589,8 @@ iree_generated_e2e_runner_test( ], test_runner = "//tools/testing/e2e:iree-e2e-matmul-test", test_type = "matmul", -) for lhs_rhs_type in [ - "f32", +) for (lhs_rhs_type, acc_type) in [ + ("f32", "f32"), ]] ########################################################################### @@ -598,6 +607,7 @@ iree_generated_e2e_runner_test( generator = ":generate_e2e_matmul_tests", generator_args = [ "--lhs_rhs_type=%s" % lhs_rhs_type, + "--acc_type=%s" % acc_type, "--shapes=easy_large_static", "--compilation_info=SPIRVVectorizeMali", ], @@ -611,10 +621,10 @@ iree_generated_e2e_runner_test( ], test_runner = "//tools/testing/e2e:iree-e2e-matmul-test", test_type = "matmul", -) for lhs_rhs_type in [ - "i8", - "f16", - "f32", +) for (lhs_rhs_type, acc_type) in [ + ("i8", "i32"), + ("f16", "f32"), + ("f32", "f32"), ]] [iree_generated_e2e_runner_test( @@ -625,6 +635,7 @@ iree_generated_e2e_runner_test( generator = ":generate_e2e_matmul_tests", generator_args = [ "--lhs_rhs_type=%s" % lhs_rhs_type, + "--acc_type=%s" % acc_type, "--shapes=easy_large_static", "--compilation_info=SPIRVVectorizeNVIDIA", ], @@ -637,10 +648,10 @@ iree_generated_e2e_runner_test( ], test_runner = "//tools/testing/e2e:iree-e2e-matmul-test", test_type = "matmul", -) for lhs_rhs_type in [ - "i8", - "f16", - "f32", +) for (lhs_rhs_type, acc_type) in [ + ("i8", "i32"), + ("f16", "f32"), + ("f32", "f32"), ]] iree_generated_e2e_runner_test( @@ -651,6 +662,7 @@ iree_generated_e2e_runner_test( generator = ":generate_e2e_matmul_tests", generator_args = [ "--lhs_rhs_type=f16", + "--acc_type=f32", "--shapes=easy_large_static", "--compilation_info=SPIRVCooperativeMatrixVectorize", ], diff --git a/tests/e2e/matmul/CMakeLists.txt b/tests/e2e/matmul/CMakeLists.txt index aeb18c361dd9..98a4ff19b6d4 100644 --- a/tests/e2e/matmul/CMakeLists.txt +++ b/tests/e2e/matmul/CMakeLists.txt @@ -927,6 +927,7 @@ iree_generated_e2e_runner_test( "generate_e2e_matmul_tests.py" GENERATOR_ARGS "--lhs_rhs_type=i8" + "--acc_type=i32" "--shapes=small" TEST_RUNNER iree_tools_testing_e2e_iree-e2e-matmul-test @@ -948,6 +949,7 @@ iree_generated_e2e_runner_test( "generate_e2e_matmul_tests.py" GENERATOR_ARGS "--lhs_rhs_type=f32" + "--acc_type=f32" "--shapes=small" TEST_RUNNER iree_tools_testing_e2e_iree-e2e-matmul-test @@ -969,6 +971,7 @@ iree_generated_e2e_runner_test( "generate_e2e_matmul_tests.py" GENERATOR_ARGS "--lhs_rhs_type=f32" + "--acc_type=f32" "--shapes=easy_large_static" "--compilation_info=LLVMGPUMatmulSimt" TEST_RUNNER @@ -994,6 +997,7 @@ iree_generated_e2e_runner_test( "generate_e2e_matmul_tests.py" GENERATOR_ARGS "--lhs_rhs_type=f32" + "--acc_type=f32" "--shapes=easy_large_static" "--compilation_info=LLVMGPUMatmulTensorCore" TEST_RUNNER @@ -1021,6 +1025,7 @@ iree_generated_e2e_runner_test( "generate_e2e_matmul_tests.py" GENERATOR_ARGS "--lhs_rhs_type=f32" + "--acc_type=f32" TEST_RUNNER iree_tools_testing_e2e_iree-e2e-matmul-test TARGET_BACKENDS @@ -1046,6 +1051,7 @@ iree_generated_e2e_runner_test( "generate_e2e_matmul_tests.py" GENERATOR_ARGS "--lhs_rhs_type=f16" + "--acc_type=f32" TEST_RUNNER iree_tools_testing_e2e_iree-e2e-matmul-test TARGET_BACKENDS @@ -1071,6 +1077,7 @@ iree_generated_e2e_runner_test( "generate_e2e_matmul_tests.py" GENERATOR_ARGS "--lhs_rhs_type=f32" + "--acc_type=f32" "--shapes=easy_large_static" "--compilation_info=LLVMGPUMatmulTensorCoreMmaSync" TEST_RUNNER @@ -1098,6 +1105,7 @@ iree_generated_e2e_runner_test( "generate_e2e_matmul_tests.py" GENERATOR_ARGS "--lhs_rhs_type=f16" + "--acc_type=f32" "--shapes=easy_large_static" "--compilation_info=LLVMGPUMatmulTensorCore" TEST_RUNNER @@ -1125,6 +1133,7 @@ iree_generated_e2e_runner_test( "generate_e2e_matmul_tests.py" GENERATOR_ARGS "--lhs_rhs_type=f16" + "--acc_type=f32" "--shapes=easy_large_static" "--compilation_info=LLVMGPUMatmulTensorCoreMmaSync" TEST_RUNNER @@ -1152,6 +1161,7 @@ iree_generated_e2e_runner_test( "generate_e2e_matmul_tests.py" GENERATOR_ARGS "--lhs_rhs_type=f32" + "--acc_type=f32" TEST_RUNNER iree_tools_testing_e2e_iree-e2e-matmul-test TARGET_BACKENDS @@ -1177,6 +1187,7 @@ iree_generated_e2e_runner_test( "generate_e2e_matmul_tests.py" GENERATOR_ARGS "--lhs_rhs_type=i8" + "--acc_type=i32" "--shapes=easy_large_static" "--compilation_info=SPIRVVectorizeMali" TEST_RUNNER @@ -1201,6 +1212,7 @@ iree_generated_e2e_runner_test( "generate_e2e_matmul_tests.py" GENERATOR_ARGS "--lhs_rhs_type=f16" + "--acc_type=f32" "--shapes=easy_large_static" "--compilation_info=SPIRVVectorizeMali" TEST_RUNNER @@ -1225,6 +1237,7 @@ iree_generated_e2e_runner_test( "generate_e2e_matmul_tests.py" GENERATOR_ARGS "--lhs_rhs_type=f32" + "--acc_type=f32" "--shapes=easy_large_static" "--compilation_info=SPIRVVectorizeMali" TEST_RUNNER @@ -1249,6 +1262,7 @@ iree_generated_e2e_runner_test( "generate_e2e_matmul_tests.py" GENERATOR_ARGS "--lhs_rhs_type=i8" + "--acc_type=i32" "--shapes=easy_large_static" "--compilation_info=SPIRVVectorizeNVIDIA" TEST_RUNNER @@ -1273,6 +1287,7 @@ iree_generated_e2e_runner_test( "generate_e2e_matmul_tests.py" GENERATOR_ARGS "--lhs_rhs_type=f16" + "--acc_type=f32" "--shapes=easy_large_static" "--compilation_info=SPIRVVectorizeNVIDIA" TEST_RUNNER @@ -1297,6 +1312,7 @@ iree_generated_e2e_runner_test( "generate_e2e_matmul_tests.py" GENERATOR_ARGS "--lhs_rhs_type=f32" + "--acc_type=f32" "--shapes=easy_large_static" "--compilation_info=SPIRVVectorizeNVIDIA" TEST_RUNNER @@ -1321,6 +1337,7 @@ iree_generated_e2e_runner_test( "generate_e2e_matmul_tests.py" GENERATOR_ARGS "--lhs_rhs_type=f16" + "--acc_type=f32" "--shapes=easy_large_static" "--compilation_info=SPIRVCooperativeMatrixVectorize" TEST_RUNNER diff --git a/tests/e2e/matmul/generate_e2e_matmul_tests.py b/tests/e2e/matmul/generate_e2e_matmul_tests.py index b5dac41e5b18..cd6f8ebea6d3 100644 --- a/tests/e2e/matmul/generate_e2e_matmul_tests.py +++ b/tests/e2e/matmul/generate_e2e_matmul_tests.py @@ -545,20 +545,6 @@ def int_or_DYN(s: DimSize): return s.value or "DYN" -# Gets friendlier form/type that we can use as arg types which we can cast into the target_type. -def cast_argtype_if_required(target_type: MatrixElemTypeId): - if target_type == MatrixElemTypeId.F8E4M3FNUZ: - return MatrixElemTypeId.F32 - return target_type - - -# Gets the op needed to cast/convert from the friendly form/type into the target_type. -def get_castback_from_arg_op(target_type: MatrixElemTypeId): - if target_type == MatrixElemTypeId.F8E4M3FNUZ: - return "arith.truncf" - return ValueError(f"Unhandled castback type of {target_type}") - - # Describes the fully resolved shape dimensions of all 3 input matrices, # LHS, RHS, and Accumulator, in a testcase. # Each value is a string, which may either represent a positive integer such as "123", @@ -659,9 +645,8 @@ def generate_function( acc_r = int_or_question_mark(shapes.acc_rows) acc_c = int_or_question_mark(shapes.acc_cols) - casted_lhs_rhs_type = cast_argtype_if_required(lhs_rhs_type) - lhs_tensor_type = f"tensor<{lhs_r}x{lhs_c}x{casted_lhs_rhs_type.value}>" - rhs_tensor_type = f"tensor<{rhs_r}x{rhs_c}x{casted_lhs_rhs_type.value}>" + lhs_tensor_type = f"tensor<{lhs_r}x{lhs_c}x{lhs_rhs_type.value}>" + rhs_tensor_type = f"tensor<{rhs_r}x{rhs_c}x{lhs_rhs_type.value}>" acc_tensor_type = f"tensor<{acc_r}x{acc_c}x{acc_type.value}>" if transpose_rhs: @@ -680,15 +665,6 @@ def generate_function( func_definition = func_definition + compilation_info_string generate_function.compilation_index += 1 compute = f" %result = {op_name} {compilation_info_attr}ins(%lhs, %rhs: {lhs_tensor_type}, {rhs_tensor_type}) outs(%acc: {acc_tensor_type}) -> {acc_tensor_type}\n" - if casted_lhs_rhs_type != lhs_rhs_type: - castback_op = get_castback_from_arg_op(lhs_rhs_type) - compute_lhs_tensor_type = f"tensor<{lhs_r}x{lhs_c}x{lhs_rhs_type.value}>" - compute_rhs_tensor_type = f"tensor<{rhs_r}x{rhs_c}x{lhs_rhs_type.value}>" - compute = ( - f" %lhs_casted = {castback_op} %lhs: {lhs_tensor_type} to {compute_lhs_tensor_type}\n" - f" %rhs_casted = {castback_op} %rhs: {rhs_tensor_type} to {compute_rhs_tensor_type}\n" - f" %result = {op_name} {compilation_info_attr}ins(%lhs_casted, %rhs_casted: {compute_lhs_tensor_type}, {compute_rhs_tensor_type}) outs(%acc: {acc_tensor_type}) -> {acc_tensor_type}" - ) if shape.accumulate: signature = f"({lhs_tensor_type}, {rhs_tensor_type}, {acc_tensor_type}) -> {acc_tensor_type}" import_declaration = f"func.func private @module.{func_name}(%lhs: !hal.buffer_view, %rhs: !hal.buffer_view, %acc: !hal.buffer_view) -> !hal.buffer_view" @@ -818,9 +794,8 @@ def generate_call( rhs_shape = [shape.k, shape.n] transpose_rhs = 0 - casted_lhs_rhs_type = cast_argtype_if_required(lhs_rhs_type) - op = op + generate_random_matrix("lhs", lhs_shape, casted_lhs_rhs_type) - op = op + generate_random_matrix("rhs", rhs_shape, casted_lhs_rhs_type) + op = op + generate_random_matrix("lhs", lhs_shape, lhs_rhs_type) + op = op + generate_random_matrix("rhs", rhs_shape, lhs_rhs_type) if shape.accumulate: op = op + generate_random_matrix("acc", [shape.m, shape.n], acc_type) # TODO(#16168): there's a bug with in-place input->output aliasing and @@ -919,16 +894,15 @@ def parse_arguments(): "f8E5M2FNUZ", "f8E4M3FNUZ", ], - help="Numeric type of input matrices", + help="Numeric type of input LHS and RHS matrices", required=True, ) parser.add_argument( "--acc_type", type=str, choices=["i32", "f32", "f16", "bf16"], - help="Numeric type of input matrices", - default="", - required=False, + help="Numeric type of the accumulator and result matrices", + required=True, ) parser.add_argument( "--shapes", @@ -1005,30 +979,9 @@ def write_calls_file(functions, calls, filename, requirements): file.write(module_definition) -# For now, the accumulator type can always be inferred from the input LHS/RHS -# type, so we do that. That is temporary: eventually there will be cases -# where the same input types are used with different accumulator types, e.g. -# f16 inputs with both f16 and f32 accumulator. -def infer_acc_type(lhs_rhs_type: MatrixElemTypeId, acc_type: MatrixElemTypeId): - if acc_type != MatrixElemTypeId.NONE: - return acc_type - if lhs_rhs_type == MatrixElemTypeId.F8E5M2: - return MatrixElemTypeId.F32 - if lhs_rhs_type == MatrixElemTypeId.F8E4M3: - return MatrixElemTypeId.F32 - if lhs_rhs_type == MatrixElemTypeId.F8E5M2FNUZ: - return MatrixElemTypeId.F32 - if lhs_rhs_type == MatrixElemTypeId.F8E4M3FNUZ: - return MatrixElemTypeId.F32 - if lhs_rhs_type == MatrixElemTypeId.I8: - return MatrixElemTypeId.I32 - return lhs_rhs_type - - def main(args): lhs_rhs_type = MatrixElemTypeId(args.lhs_rhs_type) acc_type = MatrixElemTypeId(args.acc_type) - acc_type = infer_acc_type(lhs_rhs_type, acc_type) shapes_id = ShapesId(args.shapes) compilation_info_id = CompilationInfoId(args.compilation_info) From aef6e1fc8fe85f4a57bc06b0272ada18975494fb Mon Sep 17 00:00:00 2001 From: Nirvedh Meshram <96096277+nirvedhmeshram@users.noreply.github.com> Date: Thu, 24 Oct 2024 16:12:54 -0500 Subject: [PATCH 09/45] [GPU] Bail out in GPUReduceBankConflicts if we have collapse_shape user (#18863) This is unsupported by upstream and can lead to a compiler error. https://github.com/llvm/llvm-project/issues/112994 Progress towards: https://github.com/iree-org/iree/issues/18858 --------- Signed-off-by: Nirvedh --- .../Common/GPU/GPUReduceBankConflicts.cpp | 23 +++++++ .../GPU/test/reduce_bank_conflicts.mlir | 60 +++++++++++++++++++ 2 files changed, 83 insertions(+) diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUReduceBankConflicts.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUReduceBankConflicts.cpp index 807ab9d339eb..51898adc02d7 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUReduceBankConflicts.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUReduceBankConflicts.cpp @@ -18,6 +18,23 @@ namespace mlir::iree_compiler { namespace { +/// Check if AllocOp has a CollapseShapeOp user. +static bool hasCollapseShapeUser(memref::AllocOp allocOp) { + SmallVector users(allocOp->getUsers()); + while (!users.empty()) { + auto user = users.pop_back_val(); + if (isa(user)) { + return true; + } + if (isa(user)) { + for (auto u : user->getUsers()) { + users.push_back(u); + } + } + } + return false; +} + /// Pad out the inner dimension of the `memref.alloc` op in order reduce the /// chances to have bank conflicts when reading 2D shapes within shared memory. static void padAlloc(MLIRContext *context, memref::AllocOp allocOp, @@ -28,6 +45,12 @@ static void padAlloc(MLIRContext *context, memref::AllocOp allocOp, int64_t innerDim = allocOpShape.back(); if (ShapedType::isDynamic(innerDim)) return; + + // Return if we have CollapseShape op as an user as padding in that case is + // unsupported. + if (hasCollapseShapeUser(allocOp)) + return; + Type elType = allocOp.getType().getElementType(); unsigned bitwidth = mlir::DataLayout::closest(allocOp).getTypeSizeInBits(elType); diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/reduce_bank_conflicts.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/reduce_bank_conflicts.mlir index befb2445ab24..b934772ffd34 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/reduce_bank_conflicts.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/reduce_bank_conflicts.mlir @@ -47,6 +47,66 @@ func.func @pad_alloc_expand_shape(%a: memref<1024x1024xf32>) { return } +// ----- +// CHECK-LABEL: func.func @no_pad_alloc_collapse_shape +// CHECK: %[[A:.*]] = memref.alloc() : memref<4x2x16x8x8xf32, #gpu.address_space> +// CHECK: %[[C:.*]] = memref.collapse_shape %[[A]] {{\[}}[0], [1, 2], [3, 4]] +// CHECK-SAME: memref<4x2x16x8x8xf32, #gpu.address_space> into +// CHECK-SAME: memref<4x32x64xf32, #gpu.address_space> +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[CST_0:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[VEC_READ:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]]], %[[CST_0]] {in_bounds = [true]} : +// CHECK-SAME: memref<1024x1024xf32>, vector<4xf32> +// CHECK: vector.transfer_write %[[VEC_READ]], %[[C]][%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true]} : +// CHECK-SAME: vector<4xf32>, memref<4x32x64xf32, #gpu.address_space> + + +func.func @no_pad_alloc_collapse_shape(%a: memref<1024x1024xf32>) { + %0 = memref.alloc() : memref<4x2x16x8x8xf32, #gpu.address_space> + %1 = memref.collapse_shape %0 [[0], [1, 2], [3, 4]] + : memref<4x2x16x8x8xf32, #gpu.address_space> into memref<4x32x64xf32, #gpu.address_space> + %c0 = arith.constant 0 : index + %cst_0 = arith.constant 0.000000e+00 : f32 + %3 = vector.transfer_read %a[%c0, %c0], %cst_0 {in_bounds = [true]} : + memref<1024x1024xf32>, vector<4xf32> + vector.transfer_write %3, %1[%c0, %c0, %c0] {in_bounds = [true]} : + vector<4xf32>, memref<4x32x64xf32, #gpu.address_space> + return +} + +// ----- + +// CHECK-LABEL: func.func @no_pad_alloc_collapse_shape_throughsubview +// CHECK: %[[A:.*]] = memref.alloc() : memref<4x2x16x8x8xf32, #gpu.address_space> +// CHECK: %[[S:.*]] = memref.subview %[[A]][0, 0, 0, 0, 0] [4, 2, 16, 8, 8] [1, 1, 1, 1, 1] : +// CHECK-SAME: memref<4x2x16x8x8xf32, #gpu.address_space> to +// CHECK-SAME: memref<4x2x16x8x8xf32, #gpu.address_space> +// CHECK: %[[C:.*]] = memref.collapse_shape %[[S]] {{\[}}[0], [1, 2], [3, 4]] +// CHECK-SAME: memref<4x2x16x8x8xf32, #gpu.address_space> into +// CHECK-SAME: memref<4x32x64xf32, #gpu.address_space> +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[VEC_READ:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]]], %[[CST]] {in_bounds = [true]} : +// CHECK-SAME: memref<1024x1024xf32>, vector<4xf32> +// CHECK: vector.transfer_write %[[VEC_READ]], %[[C]][%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true]} : +// CHECK-SAME: vector<4xf32>, memref<4x32x64xf32, #gpu.address_space> + + +func.func @no_pad_alloc_collapse_shape_throughsubview(%a: memref<1024x1024xf32>) { + %0 = memref.alloc() : memref<4x2x16x8x8xf32, #gpu.address_space> + %subview = memref.subview %0[0, 0, 0, 0, 0] [4, 2, 16, 8, 8] [1, 1, 1, 1, 1] + : memref<4x2x16x8x8xf32, #gpu.address_space> to memref<4x2x16x8x8xf32, #gpu.address_space> + %1 = memref.collapse_shape %subview [[0], [1, 2], [3, 4]] + : memref<4x2x16x8x8xf32, #gpu.address_space> into memref<4x32x64xf32, #gpu.address_space> + %c0 = arith.constant 0 : index + %cst_0 = arith.constant 0.000000e+00 : f32 + %3 = vector.transfer_read %a[%c0, %c0], %cst_0 {in_bounds = [true]} : + memref<1024x1024xf32>, vector<4xf32> + vector.transfer_write %3, %1[%c0, %c0, %c0] {in_bounds = [true]} : + vector<4xf32>, memref<4x32x64xf32, #gpu.address_space> + return +} + // ----- // CHECK-LABEL: func.func @pad_alloc_negative From e96e3c004011550c6b4ecc1520e0f022984c4b17 Mon Sep 17 00:00:00 2001 From: Stanley Winata <68087699+raikonenfnu@users.noreply.github.com> Date: Thu, 24 Oct 2024 16:56:38 -0700 Subject: [PATCH 10/45] [VectorLayout] Fix insertion of new constOp for non dominate issue. (#18894) Main motivation of this patch is to resolve issue where we have the same constOp being used by multiple operations. But with a twist where first time the constOp needs a layout is on a op that topologically comes after other ops that use constOp. This will generate a copy of constOp in the location right before the latter op, which is problematic because this constOp will be used by other ops before it. Previously for the test added we get this error: ``` within split at contraction_layout.mlir:1 offset :24:10: note: see current operation: %9 = "arith.addf"(%8, %6) <{fastmath = #arith.fastmath}> : (vector<96x64xf16>, vector<96x64xf16>) -> vector<96x64xf16> within split at contraction_layout.mlir:1 offset :22:19: error: operand #1 does not dominate this use %scaled_rhs = arith.mulf %read_1, %cst_1 : vector<96x64xf16> ``` While minor, this is also problematic because this error seem to stopped layout analysis (but not fatally) S.T it fails to vector distribute in some cases, making it hard to debug. Signed-off-by: Stanley Winata --- .../GPU/test/gpu_vector_distribution.mlir | 8 ++-- .../Codegen/Common/VectorLayoutAnalysis.cpp | 1 + .../Common/test/vector_layout_analysis.mlir | 41 +++++++++++++++++++ 3 files changed, 46 insertions(+), 4 deletions(-) diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_vector_distribution.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_vector_distribution.mlir index da40806ac73c..cf47ca9d47b5 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_vector_distribution.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_vector_distribution.mlir @@ -666,10 +666,10 @@ builtin.module attributes { transform.with_named_sequence } { } // CHECK-LABEL: func.func @resolve_constant_with_multiple_layout_uses // CHECK-SAME: (%[[ARG0:.+]]: vector<64x64xf16>, %[[ARG0:.+]]: vector<64x64xf16>) -// CHECK: %[[V0:.+]] = arith.constant dense<0.000000e+00> : vector<2x2x16xf16> -// CHECK: %[[V1:.+]] = arith.constant dense<0.000000e+00> : vector<2x2x8xf16> -// CHECK: %[[ADD0:.+]] = arith.addf %{{.+}}, %[[V1]]{{.*}} : vector<2x2x8xf16> -// CHECK: %[[ADD1:.+]] = arith.addf %{{.+}}, %[[V0]]{{.*}} : vector<2x2x16xf16> +// CHECK: %[[V0:.+]] = arith.constant dense<0.000000e+00> : vector<2x2x8xf16> +// CHECK: %[[V1:.+]] = arith.constant dense<0.000000e+00> : vector<2x2x16xf16> +// CHECK: %[[ADD0:.+]] = arith.addf %{{.+}}, %[[V0]]{{.*}} : vector<2x2x8xf16> +// CHECK: %[[ADD1:.+]] = arith.addf %{{.+}}, %[[V1]]{{.*}} : vector<2x2x16xf16> // CHECK: arith.addf %{{.+}}, %[[ADD0]]{{.*}} : vector<2x2x8xf16> transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) { diff --git a/compiler/src/iree/compiler/Codegen/Common/VectorLayoutAnalysis.cpp b/compiler/src/iree/compiler/Codegen/Common/VectorLayoutAnalysis.cpp index fa3786caf61d..deabc58165fb 100644 --- a/compiler/src/iree/compiler/Codegen/Common/VectorLayoutAnalysis.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/VectorLayoutAnalysis.cpp @@ -205,6 +205,7 @@ ChangeResult DistributionLayout::resolveWithPossibleConflict( if (!opOperand.get().hasOneUse() && !vectorLayout && llvm::dyn_cast_or_null( opOperand.get().getDefiningOp())) { + builder.setInsertionPoint(opOperand.get().getDefiningOp()); Operation *copiedConstOp = builder.clone(*opOperand.get().getDefiningOp()); Value copiedConst = copiedConstOp->getResult(0); builder.replaceAllUsesExcept(opOperand.get(), copiedConst, diff --git a/compiler/src/iree/compiler/Codegen/Common/test/vector_layout_analysis.mlir b/compiler/src/iree/compiler/Codegen/Common/test/vector_layout_analysis.mlir index 09c4d2787bf1..28e8ab01d89f 100644 --- a/compiler/src/iree/compiler/Codegen/Common/test/vector_layout_analysis.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/test/vector_layout_analysis.mlir @@ -464,6 +464,47 @@ builtin.module attributes { transform.with_named_sequence } { // ----- +#contract_layout = #iree_vector_ext.nested_layout< + subgroup_tile = [1, 1], + batch_tile = [3, 2], + outer_tile = [4, 1], + thread_tile = [2, 32], + element_tile = [4, 1], + + subgroup_strides = [0, 0], + thread_strides = [32, 1] +> + +// This test ensures that we are not running into ops not dominating constantOp operands after layout analysis. +// We simulate that by doing elmentwise op on the value with "layout" i.e scaled_lhs after scaled_rhs. +// If not handled properly, will generate constOp before "scaled_lhs", but would get used also by "scaled_rhs". +builtin.module attributes { transform.with_named_sequence } { + func.func @handle_multiuse_constant(%lhs: vector<96x64xf16>, %rhs: vector<96x64xf16>, %arr: memref<96x64xf16>) -> () { + %c0 = arith.constant 0 : index + %cst = arith.constant 0.000000e+00 : f16 + %cst_1 = arith.constant dense<1.562500e-02> : vector<96x64xf16> + // expected-remark @above {{thread_strides = [32, 1]}} + %lhs_layout = iree_vector_ext.to_layout %lhs to layout(#contract_layout) : vector<96x64xf16> + + %scaled_rhs = arith.mulf %rhs, %cst_1 : vector<96x64xf16> + // expected-remark @above {{thread_strides = [32, 1]}} + %scaled_lhs = arith.mulf %lhs_layout, %cst_1 : vector<96x64xf16> + // expected-remark @above {{thread_strides = [32, 1]}} + %add = arith.addf %scaled_lhs, %scaled_rhs : vector<96x64xf16> + // expected-remark @above {{thread_strides = [32, 1]}} + vector.transfer_write %add, %arr[%c0, %c0] {in_bounds = [true, true]} : vector<96x64xf16>, memref<96x64xf16> + func.return + } + + transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) { + %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op + transform.iree.test_vector_layout_analysis %top_level_func : !transform.any_op + transform.yield + } +} + +// ----- + #layout = #iree_vector_ext.nested_layout< subgroup_tile = [2, 1, 1], batch_tile = [1, 2, 4], From 3b751a4d2797d29422e08327b1a53933448a26fd Mon Sep 17 00:00:00 2001 From: Prashant Kumar Date: Fri, 25 Oct 2024 08:00:21 +0530 Subject: [PATCH 11/45] [LLVMCPU] Enable tileDispatchUsingForall as default (#18777) --- .../Common/TileDispatchUsingForall.cpp | 13 +++++---- .../iree/compiler/Codegen/LLVMCPU/Passes.cpp | 28 +++++++------------ .../test/ROCDL/pipeline_tile_and_fuse.mlir | 4 +-- .../Dialect/Stream/Builtins/fill_i64.mlir | 9 +++--- .../Dialect/Stream/Builtins/splat_i64.mlir | 9 +++--- .../DispatchCreation/FormDispatchRegions.cpp | 8 ++++++ .../onnx_ops/onnx_ops_cpu_llvm_sync.json | 7 ----- 7 files changed, 38 insertions(+), 40 deletions(-) diff --git a/compiler/src/iree/compiler/Codegen/Common/TileDispatchUsingForall.cpp b/compiler/src/iree/compiler/Codegen/Common/TileDispatchUsingForall.cpp index ebbe585bf53e..218b7f5217f1 100644 --- a/compiler/src/iree/compiler/Codegen/Common/TileDispatchUsingForall.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/TileDispatchUsingForall.cpp @@ -202,13 +202,16 @@ static LogicalResult dropUnitDistributedDims(RewriterBase &rewriter, llvm::SmallDenseSet droppedLoops; for (auto [index, lb, ub, step] : llvm::enumerate(mixedLbs, mixedUbs, mixedSteps)) { - if (!isa(lb) || !isa(ub) || !isa(step)) { + + std::optional lbVal = getConstantIntValue(lb); + std::optional ubVal = getConstantIntValue(ub); + std::optional stepVal = getConstantIntValue(step); + + if (!(lbVal && ubVal && stepVal)) { continue; } - int64_t lbVal = getConstantIntValue(lb).value(); - int64_t ubVal = getConstantIntValue(ub).value(); - int64_t stepVal = getConstantIntValue(step).value(); - if (CEILDIV(ubVal - lbVal, stepVal) == 1) { + + if (CEILDIV(ubVal.value() - lbVal.value(), stepVal.value()) == 1) { droppedLoops.insert(index); } } diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp index 0951fbba4273..71b3aec7389f 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp @@ -95,7 +95,7 @@ static llvm::cl::opt clEnableVectorContractCustomKernels( static llvm::cl::opt clTileDispatchUsingForall( "iree-llvmcpu-tile-dispatch-using-forall", llvm::cl::desc("Enable tile and distribute to workgroups using scf.forall"), - llvm::cl::init(false)); + llvm::cl::init(true)); // By default, IREE does not enable the Armv9-A streaming SVE mode in the // presence of scalable vectors (even when using `+sme`), as currently there's @@ -111,9 +111,8 @@ static llvm::cl::opt clForceArmStreaming( llvm::cl::init(false)); // TODO: Enable `TileDispatchUsingForall` for every pipeline. -static void addTileAndDistributePasses(OpPassManager &funcPassManager, - bool enableTileDispatchUsingForall) { - if (enableTileDispatchUsingForall || clTileDispatchUsingForall) { +static void addTileAndDistributePasses(OpPassManager &funcPassManager) { + if (clTileDispatchUsingForall) { funcPassManager.addPass( createTileAndDistributeToWorkgroupsUsingForallOpPass()); } else { @@ -346,8 +345,7 @@ void buildLLVMCPUVectorLoweringPipeline( void addCPUBufferOpsTileAndVectorizePipeline( OpPassManager &funcPassManager, TilingConfig &tilingConfig, LLVMCPUPipelineOptions &pipelineOpt) { - addTileAndDistributePasses(funcPassManager, - /*enableTileDispatchUsingForall=*/true); + addTileAndDistributePasses(funcPassManager); // Skip tiling reduction loops because this is expected to apply on copy ops // only. @@ -384,8 +382,7 @@ void addCPUBufferOpsTileAndVectorizePipeline( void addMultiTilingExpertPassPipeline(OpPassManager &funcPassManager, TilingConfig &tilingConfig, LLVMCPUPipelineOptions &pipelineOpt) { - addTileAndDistributePasses(funcPassManager, - /*enableTileDispatchUsingForall=*/true); + addTileAndDistributePasses(funcPassManager); SmallVector allFusableLevels(tilingConfig.getFusableLevels()); // Apply tile and fuse to all the non-distribution fusable levels. Skip @@ -464,8 +461,7 @@ void addMultiTilingExpertPassPipeline(OpPassManager &funcPassManager, void addConvTileAndDecomposeExpertPassPipeline( OpPassManager &funcPassManager, TilingConfig &tilingConfig, LLVMCPUPipelineOptions &pipelineOpt) { - addTileAndDistributePasses(funcPassManager, - /*enableTileDispatchUsingForall=*/true); + addTileAndDistributePasses(funcPassManager); // Run LLVMTileAndFuse firstly in case that we have fill + conv + generic // ops. At this stage, we do not apply vectorization. The reduction dim won't @@ -528,8 +524,7 @@ void addConvTileAndDecomposeExpertPassPipeline( void addMmt4dTilingExpertPassPipeline(OpPassManager &funcPassManager, TilingConfig &tilingConfig, LLVMCPUPipelineOptions &pipelineOpt) { - addTileAndDistributePasses(funcPassManager, - /*enableTileDispatchUsingForall=*/true); + addTileAndDistributePasses(funcPassManager); funcPassManager.addPass(createLLVMCPUTileAndFusePass( static_cast(tilingConfig.getVectorCommonParallelLevel()))); @@ -577,8 +572,7 @@ void addMmt4dTilingExpertPassPipeline(OpPassManager &funcPassManager, void addCPUDataTilingPipeline(OpPassManager &funcPassManager, TilingConfig &tilingConfig, LLVMCPUPipelineOptions &pipelineOpt) { - addTileAndDistributePasses(funcPassManager, - /*enableTileDispatchUsingForall=*/true); + addTileAndDistributePasses(funcPassManager); // The below two passes are nop if pack/unpack is not specified in ukernels // attribute. By default, they are disabled. @@ -621,8 +615,7 @@ void addCPUDataTilingPipeline(OpPassManager &funcPassManager, void addCPULinalgExtTileAndVectorizePipeline( OpPassManager &funcPassManager, TilingConfig &tilingConfig, LLVMCPUPipelineOptions &pipelineOpt) { - addTileAndDistributePasses(funcPassManager, - /*enableTileDispatchUsingForall=*/false); + addTileAndDistributePasses(funcPassManager); funcPassManager.addPass( createLLVMCPUTilePass(tilingConfig.getVectorCommonParallelLevel())); // TODO: Remove the pass once we have PartialReductionOpInterface implemented @@ -661,8 +654,7 @@ void addCPULinalgExtTileAndVectorizePipeline( } void addCPUDefaultPassPipeline(OpPassManager &funcPassManager) { - addTileAndDistributePasses(funcPassManager, - /*enableTileDispatchUsingForall=*/false); + addTileAndDistributePasses(funcPassManager); addCPUBufferizePasses(funcPassManager); } diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir index 912acf310b26..2ebc85496759 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir @@ -290,7 +290,7 @@ hal.executable private @main { // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index // CHECK-DAG: %[[C720:.+]] = arith.constant 720 : index // CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index -// CHECK: scf.forall ({{.*}}) in (2, 4, 1, 5) { +// CHECK: scf.forall ({{.*}}) in (2, 4, 5) { // CHECK: %[[LOOP:.+]] = scf.for %[[IV:.+]] = %[[C0]] to %[[C720]] step %[[C2]] {{.*}} -> (vector<1x4x1x4x4x1xf32>) // CHECK: gpu.barrier // CHECK-DAG: %[[LHS_RD:.+]] = vector.transfer_read %[[B0]]{{.*}} vector<8xf16> @@ -307,7 +307,7 @@ hal.executable private @main { // CHECK: %[[LOOP_T:.+]] = vector.transpose %[[LOOP]], [0, 1, 2, 4, 3, 5] : vector<1x4x1x4x4x1xf32> to vector<1x4x1x4x4x1xf32> // CHECK: %[[EXTRACT:.+]] = vector.extract %[[LOOP_T]][0] : vector<4x1x4x4x1xf32> from vector<1x4x1x4x4x1xf32> // CHECK: vector.transfer_write %[[EXTRACT]], %[[B2]] -// CHECK: } {mapping = [#iree_codegen.workgroup_mapping, #iree_codegen.workgroup_mapping, #iree_codegen.workgroup_mapping, #iree_codegen.workgroup_mapping]} +// CHECK: } {mapping = [#iree_codegen.workgroup_mapping, #iree_codegen.workgroup_mapping, #iree_codegen.workgroup_mapping]} // ----- diff --git a/compiler/src/iree/compiler/Dialect/Stream/Builtins/fill_i64.mlir b/compiler/src/iree/compiler/Dialect/Stream/Builtins/fill_i64.mlir index 96e527a20f0f..5d7d686b5bd1 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Builtins/fill_i64.mlir +++ b/compiler/src/iree/compiler/Dialect/Stream/Builtins/fill_i64.mlir @@ -9,16 +9,17 @@ stream.executable private @__builtin_fill_i64 { stream.executable.export public @__builtin_fill_i64 workgroups(%arg0: index) -> (index, index, index) { - %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg0 + %x, %y, %z = flow.dispatch.workgroup_count_from_slice %arg0 stream.return %x, %y, %z : index, index, index } builtin.module { func.func @__builtin_fill_i64(%value: i64, %count: index, %out_binding: !stream.binding) { %c0 = arith.constant 0 : index - %out = stream.binding.subspan %out_binding[%c0] : !stream.binding -> !flow.dispatch.tensor>{%count} - %0 = tensor.empty(%count) : tensor + %count0 = flow.dispatch.workload.ordinal %count, 0 : index + %out = stream.binding.subspan %out_binding[%c0] : !stream.binding -> !flow.dispatch.tensor>{%count0} + %0 = tensor.empty(%count0) : tensor %1 = linalg.fill ins(%value : i64) outs(%0 : tensor) -> tensor - flow.dispatch.tensor.store %1, %out, offsets = [0], sizes = [%count], strides = [1] : tensor -> !flow.dispatch.tensor>{%count} + flow.dispatch.tensor.store %1, %out, offsets = [0], sizes = [%count0], strides = [1] : tensor -> !flow.dispatch.tensor>{%count0} return } } diff --git a/compiler/src/iree/compiler/Dialect/Stream/Builtins/splat_i64.mlir b/compiler/src/iree/compiler/Dialect/Stream/Builtins/splat_i64.mlir index 7d94e51a26d7..4d25d358c7b4 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Builtins/splat_i64.mlir +++ b/compiler/src/iree/compiler/Dialect/Stream/Builtins/splat_i64.mlir @@ -9,16 +9,17 @@ stream.executable private @__builtin_splat_i64 { stream.executable.export public @__builtin_splat_i64 workgroups(%arg0: index) -> (index, index, index) { - %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg0 + %x, %y, %z = flow.dispatch.workgroup_count_from_slice %arg0 stream.return %x, %y, %z : index, index, index } builtin.module { func.func @__builtin_splat_i64(%value: i64, %count: index, %out_binding: !stream.binding) { %c0 = arith.constant 0 : index - %out = stream.binding.subspan %out_binding[%c0] : !stream.binding -> !flow.dispatch.tensor>{%count} - %0 = tensor.empty(%count) : tensor + %count0 = flow.dispatch.workload.ordinal %count, 0 : index + %out = stream.binding.subspan %out_binding[%c0] : !stream.binding -> !flow.dispatch.tensor>{%count0} + %0 = tensor.empty(%count0) : tensor %1 = linalg.fill ins(%value : i64) outs(%0 : tensor) -> tensor - flow.dispatch.tensor.store %1, %out, offsets = [0], sizes = [%count], strides = [1] : tensor -> !flow.dispatch.tensor>{%count} + flow.dispatch.tensor.store %1, %out, offsets = [0], sizes = [%count0], strides = [1] : tensor -> !flow.dispatch.tensor>{%count0} return } } diff --git a/compiler/src/iree/compiler/DispatchCreation/FormDispatchRegions.cpp b/compiler/src/iree/compiler/DispatchCreation/FormDispatchRegions.cpp index e866022eb9a9..b38b1a593001 100644 --- a/compiler/src/iree/compiler/DispatchCreation/FormDispatchRegions.cpp +++ b/compiler/src/iree/compiler/DispatchCreation/FormDispatchRegions.cpp @@ -547,6 +547,14 @@ isFusableWithConsumer(OpOperand &fusedOperand, return false; } + // TODO: Enable grouped convolution and depth wise pooling fusion. + // Rightnow, this is going through the default CPU pipeline and not through + // CONVTilingExpert. + if (isa(producer)) { + return false; + } + auto producerFusionOp = dyn_cast(producer); auto consumerFusionOp = diff --git a/tests/external/iree-test-suites/onnx_ops/onnx_ops_cpu_llvm_sync.json b/tests/external/iree-test-suites/onnx_ops/onnx_ops_cpu_llvm_sync.json index a025431d7af4..f8ca790fe5b1 100644 --- a/tests/external/iree-test-suites/onnx_ops/onnx_ops_cpu_llvm_sync.json +++ b/tests/external/iree-test-suites/onnx_ops/onnx_ops_cpu_llvm_sync.json @@ -392,13 +392,6 @@ "onnx/node/generated/test_softsign_example", "onnx/node/generated/test_stft", "onnx/node/generated/test_stft_with_window", - "onnx/node/generated/test_tfidfvectorizer_tf_batch_onlybigrams_skip0", - "onnx/node/generated/test_tfidfvectorizer_tf_batch_onlybigrams_skip5", - "onnx/node/generated/test_tfidfvectorizer_tf_batch_uniandbigrams_skip5", - "onnx/node/generated/test_tfidfvectorizer_tf_only_bigrams_skip0", - "onnx/node/generated/test_tfidfvectorizer_tf_onlybigrams_levelempty", - "onnx/node/generated/test_tfidfvectorizer_tf_onlybigrams_skip5", - "onnx/node/generated/test_tfidfvectorizer_tf_uniandbigrams_skip5", "onnx/node/generated/test_training_dropout", "onnx/node/generated/test_training_dropout_default", "onnx/node/generated/test_training_dropout_default_mask", From 1fc6e5b62c6fe91a8bdc387205c9d211a7a3151c Mon Sep 17 00:00:00 2001 From: Benoit Jacob Date: Fri, 25 Oct 2024 09:47:11 -0400 Subject: [PATCH 12/45] Add CDNA3 MFMA BF16 intrinsics. (#18892) Signed-off-by: Benoit Jacob --- .../ROCM/test/target_device_features.mlir | 4 +- .../GPU/test/gpu_materialize_encoding.mlir | 60 +++++++++++++++++++ .../Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp | 60 ++++++++++++++++++- .../Codegen/Dialect/GPU/IR/IREEGPUEnums.td | 4 ++ .../Dialect/GPU/TargetUtils/KnownTargets.cpp | 2 + tests/e2e/matmul/CMakeLists.txt | 29 +++++++++ 6 files changed, 155 insertions(+), 4 deletions(-) diff --git a/compiler/plugins/target/ROCM/test/target_device_features.mlir b/compiler/plugins/target/ROCM/test/target_device_features.mlir index 76818168f648..578cd5921c59 100644 --- a/compiler/plugins/target/ROCM/test/target_device_features.mlir +++ b/compiler/plugins/target/ROCM/test/target_device_features.mlir @@ -15,7 +15,7 @@ // GFX942: target = #iree_gpu.target, , , , , , ], +// GFX942-SAME: mma = [, , , , , , , , ], // GFX942-SAME: subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024], // GFX942-SAME: max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536, // GFX942-SAME: max_workgroup_counts = [2147483647, 2147483647, 2147483647], @@ -26,7 +26,7 @@ // GFX941-SAME: features = "+sramecc,-xnack" // GFX940: target = #iree_gpu.target, , , , , , ], +// GFX940-SAME: mma = [, , , , , , , , ], // GFX1100: target = #iree_gpu.target, , ] diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_materialize_encoding.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_materialize_encoding.mlir index e20fc3b88608..90becb209c6c 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_materialize_encoding.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_materialize_encoding.mlir @@ -1130,3 +1130,63 @@ func.func @batch_matmul_lowering_MFMA_F32_16x16x32_F8E4M3FNUZ() { // CHECK-SAME: iterator_types = [#iree_gpu.iterator_type, #iree_gpu.iterator_type, #iree_gpu.iterator_type, #iree_gpu.iterator_type] // CHECK-SAME: kind = #iree_gpu.data_tiled_mma_layout // CHECK: flow.dispatch.tensor.store %[[MMA]], %[[ACC_BINDING]] + +// ----- + +#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> +#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)> +#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> +#encoding_lhs = #iree_encoding.encoding +#encoding_rhs = #iree_encoding.encoding +#encoding_result = #iree_encoding.encoding +#pipeline_layout_4 = #hal.pipeline.layout, + #hal.pipeline.binding, + #hal.pipeline.binding +]> +func.func @batch_matmul_lowering_MFMA_F32_16x16x16_BF16() { + %c0 = arith.constant 0 : index + %B = hal.interface.constant.load layout(#pipeline_layout_4) ordinal(0) : index + %M = hal.interface.constant.load layout(#pipeline_layout_4) ordinal(1) : index + %N = hal.interface.constant.load layout(#pipeline_layout_4) ordinal(2) : index + %K = hal.interface.constant.load layout(#pipeline_layout_4) ordinal(3) : index + %0 = hal.interface.binding.subspan layout(#pipeline_layout_4) binding(0) alignment(64) offset(%c0) + : !flow.dispatch.tensor>{%B, %M, %K} + %1 = hal.interface.binding.subspan layout(#pipeline_layout_4) binding(1) alignment(64) offset(%c0) + : !flow.dispatch.tensor>{%B, %K, %N} + %2 = hal.interface.binding.subspan layout(#pipeline_layout_4) binding(2) alignment(64) offset(%c0) + : !flow.dispatch.tensor>{%B, %M, %N} + %3 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0], sizes = [%B, %M, %K], strides = [1, 1, 1] + : !flow.dispatch.tensor>{%B, %M, %K} + -> tensor + %4 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0], sizes = [%B, %K, %N], strides = [1, 1, 1] + : !flow.dispatch.tensor>{%B, %K, %N} + -> tensor + %5 = flow.dispatch.tensor.load %2, offsets = [0, 0, 0], sizes = [%B, %M, %N], strides = [1, 1, 1] + : !flow.dispatch.tensor>{%B, %M, %N} + -> tensor + %6 = linalg.batch_matmul + ins(%3, %4 : tensor, + tensor) + outs(%5 : tensor) + -> tensor + flow.dispatch.tensor.store %6, %2, offsets = [0, 0, 0], sizes = [%B, %M, %N], strides = [1, 1, 1] + : tensor + -> !flow.dispatch.tensor>{%B, %M, %N} + return +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> +// CHECK: func.func @batch_matmul_lowering_MFMA_F32_16x16x16_BF16 +// CHECK-DAG: %[[LHS_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(0) +// CHECK-DAG: %[[RHS_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(1) +// CHECK-DAG: %[[ACC_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(2) +// CHECK-DAG: %[[LHS:.+]] = flow.dispatch.tensor.load %[[LHS_BINDING]]{{.+}} -> tensor +// CHECK-DAG: %[[RHS:.+]] = flow.dispatch.tensor.load %[[RHS_BINDING]]{{.+}} -> tensor +// CHECK-DAG: %[[ACC:.+]] = flow.dispatch.tensor.load %[[ACC_BINDING]]{{.+}} -> tensor +// CHECK: %[[MMA:.+]] = iree_gpu.multi_mma %[[LHS]], %[[RHS]], %[[ACC]] +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]], +// CHECK-SAME: iterator_types = [#iree_gpu.iterator_type, #iree_gpu.iterator_type, #iree_gpu.iterator_type, #iree_gpu.iterator_type] +// CHECK-SAME: kind = #iree_gpu.data_tiled_mma_layout +// CHECK: flow.dispatch.tensor.store %[[MMA]], %[[ACC_BINDING]] diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp index d9c26ae21e8e..41c099f12809 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp @@ -214,6 +214,7 @@ static OpaqueMmaLayout getOpaqueMFMALayout(MLIRContext *context, Type f8E4M3FNUZ = Float8E4M3FNUZType::get(context); Type f8E5M2FNUZ = Float8E5M2FNUZType::get(context); Type f16 = Float16Type::get(context); + Type bf16 = BFloat16Type::get(context); Type f32 = Float32Type::get(context); Type i8 = IntegerType::get(context, 8); @@ -229,6 +230,12 @@ static OpaqueMmaLayout getOpaqueMFMALayout(MLIRContext *context, case MMAIntrinsic::MFMA_F32_32x32x8_F16: { return OpaqueMmaLayout{32, 32, 8, f16, f16, f32}; } + case MMAIntrinsic::MFMA_F32_16x16x16_BF16: { + return OpaqueMmaLayout{16, 16, 16, bf16, bf16, f32}; + } + case MMAIntrinsic::MFMA_F32_32x32x8_BF16: { + return OpaqueMmaLayout{32, 32, 8, bf16, bf16, f32}; + } case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ: { return OpaqueMmaLayout{16, 16, 32, f8E4M3FNUZ, f8E4M3FNUZ, f32}; } @@ -336,6 +343,45 @@ static ConcreteMmaLayout getConcreteMFMALayout(MLIRContext *context, return ConcreteMmaLayout{opaqueLayout, aMLayout, aKLayout, bKLayout, bNLayout, cMLayout, cNLayout}; } + case MMAIntrinsic::MFMA_F32_16x16x16_BF16: { + // #outer = #iree_vector_ext.per_dim_layout<[LANEX], [16]> + // #inner = #iree_vector_ext.per_dim_layout<[LANEY, VECTORX], [4, 4]> + // #layout_a = #iree_vector_ext.layout<#outer, #inner> + // #layout_b = #iree_vector_ext.layout<#inner, #outer> + // #layout_c = #iree_vector_ext.layout<#inner, #outer> + + auto outer = PerDimLayoutAttr::get(context, {laneX}, {16}); + auto inner = PerDimLayoutAttr::get(context, {laneY, vectorX}, {4, 4}); + auto aMLayout = outer; + auto aKLayout = inner; + auto bKLayout = inner; + auto bNLayout = outer; + auto cMLayout = inner; + auto cNLayout = outer; + return ConcreteMmaLayout{opaqueLayout, aMLayout, aKLayout, bKLayout, + bNLayout, cMLayout, cNLayout}; + } + case MMAIntrinsic::MFMA_F32_32x32x8_BF16: { + // #outer = #iree_vector_ext.per_dim_layout<[LANEX], [32]> + // #inner1 = #iree_vector_ext.per_dim_layout<[LANEY, VECTORX], [2, 4]> + // #inner2 = #iree_vector_ext.per_dim_layout<[VECTORY, LANEY, VECTORX], + // [4, 2, 4]> + // #layout_a = #iree_vector_ext.layout<#outer, #inner1> + // #layout_b = #iree_vector_ext.layout<#inner1, #outer> + // #layout_c = #iree_vector_ext.layout<#inner2, #outer> + + auto outer = PerDimLayoutAttr::get(context, {laneX}, {32}); + auto inner = PerDimLayoutAttr::get(context, {laneY, vectorX}, {2, 4}); + auto aMLayout = outer; + auto aKLayout = inner; + auto bKLayout = inner; + auto bNLayout = outer; + auto cMLayout = + PerDimLayoutAttr::get(context, {vectorY, laneY, vectorX}, {4, 2, 4}); + auto cNLayout = outer; + return ConcreteMmaLayout{opaqueLayout, aMLayout, aKLayout, bKLayout, + bNLayout, cMLayout, cNLayout}; + } case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ: case MMAIntrinsic::MFMA_I32_16x16x32_I8: { // #outer = #iree_vector_ext.per_dim_layout<[LANEX], [16]> @@ -462,14 +508,16 @@ MMAAttr::getABCVectorTypes() const { return std::make_tuple(aType, bType, cType); } case MMAIntrinsic::MFMA_I32_16x16x16_I8: - case MMAIntrinsic::MFMA_F32_16x16x16_F16: { + case MMAIntrinsic::MFMA_F32_16x16x16_F16: + case MMAIntrinsic::MFMA_F32_16x16x16_BF16: { auto aType = VectorType::get({4}, getAType()); auto bType = VectorType::get({4}, getBType()); auto cType = VectorType::get({4}, getCType()); return std::make_tuple(aType, bType, cType); } case MMAIntrinsic::MFMA_I32_32x32x8_I8: - case MMAIntrinsic::MFMA_F32_32x32x8_F16: { + case MMAIntrinsic::MFMA_F32_32x32x8_F16: + case MMAIntrinsic::MFMA_F32_32x32x8_BF16: { auto aType = VectorType::get({4}, getAType()); auto bType = VectorType::get({4}, getBType()); auto cType = VectorType::get({16}, getCType()); @@ -519,8 +567,10 @@ int64_t MMAAttr::getBlockSize() const { switch (getIntrinsic().getValue()) { case MMAIntrinsic::MFMA_F32_16x16x4_F32: case MMAIntrinsic::MFMA_F32_16x16x16_F16: + case MMAIntrinsic::MFMA_F32_16x16x16_BF16: case MMAIntrinsic::MFMA_I32_16x16x16_I8: case MMAIntrinsic::MFMA_F32_32x32x8_F16: + case MMAIntrinsic::MFMA_F32_32x32x8_BF16: case MMAIntrinsic::MFMA_I32_32x32x8_I8: case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ: case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ: @@ -540,8 +590,10 @@ static int64_t getIntrinsicSubgroupSize(MMAIntrinsic intrinsic) { switch (intrinsic) { case MMAIntrinsic::MFMA_F32_16x16x4_F32: case MMAIntrinsic::MFMA_F32_16x16x16_F16: + case MMAIntrinsic::MFMA_F32_16x16x16_BF16: case MMAIntrinsic::MFMA_I32_16x16x16_I8: case MMAIntrinsic::MFMA_F32_32x32x8_F16: + case MMAIntrinsic::MFMA_F32_32x32x8_BF16: case MMAIntrinsic::MFMA_I32_32x32x8_I8: case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ: case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ: @@ -584,6 +636,7 @@ MMASingleSubgroupLayout getSingleSubgroupLayout(MMAIntrinsic intrinsic, } case MMAIntrinsic::MFMA_I32_16x16x16_I8: case MMAIntrinsic::MFMA_F32_16x16x16_F16: + case MMAIntrinsic::MFMA_F32_16x16x16_BF16: switch (fragment) { case MMAFragment::Lhs: return {/*outer=*/{1, 1}, /*thread=*/{16, 4}, /*tstrides=*/{1, 16}, @@ -597,6 +650,7 @@ MMASingleSubgroupLayout getSingleSubgroupLayout(MMAIntrinsic intrinsic, } case MMAIntrinsic::MFMA_I32_32x32x8_I8: case MMAIntrinsic::MFMA_F32_32x32x8_F16: + case MMAIntrinsic::MFMA_F32_32x32x8_BF16: switch (fragment) { case MMAFragment::Lhs: return {/*outer=*/{1, 1}, /*thread=*/{32, 2}, /*tstrides=*/{1, 32}, @@ -704,8 +758,10 @@ FailureOr MMAAttr::buildMmaOperation(OpBuilder &builder, Location loc, } case MMAIntrinsic::MFMA_I32_16x16x16_I8: case MMAIntrinsic::MFMA_F32_16x16x16_F16: + case MMAIntrinsic::MFMA_F32_16x16x16_BF16: case MMAIntrinsic::MFMA_I32_32x32x8_I8: case MMAIntrinsic::MFMA_F32_32x32x8_F16: + case MMAIntrinsic::MFMA_F32_32x32x8_BF16: case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ: case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ: case MMAIntrinsic::MFMA_I32_16x16x32_I8: diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.td b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.td index d1c84a8d9eb1..9d4ac2e9a4e1 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.td +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.td @@ -121,6 +121,8 @@ class IREEGPU_I32MmaEnumAttr def MFMA_F32_16x16x4_F32 : I32EnumAttrCase<"MFMA_F32_16x16x4_F32", 0x0900>; def MFMA_F32_16x16x16_F16 : I32EnumAttrCase<"MFMA_F32_16x16x16_F16", 0x0910>; def MFMA_F32_32x32x8_F16 : I32EnumAttrCase<"MFMA_F32_32x32x8_F16", 0x0911>; +def MFMA_F32_16x16x16_BF16 : I32EnumAttrCase<"MFMA_F32_16x16x16_BF16", 0x0920>; +def MFMA_F32_32x32x8_BF16 : I32EnumAttrCase<"MFMA_F32_32x32x8_BF16", 0x0921>; def MFMA_F32_16x16x32_F8E5M2FNUZ : I32EnumAttrCase<"MFMA_F32_16x16x32_F8E5M2FNUZ", 0x0930>; def MFMA_F32_16x16x32_F8E4M3FNUZ : I32EnumAttrCase<"MFMA_F32_16x16x32_F8E4M3FNUZ", 0x0940>; def MFMA_I32_16x16x32_I8 : I32EnumAttrCase<"MFMA_I32_16x16x32_I8", 0x0980>; @@ -143,6 +145,8 @@ def IREEGPU_MMAIntrinsic : IREEGPU_I32MmaEnumAttr<"MMAIntrinsic", MFMA_F32_16x16x4_F32, MFMA_F32_16x16x16_F16, MFMA_F32_32x32x8_F16, + MFMA_F32_16x16x16_BF16, + MFMA_F32_32x32x8_BF16, MFMA_F32_16x16x32_F8E4M3FNUZ, MFMA_F32_16x16x32_F8E5M2FNUZ, MFMA_I32_16x16x32_I8, diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp index c187f44b0512..5e8f031ff8ac 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp @@ -136,6 +136,8 @@ const WgpDetails *getCDNA3WgpDetails() { MMAIntrinsic::MFMA_F32_16x16x4_F32, MMAIntrinsic::MFMA_F32_16x16x16_F16, MMAIntrinsic::MFMA_F32_32x32x8_F16, + MMAIntrinsic::MFMA_F32_16x16x16_BF16, + MMAIntrinsic::MFMA_F32_32x32x8_BF16, MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ, MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ, MMAIntrinsic::MFMA_I32_16x16x32_I8, diff --git a/tests/e2e/matmul/CMakeLists.txt b/tests/e2e/matmul/CMakeLists.txt index 98a4ff19b6d4..f2294345984f 100644 --- a/tests/e2e/matmul/CMakeLists.txt +++ b/tests/e2e/matmul/CMakeLists.txt @@ -1570,6 +1570,35 @@ iree_generated_e2e_runner_test( "requires-gpu-cdna3" ) +iree_generated_e2e_runner_test( + NAME + e2e_matmul_rocm_bf16_cdna3_mfma_data_tiled + TEST_TYPE + matmul + GENERATOR + "generate_e2e_matmul_tests.py" + GENERATOR_ARGS + "--lhs_rhs_type=bf16" + "--acc_type=f32" + TEST_RUNNER + iree_tools_testing_e2e_iree-e2e-matmul-test + TARGET_BACKENDS + "rocm" + DRIVERS + "hip" + COMPILER_FLAGS + ${IREE_HIP_TEST_COMPILER_FLAGS} + "--iree-opt-data-tiling" + "--iree-global-opt-experimental-rocm-data-tiling" + "--iree-global-opt-enable-early-materialization=true" + LABELS + "noasan" + "nomsan" + "notsan" + "noubsan" + "requires-gpu-cdna3" +) + iree_generated_e2e_runner_test( NAME e2e_matmul_rocm_i8_cdna3_mfma_data_tiled From 1aa58257eb65db2a01227fb60b409b219f4790ff Mon Sep 17 00:00:00 2001 From: Kunwar Grover Date: Fri, 25 Oct 2024 16:43:50 +0100 Subject: [PATCH 13/45] [LLVMGPU] Combine parallel and reduction padding in LLVMGPUPadAndVectorDistribute (#18771) Since https://github.com/iree-org/iree/pull/18748 tensor.pad can be fused in with tiling. This patch combines the parallel and reduction padding passes into a single pass that pads at once, and the pads are later fused during tiling. --- .../LLVMGPU/LLVMGPUPromoteMatmulToFitMMA.cpp | 131 +++--------------- .../iree/compiler/Codegen/LLVMGPU/Passes.cpp | 9 +- .../iree/compiler/Codegen/LLVMGPU/Passes.h | 4 - .../iree/compiler/Codegen/LLVMGPU/Passes.td | 13 -- .../pipeline_vector_distribute_gfx940.mlir | 8 +- .../test/promote_matmul_to_fit_mma.mlir | 129 +++-------------- 6 files changed, 41 insertions(+), 253 deletions(-) diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUPromoteMatmulToFitMMA.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUPromoteMatmulToFitMMA.cpp index dbcc5b1e54b6..24214940f30e 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUPromoteMatmulToFitMMA.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUPromoteMatmulToFitMMA.cpp @@ -27,25 +27,18 @@ class LLVMGPUPromoteMatmulToFitMMAPass final public: using impl::LLVMGPUPromoteMatmulToFitMMAPassBase< LLVMGPUPromoteMatmulToFitMMAPass>::LLVMGPUPromoteMatmulToFitMMAPassBase; - explicit LLVMGPUPromoteMatmulToFitMMAPass( - const LLVMGPUMatmulPadOption &option) { - this->targetDimensions.setValue(option); - } void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); } void padWithZeroValue(RewriterBase &rewriter, linalg::LinalgOp op, - ArrayRef paddingDims, - ArrayRef padToMultipleOf, bool noFold) const { - assert(paddingDims.size() == padToMultipleOf.size() && - "invalid pad multiples for padding dimensions"); - + ArrayRef padToMultipleOf) const { LLVM_DEBUG(llvm::dbgs() << "candidate: " << op << "\n"); OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointAfter(op); - SmallVector nofoldFlags(op.getNumDpsInputs(), noFold); + SmallVector paddingDims = + llvm::to_vector(llvm::seq(padToMultipleOf.size())); SmallVector paddingValueAttributes; for (auto &operand : op->getOpOperands()) { @@ -58,7 +51,6 @@ class LLVMGPUPromoteMatmulToFitMMAPass final .setPaddingDimensions(paddingDims) .setPaddingValues(paddingValueAttributes) .setPadToMultipleOf(padToMultipleOf) - .setNofoldFlags(nofoldFlags) .setCopyBackOp(linalg::LinalgPaddingOptions::CopyBackOp::None); FailureOr result = @@ -72,26 +64,6 @@ class LLVMGPUPromoteMatmulToFitMMAPass final MLIRContext *ctx = &getContext(); auto funcOp = getOperation(); - // Preserve the innermost tensor.pad ops (i.e., pad for reduction dims), so - // we can kick canonicalization patterns to fold outer tensor.pad ops away. - bool noFold = false; - utils::IteratorType targetIterType = utils::IteratorType::parallel; - switch (targetDimensions) { - case LLVMGPUMatmulPadOption::ParallelDims: - LLVM_DEBUG(llvm::dbgs() << "padding parallel dims\n"); - targetIterType = utils::IteratorType::parallel; - noFold = false; - break; - case LLVMGPUMatmulPadOption::ReductionDims: - LLVM_DEBUG(llvm::dbgs() << "padding reduction dims\n"); - targetIterType = utils::IteratorType::reduction; - noFold = true; - break; - default: // Unreachable. - assert(false); - break; - }; - SmallVector candidates; funcOp->walk([&](linalg::LinalgOp op) { if (linalg::isaContractionOpInterface(op)) { @@ -101,46 +73,27 @@ class LLVMGPUPromoteMatmulToFitMMAPass final IRRewriter rewriter(ctx); for (linalg::LinalgOp op : candidates) { - SmallVector padMultiples(op.getNumLoops(), 1); auto config = dyn_cast_or_null( getLoweringConfig(op)); - if (config) { - switch (targetDimensions) { - case LLVMGPUMatmulPadOption::ParallelDims: - padMultiples = config.getStaticTilingLevelSizes( - static_cast(IREE::GPU::TilingLevel::Workgroup), op); - break; - case LLVMGPUMatmulPadOption::ReductionDims: - padMultiples = config.getStaticTilingLevelSizes( - static_cast(IREE::GPU::TilingLevel::Reduction), op); - break; - default: - assert(false && "Unexpected target dimensions"); - break; - } + if (!config) { + continue; } - // Populate padding dimensions. - SmallVector paddingDimensions; - for (auto [idx, iter] : llvm::enumerate(op.getIteratorTypesArray())) { - if (iter == targetIterType) { - paddingDimensions.push_back(idx); - } - } + SmallVector wgTiles = config.getStaticTilingLevelSizes( + static_cast(IREE::GPU::TilingLevel::Workgroup), op); + SmallVector redTiles = config.getStaticTilingLevelSizes( + static_cast(IREE::GPU::TilingLevel::Reduction), op); - // Populate tile sizes. We pad to multiples of workgroup/reduction - // tile sizes based on the selected target tiling dimensions. - // This pass is ran after the select target tiling is done to pad - // all dimensions to the select tile sizes. - SmallVector padToMultipleOf; - for (int64_t dim : paddingDimensions) { - if (padMultiples[dim] != 0) { - padToMultipleOf.push_back(padMultiples[dim]); - } + // Populate padding dimensions to maximum of possible tile sizes. + SmallVector padToMultipleOf(op.getNumLoops(), 1); + for (auto [wgTile, redTile, padMultiple] : + llvm::zip_equal(wgTiles, redTiles, padToMultipleOf)) { + padMultiple = std::max({wgTile, redTile, padMultiple}); } + SmallVector paddingDimensions = + llvm::to_vector(llvm::seq(op.getNumLoops())); - padWithZeroValue(rewriter, op, paddingDimensions, padToMultipleOf, - noFold); + padWithZeroValue(rewriter, op, padToMultipleOf); } { @@ -156,58 +109,8 @@ class LLVMGPUPromoteMatmulToFitMMAPass final return signalPassFailure(); } } - - // XXX(hanchung): This is needed for pad op fusion, which will remove - // outer pad ops. I.e., it mainly wants to remove first pad op in the - // pad->extract_slice->pad chain, while the canonicalization pattern can - // only recognize slice->pad->slice->pad. - { - SmallVector padOps; - funcOp.walk([&](tensor::PadOp op) { padOps.push_back(op); }); - for (auto op : padOps) { - auto srcExtractSliceOp = - op.getSource().getDefiningOp(); - if (!srcExtractSliceOp) { - continue; - } - auto producerPadOp = - srcExtractSliceOp.getSource().getDefiningOp(); - if (!producerPadOp) { - continue; - } - auto src = producerPadOp.getSource() - .getDefiningOp(); - if (!src) { - continue; - } - - rewriter.setInsertionPointAfter(src); - SmallVector sizes = - tensor::getMixedSizes(rewriter, op.getLoc(), src); - SmallVector offsets(sizes.size(), - rewriter.getIndexAttr(0)); - SmallVector strides(sizes.size(), - rewriter.getIndexAttr(1)); - auto extractSliceOp = rewriter.create( - op.getLoc(), src.getResult(), offsets, sizes, strides); - rewriter.startOpModification(op); - producerPadOp.getSourceMutable().assign(extractSliceOp.getResult()); - rewriter.finalizeOpModification(op); - } - - RewritePatternSet patterns(ctx); - tensor::PadOp::getCanonicalizationPatterns(patterns, ctx); - if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) { - return signalPassFailure(); - } - } } }; } // namespace -std::unique_ptr> -createLLVMGPUPromoteMatmulToFitMMAPass(LLVMGPUMatmulPadOption option) { - return std::make_unique(option); -} - } // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp index 76b1af3204be..51fcc6b3996c 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp @@ -858,25 +858,20 @@ void addGPUVectorDistributePassPipeline(OpPassManager &funcPassManager, funcPassManager.addPass(createCSEPass()); if (usePadToModelSharedMemcpy) { - LLVMGPUMatmulPadOption option = LLVMGPUMatmulPadOption::ParallelDims; - funcPassManager.addPass(createLLVMGPUPromoteMatmulToFitMMAPass(option)); + funcPassManager.addPass(createLLVMGPUPromoteMatmulToFitMMAPass()); } // Tile to reduction loops. { GPUApplyTilingLevelPassOptions options; options.tilingLevel = IREE::GPU::TilingLevel::Reduction; + options.allowZeroSlices = true; funcPassManager.addPass(createGPUApplyTilingLevelPass(options)); funcPassManager.addPass(affine::createLoopCoalescingPass()); funcPassManager.addPass(createCanonicalizerPass()); funcPassManager.addPass(createCSEPass()); } - if (usePadToModelSharedMemcpy) { - LLVMGPUMatmulPadOption option = LLVMGPUMatmulPadOption::ReductionDims; - funcPassManager.addPass(createLLVMGPUPromoteMatmulToFitMMAPass(option)); - } - funcPassManager.addPass(IREE::LinalgExt::createDecomposeAttentionPass()); funcPassManager.addPass(createCanonicalizerPass()); funcPassManager.addPass(createCSEPass()); diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.h b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.h index c1181776e8f7..d9325647a50d 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.h +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.h @@ -103,10 +103,6 @@ verifyGPUMatmulPipeline(Operation *op, // Wrappers that not use tablegen options. //------------------------------------------------------------------------------ -enum class LLVMGPUMatmulPadOption { ParallelDims, ReductionDims }; -std::unique_ptr> -createLLVMGPUPromoteMatmulToFitMMAPass(LLVMGPUMatmulPadOption option); - enum class GPUTensorCoreType { WMMA = 0, MMA_SYNC = 1, diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.td b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.td index ef51a6a9a883..815a82f28d8d 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.td +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.td @@ -105,19 +105,6 @@ def LLVMGPUPrefetchSharedMemoryPass : def LLVMGPUPromoteMatmulToFitMMAPass : InterfacePass<"iree-llvmgpu-promote-matmul-to-fit-mma", "mlir::FunctionOpInterface"> { let summary = "Pass to promote contraction ops to fit mma shapes"; - let options = [ - Option<"targetDimensions", "target-dimensions", "mlir::iree_compiler::LLVMGPUMatmulPadOption", - /*default=*/"mlir::iree_compiler::LLVMGPUMatmulPadOption::ParallelDims", - "Select the strategy to control how multi_reduction is lowered.", - [{::llvm::cl::values( - clEnumValN(mlir::iree_compiler::LLVMGPUMatmulPadOption::ParallelDims, - "parallel", - "Pad all the parallel dims for contraction ops."), - clEnumValN(mlir::iree_compiler::LLVMGPUMatmulPadOption::ReductionDims, - "reduction", - "Pad all the reduction dims for contraction ops.") - )}]> - ]; } def LLVMGPUSelectLoweringStrategyPass : diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx940.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx940.mlir index 610e114c81e3..d21faf8867b1 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx940.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx940.mlir @@ -511,7 +511,7 @@ hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb">) { // CHECK: %[[RHS_LOAD:.+]] = vector.transfer_read %[[RHS_GLOBAL_SUB]]{{.+}} {in_bounds = [true, false, false]} // CHECK: vector.transfer_write %[[LHS_LOAD]], %[[LHS_SHARED]] // CHECK: vector.transfer_write %[[RHS_LOAD]], %[[RHS_SHARED]] -// CHECK: %[[RES:.+]] scf.for {{.*}} = %c0 to %c1265 step %c16 iter_args({{.*}}) -> (vector<1x1x1x1x1x1x1x4x1xf16>) +// CHECK: %[[RES:.+]] scf.for {{.*}} = %c0 to %c1280 step %c16 iter_args({{.*}}) -> (vector<1x1x1x1x1x1x1x4x1xf16>) // CHECK-DAG: %[[LHS_GLOBAL_SUB:.+]] = memref.subview %[[LHS_GLOBAL]] // CHECK-DAG: %[[RHS_GLOBAL_SUB:.+]] = memref.subview %[[RHS_GLOBAL]] // CHECK: %[[LHS_LOAD:.+]] = vector.transfer_read %[[LHS_GLOBAL_SUB]] @@ -581,9 +581,11 @@ hal.executable public @pad_batch_matmul { // CHECK-SAME: memref<196x16x24xf32 // CHECK-SAME: vector<1x1x1xf32> // RHS +// The dynamic dimension should be removed after: +// https://github.com/llvm/llvm-project/pull/112236 // CHECK: vector.transfer_read -// CHECK-SAME: in_bounds = [true, true, false] -// CHECK-SAME: memref<1x8x24xf32 +// CHECK-SAME: in_bounds = [true, false, false] +// CHECK-SAME: memref<1x?x24xf32 // CHECK-SAME: vector<1x1x2xf32> // CHECK: scf.yield // OUTPUT diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/promote_matmul_to_fit_mma.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/promote_matmul_to_fit_mma.mlir index bda4836eaec3..21bc2fc3cac3 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/promote_matmul_to_fit_mma.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/promote_matmul_to_fit_mma.mlir @@ -1,5 +1,4 @@ -// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(func.func(iree-llvmgpu-promote-matmul-to-fit-mma{target-dimensions=parallel}))" %s | FileCheck %s --check-prefixes=ALL,PARALLEL -// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(func.func(iree-llvmgpu-promote-matmul-to-fit-mma{target-dimensions=reduction}))" %s | FileCheck %s --check-prefixes=ALL,REDUCTION +// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(func.func(iree-llvmgpu-promote-matmul-to-fit-mma))" %s | FileCheck %s #pipeline_layout = #hal.pipeline.layout, @@ -34,114 +33,20 @@ func.func @batch_matmul_f16() { flow.dispatch.tensor.store %11, %2, offsets = [%workgroup_id_z, %3, %4], sizes = [1, %5, %6], strides = [1, 1, 1] : tensor<1x?x?xf16> -> !flow.dispatch.tensor> return } -// ALL-LABEL: func.func @batch_matmul_f16 -// ALL: %[[LHS_HANDLE:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> -// ALL: %[[RHS_HANDLE:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(1) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> -// ALL: %[[OUT_HANDLE:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(2) alignment(64) offset(%c0) : !flow.dispatch.tensor> -// ALL-DAG: %[[LHS:.+]] = flow.dispatch.tensor.load %[[LHS_HANDLE]] -// ALL-DAG: %[[RHS:.+]] = flow.dispatch.tensor.load %[[RHS_HANDLE]] -// PARALLEL: %[[PADDED_LHS:.+]] = tensor.pad %[[LHS]] -// PARALLEL: } : tensor<1x?x1281xf16> to tensor<1x64x1281xf16> -// PARALLEL: %[[PADDED_RHS:.+]] = tensor.pad %[[RHS]] -// PARALLEL: } : tensor<1x1281x?xf16> to tensor<1x1281x128xf16> -// PARALLEL: %[[FILL:.+]] = linalg.fill -// PARALLEL: %[[GEMM:.+]] = linalg.batch_matmul -// PARALLEL-SAME: ins(%[[PADDED_LHS]], %[[PADDED_RHS]] -// PARALLEL-SAME: outs(%[[FILL]] +// CHECK-LABEL: func.func @batch_matmul_f16 +// CHECK: %[[LHS_HANDLE:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> +// CHECK: %[[RHS_HANDLE:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(1) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> +// CHECK: %[[OUT_HANDLE:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(2) alignment(64) offset(%c0) : !flow.dispatch.tensor> +// CHECK-DAG: %[[LHS:.+]] = flow.dispatch.tensor.load %[[LHS_HANDLE]] +// CHECK-DAG: %[[RHS:.+]] = flow.dispatch.tensor.load %[[RHS_HANDLE]] +// CHECK: %[[PADDED_LHS:.+]] = tensor.pad %[[LHS]] +// CHECK: } : tensor<1x?x1281xf16> to tensor<1x64x1296xf16> +// CHECK: %[[PADDED_RHS:.+]] = tensor.pad %[[RHS]] +// CHECK: } : tensor<1x1281x?xf16> to tensor<1x1296x128xf16> +// CHECK: %[[FILL:.+]] = linalg.fill +// CHECK: %[[GEMM:.+]] = linalg.batch_matmul +// CHECK-SAME: ins(%[[PADDED_LHS]], %[[PADDED_RHS]] +// CHECK-SAME: outs(%[[FILL]] -// The reduction dim is not tiled in the test case, so it pads it to the -// matmul intrinsic k. -// REDUCTION-DAG: %[[FILL_DEST:.+]] = flow.dispatch.tensor.load %[[OUT_HANDLE]] -// REDUCTION: %[[FILL:.+]] = linalg.fill ins(%{{.+}}) outs(%[[FILL_DEST]] -// REDUCTION: %[[PADDED_LHS:.+]] = tensor.pad %[[LHS]] -// REDUCTION: } : tensor<1x?x1281xf16> to tensor<1x?x1296xf16> -// REDUCTION: %[[PADDED_RHS:.+]] = tensor.pad %[[RHS]] -// REDUCTION: } : tensor<1x1281x?xf16> to tensor<1x1296x?xf16> -// REDUCTION: %[[GEMM:.+]] = linalg.batch_matmul -// REDUCTION-SAME: ins(%[[PADDED_LHS]], %[[PADDED_RHS]] -// REDUCTION-SAME: outs(%[[FILL]] - -// ALL: %[[OUT_SLICE:.+]] = tensor.extract_slice %[[GEMM]] -// ALL: flow.dispatch.tensor.store %[[OUT_SLICE]], %[[OUT_HANDLE]] - -// ----- - -#pipeline_layout = #hal.pipeline.layout, - #hal.pipeline.binding, - #hal.pipeline.binding -]> -#map = affine_map<()[s0] -> (s0 * 64)> -#map1 = affine_map<()[s0] -> (s0 * 128)> -#map2 = affine_map<()[s0] -> (s0 * -64 + 968, 64)> -#map3 = affine_map<()[s0] -> (s0 * -128 + 1281, 128)> -#map4 = affine_map<()[s0] -> (-s0 + 64)> -#map5 = affine_map<()[s0] -> (-s0 + 128)> -#map6 = affine_map<(d0) -> (-d0 + 1281, 64)> -func.func @batch_matmul_pad_reduction_after_tiling() { - %c64 = arith.constant 64 : index - %c1281 = arith.constant 1281 : index - %c2 = arith.constant 2 : index - %c1 = arith.constant 1 : index - %cst = arith.constant 0.000000e+00 : f16 - %c0 = arith.constant 0 : index - %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> - %1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> - %2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) alignment(64) offset(%c0) : !flow.dispatch.tensor> - %workgroup_id_z = hal.interface.workgroup.id[2] : index - %workgroup_id_y = hal.interface.workgroup.id[1] : index - %3 = affine.apply #map()[%workgroup_id_y] - %workgroup_id_x = hal.interface.workgroup.id[0] : index - %4 = affine.apply #map1()[%workgroup_id_x] - %5 = affine.min #map2()[%workgroup_id_y] - %6 = affine.min #map3()[%workgroup_id_x] - %7 = flow.dispatch.tensor.load %0, offsets = [%workgroup_id_z, %3, 0], sizes = [1, %5, 1281], strides = [1, 1, 1] : !flow.dispatch.tensor> -> tensor<1x?x1281xf16> - %dim = tensor.dim %7, %c1 : tensor<1x?x1281xf16> - %8 = flow.dispatch.tensor.load %1, offsets = [%workgroup_id_z, 0, %4], sizes = [1, 1281, %6], strides = [1, 1, 1] : !flow.dispatch.tensor> -> tensor<1x1281x?xf16> - %dim_0 = tensor.dim %8, %c2 : tensor<1x1281x?xf16> - %9 = affine.apply #map4()[%5] - %padded = tensor.pad %7 low[0, 0, 0] high[0, %9, 0] { - ^bb0(%arg0: index, %arg1: index, %arg2: index): - tensor.yield %cst : f16 - } : tensor<1x?x1281xf16> to tensor<1x64x1281xf16> - %10 = affine.apply #map5()[%6] - %padded_2 = tensor.pad %8 low[0, 0, 0] high[0, 0, %10] { - ^bb0(%arg0: index, %arg1: index, %arg2: index): - tensor.yield %cst : f16 - } : tensor<1x1281x?xf16> to tensor<1x1281x128xf16> - %11 = tensor.empty() : tensor<1x64x128xf16> - %12 = linalg.fill ins(%cst : f16) outs(%11 : tensor<1x64x128xf16>) -> tensor<1x64x128xf16> - %13 = scf.for %arg0 = %c0 to %c1281 step %c64 iter_args(%arg1 = %12) -> (tensor<1x64x128xf16>) { - %14 = affine.min #map6(%arg0) - %extracted_slice_4 = tensor.extract_slice %padded[0, 0, %arg0] [1, 64, %14] [1, 1, 1] : tensor<1x64x1281xf16> to tensor<1x64x?xf16> - %extracted_slice_5 = tensor.extract_slice %padded_2[0, %arg0, 0] [1, %14, 128] [1, 1, 1] : tensor<1x1281x128xf16> to tensor<1x?x128xf16> - %15 = linalg.batch_matmul ins(%extracted_slice_4, %extracted_slice_5 : tensor<1x64x?xf16>, tensor<1x?x128xf16>) outs(%arg1 : tensor<1x64x128xf16>) -> tensor<1x64x128xf16> - scf.yield %15 : tensor<1x64x128xf16> - } - %extracted_slice_3 = tensor.extract_slice %13[0, 0, 0] [1, %5, %6] [1, 1, 1] : tensor<1x64x128xf16> to tensor<1x?x?xf16> - flow.dispatch.tensor.store %extracted_slice_3, %2, offsets = [%workgroup_id_z, %3, %4], sizes = [1, %5, %6], strides = [1, 1, 1] : tensor<1x?x?xf16> -> !flow.dispatch.tensor> - return -} -// The padding on parallel dims is a nop because they are already padded. Skip -// the check for the testcase. -// ALL-LABEL: func.func @batch_matmul_pad_reduction_after_tiling -// ALL: %[[LHS_HANDLE:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> -// ALL: %[[RHS_HANDLE:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(1) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> -// ALL: %[[OUT_HANDLE:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(2) alignment(64) offset(%c0) : !flow.dispatch.tensor> -// ALL-DAG: %[[LHS:.+]] = flow.dispatch.tensor.load %[[LHS_HANDLE]] -// ALL-DAG: %[[RHS:.+]] = flow.dispatch.tensor.load %[[RHS_HANDLE]] -// REDUCTION: %[[INIT:.+]] = tensor.empty() : tensor<1x64x128xf16> -// REDUCTION: %[[FILL:.+]] = linalg.fill ins(%{{.+}}) outs(%[[INIT]] -// REDUCTION: %[[RES:.+]] = scf.for {{.+}} iter_args(%[[ITER:.+]] = %[[FILL]]) -// REDUCTION: %[[LHS_SLICE:.+]] = tensor.extract_slice %[[LHS]] -// REDUCTION: %[[PADDED_LHS:.+]] = tensor.pad %[[LHS_SLICE]] -// REDUCTION: } : tensor<1x?x?xf16> to tensor<1x64x64xf16> -// REDUCTION: %[[RHS_SLICE:.+]] = tensor.extract_slice %[[RHS]] -// REDUCTION: %[[PADDED_RHS:.+]] = tensor.pad %[[RHS_SLICE]] -// REDUCTION: } : tensor<1x?x?xf16> to tensor<1x64x128xf16> -// REDUCTION: %[[GEMM:.+]] = linalg.batch_matmul -// REDUCTION-SAME: ins(%[[PADDED_LHS]], %[[PADDED_RHS]] -// REDUCTION-SAME: outs(%[[ITER]] -// REDUCTION: scf.yield %[[GEMM]] -// REDUCTION: %[[OUT_SLICE:.+]] = tensor.extract_slice %[[RES]] -// REDUCTION: flow.dispatch.tensor.store %[[OUT_SLICE]], %[[OUT_HANDLE]] +// CHECK: %[[OUT_SLICE:.+]] = tensor.extract_slice %[[GEMM]] +// CHECK: flow.dispatch.tensor.store %[[OUT_SLICE]], %[[OUT_HANDLE]] From c6b3592d44274a3f90bcf4655d05876c0c2def76 Mon Sep 17 00:00:00 2001 From: Nithin Meganathan <18070964+nithinsubbiah@users.noreply.github.com> Date: Fri, 25 Oct 2024 10:30:10 -0700 Subject: [PATCH 14/45] [Dispatch Creation] Bubble up ExtractSliceOp with FillOp when the latter has multiple consumers (#18896) Signed-off-by: nithinsubbiah --- .../BubbleUpExtractSlices.cpp | 26 +++++++++++++++++++ .../test/bubble_up_extract_slice.mlir | 21 +++++++++++++++ 2 files changed, 47 insertions(+) diff --git a/compiler/src/iree/compiler/DispatchCreation/BubbleUpExtractSlices.cpp b/compiler/src/iree/compiler/DispatchCreation/BubbleUpExtractSlices.cpp index 4b03b60f5779..59ddb577fb51 100644 --- a/compiler/src/iree/compiler/DispatchCreation/BubbleUpExtractSlices.cpp +++ b/compiler/src/iree/compiler/DispatchCreation/BubbleUpExtractSlices.cpp @@ -115,6 +115,30 @@ struct BubbleUpExtract : OpRewritePattern { } }; +/// Swaps tensor.extract_slice(linalg.fill(%cst, %init)) into linalg.fill(%cst, +/// tensor.extract_slice(%init)) even when the linalg.fill has multiple users. +/// Bubbles up tensor.extract_slice when encountered with linalg.fill and the +/// former can be folded away. +struct SwapExtractSliceOfFill final + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::ExtractSliceOp extractOp, + PatternRewriter &rewriter) const override { + auto fillOp = extractOp.getSource().getDefiningOp(); + if (!fillOp) + return failure(); + + auto newExtractOp = rewriter.create( + extractOp.getLoc(), extractOp.getType(), fillOp.getOutputs()[0], + extractOp.getMixedOffsets(), extractOp.getMixedSizes(), + extractOp.getMixedStrides()); + rewriter.replaceOpWithNewOp( + extractOp, fillOp.getInputs(), ValueRange{newExtractOp.getResult()}); + return success(); + } +}; + struct BubbleUpExtractSlicesPass : impl::BubbleUpExtractSlicesPassBase { void runOnOperation() override { @@ -122,6 +146,8 @@ struct BubbleUpExtractSlicesPass { RewritePatternSet patterns(context); patterns.insert(context); + patterns.insert(context); + tensor::populateFoldTensorEmptyPatterns(patterns, false); if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) { return signalPassFailure(); diff --git a/compiler/src/iree/compiler/DispatchCreation/test/bubble_up_extract_slice.mlir b/compiler/src/iree/compiler/DispatchCreation/test/bubble_up_extract_slice.mlir index a5b7ea13ee27..56fa91d7b2d6 100644 --- a/compiler/src/iree/compiler/DispatchCreation/test/bubble_up_extract_slice.mlir +++ b/compiler/src/iree/compiler/DispatchCreation/test/bubble_up_extract_slice.mlir @@ -94,3 +94,24 @@ util.func public @bubble_up_extract_with_use(%arg0 : tensor<1024x7x7x2xi8>) -> ( // CHECK-DAG: %[[GENERIC1:.+]] = linalg.generic // CHECK-SAME: ins(%[[EXTRACT0]] : tensor<1024x7x7xi8>) // CHECK: util.return %[[GENERIC1]], %[[GENERIC0]] + +util.func public @bubble_up_extract_fill_multi_use() -> tensor<2x320x130x130xf8E4M3FNUZ> { + %cst_1 = arith.constant 1.000000e+00 : f8E4M3FNUZ + %cst_2 = arith.constant 2.000000e+00 : f8E4M3FNUZ + %1 = tensor.empty() : tensor<2x320x128x128xf8E4M3FNUZ> + %2 = linalg.fill ins(%cst_2 : f8E4M3FNUZ) outs(%1 : tensor<2x320x128x128xf8E4M3FNUZ>) -> tensor<2x320x128x128xf8E4M3FNUZ> + %3 = tensor.empty() : tensor<2x320x130x130xf8E4M3FNUZ> + %4 = linalg.fill ins(%cst_1 : f8E4M3FNUZ) outs(%3 : tensor<2x320x130x130xf8E4M3FNUZ>) -> tensor<2x320x130x130xf8E4M3FNUZ> + %extracted_slice_1 = tensor.extract_slice %4[0, 0, 1, 0] [2, 320, 128, 130] [1, 1, 1, 1] : tensor<2x320x130x130xf8E4M3FNUZ> to tensor<2x320x128x130xf8E4M3FNUZ> + %inserted_slice_1 = tensor.insert_slice %2 into %extracted_slice_1[0, 0, 0, 1] [2, 320, 128, 128] [1, 1, 1, 1] : tensor<2x320x128x128xf8E4M3FNUZ> into tensor<2x320x128x130xf8E4M3FNUZ> + %inserted_slice_2 = tensor.insert_slice %inserted_slice_1 into %4[0, 0, 1, 0] [2, 320, 128, 130] [1, 1, 1, 1] : tensor<2x320x128x130xf8E4M3FNUZ> into tensor<2x320x130x130xf8E4M3FNUZ> + util.return %inserted_slice_2 : tensor<2x320x130x130xf8E4M3FNUZ> +} + +// CHECK-LABEL: @bubble_up_extract_fill_multi_use +// CHECK: %[[FILL1:.+]] = linalg.fill +// CHECK: %[[EMPTY1:.+]] = tensor.empty +// CHECK: %[[FILL2:.+]] = linalg.fill +// CHECK-NOT: %[[SLICE:.+]] = tensor.extract_slice +// CHECK: %[[EMPTY2:.+]] = tensor.empty +// CHECK: %[[FILL3:.+]] = linalg.fill From 55c55620b4c1f611b9a656c7e6bc115acbd765e7 Mon Sep 17 00:00:00 2001 From: Max191 <44243577+Max191@users.noreply.github.com> Date: Fri, 25 Oct 2024 10:36:33 -0700 Subject: [PATCH 15/45] [LLVMGPU][NFC] Create LLVMGPU pass for IGEMM (#18871) This PR refactors the ConvolutionToIGEMM pass to a shared transform function, and creates a new pass for LLVMGPU. This keeps the lowering config details in LLVMGPU separate from the common pass, and removes the need for passing a control function or config function in the pass constructor. This is also a precursor to adding some more complex logic in the control function for LLVMGPU, which will be added in a later PR. --------- Signed-off-by: Max Dawkins --- .../Codegen/Common/ConvolutionToIGEMM.cpp | 162 +++++++++--------- .../src/iree/compiler/Codegen/Common/Passes.h | 7 - .../iree/compiler/Codegen/Common/Passes.td | 4 + .../iree/compiler/Codegen/Common/Transforms.h | 11 ++ .../Common/test/convolution_to_igemm.mlir | 19 -- .../iree/compiler/Codegen/LLVMGPU/BUILD.bazel | 1 + .../compiler/Codegen/LLVMGPU/CMakeLists.txt | 1 + .../LLVMGPU/LLVMGPUConvolutionToIGEMM.cpp | 66 +++++++ .../iree/compiler/Codegen/LLVMGPU/Passes.cpp | 26 +-- .../iree/compiler/Codegen/LLVMGPU/Passes.td | 11 ++ .../compiler/Codegen/LLVMGPU/test/BUILD.bazel | 1 + .../Codegen/LLVMGPU/test/CMakeLists.txt | 1 + .../test/llvmgpu_convolution_to_igemm.mlir | 36 ++++ .../Transforms/ConvertConv2DToIm2ColOp.cpp | 14 +- .../Dialect/LinalgExt/Transforms/Passes.h | 6 +- 15 files changed, 230 insertions(+), 136 deletions(-) create mode 100644 compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUConvolutionToIGEMM.cpp create mode 100644 compiler/src/iree/compiler/Codegen/LLVMGPU/test/llvmgpu_convolution_to_igemm.mlir diff --git a/compiler/src/iree/compiler/Codegen/Common/ConvolutionToIGEMM.cpp b/compiler/src/iree/compiler/Codegen/Common/ConvolutionToIGEMM.cpp index 58b678ce6588..8998b11ccee4 100644 --- a/compiler/src/iree/compiler/Codegen/Common/ConvolutionToIGEMM.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/ConvolutionToIGEMM.cpp @@ -5,6 +5,7 @@ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception #include "iree/compiler/Codegen/Common/Passes.h" +#include "iree/compiler/Codegen/Common/Transforms.h" #include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h" #include "iree/compiler/Codegen/Transforms/Transforms.h" #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h" @@ -12,6 +13,7 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/IR/MLIRContext.h" #include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassRegistry.h" @@ -26,10 +28,14 @@ namespace { using iree_compiler::IREE::LinalgExt::IREELinalgExtDialect; +/// Pattern to set a lowering configuration on an IGEMM convolution. Searches +/// for a contraction with a linalg_ext.im2col producer, and calls the configFn +/// to set the configuration. +/// TODO(Max191): Use a funcOp walk instead of a pattern for this. struct SetIGEMMConfiguration final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; - SetIGEMMConfiguration(MLIRContext *context, ConfigFn configFn) + SetIGEMMConfiguration(MLIRContext *context, IGEMMConfigFn configFn) : OpRewritePattern(context), configFn(configFn) {} LogicalResult matchAndRewrite(linalg::GenericOp genericOp, @@ -67,7 +73,7 @@ struct SetIGEMMConfiguration final : OpRewritePattern { } private: - ConfigFn configFn; + IGEMMConfigFn configFn; }; class ConvolutionToIGEMMPass final @@ -75,91 +81,87 @@ class ConvolutionToIGEMMPass final public: using ConvolutionToIGEMMPassBase::ConvolutionToIGEMMPassBase; - explicit ConvolutionToIGEMMPass(ConfigFn configFn) : configFn(configFn) {} + ConvolutionToIGEMMPass(std::optional configFn, + std::optional controlFn) + : configFn(configFn), controlFn(controlFn) {} - void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); - } - void runOnOperation() override { - MLIRContext *context = &getContext(); - - // Rewrite convolutions into a im2col and GEMM. - { - auto conv2dToIm2colControlFn = [](Operation *conv) { - // Don't transform convolutions that have a preset lowering config. - if (getLoweringConfig(conv)) { - return false; - } - return true; - }; - MLIRContext *context = &getContext(); - RewritePatternSet patterns(context); - iree_compiler::IREE::LinalgExt::populateConv2DToIm2colOpPatterns( - patterns, conv2dToIm2colControlFn); - patterns.add(context, configFn); - if (failed(applyPatternsAndFoldGreedily(getOperation(), - std::move(patterns)))) { - return signalPassFailure(); - } - } - - // The im2col transformation collapses some of the dimensions of the - // convolution operands. Try to push the reshape ops towards the boundaries - // of the function and fold with interface tensor ops. - // - // TODO(Max191): Allow for the im2col op to have multiple M dimensions, and - // generate a multi-M dim contraction instead of collapsing and - // propagating reshapes. It should ultimately become a pass option to - // decide whether to collapse the contraction dimensions into a single - // M/N/K dimension. - { - RewritePatternSet bubbleCollapseShapePatterns(context); - linalg::ControlFusionFn bubbleUpExpansionControlFn = - [](OpOperand *fusedOperand) { - Operation *producer = fusedOperand->get().getDefiningOp(); - Operation *consumer = fusedOperand->getOwner(); - - // Block only if one of the operations has a lowering configuration - // which means it likely expects tiling specific to its original - // shape. - if (getLoweringConfig(producer) || getLoweringConfig(consumer)) { - return false; - } - return true; - }; - linalg::populateFoldReshapeOpsByCollapsingPatterns( - bubbleCollapseShapePatterns, bubbleUpExpansionControlFn); - // Add patterns to do some additional cleanup (on top of canonicalizations - // that can be done later) of reshape ops. - tensor::populateFoldTensorEmptyPatterns(bubbleCollapseShapePatterns); - linalg::FillOp::getCanonicalizationPatterns(bubbleCollapseShapePatterns, - context); - tensor::CollapseShapeOp::getCanonicalizationPatterns( - bubbleCollapseShapePatterns, context); - tensor::EmptyOp::getCanonicalizationPatterns(bubbleCollapseShapePatterns, - context); - tensor::ExpandShapeOp::getCanonicalizationPatterns( - bubbleCollapseShapePatterns, context); - populateReshapeToInterfaceTensorPatterns(bubbleCollapseShapePatterns); - if (failed(applyPatternsAndFoldGreedily( - getOperation(), std::move(bubbleCollapseShapePatterns)))) { - return signalPassFailure(); - } - } - } + void runOnOperation() override; private: - ConfigFn configFn = [](linalg::GenericOp genericOp, - IREE::LinalgExt::Im2colOp im2colOp) { - return failure(); - }; + std::optional configFn; + std::optional controlFn; }; } // namespace -std::unique_ptr> -createConvolutionToIGEMMPass(ConfigFn configFn) { - return std::make_unique(configFn); +LogicalResult +convertToIGEMMAndSetConfig(FunctionOpInterface funcOp, + std::optional configFn, + std::optional controlFn) { + // Rewrite convolutions into a im2col and GEMM. + MLIRContext *context = funcOp->getContext(); + { + RewritePatternSet patterns(context); + iree_compiler::IREE::LinalgExt::populateConv2DToIm2colOpPatterns(patterns, + controlFn); + if (configFn.has_value()) { + patterns.add(context, configFn.value()); + } + if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) { + return failure(); + } + } + + // The im2col transformation collapses some of the dimensions of the + // convolution operands. Try to push the reshape ops towards the boundaries + // of the function and fold with interface tensor ops. + // + // TODO(Max191): Allow for the im2col op to have multiple M dimensions, and + // generate a multi-M dim contraction instead of collapsing and + // propagating reshapes. It should ultimately become a pass option to + // decide whether to collapse the contraction dimensions into a single + // M/N/K dimension. + { + RewritePatternSet bubbleCollapseShapePatterns(context); + linalg::ControlFusionFn bubbleUpExpansionControlFn = + [](OpOperand *fusedOperand) { + Operation *producer = fusedOperand->get().getDefiningOp(); + Operation *consumer = fusedOperand->getOwner(); + + // Block only if one of the operations has a lowering configuration + // which means it likely expects tiling specific to its original + // shape. + if (getLoweringConfig(producer) || getLoweringConfig(consumer)) { + return false; + } + return true; + }; + linalg::populateFoldReshapeOpsByCollapsingPatterns( + bubbleCollapseShapePatterns, bubbleUpExpansionControlFn); + // Add patterns to do some additional cleanup (on top of canonicalizations + // that can be done later) of reshape ops. + tensor::populateFoldTensorEmptyPatterns(bubbleCollapseShapePatterns); + linalg::FillOp::getCanonicalizationPatterns(bubbleCollapseShapePatterns, + context); + tensor::CollapseShapeOp::getCanonicalizationPatterns( + bubbleCollapseShapePatterns, context); + tensor::EmptyOp::getCanonicalizationPatterns(bubbleCollapseShapePatterns, + context); + tensor::ExpandShapeOp::getCanonicalizationPatterns( + bubbleCollapseShapePatterns, context); + populateReshapeToInterfaceTensorPatterns(bubbleCollapseShapePatterns); + if (failed(applyPatternsAndFoldGreedily( + funcOp, std::move(bubbleCollapseShapePatterns)))) { + return failure(); + } + } + return success(); +} + +void ConvolutionToIGEMMPass::runOnOperation() { + if (failed(convertToIGEMMAndSetConfig(getOperation()))) { + return signalPassFailure(); + } } } // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Codegen/Common/Passes.h b/compiler/src/iree/compiler/Codegen/Common/Passes.h index 94192d52a76b..eac457dc6280 100644 --- a/compiler/src/iree/compiler/Codegen/Common/Passes.h +++ b/compiler/src/iree/compiler/Codegen/Common/Passes.h @@ -60,13 +60,6 @@ std::unique_ptr> createConvertToDestinationPassingStylePass( bool useWARForCooperativeMatrixCodegen); -using ConfigFn = - std::function; -/// Pass to convert Conv2D ops into IGEMM (Im2colOp + matmul). `configFn` is -/// used to set lowering configurations on the resulting ops, if necessary. -std::unique_ptr> -createConvolutionToIGEMMPass(ConfigFn configFn); - std::unique_ptr createDecomposeSoftmaxPass(bool useFusion); /// Pass to perform linalg on tensor bufferization. The function passed into diff --git a/compiler/src/iree/compiler/Codegen/Common/Passes.td b/compiler/src/iree/compiler/Codegen/Common/Passes.td index ff281d6e385d..6a5a9b5578c0 100644 --- a/compiler/src/iree/compiler/Codegen/Common/Passes.td +++ b/compiler/src/iree/compiler/Codegen/Common/Passes.td @@ -83,6 +83,10 @@ def ConvolutionToIGEMMPass : InterfacePass<"iree-codegen-convolution-to-igemm", "mlir::FunctionOpInterface"> { let summary = "Transforms convolution operations into an implicit GEMM format."; + let dependentDialects = [ + "tensor::TensorDialect", + "iree_compiler::IREE::LinalgExt::IREELinalgExtDialect" + ]; } def DecomposeAffineOpsPass: Pass<"iree-codegen-decompose-affine-ops"> { diff --git a/compiler/src/iree/compiler/Codegen/Common/Transforms.h b/compiler/src/iree/compiler/Codegen/Common/Transforms.h index 13cdbf577363..0a000348e22e 100644 --- a/compiler/src/iree/compiler/Codegen/Common/Transforms.h +++ b/compiler/src/iree/compiler/Codegen/Common/Transforms.h @@ -18,6 +18,17 @@ struct OneShotBufferizationOptions; namespace mlir::iree_compiler { +using IGEMMConfigFn = + std::function; +using IGEMMControlFn = std::function; + +/// Converts conv_2d ops into linalg_ext.im2col + matmul, and sets a lowering +/// configuration on the matmul. +LogicalResult convertToIGEMMAndSetConfig( + FunctionOpInterface funcOp, + std::optional configFn = std::nullopt, + std::optional controlFn = std::nullopt); + /// Eliminates tensor.empty ops to avoid buffer allocations. LogicalResult eliminateEmptyTensors( RewriterBase &rewriter, Operation *op, diff --git a/compiler/src/iree/compiler/Codegen/Common/test/convolution_to_igemm.mlir b/compiler/src/iree/compiler/Codegen/Common/test/convolution_to_igemm.mlir index 3d5494e79244..3373fda8c326 100644 --- a/compiler/src/iree/compiler/Codegen/Common/test/convolution_to_igemm.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/test/convolution_to_igemm.mlir @@ -69,25 +69,6 @@ module { // ----- -#map = affine_map<(d0, d1, d2, d3)->(d0, d1, d2, d3)> -#config = #iree_codegen.lowering_config -func.func public @conv_with_lowering_config(%arg0: tensor<1x16x16x4xf32>, %arg1: tensor<3x3x4x16xf32>) -> tensor<1x14x14x16xf32> { - %cst = arith.constant 0.0 : f32 - %empty = tensor.empty() : tensor<1x14x14x16xf32> - %fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> - %0 = linalg.conv_2d_nhwc_hwcf {lowering_config = #config, - dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} - ins(%arg0, %arg1: tensor<1x16x16x4xf32>, tensor<3x3x4x16xf32>) - outs(%fill: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> - return %0 : tensor<1x14x14x16xf32> -} -// CHECK: func.func public @conv_with_lowering_config -// CHECK-NOT: iree_linalg_ext.im2col -// CHECK: linalg.conv_2d_nhwc_hwcf -// CHECK-SAME: lowering_config - -// ----- - #map = affine_map<(d0, d1, d2) -> (d0, d2)> #map1 = affine_map<(d0, d1, d2) -> (d1, d2)> #map2 = affine_map<(d0, d1, d2) -> (d0, d1)> diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel index b074612adbc5..3d8c7a2088b0 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel @@ -95,6 +95,7 @@ iree_compiler_cc_library( "LLVMGPUCastTypeToFitMMA.cpp", "LLVMGPUConfigureTensorLayouts.cpp", "LLVMGPUConfigureVectorLayouts.cpp", + "LLVMGPUConvolutionToIGEMM.cpp", "LLVMGPULowerExecutableTarget.cpp", "LLVMGPUPackSharedMemoryAlloc.cpp", "LLVMGPUPrefetching.cpp", diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt index 6a92f60d7f04..9016d63b6f24 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt @@ -80,6 +80,7 @@ iree_cc_library( "LLVMGPUCastTypeToFitMMA.cpp" "LLVMGPUConfigureTensorLayouts.cpp" "LLVMGPUConfigureVectorLayouts.cpp" + "LLVMGPUConvolutionToIGEMM.cpp" "LLVMGPULowerExecutableTarget.cpp" "LLVMGPUPackSharedMemoryAlloc.cpp" "LLVMGPUPrefetching.cpp" diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUConvolutionToIGEMM.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUConvolutionToIGEMM.cpp new file mode 100644 index 000000000000..b88696ab8f63 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUConvolutionToIGEMM.cpp @@ -0,0 +1,66 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Codegen/Common/Passes.h" +#include "iree/compiler/Codegen/Common/Transforms.h" +#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.h" +#include "iree/compiler/Codegen/LLVMGPU/Passes.h" +#include "iree/compiler/Codegen/Utils/GPUUtils.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" + +#define DEBUG_TYPE "iree-llvmgpu-convolution-to-igemm" + +namespace mlir::iree_compiler { + +#define GEN_PASS_DEF_LLVMGPUCONVOLUTIONTOIGEMMPASS +#include "iree/compiler/Codegen/LLVMGPU/Passes.h.inc" + +namespace { + +/// Function for setting lowering configurations on contractions resulting from +/// the IGEMM transformation. This currently uses the TileAndFuse pipeline, and +/// tries to target MMA intrinsics. +static LogicalResult llvmgpuConfigFn(linalg::GenericOp genericOp, + IREE::LinalgExt::Im2colOp im2colOp) { + auto funcOp = genericOp->getParentOfType(); + if (!funcOp) { + return genericOp.emitError("cannot find parent funcOp"); + } + IREE::GPU::TargetAttr target = getGPUTargetAttr(funcOp); + if (!target) { + return funcOp.emitError("missing GPU target in parent funcOp"); + } + if (failed(IREE::GPU::setMatmulLoweringConfig(target, funcOp, genericOp))) { + return IREE::GPU::setTileAndFuseLoweringConfig(target, funcOp, genericOp); + } + return success(); +} + +static bool llvmgpuControlFn(Operation *op) { + // Do not convert anything that already has a lowering configuration. + if (getLoweringConfig(op)) { + return false; + } + return true; +} + +struct LLVMGPUConvolutionToIGEMMPass final + : impl::LLVMGPUConvolutionToIGEMMPassBase { + using impl::LLVMGPUConvolutionToIGEMMPassBase< + LLVMGPUConvolutionToIGEMMPass>::LLVMGPUConvolutionToIGEMMPassBase; + + void runOnOperation() override; +}; + +void LLVMGPUConvolutionToIGEMMPass::runOnOperation() { + if (failed(convertToIGEMMAndSetConfig(getOperation(), llvmgpuConfigFn, + llvmgpuControlFn))) { + return signalPassFailure(); + } +} + +} // namespace +} // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp index 51fcc6b3996c..aab73c952c5f 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp @@ -1170,29 +1170,12 @@ void addGPUTransformDialectPasses(OpPassManager &funcPassManager, // Common Pass Pipelines //===----------------------------------------------------------------------===// -static LogicalResult igemmConfigFn(linalg::GenericOp genericOp, - IREE::LinalgExt::Im2colOp im2colOp) { - auto funcOp = genericOp->getParentOfType(); - if (!funcOp) { - return genericOp.emitError("cannot find parent funcOp"); - } - IREE::GPU::TargetAttr target = getGPUTargetAttr(funcOp); - if (!target) { - return funcOp.emitError("missing GPU target in parent funcOp"); - } - if (failed(IREE::GPU::setMatmulLoweringConfig(target, funcOp, genericOp))) { - return IREE::GPU::setTileAndFuseLoweringConfig(target, funcOp, genericOp); - } - return success(); -} - static void buildLLVMGPUCodegenConfigurationPassPipelineImpl( OpPassManager &modulePassManager) { { FunctionLikeNest funcPassManager(modulePassManager); - funcPassManager.addPredicatedPass(clLLVMGPUUseIgemm, []() { - return createConvolutionToIGEMMPass(igemmConfigFn); - }); + funcPassManager.addPredicatedPass(clLLVMGPUUseIgemm, + createLLVMGPUConvolutionToIGEMMPass); funcPassManager.addPass(createGPUGeneralizeNamedOpsPass); addCommonTargetExecutablePreprocessingPasses(funcPassManager); addEncodingToNopPasses(funcPassManager); @@ -1242,9 +1225,8 @@ static void buildROCDLCodegenConfigurationPassPipelineImpl( OpPassManager &modulePassManager) { { FunctionLikeNest funcPassManager(modulePassManager); - funcPassManager.addPredicatedPass(clLLVMGPUUseIgemm, []() { - return createConvolutionToIGEMMPass(igemmConfigFn); - }); + funcPassManager.addPredicatedPass(clLLVMGPUUseIgemm, + createLLVMGPUConvolutionToIGEMMPass); funcPassManager.addPass(createGPUGeneralizeNamedOpsPass); addCommonTargetExecutablePreprocessingPasses(funcPassManager); } diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.td b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.td index 815a82f28d8d..aa6b55253734 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.td +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.td @@ -87,6 +87,17 @@ def LLVMGPUConfigureVectorLayoutsPass : let summary = "Pass to set layouts for vector distribution"; } +def LLVMGPUConvolutionToIGEMMPass : + InterfacePass<"iree-llvmgpu-convolution-to-igemm", "mlir::FunctionOpInterface"> { + let summary = "Pass to convert conv_2d ops to igemm and set a lowering configuration."; + let dependentDialects = [ + "tensor::TensorDialect", + "iree_compiler::IREE::Codegen::IREECodegenDialect", + "iree_compiler::IREE::GPU::IREEGPUDialect", + "iree_compiler::IREE::LinalgExt::IREELinalgExtDialect" + ]; +} + def LLVMGPULowerExecutableTargetPass : InterfacePass<"iree-llvmgpu-lower-executable-target", "mlir::FunctionOpInterface"> { let summary = "Perform lowering of executable target using one of the IREE::HAL::DispatchLoweringPassPipeline"; diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel index 00bc6f967acf..40973205380e 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel @@ -49,6 +49,7 @@ iree_lit_test_suite( "legalize.mlir", "linalg_transform.mlir", "llvmgpu_bufferize.mlir", + "llvmgpu_convolution_to_igemm.mlir", "pack_pipeline_test.mlir", "pack_shared_memory_alloc.mlir", "prefetch_shared_memory.mlir", diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt index 6be97c06d533..2a86fd3507f4 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt @@ -40,6 +40,7 @@ iree_lit_test_suite( "legalize.mlir" "linalg_transform.mlir" "llvmgpu_bufferize.mlir" + "llvmgpu_convolution_to_igemm.mlir" "nvvm_extract_address_computation.mlir" "nvvm_mma_sync_pipeline_test.mlir" "nvvm_pipeline_test.mlir" diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/llvmgpu_convolution_to_igemm.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/llvmgpu_convolution_to_igemm.mlir new file mode 100644 index 000000000000..1fa2bae99a8e --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/llvmgpu_convolution_to_igemm.mlir @@ -0,0 +1,36 @@ +// RUN: iree-opt --split-input-file --iree-gpu-test-target=gfx940 --pass-pipeline="builtin.module(func.func(iree-llvmgpu-convolution-to-igemm),canonicalize,cse)" %s | FileCheck %s + +#config = #iree_codegen.lowering_config +func.func public @conv_with_lowering_config(%arg0: tensor<1x16x16x4xf32>, %arg1: tensor<3x3x4x16xf32>) -> tensor<1x14x14x16xf32> { + %cst = arith.constant 0.0 : f32 + %empty = tensor.empty() : tensor<1x14x14x16xf32> + %fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> + %0 = linalg.conv_2d_nhwc_hwcf {lowering_config = #config, + dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} + ins(%arg0, %arg1: tensor<1x16x16x4xf32>, tensor<3x3x4x16xf32>) + outs(%fill: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> + return %0 : tensor<1x14x14x16xf32> +} +// CHECK: func.func public @conv_with_lowering_config +// CHECK-NOT: iree_linalg_ext.im2col +// CHECK: linalg.conv_2d_nhwc_hwcf +// CHECK-SAME: lowering_config + +// ----- + +func.func public @set_lowering_config(%arg0: tensor<1x34x34x128xf32>, %arg1: tensor<3x3x128x128xf32>) -> tensor<1x32x32x128xf32> { + %cst = arith.constant 0.0 : f32 + %empty = tensor.empty() : tensor<1x32x32x128xf32> + %fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<1x32x32x128xf32>) -> tensor<1x32x32x128xf32> + %0 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} + ins(%arg0, %arg1: tensor<1x34x34x128xf32>, tensor<3x3x128x128xf32>) + outs(%fill: tensor<1x32x32x128xf32>) -> tensor<1x32x32x128xf32> + return %0 : tensor<1x32x32x128xf32> +} +// CHECK: func.func public @set_lowering_config +// CHECK: iree_linalg_ext.im2col +// CHECK: linalg.generic +// CHECK-SAME: lowering_config = #iree_gpu.lowering_config< +// CHECK-SAME: {mma_kind = #iree_gpu.mma_layout, +// CHECK-SAME: promote_operands = [0, 1], reduction = [0, 0, 0, 0, 8], +// CHECK-SAME: subgroup = [0, 0, 2, 2, 0], workgroup = [1, 1, 2, 8, 0]}> diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ConvertConv2DToIm2ColOp.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ConvertConv2DToIm2ColOp.cpp index 6e699fda1f2e..131ff3e5437b 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ConvertConv2DToIm2ColOp.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ConvertConv2DToIm2ColOp.cpp @@ -38,7 +38,7 @@ static Value createMul(Location loc, Value x, Value y, OpBuilder &builder) { namespace { -using ControlFnTy = std::optional>; +using ControlFnTy = std::function; // Convert linalg.conv_2d_nhwc_hwcf into linalg.generic (for img2col packing) // and linalg.matmul. @@ -78,7 +78,8 @@ class ConvertConv2DNhwcHwcf final public: using OpRewritePattern::OpRewritePattern; - ConvertConv2DNhwcHwcf(MLIRContext *context, ControlFnTy controlFn) + ConvertConv2DNhwcHwcf(MLIRContext *context, + std::optional controlFn) : OpRewritePattern(context), controlFn(controlFn) {} @@ -192,7 +193,7 @@ class ConvertConv2DNhwcHwcf final } private: - ControlFnTy controlFn; + std::optional controlFn; }; // For nchw, because the channels are to the left of the image shape dimensions, @@ -204,7 +205,8 @@ class ConvertConv2DNchwFchw final public: using OpRewritePattern::OpRewritePattern; - ConvertConv2DNchwFchw(MLIRContext *context, ControlFnTy controlFn) + ConvertConv2DNchwFchw(MLIRContext *context, + std::optional controlFn) : OpRewritePattern(context), controlFn(controlFn) {} @@ -314,7 +316,7 @@ class ConvertConv2DNchwFchw final } private: - ControlFnTy controlFn; + std::optional controlFn; }; struct ConvertConv2DToIm2ColOpPass final @@ -335,7 +337,7 @@ struct ConvertConv2DToIm2ColOpPass final } // namespace void populateConv2DToIm2colOpPatterns(RewritePatternSet &patterns, - ControlFnTy controlFn) { + std::optional controlFn) { patterns.insert( patterns.getContext(), std::move(controlFn)); } diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.h b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.h index 1e858df14e2f..cc894b3edecd 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.h +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.h @@ -28,8 +28,10 @@ LogicalResult splitReduction(RewriterBase &rewriter, LinalgExt::TopkOp topkOp, const TopkSplitReductionControlFn &splitReductionFn); -// Patterns to convert linalg convolution ops into a gemm with an im2col -// op and reshapes on the inputs. +/// Patterns to convert linalg convolution ops into a gemm with an im2col +/// op and reshapes on the inputs. +/// TODO(Max191): Maybe move to transforms and use a funcOp walk instead of a +/// rewrite pattern for this. void populateConv2DToIm2colOpPatterns( RewritePatternSet &patterns, std::optional> controlFn = std::nullopt); From 0c2c627747586ed39ce7b1f6bfc9d8b83c4a4e69 Mon Sep 17 00:00:00 2001 From: Ian Wood <75152913+IanWood1@users.noreply.github.com> Date: Fri, 25 Oct 2024 11:11:03 -0700 Subject: [PATCH 16/45] [NFC] Update old naming from flow to dispatch creation (#18904) Update naming from `Flow` -> `DispatchCreation` in Passes.cpp Signed-off-by: Ian Wood --- .../src/iree/compiler/DispatchCreation/Passes.cpp | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/compiler/src/iree/compiler/DispatchCreation/Passes.cpp b/compiler/src/iree/compiler/DispatchCreation/Passes.cpp index d5150ccae7c1..afee21cbbcd8 100644 --- a/compiler/src/iree/compiler/DispatchCreation/Passes.cpp +++ b/compiler/src/iree/compiler/DispatchCreation/Passes.cpp @@ -49,10 +49,10 @@ static llvm::cl::opt clEnableFusePaddingIntoLinalgProducerOps( static llvm::cl::opt clPadFactor( "iree-dispatch-creation-pad-factor", - llvm::cl::desc( - "Provides padding size hints that will be attached to " - "encodings. This only affects the experimental data tiling " - "path in Flow with iree-dispatch-creation-experimental-data-tiling."), + llvm::cl::desc("Provides padding size hints that will be attached to " + "encodings. This only affects the experimental data tiling " + "path in DispatchCreation with " + "iree-dispatch-creation-experimental-data-tiling."), llvm::cl::init(32)); static llvm::cl::opt clEnablePadHandling( @@ -337,14 +337,14 @@ void registerDispatchCreationPasses() { } void registerDispatchCreationPipelines() { - PassPipelineRegistration flowDispatchRegionCreationPipeline( + PassPipelineRegistration dispatchCreationPipeline( "iree-dispatch-creation-pipeline", "Flag used to run passes that form dispatch regions", [](OpPassManager &passManager, const TransformOptions &transformOptions) { buildDispatchCreationPassPipeline(passManager, transformOptions); }); - PassPipelineRegistration<> flowDispatchRegionFormationPreprocessingPipeline( + PassPipelineRegistration<> dispatchCreationPreprocessingPipeline( "iree-dispatch-creation-preprocessing-pipeline", "Flag used to run preprocessing passes that run passes before dispatch " "region formation. Used only for testing", From 03c744ead1482abde3ee9e70293215c5b557c629 Mon Sep 17 00:00:00 2001 From: Max191 <44243577+Max191@users.noreply.github.com> Date: Fri, 25 Oct 2024 16:37:27 -0700 Subject: [PATCH 17/45] [GPU] Support multiple contraction dims in MmaSchedules (#18720) This adds support for multiple M, N, and K dims in problems when deducing a GPUMMASchedule. The new heuristic is similar to the old one, but works on pairs of M and N dims. For example: ``` tensor * tensor -> tensor ``` This will try to distribute the seeded tile counts to `M0` and `N0` (first attempting to distribute evenly, and then distributing to N followed by N), and then distribute the residual counts to `M1` and `N1`. The K tile counts will be partitioned to `K0` first, and then the residual tile counts will be partitioned to `K1`. This PR also updates the config selection logic for the TileAndFuse pipeline to make use of the multiple contraction dimensions in mma schedules. --------- Signed-off-by: Max Dawkins --- .../Codegen/Common/GPU/GPUHeuristics.cpp | 361 ++++++++++++------ .../Codegen/Common/GPU/GPUHeuristics.h | 58 ++- .../Dialect/GPU/TargetUtils/ConfigUtils.cpp | 98 +++-- .../compiler/Codegen/LLVMGPU/KernelConfig.cpp | 79 ++-- .../test/ROCDL/config_tile_and_fuse.mlir | 72 +++- .../test/llvmgpu_convolution_to_igemm.mlir | 2 +- .../compiler/Codegen/SPIRV/KernelConfig.cpp | 24 +- .../Preprocessing/Common/PadToIntrinsics.cpp | 30 +- 8 files changed, 492 insertions(+), 232 deletions(-) diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.cpp index dc3078372f92..790484d2c565 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.cpp @@ -9,6 +9,7 @@ #include #include "llvm/ADT/APInt.h" +#include "llvm/ADT/Sequence.h" #include "llvm/Support/Debug.h" #include "llvm/Support/MathExtras.h" #include "llvm/Support/raw_ostream.h" @@ -20,51 +21,106 @@ using llvm::APIntOps::GreatestCommonDivisor; namespace mlir::iree_compiler { +template static llvm::raw_ostream &operator<<(llvm::raw_ostream &os, - const GPUMMASchedule &schedule) { - os << "mSize: " << schedule.mSize << ", "; - os << "nSize: " << schedule.nSize << ", "; - os << "kSize: " << schedule.kSize << ", "; - os << "mTileCount: " << schedule.mTileCount << ", "; - os << "nTileCount: " << schedule.nTileCount << ", "; - os << "kTileCount: " << schedule.kTileCount << ", "; - os << "mWarpCount: " << schedule.mWarpCount << ", "; - os << "nWarpCount: " << schedule.nWarpCount; + const llvm::SmallVectorImpl &vector) { + os << "["; + llvm::interleaveComma(vector, os); + os << "]"; return os; } +llvm::raw_ostream &operator<<(llvm::raw_ostream &os, + const GPUMMASchedule &schedule) { + os << "mSizes: " << schedule.mSize << ", "; + os << "nSizes: " << schedule.nSize << ", "; + os << "kSizes: " << schedule.kSize << ", "; + os << "mTileSizes: " << schedule.mTileSizes << ", "; + os << "nTileSizes: " << schedule.nTileSizes << ", "; + os << "kTileSizes: " << schedule.kTileSizes << ", "; + os << "mSubgroupCounts: " << schedule.mSubgroupCounts << ", "; + os << "nSubgroupCounts: " << schedule.nSubgroupCounts; + return os; +} + +// Shortened helper to compute the product of `values`. +static int64_t prod(ArrayRef values) { + return ShapedType::getNumElements(values); +} + static int64_t calculateSharedMemoryUsedInBytes(const GPUMMASchedule &schedule, int64_t lhsBitwidth, int64_t rhsBitwidth) { - int64_t tileM = schedule.mSize * schedule.mTileCount * schedule.mWarpCount; - int64_t tileN = schedule.nSize * schedule.nTileCount * schedule.nWarpCount; - int64_t tileK = schedule.kSize * schedule.kTileCount; + + int64_t tileM = schedule.mSize * prod(schedule.mTileSizes) * + prod(schedule.mSubgroupCounts); + int64_t tileN = schedule.nSize * prod(schedule.nTileSizes) * + prod(schedule.nSubgroupCounts); + int64_t tileK = schedule.kSize * prod(schedule.kTileSizes); return (tileM * tileK * lhsBitwidth + tileN * tileK * rhsBitwidth) / 8; } +/// Check that a GPUMMASchedule fits alignment restrictions. To be aligned, +/// the problem must be evenly divisible by the number of elements in the +/// schedule for each dimension. If `mustBeAligned` is false, then the innermost +/// problem dimension is allowed to be unaligned . static bool isScheduleAligned(const GPUMatmulShapeType &problem, const GPUMMASchedule &schedule, bool mustBeAligned) { - auto alignedMSize = - mustBeAligned - ? problem.mSize - : llvm::divideCeil(problem.mSize, schedule.mSize) * schedule.mSize; - auto alignedNSize = - mustBeAligned - ? problem.nSize - : llvm::divideCeil(problem.nSize, schedule.nSize) * schedule.nSize; - auto alignedKSize = - mustBeAligned - ? problem.kSize - : llvm::divideCeil(problem.kSize, schedule.kSize) * schedule.kSize; - bool isValidM = (alignedMSize % (schedule.mSize * schedule.mTileCount * - schedule.mWarpCount)) == 0; - bool isValidN = (alignedNSize % (schedule.nSize * schedule.nTileCount * - schedule.nWarpCount)) == 0; - bool isValidK = (alignedKSize % (schedule.kSize * schedule.kTileCount)) == 0; + SmallVector alignedMSizes(problem.mSizes); + alignedMSizes.back() = + mustBeAligned ? problem.mSizes.back() + : llvm::divideCeil(problem.mSizes.back(), schedule.mSize) * + schedule.mSize; + SmallVector alignedNSizes(problem.nSizes); + alignedNSizes.back() = + mustBeAligned ? problem.nSizes.back() + : llvm::divideCeil(problem.nSizes.back(), schedule.nSize) * + schedule.nSize; + SmallVector alignedKSizes(problem.kSizes); + alignedKSizes.back() = + mustBeAligned ? problem.kSizes.back() + : llvm::divideCeil(problem.kSizes.back(), schedule.kSize) * + schedule.kSize; + // Returns the number of elements in the schedule for each dimension. + auto getScheduleSizes = + [&](int64_t size, SmallVector tileCount, + std::optional> subgroupCount) { + SmallVector sizes = llvm::map_to_vector( + llvm::seq(tileCount.size()), [&](int64_t i) { + return subgroupCount ? tileCount[i] * subgroupCount.value()[i] + : tileCount[i]; + }); + sizes.back() *= size; + return sizes; + }; + // Checks whether the elements of `a` are evenly divisible by the + // corresponding elements of `b`. + auto areAligned = [](SmallVector a, SmallVector b) { + for (auto [aVal, bVal] : llvm::zip_equal(a, b)) { + if (aVal % bVal != 0) { + return false; + } + } + return true; + }; + bool isValidM = areAligned( + alignedMSizes, getScheduleSizes(schedule.mSize, schedule.mTileSizes, + schedule.mSubgroupCounts)); + bool isValidN = areAligned( + alignedNSizes, getScheduleSizes(schedule.nSize, schedule.nTileSizes, + schedule.nSubgroupCounts)); + bool isValidK = areAligned( + alignedKSizes, + getScheduleSizes(schedule.kSize, schedule.kTileSizes, std::nullopt)); return isValidM && isValidN && isValidK; } +/// Returns whether or not a GPUMMASchedule is valid for the given problem. +/// This checks that: +/// - The problem is aligned to the schedule +/// - the number of threads in the schedule workgroup can be distributed +/// to a corresponding vector.transfer read in VectorDistribute. static bool isValidMMASchedule(const GPUMatmulShapeType &problem, const GPUMMASchedule &schedule, bool mustBeAligned, int64_t subgroupSize, @@ -76,11 +132,13 @@ static bool isValidMMASchedule(const GPUMatmulShapeType &problem, const int64_t kMaxVectorLoadBitWidth = 128; int64_t elemsPerThread = kMaxVectorLoadBitWidth / problem.bType.getIntOrFloatBitWidth(); - int64_t wgThreads = schedule.mWarpCount * schedule.nWarpCount * subgroupSize; - - int64_t mWgSize = schedule.mSize * schedule.mTileCount * schedule.mWarpCount; - int64_t nWgSize = schedule.nSize * schedule.nTileCount * schedule.nWarpCount; - int64_t kWgSize = schedule.kSize * schedule.kTileCount; + int64_t wgThreads = subgroupSize * prod(schedule.mSubgroupCounts) * + prod(schedule.nSubgroupCounts); + int64_t mWgSize = schedule.mSize * prod(schedule.mTileSizes) * + prod(schedule.mSubgroupCounts); + int64_t nWgSize = schedule.nSize * prod(schedule.nTileSizes) * + prod(schedule.nSubgroupCounts); + int64_t kWgSize = schedule.kSize * prod(schedule.kTileSizes); int64_t innerLhsDimSize = transposedLhs ? mWgSize : kWgSize; int64_t innerRhsDimSize = transposedRhs ? kWgSize : nWgSize; @@ -94,6 +152,10 @@ static bool isValidMMASchedule(const GPUMatmulShapeType &problem, return isAligned && isDistributableLhs && isDistributableRhs; } +/// Tries to fit the schedule into shared memory by decrementing the size of the +/// schedule dimensions from outermost to innermost until a valid schedule is +/// found. The schedule sizes are reduced in the order of mTileSizes, +/// nTileSizes, kTileSizes, mSubgroupCounts, nSubgroupCounts. static FailureOr fitScheduleInSharedMemory( GPUMatmulShapeType intrinsic, GPUMMASchedule schedule, llvm::function_ref isScheduleValid) { @@ -105,31 +167,35 @@ static FailureOr fitScheduleInSharedMemory( llvm::dbgs() << "Shrinking schedule...\n"; }); - auto decrementIfPossible = [](int64_t &c) -> LogicalResult { - if (c <= 1) { - return failure(); + auto decrementIfPossible = + [](SmallVector &sizes) -> LogicalResult { + for (int64_t &size : sizes) { + if (size <= 1) + continue; + --size; + return success(); } - --c; - return success(); + return failure(); }; // Attempt to shrink the schedule along one of the dimensions. // TODO: A better solution should probably factor problem.mSize / - // (mWarpCount * mTileCount * mSize) and then pop off the smallest factors - // one at a time, preferably trying to keep the tile "generally square." - if (succeeded(decrementIfPossible(schedule.mTileCount))) { + // (mSubgroupCount * mTileCount * mSize) and then pop off the smallest + // factors one at a time, preferably trying to keep the tile "generally + // square." + if (succeeded(decrementIfPossible(schedule.mTileSizes))) { continue; } - if (succeeded(decrementIfPossible(schedule.nTileCount))) { + if (succeeded(decrementIfPossible(schedule.nTileSizes))) { continue; } - if (succeeded(decrementIfPossible(schedule.kTileCount))) { + if (succeeded(decrementIfPossible(schedule.kTileSizes))) { continue; } - if (succeeded(decrementIfPossible(schedule.mWarpCount))) { + if (succeeded(decrementIfPossible(schedule.mSubgroupCounts))) { continue; } - if (succeeded(decrementIfPossible(schedule.nWarpCount))) { + if (succeeded(decrementIfPossible(schedule.nSubgroupCounts))) { continue; } @@ -148,6 +214,9 @@ static FailureOr fitScheduleInSharedMemory( static LogicalResult canTargetIntrinsic(const GPUMatmulShapeType &problem, const GPUMatmulShapeType &intrinsic, bool canUpcastAcc, bool mustBeAligned) { + assert(intrinsic.mSizes.size() == 1 && intrinsic.nSizes.size() == 1 && + intrinsic.kSizes.size() == 1 && + "expected intrinsic to have a single M, N, and K dimension."); if (problem.aType != intrinsic.aType || problem.bType != intrinsic.bType) { return failure(); // Cannot use this intrinsic for mismatched types } @@ -161,17 +230,17 @@ static LogicalResult canTargetIntrinsic(const GPUMatmulShapeType &problem, } } - if (mustBeAligned && (problem.mSize % intrinsic.mSize != 0 || - problem.nSize % intrinsic.nSize != 0 || - problem.kSize % intrinsic.kSize != 0)) { + if (mustBeAligned && (problem.mSizes.back() % intrinsic.mSizes[0] != 0 || + problem.nSizes.back() % intrinsic.nSizes[0] != 0 || + problem.kSizes.back() % intrinsic.kSizes[0] != 0)) { return failure(); // Cannot use this intrinsic for misaligned cases. } // Cannot use the intrinsic when the tile size is greater than problem size. // Because tiling is a no-op, and we can't infer tiling sizes from IR. - if (!mustBeAligned && - (problem.mSize < intrinsic.mSize || problem.nSize < intrinsic.nSize || - problem.kSize < intrinsic.kSize)) { + if (!mustBeAligned && (problem.mSizes.back() < intrinsic.mSizes[0] || + problem.nSizes.back() < intrinsic.nSizes[0] || + problem.kSizes.back() < intrinsic.kSizes[0])) { return failure(); } @@ -185,77 +254,123 @@ static GPUMMASchedule getOptimalMMASchedule(const GPUMatmulShapeType &problem, const GPUMatmulShapeType &intrinsic, const GPUMMAHeuristicSeeds &seeds, uint64_t intrinsicIndex) { - int64_t mTotalTileCount = llvm::divideCeil(problem.mSize, intrinsic.mSize); - int64_t nTotalTileCount = llvm::divideCeil(problem.nSize, intrinsic.nSize); - - int64_t remainingWarps = seeds.bestSubgroupCountPerWorkgroup; + assert(intrinsic.mSizes.size() == 1 && intrinsic.nSizes.size() == 1 && + intrinsic.kSizes.size() == 1 && + "expected intrinsic to have a single M, N, and K dimension."); + // mTotalTileCounts and nTotalTileCounts represent the total number of + // intrinsics along the M or N dimensions needed to fill the problem size. + // For example, if the problem is {M:[4, 16], N:[2, 32], K[3, 128]} for a + // 16x16x16 intrinsic, then: + // - mTotalTileCounts would be 4 * (16/16) = 4 + // - nTotalTileCounts would be 2 * (32/16) = 4 + SmallVector mTotalTileCounts = problem.mSizes; + SmallVector nTotalTileCounts = problem.nSizes; + mTotalTileCounts.back() = + llvm::divideCeil(problem.mSizes.back(), intrinsic.mSizes[0]); + nTotalTileCounts.back() = + llvm::divideCeil(problem.nSizes.back(), intrinsic.nSizes[0]); + + int64_t remainingSubgroups = seeds.bestSubgroupCountPerWorkgroup; int64_t remainingTiles = seeds.bestMNTileCountPerSubgroup; - // Assign more warps to the M dimension (used later) to balance thread + // Assign more subgroups to the M dimension (used later) to balance thread // counts along X and Y dimensions. - int64_t warpSqrt = - 1ull << (llvm::divideCeil(llvm::Log2_64(remainingWarps), 2)); - int64_t tileSqrt = 1ull << (llvm::Log2_64(remainingTiles) / 2); - - int64_t mWarpCount = 0, nWarpCount = 0; - int64_t mTileCount = 0, nTileCount = 0; - - // See if the square root can divide mTotalTileCount. If so it means we can - // distribute to both dimensions evenly. Otherwise, try to distribute to N - // and then M. - if (mTotalTileCount > (warpSqrt * tileSqrt) && - mTotalTileCount % (warpSqrt * tileSqrt) == 0) { - mWarpCount = warpSqrt; - mTileCount = tileSqrt; - - remainingWarps /= warpSqrt; - remainingTiles /= tileSqrt; - - APInt nGCD = GreatestCommonDivisor(APInt(64, nTotalTileCount), - APInt(64, remainingWarps)); - nWarpCount = nGCD.getSExtValue(); - nTotalTileCount /= nWarpCount; - remainingWarps /= nWarpCount; - - nGCD = GreatestCommonDivisor(APInt(64, nTotalTileCount), - APInt(64, remainingTiles)); - nTileCount = nGCD.getSExtValue(); - } else { - APInt nGCD = GreatestCommonDivisor(APInt(64, nTotalTileCount), - APInt(64, remainingWarps)); - nWarpCount = nGCD.getSExtValue(); - nTotalTileCount /= nWarpCount; - remainingWarps /= nWarpCount; - - nGCD = GreatestCommonDivisor(APInt(64, nTotalTileCount), - APInt(64, remainingTiles)); - nTileCount = nGCD.getSExtValue(); - remainingTiles /= nTileCount; - - APInt mGCD = GreatestCommonDivisor(APInt(64, mTotalTileCount), - APInt(64, remainingWarps)); - mWarpCount = mGCD.getSExtValue(); - mTotalTileCount /= mWarpCount; - remainingWarps /= mWarpCount; - - mGCD = GreatestCommonDivisor(APInt(64, mTotalTileCount), - APInt(64, remainingTiles)); - mTileCount = mGCD.getSExtValue(); + int mDim = problem.mSizes.size() - 1; + int nDim = problem.nSizes.size() - 1; + SmallVector mTileSizes(problem.mSizes.size(), 0), + nTileSizes(problem.nSizes.size(), 0), + mSubgroupCounts(problem.mSizes.size(), 0), + nSubgroupCounts(problem.nSizes.size(), 0); + // Start at the innermost nDim and mDim, and try to distribute evenly to M and + // N for each pair of M and N dims. Otherwise, distribute to N and then M. + while (mDim >= 0 || nDim >= 0) { + int64_t subgroupSqrt = + 1ull << (llvm::divideCeil(llvm::Log2_64(remainingSubgroups), 2)); + int64_t tileSqrt = 1ull << (llvm::Log2_64(remainingTiles) / 2); + + // See if the square root can divide mTotalTileCount. If so it means we can + // distribute to both dimensions evenly to minimize the number of global + // loads. Otherwise, try to distribute to N and then M. + if (mDim >= 0 && nDim >= 0 && + mTotalTileCounts[mDim] > (subgroupSqrt * tileSqrt) && + mTotalTileCounts[mDim] % (subgroupSqrt * tileSqrt) == 0) { + mSubgroupCounts[mDim] = subgroupSqrt; + mTileSizes[mDim] = tileSqrt; + + remainingSubgroups /= subgroupSqrt; + remainingTiles /= tileSqrt; + + APInt nGCD = GreatestCommonDivisor(APInt(64, nTotalTileCounts[nDim]), + APInt(64, remainingSubgroups)); + nSubgroupCounts[nDim] = nGCD.getSExtValue(); + nTotalTileCounts[nDim] /= nSubgroupCounts[nDim]; + remainingSubgroups /= nSubgroupCounts[nDim]; + + nGCD = GreatestCommonDivisor(APInt(64, nTotalTileCounts[nDim]), + APInt(64, remainingTiles)); + nTileSizes[nDim] = nGCD.getSExtValue(); + remainingTiles /= nTileSizes[nDim]; + } else { + if (nDim >= 0) { + APInt nGCD = GreatestCommonDivisor(APInt(64, nTotalTileCounts[nDim]), + APInt(64, remainingSubgroups)); + nSubgroupCounts[nDim] = nGCD.getSExtValue(); + nTotalTileCounts[nDim] /= nSubgroupCounts[nDim]; + remainingSubgroups /= nSubgroupCounts[nDim]; + + nGCD = GreatestCommonDivisor(APInt(64, nTotalTileCounts[nDim]), + APInt(64, remainingTiles)); + nTileSizes[nDim] = nGCD.getSExtValue(); + remainingTiles /= nTileSizes[nDim]; + } + + if (mDim >= 0) { + APInt mGCD = GreatestCommonDivisor(APInt(64, mTotalTileCounts[mDim]), + APInt(64, remainingSubgroups)); + mSubgroupCounts[mDim] = mGCD.getSExtValue(); + mTotalTileCounts[mDim] /= mSubgroupCounts[mDim]; + remainingSubgroups /= mSubgroupCounts[mDim]; + + mGCD = GreatestCommonDivisor(APInt(64, mTotalTileCounts[mDim]), + APInt(64, remainingTiles)); + mTileSizes[mDim] = mGCD.getSExtValue(); + remainingTiles /= mTileSizes[mDim]; + } + } + --mDim; + --nDim; } - const uint64_t kTotalTileCount = - llvm::divideCeil(problem.kSize, intrinsic.kSize); + // kTotalTileCounts is similar to m/nTotalTileCounts, representing the total + // number of intrinsics along the K dimensions needed to fill the problem. + // For the problem described above {M:[4, 16], N:[2, 32], K[3, 128]} with a + // 16x16x16 intrinsic, then: + // - kTotalTileCounts would be 3 * (128/16) = 24 + SmallVector kTotalTileCounts = problem.kSizes; + kTotalTileCounts.back() = + llvm::divideCeil(problem.kSizes.back(), intrinsic.kSizes[0]); + // Compute the ideal number of intrinsics along K per subgroup based on the + // seed. int64_t bestKTileCountPerSubgroup = seeds.bestKElementCountPerSubgroup ? llvm::divideCeil(seeds.bestKElementCountPerSubgroup, - intrinsic.kSize) + intrinsic.kSizes[0]) : seeds.bestKTileCountPerSubgroup; - APInt kGCD = GreatestCommonDivisor(APInt(64, kTotalTileCount), - APInt(64, bestKTileCountPerSubgroup)); - int64_t kTileCount = kGCD.getSExtValue(); + SmallVector kTileSizes(problem.kSizes.size(), 0); + // Start at the innermost K dim, and tile each dim to try to satisfy the ideal + // K intrinsic count per subgroup with the overall product of K tile counts. + int kDim = problem.kSizes.size() - 1; + while (kDim >= 0) { + APInt kGCD = GreatestCommonDivisor(APInt(64, kTotalTileCounts[kDim]), + APInt(64, bestKTileCountPerSubgroup)); + kTileSizes[kDim] = kGCD.getSExtValue(); + bestKTileCountPerSubgroup /= kTileSizes[kDim]; + --kDim; + } - return GPUMMASchedule{intrinsicIndex, intrinsic.mSize, intrinsic.nSize, - intrinsic.kSize, mWarpCount, nWarpCount, - mTileCount, nTileCount, kTileCount}; + return GPUMMASchedule{ + intrinsicIndex, intrinsic.mSizes[0], intrinsic.nSizes[0], + intrinsic.kSizes[0], mSubgroupCounts, nSubgroupCounts, + mTileSizes, nTileSizes, kTileSizes}; } FailureOr deduceMMASchedule( @@ -297,7 +412,6 @@ FailureOr deduceMMASchedule( return isAligned && sharedMemoryUsed <= sharedMemLimitInBytes; }; - return fitScheduleInSharedMemory(intrinsic, schedule, isValidSchedule); } return failure(); @@ -309,7 +423,10 @@ FailureOr deduceAttentionSchedule( const GPUMMAHeuristicSeeds &pvMatmulSeeds, int64_t sharedMemLimitInBytes, int64_t subgroupSize, bool transposedQ, bool transposedK, bool transposedV, bool canUpcastAcc, bool mustBeAligned) { - + assert(pvMatmul.mSizes.size() == 1 && pvMatmul.nSizes.size() == 1 && + pvMatmul.kSizes.size() == 1 && qkMatmul.mSizes.size() == 1 && + qkMatmul.nSizes.size() == 1 && qkMatmul.kSizes.size() == 1 && + "unimplemented: multi M/N/K attention schedule"); for (auto [index, intrinsic] : llvm::enumerate(intrinsics)) { if (failed(canTargetIntrinsic(qkMatmul, intrinsic, canUpcastAcc, mustBeAligned))) { @@ -329,7 +446,7 @@ FailureOr deduceAttentionSchedule( llvm::dbgs() << " " << schedule << "\n"; }); - int64_t intrinsicK = intrinsic.kSize; + int64_t intrinsicK = intrinsic.kSizes[0]; auto isValidSchedule = [&](const GPUMMASchedule &schedule) -> bool { // Create a mma schedule for qkMatmul in attention. // qkMatmul.M = pvMatmul.M @@ -339,11 +456,11 @@ FailureOr deduceAttentionSchedule( schedule.mSize, schedule.kSize, intrinsicK, - /*mWarpCount=*/schedule.mWarpCount, - /*nWarpCount=*/1, - schedule.mTileCount, - schedule.kTileCount, - qkMatmul.kSize / intrinsicK}; + /*mSubgroupCount=*/schedule.mSubgroupCounts[0], + /*nSubgroupCount=*/1, + schedule.mTileSizes[0], + schedule.kTileSizes[0], + qkMatmul.kSizes[0] / intrinsicK}; bool isQKAligned = isValidMMASchedule(qkMatmul, qkSchedule, mustBeAligned, subgroupSize, diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.h b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.h index 8211443a2e12..13f6a56c1b6f 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.h +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.h @@ -10,15 +10,18 @@ namespace mlir::iree_compiler { /// Struct containing information about a matmul's shape and type. struct GPUMatmulShapeType { - int64_t mSize; - int64_t nSize; - int64_t kSize; + SmallVector mSizes; + SmallVector nSizes; + SmallVector kSizes; Type aType; Type bType; Type cType; GPUMatmulShapeType(int64_t m, int64_t n, int64_t k, Type a, Type b, Type c) - : mSize(m), nSize(n), kSize(k), aType(a), bType(b), cType(c) {} + : mSizes({m}), nSizes({n}), kSizes({k}), aType(a), bType(b), cType(c) {} + GPUMatmulShapeType(SmallVector m, SmallVector n, + SmallVector k, Type a, Type b, Type c) + : mSizes(m), nSizes(n), kSizes(k), aType(a), bType(b), cType(c) {} }; /// Struct containing seed tile sizes for GPU MMA heuristics deduction logic. @@ -38,14 +41,42 @@ struct GPUMMAHeuristicSeeds { struct GPUMMASchedule { // Index of the chosen intrinsic into the list of given MMA intrinsics uint64_t index; - int64_t mSize; // Native MMA size along M dimension - int64_t nSize; // Native MMA size along N dimension - int64_t kSize; // Native MMA size along K dimension - int64_t mWarpCount; // Number of subgroups along M dimension - int64_t nWarpCount; // Number of subgroups along N dimension - int64_t mTileCount; // Number of tiles per subgroup along M dimension - int64_t nTileCount; // Number of tiles per subgroup along N dimension - int64_t kTileCount; // Number of tiles along K dimension + int64_t mSize; // Native MMA intrinsic size along M dimension for a subgroup. + int64_t nSize; // Native MMA intrinsic size along N dimension for a subgroup. + int64_t kSize; // Native MMA intrinsic size along K dimension for a subgroup. + + // Number of subgroups along each M and N dimension. + SmallVector mSubgroupCounts; + SmallVector nSubgroupCounts; + + // Tile sizes for each M, N, and K dimension. When there are multiple M, N, + // or K dimensions, the intrinsic sizes are targeted to the innermost + // dimension, and the outer dimensions can be thought of as unrolling factors + // along M, N, or K. + SmallVector mTileSizes; // M tile sizes per subgroup. + SmallVector nTileSizes; // N tile sizes per subgroup. + SmallVector kTileSizes; // K tile sizes. + + // Constructor for multi M, N, K dim schedules. + GPUMMASchedule(uint64_t i, int64_t mIntrinsicSize, int64_t nIntrinsicSize, + int64_t kIntrinsicSize, SmallVector mSubgroupCounts, + SmallVector nSubgroupCounts, + SmallVector mTileSizes, + SmallVector nTileSizes, + SmallVector kTileSizes) + : index(i), mSize(mIntrinsicSize), nSize(nIntrinsicSize), + kSize(kIntrinsicSize), mSubgroupCounts(mSubgroupCounts), + nSubgroupCounts(nSubgroupCounts), mTileSizes(mTileSizes), + nTileSizes(nTileSizes), kTileSizes(kTileSizes) {} + + // Constructor for single M, N, K dim schedules. + GPUMMASchedule(uint64_t i, int64_t mIntrinsicSize, int64_t nIntrinsicSize, + int64_t kIntrinsicSize, int64_t mSubgroup, int64_t nSubgroup, + int64_t mTileSize, int64_t nTileSize, int64_t kTileSize) + : index(i), mSize(mIntrinsicSize), nSize(nIntrinsicSize), + kSize(kIntrinsicSize), mSubgroupCounts({mSubgroup}), + nSubgroupCounts({nSubgroup}), mTileSizes({mTileSize}), + nTileSizes({nTileSize}), kTileSizes({kTileSize}) {} }; /// Returns a schedule for using one of the given MMA |intrinsics| to target the @@ -69,4 +100,7 @@ FailureOr deduceAttentionSchedule( bool transposedV = false, bool canUpcastAcc = false, bool mustBeAligned = true); +llvm::raw_ostream &operator<<(llvm::raw_ostream &os, + const GPUMMASchedule &schedule); + } // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp index ca23b0ca6e06..58bfdc0a028b 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp @@ -13,6 +13,7 @@ #include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUInterfaces.h" #include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.h" #include "iree/compiler/Codegen/Utils/Utils.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" @@ -124,20 +125,37 @@ LogicalResult setMatmulLoweringConfig(IREE::GPU::TargetAttr target, return failure(); } - // For now we are not being smart and trying to reshape dimensions to allow - // for better usage of intrinsics, and instead are tiling all dimensions - // except the inner most m, n, and k dimensions to 1. - int64_t mDim = contractionDims.m.back(); - int64_t nDim = contractionDims.n.back(); - int64_t kDim = contractionDims.k.back(); - - // Dynamic dims are expected to be taken care of earlier in the pipeline. - if (ShapedType::isDynamic(bounds[mDim]) || - ShapedType::isDynamic(bounds[nDim]) || - ShapedType::isDynamic(bounds[kDim])) { + // TODO(Max191): add dynamic shape support for inner most dims. + if (ShapedType::isDynamic(bounds[contractionDims.m.back()]) || + ShapedType::isDynamic(bounds[contractionDims.n.back()]) || + ShapedType::isDynamic(bounds[contractionDims.k.back()])) { return failure(); } + // Gather all static M, N, and K dimensions to deduce the MMASchedule. Dynamic + // dimensions will be tiled to 1 in workgroup tiling, so they are ignored when + // computing an MMA schedule. + SmallVector mDims, nDims, kDims; + for (auto mDim : contractionDims.m) { + if (!ShapedType::isDynamic(bounds[mDim])) { + mDims.push_back(mDim); + } + } + for (auto nDim : contractionDims.n) { + if (!ShapedType::isDynamic(bounds[nDim])) { + nDims.push_back(nDim); + } + } + for (auto kDim : contractionDims.k) { + if (!ShapedType::isDynamic(bounds[kDim])) { + kDims.push_back(kDim); + } + } + + auto getDimBounds = [&](SmallVector dims) -> SmallVector { + return llvm::map_to_vector(dims, [&](int64_t dim) { return bounds[dim]; }); + }; + Value lhs = linalgOp.getDpsInputOperand(0)->get(); Value rhs = linalgOp.getDpsInputOperand(1)->get(); Value init = linalgOp.getDpsInitOperand(0)->get(); @@ -146,8 +164,9 @@ LogicalResult setMatmulLoweringConfig(IREE::GPU::TargetAttr target, Type rhsElemType = getElementTypeOrSelf(rhs); Type initElemType = getElementTypeOrSelf(init); - GPUMatmulShapeType problem{bounds[mDim], bounds[nDim], bounds[kDim], - lhsElemType, rhsElemType, initElemType}; + GPUMatmulShapeType problem{getDimBounds(mDims), getDimBounds(nDims), + getDimBounds(kDims), lhsElemType, + rhsElemType, initElemType}; SmallVector intrinsics; for (IREE::GPU::MMAAttr mma : target.getWgp().getMma()) { @@ -166,7 +185,9 @@ LogicalResult setMatmulLoweringConfig(IREE::GPU::TargetAttr target, // Note that the following heuristic seeds are just placeholder values. // We need to clean it up and make it adjusting to different targets. // See https://github.com/iree-org/iree/issues/16341 for details. - if (problem.mSize * problem.nSize <= 512 * 512) { + int64_t mSize = ShapedType::getNumElements(problem.mSizes); + int64_t nSize = ShapedType::getNumElements(problem.nSizes); + if (mSize * nSize <= 512 * 512) { // For matmuls with small M*N size, we want to distribute M*N onto more // workgroups to fill the GPU. Use a smaller bestMNTileCountPerSubgroup // and a larger bestKTileCountPerSubgroup. @@ -190,10 +211,10 @@ LogicalResult setMatmulLoweringConfig(IREE::GPU::TargetAttr target, // TODO: Drop this. This is only a consideration for other pipelines. SmallVector maps = linalgOp.getIndexingMapsArray(); bool transposedLhs = - kDim != + kDims.back() != llvm::cast(maps[0].getResults().back()).getPosition(); bool transposedRhs = - nDim != + nDims.back() != llvm::cast(maps[1].getResults().back()).getPosition(); // First try to find a schedule with an exactly matching intrinsic. @@ -213,16 +234,13 @@ LogicalResult setMatmulLoweringConfig(IREE::GPU::TargetAttr target, } LDBG("Target Subgroup size: " << targetSubgroupSize); - LDBG("Schedule: sizes [" << schedule->mSize << ", " << schedule->nSize << ", " - << schedule->kSize << "]"); - LDBG("Schedule: tile counts [" << schedule->mTileCount << ", " - << schedule->nTileCount << ", " - << schedule->kTileCount << "]"); - LDBG("Schedule: warp counts [" << schedule->mWarpCount << ", " - << schedule->nWarpCount << "]"); + LDBG("Schedule: " << schedule); - std::array workgroupSize{ - schedule->nWarpCount * targetSubgroupSize, schedule->mWarpCount, 1}; + int64_t flatWorkgroupSize = + targetSubgroupSize * + ShapedType::getNumElements(schedule->nSubgroupCounts) * + ShapedType::getNumElements(schedule->mSubgroupCounts); + std::array workgroupSize{flatWorkgroupSize, 1, 1}; SmallVector workgroupTileSizes(linalgOp.getNumLoops(), 0); SmallVector reductionTileSizes(linalgOp.getNumLoops(), 0); @@ -244,16 +262,30 @@ LogicalResult setMatmulLoweringConfig(IREE::GPU::TargetAttr target, reductionTileSizes[k] = 1; } - // Compute the M/N dimension tile size by multiplying subgroup information. - workgroupTileSizes[mDim] = schedule->mWarpCount * schedule->mTileCount; - workgroupTileSizes[nDim] = schedule->nWarpCount * schedule->nTileCount; - - // Specify the subgroup tile sizes from the mma schedule. This is applied - subgroupTileSizes[mDim] = schedule->mTileCount; - subgroupTileSizes[nDim] = schedule->nTileCount; + // Adjust the inner bound size for packing to intrinsic shapes, since tiling + // happens after packing. + assert(bounds[mDims.back()] % schedule->mSize == 0 && + bounds[nDims.back()] % schedule->nSize == 0 && + "expected inner bound to be evenly divisible by schedule sizes."); + bounds[mDims.back()] /= schedule->mSize; + bounds[nDims.back()] /= schedule->nSize; + + // Compute the M/N dimension tile sizes by multiplying subgroup information. + for (auto [i, mDim] : llvm::enumerate(mDims)) { + workgroupTileSizes[mDim] = + schedule->mSubgroupCounts[i] * schedule->mTileSizes[i]; + subgroupTileSizes[mDim] = schedule->mTileSizes[i]; + } + for (auto [i, nDim] : llvm::enumerate(nDims)) { + workgroupTileSizes[nDim] = + schedule->nSubgroupCounts[i] * schedule->nTileSizes[i]; + subgroupTileSizes[nDim] = schedule->nTileSizes[i]; + } // Similarly the reduction tile size is just the post-packing tile count. - reductionTileSizes[kDim] = schedule->kTileCount; + for (auto [i, kDim] : llvm::enumerate(kDims)) { + reductionTileSizes[kDim] = schedule->kTileSizes[i]; + } IREE::GPU::MmaInterfaceAttr mmaKind = target.getWgp().getMma()[schedule->index]; diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp index ff002ace5b0f..4b64cda3adc9 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp @@ -301,6 +301,11 @@ setConvolutionVectorDistributionConfig(IREE::GPU::TargetAttr target, Type rhsElemType = getElementTypeOrSelf(rhs); Type initElemType = getElementTypeOrSelf(init); + // TODO(Max191): Support multiple M/N/K dimension problems for MMASchedules + // once the pipeline is able to support it. After adding multiple dimensions, + // all instances of schedule->m/nSubgroupCounts[0] and + // schedule->m/n/kTileSizes[0] need to use the full list of sizes instead of + // just the first element. GPUMatmulShapeType problem{bounds[mDim], bounds[nDim], bounds[kDim], lhsElemType, rhsElemType, initElemType}; @@ -339,8 +344,9 @@ setConvolutionVectorDistributionConfig(IREE::GPU::TargetAttr target, return failure(); } - std::array workgroupSize{ - schedule->nWarpCount * targetSubgroupSize, schedule->mWarpCount, 1}; + std::array workgroupSize{schedule->nSubgroupCounts[0] * + targetSubgroupSize, + schedule->mSubgroupCounts[0], 1}; SmallVector workgroupTileSizes(op.getNumLoops(), 0); SmallVector reductionTileSizes(op.getNumLoops(), 0); @@ -360,11 +366,11 @@ setConvolutionVectorDistributionConfig(IREE::GPU::TargetAttr target, } // Compute the M/N dimension tile size by multiply subgroup information. workgroupTileSizes[mDim] = - schedule->mWarpCount * schedule->mTileCount * schedule->mSize; + schedule->mSubgroupCounts[0] * schedule->mTileSizes[0] * schedule->mSize; workgroupTileSizes[nDim] = - schedule->nWarpCount * schedule->nTileCount * schedule->nSize; + schedule->nSubgroupCounts[0] * schedule->nTileSizes[0] * schedule->nSize; - reductionTileSizes[kDim] = schedule->kTileCount * schedule->kSize; + reductionTileSizes[kDim] = schedule->kTileSizes[0] * schedule->kSize; // Tile all filter loop dimensions to 1. for (int64_t filterDim : convolutionDims->filterLoop) { @@ -386,8 +392,8 @@ setConvolutionVectorDistributionConfig(IREE::GPU::TargetAttr target, // for later access in the pipeline. SmallVector pipelineAttrs; auto scheduleAttr = IREE::GPU::MMAScheduleAttr::get( - context, target.getWgp().getMma()[schedule->index], schedule->mWarpCount, - schedule->nWarpCount); + context, target.getWgp().getMma()[schedule->index], + schedule->mSubgroupCounts[0], schedule->nSubgroupCounts[0]); pipelineAttrs.emplace_back(StringAttr::get(context, "mma_schedule"), scheduleAttr); @@ -489,6 +495,11 @@ setMatmulVectorDistributionConfig(IREE::GPU::TargetAttr target, rhsElemType = getElementTypeOrSelf(rhsOp.getDpsInputs()[0]); } + // TODO(Max191): Support multiple M/N/K dimension problems for MMASchedules + // once the pipeline is able to support it. After adding multiple dimensions, + // all instances of schedule->m/nSubgroupCounts[0] and + // schedule->m/n/kTileSizes[0] need to use the full list of sizes instead of + // just the first element. GPUMatmulShapeType problem{bounds[mDim], bounds[nDim], bounds[kDim], lhsElemType, rhsElemType, initElemType}; @@ -509,7 +520,7 @@ setMatmulVectorDistributionConfig(IREE::GPU::TargetAttr target, // Note that the following heuristic seeds are just placeholder values. // We need to clean it up and make it adjusting to different targets. // See https://github.com/iree-org/iree/issues/16341 for details. - if (problem.mSize * problem.nSize <= clGPUMatmulCThreshold) { + if (problem.mSizes[0] * problem.nSizes[0] <= clGPUMatmulCThreshold) { // For matmuls with small M*N size, we want to distribute M*N onto more // workgroups to fill the GPU. Use a smaller bestMNTileCountPerSubgroup // and a larger bestKTileCountPerSubgroup. @@ -573,16 +584,11 @@ setMatmulVectorDistributionConfig(IREE::GPU::TargetAttr target, } LDBG("Target Subgroup size: " << targetSubgroupSize); - LDBG("Schedule: sizes [" << schedule->mSize << ", " << schedule->nSize << ", " - << schedule->kSize << "]"); - LDBG("Schedule: tile counts [" << schedule->mTileCount << ", " - << schedule->nTileCount << ", " - << schedule->kTileCount << "]"); - LDBG("Schedule: warp counts [" << schedule->mWarpCount << ", " - << schedule->nWarpCount << "]"); + LDBG("Schedule: " << schedule); - std::array workgroupSize{ - schedule->nWarpCount * targetSubgroupSize, schedule->mWarpCount, 1}; + std::array workgroupSize{schedule->nSubgroupCounts[0] * + targetSubgroupSize, + schedule->mSubgroupCounts[0], 1}; SmallVector workgroupTileSizes(op.getNumLoops(), 0); SmallVector reductionTileSizes(op.getNumLoops(), 0); @@ -605,11 +611,11 @@ setMatmulVectorDistributionConfig(IREE::GPU::TargetAttr target, // Compute the M/N dimension tile size by multiply subgroup information. workgroupTileSizes[mDim] = - schedule->mWarpCount * schedule->mTileCount * schedule->mSize; + schedule->mSubgroupCounts[0] * schedule->mTileSizes[0] * schedule->mSize; workgroupTileSizes[nDim] = - schedule->nWarpCount * schedule->nTileCount * schedule->nSize; + schedule->nSubgroupCounts[0] * schedule->nTileSizes[0] * schedule->nSize; - reductionTileSizes[kDim] = schedule->kTileCount * schedule->kSize; + reductionTileSizes[kDim] = schedule->kTileSizes[0] * schedule->kSize; LLVM_DEBUG(debugPrintContractionInfo("Workgroup tile sizes", op.getNumLoops(), *contractionDims, workgroupTileSizes)); @@ -631,8 +637,8 @@ setMatmulVectorDistributionConfig(IREE::GPU::TargetAttr target, // for later access in the pipeline. SmallVector pipelineAttrs; auto scheduleAttr = IREE::GPU::MMAScheduleAttr::get( - context, target.getWgp().getMma()[schedule->index], schedule->mWarpCount, - schedule->nWarpCount); + context, target.getWgp().getMma()[schedule->index], + schedule->mSubgroupCounts[0], schedule->nSubgroupCounts[0]); pipelineAttrs.emplace_back(StringAttr::get(context, "mma_schedule"), scheduleAttr); @@ -772,22 +778,17 @@ setAttentionVectorDistributionConfig(IREE::GPU::TargetAttr target, // TODO: Due to a bug in layout configuration, we cannot set warp count on // the N dimension. This is however ok, because we generally do not want to // distribute subgroups on N dimension anyway. - if (schedule->nWarpCount != 1) { - schedule->nTileCount *= schedule->nWarpCount; - schedule->nWarpCount = 1; + if (schedule->nSubgroupCounts[0] != 1) { + schedule->nTileSizes[0] *= schedule->nSubgroupCounts[0]; + schedule->nSubgroupCounts[0] = 1; } LDBG("Target Subgroup size: " << targetSubgroupSize); - LDBG("Schedule: sizes [" << schedule->mSize << ", " << schedule->nSize << ", " - << schedule->kSize << "]"); - LDBG("Schedule: tile counts [" << schedule->mTileCount << ", " - << schedule->nTileCount << ", " - << schedule->kTileCount << "]"); - LDBG("Schedule: warp counts [" << schedule->mWarpCount << ", " - << schedule->nWarpCount << "]"); + LDBG("Schedule: " << schedule); - std::array workgroupSize{ - schedule->nWarpCount * targetSubgroupSize, schedule->mWarpCount, 1}; + std::array workgroupSize{schedule->nSubgroupCounts[0] * + targetSubgroupSize, + schedule->mSubgroupCounts[0], 1}; SmallVector workgroupTileSizes(opInfo.getDomainRank(), 0); SmallVector reductionTileSizes(op.getNumLoops(), 0); @@ -811,11 +812,11 @@ setAttentionVectorDistributionConfig(IREE::GPU::TargetAttr target, // Compute the M/N dimension tile size by multiply subgroup information. workgroupTileSizes[mDim] = - schedule->mWarpCount * schedule->mTileCount * schedule->mSize; + schedule->mSubgroupCounts[0] * schedule->mTileSizes[0] * schedule->mSize; workgroupTileSizes[nDim] = - schedule->nWarpCount * schedule->nTileCount * schedule->nSize; + schedule->nSubgroupCounts[0] * schedule->nTileSizes[0] * schedule->nSize; - reductionTileSizes[k2Dim] = schedule->kTileCount * schedule->kSize; + reductionTileSizes[k2Dim] = schedule->kTileSizes[0] * schedule->kSize; MLIRContext *context = op.getContext(); SmallVector attrs; @@ -831,8 +832,8 @@ setAttentionVectorDistributionConfig(IREE::GPU::TargetAttr target, // for later access in the pipeline. SmallVector pipelineAttrs; auto scheduleAttr = IREE::GPU::MMAScheduleAttr::get( - context, target.getWgp().getMma()[schedule->index], schedule->mWarpCount, - schedule->nWarpCount); + context, target.getWgp().getMma()[schedule->index], + schedule->mSubgroupCounts[0], schedule->nSubgroupCounts[0]); pipelineAttrs.emplace_back(StringAttr::get(context, "mma_schedule"), scheduleAttr); diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse.mlir index b98e85a79713..819b8826bb1d 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse.mlir @@ -37,11 +37,79 @@ func.func @expanded_matmul_transpose_b(%lhs: tensor<2x64x2048xf16>, %rhs: tensor // CHECK-SAME: mma_kind = #iree_gpu.mma_layout // CHECK-SAME: promote_operands = [0, 1] // CHECK-SAME: reduction = [0, 0, 0, 0, 4] -// CHECK-SAME: subgroup = [0, 0, 4, 1, 0] +// CHECK-SAME: subgroup = [1, 1, 4, 1, 0] // CHECK-SAME: workgroup = [1, 1, 4, 4, 0] // ----- +#map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d4, d5)> +#map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d3, d4, d5)> +#map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> +func.func @multi_dim_mma_schedule(%lhs: tensor<10x32x128x16xf16>, %rhs: tensor<4x32x128x16xf16>) -> tensor<10x4x32x32xf16> { + %c0 = arith.constant 0 : index + %cst = arith.constant 0.000000e+00 : f16 + %5 = tensor.empty() : tensor<10x4x32x32xf16> + %6 = linalg.fill ins(%cst : f16) outs(%5 : tensor<10x4x32x32xf16>) -> tensor<10x4x32x32xf16> + %7 = linalg.generic { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} + ins(%lhs, %rhs : tensor<10x32x128x16xf16>, tensor<4x32x128x16xf16>) outs(%6 : tensor<10x4x32x32xf16>) { + ^bb0(%in: f16, %in_0: f16, %out: f16): + %8 = arith.mulf %in, %in_0 : f16 + %9 = arith.addf %8, %out : f16 + linalg.yield %9 : f16 + } -> tensor<10x4x32x32xf16> + return %7 : tensor<10x4x32x32xf16> +} + +// CHECK-LABEL: func.func @multi_dim_mma_schedule +// CHECK-SAME: #iree_codegen.translation_info + +// CHECK: linalg.generic {{.*}}lowering_config = #iree_gpu.lowering_config +// CHECK-SAME: mma_kind = #iree_gpu.mma_layout +// CHECK-SAME: promote_operands = [0, 1] +// CHECK-SAME: reduction = [0, 0, 0, 0, 4, 1] +// CHECK-SAME: subgroup = [2, 2, 1, 1, 0, 0] +// CHECK-SAME: workgroup = [2, 2, 2, 2, 0, 0] + +// ----- + +#map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d3, d5, d6)> +#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d2, d4, d5, d6)> +#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4)> +func.func @dynamic_multi_dim_mma_schedule(%lhs: tensor, %rhs: tensor) -> tensor { + %c0 = arith.constant 0 : index + %cst = arith.constant 0.000000e+00 : f16 + %d0 = tensor.dim %lhs, %c0 : tensor + %d2 = tensor.dim %rhs, %c0 : tensor + %5 = tensor.empty(%d0, %d2) : tensor + %6 = linalg.fill ins(%cst : f16) outs(%5 : tensor) -> tensor + %7 = linalg.generic { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} + ins(%lhs, %rhs : tensor, tensor) outs(%6 : tensor) { + ^bb0(%in: f16, %in_0: f16, %out: f16): + %8 = arith.mulf %in, %in_0 : f16 + %9 = arith.addf %8, %out : f16 + linalg.yield %9 : f16 + } -> tensor + return %7 : tensor +} + +// CHECK-LABEL: func.func @dynamic_multi_dim_mma_schedule +// CHECK-SAME: #iree_codegen.translation_info + +// CHECK: linalg.generic {{.*}}lowering_config = #iree_gpu.lowering_config +// CHECK-SAME: mma_kind = #iree_gpu.mma_layout +// CHECK-SAME: promote_operands = [0, 1] +// CHECK-SAME: reduction = [0, 0, 0, 0, 0, 1, 1] +// CHECK-SAME: subgroup = [0, 1, 0, 1, 1, 0, 0] +// CHECK-SAME: workgroup = [1, 2, 1, 1, 2, 0, 0] + +// ----- + func.func @mfma_matmul_1024x1024x1024(%lhs: tensor<1024x1024xf16>, %rhs: tensor<1024x1024xf16>) -> tensor<1024x1024xf32> { %cst = arith.constant 0.000000e+00 : f32 %c0 = arith.constant 0 : index @@ -52,7 +120,7 @@ func.func @mfma_matmul_1024x1024x1024(%lhs: tensor<1024x1024xf16>, %rhs: tensor< } // CHECK-LABEL: func.func @mfma_matmul_1024x1024x1024 -// CHECK-SAME: #iree_codegen.translation_info // Verify that the fill does not have the lowering config propagated to it. diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/llvmgpu_convolution_to_igemm.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/llvmgpu_convolution_to_igemm.mlir index 1fa2bae99a8e..9618281c699e 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/llvmgpu_convolution_to_igemm.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/llvmgpu_convolution_to_igemm.mlir @@ -33,4 +33,4 @@ func.func public @set_lowering_config(%arg0: tensor<1x34x34x128xf32>, %arg1: ten // CHECK-SAME: lowering_config = #iree_gpu.lowering_config< // CHECK-SAME: {mma_kind = #iree_gpu.mma_layout, // CHECK-SAME: promote_operands = [0, 1], reduction = [0, 0, 0, 0, 8], -// CHECK-SAME: subgroup = [0, 0, 2, 2, 0], workgroup = [1, 1, 2, 8, 0]}> +// CHECK-SAME: subgroup = [1, 1, 2, 2, 0], workgroup = [1, 1, 2, 8, 0]}> diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp index 16a1acf4316f..bbdec5c83f6d 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp +++ b/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp @@ -884,6 +884,11 @@ setCooperativeMatrixConfig(IREE::GPU::TargetAttr target, linalg::LinalgOp op, Type lhsElem = getElementType(lhs); Type rhsElem = getElementType(rhs); Type initElem = getElementType(init); + // TODO(Max191): Support multiple M/N/K dimension problems for MMASchedules + // once the pipeline is able to support it. After adding multiple dimensions, + // all instances of schedule->m/nSubgroupCounts[0] and + // schedule->m/n/kTileSizes[0] need to use the full list of sizes instead of + // just the first element. GPUMatmulShapeType problem(dimM, dimN, dimK, lhsElem, rhsElem, initElem); SmallVector intrinsics; @@ -921,8 +926,9 @@ setCooperativeMatrixConfig(IREE::GPU::TargetAttr target, linalg::LinalgOp op, auto pipeline = CodeGenPipeline::SPIRVCooperativeMatrixVectorize; - std::array workgroupSize{schedule->nWarpCount * subgroupSize, - schedule->mWarpCount, 1}; + std::array workgroupSize{schedule->nSubgroupCounts[0] * + subgroupSize, + schedule->mSubgroupCounts[0], 1}; SmallVector vectorSizes(kIndex + 1, 0); if (isBM) @@ -934,21 +940,23 @@ setCooperativeMatrixConfig(IREE::GPU::TargetAttr target, linalg::LinalgOp op, SmallVector subgroupTileSizes(lastParallelDim + 1, 0); if (isBM) subgroupTileSizes[bIndex] = 1; - subgroupTileSizes[mIndex] = schedule->mTileCount * vectorSizes[mIndex]; - subgroupTileSizes[nIndex] = schedule->nTileCount * vectorSizes[nIndex]; + subgroupTileSizes[mIndex] = schedule->mTileSizes[0] * vectorSizes[mIndex]; + subgroupTileSizes[nIndex] = schedule->nTileSizes[0] * vectorSizes[nIndex]; SmallVector workgroupTileSizes(lastParallelDim + 1, 0); if (isBM) workgroupTileSizes[bIndex] = 1; - workgroupTileSizes[mIndex] = schedule->mWarpCount * subgroupTileSizes[mIndex]; - workgroupTileSizes[nIndex] = schedule->nWarpCount * subgroupTileSizes[nIndex]; + workgroupTileSizes[mIndex] = + schedule->mSubgroupCounts[0] * subgroupTileSizes[mIndex]; + workgroupTileSizes[nIndex] = + schedule->nSubgroupCounts[0] * subgroupTileSizes[nIndex]; // Also create one level for reduction. This is needed because of // SPIRVTileAndPromotePass requires it. // TODO(#10499): Consolidate tiling configuration across different pipelines. SmallVector reductionTileSizes; reductionTileSizes.append(kIndex, 0); - reductionTileSizes.push_back(schedule->kTileCount * schedule->kSize); + reductionTileSizes.push_back(schedule->kTileSizes[0] * schedule->kSize); TileSizesListType tileSizes = {workgroupTileSizes, subgroupTileSizes, reductionTileSizes, vectorSizes}; @@ -956,7 +964,7 @@ setCooperativeMatrixConfig(IREE::GPU::TargetAttr target, linalg::LinalgOp op, // Don't do multibuffering if the inner reduction loop is folded out. auto pipelineDepth = softwarePipelineDepth; auto storeStage = softwarePipelineStoreStage; - if (schedule->kTileCount <= 1) { + if (schedule->kTileSizes[0] <= 1) { pipelineDepth = 0; storeStage = 0; } diff --git a/compiler/src/iree/compiler/Preprocessing/Common/PadToIntrinsics.cpp b/compiler/src/iree/compiler/Preprocessing/Common/PadToIntrinsics.cpp index ba415b3fb656..922e50882775 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/PadToIntrinsics.cpp +++ b/compiler/src/iree/compiler/Preprocessing/Common/PadToIntrinsics.cpp @@ -242,16 +242,16 @@ padConvOp(RewriterBase &rewriter, linalg::LinalgOp linalgOp, return llvm::divideCeil(value, padTo) * padTo - value; }; - if (mSize % intrinsic.mSize != 0) { - mPadding = getPadding(mSize, intrinsic.mSize); + if (mSize % intrinsic.mSizes[0] != 0) { + mPadding = getPadding(mSize, intrinsic.mSizes[0]); } - if (nSize % intrinsic.nSize != 0) { - nPadding = getPadding(nSize, intrinsic.nSize); + if (nSize % intrinsic.nSizes[0] != 0) { + nPadding = getPadding(nSize, intrinsic.nSizes[0]); } - if (kSize % intrinsic.kSize != 0) { - kPadding = getPadding(kSize, intrinsic.kSize); + if (kSize % intrinsic.kSizes[0] != 0) { + kPadding = getPadding(kSize, intrinsic.kSizes[0]); } if (!mPadding && !nPadding && !kPadding) { @@ -381,7 +381,7 @@ static void padContractionLikeOp( for (GPUMatmulShapeType &intrinsic : intrinsics) { std::optional mPadding, nPadding, kPadding; SmallVector> dimsToExpandCandidate; - if (mSize % intrinsic.mSize != 0 || ShapedType::isDynamic(mSize)) { + if (mSize % intrinsic.mSizes[0] != 0 || ShapedType::isDynamic(mSize)) { OpFoldResult mSizeExpr = rewriter.getIndexAttr(mSize); if (ShapedType::isDynamic(mSize)) { auto mOperandDimPair = getSrcOperandAndDim(mDim); @@ -390,12 +390,12 @@ static void padContractionLikeOp( auto [mOperand, mOperandDim] = mOperandDimPair.value(); mSizeExpr = rewriter.create(loc, mOperand, mOperandDim) .getResult(); - dimsToExpandCandidate.emplace_back(mDim, intrinsic.mSize); + dimsToExpandCandidate.emplace_back(mDim, intrinsic.mSizes[0]); } - mPadding = getPadding(mSizeExpr, intrinsic.mSize); + mPadding = getPadding(mSizeExpr, intrinsic.mSizes[0]); } - if (nSize % intrinsic.nSize != 0 || ShapedType::isDynamic(nSize)) { + if (nSize % intrinsic.nSizes[0] != 0 || ShapedType::isDynamic(nSize)) { OpFoldResult nSizeExpr = rewriter.getIndexAttr(nSize); if (ShapedType::isDynamic(nSize)) { auto nOperandDimPair = getSrcOperandAndDim(nDim); @@ -404,12 +404,12 @@ static void padContractionLikeOp( auto [nOperand, nOperandDim] = nOperandDimPair.value(); nSizeExpr = rewriter.create(loc, nOperand, nOperandDim) .getResult(); - dimsToExpandCandidate.emplace_back(nDim, intrinsic.nSize); + dimsToExpandCandidate.emplace_back(nDim, intrinsic.nSizes[0]); } - nPadding = getPadding(nSizeExpr, intrinsic.nSize); + nPadding = getPadding(nSizeExpr, intrinsic.nSizes[0]); } - if (kSize % intrinsic.kSize != 0 || ShapedType::isDynamic(kSize)) { + if (kSize % intrinsic.kSizes[0] != 0 || ShapedType::isDynamic(kSize)) { OpFoldResult kSizeExpr = rewriter.getIndexAttr(kSize); if (ShapedType::isDynamic(kSize)) { auto kOperandDimPair = getSrcOperandAndDim(kDim); @@ -418,9 +418,9 @@ static void padContractionLikeOp( auto [kOperand, kOperandDim] = kOperandDimPair.value(); kSizeExpr = rewriter.create(loc, kOperand, kOperandDim) .getResult(); - dimsToExpandCandidate.emplace_back(kDim, intrinsic.kSize); + dimsToExpandCandidate.emplace_back(kDim, intrinsic.kSizes[0]); } - kPadding = getPadding(kSizeExpr, intrinsic.kSize); + kPadding = getPadding(kSizeExpr, intrinsic.kSizes[0]); } if (!mPadding && !nPadding && !kPadding) { From 9731fed369ef59ee1e120f9f1315a7b9981e515c Mon Sep 17 00:00:00 2001 From: MaheshRavishankar <1663364+MaheshRavishankar@users.noreply.github.com> Date: Mon, 28 Oct 2024 07:00:57 -0700 Subject: [PATCH 18/45] Pass to block dynamic dimensions of operands of `iree_linalg_ext.attention`. (#18874) The use of `IntegerRangeAnalysis` and `IntegerDivisibilityAnalysis` gives range and divisibility information for constants passed to the dispatch. This can be used to infer the range and divisibility information for all tensor values in the dispatch. This PR adds an analysis to do this. This analysis is then used to expand the dimensions of operands of the attention operation that are dynamic, but are known to be divisible by a compile-time static value. This gets the operations into a form that can be compiled by the AMDGPU backend and target the mfma intrinsics. Signed-off-by: MaheshRavishankar --------- Signed-off-by: MaheshRavishankar --- .../iree/compiler/Codegen/Common/BUILD.bazel | 5 + .../Codegen/Common/BlockDynamicDimensions.cpp | 302 ++++++++++++++++++ .../compiler/Codegen/Common/CMakeLists.txt | 5 + .../iree/compiler/Codegen/Common/Passes.td | 6 + .../Common/TensorDynamicDimAnalysis.cpp | 236 ++++++++++++++ .../Codegen/Common/TensorDynamicDimAnalysis.h | 65 ++++ .../compiler/Codegen/Common/test/BUILD.bazel | 1 + .../Codegen/Common/test/CMakeLists.txt | 1 + .../Common/test/block_dynamic_dims.mlir | 101 ++++++ .../iree/compiler/Codegen/LLVMGPU/Passes.cpp | 3 + .../iree/compiler/Dialect/Flow/IR/FlowOps.cpp | 21 ++ .../iree/compiler/Dialect/Flow/IR/FlowOps.td | 7 +- 12 files changed, 752 insertions(+), 1 deletion(-) create mode 100644 compiler/src/iree/compiler/Codegen/Common/BlockDynamicDimensions.cpp create mode 100644 compiler/src/iree/compiler/Codegen/Common/TensorDynamicDimAnalysis.cpp create mode 100644 compiler/src/iree/compiler/Codegen/Common/TensorDynamicDimAnalysis.h create mode 100644 compiler/src/iree/compiler/Codegen/Common/test/block_dynamic_dims.mlir diff --git a/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel index 7aca986d540b..d9d23b22dc31 100644 --- a/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel @@ -86,6 +86,7 @@ iree_compiler_cc_library( name = "Common", srcs = [ "AddFastMathFlags.cpp", + "BlockDynamicDimensions.cpp", "BubbleUpOrdinalOps.cpp", "BufferizationAnalysis.cpp", "BufferizeCopyOnlyDispatchesPass.cpp", @@ -137,6 +138,7 @@ iree_compiler_cc_library( "RemoveSingleIterationLoop.cpp", "ReplaceSlowMinMaxOps.cpp", "SplitFullPartialTransferPass.cpp", + "TensorDynamicDimAnalysis.cpp", "TensorToVectorVectorizePad.cpp", "TestExecutablePreprocessing.cpp", "TestPartitionableLoopsInterface.cpp", @@ -155,6 +157,7 @@ iree_compiler_cc_library( "ExtractAddressComputation.h", "PassUtils.h", "Passes.h", + "TensorDynamicDimAnalysis.h", "TileSizeSelection.h", "Transforms.h", "UserConfig.h", @@ -176,6 +179,7 @@ iree_compiler_cc_library( "//compiler/src/iree/compiler/Dialect/HAL/IR", "//compiler/src/iree/compiler/Dialect/LinalgExt/IR", "//compiler/src/iree/compiler/Dialect/LinalgExt/Transforms", + "//compiler/src/iree/compiler/Dialect/Util/Analysis", "//compiler/src/iree/compiler/Dialect/Util/IR", "//compiler/src/iree/compiler/Utils", "//llvm-external-projects/iree-dialects:IREELinalgTransformDialect", @@ -191,6 +195,7 @@ iree_compiler_cc_library( "@llvm-project//mlir:BufferizationDialect", "@llvm-project//mlir:BufferizationInterfaces", "@llvm-project//mlir:BufferizationTransforms", + "@llvm-project//mlir:DestinationStyleOpInterface", "@llvm-project//mlir:DialectUtils", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:FuncTransforms", diff --git a/compiler/src/iree/compiler/Codegen/Common/BlockDynamicDimensions.cpp b/compiler/src/iree/compiler/Codegen/Common/BlockDynamicDimensions.cpp new file mode 100644 index 000000000000..7a45116f0abd --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Common/BlockDynamicDimensions.cpp @@ -0,0 +1,302 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Codegen/Common/Passes.h" +#include "iree/compiler/Codegen/Common/TensorDynamicDimAnalysis.h" +#include "iree/compiler/Codegen/Transforms/Transforms.h" +#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h" +#include "iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h" +#include "iree/compiler/Dialect/Util/IR/UtilOps.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/Utils/Utils.h" +#include "mlir/Dialect/MemRef/Transforms/Transforms.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#define DEBUG_TYPE "iree-codegen-block-dynamic-dimensions" + +namespace mlir::iree_compiler { + +#define GEN_PASS_DEF_BLOCKDYNAMICDIMENSIONSPASS +#include "iree/compiler/Codegen/Common/Passes.h.inc" + +using TensorDivisibilityInfo = + llvm::SmallDenseMap; + +namespace { + +struct RemoveOptimizationBarrier final + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(IREE::Util::OptimizationBarrierOp barrierOp, + PatternRewriter &rewriter) const override { + rewriter.replaceOp(barrierOp, barrierOp.getOperands()); + return success(); + } +}; + +/// This pass is used to materialize information about dynamic dimensions of +/// `tensor` operands of an operation in the IR. If a dynamic dimension is +/// known to be a multiple of a compile-time constant value, this pass +/// expands the shape of the operands. For example if a `tensor` operand +/// is of shape `tensor<...x?x...>` and that dimension is known to be a +/// multiple of 16, this operand is expanded to `tensor<...x?x16x...>` where the +/// size of the new dynamic dimension is 1/16-th the size of the original +/// dynamic dimension size. This is done in two steps. +/// 1) Replace operands with such dynamic dimension with the result of a +/// `tensor.expand_shape/tensor.collapse_shape` pair +/// to materialize the new static dimension and immediately fold it away. A +/// optimization barrier is added in between to prevent these operations from +/// being folded. +/// 2) Use patterns that propagate the `tensor.collapse_shape` down to +/// manipulate the operation appropriately. This +/// allows re-using the (fairly complex) logic used to expand dimensions of +/// operations implemented in the propagation patterns. +/// At the end of the pass the optimization barriers are removed to fold away +/// any un-propagated `tensor.expand_shape/tensor.collapse_shape` patterns. +struct BlockDynamicDimensionsPass final + : impl::BlockDynamicDimensionsPassBase { + void runOnOperation() override; +}; +} // namespace + +/// Retrieve the divisibility information for dynamic dimensions of `v` if +/// known. +static TensorDivisibilityInfo +getTensorDivisibilityInfo(const TensorDynamicDimAnalysis &dynamicDimAnalysis, + Value v) { + TensorDivisibilityInfo divisibilityInfo; + auto tensorType = dyn_cast(v.getType()); + if (!tensorType) { + return divisibilityInfo; + } + + for (auto [index, dim] : llvm::enumerate(tensorType.getShape())) { + if (!tensorType.isDynamicDim(index)) + continue; + std::optional dimDivisibility = + dynamicDimAnalysis.getDivisibilityInfo(v, index); + if (!dimDivisibility) + continue; + divisibilityInfo[index] = std::move(dimDivisibility.value()); + } + + return divisibilityInfo; +} + +/// For a `v` if the dimension is known to be multiple of a compile-time static +/// value, insert +/// +/// ```mlir +/// %v_expand = tensor.expand_shape %v +/// %barrier = util.optimization.barrier %v +/// %v_collapse = tensor.collapse_shape %barrier +/// ``` +/// +/// where the generated `tensor.expand_shape` and `tensor.collapse_shape` are +/// inverses of each other. The `util.optimization.barrier` avoid these from +/// getting folded away during reshape propagation. Return the result of the +/// `tensor.collapse_shape generated. +static std::optional +blockDynamicDimensionsOfValue(RewriterBase &rewriter, + const TensorDivisibilityInfo &divisibilityInfo, + Value v) { + auto tensorType = dyn_cast(v.getType()); + if (!tensorType) { + return std::nullopt; + } + + // Check if we know that the operands have a divisibility information. + SmallVector outputShape; + SmallVector reassociation; + Location loc = v.getLoc(); + + for (auto [index, dim] : llvm::enumerate(tensorType.getShape())) { + reassociation.emplace_back(ReassociationIndices{}); + + // Check if this needs division. + if (!tensorType.isDynamicDim(index) || !divisibilityInfo.contains(index)) { + reassociation.back().push_back(outputShape.size()); + outputShape.push_back(rewriter.getIndexAttr(dim)); + continue; + } + + // Split the dynamic based on the divisibility info. + IREE::Util::ConstantIntDivisibility currDivisibility = + divisibilityInfo.lookup(index); + uint64_t factor = currDivisibility.sdiv(); + AffineExpr s0 = rewriter.getAffineSymbolExpr(0); + AffineExpr divExpr = s0.floorDiv(factor); + Value sourceDim = rewriter.create(loc, v, index).getResult(); + OpFoldResult newDynamicDim = affine::makeComposedFoldedAffineApply( + rewriter, loc, divExpr, ArrayRef{sourceDim}); + OpFoldResult newStaticDim = rewriter.getIndexAttr(factor); + + reassociation.back().push_back(outputShape.size()); + reassociation.back().push_back(outputShape.size() + 1); + + outputShape.push_back(newDynamicDim); + outputShape.push_back(newStaticDim); + } + + auto staticOutputShape = + llvm::map_to_vector(outputShape, [](OpFoldResult ofr) { + if (auto staticShapeAttr = dyn_cast(ofr)) { + return cast(staticShapeAttr).getInt(); + } + return ShapedType::kDynamic; + }); + auto outputType = RankedTensorType::get( + staticOutputShape, tensorType.getElementType(), tensorType.getEncoding()); + + Value expandShape = rewriter.create( + loc, outputType, v, reassociation, outputShape); + Value barrier = + rewriter.create(loc, expandShape) + .getResult(0); + Value collapseShape = rewriter.create( + loc, tensorType, barrier, reassociation); + return collapseShape; +} + +/// For an operation, replace the operands at indices specified in +/// `limitToOperandIndices` with the result of +/// `tensor.expand_shape`/`tensor.collapse_shape` pair to materialize the +/// information about dynamic dimensions that are known to be a multiple of a +/// compile-time static value. For example, +/// +/// ```mlir +/// %1 = (..., %0, ...) : ... , tensor<4x?x6xf32> +/// ``` +/// +/// If the dynamic dimension is known to be a multiple of 16, then generate +/// +/// ```mlir +/// %expanded = tensor.expand_shape %0 : +/// tensor<4x?x5xf32> into tensor<4x?x16x6xf32> +/// %barrier = util.optimization.barrier %expanded +/// %collapsed = tensor.collapse_shape %barrier +/// : tensor<4x?x16x5xf32> into tensor<4x?x5xf32> +/// %1 = (..., %collaped, ...) : ... , tensor<4x?x6xf32> +/// ``` +static LogicalResult blockDynamicDimensions( + RewriterBase &rewriter, const TensorDynamicDimAnalysis &dynamicDimAnalysis, + Operation *operation, llvm::SmallDenseSet limitToOperandIndices) { + OpBuilder::InsertionGuard g(rewriter); + + for (OpOperand &operand : operation->getOpOperands()) { + if (!limitToOperandIndices.contains(operand.getOperandNumber())) + continue; + if (operand.get().getDefiningOp()) + continue; + TensorDivisibilityInfo operandDivisibilityInfo = + getTensorDivisibilityInfo(dynamicDimAnalysis, operand.get()); + if (operandDivisibilityInfo.empty()) + continue; + std::optional newOperand = blockDynamicDimensionsOfValue( + rewriter, operandDivisibilityInfo, operand.get()); + if (newOperand) { + rewriter.modifyOpInPlace(operation, + [&]() { operand.set(newOperand.value()); }); + } + } + return success(); +} + +/// Insert `tensor.expand_shape` operations to materialize in IR information +/// about dynamic dimensions that are known to be a multiple of a compile-time +/// know value, for the operands of `iree_linalg_ext.attention` operation. +static LogicalResult +blockDynamicDimensions(RewriterBase &rewriter, + const TensorDynamicDimAnalysis &dynamicDimAnalysis, + IREE::LinalgExt::AttentionOp attentionOp) { + // Only block the q and k values. + llvm::SmallDenseSet prunedOperandsList; + prunedOperandsList.insert(attentionOp.getQueryMutable().getOperandNumber()); + prunedOperandsList.insert(attentionOp.getKeyMutable().getOperandNumber()); + return blockDynamicDimensions(rewriter, dynamicDimAnalysis, attentionOp, + prunedOperandsList); +} + +void BlockDynamicDimensionsPass::runOnOperation() { + Operation *operation = getOperation(); + MLIRContext *context = &getContext(); + TensorDynamicDimAnalysis dynamicDimAnalysis(operation); + if (failed(dynamicDimAnalysis.run())) { + return signalPassFailure(); + } + + IRRewriter rewriter(context); + auto walkResult = operation->walk( + [&](IREE::LinalgExt::AttentionOp attentionOp) -> WalkResult { + rewriter.setInsertionPoint(attentionOp); + return blockDynamicDimensions(rewriter, dynamicDimAnalysis, + attentionOp); + }); + if (walkResult.wasInterrupted()) { + return signalPassFailure(); + } + + LLVM_DEBUG({ + llvm::dbgs() << "After blocking dimensions:\n"; + operation->print(llvm::dbgs(), OpPrintingFlags().useLocalScope()); + llvm::dbgs() << "\n"; + }); + + { + RewritePatternSet bubbleExpandShapePatterns(context); + // Add patterns to "push down" the `tensor.collapse_shape` patterns (which + // are the dual of the patterns to "bubble up" `tensor.expand_shape` + // patterns) + linalg::ControlFusionFn controlFn = [](OpOperand *) { return true; }; + linalg::populateFoldReshapeOpsByExpansionPatterns(bubbleExpandShapePatterns, + controlFn); + IREE::LinalgExt::populateFoldReshapeOpsByExpansionPatterns( + bubbleExpandShapePatterns, controlFn); + // Add patterns to fold the "bubbled-up" `tensor.expand_shape` operation and + // "pushed-down" `tensor.collapse_shape` operation with their interface + // bindings or `tensor.empty` operations. + populateReshapeToInterfaceTensorPatterns(bubbleExpandShapePatterns); + tensor::populateFoldTensorEmptyPatterns(bubbleExpandShapePatterns); + // Add some additional patterns that can simplify the IR and remove dead + // operations. + memref::populateResolveRankedShapedTypeResultDimsPatterns( + bubbleExpandShapePatterns); + populateRemoveDeadMemAllocPatterns(bubbleExpandShapePatterns); + if (failed(applyPatternsAndFoldGreedily( + operation, std::move(bubbleExpandShapePatterns)))) { + operation->emitOpError( + "failed in application of bubble up expand shape patterns"); + return signalPassFailure(); + } + } + + LLVM_DEBUG({ + llvm::dbgs() << "After reshape propagation:\n"; + operation->print(llvm::dbgs(), OpPrintingFlags().useLocalScope()); + llvm::dbgs() << "\n"; + }); + + // Delete the optimization barrier and run some further cleanup. + { + RewritePatternSet removeBarrierOpsPatterns(context); + removeBarrierOpsPatterns.insert(context); + tensor::ExpandShapeOp::getCanonicalizationPatterns(removeBarrierOpsPatterns, + context); + tensor::CollapseShapeOp::getCanonicalizationPatterns( + removeBarrierOpsPatterns, context); + if (failed(applyPatternsAndFoldGreedily( + operation, std::move(removeBarrierOpsPatterns)))) { + operation->emitOpError("failed in cleanup patterns"); + return signalPassFailure(); + } + } + + return; +} + +} // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt index 764bc258c902..ee7c406d51c8 100644 --- a/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt @@ -72,11 +72,13 @@ iree_cc_library( "ExtractAddressComputation.h" "PassUtils.h" "Passes.h" + "TensorDynamicDimAnalysis.h" "TileSizeSelection.h" "Transforms.h" "UserConfig.h" SRCS "AddFastMathFlags.cpp" + "BlockDynamicDimensions.cpp" "BubbleUpOrdinalOps.cpp" "BufferizationAnalysis.cpp" "BufferizeCopyOnlyDispatchesPass.cpp" @@ -128,6 +130,7 @@ iree_cc_library( "RemoveSingleIterationLoop.cpp" "ReplaceSlowMinMaxOps.cpp" "SplitFullPartialTransferPass.cpp" + "TensorDynamicDimAnalysis.cpp" "TensorToVectorVectorizePad.cpp" "TestExecutablePreprocessing.cpp" "TestPartitionableLoopsInterface.cpp" @@ -154,6 +157,7 @@ iree_cc_library( MLIRArithUtils MLIRBufferizationDialect MLIRBufferizationTransforms + MLIRDestinationStyleOpInterface MLIRFuncDialect MLIRFuncTransforms MLIRFunctionInterfaces @@ -203,6 +207,7 @@ iree_cc_library( iree::compiler::Dialect::HAL::IR iree::compiler::Dialect::LinalgExt::IR iree::compiler::Dialect::LinalgExt::Transforms + iree::compiler::Dialect::Util::Analysis iree::compiler::Dialect::Util::IR iree::compiler::Utils PUBLIC diff --git a/compiler/src/iree/compiler/Codegen/Common/Passes.td b/compiler/src/iree/compiler/Codegen/Common/Passes.td index 6a5a9b5578c0..5aa3ef414bcb 100644 --- a/compiler/src/iree/compiler/Codegen/Common/Passes.td +++ b/compiler/src/iree/compiler/Codegen/Common/Passes.td @@ -19,6 +19,12 @@ def AddFastMathFlagsPass "given a floating-point mode."; } +def BlockDynamicDimensionsPass + : Pass<"iree-codegen-block-dynamic-dimensions"> { + let summary = "Expand dynamic dimensions that are known to be multiples of " + "statically known values."; +} + def BubbleUpOrdinalOpsPass : Pass<"iree-codegen-bubble-up-ordinal-ops", ""> { let summary = "Bubbles op ordinal ops to allow for workgroup count computation"; let description = [{ diff --git a/compiler/src/iree/compiler/Codegen/Common/TensorDynamicDimAnalysis.cpp b/compiler/src/iree/compiler/Codegen/Common/TensorDynamicDimAnalysis.cpp new file mode 100644 index 000000000000..b0e76678732e --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Common/TensorDynamicDimAnalysis.cpp @@ -0,0 +1,236 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Codegen/Common/TensorDynamicDimAnalysis.h" +#include "iree/compiler/Dialect/Flow/IR/FlowOps.h" +#include "iree/compiler/Dialect/Util/Analysis/IntegerDivisibilityAnalysis.h" +#include "llvm/Support/Debug.h" +#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" +#include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Interfaces/DestinationStyleOpInterface.h" + +#define DEBUG_TYPE "iree-codegen-dynamic-dim-analysis" + +namespace mlir::iree_compiler { + +//===---------------------------------------------------------------------===// +// Helper function to update tensor dynamic dimension info +//===---------------------------------------------------------------------===// + +static void +updateRangeInfo(TensorDynamicDimAnalysis::TensorDimRangeInfo &rangeInfo, + Value v, unsigned dim, const ConstantIntRanges &range) { + assert(!rangeInfo.contains({v, dim}) && + "overwriting existing dim range info"); + rangeInfo.insert({{v, dim}, + ConstantIntRanges(range.umin(), range.umax(), range.smin(), + range.smax())}); +} + +static void updateDivisibilityInfo( + TensorDynamicDimAnalysis::TensorDimDivisibilityInfo &divisibilityInfo, + Value v, unsigned dim, + const IREE::Util::ConstantIntDivisibility &divisibility) { + assert(!divisibilityInfo.contains({v, dim}) && + "overwriting existing dim divisibility info"); + divisibilityInfo[{v, dim}] = divisibility; +} + +// Update the dynamic dim analysis to record the range/divisibility information +// for `tensorValue` at dimension `dimIndex` based on the range/divisibility +// information of an integer/index value `dynamicDim`. +static void updateTensorDimInfo( + Value tensorValue, unsigned dimIndex, Value dynamicDim, + const DataFlowSolver &solver, + TensorDynamicDimAnalysis::TensorDimDivisibilityInfo &divisibilityInfo, + TensorDynamicDimAnalysis::TensorDimRangeInfo &rangeInfo) { + // Update range info. + auto *rangeState = + solver.lookupState(dynamicDim); + if (rangeState && !rangeState->getValue().isUninitialized()) { + updateRangeInfo(rangeInfo, tensorValue, dimIndex, + rangeState->getValue().getValue()); + } + + // Update solver info + auto *divisibilityState = + solver.lookupState(dynamicDim); + if (divisibilityState && !divisibilityState->getValue().isUninitialized()) { + updateDivisibilityInfo(divisibilityInfo, tensorValue, dimIndex, + divisibilityState->getValue().getValue()); + } +} + +//===---------------------------------------------------------------------===// +// Transfer functions for updating dynamic dimension of results of operation. +//===---------------------------------------------------------------------===// + +// Helper function to just transfer the range and divisibility information +// `source` value to `dest` value. +static void transferTensorDimInfo( + Value source, Value dest, const DataFlowSolver &solver, + TensorDynamicDimAnalysis::TensorDimDivisibilityInfo &divisibilityInfo, + TensorDynamicDimAnalysis::TensorDimRangeInfo &rangeInfo) { + // expected that `source` and `dest` are of `RankedTensorType` and of the same + // type. + assert(source.getType() == dest.getType()); + auto sourceType = cast(source.getType()); + for (auto index : llvm::seq(0, sourceType.getRank())) { + // Transfer range info + auto rangeIt = rangeInfo.find({source, index}); + if (rangeIt != rangeInfo.end()) { + updateRangeInfo(rangeInfo, dest, index, rangeIt->second); + } + + auto divisibilityIt = divisibilityInfo.find({source, index}); + if (divisibilityIt != divisibilityInfo.end()) { + updateDivisibilityInfo(divisibilityInfo, dest, index, + divisibilityIt->second); + } + } +} + +// Update the tensor dimension information for result of a +// `flow.dispatch.tensor.load` operation. +static void updateTensorDimInfo( + IREE::Flow::DispatchTensorLoadOp flowLoadOp, const DataFlowSolver &solver, + TensorDynamicDimAnalysis::TensorDimDivisibilityInfo &divisibilityInfo, + TensorDynamicDimAnalysis::TensorDimRangeInfo &rangeInfo) { + // If there are no dynamic dimensions, nothing to do. + if (flowLoadOp.getType().hasStaticShape()) { + return; + } + // Check that all strides are 1. Abort otherwise + if (llvm::any_of(flowLoadOp.getMixedStrides(), + [](OpFoldResult s) { return !isConstantIntValue(s, 1); })) { + return; + } + + Value result = flowLoadOp.getResult(); + for (auto [index, size] : llvm::enumerate(flowLoadOp.getMixedSizes())) { + auto dynamicDim = dyn_cast(size); + if (!dynamicDim) { + continue; + } + updateTensorDimInfo(result, index, dynamicDim, solver, divisibilityInfo, + rangeInfo); + } +} + +// Update the tensor dimension information for result of a `tensor.empty` +// operation. +static void updateTensorDimInfo( + tensor::EmptyOp emptyOp, const DataFlowSolver &solver, + TensorDynamicDimAnalysis::TensorDimDivisibilityInfo &divisibilityInfo, + TensorDynamicDimAnalysis::TensorDimRangeInfo &rangeInfo) { + auto dimOperands = emptyOp.getOperands(); + if (dimOperands.empty()) { + return; + } + + Value result = emptyOp.getResult(); + auto resultType = cast(result.getType()); + int dimOperandIndex = 0; + for (auto [index, shape] : llvm::enumerate(resultType.getShape())) { + if (!ShapedType::isDynamic(shape)) + continue; + updateTensorDimInfo(result, index, dimOperands[dimOperandIndex++], solver, + divisibilityInfo, rangeInfo); + } +} + +// Update the tensor dimension information for results of an operation that +// implements the `DestinationStyleOpInterface`. +static void updateTensorDimInfo( + DestinationStyleOpInterface dstStyleOp, const DataFlowSolver &solver, + TensorDynamicDimAnalysis::TensorDimDivisibilityInfo &divisibilityInfo, + TensorDynamicDimAnalysis::TensorDimRangeInfo &rangeInfo) { + for (auto [index, result] : llvm::enumerate(dstStyleOp->getResults())) { + auto resultTensorType = dyn_cast(result.getType()); + if (!resultTensorType || resultTensorType.hasStaticShape()) { + continue; + } + Value source = dstStyleOp.getDpsInitOperand(index)->get(); + transferTensorDimInfo(source, result, solver, divisibilityInfo, rangeInfo); + } +} + +// Dispatch to the method that updates the dimension information for an +// operation. +static void updateTensorDimInfo( + Operation *op, const DataFlowSolver &solver, + TensorDynamicDimAnalysis::TensorDimDivisibilityInfo &divisibilityInfo, + TensorDynamicDimAnalysis::TensorDimRangeInfo &rangeInfo) { + LLVM_DEBUG({ + llvm::dbgs() << "Start updating op\n"; + op->print(llvm::dbgs(), OpPrintingFlags().useLocalScope()); + llvm::dbgs() << "\n"; + }); + + TypeSwitch(op) + .Case([&](auto op) { + updateTensorDimInfo(op, solver, divisibilityInfo, rangeInfo); + }) + .Case([&](auto op) { + updateTensorDimInfo(op, solver, divisibilityInfo, rangeInfo); + }); + + LLVM_DEBUG({ + for (auto [resultIndex, result] : llvm::enumerate(op->getResults())) { + auto tensorType = dyn_cast(result.getType()); + if (!tensorType) + continue; + for (auto index : llvm::seq(0, tensorType.getRank())) { + std::optional range; + std::optional divisibility; + auto rangeIt = rangeInfo.find({result, index}); + if (rangeIt != rangeInfo.end()) { + range = rangeIt->second; + } + auto divisibilityIt = divisibilityInfo.find({result, index}); + if (divisibilityIt != divisibilityInfo.end()) { + divisibility = divisibilityIt->second; + } + if (!range && !divisibility) { + continue; + } + llvm::dbgs() << "\tDim Info: Result number : " << resultIndex + << ", dim " << index; + if (range) { + llvm::dbgs() << " : Range " << range.value(); + } + if (divisibility) { + llvm::dbgs() << " : Divisibility " << divisibility.value(); + } + llvm::dbgs() << "\n"; + } + } + }); +} + +TensorDynamicDimAnalysis::TensorDynamicDimAnalysis(Operation *rootOp) + : rootOperation(rootOp) { + solver.load(); + solver.load(); + solver.load(); +} + +LogicalResult TensorDynamicDimAnalysis::run() { + if (failed(solver.initializeAndRun(rootOperation))) { + return failure(); + } + + // Walk the IR pre-order, forward and update the dynamic information for each + // tensor. + rootOperation->walk([&](Operation *op) { + updateTensorDimInfo(op, solver, divisibilityInfo, rangeInfo); + }); + + return success(); +} + +} // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Codegen/Common/TensorDynamicDimAnalysis.h b/compiler/src/iree/compiler/Codegen/Common/TensorDynamicDimAnalysis.h new file mode 100644 index 000000000000..13bdb5cac8d7 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Common/TensorDynamicDimAnalysis.h @@ -0,0 +1,65 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Dialect/Util/IR/UtilTypes.h" +#include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h" +#include "mlir/Analysis/DataFlowFramework.h" + +namespace mlir::iree_compiler { + +/// Analysis to compute information about dynamic dimensions of tensors. +/// +/// Using the IntegerRangeAnalysis and the IntegerDivisibilityAnalysis +/// this analysis builds information about the range and divisibility of dynamic +/// dimensions of tensor operands in the program. The analysis can then be +/// queried to get the range and divisibility info for any tensor value for any +/// dynamic dimension. +/// TODO: This is not a dataflow analysis or does not update information on IR +/// changes. This could be potentially expensive and is really meant to be used +/// before any transformations to the dispatch. If this needs to be more +/// efficient then this needs to be converted to a data flow solver. +class TensorDynamicDimAnalysis { +public: + explicit TensorDynamicDimAnalysis(Operation *rootOperation); + + LogicalResult run(); + + using TensorDimDivisibilityInfo = + DenseMap, + IREE::Util::ConstantIntDivisibility>; + using TensorDimRangeInfo = + DenseMap, ConstantIntRanges>; + + std::optional getRangeInfo(Value v, + unsigned dimIndex) const { + auto it = rangeInfo.find({v, dimIndex}); + if (it == rangeInfo.end()) { + return std::nullopt; + } + return it->second; + } + + std::optional + getDivisibilityInfo(Value v, unsigned dimIndex) const { + auto it = divisibilityInfo.find({v, dimIndex}); + if (it == divisibilityInfo.end()) { + return std::nullopt; + } + return it->second; + } + +private: + DataFlowSolver solver; + + // Operation scope within which the analysis is run. + Operation *rootOperation; + + // Map of tensor value to integer divisibility information for each dimension. + TensorDimDivisibilityInfo divisibilityInfo; + TensorDimRangeInfo rangeInfo; +}; + +} // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel index 7879b5809950..ab1a76ab2fc8 100644 --- a/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel @@ -21,6 +21,7 @@ iree_lit_test_suite( "add_fmfs.mlir", "affinemin_canonicalization.mlir", "batch_matmuls.mlir", + "block_dynamic_dims.mlir", "bubble_up_ordinal_ops.mlir", "bufferize_copy_only_dispatches.mlir", "canonicalize_interface_load_store.mlir", diff --git a/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt index 832319e9d9df..3ac6423c08c0 100644 --- a/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt @@ -17,6 +17,7 @@ iree_lit_test_suite( "add_fmfs.mlir" "affinemin_canonicalization.mlir" "batch_matmuls.mlir" + "block_dynamic_dims.mlir" "bubble_up_ordinal_ops.mlir" "bufferize_copy_only_dispatches.mlir" "canonicalize_interface_load_store.mlir" diff --git a/compiler/src/iree/compiler/Codegen/Common/test/block_dynamic_dims.mlir b/compiler/src/iree/compiler/Codegen/Common/test/block_dynamic_dims.mlir new file mode 100644 index 000000000000..819c4128a546 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Common/test/block_dynamic_dims.mlir @@ -0,0 +1,101 @@ +// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-codegen-block-dynamic-dimensions, cse))" --split-input-file --mlir-print-local-scope %s | FileCheck %s + +#pipeline_layout = #hal.pipeline.layout, + #hal.pipeline.binding, + #hal.pipeline.binding, + #hal.pipeline.binding, + #hal.pipeline.binding], flags = Indirect> +func.func @block_attention_dims() { + %c0 = arith.constant 0 : index + %cst = arith.constant 8.837890e-02 : f16 + %m_in = hal.interface.constant.load layout(#pipeline_layout) ordinal(0) : index + %k2_in = hal.interface.constant.load layout(#pipeline_layout) ordinal(1) : index + %0:2 = util.assume.int + %m_in, + %k2_in + : index, index + %m = flow.dispatch.workload.ordinal %0#0, 0 : index + %k2 = flow.dispatch.workload.ordinal %0#1, 1 : index + %q_in = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags("ReadOnly|Indirect") + : !flow.dispatch.tensor>{%m} + %key_in = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) flags("ReadOnly|Indirect") + : !flow.dispatch.tensor>{%k2} + %value_in = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) alignment(64) offset(%c0) flags("ReadOnly|Indirect") + : !flow.dispatch.tensor>{%k2} + %mask_in = hal.interface.binding.subspan layout(#pipeline_layout) binding(3) alignment(64) offset(%c0) flags("ReadOnly|Indirect") + : !flow.dispatch.tensor>{%m, %k2} + %output_in = hal.interface.binding.subspan layout(#pipeline_layout) binding(4) alignment(64) offset(%c0) flags(Indirect) + : !flow.dispatch.tensor>{%m} + %q = flow.dispatch.tensor.load %q_in, offsets = [0, 0, 0, 0], sizes = [4, %m, 32, 128], strides = [1, 1, 1, 1] + : !flow.dispatch.tensor>{%m} -> tensor<4x?x32x128xf16> + %key = flow.dispatch.tensor.load %key_in, offsets = [0, 0, 0, 0], sizes = [4, %k2, 32, 128], strides = [1, 1, 1, 1] + : !flow.dispatch.tensor>{%k2} -> tensor<4x?x32x128xf16> + %value = flow.dispatch.tensor.load %value_in, offsets = [0, 0, 0, 0], sizes = [4, %k2, 32, 128], strides = [1, 1, 1, 1] + : !flow.dispatch.tensor>{%k2} -> tensor<4x?x32x128xf16> + %mask = flow.dispatch.tensor.load %mask_in, offsets = [0, 0, 0, 0], sizes = [4, 32, %m, %k2], strides = [1, 1, 1, 1] + : !flow.dispatch.tensor>{%m, %k2} -> tensor<4x32x?x?xf16> + %1 = tensor.empty(%m) : tensor<4x?x32x128xf16> + %2 = tensor.empty(%m) : tensor<4x32x?x128xf16> + %attn = iree_linalg_ext.attention { + indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d1, d4)>, + affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d5, d1, d4)>, + affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d5, d1, d3)>, + affine_map<(d0, d1, d2, d3, d4, d5) -> ()>, + affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>, + affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>]} + ins(%q, %key, %value, %cst, %mask : tensor<4x?x32x128xf16>, tensor<4x?x32x128xf16>, tensor<4x?x32x128xf16>, f16, tensor<4x32x?x?xf16>) + outs(%2 : tensor<4x32x?x128xf16>) { + ^bb0(%b0 : f16) : + iree_linalg_ext.yield %b0 : f16 + }-> tensor<4x32x?x128xf16> + %result = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], + iterator_types = ["parallel", "parallel", "parallel", "parallel"]} + ins(%attn : tensor<4x32x?x128xf16>) outs(%1 : tensor<4x?x32x128xf16>) { + ^bb0(%in: f16, %out: f16): + linalg.yield %in : f16 + } -> tensor<4x?x32x128xf16> + flow.dispatch.tensor.store %result, %output_in, offsets = [0, 0, 0, 0], sizes = [4, %m, 32, 128], strides = [1, 1, 1, 1] + : tensor<4x?x32x128xf16> -> !flow.dispatch.tensor>{%m} + return +} +// CHECK-LABEL: func @block_attention_dims() +// CHECK-DAG: %[[C32:.+]] = arith.constant 32 : index +// CHECK-DAG: %[[C16:.+]] = arith.constant 16 : index +// CHECK-DAG: %[[M:.+]] = flow.dispatch.workload.ordinal %{{.+}}, 0 : index +// CHECK-DAG: %[[K2:.+]] = flow.dispatch.workload.ordinal %{{.+}}, 1 : index +// CHECK-DAG: %[[M_DYNAMIC:.+]] = arith.divui %[[M]], %[[C16]] +// CHECK: %[[Q_BINDING:.+]] = hal.interface.binding.subspan +// CHECK-SAME: binding(0) +// CHECK-SAME: !flow.dispatch.tensor>{%[[M_DYNAMIC]]} +// CHECK: %[[K2_DYNAMIC:.+]] = arith.divui %[[K2]], %[[C32]] +// CHECK: %[[K_BINDING:.+]] = hal.interface.binding.subspan +// CHECK-SAME: binding(1) +// CHECK-SAME: !flow.dispatch.tensor>{%[[K2_DYNAMIC]]} +// CHECK: %[[V_BINDING:.+]] = hal.interface.binding.subspan +// CHECK-SAME: binding(2) +// CHECK-SAME: !flow.dispatch.tensor>{%[[K2_DYNAMIC]]} +// CHECK: %[[MASK_BINDING:.+]] = hal.interface.binding.subspan +// CHECK-SAME: binding(3) +// CHECK-SAME: !flow.dispatch.tensor>{%[[M_DYNAMIC]], %[[K2_DYNAMIC]]} +// CHECK: %[[OUTPUT_BINDING:.+]] = hal.interface.binding.subspan +// CHECK-SAME: binding(4) +// CHECK-SAME: !flow.dispatch.tensor>{%[[M_DYNAMIC]]} +// CHECK: %[[Q:.+]] = flow.dispatch.tensor.load %[[Q_BINDING]] +// CHECK-SAME: sizes = [4, %[[M_DYNAMIC]], 16, 32, 128] +// CHECK-SAME: !flow.dispatch.tensor>{%[[M_DYNAMIC]]} +// CHECK: %[[K:.+]] = flow.dispatch.tensor.load %[[K_BINDING]] +// CHECK-SAME: sizes = [4, %[[K2_DYNAMIC]], 32, 32, 128] +// CHECK-SAME: !flow.dispatch.tensor>{%[[K2_DYNAMIC]]} +// CHECK: %[[V:.+]] = flow.dispatch.tensor.load %[[V_BINDING]] +// CHECK-SAME: sizes = [4, %[[K2_DYNAMIC]], 32, 32, 128] +// CHECK-SAME: !flow.dispatch.tensor>{%[[K2_DYNAMIC]]} +// CHECK: %[[MASK:.+]] = flow.dispatch.tensor.load %[[MASK_BINDING]] +// CHECK-SAME: sizes = [4, 32, %[[M_DYNAMIC]], 16, %[[K2_DYNAMIC]], 32] +// CHECK-SAME: !flow.dispatch.tensor>{%[[M_DYNAMIC]], %[[K2_DYNAMIC]]} +// CHECK: %[[ATTENTION:.+]] = iree_linalg_ext.attention +// CHECK: ins(%[[Q]], %[[K]], %[[V]], %{{.+}}, %[[MASK]] : +// CHECK: %[[GENERIC:.+]] = linalg.generic +// CHECK: flow.dispatch.tensor.store %[[GENERIC]], %[[OUTPUT_BINDING]] diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp index aab73c952c5f..86f65e1b0cc9 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp @@ -1179,6 +1179,9 @@ static void buildLLVMGPUCodegenConfigurationPassPipelineImpl( funcPassManager.addPass(createGPUGeneralizeNamedOpsPass); addCommonTargetExecutablePreprocessingPasses(funcPassManager); addEncodingToNopPasses(funcPassManager); + funcPassManager.addPass(createBlockDynamicDimensionsPass); + funcPassManager.addPass(createCanonicalizerPass); + funcPassManager.addPass(createCSEPass); } modulePassManager.addPass(createMaterializeUserConfigsPass()); modulePassManager.addPass(createLLVMGPUSelectLoweringStrategyPass()); diff --git a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.cpp b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.cpp index 5e5b2efed3cc..f19a665d9258 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.cpp @@ -1347,6 +1347,27 @@ LogicalResult verifyDispatchWorkgroupInfoOp(Operation *op, uint64_t dimension) { return success(); } +//===----------------------------------------------------------------------===// +// flow.dispatch.workload.ordinal +//===----------------------------------------------------------------------===// + +void DispatchWorkloadOrdinalOp::inferResultDivisibility( + ArrayRef argDivs, + IREE::Util::SetIntDivisibilityFn setResultDivisibility) { + if (argDivs[0].isUninitialized()) { + setResultDivisibility(getResult(), + IREE::Util::ConstantIntDivisibility(1, 1)); + return; + } + setResultDivisibility(getResult(), argDivs[0].getValue()); +} + +void DispatchWorkloadOrdinalOp::inferResultRanges( + ArrayRef argRanges, SetIntRangeFn setResultRange) { + assert(!argRanges.empty() && "expected range of input to be set"); + setResultRange(getResult(), argRanges[0]); +} + //===----------------------------------------------------------------------===// // flow.executable //===----------------------------------------------------------------------===// diff --git a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.td b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.td index 301ce8b15b28..69d8cc419382 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.td +++ b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.td @@ -16,6 +16,7 @@ include "mlir/IR/OpAsmInterface.td" include "mlir/IR/SymbolInterfaces.td" include "mlir/Interfaces/CallInterfaces.td" include "mlir/Interfaces/ControlFlowInterfaces.td" +include "mlir/Interfaces/InferIntRangeInterface.td" include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/ViewLikeInterface.td" @@ -1741,7 +1742,11 @@ def FLOW_DispatchWorkgroupCountFromSliceOp : } def FLOW_DispatchWorkloadOrdinalOp : - FLOW_PureOp<"dispatch.workload.ordinal"> { + FLOW_PureOp<"dispatch.workload.ordinal", [ + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods + ]> { let arguments = (ins Index:$operand, IndexAttr:$ordinal From 7f14078bf688bc720095271420cb136df4ca639a Mon Sep 17 00:00:00 2001 From: Ben Vanik Date: Mon, 28 Oct 2024 10:30:23 -0700 Subject: [PATCH 19/45] Adding build config for HSA runtime headers. (#18909) --- build_tools/bazel/workspace.bzl | 7 +++++ .../bazel_to_cmake/bazel_to_cmake_targets.py | 1 + .../hsa-runtime-headers/BUILD.overlay | 16 +++++++++++ .../hsa-runtime-headers/CMakeLists.txt | 28 +++++++++++++++++++ third_party/hsa-runtime-headers | 2 +- 5 files changed, 53 insertions(+), 1 deletion(-) create mode 100644 build_tools/third_party/hsa-runtime-headers/BUILD.overlay create mode 100644 build_tools/third_party/hsa-runtime-headers/CMakeLists.txt diff --git a/build_tools/bazel/workspace.bzl b/build_tools/bazel/workspace.bzl index 654649508d3c..bb39437ea561 100644 --- a/build_tools/bazel/workspace.bzl +++ b/build_tools/bazel/workspace.bzl @@ -147,6 +147,13 @@ def configure_iree_submodule_deps(iree_repo_alias = "@", iree_path = "./"): path = paths.join(iree_path, "third_party/nccl"), ) + maybe( + native.new_local_repository, + name = "hsa_runtime_headers", + build_file = iree_repo_alias + "//:build_tools/third_party/hsa-runtime-headers/BUILD.overlay", + path = paths.join(iree_path, "third_party/hsa-runtime-headers"), + ) + maybe( native.new_local_repository, name = "webgpu_headers", diff --git a/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py b/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py index cecc21777f5f..0c0469eb335c 100644 --- a/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py +++ b/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py @@ -113,6 +113,7 @@ def __init__(self, repo_map: Dict[str, str]): "@com_google_googletest//:gtest": ["gmock", "gtest"], "@spirv_cross//:spirv_cross_lib": ["spirv-cross-msl"], "@cpuinfo": ["${IREE_CPUINFO_TARGET}"], + "@hsa_runtime_headers": ["hsa_runtime::headers"], "@webgpu_headers": [], } ) diff --git a/build_tools/third_party/hsa-runtime-headers/BUILD.overlay b/build_tools/third_party/hsa-runtime-headers/BUILD.overlay new file mode 100644 index 000000000000..b3e0b85dd47b --- /dev/null +++ b/build_tools/third_party/hsa-runtime-headers/BUILD.overlay @@ -0,0 +1,16 @@ +# Copyright 2024 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +package(default_visibility = ["//visibility:public"]) + +cc_library( + name = "hsa_runtime_headers", + hdrs = glob([ + "include/hsa/*.h", + ]), + include_prefix = "third_party/hsa-runtime-headers/", + includes = ["include"], +) diff --git a/build_tools/third_party/hsa-runtime-headers/CMakeLists.txt b/build_tools/third_party/hsa-runtime-headers/CMakeLists.txt new file mode 100644 index 000000000000..e939e9895122 --- /dev/null +++ b/build_tools/third_party/hsa-runtime-headers/CMakeLists.txt @@ -0,0 +1,28 @@ +# Copyright 2024 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +set(HSA_RUNTIME_HEADERS_ROOT "${IREE_ROOT_DIR}/third_party/hsa-runtime-headers/") + +external_cc_library( + PACKAGE + hsa_runtime + NAME + headers + ROOT + ${HSA_RUNTIME_HEADERS_ROOT} + SYSTEM_INCLUDES + ${HSA_RUNTIME_HEADERS_ROOT}/include/ + PUBLIC +) + +iree_install_targets( + TARGETS + hsa_runtime_headers + COMPONENT + IREEBundledLibraries + EXPORT_SET + Runtime +) diff --git a/third_party/hsa-runtime-headers b/third_party/hsa-runtime-headers index c4fb247e2861..ffa0dc3307be 160000 --- a/third_party/hsa-runtime-headers +++ b/third_party/hsa-runtime-headers @@ -1 +1 @@ -Subproject commit c4fb247e28616c51d37a45f2c0056ed5f4df0555 +Subproject commit ffa0dc3307be5472ccdf7c9825c3dc68340649de From bb7ece71ffe955d34494703d6f9d47d92b334efe Mon Sep 17 00:00:00 2001 From: Ben Vanik Date: Mon, 28 Oct 2024 10:31:01 -0700 Subject: [PATCH 20/45] Minor comment/style tweaks to the null HAL driver. (#18911) --- .github/CODEOWNERS | 1 - runtime/src/iree/hal/drivers/null/README.md | 2 ++ runtime/src/iree/hal/drivers/null/allocator.c | 5 +++++ runtime/src/iree/hal/drivers/null/allocator.h | 4 ++++ runtime/src/iree/hal/drivers/null/buffer.c | 6 +++++- runtime/src/iree/hal/drivers/null/buffer.h | 4 ++++ runtime/src/iree/hal/drivers/null/channel.c | 6 +++++- runtime/src/iree/hal/drivers/null/channel.h | 4 ++++ runtime/src/iree/hal/drivers/null/command_buffer.c | 6 +++++- runtime/src/iree/hal/drivers/null/command_buffer.h | 4 ++++ runtime/src/iree/hal/drivers/null/device.c | 6 +++++- runtime/src/iree/hal/drivers/null/device.h | 4 ++++ runtime/src/iree/hal/drivers/null/driver.c | 6 +++++- runtime/src/iree/hal/drivers/null/driver.h | 4 ++++ runtime/src/iree/hal/drivers/null/event.c | 6 +++++- runtime/src/iree/hal/drivers/null/event.h | 4 ++++ runtime/src/iree/hal/drivers/null/executable.c | 6 +++++- runtime/src/iree/hal/drivers/null/executable.h | 4 ++++ runtime/src/iree/hal/drivers/null/executable_cache.c | 6 +++++- runtime/src/iree/hal/drivers/null/executable_cache.h | 4 ++++ runtime/src/iree/hal/drivers/null/semaphore.c | 2 +- 21 files changed, 84 insertions(+), 10 deletions(-) diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 11cf13cc35c2..6edcff51e7ad 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -32,7 +32,6 @@ # Experimental # It's experimental, but we still don't want any old directory added here. /experimental/ @benvanik @stellaraccident -/experimental/rocm/ @benvanik /experimental/web/ @ScottTodd /experimental/webgpu/ @benvanik @ScottTodd diff --git a/runtime/src/iree/hal/drivers/null/README.md b/runtime/src/iree/hal/drivers/null/README.md index 3c3e4200334b..3dcb62a512a9 100644 --- a/runtime/src/iree/hal/drivers/null/README.md +++ b/runtime/src/iree/hal/drivers/null/README.md @@ -17,8 +17,10 @@ fill (memset) you can often implement copy (memcpy) as well at the same time. `experimental/` folder if going in-tree. 1. Find/replace `{Null}` with the friendly name of your driver (e.g. `Vulkan`). 1. Find/replace `_null_` with the C name of your driver (e.g. `vulkan`). +1. Find/replace `_NULL_` with the upper C name of your driver (e.g. `VULKAN`). 1. Find/replace `// TODO(null):` with your github ID, your driver name, or a GitHub issue number tracking driver creation (e.g. `// TODO(#1234):`). +1. Find/replace `iree/hal/drivers/null/` with your source path. ## Build Setup diff --git a/runtime/src/iree/hal/drivers/null/allocator.c b/runtime/src/iree/hal/drivers/null/allocator.c index f84f00257ee5..e1c91ce7c35d 100644 --- a/runtime/src/iree/hal/drivers/null/allocator.c +++ b/runtime/src/iree/hal/drivers/null/allocator.c @@ -8,6 +8,10 @@ #include "iree/hal/drivers/null/buffer.h" +//===----------------------------------------------------------------------===// +// iree_hal_null_allocator_t +//===----------------------------------------------------------------------===// + // TODO(null): use one ID per address space or pool - each shows as a different // track in tracing tools. #if IREE_TRACING_FEATURES & IREE_TRACING_FEATURE_ALLOCATION_TRACKING @@ -33,6 +37,7 @@ iree_status_t iree_hal_null_allocator_create( iree_allocator_t host_allocator, iree_hal_allocator_t** out_allocator) { IREE_ASSERT_ARGUMENT(out_allocator); IREE_TRACE_ZONE_BEGIN(z0); + *out_allocator = NULL; iree_hal_null_allocator_t* allocator = NULL; IREE_RETURN_AND_END_ZONE_IF_ERROR( diff --git a/runtime/src/iree/hal/drivers/null/allocator.h b/runtime/src/iree/hal/drivers/null/allocator.h index c0286bac6041..299c9c96c44f 100644 --- a/runtime/src/iree/hal/drivers/null/allocator.h +++ b/runtime/src/iree/hal/drivers/null/allocator.h @@ -10,6 +10,10 @@ #include "iree/base/api.h" #include "iree/hal/api.h" +//===----------------------------------------------------------------------===// +// iree_hal_null_allocator_t +//===----------------------------------------------------------------------===// + // Creates a {Null} buffer allocator used for persistent allocations. iree_status_t iree_hal_null_allocator_create( iree_allocator_t host_allocator, iree_hal_allocator_t** out_allocator); diff --git a/runtime/src/iree/hal/drivers/null/buffer.c b/runtime/src/iree/hal/drivers/null/buffer.c index f6eeecb11f20..6e676526b1b5 100644 --- a/runtime/src/iree/hal/drivers/null/buffer.c +++ b/runtime/src/iree/hal/drivers/null/buffer.c @@ -6,6 +6,10 @@ #include "iree/hal/drivers/null/buffer.h" +//===----------------------------------------------------------------------===// +// iree_hal_null_buffer_t +//===----------------------------------------------------------------------===// + typedef struct iree_hal_null_buffer_t { iree_hal_buffer_t base; iree_hal_buffer_release_callback_t release_callback; @@ -33,8 +37,8 @@ iree_status_t iree_hal_null_buffer_wrap( iree_hal_buffer_release_callback_t release_callback, iree_allocator_t host_allocator, iree_hal_buffer_t** out_buffer) { IREE_ASSERT_ARGUMENT(out_buffer); - *out_buffer = NULL; IREE_TRACE_ZONE_BEGIN(z0); + *out_buffer = NULL; iree_hal_null_buffer_t* buffer = NULL; IREE_RETURN_AND_END_ZONE_IF_ERROR( diff --git a/runtime/src/iree/hal/drivers/null/buffer.h b/runtime/src/iree/hal/drivers/null/buffer.h index 7e492f4d49d7..edf2e457e5e8 100644 --- a/runtime/src/iree/hal/drivers/null/buffer.h +++ b/runtime/src/iree/hal/drivers/null/buffer.h @@ -10,6 +10,10 @@ #include "iree/base/api.h" #include "iree/hal/api.h" +//===----------------------------------------------------------------------===// +// iree_hal_null_buffer_t +//===----------------------------------------------------------------------===// + // Wraps a {Null} allocation in an iree_hal_buffer_t. iree_status_t iree_hal_null_buffer_wrap( iree_hal_allocator_t* allocator, iree_hal_memory_type_t memory_type, diff --git a/runtime/src/iree/hal/drivers/null/channel.c b/runtime/src/iree/hal/drivers/null/channel.c index 195c3d5786cd..0d2915b066fa 100644 --- a/runtime/src/iree/hal/drivers/null/channel.c +++ b/runtime/src/iree/hal/drivers/null/channel.c @@ -6,6 +6,10 @@ #include "iree/hal/drivers/null/channel.h" +//===----------------------------------------------------------------------===// +// iree_hal_null_channel_t +//===----------------------------------------------------------------------===// + typedef struct iree_hal_null_channel_t { iree_hal_resource_t resource; iree_allocator_t host_allocator; @@ -34,8 +38,8 @@ iree_status_t iree_hal_null_channel_create(iree_hal_channel_params_t params, iree_allocator_t host_allocator, iree_hal_channel_t** out_channel) { IREE_ASSERT_ARGUMENT(out_channel); - *out_channel = NULL; IREE_TRACE_ZONE_BEGIN(z0); + *out_channel = NULL; iree_hal_null_channel_t* channel = NULL; IREE_RETURN_AND_END_ZONE_IF_ERROR( diff --git a/runtime/src/iree/hal/drivers/null/channel.h b/runtime/src/iree/hal/drivers/null/channel.h index 83c4ef1aef88..efa7c10c5e62 100644 --- a/runtime/src/iree/hal/drivers/null/channel.h +++ b/runtime/src/iree/hal/drivers/null/channel.h @@ -10,6 +10,10 @@ #include "iree/base/api.h" #include "iree/hal/api.h" +//===----------------------------------------------------------------------===// +// iree_hal_null_channel_t +//===----------------------------------------------------------------------===// + // Creates a {Null} HAL collective channel using the given |params|. iree_status_t iree_hal_null_channel_create(iree_hal_channel_params_t params, iree_allocator_t host_allocator, diff --git a/runtime/src/iree/hal/drivers/null/command_buffer.c b/runtime/src/iree/hal/drivers/null/command_buffer.c index 9d474d44cb96..4f8fe822ccf1 100644 --- a/runtime/src/iree/hal/drivers/null/command_buffer.c +++ b/runtime/src/iree/hal/drivers/null/command_buffer.c @@ -10,6 +10,10 @@ #include "iree/hal/drivers/null/channel.h" #include "iree/hal/drivers/null/executable.h" +//===----------------------------------------------------------------------===// +// iree_hal_null_command_buffer_t +//===----------------------------------------------------------------------===// + typedef struct iree_hal_null_command_buffer_t { iree_hal_command_buffer_t base; iree_allocator_t host_allocator; @@ -31,8 +35,8 @@ iree_status_t iree_hal_null_command_buffer_create( iree_allocator_t host_allocator, iree_hal_command_buffer_t** out_command_buffer) { IREE_ASSERT_ARGUMENT(out_command_buffer); - *out_command_buffer = NULL; IREE_TRACE_ZONE_BEGIN(z0); + *out_command_buffer = NULL; iree_hal_null_command_buffer_t* command_buffer = NULL; IREE_RETURN_AND_END_ZONE_IF_ERROR( diff --git a/runtime/src/iree/hal/drivers/null/command_buffer.h b/runtime/src/iree/hal/drivers/null/command_buffer.h index cca92367dd82..d8ab61d89175 100644 --- a/runtime/src/iree/hal/drivers/null/command_buffer.h +++ b/runtime/src/iree/hal/drivers/null/command_buffer.h @@ -10,6 +10,10 @@ #include "iree/base/api.h" #include "iree/hal/api.h" +//===----------------------------------------------------------------------===// +// iree_hal_null_command_buffer_t +//===----------------------------------------------------------------------===// + // Creates {Null} command buffer. iree_status_t iree_hal_null_command_buffer_create( iree_hal_allocator_t* device_allocator, iree_hal_command_buffer_mode_t mode, diff --git a/runtime/src/iree/hal/drivers/null/device.c b/runtime/src/iree/hal/drivers/null/device.c index aaa7b1591fd5..ce122400d313 100644 --- a/runtime/src/iree/hal/drivers/null/device.c +++ b/runtime/src/iree/hal/drivers/null/device.c @@ -17,6 +17,10 @@ #include "iree/hal/utils/file_transfer.h" #include "iree/hal/utils/memory_file.h" +//===----------------------------------------------------------------------===// +// iree_hal_null_device_t +//===----------------------------------------------------------------------===// + typedef struct iree_hal_null_device_t { iree_hal_resource_t resource; iree_string_view_t identifier; @@ -60,8 +64,8 @@ iree_status_t iree_hal_null_device_create( iree_allocator_t host_allocator, iree_hal_device_t** out_device) { IREE_ASSERT_ARGUMENT(options); IREE_ASSERT_ARGUMENT(out_device); - *out_device = NULL; IREE_TRACE_ZONE_BEGIN(z0); + *out_device = NULL; // Verify the parameters prior to creating resources. IREE_RETURN_AND_END_ZONE_IF_ERROR( diff --git a/runtime/src/iree/hal/drivers/null/device.h b/runtime/src/iree/hal/drivers/null/device.h index aa70db6408d6..18978668bbf7 100644 --- a/runtime/src/iree/hal/drivers/null/device.h +++ b/runtime/src/iree/hal/drivers/null/device.h @@ -10,6 +10,10 @@ #include "iree/base/api.h" #include "iree/hal/api.h" +//===----------------------------------------------------------------------===// +// iree_hal_null_device_t +//===----------------------------------------------------------------------===// + // NOTE: nothing in the skeleton implementation. Device creation and adoption is // part of the public API header. This header can contain internal types and // functions. diff --git a/runtime/src/iree/hal/drivers/null/driver.c b/runtime/src/iree/hal/drivers/null/driver.c index 94be18a45364..78cf511a6999 100644 --- a/runtime/src/iree/hal/drivers/null/driver.c +++ b/runtime/src/iree/hal/drivers/null/driver.c @@ -8,6 +8,10 @@ #include "iree/hal/drivers/null/api.h" +//===----------------------------------------------------------------------===// +// iree_hal_null_driver_t +//===----------------------------------------------------------------------===// + // TODO(null): if it's possible to have more than one device use real IDs. // This is a placeholder for this skeleton that just indicates the first and // only device. @@ -57,8 +61,8 @@ IREE_API_EXPORT iree_status_t iree_hal_null_driver_create( iree_allocator_t host_allocator, iree_hal_driver_t** out_driver) { IREE_ASSERT_ARGUMENT(options); IREE_ASSERT_ARGUMENT(out_driver); - *out_driver = NULL; IREE_TRACE_ZONE_BEGIN(z0); + *out_driver = NULL; // TODO(null): verify options; this may be moved after any libraries are // loaded so the verification can use underlying implementation queries. diff --git a/runtime/src/iree/hal/drivers/null/driver.h b/runtime/src/iree/hal/drivers/null/driver.h index 84b12c1beac8..1938778056d8 100644 --- a/runtime/src/iree/hal/drivers/null/driver.h +++ b/runtime/src/iree/hal/drivers/null/driver.h @@ -10,6 +10,10 @@ #include "iree/base/api.h" #include "iree/hal/api.h" +//===----------------------------------------------------------------------===// +// iree_hal_null_driver_t +//===----------------------------------------------------------------------===// + // NOTE: nothing in the skeleton implementation. Driver creation and adoption is // part of the public API header. This header can contain internal types and // functions. diff --git a/runtime/src/iree/hal/drivers/null/event.c b/runtime/src/iree/hal/drivers/null/event.c index 5f1e413ca204..fabbe45b1311 100644 --- a/runtime/src/iree/hal/drivers/null/event.c +++ b/runtime/src/iree/hal/drivers/null/event.c @@ -6,6 +6,10 @@ #include "iree/hal/drivers/null/event.h" +//===----------------------------------------------------------------------===// +// iree_hal_null_event_t +//===----------------------------------------------------------------------===// + typedef struct iree_hal_null_event_t { iree_hal_resource_t resource; iree_allocator_t host_allocator; @@ -23,8 +27,8 @@ iree_status_t iree_hal_null_event_create( iree_hal_queue_affinity_t queue_affinity, iree_hal_event_flags_t flags, iree_allocator_t host_allocator, iree_hal_event_t** out_event) { IREE_ASSERT_ARGUMENT(out_event); - *out_event = NULL; IREE_TRACE_ZONE_BEGIN(z0); + *out_event = NULL; iree_hal_null_event_t* event = NULL; IREE_RETURN_AND_END_ZONE_IF_ERROR( diff --git a/runtime/src/iree/hal/drivers/null/event.h b/runtime/src/iree/hal/drivers/null/event.h index 68c11f44d1e6..ca7f364e458e 100644 --- a/runtime/src/iree/hal/drivers/null/event.h +++ b/runtime/src/iree/hal/drivers/null/event.h @@ -10,6 +10,10 @@ #include "iree/base/api.h" #include "iree/hal/api.h" +//===----------------------------------------------------------------------===// +// iree_hal_null_event_t +//===----------------------------------------------------------------------===// + // WIP API and may change. Mostly ignored for now. iree_status_t iree_hal_null_event_create( iree_hal_queue_affinity_t queue_affinity, iree_hal_event_flags_t flags, diff --git a/runtime/src/iree/hal/drivers/null/executable.c b/runtime/src/iree/hal/drivers/null/executable.c index a90d697d9d8d..3301d6cd767a 100644 --- a/runtime/src/iree/hal/drivers/null/executable.c +++ b/runtime/src/iree/hal/drivers/null/executable.c @@ -6,6 +6,10 @@ #include "iree/hal/drivers/null/executable.h" +//===----------------------------------------------------------------------===// +// iree_hal_null_executable_t +//===----------------------------------------------------------------------===// + typedef struct iree_hal_null_executable_t { iree_hal_resource_t resource; iree_allocator_t host_allocator; @@ -24,8 +28,8 @@ iree_status_t iree_hal_null_executable_create( iree_allocator_t host_allocator, iree_hal_executable_t** out_executable) { IREE_ASSERT_ARGUMENT(executable_params); IREE_ASSERT_ARGUMENT(out_executable); - *out_executable = NULL; IREE_TRACE_ZONE_BEGIN(z0); + *out_executable = NULL; // Allocate storage for the executable and its associated data structures. iree_hal_null_executable_t* executable = NULL; diff --git a/runtime/src/iree/hal/drivers/null/executable.h b/runtime/src/iree/hal/drivers/null/executable.h index 0107e1a14d4a..0ae87aefb947 100644 --- a/runtime/src/iree/hal/drivers/null/executable.h +++ b/runtime/src/iree/hal/drivers/null/executable.h @@ -10,6 +10,10 @@ #include "iree/base/api.h" #include "iree/hal/api.h" +//===----------------------------------------------------------------------===// +// iree_hal_null_executable_t +//===----------------------------------------------------------------------===// + // Creates a {Null} executable from a binary in memory. Each executable may // contain multiple entry points and be composed of several modules presented to // the HAL as a single instance. See iree_hal_executable_params_t for more diff --git a/runtime/src/iree/hal/drivers/null/executable_cache.c b/runtime/src/iree/hal/drivers/null/executable_cache.c index d4f0ad6ad066..a7c6f4b7cec1 100644 --- a/runtime/src/iree/hal/drivers/null/executable_cache.c +++ b/runtime/src/iree/hal/drivers/null/executable_cache.c @@ -8,6 +8,10 @@ #include "iree/hal/drivers/null/executable.h" +//===----------------------------------------------------------------------===// +// iree_hal_null_executable_cache_t +//===----------------------------------------------------------------------===// + typedef struct iree_hal_null_executable_cache_t { iree_hal_resource_t resource; iree_allocator_t host_allocator; @@ -26,8 +30,8 @@ iree_status_t iree_hal_null_executable_cache_create( iree_string_view_t identifier, iree_allocator_t host_allocator, iree_hal_executable_cache_t** out_executable_cache) { IREE_ASSERT_ARGUMENT(out_executable_cache); - *out_executable_cache = NULL; IREE_TRACE_ZONE_BEGIN(z0); + *out_executable_cache = NULL; iree_hal_null_executable_cache_t* executable_cache = NULL; IREE_RETURN_AND_END_ZONE_IF_ERROR( diff --git a/runtime/src/iree/hal/drivers/null/executable_cache.h b/runtime/src/iree/hal/drivers/null/executable_cache.h index 519b8c05e18a..b4af9e76cc28 100644 --- a/runtime/src/iree/hal/drivers/null/executable_cache.h +++ b/runtime/src/iree/hal/drivers/null/executable_cache.h @@ -10,6 +10,10 @@ #include "iree/base/api.h" #include "iree/hal/api.h" +//===----------------------------------------------------------------------===// +// iree_hal_null_executable_cache_t +//===----------------------------------------------------------------------===// + // Creates a no-op executable cache that does not cache at all. // This is useful to isolate pipeline caching behavior and verify compilation // behavior. diff --git a/runtime/src/iree/hal/drivers/null/semaphore.c b/runtime/src/iree/hal/drivers/null/semaphore.c index 25ec7dc99fbb..b397c85fe2c1 100644 --- a/runtime/src/iree/hal/drivers/null/semaphore.c +++ b/runtime/src/iree/hal/drivers/null/semaphore.c @@ -29,8 +29,8 @@ iree_status_t iree_hal_null_semaphore_create( uint64_t initial_value, iree_hal_semaphore_flags_t flags, iree_allocator_t host_allocator, iree_hal_semaphore_t** out_semaphore) { IREE_ASSERT_ARGUMENT(out_semaphore); - *out_semaphore = NULL; IREE_TRACE_ZONE_BEGIN(z0); + *out_semaphore = NULL; iree_hal_null_semaphore_t* semaphore = NULL; IREE_RETURN_AND_END_ZONE_IF_ERROR( From 4823dc0e2e6a9de8289421eb87b58ebce8620198 Mon Sep 17 00:00:00 2001 From: Ben Vanik Date: Mon, 28 Oct 2024 10:54:26 -0700 Subject: [PATCH 21/45] Adding HAL semaphore support for statuses-as-failure-payloads. (#18912) This allows an implementation to have a single atomic value for a semaphore that encodes the user payload or an error payload that optionally references an iree_status_t object. Implementations not using the status feature can ignore it but must perform a greater-than-or-equal check on `IREE_HAL_SEMAPHORE_FAILURE_VALUE` instead of equality. --- .../iree/hal/cts/semaphore_submission_test.h | 2 +- runtime/src/iree/hal/cts/semaphore_test.h | 8 +-- .../iree/hal/drivers/cuda/event_semaphore.c | 6 +- .../iree/hal/drivers/hip/event_semaphore.c | 6 +- runtime/src/iree/hal/semaphore.h | 62 +++++++++++++++++-- 5 files changed, 69 insertions(+), 15 deletions(-) diff --git a/runtime/src/iree/hal/cts/semaphore_submission_test.h b/runtime/src/iree/hal/cts/semaphore_submission_test.h index a158082b36c3..b745761cf6d9 100644 --- a/runtime/src/iree/hal/cts/semaphore_submission_test.h +++ b/runtime/src/iree/hal/cts/semaphore_submission_test.h @@ -882,7 +882,7 @@ TEST_F(SemaphoreSubmissionTest, PropagateFailSignal) { EXPECT_EQ(iree_status_code(wait_status), IREE_STATUS_ABORTED); uint64_t value = 1234; iree_status_t query_status = iree_hal_semaphore_query(semaphore2, &value); - EXPECT_EQ(value, IREE_HAL_SEMAPHORE_FAILURE_VALUE); + EXPECT_GE(value, IREE_HAL_SEMAPHORE_FAILURE_VALUE); CheckStatusContains(query_status, status); signal_thread.join(); diff --git a/runtime/src/iree/hal/cts/semaphore_test.h b/runtime/src/iree/hal/cts/semaphore_test.h index 54e907e47004..7d0592f1921a 100644 --- a/runtime/src/iree/hal/cts/semaphore_test.h +++ b/runtime/src/iree/hal/cts/semaphore_test.h @@ -406,7 +406,7 @@ TEST_F(SemaphoreTest, FailThenWait) { EXPECT_EQ(iree_status_code(wait_status), IREE_STATUS_ABORTED); uint64_t value = 1234; iree_status_t query_status = iree_hal_semaphore_query(semaphore, &value); - EXPECT_EQ(value, IREE_HAL_SEMAPHORE_FAILURE_VALUE); + EXPECT_GE(value, IREE_HAL_SEMAPHORE_FAILURE_VALUE); CheckStatusContains(query_status, status); iree_hal_semaphore_release(semaphore); @@ -431,7 +431,7 @@ TEST_F(SemaphoreTest, WaitThenFail) { EXPECT_EQ(iree_status_code(wait_status), IREE_STATUS_ABORTED); uint64_t value = 1234; iree_status_t query_status = iree_hal_semaphore_query(semaphore, &value); - EXPECT_EQ(value, IREE_HAL_SEMAPHORE_FAILURE_VALUE); + EXPECT_GE(value, IREE_HAL_SEMAPHORE_FAILURE_VALUE); CheckStatusContains(query_status, status); signal_thread.join(); @@ -467,7 +467,7 @@ TEST_F(SemaphoreTest, MultiWaitThenFail) { uint64_t value = 1234; iree_status_t semaphore1_query_status = iree_hal_semaphore_query(semaphore1, &value); - EXPECT_EQ(value, IREE_HAL_SEMAPHORE_FAILURE_VALUE); + EXPECT_GE(value, IREE_HAL_SEMAPHORE_FAILURE_VALUE); CheckStatusContains(semaphore1_query_status, status); // semaphore2 must not have changed. @@ -511,7 +511,7 @@ TEST_F(SemaphoreTest, DeviceMultiWaitThenFail) { uint64_t value = 1234; iree_status_t semaphore1_query_status = iree_hal_semaphore_query(semaphore1, &value); - EXPECT_EQ(value, IREE_HAL_SEMAPHORE_FAILURE_VALUE); + EXPECT_GE(value, IREE_HAL_SEMAPHORE_FAILURE_VALUE); CheckStatusContains(semaphore1_query_status, status); // semaphore2 must not have changed. diff --git a/runtime/src/iree/hal/drivers/cuda/event_semaphore.c b/runtime/src/iree/hal/drivers/cuda/event_semaphore.c index fb86efe7e815..0c0cf41e6ba9 100644 --- a/runtime/src/iree/hal/drivers/cuda/event_semaphore.c +++ b/runtime/src/iree/hal/drivers/cuda/event_semaphore.c @@ -325,7 +325,7 @@ static iree_status_t iree_hal_cuda_semaphore_wait( } iree_slim_mutex_lock(&semaphore->mutex); - if (semaphore->current_value == IREE_HAL_SEMAPHORE_FAILURE_VALUE) { + if (semaphore->current_value >= IREE_HAL_SEMAPHORE_FAILURE_VALUE) { iree_slim_mutex_unlock(&semaphore->mutex); IREE_TRACE_ZONE_END(z0); return iree_make_status(IREE_STATUS_ABORTED); @@ -350,7 +350,7 @@ static iree_status_t iree_hal_cuda_semaphore_wait( } iree_slim_mutex_lock(&semaphore->mutex); - if (semaphore->current_value == IREE_HAL_SEMAPHORE_FAILURE_VALUE) { + if (semaphore->current_value >= IREE_HAL_SEMAPHORE_FAILURE_VALUE) { status = iree_make_status(IREE_STATUS_ABORTED); } iree_slim_mutex_unlock(&semaphore->mutex); @@ -444,7 +444,7 @@ iree_status_t iree_hal_cuda_semaphore_multi_wait( iree_hal_cuda_semaphore_t* semaphore = iree_hal_cuda_semaphore_cast(semaphore_list.semaphores[i]); iree_slim_mutex_lock(&semaphore->mutex); - if (semaphore->current_value == IREE_HAL_SEMAPHORE_FAILURE_VALUE) { + if (semaphore->current_value >= IREE_HAL_SEMAPHORE_FAILURE_VALUE) { iree_slim_mutex_unlock(&semaphore->mutex); status = iree_make_status(IREE_STATUS_ABORTED); break; diff --git a/runtime/src/iree/hal/drivers/hip/event_semaphore.c b/runtime/src/iree/hal/drivers/hip/event_semaphore.c index 926eb54ce5f7..de10b09125ec 100644 --- a/runtime/src/iree/hal/drivers/hip/event_semaphore.c +++ b/runtime/src/iree/hal/drivers/hip/event_semaphore.c @@ -323,7 +323,7 @@ static iree_status_t iree_hal_hip_semaphore_wait( } iree_slim_mutex_lock(&semaphore->mutex); - if (semaphore->current_value == IREE_HAL_SEMAPHORE_FAILURE_VALUE) { + if (semaphore->current_value >= IREE_HAL_SEMAPHORE_FAILURE_VALUE) { iree_slim_mutex_unlock(&semaphore->mutex); IREE_TRACE_ZONE_END(z0); return iree_make_status(IREE_STATUS_ABORTED); @@ -346,7 +346,7 @@ static iree_status_t iree_hal_hip_semaphore_wait( } iree_slim_mutex_lock(&semaphore->mutex); - if (semaphore->current_value == IREE_HAL_SEMAPHORE_FAILURE_VALUE) { + if (semaphore->current_value >= IREE_HAL_SEMAPHORE_FAILURE_VALUE) { status = iree_make_status(IREE_STATUS_ABORTED); } iree_slim_mutex_unlock(&semaphore->mutex); @@ -440,7 +440,7 @@ iree_status_t iree_hal_hip_semaphore_multi_wait( iree_hal_hip_semaphore_t* semaphore = iree_hal_hip_semaphore_cast(semaphore_list.semaphores[i]); iree_slim_mutex_lock(&semaphore->mutex); - if (semaphore->current_value == IREE_HAL_SEMAPHORE_FAILURE_VALUE) { + if (semaphore->current_value >= IREE_HAL_SEMAPHORE_FAILURE_VALUE) { iree_slim_mutex_unlock(&semaphore->mutex); status = iree_make_status(IREE_STATUS_ABORTED); break; diff --git a/runtime/src/iree/hal/semaphore.h b/runtime/src/iree/hal/semaphore.h index 8cc073bfcbf9..52571ed048fd 100644 --- a/runtime/src/iree/hal/semaphore.h +++ b/runtime/src/iree/hal/semaphore.h @@ -30,10 +30,6 @@ enum iree_hal_semaphore_flag_bits_t { }; typedef uint32_t iree_hal_semaphore_flags_t; -//===----------------------------------------------------------------------===// -// iree_hal_semaphore_t -//===----------------------------------------------------------------------===// - // The maximum valid payload value of an iree_hal_semaphore_t. // Payload values larger than this indicate that the semaphore has failed. // @@ -56,8 +52,66 @@ typedef uint32_t iree_hal_semaphore_flags_t; // https://vulkan.gpuinfo.org/displayextensionproperty.php?name=maxTimelineSemaphoreValueDifference #define IREE_HAL_SEMAPHORE_MAX_VALUE (2147483647ull - 1) +// The minimum value for a semaphore that indicates failure. Any value +// greater-than-or-equal-to (>=) this indicates the semaphore has failed. +// +// If the upper bit 63 is set then the value represents an iree_status_t. +// Use iree_hal_semaphore_failure_as_status to convert a payload value to a +// status. Not all implementations do (or can) support encoding statuses and may +// only ever be able to set a failing semaphore to this value. #define IREE_HAL_SEMAPHORE_FAILURE_VALUE (IREE_HAL_SEMAPHORE_MAX_VALUE + 1) +// Bit indicating that a failing semaphore value can be interpreted as an +// iree_status_t. +#define IREE_HAL_SEMAPHORE_FAILURE_VALUE_STATUS_BIT 0x8000000000000000ull + +// Returns a semaphore payload value that encodes the given |status|. +// Ownership of the status is transferred to the semaphore and it must be +// freed by a consumer. Not all implementations can support failure status +// payloads and this should only be used by those implementations that can. +static inline uint64_t iree_hal_status_as_semaphore_failure( + iree_status_t status) { + return IREE_HAL_SEMAPHORE_FAILURE_VALUE_STATUS_BIT | + (((uint64_t)status) >> 1); +} + +// Returns OK if the |value| does not indicate an error. +// Returns an error status if the semaphore payload value represents a failure. +// If the payload contains an encoded iree_status_t it will be cloned and the +// new copy will be returned to the caller. +static inline iree_status_t iree_hal_semaphore_failure_as_status( + uint64_t value) { + if (value >= IREE_HAL_SEMAPHORE_FAILURE_VALUE) { + if (value & IREE_HAL_SEMAPHORE_FAILURE_VALUE_STATUS_BIT) { + // The top bits of a pointer are sign-extended from bit 47 so we can + // restore the top bit by left-shifting the upper bits and then + // right-shifting with sign extension. We only use a single bit today and + // so bit 62 should still be the original value of the pointer. + // Note that if the status is just a code (no allocated pointer) this + // clone is a no-op and the code will be returned without an allocation. + // + // See: + // https://en.wikipedia.org/wiki/X86-64#Canonical_form_addresses + return iree_status_clone((iree_status_t)(((int64_t)value << 1) >> 1)); + } else { + return iree_status_from_code(IREE_STATUS_INTERNAL); + } + } else { + return iree_ok_status(); + } +} + +// Frees an iree_status_t encoded in a semaphore |value|, if any. +static inline void iree_hal_semaphore_failure_free(uint64_t value) { + if (value & IREE_HAL_SEMAPHORE_FAILURE_VALUE_STATUS_BIT) { + iree_status_free((iree_status_t)(((int64_t)value << 1) >> 1)); + } +} + +//===----------------------------------------------------------------------===// +// iree_hal_semaphore_t +//===----------------------------------------------------------------------===// + // Synchronization mechanism for host->device, device->host, host->host, // and device->device notification. Semaphores behave like Vulkan timeline // semaphores (or D3D12 fences) and contain a monotonically increasing From f8b84141b215a4b1e5d01758ec486554cbc4c819 Mon Sep 17 00:00:00 2001 From: Ben Vanik Date: Mon, 28 Oct 2024 10:54:33 -0700 Subject: [PATCH 22/45] Modernizing iree_atomic_*. (#18910) C11's _Generic lets us avoid the need for specifying the type in the name and more closely match the C11 atomic syntax. This assumes that any C compiler we have that goes down the disabled atomics path supports _Generic (modern GCC, Clang, and MSVC all have for awhile). This allows us to drop-in replace C11-style atomics (useful in the new AMDGPU backend) and on MSVC will allow us to use their implementation when it's ready (it's way better than the Interlocked solution we have now). --- experimental/webgpu/nop_semaphore.c | 12 +- runtime/src/iree/base/internal/atomics.h | 55 +-- .../src/iree/base/internal/atomics_clang.h | 35 +- .../src/iree/base/internal/atomics_disabled.h | 344 ++++++++++------ runtime/src/iree/base/internal/atomics_gcc.h | 44 ++- runtime/src/iree/base/internal/atomics_msvc.h | 374 ++++++++++++------ .../src/iree/base/internal/atomics_test.cc | 64 +-- .../base/internal/dynamic_library_win32.c | 2 +- .../src/iree/base/internal/synchronization.c | 50 ++- .../src/iree/base/internal/threading_darwin.c | 7 +- .../iree/base/internal/threading_pthreads.c | 12 +- .../src/iree/base/internal/threading_test.cc | 34 +- .../src/iree/base/internal/threading_win32.c | 7 +- .../iree/base/internal/wait_handle_inproc.c | 13 +- .../src/iree/hal/drivers/cuda/memory_pools.c | 16 +- .../src/iree/hal/drivers/hip/memory_pools.c | 16 +- .../src/iree/hal/drivers/metal/shared_event.m | 4 +- .../iree/hal/drivers/metal/staging_buffer.m | 10 +- .../hal/drivers/vulkan/native_semaphore.cc | 9 +- .../hal/local/executable_plugin_manager.c | 11 +- .../src/iree/hal/utils/deferred_work_queue.c | 42 +- runtime/src/iree/hal/utils/file_transfer.c | 4 +- runtime/src/iree/task/affinity_set.h | 8 +- runtime/src/iree/task/executor.c | 10 +- runtime/src/iree/task/executor_demo.cc | 8 +- runtime/src/iree/task/poller.c | 32 +- runtime/src/iree/task/scope.c | 22 +- runtime/src/iree/task/task.c | 78 ++-- runtime/src/iree/task/task_test_dispatch.cc | 7 +- runtime/src/iree/task/worker.c | 22 +- runtime/src/iree/vm/context.c | 4 +- runtime/src/iree/vm/invocation.c | 4 +- runtime/src/iree/vm/ref.c | 34 +- runtime/src/iree/vm/ref_test.cc | 6 +- 34 files changed, 794 insertions(+), 606 deletions(-) diff --git a/experimental/webgpu/nop_semaphore.c b/experimental/webgpu/nop_semaphore.c index d4151ee29990..65d26486567b 100644 --- a/experimental/webgpu/nop_semaphore.c +++ b/experimental/webgpu/nop_semaphore.c @@ -38,8 +38,8 @@ iree_status_t iree_hal_webgpu_nop_semaphore_create( iree_hal_resource_initialize(&iree_hal_webgpu_nop_semaphore_vtable, &semaphore->resource); semaphore->host_allocator = host_allocator; - iree_atomic_store_int64(&semaphore->value, initial_value, - iree_memory_order_seq_cst); + iree_atomic_store(&semaphore->value, initial_value, + iree_memory_order_seq_cst); *out_semaphore = (iree_hal_semaphore_t*)semaphore; } @@ -63,8 +63,7 @@ static iree_status_t iree_hal_webgpu_nop_semaphore_query( iree_hal_semaphore_t* base_semaphore, uint64_t* out_value) { iree_hal_webgpu_nop_semaphore_t* semaphore = iree_hal_webgpu_nop_semaphore_cast(base_semaphore); - *out_value = - iree_atomic_load_int64(&semaphore->value, iree_memory_order_seq_cst); + *out_value = iree_atomic_load(&semaphore->value, iree_memory_order_seq_cst); return iree_ok_status(); } @@ -72,8 +71,7 @@ static iree_status_t iree_hal_webgpu_nop_semaphore_signal( iree_hal_semaphore_t* base_semaphore, uint64_t new_value) { iree_hal_webgpu_nop_semaphore_t* semaphore = iree_hal_webgpu_nop_semaphore_cast(base_semaphore); - iree_atomic_store_int64(&semaphore->value, new_value, - iree_memory_order_seq_cst); + iree_atomic_store(&semaphore->value, new_value, iree_memory_order_seq_cst); return iree_ok_status(); } @@ -88,7 +86,7 @@ static iree_status_t iree_hal_webgpu_nop_semaphore_wait( iree_hal_webgpu_nop_semaphore_t* semaphore = iree_hal_webgpu_nop_semaphore_cast(base_semaphore); uint64_t current_value = - iree_atomic_load_int64(&semaphore->value, iree_memory_order_seq_cst); + iree_atomic_load(&semaphore->value, iree_memory_order_seq_cst); if (current_value < value) { return iree_make_status( IREE_STATUS_FAILED_PRECONDITION, diff --git a/runtime/src/iree/base/internal/atomics.h b/runtime/src/iree/base/internal/atomics.h index 731d9eef510e..f428731506a5 100644 --- a/runtime/src/iree/base/internal/atomics.h +++ b/runtime/src/iree/base/internal/atomics.h @@ -86,47 +86,6 @@ extern "C" { #endif // IREE_COMPILER_* -// If the compiler can automatically determine the types: -#ifdef iree_atomic_load_auto - -#define iree_atomic_load_int32 iree_atomic_load_auto -#define iree_atomic_store_int32 iree_atomic_store_auto -#define iree_atomic_fetch_add_int32 iree_atomic_fetch_add_auto -#define iree_atomic_fetch_sub_int32 iree_atomic_fetch_sub_auto -#define iree_atomic_fetch_and_int32 iree_atomic_fetch_and_auto -#define iree_atomic_fetch_or_int32 iree_atomic_fetch_or_auto -#define iree_atomic_fetch_xor_int32 iree_atomic_fetch_xor_auto -#define iree_atomic_exchange_int32 iree_atomic_exchange_auto -#define iree_atomic_compare_exchange_strong_int32 \ - iree_atomic_compare_exchange_strong_auto -#define iree_atomic_compare_exchange_weak_int32 \ - iree_atomic_compare_exchange_weak_auto - -#define iree_atomic_load_int64 iree_atomic_load_auto -#define iree_atomic_store_int64 iree_atomic_store_auto -#define iree_atomic_fetch_add_int64 iree_atomic_fetch_add_auto -#define iree_atomic_fetch_sub_int64 iree_atomic_fetch_sub_auto -#define iree_atomic_fetch_and_int64 iree_atomic_fetch_and_auto -#define iree_atomic_fetch_or_int64 iree_atomic_fetch_or_auto -#define iree_atomic_fetch_xor_int64 iree_atomic_fetch_xor_auto -#define iree_atomic_exchange_int64 iree_atomic_exchange_auto -#define iree_atomic_compare_exchange_strong_int64 \ - iree_atomic_compare_exchange_strong_auto -#define iree_atomic_compare_exchange_weak_int64 \ - iree_atomic_compare_exchange_weak_auto - -#define iree_atomic_load_intptr iree_atomic_load_auto -#define iree_atomic_store_intptr iree_atomic_store_auto -#define iree_atomic_fetch_add_intptr iree_atomic_fetch_add_auto -#define iree_atomic_fetch_sub_intptr iree_atomic_fetch_sub_auto -#define iree_atomic_exchange_intptr iree_atomic_exchange_auto -#define iree_atomic_compare_exchange_strong_intptr \ - iree_atomic_compare_exchange_strong_auto -#define iree_atomic_compare_exchange_weak_intptr \ - iree_atomic_compare_exchange_weak_auto - -#endif // iree_atomic_load_auto - //============================================================================== // Reference count atomics //============================================================================== @@ -140,10 +99,10 @@ typedef iree_atomic_int32_t iree_atomic_ref_count_t; // should use IREE_ATOMIC_VAR_INIT, but apparently this has to be fixed // at call sites (where the variables are initialized in the first place). #define iree_atomic_ref_count_init_value(count_ptr, value) \ - iree_atomic_store_int32(count_ptr, value, iree_memory_order_relaxed) + iree_atomic_store((count_ptr), (value), iree_memory_order_relaxed) #define iree_atomic_ref_count_init(count_ptr) \ - iree_atomic_ref_count_init_value(count_ptr, 1) + iree_atomic_ref_count_init_value((count_ptr), 1) // Why relaxed order: // https://www.boost.org/doc/libs/1_57_0/doc/html/atomic/usage_examples.html#boost_atomic.usage_examples.example_reference_counters.discussion @@ -155,9 +114,9 @@ typedef iree_atomic_int32_t iree_atomic_ref_count_t; // value (unlike iree_atomic_ref_count_dec), so we make sure that it does not, // which allows the implementation to use faster atomic instructions where // available, e.g. STADD on ARMv8.1-a. -#define iree_atomic_ref_count_inc(count_ptr) \ - do { \ - iree_atomic_fetch_add_int32(count_ptr, 1, iree_memory_order_relaxed); \ +#define iree_atomic_ref_count_inc(count_ptr) \ + do { \ + iree_atomic_fetch_add((count_ptr), 1, iree_memory_order_relaxed); \ } while (false) // For now we stick to acq_rel order. TODO: should we follow Boost's advice? @@ -169,13 +128,13 @@ typedef iree_atomic_int32_t iree_atomic_ref_count_t; // may be a pessimization... I would like to hear a second opinion on this, // particularly regarding how x86-centric this might be. #define iree_atomic_ref_count_dec(count_ptr) \ - iree_atomic_fetch_sub_int32(count_ptr, 1, iree_memory_order_acq_rel) + iree_atomic_fetch_sub((count_ptr), 1, iree_memory_order_acq_rel) // memory_order_acquire order ensures that this sees decrements from // iree_atomic_ref_count_dec. On the other hand, there is no ordering with // iree_atomic_ref_count_inc. #define iree_atomic_ref_count_load(count_ptr) \ - iree_atomic_load_int32(count_ptr, iree_memory_order_acquire) + iree_atomic_load((count_ptr), iree_memory_order_acquire) // Aborts the program if the given reference count value is not 1. // This should be avoided in all situations but those where continuing execution diff --git a/runtime/src/iree/base/internal/atomics_clang.h b/runtime/src/iree/base/internal/atomics_clang.h index 44514e05c742..afa7a3352017 100644 --- a/runtime/src/iree/base/internal/atomics_clang.h +++ b/runtime/src/iree/base/internal/atomics_clang.h @@ -33,37 +33,38 @@ typedef enum iree_memory_order_e { typedef _Atomic int32_t iree_atomic_int32_t; typedef _Atomic int64_t iree_atomic_int64_t; +typedef _Atomic uint32_t iree_atomic_uint32_t; +typedef _Atomic uint64_t iree_atomic_uint64_t; // TODO(#3453): check for __int128 support before using // typedef _Atomic __int128 iree_atomic_int128_t; typedef _Atomic intptr_t iree_atomic_intptr_t; -#define iree_atomic_load_auto(object, order) \ - __c11_atomic_load((object), (order)) -#define iree_atomic_store_auto(object, desired, order) \ +#define iree_atomic_thread_fence(order) __c11_atomic_thread_fence(order) + +#define iree_atomic_load(object, order) __c11_atomic_load((object), (order)) +#define iree_atomic_store(object, desired, order) \ __c11_atomic_store((object), (desired), (order)) -#define iree_atomic_fetch_add_auto(object, operand, order) \ +#define iree_atomic_fetch_add(object, operand, order) \ __c11_atomic_fetch_add((object), (operand), (order)) -#define iree_atomic_fetch_sub_auto(object, operand, order) \ +#define iree_atomic_fetch_sub(object, operand, order) \ __c11_atomic_fetch_sub((object), (operand), (order)) -#define iree_atomic_fetch_and_auto(object, operand, order) \ +#define iree_atomic_fetch_and(object, operand, order) \ __c11_atomic_fetch_and((object), (operand), (order)) -#define iree_atomic_fetch_or_auto(object, operand, order) \ +#define iree_atomic_fetch_or(object, operand, order) \ __c11_atomic_fetch_or((object), (operand), (order)) -#define iree_atomic_fetch_xor_auto(object, operand, order) \ +#define iree_atomic_fetch_xor(object, operand, order) \ __c11_atomic_fetch_xor((object), (operand), (order)) -#define iree_atomic_exchange_auto(object, operand, order) \ +#define iree_atomic_exchange(object, operand, order) \ __c11_atomic_exchange((object), (operand), (order)) -#define iree_atomic_compare_exchange_strong_auto(object, expected, desired, \ - order_succ, order_fail) \ - __c11_atomic_compare_exchange_strong((object), (expected), (desired), \ +#define iree_atomic_compare_exchange_strong(object, expected, desired, \ + order_succ, order_fail) \ + __c11_atomic_compare_exchange_strong((object), (expected), (desired), \ (order_succ), (order_fail)) -#define iree_atomic_compare_exchange_weak_auto(object, expected, desired, \ - order_succ, order_fail) \ - __c11_atomic_compare_exchange_weak((object), (expected), (desired), \ +#define iree_atomic_compare_exchange_weak(object, expected, desired, \ + order_succ, order_fail) \ + __c11_atomic_compare_exchange_weak((object), (expected), (desired), \ (order_succ), (order_fail)) -#define iree_atomic_thread_fence(order) __c11_atomic_thread_fence(order) - #ifdef __cplusplus } // extern "C" #endif diff --git a/runtime/src/iree/base/internal/atomics_disabled.h b/runtime/src/iree/base/internal/atomics_disabled.h index 5c0a7cad6ff5..5dbb272f4748 100644 --- a/runtime/src/iree/base/internal/atomics_disabled.h +++ b/runtime/src/iree/base/internal/atomics_disabled.h @@ -16,12 +16,8 @@ #if IREE_SYNCHRONIZATION_DISABLE_UNSAFE -#ifdef __cplusplus -extern "C" { -#endif - typedef enum iree_memory_order_e { - iree_memory_order_relaxed, + iree_memory_order_relaxed = 0u, iree_memory_order_consume, iree_memory_order_acquire, iree_memory_order_release, @@ -33,65 +29,197 @@ typedef enum iree_memory_order_e { typedef int32_t iree_atomic_int32_t; typedef int64_t iree_atomic_int64_t; +typedef uint32_t iree_atomic_uint32_t; +typedef uint64_t iree_atomic_uint64_t; // TODO(#3453): check for __int128 support before using // typedef __int128 iree_atomic_int128_t; typedef intptr_t iree_atomic_intptr_t; -#define iree_atomic_load_int32(object, order) (*(object)) -#define iree_atomic_store_int32(object, desired, order) (*(object) = (desired)) -#define iree_atomic_fetch_add_int32(object, operand, order) \ - iree_atomic_fetch_add_int32_impl((volatile iree_atomic_int32_t*)(object), \ - (int32_t)(operand)) -#define iree_atomic_fetch_sub_int32(object, operand, order) \ - iree_atomic_fetch_add_int32_impl((volatile iree_atomic_int32_t*)(object), \ - -(int32_t)(operand)) -#define iree_atomic_fetch_and_int32(object, operand, order) \ - iree_atomic_fetch_and_int32_impl((volatile iree_atomic_int32_t*)(object), \ - (int32_t)(operand)) -#define iree_atomic_fetch_or_int32(object, operand, order) \ - iree_atomic_fetch_or_int32_impl((volatile iree_atomic_int32_t*)(object), \ - (int32_t)(operand)) -#define iree_atomic_fetch_xor_int32(object, operand, order) \ - iree_atomic_fetch_xor_int32_impl((volatile iree_atomic_int32_t*)(object), \ - (int32_t)(operand)) -#define iree_atomic_exchange_int32(object, desired, order) \ - iree_atomic_fetch_exchange_int32_impl( \ - (volatile iree_atomic_int32_t*)(object), (int32_t)(desired)) -#define iree_atomic_compare_exchange_strong_int32(object, expected, desired, \ - order_succ, order_fail) \ - iree_atomic_compare_exchange_int32_impl( \ - (volatile iree_atomic_int32_t*)(object), (int32_t*)(expected), \ - (int32_t)(desired)) -#define iree_atomic_compare_exchange_weak_int32 \ - iree_atomic_compare_exchange_strong_int32 - -#define iree_atomic_load_int64(object, order) (*(object)) -#define iree_atomic_store_int64(object, desired, order) (*(object) = (desired)) -#define iree_atomic_fetch_add_int64(object, operand, order) \ - iree_atomic_fetch_add_int64_impl((volatile iree_atomic_int64_t*)(object), \ - (int64_t)(operand)) -#define iree_atomic_fetch_sub_int64(object, operand, order) \ - iree_atomic_fetch_add_int64_impl((volatile iree_atomic_int64_t*)(object), \ - -(int64_t)(operand)) -#define iree_atomic_fetch_and_int64(object, operand, order) \ - iree_atomic_fetch_and_int64_impl((volatile iree_atomic_int64_t*)(object), \ - (int64_t)(operand)) -#define iree_atomic_fetch_or_int64(object, operand, order) \ - iree_atomic_fetch_or_int64_impl((volatile iree_atomic_int64_t*)(object), \ - (int64_t)(operand)) -#define iree_atomic_fetch_xor_int64(object, operand, order) \ - iree_atomic_fetch_xor_int64_impl((volatile iree_atomic_int64_t*)(object), \ - (int64_t)(operand)) -#define iree_atomic_exchange_int64(object, desired, order) \ - iree_atomic_fetch_exchange_int64_impl( \ - (volatile iree_atomic_int64_t*)(object), (int64_t)(desired)) -#define iree_atomic_compare_exchange_strong_int64(object, expected, desired, \ - order_succ, order_fail) \ - iree_atomic_compare_exchange_int64_impl( \ - (volatile iree_atomic_int64_t*)(object), (int64_t*)(expected), \ - (int64_t)(desired)) -#define iree_atomic_compare_exchange_weak_int64 \ - iree_atomic_compare_exchange_strong_int64 +#define iree_atomic_thread_fence(order) + +#ifdef __cplusplus + +extern "C++" { + +#define iree_atomic_load(object, order) (*(object)) +#define iree_atomic_store(object, desired, order) (*(object) = (desired)) +#define iree_atomic_fetch_add(object, operand, order) \ + iree_atomic_fetch_add_impl((object), (operand)) +#define iree_atomic_fetch_sub(object, operand, order) \ + iree_atomic_fetch_sub_impl((object), (operand)) +#define iree_atomic_fetch_and(object, operand, order) \ + iree_atomic_fetch_and_impl((object), (operand)) +#define iree_atomic_fetch_or(object, operand, order) \ + iree_atomic_fetch_or_impl((object), (operand)) +#define iree_atomic_fetch_xor(object, operand, order) \ + iree_atomic_fetch_xor_impl((object), (operand)) +#define iree_atomic_exchange(object, desired, order) \ + iree_atomic_fetch_exchange_impl((object), (desired)) +#define iree_atomic_compare_exchange_strong(object, expected, desired, \ + order_succ, order_fail) \ + iree_atomic_compare_exchange_impl((object), (expected), (desired)) +#define iree_atomic_compare_exchange_weak iree_atomic_compare_exchange_strong + +template +static inline T iree_atomic_fetch_add_impl(volatile T* object, V operand) { + T original = *object; + *object += operand; + return original; +} + +template +static inline T iree_atomic_fetch_sub_impl(volatile T* object, V operand) { + T original = *object; + *object -= operand; + return original; +} + +template +static inline T iree_atomic_fetch_and_impl(volatile T* object, V operand) { + T original = *object; + *object &= operand; + return original; +} + +template +static inline T iree_atomic_fetch_or_impl(volatile T* object, V operand) { + T original = *object; + *object |= operand; + return original; +} + +template +static inline T iree_atomic_fetch_xor_impl(volatile T* object, V operand) { + T original = *object; + *object ^= operand; + return original; +} + +template +static inline T iree_atomic_fetch_exchange_impl(volatile T* object, V desired) { + T original = *object; + *object = desired; + return original; +} + +template +static inline bool iree_atomic_compare_exchange_impl(volatile T* object, + V* expected, V desired) { + if (*object == *expected) { + *object = desired; + return true; + } else { + *expected = *object; + return false; + } +} + +} // extern "C" + +#else + +#define iree_atomic_load(object, order) (*(object)) +#define iree_atomic_store(object, desired, order) (*(object) = (desired)) +#define iree_atomic_fetch_add(object, operand, order) \ + _Generic((object), \ + iree_atomic_int32_t *: iree_atomic_fetch_add_int32_impl( \ + (volatile iree_atomic_int32_t*)(object), \ + (int32_t)(operand)), \ + iree_atomic_int64_t *: iree_atomic_fetch_add_int64_impl( \ + (volatile iree_atomic_int64_t*)(object), \ + (int64_t)(operand)), \ + iree_atomic_uint32_t *: iree_atomic_fetch_add_uint32_impl( \ + (volatile iree_atomic_uint32_t*)(object), \ + (uint32_t)(operand)), \ + iree_atomic_uint64_t *: iree_atomic_fetch_add_uint64_impl( \ + (volatile iree_atomic_uint64_t*)(object), \ + (uint64_t)(operand))) +#define iree_atomic_fetch_sub(object, operand, order) \ + _Generic((object), \ + iree_atomic_int32_t *: iree_atomic_fetch_sub_int32_impl( \ + (volatile iree_atomic_int32_t*)(object), \ + (int32_t)(operand)), \ + iree_atomic_int64_t *: iree_atomic_fetch_sub_int64_impl( \ + (volatile iree_atomic_int64_t*)(object), \ + (int64_t)(operand)), \ + iree_atomic_uint32_t *: iree_atomic_fetch_sub_uint32_impl( \ + (volatile iree_atomic_uint32_t*)(object), \ + (uint32_t)(operand)), \ + iree_atomic_uint64_t *: iree_atomic_fetch_sub_uint64_impl( \ + (volatile iree_atomic_uint64_t*)(object), \ + (uint64_t)(operand))) +#define iree_atomic_fetch_and(object, operand, order) \ + _Generic((object), \ + iree_atomic_int32_t *: iree_atomic_fetch_and_int32_impl( \ + (volatile iree_atomic_int32_t*)(object), \ + (int32_t)(operand)), \ + iree_atomic_int64_t *: iree_atomic_fetch_and_int64_impl( \ + (volatile iree_atomic_int64_t*)(object), \ + (int64_t)(operand)), \ + iree_atomic_uint32_t *: iree_atomic_fetch_and_int32_impl( \ + (volatile iree_atomic_int32_t*)(object), \ + (int32_t)(operand)), \ + iree_atomic_uint64_t *: iree_atomic_fetch_and_int64_impl( \ + (volatile iree_atomic_int64_t*)(object), \ + (int64_t)(operand))) +#define iree_atomic_fetch_or(object, operand, order) \ + _Generic((object), \ + iree_atomic_int32_t *: iree_atomic_fetch_or_int32_impl( \ + (volatile iree_atomic_int32_t*)(object), \ + (int32_t)(operand)), \ + iree_atomic_int64_t *: iree_atomic_fetch_or_int64_impl( \ + (volatile iree_atomic_int64_t*)(object), \ + (int64_t)(operand)), \ + iree_atomic_uint32_t *: iree_atomic_fetch_or_int32_impl( \ + (volatile iree_atomic_int32_t*)(object), \ + (int32_t)(operand)), \ + iree_atomic_uint64_t *: iree_atomic_fetch_or_int64_impl( \ + (volatile iree_atomic_int64_t*)(object), \ + (int64_t)(operand))) +#define iree_atomic_fetch_xor(object, operand, order) \ + _Generic((object), \ + iree_atomic_int32_t *: iree_atomic_fetch_xor_int32_impl( \ + (volatile iree_atomic_int32_t*)(object), \ + (int32_t)(operand)), \ + iree_atomic_int64_t *: iree_atomic_fetch_xor_int64_impl( \ + (volatile iree_atomic_int64_t*)(object), \ + (int64_t)(operand)), \ + iree_atomic_uint32_t *: iree_atomic_fetch_xor_int32_impl( \ + (volatile iree_atomic_int32_t*)(object), \ + (int32_t)(operand)), \ + iree_atomic_uint64_t *: iree_atomic_fetch_xor_int64_impl( \ + (volatile iree_atomic_int64_t*)(object), \ + (int64_t)(operand))) +#define iree_atomic_exchange(object, desired, order) \ + _Generic((object), \ + iree_atomic_int32_t *: iree_atomic_fetch_exchange_int32_impl( \ + (volatile iree_atomic_int32_t*)(object), \ + (int32_t)(desired)), \ + iree_atomic_int64_t *: iree_atomic_fetch_exchange_int64_impl( \ + (volatile iree_atomic_int64_t*)(object), \ + (int64_t)(desired)), \ + iree_atomic_uint32_t *: iree_atomic_fetch_exchange_int32_impl( \ + (volatile iree_atomic_int32_t*)(object), \ + (int32_t)(desired)), \ + iree_atomic_uint64_t *: iree_atomic_fetch_exchange_int64_impl( \ + (volatile iree_atomic_int64_t*)(object), \ + (int64_t)(desired))) +#define iree_atomic_compare_exchange_strong(object, expected, desired, \ + order_succ, order_fail) \ + _Generic((object), \ + iree_atomic_int32_t *: iree_atomic_compare_exchange_int32_impl( \ + (volatile iree_atomic_int32_t*)(object), \ + (int32_t*)(expected), (int32_t)(desired)), \ + iree_atomic_int64_t *: iree_atomic_compare_exchange_int64_impl( \ + (volatile iree_atomic_int64_t*)(object), \ + (int64_t*)(expected), (int64_t)(desired)), \ + iree_atomic_uint32_t *: iree_atomic_compare_exchange_int32_impl( \ + (volatile iree_atomic_int32_t*)(object), \ + (int32_t*)(expected), (int32_t)(desired)), \ + iree_atomic_uint64_t *: iree_atomic_compare_exchange_int64_impl( \ + (volatile iree_atomic_int64_t*)(object), \ + (int64_t*)(expected), (int64_t)(desired))) +#define iree_atomic_compare_exchange_weak iree_atomic_compare_exchange_strong static inline int32_t iree_atomic_fetch_add_int32_impl( volatile iree_atomic_int32_t* object, int32_t operand) { @@ -100,6 +228,27 @@ static inline int32_t iree_atomic_fetch_add_int32_impl( return original; } +static inline int32_t iree_atomic_fetch_sub_int32_impl( + volatile iree_atomic_int32_t* object, int32_t operand) { + int32_t original = *object; + *object -= operand; + return original; +} + +static inline int32_t iree_atomic_fetch_add_uint32_impl( + volatile iree_atomic_int32_t* object, uint32_t operand) { + uint32_t original = *object; + *object += operand; + return original; +} + +static inline int32_t iree_atomic_fetch_sub_uint32_impl( + volatile iree_atomic_uint32_t* object, uint32_t operand) { + uint32_t original = *object; + *object -= operand; + return original; +} + static inline int32_t iree_atomic_fetch_and_int32_impl( volatile iree_atomic_int32_t* object, int32_t operand) { int32_t original = *object; @@ -146,6 +295,27 @@ static inline int64_t iree_atomic_fetch_add_int64_impl( return original; } +static inline int64_t iree_atomic_fetch_sub_int64_impl( + volatile iree_atomic_int64_t* object, int64_t operand) { + int64_t original = *object; + *object -= operand; + return original; +} + +static inline int64_t iree_atomic_fetch_add_uint64_impl( + volatile iree_atomic_uint64_t* object, uint64_t operand) { + uint64_t original = *object; + *object += operand; + return original; +} + +static inline int64_t iree_atomic_fetch_sub_uint64_impl( + volatile iree_atomic_uint64_t* object, uint64_t operand) { + uint64_t original = *object; + *object -= operand; + return original; +} + static inline int64_t iree_atomic_fetch_and_int64_impl( volatile iree_atomic_int64_t* object, int64_t operand) { int64_t original = *object; @@ -185,59 +355,7 @@ static inline bool iree_atomic_compare_exchange_int64_impl( } } -// There are no pointer-width atomic ops in MSVC so we need to specialize based -// on the pointer size. -#if defined(IREE_PTR_SIZE_32) -#define iree_atomic_load_intptr(object, order) \ - (intptr_t) iree_atomic_load_int32((iree_atomic_int32_t*)(object), (order)) -#define iree_atomic_store_intptr(object, desired, order) \ - (intptr_t) iree_atomic_store_int32((iree_atomic_int32_t*)(object), \ - (int32_t)(desired), (order)) -#define iree_atomic_fetch_add_intptr(object, operand, order) \ - (intptr_t) iree_atomic_fetch_add_int32((iree_atomic_int32_t*)(object), \ - (int32_t)(operand), (order)) -#define iree_atomic_fetch_sub_intptr(object, operand, order) \ - (intptr_t) iree_atomic_fetch_sub_int32((iree_atomic_int32_t*)(object), \ - (int32_t)(operand), (order)) -#define iree_atomic_exchange_intptr(object, desired, order) \ - (intptr_t) iree_atomic_exchange_int32((iree_atomic_int32_t*)(object), \ - (int32_t)(desired), (order)) -#define iree_atomic_compare_exchange_strong_intptr(object, expected, desired, \ - order_succ, order_fail) \ - iree_atomic_compare_exchange_strong_int32( \ - (iree_atomic_int32_t*)(object), (int32_t*)(expected), \ - (int32_t)(desired), (order_succ), (order_fail)) -#define iree_atomic_compare_exchange_weak_intptr \ - iree_atomic_compare_exchange_strong_intptr -#else -#define iree_atomic_load_intptr(object, order) \ - (intptr_t) iree_atomic_load_int64((iree_atomic_int64_t*)(object), (order)) -#define iree_atomic_store_intptr(object, desired, order) \ - (intptr_t) iree_atomic_store_int64((iree_atomic_int64_t*)(object), \ - (int64_t)(desired), (order)) -#define iree_atomic_fetch_add_intptr(object, operand, order) \ - (intptr_t) iree_atomic_fetch_add_int64((iree_atomic_int64_t*)(object), \ - (int64_t)(operand), (order)) -#define iree_atomic_fetch_sub_intptr(object, operand, order) \ - (intptr_t) iree_atomic_fetch_sub_int64((iree_atomic_int64_t*)(object), \ - (int64_t)(operand), (order)) -#define iree_atomic_exchange_intptr(object, desired, order) \ - (intptr_t) iree_atomic_exchange_int64((iree_atomic_int64_t*)(object), \ - (int64_t)(desired), (order)) -#define iree_atomic_compare_exchange_strong_intptr(object, expected, desired, \ - order_succ, order_fail) \ - iree_atomic_compare_exchange_strong_int64( \ - (iree_atomic_int64_t*)(object), (int64_t*)(expected), \ - (int64_t)(desired), (order_succ), (order_fail)) -#define iree_atomic_compare_exchange_weak_intptr \ - iree_atomic_compare_exchange_strong_intptr -#endif // IREE_PTR_SIZE_32 - -#define iree_atomic_thread_fence(order) - -#ifdef __cplusplus -} // extern "C" -#endif +#endif // __cplusplus #endif // IREE_SYNCHRONIZATION_DISABLE_UNSAFE diff --git a/runtime/src/iree/base/internal/atomics_gcc.h b/runtime/src/iree/base/internal/atomics_gcc.h index d413b9816253..728add728612 100644 --- a/runtime/src/iree/base/internal/atomics_gcc.h +++ b/runtime/src/iree/base/internal/atomics_gcc.h @@ -34,6 +34,8 @@ typedef enum iree_memory_order_e { typedef int32_t iree_atomic_int32_t; typedef int64_t iree_atomic_int64_t; +typedef uint32_t iree_atomic_uint32_t; +typedef uint64_t iree_atomic_uint64_t; // typedef __int128 iree_atomic_int128_t; typedef intptr_t iree_atomic_intptr_t; @@ -45,47 +47,47 @@ typedef intptr_t iree_atomic_intptr_t; #define __iree_auto_type __auto_type #endif -#define iree_atomic_load_auto(object, order) \ +static inline void iree_atomic_thread_fence(int order) { + // Ignore error where TSan does not support atomic thread fence. + IREE_DISABLE_COMPILER_TSAN_ERRORS() + __atomic_thread_fence(order); + IREE_RESTORE_COMPILER_TSAN_ERRORS() +} + +#define iree_atomic_load(object, order) \ __extension__({ \ __iree_auto_type __atomic_load_ptr = (object); \ __typeof__(*__atomic_load_ptr) __atomic_load_tmp; \ __atomic_load(__atomic_load_ptr, &__atomic_load_tmp, (order)); \ __atomic_load_tmp; \ }) -#define iree_atomic_store_auto(object, desired, order) \ +#define iree_atomic_store(object, desired, order) \ __extension__({ \ __iree_auto_type __atomic_store_ptr = (object); \ __typeof__(*__atomic_store_ptr) __atomic_store_tmp = (desired); \ __atomic_store(__atomic_store_ptr, &__atomic_store_tmp, (order)); \ }) -#define iree_atomic_fetch_add_auto(object, operand, order) \ +#define iree_atomic_fetch_add(object, operand, order) \ __atomic_fetch_add((object), (operand), (order)) -#define iree_atomic_fetch_sub_auto(object, operand, order) \ +#define iree_atomic_fetch_sub(object, operand, order) \ __atomic_fetch_sub((object), (operand), (order)) -#define iree_atomic_fetch_and_auto(object, operand, order) \ +#define iree_atomic_fetch_and(object, operand, order) \ __atomic_fetch_and((object), (operand), (order)) -#define iree_atomic_fetch_or_auto(object, operand, order) \ +#define iree_atomic_fetch_or(object, operand, order) \ __atomic_fetch_or((object), (operand), (order)) -#define iree_atomic_fetch_xor_auto(object, operand, order) \ +#define iree_atomic_fetch_xor(object, operand, order) \ __atomic_fetch_xor((object), (operand), (order)) -#define iree_atomic_exchange_auto(object, operand, order) \ +#define iree_atomic_exchange(object, operand, order) \ __atomic_exchange_n((object), (operand), (order)) -#define iree_atomic_compare_exchange_strong_auto(object, expected, desired, \ - order_succ, order_fail) \ - __atomic_compare_exchange_n(object, expected, desired, /*weak=*/false, \ +#define iree_atomic_compare_exchange_strong(object, expected, desired, \ + order_succ, order_fail) \ + __atomic_compare_exchange_n(object, expected, desired, /*weak=*/false, \ (order_succ), (order_fail)) -#define iree_atomic_compare_exchange_weak_auto(object, expected, desired, \ - order_succ, order_fail) \ - __atomic_compare_exchange_n(object, expected, desired, /*weak=*/true, \ +#define iree_atomic_compare_exchange_weak(object, expected, desired, \ + order_succ, order_fail) \ + __atomic_compare_exchange_n(object, expected, desired, /*weak=*/true, \ (order_succ), (order_fail)) -static inline void iree_atomic_thread_fence(int order) { - // Ignore error where TSan does not support atomic thread fence. - IREE_DISABLE_COMPILER_TSAN_ERRORS() - __atomic_thread_fence(order); - IREE_RESTORE_COMPILER_TSAN_ERRORS() -} - #ifdef __cplusplus } // extern "C" #endif diff --git a/runtime/src/iree/base/internal/atomics_msvc.h b/runtime/src/iree/base/internal/atomics_msvc.h index 5cfbf43eb3f6..2af2798c0a13 100644 --- a/runtime/src/iree/base/internal/atomics_msvc.h +++ b/runtime/src/iree/base/internal/atomics_msvc.h @@ -16,12 +16,141 @@ #if defined(IREE_COMPILER_MSVC) -#ifdef __cplusplus +// TODO(benvanik): make MSVC's C11 atomic support work. +// It's difficult to detect and has some weird configuration assertions around +// mixed C and C++ code. Support is only present when the +// `/experimental:c11atomics` but that is ignored on /TP (C++) compilation. +// __STDC_NO_ATOMICS__ is not unset when included/enabled so we can't use the +// standard check. Hopefully that'd be fixed if it ever leaves experimental. +#define IREE_ATOMIC_USE_MSVC_C11 0 +#if IREE_ATOMIC_USE_MSVC_C11 +#include +#endif // IREE_ATOMIC_USE_MSVC_C11 + +#if IREE_ATOMIC_USE_MSVC_C11 && defined(atomic_init) + +typedef enum iree_memory_order_e { + iree_memory_order_relaxed = _Atomic_memory_order_relaxed, + iree_memory_order_consume = _Atomic_memory_order_consume, + iree_memory_order_acquire = _Atomic_memory_order_acquire, + iree_memory_order_release = _Atomic_memory_order_release, + iree_memory_order_acq_rel = _Atomic_memory_order_acq_rel, + iree_memory_order_seq_cst = _Atomic_memory_order_seq_cst, +} iree_memory_order_t; + +#define IREE_ATOMIC_VAR_INIT(value) (value) + +typedef _Atomic int32_t iree_atomic_int32_t; +typedef _Atomic int64_t iree_atomic_int64_t; +typedef _Atomic uint32_t iree_atomic_uint32_t; +typedef _Atomic uint64_t iree_atomic_uint64_t; +// TODO(#3453): check for __int128 support before using +// typedef _Atomic __int128 iree_atomic_int128_t; +typedef _Atomic intptr_t iree_atomic_intptr_t; + +#define iree_atomic_thread_fence(order) atomic_thread_fence(order) + +#define iree_atomic_load(object, order) __c11_atomic_load((object), (order)) +#define iree_atomic_store(object, desired, order) \ + __c11_atomic_store((object), (desired), (order)) +#define iree_atomic_fetch_add(object, operand, order) \ + __c11_atomic_fetch_add((object), (operand), (order)) +#define iree_atomic_fetch_sub(object, operand, order) \ + __c11_atomic_fetch_sub((object), (operand), (order)) +#define iree_atomic_fetch_and(object, operand, order) \ + __c11_atomic_fetch_and((object), (operand), (order)) +#define iree_atomic_fetch_or(object, operand, order) \ + __c11_atomic_fetch_or((object), (operand), (order)) +#define iree_atomic_fetch_xor(object, operand, order) \ + __c11_atomic_fetch_xor((object), (operand), (order)) +#define iree_atomic_exchange(object, operand, order) \ + __c11_atomic_exchange((object), (operand), (order)) +#define iree_atomic_compare_exchange_strong(object, expected, desired, \ + order_succ, order_fail) \ + __c11_atomic_compare_exchange_strong((object), (expected), (desired), \ + (order_succ), (order_fail)) +#define iree_atomic_compare_exchange_weak(object, expected, desired, \ + order_succ, order_fail) \ + __c11_atomic_compare_exchange_weak((object), (expected), (desired), \ + (order_succ), (order_fail)) + +#elif __cplusplus + +// When compiling for C++ we reinterpret atomics as std::atomic. This relies +// on std::atomic on primitive types being lock-free such that the memory for +// each atomic is just the atomic value. We need this special path because MSVC +// doesn't support C features like _Generic in C++. + +extern "C++" { +#include +} // extern "C++" + extern "C" { -#endif typedef enum iree_memory_order_e { - iree_memory_order_relaxed, + iree_memory_order_relaxed = _Atomic_memory_order_relaxed, + iree_memory_order_consume = _Atomic_memory_order_consume, + iree_memory_order_acquire = _Atomic_memory_order_acquire, + iree_memory_order_release = _Atomic_memory_order_release, + iree_memory_order_acq_rel = _Atomic_memory_order_acq_rel, + iree_memory_order_seq_cst = _Atomic_memory_order_seq_cst, +} iree_memory_order_t; + +#define IREE_ATOMIC_VAR_INIT(value) (value) + +typedef std::atomic iree_atomic_int32_t; +typedef std::atomic iree_atomic_int64_t; +typedef std::atomic iree_atomic_uint32_t; +typedef std::atomic iree_atomic_uint64_t; +typedef std::atomic iree_atomic_intptr_t; + +#define iree_atomic_thread_fence(order) std::atomic_thread_fence(order) + +#define iree_atomic_load(object, order) \ + std::atomic_load_explicit((object), (std::memory_order)(order)) +#define iree_atomic_store(object, desired, order) \ + std::atomic_store_explicit((object), (desired), (std::memory_order)(order)) +#define iree_atomic_fetch_add(object, operand, order) \ + std::atomic_fetch_add_explicit((object), (operand), \ + (std::memory_order)(order)) +#define iree_atomic_fetch_sub(object, operand, order) \ + std::atomic_fetch_sub_explicit((object), (operand), \ + (std::memory_order)(order)) +#define iree_atomic_fetch_and(object, operand, order) \ + std::atomic_fetch_and_explicit((object), (operand), \ + (std::memory_order)(order)) +#define iree_atomic_fetch_or(object, operand, order) \ + std::atomic_fetch_or_explicit((object), (operand), (std::memory_order)(order)) +#define iree_atomic_fetch_xor(object, operand, order) \ + std::atomic_fetch_xor_explicit((object), (operand), \ + (std::memory_order)(order)) +#define iree_atomic_exchange(object, operand, order) \ + std::atomic_exchange_explicit((object), (operand), (std::memory_order)(order)) +#define iree_atomic_compare_exchange_strong(object, expected, desired, \ + order_succ, order_fail) \ + std::atomic_compare_exchange_strong_explicit( \ + (object), (expected), (desired), (std::memory_order)(order_succ), \ + (std::memory_order)(order_fail)) +#define iree_atomic_compare_exchange_weak(object, expected, desired, \ + order_succ, order_fail) \ + std::atomic_compare_exchange_weak_explicit((object), (expected), (desired), \ + (std::memory_order)(order_succ), \ + (std::memory_order)(order_fail)) + +} // extern "C" + +#else + +// When compiling in C we can use _Generic to automatically route to the +// builtins that change their name based on the atomic type. This implementation +// is not good: it ignores memory order entirely and uses the full barrier +// implied by any of the _Interlocked* builtins. There are some variants of the +// builtins that we could use based on the order but their support across +// targets differs. Hopefully ~soon we can use C11 atomics directly and drop +// this code path. + +typedef enum iree_memory_order_e { + iree_memory_order_relaxed = 0u, iree_memory_order_consume, iree_memory_order_acquire, iree_memory_order_release, @@ -29,72 +158,131 @@ typedef enum iree_memory_order_e { iree_memory_order_seq_cst, } iree_memory_order_t; -#define IREE_ATOMIC_VAR_INIT(value) \ - { (value) } - -typedef struct { - int32_t __val; -} iree_atomic_int32_t; -typedef struct { - int64_t __val; -} iree_atomic_int64_t; -// typedef __declspec(align(16)) struct { -// uint64_t __val[2]; -// } iree_atomic_int128_t; -typedef struct { - intptr_t __val; -} iree_atomic_intptr_t; - -#define iree_atomic_load_int32(object, order) \ - InterlockedExchangeAdd((volatile LONG*)object, 0) -#define iree_atomic_store_int32(object, desired, order) \ - InterlockedExchange((volatile LONG*)object, desired) -#define iree_atomic_fetch_add_int32(object, operand, order) \ - InterlockedExchangeAdd((volatile LONG*)object, operand) -#define iree_atomic_fetch_sub_int32(object, operand, order) \ - InterlockedExchangeAdd((volatile LONG*)object, -((int32_t)(operand))) -#define iree_atomic_fetch_and_int32(object, operand, order) \ - InterlockedAnd((volatile LONG*)object, operand) -#define iree_atomic_fetch_or_int32(object, operand, order) \ - InterlockedOr((volatile LONG*)object, operand) -#define iree_atomic_fetch_xor_int32(object, operand, order) \ - InterlockedXor((volatile LONG*)object, operand) -#define iree_atomic_exchange_int32(object, desired, order) \ - InterlockedExchange((volatile LONG*)object, desired) -#define iree_atomic_compare_exchange_strong_int32(object, expected, desired, \ - order_succ, order_fail) \ - iree_atomic_compare_exchange_strong_int32_impl( \ - (volatile iree_atomic_int32_t*)(object), (int32_t*)(expected), \ - (int32_t)(desired), (order_succ), (order_fail)) -#define iree_atomic_compare_exchange_weak_int32 \ - iree_atomic_compare_exchange_strong_int32 - -#define iree_atomic_load_int64(object, order) \ - InterlockedExchangeAdd64((volatile LONG64*)object, 0) -#define iree_atomic_store_int64(object, desired, order) \ - InterlockedExchange64((volatile LONG64*)object, (LONG64)desired) -#define iree_atomic_fetch_add_int64(object, operand, order) \ - InterlockedExchangeAdd64((volatile LONG64*)object, (LONG64)operand) -#define iree_atomic_fetch_sub_int64(object, operand, order) \ - InterlockedExchangeAdd64((volatile LONG64*)object, -(operand)) -#define iree_atomic_fetch_and_int64(object, operand, order) \ - InterlockedAnd64((volatile LONG64*)object, operand) -#define iree_atomic_fetch_or_int64(object, operand, order) \ - InterlockedOr64((volatile LONG64*)object, operand) -#define iree_atomic_fetch_xor_int64(object, operand, order) \ - InterlockedXor64((volatile LONG64*)object, operand) -#define iree_atomic_exchange_int64(object, desired, order) \ - InterlockedExchange64((volatile LONG64*)object, desired) -#define iree_atomic_compare_exchange_strong_int64(object, expected, desired, \ - order_succ, order_fail) \ - iree_atomic_compare_exchange_strong_int64_impl( \ - (volatile iree_atomic_int64_t*)(object), (int64_t*)(expected), \ - (int64_t)(desired), (order_succ), (order_fail)) -#define iree_atomic_compare_exchange_weak_int64 \ - iree_atomic_compare_exchange_strong_int64 +#define IREE_ATOMIC_VAR_INIT(value) (value) + +typedef int32_t iree_atomic_int32_t; +typedef int64_t iree_atomic_int64_t; +typedef uint32_t iree_atomic_uint32_t; +typedef uint64_t iree_atomic_uint64_t; +typedef intptr_t iree_atomic_intptr_t; #define iree_atomic_thread_fence(order) MemoryBarrier() +#define iree_atomic_load(object, order) \ + _Generic((object), \ + iree_atomic_int32_t *: _InterlockedExchangeAdd( \ + (volatile int32_t*)(object), 0), \ + iree_atomic_int64_t *: _InterlockedExchangeAdd64( \ + (volatile int64_t*)(object), 0), \ + iree_atomic_uint32_t *: _InterlockedExchangeAdd( \ + (volatile int32_t*)(object), 0), \ + iree_atomic_uint64_t *: _InterlockedExchangeAdd64( \ + (volatile int64_t*)(object), 0)) +#define iree_atomic_store(object, desired, order) \ + _Generic((object), \ + iree_atomic_int32_t *: _InterlockedExchange((volatile int32_t*)(object), \ + (int32_t)(desired)), \ + iree_atomic_int64_t *: _InterlockedExchange64( \ + (volatile int64_t*)(object), \ + (int64_t)(desired)), \ + iree_atomic_uint32_t *: _InterlockedExchange( \ + (volatile int32_t*)(object), \ + (int32_t)(desired)), \ + iree_atomic_uint64_t *: _InterlockedExchange64( \ + (volatile int64_t*)(object), \ + (int64_t)(desired))) +#define iree_atomic_fetch_add(object, operand, order) \ + _Generic((object), \ + iree_atomic_int32_t *: _InterlockedExchangeAdd( \ + (volatile int32_t*)(object), \ + (int32_t)(operand)), \ + iree_atomic_int64_t *: _InterlockedExchangeAdd64( \ + (volatile int64_t*)(object), \ + (int64_t)(operand)), \ + iree_atomic_uint32_t *: _InterlockedExchangeAdd( \ + (volatile int32_t*)(object), \ + (int32_t)(operand)), \ + iree_atomic_uint64_t *: _InterlockedExchangeAdd64( \ + (volatile int64_t*)(object), \ + (int64_t)(operand))) +#define iree_atomic_fetch_sub(object, operand, order) \ + _Generic((object), \ + iree_atomic_int32_t *: _InterlockedExchangeAdd( \ + (volatile int32_t*)(object), \ + -((int32_t)(operand))), \ + iree_atomic_int64_t *: _InterlockedExchangeAdd64( \ + (volatile int64_t*)(object), \ + -((int64_t)(operand))), \ + iree_atomic_uint32_t *: _InterlockedExchangeAdd( \ + (volatile int32_t*)(object), \ + -((int32_t)(operand))), \ + iree_atomic_uint64_t *: _InterlockedExchangeAdd64( \ + (volatile int64_t*)(object), \ + -((int64_t)(operand)))) +#define iree_atomic_fetch_and(object, operand, order) \ + _Generic((object), \ + iree_atomic_int32_t *: _InterlockedAnd((volatile int32_t*)(object), \ + (int32_t)(operand)), \ + iree_atomic_int64_t *: _InterlockedAnd64((volatile int64_t*)(object), \ + (int64_t)(operand)), \ + iree_atomic_uint32_t *: _InterlockedAnd((volatile int32_t*)(object), \ + (int32_t)(operand)), \ + iree_atomic_uint64_t *: _InterlockedAnd64((volatile int64_t*)(object), \ + (int64_t)(operand))) +#define iree_atomic_fetch_or(object, operand, order) \ + _Generic((object), \ + iree_atomic_int32_t *: _InterlockedOr((volatile int32_t*)(object), \ + (int32_t)(operand)), \ + iree_atomic_int64_t *: _InterlockedOr64((volatile int64_t*)(object), \ + (int64_t)(operand)), \ + iree_atomic_uint32_t *: _InterlockedOr((volatile int32_t*)(object), \ + (int32_t)(operand)), \ + iree_atomic_uint64_t *: _InterlockedOr64((volatile int64_t*)(object), \ + (int64_t)(operand))) +#define iree_atomic_fetch_xor(object, operand, order) \ + _Generic((object), \ + iree_atomic_int32_t *: _InterlockedXor((volatile int32_t*)(object), \ + (int32_t)(operand)), \ + iree_atomic_int64_t *: _InterlockedXor64((volatile int64_t*)(object), \ + (int64_t)(operand)), \ + iree_atomic_uint32_t *: _InterlockedXor((volatile int32_t*)(object), \ + (int32_t)(operand)), \ + iree_atomic_uint64_t *: _InterlockedXor64((volatile int64_t*)(object), \ + (int64_t)(operand))) +#define iree_atomic_exchange(object, desired, order) \ + _Generic((object), \ + iree_atomic_int32_t *: _InterlockedExchange((volatile int32_t*)(object), \ + (int32_t)(desired)), \ + iree_atomic_int64_t *: _InterlockedExchange64( \ + (volatile int64_t*)(object), \ + (int64_t)(desired)), \ + iree_atomic_uint32_t *: _InterlockedExchange( \ + (volatile int32_t*)(object), \ + (int32_t)(desired)), \ + iree_atomic_uint64_t *: _InterlockedExchange64( \ + (volatile int64_t*)(object), \ + (int64_t)(desired))) +#define iree_atomic_compare_exchange_strong(object, expected, desired, \ + order_succ, order_fail) \ + _Generic((object), \ + iree_atomic_int32_t *: iree_atomic_compare_exchange_strong_int32_impl( \ + (volatile iree_atomic_int32_t*)(object), \ + (int32_t*)(expected), (int32_t)(desired), \ + (order_succ), (order_fail)), \ + iree_atomic_int64_t *: iree_atomic_compare_exchange_strong_int64_impl( \ + (volatile iree_atomic_int64_t*)(object), \ + (int64_t*)(expected), (int64_t)(desired), \ + (order_succ), (order_fail)), \ + iree_atomic_uint32_t *: iree_atomic_compare_exchange_strong_int32_impl( \ + (volatile iree_atomic_int32_t*)(object), \ + (int32_t*)(expected), (int32_t)(desired), \ + (order_succ), (order_fail)), \ + iree_atomic_uint64_t *: iree_atomic_compare_exchange_strong_int64_impl( \ + (volatile iree_atomic_int64_t*)(object), \ + (int64_t*)(expected), (int64_t)(desired), \ + (order_succ), (order_fail))) +#define iree_atomic_compare_exchange_weak iree_atomic_compare_exchange_strong + static inline bool iree_atomic_compare_exchange_strong_int32_impl( volatile iree_atomic_int32_t* object, int32_t* expected, int32_t desired, iree_memory_order_t order_succ, iree_memory_order_t order_fail) { @@ -123,59 +311,7 @@ static inline bool iree_atomic_compare_exchange_strong_int64_impl( } } -#define iree_atomic_thread_fence(order) MemoryBarrier() - -// There are no pointer-width atomic ops in MSVC so we need to specialize based -// on the pointer size. -#if defined(IREE_PTR_SIZE_32) -#define iree_atomic_load_intptr(object, order) \ - (intptr_t) iree_atomic_load_int32((iree_atomic_int32_t*)(object), (order)) -#define iree_atomic_store_intptr(object, desired, order) \ - (intptr_t) iree_atomic_store_int32((iree_atomic_int32_t*)(object), \ - (int32_t)(desired), (order)) -#define iree_atomic_fetch_add_intptr(object, operand, order) \ - (intptr_t) iree_atomic_fetch_add_int32((iree_atomic_int32_t*)(object), \ - (int32_t)(operand), (order)) -#define iree_atomic_fetch_sub_intptr(object, operand, order) \ - (intptr_t) iree_atomic_fetch_sub_int32((iree_atomic_int32_t*)(object), \ - (int32_t)(operand), (order)) -#define iree_atomic_exchange_intptr(object, desired, order) \ - (intptr_t) iree_atomic_exchange_int32((iree_atomic_int32_t*)(object), \ - (int32_t)(desired), (order)) -#define iree_atomic_compare_exchange_strong_intptr(object, expected, desired, \ - order_succ, order_fail) \ - iree_atomic_compare_exchange_strong_int32( \ - (iree_atomic_int32_t*)(object), (int32_t*)(expected), \ - (int32_t)(desired), (order_succ), (order_fail)) -#define iree_atomic_compare_exchange_weak_intptr \ - iree_atomic_compare_exchange_strong_intptr -#else -#define iree_atomic_load_intptr(object, order) \ - (intptr_t) iree_atomic_load_int64((iree_atomic_int64_t*)(object), (order)) -#define iree_atomic_store_intptr(object, desired, order) \ - (intptr_t) iree_atomic_store_int64((iree_atomic_int64_t*)(object), \ - (int64_t)(desired), (order)) -#define iree_atomic_fetch_add_intptr(object, operand, order) \ - (intptr_t) iree_atomic_fetch_add_int64((iree_atomic_int64_t*)(object), \ - (int64_t)(operand), (order)) -#define iree_atomic_fetch_sub_intptr(object, operand, order) \ - (intptr_t) iree_atomic_fetch_sub_int64((iree_atomic_int64_t*)(object), \ - (int64_t)(operand), (order)) -#define iree_atomic_exchange_intptr(object, desired, order) \ - (intptr_t) iree_atomic_exchange_int64((iree_atomic_int64_t*)(object), \ - (int64_t)(desired), (order)) -#define iree_atomic_compare_exchange_strong_intptr(object, expected, desired, \ - order_succ, order_fail) \ - iree_atomic_compare_exchange_strong_int64( \ - (iree_atomic_int64_t*)(object), (int64_t*)(expected), \ - (int64_t)(desired), (order_succ), (order_fail)) -#define iree_atomic_compare_exchange_weak_intptr \ - iree_atomic_compare_exchange_strong_intptr -#endif // IREE_PTR_SIZE_32 - -#ifdef __cplusplus -} // extern "C" -#endif +#endif // IREE_ATOMIC_USE_MSVC_C11 #endif // IREE_COMPILER_MSVC diff --git a/runtime/src/iree/base/internal/atomics_test.cc b/runtime/src/iree/base/internal/atomics_test.cc index a9fce2f2173e..d78890c674a7 100644 --- a/runtime/src/iree/base/internal/atomics_test.cc +++ b/runtime/src/iree/base/internal/atomics_test.cc @@ -21,9 +21,9 @@ TEST(AtomicPtr, LoadStore) { intptr_t ptr_0 = 0x0; intptr_t ptr_1 = 0x1; iree_atomic_intptr_t value = IREE_ATOMIC_VAR_INIT(ptr_0); - EXPECT_EQ(ptr_0, iree_atomic_load_intptr(&value, iree_memory_order_seq_cst)); - iree_atomic_store_intptr(&value, ptr_1, iree_memory_order_seq_cst); - EXPECT_EQ(ptr_1, iree_atomic_load_intptr(&value, iree_memory_order_seq_cst)); + EXPECT_EQ(ptr_0, iree_atomic_load(&value, iree_memory_order_seq_cst)); + iree_atomic_store(&value, ptr_1, iree_memory_order_seq_cst); + EXPECT_EQ(ptr_1, iree_atomic_load(&value, iree_memory_order_seq_cst)); } TEST(AtomicPtr, AddSub) { @@ -31,15 +31,15 @@ TEST(AtomicPtr, AddSub) { intptr_t ptr_1 = 0x1; intptr_t ptr_2 = 0x2; iree_atomic_intptr_t value = IREE_ATOMIC_VAR_INIT(ptr_0); - EXPECT_EQ(ptr_0, iree_atomic_fetch_add_intptr(&value, ptr_1, - iree_memory_order_seq_cst)); - EXPECT_EQ(ptr_1, iree_atomic_fetch_add_intptr(&value, ptr_1, - iree_memory_order_seq_cst)); - EXPECT_EQ(ptr_2, iree_atomic_fetch_sub_intptr(&value, ptr_1, - iree_memory_order_seq_cst)); - EXPECT_EQ(ptr_1, iree_atomic_fetch_sub_intptr(&value, ptr_1, - iree_memory_order_seq_cst)); - EXPECT_EQ(ptr_0, iree_atomic_load_intptr(&value, iree_memory_order_seq_cst)); + EXPECT_EQ(ptr_0, + iree_atomic_fetch_add(&value, ptr_1, iree_memory_order_seq_cst)); + EXPECT_EQ(ptr_1, + iree_atomic_fetch_add(&value, ptr_1, iree_memory_order_seq_cst)); + EXPECT_EQ(ptr_2, + iree_atomic_fetch_sub(&value, ptr_1, iree_memory_order_seq_cst)); + EXPECT_EQ(ptr_1, + iree_atomic_fetch_sub(&value, ptr_1, iree_memory_order_seq_cst)); + EXPECT_EQ(ptr_0, iree_atomic_load(&value, iree_memory_order_seq_cst)); } TEST(AtomicPtr, Exchange) { @@ -47,11 +47,11 @@ TEST(AtomicPtr, Exchange) { intptr_t ptr_1 = 0x1; intptr_t ptr_2 = 0x2; iree_atomic_intptr_t value = IREE_ATOMIC_VAR_INIT(ptr_0); - EXPECT_EQ(ptr_0, iree_atomic_exchange_intptr(&value, ptr_1, - iree_memory_order_seq_cst)); - EXPECT_EQ(ptr_1, iree_atomic_exchange_intptr(&value, ptr_2, - iree_memory_order_seq_cst)); - EXPECT_EQ(ptr_2, iree_atomic_load_intptr(&value, iree_memory_order_seq_cst)); + EXPECT_EQ(ptr_0, + iree_atomic_exchange(&value, ptr_1, iree_memory_order_seq_cst)); + EXPECT_EQ(ptr_1, + iree_atomic_exchange(&value, ptr_2, iree_memory_order_seq_cst)); + EXPECT_EQ(ptr_2, iree_atomic_load(&value, iree_memory_order_seq_cst)); } TEST(AtomicPtr, CompareExchange) { @@ -62,31 +62,31 @@ TEST(AtomicPtr, CompareExchange) { intptr_t ptr_expected = 0; // OK: value == ptr_0, CAS(ptr_0 -> ptr_1) - iree_atomic_store_intptr(&value, ptr_0, iree_memory_order_seq_cst); + iree_atomic_store(&value, ptr_0, iree_memory_order_seq_cst); ptr_expected = ptr_0; - EXPECT_TRUE(iree_atomic_compare_exchange_strong_intptr( - &value, &ptr_expected, ptr_1, iree_memory_order_seq_cst, - iree_memory_order_seq_cst)); + EXPECT_TRUE(iree_atomic_compare_exchange_strong(&value, &ptr_expected, ptr_1, + iree_memory_order_seq_cst, + iree_memory_order_seq_cst)); EXPECT_EQ(ptr_0, ptr_expected); - EXPECT_EQ(ptr_1, iree_atomic_load_intptr(&value, iree_memory_order_seq_cst)); + EXPECT_EQ(ptr_1, iree_atomic_load(&value, iree_memory_order_seq_cst)); // OK: value == ptr_1, CAS(ptr_1 -> ptr_2) - iree_atomic_store_intptr(&value, ptr_1, iree_memory_order_seq_cst); + iree_atomic_store(&value, ptr_1, iree_memory_order_seq_cst); ptr_expected = ptr_1; - EXPECT_TRUE(iree_atomic_compare_exchange_strong_intptr( - &value, &ptr_expected, ptr_2, iree_memory_order_seq_cst, - iree_memory_order_seq_cst)); + EXPECT_TRUE(iree_atomic_compare_exchange_strong(&value, &ptr_expected, ptr_2, + iree_memory_order_seq_cst, + iree_memory_order_seq_cst)); EXPECT_EQ(ptr_1, ptr_expected); - EXPECT_EQ(ptr_2, iree_atomic_load_intptr(&value, iree_memory_order_seq_cst)); + EXPECT_EQ(ptr_2, iree_atomic_load(&value, iree_memory_order_seq_cst)); // FAIL: value == ptr_0, CAS(ptr_1 -> ptr_2) - iree_atomic_store_intptr(&value, ptr_0, iree_memory_order_seq_cst); + iree_atomic_store(&value, ptr_0, iree_memory_order_seq_cst); ptr_expected = ptr_1; - EXPECT_FALSE(iree_atomic_compare_exchange_strong_intptr( - &value, &ptr_expected, ptr_2, iree_memory_order_seq_cst, - iree_memory_order_seq_cst)); + EXPECT_FALSE(iree_atomic_compare_exchange_strong(&value, &ptr_expected, ptr_2, + iree_memory_order_seq_cst, + iree_memory_order_seq_cst)); EXPECT_EQ(ptr_0, ptr_expected); - EXPECT_EQ(ptr_0, iree_atomic_load_intptr(&value, iree_memory_order_seq_cst)); + EXPECT_EQ(ptr_0, iree_atomic_load(&value, iree_memory_order_seq_cst)); } TEST(AtomicRefCount, IncDec) { diff --git a/runtime/src/iree/base/internal/dynamic_library_win32.c b/runtime/src/iree/base/internal/dynamic_library_win32.c index af6e4e80b8ef..2cbdd07f6416 100644 --- a/runtime/src/iree/base/internal/dynamic_library_win32.c +++ b/runtime/src/iree/base/internal/dynamic_library_win32.c @@ -91,7 +91,7 @@ static iree_status_t iree_dynamic_library_make_temp_file_path( static iree_atomic_int32_t next_unique_id = IREE_ATOMIC_VAR_INIT(0); // relaxed because we only care about uniqueness, we don't care about ordering // of accesses to unique_id w.r.t. other memory operations. - uint32_t unique_id = (uint32_t)iree_atomic_fetch_add_int32( + uint32_t unique_id = (uint32_t)iree_atomic_fetch_add( &next_unique_id, 1, iree_memory_order_relaxed); // Allocate storage for the full file path and format it in. diff --git a/runtime/src/iree/base/internal/synchronization.c b/runtime/src/iree/base/internal/synchronization.c index 65fb0d1a93e8..960a70c3be9b 100644 --- a/runtime/src/iree/base/internal/synchronization.c +++ b/runtime/src/iree/base/internal/synchronization.c @@ -447,8 +447,7 @@ void iree_slim_mutex_initialize(iree_slim_mutex_t* out_mutex) { void iree_slim_mutex_deinitialize(iree_slim_mutex_t* mutex) { // Assert unlocked (callers must ensure the mutex is no longer in use). - SYNC_ASSERT( - iree_atomic_load_int32(&mutex->value, iree_memory_order_acquire) == 0); + SYNC_ASSERT(iree_atomic_load(&mutex->value, iree_memory_order_acquire) == 0); } // Helper to perform a compare_exchange operation on mutex->value, internally @@ -467,9 +466,9 @@ static bool iree_slim_mutex_try_lock_compare_exchange( // more about efficiency in the uncontended case than we care about avoiding // spurious failure. Also, some callers are calling this in a loop, where they // would want the weak form anyway. - return iree_atomic_compare_exchange_weak_int32( - &mutex->value, expected, desired, iree_memory_order_acquire, - iree_memory_order_relaxed); + return iree_atomic_compare_exchange_weak(&mutex->value, expected, desired, + iree_memory_order_acquire, + iree_memory_order_relaxed); } void iree_slim_mutex_lock(iree_slim_mutex_t* mutex) @@ -490,8 +489,7 @@ void iree_slim_mutex_lock(iree_slim_mutex_t* mutex) // This uses relaxed order because this is an internal intermediate step and // we only need atomicity here. value = - iree_atomic_fetch_add_int32(&mutex->value, 1, iree_memory_order_relaxed) + - 1; + iree_atomic_fetch_add(&mutex->value, 1, iree_memory_order_relaxed) + 1; while (true) { // While the lock is available: try to acquire it for this thread. @@ -513,8 +511,7 @@ void iree_slim_mutex_lock(iree_slim_mutex_t* mutex) int spin_count = 100; for (int i = 0; i < spin_count && iree_slim_mutex_is_locked(value); ++i) { iree_processor_yield(); - value = - iree_atomic_load_int32(&mutex->value, iree_memory_order_relaxed); + value = iree_atomic_load(&mutex->value, iree_memory_order_relaxed); } } @@ -523,7 +520,7 @@ void iree_slim_mutex_lock(iree_slim_mutex_t* mutex) // NOTE: we don't care about wait failure here as we are going to loop // and check again anyway. iree_futex_wait(&mutex->value, value, IREE_TIME_INFINITE_FUTURE); - value = iree_atomic_load_int32(&mutex->value, iree_memory_order_relaxed); + value = iree_atomic_load(&mutex->value, iree_memory_order_relaxed); } } } @@ -541,8 +538,8 @@ void iree_slim_mutex_unlock(iree_slim_mutex_t* mutex) IREE_DISABLE_THREAD_SAFETY_ANALYSIS { // Refer to the iree_slim_mutex_t struct comment, "Notes on atomics". // Transition 1->0 (unlocking with no waiters) or 2->1 (with waiters). - if (iree_atomic_fetch_sub_int32(&mutex->value, iree_slim_mutex_value(1), - iree_memory_order_release) != + if (iree_atomic_fetch_sub(&mutex->value, iree_slim_mutex_value(1), + iree_memory_order_release) != iree_slim_mutex_value(1)) { // One (or more) waiters; wake a single one to avoid a thundering herd of // multiple threads all waking and trying to grab the lock (as only one will @@ -749,14 +746,14 @@ void iree_notification_initialize(iree_notification_t* out_notification) { void iree_notification_deinitialize(iree_notification_t* notification) { // Assert no more waiters (callers must tear down waiters first). SYNC_ASSERT( - (iree_atomic_load_int64(¬ification->value, iree_memory_order_acquire) & + (iree_atomic_load(¬ification->value, iree_memory_order_acquire) & IREE_NOTIFICATION_WAITER_MASK) == 0); } void iree_notification_post(iree_notification_t* notification, int32_t count) { - uint64_t previous_value = iree_atomic_fetch_add_int64( - ¬ification->value, IREE_NOTIFICATION_EPOCH_INC, - iree_memory_order_acq_rel); + uint64_t previous_value = + iree_atomic_fetch_add(¬ification->value, IREE_NOTIFICATION_EPOCH_INC, + iree_memory_order_acq_rel); // Ensure we have at least one waiter; wake up to |count| of them. if (IREE_UNLIKELY(previous_value & IREE_NOTIFICATION_WAITER_MASK)) { iree_futex_wake(iree_notification_epoch_address(notification), count); @@ -765,9 +762,9 @@ void iree_notification_post(iree_notification_t* notification, int32_t count) { iree_wait_token_t iree_notification_prepare_wait( iree_notification_t* notification) { - uint64_t previous_value = iree_atomic_fetch_add_int64( - ¬ification->value, IREE_NOTIFICATION_WAITER_INC, - iree_memory_order_acq_rel); + uint64_t previous_value = + iree_atomic_fetch_add(¬ification->value, IREE_NOTIFICATION_WAITER_INC, + iree_memory_order_acq_rel); return (iree_wait_token_t)(previous_value >> IREE_NOTIFICATION_EPOCH_SHIFT); } @@ -779,8 +776,7 @@ typedef enum iree_notification_result_e { static iree_notification_result_t iree_notification_test_wait_condition( iree_notification_t* notification, iree_wait_token_t wait_token) { - return (iree_atomic_load_int64(¬ification->value, - iree_memory_order_acquire) >> + return (iree_atomic_load(¬ification->value, iree_memory_order_acquire) >> IREE_NOTIFICATION_EPOCH_SHIFT) != wait_token ? IREE_NOTIFICATION_RESULT_RESOLVED : IREE_NOTIFICATION_RESULT_UNRESOLVED; @@ -830,9 +826,9 @@ bool iree_notification_commit_wait(iree_notification_t* notification, // TODO(benvanik): benchmark under real workloads. // iree_memory_order_relaxed would suffice for correctness but the faster // the waiter count gets to 0 the less likely we'll wake on the futex. - uint64_t previous_value = iree_atomic_fetch_add_int64( - ¬ification->value, IREE_NOTIFICATION_WAITER_DEC, - iree_memory_order_acq_rel); + uint64_t previous_value = + iree_atomic_fetch_add(¬ification->value, IREE_NOTIFICATION_WAITER_DEC, + iree_memory_order_acq_rel); SYNC_ASSERT((previous_value & IREE_NOTIFICATION_WAITER_MASK) != 0); return result == IREE_NOTIFICATION_RESULT_RESOLVED; @@ -842,9 +838,9 @@ void iree_notification_cancel_wait(iree_notification_t* notification) { // TODO(benvanik): benchmark under real workloads. // iree_memory_order_relaxed would suffice for correctness but the faster // the waiter count gets to 0 the less likely we'll wake on the futex. - uint64_t previous_value = iree_atomic_fetch_add_int64( - ¬ification->value, IREE_NOTIFICATION_WAITER_DEC, - iree_memory_order_acq_rel); + uint64_t previous_value = + iree_atomic_fetch_add(¬ification->value, IREE_NOTIFICATION_WAITER_DEC, + iree_memory_order_acq_rel); SYNC_ASSERT((previous_value & IREE_NOTIFICATION_WAITER_MASK) != 0); } diff --git a/runtime/src/iree/base/internal/threading_darwin.c b/runtime/src/iree/base/internal/threading_darwin.c index 52932f848816..dc4b5f8ef81e 100644 --- a/runtime/src/iree/base/internal/threading_darwin.c +++ b/runtime/src/iree/base/internal/threading_darwin.c @@ -104,9 +104,8 @@ iree_status_t iree_thread_create(iree_thread_entry_t entry, void* entry_arg, thread->entry_arg = entry_arg; iree_strncpy_s(thread->name, IREE_ARRAYSIZE(thread->name), params.name.data, iree_min(params.name.size, IREE_ARRAYSIZE(thread->name) - 1)); - iree_atomic_store_int32(&thread->is_suspended, - params.create_suspended ? 1 : 0, - iree_memory_order_relaxed); + iree_atomic_store(&thread->is_suspended, params.create_suspended ? 1 : 0, + iree_memory_order_relaxed); pthread_attr_t thread_attr; pthread_attr_init(&thread_attr); @@ -239,7 +238,7 @@ void iree_thread_resume(iree_thread_t* thread) { // always balance suspend/resume or else we'll mess with any // debuggers/profilers that may be suspending threads for their own uses. int32_t expected = 1; - if (iree_atomic_compare_exchange_strong_int32( + if (iree_atomic_compare_exchange_strong( &thread->is_suspended, &expected, 0, iree_memory_order_acq_rel, iree_memory_order_relaxed /* expected is unused */)) { thread_resume(thread->mach_port); diff --git a/runtime/src/iree/base/internal/threading_pthreads.c b/runtime/src/iree/base/internal/threading_pthreads.c index 1686fd16a060..3f15987be768 100644 --- a/runtime/src/iree/base/internal/threading_pthreads.c +++ b/runtime/src/iree/base/internal/threading_pthreads.c @@ -51,8 +51,8 @@ static void iree_thread_set_priority_class( static bool iree_thread_resumed_predicate(void* arg) { iree_thread_t* thread = (iree_thread_t*)arg; - return iree_atomic_load_int32(&thread->suspend_count, - iree_memory_order_acquire) == 0; + return iree_atomic_load(&thread->suspend_count, iree_memory_order_acquire) == + 0; } #if defined(IREE_PLATFORM_EMSCRIPTEN) @@ -99,8 +99,8 @@ static void* iree_thread_start_routine(void* param) { IREE_TRACE_SET_THREAD_NAME(thread->name); // Wait until we resume if we were created suspended. - while (iree_atomic_load_int32(&thread->suspend_count, - iree_memory_order_acquire) > 0) { + while (iree_atomic_load(&thread->suspend_count, iree_memory_order_acquire) > + 0) { iree_notification_await(&thread->suspend_barrier, iree_thread_resumed_predicate, thread, iree_infinite_timeout()); @@ -335,8 +335,8 @@ void iree_thread_request_affinity(iree_thread_t* thread, void iree_thread_resume(iree_thread_t* thread) { IREE_TRACE_ZONE_BEGIN(z0); - if (iree_atomic_exchange_int32(&thread->suspend_count, 0, - iree_memory_order_acq_rel) == 1) { + if (iree_atomic_exchange(&thread->suspend_count, 0, + iree_memory_order_acq_rel) == 1) { iree_notification_post(&thread->suspend_barrier, IREE_ALL_WAITERS); } diff --git a/runtime/src/iree/base/internal/threading_test.cc b/runtime/src/iree/base/internal/threading_test.cc index 8ee5a96b7fa6..1fd973083e22 100644 --- a/runtime/src/iree/base/internal/threading_test.cc +++ b/runtime/src/iree/base/internal/threading_test.cc @@ -34,12 +34,11 @@ TEST(ThreadTest, Lifetime) { iree_atomic_int32_t value; iree_notification_t barrier; } entry_data; - iree_atomic_store_int32(&entry_data.value, 123, iree_memory_order_relaxed); + iree_atomic_store(&entry_data.value, 123, iree_memory_order_relaxed); iree_notification_initialize(&entry_data.barrier); iree_thread_entry_t entry_fn = +[](void* entry_arg) -> int { auto* entry_data = reinterpret_cast(entry_arg); - iree_atomic_fetch_add_int32(&entry_data->value, 1, - iree_memory_order_acq_rel); + iree_atomic_fetch_add(&entry_data->value, 1, iree_memory_order_acq_rel); iree_notification_post(&entry_data->barrier, IREE_ALL_WAITERS); return 0; }; @@ -55,8 +54,8 @@ TEST(ThreadTest, Lifetime) { &entry_data.barrier, +[](void* entry_arg) -> bool { auto* entry_data = reinterpret_cast(entry_arg); - return iree_atomic_load_int32(&entry_data->value, - iree_memory_order_relaxed) == (123 + 1); + return iree_atomic_load(&entry_data->value, + iree_memory_order_relaxed) == (123 + 1); }, &entry_data, iree_infinite_timeout()); @@ -76,12 +75,11 @@ TEST(ThreadTest, CreateSuspended) { iree_atomic_int32_t value; iree_notification_t barrier; } entry_data; - iree_atomic_store_int32(&entry_data.value, 123, iree_memory_order_relaxed); + iree_atomic_store(&entry_data.value, 123, iree_memory_order_relaxed); iree_notification_initialize(&entry_data.barrier); iree_thread_entry_t entry_fn = +[](void* entry_arg) -> int { auto* entry_data = reinterpret_cast(entry_arg); - iree_atomic_fetch_add_int32(&entry_data->value, 1, - iree_memory_order_acq_rel); + iree_atomic_fetch_add(&entry_data->value, 1, iree_memory_order_acq_rel); iree_notification_post(&entry_data->barrier, IREE_ALL_WAITERS); return 0; }; @@ -95,11 +93,11 @@ TEST(ThreadTest, CreateSuspended) { // the value. I can't think of a good way to test this, though, so we'll just // wait a moment here and assume that if the thread was able to run it would // have during this wait. - ASSERT_EQ(123, iree_atomic_load_int32(&entry_data.value, - iree_memory_order_seq_cst)); + ASSERT_EQ(123, + iree_atomic_load(&entry_data.value, iree_memory_order_seq_cst)); std::this_thread::sleep_for(std::chrono::milliseconds(150)); - ASSERT_EQ(123, iree_atomic_load_int32(&entry_data.value, - iree_memory_order_seq_cst)); + ASSERT_EQ(123, + iree_atomic_load(&entry_data.value, iree_memory_order_seq_cst)); // Resume the thread and wait for it to finish its work. iree_thread_resume(thread); @@ -107,8 +105,8 @@ TEST(ThreadTest, CreateSuspended) { &entry_data.barrier, +[](void* entry_arg) -> bool { auto* entry_data = reinterpret_cast(entry_arg); - return iree_atomic_load_int32(&entry_data->value, - iree_memory_order_relaxed) == (123 + 1); + return iree_atomic_load(&entry_data->value, + iree_memory_order_relaxed) == (123 + 1); }, &entry_data, iree_infinite_timeout()); iree_thread_release(thread); @@ -126,11 +124,10 @@ TEST(ThreadTest, PriorityOverride) { struct entry_data_t { iree_atomic_int32_t value; } entry_data; - iree_atomic_store_int32(&entry_data.value, 0, iree_memory_order_relaxed); + iree_atomic_store(&entry_data.value, 0, iree_memory_order_relaxed); iree_thread_entry_t entry_fn = +[](void* entry_arg) -> int { auto* entry_data = reinterpret_cast(entry_arg); - iree_atomic_fetch_add_int32(&entry_data->value, 1, - iree_memory_order_release); + iree_atomic_fetch_add(&entry_data->value, 1, iree_memory_order_release); return 0; }; @@ -150,8 +147,7 @@ TEST(ThreadTest, PriorityOverride) { thread, IREE_THREAD_PRIORITY_CLASS_LOWEST); // Wait for the thread to finish. - while (iree_atomic_load_int32(&entry_data.value, iree_memory_order_acquire) != - 1) { + while (iree_atomic_load(&entry_data.value, iree_memory_order_acquire) != 1) { iree_thread_yield(); } diff --git a/runtime/src/iree/base/internal/threading_win32.c b/runtime/src/iree/base/internal/threading_win32.c index 6166ce288175..64ddca614da2 100644 --- a/runtime/src/iree/base/internal/threading_win32.c +++ b/runtime/src/iree/base/internal/threading_win32.c @@ -143,9 +143,8 @@ iree_status_t iree_thread_create(iree_thread_entry_t entry, void* entry_arg, thread->entry_arg = entry_arg; strncpy_s(thread->name, IREE_ARRAYSIZE(thread->name), params.name.data, min(params.name.size, IREE_ARRAYSIZE(thread->name) - 1)); - iree_atomic_store_int32(&thread->is_suspended, - params.create_suspended ? 1 : 0, - iree_memory_order_relaxed); + iree_atomic_store(&thread->is_suspended, params.create_suspended ? 1 : 0, + iree_memory_order_relaxed); iree_thread_override_list_initialize(iree_thread_set_priority_class, params.priority_class, thread->allocator, &thread->qos_override_list); @@ -304,7 +303,7 @@ void iree_thread_resume(iree_thread_t* thread) { // always balance suspend/resume or else we'll mess with any // debuggers/profilers that may be suspending threads for their own uses. int32_t expected = 1; - if (iree_atomic_compare_exchange_strong_int32( + if (iree_atomic_compare_exchange_strong( &thread->is_suspended, &expected, 0, iree_memory_order_acq_rel, iree_memory_order_relaxed /* expected is unused */)) { ResumeThread(thread->handle); diff --git a/runtime/src/iree/base/internal/wait_handle_inproc.c b/runtime/src/iree/base/internal/wait_handle_inproc.c index e3192595e177..7f92797b1bc8 100644 --- a/runtime/src/iree/base/internal/wait_handle_inproc.c +++ b/runtime/src/iree/base/internal/wait_handle_inproc.c @@ -240,7 +240,7 @@ static bool iree_wait_set_check(const iree_wait_set_check_params_t* params) { iree_wait_handle_t* wait_handle = ¶ms->set->handles[i]; iree_futex_handle_t* futex = (iree_futex_handle_t*)wait_handle->value.local_futex; - if (iree_atomic_load_int64(&futex->value, iree_memory_order_acquire) != 0) { + if (iree_atomic_load(&futex->value, iree_memory_order_acquire) != 0) { ++ready_count; if (params->wake_handle) { *params->wake_handle = *wait_handle; @@ -292,7 +292,7 @@ iree_status_t iree_wait_any(iree_wait_set_t* set, iree_time_t deadline_ns, } static bool iree_futex_handle_check(iree_futex_handle_t* futex) { - return iree_atomic_load_int64(&futex->value, iree_memory_order_acquire) != 0; + return iree_atomic_load(&futex->value, iree_memory_order_acquire) != 0; } iree_status_t iree_wait_one(iree_wait_handle_t* handle, @@ -335,8 +335,8 @@ iree_status_t iree_event_initialize(bool initial_state, if (iree_status_is_ok(status)) { out_event->type = IREE_WAIT_PRIMITIVE_TYPE_LOCAL_FUTEX; out_event->value.local_futex = (void*)futex; - iree_atomic_store_int64(&futex->value, initial_state ? 1 : 0, - iree_memory_order_release); + iree_atomic_store(&futex->value, initial_state ? 1 : 0, + iree_memory_order_release); iree_notification_initialize(&futex->notification); } @@ -358,8 +358,7 @@ void iree_event_set(iree_event_t* event) { // Try to transition from unset -> set. // No-op if already set and otherwise we successfully signaled the event and // need to notify all waiters. - if (iree_atomic_exchange_int64(&futex->value, 1, iree_memory_order_release) == - 0) { + if (iree_atomic_exchange(&futex->value, 1, iree_memory_order_release) == 0) { // Notify those waiting on just this event. iree_notification_post(&futex->notification, IREE_ALL_WAITERS); // Notify any multi-waits that may have this event as part of their set. @@ -371,7 +370,7 @@ void iree_event_reset(iree_event_t* event) { if (!event) return; iree_futex_handle_t* futex = (iree_futex_handle_t*)event->value.local_futex; if (!futex) return; - iree_atomic_store_int64(&futex->value, 0, iree_memory_order_release); + iree_atomic_store(&futex->value, 0, iree_memory_order_release); } #endif // IREE_WAIT_API == IREE_WAIT_API_INPROC diff --git a/runtime/src/iree/hal/drivers/cuda/memory_pools.c b/runtime/src/iree/hal/drivers/cuda/memory_pools.c index 236ffaac840b..1e34422478f5 100644 --- a/runtime/src/iree/hal/drivers/cuda/memory_pools.c +++ b/runtime/src/iree/hal/drivers/cuda/memory_pools.c @@ -121,8 +121,8 @@ static void iree_hal_cuda_memory_pool_track_alloc( iree_atomic_int64_t* bytes_allocated = is_device_local ? &pools->statistics.device_bytes_allocated : &pools->statistics.host_bytes_allocated; - iree_atomic_fetch_add_int64(bytes_allocated, allocation_size, - iree_memory_order_relaxed); + iree_atomic_fetch_add(bytes_allocated, allocation_size, + iree_memory_order_relaxed); }); } @@ -141,8 +141,8 @@ static void iree_hal_cuda_memory_pool_track_free( : &pools->statistics.host_bytes_freed; iree_device_size_t allocation_size = iree_hal_buffer_allocation_size(buffer); - iree_atomic_fetch_add_int64(bytes_freed, allocation_size, - iree_memory_order_relaxed); + iree_atomic_fetch_add(bytes_freed, allocation_size, + iree_memory_order_relaxed); }); } @@ -150,13 +150,13 @@ void iree_hal_cuda_memory_pools_merge_statistics( iree_hal_cuda_memory_pools_t* pools, iree_hal_allocator_statistics_t* statistics) { IREE_STATISTICS({ - statistics->device_bytes_allocated = iree_atomic_load_int64( + statistics->device_bytes_allocated = iree_atomic_load( &pools->statistics.device_bytes_allocated, iree_memory_order_relaxed); - statistics->host_bytes_allocated = iree_atomic_load_int64( + statistics->host_bytes_allocated = iree_atomic_load( &pools->statistics.host_bytes_allocated, iree_memory_order_relaxed); - statistics->device_bytes_freed = iree_atomic_load_int64( + statistics->device_bytes_freed = iree_atomic_load( &pools->statistics.device_bytes_freed, iree_memory_order_relaxed); - statistics->host_bytes_freed = iree_atomic_load_int64( + statistics->host_bytes_freed = iree_atomic_load( &pools->statistics.host_bytes_freed, iree_memory_order_relaxed); if (pools->device_local) { cuuint64_t pool_peak = 0; diff --git a/runtime/src/iree/hal/drivers/hip/memory_pools.c b/runtime/src/iree/hal/drivers/hip/memory_pools.c index e599cf62daa0..89e27fafdfd1 100644 --- a/runtime/src/iree/hal/drivers/hip/memory_pools.c +++ b/runtime/src/iree/hal/drivers/hip/memory_pools.c @@ -121,8 +121,8 @@ static void iree_hal_hip_memory_pool_track_alloc( iree_atomic_int64_t* bytes_allocated = is_device_local ? &pools->statistics.device_bytes_allocated : &pools->statistics.host_bytes_allocated; - iree_atomic_fetch_add_int64(bytes_allocated, allocation_size, - iree_memory_order_relaxed); + iree_atomic_fetch_add(bytes_allocated, allocation_size, + iree_memory_order_relaxed); }); } @@ -141,8 +141,8 @@ static void iree_hal_hip_memory_pool_track_free( : &pools->statistics.host_bytes_freed; iree_device_size_t allocation_size = iree_hal_buffer_allocation_size(buffer); - iree_atomic_fetch_add_int64(bytes_freed, allocation_size, - iree_memory_order_relaxed); + iree_atomic_fetch_add(bytes_freed, allocation_size, + iree_memory_order_relaxed); }); } @@ -150,13 +150,13 @@ void iree_hal_hip_memory_pools_merge_statistics( iree_hal_hip_memory_pools_t* pools, iree_hal_allocator_statistics_t* statistics) { IREE_STATISTICS({ - statistics->device_bytes_allocated = iree_atomic_load_int64( + statistics->device_bytes_allocated = iree_atomic_load( &pools->statistics.device_bytes_allocated, iree_memory_order_relaxed); - statistics->host_bytes_allocated = iree_atomic_load_int64( + statistics->host_bytes_allocated = iree_atomic_load( &pools->statistics.host_bytes_allocated, iree_memory_order_relaxed); - statistics->device_bytes_freed = iree_atomic_load_int64( + statistics->device_bytes_freed = iree_atomic_load( &pools->statistics.device_bytes_freed, iree_memory_order_relaxed); - statistics->host_bytes_freed = iree_atomic_load_int64( + statistics->host_bytes_freed = iree_atomic_load( &pools->statistics.host_bytes_freed, iree_memory_order_relaxed); if (pools->device_local) { diff --git a/runtime/src/iree/hal/drivers/metal/shared_event.m b/runtime/src/iree/hal/drivers/metal/shared_event.m index f741f2ea3a63..716306c215bb 100644 --- a/runtime/src/iree/hal/drivers/metal/shared_event.m +++ b/runtime/src/iree/hal/drivers/metal/shared_event.m @@ -231,7 +231,7 @@ iree_status_t iree_hal_metal_shared_event_multi_wait( // Create an atomic to count how many semaphores have signaled. Mark it as `__block` so different // threads are sharing the same data via reference. __block iree_atomic_int32_t wait_count; - iree_atomic_store_int32(&wait_count, 0, iree_memory_order_release); + iree_atomic_store(&wait_count, 0, iree_memory_order_release); // The total count we are expecting to see. iree_host_size_t total_count = (wait_mode == IREE_HAL_WAIT_MODE_ALL) ? semaphore_list->count : 1; // Theoretically we don't really need to mark the semaphore handle as __block given that the @@ -253,7 +253,7 @@ iree_status_t iree_hal_metal_shared_event_multi_wait( // Fail as a whole if any participating semaphore failed. if (v >= IREE_HAL_SEMAPHORE_FAILURE_VALUE) did_fail = true; - int32_t old_value = iree_atomic_fetch_add_int32( + int32_t old_value = iree_atomic_fetch_add( &wait_count, 1, iree_memory_order_release); // The last signaled semaphore send out the notification. // Atomic fetch add returns the old value, so need to +1. diff --git a/runtime/src/iree/hal/drivers/metal/staging_buffer.m b/runtime/src/iree/hal/drivers/metal/staging_buffer.m index ca0128f78890..e83e622e868b 100644 --- a/runtime/src/iree/hal/drivers/metal/staging_buffer.m +++ b/runtime/src/iree/hal/drivers/metal/staging_buffer.m @@ -37,8 +37,7 @@ iree_status_t iree_hal_metal_staging_buffer_initialize( out_staging_buffer->host_buffer = metal_buffer.contents; iree_slim_mutex_initialize(&out_staging_buffer->offset_mutex); out_staging_buffer->offset = 0; - iree_atomic_store_int32(&out_staging_buffer->pending_command_buffers, 0, - iree_memory_order_relaxed); + iree_atomic_store(&out_staging_buffer->pending_command_buffers, 0, iree_memory_order_relaxed); IREE_TRACE_ZONE_END(z0); return iree_ok_status(); @@ -97,14 +96,13 @@ void iree_hal_metal_staging_buffer_reset(iree_hal_metal_staging_buffer_t* stagin void iree_hal_metal_staging_buffer_increase_command_buffer_refcount( iree_hal_metal_staging_buffer_t* staging_buffer) { - iree_atomic_fetch_add_int32(&staging_buffer->pending_command_buffers, 1, - iree_memory_order_relaxed); + iree_atomic_fetch_add(&staging_buffer->pending_command_buffers, 1, iree_memory_order_relaxed); } void iree_hal_metal_staging_buffer_decrease_command_buffer_refcount( iree_hal_metal_staging_buffer_t* staging_buffer) { - if (iree_atomic_fetch_sub_int32(&staging_buffer->pending_command_buffers, 1, - iree_memory_order_acq_rel) == 1) { + if (iree_atomic_fetch_sub(&staging_buffer->pending_command_buffers, 1, + iree_memory_order_acq_rel) == 1) { iree_hal_metal_staging_buffer_reset(staging_buffer); } } diff --git a/runtime/src/iree/hal/drivers/vulkan/native_semaphore.cc b/runtime/src/iree/hal/drivers/vulkan/native_semaphore.cc index f75b2c0bbdb1..631f138a1c26 100644 --- a/runtime/src/iree/hal/drivers/vulkan/native_semaphore.cc +++ b/runtime/src/iree/hal/drivers/vulkan/native_semaphore.cc @@ -68,8 +68,7 @@ iree_status_t iree_hal_vulkan_native_semaphore_create( &semaphore->base); semaphore->logical_device = logical_device; semaphore->handle = handle; - iree_atomic_store_intptr(&semaphore->failure_status, 0, - iree_memory_order_release); + iree_atomic_store(&semaphore->failure_status, 0, iree_memory_order_release); *out_semaphore = &semaphore->base; } else { logical_device->syms()->vkDestroySemaphore(*logical_device, handle, @@ -87,7 +86,7 @@ static void iree_hal_vulkan_native_semaphore_destroy( iree_allocator_t host_allocator = semaphore->logical_device->host_allocator(); IREE_TRACE_ZONE_BEGIN(z0); - iree_status_ignore((iree_status_t)iree_atomic_load_intptr( + iree_status_ignore((iree_status_t)iree_atomic_load( &semaphore->failure_status, iree_memory_order_acquire)); semaphore->logical_device->syms()->vkDestroySemaphore( @@ -127,7 +126,7 @@ static iree_status_t iree_hal_vulkan_native_semaphore_query( // If the semaphore failed then clone the status so we can report it. if (value >= IREE_HAL_SEMAPHORE_FAILURE_VALUE) { - iree_status_t failure_status = (iree_status_t)iree_atomic_load_intptr( + iree_status_t failure_status = (iree_status_t)iree_atomic_load( &semaphore->failure_status, iree_memory_order_acquire); if (iree_status_is_ok(failure_status)) { return iree_make_status(IREE_STATUS_RESOURCE_EXHAUSTED, @@ -178,7 +177,7 @@ static void iree_hal_vulkan_native_semaphore_fail( // Try to set our local status - we only preserve the first failure so only // do this if we are going from a valid semaphore to a failed one. iree_status_t old_status = iree_ok_status(); - if (!iree_atomic_compare_exchange_strong_intptr( + if (!iree_atomic_compare_exchange_strong( &semaphore->failure_status, (intptr_t*)&old_status, (intptr_t)status, iree_memory_order_acq_rel, iree_memory_order_relaxed /* old_status is unused */)) { diff --git a/runtime/src/iree/hal/local/executable_plugin_manager.c b/runtime/src/iree/hal/local/executable_plugin_manager.c index 6d41c76df5d0..2739aa9f26c6 100644 --- a/runtime/src/iree/hal/local/executable_plugin_manager.c +++ b/runtime/src/iree/hal/local/executable_plugin_manager.c @@ -432,8 +432,8 @@ static iree_status_t iree_hal_executable_plugin_manager_register( // Get the next provider slot. Note that we don't yet increment it as we need // to put the provider in there first. - int32_t slot = iree_atomic_load_int32(&manager->provider_count, - iree_memory_order_acquire); + int32_t slot = + iree_atomic_load(&manager->provider_count, iree_memory_order_acquire); if (slot >= manager->capacity) { iree_slim_mutex_unlock(&manager->mutex); return iree_make_status(IREE_STATUS_RESOURCE_EXHAUSTED, @@ -449,8 +449,7 @@ static iree_status_t iree_hal_executable_plugin_manager_register( } // Mark the slot as valid now that the provider is in it. - iree_atomic_fetch_add_int32(&manager->provider_count, 1, - iree_memory_order_release); + iree_atomic_fetch_add(&manager->provider_count, 1, iree_memory_order_release); iree_slim_mutex_unlock(&manager->mutex); return iree_ok_status(); @@ -506,8 +505,8 @@ static iree_status_t iree_hal_executable_plugin_manager_resolve( // but that's ok: multithreaded registration/resolution is non-deterministic // by nature. Not holding the lock here means we allow multiple threads to // resolve imports at the same time. - int32_t provider_count = iree_atomic_load_int32(&manager->provider_count, - iree_memory_order_acquire); + int32_t provider_count = + iree_atomic_load(&manager->provider_count, iree_memory_order_acquire); // Scan in reverse registration order so that more recently registered // providers get queried first. try_resolve will populate any function diff --git a/runtime/src/iree/hal/utils/deferred_work_queue.c b/runtime/src/iree/hal/utils/deferred_work_queue.c index b4b2285c972f..e41fe3523778 100644 --- a/runtime/src/iree/hal/utils/deferred_work_queue.c +++ b/runtime/src/iree/hal/utils/deferred_work_queue.c @@ -393,9 +393,9 @@ static void iree_hal_deferred_work_queue_working_area_initialize( iree_notification_initialize(&working_area->state_notification); iree_hal_deferred_work_queue_ready_action_list_deinitialize( &working_area->ready_worklist, host_allocator); - iree_atomic_store_int32(&working_area->worker_state, - IREE_HAL_WORKER_STATE_IDLE_WAITING, - iree_memory_order_release); + iree_atomic_store(&working_area->worker_state, + IREE_HAL_WORKER_STATE_IDLE_WAITING, + iree_memory_order_release); } static void iree_hal_deferred_work_queue_working_area_deinitialize( @@ -413,9 +413,9 @@ static void iree_hal_deferred_work_queue_completion_area_initialize( iree_notification_initialize(&completion_area->state_notification); iree_hal_deferred_work_queue_completion_list_initialize( &completion_area->completion_list); - iree_atomic_store_int32(&completion_area->worker_state, - IREE_HAL_WORKER_STATE_IDLE_WAITING, - iree_memory_order_release); + iree_atomic_store(&completion_area->worker_state, + IREE_HAL_WORKER_STATE_IDLE_WAITING, + iree_memory_order_release); } static void iree_hal_deferred_work_queue_completion_area_deinitialize( @@ -557,17 +557,17 @@ static iree_hal_deferred_work_queue_t* iree_hal_deferred_work_queue_cast( static void iree_hal_deferred_work_queue_notify_worker_thread( iree_hal_deferred_work_queue_working_area_t* working_area) { - iree_atomic_store_int32(&working_area->worker_state, - IREE_HAL_WORKER_STATE_WORKLOAD_PENDING, - iree_memory_order_release); + iree_atomic_store(&working_area->worker_state, + IREE_HAL_WORKER_STATE_WORKLOAD_PENDING, + iree_memory_order_release); iree_notification_post(&working_area->state_notification, IREE_ALL_WAITERS); } static void iree_hal_deferred_work_queue_notify_completion_thread( iree_hal_deferred_work_queue_completion_area_t* completion_area) { - iree_atomic_store_int32(&completion_area->worker_state, - IREE_HAL_WORKER_STATE_WORKLOAD_PENDING, - iree_memory_order_release); + iree_atomic_store(&completion_area->worker_state, + IREE_HAL_WORKER_STATE_WORKLOAD_PENDING, + iree_memory_order_release); iree_notification_post(&completion_area->state_notification, IREE_ALL_WAITERS); } @@ -1236,14 +1236,14 @@ iree_status_t iree_hal_deferred_work_queue_issue( static bool iree_hal_deferred_work_queue_worker_has_incoming_request( iree_hal_deferred_work_queue_working_area_t* working_area) { - iree_hal_deferred_work_queue_worker_state_t value = iree_atomic_load_int32( - &working_area->worker_state, iree_memory_order_acquire); + iree_hal_deferred_work_queue_worker_state_t value = + iree_atomic_load(&working_area->worker_state, iree_memory_order_acquire); return value == IREE_HAL_WORKER_STATE_WORKLOAD_PENDING; } static bool iree_hal_deferred_work_queue_completion_has_incoming_request( iree_hal_deferred_work_queue_completion_area_t* completion_area) { - iree_hal_deferred_work_queue_worker_state_t value = iree_atomic_load_int32( + iree_hal_deferred_work_queue_worker_state_t value = iree_atomic_load( &completion_area->worker_state, iree_memory_order_acquire); return value == IREE_HAL_WORKER_STATE_WORKLOAD_PENDING; } @@ -1369,9 +1369,9 @@ static int iree_hal_deferred_work_queue_completion_execute( // sure that we don't accidentally ignore new workload pushed after done // ready list processing but before overwriting the state from this worker // thread. - iree_atomic_store_int32(&completion_area->worker_state, - IREE_HAL_WORKER_STATE_IDLE_WAITING, - iree_memory_order_release); + iree_atomic_store(&completion_area->worker_state, + IREE_HAL_WORKER_STATE_IDLE_WAITING, + iree_memory_order_release); iree_hal_deferred_work_queue_worker_process_completion(actions); iree_slim_mutex_lock(&actions->action_mutex); @@ -1424,9 +1424,9 @@ static int iree_hal_deferred_work_queue_worker_execute( // sure that we don't accidentally ignore new workload pushed after done // ready list processing but before overwriting the state from this worker // thread. - iree_atomic_store_int32(&working_area->worker_state, - IREE_HAL_WORKER_STATE_IDLE_WAITING, - iree_memory_order_release); + iree_atomic_store(&working_area->worker_state, + IREE_HAL_WORKER_STATE_IDLE_WAITING, + iree_memory_order_release); iree_hal_deferred_work_queue_worker_process_ready_list(actions); diff --git a/runtime/src/iree/hal/utils/file_transfer.c b/runtime/src/iree/hal/utils/file_transfer.c index cee1df6ebe2c..2bc8decf2f9a 100644 --- a/runtime/src/iree/hal/utils/file_transfer.c +++ b/runtime/src/iree/hal/utils/file_transfer.c @@ -242,8 +242,8 @@ static iree_status_t iree_hal_transfer_operation_create( // steps are part of this transfer. IREE_TRACE({ static iree_atomic_int32_t next_trace_id = IREE_ATOMIC_VAR_INIT(0); - operation->trace_id = iree_atomic_fetch_add_int32( - &next_trace_id, 1, iree_memory_order_seq_cst); + operation->trace_id = + iree_atomic_fetch_add(&next_trace_id, 1, iree_memory_order_seq_cst); IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, operation->trace_id); }); diff --git a/runtime/src/iree/task/affinity_set.h b/runtime/src/iree/task/affinity_set.h index 3dbf756d7519..dfe6a7a5293e 100644 --- a/runtime/src/iree/task/affinity_set.h +++ b/runtime/src/iree/task/affinity_set.h @@ -61,25 +61,25 @@ typedef iree_atomic_int64_t iree_atomic_task_affinity_set_t; static inline iree_task_affinity_set_t iree_atomic_task_affinity_set_load( iree_atomic_task_affinity_set_t* set, iree_memory_order_t order) { - return iree_atomic_load_int64(set, order); + return iree_atomic_load(set, order); } static inline void iree_atomic_task_affinity_set_store( iree_atomic_task_affinity_set_t* set, iree_task_affinity_set_t value, iree_memory_order_t order) { - iree_atomic_store_int64(set, value, order); + iree_atomic_store(set, value, order); } static inline iree_task_affinity_set_t iree_atomic_task_affinity_set_fetch_and( iree_atomic_task_affinity_set_t* set, iree_task_affinity_set_t value, iree_memory_order_t order) { - return iree_atomic_fetch_and_int64(set, value, order); + return iree_atomic_fetch_and(set, value, order); } static inline iree_task_affinity_set_t iree_atomic_task_affinity_set_fetch_or( iree_atomic_task_affinity_set_t* set, iree_task_affinity_set_t value, iree_memory_order_t order) { - return iree_atomic_fetch_or_int64(set, value, order); + return iree_atomic_fetch_or(set, value, order); } #ifdef __cplusplus diff --git a/runtime/src/iree/task/executor.c b/runtime/src/iree/task/executor.c index ff3280aaf4d2..6fc98e279e4c 100644 --- a/runtime/src/iree/task/executor.c +++ b/runtime/src/iree/task/executor.c @@ -103,10 +103,9 @@ iree_status_t iree_task_executor_create(iree_task_executor_options_t options, IREE_TRACE({ static iree_atomic_int32_t executor_id = IREE_ATOMIC_VAR_INIT(0); char trace_name[32]; - int trace_name_length = - snprintf(trace_name, sizeof(trace_name), "iree-executor-%d", - iree_atomic_fetch_add_int32(&executor_id, 1, - iree_memory_order_seq_cst)); + int trace_name_length = snprintf( + trace_name, sizeof(trace_name), "iree-executor-%d", + iree_atomic_fetch_add(&executor_id, 1, iree_memory_order_seq_cst)); IREE_LEAK_CHECK_DISABLE_PUSH(); executor->trace_name = malloc(trace_name_length + 1); memcpy((void*)executor->trace_name, trace_name, trace_name_length + 1); @@ -540,8 +539,7 @@ static iree_task_t* iree_task_executor_try_steal_task_from_affinity_set( worker_index += offset + 1; mask = iree_shr(mask, offset + 1); iree_task_worker_t* victim_worker = &executor->workers[victim_index]; - if (iree_atomic_load_int32(&victim_worker->state, - iree_memory_order_acquire) != + if (iree_atomic_load(&victim_worker->state, iree_memory_order_acquire) != IREE_TASK_WORKER_STATE_RUNNING) { return NULL; } diff --git a/runtime/src/iree/task/executor_demo.cc b/runtime/src/iree/task/executor_demo.cc index 63dba4ce0192..972d16b114a7 100644 --- a/runtime/src/iree/task/executor_demo.cc +++ b/runtime/src/iree/task/executor_demo.cc @@ -89,8 +89,8 @@ extern "C" int main(int argc, char* argv[]) { IREE_TRACE_SCOPE_NAMED("tile0"); IREE_ASSERT_EQ(0, user_context); simulate_work(tile_context); - iree_atomic_fetch_add_int32(&tile_context->statistics->reserved, 1, - iree_memory_order_relaxed); + iree_atomic_fetch_add(&tile_context->statistics->reserved, 1, + iree_memory_order_relaxed); return iree_ok_status(); }, 0), @@ -107,8 +107,8 @@ extern "C" int main(int argc, char* argv[]) { IREE_TRACE_SCOPE_NAMED("tile1"); IREE_ASSERT_EQ(0, user_context); simulate_work(tile_context); - iree_atomic_fetch_add_int32(&tile_context->statistics->reserved, 1, - iree_memory_order_relaxed); + iree_atomic_fetch_add(&tile_context->statistics->reserved, 1, + iree_memory_order_relaxed); return iree_ok_status(); }, 0), diff --git a/runtime/src/iree/task/poller.c b/runtime/src/iree/task/poller.c index e314379dc3be..e04aa3bcf162 100644 --- a/runtime/src/iree/task/poller.c +++ b/runtime/src/iree/task/poller.c @@ -32,8 +32,8 @@ iree_status_t iree_task_poller_initialize( // thread as it performs the initial resume of the wait thread. We'll need to // check in enqueue to see if the wait thread needs to be resumed. // initial_state = IREE_TASK_POLLER_STATE_SUSPENDED; - iree_atomic_store_int32(&out_poller->state, initial_state, - iree_memory_order_release); + iree_atomic_store(&out_poller->state, initial_state, + iree_memory_order_release); // Acquire an event we can use to wake the wait thread from other threads. iree_status_t status = iree_event_pool_acquire( @@ -83,7 +83,7 @@ void iree_task_poller_request_exit(iree_task_poller_t* poller) { // If the thread is already in the exiting/zombie state we don't need to do // anything. iree_task_poller_state_t prev_state = - (iree_task_poller_state_t)iree_atomic_exchange_int32( + (iree_task_poller_state_t)iree_atomic_exchange( &poller->state, IREE_TASK_POLLER_STATE_EXITING, iree_memory_order_acq_rel); switch (prev_state) { @@ -93,8 +93,8 @@ void iree_task_poller_request_exit(iree_task_poller_t* poller) { break; case IREE_TASK_POLLER_STATE_ZOMBIE: // Poller already exited; reset state to ZOMBIE. - iree_atomic_store_int32(&poller->state, IREE_TASK_POLLER_STATE_ZOMBIE, - iree_memory_order_release); + iree_atomic_store(&poller->state, IREE_TASK_POLLER_STATE_ZOMBIE, + iree_memory_order_release); break; default: // Poller now set to EXITING and should exit soon. @@ -111,7 +111,7 @@ void iree_task_poller_request_exit(iree_task_poller_t* poller) { // Returns true if the wait thread is in the zombie state (exited and awaiting // teardown). static bool iree_task_poller_is_zombie(iree_task_poller_t* poller) { - return iree_atomic_load_int32(&poller->state, iree_memory_order_acquire) == + return iree_atomic_load(&poller->state, iree_memory_order_acquire) == IREE_TASK_POLLER_STATE_ZOMBIE; } @@ -240,8 +240,8 @@ static iree_task_poller_prepare_result_t iree_task_poller_prepare_task( // scan of tasks. wait_status_code = IREE_STATUS_OK; } else if (task->cancellation_flag != NULL && - iree_atomic_load_int32(task->cancellation_flag, - iree_memory_order_acquire) != 0) { + iree_atomic_load(task->cancellation_flag, + iree_memory_order_acquire) != 0) { // Task was cancelled by the user (or a wait-any). These retire without // failure and it's up to the user to handle what happens to them. wait_status_code = IREE_STATUS_CANCELLED; @@ -313,8 +313,8 @@ static iree_task_poller_prepare_result_t iree_task_poller_prepare_task( // If this was part of a wait-any operation then set the cancellation flag // such that other waits are cancelled. if (iree_any_bit_set(task->header.flags, IREE_TASK_FLAG_WAIT_ANY)) { - if (iree_atomic_fetch_add_int32(task->cancellation_flag, 1, - iree_memory_order_release) == 0) { + if (iree_atomic_fetch_add(task->cancellation_flag, 1, + iree_memory_order_release) == 0) { // Ensure we scan again to clean up any potentially cancelled tasks. // If this was task 4 in a wait-any list then tasks 0-3 need to be // retired. @@ -429,7 +429,7 @@ static void iree_task_poller_wake_task(iree_task_poller_t* poller, // wait handles were resolved. static void iree_task_poller_commit_wait(iree_task_poller_t* poller, iree_time_t deadline_ns) { - if (iree_atomic_load_int32(&poller->state, iree_memory_order_acquire) == + if (iree_atomic_load(&poller->state, iree_memory_order_acquire) == IREE_TASK_POLLER_STATE_EXITING) { // Thread exit requested - don't block shutdown. return; @@ -486,7 +486,7 @@ static void iree_task_poller_commit_wait(iree_task_poller_t* poller, static void iree_task_poller_pump_until_exit(iree_task_poller_t* poller) { while (true) { // Check state to see if we've been asked to exit. - if (iree_atomic_load_int32(&poller->state, iree_memory_order_acquire) == + if (iree_atomic_load(&poller->state, iree_memory_order_acquire) == IREE_TASK_POLLER_STATE_EXITING) { // Thread exit requested - cancel pumping. break; @@ -536,8 +536,8 @@ static int iree_task_poller_main(iree_task_poller_t* poller) { // to exit while suspended/still starting up, so check that here before we // mess with any data structures. const bool should_run = - iree_atomic_exchange_int32(&poller->state, IREE_TASK_POLLER_STATE_RUNNING, - iree_memory_order_acq_rel) != + iree_atomic_exchange(&poller->state, IREE_TASK_POLLER_STATE_RUNNING, + iree_memory_order_acq_rel) != IREE_TASK_POLLER_STATE_EXITING; if (IREE_LIKELY(should_run)) { // << work happens here >> @@ -545,8 +545,8 @@ static int iree_task_poller_main(iree_task_poller_t* poller) { } IREE_TRACE_ZONE_END(thread_zone); - iree_atomic_store_int32(&poller->state, IREE_TASK_POLLER_STATE_ZOMBIE, - iree_memory_order_release); + iree_atomic_store(&poller->state, IREE_TASK_POLLER_STATE_ZOMBIE, + iree_memory_order_release); iree_notification_post(&poller->state_notification, IREE_ALL_WAITERS); return 0; } diff --git a/runtime/src/iree/task/scope.c b/runtime/src/iree/task/scope.c index 3ccf6ae5dfea..a777d3dc6067 100644 --- a/runtime/src/iree/task/scope.c +++ b/runtime/src/iree/task/scope.c @@ -49,12 +49,12 @@ void iree_task_scope_deinitialize(iree_task_scope_t* scope) { memset(scope->name, 0xCD, sizeof(scope->name)); // In most cases the status will have been consumed by the scope owner. - iree_status_t status = (iree_status_t)iree_atomic_exchange_intptr( + iree_status_t status = (iree_status_t)iree_atomic_exchange( &scope->permanent_status, (intptr_t)NULL, iree_memory_order_acquire); IREE_IGNORE_ERROR(status); - while (iree_atomic_load_int32(&scope->pending_idle_notification_posts, - iree_memory_order_acquire)) { + while (iree_atomic_load(&scope->pending_idle_notification_posts, + iree_memory_order_acquire)) { iree_thread_yield(); } iree_notification_deinitialize(&scope->idle_notification); @@ -74,14 +74,14 @@ iree_task_dispatch_statistics_t iree_task_scope_consume_statistics( } bool iree_task_scope_has_failed(iree_task_scope_t* scope) { - return iree_atomic_load_intptr(&scope->permanent_status, - iree_memory_order_acquire) != 0; + return iree_atomic_load(&scope->permanent_status, + iree_memory_order_acquire) != 0; } iree_status_t iree_task_scope_consume_status(iree_task_scope_t* scope) { iree_status_t old_status = iree_ok_status(); iree_status_t new_status = iree_ok_status(); - while (!iree_atomic_compare_exchange_strong_intptr( + while (!iree_atomic_compare_exchange_strong( &scope->permanent_status, (intptr_t*)&old_status, (intptr_t)new_status, iree_memory_order_acq_rel, iree_memory_order_acquire /* old_status is actually used */)) { @@ -114,7 +114,7 @@ static void iree_task_scope_try_set_status(iree_task_scope_t* scope, } iree_status_t old_status = iree_ok_status(); - if (!iree_atomic_compare_exchange_strong_intptr( + if (!iree_atomic_compare_exchange_strong( &scope->permanent_status, (intptr_t*)&old_status, (intptr_t)new_status, iree_memory_order_acq_rel, iree_memory_order_relaxed /* old_status is unused */)) { @@ -140,16 +140,16 @@ void iree_task_scope_begin(iree_task_scope_t* scope) { // relaxed because this 'begin' call will be paired with a 'end' call that // will perform the release-store, and this value is only read by // 'deinitialize'. - iree_atomic_store_int32(&scope->pending_idle_notification_posts, 1, - iree_memory_order_relaxed); + iree_atomic_store(&scope->pending_idle_notification_posts, 1, + iree_memory_order_relaxed); } void iree_task_scope_end(iree_task_scope_t* scope) { if (iree_atomic_ref_count_dec(&scope->pending_submissions) == 1) { // All submissions have completed in this scope - notify any waiters. iree_notification_post(&scope->idle_notification, IREE_ALL_WAITERS); - iree_atomic_store_int32(&scope->pending_idle_notification_posts, 0, - iree_memory_order_release); + iree_atomic_store(&scope->pending_idle_notification_posts, 0, + iree_memory_order_release); } } diff --git a/runtime/src/iree/task/task.c b/runtime/src/iree/task/task.c index ae4fbf99d5b3..d0e40103e814 100644 --- a/runtime/src/iree/task/task.c +++ b/runtime/src/iree/task/task.c @@ -39,13 +39,13 @@ void iree_task_set_completion_task(iree_task_t* task, iree_task_t* completion_task) { IREE_ASSERT(!task->completion_task); task->completion_task = completion_task; - iree_atomic_fetch_add_int32(&completion_task->pending_dependency_count, 1, - iree_memory_order_acq_rel); + iree_atomic_fetch_add(&completion_task->pending_dependency_count, 1, + iree_memory_order_acq_rel); } bool iree_task_is_ready(iree_task_t* task) { - if (iree_atomic_load_int32(&task->pending_dependency_count, - iree_memory_order_acquire) > 0) { + if (iree_atomic_load(&task->pending_dependency_count, + iree_memory_order_acquire) > 0) { // At least one dependency is still pending. return false; } @@ -62,7 +62,7 @@ static void iree_task_try_set_status(iree_atomic_intptr_t* permanent_status, z0, iree_status_code_string(iree_status_code(new_status))); iree_status_t old_status = iree_ok_status(); - if (!iree_atomic_compare_exchange_strong_intptr( + if (!iree_atomic_compare_exchange_strong( permanent_status, (intptr_t*)&old_status, (intptr_t)new_status, iree_memory_order_acq_rel, iree_memory_order_relaxed /* old_status is unused */)) { @@ -102,16 +102,15 @@ void iree_task_discard(iree_task_t* task, iree_task_list_t* discard_worklist) { // tasks in the appropriate order: if we had a DAG of A -> B, C -> D we must // discard respecting the same topological ordering. - IREE_ASSERT_EQ(0, iree_atomic_load_int32(&task->pending_dependency_count, - iree_memory_order_acquire)); + IREE_ASSERT_EQ(0, iree_atomic_load(&task->pending_dependency_count, + iree_memory_order_acquire)); // Almost all tasks will have a completion task; some may have additional // dependent tasks (like barriers) that will be handled below. const bool completion_task_ready = task->completion_task && - iree_atomic_fetch_sub_int32( - &task->completion_task->pending_dependency_count, 1, - iree_memory_order_acq_rel) == 1; + iree_atomic_fetch_sub(&task->completion_task->pending_dependency_count, 1, + iree_memory_order_acq_rel) == 1; if (completion_task_ready) { iree_task_list_push_back(discard_worklist, task->completion_task); } @@ -147,8 +146,8 @@ void iree_task_discard(iree_task_t* task, iree_task_list_t* discard_worklist) { static void iree_task_retire(iree_task_t* task, iree_task_submission_t* pending_submission, iree_status_t status) { - IREE_ASSERT_EQ(0, iree_atomic_load_int32(&task->pending_dependency_count, - iree_memory_order_acquire)); + IREE_ASSERT_EQ(0, iree_atomic_load(&task->pending_dependency_count, + iree_memory_order_acquire)); // Decrement the pending count on the completion task, if any. iree_task_t* completion_task = task->completion_task; @@ -159,8 +158,8 @@ static void iree_task_retire(iree_task_t* task, iree_task_cleanup(task, IREE_STATUS_OK); bool completion_task_ready = completion_task && - iree_atomic_fetch_sub_int32(&completion_task->pending_dependency_count, - 1, iree_memory_order_acq_rel) == 1; + iree_atomic_fetch_sub(&completion_task->pending_dependency_count, 1, + iree_memory_order_acq_rel) == 1; if (completion_task_ready) { // This was the last pending dependency and the completion task is ready // to run. @@ -180,8 +179,8 @@ static void iree_task_retire(iree_task_t* task, bool completion_task_ready = completion_task && - iree_atomic_fetch_sub_int32(&completion_task->pending_dependency_count, - 1, iree_memory_order_acq_rel) == 1; + iree_atomic_fetch_sub(&completion_task->pending_dependency_count, 1, + iree_memory_order_acq_rel) == 1; if (completion_task_ready) { // This was the last pending dependency and we know that we can safely // abort the completion task by discarding. @@ -239,7 +238,7 @@ void iree_task_call_initialize(iree_task_scope_t* scope, iree_task_call_t* out_task) { iree_task_initialize(IREE_TASK_TYPE_CALL, scope, &out_task->header); out_task->closure = closure; - iree_atomic_store_intptr(&out_task->status, 0, iree_memory_order_release); + iree_atomic_store(&out_task->status, 0, iree_memory_order_release); } void iree_task_call_execute(iree_task_call_t* task, @@ -272,9 +271,9 @@ void iree_task_call_execute(iree_task_call_t* task, // Check to see if there are no pending dependencies before retiring; the // dependency count can go up if new nested tasks were enqueued. - if (iree_atomic_load_int32(&task->header.pending_dependency_count, - iree_memory_order_acquire) == 0) { - iree_status_t status = (iree_status_t)iree_atomic_exchange_intptr( + if (iree_atomic_load(&task->header.pending_dependency_count, + iree_memory_order_acquire) == 0) { + iree_status_t status = (iree_status_t)iree_atomic_exchange( &task->status, 0, iree_memory_order_acq_rel); iree_task_retire(&task->header, pending_submission, status); } @@ -295,8 +294,8 @@ void iree_task_barrier_initialize(iree_task_scope_t* scope, out_task->dependent_tasks = dependent_tasks; for (iree_host_size_t i = 0; i < out_task->dependent_task_count; ++i) { iree_task_t* dependent_task = out_task->dependent_tasks[i]; - iree_atomic_fetch_add_int32(&dependent_task->pending_dependency_count, 1, - iree_memory_order_acq_rel); + iree_atomic_fetch_add(&dependent_task->pending_dependency_count, 1, + iree_memory_order_acq_rel); } } @@ -314,8 +313,8 @@ void iree_task_barrier_set_dependent_tasks( task->dependent_tasks = dependent_tasks; for (iree_host_size_t i = 0; i < task->dependent_task_count; ++i) { iree_task_t* dependent_task = task->dependent_tasks[i]; - iree_atomic_fetch_add_int32(&dependent_task->pending_dependency_count, 1, - iree_memory_order_acq_rel); + iree_atomic_fetch_add(&dependent_task->pending_dependency_count, 1, + iree_memory_order_acq_rel); } } @@ -329,8 +328,8 @@ static void iree_task_barrier_discard(iree_task_barrier_t* task, for (iree_host_size_t i = 0; i < task->dependent_task_count; ++i) { iree_task_t* dependent_task = task->dependent_tasks[i]; const bool dependent_task_ready = - iree_atomic_fetch_sub_int32(&dependent_task->pending_dependency_count, - 1, iree_memory_order_acq_rel) == 1; + iree_atomic_fetch_sub(&dependent_task->pending_dependency_count, 1, + iree_memory_order_acq_rel) == 1; if (dependent_task_ready) { // The dependent task has retired and can now be discard. iree_task_list_push_back(discard_worklist, dependent_task); @@ -348,8 +347,8 @@ void iree_task_barrier_retire(iree_task_barrier_t* task, for (iree_host_size_t i = 0; i < task->dependent_task_count; ++i) { iree_task_t* dependent_task = task->dependent_tasks[task->dependent_task_count - i - 1]; - if (iree_atomic_fetch_sub_int32(&dependent_task->pending_dependency_count, - 1, iree_memory_order_acq_rel) == 1) { + if (iree_atomic_fetch_sub(&dependent_task->pending_dependency_count, 1, + iree_memory_order_acq_rel) == 1) { // The dependent task has retired and can now be made ready. iree_task_submission_enqueue(pending_submission, dependent_task); } @@ -530,13 +529,13 @@ static void iree_task_dispatch_initialize_base( memcpy(out_task->workgroup_size, workgroup_size, sizeof(out_task->workgroup_size)); out_task->local_memory_size = 0; - iree_atomic_store_intptr(&out_task->status, 0, iree_memory_order_release); + iree_atomic_store(&out_task->status, 0, iree_memory_order_release); memset(&out_task->statistics, 0, sizeof(out_task->statistics)); IREE_TRACE({ static iree_atomic_int64_t next_dispatch_id = IREE_ATOMIC_VAR_INIT(0); - out_task->dispatch_id = iree_atomic_fetch_add_int64( - &next_dispatch_id, 1ll, iree_memory_order_acq_rel); + out_task->dispatch_id = iree_atomic_fetch_add(&next_dispatch_id, 1ll, + iree_memory_order_acq_rel); }); } @@ -597,8 +596,7 @@ void iree_task_dispatch_issue(iree_task_dispatch_t* dispatch_task, #endif // IREE_HAL_VERBOSE_TRACING_ENABLE // Setup the iteration space for shards to pull work from the complete grid. - iree_atomic_store_int32(&dispatch_task->tile_index, 0, - iree_memory_order_relaxed); + iree_atomic_store(&dispatch_task->tile_index, 0, iree_memory_order_relaxed); dispatch_task->tile_count = workgroup_count[0] * workgroup_count[1] * workgroup_count[2]; @@ -672,7 +670,7 @@ void iree_task_dispatch_retire(iree_task_dispatch_t* dispatch_task, // any other has hit an error; failure in a dispatch should be so exceedingly // rare that allowing some shards to complete after one encounters an error is // not a problem. - iree_status_t status = (iree_status_t)iree_atomic_exchange_intptr( + iree_status_t status = (iree_status_t)iree_atomic_exchange( &dispatch_task->status, 0, iree_memory_order_acq_rel); iree_task_retire(&dispatch_task->header, pending_submission, status); @@ -763,9 +761,9 @@ void iree_task_dispatch_shard_execute( const uint32_t tiles_per_reservation = dispatch_task->tiles_per_reservation; // relaxed order because we only care about atomic increments, not about // ordering of tile_index accesses w.r.t. other memory accesses. - uint32_t tile_base = iree_atomic_fetch_add_int32(&dispatch_task->tile_index, - tiles_per_reservation, - iree_memory_order_relaxed); + uint32_t tile_base = + iree_atomic_fetch_add(&dispatch_task->tile_index, tiles_per_reservation, + iree_memory_order_relaxed); while (tile_base < tile_count) { const uint32_t tile_range = iree_min(tile_base + tiles_per_reservation, tile_count); @@ -813,9 +811,9 @@ void iree_task_dispatch_shard_execute( } // Try to grab the next slice of tiles. - tile_base = iree_atomic_fetch_add_int32(&dispatch_task->tile_index, - tiles_per_reservation, - iree_memory_order_relaxed); + tile_base = + iree_atomic_fetch_add(&dispatch_task->tile_index, tiles_per_reservation, + iree_memory_order_relaxed); } abort_shard: diff --git a/runtime/src/iree/task/task_test_dispatch.cc b/runtime/src/iree/task/task_test_dispatch.cc index 3324b6cc464e..b18c26e790ec 100644 --- a/runtime/src/iree/task/task_test_dispatch.cc +++ b/runtime/src/iree/task/task_test_dispatch.cc @@ -35,8 +35,7 @@ class GridCoverage { bool Verify() { fflush(stdout); for (iree_host_size_t i = 0; i < workgroup_count_; ++i) { - if (iree_atomic_load_int32(&storage_[i], iree_memory_order_seq_cst) != - 1) { + if (iree_atomic_load(&storage_[i], iree_memory_order_seq_cst) != 1) { return false; } } @@ -52,8 +51,8 @@ class GridCoverage { tile_context->workgroup_count[0]) + tile_context->workgroup_xyz[1] * tile_context->workgroup_count[0] + tile_context->workgroup_xyz[0]; - iree_atomic_fetch_add_int32(&coverage->storage_[slot], 1, - iree_memory_order_seq_cst); + iree_atomic_fetch_add(&coverage->storage_[slot], 1, + iree_memory_order_seq_cst); // Useful when testing large grids: // printf("%u, %u, %u\n", tile_context->workgroup_xyz[0], diff --git a/runtime/src/iree/task/worker.c b/runtime/src/iree/task/worker.c index 5bebaa50fc09..e0e1efd82085 100644 --- a/runtime/src/iree/task/worker.c +++ b/runtime/src/iree/task/worker.c @@ -48,8 +48,8 @@ iree_status_t iree_task_worker_initialize( iree_task_queue_initialize(&out_worker->local_task_queue); iree_task_worker_state_t initial_state = IREE_TASK_WORKER_STATE_RUNNING; - iree_atomic_store_int32(&out_worker->state, initial_state, - iree_memory_order_release); + iree_atomic_store(&out_worker->state, initial_state, + iree_memory_order_release); iree_thread_create_params_t thread_params; memset(&thread_params, 0, sizeof(thread_params)); @@ -78,14 +78,14 @@ void iree_task_worker_request_exit(iree_task_worker_t* worker) { // If the thread is already in the exiting/zombie state we don't need to do // anything. iree_task_worker_state_t prev_state = - (iree_task_worker_state_t)iree_atomic_exchange_int32( + (iree_task_worker_state_t)iree_atomic_exchange( &worker->state, IREE_TASK_WORKER_STATE_EXITING, iree_memory_order_acq_rel); switch (prev_state) { case IREE_TASK_WORKER_STATE_ZOMBIE: // Worker already exited; reset state to ZOMBIE. - iree_atomic_store_int32(&worker->state, IREE_TASK_WORKER_STATE_ZOMBIE, - iree_memory_order_release); + iree_atomic_store(&worker->state, IREE_TASK_WORKER_STATE_ZOMBIE, + iree_memory_order_release); break; default: // Worker now set to EXITING and should exit soon. @@ -101,7 +101,7 @@ void iree_task_worker_request_exit(iree_task_worker_t* worker) { // Returns true if the worker is in the zombie state (exited and awaiting // teardown). static bool iree_task_worker_is_zombie(iree_task_worker_t* worker) { - return iree_atomic_load_int32(&worker->state, iree_memory_order_acquire) == + return iree_atomic_load(&worker->state, iree_memory_order_acquire) == IREE_TASK_WORKER_STATE_ZOMBIE; } @@ -310,7 +310,7 @@ static void iree_task_worker_pump_until_exit(iree_task_worker_t* worker) { iree_task_worker_mark_active(worker); // Check state to see if we've been asked to exit. - if (iree_atomic_load_int32(&worker->state, iree_memory_order_acquire) == + if (iree_atomic_load(&worker->state, iree_memory_order_acquire) == IREE_TASK_WORKER_STATE_EXITING) { // Thread exit requested - cancel pumping. iree_notification_cancel_wait(&worker->wake_notification); @@ -395,8 +395,8 @@ static int iree_task_worker_main(iree_task_worker_t* worker) { // to exit while suspended/still starting up, so check that here before we // mess with any data structures. const bool should_run = - iree_atomic_exchange_int32(&worker->state, IREE_TASK_WORKER_STATE_RUNNING, - iree_memory_order_acq_rel) != + iree_atomic_exchange(&worker->state, IREE_TASK_WORKER_STATE_RUNNING, + iree_memory_order_acq_rel) != IREE_TASK_WORKER_STATE_EXITING; if (IREE_LIKELY(should_run)) { // << work happens here >> @@ -407,8 +407,8 @@ static int iree_task_worker_main(iree_task_worker_t* worker) { iree_task_worker_mark_idle(worker); IREE_TRACE_ZONE_END(thread_zone); - iree_atomic_store_int32(&worker->state, IREE_TASK_WORKER_STATE_ZOMBIE, - iree_memory_order_release); + iree_atomic_store(&worker->state, IREE_TASK_WORKER_STATE_ZOMBIE, + iree_memory_order_release); iree_notification_post(&worker->state_notification, IREE_ALL_WAITERS); return 0; } diff --git a/runtime/src/iree/vm/context.c b/runtime/src/iree/vm/context.c index d55e67fb99f3..3a1fc239e999 100644 --- a/runtime/src/iree/vm/context.c +++ b/runtime/src/iree/vm/context.c @@ -51,8 +51,8 @@ static iree_vm_context_id_t iree_vm_context_allocate_id(void) { static iree_atomic_int32_t next_context_id = IREE_ATOMIC_VAR_INIT(1); // relaxed because we only care about atomic increments, not ordering w.r.t. // other memory accesses. - uint32_t context_id = iree_atomic_fetch_add_int32(&next_context_id, 1, - iree_memory_order_relaxed); + uint32_t context_id = + iree_atomic_fetch_add(&next_context_id, 1, iree_memory_order_relaxed); #if IREE_TRACING_FEATURES & IREE_TRACING_FEATURE_FIBERS // This is what we pass to Tracy as the fiber name. // The string must remain live for the lifetime of the process. diff --git a/runtime/src/iree/vm/invocation.c b/runtime/src/iree/vm/invocation.c index 2ba5bab75ab3..d3fe20ac0f12 100644 --- a/runtime/src/iree/vm/invocation.c +++ b/runtime/src/iree/vm/invocation.c @@ -226,8 +226,8 @@ static iree_vm_invocation_id_t iree_vm_invoke_allocate_id( // The string must remain live for the lifetime of the process. // TODO(benvanik): name it based on the function? static iree_atomic_int32_t next_invocation_id = IREE_ATOMIC_VAR_INIT(1); - uint32_t invocation_id = iree_atomic_fetch_add_int32( - &next_invocation_id, 1, iree_memory_order_relaxed); + uint32_t invocation_id = iree_atomic_fetch_add(&next_invocation_id, 1, + iree_memory_order_relaxed); IREE_LEAK_CHECK_DISABLE_PUSH(); char* name = (char*)malloc(32); snprintf(name, 32, "invoke-%04d", invocation_id - 1); diff --git a/runtime/src/iree/vm/ref.c b/runtime/src/iree/vm/ref.c index 3d5f2552b585..fe3313620075 100644 --- a/runtime/src/iree/vm/ref.c +++ b/runtime/src/iree/vm/ref.c @@ -12,15 +12,15 @@ // Useful debugging tool: #if 0 -static inline volatile iree_atomic_ref_count_t* iree_vm_get_raw_counter_ptr( +static inline iree_atomic_ref_count_t* iree_vm_get_raw_counter_ptr( void* ptr, iree_vm_ref_type_t type); -static inline volatile iree_atomic_ref_count_t* iree_vm_get_ref_counter_ptr( +static inline iree_atomic_ref_count_t* iree_vm_get_ref_counter_ptr( iree_vm_ref_t* ref); static void iree_vm_ref_trace(const char* msg, iree_vm_ref_t* ref) { if (!ref->ptr) return; - volatile iree_atomic_ref_count_t* counter = iree_vm_get_ref_counter_ptr(ref); + iree_atomic_ref_count_t* counter = iree_vm_get_ref_counter_ptr(ref); iree_string_view_t name = iree_vm_ref_type_name(ref->type); fprintf(stderr, "%s %.*s 0x%p %d\n", msg, (int)name.size, name.data, ref->ptr, iree_atomic_ref_count_load(counter)); @@ -28,7 +28,7 @@ static void iree_vm_ref_trace(const char* msg, iree_vm_ref_t* ref) { static void iree_vm_ref_ptr_trace(const char* msg, void* ptr, iree_vm_ref_type_t type) { if (!ptr) return; - volatile iree_atomic_ref_count_t* counter = + iree_atomic_ref_count_t* counter = iree_vm_get_raw_counter_ptr(ptr, type); iree_string_view_t name = iree_vm_ref_type_name(type); fprintf(stderr, "%s %.*s 0x%p %d\n", msg, (int)name.size, name.data, ptr, @@ -45,19 +45,18 @@ iree_vm_ref_type_name(iree_vm_ref_type_t type) { return iree_vm_ref_type_descriptor(type)->type_name; } -static inline volatile iree_atomic_ref_count_t* iree_vm_get_raw_counter_ptr( +static inline iree_atomic_ref_count_t* iree_vm_get_raw_counter_ptr( void* ptr, iree_vm_ref_type_t type) { IREE_VM_REF_ASSERT(ptr); IREE_VM_REF_ASSERT(type_descriptor); - return (volatile iree_atomic_ref_count_t*)ptr + - (type & IREE_VM_REF_TYPE_TAG_BIT_MASK); + return (iree_atomic_ref_count_t*)ptr + (type & IREE_VM_REF_TYPE_TAG_BIT_MASK); } -static inline volatile iree_atomic_ref_count_t* iree_vm_get_ref_counter_ptr( +static inline iree_atomic_ref_count_t* iree_vm_get_ref_counter_ptr( iree_vm_ref_t* ref) { IREE_VM_REF_ASSERT(ref); IREE_VM_REF_ASSERT(ref->ptr); - return (volatile iree_atomic_ref_count_t*)ref->ptr + + return (iree_atomic_ref_count_t*)ref->ptr + (ref->type & IREE_VM_REF_TYPE_TAG_BIT_MASK); } @@ -65,8 +64,7 @@ IREE_API_EXPORT void iree_vm_ref_object_retain(void* ptr, iree_vm_ref_type_t type) { if (!ptr) return; IREE_VM_REF_ASSERT(type); - volatile iree_atomic_ref_count_t* counter = - iree_vm_get_raw_counter_ptr(ptr, type); + iree_atomic_ref_count_t* counter = iree_vm_get_raw_counter_ptr(ptr, type); iree_atomic_ref_count_inc(counter); iree_vm_ref_ptr_trace("RETAIN", ptr, type); } @@ -76,8 +74,7 @@ IREE_API_EXPORT void iree_vm_ref_object_release(void* ptr, if (!ptr) return; IREE_VM_REF_ASSERT(type); iree_vm_ref_ptr_trace("RELEASE", ptr, type); - volatile iree_atomic_ref_count_t* counter = - iree_vm_get_raw_counter_ptr(ptr, type); + iree_atomic_ref_count_t* counter = iree_vm_get_raw_counter_ptr(ptr, type); if (iree_atomic_ref_count_dec(counter) == 1) { const iree_vm_ref_type_descriptor_t* descriptor = iree_vm_ref_type_descriptor(type); @@ -130,8 +127,7 @@ IREE_API_EXPORT iree_status_t iree_vm_ref_wrap_retain(void* ptr, out_ref->ptr = ptr; out_ref->type = type; if (out_ref->ptr) { - volatile iree_atomic_ref_count_t* counter = - iree_vm_get_ref_counter_ptr(out_ref); + iree_atomic_ref_count_t* counter = iree_vm_get_ref_counter_ptr(out_ref); iree_atomic_ref_count_inc(counter); iree_vm_ref_trace("WRAP RETAIN", out_ref); } @@ -142,8 +138,7 @@ IREE_API_EXPORT iree_status_t iree_vm_ref_wrap_retain(void* ptr, IREE_API_EXPORT void iree_vm_ref_retain_inplace(iree_vm_ref_t* ref) { IREE_VM_REF_ASSERT(ref); if (ref->ptr) { - volatile iree_atomic_ref_count_t* counter = - iree_vm_get_ref_counter_ptr(ref); + iree_atomic_ref_count_t* counter = iree_vm_get_ref_counter_ptr(ref); iree_atomic_ref_count_inc(counter); iree_vm_ref_trace("RETAIN", ref); } @@ -157,8 +152,7 @@ IREE_API_EXPORT void iree_vm_ref_retain(iree_vm_ref_t* ref, IREE_VM_REF_ASSERT(out_ref); iree_vm_ref_t temp_ref = *ref; if (ref->ptr) { - volatile iree_atomic_ref_count_t* counter = - iree_vm_get_ref_counter_ptr(ref); + iree_atomic_ref_count_t* counter = iree_vm_get_ref_counter_ptr(ref); iree_atomic_ref_count_inc(counter); iree_vm_ref_trace("RETAIN", ref); } @@ -217,7 +211,7 @@ IREE_API_EXPORT void iree_vm_ref_release(iree_vm_ref_t* ref) { if (ref->type == IREE_VM_REF_TYPE_NULL || ref->ptr == NULL) return; iree_vm_ref_trace("RELEASE", ref); - volatile iree_atomic_ref_count_t* counter = iree_vm_get_ref_counter_ptr(ref); + iree_atomic_ref_count_t* counter = iree_vm_get_ref_counter_ptr(ref); if (iree_atomic_ref_count_dec(counter) == 1) { const iree_vm_ref_type_descriptor_t* descriptor = iree_vm_ref_type_descriptor(ref->type); diff --git a/runtime/src/iree/vm/ref_test.cc b/runtime/src/iree/vm/ref_test.cc index 68eaa5eb5dc5..5260749b31aa 100644 --- a/runtime/src/iree/vm/ref_test.cc +++ b/runtime/src/iree/vm/ref_test.cc @@ -73,9 +73,9 @@ static iree_vm_ref_t MakeRef(InstancePtr& instance, const char* type_name) { // WARNING: this is an implementation detail and must never be relied on - it's // only here to test the expected behavior. static int32_t ReadCounter(iree_vm_ref_t* ref) { - return iree_atomic_load_int32((iree_atomic_ref_count_t*)ref->ptr + - (ref->type & IREE_VM_REF_TYPE_TAG_BIT_MASK), - iree_memory_order_seq_cst); + return iree_atomic_load((iree_atomic_ref_count_t*)ref->ptr + + (ref->type & IREE_VM_REF_TYPE_TAG_BIT_MASK), + iree_memory_order_seq_cst); } } // namespace From 88061739ffed74d2992561e43205a61ddb366e85 Mon Sep 17 00:00:00 2001 From: Ian Wood <75152913+IanWood1@users.noreply.github.com> Date: Mon, 28 Oct 2024 11:31:25 -0700 Subject: [PATCH 23/45] Revert "[DispatchCreation] Extend multi-use producer fusion" (#18917) The reverted commit does not handle when the "consumer" uses a value defined above. See https://github.com/iree-org/iree/issues/18879 for the original issue. This is causing issue with ~15 onnx models. I have a PR (https://github.com/iree-org/iree/pull/18855) to fix this by including values used in an ops region in the backwards slice, but It is waiting on upstream changes to `getBackwardSlice`. Currently, the PR is using a wrapper around `getBackwardSlice` to acheive the same effect, but this will be updated once the upstream change lands (https://github.com/llvm/llvm-project/pull/113478) Reverts iree-org/iree#18551 --------- Signed-off-by: Ian Wood --- .github/workflows/pkgci_regression_test.yml | 4 +- .../FuseHorizontalContractions.cpp | 61 +++++++++++++-- .../FuseMultiUseElementwiseProducer.cpp | 76 ++++--------------- .../compiler/DispatchCreation/FusionUtils.cpp | 33 -------- .../compiler/DispatchCreation/FusionUtils.h | 44 ----------- .../fuse_multiuse_elementwise_producer.mlir | 25 ------ 6 files changed, 74 insertions(+), 169 deletions(-) diff --git a/.github/workflows/pkgci_regression_test.yml b/.github/workflows/pkgci_regression_test.yml index fb94905c1b29..9849c574dd72 100644 --- a/.github/workflows/pkgci_regression_test.yml +++ b/.github/workflows/pkgci_regression_test.yml @@ -220,7 +220,7 @@ jobs: --goldentime-rocm-unet-ms 419.0 \ --goldentime-rocm-clip-ms 18.5 \ --goldentime-rocm-vae-ms 337.0 \ - --goldendispatch-rocm-unet 1527 \ + --goldendispatch-rocm-unet 1531 \ --goldendispatch-rocm-clip 1139 \ --goldendispatch-rocm-vae 247 \ --goldensize-rocm-unet-bytes 2280000 \ @@ -241,7 +241,7 @@ jobs: --goldentime-rocm-unet-ms 95.0 \ --goldentime-rocm-clip-ms 15.5 \ --goldentime-rocm-vae-ms 80.0 \ - --goldendispatch-rocm-unet 1527 \ + --goldendispatch-rocm-unet 1531 \ --goldendispatch-rocm-clip 1139 \ --goldendispatch-rocm-vae 247 \ --goldensize-rocm-unet-bytes 2270000 \ diff --git a/compiler/src/iree/compiler/DispatchCreation/FuseHorizontalContractions.cpp b/compiler/src/iree/compiler/DispatchCreation/FuseHorizontalContractions.cpp index 845485667d38..a78b6b83876b 100644 --- a/compiler/src/iree/compiler/DispatchCreation/FuseHorizontalContractions.cpp +++ b/compiler/src/iree/compiler/DispatchCreation/FuseHorizontalContractions.cpp @@ -7,7 +7,6 @@ #include "iree/compiler/Dialect/Flow/IR/FlowOps.h" #include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h" #include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h" -#include "iree/compiler/DispatchCreation/FusionUtils.h" #include "iree/compiler/DispatchCreation/Passes.h" #include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Analysis/TopologicalSortUtils.h" @@ -108,6 +107,25 @@ static bool isEmptyFillContractionDAGRootOp( return true; } +/// Check that a given operation is "horizontal" to the group. The operation +/// is horizontal if the `slice` of the operation does not contain any op +/// from the group. +static bool isHorizontalToGroup(Operation *op, + const llvm::SetVector &currGroup, + const DominanceInfo &dominanceInfo, + Operation *seedOp) { + BackwardSliceOptions options; + // Limit the slice to the seed to make sure the slice is small. + options.filter = [&](Operation *op) { + return !dominanceInfo.properlyDominates(op, seedOp); + }; + llvm::SetVector slice; + getBackwardSlice(op, &slice, options); + return !llvm::any_of(currGroup, [&](Operation *groupedOp) { + return slice.contains(groupedOp); + }); +} + /// Get user of operation that is a truncate operation. static std::optional getTruncateOp(Operation *op, @@ -131,8 +149,8 @@ getTruncateOp(Operation *op, if (!checkOperationEquivalence(genericOp, seedTruncateOp.value())) { return std::nullopt; } - if (!isHorizontalToGroup(genericOp, groupedOperations.getArrayRef(), - dominanceInfo, seedTruncateOp.value())) { + if (!isHorizontalToGroup(genericOp, groupedOperations, dominanceInfo, + seedTruncateOp.value())) { return std::nullopt; } } @@ -208,8 +226,7 @@ static std::optional getHorizontalFusionGroupMembers( if (!dominanceInfo.properlyDominates(seedOp, linalgOp)) { return false; } - if (!isHorizontalToGroup(linalgOp, allOps.getArrayRef(), dominanceInfo, - seedOp)) { + if (!isHorizontalToGroup(linalgOp, allOps, dominanceInfo, seedOp)) { return false; } return true; @@ -329,6 +346,40 @@ static AffineMap getConcatenatedIndexingMap(RewriterBase &rewriter, return newIndexingMap.insertResult(rewriter.getAffineDimExpr(0), 0); } +/// During horizontal fusion, there might be operands of the fused operations +/// whose definitions are interspersed between the fused operations. For groups +/// chosen to fuse horizontally, such operations can be moved before the +/// seed contraction operation (where the fused operation is generated). +template +static LogicalResult +moveOperandDefs(RewriterBase &rewriter, ArrayRef operations, + Operation *insertionPoint, DominanceInfo &dominanceInfo, + ArrayRef ignoreOperations = {}) { + BackwardSliceOptions options; + llvm::DenseSet ignoreOperationsSet; + ignoreOperationsSet.insert(ignoreOperations.begin(), ignoreOperations.end()); + options.filter = [&](Operation *op) { + return !dominanceInfo.properlyDominates(op, insertionPoint) && + !ignoreOperationsSet.contains(op); + }; + // Set inclusive to true cause the slice is computed from the operand, and + // we want to include the defining op (which is the point here) + options.inclusive = true; + + llvm::SetVector slice; + for (auto op : operations) { + for (auto operand : op->getOperands()) { + getBackwardSlice(operand, &slice, options); + } + } + + mlir::topologicalSort(slice); + for (auto op : slice) { + rewriter.moveOpBefore(op, insertionPoint); + } + return success(); +} + /// On finding this pattern /// ``` /// %0 = linalg.matmul ins(%arg0, %arg1) diff --git a/compiler/src/iree/compiler/DispatchCreation/FuseMultiUseElementwiseProducer.cpp b/compiler/src/iree/compiler/DispatchCreation/FuseMultiUseElementwiseProducer.cpp index d79d5145e77d..9d9d477c9a57 100644 --- a/compiler/src/iree/compiler/DispatchCreation/FuseMultiUseElementwiseProducer.cpp +++ b/compiler/src/iree/compiler/DispatchCreation/FuseMultiUseElementwiseProducer.cpp @@ -16,13 +16,9 @@ #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h" #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h" #include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h" -#include "iree/compiler/DispatchCreation/FusionUtils.h" #include "iree/compiler/DispatchCreation/Passes.h" -#include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/STLExtras.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" -#include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Analysis/TopologicalSortUtils.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" @@ -49,55 +45,25 @@ static llvm::cl::opt clLinalgMaxConstantFoldElements( llvm::cl::desc("Maximum number of elements to try to constant fold."), llvm::cl::init(0)); -static Operation *getMostDominantUse(Operation *op, - const DominanceInfo &dominanceInfo) { - auto uses = op->getUses(); - auto it = llvm::find_if(uses, [&](OpOperand &source) { - Operation *sourceOp = source.getOwner(); - - return llvm::all_of(uses, [&](OpOperand &target) { - Operation *targetOp = target.getOwner(); - return dominanceInfo.dominates(sourceOp, targetOp); - }); - }); - if (it != uses.end()) { - return it->getOwner(); - } - return nullptr; -} - /// Check if any of the use dominates all other uses of the operation. -static Operation *getFusableUse(Operation *op, - const DominanceInfo &dominanceInfo) { +static std::optional getFusableUse(Operation *op, + DominanceInfo &dominanceInfo) { auto uses = op->getUses(); - Operation *fusableUse = nullptr; for (OpOperand &source : uses) { Operation *sourceOp = source.getOwner(); - - bool dominatesAllFusableOps = llvm::all_of(uses, [&](OpOperand &target) { + bool dominatesAllUsers = true; + for (OpOperand &target : uses) { Operation *targetOp = target.getOwner(); - return !isa(targetOp) || - dominanceInfo.dominates(sourceOp, targetOp); - }); - if (dominatesAllFusableOps) { - fusableUse = sourceOp; - break; + if (!dominanceInfo.dominates(sourceOp, targetOp)) { + dominatesAllUsers = false; + break; + } + } + if (dominatesAllUsers) { + return &source; } } - Operation *mostDominantOp = getMostDominantUse(op, dominanceInfo); - if (!fusableUse || !mostDominantOp) { - return nullptr; - } - - // If `fusableUse` dominates all other users, there's nothing else to do. - if (fusableUse == mostDominantOp) { - return fusableUse; - } - - SmallVector users(op->getUsers().begin(), op->getUsers().end()); - return isHorizontalToGroup(fusableUse, users, dominanceInfo, mostDominantOp) - ? fusableUse - : nullptr; + return std::nullopt; } static OpOperand *getFirstUseInConsumer(Operation *producer, @@ -125,7 +91,6 @@ static SmallVector getAllUsesInConsumer(Operation *producer, /// using elementwise fusion. static LogicalResult doMultiUseFusion(Operation *rootOp, llvm::SetVector &fusableOps, - const DominanceInfo &dominanceInfo, RewriterBase &rewriter) { assert(rootOp && "root op cant be null"); @@ -147,20 +112,11 @@ static LogicalResult doMultiUseFusion(Operation *rootOp, Operation *consumerOp = rootOp; OpBuilder::InsertionGuard g(rewriter); for (Operation *producerOp : llvm::reverse(fusedOpsVec)) { - Operation *mostDominantUser = getMostDominantUse(producerOp, dominanceInfo); // Fuse all uses from producer -> consumer. It has been checked // before that all uses are fusable. while (OpOperand *fusedOperand = getFirstUseInConsumer(producerOp, consumerOp)) { rewriter.setInsertionPoint(consumerOp); - - if (consumerOp != mostDominantUser && - failed(moveOperandDefs(rewriter, ArrayRef{consumerOp}, - mostDominantUser, dominanceInfo))) { - return rewriter.notifyMatchFailure(consumerOp, - "failed to move operand defs"); - } - rewriter.moveOpBefore(consumerOp, mostDominantUser); FailureOr fusionResult = linalg::fuseElementwiseOps(rewriter, fusedOperand); if (failed(fusionResult)) { @@ -234,8 +190,9 @@ static FailureOr fuseMultiUseProducers(Operation *funcOp, } // 6. Check that the `genericOp` dominates all uses of `producer`. - Operation *fusableUse = getFusableUse(producer, dominanceInfo); - if (!fusableUse || fusableUse != genericOp) { + std::optional fusableUse = + getFusableUse(producer, dominanceInfo); + if (!fusableUse || fusableUse.value()->getOwner() != genericOp) { continue; } @@ -275,8 +232,7 @@ static FailureOr fuseMultiUseProducers(Operation *funcOp, IRRewriter rewriter(context); for (auto it = fusedOps.rbegin(), ie = fusedOps.rend(); it != ie; ++it) { - if (failed( - doMultiUseFusion(it->first, it->second, dominanceInfo, rewriter))) { + if (failed(doMultiUseFusion(it->first, it->second, rewriter))) { return funcOp->emitOpError("failed multi use fusion"); } } diff --git a/compiler/src/iree/compiler/DispatchCreation/FusionUtils.cpp b/compiler/src/iree/compiler/DispatchCreation/FusionUtils.cpp index 238c866fe461..c428091f6cf8 100644 --- a/compiler/src/iree/compiler/DispatchCreation/FusionUtils.cpp +++ b/compiler/src/iree/compiler/DispatchCreation/FusionUtils.cpp @@ -10,11 +10,7 @@ #include "compiler/src/iree/compiler/DispatchCreation/FusionUtils.h" #include "compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h" #include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h" -#include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/IR/Dominance.h" -#include "mlir/IR/OpDefinition.h" -#include "mlir/Transforms/RegionUtils.h" namespace mlir::iree_compiler::DispatchCreation { @@ -101,33 +97,4 @@ bool areFusableAsElementwiseOps(MLIRContext *context, OpOperand *fusedOperand, return true; } -bool isHorizontalToGroup(Operation *op, ArrayRef currGroup, - const DominanceInfo &dominanceInfo, - Operation *seedOp) { - assert(dominanceInfo.properlyDominates(seedOp, op) && - op->getParentRegion() == seedOp->getParentRegion()); - BackwardSliceOptions options; - // Limit the slice to the seed to make sure the slice is small. - options.filter = [&](Operation *op) { - return !dominanceInfo.properlyDominates(op, seedOp); - }; - llvm::SetVector slice; - getBackwardSlice(op, &slice, options); - - // `getBackwardSlice` doesnt track uses from within an ops region, so make - // sure there are no values defined above. - for (Operation *sliceOp : slice) { - bool usesValuesFromAbove = false; - mlir::visitUsedValuesDefinedAbove( - sliceOp->getRegions(), [&](void *) { usesValuesFromAbove = true; }); - if (usesValuesFromAbove) { - return false; - } - } - - return !llvm::any_of(currGroup, [&](Operation *groupedOp) { - return slice.contains(groupedOp); - }); -} - } // namespace mlir::iree_compiler::DispatchCreation diff --git a/compiler/src/iree/compiler/DispatchCreation/FusionUtils.h b/compiler/src/iree/compiler/DispatchCreation/FusionUtils.h index 6526badfea31..1d9c9306f7ae 100644 --- a/compiler/src/iree/compiler/DispatchCreation/FusionUtils.h +++ b/compiler/src/iree/compiler/DispatchCreation/FusionUtils.h @@ -10,10 +10,6 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Analysis/SliceAnalysis.h" -#include "mlir/Analysis/TopologicalSortUtils.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/IR/Dominance.h" #include "mlir/IR/Operation.h" namespace mlir::iree_compiler::DispatchCreation { @@ -23,44 +19,4 @@ namespace mlir::iree_compiler::DispatchCreation { bool areFusableAsElementwiseOps(MLIRContext *context, OpOperand *operand, bool fuseMultiReduction); -/// Check that a given operation is "horizontal" to the group. The operation -/// is horizontal if the program slice of the operation (from op back to seedOp) -/// does not contain any op from the group. -bool isHorizontalToGroup(Operation *op, ArrayRef currGroup, - const DominanceInfo &dominanceInfo, Operation *seedOp); - -/// Moves the operands and transitive defs for each op in `operations` directly -/// after `insertionPoint`. Note: this does not check if it is legal to move the -/// operands. -template -static LogicalResult -moveOperandDefs(RewriterBase &rewriter, ArrayRef operations, - Operation *insertionPoint, const DominanceInfo &dominanceInfo, - ArrayRef ignoreOperations = {}) { - BackwardSliceOptions options; - llvm::DenseSet ignoreOperationsSet; - ignoreOperationsSet.insert(ignoreOperations.begin(), ignoreOperations.end()); - options.filter = [&](Operation *op) { - return !dominanceInfo.properlyDominates(op, insertionPoint) && - !ignoreOperationsSet.contains(op); - }; - // Set inclusive to true cause the slice is computed from the operand, and - // we want to include the defining op (which is the point here) - options.inclusive = true; - - llvm::SetVector slice; - for (auto op : operations) { - assert(insertionPoint->getBlock() == op->getBlock()); - for (auto operand : op->getOperands()) { - getBackwardSlice(operand, &slice, options); - } - } - - mlir::topologicalSort(slice); - for (auto op : slice) { - rewriter.moveOpBefore(op, insertionPoint); - } - return success(); -} - } // namespace mlir::iree_compiler::DispatchCreation diff --git a/compiler/src/iree/compiler/DispatchCreation/test/fuse_multiuse_elementwise_producer.mlir b/compiler/src/iree/compiler/DispatchCreation/test/fuse_multiuse_elementwise_producer.mlir index c76fa0653635..cc3e159ca943 100644 --- a/compiler/src/iree/compiler/DispatchCreation/test/fuse_multiuse_elementwise_producer.mlir +++ b/compiler/src/iree/compiler/DispatchCreation/test/fuse_multiuse_elementwise_producer.mlir @@ -139,28 +139,3 @@ util.func public @math_sin() { // CHECK: %[[GENERIC:.+]]:2 = linalg.generic // CHECK-DAG: check.expect_almost_eq(%[[GENERIC]]#0, // CHECK-DAG: check.expect_almost_eq(%[[GENERIC]]#1, - -// ----- - -#map = affine_map<(d0, d1) -> (d0, d1)> -util.func public @fuse_by_moving_consumer(%arg0: tensor<5x5xf32>, %arg1: tensor<5x5xf32>) -> (tensor<5x5xf32>, tensor<25xf32>) { - %cst = arith.constant 1.000000e+00 : f32 - %cst_0 = arith.constant 2.000000e+00 : f32 - %cst_1 = arith.constant 3.000000e+00 : f32 - %4 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor<5x5xf32>) outs(%arg1 : tensor<5x5xf32>) { - ^bb0(%arg2: f32, %arg3: f32): - %8 = arith.addf %arg2, %cst : f32 - linalg.yield %8 : f32 - } -> tensor<5x5xf32> - // expected-note @below {{prior use here}} - %collapsed = tensor.collapse_shape %4 [[0, 1]] : tensor<5x5xf32> into tensor<25xf32> - %5 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%4 : tensor<5x5xf32>) outs(%arg1 : tensor<5x5xf32>) { - ^bb0(%arg2: f32, %arg3: f32): - %8 = arith.subf %arg2, %cst_0 : f32 - linalg.yield %8 : f32 - } -> tensor<5x5xf32> - util.return %5, %collapsed: tensor<5x5xf32>, tensor<25xf32> -} -// CHECK-LABEL: util.func public @fuse_by_moving_consumer -// CHECK: linalg.generic -// CHECK-NOT: linalg.generic From a04179893f9546de807bbc1055b933e96aea1353 Mon Sep 17 00:00:00 2001 From: Bangtian Liu Date: Mon, 28 Oct 2024 14:53:44 -0400 Subject: [PATCH 24/45] [VectorDistribution] Add vector distribution support multi-dim reduction with scalars (#18800) Splitting https://github.com/iree-org/iree/pull/18519 into four patches. Depends #18784 This is the second one, adding the corresponding layout analysis and especially supporting the case where reduction is performed inside scf.for operation. Also, the relevant tests are added. Since patch 2 includes changes from patch #18784, the necessary updates from the first patch have also been included here. --------- Signed-off-by: Bangtian Liu --- .../iree/compiler/Codegen/Common/BUILD.bazel | 1 + .../compiler/Codegen/Common/CMakeLists.txt | 1 + .../Common/GPU/GPUDistributionPatterns.cpp | 26 ++++-- .../GPUNestedLayoutDistributionPatterns.cpp | 58 ++++++++---- .../Common/GPU/GPUVectorDistribution.cpp | 18 ++-- ...gpu_nested_layout_vector_distribution.mlir | 92 +++++++++++++++++++ .../Codegen/Common/VectorLayoutAnalysis.cpp | 56 ++++++++++- .../Common/test/vector_layout_analysis.mlir | 33 +++++++ .../iree/compiler/Codegen/Utils/GPUUtils.cpp | 4 + .../iree/compiler/Codegen/Utils/GPUUtils.h | 3 + 10 files changed, 257 insertions(+), 35 deletions(-) diff --git a/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel index d9d23b22dc31..d6cc75d9cefa 100644 --- a/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel @@ -78,6 +78,7 @@ iree_compiler_cc_library( "@llvm-project//mlir:Analysis", "@llvm-project//mlir:DialectUtils", "@llvm-project//mlir:IR", + "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:VectorDialect", ], ) diff --git a/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt index ee7c406d51c8..8f729de2f714 100644 --- a/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt @@ -58,6 +58,7 @@ iree_cc_library( LLVMSupport MLIRAnalysis MLIRIR + MLIRSCFDialect MLIRVectorDialect iree::compiler::Codegen::Dialect::VectorExt::IR::IREEVectorExtDialect PUBLIC diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributionPatterns.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributionPatterns.cpp index 5133d9dfaa4a..0ef6e64d2c26 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributionPatterns.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributionPatterns.cpp @@ -107,11 +107,11 @@ struct DistributeConstants final : OpDistributionPattern { Type elementType = constant.getType().getElementType(); auto vectorType = VectorType::get(layout.getDistributedShape(), elementType); - Operation *distirbutedOp = rewriter.create( + auto distributedOp = rewriter.create( constantOp.getLoc(), vectorType, SplatElementsAttr::get(vectorType, attr.getSplatValue())); replaceOpWithDistributedValues(rewriter, constantOp, - distirbutedOp->getResult(0)); + distributedOp->getResult(0)); return success(); } }; @@ -536,8 +536,10 @@ struct DistributeScfFor final : OpDistributionPattern { SmallVector newInitArgs; for (Value initArg : forOp.getInitArgs()) { if (auto vectorInitArg = dyn_cast(initArg)) { - initArg = - getDistributed(rewriter, vectorInitArg, signature[vectorInitArg]); + if (isNonZeroRank(vectorInitArg)) { + initArg = + getDistributed(rewriter, vectorInitArg, signature[vectorInitArg]); + } } newInitArgs.push_back(initArg); } @@ -582,8 +584,14 @@ struct DistributeScfFor final : OpDistributionPattern { SmallVector operands; for (Value operand : yieldOp->getOperands()) { if (auto vectorOperand = dyn_cast(operand)) { - operand = DistributionPattern::getDistributed(rewriter, vectorOperand, - signature[vectorOperand]); + // Distributing the operand requires it to have a non-zero rank, meaning + // it must have at least one dimension. If the vector has a non-zero + // rank, the operand is distributed according to the provided layout + // signature. + if (isNonZeroRank(vectorOperand)) { + operand = DistributionPattern::getDistributed( + rewriter, vectorOperand, signature[vectorOperand]); + } } operands.push_back(operand); } @@ -606,8 +614,10 @@ struct DistributeScfFor final : OpDistributionPattern { for (auto [bbArg, oldInit] : llvm::zip_equal(bbArgs, oldInits)) { Value val = bbArg; if (auto oldVectorInit = dyn_cast(oldInit)) { - val = rewriter.create( - oldVectorInit.getLoc(), oldVectorInit.getType(), val); + if (isNonZeroRank(oldVectorInit)) { + val = rewriter.create( + oldVectorInit.getLoc(), oldVectorInit.getType(), val); + } } replacements.push_back(val); } diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp index e36ad993684f..c8b2edef15d2 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp @@ -305,7 +305,9 @@ struct DistributeBroadcast final : OpDistributionPattern { auto vectorType = VectorType::get(distShape, elementType); VectorValue srcVector = dyn_cast(broadcastOp.getSource()); - if (!srcVector) { + // If the srcVector is a scalar (like f32) or a rank-0 vector (like + // vector), we proceed with the scalar distribution branch. + if (!srcVector || !isNonZeroRank(srcVector)) { // The way distribution currently works, there is no partial thread // distribution, so a scalar is available to all threads. Scalar // distribution is simply a broadcast from scalar to the distributed @@ -413,16 +415,10 @@ struct DistributeMultiReduction final DistributionSignature &signature, PatternRewriter &rewriter) const override { VectorValue srcVector = multiReduceOp.getSource(); - auto accVector = dyn_cast(multiReduceOp.getAcc()); - if (!accVector) { - return rewriter.notifyMatchFailure( - multiReduceOp, "unimplemented: scalar accumulator distribution"); - } - auto resVector = dyn_cast(multiReduceOp.getResult()); - if (!resVector) { - return rewriter.notifyMatchFailure( - multiReduceOp, "unimplemented: scalar result distribution"); - } + Value acc = multiReduceOp.getAcc(); + Value res = multiReduceOp.getResult(); + auto accVector = dyn_cast(acc); + auto resVector = dyn_cast(res); auto srcLayout = dyn_cast_or_null(signature[srcVector]); if (!srcLayout) { @@ -440,8 +436,14 @@ struct DistributeMultiReduction final VectorValue disSrc = getDistributed(rewriter, srcVector, signature[srcVector]); - VectorValue disAcc = - getDistributed(rewriter, accVector, signature[accVector]); + + Value disAcc; + if (accVector) { + disAcc = getDistributed(rewriter, accVector, signature[accVector]); + } else { + // Scalars are always distributed to all threads already. + disAcc = multiReduceOp.getAcc(); + } Location loc = multiReduceOp.getLoc(); @@ -462,7 +464,16 @@ struct DistributeMultiReduction final auto localReduction = rewriter.create( loc, disSrc, localInit, distributedReductionMask, multiReduceOp.getKind()); - auto locallyReduced = dyn_cast(localReduction.getResult()); + + VectorValue locallyReduced; + if (accVector) { + locallyReduced = dyn_cast(localReduction.getResult()); + } else { + // Broadcast scalar accumulator to vector. + VectorType vecType = VectorType::get(ArrayRef{int64_t(1)}, elemTy); + locallyReduced = rewriter.create( + loc, vecType, localReduction.getResult()); + } assert(locallyReduced && "result should have been a vector"); @@ -485,15 +496,30 @@ struct DistributeMultiReduction final // reduction. VectorValue unflattened = rewriter.create( loc, shaped, threadReduced.value()); + + if (!accVector) { + // Broadcast the scalar (e.g., f32) to a vector type (e.g., vector) + // because the following implementation requires the operand to be a + // vector. + disAcc = rewriter.create(loc, shaped, disAcc); + } + Value accReduction = vector::makeArithReduction( rewriter, loc, multiReduceOp.getKind(), unflattened, disAcc); auto accReduced = dyn_cast(accReduction); if (!accReduced) { return failure(); } - replaceOpWithDistributedValues(rewriter, multiReduceOp, accReduced); - return failure(); + if (resVector) { + replaceOpWithDistributedValues(rewriter, multiReduceOp, accReduced); + } else { + Value accReducedVal = rewriter.create( + loc, accReduction, ArrayRef{int64_t(0)}); + replaceOpWithDistributedValues(rewriter, multiReduceOp, accReducedVal); + } + + return success(); } FailureOr doThreadReduction(RewriterBase &rewriter, diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.cpp index a8831809e25b..7e927b499077 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.cpp @@ -132,14 +132,16 @@ void DistributionPattern::replaceOpWithDistributedValues( for (auto [opResult, replacement] : llvm::zip_equal(op->getOpResults(), values)) { // If this value is a vector type, it must be converted back to simd. - if (isa(replacement.getType())) { - auto oldResult = cast(opResult); - // Create a toSIMD op to convert the value back to the simd. - rewriter.setInsertionPointAfterValue(oldResult); - Value toSIMD = rewriter.create( - oldResult.getLoc(), oldResult.getType(), replacement); - // Add to replacements. - replacement = toSIMD; + if (auto replacementType = dyn_cast(replacement.getType())) { + if (replacementType.getRank() != 0) { + auto oldResult = cast(opResult); + // Create a toSIMD op to convert the value back to the simd. + rewriter.setInsertionPointAfterValue(oldResult); + Value toSIMD = rewriter.create( + oldResult.getLoc(), oldResult.getType(), replacement); + // Add to replacements. + replacement = toSIMD; + } } replacements.push_back(replacement); } diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_vector_distribution.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_vector_distribution.mlir index f05b9925cd6c..1fd7682b58e6 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_vector_distribution.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_vector_distribution.mlir @@ -1047,3 +1047,95 @@ builtin.module attributes { transform.with_named_sequence } { // CHECK: gpu.subgroup_reduce maximumf %{{.*}} cluster(size = 2, stride = 32) : (f32) -> f32 // Accumulator reduction // CHECK: arith.maximumf %{{.*}}, %{{.*}} : vector<1x1x1xf32> + +// ----- + +#nested = #iree_vector_ext.nested_layout< + subgroup_tile = [1, 1], + batch_tile = [2, 2], + outer_tile = [1, 1], + thread_tile = [16, 4], + element_tile = [1, 4], + + subgroup_strides = [1, 1], + thread_strides = [1, 16] +> + +func.func @mfma_16x16x16_out_reduced_alldims(%arg0: vector<32x32xf32>, %arg1: f32) -> f32 { + %arg0l = iree_vector_ext.to_layout %arg0 to layout(#nested) : vector<32x32xf32> + %0 = vector.multi_reduction , %arg0l, %arg1 [0, 1] : vector<32x32xf32> to f32 + return %0 : f32 +} + +builtin.module attributes { transform.with_named_sequence } { + transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) { + %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op + transform.iree.test_gpu_vector_distribution %top_level_func : !transform.any_op + transform.yield + } +} + +// CHECK-LABEL: func @mfma_16x16x16_out_reduced_alldims +// Local reduction +// CHECK: vector.multi_reduction , %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5] : vector<2x2x1x1x1x4xf32> to f32 +// Global reduction +// CHECK: gpu.subgroup_reduce maximumf %{{.*}} cluster(size = 16) : (f32) -> f32 +// CHECK-NEXT: gpu.subgroup_reduce maximumf %{{.*}} cluster(size = 4, stride = 16) : (f32) -> f32 +// Accumulator reduction +// CHECK: arith.maximumf %{{.*}}, %{{.*}} : vector<1xf32> + +// ----- + +#layout = #iree_vector_ext.nested_layout< + subgroup_tile = [1, 1], + batch_tile = [2, 2], + outer_tile = [1, 1], + thread_tile = [16, 4], + element_tile = [1, 4], + + subgroup_strides = [1, 1], + thread_strides = [1, 16] +> + +func.func @distribute_scf_for(%arr: memref<32x32xf16>, %a: vector<32x32xf16>) -> vector { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + %cst = arith.constant dense<0.000000e+00> : vector + %cst_0 = arith.constant 0.0 : f16 + %out = scf.for %i = %c0 to %c128 step %c1 iter_args(%arg0 = %cst) -> (vector) { + %root = vector.transfer_read %arr[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<32x32xf16>, vector<32x32xf16> + %rootl = iree_vector_ext.to_layout %root to layout(#layout) : vector<32x32xf16> + %b = arith.addf %rootl, %a : vector<32x32xf16> + %c = arith.extf %b : vector<32x32xf16> to vector<32x32xf32> + %init = vector.extractelement %arg0[] : vector + %root_red = vector.multi_reduction, %c, %init [0, 1] : vector<32x32xf32> to f32 + %d = vector.broadcast %root_red : f32 to vector + scf.yield %d : vector + } + return %out : vector +} + +builtin.module attributes { transform.with_named_sequence } { + transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) { + %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op + transform.iree.test_gpu_vector_distribution %top_level_func : !transform.any_op + transform.yield + } +} + +// CHECK-LABEL: func @distribute_scf_for +// CHECK: %[[ROOT:.*]] = arith.constant dense<0.000000e+00> : vector +// CHECK: iter_args(%[[ARG0:.*]] = %[[ROOT]]) -> (vector) +// CHECK: %[[A:.*]] = iree_vector_ext.to_simt %{{.*}} : vector<32x32xf16> -> vector<2x2x1x1x1x4xf16> +// CHECK: %[[B:.*]] = arith.addf %{{.*}}, %[[A]] +// CHECK: %[[C:.*]] = arith.extf %[[B]] +// CHECK-NEXT: %[[D:.*]] = vector.extractelement %[[ARG0]][] : vector +// Local reduction +// CHECK: vector.multi_reduction , %[[C]], %{{.*}} [0, 1, 2, 3, 4, 5] : vector<2x2x1x1x1x4xf32> to f32 +// Global reduction +// CHECK: gpu.subgroup_reduce add %{{.*}} cluster(size = 16) : (f32) -> f32 +// CHECK-NEXT: gpu.subgroup_reduce add %{{.*}} cluster(size = 4, stride = 16) : (f32) -> f32 +// Accumulator reduction +// CHECK: vector.broadcast %[[D]] : f32 to vector<1xf32> +// CHECK: arith.addf %{{.*}}, %{{.*}} : vector<1xf32> diff --git a/compiler/src/iree/compiler/Codegen/Common/VectorLayoutAnalysis.cpp b/compiler/src/iree/compiler/Codegen/Common/VectorLayoutAnalysis.cpp index deabc58165fb..28b75d1f7ef8 100644 --- a/compiler/src/iree/compiler/Codegen/Common/VectorLayoutAnalysis.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/VectorLayoutAnalysis.cpp @@ -13,6 +13,7 @@ #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/raw_ostream.h" #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/Diagnostics.h" @@ -135,6 +136,9 @@ class EnforceLayout : public DataFlowAnalysis { RegionBranchPoint branchPoint, MutableArrayRef operands); + void visitRegionBranchTerminatorOpInterface(RegionBranchOpInterface branch, + RegionBranchPoint branchPoint); + DistributionLayout *getLatticeElement(Value val); MLIRContext *ctx; @@ -662,6 +666,9 @@ static void enforceLayoutToMultiReductionOp( ArrayRef operandLattices, ArrayRef resultLattices, std::function update) { + if (resultLattices.empty()) { + return; + } // Reductions should always propagate value layout to result. Result can // enforce it's layout on init. const DistributionLayout *result = resultLattices[0]; @@ -727,9 +734,12 @@ static void enforceLayoutToBroadcastOp( auto resultShape = broadcast.getResultVectorType().getShape(); auto inputType = broadcast.getSourceType(); - assert(isa(inputType) && - "Scalar broadcast not supported for now."); - auto inputShape = cast(inputType).getShape(); + + VectorType inputVectorType = dyn_cast(inputType); + if (!inputVectorType) + return; + + auto inputShape = inputVectorType.getShape(); SmallVector reductionMask(resultShape.size(), false); // Set the trailing dimensions to be reduced. @@ -994,6 +1004,9 @@ void EnforceLayout::visitOperation(Operation *op) { if (auto branch = dyn_cast(op)) { visitRegionSuccessors(branch, RegionBranchPoint::parent(), branch->getOpOperands()); + + // Handle the propagation from scf.for to yield op. + visitRegionBranchTerminatorOpInterface(branch, RegionBranchPoint::parent()); return; } @@ -1086,6 +1099,43 @@ void EnforceLayout::visitRegionSuccessors(RegionBranchOpInterface branch, } } +void EnforceLayout::visitRegionBranchTerminatorOpInterface( + RegionBranchOpInterface branch, RegionBranchPoint branchPoint) { + SmallVector successors; + branch.getSuccessorRegions(branchPoint, successors); + if (!branch.hasLoop()) + return; + SmallVector resultLattices; + for (Value result : branch->getResults()) { + DistributionLayout *resultLattice = getLatticeElement(result); + if (resultLattice->isUninitialized()) + continue; + resultLattices.push_back(resultLattice); + } + + // We do not support multiple results yet. + if (resultLattices.size() != 1) + return; + + for (RegionSuccessor successor : successors) { + if (Region *succ = successor.getSuccessor()) { + Operation *terminator = succ->back().getTerminator(); + if (scf::YieldOp yieldOp = dyn_cast(terminator)) { + for (Value operand : yieldOp.getOperands()) { + if (!isa(operand.getType())) { + continue; + } + DistributionLayout *forwardLattice = getLatticeElement(operand); + ChangeResult changed = forwardLattice->resolve(resultLattices[0]); + propagateIfChanged(forwardLattice, changed); + } + } + } + } + + return; +} + DistributionLayout *EnforceLayout::getLatticeElement(Value val) { // Add dependency of operation on the analysis state. assert(isa(val.getType()) && "Lattice value should be a vector"); diff --git a/compiler/src/iree/compiler/Codegen/Common/test/vector_layout_analysis.mlir b/compiler/src/iree/compiler/Codegen/Common/test/vector_layout_analysis.mlir index 28e8ab01d89f..6533a09e6d5a 100644 --- a/compiler/src/iree/compiler/Codegen/Common/test/vector_layout_analysis.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/test/vector_layout_analysis.mlir @@ -562,3 +562,36 @@ builtin.module attributes { transform.with_named_sequence } { transform.yield } } + +// ----- + +#layout = #iree_vector_ext.layout<<[VECTORY], [16]>, <[BATCHY, VECTORX], [2, 8]>> + +// Propagate and enforce through scf.for +builtin.module attributes { transform.with_named_sequence } { + func.func @scffor(%arr: memref<16x16xf16>, %arr2: memref<16xf16>, %a: vector<16xf16>, %b: vector<16xf16>) -> vector { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c1024 = arith.constant 1024 : index + %cst = arith.constant dense<0.000000e+00> : vector + %cst_0 = arith.constant 0.0 : f16 + + %out = scf.for %iv = %c0 to %c1024 step %c1 iter_args(%arg1 = %cst) -> (vector) { + %root = vector.transfer_read %arr[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16> + // expected-remark @above {{layout of result #0 is #iree_vector_ext.layout<<[ VECTORY], [16]>, <[ BATCHY, VECTORX], [2, 8]>>}} + %rootl = iree_vector_ext.to_layout %root to layout(#layout) : vector<16x16xf16> + %init = vector.extractelement %arg1[] : vector + %root_red = vector.multi_reduction, %rootl, %init [0, 1] : vector<16x16xf16> to f16 + %c = vector.broadcast %root_red : f16 to vector + scf.yield %c : vector + } + + func.return %out : vector + } + + transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) { + %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op + transform.iree.test_vector_layout_analysis %top_level_func : !transform.any_op + transform.yield + } +} diff --git a/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.cpp b/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.cpp index 5eb4519ec8ce..e996aba997b1 100644 --- a/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.cpp +++ b/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.cpp @@ -212,6 +212,10 @@ getGPUScfTileSizeComputeFn(mlir::FunctionOpInterface funcOp, int tilingLevel) { return computeFn; } +bool isNonZeroRank(TypedValue val) { + return val.getType().getRank() != 0; +} + //===----------------------------------------------------------------------===// // GPU workgroup memory //===----------------------------------------------------------------------===// diff --git a/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.h b/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.h index cdbc297cb4c1..4e7c108f7c19 100644 --- a/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.h +++ b/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.h @@ -105,6 +105,9 @@ FailureOr> getGPUTileSize(mlir::FunctionOpInterface funcOp, FailureOr getGPUScfTileSizeComputeFn(mlir::FunctionOpInterface funcOp, int tilingLevel); +/// Returns true iff the rank of the input value 'val' is non-zero. +bool isNonZeroRank(TypedValue val); + //===----------------------------------------------------------------------===// // GPU workgroup memory //===----------------------------------------------------------------------===// From e66171aa4c928727a589ad016134f009140c8a03 Mon Sep 17 00:00:00 2001 From: Kunwar Grover Date: Mon, 28 Oct 2024 18:58:00 +0000 Subject: [PATCH 25/45] [LinalgExt] Generalize attribute setting for attention decomposition (#18780) This PR teaches attention decomposition to set attributes for attention matmuls by passing attribute dictionaries to iree_linalg_ext.online_attention operation. This allows us to further control codegen of matmuls (generally the root operations) after decomposition (for example, setting lowering config on the decompose matmuls). --- .../compiler/Codegen/LLVMGPU/KernelConfig.cpp | 23 +++++++++++++++++- .../pipeline_vector_distribute_gfx940.mlir | 24 ++++++++++++++++--- .../IR/AggregatedOpInterfaceImpl.cpp | 20 ++++++++++------ .../Dialect/LinalgExt/IR/LinalgExtOps.cpp | 4 ++-- .../Dialect/LinalgExt/IR/LinalgExtOps.td | 18 ++++++++++++-- .../LinalgExt/Transforms/TileAttention.cpp | 3 ++- 6 files changed, 76 insertions(+), 16 deletions(-) diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp index 4b64cda3adc9..0d9c7f9ad2e6 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp @@ -825,7 +825,25 @@ setAttentionVectorDistributionConfig(IREE::GPU::TargetAttr target, attrs.emplace_back(StringAttr::get(context, "reduction"), b.getI64ArrayAttr(reductionTileSizes)); - auto configDict = DictionaryAttr::get(context, attrs); + SmallVector qkAttrs; + SmallVector pvAttrs; + + qkAttrs.emplace_back(b.getNamedAttr("attention_qk_matmul", b.getUnitAttr())); + pvAttrs.emplace_back(b.getNamedAttr("attention_pv_matmul", b.getUnitAttr())); + + auto qkAttrDict = b.getDictionaryAttr(qkAttrs); + auto pvAttrDict = b.getDictionaryAttr(pvAttrs); + + SmallVector decompositionConfig; + decompositionConfig.emplace_back( + b.getNamedAttr(IREE::LinalgExt::AttentionOp::getQKAttrStr(), qkAttrDict)); + decompositionConfig.emplace_back( + b.getNamedAttr(IREE::LinalgExt::AttentionOp::getPVAttrStr(), pvAttrDict)); + + DictionaryAttr decompositionConfigDict = + b.getDictionaryAttr(decompositionConfig); + + auto configDict = b.getDictionaryAttr(attrs); auto loweringConfig = IREE::GPU::LoweringConfigAttr::get(context, configDict); // Attach the MMA schedule as an attribute to the entry point export function @@ -843,6 +861,9 @@ setAttentionVectorDistributionConfig(IREE::GPU::TargetAttr target, auto pipelineConfig = DictionaryAttr::get(context, pipelineAttrs); + // Set attention decomposition control config. + op.setDecompositionConfigAttr(decompositionConfigDict); + return setOpConfigAndEntryPointFnTranslation( entryPoint, op, loweringConfig, CodeGenPipeline::LLVMGPUVectorDistribute, workgroupSize, targetSubgroupSize, pipelineConfig); diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx940.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx940.mlir index d21faf8867b1..4334e79d6f88 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx940.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx940.mlir @@ -688,7 +688,9 @@ hal.executable private @attention_20x4096x64x4096x64 { affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> ()>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>], - lowering_config = #config} + lowering_config = #config, + decomposition_config = {qk_attrs = {attention_qk_matmul}, + pv_attrs = {attention_pv_matmul}}} ins(%4, %5, %6, %cst : tensor<20x4096x64xf16>, tensor<20x4096x64xf16>, tensor<20x4096x64xf16>, f16) outs(%7 : tensor<20x4096x64xf16>) { ^bb0(%score: f32): iree_linalg_ext.yield %score : f32 @@ -753,7 +755,15 @@ hal.executable private @attention_multiple_m_transpose { %6 = flow.dispatch.tensor.load %2, offsets = [0, 0, 0], sizes = [24, 4608, 128], strides = [1, 1, 1] : !flow.dispatch.tensor> -> tensor<24x4608x128xf16> %7 = tensor.empty() : tensor<64x4608x24x128xf16> %8 = tensor.empty() : tensor<24x64x4608x128xf16> - %9 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d3)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> ()>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>], lowering_config = #config} ins(%4, %5, %6, %cst : tensor<24x64x4608x128xf16>, tensor<24x4608x128xf16>, tensor<24x4608x128xf16>, f16) outs(%8 : tensor<24x64x4608x128xf16>) { + %9 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>, + affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d3)>, + affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d5)>, + affine_map<(d0, d1, d2, d3, d4, d5) -> ()>, + affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>], + lowering_config = #config, + decomposition_config = {qk_attrs = {attention_qk_matmul}, + pv_attrs = {attention_pv_matmul}}} + ins(%4, %5, %6, %cst : tensor<24x64x4608x128xf16>, tensor<24x4608x128xf16>, tensor<24x4608x128xf16>, f16) outs(%8 : tensor<24x64x4608x128xf16>) { ^bb0(%score: f32): iree_linalg_ext.yield %score : f32 } -> tensor<24x64x4608x128xf16> @@ -811,7 +821,15 @@ hal.executable private @attention_mfma_32x32x8 { %6 = flow.dispatch.tensor.load %2, offsets = [0, 0, 0], sizes = [24, 4608, 128], strides = [1, 1, 1] : !flow.dispatch.tensor> -> tensor<24x4608x128xf16> %7 = tensor.empty() : tensor<64x4608x24x128xf16> %8 = tensor.empty() : tensor<24x64x4608x128xf16> - %9 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d3)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> ()>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>], lowering_config = #config} ins(%4, %5, %6, %cst : tensor<24x64x4608x128xf16>, tensor<24x4608x128xf16>, tensor<24x4608x128xf16>, f16) outs(%8 : tensor<24x64x4608x128xf16>) { + %9 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>, + affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d3)>, + affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d5)>, + affine_map<(d0, d1, d2, d3, d4, d5) -> ()>, + affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>], + lowering_config = #config, + decomposition_config = {qk_attrs = {attention_qk_matmul}, + pv_attrs = {attention_pv_matmul}}} + ins(%4, %5, %6, %cst : tensor<24x64x4608x128xf16>, tensor<24x4608x128xf16>, tensor<24x4608x128xf16>, f16) outs(%8 : tensor<24x64x4608x128xf16>) { ^bb0(%score: f32): iree_linalg_ext.yield %score : f32 } -> tensor<24x64x4608x128xf16> diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp index 02d1e71e423c..204ae3533c7b 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp @@ -313,6 +313,13 @@ OnlineAttentionOp::decomposeOperation(OpBuilder &b) { Value oldMax = getMax(); Value oldSum = getSum(); Type elementType = getElementTypeOrSelf(getOutput().getType()); + DictionaryAttr config = getDecompositionConfigAttr(); + + DictionaryAttr qkAttrs, pvAttrs; + if (config) { + qkAttrs = config.getAs(getQKAttrStr()); + pvAttrs = config.getAs(getPVAttrStr()); + } FailureOr maybeOpInfo = AttentionOpDetail::get(getIndexingMapsArray()); @@ -368,10 +375,9 @@ OnlineAttentionOp::decomposeOperation(OpBuilder &b) { Value s = b.create(loc, sZero, emptyS).getResult(0); s = computeMatmul(b, loc, getQueryMap(), getKeyMap(), sMap, query, key, s); - - // TODO: We shouldn't be relying on such attributes. We need a better - // mechanism to identify attention matmuls. - s.getDefiningOp()->setAttr("attention_qk_matmul", b.getUnitAttr()); + if (qkAttrs) { + s.getDefiningOp()->setDiscardableAttrs(qkAttrs); + } s = applyPostQKMatmulElementwise(b, loc, getRegion(), s); @@ -448,9 +454,9 @@ OnlineAttentionOp::decomposeOperation(OpBuilder &b) { // newAcc = P @ V + newAcc newAcc = computeMatmul(b, loc, pMap, getValueMap(), accMap, p, value, newAcc); - // TODO: We shouldn't be relying on such attributes. We need a better - // mechanism to identify attention matmuls. - newAcc.getDefiningOp()->setAttr("attention_pv_matmul", b.getUnitAttr()); + if (pvAttrs) { + newAcc.getDefiningOp()->setDiscardableAttrs(pvAttrs); + } return SmallVector{newAcc, newMax, newSum}; } diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp index 6abaec41f91a..77a2d518acb2 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp @@ -1213,7 +1213,7 @@ void AttentionOp::build(OpBuilder &odsBuilder, OperationState &odsState, std::optional mask) { Value maskIn = mask.value_or(Value()); build(odsBuilder, odsState, results, query, key, value, scale, maskIn, output, - indexingMaps); + indexingMaps, DictionaryAttr()); } LogicalResult AttentionOp::verify() { @@ -1388,7 +1388,7 @@ void OnlineAttentionOp::build(OpBuilder &odsBuilder, OperationState &odsState, std::optional mask) { Value maskIn = mask.value_or(Value()); build(odsBuilder, odsState, results, query, key, value, maskIn, scale, output, - max, sum, indexingMaps); + max, sum, indexingMaps, DictionaryAttr()); } LogicalResult OnlineAttentionOp::verify() { diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td index e097ce5a9089..329c79ca5297 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td @@ -501,7 +501,8 @@ def IREELinalgExt_AttentionOp : IREELinalgExt_PureOp<"attention", AnyFloat:$scale, Optional:$mask, AnyShaped:$output, - AffineMapArrayAttr:$indexing_maps + AffineMapArrayAttr:$indexing_maps, + OptionalAttr:$decomposition_config ); let regions = (region SizedRegion<1>:$region); @@ -558,6 +559,12 @@ def IREELinalgExt_AttentionOp : IREELinalgExt_PureOp<"attention", int64_t getIterationDomainRank() { return getQueryMap().getNumDims(); } + + /* Decomposition control attributes */ + + // Attributes to set on QK and PV matmul after decomposition. + static StringRef getQKAttrStr() { return "qk_attrs"; } + static StringRef getPVAttrStr() { return "pv_attrs"; } }]; } @@ -612,7 +619,8 @@ def IREELinalgExt_OnlineAttentionOp : IREELinalgExt_PureOp<"online_attention", AnyShaped:$output, AnyShaped:$max, AnyShaped:$sum, - AffineMapArrayAttr:$indexing_maps + AffineMapArrayAttr:$indexing_maps, + OptionalAttr:$decomposition_config ); let regions = (region SizedRegion<1>:$region); @@ -679,6 +687,12 @@ def IREELinalgExt_OnlineAttentionOp : IREELinalgExt_PureOp<"online_attention", int64_t getIterationDomainRank() { return getQueryMap().getNumDims(); } + + /* Decomposition control attributes */ + + // Attributes to set on QK and PV matmul after decomposition. + static StringRef getQKAttrStr() { return "qk_attrs"; } + static StringRef getPVAttrStr() { return "pv_attrs"; } }]; } //===----------------------------------------------------------------------===// diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TileAttention.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TileAttention.cpp index 0aa3a37aa5fe..d9a48736fdd4 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TileAttention.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TileAttention.cpp @@ -106,7 +106,8 @@ void convertToOnlineAttention(IREE::LinalgExt::AttentionOp attnOp, loc, TypeRange{accFill.getType(), maxFill.getType(), sumFill.getType()}, attnOp.getQuery(), attnOp.getKey(), attnOp.getValue(), attnOp.getScale(), mask, accFill, maxFill, sumFill, - rewriter.getAffineMapArrayAttr(indexingMaps)); + rewriter.getAffineMapArrayAttr(indexingMaps), + attnOp.getDecompositionConfigAttr()); rewriter.cloneRegionBefore(attnOp.getRegion(), onlineAttn.getRegion(), onlineAttn.getRegion().begin()); From 9d36cfa0a95a606387b65a551cf33ba5d1fb91ee Mon Sep 17 00:00:00 2001 From: Max191 <44243577+Max191@users.noreply.github.com> Date: Mon, 28 Oct 2024 12:30:37 -0700 Subject: [PATCH 26/45] [Codegen] Don't require full slice to decompose boundary pack and unpack ops (#18906) This PR loosens the restrictions on decomposing boundary pack and unpack ops. The current restriction is that the dispatch.tensor.load/store ops are full slices, but this is not necessary for the current use case in the TileAndFuse pipeline. Instead, it is better for the time being to decompose non-padded pack/unpack ops at function boundaries regardless of the dispatch.tensor.load/store ops being full slices, because decomposing such ops later on can cause issues with DPS. The DPS issues are tracked in https://github.com/iree-org/iree/issues/18902, but we can loosen the restrictions regardless, since it does not pose any issues to decompose in such cases. Signed-off-by: Max Dawkins --- .../Codegen/Common/DecomposePackUnPackOps.cpp | 44 +++++-------------- .../decompose_boundary_pack_unpack_ops.mlir | 4 +- 2 files changed, 14 insertions(+), 34 deletions(-) diff --git a/compiler/src/iree/compiler/Codegen/Common/DecomposePackUnPackOps.cpp b/compiler/src/iree/compiler/Codegen/Common/DecomposePackUnPackOps.cpp index f8169411fc22..fed4470e8580 100644 --- a/compiler/src/iree/compiler/Codegen/Common/DecomposePackUnPackOps.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/DecomposePackUnPackOps.cpp @@ -314,51 +314,31 @@ static bool hasPadding(Operation *op) { } /// Control function for decomposing pack and unpack ops. Returns true if the -/// op is a pack or unpack op, and its reshapes can be folded with a producer -/// or consumer interface tensor op. To be foldable, the following conditions -/// must be met: -/// +/// op is an unpadded pack or unpack op, and it is at the boundary of a +/// dispatch. The following conditions need to be met: /// 1. The PackOp or UnPackOp must have no padding. /// 2. If the op is a PackOp, then its producer must be a dispatch tensor load. /// 3. If the op is an UnPackOp, then all of its consumers must be dispatch /// tensor stores. -/// 4. Any dispatch tensor load producers or dispatch tensor store consumers -/// must be full slices. -static LogicalResult isFoldableIntoInterfaceTensor(Operation *op) { - // Full slice means zero offsets, unit strides, and sizes match full tensor - // shape. - auto isFullSlice = - [](ArrayRef offsets, ArrayRef sizes, - ArrayRef strides, ArrayRef fullTensorShape) { - return areAllConstantIntValue(offsets, 0) && - areAllConstantIntValue(strides, 1) && - areConstantIntValues(sizes, fullTensorShape); - }; - if (!isa(op)) { +static LogicalResult isUnpaddedAndAtBoundary(Operation *op) { + if (!isa(op) && !isa(op)) { return failure(); } if (hasPadding(op)) { return failure(); } - // If the producer is a full slice dispatch tensor load, then the `op` is - // foldable if it is a PackOp. - auto load = dyn_cast( - op->getOperand(0).getDefiningOp()); - if (isa(op) && load && - isFullSlice(load.getMixedOffsets(), load.getMixedSizes(), - load.getMixedStrides(), load.getSourceType().getShape())) { + // If the producer is a dispatch tensor load, then the `op` is decomposable + // if it is a PackOp. + if (isa(op) && + op->getOperand(0).getDefiningOp()) { return success(); } - // If all consumers are full slice dispatch tensor stores, then the `op` is - // foldable if it is an UnPackOp. + // If all consumers are dispatch tensor stores, then the `op` is decomposable + // if it is an UnPackOp. if (isa(op) && llvm::all_of(op->getUsers(), [&](Operation *user) { - auto store = dyn_cast(user); - return store && - isFullSlice(store.getMixedOffsets(), store.getMixedSizes(), - store.getMixedStrides(), - store.getTargetType().getShape()); + return isa(user); })) { return success(); } @@ -368,7 +348,7 @@ static LogicalResult isFoldableIntoInterfaceTensor(Operation *op) { void DecomposeBoundaryPackUnPackOpsPass::runOnOperation() { if (failed(commonRunOnOperation(&getContext(), getOperation(), /*useOnlyReshapes=*/true, tileOuterToOne, - isFoldableIntoInterfaceTensor))) { + isUnpaddedAndAtBoundary))) { return signalPassFailure(); } } diff --git a/compiler/src/iree/compiler/Codegen/Common/test/decompose_boundary_pack_unpack_ops.mlir b/compiler/src/iree/compiler/Codegen/Common/test/decompose_boundary_pack_unpack_ops.mlir index 096043ba8897..6ff5bed59060 100644 --- a/compiler/src/iree/compiler/Codegen/Common/test/decompose_boundary_pack_unpack_ops.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/test/decompose_boundary_pack_unpack_ops.mlir @@ -133,7 +133,7 @@ func.func @load_non_full_slice() { return } // CHECK-LABEL: func.func @load_non_full_slice -// CHECK: tensor.pack +// CHECK-NOT: tensor.pack // ----- @@ -152,7 +152,7 @@ func.func @store_non_full_slice() { return } // CHECK-LABEL: func.func @store_non_full_slice -// CHECK: tensor.unpack +// CHECK-NOT: tensor.unpack // ----- From 67ba1c45424d5cedc7baf7bfe8a998ee86e510af Mon Sep 17 00:00:00 2001 From: Ben Vanik Date: Mon, 28 Oct 2024 16:42:55 -0700 Subject: [PATCH 27/45] Fixing missing parameters module fork_state vtable entry. --- runtime/src/iree/modules/io/parameters/module.c | 1 + 1 file changed, 1 insertion(+) diff --git a/runtime/src/iree/modules/io/parameters/module.c b/runtime/src/iree/modules/io/parameters/module.c index 655c1bae9f8f..c5dfffdd0614 100644 --- a/runtime/src/iree/modules/io/parameters/module.c +++ b/runtime/src/iree/modules/io/parameters/module.c @@ -489,6 +489,7 @@ IREE_API_EXPORT iree_status_t iree_io_parameters_module_create( .destroy = iree_io_parameters_module_destroy, .alloc_state = iree_io_parameters_module_alloc_state, .free_state = iree_io_parameters_module_free_state, + .fork_state = iree_io_parameters_module_fork_state, .notify = iree_io_parameters_module_notify, }; From f4a5f130ca18391db6dc0208168ab4c46a54ba94 Mon Sep 17 00:00:00 2001 From: Quinn Dawkins Date: Mon, 28 Oct 2024 21:36:46 -0400 Subject: [PATCH 28/45] Use workgroup_count_from_slice in Stream builtins (#18924) `workgroup_count_from_dag_root` is planned to be replaced in the future and is not supported by all codegen paths. Switch to `workgroup_count_from_slice`. --- .../iree/compiler/Dialect/Stream/Builtins/fill_i16.mlir | 9 +++++---- .../iree/compiler/Dialect/Stream/Builtins/fill_i32.mlir | 9 +++++---- .../iree/compiler/Dialect/Stream/Builtins/fill_i8.mlir | 9 +++++---- .../iree/compiler/Dialect/Stream/Builtins/splat_i16.mlir | 9 +++++---- .../iree/compiler/Dialect/Stream/Builtins/splat_i32.mlir | 9 +++++---- .../iree/compiler/Dialect/Stream/Builtins/splat_i8.mlir | 9 +++++---- 6 files changed, 30 insertions(+), 24 deletions(-) diff --git a/compiler/src/iree/compiler/Dialect/Stream/Builtins/fill_i16.mlir b/compiler/src/iree/compiler/Dialect/Stream/Builtins/fill_i16.mlir index 81f683f26cdf..af2286a55139 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Builtins/fill_i16.mlir +++ b/compiler/src/iree/compiler/Dialect/Stream/Builtins/fill_i16.mlir @@ -9,16 +9,17 @@ stream.executable private @__builtin_fill_i16 { stream.executable.export public @__builtin_fill_i16 workgroups(%arg0: index) -> (index, index, index) { - %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg0 + %x, %y, %z = flow.dispatch.workgroup_count_from_slice %arg0 stream.return %x, %y, %z : index, index, index } builtin.module { func.func @__builtin_fill_i16(%value: i16, %count: index, %out_binding: !stream.binding) { %c0 = arith.constant 0 : index - %out = stream.binding.subspan %out_binding[%c0] : !stream.binding -> !flow.dispatch.tensor>{%count} - %0 = tensor.empty(%count) : tensor + %count0 = flow.dispatch.workload.ordinal %count, 0 : index + %out = stream.binding.subspan %out_binding[%c0] : !stream.binding -> !flow.dispatch.tensor>{%count0} + %0 = tensor.empty(%count0) : tensor %1 = linalg.fill ins(%value : i16) outs(%0 : tensor) -> tensor - flow.dispatch.tensor.store %1, %out, offsets = [0], sizes = [%count], strides = [1] : tensor -> !flow.dispatch.tensor>{%count} + flow.dispatch.tensor.store %1, %out, offsets = [0], sizes = [%count0], strides = [1] : tensor -> !flow.dispatch.tensor>{%count} return } } diff --git a/compiler/src/iree/compiler/Dialect/Stream/Builtins/fill_i32.mlir b/compiler/src/iree/compiler/Dialect/Stream/Builtins/fill_i32.mlir index 43b0829e99e9..758591f4159e 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Builtins/fill_i32.mlir +++ b/compiler/src/iree/compiler/Dialect/Stream/Builtins/fill_i32.mlir @@ -9,16 +9,17 @@ stream.executable private @__builtin_fill_i32 { stream.executable.export public @__builtin_fill_i32 workgroups(%arg0: index) -> (index, index, index) { - %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg0 + %x, %y, %z = flow.dispatch.workgroup_count_from_slice %arg0 stream.return %x, %y, %z : index, index, index } builtin.module { func.func @__builtin_fill_i32(%value: i32, %count: index, %out_binding: !stream.binding) { %c0 = arith.constant 0 : index - %out = stream.binding.subspan %out_binding[%c0] : !stream.binding -> !flow.dispatch.tensor>{%count} - %0 = tensor.empty(%count) : tensor + %count0 = flow.dispatch.workload.ordinal %count, 0 : index + %out = stream.binding.subspan %out_binding[%c0] : !stream.binding -> !flow.dispatch.tensor>{%count0} + %0 = tensor.empty(%count0) : tensor %1 = linalg.fill ins(%value : i32) outs(%0 : tensor) -> tensor - flow.dispatch.tensor.store %1, %out, offsets = [0], sizes = [%count], strides = [1] : tensor -> !flow.dispatch.tensor>{%count} + flow.dispatch.tensor.store %1, %out, offsets = [0], sizes = [%count0], strides = [1] : tensor -> !flow.dispatch.tensor>{%count} return } } diff --git a/compiler/src/iree/compiler/Dialect/Stream/Builtins/fill_i8.mlir b/compiler/src/iree/compiler/Dialect/Stream/Builtins/fill_i8.mlir index 7005ded9aee4..c2c642dd53bb 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Builtins/fill_i8.mlir +++ b/compiler/src/iree/compiler/Dialect/Stream/Builtins/fill_i8.mlir @@ -9,16 +9,17 @@ stream.executable private @__builtin_fill_i8 { stream.executable.export public @__builtin_fill_i8 workgroups(%arg0: index) -> (index, index, index) { - %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg0 + %x, %y, %z = flow.dispatch.workgroup_count_from_slice %arg0 stream.return %x, %y, %z : index, index, index } builtin.module { func.func @__builtin_fill_i8(%value: i8, %count: index, %out_binding: !stream.binding) { %c0 = arith.constant 0 : index - %out = stream.binding.subspan %out_binding[%c0] : !stream.binding -> !flow.dispatch.tensor>{%count} - %0 = tensor.empty(%count) : tensor + %count0 = flow.dispatch.workload.ordinal %count, 0 : index + %out = stream.binding.subspan %out_binding[%c0] : !stream.binding -> !flow.dispatch.tensor>{%count0} + %0 = tensor.empty(%count0) : tensor %1 = linalg.fill ins(%value : i8) outs(%0 : tensor) -> tensor - flow.dispatch.tensor.store %1, %out, offsets = [0], sizes = [%count], strides = [1] : tensor -> !flow.dispatch.tensor>{%count} + flow.dispatch.tensor.store %1, %out, offsets = [0], sizes = [%count0], strides = [1] : tensor -> !flow.dispatch.tensor>{%count} return } } diff --git a/compiler/src/iree/compiler/Dialect/Stream/Builtins/splat_i16.mlir b/compiler/src/iree/compiler/Dialect/Stream/Builtins/splat_i16.mlir index a94cdf1d6cf7..139788921a8a 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Builtins/splat_i16.mlir +++ b/compiler/src/iree/compiler/Dialect/Stream/Builtins/splat_i16.mlir @@ -9,16 +9,17 @@ stream.executable private @__builtin_splat_i16 { stream.executable.export public @__builtin_splat_i16 workgroups(%arg0: index) -> (index, index, index) { - %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg0 + %x, %y, %z = flow.dispatch.workgroup_count_from_slice %arg0 stream.return %x, %y, %z : index, index, index } builtin.module { func.func @__builtin_splat_i16(%value: i16, %count: index, %out_binding: !stream.binding) { %c0 = arith.constant 0 : index - %out = stream.binding.subspan %out_binding[%c0] : !stream.binding -> !flow.dispatch.tensor>{%count} - %0 = tensor.empty(%count) : tensor + %count0 = flow.dispatch.workload.ordinal %count, 0 : index + %out = stream.binding.subspan %out_binding[%c0] : !stream.binding -> !flow.dispatch.tensor>{%count0} + %0 = tensor.empty(%count0) : tensor %1 = linalg.fill ins(%value : i16) outs(%0 : tensor) -> tensor - flow.dispatch.tensor.store %1, %out, offsets = [0], sizes = [%count], strides = [1] : tensor -> !flow.dispatch.tensor>{%count} + flow.dispatch.tensor.store %1, %out, offsets = [0], sizes = [%count0], strides = [1] : tensor -> !flow.dispatch.tensor>{%count} return } } diff --git a/compiler/src/iree/compiler/Dialect/Stream/Builtins/splat_i32.mlir b/compiler/src/iree/compiler/Dialect/Stream/Builtins/splat_i32.mlir index 07f3b4cb1b54..a1f19b894e7a 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Builtins/splat_i32.mlir +++ b/compiler/src/iree/compiler/Dialect/Stream/Builtins/splat_i32.mlir @@ -9,16 +9,17 @@ stream.executable private @__builtin_splat_i32 { stream.executable.export public @__builtin_splat_i32 workgroups(%arg0: index) -> (index, index, index) { - %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg0 + %x, %y, %z = flow.dispatch.workgroup_count_from_slice %arg0 stream.return %x, %y, %z : index, index, index } builtin.module { func.func @__builtin_splat_i32(%value: i32, %count: index, %out_binding: !stream.binding) { %c0 = arith.constant 0 : index - %out = stream.binding.subspan %out_binding[%c0] : !stream.binding -> !flow.dispatch.tensor>{%count} - %0 = tensor.empty(%count) : tensor + %count0 = flow.dispatch.workload.ordinal %count, 0 : index + %out = stream.binding.subspan %out_binding[%c0] : !stream.binding -> !flow.dispatch.tensor>{%count0} + %0 = tensor.empty(%count0) : tensor %1 = linalg.fill ins(%value : i32) outs(%0 : tensor) -> tensor - flow.dispatch.tensor.store %1, %out, offsets = [0], sizes = [%count], strides = [1] : tensor -> !flow.dispatch.tensor>{%count} + flow.dispatch.tensor.store %1, %out, offsets = [0], sizes = [%count0], strides = [1] : tensor -> !flow.dispatch.tensor>{%count} return } } diff --git a/compiler/src/iree/compiler/Dialect/Stream/Builtins/splat_i8.mlir b/compiler/src/iree/compiler/Dialect/Stream/Builtins/splat_i8.mlir index 5e5f8cb261d7..d0c6dc046f1e 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Builtins/splat_i8.mlir +++ b/compiler/src/iree/compiler/Dialect/Stream/Builtins/splat_i8.mlir @@ -9,16 +9,17 @@ stream.executable private @__builtin_splat_i8 { stream.executable.export public @__builtin_splat_i8 workgroups(%arg0: index) -> (index, index, index) { - %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg0 + %x, %y, %z = flow.dispatch.workgroup_count_from_slice %arg0 stream.return %x, %y, %z : index, index, index } builtin.module { func.func @__builtin_splat_i8(%value: i8, %count: index, %out_binding: !stream.binding) { %c0 = arith.constant 0 : index - %out = stream.binding.subspan %out_binding[%c0] : !stream.binding -> !flow.dispatch.tensor>{%count} - %0 = tensor.empty(%count) : tensor + %count0 = flow.dispatch.workload.ordinal %count, 0 : index + %out = stream.binding.subspan %out_binding[%c0] : !stream.binding -> !flow.dispatch.tensor>{%count0} + %0 = tensor.empty(%count0) : tensor %1 = linalg.fill ins(%value : i8) outs(%0 : tensor) -> tensor - flow.dispatch.tensor.store %1, %out, offsets = [0], sizes = [%count], strides = [1] : tensor -> !flow.dispatch.tensor>{%count} + flow.dispatch.tensor.store %1, %out, offsets = [0], sizes = [%count0], strides = [1] : tensor -> !flow.dispatch.tensor>{%count} return } } From 3b6967990b4161422c9545f11c50e52a430b1b4c Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Mon, 28 Oct 2024 22:35:27 -0700 Subject: [PATCH 29/45] Enable the MLIR debug actions CL options in the compiler driver. (#18928) Signed-off-by: Stella Laurenzo --- compiler/src/iree/compiler/API/Internal/BUILD.bazel | 1 + .../src/iree/compiler/API/Internal/CMakeLists.txt | 1 + .../src/iree/compiler/API/Internal/CompilerDriver.cpp | 11 +++++++++++ 3 files changed, 13 insertions(+) diff --git a/compiler/src/iree/compiler/API/Internal/BUILD.bazel b/compiler/src/iree/compiler/API/Internal/BUILD.bazel index c8ac4551dd1d..2413bed54150 100644 --- a/compiler/src/iree/compiler/API/Internal/BUILD.bazel +++ b/compiler/src/iree/compiler/API/Internal/BUILD.bazel @@ -38,6 +38,7 @@ iree_compiler_cc_library( "@llvm-project//mlir:BuiltinToLLVMIRTranslation", "@llvm-project//mlir:BytecodeWriter", "@llvm-project//mlir:CAPIIR", + "@llvm-project//mlir:Debug", "@llvm-project//mlir:IR", "@llvm-project//mlir:Parser", "@llvm-project//mlir:Support", diff --git a/compiler/src/iree/compiler/API/Internal/CMakeLists.txt b/compiler/src/iree/compiler/API/Internal/CMakeLists.txt index 61631e148162..191ea93a1cbe 100644 --- a/compiler/src/iree/compiler/API/Internal/CMakeLists.txt +++ b/compiler/src/iree/compiler/API/Internal/CMakeLists.txt @@ -23,6 +23,7 @@ iree_cc_library( MLIRBuiltinToLLVMIRTranslation MLIRBytecodeWriter MLIRCAPIIR + MLIRDebug MLIRIR MLIRParser MLIRSupport diff --git a/compiler/src/iree/compiler/API/Internal/CompilerDriver.cpp b/compiler/src/iree/compiler/API/Internal/CompilerDriver.cpp index 488555af6640..7f83a5e3b3fe 100644 --- a/compiler/src/iree/compiler/API/Internal/CompilerDriver.cpp +++ b/compiler/src/iree/compiler/API/Internal/CompilerDriver.cpp @@ -67,6 +67,7 @@ #include "mlir/Bytecode/BytecodeWriter.h" #include "mlir/CAPI/IR.h" #include "mlir/CAPI/Wrap.h" +#include "mlir/Debug/CLOptionsSetup.h" #include "mlir/IR/AsmState.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/Dialect.h" @@ -274,6 +275,7 @@ void GlobalInit::registerCommandLineOptions() { // Register pass manager command-line options like -mlir-print-ir-*. mlir::registerPassManagerCLOptions(); mlir::registerDefaultTimingManagerCLOptions(); + mlir::tracing::DebugConfig::registerCLOptions(); // Bind session options to the command line environment. clPluginManagerOptions = &PluginManagerOptions::FromFlags::get(); @@ -366,6 +368,11 @@ struct Session { // All user access to the context is done via this reference. MLIRContext &context; OptionsBinder binder; + + // Debug configuration. + mlir::tracing::DebugConfig debugConfig; + std::optional debugHandlerInstall; + // PluginManagerOptions must initialize first because the session depends on // it. PluginManagerOptions pluginManagerOptions; @@ -402,6 +409,7 @@ Session::Session(GlobalInit &globalInit) // Bootstrap session options from the cl environment, if enabled. if (globalInit.usesCommandLine) { + debugConfig = mlir::tracing::DebugConfig::createFromCLOptions(); pluginManagerOptions = *globalInit.clPluginManagerOptions; bindingOptions = *globalInit.clBindingOptions; inputOptions = *globalInit.clInputOptions; @@ -417,6 +425,9 @@ Session::Session(GlobalInit &globalInit) #endif } + // Enable debug integration. + debugHandlerInstall.emplace(context, debugConfig); + // Register each options struct with the binder so we can manipulate // mnemonically via the API. bindingOptions.bindOptions(binder); From fa752ae1e491a1f8fde8967bf04473c6a6c1ca18 Mon Sep 17 00:00:00 2001 From: Ian Wood <75152913+IanWood1@users.noreply.github.com> Date: Tue, 29 Oct 2024 00:42:13 -0700 Subject: [PATCH 30/45] [DispatchCreation] Run preprocessing before elementwise fusion (#18920) I think it makes sense to run `FusionPreprocessingPass` before `ElementwiseOpFusionPass` because it helps put the IR in a better state for fusion (e.g. interchanging `linalg.generic` indexing maps). But also, reshapes have been propagated to the edges of the program, which allows the `GatherFusionPattern` to be more effective. Fixes compilation error from https://github.com/iree-org/iree/issues/17226#issuecomment-2441200369. --------- Signed-off-by: Ian Wood --- .../src/iree/compiler/DispatchCreation/Passes.cpp | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/compiler/src/iree/compiler/DispatchCreation/Passes.cpp b/compiler/src/iree/compiler/DispatchCreation/Passes.cpp index afee21cbbcd8..9cf5732962fd 100644 --- a/compiler/src/iree/compiler/DispatchCreation/Passes.cpp +++ b/compiler/src/iree/compiler/DispatchCreation/Passes.cpp @@ -127,9 +127,12 @@ static void addCleanupPatterns(OpPassManager &passManager) { //===----------------------------------------------------------------------===// void addDispatchRegionCreationPreprocessingPasses(OpPassManager &passManager) { - // 1. Do some simple elementwise op fusion. This could be skipped, - // but could reduce the surface area of ops to handle later. FunctionLikeNest(passManager) + .addPass(IREE::Flow::createCanonicalizerPass) + .addPass(mlir::createCSEPass) + .addPass(DispatchCreation::createFusionPreprocessingPass) + // 1. Do some simple elementwise op fusion. This could be skipped, + // but could reduce the surface area of ops to handle later. .addPass([]() { return DispatchCreation::createElementwiseOpFusionPass( ElementwiseOpFusionPassOptions{ @@ -148,6 +151,7 @@ void addDispatchRegionCreationPreprocessingPasses(OpPassManager &passManager) { // 3. Perform elementwise operation fusion again (now with higher // dimensionality). + .addPass(DispatchCreation::createFusionPreprocessingPass) .addPass([]() { return DispatchCreation::createElementwiseOpFusionPass( ElementwiseOpFusionPassOptions{ @@ -294,12 +298,6 @@ void buildDispatchCreationPassPipeline( IREE::Util::createFixedPointIteratorPass(std::move(ipoPipeline))); } - FunctionLikeNest(passManager) - // Preprocess the input to a form more amenable for fusion. - .addPass(DispatchCreation::createFusionPreprocessingPass) - .addPass(IREE::Flow::createCanonicalizerPass) - .addPass(mlir::createCSEPass); - addDispatchRegionCreationPreprocessingPasses(passManager); addDispatchRegionCreationPasses(passManager); From 36caa0596f2ab5d53f21126ee0f123b08cf3a032 Mon Sep 17 00:00:00 2001 From: Quinn Dawkins Date: Tue, 29 Oct 2024 12:13:07 -0400 Subject: [PATCH 31/45] [ROCM] Add flag to enable GlobalISel (#18922) This exposes a flag to control the use of GlobalISel in the ROCm LLVM backend. There are observed correctness issues with SelectionDAG and the plan is to turn this on by default once all issues have been triaged. --- compiler/plugins/target/ROCM/ROCMTarget.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/compiler/plugins/target/ROCM/ROCMTarget.cpp b/compiler/plugins/target/ROCM/ROCMTarget.cpp index 37565aa5709d..e384dd79c405 100644 --- a/compiler/plugins/target/ROCM/ROCMTarget.cpp +++ b/compiler/plugins/target/ROCM/ROCMTarget.cpp @@ -62,6 +62,7 @@ struct ROCmOptions { std::string enableROCMUkernels = "none"; bool legacySync = true; bool slpVectorization = false; + bool globalISel = false; /// List of LLVM opt pass pluggins to be loaded during GPU code /// generation. The pluggins are paths to dynamic libraries that @@ -114,6 +115,8 @@ struct ROCmOptions { cl::desc( "Enable slp vectorization in llvm opt. This can have an impact on " "performance/numerics so its turned off by default currently.")); + binder.opt("iree-hip-llvm-global-isel", globalISel, cl::cat(category), + cl::desc("Enable global instruction selection in llvm.")); } LogicalResult verify(mlir::Builder &builder) const { @@ -466,6 +469,7 @@ class ROCMTargetBackend final : public TargetBackend { opt.UnsafeFPMath = false; opt.NoInfsFPMath = false; opt.NoNaNsFPMath = true; + opt.EnableGlobalISel = options.globalISel; SmallVector features; if (targetArch.starts_with("gfx10") || targetArch.starts_with("gfx11")) { From 3eeea7fec719b16cca22bf50f1b080c22be5f512 Mon Sep 17 00:00:00 2001 From: Marius Brehler Date: Tue, 29 Oct 2024 17:26:26 +0100 Subject: [PATCH 32/45] Work around missing pybind in BYO LLVM build (#18916) Installs pybind via pip to work around #18884. Closes #18884. --- build_tools/cmake/build_and_test_byo_llvm.sh | 6 ++++++ build_tools/llvm/byo_llvm.sh | 3 +++ 2 files changed, 9 insertions(+) diff --git a/build_tools/cmake/build_and_test_byo_llvm.sh b/build_tools/cmake/build_and_test_byo_llvm.sh index 043bc98c8784..d233664be2f1 100755 --- a/build_tools/cmake/build_and_test_byo_llvm.sh +++ b/build_tools/cmake/build_and_test_byo_llvm.sh @@ -24,6 +24,12 @@ echo "Setting up venv at $VENV_DIR" python3 -m venv "$VENV_DIR" source "$VENV_DIR/bin/activate" python -m pip install -r runtime/bindings/python/iree/runtime/build_requirements.txt +python -m pip install -r third_party/llvm-project/mlir/python/requirements.txt +# Note: IREE's Python bindings for Python 3.13 are build with support for +# free-threading for which support was added to pybind with version 2.13.0. +# Therefore, we upgrade to a more recent version and avoid mixing of different +# pybind versions. +python -m pip install pybind11==2.13.6 # Note: by using the `build_llvm` action here, we are exercising byo_llvm.sh's # ability to build LLVM... from our own third_party/llvm-project. That's not diff --git a/build_tools/llvm/byo_llvm.sh b/build_tools/llvm/byo_llvm.sh index 0f3d0fda3d4d..d88fb57caf45 100755 --- a/build_tools/llvm/byo_llvm.sh +++ b/build_tools/llvm/byo_llvm.sh @@ -113,6 +113,9 @@ do_build_mlir() { cmake_options="-DLLVM_DIR='${main_install_dir}/lib/cmake/llvm'" cmake_options="${cmake_options} -DPython3_EXECUTABLE='$(which $python3_command)'" + # Note: Building the MLIR Python bindings requires the installation of + # dependencies as specified in `mlir/python/requirements.txt`, which among + # others include pybind11. cmake_options="${cmake_options} -DMLIR_ENABLE_BINDINGS_PYTHON=ON" cmake_options="${cmake_options} -DCMAKE_INSTALL_PREFIX=${mlir_install_dir}" cmake_options="${cmake_options} -C $TD/mlir_config.cmake" From b31b0335a999d0c299f1d2b5567f37476e1339a3 Mon Sep 17 00:00:00 2001 From: Ian Wood <75152913+IanWood1@users.noreply.github.com> Date: Tue, 29 Oct 2024 09:29:12 -0700 Subject: [PATCH 33/45] Revert "[DispatchCreation] Run preprocessing before..." (#18934) This PR got merged before I was able to resolve the perf regressions in VAE decode on MI250. See @ScottTodd's comment on the original PR. I need time to resolve the regressions but this can be relanded once resolved Reverts iree-org/iree#18920 --- .../src/iree/compiler/DispatchCreation/Passes.cpp | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/compiler/src/iree/compiler/DispatchCreation/Passes.cpp b/compiler/src/iree/compiler/DispatchCreation/Passes.cpp index 9cf5732962fd..afee21cbbcd8 100644 --- a/compiler/src/iree/compiler/DispatchCreation/Passes.cpp +++ b/compiler/src/iree/compiler/DispatchCreation/Passes.cpp @@ -127,12 +127,9 @@ static void addCleanupPatterns(OpPassManager &passManager) { //===----------------------------------------------------------------------===// void addDispatchRegionCreationPreprocessingPasses(OpPassManager &passManager) { + // 1. Do some simple elementwise op fusion. This could be skipped, + // but could reduce the surface area of ops to handle later. FunctionLikeNest(passManager) - .addPass(IREE::Flow::createCanonicalizerPass) - .addPass(mlir::createCSEPass) - .addPass(DispatchCreation::createFusionPreprocessingPass) - // 1. Do some simple elementwise op fusion. This could be skipped, - // but could reduce the surface area of ops to handle later. .addPass([]() { return DispatchCreation::createElementwiseOpFusionPass( ElementwiseOpFusionPassOptions{ @@ -151,7 +148,6 @@ void addDispatchRegionCreationPreprocessingPasses(OpPassManager &passManager) { // 3. Perform elementwise operation fusion again (now with higher // dimensionality). - .addPass(DispatchCreation::createFusionPreprocessingPass) .addPass([]() { return DispatchCreation::createElementwiseOpFusionPass( ElementwiseOpFusionPassOptions{ @@ -298,6 +294,12 @@ void buildDispatchCreationPassPipeline( IREE::Util::createFixedPointIteratorPass(std::move(ipoPipeline))); } + FunctionLikeNest(passManager) + // Preprocess the input to a form more amenable for fusion. + .addPass(DispatchCreation::createFusionPreprocessingPass) + .addPass(IREE::Flow::createCanonicalizerPass) + .addPass(mlir::createCSEPass); + addDispatchRegionCreationPreprocessingPasses(passManager); addDispatchRegionCreationPasses(passManager); From 3cf5b65f736ce50c9890190b80e6343c0b929d56 Mon Sep 17 00:00:00 2001 From: Kunwar Grover Date: Tue, 29 Oct 2024 16:46:52 +0000 Subject: [PATCH 34/45] [LinalgExt] Implement AggregateOpInterface for AttentionOp (#18890) - Adds AggregateOpInterface for AttentionOp - Move all aggregate interface tests to IR/test/decompose_aggregate_op --- .../IR/AggregatedOpInterfaceImpl.cpp | 235 ++++++++++++++---- .../Dialect/LinalgExt/IR/LinalgExtOps.td | 1 + .../Dialect/LinalgExt/IR/test/BUILD.bazel | 1 + .../Dialect/LinalgExt/IR/test/CMakeLists.txt | 1 + .../test/decompose_aggregate_op.mlir} | 188 +++++++++++++- .../LinalgExt/Transforms/test/BUILD.bazel | 2 - .../LinalgExt/Transforms/test/CMakeLists.txt | 2 - .../test/decompose_aggregate_op.mlir | 62 ----- 8 files changed, 374 insertions(+), 118 deletions(-) rename compiler/src/iree/compiler/Dialect/LinalgExt/{Transforms/test/decompose_online_attention.mlir => IR/test/decompose_aggregate_op.mlir} (51%) delete mode 100644 compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_aggregate_op.mlir diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp index 204ae3533c7b..7fc985bf67ab 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp @@ -299,50 +299,29 @@ static bool willBeContiguousSlice(OpFoldResult inputSize, OpFoldResult tileSize, } //===----------------------------------------------------------------------===// -// OnlineAttentionOp +// Attention Helpers //===----------------------------------------------------------------------===// -FailureOr> -OnlineAttentionOp::decomposeOperation(OpBuilder &b) { - Location loc = getLoc(); - Value query = getQuery(); - Value key = getKey(); - Value value = getValue(); - std::optional mask = getMask(); - Value oldAcc = getOutput(); - Value oldMax = getMax(); - Value oldSum = getSum(); - Type elementType = getElementTypeOrSelf(getOutput().getType()); - DictionaryAttr config = getDecompositionConfigAttr(); - - DictionaryAttr qkAttrs, pvAttrs; - if (config) { - qkAttrs = config.getAs(getQKAttrStr()); - pvAttrs = config.getAs(getPVAttrStr()); - } - - FailureOr maybeOpInfo = - AttentionOpDetail::get(getIndexingMapsArray()); - assert(succeeded(maybeOpInfo) && "Invalid attention indexing maps"); - AttentionOpDetail opInfo = maybeOpInfo.value(); - - SmallVector sizes = llvm::map_to_vector( - getIterationDomain(b), [](Range x) { return x.size; }); - +Value computeQKAndElementwise(Location loc, OpBuilder &b, Value query, + Value key, Value scale, std::optional mask, + AffineMap qMap, AffineMap kMap, AffineMap sMap, + std::optional maskMap, + SmallVector iterationDomain, + Type sElementType, Region &elementwiseRegion, + DictionaryAttr qkAttrs, bool lowPrecision) { + MLIRContext *ctx = b.getContext(); // Since we use exp2 for attention instead of the original exp, we have to // multiply the scale by log2(e). We use exp2 instead of exp as most platforms // have better support for exp2 (we verified that we gain some speedup on // some GPUs). - Value scale = getScale(); Value log2e = b.create( loc, b.getFloatAttr(scale.getType(), M_LOG2E)); scale = b.create(loc, scale, log2e); auto qETy = getElementTypeOrSelf(query.getType()); - auto vETy = getElementTypeOrSelf(value.getType()); - AffineMap scaleMap = AffineMap::get(/*dimCount=*/getQueryMap().getNumInputs(), - /*symbolCount=*/0, getContext()); + AffineMap scaleMap = AffineMap::get(/*dimCount=*/qMap.getNumInputs(), + /*symbolCount=*/0, ctx); // In the original algorithm, the scaling is done after the softmax: // softmax(Q @ K.T * scale) @ V @@ -352,43 +331,40 @@ OnlineAttentionOp::decomposeOperation(OpBuilder &b) { // iteration of the loop. This is only valid for f16 or f32 types as f8 // is extremely limited on its dynamic range therefore this would // significantly affect numerics. - if (qETy.getIntOrFloatBitWidth() > 8) { - AffineMap qMap = getQueryMap(); + if (!lowPrecision) { query = elementwiseValueInPlace(b, loc, qMap, scaleMap, query, scale); } - // ---- Matmul 1 ---- + // ---- QK Matmul ---- // Get sizes for S. - AffineMap sMap = opInfo.getSMap(); SmallVector sSizes; for (AffineExpr dimExpr : sMap.getResults()) { int dim = cast(dimExpr).getPosition(); - sSizes.push_back(sizes[dim]); + sSizes.push_back(iterationDomain[dim]); } // S = Q @ K // SMap = QMap @ KMap - Value emptyS = b.create(loc, sSizes, elementType); - Value sZero = b.create(loc, b.getZeroAttr(elementType)); + Value emptyS = b.create(loc, sSizes, sElementType); + Value sZero = b.create(loc, b.getZeroAttr(sElementType)); Value s = b.create(loc, sZero, emptyS).getResult(0); - s = computeMatmul(b, loc, getQueryMap(), getKeyMap(), sMap, query, key, s); + s = computeMatmul(b, loc, qMap, kMap, sMap, query, key, s); if (qkAttrs) { - s.getDefiningOp()->setDiscardableAttrs(qkAttrs); + s.getDefiningOp()->setAttrs(qkAttrs); } - s = applyPostQKMatmulElementwise(b, loc, getRegion(), s); + s = applyPostQKMatmulElementwise(b, loc, elementwiseRegion, s); - bool lowPrecision = qETy.getIntOrFloatBitWidth() <= 8; if (lowPrecision) { // For low bit-depth types we perform post Q @ K scaling. This is to avoid // losing numerical precision due to the low dynamic range of fp8 types when // pre applying the sclaing. AffineMap sMap = b.getMultiDimIdentityMap(sSizes.size()); AffineMap scaleMap = AffineMap::get(/*dimCount=*/sMap.getNumInputs(), - /*symbolCount=*/0, getContext()); + /*symbolCount=*/0, ctx); s = elementwiseValueInPlace(b, loc, sMap, scaleMap, s, scale); @@ -401,16 +377,176 @@ OnlineAttentionOp::decomposeOperation(OpBuilder &b) { APFloat::getLargest(fpTy.getFloatSemantics(), /*Negative=*/false) .convertToDouble(); Value offset = b.create( - loc, b.getFloatAttr(elementType, clAttentionSoftmaxMax / mx)); + loc, b.getFloatAttr(sElementType, clAttentionSoftmaxMax / mx)); s = elementwiseValueInPlace(b, loc, sMap, scaleMap, s, offset); } // S += mask if (mask != nullptr) { - s = applyMask(b, loc, sMap, *getMaskMap(), s, mask.value()); + s = applyMask(b, loc, sMap, *maskMap, s, mask.value()); + } + + return s; +} + +//===----------------------------------------------------------------------===// +// AttentionOp +//===----------------------------------------------------------------------===// + +FailureOr> AttentionOp::decomposeOperation(OpBuilder &b) { + Location loc = getLoc(); + Value query = getQuery(); + Value key = getKey(); + Value value = getValue(); + std::optional mask = getMask(); + DictionaryAttr config = getDecompositionConfigAttr(); + + DictionaryAttr qkAttrs, pvAttrs; + if (config) { + qkAttrs = config.getAs(getQKAttrStr()); + pvAttrs = config.getAs(getPVAttrStr()); + } + Value output = getOutput(); + + FailureOr maybeOpInfo = + AttentionOpDetail::get(getIndexingMapsArray()); + assert(succeeded(maybeOpInfo) && "Invalid attention indexing maps"); + AttentionOpDetail opInfo = maybeOpInfo.value(); + + SmallVector sizes = llvm::map_to_vector( + getIterationDomain(b), [](Range x) { return x.size; }); + + AffineMap qMap = getQueryMap(); + AffineMap kMap = getKeyMap(); + AffineMap sMap = opInfo.getSMap(); + + auto qETy = getElementTypeOrSelf(query.getType()); + bool lowPrecision = qETy.getIntOrFloatBitWidth() <= 8; + + // We compute output of first matmul in f32. + Type f32Type = b.getF32Type(); + + // ---- QK Matmul + elementwise math ---- + Value s = computeQKAndElementwise(loc, b, query, key, getScale(), mask, qMap, + kMap, sMap, getMaskMap(), sizes, f32Type, + getRegion(), qkAttrs, lowPrecision); + + // ---- Softmax ---- + + AffineMap accMap = getOutputMap(); + + llvm::SmallBitVector projectedK2Dims(opInfo.getDomainRank(), false); + for (auto dim : opInfo.getK2Dims()) { + projectedK2Dims.set(dim); } + AffineMap maxMap = projectDims(sMap, projectedK2Dims).dropZeroResults(); + AffineMap sumMap = maxMap; + + SmallVector rowRedSize = + applyPermutationMap(maxMap, sizes); + + Value rowRedEmpty = b.create(loc, rowRedSize, f32Type); + + Value accInit = arith::getIdentityValue(arith::AtomicRMWKind::addf, + getElementTypeOrSelf(output), b, loc, + /*useOnlyFiniteValue=*/true); + Value maxInit = + arith::getIdentityValue(arith::AtomicRMWKind::maximumf, f32Type, b, loc, + /*useOnlyFiniteValue=*/true); + Value sumInit = + arith::getIdentityValue(arith::AtomicRMWKind::addf, f32Type, b, loc); + + Value accFill = + b.create(loc, ValueRange{accInit}, output).getResult(0); + Value maxFill = + b.create(loc, ValueRange{maxInit}, rowRedEmpty) + .getResult(0); + Value sumFill = + b.create(loc, ValueRange{sumInit}, rowRedEmpty) + .getResult(0); + + // max = rowMax(S) + Value max = reduce(b, loc, sMap, maxMap, s, maxFill); + + // P = exp2(S - max) + AffineMap pMap = sMap; + Value p = computeSubAndExp2(b, loc, maxMap, sMap, max, s); + + // sum = rowSum(P) + Value sum = reduce(b, loc, pMap, sumMap, p, sumFill); + + // P = P / sum + p = elementwiseValueInPlace(b, loc, pMap, sumMap, p, sum); + + // ---- Scale and truncate LHS to match RHS ---- + SmallVector sSizes; + for (AffineExpr dimExpr : sMap.getResults()) { + int dim = cast(dimExpr).getPosition(); + sSizes.push_back(sizes[dim]); + } + + auto pETy = getElementTypeOrSelf(p.getType()); + auto vETy = getElementTypeOrSelf(value.getType()); + if (pETy != vETy && isa(vETy)) { + Value convertP = b.create(loc, sSizes, vETy); + p = truncateFloat(b, loc, pMap, pMap, p, convertP, lowPrecision); + } + + // result = P @ V + acc + Value result = + computeMatmul(b, loc, pMap, getValueMap(), accMap, p, value, accFill); + if (pvAttrs) { + result.getDefiningOp()->setAttrs(pvAttrs); + } + + return SmallVector{result}; +} + +//===----------------------------------------------------------------------===// +// OnlineAttentionOp +//===----------------------------------------------------------------------===// + +FailureOr> +OnlineAttentionOp::decomposeOperation(OpBuilder &b) { + Location loc = getLoc(); + Value query = getQuery(); + Value key = getKey(); + Value value = getValue(); + std::optional mask = getMask(); + Value oldAcc = getOutput(); + Value oldMax = getMax(); + Value oldSum = getSum(); + Type elementType = getElementTypeOrSelf(getOutput().getType()); + DictionaryAttr config = getDecompositionConfigAttr(); + + DictionaryAttr qkAttrs, pvAttrs; + if (config) { + qkAttrs = config.getAs(getQKAttrStr()); + pvAttrs = config.getAs(getPVAttrStr()); + } + + FailureOr maybeOpInfo = + AttentionOpDetail::get(getIndexingMapsArray()); + assert(succeeded(maybeOpInfo) && "Invalid attention indexing maps"); + AttentionOpDetail opInfo = maybeOpInfo.value(); + + SmallVector sizes = llvm::map_to_vector( + getIterationDomain(b), [](Range x) { return x.size; }); + + AffineMap qMap = getQueryMap(); + AffineMap kMap = getKeyMap(); + AffineMap sMap = opInfo.getSMap(); + + auto qETy = getElementTypeOrSelf(query.getType()); + bool lowPrecision = qETy.getIntOrFloatBitWidth() <= 8; + + // ---- QK Matmul + elementwise math ---- + Value s = computeQKAndElementwise( + loc, b, query, key, getScale(), mask, qMap, kMap, sMap, getMaskMap(), + sizes, elementType, getRegion(), qkAttrs, lowPrecision); + // TODO: This decomposition should be in a seperate op called // "online softmax". // ---- Online Softmax ---- @@ -441,7 +577,14 @@ OnlineAttentionOp::decomposeOperation(OpBuilder &b) { AffineMap accMap = getOutputMap(); // ---- Scale and truncate LHS to match RHS ---- + SmallVector sSizes; + for (AffineExpr dimExpr : sMap.getResults()) { + int dim = cast(dimExpr).getPosition(); + sSizes.push_back(sizes[dim]); + } + auto pETy = getElementTypeOrSelf(p.getType()); + auto vETy = getElementTypeOrSelf(value.getType()); if (pETy != vETy && isa(vETy)) { Value convertP = b.create(loc, sSizes, vETy); p = truncateFloat(b, loc, pMap, pMap, p, convertP, lowPrecision); diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td index 329c79ca5297..3b46114abe5e 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td @@ -475,6 +475,7 @@ def IREELinalgExt_AttentionOp : IREELinalgExt_PureOp<"attention", ["getIndexingMapsForResults", "getIndexingMapsForOperands", "getStaticLoopRanges"]>, DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, DeclareOpInterfaceMethods !transform.any_op + transform.iree.decompose_aggregate_op %0 : (!transform.any_op) -> () + transform.yield + } +} + +func.func @custom_op_decomposition(%lhs1 : tensor<1000000x?xf32>, + %rhs1 : tensor, %rhs2 : tensor, %scalar : f32, + %outs1 : tensor<1000000x?xf32>, %outs2 : tensor<1000000x?xf32>) + -> (tensor<1000000x?xf32>, tensor<1000000x?xf32>) { + %0:2 = iree_linalg_ext.custom_op { + indexing_maps = [affine_map<(d0, d1)[s0, s1] -> (d0, s0)>, + affine_map<(d0, d1)[s0, s1] -> (s0, s1)>, + affine_map<(d0, d1)[s0, s1] -> (s1, d1)>, + affine_map<(d0, d1)[s0, s1] -> ()>, + affine_map<(d0, d1)[s0, s1] -> (d0, s1)>, + affine_map<(d0, d1)[s0, s1] -> (d0, d1)>], + iterator_types = [#iree_linalg_ext.iterator_type, + #iree_linalg_ext.iterator_type]} + ins(%lhs1, %rhs1, %rhs2, %scalar + : tensor<1000000x?xf32>, tensor, tensor, f32) + outs(%outs1, %outs2 : tensor<1000000x?xf32>, tensor<1000000x?xf32>) { + ^bb0(%t0 : tensor, %t1 : tensor, %t2 : tensor, + %s : f32, %t3 : tensor, %t4 : tensor) : + %0 = linalg.matmul ins(%t0, %t1 : tensor, tensor) + outs(%t3 : tensor) -> tensor + %1 = linalg.matmul ins(%0, %t2 : tensor, tensor) + outs(%t4 : tensor) -> tensor + %2 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> ()>, + affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"]} + ins(%1, %s : tensor, f32) outs(%1 : tensor) { + ^bb0(%b0 : f32, %b1 : f32, %b2 :f32): + %3 = arith.addf %b0, %b2 : f32 + linalg.yield %3 : f32 + } -> tensor + iree_linalg_ext.yield %0, %2 : tensor, tensor + } -> tensor<1000000x?xf32>, tensor<1000000x?xf32> + return %0#0, %0#1 : tensor<1000000x?xf32>, tensor<1000000x?xf32> +} + +// CHECK-LABEL: func @custom_op_decomposition( +// CHECK-SAME: %[[LHS1:[a-zA-Z0-9]+]]: tensor<1000000x?xf32> +// CHECK-SAME: %[[RHS1:[a-zA-Z0-9]+]]: tensor +// CHECK-SAME: %[[RHS2:[a-zA-Z0-9]+]]: tensor +// CHECK-SAME: %[[SCALAR:[a-zA-Z0-9]+]]: f32 +// CHECK-SAME: %[[INIT1:[a-zA-Z0-9]+]]: tensor<1000000x?xf32> +// CHECK-SAME: %[[INIT2:[a-zA-Z0-9]+]]: tensor<1000000x?xf32> +// CHECK: %[[MATMUL1:.+]] = linalg.matmul +// CHECK-SAME: ins(%[[LHS1]], %[[RHS1]] : +// CHECK-SAME: outs(%[[INIT1]] : +// CHECK: %[[MATMUL2:.+]] = linalg.matmul +// CHECK-SAME: ins(%[[MATMUL1]], %[[RHS2]] : +// CHECK-SAME: outs(%[[INIT2]] : +// CHECK: %[[GENERIC:.+]] = linalg.generic +// CHECK-SAME: ins(%[[MATMUL2]], %[[SCALAR]] : +// CHECK-SAME: outs(%[[MATMUL2]] : +// CHECK: return %[[MATMUL1]], %[[GENERIC]] + +// ----- + +// Spec to decompose online attention op. +module attributes { transform.with_named_sequence } { + transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["iree_linalg_ext.attention"]} in %module_op : (!transform.any_op) -> !transform.any_op + transform.iree.decompose_aggregate_op %0 : (!transform.any_op) -> () + transform.yield + } +} #mapQ = affine_map<(batch, m, k1, k2, n) -> (batch, m, k1)> #mapK = affine_map<(batch, m, k1, k2, n) -> (batch, k2, k1)> @@ -8,6 +83,89 @@ #mapR = affine_map<(batch, m, k1, k2, n) -> (batch, m)> func.func @attention_f16(%query: tensor<192x1024x64xf16>, + %key: tensor<192x1024x64xf16>, + %value: tensor<192x1024x64xf16>, + %output: tensor<192x1024x64xf32>) + -> (tensor<192x1024x64xf32>) { + %scale = arith.constant 1.0 : f16 + + %out = iree_linalg_ext.attention + { indexing_maps = [#mapQ, #mapK, #mapV, #mapS, #mapO] } + ins(%query, %key, %value, %scale : tensor<192x1024x64xf16>, tensor<192x1024x64xf16>, tensor<192x1024x64xf16>, f16) + outs(%output : tensor<192x1024x64xf32>) { + ^bb0(%score: f32): + iree_linalg_ext.yield %score: f32 + } + -> tensor<192x1024x64xf32> + + return %out : tensor<192x1024x64xf32> +} + +// We just want to check if we are using the correct algorithm +// CHECK-LABEL: @attention_f16 +// Q = Q * scale +// CHECK: linalg.generic +// CHECK: arith.mulf +// S = Q @ K +// CHECK: linalg.generic +// CHECK: arith.extf +// CHECK: arith.extf +// CHECK: arith.mulf +// CHECK: arith.addf +// CHECK: linalg.yield +// max = rowMax(S) +// CHECK: linalg.generic +// CHECK-NOT: arith.extf +// CHECK: arith.maximumf +// CHECK: linalg.yield +// P = exp2(S - max) +// CHECK: linalg.generic +// CHECK-NOT: arith.extf +// CHECK: arith.subf +// CHECK: math.exp2 +// CHECK: linalg.yield +// sum = rowSum(P) +// CHECK: linalg.generic +// CHECK-NOT: arith.extf +// CHECK: arith.addf +// CHECK: linalg.yield +// P = P /= sum +// CHECK: linalg.generic +// CHECK-NOT: arith.extf +// CHECK: arith.divf +// CHECK: linalg.yield +// truncf P : f32 to f16 +// CHECK: linalg.generic +// CHECK-NOT: arith.extf +// CHECK: arith.truncf +// CHECK: linalg.yield +// newAcc = P @ V +// CHECK: linalg.generic +// CHECK: arith.extf +// CHECK: arith.extf +// CHECK: arith.mulf +// CHECK: arith.addf +// CHECK: linalg.yield + +// ----- + +// Spec to decompose online attention op. +module attributes { transform.with_named_sequence } { + transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["iree_linalg_ext.online_attention"]} in %module_op : (!transform.any_op) -> !transform.any_op + transform.iree.decompose_aggregate_op %0 : (!transform.any_op) -> () + transform.yield + } +} + +#mapQ = affine_map<(batch, m, k1, k2, n) -> (batch, m, k1)> +#mapK = affine_map<(batch, m, k1, k2, n) -> (batch, k2, k1)> +#mapV = affine_map<(batch, m, k1, k2, n) -> (batch, k2, n)> +#mapS = affine_map<(batch, m, k1, k2, n) -> ()> +#mapO = affine_map<(batch, m, k1, k2, n) -> (batch, m, n)> +#mapR = affine_map<(batch, m, k1, k2, n) -> (batch, m)> + +func.func @online_attention_f16(%query: tensor<192x1024x64xf16>, %key: tensor<192x1024x64xf16>, %value: tensor<192x1024x64xf16>, %output: tensor<192x1024x64xf32>, @@ -30,7 +188,7 @@ func.func @attention_f16(%query: tensor<192x1024x64xf16>, // We just want to check if we are using the correct algorithm and the // correct number of extf/truncfs are emitted. -// CHECK-LABEL: @attention_f16 +// CHECK-LABEL: @online_attention_f16 // Q = Q * scale // CHECK: linalg.generic // CHECK: arith.mulf @@ -83,6 +241,15 @@ func.func @attention_f16(%query: tensor<192x1024x64xf16>, // ----- +// Spec to decompose online attention op. +module attributes { transform.with_named_sequence } { + transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["iree_linalg_ext.online_attention"]} in %module_op : (!transform.any_op) -> !transform.any_op + transform.iree.decompose_aggregate_op %0 : (!transform.any_op) -> () + transform.yield + } +} + #mapQ = affine_map<(batch, m, k1, k2, n) -> (batch, m, k1)> #mapK = affine_map<(batch, m, k1, k2, n) -> (batch, k2, k1)> #mapV = affine_map<(batch, m, k1, k2, n) -> (batch, k2, n)> @@ -90,7 +257,7 @@ func.func @attention_f16(%query: tensor<192x1024x64xf16>, #mapO = affine_map<(batch, m, k1, k2, n) -> (batch, m, n)> #mapR = affine_map<(batch, m, k1, k2, n) -> (batch, m)> -func.func @attention_f8(%query: tensor<192x1024x64xf8E4M3FNUZ>, +func.func @online_attention_f8(%query: tensor<192x1024x64xf8E4M3FNUZ>, %key: tensor<192x1024x64xf8E4M3FNUZ>, %value: tensor<192x1024x64xf8E4M3FNUZ>, %output: tensor<192x1024x64xf32>, @@ -111,7 +278,7 @@ func.func @attention_f8(%query: tensor<192x1024x64xf8E4M3FNUZ>, return %out#0, %out#2 : tensor<192x1024x64xf32>, tensor<192x1024xf32> } -// CHECK-LABEL: @attention_f8 +// CHECK-LABEL: @online_attention_f8 // S = Q @ K // CHECK: linalg.generic // CHECK: arith.extf %[[A:.+]] : f8E4M3FNUZ to f32 @@ -176,6 +343,15 @@ func.func @attention_f8(%query: tensor<192x1024x64xf8E4M3FNUZ>, // ----- +// Spec to decompose online attention op. +module attributes { transform.with_named_sequence } { + transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["iree_linalg_ext.online_attention"]} in %module_op : (!transform.any_op) -> !transform.any_op + transform.iree.decompose_aggregate_op %0 : (!transform.any_op) -> () + transform.yield + } +} + #mapQ = affine_map<(batch, m, k1, k2, n) -> (batch, m, k1)> #mapK = affine_map<(batch, m, k1, k2, n) -> (batch, k2, k1)> #mapV = affine_map<(batch, m, k1, k2, n) -> (batch, k2, n)> @@ -184,7 +360,7 @@ func.func @attention_f8(%query: tensor<192x1024x64xf8E4M3FNUZ>, #mapO = affine_map<(batch, m, k1, k2, n) -> (batch, m, n)> #mapR = affine_map<(batch, m, k1, k2, n) -> (batch, m)> -func.func @attention_f8_masked(%query: tensor<192x1024x64xf8E4M3FNUZ>, +func.func @online_attention_f8_masked(%query: tensor<192x1024x64xf8E4M3FNUZ>, %key: tensor<192x1024x64xf8E4M3FNUZ>, %value: tensor<192x1024x64xf8E4M3FNUZ>, %mask: tensor<192x1024x1024xf8E4M3FNUZ>, @@ -205,7 +381,7 @@ func.func @attention_f8_masked(%query: tensor<192x1024x64xf8E4M3FNUZ>, return %out#0, %out#2 : tensor<192x1024x64xf32>, tensor<192x1024xf32> } -// CHECK-LABEL: @attention_f8_masked +// CHECK-LABEL: @online_attention_f8_masked // S = Q @ K // CHECK: linalg.generic // CHECK: arith.extf %[[A:.+]] : f8E4M3FNUZ to f32 diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/BUILD.bazel index efe463a65949..6ba9d5cd801d 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/BUILD.bazel @@ -20,9 +20,7 @@ iree_lit_test_suite( "conv2d_to_winograd.mlir", "convert_to_loops.mlir", "convert_to_online_attention.mlir", - "decompose_aggregate_op.mlir", "decompose_im2col.mlir", - "decompose_online_attention.mlir", "decompose_winograd.mlir", "distribution.mlir", "pad_contraction_to_block_size.mlir", diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/CMakeLists.txt index 3288c1443dfd..a912973cb2f7 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/CMakeLists.txt @@ -18,9 +18,7 @@ iree_lit_test_suite( "conv2d_to_winograd.mlir" "convert_to_loops.mlir" "convert_to_online_attention.mlir" - "decompose_aggregate_op.mlir" "decompose_im2col.mlir" - "decompose_online_attention.mlir" "decompose_winograd.mlir" "distribution.mlir" "pad_contraction_to_block_size.mlir" diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_aggregate_op.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_aggregate_op.mlir deleted file mode 100644 index 80b0b7a693e3..000000000000 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_aggregate_op.mlir +++ /dev/null @@ -1,62 +0,0 @@ -// RUN: iree-opt --iree-transform-dialect-interpreter --canonicalize --mlir-print-local-scope --split-input-file %s | FileCheck %s - -func.func @custom_op_decomposition(%lhs1 : tensor<1000000x?xf32>, - %rhs1 : tensor, %rhs2 : tensor, %scalar : f32, - %outs1 : tensor<1000000x?xf32>, %outs2 : tensor<1000000x?xf32>) - -> (tensor<1000000x?xf32>, tensor<1000000x?xf32>) { - %0:2 = iree_linalg_ext.custom_op { - indexing_maps = [affine_map<(d0, d1)[s0, s1] -> (d0, s0)>, - affine_map<(d0, d1)[s0, s1] -> (s0, s1)>, - affine_map<(d0, d1)[s0, s1] -> (s1, d1)>, - affine_map<(d0, d1)[s0, s1] -> ()>, - affine_map<(d0, d1)[s0, s1] -> (d0, s1)>, - affine_map<(d0, d1)[s0, s1] -> (d0, d1)>], - iterator_types = [#iree_linalg_ext.iterator_type, - #iree_linalg_ext.iterator_type]} - ins(%lhs1, %rhs1, %rhs2, %scalar - : tensor<1000000x?xf32>, tensor, tensor, f32) - outs(%outs1, %outs2 : tensor<1000000x?xf32>, tensor<1000000x?xf32>) { - ^bb0(%t0 : tensor, %t1 : tensor, %t2 : tensor, - %s : f32, %t3 : tensor, %t4 : tensor) : - %0 = linalg.matmul ins(%t0, %t1 : tensor, tensor) - outs(%t3 : tensor) -> tensor - %1 = linalg.matmul ins(%0, %t2 : tensor, tensor) - outs(%t4 : tensor) -> tensor - %2 = linalg.generic { - indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, - affine_map<(d0, d1) -> ()>, - affine_map<(d0, d1) -> (d0, d1)>], - iterator_types = ["parallel", "parallel"]} - ins(%1, %s : tensor, f32) outs(%1 : tensor) { - ^bb0(%b0 : f32, %b1 : f32, %b2 :f32): - %3 = arith.addf %b0, %b2 : f32 - linalg.yield %3 : f32 - } -> tensor - iree_linalg_ext.yield %0, %2 : tensor, tensor - } -> tensor<1000000x?xf32>, tensor<1000000x?xf32> - return %0#0, %0#1 : tensor<1000000x?xf32>, tensor<1000000x?xf32> -} -module attributes { transform.with_named_sequence } { - transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) { - %0 = transform.structured.match ops{["iree_linalg_ext.custom_op"]} in %module_op : (!transform.any_op) -> !transform.any_op - transform.iree.decompose_aggregate_op %0 : (!transform.any_op) -> () - transform.yield - } -} -// CHECK-LABEL: func @custom_op_decomposition( -// CHECK-SAME: %[[LHS1:[a-zA-Z0-9]+]]: tensor<1000000x?xf32> -// CHECK-SAME: %[[RHS1:[a-zA-Z0-9]+]]: tensor -// CHECK-SAME: %[[RHS2:[a-zA-Z0-9]+]]: tensor -// CHECK-SAME: %[[SCALAR:[a-zA-Z0-9]+]]: f32 -// CHECK-SAME: %[[INIT1:[a-zA-Z0-9]+]]: tensor<1000000x?xf32> -// CHECK-SAME: %[[INIT2:[a-zA-Z0-9]+]]: tensor<1000000x?xf32> -// CHECK: %[[MATMUL1:.+]] = linalg.matmul -// CHECK-SAME: ins(%[[LHS1]], %[[RHS1]] : -// CHECK-SAME: outs(%[[INIT1]] : -// CHECK: %[[MATMUL2:.+]] = linalg.matmul -// CHECK-SAME: ins(%[[MATMUL1]], %[[RHS2]] : -// CHECK-SAME: outs(%[[INIT2]] : -// CHECK: %[[GENERIC:.+]] = linalg.generic -// CHECK-SAME: ins(%[[MATMUL2]], %[[SCALAR]] : -// CHECK-SAME: outs(%[[MATMUL2]] : -// CHECK: return %[[MATMUL1]], %[[GENERIC]] From 437611752055a0f3af168a8d20f7e35979927460 Mon Sep 17 00:00:00 2001 From: Kunwar Grover Date: Tue, 29 Oct 2024 16:59:36 +0000 Subject: [PATCH 35/45] [GPU] Do not treat pad as a tilable producer for operand promotion (#18918) PadOp doesn't have an implementation for deriving thread configuration from derived_thread_config, so ignore promoting it until an implementation is added. --- .../Common/GPU/GPUPromoteMatmulOperands.cpp | 12 +++++++--- .../GPU/test/gpu_promote_matmul_operands.mlir | 24 +++++++++++++++++++ 2 files changed, 33 insertions(+), 3 deletions(-) diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPromoteMatmulOperands.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPromoteMatmulOperands.cpp index dd498fad50e8..5e50a956bd82 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPromoteMatmulOperands.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPromoteMatmulOperands.cpp @@ -53,9 +53,15 @@ void promoteOperand(OpBuilder &builder, Operation *op, unsigned index) { return; } } - setLoweringConfig(producer, IREE::GPU::DerivedThreadConfigAttr::get( - builder.getContext())); - return; + + // We only support thread tile size derivation of linalgOp and Im2colOp for + // now. + if (isa( + producer.getOperation())) { + setLoweringConfig(producer, IREE::GPU::DerivedThreadConfigAttr::get( + builder.getContext())); + return; + } } auto tensorType = dyn_cast(operand.getType()); diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_promote_matmul_operands.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_promote_matmul_operands.mlir index f05cf7b1890b..643b12c01e39 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_promote_matmul_operands.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_promote_matmul_operands.mlir @@ -82,3 +82,27 @@ func.func @no_promote_fill(%b: tensor<128x128xf32>) -> tensor<4x128xf32> { // CHECK-LABEL: func.func @no_promote_fill // CHECK-NOT: iree_gpu.derived_thread_config // CHECK: return + +// ----- + +#lowering_config = #iree_gpu.lowering_config<{promote_operands = [0]}> + +func.func @promote_pad(%a : tensor<4x127xf32>, %b: tensor<128x128xf32>) -> tensor<4x128xf32> { + %cst = arith.constant 0.000000e+00 : f32 + %empty = tensor.empty() : tensor<4x128xf32> + %fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<4x128xf32>) -> tensor<4x128xf32> + %padded = tensor.pad %a low[0, 0] high[0, 1] { + ^bb0(%arg0: index, %arg1: index): + tensor.yield %cst : f32 + } : tensor<4x127xf32> to tensor<4x128xf32> + %mm = linalg.matmul {lowering_config = #lowering_config} + ins(%padded, %b : tensor<4x128xf32>, tensor<128x128xf32>) outs(%fill : tensor<4x128xf32>) -> tensor<4x128xf32> + return %mm : tensor<4x128xf32> +} + +// Verify that pad is promoted with linalg.copy +// CHECK-LABEL: func.func @promote_pad +// CHECK: tensor.pad +// CHECK: linalg.copy +// CHECK-SAME: derived_thread_config +// CHECK: return From a321be20c6efa40e128bb277db629a40cdeefb5e Mon Sep 17 00:00:00 2001 From: Ben Vanik Date: Tue, 29 Oct 2024 13:23:11 -0700 Subject: [PATCH 36/45] Adding 'amdgpu' target device and flatbuffer for HAL executables. (#18933) The schema may change as the branch gets closer to merging but the refactoring in the compiler for serializing multiple ABIs will remain the same. --- compiler/plugins/target/ROCM/BUILD.bazel | 1 + compiler/plugins/target/ROCM/CMakeLists.txt | 1 + compiler/plugins/target/ROCM/ROCMTarget.cpp | 305 +++++++++++++----- .../plugins/target/ROCM/ROCMTargetUtils.cpp | 4 +- runtime/src/iree/schemas/BUILD.bazel | 8 + runtime/src/iree/schemas/CMakeLists.txt | 15 + .../iree/schemas/amdgpu_executable_def.fbs | 63 ++++ 7 files changed, 321 insertions(+), 76 deletions(-) create mode 100644 runtime/src/iree/schemas/amdgpu_executable_def.fbs diff --git a/compiler/plugins/target/ROCM/BUILD.bazel b/compiler/plugins/target/ROCM/BUILD.bazel index 7962cf8e6073..9692d1aafd26 100644 --- a/compiler/plugins/target/ROCM/BUILD.bazel +++ b/compiler/plugins/target/ROCM/BUILD.bazel @@ -39,6 +39,7 @@ iree_compiler_cc_library( "//compiler/src/iree/compiler/Dialect/HAL/Utils:LLVMLinkerUtils", "//compiler/src/iree/compiler/PluginAPI", "//compiler/src/iree/compiler/Utils", + "//runtime/src/iree/schemas:amdgpu_executable_def_c_fbs", "//runtime/src/iree/schemas:executable_debug_info_c_fbs", "//runtime/src/iree/schemas:hip_executable_def_c_fbs", "@llvm-project//llvm:AMDGPUCodeGen", diff --git a/compiler/plugins/target/ROCM/CMakeLists.txt b/compiler/plugins/target/ROCM/CMakeLists.txt index 9430dca4fc16..938261acd14e 100644 --- a/compiler/plugins/target/ROCM/CMakeLists.txt +++ b/compiler/plugins/target/ROCM/CMakeLists.txt @@ -64,6 +64,7 @@ iree_cc_library( iree::compiler::Dialect::HAL::Utils::LLVMLinkerUtils iree::compiler::PluginAPI iree::compiler::Utils + iree::schemas::amdgpu_executable_def_c_fbs iree::schemas::executable_debug_info_c_fbs iree::schemas::hip_executable_def_c_fbs PUBLIC diff --git a/compiler/plugins/target/ROCM/ROCMTarget.cpp b/compiler/plugins/target/ROCM/ROCMTarget.cpp index e384dd79c405..c860b630fe77 100644 --- a/compiler/plugins/target/ROCM/ROCMTarget.cpp +++ b/compiler/plugins/target/ROCM/ROCMTarget.cpp @@ -23,6 +23,7 @@ #include "iree/compiler/PluginAPI/Client.h" #include "iree/compiler/Utils/FlatbufferUtils.h" #include "iree/compiler/Utils/ToolUtils.h" +#include "iree/schemas/amdgpu_executable_def_builder.h" #include "iree/schemas/hip_executable_def_builder.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringRef.h" @@ -54,7 +55,9 @@ namespace mlir::iree_compiler::IREE::HAL { namespace { -struct ROCmOptions { +// TODO(#18792): rename flags back to iree-rocm- as they are not HIP-specific. +// Only iree-hip-legacy-sync applies uniquely to HIP. +struct ROCMOptions { std::string target = ""; std::string targetFeatures = ""; std::string bitcodeDirectory = getDefaultBitcodeDirectory(); @@ -196,45 +199,9 @@ static std::string translateModuleToISA(llvm::Module &module, } } // namespace -class ROCMTargetDevice final : public TargetDevice { -public: - ROCMTargetDevice(const ROCmOptions &options) : options(options) {} - - IREE::HAL::DeviceTargetAttr - getDefaultDeviceTarget(MLIRContext *context, - const TargetRegistry &targetRegistry) const override { - Builder b(context); - - SmallVector deviceConfigAttrs; - if (options.legacySync) { - // Indicates that the runtime HAL driver operates only in the legacy - // synchronous mode. - deviceConfigAttrs.emplace_back(b.getStringAttr("legacy_sync"), - b.getUnitAttr()); - } - auto deviceConfigAttr = b.getDictionaryAttr(deviceConfigAttrs); - - SmallVector executableConfigAttrs; - auto executableConfigAttr = b.getDictionaryAttr(executableConfigAttrs); - - // If we had multiple target environments we would generate one target attr - // per environment, with each setting its own environment attribute. - SmallVector executableTargetAttrs; - targetRegistry.getTargetBackend("rocm")->getDefaultExecutableTargets( - context, "rocm", executableConfigAttr, executableTargetAttrs); - - return IREE::HAL::DeviceTargetAttr::get(context, b.getStringAttr("hip"), - deviceConfigAttr, - executableTargetAttrs); - } - -private: - const ROCmOptions &options; -}; - class ROCMTargetBackend final : public TargetBackend { public: - ROCMTargetBackend(const ROCmOptions &options) : options(options) {} + ROCMTargetBackend(const ROCMOptions &options) : options(options) {} std::string getLegacyDefaultDeviceID() const override { return "hip"; } @@ -242,31 +209,43 @@ class ROCMTargetBackend final : public TargetBackend { MLIRContext *context, StringRef deviceID, DictionaryAttr deviceConfigAttr, SmallVectorImpl &executableTargetAttrs) const override { - if (auto target = getExecutableTarget(context)) + if (auto target = getExecutableTarget(deviceID, context)) { executableTargetAttrs.push_back(target); + } } IREE::HAL::ExecutableTargetAttr - getExecutableTarget(MLIRContext *context) const { + getExecutableTarget(StringRef deviceID, MLIRContext *context) const { Builder b(context); SmallVector configItems; auto addConfig = [&](StringRef name, Attribute value) { configItems.emplace_back(b.getStringAttr(name), value); }; - if (failed(options.verify(b))) + if (failed(options.verify(b))) { return nullptr; + } + + addConfig("abi", b.getStringAttr(deviceID)); + std::string format; + if (deviceID == "amdgpu") { + format = options.target; + } else { + format = "rocm-hsaco-fb"; // legacy HIP + } - if (auto target = GPU::getHIPTargetDetails(options.target, - options.targetFeatures, context)) + if (auto target = GPU::getHIPTargetDetails( + options.target, options.targetFeatures, context)) { addConfig("iree.gpu.target", target); + } addConfig("ukernels", b.getStringAttr(options.enableROCMUkernels)); - if (options.wavesPerEu > 0) + if (options.wavesPerEu > 0) { addConfig("waves_per_eu", b.getI64IntegerAttr(options.wavesPerEu)); + } return b.getAttr( - b.getStringAttr("rocm"), b.getStringAttr("rocm-hsaco-fb"), + b.getStringAttr("rocm"), b.getStringAttr(format), b.getDictionaryAttr(configItems)); } @@ -356,9 +335,10 @@ class ROCMTargetBackend final : public TargetBackend { return success(); } - LogicalResult serializeExecutable(const SerializationOptions &serOptions, - IREE::HAL::ExecutableVariantOp variantOp, - OpBuilder &executableBuilder) override { + LogicalResult + serializeExecutable(const SerializationOptions &serializationOptions, + IREE::HAL::ExecutableVariantOp variantOp, + OpBuilder &executableBuilder) override { ModuleOp innerModuleOp = variantOp.getInnerModule(); auto targetAttr = variantOp.getTargetAttr(); StringRef targetArch = options.target; @@ -552,18 +532,18 @@ class ROCMTargetBackend final : public TargetBackend { return failure(); } - if (!serOptions.dumpIntermediatesPath.empty()) { - dumpModuleToPath(serOptions.dumpIntermediatesPath, - serOptions.dumpBaseName, variantOp.getName(), + if (!serializationOptions.dumpIntermediatesPath.empty()) { + dumpModuleToPath(serializationOptions.dumpIntermediatesPath, + serializationOptions.dumpBaseName, variantOp.getName(), ".linked.ll", *llvmModule); } // Run LLVM optimization passes. optimizeModule(*llvmModule, *targetMachine, options.passPlugins, options.slpVectorization); - if (!serOptions.dumpIntermediatesPath.empty()) { - dumpModuleToPath(serOptions.dumpIntermediatesPath, - serOptions.dumpBaseName, variantOp.getName(), + if (!serializationOptions.dumpIntermediatesPath.empty()) { + dumpModuleToPath(serializationOptions.dumpIntermediatesPath, + serializationOptions.dumpBaseName, variantOp.getName(), ".optimized.ll", *llvmModule); } @@ -572,7 +552,7 @@ class ROCMTargetBackend final : public TargetBackend { } // Dump the assembly output. - if (!serOptions.dumpIntermediatesPath.empty()) { + if (!serializationOptions.dumpIntermediatesPath.empty()) { auto moduleCopy = llvm::CloneModule(*llvmModule); if (!moduleCopy) { llvm::errs() << "Error: cloning LLVM IR failed\n"; @@ -580,9 +560,9 @@ class ROCMTargetBackend final : public TargetBackend { } std::string targetISA = translateModuleToISA(*moduleCopy.get(), *targetMachine); - dumpDataToPath(serOptions.dumpIntermediatesPath, - serOptions.dumpBaseName, variantOp.getName(), ".rocmasm", - targetISA); + dumpDataToPath(serializationOptions.dumpIntermediatesPath, + serializationOptions.dumpBaseName, variantOp.getName(), + ".rocmasm", targetISA); } // Serialize hsaco kernel into the binary that we will embed in the @@ -593,23 +573,136 @@ class ROCMTargetBackend final : public TargetBackend { return failure(); } - if (!serOptions.dumpBinariesPath.empty()) { - dumpDataToPath(serOptions.dumpBinariesPath, serOptions.dumpBaseName, - variantOp.getName(), ".hsaco", targetHSACO); + if (!serializationOptions.dumpBinariesPath.empty()) { + dumpDataToPath(serializationOptions.dumpBinariesPath, + serializationOptions.dumpBaseName, variantOp.getName(), + ".hsaco", targetHSACO); + } + + // Wrap the HSACO ELF binary in a Flatbuffers container. + FailureOr binaryContainer; + if (targetAttr.getConfiguration() && + targetAttr.getConfiguration().getAs("abi") == "amdgpu") { + binaryContainer = serializeAMDGPUBinaryContainer( + serializationOptions, variantOp, exportOps, targetHSACO); + } else { + binaryContainer = serializeHIPBinaryContainer( + serializationOptions, variantOp, exportOps, targetHSACO); + } + if (failed(binaryContainer) || !binaryContainer.value()) { + return failure(); + } + + // Add the binary data to the target executable. + executableBuilder.create( + variantOp.getLoc(), variantOp.getSymName(), + variantOp.getTarget().getFormat(), binaryContainer.value()); + + return success(); + } + +protected: + FailureOr serializeAMDGPUBinaryContainer( + const SerializationOptions &serializationOptions, + IREE::HAL::ExecutableVariantOp variantOp, + ArrayRef exportOps, + StringRef hsacoModule) { + iree_compiler::FlatbufferBuilder builder; + iree_hal_amdgpu_ExecutableDef_start_as_root(builder); + + // Attach embedded source file contents. + auto sourceFilesRef = createSourceFilesVec( + serializationOptions.debugLevel, variantOp.getSourcesAttr(), builder); + + // Only a single module today. + SmallVector moduleRefs; + { + auto hsacoImageRef = flatbuffers_string_create( + builder, hsacoModule.data(), hsacoModule.size()); + moduleRefs.push_back( + iree_hal_amdgpu_ModuleDef_create(builder, hsacoImageRef)); + } + auto modulesRef = builder.createOffsetVecDestructive(moduleRefs); + + // Generate optional per-export debug information. + // May be empty if no debug information was requested. + auto exportDebugInfos = + createExportDefs(serializationOptions.debugLevel, exportOps, builder); + + SmallVector exportRefs; + exportRefs.resize(exportOps.size(), 0); + for (auto exportOp : exportOps) { + auto ordinalAttr = exportOp.getOrdinalAttr(); + if (!ordinalAttr) { + return mlir::emitError(exportOp.getLoc()) + << "could not compile rocm binary: export op is missing ordinal"; + } + int64_t ordinal = ordinalAttr.getInt(); + + auto symbolNameRef = builder.createString(exportOp.getName()); + + iree_hal_amdgpu_Dims_t workgroupSize = {0}; + if (auto workgroupSizeAttr = exportOp.getWorkgroupSize()) { + auto workgroupSizeDims = workgroupSizeAttr->getValue(); + workgroupSize.x = cast(workgroupSizeDims[0]).getInt(); + workgroupSize.y = cast(workgroupSizeDims[1]).getInt(); + workgroupSize.z = cast(workgroupSizeDims[2]).getInt(); + } + + auto layoutAttr = exportOp.getLayoutAttr(); + uint32_t constantCount = static_cast(layoutAttr.getConstants()); + SmallVector bindingFlags; + for (auto bindingAttr : layoutAttr.getBindings()) { + iree_hal_amdgpu_BindingBits_enum_t flags = 0; + if (allEnumBitsSet(bindingAttr.getFlags(), + IREE::HAL::DescriptorFlags::ReadOnly)) { + flags |= iree_hal_amdgpu_BindingBits_READ_ONLY; + } + if (allEnumBitsSet(bindingAttr.getFlags(), + IREE::HAL::DescriptorFlags::Indirect)) { + flags |= iree_hal_amdgpu_BindingBits_INDIRECT; + } + bindingFlags.push_back(flags); + } + auto bindingFlagsRef = iree_hal_amdgpu_BindingBits_vec_create( + builder, bindingFlags.data(), bindingFlags.size()); + + iree_hal_amdgpu_ExportDef_start(builder); + iree_hal_amdgpu_ExportDef_symbol_name_add(builder, symbolNameRef); + iree_hal_amdgpu_ExportDef_workgroup_size_add(builder, &workgroupSize); + iree_hal_amdgpu_ExportDef_constant_count_add(builder, constantCount); + iree_hal_amdgpu_ExportDef_binding_flags_add(builder, bindingFlagsRef); + iree_hal_amdgpu_ExportDef_debug_info_add(builder, + exportDebugInfos[ordinal]); + exportRefs[ordinal] = iree_hal_amdgpu_ExportDef_end(builder); } + auto exportsRef = builder.createOffsetVecDestructive(exportRefs); + + iree_hal_amdgpu_ExecutableDef_exports_add(builder, exportsRef); + iree_hal_amdgpu_ExecutableDef_modules_add(builder, modulesRef); + iree_hal_amdgpu_ExecutableDef_source_files_add(builder, sourceFilesRef); + iree_hal_amdgpu_ExecutableDef_end_as_root(builder); + return builder.getBufferAttr(variantOp.getContext()); + } + + FailureOr + serializeHIPBinaryContainer(const SerializationOptions &serializationOptions, + IREE::HAL::ExecutableVariantOp variantOp, + ArrayRef exportOps, + StringRef hsacoModule) { iree_compiler::FlatbufferBuilder builder; iree_hal_hip_ExecutableDef_start_as_root(builder); // Attach embedded source file contents. auto sourceFilesRef = createSourceFilesVec( - serOptions.debugLevel, variantOp.getSourcesAttr(), builder); + serializationOptions.debugLevel, variantOp.getSourcesAttr(), builder); // Only a single module today. SmallVector moduleRefs; { auto hsacoImageRef = flatbuffers_string_create( - builder, targetHSACO.c_str(), targetHSACO.size()); + builder, hsacoModule.data(), hsacoModule.size()); moduleRefs.push_back( iree_hal_hip_ModuleDef_create(builder, hsacoImageRef)); } @@ -618,7 +711,7 @@ class ROCMTargetBackend final : public TargetBackend { // Generate optional per-export debug information. // May be empty if no debug information was requested. auto exportDebugInfos = - createExportDefs(serOptions.debugLevel, exportOps, builder); + createExportDefs(serializationOptions.debugLevel, exportOps, builder); SmallVector exportRefs; exportRefs.resize(exportOps.size(), 0); @@ -682,27 +775,91 @@ class ROCMTargetBackend final : public TargetBackend { iree_hal_hip_ExecutableDef_source_files_add(builder, sourceFilesRef); iree_hal_hip_ExecutableDef_end_as_root(builder); - // Add the binary data to the target executable. - executableBuilder.create( - variantOp.getLoc(), variantOp.getSymName(), - variantOp.getTarget().getFormat(), - builder.getBufferAttr(executableBuilder.getContext())); + return builder.getBufferAttr(variantOp.getContext()); + } - return success(); +private: + const ROCMOptions &options; +}; + +class AMDGPUTargetDevice final : public TargetDevice { +public: + AMDGPUTargetDevice(const ROCMOptions &options) : options(options) {} + + IREE::HAL::DeviceTargetAttr + getDefaultDeviceTarget(MLIRContext *context, + const TargetRegistry &targetRegistry) const override { + Builder b(context); + + SmallVector deviceConfigAttrs; + auto deviceConfigAttr = b.getDictionaryAttr(deviceConfigAttrs); + + SmallVector executableConfigAttrs; + auto executableConfigAttr = b.getDictionaryAttr(executableConfigAttrs); + + // If we had multiple target environments we would generate one target attr + // per environment, with each setting its own environment attribute. + SmallVector executableTargetAttrs; + targetRegistry.getTargetBackend("rocm")->getDefaultExecutableTargets( + context, "amdgpu", executableConfigAttr, executableTargetAttrs); + + return IREE::HAL::DeviceTargetAttr::get(context, b.getStringAttr("amdgpu"), + deviceConfigAttr, + executableTargetAttrs); + } + +private: + const ROCMOptions &options; +}; + +class HIPTargetDevice final : public TargetDevice { +public: + HIPTargetDevice(const ROCMOptions &options) : options(options) {} + + IREE::HAL::DeviceTargetAttr + getDefaultDeviceTarget(MLIRContext *context, + const TargetRegistry &targetRegistry) const override { + Builder b(context); + + SmallVector deviceConfigAttrs; + if (options.legacySync) { + // Indicates that the runtime HAL driver operates only in the legacy + // synchronous mode. + deviceConfigAttrs.emplace_back(b.getStringAttr("legacy_sync"), + b.getUnitAttr()); + } + auto deviceConfigAttr = b.getDictionaryAttr(deviceConfigAttrs); + + SmallVector executableConfigAttrs; + auto executableConfigAttr = b.getDictionaryAttr(executableConfigAttrs); + + // If we had multiple target environments we would generate one target attr + // per environment, with each setting its own environment attribute. + SmallVector executableTargetAttrs; + targetRegistry.getTargetBackend("rocm")->getDefaultExecutableTargets( + context, "hip", executableConfigAttr, executableTargetAttrs); + + return IREE::HAL::DeviceTargetAttr::get(context, b.getStringAttr("hip"), + deviceConfigAttr, + executableTargetAttrs); } private: - const ROCmOptions &options; + const ROCMOptions &options; }; namespace { struct ROCMSession final - : PluginSession { void populateHALTargetDevices(IREE::HAL::TargetDeviceList &targets) { + // #hal.device.target<"amdgpu", ... + targets.add("amdgpu", [&]() { + return std::make_shared(options); + }); // #hal.device.target<"hip", ... targets.add("hip", - [&]() { return std::make_shared(options); }); + [&]() { return std::make_shared(options); }); } void populateHALTargetBackends(IREE::HAL::TargetBackendList &targets) { // #hal.executable.target<"rocm", ... @@ -728,4 +885,4 @@ extern "C" bool iree_register_compiler_plugin_hal_target_rocm( return true; } -IREE_DEFINE_COMPILER_OPTION_FLAGS(mlir::iree_compiler::IREE::HAL::ROCmOptions); +IREE_DEFINE_COMPILER_OPTION_FLAGS(mlir::iree_compiler::IREE::HAL::ROCMOptions); diff --git a/compiler/plugins/target/ROCM/ROCMTargetUtils.cpp b/compiler/plugins/target/ROCM/ROCMTargetUtils.cpp index 7453af749b80..a1757afd75f1 100644 --- a/compiler/plugins/target/ROCM/ROCMTargetUtils.cpp +++ b/compiler/plugins/target/ROCM/ROCMTargetUtils.cpp @@ -35,7 +35,7 @@ loadIRModule(Location loc, const std::string &filename, diagnostic, *llvm_context)); if (!module) { - mlir::emitError(loc) << "error loading HIP LLVM module: " + mlir::emitError(loc) << "error loading ROCM LLVM module: " << diagnostic.getFilename().str() << ":" << diagnostic.getLineNo() << ":" << diagnostic.getColumnNo() << ": " @@ -90,7 +90,7 @@ static LogicalResult linkBitcodeFile(Location loc, llvm::Linker &linker, auto setAlwaysInline = [&](llvm::Module &module) { if (targetMachine.getTargetCPU().contains("gfx10") || targetMachine.getTargetCPU().contains("gfx11")) { - // some ROCM/HIP functions for gfx10 or gfx11 has accuracy issue if + // Some ROCM/HIP functions for gfx10 or gfx11 has accuracy issue if // inlined. return; } diff --git a/runtime/src/iree/schemas/BUILD.bazel b/runtime/src/iree/schemas/BUILD.bazel index a8fbfcab8b12..e98a425424ee 100644 --- a/runtime/src/iree/schemas/BUILD.bazel +++ b/runtime/src/iree/schemas/BUILD.bazel @@ -20,6 +20,13 @@ FLATCC_ARGS = [ "--json", ] +iree_flatbuffer_c_library( + name = "amdgpu_executable_def_c_fbs", + srcs = ["amdgpu_executable_def.fbs"], + flatcc_args = FLATCC_ARGS, + includes = ["executable_debug_info.fbs"], +) + iree_flatbuffer_c_library( name = "bytecode_module_def_c_fbs", srcs = ["bytecode_module_def.fbs"], @@ -70,6 +77,7 @@ iree_flatbuffer_c_library( iree_build_test( name = "schema_build_test", targets = [ + ":amdgpu_executable_def_c_fbs", ":bytecode_module_def_c_fbs", ":cuda_executable_def_c_fbs", ":executable_debug_info_c_fbs", diff --git a/runtime/src/iree/schemas/CMakeLists.txt b/runtime/src/iree/schemas/CMakeLists.txt index 574b2cac4578..f30430df0789 100644 --- a/runtime/src/iree/schemas/CMakeLists.txt +++ b/runtime/src/iree/schemas/CMakeLists.txt @@ -10,6 +10,21 @@ iree_add_all_subdirs() +flatbuffer_c_library( + NAME + amdgpu_executable_def_c_fbs + SRCS + "amdgpu_executable_def.fbs" + FLATCC_ARGS + "--reader" + "--builder" + "--verifier" + "--json" + INCLUDES + "executable_debug_info.fbs" + PUBLIC +) + flatbuffer_c_library( NAME bytecode_module_def_c_fbs diff --git a/runtime/src/iree/schemas/amdgpu_executable_def.fbs b/runtime/src/iree/schemas/amdgpu_executable_def.fbs new file mode 100644 index 000000000000..43efdb0a34dc --- /dev/null +++ b/runtime/src/iree/schemas/amdgpu_executable_def.fbs @@ -0,0 +1,63 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +include "iree/schemas/executable_debug_info.fbs"; + +namespace iree.hal.amdgpu; + +// 'AMDGPU v1 Executable'. +file_identifier "AMD1"; +file_extension "amd1"; + +// A struct for the kernel block size along each dimension. +struct Dims { + x:uint32; + y:uint32; + z:uint32; +} + +// Describes the behavior of each binding. +enum BindingBits:uint64 (bit_flags) { + READ_ONLY = 0, // 1u << 0 + INDIRECT = 1, // 1u << 1 +} + +// Information about an exported function on the executable. +table ExportDef { + // String name of the exported function symbol in the module. + symbol_name:string; + + // Workgroup size for the export. + workgroup_size:Dims; + + // Total number of 32-bit push constants used by the export. + constant_count:uint32; + + // Binding count and flags for each binding. + binding_flags:[BindingBits]; + + // Optional debug information related to the export. + debug_info:iree.hal.debug.ExportDef; +} + +// A library containing one or more exported functions. +table ModuleDef { + // AMD ELF image for loading an hsa_executable_t. + image:string; +} + +table ExecutableDef { + // Exported functions in canonical executable entry point order. + exports:[ExportDef]; + + // Modules containing executable code. + modules:[ModuleDef]; + + // Embedded source files sorted ascending by path. + source_files:[iree.hal.debug.SourceFileDef]; +} + +root_type ExecutableDef; From 49ffdac66b5ac8014b29077350584675d467a6f9 Mon Sep 17 00:00:00 2001 From: Ben Vanik Date: Tue, 29 Oct 2024 16:24:56 -0700 Subject: [PATCH 37/45] Enabling linking in the ROCM/CUDA compiler targets. (#18936) This does exactly what the LLVMCPU side does - which is bad for compile time (serializes LLVM codegen) but much better for runtime. Future improvements should move LLVM codegen to the linking phase so it can happen in parallel and then perform the linking using LLVM's linker (each executable turned into a .o and then combined into a .so, or last-level bitcode if then we just want serialization to be bitcode to machine code). This is definitely a compile-time regression but we can't keep pessimizing runtime. --- compiler/plugins/target/CUDA/CUDATarget.cpp | 4 + .../plugins/target/CUDA/test/smoketest.mlir | 44 +++- .../plugins/target/LLVMCPU/LLVMCPUTarget.cpp | 2 +- compiler/plugins/target/ROCM/ROCMTarget.cpp | 4 + .../plugins/target/ROCM/test/smoketest.mlir | 48 +++- .../API/Internal/IREEGPUDialectCAPI.cpp | 240 +++++++++--------- .../LLVMCPU/LLVMCPULinkExecutables.cpp | 25 +- .../iree/compiler/Codegen/LLVMCPU/Passes.cpp | 7 +- .../iree/compiler/Codegen/LLVMCPU/Passes.h | 6 +- .../iree/compiler/Codegen/LLVMCPU/Passes.td | 7 + .../iree/compiler/Codegen/LLVMGPU/BUILD.bazel | 2 + .../compiler/Codegen/LLVMGPU/CMakeLists.txt | 2 + .../LLVMGPU/LLVMGPUAssignConstantOrdinals.cpp | 53 ++++ .../LLVMGPU/LLVMGPULinkExecutables.cpp | 123 +++++++++ .../iree/compiler/Codegen/LLVMGPU/Passes.cpp | 26 ++ .../iree/compiler/Codegen/LLVMGPU/Passes.h | 15 +- .../iree/compiler/Codegen/LLVMGPU/Passes.td | 17 ++ .../compiler/Codegen/LLVMGPU/test/BUILD.bazel | 2 + .../Codegen/LLVMGPU/test/CMakeLists.txt | 2 + .../test/assign_constant_ordinals.mlir | 22 ++ .../LLVMGPU/test/link_executables.mlir | 150 +++++++++++ .../compiler/Codegen/Utils/LinkingUtils.cpp | 18 +- .../compiler/Codegen/Utils/LinkingUtils.h | 9 +- .../web/sample_static/device_multithreaded.c | 2 +- experimental/web/sample_static/device_sync.c | 2 +- tests/e2e/stablehlo_models/CMakeLists.txt | 4 +- 26 files changed, 669 insertions(+), 167 deletions(-) create mode 100644 compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUAssignConstantOrdinals.cpp create mode 100644 compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPULinkExecutables.cpp create mode 100644 compiler/src/iree/compiler/Codegen/LLVMGPU/test/assign_constant_ordinals.mlir create mode 100644 compiler/src/iree/compiler/Codegen/LLVMGPU/test/link_executables.mlir diff --git a/compiler/plugins/target/CUDA/CUDATarget.cpp b/compiler/plugins/target/CUDA/CUDATarget.cpp index 18896f2bb0fe..ffc49b57fa7d 100644 --- a/compiler/plugins/target/CUDA/CUDATarget.cpp +++ b/compiler/plugins/target/CUDA/CUDATarget.cpp @@ -461,6 +461,10 @@ class CUDATargetBackend final : public TargetBackend { buildLLVMGPUCodegenPassPipeline(passManager, false); } + void buildLinkingPassPipeline(OpPassManager &passManager) override { + buildLLVMGPULinkingPassPipeline(passManager, "cuda"); + } + LogicalResult serializeExecutable(const SerializationOptions &serOptions, IREE::HAL::ExecutableVariantOp variantOp, OpBuilder &executableBuilder) override { diff --git a/compiler/plugins/target/CUDA/test/smoketest.mlir b/compiler/plugins/target/CUDA/test/smoketest.mlir index 6e6fa946fcd9..0c12f0652e84 100644 --- a/compiler/plugins/target/CUDA/test/smoketest.mlir +++ b/compiler/plugins/target/CUDA/test/smoketest.mlir @@ -1,8 +1,6 @@ // RUN: iree-opt --split-input-file --iree-hal-transformation-pipeline --iree-gpu-test-target=sm_60 %s | FileCheck %s // RUN: iree-opt --split-input-file --iree-hal-transformation-pipeline --iree-gpu-test-target=sm_60 --iree-hal-dump-executable-binaries-to=- %s 2>&1 | FileCheck %s --check-prefix=PTX -#map = affine_map<(d0) -> (d0)> - module attributes { hal.device.targets = [ #hal.device.target<"cuda", [ @@ -11,13 +9,13 @@ module attributes { ] } { -stream.executable public @add_dispatch_0 { - stream.executable.export @add_dispatch_0 workgroups(%arg0 : index) -> (index, index, index) { +stream.executable public @add_dispatch_executable { + stream.executable.export @add_dispatch workgroups(%arg0 : index) -> (index, index, index) { %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg0 stream.return %x, %y, %z : index, index, index } builtin.module { - func.func @add_dispatch_0(%arg0_binding: !stream.binding, %arg1_binding: !stream.binding, %arg2_binding: !stream.binding) { + func.func @add_dispatch(%arg0_binding: !stream.binding, %arg1_binding: !stream.binding, %arg2_binding: !stream.binding) { %c0 = arith.constant 0 : index %arg0 = stream.binding.subspan %arg0_binding[%c0] : !stream.binding -> !flow.dispatch.tensor> %arg1 = stream.binding.subspan %arg1_binding[%c0] : !stream.binding -> !flow.dispatch.tensor> @@ -26,7 +24,7 @@ stream.executable public @add_dispatch_0 { %1 = flow.dispatch.tensor.load %arg0, offsets=[0], sizes=[16], strides=[1] : !flow.dispatch.tensor> -> tensor<16xf32> %2 = flow.dispatch.tensor.load %arg1, offsets=[0], sizes=[16], strides=[1] : !flow.dispatch.tensor> -> tensor<16xf32> %3 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%1, %2 : tensor<16xf32>, tensor<16xf32>) outs(%0 : tensor<16xf32>) { - ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors + ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): %4 = arith.addf %arg3, %arg4 : f32 linalg.yield %4 : f32 } -> tensor<16xf32> @@ -36,12 +34,42 @@ stream.executable public @add_dispatch_0 { } } +stream.executable public @mul_dispatch_executable { + stream.executable.export @mul_dispatch workgroups(%arg0 : index) -> (index, index, index) { + %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg0 + stream.return %x, %y, %z : index, index, index + } + builtin.module { + func.func @mul_dispatch(%arg0_binding: !stream.binding, %arg1_binding: !stream.binding, %arg2_binding: !stream.binding) { + %c0 = arith.constant 0 : index + %arg0 = stream.binding.subspan %arg0_binding[%c0] : !stream.binding -> !flow.dispatch.tensor> + %arg1 = stream.binding.subspan %arg1_binding[%c0] : !stream.binding -> !flow.dispatch.tensor> + %arg2 = stream.binding.subspan %arg2_binding[%c0] : !stream.binding -> !flow.dispatch.tensor> + %0 = tensor.empty() : tensor<16xf32> + %1 = flow.dispatch.tensor.load %arg0, offsets=[0], sizes=[16], strides=[1] : !flow.dispatch.tensor> -> tensor<16xf32> + %2 = flow.dispatch.tensor.load %arg1, offsets=[0], sizes=[16], strides=[1] : !flow.dispatch.tensor> -> tensor<16xf32> + %3 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%1, %2 : tensor<16xf32>, tensor<16xf32>) outs(%0 : tensor<16xf32>) { + ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): + %4 = arith.mulf %arg3, %arg4 : f32 + linalg.yield %4 : f32 + } -> tensor<16xf32> + flow.dispatch.tensor.store %3, %arg2, offsets=[0], sizes=[16], strides=[1] : tensor<16xf32> -> !flow.dispatch.tensor> + return + } + } +} + } -// PTX: .entry add_dispatch_0 +// PTX: .entry add_dispatch // PTX: .maxntid 64, 1, 1 // PTX: add.rn.f32 -// CHECK: hal.executable.binary public @cuda_nvptx_fb attributes { +// PTX: .entry mul_dispatch +// PTX: .maxntid 64, 1, 1 +// PTX: mul.rn.f32 + +// CHECK: hal.executable public @smoketest_linked +// CHECK-NEXT: hal.executable.binary public @cuda_nvptx_fb attributes { // CHECK-SAME: data = dense // CHECK-SAME: format = "cuda-nvptx-fb" diff --git a/compiler/plugins/target/LLVMCPU/LLVMCPUTarget.cpp b/compiler/plugins/target/LLVMCPU/LLVMCPUTarget.cpp index 7db50acd0033..ee8e256321a5 100644 --- a/compiler/plugins/target/LLVMCPU/LLVMCPUTarget.cpp +++ b/compiler/plugins/target/LLVMCPU/LLVMCPUTarget.cpp @@ -241,7 +241,7 @@ class LLVMCPUTargetBackend final : public TargetBackend { } void buildLinkingPassPipeline(OpPassManager &passManager) override { - buildLLVMCPULinkingPassPipeline(passManager); + buildLLVMCPULinkingPassPipeline(passManager, "llvm-cpu"); } // Gets the LLVM target from |variantOp|. diff --git a/compiler/plugins/target/ROCM/ROCMTarget.cpp b/compiler/plugins/target/ROCM/ROCMTarget.cpp index c860b630fe77..05ab66779271 100644 --- a/compiler/plugins/target/ROCM/ROCMTarget.cpp +++ b/compiler/plugins/target/ROCM/ROCMTarget.cpp @@ -269,6 +269,10 @@ class ROCMTargetBackend final : public TargetBackend { buildLLVMGPUCodegenPassPipeline(passManager, true); } + void buildLinkingPassPipeline(OpPassManager &passManager) override { + buildLLVMGPULinkingPassPipeline(passManager, "rocm"); + } + // Performs optimizations on |module| (including LTO-style whole-program // ones). Inspired by code section in // https://github.com/iree-org/iree/blob/main/compiler/plugins/target/CUDA/CUDATarget.cpp diff --git a/compiler/plugins/target/ROCM/test/smoketest.mlir b/compiler/plugins/target/ROCM/test/smoketest.mlir index 1afe688467ee..a25547b387e2 100644 --- a/compiler/plugins/target/ROCM/test/smoketest.mlir +++ b/compiler/plugins/target/ROCM/test/smoketest.mlir @@ -2,19 +2,19 @@ module attributes { hal.device.targets = [ - #hal.device.target<"hip", [ - #hal.executable.target<"rocm", "rocm-hsaco-fb"> + #hal.device.target<"amdgpu", [ + #hal.executable.target<"rocm", "amdgcn-amd-amdhsa"> ]> : !hal.device ] } { -stream.executable public @add_dispatch_0 { - stream.executable.export @add_dispatch_0 workgroups(%arg0 : index) -> (index, index, index) { +stream.executable public @add_dispatch_executable { + stream.executable.export @add_dispatch workgroups(%arg0 : index) -> (index, index, index) { %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg0 stream.return %x, %y, %z : index, index, index } builtin.module { - func.func @add_dispatch_0(%arg0_binding: !stream.binding, %arg1_binding: !stream.binding, %arg2_binding: !stream.binding) { + func.func @add_dispatch(%arg0_binding: !stream.binding, %arg1_binding: !stream.binding, %arg2_binding: !stream.binding) { %c0 = arith.constant 0 : index %arg0 = stream.binding.subspan %arg0_binding[%c0] : !stream.binding -> !flow.dispatch.tensor> %arg1 = stream.binding.subspan %arg1_binding[%c0] : !stream.binding -> !flow.dispatch.tensor> @@ -23,7 +23,7 @@ stream.executable public @add_dispatch_0 { %1 = flow.dispatch.tensor.load %arg0, offsets=[0], sizes=[16], strides=[1] : !flow.dispatch.tensor> -> tensor<16xf32> %2 = flow.dispatch.tensor.load %arg1, offsets=[0], sizes=[16], strides=[1] : !flow.dispatch.tensor> -> tensor<16xf32> %3 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%1, %2 : tensor<16xf32>, tensor<16xf32>) outs(%0 : tensor<16xf32>) { - ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors + ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): %4 = arith.addf %arg3, %arg4 : f32 linalg.yield %4 : f32 } -> tensor<16xf32> @@ -33,11 +33,37 @@ stream.executable public @add_dispatch_0 { } } +stream.executable public @mul_dispatch_executable { + stream.executable.export @mul_dispatch workgroups(%arg0 : index) -> (index, index, index) { + %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg0 + stream.return %x, %y, %z : index, index, index + } + builtin.module { + func.func @mul_dispatch(%arg0_binding: !stream.binding, %arg1_binding: !stream.binding, %arg2_binding: !stream.binding) { + %c0 = arith.constant 0 : index + %arg0 = stream.binding.subspan %arg0_binding[%c0] : !stream.binding -> !flow.dispatch.tensor> + %arg1 = stream.binding.subspan %arg1_binding[%c0] : !stream.binding -> !flow.dispatch.tensor> + %arg2 = stream.binding.subspan %arg2_binding[%c0] : !stream.binding -> !flow.dispatch.tensor> + %0 = tensor.empty() : tensor<16xf32> + %1 = flow.dispatch.tensor.load %arg0, offsets=[0], sizes=[16], strides=[1] : !flow.dispatch.tensor> -> tensor<16xf32> + %2 = flow.dispatch.tensor.load %arg1, offsets=[0], sizes=[16], strides=[1] : !flow.dispatch.tensor> -> tensor<16xf32> + %3 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%1, %2 : tensor<16xf32>, tensor<16xf32>) outs(%0 : tensor<16xf32>) { + ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): + %4 = arith.mulf %arg3, %arg4 : f32 + linalg.yield %4 : f32 + } -> tensor<16xf32> + flow.dispatch.tensor.store %3, %arg2, offsets=[0], sizes=[16], strides=[1] : tensor<16xf32> -> !flow.dispatch.tensor> + return + } + } +} + } -// CHECK: hal.executable.binary public @rocm_hsaco_fb attributes { +// CHECK: hal.executable public @smoketest_linked +// CHECK: hal.executable.binary public @amdgcn_amd_amdhsa attributes { // CHECK-SAME: data = dense -// CHECK-SAME: format = "rocm-hsaco-fb" +// CHECK-SAME: format = "amdgcn-amd-amdhsa" // ----- @@ -52,13 +78,13 @@ module attributes { ] } { -stream.executable public @add_dispatch_0 { - stream.executable.export @add_dispatch_0 workgroups(%arg0 : index) -> (index, index, index) { +stream.executable public @executable { + stream.executable.export @export workgroups(%arg0 : index) -> (index, index, index) { %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg0 stream.return %x, %y, %z : index, index, index } loc(#loc) builtin.module { - func.func @add_dispatch_0() { + func.func @export() { return } loc(#loc) } loc(#loc) diff --git a/compiler/src/iree/compiler/API/Internal/IREEGPUDialectCAPI.cpp b/compiler/src/iree/compiler/API/Internal/IREEGPUDialectCAPI.cpp index 9b4639bb0cc2..555601c4bcc4 100644 --- a/compiler/src/iree/compiler/API/Internal/IREEGPUDialectCAPI.cpp +++ b/compiler/src/iree/compiler/API/Internal/IREEGPUDialectCAPI.cpp @@ -1,120 +1,120 @@ -// Copyright 2024 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h" -#include "iree/compiler/dialects/iree_gpu.h" -#include "mlir-c/IR.h" -#include "mlir/CAPI/IR.h" -#include "mlir/CAPI/Support.h" - -bool ireeAttributeIsAGPUPipelineOptionsAttr(MlirAttribute attr) { - return llvm::isa( - unwrap(attr)); -} - -MlirAttribute -ireeGPUPipelineOptionsAttrGet(MlirContext mlirCtx, bool *prefetchSharedMemory, - bool *noReduceSharedMemoryBankConflicts, - MlirAttribute *reorderWorkgroupsStrategy) { - mlir::MLIRContext *ctx = unwrap(mlirCtx); - mlir::Builder b(ctx); - auto prefetchSharedMemoryAttr = mlir::BoolAttr(); - if (prefetchSharedMemory) { - prefetchSharedMemoryAttr = b.getBoolAttr(*prefetchSharedMemory); - } - auto noReduceSharedMemoryBankConflictsAttr = mlir::BoolAttr(); - if (noReduceSharedMemoryBankConflicts) { - noReduceSharedMemoryBankConflictsAttr = - b.getBoolAttr(*noReduceSharedMemoryBankConflicts); - } - auto strategyAttr = - mlir::iree_compiler::IREE::GPU::ReorderWorkgroupsStrategyAttr(); - if (reorderWorkgroupsStrategy) { - strategyAttr = llvm::dyn_cast< - mlir::iree_compiler::IREE::GPU::ReorderWorkgroupsStrategyAttr>( - unwrap(*reorderWorkgroupsStrategy)); - } - return wrap(mlir::iree_compiler::IREE::GPU::GPUPipelineOptionsAttr::get( - ctx, prefetchSharedMemoryAttr, noReduceSharedMemoryBankConflictsAttr, - strategyAttr)); -} - -MlirAttribute -ireeGPUPipelineOptionsAttrGetPrefetchSharedMemory(MlirAttribute attr) { - auto gpuAttr = - llvm::cast( - unwrap(attr)); - return wrap(gpuAttr.getPrefetchSharedMemory()); -} - -MlirAttribute ireeGPUPipelineOptionsAttrGetNoReduceSharedMemoryBankConflicts( - MlirAttribute attr) { - auto gpuAttr = - llvm::cast( - unwrap(attr)); - return wrap(gpuAttr.getNoReduceSharedMemoryBankConflicts()); -} - -MlirAttribute -ireeGPUPipelineOptionsAttrGetReorderWorkgroupsStrategy(MlirAttribute attr) { - auto gpuAttr = - llvm::cast( - unwrap(attr)); - return wrap(gpuAttr.getReorderWorkgroupsStrategy()); -} - -MlirTypeID ireeGPUPipelineOptionsAttrGetTypeID() { - return wrap( - mlir::iree_compiler::IREE::GPU::GPUPipelineOptionsAttr::getTypeID()); -} - -static_assert( - static_cast(ireeGPUReorderWorkgroupsStrategyEnumNone) == - static_cast(mlir::iree_compiler::IREE::GPU:: - ReorderWorkgroupsStrategy::None) && - static_cast(ireeGPUReorderWorkgroupsStrategyEnumSwizzle) == - static_cast(mlir::iree_compiler::IREE::GPU:: - ReorderWorkgroupsStrategy::Swizzle) && - static_cast(ireeGPUReorderWorkgroupsStrategyEnumTranspose) == - static_cast(mlir::iree_compiler::IREE::GPU:: - ReorderWorkgroupsStrategy::Transpose) && - static_cast(ireeGPUReorderWorkgroupsStrategyEnumTranspose) == - mlir::iree_compiler::IREE::GPU:: - getMaxEnumValForReorderWorkgroupsStrategy(), - "ireeGPUReorderWorkgroupsStrategyEnum and " - "mlir::iree_compiler::IREE::GPU::ReorderWorkgroupsStrategy definitions " - "have diverged"); - -bool ireeAttributeIsAGPUReorderWorkgroupsStrategyAttr(MlirAttribute attr) { - return llvm::isa< - mlir::iree_compiler::IREE::GPU::ReorderWorkgroupsStrategyAttr>( - unwrap(attr)); -} - -MlirTypeID ireeGPUReorderWorkgroupsStrategyAttrGetTypeID() { - return wrap(mlir::iree_compiler::IREE::GPU::ReorderWorkgroupsStrategyAttr:: - getTypeID()); -} - -MlirAttribute ireeGPUReorderWorkgroupsStrategyAttrGet( - MlirContext mlirCtx, ireeGPUReorderWorkgroupsStrategyEnum value) { - mlir::MLIRContext *ctx = unwrap(mlirCtx); - return wrap( - mlir::iree_compiler::IREE::GPU::ReorderWorkgroupsStrategyAttr::get( - ctx, static_cast< - mlir::iree_compiler::IREE::GPU::ReorderWorkgroupsStrategy>( - value))); -} - -ireeGPUReorderWorkgroupsStrategyEnum -ireeGPUReorderWorkgroupsStrategyAttrGetValue(MlirAttribute attr) { - assert(ireeAttributeIsAGPUReorderWorkgroupsStrategyAttr(attr) && - "attr is not a GPUReorderWorkgroupsStrategyAttr"); - return static_cast( - llvm::cast( - unwrap(attr)) - .getValue()); -} +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h" +#include "iree/compiler/dialects/iree_gpu.h" +#include "mlir-c/IR.h" +#include "mlir/CAPI/IR.h" +#include "mlir/CAPI/Support.h" + +bool ireeAttributeIsAGPUPipelineOptionsAttr(MlirAttribute attr) { + return llvm::isa( + unwrap(attr)); +} + +MlirAttribute +ireeGPUPipelineOptionsAttrGet(MlirContext mlirCtx, bool *prefetchSharedMemory, + bool *noReduceSharedMemoryBankConflicts, + MlirAttribute *reorderWorkgroupsStrategy) { + mlir::MLIRContext *ctx = unwrap(mlirCtx); + mlir::Builder b(ctx); + auto prefetchSharedMemoryAttr = mlir::BoolAttr(); + if (prefetchSharedMemory) { + prefetchSharedMemoryAttr = b.getBoolAttr(*prefetchSharedMemory); + } + auto noReduceSharedMemoryBankConflictsAttr = mlir::BoolAttr(); + if (noReduceSharedMemoryBankConflicts) { + noReduceSharedMemoryBankConflictsAttr = + b.getBoolAttr(*noReduceSharedMemoryBankConflicts); + } + auto strategyAttr = + mlir::iree_compiler::IREE::GPU::ReorderWorkgroupsStrategyAttr(); + if (reorderWorkgroupsStrategy) { + strategyAttr = llvm::dyn_cast< + mlir::iree_compiler::IREE::GPU::ReorderWorkgroupsStrategyAttr>( + unwrap(*reorderWorkgroupsStrategy)); + } + return wrap(mlir::iree_compiler::IREE::GPU::GPUPipelineOptionsAttr::get( + ctx, prefetchSharedMemoryAttr, noReduceSharedMemoryBankConflictsAttr, + strategyAttr)); +} + +MlirAttribute +ireeGPUPipelineOptionsAttrGetPrefetchSharedMemory(MlirAttribute attr) { + auto gpuAttr = + llvm::cast( + unwrap(attr)); + return wrap(gpuAttr.getPrefetchSharedMemory()); +} + +MlirAttribute ireeGPUPipelineOptionsAttrGetNoReduceSharedMemoryBankConflicts( + MlirAttribute attr) { + auto gpuAttr = + llvm::cast( + unwrap(attr)); + return wrap(gpuAttr.getNoReduceSharedMemoryBankConflicts()); +} + +MlirAttribute +ireeGPUPipelineOptionsAttrGetReorderWorkgroupsStrategy(MlirAttribute attr) { + auto gpuAttr = + llvm::cast( + unwrap(attr)); + return wrap(gpuAttr.getReorderWorkgroupsStrategy()); +} + +MlirTypeID ireeGPUPipelineOptionsAttrGetTypeID() { + return wrap( + mlir::iree_compiler::IREE::GPU::GPUPipelineOptionsAttr::getTypeID()); +} + +static_assert( + static_cast(ireeGPUReorderWorkgroupsStrategyEnumNone) == + static_cast(mlir::iree_compiler::IREE::GPU:: + ReorderWorkgroupsStrategy::None) && + static_cast(ireeGPUReorderWorkgroupsStrategyEnumSwizzle) == + static_cast(mlir::iree_compiler::IREE::GPU:: + ReorderWorkgroupsStrategy::Swizzle) && + static_cast(ireeGPUReorderWorkgroupsStrategyEnumTranspose) == + static_cast(mlir::iree_compiler::IREE::GPU:: + ReorderWorkgroupsStrategy::Transpose) && + static_cast(ireeGPUReorderWorkgroupsStrategyEnumTranspose) == + mlir::iree_compiler::IREE::GPU:: + getMaxEnumValForReorderWorkgroupsStrategy(), + "ireeGPUReorderWorkgroupsStrategyEnum and " + "mlir::iree_compiler::IREE::GPU::ReorderWorkgroupsStrategy definitions " + "have diverged"); + +bool ireeAttributeIsAGPUReorderWorkgroupsStrategyAttr(MlirAttribute attr) { + return llvm::isa< + mlir::iree_compiler::IREE::GPU::ReorderWorkgroupsStrategyAttr>( + unwrap(attr)); +} + +MlirTypeID ireeGPUReorderWorkgroupsStrategyAttrGetTypeID() { + return wrap(mlir::iree_compiler::IREE::GPU::ReorderWorkgroupsStrategyAttr:: + getTypeID()); +} + +MlirAttribute ireeGPUReorderWorkgroupsStrategyAttrGet( + MlirContext mlirCtx, ireeGPUReorderWorkgroupsStrategyEnum value) { + mlir::MLIRContext *ctx = unwrap(mlirCtx); + return wrap( + mlir::iree_compiler::IREE::GPU::ReorderWorkgroupsStrategyAttr::get( + ctx, static_cast< + mlir::iree_compiler::IREE::GPU::ReorderWorkgroupsStrategy>( + value))); +} + +ireeGPUReorderWorkgroupsStrategyEnum +ireeGPUReorderWorkgroupsStrategyAttrGetValue(MlirAttribute attr) { + assert(ireeAttributeIsAGPUReorderWorkgroupsStrategyAttr(attr) && + "attr is not a GPUReorderWorkgroupsStrategyAttr"); + return static_cast( + llvm::cast( + unwrap(attr)) + .getValue()); +} diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPULinkExecutables.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPULinkExecutables.cpp index 8a2e91c6a646..7bfe586beec5 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPULinkExecutables.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPULinkExecutables.cpp @@ -19,7 +19,8 @@ namespace { struct LLVMCPULinkExecutablesPass : public impl::LLVMCPULinkExecutablesPassBase { - LLVMCPULinkExecutablesPass() = default; + using impl::LLVMCPULinkExecutablesPassBase< + LLVMCPULinkExecutablesPass>::LLVMCPULinkExecutablesPassBase; void runOnOperation() override { auto moduleOp = getOperation(); auto moduleBuilder = OpBuilder::atBlockBegin(moduleOp.getBody()); @@ -30,29 +31,36 @@ struct LLVMCPULinkExecutablesPass return; // Guess a module name, if needed, to make the output files readable. - auto moduleName = guessModuleName(moduleOp, "llvm_module"); + auto moduleName = guessModuleName(moduleOp, "module"); // Create our new "linked" hal.executable. - std::string linkedExecutableName = - llvm::formatv("{0}_linked_{1}", moduleName, "llvm_cpu"); + SymbolTable moduleTable(moduleOp); + std::string linkedExecutableName = llvm::formatv("{0}_linked", moduleName); auto linkedExecutableOp = moduleBuilder.create( moduleOp.getLoc(), linkedExecutableName); linkedExecutableOp.setVisibility( sourceExecutableOps.front().getVisibility()); + moduleTable.insert(linkedExecutableOp); auto executableBuilder = OpBuilder::atBlockBegin(&linkedExecutableOp.getBlock()); // Gather all unique executable targets - we may have multiple. auto executableTargetAttrs = gatherExecutableTargets(sourceExecutableOps); - for (auto [index, attr] : llvm::enumerate(executableTargetAttrs)) { + for (auto [index, targetAttr] : llvm::enumerate(executableTargetAttrs)) { + // Only link the target specified. If none specified link all. + if (!target.empty() && targetAttr.getBackend().getValue() != target) { + continue; // not linking this target + } + // Add our hal.executable.variant with an empty module. std::string linkedVariantName = executableTargetAttrs.size() == 1 - ? attr.getSymbolNameFragment() - : llvm::formatv("{0}_{1}", attr.getSymbolNameFragment(), index); + ? targetAttr.getSymbolNameFragment() + : llvm::formatv("{0}_{1}", targetAttr.getSymbolNameFragment(), + index); auto linkedTargetOp = executableBuilder.create( - moduleOp.getLoc(), linkedVariantName, attr); + moduleOp.getLoc(), linkedVariantName, targetAttr); auto targetBuilder = OpBuilder::atBlockBegin(&linkedTargetOp.getBlock()); targetBuilder.create(moduleOp.getLoc()); @@ -71,5 +79,6 @@ struct LLVMCPULinkExecutablesPass } } }; + } // namespace } // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp index 71b3aec7389f..9ef65e28e94f 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp @@ -827,9 +827,12 @@ void buildLLVMCPUCodegenPassPipeline(OpPassManager &variantPassManager, // NOTE: this runs on the top-level program module containing all // hal.executable ops. -void buildLLVMCPULinkingPassPipeline(OpPassManager &modulePassManager) { +void buildLLVMCPULinkingPassPipeline(OpPassManager &modulePassManager, + std::optional target) { // Link together executables. This may produce some IR duplication. - modulePassManager.addPass(createLLVMCPULinkExecutablesPass()); + LLVMCPULinkExecutablesPassOptions linkOptions; + linkOptions.target = target.value_or(""); + modulePassManager.addPass(createLLVMCPULinkExecutablesPass(linkOptions)); // Cleanup IR duplication. modulePassManager.addNestedPass( diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.h b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.h index 42d4035260db..4696bc808118 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.h +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.h @@ -12,6 +12,8 @@ #ifndef IREE_COMPILER_CODEGEN_LLVMCPU_PASSES_H_ #define IREE_COMPILER_CODEGEN_LLVMCPU_PASSES_H_ +#include + #include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h" #include "mlir/Pass/Pass.h" @@ -156,7 +158,9 @@ void buildLLVMCPUCodegenPassPipeline(OpPassManager &variantPassManager, //----------------------------------------------------------------------------// /// Populates passes needed to link HAL executables across LLVMCPU targets. -void buildLLVMCPULinkingPassPipeline(OpPassManager &modulePassManager); +void buildLLVMCPULinkingPassPipeline( + OpPassManager &modulePassManager, + std::optional target = std::nullopt); //----------------------------------------------------------------------------// // Register LLVMCPU Passes diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.td b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.td index c9aec6740923..12f90be95ee0 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.td +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.td @@ -69,6 +69,13 @@ def LLVMCPUEmitVectorizationRemarksPass : def LLVMCPULinkExecutablesPass : Pass<"iree-llvmcpu-link-executables", "mlir::ModuleOp"> { let summary = "Links LLVMCPU HAL executables within the top-level program module."; + let options = [ + Option< + "target", "target", + "std::string", "", + "Target backend name whose executables will be linked by this pass." + >, + ]; } def LLVMCPULowerExecutableTargetPass : diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel index 3d8c7a2088b0..19af0c4155c3 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel @@ -91,11 +91,13 @@ iree_compiler_cc_library( "ConvertToROCDL.cpp", "ExtractAddressComputationGPUPass.cpp", "KernelConfig.cpp", + "LLVMGPUAssignConstantOrdinals.cpp", "LLVMGPUCastAddressSpaceFunction.cpp", "LLVMGPUCastTypeToFitMMA.cpp", "LLVMGPUConfigureTensorLayouts.cpp", "LLVMGPUConfigureVectorLayouts.cpp", "LLVMGPUConvolutionToIGEMM.cpp", + "LLVMGPULinkExecutables.cpp", "LLVMGPULowerExecutableTarget.cpp", "LLVMGPUPackSharedMemoryAlloc.cpp", "LLVMGPUPrefetching.cpp", diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt index 9016d63b6f24..aa2c5a56bea5 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt @@ -76,11 +76,13 @@ iree_cc_library( "ConvertToROCDL.cpp" "ExtractAddressComputationGPUPass.cpp" "KernelConfig.cpp" + "LLVMGPUAssignConstantOrdinals.cpp" "LLVMGPUCastAddressSpaceFunction.cpp" "LLVMGPUCastTypeToFitMMA.cpp" "LLVMGPUConfigureTensorLayouts.cpp" "LLVMGPUConfigureVectorLayouts.cpp" "LLVMGPUConvolutionToIGEMM.cpp" + "LLVMGPULinkExecutables.cpp" "LLVMGPULowerExecutableTarget.cpp" "LLVMGPUPackSharedMemoryAlloc.cpp" "LLVMGPUPrefetching.cpp" diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUAssignConstantOrdinals.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUAssignConstantOrdinals.cpp new file mode 100644 index 000000000000..c789b92a644c --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUAssignConstantOrdinals.cpp @@ -0,0 +1,53 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Codegen/LLVMGPU/Passes.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Pass/Pass.h" + +namespace mlir::iree_compiler { + +#define GEN_PASS_DEF_LLVMGPUASSIGNCONSTANTORDINALSPASS +#include "iree/compiler/Codegen/LLVMGPU/Passes.h.inc" + +namespace { + +struct LLVMGPUAssignConstantOrdinalsPass + : public impl::LLVMGPUAssignConstantOrdinalsPassBase< + LLVMGPUAssignConstantOrdinalsPass> { + void runOnOperation() override { + auto variantOp = getOperation(); + + // Get a constant key -> ordinal mapping. + auto keyOrdinals = variantOp.gatherConstantOrdinals(); + if (keyOrdinals.empty()) + return; + + // Update placeholders to hold the concrete ordinal values. + // Eventually MLIR or LLVM will inline them. + auto moduleOp = variantOp.getInnerModule(); + for (auto globalOp : + llvm::make_early_inc_range(moduleOp.getOps())) { + auto keyAttr = globalOp->getAttr( + IREE::HAL::ExecutableConstantBlockOp::getKeyAttrName()); + if (!keyAttr) + continue; + auto it = keyOrdinals.find(keyAttr); + if (it == keyOrdinals.end()) { + globalOp.emitOpError() + << "no constant block providing key '" << keyAttr << "'"; + return signalPassFailure(); + } + globalOp->removeAttr( + IREE::HAL::ExecutableConstantBlockOp::getKeyAttrName()); + globalOp.setConstantAttr(UnitAttr::get(globalOp.getContext())); + globalOp.setValueAttr(IntegerAttr::get( + IntegerType::get(globalOp.getContext(), 32), it->second)); + } + } +}; +} // namespace +} // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPULinkExecutables.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPULinkExecutables.cpp new file mode 100644 index 000000000000..5ffaff984b98 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPULinkExecutables.cpp @@ -0,0 +1,123 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Codegen/LLVMGPU/Passes.h" +#include "iree/compiler/Codegen/Utils/LinkingUtils.h" +#include "iree/compiler/Utils/ModuleUtils.h" +#include "llvm/Support/FormatVariadic.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Pass/Pass.h" + +namespace mlir::iree_compiler { + +#define GEN_PASS_DEF_LLVMGPULINKEXECUTABLESPASS +#include "iree/compiler/Codegen/LLVMGPU/Passes.h.inc" + +namespace { + +// Returns true if the address space of a global symbol is private to the module +// scope it originates in. AMD and NVIDIA disagree on the naming but the values +// match. LLVM is a mess here. +static bool isSymbolAddressSpacePrivate(uint32_t addressSpace) { + return addressSpace == /*local*/ 3 || addressSpace == /*private*/ 5; +} + +static SymbolTable::Visibility +convertLinkageToVisibility(LLVM::Linkage linkage) { + switch (linkage) { + case LLVM::Linkage::Private: + return SymbolTable::Visibility::Private; + case LLVM::Linkage::External: + return SymbolTable::Visibility::Public; + default: + return SymbolTable::Visibility::Public; + } +} + +// Returns true if we are allowed to rename |op| as part of merging. +// The LLVMGPU lowering is super careful about assigning linkage so we err on +// the side of renaming (as 100% of usage today does not reference external +// things). +static bool allowRenamingPrivateLLVMSymbols(Operation *op) { + if (auto globalOp = dyn_cast(op)) { + if (isSymbolAddressSpacePrivate(globalOp.getAddrSpace())) { + return true; + } + return convertLinkageToVisibility(globalOp.getLinkage()) == + SymbolTable::Visibility::Private; + } else if (auto funcOp = dyn_cast(op)) { + return convertLinkageToVisibility(funcOp.getLinkage()) == + SymbolTable::Visibility::Private; + } + return SymbolTable::getSymbolVisibility(op) == + SymbolTable::Visibility::Private; +} + +struct LLVMGPULinkExecutablesPass + : public impl::LLVMGPULinkExecutablesPassBase { + using impl::LLVMGPULinkExecutablesPassBase< + LLVMGPULinkExecutablesPass>::LLVMGPULinkExecutablesPassBase; + void runOnOperation() override { + auto moduleOp = getOperation(); + auto moduleBuilder = OpBuilder::atBlockBegin(moduleOp.getBody()); + + auto sourceExecutableOps = + llvm::to_vector<8>(moduleOp.getOps()); + if (sourceExecutableOps.size() <= 1) + return; + + // Guess a module name, if needed, to make the output files readable. + auto moduleName = guessModuleName(moduleOp, "module"); + + // Create our new "linked" hal.executable. + SymbolTable moduleTable(moduleOp); + std::string linkedExecutableName = llvm::formatv("{0}_linked", moduleName); + auto linkedExecutableOp = moduleBuilder.create( + moduleOp.getLoc(), linkedExecutableName); + linkedExecutableOp.setVisibility( + sourceExecutableOps.front().getVisibility()); + moduleTable.insert(linkedExecutableOp); + auto executableBuilder = + OpBuilder::atBlockBegin(&linkedExecutableOp.getBlock()); + + // Gather all unique executable targets - we may have multiple. + auto executableTargetAttrs = gatherExecutableTargets(sourceExecutableOps); + for (auto [index, targetAttr] : llvm::enumerate(executableTargetAttrs)) { + // Only link the target specified. If none specified link all. + if (!target.empty() && targetAttr.getBackend().getValue() != target) { + continue; // not linking this target + } + + // Add our hal.executable.variant with an empty module. + std::string linkedVariantName = + executableTargetAttrs.size() == 1 + ? targetAttr.getSymbolNameFragment() + : llvm::formatv("{0}_{1}", targetAttr.getSymbolNameFragment(), + index); + auto linkedTargetOp = + executableBuilder.create( + moduleOp.getLoc(), linkedVariantName, targetAttr); + auto targetBuilder = OpBuilder::atBlockBegin(&linkedTargetOp.getBlock()); + targetBuilder.create(moduleOp.getLoc()); + + auto mergeModuleFn = [](mlir::ModuleOp sourceInnerModule, + mlir::ModuleOp linkedInnerModule, + DenseMap &symbolMap) { + return mergeModuleInto(sourceInnerModule, linkedInnerModule, symbolMap, + allowRenamingPrivateLLVMSymbols); + }; + + // Try linking together all executables in moduleOp. + if (failed(linkExecutablesInto(moduleOp, sourceExecutableOps, + linkedExecutableOp, linkedTargetOp, + mergeModuleFn))) { + return signalPassFailure(); + } + } + } +}; +} // namespace +} // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp index 86f65e1b0cc9..3c7eaf88eb46 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp @@ -1220,6 +1220,25 @@ void buildLLVMGPUCodegenPassPipeline(OpPassManager &variantPassManager, }); } +// NOTE: this runs on the top-level program module containing all +// hal.executable ops. +void buildLLVMGPULinkingPassPipeline(OpPassManager &modulePassManager, + std::optional target) { + // Link together executables. This may produce some IR duplication. + LLVMGPULinkExecutablesPassOptions linkOptions; + linkOptions.target = target.value_or(""); + modulePassManager.addPass(createLLVMGPULinkExecutablesPass(linkOptions)); + + // Cleanup IR duplication. + modulePassManager.addNestedPass( + mlir::createCanonicalizerPass()); + + // Assign final executable constant and import ordinals. + auto &variantPassManager = modulePassManager.nest() + .nest(); + variantPassManager.addPass(createLLVMGPUAssignConstantOrdinalsPass()); +} + //===----------------------------------------------------------------------===// // ROCDL Pass Pipelines //===----------------------------------------------------------------------===// @@ -1298,6 +1317,13 @@ void registerCodegenLLVMGPUPasses() { [](OpPassManager &passManager) { buildLLVMGPUCodegenPassPipeline(passManager, true); }); + + static PassPipelineRegistration<> LLVMGPULinkingPipeline( + "iree-codegen-llvmgpu-linking-pipeline", + "Runs the LLVMGPU HAL executable linking pipeline", + [](OpPassManager &modulePassManager) { + buildLLVMGPULinkingPassPipeline(modulePassManager); + }); } //===---------------------------------------------------------------------===// diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.h b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.h index d9325647a50d..e7132c7bbd08 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.h +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.h @@ -12,6 +12,8 @@ #ifndef IREE_COMPILER_CODEGEN_LLVMGPU_PASSES_H_ #define IREE_COMPILER_CODEGEN_LLVMGPU_PASSES_H_ +#include + #include "iree/compiler/Codegen/Common/GPU/Passes.h" #include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h" #include "iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.h" @@ -22,7 +24,7 @@ namespace mlir::iree_compiler { using IREE::GPU::GPUPipelineOptions; //----------------------------------------------------------------------------// -// LLVMGPU backend Pass Pipelines. +// LLVMGPU Backend Pass Pipelines //----------------------------------------------------------------------------// /// Lowering using SIMT CUDA core operations. @@ -99,8 +101,17 @@ verifyGPUMatmulPipeline(Operation *op, IREE::Codegen::TranslationInfoAttr translationInfo, ArrayRef workgroupSize); +//----------------------------------------------------------------------------// +// LLVMGPU Linking Passes and Pipelines +//----------------------------------------------------------------------------// + +/// Populates passes needed to link HAL executables across LLVMGPU targets. +void buildLLVMGPULinkingPassPipeline( + OpPassManager &modulePassManager, + std::optional target = std::nullopt); + //------------------------------------------------------------------------------ -// Wrappers that not use tablegen options. +// Wrappers that do not use tablegen options //------------------------------------------------------------------------------ enum class GPUTensorCoreType { diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.td b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.td index aa6b55253734..0b8df811a628 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.td +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.td @@ -66,6 +66,11 @@ def ExtractAddressComputationGPUPass: Pass<"extract-address-computation-gpu"> { ]; } +def LLVMGPUAssignConstantOrdinalsPass : + Pass<"iree-llvmgpu-assign-constant-ordinals", "IREE::HAL::ExecutableVariantOp"> { + let summary = "Assigns executable constant ordinals across all LLVMGPU variants."; +} + def LLVMGPUCastAddressSpaceFunctionPass : Pass<"iree-llvmgpu-cast-address-space-function", "ModuleOp"> { let summary = "Cast address space to generic in CallOp and FuncOp"; @@ -98,6 +103,18 @@ def LLVMGPUConvolutionToIGEMMPass : ]; } +def LLVMGPULinkExecutablesPass : + Pass<"iree-llvmgpu-link-executables", "mlir::ModuleOp"> { + let summary = "Links LLVMGPU HAL executables within the top-level program module."; + let options = [ + Option< + "target", "target", + "std::string", "", + "Target backend name whose executables will be linked by this pass." + >, + ]; +} + def LLVMGPULowerExecutableTargetPass : InterfacePass<"iree-llvmgpu-lower-executable-target", "mlir::FunctionOpInterface"> { let summary = "Perform lowering of executable target using one of the IREE::HAL::DispatchLoweringPassPipeline"; diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel index 40973205380e..1088035a5697 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel @@ -21,6 +21,7 @@ iree_lit_test_suite( "amdgpu_chained_matmul.mlir", "amdgpu_contraction_distribution.mlir", "amdgpu_set_anchor_layouts.mlir", + "assign_constant_ordinals.mlir", "conv_pipeline_test_cuda.mlir", "conv_pipeline_test_rocm.mlir", "convert_to_nvvm.mlir", @@ -38,6 +39,7 @@ iree_lit_test_suite( "gpu_set_num_workgroups.mlir", "gpu_pipeline_generalize_named_ops.mlir", "gpu_pipeline_igemm.mlir", + "link_executables.mlir", "nvvm_extract_address_computation.mlir", "nvvm_pipeline_test.mlir", "nvvm_mma_sync_pipeline_test.mlir", diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt index 2a86fd3507f4..795ee25f3303 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt @@ -17,6 +17,7 @@ iree_lit_test_suite( "amdgpu_chained_matmul.mlir" "amdgpu_contraction_distribution.mlir" "amdgpu_set_anchor_layouts.mlir" + "assign_constant_ordinals.mlir" "cast_address_space_function.mlir" "cast_type_to_fit_mma.mlir" "config_custom_op.mlir" @@ -39,6 +40,7 @@ iree_lit_test_suite( "illegal_configuration.mlir" "legalize.mlir" "linalg_transform.mlir" + "link_executables.mlir" "llvmgpu_bufferize.mlir" "llvmgpu_convolution_to_igemm.mlir" "nvvm_extract_address_computation.mlir" diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/assign_constant_ordinals.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/assign_constant_ordinals.mlir new file mode 100644 index 000000000000..8a133f91a8fb --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/assign_constant_ordinals.mlir @@ -0,0 +1,22 @@ +// RUN: iree-opt --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(iree-llvmgpu-assign-constant-ordinals)))" --split-input-file %s | FileCheck %s + +hal.executable private @executable { + hal.executable.variant public @variant target(#hal.executable.target<"rocm", "rocm-hsaco-fb">) { + hal.executable.constant.block(%device: !hal.device) -> i32 as "foo" { + %c0 = arith.constant 0 : i32 + hal.return %c0 : i32 + } + hal.executable.constant.block(%device: !hal.device) -> i32 as "bar" { + %c1 = arith.constant 1 : i32 + hal.return %c1 : i32 + } + builtin.module { + // CHECK: llvm.mlir.global internal constant @__constant_ordinal_foo_a(0 : i32) + llvm.mlir.global internal @__constant_ordinal_foo_a() {addr_space = 4 : i32, hal.executable.constant.key = "foo", sym_visibility = "private"} : i32 + // CHECK: llvm.mlir.global internal constant @__constant_ordinal_foo_b(0 : i32) + llvm.mlir.global internal @__constant_ordinal_foo_b() {addr_space = 4 : i32, hal.executable.constant.key = "foo", sym_visibility = "private"} : i32 + // CHECK: llvm.mlir.global internal constant @__constant_ordinal_bar(1 : i32) + llvm.mlir.global internal @__constant_ordinal_bar() {addr_space = 4 : i32, hal.executable.constant.key = "bar", sym_visibility = "private"} : i32 + } + } +} diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/link_executables.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/link_executables.mlir new file mode 100644 index 000000000000..5655992d1d8f --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/link_executables.mlir @@ -0,0 +1,150 @@ +// RUN: iree-opt --iree-llvmgpu-link-executables --split-input-file %s | FileCheck %s +// RUN: iree-opt --pass-pipeline='builtin.module(iree-llvmgpu-link-executables{target="rocm"})' --split-input-file %s | FileCheck %s --check-prefix=CHECK-TARGET +// RUN: iree-opt --pass-pipeline='builtin.module(iree-llvmgpu-link-executables{target="cuda"},iree-llvmgpu-link-executables{target="rocm"})' --split-input-file %s | FileCheck %s --check-prefix=CHECK-MULTI + +#executable_target_rocm = #hal.executable.target<"rocm", "rocm-hsaco-fb"> + +// Expect a single executable with both exports and correct ordinals. +// CHECK: hal.executable private @link_executables_linked +// CHECK: hal.executable.variant public @rocm_hsaco_fb +// CHECK: hal.executable.export public @export0 ordinal(0) +// CHECK: hal.executable.export public @export1 ordinal(1) + +// Expect one LLVM module with all globals and functions. +// Note that shared memory is duplicated but dynamic shared memory is not. +// CHECK: builtin.module +// CHECK-NEXT: llvm.mlir.global external @__dynamic_shared_memory__ +// CHECK-NEXT: llvm.mlir.global private @__shared_memory__{{.+}} : !llvm.array<2 x array<64 x i32>> +// CHECK-NEXT: llvm.func @export0 +// CHECK-NEXT: llvm.mlir.addressof @__dynamic_shared_memory__ : !llvm.ptr<3> +// CHECK-NEXT: llvm.mlir.addressof @__shared_memory__ : !llvm.ptr<3> +// CHECK: llvm.mlir.global private @__shared_memory___0{{.+}} : !llvm.array<2 x array<128 x i32>> +// CHECK-NEXT: llvm.func @export1 +// CHECK-NEXT: llvm.mlir.addressof @__dynamic_shared_memory__ : !llvm.ptr<3> +// CHECK-NEXT: llvm.mlir.addressof @__shared_memory___0 : !llvm.ptr<3> + +hal.executable private @executable0 { + hal.executable.variant public @rocm_hsaco_fb target(#executable_target_rocm) { + hal.executable.export public @export0 ordinal(0) layout(#hal.pipeline.layout]>) { + ^bb0(%arg0: !hal.device): + %c1 = arith.constant 1 : index + hal.return %c1, %c1, %c1 : index, index, index + } + builtin.module { + llvm.mlir.global external @__dynamic_shared_memory__() {addr_space = 3 : i32, alignment = 16 : i64} : !llvm.array<0 x i8> + llvm.mlir.global private @__shared_memory__() {addr_space = 3 : i32, alignment = 4 : i64} : !llvm.array<2 x array<64 x i32>> + llvm.func @export0(%arg0: !llvm.ptr<1> {llvm.align = 16 : i32, llvm.noalias}) { + %0 = llvm.mlir.addressof @__dynamic_shared_memory__ : !llvm.ptr<3> + %1 = llvm.mlir.addressof @__shared_memory__ : !llvm.ptr<3> + llvm.return + } + } + } +} +hal.executable private @executable1 { + hal.executable.variant public @rocm_hsaco_fb target(#executable_target_rocm) { + hal.executable.export public @export1 ordinal(0) layout(#hal.pipeline.layout]>) { + ^bb0(%arg0: !hal.device): + %c1 = arith.constant 1 : index + hal.return %c1, %c1, %c1 : index, index, index + } + builtin.module { + llvm.mlir.global external @__dynamic_shared_memory__() {addr_space = 3 : i32, alignment = 16 : i64} : !llvm.array<0 x i8> + llvm.mlir.global private @__shared_memory__() {addr_space = 3 : i32, alignment = 4 : i64} : !llvm.array<2 x array<128 x i32>> + llvm.func @export1(%arg0: !llvm.ptr<1> {llvm.align = 16 : i32, llvm.noalias}) { + %0 = llvm.mlir.addressof @__dynamic_shared_memory__ : !llvm.ptr<3> + %1 = llvm.mlir.addressof @__shared_memory__ : !llvm.ptr<3> + llvm.return + } + } + } +} + +// ----- + +#executable_target_cuda = #hal.executable.target<"cuda", "cuda-nvptx-fb"> +#executable_target_rocm = #hal.executable.target<"rocm", "rocm-hsaco-fb"> + +// Expect a single executable with multiple variants when not specifying target. +// CHECK: hal.executable private @link_executables_linked +// CHECK: hal.executable.variant public @cuda_nvptx_fb_0 +// CHECK: hal.executable.export public @export0 ordinal(0) +// CHECK: hal.executable.export public @export1 ordinal(1) +// CHECK: hal.executable.variant public @rocm_hsaco_fb_1 +// CHECK: hal.executable.export public @export0 ordinal(0) +// CHECK: hal.executable.export public @export1 ordinal(1) + +// Expect only one target be linked when specified. +// CHECK-TARGET: hal.executable private @link_executables_linked +// CHECK-TARGET: hal.executable.variant public @rocm_hsaco_fb_1 +// CHECK-TARGET: hal.executable.export public @export0 ordinal(0) +// CHECK-TARGET: hal.executable.export public @export1 ordinal(1) +// CHECK-TARGET: hal.executable private @executable0 +// CHECK-TARGET: hal.executable.variant public @cuda_nvptx_fb +// CHECK-TARGET: hal.executable.export public @export0 ordinal(0) +// CHECK-TARGET: hal.executable private @executable1 +// CHECK-TARGET: hal.executable.variant public @cuda_nvptx_fb +// CHECK-TARGET: hal.executable.export public @export1 ordinal(0) + +// Multiple applications of the pass per target should not conflict. +// CHECK-MULTI: hal.executable private @link_executables_linked_0 +// CHECK-MULTI: hal.executable.variant public @rocm_hsaco_fb_1 +// CHECK-MULTI: hal.executable.export public @export0 ordinal(0) +// CHECK-MULTI: hal.executable.export public @export1 ordinal(1) +// CHECK-MULTI: hal.executable private @link_executables_linked +// CHECK-MULTI: hal.executable.variant public @cuda_nvptx_fb_0 +// CHECK-MULTI: hal.executable.export public @export0 ordinal(0) +// CHECK-MULTI: hal.executable.export public @export1 ordinal(1) + +hal.executable private @executable0 { + hal.executable.variant public @cuda_nvptx_fb target(#executable_target_cuda) { + hal.executable.export public @export0 ordinal(0) layout(#hal.pipeline.layout]>) { + ^bb0(%arg0: !hal.device): + %c1 = arith.constant 1 : index + hal.return %c1, %c1, %c1 : index, index, index + } + builtin.module { + llvm.func @export0(%arg0: !llvm.ptr<1> {llvm.align = 16 : i32, llvm.noalias}) { + llvm.return + } + } + } + hal.executable.variant public @rocm_hsaco_fb target(#executable_target_rocm) { + hal.executable.export public @export0 ordinal(0) layout(#hal.pipeline.layout]>) { + ^bb0(%arg0: !hal.device): + %c1 = arith.constant 1 : index + hal.return %c1, %c1, %c1 : index, index, index + } + builtin.module { + llvm.func @export0(%arg0: !llvm.ptr<1> {llvm.align = 16 : i32, llvm.noalias}) { + llvm.return + } + } + } +} +hal.executable private @executable1 { + hal.executable.variant public @cuda_nvptx_fb target(#executable_target_cuda) { + hal.executable.export public @export1 ordinal(0) layout(#hal.pipeline.layout]>) { + ^bb0(%arg0: !hal.device): + %c1 = arith.constant 1 : index + hal.return %c1, %c1, %c1 : index, index, index + } + builtin.module { + llvm.func @export1(%arg0: !llvm.ptr<1> {llvm.align = 16 : i32, llvm.noalias}) { + llvm.return + } + } + } + hal.executable.variant public @rocm_hsaco_fb target(#executable_target_rocm) { + hal.executable.export public @export1 ordinal(0) layout(#hal.pipeline.layout]>) { + ^bb0(%arg0: !hal.device): + %c1 = arith.constant 1 : index + hal.return %c1, %c1, %c1 : index, index, index + } + builtin.module { + llvm.func @export1(%arg0: !llvm.ptr<1> {llvm.align = 16 : i32, llvm.noalias}) { + llvm.return + } + } + } +} diff --git a/compiler/src/iree/compiler/Codegen/Utils/LinkingUtils.cpp b/compiler/src/iree/compiler/Codegen/Utils/LinkingUtils.cpp index ad4e543d70b3..003d3f759d0e 100644 --- a/compiler/src/iree/compiler/Codegen/Utils/LinkingUtils.cpp +++ b/compiler/src/iree/compiler/Codegen/Utils/LinkingUtils.cpp @@ -67,7 +67,8 @@ renameWithDisambiguatedName(Operation *op, Operation *moduleOp, // symbol tracked in |targetSymbolMap|. LogicalResult mergeModuleInto(Operation *sourceModuleOp, Operation *targetModuleOp, - DenseMap &targetSymbolMap) { + DenseMap &targetSymbolMap, + std::function canRenameSymbol) { auto &sourceBlock = sourceModuleOp->getRegion(0).front(); auto &targetBlock = targetModuleOp->getRegion(0).front(); SymbolTable sourceSymbolTable(sourceModuleOp); @@ -90,15 +91,19 @@ mergeModuleInto(Operation *sourceModuleOp, Operation *targetModuleOp, // use the existing target op. continue; } - if (symbolOp.getVisibility() == SymbolTable::Visibility::Private) { + if (canRenameSymbol(symbolOp)) { // Since the source symbol is private we can rename it as all uses // are known to be local to the source module. renameWithDisambiguatedName(sourceOp, sourceModuleOp, targetSymbolMap, &sourceSymbolTable); } else { // The source symbol has 'nested' or 'public' visibility. - if (SymbolTable::getSymbolVisibility(targetOp) != - SymbolTable::Visibility::Private) { + if (canRenameSymbol(targetOp)) { + // Keep the original name for our new op, rename the target op. + renameWithDisambiguatedName(targetOp, targetModuleOp, + targetSymbolMap, + /*optionalSymbolTable=*/nullptr); + } else { // Oops! Both symbols are public and we can't safely rename either. // If you hit this with ops that you think are safe to rename, mark // them private. @@ -109,11 +114,6 @@ mergeModuleInto(Operation *sourceModuleOp, Operation *targetModuleOp, // where that isn't true. return sourceOp->emitError() << "multiple public symbols with the name: " << symbolName; - } else { - // Keep the original name for our new op, rename the target op. - renameWithDisambiguatedName(targetOp, targetModuleOp, - targetSymbolMap, - /*optionalSymbolTable=*/nullptr); } } } diff --git a/compiler/src/iree/compiler/Codegen/Utils/LinkingUtils.h b/compiler/src/iree/compiler/Codegen/Utils/LinkingUtils.h index cf4ca4db47b5..a33f168d08c2 100644 --- a/compiler/src/iree/compiler/Codegen/Utils/LinkingUtils.h +++ b/compiler/src/iree/compiler/Codegen/Utils/LinkingUtils.h @@ -19,6 +19,11 @@ gatherExecutableTargets(ArrayRef executableOps); // TODO(benvanik): replace with iree/compiler/Utils/ModuleUtils.h version. // Only difference is one has the symbol map that we don't even need. +static inline bool allowRenamingPrivateSymbols(Operation *op) { + return SymbolTable::getSymbolVisibility(op) == + SymbolTable::Visibility::Private; +} + // Destructively merges |sourceModuleOp| into |targetModuleOp|. // |targetSymbolMap| is updated with the new symbols. // @@ -29,7 +34,9 @@ gatherExecutableTargets(ArrayRef executableOps); // symbol tracked in |targetSymbolMap|. LogicalResult mergeModuleInto(Operation *sourceModuleOp, Operation *targetModuleOp, - DenseMap &targetSymbolMap); + DenseMap &targetSymbolMap, + std::function canRenameSymbol = + allowRenamingPrivateSymbols); // Links all executables for the current target found in |moduleOp| into // |linkedExecutableOp|. Functions will be moved into |linkedModuleOp|. diff --git a/experimental/web/sample_static/device_multithreaded.c b/experimental/web/sample_static/device_multithreaded.c index c70924bdc4bd..8b5ba39f6c78 100644 --- a/experimental/web/sample_static/device_multithreaded.c +++ b/experimental/web/sample_static/device_multithreaded.c @@ -18,7 +18,7 @@ iree_status_t create_device_with_static_loader(iree_allocator_t host_allocator, // Register the statically linked executable library. const iree_hal_executable_library_query_fn_t libraries[] = { - mnist_linked_llvm_cpu_library_query, + mnist_linked_library_query, }; iree_hal_executable_loader_t* library_loader = NULL; iree_status_t status = iree_hal_static_library_loader_create( diff --git a/experimental/web/sample_static/device_sync.c b/experimental/web/sample_static/device_sync.c index 3fbe3eed0bf6..f072903b963f 100644 --- a/experimental/web/sample_static/device_sync.c +++ b/experimental/web/sample_static/device_sync.c @@ -15,7 +15,7 @@ iree_status_t create_device_with_static_loader(iree_allocator_t host_allocator, // Register the statically linked executable library. const iree_hal_executable_library_query_fn_t libraries[] = { - mnist_linked_llvm_cpu_library_query, + mnist_linked_library_query, }; iree_hal_executable_loader_t* library_loader = NULL; iree_status_t status = iree_hal_static_library_loader_create( diff --git a/tests/e2e/stablehlo_models/CMakeLists.txt b/tests/e2e/stablehlo_models/CMakeLists.txt index f12f2fa970f2..896a852e4640 100644 --- a/tests/e2e/stablehlo_models/CMakeLists.txt +++ b/tests/e2e/stablehlo_models/CMakeLists.txt @@ -42,7 +42,7 @@ iree_static_linker_test( SRC "mnist_fake_weights.mlir" STATIC_LIB_PREFIX - mnist_fake_weights_linked_llvm_cpu + mnist_fake_weights_linked ENTRY_FUNCTION "predict" FUNCTION_INPUTS @@ -57,7 +57,7 @@ iree_static_linker_test( SRC "mnist_fake_weights.mlir" STATIC_LIB_PREFIX - mnist_fake_weights_linked_llvm_cpu + mnist_fake_weights_linked ENTRY_FUNCTION "predict" FUNCTION_INPUTS From d1dd3e377e1e5835f8537d0c4052781a833e12e3 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Tue, 29 Oct 2024 19:17:29 -0700 Subject: [PATCH 38/45] Add integer range inference to hal.buffer_view.dim and rank ops. (#18943) This matches that default range behavior of runtime dimensions we get from frontends. --------- Signed-off-by: Stella Laurenzo --- .../iree/compiler/Dialect/HAL/IR/HALOps.cpp | 45 +++++++++++++++++++ .../src/iree/compiler/Dialect/HAL/IR/HALOps.h | 1 + .../iree/compiler/Dialect/HAL/IR/HALOps.td | 11 ++++- .../test/optimize_int_arithmetic.mlir | 30 +++++++++++++ 4 files changed, 85 insertions(+), 2 deletions(-) diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp index f1a820fb245b..81f8da846eb1 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp @@ -22,6 +22,27 @@ namespace mlir::iree_compiler::IREE::HAL { +namespace { + +// We aribtrarily say that unbounded dimensions in a torch program cannot +// exceed 53bits, making the maximum safe dimension 9007199254740991. The +// astute reader will note that this is also the maximum safe value in +// JavaScript, which also "happens" to be the largest mantissa value in a +// 64bit double. We need a maximum and in the absence of a better choice, +// with this one we are at least in good company. This limit is also used +// in the frontends. +static constexpr uint64_t MAX_DIM_VALUE = (static_cast(1) << 53) - 1; + +// Similarly we use a very conservative maximum rank value for specifying +// ranges of runtime rank resolution functions. Various frameworks have hard +// and practical limits ranging from 32 (numpy) to hundreds. At the time of +// writing, PyTorch throws weird errors if trying to print a tensor with a rank +// greater than 992. We really just want a smallish integer value to bound +// arithmetic, so we use an arbitrary maximum. +static constexpr uint64_t MAX_RANK_VALUE = 4096; + +} // namespace + //===----------------------------------------------------------------------===// // custom($descriptor_type) //===----------------------------------------------------------------------===// @@ -1024,6 +1045,30 @@ void BufferViewBufferOp::getAsmResultNames( setNameFn(getResult(), "buffer"); } +//===----------------------------------------------------------------------===// +// hal.buffer_view.dim +//===----------------------------------------------------------------------===// + +void BufferViewDimOp::inferResultRangesFromOptional( + ArrayRef argRanges, SetIntLatticeFn setResultRange) { + const unsigned indexTypeNumBits = 64; + setResultRange(getResult(), IntegerValueRange(ConstantIntRanges::fromUnsigned( + APInt::getZero(indexTypeNumBits), + APInt(indexTypeNumBits, MAX_DIM_VALUE)))); +} + +//===----------------------------------------------------------------------===// +// hal.buffer_view.dim +//===----------------------------------------------------------------------===// + +void BufferViewRankOp::inferResultRangesFromOptional( + ArrayRef argRanges, SetIntLatticeFn setResultRange) { + const unsigned indexTypeNumBits = 64; + setResultRange(getResult(), IntegerValueRange(ConstantIntRanges::fromUnsigned( + APInt::getZero(indexTypeNumBits), + APInt(indexTypeNumBits, MAX_RANK_VALUE)))); +} + //===----------------------------------------------------------------------===// // hal.channel.create //===----------------------------------------------------------------------===// diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.h b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.h index ae58127959bb..16dd46bc5e17 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.h +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.h @@ -20,6 +20,7 @@ #include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/SymbolTable.h" +#include "mlir/Interfaces/InferIntRangeInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Interfaces/ViewLikeInterface.h" diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td index fdd43b7a5e72..9e370a10c22b 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td @@ -18,6 +18,7 @@ include "mlir/IR/OpAsmInterface.td" include "mlir/IR/SymbolInterfaces.td" include "mlir/Interfaces/CallInterfaces.td" include "mlir/Interfaces/ControlFlowInterfaces.td" +include "mlir/Interfaces/InferIntRangeInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/ViewLikeInterface.td" @@ -1010,7 +1011,10 @@ def HAL_BufferViewEncodingTypeOp : HAL_PureOp<"buffer_view.encoding_type"> { }]; } -def HAL_BufferViewRankOp : HAL_PureOp<"buffer_view.rank"> { +def HAL_BufferViewRankOp : HAL_PureOp<"buffer_view.rank", [ + DeclareOpInterfaceMethods, +]> { let summary = [{buffer view rank query}]; let description = [{ Returns the rank of the buffer view. @@ -1030,7 +1034,10 @@ def HAL_BufferViewRankOp : HAL_PureOp<"buffer_view.rank"> { }]; } -def HAL_BufferViewDimOp : HAL_PureOp<"buffer_view.dim"> { +def HAL_BufferViewDimOp : HAL_PureOp<"buffer_view.dim", [ + DeclareOpInterfaceMethods, +]> { let summary = [{buffer view dimension value query}]; let description = [{ Returns the value of the given dimension. diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/optimize_int_arithmetic.mlir b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/optimize_int_arithmetic.mlir index 1924f423ef66..f78817cb03af 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/optimize_int_arithmetic.mlir +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/optimize_int_arithmetic.mlir @@ -493,3 +493,33 @@ util.func @util_align_zero(%arg0 : i64) -> i64 { %rem16 = arith.remui %0, %c16 : i64 util.return %rem16 : i64 } + +// ----- + +util.func @hal_buffer_view_dim_min_max(%bv : !hal.buffer_view) -> (i1, i1, i1) { + %zero = arith.constant 0 : index + %max = arith.constant 9007199254740991 : index + %0 = hal.buffer_view.dim<%bv : !hal.buffer_view>[0] : index + %1 = arith.cmpi slt, %0, %zero : index + %2 = arith.cmpi uge, %0, %zero : index + %3 = arith.cmpi ugt, %0, %max : index + // CHECK-DAG: %[[FALSE:.*]] = arith.constant false + // CHECK-DAG: %[[TRUE:.*]] = arith.constant true + // CHECK: util.return %[[FALSE]], %[[TRUE]], %[[FALSE]] + util.return %1, %2, %3 : i1, i1, i1 +} + +// ----- + +util.func @hal_buffer_view_rank_min_max(%bv : !hal.buffer_view) -> (i1, i1, i1) { + %zero = arith.constant 0 : index + %max = arith.constant 4096 : index + %0 = hal.buffer_view.rank<%bv : !hal.buffer_view> : index + %1 = arith.cmpi slt, %0, %zero : index + %2 = arith.cmpi uge, %0, %zero : index + %3 = arith.cmpi ugt, %0, %max : index + // CHECK-DAG: %[[FALSE:.*]] = arith.constant false + // CHECK-DAG: %[[TRUE:.*]] = arith.constant true + // CHECK: util.return %[[FALSE]], %[[TRUE]], %[[FALSE]] + util.return %1, %2, %3 : i1, i1, i1 +} From 5fc340d77bd4f07f4a99f43d9354a9ead62ae1e5 Mon Sep 17 00:00:00 2001 From: Ben Vanik Date: Tue, 29 Oct 2024 22:06:53 -0700 Subject: [PATCH 39/45] Registering the ROCDL dialect in init_mlir_dialects. (#18944) We're emitting those ops now (directly or indirectly) and in order to parse IR that contains them (from e.g. resuming --compile-to=executable-targets) we must have all dialects registered. --- compiler/src/iree/compiler/Tools/BUILD.bazel | 1 + compiler/src/iree/compiler/Tools/CMakeLists.txt | 2 ++ compiler/src/iree/compiler/Tools/init_mlir_dialects.h | 2 ++ 3 files changed, 5 insertions(+) diff --git a/compiler/src/iree/compiler/Tools/BUILD.bazel b/compiler/src/iree/compiler/Tools/BUILD.bazel index 7c717d641ebf..7a813c2a51cd 100644 --- a/compiler/src/iree/compiler/Tools/BUILD.bazel +++ b/compiler/src/iree/compiler/Tools/BUILD.bazel @@ -114,6 +114,7 @@ iree_compiler_cc_library( "@llvm-project//mlir:MathDialect", "@llvm-project//mlir:MemRefDialect", "@llvm-project//mlir:QuantOps", + "@llvm-project//mlir:ROCDLDialect", "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:SCFToGPU", "@llvm-project//mlir:SCFTransforms", diff --git a/compiler/src/iree/compiler/Tools/CMakeLists.txt b/compiler/src/iree/compiler/Tools/CMakeLists.txt index d9033ba99cee..8ad934245c3f 100644 --- a/compiler/src/iree/compiler/Tools/CMakeLists.txt +++ b/compiler/src/iree/compiler/Tools/CMakeLists.txt @@ -5,6 +5,7 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # Doesn't use bazel_to_cmake because of various special logic throughout. +# That there's various special logic throughout is _bad_. Don't replicate this. # Enable compiler targets based on options. set(IREE_COMPILER_TARGETS "") @@ -95,6 +96,7 @@ iree_cc_library( MLIRLinalgTransforms MLIRMLProgramDialect MLIRQuantDialect + MLIRROCDLDialect MLIRSCFDialect MLIRSCFToGPU MLIRSCFTransforms diff --git a/compiler/src/iree/compiler/Tools/init_mlir_dialects.h b/compiler/src/iree/compiler/Tools/init_mlir_dialects.h index e399e63dffe8..ce2b67134989 100644 --- a/compiler/src/iree/compiler/Tools/init_mlir_dialects.h +++ b/compiler/src/iree/compiler/Tools/init_mlir_dialects.h @@ -29,6 +29,7 @@ #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/ROCDLDialect.h" #include "mlir/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/TransformOps/DialectExtension.h" @@ -82,6 +83,7 @@ inline void registerMlirDialects(DialectRegistry ®istry) { pdl_interp::PDLInterpDialect, scf::SCFDialect, quant::QuantDialect, + ROCDL::ROCDLDialect, spirv::SPIRVDialect, arm_neon::ArmNeonDialect, arm_sve::ArmSVEDialect, From 15ea0dc4a04b3ce792498ccd7183c8ce15b56cfa Mon Sep 17 00:00:00 2001 From: Cullen Rhodes Date: Wed, 30 Oct 2024 08:12:16 +0000 Subject: [PATCH 40/45] [GlobalOpt] Prevent fusing transposed extend in RaiseSpecialOps (#18901) NamedImplicitCastOpConversion pattern is incorrectly fusing transposed element-wise extend into Linalg op. --------- Signed-off-by: Cullen Rhodes --- .../GlobalOptimization/RaiseSpecialOps.cpp | 4 +++ .../test/raise_special_ops.mlir | 27 +++++++++++++++++++ 2 files changed, 31 insertions(+) diff --git a/compiler/src/iree/compiler/GlobalOptimization/RaiseSpecialOps.cpp b/compiler/src/iree/compiler/GlobalOptimization/RaiseSpecialOps.cpp index b7008690b6ab..a1b579e8ba3e 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/RaiseSpecialOps.cpp +++ b/compiler/src/iree/compiler/GlobalOptimization/RaiseSpecialOps.cpp @@ -304,6 +304,10 @@ class NamedImplicitCastOpConversion : public OpInterfaceRewritePattern { return false; } + if (!llvm::all_of(producer.getIndexingMapsArray(), + [](AffineMap map) { return map.isIdentity(); })) + return false; + std::optional castOp = getDefiningNonI1ExtendingCastOp(operand.get()); if (!castOp) { diff --git a/compiler/src/iree/compiler/GlobalOptimization/test/raise_special_ops.mlir b/compiler/src/iree/compiler/GlobalOptimization/test/raise_special_ops.mlir index a1cd2d63216e..c84f128ed15e 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/test/raise_special_ops.mlir +++ b/compiler/src/iree/compiler/GlobalOptimization/test/raise_special_ops.mlir @@ -566,6 +566,33 @@ util.func public @matmul_extsi(%arg0 : tensor<10x20xi32>, // CHECK: util.return %[[RESULT]] // ----- +// Regression test. extsi is transposed, dont't fuse into matmul. +util.func public @matmul_extsi_transposed(%arg0 : tensor<10x20xi32>, + %arg1 : tensor<40x20xi16>) -> tensor<10x40xi32> { + %0 = tensor.empty() : tensor<20x40xi32> + %1 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>], + iterator_types = ["parallel", "parallel"]} + ins(%arg1 : tensor<40x20xi16>) outs(%0 : tensor<20x40xi32>) { + ^bb0(%b0 : i16, %b1 : i32): + %e = arith.extsi %b0 : i16 to i32 + linalg.yield %e : i32 + } -> tensor<20x40xi32> + %2 = tensor.empty() : tensor<10x40xi32> + %3 = arith.constant 0 : i32 + %4 = linalg.fill ins(%3 : i32) outs(%2 : tensor<10x40xi32>) -> tensor<10x40xi32> + %5 = linalg.matmul ins(%arg0, %1 : tensor<10x20xi32>, tensor<20x40xi32>) + outs(%4 : tensor<10x40xi32>) -> tensor<10x40xi32> + util.return %5 : tensor<10x40xi32> +} +// CHECK-LABEL: util.func public @matmul_extsi_transposed +// CHECK-SAME: %[[ARG0:.+]]: tensor<10x20xi32> +// CHECK-SAME: %[[ARG1:.+]]: tensor<40x20xi16> +// CHECK: %[[GEN:.+]] = linalg.generic +// CHECK: %[[RESULT:.+]] = linalg.matmul ins(%[[ARG0]], %[[GEN]] +// CHECK: util.return %[[RESULT]] +// ----- + util.func public @matmul_extsi_a(%arg0 : tensor<10x20xi16>, %arg1 : tensor<20x40xi32>) -> tensor<10x40xi32> { %0 = tensor.empty() : tensor<10x20xi32> From 26ba4fd1717af8f9cdd2552044d25633ce180a11 Mon Sep 17 00:00:00 2001 From: Ben Vanik Date: Wed, 30 Oct 2024 09:50:25 -0700 Subject: [PATCH 41/45] Switching VM's EraseUnusedCallOp pattern to a pass. (#18950) This lets it use a symbol table to speed things up _a lot_. Includes various fixes found during testing. --- .../iree/compiler/API/Internal/BUILD.bazel | 2 + .../iree/compiler/API/Internal/CMakeLists.txt | 2 + .../API/Internal/IREEOptToolEntryPoint.cpp | 4 + .../compiler/Dialect/VM/IR/VMOpFolders.cpp | 46 +------- .../VM/IR/test/control_flow_folding.mlir | 13 +-- .../Dialect/VM/Transforms/BUILD.bazel | 1 + .../Dialect/VM/Transforms/CMakeLists.txt | 1 + .../Dialect/VM/Transforms/DropUnusedCalls.cpp | 104 ++++++++++++++++++ .../compiler/Dialect/VM/Transforms/Passes.cpp | 5 + .../compiler/Dialect/VM/Transforms/Passes.h | 4 + .../Dialect/VM/Transforms/test/BUILD.bazel | 1 + .../Dialect/VM/Transforms/test/CMakeLists.txt | 1 + .../VM/Transforms/test/drop_unused_calls.mlir | 36 ++++++ 13 files changed, 167 insertions(+), 53 deletions(-) create mode 100644 compiler/src/iree/compiler/Dialect/VM/Transforms/DropUnusedCalls.cpp create mode 100644 compiler/src/iree/compiler/Dialect/VM/Transforms/test/drop_unused_calls.mlir diff --git a/compiler/src/iree/compiler/API/Internal/BUILD.bazel b/compiler/src/iree/compiler/API/Internal/BUILD.bazel index 2413bed54150..883f7fb525b0 100644 --- a/compiler/src/iree/compiler/API/Internal/BUILD.bazel +++ b/compiler/src/iree/compiler/API/Internal/BUILD.bazel @@ -78,7 +78,9 @@ iree_compiler_cc_library( deps = [ "//compiler/bindings/c:headers", "//compiler/src/iree/compiler/Dialect/HAL/Target", + "//compiler/src/iree/compiler/Dialect/VM/Target:init_targets", "//compiler/src/iree/compiler/PluginAPI:PluginManager", + "//compiler/src/iree/compiler/Tools:init_llvmir_translations", "//compiler/src/iree/compiler/Tools:init_passes_and_dialects", "@llvm-project//llvm:Support", "@llvm-project//mlir:Debug", diff --git a/compiler/src/iree/compiler/API/Internal/CMakeLists.txt b/compiler/src/iree/compiler/API/Internal/CMakeLists.txt index 191ea93a1cbe..3ee76d97d89f 100644 --- a/compiler/src/iree/compiler/API/Internal/CMakeLists.txt +++ b/compiler/src/iree/compiler/API/Internal/CMakeLists.txt @@ -80,7 +80,9 @@ iree_cc_library( MLIRPass MLIRSupport iree::compiler::Dialect::HAL::Target + iree::compiler::Dialect::VM::Target::init_targets iree::compiler::PluginAPI::PluginManager + iree::compiler::Tools::init_llvmir_translations iree::compiler::Tools::init_passes_and_dialects iree::compiler::bindings::c::headers PUBLIC diff --git a/compiler/src/iree/compiler/API/Internal/IREEOptToolEntryPoint.cpp b/compiler/src/iree/compiler/API/Internal/IREEOptToolEntryPoint.cpp index b621724ae038..4d3b07c268bd 100644 --- a/compiler/src/iree/compiler/API/Internal/IREEOptToolEntryPoint.cpp +++ b/compiler/src/iree/compiler/API/Internal/IREEOptToolEntryPoint.cpp @@ -9,8 +9,10 @@ // Based on mlir-opt but registers the passes and dialects we care about. #include "iree/compiler/Dialect/HAL/Target/TargetRegistry.h" +#include "iree/compiler/Dialect/VM/Target/init_targets.h" #include "iree/compiler/PluginAPI/PluginManager.h" #include "iree/compiler/Tools/init_dialects.h" +#include "iree/compiler/Tools/init_llvmir_translations.h" #include "iree/compiler/Tools/init_passes.h" #include "iree/compiler/tool_entry_points_api.h" #include "llvm/Support/InitLLVM.h" @@ -145,6 +147,8 @@ int ireeOptRunMain(int argc, char **argv) { mlir::DialectRegistry registry; mlir::iree_compiler::registerAllDialects(registry); mlir::iree_compiler::registerAllPasses(); + mlir::iree_compiler::registerVMTargets(); + mlir::iree_compiler::registerLLVMIRTranslations(registry); // Register the pass to drop embedded transform dialect IR. // TODO: this should be upstreamed. diff --git a/compiler/src/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp b/compiler/src/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp index bac70cdca819..bec93d4c43b5 100644 --- a/compiler/src/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp @@ -3142,49 +3142,8 @@ void CondBranchOp::getCanonicalizationPatterns(RewritePatternSet &results, SwapInvertedCondBranchOpTargets>(context); } -namespace { - -/// Removes vm.call ops to functions that are marked as having no side-effects -/// if the results are unused. -template -struct EraseUnusedCallOp : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(T op, - PatternRewriter &rewriter) const override { - // First check if the call is unused - this ensures we only do the symbol - // lookup if we are actually going to use it. - for (auto result : op.getResults()) { - if (!result.use_empty()) { - return failure(); - } - } - - auto *calleeOp = SymbolTable::lookupSymbolIn( - op->template getParentOfType(), op.getCallee()); - - bool hasNoSideEffects = false; - if (calleeOp->getAttr("nosideeffects")) { - hasNoSideEffects = true; - } else if (auto import = dyn_cast(calleeOp)) { - hasNoSideEffects = !import.hasSideEffects(); - } - if (!hasNoSideEffects) { - // Op has side-effects (or may have them); can't remove. - return failure(); - } - - // Erase op as it is unused. - rewriter.eraseOp(op); - return success(); - } -}; - -} // namespace - void CallOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { - results.insert>(context); -} + MLIRContext *context) {} namespace { @@ -3210,8 +3169,7 @@ struct ConvertNonVariadicToCallOp : public OpRewritePattern { void CallVariadicOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.insert, ConvertNonVariadicToCallOp>( - context); + results.insert(context); } namespace { diff --git a/compiler/src/iree/compiler/Dialect/VM/IR/test/control_flow_folding.mlir b/compiler/src/iree/compiler/Dialect/VM/IR/test/control_flow_folding.mlir index 17f2a8997b47..82d16d3fb800 100644 --- a/compiler/src/iree/compiler/Dialect/VM/IR/test/control_flow_folding.mlir +++ b/compiler/src/iree/compiler/Dialect/VM/IR/test/control_flow_folding.mlir @@ -55,17 +55,12 @@ vm.module @cond_br_folds { // ^bb2(%1 : i32): // vm.return %1 : i32 // } +} - // CHECK-LABEL: @erase_unused_pure_call - vm.func @erase_unused_pure_call(%arg0 : i32) { - %0 = vm.call @nonvariadic_pure_func(%arg0) : (i32) -> i32 - %1 = vm.call.variadic @variadic_pure_func([%arg0]) : (i32 ...) -> i32 - // CHECK-NEXT: vm.return - vm.return - } - vm.import private @nonvariadic_pure_func(%arg0 : i32) -> i32 attributes {nosideeffects} - vm.import private @variadic_pure_func(%arg0 : i32 ...) -> i32 attributes {nosideeffects} +// ----- +// CHECK-LABEL: @call_folds +vm.module @call_folds { // CHECK-LABEL: @convert_nonvariadic_to_call vm.func @convert_nonvariadic_to_call(%arg0 : i32) -> (i32, i32) { // CHECK-NEXT: vm.call @nonvariadic_func(%arg0) : (i32) -> i32 diff --git a/compiler/src/iree/compiler/Dialect/VM/Transforms/BUILD.bazel b/compiler/src/iree/compiler/Dialect/VM/Transforms/BUILD.bazel index bf1131072731..381e4528bd59 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Transforms/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/VM/Transforms/BUILD.bazel @@ -18,6 +18,7 @@ iree_compiler_cc_library( "Conversion.cpp", "DeduplicateRodata.cpp", "DropEmptyModuleInitializers.cpp", + "DropUnusedCalls.cpp", "GlobalInitialization.cpp", "HoistInlinedRodata.cpp", "OrdinalAllocation.cpp", diff --git a/compiler/src/iree/compiler/Dialect/VM/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/VM/Transforms/CMakeLists.txt index 71c5aa44dc51..c3c89682bb0f 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Transforms/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/VM/Transforms/CMakeLists.txt @@ -19,6 +19,7 @@ iree_cc_library( "Conversion.cpp" "DeduplicateRodata.cpp" "DropEmptyModuleInitializers.cpp" + "DropUnusedCalls.cpp" "GlobalInitialization.cpp" "HoistInlinedRodata.cpp" "OrdinalAllocation.cpp" diff --git a/compiler/src/iree/compiler/Dialect/VM/Transforms/DropUnusedCalls.cpp b/compiler/src/iree/compiler/Dialect/VM/Transforms/DropUnusedCalls.cpp new file mode 100644 index 000000000000..9690bca8b5ad --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/VM/Transforms/DropUnusedCalls.cpp @@ -0,0 +1,104 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Dialect/VM/IR/VMOps.h" +#include "iree/compiler/Dialect/VM/Transforms/Passes.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassRegistry.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +namespace mlir::iree_compiler::IREE::VM { + +namespace { + +/// Removes vm.call ops to functions that are marked as having no side-effects +/// if the results are unused. +template +struct EraseUnusedCallOp : public OpRewritePattern { + DenseSet &noSideEffectsSymbols; + EraseUnusedCallOp(MLIRContext *context, + DenseSet &noSideEffectsSymbols, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), + noSideEffectsSymbols(noSideEffectsSymbols) {} + LogicalResult matchAndRewrite(T op, + PatternRewriter &rewriter) const override { + // First check if the call is unused - this ensures we only do the symbol + // lookup if we are actually going to use it. + for (auto result : op.getResults()) { + if (!result.use_empty()) { + return failure(); + } + } + + // Check that + bool hasNoSideEffects = noSideEffectsSymbols.contains(op.getCallee()); + if (!hasNoSideEffects) { + // Op has side-effects (or may have them); can't remove. + return failure(); + } + + // Erase op as it is unused. + rewriter.eraseOp(op); + return success(); + } +}; + +} // namespace + +class DropUnusedCallsPass + : public PassWrapper> { +public: + StringRef getArgument() const override { return "iree-vm-drop-unused-calls"; } + + StringRef getDescription() const override { + return "Drops vm.call ops that have no side effects and are unused."; + } + + void runOnOperation() override { + auto moduleOp = getOperation(); + SymbolTable symbolTable(moduleOp); + + // Find all top-level symbols that have no side effects. + DenseSet noSideEffectsSymbols; + for (auto symbolOp : moduleOp.getOps()) { + if (symbolOp->getAttr("nosideeffects")) { + noSideEffectsSymbols.insert(symbolOp.getName()); + } else if (auto importOp = + dyn_cast(symbolOp.getOperation())) { + if (!importOp.hasSideEffects()) { + noSideEffectsSymbols.insert(symbolOp.getName()); + } + } + } + + // Remove all unused calls. + // Note that we want to remove entire chains of unused calls and run this + // as a pattern application. + RewritePatternSet patterns(&getContext()); + // patterns + patterns.insert, + EraseUnusedCallOp>( + &getContext(), noSideEffectsSymbols); + if (failed(applyPatternsAndFoldGreedily(getOperation(), + std::move(patterns)))) { + signalPassFailure(); + } + } +}; + +std::unique_ptr> createDropUnusedCallsPass() { + return std::make_unique(); +} + +static PassRegistration pass; + +} // namespace mlir::iree_compiler::IREE::VM diff --git a/compiler/src/iree/compiler/Dialect/VM/Transforms/Passes.cpp b/compiler/src/iree/compiler/Dialect/VM/Transforms/Passes.cpp index 05852aa07fe2..665d4f09cf3b 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Transforms/Passes.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/Transforms/Passes.cpp @@ -38,6 +38,11 @@ static void addCleanupPatterns(OpPassManager &passManager) { passManager.addPass(mlir::createCanonicalizerPass()); passManager.addPass(mlir::createCSEPass()); + // Aggressive MLIR cleanup. + passManager.addNestedPass( + IREE::VM::createDropUnusedCallsPass()); + passManager.addPass(mlir::createSymbolDCEPass()); + // Simplify util.global accesses; this can help with data flow tracking as // redundant store-loads are removed. FunctionLikeNest(passManager) diff --git a/compiler/src/iree/compiler/Dialect/VM/Transforms/Passes.h b/compiler/src/iree/compiler/Dialect/VM/Transforms/Passes.h index 9b569e657990..88c0ade8a8cc 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Transforms/Passes.h +++ b/compiler/src/iree/compiler/Dialect/VM/Transforms/Passes.h @@ -86,6 +86,9 @@ createOrdinalAllocationPass(); std::unique_ptr> createDropEmptyModuleInitializersPass(); +// Drops unused calls to functions marked as having no side effects. +std::unique_ptr> createDropUnusedCallsPass(); + // Sinks defining ops with few uses to their use-sites to reduce the total // number of live registers at the cost of additional storage requirements. std::unique_ptr> createSinkDefiningOpsPass(); @@ -101,6 +104,7 @@ inline void registerVMPasses() { createHoistInlinedRodataPass(); createDeduplicateRodataPass(); createDropEmptyModuleInitializersPass(); + createDropUnusedCallsPass(); createGlobalInitializationPass(); createOrdinalAllocationPass(); createResolveRodataLoadsPass(); diff --git a/compiler/src/iree/compiler/Dialect/VM/Transforms/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/VM/Transforms/test/BUILD.bazel index 6f0bdd9a23ee..937957d21fbf 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Transforms/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/VM/Transforms/test/BUILD.bazel @@ -18,6 +18,7 @@ iree_lit_test_suite( [ "deduplicate_rodata.mlir", "drop_empty_module_initializers.mlir", + "drop_unused_calls.mlir", "global_initialization.mlir", "hoist_inlined_rodata.mlir", "ordinal_allocation.mlir", diff --git a/compiler/src/iree/compiler/Dialect/VM/Transforms/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/VM/Transforms/test/CMakeLists.txt index 7e2336ffe684..e854e956756c 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Transforms/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/VM/Transforms/test/CMakeLists.txt @@ -16,6 +16,7 @@ iree_lit_test_suite( SRCS "deduplicate_rodata.mlir" "drop_empty_module_initializers.mlir" + "drop_unused_calls.mlir" "global_initialization.mlir" "hoist_inlined_rodata.mlir" "ordinal_allocation.mlir" diff --git a/compiler/src/iree/compiler/Dialect/VM/Transforms/test/drop_unused_calls.mlir b/compiler/src/iree/compiler/Dialect/VM/Transforms/test/drop_unused_calls.mlir new file mode 100644 index 000000000000..bbd8b0b14211 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/VM/Transforms/test/drop_unused_calls.mlir @@ -0,0 +1,36 @@ +// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(vm.module(iree-vm-drop-unused-calls))" %s | FileCheck %s + +// CHECK-LABEL: @drop_calls +vm.module public @drop_calls { + // CHECK: vm.func @fn + vm.func @fn(%arg0 : i32) { + // CHECK-NOT: vm.call @nonvariadic_pure_func + %0 = vm.call @nonvariadic_pure_func(%arg0) : (i32) -> i32 + // CHECK-NOT: vm.call.variadic @variadic_pure_func + %1 = vm.call.variadic @variadic_pure_func([%arg0]) : (i32 ...) -> i32 + // CHECK-NEXT: vm.return + vm.return + } + vm.import private @nonvariadic_pure_func(%arg0 : i32) -> i32 attributes {nosideeffects} + vm.import private @variadic_pure_func(%arg0 : i32 ...) -> i32 attributes {nosideeffects} +} + +// ----- + +// CHECK-LABEL: @drop_call_trees +vm.module public @drop_call_trees { + // CHECK: vm.func @fn + vm.func @fn(%arg0 : i32) { + // CHECK: vm.call @impure_func + %0 = vm.call @impure_func(%arg0) : (i32) -> i32 + // CHECK-NOT: vm.call @pure_func_a + %1 = vm.call @pure_func_a(%0) : (i32) -> i32 + // CHECK-NOT: vm.call @pure_func_b + %2 = vm.call @pure_func_b(%1) : (i32) -> i32 + // CHECK-NEXT: vm.return + vm.return + } + vm.import private @impure_func(%arg0 : i32) -> i32 + vm.import private @pure_func_a(%arg0 : i32) -> i32 attributes {nosideeffects} + vm.import private @pure_func_b(%arg0 : i32) -> i32 attributes {nosideeffects} +} From 0bb6d92496c252c51f057fddd8e01826f2a1f241 Mon Sep 17 00:00:00 2001 From: MaheshRavishankar <1663364+MaheshRavishankar@users.noreply.github.com> Date: Wed, 30 Oct 2024 10:11:42 -0700 Subject: [PATCH 42/45] Add `ReifyRankedShapedTypeOpInterface` to `hal.interface.binding.subspan` (#18946) Fixes #18942 Signed-off-by: MaheshRavishankar --- .../iree/compiler/Dialect/HAL/IR/BUILD.bazel | 2 ++ .../compiler/Dialect/HAL/IR/CMakeLists.txt | 1 + .../iree/compiler/Dialect/HAL/IR/HALOps.cpp | 13 +++++++++++++ .../iree/compiler/Dialect/HAL/IR/HALOps.td | 14 ++------------ .../Dialect/HAL/Transforms/test/BUILD.bazel | 1 + .../HAL/Transforms/test/CMakeLists.txt | 1 + .../test/resolve_ranked_shaped_type.mlir | 19 +++++++++++++++++++ 7 files changed, 39 insertions(+), 12 deletions(-) create mode 100644 compiler/src/iree/compiler/Dialect/HAL/Transforms/test/resolve_ranked_shaped_type.mlir diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/BUILD.bazel b/compiler/src/iree/compiler/Dialect/HAL/IR/BUILD.bazel index 6841565357a6..d9d6a92ef71c 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/BUILD.bazel @@ -35,6 +35,7 @@ iree_td_library( "//compiler/src/iree/compiler/Dialect/Util/IR:td_files", "@llvm-project//mlir:BuiltinDialectTdFiles", "@llvm-project//mlir:FuncTdFiles", + "@llvm-project//mlir:InferTypeOpInterfaceTdFiles", "@llvm-project//mlir:OpBaseTdFiles", "@llvm-project//mlir:ViewLikeInterfaceTdFiles", ], @@ -80,6 +81,7 @@ iree_compiler_cc_library( "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:FunctionInterfaces", "@llvm-project//mlir:IR", + "@llvm-project//mlir:InferTypeOpInterface", "@llvm-project//mlir:MemRefDialect", "@llvm-project//mlir:Parser", "@llvm-project//mlir:SCFDialect", diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/HAL/IR/CMakeLists.txt index 16b490ea3fc4..837855157e90 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/CMakeLists.txt @@ -45,6 +45,7 @@ iree_cc_library( MLIRFuncDialect MLIRFunctionInterfaces MLIRIR + MLIRInferTypeOpInterface MLIRMemRefDialect MLIRParser MLIRSCFDialect diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp index 81f8da846eb1..9f77e8e7f156 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp @@ -19,6 +19,7 @@ #include "mlir/IR/SymbolTable.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Interfaces/FunctionImplementation.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" namespace mlir::iree_compiler::IREE::HAL { @@ -2039,6 +2040,18 @@ llvm::Align InterfaceBindingSubspanOp::calculateAlignment() { offsetOrAlignment.value()); } +LogicalResult InterfaceBindingSubspanOp::reifyResultShapes( + OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) { + auto resultShapedType = dyn_cast(getResult().getType()); + if (!resultShapedType) { + return failure(); + } + SmallVector resultShape = mlir::getMixedValues( + resultShapedType.getShape(), getDynamicDims(), builder); + reifiedReturnShapes.emplace_back(std::move(resultShape)); + return success(); +} + //===----------------------------------------------------------------------===// // hal.interface.workgroup.* //===----------------------------------------------------------------------===// diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td index 9e370a10c22b..e3f05fcd13ad 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td @@ -19,6 +19,7 @@ include "mlir/IR/SymbolInterfaces.td" include "mlir/Interfaces/CallInterfaces.td" include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/InferIntRangeInterface.td" +include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/ViewLikeInterface.td" @@ -811,12 +812,6 @@ def HAL_ElementTypeOp : HAL_PureOp<"element_type", [ `:` type($result) }]; - let builders = [ - OpBuilder<(ins "Type":$type), [{ - build($_builder, $_state, $_builder.getI32Type(), TypeAttr::get(type)); - }]> - ]; - let extraClassDeclaration = [{ // Returns a stable identifier for the MLIR element type or nullopt if the // type is unsupported in the ABI. @@ -848,12 +843,6 @@ def HAL_EncodingTypeOp : HAL_PureOp<"encoding_type", [ `:` type($result) }]; - let builders = [ - OpBuilder<(ins "Attribute":$encoding), [{ - build($_builder, $_state, $_builder.getI32Type(), encoding); - }]> - ]; - let extraClassDeclaration = [{ // Returns a stable identifier for the MLIR encoding type or 0 (opaque) if // the type is unsupported in the ABI. @@ -3051,6 +3040,7 @@ def HAL_InterfaceConstantLoadOp : HAL_PureOp<"interface.constant.load"> { def HAL_InterfaceBindingSubspanOp : HAL_PureOp<"interface.binding.subspan", [ AttrSizedOperandSegments, + DeclareOpInterfaceMethods, Util_ShapeAwareOp, ]> { let summary = [{returns an alias to a subspan of interface binding data}]; diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/BUILD.bazel index 97fbaff18c69..b949fda38b55 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/BUILD.bazel @@ -38,6 +38,7 @@ iree_lit_test_suite( "resolve_device_aliases.mlir", "resolve_device_promises.mlir", "resolve_export_ordinals.mlir", + "resolve_ranked_shaped_type.mlir", "strip_executable_contents.mlir", "substitute_executables.mlir", "verify_devices.mlir", diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/CMakeLists.txt index def07d67293c..1b9d35ec747c 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/CMakeLists.txt @@ -36,6 +36,7 @@ iree_lit_test_suite( "resolve_device_aliases.mlir" "resolve_device_promises.mlir" "resolve_export_ordinals.mlir" + "resolve_ranked_shaped_type.mlir" "strip_executable_contents.mlir" "substitute_executables.mlir" "verify_devices.mlir" diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/resolve_ranked_shaped_type.mlir b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/resolve_ranked_shaped_type.mlir new file mode 100644 index 000000000000..4037dad52451 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/resolve_ranked_shaped_type.mlir @@ -0,0 +1,19 @@ +// RUN: iree-opt -resolve-ranked-shaped-type-result-dims --split-input-file %s | FileCheck %s + +util.func public @hal_interface_binding_subspan_op(%arg0 : index, %arg1 : index) -> (index, index, index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %0 = hal.interface.binding.subspan layout(< + constants = 0, bindings = [#hal.pipeline.binding], flags = Indirect>) + binding(0) : memref<64x?x?xf16>{%arg0, %arg1} + %d0 = memref.dim %0, %c0 : memref<64x?x?xf16> + %d1 = memref.dim %0, %c1 : memref<64x?x?xf16> + %d2 = memref.dim %0, %c2 : memref<64x?x?xf16> + util.return %d0, %d1, %d2 : index, index, index +} +// CHECK-LABEL: func public @hal_interface_binding_subspan_op( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index, +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index) +// CHECK: %[[C64:.+]] = arith.constant 64 : index +// CHECK: return %[[C64]], %[[ARG0]], %[[ARG1]] From 1f76cb77e7cd03c336cc6ad25314c5e6ffe2bf05 Mon Sep 17 00:00:00 2001 From: Benoit Jacob Date: Wed, 30 Oct 2024 13:35:02 -0400 Subject: [PATCH 43/45] GPU data tiling: reimplement getConcreteMFMALayout (#18953) This `getConcreteMFMALayout` function had its own independent database of layout details for MMA intrinsics. Instead, reimplement it using the tile swizzle infra (which gets its information about layout of MMA intrinsics from `getSingleSubgroupLayout`). Since this function is exercised by lit tests covering RDNA3 WMMA intrinsics with zero strides, this forced solving the issue about that in GPUTileSwizzleUtils.cpp, a blocker for GPU data tiling on RDNA3. Signed-off-by: Benoit Jacob --- .../Dialect/GPU/IR/GPUTileSwizzleUtils.cpp | 14 +- .../Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp | 267 +++++------------- .../test/amdgpu_set_anchor_layouts.mlir | 2 +- 3 files changed, 84 insertions(+), 199 deletions(-) diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/GPUTileSwizzleUtils.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/GPUTileSwizzleUtils.cpp index 7ef46a6c0d9a..ae9d5d9b6188 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/GPUTileSwizzleUtils.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/GPUTileSwizzleUtils.cpp @@ -105,15 +105,13 @@ TileSwizzle getIntrinsicSwizzle(IREE::GPU::MMAIntrinsic intrinsic, swizzle.expandShape.push_back({dim}); } // The layout strides decide the initial swizzle.permutation. - // Some WMMA intrinsics have tstrides=0 values, assert on that as that - // would defeat this algorithm. - // TODO(bjacob): Resolve that to support WMMA intrinsics. - for (auto s : layout.tstrides) { - (void)s; - assert(s != 0); + // Some WMMA intrinsics have tstrides=0 value. That always indicates an outer + // dimension, so overwrite 0 with a large value to get the right order. + SmallVector order = layout.tstrides; + for (auto &val : order) { + val = (val == 0) ? INT64_MAX : val; } - swizzle.permutation = - getSortingPermutation(layout.tstrides); + swizzle.permutation = getSortingPermutation(order); // Deal with any element size greater than 1 by inserting it innermost. // Notice that this is similar to the unroll() function, just creating an // inner dimension instead of an outer dimension. diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp index 41c099f12809..93a2ca762b51 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp @@ -268,185 +268,72 @@ static OpaqueMmaLayout getOpaqueMFMALayout(MLIRContext *context, return OpaqueMmaLayout{}; } -static ConcreteMmaLayout getConcreteMFMALayout(MLIRContext *context, - MMAIntrinsic type) { - auto opaqueLayout = getOpaqueMFMALayout(context, type); - - LayoutDimensionAttr laneX = - LayoutDimensionAttr::get(context, LayoutDimension::LANEX); - LayoutDimensionAttr laneY = - LayoutDimensionAttr::get(context, LayoutDimension::LANEY); - LayoutDimensionAttr laneZ = - LayoutDimensionAttr::get(context, LayoutDimension::LANEZ); - LayoutDimensionAttr vectorX = - LayoutDimensionAttr::get(context, LayoutDimension::VECTORX); - LayoutDimensionAttr vectorY = - LayoutDimensionAttr::get(context, LayoutDimension::VECTORY); - LayoutDimensionAttr vectorZ = - LayoutDimensionAttr::get(context, LayoutDimension::VECTORZ); - (void)laneZ, (void)vectorZ; - switch (type) { - case MMAIntrinsic::MFMA_F32_16x16x4_F32: { - // #outer = #iree_vector_ext.per_dim_layout<[LANEX], [16]> - // #inner = #iree_vector_ext.per_dim_layout<[LANEY, VECTORX], [4, 1]> - // #layout_a = #iree_vector_ext.layout<#outer, #inner> - // #layout_b = #iree_vector_ext.layout<#inner, #outer> - // #layout_c = #iree_vector_ext.layout<#inner, #outer> - - auto outer = PerDimLayoutAttr::get(context, {laneX}, {16}); - auto inner = PerDimLayoutAttr::get(context, {laneY, vectorX}, {4, 1}); - auto aMLayout = outer; - auto aKLayout = inner; - auto bKLayout = inner; - auto bNLayout = outer; - auto cMLayout = PerDimLayoutAttr::get(context, {laneY, vectorX}, {4, 4}); - auto cNLayout = outer; - return ConcreteMmaLayout{opaqueLayout, aMLayout, aKLayout, bKLayout, - bNLayout, cMLayout, cNLayout}; - } - case MMAIntrinsic::MFMA_F32_16x16x16_F16: { - // #outer = #iree_vector_ext.per_dim_layout<[LANEX], [16]> - // #inner = #iree_vector_ext.per_dim_layout<[LANEY, VECTORX], [4, 4]> - // #layout_a = #iree_vector_ext.layout<#outer, #inner> - // #layout_b = #iree_vector_ext.layout<#inner, #outer> - // #layout_c = #iree_vector_ext.layout<#inner, #outer> - - auto outer = PerDimLayoutAttr::get(context, {laneX}, {16}); - auto inner = PerDimLayoutAttr::get(context, {laneY, vectorX}, {4, 4}); - auto aMLayout = outer; - auto aKLayout = inner; - auto bKLayout = inner; - auto bNLayout = outer; - auto cMLayout = inner; - auto cNLayout = outer; - return ConcreteMmaLayout{opaqueLayout, aMLayout, aKLayout, bKLayout, - bNLayout, cMLayout, cNLayout}; - } - case MMAIntrinsic::MFMA_F32_32x32x8_F16: { - // #outer = #iree_vector_ext.per_dim_layout<[LANEX], [32]> - // #inner1 = #iree_vector_ext.per_dim_layout<[LANEY, VECTORX], [2, 4]> - // #inner2 = #iree_vector_ext.per_dim_layout<[VECTORY, LANEY, VECTORX], - // [4, 2, 4]> - // #layout_a = #iree_vector_ext.layout<#outer, #inner1> - // #layout_b = #iree_vector_ext.layout<#inner1, #outer> - // #layout_c = #iree_vector_ext.layout<#inner2, #outer> - - auto outer = PerDimLayoutAttr::get(context, {laneX}, {32}); - auto inner = PerDimLayoutAttr::get(context, {laneY, vectorX}, {2, 4}); - auto aMLayout = outer; - auto aKLayout = inner; - auto bKLayout = inner; - auto bNLayout = outer; - auto cMLayout = - PerDimLayoutAttr::get(context, {vectorY, laneY, vectorX}, {4, 2, 4}); - auto cNLayout = outer; - return ConcreteMmaLayout{opaqueLayout, aMLayout, aKLayout, bKLayout, - bNLayout, cMLayout, cNLayout}; - } - case MMAIntrinsic::MFMA_F32_16x16x16_BF16: { - // #outer = #iree_vector_ext.per_dim_layout<[LANEX], [16]> - // #inner = #iree_vector_ext.per_dim_layout<[LANEY, VECTORX], [4, 4]> - // #layout_a = #iree_vector_ext.layout<#outer, #inner> - // #layout_b = #iree_vector_ext.layout<#inner, #outer> - // #layout_c = #iree_vector_ext.layout<#inner, #outer> - - auto outer = PerDimLayoutAttr::get(context, {laneX}, {16}); - auto inner = PerDimLayoutAttr::get(context, {laneY, vectorX}, {4, 4}); - auto aMLayout = outer; - auto aKLayout = inner; - auto bKLayout = inner; - auto bNLayout = outer; - auto cMLayout = inner; - auto cNLayout = outer; - return ConcreteMmaLayout{opaqueLayout, aMLayout, aKLayout, bKLayout, - bNLayout, cMLayout, cNLayout}; - } - case MMAIntrinsic::MFMA_F32_32x32x8_BF16: { - // #outer = #iree_vector_ext.per_dim_layout<[LANEX], [32]> - // #inner1 = #iree_vector_ext.per_dim_layout<[LANEY, VECTORX], [2, 4]> - // #inner2 = #iree_vector_ext.per_dim_layout<[VECTORY, LANEY, VECTORX], - // [4, 2, 4]> - // #layout_a = #iree_vector_ext.layout<#outer, #inner1> - // #layout_b = #iree_vector_ext.layout<#inner1, #outer> - // #layout_c = #iree_vector_ext.layout<#inner2, #outer> - - auto outer = PerDimLayoutAttr::get(context, {laneX}, {32}); - auto inner = PerDimLayoutAttr::get(context, {laneY, vectorX}, {2, 4}); - auto aMLayout = outer; - auto aKLayout = inner; - auto bKLayout = inner; - auto bNLayout = outer; - auto cMLayout = - PerDimLayoutAttr::get(context, {vectorY, laneY, vectorX}, {4, 2, 4}); - auto cNLayout = outer; - return ConcreteMmaLayout{opaqueLayout, aMLayout, aKLayout, bKLayout, - bNLayout, cMLayout, cNLayout}; - } - case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ: - case MMAIntrinsic::MFMA_I32_16x16x32_I8: { - // #outer = #iree_vector_ext.per_dim_layout<[LANEX], [16]> - // #inner = #iree_vector_ext.per_dim_layout<[LANEY, VECTORX], [4, 8]> - // #layout_a = #iree_vector_ext.layout<#outer, #inner> - // #layout_b = #iree_vector_ext.layout<#inner, #outer> - - auto outer = PerDimLayoutAttr::get(context, {laneX}, {16}); - auto inner = PerDimLayoutAttr::get(context, {laneY, vectorX}, {4, 8}); - auto aMLayout = outer; - auto aKLayout = inner; - auto bKLayout = inner; - auto bNLayout = outer; - auto cMLayout = PerDimLayoutAttr::get(context, {laneY, vectorX}, {4, 4}); - auto cNLayout = outer; - return ConcreteMmaLayout{opaqueLayout, aMLayout, aKLayout, bKLayout, - bNLayout, cMLayout, cNLayout}; - } - case MMAIntrinsic::MFMA_I32_32x32x16_I8: { - // #outer = #iree_vector_ext.per_dim_layout<[LANEX], [16]> - // #inner = #iree_vector_ext.per_dim_layout<[LANEY, VECTORX], [2, 8]> - // #layout_a = #iree_vector_ext.layout<#outer, #inner> - // #layout_b = #iree_vector_ext.layout<#inner, #outer> - - auto outer = PerDimLayoutAttr::get(context, {laneX}, {32}); - auto inner = PerDimLayoutAttr::get(context, {laneY, vectorX}, {2, 8}); - auto aMLayout = outer; - auto aKLayout = inner; - auto bKLayout = inner; - auto bNLayout = outer; - auto cMLayout = - PerDimLayoutAttr::get(context, {vectorY, laneY, vectorX}, {4, 2, 4}); - auto cNLayout = outer; - return ConcreteMmaLayout{opaqueLayout, aMLayout, aKLayout, bKLayout, - bNLayout, cMLayout, cNLayout}; +static std::tuple +getPerDimLayoutAttrs(MLIRContext *context, TileSwizzle swizzle) { + // Step 1: obtain the swizzled tile shape, but keeping track of the source + // dimension indices. + struct SrcIndexAndSwizzleDim { + size_t srcIndex; + TileSwizzle::Dim dim; + }; + SmallVector swizzledShape; + for (auto [i, e] : llvm::enumerate(swizzle.expandShape)) { + for (TileSwizzle::Dim d : e) { + swizzledShape.push_back(SrcIndexAndSwizzleDim{i, d}); + } } - case MMAIntrinsic::WMMA_F32_16x16x16_F16: - case MMAIntrinsic::WMMA_F16_16x16x16_F16: - case MMAIntrinsic::WMMA_I32_16x16x16_I8: { - // #outer = #iree_vector_ext.per_dim_layout<[LANEX], [16]> - // #inner = #iree_vector_ext.per_dim_layout<[LANEY, VECTORX], [1, 16]> - // #layout_a = #iree_vector_ext.layout<#outer, #inner> - // #layout_b = #iree_vector_ext.layout<#inner, #outer> - - int64_t vecYShape = type == MMAIntrinsic::WMMA_F16_16x16x16_F16 ? 16 : 8; - int64_t laneYShape = type == MMAIntrinsic::WMMA_F16_16x16x16_F16 ? 1 : 2; - - auto outer = PerDimLayoutAttr::get(context, {laneX}, {16}); - auto inner = PerDimLayoutAttr::get(context, {laneY, vectorX}, {1, 16}); - auto aMLayout = outer; - auto aKLayout = inner; - auto bKLayout = inner; - auto bNLayout = outer; - auto cMLayout = PerDimLayoutAttr::get(context, {vectorY, laneY, vectorX}, - {vecYShape, laneYShape, 1}); - auto cNLayout = PerDimLayoutAttr::get(context, {laneX}, {16}); - return ConcreteMmaLayout{opaqueLayout, aMLayout, aKLayout, bKLayout, - bNLayout, cMLayout, cNLayout}; - } - default: { - break; + applyPermutationToVector(swizzledShape, swizzle.permutation); + + // Step 2: collect the appropriate labels to use for the swizzled dims. + LayoutDimension internalLabels[] = {LayoutDimension::VECTORZ, + LayoutDimension::VECTORY, + LayoutDimension::VECTORX}; + LayoutDimension crossThreadLabels[] = { + LayoutDimension::LANEZ, LayoutDimension::LANEY, LayoutDimension::LANEX}; + auto internalLabelIter = std::end(internalLabels); + auto crossThreadLabelIter = std::end(crossThreadLabels); + for (SrcIndexAndSwizzleDim d : swizzledShape) { + if (d.dim.kind == TileSwizzle::Dim::Kind::Internal) { + assert(internalLabelIter != std::begin(internalLabels)); + --internalLabelIter; + } else if (d.dim.kind == TileSwizzle::Dim::Kind::CrossThread) { + assert(crossThreadLabelIter != std::begin(crossThreadLabels)); + --crossThreadLabelIter; + } else { + assert(false && "unexpected dimension kind in intrinsic swizzle"); + } } + + // Step 3: put together the result PerDimLayoutAttr'd for the two source dims. + SmallVector labels[2]; + SmallVector shape[2]; + for (SrcIndexAndSwizzleDim d : swizzledShape) { + shape[d.srcIndex].push_back(d.dim.size); + auto &labelIterRef = (d.dim.kind == TileSwizzle::Dim::Kind::Internal) + ? internalLabelIter + : crossThreadLabelIter; + labels[d.srcIndex].push_back(LayoutDimensionAttr::get( + context, static_cast(*labelIterRef++))); } - llvm_unreachable("unhandled concrete mma type"); - return ConcreteMmaLayout{}; + return {PerDimLayoutAttr::get(context, labels[0], shape[0]), + PerDimLayoutAttr::get(context, labels[1], shape[1])}; +}; + +static ConcreteMmaLayout getConcreteMFMALayout(MLIRContext *context, + MMAIntrinsic intrinsic) { + auto opaque = getOpaqueMFMALayout(context, intrinsic); + ConcreteMmaLayout concreteLayout; + concreteLayout.base = opaque; + auto lhsSwizzle = getIntrinsicSwizzle(intrinsic, MMAFragment::Lhs); + auto rhsSwizzle = getIntrinsicSwizzle(intrinsic, MMAFragment::Rhs); + auto accSwizzle = getIntrinsicSwizzle(intrinsic, MMAFragment::Acc); + std::tie(concreteLayout.aMLayout, concreteLayout.aKLayout) = + getPerDimLayoutAttrs(context, lhsSwizzle); + std::tie(concreteLayout.bNLayout, concreteLayout.bKLayout) = + getPerDimLayoutAttrs(context, rhsSwizzle); + std::tie(concreteLayout.cMLayout, concreteLayout.cNLayout) = + getPerDimLayoutAttrs(context, accSwizzle); + return concreteLayout; } //===----------------------------------------------------------------------===// @@ -960,20 +847,6 @@ LogicalResult MMAAttr::materializeOperandConcreteShape( // DataTiledMMA Attributes //===----------------------------------------------------------------------===// -std::tuple DataTiledMMAAttr::getABCElementTypes() const { - MLIRContext *ctx = getContext(); - auto opaqueLayout = getOpaqueMFMALayout(ctx, getIntrinsic().getValue()); - return {opaqueLayout.aType, opaqueLayout.bType, opaqueLayout.cType}; -} - -std::tuple DataTiledMMAAttr::getMNKShape() const { - MLIRContext *ctx = getContext(); - auto opaqueLayout = getOpaqueMFMALayout(ctx, getIntrinsic().getValue()); - return {opaqueLayout.mSize * getUnrollM() * getUnrollMToSubgroups(), - opaqueLayout.nSize * getUnrollN() * getUnrollNToSubgroups(), - opaqueLayout.kSize * getUnrollK()}; -} - /// Returns the swizzled tile shape, but with dim sizes overwritten with 1 if /// `predicate` returns false. static SmallVector @@ -989,6 +862,20 @@ sliceSwizzledShape(const TileSwizzle &swizzle, return shape; } +std::tuple DataTiledMMAAttr::getABCElementTypes() const { + MLIRContext *ctx = getContext(); + auto opaqueLayout = getOpaqueMFMALayout(ctx, getIntrinsic().getValue()); + return {opaqueLayout.aType, opaqueLayout.bType, opaqueLayout.cType}; +} + +std::tuple DataTiledMMAAttr::getMNKShape() const { + MLIRContext *ctx = getContext(); + auto opaqueLayout = getOpaqueMFMALayout(ctx, getIntrinsic().getValue()); + return {opaqueLayout.mSize * getUnrollM() * getUnrollMToSubgroups(), + opaqueLayout.nSize * getUnrollN() * getUnrollNToSubgroups(), + opaqueLayout.kSize * getUnrollK()}; +} + std::tuple DataTiledMMAAttr::getABCVectorTypes() const { auto [A, B, C] = getABCElementTypes(); diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/amdgpu_set_anchor_layouts.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/amdgpu_set_anchor_layouts.mlir index 972143537dec..da7a3b338ac1 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/amdgpu_set_anchor_layouts.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/amdgpu_set_anchor_layouts.mlir @@ -80,7 +80,7 @@ builtin.module attributes { transform.with_named_sequence } { %rhs = vector.transfer_read %b[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16> // expected-remark @above {{layout of result #0 is #iree_vector_ext.layout<<[ BATCHX, LANEX], [1, 16]>, <[ BATCHY, LANEY, VECTORX], [1, 1, 16]>>}} %output = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %lhs, %rhs, %init : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf32> - // expected-remark @above {{layout of result #0 is #iree_vector_ext.layout<<[ BATCHX, VECTORY, LANEY, VECTORX], [1, 8, 2, 1]>, <[ BATCHY, LANEX], [1, 16]>>}} + // expected-remark @above {{layout of result #0 is #iree_vector_ext.layout<<[ BATCHX, VECTORX, LANEY], [1, 8, 2]>, <[ BATCHY, LANEX], [1, 16]>>}} return %output : vector<16x16xf32> } transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) { From 554f31f6cd1520ec0691fd7a01efa64bd873e71f Mon Sep 17 00:00:00 2001 From: Ben Vanik Date: Wed, 30 Oct 2024 10:51:35 -0700 Subject: [PATCH 44/45] Adding a flag to force indirect command buffers on in non-reusable cases. (#18945) Includes various fixes found during testing. --- compiler/plugins/target/ROCM/ROCMTarget.cpp | 14 ++++++++++++-- .../compiler/API/Internal/CompilerDriver.cpp | 8 ++++++-- .../HAL/Conversion/StreamToHAL/Patterns.cpp | 13 ++++++++++++- .../iree/compiler/Dialect/HAL/IR/HALBase.td | 2 -- .../iree/compiler/Dialect/HAL/IR/HALOps.td | 2 +- .../HAL/Transforms/OutlineMemoizeRegions.cpp | 19 +++++++++++++++---- .../Dialect/VM/Transforms/DropUnusedCalls.cpp | 1 - tests/compiler_driver/streams.mlir | 8 ++++---- 8 files changed, 50 insertions(+), 17 deletions(-) diff --git a/compiler/plugins/target/ROCM/ROCMTarget.cpp b/compiler/plugins/target/ROCM/ROCMTarget.cpp index 05ab66779271..308c88097cd7 100644 --- a/compiler/plugins/target/ROCM/ROCMTarget.cpp +++ b/compiler/plugins/target/ROCM/ROCMTarget.cpp @@ -160,6 +160,17 @@ struct ROCMOptions { } }; +// Returns the ABI or an empty string if unspecified. +static StringRef getABI(IREE::HAL::ExecutableTargetAttr targetAttr) { + if (targetAttr) { + if (auto config = targetAttr.getConfiguration()) { + auto abiAttr = targetAttr.getConfiguration().getAs("abi"); + return abiAttr ? abiAttr.getValue() : ""; + } + } + return ""; +} + static void dumpModuleToPath(StringRef path, StringRef baseName, StringRef suffix, StringRef extension, llvm::Module &module) { @@ -585,8 +596,7 @@ class ROCMTargetBackend final : public TargetBackend { // Wrap the HSACO ELF binary in a Flatbuffers container. FailureOr binaryContainer; - if (targetAttr.getConfiguration() && - targetAttr.getConfiguration().getAs("abi") == "amdgpu") { + if (getABI(targetAttr) == "amdgpu") { binaryContainer = serializeAMDGPUBinaryContainer( serializationOptions, variantOp, exportOps, targetHSACO); } else { diff --git a/compiler/src/iree/compiler/API/Internal/CompilerDriver.cpp b/compiler/src/iree/compiler/API/Internal/CompilerDriver.cpp index 7f83a5e3b3fe..d12dbc5c6d4b 100644 --- a/compiler/src/iree/compiler/API/Internal/CompilerDriver.cpp +++ b/compiler/src/iree/compiler/API/Internal/CompilerDriver.cpp @@ -1100,8 +1100,12 @@ Error *Invocation::outputHALExecutable(Output &output) { return new Error("not a valid HAL executable"); } auto binaryOp = binaryOps.front(); - auto rawData = binaryOp.getData().getRawData(); - output.outputStream->write(rawData.data(), rawData.size()); + if (failed(cast(binaryOp.getData()) + .serializeToStream(binaryOp.getLoc(), llvm::endianness::little, + *output.outputStream))) { + return new Error( + "data attribute failed to serialize: unsupported format or encoding"); + } output.outputStream->flush(); return output.getWriteError(); } diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp index 4ca2fdbde223..491bd876ecb1 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp @@ -32,6 +32,15 @@ static llvm::cl::opt clIndirectCommandBuffers{ llvm::cl::init(true), }; +// TODO(benvanik): remove when we support capturing dynamic values for reuse. +static llvm::cl::opt clForceIndirectCommandBuffers{ + "iree-hal-force-indirect-command-buffers", + llvm::cl::desc("Forces indirect command buffers when they would otherwise " + "not be chosen due to the values they capture. They may not " + "be reusable but will still be outlined."), + llvm::cl::init(false), +}; + struct ContextResolveOpPattern : public StreamConversionPattern { using StreamConversionPattern::StreamConversionPattern; @@ -1002,7 +1011,9 @@ struct CmdExecuteOpPattern // changes dispatches to use them for any dispatch we can - note that there // may still be some that slip through due to custom executables. const bool capturesDynamicUniformValues = - regionCapturesDynamicUniformValues(executeOp); + clForceIndirectCommandBuffers + ? false + : regionCapturesDynamicUniformValues(executeOp); // Calculate the indirect buffer references used within the command buffer // by analyzing captured resources. This analysis will be used by subsequent diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALBase.td b/compiler/src/iree/compiler/Dialect/HAL/IR/HALBase.td index 2b2f23c8dd17..3f1e8110b83e 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALBase.td +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALBase.td @@ -150,8 +150,6 @@ def HAL_Ordinal : TypeAlias; def HAL_OrdinalAttr : Util_IndexAttrBase<"size_t">; def HAL_OrdinalArrayAttr : TypedArrayAttrBase; -def HAL_ExecutableDataAttr : SignlessIntElementsAttr<8>; - def HAL_ElementType : TypeAlias; def HAL_ElementTypeAttr : SignlessIntegerAttrBase< I32, "element type attribute">; diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td index e3f05fcd13ad..c1d9a4ea5b56 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td @@ -2593,7 +2593,7 @@ def HAL_ExecutableBinaryOp : HAL_Op<"executable.binary", [ OptionalAttr:$sym_visibility, SymbolNameAttr:$sym_name, StrAttr:$format, - HAL_ExecutableDataAttr:$data, + Util_AnySerializableAttr:$data, OptionalAttr:$mime_type // TODO(benvanik): add compatibility and versioning attributes. ); diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/OutlineMemoizeRegions.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/OutlineMemoizeRegions.cpp index 30447862945f..19b8d490a5ab 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/OutlineMemoizeRegions.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/OutlineMemoizeRegions.cpp @@ -13,6 +13,7 @@ #include "iree/compiler/Dialect/HAL/Transforms/Passes.h" #include "iree/compiler/Dialect/Util/IR/UtilDialect.h" #include "iree/compiler/Utils/StringUtils.h" +#include "llvm/Support/Debug.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/SCF/IR/SCF.h" @@ -24,6 +25,8 @@ #include "mlir/Pass/Pass.h" #include "mlir/Transforms/RegionUtils.h" +#define DEBUG_TYPE "iree-hal-outline-memoize-regions" + namespace mlir::iree_compiler::IREE::HAL { #define GEN_PASS_DEF_OUTLINEMEMOIZEREGIONSPASS @@ -153,6 +156,8 @@ static IREE::Util::FuncOp outlineMemoizeRegionBody( name, funcType); moduleSymbolTable.insert(funcOp); funcOp.setVisibility(SymbolTable::Visibility::Private); + funcOp.setInliningPolicyAttr( + moduleBuilder.getAttr()); auto funcBuilder = OpBuilder::atBlockBegin(funcOp.addEntryBlock()); // Remap any captured operands that have corresponding function arguments. @@ -521,8 +526,11 @@ static void memoizeRegionOp(IREE::HAL::DeviceMemoizeOp memoizeOp, // If we can't memoize the resources at initialization time then we need // to do it on-demand. if (!memoizeAnalysis.canRunAtInitializationTime()) { - memoizeOp.emitWarning( - "memoization failed: dynamic values captured at the call site"); + LLVM_DEBUG({ + llvm::dbgs() + << "memoization failed: dynamic values captured at the call site\n"; + memoizeOp.dump(); + }); replaceMemoizeOpWithApply(memoizeOp, memoizeAnalysis, applyFuncOp); return; } @@ -532,8 +540,11 @@ static void memoizeRegionOp(IREE::HAL::DeviceMemoizeOp memoizeOp, auto deviceGlobals = deviceAnalysis.lookupDeviceGlobals(memoizeOp.getDevice()); if (!deviceGlobals) { - memoizeOp.emitWarning("memoization failed: unable to analyze devices " - "that may be used with memoized region"); + LLVM_DEBUG({ + llvm::dbgs() << "memoization failed: unable to analyze devices that may " + "be used with memoized region\n"; + memoizeOp.dump(); + }); replaceMemoizeOpWithApply(memoizeOp, memoizeAnalysis, applyFuncOp); return; } diff --git a/compiler/src/iree/compiler/Dialect/VM/Transforms/DropUnusedCalls.cpp b/compiler/src/iree/compiler/Dialect/VM/Transforms/DropUnusedCalls.cpp index 9690bca8b5ad..a71011b2b45c 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Transforms/DropUnusedCalls.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/Transforms/DropUnusedCalls.cpp @@ -84,7 +84,6 @@ class DropUnusedCallsPass // Note that we want to remove entire chains of unused calls and run this // as a pattern application. RewritePatternSet patterns(&getContext()); - // patterns patterns.insert, EraseUnusedCallOp>( &getContext(), noSideEffectsSymbols); diff --git a/tests/compiler_driver/streams.mlir b/tests/compiler_driver/streams.mlir index 03ebbc331527..9e30a2e12f8f 100644 --- a/tests/compiler_driver/streams.mlir +++ b/tests/compiler_driver/streams.mlir @@ -51,10 +51,10 @@ stream.executable private @executable_0 { } } } -// CHECK: vm.func private @simple_mul +// CHECK: vm.func private @__simple_mul_memoize_apply +// CHECK: vm.call.variadic @hal.command_buffer.dispatch func.func @simple_mul(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { %c4 = arith.constant 4 : index - // CHECK: vm.call.variadic @hal.command_buffer.dispatch %ret0 = flow.dispatch @executable_0::@dispatch[%c4](%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> return %ret0 : tensor<4xf32> } @@ -98,10 +98,10 @@ stream.executable private @executable_1 { } } } -// CHECK: vm.func private @simple_mul_inplace +// CHECK: vm.func private @__simple_mul_inplace_memoize_apply +// CHECK: vm.call.variadic @hal.command_buffer.dispatch func.func @simple_mul_inplace(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { %c4 = arith.constant 4 : index - // CHECK: vm.call.variadic @hal.command_buffer.dispatch %ret0 = flow.dispatch @executable_1::@dispatch[%c4](%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> %arg0 return %ret0 : tensor<4xf32> } From 14f58e0b6c117b493860fd39339d3e129afbb390 Mon Sep 17 00:00:00 2001 From: Nirvedh Meshram <96096277+nirvedhmeshram@users.noreply.github.com> Date: Wed, 30 Oct 2024 13:07:22 -0500 Subject: [PATCH 45/45] [ROCM] Turn on SLP vectorization (#18949) After doing review on the benchmarks we see that this is not causing any model regression and we also did a benchmarking suite of attention kernels for which in the past we thought this could cause a regression. However, we didnt find any significant perf change. See [here](https://docs.google.com/spreadsheets/d/102hYwdOGehmi_HhnLxHevwAVzMraKgKEjdTdUlJ4_TU/edit?usp=sharing) As this is needed for correctness issue found in https://github.com/iree-org/iree/issues/18798 we are turning on the SLP vectorizer. Signed-off-by: Nirvedh --- compiler/plugins/target/ROCM/ROCMTarget.cpp | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/compiler/plugins/target/ROCM/ROCMTarget.cpp b/compiler/plugins/target/ROCM/ROCMTarget.cpp index 308c88097cd7..b56b9bbcd722 100644 --- a/compiler/plugins/target/ROCM/ROCMTarget.cpp +++ b/compiler/plugins/target/ROCM/ROCMTarget.cpp @@ -64,7 +64,7 @@ struct ROCMOptions { int wavesPerEu = 0; std::string enableROCMUkernels = "none"; bool legacySync = true; - bool slpVectorization = false; + bool slpVectorization = true; bool globalISel = false; /// List of LLVM opt pass pluggins to be loaded during GPU code @@ -113,11 +113,9 @@ struct ROCMOptions { "to be passed to the target backend compiler during HIP " "executable serialization"), cl::ZeroOrMore, cl::cat(category)); - binder.opt( - "iree-hip-llvm-slp-vec", slpVectorization, cl::cat(category), - cl::desc( - "Enable slp vectorization in llvm opt. This can have an impact on " - "performance/numerics so its turned off by default currently.")); + binder.opt("iree-hip-llvm-slp-vec", slpVectorization, + cl::cat(category), + cl::desc("Enable slp vectorization in llvm opt.")); binder.opt("iree-hip-llvm-global-isel", globalISel, cl::cat(category), cl::desc("Enable global instruction selection in llvm.")); }