Skip to content

Commit

Permalink
Add ReifyRankedShapedTypeOpInterface to `hal.interface.binding.subs…
Browse files Browse the repository at this point in the history
…pan`.

Fixes #18942

Signed-off-by: MaheshRavishankar <mahesh.ravishankar@gmail.com>
  • Loading branch information
MaheshRavishankar committed Oct 30, 2024
1 parent 5fc340d commit 20f668f
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 12 deletions.
13 changes: 13 additions & 0 deletions compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "mlir/IR/SymbolTable.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/FunctionImplementation.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"

namespace mlir::iree_compiler::IREE::HAL {

Expand Down Expand Up @@ -2039,6 +2040,18 @@ llvm::Align InterfaceBindingSubspanOp::calculateAlignment() {
offsetOrAlignment.value());
}

LogicalResult InterfaceBindingSubspanOp::reifyResultShapes(
OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
auto resultShapedType = dyn_cast<ShapedType>(getResult().getType());
if (!resultShapedType) {
return failure();
}
SmallVector<OpFoldResult> resultShape = mlir::getMixedValues(
resultShapedType.getShape(), getDynamicDims(), builder);
reifiedReturnShapes.emplace_back(std::move(resultShape));
return success();
}

//===----------------------------------------------------------------------===//
// hal.interface.workgroup.*
//===----------------------------------------------------------------------===//
Expand Down
14 changes: 2 additions & 12 deletions compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ include "mlir/IR/SymbolInterfaces.td"
include "mlir/Interfaces/CallInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/InferIntRangeInterface.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/ViewLikeInterface.td"

Expand Down Expand Up @@ -811,12 +812,6 @@ def HAL_ElementTypeOp : HAL_PureOp<"element_type", [
`:` type($result)
}];

let builders = [
OpBuilder<(ins "Type":$type), [{
build($_builder, $_state, $_builder.getI32Type(), TypeAttr::get(type));
}]>
];

let extraClassDeclaration = [{
// Returns a stable identifier for the MLIR element type or nullopt if the
// type is unsupported in the ABI.
Expand Down Expand Up @@ -848,12 +843,6 @@ def HAL_EncodingTypeOp : HAL_PureOp<"encoding_type", [
`:` type($result)
}];

let builders = [
OpBuilder<(ins "Attribute":$encoding), [{
build($_builder, $_state, $_builder.getI32Type(), encoding);
}]>
];

let extraClassDeclaration = [{
// Returns a stable identifier for the MLIR encoding type or 0 (opaque) if
// the type is unsupported in the ABI.
Expand Down Expand Up @@ -3051,6 +3040,7 @@ def HAL_InterfaceConstantLoadOp : HAL_PureOp<"interface.constant.load"> {

def HAL_InterfaceBindingSubspanOp : HAL_PureOp<"interface.binding.subspan", [
AttrSizedOperandSegments,
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
Util_ShapeAwareOp,
]> {
let summary = [{returns an alias to a subspan of interface binding data}];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ iree_lit_test_suite(
"resolve_device_aliases.mlir",
"resolve_device_promises.mlir",
"resolve_export_ordinals.mlir",
"resolve_ranked_shaped_type.mlir",
"strip_executable_contents.mlir",
"substitute_executables.mlir",
"verify_devices.mlir",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ iree_lit_test_suite(
"resolve_device_aliases.mlir"
"resolve_device_promises.mlir"
"resolve_export_ordinals.mlir"
"resolve_ranked_shaped_type.mlir"
"strip_executable_contents.mlir"
"substitute_executables.mlir"
"verify_devices.mlir"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
// RUN: iree-opt -resolve-ranked-shaped-type-result-dims --split-input-file %s | FileCheck %s

util.func public @hal_interface_binding_subspan_op(%arg0 : index, %arg1 : index) -> (index, index, index) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%0 = hal.interface.binding.subspan layout(<
constants = 0, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">], flags = Indirect>)
binding(0) : memref<64x?x?xf16>{%arg0, %arg1}
%d0 = memref.dim %0, %c0 : memref<64x?x?xf16>
%d1 = memref.dim %0, %c1 : memref<64x?x?xf16>
%d2 = memref.dim %0, %c2 : memref<64x?x?xf16>
util.return %d0, %d1, %d2 : index, index, index
}
// CHECK-LABEL: func public @hal_interface_binding_subspan_op(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index,
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index)
// CHECK: %[[C64:.+]] = arith.constant 64 : index
// CHECK: return %[[C64]], %[[ARG0]], %[[ARG1]]

0 comments on commit 20f668f

Please sign in to comment.