Skip to content

Commit

Permalink
Fix vectorization lit test
Browse files Browse the repository at this point in the history
  • Loading branch information
Abhishek-Varma committed Oct 9, 2024
1 parent 481a5fc commit 6b9b478
Showing 1 changed file with 11 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -65,17 +65,15 @@ func.func @fillAndCopy() -> tensor<8xbf16> {


// CHECK-LABEL: @matmul_elementwise
// CHECK-SAME: (%[[ARG0:.*]]: tensor<4240x160xi8>, %[[ELE:.*]]: tensor<160xi8>)
func.func @matmul_elementwise(%arg0: tensor<4240x160xi8>, %ele: tensor<160xi8>) -> tensor<4240x160xi8> {
%0 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%arg0, %ele : tensor<4240x160xi8>, tensor<160xi8>) outs(%arg0 : tensor<4240x160xi8>) {
^bb0(%in: i8, %in_1: i8, %out: i8):
%1 = arith.addi %in, %in_1 : i8
linalg.yield %1 : i8
} -> tensor<4240x160xi8>
return %0 : tensor<4240x160xi8>
// CHECK-SAME: (%[[ARG0:.*]]: tensor<4240x160xf32>, %[[ARG1:.*]]: tensor<4240x160xbf16>)
func.func @matmul_elementwise(%arg0: tensor<4240x160xf32>, %arg1: tensor<4240x160xbf16>) -> tensor<4240x160xbf16> {
%0 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%arg0: tensor<4240x160xf32>) outs(%arg1 : tensor<4240x160xbf16>) {
^bb0(%in: f32, %out: bf16):
%1 = arith.truncf %in : f32 to bf16
linalg.yield %1 : bf16
} -> tensor<4240x160xbf16>
return %0 : tensor<4240x160xbf16>
}
// CHECK: %[[VEC_OPERAND_0:.*]] = vector.transfer_read %[[ARG0]]{{.*}} vector<4240x160xi8>
// CHECK: %[[VEC_ELE:.*]] = vector.transfer_read %[[ELE]]{{.*}} vector<160xi8>
// CHECK: %[[VEC_OPERAND_1:.*]] = vector.broadcast %[[VEC_ELE]]{{.*}} vector<4240x160xi8>
// CHECK: %[[ADD:.*]] = arith.addi %[[VEC_OPERAND_0]], %[[VEC_OPERAND_1]]
// CHECK: vector.transfer_write %[[ADD]], %[[ARG0]]
// CHECK: %[[VEC_OPERAND_0:.*]] = vector.transfer_read %[[ARG0]]{{.*}} vector<4240x160xf32>
// CHECK: %[[TRUNCF:.*]] = arith.truncf %[[VEC_OPERAND_0]]
// CHECK: vector.transfer_write %[[TRUNCF]], %[[ARG1]]

0 comments on commit 6b9b478

Please sign in to comment.