diff --git a/compiler/src/iree/compiler/Codegen/Common/test/iree_comprehensive_bufferize.mlir b/compiler/src/iree/compiler/Codegen/Common/test/iree_comprehensive_bufferize.mlir index e8a8488c3787..f7954d0483b4 100644 --- a/compiler/src/iree/compiler/Codegen/Common/test/iree_comprehensive_bufferize.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/test/iree_comprehensive_bufferize.mlir @@ -2566,3 +2566,28 @@ func.func @micro_kernel_op() { // CHECK-SAME: outs(%[[ARG1]], %[[ARG2]] : // CHECK-SAME: (%[[S0]], %[[ARG3]], %[[S1]] : // CHECK: return + +// ----- + +func.func @sub_byte_bufferize_with_offset() { + %c64 = arith.constant 64 : index + %c0 = arith.constant 0 : index + %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c64) flags(ReadOnly) : !flow.dispatch.tensor> + %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor> + %workgroup_id_x = hal.interface.workgroup.id[0] : index + %2 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%workgroup_id_x] + %3 = flow.dispatch.tensor.load %1, offsets = [%2], sizes = [64], strides = [1] : !flow.dispatch.tensor> -> tensor<64xf32> + %4 = flow.dispatch.tensor.load %0, offsets = [%2], sizes = [64], strides = [1] : !flow.dispatch.tensor> -> tensor<64xi4> + %5 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%4 : tensor<64xi4>) outs(%3 : tensor<64xf32>) { + ^bb0(%in: i4, %out: f32): + %6 = arith.extui %in : i4 to i32 + %7 = arith.uitofp %6 : i32 to f32 + linalg.yield %7 : f32 + } -> tensor<64xf32> + flow.dispatch.tensor.store %5, %1, offsets = [%2], sizes = [64], strides = [1] : tensor<64xf32> -> !flow.dispatch.tensor> + return +} +// CHECK-LABEL: func.func @sub_byte_bufferize_with_offset() +// CHECK: %[[C64:.+]] = arith.constant 64 : index +// CHECK: hal.interface.binding.subspan set(0) binding(0) +// CHECK-SAME: memref<64xi4, strided<[1], offset: 128> diff --git a/compiler/src/iree/compiler/Codegen/Utils/Utils.cpp b/compiler/src/iree/compiler/Codegen/Utils/Utils.cpp index 370cbd71cf74..241070e16bcc 100644 --- a/compiler/src/iree/compiler/Codegen/Utils/Utils.cpp +++ b/compiler/src/iree/compiler/Codegen/Utils/Utils.cpp @@ -711,29 +711,32 @@ OpFoldResult convertByteOffsetToElementOffset(RewriterBase &rewriter, Location loc, OpFoldResult byteOffset, Type elementType) { - OpFoldResult elementWidth = - TypeSwitch(elementType) - .Case( - [&](auto type) -> OpFoldResult { - return rewriter.getIndexAttr( - IREE::Util::getRoundedElementByteWidth(elementType)); - }) - .Default([&](Type t) -> OpFoldResult { - return rewriter.create(loc, elementType) - .getResult(); - }); - AffineExpr s0, s1; - bindSymbols(rewriter.getContext(), s0, s1); - return affine::makeComposedFoldedAffineApply(rewriter, loc, s0.floorDiv(s1), - {byteOffset, elementWidth}); + if (isa(elementType)) { + unsigned typeBitWidth = IREE::Util::getTypeBitWidth(elementType); + assert(llvm::isPowerOf2_32(typeBitWidth) && + "unhandled non powers of 2 bit width while converting byte offset " + "to element offset"); + AffineExpr s0, s1; + bindSymbols(rewriter.getContext(), s0, s1); + return affine::makeComposedFoldedAffineApply( + rewriter, loc, (s0 * 8).floorDiv(typeBitWidth), + {byteOffset, rewriter.getIndexAttr(typeBitWidth)}); + } else { + OpFoldResult elementByteSize = + rewriter.create(loc, elementType).getResult(); + AffineExpr s0, s1; + bindSymbols(rewriter.getContext(), s0, s1); + return affine::makeComposedFoldedAffineApply(rewriter, loc, s0.floorDiv(s1), + {byteOffset, elementByteSize}); + } } //===---------------------------------------------------------------------===// // Replace Memref users (transitively) //===---------------------------------------------------------------------===// -/// Replaces a `use` with the `replacement` for cases where a simple substition -/// might lead to verification errors. +/// Replaces a `use` with the `replacement` for cases where a simple +/// substition might lead to verification errors. static std::optional> replaceNonTrivialUse(RewriterBase &rewriter, Location loc, OpOperand &use, Value replacement) { diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilTypes.h b/compiler/src/iree/compiler/Dialect/Util/IR/UtilTypes.h index 362b0584bbf7..76bb1c1f8209 100644 --- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilTypes.h +++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilTypes.h @@ -196,6 +196,10 @@ static inline unsigned getTypeBitWidth(Type type) { if (auto complexType = type.dyn_cast()) { return 2 * complexType.getElementType().getIntOrFloatBitWidth(); } + if (auto vectorType = type.dyn_cast()) { + return vectorType.getNumElements() * + getTypeBitWidth(vectorType.getElementType()); + } return type.getIntOrFloatBitWidth(); }