Skip to content

Commit

Permalink
Fix calculation of byte offset to element offset during bufferization. (
Browse files Browse the repository at this point in the history
#14647)

The logic was not adapted to handle sub-byte sizes. This handles powers of 2 sub-byte sizes only. To handle more general case requires upstream MLIR to have a good representation of the packing of such types. This should be revisited and adapted then.

Fixes #14642
  • Loading branch information
MaheshRavishankar authored Aug 12, 2023
1 parent 1466742 commit 6b0930c
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2566,3 +2566,28 @@ func.func @micro_kernel_op() {
// CHECK-SAME: outs(%[[ARG1]], %[[ARG2]] :
// CHECK-SAME: (%[[S0]], %[[ARG3]], %[[S1]] :
// CHECK: return

// -----

func.func @sub_byte_bufferize_with_offset() {
%c64 = arith.constant 64 : index
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c64) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<64xi4>>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<64xf32>>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%2 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%workgroup_id_x]
%3 = flow.dispatch.tensor.load %1, offsets = [%2], sizes = [64], strides = [1] : !flow.dispatch.tensor<writeonly:tensor<64xf32>> -> tensor<64xf32>
%4 = flow.dispatch.tensor.load %0, offsets = [%2], sizes = [64], strides = [1] : !flow.dispatch.tensor<readonly:tensor<64xi4>> -> tensor<64xi4>
%5 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%4 : tensor<64xi4>) outs(%3 : tensor<64xf32>) {
^bb0(%in: i4, %out: f32):
%6 = arith.extui %in : i4 to i32
%7 = arith.uitofp %6 : i32 to f32
linalg.yield %7 : f32
} -> tensor<64xf32>
flow.dispatch.tensor.store %5, %1, offsets = [%2], sizes = [64], strides = [1] : tensor<64xf32> -> !flow.dispatch.tensor<writeonly:tensor<64xf32>>
return
}
// CHECK-LABEL: func.func @sub_byte_bufferize_with_offset()
// CHECK: %[[C64:.+]] = arith.constant 64 : index
// CHECK: hal.interface.binding.subspan set(0) binding(0)
// CHECK-SAME: memref<64xi4, strided<[1], offset: 128>
37 changes: 20 additions & 17 deletions compiler/src/iree/compiler/Codegen/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -711,29 +711,32 @@ OpFoldResult convertByteOffsetToElementOffset(RewriterBase &rewriter,
Location loc,
OpFoldResult byteOffset,
Type elementType) {
OpFoldResult elementWidth =
TypeSwitch<Type, OpFoldResult>(elementType)
.Case<ComplexType, FloatType, IntegerType, VectorType>(
[&](auto type) -> OpFoldResult {
return rewriter.getIndexAttr(
IREE::Util::getRoundedElementByteWidth(elementType));
})
.Default([&](Type t) -> OpFoldResult {
return rewriter.create<IREE::Util::SizeOfOp>(loc, elementType)
.getResult();
});
AffineExpr s0, s1;
bindSymbols(rewriter.getContext(), s0, s1);
return affine::makeComposedFoldedAffineApply(rewriter, loc, s0.floorDiv(s1),
{byteOffset, elementWidth});
if (isa<ComplexType, FloatType, IntegerType, VectorType>(elementType)) {
unsigned typeBitWidth = IREE::Util::getTypeBitWidth(elementType);
assert(llvm::isPowerOf2_32(typeBitWidth) &&
"unhandled non powers of 2 bit width while converting byte offset "
"to element offset");
AffineExpr s0, s1;
bindSymbols(rewriter.getContext(), s0, s1);
return affine::makeComposedFoldedAffineApply(
rewriter, loc, (s0 * 8).floorDiv(typeBitWidth),
{byteOffset, rewriter.getIndexAttr(typeBitWidth)});
} else {
OpFoldResult elementByteSize =
rewriter.create<IREE::Util::SizeOfOp>(loc, elementType).getResult();
AffineExpr s0, s1;
bindSymbols(rewriter.getContext(), s0, s1);
return affine::makeComposedFoldedAffineApply(rewriter, loc, s0.floorDiv(s1),
{byteOffset, elementByteSize});
}
}

//===---------------------------------------------------------------------===//
// Replace Memref users (transitively)
//===---------------------------------------------------------------------===//

/// Replaces a `use` with the `replacement` for cases where a simple substition
/// might lead to verification errors.
/// Replaces a `use` with the `replacement` for cases where a simple
/// substition might lead to verification errors.
static std::optional<SmallVector<Value>>
replaceNonTrivialUse(RewriterBase &rewriter, Location loc, OpOperand &use,
Value replacement) {
Expand Down
4 changes: 4 additions & 0 deletions compiler/src/iree/compiler/Dialect/Util/IR/UtilTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,10 @@ static inline unsigned getTypeBitWidth(Type type) {
if (auto complexType = type.dyn_cast<ComplexType>()) {
return 2 * complexType.getElementType().getIntOrFloatBitWidth();
}
if (auto vectorType = type.dyn_cast<VectorType>()) {
return vectorType.getNumElements() *
getTypeBitWidth(vectorType.getElementType());
}
return type.getIntOrFloatBitWidth();
}

Expand Down

0 comments on commit 6b0930c

Please sign in to comment.