Skip to content

Commit

Permalink
[Codegen] Implement getDistributedShape for NestedLayout (iree-org#16377
Browse files Browse the repository at this point in the history
)
  • Loading branch information
Groverkss authored Feb 12, 2024
1 parent 9aabcb3 commit 9722440
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,37 @@ func.func @distribute_elementwise_i32(%a: vector<16x16xi32>, %b: vector<16x16xi3
return %d : vector<16x16xi32>
}

#nested = #iree_vector_ext.nested_layout<
subgroups_per_workgroup = [2, 1, 1],
batches_per_subgroup = [8, 2, 4],
outers_per_batch = [1, 4, 4],
threads_per_outer = [8, 2, 4],
elements_per_thread = [1, 8, 2],

subgroup_order = [0, 1, 2],
batch_order = [0, 1, 2],
outer_order = [0, 1, 2],
thread_order = [0, 1, 2],
element_order = [0, 2, 1]
>

// CHECK-LABEL: @distribute_elementwise_nested_layout_f16
func.func @distribute_elementwise_nested_layout_f16(%a: vector<128x128x128xf16>, %b: vector<128x128x128xf16>) -> vector<128x128x128xf16> {
%c0 = arith.constant 0 : index
%cst_0 = arith.constant 0.0 : f16
// CHECK: %[[ROOT:.*]] = arith.constant dense<0.000000e+00> : vector<8x2x4x1x4x4x1x2x8xf16>
%root = arith.constant {"__vector_layout_test_anchor_result_0" = #nested} dense<0.0> : vector<128x128x128xf16>
// CHECK-DAG: %[[B:.*]] = iree_vector_ext.to_simt %{{.*}} : vector<128x128x128xf16> -> vector<8x2x4x1x4x4x1x2x8xf16>
// CHECK-DAG: %[[C:.*]] = arith.mulf %[[B]], %[[ROOT]] : vector<8x2x4x1x4x4x1x2x8xf16>
%c = arith.mulf %root, %b : vector<128x128x128xf16>
// CHECK-DAG: %[[A:.*]] = iree_vector_ext.to_simt %{{.*}} : vector<128x128x128xf16> -> vector<8x2x4x1x4x4x1x2x8xf16>
// CHECK-DAG: %[[D:.*]] = arith.addf %[[C]], %[[A]] fastmath<reassoc,nnan> : vector<8x2x4x1x4x4x1x2x8xf16>
%d = arith.addf %c, %a fastmath<reassoc,nnan> : vector<128x128x128xf16>
// CHECK: iree_vector_ext.to_simd %[[D]] : vector<8x2x4x1x4x4x1x2x8xf16> -> vector<128x128x128xf16>
return %d : vector<128x128x128xf16>
}

// CHECK-LABEL: @distribute_scf_for
func.func @distribute_scf_for(%a: vector<16x16xi32>, %b: vector<16x16xi32>) -> vector<16x16xi32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -206,8 +206,14 @@ NestedLayoutAttr::permute(ArrayRef<int64_t> permutation) const {
llvm_unreachable("Not yet implemented");
}

/// We distribute to:
/// <BATCH x OUTER x ELEMENT>
SmallVector<int64_t> NestedLayoutAttr::getDistributedShape() const {
llvm_unreachable("Not yet implemented");
SmallVector<int64_t> shape;
shape.append(applyPermutation(getBatchesPerSubgroup(), getBatchOrder()));
shape.append(applyPermutation(getOutersPerBatch(), getOuterOrder()));
shape.append(applyPermutation(getElementsPerThread(), getElementOrder()));
return shape;
}

bool NestedLayoutAttr::isValidLayout(ArrayRef<int64_t> shape) const {
Expand Down

0 comments on commit 9722440

Please sign in to comment.