From d9b7f6d5c0c95c6c8e9c40ef66dc68b432b23f2a Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Tue, 8 Aug 2023 09:32:26 -0700 Subject: [PATCH] [HAL] Remove incorrect buffer_view.buffer folding (#14590) It is unsafe to fold `buffer_view.buffer` of `buffer_view.create` because create operation can create a buffer subspan, and to do the correct folding we have to know the buffer size. --- .../compiler/Dialect/HAL/IR/HALOpFolders.cpp | 26 ------------- .../iree/compiler/Dialect/HAL/IR/HALOps.td | 2 - .../HAL/IR/test/buffer_view_folding.mlir | 38 ------------------- 3 files changed, 66 deletions(-) diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp index f55251c673e0..6d52a2009986 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp @@ -162,32 +162,6 @@ void BufferViewCreateOp::getCanonicalizationPatterns(RewritePatternSet &results, results.insert(context); } -namespace { - -/// Skips a hal.buffer_view.buffer accessor when the buffer view was created in -/// the same scope at zero offset and we know the origin buffer. -struct SkipBufferViewBufferOp : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(BufferViewBufferOp op, - PatternRewriter &rewriter) const override { - auto createOp = dyn_cast_or_null( - op.getBufferView().getDefiningOp()); - if (!createOp || !matchPattern(createOp.getSourceOffset(), m_Zero())) - return failure(); - - rewriter.replaceOp(op, createOp.getSourceBuffer()); - return success(); - } -}; - -} // namespace - -void BufferViewBufferOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { - results.insert(context); -} - //===----------------------------------------------------------------------===// // hal.channel.create //===----------------------------------------------------------------------===// diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td index 7c422e6e07eb..6c4844f1e5e2 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td @@ -623,8 +623,6 @@ def HAL_BufferViewBufferOp : HAL_PureOp<"buffer_view.buffer", [ `:` type($result) attr-dict-with-keyword }]; - - let hasCanonicalizer = 1; } def HAL_BufferViewElementTypeOp : HAL_PureOp<"buffer_view.element_type"> { diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/test/buffer_view_folding.mlir b/compiler/src/iree/compiler/Dialect/HAL/IR/test/buffer_view_folding.mlir index 4052725c82d9..b2ec3ced0e8f 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/test/buffer_view_folding.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/test/buffer_view_folding.mlir @@ -20,41 +20,3 @@ func.func @FoldBufferViewCreateSubspan(%base_buffer: !hal.buffer, %subspan_offse encoding(%encoding) : !hal.buffer_view return %view : !hal.buffer_view } - -// ----- - -// CHECK-LABEL: func.func @SkipBufferViewBufferOp -// CHECK-SAME: %[[BUFFER:.+]]: !hal.buffer -func.func @SkipBufferViewBufferOp(%buffer : !hal.buffer) -> !hal.buffer { - %c0 = arith.constant 0 : index - %c10 = arith.constant 10 : index - %c11 = arith.constant 11 : index - %c32 = arith.constant 32 : i32 - %encoding = arith.constant 1 : i32 - %view = hal.buffer_view.create buffer(%buffer : !hal.buffer)[%c0, %c10] - shape([%c10, %c11]) - type(%c32) - encoding(%encoding) : !hal.buffer_view - %view_buffer = hal.buffer_view.buffer<%view : !hal.buffer_view> : !hal.buffer - // CHECK: return %[[BUFFER]] - return %view_buffer : !hal.buffer -} - -// ----- - -// CHECK-LABEL: func.func @DoNotSkipBufferViewBufferOp -func.func @DoNotSkipBufferViewBufferOp(%buffer : !hal.buffer) -> !hal.buffer { - %c5 = arith.constant 5 : index - %c10 = arith.constant 10 : index - %c11 = arith.constant 11 : index - %c32 = arith.constant 32 : i32 - %encoding = arith.constant 1 : i32 - %view = hal.buffer_view.create buffer(%buffer : !hal.buffer)[%c5, %c10] - shape([%c10, %c11]) - type(%c32) - encoding(%encoding) : !hal.buffer_view - // CHECK: %[[BUFFER:.+]] = hal.buffer_view.buffer - %view_buffer = hal.buffer_view.buffer<%view : !hal.buffer_view> : !hal.buffer - // CHECK: return %[[BUFFER]] - return %view_buffer : !hal.buffer -}