From 97224402dae8d80a903813dcbd135cc1ade1d965 Mon Sep 17 00:00:00 2001 From: Kunwar Grover Date: Tue, 13 Feb 2024 02:42:03 +0530 Subject: [PATCH] [Codegen] Implement getDistributedShape for NestedLayout (#16377) --- .../GPU/test/gpu_vector_distribution.mlir | 31 +++++++++++++++++++ .../Dialect/VectorExt/IR/VectorExtAttrs.cpp | 8 ++++- 2 files changed, 38 insertions(+), 1 deletion(-) diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_vector_distribution.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_vector_distribution.mlir index 3a7c2087f37b..bacdf77d1ad4 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_vector_distribution.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_vector_distribution.mlir @@ -34,6 +34,37 @@ func.func @distribute_elementwise_i32(%a: vector<16x16xi32>, %b: vector<16x16xi3 return %d : vector<16x16xi32> } +#nested = #iree_vector_ext.nested_layout< + subgroups_per_workgroup = [2, 1, 1], + batches_per_subgroup = [8, 2, 4], + outers_per_batch = [1, 4, 4], + threads_per_outer = [8, 2, 4], + elements_per_thread = [1, 8, 2], + + subgroup_order = [0, 1, 2], + batch_order = [0, 1, 2], + outer_order = [0, 1, 2], + thread_order = [0, 1, 2], + element_order = [0, 2, 1] +> + +// CHECK-LABEL: @distribute_elementwise_nested_layout_f16 +func.func @distribute_elementwise_nested_layout_f16(%a: vector<128x128x128xf16>, %b: vector<128x128x128xf16>) -> vector<128x128x128xf16> { + %c0 = arith.constant 0 : index + %cst_0 = arith.constant 0.0 : f16 + // CHECK: %[[ROOT:.*]] = arith.constant dense<0.000000e+00> : vector<8x2x4x1x4x4x1x2x8xf16> + %root = arith.constant {"__vector_layout_test_anchor_result_0" = #nested} dense<0.0> : vector<128x128x128xf16> + // CHECK-DAG: %[[B:.*]] = iree_vector_ext.to_simt %{{.*}} : vector<128x128x128xf16> -> vector<8x2x4x1x4x4x1x2x8xf16> + // CHECK-DAG: %[[C:.*]] = arith.mulf %[[B]], %[[ROOT]] : vector<8x2x4x1x4x4x1x2x8xf16> + %c = arith.mulf %root, %b : vector<128x128x128xf16> + // CHECK-DAG: %[[A:.*]] = iree_vector_ext.to_simt %{{.*}} : vector<128x128x128xf16> -> vector<8x2x4x1x4x4x1x2x8xf16> + // CHECK-DAG: %[[D:.*]] = arith.addf %[[C]], %[[A]] fastmath : vector<8x2x4x1x4x4x1x2x8xf16> + %d = arith.addf %c, %a fastmath : vector<128x128x128xf16> + // CHECK: iree_vector_ext.to_simd %[[D]] : vector<8x2x4x1x4x4x1x2x8xf16> -> vector<128x128x128xf16> + return %d : vector<128x128x128xf16> +} + +// CHECK-LABEL: @distribute_scf_for func.func @distribute_scf_for(%a: vector<16x16xi32>, %b: vector<16x16xi32>) -> vector<16x16xi32> { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtAttrs.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtAttrs.cpp index 3f9b0c866887..6b8052910ffb 100644 --- a/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtAttrs.cpp +++ b/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtAttrs.cpp @@ -206,8 +206,14 @@ NestedLayoutAttr::permute(ArrayRef permutation) const { llvm_unreachable("Not yet implemented"); } +/// We distribute to: +/// SmallVector NestedLayoutAttr::getDistributedShape() const { - llvm_unreachable("Not yet implemented"); + SmallVector shape; + shape.append(applyPermutation(getBatchesPerSubgroup(), getBatchOrder())); + shape.append(applyPermutation(getOutersPerBatch(), getOuterOrder())); + shape.append(applyPermutation(getElementsPerThread(), getElementOrder())); + return shape; } bool NestedLayoutAttr::isValidLayout(ArrayRef shape) const {