Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPIRV] Break SPIRVVectorize pass into GenericVectorization and SPIRVVectorLowering passes. #15009

Merged
merged 1 commit into from
Sep 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion compiler/src/iree/compiler/Codegen/SPIRV/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ iree_compiler_cc_library(
"SPIRVTileAndDistribute.cpp",
"SPIRVTileAndPromote.cpp",
"SPIRVTileAndVectorizeToCooperativeOps.cpp",
"SPIRVVectorLowering.cpp",
"SPIRVVectorToGPUSubgroupMMAOps.cpp",
"SPIRVVectorize.cpp",
"SPIRVVectorizeLoadStore.cpp",
"Utils.cpp",
"Verifiers.cpp",
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/iree/compiler/Codegen/SPIRV/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ iree_cc_library(
"SPIRVTileAndDistribute.cpp"
"SPIRVTileAndPromote.cpp"
"SPIRVTileAndVectorizeToCooperativeOps.cpp"
"SPIRVVectorLowering.cpp"
"SPIRVVectorToGPUSubgroupMMAOps.cpp"
"SPIRVVectorize.cpp"
"SPIRVVectorizeLoadStore.cpp"
"Utils.cpp"
"Verifiers.cpp"
Expand Down
26 changes: 20 additions & 6 deletions compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,13 @@ void addSPIRVBaseVectorizePassPipeline(OpPassManager &pm) {
nestedModulePM.addNestedPass<func::FuncOp>(createSPIRVTilePass());
nestedModulePM.addPass(createCanonicalizerPass());
nestedModulePM.addPass(createCSEPass());
nestedModulePM.addNestedPass<func::FuncOp>(createSPIRVVectorizePass());
{
GenericVectorizationPassOptions options;
options.vectorizeGatherAccesses = true;
nestedModulePM.addNestedPass<func::FuncOp>(
createGenericVectorizationPass(options));
}
nestedModulePM.addNestedPass<func::FuncOp>(createSPIRVVectorLoweringPass());
nestedModulePM.addNestedPass<func::FuncOp>(createForOpCanonicalizationPass());
nestedModulePM.addPass(createCanonicalizerPass());
nestedModulePM.addPass(createCSEPass());
Expand Down Expand Up @@ -384,7 +390,7 @@ void addSPIRVCooperativeMatrixVectorizePassPipeline(OpPassManager &pm,
createSPIRVVectorToGPUSubgroupMMAOpsPass());
nestedModulePM.addPass(createCanonicalizerPass());
nestedModulePM.addPass(createCSEPass());
nestedModulePM.addNestedPass<func::FuncOp>(createSPIRVVectorizePass());
nestedModulePM.addNestedPass<func::FuncOp>(createSPIRVVectorLoweringPass());

if (pipelineDepth > 0) {
PipeliningSchedulingStrategy schedule =
Expand Down Expand Up @@ -446,7 +452,7 @@ void addSPIRVMatmulPromoteVectorizePassPipeline(OpPassManager &topPM,
nestedPM.addNestedPass<func::FuncOp>(createGPUReduceSharedMemoryBankConflicts(
detail::bankConflictReductionPaddingBits));

nestedPM.addNestedPass<func::FuncOp>(createSPIRVVectorizePass());
nestedPM.addNestedPass<func::FuncOp>(createSPIRVVectorLoweringPass());
nestedPM.addNestedPass<func::FuncOp>(createForOpCanonicalizationPass());
nestedPM.addPass(createCanonicalizerPass());
nestedPM.addPass(createCSEPass());
Expand Down Expand Up @@ -574,7 +580,7 @@ void addSPIRVSubgroupReducePassPipeline(OpPassManager &pm) {
createConvertVectorReductionToGPUPass(getWarpSize));
// Perform normal vector unrolling and lowering transformations. This breaks
// vectors down to native machine size.
nestedModulePM.addNestedPass<func::FuncOp>(createSPIRVVectorizePass());
nestedModulePM.addNestedPass<func::FuncOp>(createSPIRVVectorLoweringPass());
nestedModulePM.addPass(createCanonicalizerPass());
nestedModulePM.addPass(createCSEPass());
}
Expand All @@ -598,7 +604,14 @@ void addSPIRVWinogradVectorizePassPipeline(OpPassManager &pm) {
createSPIRVAnnotateWinogradLoopsPass());
nestedModulePM.addPass(createCanonicalizerPass());
nestedModulePM.addPass(createCSEPass());
nestedModulePM.addNestedPass<func::FuncOp>(createSPIRVVectorizePass());
{
GenericVectorizationPassOptions options;
options.vectorizeGatherAccesses = true;
options.enableCleanup = true;
nestedModulePM.addNestedPass<func::FuncOp>(
createGenericVectorizationPass(options));
}
nestedModulePM.addNestedPass<func::FuncOp>(createSPIRVVectorLoweringPass());
nestedModulePM.addNestedPass<func::FuncOp>(createForOpCanonicalizationPass());
nestedModulePM.addPass(createCanonicalizerPass());
nestedModulePM.addPass(createCSEPass());
Expand All @@ -621,7 +634,8 @@ void addSPIRVTransformDialectPassPipeline(OpPassManager &pm) {
// Run SPIRVVectorize pass additionally to convert vectors into forms needed
// for SPIR-V.
auto &nestedModulePM = pm.nest<ModuleOp>();
nestedModulePM.addNestedPass<func::FuncOp>(createSPIRVVectorizePass());
nestedModulePM.addNestedPass<func::FuncOp>(createGenericVectorizationPass());
nestedModulePM.addNestedPass<func::FuncOp>(createSPIRVVectorLoweringPass());
}

//===----------------------------------------------------------------------===//
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/iree/compiler/Codegen/SPIRV/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ createSPIRVVectorToGPUSubgroupMMAOpsPass();
std::unique_ptr<OperationPass<ModuleOp>> createSPIRVVectorizeLoadStore();

/// Pass to vectorize Linalg ops with buffer semantics.
std::unique_ptr<OperationPass<func::FuncOp>> createSPIRVVectorizePass();
std::unique_ptr<OperationPass<func::FuncOp>> createSPIRVVectorLoweringPass();

/// Pass to do vectorization suitable for lowering to SPIR-V cooperative ops.
std::unique_ptr<OperationPass<func::FuncOp>>
Expand Down
4 changes: 2 additions & 2 deletions compiler/src/iree/compiler/Codegen/SPIRV/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,9 @@ def SPIRVTileToCooperativeOps : Pass<
"mlir::iree_compiler::createSPIRVTileToCooperativeOpsPass()";
}

def SPIRVVectorize : Pass<"iree-spirv-vectorize", "func::FuncOp"> {
def SPIRVVectorLowering : Pass<"iree-spirv-vector-lowering", "func::FuncOp"> {
let summary = "Vectorize Linalg ops with buffer semantics";
let constructor = "mlir::iree_compiler::createSPIRVVectorizePass()";
let constructor = "mlir::iree_compiler::createSPIRVVectorLoweringPass()";
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably remove these constructors in the future and let tablegem autogenerate all of this code. But this is as a side note only, let's keep things as close to NFC as possible in this PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes we should. I split up the TD file to be per backend to do that (someday). I

}

def SPIRVVectorizeLoadStore :
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

//===- SPIRVVectorize.cpp -------------------------------------------------===//
//===- SPIRVVectorLoweringPass.cpp
//-------------------------------------------------===//
//
// This pass vectorizes Linalg ops with buffer semantics.
//
Expand Down Expand Up @@ -41,10 +42,7 @@
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

using mlir::iree_compiler::IREE::LinalgExt::LinalgVectorizationPattern;
using mlir::iree_compiler::IREE::LinalgExt::VectorizationPatterns;

#define DEBUG_TYPE "iree-spirv-vectorize"
#define DEBUG_TYPE "iree-spirv-vector-lowering"

namespace mlir {
namespace iree_compiler {
Expand Down Expand Up @@ -252,20 +250,6 @@ getNativeVectorShape(Operation *op, bool targetSupportsDotProd) {
.Default([](Operation *) { return std::nullopt; });
}

/// Add patterns to vectorize any supported Linalg ops.
void populateVectorizationPatterns(RewritePatternSet &patterns) {
IREE::LinalgExt::LinalgTransformationFilter filter;
IREE::LinalgExt::LinalgVectorizationOptions options;
// Enable vectorizing tensor.extract in Linalg ops.
options.vectorizeGatherAccesses = true;
VectorizationPatterns<linalg::FillOp, linalg::GenericOp>::insert(
patterns, options, filter);
linalg::populateConvolutionVectorizationPatterns(patterns);
patterns.add<LinalgVectorizationPattern>(
patterns.getContext(), options,
filter.addOpFilter<linalg::ContractionOpInterface>());
}

/// Adds patterns to unroll vector ops to SPIR-V native vector size.
void populateVectorUnrollPatterns(RewritePatternSet &patterns,
bool targetSupportsDotProd) {
Expand Down Expand Up @@ -303,10 +287,11 @@ bool supportsIntegerDotProductOps(func::FuncOp fn) {
}

/// Vectorizes Linalg ops on buffer semantics.
class SPIRVVectorizePass : public SPIRVVectorizeBase<SPIRVVectorizePass> {
class SPIRVVectorLoweringPass
: public SPIRVVectorLoweringBase<SPIRVVectorLoweringPass> {
public:
SPIRVVectorizePass() = default;
SPIRVVectorizePass(const SPIRVVectorizePass &pass) = default;
SPIRVVectorLoweringPass() = default;
SPIRVVectorLoweringPass(const SPIRVVectorLoweringPass &pass) = default;

void getDependentDialects(DialectRegistry &registry) const override {
// vector.gather lowering patterns target scf ops.
Expand All @@ -321,10 +306,9 @@ class SPIRVVectorizePass : public SPIRVVectorizeBase<SPIRVVectorizePass> {
bool emitIntegerDotProdOps = supportsIntegerDotProductOps(funcOp);

// First apply vectorization to generate vectors of the original tensor
// shape.
// shape for tensor.pad ops.
{
RewritePatternSet patterns(context);
populateVectorizationPatterns(patterns);
// Pull in additional vectorization patterns in IREE.
populateVectorizePadPatterns(patterns);
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
Expand All @@ -333,26 +317,11 @@ class SPIRVVectorizePass : public SPIRVVectorizeBase<SPIRVVectorizePass> {
}

LLVM_DEBUG({
llvm::dbgs() << "--- After vectorization ---\n";
llvm::dbgs() << "--- After IREE tensor.pad vectorization ---\n";
funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
llvm::dbgs() << "\n\n";
});

{
auto result = funcOp.walk([&](linalg::LinalgOp op) {
// linalg.generic ops for copy are fine to not vectorize; they will be
// handled in later steps.
if (isa<linalg::YieldOp>(op.getBlock()->begin())) {
return WalkResult::advance();
}
// Other ones should error out.
op.emitOpError("should not remain after vectorization");
return WalkResult::interrupt();
});
if (result.wasInterrupted())
return signalPassFailure();
}

// Special peephole optimizations to clean up IR before further processing.
{
RewritePatternSet patterns(context);
Expand Down Expand Up @@ -657,8 +626,8 @@ class SPIRVVectorizePass : public SPIRVVectorizeBase<SPIRVVectorizePass> {

} // namespace

std::unique_ptr<OperationPass<func::FuncOp>> createSPIRVVectorizePass() {
return std::make_unique<SPIRVVectorizePass>();
std::unique_ptr<OperationPass<func::FuncOp>> createSPIRVVectorLoweringPass() {
return std::make_unique<SPIRVVectorLoweringPass>();
}

} // namespace iree_compiler
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// RUN: iree-opt --split-input-file \
// RUN: --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(builtin.module(func.func(iree-spirv-tile,canonicalize,cse,iree-spirv-vectorize,canonicalize,cse)))))' \
// RUN: --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(builtin.module(func.func(iree-spirv-tile,canonicalize,cse,iree-codegen-generic-vectorization,iree-spirv-vector-lowering,canonicalize,cse)))))' \
// RUN: %s | FileCheck %s

#config = #iree_codegen.lowering_config<tile_sizes = [[1, 8, 64], [1, 8, 4], [0, 0, 0, 4]]>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// RUN: iree-opt --split-input-file \
// RUN: --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(builtin.module(func.func(iree-spirv-create-fast-slow-path,iree-spirv-tile,canonicalize,cse,iree-spirv-vectorize,canonicalize,cse)))))' \
// RUN: --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(builtin.module(func.func(iree-spirv-create-fast-slow-path,iree-spirv-tile,canonicalize,cse,iree-codegen-generic-vectorization,iree-spirv-vector-lowering,canonicalize,cse)))))' \
// RUN: %s | FileCheck %s

#config = #iree_codegen.lowering_config<tile_sizes = [[0, 4, 4, 16], [0, 2, 2, 4], [0, 0, 0, 0, 1, 1, 4], [0, 1, 0, 0]]>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// RUN: iree-opt --split-input-file \
// RUN: --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(builtin.module(func.func(iree-spirv-tile,canonicalize,cse,iree-spirv-vectorize,canonicalize,cse)))))' %s | FileCheck %s
// RUN: --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(builtin.module(func.func(iree-spirv-tile,canonicalize,cse,iree-codegen-generic-vectorization,iree-spirv-vector-lowering,canonicalize,cse)))))' %s | FileCheck %s

#config = #iree_codegen.lowering_config<tile_sizes = [[8, 64], [8, 4], [0, 0, 4]]>
#translation = #iree_codegen.translation_info<SPIRVBaseVectorize>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// RUN: iree-opt --split-input-file \
// RUN: --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(builtin.module(func.func(iree-spirv-tile,canonicalize,cse,iree-spirv-vectorize,canonicalize,cse)))))' \
// RUN: --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(builtin.module(func.func(iree-spirv-tile,canonicalize,cse,iree-codegen-generic-vectorization,iree-spirv-vector-lowering,canonicalize,cse)))))' \
// RUN: %s | FileCheck %s

#config = #iree_codegen.lowering_config<tile_sizes = [[0, 2, 2, 8], [0, 1, 1, 4], [0, 0, 0, 0, 1, 1], [0, 1, 0, 0]]>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
// RUN: iree-opt --split-input-file --iree-spirv-vectorize %s | FileCheck %s
// RUN: iree-opt --split-input-file \
// RUN: --pass-pipeline='builtin.module(func.func(iree-codegen-generic-vectorization,iree-spirv-vector-lowering))' \
// RUN: %s | FileCheck %s

func.func @ncw_conv_1d(%input: tensor<2x4x4xf32>, %filter: tensor<4x4x1xf32>, %init: tensor<2x4x4xf32>) -> tensor<2x4x4xf32> {
%0 = linalg.conv_1d_ncw_fcw {dilations = dense<1> : vector<1xi64>, strides = dense<1> : vector<1xi64>}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
// RUN: iree-opt --split-input-file --iree-spirv-vectorize %s | FileCheck %s
// RUN: iree-opt --split-input-file \
// RUN: --pass-pipeline='builtin.module(func.func(iree-codegen-generic-vectorization,iree-spirv-vector-lowering))' \
// RUN: %s | FileCheck %s

func.func @add(%lhs: tensor<2x8xf32>, %rhs: tensor<2x8xf32>) -> tensor<2x8xf32> {
%init = tensor.empty() : tensor<2x8xf32>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
// RUN: iree-opt --split-input-file --iree-spirv-vectorize -canonicalize %s | FileCheck %s
// RUN: iree-opt --split-input-file \
// RUN: --pass-pipeline='builtin.module(func.func(iree-codegen-generic-vectorization{vectorize-gather-accesses},iree-spirv-vector-lowering))' \
// RUN: %s | FileCheck %s

func.func @tensor_extract(%arg0: tensor<6x4xf32>, %arg1: tensor<6xi32>, %data: tensor<1x2x512xf32>, %init: tensor<6x4xf32>, %i : index, %j: index) -> tensor<6x4xf32> {
%c0 = arith.constant 0 : index
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
// RUN: iree-opt --split-input-file --iree-spirv-vectorize --canonicalize %s | FileCheck %s
// RUN: iree-opt --split-input-file \
// RUN: --pass-pipeline='builtin.module(func.func(iree-codegen-generic-vectorization,iree-spirv-vector-lowering))' \
// RUN: %s | FileCheck %s

func.func @matmul_1x4x4(%lhs: tensor<1x4xf32>, %rhs: tensor<4x4xf32>, %init: tensor<1x4xf32>) -> tensor<1x4xf32> {
%0 = linalg.matmul ins(%lhs, %rhs : tensor<1x4xf32>, tensor<4x4xf32>) outs(%init : tensor<1x4xf32>) -> tensor<1x4xf32>
Expand Down Expand Up @@ -175,31 +177,31 @@ func.func @matmul_2x8x128_fp16(%a: tensor<2x128xf16>, %b: tensor<128x8xf16>, %x:
// CHECK-LABEL: func.func @matmul_2x8x128_fp16
// CHECK-SAME: (%[[LHS:.+]]: tensor<2x128xf16>, %[[RHS:.+]]: tensor<128x8xf16>, %[[X:.+]]: tensor<2x8xf16>, %[[Y:.+]]: tensor<2x8xf16>)
// CHECK: %[[ZERO:.+]] = arith.constant dense<0.000000e+00> : vector<8xf16>
// CHECK: %[[FOR:.+]]:2 = scf.for %arg4 = %{{.+}} to %{{.+}} step %{{.+}} iter_args(%arg5 = %[[ZERO]], %arg6 = %[[ZERO]])
// CHECK: %[[FOR:.+]]:3 = scf.for %arg4 = %{{.+}} to %{{.+}} step %{{.+}} iter_args(%arg5 = %{{.+}}, %arg6 = %[[ZERO]], %arg7 = %[[ZERO]])
// CHECK-COUNT-2: vector.transfer_read %[[LHS]]{{.+}} : tensor<2x128xf16>, vector<8xf16>
// CHECK-COUNT-8: vector.transfer_read %[[RHS]]{{.+}} : tensor<128x8xf16>, vector<8xf16>
// CHECK-COUNT-32: vector.fma {{.+}} : vector<4xf16>
// CHECK: %[[ISS0:.+]] = vector.insert_strided_slice %{{.+}}, %[[ZERO]] {offsets = [0], strides = [1]} : vector<4xf16> into vector<8xf16>
// CHECK: %[[ISS1:.+]] = vector.insert_strided_slice %{{.+}}, %[[ISS0]] {offsets = [4], strides = [1]} : vector<4xf16> into vector<8xf16>
// CHECK: %[[ISS2:.+]] = vector.insert_strided_slice %{{.+}}, %[[ZERO]] {offsets = [0], strides = [1]} : vector<4xf16> into vector<8xf16>
// CHECK: %[[ISS3:.+]] = vector.insert_strided_slice %{{.+}}, %[[ISS2]] {offsets = [4], strides = [1]} : vector<4xf16> into vector<8xf16>
// CHECK: scf.yield %[[ISS3]], %[[ISS1]] : vector<8xf16>, vector<8xf16>
// CHECK: scf.yield %arg5, %[[ISS3]], %[[ISS1]] : tensor<2x8xf16>, vector<8xf16>, vector<8xf16>
// CHECK: }
// CHECK: %[[X0:.+]] = vector.transfer_read %[[X]]{{.+}} : tensor<2x8xf16>, vector<8xf16>
// CHECK: %[[X1:.+]] = vector.transfer_read %[[X]]{{.+}} : tensor<2x8xf16>, vector<8xf16>
// CHECK: %[[LHS0:.+]] = vector.extract_strided_slice %[[FOR]]#1 {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
// CHECK: %[[LHS0:.+]] = vector.extract_strided_slice %[[FOR]]#2 {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
// CHECK: %[[RHS0:.+]] = vector.extract_strided_slice %[[X0]] {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
// CHECK: %[[DIV0:.+]] = arith.divf %[[LHS0]], %[[RHS0]]
// CHECK: %[[ISS0:.+]] = vector.insert_strided_slice %[[DIV0]], %[[ZERO]] {offsets = [0], strides = [1]} : vector<4xf16> into vector<8xf16>
// CHECK: %[[LHS1:.+]] = vector.extract_strided_slice %[[FOR]]#1 {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
// CHECK: %[[LHS1:.+]] = vector.extract_strided_slice %[[FOR]]#2 {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
// CHECK: %[[RHS1:.+]] = vector.extract_strided_slice %[[X0]] {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
// CHECK: %[[DIV1:.+]] = arith.divf %[[LHS1]], %[[RHS1]]
// CHECK: %[[ISS1:.+]] = vector.insert_strided_slice %[[DIV1]], %[[ISS0]] {offsets = [4], strides = [1]} : vector<4xf16> into vector<8xf16>
// CHECK: %[[LHS2:.+]] = vector.extract_strided_slice %[[FOR]]#0 {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
// CHECK: %[[LHS2:.+]] = vector.extract_strided_slice %[[FOR]]#1 {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
// CHECK: %[[RHS2:.+]] = vector.extract_strided_slice %[[X1]] {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
// CHECK: %[[DIV2:.+]] = arith.divf %[[LHS2]], %[[RHS2]]
// CHECK: %[[ISS2:.+]] = vector.insert_strided_slice %[[DIV2]], %[[ZERO]] {offsets = [0], strides = [1]} : vector<4xf16> into vector<8xf16>
// CHECK: %[[LHS3:.+]] = vector.extract_strided_slice %[[FOR]]#0 {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
// CHECK: %[[LHS3:.+]] = vector.extract_strided_slice %[[FOR]]#1 {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
// CHECK: %[[RHS3:.+]] = vector.extract_strided_slice %[[X1]] {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
// CHECK: %[[DIV3:.+]] = arith.divf %[[LHS3]], %[[RHS3]]
// CHECK: %[[ISS3:.+]] = vector.insert_strided_slice %[[DIV3]], %[[ISS2]] {offsets = [4], strides = [1]} : vector<4xf16> into vector<8xf16>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
// RUN: iree-opt -split-input-file -iree-spirv-vectorize %s | FileCheck %s
// RUN: iree-opt --split-input-file \
// RUN: --pass-pipeline='builtin.module(func.func(iree-codegen-generic-vectorization,iree-spirv-vector-lowering))' \
// RUN: %s | FileCheck %s

func.func @reduce_outmost_dim(%input: tensor<4x1x4xf32>, %init: tensor<1x4xf32>) -> tensor<1x4xf32> {
%f0 = arith.constant 0.0 : f32
Expand Down
Loading