diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToNVVM.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToNVVM.cpp index 071e07bbd192..0cd451da4cf1 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToNVVM.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToNVVM.cpp @@ -119,6 +119,10 @@ struct ConvertToNVVMPass : public ConvertToNVVMBase { vector::VectorTransformsOptions().setVectorTransformsOptions( vector::VectorContractLowering::OuterProduct)); vector::populateVectorMaskOpLoweringPatterns(patterns); + // We currently always use 64 bit indices, thus ensure the bit width of + // the mask compare is consistent. + vector::populateVectorMaskMaterializationPatterns( + patterns, /*force32BitVectorIndices=*/false); vector::populateVectorShapeCastLoweringPatterns(patterns); // TODO: doubtful that the "default" does what one want here, it is likely // better to use something else. diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToROCDL.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToROCDL.cpp index f7d33f8cd1bf..a43cbe4a711b 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToROCDL.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToROCDL.cpp @@ -129,6 +129,10 @@ struct ConvertToROCDLPass : public ConvertToROCDLBase { vector::VectorTransformsOptions().setVectorTransformsOptions( vector::VectorContractLowering::OuterProduct)); vector::populateVectorMaskOpLoweringPatterns(patterns); + // We currently always use 64 bit indices, thus ensure the bit width of + // the mask compare is consistent. + vector::populateVectorMaskMaterializationPatterns( + patterns, /*force32BitVectorIndices=*/false); vector::populateVectorShapeCastLoweringPatterns(patterns); // TODO: doubtful that the "default" does what one want here, it is likely // better to use something else. diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp index 232ba442c807..2dc54cea9fde 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp @@ -27,7 +27,7 @@ using namespace mlir; using namespace mlir::iree_compiler; -static constexpr unsigned cudaWarpSize = 32; +static constexpr unsigned kCudaWarpSize = 32; static constexpr StringLiteral kCudaTarget = "cuda"; static constexpr StringLiteral kRocmTarget = "rocm"; namespace mlir { @@ -395,7 +395,7 @@ static LogicalResult setContractConfig(func::FuncOp entryPoint, } } // Special case for very small matrices. - if (sizeM * sizeN <= cudaWarpSize) { + if (sizeM * sizeN <= kCudaWarpSize) { return setMatmulConfig( sizeN, sizeM, 4, {sizeM, sizeN, 1}, softwarePipelineDepthSimt, IREE::Codegen::DispatchLoweringPassPipeline::LLVMGPUMatmulSimt); @@ -448,7 +448,7 @@ static LogicalResult setFftConfig(func::FuncOp entryPoint, interfaceOp.getPartitionableLoops(kNumMaxParallelDims); unsigned loopDepth = partitionedLoops.back() + 1; SmallVector workgroupTileSize(loopDepth, 0); - SmallVector workgroupSize = {cudaWarpSize, 1, 1}; + SmallVector workgroupSize = {kCudaWarpSize, 1, 1}; // Tiling along partitioned loops with size 1. for (int64_t loopIndex : partitionedLoops) { @@ -485,7 +485,7 @@ static LogicalResult setSortConfig(func::FuncOp entryPoint, Operation *op) { } size_t numLoops = partitionedLoops.back() + 1; // To get peak occupancy we need a workgroup size of at least two warps - std::array workgroupSize = {2 * cudaWarpSize, 1, 1}; + std::array workgroupSize = {2 * kCudaWarpSize, 1, 1}; SmallVector workgroupTileSizes(numLoops, 1); // Set all non-parallel loops to zero tile size. llvm::DenseSet partitionedLoopsSet(partitionedLoops.begin(), @@ -531,7 +531,7 @@ getDefaultWorkgroupTileSizesForPackUnPack(TilingInterface op, static LogicalResult setPackConfig(func::FuncOp entryPoint, tensor::PackOp packOp) { SmallVector tileSizes = getDefaultWorkgroupTileSizesForPackUnPack( - cast(packOp.getOperation()), cudaWarpSize); + cast(packOp.getOperation()), kCudaWarpSize); // The default function aims to returns the number of workload per workgroup, // but it does not know that it is working on packed domain. We need to take @@ -546,7 +546,7 @@ static LogicalResult setPackConfig(func::FuncOp entryPoint, } TileSizesListType tileSizesList = {tileSizes}; - std::array workgroupSizes = {cudaWarpSize, 1, 1}; + std::array workgroupSizes = {kCudaWarpSize, 1, 1}; return setOpConfigAndEntryPointFnTranslation( entryPoint, packOp, tileSizesList, IREE::Codegen::DispatchLoweringPassPipeline::LLVMGPUPackUnPack, @@ -569,7 +569,7 @@ static LogicalResult setRootDefaultConfig(func::FuncOp entryPoint, size_t numLoops = partitionedLoops.back() + 1; // To get peak occupancy we need a workgroup size of at least two warps - std::array workgroupSize = {2 * cudaWarpSize, 1, 1}; + std::array workgroupSize = {2 * kCudaWarpSize, 1, 1}; unsigned vectorSize = 4; SmallVector workgroupTileSizes(numLoops, 1); // Set all non-parallel loops to zero tile size. @@ -606,7 +606,7 @@ static LogicalResult setRootDefaultConfig(func::FuncOp entryPoint, int64_t problemSize = std::accumulate( shape.begin(), shape.end(), 1, [](const int64_t &a, const int64_t &b) { return a * b; }); - if ((problemSize / (cudaWarpSize * vectorSize)) < 64) { + if ((problemSize / (kCudaWarpSize * vectorSize)) < 64) { vectorSize = 1; break; } @@ -750,11 +750,19 @@ static LogicalResult setWarpReductionConfig(func::FuncOp entryPoint, return failure(); // Make sure reduction dimensions are static and innermost ones. + int64_t numDynamicReductionDims = 0; for (unsigned dim : reductionDims) { - if (ShapedType::isDynamic(bounds[dim])) - return failure(); - if (dim < numParallelDims) + if (ShapedType::isDynamic(bounds[dim])) { + numDynamicReductionDims++; + } + if (dim < numParallelDims) { return failure(); + } + } + + // Distribution of multi-dim masked writes currently aren't fully supported. + if (numDynamicReductionDims > 1) { + return failure(); } if (op.getRegionOutputArgs().size() != 1) @@ -784,10 +792,36 @@ static LogicalResult setWarpReductionConfig(func::FuncOp entryPoint, if (!foundSingleReductionOutput) return failure(); + // Tile all the parallel dimension to 1. + SmallVector partitionedLoops = + cast(op.getOperation()) + .getPartitionableLoops(kNumMaxParallelDims); + size_t numLoops = partitionedLoops.empty() ? 0 : partitionedLoops.back() + 1; + SmallVector workgroupTileSizes(numLoops, 1); + + // Without any bounds on dynamic reduction dims, we need specialization to + // get peak performance. For now, just use the warp size. + if (numDynamicReductionDims) { + SmallVector reductionTileSizes(op.getNumLoops(), 0); + // TODO: Don't hard code this. + reductionTileSizes[reductionDims[0]] = kCudaWarpSize; + TileSizesListType tileSizes; + tileSizes.emplace_back(std::move(workgroupTileSizes)); // Workgroup level + tileSizes.emplace_back(std::move(reductionTileSizes)); // Reduction level + std::array workgroupSize = {kCudaWarpSize, 1, 1}; + if (failed(setOpConfigAndEntryPointFnTranslation( + entryPoint, op, tileSizes, + IREE::Codegen::DispatchLoweringPassPipeline::LLVMGPUWarpReduction, + workgroupSize))) { + return failure(); + } + return success(); + } + int64_t reductionSize = 1; for (int64_t dim : reductionDims) reductionSize *= bounds[dim]; - if (reductionSize % cudaWarpSize != 0) + if (reductionSize % kCudaWarpSize != 0) return failure(); const Type elementType = @@ -802,7 +836,7 @@ static LogicalResult setWarpReductionConfig(func::FuncOp entryPoint, const unsigned largestLoadSizeInBits = 128; unsigned vectorSize = largestLoadSizeInBits / bitWidth; - while ((reductionSize / vectorSize) % cudaWarpSize != 0) + while ((reductionSize / vectorSize) % kCudaWarpSize != 0) vectorSize /= 2; // Deduce the workgroup size we should use for reduction. Currently a @@ -839,7 +873,7 @@ static LogicalResult setWarpReductionConfig(func::FuncOp entryPoint, // How many 128-bit vectors each thread should at least read. const int targetVectorCount = 8; while (parallelSize && *parallelSize > parallelThreshold && - (groupSize / 2) % cudaWarpSize == 0 && + (groupSize / 2) % kCudaWarpSize == 0 && reductionSize / (groupSize * vectorSize) < targetVectorCount) { // Use less subgroups per workgroup.. groupSize /= 2; @@ -851,29 +885,23 @@ static LogicalResult setWarpReductionConfig(func::FuncOp entryPoint, // First, do warp reductions along multiple subgroups. // Second, reduce results from multiple subgroups using single warp reduce. // The final warp reduce requires subgroup count <= subgroup size to work. - if ((groupSize / cudaWarpSize) > cudaWarpSize) + if ((groupSize / kCudaWarpSize) > kCudaWarpSize) return failure(); std::array workgroupSize = {groupSize, 1, 1}; - SmallVector partitionedLoops = - cast(op.getOperation()) - .getPartitionableLoops(kNumMaxParallelDims); - size_t numLoops = partitionedLoops.empty() ? 0 : partitionedLoops.back() + 1; - // Tile all the parallel dimension to 1. - SmallVector workgroupTileSizes(numLoops, 1); SmallVector reductionTileSizes(op.getNumLoops(), 0); - int64_t remaingGroupSize = groupSize; + int64_t remainingGroupSize = groupSize; for (int i = reductionDims.size() - 1; i >= 0; --i) { int64_t dim = reductionDims[i]; int64_t bound = bounds[dim]; if (i == reductionDims.size() - 1) bound /= vectorSize; APInt size = llvm::APIntOps::GreatestCommonDivisor( - {64, uint64_t(remaingGroupSize)}, {64, uint64_t(bound)}); + {64, uint64_t(remainingGroupSize)}, {64, uint64_t(bound)}); reductionTileSizes[dim] = size.getSExtValue(); if (i == reductionDims.size() - 1) reductionTileSizes[dim] *= vectorSize; - remaingGroupSize /= size.getSExtValue(); + remainingGroupSize /= size.getSExtValue(); } TileSizesListType tileSizes; tileSizes.emplace_back(std::move(workgroupTileSizes)); // Workgroup level diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp index 247f08331ce4..4e5b97867ac1 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp @@ -394,6 +394,8 @@ void addGPUWarpReductionPassPipeline(OpPassManager &pm) { // Linalg -> vector { GenericVectorizationPassOptions options; + options.enableVectorMasking = true; + options.useConfiguredVectorSizes = false; options.vectorizePadding = true; options.vectorizeGatherAccesses = true; options.enableCleanup = false; diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/convert_to_nvvm.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/convert_to_nvvm.mlir index 3a63bac228fc..e7137aad4745 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/convert_to_nvvm.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/convert_to_nvvm.mlir @@ -396,3 +396,33 @@ hal.executable @shared_memory_lowering_index { // CHECK-NEXT: %{{.*}} = llvm.mlir.constant(0 : i64) : i64 // CHECK-NEXT: %{{.*}} = llvm.mlir.constant(0 : i64) : i64 // CHECK-NEXT: %{{.*}} = llvm.getelementptr %{{.*}} : (!llvm.ptr<3>, i64, i64) -> !llvm.ptr<3> + +// ----- +#pipeline_layout = #hal.pipeline.layout, + #hal.descriptor_set.binding<1, storage_buffer> + ]> +]> +hal.executable @masked_load_store { + hal.executable.variant @cuda target(<"cuda", "cuda-nvptx-fb">) { + hal.executable.export @masked_load_store layout(#pipeline_layout) + builtin.module { + func.func @masked_load_store() { + %c0 = arith.constant 0 : index + %idx = gpu.thread_id x + %pass_thru = arith.constant dense<0.000000e+00> : vector<1xf32> + %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<64xf32, #gpu.address_space> + %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : memref<64xf32, #gpu.address_space> + %mask = vector.create_mask %idx : vector<1xi1> + %ld = vector.maskedload %0[%idx], %mask, %pass_thru : memref<64xf32, #gpu.address_space>, vector<1xi1>, vector<1xf32> into vector<1xf32> + vector.maskedstore %1[%idx], %mask, %ld : memref<64xf32, #gpu.address_space>, vector<1xi1>, vector<1xf32> + return + } + } + } +} +// CHECK-LABEL: llvm.func @masked_load_store +// CHECK: %[[MASK_BIT:.+]] = llvm.icmp "sgt" {{.*}} : vector<1xi64> +// CHECK: llvm.intr.masked.load %{{.*}}, %[[MASK_BIT]] +// CHECK: llvm.intr.masked.store %{{.*}}, %[[MASK_BIT]] diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/convert_to_rocdl.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/convert_to_rocdl.mlir index fd52a4529d18..b33cf0ede17a 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/convert_to_rocdl.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/convert_to_rocdl.mlir @@ -78,23 +78,49 @@ hal.executable @abs_ex_dispatch_0 { // Test that gpu barriers be lowered to `s_waitcnt lgkmcnt(0)\0As_barrier` on rocm #pipeline_layout = #hal.pipeline.layout, - #hal.descriptor_set.binding<4, storage_buffer> - ]>, - #hal.descriptor_set.layout<1, bindings = [ - #hal.descriptor_set.binding<2, storage_buffer> + #hal.descriptor_set.binding<0, storage_buffer> ]> ]> -hal.executable @matmul_dispatch_0 { +hal.executable @simple_barrier { hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb">) { - hal.executable.export @matmul_dispatch_0 layout(#pipeline_layout) + hal.executable.export @simple_barrier layout(#pipeline_layout) builtin.module { - func.func @matmul_dispatch_0() { + func.func @simple_barrier() { gpu.barrier return } } } } -// CHECK-LABEL: llvm.func @matmul_dispatch_0 +// CHECK-LABEL: llvm.func @simple_barrier // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "s_waitcnt lgkmcnt(0)\0As_barrier", "" : () -> () + +// ----- +#pipeline_layout = #hal.pipeline.layout, + #hal.descriptor_set.binding<1, storage_buffer> + ]> +]> +hal.executable @masked_load_store { + hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb">) { + hal.executable.export @masked_load_store layout(#pipeline_layout) + builtin.module { + func.func @masked_load_store() { + %c0 = arith.constant 0 : index + %idx = gpu.thread_id x + %pass_thru = arith.constant dense<0.000000e+00> : vector<1xf32> + %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<64xf32, #gpu.address_space> + %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : memref<64xf32, #gpu.address_space> + %mask = vector.create_mask %idx : vector<1xi1> + %ld = vector.maskedload %0[%idx], %mask, %pass_thru : memref<64xf32, #gpu.address_space>, vector<1xi1>, vector<1xf32> into vector<1xf32> + vector.maskedstore %1[%idx], %mask, %ld : memref<64xf32, #gpu.address_space>, vector<1xi1>, vector<1xf32> + return + } + } + } +} +// CHECK-LABEL: llvm.func @masked_load_store +// CHECK: %[[MASK_BIT:.+]] = llvm.icmp "sgt" {{.*}} : vector<1xi64> +// CHECK: llvm.intr.masked.load %{{.*}}, %[[MASK_BIT]] +// CHECK: llvm.intr.masked.store %{{.*}}, %[[MASK_BIT]] diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/reduction_pipeline_rocm.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/reduction_pipeline_rocm.mlir index 9aadf2be312e..0f3bdd5c8920 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/reduction_pipeline_rocm.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/reduction_pipeline_rocm.mlir @@ -33,3 +33,48 @@ hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb", {target_arch = "gf // CHECK-LABEL: func.func @softmax // CHECK-COUNT-20: gpu.shuffle xor{{.*}}{{[[:space:]].*}}{{.*}} + +// ----- + +#pipeline_layout = #hal.pipeline.layout, + #hal.descriptor_set.binding<1, storage_buffer> + ]> +]> + +hal.executable private @dynamic_softmax { + hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb", {target_arch = "gfx1100"}>) { + hal.executable.export public @dynamic_softmax ordinal(0) layout(#pipeline_layout) { + ^bb0(%arg0: !hal.device, %arg1: index): + %x, %y, %z = flow.dispatch.workgroup_count_from_slice %arg1 + hal.return %x, %y, %z : index, index, index + } + builtin.module { + func.func @dynamic_softmax() { + %c32_i64 = arith.constant 32 : i64 + %c0 = arith.constant 0 : index + %0 = hal.interface.constant.load[0] : i32 + %1 = hal.interface.constant.load[1] : i32 + %2 = arith.extui %0 : i32 to i64 + %3 = arith.extui %1 : i32 to i64 + %4 = arith.shli %3, %c32_i64 : i64 + %5 = arith.ori %2, %4 : i64 + %6 = arith.index_castui %5 : i64 to index + %7 = flow.dispatch.workload.ordinal %6, 0 : index + %8 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor>{%7} + %9 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor>{%7} + %10 = flow.dispatch.tensor.load %8, offsets = [0, 0], sizes = [32, %7], strides = [1, 1] : !flow.dispatch.tensor>{%7} -> tensor<32x?xf16> + %11 = tensor.empty(%7) : tensor<32x?xf16> + %12 = linalg.softmax dimension(1) ins(%10 : tensor<32x?xf16>) outs(%11 : tensor<32x?xf16>) -> tensor<32x?xf16> + flow.dispatch.tensor.store %12, %9, offsets = [0, 0], sizes = [32, %7], strides = [1, 1] : tensor<32x?xf16> -> !flow.dispatch.tensor>{%7} + return + } + } + } +} + +// Finer details of this lowering are captured by the spirv pipeline test. Just +// verify that warp reduction triggers. +// CHECK-LABEL: func.func @dynamic_softmax +// CHECK-COUNT-10: gpu.shuffle xor {{.*}} : i32