Skip to content

Commit

Permalink
Use getEntryPoint
Browse files Browse the repository at this point in the history
  • Loading branch information
qedawkins committed Aug 25, 2023
1 parent 697e162 commit 4e09303
Showing 1 changed file with 7 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "iree/compiler/Codegen/Common/GPU/Passes.h"
#include "iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPUUtils.h"
#include "iree/compiler/Codegen/Utils/GPUUtils.h"
#include "iree/compiler/Codegen/Utils/Utils.h"
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/ScopeExit.h"
Expand Down Expand Up @@ -79,24 +80,12 @@ transform_dialect::MapNestedForallToGpuThreadsOp::applyToOne(
transform::TransformRewriter &rewriter, func::FuncOp target,
transform::ApplyToEachResultList &results,
transform::TransformState &state) {
if (!isa<HAL::ExecutableOp, HAL::ExecutableVariantOp>(state.getTopLevel())) {
state.getTopLevel()->emitOpError(
"requires HAL::ExecutableOp or HAL::ExecutableVariantOp "
"toplevel to "
"attach the workgroup size information to a nested "
"ExecutableExportOp");
return emitDefaultDefiniteFailure(target);
}

IREE::HAL::ExecutableExportOp exportOp;
state.getTopLevel()->walk([&](IREE::HAL::ExecutableExportOp op) {
if (op.getSymName() == target.getName())
exportOp = op;
});
if (!exportOp) {
FailureOr<IREE::HAL::ExecutableExportOp> maybeExportOp = getEntryPoint(target);
if (failed(maybeExportOp)) {
state.getTopLevel()->emitOpError("no IREE::HAL::ExecutableExportOp found");
return emitDefaultDefiniteFailure(target);
}
IREE::HAL::ExecutableExportOp exportOp = *maybeExportOp;

auto transformOp = cast<transform::TransformOpInterface>(getOperation());

Expand Down Expand Up @@ -609,22 +598,12 @@ transform_dialect::VectorWarpDistributionOp::applyToOne(
transform::TransformRewriter &rewriter, func::FuncOp target,
transform::ApplyToEachResultList &results,
transform::TransformState &state) {
if (!isa<HAL::ExecutableOp, HAL::ExecutableVariantOp>(state.getTopLevel())) {
state.getTopLevel()->emitOpError(
"requires HAL::ExecutableOp or HAL::ExecutableVariantOp "
"toplevel to extract subgroup size information");
return emitDefaultDefiniteFailure(target);
}

IREE::HAL::ExecutableExportOp exportOp;
state.getTopLevel()->walk([&](IREE::HAL::ExecutableExportOp op) {
if (op.getSymName() == target.getName())
exportOp = op;
});
if (!exportOp) {
FailureOr<IREE::HAL::ExecutableExportOp> maybeExportOp = getEntryPoint(target);
if (failed(maybeExportOp)) {
state.getTopLevel()->emitOpError("no IREE::HAL::ExecutableExportOp found");
return emitDefaultDefiniteFailure(target);
}
IREE::HAL::ExecutableExportOp exportOp = *maybeExportOp;

std::optional<llvm::APInt> subgroupSize = exportOp.getSubgroupSize();
if (!subgroupSize) {
Expand Down

0 comments on commit 4e09303

Please sign in to comment.